diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 68aff793ae6aa..76f6d7aeca0d8 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -5,11 +5,11 @@ import os import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB -# Note that we have 400 MiB quota, please use it wisely. -# See https://github.com/pypi/support/issues/3792 . +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB +# Note that we have 800 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/6326 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) def print_top_10_largest_files(zip_file): diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index b39f9899a8f28..e6f5c8b60f459 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -141,7 +141,7 @@ When run, benchmark script generates results under `benchmark/results` folder, a `compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead. -Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output lenght, max concurrency and qps. +Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output length, max concurrency and qps. `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` | | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio | diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 8afde017d383e..2ef36089b6afb 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -8,7 +8,7 @@ This benchmark aims to: Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html), scroll to the end. -Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) +Latest reproduction guide: [github issue link](https://github.com/vllm-project/vllm/issues/8176) ## Setup @@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` - - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. - Hardware - 8x Nvidia A100 GPUs diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 50431d0cd4c5e..5ea5a50a258a4 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -218,7 +218,7 @@ if __name__ == "__main__": "--xaxis", type=str, default="# of max concurrency.", - help="column name to use as X Axis in comparision graph", + help="column name to use as X Axis in comparison graph", ) args = parser.parse_args() diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 77047636bb951..a655a650cb325 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -368,7 +368,7 @@ if __name__ == "__main__": # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # we want to turn it into "8xGPUTYPE" df["GPU"] = df["GPU"].apply( - lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}" + lambda x: f"{len(x.splitlines())}x{x.splitlines()[0]}" ) # get markdown tables diff --git a/.buildkite/nightly-benchmarks/scripts/launch-server.sh b/.buildkite/nightly-benchmarks/scripts/launch-server.sh index fb5063db86942..ebacdcbd6821b 100644 --- a/.buildkite/nightly-benchmarks/scripts/launch-server.sh +++ b/.buildkite/nightly-benchmarks/scripts/launch-server.sh @@ -181,18 +181,14 @@ launch_vllm_server() { if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ + server_command="vllm serve $model \ -tp $tp \ - --model $model \ --port $port \ $server_args" else echo "Key 'fp8' does not exist in common params." - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ + server_command="vllm serve $model \ -tp $tp \ - --model $model \ --port $port \ $server_args" fi diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index b1b7d2d77a44d..c64e5638029e7 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -365,8 +365,7 @@ run_serving_tests() { continue fi - server_command="$server_envs python3 \ - -m vllm.entrypoints.openai.api_server \ + server_command="$server_envs vllm serve \ $server_args" # run the server @@ -455,11 +454,6 @@ main() { fi check_hf_token - # Set to v1 to run v1 benchmark - if [[ "${ENGINE_VERSION:-v0}" == "v1" ]]; then - export VLLM_USE_V1=1 - fi - # dependencies (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json index 2d88a0b30c4f8..f758097e098e4 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json @@ -1,6 +1,6 @@ [ { - "test_name": "serving_llama8B_tp1_sharegpt", + "test_name": "serving_llama8B_bf16_tp1_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -32,7 +32,7 @@ } }, { - "test_name": "serving_llama8B_tp2_sharegpt", + "test_name": "serving_llama8B_bf16_tp2_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -64,7 +64,7 @@ } }, { - "test_name": "serving_llama8B_tp4_sharegpt", + "test_name": "serving_llama8B_bf16_tp4_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -96,7 +96,7 @@ } }, { - "test_name": "serving_llama8B_tp1_random_128_128", + "test_name": "serving_llama8B_bf16_tp1_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -131,7 +131,7 @@ } }, { - "test_name": "serving_llama8B_tp2_random_128_128", + "test_name": "serving_llama8B_bf16_tp2_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -166,7 +166,7 @@ } }, { - "test_name": "serving_llama8B_tp4_random_128_128", + "test_name": "serving_llama8B_bf16_tp4_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -198,5 +198,413 @@ "random-output-len": 128, "num_prompts": 1000 } + }, + { + "test_name": "serving_llama8B_int8_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } } ] diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json index 823abbaa99f86..ce396d6e54f27 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json @@ -1,6 +1,6 @@ [ { - "test_name": "serving_llama8B_pp1_sharegpt", + "test_name": "serving_llama8B_bf16_pp1_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -32,7 +32,39 @@ } }, { - "test_name": "serving_llama8B_pp3_sharegpt", + "test_name": "serving_llama8B_bf16_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_bf16_pp3_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -64,7 +96,7 @@ } }, { - "test_name": "serving_llama8B_tp2pp3_sharegpt", + "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -97,7 +129,7 @@ } }, { - "test_name": "serving_llama8B_pp1_random_128_128", + "test_name": "serving_llama8B_bf16_pp1_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -132,7 +164,42 @@ } }, { - "test_name": "serving_llama8B_pp3_random_128_128", + "test_name": "serving_llama8B_bf16_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_bf16_pp3_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -167,7 +234,7 @@ } }, { - "test_name": "serving_llama8B_tp2pp3_random_128_128", + "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -201,5 +268,553 @@ "ignore-eos": "", "num_prompts": 1000 } + }, + { + "test_name": "serving_llama8B_int8_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } } ] diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml deleted file mode 100644 index d5cad1c73c6f8..0000000000000 --- a/.buildkite/pyproject.toml +++ /dev/null @@ -1,46 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.format] -docstring-code-format = true diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index f96c38bf57db7..505323bc2b654 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,21 +1,22 @@ steps: - # aarch64 + CUDA builds - - label: "Build arm64 wheel - CUDA 12.8" - id: build-wheel-arm64-cuda-12-8 + # aarch64 + CUDA builds. PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + - label: "Build arm64 wheel - CUDA 12.9" + depends_on: ~ + id: build-wheel-arm64-cuda-12-9 agents: queue: arm64_cpu_queue_postmerge commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - # x86 + CUDA builds - label: "Build wheel - CUDA 12.8" + depends_on: ~ id: build-wheel-cuda-12-8 agents: queue: cpu_queue_postmerge @@ -27,12 +28,8 @@ steps: env: DOCKER_BUILDKIT: "1" - - block: "Build CUDA 12.6 wheel" - key: block-build-cu126-wheel - depends_on: ~ - - label: "Build wheel - CUDA 12.6" - depends_on: block-build-cu126-wheel + depends_on: ~ id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge @@ -44,44 +41,61 @@ steps: env: DOCKER_BUILDKIT: "1" - # Note(simon): We can always build CUDA 11.8 wheel to ensure the build is working. - # However, this block can be uncommented to save some compute hours. - # - block: "Build CUDA 11.8 wheel" - # key: block-build-cu118-wheel - - - label: "Build wheel - CUDA 11.8" - # depends_on: block-build-cu118-wheel - id: build-wheel-cuda-11-8 + # x86 + CUDA builds + - label: "Build wheel - CUDA 12.9" + depends_on: ~ + id: build-wheel-cuda-12-9 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - - block: "Build release image" + - label: "Build release image (x86)" depends_on: ~ - key: block-release-image-build - - - label: "Build release image" - depends_on: block-release-image-build - id: build-release-image + id: build-release-image-x86 agents: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + # re-tag to default image tag and push, just in case arm64 build fails + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + # PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + - label: "Build release image (arm64)" + depends_on: ~ + id: build-release-image-arm64 + agents: + queue: arm64_cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" + + # Add job to create multi-arch manifest + - label: "Create multi-arch manifest" + depends_on: + - build-release-image-x86 + - build-release-image-arm64 + id: create-multi-arch-manifest + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker manifest create public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64 public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64 --amend" + - "docker manifest push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Annotate release workflow" depends_on: - - build-release-image + - create-multi-arch-manifest - build-wheel-cuda-12-8 - - build-wheel-cuda-12-6 - - build-wheel-cuda-11-8 id: annotate-release-workflow agents: queue: cpu_queue_postmerge @@ -128,18 +142,30 @@ steps: env: DOCKER_BUILDKIT: "1" - - block: "Build Neuron release image" - key: block-neuron-release-image-build - depends_on: ~ - - - label: "Build and publish Neuron release image" - depends_on: block-neuron-release-image-build + - label: "Build and publish nightly multi-arch image to DockerHub" + depends_on: + - create-multi-arch-manifest + if: build.env("NIGHTLY") == "1" agents: - queue: neuron-postmerge + queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." - - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" - - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" + - "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64" + - "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64" + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64 vllm/vllm-openai:nightly-x86_64" + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64 vllm/vllm-openai:nightly-aarch64" + - "docker push vllm/vllm-openai:nightly-x86_64" + - "docker push vllm/vllm-openai:nightly-aarch64" + - "docker manifest create vllm/vllm-openai:nightly vllm/vllm-openai:nightly-x86_64 vllm/vllm-openai:nightly-aarch64 --amend" + - "docker manifest create vllm/vllm-openai:nightly-$BUILDKITE_COMMIT vllm/vllm-openai:nightly-x86_64 vllm/vllm-openai:nightly-aarch64 --amend" + - "docker manifest push vllm/vllm-openai:nightly" + - "docker manifest push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT" + # Clean up old nightly builds (keep only last 14) + - "bash .buildkite/scripts/cleanup-nightly-builds.sh" + plugins: + - docker-login#v3.0.0: + username: vllmbot + password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" + DOCKERHUB_USERNAME: "vllmbot" diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh index 94e0ac2398f34..fde48603ad3cd 100755 --- a/.buildkite/scripts/annotate-release.sh +++ b/.buildkite/scripts/annotate-release.sh @@ -14,18 +14,33 @@ buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF To download the wheel: \`\`\` aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl . + aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl . -aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu118/vllm-${RELEASE_VERSION}+cu118-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu129/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl . \`\`\` To download and upload the image: \`\`\` -docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} -docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} vllm/vllm-openai -docker tag vllm/vllm-openai vllm/vllm-openai:latest -docker tag vllm/vllm-openai vllm/vllm-openai:v${RELEASE_VERSION} -docker push vllm/vllm-openai:latest -docker push vllm/vllm-openai:v${RELEASE_VERSION} +docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 +docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 + +docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 vllm/vllm-openai:x86_64 +docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:latest-x86_64 +docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 +docker push vllm/vllm-openai:latest-x86_64 +docker push vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 + +docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 vllm/vllm-openai:aarch64 +docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:latest-aarch64 +docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 +docker push vllm/vllm-openai:latest-aarch64 +docker push vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 + +docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64 --amend +docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 --amend +docker manifest push vllm/vllm-openai:latest +docker manifest push vllm/vllm-openai:v${RELEASE_VERSION} \`\`\` EOF \ No newline at end of file diff --git a/.buildkite/scripts/cleanup-nightly-builds.sh b/.buildkite/scripts/cleanup-nightly-builds.sh new file mode 100755 index 0000000000000..f02a128c67726 --- /dev/null +++ b/.buildkite/scripts/cleanup-nightly-builds.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +set -ex + +# Clean up old nightly builds from DockerHub, keeping only the last 14 builds +# This script uses DockerHub API to list and delete old tags with "nightly-" prefix + +# DockerHub API endpoint for vllm/vllm-openai repository +REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags" + +# Get DockerHub credentials from environment +if [ -z "$DOCKERHUB_TOKEN" ]; then + echo "Error: DOCKERHUB_TOKEN environment variable is not set" + exit 1 +fi + +if [ -z "$DOCKERHUB_USERNAME" ]; then + echo "Error: DOCKERHUB_USERNAME environment variable is not set" + exit 1 +fi + +# Get DockerHub bearer token +echo "Getting DockerHub bearer token..." +set +x +BEARER_TOKEN=$(curl -s -X POST \ + -H "Content-Type: application/json" \ + -d "{\"username\": \"$DOCKERHUB_USERNAME\", \"password\": \"$DOCKERHUB_TOKEN\"}" \ + "https://hub.docker.com/v2/users/login" | jq -r '.token') +set -x + +if [ -z "$BEARER_TOKEN" ] || [ "$BEARER_TOKEN" = "null" ]; then + echo "Error: Failed to get DockerHub bearer token" + exit 1 +fi + +# Function to get all tags from DockerHub +get_all_tags() { + local page=1 + local all_tags="" + + while true; do + set +x + local response=$(curl -s -H "Authorization: Bearer $BEARER_TOKEN" \ + "$REPO_API_URL?page=$page&page_size=100") + set -x + + # Get both last_updated timestamp and tag name, separated by | + local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"') + + if [ -z "$tags" ]; then + break + fi + + all_tags="$all_tags$tags"$'\n' + page=$((page + 1)) + done + + # Sort by timestamp (newest first) and extract just the tag names + echo "$all_tags" | sort -r | cut -d'|' -f2 +} + +delete_tag() { + local tag_name="$1" + echo "Deleting tag: $tag_name" + + local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name" + set +x + local response=$(curl -s -X DELETE -H "Authorization: Bearer $BEARER_TOKEN" "$delete_url") + set -x + + if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then + echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')" + else + echo "Successfully deleted tag: $tag_name" + fi +} + +# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first) +echo "Fetching all tags from DockerHub..." +all_tags=$(get_all_tags) + +if [ -z "$all_tags" ]; then + echo "No tags found to clean up" + exit 0 +fi + +# Count total tags +total_tags=$(echo "$all_tags" | wc -l) +echo "Found $total_tags tags" + +# Keep only the last 14 builds (including the current one) +tags_to_keep=14 +tags_to_delete=$((total_tags - tags_to_keep)) + +if [ $tags_to_delete -le 0 ]; then + echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)" + exit 0 +fi + +echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep" + +# Get tags to delete (skip the first $tags_to_keep tags) +tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1))) + +if [ -z "$tags_to_delete_list" ]; then + echo "No tags to delete" + exit 0 +fi + +# Delete old tags +echo "Deleting old tags..." +while IFS= read -r tag; do + if [ -n "$tag" ]; then + delete_tag "$tag" + # Add a small delay to avoid rate limiting + sleep 1 + fi +done <<< "$tags_to_delete_list" + +echo "Cleanup completed successfully" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index df0bae0c9cbff..aa4cc7b35a543 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -86,10 +86,6 @@ if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} fi -if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then - commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"} -fi - if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} fi @@ -164,16 +160,9 @@ if [[ $commands == *" entrypoints/llm "* ]]; then --ignore=entrypoints/llm/test_chat.py \ --ignore=entrypoints/llm/test_accuracy.py \ --ignore=entrypoints/llm/test_init.py \ - --ignore=entrypoints/llm/test_generate_multiple_loras.py \ --ignore=entrypoints/llm/test_prompt_validation.py "} fi -#Obsolete currently -##ignore certain Entrypoints/llm tests -#if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then -# commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "} -#fi - # --ignore=entrypoints/openai/test_encoder_decoder.py \ # --ignore=entrypoints/openai/test_embedding.py \ # --ignore=entrypoints/openai/test_oot_registration.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 36bcb015d308e..39ea180173081 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -25,25 +25,28 @@ function cpu_tests() { # offline inference podman exec -it "$container_id" bash -c " - set -e - python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + set -xve + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" >> $HOME/test_basic.log # Run basic model test podman exec -it "$container_id" bash -c " - set -e + set -evx pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install sentence-transformers datamodel_code_generator - pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model + + # Note: disable Bart until supports V1 + # pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" + # TODO: Below test case tests/models/language/pooling/test_embedding.py::test_models[True-ssmits/Qwen2-7B-Instruct-embed-base] fails on ppc64le. Disabling it for time being. + # pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" >> $HOME/test_rest.log } # All of CPU tests are expected to be finished less than 40 mins. export container_id export -f cpu_tests -timeout 40m bash -c cpu_tests +timeout 120m bash -c cpu_tests diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 9dec9f8e9eb32..7512cb1bbed01 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -49,57 +49,69 @@ function cpu_tests() { # Run kernel tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -v -s tests/kernels/test_onednn.py" + pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e # Note: disable until supports V1 - # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model - # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_cache.py -m cpu_model + # pytest -x -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model - # Note: disable Bart until supports V1 - pytest -v -s tests/models/language/generation -m cpu_model \ - --ignore=tests/models/language/generation/test_bart.py - VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ - --ignore=tests/models/language/generation/test_bart.py + pytest -x -v -s tests/models/language/generation -m cpu_model + VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model - pytest -v -s tests/models/language/pooling -m cpu_model - pytest -v -s tests/models/multimodal/generation \ - --ignore=tests/models/multimodal/generation/test_mllama.py \ + pytest -x -v -s tests/models/language/pooling -m cpu_model + pytest -x -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_pixtral.py \ -m cpu_model" # Run compressed-tensor test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" # Note: disable it until supports V1 # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e - # VLLM_USE_V1=0 pytest -s -v \ + # VLLM_USE_V1=0 pytest -x -s -v \ # tests/quantization/test_ipex_quant.py" # Run multi-lora tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + pytest -x -s -v \ tests/lora/test_qwen2vl.py" - # online serving + # online serving: tp+pp docker exec cpu-test-"$NUMA_NODE" bash -c ' set -e VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & + server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ - --endpoint /v1/completions' + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' + + # online serving: tp+dp + docker exec cpu-test-"$NUMA_NODE" bash -c ' + set -e + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & + server_pid=$! + timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 + vllm bench serve \ + --backend vllm \ + --dataset-name random \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --num-prompts 20 \ + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh deleted file mode 100644 index a397457c83261..0000000000000 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash - -# This script build the Neuron docker image and run the API server inside the container. -# It serves a sanity check for compilation and basic model usage. -set -e -set -v - -image_name="neuron/vllm-ci" -container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" - -HF_CACHE="$(realpath ~)/huggingface" -mkdir -p "${HF_CACHE}" -HF_MOUNT="/root/.cache/huggingface" -HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN) - -NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" -mkdir -p "${NEURON_COMPILE_CACHE_URL}" -NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" - -# Try building the docker image -aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws - -# prune old image and containers to save disk space, and only once a day -# by using a timestamp file in tmp. -if [ -f /tmp/neuron-docker-build-timestamp ]; then - last_build=$(cat /tmp/neuron-docker-build-timestamp) - current_time=$(date +%s) - if [ $((current_time - last_build)) -gt 86400 ]; then - # Remove dangling images (those that are not tagged and not used by any container) - docker image prune -f - # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune -f - echo "$current_time" > /tmp/neuron-docker-build-timestamp - fi -else - date "+%s" > /tmp/neuron-docker-build-timestamp -fi - -docker build -t "${image_name}" -f docker/Dockerfile.neuron . - -# Setup cleanup -remove_docker_container() { - docker image rm -f "${image_name}" || true; -} -trap remove_docker_container EXIT - -# Run the image -docker run --rm -it --device=/dev/neuron0 --network bridge \ - -v "${HF_CACHE}:${HF_MOUNT}" \ - -e "HF_HOME=${HF_MOUNT}" \ - -e "HF_TOKEN=${HF_TOKEN}" \ - -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ - -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ - --name "${container_name}" \ - ${image_name} \ - /bin/bash -c " - set -e; # Exit on first error - python3 /workspace/vllm/examples/offline_inference/neuron.py; - python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; - for f in /workspace/vllm/tests/neuron/2_core/*.py; do - echo \"Running test file: \$f\"; - python3 -m pytest \$f -v --capture=tee-sys; - done - " \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-npu-test.sh b/.buildkite/scripts/hardware_ci/run-npu-test.sh new file mode 100644 index 0000000000000..29c8f5ed5a91a --- /dev/null +++ b/.buildkite/scripts/hardware_ci/run-npu-test.sh @@ -0,0 +1,191 @@ +#!/bin/bash + +# This script build the Ascend NPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Base ubuntu image with basic ascend development libraries and python installed +VLLM_ASCEND_REPO="https://github.com/vllm-project/vllm-ascend.git" +CONFIG_FILE_REMOTE_PATH="tests/e2e/vllm_interface/vllm_test.cfg" +TEST_RUN_CONFIG_FILE="vllm_test.cfg" +VLLM_ASCEND_TMP_DIR= +# Get the test run configuration file from the vllm-ascend repository +fetch_vllm_test_cfg() { + VLLM_ASCEND_TMP_DIR=$(mktemp -d) + # Ensure that the temporary directory is cleaned up when an exception occurs during configuration file retrieval + cleanup() { + rm -rf "${VLLM_ASCEND_TMP_DIR}" + } + trap cleanup EXIT + + GIT_TRACE=1 git clone -v --depth 1 "${VLLM_ASCEND_REPO}" "${VLLM_ASCEND_TMP_DIR}" + if [ ! -f "${VLLM_ASCEND_TMP_DIR}/${CONFIG_FILE_REMOTE_PATH}" ]; then + echo "Error: file '${CONFIG_FILE_REMOTE_PATH}' does not exist in the warehouse" >&2 + exit 1 + fi + + # If the file already exists locally, just overwrite it + cp "${VLLM_ASCEND_TMP_DIR}/${CONFIG_FILE_REMOTE_PATH}" "${TEST_RUN_CONFIG_FILE}" + echo "Copied ${CONFIG_FILE_REMOTE_PATH} to ${TEST_RUN_CONFIG_FILE}" + + # Since the trap will be overwritten later, and when it is executed here, the task of cleaning up resources + # when the trap is abnormal has been completed, so the temporary resources are manually deleted here. + rm -rf "${VLLM_ASCEND_TMP_DIR}" + trap - EXIT +} + +# Downloads test run configuration file from a remote URL. +# Loads the configuration into the current script environment. +get_config() { + if [ ! -f "${TEST_RUN_CONFIG_FILE}" ]; then + echo "Error: file '${TEST_RUN_CONFIG_FILE}' does not exist in the warehouse" >&2 + exit 1 + fi + source "${TEST_RUN_CONFIG_FILE}" + echo "Base docker image name that get from configuration: ${BASE_IMAGE_NAME}" + return 0 +} + +# get test running configuration. +fetch_vllm_test_cfg +get_config +# Check if the function call was successful. If not, exit the script. +if [ $? -ne 0 ]; then + exit 1 +fi + +image_name="npu/vllm-ci:${BUILDKITE_COMMIT}_${EPOCHSECONDS}" +container_name="npu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" + +# BUILDKITE_AGENT_NAME format is {hostname}-{agent_idx}-{npu_card_num}cards +agent_idx=$(echo "${BUILDKITE_AGENT_NAME}" | awk -F'-' '{print $(NF-1)}') +echo "agent_idx: ${agent_idx}" +builder_name="cachebuilder${agent_idx}" +builder_cache_dir="/mnt/docker-cache${agent_idx}" +mkdir -p ${builder_cache_dir} + +# Try building the docker image +cat <=6.0 modelscope + +WORKDIR /workspace/vllm + +# Install vLLM dependencies in advance. Effect: As long as common.txt remains unchanged, the docker cache layer will be valid. +COPY requirements/common.txt /workspace/vllm/requirements/common.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements/common.txt + +COPY . . + +# Install vLLM +RUN --mount=type=cache,target=/root/.cache/pip \ + VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton + +# Install vllm-ascend +WORKDIR /workspace +ARG VLLM_ASCEND_REPO=https://github.com/vllm-project/vllm-ascend.git +ARG VLLM_ASCEND_TAG=main +RUN git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf "https://github.com/" && \ + git clone --depth 1 \$VLLM_ASCEND_REPO --branch \$VLLM_ASCEND_TAG /workspace/vllm-ascend + +# Install vllm dependencies in advance. Effect: As long as common.txt remains unchanged, the docker cache layer will be valid. +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r /workspace/vllm-ascend/requirements.txt + +RUN --mount=type=cache,target=/root/.cache/pip \ + export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \ + source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + python3 -m pip install -v -e /workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/ + +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn +ENV VLLM_USE_MODELSCOPE=True + +WORKDIR /workspace/vllm-ascend + +CMD ["/bin/bash"] + +EOF + +# Setup cleanup +remove_docker_container() { + docker rm -f "${container_name}" || true; + docker image rm -f "${image_name}" || true; + docker system prune -f || true; +} +trap remove_docker_container EXIT + +# Generate corresponding --device args based on BUILDKITE_AGENT_NAME +# Ascend NPU BUILDKITE_AGENT_NAME format is {hostname}-{agent_idx}-{npu_card_num}cards, and agent_idx starts from 1. +# e.g. atlas-a2-001-1-2cards means this is the 1-th agent on atlas-a2-001 host, and it has 2 NPU cards. +# returns --device /dev/davinci0 --device /dev/davinci1 +parse_and_gen_devices() { + local input="$1" + local index cards_num + if [[ "$input" =~ ([0-9]+)-([0-9]+)cards$ ]]; then + index="${BASH_REMATCH[1]}" + cards_num="${BASH_REMATCH[2]}" + else + echo "parse error" >&2 + return 1 + fi + + local devices="" + local i=0 + while (( i < cards_num )); do + local dev_idx=$(((index - 1)*cards_num + i )) + devices="$devices --device /dev/davinci${dev_idx}" + ((i++)) + done + + # trim leading space + devices="${devices#"${devices%%[![:space:]]*}"}" + # Output devices: assigned to the caller variable + printf '%s' "$devices" +} + +devices=$(parse_and_gen_devices "${BUILDKITE_AGENT_NAME}") || exit 1 + +# Run the image and execute the Out-Of-Tree (OOT) platform interface test case on Ascend NPU hardware. +# This test checks whether the OOT platform interface is functioning properly in conjunction with +# the hardware plugin vllm-ascend. +model_cache_dir=/mnt/modelscope${agent_idx} +mkdir -p ${model_cache_dir} +docker run \ + ${devices} \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v ${model_cache_dir}:/root/.cache/modelscope \ + --entrypoint="" \ + --name "${container_name}" \ + "${image_name}" \ + bash -c ' + set -e + pytest -v -s tests/e2e/vllm_interface/ +' diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh index b571618f48c2b..cbb2527a4ff0a 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh @@ -61,13 +61,12 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ - && python3 -m pip install --progress-bar off hf-transfer + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ + && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0 echo "--- Python dependencies installed ---" -export VLLM_USE_V1=1 + export VLLM_XLA_CHECK_RECOMPILATION=1 export VLLM_XLA_CACHE_PATH= -echo "Using VLLM V1" echo "--- Hardware Information ---" # tpu-info diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index d55a786e41e8b..f022fa3672eeb 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -61,13 +61,12 @@ echo "Results will be stored in: $RESULTS_DIR" echo "--- Installing Python dependencies ---" python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ - && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ - && python3 -m pip install --progress-bar off hf-transfer + && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ + && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0 echo "--- Python dependencies installed ---" -export VLLM_USE_V1=1 + export VLLM_XLA_CHECK_RECOMPILATION=1 export VLLM_XLA_CACHE_PATH= -echo "Using VLLM V1" echo "--- Hardware Information ---" # tpu-info diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 445cd2735c190..250a64fdd071c 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -30,18 +30,19 @@ docker run \ bash -c ' set -e echo $ZE_AFFINITY_MASK - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + pip install tblib==3.1.0 + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager cd tests pytest -v -s v1/core pytest -v -s v1/engine pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py + pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_metrics_reader.py ' diff --git a/.buildkite/scripts/run-benchmarks.sh b/.buildkite/scripts/run-benchmarks.sh index 72812218cb668..51536b36b808d 100644 --- a/.buildkite/scripts/run-benchmarks.sh +++ b/.buildkite/scripts/run-benchmarks.sh @@ -18,7 +18,7 @@ vllm bench throughput --input-len 256 --output-len 256 --output-json throughput_ bench_throughput_exit_code=$? # run server-based benchmarks and upload the result to buildkite -python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf & +vllm serve meta-llama/Llama-2-7b-chat-hf & server_pid=$! wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json diff --git a/.buildkite/scripts/run-prime-rl-test.sh b/.buildkite/scripts/run-prime-rl-test.sh new file mode 100755 index 0000000000000..5b25c358fc4aa --- /dev/null +++ b/.buildkite/scripts/run-prime-rl-test.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Setup script for Prime-RL integration tests +# This script prepares the environment for running Prime-RL tests with nightly vLLM + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git" +PRIME_RL_DIR="${REPO_ROOT}/prime-rl" + +echo "Setting up Prime-RL integration test environment..." + +# Clean up any existing Prime-RL directory +if [ -d "${PRIME_RL_DIR}" ]; then + echo "Removing existing Prime-RL directory..." + rm -rf "${PRIME_RL_DIR}" +fi + +# Install UV if not available +if ! command -v uv &> /dev/null; then + echo "Installing UV package manager..." + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +fi + +# Clone Prime-RL repository at specific branch for reproducible tests +PRIME_RL_BRANCH="integ-vllm-main" +echo "Cloning Prime-RL repository at branch: ${PRIME_RL_BRANCH}..." +git clone --branch "${PRIME_RL_BRANCH}" --single-branch "${PRIME_RL_REPO}" "${PRIME_RL_DIR}" +cd "${PRIME_RL_DIR}" + +echo "Setting up UV project environment..." +export UV_PROJECT_ENVIRONMENT=/usr/local +ln -s /usr/bin/python3 /usr/local/bin/python + +# Remove vllm pin from pyproject.toml +echo "Removing vllm pin from pyproject.toml..." +sed -i '/vllm==/d' pyproject.toml + +# Sync Prime-RL dependencies +echo "Installing Prime-RL dependencies..." +uv sync --inexact && uv sync --inexact --all-extras + +# Verify installation +echo "Verifying installations..." +uv run python -c "import vllm; print(f'vLLM version: {vllm.__version__}')" +uv run python -c "import prime_rl; print('Prime-RL imported successfully')" + +echo "Prime-RL integration test environment setup complete!" + +echo "Running Prime-RL integration tests..." +export WANDB_MODE=offline # this makes this test not require a WANDB_API_KEY +uv run pytest -vs tests/integration/test_rl.py -m gpu + +echo "Prime-RL integration tests completed!" diff --git a/.buildkite/scripts/tpu/quantized_v6e_1.env b/.buildkite/scripts/tpu/quantized_v6e_1.env index bd25c803081a6..ecb98d4516bd5 100644 --- a/.buildkite/scripts/tpu/quantized_v6e_1.env +++ b/.buildkite/scripts/tpu/quantized_v6e_1.env @@ -9,6 +9,6 @@ MAX_NUM_BATCHED_TOKENS=1024 TENSOR_PARALLEL_SIZE=1 MAX_MODEL_LEN=2048 DOWNLOAD_DIR=/mnt/disks/persist -EXPECTED_THROUGHPUT=10.0 +EXPECTED_THROUGHPUT=8.7 INPUT_LEN=1800 OUTPUT_LEN=128 diff --git a/.buildkite/scripts/tpu/run_bm.sh b/.buildkite/scripts/tpu/run_bm.sh index b1e17b438578d..3364fce8e1fdc 100755 --- a/.buildkite/scripts/tpu/run_bm.sh +++ b/.buildkite/scripts/tpu/run_bm.sh @@ -42,7 +42,7 @@ echo "lanching vllm..." echo "logging to $VLLM_LOG" echo -VLLM_USE_V1=1 vllm serve $MODEL \ +vllm serve $MODEL \ --seed 42 \ --max-num-seqs $MAX_NUM_SEQS \ --max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \ diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 745f285c008ad..43aa8c47be299 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -58,14 +58,15 @@ python3 .buildkite/generate_index.py --wheel "$normal_wheel" aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -if [[ $normal_wheel == *"cu118"* ]]; then - # if $normal_wheel matches cu118, do not upload the index.html - echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu126"* ]]; then +if [[ $normal_wheel == *"cu126"* ]]; then # if $normal_wheel matches cu126, do not upload the index.html echo "Skipping index files for cu126 wheels" +elif [[ $normal_wheel == *"cu128"* ]]; then + # if $normal_wheel matches cu128, do not upload the index.html + echo "Skipping index files for cu128 wheels" else - # only upload index.html for cu128 wheels (default wheels) + # only upload index.html for cu129 wheels (default wheels) as it + # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" fi @@ -74,14 +75,15 @@ fi aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" -if [[ $normal_wheel == *"cu118"* ]]; then - # if $normal_wheel matches cu118, do not upload the index.html - echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu126"* ]]; then +if [[ $normal_wheel == *"cu126"* ]]; then # if $normal_wheel matches cu126, do not upload the index.html echo "Skipping index files for cu126 wheels" +elif [[ $normal_wheel == *"cu128"* ]]; then + # if $normal_wheel matches cu128, do not upload the index.html + echo "Skipping index files for cu128 wheels" else - # only upload index.html for cu128 wheels (default wheels) + # only upload index.html for cu129 wheels (default wheels) as it + # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 20f3ce1adb46d..ebe0602a1b5db 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -6,24 +6,28 @@ # to generate the final pipeline yaml file. # Documentation -# label(str): the name of the test. emoji allowed. -# fast_check(bool): whether to run this on each commit on fastcheck pipeline. -# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline. -# fast_check_only(bool): run this test on fastcheck pipeline only -# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. +# label(str): the name of the test. emojis allowed. +# fast_check(bool): whether to run this on each commit on the fastcheck pipeline. +# torch_nightly(bool): whether to run this on vllm against the torch nightly pipeline. +# fast_check_only(bool): run this test on the fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's a scheduled nightly run. +# soft_fail(bool): allow this step to fail without failing the entire pipeline (useful for flaky or experimental tests). # command(str): the single command to run for tests. incompatible with commands. -# commands(list): the list of commands to run for test. incompatbile with command. -# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] -# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 -# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. -# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, -# in this case, commands must be specified. the first command runs on first host, the second +# commands(list): the list of commands to run for the test. incompatible with command. +# mirror_hardwares(list): the list of hardware to run the test on as well. currently only supports [amdexperimental] +# gpu(str): override the GPU selection for the test. default is L4 GPUs. supports a100, b200, h200 +# num_gpus(int): override the number of GPUs for the test. defaults to 1 GPU. currently supports 2,4. +# num_nodes(int): whether to simulate multi-node setup by launching multiple containers on one host, +# in this case, commands must be specified. the first command runs on the first host, the second # command runs on the second host. -# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests -# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run. +# timeout_in_minutes(int): sets a timeout for the step in minutes. if not specified, uses the default timeout. +# parallelism(int): number of parallel jobs to run for this step. enables test sharding using $$BUILDKITE_PARALLEL_JOB +# and $$BUILDKITE_PARALLEL_JOB_COUNT environment variables. +# working_dir(str): specify the place where the command should execute, default to /vllm-workspace/tests +# source_file_dependencies(list): the list of prefixes to opt-in the test for, if empty, the test will always run. # When adding a test -# - If the test belong to an existing group, add it there +# - If the test belongs to an existing group, add it there # - If the test is short, add to any existing step # - If the test takes more than 10min, then it is okay to create a new step. # Note that all steps execute in parallel. @@ -41,29 +45,36 @@ steps: commands: - bash standalone_tests/pytorch_nightly_dependency.sh -- label: Async Engine, Inputs, Utils, Worker Test # 24min +- label: Async Engine, Inputs, Utils, Worker Test # 36min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - - tests/mq_llm_engine - - tests/async_engine + - tests/multimodal + - tests/utils_ + commands: + - pytest -v -s -m 'not cpu_test' multimodal + - pytest -v -s utils_ + +- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins + timeout_in_minutes: 10 + source_file_dependencies: + - vllm/ - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - - tests/utils_ - - tests/worker - tests/standalone_tests/lazy_imports.py + - tests/transformers_utils + no_gpu: true commands: - python3 standalone_tests/lazy_imports.py - - pytest -v -s mq_llm_engine # MQLLMEngine - - pytest -v -s async_engine # AsyncLLMEngine - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - - pytest -v -s multimodal - - pytest -v -s utils_ # Utils - - pytest -v -s worker # Worker + - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s transformers_utils -- label: Python-only Installation Test +- label: Python-only Installation Test # 10min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh @@ -71,7 +82,8 @@ steps: commands: - bash standalone_tests/python_only_compile.sh -- label: Basic Correctness Test # 30min +- label: Basic Correctness Test # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] fast_check: true torch_nightly: true @@ -79,26 +91,26 @@ steps: - vllm/ - tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_cpu_offload - - tests/basic_correctness/test_preemption - tests/basic_correctness/test_cumem.py commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Core Test # 10min - mirror_hardwares: [amdexperimental] +- label: Entrypoints Unit Tests # 5min + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" fast_check: true source_file_dependencies: - - vllm/core - - vllm/distributed - - tests/core + - vllm/entrypoints + - tests/entrypoints/ commands: - - pytest -v -s core + - pytest -v -s entrypoints/openai/tool_parsers + - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling -- label: Entrypoints Test (LLM) # 40min +- label: Entrypoints Integration Test (LLM) # 30min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true @@ -109,13 +121,12 @@ steps: - tests/entrypoints/offline_mode commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_collective_rpc.py - - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests -- label: Entrypoints Test (API Server) # 40min +- label: Entrypoints Integration Test (API Server) # 100min + timeout_in_minutes: 130 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true @@ -127,16 +138,29 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/ - pytest -v -s entrypoints/test_chat_utils.py -- label: Distributed Tests (4 GPUs) # 10min +- label: Entrypoints Integration Test (Pooling) + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + working_dir: "/vllm-workspace/tests" + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/entrypoints/pooling + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/pooling + +- label: Distributed Tests (4 GPUs) # 35min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: - vllm/distributed/ - - vllm/core/ - tests/distributed/test_utils - tests/distributed/test_pynccl - tests/distributed/test_events @@ -144,28 +168,34 @@ steps: - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - - tests/v1/test_async_llm_dp.py - - tests/v1/test_external_lb_dp.py - - tests/v1/test_internal_lb_dp.py - - tests/v1/test_hybrid_lb_dp.py + - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py + - tests/distributed/test_symm_mem_allreduce.py commands: - # test with tp=2 and external_dp=2 - - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - # test with tp=2 and pp=2 + # test with torchrun tp=2 and pp=2 - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=4 and dp=1 + - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2, pp=2 and dp=1 + - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=1 and dp=4 with ep + - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2 and dp=2 with ep + - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py - - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_events.py + - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference @@ -173,7 +203,8 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd -- label: EPLB Algorithm Test +- label: EPLB Algorithm Test # 5min + timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" source_file_dependencies: - vllm/distributed/eplb @@ -182,6 +213,7 @@ steps: - pytest -v -s distributed/test_eplb_algo.py - label: EPLB Execution Test # 5min + timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -190,26 +222,26 @@ steps: commands: - pytest -v -s distributed/test_eplb_execute.py -- label: Metrics, Tracing Test # 10min +- label: Metrics, Tracing Test # 12min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] num_gpus: 2 source_file_dependencies: - vllm/ - - tests/metrics - - tests/tracing + - tests/v1/tracing commands: - - pytest -v -s metrics - "pip install \ 'opentelemetry-sdk>=1.26.0' \ 'opentelemetry-api>=1.26.0' \ 'opentelemetry-exporter-otlp>=1.26.0' \ 'opentelemetry-semantic-conventions-ai>=0.4.1'" - - pytest -v -s tracing + - pytest -v -s v1/tracing ##### fast check tests ##### ##### 1 GPU test ##### -- label: Regression Test # 5min +- label: Regression Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -219,7 +251,8 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 10min +- label: Engine Test # 25min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -234,36 +267,66 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: V1 Test +- label: V1 Test e2e + engine # 30min + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e + - pytest -v -s v1/engine + +- label: V1 Test entrypoints # 35min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - pytest -v -s v1/entrypoints + +- label: V1 Test others # 42min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 commands: # split the test to avoid interference - - pytest -v -s v1/core - - pytest -v -s v1/engine - - pytest -v -s v1/entrypoints + - pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s v1/executor + - pytest -v -s v1/kv_offload - pytest -v -s v1/sample - pytest -v -s v1/logits_processors - pytest -v -s v1/worker - - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode - - pytest -v -s v1/kv_connector/unit - - pytest -v -s v1/metrics - - pytest -v -s v1/test_serial_utils.py - - pytest -v -s v1/test_utils.py + - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - - pytest -v -s v1/test_metrics_reader.py - # TODO: accuracy does not match, whether setting - # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - - pytest -v -s v1/e2e + - pytest -v -s v1/test_request.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine -- label: Examples Test # 25min +- label: V1 Test others (CPU) # 5 mins + source_file_dependencies: + - vllm/ + - tests/v1 + no_gpu: true + commands: + # split the test to avoid interference + - pytest -v -s -m 'cpu_test' v1/core + - pytest -v -s v1/structured_output + - pytest -v -s v1/test_serial_utils.py + - pytest -v -s -m 'cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'cpu_test' v1/metrics + + +- label: Examples Test # 30min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" source_file_dependencies: @@ -280,15 +343,16 @@ steps: - python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py - - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 + - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 -- label: Platform Tests (CUDA) +- label: Platform Tests (CUDA) # 4min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -296,7 +360,8 @@ steps: commands: - pytest -v -s cuda/test_cuda_context.py -- label: Samplers Test # 36min +- label: Samplers Test # 56min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers @@ -307,15 +372,23 @@ steps: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers -- label: LoRA Test %N # 15min each +- label: LoRA Test %N # 20min each + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + commands: + - pytest -v -s lora \ + --shard-id=$$BUILDKITE_PARALLEL_JOB \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --ignore=lora/test_chatglm3_tp.py \ + --ignore=lora/test_llama_tp.py \ + --ignore=lora/test_llm_with_multi_loras.py parallelism: 4 -- label: PyTorch Compilation Unit Tests +- label: PyTorch Compilation Unit Tests # 15min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -325,13 +398,15 @@ steps: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion_attn.py + - pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - - pytest -v -s compile/test_sequence_parallelism.py - - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py + - pytest -v -s compile/test_noop_elimination.py + - pytest -v -s compile/test_aot_compile.py -- label: PyTorch Fullgraph Smoke Test # 9min +- label: PyTorch Fullgraph Smoke Test # 15min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -339,13 +414,10 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py - # these tests need to be separated, cannot combine - - pytest -v -s compile/piecewise/test_simple.py - - pytest -v -s compile/piecewise/test_toy_llama.py - - pytest -v -s compile/piecewise/test_full_cudagraph.py - - pytest -v -s compile/piecewise/test_multiple_graphs.py + - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 18min +- label: PyTorch Fullgraph Test # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -354,15 +426,18 @@ steps: commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Core Operation Test +- label: Kernels Core Operation Test # 48min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - tests/kernels/core + - tests/kernels/test_top_k_per_row.py commands: - - pytest -v -s kernels/core + - pytest -v -s kernels/core kernels/test_top_k_per_row.py -- label: Kernels Attention Test %N +- label: Kernels Attention Test %N # 23min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/attention/ @@ -373,7 +448,8 @@ steps: - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels Quantization Test %N +- label: Kernels Quantization Test %N # 64min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/quantization/ @@ -383,48 +459,44 @@ steps: - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels MoE Test %N +- label: Kernels MoE Test %N # 40min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/quantization/cutlass_w8a8/moe/ - csrc/moe/ - tests/kernels/moe - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ commands: - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels Mamba Test +- label: Kernels Mamba Test # 31min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba + - vllm/model_executor/layers/mamba/ops commands: - pytest -v -s kernels/mamba -- label: Tensorizer Test # 11min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/model_executor/model_loader - - tests/tensorizer_loader - - tests/entrypoints/openai/test_tensorizer_entrypoint.py - commands: - - apt-get update && apt-get install -y curl libsodium23 - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s tensorizer_loader - - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py - -- label: Model Executor Test +- label: Model Executor Test # 23min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor - tests/model_executor + - tests/entrypoints/openai/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py -- label: Benchmarks # 9min +- label: Benchmarks # 11min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite" source_file_dependencies: @@ -432,7 +504,8 @@ steps: commands: - bash scripts/run-benchmarks.sh -- label: Benchmarks CLI Test # 10min +- label: Benchmarks CLI Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -440,7 +513,8 @@ steps: commands: - pytest -v -s benchmarks/ -- label: Quantization Test +- label: Quantization Test # 70min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -448,11 +522,16 @@ steps: - tests/quantization commands: # temporary install here since we need nightly, will move to requirements/test.in - # after torchao 0.12 release - - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization + # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved + - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ - label: LM Eval Small Models # 53min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -460,7 +539,8 @@ steps: commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 -- label: OpenAI API correctness +- label: OpenAI API correctness # 22min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -469,51 +549,109 @@ steps: commands: # LMEval+Transcription WER check - pytest -s entrypoints/openai/correctness/ -- label: Encoder Decoder tests # 5min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/encoder_decoder - commands: - - pytest -v -s encoder_decoder - -- label: OpenAI-Compatible Tool Use # 20 min +- label: OpenAI-Compatible Tool Use # 23 min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] fast_check: false source_file_dependencies: - vllm/ - tests/tool_use - - tests/mistral_tool_use commands: - - pytest -v -s tool_use - - pytest -v -s mistral_tool_use + - pytest -v -s -m 'not cpu_test' tool_use + +- label: OpenAI-Compatible Tool Use (CPU) # 5 mins + timeout_in_minutes: 10 + source_file_dependencies: + - vllm/ + - tests/tool_use + no_gpu: true + commands: + - pytest -v -s -m 'cpu_test' tool_use ##### models test ##### -- label: Basic Models Test # 24min +- label: Basic Models Tests (Initialization) + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - - tests/models + - tests/models/test_initialization.py commands: - - pytest -v -s models/test_transformers.py - - pytest -v -s models/test_registry.py - - pytest -v -s models/test_utils.py - - pytest -v -s models/test_vision.py - - pytest -v -s models/test_initialization.py + # Run a subset of model initialization tests + - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset -- label: Language Models Test (Standard) +- label: Basic Models Tests (Extra Initialization) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/test_initialization.py + commands: + # Only when vLLM model source is modified - test initialization of a large + # subset of supported models (the complement of the small subset in the above + # test.) Also run if model initialization test file is modified + - pytest -v -s models/test_initialization.py \ + -k 'not test_can_initialize_small_subset' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Basic Models Tests (Other) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_transformers.py + - tests/models/test_registry.py + commands: + - pytest -v -s models/test_transformers.py models/test_registry.py + +- label: Basic Models Test (Other CPU) # 5min + timeout_in_minutes: 10 + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_utils.py + - tests/models/test_vision.py + no_gpu: true + commands: + - pytest -v -s models/test_utils.py models/test_vision.py + +- label: Language Models Tests (Standard) + timeout_in_minutes: 25 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language commands: + # Test standard language models, excluding a subset of slow tests - pip freeze | grep -E 'torch' - - pytest -v -s models/language -m core_model + - pytest -v -s models/language -m 'core_model and (not slow_test)' -- label: Language Models Test (Hybrid) # 35 min +- label: Language Models Tests (Extra Standard) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/language/pooling/test_embedding.py + - tests/models/language/generation/test_common.py + - tests/models/language/pooling/test_classification.py + commands: + # Shard slow subset of standard language models tests. Only run when model + # source is modified, or when specified test files are modified + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and slow_test' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Tests (Hybrid) %N + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -524,9 +662,15 @@ steps: # Note: also needed to run plamo2 model in vLLM - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - - pytest -v -s models/language/generation -m hybrid_model + # Shard hybrid language model tests + - pytest -v -s models/language/generation \ + -m hybrid_model \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 -- label: Language Models Test (Extended Generation) # 1hr20min +- label: Language Models Test (Extended Generation) # 80min + timeout_in_minutes: 110 mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: @@ -537,7 +681,18 @@ steps: - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' +- label: Language Models Test (PPL) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation_ppl_test + commands: + - pytest -v -s models/language/generation_ppl_test + - label: Language Models Test (Extended Pooling) # 36min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: @@ -546,16 +701,27 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' -- label: Multi-Modal Processor Test +- label: Language Models Test (MTEB) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling_mteb_test + commands: + - pytest -v -s models/language/pooling_mteb_test + +- label: Multi-Modal Processor Test # 44min + timeout_in_minutes: 60 source_file_dependencies: - vllm/ - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py - - pytest -v -s models/multimodal/processing/test_tensor_schema.py + - pytest -v -s models/multimodal/processing -- label: Multi-Modal Models Test (Standard) +- label: Multi-Modal Models Test (Standard) # 60min + timeout_in_minutes: 80 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -565,7 +731,7 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 mirror_hardwares: [amdexperimental] @@ -597,7 +763,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' -- label: Quantized Models Test +- label: Quantized Models Test # 45 min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers/quantization @@ -621,13 +788,16 @@ steps: commands: - pip install --upgrade git+https://github.com/huggingface/transformers - pytest -v -s tests/models/test_initialization.py + - pytest -v -s tests/models/test_transformers.py - pytest -v -s tests/models/multimodal/processing/ - pytest -v -s tests/models/multimodal/test_mapping.py - python3 examples/offline_inference/basic/chat.py - - python3 examples/offline_inference/audio_language.py --model-type whisper - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl + # Whisper needs spawn method to avoid deadlock + - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test +- label: Blackwell Test # 38 min + timeout_in_minutes: 60 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -649,23 +819,71 @@ steps: # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - - pytest -v -s tests/kernels/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py + - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py # Fusion - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + +- label: Blackwell GPT-OSS Eval + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true # run on nightlies + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + +- label: Blackwell Quantized MoE Test + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - tests/quantization/test_blackwell_moe.py + - vllm/model_executor/models/deepseek_v2.py + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization/compressed_tensors + - vllm/model_executor/layers/quantization/modelopt.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - pytest -s -v tests/quantization/test_blackwell_moe.py + +- label: Blackwell LM Eval Small Models + timeout_in_minutes: 120 + gpu: b200 + optional: true # run on nightlies + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 ##### 1 GPU test ##### ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -675,8 +893,11 @@ steps: commands: - pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py - label: 2 Node Tests (4 GPUs in total) # 16min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -700,47 +921,61 @@ steps: - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code -- label: Distributed Tests (2 GPUs) # 40min +- label: Distributed Tests (2 GPUs) # 68min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: + - vllm/compilation/ - vllm/distributed/ - vllm/engine/ - vllm/executor/ - - vllm/model_executor/models/ - - tests/distributed/ - - vllm/compilation - vllm/worker/worker_base.py - - vllm/worker/worker.py - - vllm/worker/model_runner.py - - entrypoints/llm/test_collective_rpc.py - - tests/v1/test_async_llm_dp.py - - tests/v1/test_external_lb_dp.py - - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ + - vllm/v1/worker/ + - tests/compile/test_basic_correctness.py + - tests/compile/test_wrapper.py + - tests/distributed/ + - tests/entrypoints/llm/test_collective_rpc.py + - tests/v1/distributed + - tests/v1/entrypoints/openai/test_multi_api_servers.py + - tests/v1/shutdown + - tests/v1/worker/test_worker_memory_snapshot.py commands: - - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - pytest -v -s distributed/test_sequence_parallel.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown + - pytest -v -s v1/worker/test_worker_memory_snapshot.py + +- label: Distributed Model Tests (2 GPUs) # 37min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/model_executor/model_loader/sharded_state_loader.py + - vllm/model_executor/models/ + - tests/basic_correctness/ + - tests/model_executor/model_loader/test_sharded_state_loader.py + - tests/models/ + commands: - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py # Avoid importing model tests that cause CUDA reinitialization error - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)' - - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' - # test sequence parallel - - pytest -v -s distributed/test_sequence_parallel.py - # this test fails consistently. - # TODO: investigate and fix - - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - - pytest -v -s models/multimodal/generation/test_maverick.py + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' - label: Plugin Tests (2 GPUs) # 40min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -753,6 +988,11 @@ steps: - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y # end platform plugin tests + # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin + - pip install -e ./plugins/prithvi_io_processor_plugin + - pytest -v -s plugins_tests/test_io_processor_plugins.py + - pip uninstall prithvi_io_processor_plugin -y + # end io_processor plugins test # other tests continue here: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model @@ -761,7 +1001,8 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins -- label: Pipeline Parallelism Test # 45min +- label: Pipeline + Context Parallelism Test # 45min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -775,7 +1016,8 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py -- label: LoRA TP Test (Distributed) +- label: LoRA TP Test (Distributed) # 17 min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] num_gpus: 4 source_file_dependencies: @@ -789,13 +1031,15 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - - pytest -v -s -x lora/test_multi_loras_with_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py - label: Weight Loading Multiple GPU Test # 33min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 + optional: true source_file_dependencies: - vllm/ - tests/weight_loading @@ -844,9 +1088,36 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 -- label: Qwen MoE EP Test # optional +##### H200 test ##### +- label: Distrubted Tests (H200) # optional gpu: h200 optional: true + working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/compile/test_async_tp.py + - pytest -v -s tests/compile/test_sequence_parallelism.py + - pytest -v -s tests/distributed/test_context_parallel.py + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + +##### B200 test ##### +- label: Distributed Tests (B200) # optional + gpu: b200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + +##### RL Integration Tests ##### +- label: Prime-RL Integration Test # 15min + timeout_in_minutes: 30 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + source_file_dependencies: + - vllm/ + - .buildkite/scripts/run-prime-rl-test.sh + commands: + - bash .buildkite/scripts/run-prime-rl-test.sh diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000000..bc6342956109b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,32 @@ +[run] +source = vllm +omit = + */tests/* + */test_* + */__pycache__/* + */build/* + */dist/* + */vllm.egg-info/* + */third_party/* + */examples/* + */benchmarks/* + */docs/* + +[report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + if settings.DEBUG + raise AssertionError + raise NotImplementedError + if 0: + if __name__ == .__main__.: + class .*\bProtocol\): + @(abc\.)?abstractmethod + +[html] +directory = htmlcov + +[xml] +output = coverage.xml diff --git a/.github/.bc-linter.yml b/.github/.bc-linter.yml new file mode 100644 index 0000000000000..443dfa45af22c --- /dev/null +++ b/.github/.bc-linter.yml @@ -0,0 +1,24 @@ +# doc: https://github.com/pytorch/test-infra/blob/main/tools/stronghold/docs/bc_linter_config.md +version: 1 +paths: +# We temporarily disable globally, and will only enable with `annotations.include` +# include: +# - "vllm/v1/attetion/*.py" +# - "vllm/v1/core/*.py" +exclude: + - "**/*.py" + +scan: + functions: true # check free functions and methods + classes: true # check classes/dataclasses + public_only: true # ignore names starting with "_" at any level + +annotations: + include: # decorators that force‑include a symbol + - name: "bc_linter_include" # matched by simple name or dotted suffix + propagate_to_members: false # for classes, include methods/inner classes + exclude: # decorators that force‑exclude a symbol + - name: "bc_linter_skip" # matched by simple name or dotted suffix + propagate_to_members: true # for classes, exclude methods/inner classes + +excluded_violations: [] # e.g. ["ParameterRenamed", "FieldTypeChanged"] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ce9590f02ce71..dbcad3aa308f5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,64 +2,88 @@ # for more info about CODEOWNERS file # This lists cover the "core" components of vLLM that require careful review +/vllm/attention @LucasWilkinson /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn +/vllm/model_executor/layers/fused_moe @mgoin +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 /vllm/model_executor/layers/mamba @tdoublep -/vllm/multimodal @DarkLight1337 @ywang96 +/vllm/model_executor/model_loader @22quinn +/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee -/vllm/reasoning @aarnphm -/vllm/entrypoints @aarnphm +/vllm/reasoning @aarnphm @chaunceyjiang +/vllm/entrypoints @aarnphm @chaunceyjiang /vllm/compilation @zou3519 @youkaichao @ProExpertProg +/vllm/distributed/kv_transfer @NickLucche @ApostaC CMakeLists.txt @tlrmchlsmth @LucasWilkinson # Any change to the VllmConfig changes can have a large user-facing impact, # so spam a lot of people /vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg +/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345 # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat -/vllm/v1/structured_output @mgoin @russellb @aarnphm +/vllm/v1/attention @LucasWilkinson +/vllm/v1/attention/backends/flashinfer.py @mgoin /vllm/v1/attention/backends/triton_attn.py @tdoublep +/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC +/vllm/v1/sample @22quinn @houseroad @njhill +/vllm/v1/spec_decode @benchislett @luccafong +/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett +/vllm/v1/kv_cache_interface.py @heheda12345 +/vllm/v1/offloading @ApostaC # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo -/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao -/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm -/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256 +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm @NickLucche +/tests/evals @mgoin +/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/models @DarkLight1337 @ywang96 -/tests/multimodal @DarkLight1337 @ywang96 -/tests/prefix_caching @comaniac @KuntaiDu +/tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 /tests/test_inputs.py @DarkLight1337 @ywang96 /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm /tests/v1/structured_output @mgoin @russellb @aarnphm +/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee /tests/models/language/generation/test_hybrid.py @tdoublep +/tests/v1/kv_connector/nixl_integration @NickLucche +/tests/v1/kv_connector @ApostaC +/tests/v1/offloading @ApostaC + +# Transformers backend +/vllm/model_executor/models/transformers.py @hmellor +/tests/models/test_transformers.py @hmellor # Docs -/docs @hmellor +/docs/mkdocs @hmellor +/docs/**/*.yml @hmellor +/requirements/docs.txt @hmellor +.readthedocs.yaml @hmellor mkdocs.yaml @hmellor +# Linting +.markdownlint.yaml @hmellor +.pre-commit-config.yaml @hmellor +/tools/pre_commit @hmellor + # CPU -/vllm/v1/worker/^cpu @bigPYJ1151 +/vllm/v1/worker/cpu* @bigPYJ1151 /csrc/cpu @bigPYJ1151 /vllm/platforms/cpu.py @bigPYJ1151 /cmake/cpu_extension.cmake @bigPYJ1151 /docker/Dockerfile.cpu @bigPYJ1151 # Intel GPU -/vllm/v1/worker/^xpu @jikunshang +/vllm/v1/worker/xpu* @jikunshang /vllm/platforms/xpu.py @jikunshang /docker/Dockerfile.xpu @jikunshang @@ -67,6 +91,9 @@ mkdocs.yaml @hmellor /vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow /vllm/model_executor/models/qwen* @sighingnow +# MTP-specific files +/vllm/model_executor/models/deepseek_mtp.py @luccafong + # Mistral-specific files /vllm/model_executor/models/mistral*.py @patrickvonplaten /vllm/model_executor/models/mixtral*.py @patrickvonplaten @@ -79,4 +106,18 @@ mkdocs.yaml @hmellor /vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep /vllm/attention/ops/triton_unified_attention.py @tdoublep +# ROCm related: specify owner with write access to notify AMD folks for careful code review +/docker/Dockerfile.rocm* @gshtras +/vllm/v1/attention/backends/rocm*.py @gshtras +/vllm/v1/attention/backends/mla/rocm*.py @gshtras +/vllm/attention/ops/rocm*.py @gshtras +/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras +# TPU +/vllm/v1/worker/tpu* @NickLucche +/vllm/platforms/tpu.py @NickLucche +/vllm/v1/sample/tpu @NickLucche +/vllm/tests/v1/tpu @NickLucche + +# KVConnector installation files +/requirements/kv_connectors.txt @NickLucche diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml index 7ee57c42895ca..c0e009855964a 100644 --- a/.github/ISSUE_TEMPLATE/750-RFC.yml +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -43,10 +43,6 @@ body: Any other things you would like to mention. validations: required: false -- type: markdown - attributes: - value: > - Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit). - type: checkboxes id: askllm attributes: diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1b30c1292df85..8043df65d5585 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,8 +7,6 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT ## Test Result -## (Optional) Documentation Update - ---
Essential Elements of an Effective PR Description Checklist @@ -17,6 +15,7 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results - [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. +- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.github/mergify.yml b/.github/mergify.yml index 495d207d44260..de1a8314a4ecd 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -2,6 +2,7 @@ pull_request_rules: - name: label-documentation description: Automatically apply documentation label conditions: + - label != stale - or: - files~=^[^/]+\.md$ - files~=^docs/ @@ -10,10 +11,13 @@ pull_request_rules: label: add: - documentation + comment: + message: "Documentation preview: https://vllm--{{number}}.org.readthedocs.build/en/{{number}}/" - name: label-ci-build description: Automatically apply ci/build label conditions: + - label != stale - or: - files~=^\.github/ - files~=\.buildkite/ @@ -30,6 +34,7 @@ pull_request_rules: - name: label-deepseek description: Automatically apply deepseek label conditions: + - label != stale - or: - files~=^examples/.*deepseek.*\.py - files~=^tests/.*deepseek.*\.py @@ -46,6 +51,7 @@ pull_request_rules: - name: label-frontend description: Automatically apply frontend label conditions: + - label != stale - files~=^vllm/entrypoints/ actions: label: @@ -55,6 +61,7 @@ pull_request_rules: - name: label-llama description: Automatically apply llama label conditions: + - label != stale - or: - files~=^examples/.*llama.*\.py - files~=^tests/.*llama.*\.py @@ -70,6 +77,7 @@ pull_request_rules: - name: label-multi-modality description: Automatically apply multi-modality label conditions: + - label != stale - or: - files~=^vllm/multimodal/ - files~=^tests/multimodal/ @@ -83,6 +91,7 @@ pull_request_rules: - name: label-new-model description: Automatically apply new-model label conditions: + - label != stale - and: - files~=^vllm/model_executor/models/ - files=vllm/model_executor/models/registry.py @@ -94,6 +103,7 @@ pull_request_rules: - name: label-performance description: Automatically apply performance label conditions: + - label != stale - or: - files~=^benchmarks/ - files~=^vllm/benchmarks/ @@ -107,6 +117,7 @@ pull_request_rules: - name: label-qwen description: Automatically apply qwen label conditions: + - label != stale - or: - files~=^examples/.*qwen.*\.py - files~=^tests/.*qwen.*\.py @@ -121,12 +132,20 @@ pull_request_rules: - name: label-gpt-oss description: Automatically apply gpt-oss label conditions: + - label != stale - or: - files~=^examples/.*gpt[-_]?oss.*\.py - files~=^tests/.*gpt[-_]?oss.*\.py + - files~=^tests/entrypoints/openai/test_response_api_with_harmony.py + - files~=^tests/entrypoints/test_context.py - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py + - files~=^vllm/entrypoints/harmony_utils.py + - files~=^vllm/entrypoints/tool_server.py + - files~=^vllm/entrypoints/tool.py + - files~=^vllm/entrypoints/context.py - title~=(?i)gpt[-_]?oss + - title~=(?i)harmony actions: label: add: @@ -135,6 +154,7 @@ pull_request_rules: - name: label-rocm description: Automatically apply rocm label conditions: + - label != stale - or: - files~=^csrc/rocm/ - files~=^docker/Dockerfile.rocm @@ -155,6 +175,7 @@ pull_request_rules: - name: label-structured-output description: Automatically apply structured-output label conditions: + - label != stale - or: - files~=^benchmarks/structured_schemas/ - files=benchmarks/benchmark_serving_structured_output.py @@ -164,7 +185,7 @@ pull_request_rules: - files=examples/online_serving/openai_chat_completion_structured_outputs.py - files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py - files~=^tests/v1/structured_output/ - - files=tests/v1/entrypoints/llm/test_guided_generate.py + - files=tests/v1/entrypoints/llm/test_struct_output_generate.py - files~=^vllm/v1/structured_output/ actions: label: @@ -174,6 +195,7 @@ pull_request_rules: - name: label-speculative-decoding description: Automatically apply speculative-decoding label conditions: + - label != stale - or: - files~=^vllm/v1/spec_decode/ - files~=^tests/v1/spec_decode/ @@ -189,6 +211,7 @@ pull_request_rules: - name: label-v1 description: Automatically apply v1 label conditions: + - label != stale - or: - files~=^vllm/v1/ - files~=^tests/v1/ @@ -201,6 +224,7 @@ pull_request_rules: description: Automatically apply tpu label # Keep this list in sync with `label-tpu-remove` conditions conditions: + - label != stale - or: - files~=tpu.py - files~=_tpu @@ -216,6 +240,7 @@ pull_request_rules: description: Automatically remove tpu label # Keep this list in sync with `label-tpu` conditions conditions: + - label != stale - and: - -files~=tpu.py - -files~=_tpu @@ -230,9 +255,9 @@ pull_request_rules: - name: label-tool-calling description: Automatically add tool-calling label conditions: + - label != stale - or: - files~=^tests/tool_use/ - - files~=^tests/mistral_tool_use/ - files~=^tests/entrypoints/openai/tool_parsers/ - files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py - files~=^vllm/entrypoints/openai/tool_parsers/ @@ -249,8 +274,9 @@ pull_request_rules: - name: ping author on conflicts and add 'needs-rebase' label conditions: - - conflict - - -closed + - label != stale + - conflict + - -closed actions: label: add: @@ -264,20 +290,55 @@ pull_request_rules: - name: assign reviewer for tensorizer changes conditions: + - label != stale + - or: - files~=^vllm/model_executor/model_loader/tensorizer.py - files~=^vllm/model_executor/model_loader/tensorizer_loader.py - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py - - files~=^tests/tensorizer_loader/ + - files~=^tests/model_executor/model_loader/tensorizer_loader/ actions: assign: users: - "sangstar" +- name: assign reviewer for modelopt changes + conditions: + - label != stale + - or: + - files~=^vllm/model_executor/layers/quantization/modelopt\.py$ + - files~=^vllm/model_executor/layers/quantization/__init__\.py$ + - files~=^tests/models/quantization/test_modelopt\.py$ + - files~=^tests/quantization/test_modelopt\.py$ + - files~=^tests/models/quantization/test_nvfp4\.py$ + - files~=^docs/features/quantization/modelopt\.md$ + actions: + assign: + users: + - "Edwardf0t1" + - name: remove 'needs-rebase' label when conflict is resolved conditions: - - -conflict - - -closed + - -conflict + - -closed actions: label: remove: - needs-rebase + +- name: label-kv-connector + description: Automatically apply kv-connector label + conditions: + - label != stale + - or: + - files~=^examples/online_serving/disaggregated[^/]*/.* + - files~=^examples/offline_inference/disaggregated[^/]*/.* + - files~=^examples/others/lmcache/ + - files~=^tests/v1/kv_connector/ + - files~=^vllm/distributed/kv_transfer/ + - title~=(?i)\bP/?D\b + - title~=(?i)NIXL + - title~=(?i)LMCache + actions: + label: + add: + - kv-connector \ No newline at end of file diff --git a/.github/scale-config.yml b/.github/scale-config.yml new file mode 100644 index 0000000000000..c41a3ee3eb196 --- /dev/null +++ b/.github/scale-config.yml @@ -0,0 +1,21 @@ +# scale-config.yml: +# Powers what instance types are available for GHA auto-scaled +# runners. Runners listed here will be available as self hosted +# runners, configuration is directly pulled from the main branch. +# runner_types: +# runner_label: +# instance_type: m4.large +# os: linux +# # min_available defaults to the global cfg in the ALI Terraform +# min_available: undefined +# # when max_available value is not defined, no max runners is enforced +# max_available: undefined +# disk_size: 50 +# is_ephemeral: true + +runner_types: + linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: true + os: linux diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml index 315042fbf5cf4..d8bbedef3174b 100644 --- a/.github/workflows/add_label_automerge.yml +++ b/.github/workflows/add_label_automerge.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add label - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | github.rest.issues.addLabels({ diff --git a/.github/workflows/bc-lint.yml b/.github/workflows/bc-lint.yml new file mode 100644 index 0000000000000..823695a921321 --- /dev/null +++ b/.github/workflows/bc-lint.yml @@ -0,0 +1,29 @@ +name: BC Lint + +on: + pull_request: + types: + - opened + - synchronize + - reopened + - labeled + - unlabeled + +jobs: + bc_lint: + if: github.repository_owner == 'vllm-project' + runs-on: ubuntu-latest + steps: + - name: Run BC Lint Action + uses: pytorch/test-infra/.github/actions/bc-lint@main + with: + repo: ${{ github.event.pull_request.head.repo.full_name }} + base_sha: ${{ github.event.pull_request.base.sha }} + head_sha: ${{ github.event.pull_request.head.sha }} + suppression: ${{ contains(github.event.pull_request.labels.*.name, 'suppress-bc-linter') }} + docs_link: 'https://github.com/pytorch/test-infra/wiki/BC-Linter' + config_dir: .github + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index d5c6b8d43a6ef..c3e132a536a42 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.12' diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml new file mode 100644 index 0000000000000..c2b17abe811cd --- /dev/null +++ b/.github/workflows/issue_autolabel.yml @@ -0,0 +1,309 @@ +name: Label issues based on keywords +on: + issues: + types: [opened, edited, reopened] +permissions: + issues: write # needed so the workflow can add labels + contents: read +concurrency: + group: issue-labeler-${{ github.event.issue.number }} + cancel-in-progress: true +jobs: + add-labels: + runs-on: ubuntu-latest + steps: + - name: Label issues based on keywords + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + script: | + // Configuration: Add new labels and keywords here + const labelConfig = { + rocm: { + // Keyword search - matches whole words only (with word boundaries) + keywords: [ + { + term: "composable kernel", + searchIn: "both" + }, + { + term: "rccl", + searchIn: "body" // only search in body + }, + { + term: "migraphx", + searchIn: "title" // only search in title + }, + { + term: "hipgraph", + searchIn: "both" + }, + { + term: "ROCm System Management Interface", + searchIn: "body" + }, + ], + + // Substring search - matches anywhere in text (partial matches) + substrings: [ + { + term: "VLLM_ROCM_", + searchIn: "both" + }, + { + term: "aiter", + searchIn: "title" + }, + { + term: "rocm", + searchIn: "title" + }, + { + term: "amd", + searchIn: "title" + }, + { + term: "hip-", + searchIn: "both" + }, + { + term: "gfx", + searchIn: "both" + }, + { + term: "cdna", + searchIn: "both" + }, + { + term: "rdna", + searchIn: "both" + }, + { + term: "torch_hip", + searchIn: "body" // only in body + }, + { + term: "_hip", + searchIn: "both" + }, + { + term: "hip_", + searchIn: "both" + }, + + // ROCm tools and libraries + { + term: "hipify", + searchIn: "both" + }, + ], + + // Regex patterns - for complex pattern matching + regexPatterns: [ + { + pattern: "\\bmi\\d{3}[a-z]*\\b", + description: "AMD GPU names (mi + 3 digits + optional letters)", + flags: "gi", + searchIn: "both" // "title", "body", or "both" + } + ], + }, + }; + + // Helper function to create regex based on search type + function createSearchRegex(term, type) { + // Escape special regex characters in the term + const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + + switch (type) { + case 'keyword': + // Word boundary search - matches whole words only + return new RegExp(`\\b${escapedTerm}\\b`, "gi"); + case 'substring': + // Substring search - matches anywhere in the text + return new RegExp(escapedTerm, "gi"); + default: + throw new Error(`Unknown search type: ${type}`); + } + } + + // Helper function to find matching terms in text with line information + function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { + const matches = []; + const lines = text.split('\n'); + + for (const termConfig of searchTerms) { + let regex; + let term, searchIn, pattern, description, flags; + + // Handle different input formats (string or object) + if (typeof termConfig === 'string') { + term = termConfig; + searchIn = 'both'; // default + } else { + term = termConfig.term; + searchIn = termConfig.searchIn || 'both'; + pattern = termConfig.pattern; + description = termConfig.description; + flags = termConfig.flags; + } + + // Skip if this term shouldn't be searched in the current location + if (searchIn !== 'both' && searchIn !== searchLocation) { + continue; + } + + // Create appropriate regex + if (searchType === 'regex') { + regex = new RegExp(pattern, flags || "gi"); + } else { + regex = createSearchRegex(term, searchType); + } + + const termMatches = []; + + // Check each line for matches + lines.forEach((line, lineIndex) => { + const lineMatches = line.match(regex); + if (lineMatches) { + lineMatches.forEach(match => { + termMatches.push({ + match: match, + lineNumber: lineIndex + 1, + lineContent: line.trim(), + searchType: searchType, + searchLocation: searchLocation, + originalTerm: term || pattern, + description: description, + // Show context around the match in the line + context: line.length > 100 ? + line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), + line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' + : line.trim() + }); + }); + } + }); + + if (termMatches.length > 0) { + matches.push({ + term: term || (description || pattern), + searchType: searchType, + searchLocation: searchLocation, + searchIn: searchIn, + pattern: pattern, + matches: termMatches, + count: termMatches.length + }); + } + } + + return matches; + } + + // Helper function to check if label should be added + async function processLabel(labelName, config) { + const body = context.payload.issue.body || ""; + const title = context.payload.issue.title || ""; + + core.notice(`Processing label: ${labelName}`); + core.notice(`Issue Title: "${title}"`); + core.notice(`Issue Body length: ${body.length} characters`); + + let shouldAddLabel = false; + let allMatches = []; + let reason = ''; + + const keywords = config.keywords || []; + const substrings = config.substrings || []; + const regexPatterns = config.regexPatterns || []; + + core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); + + // Search in title + if (title.trim()) { + core.notice(`Searching in title: "${title}"`); + + const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); + const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); + const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); + + allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); + } + + // Search in body + if (body.trim()) { + core.notice(`Searching in body (${body.length} characters)`); + + const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); + const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); + const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); + + allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); + } + + if (allMatches.length > 0) { + core.notice(`Found ${allMatches.length} matching term(s):`); + + for (const termMatch of allMatches) { + const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; + const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; + + if (termMatch.searchType === 'regex') { + core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } else { + core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); + } + + // Show details for each match + termMatch.matches.forEach((match, index) => { + core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); + if (match.description) { + core.notice(` Description: ${match.description}`); + } + core.notice(` Context: ${match.context}`); + if (match.lineContent !== match.context) { + core.notice(` Full line: ${match.lineContent}`); + } + }); + } + + shouldAddLabel = true; + const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); + const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); + const bodyMatches = allMatches.filter(t => t.searchLocation === 'body').reduce((sum, t) => sum + t.count, 0); + const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); + const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); + const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); + + reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; + } + + core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); + core.notice(`Reason: ${reason || 'No matching terms found'}`); + + if (shouldAddLabel) { + const existingLabels = context.payload.issue.labels.map(l => l.name); + if (!existingLabels.includes(labelName)) { + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: [labelName], + }); + core.notice(`Label "${labelName}" added. ${reason}`); + return true; + } + core.notice(`Label "${labelName}" already present.`); + return false; + } + + core.notice(`No matching terms found for label "${labelName}".`); + return false; + } + + // Process all configured labels + const processLabels = Object.entries(labelConfig) + .map(([labelName, config]) => processLabel(labelName, config)); + const labelsAdded = await Promise.all(processLabels); + const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0); + core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); \ No newline at end of file diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 195579f206a2f..e21d13b8161f3 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 1ee605dc7bb0d..8884359fa0ce4 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Remind to run full CI on PR - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | try { diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 656f3d3fa7bc4..dca3089f496c9 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 + - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months diff --git a/.gitignore b/.gitignore index 465935d488f84..b1df673e83ca8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* -# triton jit +# triton jit .triton # Byte-compiled / optimized / DLL files @@ -177,6 +177,14 @@ cython_debug/ # VSCode .vscode/ +# Claude +CLAUDE.md +.claude/ + +# Codex +AGENTS.md +.codex/ + # DS Store .DS_Store @@ -209,4 +217,4 @@ shellcheck*/ csrc/moe/marlin_moe_wna16/kernel_* # Ignore ep_kernels_workspace folder -ep_kernels_workspace/ \ No newline at end of file +ep_kernels_workspace/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 612b290e88d46..832c3edcdc7fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,30 +6,18 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.14.0 hooks: - - id: ruff + - id: ruff-check args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos - rev: v1.34.0 + rev: v1.38.1 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v20.1.3 + rev: v21.1.2 hooks: - id: clang-format exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' @@ -46,10 +34,10 @@ repos: hooks: - id: actionlint - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.6.17 + rev: 0.9.1 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] files: ^requirements/test\.(in|txt)$ - repo: local hooks: @@ -60,38 +48,32 @@ repos: files: ^requirements/test\.(in|txt)$ - id: mypy-local name: Run mypy for local Python installation - entry: tools/mypy.sh 0 "local" - language: python - types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] + entry: python tools/pre_commit/mypy.py 0 "local" stages: [pre-commit] # Don't run in CI - - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.9 - entry: tools/mypy.sh 1 "3.9" - language: python - types: [python] - additional_dependencies: *mypy_deps - stages: [manual] # Only run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 - entry: tools/mypy.sh 1 "3.10" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 - entry: tools/mypy.sh 1 "3.11" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 - entry: tools/mypy.sh 1 "3.12" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.13 + entry: python tools/pre_commit/mypy.py 1 "3.13" + <<: *mypy_common stages: [manual] # Only run in CI - id: shellcheck name: Lint shell scripts @@ -155,18 +137,15 @@ repos: additional_dependencies: [regex] - id: check-pickle-imports name: Prevent new pickle/cloudpickle imports - entry: python tools/check_pickle_imports.py + entry: python tools/pre_commit/check_pickle_imports.py language: python types: [python] - pass_filenames: false - additional_dependencies: [pathspec, regex] + additional_dependencies: [regex] - id: validate-config name: Validate configuration has default values and that each field has a docstring entry: python tools/validate_config.py language: python - types: [python] - pass_filenames: true - files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py + additional_dependencies: [regex] # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 4329750090683..d83d6df35ed9a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -13,6 +13,7 @@ build: mkdocs: configuration: mkdocs.yaml + fail_on_warning: true # Optionally declare the Python requirements required to build your docs python: diff --git a/.yapfignore b/.yapfignore index 2d6dcf8380cac..38158259032a6 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1 +1,2 @@ collect_env.py +vllm/model_executor/layers/fla/ops/*.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a1deefb07f09c..005590445361a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,10 @@ cmake_minimum_required(VERSION 3.26) # cmake --install . --component _C project(vllm_extensions LANGUAGES CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + + # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") @@ -30,10 +34,10 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12", "3.13") +set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") # # Supported/expected torch versions for CUDA/ROCm. @@ -45,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1") -set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") # # Try to find python package with an executable that exactly matches @@ -82,6 +86,9 @@ find_package(Torch REQUIRED) # Supported NVIDIA architectures. # This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0") +elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") else() @@ -171,6 +178,25 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +# +# Set compression mode for CUDA >=13.x. +# +if(VLLM_GPU_LANG STREQUAL "CUDA" AND + DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + list(APPEND VLLM_GPU_FLAGS "--compress-mode=size") +endif() + +# +# Set CUDA include flags for CXX compiler. +# +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include") + if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl") + endif() +endif() + # # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. # setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. @@ -243,8 +269,8 @@ set(VLLM_EXT_SRC "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" - "csrc/quantization/fp8/common.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" @@ -256,7 +282,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v4.2.1" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -288,14 +314,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/attention/mla/cutlass_mla_entry.cu" - "csrc/quantization/fp8/per_token_group_quant.cu") + "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/quantization/w8a8/int8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -399,11 +424,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -427,12 +452,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -457,12 +486,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) # require CUDA 12.8 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -493,7 +526,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") @@ -537,10 +570,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require # CUDA 12.8 or later - cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -555,10 +593,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # FP4 Archs and flags - cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") @@ -576,10 +619,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # CUTLASS MLA Archs and flags - cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) set(SRCS - "csrc/attention/mla/cutlass_mla_kernels.cu" "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -603,7 +649,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -621,9 +667,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -642,9 +692,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # moe_data.cu is used by all CUTLASS MoE kernels. - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") @@ -661,9 +715,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -750,6 +808,44 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "found in CUDA target architectures") endif() endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + + # Hadacore kernels + cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}") + if(HADACORE_ARCHS) + set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${HADACORE_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building hadacore") + endif() + # if CUDA endif endif() @@ -790,7 +886,9 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -909,6 +1007,7 @@ endif() # For CUDA we also build and ship some external projects. if (VLLM_GPU_LANG STREQUAL "CUDA") include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/qutlass.cmake) # vllm-flash-attn should be last as it overwrites some CMake functions include(cmake/external_projects/vllm_flash_attn.cmake) diff --git a/MANIFEST.in b/MANIFEST.in index 82fd22b845f09..fb3cccbb4a9c1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,6 @@ include LICENSE include requirements/common.txt include requirements/cuda.txt include requirements/rocm.txt -include requirements/neuron.txt include requirements/cpu.txt include CMakeLists.txt diff --git a/README.md b/README.md index fd8b02ac1f781..3dcdd7dc00942 100644 --- a/README.md +++ b/README.md @@ -14,18 +14,26 @@ Easy, fast, and cheap LLM serving for everyone | Documentation | Blog | Paper | Twitter/X | User Forum | Developer Slack |

+--- +Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) and [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco for our latest updates on vLLM and to meet the vLLM team! Register now for the largest vLLM community events of the year! + --- *Latest News* 🔥 -- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). -- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing). +- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA). +- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing). +- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). +- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). @@ -74,7 +82,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support @@ -141,6 +149,7 @@ Compute Resources: - Trainy - UC Berkeley - UC San Diego +- Volcengine Slack Sponsor: Anyscale diff --git a/SECURITY.md b/SECURITY.md index 414669fb3712e..d6319cdb1ac27 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -42,4 +42,9 @@ For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we ma * If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis. +* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications + * Substantial internal deployment leveraging the upstream vLLM project. + * Established internal security teams and comprehensive compliance measures. + * Active and consistent contributions to the upstream vLLM project. + * We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included. diff --git a/benchmarks/README.md b/benchmarks/README.md index 176b40212978f..269a4d51ec2ef 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,725 +1,20 @@ -# Benchmarking vLLM +# Benchmarks -This README guides you through running benchmark tests with the extensive -datasets supported on vLLM. It’s a living document, updated as new features and datasets -become available. +This directory used to contain vLLM's benchmark scripts and utilities for performance testing and evaluation. -## Dataset Overview +## Contents - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
DatasetOnlineOfflineData Path
ShareGPTwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
ShareGPT4V (Image) - wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json -
-
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
- wget http://images.cocodataset.org/zips/train2017.zip -
ShareGPT4Video (Video) - git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video -
BurstGPTwget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv
Sonnet (deprecated)Local file: benchmarks/sonnet.txt
Randomsynthetic
Prefix Repetitionsynthetic
HuggingFace-VisionArenalmarena-ai/VisionArena-Chat
HuggingFace-InstructCoderlikaixin/InstructCoder
HuggingFace-AIMOAI-MO/aimo-validation-aime , AI-MO/NuminaMath-1.5, AI-MO/NuminaMath-CoT
HuggingFace-Otherlmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered
CustomLocal file: data.jsonl
+- **Serving benchmarks**: Scripts for testing online inference performance (latency, throughput) +- **Throughput benchmarks**: Scripts for testing offline batch inference performance +- **Specialized benchmarks**: Tools for testing specific features like structured output, prefix caching, long document QA, request prioritization, and multi-modal inference +- **Dataset utilities**: Framework for loading and sampling from various benchmark datasets (ShareGPT, HuggingFace datasets, synthetic data, etc.) -✅: supported +## Usage -🟡: Partial support +For detailed usage instructions, examples, and dataset information, see the [Benchmark CLI documentation](https://docs.vllm.ai/en/latest/contributing/benchmarks.html#benchmark-cli). -🚧: to be supported +For full CLI reference see: -**Note**: HuggingFace dataset's `dataset-name` should be set to `hf` - -## 🚀 Example - Online Benchmark - -
-Show more - -
- -First start serving your model - -```bash -vllm serve NousResearch/Hermes-3-Llama-3.1-8B -``` - -Then run the benchmarking script - -```bash -# download dataset -# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -vllm bench serve \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --endpoint /v1/completions \ - --dataset-name sharegpt \ - --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ - --num-prompts 10 -``` - -If successful, you will see the following output - -```text -============ Serving Benchmark Result ============ -Successful requests: 10 -Benchmark duration (s): 5.78 -Total input tokens: 1369 -Total generated tokens: 2212 -Request throughput (req/s): 1.73 -Output token throughput (tok/s): 382.89 -Total Token throughput (tok/s): 619.85 ----------------Time to First Token---------------- -Mean TTFT (ms): 71.54 -Median TTFT (ms): 73.88 -P99 TTFT (ms): 79.49 ------Time per Output Token (excl. 1st token)------ -Mean TPOT (ms): 7.91 -Median TPOT (ms): 7.96 -P99 TPOT (ms): 8.03 ----------------Inter-token Latency---------------- -Mean ITL (ms): 7.74 -Median ITL (ms): 7.70 -P99 ITL (ms): 8.39 -================================================== -``` - -### Custom Dataset - -If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl - -```json -{"prompt": "What is the capital of India?"} -{"prompt": "What is the capital of Iran?"} -{"prompt": "What is the capital of China?"} -``` - -```bash -# start server -VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct -``` - -```bash -# run benchmarking script -vllm bench serve --port 9001 --save-result --save-detailed \ - --backend vllm \ - --model meta-llama/Llama-3.1-8B-Instruct \ - --endpoint /v1/completions \ - --dataset-name custom \ - --dataset-path \ - --custom-skip-chat-template \ - --num-prompts 80 \ - --max-concurrency 1 \ - --temperature=0.3 \ - --top-p=0.75 \ - --result-dir "./log/" -``` - -You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. - -### VisionArena Benchmark for Vision Language Models - -```bash -# need a model with vision capability here -vllm serve Qwen/Qwen2-VL-7B-Instruct -``` - -```bash -vllm bench serve \ - --backend openai-chat \ - --endpoint-type openai-chat \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path lmarena-ai/VisionArena-Chat \ - --hf-split train \ - --num-prompts 1000 -``` - -### InstructCoder Benchmark with Speculative Decoding - -``` bash -VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ - --speculative-config $'{"method": "ngram", - "num_speculative_tokens": 5, "prompt_lookup_max": 5, - "prompt_lookup_min": 2}' -``` - -``` bash -vllm bench serve \ - --model meta-llama/Meta-Llama-3-8B-Instruct \ - --dataset-name hf \ - --dataset-path likaixin/InstructCoder \ - --num-prompts 2048 -``` - -### Other HuggingFaceDataset Examples - -```bash -vllm serve Qwen/Qwen2-VL-7B-Instruct -``` - -`lmms-lab/LLaVA-OneVision-Data`: - -```bash -vllm bench serve \ - --backend openai-chat \ - --endpoint-type openai-chat \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path lmms-lab/LLaVA-OneVision-Data \ - --hf-split train \ - --hf-subset "chart2text(cauldron)" \ - --num-prompts 10 -``` - -`Aeala/ShareGPT_Vicuna_unfiltered`: - -```bash -vllm bench serve \ - --backend openai-chat \ - --endpoint-type openai-chat \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ - --hf-split train \ - --num-prompts 10 -``` - -`AI-MO/aimo-validation-aime`: - -``` bash -vllm bench serve \ - --model Qwen/QwQ-32B \ - --dataset-name hf \ - --dataset-path AI-MO/aimo-validation-aime \ - --num-prompts 10 \ - --seed 42 -``` - -`philschmid/mt-bench`: - -``` bash -vllm bench serve \ - --model Qwen/QwQ-32B \ - --dataset-name hf \ - --dataset-path philschmid/mt-bench \ - --num-prompts 80 -``` - -### Running With Sampling Parameters - -When using OpenAI-compatible backends such as `vllm`, optional sampling -parameters can be specified. Example client command: - -```bash -vllm bench serve \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --endpoint /v1/completions \ - --dataset-name sharegpt \ - --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ - --top-k 10 \ - --top-p 0.9 \ - --temperature 0.5 \ - --num-prompts 10 -``` - -### Running With Ramp-Up Request Rate - -The benchmark tool also supports ramping up the request rate over the -duration of the benchmark run. This can be useful for stress testing the -server or finding the maximum throughput that it can handle, given some latency budget. - -Two ramp-up strategies are supported: - -- `linear`: Increases the request rate linearly from a start value to an end value. -- `exponential`: Increases the request rate exponentially. - -The following arguments can be used to control the ramp-up: - -- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). -- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. -- `--ramp-up-end-rps`: The request rate at the end of the benchmark. - -
- -## 📈 Example - Offline Throughput Benchmark - -
-Show more - -
- -```bash -vllm bench throughput \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset-name sonnet \ - --dataset-path vllm/benchmarks/sonnet.txt \ - --num-prompts 10 -``` - -If successful, you will see the following output - -```text -Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s -Total num prompt tokens: 5014 -Total num output tokens: 1500 -``` - -### VisionArena Benchmark for Vision Language Models - -```bash -vllm bench throughput \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --backend vllm-chat \ - --dataset-name hf \ - --dataset-path lmarena-ai/VisionArena-Chat \ - --num-prompts 1000 \ - --hf-split train -``` - -The `num prompt tokens` now includes image token counts - -```text -Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s -Total num prompt tokens: 14527 -Total num output tokens: 1280 -``` - -### InstructCoder Benchmark with Speculative Decoding - -``` bash -VLLM_WORKER_MULTIPROC_METHOD=spawn \ -VLLM_USE_V1=1 \ -vllm bench throughput \ - --dataset-name=hf \ - --dataset-path=likaixin/InstructCoder \ - --model=meta-llama/Meta-Llama-3-8B-Instruct \ - --input-len=1000 \ - --output-len=100 \ - --num-prompts=2048 \ - --async-engine \ - --speculative-config $'{"method": "ngram", - "num_speculative_tokens": 5, "prompt_lookup_max": 5, - "prompt_lookup_min": 2}' -``` - -```text -Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s -Total num prompt tokens: 261136 -Total num output tokens: 204800 -``` - -### Other HuggingFaceDataset Examples - -`lmms-lab/LLaVA-OneVision-Data`: - -```bash -vllm bench throughput \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --backend vllm-chat \ - --dataset-name hf \ - --dataset-path lmms-lab/LLaVA-OneVision-Data \ - --hf-split train \ - --hf-subset "chart2text(cauldron)" \ - --num-prompts 10 -``` - -`Aeala/ShareGPT_Vicuna_unfiltered`: - -```bash -vllm bench throughput \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --backend vllm-chat \ - --dataset-name hf \ - --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ - --hf-split train \ - --num-prompts 10 -``` - -`AI-MO/aimo-validation-aime`: - -```bash -vllm bench throughput \ - --model Qwen/QwQ-32B \ - --backend vllm \ - --dataset-name hf \ - --dataset-path AI-MO/aimo-validation-aime \ - --hf-split train \ - --num-prompts 10 -``` - -Benchmark with LoRA adapters: - -``` bash -# download dataset -# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -vllm bench throughput \ - --model meta-llama/Llama-2-7b-hf \ - --backend vllm \ - --dataset_path /ShareGPT_V3_unfiltered_cleaned_split.json \ - --dataset_name sharegpt \ - --num-prompts 10 \ - --max-loras 2 \ - --max-lora-rank 8 \ - --enable-lora \ - --lora-path yard1/llama-2-7b-sql-lora-test - ``` - -
- -## 🛠️ Example - Structured Output Benchmark - -
-Show more - -
- -Benchmark the performance of structured output generation (JSON, grammar, regex). - -### Server Setup - -```bash -vllm serve NousResearch/Hermes-3-Llama-3.1-8B -``` - -### JSON Schema Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset json \ - --structured-output-ratio 1.0 \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### Grammar-based Generation Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset grammar \ - --structure-type grammar \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### Regex-based Generation Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset regex \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### Choice-based Generation Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset choice \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### XGrammar Benchmark Dataset - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset xgrammar_bench \ - --request-rate 10 \ - --num-prompts 1000 -``` - -
- -## 📚 Example - Long Document QA Benchmark - -
-Show more - -
- -Benchmark the performance of long document question-answering with prefix caching. - -### Basic Long Document QA Test - -```bash -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 16 \ - --document-length 2000 \ - --output-len 50 \ - --repeat-count 5 -``` - -### Different Repeat Modes - -```bash -# Random mode (default) - shuffle prompts randomly -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 8 \ - --document-length 3000 \ - --repeat-count 3 \ - --repeat-mode random - -# Tile mode - repeat entire prompt list in sequence -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 8 \ - --document-length 3000 \ - --repeat-count 3 \ - --repeat-mode tile - -# Interleave mode - repeat each prompt consecutively -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 8 \ - --document-length 3000 \ - --repeat-count 3 \ - --repeat-mode interleave -``` - -
- -## 🗂️ Example - Prefix Caching Benchmark - -
-Show more - -
- -Benchmark the efficiency of automatic prefix caching. - -### Fixed Prompt with Prefix Caching - -```bash -python3 benchmarks/benchmark_prefix_caching.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-prompts 1 \ - --repeat-count 100 \ - --input-length-range 128:256 -``` - -### ShareGPT Dataset with Prefix Caching - -```bash -# download dataset -# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - -python3 benchmarks/benchmark_prefix_caching.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ - --enable-prefix-caching \ - --num-prompts 20 \ - --repeat-count 5 \ - --input-length-range 128:256 -``` - -### Prefix Repetition Dataset - -```bash -vllm bench serve \ - --backend openai \ - --model meta-llama/Llama-2-7b-chat-hf \ - --dataset-name prefix_repetition \ - --num-prompts 100 \ - --prefix-repetition-prefix-len 512 \ - --prefix-repetition-suffix-len 128 \ - --prefix-repetition-num-prefixes 5 \ - --prefix-repetition-output-len 128 -``` - -
- -## ⚡ Example - Request Prioritization Benchmark - -
-Show more - -
- -Benchmark the performance of request prioritization in vLLM. - -### Basic Prioritization Test - -```bash -python3 benchmarks/benchmark_prioritization.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --input-len 128 \ - --output-len 64 \ - --num-prompts 100 \ - --scheduling-policy priority -``` - -### Multiple Sequences per Prompt - -```bash -python3 benchmarks/benchmark_prioritization.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --input-len 128 \ - --output-len 64 \ - --num-prompts 100 \ - --scheduling-policy priority \ - --n 2 -``` - -
- -## 👁️ Example - Multi-Modal Benchmark - -
-Show more - -
- -Benchmark the performance of multi-modal requests in vLLM. - -### Images (ShareGPT4V) - -Start vLLM: - -```bash -python -m vllm.entrypoints.openai.api_server \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dtype bfloat16 \ - --limit-mm-per-prompt '{"image": 1}' \ - --allowed-local-media-path /path/to/sharegpt4v/images -``` - -Send requests with images: - -```bash -python benchmarks/benchmark_serving.py \ - --backend openai-chat \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dataset-name sharegpt \ - --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \ - --num-prompts 100 \ - --save-result \ - --result-dir ~/vllm_benchmark_results \ - --save-detailed \ - --endpoint /v1/chat/completion -``` - -### Videos (ShareGPT4Video) - -Start vLLM: - -```bash -python -m vllm.entrypoints.openai.api_server \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dtype bfloat16 \ - --limit-mm-per-prompt '{"video": 1}' \ - --allowed-local-media-path /path/to/sharegpt4video/videos -``` - -Send requests with videos: - -```bash -python benchmarks/benchmark_serving.py \ - --backend openai-chat \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dataset-name sharegpt \ - --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ - --num-prompts 100 \ - --save-result \ - --result-dir ~/vllm_benchmark_results \ - --save-detailed \ - --endpoint /v1/chat/completion -``` - -
+- +- +- diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md index 9aad51df6e003..d1bdb4c43f10b 100644 --- a/benchmarks/auto_tune/README.md +++ b/benchmarks/auto_tune/README.md @@ -31,6 +31,12 @@ cd vllm You must set the following variables at the top of the script before execution. + Note: You can also override the default values below via environment variables when running the script. + +```bash +MODEL=meta-llama/Llama-3.3-70B-Instruct SYSTEM=TPU TP=8 DOWNLOAD_DIR='' INPUT_LEN=128 OUTPUT_LEN=2048 MAX_MODEL_LEN=2300 MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 NUM_SEQS_LIST="128 256" NUM_BATCHED_TOKENS_LIST="1024 2048 4096" VLLM_LOGGING_LEVEL=DEBUG bash auto_tune.sh +``` + | Variable | Description | Example Value | | --- | --- | --- | | `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` | @@ -143,3 +149,70 @@ The script follows a systematic process to find the optimal parameters: 4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far. 5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard. + +## Batched `auto_tune` + +The `batch_auto_tune.sh` script allows you to run multiple `auto_tune.sh` experiments sequentially from a single configuration file. It iterates through a list of parameter sets, executes `auto_tune.sh` for each, and records the results back into the input file. + +### Prerequisites + +- **jq**: This script requires `jq` to parse the JSON configuration file. +- **gcloud**: If you plan to upload results to Google Cloud Storage, the `gcloud` CLI must be installed and authenticated. + +### How to Run + +1. **Create a JSON configuration file**: Create a file (e.g., `runs_config.json`) containing an array of JSON objects. Each object defines the parameters for a single `auto_tune.sh` run. + +2. **Execute the script**: + + ```bash + bash batch_auto_tune.sh [gcs_upload_path] + ``` + + - ``: **Required.** Path to your JSON configuration file. + - `[gcs_upload_path]`: **Optional.** A GCS path (e.g., `gs://my-bucket/benchmark-results`) where the detailed results and profiles for each run will be uploaded. If this is empty, the results will be available on the local filesystem (see the log for `RESULT_FILE=/path/to/results/file.txt`). + +### Configuration File + +The JSON configuration file should contain an array of objects. Each object's keys correspond to the configuration variables for `auto_tune.sh` (see the [Configuration table above](#configuration)). These keys will be converted to uppercase environment variables for each run. + +Here is an example `runs_config.json` with two benchmark configurations: + +```json +[ + { + "base": "/home/user", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "system": "TPU", # OR GPU + "tp": 8, + "input_len": 128, + "output_len": 2048, + "max_model_len": 2300, + "num_seqs_list": "128 256", + "num_batched_tokens_list": "8192 16384" + }, + { + "base": "/home/user", + "model": "meta-llama/Llama-3.1-70B-Instruct", + "system": "TPU", # OR GPU + "tp": 8, + "input_len": 4000, + "output_len": 16, + "max_model_len": 4096, + "num_seqs_list": "64 128", + "num_batched_tokens_list": "4096 8192", + "max_latency_allowed_ms": 500 + } +] +``` + +### Output + +The script modifies the input JSON file in place, adding the results of each run to the corresponding object. The following fields are added: + +- `run_id`: A unique identifier for the run, derived from the timestamp. +- `status`: The outcome of the run (`SUCCESS`, `FAILURE`, or `WARNING_NO_RESULT_FILE`). +- `results`: The content of the `result.txt` file from the `auto_tune.sh` run. +- `gcs_results`: The GCS URL where the run's artifacts are stored (if a GCS path was provided). + +A summary of successful and failed runs is also printed to the console upon completion. diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 82c20ffa6554c..56b721cbb4021 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -5,25 +5,41 @@ TAG=$(date +"%Y_%m_%d_%H_%M") SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -BASE="$SCRIPT_DIR/../../.." -MODEL="meta-llama/Llama-3.1-8B-Instruct" -SYSTEM="TPU" -TP=1 -DOWNLOAD_DIR="" -INPUT_LEN=4000 -OUTPUT_LEN=16 -MAX_MODEL_LEN=4096 -MIN_CACHE_HIT_PCT=0 -MAX_LATENCY_ALLOWED_MS=100000000000 -NUM_SEQS_LIST="128 256" -NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" +VLLM_LOGGING_LEVEL=${VLLM_LOGGING_LEVEL:-INFO} +BASE=${BASE:-"$SCRIPT_DIR/../../.."} +MODEL=${MODEL:-"meta-llama/Llama-3.1-8B-Instruct"} +SYSTEM=${SYSTEM:-"TPU"} +TP=${TP:-1} +DOWNLOAD_DIR=${DOWNLOAD_DIR:-""} +INPUT_LEN=${INPUT_LEN:-4000} +OUTPUT_LEN=${OUTPUT_LEN:-16} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-4096} +MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} +MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} +NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} +NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" PROFILE_PATH="$LOG_FOLDER/profile" -echo "result file: $RESULT" -echo "model: $MODEL" +echo "====================== AUTO TUNE PARAMETERS ====================" +echo "SCRIPT_DIR=$SCRIPT_DIR" +echo "BASE=$BASE" +echo "MODEL=$MODEL" +echo "SYSTEM=$SYSTEM" +echo "TP=$TP" +echo "DOWNLOAD_DIR=$DOWNLOAD_DIR" +echo "INPUT_LEN=$INPUT_LEN" +echo "OUTPUT_LEN=$OUTPUT_LEN" +echo "MAX_MODEL_LEN=$MAX_MODEL_LEN" +echo "MIN_CACHE_HIT_PCT=$MIN_CACHE_HIT_PCT" +echo "MAX_LATENCY_ALLOWED_MS=$MAX_LATENCY_ALLOWED_MS" +echo "NUM_SEQS_LIST=$NUM_SEQS_LIST" +echo "NUM_BATCHED_TOKENS_LIST=$NUM_BATCHED_TOKENS_LIST" +echo "VLLM_LOGGING_LEVEL=$VLLM_LOGGING_LEVEL" +echo "RESULT_FILE=$RESULT" +echo "====================== AUTO TUNEPARAMETERS ====================" rm -rf $LOG_FOLDER rm -rf $PROFILE_PATH @@ -58,7 +74,7 @@ start_server() { local vllm_log=$4 local profile_dir=$5 - pkill -if vllm + pkill -if "vllm serve" || true # Define the common arguments as a bash array. # Each argument and its value are separate elements. @@ -80,17 +96,22 @@ start_server() { # This correctly passes each element as a separate argument. if [[ -n "$profile_dir" ]]; then # Start server with profiling enabled - VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ + VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & else # Start server without profiling - VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \ + VLLM_SERVER_DEV_MODE=1 \ vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & fi + local server_pid=$! # wait for 10 minutes... server_started=0 for i in {1..60}; do + # This line checks whether the server is still alive or not, + # since that we should always have permission to send signal to the server process. + kill -0 $server_pid 2> /dev/null || break + RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) if [[ "$STATUS_CODE" -eq 200 ]]; then @@ -102,7 +123,7 @@ start_server() { done if (( ! server_started )); then - echo "server did not start within 10 minutes. Please check server log at $vllm_log". + echo "server did not start within 10 minutes or crashed. Please check server log at $vllm_log". return 1 else return 0 @@ -118,7 +139,7 @@ run_benchmark() { echo "vllm_log: $vllm_log" echo rm -f $vllm_log - pkill -if vllm + pkill -if "vllm serve" || true echo "starting server..." # Call start_server without a profile_dir to avoid profiling overhead @@ -211,9 +232,9 @@ run_benchmark() { echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" - pkill -if vllm + pkill -if "vllm serve" || true sleep 10 - printf '=%.0s' $(seq 1 20) + echo "====================" return 0 } @@ -287,6 +308,6 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then else echo "No configuration met the latency requirements. Skipping final profiling run." fi -pkill -if vllm +pkill -if "vllm serve" || true echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" diff --git a/benchmarks/auto_tune/batch_auto_tune.sh b/benchmarks/auto_tune/batch_auto_tune.sh new file mode 100755 index 0000000000000..57ef20daf6b71 --- /dev/null +++ b/benchmarks/auto_tune/batch_auto_tune.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +INPUT_JSON="$1" +GCS_PATH="$2" # Optional GCS path for uploading results for each run + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) +AUTOTUNE_SCRIPT="$SCRIPT_DIR/auto_tune.sh" + +if [[ -z "$INPUT_JSON" ]]; then + echo "Error: Input JSON file not provided." + echo "Usage: $0 [gcs_upload_path]" + exit 1 +fi + +if [[ ! -f "$INPUT_JSON" ]]; then + echo "Error: File not found at '$INPUT_JSON'" + exit 1 +fi + +if ! command -v jq &> /dev/null; then + echo "Error: 'jq' command not found. Please install jq to process the JSON input." + exit 1 +fi + +if [[ -n "$GCS_PATH" ]] && ! command -v gcloud &> /dev/null; then + echo "Error: 'gcloud' command not found, but a GCS_PATH was provided." + exit 1 +fi + +SUCCESS_COUNT=0 +FAILURE_COUNT=0 +FAILED_RUNS=() +SCRIPT_START_TIME=$(date +%s) + +json_content=$(cat "$INPUT_JSON") +if ! num_runs=$(echo "$json_content" | jq 'length'); then + echo "Error: Invalid JSON in $INPUT_JSON. 'jq' failed to get array length." >&2 + exit 1 +fi + +echo "Found $num_runs benchmark configurations in $INPUT_JSON." +echo "Starting benchmark runs..." +echo "--------------------------------------------------" + +for i in $(seq 0 $(($num_runs - 1))); do + run_object=$(echo "$json_content" | jq ".[$i]") + + RUN_START_TIME=$(date +%s) + ENV_VARS_ARRAY=() + # Dynamically create env vars from the JSON object's keys + for key in $(echo "$run_object" | jq -r 'keys_unsorted[]'); do + value=$(echo "$run_object" | jq -r ".$key") + var_name=$(echo "$key" | tr '[:lower:]' '[:upper:]' | tr -cd 'A-Z0-9_') + ENV_VARS_ARRAY+=("${var_name}=${value}") + done + + echo "Executing run #$((i+1))/$num_runs with parameters: ${ENV_VARS_ARRAY[*]}" + + # Execute auto_tune.sh and capture output + RUN_OUTPUT_FILE=$(mktemp) + if env "${ENV_VARS_ARRAY[@]}" bash "$AUTOTUNE_SCRIPT" > >(tee -a "$RUN_OUTPUT_FILE") 2>&1; then + STATUS="SUCCESS" + ((SUCCESS_COUNT++)) + else + STATUS="FAILURE" + ((FAILURE_COUNT++)) + FAILED_RUNS+=("Run #$((i+1)): $(echo $run_object | jq -c .)") + fi + + RUN_OUTPUT=$(<"$RUN_OUTPUT_FILE") + rm "$RUN_OUTPUT_FILE" + + # Parse results and optionally upload them to GCS + RUN_ID="" + RESULTS="" + GCS_RESULTS_URL="" + if [[ "$STATUS" == "SUCCESS" ]]; then + RESULT_FILE_PATH=$(echo "$RUN_OUTPUT" | grep 'RESULT_FILE=' | tail -n 1 | cut -d'=' -f2 | tr -s '/' || true) + + if [[ -n "$RESULT_FILE_PATH" && -f "$RESULT_FILE_PATH" ]]; then + RUN_ID=$(basename "$(dirname "$RESULT_FILE_PATH")") + RESULT_DIR=$(dirname "$RESULT_FILE_PATH") + RESULTS=$(cat "$RESULT_FILE_PATH") + + if [[ -n "$GCS_PATH" ]]; then + GCS_RESULTS_URL="${GCS_PATH}/${RUN_ID}" + echo "Uploading results to GCS..." + if gcloud storage rsync --recursive "$RESULT_DIR/" "$GCS_RESULTS_URL"; then + echo "GCS upload successful." + else + echo "Warning: GCS upload failed for RUN_ID $RUN_ID." + fi + fi + else + echo "Warning: Could not find result file for a successful run." + STATUS="WARNING_NO_RESULT_FILE" + fi + fi + + # Add the results back into the JSON object for this run + json_content=$(echo "$json_content" | jq --argjson i "$i" --arg run_id "$RUN_ID" --arg status "$STATUS" --arg results "$RESULTS" --arg gcs_results "$GCS_RESULTS_URL" \ + '.[$i] += {run_id: $run_id, status: $status, results: $results, gcs_results: $gcs_results}') + + RUN_END_TIME=$(date +%s) + echo "Run finished in $((RUN_END_TIME - RUN_START_TIME)) seconds. Status: $STATUS" + echo "--------------------------------------------------" + + # Save intermediate progress back to the file + echo "$json_content" > "$INPUT_JSON.tmp" && mv "$INPUT_JSON.tmp" "$INPUT_JSON" + +done + +SCRIPT_END_TIME=$(date +%s) +echo "All benchmark runs completed in $((SCRIPT_END_TIME - SCRIPT_START_TIME)) seconds." +echo +echo "====================== SUMMARY ======================" +echo "Successful runs: $SUCCESS_COUNT" +echo "Failed runs: $FAILURE_COUNT" +echo "===================================================" + +if [[ $FAILURE_COUNT -gt 0 ]]; then + echo "Details of failed runs (see JSON file for full parameters):" + for failed in "${FAILED_RUNS[@]}"; do + echo " - $failed" + done +fi + +echo "Updated results have been saved to '$INPUT_JSON'." diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py index fd363c2ad0514..5434f8b6a4e44 100644 --- a/benchmarks/benchmark_block_pool.py +++ b/benchmarks/benchmark_block_pool.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +from benchmark_utils import TimeCollector from tabulate import tabulate -from benchmark_utils import TimeCollector from vllm.utils import FlexibleArgumentParser from vllm.v1.core.block_pool import BlockPool @@ -57,7 +57,7 @@ def invoke_main() -> None: "--num-iteration", type=int, default=1000, - help="Number of iterations to run to stablize final data readings", + help="Number of iterations to run to stabilize final data readings", ) parser.add_argument( "--allocate-blocks", diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py deleted file mode 100644 index 2ea4f9ccaff2b..0000000000000 --- a/benchmarks/benchmark_dataset.py +++ /dev/null @@ -1,1288 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This module defines a framework for sampling benchmark requests from various -datasets. Each dataset subclass of BenchmarkDataset must implement sample -generation. Supported dataset types include: - - ShareGPT - - Random (synthetic) - - Sonnet - - BurstGPT - - HuggingFace - - VisionArena -""" - -import base64 -import io -import json -import logging -import random -from abc import ABC, abstractmethod -from collections.abc import Mapping -from copy import deepcopy -from dataclasses import dataclass -from functools import cache -from io import BytesIO -from typing import Any, Callable, Optional, Union - -import numpy as np -import pandas as pd -from datasets import load_dataset -from PIL import Image -from transformers import PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path -from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.image import convert_image_mode -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer - -logger = logging.getLogger(__name__) - -# ----------------------------------------------------------------------------- -# Data Classes -# ----------------------------------------------------------------------------- - - -@dataclass -class SampleRequest: - """ - Represents a single inference request for benchmarking. - """ - - prompt: Union[str, Any] - prompt_len: int - expected_output_len: int - multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None - lora_request: Optional[LoRARequest] = None - request_id: Optional[str] = None - - -# ----------------------------------------------------------------------------- -# Benchmark Dataset Base Class -# ----------------------------------------------------------------------------- - - -class BenchmarkDataset(ABC): - DEFAULT_SEED = 0 - IS_MULTIMODAL = False - - def __init__( - self, - dataset_path: Optional[str] = None, - random_seed: int = DEFAULT_SEED, - ) -> None: - """ - Initialize the BenchmarkDataset with an optional dataset path and random - seed. Args: - dataset_path (Optional[str]): Path to the dataset. If None, it - indicates that a default or random dataset might be used. - random_seed (int): Seed value for reproducible shuffling or - sampling. Defaults to DEFAULT_SEED. - """ - self.dataset_path = dataset_path - # Set the random seed, ensuring that a None value is replaced with the - # default seed. - self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED - self.data = None - - def apply_multimodal_chat_transformation( - self, prompt: str, mm_content: Optional[MultiModalDataDict] = None - ) -> list[dict]: - """ - Transform a prompt and optional multimodal content into a chat format. - This method is used for chat models that expect a specific conversation - format. - """ - content = [{"text": prompt, "type": "text"}] - if mm_content is not None: - content.append(mm_content) - return [{"role": "user", "content": content}] - - def load_data(self) -> None: - """ - Load data from the dataset path into self.data. - - This method must be overridden by subclasses since the method to load - data will vary depending on the dataset format and source. - - Raises: - NotImplementedError: If a subclass does not implement this method. - """ - # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError("load_data must be implemented in subclasses.") - - def get_random_lora_request( - self, - tokenizer: PreTrainedTokenizerBase, - max_loras: Optional[int] = None, - lora_path: Optional[str] = None, - ) -> tuple[Optional[LoRARequest], AnyTokenizer]: - """ - Optionally select a random LoRA request and return its associated - tokenizer. - - This method is used when LoRA parameters are provided. It randomly - selects a LoRA based on max_loras and retrieves a cached tokenizer for - that LoRA if available. Otherwise, it returns the base tokenizer. - - Args: - tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no - LoRA is selected. max_loras (Optional[int]): The maximum number of - LoRAs available. If None, LoRA is not used. lora_path - (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA - is not used. - - Returns: - tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first - element is a LoRARequest (or None if not applicable) and the second - element is the tokenizer associated with the LoRA request (or the - base tokenizer). - """ - if max_loras is None or lora_path is None: - return None, tokenizer - - # Generate a random LoRA ID in the range [1, max_loras]. - lora_id = random.randint(1, max_loras) - lora_request = LoRARequest( - lora_name=str(lora_id), - lora_int_id=lora_id, - lora_path=lora_path_on_disk(lora_path), - ) - if lora_id not in lora_tokenizer_cache: - lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) - # Return lora_request and the cached tokenizer if available; otherwise, - # return the base tokenizer - return lora_request, lora_tokenizer_cache[lora_id] or tokenizer - - @abstractmethod - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "", - ) -> list[SampleRequest]: - """ - Abstract method to generate sample requests from the dataset. - - Subclasses must override this method to implement dataset-specific logic - for generating a list of SampleRequest objects. - - Args: - tokenizer (PreTrainedTokenizerBase): The tokenizer to be used - for processing the dataset's text. - num_requests (int): The number of sample requests to generate. - request_id_prefix (str) The prefix of request_id. - - Returns: - list[SampleRequest]: A list of sample requests generated from the - dataset. - """ - raise NotImplementedError("sample must be implemented in subclasses.") - - def maybe_oversample_requests( - self, - requests: list[SampleRequest], - num_requests: int, - request_id_prefix: str = "", - ) -> None: - """ - Oversamples the list of requests if its size is less than the desired - number. - - Args: - requests (List[SampleRequest]): The current list of sampled - requests. - num_requests (int): The target number of requests. - request_id_prefix (str) The prefix of the request ids. - """ - if len(requests) < num_requests: - random.seed(self.random_seed) - additional = deepcopy( - random.choices(requests, k=num_requests - len(requests)) - ) - for i in range(len(additional)): - req = additional[i] - req.request_id = request_id_prefix + str(len(requests) + i) - requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", num_requests) - - -# ----------------------------------------------------------------------------- -# Utility Functions and Global Caches -# ----------------------------------------------------------------------------- - - -def is_valid_sequence( - prompt_len: int, - output_len: int, - min_len: int = 4, - max_prompt_len: int = 1024, - max_total_len: int = 2048, - skip_min_output_len_check: bool = False, -) -> bool: - """ - Validate a sequence based on prompt and output lengths. - - Default pruning criteria are copied from the original `sample_hf_requests` - and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as - from `sample_requests` in benchmark_throughput.py. - """ - # Check for invalid conditions - prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len < min_len) - prompt_too_long = prompt_len > max_prompt_len - combined_too_long = (prompt_len + output_len) > max_total_len - - # Return True if none of the invalid conditions are met - return not ( - prompt_too_short or output_too_short or prompt_too_long or combined_too_long - ) - - -@cache -def lora_path_on_disk(lora_path: str) -> str: - return get_adapter_absolute_path(lora_path) - - -# Global cache for LoRA tokenizers. -lora_tokenizer_cache: dict[int, AnyTokenizer] = {} - - -def process_image(image: Any) -> Mapping[str, Any]: - """ - Process a single image input and return a multimedia content dictionary. - - Supports three input types: - - 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key - containing raw image data. - Loads the bytes as a PIL.Image.Image. - - 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as - a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns - a dictionary with the image as a base64 data URL. - - 3. String input: - Treats the string as a URL or local file path. - - Prepends "file://" if the string doesn't start with "http://" or - "file://". - Returns a dictionary with the image URL. - - Raises: - ValueError: If the input is not a supported type. - """ - if isinstance(image, dict) and "bytes" in image: - image = Image.open(BytesIO(image["bytes"])) - if isinstance(image, Image.Image): - image = convert_image_mode(image, "RGB") - with io.BytesIO() as image_data: - image.save(image_data, format="JPEG") - image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") - return { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, - } - - if isinstance(image, str): - image_url = ( - image if image.startswith(("http://", "file://")) else f"file://{image}" - ) - return {"type": "image_url", "image_url": {"url": image_url}} - - raise ValueError( - f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes." - ) - - -def process_video(video: Any) -> Mapping[str, Any]: - """ - Process a single video input and return a multimedia content dictionary. - - Supports the following input types: - - 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key - containing raw video data. - - 2. String input: - Treats the string as a URL or local file path. - - Prepends "file://" if the string doesn't start with "http://" or - "file://". - Returns a dictionary with the image URL. - - Raises: - ValueError: If the input is not a supported type. - """ - if isinstance(video, dict) and "bytes" in video: - video_bytes = video["bytes"] - video_base64 = base64.b64encode(video_bytes).decode("utf-8") - return { - "type": "video_url", - "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, - } - - if isinstance(video, str): - video_url = ( - video if video.startswith(("http://", "file://")) else f"file://{video}" - ) - return {"type": "video_url", "video_url": {"url": video_url}} - - raise ValueError( - f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 - ) - - -# ----------------------------------------------------------------------------- -# Random Dataset Implementation (Synthetic Data) -# ----------------------------------------------------------------------------- - - -class RandomDataset(BenchmarkDataset): - # Default values copied from benchmark_serving.py for the random dataset. - DEFAULT_PREFIX_LEN = 0 - DEFAULT_RANGE_RATIO = 0.0 - DEFAULT_INPUT_LEN = 1024 - DEFAULT_OUTPUT_LEN = 128 - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - prefix_len: int = DEFAULT_PREFIX_LEN, - range_ratio: float = DEFAULT_RANGE_RATIO, - input_len: int = DEFAULT_INPUT_LEN, - output_len: int = DEFAULT_OUTPUT_LEN, - request_id_prefix: str = "", - **kwargs, - ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" - ) - - vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - - prefix_token_ids = ( - np.random.randint(0, vocab_size, size=prefix_len).tolist() - if prefix_len > 0 - else [] - ) - - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - # Ensure the lower bound for output length is at least 1 to prevent - # sampling 0 tokens, which can cause request failures. - output_low = max(output_low, 1) - output_high = int(output_len * (1 + range_ratio)) - - # Add logging for debugging - logger.info("Sampling input_len from [%s, %s]", input_low, input_high) - logger.info("Sampling output_len from [%s, %s]", output_low, output_high) - - input_lens = np.random.randint(input_low, input_high + 1, size=num_requests) - output_lens = np.random.randint(output_low, output_high + 1, size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) - - requests = [] - for i in range(num_requests): - inner_seq = ( - (offsets[i] + i + np.arange(input_lens[i])) % vocab_size - ).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ - :total_input_len - ] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) - requests.append( - SampleRequest( - prompt=prompt, - prompt_len=total_input_len, - expected_output_len=int(output_lens[i]), - request_id=request_id_prefix + str(i), - ) - ) - - return requests - - -# ----------------------------------------------------------------------------- -# ShareGPT Dataset Implementation -# ----------------------------------------------------------------------------- - - -class ShareGPTDataset(BenchmarkDataset): - """ - Implements the ShareGPT dataset. Loads data from a JSON file and generates - sample requests based on conversation turns. - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data(self) -> None: - if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") - - with open(self.dataset_path, encoding="utf-8") as f: - self.data = json.load(f) - # Filter entries with at least two conversation turns. - self.data = [ - entry - for entry in self.data - if "conversations" in entry and len(entry["conversations"]) >= 2 - ] - random.seed(self.random_seed) - random.shuffle(self.data) - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - lora_path: Optional[str] = None, - max_loras: Optional[int] = None, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - samples: list = [] - ind = 0 - for entry in self.data: - if len(samples) >= num_requests: - break - prompt, completion = ( - entry["conversations"][0]["value"], - entry["conversations"][1]["value"], - ) - - lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path - ) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - new_output_len = len(completion_ids) if output_len is None else output_len - if not is_valid_sequence( - prompt_len, - new_output_len, - skip_min_output_len_check=output_len is not None, - ): - continue - if image_path := entry.get("image"): - mm_content = process_image(image_path) - elif video_path := entry.get("video"): - mm_content = process_video(video_path) - else: - mm_content = None - if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - samples.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=new_output_len, - lora_request=lora_request, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) - return samples - - -# ----------------------------------------------------------------------------- -# Custom Dataset Implementation -# ----------------------------------------------------------------------------- - - -class CustomDataset(BenchmarkDataset): - """ - Implements the Custom dataset. Loads data from a JSONL file and generates - sample requests based on conversation turns. E.g., - ``` - {"prompt": "What is the capital of India?"} - {"prompt": "What is the capital of Iran?"} - {"prompt": "What is the capital of China?"} - ``` - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data(self) -> None: - if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") - - # self.data will be a list of dictionaries - # e.g., [{"prompt": "What is the capital of India?"}, ...] - # This will be the standardized format which load_data() - # has to convert into depending on the filetype of dataset_path. - # sample() will assume this standardized format of self.data - self.data = [] - - # Load the JSONL file - if self.dataset_path.endswith(".jsonl"): - jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) - - # check if the JSONL file has a 'prompt' column - if "prompt" not in jsonl_data.columns: - raise ValueError("JSONL file must contain a 'prompt' column.") - - # Convert each row to a dictionary and append to self.data - # This will convert the DataFrame to a list of dictionaries - # where each dictionary corresponds to a row in the DataFrame. - # This is the standardized format we want for self.data - for _, row in jsonl_data.iterrows(): - self.data.append(row.to_dict()) - else: - raise NotImplementedError( - "Only JSONL format is supported for CustomDataset." - ) - - random.seed(self.random_seed) - random.shuffle(self.data) - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - lora_path: Optional[str] = None, - max_loras: Optional[int] = None, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - skip_chat_template: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = item["prompt"] - - # apply template - if not skip_chat_template: - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) - - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Sonnet Dataset Implementation -# ----------------------------------------------------------------------------- - - -class SonnetDataset(BenchmarkDataset): - """ - Simplified implementation of the Sonnet dataset. Loads poem lines from a - text file and generates sample requests. Default values here copied from - `benchmark_serving.py` for the sonnet dataset. - """ - - DEFAULT_PREFIX_LEN = 200 - DEFAULT_INPUT_LEN = 550 - DEFAULT_OUTPUT_LEN = 150 - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data(self) -> None: - if not self.dataset_path: - raise ValueError("dataset_path must be provided.") - with open(self.dataset_path, encoding="utf-8") as f: - self.data = f.readlines() - - def sample( - self, - tokenizer, - num_requests: int, - prefix_len: int = DEFAULT_PREFIX_LEN, - input_len: int = DEFAULT_INPUT_LEN, - output_len: int = DEFAULT_OUTPUT_LEN, - return_prompt_formatted: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - # Calculate average token length for a poem line. - tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) - - # Build the base prompt. - base_prompt = "Pick as many lines as you can from these poem lines:\n" - base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template( - base_msg, add_generation_prompt=True, tokenize=False - ) - base_offset = len(tokenizer(base_fmt).input_ids) - if input_len <= base_offset: - raise ValueError( - f"'input_len' must be higher than the base prompt length " - f"({base_offset})." - ) - - # Determine how many poem lines to use. - num_input_lines = round((input_len - base_offset) / avg_len) - num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) - prefix_lines = self.data[:num_prefix_lines] - - samples = [] - ind = 0 - while len(samples) < num_requests: - extra_lines = random.choices( - self.data, k=num_input_lines - num_prefix_lines - ) - prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" - msg = [{"role": "user", "content": prompt}] - prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False - ) - prompt_len = len(tokenizer(prompt_formatted).input_ids) - - if prompt_len <= input_len: - samples.append( - SampleRequest( - prompt=prompt_formatted if return_prompt_formatted else prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - return samples - - -# ----------------------------------------------------------------------------- -# BurstGPT Dataset Implementation -# ----------------------------------------------------------------------------- - - -class BurstGPTDataset(BenchmarkDataset): - """ - Implements the BurstGPT dataset. Loads data from a CSV file and generates - sample requests based on synthetic prompt generation. Only rows with Model - "GPT-4" and positive response tokens are used. - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data( - self, - ): - if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") - - df = pd.read_csv(self.dataset_path) - # Filter to keep only GPT-4 rows. - gpt4_df = df[df["Model"] == "GPT-4"] - # Remove failed requests (where Response tokens is 0 or less). - gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] - # Sample the desired number of rows. - self.data = gpt4_df - - def _sample_loaded_data(self, num_requests: int) -> list: - if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, random_state=self.random_seed) - else: - data = self.data.sample( - n=num_requests, - random_state=self.random_seed, - replace=True, - ) - # Convert the dataframe to a list of lists. - return data.values.tolist() - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - max_loras: Optional[int] = None, - lora_path: Optional[str] = None, - request_id_prefix: str = "", - **kwargs, - ) -> list[SampleRequest]: - samples = [] - data = self._sample_loaded_data(num_requests=num_requests) - for i in range(num_requests): - input_len = int(data[i][2]) - output_len = int(data[i][3]) - lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path - ) - vocab_size = tokenizer.vocab_size - # Generate a synthetic prompt: a list of token IDs computed as (i + - # j) modulo vocab_size. - token_ids = [(i + j) % vocab_size for j in range(input_len)] - prompt = tokenizer.decode(token_ids) - samples.append( - SampleRequest( - prompt=prompt, - prompt_len=input_len, - expected_output_len=output_len, - lora_request=lora_req, - request_id=request_id_prefix + str(i), - ) - ) - return samples - - -# ----------------------------------------------------------------------------- -# HuggingFace Dataset Base Implementation -# ----------------------------------------------------------------------------- -class HuggingFaceDataset(BenchmarkDataset): - """Base class for datasets hosted on HuggingFace.""" - - SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() - - def __init__( - self, - dataset_path: str, - dataset_split: str, - no_stream: bool = False, - dataset_subset: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__(dataset_path=dataset_path, **kwargs) - - self.dataset_split = dataset_split - self.dataset_subset = dataset_subset - self.load_stream = not no_stream - self.load_data() - - def load_data(self) -> None: - """Load data from HuggingFace datasets.""" - self.data = load_dataset( - self.dataset_path, - name=self.dataset_subset, - split=self.dataset_split, - streaming=self.load_stream, - ) - self.data = self.data.shuffle(seed=self.random_seed) - - -# ----------------------------------------------------------------------------- -# Conversation Dataset Implementation -# ----------------------------------------------------------------------------- - - -class ConversationDataset(HuggingFaceDataset): - """Dataset for conversation data with multimodal support.""" - - SUPPORTED_DATASET_PATHS = { - "lmms-lab/LLaVA-OneVision-Data", - "Aeala/ShareGPT_Vicuna_unfiltered", - } - IS_MULTIMODAL = True - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - # Filter examples with at least 2 conversations - filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) - sampled_requests = [] - dynamic_output = output_len is None - ind = 0 - - for item in filtered_data: - if len(sampled_requests) >= num_requests: - break - conv = item["conversations"] - prompt, completion = conv[0]["value"], conv[1]["value"] - - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, completion_len): - continue - mm_content = process_image(item["image"]) if "image" in item else None - if enable_multimodal_chat: - # Note: when chat is enabled the request prompt_len is no longer - # accurate and we will be using request output to count the - # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Vision Arena Dataset Implementation -# ----------------------------------------------------------------------------- - - -class VisionArenaDataset(HuggingFaceDataset): - """ - Vision Arena Dataset. - """ - - DEFAULT_OUTPUT_LEN = 128 - SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], - } - IS_MULTIMODAL = True - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) - if parser_fn is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") - prompt = parser_fn(item) - mm_content = process_image(item["images"][0]) - prompt_len = len(tokenizer(prompt).input_ids) - if enable_multimodal_chat: - # Note: when chat is enabled the request prompt_len is no longer - # accurate and we will be using request output to count the - # actual prompt len - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Instruct Coder Dataset Implementation -# ----------------------------------------------------------------------------- - - -class InstructCoderDataset(HuggingFaceDataset): - """ - InstructCoder Dataset. - https://huggingface.co/datasets/likaixin/InstructCoder - - InstructCoder is the dataset designed for general code editing. It consists - of 114,239 instruction-input-output triplets, and covers multiple distinct - code editing scenario. - """ - - DEFAULT_OUTPUT_LEN = 200 # this is the average default output length - SUPPORTED_DATASET_PATHS = { - "likaixin/InstructCoder", - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = ( - f"{item['input']}\n\n{item['instruction']} Just output " - "the code, do not include any explanation." - ) - - # apply template - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# MT-Bench Dataset Implementation -# ----------------------------------------------------------------------------- - - -class MTBenchDataset(HuggingFaceDataset): - """ - MT-Bench Dataset. - https://huggingface.co/datasets/philschmid/mt-bench - - We create a single turn dataset for MT-Bench. - This is similar to Spec decoding benchmark setup in vLLM - https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 - """ # noqa: E501 - - DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM - SUPPORTED_DATASET_PATHS = { - "philschmid/mt-bench", - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] - - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = item["turns"][0] - - # apply template - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) - - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# AIMO Dataset Implementation -# ----------------------------------------------------------------------------- - - -class AIMODataset(HuggingFaceDataset): - """ - Dataset class for processing a AIMO dataset with reasoning questions. - """ - - SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", - "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT", - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - request_id_prefix: str = "", - **kwargs, - ) -> list: - sampled_requests = [] - dynamic_output = output_len is None - ind = 0 - - for item in self.data: - if len(sampled_requests) >= num_requests: - break - prompt, completion = item["problem"], item["solution"] - - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 - ): - continue - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=None, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Next Edit Prediction Dataset Implementation -# ----------------------------------------------------------------------------- - - -zeta_prompt = """### Instruction: -You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. - -### User Edits: - -{} - -### User Excerpt: - -{} - -### Response: - -""" # noqa: E501 - - -def _format_zeta_prompt( - sample: dict, original_start_marker: str = "<|editable_region_start|>" -) -> dict: - """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. - - This function formats examples from the NEP dataset - into prompts and expected outputs. It could be - further extended to support more NEP datasets. - - Args: - sample: The dataset sample containing events, - inputs, and outputs. - original_start_marker: The marker indicating the - start of the editable region. Defaults to - "<|editable_region_start|>". - - Returns: - A dictionary with the formatted prompts and expected outputs. - """ - events = sample["events"] - input = sample["input"] - output = sample["output"] - prompt = zeta_prompt.format(events, input) - - # following the original implementation, extract the focused region - # from the raw output - output_start_index = output.find(original_start_marker) - output_focused_region = output[output_start_index:] - expected_output = output_focused_region - - return {"prompt": prompt, "expected_output": expected_output} - - -class NextEditPredictionDataset(HuggingFaceDataset): - """ - Dataset class for processing a Next Edit Prediction dataset. - """ - - SUPPORTED_DATASET_PATHS = { - "zed-industries/zeta", - } - MAPPING_PROMPT_FUNCS = { - "zed-industries/zeta": _format_zeta_prompt, - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "", - **kwargs, - ): - formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) - if formatting_prompt_func is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") - samples = [] - for i, sample in enumerate(self.data): - sample = formatting_prompt_func(sample) - samples.append( - SampleRequest( - prompt=sample["prompt"], - prompt_len=len(tokenizer(sample["prompt"]).input_ids), - expected_output_len=len( - tokenizer(sample["expected_output"]).input_ids - ), - request_id=request_id_prefix + str(i), - ) - ) - if len(samples) >= num_requests: - break - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) - return samples - - -# ----------------------------------------------------------------------------- -# ASR Dataset Implementation -# ----------------------------------------------------------------------------- - - -class ASRDataset(HuggingFaceDataset): - """ - Dataset class for processing a ASR dataset for transcription. - Tested on the following set: - - +----------------+----------------------------------------+--------------------------+-----------------------------+ - | Dataset | Domain | Speaking Style | hf-subset | - +----------------+----------------------------------------+--------------------------+-----------------------------+ - | TED-LIUM | TED talks | Oratory | release1, release2, release3| - | | | | release3-speaker-adaptation | - | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | - | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | - | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | - | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | - | AMI | Meetings | Spontaneous | ihm, sdm | - +----------------+----------------------------------------+--------------------------+-----------------------------+ - - """ # noqa: E501 - - SUPPORTED_DATASET_PATHS = { - "openslr/librispeech_asr", - "facebook/voxpopuli", - "LIUM/tedlium", - "edinburghcstr/ami", - "speechcolab/gigaspeech", - "kensho/spgispeech", - } - - DEFAULT_OUTPUT_LEN = 128 - IS_MULTIMODAL = True - - # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" - skip_long_audios: bool = True - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - request_id_prefix: str = "", - **kwargs, - ) -> list: - import librosa - - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - prompt = ASRDataset.TRANSCRIPTION_PREAMBLE - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests = [] - skipped = 0 - ind = 0 - for item in self.data: - if len(sampled_requests) >= num_requests: - break - audio = item["audio"] - y, sr = audio["array"], audio["sampling_rate"] - duration_s = librosa.get_duration(y=y, sr=sr) - # Whisper max supported duration - if self.skip_long_audios and duration_s > 30: - skipped += 1 - continue - - mm_content = {"audio": (y, sr)} - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - if skipped: - logger.warning( - "%d samples discarded from dataset due to" - " their length being greater than" - " what Whisper supports.", - skipped, - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index d8b960edaa468..a7892f3f71243 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,191 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark the latency of processing a single batch of requests.""" - -import argparse -import dataclasses -import json -import os -import time -from typing import Any, Optional - -import numpy as np -from tqdm import tqdm -from typing_extensions import deprecated - -import vllm.envs as envs -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptType -from vllm.sampling_params import BeamSearchParams -from vllm.utils import FlexibleArgumentParser - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any] -) -> None: - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={"latency": results["latencies"]}, - extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, - ) - if pt_records: - pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -@deprecated( - "benchmark_latency.py is deprecated and will be removed in a " - "future version. Please use 'vllm bench latency' instead.", -) -def main(args: argparse.Namespace): - print(args) - - engine_args = EngineArgs.from_cli_args(args) - - # NOTE(woosuk): If the request cannot be processed in a single batch, - # the engine will automatically process the request in multiple batches. - llm = LLM(**dataclasses.asdict(engine_args)) - assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + args.output_len - ), ( - "Please ensure that max_model_len is greater than" - " the sum of input_len and output_len." - ) - - sampling_params = SamplingParams( - n=args.n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=args.output_len, - detokenize=not args.disable_detokenize, - ) - print(sampling_params) - dummy_prompt_token_ids = np.random.randint( - 10000, size=(args.batch_size, args.input_len) - ) - dummy_prompts: list[PromptType] = [ - {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() - ] - - def llm_generate(): - if not args.use_beam_search: - llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) - else: - llm.beam_search( - dummy_prompts, - BeamSearchParams( - beam_width=args.n, - max_tokens=args.output_len, - ignore_eos=True, - ), - ) - - def run_to_completion(profile_dir: Optional[str] = None): - if profile_dir: - llm.start_profile() - llm_generate() - llm.stop_profile() - else: - start_time = time.perf_counter() - llm_generate() - end_time = time.perf_counter() - latency = end_time - start_time - return latency - - print("Warming up...") - for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): - run_to_completion(profile_dir=None) - - if args.profile: - profile_dir = envs.VLLM_TORCH_PROFILER_DIR - print(f"Profiling (results will be saved to '{profile_dir}')...") - run_to_completion(profile_dir=profile_dir) - return - - # Benchmark. - latencies = [] - for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): - latencies.append(run_to_completion(profile_dir=None)) - latencies = np.array(latencies) - percentages = [10, 25, 50, 75, 90, 99] - percentiles = np.percentile(latencies, percentages) - print(f"Avg latency: {np.mean(latencies)} seconds") - for percentage, percentile in zip(percentages, percentiles): - print(f"{percentage}% percentile latency: {percentile} seconds") - - # Output JSON results if specified - if args.output_json: - results = { - "avg_latency": np.mean(latencies), - "latencies": latencies.tolist(), - "percentiles": dict(zip(percentages, percentiles.tolist())), - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - save_to_pytorch_benchmark_format(args, results) - - -def create_argument_parser(): - parser = FlexibleArgumentParser( - description="Benchmark the latency of processing a single batch of " - "requests till completion." - ) - parser.add_argument("--input-len", type=int, default=32) - parser.add_argument("--output-len", type=int, default=128) - parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument( - "--n", - type=int, - default=1, - help="Number of generated sequences per prompt.", - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-iters-warmup", - type=int, - default=10, - help="Number of iterations to run for warmup.", - ) - parser.add_argument( - "--num-iters", type=int, default=30, help="Number of iterations to run." - ) - parser.add_argument( - "--profile", - action="store_true", - help="profile the generation process of a single batch", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save the latency results in JSON format.", - ) - parser.add_argument( - "--disable-detokenize", - action="store_true", - help=( - "Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)" - ), - ) - - parser = EngineArgs.add_cli_args(parser) - # V1 enables prefix caching by default which skews the latency - # numbers. We need to disable prefix caching by default. - parser.set_defaults(enable_prefix_caching=False) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: - raise OSError( - "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler." - ) - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench latency + +For help with the new command, run: + vllm bench latency --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench latency --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index c60040d05ab7a..626b150ee4ce0 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -1,17 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import time +from unittest import mock import numpy as np +from benchmark_utils import TimeCollector from tabulate import tabulate -from benchmark_utils import TimeCollector -from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm.v1.worker.gpu_model_runner import GPUModelRunner -def main(args): +def benchmark_propose(args): rows = [] for max_ngram in args.max_ngram: collector = TimeCollector(TimeCollector.US) @@ -69,15 +83,93 @@ def main(args): ) +def benchmark_batched_propose(args): + NUM_SPECULATIVE_TOKENS_NGRAM = 10 + PROMPT_LOOKUP_MIN = 5 + PROMPT_LOOKUP_MAX = 15 + MAX_MODEL_LEN = int(1e7) + DEVICE = current_platform.device_type + + model_config = ModelConfig(model="facebook/opt-125m", runner="generate") + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + method="ngram", + num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM, + prompt_lookup_max=PROMPT_LOOKUP_MAX, + prompt_lookup_min=PROMPT_LOOKUP_MIN, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig(), + ) + + # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + with mock.patch( + "vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group + ): + runner = GPUModelRunner(vllm_config, DEVICE) + + # hack max model len + runner.max_model_len = MAX_MODEL_LEN + runner.drafter.max_model_len = MAX_MODEL_LEN + + dummy_input_batch = InputBatch( + max_num_reqs=args.num_req, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=args.num_req * args.num_token, + device=DEVICE, + pin_memory=False, + vocab_size=256000, + block_sizes=[16], + ) + dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req)) + dummy_input_batch.spec_decode_unsupported_reqs = () + dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req + dummy_input_batch.token_ids_cpu = np.random.randint( + 0, 20, (args.num_req, args.num_token) + ) + + runner.input_batch = dummy_input_batch + + sampled_token_ids = [[0]] * args.num_req + + print("Starting benchmark") + # first run is warmup so ignore it + for _ in range(args.num_iteration): + start = time.time() + runner.drafter.propose( + sampled_token_ids, + dummy_input_batch.req_ids, + dummy_input_batch.num_tokens_no_spec, + dummy_input_batch.token_ids_cpu, + dummy_input_batch.spec_decode_unsupported_reqs, + ) + end = time.time() + print(f"Iteration time (s): {end - start}") + + def invoke_main() -> None: parser = FlexibleArgumentParser( description="Benchmark the performance of N-gram speculative decode drafting" ) + parser.add_argument( + "--batched", action="store_true", help="consider time to prepare batch" + ) parser.add_argument( "--num-iteration", type=int, default=100, - help="Number of iterations to run to stablize final data readings", + help="Number of iterations to run to stabilize final data readings", ) parser.add_argument( "--num-req", type=int, default=128, help="Number of requests in the batch" @@ -105,8 +197,17 @@ def invoke_main() -> None: help="Number of speculative tokens to generate", ) args = parser.parse_args() - main(args) + + if not args.batched: + benchmark_propose(args) + else: + benchmark_batched_propose(args) +""" +# Example command lines: +# time python3 benchmarks/benchmark_ngram_proposer.py +# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128 +""" # noqa: E501 if __name__ == "__main__": invoke_main() # pragma: no cover diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 02f5f585c0c16..76cf51498020b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,1324 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -r"""Benchmark online serving throughput. - -On the server side, run one of the following commands: - vLLM OpenAI API server - vllm serve \ - --swap-space 16 - -On the client side, run: - python benchmarks/benchmark_serving.py \ - --backend \ - --model \ - --dataset-name sharegpt \ - --dataset-path \ - --request-rate \ # By default is inf - --num-prompts # By default is 1000 - - when using tgi backend, add - --endpoint /generate_stream - to the end of the command above. -""" - -import argparse -import asyncio -import gc -import json -import os -import random -import time -import warnings -from collections.abc import Iterable -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Literal, Optional - -import numpy as np -from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase -from typing_extensions import deprecated - -from backend_request_func import ( - ASYNC_REQUEST_FUNCS, - OPENAI_COMPATIBLE_BACKENDS, - RequestFuncInput, - RequestFuncOutput, -) - -try: - from vllm.transformers_utils.tokenizer import get_tokenizer -except ImportError: - from backend_request_func import get_tokenizer - -try: - from vllm.utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser - -from benchmark_dataset import ( - AIMODataset, - ASRDataset, - BurstGPTDataset, - ConversationDataset, - CustomDataset, - HuggingFaceDataset, - InstructCoderDataset, - MTBenchDataset, - NextEditPredictionDataset, - RandomDataset, - SampleRequest, - ShareGPTDataset, - SonnetDataset, - VisionArenaDataset, -) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.benchmarks.serve import get_request - -MILLISECONDS_TO_SECONDS_CONVERSION = 1000 - - -@dataclass -class BenchmarkMetrics: - completed: int - total_input: int - total_output: int - request_throughput: float - request_goodput: float - output_throughput: float - total_token_throughput: float - mean_ttft_ms: float - median_ttft_ms: float - std_ttft_ms: float - percentiles_ttft_ms: list[tuple[float, float]] - mean_tpot_ms: float - median_tpot_ms: float - std_tpot_ms: float - percentiles_tpot_ms: list[tuple[float, float]] - mean_itl_ms: float - median_itl_ms: float - std_itl_ms: float - percentiles_itl_ms: list[tuple[float, float]] - # E2EL stands for end-to-end latency per request. - # It is the time taken on the client side from sending - # a request to receiving a complete response. - mean_e2el_ms: float - median_e2el_ms: float - std_e2el_ms: float - percentiles_e2el_ms: list[tuple[float, float]] - - -def calculate_metrics( - input_requests: list[SampleRequest], - outputs: list[RequestFuncOutput], - dur_s: float, - tokenizer: PreTrainedTokenizerBase, - selected_percentile_metrics: list[str], - selected_percentiles: list[float], - goodput_config_dict: dict[str, float], -) -> tuple[BenchmarkMetrics, list[int]]: - actual_output_lens: list[int] = [] - total_input = 0 - completed = 0 - good_completed = 0 - itls: list[float] = [] - tpots: list[float] = [] - all_tpots: list[float] = [] - ttfts: list[float] = [] - e2els: list[float] = [] - for i in range(len(outputs)): - if outputs[i].success: - output_len = outputs[i].output_tokens - - if not output_len: - # We use the tokenizer to count the number of output tokens - # for some serving backends instead of looking at - # len(outputs[i].itl) since multiple output tokens may be - # bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer( - outputs[i].generated_text, add_special_tokens=False - ).input_ids - ) - actual_output_lens.append(output_len) - total_input += input_requests[i].prompt_len - tpot = 0 - if output_len > 1: - latency_minus_ttft = outputs[i].latency - outputs[i].ttft - tpot = latency_minus_ttft / (output_len - 1) - tpots.append(tpot) - # Note: if output_len <= 1, we regard tpot as 0 for goodput - all_tpots.append(tpot) - itls += outputs[i].itl - ttfts.append(outputs[i].ttft) - e2els.append(outputs[i].latency) - completed += 1 - else: - actual_output_lens.append(0) - - if goodput_config_dict: - valid_metrics = [] - slo_values = [] - - if "ttft" in goodput_config_dict: - valid_metrics.append(ttfts) - slo_values.append( - goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - if "tpot" in goodput_config_dict: - valid_metrics.append(all_tpots) - slo_values.append( - goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - if "e2el" in goodput_config_dict: - valid_metrics.append(e2els) - slo_values.append( - goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - - for req_metric in zip(*valid_metrics): - is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) - if is_good_req: - good_completed += 1 - - if completed == 0: - warnings.warn( - "All requests failed. This is likely due to a misconfiguration " - "on the benchmark arguments.", - stacklevel=2, - ) - metrics = BenchmarkMetrics( - completed=completed, - total_input=total_input, - total_output=sum(actual_output_lens), - request_throughput=completed / dur_s, - request_goodput=good_completed / dur_s, - output_throughput=sum(actual_output_lens) / dur_s, - total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) - * 1000, # ttfts is empty if streaming is not supported by backend - std_ttft_ms=np.std(ttfts or 0) * 1000, - median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[ - (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles - ], - mean_tpot_ms=np.mean(tpots or 0) * 1000, - std_tpot_ms=np.std(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[ - (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles - ], - mean_itl_ms=np.mean(itls or 0) * 1000, - std_itl_ms=np.std(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[ - (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles - ], - mean_e2el_ms=np.mean(e2els or 0) * 1000, - std_e2el_ms=np.std(e2els or 0) * 1000, - median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[ - (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles - ], - ) - - return metrics, actual_output_lens - - -async def benchmark( - backend: str, - api_url: str, - base_url: str, - model_id: str, - model_name: str, - tokenizer: PreTrainedTokenizerBase, - input_requests: list[SampleRequest], - logprobs: Optional[int], - request_rate: float, - burstiness: float, - disable_tqdm: bool, - profile: bool, - selected_percentile_metrics: list[str], - selected_percentiles: list[float], - ignore_eos: bool, - goodput_config_dict: dict[str, float], - max_concurrency: Optional[int], - lora_modules: Optional[Iterable[str]], - extra_body: Optional[dict], - ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, - ramp_up_start_rps: Optional[int] = None, - ramp_up_end_rps: Optional[int] = None, -): - if backend in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[backend] - else: - raise ValueError(f"Unknown backend: {backend}") - - print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0].prompt, - input_requests[0].prompt_len, - input_requests[0].expected_output_len, - input_requests[0].multi_modal_data, - ) - - assert ( - test_mm_content is None - or isinstance(test_mm_content, dict) - or ( - isinstance(test_mm_content, list) - and all(isinstance(item, dict) for item in test_mm_content) - ) - ), "multi_modal_data must be a dict or list[dict]" - test_input = RequestFuncInput( - model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=api_url, - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - - test_output = await request_func(request_func_input=test_input) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}" - ) - else: - print("Initial test run completed. Starting main benchmark run...") - - if lora_modules: - # For each input request, choose a LoRA module at random. - lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))] - ) - - if profile: - print("Starting profiler...") - profile_input = RequestFuncInput( - model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler started") - - distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" - - if ramp_up_strategy is not None: - print( - f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " - f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " - "the duration of the benchmark." - ) - else: - print(f"Traffic request rate: {request_rate} RPS.") - - print(f"Burstiness factor: {burstiness} ({distribution})") - print(f"Maximum request concurrency: {max_concurrency}") - - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - - async def limited_request_func(request_func_input, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, pbar=pbar) - async with semaphore: - return await request_func(request_func_input=request_func_input, pbar=pbar) - - benchmark_start_time = time.perf_counter() - tasks: list[asyncio.Task] = [] - - rps_change_events = [] - last_int_rps = -1 - if ramp_up_strategy is not None and ramp_up_start_rps is not None: - last_int_rps = ramp_up_start_rps - rps_change_events.append( - { - "rps": last_int_rps, - "timestamp": datetime.now().isoformat(), - } - ) - - async for request, current_request_rate in get_request( - input_requests, - request_rate, - burstiness, - ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - ): - if ramp_up_strategy is not None: - current_int_rps = int(current_request_rate) - if current_int_rps > last_int_rps: - timestamp = datetime.now().isoformat() - for rps_val in range(last_int_rps + 1, current_int_rps + 1): - rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) - last_int_rps = current_int_rps - - prompt, prompt_len, output_len, mm_content, request_id = ( - request.prompt, - request.prompt_len, - request.expected_output_len, - request.multi_modal_data, - request.request_id, - ) - req_model_id, req_model_name = model_id, model_name - if lora_modules: - req_lora_module = next(lora_modules) - req_model_id, req_model_name = req_lora_module, req_lora_module - - request_func_input = RequestFuncInput( - model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - request_id=request_id, - ) - task = limited_request_func(request_func_input=request_func_input, pbar=pbar) - tasks.append(asyncio.create_task(task)) - outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) - - if pbar is not None: - pbar.close() - - benchmark_duration = time.perf_counter() - benchmark_start_time - - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentile_metrics=selected_percentile_metrics, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) - - print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) - print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - if max_concurrency is not None: - print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) - if request_rate != float("inf"): - print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) - print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) - print( - "{:<40} {:<10.2f}".format( - "Request throughput (req/s):", metrics.request_throughput - ) - ) - if goodput_config_dict: - print( - "{:<40} {:<10.2f}".format( - "Request goodput (req/s):", metrics.request_goodput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", metrics.output_throughput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Total Token throughput (tok/s):", metrics.total_token_throughput - ) - ) - - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } - - if rps_change_events: - result["rps_change_events"] = rps_change_events - - def process_one_metric( - # E.g., "ttft" - metric_attribute_name: str, - # E.g., "TTFT" - metric_name: str, - # E.g., "Time to First Token" - metric_header: str, - ): - # This function prints and adds statistics of the specified - # metric. - if metric_attribute_name not in selected_percentile_metrics: - return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) - print( - "{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"), - ) - ) - print( - "{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"), - ) - ) - result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms" - ) - result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms" - ) - result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms" - ) - for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): - p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) - result[f"p{p_word}_{metric_attribute_name}_ms"] = value - - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") - process_one_metric("e2el", "E2EL", "End-to-end Latency") - - print("=" * 50) - - if profile: - print("Stopping profiler...") - profile_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=base_url + "/stop_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler stopped") - - return result - - -def check_goodput_args(args): - # Check and parse goodput arguments - goodput_config_dict = {} - VALID_NAMES = ["ttft", "tpot", "e2el"] - if args.goodput: - goodput_config_dict = parse_goodput(args.goodput) - for slo_name, slo_val in goodput_config_dict.items(): - if slo_name not in VALID_NAMES: - raise ValueError( - f"Invalid metric name found, {slo_name}: {slo_val}. " - "The service level objective name should be one of " - f"{str(VALID_NAMES)}. " - ) - if slo_val < 0: - raise ValueError( - f"Invalid value found, {slo_name}: {slo_val}. " - "The service level objective value should be " - "non-negative." - ) - return goodput_config_dict - - -def parse_goodput(slo_pairs): - goodput_config_dict = {} - try: - for slo_pair in slo_pairs: - slo_name, slo_val = slo_pair.split(":") - goodput_config_dict[slo_name] = float(slo_val) - except ValueError as err: - raise argparse.ArgumentTypeError( - "Invalid format found for service level objectives. " - 'Specify service level objectives for goodput as "KEY:VALUE" ' - "pairs, where the key is a metric name, and the value is a " - "number in milliseconds." - ) from err - return goodput_config_dict - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any], file_name: str -) -> None: - metrics = [ - "median_ttft_ms", - "mean_ttft_ms", - "std_ttft_ms", - "p99_ttft_ms", - "mean_tpot_ms", - "median_tpot_ms", - "std_tpot_ms", - "p99_tpot_ms", - "median_itl_ms", - "mean_itl_ms", - "std_itl_ms", - "p99_itl_ms", - ] - # These raw data might be useful, but they are rather big. They can be added - # later if needed - ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={k: [results[k]] for k in metrics}, - extra_info={ - k: results[k] - for k in results - if k not in metrics and k not in ignored_metrics - }, - ) - if pt_records: - # Don't use json suffix here as we don't want CI to pick it up - pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -@deprecated( - "benchmark_serving.py is deprecated and will be removed in a future " - "version. Please use 'vllm bench serve' instead.", -) -def main(args: argparse.Namespace): - print(args) - random.seed(args.seed) - np.random.seed(args.seed) - - backend = args.backend - model_id = args.model - model_name = args.served_model_name - tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model - tokenizer_mode = args.tokenizer_mode - - # Validate ramp-up arguments - if args.ramp_up_strategy is not None: - if args.request_rate != float("inf"): - raise ValueError( - "When using ramp-up, do not specify --request-rate. " - "The request rate will be controlled by ramp-up parameters. " - "Please remove the --request-rate argument." - ) - if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: - raise ValueError( - "When using --ramp-up-strategy, both --ramp-up-start-rps and " - "--ramp-up-end-rps must be specified" - ) - if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: - raise ValueError("Ramp-up start and end RPS must be non-negative") - if args.ramp_up_start_rps > args.ramp_up_end_rps: - raise ValueError("Ramp-up start RPS must be less than end RPS") - if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: - raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") - - if args.base_url is not None: - api_url = f"{args.base_url}{args.endpoint}" - base_url = f"{args.base_url}" - else: - api_url = f"http://{args.host}:{args.port}{args.endpoint}" - base_url = f"http://{args.host}:{args.port}" - - tokenizer = get_tokenizer( - tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code, - ) - - if args.dataset_name is None: - raise ValueError( - "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required." - ) - - if args.dataset_name == "custom": - dataset = CustomDataset(dataset_path=args.dataset_path) - input_requests = dataset.sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.custom_output_len, - skip_chat_template=args.custom_skip_chat_template, - request_id_prefix=args.request_id_prefix, - ) - - elif args.dataset_name == "sonnet": - dataset = SonnetDataset(dataset_path=args.dataset_path) - # For the "sonnet" dataset, formatting depends on the backend. - if args.backend == "openai-chat": - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False, - request_id_prefix=args.request_id_prefix, - ) - else: - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset." - ) - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True, - request_id_prefix=args.request_id_prefix, - ) - - elif args.dataset_name == "hf": - # all following datasets are implemented from the - # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_class = VisionArenaDataset - args.hf_split = "train" - args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_class = InstructCoderDataset - args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: - dataset_class = MTBenchDataset - args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ConversationDataset - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_class = AIMODataset - args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 - dataset_class = NextEditPredictionDataset - args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ASRDataset - args.hf_split = "train" - else: - supported_datasets = set( - [ - dataset_name - for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ] - ) - raise ValueError( - f"Unsupported dataset path: {args.dataset_path}. " - "Huggingface dataset only supports dataset_path" - f" from one of following: {supported_datasets}. " - "Please consider contributing if you would " - "like to add support for additional dataset formats." - ) - - if dataset_class.IS_MULTIMODAL and backend not in [ - "openai-chat", - "openai-audio", - ]: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend." - ) - input_requests = dataset_class( - dataset_path=args.dataset_path, - dataset_subset=args.hf_subset, - dataset_split=args.hf_split, - random_seed=args.seed, - no_stream=args.no_stream, - ).sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.hf_output_len, - request_id_prefix=args.request_id_prefix, - ) - - else: - # For datasets that follow a similar structure, use a mapping. - dataset_mapping = { - "sharegpt": lambda: ShareGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - request_id_prefix=args.request_id_prefix, - ), - "burstgpt": lambda: BurstGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - request_id_prefix=args.request_id_prefix, - ), - "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - range_ratio=args.random_range_ratio, - request_id_prefix=args.request_id_prefix, - ), - } - - try: - input_requests = dataset_mapping[args.dataset_name]() - except KeyError as err: - raise ValueError(f"Unknown dataset: {args.dataset_name}") from err - goodput_config_dict = check_goodput_args(args) - - # Collect the sampling parameters. - sampling_params = { - k: v - for k, v in { - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "temperature": args.temperature, - }.items() - if v is not None - } - - # Sampling parameters are only supported by openai-compatible backend. - if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: - raise ValueError( - "Sampling parameters are only supported by openai-compatible backends." - ) - - if "temperature" not in sampling_params: - sampling_params["temperature"] = 0.0 # Default to greedy decoding. - - if args.backend == "llama.cpp": - # Disable prompt caching in llama.cpp backend - sampling_params["cache_prompt"] = False - - # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() - - benchmark_result = asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ) - ) - - # Save config and results to json - if args.save_result or args.append_result: - result_json: dict[str, Any] = {} - - # Setup - current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") - result_json["date"] = current_dt - result_json["backend"] = backend - result_json["model_id"] = model_id - result_json["tokenizer_id"] = tokenizer_id - result_json["num_prompts"] = args.num_prompts - - # Metadata - if args.metadata: - for item in args.metadata: - if "=" in item: - kvstring = item.split("=") - result_json[kvstring[0].strip()] = kvstring[1].strip() - else: - raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) - # Traffic - result_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf" - ) - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency - - if args.ramp_up_strategy is not None: - result_json["ramp_up_strategy"] = args.ramp_up_strategy - result_json["ramp_up_start_rps"] = args.ramp_up_start_rps - result_json["ramp_up_end_rps"] = args.ramp_up_end_rps - - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} - - if not args.save_detailed: - # Remove fields with too many data points - for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", - ]: - if field in result_json: - del result_json[field] - if field in benchmark_result: - del benchmark_result[field] - - # Save to file - base_model_id = model_id.split("/")[-1] - max_concurrency_str = ( - f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None - else "" - ) - if args.ramp_up_strategy is not None: - file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - else: - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - if args.result_filename: - file_name = args.result_filename - if args.result_dir: - os.makedirs(args.result_dir, exist_ok=True) - file_name = os.path.join(args.result_dir, file_name) - with open( - file_name, mode="a+" if args.append_result else "w", encoding="utf-8" - ) as outfile: - # Append a newline. - if args.append_result and outfile.tell() != 0: - outfile.write("\n") - json.dump(result_json, outfile) - save_to_pytorch_benchmark_format(args, result_json, file_name) - - -def create_argument_parser(): - parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput." - ) - parser.add_argument( - "--backend", - type=str, - default="vllm", - choices=list(ASYNC_REQUEST_FUNCS.keys()), - ) - parser.add_argument( - "--base-url", - type=str, - default=None, - help="Server or API base url if not using http host and port.", - ) - # Use 127.0.0.1 here instead of localhost to force the use of ipv4 - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--endpoint", - type=str, - default="/v1/completions", - help="API endpoint.", - ) - parser.add_argument( - "--dataset-name", - type=str, - default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], - help="Name of the dataset to benchmark on.", - ) - parser.add_argument( - "--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Do not load the dataset in streaming mode.", - ) - parser.add_argument( - "--max-concurrency", - type=int, - default=None, - help="Maximum number of concurrent requests. This can be used " - "to help simulate an environment where a higher level component " - "is enforcing a maximum number of concurrent requests. While the " - "--request-rate argument controls the rate at which requests are " - "initiated, this argument will control how many are actually allowed " - "to execute at a time. This means that when used in combination, the " - "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.", - ) - - parser.add_argument( - "--model", - type=str, - required=True, - help="Name of the model.", - ) - parser.add_argument( - "--tokenizer", - type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.", - ) - parser.add_argument( - "--logprobs", - type=int, - default=None, - help=( - "Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed" - ), - ) - parser.add_argument( - "--request-rate", - type=float, - default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process or gamma distribution " - "to synthesize the request arrival times.", - ) - parser.add_argument( - "--burstiness", - type=float, - default=1.0, - help="Burstiness factor of the request generation. " - "Only take effect when request_rate is not inf. " - "Default value is 1, which follows Poisson process. " - "Otherwise, the request intervals follow a gamma distribution. " - "A lower burstiness value (0 < burstiness < 1) results in more " - "bursty requests. A higher burstiness value (burstiness > 1) " - "results in a more uniform arrival of requests.", - ) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Trust remote code from huggingface", - ) - parser.add_argument( - "--disable-tqdm", - action="store_true", - help="Specify to disable tqdm progress bar.", - ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", - ) - parser.add_argument( - "--save-result", - action="store_true", - help="Specify to save benchmark results to a json file", - ) - parser.add_argument( - "--save-detailed", - action="store_true", - help="When saving the results, whether to include per request " - "information such as response, error, ttfs, tpots, etc.", - ) - parser.add_argument( - "--append-result", - action="store_true", - help="Append the benchmark result to the existing json file.", - ) - parser.add_argument( - "--metadata", - metavar="KEY=VALUE", - nargs="*", - help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " - "for metadata of this run to be saved in the result JSON file " - "for record keeping purposes.", - ) - parser.add_argument( - "--result-dir", - type=str, - default=None, - help="Specify directory to save benchmark json results." - "If not specified, results are saved in the current directory.", - ) - parser.add_argument( - "--result-filename", - type=str, - default=None, - help="Specify the filename to save benchmark json results." - "If not specified, results will be saved in " - "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" - " format.", - ) - parser.add_argument( - "--ignore-eos", - action="store_true", - help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", - ) - parser.add_argument( - "--percentile-metrics", - type=str, - default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentils. " - "This argument specifies the metrics to report percentiles. " - 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' - 'Default value is "ttft,tpot,itl".', - ) - parser.add_argument( - "--metric-percentiles", - type=str, - default="99", - help="Comma-separated list of percentiles for selected metrics. " - 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' - 'Default value is "99". ' - 'Use "--percentile-metrics" to select metrics.', - ) - parser.add_argument( - "--goodput", - nargs="+", - required=False, - help='Specify service level objectives for goodput as "KEY:VALUE" ' - "pairs, where the key is a metric name, and the value is in " - 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' - "separated by spaces. Allowed request level metric names are " - '"ttft", "tpot", "e2el". For more context on the definition of ' - "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve", - ) - parser.add_argument( - "--request-id-prefix", - type=str, - required=False, - default="benchmark-serving", - help="Specify the prefix of request id.", - ) - - # group for dataset specific arguments - custom_group = parser.add_argument_group("custom dataset options") - custom_group.add_argument( - "--custom-output-len", - type=int, - default=256, - help="Number of output tokens per request, used only for custom dataset.", - ) - custom_group.add_argument( - "--custom-skip-chat-template", - action="store_true", - help="Skip applying chat template to prompt, used only for custom dataset.", - ) - - sonnet_group = parser.add_argument_group("sonnet dataset options") - sonnet_group.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help="Number of input tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help="Number of output tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help="Number of prefix tokens per request, used only for sonnet dataset.", - ) - - sharegpt_group = parser.add_argument_group("sharegpt dataset options") - sharegpt_group.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.", - ) - - random_group = parser.add_argument_group("random dataset options") - random_group.add_argument( - "--random-input-len", - type=int, - default=1024, - help="Number of input tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-output-len", - type=int, - default=128, - help="Number of output tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-range-ratio", - type=float, - default=0.0, - help="Range ratio for sampling input/output length, " - "used only for random sampling. Must be in the range [0, 1) to define " - "a symmetric sampling range" - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - random_group.add_argument( - "--random-prefix-len", - type=int, - default=0, - help=( - "Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]." - ), - ) - - hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument( - "--hf-subset", type=str, default=None, help="Subset of the HF dataset." - ) - hf_group.add_argument( - "--hf-split", type=str, default=None, help="Split of the HF dataset." - ) - hf_group.add_argument( - "--hf-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output lengths " - "from the sampled HF dataset.", - ) - - sampling_group = parser.add_argument_group("sampling parameters") - sampling_group.add_argument( - "--top-p", - type=float, - default=None, - help="Top-p sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--top-k", - type=int, - default=None, - help="Top-k sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--min-p", - type=float, - default=None, - help="Min-p sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--temperature", - type=float, - default=None, - help="Temperature sampling parameter. Only has effect on " - "openai-compatible backends. If not specified, default to greedy " - "decoding (i.e. temperature==0.0).", - ) - - parser.add_argument( - "--tokenizer-mode", - type=str, - default="auto", - choices=["auto", "slow", "mistral", "custom"], - help='The tokenizer mode.\n\n* "auto" will use the ' - 'fast tokenizer if available.\n* "slow" will ' - "always use the slow tokenizer. \n* " - '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.', - ) - - parser.add_argument( - "--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ", - ) - - parser.add_argument( - "--lora-modules", - nargs="+", - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.", - ) - - parser.add_argument( - "--ramp-up-strategy", - type=str, - default=None, - choices=["linear", "exponential"], - help="The ramp-up strategy. This would be used to " - "ramp up the request rate from initial RPS to final " - "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " - "over the duration of the benchmark.", - ) - parser.add_argument( - "--ramp-up-start-rps", - type=int, - default=None, - help="The starting request rate for ramp-up (RPS). " - "Needs to be specified when --ramp-up-strategy is used.", - ) - parser.add_argument( - "--ramp-up-end-rps", - type=int, - default=None, - help="The ending request rate for ramp-up (RPS). " - "Needs to be specified when --ramp-up-strategy is used.", - ) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench serve + +For help with the new command, run: + vllm bench serve --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench serve --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index ca6843a72aa36..58b9767d09390 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -37,14 +37,13 @@ from typing import Optional import datasets import numpy as np import pandas as pd -from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase - from backend_request_func import ( ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput, ) +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase try: from vllm.transformers_utils.tokenizer import get_tokenizer @@ -449,7 +448,8 @@ async def benchmark( def prepare_extra_body(request) -> dict: extra_body = {} # Add the schema to the extra_body - extra_body[request.structure_type] = request.schema + extra_body["structured_outputs"] = {} + extra_body["structured_outputs"][request.structure_type] = request.schema return extra_body print("Starting initial single prompt test run...") @@ -696,11 +696,11 @@ def evaluate(ret, args): return re.match(args.regex, actual) is not None def _eval_correctness(expected, actual): - if args.structure_type == "guided_json": + if args.structure_type == "json": return _eval_correctness_json(expected, actual) - elif args.structure_type == "guided_regex": + elif args.structure_type == "regex": return _eval_correctness_regex(expected, actual) - elif args.structure_type == "guided_choice": + elif args.structure_type == "choice": return _eval_correctness_choice(expected, actual) else: return None @@ -780,18 +780,18 @@ def main(args: argparse.Namespace): ) if args.dataset == "grammar": - args.structure_type = "guided_grammar" + args.structure_type = "grammar" elif args.dataset == "regex": - args.structure_type = "guided_regex" + args.structure_type = "regex" elif args.dataset == "choice": - args.structure_type = "guided_choice" + args.structure_type = "choice" else: - args.structure_type = "guided_json" + args.structure_type = "json" if args.no_structured_output: args.structured_output_ratio = 0 if args.save_results: - result_file_name = f"{args.structured_output_ratio}guided" + result_file_name = f"{args.structured_output_ratio}so" result_file_name += f"_{backend}" result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.model.split('/')[-1]}" @@ -909,13 +909,13 @@ def create_argument_parser(): parser.add_argument( "--tokenizer", type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", ) parser.add_argument( "--tokenizer-mode", type=str, default="auto", - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", ) parser.add_argument( "--num-prompts", @@ -998,7 +998,7 @@ def create_argument_parser(): "--percentile-metrics", type=str, default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentils. " + help="Comma-separated list of selected metrics to report percentiles. " "This argument specifies the metrics to report percentiles. " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' 'Default value is "ttft,tpot,itl".', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c7f290e1eb88e..b6dc0918fd4d1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,742 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark offline inference throughput.""" - -import argparse -import dataclasses -import json -import os -import random -import time -import warnings -from typing import Any, Optional, Union - -import torch -import uvloop -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase -from typing_extensions import deprecated - -from benchmark_dataset import ( - AIMODataset, - BurstGPTDataset, - ConversationDataset, - InstructCoderDataset, - RandomDataset, - SampleRequest, - ShareGPTDataset, - SonnetDataset, - VisionArenaDataset, -) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, -) -from vllm.inputs import TextPrompt, TokensPrompt -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput -from vllm.sampling_params import BeamSearchParams -from vllm.utils import FlexibleArgumentParser, merge_async_iterators - - -def run_vllm( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -) -> tuple[float, Optional[list[RequestOutput]]]: - from vllm import LLM, SamplingParams - - llm = LLM(**dataclasses.asdict(engine_args)) - assert all( - llm.llm_engine.model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests." - ) - # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] - sampling_params: list[SamplingParams] = [] - for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) - if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) - ) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - lora_requests: Optional[list[LoRARequest]] = None - if engine_args.enable_lora: - lora_requests = [request.lora_request for request in requests] - - use_beam_search = False - - outputs = None - if not use_beam_search: - start = time.perf_counter() - outputs = llm.generate( - prompts, sampling_params, lora_request=lora_requests, use_tqdm=True - ) - end = time.perf_counter() - else: - assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] - # output_len should be the same for all requests. - output_len = requests[0].expected_output_len - for request in requests: - assert request.expected_output_len == output_len - start = time.perf_counter() - llm.beam_search( - prompts, - BeamSearchParams( - beam_width=n, - max_tokens=output_len, - ignore_eos=True, - ), - ) - end = time.perf_counter() - return end - start, outputs - - -def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -) -> tuple[float, list[RequestOutput]]: - """ - Run vLLM chat benchmark. This function is recommended ONLY for benchmarking - multimodal models as it properly handles multimodal inputs and chat - formatting. For non-multimodal models, use run_vllm() instead. - """ - from vllm import LLM, SamplingParams - - llm = LLM(**dataclasses.asdict(engine_args)) - - assert all( - llm.llm_engine.model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests." - ) - - prompts = [] - sampling_params: list[SamplingParams] = [] - for request in requests: - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - start = time.perf_counter() - outputs = llm.chat(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() - return end - start, outputs - - -async def run_vllm_async( - requests: list[SampleRequest], - n: int, - engine_args: AsyncEngineArgs, - disable_frontend_multiprocessing: bool = False, - disable_detokenize: bool = False, -) -> float: - from vllm import SamplingParams - - async with build_async_engine_client_from_engine_args( - engine_args, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, - ) as llm: - model_config = await llm.get_model_config() - assert all( - model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests." - ) - - # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] - sampling_params: list[SamplingParams] = [] - lora_requests: list[Optional[LoRARequest]] = [] - for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) - if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) - ) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - lora_requests.append(request.lora_request) - - generators = [] - start = time.perf_counter() - for i, (prompt, sp, lr) in enumerate( - zip(prompts, sampling_params, lora_requests) - ): - generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") - generators.append(generator) - all_gens = merge_async_iterators(*generators) - async for i, res in all_gens: - pass - end = time.perf_counter() - return end - start - - -def run_hf( - requests: list[SampleRequest], - model: str, - tokenizer: PreTrainedTokenizerBase, - n: int, - max_batch_size: int, - trust_remote_code: bool, - disable_detokenize: bool = False, -) -> float: - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code - ) - if llm.config.model_type == "llama": - # To enable padding in the HF backend. - tokenizer.pad_token = tokenizer.eos_token - llm = llm.cuda() - - pbar = tqdm(total=len(requests)) - start = time.perf_counter() - batch: list[str] = [] - max_prompt_len = 0 - max_output_len = 0 - for i in range(len(requests)): - prompt = requests[i].prompt - prompt_len = requests[i].prompt_len - output_len = requests[i].expected_output_len - # Add the prompt to the batch. - batch.append(prompt) - max_prompt_len = max(max_prompt_len, prompt_len) - max_output_len = max(max_output_len, output_len) - if len(batch) < max_batch_size and i != len(requests) - 1: - # Check if we can add more requests to the batch. - next_prompt_len = requests[i + 1].prompt_len - next_output_len = requests[i + 1].expected_output_len - if ( - max(max_prompt_len, next_prompt_len) - + max(max_output_len, next_output_len) - ) <= 2048: - # We can add more requests to the batch. - continue - - # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids - llm_outputs = llm.generate( - input_ids=input_ids.cuda(), - do_sample=True, - num_return_sequences=n, - temperature=1.0, - top_p=1.0, - use_cache=True, - max_new_tokens=max_output_len, - ) - if not disable_detokenize: - # Include the decoding time. - tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) - pbar.update(len(batch)) - - # Clear the batch. - batch = [] - max_prompt_len = 0 - max_output_len = 0 - end = time.perf_counter() - return end - start - - -def run_mii( - requests: list[SampleRequest], - model: str, - tensor_parallel_size: int, - output_len: int, -) -> float: - from mii import client, serve - - llm = serve(model, tensor_parallel=tensor_parallel_size) - prompts = [request.prompt for request in requests] - - start = time.perf_counter() - llm.generate(prompts, max_new_tokens=output_len) - end = time.perf_counter() - client = client(model) - client.terminate_server() - return end - start - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any] -) -> None: - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={ - "requests_per_second": [results["requests_per_second"]], - "tokens_per_second": [results["tokens_per_second"]], - }, - extra_info={ - k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }, - ) - if pt_records: - # Don't use json suffix here as we don't want CI to pick it up - pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -def get_requests(args, tokenizer): - # Common parameters for all dataset types. - common_kwargs = { - "dataset_path": args.dataset_path, - "random_seed": args.seed, - } - sample_kwargs = { - "tokenizer": tokenizer, - "lora_path": args.lora_path, - "max_loras": args.max_loras, - "num_requests": args.num_prompts, - "input_len": args.input_len, - "output_len": args.output_len, - } - - if args.dataset_path is None or args.dataset_name == "random": - sample_kwargs["range_ratio"] = args.random_range_ratio - sample_kwargs["prefix_len"] = args.prefix_len - dataset_cls = RandomDataset - elif args.dataset_name == "sharegpt": - dataset_cls = ShareGPTDataset - if args.backend == "vllm-chat": - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_name == "sonnet": - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset." - ) - dataset_cls = SonnetDataset - sample_kwargs["prefix_len"] = args.prefix_len - sample_kwargs["return_prompt_formatted"] = True - elif args.dataset_name == "burstgpt": - dataset_cls = BurstGPTDataset - elif args.dataset_name == "hf": - common_kwargs["no_stream"] = args.no_stream - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = VisionArenaDataset - common_kwargs["dataset_subset"] = None - common_kwargs["dataset_split"] = "train" - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = InstructCoderDataset - common_kwargs["dataset_split"] = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = ConversationDataset - common_kwargs["dataset_subset"] = args.hf_subset - common_kwargs["dataset_split"] = args.hf_split - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_cls = AIMODataset - common_kwargs["dataset_subset"] = None - common_kwargs["dataset_split"] = "train" - else: - raise ValueError(f"Unknown dataset name: {args.dataset_name}") - # Remove None values - sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} - return dataset_cls(**common_kwargs).sample(**sample_kwargs) - - -@deprecated( - "benchmark_throughput.py is deprecated and will be removed in a " - "future version. Please use 'vllm bench throughput' instead.", -) -def main(args: argparse.Namespace): - if args.seed is None: - args.seed = 0 - print(args) - random.seed(args.seed) - # Sample the requests. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code - ) - requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None for request in requests) - request_outputs: Optional[list[RequestOutput]] = None - if args.backend == "vllm": - if args.async_engine: - elapsed_time = uvloop.run( - run_vllm_async( - requests, - args.n, - AsyncEngineArgs.from_cli_args(args), - args.disable_frontend_multiprocessing, - args.disable_detokenize, - ) - ) - else: - elapsed_time, request_outputs = run_vllm( - requests, - args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize, - ) - elif args.backend == "hf": - assert args.tensor_parallel_size == 1 - elapsed_time = run_hf( - requests, - args.model, - tokenizer, - args.n, - args.hf_max_batch_size, - args.trust_remote_code, - args.disable_detokenize, - ) - elif args.backend == "mii": - elapsed_time = run_mii( - requests, args.model, args.tensor_parallel_size, args.output_len - ) - elif args.backend == "vllm-chat": - elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize - ) - else: - raise ValueError(f"Unknown backend: {args.backend}") - - if request_outputs: - # Note: with the vllm and vllm-chat backends, - # we have request_outputs, which we use to count tokens. - total_prompt_tokens = 0 - total_output_tokens = 0 - for ro in request_outputs: - if not isinstance(ro, RequestOutput): - continue - total_prompt_tokens += ( - len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 - ) - total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) - total_num_tokens = total_prompt_tokens + total_output_tokens - else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) - total_output_tokens = sum(r.expected_output_len for r in requests) - total_prompt_tokens = total_num_tokens - total_output_tokens - - if is_multi_modal and args.backend != "vllm-chat": - print( - "\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details." - ) - # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. - # vllm-chat backend counts the image tokens now - - print( - f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s" - ) - print(f"Total num prompt tokens: {total_prompt_tokens}") - print(f"Total num output tokens: {total_output_tokens}") - - # Output JSON results if specified - if args.output_json: - results = { - "elapsed_time": elapsed_time, - "num_requests": len(requests), - "total_num_tokens": total_num_tokens, - "requests_per_second": len(requests) / elapsed_time, - "tokens_per_second": total_num_tokens / elapsed_time, - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - save_to_pytorch_benchmark_format(args, results) - - -def validate_args(args): - """ - Validate command-line arguments. - """ - - # === Deprecation and Defaulting === - if args.dataset is not None: - warnings.warn( - "The '--dataset' argument will be deprecated in the next release. " - "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2, - ) - args.dataset_path = args.dataset - - if not getattr(args, "tokenizer", None): - args.tokenizer = args.model - - # === Backend Validation === - valid_backends = {"vllm", "hf", "mii", "vllm-chat"} - if args.backend not in valid_backends: - raise ValueError(f"Unsupported backend: {args.backend}") - - # === Dataset Configuration === - if not args.dataset and not args.dataset_path: - print("When dataset path is not set, it will default to random dataset") - args.dataset_name = "random" - if args.input_len is None: - raise ValueError("input_len must be provided for a random dataset") - - # === Dataset Name Specific Checks === - # --hf-subset and --hf-split: only used - # when dataset_name is 'hf' - if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None - ): - warnings.warn( - "--hf-subset and --hf-split will be ignored \ - since --dataset-name is not 'hf'.", - stacklevel=2, - ) - elif args.dataset_name == "hf": - if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS - ): - assert args.backend == "vllm-chat", ( - f"{args.dataset_path} needs to use vllm-chat as the backend." - ) # noqa: E501 - elif args.dataset_path in ( - InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS - ): - assert args.backend == "vllm", ( - f"{args.dataset_path} needs to use vllm as the backend." - ) # noqa: E501 - else: - raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") - - # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != "random" and args.random_range_ratio is not None: - warnings.warn( - "--random-range-ratio will be ignored since \ - --dataset-name is not 'random'.", - stacklevel=2, - ) - - # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not - # set. - if ( - args.dataset_name not in {"random", "sonnet", None} - and args.prefix_len is not None - ): - warnings.warn( - "--prefix-len will be ignored since --dataset-name\ - is not 'random', 'sonnet', or not set.", - stacklevel=2, - ) - - # === LoRA Settings === - if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError("LoRA benchmarking is only supported for vLLM backend") - if getattr(args, "enable_lora", False) and args.lora_path is None: - raise ValueError("LoRA path must be provided when enable_lora is True") - - # === Backend-specific Validations === - if args.backend == "hf" and args.hf_max_batch_size is None: - raise ValueError("HF max batch size is required for HF backend") - if args.backend != "hf" and args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - - if ( - args.backend in {"hf", "mii"} - and getattr(args, "quantization", None) is not None - ): - raise ValueError("Quantization is only for vLLM backend.") - - if args.backend == "mii" and args.dtype != "auto": - raise ValueError("dtype must be auto for MII backend.") - if args.backend == "mii" and args.n != 1: - raise ValueError("n must be 1 for MII backend.") - if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError("Tokenizer must be the same as the model for MII backend.") - - # --data-parallel is not supported currently. - # https://github.com/vllm-project/vllm/issues/16222 - if args.data_parallel_size > 1: - raise ValueError( - "Data parallel is not supported in offline benchmark, " - "please use benchmark serving instead" - ) - - -def create_argument_parser(): - parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument( - "--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm", - ) - parser.add_argument( - "--dataset-name", - type=str, - choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], - help="Name of the dataset to benchmark on.", - default="sharegpt", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Do not load the dataset in streaming mode.", - ) - parser.add_argument( - "--dataset", - type=str, - default=None, - help="Path to the ShareGPT dataset, will be deprecated in\ - the next release. The dataset is expected to " - "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]", - ) - parser.add_argument( - "--dataset-path", type=str, default=None, help="Path to the dataset" - ) - parser.add_argument( - "--input-len", - type=int, - default=None, - help="Input prompt length for each request", - ) - parser.add_argument( - "--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.", - ) - parser.add_argument( - "--n", type=int, default=1, help="Number of generated sequences per prompt." - ) - parser.add_argument( - "--num-prompts", type=int, default=1000, help="Number of prompts to process." - ) - parser.add_argument( - "--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save the throughput results in JSON format.", - ) - parser.add_argument( - "--async-engine", - action="store_true", - default=False, - help="Use vLLM async engine rather than LLM class.", - ) - parser.add_argument( - "--disable-frontend-multiprocessing", - action="store_true", - default=False, - help="Disable decoupled async engine frontend.", - ) - parser.add_argument( - "--disable-detokenize", - action="store_true", - help=( - "Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)" - ), - ) - # LoRA - parser.add_argument( - "--lora-path", - type=str, - default=None, - help="Path to the LoRA adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.", - ) - parser.add_argument( - "--prefix-len", - type=int, - default=None, - help=f"Number of prefix tokens to be used in RandomDataset " - "and SonnetDataset. For RandomDataset, the total input " - "length is the sum of prefix-len (default: " - f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length " - "sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]. For SonnetDataset, " - f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " - "controls how much of the input is fixed lines versus " - "random lines, but the total input length remains approximately " - "input_len tokens.", - ) - # random dataset - parser.add_argument( - "--random-range-ratio", - type=float, - default=None, - help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) " - "for sampling input/output length, " - "used only for RandomDataset. Must be in the range [0, 1) to " - "define a symmetric sampling range " - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - - # hf dtaset - parser.add_argument( - "--hf-subset", type=str, default=None, help="Subset of the HF dataset." - ) - parser.add_argument( - "--hf-split", type=str, default=None, help="Split of the HF dataset." - ) - - parser = AsyncEngineArgs.add_cli_args(parser) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - if args.tokenizer is None: - args.tokenizer = args.model - validate_args(args) - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench throughput + +For help with the new command, run: + vllm bench throughput --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench throughput --help +""") + sys.exit(1) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index a5a5b52f60397..02f8c593392c4 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.utils import FlexibleArgumentParser, cdiv @@ -158,7 +158,7 @@ def bench_fp8( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) ), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) ), "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 92f97ffabea2a..d683835db96a4 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -55,24 +55,20 @@ benchmark() { output_len=$2 - CUDA_VISIBLE_DEVICES=0 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ + CUDA_VISIBLE_DEVICES=0 vllm serve $model \ --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & - CUDA_VISIBLE_DEVICES=1 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ + CUDA_VISIBLE_DEVICES=1 vllm serve $model \ --port 8200 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index af2bcba3ea57a..35c86cc845221 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -38,16 +38,12 @@ wait_for_server() { launch_chunked_prefill() { model="meta-llama/Meta-Llama-3.1-8B-Instruct" # disagg prefill - CUDA_VISIBLE_DEVICES=0 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ + CUDA_VISIBLE_DEVICES=0 vllm serve $model \ --port 8100 \ --max-model-len 10000 \ --enable-chunked-prefill \ --gpu-memory-utilization 0.6 & - CUDA_VISIBLE_DEVICES=1 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ + CUDA_VISIBLE_DEVICES=1 vllm serve $model \ --port 8200 \ --max-model-len 10000 \ --enable-chunked-prefill \ @@ -62,23 +58,19 @@ launch_chunked_prefill() { launch_disagg_prefill() { model="meta-llama/Meta-Llama-3.1-8B-Instruct" # disagg prefill - CUDA_VISIBLE_DEVICES=0 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ + CUDA_VISIBLE_DEVICES=0 vllm serve $model \ --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & - CUDA_VISIBLE_DEVICES=1 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ + CUDA_VISIBLE_DEVICES=1 vllm serve $model \ --port 8200 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py new file mode 100644 index 0000000000000..f1e504499eaf6 --- /dev/null +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_w8a8_block_fp8_linear, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton as vllm_triton + +assert current_platform.is_cuda(), ( + "Only support benchmarking w8a8 block fp8 kernel on CUDA device." +) + +# DeepSeek-V3 weight shapes +DEEPSEEK_V3_SHAPES = [ + (512 + 64, 7168), + (2112, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + (18432 * 2, 7168), + (24576, 1536), + (12288, 7168), + (4096, 7168), + (7168, 2048), +] + + +def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): + """Build runner function for w8a8 block fp8 matmul.""" + factor_for_scale = 1e-2 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + # Create random FP8 tensors + A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + + B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + # Create scales + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) + * factor_for_scale + ) + + # SM90 CUTLASS requires row-major format for scales + if use_cutlass and current_platform.is_device_capability(90): + Bs = Bs.T.contiguous() + + def run(): + if use_cutlass: + return apply_w8a8_block_fp8_linear( + A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True + ) + else: + return apply_w8a8_block_fp8_linear( + A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False + ) + + return run + + +# Determine available providers +available_providers = ["torch-bf16", "w8a8-block-fp8-triton"] +plot_title = "BF16 vs W8A8 Block FP8 GEMMs" + +if CUTLASS_BLOCK_FP8_SUPPORTED: + available_providers.append("w8a8-block-fp8-cutlass") + + +@vllm_triton.testing.perf_report( + vllm_triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=available_providers, + line_names=available_providers, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs W8A8 Block FP8 GEMMs", + args={}, + ) +) +def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): + M = batch_size + device = "cuda" + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + a = torch.randn((M, K), device=device, dtype=torch.bfloat16) + b = torch.randn((N, K), device=device, dtype=torch.bfloat16) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + elif provider == "w8a8-block-fp8-triton": + run_w8a8_triton = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=False + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_triton(), quantiles=quantiles + ) + elif provider == "w8a8-block-fp8-cutlass": + run_w8a8_cutlass = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=True + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_cutlass(), quantiles=quantiles + ) + else: + raise ValueError(f"Unknown provider: {provider}") + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +if __name__ == "__main__": + block_size = (128, 128) + + for N, K in DEEPSEEK_V3_SHAPES: + print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}") + + print(f"TFLOP/s comparison (block_size={block_size}):") + benchmark_tflops.run( + print_data=True, + # show_plots=False, + # save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}", + N=N, + K=K, + block_size=block_size, + ) + + print("\nBenchmark finished!") diff --git a/benchmarks/kernels/bench_mxfp4_qutlass.py b/benchmarks/kernels/bench_mxfp4_qutlass.py new file mode 100644 index 0000000000000..dfc7721876a17 --- /dev/null +++ b/benchmarks/kernels/bench_mxfp4_qutlass.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "mxfp4": dict(no_a_quant=False, enabled=True), + "mxfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_mxfp4( + b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx( + b, forward_hadamard_matrix, method="abs_max" + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton") + return weight_hf_e2m1, weight_hf_scale_block + + +def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4( + b, forward_hadamard_matrix, device + ) + alpha = torch.tensor([1.0], device="cuda") + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + + def run(): + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs MXFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_mxfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_mxfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/bench_nvfp4_gemm.py b/benchmarks/kernels/bench_nvfp4_gemm.py index 9e832c9faa8e8..6b19eb113f3e7 100644 --- a/benchmarks/kernels/bench_nvfp4_gemm.py +++ b/benchmarks/kernels/bench_nvfp4_gemm.py @@ -3,6 +3,7 @@ import argparse import copy import itertools +import os import torch from weight_shapes import WEIGHT_SHAPES @@ -23,21 +24,45 @@ PROVIDER_CFGS = { "torch-bf16": dict(enabled=True), "nvfp4": dict(no_a_quant=False, enabled=True), "nvfp4-noquant": dict(no_a_quant=True, enabled=True), + "fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True), + "fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True), } +_needs_fbgemm = any( + v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False) +) +if _needs_fbgemm: + try: + from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import ( + triton_scale_nvfp4_quant, + ) + except ImportError: + print( + "WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. " + "These providers will be skipped. Please install fbgemm_gpu with: " + "'pip install fbgemm-gpu-genai' to run them." + ) + # Disable FBGEMM providers so the benchmark can run. + for cfg in PROVIDER_CFGS.values(): + if cfg.get("fbgemm"): + cfg["enabled"] = False + _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] -def _quant_weight_nvfp4(b: torch.Tensor, device: str): +def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg): # Compute global scale for weight b_amax = torch.abs(b).max().to(torch.float32) b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax - b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) + if "fbgemm" in cfg and cfg["fbgemm"]: + b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale) + else: + b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) return b_fp4, scale_b_fp4, b_global_scale def build_nvfp4_runner(cfg, a, b, dtype, device): - b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device) + b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg) # Compute global scale for activation # NOTE: This is generally provided ahead-of-time by the model checkpoint. @@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device): # Alpha for the GEMM operation alpha = 1.0 / (a_global_scale * b_global_scale) + if "fbgemm" in cfg and cfg["fbgemm"]: + if cfg["no_a_quant"]: + a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale) + + def run(): + return torch.ops.fbgemm.f4f4bf16( + a_fp4, + b_fp4, + scale_a_fp4, + scale_b_fp4, + global_scale=alpha, + use_mx=False, + ) + + return run + else: + + def run(): + a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale) + return torch.ops.fbgemm.f4f4bf16( + a_fp4, + b_fp4, + scale_a_fp4, + scale_b_fp4, + global_scale=alpha, + use_mx=False, + ) + + return run if cfg["no_a_quant"]: # Pre-quantize activation @@ -130,10 +184,13 @@ if __name__ == "__main__": for K, N, model in prepare_shapes(args): print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:") + save_dir = f"bench_nvfp4_res_n{N}_k{K}" + os.makedirs(save_dir, exist_ok=True) + benchmark.run( print_data=True, show_plots=True, - save_path=f"bench_nvfp4_res_n{N}_k{K}", + save_path=save_dir, N=N, K=K, ) diff --git a/benchmarks/kernels/bench_nvfp4_qutlass.py b/benchmarks/kernels/bench_nvfp4_qutlass.py new file mode 100644 index 0000000000000..6fecc816f9466 --- /dev/null +++ b/benchmarks/kernels/bench_nvfp4_qutlass.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "nvfp4": dict(no_a_quant=False, enabled=True), + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_nvfp4( + b: torch.Tensor, + forward_hadamard_matrix: torch.Tensor, + global_scale: torch.Tensor, + device: str, + M: int, + N: int, + K: int, +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv( + b, forward_hadamard_matrix, global_scale + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return weight_hf_e2m1, weight_hf_scale_block + + +def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K): + alpha = torch.tensor([1.0], device="cuda") + global_scale = torch.tensor([1.0], device="cuda") + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4( + b, forward_hadamard_matrix, global_scale, device, M, N, K + ) + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + + def run(): + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs NVFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_nvfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [16, 32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_nvfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 923d678f1f2db..e08e5680c191e 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -2,14 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Callable +from unittest.mock import patch +import pandas as pd import torch -from vllm import _custom_ops as ops -from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +def with_triton_mode(fn): + """Temporarily force the Triton fallback path""" + + def wrapped(*args, **kwargs): + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + return fn(*args, **kwargs) + + return wrapped # TODO(luka): use standalone_compile utility @@ -21,78 +32,238 @@ def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): return inner -torch._dynamo.config.recompile_limit = 8888 -compilation_config = CompilationConfig(custom_ops=["none"]) -with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): - torch_per_token_quant_fp8 = torch.compile( - QuantFP8(False, GroupShape.PER_TOKEN), - fullgraph=True, - dynamic=False, # recompile for different shapes - ) +def bench_compile(fn: Callable): + # recompile for different shapes + fwd = torch.compile(fn, fullgraph=True, dynamic=False) # First dim is explicitly dynamic to simulate vLLM usage - torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) + return with_dyn_arg(fwd, 0, 0) -def cuda_per_token_quant_fp8( - input: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(input) +torch._dynamo.config.recompile_limit = 8888 -def calculate_diff(batch_size: int, seq_len: int): - """Calculate difference between Triton and CUDA implementations.""" +def calculate_diff( + batch_size: int, + hidden_size: int, + group_shape: GroupShape, + dtype: torch.dtype, +): + """Calculate the difference between Inductor and CUDA implementations.""" device = torch.device("cuda") - x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device) - torch_out, torch_scale = torch_per_token_quant_fp8(x) - cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) - if torch.allclose( - cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x) + torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) + cuda_out, cuda_scale = quant_fp8.forward_cuda(x) + + try: + torch.testing.assert_close( + cuda_out.to(torch.float32), + torch_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5) + torch.testing.assert_close( + cuda_out.to(torch.float32), + torch_eager_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5) print("✅ All implementations match") - else: + except AssertionError as e: print("❌ Implementations differ") + print(e) -batch_size_range = [1, 16, 32, 64, 128] -seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] - -configs = list(itertools.product(batch_size_range, seq_len_range)) +configs = [] -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], - x_vals=configs, - line_arg="provider", - line_vals=["torch", "cuda"], - line_names=["Torch", "CUDA"], - styles=[("blue", "-"), ("green", "-")], - ylabel="us", - plot_name="per-token-dynamic-quant-fp8-performance", - args={}, - ) -) -def benchmark_quantization(batch_size, seq_len, provider): - dtype = torch.float16 +def benchmark_quantization( + batch_size, + hidden_size, + provider, + group_shape: GroupShape, + col_major: bool, + dtype: torch.dtype, +): device = torch.device("cuda") - x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) if provider == "torch": - fn = lambda: torch_per_token_quant_fp8(x.clone()) + fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone()) elif provider == "cuda": - fn = lambda: cuda_per_token_quant_fp8(x.clone()) + fn = lambda: quant_fp8.forward_cuda(x.clone()) + elif provider == "triton": + if not group_shape.is_per_group(): + # Triton only supported for per-group + return 0, 0, 0 + + fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone()) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms +# TODO(luka) extract to utils +def compute_geomean_speedups( + df: pd.DataFrame, + baseline_col: str, + speedup_cols: list[str], + groupby_cols: list[str] | None = None, +) -> pd.DataFrame: + """ + Compute geometric mean speedups over a baseline column. + + Args: + df: Input dataframe + baseline_col: Column to use as baseline + speedup_cols: Columns to compute speedups for + groupby_cols: Columns to group by. If None, compute over entire df. + + Returns: + pd.DataFrame with geometric mean speedups + """ + from scipy.stats import gmean + + def geo_speedup(group: pd.DataFrame) -> pd.Series: + ratios = { + col: (group[baseline_col] / group[col]).values for col in speedup_cols + } + return pd.Series({col: gmean(vals) for col, vals in ratios.items()}) + + if groupby_cols is None: + result = geo_speedup(df).to_frame().T + else: + result = ( + df.groupby(groupby_cols) + .apply(geo_speedup, include_groups=False) + .reset_index() + ) + + return result + + if __name__ == "__main__": - calculate_diff(batch_size=4, seq_len=4096) - benchmark_quantization.run(print_data=True) + parser = FlexibleArgumentParser( + description="Benchmark the various implementations of QuantFP8 (dynamic-only)" + ) + parser.add_argument("-c", "--check", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + parser.add_argument( + "--hidden-sizes", + type=int, + nargs="+", + default=[896, 1024, 2048, 4096, 7168], + help="Hidden sizes to benchmark", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=[1, 16, 128, 512, 1024], + help="Batch sizes to benchmark", + ) + parser.add_argument( + "--group-sizes", + type=int, + nargs="+", + default=None, + help="Group sizes for GroupShape(1,N) to benchmark. " + "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)", + ) + parser.add_argument( + "--no-column-major", + action="store_true", + help="Disable column-major scales testing", + ) + + args = parser.parse_args() + assert args + + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + hidden_sizes = args.hidden_sizes + batch_sizes = args.batch_sizes + + if args.group_sizes is not None: + group_shapes = [] + for size in args.group_sizes: + if size == 0: + group_shapes.append(GroupShape.PER_TENSOR) + elif size == -1: + group_shapes.append(GroupShape.PER_TOKEN) + else: + group_shapes.append(GroupShape(1, size)) + else: + group_shapes = [ + GroupShape.PER_TENSOR, + GroupShape.PER_TOKEN, + GroupShape(1, 64), + GroupShape(1, 128), + ] + + column_major_scales = [False] if args.no_column_major else [True, False] + + config_gen = itertools.product( + group_shapes, + column_major_scales, + batch_sizes, + hidden_sizes, + ) + + # filter out column-major scales for non-group, reverse order + configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) + + print(f"Running {len(configs)} configurations:") + print(f" Hidden sizes: {hidden_sizes}") + print(f" Batch sizes: {batch_sizes}") + print(f" Group shapes: {[str(g) for g in group_shapes]}") + print(f" Column major scales: {column_major_scales}") + print() + + if args.check: + for group_shape in group_shapes: + group_size = group_shape[1] + print(f"{group_size=}") + calculate_diff( + batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype + ) + + benchmark = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size", "col_major", "group_shape"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda", "triton"], + line_names=["Torch (Compiled)", "CUDA", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("black", "-")], + ylabel="us", + plot_name="QuantFP8 performance", + args={}, + ) + )(benchmark_quantization) + + df = benchmark.run(print_data=True, dtype=dtype, return_df=True) + + # Print geomean speedups + geo_table_grouped = compute_geomean_speedups( + df, + baseline_col="Torch (Compiled)", + speedup_cols=["CUDA", "Triton"], + groupby_cols=["col_major", "group_shape"], + ) + + print("Speedup over Torch (Compiled)") + print(geo_table_grouped.to_string(index=False)) diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py new file mode 100644 index 0000000000000..93edbcc9391fc --- /dev/null +++ b/benchmarks/kernels/benchmark_activation.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# benchmark custom activation op performance +import itertools + +import torch + +import vllm.model_executor.layers.activation # noqa F401 +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +intermediate_size = [3072, 9728, 12288] +configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) + + +def benchmark_activation( + batch_size: int, + seq_len: int, + intermediate_size: int, + provider: str, + func_name: str, + dtype: torch.dtype, +): + device = "cuda" + num_tokens = batch_size * seq_len + dim = intermediate_size + current_platform.seed_everything(42) + torch.set_default_device(device) + + if func_name == "gelu_and_mul": + layer = CustomOp.op_registry[func_name](approximate="none") + elif func_name == "gelu_and_mul_tanh": + layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh") + elif func_name == "fatrelu_and_mul": + threshold = 0.5 + layer = CustomOp.op_registry[func_name](threshold) + else: + layer = CustomOp.op_registry[func_name]() + + x = torch.randn(num_tokens, dim, dtype=dtype, device=device) + compiled_layer = torch.compile(layer.forward_native) + + if provider == "custom": + fn = lambda: layer(x) + elif provider == "compiled": + fn = lambda: compiled_layer(x) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the custom activation op.") + parser.add_argument( + "--func-name", + type=str, + choices=[ + "mul_and_silu", + "silu_and_mul", + "gelu_and_mul", + "gelu_and_mul_tanh", + "fatrelu_and_mul", + "swigluoai_and_mul", + "gelu_new", + "gelu_fast", + "quick_gelu", + ], + default="silu_and_mul", + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + args = parser.parse_args() + assert args + + func_name = args.func_name + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + perf_report = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "intermediate_size"], + x_vals=configs, + line_arg="provider", + line_vals=["custom", "compiled"], + line_names=["Custom OP", "Compiled"], + styles=[("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"{func_name}-op-performance", + args={}, + ) + ) + + perf_report( + lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation( + batch_size, seq_len, intermediate_size, provider, func_name, dtype + ) + ).run(print_data=True) diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py index 35c20ee41b9a9..726a2a371d109 100644 --- a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.scalar_type import scalar_types @@ -140,6 +144,12 @@ def bench_run( a_fp8_scale: torch.Tensor, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + for _ in range(num_repeats): fused_experts( a, @@ -147,10 +157,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_fp8_scale, + quant_config=quant_config, ) def run_cutlass_moe_fp4( @@ -172,25 +179,27 @@ def bench_run( device: torch.device, num_repeats: int, ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) for _ in range(num_repeats): with nvtx.annotate("cutlass_moe_fp4", color="green"): cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, - a2_gscale=a2_gs, w1_fp4=w1_fp4, - w1_blockscale=w1_blockscale, - w1_alphas=w1_gs, w2_fp4=w2_fp4, - w2_blockscale=w2_blockscale, - w2_alphas=w2_gs, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, - device=device, + quant_config=quant_config, ) def run_cutlass_from_graph( @@ -211,26 +220,29 @@ def bench_run( e: int, device: torch.device, ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): return cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, w1_fp4=w1_fp4, - w1_blockscale=w1_blockscale, - w1_alphas=w1_alphas, - a2_gscale=a2_gs, w2_fp4=w2_fp4, - w2_blockscale=w2_blockscale, - w2_alphas=w2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, - device=device, + quant_config=quant_config, ) def run_triton_from_graph( @@ -246,16 +258,18 @@ def bench_run( with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) return fused_experts( a, w1, w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_fp8_scale, + quant_config=quant_config, ) def replay_graph(graph, num_repeats): diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py new file mode 100644 index 0000000000000..b419b2fa0e3eb --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe +kernel. Both kernels take in fp8 quantized weights and 16-bit activations, +but use different quantization strategies and backends. +""" + +import nvtx +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +# Weight shapes for different models: [num_experts, topk, hidden_size, +# intermediate_size] +WEIGHT_SHAPES_MOE = { + "mixtral-8x7b": [ + [8, 2, 4096, 14336], + ], + "deepseek-v2": [ + [160, 6, 5120, 12288], + ], + "custom-small": [ + [8, 2, 2048, 7168], + ], + "glm45-fp8": [ + [128, 8, 4096, 1408], + ], + "Llama-4-Maverick-17B-128E-Instruct-FP8": [ + [128, 1, 5120, 8192], + ], +} + +DEFAULT_MODELS = [ + "mixtral-8x7b", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] + +FP8_DTYPE = current_platform.fp8_dtype() + + +def bench_run( + results: list, + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + + # Create input activations + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + + # Create weights + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + # Create FP8 quantized weights and scales for both kernels + w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE) + + # Create scales based on quantization strategy + if per_out_ch: + # Per-channel quantization + w1_scale = torch.empty( + (num_experts, 2 * n, 1), device=device, dtype=torch.float32 + ) + w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32) + else: + # Per-tensor quantization + w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + # Quantize weights + for expert in range(num_experts): + if per_out_ch: + # Per-channel quantization - not yet implemented properly + # For now, fall back to per-tensor quantization + w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert]) + # Expand scalar scales to the expected per-channel shape + w1_scale[expert] = w1_scale_temp.expand(2 * n, 1) + w2_scale[expert] = w2_scale_temp.expand(k, 1) + else: + # Per-tensor quantization + w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert]) + # Store scalar scales in [1, 1] tensors + w1_scale[expert, 0, 0] = w1_scale_temp + w2_scale[expert, 0, 0] = w2_scale_temp + + # Prepare weights for CUTLASS (no transpose needed) + w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K] + w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N] + + # Create router scores and get topk + score = torch.randn((m, num_experts), device=device, dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + # WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization + # Force per-tensor quantization for all cases to match working e2e setup + a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32) + a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32) + + # Force per-tensor quantization for all cases + per_act_token = False + + # Create stride tensors for CUTLASS + ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device) + ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device) + c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device) + c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + + def run_cutlass_moe_fp8( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + num_repeats: int, + ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp8", color="blue"): + cutlass_moe_fp8( + a=a, + w1_q=w1, + w2_q=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + quant_config=quant_config, + activation="silu", + global_num_experts=num_experts, + ) + + # Pre-create quantization config to avoid creating it inside CUDA graph + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + ) + + # Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly) + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + # Capture 10 invocations like benchmark_moe.py + for _ in range(10): + cutlass_moe_fp8( + a=a, + w1_q=w1_fp8q_cutlass, + w2_q=w2_fp8q_cutlass, + topk_weights=topk_weights, + topk_ids=topk_ids, + ab_strides1=ab_strides1, + ab_strides2=ab_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + quant_config=quant_config, + activation="silu", + global_num_experts=num_experts, + ) + torch.cuda.synchronize() + + # Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly) + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + # Capture 10 invocations like benchmark_moe.py + for _ in range(10): + fused_experts( + a, + w1_fp8q, + w2_fp8q, + topk_weights, + topk_ids, + quant_config=quant_config, + ) + torch.cuda.synchronize() + + def bench_cuda_graph(graph, num_warmup=5, num_iters=100): + """Benchmark CUDA graph using events like benchmark_moe.py""" + # Warmup + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + # Timing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies = [] + for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + + # Divide by 10 since graph contains 10 calls + return sum(latencies) / (num_iters * 10) + + # Benchmark parameters + num_warmup = 5 + num_iters = 100 + + # Benchmark only CUDA graphs (more reliable and faster) + # Benchmark Triton MoE with CUDA graphs + triton_graph_time = bench_cuda_graph( + triton_graph, num_warmup=num_warmup, num_iters=num_iters + ) + + # Benchmark CUTLASS MoE with CUDA graphs + cutlass_graph_time = bench_cuda_graph( + cutlass_graph, num_warmup=num_warmup, num_iters=num_iters + ) + + # Convert ms to us and return results + triton_time_us = triton_graph_time * 1000 + cutlass_time_us = cutlass_graph_time * 1000 + + return { + "batch_size": m, + "triton_time_us": triton_time_us, + "cutlass_time_us": cutlass_time_us, + } + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + all_results = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in args.per_act_token_opts: + for per_out_ch in args.per_out_ch_opts: + print( + f"\n=== {model}, experts={num_experts}, topk={topk}," + f"per_act={per_act_token}, per_out_ch={per_out_ch} ===" + ) + + config_results = [] + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + result = bench_run( + [], # Not used anymore + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + if result: + config_results.append(result) + + # Print results table for this configuration + if config_results: + print( + f"\n{'Batch Size':<12}" + f"{'Triton (us)':<15}" + f"{'CUTLASS (us)':<15}" + ) + print("-" * 45) + for result in config_results: + print( + f"{result['batch_size']:<12}" + f"{result['triton_time_us']:<15.2f}" + f"{result['cutlass_time_us']:<15.2f}" + ) + + all_results.extend(config_results) + + print(f"\nTotal benchmarks completed: {len(all_results)}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE + across specified models/shapes/batches + + Example usage: + python benchmark_cutlass_moe_fp8.py \ + --model "Llama-4-Maverick-17B-128E-Instruct-FP8" \ + --tp-sizes 8 \ + --batch-size 2 4 8 \ + --per-act-token-opts false \ + --per-out-ch-opts false + + """ + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument( + "--per-act-token-opts", + nargs="+", + type=lambda x: x.lower() == "true", + default=[False, True], + help="Per-activation token quantization options (true/false)", + ) + parser.add_argument( + "--per-out-ch-opts", + nargs="+", + type=lambda x: x.lower() == "true", + default=[False, True], + help="Per-output channel quantization options (true/false)", + ) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py new file mode 100644 index 0000000000000..4cbdde5a5b2ca --- /dev/null +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -0,0 +1,508 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark script for device communicators: +CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, +and SymmMemCommunicator (multimem, two-shot). + +for NCCL symmetric memory you need to set the environment variables +NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does +not use fast NVLS implementation for all reduce. + +Usage: + torchrun --nproc_per_node= benchmark_device_communicators.py [options] + +Example: + torchrun --nproc_per_node=2 benchmark_device_communicators.py + --sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100 +""" + +import json +import os +import time +from contextlib import nullcontext +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + register_nccl_symmetric_ops, +) +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) +from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + +# Default sequence lengths to benchmark +DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192] + +# Fixed hidden size and dtype for all benchmarks +HIDDEN_SIZE = 8192 +BENCHMARK_DTYPE = torch.bfloat16 + +# CUDA graph settings +CUDA_GRAPH_CAPTURE_CYCLES = 10 + + +class CommunicatorBenchmark: + """Benchmark class for testing device communicators.""" + + def __init__( + self, + rank: int, + world_size: int, + device: torch.device, + cpu_group: ProcessGroup, + sequence_lengths: list[int], + ): + self.rank = rank + self.world_size = world_size + self.device = device + self.cpu_group = cpu_group + + # Calculate max_size_override based on largest sequence length + max_seq_len = max(sequence_lengths) + max_tensor_elements = max_seq_len * HIDDEN_SIZE + self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1 + + # Initialize communicators + self.custom_allreduce = None + self.pynccl_comm = None + self.symm_mem_comm = None + self.symm_mem_comm_multimem = None + self.symm_mem_comm_two_shot = None + + self._init_communicators() + + def _init_communicators(self): + """Initialize all available communicators.""" + try: + self.custom_allreduce = CustomAllreduce( + group=self.cpu_group, + device=self.device, + max_size=self.max_size_override, + ) + if not self.custom_allreduce.disabled: + logger.info("Rank %s: CustomAllreduce initialized", self.rank) + else: + logger.info("Rank %s: CustomAllreduce disabled", self.rank) + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e + ) + self.custom_allreduce = None + + try: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, device=self.device + ) + if not self.pynccl_comm.disabled: + logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + register_nccl_symmetric_ops(self.pynccl_comm) + else: + logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) + self.pynccl_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e + ) + self.pynccl_comm = None + + # Initialize variants for SymmMemCommunicator + try: + self.symm_mem_comm_multimem = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=True, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_multimem.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (multimem) initialized", self.rank + ) + else: + self.symm_mem_comm_multimem = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s", + self.rank, + e, + ) + self.symm_mem_comm_multimem = None + + try: + self.symm_mem_comm_two_shot = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=False, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_two_shot.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank + ) + else: + self.symm_mem_comm_two_shot = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s", + self.rank, + e, + ) + self.symm_mem_comm_two_shot = None + + def benchmark_allreduce( + self, sequence_length: int, num_warmup: int, num_trials: int + ) -> dict[str, float]: + """Benchmark allreduce operations for all available communicators.""" + + results = {} + + # Define communicators with their benchmark functions + communicators = [] + + if self.custom_allreduce is not None: + comm = self.custom_allreduce + # CustomAllreduce one-shot + communicators.append( + ( + "ca_1stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "1stage", # env variable value + ) + ) + # CustomAllreduce two-shot + communicators.append( + ( + "ca_2stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "2stage", # env variable value + ) + ) + + if self.pynccl_comm is not None: + comm = self.pynccl_comm + communicators.append( + ( + "pynccl", + lambda t, c=comm: c.all_reduce(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) + communicators.append( + ( + "pynccl-symm", + lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_multimem is not None: + comm = self.symm_mem_comm_multimem + communicators.append( + ( + "symm_mem_multimem", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_two_shot is not None: + comm = self.symm_mem_comm_two_shot + communicators.append( + ( + "symm_mem_two_shot", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + # Benchmark each communicator + for name, allreduce_fn, should_use_fn, context, env_var in communicators: + # Set environment variable if needed + if env_var is not None: + os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var + else: + # Clear the environment variable to avoid interference + os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None) + + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + + return results + + def benchmark_allreduce_single( + self, + sequence_length: int, + allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]], + should_use_fn: Callable[[torch.Tensor], bool], + context, + num_warmup: int, + num_trials: int, + ) -> Optional[float]: + """Benchmark method with CUDA graph optimization.""" + try: + # Create test tensor (2D: sequence_length x hidden_size) + tensor = torch.randn( + sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device + ) + if not should_use_fn(tensor): + return None + + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + graph_input = tensor.clone() + + # Warmup before capture + for _ in range(3): + allreduce_fn(graph_input) + + # Capture the graph using context manager + with context: + graph = torch.cuda.CUDAGraph() + graph_pool = torch.cuda.graph_pool_handle() + set_graph_pool_id(graph_pool) + with torch.cuda.graph(graph, pool=graph_pool): + for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): + allreduce_fn(graph_input) + + torch.cuda.synchronize() + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(num_trials): + graph.replay() + torch.cuda.synchronize() + + end_time = time.perf_counter() + + # Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES + return ( + (end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000 + ) + + except Exception as e: + logger.error("CUDA graph benchmark failed: %s", e) + raise RuntimeError( + f"CUDA graph benchmark failed for communicator: {e}" + ) from e + + +def _calculate_speedup_info(comm_results: dict[str, float]) -> str: + """Calculate speedup information for a single tensor size.""" + if not comm_results: + return "N/A" + + # Find the fastest communicator + fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k]) + fastest_time = comm_results[fastest_comm] + + # Calculate speedup vs PyNccl if available + if "pynccl" in comm_results: + pynccl_time = comm_results["pynccl"] + speedup = pynccl_time / fastest_time + return f"{fastest_comm} ({speedup:.2f}x)" + else: + return f"{fastest_comm} (N/A)" + + +def print_results( + results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int +): + """Print benchmark results in a formatted table.""" + + print(f"\n{'=' * 130}") + print("Device Communicator Benchmark Results") + print( + f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, " + f"Hidden Size: {HIDDEN_SIZE}" + ) + print(f"{'=' * 130}") + + # Get all communicator names + all_comms = set() + for size_results in results.values(): + all_comms.update(size_results.keys()) + + all_comms = sorted(list(all_comms)) + + # Print header + header = f"{'Tensor Shape':<20}{'Tensor Size':<15}" + for comm in all_comms: + header += f"{comm:<20}" + header += f"{'Best (Speedup vs PyNccl)':<30}" + print(header) + print("-" * len(header)) + + # Print results for each sequence length + for seq_len in sequence_lengths: + if seq_len in results: + # Calculate tensor size in elements and bytes + tensor_elements = seq_len * HIDDEN_SIZE + tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize + + # Format tensor size (MB) + tensor_size_mb = tensor_bytes / (1024 * 1024) + tensor_size_str = f"{tensor_size_mb:.2f} MB" + + # Format tensor shape + tensor_shape = f"({seq_len}, {HIDDEN_SIZE})" + + row = f"{tensor_shape:<20}{tensor_size_str:<15}" + for comm in all_comms: + if comm in results[seq_len]: + row += f"{results[seq_len][comm]:<20.3f}" + else: + row += f"{'N/A':<20}" + + # Calculate speedup information + speedup_info = _calculate_speedup_info(results[seq_len]) + row += f"{speedup_info:<30}" + + print(row) + + print(f"{'=' * 130}") + print("All times are in milliseconds (ms) per allreduce operation") + print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)") + + +def main(): + parser = FlexibleArgumentParser(description="Benchmark device communicators") + + parser.add_argument( + "--sequence-lengths", + type=int, + nargs="+", + default=DEFAULT_SEQUENCE_LENGTHS, + help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)", + ) + + parser.add_argument( + "--num-warmup", type=int, default=5, help="Number of warmup iterations" + ) + + parser.add_argument( + "--num-trials", type=int, default=50, help="Number of benchmark trials" + ) + + parser.add_argument("--output-json", type=str, help="Output results to JSON file") + + args = parser.parse_args() + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + # Get CPU process group + cpu_group = dist.new_group(backend="gloo") + + # Disable USE_SYMM_MEM to avoid affecting the max_sizes + # in symm_mem and custom_all_reduce for benchmark + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + # Initialize benchmark + benchmark = CommunicatorBenchmark( + rank, world_size, device, cpu_group, args.sequence_lengths + ) + + # Run benchmarks + all_results = {} + + for seq_len in args.sequence_lengths: + if rank == 0: + logger.info( + "Benchmarking sequence length: %s (tensor shape: %s x %s)", + seq_len, + seq_len, + HIDDEN_SIZE, + ) + + results = benchmark.benchmark_allreduce( + sequence_length=seq_len, + num_warmup=args.num_warmup, + num_trials=args.num_trials, + ) + + all_results[seq_len] = results + + # Synchronize between ranks + dist.barrier() + + # Print results (only rank 0) + if rank == 0: + print_results(all_results, args.sequence_lengths, world_size) + + # Save to JSON if requested + if args.output_json: + # Add speedup information to results + enhanced_results = {} + for seq_len, comm_results in all_results.items(): + enhanced_results[seq_len] = { + "timings": comm_results, + "speedup_info": _calculate_speedup_info(comm_results), + } + + output_data = { + "world_size": world_size, + "dtype": str(BENCHMARK_DTYPE), + "hidden_size": HIDDEN_SIZE, + "sequence_lengths": args.sequence_lengths, + "num_warmup": args.num_warmup, + "num_trials": args.num_trials, + "cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES, + "results": enhanced_results, + } + + with open(args.output_json, "w") as f: + json.dump(output_data, f, indent=2) + + logger.info("Results saved to %s", args.output_json) + + # Cleanup + if cpu_group != dist.group.WORLD: + dist.destroy_process_group(cpu_group) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index a6b42406b5cb0..14330ae6f03c5 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, @@ -96,6 +97,11 @@ def bench_run( a_scale: torch.Tensor, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) for _ in range(num_repeats): fused_experts( a, @@ -103,10 +109,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, + quant_config=quant_config, ) def run_cutlass_moe( @@ -125,6 +128,12 @@ def bench_run( per_act_token: bool, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + for _ in range(num_repeats): cutlass_moe_fp8( a, @@ -132,14 +141,11 @@ def bench_run( w2, topk_weights, topk_ids, - w1_scale, - w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, - per_act_token, - a1_scale=None, + quant_config=quant_config, ) def run_cutlass_from_graph( @@ -156,6 +162,12 @@ def bench_run( topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): @@ -165,14 +177,11 @@ def bench_run( w2_q, topk_weights, topk_ids, - w1_scale, - w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, - per_act_token, - a1_scale=None, + quant_config=quant_config, ) def run_triton_from_graph( @@ -185,6 +194,11 @@ def bench_run( w2_scale: torch.Tensor, a_scale: torch.Tensor, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): @@ -194,10 +208,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, + quant_config=quant_config, ) def replay_graph(graph, num_repeats): diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 3d38d4b3534e8..799b16999873f 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -79,9 +79,9 @@ def make_rand_lora_weight_tensor( def make_rand_tensors( - a_shape: tuple[int], - b_shape: tuple[int], - c_shape: tuple[int], + a_shape: tuple[int, ...], + b_shape: tuple[int, ...], + c_shape: tuple[int, ...], a_dtype: torch.dtype, b_dtype: torch.dtype, c_dtype: torch.dtype, @@ -243,7 +243,7 @@ class OpType(Enum): lora_rank: int, num_loras: int, num_slices: int, - ) -> tuple[tuple[int], tuple[int], tuple[int]]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Given num_slices, return the shapes of the A, B, and C matrices in A x B = C, for the op_type @@ -464,7 +464,11 @@ class BenchmarkTensors: for field_name in LoRAKernelMeta.__dataclass_fields__: field = getattr(self.lora_kernel_meta, field_name) assert isinstance(field, torch.Tensor) - setattr(self.lora_kernel_meta, field_name, to_device(field)) + setattr( + self.lora_kernel_meta, + field_name, + to_device(field) if field_name != "no_lora_flag_cpu" else field, + ) def metadata(self) -> tuple[int, int, int]: """ @@ -512,6 +516,7 @@ class BenchmarkTensors: "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, "lora_ids": self.lora_kernel_meta.active_lora_ids, "scaling": 1.0, + "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: @@ -552,6 +557,7 @@ class BenchmarkTensors: "lora_ids": self.lora_kernel_meta.active_lora_ids, "offset_start": 0, "add_inputs": add_inputs, + "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } def bench_fn_kwargs( @@ -637,7 +643,7 @@ def bench_optype( # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() - # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) torch.cuda.synchronize() diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index a9c4d30d9b189..1b1c3b321cce4 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -284,6 +284,25 @@ def machete_create_bench_fn( ) +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + # impl # bench @@ -385,6 +404,20 @@ def bench( ) ) + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 752c2d0082167..d3040e9738f7b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -14,6 +14,10 @@ import ray import torch from ray.experimental.tqdm_ray import tqdm +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, +) from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform from vllm.transformers_utils.config import get_config @@ -134,43 +138,36 @@ def benchmark_config( def run(): from vllm.model_executor.layers.fused_moe import override_config + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = FusedMoEQuantConfig.make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) + with override_config(config): - if use_deep_gemm: - topk_weights, topk_ids, token_expert_indices = fused_topk( - x, input_gating, topk, False - ) - return fused_experts( - x, - w1, - w2, - topk_weights, - topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - allow_deep_gemm=True, - ) - else: - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - ) + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + return fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=quant_config, + allow_deep_gemm=use_deep_gemm, + ) # JIT compilation & warmup run() @@ -414,13 +411,15 @@ class BenchmarkWorker: use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = get_config_dtype_str( + dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. + block_n = block_quant_shape[0] if block_quant_shape else None + block_k = block_quant_shape[1] if block_quant_shape else None op_config = get_moe_configs( - num_experts, shard_intermediate_size // 2, dtype_str + num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k ) if op_config is None: config = get_default_config( @@ -430,6 +429,7 @@ class BenchmarkWorker: hidden_size, topk, dtype_str, + block_quant_shape, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] @@ -544,7 +544,7 @@ def save_configs( block_quant_shape: list[int], save_dir: str, ) -> None: - dtype_str = get_config_dtype_str( + dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) @@ -557,7 +557,7 @@ def save_configs( filename = os.path.join(save_dir, filename) print(f"Writing best config to {filename}...") with open(filename, "w") as f: - json.dump(configs, f, indent=4) + json.dump({"triton_version": triton.__version__, **configs}, f, indent=4) f.write("\n") @@ -579,26 +579,42 @@ def main(args: argparse.Namespace): E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size + hidden_size = config.hidden_size elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + hidden_size = config.hidden_size elif config.architectures[0] in ( - "DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", "Glm4MoeForCausalLM", ): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): + hidden_size = config.hidden_size + elif config.architectures[0] in ( + "Qwen2MoeForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", + ): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size + elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration": + text_config = config.get_text_config() + E = text_config.num_experts + topk = text_config.num_experts_per_tok + intermediate_size = text_config.moe_intermediate_size + hidden_size = text_config.hidden_size elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] + hidden_size = config.hidden_size else: # Support for llama4 config = config.get_text_config() @@ -606,6 +622,7 @@ def main(args: argparse.Namespace): E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + hidden_size = config.hidden_size enable_ep = bool(args.enable_expert_parallel) if enable_ep: ensure_divisibility(E, args.tp_size, "Number of experts") @@ -614,7 +631,6 @@ def main(args: argparse.Namespace): else: ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") shard_intermediate_size = 2 * intermediate_size // args.tp_size - hidden_size = config.hidden_size dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" @@ -675,7 +691,11 @@ def main(args: argparse.Namespace): is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") - + if use_deep_gemm: + raise ValueError( + "Tuning with --use-deep-gemm is not supported as it only tunes Triton " + "kernels. Please remove the flag." + ) start = time.time() configs = _distribute( "tune", diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py new file mode 100644 index 0000000000000..9ac8f5e6594e4 --- /dev/null +++ b/benchmarks/kernels/benchmark_polynorm.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch + +from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton + + +def polynorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + def norm(x, eps: float): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + x = x.float() + return ( + ( + weight[0] * norm(x**3, eps) + + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + + bias + ) + .to(weight.dtype) + .view(orig_shape) + ) + + +def polynorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + out = torch.empty_like(x) + vllm_ops.poly_norm(out, x, weight, bias, eps) + output = out + + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_dim): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + output_naive = polynorm_naive(x, weight, bias) + output_vllm = polynorm_vllm(x, weight, bias) + + if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +dim_range = [2048, 4096] +configs = list(itertools.product(dim_range, batch_size_range, seq_length_range)) + + +def get_benchmark(): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["dim", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "vllm"], + line_names=["Naive", "vLLM"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="polynorm-perf", + args={}, + ) + ) + def benchmark(dim, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_dim = dim * 4 + + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_naive(x, weight, bias), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_vllm(x, weight, bias), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=8192, + help="Intermediate size of MLP", + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/polnorm/", + help="Path to save polnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_dim=args.hidden_dim, + ) + + benchmark = get_benchmark() + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py new file mode 100644 index 0000000000000..af9841daadf24 --- /dev/null +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import random +import time + +import torch +from tabulate import tabulate + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + FlexibleArgumentParser, + create_kv_caches_with_random, +) + +logger = init_logger(__name__) + + +@torch.inference_mode() +def run_benchmark( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + kv_cache_dtype: str, + num_iters: int, + benchmark_mode: str, + device: str = "cuda", +) -> float: + """Return latency (seconds) for given num_tokens.""" + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + + current_platform.seed_everything(42) + torch.set_default_device(device) + + # create random key / value tensors [T, H, D]. + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device) + value = torch.randn_like(key) + + # prepare the slot mapping. + # each token is assigned a unique slot in the KV-cache. + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError("num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + key_caches, value_caches = create_kv_caches_with_random( + num_blocks, + block_size, + 1, # num_layers + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches + + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + + function_under_test = lambda: ops.reshape_and_cache( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + + def run_cuda_benchmark(n_iters: int) -> float: + nonlocal key, value, key_cache, value_cache, slot_mapping + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(n_iters): + function_under_test() + torch.cuda.synchronize() + end = time.perf_counter() + return (end - start) / n_iters + + # warm-up + run_cuda_benchmark(3) + + lat = run_cuda_benchmark(num_iters) + + # free tensors to mitigate OOM when sweeping + del key, value, key_cache, value_cache, slot_mapping + torch.cuda.empty_cache() + + return lat + + +def main(args): + rows = [] + for exp in range(1, 17): + n_tok = 2**exp + lat = run_benchmark( + num_tokens=n_tok, + num_heads=args.num_heads, + head_size=args.head_size, + block_size=args.block_size, + num_blocks=args.num_blocks, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + kv_cache_dtype=args.kv_cache_dtype, + num_iters=args.iters, + benchmark_mode=args.mode, + device="cuda", + ) + rows.append([n_tok, lat * 1e6]) # convert to microseconds + + print(f"Benchmark results for implementation cuda (measuring with {args.mode}):") + print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f")) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + + parser.add_argument("--num-heads", type=int, default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--num-blocks", type=int, default=128 * 128) + + parser.add_argument( + "--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="bfloat16", + ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + ) + + parser.add_argument("--iters", type=int, default=200) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index d4648c18f31d5..0aace571064a0 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -9,6 +9,9 @@ import torch from tabulate import tabulate from vllm import _custom_ops as ops +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import ( @@ -31,6 +34,8 @@ def run_benchmark( kv_cache_dtype: str, kv_cache_layout: str, num_iters: int, + implementation: str, + benchmark_mode: str, device: str = "cuda", ) -> float: """Return latency (seconds) for given num_tokens.""" @@ -38,6 +43,14 @@ def run_benchmark( if kv_cache_dtype == "fp8" and head_size % 16: raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + if implementation not in ("cuda", "triton"): + raise ValueError( + f"Unsupported implementation: {implementation}. " + "Only 'cuda' and 'triton' are supported." + ) + if implementation == "triton" and kv_cache_layout == "HND": + return float("nan") # Triton does not support HND layout yet. + current_platform.seed_everything(42) torch.set_default_device(device) @@ -65,27 +78,49 @@ def run_benchmark( cache_layout=kv_cache_layout, ) key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches # compute per-kernel scaling factors for fp8 conversion (if used). k_scale = (key.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32) + if implementation == "cuda": + function_under_test = lambda: ops.reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + function_under_test = lambda: triton_reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + def run_cuda_benchmark(n_iters: int) -> float: nonlocal key, value, key_cache, value_cache, slot_mapping torch.cuda.synchronize() start = time.perf_counter() for _ in range(n_iters): - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - torch.cuda.synchronize() + function_under_test() + torch.cuda.synchronize() end = time.perf_counter() return (end - start) / n_iters @@ -116,10 +151,16 @@ def main(args): kv_cache_dtype=args.kv_cache_dtype, kv_cache_layout=layout, num_iters=args.iters, + implementation=args.implementation, + benchmark_mode=args.mode, device="cuda", ) rows.append([n_tok, layout, f"{lat * 1e6:.3f}"]) + print( + f"Benchmark results for implementation {args.implementation}" + f" (measuring with {args.mode}):" + ) print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"])) @@ -151,6 +192,21 @@ if __name__ == "__main__": ) parser.add_argument("--iters", type=int, default=100) + + parser.add_argument( + "--implementation", + type=str, + choices=["cuda", "triton"], + default="cuda", + ) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index 0650cbf3cc18e..a5887aafd30d6 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -1,77 +1,720 @@ -#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time +""" +Comprehensive 3-way SiLU Benchmark Suite + +This benchmark compares three SiLU implementations: +1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation +2. Triton Kernel - Triton-based implementation + +The suite generates detailed performance comparisons including: +- Memory bandwidth utilization +- Speedup ratios (baseline vs optimized implementations) +- Performance across different expert configurations and token distributions +""" + +from collections.abc import Callable + +import matplotlib.pyplot as plt +import numpy as np import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm, + persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used -def benchmark(E, T, H, G=128, runs=50): - current_platform.seed_everything(42) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") - tokens_per_expert = torch.randint( - T // 2, T, size=(E,), dtype=torch.int32, device="cuda" +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + # Stride for counts (elements) + stride_counts_e, + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h + base_ys_offset = e * stride_ys_e + g * stride_ys_g + + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) + + +def silu_mul_fp8_quant_deep_gemm_triton( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens, + group_size: int = 128, + eps: float = 1e-10, + expert_offsets: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = (H + group_size - 1) // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, ( + "tokens_per_expert must be shape (E,)" + ) + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, ) + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + + return y_q, y_s + + +# Parse generation strategies +strategies = ["random_imbalanced", "uniform", "max_t"] + + +def benchmark( + kernel: Callable, + E: int, + T: int, + H: int, + total_tokens: int, + num_parallel_tokens: int = 64, + G: int = 128, + runs: int = 200, + num_warmups: int = 20, + gen_strategy: str = "default", + iterations_per_run: int = 20, +): + def generate_data(seed_offset=0): + """Generate input data with given seed offset""" + current_platform.seed_everything(42 + seed_offset) + y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() + + if gen_strategy == "random_imbalanced": + + def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"): + mean = total_tokens // n_e + min_max = mean // ratio + e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean + e[0] = min_max + r = torch.rand(size=(E - 1,)) + r /= r.sum() + r *= total_tokens - min_max + r = r.round().long() + e[1:] = r.to(device=device) + return e + + tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda") + elif gen_strategy == "uniform": + r = torch.rand(size=(E,)) + r /= r.sum() + r *= total_tokens + r = r.round().long() + tokens_per_expert = r + elif gen_strategy == "max_t": + tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert.fill_(total_tokens / E) + elif gen_strategy == "first_t": + tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert[0] = min(T, total_tokens) + else: + raise ValueError(f"Unknown generation strategy: {gen_strategy}") + return y, tokens_per_expert + + dataset_count = 4 + # Pre-generate different input matrices for each iteration to avoid cache effects + data_sets = [generate_data(i) for i in range(dataset_count)] + # Warmup - for _ in range(10): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + y, tokens_per_expert = data_sets[0] + for _ in range(num_warmups): + kernel( + y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G + ) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) # Benchmark - torch.cuda.synchronize() - start = time.perf_counter() + latencies: list[float] = [] for _ in range(runs): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + torch.cuda.synchronize() - avg_time = (time.perf_counter() - start) / runs * 1000 + start_event.record() + for i in range(iterations_per_run): + y, tokens_per_expert = data_sets[i % dataset_count] + kernel( + y, + tokens_per_expert, + num_parallel_tokens=num_parallel_tokens, + group_size=G, + ) + end_event.record() + end_event.synchronize() - # Calculate actual work done (only count valid tokens) + total_time_ms = start_event.elapsed_time(end_event) + per_iter_time_ms = total_time_ms / iterations_per_run + latencies.append(per_iter_time_ms) + + # Use median instead of average for better outlier handling + median_time_ms = np.median(latencies) + median_time_s = median_time_ms / 1000 + + # Calculate actual work done (using first dataset for consistency) + _, tokens_per_expert = data_sets[0] actual_tokens = tokens_per_expert.sum().item() actual_elements = actual_tokens * H # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops ops_per_element = 8 total_ops = actual_elements * ops_per_element - gflops = total_ops / (avg_time / 1000) / 1e9 + gflops = total_ops / median_time_s / 1e9 # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs output_bytes = actual_tokens * H * 1 # H fp8 outputs scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 total_bytes = input_bytes + output_bytes + scale_bytes - memory_bw = total_bytes / (avg_time / 1000) / 1e9 + memory_bw = total_bytes / median_time_s / 1e9 - return avg_time, gflops, memory_bw + HOPPER_BANDWIDTH_TBPS = 3.35 + return ( + median_time_ms, + gflops, + memory_bw, + (memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100, + ) +def create_comparison_plot( + ratios, silu_v2_times, triton_times, config_labels, strategy_name, id +): + fig, ax = plt.subplots(1, 1, figsize=(18, 6)) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.25 + + # Execution Time plot (lower is better) + ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue") + ax.bar( + x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green" + ) + + # Add speedup labels over each bar trio + for i in range(len(x)): + triton_v2_speedup = ratios[i][1] # triton/v2 + max_height = max(silu_v2_times[i], triton_times[i]) + + # Triton/V2 speedup + ax.text( + x[i] + width / 2, + max_height + max_height * 0.02, + f"{triton_v2_speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=8, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig, ax + + +def create_combined_plot(all_results): + num_strategies = len(all_results) + fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies)) + + if num_strategies == 1: + axes = [axes] + + for idx, ( + strategy_name, + all_ratios, + all_silu_v2_results, + all_triton_results, + config_labels, + config_x_axis, + ) in enumerate(all_results): + ax = axes[idx] + + # Flatten the nested results to get bandwidth percentages for plotting + silu_v2_bandwidths = [] + triton_bandwidths = [] + flat_ratios = [] + + for config_results in all_silu_v2_results: + for result in config_results: + silu_v2_bandwidths.append(result[3]) # bandwidth percentage + + for config_results in all_triton_results: + for result in config_results: + triton_bandwidths.append(result[3]) # bandwidth percentage + + for config_ratios in all_ratios: + for ratio in config_ratios: + flat_ratios.append(ratio) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.25 + + # Bandwidth utilization plot (higher is better) + ax.bar( + x, + silu_v2_bandwidths, + width, + label="SiLU V2 (CUDA)", + alpha=0.8, + color="blue", + ) + ax.bar( + x + width, + triton_bandwidths, + width, + label="Triton Kernel", + alpha=0.8, + color="green", + ) + + # Add speedup labels over each bar trio + for i in range(len(x)): + triton_v2_speedup = flat_ratios[i] # triton/v2 + max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i]) + + # Triton/V2 speedup + ax.text( + x[i] + width / 2, + max_height + max_height * 0.02, + f"{triton_v2_speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=8, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + filename = "silu_benchmark_combined_3way.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +outer_dim = 7168 configs = [ - (8, 32, 1024), - (16, 64, 2048), - (32, 128, 4096), # DeepSeekV3 Configs - (256, 16, 7168), - (256, 32, 7168), - (256, 64, 7168), - (256, 128, 7168), - (256, 256, 7168), - (256, 512, 7168), + # (1, 56, 7168), + (8, 1024, 7168), + # (32, 56, 7168), + # DeepSeekV3 Configs + (32, 1024, 7168), + # DeepSeekV3 Configs (256, 1024, 7168), ] -print(f"GPU: {torch.cuda.get_device_name()}") -print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") -print("-" * 50) +runs = 100 +num_warmups = 20 -for E, T, H in configs: - try: - time_ms, gflops, gbps = benchmark(E, T, H) - print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") - except Exception: - print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") +strategy_descriptions = { + "uniform": "Uniform Random", + "random_imbalanced": "Imbalanced Random", + "max_t": "Even Assignment", + "first_t": "experts[0] = T, experts[1:] = 0", +} + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"Testing strategies: {', '.join(strategies)}") +print(f"Configurations: {len(configs)} configs") + +all_results = [] + +# Run benchmarks for each strategy +for id, strategy in enumerate(strategies): + print(f"\n{'=' * 60}") + print(f"Testing strategy: {strategy_descriptions[strategy]}") + print(f"{'=' * 60}") + + # Collect benchmark data for all three algorithms + config_labels = [] + config_x_axis = [] + all_silu_v2_results = [] + all_triton_results = [] + all_ratios = [] + + for E, T, H in configs: + total_tokens_config = [] + for i in [8, 16, 32, 64, 128, 256, 512]: + if i <= T: + total_tokens_config.append(i * E) + config_x_axis.append(total_tokens_config) + + silu_v2_results = [] + triton_results = [] + ratios = [] + + for total_tokens in total_tokens_config: + config_label = f"E={E},T={T},H={H},TT={total_tokens}" + config_labels.append(config_label) + + # SiLU V2 (CUDA kernel) results + time_ms_silu_v2, gflops, gbps, perc = benchmark( + persistent_masked_m_silu_mul_quant, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc)) + + # Triton kernel results + time_ms_triton, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_triton, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + triton_results.append((time_ms_triton, gflops, gbps, perc)) + + # Calculate speedup ratios (triton baseline / implementation) + triton_v2_ratio = time_ms_triton / time_ms_silu_v2 + ratios.append(triton_v2_ratio) + + print( + f"Completed: {config_label}:" + f" V2: {time_ms_silu_v2:.3f}ms," + f" Triton: {time_ms_triton:.3f}ms" + ) + + all_silu_v2_results.append(silu_v2_results) + all_triton_results.append(triton_results) + all_ratios.append(ratios) + + # Store results for combined plotting + all_results.append( + ( + strategy_descriptions[strategy], + all_ratios, + all_silu_v2_results, + all_triton_results, + config_labels, + config_x_axis, + ) + ) + + # Print summary table for this strategy + print(f"\nSummary Table - {strategy_descriptions[strategy]}:") + print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}") + print("-" * 90) + + for i, (E, T, H) in enumerate(configs): + # Get the first result for each config (simplifying for summary) + v2_time = silu_v2_results[i][0] + triton_time = triton_results[i][0] + triton_v2_speedup = triton_time / v2_time + config_label = f"E={E:3d},T={T:4d},H={H:4d}" + print( + f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} " + f"{triton_v2_speedup:8.2f}x" + ) + + +def create_total_tokens_plot(all_results): + num_strategies = len(all_results) + num_configs = len(configs) + + fig, axs = plt.subplots( + num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies) + ) + + # Add main title to the entire figure + fig.suptitle( + "Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)", + fontsize=18, + fontweight="bold", + y=0.98, + ) + + # Handle single strategy case + if num_strategies == 1: + axs = axs.reshape(1, -1) + + # Handle single config case + if num_configs == 1: + axs = axs.reshape(-1, 2) + + for strategy_idx, result in enumerate(all_results): + ( + strategy_name, + all_ratios, + all_silu_v2_results, + all_triton_results, + config_labels, + config_x_axis, + ) = result + + for config_idx in range(num_configs): + # Speedup plot (left column) + ax_speedup = axs[strategy_idx, config_idx * 2] + # Bandwidth plot (right column) + ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1] + + E, T, H = configs[config_idx] + ratios = all_ratios[config_idx] + total_tokens_values = config_x_axis[config_idx] + + # Extract speedup ratios + triton_v2_ratios = [ratio for ratio in ratios] + + # Extract bandwidth percentages for all implementations + v2_bandwidth_percentages = [ + result[3] for result in all_silu_v2_results[config_idx] + ] + triton_bandwidth_percentages = [ + result[3] for result in all_triton_results[config_idx] + ] + + # Plot speedup ratios vs total tokens (left plot) + ax_speedup.plot( + total_tokens_values, + triton_v2_ratios, + "go-", + linewidth=3, + markersize=8, + label="Triton/V2 Speedup", + ) + ax_speedup.set_title( + f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.legend(prop={"weight": "bold"}) + ax_speedup.grid(True, alpha=0.3) + + # Plot bandwidth utilization (right plot) + ax_bandwidth.plot( + total_tokens_values, + v2_bandwidth_percentages, + "o-", + linewidth=3, + markersize=8, + label="SiLU V2", + color="blue", + ) + ax_bandwidth.plot( + total_tokens_values, + triton_bandwidth_percentages, + "o-", + linewidth=3, + markersize=8, + label="Triton", + color="green", + ) + ax_bandwidth.set_title( + f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_bandwidth.set_ylabel( + "% of Peak Bandwidth", fontweight="bold", fontsize=11 + ) + ax_bandwidth.legend(prop={"weight": "bold"}) + ax_bandwidth.grid(True, alpha=0.3) + + # Format x-axis labels for both plots + for ax in [ax_speedup, ax_bandwidth]: + ax.set_xticks(total_tokens_values) + ax.set_xticklabels( + [ + f"{tt // 1000}K" if tt >= 1000 else str(tt) + for tt in total_tokens_values + ], + fontweight="bold", + ) + # Make tick labels bold + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight("bold") + + # Add value labels on Triton/V2 speedup points + for x, y in zip(total_tokens_values, triton_v2_ratios): + ax_speedup.annotate( + f"{y:.2f}x", + (x, y), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3), + ) + + plt.tight_layout() + plt.subplots_adjust(top=0.93) # Make room for main title + filename = "silu_benchmark_total_tokens_3way.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +# Create comprehensive 3-way comparison plots +combined_plot_filename = create_combined_plot(all_results) +total_tokens_plot_filename = create_total_tokens_plot(all_results) + +print(f"\n{'=' * 80}") +print("3-Way Benchmark Suite Complete!") +print(f"Generated combined comparison plot: {combined_plot_filename}") +print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}") +print("Compared: SiLU V2 (CUDA), and Triton implementations") +print(f"{'=' * 80}") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 603ce5ecf0d2c..6ddab46214577 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -259,6 +259,7 @@ if __name__ == "__main__": # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 40903c6c3444f..131df74c7de1b 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -274,6 +274,7 @@ if __name__ == "__main__": quant_dtypes = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 4fcdbadd65ecd..602fad1810748 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -11,13 +11,13 @@ from datetime import datetime from typing import Any import torch -import tqdm -import triton +from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - _w8a8_block_fp8_matmul, + _w8a8_triton_block_scaled_mm, ) from vllm.platforms import current_platform +from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser mp.set_start_method("spawn", force=True) @@ -56,7 +56,7 @@ def w8a8_block_matmul( Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. - output_dytpe: The dtype of the returned tensor. + output_dtype: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. @@ -83,7 +83,7 @@ def w8a8_block_matmul( ) if A.dtype == torch.float8_e4m3fn: - kernel = _w8a8_block_fp8_matmul + kernel = _w8a8_triton_block_scaled_mm else: raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") @@ -141,6 +141,7 @@ def get_weight_shapes(tp_size): # cannot TP total = [ (512 + 64, 7168), + (2112, 7168), ((128 + 64) * 128, 7168), (128 * (128 + 128), 512), (7168, 16384), diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index b99c2099f2c38..ba31bc5638298 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# fmt: off # ruff: noqa: E501 import time @@ -8,27 +7,33 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.triton_utils import triton -from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 +from vllm.utils.deep_gemm import ( + calc_diff, + fp8_gemm_nt, + get_col_major_tma_aligned_tensor, + per_block_cast_to_fp8, +) -def benchmark_shape(m: int, - n: int, - k: int, - warmup: int = 100, - repeat: int = 10000, - verbose: bool = False) -> dict: +def benchmark_shape( + m: int, + n: int, + k: int, + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False, +) -> dict: """Benchmark all implementations for a specific (m, n, k) shape.""" if verbose: print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") # Create test tensors - A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) # Reference result in BF16 torch.cuda.synchronize() @@ -45,34 +50,39 @@ def benchmark_shape(m: int, # Pre-quantize A for all implementations A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) + A, block_size[1], column_major_scales=True + ) # === DeepGEMM Implementation === def deepgemm_gemm(): - fp8_gemm_nt((A_deepgemm, A_scale_deepgemm), - (B_deepgemm, B_scale_deepgemm), - C_deepgemm) + fp8_gemm_nt( + (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm + ) return C_deepgemm # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_block_fp8_matmul(A_vllm, - B_vllm, - A_scale_vllm, - B_scale_vllm, - block_size, - output_dtype=torch.bfloat16) + return w8a8_triton_block_scaled_mm( + A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16, + ) # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): - return ops.cutlass_scaled_mm(A_vllm_cutlass, - B_vllm.T, - scale_a=A_scale_vllm_cutlass, - scale_b=B_scale_vllm.T, - out_dtype=torch.bfloat16) + return ops.cutlass_scaled_mm( + A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16, + ) # Run correctness check first if verbose: @@ -89,26 +99,23 @@ def benchmark_shape(m: int, print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") - print("vLLM Triton vs DeepGEMM difference: " - f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") - print("vLLM CUTLASS vs DeepGEMM difference: " - f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + print( + "vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}" + ) + print( + "vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}" + ) # Benchmark implementations implementations = { "DeepGEMM": deepgemm_gemm, "vLLM Triton": vllm_triton_gemm, - "vLLM CUTLASS": vllm_cutlass_gemm + "vLLM CUTLASS": vllm_cutlass_gemm, } - benchmark_results = { - "shape": { - "m": m, - "n": n, - "k": k - }, - "implementations": {} - } + benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}} for name, func in implementations.items(): # Warmup @@ -136,38 +143,36 @@ def benchmark_shape(m: int, "tflops": tflops, "gb_s": gb_s, "diff": { - "DeepGEMM": - 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), - "Reference": - deepgemm_diff if name == "DeepGEMM" else - (vllm_triton_diff - if name == "vLLM Triton" else vllm_cutlass_diff) - } + "DeepGEMM": 0.0 + if name == "DeepGEMM" + else calc_diff(func(), C_deepgemm), + "Reference": deepgemm_diff + if name == "DeepGEMM" + else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff), + }, } if verbose: - print( - f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" - ) + print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s") # Calculate speedups baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] for name, data in benchmark_results["implementations"].items(): if name != "DeepGEMM": speedup = baseline / data["time_ms"] - benchmark_results["implementations"][name][ - "speedup_vs_deepgemm"] = speedup + benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup if verbose: - print(f"DeepGEMM is {1/speedup:.2f}x " - f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + print( + f"DeepGEMM is {1 / speedup:.2f}x " + f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}" + ) - vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ - "time_ms"] - vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"] cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time - benchmark_results["implementations"]["vLLM CUTLASS"][ - "speedup_vs_triton"] = cutlass_vs_triton + benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = ( + cutlass_vs_triton + ) if verbose: print( f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " @@ -179,8 +184,7 @@ def benchmark_shape(m: int, def format_table_row(values, widths): """Format a row with specified column widths.""" - return "| " + " | ".join(f"{val:{w}}" - for val, w in zip(values, widths)) + " |" + return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |" def print_table(headers, rows, title=None): @@ -288,38 +292,50 @@ def run_benchmarks(verbose: bool = False): for result in all_results: shape = result["shape"] impl_data = result["implementations"]["DeepGEMM"] - deepgemm_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" - ]) + deepgemm_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + ] + ) - print_table(deepgemm_headers, - deepgemm_rows, - title="DeepGEMM Implementation:") + print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:") # Print vLLM Triton table - triton_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" - ] + triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"] triton_rows = [] for result in all_results: shape = result["shape"] impl_data = result["implementations"]["vLLM Triton"] speedup = impl_data.get("speedup_vs_deepgemm", 1.0) - triton_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(speedup) - ]) + triton_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(speedup), + ] + ) - print_table(triton_headers, - triton_rows, - title="vLLM Triton Implementation:") + print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:") # Print vLLM CUTLASS table cutlass_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", - "vs Triton" + "m", + "n", + "k", + "Time (μs)", + "TFLOPS", + "GB/s", + "vs DeepGEMM", + "vs Triton", ] cutlass_rows = [] for result in all_results: @@ -327,28 +343,27 @@ def run_benchmarks(verbose: bool = False): impl_data = result["implementations"]["vLLM CUTLASS"] vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) vs_triton = impl_data.get("speedup_vs_triton", 1.0) - cutlass_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(vs_deepgemm), - format_speedup(vs_triton) - ]) + cutlass_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton), + ] + ) - print_table(cutlass_headers, - cutlass_rows, - title="vLLM CUTLASS Implementation:") + print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:") # Calculate and print averages print("\n===== AVERAGE PERFORMANCE =====") implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] avg_metrics = { - impl: { - "tflops": 0, - "gb_s": 0, - "time_ms": 0 - } - for impl in implementations + impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations } for result in all_results: @@ -366,9 +381,9 @@ def run_benchmarks(verbose: bool = False): avg_tflops = avg_metrics[impl]["tflops"] / num_shapes avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes avg_time = avg_metrics[impl]["time_ms"] / num_shapes - avg_rows.append([ - impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" - ]) + avg_rows.append( + [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"] + ) print_table(avg_headers, avg_rows) @@ -376,21 +391,19 @@ def run_benchmarks(verbose: bool = False): avg_speedups = { "DeepGEMM vs vLLM Triton": 0, "DeepGEMM vs vLLM CUTLASS": 0, - "vLLM CUTLASS vs vLLM Triton": 0 + "vLLM CUTLASS vs vLLM Triton": 0, } for result in all_results: deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] - vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"] - avg_speedups[ - "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time - avg_speedups[ - "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time - avg_speedups[ - "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups["vLLM CUTLASS vs vLLM Triton"] += ( + vllm_triton_time / vllm_cutlass_time + ) print("\n===== AVERAGE SPEEDUPS =====") speedup_headers = ["Comparison", "Speedup"] @@ -408,8 +421,7 @@ def run_benchmarks(verbose: bool = False): for result in all_results: for impl in implementations: - avg_diff[impl] += result["implementations"][impl]["diff"][ - "Reference"] + avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"] diff_headers = ["Implementation", "Avg Diff vs Reference"] diff_rows = [] diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394afbd..9a057990bda5f 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ WEIGHT_SHAPES = { ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], } diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md index 7adf97bcf5622..f5b5c6c97d484 100644 --- a/benchmarks/multi_turn/README.md +++ b/benchmarks/multi_turn/README.md @@ -55,6 +55,107 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 ---------------------------------------------------------------------------------------------------- ``` +### JSON configuration file for synthetic conversations generation + +The input flag `--input-file` is used to determine the input conversations for the benchmark.
+When the input is a JSON file with the field `"filetype": "generate_conversations"` the tool will generate synthetic multi-turn (questions and answers) conversations. + +The file `generate_multi_turn.json` is an example file. + +The file must contain the sections `prompt_input` and `prompt_output`. + +The `prompt_input` section must contain `num_turns`, `prefix_num_tokens` and `num_tokens`: + +* `num_turns` - Number of total turns in the conversation (both user & assistant).
+The final value will always be rounded to an even number so each user turn has a reply. +* `prefix_num_tokens` - Tokens added at the start of only the **first user turn** in a conversation (unique per conversation). +* `num_tokens` - Total token length of each **user** message (one turn). + +The `prompt_output` section must contain `num_tokens`: + +* `num_tokens` - Total token length of each **assistant** message (one turn). + +### Random distributions for synthetic conversations generation + +When creating an input JSON file (such as `generate_multi_turn.json`),
+every numeric field (such as `num_turns` or `num_tokens`) requires a distribution.
+The distribution determines how to randomly sample values for the field. + +The available distributions are listed below. + +**Note:** The optional `max` field (for lognormal, zipf, and poisson) can be used to cap sampled values at an upper bound.
+Can be used to make sure that the total number of tokens in every request does not exceed `--max-model-len`. + +#### constant + +```json +{ + "distribution": "constant", + "value": 500 +} +``` + +* `value` - the fixed integer value (always returns the same number). + +#### uniform + +```json +{ + "distribution": "uniform", + "min": 12, + "max": 18 +} +``` + +* `min` - minimum value (inclusive). +* `max` - maximum value (inclusive), should be equal or larger than min. + +#### lognormal + +```json +{ + "distribution": "lognormal", + "average": 1000, + "max": 5000 +} +``` + +You can parameterize the lognormal distribution in one of two ways: + +Using the average and optional median ratio: + +* `average` - target average value of the distribution. +* `median_ratio` - the ratio of the median to the average; controls the skewness. Must be in the range (0, 1). + +Using the parameters of the underlying normal distribution: + +* `mean` - mean of the underlying normal distribution. +* `sigma` - standard deviation of the underlying normal distribution. + +#### zipf + +```json +{ + "distribution": "zipf", + "alpha": 1.2, + "max": 100 +} +``` + +* `alpha` - skew parameter (> 1). Larger values produce stronger skew toward smaller integers. + +#### poisson + +```json +{ + "distribution": "poisson", + "alpha": 10, + "max": 50 +} +``` + +* `alpha` - expected value (λ). Also the variance of the distribution. + ## ShareGPT Conversations To run with the ShareGPT data, download the following ShareGPT dataset: diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 411b89dd23dc6..67b937930d58c 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -99,21 +99,105 @@ class PoissonDistribution(Distribution): class LognormalDistribution(Distribution): def __init__( - self, mean: float, sigma: float, max_val: Optional[int] = None + self, + mean: Optional[float] = None, + sigma: Optional[float] = None, + average: Optional[int] = None, + median_ratio: Optional[float] = None, + max_val: Optional[int] = None, ) -> None: + self.average = average + self.median_ratio = median_ratio + self.max_val = max_val + + if average is not None: + if average < 1: + raise ValueError("Lognormal average must be positive") + + if mean or sigma: + raise ValueError( + "When using lognormal average, you can't provide mean/sigma" + ) + + if self.median_ratio is None: + # Default value that provides relatively wide range of values + self.median_ratio = 0.85 + + # Calculate mean/sigma of np.random.lognormal based on the average + mean, sigma = self._generate_lognormal_by_median( + target_average=self.average, median_ratio=self.median_ratio + ) + else: + if mean is None or sigma is None: + raise ValueError( + "Must provide both mean and sigma if average is not used" + ) + + if mean <= 0 or sigma < 0: + raise ValueError( + "Lognormal mean must be positive and sigma must be non-negative" + ) + + # Mean and standard deviation of the underlying normal distribution + # Based on numpy.random.lognormal self.mean = mean self.sigma = sigma - self.max_val = max_val + + @staticmethod + def _generate_lognormal_by_median( + target_average: int, median_ratio: float + ) -> tuple[float, float]: + """ + Compute (mu, sigma) for a lognormal distribution given: + - a target average (mean of the distribution) + - a ratio of median / mean (controls skewness), assume mean > median + + Background: + If Z ~ Normal(mu, sigma^2), then X = exp(Z) ~ LogNormal(mu, sigma). + * mean(X) = exp(mu + sigma^2 / 2) + * median(X) = exp(mu) + + So: + median / mean = exp(mu) / exp(mu + sigma^2 / 2) + = exp(-sigma^2 / 2) + + Rearranging: + sigma^2 = 2 * ln(mean / median) + mu = ln(median) + + This gives a unique (mu, sigma) for any valid mean and median. + """ + # Check input validity: median must be smaller than mean + if median_ratio <= 0 or median_ratio >= 1: + raise ValueError("median_ratio must be in range (0, 1)") + + target_median = target_average * median_ratio + + # Solve sigma^2 = 2 * ln(mean / median) + sigma = np.sqrt(2 * np.log(target_average / target_median)) + mu = np.log(target_median) + + return mu, sigma def sample(self, size: int = 1) -> np.ndarray: samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size) + + if self.average is not None: + # Scale to average + samples *= self.average / samples.mean() + if self.max_val: samples = np.minimum(samples, self.max_val) return np.round(samples).astype(int) def __repr__(self) -> str: - return f"LognormalDistribution[{self.mean}, {self.sigma}]" + if self.average: + return ( + f"LognormalDistribution[{self.average}, " + f"{self.median_ratio}, {self.max_val}]" + ) + return f"LognormalDistribution[{self.mean}, {self.sigma}, {self.max_val}]" class GenConvArgs(NamedTuple): @@ -173,10 +257,21 @@ def get_random_distribution( return PoissonDistribution(conf["alpha"], max_val=max_val) elif distribution == "lognormal": + max_val = conf.get("max", None) + + if "average" in conf: + # Infer lognormal mean/sigma (numpy) from input average + median_ratio = conf.get("median_ratio", None) + return LognormalDistribution( + average=conf["average"], median_ratio=median_ratio, max_val=max_val + ) + + # Use mean/sigma directly (for full control over the distribution) verify_field_exists(conf, "mean", section, subsection) verify_field_exists(conf, "sigma", section, subsection) - max_val = conf.get("max", None) - return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val) + return LognormalDistribution( + mean=conf["mean"], sigma=conf["sigma"], max_val=max_val + ) elif distribution == "uniform": verify_field_exists(conf, "min", section, subsection) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index d23b7b6e4571d..233ed460fc8d5 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -13,7 +13,7 @@ from datetime import datetime from enum import Enum from http import HTTPStatus from statistics import mean -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Union import aiohttp # type: ignore import numpy as np # type: ignore @@ -46,9 +46,9 @@ class ConversationSampling(str, Enum): class ClientArgs(NamedTuple): seed: int - max_num_requests: Optional[int] + max_num_requests: int | None skip_first_turn: bool - max_turns: Optional[int] + max_turns: int | None max_active_conversations: int verbose: bool print_content: bool @@ -109,9 +109,9 @@ class RequestStats(NamedTuple): class MetricStats: def __init__(self) -> None: - self.min: Optional[float] = None - self.max: Optional[float] = None - self.avg: Optional[float] = None + self.min: float | None = None + self.max: float | None = None + self.avg: float | None = None self.sum = 0.0 self.count = 0 @@ -143,7 +143,7 @@ class MovingAverage: self.index = 0 self.sum = 0.0 self.count = 0 - self.avg: Optional[float] = None + self.avg: float | None = None def update(self, new_value: float) -> None: if self.count < self.window_size: @@ -198,14 +198,6 @@ class DebugStats: self.logger.info("-" * 50) -# Must support Python 3.8, we can't use str.removeprefix(prefix) -# introduced in Python 3.9 -def remove_prefix(text: str, prefix: str) -> str: - if text.startswith(prefix): - return text[len(prefix) :] - return text - - def nanosec_to_millisec(value: float) -> float: return value / 1000000.0 @@ -220,8 +212,8 @@ async def send_request( chat_url: str, model: str, stream: bool = True, - min_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, + min_tokens: int | None = None, + max_tokens: int | None = None, ) -> ServerResponse: payload = { "model": model, @@ -250,9 +242,9 @@ async def send_request( timeout = aiohttp.ClientTimeout(total=timeout_sec) valid_response = True - ttft: Optional[float] = None + ttft: float | None = None chunk_delay: list[int] = [] - latency: Optional[float] = None + latency: float | None = None first_chunk = "" generated_text = "" @@ -269,7 +261,7 @@ async def send_request( if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk == "[DONE]": # End of stream latency = time.perf_counter_ns() - start_time @@ -364,7 +356,7 @@ async def send_turn( req_args: RequestArgs, verbose: bool, verify_output: bool, -) -> Optional[RequestStats]: +) -> RequestStats | None: assert messages_to_use > 0 assert messages_to_use <= len(conversation_messages) @@ -769,7 +761,7 @@ def get_client_config( "Number of conversations must be equal or larger than the number of clients" ) - max_req_per_client: Optional[int] = None + max_req_per_client: int | None = None if args.max_num_requests is not None: # Max number of requests per client req_per_client = args.max_num_requests // args.num_clients @@ -962,7 +954,7 @@ async def main_mp( # At this point all the clients finished, # collect results (TTFT, TPOT, etc.) from all the clients. - # This needs to happens before calling join on the clients + # This needs to happen before calling join on the clients # (result_queue should be emptied). while not result_queue.empty(): client_metrics.append(result_queue.get()) @@ -1032,7 +1024,7 @@ def process_statistics( warmup_percentages: list[float], test_params: dict, verbose: bool, - gen_conv_args: Optional[GenConvArgs] = None, + gen_conv_args: GenConvArgs | None = None, excel_output: bool = False, ) -> None: if len(client_metrics) == 0: diff --git a/benchmarks/multi_turn/generate_multi_turn.json b/benchmarks/multi_turn/generate_multi_turn.json index 274d03c2bdb2b..03cfc7d63e8aa 100644 --- a/benchmarks/multi_turn/generate_multi_turn.json +++ b/benchmarks/multi_turn/generate_multi_turn.json @@ -15,9 +15,8 @@ }, "prefix_num_tokens": { "distribution": "lognormal", - "mean": 6, - "sigma": 4, - "max": 1500 + "average": 1000, + "max": 5000 }, "num_tokens": { "distribution": "uniform", diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml deleted file mode 100644 index 65b1e09a247e2..0000000000000 --- a/benchmarks/pyproject.toml +++ /dev/null @@ -1,49 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.lint.isort] -known-first-party = ["vllm"] - -[tool.ruff.format] -docstring-code-format = true \ No newline at end of file diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index cc38cd41a5b24..9bac5ea41c8d4 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,6 +1,7 @@ include(FetchContent) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -87,6 +88,7 @@ is_avx512_disabled(AVX512_DISABLED) if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") message(STATUS "Apple Silicon Detected") + set(APPLE_SILICON_FOUND TRUE) set(ENABLE_NUMA OFF) check_sysctl(hw.optional.neon ASIMD_FOUND) check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) @@ -99,6 +101,7 @@ else() find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support find_isa(${CPUINFO} "S390" S390_FOUND) + find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support endif() if (AVX512_FOUND AND NOT AVX512_DISABLED) @@ -175,8 +178,14 @@ elseif (S390_FOUND) "-mzvector" "-march=native" "-mtune=native") +elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64") + if(RVV_FOUND) + message(FAIL_ERROR "Can't support rvv now.") + else() + list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc") + endif() else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") endif() # @@ -188,14 +197,25 @@ else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.9 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE - ) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) + set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.") + + if(FETCHCONTENT_SOURCE_DIR_ONEDNN) + message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}") + FetchContent_Declare( + oneDNN + SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN} + ) + else() + message(STATUS "Downloading oneDNN from GitHub") + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.9 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + endif() if(USE_ACL) find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) @@ -204,6 +224,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POW endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") + add_compile_definitions(VLLM_USE_ACL) endif() set(ONEDNN_LIBRARY_TYPE "STATIC") @@ -256,7 +277,8 @@ set(VLLM_EXT_SRC "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp") + "csrc/cpu/torch_bindings.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC @@ -298,4 +320,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") +message(STATUS "Enabling C extension.") \ No newline at end of file diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 02224cfe3ee81..c9e7aec880b99 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR) else() FetchContent_Declare( flashmla - GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA + GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") # The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. # Only build FlashMLA kernels if we are building for something compatible with # sm90a -cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) + +set(SUPPORT_ARCHS) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3) + list(APPEND SUPPORT_ARCHS 9.0a) +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8) + list(APPEND SUPPORT_ARCHS 10.0a) +endif() + + +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}") +if(FLASH_MLA_ARCHS) + set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS}) + list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math") + set(FlashMLA_SOURCES - ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu - ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) + ${flashmla_SOURCE_DIR}/csrc/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu + ${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu + ) + + set(FlashMLA_Extension_SOURCES + ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu + ) set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/sm90 ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc) + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) + + set(FlashMLA_Extension_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/sm90 + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/ + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" CUDA_ARCHS "${FLASH_MLA_ARCHS}") + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_Extension_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + define_gpu_extension_target( _flashmla_C DESTINATION vllm @@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} USE_SABI 3 WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) + + define_gpu_extension_target( + _flashmla_extension_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_Extension_SOURCES} + COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES} + USE_SABI 3 + WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_extension_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) else() - # Create an empty target for setup.py when not targeting sm90a systems + # Create empty targets for setup.py when not targeting sm90a systems add_custom_target(_flashmla_C) + add_custom_target(_flashmla_extension_C) endif() diff --git a/cmake/external_projects/qutlass.cmake b/cmake/external_projects/qutlass.cmake new file mode 100644 index 0000000000000..9aace7693077a --- /dev/null +++ b/cmake/external_projects/qutlass.cmake @@ -0,0 +1,97 @@ +include(FetchContent) + +set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory") + +if(DEFINED ENV{QUTLASS_SRC_DIR}) + set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR}) +endif() + +if(QUTLASS_SRC_DIR) + FetchContent_Declare( + qutlass + SOURCE_DIR ${QUTLASS_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + qutlass + GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git + GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) + FetchContent_Populate(qutlass) + set(qutlass_SOURCE_DIR "${qutlass_SOURCE_DIR}") +endif() + +if(NOT qutlass_SOURCE_DIR) + message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.") +endif() +message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") + +cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS) + + if(QUTLASS_ARCHS MATCHES "10\\.0a") + set(QUTLASS_TARGET_CC 100) + elseif(QUTLASS_ARCHS MATCHES "12\\.0a") + set(QUTLASS_TARGET_CC 120) + else() + message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") + endif() + + set(QUTLASS_SOURCES + ${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu + ) + + set(QUTLASS_INCLUDES + ${qutlass_SOURCE_DIR} + ${qutlass_SOURCE_DIR}/qutlass + ${qutlass_SOURCE_DIR}/qutlass/csrc/include + ${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions + ) + + if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}") + elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include") + message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).") + else() + message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. " + "Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include") + endif() + + set_gencode_flags_for_srcs( + SRCS "${QUTLASS_SOURCES}" + CUDA_ARCHS "${QUTLASS_ARCHS}" + ) + + target_sources(_C PRIVATE ${QUTLASS_SOURCES}) + target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES}) + target_compile_definitions(_C PRIVATE + QUTLASS_DISABLE_PYBIND=1 + TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC} + ) + + set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS + $<$:--expt-relaxed-constexpr --use_fast_math -O3> + ) + +else() + if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8") + message(STATUS + "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") + else() + message(STATUS + "[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " + "CUDA_ARCHS='${CUDA_ARCHS}'.") + endif() +endif() diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 49defccbb1fa4..d4908772c69ec 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f + GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/cmake/hipify.py b/cmake/hipify.py index 55d378f5b1113..8504f9defee96 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -16,7 +16,7 @@ import shutil from torch.utils.hipify.hipify_python import hipify -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Project directory where all the source + include files live. @@ -34,15 +34,14 @@ if __name__ == '__main__': ) # Source files to convert. - parser.add_argument("sources", - help="Source files to hipify.", - nargs="*", - default=[]) + parser.add_argument( + "sources", help="Source files to hipify.", nargs="*", default=[] + ) args = parser.parse_args() # Limit include scope to project_dir only - includes = [os.path.join(args.project_dir, '*')] + includes = [os.path.join(args.project_dir, "*")] # Get absolute path for all source files. extra_files = [os.path.abspath(s) for s in args.sources] @@ -51,25 +50,31 @@ if __name__ == '__main__': # The directory might already exist to hold object files so we ignore that. shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) - hipify_result = hipify(project_directory=args.project_dir, - output_directory=args.output_dir, - header_include_dirs=[], - includes=includes, - extra_files=extra_files, - show_detailed=True, - is_pytorch_extension=True, - hipify_extra_files_only=True) + hipify_result = hipify( + project_directory=args.project_dir, + output_directory=args.output_dir, + header_include_dirs=[], + includes=includes, + extra_files=extra_files, + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) hipified_sources = [] for source in args.sources: s_abs = os.path.abspath(source) - hipified_s_abs = (hipify_result[s_abs].hipified_path if - (s_abs in hipify_result - and hipify_result[s_abs].hipified_path is not None) - else s_abs) + hipified_s_abs = ( + hipify_result[s_abs].hipified_path + if ( + s_abs in hipify_result + and hipify_result[s_abs].hipified_path is not None + ) + else s_abs + ) hipified_sources.append(hipified_s_abs) - assert (len(hipified_sources) == len(args.sources)) + assert len(hipified_sources) == len(args.sources) # Print hipified source files. print("\n".join(hipified_sources)) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 9c0ed1d09572e..f6a0d2b75be1a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -310,13 +310,13 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR list(REMOVE_DUPLICATES _PTX_ARCHS) list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) - # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should - # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS + # If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS set(_CUDA_ARCHS) foreach(_arch ${_SRC_CUDA_ARCHS}) - if(_arch MATCHES "\\a$") + if(_arch MATCHES "[af]$") list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") - string(REPLACE "a" "" _base "${_arch}") + string(REGEX REPLACE "[af]$" "" _base "${_arch}") if ("${_base}" IN_LIST TGT_CUDA_ARCHS) list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") list(APPEND _CUDA_ARCHS "${_arch}") @@ -480,7 +480,6 @@ function (define_gpu_extension_target GPU_MOD_NAME) ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") endif() - set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17) target_compile_options(${GPU_MOD_NAME} PRIVATE $<$:${GPU_COMPILE_FLAGS}>) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 57382c1ddc65b..052ff168cec4f 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -28,10 +28,10 @@ #ifdef USE_ROCM #include - #include "../quantization/fp8/amd/quant_utils.cuh" + #include "../quantization/w8a8/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; #else - #include "../quantization/fp8/nvidia/quant_utils.cuh" + #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu deleted file mode 100644 index 0319d1daf302f..0000000000000 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA -void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale); -#endif - -void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale) { -#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA - return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale); -#endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); -} diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu deleted file mode 100644 index 9d05d910dd81f..0000000000000 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.h" - -#include "cutlass_extensions/common.hpp" - -#include "device/sm100_mla.hpp" -#include "kernel/sm100_mla_tile_scheduler.hpp" - -using namespace cute; -using namespace cutlass::fmha::kernel; - -template -struct MlaSm100 { - using Element = T; - using ElementAcc = float; - using ElementOut = T; - - using TileShape = Shape<_128, _128, Shape<_512, _64>>; - using TileShapeH = cute::tuple_element_t<0, TileShape>; - using TileShapeD = cute::tuple_element_t<2, TileShape>; - - // H K (D_latent D_rope) B - using ProblemShape = cute::tuple; - - using StrideQ = cute::tuple; // H D B - using StrideK = cute::tuple; // K D B - using StrideO = StrideK; // H D B - using StrideLSE = cute::tuple<_1, int>; // H B - - using TileScheduler = - std::conditional_t; - - using FmhaKernel = - cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< - TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, - /*kIsCpAsync=*/true>; - using Fmha = cutlass::fmha::device::MLA; -}; - -template -typename T::Fmha::Arguments args_from_options( - at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table, double scale) { - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope.device().index(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - - int batches = q_nope.sizes()[0]; - int page_count_per_seq = page_table.sizes()[1]; - int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; - int page_size = kv_c_and_k_pe_cache.sizes()[1]; - int max_seq_len = page_size * page_count_per_seq; - using TileShapeH = typename T::TileShapeH; - using TileShapeD = typename T::TileShapeD; - auto problem_shape = - cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); - - auto [H, K, D, B] = problem_shape; - auto [D_latent, D_rope] = D; - - using StrideQ = typename T::StrideQ; - using StrideK = typename T::StrideK; - using StrideO = typename T::StrideO; - using StrideLSE = typename T::StrideLSE; - - StrideQ stride_Q_latent = cute::make_tuple( - static_cast(D_latent), _1{}, static_cast(H * D_latent)); - StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, - static_cast(H * D_rope)); - StrideK stride_C = - cute::make_tuple(static_cast(D_latent + D_rope), _1{}, - static_cast(page_size * (D_latent + D_rope))); - StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); - StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast(H)); - StrideO stride_O = cute::make_tuple(static_cast(D_latent), _1{}, - static_cast(H * D_latent)); - - using Element = typename T::Element; - using ElementOut = typename T::ElementOut; - using ElementAcc = typename T::ElementAcc; - auto Q_latent_ptr = static_cast(q_nope.data_ptr()); - auto Q_rope_ptr = static_cast(q_pe.data_ptr()); - auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); - auto scale_f = static_cast(scale); - typename T::Fmha::Arguments arguments{ - problem_shape, - {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr, - stride_C, C_ptr + D_latent, stride_C, - static_cast(seq_lens.data_ptr()), - static_cast(page_table.data_ptr()), stride_PT, page_count_total, - page_size}, - {static_cast(out.data_ptr()), stride_O, - static_cast(nullptr), stride_LSE}, - hw_info, - 1, // split_kv - nullptr, // is_var_split_kv - }; - // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute - // split_kv automatically based on batch size and sequence length to balance - // workload across available SMs. Consider using var_split_kv for manual - // control if needed. - T::Fmha::set_split_kv(arguments); - return arguments; -} - -template -void runMla(at::Tensor const& out, at::Tensor const& q_nope, - at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, at::Tensor const& page_table, - float scale, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; - typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); - size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - CUTLASS_CHECK(fmha.can_implement(arguments)); - - CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); - - CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); -} - -void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale) { - TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA"); - TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor"); - TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, - "kv_c_and_k_pe_cache must be a 3D tensor"); - TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); - TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); - TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); - - auto B_q_nope = q_nope.size(0); - auto H_q_nope = q_nope.size(1); - auto D_q_nope = q_nope.size(2); - auto B_q_pe = q_pe.size(0); - auto H_q_pe = q_pe.size(1); - auto D_q_pe = q_pe.size(2); - auto B_pt = page_table.size(0); - auto PAGE_NUM = page_table.size(1); - auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); - auto D_ckv = kv_c_and_k_pe_cache.size(2); - auto B_o = out.size(0); - auto H_o = out.size(1); - auto D_o = out.size(2); - - TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512"); - TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64"); - TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576"); - TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128, - "H_q_nope, H_q_pe, and H_o must be equal to 128"); - TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, - "PAGE_SIZE must be a power of 2"); - TORCH_CHECK( - B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o, - "Batch dims must be same for page_table, q_nope and q_pe, and out"); - TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, - "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); - TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); - - TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half || - q_nope.dtype() == at::ScalarType::BFloat16 || - q_nope.dtype() == at::ScalarType::Float8_e4m3fn, - "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && - q_nope.dtype() == q_pe.dtype(), - "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); - TORCH_CHECK(seq_lens.dtype() == torch::kInt32, - "seq_lens must be a 32-bit integer tensor"); - TORCH_CHECK(page_table.dtype() == torch::kInt32, - "page_table must be a 32-bit integer tensor"); - - auto in_dtype = q_nope.dtype(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(q_nope.get_device()); - if (in_dtype == at::ScalarType::Half) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, - page_table, scale, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); - } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); - } else { - TORCH_CHECK(false, "Unsupported input data type of MLA"); - } -} diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index 95e32559cd540..297d94dcc0631 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -133,6 +133,14 @@ public: // printf(" sm_count = %d\n", sm_count); int max_splits = ceil_div(K, 128); max_splits = min(16, max_splits); + + // TODO: This avoids a hang when the batch size larger than 1 and + // there is more than 1 kv_splits. + // Discuss with NVIDIA how this can be fixed. + if (B > 1) { + max_splits = min(1, max_splits); + } + // printf(" max_splits = %d\n", max_splits); int sms_per_batch = max(1, sm_count / B); // printf(" sms_per_batch = %d\n", sms_per_batch); diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index 2cbc2379579eb..1f62c37ba4b7f 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) - continue; + if (local_split_kv <= get<3>(blk_coord)) + continue; load_page_table( blk_coord, problem_shape, params.mainloop, shared_storage.tensors, pipeline_page_table, pipeline_pt_producer_state, - local_split_kv + local_split_kv ); } } @@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_cpasync( blk_coord, @@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { params.mainloop_params, shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv, + local_split_kv, /* must be shared pipe */ pipeline_page_table, pipeline_pt_consumer_state ); @@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { - local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; - } + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_tma( blk_coord, @@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv + local_split_kv ); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); } @@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; - } + } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_tma( blk_coord, @@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv + local_split_kv ); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); } @@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; mma(blk_coord, problem_shape, @@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { pipeline_mma_s, pipeline_mma_s_producer_state, pipeline_p_mma, pipeline_p_mma_consumer_state, pipeline_mma_o, pipeline_mma_o_producer_state, - local_split_kv + local_split_kv ); } } @@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto split_kv = params.split_kv; - auto local_split_kv = split_kv; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; compute( blk_coord, @@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { pipeline_mma_s, pipeline_mma_s_consumer_state, pipeline_p_mma, pipeline_p_mma_producer_state, pipeline_mma_o, pipeline_mma_o_consumer_state, - local_split_kv + local_split_kv ); } @@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { cutlass::arch::NamedBarrier( (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue - ).arrive(); + ).arrive_and_wait(); return; } diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 6dd6f269f3dc9..d1874515cc8fd 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -36,12 +36,14 @@ limitations under the License. #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } @@ -64,11 +66,11 @@ struct IsPersistent { static const bool value = v; }; -template > +template > struct MlaSm100 { using Element = T; using ElementAcc = float; - using ElementOut = T; + using ElementOut = TOut; using TileShape = Shape<_128, _128, Shape<_512, _64>>; using TileShapeH = cute::tuple_element_t<0, TileShape>; @@ -99,6 +101,7 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -162,7 +165,10 @@ typename T::Fmha::Arguments args_from_options( stride_PT, page_count_total, page_size}, - {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, hw_info, // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will @@ -178,9 +184,10 @@ typename T::Fmha::Arguments args_from_options( return arguments; } -template +template void runMla( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -190,9 +197,9 @@ void runMla( double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -214,6 +221,7 @@ void runMla( void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -233,14 +241,14 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } @@ -253,7 +261,7 @@ void sm100_cutlass_mla_decode( int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) // which are float, so Element type here doesn't matter. - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; // Get split kv. Requires problem shape and sm_count only. typename MlaSm100Type::Fmha::Arguments arguments; diff --git a/csrc/cache.h b/csrc/cache.h index fb0c353b96137..b162a4a2bc31f 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -47,4 +47,28 @@ void gather_and_maybe_dequant_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, const std::string& kv_cache_dtype, torch::Tensor const& scale, - std::optional seq_starts = std::nullopt); \ No newline at end of file + std::optional seq_starts = std::nullopt); + +// TODO(hc): cp_gather_cache need support scaled kvcahe in the future. +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); + +// Indexer K quantization and cache function +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt); + +// Extract function to gather quantized K cache +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b3a985c2d5bbb..0aa0dc14c7480 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -8,15 +9,14 @@ #include "quantization/vectorization_utils.cuh" #ifdef USE_ROCM - #include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" #else - #include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #include #include -#include -#include +#include #ifdef USE_ROCM #include @@ -208,6 +208,20 @@ void copy_blocks_mla(std::vector const& kv_caches, namespace vllm { +// Used to copy/convert one element +template +struct CopyWithScaleOp { + float scale; + + __device__ __forceinline__ void operator()(OutT& dst, const InT src) const { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst = static_cast(src); + } else { + dst = fp8::scaled_convert(src, scale); + } + } +}; + template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] @@ -223,59 +237,51 @@ __global__ void reshape_and_cache_kernel( const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { - // Padding token that should be ignored. return; } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; + const int h_block_count = head_size / x; // head_size//x - const int n = num_heads * head_size; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int64_t src_key_idx = token_idx * key_stride + i; - const int64_t src_value_idx = token_idx * value_stride + i; + const int h_block_idx = threadIdx.x; + if (h_block_idx >= num_heads * h_block_count) { + return; + } - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; + const int head_idx = h_block_idx / h_block_count; + const int h_block = h_block_idx % h_block_count; - const int64_t tgt_key_idx = - block_idx * num_heads * (head_size / x) * block_size * x + - head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + - block_offset * x + x_offset; - const int64_t tgt_value_idx = - block_idx * num_heads * head_size * block_size + - head_idx * head_size * block_size + head_offset * block_size + - block_offset; - scalar_t tgt_key = key[src_key_idx]; - scalar_t tgt_value = value[src_value_idx]; - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - key_cache[tgt_key_idx] = tgt_key; - value_cache[tgt_value_idx] = tgt_value; - } else { - key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, *k_scale); - value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, *v_scale); - } + const scalar_t* __restrict__ key_src = + key + token_idx * key_stride + head_idx * head_size + h_block * x; + const int64_t src_value_start = + token_idx * value_stride + head_idx * head_size + h_block * x; + + cache_t* __restrict__ key_dst = + key_cache + block_idx * num_heads * h_block_count * block_size * x + + head_idx * h_block_count * block_size * x + h_block * block_size * x + + block_offset * x; + const int64_t tgt_value_start = + block_idx * num_heads * h_block_count * x * block_size + + head_idx * h_block_count * x * block_size + h_block * x * block_size + + block_offset; + + constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4; + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; + CopyWithScaleOp k_op{k_scale_val}; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; + CopyWithScaleOp v_op{v_scale_val}; + + vectorize_with_alignment(key_src, key_dst, x, 0, 1, k_op); + + const scalar_t* __restrict__ value_src = value + src_value_start; + cache_t* __restrict__ value_dst = value_cache + tgt_value_start; +#pragma unroll + for (int i = 0; i < x; i++) { + v_op(value_dst[i * block_size], value_src[i]); } } -// Used by vectorization_utils to copy/convert one element -template -struct CopyWithScaleOp { - float scale; - - __device__ __forceinline__ void operator()(OutT& dst, const InT src) const { - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - dst = static_cast(src); - } else { - dst = fp8::scaled_convert(src, scale); - } - } -}; - template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] @@ -395,6 +401,241 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } +template +__global__ void concat_and_cache_ds_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int64_t dst_idx_start = + block_idx * block_stride + block_offset * entry_stride; + + // For the NoPE part, each tile of 128 elements is handled by half of one warp + // (16 threads). There are 4 total tiles, so 2 warps (64 threads). + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. + // The RoPE part (last 64 elements) is handled by another 1 warp (32 threads). + // So in total, we use 3 warps (96 threads) per block. + + // Cast kv_cache to 16_bit for RoPE values + scalar_t* kv_cache_16bit = + reinterpret_cast(&kv_cache[dst_idx_start]); + + // The last warp handles the RoPE part + if (threadIdx.x >= 64) { + // Each thread handles two elements of RoPE + const int8_t pe_idx_start = (threadIdx.x - 64) * 2; + const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start; + // Vectorized load of two 16-bit values, performed as one 32-bit load + const int32_t vals = *reinterpret_cast(&k_pe[src_idx]); + // RoPE values start after the packed 8-bit NoPE values and the + // 32-bit scales + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start; + // Vectorized store of two 16-bit values, performed as one 32-bit store + *reinterpret_cast(&kv_cache_16bit[dst_idx]) = vals; + return; + } + + // The first two warps handle the NoPE part + const int8_t warp_idx = threadIdx.x >> 5; + const int8_t lane_idx = threadIdx.x & 31; + const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); + + // Each thread handles 8 elements of NoPE + // Load the NoPE elements for this thread into registers + const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8); + // Vectorized load of eight 16-bit values, performed as an int4 load + const int4 vals_i4 = *reinterpret_cast(&kv_c[src_idx_start]); + const scalar_t* vals = reinterpret_cast(&vals_i4); + + // Max absolute value of this thread's elements + float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])), + fmaxf(fabsf(vals[2]), fabsf(vals[3]))), + fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])), + fmaxf(fabsf(vals[6]), fabsf(vals[7])))); + + // Warp-level reduction to find the max absolute value in each half-warp +#pragma unroll + for (int offset = 8; offset > 0; offset /= 2) { + max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16)); + } + + // Compute the scale for the tile + float tile_scale = max_abs / 448.f; + tile_scale = fmaxf(tile_scale, FLT_MIN); + + // The first lane of each half-warp writes the scale to kv_cache + if ((lane_idx == 0) || (lane_idx == 16)) { + float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); + const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; + kv_cache_32bit[dst_idx] = tile_scale; + } + + // Now all threads in the block scale and write their elements + // NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes) + const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8); + + uint8_t result[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = + fp8::scaled_convert( + vals[i], tile_scale); + } + + // Store as aligned 64-bit writes + *reinterpret_cast(&kv_cache[dst_idx_base]) = + *reinterpret_cast(result); +} + +template +__global__ void indexer_k_quant_and_cache_kernel( + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int head_dim, // dimension of each head + const int quant_block_size, // quantization block size + const int cache_block_size, // cache block size + const int cache_stride, // stride for each token in kv_cache + const bool use_ue8m0 // use ue8m0 scale format +) { + constexpr int VEC_SIZE = 4; + const int64_t token_idx = blockIdx.x; + const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x) * + VEC_SIZE; + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / cache_block_size; + const int64_t block_offset = slot_idx % cache_block_size; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0 || (head_dim_idx >= head_dim)) { + return; + } + + float2 k_val = (reinterpret_cast( + k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; + scalar_t* k_val_ptr = reinterpret_cast(&k_val); + float amax = 0.0f; + for (int i = 0; i < VEC_SIZE; i++) { + amax = fmaxf(amax, fabsf(float(k_val_ptr[i]))); + } +#ifndef USE_ROCM + __syncwarp(); +#endif + + // Reduced amax + for (int mask = 16; mask > 0; mask /= 2) { +#ifdef USE_ROCM + amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask)); +#else + amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask)); +#endif + } +#ifndef USE_ROCM + __syncwarp(); +#endif + float scale = fmaxf(amax, 1e-4) / 448.0f; + if (use_ue8m0) { + scale = exp2f(ceilf(log2f(scale))); + } + + const int64_t dst_offset = block_idx * cache_block_size * cache_stride + + block_offset * head_dim + head_dim_idx; + for (int i = 0; i < VEC_SIZE; i++) { + kv_cache[dst_offset + i] = + fp8::scaled_convert(k_val_ptr[i], scale); + } + if (threadIdx.x == 0) { + const int64_t dst_scale_idx = + block_idx * cache_block_size * cache_stride + + cache_block_size * head_dim + + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; + } +} + +template +__global__ void cp_gather_indexer_k_quant_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size * + // 4] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, // batch size + const int64_t token_stride, // stride for each token in dst_k + const int64_t head_dim, // dimension of each head + const int64_t block_stride, // stride for each block in kv_cache + const int64_t cache_token_stride, // stride for each token in kv_cache + const int64_t cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int num_tokens, // number of tokens + const int quant_block_size // quantization block size +) { + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + // Find batch index within a block + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); + iter++) { + int tid = iter * blockDim.x + threadIdx.x; + if (tid < batch_size) { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = tid; + } + } + } + +#ifndef USE_ROCM + __syncwarp(); +#endif + + if (head_idx >= head_dim || token_idx >= num_tokens) { + return; + } + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_idx * block_stride; + const int64_t cache_inblock_offset = + (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; + ; + if (threadIdx.x == 0) { + const int64_t src_scale_offset = + src_block_offset + cache_block_size * head_dim + + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = + reinterpret_cast(kv_cache)[src_scale_offset / 4]; + } +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -430,14 +671,15 @@ void reshape_and_cache( int key_stride = key.stride(0); int value_stride = value.stride(0); + int head_div_x = head_size / x; dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); + dim3 block(std::min(num_heads * head_div_x, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, - CALL_RESHAPE_AND_CACHE) + CALL_RESHAPE_AND_CACHE); } // KV_T is the data type of key and value tensors. @@ -508,6 +750,18 @@ void reshape_and_cache_flash( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_ds_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -530,20 +784,43 @@ void concat_and_cache_mla( int pe_dim = k_pe.size(1); int block_size = kv_cache.size(1); - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + if (kv_cache_dtype == "fp8_ds_mla") { + TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"); + TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"); + TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(), + "kv_cache.size(2) must be 656 bytes for fp8_ds_mla"); + TORCH_CHECK(kv_c.itemsize() == 2, + "kv_c.itemsize() must be 2 for fp8_ds_mla"); + TORCH_CHECK(k_pe.itemsize() == 2, + "k_pe.itemsize() must be 2 for fp8_ds_mla"); + } else { + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + } int kv_c_stride = kv_c.stride(0); int k_pe_stride = k_pe.stride(0); int block_stride = kv_cache.stride(0); int entry_stride = kv_cache.stride(1); - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CONCAT_AND_CACHE_MLA); + if (kv_cache_dtype == "fp8_ds_mla") { + dim3 grid(num_tokens); + // For the NoPE part, each tile of 128 elements is handled by half of one + // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. + // The RoPE part (last 64 elements) is handled by another 1 warp (32 + // threads). So in total, we use 3 warps (96 threads) per block. + dim3 block(96); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_DS_MLA); + } else { + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); + } } namespace vllm { @@ -779,3 +1056,240 @@ void gather_and_maybe_dequant_cache( DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); } + +namespace vllm { +template +// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by +// block_size. +__global__ void cp_gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRY_SIZE] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts // Optional: starting offsets per + // batch +) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on it + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + if (seq_starts != nullptr) { + offset += seq_starts[bid]; + } + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * dst_entry_stride; + copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr); + offset += 1; + // bump to next block + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \ + vllm::cp_gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting slot index by +// seq_starts[bid] +void cp_gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_CP_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_CP_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_CP_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} + +// Macro to dispatch the kernel based on the data type. +#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::indexer_k_quant_and_cache_kernel \ + <<>>( \ + reinterpret_cast(k.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), head_dim, quant_block_size, \ + cache_block_size, cache_stride, use_ue8m0); + +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt) { + int num_tokens = k.size(0); + int head_dim = k.size(1); + int cache_block_size = kv_cache.size(1); + int cache_stride = kv_cache.size(2); + bool use_ue8m0 = scale_fmt == "ue8m0"; + + TORCH_CHECK(k.device() == kv_cache.device(), + "k and kv_cache must be on the same device"); + TORCH_CHECK(k.device() == slot_mapping.device(), + "k and slot_mapping must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 4; + dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / + (quant_block_size * vec_size)); + dim3 block(32, vec_size); + const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", + CALL_INDEXER_K_QUANT_AND_CACHE); +} + +// Macro to dispatch the kernel based on the data amount. +#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ + vllm::cp_gather_indexer_k_quant_cache_kernel \ + <<>>( \ + reinterpret_cast(kv_cache.data_ptr()), \ + reinterpret_cast(dst_k.data_ptr()), \ + reinterpret_cast(dst_scale.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \ + kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \ + num_tokens, quant_block_size); + +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens // [batch_size + 1] +) { + int batch_size = block_table.size(0); + int num_tokens = dst_k.size(0); + int head_dim = dst_k.size(1); + int quant_block_size = head_dim * 4 / dst_scale.size(1); + + TORCH_CHECK(kv_cache.device() == dst_k.device(), + "kv_cache and dst_k must be on the same device"); + TORCH_CHECK(kv_cache.device() == dst_scale.device(), + "kv_cache and dst_scale must be on the same device"); + TORCH_CHECK(kv_cache.device() == block_table.device(), + "kv_cache and block_table must be on the same device"); + TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), + "kv_cache and cu_seq_lens must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 16; + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens < 32) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1); + } else if (num_tokens < 64) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2); + } else if (num_tokens < 128) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4); + } else if (num_tokens < 256) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8); + } else if (num_tokens < 512) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16); + } else { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); + } +} diff --git a/csrc/core/batch_invariant.hpp b/csrc/core/batch_invariant.hpp new file mode 100644 index 0000000000000..e769e1a25ac0e --- /dev/null +++ b/csrc/core/batch_invariant.hpp @@ -0,0 +1,19 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +// vllm_kernel_override_batch_invariant(); returns true +// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1 +inline bool vllm_kernel_override_batch_invariant() { + static bool cached = []() { + std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"; + const char* val = std::getenv(env_key.c_str()); + return (val && std::atoi(val) != 0) ? 1 : 0; + }(); + return cached; +} + +} // namespace vllm diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 17bbe04eef94a..9cdcd2edacfdb 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -14,7 +14,12 @@ // arm implementation #include "cpu_types_arm.hpp" #else - #warning "unsupported vLLM cpu implementation" + #warning "unsupported vLLM cpu implementation, vLLM will compile with scalar" + #include "cpu_types_scalar.hpp" +#endif + +#ifdef _OPENMP + #include #endif #endif \ No newline at end of file diff --git a/csrc/cpu/cpu_types_scalar.hpp b/csrc/cpu/cpu_types_scalar.hpp new file mode 100644 index 0000000000000..1a9278bc662e5 --- /dev/null +++ b/csrc/cpu/cpu_types_scalar.hpp @@ -0,0 +1,513 @@ +#include +#include +#include +#include +#include "float_convert.hpp" + +namespace vec_op { + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) +#else + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +#define __max(a, b) ((a) > (b) ? (a) : (b)) +#define __min(a, b) ((a) < (b) ? (a) : (b)) +#define __abs(a) ((a) < (0) ? (0 - a) : (a)) + +typedef struct f16x8_t { + uint16_t val[8]; +} f16x8_t; + +typedef struct f16x16_t { + uint16_t val[16]; +} f16x16_t; + +typedef struct f16x32_t { + uint16_t val[32]; +} f16x32_t; + +typedef struct f32x4_t { + float val[4]; +} f32x4_t; + +typedef struct f32x8_t { + float val[8]; +} f32x8_t; + +typedef struct f32x16_t { + float val[16]; +} f32x16_t; + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +}; +}; // namespace + +template > > +constexpr void unroll_loop(F&& f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template +struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + f16x8_t reg; + + explicit FP16Vec8(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP16Vec8(const FP32Vec8&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + f16x16_t reg; + + explicit FP16Vec16(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP16Vec16(const FP32Vec16&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + int num = __min(elem_num, VEC_ELEM_NUM); + std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); + } +}; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + f16x8_t reg; + + explicit BF16Vec8(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec8(const FP32Vec8&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + f16x16_t reg; + + explicit BF16Vec16(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec16(const FP32Vec16&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + int num = __min(elem_num, VEC_ELEM_NUM); + std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); + } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + f16x32_t reg; + + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec32(f16x32_t data) : reg(data) {}; + + explicit BF16Vec32(BF16Vec8& vec8_data) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM]; + } + } + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + + f32x4_t reg; + + explicit FP32Vec4(float v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = v; + } + } + + explicit FP32Vec4() { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = 0.0f; + } + } + + explicit FP32Vec4(const float* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP32Vec4(f32x4_t data) : reg(data) {}; + + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}; +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + f32x8_t reg; + + explicit FP32Vec8(float v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = v; + } + } + + explicit FP32Vec8() { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = 0.0f; + } + } + + explicit FP32Vec8(const float* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP32Vec8(f32x8_t data) : reg(data) {}; + + explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; + + explicit FP32Vec8(const FP16Vec8& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = fp16_to_float(v.reg.val[i]); + } + } + + FP32Vec8(const BF16Vec8& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = bf16_to_float(v.reg.val[i]); + } + } + + float reduce_sum() const { + float result = 0; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result += reg.val[i]; + } + return result; + } + + FP32Vec8 exp() const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = expf(reg.val[i]); + } + return FP32Vec8(ret); + } + + FP32Vec8 tanh() const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = tanhf(reg.val[i]); + } + return FP32Vec8(ret); + } + + FP32Vec8 er() const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = erf(reg.val[i]); + } + return FP32Vec8(ret); + } + + FP32Vec8 operator*(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] * b.reg.val[i]; + } + return FP32Vec8(ret); + } + + FP32Vec8 operator+(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] + b.reg.val[i]; + } + return FP32Vec8(ret); + } + + FP32Vec8 operator-(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] - b.reg.val[i]; + } + return FP32Vec8(ret); + } + + FP32Vec8 operator/(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] / b.reg.val[i]; + } + return FP32Vec8(ret); + } + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + f32x16_t reg; + + explicit FP32Vec16(float v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = v; + } + } + + explicit FP32Vec16() { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = 0.0f; + } + } + + explicit FP32Vec16(const float* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP32Vec16(f32x16_t data) : reg(data) {}; + + FP32Vec16(const FP32Vec4& data) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM]; + } + } + + FP32Vec16(const FP32Vec8& data) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM]; + } + } + + FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}; + + explicit FP32Vec16(const FP16Vec16& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = fp16_to_float(v.reg.val[i]); + } + } + + explicit FP32Vec16(const BF16Vec16& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = bf16_to_float(v.reg.val[i]); + } + } + + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; + + FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; + + FP32Vec16 operator*(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] * b.reg.val[i]; + } + return result; + } + + FP32Vec16 operator+(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] + b.reg.val[i]; + } + return result; + } + + FP32Vec16 operator-(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] - b.reg.val[i]; + } + return result; + } + + FP32Vec16 operator/(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] / b.reg.val[i]; + } + return result; + } + + FP32Vec16 max(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = __max(reg.val[i], b.reg.val[i]); + } + return result; + } + + FP32Vec16 min(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = __min(reg.val[i], b.reg.val[i]); + } + return result; + } + + FP32Vec16 abs() const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = __abs(reg.val[i]); + } + return result; + } + + float reduce_sum() const { + float result = 0.0f; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result += reg.val[i]; + } + return result; + } + + float reduce_max() const { + float result = reg.val[0]; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result = __max(reg.val[i], result); + } + return result; + } + + float reduce_min() const { + float result = reg.val[0]; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result = __min(reg.val[i], result); + } + return result; + } + + template + float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + float sum = 0.0; + int start = idx * group_size; + int end = (idx + 1) * group_size; + + for (; (start < VEC_ELEM_NUM) && (start < end); ++start) { + sum += reg.val[start]; + } + + return sum; + } + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = FP16Vec8; +}; + +template <> +struct VecType { + using vec_type = BF16Vec8; +}; + +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} + +/* +template <> inline void storeFP32(float v, c10::Half *ptr) { + c10::Half __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} +*/ + +template <> +inline void storeFP32(float v, c10::Half* ptr) { + uint16_t fp16 = float_to_fp16(v); + *reinterpret_cast(ptr) = fp16; +} + +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { + int i = 0; + for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_fp16(v.reg.val[i]); + } +} + +inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { + int i = 0; + for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_fp16(v.reg.val[i]); + } +} + +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { + acc = acc + a * b; +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { + int i = 0; + for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_bf16(v.reg.val[i]); + } +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { + int i = 0; + for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_bf16(v.reg.val[i]); + } +} + +inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); } + +}; // namespace vec_op diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp index ab8cbbbf4ec4f..51bca37e699b9 100644 --- a/csrc/cpu/cpu_types_vxe.hpp +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -12,7 +12,7 @@ namespace vec_op { #define vec_sub(a, b) ((a) - (b)) #define vec_mul(a, b) ((a) * (b)) #define vec_div(a, b) ((a) / (b)) -#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic #define vec_sl(a, b) ((a) << (b)) // Vector Shift Left // FIXME: FP16 is not fully supported in Torch-CPU diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index f3f00edb36068..0f0cc34602b34 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) { delete ptr; } +DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) { + this->realloc(allocation_unit * 128); +} + +void DNNLScratchPadManager::realloc(size_t new_size) { + new_size = round(new_size); + if (new_size > size_) { + ptr_ = std::aligned_alloc(64, new_size); + size_ = new_size; + } +} + +DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() { + static DNNLScratchPadManager manager; + return &manager; +} + template class DNNLPrimitiveCache { public: @@ -120,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( } void DNNLMatMulPrimitiveHandler::prepack_weight( - void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { - dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, - {b_k_stride_, b_n_stride_}); + void* original_b_ptr, dnnl::memory::desc original_b_md, + dnnl::memory::desc b_target_mem_desc) { dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); dnnl::memory packed_weight(b_target_mem_desc, default_engine()); { @@ -166,6 +182,23 @@ struct hash { hash()(static_cast(val.bias_type)); } }; + +template <> +struct hash { + size_t operator()( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size); + } +}; + +template <> +struct hash { + size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ + hash()(val.a_m_stride) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; } // namespace std bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, @@ -181,6 +214,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, l.bias_type == r.bias_type; } +bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; +} + +bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, + const MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride && + l.use_bias == r.use_bias && l.bias_type == r.bias_type; +} + static std::shared_ptr get_w8a8_class_primitive_cache( const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, @@ -205,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) if (a_qs_ == QuantizationStrategy::PER_TOKEN) { assert(!use_azp_); }; - prepack_weight(args.b_ptr, + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, .use_bias = false, @@ -239,6 +285,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { } dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + matmul.execute(default_stream(), memory_cache_); default_stream().wait(); } @@ -257,6 +308,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); return dnnl::matmul(desc); }); } @@ -300,6 +353,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); } dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( @@ -319,6 +377,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( dnnl::memory::format_tag::ab); dnnl::primitive_attr attr; + + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + // For PER_TOKEN, scales will be applied in outside epilogue if (a_qs_ == QuantizationStrategy::PER_TENSOR) { attr.set_scales_mask(DNNL_ARG_SRC, 0); @@ -344,3 +405,177 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( attr); } } + +MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), args.ab_type), + m_size_cache_(nullptr) { + assert(ab_type_ == dnnl::memory::data_type::f32 || + ab_type_ == dnnl::memory::data_type::bf16 || + ab_type_ == dnnl::memory::data_type::f16); + + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + + prepack_weight(args.b_ptr, original_b_md, + create_primitive_desc( + MSizeCacheKey{ +#ifdef VLLM_USE_ACL + // Arm Compute Library (ACL) backend for oneDNN does + // not support runtime + // dimensions, so we set M to a default value + .a_m_size = 128, + .a_m_stride = b_k_size_, +#else + .a_m_size = DNNL_RUNTIME_DIM_VAL, + .a_m_stride = DNNL_RUNTIME_DIM_VAL, +#endif + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +static std::shared_ptr +get_matul_class_primitive_cache( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +void MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + +#ifndef VLLM_USE_ACL + // We do not support in ACL backend of oneDNN, we handle bias by: + // 1. copying it into the result tensor + // 2. attaching a fused-sum post-op to the matmul primitive + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2); + bias_storage->set_data_handle((void*)args.bias_ptr); + } +#endif + dnnl::matmul matmul = get_matmul_cache(args); + +// With ACL backend of oneDNN, the required memory format might change when the +// source tensor dims change. This does not really happen in practice, so isn't +// a performance hit, but we need to support it because the API allows for it. +#ifdef VLLM_USE_ACL + auto new_expected_wei_desc = + dnnl::matmul::primitive_desc( + const_cast(matmul.get_primitive_desc())) + .weights_desc(); + if (new_expected_wei_desc != b_target_mem_desc_) { + prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(), + b_target_mem_desc_, new_expected_wei_desc); + } +#endif + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; + m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + } + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); + return dnnl::matmul(desc); + }); +} + +dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md; + dnnl::memory::desc b_md; + if (first_time) { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + dnnl::memory::format_tag::ab); + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); + } else { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + {key.a_m_stride, 1}); +#ifdef VLLM_USE_ACL + // ACL's backend of oneDNN always expects the weight format to be "any" + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); +#else + b_md = b_target_mem_desc_; +#endif + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (key.use_bias) { + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); +// Since ACL's matmuls don't support passing a bias_md, we apply the bias +// through a fused-sum post-op +#ifdef VLLM_USE_ACL + dnnl::post_ops post_ops; + post_ops.append_sum(); + attr.set_post_ops(post_ops); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); +#else + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); +#endif + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} + +void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory( + {{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + +// ACL matmuls don't support bias_md, so we don't need these +#ifndef VLLM_USE_ACL + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get()); +#endif + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); +} + +bool is_onednn_acl_supported() { +#ifdef VLLM_USE_ACL + return true; +#else + return false; +#endif +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index 54ceefced9e98..f0cb197d81a35 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -59,6 +59,30 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() { return DNNLType>::type; } +class DNNLScratchPadManager { + public: + static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB + + static DNNLScratchPadManager* get_dnnl_scratchpad_manager(); + + DNNLScratchPadManager(); + + template + T* get_data() { + return reinterpret_cast(ptr_); + } + + static size_t round(size_t size) { + return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit; + } + + void realloc(size_t new_size); + + private: + size_t size_; + void* ptr_; +}; + class DNNLMatMulPrimitiveHandler { public: virtual ~DNNLMatMulPrimitiveHandler() = default; @@ -77,7 +101,7 @@ class DNNLMatMulPrimitiveHandler { protected: DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); - void prepack_weight(void* original_b_ptr, + void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md, dnnl::memory::desc b_target_mem_desc); void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); @@ -166,4 +190,54 @@ class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { std::shared_ptr m_size_cache_; }; +class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + struct Args : public DNNLMatMulPrimitiveHandler::Args { + dnnl::memory::data_type ab_type; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + dnnl_dim_t a_m_stride; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const void* a_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + MatMulPrimitiveHandler(const Args& args); + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + std::shared_ptr m_size_cache_; +}; + #endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp index acc3b9ecde143..6d062c71e7674 100644 --- a/csrc/cpu/dnnl_kernels.cpp +++ b/csrc/cpu/dnnl_kernels.cpp @@ -145,7 +145,8 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, } } - float scale_val, azp_val; + float scale_val; + float azp_val = 0.0f; if constexpr (AZP) { float max_scalar = max_value.reduce_max(); float min_scalar = min_value.reduce_min(); @@ -379,6 +380,7 @@ void onednn_scaled_mm( exec_args.a_ptr = a.data_ptr(); exec_args.a_m_size = a.size(0); exec_args.bias_ptr = nullptr; + exec_args.bias_type = get_dnnl_type(); exec_args.use_bias = false; exec_args.a_scales_ptr = nullptr; exec_args.a_zero_points_ptr = nullptr; @@ -492,3 +494,77 @@ void dynamic_scaled_int8_quant( } }); } + +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + + MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler", + [&] { + args.c_type = get_dnnl_type(); + args.ab_type = get_dnnl_type(); + }); + + return reinterpret_cast(new MatMulPrimitiveHandler(args)); +} + +void onednn_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const std::optional& bias, int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.stride(-1) == 1); + TORCH_CHECK(c.stride(-1) == 1); + MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + +// ACL matmuls expect contiguous source tensors +#ifdef VLLM_USE_ACL + torch::Tensor a_contig = a.contiguous(); +#endif + + MatMulPrimitiveHandler::ExecArgs exec_args; + +#ifdef VLLM_USE_ACL + exec_args.a_m_size = a_contig.size(0); + exec_args.a_m_stride = a_contig.stride(0); +#else + exec_args.a_m_size = a.size(0); + exec_args.a_m_stride = a.stride(0); +#endif + VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] { + if (bias.has_value()) { + exec_args.use_bias = true; + exec_args.bias_type = get_dnnl_type(); +#ifdef VLLM_USE_ACL + // ACL matmuls in oneDNN do not support a bias. + // We handle a matmul with bias by doing: c = bias; c += matmul(a, b) + c.copy_(bias.value()); +#else + exec_args.bias_ptr = bias->data_ptr(); +#endif + } else { + exec_args.use_bias = false; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = nullptr; + } +#ifdef VLLM_USE_ACL + exec_args.a_ptr = a_contig.data_ptr(); +#else + exec_args.a_ptr = a.data_ptr(); + +#endif + exec_args.c_ptr = c.data_ptr(); + + ptr->execute(exec_args); + }); +} diff --git a/csrc/cpu/float_convert.hpp b/csrc/cpu/float_convert.hpp new file mode 100644 index 0000000000000..c792bf131ccdc --- /dev/null +++ b/csrc/cpu/float_convert.hpp @@ -0,0 +1,106 @@ + +static float bf16_to_float(uint16_t bf16) { + uint32_t bits = static_cast(bf16) << 16; + float fp32; + std::memcpy(&fp32, &bits, sizeof(fp32)); + return fp32; +} + +static uint16_t float_to_bf16(float fp32) { + uint32_t bits; + std::memcpy(&bits, &fp32, sizeof(fp32)); + return static_cast(bits >> 16); +} + +/************************************************ + * Copyright (c) 2015 Princeton Vision Group + * Licensed under the MIT license. + * Codes below copied from + * https://github.com/PrincetonVision/marvin/tree/master/tools/tensorIO_matlab + *************************************************/ +static uint16_t float_to_fp16(float fp32) { + uint16_t fp16; + + unsigned x; + unsigned u, remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + std::memcpy(&x, &fp32, sizeof(fp32)); + u = (x & 0x7fffffff); + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + fp16 = 0x7fffU; + return fp16; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + fp16 = sign | 0x7c00U; + return fp16; + } + if (u < 0x33000001) { + fp16 = (sign | 0x0000); + return fp16; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + fp16 = (sign | (exponent << 10) | mantissa); + + return fp16; +} + +static float fp16_to_float(uint16_t fp16) { + unsigned sign = ((fp16 >> 15) & 1); + unsigned exponent = ((fp16 >> 10) & 0x1f); + unsigned mantissa = ((fp16 & 0x3ff) << 13); + int temp; + float fp32; + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + temp = ((sign << 31) | (exponent << 23) | mantissa); + std::memcpy(&fp32, &temp, sizeof(temp)); + return fp32; +} diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp index beeccff783ea0..94b24c2f13a06 100644 --- a/csrc/cpu/sgl-kernels/moe.cpp +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -215,7 +215,7 @@ int moe_align_block_size( offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); } }); - // TODO: do we need to vecterize this ? + // TODO: do we need to vectorize this ? for (int mb = 0; mb < num_token_blocks; ++mb) { offsets[mb + 1] += offsets[mb]; } diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index c9f426bdf618a..9df19d1ac3928 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -21,6 +21,14 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const std::optional& bias, int64_t handler); +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size); + +void onednn_mm(torch::Tensor& c, const torch::Tensor& a, + const std::optional& bias, int64_t handler); + +bool is_onednn_acl_supported(); + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -82,8 +90,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + ops.def( + "dynamic_4bit_int_moe(" + "Tensor x, Tensor topk_ids, Tensor topk_weights," + "Tensor w13_packed, Tensor w2_packed, int H, int I, int I2," + "int group_size, bool apply_router_weight_on_input, int activation_kind" + ") -> Tensor"); + + ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu); + // PagedAttention V2. ops.def( "paged_attention_v2(" @@ -153,6 +171,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("release_dnnl_matmul_handler(int handler) -> ()", &release_dnnl_matmul_handler); + // Create oneDNN GEMM handler + ops.def( + "create_onednn_mm_handler(Tensor b, int " + "primitive_cache_size) -> int", + &create_onednn_mm_handler); + + // oneDNN GEMM + ops.def( + "onednn_mm(Tensor! c, Tensor a, Tensor? bias, " + "int handler) -> ()"); + ops.impl("onednn_mm", torch::kCPU, &onednn_mm); + + // Check if oneDNN was built with ACL backend + ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported); + // Create oneDNN W8A8 handler ops.def( "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h new file mode 100644 index 0000000000000..18e4e343ad8b7 --- /dev/null +++ b/csrc/cub_helpers.h @@ -0,0 +1,18 @@ +#pragma once + +#ifndef USE_ROCM + #include + #if CUB_VERSION >= 200800 + #include +using CubAddOp = cuda::std::plus<>; +using CubMaxOp = cuda::maximum<>; + #else // if CUB_VERSION < 200800 +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; + #endif // CUB_VERSION +#else + #include +namespace cub = hipcub; +using CubAddOp = hipcub::Sum; +using CubMaxOp = hipcub::Max; +#endif // USE_ROCM diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b4597765..58926f6429dd3 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -15,6 +15,8 @@ typedef __hip_bfloat16 nv_bfloat16; #include #include #include +#include +#include namespace vllm { #define CUDACHECK(cmd) \ @@ -555,22 +557,47 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); + + // Check environment variable once + const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO"); + bool force_1stage = false; + bool force_2stage = false; + if (env_algo != nullptr) { + if (std::strcmp(env_algo, "1stage") == 0 || + std::strcmp(env_algo, "oneshot") == 0) { + force_1stage = true; + } else if (std::strcmp(env_algo, "2stage") == 0 || + std::strcmp(env_algo, "twoshot") == 0) { + force_2stage = true; + } else { + throw std::runtime_error( + "Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) + + ". Valid values: 1stage, oneshot, 2stage, twoshot"); + } + } + #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); -#define REDUCE_CASE(ngpus) \ - case ngpus: { \ - if (world_size_ == 2) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else if (fully_connected_) { \ - if ((world_size_ <= 4 && bytes < 512 * 1024) || \ - (world_size_ <= 8 && bytes < 256 * 1024)) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else { \ - KL(ngpus, cross_device_reduce_2stage); \ - } \ - } \ - break; \ +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (force_1stage) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (force_2stage) { \ + KL(ngpus, cross_device_reduce_2stage); \ + } else { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (fully_connected_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + } \ + break; \ } switch (world_size_) { diff --git a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp deleted file mode 100644 index ec75c29e54f4d..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp +++ /dev/null @@ -1,123 +0,0 @@ -// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl -// clang-format off -#pragma once - -#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" - -#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_SS (BlockScaled Builders) -template < - class ElementA, - class GmemLayoutATag, - int AlignmentA, - class ElementB, - class GmemLayoutBTag, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - int ScaleGranularityM -> -struct CollectiveBuilder< - arch::Sm90, - arch::OpClassTensorOp, - ElementA, - GmemLayoutATag, - AlignmentA, - ElementB, - GmemLayoutBTag, - AlignmentB, - ElementAccumulator, - TileShape_MNK, - ClusterShape_MNK, - StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, - cute::enable_if_t< - not detail::is_use_rmem_A()> -> { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; - - static_assert(is_static::value); - static_assert(is_static::value); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - static_assert(detail::is_aligned(), - "Should meet TMA alignment requirement\n"); - - static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); - static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert((!IsFP8Input || !IsArrayOfPointersGemm), - "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); - - // For fp32 types, map to tf32 MMA value type - using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; - using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - static constexpr bool IsCooperative = cute::is_any_of_v>; - using AtomLayoutMNK = cute::conditional_t>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector< - GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector< - GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; - static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); - - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMma< - DispatchPolicy, - TileShape_MNK, - ElementA, - TagToStrideA_t, - ElementB, - TagToStrideB_t, - TiledMma, - GmemTiledCopyA, - SmemLayoutAtomA, - SmemCopyAtomA, - cute::identity, - GmemTiledCopyB, - SmemLayoutAtomB, - SmemCopyAtomB, - cute::identity - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp deleted file mode 100644 index 13b90e998625e..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp +++ /dev/null @@ -1,183 +0,0 @@ -// clang-format off -// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp - -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cute/algorithm/clear.hpp" -#include "cute/tensor.hpp" - -////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////FP8 Accumulation/////////////////////////// -////////////////////////////////////////////////////////////////////////////// -/// This class provides API to promote (add) or scale (multiply_add) the results -/// from the tensor core accumulators to the main accumulators when the number -/// of MMAs reaches the max number of MMA interval specified by user, after that -/// the tensor core accumulators are zeroed. -////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -template < - class EngineAccum, - class LayoutAccum> -struct GmmaFP8AccumulationWithScale { - using TensorAccum = cute::Tensor; - using ElementAccumulator = typename EngineAccum::value_type; - - static_assert(is_static::value, "Accumulator Layout should be static"); - static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); - -private: - TensorAccum& accum_; - TensorAccum accum_temp_; - - uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. - uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop - uint32_t mma_count_; // current executed MMAs - uint32_t reset_accum_flag_; // accum needs to be zeroed or not. - - // promote or `add` the partial accumulators to main accumulator (FADD). - CUTLASS_DEVICE - void promote_core() { - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i); - } - } - - // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_core(const cute::Tensor &scale) { - using TensorScale = cute::Tensor; - - static_assert(is_static::value, "Scale Layout should be static"); - static_assert(is_rmem::value , "Scale tensor must be rmem resident."); - - static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); - - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i) * scale(i); - } - } - -public: - CUTLASS_DEVICE - GmmaFP8AccumulationWithScale( - TensorAccum &accum, - uint32_t accum_promotion_interval, - uint32_t mma_count_per_mainloop_iteration) - : accum_(accum), - accum_promotion_interval_(accum_promotion_interval), - mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), - mma_count_(0), - reset_accum_flag_(0) - { - accum_temp_ = cute::make_fragment_like(accum); - } - - // - // Methods (Common) - // - - CUTLASS_DEVICE - TensorAccum& operator()() { - return accum_temp_; - } - - /// prepare the MMA accumulators when initialization or zeroing is required. - CUTLASS_DEVICE - bool prepare_if_needed() { - return reset_accum_flag_; - } - - // - // Methods (for FADD version) - // - - /// promote (add) the results from the MMA accumulators to main accumulator if needed. - CUTLASS_DEVICE - void promote_if_needed() { - mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); - if (reset_accum_flag_) { - promote_core(); - mma_count_ = 0; - } - } - - /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. - CUTLASS_DEVICE - void promote_residue_if_needed() { - if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - promote_core(); - } - } - - // - // Methods (for FFMA version) - // - - /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scale) { - mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); - if (reset_accum_flag_) { - scale_core(scale); - mma_count_ = 0; - } - } - - /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scale) { - if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - scale_core(scale); - } - } -}; - -} // namespace cutlass::gemm::collective diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp deleted file mode 100644 index ce7f47cf72337..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ /dev/null @@ -1,729 +0,0 @@ -// clang-format off -// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp - -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/trace.h" -#include "cutlass/numeric_types.h" - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm80.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cute/algorithm/functional.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" - -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template < - int Stages, - class ClusterShape, - class KernelSchedule, - int ScaleGranularityM_, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using ElementBlockScale = ElementAccumulator; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - using PipelineParams = typename MainloopPipeline::Params; - - // Two threads per CTA are producers (1 for operand tile and 32 for scales) - static constexpr int NumProducerThreadEvents = 33; - - static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - - // Block scaling gmem-to-smem copy atom - using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; - - // Block scaling smem layout - using SmemLayoutScaleA = Layout, Int>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v, - "ElementAccumulator and ElementBlockScale should be same datatype"); - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_A; // mxk - cute::array_aligned> smem_B; // nxk - cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k - cute::array_aligned> smem_scale_B; // 1xk - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - ElementBlockScale const* ptr_scale_A; - ElementBlockScale const* ptr_scale_B; - }; - - // Device side kernel params - struct Params { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy_A_sm90( - GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_,_,0), - TileShape{}, - ClusterShape{})); - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy_B_sm90( - GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_,_,0), - TileShape{}, - ClusterShape{})); - TMA_A tma_load_a; - TMA_B tma_load_b; - uint32_t tma_transaction_bytes = TmaTransactionBytes; - uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; - // Block scaling factors for A and B - ElementBlockScale const* ptr_scale_A; - ElementBlockScale const* ptr_scale_B; - }; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - uint32_t transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t transaction_bytes_nk = TmaTransactionBytesNK; - uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; - - return { - tma_load_a, - tma_load_b, - transaction_bytes, - transaction_bytes_mk, - transaction_bytes_nk, - args.ptr_scale_A, - args.ptr_scale_B - }; - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytesMK = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytesNK = - cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - template - CUTLASS_DEVICE auto - load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { - using X = Underscore; - // Separate out problem shape for convenience - auto [M,N,K,L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - constexpr auto scales_m = Int{}; - auto tM = get<2>(gA_mkl.shape()); - auto tN = get<2>(gB_nkl.shape()); - auto tK = get<3>(gA_mkl.shape()); - - // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) - auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); - auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) - auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); - - // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and - // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) - - return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template < - class TensorA, class TensorB, - class TensorScaleA, class TensorScaleB, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( - Params const& mainloop_params, - MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const& load_inputs, - BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { - int lane_predicate = cute::elect_one_sync(); - - // Blockscaling: Tma loads for load_input and CpAsync for load_scale - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - - // Block scaling: load_scale has scaling tensors in global memory which are not tiled - Tensor mScaleA_mkl = get<2>(load_inputs); - Tensor mScaleB_nkl = get<3>(load_inputs); - auto scales_m = get<0>(mScaleA_mkl.shape()); - - Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); - - Tensor gScaleA = local_tile( - mScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) - Tensor cScaleA = local_tile( - cScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); - Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) - - // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 - TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, - Layout>{}, Layout>{}); // (1,1,1) - TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, - Layout>{}, Layout>{}); // (1,1,1) - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); - - Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); - Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); - Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); - - Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); - Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - - // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors - // Maps the tile -> block, value - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } - } - - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } - } - - // Allocate predicate tensors for a_scales (since we can't guarantee that - // all scales are valid, since we could have a partial tiles along M) - Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); - #pragma unroll - for (int i = 0; i < size(tApA_ScaleA); ++i) { - tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - int write_stage = smem_pipe_write.index(); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - // Copy operands A and B from global memory to shared memory - if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - - // Copy scale tensors from global memory to shared memory - copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); - pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); - - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail( - MainloopPipeline pipeline, - PipelineState smem_pipe_write) { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template < - class FrgTensorC - > - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, - PipelineState smem_pipe_read, - FrgTensorC& accum, - int k_tile_count, - int thread_idx, - TensorStorage& shared_tensors, - Params const& mainloop_params) { - - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // Block scaling - Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), - Layout< - Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, - Stride, _0, Int> - >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) - - // - // Define C accumulators and A/B partitioning - // - - // Layout of warp group to thread mapping - - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and - stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); - - constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, - Int{}); - - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Per block scale values for operand A and B - - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. - using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above - - Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) - ElementBlockScale scale_b; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); - warpgroup_fence_operand(accumulation()); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - if (accumulation.prepare_if_needed()) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers. - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); - } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; - } - } - - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accumulation()); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); - } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; - } - } - - if (accumulation.prepare_if_needed()) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accumulation()); - - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); - - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - accumulation.scale_residue_if_needed(tCrScaleAViewAsC); - - warpgroup_fence_operand(accumulation()); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void - mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp deleted file mode 100644 index df809e27a3efe..0000000000000 --- a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "cutlass/gemm/dispatch_policy.hpp" - -namespace cutlass::gemm { - -////////////////////////////////////////////////////////////////////////////// - -// FP8 related policies (including Blocked Scaled Accumulation) -// `ScaleGranularityM` specifies scaling granularity along M, while zero-value -// `ScaleGranularityM` indicates that scaling granularity is -// `size<0>(TileShape_MNK{})` along M. -template -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum - : KernelTmaWarpSpecializedCooperative {}; - -// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp -// specialized dynamic schedule For FP8 kernels with Block Scaling -template , - class KernelSchedule = KernelTmaWarpSpecialized, - int ScaleGranularityM = - 0 // `ScaleGranularityM` specifies scaling granularity along M, - // while zero-value `ScaleGranularityM` indicates that scaling - // granularity is `size<0>(TileShape_MNK{})` along M. - > -struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 - : MainloopSm90TmaGmmaWarpSpecialized { - static_assert( - cute::is_same_v< - KernelSchedule, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< - ScaleGranularityM>>, - "KernelSchedule must be one of the warp specialized policies"); -}; - -////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh index e7fbba4cd4b0d..085ee1290031f 100644 --- a/csrc/cutlass_extensions/vllm_collective_builder.cuh +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" namespace cutlass::gemm::collective { using namespace cute; diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 1dd7101acc27d..5e742d0b02932 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = { **{ VLLMDataType.u4b8: "u4b8", VLLMDataType.u8b128: "u8b128", - } + }, } VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", - } + }, } VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { @@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { **{ VLLMDataType.u4b8: 4, VLLMDataType.u8b128: 8, - } + }, } VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: dict[Union[ - MixedInputKernelScheduleType, KernelScheduleType], str] = { - **KernelScheduleTag, # type: ignore - **{ - MixedInputKernelScheduleType.TmaWarpSpecialized: - "cutlass::gemm::KernelTmaWarpSpecialized", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: - "cutlass::gemm::KernelTmaWarpSpecializedPingpong", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: - "cutlass::gemm::KernelTmaWarpSpecializedCooperative", - } - } +VLLMKernelScheduleTag: dict[ + Union[MixedInputKernelScheduleType, KernelScheduleType], str +] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501 + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", # noqa: E501 + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", # noqa: E501 + }, +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index f7b75c48373f6..995374a50b037 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,6 +19,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) + // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. diff --git a/csrc/launch_bounds_utils.h b/csrc/launch_bounds_utils.h new file mode 100644 index 0000000000000..92d7ef802f97f --- /dev/null +++ b/csrc/launch_bounds_utils.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +// maximum blocks per SM cap +#ifndef VLLM_LAUNCH_BLOCKS_CAP + #define VLLM_LAUNCH_BLOCKS_CAP 4 +#endif + +// Compile-time estimate of max threads per SM for launch bounds. +// Families: 1024, 1536, 2048 threads/SM. +#ifndef VLLM_MAX_THREADS_PER_SM + #ifdef __CUDA_ARCH__ + + /* 1024 thr/SM: Turing (sm_75) */ + #if (__CUDA_ARCH__ == 750) + #define VLLM_MAX_THREADS_PER_SM 1024 + + /* 1536 thr/SM: Ampere GA10x (sm_86/87), Ada (sm_89), + GB20x consumer (sm_120/121), Thor (sm_101 or sm_110) */ + #elif (__CUDA_ARCH__ == 860) || (__CUDA_ARCH__ == 870) || \ + (__CUDA_ARCH__ == 890) || (__CUDA_ARCH__ == 1010) || \ + (__CUDA_ARCH__ == 1100) || (__CUDA_ARCH__ == 1200) || \ + (__CUDA_ARCH__ == 1210) + #define VLLM_MAX_THREADS_PER_SM 1536 + + /* 2048 thr/SM: Volta (sm_70/72), Ampere GA100 (sm_80), + Hopper (sm_90), Blackwell (sm_100/103) */ + #elif (__CUDA_ARCH__ == 700) || (__CUDA_ARCH__ == 720) || \ + (__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900) || \ + (__CUDA_ARCH__ == 1000) || (__CUDA_ARCH__ == 1030) + #define VLLM_MAX_THREADS_PER_SM 2048 + + /* Fallback: use 2048 for unknown future CCs */ + #else + #define VLLM_MAX_THREADS_PER_SM 2048 + #endif + + #else + /* Host pass (no __CUDA_ARCH__): neutral default */ + #define VLLM_MAX_THREADS_PER_SM 2048 + #endif +#endif + +// compute the number of blocks per SM to request in __launch_bounds__ +#define VLLM_BLOCKS_DIV(VAL) (VLLM_MAX_THREADS_PER_SM / (VAL)) +#define VLLM_CLAMP_BLOCKS_PER_SM(VAL) \ + (((VAL) <= 0) \ + ? 1 \ + : (((VAL) < VLLM_LAUNCH_BLOCKS_CAP) ? (VAL) : VLLM_LAUNCH_BLOCKS_CAP)) +#define VLLM_BLOCKS_PER_SM(BLOCK_THREADS) \ + VLLM_CLAMP_BLOCKS_PER_SM(VLLM_BLOCKS_DIV(BLOCK_THREADS)) + +// runtime-time helper to compute blocks/SM +static inline int vllm_runtime_blocks_per_sm(int block_threads) { + int device = -1; + cudaGetDevice(&device); + int max_threads_per_sm = VLLM_MAX_THREADS_PER_SM; + cudaDeviceGetAttribute(&max_threads_per_sm, + cudaDevAttrMaxThreadsPerMultiProcessor, device); + int blocks = (block_threads > 0) ? (max_threads_per_sm / block_threads) : 1; + return VLLM_CLAMP_BLOCKS_PER_SM(blocks); +} diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f051eb0702228..6c3685f6f7cdc 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,15 +1,11 @@ #include "type_convert.cuh" #include "dispatch_utils.h" +#include "cub_helpers.h" +#include "core/batch_invariant.hpp" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { // TODO(woosuk): Further optimize this kernel. @@ -30,7 +26,7 @@ __global__ void rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -85,7 +81,7 @@ fused_add_rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -126,7 +122,7 @@ fused_add_rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -140,6 +136,211 @@ fused_add_rms_norm_kernel( } } +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. + + _f16VecPN struct extends _f16Vec to add operations specifically required for + polynomial normalization (poly norm). + The original _f16Vec does not include the sum-of-powers computation or + in-place polynomial normalization logic. */ +template +struct alignas(16) _f16VecPN : _f16Vec { + using Base = _f16Vec; + using Converter = typename Base::Converter; + using T1 = typename Base::T1; + using T2 = typename Base::T2; + using Base::data; + + __device__ auto sum_pows() const { + float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f; + +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + float x2 = z.x * z.x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + float y2 = z.y * z.y; + float y4 = y2 * y2; + float y6 = y4 * y2; + + s2 += x2 + y2; + s4 += x4 + y4; + s6 += x6 + y6; + } + return std::make_tuple(s2, s4, s6); + } + + __device__ void poly_norm_inplace(const float w2_inv_std, + const float w1_inv_std2, + const float w0_inv_std3, const float bias) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + + float x2 = z.x * z.x; + float x3 = x2 * z.x; + z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias; + + float y2 = z.y * z.y; + float y3 = y2 * z.y; + z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias; + + auto out = Converter::convert(z); + data[i] = out.x; + data[i + 1] = out.y; + } + } +}; + +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16VecPN>); + static_assert(sizeof(_f16VecPN) == sizeof(scalar_t) * width); + + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast*>(input); + const int vec_hidden_size = hidden_size / width; + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + auto [x2, x4, x6] = temp.sum_pows(); + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); + + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + auto* __restrict__ out_v = reinterpret_cast<_f16VecPN*>(out); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias); + out_v[id] = temp; + } +} + +/* Generic poly_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); + + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x3 = x2 * x; + + out[blockIdx.x * hidden_size + idx] = + (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + + s_bias); + } +} + } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -213,9 +414,58 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] wt_ptr % req_alignment_bytes == 0; bool offsets_are_multiple_of_vector_width = hidden_size % vector_width == 0 && input_stride % vector_width == 0; - if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && offsets_are_multiple_of_vector_width && + !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); } } + +#define LAUNCH_FUSED_POLY_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ + vllm::poly_norm_kernel<<>>( \ + out.data_ptr(), input.data_ptr(), \ + weight.data_ptr(), bias.data_ptr(), epsilon, \ + hidden_size); \ + }); + +void poly_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [3] + torch::Tensor& bias, // [1] + double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.data_ptr() != input.data_ptr()); + + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) { + LAUNCH_FUSED_POLY_NORM(8); + } else { + LAUNCH_FUSED_POLY_NORM(0); + } +} diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fd5849d9626c..0fc462194fcde 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -6,18 +6,14 @@ */ #include "type_convert.cuh" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include "dispatch_utils.h" +#include "cub_helpers.h" +#include "core/batch_invariant.hpp" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { // TODO(woosuk): Further optimize this kernel. @@ -39,7 +35,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -100,7 +96,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -149,7 +145,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -245,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && + !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index c4ddbc142791f..d534e138d26d6 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -27,11 +27,12 @@ template + bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_> struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; using weight_t = weight_t_; + using state_t = state_t_; static constexpr int kNThreads = kNThreads_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; @@ -132,7 +133,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + + typename Ktraits::state_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + cache_index * params.ssm_states_batch_stride + dim_id * kNRows * params.ssm_states_dim_stride; @@ -261,7 +262,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; if (chunk == n_chunks - 1) { - ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y); + ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y); } } #pragma unroll @@ -310,7 +311,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } -template +template void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // processing 1 row. @@ -321,7 +322,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; + using Ktraits = Selective_Scan_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; @@ -341,59 +342,78 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { }); } -template +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { #ifndef USE_ROCM if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #else if (params.seqlen <= 256) { - selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #endif } -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ using weight_t = float; \ - __VA_ARGS__(); \ + if (STYPE == at::ScalarType::Half) { \ + using state_t = at::Half; \ + __VA_ARGS__(); \ + } else if (STYPE == at::ScalarType::Float) { \ + using state_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \ + } \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ using weight_t = float; \ - __VA_ARGS__(); \ + if (STYPE == at::ScalarType::BFloat16) { \ + using state_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (STYPE == at::ScalarType::Float) { \ + using state_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \ + } \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ using weight_t = float; \ + using state_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ } -template +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); void set_ssm_params_fwd(SSMParamsBase ¶ms, @@ -648,7 +668,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; - TORCH_CHECK(ssm_states.scalar_type() == input_type); + // ssm_states can now be either the same as input_type or float32 + auto state_type = ssm_states.scalar_type(); + TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float); TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.stride(-1) == 1); @@ -670,7 +692,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); }); } diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp new file mode 100644 index 0000000000000..1d06fc6b5b0a0 --- /dev/null +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -0,0 +1,156 @@ +#include +#include +#include + +// _dyn_quant_matmul_4bit is only available on AArch64. +#if defined(__aarch64__) + #include +#endif + +inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w, + int64_t group_size_eff, int64_t in_features, + int64_t out_features) { +#if defined(__aarch64__) + return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff, + in_features, out_features); +#else + TORCH_CHECK(false, + "dynamic 4-bit int MoE path requires AArch64 (ARM64); " + "_dyn_quant_matmul_4bit is unavailable on this architecture"); + return {}; +#endif +} + +enum ActivationKind : int64_t { + SwiGLU_Gu = 0, // act = SiLU(g) * u + SwiGLUOAI = 1, // act = SiLU(u) * g + SiLU = 2 // SiLU +}; + +torch::Tensor dynamic_4bit_int_moe_cpu( + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, + int64_t activation_kind) { + TORCH_CHECK(x.dim() == 2, "x must be 2D"); + TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2, + "topk tensors must be [T, K]"); + TORCH_CHECK( + w13_packed.size(0) == w2_packed.size(0), + "w13_packed and w2_packed must have same number of experts in dim 0"); + TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I"); + + const int64_t T = x.size(0); + const int64_t K = topk_ids.size(1); + const int64_t E = w13_packed.size(0); + const int64_t N = T * K; + + auto x_c = x.contiguous(); + auto ids_c = topk_ids.contiguous(); + auto gates_c = topk_weights.to(at::kFloat).contiguous(); + + // bucketing tokens -> experts + c10::SmallVector counts( + E, 0); // Small vector uses stack allocation + { + const auto* ids_ptr = ids_c.data_ptr(); + for (int64_t i = 0; i < N; ++i) { + const int64_t e_id = ids_ptr[i]; + TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range"); + counts[e_id]++; + } + } + c10::SmallVector offsets(E + 1, 0); // ( E +1 ) + for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e]; + + auto expert_tokens = at::empty({offsets[E]}, ids_c.options()); + auto expert_gates = at::empty({offsets[E]}, gates_c.options()); + { + c10::SmallVector cursor(E, 0); + const auto* ids_ptr = ids_c.data_ptr(); + const auto* gts_ptr = gates_c.data_ptr(); + auto* tok_ptr = expert_tokens.data_ptr(); + auto* gate_ptr = expert_gates.data_ptr(); + + for (int64_t t = 0; t < T; ++t) { + const int64_t base = t * K; + for (int64_t k = 0; k < K; ++k) { + const int64_t idx = base + k; + const int64_t e = ids_ptr[idx]; + const int64_t p = offsets[e] + (cursor[e]++); + tok_ptr[p] = t; + gate_ptr[p] = gts_ptr[idx]; + } + } + } + + const int64_t g_eff_13 = (group_size != -1) ? group_size : H; + const int64_t g_eff_2 = (group_size != -1) ? group_size : I; + + // Per-expert outputs filled in parallel + std::vector y_list(E); + y_list.resize(E); + + at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + for (int64_t e = e_begin; e < e_end; ++e) { + const int64_t te = counts[e]; + if (te == 0) { + y_list[e] = at::empty({0, H}, x_c.options()); + continue; + } + + const int64_t start = offsets[e]; + + auto sel_tokens = + expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); + auto gates_e = + expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); + + auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); + + if (apply_router_weight_on_input) { + x_e = x_e.mul(gates_e.unsqueeze(1)); + } + + auto w13_e = w13_packed.select(/*dim=*/0, e); + auto w2_e = w2_packed.select(/*dim=*/0, e); + + // W13 + auto y13 = + mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2); + + auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I); + auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I); + + torch::Tensor act; + if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI + constexpr double kAlpha = 1.702; // GPT-OSS default + constexpr double kLimit = 7.0; // GPT-OSS default + auto gate_c = at::clamp_max(g_part, kLimit); + auto up_c = at::clamp(u_part, -kLimit, kLimit); + auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha))); + act = up_c.add(1.0).mul(glu); + } else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul() + act = at::silu(g_part).mul(u_part); + } + + // W2 + auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); + + if (!apply_router_weight_on_input) { + y = y.mul(gates_e.unsqueeze(1)); + } + + // Store per-expert result + y_list[e] = y; + } + }); + + // Concatenate all expert outputs to match expert_tokens order + auto Y_all = at::cat(y_list, /*dim=*/0); + auto out = at::zeros({T, H}, x.options()); + out = + at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all); + + return out; +} diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu new file mode 100644 index 0000000000000..c93f9d54d780c --- /dev/null +++ b/csrc/moe/grouped_topk_kernels.cu @@ -0,0 +1,770 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +namespace cg = cooperative_groups; + +namespace vllm { +namespace moe { + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t BLOCK_SIZE = 512; +constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +namespace warp_topk { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) { + return 0; + } + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + +template +int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { + int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; + int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); + return max(cache_topk, + round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); +} + +template +struct BitonicMerge { + // input should be a bitonic sequence, and sort it to be a monotonic sequence + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + + // ascending doesn't matter before merging since all we need is a bitonic + // sequence + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + + T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + // load and merge k sorted values + __device__ void load_sorted(T const* __restrict__ in, + idxT const* __restrict__ in_idx, idxT start) { + idxT idx = start + WARP_SIZE - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { + if (idx < start + k_) { + T t = in[idx]; + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + } + + __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + __device__ void dumpIdx(idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out_idx[out_i] = idx_arr_[i]; + } + } + } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + int const lane_; + idxT const k_; + T const dummy_; + +}; // end class WarpSort + +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; + + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T const* in, idxT start, idxT end) { + idxT const end_for_fullwarp = + round_up_to_multiple_of(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); + if (mask == 0) { + return; + } + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + + // after done(), smem is used for merging results among warps + __syncthreads(); + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + + T& old = val_arr_[max_arr_len_ - 1]; + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; // end class WarpSelect +} // namespace warp_topk + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ inline T neg_inf() { + // cuda::std::numeric_limits::infinity() returns `0` for [T=bf16 or fp16] + // so we need to cast from fp32 + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + +template +__device__ inline bool is_finite(const T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + +template +__device__ void topk_with_k2(T* output, T const* input, + cg::thread_block_tile<32> const& tile, + int32_t const lane_id, + int const num_experts_per_group) { + // Get the top2 per thread + T largest = neg_inf(); + T second_largest = neg_inf(); + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + T value = input[i]; + if (value > largest) { + second_largest = largest; + largest = value; + } else if (value > second_largest) { + second_largest = value; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + largest = input[i]; + } + } + + __syncwarp(); // Ensure all threads have valid data before reduction + // Get the top2 warpwise + T max1 = cg::reduce(tile, largest, cg::greater()); + + T max2 = max1; + bool equal_to_max1 = (max1 == largest); + + int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); + + if (count_max1 == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + + if (lane_id == 0) { + *output = max1 + max2; + } +} + +template +__global__ void topk_with_k2_kernel(T* output, T* input, + int64_t const num_tokens, + int64_t const num_cases, + int64_t const n_group, + int64_t const num_experts_per_group) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + + int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; + if (case_id < num_cases) { + input += case_id * num_experts_per_group; + output += case_id; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +__global__ void group_idx_and_topk_idx_kernel( + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, + T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + int64_t const topk_group, int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool renormalize, + double routed_scaling_factor) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf); + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + s_topk_idx += warp_id * topk; + + T value = neg_inf(); + T topk_group_value = neg_inf(); + int32_t num_equalto_topkth_group; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + // The check is necessary to avoid abnormal input + if (lane_id < n_group && is_finite(group_scores[lane_id])) { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = neg_inf(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = + __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, neg_inf()); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = topk_group_value != neg_inf(); + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = (i < num_experts_per_group) && + is_finite(scores_with_bias[offset + i]) + ? scores_with_bias[offset + i] + : neg_inf(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = + i < topk + ? scores[s_topk_idx[i]] + : cuda_cast(0.0f); // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); + } + } + + __syncthreads(); + + if (case_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + topk_indices[i] = s_topk_idx[i]; + topk_values[i] = cuda_cast(value); + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + topk_indices[i] = i; + topk_values[i] = cuda_cast(1.0f / topk); + } + } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, + IdxT* topk_indices, T* scores_with_bias, + int64_t const num_tokens, int64_t const num_experts, + int64_t const n_group, int64_t const topk_group, + int64_t const topk, bool const renormalize, + double const routed_scaling_factor, bool enable_pdl = false, + cudaStream_t const stream = 0) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); +} + +#define INSTANTIATE_NOAUX_TC(T, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ + T * scores_with_bias, int64_t const num_tokens, \ + int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, bool const renormalize, \ + double const routed_scaling_factor, bool enable_pdl, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, int32_t); +INSTANTIATE_NOAUX_TC(half, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); +} // end namespace moe +} // namespace vllm + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor) { + auto data_type = scores_with_bias.scalar_type(); + auto input_size = scores_with_bias.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + + torch::Tensor group_scores = torch::empty( + {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + + switch (data_type) { + case torch::kFloat16: + // Handle Float16 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kFloat32: + // Handle Float32 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kBFloat16: + // Handle BFloat16 + vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), + num_tokens, num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + default: + // Handle other data types + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } + return {topk_values, topk_indices}; +} diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 698deb107cc06..be5b68cc53e6f 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -17,25 +17,30 @@ FILE_HEAD = """ namespace MARLIN_NAMESPACE_NAME { """.strip() -TEMPLATE = ("template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " - "{{s_type_id}}, " - "{{threads}}, " - "{{thread_m_blocks}}, " - "{{thread_n_blocks}}, " - "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " - "{{stages}}, " - "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" - "( MARLIN_KERNEL_PARAMS );") +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ - "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kU4", + "vllm::kU4B8", + "vllm::kU8B128", + "vllm::kFE4M3fn", + "vllm::kFE2M1f", ] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] @@ -58,11 +63,12 @@ def generate_new_kernels(): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" + "vllm::kU4B8", + "vllm::kU8B128", ]: continue if thread_configs[2] == 256: diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 8bbcf5a673fd3..629348bf88764 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -44,6 +44,9 @@ __global__ void moe_align_block_size_kernel( for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); @@ -95,12 +98,15 @@ template __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel) { + size_t numel, int32_t num_experts) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } @@ -269,7 +275,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, sort_kernel<<>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); + cumsum_buffer.data_ptr(), topk_ids.numel(), num_experts); } }); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 661730c96867e..92fc280b362b9 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 99c52ef17d08b..eca021f1c1863 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -20,17 +20,8 @@ #include #include #include "../cuda_compat.h" - -#ifndef USE_ROCM - #include - #include - #include - using AddOp = cuda::std::plus; -#else - #include - #include - using AddOp = cub::Sum; -#endif +#include "../cub_helpers.h" +#include "../core/batch_invariant.hpp" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -79,7 +70,7 @@ __launch_bounds__(TPB) __global__ threadData = max(static_cast(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); if (threadIdx.x == 0) { float_max = maxElem; @@ -94,7 +85,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp()); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); if (threadIdx.x == 0) { @@ -415,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); @@ -573,7 +565,7 @@ void topk_softmax( stream); } else { - assert(topk_indices.scalar_type() == at::ScalarType::Int64); + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7e49f68f62438..8f33d6cd666fa 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "output_tensor) -> ()"); m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + // Apply grouped topk routing to select experts. + m.def( + "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "topk_group, int topk, bool renormalize, float " + "routed_scaling_factor) -> (Tensor, Tensor)"); + m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 86fe848e2fd5a..2a9214e7fb03d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,11 +92,19 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& bias, double epsilon); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); +void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + torch::Tensor& values, int64_t numRows, int64_t stride0, + int64_t stride1); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); @@ -119,17 +127,24 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); -void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - std::optional key, - int64_t head_size, torch::Tensor& cos_sin_cache, - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets); - void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +#ifndef USE_ROCM +void silu_and_mul_nvfp4_quant(torch::Tensor& out, + torch::Tensor& output_block_scale, + torch::Tensor& input, + torch::Tensor& input_global_scale); +#endif +void persistent_masked_m_silu_mul_quant( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + bool use_ue8m0); + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -318,6 +333,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); +torch::Tensor dynamic_4bit_int_moe_cpu( + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, + int64_t activation_kind); + using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, @@ -337,6 +358,8 @@ std::tuple allocate_shared_buffer_and_handle( int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); +torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace); + #ifdef USE_ROCM fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); @@ -346,4 +369,4 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); -#endif \ No newline at end of file +#endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 266f2a0667a24..b5645b33b9073 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel( token_idx, query_stride, key_stride, head_stride); } -template -__global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // nullptr or - // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int64_t head_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = - cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - - apply_rotary_embedding( - query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride, head_stride); -} - } // namespace vllm void rotary_embedding( @@ -211,96 +182,3 @@ void rotary_embedding( } }); } - -/* -Batched version of rotary embedding, pack multiple LoRAs together -and process in batched manner. -*/ -void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - std::optional - key, // null or - // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] -) { - // num_tokens = batch_size * seq_len - int64_t num_tokens = cos_sin_cache_offsets.size(0); - TORCH_CHECK( - positions.size(0) == num_tokens || positions.numel() == num_tokens, - "positions must have the same num_tokens or batch_size as " - "cos_sin_cache_offsets"); - - int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key - TORCH_CHECK( - positions_ndim == 1 || positions_ndim == 2, - "positions must have shape [num_tokens] or [batch_size, seq_len]"); - if (positions_ndim == 1) { - TORCH_CHECK(query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)), - "query, key and positions must have the same number of tokens"); - } - if (positions_ndim == 2) { - TORCH_CHECK( - query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)) && - query.size(1) == positions.size(1) && - (!key.has_value() || key->size(1) == positions.size(1)), - "query, key and positions must have the same batch_size and seq_len"); - } - - // Make sure head_size is valid for query and key - int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; - TORCH_CHECK(query_hidden_size % head_size == 0); - TORCH_CHECK(key_hidden_size % head_size == 0); - - // Make sure query and key have concistent number of heads - int num_heads = query_hidden_size / head_size; - int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; - TORCH_CHECK(num_heads % num_kv_heads == 0); - - int seq_dim_idx = positions_ndim - 1; - int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; - // Determine head stride: for [*, heads, head_size] use stride of last dim; - // for flat [*, heads*head_size], heads blocks are contiguous of size - // head_size - int query_ndim = query.dim(); - int64_t head_stride = - (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } else { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } - }); -} diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5a..6fcd246f63c50 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -7,8 +7,33 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" +#include + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16_raw __nv_bfloat16_raw; + #if defined(HIP_FP8_TYPE_OCP) +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; + #else +// ROCm 6.2 fallback: only *_fnuz types exist +typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3; + #endif +#endif + +#include "core/registration.h" namespace vllm { template @@ -87,6 +112,429 @@ __global__ void act_and_mul_quant_kernel( } } } + +__device__ __forceinline__ float silu(float x) { + return __fdividef(x, (1.f + expf(-x))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) { +#ifndef USE_ROCM + return make_bfloat162(__float2bfloat16_rn(silu(x.x)), + __float2bfloat16_rn(silu(x.y))); +#else + return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y))); +#endif +} + +#ifndef USE_ROCM +__device__ __forceinline__ float warp_max(float v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} + +__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} +#endif + +template +__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + auto smem_ptr = reinterpret_cast(_smem_ptr); + auto glob_ptr = reinterpret_cast(_glob_ptr); + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +#else + _smem_ptr[0] = _glob_ptr[0]; +#endif +} + +__device__ __forceinline__ void cp_async_fence() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n" ::); +#else +#endif +} + +template +__device__ __forceinline__ void cp_async_wait() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else +#endif +} + +template <> +__device__ __forceinline__ void cp_async_wait<0>() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_all;\n" ::); +#else +#endif +} + +__device__ __forceinline__ float clip(float v, float mmin, float mmax) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + return fminf(mmax, fmaxf(v, mmin)); +#else +#endif +} + +__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, + __nv_bfloat16 mmin, + __nv_bfloat16 mmax) { + return __hmin(mmax, __hmax(v, mmin)); +} + +__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v, + __nv_bfloat162 mmin, + __nv_bfloat162 mmax) { + return __hmin2(mmax, __hmax2(v, mmin)); +} + +// We use the following values for fp8 min/max: +// __nv_fp8_e4m3 = (-448, +448) +// __nv_fp8_e4m3uz = (-240.0, +240.0) +// It is currently assumed that only +template +constexpr __nv_bfloat16 get_fp8_max() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264}); + } +} + +template +constexpr __nv_bfloat16 get_fp8_min() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); + } +} + +template +__device__ __forceinline__ int warp_expert_search( + int idx, int n, const Idx_t* __restrict__ input, Idx_t val) { + const Idx_t* input_ptr = input + idx; + int base_offset = 0; + + for (;;) { + bool move_on = (idx < n && *input_ptr <= val); + + unsigned mask = __ballot_sync(0xffffffff, move_on); + + if (mask != 0xffffffffu) { + int last_lane = 31 - __clz(mask); + return base_offset + last_lane; + } + + input_ptr += 32; + base_offset += 32; + idx += 32; + } +} + +template +__device__ __forceinline__ void token_bounds(int32_t n_tokens, + int32_t worker_id, + int32_t& n_tokens_lower, + int32_t& n_tokens_upper) { + if (n_tokens < num_parallel_tokens && worker_id < n_tokens) { + if (worker_id >= num_parallel_tokens) return; + n_tokens_lower = worker_id; + n_tokens_upper = worker_id + 1; + } else { + int32_t chunk_size = n_tokens / num_parallel_tokens; + int32_t residual = n_tokens - chunk_size * num_parallel_tokens; + auto calc_id = [&](int32_t id) { + if (id < residual) + return min(n_tokens, id * (chunk_size + 1)); + else + return min(n_tokens, id * chunk_size + residual); + }; + n_tokens_lower = calc_id(worker_id); + n_tokens_upper = calc_id(worker_id + 1); + } +} + +template +__global__ void silu_mul_fp8_quant_deep_gemm_kernel( + const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, + float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, + // sizes + Idx_t E, Idx_t T, Idx_t H, + // strides (in elements) + Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, + Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, + Idx_t stride_ys_g, Idx_t stride_counts_e) { +#ifndef USE_ROCM + static constexpr int NUM_WARPS = THREADS / WARP_SIZE; + + static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8; + static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE; + + static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4; + static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES; + + extern __shared__ __align__(16) __int128_t smem_128[]; + + int* s_expert_offsets = + reinterpret_cast(smem_128 + (SMEM_SIZE_BYTES_Y / 16)); + + static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); + static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); + // We assign EPS with it's 16-bit unsigned counterpart to allow constexpr. + static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); + int tid = threadIdx.x; + int warp_id = tid >> 5; + int lane_id = tid & 0x1f; + + int running_sum{}; + if (!warp_id) { + for (int i = 0; i < E; i += WARP_SIZE) { + bool valid = (i + threadIdx.x) < E; + int value = + (valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) + + (!lane_id ? running_sum : 0); + + for (int offset = 1; offset < 32; offset *= 2) { + int n = __shfl_up_sync(0xFFFFFFFFu, value, offset); + if (lane_id >= offset) value += n; + } + + if (valid) { + s_expert_offsets[i + threadIdx.x + 1] = value; + } + + running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1); + } + + if (!lane_id) { + s_expert_offsets[0] = 0; + } + } + + __syncthreads(); + + int32_t total_tokens = s_expert_offsets[E]; + + const int warp_position_yq = warp_id * (H / NUM_WARPS); + const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS)); + + // A single block will handle tokens_per_block tokens. + // Each block i iterates over tokens of a slice of n_tokens = + // expert_counts[i], with the size of chunk being + // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of + // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. + + // Each warp will get space to store its hidden dim for gate and up. + __int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES); + __int128_t* smem_load_ptr = s_hidden_load + lane_id; + + const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max); + + int32_t compute_pipeline_offset_64 = 0; + int32_t load_stage_offset{}; + const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f); + + __int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) + + warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) + + lane_id; + __int64_t* s_gate64_ptr = smem_compute_ptr; + __int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4; + + int tokens_lower, tokens_upper; + + token_bounds(total_tokens, blockIdx.x, tokens_lower, + tokens_upper); + + Idx_t expert_id{}, expert_offset{}, next_expert_offset{}; + int token_id = tokens_lower; + int32_t t_load{}; + + if (token_id < tokens_upper) { + expert_id = warp_expert_search(lane_id, E, s_expert_offsets, token_id); + expert_offset = s_expert_offsets[expert_id]; + next_expert_offset = s_expert_offsets[expert_id + 1]; + } else { + // This thread block has no work to do. + return; + } + + int t_load_bound = H / (GROUP_SIZE * NUM_WARPS); + + Idx_t base_i = ((expert_id * stride_i_e) / 8) + + (token_id - expert_offset) * stride_i_t / 8; + const Idx_t gate_warp_offset = + warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111); + + const __int128_t* input_128_ptr = + reinterpret_cast(_input) + gate_warp_offset + + ((lane_id < 16) ? 0 : ((H * stride_i_h) / 8)); + __int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + + auto token_offset = token_id - expert_offset; + + auto load_and_advance_y_pred = [&] { + if (t_load < t_load_bound) { + // Here we are simply continuing to load data + // from the current token. + auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + load_stage_offset += LOAD_STAGE_SIZE; + load_stage_offset %= LOAD_STAGE_MOD; + + cp_async4(smem_load_ptr_staged, load_ptr); + load_ptr += GROUP_SIZE / 8; + ++t_load; + } else if (token_id + 1 < tokens_upper) { + // We loaded everything from the current token, let's move on + // to the next one, and we checked that we have more tokens to load. + ++token_id; + t_load = 0; + if (token_id >= next_expert_offset) { + // We need to find the next expert. + do { + // This is a loop because it's possible + // that some experts are assigned 0 tokens. + // NOTE: We are guaranteed that there's at least + // one more token left so we don't have to check for + // expert_id bounds. + ++expert_id; + // This skips 1 memory read. + expert_offset = next_expert_offset; + next_expert_offset = s_expert_offsets[expert_id + 1]; + } while (next_expert_offset == expert_offset); + + base_i = expert_id * (stride_i_e / 8); + token_offset = 0; + load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + } else { + // We remain within the same expert, so just + // move by H/4 __int128_t (2 * H/8). + base_i += stride_yq_t / 4; + token_offset++; + } + + load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + + auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + load_stage_offset += LOAD_STAGE_SIZE; + load_stage_offset %= LOAD_STAGE_MOD; + + cp_async4(smem_load_ptr_staged, load_ptr); + load_ptr += GROUP_SIZE / 8; + ++t_load; + } + // We fence even if there is nothing to load to simplify pipelining. + cp_async_fence(); + }; + + // We need to warm-up the pipeline. + #pragma unroll + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_and_advance_y_pred(); + } + + __nv_fp8x4_e4m3* y_q_base_ptr = + reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; + auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g; + + for (auto j = tokens_lower; j < tokens_upper; j++) { + const Idx_t base_ys = expert_id * stride_ys_e; + auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t; + __nv_fp8x4_e4m3* y_q_ptr = + y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t + + warp_position_yq * stride_yq_h) / + 4; + const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS); + + for (int i = 0; i < COMPUTE_LIMIT; i++) { + cp_async_wait(); + __syncthreads(); + load_and_advance_y_pred(); + + __int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64; + __int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64; + + // COMPUTE_STAGE_SIZE/MOD must also be constexpr! + compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE; + compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD; + + __int64_t gate64 = *gate64_ptr; + __int64_t up64 = *up64_ptr; + + // Compute + __nv_bfloat162 res[2]; + __nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64); + __nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64); + + #pragma unroll + for (int32_t k = 0; k < 2; ++k) { + __nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k])); + res[k] = __hmul2(gate, s_up_comp[k]); + } + + auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1])); + + _y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS); + + __nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv); + + if constexpr (USE_UE8M0) { + y_s = hexp2(hceil(hlog2(y_s))); + } + + __nv_bfloat16 inv_y = __hdiv(one_bf16, y_s); + + auto y_s2 = make_bfloat162(inv_y, inv_y); + + #pragma unroll + for (int32_t k = 0; k < 2; ++k) { + res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min), + __bfloat162bfloat162(fp8_max)); + } + + *y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]); + y_q_ptr += WARP_SIZE * stride_yq_h; + + if (!lane_id) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_g; + } + } + } +#endif +} + } // namespace vllm // Launch activation, gating, and quantize kernel. @@ -119,3 +567,86 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void persistent_masked_m_silu_mul_quant( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& tokens_per_expert, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + bool use_ue8m0) { +#ifndef USE_ROCM + + // This kernel currently only supports H % 128 == 0 and assumes a + // fixed GROUP_SIZE of 128. + TORCH_CHECK(input.dtype() == torch::kBFloat16); + TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || + y_q.dtype() == torch::kFloat8_e4m3fnuz); + TORCH_CHECK(y_s.dtype() == torch::kFloat32); + TORCH_CHECK(input.size(-1) % 256 == 0); + + using Idx_t = int64_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + Idx_t stride_counts_e = tokens_per_expert.stride(0); + + static constexpr int GROUP_SIZE = 128; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ + static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ + int sms = SILU_V2_BLOCK_COUNT; \ + static constexpr int max_shared_mem_bytes = \ + GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \ + dim3 grid(sms), block(THREAD_COUNT); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + VLLM_DISPATCH_FP8_TYPES( \ + y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ + BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \ + USE_UE8M0, GROUP_SIZE, STAGES> \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(tokens_per_expert.data_ptr()), E, \ + T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ + stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \ + stride_ys_g, stride_counts_e); \ + }); + + static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + + if (!use_ue8m0) { + if (H >= 4096) { + static constexpr int NUM_STAGES = 4; + static constexpr int THREAD_COUNT = 256; + KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); + } else { + static constexpr int THREAD_COUNT = 32; + KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); + } + } else { + if (H >= 4096) { + static constexpr int NUM_STAGES = 4; + static constexpr int THREAD_COUNT = 256; + KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); + } else { + static constexpr int THREAD_COUNT = 32; + KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); + } + } + +#endif +} diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu new file mode 100644 index 0000000000000..2d1568b08651c --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -0,0 +1,494 @@ +// +// Based off of: +// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +// + +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" + +#include "core/registration.h" + +#include "cutlass/cutlass.h" +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +#include + +namespace vllm::cutlass_w4a8 { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using MmaType = cutlass::float_e4m3_t; // A/scale element type +using QuantType = cutlass::int4b_t; // B element type (packed int4) + +static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr ScalePackSize = 8; // pack 8 scale elements together +static int constexpr PackFactor = 8; // 8 4-bit packed into int32 + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) +using StrideA = cutlass::detail::TagToStrideA_t; + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reordered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout, StrideB>{})); + +// Group-wise scales +using ElementScale = MmaType; +using LayoutScale = cutlass::layout::RowMajor; + +// Per-tok, per-chan scales +using ElementSChannel = float; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch + // based on the default + // setting in the + // Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +// ---------------------------------------------------------------------------- +// Kernel template — Tile/Cluster shapes +// ---------------------------------------------------------------------------- +template +struct W4A8GemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // Epilogue per-tok, per-chan scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C + // matrix. We can enable this if beta == 0 by changing ElementC to + // void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, + AlignmentC, ElementD, + typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule, // This is the only epi supporting the required + // swap + transpose. + EVTCompute>::CollectiveOp; + + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, CollectiveEpilogue>; + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + using StrideD = typename GemmKernelShuffled::StrideD; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + + static torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type) { + // TODO: param validation + int m = A.size(0); + int k = A.size(1); + int n = B.size(1); + + // safely cast group_size to int + TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits::max(), + "group_size out of supported range for int: ", group_size); + int const group_size_int = static_cast(group_size); + + // Allocate output + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto device = A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor D = + torch::empty({m, n}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + // prepare arg pointers + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.data_ptr()); + // can we avoid hardcode the 8 here + auto S_ptr = + static_cast const*>( + group_scales.const_data_ptr()); + + // runtime layout for B + auto shape_B = cute::make_shape(n, k, 1); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + // strides + int const scale_k = cutlass::ceil_div(k, group_size_int); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + // Reverse stride here due to swap and transpose + StrideD stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(n, scale_k, 1)); + + // Create a structure of gemm kernel arguments suitable for invoking an + // instance of Gemm auto arguments = + // args_from_options(options); + /// Populates a Gemm::Arguments structure from the given arguments + /// Swap the A and B tensors, as well as problem shapes here. + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + + MainloopArguments mainloop_arguments{ + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size_int}; + + EpilogueArguments epilogue_arguments{ + ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), + nullptr, + {}, // no C + D_ptr, + stride_D}; + + Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, // shape + mainloop_arguments, + epilogue_arguments}; + + // Workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + + return D; + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Kernel_256x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x256_2x1x1 = + W4A8GemmKernel, Shape<_2, _1, _1>>; +using Kernel_128x256_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; + +torch::Tensor mm_dispatch(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { + if (schedule == "256x128_1x1x1") { + return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x64_1x1x1") { + return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x32_1x1x1") { + return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x16_1x1x1") { + return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_2x1x1") { + return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_1x1x1") { + return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x128_1x1x1") { + return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x64_1x1x1") { + return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x32_1x1x1") { + return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x16_1x1x1") { + return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } + TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + return {}; +} + +torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { + // requested a specific schedule + if (maybe_schedule) { + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, *maybe_schedule); + } + std::string schedule; + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + // heuristic + if (M <= 16) { + schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; + } else if (M <= 32) { + schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; + } else if (M <= 64) { + if (K == 16384 && N == 18432) + schedule = "256x64_1x1x1"; + else if (N <= 8192 && K <= 8192) + schedule = "128x32_1x1x1"; + else + schedule = "128x64_1x1x1"; + } else if (M <= 128) { + if (K == 16384 && N == 18432) + schedule = "256x128_1x1x1"; + else if (N <= 8192) + schedule = "128x64_1x1x1"; + else + schedule = "128x128_1x1x1"; + } else if (M <= 256) { + if (N <= 4096) + schedule = "128x64_1x1x1"; + else if (N <= 8192) + schedule = "128x128_1x1x1"; + else + schedule = "128x256_1x1x1"; + } else if (M <= 512 && N <= 4096) { + schedule = "128x128_1x1x1"; + } else if (M <= 1024) { + schedule = "128x256_1x1x1"; + } else { + schedule = "128x256_2x1x1"; + } + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, schedule); +} + +// ---------------------------------------------------------------------------- +// Pre-processing utils +// ---------------------------------------------------------------------------- +torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { + TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(scales.is_cuda()); + + auto packed_scales = torch::empty( + {scales.numel() * ScalePackSize}, + torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto scales_ptr = static_cast(scales.const_data_ptr()); + auto packed_scales_ptr = + static_cast*>( + packed_scales.data_ptr()); + + cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); + + return packed_scales; +} + +/* + GPU-accelerated implementation of cutlass::unified_encode_int4b. + Constructs a lookup table in constant memory to map 8 bits + (two 4-bit values) at a time. Assumes memory is contiguous + and pointers are 16-byte aligned. +*/ +__constant__ uint8_t kNibbleLUT[256]; + +__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, + size_t nbytes) { + constexpr size_t V = sizeof(uint4); // 16 bytes + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t nthreads = size_t(gridDim.x) * blockDim.x; + const size_t nvec = nbytes / V; + + // 1-D grid-stride loop over 16-byte chunks + for (size_t vec = tid; vec < nvec; vec += nthreads) { + uint4 v = reinterpret_cast(in)[vec]; + uint8_t* b = reinterpret_cast(&v); +#pragma unroll + for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; + reinterpret_cast(out)[vec] = v; + } +} + +static bool upload_lut() { + std::array lut{}; + auto map_nib = [](uint8_t v) -> uint8_t { + // 1..7 -> (8 - v); keep 0 and 8..15 + return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); + }; + for (int b = 0; b < 256; ++b) { + uint8_t lo = b & 0xF; + uint8_t hi = (b >> 4) & 0xF; + lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); + } + cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), + /*offset=*/0, cudaMemcpyHostToDevice); + + return (e == cudaSuccess); +} + +static bool unified_encode_int4b(cutlass::int4b_t const* in, + cutlass::int4b_t* out, size_t num_int4_elems) { + // Build/upload LUT + if (!upload_lut()) return false; + + static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, + "int4 storage must be 1 byte"); + const size_t nbytes = num_int4_elems >> 1; + + auto* in_bytes = reinterpret_cast(in); + auto* out_bytes = reinterpret_cast(out); + + // kernel launch params + constexpr int block = 256; + const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors + int grid = int((nvec + block - 1) / block); + if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel + + unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); + cudaError_t err = cudaGetLastError(); + return (err == cudaSuccess); +} + +torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { + TORCH_CHECK(B.dtype() == torch::kInt32); + TORCH_CHECK(B.dim() == 2); + + torch::Tensor B_packed = torch::empty_like(B); + + int k = B.size(0) * PackFactor; // logical k + int n = B.size(1); + TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks"); + + auto B_ptr = static_cast(B.const_data_ptr()); + auto B_packed_ptr = static_cast(B_packed.data_ptr()); + auto shape_B = cute::make_shape(n, k, 1); + auto layout_B = make_layout(shape_B, LayoutRight{}); // row major + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + bool ok = + vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + TORCH_CHECK(ok, "unified_encode_int4b failed"); + cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); + + return B_packed; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_mm", &mm); + m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); + m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8 \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh deleted file mode 100644 index e089c3d4be2cc..0000000000000 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ /dev/null @@ -1,194 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -#include "cute/tensor.hpp" -#include "cutlass/tensor_ref.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - -#include "cutlass_gemm_caller.cuh" - -namespace vllm { - -using namespace cute; - -template > -struct cutlass_3x_gemm_fp8_blockwise { - using GroupSizeM = Int; - using GroupSizeN = Int; - using GroupSizeK = Int; - using TileSizeM = Int; - - static_assert(TileSizeM_ % GroupSizeM_ == 0, - "TileSizeM must be a multiple of GroupSizeM"); - - using ElementAB = cutlass::float_e4m3_t; - - using ElementA = ElementAB; - using LayoutA = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - - using ElementB = ElementAB; - using LayoutB = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - - using ElementD = OutType; - using StrideD = Stride, Int<0>>; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - - using ElementC = void; - using StrideC = StrideD; - static constexpr int AlignmentC = AlignmentD; - - using ElementAccumulator = float; - using ElementBlockScale = float; - using ElementCompute = float; - using ArchTag = cutlass::arch::Sm90; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = Shape; - - using KernelSchedule = cutlass::gemm:: - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< - GroupSizeM_>; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; - - using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< - cutlass::epilogue::fusion::Sm90AccFetch>; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, - ElementD, StrideD, AlignmentD, EpilogueSchedule, - StoreEpilogueCompute>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, - LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - SchedulerType>>; - - struct GemmKernel : public KernelType {}; - - using StrideA = typename GemmKernel::StrideA; - using StrideB = typename GemmKernel::StrideB; -}; - -template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - using GemmKernel = typename Gemm::GemmKernel; - - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - auto prob_shape = c3x::get_problem_shape(a, b); - int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), - k = get<2>(prob_shape); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); - - // Check is the t is contiguous and is 1D or 2D with one of the dimensions - // being 1 (i.e. a row or column vector) - auto is_contiguous_vector = [](const torch::Tensor& t) { - auto t_sizes = t.sizes(); - return t.is_contiguous() && - (t.dim() == 1 || - (t.dim() == 2 && - *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); - }; - - // TODO(lucas): lets clean-up the kernel so that we pass in Strides so - // we don't have to deal with enforcing implicit layouts - TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); - TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), - "a_scales must be M major"); - TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); - TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), - "b_scales must be K major"); - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - {}, c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::TileSchedulerArguments scheduler; - - static constexpr bool UsesStreamKScheduler = - cute::is_same_v; - - if constexpr (UsesStreamKScheduler) { - using DecompositionMode = typename cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using ReductionMode = typename cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::ReductionMode; - - scheduler.decomposition_mode = DecompositionMode::StreamK; - scheduler.reduction_mode = ReductionMode::Nondeterministic; - } - - c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, - epilogue_args, scheduler); -} - -template -void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto k = a.size(1); - auto n = b.size(1); - - if (k > 3 * n) { - cutlass_gemm_caller_blockwise>( - out, a, b, a_scales, b_scales); - } else { - cutlass_gemm_caller_blockwise>( - out, a, b, a_scales, b_scales); - } -} - -} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu new file mode 100644 index 0000000000000..7539f836ecf37 --- /dev/null +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" +#include "launch_bounds_utils.h" +#include "nvfp4_utils.cuh" + +namespace vllm { + +// silu in float32 +__device__ __forceinline__ float silu(float x) { + return __fdividef(x, (1.f + __expf(-x))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +template +__inline__ __device__ PackedVec compute_silu_mul(PackedVec& vec, + PackedVec& vec2) { + PackedVec result; + using packed_type = typename TypeConverter::Type; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + // silu_mul in float32 + if constexpr (std::is_same_v) { + float2 silu_vec = silu2(__half22float2(vec.elts[i])); + result.elts[i] = + __float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i]))); + } else { + float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i])); + result.elts[i] = __float22bfloat162_rn( + __fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i]))); + } + } + return result; +} + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) + silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, + uint32_t* SFout) { + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + auto& out_pos = out[outOffset]; + + // Compute silu and mul + PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = cvt_warp_fp16_to_fp4(out_silu_mul, SFScaleVal, + sf_out); + } + } +} + +} // namespace vllm + +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::silu_mul_cvt_fp16_to_fp4<<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); +} diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 03db5cc196d59..5b007e5ea3283 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -1,3 +1,21 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/registration.h" + #include #include @@ -402,3 +420,7 @@ void cutlass_fp4_group_mm( "12.8 or above."); #endif } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm); +} diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 190d66f318a83..6d385e0dd94e7 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -1,247 +1,43 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include +#include +#include + #include #include -#include #include +#include "dispatch_utils.h" -template -struct TypeConverter { - using Type = half2; -}; // keep for generality +#include "nvfp4_utils.cuh" +#include "launch_bounds_utils.h" -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts, + bool low_latency) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -299,8 +95,8 @@ cvt_fp16_to_fp4( &input_offset_by_experts[chunk_start + 12])); local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); - // Check against the 16 loaded offsets - #pragma unroll +// Check against the 16 loaded offsets +#pragma unroll for (int i = 0; i < 16; i++) { if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { rowIdx_in_expert = rowIdx - local_offsets[i]; @@ -330,21 +126,15 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -425,7 +215,6 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } template @@ -445,8 +234,9 @@ void quant_impl(void* output, void* output_scale, void* input, int const workSizePerRow = k / ELTS_PER_THREAD; int const totalWorkSize = m_topk * workSizePerRow; dim3 block(std::min(workSizePerRow, 512)); - // Get number of blocks per SM (assume we can fully utilize the SM). - int const numBlocksPerSM = 2048 / block.x; + // Get number of blocks per SM + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM)); while (grid.x <= multiProcessorCount && block.x > 64) { @@ -501,6 +291,8 @@ void quant_impl(void* output, void* output_scale, void* input, } } +} // namespace vllm + /*Quantization entry for fp4 experts quantization*/ #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ @@ -560,23 +352,17 @@ void scaled_fp4_experts_quant_sm100a( // 4 means 4 fp8 values are packed into one int32 TORCH_CHECK(output_scale.size(1) * 4 == padded_k); - auto in_dtype = input.dtype(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); - if (in_dtype == at::ScalarType::Half) { - quant_impl(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, - n_experts, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, - k, n_experts, stream); - } else { - TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); - } + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 1b61bd4519fc3..c2b39e5438805 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a( torch::Tensor const& output_scale_offset_by_experts); #endif +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, + torch::Tensor& output_sf, + torch::Tensor& input, + torch::Tensor& input_sf); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -54,3 +62,13 @@ void scaled_fp4_experts_quant( TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); } + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, + torch::Tensor& input, torch::Tensor& input_sf) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 4e080de151648..5575ee8e4197e 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -23,245 +23,19 @@ #include #include +#include "dispatch_utils.h" #include "cuda_utils.h" +#include "launch_bounds_utils.h" +#include "nvfp4_utils.cuh" -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -293,7 +67,6 @@ cvt_fp16_to_fp4( cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } } -#endif } template @@ -303,8 +76,9 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, // Grid, Block size. // Each thread converts 8 values. dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - // Get number of blocks per SM (assume we can fully utilize the SM). - int const numBlocksPerSM = 2048 / block.x; + // Get number of blocks per SM + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); // Launch the cvt kernel. @@ -332,6 +106,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, int multiProcessorCount, cudaStream_t stream); +} // namespace vllm + void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, @@ -340,6 +116,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, int32_t n = input.size(1); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); @@ -353,24 +132,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, // We don't support e8m0 scales at this moment. bool useUE8M0 = false; - switch (input.scalar_type()) { - case torch::kHalf: { - auto input_ptr = reinterpret_cast(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - case torch::kBFloat16: { - auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - default: { - std::cerr << "Observing: " << input.scalar_type() - << " for the input datatype which is invalid"; - throw std::runtime_error( - "Unsupported input data type for quantize_to_fp4."); - } - } + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, + sf_out, useUE8M0, multiProcessorCount, stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh new file mode 100644 index 0000000000000..48e4959de9793 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +namespace vllm { + +// Convert PyTorch cpp type to CUDA type +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } + return nullptr; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 3f188872d80d3..2d2fd771205c7 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -8,11 +8,7 @@ #include "quantization/utils.cuh" #include "quant_conversions.cuh" -#ifndef USE_ROCM - #include -#else - #include -#endif +#include "../../cub_helpers.h" namespace vllm { @@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); __shared__ float s_rms; if (threadIdx.x == 0) { @@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales( __shared__ typename BlockReduce::TempStorage reduceStore; block_absmax_val_maybe = BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { @@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); __shared__ float s_rms; if (threadIdx.x == 0) { @@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales( __shared__ typename BlockReduce::TempStorage reduceStore; block_absmax_val_maybe = BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 4e6118e52e8d6..2b1eb1d568e4e 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -6,7 +6,7 @@ #include "quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" namespace vllm { diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 7576e0548abe9..42d3b456096ee 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -17,28 +17,32 @@ FILE_HEAD = """ namespace MARLIN_NAMESPACE_NAME { """.strip() -TEMPLATE = ("template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " - "{{s_type_id}}, " - "{{threads}}, " - "{{thread_m_blocks}}, " - "{{thread_n_blocks}}, " - "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " - "{{stages}}, " - "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" - "( MARLIN_KERNEL_PARAMS );") +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ - "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kU4", + "vllm::kU4B8", + "vllm::kU8B128", + "vllm::kFE4M3fn", + "vllm::kFE2M1f", ] -THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), - (128, 64, 128)] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] # group_blocks: @@ -59,11 +63,12 @@ def generate_new_kernels(): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" + "vllm::kU4B8", + "vllm::kU8B128", ]: continue if thread_configs[2] == 256: @@ -93,8 +98,7 @@ def generate_new_kernels(): c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and \ - group_blocks == 4: + if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: # HQQ (is_zp_float = true) only supports # 4bit quantization and fp16 is_zp_float_list.append(True) diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu new file mode 100644 index 0000000000000..5369d409f9b21 --- /dev/null +++ b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu @@ -0,0 +1,817 @@ +// clang-format off +// Adapted from: https://github.com/meta-pytorch/applied-ai/blob/main/kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu + +/*********** +Copyright 2024 Meta + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +***********/ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/registration.h" +#include "dispatch_utils.h" + +namespace hadacore { + +#ifndef __CUDACC__ +#define __launch_bounds__(x,y) +#endif + +#define MAX_WARPS_PER_SM 48 + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +using b16 = uint16_t; +using b32 = uint32_t; + +constexpr int launch_configs_big[7][3] = { + // default + {2, 1, 24}, + {2, 2, 16}, + {2, 4, 8}, + {2, 8, 4}, + {2, 16, 3}, + {4, 16, 2}, + {8, 16, 1} + // // extra coalescing + // {2, 1, 24}, + // {2, 2, 16}, + // {2, 4, 8}, + // {2, 8, 4}, + // {4, 8, 3}, + // {8, 8, 2}, + // {16, 8, 1} + // // less coalescing + // {2, 1, 24}, + // {2, 2, 16}, + // {2, 4, 8}, + // {2, 8, 4}, + // {1, 32, 1}, + // {2, 32, 1}, + // {4, 32, 1} +}; + +// a 4x2, b 2x2, c 2x2 +template +__device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){ + static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16); + // d, a, b, c + b32 zero = 0; + if constexpr(dtype == torch::ScalarType::Half) { + asm ( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n\t" + : "=r"(c0), "=r"(c1) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero) + ); + } else { + b32 temp0, temp1, temp2, temp3; + asm ( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero), "r"(zero), "r"(zero) + ); + asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + } +} + +// a 4x2, b 4x2, c 4x2 +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){ + mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b0, b1, c0, c1); + mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b2, b3, c2, c3); +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) { + asm ( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) : "r"(a0) + ); +} + +#define p_p(i) ((val_1p[i] & 0x0000FFFF) | val_1p[i] << 16) +#define p_n(i) ((val_1p[i] & 0x0000FFFF) | val_1n[i] << 16) +#define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16) +#define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16) + +template +__global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm) +// a is column major, b is row major +hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { + static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + + b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads) + + int64_t blockid = blockIdx.x * warps_per_block + threadIdx.x / 32; + int64_t threadid = threadIdx.x % 32; + extern __shared__ b32 bfrag_arr[]; // num_chunks * warps_per_block * 128 + int64_t real_num_chunks = ((blockid + 1) * num_chunks) > total_num_chunks ? (total_num_chunks - (blockid * num_chunks)) : num_chunks; + int64_t diff_num_chunks = real_num_chunks - num_chunks; + + b32* a_start_ptr = (b32*) (a + blockid * num_chunks * 256); // offset a to where this warp starts + b32* out_start_ptr = (b32*) (out + blockid * num_chunks * 256); + b32* a_ptr = a_start_ptr + threadid * 4; + b32* b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128 + threadid * 4; + + #if (__CUDA_ARCH__ < 900) // SM80, SM89 + uint64_t cache_policy; + asm volatile( + "createpolicy.fractional.L2::evict_first.b64 %0, 1.0;\n" + : "=l"(cache_policy) + ); + #endif + + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + size_t shared_ptr = __cvta_generic_to_shared(b_frag_ptr); + #if (__CUDA_ARCH__ >= 900) // SM90 + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;\n" + "cp.async.commit_group;\n" + :: "l"(shared_ptr), "l"(a_ptr) + ); + #else // SM80, SM89 + asm volatile( + "cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2;\n" + "cp.async.commit_group;\n" + :: "l"(shared_ptr), "l"(a_ptr), "l"(cache_policy) + ); + #endif + + a_ptr += 128; + b_frag_ptr += 128; + } + + // generate hadamard 16x16 (up to 2 of them) + constexpr b16 fp16_1p[4] = {0b0011100110101000, 0b0011100000000000, 0b0011010110101000, 0b0011010000000000}; + constexpr b16 fp16_1n[4] = {0b1011100110101000, 0b1011100000000000, 0b1011010110101000, 0b1011010000000000}; + constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000}; + constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000}; + + #define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) + #define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) + constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)}; + constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)}; + + constexpr b32 p_p[4] = {p_p(0), p_p(1), p_p(2), p_p(3)}; + constexpr b32 p_n[4] = {p_n(0), p_n(1), p_n(2), p_n(3)}; + constexpr b32 n_p[4] = {n_p(0), n_p(1), n_p(2), n_p(3)}; + constexpr b32 n_n[4] = {n_n(0), n_n(1), n_n(2), n_n(3)}; + const b32 had_16_p1[4][4] = { + { + 0b10001000010001000010001000010001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b10001000010001000010001000010001 + }, + { + 0b11001100100010000011001100100010, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11001100100010000011001100100010 + }, + { + 0b11111111101010101100110010011001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11111111101010101100110010011001 + }, + { + 0b11111111101010101100110010011001, + 0b11111111101010101100110010011001, + 0b11111111101010101100110010011001, + 0b00000000010101010011001101100110 + } + }; + const b32 had_16_p2[4][4] = { + { + 0b10000000010000000010000000010000, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b10000000010000000010000000010000 + }, + { + 0b11000000100001000011000000100001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11000000100001000011000000100001 + }, + { + 0b11110000101001011100001110010110, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11110000101001011100001110010110 + }, + { + 0b11110000101001011100001110010110, + 0b11110000101001011100001110010110, + 0b11110000101001011100001110010110, + 0b00001111010110100011110001101001 + } + }; + const b32 had_16_mask[3][4] = { + { + 0b10001000010001000010001000010001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b10001000010001000010001000010001 + }, + { + 0b11001100110011000011001100110011, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11001100110011000011001100110011 + }, + { + 0b11111111111111111111111111111111, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11111111111111111111111111111111 + } + }; + b32 had_frag[8]; + #pragma unroll + for (int64_t i = 0; i < 2; i++) { + int64_t c_log_h = (i == 0) ? MIN(4, log_had_size) : log_had_size % 4; + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + if (c_log_h < 4) { + bool mask = had_16_mask[c_log_h - 1][j] & (1 << (31 - threadid)); + if (!mask) { + had_frag[i * 4 + j] = 0; + continue; + } + } + bool pred1 = had_16_p1[c_log_h - 1][j] & (1 << (31 - threadid)); + bool pred2 = had_16_p2[c_log_h - 1][j] & (1 << (31 - threadid)); + b32 val = pred1 ? (pred2 ? p_p[c_log_h - 1] : p_n[c_log_h - 1]) : (pred2 ? n_p[c_log_h - 1] : n_n[c_log_h - 1]); + had_frag[i * 4 + j] = val; + } + if constexpr(log_had_size <= 4 || log_had_size % 4 == 0) break; + } + + // log had size above 8, only used for above 2^8 = 256 size + constexpr int64_t part8_log_had_size = log_had_size - 8; + + b32* a_chunk_ptr = a_start_ptr; // first chunk starts at this warp's data starts + b32* out_chunk_ptr = out_start_ptr; + + #pragma unroll + for (int64_t l = 0; l < 2; l++) { + if constexpr(log_had_size <= 8) { // l == 0 guaranteed, redundant simplified version of else body, to help compiler warnings + b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128; + } else { + b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * (l == 0 ? 128 : (128 >> part8_log_had_size)); + } + + if (l == 1) { + if constexpr(log_had_size > 8) { + __syncthreads(); // sync between first and second iterations if above size 256 + + if constexpr(log_had_size >= 12) { + // sizes 4k and above + + // a + threadblock offset + warp offset + // can then index into all chunks owned by this warp + b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block)); + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + // here, j represents register, and k represents 8-offset/chunk + uint64_t real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data + + int64_t real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread # + int64_t chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data) + int64_t thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads) + int64_t thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads + int64_t reg_idx = (j / 2) * 8 + (j % 2); // index due to target register + int64_t idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index + + // fix idx for majorness + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + + // store[rowidx * 128 + colidx] = data; + b32 data = store[rowidx * 128 + colidx]; + + // compiler generates excessive instructions, so we manually do the if statement + #pragma unroll + for (uint64_t i = 0; i < num_chunks; i++) { + asm volatile ( + "{\n\t" + " .reg .pred p0;\n\t" + " setp.eq.s64 p0, %1, %2;\n\t" + " @p0 mov.b32 %0, %3;\n\t" + "}\n\t" + : "+r"(b_frag_all[i][j]) // Output operand %0 + : "l"(real_chunk_num), "l"(i), "r"(data) // Input operands %1, %2, %3 + ); + } + } + } + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 1; k < num_chunks; k++) { + int64_t threadid_contig = threadid % num_chunks; + int64_t threadid_mul = threadid / num_chunks; + int64_t threadid2 = (threadid_contig + num_chunks - k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to + b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2); + } + } + } + } + } + + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + if constexpr(enable_mask) { + if (k >= real_num_chunks) + break; + } + if (l == 0) { + // bad fix for k not being recognized as a constexpr by compiler + // asm("cp.async.wait_group %0;\n" :: "n"(num_chunks - k - 1)); + #define SWITCH_WAIT_ASYNC_LOAD_GROUP(i) case i: asm volatile("cp.async.wait_group %0;\n" :: "n"(num_chunks - i - 1)); break; + if constexpr(enable_mask) { + switch(k + diff_num_chunks) { + SWITCH_WAIT_ASYNC_LOAD_GROUP(0) + SWITCH_WAIT_ASYNC_LOAD_GROUP(1) + SWITCH_WAIT_ASYNC_LOAD_GROUP(2) + SWITCH_WAIT_ASYNC_LOAD_GROUP(3) + SWITCH_WAIT_ASYNC_LOAD_GROUP(4) + SWITCH_WAIT_ASYNC_LOAD_GROUP(5) + SWITCH_WAIT_ASYNC_LOAD_GROUP(6) + SWITCH_WAIT_ASYNC_LOAD_GROUP(7) + SWITCH_WAIT_ASYNC_LOAD_GROUP(8) + SWITCH_WAIT_ASYNC_LOAD_GROUP(9) + SWITCH_WAIT_ASYNC_LOAD_GROUP(10) + SWITCH_WAIT_ASYNC_LOAD_GROUP(11) + SWITCH_WAIT_ASYNC_LOAD_GROUP(12) + SWITCH_WAIT_ASYNC_LOAD_GROUP(13) + SWITCH_WAIT_ASYNC_LOAD_GROUP(14) + SWITCH_WAIT_ASYNC_LOAD_GROUP(15) + SWITCH_WAIT_ASYNC_LOAD_GROUP(16) + SWITCH_WAIT_ASYNC_LOAD_GROUP(17) + SWITCH_WAIT_ASYNC_LOAD_GROUP(18) + SWITCH_WAIT_ASYNC_LOAD_GROUP(19) + SWITCH_WAIT_ASYNC_LOAD_GROUP(20) + SWITCH_WAIT_ASYNC_LOAD_GROUP(21) + SWITCH_WAIT_ASYNC_LOAD_GROUP(22) + SWITCH_WAIT_ASYNC_LOAD_GROUP(23) + SWITCH_WAIT_ASYNC_LOAD_GROUP(24) + SWITCH_WAIT_ASYNC_LOAD_GROUP(25) + SWITCH_WAIT_ASYNC_LOAD_GROUP(26) + SWITCH_WAIT_ASYNC_LOAD_GROUP(27) + SWITCH_WAIT_ASYNC_LOAD_GROUP(28) + SWITCH_WAIT_ASYNC_LOAD_GROUP(29) + SWITCH_WAIT_ASYNC_LOAD_GROUP(30) + SWITCH_WAIT_ASYNC_LOAD_GROUP(31) + } + } else { + switch(k) { + SWITCH_WAIT_ASYNC_LOAD_GROUP(0) + SWITCH_WAIT_ASYNC_LOAD_GROUP(1) + SWITCH_WAIT_ASYNC_LOAD_GROUP(2) + SWITCH_WAIT_ASYNC_LOAD_GROUP(3) + SWITCH_WAIT_ASYNC_LOAD_GROUP(4) + SWITCH_WAIT_ASYNC_LOAD_GROUP(5) + SWITCH_WAIT_ASYNC_LOAD_GROUP(6) + SWITCH_WAIT_ASYNC_LOAD_GROUP(7) + SWITCH_WAIT_ASYNC_LOAD_GROUP(8) + SWITCH_WAIT_ASYNC_LOAD_GROUP(9) + SWITCH_WAIT_ASYNC_LOAD_GROUP(10) + SWITCH_WAIT_ASYNC_LOAD_GROUP(11) + SWITCH_WAIT_ASYNC_LOAD_GROUP(12) + SWITCH_WAIT_ASYNC_LOAD_GROUP(13) + SWITCH_WAIT_ASYNC_LOAD_GROUP(14) + SWITCH_WAIT_ASYNC_LOAD_GROUP(15) + SWITCH_WAIT_ASYNC_LOAD_GROUP(16) + SWITCH_WAIT_ASYNC_LOAD_GROUP(17) + SWITCH_WAIT_ASYNC_LOAD_GROUP(18) + SWITCH_WAIT_ASYNC_LOAD_GROUP(19) + SWITCH_WAIT_ASYNC_LOAD_GROUP(20) + SWITCH_WAIT_ASYNC_LOAD_GROUP(21) + SWITCH_WAIT_ASYNC_LOAD_GROUP(22) + SWITCH_WAIT_ASYNC_LOAD_GROUP(23) + SWITCH_WAIT_ASYNC_LOAD_GROUP(24) + SWITCH_WAIT_ASYNC_LOAD_GROUP(25) + SWITCH_WAIT_ASYNC_LOAD_GROUP(26) + SWITCH_WAIT_ASYNC_LOAD_GROUP(27) + SWITCH_WAIT_ASYNC_LOAD_GROUP(28) + SWITCH_WAIT_ASYNC_LOAD_GROUP(29) + SWITCH_WAIT_ASYNC_LOAD_GROUP(30) + SWITCH_WAIT_ASYNC_LOAD_GROUP(31) + } + } + } + + if (l == 0) { + // loading for the first iteration + + // thread 0 loads [t0r0, t16r1, t0r2, t16r3] + // thread 16 loads [t0r1, t16r0, t0r3, t16r2] + // allows full coalescing, same for t1/t17, t2/t18, etc. + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2)); + int64_t real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16); + int64_t real_row = real_thread_id % 4; + int64_t real_col = real_thread_id / 4; + b_frag_all[k][j] = b_frag_ptr[(real_row + (reg % 2) * 4) + (real_col + (j / 2) * 8) * 8]; + } + + // for t16 swap r0/r1 and r2/r3 to have [t16r0, t0r1, t16r2, t0r3] + // so registers are in right order, same for t17, t18, etc. + if ((threadid & 16) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][1]; + b_frag_all[k][1] = temp; + + temp = b_frag_all[k][2]; + b_frag_all[k][2] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + // t0 and t16 swap r1 and r3 to have their own data, + // same for t1/t17, t2/18, etc. + #pragma unroll + for (int64_t j = 1; j < 4; j += 2) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16); + } + } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings + if constexpr(log_had_size < 12) { + // sizes 512, 1k, and 2k + + // for 512: + // thread 0 loads [t0r0, t0r1, t16r2, t16r3] + // thread 16 loads [t0r2, t0r3, t16r0, t16r1] + // same for t1/t17, t2/t18, etc. + // for 1k and 2k: + // thread 0 loads [t0r0, t0r1, t1r2, t1r3] + // thread 1 loads [t0r2, t0r3, t1r0, t1r1] + // same for t2/t3, t4/t5, etc. + // allows full coalescing for 512 and 1k, 16x coalescing for 2k + constexpr int64_t xor_val = log_had_size == 9 ? 16 : 1; + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4; + int64_t real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val); + int64_t idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2); + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + b_frag_all[k][j] = b_frag_ptr[rowidx * 128 + colidx]; + } + + if ((threadid & xor_val) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][2]; + b_frag_all[k][2] = temp; + + temp = b_frag_all[k][1]; + b_frag_all[k][1] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + #pragma unroll + for (int64_t j = 2; j < 4; j++) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val); + } + } + } + + if (l == 1) { + // for second iteration, we load 2 consecutive b16s (1 b32) per register, + // but tensor core register layout requires 2 b16s that are in the + // same column/consecutive rows to be in the same register, so do the swap + b32 f0 = ((b_frag_all[k][1] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF); + b32 f1 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][2] & 0xFFFF); + b32 f2 = (b_frag_all[k][1] & 0xFFFF0000) | (b_frag_all[k][0] >> 16); + b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][2] >> 16); + b_frag_all[k][0] = f0; + b_frag_all[k][1] = f1; + b_frag_all[k][2] = f2; + b_frag_all[k][3] = f3; + } + + #pragma unroll + for(int64_t i = 0, remaining_log_had_size = log_had_size - l * 8; i < 2 && remaining_log_had_size > 0; i++) { + int64_t had_off = ((remaining_log_had_size < 4) && !(log_had_size <= 4 || log_had_size % 4 == 0)) ? 4 : 0; + mma_m16_n16_k16_b16_b16_b16_noacc(had_frag[had_off + 0], had_frag[had_off + 1], had_frag[had_off + 2], had_frag[had_off + 3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3]); + + remaining_log_had_size -= 4; + if (remaining_log_had_size <= 0 && i == 0) { + // TODO: consider different storing so no need for transpose + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][0]); + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][1]); + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][2]); + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][3]); + } else { + // swap and use output directly as b_frag for next iteration as an actually free transpose + b32 temp = b_frag_all[k][1]; + b_frag_all[k][1] = b_frag_all[k][2]; + b_frag_all[k][2] = temp; + } + } + + if (l == 1) { + // invert swap from above for second iteration + b32 f0 = ((b_frag_all[k][2] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF); + b32 f1 = (b_frag_all[k][2] & 0xFFFF0000) | (b_frag_all[k][0] >> 16); + b32 f2 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][1] & 0xFFFF); + b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][1] >> 16); + b_frag_all[k][0] = f0; + b_frag_all[k][1] = f1; + b_frag_all[k][2] = f2; + b_frag_all[k][3] = f3; + } + + if (l == 0) { + // inverse of coalesced load for first iteration to store result + #pragma unroll + for (int64_t j = 1; j < 4; j += 2) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16); + } + + if ((threadid & 16) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][1]; + b_frag_all[k][1] = temp; + + temp = b_frag_all[k][2]; + b_frag_all[k][2] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + // if only going up to 256 size, store directly back to global memory, + // otherwise store back to shared memory for next iteration + b32* store = (log_had_size <= 8) ? out_chunk_ptr : b_frag_ptr; + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2)); + int64_t real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16); + int64_t real_row = real_thread_id % 4; + int64_t real_col = real_thread_id / 4; + store[(real_row + (reg % 2) * 4) + (real_col + (reg / 2) * 8) * 8] = b_frag_all[k][j]; + } + } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings + if (log_had_size < 12) { + // inverse of coalesced load for sizes 512, 1k and 2k to store result + constexpr int xor_val = log_had_size == 9 ? 16 : 1; + #pragma unroll + for (int64_t j = 2; j < 4; j++) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val); + } + + if ((threadid & xor_val) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][2]; + b_frag_all[k][2] = temp; + + temp = b_frag_all[k][1]; + b_frag_all[k][1] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + b32* store = (b32*)(out + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 256 + (256 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block) + k)); + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4; + b32 data = b_frag_all[k][j]; + int64_t real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val); + int64_t idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2); + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + store[rowidx * 128 + colidx] = data; + } + } + // for size 4k and above, wait to process all chunks so a final store can be performed coalesced + } + + a_chunk_ptr += 128; // (only affects first 256 size) move on to next chunk by skipping 256 elements in b16 (= 128 in b32) + out_chunk_ptr += 128; + if constexpr(log_had_size > 8) { + b_frag_ptr += (l == 0 ? 128 : (128 >> part8_log_had_size)); + } else { // else is redundant, simplified version of if body, to help compiler warnings + b_frag_ptr += 128; + } + } + if (log_had_size <= 8) + break; + } + + if constexpr(log_had_size >= 12) { + // for sizes 4k and above, perform final coalesced store after processing all chunks + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 1; k < num_chunks; k++) { + int64_t threadid_contig = threadid % num_chunks; + int64_t threadid_mul = threadid / num_chunks; + int64_t threadid2 = (threadid_contig + k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to + b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2); + } + } + + // a + threadblock offset + warp offset + // can then index into all chunks owned by this warp + b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block)); + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + // here, j represents register, and k represents 8-offset/chunk + int64_t real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data + + // b32 data = b_frag_all[real_chunk_num][j]; // target thread data + b32 data; + #pragma unroll + for (int64_t i = 0; i < num_chunks; i++) { + if (real_chunk_num == i) data = b_frag_all[i][j]; + } + + int64_t real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread # + int64_t chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data) + int64_t thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads) + int64_t thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads + int64_t reg_idx = (j / 2) * 8 + (j % 2); // index due to target register + int64_t idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index + + // fix idx for majorness + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + + store[rowidx * 128 + colidx] = data; + } + } + + __syncthreads(); + store = ((b32*) out) + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 128; + int4* store4 = (int4*) store; + int4* bfrag_arr4 = (int4*) bfrag_arr; + // flush smem, simply linearly write to store + // always divisible by 128*32b, so (32*4)*32b is ok + #pragma unroll + for (int64_t warp_off = 0; warp_off < (num_chunks * warps_per_block * 128 / 4); warp_off += 32 * warps_per_block) { + int64_t total_off = warp_off + threadid + (blockid % warps_per_block) * 32; + store4[total_off] = bfrag_arr4[total_off]; + } + } + +} + +constexpr int64_t ceil_div(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +template +void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaStream_t stream) { + int64_t shared_size = chunks_per_warp * warps_per_block * 128 * 4; + dim3 block_size = 32 * warps_per_block; + + #define CHECK_SHARED_LIM() { \ + if (shared_size > 48 * 1024) { \ + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ + } \ + } \ + + if constexpr(check_masking) { + if (num_chunks % (chunks_per_warp * warps_per_block) != 0) { + dim3 grid_size = ceil_div(ceil_div(num_chunks, chunks_per_warp), warps_per_block); + auto kernel = hadamard_transform_kernel; + CHECK_SHARED_LIM(); + kernel<<>>(a_mat, out, num_chunks); + } else { + dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block; + auto kernel = hadamard_transform_kernel; + CHECK_SHARED_LIM(); + kernel<<>>(a_mat, out, num_chunks); + } + } else { + dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block; + auto kernel = hadamard_transform_kernel; + CHECK_SHARED_LIM(); + kernel<<>>(a_mat, out, num_chunks); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream) { + int64_t num_chunks = numel / 256; // caller required to ensure divisible by 256 + // for size 256, use (2, 1) + // for size 32k use (8, 16) + constexpr int64_t chunks_per_warp_small = 1;// 8; + constexpr int64_t warps_per_block_small = 1;//2;//16; + constexpr int64_t blocks_per_sm_small = 24; + constexpr int64_t chunks_per_warp_large = 2; + constexpr int64_t warps_per_block_large = 1; + constexpr int64_t blocks_per_sm_large = 24; + + b16* a_mat = (b16*) a_mat_ptr; + b16* out = (b16*) out_ptr; + + if (numel <= 256) { + switch (had_size) { + case (1<<1): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<2): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<3): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<4): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<5): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<6): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<7): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<8): run_kernel(a_mat, out, num_chunks, stream); break; + } + } else { + switch (had_size) { + case (1<<1): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<2): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<3): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<4): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<5): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<6): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<7): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<8): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<9): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<10): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<11): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<12): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<13): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<14): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<15): run_kernel(a_mat, out, num_chunks, stream); break; + } + } +} + +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); + +} // namespace hadacore + +constexpr bool is_power_of_two(int x) { return x && !(x & (x - 1)); } + +torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) { + auto dtype = x.scalar_type(); + TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + TORCH_CHECK(x.is_cuda()); + + const int had_size = x.size(-1); + TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), + "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size); + + const auto res_shape = x.sizes(); + x = x.reshape({-1, had_size}); + + auto numel = x.numel(); + if (numel % 256 != 0) { + x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); + } + + if (x.stride(-1) != 1) { + x = x.contiguous(); + } + torch::Tensor out = inplace ? x : torch::empty_like(x); + + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + VLLM_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] { + auto constexpr SCALAR_TYPE = c10::CppTypeToScalarType::value; + hadacore::run_fht(x.data_ptr(), x.data_ptr(), x.numel(), had_size, stream); + }); + + if (numel % 256 != 0) { + out = out.index({torch::indexing::Slice(0, numel / had_size)}); + } + + if (inplace && out.data_ptr() != x.data_ptr()) { + x.copy_(out.view(res_shape)); + return x; + } + return out.reshape(res_shape); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("hadacore_transform", &hadacore_transform); +} diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 0d14ba15937c6..d29a199c5d32f 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -12,20 +12,21 @@ from functools import reduce from typing import Optional, Union import jinja2 -# yapf conflicts with isort for this block -# yapf: disable -from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, - EpilogueScheduleType, - MixedInputKernelScheduleType, - TileSchedulerTag, - TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, - VLLMDataTypeSize, VLLMDataTypeTag, - VLLMDataTypeTorchDataTypeTag, - VLLMDataTypeVLLMScalarTypeTag, - VLLMKernelScheduleTag) - -# yapf: enable +from vllm_cutlass_library_extension import ( + DataType, + EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, + VLLMDataType, + VLLMDataTypeNames, + VLLMDataTypeSize, + VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, + VLLMKernelScheduleTag, +) # # Generator templating @@ -286,18 +287,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) - cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + - f"x{schedule_config.cluster_shape_mnk[1]}" + - f"x{schedule_config.cluster_shape_mnk[2]}") - kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ - .split("::")[-1] - epilogue_schedule = EpilogueScheduleTag[ - schedule_config.epilogue_schedule].split("::")[-1] - tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ - .split("::")[-1] + cluster_shape = ( + f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}" + ) + kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split( + "::" + )[-1] + epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split( + "::" + )[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1] - return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + - f"_{epilogue_schedule}_{tile_scheduler}") + return ( + f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}" + ) # mostly unique shorter sch_sig @@ -316,18 +322,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: # unique type_name def generate_type_signature(kernel_types: TypeConfig): - return str("".join([ - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ])) + return str( + "".join( + [ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) + ) def generate_type_option_name(kernel_types: TypeConfig): - return ", ".join([ - f"{field.name.replace('b_', 'with_')+'_type'}=" + - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ]) + return ", ".join( + [ + f"{field.name.replace('b_', 'with_') + '_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) def is_power_of_two(n): @@ -335,7 +347,6 @@ def is_power_of_two(n): def to_cute_constant(value: list[int]): - def _to_cute_constant(value: int): if is_power_of_two(value): return f"_{value}" @@ -350,11 +361,11 @@ def to_cute_constant(value: list[int]): def unique_schedules(impl_configs: list[ImplConfig]): # Use dict over set for deterministic ordering - return list({ - sch: None - for impl_config in impl_configs - for sch in impl_config.schedules - }.keys()) + return list( + { + sch: None for impl_config in impl_configs for sch in impl_config.schedules + }.keys() + ) def unsigned_type_with_bitwidth(num_bits): @@ -380,7 +391,7 @@ template_globals = { "gen_type_sig": generate_type_signature, "unique_schedules": unique_schedules, "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, - "gen_type_option_name": generate_type_option_name + "gen_type_option_name": generate_type_option_name, } @@ -398,26 +409,31 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE) def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): sources = [] - sources.append(( - "machete_mm_dispatch", - mm_dispatch_template.render(impl_configs=impl_configs), - )) + sources.append( + ( + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), + ) + ) prepack_types = [] for impl_config in impl_configs: - convert_type = impl_config.types.a \ - if impl_config.types.b_group_scale == DataType.void \ - else impl_config.types.b_group_scale + convert_type = ( + impl_config.types.a + if impl_config.types.b_group_scale == DataType.void + else impl_config.types.b_group_scale + ) prepack_types.append( PrepackTypeConfig( a=impl_config.types.a, b_num_bits=VLLMDataTypeSize[impl_config.types.b], convert=convert_type, accumulator=impl_config.types.accumulator, - )) + ) + ) def prepacked_type_key(prepack_type: PrepackTypeConfig): - # For now we we can just use the first accumulator type seen since + # For now, we can just use the first accumulator type seen since # the tensor core shapes/layouts don't vary based on accumulator # type so we can generate less code this way return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) @@ -430,10 +446,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): unique_prepack_types.append(prepack_type) prepack_types_seen.add(key) - sources.append(( - "machete_prepack", - prepack_dispatch_template.render(types=unique_prepack_types, ), - )) + sources.append( + ( + "machete_prepack", + prepack_dispatch_template.render( + types=unique_prepack_types, + ), + ) + ) # Split up impls across files num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) @@ -466,10 +486,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): curr_impl_in_file += len(files_impls[-1][-1].schedules) for part, file_impls in enumerate(files_impls): - sources.append(( - f"machete_mm_impl_part{part+1}", - mm_impl_template.render(impl_configs=file_impls), - )) + sources.append( + ( + f"machete_mm_impl_part{part + 1}", + mm_impl_template.render(impl_configs=file_impls), + ) + ) return sources @@ -514,8 +536,7 @@ def generate(): # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore + (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore for cond, tile_config in default_tile_heuristic_config.items() ] @@ -541,14 +562,18 @@ def generate(): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(GPTQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] AWQ_kernel_type_configs = list( @@ -561,14 +586,18 @@ def generate(): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (DataType.u4, DataType.u8) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(AWQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] # TODO: Support W4A8 when ready diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/w8a8/cutlass/Epilogues.md similarity index 100% rename from csrc/quantization/cutlass_w8a8/Epilogues.md rename to csrc/quantization/w8a8/cutlass/Epilogues.md diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh rename to csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh similarity index 88% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index c841125dbb734..e7bb061ba0244 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -14,9 +14,6 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { @@ -149,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -169,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); - auto mainloop_args = [&](){ - // layout_SFA and layout_SFB cannot be swapped since they are deduced. - if (swap_ab) { - return typename GemmKernel::MainloopArguments{ - b_ptr, b_stride, a_ptr, a_stride, - b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB - }; - } - else { - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - } - }(); + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.layout_SFB = layout_SFB; + if (swap_ab) { + mainloop_args.ptr_A = b_ptr; + mainloop_args.dA = b_stride; + mainloop_args.ptr_B = a_ptr; + mainloop_args.dB = a_stride; + mainloop_args.ptr_SFA = b_scales_ptr; + mainloop_args.ptr_SFB = a_scales_ptr; + } else { + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.ptr_SFB = b_scales_ptr; + } auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); @@ -230,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, } else { cutlass_gemm_caller_blockwise, Int>, - Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } @@ -244,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, } else { cutlass_gemm_caller_blockwise, Int>, - Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } @@ -258,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, } else { cutlass_gemm_caller_blockwise, Int>, - Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, + Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( out, a, b, a_scales, b_scales); } @@ -270,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, // TMA epilogue isn't compatible with Swap A/B cutlass_gemm_caller_blockwise, Int, Int>, - Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( out, a, b, a_scales, b_scales); } } -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh similarity index 90% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index d50a83ae1cd48..811741aee58b3 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -14,9 +14,6 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { @@ -128,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -146,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); - auto mainloop_args = [&](){ - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - }(); + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.ptr_SFB = b_scales_ptr; + mainloop_args.layout_SFB = layout_SFB; auto prob_shape = cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh new file mode 100644 index 0000000000000..147eb8efc0778 --- /dev/null +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +// clang-format off +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + + TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.ptr_SFB = b_scales_ptr; + mainloop_args.layout_SFB = layout_SFB; + auto prob_shape = cute::make_shape(m, n, k, 1); + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + // TODO: better heuristics + cutlass_gemm_caller_blockwise, + Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>( + out, a, b, a_scales, b_scales); +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp similarity index 57% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index 2ee6a19407f92..2204a49257b08 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -25,14 +25,17 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, if constexpr (!std::is_same_v) { int8_func(c, a, b, a_scales, b_scales, bias); } else { - TORCH_CHECK(false, "Int8 not supported for this architecture"); + int32_t version_num = get_sm_version_num(); + TORCH_CHECK( + false, "Int8 not supported on SM", version_num, + ". Use FP8 quantization instead, or run on older arch (SM < 100)."); } } } else { TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); int32_t version_num = get_sm_version_num(); - if (version_num >= 100) { + if (version_num >= 90) { TORCH_CHECK( a.size(0) == a_scales.size(0) && cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), @@ -41,32 +44,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), "b_scale_group_shape must be [128, 128]."); - } else { - // TODO: Remove this after using cutlass sm90 blockwise scaling gemm - // kernel, or introducing ceil_div to the load_init() of mainloop. - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); } TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh similarity index 99% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 24564efbd21be..f876b7d9acd87 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -133,4 +133,4 @@ void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, } } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh rename to csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_data.cu rename to csrc/quantization/w8a8/cutlass/moe/moe_data.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu similarity index 98% rename from csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 84843ee6e0949..1001af05ff003 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif -#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ - defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 +#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ + defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \ + defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120 void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -253,7 +254,7 @@ void cutlass_moe_mm( bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); #if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 - if (version_num >= 100) { + if (version_num >= 100 && version_num < 110) { cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); @@ -261,7 +262,7 @@ void cutlass_moe_mm( } #endif #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - if (version_num >= 90) { + if (version_num >= 90 && version_num < 100) { cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/amd/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/amd/quant_utils.cuh index e51a4e14e518f..81f5cb83f3e18 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh @@ -5,7 +5,7 @@ #include #include -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/w8a8/fp8/common.cu similarity index 98% rename from csrc/quantization/fp8/common.cu rename to csrc/quantization/w8a8/fp8/common.cu index 5fe5dd04bd891..7a822fb8fb8aa 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/w8a8/fp8/common.cu @@ -1,15 +1,10 @@ #include "common.cuh" #include "dispatch_utils.h" -#include "../vectorization_utils.cuh" +#include "cub_helpers.h" +#include "quantization/vectorization_utils.cuh" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { template @@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; const float block_max = - BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh similarity index 86% rename from csrc/quantization/fp8/common.cuh rename to csrc/quantization/w8a8/fp8/common.cuh index 1aad6330c44b8..7838f211c59db 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/w8a8/fp8/common.cuh @@ -5,7 +5,9 @@ #include -#ifdef USE_ROCM +#ifndef USE_ROCM + #include "nvidia/quant_utils.cuh" +#else #include "amd/quant_utils.cuh" #endif @@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, float r = fmaxf(-quant_type_max_v, fminf(x, quant_type_max_v)); #ifndef USE_ROCM - return static_cast(r); + // Use hardware cvt instruction for fp8 on nvidia + // Currently only support fp8_type = c10::Float8_e4m3fn + return fp8::vec_conversion(r); #else // Use hardware cvt instruction for fp8 on rocm return fp8::cvt_c10(r); diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh similarity index 92% rename from csrc/quantization/fp8/nvidia/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh index f8cd1dcba4ab3..421e8092474bd 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" #include #include #include @@ -12,13 +12,26 @@ namespace vllm { namespace fp8 { #ifdef ENABLE_FP8 - #if 0 // Disable the following code to reduce the binary size. template -__inline__ __device__ Tout -vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { +__inline__ __device__ Tout vec_conversion( + const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) { return x; } +// float -> c10::Float8_e4m3fn +template <> +__inline__ __device__ c10::Float8_e4m3fn +vec_conversion( + const float& a, const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return static_cast(a); + #else + return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type), + c10::Float8_e4m3fn::from_bits()); + #endif +} + + #if 0 // Disable the following code to reduce the binary size. // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion( @@ -563,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ + } else if (KV_DTYPE == "fp8_ds_mla") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ } else { \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu similarity index 96% rename from csrc/quantization/fp8/per_token_group_quant.cu rename to csrc/quantization/w8a8/fp8/per_token_group_quant.cu index f5b40e35b6e5a..e3ab0676b254e 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -1,6 +1,6 @@ #include -#include "../per_token_group_quant_8bit.h" +#include "quantization/w8a8/per_token_group_quant_8bit.h" #include @@ -8,12 +8,12 @@ #include -#include "../vectorization.cuh" -#include "../vectorization_utils.cuh" -#include "../../dispatch_utils.h" +#include "quantization/vectorization.cuh" +#include "quantization/vectorization_utils.cuh" +#include "dispatch_utils.h" -__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { - unsigned mask = 0xffff; +__device__ __forceinline__ float GroupReduceMax(float val) { + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); @@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel( threads_per_group, // stride in group scalar_op_cache); // scalar handler - local_absmax = GroupReduceMax(local_absmax, lane_id); + local_absmax = GroupReduceMax(local_absmax); float y_s = local_absmax / max_8bit; if constexpr (SCALE_UE8M0) { @@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input, double fp8_max, bool scale_ue8m0) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); -} +} \ No newline at end of file diff --git a/csrc/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/quantization/w8a8/int8/per_token_group_quant.cu new file mode 100644 index 0000000000000..9d808a176f538 --- /dev/null +++ b/csrc/quantization/w8a8/int8/per_token_group_quant.cu @@ -0,0 +1,12 @@ +#include +#include + +#include "quantization/w8a8/per_token_group_quant_8bit.h" + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} \ No newline at end of file diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu similarity index 92% rename from csrc/quantization/compressed_tensors/int8_quant_kernels.cu rename to csrc/quantization/w8a8/int8/scaled_quant.cu index d8369108d0bd3..7fe9e96bfb017 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,22 +1,11 @@ #include #include -#ifndef USE_ROCM - #include "../per_token_group_quant_8bit.h" -#endif - #include -#include "../../dispatch_utils.h" -#include "../vectorization_utils.cuh" - -#ifndef USE_ROCM - #include - #include -#else - #include - #include -#endif +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" +#include "cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -32,7 +21,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -91,7 +79,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -173,7 +160,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; - float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); + float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x); __shared__ float absmax; if (tid == 0) { absmax = block_max; @@ -183,7 +170,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -201,7 +187,6 @@ struct MinMax { __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} - // add a value to the MinMax __host__ __device__ MinMax& operator+=(float v) { min = fminf(min, v); max = fmaxf(max, v); @@ -235,7 +220,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; - // 1. calculate min & max MinMax thread_mm; vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -268,7 +252,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const float inv_s = 1.f / scale_sh; const azp_t azp = azp_sh; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -339,14 +322,4 @@ void dynamic_scaled_int8_quant( hidden_size); } }); -} - -#ifndef USE_ROCM -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max) { - per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, - int8_min, int8_max); -} -#endif +} \ No newline at end of file diff --git a/csrc/quantization/per_token_group_quant_8bit.h b/csrc/quantization/w8a8/per_token_group_quant_8bit.h similarity index 84% rename from csrc/quantization/per_token_group_quant_8bit.h rename to csrc/quantization/w8a8/per_token_group_quant_8bit.h index 537b61bc4303f..25d4ecd1131a1 100644 --- a/csrc/quantization/per_token_group_quant_8bit.h +++ b/csrc/quantization/w8a8/per_token_group_quant_8bit.h @@ -1,7 +1,6 @@ #pragma once #include -// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders // 8-bit per-token-group quantization helper used by both FP8 and INT8 void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q, diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index e3a0e15f5304f..a339c5641bb4a 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -23,14 +23,25 @@ #include #include "../attention/dtype_fp8.cuh" -#include "../quantization/fp8/amd/quant_utils.cuh" +#include "../quantization/w8a8/fp8/amd/quant_utils.cuh" + +// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent +#if !defined(HIP_FP8_TYPE_OCP) +using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz; +using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz; +#endif #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) #define __HIP__GFX9__ #endif -#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) +#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__FP8MFMA__ +#endif + +#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1150__) || defined(__gfx1151__)) #define __HIP__GFX11__ #endif @@ -51,6 +62,12 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +enum class MFMAType { + F16 = 0, + Fp8 = 1, + Fp4 = 2, +}; + #if defined(__HIP__GFX9__) #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 @@ -112,6 +129,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA, + const long& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz, + cbid, blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz, + cbid, blgp); + } else { + static_assert(false, "unsupported 8b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -256,12 +288,44 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { return ret; } +typedef union u64_cvt { + half f16x4[4]; + int16_t b16x4[4]; + _B8x8 b8x8; + _B16x4 b64; + int64_t i64; +} _T8x8; + +__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input, + _T8x8& Mtemp) { + _T8x8 Qtmp8x8; + + for (int i = 0; i < 2; i++) { + floatx4 q_out = {0, 0, 0, 0}; + q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(Mtemp.b64, input.xy[i], + q_out); + Qtmp8x8.b16x4[i * 2] = + __builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1], 0, false); + Qtmp8x8.b16x4[i * 2 + 1] = + __builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3], 0, false); + } + return Qtmp8x8.b8x8; +} + +__device__ float warpReduceMax(float val) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val = max( + val, __shfl_down(val, offset, WARP_SIZE)); // Using max() for reduction + } + return val; +} + // grid (num_seqs, num_partitions,num_kv_heads) // block (256) // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -367,6 +431,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; int kphysical_block_number[TLOOP]; + #if defined(__HIP__FP8MFMA__) + float q_max = 0; + float q_scale = 1.0; + #endif // fetch k physical block numbers for (int token_depth = 0; token_depth < TLOOP; token_depth++) { @@ -416,6 +484,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] [2 * qkratio + i]; + #if defined(__HIP__FP8MFMA__) + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto && + MFMA_TYPE == MFMAType::Fp8) { + scalar_t* qptr = + reinterpret_cast(&Qlocal[qkhe_depth][qkratio].xy[i]); + for (int k = 0; k < 4; k++) + q_max = fmax(fabs(to_float(qptr[k])), q_max); + } + #endif } } } @@ -515,6 +592,14 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { // multiply by k_scale if fp8 kv cache scale2 *= *k_scale; + #if defined(__HIP__FP8MFMA__) + q_max = warpReduceMax(q_max); + constexpr float FP8_E4M3_SCALE_TARGET = 224.0f; + if constexpr (MFMA_TYPE == MFMAType::Fp8) { + q_scale = q_max > 0 ? FP8_E4M3_SCALE_TARGET / q_max : 1.0f; + scale2 /= q_scale; + } + #endif } floatx4 d_out[TLOOP]; @@ -534,12 +619,41 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( auto Ktmp = Klocal[token_depth][qkhe_depth]; _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; - _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); - for (int i = 0; i < 2; i++) { - d_out[token_depth] = gcn_mfma16x16x16_instr( - Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], - d_out[token_depth]); + if constexpr (MFMA_TYPE == MFMAType::F16) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + d_out[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + d_out[token_depth]); + } + } else { + #if defined(__HIP__FP8MFMA__) + _T8x8 Ktmp8x8, Qtmp8x8; + Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio]; + + for (int n = 0; n < 2; n++) { + scalar_t* qptr = reinterpret_cast( + &Qlocal[qkhe_depth][qkratio].xy[n]); + + Qtmp8x8.b16x4[n * 2] = + vllm::fp8::scaled_vec_conversion( + make_float2(to_float(qptr[0]), + to_float(qptr[1])), + q_scale); + Qtmp8x8.b16x4[n * 2 + 1] = + vllm::fp8::scaled_vec_conversion( + make_float2(to_float(qptr[2]), + to_float(qptr[3])), + q_scale); + } + + d_out[token_depth] = + gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>( + Ktmp8x8.i64, Qtmp8x8.i64, d_out[token_depth]); + #else + UNREACHABLE_CODE + #endif } } } @@ -629,17 +743,36 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // disable rtz conversion due to its impact on accuracy. constexpr bool LOGITS_RTZ_CONVERSION = false; + #if defined(__HIP__FP8MFMA__) + int rowid_8x8 = rowid / 2; + int offset = rowid % 2; + #endif + // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { d_out[token_depth] *= inv_sum_scale; - if constexpr (LOGITS_RTZ_CONVERSION) { - // use rtz conversion for better performance, with negligible impact on - // accuracy - shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4_rtz(d_out[token_depth]); + if constexpr (MFMA_TYPE != MFMAType::Fp8) { + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(d_out[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(d_out[token_depth]); + } } else { - shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4(d_out[token_depth]); + #if defined(__HIP__FP8MFMA__) + // cast _B16x4* to _B8x8* + _T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>( + &shared_logits[warpid][token_depth][lane16id][rowid_8x8]); + logits_8x8.b16x4[offset * 2] = __builtin_amdgcn_cvt_pk_fp8_f32( + d_out[token_depth][0], d_out[token_depth][1], 0, false); + logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32( + d_out[token_depth][2], d_out[token_depth][3], 0, false); + #else + UNREACHABLE_CODE + #endif } } @@ -692,19 +825,42 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; - _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { - const int offset = - rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + - j * ELEMS8_ELEMS4_RATIO + i; - const int offset1 = offset % ROWS_PER_WARP; - const int offset2 = offset / ROWS_PER_WARP; - // output format is 16 qheads across 16 lanes, 16 head elems - // spread across 4 rows - tmp_out = gcn_mfma16x16x16_instr( - Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); + if constexpr (MFMA_TYPE == MFMAType::F16) { + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } else { + #if defined(__HIP__FP8MFMA__) + for (int i = 0; i < ELEMS8_ELEMS4_RATIO / 2; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = (offset % ROWS_PER_WARP) / 2; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>( + reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64, + reinterpret_cast<_T8x8*>( + &shared_logits[vtoken_depth][offset2][lane16id] + [offset1]) + ->i64, + tmp_out); + } + #else + UNREACHABLE_CODE + #endif } } } @@ -1570,7 +1726,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, + MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2337,7 +2494,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, + MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2969,7 +3127,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template + int GQA_RATIO, MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -3041,7 +3199,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ paged_attention_ll4mi_QKV_mfma16_kernel \ + GQA_RATIO, MFMA_TYPE> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ @@ -3069,7 +3227,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template + bool ALIBI_ENABLED, MFMAType MFMA_TYPE> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -3225,7 +3383,7 @@ void paged_attention_custom_launcher( template + bool ALIBI_ENABLED, MFMAType MFMA_TYPE> void paged_attention_custom_launcher_navi( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -3397,74 +3555,77 @@ void paged_attention_custom_launcher_navi( } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE, ALIBI_ENABLED) \ + PSIZE, ALIBI_ENABLED, MFMA_TYPE) \ if (!is_navi) { \ paged_attention_custom_launcher( \ + OUTT, PSIZE, ALIBI_ENABLED, MFMA_TYPE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ } else { \ - paged_attention_custom_launcher_navi< \ - T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ + paged_attention_custom_launcher_navi( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ max_seq_len, alibi_slopes, k_scale, v_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - OUTT, PSIZE) \ + OUTT, PSIZE, MFMA_TYPE) \ if (alibi_slopes) { \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ - true); \ + true, MFMA_TYPE); \ } else { \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ - false); \ + false, MFMA_TYPE); \ } #if defined(__HIPCC__) && defined(__gfx90a__) - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + MFMA_TYPE) \ if (fp8_out_scale) { \ TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ } else { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ - 256); \ + 256, MFMA_TYPE); \ } #else - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + MFMA_TYPE) \ if (fp8_out_scale) { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - uint8_t, 256); \ + uint8_t, 256, MFMA_TYPE); \ } else { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ - 256); \ + 256, MFMA_TYPE); \ } #endif -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE, MFMA_TYPE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE, MFMA_TYPE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE, MFMA_TYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ - switch (head_size) { \ - case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ - break; \ - case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported head size: ", head_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE, MFMA_TYPE) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64, MFMA_TYPE); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128, MFMA_TYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ } bool is_navi_gpu() { @@ -3503,28 +3664,43 @@ void paged_attention( const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, - const std::optional& fp8_out_scale) { + const std::optional& fp8_out_scale, + const std::string& mfma_type) { // clang-format on bool is_navi = is_navi_gpu(); - const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, - vllm::Fp8KVCacheDataType::kAuto); + CALL_CUSTOM_LAUNCHER_BLK_HEAD( + _Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto, MFMAType::F16); } else if (query.dtype() == at::ScalarType::BFloat16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, - vllm::Fp8KVCacheDataType::kAuto); + vllm::Fp8KVCacheDataType::kAuto, + MFMAType::F16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); + if (mfma_type == "fp8") { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::Fp8); + } else { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::F16); + } } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); + if (mfma_type == "fp8") { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::Fp8); + } else { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::F16); + } } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 34dcc9401aae8..8b80362583eec 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -5,11 +5,14 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, const int64_t rows_per_block); -torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, +torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, const int64_t CuCount); -void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); +void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, at::Tensor& out_c, + const at::Tensor& scale_a, const at::Tensor& scale_b, + const int64_t CuCount); void paged_attention( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, @@ -19,4 +22,5 @@ void paged_attention( const std::optional& query_start_loc, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const std::optional& fp8_out_scale); + torch::Tensor& v_scale, const std::optional& fp8_out_scale, + const std::string& mfma_type); diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index eb47139208c91..2ef579a1b7537 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -11,7 +11,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) @@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); } } @@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); } } @@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); + } } } } @@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + if (commitColumn[i]) { + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); + } } } } @@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + if (commitColumn[i]) { + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) { return rtn; } -torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, +torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, const int64_t CuCount) { auto M_in = in_a.size(0); auto K_in = in_a.size(1); auto N_in = in_b.size(0); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); @@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } else if (K_in * N_in <= max_lds_len * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitK_hf_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ wvSplitK_hf_big_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } \ } @@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, using fptype = typename scalar::type; fptype* af4 = reinterpret_cast(in_a.data_ptr()); const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + const fptype* biasf4 = + (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); switch (N_in) { case 1: @@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, - const fp8_t* __restrict__ A, scalar_t* C, + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx, + const int By, const fp8_t* B, const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { @@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { - C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + if (y + m >= M) break; // To avoid mem access fault. + sum[n][y][0] *= sA * sB; + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += + __bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]); + } + C[m + y + n * M] = __float2s(sum[n][y][0]); // * sA * sB); } } } @@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, - const fp8_t* B, const fp8_t* __restrict__ A, + const int Bx, const int By, const fp8_t* B, + const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { @@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, - const fp8_t* __restrict__ A, scalar_t* C, + wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx, + const int By, const fp8_t* B, const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE; @@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + sum[n][y][0] *= sA * sB; + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += + __bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]); + } + C[m + y + n * M] = __float2s(sum[n][y][0]); } } } @@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, - const fp8_t* B, const fp8_t* __restrict__ A, - scalar_t* C, const float* __restrict__ s_A, + const int Bx, const int By, const fp8_t* B, + const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support -void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, +void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, at::Tensor& out_c, + const at::Tensor& scale_a, const at::Tensor& scale_b, const int64_t CuCount) { static c10::ScalarType kFp8Type = is_fp8_ocp() ? c10::ScalarType::Float8_e4m3fn @@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, auto K_in = in_a.size(1); auto N_in = in_b.size(0); auto Kp_in = in_a.stride(0); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); TORCH_CHECK(out_c.dtype() == torch::kFloat16 || @@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitKQ_hf_sml_ \ - <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ + <<>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ + b_ptr, bias_ptr, c_ptr, s_a, s_b, \ + __wvPrGrp, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitKQ_hf_ \ - <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ + <<>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ + b_ptr, bias_ptr, c_ptr, s_a, s_b, \ + __wvPrGrp, CuCount); \ } \ } @@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { auto a_ptr = in_a.data_ptr(); auto b_ptr = in_b.data_ptr(); + auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; switch (N_in) { case 1: WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 66bdc448da3ca..518486b1ca5de 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( - "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> " "Tensor"); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); // wvSplitK for fp8 rocm_ops.def( - "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, " + "Tensor scale_a, " " Tensor scale_b, int CuCount) -> ()"); rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); @@ -48,7 +49,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor? alibi_slopes," " str kv_cache_dtype," " Tensor k_scale, Tensor v_scale," - " Tensor? fp8_out_scale) -> ()"); + " Tensor? fp8_out_scale," + " str mfma_type) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/sampler.cu b/csrc/sampler.cu index b0cce2e98d221..bc589d99d04bf 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel( } } +static inline __device__ uint16_t extractBinIdx(float x) { + union { + __half h; + uint16_t u16; + } tmp; + tmp.h = __float2half_rn(x); + tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); + return 511 - (tmp.u16 >> 7); +} + +template +static __global__ void topKPerRow(const float* logits, const int* rowStarts, + const int* rowEnds, int* outIndices, + float* outLogits, int stride0, int stride1) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + // The number of elements per thread for the final top-k sort. + static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; + // The class to sort the elements during the final top-k sort. + using TopKSort = cub::BlockRadixSort; + + // The number of slots for the final pass. + static constexpr int kNumFinalItems = 3072; + // The number of elements per thread for the final sort. + static constexpr int kNumFinalItemsPerThread = + kNumFinalItems / kNumThreadsPerBlock; + // The class to sort the elements during the final pass. + using FinalSort = cub::BlockRadixSort; + + // The class to compute the inclusive prefix-sum over the histogram. + using Scan = cub::BlockScan; + + // Shared memory to compute the block scan. + __shared__ typename Scan::TempStorage smemScan; + + // The structure to store the final items (for the final pass). + struct FinalItems { + // Shared memory to store the indices for the final pass. + int indices[kNumFinalItems]; + // Shared memory to store the logits for the final pass. + float logits[kNumFinalItems]; + }; + + // Shared memory to compute the block sort. + __shared__ union { + FinalItems items; + typename FinalSort::TempStorage finalSort; + typename TopKSort::TempStorage topKSort; + } smemFinal; + + // Shared memory to store the histogram. + __shared__ int smemHistogram[kNumBins]; + // Shared memory to store the selected indices. + __shared__ int smemIndices[kTopK]; + // Shared memory to store the selected logits. + __shared__ float smemLogits[kTopK]; + // Shared memory to store the threshold bin. + __shared__ int smemThresholdBinIdx[1]; + // Shared memory counter to register the candidates for the final phase. + __shared__ int smemFinalDstIdx[1]; + + // The row computed by this block. + int rowIdx = blockIdx.x; + // The range of logits within the row. + int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx]; + // The length of the row. + int rowLen = rowEnd - rowStart; + + // Shortcut if the length of the row is smaller than Top-K. Indices are not + // sorted by their corresponding logit. + if (rowLen <= kTopK) { + for (int rowIt = threadIdx.x; rowIt < rowLen; + rowIt += kNumThreadsPerBlock) { + int idx = rowStart + rowIt; + outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + outLogits[rowIdx * kTopK + rowIt] = + logits[rowIdx * stride0 + idx * stride1]; + } + for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; + rowIt += kNumThreadsPerBlock) { + outIndices[rowIdx * kTopK + rowIt] = -1; + outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX; + } + return; + } + + // Clear the histogram. + if (threadIdx.x < kNumBins) { + smemHistogram[threadIdx.x] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Fetch elements one-by-one. + for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; + rowIt += kNumThreadsPerBlock) { + uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); + atomicAdd(&smemHistogram[idx], 1); + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Read the values from SMEM. + int binCount{0}; + if (threadIdx.x < kNumBins) { + binCount = smemHistogram[threadIdx.x]; + } + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + if (threadIdx.x < kNumBins) { + smemHistogram[threadIdx.x] = prefixSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + if (threadIdx.x < kNumBins) { + int nextPrefixSum = + threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; + if (prefixSum < kTopK && nextPrefixSum >= kTopK) { + smemThresholdBinIdx[0] = threadIdx.x; + } + } + + // Clear the counter to store the items for the final phase. + if (threadIdx.x == 0) { + smemFinalDstIdx[0] = 0; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + int thresholdBinIdx = smemThresholdBinIdx[0]; + + // Fetch elements one-by-one and populate the shared memory buffers. + for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; + rowIt += kNumThreadsPerBlock) { + float logit = logits[rowIdx * stride0 + rowIt * stride1]; + uint16_t idx = extractBinIdx(logit); + if (idx < thresholdBinIdx) { + int dstIdx = atomicAdd(&smemHistogram[idx], 1); + smemLogits[dstIdx] = logit; + smemIndices[dstIdx] = rowIt; + } else if (idx == thresholdBinIdx) { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + if (dstIdx < kNumFinalItems) { + smemFinal.items.logits[dstIdx] = logit; + smemFinal.items.indices[dstIdx] = rowIt; + } + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // The logits of the elements to be sorted in the final pass. + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; + +// Init. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + finalLogits[ii] = -FLT_MAX; + } + +// Read the elements from SMEM. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if (srcIdx < smemFinalDstIdx[0]) { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + + // Make sure the shared memory has been read. + __syncthreads(); + + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); + + // Copy the data back to the shared memory storage. + int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + if (dstIdx < kTopK) { + smemLogits[dstIdx] = finalLogits[ii]; + smemIndices[dstIdx] = finalIndices[ii]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The topK logits. + float topKLogits[kNumTopKItemsPerThread]; + // The topK indices. + int topKIndices[kNumTopKItemsPerThread]; + +// Load from shared memory. +#pragma unroll + for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { + topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x]; + topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x]; + } + + // Sort the elements. + TopKSort(smemFinal.topKSort) + .SortDescendingBlockedToStriped(topKLogits, topKIndices); + +// Store to global memory. +#pragma unroll + for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { + int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; + outIndices[offset] = topKIndices[ii] - rowStart; + outLogits[offset] = topKLogits[ii]; + } +} + } // namespace vllm void apply_repetition_penalties_( @@ -85,4 +324,20 @@ void apply_repetition_penalties_( repetition_penalties.data_ptr(), num_seqs, vocab_size, tile_size); }); -} \ No newline at end of file +} + +void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + torch::Tensor& values, int64_t numRows, int64_t stride0, + int64_t stride1) { + // Compute the results on the device. + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::topKPerRow + <<>>( + logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + values.data_ptr(), static_cast(stride0), + static_cast(stride1)); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4edb7af50f102..a4a9f87b28f14 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #define stride_tag #endif + ops.def( + "persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s," + "bool use_ue8m0) -> ()"); + ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA, + &persistent_masked_m_silu_mul_quant); + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); @@ -115,6 +122,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); +#ifndef USE_ROCM + ops.def( + "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " + "Tensor input, Tensor input_global_scale) -> ()"); + ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); +#endif + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -161,6 +175,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Polynomial Normalization. + ops.def( + "poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float " + "epsilon) -> ()"); + ops.impl("poly_norm", torch::kCUDA, &poly_norm); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -168,6 +188,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("apply_repetition_penalties_", torch::kCUDA, &apply_repetition_penalties_); + // Optimized top-k per row operation + ops.def( + "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " + "Tensor! indices, Tensor! values, int numRows, int stride0, " + "int stride1) -> ()"); + ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -201,16 +228,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - // Apply GPT-NeoX or GPT-J style rotary embedding to query and key - // (supports multiple loras). - ops.def( - "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor!? key, int head_size," - " Tensor cos_sin_cache, bool is_neox," - " int rot_dim," - " Tensor cos_sin_cache_offsets) -> ()"); - ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); - // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AWQ. @@ -309,6 +326,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor", + {stride_tag}); + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -367,7 +404,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", {stride_tag}); - ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + // conditionally compiled so impl registration is in source file // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias @@ -480,19 +517,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); - // CUTLASS MLA decode - ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, float scale) -> ()"); - ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // SM100 CUTLASS MLA decode ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, Tensor workspace, float " - "scale," + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," " int num_kv_splits) -> ()"); // conditionally compiled so impl in source file @@ -583,6 +613,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + // Hadamard transforms + ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); + #ifndef USE_ROCM // Compute per-token-group FP8 quantized tensor and scaling factor. ops.def( @@ -682,6 +715,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale, Tensor? seq_starts) -> ()"); cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, &gather_and_maybe_dequant_cache); + + cache_ops.def( + "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); + cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + + cache_ops.def( + "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " + "slot_mapping, " + "int quant_block_size, str kv_cache_dtype) -> ()"); + cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, + &indexer_k_quant_and_cache); + + cache_ops.def( + "cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! " + "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); + cache_ops.impl("cp_gather_indexer_k_quant_cache", torch::kCUDA, + &cp_gather_indexer_k_quant_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/docker/Dockerfile b/docker/Dockerfile index 839ac501dbaf0..3a0db3cc49f61 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12 # # Example: # docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 + +# Important: We build with an old version of Ubuntu to maintain broad +# compatibility with other Linux OSes. The main reason for this is that the +# glibc version is baked into the distro, and binaries built with one glibc +# version are not backwards compatible with OSes that use an earlier version. ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 # TODO: Restore to base image after FlashInfer AOT wheel fixed ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 @@ -75,34 +80,19 @@ ARG TARGETPLATFORM ARG INSTALL_KV_CONNECTORS=false ENV DEBIAN_FRONTEND=noninteractive -ARG DEADSNAKES_MIRROR_URL -ARG DEADSNAKES_GPGKEY_URL ARG GET_PIP_URL -# Install Python and other dependencies +# Install system dependencies and uv, then create Python virtual environment RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ - && apt-get install -y ccache software-properties-common git curl sudo \ - && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \ - if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \ - mkdir -p -m 0755 /etc/apt/keyrings ; \ - curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \ - sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \ - echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \ - fi ; \ - else \ - for i in 1 2 3; do \ - add-apt-repository -y ppa:deadsnakes/ppa && break || \ - { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ - done ; \ - fi \ - && apt-get update -y \ - && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ - && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ - && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ - && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ - && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \ + && apt-get install -y ccache software-properties-common git curl sudo python3-pip \ + && curl -LsSf https://astral.sh/uv/install.sh | sh \ + && $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \ + && rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \ + && ln -s /opt/venv/bin/python3 /usr/bin/python3 \ + && ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \ + && ln -s /opt/venv/bin/pip /usr/bin/pip \ && python3 --version && python3 -m pip --version ARG PIP_INDEX_URL UV_INDEX_URL @@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER -# Install uv for faster pip installs -RUN --mount=type=cache,target=/root/.cache/uv \ - python3 -m pip install uv +# Activate virtual environment and add uv to PATH +ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH" +ENV VIRTUAL_ENV="/opt/venv" # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 @@ -142,7 +132,7 @@ WORKDIR /workspace COPY requirements/common.txt requirements/common.txt COPY requirements/cuda.txt requirements/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/cuda.txt \ + uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # cuda arch list used by torch @@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" ENV UV_LINK_MODE=copy RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/build.txt \ + uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') COPY . . @@ -196,6 +186,7 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0 # Flag to control whether to use pre-built vLLM wheels ARG VLLM_USE_PRECOMPILED="" +ARG VLLM_MAIN_CUDA_VERSION="" # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/uv \ @@ -213,6 +204,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ + && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ && export VLLM_DOCKER_BUILD_CONTEXT=1 \ && sccache --show-stats \ && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ @@ -237,7 +229,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py -ARG VLLM_MAX_SIZE_MB=400 +ARG VLLM_MAX_SIZE_MB=450 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ @@ -261,11 +253,13 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" # Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy +# Install libnuma-dev, required by fastsafetensors (fixes #20384) +RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/dev.txt \ + uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') #################### DEV IMAGE #################### @@ -279,6 +273,10 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM +ARG GDRCOPY_CUDA_VERSION=12.8 +# Keep in line with FINAL_BASE_IMAGE +ARG GDRCOPY_OS_VERSION=Ubuntu22_04 + SHELL ["/bin/bash", "-c"] ARG DEADSNAKES_MIRROR_URL @@ -358,62 +356,14 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist uv pip install --system dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -# If we need to build FlashInfer wheel before its release: -# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ -# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' -# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive -# $ cd flashinfer -# $ git checkout v0.2.6.post1 -# $ python -m flashinfer.aot -# $ python -m build --no-isolation --wheel -# $ ls -la dist -# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl -# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl +# Install FlashInfer pre-compiled kernel cache and binaries +# https://docs.flashinfer.ai/installation.html +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system flashinfer-cubin==0.4.0 \ + && uv pip install --system flashinfer-jit-cache==0.4.0 \ + --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + && flashinfer show-config -# Install FlashInfer from source -ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -# Keep this in sync with "flashinfer" extra in setup.py -ARG FLASHINFER_GIT_REF="v0.2.12" -# Flag to control whether to compile FlashInfer AOT kernels -# Set to "true" to enable AOT compilation: -# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... -ARG FLASHINFER_AOT_COMPILE=false -RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' - . /etc/environment - git clone --depth 1 --recursive --shallow-submodules \ - --branch ${FLASHINFER_GIT_REF} \ - ${FLASHINFER_GIT_REPO} flashinfer - pushd flashinfer - if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" - fi - echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" - # Build AOT kernels - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer.aot - # Install with no-build-isolation since we already built AOT kernels - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation . \ - --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') - # Download pre-compiled cubins - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." - else - echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode" - uv pip install --system . \ - --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') - fi - popd - rm -rf flashinfer -BASH COPY examples examples COPY benchmarks benchmarks COPY ./vllm/collect_env.py . @@ -432,19 +382,32 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # Install DeepGEMM from source -ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" +ARG DEEPGEMM_GIT_REF COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh RUN --mount=type=cache,target=/root/.cache/uv \ - VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \ - && rm /tmp/install_deepgemm.sh + VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} -# Install EP kernels(pplx-kernels and DeepEP), NixL +COPY tools/install_gdrcopy.sh install_gdrcopy.sh +RUN set -eux; \ + case "${TARGETPLATFORM}" in \ + linux/arm64) UUARCH="aarch64" ;; \ + linux/amd64) UUARCH="x64" ;; \ + *) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \ + esac; \ + ./install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"; \ + rm ./install_gdrcopy.sh + +# Install EP kernels(pplx-kernels and DeepEP) COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh -COPY tools/install_nixl.sh install_nixl.sh ENV CUDA_HOME=/usr/local/cuda -RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ - && bash install_python_libraries.sh \ - && bash install_nixl.sh --force +RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a 10.0a+PTX}" \ + && bash install_python_libraries.sh + +# CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will +# return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers +# consistently from the host (see https://github.com/vllm-project/vllm/issues/18859). +# Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override. +ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} #################### vLLM installation IMAGE #################### @@ -518,7 +481,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3] + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.14.0' ENV VLLM_USAGE_SOURCE production-docker-image @@ -531,5 +494,5 @@ ENTRYPOINT ["./sagemaker-entrypoint.sh"] FROM vllm-openai-base AS vllm-openai -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["vllm", "serve"] #################### OPENAI API SERVER #################### diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 1a0981f8ea6d6..2aed1872ee85a 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -13,7 +13,7 @@ # vllm-dev: used for development # # Build arguments: -# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9 +# PYTHON_VERSION=3.13|3.12 (default)|3.11|3.10 # VLLM_CPU_DISABLE_AVX512=false (default)|true # VLLM_CPU_AVX512BF16=false (default)|true # VLLM_CPU_AVX512VNNI=false (default)|true @@ -47,7 +47,7 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH" ENV UV_HTTP_TIMEOUT=500 -# Install Python dependencies +# Install Python dependencies ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} ENV UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} ENV UV_INDEX_STRATEGY="unsafe-best-match" @@ -104,7 +104,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/workspace/vllm/.deps,sharing=locked \ --mount=type=bind,source=.git,target=.git \ - VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel ######################### TEST DEPS ######################### FROM base AS vllm-test-deps @@ -114,13 +114,10 @@ WORKDIR /workspace/vllm RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ cp requirements/test.in requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ - sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \ - sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ - sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install -r requirements/cpu-test.txt + uv pip install -r requirements/cpu-test.txt ######################### DEV IMAGE ######################### FROM vllm-build AS vllm-dev @@ -133,12 +130,12 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install -e tests/vllm_test_utils + uv pip install -e tests/vllm_test_utils RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=cache,target=/root/.cache/ccache \ --mount=type=bind,source=.git,target=.git \ - VLLM_TARGET_DEVICE=cpu python3 setup.py develop + VLLM_TARGET_DEVICE=cpu python3 setup.py develop COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt @@ -163,11 +160,12 @@ ADD ./benchmarks/ ./benchmarks/ ADD ./vllm/collect_env.py . ADD ./.buildkite/ ./.buildkite/ +# Create symlink for vllm-workspace to maintain CI compatibility +RUN ln -sf /workspace /vllm-workspace + # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install -e tests/vllm_test_utils - -ENTRYPOINT ["bash"] + uv pip install -e tests/vllm_test_utils ######################### RELEASE IMAGE ######################### FROM base AS vllm-openai @@ -179,4 +177,4 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \ uv pip install dist/*.whl -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["vllm", "serve"] diff --git a/docker/Dockerfile.neuron b/docker/Dockerfile.neuron deleted file mode 100644 index 8bc23554718dc..0000000000000 --- a/docker/Dockerfile.neuron +++ /dev/null @@ -1,56 +0,0 @@ -# default base image -# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04" - -FROM $BASE_IMAGE - -RUN echo "Base image is $BASE_IMAGE" - -# Install some basic utilities -RUN apt-get update && \ - apt-get install -y \ - git \ - python3 \ - python3-pip \ - ffmpeg libsm6 libxext6 libgl1 - -### Mount Point ### -# When launching the container, mount the code directory to /workspace -ARG APP_MOUNT=/workspace -VOLUME [ ${APP_MOUNT} ] -WORKDIR ${APP_MOUNT}/vllm - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity -RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install pytest - -# uninstall transformers-neuronx package explicitly to avoid version conflict -RUN python3 -m pip uninstall -y transformers-neuronx - -COPY . . -ARG GIT_REPO_CHECK=0 -RUN --mount=type=bind,source=.git,target=.git \ - if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi - -RUN python3 -m pip install -U \ - 'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ - -r requirements/neuron.txt - -ENV VLLM_TARGET_DEVICE neuron -RUN --mount=type=bind,source=.git,target=.git \ - pip install --no-build-isolation -v -e . - -# install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils - -# install transformers-neuronx package as an optional dependencies (for V0) -# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict -RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps - -RUN python3 -m pip install sentencepiece transformers==4.48.0 -U - -# overwrite entrypoint to run bash script -RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py - -CMD ["/bin/bash"] diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index e147b97f0e056..165256a9bd513 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0 # #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base ARG CUDA_VERSION=12.8.0 ARG PYTHON_VERSION=3.12 ARG TARGETPLATFORM @@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. # build flashinfer for torch nightly from source around 10 mins -# release version: v0.2.2.post1 +# release version: v0.4.0 # todo(elainewy): cache flashinfer build result for faster build ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ @@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ echo "git clone flashinfer..." \ && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ && cd flashinfer \ - && git checkout v0.2.2.post1 \ + && git checkout v0.4.0 \ && git submodule update --init --recursive \ && echo "finish git clone flashinfer..." \ && rm -rf build \ diff --git a/docker/Dockerfile.ppc64le b/docker/Dockerfile.ppc64le index aaff240388f2c..ad9eae94b83dd 100644 --- a/docker/Dockerfile.ppc64le +++ b/docker/Dockerfile.ppc64le @@ -1,4 +1,4 @@ -ARG BASE_UBI_IMAGE_TAG=9.5-1741850109 +ARG BASE_UBI_IMAGE_TAG=9.6-1754584681 ############################################################### # Stage to build openblas @@ -7,7 +7,7 @@ ARG BASE_UBI_IMAGE_TAG=9.5-1741850109 FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS openblas-builder ARG MAX_JOBS -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 RUN microdnf install -y dnf && dnf install -y gcc-toolset-13 make wget unzip \ && source /opt/rh/gcc-toolset-13/enable \ && wget https://github.com/OpenMathLib/OpenBLAS/releases/download/v$OPENBLAS_VERSION/OpenBLAS-$OPENBLAS_VERSION.zip \ @@ -38,7 +38,7 @@ RUN dnf install -y openjpeg2-devel lcms2-devel tcl-devel tk-devel fribidi-devel FROM centos-deps-builder AS base-builder ARG PYTHON_VERSION=3.12 -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 # Set Environment Variables for venv, cargo & openblas ENV VIRTUAL_ENV=/opt/vllm @@ -61,7 +61,7 @@ RUN --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/, pkgconfig xsimd zeromq-devel kmod findutils protobuf* \ libtiff-devel libjpeg-devel zlib-devel freetype-devel libwebp-devel \ harfbuzz-devel libraqm-devel libimagequant-devel libxcb-devel \ - python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \ + python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip clang-devel \ && dnf clean all \ && PREFIX=/usr/local make -C /openblas install \ && ln -sf /usr/lib64/libatomic.so.1 /usr/lib64/libatomic.so \ @@ -79,9 +79,9 @@ RUN --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/, FROM base-builder AS torch-builder ARG MAX_JOBS -ARG TORCH_VERSION=2.6.0 +ARG TORCH_VERSION=2.7.0 ARG _GLIBCXX_USE_CXX11_ABI=1 -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 RUN --mount=type=cache,target=/root/.cache/uv \ source /opt/rh/gcc-toolset-13/enable && \ @@ -93,7 +93,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ MAX_JOBS=${MAX_JOBS:-$(nproc)} \ PYTORCH_BUILD_VERSION=${TORCH_VERSION} PYTORCH_BUILD_NUMBER=1 uv build --wheel --out-dir /torchwheels/ -ARG TORCHVISION_VERSION=0.21.0 +ARG TORCHVISION_VERSION=0.22.0 ARG TORCHVISION_USE_NVJPEG=0 ARG TORCHVISION_USE_FFMPEG=0 RUN --mount=type=cache,target=/root/.cache/uv \ @@ -104,7 +104,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ BUILD_VERSION=${TORCHVISION_VERSION} \ uv build --wheel --out-dir /torchwheels/ --no-build-isolation -ARG TORCHAUDIO_VERSION=2.6.0 +ARG TORCHAUDIO_VERSION=2.7.0 ARG BUILD_SOX=1 ARG BUILD_KALDI=1 ARG BUILD_RNNT=1 @@ -128,7 +128,7 @@ FROM base-builder AS arrow-builder ARG MAX_JOBS ARG PYARROW_PARALLEL -ARG PYARROW_VERSION=19.0.1 +ARG PYARROW_VERSION=21.0.0 RUN --mount=type=cache,target=/root/.cache/uv \ source /opt/rh/gcc-toolset-13/enable && \ git clone --recursive https://github.com/apache/arrow.git -b apache-arrow-${PYARROW_VERSION} && \ @@ -145,7 +145,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ make install -j ${MAX_JOBS:-$(nproc)} && \ cd ../../python/ && \ uv pip install -v -r requirements-build.txt && uv pip install numpy==2.1.3 && \ - pip show numpy && ls -lrt /opt/vllm/lib/python3.12/site-packages/numpy && \ PYARROW_PARALLEL=${PYARROW_PARALLEL:-$(nproc)} \ python setup.py build_ext \ --build-type=release --bundle-arrow-cpp \ @@ -187,6 +186,23 @@ RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_V && make -j ${MAX_JOBS:-$(nproc)} +############################################################### +# Stage to build numba +############################################################### + +FROM base-builder AS numba-builder + +ARG MAX_JOBS +ARG NUMBA_VERSION=0.61.2 + +# Clone all required dependencies +RUN dnf install ninja-build llvm15 llvm15-devel -y && source /opt/rh/gcc-toolset-13/enable && export PATH=$PATH:/usr/lib64/llvm15/bin && \ + git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ + cd ./numba && \ + if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ + sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ + fi && python -m build --wheel --installer=uv --outdir /numbawheels/ + ############################################################### # Stage to build vllm - this stage builds and installs # vllm, tensorizer and vllm-tgis-adapter and builds uv cache @@ -199,6 +215,7 @@ COPY --from=torch-builder /tmp/control /dev/null COPY --from=arrow-builder /tmp/control /dev/null COPY --from=cv-builder /tmp/control /dev/null COPY --from=numa-builder /tmp/control /dev/null +COPY --from=numba-builder /tmp/control /dev/null ARG VLLM_TARGET_DEVICE=cpu ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 @@ -206,6 +223,8 @@ ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 # this step installs vllm and populates uv cache # with all the transitive dependencies RUN --mount=type=cache,target=/root/.cache/uv \ + dnf install llvm15 llvm15-devel -y && \ + rpm -ivh --nodeps https://mirror.stream.centos.org/9-stream/CRB/ppc64le/os/Packages/protobuf-lite-devel-3.14.0-16.el9.ppc64le.rpm && \ source /opt/rh/gcc-toolset-13/enable && \ git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \ uv pip install maturin && \ @@ -215,15 +234,18 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \ --mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \ --mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \ + --mount=type=bind,from=numba-builder,source=/numbawheels/,target=/numbawheels/,ro \ --mount=type=bind,src=.,dst=/src/,rw \ source /opt/rh/gcc-toolset-13/enable && \ - uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \ + export PATH=$PATH:/usr/lib64/llvm15/bin && \ + uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /numbawheels/*.whl && \ sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \ - uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \ + sed -i -e 's/.*sentencepiece.*//g' /src/pyproject.toml /src/requirements/*.txt && \ + uv pip install sentencepiece==0.2.0 pandas pythran nanobind pybind11 /hf_wheels/*.whl && \ make -C /numactl install && \ # sentencepiece.pc is in some pkgconfig inside uv cache export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \ - uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \ + nanobind_DIR=$(uv pip show nanobind | grep Location | sed 's/^Location: //;s/$/\/nanobind\/cmake/') && uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \ cd /src/ && \ uv build --wheel --out-dir /vllmwheel/ --no-build-isolation && \ uv pip install /vllmwheel/*.whl @@ -250,7 +272,7 @@ RUN git clone --recursive https://github.com/Reference-LAPACK/lapack.git -b v${L FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS vllm-openai ARG PYTHON_VERSION=3.12 -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 # Set Environment Variables for venv & openblas ENV VIRTUAL_ENV=/opt/vllm @@ -268,6 +290,7 @@ COPY --from=vllmcache-builder /tmp/control /dev/null COPY --from=numa-builder /tmp/control /dev/null COPY --from=lapack-builder /tmp/control /dev/null COPY --from=openblas-builder /tmp/control /dev/null +COPY --from=numba-builder /tmp/control /dev/null # install gcc-11, python, openblas, numactl, lapack RUN --mount=type=cache,target=/root/.cache/uv \ @@ -276,13 +299,13 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/,target=/openblas/,rw \ rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \ microdnf install --nodocs -y \ - tar findutils openssl \ + libomp tar findutils openssl llvm15 llvm15-devel \ pkgconfig xsimd g++ gcc-fortran libsndfile \ libtiff libjpeg openjpeg2 zlib zeromq \ freetype lcms2 libwebp tcl tk utf8proc \ - harfbuzz fribidi libraqm libimagequant libxcb \ + harfbuzz fribidi libraqm libimagequant libxcb util-linux \ python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \ - && microdnf clean all \ + && export PATH=$PATH:/usr/lib64/llvm15/bin && microdnf clean all \ && python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \ && python -m pip install -U pip uv --no-cache \ && make -C /numactl install \ @@ -298,7 +321,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \ --mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \ --mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \ - HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl + --mount=type=bind,from=numba-builder,source=/numbawheels/,target=/numbawheels/,ro \ + export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && uv pip install sentencepiece==0.2.0 && \ + HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /numbawheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl + COPY ./ /workspace/vllm WORKDIR /workspace/vllm @@ -314,4 +340,4 @@ WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks -ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["vllm", "serve"] \ No newline at end of file diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f164857325043..c8900212e5a1b 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -29,7 +29,10 @@ ARG VLLM_BRANCH="main" ONBUILD RUN git clone ${VLLM_REPO} \ && cd vllm \ && git fetch -v --prune -- origin ${VLLM_BRANCH} \ - && git checkout FETCH_HEAD + && git checkout FETCH_HEAD \ + && if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \ + git remote add upstream "https://github.com/vllm-project/vllm.git" \ + && git fetch upstream ; fi FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm # ----------------------- @@ -47,6 +50,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite # ----------------------- @@ -71,7 +75,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ + && python3 -m pip install lm-eval[api]==0.4.4 \ && python3 -m pip install pytest-shard # ----------------------- @@ -100,8 +104,10 @@ ARG COMMON_WORKDIR # Copy over the benchmark scripts as well COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples +COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 +ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false # ENV that can improve safe tensor loading, and end-to-end time diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 3414c0aa845cb..873c2fbcd4d30 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,27 +1,23 @@ -ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete -ARG HIPBLASLT_BRANCH="db8e93b4" -ARG HIPBLAS_COMMON_BRANCH="7c1566b" -ARG LEGACY_HIPBLASLT_OPTION= -ARG RCCL_BRANCH="648a58d" -ARG RCCL_REPO="https://github.com/ROCm/rccl" -ARG TRITON_BRANCH="e5be006" -ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="295f2ed4" -ARG PYTORCH_VISION_BRANCH="v0.21.0" -ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete +ARG TRITON_BRANCH="f9e5bf54" +ARG TRITON_REPO="https://github.com/ROCm/triton.git" +ARG PYTORCH_BRANCH="b2fb6885" +ARG PYTORCH_VISION_BRANCH="v0.23.0" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" -ARG FA_BRANCH="1a7f4dfa" +ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="916bf3c" +ARG AITER_BRANCH="2ab9f4cd" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base -ENV PATH=/opt/rocm/llvm/bin:$PATH +ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV ROCM_PATH=/opt/rocm ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: -ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 +ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ENV AITER_ROCM_ARCH=gfx942;gfx950 ARG PYTHON_VERSION=3.12 @@ -45,38 +41,7 @@ RUN apt-get update -y \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version -RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython - -FROM base AS build_hipblaslt -ARG HIPBLASLT_BRANCH -ARG HIPBLAS_COMMON_BRANCH -# Set to "--legacy_hipblas_direct" for ROCm<=6.2 -ARG LEGACY_HIPBLASLT_OPTION -RUN git clone https://github.com/ROCm/hipBLAS-common.git -RUN cd hipBLAS-common \ - && git checkout ${HIPBLAS_COMMON_BRANCH} \ - && mkdir build \ - && cd build \ - && cmake .. \ - && make package \ - && dpkg -i ./*.deb -RUN git clone https://github.com/ROCm/hipBLASLt -RUN cd hipBLASLt \ - && git checkout ${HIPBLASLT_BRANCH} \ - && apt-get install -y llvm-dev \ - && ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ - && cd build/release \ - && make package -RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install - -FROM base AS build_rccl -ARG RCCL_BRANCH -ARG RCCL_REPO -RUN git clone ${RCCL_REPO} -RUN cd rccl \ - && git checkout ${RCCL_BRANCH} \ - && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} -RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install +RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython FROM base AS build_triton ARG TRITON_BRANCH @@ -84,9 +49,11 @@ ARG TRITON_REPO RUN git clone ${TRITON_REPO} RUN cd triton \ && git checkout ${TRITON_BRANCH} \ - && cd python \ - && python3 setup.py bdist_wheel --dist-dir=dist -RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + && if [ ! -f setup.py ]; then cd python; fi \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && mkdir -p /app/install && cp dist/*.whl /app/install +RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \ + && python3 -m build --wheel && cp dist/*.whl /app/install; fi FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ @@ -98,8 +65,6 @@ ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO -ARG FA_BRANCH -ARG FA_REPO RUN git clone ${PYTORCH_REPO} pytorch RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ pip install -r requirements.txt && git submodule update --init --recursive \ @@ -110,14 +75,20 @@ RUN git clone ${PYTORCH_VISION_REPO} vision RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl +RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ + && cp /app/vision/dist/*.whl /app/install + +FROM base AS build_fa +ARG FA_BRANCH +ARG FA_REPO +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl RUN git clone ${FA_REPO} RUN cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist -RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ - && cp /app/vision/dist/*.whl /app/install \ - && cp /app/flash-attention/dist/*.whl /app/install +RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install FROM base AS build_aiter ARG AITER_BRANCH @@ -129,33 +100,27 @@ RUN cd aiter \ && git checkout ${AITER_BRANCH} \ && git submodule update --init --recursive \ && pip install -r requirements.txt -RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl +RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install -FROM base AS final -RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ - dpkg -i /install/*deb \ - && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ - && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status -RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ - dpkg -i /install/*deb \ - && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ - && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status +FROM base AS debs +RUN mkdir /app/debs RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ - pip install /install/*.whl + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ - pip install /install/*.whl + cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ - pip install /install/*.whl + cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs + +FROM base AS final +RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \ pip install /install/*.whl ARG BASE_IMAGE -ARG HIPBLAS_COMMON_BRANCH -ARG HIPBLASLT_BRANCH -ARG LEGACY_HIPBLASLT_OPTION -ARG RCCL_BRANCH -ARG RCCL_REPO ARG TRITON_BRANCH ARG TRITON_REPO ARG PYTORCH_BRANCH @@ -167,11 +132,6 @@ ARG FA_REPO ARG AITER_BRANCH ARG AITER_REPO RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ - && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ - && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ - && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ - && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ - && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ @@ -179,5 +139,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 9270b48c54d4b..7fd7598b8bd93 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,8 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile \ + clang llvm-devel llvm-static clang-devel && \ microdnf clean all # Python Installation @@ -191,7 +192,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ -DCOMPILER_RT_BUILD_ORC=OFF \ -DCOMPILER_RT_INCLUDE_TESTS=OFF \ ${CMAKE_ARGS} -GNinja ../llvm \ - && ninja install . && \ # build llvmlite cd ../../llvmlite && python setup.py bdist_wheel && \ @@ -200,6 +200,45 @@ RUN --mount=type=cache,target=/root/.cache/uv \ sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ fi && python setup.py bdist_wheel +# Edit aws-lc-sys to support s390x +FROM python-install AS aws-lc-sys-editor +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG AWS_LC_VERSION=v0.30.0 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone --recursive https://github.com/aws/aws-lc-rs.git && \ + cd aws-lc-rs && \ + git checkout tags/aws-lc-sys/${AWS_LC_VERSION} && \ + git submodule sync && \ + git submodule update --init --recursive && \ + cd aws-lc-sys && \ + sed -i '682 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '712 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '747 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c + +# Build Outlines Core +FROM python-install AS outlines-core-builder +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG OUTLINES_CORE_VERSION=0.2.10 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=aws-lc-sys-editor,source=/tmp/aws-lc-rs/aws-lc-sys,target=/tmp/aws-lc-sys,rw \ + git clone https://github.com/dottxt-ai/outlines-core.git && \ + cd outlines-core && \ + git checkout tags/${OUTLINES_CORE_VERSION} && \ + sed -i "s/version = \"0.0.0\"/version = \"${OUTLINES_CORE_VERSION}\"/" Cargo.toml && \ + echo '[patch.crates-io]' >> Cargo.toml && \ + echo 'aws-lc-sys = { path = "/tmp/aws-lc-sys" }' >> Cargo.toml && \ + uv pip install maturin && \ + python -m maturin build --release --out dist # Final build stage FROM python-install AS vllm-cpu @@ -230,6 +269,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ + --mount=type=bind,from=outlines-core-builder,source=/tmp/outlines-core/dist,target=/tmp/outlines-core/dist/ \ sed -i '/^torch/d' requirements/build.txt && \ ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ @@ -237,6 +277,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ + OUTLINES_CORE_WHL_FILE=$(ls /tmp/outlines-core/dist/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ @@ -244,6 +285,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ $TORCH_WHL_FILE \ $LLVM_WHL_FILE \ $NUMBA_WHL_FILE \ + $OUTLINES_CORE_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ -r requirements/cpu.txt @@ -267,4 +309,4 @@ USER 2000 WORKDIR /home/vllm # Set the default entrypoint -ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["vllm", "serve"] diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 65d2e5036b783..49ea39cad5128 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -1,12 +1,10 @@ FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base -RUN rm /etc/apt/sources.list.d/intel-graphics.list +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ + add-apt-repository -y ppa:kobuk-team/intel-graphics RUN apt clean && apt-get update -y && \ - apt-get install -y software-properties-common && \ - add-apt-repository ppa:deadsnakes/ppa && \ - apt-get install -y python3.10 python3.10-distutils && \ - curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \ apt-get install -y --no-install-recommends --fix-missing \ curl \ ffmpeg \ @@ -17,17 +15,29 @@ RUN apt clean && apt-get update -y && \ libgl1 \ lsb-release \ numactl \ - python3.10-dev \ - wget + wget \ + vim \ + python3.12 \ + python3.12-dev \ + python3-pip +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 -RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 -RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 +RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing + +RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh +RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc +SHELL ["bash", "-c"] +CMD ["bash", "-c", "source /root/.bashrc && exec bash"] WORKDIR /workspace/vllm COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt COPY requirements/common.txt /workspace/vllm/requirements/common.txt +# suppress the python externally managed environment error +RUN python3 -m pip config set global.break-system-packages true + RUN --mount=type=cache,target=/root/.cache/pip \ pip install --no-cache-dir \ -r requirements/xpu.txt @@ -54,8 +64,14 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope -ENV VLLM_USAGE_SOURCE production-docker-image \ - TRITON_XPU_PROFILE 1 +RUN --mount=type=cache,target=/root/.cache/pip \ + pip uninstall oneccl oneccl-devel -y + # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] + +# install nixl from source code +RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages/.nixl.mesonpy.libs/plugins/" + +ENTRYPOINT ["vllm", "serve"] diff --git a/docs/.nav.yml b/docs/.nav.yml index dbac0e12f1bf2..c103ed476d76d 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -32,10 +32,7 @@ nav: - models/pooling_models.md - models/extensions - Hardware Supported Models: models/hardware_supported_models - - Features: - - features/compatibility_matrix.md - - features/* - - features/quantization + - Features: features - Developer Guide: - contributing/README.md - General: @@ -47,11 +44,12 @@ nav: - contributing/model/registration.md - contributing/model/tests.md - contributing/model/multimodal.md + - contributing/model/transcription.md - CI: contributing/ci - Design Documents: design - API Reference: - api/README.md - - api/vllm/* + - api/vllm - CLI Reference: cli - Community: - community/* diff --git a/docs/README.md b/docs/README.md index 683e1d37563f5..ae95717def4cd 100644 --- a/docs/README.md +++ b/docs/README.md @@ -56,7 +56,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support diff --git a/docs/api/README.md b/docs/api/README.md index 57142e8f5625d..86e310f567dd3 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -14,7 +14,7 @@ API documentation for vLLM's configuration classes. - [vllm.config.LoRAConfig][] - [vllm.config.MultiModalConfig][] - [vllm.config.PoolerConfig][] -- [vllm.config.DecodingConfig][] +- [vllm.config.StructuredOutputsConfig][] - [vllm.config.ObservabilityConfig][] - [vllm.config.KVTransferConfig][] - [vllm.config.CompilationConfig][] @@ -46,7 +46,6 @@ Engine classes for offline and online inference. Inference parameters for vLLM APIs. [](){ #sampling-params } -[](){ #pooling-params } - [vllm.SamplingParams][] - [vllm.PoolingParams][] diff --git a/docs/api/vllm/.meta.yml b/docs/api/vllm/.meta.yml index c15adfec644cf..d105540fee792 100644 --- a/docs/api/vllm/.meta.yml +++ b/docs/api/vllm/.meta.yml @@ -1,2 +1,2 @@ search: - boost: 0.5 + exclude: true diff --git a/docs/assets/deployment/hf-inference-endpoints-catalog.png b/docs/assets/deployment/hf-inference-endpoints-catalog.png new file mode 100644 index 0000000000000..a26681eec7b33 Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-catalog.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-choose-infra.png b/docs/assets/deployment/hf-inference-endpoints-choose-infra.png new file mode 100644 index 0000000000000..09e92ad3fc7a0 Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-choose-infra.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-click-deploy-button.png b/docs/assets/deployment/hf-inference-endpoints-click-deploy-button.png new file mode 100644 index 0000000000000..687db6e03212f Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-click-deploy-button.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-configure-container.png b/docs/assets/deployment/hf-inference-endpoints-configure-container.png new file mode 100644 index 0000000000000..834d0dda65acc Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-configure-container.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-create-endpoint.png b/docs/assets/deployment/hf-inference-endpoints-create-endpoint.png new file mode 100644 index 0000000000000..e1b0d12d1caf0 Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-create-endpoint.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-locate-deploy-button.png b/docs/assets/deployment/hf-inference-endpoints-locate-deploy-button.png new file mode 100644 index 0000000000000..4fc6fe8eebefd Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-locate-deploy-button.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-new-endpoint.png b/docs/assets/deployment/hf-inference-endpoints-new-endpoint.png new file mode 100644 index 0000000000000..2ce2e6ad8d78b Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-new-endpoint.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-select-hardware.png b/docs/assets/deployment/hf-inference-endpoints-select-hardware.png new file mode 100644 index 0000000000000..444863b17c1c0 Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-select-hardware.png differ diff --git a/docs/assets/deployment/hf-inference-endpoints-select-model.png b/docs/assets/deployment/hf-inference-endpoints-select-model.png new file mode 100644 index 0000000000000..44f66520fd12d Binary files /dev/null and b/docs/assets/deployment/hf-inference-endpoints-select-model.png differ diff --git a/docs/assets/design/cuda_graphs/current_design.png b/docs/assets/design/cuda_graphs/current_design.png new file mode 100644 index 0000000000000..045b8bbd6bfd4 Binary files /dev/null and b/docs/assets/design/cuda_graphs/current_design.png differ diff --git a/docs/assets/design/cuda_graphs/executor_runtime.png b/docs/assets/design/cuda_graphs/executor_runtime.png new file mode 100644 index 0000000000000..f8d8abe43aac1 Binary files /dev/null and b/docs/assets/design/cuda_graphs/executor_runtime.png differ diff --git a/docs/assets/design/cuda_graphs/previous_design.png b/docs/assets/design/cuda_graphs/previous_design.png new file mode 100644 index 0000000000000..db1432288a2fe Binary files /dev/null and b/docs/assets/design/cuda_graphs/previous_design.png differ diff --git a/docs/assets/design/cuda_graphs/wrapper_flow.png b/docs/assets/design/cuda_graphs/wrapper_flow.png new file mode 100644 index 0000000000000..749dc7f8bc5cc Binary files /dev/null and b/docs/assets/design/cuda_graphs/wrapper_flow.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png new file mode 100644 index 0000000000000..185f61e6a3ede Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/basic_grouping_example.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/full_attn.png b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png new file mode 100644 index 0000000000000..30eade5c7051c Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/full_attn.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png new file mode 100644 index 0000000000000..bcffc27a71649 Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/memory_layout.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/overview.png b/docs/assets/design/hybrid_kv_cache_manager/overview.png new file mode 100644 index 0000000000000..ac80581f491da Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/overview.png differ diff --git a/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png new file mode 100644 index 0000000000000..10aa6146dc7ab Binary files /dev/null and b/docs/assets/design/hybrid_kv_cache_manager/sw_attn.png differ diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 36232e6ad96cc..e821e2ac81149 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,11 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing) +- [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ), August 30th 2025. [[Slides]](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA) +- [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet), August 27th 2025. [[Slides]](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing) +- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md index 6ad3a66252664..8abb07caaab62 100644 --- a/docs/community/sponsors.md +++ b/docs/community/sponsors.md @@ -34,6 +34,7 @@ Compute Resources: - Trainy - UC Berkeley - UC San Diego +- Volcengine Slack Sponsor: Anyscale diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 058eba5fe0b1e..26b95ad053337 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits @@ -122,6 +122,46 @@ llm = LLM(model="google/gemma-3-27b-it", limit_mm_per_prompt={"image": 0}) ``` +### Configurable options + +`limit_mm_per_prompt` also accepts configurable options per modality. In the configurable form, you still specify `count`, and you may optionally provide size hints that control how vLLM profiles and reserves memory for your multi‑modal inputs. This helps you tune memory for the actual media you expect, instead of the model’s absolute maxima. + +Configurable options by modality: + +- `image`: `{"count": int, "width": int, "height": int}` +- `video`: `{"count": int, "num_frames": int, "width": int, "height": int}` +- `audio`: `{"count": int, "length": int}` + +Details could be found in [`ImageDummyOptions`][vllm.config.multimodal.ImageDummyOptions], [`VideoDummyOptions`][vllm.config.multimodal.VideoDummyOptions], and [`AudioDummyOptions`][vllm.config.multimodal.AudioDummyOptions]. + +Examples: + +```python +from vllm import LLM + +# Up to 5 images per prompt, profile with 512x512. +# Up to 1 video per prompt, profile with 32 frames at 640x640. +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={ + "image": {"count": 5, "width": 512, "height": 512}, + "video": {"count": 1, "num_frames": 32, "width": 640, "height": 640}, + }, +) +``` + +For backward compatibility, passing an integer works as before and is interpreted as `{"count": }`. For example: + +- `limit_mm_per_prompt={"image": 5}` is equivalent to `limit_mm_per_prompt={"image": {"count": 5}}` +- You can mix formats: `limit_mm_per_prompt={"image": 5, "video": {"count": 1, "num_frames": 32, "width": 640, "height": 640}}` + +!!! note + - The size hints affect memory profiling only. They shape the dummy inputs used to compute reserved activation sizes. They do not change how inputs are actually processed at inference time. + - If a hint exceeds what the model can accept, vLLM clamps it to the model's effective maximum and may log a warning. + +!!! warning + These size hints currently only affect activation memory profiling. Encoder cache size is determined by the actual inputs at runtime and is not limited by these hints. + ## Multi-modal processor arguments For certain models, you can adjust the multi-modal processor arguments to diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 69d4de9d2f644..5c74610ebd290 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -139,9 +139,9 @@ there is relatively little gain from TP. On the other hand, TP incurs significan overhead because of all-reduce being performed after every layer. Given this, it may be advantageous to instead shard the batched input data using TP, essentially -performing batch-level DP. This has been shown to improve the throughput by around 10% for +performing batch-level DP. This has been shown to improve the throughput and TTFT by around 10% for `tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, -batch-level DP can provide another 40% increase to throughput compared to regular TP. +batch-level DP can provide another 40% improvement compared to regular TP. Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. @@ -164,16 +164,23 @@ llm = LLM( ) ``` -!! important +!!! important Batch-level DP is not to be confused with API request-level DP (which is instead controlled by `data_parallel_size`). -The availablilty of batch-level DP is based on model implementation. -Currently, the following models support `mm_encoder_tp_mode="data"`: +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. +Known supported models (with corresponding benchmarks): + +- dots_ocr () +- GLM-4.1V or above () +- InternVL () +- Kimi-VL () - Llama4 () -- MiniCPM-V-4 () -- Qwen2.5-VL () +- MiniCPM-V-2.5 or above (, ) +- Qwen2-VL or above (, , ) - Step3 () ## Input Processing @@ -196,21 +203,55 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 !!! note API server scale-out is only available for online inference. +!!! warning + By default, 8 CPU threads are used in each API server to load media items (e.g. images) + from request data. + + If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT` + to avoid CPU resource exhaustion. + !!! note - [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled - because it requires a one-to-one correspondance between API and engine core processes. + API server scale-out disables [multi-modal IPC caching](#ipc-caching) + because it requires a one-to-one correspondence between API and engine core processes. + + This does not impact [multi-modal processor caching](#processor-caching). ## Multi-Modal Caching -### Processor Cache - -By default, the multi-modal processor cache is enabled to avoid repeatedly processing -the same multi-modal inputs via Hugging Face `AutoProcessor`, +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` -(default 4 GiB per API process + 4 GiB per engine core process). -If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +#### Key-Replicated Cache + +By default, IPC caching uses a **key-replicated cache**, where cache keys exist +in both the API (`P0`) and engine core (`P1`) processes, but the actual cache +data resides only in `P1`. + +#### Shared Memory Cache + +When multiple worker processes are involved (e.g., when TP > 1), a +**shared-memory cache** is more efficient. This can be enabled by setting +`mm_processor_cache_type="shm"`. In this mode, cache keys are stored +on `P0`, while the cache data itself lives in shared memory accessible by all +processes. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. Examples: @@ -219,7 +260,27 @@ Examples: llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=8) +# Use a shared-memory based IPC cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size=2, + mm_processor_cache_type="shm", + mm_processor_cache_gb=8) + # Disable the cache llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) ``` + +### Cache Placement + +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: + +| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------|-------------| +| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` | +| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` | +| N/A | Disabled | N/A | N/A | N/A | `0` | + +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index a93435ed71b50..e456077e04958 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -45,32 +45,32 @@ This initial compilation time ranges significantly and is impacted by many of th ### Optimize based on your data -#### max model len vs. most model len +#### max-model-len vs. most-model-len ![most_model_len](../assets/design/tpu/most_model_len.png) -If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most model len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. +If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`. -The requests get subdivided into max-model-len and most-model-len categories, for the latter category, we can gain better performance since the server can process more requests at a time. +The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time. #### Padding -For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128: 128, 256, etc. +For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.) -The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about tpu padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: +The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: -1) the default exponential padding (pad to the nearest power of 2) -2) bucket padding (pad to the nearest linearly increasing bucket). +1. the default exponential padding (pad to the nearest power of 2) +2. bucket padding (pad to the nearest linearly increasing bucket). When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`. For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]. -The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. +The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320. -However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compilaed graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. +However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. #### Quantization diff --git a/docs/contributing/README.md b/docs/contributing/README.md index 5a2a70d57e85f..b52bdf7f02e40 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -26,113 +26,123 @@ See . ## Developing ---8<-- "docs/getting_started/installation/python_env_setup.inc.md" - -Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. -Check out the [building from source][build-from-source] documentation for details. - -For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations. - -### Building the docs with MkDocs - -#### Introduction to MkDocs - -[MkDocs](https://github.com/mkdocs/mkdocs) is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file. - -#### Install MkDocs and Plugins - -Install MkDocs along with the [plugins](https://github.com/vllm-project/vllm/blob/main/mkdocs.yaml) used in the vLLM documentation, as well as required dependencies: - -```bash -uv pip install -r requirements/docs.txt -``` - -!!! note - Ensure that your Python version is compatible with the plugins (e.g., `mkdocs-awesome-nav` requires Python 3.10+) - -#### Verify Installation - -Confirm that MkDocs is correctly installed: - -```bash -mkdocs --version -``` - -Example output: - -```console -mkdocs, version 1.6.1 from /opt/miniconda3/envs/mkdoc/lib/python3.10/site-packages/mkdocs (Python 3.10) -``` - -#### Clone the `vLLM` repository +The first step of contributing to vLLM is to clone the GitHub repository: ```bash git clone https://github.com/vllm-project/vllm.git cd vllm ``` -#### Start the Development Server +Then, configure your Python virtual environment. -MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the `mkdocs.yml` configuration file, and then start the server by running the `mkdocs serve` command: +--8<-- "docs/getting_started/installation/python_env_setup.inc.md" + +If you are only developing vLLM's Python code, install vLLM using: ```bash -mkdocs serve +VLLM_USE_PRECOMPILED=1 uv pip install -e . ``` -Example output: +If you are developing vLLM's Python and CUDA/C++ code, install vLLM using: -```console -INFO - Documentation built in 106.83 seconds -INFO - [22:02:02] Watching paths for changes: 'docs', 'mkdocs.yaml' -INFO - [22:02:02] Serving on http://127.0.0.1:8000/ +```bash +uv pip install -e . ``` -#### View in Your Browser +For more details about installing from source and installing for other hardware, check out the [installation instructions](../getting_started/installation/README.md) for your hardware and head to the "Build wheel from source" section. -Open up [http://127.0.0.1:8000/](http://127.0.0.1:8000/) in your browser to see a live preview:. - -#### Learn More - -For additional features and advanced configurations, refer to the official [MkDocs Documentation](https://www.mkdocs.org/). - -## Testing - -??? console "Commands" - - ```bash - # These commands are only for Nvidia CUDA platforms. - uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto - - # Linting, formatting and static type checking - pre-commit install - - # You can manually run pre-commit with - pre-commit run --all-files --show-diff-on-failure - - # To manually run something from CI that does not run - # locally by default, you can run: - pre-commit run mypy-3.9 --hook-stage manual --all-files - - # Unit tests - pytest tests/ - - # Run tests for a single test file with detailed output - pytest -s -v tests/test_logger.py - ``` +For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations. !!! tip - Since the ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12. + vLLM is compatible with Python versions 3.10 to 3.13. However, vLLM's default [Dockerfile](gh-file:docker/Dockerfile) ships with Python 3.12 and tests in CI (except `mypy`) are run with Python 3.12. Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment. -!!! note "Install python3-dev if Python.h is missing" +### Linting + +vLLM uses `pre-commit` to lint and format the codebase. See if `pre-commit` is new to you. Setting up `pre-commit` is as easy as: + +```bash +uv pip install pre-commit +pre-commit install +``` + +vLLM's `pre-commit` hooks will now run automatically every time you commit. + +!!! tip "Tips" + You can manually run the `pre-commit` hooks using: + + ```bash + pre-commit run # runs on staged files + pre-commit run -a # runs on all files (short for --all-files) + ``` + + --- + + Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with: + + ```bash + pre-commit run --hook-stage manual markdownlint + pre-commit run --hook-stage manual mypy-3.10 + ``` + +### Documentation + +MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, . + +Get started with: + +```bash +uv pip install -r requirements/docs.txt +``` + +!!! tip + Ensure that your Python version is compatible with the plugins + (e.g., `mkdocs-awesome-nav` requires Python 3.10+) + +MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. +From the root of the repository, run: + +```bash +mkdocs serve # with API ref (~10 minutes) +API_AUTONAV_EXCLUDE=vllm mkdocs serve # API ref off (~15 seconds) +``` + +Once you see `Serving on http://127.0.0.1:8000/` in the logs, the live preview is ready! +Open in your browser to see it. + +For additional features and advanced configurations, refer to the: + +- [MkDocs documentation](https://www.mkdocs.org/) +- [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/) (the MkDocs theme we use) + +### Testing + +vLLM uses `pytest` to test the codebase. + +```bash +# Install the test dependencies used in CI (CUDA only) +uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto + +# Install some common test dependencies (hardware agnostic) +uv pip install pytest pytest-asyncio + +# Run all tests +pytest tests/ + +# Run tests for a single test file with detailed output +pytest -s -v tests/test_logger.py +``` + +!!! tip "Install python3-dev if Python.h is missing" If any of the above commands fails with `Python.h: No such file or directory`, install `python3-dev` with `sudo apt install python3-dev`. -!!! note +!!! warning "Warnings" Currently, the repository is not fully checked by `mypy`. -!!! note + --- + Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU platform to run unit tests locally, rely on the continuous integration system to run the tests for now. @@ -194,8 +204,7 @@ appropriately to indicate the type of change. Please use one of the following: The PR needs to meet the following code quality standards: - We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). -- Pass all linter checks. Please use `pre-commit` to format your code. See - if `pre-commit` is new to you. +- Pass all linter checks. - The code needs to be well-documented to ensure future contributors can easily understand the code. - Include sufficient tests to ensure the project stays correct and robust. This diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 2bbed778f3c6a..6b1eabf3d67fa 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -1,9 +1,882 @@ +--- +toc_depth: 4 +--- + # Benchmark Suites -vLLM contains two sets of benchmarks: +vLLM provides comprehensive benchmarking tools for performance testing and evaluation: -- [Performance benchmarks][performance-benchmarks] -- [Nightly benchmarks][nightly-benchmarks] +- **[Benchmark CLI]**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing +- **[Performance benchmarks][performance-benchmarks]**: Automated CI benchmarks for development +- **[Nightly benchmarks][nightly-benchmarks]**: Comparative benchmarks against alternatives + +[Benchmark CLI]: #benchmark-cli + +## Benchmark CLI + +This section guides you through running benchmark tests with the extensive +datasets supported on vLLM. It's a living document, updated as new features and datasets +become available. + +### Dataset Overview + + + +| Dataset | Online | Offline | Data Path | +|---------|--------|---------|-----------| +| ShareGPT | ✅ | ✅ | `wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json` | +| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
`wget http://images.cocodataset.org/zips/train2017.zip` | +| ShareGPT4Video (Video) | ✅ | ✅ | `git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video` | +| BurstGPT | ✅ | ✅ | `wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv` | +| Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` | +| Random | ✅ | ✅ | `synthetic` | +| RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` | +| Prefix Repetition | ✅ | ✅ | `synthetic` | +| HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` | +| HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` | +| HuggingFace-InstructCoder | ✅ | ✅ | `likaixin/InstructCoder` | +| HuggingFace-AIMO | ✅ | ✅ | `AI-MO/aimo-validation-aime`, `AI-MO/NuminaMath-1.5`, `AI-MO/NuminaMath-CoT` | +| HuggingFace-Other | ✅ | ✅ | `lmms-lab/LLaVA-OneVision-Data`, `Aeala/ShareGPT_Vicuna_unfiltered` | +| HuggingFace-MTBench | ✅ | ✅ | `philschmid/mt-bench` | +| HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` | +| Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` | +| Custom | ✅ | ✅ | Local file: `data.jsonl` | + +Legend: + +- ✅ - supported +- 🟡 - Partial support +- 🚧 - to be supported + +!!! note + HuggingFace dataset's `dataset-name` should be set to `hf`. + For local `dataset-path`, please set `hf-name` to its Hugging Face ID like + + ```bash + --dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat + ``` + +### Examples + +#### 🚀 Online Benchmark + +
+Show more + +First start serving your model: + +```bash +vllm serve NousResearch/Hermes-3-Llama-3.1-8B +``` + +Then run the benchmarking script: + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench serve \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --endpoint /v1/completions \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ + --num-prompts 10 +``` + +If successful, you will see the following output: + +```text +============ Serving Benchmark Result ============ +Successful requests: 10 +Benchmark duration (s): 5.78 +Total input tokens: 1369 +Total generated tokens: 2212 +Request throughput (req/s): 1.73 +Output token throughput (tok/s): 382.89 +Total Token throughput (tok/s): 619.85 +---------------Time to First Token---------------- +Mean TTFT (ms): 71.54 +Median TTFT (ms): 73.88 +P99 TTFT (ms): 79.49 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 7.91 +Median TPOT (ms): 7.96 +P99 TPOT (ms): 8.03 +---------------Inter-token Latency---------------- +Mean ITL (ms): 7.74 +Median ITL (ms): 7.70 +P99 ITL (ms): 8.39 +================================================== +``` + +##### Custom Dataset + +If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl + +```json +{"prompt": "What is the capital of India?"} +{"prompt": "What is the capital of Iran?"} +{"prompt": "What is the capital of China?"} +``` + +```bash +# start server +vllm serve meta-llama/Llama-3.1-8B-Instruct +``` + +```bash +# run benchmarking script +vllm bench serve --port 9001 --save-result --save-detailed \ + --backend vllm \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --endpoint /v1/completions \ + --dataset-name custom \ + --dataset-path \ + --custom-skip-chat-template \ + --num-prompts 80 \ + --max-concurrency 1 \ + --temperature=0.3 \ + --top-p=0.75 \ + --result-dir "./log/" +``` + +You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. + +##### VisionArena Benchmark for Vision Language Models + +```bash +# need a model with vision capability here +vllm serve Qwen/Qwen2-VL-7B-Instruct +``` + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --hf-split train \ + --num-prompts 1000 +``` + +##### InstructCoder Benchmark with Speculative Decoding + +``` bash +vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +``` bash +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name hf \ + --dataset-path likaixin/InstructCoder \ + --num-prompts 2048 +``` + +##### Spec Bench Benchmark with Speculative Decoding + +``` bash +vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +[SpecBench dataset](https://github.com/hemingkx/Spec-Bench) + +Run all categories: + +``` bash +# Download the dataset using: +# wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name spec_bench \ + --dataset-path "/data/spec_bench/question.jsonl" \ + --num-prompts -1 +``` + +Available categories include `[writing, roleplay, reasoning, math, coding, extraction, stem, humanities, translation, summarization, qa, math_reasoning, rag]`. + +Run only a specific category like "summarization": + +``` bash +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name spec_bench \ + --dataset-path "/data/spec_bench/question.jsonl" \ + --num-prompts -1 + --spec-bench-category "summarization" +``` + +##### Other HuggingFaceDataset Examples + +```bash +vllm serve Qwen/Qwen2-VL-7B-Instruct +``` + +`lmms-lab/LLaVA-OneVision-Data`: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmms-lab/LLaVA-OneVision-Data \ + --hf-split train \ + --hf-subset "chart2text(cauldron)" \ + --num-prompts 10 +``` + +`Aeala/ShareGPT_Vicuna_unfiltered`: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ + --hf-split train \ + --num-prompts 10 +``` + +`AI-MO/aimo-validation-aime`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path AI-MO/aimo-validation-aime \ + --num-prompts 10 \ + --seed 42 +``` + +`philschmid/mt-bench`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path philschmid/mt-bench \ + --num-prompts 80 +``` + +`vdaita/edit_5k_char` or `vdaita/edit_10k_char`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path vdaita/edit_5k_char \ + --num-prompts 90 \ + --blazedit-min-distance 0.01 \ + --blazedit-max-distance 0.99 +``` + +##### Running With Sampling Parameters + +When using OpenAI-compatible backends such as `vllm`, optional sampling +parameters can be specified. Example client command: + +```bash +vllm bench serve \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --endpoint /v1/completions \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ + --top-k 10 \ + --top-p 0.9 \ + --temperature 0.5 \ + --num-prompts 10 +``` + +##### Running With Ramp-Up Request Rate + +The benchmark tool also supports ramping up the request rate over the +duration of the benchmark run. This can be useful for stress testing the +server or finding the maximum throughput that it can handle, given some latency budget. + +Two ramp-up strategies are supported: + +- `linear`: Increases the request rate linearly from a start value to an end value. +- `exponential`: Increases the request rate exponentially. + +The following arguments can be used to control the ramp-up: + +- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). +- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. +- `--ramp-up-end-rps`: The request rate at the end of the benchmark. + +
+ +#### 📈 Offline Throughput Benchmark + +
+Show more + +```bash +vllm bench throughput \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset-name sonnet \ + --dataset-path vllm/benchmarks/sonnet.txt \ + --num-prompts 10 +``` + +If successful, you will see the following output + +```text +Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s +Total num prompt tokens: 5014 +Total num output tokens: 1500 +``` + +##### VisionArena Benchmark for Vision Language Models + +```bash +vllm bench throughput \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --num-prompts 1000 \ + --hf-split train +``` + +The `num prompt tokens` now includes image token counts + +```text +Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s +Total num prompt tokens: 14527 +Total num output tokens: 1280 +``` + +##### InstructCoder Benchmark with Speculative Decoding + +``` bash +VLLM_WORKER_MULTIPROC_METHOD=spawn \ +vllm bench throughput \ + --dataset-name=hf \ + --dataset-path=likaixin/InstructCoder \ + --model=meta-llama/Meta-Llama-3-8B-Instruct \ + --input-len=1000 \ + --output-len=100 \ + --num-prompts=2048 \ + --async-engine \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +```text +Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s +Total num prompt tokens: 261136 +Total num output tokens: 204800 +``` + +##### Other HuggingFaceDataset Examples + +`lmms-lab/LLaVA-OneVision-Data`: + +```bash +vllm bench throughput \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path lmms-lab/LLaVA-OneVision-Data \ + --hf-split train \ + --hf-subset "chart2text(cauldron)" \ + --num-prompts 10 +``` + +`Aeala/ShareGPT_Vicuna_unfiltered`: + +```bash +vllm bench throughput \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ + --hf-split train \ + --num-prompts 10 +``` + +`AI-MO/aimo-validation-aime`: + +```bash +vllm bench throughput \ + --model Qwen/QwQ-32B \ + --backend vllm \ + --dataset-name hf \ + --dataset-path AI-MO/aimo-validation-aime \ + --hf-split train \ + --num-prompts 10 +``` + +Benchmark with LoRA adapters: + +``` bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench throughput \ + --model meta-llama/Llama-2-7b-hf \ + --backend vllm \ + --dataset_path /ShareGPT_V3_unfiltered_cleaned_split.json \ + --dataset_name sharegpt \ + --num-prompts 10 \ + --max-loras 2 \ + --max-lora-rank 8 \ + --enable-lora \ + --lora-path yard1/llama-2-7b-sql-lora-test +``` + +
+ +#### 🛠️ Structured Output Benchmark + +
+Show more + +Benchmark the performance of structured output generation (JSON, grammar, regex). + +##### Server Setup + +```bash +vllm serve NousResearch/Hermes-3-Llama-3.1-8B +``` + +##### JSON Schema Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset json \ + --structured-output-ratio 1.0 \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### Grammar-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset grammar \ + --structure-type grammar \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### Regex-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset regex \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### Choice-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset choice \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### XGrammar Benchmark Dataset + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset xgrammar_bench \ + --request-rate 10 \ + --num-prompts 1000 +``` + +
+ +#### 📚 Long Document QA Benchmark + +
+Show more + +Benchmark the performance of long document question-answering with prefix caching. + +##### Basic Long Document QA Test + +```bash +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 16 \ + --document-length 2000 \ + --output-len 50 \ + --repeat-count 5 +``` + +##### Different Repeat Modes + +```bash +# Random mode (default) - shuffle prompts randomly +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode random + +# Tile mode - repeat entire prompt list in sequence +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode tile + +# Interleave mode - repeat each prompt consecutively +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode interleave +``` + +
+ +#### 🗂️ Prefix Caching Benchmark + +
+Show more + +Benchmark the efficiency of automatic prefix caching. + +##### Fixed Prompt with Prefix Caching + +```bash +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 128:256 +``` + +##### ShareGPT Dataset with Prefix Caching + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +``` + +##### Prefix Repetition Dataset + +```bash +vllm bench serve \ + --backend openai \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-name prefix_repetition \ + --num-prompts 100 \ + --prefix-repetition-prefix-len 512 \ + --prefix-repetition-suffix-len 128 \ + --prefix-repetition-num-prefixes 5 \ + --prefix-repetition-output-len 128 +``` + +
+ +#### ⚡ Request Prioritization Benchmark + +
+Show more + +Benchmark the performance of request prioritization in vLLM. + +##### Basic Prioritization Test + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority +``` + +##### Multiple Sequences per Prompt + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority \ + --n 2 +``` + +
+ +#### 👁️ Multi-Modal Benchmark + +
+Show more + +Benchmark the performance of multi-modal requests in vLLM. + +##### Images (ShareGPT4V) + +Start vLLM: + +```bash +vllm serve Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"image": 1}' \ + --allowed-local-media-path /path/to/sharegpt4v/images +``` + +Send requests with images: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completions +``` + +##### Videos (ShareGPT4Video) + +Start vLLM: + +```bash +vllm serve Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"video": 1}' \ + --allowed-local-media-path /path/to/sharegpt4video/videos +``` + +Send requests with videos: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completions +``` + +##### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. + +
+ +#### Embedding Benchmark + +Benchmark the performance of embedding requests in vLLM. + +
+Show more + +##### Text Embeddings + +Unlike generative models which use Completions API or Chat Completions API, +you should set `--backend openai-embeddings` and `--endpoint /v1/embeddings` to use the Embeddings API. + +You can use any text dataset to benchmark the model, such as ShareGPT. + +Start the server: + +```bash +vllm serve jinaai/jina-embeddings-v3 --trust-remote-code +``` + +Run the benchmark: + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench serve \ + --model jinaai/jina-embeddings-v3 \ + --backend openai-embeddings \ + --endpoint /v1/embeddings \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json +``` + +##### Multi-modal Embeddings + +Unlike generative models which use Completions API or Chat Completions API, +you should set `--endpoint /v1/embeddings` to use the Embeddings API. The backend to use depends on the model: + +- CLIP: `--backend openai-embeddings-clip` +- VLM2Vec: `--backend openai-embeddings-vlm2vec` + +For other models, please add your own implementation inside to match the expected instruction format. + +You can use any text or multi-modal dataset to benchmark the model, as long as the model supports it. +For example, you can use ShareGPT and VisionArena to benchmark vision-language embeddings. + +Serve and benchmark CLIP: + +```bash +# Run this in another process +vllm serve openai/clip-vit-base-patch32 + +# Run these one by one after the server is up +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench serve \ + --model openai/clip-vit-base-patch32 \ + --backend openai-embeddings-clip \ + --endpoint /v1/embeddings \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json + +vllm bench serve \ + --model openai/clip-vit-base-patch32 \ + --backend openai-embeddings-clip \ + --endpoint /v1/embeddings \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat +``` + +Serve and benchmark VLM2Vec: + +```bash +# Run this in another process +vllm serve TIGER-Lab/VLM2Vec-Full --runner pooling \ + --trust-remote-code \ + --chat-template examples/template_vlm2vec_phi3v.jinja + +# Run these one by one after the server is up +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench serve \ + --model TIGER-Lab/VLM2Vec-Full \ + --backend openai-embeddings-vlm2vec \ + --endpoint /v1/embeddings \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json + +vllm bench serve \ + --model TIGER-Lab/VLM2Vec-Full \ + --backend openai-embeddings-vlm2vec \ + --endpoint /v1/embeddings \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat +``` + +
[](){ #performance-benchmarks } @@ -11,9 +884,63 @@ vLLM contains two sets of benchmarks: The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM. +### Manually Trigger the benchmark + +Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite. +For CPU environment, please use the image with "-cpu" postfix. + +Here is an example for docker run command for CPU. + +```bash +docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu +``` + +Then, run below command inside the docker instance. + +```bash +bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +``` + +When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json. + +#### Runtime environment variables + +- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0. +- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file). +- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file). +- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file). +- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string. +- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string. + +For more results visualization, check the [visualizing the results](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md#visualizing-the-results). + The latest performance results are hosted on the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). -More information on the performance benchmarks and their parameters can be found [here](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). +More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). + +### Continuous Benchmarking + +The continuous benchmarking provides automated performance monitoring for vLLM across different models and GPU devices. This helps track vLLM's performance characteristics over time and identify any performance regressions or improvements. + +#### How It Works + +The continuous benchmarking is triggered via a [GitHub workflow CI](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) in the PyTorch infrastructure repository, which runs automatically every 4 hours. The workflow executes three types of performance tests: + +- **Serving tests**: Measure request handling and API performance +- **Throughput tests**: Evaluate token generation rates +- **Latency tests**: Assess response time characteristics + +#### Benchmark Configuration + +The benchmarking currently runs on a predefined set of models configured in the [vllm-benchmarks directory](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks). To add new models for benchmarking: + +1. Navigate to the appropriate GPU directory in the benchmarks configuration +2. Add your model specifications to the corresponding configuration files +3. The new models will be included in the next scheduled benchmark run + +#### Viewing Results + +All continuous benchmarking results are automatically published to the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). [](){ #nightly-benchmarks } diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 7ef22d6f8c3f5..3dae62dd5d944 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -90,7 +90,7 @@ address the long build time at its source, the current workaround is to set `VLL to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) when manually triggering a build on Buildkite. This branch accomplishes two things: -1. Increase the timeout limit to 10 hours so that the build doesn't timeout. +1. Increase the timeout limit to 10 hours so that the build doesn't time out. 2. Allow the compiled artifacts to be written to the vLLM sccache S3 bucket to warm it up so that future builds are faster. diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md index 0e34e69245afb..cc01a60ce1e7f 100644 --- a/docs/contributing/incremental_build.md +++ b/docs/contributing/incremental_build.md @@ -40,6 +40,16 @@ python tools/generate_cmake_presets.py The script will prompt you if it cannot automatically determine certain paths (e.g., `nvcc` or a specific Python executable for your vLLM development environment). Follow the on-screen prompts. If an existing `CMakeUserPresets.json` is found, the script will ask for confirmation before overwriting it. +**Force overwrite existing file:** + +To automatically overwrite an existing `CMakeUserPresets.json` without prompting, use the `--force-overwrite` flag: + +```console +python tools/generate_cmake_presets.py --force-overwrite +``` + +This is particularly useful in automated scripts or CI/CD environments where interactive prompts are not desired. + After running the script, a `CMakeUserPresets.json` file will be created in the root of your vLLM repository. ### Example `CMakeUserPresets.json` diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 0ca77fa499db7..36068bc14876b 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -3,7 +3,7 @@ !!! important Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve ` works first! -vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/compatibility_matrix.md) to optimize their performance. +vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/README.md#compatibility-matrix) to optimize their performance. The complexity of integrating a model into vLLM depends heavily on the model's architecture. The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. @@ -15,6 +15,7 @@ Read through these pages for a step-by-step guide: - [Registering a Model](registration.md) - [Unit Testing](tests.md) - [Multi-Modal Support](multimodal.md) +- [Speech-to-Text Support](transcription.md) !!! tip If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index 21b1f21d60a35..aafdb1058e03c 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -121,3 +121,31 @@ To support a model with interleaving sliding windows, we need to take care of th - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model. + +### How to support models that use Mamba? + +We consider 3 different scenarios: + +1. Models that use Mamba layers (either Mamba-1 or Mamba-2) but do not use attention layers. +2. Models that combine Mamba layers (either Mamba-1 or Mamba-2) together with attention layers. +3. Models that combine Mamba-like mechanisms (e.g., Linear Attention, ShortConv) together with attention layers. + +For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](gh-file:vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](gh-file:vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference. +The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config. +For the mamba layers themselves, please use the [`MambaMixer`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes. +Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations. +V0-only classes and code will be removed in the very near future. +The model should also be added to the `MODELS_CONFIG_MAP` dictionary in to ensure that the runtime defaults are optimized. + +For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](gh-file:vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](gh-file:vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together). +These models should follow the same instructions as case (1), but they should inherit protocol `IsHybrid` (instead of `IsAttentionFree`) and it is *not* necessary to add them to the `MODELS_CONFIG_MAP` (their runtime defaults will be inferred from the protocol). + +For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](gh-file:vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](gh-file:vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively. +Please follow the same guidelines as case (2) for implementing these models. +We use "mamba-like" to refer to layers that posses a state that is updated in-place, rather than being appended-to (like KV cache for attention). +For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. +It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. +Please see [`LinearAttentionMetadata`](gh-file:vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](gh-file:v1/attention/backends/short_conv_attn.py) for examples of this. +Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. +Please see the calls to `direct_register_custom_op` in or for examples of this. +The new custom op should then be added to the list `_attention_ops` in to ensure that piecewise CUDA graphs works as intended. diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 76d0f067fd452..724dc2284e282 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -66,35 +66,12 @@ Further update the model as follows: !!! important The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request. -- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. +!!! note + By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in + [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing. + This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings]. - ??? code - - ```python - from .utils import merge_multimodal_embeddings - - class YourModelForImage2Seq(nn.Module): - ... - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index) - - return inputs_embeds - ``` + You may override this method if additional logic is required for your model when merging embeddings. - Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model. @@ -281,17 +258,21 @@ Assuming that the memory usage increases with the number of tokens, the dummy in self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = \ self.info.get_image_size_with_most_features() + image_overrides = mm_options.get("image") if mm_options else None + return { "image": self._get_dummy_images(width=target_width, height=target_height, - num_images=num_images) + num_images=num_images, + overrides=image_overrides) } ``` @@ -461,16 +442,20 @@ Assuming that the memory usage increases with the number of tokens, the dummy in self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { "image": self._get_dummy_images(width=target_width, height=target_height, - num_images=num_images) + num_images=num_images, + overrides=image_overrides) } ``` @@ -840,7 +825,6 @@ Some HF processors directly insert feature tokens without replacing anything in Examples: - BLIP-2 (insert at start of prompt): -- Florence2 (insert at start of prompt): - Molmo (insert after `<|endoftext|>` token): ### Handling prompt updates unrelated to multi-modal data @@ -855,7 +839,7 @@ Examples: ### Custom HF processor -Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. +Some models don't define an HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to [_call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]. Examples: diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md new file mode 100644 index 0000000000000..62e58e5c6ac58 --- /dev/null +++ b/docs/contributing/model/transcription.md @@ -0,0 +1,276 @@ +# Speech-to-Text (Transcription/Translation) Support + +This document walks you through the steps to add support for speech-to-text (ASR) models to vLLM’s transcription and translation APIs by implementing [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription]. +Please refer to the [supported models](../../models/supported_models.md#transcription) for further guidance. + +## Update the base vLLM model + +It is assumed you have already implemented your model in vLLM according to the basic model guide. Extend your model with the [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription] interface and implement the following class attributes and methods. + +### `supported_languages` and `supports_transcription_only` + +Declare supported languages and capabilities: + +- The `supported_languages` mapping is validated at init time. +- Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper). + +??? code "supported_languages and supports_transcription_only" + ```python + from typing import ClassVar, Mapping, Optional, Literal + import numpy as np + import torch + from torch import nn + + from vllm.config import ModelConfig, SpeechToTextConfig + from vllm.inputs.data import PromptType + from vllm.model_executor.models.interfaces import SupportsTranscription + + class YourASRModel(nn.Module, SupportsTranscription): + # Map of ISO 639-1 language codes to language names + supported_languages: ClassVar[Mapping[str, str]] = { + "en": "English", + "it": "Italian", + # ... add more as needed + } + + # If your model only supports audio-conditioned generation + # (no text-only generation), enable this flag. + supports_transcription_only: ClassVar[bool] = True + ``` + +Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor.models.interfaces.SupportsTranscription.get_speech_to_text_config]. + +This is for controlling general behavior of the API when serving your model: + +??? code "get_speech_to_text_config()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_speech_to_text_config( + cls, + model_config: ModelConfig, + task_type: Literal["transcribe", "translate"], + ) -> SpeechToTextConfig: + return SpeechToTextConfig( + sample_rate=16_000, + max_audio_clip_s=30, + # Set to None to disable server-side chunking if your + # model/processor handles it already + min_energy_split_window_size=None, + ) + ``` + +See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls. + +Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns: + +#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) + +Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`: + +??? code "get_generation_prompt()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + # Example with a free-form instruction prompt + task_word = "Transcribe" if task_type == "transcribe" else "Translate" + prompt = ( + "user\n" + f"{task_word} this audio: " + "\nmodel\n" + ) + + return { + "multi_modal_data": {"audio": (audio, stt_config.sample_rate)}, + "prompt": prompt, + } + ``` + + For further clarification on multi modal inputs, please refer to [Multi-Modal Inputs](../../features/multimodal_inputs.md). + +#### Encoder–decoder audio-only (e.g., Whisper) + +Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: + +??? code "get_generation_prompt()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + if language is None: + raise ValueError("Language must be specified") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + }, + "decoder_prompt": ( + (f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>" + ), + } + return cast(PromptType, prompt) + ``` + +### `validate_language` (optional) + +Language validation via [validate_language][vllm.model_executor.models.interfaces.SupportsTranscription.validate_language] + +If your model requires a language and you want a default, override this method (see Whisper): + +??? code "validate_language()" + ```python + @classmethod + def validate_language(cls, language: Optional[str]) -> Optional[str]: + if language is None: + logger.warning( + "Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.") + language = "en" + return super().validate_language(language) + ``` + +### `get_num_audio_tokens` (optional) + +Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.models.interfaces.SupportsTranscription.get_num_audio_tokens] + +Provide a fast duration→token estimate to improve streaming usage statistics: + +??? code "get_num_audio_tokens()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: + # Return None if unknown; otherwise return an estimate. + return int(audio_duration_s * stt_config.sample_rate // 320) # example + ``` + +## Audio preprocessing and chunking + +The API server takes care of basic audio I/O and optional chunking before building prompts: + +- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`. +- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`. +- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words. + +Relevant server logic: + +??? code "_preprocess_speech_to_text()" + ```python + # vllm/entrypoints/openai/speech_to_text.py + async def _preprocess_speech_to_text(...): + language = self.model_cls.validate_language(request.language) + ... + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + duration = librosa.get_duration(y=y, sr=sr) + do_split_audio = (self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s) + chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = self.model_cls.get_generation_prompt( + audio=chunk, + stt_config=self.asr_config, + model_config=self.model_config, + language=language, + task_type=self.task_type, + request_prompt=request.prompt, + to_language=to_language, + ) + prompts.append(prompt) + return prompts, duration + ``` + +## Exposing tasks automatically + +vLLM automatically advertises transcription support if your model implements the interface: + +```python +if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + supported_tasks.append("transcription") +``` + +When enabled, the server initializes the transcription and translation handlers: + +```python +state.openai_serving_transcription = OpenAIServingTranscription(...) if "transcription" in supported_tasks else None +state.openai_serving_translation = OpenAIServingTranslation(...) if "transcription" in supported_tasks else None +``` + +No extra registration is required beyond having your model class available via the model registry and implementing `SupportsTranscription`. + +## Examples in-tree + +- Whisper encoder–decoder (audio-only): +- Voxtral decoder-only (audio embeddings + LLM): +- Gemma3n decoder-only with fixed instruction prompt: + +## Test with the API + +Once your model implements `SupportsTranscription`, you can test the endpoints (API mimics OpenAI): + +- Transcription (ASR): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/transcriptions + ``` + +- Translation (source → English unless otherwise supported): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/translations + ``` + +Or check out more examples in . + +!!! note + - If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking. + - Implementing `get_num_audio_tokens` improves accuracy of streaming usage metrics (`prompt_tokens`) without an extra forward pass. + - For multilingual behavior, keep `supported_languages` aligned with actual model capabilities. diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 74627e9062167..f6a73e99546ee 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -19,7 +19,7 @@ When using `vllm bench serve`, you can enable profiling by passing the `--profil Traces can be visualized using . !!! tip -You can directly call bench module without installing vllm using `python -m vllm.entrypoints.cli.main bench`. + You can directly call bench module without installing vLLM using `python -m vllm.entrypoints.cli.main bench`. !!! tip Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. @@ -39,8 +39,7 @@ Refer to for an example ```bash VLLM_TORCH_PROFILER_DIR=./vllm_profile \ - python -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3-70B + vllm serve meta-llama/Meta-Llama-3-70B ``` vllm bench command: @@ -73,6 +72,8 @@ apt install nsight-systems-cli ### Example commands and usage +When profiling with `nsys`, it is advisable to set the environment variable `VLLM_WORKER_MULTIPROC_METHOD=spawn`. The default is to use the `fork` method instead of `spawn`. More information on the topic can be found in the [Nsight Systems release notes](https://docs.nvidia.com/nsight-systems/ReleaseNotes/index.html#general-issues). + #### Offline Inference For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node` before any existing script you would run for offline inference. @@ -158,6 +159,22 @@ GUI example: Screenshot 2025-03-05 at 11 48 42 AM +## Continuous Profiling + +There is a [GitHub CI workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-profiling.yml) in the PyTorch infrastructure repository that provides continuous profiling for different models on vLLM. This automated profiling helps track performance characteristics over time and across different model configurations. + +### How It Works + +The workflow currently runs weekly profiling sessions for selected models, generating detailed performance traces that can be analyzed using different tools to identify performance regressions or optimization opportunities. But, it can be triggered manually as well, using the Github Action tool. + +### Adding New Models + +To extend the continuous profiling to additional models, you can modify the [profiling-tests.json](https://github.com/pytorch/pytorch-integration-testing/blob/main/vllm-profiling/cuda/profiling-tests.json) configuration file in the PyTorch integration testing repository. Simply add your model specifications to this file to include them in the automated profiling runs. + +### Viewing Profiling Results + +The profiling traces generated by the continuous profiling workflow are publicly available on the [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). Look for the **Profiling traces** table to access and download the traces for different models and runs. + ## Profiling vLLM Python Code The Python standard library includes @@ -206,3 +223,11 @@ One example is [snakeviz](https://jiffyclub.github.io/snakeviz/). pip install snakeviz snakeviz expensive_function.prof ``` + +### Analyzing Garbage Collection Costs + +Leverage VLLM_GC_DEBUG environment variable to debug GC costs. + +- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times +- VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5 + collected objects for each gc.collect diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index e62a33b2085ca..40a463a8a596c 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -1,41 +1,53 @@ -# Anything LLM +# AnythingLLM -[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. +[AnythingLLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment: + +```bash +pip install vllm +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with a supported chat-completion model, for example: -```bash -vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 -``` + ```bash + vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 + ``` -- Download and install [Anything LLM desktop](https://anythingllm.com/desktop). +1. Download and install [AnythingLLM Desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Prooviders --> LLM: - - LLM Provider: Generic OpenAI - - Base URL: http://{vllm server host}:{vllm server port}/v1 - - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` +1. Configure the AI provider: -![](../../assets/deployment/anything-llm-provider.png) + - At the bottom, click the 🔧 wrench icon -> **Open settings** -> **AI Providers** -> **LLM**. + - Enter the following values: + - LLM Provider: Generic OpenAI + - Base URL: `http://{vllm server host}:{vllm server port}/v1` + - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` -- Back to home page, New Workspace --> create `vllm` workspace, and start to chat: + ![set AI providers](../../assets/deployment/anything-llm-provider.png) -![](../../assets/deployment/anything-llm-chat-without-doc.png) +1. Create a workspace: -- Click the upload button: - - upload the doc - - select the doc and move to the workspace - - save and embed + 1. At the bottom, click the ↺ back icon and back to workspaces. + 1. Create a workspace (e.g., `vllm`) and start chatting. -![](../../assets/deployment/anything-llm-upload-doc.png) + ![create a workspace](../../assets/deployment/anything-llm-chat-without-doc.png) -- Chat again: +1. Add a document. -![](../../assets/deployment/anything-llm-chat-with-doc.png) + 1. Click the 📎 attachment icon. + 1. Upload a document. + 1. Select and move the document into your workspace. + 1. Save and embed it. + + ![add a document](../../assets/deployment/anything-llm-upload-doc.png) + +1. Chat using your document as context. + + ![chat with your context](../../assets/deployment/anything-llm-chat-with-doc.png) diff --git a/docs/deployment/frameworks/autogen.md b/docs/deployment/frameworks/autogen.md index c255a85d38401..5790087ed5c27 100644 --- a/docs/deployment/frameworks/autogen.md +++ b/docs/deployment/frameworks/autogen.md @@ -4,9 +4,7 @@ ## Prerequisites -- Setup vLLM environment - -- Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment +Set up the vLLM and [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment: ```bash pip install vllm @@ -18,14 +16,13 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]" ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -python -m vllm.entrypoints.openai.api_server \ - --model mistralai/Mistral-7B-Instruct-v0.2 -``` + ```bash + vllm serve mistralai/Mistral-7B-Instruct-v0.2 + ``` -- Call it with AutoGen: +1. Call it with AutoGen: ??? code diff --git a/docs/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md index cbca6e6282fc6..002935da56009 100644 --- a/docs/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -6,27 +6,31 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment: + +```bash +pip install vllm +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve qwen/Qwen1.5-0.5B-Chat + ``` -- Download and install [Chatbox desktop](https://chatboxai.app/en#download). +1. Download and install [Chatbox desktop](https://chatboxai.app/en#download). -- On the bottom left of settings, Add Custom Provider +1. On the bottom left of settings, Add Custom Provider - API Mode: `OpenAI API Compatible` - Name: vllm - API Host: `http://{vllm server host}:{vllm server port}/v1` - API Path: `/chat/completions` - Model: `qwen/Qwen1.5-0.5B-Chat` -![](../../assets/deployment/chatbox-settings.png) + ![](../../assets/deployment/chatbox-settings.png) -- Go to `Just chat`, and start to chat: +1. Go to `Just chat`, and start to chat: -![](../../assets/deployment/chatbox-chat.png) + ![](../../assets/deployment/chatbox-chat.png) diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md index 35f02c33cb02b..820ef0cbed9fa 100644 --- a/docs/deployment/frameworks/dify.md +++ b/docs/deployment/frameworks/dify.md @@ -8,44 +8,50 @@ This guide walks you through deploying Dify using a vLLM backend. ## Prerequisites -- Setup vLLM environment -- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) +Set up the vLLM environment: + +```bash +pip install vllm +``` + +And install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/). ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve Qwen/Qwen1.5-7B-Chat -``` + ```bash + vllm serve Qwen/Qwen1.5-7B-Chat + ``` -- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): +1. Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): -```bash -git clone https://github.com/langgenius/dify.git -cd dify -cd docker -cp .env.example .env -docker compose up -d -``` + ```bash + git clone https://github.com/langgenius/dify.git + cd dify + cd docker + cp .env.example .env + docker compose up -d + ``` -- Open the browser to access `http://localhost/install`, config the basic login information and login. +1. Open the browser to access `http://localhost/install`, config the basic login information and login. -- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. +1. In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. + +1. Fill in the model provider details as follows: -- Fill in the model provider details as follows: - **Model Type**: `LLM` - **Model Name**: `Qwen/Qwen1.5-7B-Chat` - **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1` - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` - **Completion Mode**: `Completion` -![](../../assets/deployment/dify-settings.png) + ![](../../assets/deployment/dify-settings.png) -- To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: +1. To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: -![](../../assets/deployment/dify-create-chatbot.png) + ![](../../assets/deployment/dify-create-chatbot.png) -- Click the chatbot you just created to open the chat interface and start interacting with the model: +1. Click the chatbot you just created to open the chat interface and start interacting with the model: -![](../../assets/deployment/dify-chat.png) + ![](../../assets/deployment/dify-chat.png) diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md index 70b4b48d4543e..836305cf15c42 100644 --- a/docs/deployment/frameworks/haystack.md +++ b/docs/deployment/frameworks/haystack.md @@ -6,7 +6,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac ## Prerequisites -- Setup vLLM and Haystack environment +Set up the vLLM and Haystack environment: ```bash pip install vllm haystack-ai @@ -14,13 +14,13 @@ pip install vllm haystack-ai ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve mistralai/Mistral-7B-Instruct-v0.1 -``` + ```bash + vllm serve mistralai/Mistral-7B-Instruct-v0.1 + ``` -- Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. +1. Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. ??? code diff --git a/docs/deployment/frameworks/hf_inference_endpoints.md b/docs/deployment/frameworks/hf_inference_endpoints.md new file mode 100644 index 0000000000000..75a234bdf1422 --- /dev/null +++ b/docs/deployment/frameworks/hf_inference_endpoints.md @@ -0,0 +1,170 @@ +# Hugging Face Inference Endpoints + +## Overview + +Models compatible with vLLM can be deployed on Hugging Face Inference Endpoints, either starting from the [Hugging Face Hub](https://huggingface.co) or directly from the [Inference Endpoints](https://endpoints.huggingface.co/) interface. This allows you to serve models in a fully managed environment with GPU acceleration, auto-scaling, and monitoring, without managing the infrastructure manually. + +For advanced details on vLLM integration and deployment options, see [Advanced Deployment Details](#advanced-deployment-details). + +## Deployment Methods + +- [**Method 1: Deploy from the Catalog.**](#method-1-deploy-from-the-catalog) One-click deploy models from the Hugging Face Hub with ready-made optimized configurations. +- [**Method 2: Guided Deployment (Transformers Models).**](#method-2-guided-deployment-transformers-models) Instantly deploy models tagged with `transformers` from the Hub UI using the **Deploy** button. +- [**Method 3: Manual Deployment (Advanced Models).**](#method-3-manual-deployment-advanced-models) For models that either use custom code with the `transformers` tag, or don’t run with standard `transformers` but are supported by vLLM. This method requires manual configuration. + +### Method 1: Deploy from the Catalog + +This is the easiest way to get started with vLLM on Hugging Face Inference Endpoints. You can browse a catalog of models with verified and optimized deployment configuration at [Inference Endpoints](https://endpoints.huggingface.co/catalog) to maximize performance. + +1. Go to [Endpoints Catalog](https://endpoints.huggingface.co/catalog) and in the **Inference Server** options, select `vLLM`.This will display the current list of models with optimized preconfigured options. + + ![Endpoints Catalog](../../assets/deployment/hf-inference-endpoints-catalog.png) + +1. Select the desired model and click **Create Endpoint**. + + ![Create Endpoint](../../assets/deployment/hf-inference-endpoints-create-endpoint.png) + +1. Once the deployment is ready, you can use the endpoint. Update the `DEPLOYMENT_URL` with the URL provided in the console, remembering to append `/v1` as required. + + ```python + # pip install openai + from openai import OpenAI + import os + + client = OpenAI( + base_url = DEPLOYMENT_URL, + api_key = os.environ["HF_TOKEN"] # https://huggingface.co/settings/tokens + ) + + chat_completion = client.chat.completions.create( + model = "HuggingFaceTB/SmolLM3-3B", + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Give me a brief explanation of gravity in simple terms." + } + ] + } + ], + stream = True + ) + + for message in chat_completion: + print(message.choices[0].delta.content, end = "") + ``` + +!!! note + The catalog provides models optimized for vLLM, including GPU settings and inference engine configurations. You can monitor the endpoint and update the **container or its configuration** from the Inference Endpoints UI. + +### Method 2: Guided Deployment (Transformers Models) + +This method applies to models with the [`transformers` library tag](https://huggingface.co/models?library=transformers) in their metadata. It allows you to deploy a model directly from the Hub UI without manual configuration. + +1. Navigate to a model on [Hugging Face Hub](https://huggingface.co/models). + For this example we will use the [`ibm-granite/granite-docling-258M`](https://huggingface.co/ibm-granite/granite-docling-258M) model. You can verify that the model is compatible by checking the front matter in the [README](https://huggingface.co/ibm-granite/granite-docling-258M/blob/main/README.md), where the library is tagged as `library: transformers`. + +2. Locate the **Deploy** button. The button appears for models tagged with `transformers` at the top right of the [model card](https://huggingface.co/ibm-granite/granite-docling-258M). + + ![Locate deploy button](../../assets/deployment/hf-inference-endpoints-locate-deploy-button.png) + +3. Click to **Deploy** button > **HF Inference Endpoints**. You will be taken to the Inference Endpoints interface to configure the deployment. + + ![Click deploy button](../../assets/deployment/hf-inference-endpoints-click-deploy-button.png) + +4. Select the Hardware (we choose AWS>GPU>T4 for the example) and Container Configuration. Choose `vLLM` as the container type and finalize the deployment pressing **Create Endpoint**. + + ![Select Hardware](../../assets/deployment/hf-inference-endpoints-select-hardware.png) + +5. Use the deployed endpoint. Update the `DEPLOYMENT_URL` with the URL provided in the console (remember to add `/v1` needed). You can then use your endpoint programmatically or via the SDK. + + ```python + # pip install openai + from openai import OpenAI + import os + + client = OpenAI( + base_url = DEPLOYMENT_URL, + api_key = os.environ["HF_TOKEN"] # https://huggingface.co/settings/tokens + ) + + chat_completion = client.chat.completions.create( + model = "ibm-granite/granite-docling-258M", + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png" + } + }, + { + "type": "text", + "text": "Convert this page to docling." + } + ] + } + ], + stream = True + ) + + for message in chat_completion: + print(message.choices[0].delta.content, end = "") + ``` + +!!! note + This method uses best-guess defaults. You may need to adjust the configuration to fit your specific requirements. + +### Method 3: Manual Deployment (Advanced Models) + +Some models require manual deployment because they: + +- Use custom code with the `transformers` tag +- Don't run with standard `transformers` but are supported by `vLLM` + +These models cannot be deployed using the **Deploy** button on the model card. + +In this guide, we demonstrate manual deployment using the [`rednote-hilab/dots.ocr`](https://huggingface.co/rednote-hilab/dots.ocr) model, an OCR model integrated with vLLM (see vLLM [PR](https://github.com/vllm-project/vllm/pull/24645)). + +1. Start a new deployment. Go to [Inference Endpoints](https://endpoints.huggingface.co/) and click `New`. + + ![New Endpoint](../../assets/deployment/hf-inference-endpoints-new-endpoint.png) + +2. Search the model in the Hub. In the dialog, switch to **Hub** and search for the desired model. + + ![Select model](../../assets/deployment/hf-inference-endpoints-select-model.png) + +3. Choosing infrastructure. On the configuration page, select the cloud provider and hardware from the available options. + For this demo, we choose AWS and L4 GPU. Adjust according to your hardware needs. + + ![Choose Infra](../../assets/deployment/hf-inference-endpoints-choose-infra.png) + +4. Configure the container. Scroll to the **Container Configuration** and select `vLLM` as the container type. + + ![Configure Container](../../assets/deployment/hf-inference-endpoints-configure-container.png) + +5. Create the endpoint. Click **Create Endpoint** to deploy the model. + + Once the endpoint is ready, you can use it with the OpenAI Completion API, cURL, or other SDKs. Remember to append `/v1` to the deployment URL if needed. + +!!! note + You can adjust the **container settings** (Container URI, Container Arguments) from the Inference Endpoints UI and press **Update Endpoint**. This redeploys the endpoint with the updated container configuration. Changes to the model itself require creating a new endpoint or redeploying with a different model. For example, for this demo, you may need to update the Container URI to the nightly image (`vllm/vllm-openai:nightly`) and add the `--trust-remote-code` flag in the container arguments. + +## Advanced Deployment Details + +With the [transformers backend integration](https://blog.vllm.ai/2025/04/11/transformers-backend.html), vLLM now offers Day 0 support for any model compatible with `transformers`. This means you can deploy such models immediately, leveraging vLLM’s optimized inference without additional backend modifications. + +Hugging Face Inference Endpoints provides a fully managed environment for serving models via vLLM. You can deploy models without configuring servers, installing dependencies, or managing clusters. Endpoints also support deployment across multiple cloud providers (AWS, Azure, GCP) without the need for separate accounts. + +The platform integrates seamlessly with the Hugging Face Hub, allowing you to deploy any vLLM- or `transformers`-compatible model, track usage, and update the inference engine directly. The vLLM engine comes preconfigured, enabling optimized inference and easy switching between models or engines without modifying your code. This setup simplifies production deployment: endpoints are ready in minutes, include monitoring and logging, and let you focus on serving models rather than maintaining infrastructure. + +## Next Steps + +- Explore the [Inference Endpoints](https://endpoints.huggingface.co/catalog) model catalog +- Read the Inference Endpoints [documentation](https://huggingface.co/docs/inference-endpoints/en/index) +- Learn about [Inference Endpoints engines](https://huggingface.co/docs/inference-endpoints/en/engines/vllm) +- Understand the [transformers backend integration](https://blog.vllm.ai/2025/04/11/transformers-backend.html) diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md index c7e514f2276e0..0d6c3729911ad 100644 --- a/docs/deployment/frameworks/litellm.md +++ b/docs/deployment/frameworks/litellm.md @@ -13,7 +13,7 @@ And LiteLLM supports all models on VLLM. ## Prerequisites -- Setup vLLM and litellm environment +Set up the vLLM and litellm environment: ```bash pip install vllm litellm @@ -23,13 +23,13 @@ pip install vllm litellm ### Chat completion -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve qwen/Qwen1.5-0.5B-Chat + ``` -- Call it with litellm: +1. Call it with litellm: ??? code @@ -51,13 +51,13 @@ vllm serve qwen/Qwen1.5-0.5B-Chat ### Embeddings -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -vllm serve BAAI/bge-base-en-v1.5 -``` + ```bash + vllm serve BAAI/bge-base-en-v1.5 + ``` -- Call it with litellm: +1. Call it with litellm: ```python from litellm import embedding diff --git a/docs/deployment/frameworks/lobe-chat.md b/docs/deployment/frameworks/lobe-chat.md index e3e7dbe6e1e80..8ecd1484eab06 100644 --- a/docs/deployment/frameworks/lobe-chat.md +++ b/docs/deployment/frameworks/lobe-chat.md @@ -6,6 +6,6 @@ Supports speech-synthesis, multi-modal, and extensible (function call) plugin sy One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. -It supports vLLM as a AI model provider to efficiently serve large language models. +It supports vLLM as an AI model provider to efficiently serve large language models. For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/deployment/frameworks/lws.md b/docs/deployment/frameworks/lws.md index 3319dc6c90e1e..3b9fa3ea43d64 100644 --- a/docs/deployment/frameworks/lws.md +++ b/docs/deployment/frameworks/lws.md @@ -22,7 +22,7 @@ Deploy the following yaml file `lws.yaml` metadata: name: vllm spec: - replicas: 2 + replicas: 1 leaderWorkerTemplate: size: 2 restartPolicy: RecreateGroupOnPodRestart @@ -41,7 +41,7 @@ Deploy the following yaml file `lws.yaml` - sh - -c - "bash /vllm-workspace/examples/online_serving/multi-node-serving.sh leader --ray_cluster_size=$(LWS_GROUP_SIZE); - python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2" + vllm serve meta-llama/Meta-Llama-3.1-405B-Instruct --port 8080 --tensor-parallel-size 8 --pipeline_parallel_size 2" resources: limits: nvidia.com/gpu: "8" @@ -126,8 +126,6 @@ Should get an output similar to this: NAME READY STATUS RESTARTS AGE vllm-0 1/1 Running 0 2s vllm-0-1 1/1 Running 0 2s -vllm-1 1/1 Running 0 2s -vllm-1-1 1/1 Running 0 2s ``` Verify that the distributed tensor-parallel inference works: diff --git a/docs/deployment/frameworks/open-webui.md b/docs/deployment/frameworks/open-webui.md index eaa51bb613287..505c129613dea 100644 --- a/docs/deployment/frameworks/open-webui.md +++ b/docs/deployment/frameworks/open-webui.md @@ -20,7 +20,7 @@ To get started with Open WebUI using vLLM, follow these steps: For example: ```console - python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 + vllm serve --host 0.0.0.0 --port 8000 ``` 3. Start the Open WebUI Docker container: diff --git a/docs/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md index d5f2ec302b6cd..d86ab1600f126 100644 --- a/docs/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -11,7 +11,7 @@ Here are the integrations: ### Prerequisites -- Setup vLLM and langchain environment +Set up the vLLM and langchain environment: ```bash pip install -U vllm \ @@ -22,33 +22,33 @@ pip install -U vllm \ ### Deploy -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -# Start embedding service (port 8000) -vllm serve ssmits/Qwen2-7B-Instruct-embed-base -``` + ```bash + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + ``` -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -# Start chat service (port 8001) -vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 -``` + ```bash + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + ``` -- Use the script: +1. Use the script: -- Run the script +1. Run the script -```python -python retrieval_augmented_generation_with_langchain.py -``` + ```python + python retrieval_augmented_generation_with_langchain.py + ``` ## vLLM + llamaindex ### Prerequisites -- Setup vLLM and llamaindex environment +Set up the vLLM and llamaindex environment: ```bash pip install vllm \ @@ -60,24 +60,24 @@ pip install vllm \ ### Deploy -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -# Start embedding service (port 8000) -vllm serve ssmits/Qwen2-7B-Instruct-embed-base -``` + ```bash + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + ``` -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -# Start chat service (port 8001) -vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 -``` + ```bash + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + ``` -- Use the script: +1. Use the script: -- Run the script +1. Run the script: -```python -python retrieval_augmented_generation_with_llamaindex.py -``` + ```python + python retrieval_augmented_generation_with_llamaindex.py + ``` diff --git a/docs/deployment/frameworks/skypilot.md b/docs/deployment/frameworks/skypilot.md index 06e2fed38f056..f4a984a6433e2 100644 --- a/docs/deployment/frameworks/skypilot.md +++ b/docs/deployment/frameworks/skypilot.md @@ -32,6 +32,7 @@ See the vLLM SkyPilot YAML for serving, [serving.yaml](https://github.com/skypil ports: 8081 # Expose to internet traffic. envs: + PYTHONUNBUFFERED: 1 MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct HF_TOKEN: # Change to your own huggingface token, or use --env to pass. @@ -47,9 +48,8 @@ See the vLLM SkyPilot YAML for serving, [serving.yaml](https://github.com/skypil run: | conda activate vllm echo 'Starting vllm api server...' - python -u -m vllm.entrypoints.openai.api_server \ + vllm serve $MODEL_NAME \ --port 8081 \ - --model $MODEL_NAME \ --trust-remote-code \ --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ 2>&1 | tee api_server.log & @@ -131,6 +131,7 @@ SkyPilot can scale up the service to multiple service replicas with built-in aut ports: 8081 # Expose to internet traffic. envs: + PYTHONUNBUFFERED: 1 MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct HF_TOKEN: # Change to your own huggingface token, or use --env to pass. @@ -146,9 +147,8 @@ SkyPilot can scale up the service to multiple service replicas with built-in aut run: | conda activate vllm echo 'Starting vllm api server...' - python -u -m vllm.entrypoints.openai.api_server \ + vllm serve $MODEL_NAME \ --port 8081 \ - --model $MODEL_NAME \ --trust-remote-code \ --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ 2>&1 | tee api_server.log @@ -243,6 +243,7 @@ This will scale the service up to when the QPS exceeds 2 for each replica. ports: 8081 # Expose to internet traffic. envs: + PYTHONUNBUFFERED: 1 MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct HF_TOKEN: # Change to your own huggingface token, or use --env to pass. @@ -258,9 +259,8 @@ This will scale the service up to when the QPS exceeds 2 for each replica. run: | conda activate vllm echo 'Starting vllm api server...' - python -u -m vllm.entrypoints.openai.api_server \ + vllm serve $MODEL_NAME \ --port 8081 \ - --model $MODEL_NAME \ --trust-remote-code \ --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ 2>&1 | tee api_server.log diff --git a/docs/deployment/frameworks/streamlit.md b/docs/deployment/frameworks/streamlit.md index af0f0690c68e2..c119878f137a4 100644 --- a/docs/deployment/frameworks/streamlit.md +++ b/docs/deployment/frameworks/streamlit.md @@ -6,35 +6,33 @@ It can be quickly integrated with vLLM as a backend API server, enabling powerfu ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment by installing all required packages: + +```bash +pip install vllm streamlit openai +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with a supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve Qwen/Qwen1.5-0.5B-Chat + ``` -- Install streamlit and openai: +1. Use the script: -```bash -pip install streamlit openai -``` +1. Start the streamlit web UI and start to chat: -- Use the script: - -- Start the streamlit web UI and start to chat: - -```bash -streamlit run streamlit_openai_chatbot_webserver.py - -# or specify the VLLM_API_BASE or VLLM_API_KEY -VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" \ + ```bash streamlit run streamlit_openai_chatbot_webserver.py -# start with debug mode to view more details -streamlit run streamlit_openai_chatbot_webserver.py --logger.level=debug -``` + # or specify the VLLM_API_BASE or VLLM_API_KEY + VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" \ + streamlit run streamlit_openai_chatbot_webserver.py -![](../../assets/deployment/streamlit-chat.png) + # start with debug mode to view more details + streamlit run streamlit_openai_chatbot_webserver.py --logger.level=debug + ``` + + ![Chat with vLLM assistant in Streamlit](../../assets/deployment/streamlit-chat.png) diff --git a/docs/deployment/integrations/kaito.md b/docs/deployment/integrations/kaito.md new file mode 100644 index 0000000000000..ff050d3eeaf47 --- /dev/null +++ b/docs/deployment/integrations/kaito.md @@ -0,0 +1,5 @@ +# KAITO + +[KAITO](https://kaito-project.github.io/kaito/docs/) is a Kubernetes operator that supports deploying and serving LLMs with vLLM. It offers managing large models via container images with built-in OpenAI-compatible inference, auto-provisioning GPU nodes and curated model presets. + +Please refer to [quick start](https://kaito-project.github.io/kaito/docs/quick-start) for more details. diff --git a/docs/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md index 28031f01f85e8..8eb7f8d81275d 100644 --- a/docs/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -1,6 +1,6 @@ # Llama Stack -vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-stack) . +vLLM is also available via [Llama Stack](https://github.com/llamastack/llama-stack). To install Llama Stack, run @@ -8,9 +8,9 @@ To install Llama Stack, run pip install llama-stack -q ``` -## Inference using OpenAI Compatible API +## Inference using OpenAI-Compatible API -Then start Llama Stack server pointing to your vLLM server with the following configuration: +Then start the Llama Stack server and configure it to point to your vLLM server with the following settings: ```yaml inference: @@ -20,15 +20,15 @@ inference: url: http://127.0.0.1:8000 ``` -Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) for more details on this remote vLLM provider. +Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/providers/inference/remote_vllm.html) for more details on this remote vLLM provider. -## Inference via Embedded vLLM +## Inference using Embedded vLLM -An [inline vLLM provider](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/inference/vllm) +An [inline provider](https://github.com/llamastack/llama-stack/tree/main/llama_stack/providers/inline/inference) is also available. This is a sample of configuration using that method: ```yaml -inference +inference: - provider_type: vllm config: model: Llama3.1-8B-Instruct diff --git a/docs/deployment/integrations/production-stack.md b/docs/deployment/integrations/production-stack.md index fae392589c060..2f1894ccf0022 100644 --- a/docs/deployment/integrations/production-stack.md +++ b/docs/deployment/integrations/production-stack.md @@ -55,7 +55,7 @@ sudo kubectl port-forward svc/vllm-router-service 30080:80 And then you can send out a query to the OpenAI-compatible API to check the available models: ```bash -curl -o- http://localhost:30080/models +curl -o- http://localhost:30080/v1/models ``` ??? console "Output" @@ -78,7 +78,7 @@ curl -o- http://localhost:30080/models To send an actual chatting request, you can issue a curl request to the OpenAI `/completion` endpoint: ```bash -curl -X POST http://localhost:30080/completions \ +curl -X POST http://localhost:30080/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "facebook/opt-125m", diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index cad801a4312cc..d3fda7eb6fb6e 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -12,6 +12,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: - [Helm](frameworks/helm.md) - [InftyAI/llmaz](integrations/llmaz.md) +- [KAITO](integrations/kaito.md) - [KServe](integrations/kserve.md) - [KubeRay](integrations/kuberay.md) - [kubernetes-sigs/lws](frameworks/lws.md) @@ -380,7 +381,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ### Startup Probe or Readiness Probe Failure, container log contains "KeyboardInterrupt: terminated" -If the startup or readiness probe failureThreshold is too low for the time needed to startup the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: +If the startup or readiness probe failureThreshold is too low for the time needed to start up the server, Kubernetes scheduler will kill the container. A couple of indications that this has happened: 1. container log contains "KeyboardInterrupt: terminated" 2. `kubectl get events` shows message `Container $NAME failed startup probe, will be restarted` diff --git a/docs/design/arch_overview.md b/docs/design/arch_overview.md index 6b70867760259..f1300a73c26c2 100644 --- a/docs/design/arch_overview.md +++ b/docs/design/arch_overview.md @@ -69,6 +69,11 @@ Sometimes you may see the API server entrypoint used directly instead of via the python -m vllm.entrypoints.openai.api_server --model ``` +!!! warning + + `python -m vllm.entrypoints.openai.api_server` is deprecated + and may become unsupported in a future release. + That code can be found in . More details on the API server can be found in the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) document. diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md new file mode 100644 index 0000000000000..f88a29f6eadd8 --- /dev/null +++ b/docs/design/cuda_graphs.md @@ -0,0 +1,241 @@ +# CUDA Graphs + +This write-up introduces the new CUDA Graphs modes in vLLM v1 beyond previous [torch.compile integration](torch_compile.md). To summarize, we: + +1. Added flexible `cudagraph_mode` configuration +2. Made full CUDA Graphs support orthogonal to compilation +3. Introduced a CUDA Graphs dispatcher as a central controller that picks the desired runtime mode and CUDA Graphs per batch automatically + +In this document we will discuss the: + +* [Motivation](#motivation) +* [CUDA Graphs modes](#cudagraphmodes) +* [Detailed design](#detailed-design) +* [Example usage of the different CUDA Graphs modes](#usage-guide) + +!!! note + In this document, we refer to pure decode (`max_query_len=1`) or speculative decode (`max_query_len =1+num_spec_tokens`) as **uniform decode** batches, and the opposite would be **non-uniform** batches (i.e., prefill or mixed prefill-decode batches). + +!!! note + The following contents are mostly based on the last commit of . + +## Motivation + +Initial piecewise compilation was built to allow piecewise cudagraph capture, excluding cudagraph-unsupported operations (mainly attention). This allowed some speedup from cudagraphs while maintaining compatibility with all attention backends. We later added support for "full cudagraphs" by not compiling piecewise, so that we could further reduce the latency in cases where attention supported cudagraphs. However, this tight coupling between compilation and cudagraph capture led to an all-or-nothing experience with little flexibility. Many attention backends also weren’t ready for unified "full" CUDA Graphs capture (e.g., only FlashAttention 3 supports it currently) or only support CUDA Graphs for pure decode batches (e.g., Flashinfer, FlashMLA, and Mamba, etc.). That led to confusing performance/compatibility tradeoffs, inconsistent CUDA Graphs support, and increasingly complex code structure. + +This led us to seek a more fine-grained CUDA Graphs solution with the following features: + +* Explicitly aware of CUDA Graphs for prefill/mixed or (uniform-)decode batch and capture them separately. +* Separate CUDAGraph capture logic from compilation (as much as feasible) for feature orthogonality, which suggest: + * Capturing piecewise and full cudagraphs using the same compiled graph, and + * Full cudagraph capture without compilation. +* Dispatch between full and piecewise cudagraph at runtime depending on batch composition. +* Centralized control of CUDAGraph behavior for reduced code complexity and allowed more extendibility. + +These features allow the most flexibility for cudagraph capture and compilation for all kinds of startup/performance tradeoffs and feature support. + +## `CudagraphModes` + +[CUDAGraphMode][vllm.config.compilation.CUDAGraphMode] is the single knob you tune in `CompilationConfig.cudagraph_mode`: + +* `NONE` — turn CUDA Graphs off. Good for debugging. +* `PIECEWISE` — a single-mode strategy (and past default). It is the most flexible: attention or other CUDA Graphs-incompatible operations stay eager, everything else goes into CUDA Graphs. Requires piecewise compilation. +* `FULL` — a single-mode strategy, which only captures full CUDA Graphs for non-uniform batches, then uniform-decode batches reuse the CUDA Graph of non-uniform batch of the same batch_size, since they are compatible; can be good for small models or workloads with small prompts. +* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs. +* `FULL_AND_PIECEWISE` — (default mode) full CUDA Graph for uniform decode, piecewise CUDA Graphs for others; generally the most performant setting, especially for low latency with small models or MoEs, but also requires the most memory and takes the longest to capture. + +Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_PIECEWISE` for better performance, (for pooling models, it's still `PIECEWISE`). Otherwise, e.g. if piecewise compilation unavailable, we default to `NONE`. + +While `NONE` , `PIECEWISE`, and `FULL` are single-mode configurations and simply equivalent to past implementations of eager execution, piecewise CUDA Graphs, and full CUDA Graphs respectively, `FULL_DECODE_ONLY` and `FULL_AND_PIECEWISE` are newly appended dual-mode configurations, which require dispatching to switch between concrete runtime modes according to runtime batches dynamically. + +!!! note + Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potantial `NONE` if no suitable CUDA Graph available), depending on the batch composition. + +While cascade attention is not cudagraph compatible, it is now compatible with all possible cudagraph mode configurations. If a batch uses cascade attention, it always gets dispatched to `PIECEWISE` mode if available (otherwise `NONE`). + +!!! note + Not all CUDA Graph modes are compatible with every attention backend. We automatically "downgrade" modes to the closest supported mode. For example, if a backend only supports CUDA Graphs for pure decode/uniform batches, we convert `FULL` to `FULL_AND_PIECEWISE` if piecewise compilation is enabled, and `FULL_DECODE_ONLY` otherwise. + +## Detailed Design + +### Overview + +The new CUDA Graphs logic is built on top of piecewise compilation and supports dual CUDA Graphs runtime mode switching. The system contains the following core components: + +* [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper]: wrapper that handles CUDAGraph capture & replay on the wrapped callable +* [CudagraphDispatcher][vllm.v1.cudagraph_dispatcher.CudagraphDispatcher]: the central controller that contains the single source of truth about CUDA Graphs and handles dispatching between them. +* [CUDAGraphMode][vllm.config.compilation.CUDAGraphMode]: enum describing the supported and runtime modes (introduced above). +* [BatchDescriptor][vllm.forward_context.BatchDescriptor], serving as a unique representation of the runtime batch used for dispatching. + +See the following figures for a quick comparison between the previous and current design patterns of CUDA Graphs with inductor compilation. We can see that previously the CUDA Graphs logic and compilation logic were tightly coupled into the vllm `PiecewiseBackend`, and CUDA Graphs was implicitly dispatched by `batch_size` idly. Now the CUDA Graphs logic is separated into the `CUDAGraphWrapper` class, responsible for both full and piecewise CUDA Graphs abilities, and dispatching is **explicitly** done via **runtime mode** plus the `BatchDescriptor` as the **dispatch key** via `CudagraphDispatcher`. + +**Before:** + +![previous_design](../assets/design/cuda_graphs/previous_design.png) + +**After:** + +![new_design](../assets/design/cuda_graphs/current_design.png) + +### `BatchDescriptor` + +[BatchDescriptor][vllm.forward_context.BatchDescriptor] is a component within `ForwardContext`, alongside the CUDA Graphs runtime modes, serving as the core structure for dispatching keys at runtime. The prototype is: + +```python +class BatchDescriptor(NamedTuple): + num_tokens: int + uniform_decode: bool = False +``` + +where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`. + +The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode. + +!!! note + The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs). + +### `CudagraphDispatcher` + +The [CudagraphDispatcher][vllm.v1.cudagraph_dispatcher.CudagraphDispatcher] takes responsibility for maintaining two sets of valid dispatching keys, one set for `FULL` runtime mode and one set for `PIECEWISE` runtime mode, and dispatches the correct runtime mode and the dispatching keys before executing the model's forwards. It will take in the initial key (a rough batch_descriptor for the padded input) and return the selected runtime mode and the final batch_descriptor, then tell the CUDAGraphWarpper instances that decision through forward contexts. Notice that `CudagraphDispatcher` is the only source of truth for available CUDA Graph keys and `CUDAGraphWrapper` instances can blindly trust the forward context on what CUDA Graphs to dispatch to. This lets us simplify the wrapper code and centralize the logic in the dispatcher. + +The dispatching keys are initialized through the dispatcher's `initialize_cudagraph_keys` method, which is called by the gpu_model_runner after all possible attention backends are initialized. This is where we can get much fancier in the future and “prepare” all kinds of CUDA Graphs combinations. For now, we just append available keys based on the valid combos of `decode_mode`/`mixed_mode` of `cudagraph_mode` and `cudagraph_capture_sizes` in the compilation config. + +The dispatch code looks like: + +```python +batch_descriptor=BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=...) +runtime_mode, batch_descriptor = cudagraphdispatcher.dispatch(batch_descriptor) +# execution +with set_forward_context(..., + cudagraph_runtime_mode=runtime_mode, + batch_descriptor=batch_descriptor): + output = self.model(...) +``` + +Inside the `dispatch()` method, the dispatcher will search the proper CUDA Graphs runtime mode and existing dispatching keys for a return. We basically search the existing keys following the priority: `FULL`>`PIECEWISE`>`None`. If the dispatching key does not exist, default to return `NONE` mode for eager execution. The implementations can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/v1/cudagraph_dispatcher.py#L91). + +Here is a simplified illustration of the workflow at runtime in the model executor: +![executor_runtime](../assets/design/cuda_graphs/executor_runtime.png) + +### `CUDAGraphWrapper` + +A [CUDAGraphWrapper][vllm.compilation.cuda_graph.CUDAGraphWrapper] instance wraps a runnable and simply mimics the runnable with appended CUDA Graphs abilities. Each wrapper instance is bound to a specific `runtime_mode`, which is restricted to `PIECEWISE` and `FULL` mode, and takes responsibility for capturing/replaying and passing through (directly calling) the runnable. At runtime, each wrapper would: + +1. inspect the runtime_mode and batch_descriptor(dispatching key) from the global forward context. +2. If runtime_mode is `NONE` or runtime_mode does not match the mode of the wrapper, just call the runnable directly. +3. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, the wrapper will perform CUDA Graphs capture (if key does not exist, create +a new entry and cache it) or replay (if key exists in the cache). + +The above steps are based on the assumption that the CUDA Graphs wrapper would directly trust what’s in the forward context (controlled by the dispatcher). This lets us simplify and cenralize the logic, reducing the complexity as well as the risk of mismatched state between the wrappers and the dispatcher. It also allows reusing the wrapper class for both `FULL` and `PIECEWISE` runtime modes. See the implementation [here](https://github.com/vllm-project/vllm/blob/f751e50b7a2aae3110d83ed0d88202fc91b3e78a/vllm/compilation/cuda_graph.py#L106). + +#### Nested Wrapper design + +The core mechanism of making a full CUDA Graphs and piecewise CUDA Graphs coexist and compatible is the nested CUDA Graphs wrapper design, building on top of piecewise compilation with only a single piecewise FX graph. We wrap a FULL mode wrapper outside the entire model for the full CUDA Graphs functionality; meanwhile, each piecewise backend is wrapped via a `PIECEWISE` mode wrapper inside the compilation. + +The flow chart below should clearly describe how it works. +![wrapper_flow](../assets/design/cuda_graphs/wrapper_flow.png) + +Therefore, for a `FULL` runtime mode, it is safe to capture/replay a full CUDA Graph since the piecewise wrapper is not activated. The situation is similar for `PIECEWISE` mode, as there are no conflicts between the `FULL` mode wrapper and `PIECEWISE` mode wrappers. For the `NONE` runtime mode, both `FULL` and `PIECEWISE` wrappers would not be activated, so we simply fall through to eager execution. + +### Full CUDA Graph capturing & warm-up + +The CUDA Graphs capturing happens when the runner first calls the model forward (using `_dummy_run`) with a non-`NONE` runtime mode. For full CUDA Graph capture, we explicitly capture different cases (i.e., prefill/mixed batch or uniform_decode batch) by properly setting attention metadata to make sure the underlying attention backends launch the desired kernel routines. To distinguish prefill/mixed batch or uniform_decode batch, the most important property is the `max_query_len` in attn_metadata (true for most attention backends). We set it to the desired `uniform_query_len` for uniform_decode otherwise we make it just the `num_tokens` for a non-uniform_decode batch. + +The CUDA Graphs wrapper no longer manages the warm-up logic. The warm-up process is now controlled directly by the GPU model runner, where the `NONE` runtime mode is assigned to play an eager execution for warm-up. When warming up for a full CUDA Graph, it is also important to explicitly run attention during the warmup `dummy_run` call. + +## CUDA Graphs Compatibility of Attention Backends + +To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backends.utils.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`. + +```python +class AttentionCGSupport(enum.Enum): + """ Constants for the CUDA Graphs support of the attention backend + Here we do not consider the cascade attention, as currently + it is never CUDA Graphs supported.""" + + ALWAYS = 3 + """CUDA Graphs always supported; supports mixed-prefill-decode""" + UNIFORM_BATCH = 2 + """CUDA Graphs supported for batches the only contain query lengths that are + the same, this can be used for spec-decode + i.e. "decodes" are 1 + num_speculative_tokens""" + UNIFORM_SINGLE_TOKEN_DECODE = 1 + """CUDA Graphs supported for batches the only contain query_len==1 decodes""" + NEVER = 0 + """NO CUDA Graphs support""" +``` + +Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation level. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture]. + +The following table lists backends that support full CUDA Graphs at the time of writing. + +| Attention Backend | cudagraph_support | Comments | +|:---|:---|:---| +| FlashAttention v2 | `UNIFORM_BATCH` | Actually `ALWAYS` but workaround to fallback to `FULL_AND_PIECEWISE` for performance reason | +| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good | +| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches | +| AITER FlashAttention | `UNIFORM_BATCH`| | +| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | | +| FlashMLA | `UNIFORM_BATCH` | | +| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | | +| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | | +| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | | + +Unlisted backends are all declared as `NEVER`. + +## Usage guide + +Now the CLI is directly using the uppercase string of cudagraph_mode for compilation_config: `--compilation-config '{"cudagraph_mode": "..."}'`, where `...` should be one of `NONE`, `PIECEWISE`, `FULL`, `FULL_DECODE_ONLY`, and `FULL_AND_PIECEWISE`. Note that all `PIECEWISE` related modes require piecewise compilation, and all `FULL` related modes need CUDA Graphs support of attention backends. For example: + +```bash +vllm serve --model meta-llama/Llama-3.1-8B-Instruct --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' +``` + +### Python examples + +```python +import os +os.environ.setdefault("VLLM_LOGGING_LEVEL", "DEBUG") + +import vllm +from vllm.config import CUDAGraphMode + +compilation_config = {"level": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"} +model = vllm.LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + dtype='auto', + compilation_config = compilation_config, + ) +sampling_params = vllm.SamplingParams( + temperature=0, # greedy decoding + max_tokens=1024, +) +outputs = model.generate( + ["My name is John and"], + sampling_params=sampling_params, +) +``` + +### Migration from legacy flags + +Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`: + +* `use_cudagraph=False` → `NONE`. +* `use_cudagraph=True` and `full_cuda_graph=False` → `PIECEWISE`. +* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy. + +As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead. + +### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism) + +Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs. + +Long term, we've added the ability to partition the graph in Inductor instead of right after Dynamo. It can be enabled with `CompilationConfig.use_inductor_graph_partition=True` but is currently experimental and only available with `torch>=2.9`. This also increases compilation time as it has to compile the whole graph and cannot reuse piecewise compilation artifacts. Once vLLM supports 2.9, we plan to make this the default approach as it will also speed up piecewise cudagraph capture. + +## About the Performance + +See the following links for examples: + +* [20059#issuecomment-3160858458](https://github.com/vllm-project/vllm/pull/20059#issuecomment-3160858458) +* [20059#issuecomment-3188735226](https://github.com/vllm-project/vllm/pull/20059#issuecomment-3188735226) +* [20059#issuecomment-3219888738](https://github.com/vllm-project/vllm/pull/20059#issuecomment-3219888738) diff --git a/docs/design/dbo.md b/docs/design/dbo.md new file mode 100644 index 0000000000000..d92c47c80f951 --- /dev/null +++ b/docs/design/dbo.md @@ -0,0 +1,88 @@ +# Dual Batch Overlap + +## Motivation + +The core motivation of the DBO system in vLLM is to overlap the sparse all-to-all communication in the MoE layer with the surrounding computation. This system currently only targets DP+EP deployments. + +## Introduction + +The Dual Batch Overlap system works by splitting the batch in the model runner, creating two worker threads, and then running the model on each of these worker threads. When DBO is enabled, yield points within the `FusedMoEModularKernel` allow the two CPU worker threads (also called UBatch threads) to ping-pong between each other so that when one is running compute, the other is waiting on communication. Throughout the code, ubatch may be used as a short form of microbatch; this is an ASCII-friendly version of the short form µ-batch. + +The DBO system includes modifications to `GpuModelRunner` and `ModularKernel`, and defines two utility classes: `UBatchWrapper` and `UBatchContext`. `UBatchWrapper` manages thread lifecycle and CUDA graph execution of the model. `UBatchContext` wraps `ForwardContext` to coordinate synchronization between the two UBatch threads. + +Below is the overlap schedule that is currently implemented in vLLM. + +```python +# Schedule notation legend: +# S = Shared expert +# A0 = MLA qkv proj, +# A1 = Core attn + out proj + MoE gate +# D = Dispatch +# C = Combine + +# Comp: |-A0₀-A1₀-||-MLP₁-||-S₁-MLP₀-||-S₀-A0₁-A1₁-| +# Comm: |----D₁---||--D₀--||----C₁---||-----C₀-----| +# Order: D₁ send, A0₀, A1₀, D₁ recv, D₀ send, MLP₁, D₀ recv, +# C₁ send, S₁, MLP₀, C₁ recv, C₀ send, S₀, A0₁, A1₁, C₀ recv. +# MLP_SHARED_OVERLAP = "mlp_shared_overlap" +``` + +## Running with DBO + +To enable the DBO system pass in the `--enable-dbo` argument to your vllm serve command. This must be run in conjunction with `--data-parallel-size N` where N is greater than 1 and `--enable-expert-parallel`. Additionally, there are two configuration knobs. + +* `--dbo-decode-token-threshold` the minimum number of tokens in a decode-only batch required to enable DBO for that batch +* `--dbo-prefill-token-threshold` the minimum number of tokens in a batch containing at least one prefill required to enable DBO for that batch + +Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `VLLM_ALL2ALL_BACKEND` environment variable must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests. + +Below is a command that will spin up a two DP rank server with expert parallelism and DBO enabled. +EX: `VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo` + +Note that there must be at least two GPUs visible in `CUDA_VISIBLE_DEVICES` + +## DBO Components + +* GPUModelRunner +* UBatchWrapper +* UBatchContext + +### GPU Model Runner + +The batch is split into microbatches by the `GPUModelRunner` class. This is accomplished in two steps. First, coordination across all DP ranks is performed to determine whether microbatching will be applied. Microbatching must be uniform across all DP ranks. If microbatching is not feasible for any DP rank, it is disabled for all ranks. If all DP ranks are going to microbatch, the total number of tokens is padded up to the max number of tokens amongst all ranks. If any rank would end up with an empty second microbatch after the padding is applied, microbatching will be aborted and no ranks will microbatch. Once microbatching has been initiated by all ranks, the second step is performed. The `CommonAttentionMetadata` is sliced in half by the `GPUModelRunner` so that there is one attention metadata per-microbatch. + +### UBatchWrapper + +gpu_ubatch_wrapper + +The `UBatchWrapper` class is a model wrapper that's responsible for all of the thread, UBatchContext, and CUDA graph management for DBO. It's designed to be relatively transparent to the GPU Model Runner. + +The implementation runs the model twice, once for each microbatch. Each model invocation occurs within a UBatch thread. These threads are launched in parallel and are synchronized using the `UBatchContext`. Each thread is provided with a sliced version of the attention metadata that is used to run its half of the batch. + +CUDA graphs for DBO are entirely managed by the `UBatchWrapper`. Because of this, DBO only supports running with Full CUDA graphs. However, once a DBO CUDA graph has been captured, it can be replayed without any multithreading or CPU synchronization. + +#### Interfaces + +The `__init__` method takes in the model, VllmConfig, CUDAGraphMode, and device. + +The `forward` method exclusively takes in model arguments. It determines whether or not to run with DBO based on whether a `ubatch_slices` object is present in the `forward_context`. Otherwise, the model is run without DBO. + +### UBatchContext + +ubatch_context + +The `UBatchContext` class is a `ForwardContext` wrapper class that is used by the `UBatchWrapper` class to synchronize the two UBatch threads. It should only be instantiated by using `make_ubatch_contexts`. + +When one of the UBatch threads reaches a `dbo_yield` call, it pauses, and starts the other thread which will run until it reaches the same `dbo_yield` call. This "ping-pong" dynamic continues, with threads swapping at each `dbo_yield call`, until the model's execution is complete. + +The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` calls in the `FusedMoEModularKernel.forward` method. + +#### Interfaces + +The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization. + +The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel. + +The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists. + +The `dbo_yield` method puts the current thread to sleep and wakes up the other UBatch thread. diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 4b917ab408eec..ee5701989265b 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts ### FusedMoEPrepareAndFinalize -The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions. -The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) +The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. +The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) ![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") @@ -133,12 +133,12 @@ class FusedMoEModularKernel: Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, * PplxPrepareAndFinalize type is backed by Pplx All2All kernels, -* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughtput All2All kernels, and +* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. #### Step 1: Add an All2All manager -The purpose of the All2All Manager is to setup the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). +The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). #### Step 2: Add a FusedMoEPrepareAndFinalize Type @@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. +`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. + +`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. + `FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. `FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. @@ -183,7 +187,7 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking #### maybe_make_prepare_finalize -The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. +The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. Please refer to the implementations in, * `ModelOptNvFp4FusedMoE` @@ -198,7 +202,7 @@ Please refer to the implementations in, * `CompressedTensorsW8A8Fp8MoECutlassMethod` * `Fp8MoEMethod` * `ModelOptNvFp4FusedMoE` -dervied classes. +derived classes. #### init_prepare_finalize @@ -226,7 +230,7 @@ Doing this will add the new implementation to the test suite. The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile @@ -238,30 +242,8 @@ Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kern ## FusedMoEPrepareAndFinalize Implementations -The following table lists the `FusedMoEPrepareAndFinalize` implementations at the time of writing, - -| Implementation | Type | Comments | -| :--- | :--- | :--- | -| DeepEPHTPrepareAndFinalize | Contiguous / Non-Batched | Uses the DeepEP High-Throughput all2all kernels. | -| DeepEPLLPrepareAndFinalize | Batched | Uses the DeepEP Low-Latency all2all kernels. | -| PplxPrepareAndFinalize | Batched | Uses the Perplexity all2all kernels. | -| FlashInferCutlassMoEPrepareAndFinalize | Contiguous | | -| MoEPrepareAndFinalizeNoEP | Contiguous | This implementation is used when there is no EP. i.e. no all2all kernels are invoked. | -| BatchedPrepareAndFinalize | Batched | A reference prepare/finalize class that reorganizes the tokens into expert batched format, i.e. E x max_num_tokens x K. (Doesn’t use any all2all kernels. This is primarily used in unit testing) | +See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses. ## FusedMoEPermuteExpertsUnpermute -The following table lists the `FusedMoEPermuteExpertsUnpermute` implementations at the time of writing, - -| Implementation | Type | Comment | -| :--- | :--- | :--- | -| BatchedDeepGemmExperts | Batched | Uses the DeepGemm’s Masked Grouped Gemm kernels for the fused_moe operation. | -| BatchedTritonExperts | Batched | Uses a Triton Kernel for the Batched matmuls. | -| BatchedTritonOrDeepGemmExperts | Batched | Chooses either the `BatchedDeepGemmExperts` or `BatchedTritonExperts` based on environment settings. | -| DeepGemmExperts | Contiguous / Non-Batched | Uses DeepGemm’s Grouped Gemm kernels for fused_moe operation. | -| TritonExperts | Contiguous / Non-Batched | Uses a Triton Kernel for fused_moe matmuls. | -| TritonOrDeepGemmExperts | Contiguous / Non-Batched | Chooses either the `DeepGemmExperts` or `TritonExperts` based on fused_moe inputs. | -| CutlassExpertsFP8 | Supports both Batched and Contiguous formats | Uses Cutlass Grouped Gemm implementations for the fp8 matmuls. | -| CutlassExpertsFP4 | Supports both Batched and Contiguous formats | Uses Cutlass Grouped Gemm implementations for the fp4 matmuls. | -| FlashInferExperts | Contiguous | Uses fused_moe operation from FlashInfer | -| NaiveBatchedExperts | Batched | Reference Batched Experts implementation. Primarily used in unit tests. | +See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts. diff --git a/docs/design/huggingface_integration.md b/docs/design/huggingface_integration.md index 5a7582c86d49f..412ce658b92a2 100644 --- a/docs/design/huggingface_integration.md +++ b/docs/design/huggingface_integration.md @@ -1,31 +1,31 @@ # Integration with Hugging Face -This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run `vllm serve`. +This document describes how vLLM integrates with Hugging Face libraries. We will explain step by step what happens under the hood when we run `vllm serve`. -Let's say we want to serve the popular QWen model by running `vllm serve Qwen/Qwen2-7B`. +Let's say we want to serve the popular Qwen model by running `vllm serve Qwen/Qwen2-7B`. 1. The `model` argument is `Qwen/Qwen2-7B`. vLLM determines whether this model exists by checking for the corresponding config file `config.json`. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182) for the implementation. Within this process: - If the `model` argument corresponds to an existing local path, vLLM will load the config file directly from this path. - - If the `model` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the HuggingFace cache works. - - If the `model` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file. + - If the `model` argument is a Hugging Face model ID consisting of a username and model name, vLLM will first try to use the config file from the Hugging Face local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the Hugging Face cache works. + - If the `model` argument is a Hugging Face model ID but it is not found in the cache, vLLM will download the config file from the Hugging Face model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file. 2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186) for the implementation. 3. Next, vLLM [inspects](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189) the `model_type` field in the config dictionary to [generate](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L190-L216) the config object to use. There are some `model_type` values that vLLM directly supports; see [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48) for the list. If the `model_type` is not in the list, vLLM will use [AutoConfig.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained) to load the config class, with `model`, `--revision`, and `--trust_remote_code` as the arguments. Please note that: - - HuggingFace also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, HuggingFace will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example. - - The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled. + - Hugging Face also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, Hugging Face will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example. + - The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, Hugging Face will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled. 4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see [here](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244) for the implementation. 5. Finally, vLLM can reach the model class we want to initialize. vLLM uses the `architectures` field in the config object to determine the model class to initialize, as it maintains the mapping from architecture name to model class in [its registry](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/registry.py#L80). If the architecture name is not found in the registry, it means this model architecture is not supported by vLLM. For `Qwen/Qwen2-7B`, the `architectures` field is `["Qwen2ForCausalLM"]`, which corresponds to the `Qwen2ForCausalLM` class in [vLLM's code](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/qwen2.py#L364). This class will initialize itself depending on various configs. -Beyond that, there are two more things vLLM depends on HuggingFace for. +Beyond that, there are two more things vLLM depends on Hugging Face for. -1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). +1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). -2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. +2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: -This completes the integration between vLLM and HuggingFace. +This completes the integration between vLLM and Hugging Face. -In summary, vLLM reads the config file `config.json`, tokenizer, and model weight from the HuggingFace model hub or a local directory. It uses the config class from either vLLM, HuggingFace transformers, or loads the config class from the model's repository. +In summary, vLLM reads the config file `config.json`, tokenizer, and model weight from the Hugging Face model hub or a local directory. It uses the config class from either vLLM, Hugging Face transformers, or loads the config class from the model's repository. diff --git a/docs/design/hybrid_kv_cache_manager.md b/docs/design/hybrid_kv_cache_manager.md new file mode 100644 index 0000000000000..8f17b473adc08 --- /dev/null +++ b/docs/design/hybrid_kv_cache_manager.md @@ -0,0 +1,245 @@ +# Hybrid KV Cache Manager + +!!! warning + This document was written based on commit [458e74](https://github.com/vllm-project/vllm/commit/458e74eb907f96069e6d8a4f3c9f457001fef2ea). This feature is still in its early stage and things may change. + +## What is a hybrid model? + +Many recent "hybrid" LLMs combine multiple attention types within one model. For example: + +1. Sliding window attention (sw) + full attention (full): gpt-oss, Gemma 2/3, Ministral, cohere, etc. +2. Mamba + full: Bamba, Jamba, Minimax, etc. +3. Local chunked attention + full: Llama4 + +To serve these models efficiently, our [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] must: + +1. Allocate different slots to different layer type, for example: + - Full attention layers: reserve slots for **all** tokens. + - Sliding window layers: reserve slots only for the most recent **`sliding_window_size`** tokens. +2. Support layer-specific prefix-cache rules, for example: + - Full attention: a cache hit prefix requires **all** tokens remain in the KV cache. + - Sliding window: a cache hit prefix only requires the last **`sliding_window_size`** tokens remain in the KV cache. + +## Definitions + +1. **kv hidden size**: The number of bytes to store one token's KV cache for a single layer. +2. **block**: the memory reserved for kv cache are divided into multiple *blocks* with the same *page size* (defined below) +3. **block size**: number of tokens inside a block +4. **page size**: the physical memory size of a block, defined as: + + $$ + \text{num_layers} \times \text{block_size} \times \text{kv_hidden_size} + $$ + + `num_layers` doesn't mean the total number of layers in the model. The exact number depends on the context in this doc. + + !!! note + This is different from `KVCacheSpec.page_size_bytes` in the code, which is defined as: + + $$ + \text{block_size} \times \text{kv_hidden_size} + $$ + +## Allocation + +### High level idea + +We use a single memory pool for all layer types. The memory pool is split into multiple blocks with the same page size. [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates different numbers of blocks to different layers according to its attention type. + +The core challenge is ensuring every layer type uses the same **page size**. For full-attention-only models, the page size is straightforward, defined as: + +$$ +\text{page_size} = \text{block_size} \times \text{num_hidden_layers} \times \text{kv_hidden_size} +$$ + +However, in hybrid models, `num_hidden_layers` varies by attention type, which would normally produce mismatched page sizes. The cases below show how we unify them. + +### Case 1: toy model + +Let's start with a toy example: a model has 1 full attention layer and 3 sliding window attention layers. All layers have the same `kv_hidden_size`. + +We let each block to hold `block_size` tokens for one layer, so: + +$$ +\text{page_size} = \text{kv_hidden_size} \times \text{block_size} +$$ + +[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates a different number of blocks to each layer. + +This case is only a toy example. For real models, please refer to the following cases. + +### Case 2: same `kv_hidden_size` and a regular pattern + +When the model has more layers, e.g., 20 sliding window attention layers and 10 full attention layers with the same `kv_hidden_size`. Calling the allocator once per layer (30 calls) is OK but becomes inefficient. As a solution, we group the allocation of layers that need the same number of blocks to reduce the number of calls. + +The grouping is feasible because there is usually a beautiful ratio between the number of different types of layers. For example: + +- Gemma-2: 1 sw : 1 full +- Llama 4: 3 local : 1 full + +Our example can be regarded as 2 sw : 1 full. We can allocate blocks as if there are 2 sw and 1 full in the model, and repeat the result by 10 times to generate the `block_ids` for the 30 layers. The page size becomes: + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +Assume `block_size` 16, sliding window size 32, request length 112, then for the above example model, we need to allocate 11 blocks (0-6 for full, 7-8 for sw group 1, 9-10 for sw group 2). + +![Allocation Result](../assets/design/hybrid_kv_cache_manager/basic_grouping_example.png) + +Here, "/" denotes no block needed (sliding‑window layers don't need slots for early tokens). + +See the formal definition below. The layers are divided into multiple *KV Cache Groups* so that there is: + +1. **Identical attention type inside each group**: Each group only contains layers with the same attention type and thus need the same number of blocks for a given request. This enables layers in the same group share the same block ids without memory waste. +2. **Identical page size across groups**: Because our memory pool only have one page size. + +Our example model is divided into 3 KV cache groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +Obviously, it satisfies rule 1. For rule 2, all 3 groups have + +$$ +10 \times \text{kv_hidden_size} \times \text{block_size} +$$ + +as their page size. + +### Case 3: same `kv_hidden_size` and no regular pattern + +Unfortunately, not all models have such a beautiful ratio, and approach in Case 2 will produce too many small groups. For example, Gemma-3-27b has 52 sliding window attention layers and 10 full attention layers. With the constraints in case 2, it would be 26 sliding window groups and 5 full attention groups, each contains 2 layers. The allocation is still inefficient. To reduce the number of kv cache groups, we group layers using the smallest layer count among all attention types. For example, min(52, 10)=10 layers per group in Gemma-3-27b. Then the grouping result is: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) +- ... +- Group 6: 10 sliding window attention layers (sw.40 - sw.49) +- Group 7: 2 sliding window attention layers (sw.50 - sw.51) and 8 padding layers + +We will update this algorithm if this heuristic leads to a bad result when a new model comes out (e.g., 20 full + 30 sw, the group size should be 10 instead of 20). + +This case happens in Gemma-3 series models, and models in case 2 but with eagle speculative decoding which introduce one full attention layer. The solution has some memory waste and is not perfect. Please report any cases where padding overhead becomes unacceptable so we can refine the algorithm. + +### Case 4: different `kv_hidden_size` (mainly hybrid mamba models) + +Some architectures (e.g., Bamba, Jamba, Minimax) interleave standard attention layers with Mamba layers, where each Mamba layer's state size per token can be much larger than the attention layers' `kv_hidden_size`. Because we only support a single page size across all groups, we must reconcile these differing hidden sizes. + +The current algorithm is: + +1. Increase the `block_size` of attention layers until + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \ge \text{state_size}_{\text{mamba}} + $$ +2. Pad the mamba state per layer to + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} + $$ +3. Apply the grouping strategy in case 3. + +!!! note + This can lead to more than 400 `block_size` for attention layers, which is too large. Another padding strategy is to increase `block_size` until + + $$ + \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \times \text{num_attn_layers} \ge \text{state_size}_{\text{mamba}} + $$ + + This padding strategy is still a work in progress. + +### Case 5: KV sharing + +KV sharing refers to a layer using the KV cache of another layer, e.g., gemma-3n. +In these models, [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] ignores all layers with kv sharing and only allocates KV cache for layers that need kv cache, and some patches are made in model runner to apply the allocation result to kv sharing layers. + +## Prefix caching + +For simplicity, we assume `block_size=1` in this section. + +### High level idea + +The block pool uses a dict similar to `tuple(block_hash, group_id) -> block` to catch the full blocks. That means the same tokens of different groups are cached and evicted independently. + +When a new request comes in, we check the cache hit prefix of each group, and return the intersection of these groups as the cached prefix of the request. See below for the detailed algorithm for checking the cache hit of one group & performing the intersection. + +### Case 0: full attention only models + +For full attention layers, blocks are allocated for all tokens in the request. For details on the underlying design, see [Prefix Caching](prefix_caching.md) + +To find the longest cache hit prefix of a request, we enumerate from left (the first block) to right (the last block), checking whether the block is cached, and exit when cache misses. For example, we will return the first 7 tokens (0-6) as the cache hit prefix in the below example (blue blocks are cached): + +![Prefix Caching of Full Attention](../assets/design/hybrid_kv_cache_manager/full_attn.png) + +### Case 1: sliding window attention only models + +For sliding window attention layers, a naive implementation for memory allocation is to allocate `sliding_window_size` blocks and fill in the blocks in a round-robin way. But this naive implementation is not compatible with prefix caching so we didn't pick this design. In vLLM, we allocate different blocks for different tokens and free blocks that are outside the sliding window. + +For a new request, the cache hit prefix only requires the last `sliding_window_size - 1` tokens being cached. +Let's say `sliding_window_size = 4` and `block_size = 1`, and the request is a 15-token prompt (blue blocks are cached): + +![Prefix Caching of Sliding Window Attention](../assets/design/hybrid_kv_cache_manager/sw_attn.png) + +There are 3 possible cache hit prefixes: + +- cache hit length 5, compute prefill with [2, 3, 4] → [5, 6, …, 14] +- cache hit length 6, compute prefill with [3, 4, 5] → [6, 7, …, 14] +- cache hit length 14, compute prefill with [11, 12, 13] → [14] (most efficient) + +We can check the cache hit from right to left, and early exit when we find a match.This is opposite from full attention, where we check from left to right and early exit when the match fails. One potential cons (compared to full attention) is that we end up iterating over the entire list of tokens when there's no match, which is often a common case. This could potentially cause non-negligible overheads, but fine with full + swa, as discussed below. + +### Case 2: sliding window attention + full attention models + +The first problem is how to find the cache hit prefix. We need to "intersect" the cache hits of global and sliding window attention layers by: + +1. Get the longest cache hit for full attention (scanning from left to right) +2. Get the longest cache hit for sliding window attention that is within that length. Implemented by checking cache hits from right to left starting from the cache hit length of full attention. + +It can be ensured that the resulting cache hit of sliding window attention layers is also a cache hit of full attention layers. This is more efficient than finding all possible prefixes of each group and doing the intersection, because our approach can exit early if there is no cache hit. + +The algorithm applies to models with exactly two attention types full attention + X, where X can be an arbitrary efficient attention algorithm like sliding window, llama 4 local attention, and mamba. It doesn't support models without full attention layers, and models with more than 2 types of attention. This is enough for most hybrid models at the moment of writing this doc. + +The second question is the cache eviction policy. For now, we use one LRU queue for all kv cache groups. The blocks are added to the LRU queue when freed, either because the request is finished or the block is out of the sliding window. + +### Case 3: mamba models + +The prefix caching support of the mamba model is work in progress. Once implemented, models with mamba layer + full attention layer can be supported via the full attention + X algorithm in case 2. + +## Implementation + +### Overview + +![Overview of Hybrid KV Cache Manager](../assets/design/hybrid_kv_cache_manager/overview.png) + +The `KVCacheManager` is organized into 3 layers: + +- **[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager]**: The interface between the scheduler and kv cache management system. +- **[KVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinator]**: coordinate per-group SingleTypeKVCacheManagers to generate the allocation result of a request. Depending on the model's configuration, one of these coordinators is chosen: + - **[KVCacheCoordinatorNoPrefixCache][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinatorNoPrefixCache]**: Used when prefix caching is disabled. + - **[UnitaryKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.UnitaryKVCacheCoordinator]**: If only one KV cache group. The prefix caching logic is simplified as no intersection is needed. + - **[HybridKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.HybridKVCacheCoordinator]**: Handles exactly two KV cache groups (must include one full‑attention group plus one other efficient‑attention group). Other cases are not implemented. You can disable prefix caching to use the KVCacheCoordinatorNoPrefixCache. +- **[SingleTypeKVCacheManager][vllm.v1.core.single_type_kv_cache_manager.SingleTypeKVCacheManager]**: Each instance manages allocation and prefix caching for one KV cache group, implementing the attention‑type–specific logic (e.g., full attention, sliding window, Mamba). + +The blue box in the above figure shows the case with 10 full attention layers and 20 sliding window attention layers, thus: + +- use `HybridKVCacheCoordinator` +- use 1 `FullAttentionManager` and 2 `SlidingWindowManager` for the 3 `KVCacheGroup`s. + +### Memory Layout + +For a model with n `KVCacheGroup`s, each with m layers, we allocate m buffers. Each buffer is shared by n layers, one from each group. + +The following figure is for a model with 10 full attention layers (full.0 - full.9) and 20 sliding window attention layers (sw.0-sw.19). It follows "case 2" in "Allocation" section and is divided into 3 groups: + +- Group 0: 10 full attention layers (full.0 - full.9) +- Group 1: 10 sliding window attention layers (sw.0 - sw.9) +- Group 2: 10 sliding window attention layers (sw.10 - sw.19) + +And for a request, we allocate 11 blocks with `block_id` 0-6 to group 0, 7-8 to group 1, and 9-10 to group 2. + +With such an example, the physical memory is divided into 10 buffers (`KVCacheTensor` 0 - `KVCacheTensor` 9). Each buffer is shared by 3 layers (e.g., `KVCacheTensor` 0 is shared by full.0 from group 0, sw.0 from group 1, and sw.10 from group 2) and is divided into pieces with size `block_size * kv_hidden_size`. The KV cache of these 3 attention layers are saved to different pieces of the buffer based on the allocated `block_ids`: + +![Example Memory Layout](../assets/design/hybrid_kv_cache_manager/memory_layout.png) + +!!! note + One logic "block" is mapped to 10 pieces in the 10 buffers of the physical memory. diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md new file mode 100644 index 0000000000000..e70ee4a076e54 --- /dev/null +++ b/docs/design/io_processor_plugins.md @@ -0,0 +1,78 @@ +# IO Processor Plugins + +IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output. + +When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggered via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint. + +## Writing an IO Processor Plugin + +IO Processor plugins implement the `IOProcessor` interface (): + +```python +IOProcessorInput = TypeVar('IOProcessorInput') +IOProcessorOutput = TypeVar('IOProcessorOutput') + +class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + raise NotImplementedError + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + @abstractmethod + def post_process(self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs) -> IOProcessorOutput: + raise NotImplementedError + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + collected_output = [item async for i, item in model_output] + return self.post_process(collected_output, request_id, **kwargs) + + @abstractmethod + def parse_request(self, request: Any) -> IOProcessorInput: + raise NotImplementedError + + @abstractmethod + def output_to_response( + self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + raise NotImplementedError +``` + +The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. +The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. +The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. + +The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is available here . + +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online () and offline () inference examples. + +## Using an IO Processor plugin + +IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded: + +1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode. +2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json). + +The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json). diff --git a/docs/design/logits_processors.md b/docs/design/logits_processors.md new file mode 100644 index 0000000000000..20d78ca3aae2c --- /dev/null +++ b/docs/design/logits_processors.md @@ -0,0 +1,559 @@ +# Logits Processors + +!!! important + Some logits processors design changes are still in progress and the API may + change in the near future. We hope to stabilize this part of the API soon + +This document describes how the vLLM engine interacts with logits processors, and the programming model which vLLM supports for implementing logits processors. + +## Logits Processors Background + +A logits processor adjusts the next-token probability distribution, usually with the intention of steering the model towards a desired type of behavior. + +In vLLM, logits processors operate at batch granularity. During a given engine step, the logits processor consumes a `(num_requests) x (vocab_size)` tensor of raw logits output by the model. For all requests which enable the logits processor, the logits processor applies a transformation to the corresponding row of the logits tensor, while leaving other rows unmodified. The transformed logits tensor is then passed to softmax. + +## Logits Processors in the vLLM engine + +The vLLM engine's persistent batch data structure maintains a list of loaded logits processors. + +In order to operate on the entire batch at once, each logits processor may maintain metadata about the requests in the batch (i.e. each request's logits-processor-specific configuration settings). Therefore, logits processors are stateful. + +In each engine step, the vLLM engine will (1) update each logits processor's internal state and (2) apply logits processors to the model output logits. + +### Updating Logits Processor Internal State + +At the beginning of each engine step, the persistent batch may add, discard and/or reorder requests in response to the scheduler output. After the persistent batch has reorganized, the vLLM engine invokes each logits processor's `update_state()` method. This is necessary to ensure that logits processors' internal states are reorganized to match the new persistent batch state at the beginning of the engine step. + +The pseudocode below shows the process by which the vLLM persistent batch notifies each logits processor of changes in batch state: + +??? code "Model Runner Updates Logits Processor States" + + ``` python + # gpu_model_runner.py + + class GPUModelRunner(...): + + ... + + def execute_model(self, scheduler_output, ...): + self._update_states(scheduler_output) + + ... + + def _update_states(...): + + ... + + # ...update persistent batch to reflect new/finished requests & reordering + # of requests within batch... + + ... + + self.input_batch.refresh_metadata() + + + # gpu_input_batch.py + + class InputBatch: + + ... + + def refresh_metadata(self): + + ... + + # Update each logits processor's state to reflect persistent batch state + batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) + for logit_proc in self.logitsprocs.all: + logit_proc.update_state(batch_update) + + ... + + + # vllm/v1/sample/logits_processor/interface.py + + @dataclass(frozen=True) + class BatchUpdate: + # Batch state-change data structure which is passed to logits processors' + # update_state() methods + + batch_size: int + + removed: Sequence[RemovedRequest] + added: Sequence[AddedRequest] + moved: Sequence[MovedRequest] + + ``` + +### Applying Logits Processors to the Model Output Logits + +After updating persistent batch state, the vLLM model runner performs model inference to obtain logits. Then, the model runner invokes the sampler against the logits. In turn, part of the sampler's operation is to invoke the logits processors' `apply()` methods against the model output logit processors, yielding transformed logits (the `apply()` methods may modify the logits in-place or out-of-place, although in-place is more memory-efficient). This process is shown in the pseudocode below. + +Note that the sampler will access the logits processors via `SamplingMetadata.logitsprocs`. When the vLLM engine constructs `SamplingMetadata` (not shown in the code below), the reference to the list of logits processors is passed from the persistent batch data structure to `SamplingMetadata`. + +??? code "Apply logits processors to model output logits" + + ``` python + # gpu_model_runner.py + + class GPUModelRunner(...): + + ... + + def execute_model(self, scheduler_output, ...): + # (discussed in previous section) + self._update_states(scheduler_output) + + ... + + # ...run model inference to obtain logits... + + ... + + # Invoke sampler, which applies logits processors + sampler_output = self.sampler(logits=logits, + sampling_metadata=sampling_metadata) + + ... + + + # sampler.py + + class Sampler(nn.Module): + + ... + + def forward(self, logits, sampling_metadata): + + ... + + # Apply non-argmax-invariant logits processors to model output logits + for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): + logits = processor.apply(logits) + + sampled = self.sample(logits, sampling_metadata) + + ... + + # ...return sampler output data structure... + + + def sample(self, logits, sampling_metadta) + + ... + + # ...exit early if all requests are greedy-sampling... + + ... + + # Apply argmax-invariant logits processors + for processor in sampling_metadata.logitsprocs.argmax_invariant: + logits = processor.apply(logits) + + ... + + # ...perform sampling and return sampling result... + ``` + +At sampling time, the sampler checks whether all requests in the persistent batch employ greedy sampling. If that is the case, the sampler saves compute by skipping "argmax-invariant" logits processors. Here, "argmax" is shorthand for the token ID with the highest logit value in a given row of the logits tensor (i.e. the token which the model weighted the highest for a given request). + +* An **argmax-invariant logits processor** is a logits processor (such as Min-P) which does not modify the argmax. For example, a logits processor which masks out the lowest-probability tokens will not change which token ID has the max logit. Greedy sampling always picks the highest-logit-value token ID, and so conceptually an argmax-invariant logits processor can be skipped for greedy sampling requests. + +* A **non-argmax-invariant logits processor** is a logits processor which may modify the argmax. For example, a logits processor which masks all tokens except for EOS after a certain number of steps in order to force decoding to terminate might end up masking the max-logit-value token and therefore change the argmax. Conceptually, these logits processors cannot be skipped for greedy sampling requests. + +The vLLM logits processor abstraction requires the engine to apply logits processors at batch granularity; therefore in practice the argmax-invariant logits processors can only be skipped when the entire batch uses greedy sampling. + +## Logits Processor Programming Model + +The previous sections alluded to the interfaces which vLLM logits processors must support. This section introduces in full the programming model for implementing logits processors that are compatible with the vLLM engine, including the `LogitsProcessor` base class and its interface methods as well as the `BatchUpdate` data structure for representing persistent batch state changes, both of which are shown in the code below: + +??? code "`LogitsProcessor` base class and `BatchUpdate` data structure" + + ``` python + from abc import ABC, abstractmethod + from collections.abc import Sequence + from dataclasses import dataclass + from enum import Enum, auto + from typing import TYPE_CHECKING, Optional + + import torch + + from vllm import SamplingParams + + if TYPE_CHECKING: + from vllm.config import VllmConfig + + + class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = auto() + # Two-way i1<->i2 req swap within batch + SWAP = auto() + + + # (index, params, prompt_tok_ids, output_tok_ids) tuples for new + # requests added to the batch. + AddedRequest = tuple[int, SamplingParams, list[int], list[int]] + + # (index 1, index 2, directionality) tuples representing + # one-way moves or two-way swaps of requests in batch + MovedRequest = tuple[int, int, MoveDirectionality] + + # Batch indices of any removed requests. + RemovedRequest = int + + + @dataclass(frozen=True) + class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Key assumption: the `output_tok_ids` list (which is an element of each + # tuple in `added`) is a reference to the request's running output tokens + # list; via this reference, the logits processors always see the latest + # list of generated output tokens + removed: Sequence[RemovedRequest] + moved: Sequence[MovedRequest] + added: Sequence[AddedRequest] + + + class LogitsProcessor(ABC): + + @abstractmethod + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> None: + raise NotImplementedError + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional["BatchUpdate"], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update is non-None iff there have been + changes to the batch makeup. + """ + raise NotImplementedError + + ``` + +A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods: + +* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)` + * `vllm_config`: engine configuration data structure + * `device`: hardware accelerator device info + * `is_pin_memory`: flag indicating whether pin memory is available to support logits processor implementation + +* `apply(self, logits: torch.Tensor) -> torch.Tensor`: + * Consume a `(num_requests) x (vocab_size)` logits tensor (`logits`) + * Apply logits processor transformation at batch granularity + * Return a transformed `(num_requests) x (vocab_size)` logits tensor + * You can modify the input logits processors in-place or out-of-place; in-place is more memory-efficient + +* `is_argmax_invariant(self) -> bool`: + * Return `True` if the logits processor is argmax invariant (never changes what is the highest-logit-value token ID for a given request), `False` if the logits processor may modify argmax + * `is_argmax_invariant()` is evaluated once at startup; if `True`, vLLM will skip applying this logits processor in a given step when all requests use greedy sampling + +* `update_state(self, batch_update: Optional["BatchUpdate"]) -> None`: + * Consume a `BatchUpdate` data structure representing persistent batch state changes at the beginning of the current engine step + * Use the `BatchUpdate` members to update logits processor internal state + * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. + +### `BatchUpdate` data structure + +The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`): + +* **Remove:** remove (without replacement) request at index `i` + + * A Remove is represented in `Batchupdate.removed` by an `int` (representing `i`) + + * Effect of remove-at-index on batch: + + ``` text + Batch: [A,B,C] + Remove @ i: 1 + + => + + New Batch: [A,x,C] # Discard B and leave an empty slot + ``` + +* **Add:** add (or replace existing request with) a new request at index `i`. If a request is replaced, its associated state should be discarded. + + * An Add is represented in `Batchupdate.added` as a tuple of + + ``` text + (index, new request SamplingParams, prompt token ids, output token ids) + ``` + + * `prompt token ids` and `output token ids` are references to the request's prompt token ids and output token ids lists, respectively. Note that the output token ids list grows with each engine step, and this growth is visible to the logits processor because output token ids are passed by reference. **This is important for LogitsProcessors that take into account the tokens generated so far**. + + * The implementation of the particular logits processor subclass determines whether or how the fields in the added request tuple are digested into an internal representation. For example, a logits processor that does not utilize prompt or output token ids may only need to utilize `index` and `SamplingParams` and discard the other tuple fields + + * If index `i` currently holds a request, a replacement occurs: + + ``` text + Batch: [A,B,C] + New request to be added @ i: D @ 1 + + => + + New Batch: [A,D,C] # Add D, discard B + ``` + + * If index `i` does not currently hold a request (because `i` is out of bounds of the current batch size): + + ``` text + Batch: [A,B,C] + New request to be added @ i: D @ 3 + + => + + New Batch: [A,B,C,D] # Add D, extending batch + ``` + +* **Move:** move request at index `s` to index `d` OR swap requests at indices `s` and `d` + + * A Move is represented in `Batchupdate.moved` as a tuple of + + ``` text + (s, d, UNIDIRECTIONAL or SWAP) + ``` + + * If the Move specifies `UNIDRECTIONAL`: + + * The request at index `s` is moved to index `d`; index `s` becomes an empty slot + + ``` text + Batch: [A,x,C,D] + Unidirectionally Move s -> d: 3 -> 1 + + => + + New Batch: [A,D,C,x] # Move D to 1, leaving empty slot at 3 + ``` + + * If another request already resided at index `d`, it is replaced and discarded + + ``` text + Batch: [A,B,C,D] + Unidirectionally Move s -> d: 3 -> 1 + + => + + New Batch: [A,D,C,x] # Move D to 1, discarding B and leaving empty slot at 3 + ``` + + * If the Move specifies `SWAP`, the requests at `s` and `d` exchange indices + + ``` text + Batch: [A,B,C,D] + Swap Move s <-> d: 3 <-> 1 + + => + + New Batch: [A,D,C,B] # Swap B and D + ``` + +Additionally, the `BatchUpdate` data structure includes a representation (`batch_size`) of the size of the persistent batch at the beginning of the engine step. + +### How the vLLM engine builds the `BatchUpdate` data structure + +Logits processor `update_state()` implementations should assume the following model for how the model runner updates persistent batch state (expressed here in terms of the `BatchUpdate` abstraction): + +1. Identify indices of requests which finished in the current engine step + +2. Identify new requests introduced in the current step + +3. Use Add operations to replace as many finished requests with new requests, in order of increasing index of the replaced request starting with the lowest index + +4. Based on the relative number of new and finished requests: + + 1. If the numbers of new and finished requests are the same, proceed to next step + + 2. *If there are more new requests than finished requests:* apply Add operations to extend the batch with the remaining new requests which did not replace finished requests. Assign consecutive indices to these new requests, starting with `current_max_batch_index + 1` + + 3. *If there are fewer new requests than finished requests:* + + * Apply Remove operations to finished requests which were not replaced with new requests. These removed request indices will necessarily be greater than the greatest index of the finished requests which were replaced in the previous step. The Removes may leave the batch in a non-contiguous state + + * **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous + + * **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots + +5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch + +Notes: + +* A logits processor `update_state()` method must process batch update operations in the following order: removes, adds, moves + +* The index argument for Add operations refers to the index *at the time the Add occurred*, i.e. before any Move operations + * Example: if a request is Added at index 5 and then swapped with index 3, the Add operation in `BatchUpdate.added` will be associated with index 5 not 3 + * In other words Move operations can be assumed to be applied after Adds and Removes + +* Move operations can be assumed to be applied in the order in which they appear in `BatchUpdate.moved` + +* If there are no new/finished requests and there is no batch reordering, then the batch update for the logits processors will be `None` + +#### Example: Batch Update with Fewer New Requests Than Finished Requests + +The following example models an engine step where 1 new request is introduced and 2 finished requests are eliminated, additionally the attention backend performs a swap to optimize the batch ordering. + +``` text +Batch state (beginning of engine step): [A,B,C,D] +Batch size: 4 + +New requests: E + +Finished requests: A, C + +Processing steps (using BatchUpdate abstraction): + +1. Add E at index 0 + +[E,B,C,D] # Discard A +Batch size: 4 + +2. Remove at index 2 + +[E,B,x,D] # Discard C, empty slot at index 2 +Batch size: 4 + +3. Condense batch with a Unidirectional Move 3 -> 2 operation and shrink batch + +[E,B,D] x # Empty slot is now outside batch +Batch size: 3 + +4. Attention backend optimization: reorder batch with Swap 0 <-> 1 + +[B,E,D] +Batch size: 3 + +``` + +The resulting `BatchUpdate` data structure will look like + +``` text +BatchUpdate instance +* added: [(0,E's SamplingParams,E's prompt tokens ref,E's output tokens ref)] +* removed: [2] # request C was removed without replacement +* moved: [(3,2,UNIDIRECTIONAL),(0,1,SWAP)] +``` + +#### Example: Batch Update with More New Requests Than Finished Requests + +The following example models an engine step where 2 new requests are introduced and 1 finished request is eliminated, additionally the attention backend performs a swap to optimize the batch ordering. + +``` text +Batch state (beginning of engine step): [A,B,C,D] +Batch size: 4 + +New requests: E,F + +Finished requests: C + +Processing steps (using BatchUpdate abstraction): + +1. Add E at index 2 + +[A,B,E,D] # Discard C +Batch size: 4 + +2. Add F at index 4 (current max batch index + 1) + +[A,B,E,D,F] # Extend batch by 1 +Batch size: 5 + +4. Attention backend optimization: reorder batch with Swap 0 <-> 1 + +[B,A,E,D,F] +Batch size: 5 + +``` + +Note that batch condensation is skipped because there are no empty slots left behind by Remove operations. + +The resulting `BatchUpdate` data structure will look like + +``` text +BatchUpdate instance +* added: [(2,E's SamplingParams,E's prompt tokens ref,E's output tokens ref),(4,F's SamplingParams,F's prompt tokens ref,F's output tokens ref)] +* removed: [] # no requests were removed without replacement +* moved: [(0,1,SWAP)] +``` + +## How to Introduce a New Logits Processor to vLLM + +### Best Practices for Writing Built-In Logits Processors + +* Write efficient `apply()` and `update_state()` implementations in light of the fact that logits processors operate at batch granularity + * For example, you may be able to use efficient vectorized operations to implement `apply()` or update internal state vectors in `update_state()` + * However, if you think that a logits processor may be used infrequently, it may be appropriate to use a "sparse" representation of request state i.e. the class can represent request configuration using a dictionary which only stores metadata about requests that enable the logits processor + +* It is up to the logits processor author to determine: + + 1. **The per-request attributes which configure the logits processor's behavior against that request.** For example, if you are writing a new built-in logits processor for vLLM, you may or may not need to add additional fields to `SamplingParams` and the vLLM REST API + + 2. **The conditions under which the logits processor is or is not enabled on a per-request basis.** Unless your intention is for the built-in logits processor to act on all requests all the time, you should write your logits processor in such a way that it is possible to disable the logits processor for a given request, i.e. by defaulting an argument to `None` or by passing in a specific do-nothing argument value i.e. `0.0`. Try to save compute and memory for requests which disable the logits processor + + 3. **The conditions under which the logits processor is short-circuited at the batch level.** Even if you have defined a way to disable the built-in logits processor at the request level, it may be difficult to translate this into compute savings i.e. if your `update_state()` and `apply()` implementations use efficient vectorized implementations that operate on the whole persistent batch in a single command. For example, you cannot skip an entire vectorized operation in `apply()` just because one request disabled the logits processor. To save compute in the edge-case where no running requests utilize the built-in logits processor, we recommend designing `apply()` to return the unmodified input tensor if all requests have the logits processor disabled. Similarly, consider whether steps can be skipped in `update_state()` if no requests enable the logits processor + + * Additionally, an easy way to save compute in `update_state()` is to exit early when the batch_update is `None` + +* Ensure that the logits processor `update_state` method discards information about finished requests (i.e. requests which are replaced by an Add or which are subject to a Remove) + +* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method + +### Built-In Logits Processors + +Built-in logits processors are always loaded when the vLLM engine starts. See the existing vLLM built-in logits processors in `vllm/v1/sample/logits_processor/builtin.py` for examples of how to write a new built-in vLLM logits processor. It makes sense to write a PR to introduce a new logits processor as a built-in if it is likely to be useful to a wide audience. vLLM currently employs the following built-in logits processors based on the programming model described above: + +* Min-P + +* Logit bias + +* Min-tokens + +Review these logits processor implementations for guidance on writing built-in logits processors. + +Additionally, the following logits-processor-like functionalities are hard-coded into the sampler and do not yet utilize the programming model described above. Most of them will be refactored to use the aforemented logits processor programming model. + +* Allowed token IDs + +* Bad words + +* Repetition penalty + +* Frequency penalty + +* Presence penalty + +* Temperature + +* Top-K + +* Top-P + +### Custom Logits Processors + +vLLM can be augmented with [user-provided custom logits processors](../features/custom_logitsprocs.md). diff --git a/docs/design/metrics.md b/docs/design/metrics.md index b01838883f31e..90b2fd32f2979 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -99,11 +99,11 @@ http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201 ### Multi-process Mode -In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See . +In v0, metrics are collected in the engine core process and we use multiprocess mode to make them available in the API server process. See . ### Built in Python/Process Metrics -The following metrics are supported by default by `prometheus_client`, but they are not exposed when multi-process mode is used: +The following metrics are supported by default by `prometheus_client`, but they are not exposed when multiprocess mode is used: - `python_gc_objects_collected_total` - `python_gc_objects_uncollectable_total` @@ -565,7 +565,7 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_emitted_tokens_total` (Counter) There is a PR under review () to add "prompt lookup (ngram)" -seculative decoding to v1. Other techniques will follow. We should +speculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. !!! note diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md new file mode 100644 index 0000000000000..0831c5bc790dc --- /dev/null +++ b/docs/design/moe_kernel_features.md @@ -0,0 +1,120 @@ +# Fused MoE Kernel features + +The purpose of this document is to provide an overview of the various MoE kernels (both modular and non-modular) so it will be easier to select an appropriate set of kernels for any particular situation. This includes information about the all2all backends used by modular kernels. + +## Fused MoE Modular All2All backends + +There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` sub-classes provide an interface for each all2all backend. + +The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support. + +The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, the finalize step requires the same format. All the backend `prepare` methods expect activations in standard format and all the `finalize methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document. + +The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports. e.g. deepep_high_throughput supports only block-quantized fp8 format, any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 w/per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16. + +Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step). + +Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass, for non-modular kernels, it is up to the experts function to deal with this flag. + +unless otherwise specified, backends are controlled via `VLLM_ALL2ALL_BACKEND`. All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP w/o EP. + + + +| Backend | Output act. format | Quant. types | Quant. format | Async | Apply Weight On Input | Sub-class | +|---------------------------------------|--------------------|-----------------|------------------------|-------|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------| +| naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE.forward_impl] | +| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] | +| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | +| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | +| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] | +| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | +| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | +| MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | +| BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | + +!!! info "Table key" + 1. All types: mxfp4, nvfp4, int4, int8, fp8 + 2. A,T quantization occurs after dispatch. + 3. All quantization happens after dispatch. + 4. Controlled by different env vars (`VLLM_FLASHINFER_MOE_BACKEND` "throughput" or "latency") + 5. This is a no-op dispatcher that can be used to pair with any modular experts to produce a modular kernel that runs w/o dispatch or combine. These cannot be selected via environment variable. These are generally use for testing or adapting an expert subclass to the `fused_experts` API. + 6. This depends on the experts implementation. + + --- + + - G - Grouped + - G(N) - Grouped w/block size N + - A - Per activation token + - T - Per tensor + +Modular kernels are supported by the following `FusedMoEMethodBase` classes. + +- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] +- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] +- [`CompressedTensorsW4A4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4MoeMethod] +- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] +- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] +- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] + +## Fused MoE Experts Kernels + +The are a number of MoE experts kernel implementations for different quantization types and architectures. Most follow the general API of the base Triton [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts] function. Many have modular kernel adatpers so they can be used with compatible all2all backends. This table lists each experts kernel and its particular properties. + +Each kernel must be provided with one of the supported input activation formats. Some flavors of kernels support both standard and batched formats through different entry points, e.g. `TritonExperts` and `BatchedTritonExperts`. Batched format kernels are currently only needed for matching with certain all2all backends, e.g. `pplx`, `DeepEPLLPrepareAndFinalize`. + +Similar to the backend kernels, each experts kernel only supports certain quantization formats. For non-modular experts, the activations will be in the original type and quantized internally by the kernel. Modular experts will expect the activations to already be in the quantized format. Both types of experts will yield outputs in the original activation type. + +Each experts kernel supports one or more activation functions, e.g. silu, gelu that are applied to the intermediate results. + +As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts. + +Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`. + +To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels must have compatible activation formats, quantization types and quantization formats. + +| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source | +|------------------------------|-----------------------|------------------|---------------|-------------------------------------------------------------|-----------------------|---------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| triton | standard | all1 | G,A,T | silu, gelu,
swigluoai,
silu_no_mul,
gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],
[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] | +| triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] | +| deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | +| cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],
[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | +| cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],
[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | +| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | +| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | +| deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | +| marlin | standard | 3 | 3 | silu,
swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] | +| marlin experts | standard | N/A | N/A | silu,
swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] | +| trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | +| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | +| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | +| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] | +| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | +| naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | + +!!! info "Table key" + 1. All types: mxfp4, nvfp4, int4, int8, fp8 + 2. A dispatcher wrapper around triton and deep gemm experts. Will select based on type + shape + quantization params + 3. uint4, uint8, fp8, fp4 + 4. This is a naive implementation of experts that supports batched format. Mainly used for testing. + 5. The `activation` parameter is ignored and SwiGlu is used by default instead. + 6. Only handled by or supported when used with modular kernels. + +## Modular Kernel "families" + +The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts. + +| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | +|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------| +| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`| +| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 247072d1cb275..6e92b20d267b4 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -8,7 +8,7 @@ page for information on known issues and how to solve them. ## Introduction !!! important - The source code references are to the state of the code at the time of writing in December, 2024. + The source code references are to the state of the code at the time of writing in December 2024. The use of Python multiprocessing in vLLM is complicated by: diff --git a/docs/design/p2p_nccl_connector.md b/docs/design/p2p_nccl_connector.md index adf838306bc77..4674bef8d2b64 100644 --- a/docs/design/p2p_nccl_connector.md +++ b/docs/design/p2p_nccl_connector.md @@ -97,7 +97,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20001 \ --tensor-parallel-size 1 \ @@ -118,7 +118,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20002 \ --tensor-parallel-size 1 \ @@ -139,7 +139,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20003 \ --tensor-parallel-size 1 \ @@ -160,7 +160,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20004 \ --tensor-parallel-size 1 \ @@ -190,7 +190,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20001 \ --tensor-parallel-size 1 \ @@ -211,7 +211,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20002 \ --tensor-parallel-size 1 \ @@ -232,7 +232,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20003 \ --tensor-parallel-size 1 \ @@ -253,7 +253,7 @@ python3 disagg_proxy_p2p_nccl_xpyd.py & ??? console "Command" ```shell - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ + CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ --host 0.0.0.0 \ --port 20004 \ --tensor-parallel-size 1 \ diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index fb991a35caf30..d87b2a639df12 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a whole block of value tokens. And each `accs` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the `accs` variable will have 8 elements, which -are 0th, 32th … 224th elements of a value head that are accumulated +are 0th, 32nd … 224th elements of a value head that are accumulated from all assigned 8 tokens. ## LV diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index ca1c2c2305d91..a384c6289f4ff 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -49,6 +49,8 @@ Every plugin has three parts: - **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported. +- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name. + ## Guidelines for Writing Plugins - **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes. diff --git a/docs/design/torch_compile.md b/docs/design/torch_compile.md index 47ac4958dbf7f..32a4efef71fb0 100644 --- a/docs/design/torch_compile.md +++ b/docs/design/torch_compile.md @@ -2,7 +2,10 @@ In vLLM's V1 architecture, `torch.compile` is enabled by default and is a critical part of the framework. This document gives a simple walk-through example to show how to understand the `torch.compile` usage. -Throughout the example, we will run a common Llama model using v1, and turn on debug level logging to show all the details. The command to be used is `VLLM_USE_V1=1 VLLM_LOGGING_LEVEL=DEBUG vllm serve meta-llama/Llama-3.2-1B`. +Throughout the example, we will run a common Llama model, and turn on debug level logging to show all the details. The command to be used is `VLLM_LOGGING_LEVEL=DEBUG vllm serve meta-llama/Llama-3.2-1B`. + +!!! note + For more information and the latest progress of `torch.compile` integration, see this [Blog Post](https://blog.vllm.ai/2025/08/20/torch-compile.html). ## Compilation Cache @@ -16,7 +19,7 @@ vLLM will take all the available factors into consideration, and decide a direct The factors considered include: -- All the related configs (see the `compute_hash` functions in the [config.py](gh-file:vllm/config.py)) +- All the related configs (see the `compute_hash` functions in their respective configs in the [config folder](gh-file:vllm/config)) - PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](gh-file:vllm/compilation/compiler_interface.py)) - The model's forward function and the relevant functions called by the forward function (see below) @@ -133,7 +136,7 @@ Unfortunately, because auto-tuning takes quite a long time (from seconds to minu ## Cudagraph Capture -vLLM's V1 architecture uses piecewise cudagraph. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on a common observation: computation between attentions are usually token-wise and easy to deal with for cudagraph; while the attention operation is non-trivial to be cudagraph compatible. Thus, by running the attention operation in eager mode while the rest operations in cudagraph, we keep the flexibility of the attention operation. +vLLM's V1 architecture uses piecewise cudagraph that aligns with the piecewise compilation. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on a common observation: computation between attentions are usually token-wise and easy to deal with for cudagraph; while the attention operation is non-trivial to be cudagraph compatible. Thus, by running the attention operation in eager mode while the rest operations in cudagraph, we keep the flexibility of the attention operation. The piecewise cudagraph also has fine-grained memory management. The purpose is to only exclude the attention kernel from cudagraph, while keeping all the rest modules and the memory allocation operations in the cudagraph. This is why the attention operation in V1 has the output tensor as the input of the attention. @@ -150,6 +153,4 @@ Then it will only capture cudagraph for the specified sizes. It can be useful to ### Full Cudagraph capture -It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`. - -Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models or MOEs. See [CUDA Graphs](cuda_graphs.md) for more details. diff --git a/docs/examples/README.md b/docs/examples/README.md index 34e4dfd408a20..94f5efc92f386 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -2,6 +2,6 @@ vLLM's examples are split into three categories: -- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference/) -- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving/) -- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others/) +- If you are using vLLM from within Python code, see the *Offline Inference* section. +- If you are using vLLM from an HTTP application or client, see the *Online Serving* section. +- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see the *Others* section. diff --git a/docs/features/compatibility_matrix.md b/docs/features/README.md similarity index 83% rename from docs/features/compatibility_matrix.md rename to docs/features/README.md index 5b08b3810776c..05ce0b57a9fc8 100644 --- a/docs/features/compatibility_matrix.md +++ b/docs/features/README.md @@ -1,4 +1,6 @@ -# Compatibility Matrix +# Features + +## Compatibility Matrix The tables below show mutually exclusive features and the support on some hardware. @@ -12,7 +14,7 @@ The symbols used have the following meanings: !!! note Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination. -## Feature x Feature +### Feature x Feature -| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | [pooling](../models/pooling_models.md) | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | -|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| -| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | -| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | -| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | -| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | | -| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | -| [pooling](../models/pooling_models.md) | 🟠\* | 🟠\* | ✅ | ❌ | ✅ | ✅ | | | | | | | | | -| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | -| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | -| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | -| async output | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | -| multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | -| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | -| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | -| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | +| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | [pooling](../models/pooling_models.md) | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | [prompt-embeds](prompt_embeds.md) | +|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| +| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | +| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | | +| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | | +| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | | | +| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | +| [pooling](../models/pooling_models.md) | 🟠\* | 🟠\* | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | +| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | | +| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | +| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | | +| async output | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | | +| multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | | +| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | +| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | | +| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | | +| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | \* Chunked prefill and prefix caching are only applicable to last-token pooling. ^ LoRA is only applicable to the language backbone of multimodal models. [](){ #feature-x-hardware } -## Feature x Hardware +### Feature x Hardware | Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU | |-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----| @@ -74,6 +77,4 @@ th:not(:first-child) { | multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ | | best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | - -!!! note - Please refer to [Feature support through NxD Inference backend][feature-support-through-nxd-inference-backend] for features supported on AWS Neuron hardware +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](gh-issue:25097) | diff --git a/docs/features/custom_arguments.md b/docs/features/custom_arguments.md new file mode 100644 index 0000000000000..74ed40835b4d4 --- /dev/null +++ b/docs/features/custom_arguments.md @@ -0,0 +1,46 @@ +# Custom Arguments + +You can use vLLM *custom arguments* to pass in arguments which are not part of the vLLM `SamplingParams` and REST API specifications. Adding or removing a vLLM custom argument does not require recompiling vLLM, since the custom arguments are passed in as a dictionary. + +Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code. + +## Offline Custom Arguments + +Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`: + +``` python +SamplingParams(extra_args={"your_custom_arg_name": 67}) +``` + +This allows arguments which are not already part of `SamplingParams` to be passed into `LLM` as part of a request. + +## Online Custom Arguments + +The vLLM REST API allows custom arguments to be passed to the vLLM server via `vllm_xargs`. The example below integrates custom arguments into a vLLM REST API request: + +``` bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-1.5B-Instruct", + ... + "vllm_xargs": {"your_custom_arg": 67} + }' +``` + +Furthermore, OpenAI SDK users can access `vllm_xargs` via the `extra_body` argument: + +``` python +batch = await client.completions.create( + model="Qwen/Qwen2.5-1.5B-Instruct", + ..., + extra_body={ + "vllm_xargs": { + "your_custom_arg": 67 + } + } +) +``` + +!!! note + `vllm_xargs` is assigned to `SamplingParams.extra_args` under the hood, so code which uses `SamplingParams.extra_args` is compatible with both offline and online scenarios. diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md new file mode 100644 index 0000000000000..201b340c5972c --- /dev/null +++ b/docs/features/custom_logitsprocs.md @@ -0,0 +1,445 @@ +# Custom Logits Processors + +!!! important + Some logits processors design changes are still in progress and the API may + change in the near future. We hope to stabilize this part of the API soon + +A "custom" logits processor is written by a user of vLLM and is loaded into vLLM at initialization without needing to modify or recompile the vLLM source code. It is the opposite of a built-in logits processor. + +This document shows how to write, load and use a custom logits processor. + +## Logits Processors Background + +A logits processor adjusts the next-token probability distribution, usually with the intention of steering the model towards a desired type of behavior. + +In vLLM, logits processors operate at batch granularity. During a given engine step, the logits processor consumes a `(num_requests) x (vocab_size)` tensor of raw logits output by the model. For all requests which enable the logits processor, the logits processor applies a transformation to the corresponding row of the logits tensor, while leaving other rows unmodified. The transformed logits tensor is then passed to softmax. + +## Creating a Custom Logits Processor + +Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods: + +* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)` + * `vllm_config`: engine configuration data structure + * `device`: hardware accelerator device info + * `is_pin_memory`: flag indicating whether pin memory is available to support logits processor implementation + +* `apply(self, logits: torch.Tensor) -> torch.Tensor`: + * Consume a `(num_requests) x (vocab_size)` logits tensor (`logits`) + * Apply logits processor transformation at batch granularity + * Return a transformed `(num_requests) x (vocab_size)` logits tensor + * You can modify the input logits processors in-place or out-of-place; in-place is more memory-efficient + +* `is_argmax_invariant(self) -> bool`: + * Return `True` if the logits processor is argmax invariant (never changes what is the highest-logit-value token ID for a given request), `False` if the logits processor may modify argmax + * `is_argmax_invariant()` is evaluated once at startup; if `True`, vLLM will skip applying this logits processor in a given step when all requests use greedy sampling + +* `update_state(self, batch_update: Optional["BatchUpdate"]) -> None`: + * Consume a `BatchUpdate` data structure representing persistent batch state changes at the beginning of the current engine step + * Use the `BatchUpdate` members to update logits processor internal state + * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. + +### How the vLLM engine builds the `BatchUpdate` data structure + +!!! important + Some logits processors design changes are still in progress. We expect + that in the future you will not need to account for batch state changes + when implementing a logits processor, and the information in this section + will become irrelevant. + +Logits processor `update_state()` implementations should assume the following model for how the model runner updates persistent batch state (expressed here in terms of the `BatchUpdate` abstraction): + +1. Identify indices of requests which finished in the current engine step + +2. Identify new requests introduced in the current step + +3. Use Add operations to replace as many finished requests with new requests, in order of increasing index of the replaced request starting with the lowest index + +4. Based on the relative number of new and finished requests: + + 1. If the numbers of new and finished requests are the same, proceed to next step + + 2. *If there are more new requests than finished requests:* apply Add operations to extend the batch with the remaining new requests which did not replace finished requests. Assign consecutive indices to these new requests, starting with `current_max_batch_index + 1` + + 3. *If there are fewer new requests than finished requests:* + + * Apply Remove operations to finished requests which were not replaced with new requests. These removed request indices will necessarily be greater than the greatest index of the finished requests which were replaced in the previous step. The Removes may leave the batch in a non-contiguous state + + * **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous + + * **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots + +5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch + +Notes: + +* A logits processor `update_state()` method must process batch update operations in the following order: removes, adds, moves + +* The index argument for Add operations refers to the index *at the time the Add occurred*, i.e. before any Move operations + * Example: if a request is Added at index 5 and then swapped with index 3, the Add operation in `BatchUpdate.added` will be associated with index 5 not 3 + * In other words Move operations can be assumed to be applied after Adds and Removes + +* Move operations can be assumed to be applied in the order in which they appear in `BatchUpdate.moved` + +* If there are no new/finished requests and there is no batch reordering, then the batch update for the logits processors will be `None` + +### Passing Custom Argument to a Custom Logits Processor + +Unlike built-in logits processors, custom logits processors may require configuration arguments that are not hard-coded into `SamplingParams` or the vLLM server REST API. To solve this problem, custom logits processors may leverage vLLM [custom arguments](./custom_arguments.md) support to receive configuration settings from the user (although you are also free to design a custom logits processor which utilizes the pre-existing fields in `SamplingParams`.) + +### Example Custom Logits Processor Implementation + +The contrived example below implements a custom logits processor which consumes a `(num\_requests) \times (vocab\_size)` logits tensor and masks out all tokens except for one (`target_token`) with `float(-inf)`. The logits processor is disabled for any request that does not specify `target_token`. To determine whether the logits processor is enabled and which token to leave unmasked, the logits processor checks `SamplingParams.extra_args` for a `target_token` custom argument associated with each request: + +??? code "Example custom logits processor definition" + + ``` python + from typing import Optional + import torch + from vllm.config import VllmConfig + from vllm.sampling_params import SamplingParams + from vllm.v1.sample.logits_processor import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) + + class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + self.req_info: dict[int, int] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + for index, params, _, _ in batch_update.added: + assert params is not None + if params.extra_args and (target_token := + params.extra_args.get("target_token")): + self.req_info[index] = target_token + else: + self.req_info.pop(index, None) + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional move (a->b) and swap + # (a<->b) + for adx, bdx, direct in batch_update.moved: + a_val = self.req_info.pop(adx, None) + b_val = self.req_info.pop(bdx, None) + if a_val is not None: + self.req_info[bdx] = a_val + if direct == MoveDirectionality.SWAP and b_val is not None: + self.req_info[adx] = b_val + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + cols = torch.tensor( + list(self.req_info.values()), dtype=torch.long, device=logits.device + ) + rows = torch.tensor( + list(self.req_info.keys()), dtype=torch.long, device=logits.device + ) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float('-inf') + logits[rows, cols] = values_to_keep + + return logits + ``` + +In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor. + +The `DummyLogitsProcessor.update_state()` implementation maintains a "sparse" representation of the batched requests in the `self.req_info` dictionary: only those requests which specify a `target_token` value have a key in the dictionary. `update_state()` adjusts the stored request indices and `target_token` values (keys and values respectively in `self.req_info`) in response to Add, Remove and Move operations against the persistent batch. + +### Wrapping an Existing Request-Level Logits Processor + +Although the vLLM engine applies logits processors at batch granularity, some users may want to use vLLM with a "request-level" logits processor implementation - an implementation which operates on individual requests. This will be especially true if your logits processor was developed for vLLM version 0, which required it to be a `Callable` (as described [here](https://docs.vllm.ai/en/v0.10.1.1/api/vllm/logits_process.html)) conforming to the following type annotation: + +``` python +RequestLogitsProcessor = Union[ + + # (output token ids, logits tensor) -> logits tensor + Callable[[list[int], Tensor], Tensor], + + # (prompt token ids, output token ids, logits tensor) -> logits tensor + Callable[[list[int], list[int], Tensor], Tensor], +] +``` + +While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above. + +You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: + +??? code "Example of Wrapping a Request-Level Logits Processor" + + ``` python + ... + + from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, # Wrapper base-class + RequestLogitsProcessor, # Request-level logitsproc type annotation + ) + + ... + + # Stand-in for your request-level logits processor: + class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + ... + + # Example of wrapping the request-level logits processor: + class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + ``` + +!!! note + Your `new_req_logits_processor()` override can return `None` to signal that the wrapped logits processor should not be applied to the request in question. + +Once you have created a custom subclass (like `WrappedPerReqLogitsProcessor`) which wraps your request level logits processor, you can pass the custom subclass to vLLM via any of the methods described in the following section. + +## Ways to Load Your Custom Logits Processor in vLLM + +Logits processors are loaded at initialization. Critically, the set of loaded logits processors cannot be modified after the vLLM engine finishes loading, and new logits logits processors cannot be loaded on-demand for individual requests. + +This section details different ways of making your logits processor visible to vLLM and triggering vLLM to load your logits processor. + +### Method 1: Pass the Custom Logits Processor Fully-Qualified Class Name (FQCN) to vLLM at Initialization Time + +This method is supported in both offline and online vLLM usage scenarios. The custom logits processor's FQCN (in the form of `dotted.path.to.module:ClassName`) can be passed as an argument to the `LLM` and `AsyncLLM` Python constructors, or as a CLI argument to `vllm serve` with the following syntax + +``` bash +vllm serve ... --logits_processors ... +``` + +The only requirements on the FQCN are + +1. Python's `importlib.import_module()` must be able to resolve the dotted path portion of the FQCN and load it as a module + +2. The class-name portion of the FQCN must be possible to import from the loaded module + +3. The object pointed to by the FQCN must be a subclass of `LogitsProcessor` + +See examples below: + +??? code "Passing custom logits processor FQCN to `LLM` in Python" + + ``` python + # Pass in FQCN + llm = LLM( + model="facebook/opt-125m", + logits_processors=["your.module.path:DummyLogitsProcessor"], + ) + ``` + +??? code "Passing custom logits processor FQCN to `AsyncLLM` in Python" + + ``` python + # Pass in FQCN + engine_args = AsyncEngineArgs(model="facebook/opt-125m", + logits_processors=["your.module.path:DummyLogitsProcessor"]) + async_llm = AsyncLLM.from_engine_args(engine_args) + ``` + +??? code "Passing custom logits processor FQCN to vLLM server via CLI" + + ```bash + vllm serve facebook/opt-125m --logits_processors your.module.path:DummyLogitsProcessor + ``` + +### Method 2: Automatically Detect Custom Logits Processors Installed in Your Python Environment As Entry Points + +[`setuptools`](https://setuptools.pypa.io/en/latest/userguide/entry_point.html) can enable installed packages to make themselves available as plugins to other Python programs, via pieces of metadata known as "entry points". + +During initialization, vLLM automatically scans the `vllm.logits_processors` entry point group and loads any installed logits processors which it finds. + +Suppose that you have developed a Python package that holds your custom logits processors. You can expose each logits processor to vLLM by adding a unique entrypoint for each logits processor to your logits processor Python package. The example below shows how to add an entrypoint to your project's `pyproject.toml` file: + +??? code "Exposing a custom logits processor as a Python entrypoint" + + ``` toml + [project.entry-points."vllm.logits_processors"] + dummy_logits_processor = "your.module.path:DummyLogitsProcessor" + ``` + +Once your package is installed, your custom logits processor will be loaded automatically whenever vLLM is initialized. You do *not* need to pass the custom logits processor to the `LLM` or `AsyncLLM` constructors or to the vLLM server explicitly at initialization time if your logits processor is exposed as an entry point. + +!!! note + vLLM will *always* load *all* logits processors which are exposed via entrypoints under the `vllm.logits_processors` grouping. + +### Method 3 (Offline-only): Pass a Python Class Object to the vLLM Constructor + +You can pass one or more custom logits processor class objects to the `LLM` and `AsyncLLM` constructors. This option is very flexible, as the logits processor classes may either be (1) defined locally within the same Python source file where `LLM` or `AsyncLLM` is instantiated, or (2) imported from a Python package. + +??? code "Passing custom logits processor class object to `LLM` or `AsyncLLM` in Python" + + ``` python + # Import custom logits processor + from some.module import DummyLogitsProcessor + + # ...or... + + # Define custom logits processor locally + from vllm.v1.sample.logits_processor import LogitsProcessor + + class DummyLogitsProcessor(LogitsProcessor): + # See DummyLogitsProcessor implementation above + ... + + # Pass class object to LLM constructor + llm = LLM( + model="facebook/opt-125m", + logits_processors=[DummyLogitsProcessor], + ) + + # Pass class object to AsyncLLM constructor + engine_args = AsyncEngineArgs(model="facebook/opt-125m", + logits_processors=[DummyLogitsProcessor]) + async_llm = AsyncLLM.from_engine_args(engine_args) + ``` + +## Invoking a Custom Logits Processor Against a Request + +The design of the custom logits processor determines whether the logits processor must be enabled/disabled for a given request, and what arguments must be provided to configure the logits processor. + +The examples below show how a user would pass a custom argument (`target_token`) to `DummyLogitsProcessor` in order to (1) enable the logits processor for that particular request and (2) control the logits processor's behavior. + +??? code "vLLM REST API: configure custom logits processor for a request" + + ``` bash + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-1.5B-Instruct", + ... + "vllm_xargs": {"target_token": 67} + }' + ``` + +??? code "OpenAI SDK: configure custom logits processor for a request" + + ``` python + batch = await client.completions.create( + model="Qwen/Qwen2.5-1.5B-Instruct", + ..., + extra_body={ + "vllm_xargs": { + "target_token": 67 + } + } + ) + ``` + +??? code "Offline: configure custom logits processor for an `LLM` request" + + ``` python + outputs_logitproc = llm.generate("your prompt", + SamplingParams(..., + extra_args={"target_token": 67})) + ``` + +??? code "Offline: configure custom logits processor for an `AsyncLLM` request" + + ``` python + async for out in engine.generate(request_id="your request id", + prompt="your prompt", + sampling_params=SamplingParams(..., + extra_args={"target_token": 67})): + + # Process async request outputs + ... + ``` + +## Best Practices for Writing Custom Logits Processors + +Once vLLM loads a logits processor during initialization, then vLLM will invoke `update_state()` and `apply()` against that logits processor in every engine step. Both methods operate on all requests which currently reside in the vLLM persistent batch. Thus it is important to implement these methods efficiently. + +* Write efficient `apply()` and `update_state()` implementations in light of the fact that logits processors operate at batch granularity + * For example, you may be able to use efficient vectorized operations to implement `apply()` or update internal state vectors in `update_state()` + * However, if you think that a logits processor may be used infrequently, it may be appropriate to use a "sparse" representation of request state i.e. the class can represent request configuration using a dictionary which only stores metadata about requests that enable the logits processor + * **Note:** wrapped request-level logits processors do not need to implement `apply()` and `update_state()`; the default `AdapterLogitsProcessor.update_state()` implementation maintains a sparse representation of request state, wherein requests for which `new_req_logits_processor()` returns `None` are not represented in the base-class state dictionary. The default implementation of `AdapterLogitsProcessor.apply()` applies the request-level logits processor to each row of input logits sequentially and assembles the output logits tensor. If the performance of this `AdapterLogitsProcessor` default implementation is insufficient, then avoid wrapping your request-level logits processor and instead re-implement it as a `LogitsProcessor` subclass with optimized `apply()` and `update_state()` implementations that operate at batch granularity + +* It is up to the logits processor author to determine: + + 1. **The per-request attributes which configure the logits processor's behavior against that request.** Your custom logits processor's `update_state()` override determines how `SamplingParams` fields are mapped into logits processor state + + * **Note:** for wrapped request-level logits processors, `new_req_logits_processor()` determines how `SamplingParams` fields are used to initialize a request-level logits processor instance. + + 2. **The conditions under which the logits processor is or is not enabled on a per-request basis.** Unless your intention is for the custom logits processor to act on all requests all the time, you should write your logits processor in such a way that it is possible to disable the logits processor for a given request, i.e. by defaulting an argument to `None` or by passing in a specific do-nothing argument value i.e. `0.0`. Try to save compute and memory for requests which disable the logits processor + + * **Note:** for wrapped per-request logits processors, the default `AdapterLogitsProcessor.update_state()` implementation ensures that the request-level logits processor is disabled when `new_req_logits_processor()` returns `None` for that request + + 3. **The conditions under which the logits processor is short-circuited at the batch level.** Even if you have defined a way to disable the custom logits processor at the request level, it may be difficult to translate this into compute savings i.e. if your `update_state()` and `apply()` implementations use efficient vectorized implementations that operate on the whole persistent batch in a single command. For example, you cannot skip an entire vectorized operation in `apply()` just because one request disabled the logits processor. To save compute in the edge-case where no running requests utilize the custom logits processor, we recommend designing `apply()` to return the unmodified input tensor if all requests have the logits processor disabled. Similarly, consider whether steps can be skipped in `update_state()` if no requests enable the logits processor + + * Additionally, an easy way to save compute in `update_state()` is to exit early when the `batch_update` is `None` + + * **Note:** for wrapped per-request logits processors, the `AdapterLogitsProcessor` base-class implements the above optimizations by default + +* Ensure that the logits processor `update_state` method discards information about finished requests (i.e. requests which are replaced by an Add or which are subject to a Remove) + + * **Note:** for wrapped per-request logits processors, the `AdapterLogitsProcessor` base-class handles this by default + +* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index 996ef00a6b960..fe065b52268a6 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -23,7 +23,7 @@ Now supports 5 types of connectors: - **SharedStorageConnector**: refer to for the example usage of SharedStorageConnector disaggregated prefilling. - **LMCacheConnectorV1**: refer to for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. -- **NixlConnector**: refer to for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. +- **NixlConnector**: refer to for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). - **P2pNcclConnector**: refer to for the example usage of P2pNcclConnector disaggregated prefilling. - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: @@ -31,6 +31,18 @@ Now supports 5 types of connectors: --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' ``` +For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as: + + ```bash + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}' + ``` + +- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker): + + ```bash + --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}' + ``` + ## Benchmarks Please refer to for disaggregated prefilling benchmarks. diff --git a/docs/features/lora.md b/docs/features/lora.md index 668460a368a77..db794b2ebd71d 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -52,7 +52,7 @@ Check out for an exa ## Serving LoRA Adapters LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use -`--lora-modules {name}={path} {name}={path}` to specify each LoRA module when we kickoff the server: +`--lora-modules {name}={path} {name}={path}` to specify each LoRA module when we kick off the server: ```bash vllm serve meta-llama/Llama-2-7b-hf \ diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 9d51f9cf52f50..dcc5ea3b90964 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -6,6 +6,13 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. +!!! tip + When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com` + + Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions. + + This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks. + ## Offline Inference To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: @@ -13,6 +20,67 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: - `prompt`: The prompt should follow the format that is documented on HuggingFace. - `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. +### Stable UUIDs for Caching (multi_modal_uuids) + +When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content. + +??? code + + ```python + from vllm import LLM + from PIL import Image + + # Qwen2.5-VL example with two images + llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct") + + prompt = "USER: \nDescribe the differences.\nASSISTANT:" + img_a = Image.open("/path/to/a.jpg") + img_b = Image.open("/path/to/b.jpg") + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": {"image": [img_a, img_b]}, + # Provide stable IDs for caching. + # Requirements (matched by this example): + # - Include every modality present in multi_modal_data. + # - For lists, provide the same number of entries. + # - Use None to fall back to content hashing for that item. + "multi_modal_uuids": {"image": ["sku-1234-a", None]}, + }) + + for o in outputs: + print(o.outputs[0].text) + ``` + +Using UUIDs, you can also skip sending media data entirely if you expect cache hits for respective items. Note that the request will fail if the skipped media doesn't have a corresponding UUID, or if the UUID fails to hit the cache. + +??? code + + ```python + from vllm import LLM + from PIL import Image + + # Qwen2.5-VL example with two images + llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct") + + prompt = "USER: \nDescribe the differences.\nASSISTANT:" + img_b = Image.open("/path/to/b.jpg") + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": {"image": [None, img_b]}, + # Since img_a is expected to be cached, we can skip sending the actual + # image entirely. + "multi_modal_uuids": {"image": ["sku-1234-a", None]}, + }) + + for o in outputs: + print(o.outputs[0].text) + ``` + +!!! warning + If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored. + ### Image Inputs You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples: @@ -180,19 +248,19 @@ When loading RGBA images (images with transparency), vLLM converts them to RGB f ```python from vllm import LLM - + # Default white background (no configuration needed) llm = LLM(model="llava-hf/llava-1.5-7b-hf") - + # Custom black background for dark theme llm = LLM( model="llava-hf/llava-1.5-7b-hf", media_io_kwargs={"image": {"rgba_background_color": [0, 0, 0]}} ) - + # Custom brand color background (e.g., blue) llm = LLM( - model="llava-hf/llava-1.5-7b-hf", + model="llava-hf/llava-1.5-7b-hf", media_io_kwargs={"image": {"rgba_background_color": [0, 0, 255]}} ) ``` @@ -353,7 +421,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd ## Online Serving -Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). +Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests. !!! important A chat template is **required** to use Chat Completions API. @@ -363,7 +431,7 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. For certain models, we provide alternative chat templates inside . - For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. + For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. ### Image Inputs @@ -403,7 +471,13 @@ Then, you can use the OpenAI client as follows: # NOTE: The prompt formatting with the image token `` is not needed # since the prompt will be processed automatically by the API server. {"type": "text", "text": "What’s in this image?"}, - {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_url", + "image_url": { + url": image_url + }, + "uuid": image_url # Optional + }, ], }], ) @@ -419,8 +493,20 @@ Then, you can use the OpenAI client as follows: "role": "user", "content": [ {"type": "text", "text": "What are the animals in these images?"}, - {"type": "image_url", "image_url": {"url": image_url_duck}}, - {"type": "image_url", "image_url": {"url": image_url_lion}}, + { + "type": "image_url", + "image_url": { + "url": image_url_duck + }, + "uuid": image_url_duck # Optional + }, + { + "type": "image_url", + "image_url": { + "url": image_url_lion + }, + "uuid": image_url_lion # Optional + }, ], }], ) @@ -487,6 +573,7 @@ Then, you can use the OpenAI client as follows: "video_url": { "url": video_url }, + "uuid": video_url # Optional }, ], }], @@ -578,6 +665,7 @@ Then, you can use the OpenAI client as follows: "data": audio_base64, "format": "wav" }, + "uuid": audio_url # Optional }, ], }], @@ -607,6 +695,7 @@ Alternatively, you can pass `audio_url`, which is the audio counterpart of `imag "audio_url": { "url": audio_url }, + "uuid": audio_url # Optional }, ], }], @@ -660,7 +749,8 @@ The following example demonstrates how to pass image embeddings to the OpenAI se model = "llava-hf/llava-1.5-7b-hf" embeds = { "type": "image_embeds", - "image_embeds": f"{base64_image_embedding}" + "image_embeds": f"{base64_image_embedding}", + "uuid": image_url # Optional } # Pass additional parameters (available to Qwen2-VL and MiniCPM-V) @@ -671,6 +761,7 @@ The following example demonstrates how to pass image embeddings to the OpenAI se "image_embeds": f"{base64_image_embedding}" , # Required "image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct }, + "uuid": image_url # Optional } model = "openbmb/MiniCPM-V-2_6" embeds = { @@ -679,6 +770,7 @@ The following example demonstrates how to pass image embeddings to the OpenAI se "image_embeds": f"{base64_image_embedding}" , # Required "image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6 }, + "uuid": image_url # Optional } chat_completion = client.chat.completions.create( messages=[ @@ -696,6 +788,39 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ) ``` +For Online Serving, you can also skip sending media if you expect cache hits with provided UUIDs. You can do so by sending media like this: + + ```python + # Image/video/audio URL: + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + + # image_embeds + { + "type": "image_embeds", + "image_embeds": None, + "uuid": image_uuid + }, + + # input_audio: + { + "type": "input_audio", + "input_audio": None, + "uuid": audio_uuid + }, + + # PIL Image: + { + "type": "image_pil", + "image_pil": None + "uuid": image_uuid + } + + ``` + !!! note Only one message can contain `{"type": "image_embeds"}`. If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc. diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md new file mode 100644 index 0000000000000..795b0c77d610e --- /dev/null +++ b/docs/features/nixl_connector_usage.md @@ -0,0 +1,165 @@ +# NixlConnector Usage Guide + +NixlConnector is a high-performance KV cache transfer connector for vLLM's disaggregated prefilling feature. It provides fully asynchronous send/receive operations using the NIXL library for efficient cross-process KV cache transfer. + +## Prerequisites + +### Installation + +Install the NIXL library: `uv pip install nixl`, as a quick start. + +- Refer to [NIXL official repository](https://github.com/ai-dynamo/nixl) for more installation instructions +- The specified required NIXL version can be found in [requirements/kv_connectors.txt](gh-file:requirements/kv_connectors.txt) and other relevant config files + +For non-cuda platform, please install nixl with ucx build from source, instructed as below. + +```bash +python tools/install_nixl_from_source_ubuntu.py +``` + +### Transport Configuration + +NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables: + +```bash +# Example UCX configuration, adjust according to your enviroment +export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc +export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1" +``` + +!!! tip + When using UCX as the transport backend, NCCL environment variables (like `NCCL_IB_HCA`, `NCCL_SOCKET_IFNAME`) are not applicable to NixlConnector, so configure UCX-specific environment variables instead of NCCL variables. + +## Basic Usage (on the same host) + +### Producer (Prefiller) Configuration + +Start a prefiller instance that produces KV caches + +```bash +# 1st GPU as prefiller +CUDA_VISIBLE_DEVICES=0 \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +vllm serve Qwen/Qwen3-0.6B \ + --port 8100 \ + --enforce-eager \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +``` + +### Consumer (Decoder) Configuration + +Start a decoder instance that consumes KV caches: + +```bash +# 2nd GPU as decoder +CUDA_VISIBLE_DEVICES=1 \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \ +vllm serve Qwen/Qwen3-0.6B \ + --port 8200 \ + --enforce-eager \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +``` + +### Proxy Server + +Use a proxy server to route requests between prefiller and decoder: + +```bash +python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ + --port 8192 \ + --prefiller-hosts localhost \ + --prefiller-ports 8100 \ + --decoder-hosts localhost \ + --decoder-ports 8200 +``` + +## Environment Variables + +- `VLLM_NIXL_SIDE_CHANNEL_PORT`: Port for NIXL handshake communication + - Default: 5600 + - **Required for both prefiller and decoder instances** + - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine + - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node). + - Used for the initial NIXL handshake between the prefiller and the decoder + +- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication + - Default: "localhost" + - Set when prefiller and decoder are on different machines + - Connection info is passed via KVTransferParams from prefiller to decoder for handshake + +- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) + - Default: 480 + - If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. + +## Multi-Instance Setup + +### Multiple Prefiller Instances on Different Machines + +```bash +# Prefiller 1 on Machine A (example IP: ${IP1}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP1} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' + +# Prefiller 2 on Machine B (example IP: ${IP2}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP2} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' +``` + +### Multiple Decoder Instances on Different Machines + +```bash +# Decoder 1 on Machine C (example IP: ${IP3}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP3} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' + +# Decoder 2 on Machine D (example IP: ${IP4}) +VLLM_NIXL_SIDE_CHANNEL_HOST=${IP4} \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ +UCX_NET_DEVICES=all \ +vllm serve Qwen/Qwen3-0.6B --port 8000 \ + --tensor-parallel-size 8 \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' +``` + +### Proxy for Multiple Instances + +```bash +python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ + --port 8192 \ + --prefiller-hosts ${IP1} ${IP2} \ + --prefiller-ports 8000 8000 \ + --decoder-hosts ${IP3} ${IP4} \ + --decoder-ports 8000 8000 +``` + +### KV Role Options + +- **kv_producer**: For prefiller instances that generate KV caches +- **kv_consumer**: For decoder instances that consume KV caches from prefiller +- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. + +!!! tip + NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`). + Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior. + +## Example Scripts/Code + +Refer to these example scripts in the vLLM repository: + +- [run_accuracy_test.sh](gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) +- [toy_proxy_server.py](gh-file:tests/v1/kv_connector/nixl_integration/toy_proxy_server.py) +- [test_accuracy.py](gh-file:tests/v1/kv_connector/nixl_integration/test_accuracy.py) diff --git a/docs/features/prompt_embeds.md b/docs/features/prompt_embeds.md index 83993bd0140fa..f9d3c1fb6c23d 100644 --- a/docs/features/prompt_embeds.md +++ b/docs/features/prompt_embeds.md @@ -6,9 +6,6 @@ This page teaches you how to pass prompt embedding inputs to vLLM. The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. -!!! note - Prompt embeddings are currently only supported in the v0 engine. - ## Offline Inference To input multi-modal data, follow this schema in [vllm.inputs.EmbedsPrompt][]: diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index e18c128f30fc9..4c8377871e141 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -4,7 +4,6 @@ Quantization trades off model precision for smaller memory footprint, allowing l Contents: -- [Supported Hardware](supported_hardware.md) - [AutoAWQ](auto_awq.md) - [AutoRound](auto_round.md) - [BitsAndBytes](bnb.md) @@ -19,3 +18,50 @@ Contents: - [AMD Quark](quark.md) - [Quantized KV Cache](quantized_kvcache.md) - [TorchAO](torchao.md) + +## Supported Hardware + +The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: + + + +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | + +- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. +- ✅︎ indicates that the quantization method is supported on the specified hardware. +- ❌ indicates that the quantization method is not supported on the specified hardware. + +!!! note + This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. + + For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index 6f53a448ee364..53b689ad53ff6 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -5,7 +5,7 @@ vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more effic !!! note Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. - For details see [supported hardware](supported_hardware.md). + For details see [supported hardware](README.md#supported-hardware). Below are the steps to utilize BitBLAS with vLLM. diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index 13b151bc7f380..5e86e9388f328 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). !!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. !!! note `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 247d0cbdd3f14..af3650e701ad0 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -6,7 +6,11 @@ This quantization method is particularly useful for reducing model size while ma Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415). !!! note - INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell). + INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper). + +!!! warning + **Blackwell GPU Limitation**: INT8 is not supported on compute capability >= 100 (e.g., RTX 6000 Blackwell). + Use [FP8 quantization](fp8.md) instead, or run on Hopper/Ada/Ampere architectures. ## Prerequisites diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 047cc8382445b..85b7d8ec84ed3 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -231,9 +231,9 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \ --tasks gsm8k ``` -## Using MXFP4 models +## Using OCP MX (MXFP4, MXFP6) models -vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). +vLLM supports loading MXFP4 and MXFP6 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). The scheme currently only supports dynamic quantization for activations. @@ -241,17 +241,21 @@ Example usage, after installing the latest AMD Quark release: ```bash vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1 +# or, for a model using fp6 activations and fp4 weights: +vllm serve fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3 --tensor-parallel-size 1 ``` -A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16). +A simulation of the matrix multiplication execution in MXFP4/MXFP6 can be run on devices that do not support OCP MX operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from FP4/FP6 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate FP4/FP6 models using vLLM, or alternatively to benefit from the ~2.5-4x memory savings (compared to float16 and bfloat16). To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example: ```bash python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \ - --quant_scheme w_mxfp4_a_mxfp4_sym \ + --quant_scheme w_mxfp4_a_mxfp4 \ --output_dir qwen_1.5-moe-a2.7b-mxfp4 \ --skip_evaluation \ --model_export hf_format \ --group_size 32 ``` + +The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights. diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md deleted file mode 100644 index 06264d08b56aa..0000000000000 --- a/docs/features/quantization/supported_hardware.md +++ /dev/null @@ -1,32 +0,0 @@ -# Supported Hardware - -The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: - - - -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | - -- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. -- ✅︎ indicates that the quantization method is supported on the specified hardware. -- ❌ indicates that the quantization method is not supported on the specified hardware. - -!!! note - This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. - - For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 04b943efbbbb4..85681669dfb22 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -10,11 +10,12 @@ vLLM currently supports the following reasoning models: | Model Series | Parser Name | Structured Output Support | Tool Calling | |--------------|-------------|------------------|-------------| -| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | -| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | +| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ | +| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | -| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | -| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` | ✅ | +| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ | +| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` | ✅ | +| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ | !!! note IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. @@ -143,7 +144,7 @@ OpenAI Python client library does not officially support `reasoning_content` att print(content, end="", flush=True) ``` -Remember to check whether the `reasoning_content` exists in the response before accessing it. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). +Remember to check whether the `reasoning_content` exists in the response before accessing it. You could check out the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py). ## Tool Calling diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md index 5749b02d26f45..e7dd9fee12d37 100644 --- a/docs/features/sleep_mode.md +++ b/docs/features/sleep_mode.md @@ -64,8 +64,7 @@ To enable sleep mode in a vLLM server you need to initialize it with the flag `V When using the flag `VLLM_SERVER_DEV_MODE=1` you enable development endpoints, and these endpoints should not be exposed to users. ```bash -VLLM_SERVER_DEV_MODE=1 python -m vllm.entrypoints.openai.api_server \ - --model Qwen/Qwen3-0.6B \ +VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ --enable-sleep-mode \ --port 8000 ``` diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 597a8e8644278..25c308a6ff206 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -48,10 +48,9 @@ The following code configures vLLM in an offline mode to use speculative decodin To perform the same with an online mode launch the server: ```bash -python -m vllm.entrypoints.openai.api_server \ +vllm serve facebook/opt-6.7b \ --host 0.0.0.0 \ --port 8000 \ - --model facebook/opt-6.7b \ --seed 42 \ -tp 1 \ --gpu_memory_utilization 0.8 \ diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 8a934d406f382..901d87e7ed3d9 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -6,29 +6,40 @@ vLLM supports the generation of structured outputs using This document shows you some examples of the different options that are available to generate structured outputs. +!!! warning + If you are still using the following deprecated API fields, please update your code to use `structured_outputs` as demonstrated in the rest of this document: + + - `guided_json` -> `{"structured_outputs": {"json": ...}}` or `StructuredOutputsParams(json=...)` + - `guided_regex` -> `{"structured_outputs": {"regex": ...}}` or `StructuredOutputsParams(regex=...)` + - `guided_choice` -> `{"structured_outputs": {"choice": ...}}` or `StructuredOutputsParams(choice=...)` + - `guided_grammar` -> `{"structured_outputs": {"grammar": ...}}` or `StructuredOutputsParams(grammar=...)` + - `guided_whitespace_pattern` -> `{"structured_outputs": {"whitespace_pattern": ...}}` or `StructuredOutputsParams(whitespace_pattern=...)` + - `structural_tag` -> `{"structured_outputs": {"structural_tag": ...}}` or `StructuredOutputsParams(structural_tag=...)` + - `guided_decoding_backend` -> Remove this field from your request + ## Online Serving (OpenAI API) You can generate structured outputs using the OpenAI's [Completions](https://platform.openai.com/docs/api-reference/completions) and [Chat](https://platform.openai.com/docs/api-reference/chat) API. The following parameters are supported, which must be added as extra parameters: -- `guided_choice`: the output will be exactly one of the choices. -- `guided_regex`: the output will follow the regex pattern. -- `guided_json`: the output will follow the JSON schema. -- `guided_grammar`: the output will follow the context free grammar. +- `choice`: the output will be exactly one of the choices. +- `regex`: the output will follow the regex pattern. +- `json`: the output will follow the JSON schema. +- `grammar`: the output will follow the context free grammar. - `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text. You can see the complete list of supported parameters on the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) page. Structured outputs are supported by default in the OpenAI-Compatible Server. You may choose to specify the backend to use by setting the -`--guided-decoding-backend` flag to `vllm serve`. The default backend is `auto`, +`--structured-outputs-config.backend` flag to `vllm serve`. The default backend is `auto`, which will try to choose an appropriate backend based on the details of the request. You may also choose a specific backend, along with some options. A full set of options is available in the `vllm serve --help` text. -Now let´s see an example for each of the cases, starting with the `guided_choice`, as it´s the easiest one: +Now let´s see an example for each of the cases, starting with the `choice`, as it´s the easiest one: ??? code @@ -45,12 +56,12 @@ Now let´s see an example for each of the cases, starting with the `guided_choic messages=[ {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], - extra_body={"guided_choice": ["positive", "negative"]}, + extra_body={"structured_outputs": {"choice": ["positive", "negative"]}}, ) print(completion.choices[0].message.content) ``` -The next example shows how to use the `guided_regex`. The idea is to generate an email address, given a simple regex template: +The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template: ??? code @@ -63,18 +74,18 @@ The next example shows how to use the `guided_regex`. The idea is to generate an "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", } ], - extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]}, + extra_body={"structured_outputs": {"regex": r"\w+@\w+\.com\n"}, "stop": ["\n"]}, ) print(completion.choices[0].message.content) ``` One of the most relevant features in structured text generation is the option to generate a valid JSON with pre-defined fields and formats. -For this we can use the `guided_json` parameter in two different ways: +For this we can use the `json` parameter in two different ways: - Using directly a [JSON Schema](https://json-schema.org/) - Defining a [Pydantic model](https://docs.pydantic.dev/latest/) and then extracting the JSON Schema from it (which is normally an easier option). -The next example shows how to use the `guided_json` parameter with a Pydantic model: +The next example shows how to use the `response_format` parameter with a Pydantic model: ??? code @@ -119,7 +130,7 @@ The next example shows how to use the `guided_json` parameter with a Pydantic mo JSON schema and how the fields should be populated. This can improve the results notably in most cases. -Finally we have the `guided_grammar` option, which is probably the most +Finally we have the `grammar` option, which is probably the most difficult to use, but it´s really powerful. It allows us to define complete languages like SQL queries. It works by using a context free EBNF grammar. As an example, we can use to define a specific format of simplified SQL queries: @@ -149,7 +160,7 @@ As an example, we can use to define a specific format of simplified SQL queries: "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", } ], - extra_body={"guided_grammar": simplified_sql_grammar}, + extra_body={"structured_outputs": {"grammar": simplified_sql_grammar}}, ) print(completion.choices[0].message.content) ``` @@ -205,7 +216,7 @@ This section covers the OpenAI beta wrapper over the `client.chat.completions.cr At the time of writing (`openai==1.54.4`), this is a "beta" feature in the OpenAI client library. Code reference can be found [here](https://github.com/openai/openai-python/blob/52357cff50bee57ef442e94d78a0de38b4173fc2/src/openai/resources/beta/chat/completions.py#L100-L104). -For the following examples, vLLM was setup using `vllm serve meta-llama/Llama-3.1-8B-Instruct` +For the following examples, vLLM was set up using `vllm serve meta-llama/Llama-3.1-8B-Instruct` Here is a simple example demonstrating how to get structured output using Pydantic models: @@ -292,8 +303,8 @@ An example of using `structural_tag` can be found here: - this is a "better" version that adds a tool-use system prompt when tools are provided, that results in much better reliability when working with parallel tool calling. -Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` +Recommended flags: + +1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend: + + `--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral` + +2. To use the default Transformers tokenization backend: + `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` ### Llama Models (`llama3_json`) @@ -169,7 +175,7 @@ All Llama 3.1, 3.2 and 4 models should be supported. The tool calling that is supported is the [JSON-based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for Llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. -Other tool calling formats like the built in python tool calling or custom tool calling are not supported. +Other tool calling formats like the built-in python tool calling or custom tool calling are not supported. Known issues: @@ -192,10 +198,14 @@ VLLM also provides a pythonic and JSON-based chat template for Llama 4, but pyth For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. -#### IBM Granite +### IBM Granite Supported models: +* `ibm-granite/granite-4.0-h-small` and other Granite 4.0 models + + Recommended flags: `--tool-call-parser hermes` + * `ibm-granite/granite-3.0-8b-instruct` Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` @@ -311,6 +321,35 @@ Flags: * For non-reasoning: `--tool-call-parser hunyuan_a13b` * For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning` +### LongCat-Flash-Chat Models (`longcat`) + +Supported models: + +* `meituan-longcat/LongCat-Flash-Chat` +* `meituan-longcat/LongCat-Flash-Chat-FP8` + +Flags: `--tool-call-parser longcat` + +### GLM-4.5 Models (`glm45`) + +Supported models: + +* `zai-org/GLM-4.5` +* `zai-org/GLM-4.5-Air` +* `zai-org/GLM-4.6` +* `zai-org/GLM-4.6-Air` + +Flags: `--tool-call-parser glm45` + +### Qwen3-Coder Models (`qwen3_xml`) + +Supported models: + +* `Qwen/Qwen3-480B-A35B-Instruct` +* `Qwen/Qwen3-Coder-30B-A3B-Instruct` + +Flags: `--tool-call-parser qwen3_xml` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. diff --git a/docs/getting_started/installation/.nav.yml b/docs/getting_started/installation/.nav.yml index d4a727c926406..ba1f8099a6456 100644 --- a/docs/getting_started/installation/.nav.yml +++ b/docs/getting_started/installation/.nav.yml @@ -3,5 +3,3 @@ nav: - gpu.md - cpu.md - google_tpu.md - - intel_gaudi.md - - aws_neuron.md diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index 0ee680f5c688c..a4e63e426b9ba 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -12,8 +12,6 @@ vLLM supports the following hardware platforms: - [Apple silicon](cpu.md#apple-silicon) - [IBM Z (S390X)](cpu.md#ibm-z-s390x) - [Google TPU](google_tpu.md) -- [Intel Gaudi](intel_gaudi.md) -- [AWS Neuron](aws_neuron.md) ## Hardware Plugins @@ -27,3 +25,4 @@ The backends below live **outside** the main `vllm` repository and follow the | MetaX MACA GPU | N/A, install from source | | | Rebellions ATOM / REBEL NPU | `vllm-rbln` | | | IBM Spyre AIU | `vllm-spyre` | | +| Cambricon MLU | `vllm-mlu` | | diff --git a/docs/getting_started/installation/aws_neuron.md b/docs/getting_started/installation/aws_neuron.md deleted file mode 100644 index b8bd76bd5bcbe..0000000000000 --- a/docs/getting_started/installation/aws_neuron.md +++ /dev/null @@ -1,147 +0,0 @@ -# AWS Neuron - -[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) is the software development kit (SDK) used to run deep learning and -generative AI workloads on AWS Inferentia and AWS Trainium powered Amazon EC2 instances and UltraServers (Inf1, Inf2, Trn1, Trn2, -and Trn2 UltraServer). Both Trainium and Inferentia are powered by fully-independent heterogeneous compute-units called NeuronCores. -This describes how to set up your environment to run vLLM on Neuron. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - -## Requirements - -- OS: Linux -- Python: 3.9 or newer -- Pytorch 2.5/2.6 -- Accelerator: NeuronCore-v2 (in trn1/inf2 chips) or NeuronCore-v3 (in trn2 chips) -- AWS Neuron SDK 2.23 - -## Configure a new environment - -### Launch a Trn1/Trn2/Inf2 instance and verify Neuron dependencies - -The easiest way to launch a Trainium or Inferentia instance with pre-installed Neuron dependencies is to follow this -[quick start guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/multiframework/multi-framework-ubuntu22-neuron-dlami.html#setup-ubuntu22-multi-framework-dlami) using the Neuron Deep Learning AMI (Amazon machine image). - -- After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance -- Once inside your instance, activate the pre-installed virtual environment for inference by running - -```bash -source /opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/activate -``` - -Refer to the [NxD Inference Setup Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/nxdi-setup.html) -for alternative setup instructions including using Docker and manually installing dependencies. - -!!! note - NxD Inference is the default recommended backend to run inference on Neuron. If you are looking to use the legacy [transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx) - library, refer to [Transformers NeuronX Setup](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/setup/index.html). - -## Set up using Python - -### Pre-built wheels - -Currently, there are no pre-built Neuron wheels. - -### Build wheel from source - -To build and install vLLM from source, run: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -U -r requirements/neuron.txt -VLLM_TARGET_DEVICE="neuron" pip install -e . -``` - -AWS Neuron maintains a [Github fork of vLLM](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2) at -, which contains several features in addition to what's -available on vLLM V0. Please utilize the AWS Fork for the following features: - -- Llama-3.2 multi-modal support -- Multi-node distributed inference - -Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html) - for more details and usage examples. - -To install the AWS Neuron fork, run the following: - -```bash -git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git -cd upstreaming-to-vllm -pip install -r requirements/neuron.txt -VLLM_TARGET_DEVICE="neuron" pip install -e . -``` - -Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested. - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Neuron images. - -### Build image from source - -See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. - -Make sure to use in place of the default Dockerfile. - -## Extra information - -[](){ #feature-support-through-nxd-inference-backend } - -### Feature support through NxD Inference backend - -The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend -to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most -[features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration. - -To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override -as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include - -```python -override_neuron_config={ - "enable_bucketing":False, -} -``` - -or when launching vLLM from the CLI, pass - -```bash ---override-neuron-config "{\"enable_bucketing\":false}" -``` - -Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts -(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads. - -### Known limitations - -- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this - [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility) - for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI. -- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this - [Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html) - to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM. -- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at - runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py) -- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed - to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature. -- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer - to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node) - to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main. -- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches - max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt - to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support - for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is - implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic. - -### Environment variables - -- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid - compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the - artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set, - but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts - under this specified path. -- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). -- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7a34d47d8e494..f290836f944cc 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -20,7 +20,7 @@ vLLM is a Python library that supports the following CPU variants. Select your C ## Requirements -- Python: 3.9 -- 3.12 +- Python: 3.10 -- 3.13 === "Intel/AMD x86" @@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels. - `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. +- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. - `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). - `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). @@ -170,7 +171,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -179,7 +180,7 @@ Inference batch size is a important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. +vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommended to use DP, TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -190,6 +191,38 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? -- Both of them requires `amx` CPU flag. +- Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. + +### Why do I see `get_mempolicy: Operation not permitted` when running in Docker? + +In some container environments (like Docker), NUMA-related syscalls used by vLLM (e.g., `get_mempolicy`, `migrate_pages`) are blocked/denied in the runtime's default seccomp/capabilities settings. This may lead to warnings like `get_mempolicy: Operation not permitted`. Functionality is not affected, but NUMA memory binding/migration optimizations may not take effect and performance can be suboptimal. + +To enable these optimizations inside Docker with the least privilege, you can follow below tips: + +```bash +docker run ... --cap-add SYS_NICE --security-opt seccomp=unconfined ... + +# 1) `--cap-add SYS_NICE` is to address `get_mempolicy` EPERM issue. + +# 2) `--security-opt seccomp=unconfined` is to enable `migrate_pages` for `numa_migrate_pages()`. +# Actually, `seccomp=unconfined` bypasses the seccomp for container, +# if it's unacceptable, you can customize your own seccomp profile, +# based on docker/runtime default.json and add `migrate_pages` to `SCMP_ACT_ALLOW` list. + +# reference : https://docs.docker.com/engine/security/seccomp/ +``` + +Alternatively, running with `--privileged=true` also works but is broader and not generally recommended. + +In K8S, the following configuration can be added to workload yaml to achieve the same effect as above: + +```yaml +securityContext: + seccompProfile: + type: Unconfined + capabilities: + add: + - SYS_NICE +``` diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 2828173a76a9a..7e2ed55008a57 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM has experimental support for macOS with Apple silicon. For now, users must build from source to natively run on macOS. +vLLM has experimental support for macOS with Apple Silicon. For now, users must build from source to natively run on macOS. Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. @@ -52,6 +52,24 @@ uv pip install -e . 1 error generated. ``` + --- + + If the build fails with C++11/C++17 compatibility errors like the following, the issue is that the build system is defaulting to an older C++ standard: + + ```text + [...] error: 'constexpr' is not a type + [...] error: expected ';' before 'constexpr' + [...] error: 'constexpr' does not name a type + ``` + + **Solution**: Your compiler might be using an older C++ standard. Edit `cmake/cpu_extension.cmake` and add `set(CMAKE_CXX_STANDARD 17)` before `set(CMAKE_CXX_STANDARD_REQUIRED ON)`. + + To check your compiler's C++ standard support: + ```bash + clang++ -std=c++17 -pedantic -dM -E -x c++ /dev/null | grep __cplusplus + ``` + On Apple Clang 16 you should see: `#define __cplusplus 201703L` + # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md index cac578eefb1d7..e45baa0aa4938 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -48,6 +48,10 @@ docker run --rm \ --dtype=bfloat16 \ other vLLM OpenAI server arguments ``` + +!!! tip + An alternative of `--privileged=true` is `--cap-add SYS_NICE --security-opt seccomp=unconfined`. + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/build.inc.md b/docs/getting_started/installation/cpu/build.inc.md index 57a09e674a821..4bd4d39a6f80b 100644 --- a/docs/getting_started/installation/cpu/build.inc.md +++ b/docs/getting_started/installation/cpu/build.inc.md @@ -16,8 +16,8 @@ cd vllm_source Third, install required dependencies: ```bash -uv pip install -r requirements/cpu-build.txt --torch-backend auto -uv pip install -r requirements/cpu.txt --torch-backend auto +uv pip install -r requirements/cpu-build.txt --torch-backend cpu +uv pip install -r requirements/cpu.txt --torch-backend cpu ``` ??? console "pip" diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md index c1917267ce91b..442c2b4ec64e8 100644 --- a/docs/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -46,22 +46,22 @@ Execute the following commands to build and install vLLM from source. Please build the following dependencies, `torchvision`, `pyarrow` from source before building vLLM. ```bash - sed -i '/^torch/d' requirements-build.txt # remove torch from requirements-build.txt since we use nightly builds + sed -i '/^torch/d' requirements/build.txt # remove torch from requirements/build.txt since we use nightly builds uv pip install -v \ --torch-backend auto \ - -r requirements-build.txt \ - -r requirements-cpu.txt \ + -r requirements/build.txt \ + -r requirements/cpu.txt \ VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ uv pip install dist/*.whl ``` ??? console "pip" ```bash - sed -i '/^torch/d' requirements-build.txt # remove torch from requirements-build.txt since we use nightly builds + sed -i '/^torch/d' requirements/build.txt # remove torch from requirements/build.txt since we use nightly builds pip install -v \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - -r requirements-build.txt \ - -r requirements-cpu.txt \ + -r requirements/build.txt \ + -r requirements/cpu.txt \ VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ pip install dist/*.whl ``` @@ -89,6 +89,9 @@ docker run --rm \ other vLLM OpenAI server arguments ``` +!!! tip + An alternative of `--privileged true` is `--cap-add SYS_NICE --security-opt seccomp=unconfined`. + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 6dc6f94249c34..00f3b726b1a0e 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -20,7 +20,80 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] ---8<-- "docs/getting_started/installation/cpu/build.inc.md" +Install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: + +```bash +sudo apt-get update -y +sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev +sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +``` + +Clone the vLLM project: + +```bash +git clone https://github.com/vllm-project/vllm.git vllm_source +cd vllm_source +``` + +Install the required dependencies: + +```bash +uv pip install -r requirements/cpu-build.txt --torch-backend cpu +uv pip install -r requirements/cpu.txt --torch-backend cpu +``` + +??? console "pip" + ```bash + pip install --upgrade pip + pip install -v -r requirements/cpu-build.txt --extra-index-url https://download.pytorch.org/whl/cpu + pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + ``` + +Build and install vLLM: + +```bash +VLLM_TARGET_DEVICE=cpu uv pip install . --no-build-isolation +``` + +If you want to develop vLLM, install it in editable mode instead. + +```bash +VLLM_TARGET_DEVICE=cpu uv pip install -e . --no-build-isolation +``` + +Optionally, build a portable wheel which you can then install elsewhere: + +```bash +VLLM_TARGET_DEVICE=cpu uv build --wheel +``` + +```bash +uv pip install dist/*.whl +``` + +??? console "pip" + ```bash + VLLM_TARGET_DEVICE=cpu python -m build --wheel --no-isolation + ``` + + ```bash + pip install dist/*.whl + ``` + +!!! example "Troubleshooting" + - **NumPy ≥2.0 error**: Downgrade using `pip install "numpy<2.0"`. + - **CMake picks up CUDA**: Add `CMAKE_DISABLE_FIND_PACKAGE_CUDA=ON` to prevent CUDA detection during CPU builds, even if CUDA is installed. + - `AMD` requies at least 4th gen processors (Zen 4/Genoa) or higher to support [AVX512](https://www.phoronix.com/review/amd-zen4-avx512) to run vLLM on CPU. + - If you receive an error such as: `Could not find a version that satisfies the requirement torch==X.Y.Z+cpu+cpu`, consider updating [pyproject.toml](https://github.com/vllm-project/vllm/blob/main/pyproject.toml) to help pip resolve the dependency. + ```toml title="pyproject.toml" + [build-system] + requires = [ + "cmake>=3.26.1", + ... + "torch==X.Y.Z+cpu" # <------- + ] + ``` + - If you are building vLLM from source and not using the pre-built images, remember to set `LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD"` on x86 machines before running vLLM. # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] @@ -43,7 +116,8 @@ docker build -f docker/Dockerfile.cpu \ # Launching OpenAI server docker run --rm \ - --privileged=true \ + --security-opt seccomp=unconfined \ + --cap-add SYS_NICE \ --shm-size=4g \ -p 8000:8000 \ -e VLLM_CPU_KVCACHE_SPACE= \ @@ -56,4 +130,4 @@ docker run --rm \ # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] -# --8<-- [end:extra-information] +# --8<-- [end:extra-information] \ No newline at end of file diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md index e688cefea0763..45162b86e2f2f 100644 --- a/docs/getting_started/installation/gpu.md +++ b/docs/getting_started/installation/gpu.md @@ -17,7 +17,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G ## Requirements - OS: Linux -- Python: 3.9 -- 3.12 +- Python: 3.10 -- 3.13 !!! note vLLM does not support Windows natively. To run vLLM on Windows, you can use the Windows Subsystem for Linux (WSL) with a compatible Linux distribution, or use some community-maintained forks, e.g. [https://github.com/SystemPanic/vllm-windows](https://github.com/SystemPanic/vllm-windows). diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 69a9842e4719b..9e64c6f2540af 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -48,7 +48,7 @@ uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VE #### Install the latest code -LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on a x86 platform with CUDA 12 for every commit since `v0.5.3`. +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on an x86 platform with CUDA 12 for every commit since `v0.5.3`. ```bash uv pip install -U vllm \ @@ -168,6 +168,7 @@ There are scenarios where the PyTorch dependency cannot be easily installed with To build vLLM using an existing PyTorch installation: ```bash +# install PyTorch first, either from PyPI or from source git clone https://github.com/vllm-project/vllm.git cd vllm python use_existing_torch.py @@ -175,6 +176,17 @@ uv pip install -r requirements/build.txt uv pip install --no-build-isolation -e . ``` +Alternatively: if you are exclusively using `uv` to create and manage virtual environments, it has [a unique mechanism](https://docs.astral.sh/uv/concepts/projects/config/#disabling-build-isolation) +for disabling build isolation for specific packages. vLLM can leverage this mechanism to specify `torch` as the package to disable build isolation for: + +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +# pip install -e . does not work directly, only uv can do this +uv pip install -e . +``` + ##### Use the local cutlass for compilation Currently, before starting the build process, vLLM fetches cutlass code from GitHub. However, there may be scenarios where you want to use a local version of cutlass instead. diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 560883d3caf9e..37c6647929b51 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM supports AMD GPUs with ROCm 6.3. +vLLM supports AMD GPUs with ROCm 6.3 or above. !!! tip [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. @@ -11,8 +11,9 @@ vLLM supports AMD GPUs with ROCm 6.3. # --8<-- [end:installation] # --8<-- [start:requirements] -- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) -- ROCm 6.3 +- GPU: MI200s (gfx90a), MI300 (gfx942), MI350 (gfx950), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) +- ROCm 6.3 or above + - MI350 requires ROCm 7.0 or above # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -32,35 +33,35 @@ Currently, there are no pre-built ROCm wheels. - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) - [PyTorch](https://pytorch.org/) - For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. + For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: ```bash # Install PyTorch pip uninstall torch -y - pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 + pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4 ``` -1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) +1. Install [Triton for ROCm](https://github.com/triton-lang/triton) - Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) + Install ROCm's Triton (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) ```bash python3 -m pip install ninja cmake wheel pybind11 pip uninstall -y triton - git clone https://github.com/OpenAI/triton.git + git clone https://github.com/triton-lang/triton.git cd triton git checkout e5be006 - cd python - pip3 install . + if [ ! -f setup.py ]; then cd python; fi + python3 setup.py install cd ../.. ``` !!! note If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. -2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention) +2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/Dao-AILab/flash-attention) Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention#amd-rocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. @@ -68,9 +69,9 @@ Currently, there are no pre-built ROCm wheels. For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. ```bash - git clone https://github.com/ROCm/flash-attention.git + git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention - git checkout b7d29fb + git checkout 1a7f4dfa git submodule update --init GPU_ARCHS="gfx90a" python3 setup.py install cd .. @@ -119,7 +120,7 @@ Currently, there are no pre-built ROCm wheels. This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation. !!! tip - - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm-up step before collecting perf numbers. - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. - The ROCm version of PyTorch, ideally, should match the ROCm driver version. @@ -149,7 +150,7 @@ Build a docker image from which setup ROCm **This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** If you choose to build this rocm_base image yourself, the steps are as follows. -It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```json { @@ -170,7 +171,7 @@ DOCKER_BUILDKIT=1 docker build \ #### Build an image with vLLM First, build a docker image from and launch a docker container from the image. -It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```bash { @@ -194,16 +195,6 @@ To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm . ``` -To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: - -```bash -DOCKER_BUILDKIT=1 docker build \ - --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" \ - -f docker/Dockerfile.rocm \ - -t vllm-rocm \ - . -``` - To run the above docker image `vllm-rocm`, use the below command: ??? console "Command" @@ -218,8 +209,7 @@ To run the above docker image `vllm-rocm`, use the below command: --device /dev/kfd \ --device /dev/dri \ -v :/app/model \ - vllm-rocm \ - bash + vllm-rocm ``` Where the `` is the location where the model is stored, for example, the weights for llama2 or llama3 models. diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index b77c4e00cf0c4..2e73ac1825694 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -3,13 +3,16 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. !!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. + There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions. # --8<-- [end:installation] # --8<-- [start:requirements] - Supported Hardware: Intel Data Center GPU, Intel ARC GPU -- OneAPI requirements: oneAPI 2025.0 +- OneAPI requirements: oneAPI 2025.1 +- Python: 3.12 +!!! warning + The provided IPEX whl is Python3.12 specific so this version is a MUST. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -24,7 +27,7 @@ Currently, there are no pre-built XPU wheels. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] -- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.0 or later. +- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.1 or later. - Second, install Python packages for vLLM XPU backend building: ```bash @@ -40,14 +43,10 @@ pip install -v -r requirements/xpu.txt VLLM_TARGET_DEVICE=xpu python setup.py install ``` -!!! note - - FP16 is the default data type in the current XPU backend. The BF16 data - type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. - # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] -Currently, there are no pre-built XPU images. +Currently, we release prebuilt XPU images at docker [hub](https://hub.docker.com/r/intel/vllm/tags) based on vLLM released version. For more information, please refer release [note](https://github.com/intel/ai-containers/blob/main/vllm). # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] @@ -65,14 +64,13 @@ docker run -it \ # --8<-- [end:build-image-from-source] # --8<-- [start:supported-features] -XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. We require Ray as the distributed runtime backend. For example, a reference execution like following: +XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. For **pipeline parallel**, we support it on single node with mp as the backend. For example, a reference execution like following: ```bash -python -m vllm.entrypoints.openai.api_server \ - --model=facebook/opt-13b \ +vllm serve facebook/opt-13b \ --dtype=bfloat16 \ --max_model_len=1024 \ - --distributed-executor-backend=ray \ + --distributed-executor-backend=mp \ --pipeline-parallel-size=2 \ -tp=8 ``` diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md deleted file mode 100644 index 61b2b02aa10ba..0000000000000 --- a/docs/getting_started/installation/intel_gaudi.md +++ /dev/null @@ -1,388 +0,0 @@ -# Intel Gaudi - -This page provides instructions on running vLLM with Intel Gaudi devices. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - -## Requirements - -- OS: Ubuntu 22.04 LTS -- Python: 3.10 -- Intel Gaudi accelerator -- Intel Gaudi software version 1.18.0 - -Please follow the instructions provided in the -[Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) -to set up the execution environment. To achieve the best performance, -please follow the methods outlined in the -[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). - -## Configure a new environment - -### Environment verification - -To verify that the Intel Gaudi software was correctly installed, run: - -```bash -hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible -apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed -pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed -pip list | grep neural # verify that neural_compressor_pt is installed -``` - -Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) -for more details. - -### Run Docker Image - -It is highly recommended to use the latest Docker image from Intel Gaudi -vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) -for more details. - -Use the following commands to run a Docker image: - -```bash -docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --ipc=host \ - vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest -``` - -## Set up using Python - -### Pre-built wheels - -Currently, there are no pre-built Intel Gaudi wheels. - -### Build wheel from source - -To build and install vLLM from source, run: - -```bash -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -r requirements/hpu.txt -python setup.py develop -``` - -Currently, the latest features and performance optimizations are developed in Gaudi's [vLLM-fork](https://github.com/HabanaAI/vllm-fork) and we periodically upstream them to vLLM main repo. To install latest [HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the following: - -```bash -git clone https://github.com/HabanaAI/vllm-fork.git -cd vllm-fork -git checkout habana_main -pip install -r requirements/hpu.txt -python setup.py develop -``` - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Intel Gaudi images. - -### Build image from source - -```bash -docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . -docker run \ - -it \ - --runtime=habana \ - -e HABANA_VISIBLE_DEVICES=all \ - -e OMPI_MCA_btl_vader_single_copy_mechanism=none \ - --cap-add=sys_nice \ - --net=host \ - --rm vllm-hpu-env -``` - -!!! tip - If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. - -## Extra information - -### Supported features - -- [Offline inference](../../serving/offline_inference.md) -- Online serving via [OpenAI-Compatible Server](../../serving/openai_compatible_server.md) -- HPU autodetection - no need to manually select device within vLLM -- Paged KV cache with algorithms enabled for Intel Gaudi accelerators -- Custom Intel Gaudi implementations of Paged Attention, KV cache ops, - prefill attention, Root Mean Square Layer Normalization, Rotary - Positional Encoding -- Tensor parallelism support for multi-card inference -- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) - for accelerating low-batch latency and throughput -- Attention with Linear Biases (ALiBi) -- INC quantization - -### Unsupported features - -- Beam search -- LoRA adapters -- AWQ quantization -- Prefill chunking (mixed-batch inferencing) - -### Supported configurations - -The following configurations have been validated to function with -Gaudi2 devices. Configurations that are not listed may or may not work. - -| Model | TP Size| dtype | Sampling | -|-------|--------|--------|----------| -| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy | -| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy | -| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy | - -## Performance tuning - -### Execution modes - -Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via `PT_HPU_LAZY_MODE` environment variable), and `--enforce-eager` flag. - -| `PT_HPU_LAZY_MODE` | `enforce_eager` | execution mode | -|----------------------|-------------------|--------------------| -| 0 | 0 | torch.compile | -| 0 | 1 | PyTorch eager mode | -| 1 | 0 | HPU Graphs | - -!!! warning - In 1.18.0, all modes utilizing `PT_HPU_LAZY_MODE=0` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode. - -[](){ #gaudi-bucketing-mechanism } - -### Bucketing mechanism - -Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime) is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. -In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - `batch_size` and `sequence_length`. - -!!! note - Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase. - -Bucketing ranges are determined with 3 parameters - `min`, `step` and `max`. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: - -```text -INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] -INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] -INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] -INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -``` - -| Parameter | Description | -|----------------|-----------------------------------------------------------------------------| -| `min` | Determines the lowest value of the bucket. | -| `step` | Determines the interval between buckets. | -| `max` | Determines the upper bound of the bucket. | -| Ramp-up phase | A special handling phase applied between `min` and `step`:
- `min` is multiplied by consecutive powers of two until `step` is reached.
- Minimizes resource wastage for small batch sizes.
- Allows larger padding for larger batches. | - -Example (with ramp-up): - -```text -min = 2, step = 32, max = 64 -=> ramp_up = (2, 4, 8, 16) -=> stable = (32, 64) -=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64) -``` - -Example (without ramp-up): - -```text -min = 128, step = 128, max = 512 -=> ramp_up = () -=> stable = (128, 256, 384, 512) -=> buckets = ramp_up + stable => (128, 256, 384, 512) -``` - -In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket. - -!!! warning - If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario. - -As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as `(4, 512)` prefill bucket, as `batch_size` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as `(4, 512)` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a `(2, 512)` bucket, or context length increases above 512 tokens, in which case it will become `(4, 640)` bucket. - -!!! note - Bucketing is transparent to a client -- padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. - -### Warmup - -Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: - -??? console "Logs" - - ```text - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB - INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB - INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB - ... - INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB - INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB - INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB - ... - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - ``` - -This example uses the same buckets as in the [Bucketing Mechanism][gaudi-bucketing-mechanism] section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. - -!!! tip - Compiling all the buckets might take some time and can be turned off with `VLLM_SKIP_WARMUP=true` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. - -### HPU Graph capture - -[HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. - -When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by `gpu_memory_utilization` flag (`0.9` by default). -Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. -Only after that, `gpu_memory_utilization` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. -Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. -Environment variable `VLLM_GRAPH_RESERVED_MEM` defines the ratio of memory reserved for HPU Graphs capture. -With its default value (`VLLM_GRAPH_RESERVED_MEM=0.1`), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache. -Environment variable `VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory constraints. -Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs. - -!!! note - `gpu_memory_utilization` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, `gpu_memory_utilization` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory. - -User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - -- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode -- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt - -When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. - -!!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. - -Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): - -??? console "Logs" - - ```text - INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache - INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 - INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB - ... - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) - INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - ... - INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB - INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB - ... - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB - INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB - INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] - INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory - INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) - ``` - -### Recommended vLLM Parameters - -- We recommend running inference on Gaudi 2 with `block_size` of 128 - for BF16 data type. Using default values (16, 32) might lead to - sub-optimal performance due to Matrix Multiplication Engine - under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). -- For max throughput on Llama 7B, we recommend running with batch size - of 128 or 256 and max context length of 2048 with HPU Graphs enabled. - If you encounter out-of-memory issues, see troubleshooting section. - -### Environment variables - -**Diagnostic and profiling knobs:** - -- `VLLM_PROFILER_ENABLED`: If `true`, enable the high level profiler. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: If `true`, log graph compilations for each vLLM engine step when any occurs. Highly recommended to use with `PT_HPU_METRICS_GC_DETAILS=1`. `false` by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: If `true`, always log graph compilations for each vLLM engine step even if none occurred. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: If `true`, log CPU fallbacks for each vLLM engine step when any occurs. `false` by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, always log CPU fallbacks for each vLLM engine step even if none occurred. `false` by default. - -**Performance tuning knobs:** - -- `VLLM_SKIP_WARMUP`: if `true`, warmup will be skipped, `false` by default - -- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for HPUGraph capture, `0.1` by default - -- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory dedicated for prompt graphs, `0.3` by default - -- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt graph capture, `min_tokens` or `max_bs`, `min_tokens` by default - -- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode graph capture, `min_tokens` or `max_bs`, `max_bs` by default - -- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment variables configuring ranges of bucketing mechanism - - - `{phase}` is either `PROMPT` or `DECODE` - - - `{dim}` is either `BS`, `SEQ` or `BLOCK` - - - `{param}` is either `MIN`, `STEP` or `MAX` - - - Default values: - -| `{phase}` | Parameter | Env Variable | Value Expression | -|-----------|-----------|--------------|------------------| -| Prompt | Batch size min | `VLLM_PROMPT_BS_BUCKET_MIN` | `1` | -| Prompt | Batch size step | `VLLM_PROMPT_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Prompt | Batch size max | `VLLM_PROMPT_BS_BUCKET_MAX` | `min(max_num_seqs, 64)` | -| Prompt | Sequence length min | `VLLM_PROMPT_SEQ_BUCKET_MIN` | `block_size` | -| Prompt | Sequence length step | `VLLM_PROMPT_SEQ_BUCKET_STEP` | `block_size` | -| Prompt | Sequence length max | `VLLM_PROMPT_SEQ_BUCKET_MAX` | `max_model_len` | -| Decode | Batch size min | `VLLM_DECODE_BS_BUCKET_MIN` | `1` | -| Decode | Batch size step | `VLLM_DECODE_BS_BUCKET_STEP` | `min(max_num_seqs, 32)` | -| Decode | Batch size max | `VLLM_DECODE_BS_BUCKET_MAX` | `max_num_seqs` | -| Decode | Sequence length min | `VLLM_DECODE_BLOCK_BUCKET_MIN` | `block_size` | -| Decode | Sequence length step | `VLLM_DECODE_BLOCK_BUCKET_STEP` | `block_size` | -| Decode | Sequence length max | `VLLM_DECODE_BLOCK_BUCKET_MAX` | `max(128, (max_num_seqs*max_model_len)/block_size)` | - -Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: - -- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used; if `1`, PyTorch Lazy backend for Gaudi will be used. `1` is default. -- `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs - -## Troubleshooting: tweaking HPU graphs - -If you experience device out-of-memory issues or want to attempt -inference at higher batch sizes, try tweaking HPU Graphs by following -the below: - -- Tweak `gpu_memory_utilization` knob. It will decrease the - allocation of KV cache, leaving some headroom for capturing graphs - with larger batch size. By default `gpu_memory_utilization` is set - to 0.9. It attempts to allocate ~90% of HBM left for KV cache after - short profiling run. Note that decreasing reduces the number of KV - cache blocks you have available, and therefore reduces the effective - maximum number of tokens you can handle at a given time. -- If this method is not efficient, you can disable `HPUGraph` - completely. With HPU Graphs disabled, you are trading latency and - throughput at lower batches for potentially higher throughput on - higher batches. You can do that by adding `--enforce-eager` flag to - server (for online serving), or by passing `enforce_eager=True` - argument to LLM constructor (for offline inference). diff --git a/docs/getting_started/installation/python_env_setup.inc.md b/docs/getting_started/installation/python_env_setup.inc.md index 423bf9b00d07f..06794f8d3120e 100644 --- a/docs/getting_started/installation/python_env_setup.inc.md +++ b/docs/getting_started/installation/python_env_setup.inc.md @@ -1,4 +1,4 @@ -It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: +It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands: ```bash uv venv --python 3.12 --seed diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 2af26626d207d..49e1f6fac7151 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform: ## Prerequisites - OS: Linux -- Python: 3.9 -- 3.13 +- Python: 3.10 -- 3.13 ## Installation diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ed5d3b0092ae7..ecd71ee1f3f66 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import logging import sys from argparse import SUPPRESS, HelpFormatter @@ -7,34 +8,64 @@ from pathlib import Path from typing import Literal from unittest.mock import MagicMock, patch +from pydantic_core import core_schema + +logger = logging.getLogger("mkdocs") + ROOT_DIR = Path(__file__).parent.parent.parent.parent ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" sys.path.insert(0, str(ROOT_DIR)) -sys.modules["aiohttp"] = MagicMock() -sys.modules["blake3"] = MagicMock() sys.modules["vllm._C"] = MagicMock() -from vllm.benchmarks import latency # noqa: E402 -from vllm.benchmarks import serve # noqa: E402 -from vllm.benchmarks import throughput # noqa: E402 -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 -from vllm.entrypoints.cli.openai import ChatCommand # noqa: E402 -from vllm.entrypoints.cli.openai import CompleteCommand # noqa: E402 -from vllm.entrypoints.openai import cli_args # noqa: E402 -from vllm.entrypoints.openai import run_batch # noqa: E402 -from vllm.utils import FlexibleArgumentParser # noqa: E402 -logger = logging.getLogger("mkdocs") +class PydanticMagicMock(MagicMock): + """`MagicMock` that's able to generate pydantic-core schemas.""" + + def __get_pydantic_core_schema__(self, source_type, handler): + return core_schema.any_schema() + + +def auto_mock(module, attr, max_mocks=50): + """Function that automatically mocks missing modules during imports.""" + logger.info("Importing %s from %s", attr, module) + for _ in range(max_mocks): + try: + # First treat attr as an attr, then as a submodule + with patch("importlib.metadata.version", return_value="0.0.0"): + return getattr( + importlib.import_module(module), + attr, + importlib.import_module(f"{module}.{attr}"), + ) + except importlib.metadata.PackageNotFoundError as e: + raise e + except ModuleNotFoundError as e: + logger.info("Mocking %s for argparse doc generation", e.name) + sys.modules[e.name] = PydanticMagicMock() + + raise ImportError( + f"Failed to import {module}.{attr} after mocking {max_mocks} imports" + ) + + +latency = auto_mock("vllm.benchmarks", "latency") +serve = auto_mock("vllm.benchmarks", "serve") +throughput = auto_mock("vllm.benchmarks", "throughput") +AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs") +EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs") +ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand") +CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand") +cli_args = auto_mock("vllm.entrypoints.openai", "cli_args") +run_batch = auto_mock("vllm.entrypoints.openai", "run_batch") +FlexibleArgumentParser = auto_mock("vllm.utils", "FlexibleArgumentParser") class MarkdownFormatter(HelpFormatter): """Custom formatter that generates markdown for argument groups.""" def __init__(self, prog, starting_heading_level=3): - super().__init__(prog, - max_help_position=float('inf'), - width=float('inf')) + super().__init__(prog, max_help_position=float("inf"), width=float("inf")) self._section_heading_prefix = "#" * starting_heading_level self._argument_heading_prefix = "#" * (starting_heading_level + 1) self._markdown_output = [] @@ -56,23 +87,19 @@ class MarkdownFormatter(HelpFormatter): def add_arguments(self, actions): for action in actions: - if (len(action.option_strings) == 0 - or "--help" in action.option_strings): + if len(action.option_strings) == 0 or "--help" in action.option_strings: continue - option_strings = f'`{"`, `".join(action.option_strings)}`' + option_strings = f"`{'`, `'.join(action.option_strings)}`" heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n" self._markdown_output.append(heading_md) if choices := action.choices: - choices = f'`{"`, `".join(str(c) for c in choices)}`' - self._markdown_output.append( - f"Possible choices: {choices}\n\n") - elif ((metavar := action.metavar) - and isinstance(metavar, (list, tuple))): - metavar = f'`{"`, `".join(str(m) for m in metavar)}`' - self._markdown_output.append( - f"Possible choices: {metavar}\n\n") + choices = f"`{'`, `'.join(str(c) for c in choices)}`" + self._markdown_output.append(f"Possible choices: {choices}\n\n") + elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)): + metavar = f"`{'`, `'.join(str(m) for m in metavar)}`" + self._markdown_output.append(f"Possible choices: {metavar}\n\n") if action.help: self._markdown_output.append(f"{action.help}\n\n") @@ -87,7 +114,7 @@ class MarkdownFormatter(HelpFormatter): def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser: """Create a parser for the given class with markdown formatting. - + Args: cls: The class to create a parser for **kwargs: Additional keyword arguments to pass to `cls.add_cli_args`. @@ -114,29 +141,23 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Create parsers to document parsers = { - "engine_args": - create_parser(EngineArgs.add_cli_args), - "async_engine_args": - create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True), - "serve": - create_parser(cli_args.make_arg_parser), - "chat": - create_parser(ChatCommand.add_cli_args), - "complete": - create_parser(CompleteCommand.add_cli_args), - "bench_latency": - create_parser(latency.add_cli_args), - "bench_throughput": - create_parser(throughput.add_cli_args), - "bench_serve": - create_parser(serve.add_cli_args), - "run-batch": - create_parser(run_batch.make_arg_parser), + "engine_args": create_parser(EngineArgs.add_cli_args), + "async_engine_args": create_parser( + AsyncEngineArgs.add_cli_args, async_args_only=True + ), + "serve": create_parser(cli_args.make_arg_parser), + "chat": create_parser(ChatCommand.add_cli_args), + "complete": create_parser(CompleteCommand.add_cli_args), + "bench_latency": create_parser(latency.add_cli_args), + "bench_throughput": create_parser(throughput.add_cli_args), + "bench_serve": create_parser(serve.add_cli_args), + "run-batch": create_parser(run_batch.make_arg_parser), } # Generate documentation for each parser for stem, parser in parsers.items(): doc_path = ARGPARSE_DOC_DIR / f"{stem}.md" - with open(doc_path, "w") as f: - f.write(parser.format_help()) + # Specify encoding for building on Windows + with open(doc_path, "w", encoding="utf-8") as f: + f.write(super(type(parser), parser).format_help()) logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 1e8b848db46d8..ed8277f628d4b 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -11,7 +11,7 @@ import regex as re logger = logging.getLogger("mkdocs") ROOT_DIR = Path(__file__).parent.parent.parent.parent -ROOT_DIR_RELATIVE = '../../../../..' +ROOT_DIR_RELATIVE = "../../../../.." EXAMPLE_DIR = ROOT_DIR / "examples" EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" @@ -36,7 +36,7 @@ def fix_case(text: str) -> str: r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 } for pattern, repl in subs.items(): - text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE) + text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE) return text @@ -58,7 +58,8 @@ class Example: determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. determine_title() -> str: Determines the title of the document. generate() -> str: Generates the documentation content. - """ # noqa: E501 + """ # noqa: E501 + path: Path category: str = None main_file: Path = field(init=False) @@ -70,6 +71,10 @@ class Example: self.other_files = self.determine_other_files() self.title = self.determine_title() + @property + def is_code(self) -> bool: + return self.main_file.suffix != ".md" + def determine_main_file(self) -> Path: """ Determines the main file in the given path. @@ -80,9 +85,8 @@ class Example: Markdown file found in the directory. Raises: IndexError: If no Markdown files are found in the directory. - """ # noqa: E501 - return self.path if self.path.is_file() else list( - self.path.glob("*.md")).pop() + """ # noqa: E501 + return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop() def determine_other_files(self) -> list[Path]: """ @@ -94,15 +98,49 @@ class Example: Returns: list[Path]: A list of Path objects representing the other files in the directory. - """ # noqa: E501 + """ # noqa: E501 if self.path.is_file(): return [] is_other_file = lambda file: file.is_file() and file != self.main_file return [file for file in self.path.rglob("*") if is_other_file(file)] def determine_title(self) -> str: + if not self.is_code: + # Specify encoding for building on Windows + with open(self.main_file, encoding="utf-8") as f: + first_line = f.readline().strip() + match = re.match(r"^#\s+(?P.+)$", first_line) + if match: + return match.group("title") return fix_case(self.path.stem.replace("_", " ").title()) + def fix_relative_links(self, content: str) -> str: + """ + Fix relative links in markdown content by converting them to gh-file + format. + + Args: + content (str): The markdown content to process + + Returns: + str: Content with relative links converted to gh-file format + """ + # Regex to match markdown links [text](relative_path) + # This matches links that don't start with http, https, ftp, or # + link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)" + + def replace_link(match): + link_text = match.group(1) + relative_path = match.group(2) + + # Make relative to repo root + gh_file = (self.main_file.parent / relative_path).resolve() + gh_file = gh_file.relative_to(ROOT_DIR) + + return f"[{link_text}](gh-file:{gh_file})" + + return re.sub(link_pattern, replace_link, content) + def generate(self) -> str: content = f"# {self.title}\n\n" content += f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n" @@ -110,12 +148,18 @@ class Example: # Use long code fence to avoid issues with # included files containing code fences too code_fence = "``````" - is_code = self.main_file.suffix != ".md" - if is_code: - content += f"{code_fence}{self.main_file.suffix[1:]}\n" - content += f'--8<-- "{self.main_file}"\n' - if is_code: - content += f"{code_fence}\n" + + if self.is_code: + content += ( + f"{code_fence}{self.main_file.suffix[1:]}\n" + f'--8<-- "{self.main_file}"\n' + f"{code_fence}\n" + ) + else: + with open(self.main_file) as f: + # Skip the title from md snippets as it's been included above + main_content = f.readlines()[1:] + content += self.fix_relative_links("".join(main_content)) content += "\n" if not self.other_files: @@ -162,6 +206,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): doc_path = EXAMPLE_DOC_DIR / example.category / example_name if not doc_path.parent.exists(): doc_path.parent.mkdir(parents=True) - with open(doc_path, "w+") as f: + # Specify encoding for building on Windows + with open(doc_path, "w+", encoding="utf-8") as f: f.write(example.generate()) logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/remove_announcement.py b/docs/mkdocs/hooks/remove_announcement.py index 1a84039abc14f..12db2265b9f82 100644 --- a/docs/mkdocs/hooks/remove_announcement.py +++ b/docs/mkdocs/hooks/remove_announcement.py @@ -7,7 +7,7 @@ from typing import Literal def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa - if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag": + if os.getenv("READTHEDOCS_VERSION_TYPE") == "tag": # remove the warning banner if the version is a tagged release mkdocs_dir = Path(__file__).parent.parent announcement_path = mkdocs_dir / "overrides/main.html" diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py index 6fce6bd8130e0..53b1fbca26b9d 100644 --- a/docs/mkdocs/hooks/url_schemes.py +++ b/docs/mkdocs/hooks/url_schemes.py @@ -25,8 +25,9 @@ from mkdocs.structure.files import Files from mkdocs.structure.pages import Page -def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, - files: Files) -> str: +def on_page_markdown( + markdown: str, *, page: Page, config: MkDocsConfig, files: Files +) -> str: """ Custom MkDocs plugin hook to rewrite special GitHub reference links in Markdown. @@ -35,7 +36,7 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, GitHub shorthand links, such as: - `[Link text](gh-issue:123)` - `<gh-pr:456>` - + And rewrites them into fully-qualified GitHub URLs with GitHub icons: - `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)` - `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)` @@ -88,21 +89,21 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, """ Replaces a matched inline-style GitHub shorthand link with a full Markdown link. - + Example: [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) """ - url = f'{urls[match.group("type")]}/{match.group("path")}' + url = f"{urls[match.group('type')]}/{match.group('path')}" if fragment := match.group("fragment"): url += f"#{fragment}" - return f'[{gh_icon} {match.group("title")}]({url})' + return f"[{gh_icon} {match.group('title')}]({url})" def replace_auto_link(match: re.Match) -> str: """ Replaces a matched autolink-style GitHub shorthand with a full Markdown link. - + Example: <gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456) """ diff --git a/docs/mkdocs/javascript/mathjax.js b/docs/mkdocs/javascript/mathjax.js new file mode 100644 index 0000000000000..5da0d443578c4 --- /dev/null +++ b/docs/mkdocs/javascript/mathjax.js @@ -0,0 +1,20 @@ +// Enables MathJax rendering +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) diff --git a/docs/models/extensions/runai_model_streamer.md b/docs/models/extensions/runai_model_streamer.md index 992dddf385d0d..8a97a49825a41 100644 --- a/docs/models/extensions/runai_model_streamer.md +++ b/docs/models/extensions/runai_model_streamer.md @@ -24,6 +24,13 @@ vllm serve s3://core-llm/Llama-3-8b \ --load-format runai_streamer ``` +To run model from Google Cloud Storage run: + +```bash +vllm serve gs://core-llm/Llama-3-8b \ + --load-format runai_streamer +``` + To run model from a S3 compatible object store run: ```bash diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md index a64ecd31ebaef..05f8d16cc4ca7 100644 --- a/docs/models/generative_models.md +++ b/docs/models/generative_models.md @@ -4,7 +4,7 @@ vLLM provides first-class support for generative models, which covers most of LL In vLLM, generative models implement the[VllmModelForTextGeneration][vllm.model_executor.models.VllmModelForTextGeneration] interface. Based on the final hidden states of the input, these models output log probabilities of the tokens to generate, -which are then passed through [Sampler][vllm.model_executor.layers.sampler.Sampler] to obtain the final text. +which are then passed through [Sampler][vllm.v1.sample.sampler.Sampler] to obtain the final text. ## Configuration @@ -19,7 +19,7 @@ Run a model in generation mode via the option `--runner generate`. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration](../api/summary.md#configuration) for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.generate` diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 39f209d0eb7ed..50982d3d0d0f3 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -59,7 +59,7 @@ enabling the corresponding APIs: #### Predefined models If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`, -you can override some of its attributes via the `--override-pooler-config` option. +you can override some of its attributes via the `--pooler-config` option. #### Converted models @@ -75,13 +75,13 @@ the pooler assigned to each task has the following attributes by default: When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults. -You can further customize this via the `--override-pooler-config` option, +You can further customize this via the `--pooler-config` option, which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. -See [configuration](../api/summary.md#configuration) for a list of options when initializing the model. +See [configuration](../api/README.md#configuration) for a list of options when initializing the model. ### `LLM.embed` @@ -205,12 +205,12 @@ Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions. -For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). +For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf-overrides '{"is_matryoshka": true}'`, `--hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online). Here is an example to serve a model with Matryoshka Embeddings enabled. ```text -vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}' +vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}' ``` ### Offline Inference @@ -228,7 +228,7 @@ outputs = llm.embed(["Follow the white rabbit."], print(outputs[0].outputs) ``` -A code example can be found here: <gh-file:examples/offline_inference/embed_matryoshka_fy.py> +A code example can be found here: <gh-file:examples/offline_inference/pooling/embed_matryoshka_fy.py> ### Online Inference @@ -258,4 +258,4 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> +An OpenAI client example can be found here: <gh-file:examples/online_serving/pooling/openai_embedding_matryoshka_fy.py> diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 297d98142b5f2..157fa8d68de5d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -17,9 +17,26 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases. +vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". -To check if the modeling backend is Transformers, you can simply do this: +Currently, the Transformers backend works for the following: + +- Modalities: embedding models, language models and vision-language models* +- Architectures: encoder-only, decoder-only, mixture-of-experts +- Attention types: full attention and/or sliding attention + +_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ + +If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM: + +- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) +- Any combination of the following vLLM parallelisation schemes: + - Data parallel + - Tensor parallel + - Expert parallel + - Pipeline parallel + +Checking if the modeling backend is Transformers is as simple as: ```python from vllm import LLM @@ -27,20 +44,16 @@ llm = LLM(model=...) # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers! +If the printed type starts with `Transformers...` then it's using the Transformers model implementation! -!!! tip - You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md). +If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md). !!! note - vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. - -!!! note - In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. #### Custom models -If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! +If a model is neither supported natively by vLLM nor Transformers, it can still be used in vLLM! For a model to be compatible with the Transformers backend for vLLM it must: @@ -66,10 +79,11 @@ This section details the necessary modifications to make to a Transformers compa To make your model compatible with the Transformers backend, it needs: 1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. + 1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`. 2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. 3. `MyModel` must contain `_supports_attention_backend = True`. -<details> +<details class="code"> <summary>modeling_my_model.py</summary> ```python @@ -78,6 +92,7 @@ from transformers import PreTrainedModel from torch import nn class MyAttention(nn.Module): + is_causal = False # Only do this for encoder-only models def forward(self, hidden_states, **kwargs): ... @@ -101,13 +116,13 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into one of the Transformers backend classes in <gh-file:vllm/model_executor/models/transformers.py> which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class: -<details> +<details class="code"> <summary>configuration_my_model.py</summary> ```python @@ -322,23 +337,24 @@ th { | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | ✅︎ | | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | -| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | -| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | | `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | +| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -347,26 +363,27 @@ th { | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | | `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | | `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | | `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | -| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | | ✅︎ | +| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | +| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -374,6 +391,7 @@ th { | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | @@ -387,20 +405,21 @@ th { | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | -| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | @@ -411,6 +430,7 @@ th { | `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ | | `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ | | `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | +| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ | Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! @@ -421,9 +441,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. -!!! note - Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture. - ### Pooling Models See [this page](./pooling_models.md) for more information on how to use pooling models. @@ -440,6 +457,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | | `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | | `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | @@ -456,7 +474,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. - You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. + You need to manually set mean pooling by passing `--pooler-config '{"pooling_type": "MEAN"}'`. !!! note For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded. @@ -497,6 +515,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | @@ -513,6 +532,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' ``` +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. + !!! note Load the official original `mxbai-rerank-v2` by using the following command. @@ -521,7 +543,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ``` !!! note - Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>. + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/pooling/qwen3_reranker.py>. ```bash vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' @@ -547,7 +569,19 @@ If your model is not in the above list, we will try to automatically convert the !!! important For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, - e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + e.g.: `--pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + +#### Token Classification + +These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API. + +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| +| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | +| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ | + +!!! note + Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>. [](){ #supported-mm-models } @@ -597,7 +631,29 @@ See [this page](../features/multimodal_inputs.md) on how to pass multi-modal inp For hybrid-only models such as Llama-4, Step3 and Mistral-3, a text-only mode can be enabled by setting all supported multimodal modalities to 0 (e.g, `--limit-mm-per-prompt '{"image":0}`) so that their multimodal modules will not be loaded to free up more GPU memory for KV cache. !!! note - vLLM currently only supports adding LoRA to the language backbone of multimodal models. + vLLM currently only supports dynamic LoRA adapters on the language backbone of multimodal models. + If you wish to use a model with LoRA in the multi-modal encoder, + please merge the weights into the base model first before running it in vLLM like a regular model. + + ```python + from peft import PeftConfig, PeftModel + from transformers import AutoModelForImageTextToText, AutoProcessor + + def merge_and_save(model_id: str, output_dir: str): + base_model = AutoModelForImageTextToText.from_pretrained(model_id) + lora_model = PeftModel.from_pretrained( + base_model, + model_id, + config=PeftConfig.from_pretrained(model_id), + ) + model = lora_model.merge_and_unload().to(dtype=base_model.dtype) + model._hf_peft_config_loaded = False # Needed to save the merged model + + processor = AutoProcessor.from_pretrained(model_id) + + model.save_pretrained(output_dir) + processor.save_pretrained(output_dir) + ``` ### Generative Models @@ -615,7 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | -| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | @@ -625,9 +681,11 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | -| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | +| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | +| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | @@ -635,11 +693,11 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | +| `MiDashengLMModel` | MiDashengLM | T + A<sup>+</sup> | `mispeech/midashenglm-7b` | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | @@ -653,7 +711,10 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ | | `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | @@ -700,7 +761,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. !!! note - Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently. !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. @@ -744,8 +805,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. !!! note - For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) - is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. + For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. #### Transcription @@ -753,8 +813,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | -| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ | +| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | ✅︎ | +| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | ### Pooling Models @@ -771,8 +832,9 @@ The following table lists those that are tested in vLLM. | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| -| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | -| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | | +| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ | +| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ | +| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ | | `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* | <sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 280b3322b11c3..93ed383395f27 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -10,7 +10,7 @@ Before using EP, you need to install the necessary dependencies. We are actively 1. **Install DeepEP and pplx-kernels**: Set up host environment following vLLM's guide for EP kernels [here](gh-file:tools/ep_kernels). 2. **Install DeepGEMM library**: Follow the [official instructions](https://github.com/deepseek-ai/DeepGEMM#installation). -3. **For disaggregated serving**: Install UCX and NIXL following the [script](gh-file:tools/install_nixl.sh). +3. **For disaggregated serving**: Install `gdrcopy` by running the [`install_gdrcopy.sh`](gh-file:tools/install_gdrcopy.sh) script (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). ### Backend Selection Guide @@ -123,18 +123,46 @@ When enabled, vLLM collects load statistics with every forward pass and periodic ### EPLB Parameters +Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. The available keys and their descriptions are: + | Parameter | Description | Default | |-----------|-------------|---------| -| `--eplb-window-size` | Number of engine steps to track for rebalancing decisions | - | -| `--eplb-step-interval` | Frequency of rebalancing (every N engine steps) | - | -| `--eplb-log-balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | -| `--num-redundant-experts` | Additional global experts per EP rank beyond equal distribution | `0` | +| `window_size`| Number of engine steps to track for rebalancing decisions | 1000 | +| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 | +| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | +| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` | + +For example: + +```bash +vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' +``` + +??? tip "Prefer individual arguments instead of JSON?" + + ```bash + vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config.window_size 1000 \ + --eplb-config.step_interval 3000 \ + --eplb-config.num_redundant_experts 2 \ + --eplb-config.log_balancedness true + ``` ### Expert Distribution Formula - **Default**: Each EP rank has `NUM_TOTAL_EXPERTS ÷ NUM_EP_RANKS` experts - **With redundancy**: Each EP rank has `(NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS` experts +### Memory Footprint Overhead + +EPLB uses redundant experts that need to fit in GPU memory. This means that EPLB may not be a good fit for memory constrained environments or when KV cache space is at a premium. + +This overhead equals `NUM_MOE_LAYERS * BYTES_PER_EXPERT * (NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS`. +For DeepSeekV3, this is approximately `2.4 GB` for one redundant expert per EP rank. + ### Example Command Single node deployment with EPLB enabled: @@ -146,12 +174,10 @@ VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V --data-parallel-size 8 \ # Data parallelism --enable-expert-parallel \ # Enable EP --enable-eplb \ # Enable load balancer - --eplb-log-balancedness \ # Log balancing metrics - --eplb-window-size 1000 \ # Track last 1000 engine steps - --eplb-step-interval 3000 # Rebalance every 3000 steps + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' ``` -For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--num-redundant-experts` to 32 in large scale use cases so the most popular experts are always available. +For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available. ## Disaggregated Serving (Prefill/Decode Split) @@ -165,9 +191,9 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok ### Setup Steps -1. **Install KV Connector**: Install NIXL using the [installation script](gh-file:tools/install_nixl.sh) +1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. For non-cuda platform to install nixl with non-cuda UCX build, run the [install_nixl_from_source_ubuntu.py](gh-file:tools/install_nixl_from_source_ubuntu.py) script. -2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}` +2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'` 3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions. diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index dfed15d4ace97..fe0e1e3df378b 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -133,7 +133,7 @@ completion = client.chat.completions.create( {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], extra_body={ - "guided_choice": ["positive", "negative"] + "structured_outputs": {"choice": ["positive", "negative"]} } ) ``` @@ -236,10 +236,32 @@ The following extra parameters are supported: Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. -If the model has a [chat template][chat-template], you can replace `inputs` with a list of `messages` (same schema as [Chat API][chat-api]) -which will be treated as a single prompt to the model. +Code example: <gh-file:examples/online_serving/pooling/openai_embedding_client.py> -Code example: <gh-file:examples/online_serving/openai_embedding_client.py> +If the model has a [chat template][chat-template], you can replace `inputs` with a list of `messages` (same schema as [Chat API][chat-api]) +which will be treated as a single prompt to the model. Here is a convenience function for calling the API while retaining OpenAI's type annotations: + +??? code + + ```python + from openai import OpenAI + from openai._types import NOT_GIVEN, NotGiven + from openai.types.chat import ChatCompletionMessageParam + from openai.types.create_embedding_response import CreateEmbeddingResponse + + def create_chat_embeddings( + client: OpenAI, + *, + messages: list[ChatCompletionMessageParam], + model: str, + encoding_format: Union[Literal["base64", "float"], NotGiven] = NOT_GIVEN, + ) -> CreateEmbeddingResponse: + return client.post( + "/embeddings", + cast_to=CreateEmbeddingResponse, + body={"messages": messages, "model": model, "encoding_format": encoding_format}, + ) + ``` #### Multi-modal inputs @@ -254,7 +276,7 @@ and passing a list of `messages` in the request. Refer to the examples below for vllm serve TIGER-Lab/VLM2Vec-Full --runner pooling \ --trust-remote-code \ --max-model-len 4096 \ - --chat-template examples/template_vlm2vec.jinja + --chat-template examples/template_vlm2vec_phi3v.jinja ``` !!! important @@ -262,34 +284,36 @@ and passing a list of `messages` in the request. Refer to the examples below for to run this model in embedding mode instead of text generation mode. The custom chat template is completely different from the original one for this model, - and can be found here: <gh-file:examples/template_vlm2vec.jinja> + and can be found here: <gh-file:examples/template_vlm2vec_phi3v.jinja> Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: ??? code ```python - import requests - + from openai import OpenAI + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="EMPTY", + ) image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - response = requests.post( - "http://localhost:8000/v1/embeddings", - json={ - "model": "TIGER-Lab/VLM2Vec-Full", - "messages": [{ + response = create_chat_embeddings( + client, + model="TIGER-Lab/VLM2Vec-Full", + messages=[ + { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": "Represent the given image."}, ], - }], - "encoding_format": "float", - }, + } + ], + encoding_format="float", ) - response.raise_for_status() - response_json = response.json() - print("Embedding output:", response_json["data"][0]["embedding"]) + + print("Image embedding output:", response.data[0].embedding) ``` === "DSE-Qwen2-MRL" @@ -313,14 +337,15 @@ and passing a list of `messages` in the request. Refer to the examples below for `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code example below for details. -Full example: <gh-file:examples/online_serving/openai_chat_embedding_client_for_multimodal.py> +Full example: <gh-file:examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py> #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:embedding-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:embedding-pooling-params" ``` The following extra parameters are supported by default: @@ -350,13 +375,92 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. Code example: <gh-file:examples/online_serving/openai_transcription_client.py> -<!-- TODO: api enforced limits + uploading audios --> #### API Enforced Limits Set the maximum audio file size (in MB) that VLLM will accept, via the `VLLM_MAX_AUDIO_CLIP_FILESIZE_MB` environment variable. Default is 25 MB. +#### Uploading Audio Files + +The Transcriptions API supports uploading audio files in various formats including FLAC, MP3, MP4, MPEG, MPGA, M4A, OGG, WAV, and WEBM. + +**Using OpenAI Python Client:** + +??? code + + ```python + from openai import OpenAI + + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + + # Upload audio file from disk + with open("audio.mp3", "rb") as audio_file: + transcription = client.audio.transcriptions.create( + model="openai/whisper-large-v3-turbo", + file=audio_file, + language="en", + response_format="verbose_json" + ) + + print(transcription.text) + ``` + +**Using curl with multipart/form-data:** + +??? code + + ```bash + curl -X POST "http://localhost:8000/v1/audio/transcriptions" \ + -H "Authorization: Bearer token-abc123" \ + -F "file=@audio.mp3" \ + -F "model=openai/whisper-large-v3-turbo" \ + -F "language=en" \ + -F "response_format=verbose_json" + ``` + +**Supported Parameters:** + +- `file`: The audio file to transcribe (required) +- `model`: The model to use for transcription (required) +- `language`: The language code (e.g., "en", "zh") (optional) +- `prompt`: Optional text to guide the transcription style (optional) +- `response_format`: Format of the response ("json", "text") (optional) +- `temperature`: Sampling temperature between 0 and 1 (optional) + +For the complete list of supported parameters including sampling parameters and vLLM extensions, see the [protocol definitions](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L2182). + +**Response Format:** + +For `verbose_json` response format: + +??? code + + ```json + { + "text": "Hello, this is a transcription of the audio file.", + "language": "en", + "duration": 5.42, + "segments": [ + { + "id": 0, + "seek": 0, + "start": 0.0, + "end": 2.5, + "text": "Hello, this is a transcription", + "tokens": [50364, 938, 428, 307, 275, 28347], + "temperature": 0.0, + "avg_logprob": -0.245, + "compression_ratio": 1.235, + "no_speech_prob": 0.012 + } + ] + } + ``` + #### Extra Parameters The following [sampling parameters][sampling-params] are supported. @@ -374,7 +478,7 @@ The following extra parameters are supported: ```python --8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params" ``` - + [](){ #translations-api } ### Translations API @@ -421,7 +525,7 @@ Our Pooling API encodes input prompts using a [pooling model](../models/pooling_ The input format is the same as [Embeddings API][embeddings-api], but the output data can contain an arbitrary nested list, not just a 1-D list of floats. -Code example: <gh-file:examples/online_serving/openai_pooling_client.py> +Code example: <gh-file:examples/online_serving/pooling/openai_pooling_client.py> [](){ #classification-api } @@ -431,7 +535,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. -Code example: <gh-file:examples/online_serving/openai_classification_client.py> +Code example: <gh-file:examples/online_serving/pooling/openai_classification_client.py> #### Example Requests @@ -527,10 +631,11 @@ curl -v "http://127.0.0.1:8000/classify" \ #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:classification-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:classification-pooling-params" ``` The following extra parameters are supported: @@ -733,10 +838,11 @@ Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_mu #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:score-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:classification-pooling-params" ``` The following extra parameters are supported: @@ -760,7 +866,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin [Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. -Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py> +Code example: <gh-file:examples/online_serving/pooling/jinaai_rerank_client.py> #### Example Request @@ -815,10 +921,11 @@ Result documents will be sorted by relevance, and the `index` property can be us #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:rerank-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:classification-pooling-params" ``` The following extra parameters are supported: diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md index fa7fc1b290d50..cef1127fc5c15 100644 --- a/docs/serving/parallelism_scaling.md +++ b/docs/serving/parallelism_scaling.md @@ -66,7 +66,7 @@ Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens. -Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm/serving-llms.html) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. +Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. For details, see the [Ray documentation](https://docs.ray.io/en/latest/index.html). @@ -104,7 +104,7 @@ Note that `VLLM_HOST_IP` is unique for each worker. Keep the shells running thes From any node, enter a container and run `ray status` and `ray list nodes` to verify that Ray finds the expected number of nodes and GPUs. !!! tip - Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/vllm-rayservice.html). + Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/rayserve-llm-example.html). ### Running vLLM on a Ray cluster diff --git a/docs/training/rlhf.md b/docs/training/rlhf.md index f608a630ab7a5..b207c9ed373b8 100644 --- a/docs/training/rlhf.md +++ b/docs/training/rlhf.md @@ -1,8 +1,19 @@ # Reinforcement Learning from Human Feedback -Reinforcement Learning from Human Feedback (RLHF) is a technique that fine-tunes language models using human-generated preference data to align model outputs with desired behaviors. +Reinforcement Learning from Human Feedback (RLHF) is a technique that fine-tunes language models using human-generated preference data to align model outputs with desired behaviors. vLLM can be used to generate the completions for RLHF. -vLLM can be used to generate the completions for RLHF. Some ways to do this include using libraries like [TRL](https://github.com/huggingface/trl), [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF), [verl](https://github.com/volcengine/verl) and [unsloth](https://github.com/unslothai/unsloth). +The following open-source RL libraries use vLLM for fast rollouts (sorted alphabetically and non-exhaustive): + +- [Cosmos-RL](https://github.com/nvidia-cosmos/cosmos-rl) +- [NeMo-RL](https://github.com/NVIDIA-NeMo/RL) +- [Open Instruct](https://github.com/allenai/open-instruct) +- [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) +- [PipelineRL](https://github.com/ServiceNow/PipelineRL) +- [Prime-RL](https://github.com/PrimeIntellect-ai/prime-rl) +- [SkyRL](https://github.com/NovaSky-AI/SkyRL) +- [TRL](https://github.com/huggingface/trl) +- [Unsloth](https://github.com/unslothai/unsloth) +- [verl](https://github.com/volcengine/verl) See the following basic examples to get started if you don't want to use an existing library: @@ -12,4 +23,5 @@ See the following basic examples to get started if you don't want to use an exis See the following notebooks showing how to use vLLM for GRPO: +- [Efficient Online Training with GRPO and vLLM in TRL](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) - [Qwen-3 4B GRPO using Unsloth + vLLM](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb) diff --git a/docs/training/trl.md b/docs/training/trl.md index c7c1a5a3bbd1e..acf48cc4ecb33 100644 --- a/docs/training/trl.md +++ b/docs/training/trl.md @@ -1,12 +1,54 @@ # Transformers Reinforcement Learning -Transformers Reinforcement Learning (TRL) is a full stack library that provides a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more. The library is integrated with 🤗 transformers. +[Transformers Reinforcement Learning](https://huggingface.co/docs/trl) (TRL) is a full stack library that provides a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more. The library is integrated with 🤗 transformers. Online methods such as GRPO or Online DPO require the model to generate completions. vLLM can be used to generate these completions! -See the guide [vLLM for fast generation in online methods](https://huggingface.co/docs/trl/main/en/speeding_up_training#vllm-for-fast-generation-in-online-methods) in the TRL documentation for more information. +See the [vLLM integration guide](https://huggingface.co/docs/trl/main/en/vllm_integration) in the TRL documentation for more information. + +TRL currently supports the following online trainers with vLLM: + +- [GRPO](https://huggingface.co/docs/trl/main/en/grpo_trainer) +- [Online DPO](https://huggingface.co/docs/trl/main/en/online_dpo_trainer) +- [RLOO](https://huggingface.co/docs/trl/main/en/rloo_trainer) +- [Nash-MD](https://huggingface.co/docs/trl/main/en/nash_md_trainer) +- [XPO](https://huggingface.co/docs/trl/main/en/xpo_trainer) + +To enable vLLM in TRL, set the `use_vllm` flag in the trainer configuration to `True`. + +## Modes of Using vLLM During Training + +TRL supports **two modes** for integrating vLLM during training: **server mode** and **colocate mode**. You can control how vLLM operates during training with the `vllm_mode` parameter. + +### Server mode + +In **server mode**, vLLM runs as an independent process on dedicated GPUs and communicates with the trainer through HTTP requests. This configuration is ideal when you have separate GPUs for inference, as it isolates generation workloads from training, ensuring stable performance and easier scaling. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + +### Colocate mode + +In **colocate mode**, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + +Some trainers also support **vLLM sleep mode**, which offloads parameters and caches to GPU RAM during training, helping reduce memory usage. Learn more in the [memory optimization docs](https://huggingface.co/docs/trl/main/en/reducing_memory_usage#vllm-sleep-mode). !!! info - For more information on the `use_vllm` flag you can provide to the configs of these online methods, see: - - [`trl.GRPOConfig.use_vllm`](https://huggingface.co/docs/trl/main/en/grpo_trainer#trl.GRPOConfig.use_vllm) - - [`trl.OnlineDPOConfig.use_vllm`](https://huggingface.co/docs/trl/main/en/online_dpo_trainer#trl.OnlineDPOConfig.use_vllm) + For detailed configuration options and flags, refer to the documentation of the specific trainer you are using. diff --git a/docs/usage/README.md b/docs/usage/README.md index 83aea121819f8..0c63d01f0f99f 100644 --- a/docs/usage/README.md +++ b/docs/usage/README.md @@ -1,6 +1,6 @@ # Using vLLM -First, vLLM must be [installed](../getting_started/installation) for your chosen device in either a Python or Docker environment. +First, vLLM must be [installed](../getting_started/installation/) for your chosen device in either a Python or Docker environment. Then, vLLM supports the following usage patterns: diff --git a/docs/usage/security.md b/docs/usage/security.md index d54e2bb37ec07..9d10b66a5a97f 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -60,6 +60,15 @@ Key points from the PyTorch security guide: - Implement proper authentication and authorization for management interfaces - Follow the principle of least privilege for all system components +### 4. **Restrict Domains Access for Media URLs:** + +Restrict domains that vLLM can access for media URLs by setting +`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks. +(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`) + +Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP +redirects from being followed to bypass domain restrictions. + ## Security and Firewalls: Protecting Exposed vLLM Systems While vLLM is designed to allow unsafe network services to be isolated to diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index b92c6cef4a3fa..6e700d1faaa9c 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -40,6 +40,34 @@ If other strategies don't solve the problem, it's likely that the vLLM instance - `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL. - `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time. +## Breakpoints + +Setting normal `pdb` breakpoints may not work in vLLM's codebase if they are executed in a subprocess. You will experience something like: + +``` text + File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 100, in trace_dispatch + return self.dispatch_line(frame) + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 125, in dispatch_line + if self.quitting: raise BdbQuit + ^^^^^^^^^^^^^ +bdb.BdbQuit +``` + +One solution is using [forked-pdb](https://github.com/Lightning-AI/forked-pdb). Install with `pip install fpdb` and set a breakpoint with something like: + +``` python +__import__('fpdb').ForkedPdb().set_trace() +``` + +Another option is to disable multiprocessing entirely, with the `VLLM_ENABLE_V1_MULTIPROCESSING` environment variable. +This keeps the scheduler in the same process, so you can use stock `pdb` breakpoints: + +``` python +import os +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +``` + ## Incorrect network setup The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as `DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl` and the IP address should be the correct one. @@ -295,4 +323,5 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to ## Known Issues - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). -- To circumvent a NCCL [bug](https://github.com/NVIDIA/nccl/issues/1234) , all vLLM processes will set an environment variable `NCCL_CUMEM_ENABLE=0` to disable NCCL's `cuMem` allocator. It does not affect performance but only gives memory benefits. When external processes want to set up a NCCL connection with vLLM's processes, they should also set this environment variable, otherwise, inconsistent environment setup will cause NCCL to hang or crash, as observed in the [RLHF integration](https://github.com/OpenRLHF/OpenRLHF/pull/604) and the [discussion](gh-issue:5723#issuecomment-2554389656) . +- To address a memory overhead issue in older NCCL versions (see [bug](https://github.com/NVIDIA/nccl/issues/1234)), vLLM versions `>= 0.4.3, <= 0.10.1.1` would set the environment variable `NCCL_CUMEM_ENABLE=0`. External processes connecting to vLLM also needed to set this variable to prevent hangs or crashes. Since the underlying NCCL bug was fixed in NCCL 2.22.3, this override was removed in newer vLLM versions to allow for NCCL performance optimizations. +- In some PCIe machines (e.g. machines without NVLink), if you see an error like `transport/shm.cc:590 NCCL WARN Cuda failure 217 'peer access is not supported between these two devices'`, it's likely caused by a driver bug. See [this issue](https://github.com/NVIDIA/nccl/issues/1838) for more details. In that case, you can try to set `NCCL_CUMEM_HOST_ENABLE=0` to disable the feature, or upgrade your driver to the latest version. diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index e78c67522f61b..4c7a7ff019e8c 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -51,7 +51,7 @@ tail ~/.config/vllm/usage_stats.json ## Opting out -You can opt-out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: +You can opt out of usage stats collection by setting the `VLLM_NO_USAGE_STATS` or `DO_NOT_TRACK` environment variable, or by creating a `~/.config/vllm/do_not_track` file: ```bash # Any of the following methods can disable usage stats collection diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 7fc615d4c042f..340aaf54bb720 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | Model Type | Status | |-----------------------------|------------------------------------------------------------------------------------| | **Decoder-only Models** | <nobr>🚀 Optimized</nobr> | -| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> | +| **Encoder-Decoder Models** | <nobr>🟢 Whisper only</nobr> | | **Embedding Models** | <nobr>🟢 Functional</nobr> | | **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | @@ -107,20 +107,20 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. -Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that -these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. +Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`). -Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`). -Please note that these models currently require disabling prefix caching, enforcing eager mode, and using the FlashInfer -attention backend in V1. +Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). + +Please note that prefix caching is not yet supported for any of the above models. #### Encoder-Decoder Models -Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`) -are not yet supported. +Whisper is supported. Other models requiring cross-attention between separate +encoder and decoder (e.g., `BartForConditionalGeneration`, +`MllamaForConditionalGeneration`) are not supported. ### Features diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 22cb8b057dac7..a36664e470450 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple): # Voxtral def run_voxtral(question: str, audio_count: int) -> ModelRequestData: from mistral_common.audio import Audio - from mistral_common.protocol.instruct.messages import ( + from mistral_common.protocol.instruct.chunk import ( AudioChunk, RawAudio, TextChunk, + ) + from mistral_common.protocol.instruct.messages import ( UserMessage, ) from mistral_common.protocol.instruct.request import ChatCompletionRequest @@ -117,7 +119,7 @@ def run_gemma3n(question: str, audio_count: int) -> ModelRequestData: # Granite Speech def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: - # NOTE - the setting in this example are somehat different than what is + # NOTE - the setting in this example are somewhat different from what is # optimal for granite speech, and it is generally recommended to use beam # search. Check the model README for suggested settings. # https://huggingface.co/ibm-granite/granite-speech-3.3-8b @@ -146,6 +148,36 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ) +# MiDashengLM +def run_midashenglm(question: str, audio_count: int): + model_name = "mispeech/midashenglm-7b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + audio_in_prompt = "".join( + ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)] + ) + + default_system = "You are a helpful language and speech assistant." + + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" @@ -352,6 +384,7 @@ model_example_map = { "voxtral": run_voxtral, "gemma3n": run_gemma3n, "granite_speech": run_granite_speech, + "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, "phi4_multimodal": run_phi4_multimodal, diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index d078c517d00e7..9e7036fea6134 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -87,6 +87,7 @@ def main(args: dict): use_tqdm=False, chat_template=chat_template, ) + print_outputs(outputs) if __name__ == "__main__": diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 6e56e24f2092c..3a95b1fdfbabc 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -143,5 +143,5 @@ outputs = llm.chat(messages, sampling_params, tools=tools) print(outputs[0].outputs[0].text.strip()) # yields -# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'The weather in Dallas, TX is 85 degrees Fahrenheit. ' # 'It is partly cloudly, with highs in the 90's.' diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dd7559451c4c6..0076d4d30ee8e 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -87,10 +87,27 @@ def parse_args(): default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--enable-dbo", + action="store_true", + help=("Enable microbatched execution"), + ) + parser.add_argument( + "--compilation-config", + type=int, + help=("Compilation optimization (O) level 0-3."), + ) parser.add_argument( "--quantization", type=str, ) + parser.add_argument( + "--disable-expert-parallel", + dest="enable_expert_parallel", + action="store_false", + help="Disable expert parallel (default: enabled).", + ) + parser.set_defaults(enable_expert_parallel=True) return parser.parse_args() @@ -103,10 +120,13 @@ def main( dp_master_port, GPUs_per_dp_rank, enforce_eager, + enable_expert_parallel, trust_remote_code, max_num_seqs, max_model_len, + compilation_config, gpu_memory_utilization, + enable_dbo, quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -156,12 +176,14 @@ def main( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=enforce_eager, - enable_expert_parallel=True, + enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + enable_dbo=enable_dbo, quantization=quantization, + compilation_config=compilation_config, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -215,10 +237,13 @@ if __name__ == "__main__": dp_master_port, tp_size, args.enforce_eager, + args.enable_expert_parallel, args.trust_remote_code, args.max_num_seqs, args.max_model_len, + args.compilation_config, args.gpu_memory_utilization, + args.enable_dbo, args.quantization, ), ) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 05a361fee0717..f619fa584f801 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -30,12 +30,12 @@ def run_prefill(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_producer", kv_rank=0, kv_parallel_size=2, @@ -74,12 +74,12 @@ def run_decode(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_consumer", kv_rank=1, kv_parallel_size=2, diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py deleted file mode 100644 index df6c1eaf4a21e..0000000000000 --- a/examples/offline_inference/encoder_decoder.py +++ /dev/null @@ -1,193 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART and mBART. - -This script is refactored to allow model selection via command-line arguments. -""" - -import argparse -from typing import NamedTuple, Optional - -from vllm import LLM, SamplingParams -from vllm.inputs import ( - ExplicitEncoderDecoderPrompt, - TextPrompt, - TokensPrompt, - zip_enc_dec_prompts, -) - - -class ModelRequestData(NamedTuple): - """ - Holds the configuration for a specific model, including its - HuggingFace ID and the prompts to use for the demo. - """ - - model_id: str - encoder_prompts: list - decoder_prompts: list - hf_overrides: Optional[dict] = None - - -def get_bart_config() -> ModelRequestData: - """ - Returns the configuration for facebook/bart-large-cnn. - This uses the exact test cases from the original script. - """ - encoder_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "An encoder prompt", - ] - decoder_prompts = [ - "A decoder prompt", - "Another decoder prompt", - ] - return ModelRequestData( - model_id="facebook/bart-large-cnn", - encoder_prompts=encoder_prompts, - decoder_prompts=decoder_prompts, - ) - - -def get_mbart_config() -> ModelRequestData: - """ - Returns the configuration for facebook/mbart-large-en-ro. - This uses prompts suitable for an English-to-Romanian translation task. - """ - encoder_prompts = [ - "The quick brown fox jumps over the lazy dog.", - "How are you today?", - ] - decoder_prompts = ["", ""] - hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} - return ModelRequestData( - model_id="facebook/mbart-large-en-ro", - encoder_prompts=encoder_prompts, - decoder_prompts=decoder_prompts, - hf_overrides=hf_overrides, - ) - - -MODEL_GETTERS = { - "bart": get_bart_config, - "mbart": get_mbart_config, -} - - -def create_all_prompt_types( - encoder_prompts_raw: list, - decoder_prompts_raw: list, - tokenizer, -) -> list: - """ - Generates a list of diverse prompt types for demonstration. - This function is generic and uses the provided raw prompts - to create various vLLM input objects. - """ - text_prompt_raw = encoder_prompts_raw[0] - text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) - tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode( - encoder_prompts_raw[2 % len(encoder_prompts_raw)] - ) - ) - - decoder_tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) - ) - single_prompt_examples = [ - text_prompt_raw, - text_prompt, - tokens_prompt, - ] - explicit_pair_examples = [ - ExplicitEncoderDecoderPrompt( - encoder_prompt=text_prompt_raw, - decoder_prompt=decoder_tokens_prompt, - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=text_prompt, - decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=tokens_prompt, - decoder_prompt=text_prompt, - ), - ] - zipped_prompt_list = zip_enc_dec_prompts( - encoder_prompts_raw, - decoder_prompts_raw, - ) - return single_prompt_examples + explicit_pair_examples + zipped_prompt_list - - -def create_sampling_params() -> SamplingParams: - """Create a sampling params object.""" - return SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=30, - ) - - -def print_outputs(outputs: list): - """Formats and prints the generation outputs.""" - print("-" * 80) - for i, output in enumerate(outputs): - prompt = output.prompt - encoder_prompt = output.encoder_prompt - generated_text = output.outputs[0].text - print(f"Output {i + 1}:") - print(f"Encoder Prompt: {encoder_prompt!r}") - print(f"Decoder Prompt: {prompt!r}") - print(f"Generated Text: {generated_text!r}") - print("-" * 80) - - -def main(args): - """Main execution function.""" - model_key = args.model - if model_key not in MODEL_GETTERS: - raise ValueError( - f"Unknown model: {model_key}. " - f"Available models: {list(MODEL_GETTERS.keys())}" - ) - config_getter = MODEL_GETTERS[model_key] - model_config = config_getter() - - print(f"🚀 Running demo for model: {model_config.model_id}") - llm = LLM( - model=model_config.model_id, - dtype="float", - hf_overrides=model_config.hf_overrides, - ) - tokenizer = llm.llm_engine.get_tokenizer_group() - prompts = create_all_prompt_types( - encoder_prompts_raw=model_config.encoder_prompts, - decoder_prompts_raw=model_config.decoder_prompts, - tokenizer=tokenizer, - ) - sampling_params = create_sampling_params() - outputs = llm.generate(prompts, sampling_params) - print_outputs(outputs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="A flexible demo for vLLM encoder-decoder models." - ) - parser.add_argument( - "--model", - "-m", - type=str, - default="bart", - choices=MODEL_GETTERS.keys(), - help="The short name of the model to run.", - ) - args = parser.parse_args() - main(args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb7e7..4a1b0c40604b2 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with the explicit/implicit prompt format on enc-dec LMMs for text generation. """ +import os import time from collections.abc import Sequence from dataclasses import asdict @@ -12,7 +13,6 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset -from vllm.assets.image import ImageAsset from vllm.utils import FlexibleArgumentParser @@ -21,70 +21,9 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] -def run_florence2(): - engine_args = EngineArgs( - model="microsoft/Florence-2-large", - tokenizer="Isotr0py/Florence-2-tokenizer", - max_num_seqs=8, - trust_remote_code=True, - limit_mm_per_prompt={"image": 1}, - dtype="half", - ) - - prompts = [ - { # implicit prompt with task token - "prompt": "<DETAILED_CAPTION>", - "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image}, - }, - { # explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "Describe in detail what is shown in the image.", - "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image}, - }, - "decoder_prompt": "", - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - -def run_mllama(): - engine_args = EngineArgs( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, - dtype="half", - ) - - prompts = [ - { # Implicit prompt - "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image, - }, - }, - { # Explicit prompt - "encoder_prompt": { - "prompt": "<|image|>", - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image, - }, - }, - "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - def run_whisper(): + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + engine_args = EngineArgs( model="openai/whisper-large-v3-turbo", max_model_len=448, @@ -118,8 +57,6 @@ def run_whisper(): model_example_map = { - "florence2": run_florence2, - "mllama": run_mllama, "whisper": run_whisper, } @@ -133,7 +70,7 @@ def parse_args(): "--model-type", "-m", type=str, - default="mllama", + default="whisper", choices=model_example_map.keys(), help='Huggingface "model_type".', ) diff --git a/examples/offline_inference/kv_load_failure_recovery/README.md b/examples/offline_inference/kv_load_failure_recovery/README.md new file mode 100644 index 0000000000000..230a16812b25e --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/README.md @@ -0,0 +1,30 @@ +# KV Load Failure Recovery Test + +This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`. + +It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output. + +## Files + +- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`). +- `decode_example.py` – performs the decode stage. Accepts: + - `--simulate-failure`: simulates KV load failure using a custom connector. + - `--async-load`: enables asynchronous KV loading mode. +- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. +- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages: + 1. Normal decode (baseline). + 2. Decode with simulated sync KV load failure. + 3. Decode with simulated async KV load failure. + + Finally, it compares the output of the baseline with the recovered outputs to verify correctness. + +## How It Works + +- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. +- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode. +- If recovery fails, the script prints a unified diff of the output mismatch and exits with error. + +## Usage + +```bash +./run.sh diff --git a/examples/offline_inference/kv_load_failure_recovery/decode_example.py b/examples/offline_inference/kv_load_failure_recovery/decode_example.py new file mode 100644 index 0000000000000..69523f56eace3 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/decode_example.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def read_prompts(): + """Read prompts from prefill_output.txt""" + prompts = [] + try: + with open("prefill_output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from prefill_output.txt") + return prompts + except FileNotFoundError: + print("Error: prefill_output.txt file not found") + exit(-1) + + +def main(): + prompts = read_prompts() + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--simulate-failure", action="store_true", help="Simulate KV load failure." + ) + parser.add_argument( + "--async-load", action="store_true", help="Simulate async KV load" + ) + args = parser.parse_args() + + if args.simulate_failure: + ktc = KVTransferConfig( + kv_connector="RogueSharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage", + "async_load": args.async_load, + }, + kv_connector_module_path="rogue_shared_storage_connector", + ) + out_file = ( + "async_decode_recovered_output.txt" + if args.async_load + else "sync_decode_recovered_output.txt" + ) + else: + ktc = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage", + }, + ) + out_file = "decode_output.txt" + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=ktc, + ) + + outputs = llm.generate(prompts, sampling_params) + + sep_str = "-" * 30 + with open(out_file, "w", encoding="utf-8") as f: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}" + print(out_str) + print(sep_str) + f.write(out_str) + f.write(sep_str) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/kv_load_failure_recovery/prefill_example.py b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py new file mode 100644 index 0000000000000..047b81c82df53 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def read_prompts(): + context = "Hi " * 1000 + context2 = "Hey " * 500 + return [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", + ] + + +def main(): + prompts = read_prompts() + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ), + ) # , max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + # Write new_prompts to prefill_output.txt + with open("prefill_output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to prefill_output.txt") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py new file mode 100644 index 0000000000000..0abe7d1612610 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + SharedStorageConnector, + SharedStorageConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +@dataclass +class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata): + req_to_block_ids: dict[str, set[int]] = field(default_factory=dict) + + @classmethod + def from_base(cls, base: SharedStorageConnectorMetadata): + return cls(requests=base.requests) + + +class RogueSharedStorageConnector(SharedStorageConnector): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( + "async_load", False + ) + self._invalid_block_ids: set = None + self._seen_requests: set = set() + self._req_to_block_ids: dict[str, list[int]] = dict() + + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata) + index, failed_request = next( + ( + (i, x) + for i, x in enumerate(connector_metadata.requests) + if not x.is_store + ), + (None, None), + ) + if index is not None: + del connector_metadata.requests[index] + self._invalid_block_ids = set( + ( + failed_request.slot_mapping[:: self._block_size] // self._block_size + ).tolist() + ) + logger.info( + "Simulating failure to load all KV blocks for the " + "first load request. Total blocks: %d", + len(self._invalid_block_ids), + ) + super().bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + self._invalid_block_ids = None + super().clear_connector_metadata() + + def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None: + if self._async_load and forward_context.attn_metadata is None: + # Bypass sanity check in super().start_load_kv + forward_context.attn_metadata = "None" + + super().start_load_kv(forward_context, **kwargs) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if self._async_load: + meta = self._get_connector_metadata() + assert isinstance(meta, RogueSharedStorageConnectorMetadata) + if meta.req_to_block_ids: + return None, set(meta.req_to_block_ids) + + return None, None + + def get_block_ids_with_load_errors(self) -> set[int]: + return self._invalid_block_ids + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int, bool]: + if request.request_id in self._seen_requests: + return 0, False + + self._seen_requests.add(request.request_id) + + num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) + return num_tokens, self._async_load and num_tokens > 0 + + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + super().update_state_after_alloc(request, blocks, num_external_tokens) + + if num_external_tokens > 0: + self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0] + + def build_connector_meta( + self, + scheduler_output: "SchedulerOutput", + ) -> KVConnectorMetadata: + if not self._async_load: + base = super().build_connector_meta(scheduler_output) + meta = RogueSharedStorageConnectorMetadata.from_base(base) + else: + meta = RogueSharedStorageConnectorMetadata() + if self._requests_need_load: + for req_id, request in self._requests_need_load.items(): + meta.add_request( + token_ids=request.prompt_token_ids, + block_ids=self._req_to_block_ids[req_id], + block_size=self._block_size, + is_store=False, + mm_hashes=[], + ) + # Clear state + self._requests_need_load.clear() + meta.req_to_block_ids = self._req_to_block_ids + self._req_to_block_ids = dict() + return meta diff --git a/examples/offline_inference/kv_load_failure_recovery/run.sh b/examples/offline_inference/kv_load_failure_recovery/run.sh new file mode 100755 index 0000000000000..53fe2385d46d1 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Constants +SHARED_STORAGE_DIR="local_storage" +PREFILL_OUTPUT="prefill_output.txt" +DECODE_OUTPUT="decode_output.txt" +SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt" +ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt" + +# Cleanup +rm -rf "$SHARED_STORAGE_DIR" +rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" + +# Run inference examples +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load + +# Compare outputs +if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then + echo "❌ Outputs differ: sync recovery failed." + diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" + exit 1 +fi + +if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then + echo "❌ Outputs differ: async recovery failed." + diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" + exit 1 +fi + +echo "✅ Outputs match: recovery successful." diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor/custom.py similarity index 73% rename from examples/offline_inference/logits_processor.py rename to examples/offline_inference/logits_processor/custom.py index 7ef20efa7d28c..4112a498f37ab 100644 --- a/examples/offline_inference/logits_processor.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -42,8 +42,8 @@ from vllm.config import VllmConfig from vllm.v1.sample.logits_processor import ( BatchUpdate, LogitsProcessor, - MoveDirectionality, ) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates # Hypothetical custom logits processor @@ -53,51 +53,33 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: - """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and ( - target_token := params.extra_args.get("target_token") - ): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + # This function returns the LP's per-request state based on the + # request details, or None if this LP does not apply to the + # request. + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: return logits # Save target values before modification - rows_list = list(self.req_info.keys()) cols = torch.tensor( - [self.req_info[i] for i in rows_list], - dtype=torch.long, - device=logits.device, + list(self.req_info.values()), dtype=torch.long, device=logits.device + ) + rows = torch.tensor( + list(self.req_info.keys()), dtype=torch.long, device=logits.device ) - rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) values_to_keep = logits[rows, cols].clone() # Mask all but target tokens diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py new file mode 100644 index 0000000000000..4c19bb4ce2bae --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates wrapping a request-level logits processor to be +compatible with vLLM's batch-level logits processing + +For demo purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. This logits processor can be +applied to a vector of logits associated with a single decode step for a single +request. The logits processor cannot be applied to a request which does not +pass in a `target_token` custom argument. + +The request-level dummy logits processor is wrapped to create a batch-level +logits processor, which can apply the logits processor to output logits from +all requests in the persistent batch in a given decode step. For requests which +do not provide a `target_token` argument, the corresponding row of `logits` +will not be modified. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Any, Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py new file mode 100644 index 0000000000000..62947d122e01c --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates a special case of wrapping a request-level logits +processor, namely the case where it is necessary to utilize engine config or +environment info passed to the constructor. The subclass must override the +wrapper base class `__init__()` method to access the engine config, the device +identifier, or the flag which indicates whether pinned memory is available. + +For demo purposes, a request-level dummy logits processor is employed which +causes the same token (`target_token`) to be decoded in each step. The +request-level dummy logits processor is wrapped to create a batch-level logits +processor, which can apply the logits processor to output logits from all +requests in the persistent batch in a given decode step. + +The wrapped dummy logits processor below models a scenario where we must +disable the logits processor on non-"cuda" platforms. The wrapper base class +`__init__()` is overridden in order to check this condition and set a flag. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect that on a "cuda" device the output will look something like: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ + +which indicates that the logits processor is running. However, on a non-"cuda" +device, the first and third requests would not repeat the same token. +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + self.is_cuda = device.type == "cuda" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value, and the device + must be "cuda"-type + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + if ( + not self.is_cuda + or ( + target_token := params.extra_args + and params.extra_args.get("target_token") + ) + is None + ): + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index f0c00bcaaeb11..6040683c68bcd 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -23,7 +23,7 @@ def create_test_prompts( 2 requests for base model, 4 requests for the LoRA. We define 2 different LoRA adapters (using the same model for demo purposes). Since we also set `max_loras=1`, the expectation is that the requests - with the second LoRA adapter will be ran after all requests with the + with the second LoRA adapter will be run after all requests with the first adapter have finished. """ return [ diff --git a/examples/offline_inference/neuron.py b/examples/offline_inference/neuron.py deleted file mode 100644 index 7826629a36d01..0000000000000 --- a/examples/offline_inference/neuron.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=1024, - block_size=1024, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py deleted file mode 100644 index 8b1d235ff9742..0000000000000 --- a/examples/offline_inference/neuron_eagle.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with an EAGLE speculative -decoding model on neuron. To use EAGLE speculative decoding, you must use -a draft model that is specifically fine-tuned for EAGLE speculation. -Additionally, to use EAGLE with NxD Inference, the draft model must include -the LM head weights from the target model. These weights are shared between -the draft and target model. -""" - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "What is annapurna labs?", -] - - -def main(): - # Create a sampling params object. - sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) - - # Create an LLM. - llm = LLM( - model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", - speculative_config={ - "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", - "num_speculative_tokens": 5, - "max_model_len": 2048, - }, - max_num_seqs=4, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in neuronx-distributed-inference. - max_model_len=2048, - block_size=2048, - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - }, - ) - - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}") - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_int8_quantization.py b/examples/offline_inference/neuron_int8_quantization.py deleted file mode 100644 index c0ecfac508996..0000000000000 --- a/examples/offline_inference/neuron_int8_quantization.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os - -from vllm import LLM, SamplingParams - -# creates XLA hlo graphs for all the context length buckets. -os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" -# creates XLA hlo graphs for all the token gen buckets. -os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" -# Quantizes neuron model weight to int8 , -# The default config for quantization is int8 dtype. -os.environ["NEURON_QUANT_DTYPE"] = "s8" - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=2048, - block_size=2048, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - quantization="neuron_quant", - override_neuron_config={ - "cast_logits_dtype": "bfloat16", - }, - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py deleted file mode 100644 index 26f7505f2fa53..0000000000000 --- a/examples/offline_inference/neuron_multimodal.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import requests -import torch -from neuronx_distributed_inference.models.mllama.utils import add_instruct -from PIL import Image - -from vllm import LLM, SamplingParams, TextPrompt - - -def get_image(image_url): - image = Image.open(requests.get(image_url, stream=True).raw) - return image - - -# Model Inputs -PROMPTS = [ - "What is in this image? Tell me a story", - "What is the recipe of mayonnaise in two sentences?", - "Describe this image", - "What is the capital of Italy famous for?", -] -IMAGES = [ - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, -] -SAMPLING_PARAMS = [ - dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16) - for _ in range(len(PROMPTS)) -] - - -def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params): - # Prepare all inputs for mllama generation, including: - # 1. put text prompt into instruct chat template - # 2. compose single text and single image prompt into Vllm's prompt class - # 3. prepare sampling parameters - input_image = single_image - has_image = torch.tensor([1]) - if isinstance(single_image, torch.Tensor) and single_image.numel() == 0: - has_image = torch.tensor([0]) - - instruct_prompt = add_instruct(prompt, has_image) - inputs = TextPrompt(prompt=instruct_prompt) - - if input_image is not None: - inputs["multi_modal_data"] = {"image": input_image} - - sampling_params = SamplingParams(**sampling_params) - return inputs, sampling_params - - -def print_outputs(outputs): - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - assert ( - len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) - ), f"""Text, image prompts and sampling parameters should have the - same batch size; but got {len(PROMPTS)}, {len(IMAGES)}, - and {len(SAMPLING_PARAMS)}""" - - # Create an LLM. - llm = LLM( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_num_seqs=1, - max_model_len=4096, - block_size=4096, - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "save_sharded_checkpoint": True, - "on_device_sampling_config": { - "global_topk": 1, - "dynamic": False, - "deterministic": False, - }, - }, - ) - - batched_inputs = [] - batched_sample_params = [] - for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS): - inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params) - # test batch-size = 1 - outputs = llm.generate(inputs, sampling_params) - print_outputs(outputs) - batched_inputs.append(inputs) - batched_sample_params.append(sampling_params) - - # test batch-size = 4 - outputs = llm.generate(batched_inputs, batched_sample_params) - print_outputs(outputs) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py deleted file mode 100644 index 7fc22caee742d..0000000000000 --- a/examples/offline_inference/neuron_speculation.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with a speculative -decoding model on neuron. -""" - -import os - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, I am a language model and I can help", - "The president of the United States is", - "The capital of France is", -] - - -def config_buckets(): - """Configure context length and token gen buckets.""" - # creates XLA hlo graphs for all the context length buckets. - os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" - # creates XLA hlo graphs for all the token gen buckets. - os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" - - -def initialize_llm(): - """Create an LLM with speculative decoding.""" - return LLM( - model="openlm-research/open_llama_7b", - speculative_config={ - "model": "openlm-research/open_llama_3b", - "num_speculative_tokens": 4, - "max_model_len": 2048, - }, - max_num_seqs=4, - max_model_len=2048, - block_size=2048, - device="neuron", - tensor_parallel_size=32, - ) - - -def process_requests(llm: LLM, sampling_params: SamplingParams): - """Generate texts from prompts and print them.""" - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - """Main function that sets up the llm and processes prompts.""" - config_buckets() - llm = initialize_llm() - # Create a sampling params object. - sampling_params = SamplingParams(max_tokens=100, top_k=1) - process_requests(llm, sampling_params) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md new file mode 100644 index 0000000000000..79afbd9cfac47 --- /dev/null +++ b/examples/offline_inference/pooling/README.md @@ -0,0 +1,39 @@ +# Pooling models + +## Convert llm model to seq cls + +```bash +# for BAAI/bge-reranker-v2-gemma +# Caution: "Yes" and "yes" are two different tokens +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# for mxbai-rerank-v2 +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls +# for Qwen3-Reranker +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls +``` + +## Embed jina_embeddings_v3 usage + +Only text matching task is supported for now. See <gh-pr:16120> + +```bash +python examples/offline_inference/pooling/embed_jina_embeddings_v3.py +``` + +## Embed matryoshka dimensions usage + +```bash +python examples/offline_inference/pooling/embed_matryoshka_fy.py +``` + +## Named Entity Recognition (NER) usage + +```bash +python examples/offline_inference/pooling/ner.py +``` + +## Qwen3 reranker usage + +```bash +python examples/offline_inference/pooling/qwen3_reranker.py +``` diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/pooling/convert_model_to_seq_cls.py similarity index 100% rename from examples/offline_inference/convert_model_to_seq_cls.py rename to examples/offline_inference/pooling/convert_model_to_seq_cls.py diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/pooling/embed_jina_embeddings_v3.py similarity index 100% rename from examples/offline_inference/embed_jina_embeddings_v3.py rename to examples/offline_inference/pooling/embed_jina_embeddings_v3.py diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/pooling/embed_matryoshka_fy.py similarity index 100% rename from examples/offline_inference/embed_matryoshka_fy.py rename to examples/offline_inference/pooling/embed_matryoshka_fy.py diff --git a/examples/offline_inference/pooling/ner.py b/examples/offline_inference/pooling/ner.py new file mode 100644 index 0000000000000..f18742fac0d54 --- /dev/null +++ b/examples/offline_inference/pooling/ner.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="boltuix/NeuroBERT-NER", + runner="pooling", + enforce_eager=True, + trust_remote_code=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + ] + + # Create an LLM. + llm = LLM(**vars(args)) + tokenizer = llm.get_tokenizer() + label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label + + # Run inference + outputs = llm.encode(prompts) + + for prompt, output in zip(prompts, outputs): + logits = output.outputs.data + predictions = logits.argmax(dim=-1) + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids) + labels = [label_map[p.item()] for p in predictions] + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/pooling/qwen3_reranker.py similarity index 100% rename from examples/offline_inference/qwen3_reranker.py rename to examples/offline_inference/pooling/qwen3_reranker.py diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index b6007b9f46301..1a5879a6d35f5 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -45,7 +45,11 @@ datamodule_config = { class PrithviMAE: def __init__(self, model): self.model = LLM( - model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True + model=model, + skip_tokenizer_init=True, + dtype="float16", + enforce_eager=True, + model_impl="terratorch", ) def run(self, input_data, location_coords): diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py new file mode 100644 index 0000000000000..418c40645f9f2 --- /dev/null +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import os + +import torch + +from vllm import LLM +from vllm.pooling_params import PoolingParams + +# This example shows how to perform an offline inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Requirement - install plugin at: +# https://github.com/christian-pinto/prithvi_io_processor_plugin + + +def main(): + torch.set_default_dtype(torch.float16) + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + llm = LLM( + model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM. + # The maximum number depends on the available GPU memory + max_num_seqs=32, + io_processor_plugin="prithvi_to_tiff", + model_impl="terratorch", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + pooler_output = llm.encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + print(output) + decoded_data = base64.b64decode(output.data) + + file_path = os.path.join(os.getcwd(), "offline_prediction.tiff") + with open(file_path, "wb") as f: + f.write(decoded_data) + + print(f"Output file path: {file_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py deleted file mode 100644 index 392fba8fc5ead..0000000000000 --- a/examples/offline_inference/profiling.py +++ /dev/null @@ -1,510 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import inspect -import json -import os -import sys -from argparse import RawTextHelpFormatter -from collections.abc import Generator -from dataclasses import asdict, dataclass -from typing import Any, Optional, TypeAlias - -import torch -import tqdm - -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.profiler.layerwise_profile import layerwise_profile -from vllm.utils import FlexibleArgumentParser - -BATCH_SIZE_DEFAULT = 1 -PROMPT_LEN_DEFAULT = 256 - - -@dataclass -class ProfileContext: - engine_args: EngineArgs - prompt_len: int - batch_size: int - - # The profiler can run in 2 modes, - # 1. Run profiler for user specified num_steps - num_steps: Optional[int] = None - # 2. Run profiler until all requests complete - complete_num_requests_per_step: Optional[int] = None - - save_chrome_traces_folder: Optional[str] = None - - -def get_dtype(dtype: str): - if dtype == "torch.float": - return torch.float - else: - return dtype - - -OutputLen_NumReqs_Map: TypeAlias = dict[int, int] - - -def compute_request_output_lengths( - batch_size: int, step_requests: list[int] -) -> OutputLen_NumReqs_Map: - """ - Given the number of requests, batch_size, and the number of requests - that each engine-step should process, step_requests, determine the - output lengths of the requests such that step_request is honoured. - - Example: - if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1] - then return, - {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning, - 32 requests should have output length 2, - 32 requests should have output length 3, - 32 requests should have output length 4, - 31 requests should have output length 5, - 1 request should have output length 6. - - Args: - batch_size (int): Number of requests submitted for profile. This is - args.batch_size. - step_requests (list[int]): step_requests[i] is the number of requests - that the ith engine step should process. - - Returns: - OutputLen_NumReqs_Map : A dictionary with output-length as keys and the - number of requests required to have that output-length as values. - """ - ol_nr: OutputLen_NumReqs_Map = {} - - # Number of request that are assigned an output-length - num_reqs_assigned: int = 0 - num_steps: int = len(step_requests) - - # sanity check. The first step (prefill-step), must process all requests. - assert step_requests[0] == batch_size - - # Begin assignments from the last step. - output_length: int = num_steps - for num_requests_at_step in reversed(step_requests): - if num_reqs_assigned == batch_size: - break - - assert num_reqs_assigned < batch_size - - # Remove the number of requests that have been determined - # to participate in this step and beyond. - num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned - assert num_reqs_unassigned_at_step >= 0 - - if num_reqs_unassigned_at_step > 0: - ol_nr[output_length] = num_reqs_unassigned_at_step - num_reqs_assigned += num_reqs_unassigned_at_step - - output_length -= 1 - - # sanity checks. - assert sum(ol_nr.values()) == batch_size, ( - "Number of requests in output-length assignment does not match " - f"batch-size.\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}" - ) - - # Check that the output-length is in [1, num-steps]. Output length must be - # at least 1 as all requests must participate in the prefill-step. - assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), ( - "Output lengths of requests should be in range " - f"[1, num-engine-steps].\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}" - ) - - return ol_nr - - -def determine_requests_per_step(context: ProfileContext) -> list[int]: - """ - Determine number of requests each engine step should process. - If context.num_steps is set, then all engine steps process the - same number of requests and the output list is of length - context.num_steps. - - If context.complete_num_requests_per_step is set, then each decode step - processes fewer and fewer requests until there are no requests to process. - In this case, the output list is as big as the number of steps - required to process all requests. - - Args: - context: ProfileContext object. - - Returns: - list[int]: Number of requests to process for all engine-steps. - output[i], contains the number of requests that the ith step - should process. - """ - if context.num_steps: - # All requests must run until num_engine_steps. This implies - # that their output lengths must be equal to num_engine_steps. - return [context.batch_size] * context.num_steps - - assert ( - context.complete_num_requests_per_step - and context.complete_num_requests_per_step > 0 - ), ( - f"Expected a positive complete_num_requests_per_step argument." - f"Instead got {context.complete_num_requests_per_step}" - ) - - # We start dropping after the first decode step. - step_requests = [ - context.batch_size, # prefill - context.batch_size, # decode - ] - - num_running_requests = context.batch_size - num_running_requests -= context.complete_num_requests_per_step - while num_running_requests > 0: - step_requests.append(num_running_requests) - num_running_requests -= context.complete_num_requests_per_step - - if step_requests[-1] != 1: - # have 1 request running at the last step. This is often - # useful - step_requests.append(1) - - return step_requests - - -def run_profile( - context: ProfileContext, csv_output: Optional[str], json_output: Optional[str] -): - print("Run profile with:") - for key, value in asdict(context).items(): - print(f" {key} = {value}") - - requests_per_step: list[int] = determine_requests_per_step(context) - - ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( - context.batch_size, requests_per_step - ) - - num_steps_to_profile: int = len(requests_per_step) - max_output_len: int = max(ol_nr.keys()) - assert max_output_len >= 1 - - # Create sampling params - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - # max_tokens is set on a per-request basis. - max_tokens=None, - ignore_eos=True, - ) - - # Create LLM - llm = LLM(**asdict(context.engine_args)) - batch_size = context.batch_size - prompt_len = context.prompt_len - - scheduler_config = llm.llm_engine.vllm_config.scheduler_config - max_model_len = llm.llm_engine.model_config.max_model_len - max_num_batched_tokens = scheduler_config.max_num_batched_tokens - max_num_seqs = scheduler_config.max_num_seqs - - if batch_size * prompt_len > max_num_batched_tokens: - print( - f"ERROR: chosen batch_size * prompt_len " - f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " - f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " - f"and therefore cannot be run in a single profile step, please " - f"choose a smaller batch size or prompt length, or increase " - f"--max-num-batched-tokens" - ) - sys.exit(-1) - if batch_size > max_num_seqs: - print( - f"ERROR: chosen batch_size ({batch_size}) is larger than " - f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " - f"single profile step, please choose a smaller batch size" - ) - sys.exit(-1) - print( - "llm.llm_engine.model_config.max_model_len: ", - llm.llm_engine.model_config.max_model_len, - ) - if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: - print( - f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " - f"{max_output_len} = {prompt_len + max_output_len}) is larger " - f"than the model's max_model_len ({max_model_len}), please " - f"choose a smaller prompt_len or max_output_len, or increase " - f"--max-model-len" - ) - sys.exit(-1) - - def add_requests(): - def get_output_len_generator() -> Generator[int, Any, Any]: - for output_len, num_reqs in ol_nr.items(): - for _ in range(num_reqs): - yield output_len - - output_len_generator = get_output_len_generator() - for i in range(batch_size): - sampling_params.max_tokens = next(output_len_generator) - assert isinstance(sampling_params.max_tokens, int) - - prompt_token_ids = torch.randint( - llm.get_tokenizer().vocab_size, size=(prompt_len,) - ).tolist() - - llm.llm_engine.add_request( - request_id=f"seq{i}", - prompt={"prompt_token_ids": prompt_token_ids}, - params=sampling_params, - ) - - def abort_requests(): - for i in range(batch_size): - llm.llm_engine.abort_request(f"seq{i}") - - # Warm up run - print("Warm up run ...") - add_requests() - llm.llm_engine.step() # Prefill - llm.llm_engine.step() # Decode - abort_requests() - - print("Profile run ...") - add_requests() - - with layerwise_profile() as prefill_prof: - llm.llm_engine.step() # First step is prefill - - decode_profs = [] - for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): - num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups() - with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof: - llm.llm_engine.step() - decode_profs.append(decode_prof) - - decode_results_list = [prof.results for prof in decode_profs] - prefill_results = prefill_prof.results - has_decode = len(decode_results_list) > 0 - - LINE_WIDTH = 80 - print("=" * LINE_WIDTH) - print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * LINE_WIDTH) - print() - prefill_results.print_model_table() - - if has_decode: - print() - print("=" * LINE_WIDTH) - print( - f"= First Decode Step Model Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})" - ) - print("=" * LINE_WIDTH) - print() - decode_results_list[0].print_model_table() - - print() - print("=" * LINE_WIDTH) - print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * LINE_WIDTH) - print() - prefill_results.print_summary_table() - - if has_decode: - print() - print("=" * LINE_WIDTH) - print( - f"= First Decode Step Summary Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})" - ) - print("=" * LINE_WIDTH) - print() - decode_results_list[0].print_summary_table() - - if csv_output: - csv_filename_base = ( - csv_output[:-4] if csv_output.endswith(".csv") else csv_output - ) - prefill_results.export_model_stats_table_csv( - csv_filename_base + "_prefill_model_table.csv" - ) - prefill_results.export_summary_stats_table_csv( - csv_filename_base + "_prefill_summary_table.csv" - ) - - if has_decode: - decode_results_list[0].export_model_stats_table_csv( - csv_filename_base + "_decode_model_table.csv" - ) - decode_results_list[0].export_summary_stats_table_csv( - csv_filename_base + "_decode_summary_table.csv" - ) - - if json_output: - cuda_devices = [ - torch.cuda.get_device_properties(dev_idx) - for dev_idx in range(torch.cuda.device_count()) - ] - - json_dict = { - "context": { - "python_version": f"{sys.version}", - "torch_version": f"{torch.__version__}", - "torch_cuda_version": f"{torch.version.cuda}", - "cuda_devices": f"{cuda_devices}", - **asdict(context), - }, - "prefill": prefill_results.convert_stats_to_dict(), - } - - if has_decode: - for idx, dr in enumerate(decode_results_list): - json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - - # Add .json to json_output filename if it doesn't exist already. - json_output_file = ( - json_output if json_output.endswith(".json") else json_output + ".json" - ) - with open(json_output_file, "w+") as f: - json.dump(json_dict, f, indent=2) - pass - - if context.save_chrome_traces_folder is not None: - os.makedirs(context.save_chrome_traces_folder, exist_ok=True) - prefill_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + "/prefill.json" - ) - for idx, decode_prof in enumerate(decode_profs): - decode_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + f"/decode_{idx + 1}.json" - ) - print( - "Traces saved as prefill.json and decode_1.json, etc." - f" in folder {context.save_chrome_traces_folder}" - ) - - -def parse_args(): - parser = FlexibleArgumentParser( - description=""" -Profile a model - - example: - ``` - python examples/offline_inference/profiling.py \\ - --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ - --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ - --enforce-eager run_num_steps -n 2 - ``` - - then you can use various tools to analyze the json output - terminal ascii tables: - ``` - python tools/profiler/print_layerwise_table.py \\ - --json-trace Llama31-8b-FP8.json --phase prefill --table summary - ``` - or create matplotlib stacked bar charts: - ``` - python tools/profiler/visualize_layerwise_profile.py \\ - --json-trace Llama31-8b-FP8.json \\ - --output-directory profile_breakdown --plot-metric pct_cuda_time - ``` -""", - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--csv", - type=str, - default=None, - help="Export the results as multiple csv file. This should be the root " - "filename, will create <filename>_prefill_model_table.csv, " - "<filename>_prefill_summary_table.csv, " - "<filename>_decode_model_table.csv, and " - "<filename>_decode_summary_table.csv", - ) - parser.add_argument( - "--json", - type=str, - default=None, - help="Export the results as a json file. This should be the filename", - ) - parser.add_argument( - "--save-chrome-traces-folder", - type=str, - help="Save chrome traces for the prefill and decode " - "will save traces as prefill.json and decode_1.json, " - "etc. inside this folder", - ) - parser.add_argument( - "--prompt-len", - type=int, - default=PROMPT_LEN_DEFAULT, - help=f"Length of the random prompt to use when profiling, all batched " - f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}", - ) - parser.add_argument( - "--batch-size", - type=int, - default=BATCH_SIZE_DEFAULT, - help=f"Number of requests to run as a single batch, " - f"default={BATCH_SIZE_DEFAULT}", - ) - - subparsers = parser.add_subparsers(dest="cmd") - - run_num_steps_parser = subparsers.add_parser( - "run_num_steps", help="This variation profiles n engine.step() invocations." - ) - run_num_steps_parser.add_argument( - "-n", - "--num-steps", - type=int, - help="Number of engine steps to profile.\n" - "Setting it to 1, profiles only the prefill step.\n" - "Setting it to 2, profiles the prefill and first decode step\n" - "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" - "and so on ...", - ) - - run_to_completion_parser = subparsers.add_parser( - "run_to_completion", - help="This variation profiles all the engine.step() invocations" - "until the engine exhausts all submitted requests.", - ) - run_to_completion_parser.add_argument( - "-n", - "--complete-num-requests-per-step", - type=int, - help="Complete complete_num_requests_per_step requests every decode step." - "For e.g., with batch_size 128 and complete_num_requests_per_step 32," - "the profiler is run for 6 engine steps, with the steps processing, " - "128, 128, 96, 64, 32, 1 requests respectively.\n" - "Note that we tack-on a one-request step at the end as it is often " - "useful.", - ) - - EngineArgs.add_cli_args(parser) - - return parser.parse_args() - - -def main(args): - context = ProfileContext( - engine_args=EngineArgs.from_cli_args(args), - **{ - k: v - for k, v in vars(args).items() - if k in inspect.signature(ProfileContext).parameters - }, - ) - run_profile(context, csv_output=args.csv, json_output=args.json) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py index d8d61667f688b..c8d0d91ce7b5c 100644 --- a/examples/offline_inference/qwen_1m.py +++ b/examples/offline_inference/qwen_1m.py @@ -5,7 +5,6 @@ from urllib.request import urlopen from vllm import LLM, SamplingParams -os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 65621023ab6ce..360fd79b55aad 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -28,12 +28,15 @@ Learn more about Ray placement groups: https://docs.ray.io/en/latest/placement-groups.html """ +import gc import os import ray import torch +import zmq from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.multiprocessing.reductions import reduce_tensor from vllm import LLM @@ -86,20 +89,72 @@ class RayTrainingActor: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) + self.zmq_context = zmq.Context() + self.zmq_address_counter = 0 + self.zmq_handle = None def report_device_id(self) -> str: return self.device_uuid - def get_weight_ipc_handles(self): - from torch.multiprocessing.reductions import reduce_tensor + def get_zmq_handles(self) -> dict[str, str]: + suffix = f"{self.device_uuid}-{self.zmq_address_counter}" + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" + self.zmq_address_counter += 1 + return {self.device_uuid: self.zmq_handle} - data = {} - for name, p in self.model.named_parameters(): - # A training actor might hold only a subset of the weights and may - # need to gather weights from other actors. For demonstration - # purposes, each training actor owns the full weight set. - data[name] = reduce_tensor(p.detach()) - return {self.device_uuid: data} + def update_weights(self): + # align size to avoid misaligned address + align_size = 256 + + def get_size(p: torch.Tensor) -> int: + return (p.nbytes + align_size - 1) // align_size * align_size + + named_parameters: dict[str, torch.nn.Parameter] = dict( + self.model.named_parameters() + ) + max_tensor_size = max(get_size(p) for p in named_parameters.values()) + # use max_tensor_size * 2 as buffer size + buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") + s = self.zmq_context.socket(zmq.REQ) + s.bind(self.zmq_handle) + handle = reduce_tensor(buffer) + + offset = 0 + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] + named_tensors: list[dict] = [] + real_tensors: list[torch.Tensor] = [] + for name, p in named_parameters.items(): + size = get_size(p) + if offset + size > buffer.numel(): + buckets.append((named_tensors, real_tensors)) + named_tensors, real_tensors = [], [] + offset = 0 + # assume tensors are contiguous + named_tensors.append( + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} + ) + real_tensors.append(p) + offset += size + if named_tensors: + buckets.append((named_tensors, real_tensors)) + s.send_pyobj(handle) + s.recv() + for named_tensors, real_tensors in buckets: + offset = 0 + for p in real_tensors: + buffer[offset : offset + p.nbytes].data.copy_( + p.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + offset += get_size(p) + torch.cuda.synchronize() + s.send_pyobj(named_tensors) + s.recv() + s.send_pyobj(None) + s.recv() + s.close() + del buffer + gc.collect() + torch.cuda.empty_cache() # Ray manages four GPUs. @@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0] # the second inference engine. assert training_actor_device_ids[2:] == inference_engine_device_ids[1] -print("Gather all the IPC handles from the training actors.") -ipc_handles = {} +print("Gather all the ZMQ handles from the training actors.") +zmq_handles = {} for actor in training_actors: - ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) + +print(f"ZMQ handles: {zmq_handles}") print("Update the weights of the inference engines.") -for llm in inference_engines: - ray.get( - llm.collective_rpc.remote( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) - ) +ray.get( + [actor.update_weights.remote() for actor in training_actors] + + [ + llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) + for llm in inference_engines + ] +) + print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index d2a8419ffabcd..c0e60b9793407 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from typing import Callable, Optional, TypedDict + import torch +import zmq def stateless_init_process_group(master_address, master_port, rank, world_size, device): @@ -66,6 +70,27 @@ class WorkerExtension: return weights_updated +def rebuild_ipc( + handle: tuple[Callable, tuple], device_id: Optional[int] = None +) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # specify the start offset of this tensor in shared ipc_buffer tensor + offset: int + + class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -76,27 +101,62 @@ class ColocateWorkerExtension: should pass the full qualified name as `worker_extension_cls` argument. """ + def update_weights_from_ipc(self, zmq_handles: dict[str, str]): + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.REP) + socket.connect(zmq_handles[self.report_device_id()]) + buffer: Optional[torch.Tensor] = None + while True: + payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( + socket.recv_pyobj() + ) + if payload is None: + # means the update is done + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + torch.cuda.synchronize() + socket.send(b"") + break + if isinstance(payload, tuple): + # an ipc handle that vLLM can use `func, args = handle` + # and `func(*args)` to rebuild GPU tensor. + buffer = rebuild_ipc(payload, self.device.index) + assert buffer.dtype == torch.uint8 + socket.send(b"") + continue + assert isinstance(payload, list) + assert buffer is not None + weights = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + self.model_runner.model.load_weights(weights=weights) + del weights + torch.cuda.synchronize() + socket.send(b"") + + socket.close() + del buffer + gc.collect() + torch.cuda.empty_cache() + def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - def check_weights_changed(self): """ Check if the weights are updated to 0. diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index c4972f02d0f8e..af65b6d38e02c 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts): def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") parser.add_argument( "--method", type=str, @@ -61,6 +62,7 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -72,8 +74,7 @@ def parse_args(): return parser.parse_args() -def main(): - args = parse_args() +def main(args): args.endpoint_type = "openai-chat" model_dir = args.model_dir @@ -118,6 +119,11 @@ def main(): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } else: raise ValueError(f"unknown method: {args.method}") @@ -130,7 +136,7 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, - max_model_len=16384, + max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, ) @@ -138,7 +144,7 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - TokensPrompt(prompt_token_ids=prompt_ids), + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], sampling_params=sampling_params, ) else: @@ -194,6 +200,39 @@ def main(): acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 print(f"acceptance at token {i}: {acceptance_rate:.2f}") + return acceptance_length + if __name__ == "__main__": - main() + args = parse_args() + acceptance_length = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + ) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 88d87beb4874d..6b6099f71b120 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -This file demonstrates the example usage of guided decoding -to generate structured outputs using vLLM. It shows how to apply -different guided decoding techniques such as Choice, Regex, JSON schema, -and Grammar to produce structured and formatted results -based on specific prompts. +This file demonstrates the example usage of structured outputs +in vLLM. It shows how to apply different constraints such as choice, +regex, json schema, and grammar to produce structured and formatted +results based on specific prompts. """ from enum import Enum @@ -13,19 +12,23 @@ from enum import Enum from pydantic import BaseModel from vllm import LLM, SamplingParams -from vllm.sampling_params import GuidedDecodingParams +from vllm.sampling_params import StructuredOutputsParams MAX_TOKENS = 50 -# Guided decoding by Choice (list of possible options) -guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) -sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) +# Structured outputs by Choice (list of possible options) +structured_outputs_params_choice = StructuredOutputsParams( + choice=["Positive", "Negative"] +) +sampling_params_choice = SamplingParams( + structured_outputs=structured_outputs_params_choice +) prompt_choice = "Classify this sentiment: vLLM is wonderful!" -# Guided decoding by Regex -guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") +# Structured outputs by Regex +structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, + structured_outputs=structured_outputs_params_regex, stop=["\n"], max_tokens=MAX_TOKENS, ) @@ -36,7 +39,7 @@ prompt_regex = ( ) -# Guided decoding by JSON using Pydantic schema +# Structured outputs by JSON using Pydantic schema class CarType(str, Enum): sedan = "sedan" suv = "SUV" @@ -51,17 +54,16 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() -guided_decoding_params_json = GuidedDecodingParams(json=json_schema) +structured_outputs_params_json = StructuredOutputsParams(json=json_schema) sampling_params_json = SamplingParams( - guided_decoding=guided_decoding_params_json, - max_tokens=MAX_TOKENS, + structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS ) prompt_json = ( - "Generate a JSON with the brand, model and car_type of" + "Generate a JSON with the brand, model and car_type of " "the most iconic car from the 90's" ) -# Guided decoding by Grammar +# Structured outputs by Grammar simplified_sql_grammar = """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition @@ -70,13 +72,15 @@ table ::= "table_1 " | "table_2 " condition ::= column "= " number number ::= "1 " | "2 " """ -guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) +structured_outputs_params_grammar = StructuredOutputsParams( + grammar=simplified_sql_grammar +) sampling_params_grammar = SamplingParams( - guided_decoding=guided_decoding_params_grammar, + structured_outputs=structured_outputs_params_grammar, max_tokens=MAX_TOKENS, ) prompt_grammar = ( - "Generate an SQL query to show the 'username' and 'email'from the 'users' table." + "Generate an SQL query to show the 'username' and 'email' from the 'users' table." ) @@ -93,16 +97,16 @@ def main(): llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) choice_output = generate_output(prompt_choice, sampling_params_choice, llm) - format_output("Guided decoding by Choice", choice_output) + format_output("Structured outputs by Choice", choice_output) regex_output = generate_output(prompt_regex, sampling_params_regex, llm) - format_output("Guided decoding by Regex", regex_output) + format_output("Structured outputs by Regex", regex_output) json_output = generate_output(prompt_json, sampling_params_json, llm) - format_output("Guided decoding by JSON", json_output) + format_output("Structured outputs by JSON", json_output) grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm) - format_output("Guided decoding by Grammar", grammar_output) + format_output("Structured outputs by Grammar", grammar_output) if __name__ == "__main__": diff --git a/examples/offline_inference/torchrun_dp_example.py b/examples/offline_inference/torchrun_dp_example.py new file mode 100644 index 0000000000000..295d1637528cd --- /dev/null +++ b/examples/offline_inference/torchrun_dp_example.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +experimental support for data-parallel inference with torchrun +Note the data load balancing and distribution is done out of the vllm engine, +no internal lb supported in external_launcher mode. + +To run this example: +```bash +$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py +``` +""" + +from vllm import LLM, SamplingParams + +# Create prompts, the same across all ranks +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create sampling parameters, the same across all ranks +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Use `distributed_executor_backend="external_launcher"` so that +# this llm engine/instance only creates one worker. +# it is important to set an explicit seed to make sure that +# all ranks have the same random seed, so that sampling can be +# deterministic across ranks. +llm = LLM( + model="microsoft/Phi-mini-MoE-instruct", + tensor_parallel_size=1, + data_parallel_size=2, + pipeline_parallel_size=1, + enable_expert_parallel=False, + distributed_executor_backend="external_launcher", + max_model_len=4096, + gpu_memory_utilization=0.6, + seed=1, +) + +dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank +dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size + +prompts = [ + f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank +] + +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n" + ) + +""" +Further tips: + +1. to communicate control messages across all ranks, use the cpu group, +a PyTorch ProcessGroup with GLOO backend. + +```python +from vllm.distributed.parallel_state import get_world_group +cpu_group = get_world_group().cpu_group +torch_rank = dist.get_rank(group=cpu_group) +if torch_rank == 0: + # do something for rank 0, e.g. saving the results to disk. +``` + +2. to communicate data across all ranks, use the model's device group, +a PyTorch ProcessGroup with NCCL backend. +```python +from vllm.distributed.parallel_state import get_world_group +device_group = get_world_group().device_group +``` + +3. to access the model directly in every rank, use the following code: +```python +llm.llm_engine.model_executor.driver_worker.worker.model_runner.model +``` +""" diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 9776f4fe322b9..0093b63b0b1f3 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -42,7 +42,7 @@ def main(): llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct" # Set `enforce_eager=True` to avoid ahead-of-time compilation. - # In real workloads, `enforace_eager` should be `False`. + # In real workloads, `enforce_eager` should be `False`. llm = LLM(**llm_args) outputs = llm.generate(prompts, sampling_params) print("-" * 50) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 8d97ba2668263..9fd9da3b0855e 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ) +# Dots-OCR +def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + engine_args = EngineArgs( + model="rednote-hilab/dots.ocr", + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -173,21 +190,30 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) -# Florence2 -def run_florence2(questions: list[str], modality: str) -> ModelRequestData: - assert modality == "image" +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" engine_args = EngineArgs( - model="microsoft/Florence-2-large", - tokenizer="Isotr0py/Florence-2-tokenizer", + model=model_name, max_model_len=4096, - max_num_seqs=2, - trust_remote_code=True, - dtype="bfloat16", + max_num_seqs=5, limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, ) - prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions] + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: <think></think>" + ) + for question in questions + ] return ModelRequestData( engine_args=engine_args, @@ -550,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: # Intern-S1 def run_interns1(questions: list[str], modality: str) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, @@ -652,6 +678,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Keye-VL-1.5 +def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1.5-8B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + trust_remote_code=True, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Kimi-VL def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -946,44 +1003,6 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: ) -# LLama 3.2 -def run_mllama(questions: list[str], modality: str) -> ModelRequestData: - assert modality == "image" - - model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" - - # Note: The default setting of max_num_seqs (256) and - # max_model_len (131072) for this model may cause OOM. - # You may lower either to run this example on lower-end GPUs. - - # The configuration below has been confirmed to launch on a single L40 GPU. - engine_args = EngineArgs( - model=model_name, - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={modality: 1}, - ) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - messages = [ - [ - { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "text": question}], - } - ] - for question in questions - ] - prompts = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False - ) - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - # Molmo def run_molmo(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1121,14 +1140,10 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: elif modality == "video": placeholder = "<video>" - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - messages = [ - [{"role": "user", "content": f"{placeholder}\n{question}"}] + prompts = [ + f"<|im_start|>user\n\n{placeholder}\n{question}<|im_end|>\n<|im_start|>assistant\n" for question in questions ] - prompts = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) return ModelRequestData( engine_args=engine_args, @@ -1435,6 +1450,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# Qwen3-VL-Dense +def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-VL-4B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# Qwen3-VL-MOE +def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # R-4B def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1600,9 +1689,10 @@ model_example_map = { "aya_vision": run_aya_vision, "blip-2": run_blip2, "chameleon": run_chameleon, + "dots_ocr": run_dots_ocr, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, - "florence2": run_florence2, + "ernie45_vl": run_ernie45_vl, "fuyu": run_fuyu, "gemma3": run_gemma3, "gemma3n": run_gemma3n, @@ -1616,6 +1706,7 @@ model_example_map = { "interns1": run_interns1, "internvl_chat": run_internvl, "keye_vl": run_keye_vl, + "keye_vl1_5": run_keye_vl1_5, "kimi_vl": run_kimi_vl, "llama4": run_llama4, "llava": run_llava, @@ -1627,7 +1718,6 @@ model_example_map = { "minicpmv": run_minicpmv, "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, - "mllama": run_mllama, "molmo": run_molmo, "nemotron_vl": run_nemotron_vl, "NVLM_D": run_nvlm_d, @@ -1643,6 +1733,8 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "qwen3_vl": run_qwen3_vl, + "qwen3_vl_moe": run_qwen3_vl_moe, "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, @@ -1652,6 +1744,15 @@ model_example_map = { } +MODELS_NEED_VIDEO_METADATA = [ + "glm4_1v", + "glm4_5v", + "glm4_5v_fp8", + "qwen3_vl", + "qwen3_vl_moe", +] + + def get_multi_modal_input(args): """ return { @@ -1676,12 +1777,13 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question + needs_metadata = args.model_type in MODELS_NEED_VIDEO_METADATA video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, + "data": ([(video, metadata)] if needs_metadata else video), "questions": vid_questions, } @@ -1700,6 +1802,7 @@ def apply_image_repeat( probs = [1.0 - image_repeat_prob, image_repeat_prob] inputs = [] + inputs_with_empty_media = [] cur_image = data for i in range(num_prompts): if image_repeat_prob is not None: @@ -1710,14 +1813,25 @@ def apply_image_repeat( new_val = (i // 256 // 256, i // 256, i % 256) cur_image.putpixel((0, 0), new_val) + uuid = "uuid_{}".format(i) + inputs.append( { "prompt": prompts[i % len(prompts)], "multi_modal_data": {modality: cur_image}, + "multi_modal_uuids": {modality: uuid}, } ) - return inputs + inputs_with_empty_media.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, + } + ) + + return inputs, inputs_with_empty_media @contextmanager @@ -1796,6 +1910,13 @@ def parse_args(): help="If True, then use different prompt (with the same multi-modal " "data) for each request.", ) + + parser.add_argument( + "--verify-mm-cache-hit-with-uuids", + action="store_true", + help="If True, will send all requests in a second batch with empty mm " + "data to verify cache hits with UUIDs.", + ) return parser.parse_args() @@ -1839,26 +1960,48 @@ def main(args): assert args.num_prompts > 0 if args.num_prompts == 1: # Single inference + uuid = "uuid_0" inputs = { "prompt": prompts[0], "multi_modal_data": {modality: data}, + "multi_modal_uuids": {modality: uuid}, + } + inputs_with_empty_media = { + "prompt": prompts[0], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, } else: # Batch inference if args.image_repeat_prob is not None: # Repeat images with specified probability of "image_repeat_prob" - inputs = apply_image_repeat( - args.image_repeat_prob, args.num_prompts, data, prompts, modality + inputs, inputs_with_empty_media = apply_image_repeat( + args.image_repeat_prob, + args.num_prompts, + data, + prompts, + modality, ) else: # Use the same image for all prompts - inputs = [ - { - "prompt": prompts[i % len(prompts)], - "multi_modal_data": {modality: data}, - } - for i in range(args.num_prompts) - ] + inputs = [] + inputs_with_empty_media = [] + for i in range(args.num_prompts): + uuid = "uuid_{}".format(i) + inputs.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: data}, + "multi_modal_uuids": {modality: uuid}, + } + ) + inputs_with_empty_media.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, + } + ) # Add LoRA request if applicable lora_request = ( @@ -1878,6 +2021,26 @@ def main(args): print(generated_text) print("-" * 50) + if args.verify_mm_cache_hit_with_uuids: + try: + # Verify cache hits with UUIDs + print( + "Sending a second batch of requests with empty media" + " and matching UUIDs." + ) + outputs = llm.generate( + inputs_with_empty_media, + sampling_params=sampling_params, + lora_request=lora_request, + ) + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + print("-" * 50) + except Exception as e: + print(f"Failed to verify cache hits with UUIDs. Error: {e}") + if __name__ == "__main__": args = parse_args() diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index d9242efa85470..c37d40a23ac20 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData: def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, @@ -371,6 +371,115 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-8B-Preview" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + +def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1_5-8B" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=32768, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + +def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "moonshotai/Kimi-VL-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=4, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" @@ -505,78 +614,6 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa ) -def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "Kwai-Keye/Keye-VL-8B-Preview" - - engine_args = EngineArgs( - model=model_name, - trust_remote_code=True, - max_model_len=8192, - max_num_seqs=5, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - - placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [ - { - "role": "user", - "content": [ - *placeholders, - {"type": "text", "text": question}, - ], - }, - ] - - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) - - prompt = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - image_data = [fetch_image(url) for url in image_urls] - - return ModelRequestData( - engine_args=engine_args, - prompt=prompt, - image_data=image_data, - ) - - -def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "moonshotai/Kimi-VL-A3B-Instruct" - - engine_args = EngineArgs( - model=model_name, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=4, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - - placeholders = [{"type": "image", "image": url} for url in image_urls] - messages = [ - { - "role": "user", - "content": [ - *placeholders, - {"type": "text", "text": question}, - ], - } - ] - - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) - - prompt = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - return ModelRequestData( - engine_args=engine_args, - prompt=prompt, - image_data=[fetch_image(url) for url in image_urls], - ) - - def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -600,26 +637,6 @@ def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: ) -def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" - - # The configuration below has been confirmed to launch on a single L40 GPU. - engine_args = EngineArgs( - model=model_name, - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - - img_prompt = "Given the first image <|image|> and the second image<|image|>" - prompt = f"<|begin_of_text|>{img_prompt}, {question}?" - return ModelRequestData( - engine_args=engine_args, - prompt=prompt, - image_data=[fetch_image(url) for url in image_urls], - ) - - def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "nvidia/NVLM-D-72B" @@ -696,11 +713,9 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData: placeholders = "\n".join( f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) ) - messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + prompt = ( + f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n" + "<|im_start|>assistant\n" ) return ModelRequestData( @@ -1209,13 +1224,13 @@ model_example_map = { "interns1": load_interns1, "internvl_chat": load_internvl, "keye_vl": load_keye_vl, + "keye_vl1_5": load_keye_vl1_5, "kimi_vl": load_kimi_vl, "llama4": load_llama4, "llava": load_llava, "llava-next": load_llava_next, "llava-onevision": load_llava_onevision, "mistral3": load_mistral3, - "mllama": load_mllama, "NVLM_D": load_nvlm_d, "ovis": load_ovis, "ovis2_5": load_ovis2_5, diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/offline_inference/vision_language_pooling.py index 0cc0c1e708b12..33ffb59014d8f 100644 --- a/examples/offline_inference/vision_language_pooling.py +++ b/examples/offline_inference/vision_language_pooling.py @@ -10,6 +10,7 @@ on HuggingFace model repository. from argparse import Namespace from dataclasses import asdict +from pathlib import Path from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args from PIL.Image import Image @@ -19,6 +20,9 @@ from vllm.entrypoints.score_utils import ScoreMultiModalParam from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser +ROOT_DIR = Path(__file__).parent.parent.parent +EXAMPLES_DIR = ROOT_DIR / "examples" + class TextQuery(TypedDict): modality: Literal["text"] @@ -54,6 +58,30 @@ class ModelRequestData(NamedTuple): documents: Optional[ScoreMultiModalParam] = None +def run_clip(query: Query) -> ModelRequestData: + if query["modality"] == "text": + prompt = query["text"] + image = None + elif query["modality"] == "image": + prompt = "" # For image input, make sure that the prompt text is empty + image = query["image"] + else: + modality = query["modality"] + raise ValueError(f"Unsupported query modality: '{modality}'") + + engine_args = EngineArgs( + model="openai/clip-vit-base-patch32", + runner="pooling", + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image=image, + ) + + def run_e5_v(query: Query) -> ModelRequestData: llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 @@ -82,23 +110,27 @@ def run_e5_v(query: Query) -> ModelRequestData: ) -def run_vlm2vec(query: Query) -> ModelRequestData: +def _get_vlm2vec_prompt_image(query: Query, image_token: str): if query["modality"] == "text": text = query["text"] - prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501 + prompt = f"Find me an everyday image that matches the given caption: {text}" image = None elif query["modality"] == "image": - prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501 + prompt = f"{image_token} Find a day-to-day image that looks similar to the provided image." # noqa: E501 image = query["image"] elif query["modality"] == "text+image": text = query["text"] - prompt = ( - f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 - ) + prompt = f"{image_token} Represent the given image with the following question: {text}" # noqa: E501 image = query["image"] else: modality = query["modality"] - raise ValueError(f"Unsupported query modality: '{modality}'") + raise ValueError(f"Unsupported query modality: {modality!r}") + + return prompt, image + + +def run_vlm2vec_phi3v(query: Query) -> ModelRequestData: + prompt, image = _get_vlm2vec_prompt_image(query, "<|image_1|>") engine_args = EngineArgs( model="TIGER-Lab/VLM2Vec-Full", @@ -116,6 +148,69 @@ def run_vlm2vec(query: Query) -> ModelRequestData: ) +def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: + # vLLM does not support LoRA adapters on multi-modal encoder, + # so we merge the weights first + from huggingface_hub.constants import HF_HUB_CACHE + from peft import PeftConfig, PeftModel + from transformers import AutoModelForImageTextToText, AutoProcessor + + from vllm.entrypoints.chat_utils import load_chat_template + + model_id = "TIGER-Lab/VLM2Vec-Qwen2VL-2B" + + base_model = AutoModelForImageTextToText.from_pretrained(model_id) + lora_model = PeftModel.from_pretrained( + base_model, + model_id, + config=PeftConfig.from_pretrained(model_id), + ) + model = lora_model.merge_and_unload().to(dtype=base_model.dtype) + model._hf_peft_config_loaded = False # Needed to save the merged model + + processor = AutoProcessor.from_pretrained( + model_id, + # `min_pixels` and `max_pixels` are deprecated for + # transformers `preprocessor_config.json` + size={"shortest_edge": 3136, "longest_edge": 12845056}, + ) + processor.chat_template = load_chat_template( + # The original chat template is not correct + EXAMPLES_DIR / "template_vlm2vec_qwen2vl.jinja", + ) + + merged_path = str( + Path(HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--") + "-vllm") + ) + print(f"Saving merged model to {merged_path}...") + print( + "NOTE: This directory is not tracked by `huggingface_hub` " + "so you have to delete this manually if you don't want it anymore." + ) + model.save_pretrained(merged_path) + processor.save_pretrained(merged_path) + print("Done!") + + prompt, image = _get_vlm2vec_prompt_image(query, "<|image_pad|>") + + engine_args = EngineArgs( + model=merged_path, + runner="pooling", + max_model_len=4096, + mm_processor_kwargs={ + "min_pixels": 3136, + "max_pixels": 12845056, + }, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image=image, + ) + + def run_jinavl_reranker(query: Query) -> ModelRequestData: if query["modality"] != "text+images": raise ValueError(f"Unsupported query modality: '{query['modality']}'") @@ -231,8 +326,10 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]): model_example_map = { + "clip": run_clip, "e5_v": run_e5_v, - "vlm2vec": run_vlm2vec, + "vlm2vec_phi3v": run_vlm2vec_phi3v, + "vlm2vec_qwen2vl": run_vlm2vec_qwen2vl, "jinavl_reranker": run_jinavl_reranker, } @@ -246,7 +343,7 @@ def parse_args(): "--model-name", "-m", type=str, - default="vlm2vec", + default="vlm2vec_phi3v", choices=model_example_map.keys(), help="The name of the embedding model.", ) diff --git a/examples/online_serving/dashboards/README.md b/examples/online_serving/dashboards/README.md new file mode 100644 index 0000000000000..30cea6b24d57e --- /dev/null +++ b/examples/online_serving/dashboards/README.md @@ -0,0 +1,87 @@ +# Monitoring Dashboards + +This directory contains monitoring dashboard configurations for vLLM, providing +comprehensive observability for your vLLM deployments. + +## Dashboard Platforms + +We provide dashboards for two popular observability platforms: + +- **[Grafana](https://grafana.com)** +- **[Perses](https://perses.dev)** + +## Dashboard Format Approach + +All dashboards are provided in **native formats** that work across different +deployment methods: + +### Grafana (JSON) + +- ✅ Works with any Grafana instance (cloud, self-hosted, Docker) +- ✅ Direct import via Grafana UI or API +- ✅ Can be wrapped in Kubernetes operators when needed +- ✅ No vendor lock-in or deployment dependencies + +### Perses (YAML) + +- ✅ Works with standalone Perses instances +- ✅ Compatible with Perses API and CLI +- ✅ Supports Dashboard-as-Code workflows +- ✅ Can be wrapped in Kubernetes operators when needed + +## Dashboard Contents + +Both platforms provide equivalent monitoring capabilities: + +| Dashboard | Description | +|-----------|-------------| +| **Performance Statistics** | Tracks latency, throughput, and performance metrics | +| **Query Statistics** | Monitors request volume, query performance, and KPIs | + +## Quick Start + +First, navigate to this example's directory: + +```bash +cd examples/online_serving/dashboards +``` + +### Grafana + +Import the JSON directly into the Grafana UI, or use the API: + +```bash +curl -X POST http://grafana/api/dashboards/db \ + -H "Content-Type: application/json" \ + -d @grafana/performance_statistics.json +``` + +### Perses + +Import via the Perses CLI: + +```bash +percli apply -f perses/performance_statistics.yaml +``` + +## Requirements + +- **Prometheus** metrics from your vLLM deployment +- **Data source** configured in your monitoring platform +- **vLLM metrics** enabled and accessible + +## Platform-Specific Documentation + +For detailed deployment instructions and platform-specific options, see: + +- **[Grafana Documentation](./grafana)** - JSON dashboards, operator usage, manual import +- **[Perses Documentation](./perses)** - YAML specs, CLI usage, operator wrapping + +## Contributing + +When adding new dashboards, please: + +1. Provide native formats (JSON for Grafana, YAML specs for Perses) +2. Update platform-specific README files +3. Ensure dashboards work across deployment methods +4. Test with the latest platform versions diff --git a/examples/online_serving/dashboards/grafana/README.md b/examples/online_serving/dashboards/grafana/README.md new file mode 100644 index 0000000000000..abe5f8cf23677 --- /dev/null +++ b/examples/online_serving/dashboards/grafana/README.md @@ -0,0 +1,59 @@ +# Grafana Dashboards for vLLM Monitoring + +This directory contains Grafana dashboard configurations (as JSON) designed to monitor +vLLM performance and metrics. + +## Requirements + +- Grafana 8.0+ +- Prometheus data source configured in Grafana +- vLLM deployment with Prometheus metrics enabled + +## Dashboard Descriptions + +- **performance_statistics.json**: Tracks performance metrics including latency and + throughput for your vLLM service. +- **query_statistics.json**: Tracks query performance, request volume, and key + performance indicators for your vLLM service. + +## Deployment Options + +### Manual Import (Recommended) + +The easiest way to use these dashboards is to manually import the JSON configurations +directly into your Grafana instance: + +1. Navigate to your Grafana instance +2. Click the '+' icon in the sidebar +3. Select 'Import' +4. Copy and paste the JSON content from the dashboard files, or upload the JSON files + directly + +### Grafana Operator + +If you're using the [Grafana Operator](https://github.com/grafana-operator/grafana-operator) +in Kubernetes, you can wrap these JSON configurations in a `GrafanaDashboard` custom +resource: + +```yaml +# Note: Adjust the instanceSelector to match your Grafana instance's labels +# You can check with: kubectl get grafana -o yaml +apiVersion: grafana.integreatly.org/v1beta1 +kind: GrafanaDashboard +metadata: + name: vllm-performance-dashboard +spec: + instanceSelector: + matchLabels: + dashboards: grafana # Adjust to match your Grafana instance labels + folder: "vLLM Monitoring" + json: | + # Replace this comment with the complete JSON content from + # performance_statistics.json - The JSON should start with { and end with } +``` + +Then apply to your cluster: + +```bash +kubectl apply -f your-dashboard.yaml -n <namespace> +``` diff --git a/examples/online_serving/dashboards/grafana/performance_statistics.json b/examples/online_serving/dashboards/grafana/performance_statistics.json new file mode 100644 index 0000000000000..390d3dd6d2594 --- /dev/null +++ b/examples/online_serving/dashboards/grafana/performance_statistics.json @@ -0,0 +1,1405 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 26, + "links": [], + "panels": [ + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 9, + "panels": [], + "title": "Graph: E2E latency over time ", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "End-to-End latency of requests, showing average and key percentiles over time.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 18, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": true, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 1 + }, + "id": 1, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(vllm:e2e_request_latency_seconds_sum[$__interval]) / rate(vllm:e2e_request_latency_seconds_count[$__interval])", + "format": "table", + "legendFormat": "E2E Latency", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "99th percentile of End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P99", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 1 + }, + "id": 5, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (P99)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "90th percentile of End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P90", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 1 + }, + "id": 4, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (P90)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Average End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "Average", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 5 + }, + "id": 2, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "(sum(increase(vllm:e2e_request_latency_seconds_sum[$__range])) / sum(increase(vllm:e2e_request_latency_seconds_count[$__range])))", + "legendFormat": "Average E2E Latency", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (Avg)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "50th percentile (median) of End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P50", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 5 + }, + "id": 3, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (P50)", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 9 + }, + "id": 8, + "panels": [], + "title": "Graph: TTFT(Time To First Token) over time ", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Time to first token (TTFT) latency, showing average and key percentiles over time.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 18, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 10 + }, + "id": 10, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(vllm:time_to_first_token_seconds_sum[$__interval]) / rate(vllm:time_to_first_token_seconds_count[$__interval])", + "format": "table", + "legendFormat": "TTFT (Avg)", + "range": true, + "refId": "A" + } + ], + "title": "TTFT Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "99th percentile of Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P99", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 10 + }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__range])))", + "legendFormat": "TTFT (p99)", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (P99)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "90th percentile of Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P90", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 10 + }, + "id": 13, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__range])))", + "legendFormat": "TTFT (p90)", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (P90)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Average Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "Average", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 14 + }, + "id": 11, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "(sum(increase(vllm:time_to_first_token_seconds_sum[$__range])) / sum(increase(vllm:time_to_first_token_seconds_count[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (Avg)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "50th percentile (median) of Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "displayName": "P50", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 14 + }, + "id": 12, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orietitletChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (P50)", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 18 + }, + "id": 7, + "panels": [], + "title": "ITL (Iteration Latency / Time Per Output Token) over time.", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Iteration latency, or average time taken to generate a single output token, with percentiles.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 17, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 19 + }, + "id": 15, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(vllm:time_per_output_token_seconds_sum[$__interval]) / rate(vllm:time_per_output_token_seconds_count[$__interval])", + "legendFormat": "ITL (Avg)", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__interval])))", + "hide": false, + "instant": false, + "legendFormat": "ITL (p50)", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__interval])))", + "hide": false, + "instant": false, + "legendFormat": "ITL (p90)", + "range": true, + "refId": "C" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__interval])))", + "hide": false, + "instant": false, + "legendFormat": "ITL (p99)", + "range": true, + "refId": "D" + } + ], + "title": "ITL (Time Per Output Token) Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "90th percentile of Iteration Latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 19 + }, + "id": 18, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (P90)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "99th percentile of Iteration Latency over the selected time range.\n\n", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 19 + }, + "id": 19, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (P99)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Average Iteration Latency (time per output token) over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 23 + }, + "id": 16, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "(sum(increase(vllm:time_per_output_token_seconds_sum[$__range])) / sum(increase(vllm:time_per_output_token_seconds_count[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (Avg)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "50th percentile (median) of Iteration Latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 23 + }, + "id": 17, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (P50)", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 27 + }, + "id": 6, + "panels": [], + "title": "TPS (Tokens Per Second)", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Rate of tokens processed per second, including prompt and generation phases.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "tokens/sec (tps)" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 28 + }, + "id": 20, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(vllm:generation_tokens_total[$__interval])", + "legendFormat": "Generation TPS", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(vllm:prompt_tokens_total[$__interval])", + "hide": false, + "instant": false, + "legendFormat": "Prompt TPS", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(vllm:iteration_tokens_total_count[$__interval])", + "hide": false, + "instant": false, + "legendFormat": "Overall Iteration TPS", + "range": true, + "refId": "C" + } + ], + "title": "TPS (Tokens Per Second) Over Time", + "type": "timeseries" + } + ], + "preload": false, + "schemaVersion": 40, + "tags": [], + "templating": { + "list": [ + { + "name": "DS_PROMETHEUS", + "type": "datasource", + "label": "datasource", + "query": "prometheus", + "refresh": 1, + "current": { + "text": "Prometheus", + "value": "prometheus" + } + }, + { + "current": { + "text": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)", + "value": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)" + }, + "label": "Aggregation", + "name": "agg_method", + "options": [ + { + "selected": true, + "text": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)", + "value": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)" + } + ], + "query": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)", + "type": "custom" + }, + { + "current": { + "text": [ + "granite-33-2b-instruct" + ], + "value": [ + "granite-33-2b-instruct" + ] + }, + "definition": "label_values(vllm:generation_tokens_total,model_name)", + "includeAll": true, + "label": "Deployment_ID", + "multi": true, + "name": "Deployment_id", + "options": [], + "query": { + "qryType": 1, + "query": "label_values(vllm:generation_tokens_total,model_name)", + "refId": "PrometheusVariableQueryEditor-VariableQuery" + }, + "refresh": 1, + "regex": "", + "type": "query" + } + ] + }, + "time": { + "from": "now-12h", + "to": "now" + }, + "timezone": "browser", + "uid": "performance-statistics", + "title": "Performance Statistics", + "version": 40, + "weekStart": "" +} \ No newline at end of file diff --git a/examples/online_serving/dashboards/grafana/query_statistics.json b/examples/online_serving/dashboards/grafana/query_statistics.json new file mode 100644 index 0000000000000..880f6c5d71764 --- /dev/null +++ b/examples/online_serving/dashboards/grafana/query_statistics.json @@ -0,0 +1,760 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "High-level overview of VLLM model deployment behavior and key performance indicators. Designed for Data Scientists and Product Managers to monitor request volume, token throughput, and latency", + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 47, + "links": [], + "panels": [ + { + "collapsed": true, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 }, + "id": 20, + "panels": [], + "title": "Request Over Time", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "req/s" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 1 }, + "id": 1, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "editorMode": "code", + "expr": "sum by (model_name) (\n rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval])\n)", + "interval": "1", + "legendFormat": "{{model_name}}", + "range": true, + "refId": "A" + } + ], + "title": "Successful Requests Over Time", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "req/s" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 1 }, + "id": 2, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["mean"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Requests Avg Rate", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calcultaions": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 1 }, + "id": 3, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "p50 Latency", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 4 }, + "id": 4, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "p90 Latency", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 4 }, + "id": 5, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "p99 Latency", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 7 }, + "id": 19, + "panels": [], + "title": "Size Distribution", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "fillOpacity": 80, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "lineWidth": 1, + "stacking": { "group": "A", "mode": "none" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 8 }, + "id": 6, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum by (le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "{{model_name}} le={{le}}", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size Distribution", + "type": "histogram" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "calculation ": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 8 }, + "id": 9, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size p90", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calcultion": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 8 }, + "id": 8, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size p50", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calcultaion": { "index": 0, "text": "mean" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 11 }, + "id": 7, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))\n/\nsum(rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size Avg", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 11 }, + "id": 10, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size p99", + "type": "stat" + }, + { + "collapsed": true, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 14 }, + "id": 18, + "panels": [], + "title": "Input Token Over Time", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 15 }, + "id": 11, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum by (model_name) (rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "{{model_name}}", + "range": true, + "refId": "A" + } + ], + "title": "Input Tokens Over Time", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "mean" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 15 }, + "id": 12, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Tokens/Sec Avg", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 21 }, + "id": 17, + "panels": [], + "title": "Output Token Over Time", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 22 }, + "id": 13, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum by (model_name) (rate(vllm:generation_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "{{model_name}}", + "range": true, + "refId": "A" + } + ], + "title": "Output Tokens Over Time", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "mean" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 22 }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:generation_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Output Tokens/Sec Avg", + "type": "stat" + } + ], + "preload": false, + "schemaVersion": 40, + "tags": [], + "templating": { + "list": [ + { + "current": { "text": "Prometheus", "value": "4184fc20-68a7-483a-8d9b-7caa59c680dd" }, + "label": "datasource", + "name": "DS_PROMETHEUS", + "options": [], + "query": "prometheus", + "refresh": 1, + "type": "datasource" + }, + { + "current": { "text": ["All"], "value": ["$__all"] }, + "definition": "label_values(vllm:request_success_total,model_name)", + "includeAll": true, + "label": "Deployment_ID", + "multi": true, + "name": "Deployment_id", + "options": [], + "query": { + "qryType": 1, + "query": "label_values(vllm:request_success_total,model_name)", + "refId": "PrometheusVariableQueryEditor-VariableQuery" + }, + "refresh": 1, + "regex": "", + "sort": 1, + "type": "query" + }, + { + "current": { "text": "All hours", "value": "All hours" }, + "hide": 2, + "label": "Rush Hours Only", + "name": "rush_hours", + "options": [ + { "selected": true, "text": "false", "value": "All hours" }, + { "selected": false, "text": "true", "value": "Rush hours" } + ], + "query": "false : All hours, true : Rush hours", + "type": "custom" + }, + { + "current": { "text": "All", "value": "All" }, + "hide": 2, + "label": "Rush Hours Type", + "name": "rush_hours_type", + "options": [ + { "selected": true, "text": "^All__.*$", "value": "All" }, + { "selected": false, "text": "^Static__.*$", "value": "Static" }, + { "selected": false, "text": "^Dynamic__.*$", "value": "Dynamic" } + ], + "query": "^All__.*$ : All, ^Static__.*$ : Static, ^Dynamic__.*$ : Dynamic", + "type": "custom" + }, + { + "current": { "text": "", "value": "" }, + "hide": 2, + "name": "query0", + "options": [], + "query": "", + "refresh": 1, + "regex": "", + "type": "query" + } + ] + }, + "time": { "from": "now-12h", "to": "now" }, + "timepicker": {}, + "timezone": "browser", + "title": "Query Statistics_New4", + "uid": "query-statistics4", + "version": 2, + "weekStart": "" +} + diff --git a/examples/online_serving/dashboards/perses/README.md b/examples/online_serving/dashboards/perses/README.md new file mode 100644 index 0000000000000..780a6ef13a3e8 --- /dev/null +++ b/examples/online_serving/dashboards/perses/README.md @@ -0,0 +1,48 @@ +# Perses Dashboards for vLLM Monitoring + +This directory contains Perses dashboard configurations designed to monitor vLLM +performance and metrics. + +## Requirements + +- Perses instance (standalone or via operator) +- Prometheus data source configured in Perses +- vLLM deployment with Prometheus metrics enabled + +## Dashboard Format + +We provide dashboards in the **native Perses YAML format** that works across all +deployment methods: + +- **Files**: `*.yaml` (native Perses dashboard specifications) +- **Format**: Pure dashboard specifications that work everywhere +- **Usage**: Works with standalone Perses, API imports, CLI, and file provisioning +- **Kubernetes**: Directly compatible with Perses Operator + +## Dashboard Descriptions + +- **performance_statistics.yaml**: Performance metrics with aggregated latency + statistics +- **query_statistics.yaml**: Query performance and deployment metrics + +## Deployment Options + +### Direct Import to Perses + +Import the dashboard specifications via Perses API or CLI: + +```bash +percli apply -f performance_statistics.yaml +``` + +### Perses Operator (Kubernetes) + +The native YAML format works directly with the Perses Operator: + +```bash +kubectl apply -f performance_statistics.yaml -n <namespace> +``` + +### File Provisioning + +Place the YAML files in a Perses provisioning folder for automatic loading. diff --git a/examples/online_serving/dashboards/perses/performance_statistics.yaml b/examples/online_serving/dashboards/perses/performance_statistics.yaml new file mode 100644 index 0000000000000..2e8d24c3324b9 --- /dev/null +++ b/examples/online_serving/dashboards/perses/performance_statistics.yaml @@ -0,0 +1,764 @@ +kind: PersesDashboard +metadata: + name: performance-statistics + createdAt: 0001-01-01T00:00:00Z + updatedAt: 0001-01-01T00:00:00Z + version: 0 + project: "" +spec: + display: + name: Performance Statistics + + variables: + - kind: ListVariable + spec: + display: + name: Deployment_ID + hidden: false + name: Deployment_id + allowAllValue: true + allowMultiple: true + defaultValue: + - $__all + sort: alphabetical-asc + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + labelName: model_name + matchers: + # Any one vllm metric that always carries model_name + - vllm:generation_tokens_total{} + + panels: + "1": + kind: Panel + spec: + display: + name: E2E Latency over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + # avg latency by model = sum(rate(sum)) / sum(rate(count)) + query: > + sum by (model_name) (rate(vllm:e2e_request_latency_seconds_sum{model_name=~"$Deployment_id"}[$__interval])) + / + sum by (model_name) (rate(vllm:e2e_request_latency_seconds_count{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}}' + + "2": + kind: Panel + spec: + display: + name: E2E Latency (Avg) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + (sum by (model_name) (increase(vllm:e2e_request_latency_seconds_sum{model_name=~"$Deployment_id"}[$__range]))) + / + (sum by (model_name) (increase(vllm:e2e_request_latency_seconds_count{model_name=~"$Deployment_id"}[$__range]))) + + "3": + kind: Panel + spec: + display: + name: E2E Latency (P50) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "4": + kind: Panel + spec: + display: + name: E2E Latency (P90) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "5": + kind: Panel + spec: + display: + name: E2E Latency (P99) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "6": + kind: Panel + spec: + display: + name: TTFT over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:time_to_first_token_seconds_sum{model_name=~"$Deployment_id"}[$__interval])) + / + sum by (model_name) (rate(vllm:time_to_first_token_seconds_count{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}}' + + "7": + kind: Panel + spec: + display: + name: TTFT (Avg) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + (sum by (model_name) (increase(vllm:time_to_first_token_seconds_sum{model_name=~"$Deployment_id"}[$__range]))) + / + (sum by (model_name) (increase(vllm:time_to_first_token_seconds_count{model_name=~"$Deployment_id"}[$__range]))) + + "8": + kind: Panel + spec: + display: + name: TTFT (P50) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "9": + kind: Panel + spec: + display: + name: TTFT (P90) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "10": + kind: Panel + spec: + display: + name: TTFT (P99) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "11": + kind: Panel + spec: + display: + name: ITL (Time per Output Token) over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:time_per_output_token_seconds_sum{model_name=~"$Deployment_id"}[$__interval])) + / + sum by (model_name) (rate(vllm:time_per_output_token_seconds_count{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}}' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + seriesNameFormat: '{{model_name}} p50' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + seriesNameFormat: '{{model_name}} p90' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + seriesNameFormat: '{{model_name}} p99' + + "12": + kind: Panel + spec: + display: + name: ITL (Avg) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + (sum by (model_name) (increase(vllm:time_per_output_token_seconds_sum{model_name=~"$Deployment_id"}[$__range]))) + / + (sum by (model_name) (increase(vllm:time_per_output_token_seconds_count{model_name=~"$Deployment_id"}[$__range]))) + + "13": + kind: Panel + spec: + display: + name: ITL (P50) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "14": + kind: Panel + spec: + display: + name: ITL (P90) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "15": + kind: Panel + spec: + display: + name: ITL (P99) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "16": + kind: Panel + spec: + display: + name: TPS (Tokens/sec) over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:generation_tokens_total{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}} generation' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:prompt_tokens_total{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}} prompt' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + # overall iteration tokens/sec if exposed + query: > + rate(vllm:iteration_tokens_total_count[$__interval]) + seriesNameFormat: 'iteration overall' + + "17": + kind: Panel + spec: + display: + name: KV Cache Usage (avg %) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + # Multiply by 100 so we can read it as a percentage without setting a unit (avoids CUE unit conflicts) + query: > + 100 * avg(vllm:gpu_cache_usage_perc) + + "18": + kind: Panel + spec: + display: + name: Running Requests by Pod + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (pod) (vllm:num_requests_running) + seriesNameFormat: '{{pod}}' + + "19": + kind: Panel + spec: + display: + name: Waiting Requests by Pod + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (pod) (vllm:num_requests_waiting) + seriesNameFormat: '{{pod}}' + + "20": + kind: Panel + spec: + display: + name: Running Requests (sum) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: sum(vllm:num_requests_running) + + "21": + kind: Panel + spec: + display: + name: Waiting Requests (sum) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: sum(vllm:num_requests_waiting) + + layouts: + - kind: Grid + spec: + display: + title: Overview + items: + - x: 0 + y: 0 + width: 6 + height: 3 + content: { $ref: '#/spec/panels/17' } # KV cache % + - x: 6 + y: 0 + width: 6 + height: 3 + content: { $ref: '#/spec/panels/20' } # running sum + - x: 12 + y: 0 + width: 6 + height: 3 + content: { $ref: '#/spec/panels/21' } # waiting sum + + - kind: Grid + spec: + display: + title: E2E Latency + items: + - x: 0 + y: 1 + width: 10 + height: 6 + content: { $ref: '#/spec/panels/1' } + - x: 10 + y: 1 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/2' } + - x: 17 + y: 1 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/3' } + - x: 10 + y: 4 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/4' } + - x: 17 + y: 4 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/5' } + + - kind: Grid + spec: + display: + title: TTFT + items: + - x: 0 + y: 8 + width: 10 + height: 6 + content: { $ref: '#/spec/panels/6' } + - x: 10 + y: 8 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/7' } + - x: 17 + y: 8 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/8' } + - x: 10 + y: 11 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/9' } + - x: 17 + y: 11 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/10' } + + - kind: Grid + spec: + display: + title: ITL (Time per Output Token) + items: + - x: 0 + y: 15 + width: 10 + height: 6 + content: { $ref: '#/spec/panels/11' } + - x: 10 + y: 15 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/12' } + - x: 17 + y: 15 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/13' } + - x: 10 + y: 18 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/14' } + - x: 17 + y: 18 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/15' } + + - kind: Grid + spec: + display: + title: TPS (Prompt / Generation / Iteration) + items: + - x: 0 + y: 22 + width: 14 + height: 6 + content: { $ref: '#/spec/panels/16' } + + - kind: Grid + spec: + display: + title: Per-Pod Request State + items: + - x: 0 + y: 28 + width: 12 + height: 6 + content: { $ref: '#/spec/panels/18' } + - x: 12 + y: 28 + width: 12 + height: 6 + content: { $ref: '#/spec/panels/19' } + diff --git a/examples/online_serving/dashboards/perses/query_statistics.yaml b/examples/online_serving/dashboards/perses/query_statistics.yaml new file mode 100644 index 0000000000000..28109aae81511 --- /dev/null +++ b/examples/online_serving/dashboards/perses/query_statistics.yaml @@ -0,0 +1,392 @@ +kind: PersesDashboard +metadata: + name: query-statistics + createdAt: 0001-01-01T00:00:00Z + updatedAt: 0001-01-01T00:00:00Z + version: 0 + project: "" +spec: + display: + name: Query Statistics_New + + variables: + - kind: ListVariable + spec: + name: NS + display: { name: Namespace } + allowMultiple: false + defaultValue: llm-d + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + labelName: namespace + matchers: + - up{service=~".*vllm.*"} + + - kind: ListVariable + spec: + name: SVC + display: { name: Service } + allowMultiple: false + defaultValue: vllm-qwen2-0-5b-sim + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + labelName: service + matchers: + - up{namespace="$NS",service=~".*vllm.*"} + + - kind: ListVariable + spec: + name: MODEL + display: { name: Model (real vLLM) } + allowAllValue: true + allowMultiple: true + defaultValue: ["$__all"] + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + labelName: model_name + matchers: + - vllm:request_success_total{namespace="$NS",service="$SVC"} + + panels: + + # --- Core (works on Simulator & Real) --- + core_running_now: + kind: Panel + spec: + display: { name: Running Requests (now) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum(vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_waiting_now: + kind: Panel + spec: + display: { name: Waiting Requests (now) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum(vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_kv_usage_now: + kind: Panel + spec: + display: { name: KV Cache Usage (0–1) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_running_ts: + kind: Panel + spec: + display: { name: Running Over Time } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (service) (vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_waiting_ts: + kind: Panel + spec: + display: { name: Waiting Over Time } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (service) (vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_targets_up: + kind: Panel + spec: + display: { name: Scrape Targets Up } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: count(up{namespace="$NS",service="$SVC"} == 1) or vector(0) + minStep: "15s" + + # --- KV Cache as Percent (works on Simulator & Real) --- + core_kv_usage_pct_now: + kind: Panel + spec: + display: { name: KV Cache Usage (%) – now } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + # multiply by 100 to present percentage; omit format.unit to avoid schema conflicts + query: (avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + minStep: "15s" + + core_kv_usage_pct_ts: + kind: Panel + spec: + display: { name: KV Cache Usage (%) – over time } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: (avg by (service) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + minStep: "15s" + + # --- Per-Pod breakdowns (works on Simulator & Real) --- + per_pod_running_ts: + kind: Panel + spec: + display: { name: Running by Pod } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (pod) (vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + per_pod_waiting_ts: + kind: Panel + spec: + display: { name: Waiting by Pod } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (pod) (vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + per_pod_kv_pct_ts: + kind: Panel + spec: + display: { name: KV Cache (%) by Pod } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + # if your exporter labels kv metric with pod (the sim does), this works; otherwise it will just return empty + query: (avg by (pod) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + minStep: "15s" + + # --- Real vLLM only (zeros on simulator) --- + real_req_rate_ts: + kind: Panel + spec: + display: { name: Request Rate (real vLLM) } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (model_name) (rate(vllm:request_success_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0) + minStep: "15s" + + real_p50: + kind: Panel + spec: + display: { name: p50 Latency (real vLLM) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: histogram_quantile(0.50, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0) + minStep: "15s" + + real_p90: + kind: Panel + spec: + display: { name: p90 Latency (real vLLM) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: histogram_quantile(0.90, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0) + minStep: "15s" + + real_p99: + kind: Panel + spec: + display: { name: p99 Latency (real vLLM) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: histogram_quantile(0.99, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0) + minStep: "15s" + + real_input_tokens_ts: + kind: Panel + spec: + display: { name: Input Tokens / sec (real vLLM) } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (model_name) (rate(vllm:prompt_tokens_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0) + minStep: "15s" + + real_output_tokens_ts: + kind: Panel + spec: + display: { name: Output Tokens / sec (real vLLM) } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (model_name) (rate(vllm:generation_tokens_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0) + minStep: "15s" + + layouts: + - kind: Grid + spec: + display: { title: Core (Sim & Real) } + items: + - { x: 0, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_running_now' } } + - { x: 6, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_waiting_now' } } + - { x: 12, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_kv_usage_now' } } + - { x: 18, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_targets_up' } } + - { x: 0, y: 3, width: 12, height: 6, content: { $ref: '#/spec/panels/core_running_ts' } } + - { x: 12, y: 3, width: 12, height: 6, content: { $ref: '#/spec/panels/core_waiting_ts' } } + + - kind: Grid + spec: + display: { title: KV Cache (%) } + items: + - { x: 0, y: 9, width: 6, height: 3, content: { $ref: '#/spec/panels/core_kv_usage_pct_now' } } + - { x: 6, y: 9, width: 18, height: 6, content: { $ref: '#/spec/panels/core_kv_usage_pct_ts' } } + + - kind: Grid + spec: + display: { title: Per-Pod breakdowns } + items: + - { x: 0, y: 15, width: 12, height: 6, content: { $ref: '#/spec/panels/per_pod_running_ts' } } + - { x: 12, y: 15, width: 12, height: 6, content: { $ref: '#/spec/panels/per_pod_waiting_ts' } } + - { x: 0, y: 21, width: 24, height: 6, content: { $ref: '#/spec/panels/per_pod_kv_pct_ts' } } + + - kind: Grid + spec: + display: { title: Real vLLM only (shows 0 on simulator) } + items: + - { x: 0, y: 27, width: 12, height: 6, content: { $ref: '#/spec/panels/real_req_rate_ts' } } + - { x: 12, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p50' } } + - { x: 16, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p90' } } + - { x: 20, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p99' } } + - { x: 0, y: 33, width: 12, height: 6, content: { $ref: '#/spec/panels/real_input_tokens_ts' } } + - { x: 12, y: 33, width: 12, height: 6, content: { $ref: '#/spec/panels/real_output_tokens_ts' } } + diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index 6925dc8af07e9..d434e22b1ae88 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -53,7 +53,7 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ @@ -62,7 +62,7 @@ CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & # wait until prefill and decode instances are ready wait_for_server 8100 diff --git a/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py index d39edb0b9d15c..1df11d9d84957 100644 --- a/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py +++ b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py @@ -203,9 +203,9 @@ class Proxy: async with session.post( url=url, json=data, headers=headers ) as response: - if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 + if 200 <= response.status < 300 or 400 <= response.status < 500: if use_chunked: - async for chunk_bytes in response.content.iter_chunked( # noqa: E501 + async for chunk_bytes in response.content.iter_chunked( 1024 ): yield chunk_bytes diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh index 7b0b12bb34d25..1e7acccb4ff94 100644 --- a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh @@ -166,7 +166,7 @@ main() { local kv_port=$((21001 + i)) echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port" - CUDA_VISIBLE_DEVICES=$gpu_id VLLM_USE_V1=1 vllm serve $MODEL \ + CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ --enforce-eager \ --host 0.0.0.0 \ --port $port \ @@ -194,7 +194,7 @@ main() { local kv_port=$((22001 + i)) echo " Decode server $((i+1)): GPU $gpu_id, Port $port, KV Port $kv_port" - VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ --enforce-eager \ --host 0.0.0.0 \ --port $port \ diff --git a/examples/online_serving/elastic_ep/serve_deepseek_v2.sh b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh index 1234ebba4d818..6845545b6fd17 100644 --- a/examples/online_serving/elastic_ep/serve_deepseek_v2.sh +++ b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh @@ -55,7 +55,6 @@ done echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS" export RAY_DEDUP_LOGS=0 -export VLLM_USE_V1=1 export VLLM_ALL2ALL_BACKEND="pplx" export VLLM_USE_DEEP_GEMM=1 diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index 584db53db4e40..f4b79b5e13020 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -6,6 +6,8 @@ import msgspec import zmq from msgspec.msgpack import Decoder +from vllm.v1.core.kv_cache_utils import ExternalBlockHash + # # Types copied from vllm.distributed.kv_events @@ -22,15 +24,17 @@ class KVCacheEvent( class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[ExternalBlockHash] + parent_block_hash: Optional[ExternalBlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[ExternalBlockHash] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): diff --git a/examples/online_serving/multi-node-serving.sh b/examples/online_serving/multi-node-serving.sh index e8ad8d3de5f41..3fc5502fb9bc2 100644 --- a/examples/online_serving/multi-node-serving.sh +++ b/examples/online_serving/multi-node-serving.sh @@ -11,7 +11,7 @@ # Example usage: # On the head node machine, start the Ray head node process and run a vLLM server. # ./multi-node-serving.sh leader --ray_port=6379 --ray_cluster_size=<SIZE> [<extra ray args>] && \ -# python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2 +# vllm serve meta-llama/Meta-Llama-3.1-405B-Instruct --port 8080 --tensor-parallel-size 8 --pipeline_parallel_size 2 # # On each worker node, start the Ray worker node process. # ./multi-node-serving.sh worker --ray_address=<HEAD_NODE_IP> --ray_port=6379 [<extra ray args>] diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index ac5f79b56e49f..5d515fbfb6716 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -38,11 +38,13 @@ client = OpenAI( base_url=openai_api_base, ) +headers = {"User-Agent": "vLLM Example Client"} + def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" - with requests.get(content_url) as response: + with requests.get(content_url, headers=headers) as response: response.raise_for_status() result = base64.b64encode(response.content).decode("utf-8") @@ -50,19 +52,19 @@ def encode_base64_content_from_url(content_url: str) -> str: # Text-only inference -def run_text_only(model: str) -> None: +def run_text_only(model: str, max_completion_tokens: int) -> None: chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": "What's the capital of France?"}], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion.choices[0].message.content - print("Chat completion output:", result) + print("Chat completion output:\n", result) # Single-image input inference -def run_single_image(model: str) -> None: +def run_single_image(model: str, max_completion_tokens: int) -> None: ## Use image url in the payload image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" chat_completion_from_url = client.chat.completions.create( @@ -79,11 +81,11 @@ def run_single_image(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output from image url:", result) + print("Chat completion output from image url:\n", result) ## Use base64 encoded image in the payload image_base64 = encode_base64_content_from_url(image_url) @@ -101,7 +103,7 @@ def run_single_image(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content @@ -109,7 +111,7 @@ def run_single_image(model: str) -> None: # Multi-image input inference -def run_multi_image(model: str) -> None: +def run_multi_image(model: str, max_completion_tokens: int) -> None: image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" chat_completion_from_url = client.chat.completions.create( @@ -130,15 +132,15 @@ def run_multi_image(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output:", result) + print("Chat completion output:\n", result) # Video input inference -def run_video(model: str) -> None: +def run_video(model: str, max_completion_tokens: int) -> None: video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4" video_base64 = encode_base64_content_from_url(video_url) @@ -157,11 +159,11 @@ def run_video(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output from image url:", result) + print("Chat completion output from video url:\n", result) ## Use base64 encoded video in the payload chat_completion_from_base64 = client.chat.completions.create( @@ -178,15 +180,15 @@ def run_video(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from base64 encoded image:", result) + print("Chat completion output from base64 encoded video:\n", result) # Audio input inference -def run_audio(model: str) -> None: +def run_audio(model: str, max_completion_tokens: int) -> None: from vllm.assets.audio import AudioAsset audio_url = AudioAsset("winning_call").url @@ -211,11 +213,11 @@ def run_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from input audio:", result) + print("Chat completion output from input audio:\n", result) # HTTP URL chat_completion_from_url = client.chat.completions.create( @@ -235,11 +237,11 @@ def run_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_url.choices[0].message.content - print("Chat completion output from audio url:", result) + print("Chat completion output from audio url:\n", result) # base64 URL chat_completion_from_base64 = client.chat.completions.create( @@ -259,17 +261,59 @@ def run_audio(model: str) -> None: } ], model=model, - max_completion_tokens=64, + max_completion_tokens=max_completion_tokens, ) result = chat_completion_from_base64.choices[0].message.content - print("Chat completion output from base64 encoded audio:", result) + print("Chat completion output from base64 encoded audio:\n", result) + + +def run_multi_audio(model: str, max_completion_tokens: int) -> None: + from vllm.assets.audio import AudioAsset + + # Two different audios to showcase batched inference. + audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + audio_url2 = AudioAsset("azacinto_foscolo").url + audio_base64_2 = encode_base64_content_from_url(audio_url2) + + # OpenAI-compatible schema (`input_audio`) + chat_completion_from_base64 = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Are these two audios the same?"}, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav", + }, + }, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64_2, + "format": "wav", + }, + }, + ], + } + ], + model=model, + max_completion_tokens=max_completion_tokens, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:\n", result) example_function_map = { "text-only": run_text_only, "single-image": run_single_image, "multi-image": run_multi_image, + "multi-audio": run_multi_audio, "video": run_video, "audio": run_audio, } @@ -288,13 +332,20 @@ def parse_args(): choices=list(example_function_map.keys()), help="Conversation type with multimodal data.", ) + parser.add_argument( + "--max-completion-tokens", + "-n", + type=int, + default=128, + help="Maximum number of tokens to generate for each completion.", + ) return parser.parse_args() def main(args) -> None: chat_type = args.chat_type model = get_first_model(client) - example_function_map[chat_type](model) + example_function_map[chat_type](model, args.max_completion_tokens) if __name__ == "__main__": diff --git a/examples/online_serving/openai_chat_completion_client_with_tools_required.py b/examples/online_serving/openai_chat_completion_client_with_tools_required.py index 7eb8668213eef..c00d712b351d7 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools_required.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -5,8 +5,8 @@ To run this example, you can start the vLLM server without any specific flags: ```bash -VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \ - --guided-decoding-backend outlines +vllm serve unsloth/Llama-3.2-1B-Instruct \ + --structured-outputs-config.backend outlines ``` This example demonstrates how to generate chat completions diff --git a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py deleted file mode 100644 index 771ad8511e972..0000000000000 --- a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import base64 -import io - -import requests -from PIL import Image - -image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - - -def vlm2vec(): - response = requests.post( - "http://localhost:8000/v1/embeddings", - json={ - "model": "TIGER-Lab/VLM2Vec-Full", - "messages": [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": "Represent the given image."}, - ], - } - ], - "encoding_format": "float", - }, - ) - response.raise_for_status() - response_json = response.json() - - print("Embedding output:", response_json["data"][0]["embedding"]) - - -def dse_qwen2_vl(inp: dict): - # Embedding an Image - if inp["type"] == "image": - messages = [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": inp["image_url"], - }, - }, - {"type": "text", "text": "What is shown in this image?"}, - ], - } - ] - # Embedding a Text Query - else: - # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image - # of the minimum input size - buffer = io.BytesIO() - image_placeholder = Image.new("RGB", (56, 56)) - image_placeholder.save(buffer, "png") - buffer.seek(0) - image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") - messages = [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_placeholder}", - }, - }, - {"type": "text", "text": f"Query: {inp['content']}"}, - ], - } - ] - - response = requests.post( - "http://localhost:8000/v1/embeddings", - json={ - "model": "MrLight/dse-qwen2-2b-mrl-v1", - "messages": messages, - "encoding_format": "float", - }, - ) - response.raise_for_status() - response_json = response.json() - - print("Embedding output:", response_json["data"][0]["embedding"]) - - -def parse_args(): - parser = argparse.ArgumentParser( - "Script to call a specified VLM through the API. Make sure to serve " - "the model with `--runner pooling` before running this." - ) - parser.add_argument( - "--model", - type=str, - choices=["vlm2vec", "dse_qwen2_vl"], - required=True, - help="Which model to call.", - ) - return parser.parse_args() - - -def main(args): - if args.model == "vlm2vec": - vlm2vec() - elif args.model == "dse_qwen2_vl": - dse_qwen2_vl( - { - "type": "image", - "image_url": image_url, - } - ) - dse_qwen2_vl( - { - "type": "text", - "content": "What is the weather like today?", - } - ) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/online_serving/openai_embedding_long_text/README.md index 04edc4680ea0b..00d3ded3e41c1 100644 --- a/examples/online_serving/openai_embedding_long_text/README.md +++ b/examples/online_serving/openai_embedding_long_text/README.md @@ -42,7 +42,7 @@ python client.py ### Server Configuration -The key parameters for chunked processing are in the `--override-pooler-config`: +The key parameters for chunked processing are in the `--pooler-config`: ```json { diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/online_serving/openai_embedding_long_text/client.py index 6e9838ac6d8db..4a3674bb3f2a8 100644 --- a/examples/online_serving/openai_embedding_long_text/client.py +++ b/examples/online_serving/openai_embedding_long_text/client.py @@ -13,7 +13,7 @@ Prerequisites: # MEAN pooling (processes all chunks, recommended for complete coverage) vllm serve intfloat/multilingual-e5-large \ - --override-pooler-config \ + --pooler-config \ '{"pooling_type": "MEAN", "normalize": true, ' \ '"enable_chunked_processing": true, "max_embed_len": 3072000}' \ --served-model-name multilingual-e5-large \ @@ -23,7 +23,7 @@ Prerequisites: # OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks) vllm serve BAAI/bge-large-en-v1.5 \ - --override-pooler-config \ + --pooler-config \ '{"pooling_type": "CLS", "normalize": true, ' \ '"enable_chunked_processing": true, "max_embed_len": 1048576}' \ --served-model-name bge-large-en-v1.5 \ diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh index f356d7d4529ea..1577de85f7ff2 100644 --- a/examples/online_serving/openai_embedding_long_text/service.sh +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -103,7 +103,7 @@ POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enab vllm serve "$MODEL_NAME" \ --tensor-parallel-size "$GPU_COUNT" \ --enforce-eager \ - --override-pooler-config "$POOLER_CONFIG" \ + --pooler-config "$POOLER_CONFIG" \ --served-model-name ${MODEL_CODE} \ --api-key "$API_KEY" \ --trust-remote-code \ @@ -120,7 +120,7 @@ echo " - API Key: $API_KEY" echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN" echo "" echo "🧪 Test the server with:" -echo " python examples/online_serving/openai_embedding_long_text_client.py" +echo " python examples/online_serving/openai_embedding_long_text/client.py" echo "" echo "📚 Enhanced features enabled:" echo " ✅ Intelligent native pooling type detection" diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md new file mode 100644 index 0000000000000..2c271b6a32bc2 --- /dev/null +++ b/examples/online_serving/pooling/README.md @@ -0,0 +1,49 @@ +# Pooling models + +## Cohere rerank usage + +```bash +python examples/online_serving/pooling/cohere_rerank_client.py +``` + +## Jinaai rerank usage + +```bash +python examples/online_serving/pooling/jinaai_rerank_client.py +``` + +## Named Entity Recognition (NER) usage + +```bash +python examples/online_serving/pooling/ner.py +``` + +## Openai chat embedding for multimodal usage + +```bash +python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +``` + +## Openai classification usage + +```bash +python examples/online_serving/pooling/openai_classification_client.py +``` + +## Openai embedding usage + +```bash +python examples/online_serving/pooling/openai_embedding_client.py +``` + +## Openai embedding matryoshka dimensions usage + +```bash +python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py +``` + +## Openai pooling usage + +```bash +python examples/online_serving/pooling/openai_pooling_client.py +``` diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/pooling/cohere_rerank_client.py similarity index 100% rename from examples/online_serving/cohere_rerank_client.py rename to examples/online_serving/pooling/cohere_rerank_client.py diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/pooling/jinaai_rerank_client.py similarity index 100% rename from examples/online_serving/jinaai_rerank_client.py rename to examples/online_serving/pooling/jinaai_rerank_client.py diff --git a/examples/online_serving/pooling/ner.py b/examples/online_serving/pooling/ner.py new file mode 100644 index 0000000000000..9ec2bd45a0fe5 --- /dev/null +++ b/examples/online_serving/pooling/ner.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +""" +Example online usage of Pooling API for Named Entity Recognition (NER). + +Run `vllm serve <model> --runner pooling` +to start up the server in vLLM. e.g. + +vllm serve boltuix/NeuroBERT-NER +""" + +import argparse + +import requests +import torch + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER") + + return parser.parse_args() + + +def main(args): + from transformers import AutoConfig, AutoTokenizer + + api_url = f"http://{args.host}:{args.port}/pooling" + model_name = args.model + + # Load tokenizer and config + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + label_map = config.id2label + + # Input text + text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + prompt = {"model": model_name, "input": text} + + pooling_response = post_http_request(prompt=prompt, api_url=api_url) + + # Run inference + output = pooling_response.json()["data"][0] + logits = torch.tensor(output["data"]) + predictions = logits.argmax(dim=-1) + inputs = tokenizer(text, return_tensors="pt") + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + labels = [label_map[p.item()] for p in predictions] + assert len(tokens) == len(predictions) + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py new file mode 100644 index 0000000000000..16ac4378c6863 --- /dev/null +++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +"""Example Python client for multimodal embedding API using vLLM API server. + +Refer to each `run_*` function for the command to run the server for that model. +""" + +import argparse +import base64 +import io +from typing import Literal, Union + +from openai import OpenAI +from openai._types import NOT_GIVEN, NotGiven +from openai.types.chat import ChatCompletionMessageParam +from openai.types.create_embedding_response import CreateEmbeddingResponse +from PIL import Image + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + + +def create_chat_embeddings( + client: OpenAI, + *, + messages: list[ChatCompletionMessageParam], + model: str, + encoding_format: Union[Literal["base64", "float"], NotGiven] = NOT_GIVEN, +) -> CreateEmbeddingResponse: + """ + Convenience function for accessing vLLM's Chat Embeddings API, + which is an extension of OpenAI's existing Embeddings API. + """ + return client.post( + "/embeddings", + cast_to=CreateEmbeddingResponse, + body={"messages": messages, "model": model, "encoding_format": encoding_format}, + ) + + +def run_clip(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve openai/clip-vit-base-patch32 \ + --runner pooling + """ + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "a photo of a cat"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + +def run_vlm2vec(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve TIGER-Lab/VLM2Vec-Full \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 4096 \ + --chat-template examples/template_vlm2vec_phi3v.jinja + """ + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Represent the given image."}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "Represent the given image with the following question: What is in the image.", + }, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image+Text embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "A cat and a dog"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + +def run_dse_qwen2_vl(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 8192 \ + --chat-template examples/template_dse_qwen2_vl.jinja + """ + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + {"type": "text", "text": "What is shown in this image?"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image + # of the minimum input size + buffer = io.BytesIO() + image_placeholder = Image.new("RGB", (56, 56)) + image_placeholder.save(buffer, "png") + buffer.seek(0) + image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_placeholder}", + }, + }, + {"type": "text", "text": "Query: What is the weather like today?"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + +model_example_map = { + "clip": run_clip, + "vlm2vec": run_vlm2vec, + "dse_qwen2_vl": run_dse_qwen2_vl, +} + + +def parse_args(): + parser = argparse.ArgumentParser( + "Script to call a specified VLM through the API. Make sure to serve " + "the model with `--runner pooling` before running this." + ) + parser.add_argument( + "--model", + type=str, + choices=model_example_map.keys(), + required=True, + help="The name of the embedding model.", + ) + return parser.parse_args() + + +def main(args): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model_id = models.data[0].id + + model_example_map[args.model](client, model_id) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_classification_client.py b/examples/online_serving/pooling/openai_classification_client.py similarity index 86% rename from examples/online_serving/openai_classification_client.py rename to examples/online_serving/pooling/openai_classification_client.py index b10e7acbd26c1..d8dc2ef001112 100644 --- a/examples/online_serving/openai_classification_client.py +++ b/examples/online_serving/pooling/openai_classification_client.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for classification API using vLLM API server +NOTE: + start a supported classification model server with `vllm serve`, e.g. + vllm serve jason9693/Qwen2.5-1.5B-apeach +""" import argparse import pprint diff --git a/examples/online_serving/openai_embedding_client.py b/examples/online_serving/pooling/openai_embedding_client.py similarity index 82% rename from examples/online_serving/openai_embedding_client.py rename to examples/online_serving/pooling/openai_embedding_client.py index 6bc390861e2ee..f5f6820d07d73 100644 --- a/examples/online_serving/openai_embedding_client.py +++ b/examples/online_serving/pooling/openai_embedding_client.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for embedding API using vLLM API server +NOTE: + start a supported embeddings model server with `vllm serve`, e.g. + vllm serve intfloat/e5-small +""" from openai import OpenAI diff --git a/examples/online_serving/openai_embedding_matryoshka_fy.py b/examples/online_serving/pooling/openai_embedding_matryoshka_fy.py similarity index 100% rename from examples/online_serving/openai_embedding_matryoshka_fy.py rename to examples/online_serving/pooling/openai_embedding_matryoshka_fy.py diff --git a/examples/online_serving/openai_pooling_client.py b/examples/online_serving/pooling/openai_pooling_client.py similarity index 89% rename from examples/online_serving/openai_pooling_client.py rename to examples/online_serving/pooling/openai_pooling_client.py index 95555d41cbea5..569015746b128 100644 --- a/examples/online_serving/openai_pooling_client.py +++ b/examples/online_serving/pooling/openai_pooling_client.py @@ -4,7 +4,9 @@ Example online usage of Pooling API. Run `vllm serve <model> --runner pooling` -to start up the server in vLLM. +to start up the server in vLLM. e.g. + +vllm serve internlm/internlm2-1_8b-reward --trust-remote-code """ import argparse @@ -23,7 +25,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach") + parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward") return parser.parse_args() diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py new file mode 100644 index 0000000000000..611a7cbc89fa2 --- /dev/null +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import os + +import requests + +# This example shows how to perform an online inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Requirements : +# - install plugin at: +# https://github.com/christian-pinto/prithvi_io_processor_plugin +# - start vllm in serving mode with the below args +# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' +# --model-impl terratorch +# --task embed --trust-remote-code +# --skip-tokenizer-init --enforce-eager +# --io-processor-plugin prithvi_to_tiff + + +def main(): + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + server_endpoint = "http://localhost:8000/pooling" + + request_payload_url = { + "data": { + "data": image_url, + "data_format": "url", + "image_format": "tiff", + "out_data_format": "b64_json", + }, + "priority": 0, + "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + "softmax": False, + } + + ret = requests.post(server_endpoint, json=request_payload_url) + + print(f"response.status_code: {ret.status_code}") + print(f"response.reason:{ret.reason}") + + response = ret.json() + + decoded_image = base64.b64decode(response["data"]["data"]) + + out_path = os.path.join(os.getcwd(), "online_prediction.tiff") + + with open(out_path, "wb") as f: + f.write(decoded_image) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/prometheus_grafana/grafana.json b/examples/online_serving/prometheus_grafana/grafana.json index 3488956a5b24c..37abc9de926fd 100644 --- a/examples/online_serving/prometheus_grafana/grafana.json +++ b/examples/online_serving/prometheus_grafana/grafana.json @@ -402,7 +402,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "includeNullMetadata": false, "instant": false, @@ -418,7 +418,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -435,7 +435,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -452,7 +452,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -468,7 +468,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", + "expr": "rate(vllm:inter_token_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:inter_token_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", "hide": false, "instant": false, "legendFormat": "Mean", @@ -476,7 +476,7 @@ "refId": "E" } ], - "title": "Time Per Output Token Latency", + "title": "Inter Token Latency", "type": "timeseries" }, { diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py index d24b553df27c7..af53443b9101a 100644 --- a/examples/online_serving/ray_serve_deepseek.py +++ b/examples/online_serving/ray_serve_deepseek.py @@ -36,7 +36,6 @@ llm_config = LLMConfig( }, # Set to the node's accelerator type. accelerator_type="H100", - runtime_env={"env_vars": {"VLLM_USE_V1": "1"}}, # Customize engine arguments as required (for example, vLLM engine kwargs). engine_kwargs={ "tensor_parallel_size": 8, diff --git a/examples/online_serving/sagemaker-entrypoint.sh b/examples/online_serving/sagemaker-entrypoint.sh index 75a99ffc1f155..1a6b6780ef2a3 100644 --- a/examples/online_serving/sagemaker-entrypoint.sh +++ b/examples/online_serving/sagemaker-entrypoint.sh @@ -21,4 +21,4 @@ while IFS='=' read -r key value; do done < <(env | grep "^${PREFIX}") # Pass the collected arguments to the main entrypoint -exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}" \ No newline at end of file +exec vllm serve "${ARGS[@]}" \ No newline at end of file diff --git a/examples/online_serving/structured_outputs/pyproject.toml b/examples/online_serving/structured_outputs/pyproject.toml index 8f31405ff584a..5e366ab0a03d3 100644 --- a/examples/online_serving/structured_outputs/pyproject.toml +++ b/examples/online_serving/structured_outputs/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "examples-online-structured-outputs" -requires-python = ">=3.9, <3.13" +requires-python = ">=3.10, <3.14" dependencies = ["openai==1.78.1", "pydantic==2.11.4"] version = "0.0.0" diff --git a/examples/online_serving/structured_outputs/structured_outputs.py b/examples/online_serving/structured_outputs/structured_outputs.py index 2a8f4637260c2..3ea6c73e90e8f 100644 --- a/examples/online_serving/structured_outputs/structured_outputs.py +++ b/examples/online_serving/structured_outputs/structured_outputs.py @@ -86,7 +86,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { "content": "Classify this sentiment: vLLM is wonderful!", } ], - "extra_body": {"guided_choice": ["positive", "negative"]}, + "extra_body": {"structured_outputs": {"choice": ["positive", "negative"]}}, }, "regex": { "messages": [ @@ -96,7 +96,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { } ], "extra_body": { - "guided_regex": r"[a-z0-9.]{1,20}@\w{6,10}\.com\n", + "structured_outputs": {"regex": r"[a-z0-9.]{1,20}@\w{6,10}\.com\n"}, }, }, "json": { @@ -122,7 +122,8 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { } ], "extra_body": { - "guided_grammar": """ + "structured_outputs": { + "grammar": """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition @@ -135,6 +136,7 @@ condition ::= column "= " number number ::= "1 " | "2 " """, + } }, }, "structural_tag": { diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index 559c7c493aca2..acbfd8cda489a 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import dataclasses import json import logging import os @@ -23,8 +21,6 @@ from vllm.utils import FlexibleArgumentParser logger = logging.getLogger() -# yapf conflicts with isort for this docstring -# yapf: disable """ tensorize_vllm_model.py is a script that can be used to serialize and deserialize vLLM models. These models can be loaded using tensorizer @@ -134,7 +130,8 @@ def get_parser(): "can be loaded using tensorizer directly to the GPU " "extremely quickly. Tensor encryption and decryption is " "also supported, although libsodium must be installed to " - "use it.") + "use it." + ) parser = EngineArgs.add_cli_args(parser) parser.add_argument( @@ -146,13 +143,14 @@ def get_parser(): "along with the model by instantiating a TensorizerConfig object, " "creating a dict from it with TensorizerConfig.to_serializable(), " "and passing it to LoRARequest's initializer with the kwarg " - "tensorizer_config_dict." + "tensorizer_config_dict.", ) - subparsers = parser.add_subparsers(dest='command', required=True) + subparsers = parser.add_subparsers(dest="command", required=True) serialize_parser = subparsers.add_parser( - 'serialize', help="Serialize a model to `--serialized-directory`") + "serialize", help="Serialize a model to `--serialized-directory`" + ) serialize_parser.add_argument( "--suffix", @@ -165,7 +163,9 @@ def get_parser(): "`--suffix` is `v1`, the serialized model tensors will be " "saved to " "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " - "If none is provided, a random UUID will be used.")) + "If none is provided, a random UUID will be used." + ), + ) serialize_parser.add_argument( "--serialized-directory", type=str, @@ -177,108 +177,127 @@ def get_parser(): "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " "where `suffix` is given by `--suffix` or a random UUID if not " - "provided.") + "provided.", + ) serialize_parser.add_argument( "--serialization-kwargs", type=tensorizer_kwargs_arg, required=False, - help=("A JSON string containing additional keyword arguments to " - "pass to Tensorizer's TensorSerializer during " - "serialization.")) + help=( + "A JSON string containing additional keyword arguments to " + "pass to Tensorizer's TensorSerializer during " + "serialization." + ), + ) serialize_parser.add_argument( "--keyfile", type=str, required=False, - help=("Encrypt the model weights with a randomly-generated binary key," - " and save the key at this path")) + help=( + "Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path" + ), + ) deserialize_parser = subparsers.add_parser( - 'deserialize', - help=("Deserialize a model from `--path-to-tensors`" - " to verify it can be loaded and used.")) + "deserialize", + help=( + "Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used." + ), + ) deserialize_parser.add_argument( "--path-to-tensors", type=str, required=False, - help="The local path or S3 URI to the model tensors to deserialize. ") + help="The local path or S3 URI to the model tensors to deserialize. ", + ) deserialize_parser.add_argument( "--serialized-directory", type=str, required=False, help="Directory with model artifacts for loading. Assumes a " - "model.tensors file exists therein. Can supersede " - "--path-to-tensors.") + "model.tensors file exists therein. Can supersede " + "--path-to-tensors.", + ) deserialize_parser.add_argument( "--keyfile", type=str, required=False, - help=("Path to a binary key to use to decrypt the model weights," - " if the model was serialized with encryption")) + help=( + "Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption" + ), + ) deserialize_parser.add_argument( "--deserialization-kwargs", type=tensorizer_kwargs_arg, required=False, - help=("A JSON string containing additional keyword arguments to " - "pass to Tensorizer's `TensorDeserializer` during " - "deserialization.")) + help=( + "A JSON string containing additional keyword arguments to " + "pass to Tensorizer's `TensorDeserializer` during " + "deserialization." + ), + ) TensorizerArgs.add_cli_args(deserialize_parser) return parser -def merge_extra_config_with_tensorizer_config(extra_cfg: dict, - cfg: TensorizerConfig): + +def merge_extra_config_with_tensorizer_config(extra_cfg: dict, cfg: TensorizerConfig): for k, v in extra_cfg.items(): if hasattr(cfg, k): setattr(cfg, k, v) logger.info( "Updating TensorizerConfig with %s from " - "--model-loader-extra-config provided", k + "--model-loader-extra-config provided", + k, ) + def deserialize(args, tensorizer_config): if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir - llm = LLM(model=args.model, - load_format="tensorizer", - tensor_parallel_size=args.tensor_parallel_size, - model_loader_extra_config=tensorizer_config, - enable_lora=True, + llm = LLM( + model=args.model, + load_format="tensorizer", + tensor_parallel_size=args.tensor_parallel_size, + model_loader_extra_config=tensorizer_config, + enable_lora=True, ) sampling_params = SamplingParams( - temperature=0, - max_tokens=256, - stop=["[/assistant]"] + temperature=0, max_tokens=256, stop=["[/assistant]"] ) # Truncating this as the extra text isn't necessary - prompts = [ - "[user] Write a SQL query to answer the question based on ..." - ] + prompts = ["[user] Write a SQL query to answer the question based on ..."] # Test LoRA load print( llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest("sql-lora", - 1, - args.lora_path, - tensorizer_config_dict = tensorizer_config - .to_serializable()) + prompts, + sampling_params, + lora_request=LoRARequest( + "sql-lora", + 1, + args.lora_path, + tensorizer_config_dict=tensorizer_config.to_serializable(), + ), ) ) else: - llm = LLM(model=args.model, - load_format="tensorizer", - tensor_parallel_size=args.tensor_parallel_size, - model_loader_extra_config=tensorizer_config + llm = LLM( + model=args.model, + load_format="tensorizer", + tensor_parallel_size=args.tensor_parallel_size, + model_loader_extra_config=tensorizer_config, ) return llm @@ -287,17 +306,20 @@ def main(): parser = get_parser() args = parser.parse_args() - s3_access_key_id = (getattr(args, 's3_access_key_id', None) - or os.environ.get("S3_ACCESS_KEY_ID", None)) - s3_secret_access_key = (getattr(args, 's3_secret_access_key', None) - or os.environ.get("S3_SECRET_ACCESS_KEY", None)) - s3_endpoint = (getattr(args, 's3_endpoint', None) - or os.environ.get("S3_ENDPOINT_URL", None)) + s3_access_key_id = getattr(args, "s3_access_key_id", None) or os.environ.get( + "S3_ACCESS_KEY_ID", None + ) + s3_secret_access_key = getattr( + args, "s3_secret_access_key", None + ) or os.environ.get("S3_SECRET_ACCESS_KEY", None) + s3_endpoint = getattr(args, "s3_endpoint", None) or os.environ.get( + "S3_ENDPOINT_URL", None + ) credentials = { "s3_access_key_id": s3_access_key_id, "s3_secret_access_key": s3_secret_access_key, - "s3_endpoint": s3_endpoint + "s3_endpoint": s3_endpoint, } model_ref = args.model @@ -311,30 +333,25 @@ def main(): if args.model_loader_extra_config: extra_config = json.loads(args.model_loader_extra_config) - - tensorizer_dir = (args.serialized_directory or - extra_config.get("tensorizer_dir")) - tensorizer_uri = (getattr(args, "path_to_tensors", None) - or extra_config.get("tensorizer_uri")) + tensorizer_dir = args.serialized_directory or extra_config.get("tensorizer_dir") + tensorizer_uri = getattr(args, "path_to_tensors", None) or extra_config.get( + "tensorizer_uri" + ) if tensorizer_dir and tensorizer_uri: - parser.error("--serialized-directory and --path-to-tensors " - "cannot both be provided") - - if not tensorizer_dir and not tensorizer_uri: - parser.error("Either --serialized-directory or --path-to-tensors " - "must be provided") - - - if args.command == "serialize": - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - - engine_args = EngineArgs.from_cli_args( - argparse.Namespace(**eng_args_dict) + parser.error( + "--serialized-directory and --path-to-tensors cannot both be provided" ) - input_dir = tensorizer_dir.rstrip('/') + if not tensorizer_dir and not tensorizer_uri: + parser.error( + "Either --serialized-directory or --path-to-tensors must be provided" + ) + + if args.command == "serialize": + engine_args = EngineArgs.from_cli_args(args) + + input_dir = tensorizer_dir.rstrip("/") suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" if engine_args.tensor_parallel_size > 1: @@ -346,15 +363,14 @@ def main(): tensorizer_uri=model_path, encryption_keyfile=keyfile, serialization_kwargs=args.serialization_kwargs or {}, - **credentials + **credentials, ) if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir tensorize_lora_adapter(args.lora_path, tensorizer_config) - merge_extra_config_with_tensorizer_config(extra_config, - tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) tensorize_vllm_model(engine_args, tensorizer_config) elif args.command == "deserialize": @@ -363,11 +379,10 @@ def main(): tensorizer_dir=args.serialized_directory, encryption_keyfile=keyfile, deserialization_kwargs=args.deserialization_kwargs or {}, - **credentials + **credentials, ) - merge_extra_config_with_tensorizer_config(extra_config, - tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) deserialize(args, tensorizer_config) else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/examples/pyproject.toml b/examples/pyproject.toml deleted file mode 100644 index f825cb203269c..0000000000000 --- a/examples/pyproject.toml +++ /dev/null @@ -1,54 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 -exclude = [ - # External file, leaving license intact - "examples/other/fp8/quantizer/quantize.py", - "vllm/vllm_flash_attn/flash_attn_interface.pyi" -] - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.lint.isort] -known-first-party = ["vllm"] - -[tool.ruff.format] -docstring-code-format = true \ No newline at end of file diff --git a/examples/template_vlm2vec.jinja b/examples/template_vlm2vec_phi3v.jinja similarity index 100% rename from examples/template_vlm2vec.jinja rename to examples/template_vlm2vec_phi3v.jinja diff --git a/examples/template_vlm2vec_qwen2vl.jinja b/examples/template_vlm2vec_qwen2vl.jinja new file mode 100644 index 0000000000000..3ab099d8f546d --- /dev/null +++ b/examples/template_vlm2vec_qwen2vl.jinja @@ -0,0 +1,15 @@ +{%- if messages | length > 1 -%} + {{ raise_exception('Embedding models should only embed one message at a time') }} +{%- endif -%} + +{% set vars = namespace(parts=[]) %} +{%- for message in messages -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- set vars.parts = vars.parts + [content['text']] %} + {%- elif content['type'] == 'image' -%} + {%- set vars.parts = vars.parts + ['<|image_pad|>'] %} + {%- endif -%} + {%- endfor -%} +{%- endfor -%} +{{ vars.parts | join(' ') }} diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 83886762c2893..6f40c38c20644 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -9,7 +9,7 @@ <|system|> {{ system_message }} {%- if tools %} -In addition to plain text responses, you can chose to call one or more of the provided functions. +In addition to plain text responses, you can choose to call one or more of the provided functions. Use the following rule to decide when to call a function: * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so @@ -19,7 +19,7 @@ If you decide to call functions: * prefix function calls with functools marker (no closing marker required) * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * respect the argument type formatting. E.g., if the type is number and format is float, write value 7 as 7.0 * make sure you pick the right functions that match the user intent diff --git a/examples/tool_chat_template_qwen3coder.jinja b/examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 0000000000000..49b0e8d0ee7e6 --- /dev/null +++ b/examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "<tools>" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }} + {%- if tool.description is defined %} + {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }} + {%- endif %} + {{- '\n<parameters>' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n<parameter>' }} + {{- '\n<name>' ~ param_name ~ '</name>' }} + {%- if param_fields.type is defined %} + {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n</parameter>' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n</parameters>' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n</function>' }} + {%- endfor %} + {{- "\n</tools>" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '<parameter=' + args_name + '>\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n</parameter>\n' }} + {%- endfor %} + {%- endif %} + {{- '</function>\n</tool_call>' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/find_cuda_init.py b/find_cuda_init.py deleted file mode 100644 index 308fc6fc2d61c..0000000000000 --- a/find_cuda_init.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import importlib -import traceback -from typing import Callable -from unittest.mock import patch - - -def find_cuda_init(fn: Callable[[], object]) -> None: - """ - Helper function to debug CUDA re-initialization errors. - - If `fn` initializes CUDA, prints the stack trace of how this happens. - """ - from torch.cuda import _lazy_init - - stack = None - - def wrapper(): - nonlocal stack - stack = traceback.extract_stack() - return _lazy_init() - - with patch("torch.cuda._lazy_init", wrapper): - fn() - - if stack is not None: - print("==== CUDA Initialized ====") - print("".join(traceback.format_list(stack)).strip()) - print("==========================") - - -if __name__ == "__main__": - find_cuda_init( - lambda: importlib.import_module("vllm.model_executor.models.llava")) diff --git a/mkdocs.yaml b/mkdocs.yaml index 47fe1ebce9712..6f2be65a18af8 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -79,6 +79,7 @@ plugins: - "re:vllm\\._.*" # Internal modules - "vllm.third_party" - "vllm.vllm_flash_attn" + - !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default - mkdocstrings: handlers: python: @@ -129,15 +130,16 @@ markdown_extensions: - toc: permalink: true # For math rendering - - mdx_math: - enable_dollar_delimiter: true + - pymdownx.arithmatex: + generic: true extra_css: - mkdocs/stylesheets/extra.css extra_javascript: - mkdocs/javascript/run_llm_widget.js - - https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML + - mkdocs/javascript/mathjax.js + - https://unpkg.com/mathjax@3.2.2/es5/tex-mml-chtml.js - mkdocs/javascript/edit_and_feedback.js - mkdocs/javascript/slack_and_forum.js diff --git a/pyproject.toml b/pyproject.toml index 013f2a6cd59e4..49a7a0b8b1210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.7.1", + "torch == 2.8.0", "wheel", "jinja2", ] @@ -20,7 +20,6 @@ license-files = ["LICENSE"] readme = "README.md" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" classifiers = [ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -31,7 +30,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ] -requires-python = ">=3.9,<3.14" +requires-python = ">=3.10,<3.14" dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] @@ -52,28 +51,10 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi where = ["."] include = ["vllm*"] -[tool.yapfignore] -ignore_patterns = [ - ".buildkite/**", - "benchmarks/**", - "build/**", - "examples/**", -] - -[tool.ruff] -# Allow lines to be as long as 80. -line-length = 80 - [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing - skip V0 code -"vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/core/**/*.py" = ["UP006", "UP035"] -"vllm/engine/**/*.py" = ["UP006", "UP035"] -"vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/worker/**/*.py" = ["UP006", "UP035"] [tool.ruff.lint] select = [ @@ -88,7 +69,7 @@ select = [ # flake8-simplify "SIM", # isort - # "I", + "I", # flake8-logging-format "G", ] @@ -97,58 +78,37 @@ ignore = [ "F405", "F403", # lambda expression assignment "E731", + # zip without `strict=` + "B905", # Loop control variable not used within loop body "B007", # f-string format "UP032", # Can remove once 3.10+ is the minimum Python version "UP007", + "UP027", + "UP035", + "UP038", + "UP045", ] +[tool.ruff.format] +docstring-code-format = true + [tool.mypy] plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -# After fixing type errors resulting from follow_imports: "skip" -> "silent", -# move the directory here and remove it from tools/mypy.sh -files = [ - "vllm/*.py", - "vllm/adapter_commons", - "vllm/assets", - "vllm/entrypoints", - "vllm/core", - "vllm/inputs", - "vllm/logging_utils", - "vllm/multimodal", - "vllm/platforms", - "vllm/transformers_utils", - "vllm/triton_utils", - "vllm/usage", -] -# TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = [ - "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", - # Ignore triton kernels in ops. - 'vllm/attention/ops/.*\.py$' -] - -[tool.isort] -skip_glob = [ - ".buildkite/*", - "benchmarks/*", - "examples/*", -] -use_parentheses = true -skip_gitignore = true - [tool.pytest.ini_options] markers = [ + "slow_test", "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", "hybrid_model: models that contain mamba layers (including pure SSM and hybrid architectures)", "cpu_model: enable this model test in CPU tests", + "cpu_test: mark test as CPU-only test", "split: run this test as part of a split", "distributed: run this test only in distributed GPU tests", "skip_v1: do not run this test with v1", @@ -228,6 +188,8 @@ fo = "fo" ba = "ba" [tool.typos.type.py.extend-words] +ba = "ba" +nd = "nd" [tool.typos.type.cpp] extend-glob = ["*.cu"] @@ -344,3 +306,6 @@ extend-ignore-re = [] windo = "windo" [tool.typos.type.vimscript.extend-words] + +[tool.uv] +no-build-isolation-package = ["torch"] diff --git a/requirements/build.txt b/requirements/build.txt index dd644d621efc1..5f826a1afa144 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,8 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.7.1 +torch==2.8.0 wheel jinja2>=3.1.6 regex +build diff --git a/requirements/common.txt b/requirements/common.txt index 8acf634526ff1..d5fa1e92bd7eb 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -13,19 +13,18 @@ protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content -pydantic >= 2.11.7 +pydantic >= 2.12.0 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer >= 0.10.11, < 0.11 +lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines_core == 0.2.10 ; platform_machine != "s390x" -outlines == 0.1.11 ; platform_machine == "s390x" +outlines_core == 0.2.11 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.21; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" +xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs @@ -33,7 +32,7 @@ pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 importlib_metadata; python_version < '3.10' -mistral_common[image,audio] >= 1.8.2 +mistral_common[image,audio] >= 1.8.5 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt index 37f072202bd71..b511b0f5d31b3 100644 --- a/requirements/cpu-build.txt +++ b/requirements/cpu-build.txt @@ -1,12 +1,11 @@ -# Temporarily used for x86 CPU backend to avoid performance regression of torch>2.6.0+cpu, -# see https://github.com/pytorch/pytorch/pull/151218 cmake>=3.26.1 ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 --extra-index-url https://download.pytorch.org/whl/cpu -torch==2.6.0+cpu +torch==2.8.0+cpu; platform_machine == "x86_64" +torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" wheel jinja2>=3.1.6 regex diff --git a/requirements/cpu.txt b/requirements/cpu.txt index f4b95b72898cc..d53ab3649308a 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,30 +1,28 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' and platform_machine != "s390x" # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x" +numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding # Dependencies for CPUs packaging>=24.2 setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu -torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 -torch==2.7.0; platform_system == "Darwin" -torch==2.7.0; platform_machine == "ppc64le" -torch==2.6.0; platform_machine == "aarch64" # for arm64 CPUs, torch 2.7.0 has a issue: https://github.com/vllm-project/vllm/issues/17960 +torch==2.8.0+cpu; platform_machine == "x86_64" +torch==2.8.0; platform_system == "Darwin" +torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" -torchaudio==2.7.0; platform_machine == "ppc64le" +torchaudio==2.8.0; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" -torchvision==0.22.0; platform_machine == "ppc64le" +torchvision==0.23.0; platform_machine == "ppc64le" datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs intel-openmp==2024.2.1; platform_machine == "x86_64" -intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 +intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64" triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile. # Use this to gather CPU info and optimize based on ARM Neoverse cores diff --git a/requirements/cuda.txt b/requirements/cuda.txt index fb30e493f80b3..06956415d072e 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -1,14 +1,15 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.7.1 -torchaudio==2.7.1 +torch==2.8.0 +torchaudio==2.8.0 # These must be updated alongside torch -torchvision==0.22.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -# https://github.com/facebookresearch/xformers/releases/tag/v0.0.31 -xformers==0.0.31; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 \ No newline at end of file +torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 +xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 +# FlashInfer should be updated together with the Dockerfile +flashinfer-python==0.4.0 \ No newline at end of file diff --git a/requirements/docs.txt b/requirements/docs.txt index a24b9c7e924bf..d1c546398780a 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,27 +7,12 @@ mkdocs-awesome-nav mkdocs-glightbox mkdocs-git-revision-date-localized-plugin mkdocs-minify-plugin -python-markdown-math regex ruff # Required for argparse hook only -f https://download.pytorch.org/whl/cpu cachetools -cbor2 -cloudpickle -fastapi msgspec -openai -openai-harmony -partial-json-parser -pillow -psutil -pybase64 pydantic -setproctitle torch -transformers -zmq -uvloop -prometheus-client diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt index 262675a231206..b1f3269cd3813 100644 --- a/requirements/kv_connectors.txt +++ b/requirements/kv_connectors.txt @@ -1 +1,2 @@ -lmcache \ No newline at end of file +lmcache +nixl >= 0.6.0 # Required for disaggregated prefill diff --git a/requirements/neuron.txt b/requirements/neuron.txt deleted file mode 100644 index 7df478eddde3f..0000000000000 --- a/requirements/neuron.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Common dependencies --r common.txt - -# Dependencies for Neuron devices -packaging>=24.2 -setuptools>=77.0.3,<80.0.0 -torch-neuronx >= 2.5.0 -neuronx-cc>=2.0.0a0 -torchvision # Required for Llama3.2 multimodal image preprocessing diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index a529bf4504e40..dea1926bbd695 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -23,14 +23,14 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.8.2 # required for voxtral test +mistral_common[image,audio] >= 1.8.5 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test -transformers==4.52.4 -tokenizers==0.21.1 +transformers==4.56.2 +tokenizers==0.22.0 schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes>=0.46.1 @@ -40,10 +40,8 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding numpy -runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 +runai-model-streamer[s3,gcs]==0.14.0 fastsafetensors>=0.1.10 -pydantic>=2.10 # 2.9 leads to error on python 3.10 +pydantic>=2.12 # 2.11 leads to error on python 3.13 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index cbae9bbb8a9b3..a86a8ab6df149 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,10 +1,10 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.2.4 -torch==2.7.0 -torchvision==0.22.0 -torchaudio==2.7.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.3 +torch==2.8.0 +torchvision==0.23.0 +torchaudio==2.8.0 triton==3.3.0 cmake>=3.26.1,<4 @@ -14,3 +14,4 @@ setuptools-scm>=8 wheel jinja2>=3.1.6 amdsmi==6.2.4 +timm>=1.0.17 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 25f950a99eceb..869fb28c3d85c 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -1,5 +1,6 @@ # Common dependencies -r common.txt +tblib==3.1.0 # entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai diff --git a/requirements/rocm.txt b/requirements/rocm.txt index c3bb65b70a0b8..d9743f0446438 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -1,20 +1,17 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for AMD GPUs -boto3 -botocore datasets -ray>=2.10.0,<2.45.0 +ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. peft pytest-asyncio tensorizer==2.10.1 packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 -conch-triton-kernels==1.2.1 \ No newline at end of file +runai-model-streamer[s3,gcs]==0.14.0 +conch-triton-kernels==1.2.1 +timm>=1.0.17 diff --git a/requirements/test.in b/requirements/test.in index 098a9242bc3af..f0941d3c59183 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -6,6 +6,7 @@ pytest-asyncio pytest-rerunfailures pytest-shard pytest-timeout +pytest-cov # testing utils backoff # required for phi4mm test @@ -21,13 +22,14 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests +tblib # for pickling test exceptions timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.7.1 -torchaudio==2.7.1 -torchvision==0.22.1 +torch==2.8.0 +torchaudio==2.8.0 +torchvision==0.23.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.8.2 # required for voxtral test +mistral_common[image,audio] >= 1.8.5 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test @@ -35,8 +37,8 @@ datamodel_code_generator # required for minicpm3 test # TODO: Use lm-eval[api]==0.4.10 once released lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.55.2 -tokenizers==0.21.1 +transformers==4.56.2 +tokenizers==0.22.0 schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes==0.46.1 @@ -46,11 +48,11 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding numpy -runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 +runai-model-streamer[s3,gcs]==0.14.0 fastsafetensors>=0.1.10 -pydantic>=2.10 # 2.9 leads to error on python 3.10 -terratorch==1.1rc2 # required for PrithviMAE test +pydantic>=2.12 # 2.11 leads to error on python 3.13 +decord==0.6.0 +terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test +gpt-oss >= 0.0.7; python_version > '3.11' \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index 8b872752d875c..03fbdcc8d453b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 +# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28 absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -10,18 +10,19 @@ aenum==3.1.16 # via lightly affine==2.4.0 # via rasterio -aiohappyeyeballs==2.4.3 +aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.10.11 +aiohttp==3.13.0 # via # aiohttp-cors # datasets # fsspec + # gpt-oss # lm-eval # ray aiohttp-cors==0.8.1 # via ray -aiosignal==1.3.1 +aiosignal==1.4.0 # via aiohttp albucore==0.0.16 # via terratorch @@ -72,7 +73,9 @@ blobfile==3.0.0 bm25s==0.2.13 # via mteb boto3==1.35.57 - # via tensorizer + # via + # runai-model-streamer-s3 + # tensorizer botocore==1.35.57 # via # boto3 @@ -101,6 +104,8 @@ chardet==5.2.0 # via mbstrdecoder charset-normalizer==3.4.0 # via requests +chz==0.3.0 + # via gpt-oss click==8.1.7 # via # black @@ -135,9 +140,11 @@ colorful==0.5.6 # via ray contourpy==1.3.0 # via matplotlib +coverage==7.10.6 + # via pytest-cov cramjam==2.9.0 # via fastparquet -cupy-cuda12x==13.3.0 +cupy-cuda12x==13.6.0 # via ray cycler==0.12.1 # via matplotlib @@ -156,6 +163,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -167,7 +176,9 @@ distlib==0.3.9 dnspython==2.7.0 # via email-validator docker==7.1.0 - # via mlflow + # via + # gpt-oss + # mlflow docopt==0.6.2 # via num2words docstring-parser==0.17.0 @@ -193,7 +204,9 @@ eval-type-backport==0.2.2 evaluate==0.4.3 # via lm-eval fastapi==0.116.1 - # via mlflow-skinny + # via + # gpt-oss + # mlflow-skinny fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -245,13 +258,31 @@ gitdb==4.0.12 gitpython==3.1.44 # via mlflow-skinny google-api-core==2.24.2 - # via opencensus + # via + # google-cloud-core + # google-cloud-storage + # opencensus google-auth==2.40.2 # via # databricks-sdk # google-api-core + # google-cloud-core + # google-cloud-storage + # runai-model-streamer-gcs +google-cloud-core==2.4.3 + # via google-cloud-storage +google-cloud-storage==3.4.0 + # via runai-model-streamer-gcs +google-crc32c==1.7.1 + # via + # google-cloud-storage + # google-resumable-media +google-resumable-media==2.7.2 + # via google-cloud-storage googleapis-common-protos==1.70.0 # via google-api-core +gpt-oss==0.0.8 + # via -r requirements/test.in graphene==3.4.3 # via mlflow graphql-core==3.2.6 @@ -279,6 +310,8 @@ hf-xet==1.1.7 # via huggingface-hub hiredis==3.0.0 # via tensorizer +html2text==2025.4.15 + # via gpt-oss httpcore==1.0.6 # via httpx httpx==0.27.2 @@ -413,6 +446,7 @@ lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b772215 lxml==5.3.0 # via # blobfile + # gpt-oss # sacrebleu mako==1.3.10 # via alembic @@ -440,7 +474,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.8.2 +mistral-common==1.8.5 # via -r requirements/test.in mlflow==2.22.0 # via terratorch @@ -493,6 +527,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate @@ -538,42 +573,42 @@ numpy==1.26.4 # tritonclient # vocos # xarray -nvidia-cublas-cu12==12.8.3.14 +nvidia-cublas-cu12==12.8.4.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.8.57 +nvidia-cuda-cupti-cu12==12.8.90 # via torch -nvidia-cuda-nvrtc-cu12==12.8.61 +nvidia-cuda-nvrtc-cu12==12.8.93 # via torch -nvidia-cuda-runtime-cu12==12.8.57 +nvidia-cuda-runtime-cu12==12.8.90 # via torch -nvidia-cudnn-cu12==9.7.1.26 +nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.3.41 +nvidia-cufft-cu12==11.3.3.83 # via torch -nvidia-cufile-cu12==1.13.0.11 +nvidia-cufile-cu12==1.13.1.3 # via torch -nvidia-curand-cu12==10.3.9.55 +nvidia-curand-cu12==10.3.9.90 # via torch -nvidia-cusolver-cu12==11.7.2.55 +nvidia-cusolver-cu12==11.7.3.90 # via torch -nvidia-cusparse-cu12==12.5.7.53 +nvidia-cusparse-cu12==12.5.8.93 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.3 +nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.26.2 +nvidia-nccl-cu12==2.27.3 # via torch -nvidia-nvjitlink-cu12==12.8.61 +nvidia-nvjitlink-cu12==12.8.93 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.8.55 +nvidia-nvtx-cu12==12.8.90 # via torch omegaconf==2.3.0 # via @@ -581,6 +616,8 @@ omegaconf==2.3.0 # lightning open-clip-torch==2.32.0 # via -r requirements/test.in +openai-harmony==0.0.4 + # via gpt-oss opencensus==0.11.4 # via ray opencensus-context==0.1.3 @@ -683,7 +720,9 @@ platformdirs==4.3.6 plotly==5.24.1 # via genai-perf pluggy==1.5.0 - # via pytest + # via + # pytest + # pytest-cov polars==1.29.0 # via mteb pooch==1.8.2 @@ -699,7 +738,9 @@ prometheus-client==0.22.0 # opentelemetry-exporter-prometheus # ray propcache==0.2.0 - # via yarl + # via + # aiohttp + # yarl proto-plus==1.26.1 # via google-api-core protobuf==5.28.3 @@ -742,19 +783,21 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.7 +pydantic==2.12.0 # via # -r requirements/test.in # albumentations # datamodel-code-generator # fastapi + # gpt-oss # lightly # mistral-common # mlflow-skinny # mteb + # openai-harmony # pydantic-extra-types # ray -pydantic-core==2.33.2 +pydantic-core==2.41.1 # via pydantic pydantic-extra-types==2.10.5 # via mistral-common @@ -783,6 +826,7 @@ pytest==8.3.5 # buildkite-test-collector # genai-perf # pytest-asyncio + # pytest-cov # pytest-forked # pytest-mock # pytest-rerunfailures @@ -793,6 +837,8 @@ pytest==8.3.5 # terratorch pytest-asyncio==0.24.0 # via -r requirements/test.in +pytest-cov==6.3.0 + # via -r requirements/test.in pytest-forked==1.6.0 # via -r requirements/test.in pytest-mock==3.14.0 @@ -878,6 +924,8 @@ requests==2.32.3 # docker # evaluate # google-api-core + # google-cloud-storage + # gpt-oss # huggingface-hub # lightly # lm-eval @@ -915,10 +963,12 @@ rsa==4.9.1 # via google-auth rtree==1.4.0 # via torchgeo -runai-model-streamer==0.11.0 - # via -r requirements/test.in -runai-model-streamer-s3==0.11.0 +runai-model-streamer==0.14.0 # via -r requirements/test.in +runai-model-streamer-gcs==0.14.0 + # via runai-model-streamer +runai-model-streamer-s3==0.14.0 + # via runai-model-streamer s3transfer==0.10.3 # via boto3 sacrebleu==2.4.3 @@ -962,8 +1012,6 @@ sentence-transformers==3.2.1 # via # -r requirements/test.in # mteb -sentencepiece==0.2.0 - # via mistral-common setuptools==77.0.3 # via # lightning-utilities @@ -1021,6 +1069,8 @@ starlette-testclient==0.4.1 # via schemathesis statsmodels==0.14.4 # via genai-perf +structlog==25.4.0 + # via gpt-oss sympy==1.13.3 # via # einx @@ -1029,17 +1079,22 @@ tabledata==1.3.3 # via pytablewriter tabulate==0.9.0 # via sacrebleu +tblib==3.1.0 + # via -r requirements/test.in tcolorpy==0.1.6 # via pytablewriter -tenacity==9.0.0 +tenacity==9.1.2 # via + # gpt-oss # lm-eval # plotly tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch==1.1rc2 +termcolor==3.1.0 + # via gpt-oss +terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn @@ -1047,8 +1102,9 @@ tifffile==2025.3.30 # via # scikit-image # terratorch -tiktoken==0.7.0 +tiktoken==0.12.0 # via + # gpt-oss # lm-eval # mistral-common timm==1.0.17 @@ -1058,7 +1114,7 @@ timm==1.0.17 # segmentation-models-pytorch # terratorch # torchgeo -tokenizers==0.21.1 +tokenizers==0.22.0 # via # -r requirements/test.in # transformers @@ -1066,7 +1122,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.7.1+cu128 +torch==2.8.0+cu128 # via # -r requirements/test.in # accelerate @@ -1095,7 +1151,7 @@ torch==2.7.1+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.7.1+cu128 +torchaudio==2.8.0+cu128 # via # -r requirements/test.in # encodec @@ -1108,7 +1164,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.22.1+cu128 +torchvision==0.23.0+cu128 # via # -r requirements/test.in # lightly @@ -1139,7 +1195,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.55.2 +transformers==4.56.2 # via # -r requirements/test.in # genai-perf @@ -1149,7 +1205,7 @@ transformers==4.55.2 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.3.1 +triton==3.4.0 # via torch tritonclient==2.51.0 # via @@ -1166,10 +1222,12 @@ types-python-dateutil==2.9.0.20241206 # via arrow typeshed-client==2.8.2 # via jsonargparse -typing-extensions==4.12.2 +typing-extensions==4.15.0 # via + # aiosignal # albumentations # alembic + # chz # fastapi # graphene # huggingface-hub @@ -1193,7 +1251,7 @@ typing-extensions==4.12.2 # typer # typeshed-client # typing-inspection -typing-inspection==0.4.1 +typing-inspection==0.4.2 # via pydantic tzdata==2024.2 # via pandas @@ -1209,7 +1267,9 @@ urllib3==2.2.3 # responses # tritonclient uvicorn==0.35.0 - # via mlflow-skinny + # via + # gpt-oss + # mlflow-skinny vector-quantize-pytorch==1.21.2 # via -r requirements/test.in virtualenv==20.31.2 diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 7ea239b48ea26..4241cbb2b0333 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -14,14 +14,4 @@ nixl==0.3.0 tpu_info==0.4.0 # Install torch_xla ---pre ---extra-index-url https://download.pytorch.org/whl/nightly/cpu ---find-links https://storage.googleapis.com/libtpu-wheels/index.html ---find-links https://storage.googleapis.com/libtpu-releases/index.html ---find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ---find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.9.0.dev20250730 -torchvision==0.24.0.dev20250730 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" - +torch_xla[tpu, pallas]==2.8.0 \ No newline at end of file diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 4607c3efdf14c..d14b631aa9364 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -9,11 +9,10 @@ setuptools>=77.0.3,<80.0.0 wheel jinja2>=3.1.6 datasets # for benchmark scripts -numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding ---extra-index-url=https://download.pytorch.org/whl/xpu +numba == 0.61.2 # Required for N-gram speculative decoding torch==2.8.0+xpu torchaudio torchvision -pytorch-triton-xpu ---extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.8.10+xpu +--extra-index-url=https://download.pytorch.org/whl/xpu + +intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl diff --git a/setup.py b/setup.py index ca6e0a8592cc2..60dde120d5004 100644 --- a/setup.py +++ b/setup.py @@ -34,34 +34,36 @@ logger = logging.getLogger(__name__) # cannot import envs directly because it depends on vllm, # which is not installed yet -envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) +envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py")) VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu": - logger.warning( - "VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS") + logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS") VLLM_TARGET_DEVICE = "cpu" -elif not (sys.platform.startswith("linux") - or sys.platform.startswith("darwin")): +elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")): logger.warning( "vLLM only supports Linux platform (including WSL) and MacOS." "Building on %s, " - "so vLLM may not be able to run correctly", sys.platform) + "so vLLM may not be able to run correctly", + sys.platform, + ) VLLM_TARGET_DEVICE = "empty" -elif (sys.platform.startswith("linux") and torch.version.cuda is None - and os.getenv("VLLM_TARGET_DEVICE") is None - and torch.version.hip is None): +elif ( + sys.platform.startswith("linux") + and torch.version.cuda is None + and os.getenv("VLLM_TARGET_DEVICE") is None + and torch.version.hip is None +): # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu VLLM_TARGET_DEVICE = "cpu" -MAIN_CUDA_VERSION = "12.8" - def is_sccache_available() -> bool: - return which("sccache") is not None and \ - not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))) + return which("sccache") is not None and not bool( + int(os.getenv("VLLM_DISABLE_SCCACHE", "0")) + ) def is_ccache_available() -> bool: @@ -85,8 +87,7 @@ def is_url_available(url: str) -> bool: class CMakeExtension(Extension): - - def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: + def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: super().__init__(name, sources=[], py_limited_api=True, **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) @@ -123,8 +124,8 @@ class cmake_build_ext(build_ext): if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) logger.info( - "Using NVCC_THREADS=%d as the number of nvcc threads.", - nvcc_threads) + "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads + ) else: nvcc_threads = 1 num_jobs = max(1, num_jobs // nvcc_threads) @@ -148,36 +149,36 @@ class cmake_build_ext(build_ext): cfg = envs.CMAKE_BUILD_TYPE or default_cfg cmake_args = [ - '-DCMAKE_BUILD_TYPE={}'.format(cfg), - '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + "-DCMAKE_BUILD_TYPE={}".format(cfg), + "-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE), ] verbose = envs.VERBOSE if verbose: - cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] + cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] if is_sccache_available(): cmake_args += [ - '-DCMAKE_C_COMPILER_LAUNCHER=sccache', - '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache', - '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache', - '-DCMAKE_HIP_COMPILER_LAUNCHER=sccache', + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", ] elif is_ccache_available(): cmake_args += [ - '-DCMAKE_C_COMPILER_LAUNCHER=ccache', - '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache', - '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache', - '-DCMAKE_HIP_COMPILER_LAUNCHER=ccache', + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", ] # Pass the python executable to cmake so it can find an exact # match. - cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] + cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)] # Pass the python path to cmake so it can reuse the build dependencies # on subsequent calls to python. - cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))] + cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))] # Override the base directory for FetchContent downloads to $ROOT/.deps # This allows sharing dependencies between profiles, @@ -185,7 +186,7 @@ class cmake_build_ext(build_ext): # To override this, set the FETCHCONTENT_BASE_DIR environment variable. fc_base_dir = os.path.join(ROOT_DIR, ".deps") fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) - cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)] + cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)] # # Setup parallelism and build tool @@ -193,30 +194,36 @@ class cmake_build_ext(build_ext): num_jobs, nvcc_threads = self.compute_num_jobs() if nvcc_threads: - cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)] + cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)] if is_ninja_available(): - build_tool = ['-G', 'Ninja'] + build_tool = ["-G", "Ninja"] cmake_args += [ - '-DCMAKE_JOB_POOL_COMPILE:STRING=compile', - '-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs), + "-DCMAKE_JOB_POOL_COMPILE:STRING=compile", + "-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs), ] else: # Default build tool to whatever cmake picks. build_tool = [] # Make sure we use the nvcc from CUDA_HOME if _is_cuda(): - cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc'] + cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"] + + other_cmake_args = os.environ.get("CMAKE_ARGS") + if other_cmake_args: + cmake_args += other_cmake_args.split() + subprocess.check_call( - ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], - cwd=self.build_temp) + ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args], + cwd=self.build_temp, + ) def build_extensions(self) -> None: # Ensure that CMake is present and working try: - subprocess.check_output(['cmake', '--version']) + subprocess.check_output(["cmake", "--version"]) except OSError as e: - raise RuntimeError('Cannot find CMake executable') from e + raise RuntimeError("Cannot find CMake executable") from e # Create build directory if it does not exist. if not os.path.exists(self.build_temp): @@ -255,13 +262,18 @@ class cmake_build_ext(build_ext): # CMake appends the extension prefix to the install path, # and outdir already contains that prefix, so we need to remove it. prefix = outdir - for _ in range(ext.name.count('.')): + for _ in range(ext.name.count(".")): prefix = prefix.parent # prefix here should actually be the same for all components install_args = [ - "cmake", "--install", ".", "--prefix", prefix, "--component", - target_name(ext.name) + "cmake", + "--install", + ".", + "--prefix", + prefix, + "--component", + target_name(ext.name), ] subprocess.check_call(install_args, cwd=self.build_temp) @@ -272,12 +284,15 @@ class cmake_build_ext(build_ext): # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # directory so that they can be included in the editable build import glob - files = glob.glob(os.path.join(self.build_lib, "vllm", - "vllm_flash_attn", "**", "*.py"), - recursive=True) + + files = glob.glob( + os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"), + recursive=True, + ) for file in files: - dst_file = os.path.join("vllm/vllm_flash_attn", - file.split("vllm/vllm_flash_attn/")[-1]) + dst_file = os.path.join( + "vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1] + ) print(f"Copying {file} to {dst_file}") os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) @@ -287,8 +302,7 @@ class precompiled_build_ext(build_ext): """Disables extension building when using precompiled binaries.""" def run(self) -> None: - assert _is_cuda( - ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" def build_extensions(self) -> None: print("Skipping build_ext: using precompiled extensions.") @@ -309,9 +323,9 @@ class precompiled_wheel_utils: wheel_filename = wheel_url_or_path.split("/")[-1] temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") wheel_path = os.path.join(temp_dir, wheel_filename) - print(f"Downloading wheel from {wheel_url_or_path} " - f"to {wheel_path}") + print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}") from urllib.request import urlretrieve + urlretrieve(wheel_url_or_path, filename=wheel_path) else: wheel_path = wheel_url_or_path @@ -324,31 +338,37 @@ class precompiled_wheel_utils: "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", + "vllm/_flashmla_extension_C.abi3.so", + "vllm/_sparse_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/cumem_allocator.abi3.so", ] compiled_regex = re.compile( - r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" + ) file_members = list( - filter(lambda x: x.filename in files_to_copy, - wheel.filelist)) + filter(lambda x: x.filename in files_to_copy, wheel.filelist) + ) file_members += list( - filter(lambda x: compiled_regex.match(x.filename), - wheel.filelist)) + filter(lambda x: compiled_regex.match(x.filename), wheel.filelist) + ) for file in file_members: print(f"[extract] {file.filename}") target_path = os.path.join(".", file.filename) os.makedirs(os.path.dirname(target_path), exist_ok=True) - with wheel.open(file.filename) as src, open( - target_path, "wb") as dst: + with ( + wheel.open(file.filename) as src, + open(target_path, "wb") as dst, + ): shutil.copyfileobj(src, dst) pkg = os.path.dirname(file.filename).replace("/", ".") package_data_patch.setdefault(pkg, []).append( - os.path.basename(file.filename)) + os.path.basename(file.filename) + ) return package_data_patch finally: @@ -364,10 +384,13 @@ class precompiled_wheel_utils: try: # Get the latest commit hash of the upstream main branch. - resp_json = subprocess.check_output([ - "curl", "-s", - "https://api.github.com/repos/vllm-project/vllm/commits/main" - ]).decode("utf-8") + resp_json = subprocess.check_output( + [ + "curl", + "-s", + "https://api.github.com/repos/vllm-project/vllm/commits/main", + ] + ).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] # In Docker build context, .git may be immutable or missing. @@ -377,25 +400,32 @@ class precompiled_wheel_utils: # Check if the upstream_main_commit exists in the local repo try: subprocess.check_output( - ["git", "cat-file", "-e", f"{upstream_main_commit}"]) + ["git", "cat-file", "-e", f"{upstream_main_commit}"] + ) except subprocess.CalledProcessError: # If not present, fetch it from the remote repository. # Note that this does not update any local branches, # but ensures that this commit ref and its history are # available in our local repo. - subprocess.check_call([ - "git", "fetch", "https://github.com/vllm-project/vllm", - "main" - ]) + subprocess.check_call( + ["git", "fetch", "https://github.com/vllm-project/vllm", "main"] + ) # Then get the commit hash of the current branch that is the same as # the upstream main commit. - current_branch = subprocess.check_output( - ["git", "branch", "--show-current"]).decode("utf-8").strip() + current_branch = ( + subprocess.check_output(["git", "branch", "--show-current"]) + .decode("utf-8") + .strip() + ) - base_commit = subprocess.check_output([ - "git", "merge-base", f"{upstream_main_commit}", current_branch - ]).decode("utf-8").strip() + base_commit = ( + subprocess.check_output( + ["git", "merge-base", f"{upstream_main_commit}", current_branch] + ) + .decode("utf-8") + .strip() + ) return base_commit except ValueError as err: raise ValueError(err) from None @@ -403,7 +433,9 @@ class precompiled_wheel_utils: logger.warning( "Failed to get the base commit in the main branch. " "Using the nightly wheel. The libraries in this " - "wheel may not be compatible with your dev branch: %s", err) + "wheel may not be compatible with your dev branch: %s", + err, + ) return "nightly" @@ -413,17 +445,13 @@ def _no_device() -> bool: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None - return (VLLM_TARGET_DEVICE == "cuda" and has_cuda - and not (_is_neuron() or _is_tpu())) + return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu() def _is_hip() -> bool: - return (VLLM_TARGET_DEVICE == "cuda" - or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None - - -def _is_neuron() -> bool: - return VLLM_TARGET_DEVICE == "neuron" + return ( + VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm" + ) and torch.version.hip is not None def _is_tpu() -> bool: @@ -462,41 +490,27 @@ def get_rocm_version(): minor = ctypes.c_uint32() patch = ctypes.c_uint32() - if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor), - ctypes.byref(patch)) == 0): + if ( + get_rocm_core_version( + ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch) + ) + == 0 + ): return f"{major.value}.{minor.value}.{patch.value}" return None except Exception: return None -def get_neuronxcc_version(): - import sysconfig - site_dir = sysconfig.get_paths()["purelib"] - version_file = os.path.join(site_dir, "neuronxcc", "version", - "__init__.py") - - # Check if the command was executed successfully - with open(version_file) as fp: - content = fp.read() - - # Extract the version using a regular expression - match = re.search(r"__version__ = '(\S+)'", content) - if match: - # Return the version string - return match.group(1) - else: - raise RuntimeError("Could not find Neuron version in the output") - - def get_nvcc_cuda_version() -> Version: """Get the CUDA version from nvcc. Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ assert CUDA_HOME is not None, "CUDA_HOME is not set" - nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], - universal_newlines=True) + nvcc_output = subprocess.check_output( + [CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True + ) output = nvcc_output.split() release_idx = output.index("release") + 1 nvcc_cuda_version = parse(output[release_idx].split(",")[0]) @@ -508,14 +522,20 @@ def get_gaudi_sw_version(): Returns the driver version. """ # Enable console printing for `hl-smi` check - output = subprocess.run("hl-smi", - shell=True, - text=True, - capture_output=True, - env={"ENABLE_CONSOLE": "true"}) + output = subprocess.run( + "hl-smi", + shell=True, + text=True, + capture_output=True, + env={"ENABLE_CONSOLE": "true"}, + ) if output.returncode == 0 and output.stdout: - return output.stdout.split("\n")[2].replace( - " ", "").split(":")[1][:-1].split("-")[0] + return ( + output.stdout.split("\n")[2] + .replace(" ", "") + .split(":")[1][:-1] + .split("-")[0] + ) return "0.0.0" # when hl-smi is not available @@ -531,7 +551,7 @@ def get_vllm_version() -> str: version += f"{sep}precompiled" else: cuda_version = str(get_nvcc_cuda_version()) - if cuda_version != MAIN_CUDA_VERSION: + if cuda_version != envs.VLLM_MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] # skip this for source tarball, required for pypi if "sdist" not in sys.argv: @@ -539,14 +559,8 @@ def get_vllm_version() -> str: elif _is_hip(): # Get the Rocm Version rocm_version = get_rocm_version() or torch.version.hip - if rocm_version and rocm_version != MAIN_CUDA_VERSION: + if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION: version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" - elif _is_neuron(): - # Get the Neuron version - neuron_version = str(get_neuronxcc_version()) - if neuron_version != MAIN_CUDA_VERSION: - neuron_version_str = neuron_version.replace(".", "")[:3] - version += f"{sep}neuron{neuron_version_str}" elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): @@ -571,8 +585,11 @@ def get_requirements() -> list[str]: for line in requirements: if line.startswith("-r "): resolved_requirements += _read_requirements(line.split()[1]) - elif not line.startswith("--") and not line.startswith( - "#") and line.strip() != "": + elif ( + not line.startswith("--") + and not line.startswith("#") + and line.strip() != "" + ): resolved_requirements.append(line) return resolved_requirements @@ -583,7 +600,7 @@ def get_requirements() -> list[str]: cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if ("vllm-flash-attn" in req and cuda_major != "12"): + if "vllm-flash-attn" in req and cuda_major != "12": # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue @@ -591,8 +608,6 @@ def get_requirements() -> list[str]: requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("rocm.txt") - elif _is_neuron(): - requirements = _read_requirements("neuron.txt") elif _is_tpu(): requirements = _read_requirements("tpu.txt") elif _is_cpu(): @@ -600,8 +615,7 @@ def get_requirements() -> list[str]: elif _is_xpu(): requirements = _read_requirements("xpu.txt") else: - raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") + raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.") return requirements @@ -617,12 +631,13 @@ if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): # FA3 requires CUDA 12.3 or later - ext_modules.append( - CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) # Optional since this doesn't get built (produce an .so file) when # not targeting a hopper system + ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append( - CMakeExtension(name="vllm._flashmla_C", optional=True)) + CMakeExtension(name="vllm._flashmla_extension_C", optional=True) + ) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): @@ -644,6 +659,7 @@ if envs.VLLM_USE_PRECOMPILED: wheel_url = wheel_location else: import platform + arch = platform.machine() if arch == "x86_64": wheel_tag = "manylinux1_x86_64" @@ -653,8 +669,11 @@ if envs.VLLM_USE_PRECOMPILED: raise ValueError(f"Unsupported architecture: {arch}") base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" - nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + nightly_wheel_url = ( + f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + ) from urllib.request import urlopen + try: with urlopen(wheel_url) as resp: if resp.status != 200: @@ -663,8 +682,7 @@ if envs.VLLM_USE_PRECOMPILED: print(f"[warn] Falling back to nightly wheel: {e}") wheel_url = nightly_wheel_url - patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( - wheel_url) + patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url) for pkg, files in patch.items(): package_data.setdefault(pkg, []).extend(files) @@ -675,8 +693,9 @@ if not ext_modules: cmdclass = {} else: cmdclass = { - "build_ext": - precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext + "build_ext": precompiled_build_ext + if envs.VLLM_USE_PRECOMPILED + else cmake_build_ext } setup( @@ -688,13 +707,14 @@ setup( "bench": ["pandas", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": - ["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"], - "audio": ["librosa", "soundfile", - "mistral_common[audio]"], # Required for audio processing + "runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"], + "audio": [ + "librosa", + "soundfile", + "mistral_common[audio]", + ], # Required for audio processing "video": [], # Kept for backwards compatibility - # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.12"], + "flashinfer": [], # Kept for backwards compatibility # Optional deps for AMD FP4 quantization support "petit-kernel": ["petit-kernel"], }, diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py deleted file mode 100644 index ec6b20f5e04b9..0000000000000 --- a/tests/async_engine/api_server_async_engine.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""vllm.entrypoints.api_server with some extra logging for testing.""" -from collections.abc import Iterable -from typing import Any - -import uvicorn -from fastapi.responses import JSONResponse, Response - -import vllm.entrypoints.api_server -import vllm.envs as envs -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.utils import FlexibleArgumentParser - -app = vllm.entrypoints.api_server.app - - -class AsyncLLMEngineWithStats(AsyncLLMEngine): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._num_aborts = 0 - - async def _engine_abort(self, request_ids: Iterable[str]): - ids = list(request_ids) - self._num_aborts += len(ids) - await super()._engine_abort(ids) - - def testing_stats(self) -> dict[str, Any]: - return {"num_aborted_requests": self._num_aborts} - - -@app.get("/stats") -def stats() -> Response: - """Get the statistics of the engine.""" - return JSONResponse(engine.testing_stats()) - - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8000) - parser = AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngineWithStats.from_engine_args(engine_args) - vllm.entrypoints.api_server.engine = engine - uvicorn.run(app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE) diff --git a/tests/async_engine/conftest.py b/tests/async_engine/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/async_engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py deleted file mode 100644 index 76c94bdf80ca8..0000000000000 --- a/tests/async_engine/test_api_server.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import subprocess -import sys -import time -from multiprocessing import Pool -from pathlib import Path - -import pytest -import requests - - -def _query_server(prompt: str, max_tokens: int = 5) -> dict: - response = requests.post("http://localhost:8000/generate", - json={ - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": 0, - "ignore_eos": True - }) - response.raise_for_status() - return response.json() - - -def _query_server_long(prompt: str) -> dict: - return _query_server(prompt, max_tokens=500) - - -@pytest.fixture -def api_server(distributed_executor_backend: str): - script_path = Path(__file__).parent.joinpath( - "api_server_async_engine.py").absolute() - commands = [ - sys.executable, - "-u", - str(script_path), - "--model", - "facebook/opt-125m", - "--host", - "127.0.0.1", - "--distributed-executor-backend", - distributed_executor_backend, - ] - - # API Server Test Requires V0. - my_env = os.environ.copy() - my_env["VLLM_USE_V1"] = "0" - uvicorn_process = subprocess.Popen(commands, env=my_env) - yield - uvicorn_process.terminate() - - -@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"]) -def test_api_server(api_server, distributed_executor_backend: str): - """ - Run the API server and test it. - - We run both the server and requests in separate processes. - - We test that the server can handle incoming requests, including - multiple requests at the same time, and that it can handle requests - being cancelled without crashing. - """ - with Pool(32) as pool: - # Wait until the server is ready - prompts = ["warm up"] * 1 - result = None - while not result: - try: - for r in pool.map(_query_server, prompts): - result = r - break - except requests.exceptions.ConnectionError: - time.sleep(1) - - # Actual tests start here - # Try with 1 prompt - for result in pool.map(_query_server, prompts): - assert result - - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] - assert num_aborted_requests == 0 - - # Try with 100 prompts - prompts = ["test prompt"] * 100 - for result in pool.map(_query_server, prompts): - assert result - - with Pool(32) as pool: - # Cancel requests - prompts = ["canceled requests"] * 100 - pool.map_async(_query_server_long, prompts) - time.sleep(0.01) - pool.terminate() - pool.join() - - # check cancellation stats - # give it some times to update the stats - time.sleep(1) - - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] - assert num_aborted_requests > 0 - - # check that server still runs after cancellations - with Pool(32) as pool: - # Try with 100 prompts - prompts = ["test prompt after canceled"] * 100 - for result in pool.map(_query_server, prompts): - assert result diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py deleted file mode 100644 index 1851eeeda7905..0000000000000 --- a/tests/async_engine/test_request_tracker.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.engine.async_llm_engine import RequestTracker -from vllm.outputs import RequestOutput - - -@pytest.mark.asyncio -async def test_request_tracker(): - tracker = RequestTracker() - stream_1 = tracker.add_request("1") - assert tracker.new_requests_event.is_set() - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - assert not tracker.new_requests_event.is_set() - assert len(new) == 1 - assert new[0]["request_id"] == "1" - assert not aborted - assert not stream_1.finished - - stream_2 = tracker.add_request("2") - stream_3 = tracker.add_request("3") - assert tracker.new_requests_event.is_set() - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - assert not tracker.new_requests_event.is_set() - assert len(new) == 2 - assert new[0]["request_id"] == "2" - assert new[1]["request_id"] == "3" - assert not aborted - assert not stream_2.finished - assert not stream_3.finished - - # request_ids must be unique - with pytest.raises(KeyError): - tracker.add_request("1") - assert not tracker.new_requests_event.is_set() - - tracker.abort_request("1") - new, aborted = tracker.get_new_and_aborted_requests() - assert len(aborted) == 1 - assert "1" in aborted - assert not new - assert stream_1.finished - - stream_4 = tracker.add_request("4") - tracker.abort_request("4") - assert tracker.new_requests_event.is_set() - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - # aborted new requests will cancel each other out - - # there's no need for them to propagate into the - # engine - assert not aborted - assert not new - assert stream_4.finished - - stream_5 = tracker.add_request("5") - assert tracker.new_requests_event.is_set() - tracker.process_request_output( - RequestOutput("2", "output", [], [], [], finished=True)) - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - assert not tracker.new_requests_event.is_set() - assert not aborted - assert len(new) == 1 - assert new[0]["request_id"] == "5" - assert stream_2.finished - assert not stream_5.finished diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index a3b09cc817917..9b9d8cfea7fad 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,7 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ + import os import weakref from unittest.mock import Mock @@ -11,8 +12,8 @@ from unittest.mock import Mock import pytest import torch -from vllm import LLM, envs -from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 +from vllm import LLM +from vllm.v1.engine.llm_engine import LLMEngine from ..conftest import HfRunner, VllmRunner from ..models.utils import check_outputs_equal @@ -26,14 +27,6 @@ MODELS = [ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" llm = LLM("distilbert/distilgpt2") @@ -45,16 +38,21 @@ def test_vllm_gc_ed(): def _fix_prompt_embed_outputs( - vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, - example_prompts: list[str]) -> list[tuple[list[int], str]]: + vllm_outputs: list[tuple[list[int], str]], + hf_model: HfRunner, + example_prompts: list[str], +) -> list[tuple[list[int], str]]: fixed_vllm_outputs = [] for vllm_output, hf_input, prompt in zip( - vllm_outputs, hf_model.get_inputs(example_prompts), - example_prompts): + vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts + ): hf_input_ids = hf_input["input_ids"].tolist()[0] fixed_vllm_outputs.append( - (hf_input_ids + vllm_output[0][len(hf_input_ids):], - prompt + vllm_output[1])) + ( + hf_input_ids + vllm_output[0][len(hf_input_ids) :], + prompt + vllm_output[1], + ) + ) return fixed_vllm_outputs @@ -62,6 +60,8 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("async_scheduling", [True, False]) +@pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( monkeypatch: pytest.MonkeyPatch, @@ -70,16 +70,12 @@ def test_models( backend: str, max_tokens: int, enforce_eager: bool, + async_scheduling: bool, + model_executor: str, enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: - pytest.skip("enable_prompt_embeds is not supported in v1.") - if backend == "XFORMERS" and model == "google/gemma-2-2b-it": - pytest.skip( - f"{backend} does not support gemma2 with full context length.") + pytest.skip(f"{backend} does not support gemma2 with full context length.") with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", backend) @@ -87,30 +83,35 @@ def test_models( # 5042 tokens for gemma2 # gemma2 has alternating sliding window size of 4096 # we need a prompt with more than 4096 tokens to test the sliding window - prompt = "The following numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) if enable_prompt_embeds: with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings( - example_prompts) + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) - with VllmRunner(model, - max_model_len=8192, - enforce_eager=enforce_eager, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7) as vllm_model: + with VllmRunner( + model, + max_model_len=8192, + enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, + async_scheduling=async_scheduling, + distributed_executor_backend=model_executor, + ) as vllm_model: if enable_prompt_embeds: - vllm_outputs = vllm_model.generate_greedy( - prompt_embeds, max_tokens) + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts) + vllm_outputs, hf_model, example_prompts + ) else: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -122,21 +123,18 @@ def test_models( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, " - "test_suite, extra_env", [ + "model, distributed_executor_backend, attention_backend, test_suite, extra_env", + [ ("distilbert/distilgpt2", "ray", "", "L4", {}), ("distilbert/distilgpt2", "mp", "", "L4", {}), - ("distilbert/distilgpt2", "ray", "", "L4", { - "VLLM_SLEEP_WHEN_IDLE": "1" - }), - ("distilbert/distilgpt2", "mp", "", "L4", { - "VLLM_SLEEP_WHEN_IDLE": "1" - }), + ("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), + ("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}), - ]) + ], +) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( monkeypatch: pytest.MonkeyPatch, @@ -150,20 +148,18 @@ def test_models_distributed( extra_env: dict[str, str], enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: - pytest.skip("enable_prompt_embeds is not supported in v1.") - if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") with monkeypatch.context() as monkeypatch_context: - if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + if ( + model == "meta-llama/Llama-3.2-1B-Instruct" + and distributed_executor_backend == "ray" + and attention_backend == "" + and test_suite == "L4" + ): # noqa if enable_prompt_embeds: - pytest.skip( - "enable_prompt_embeds does not work with ray compiled dag." - ) + pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -185,30 +181,26 @@ def test_models_distributed( # will hurt multiprocessing backend with fork method # (the default method). with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7, + model, + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, ) as vllm_model: if enable_prompt_embeds: with hf_runner(model, dtype=dtype) as hf_model: with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings( - example_prompts) - vllm_outputs = vllm_model.generate_greedy( - prompt_embeds, max_tokens) + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts) - hf_outputs = hf_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs, hf_model, example_prompts + ) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) else: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy( - example_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -219,27 +211,18 @@ def test_models_distributed( def test_failed_model_execution(vllm_runner, monkeypatch) -> None: - - from vllm.envs import VLLM_USE_V1 - - if not VLLM_USE_V1: - pytest.skip("Skipping V0 test, dump input not supported") - # Needed to mock an error in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: - if isinstance(vllm_model.llm.llm_engine, LLMEngineV1): + with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model: + if isinstance(vllm_model.llm.llm_engine, LLMEngine): v1_test_failed_model_execution(vllm_model) def v1_test_failed_model_execution(vllm_model): - engine = vllm_model.llm.llm_engine - mocked_execute_model = Mock( - side_effect=RuntimeError("Mocked Critical Error")) - engine.engine_core.engine_core.model_executor.execute_model =\ - mocked_execute_model + mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model with pytest.raises(RuntimeError) as exc_info: prompts = [ diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 28bfe9e7c8020..3c1e01d072b9e 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -5,5 +5,6 @@ from ..utils import compare_two_settings def test_cpu_offload(): - compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [], - ["--cpu-offload-gb", "1"]) + compare_two_settings( + "meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"] + ) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 34f9389c82a9b..f1b0f7b2de891 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -23,13 +23,13 @@ def test_python_error(): tensors = [] with allocator.use_memory_pool(): # allocate 70% of the total memory - x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") tensors.append(x) # release the memory allocator.sleep() # allocate more memory than the total memory - y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") tensors.append(y) with pytest.raises(RuntimeError): # when the allocator is woken up, it should raise an error @@ -41,17 +41,17 @@ def test_python_error(): def test_basic_cumem(): # some tensors from default memory pool shape = (1024, 1024) - x = torch.empty(shape, device='cuda') + x = torch.empty(shape, device="cuda") x.zero_() # some tensors from custom memory pool allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): # custom memory pool - y = torch.empty(shape, device='cuda') + y = torch.empty(shape, device="cuda") y.zero_() y += 1 - z = torch.empty(shape, device='cuda') + z = torch.empty(shape, device="cuda") z.zero_() z += 2 @@ -74,16 +74,16 @@ def test_basic_cumem(): def test_cumem_with_cudagraph(): allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): - weight = torch.eye(1024, device='cuda') + weight = torch.eye(1024, device="cuda") with allocator.use_memory_pool(tag="discard"): - cache = torch.empty(1024, 1024, device='cuda') + cache = torch.empty(1024, 1024, device="cuda") def model(x): out = x @ weight - cache[:out.size(0)].copy_(out) + cache[: out.size(0)].copy_(out) return out + 1 - x = torch.empty(128, 1024, device='cuda') + x = torch.empty(128, 1024, device="cuda") # warmup model(x) @@ -109,7 +109,7 @@ def test_cumem_with_cudagraph(): model_graph.replay() # cache content is as expected - assert torch.allclose(x, cache[:x.size(0)]) + assert torch.allclose(x, cache[: x.size(0)]) # output content is as expected assert torch.allclose(y, x + 1) @@ -117,63 +117,87 @@ def test_cumem_with_cudagraph(): @create_new_process_for_each_test() @pytest.mark.parametrize( - "model, use_v1", + "model", [ # sleep mode with safetensors - ("meta-llama/Llama-3.2-1B", True), + "meta-llama/Llama-3.2-1B", # sleep mode with pytorch checkpoint - ("facebook/opt-125m", False), - ]) -def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - free, total = torch.cuda.mem_get_info() - used_bytes_baseline = total - free # in case other process is running - llm = LLM(model, enable_sleep_mode=True) - prompt = "How are you?" - sampling_params = SamplingParams(temperature=0, max_tokens=10) - output = llm.generate(prompt, sampling_params) + "facebook/opt-125m", + ], +) +def test_end_to_end(model: str): + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM(model, enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) - # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage, - # which is difficult to measure in the test. therefore, we only - # test sleep level 1 here. - llm.sleep(level=1) + # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage, + # which is difficult to measure in the test. therefore, we only + # test sleep level 1 here. + llm.sleep(level=1) - free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() - used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline - # now the memory usage is mostly cudagraph memory pool, - # and it should be less than the model weights (1B model, 2GiB weights) + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + # now the memory usage is mostly cudagraph memory pool, + # and it should be less than the model weights (1B model, 2GiB weights) - # NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size) - # is captured but cannot be releasesd from PyTorch due to a known bug, - # therefore high memory usage after `llm.sleep` is called is expected. - # FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode - # in V1. - if use_v1: - assert used_bytes < 7 * GiB_bytes - else: - assert used_bytes < 2 * GiB_bytes + # NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size) + # is captured but cannot be releasesd from PyTorch due to a known bug, + # therefore high memory usage after `llm.sleep` is called is expected. + # FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode + # in V1. + assert used_bytes < 7 * GiB_bytes - llm.wake_up() - output2 = llm.generate(prompt, sampling_params) - # cmp output - assert output[0].outputs[0].text == output2[0].outputs[0].text + llm.wake_up() + output2 = llm.generate(prompt, sampling_params) + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text - llm.sleep(level=1) - llm.wake_up(tags=["weights"]) + llm.sleep(level=1) + llm.wake_up(tags=["weights"]) - free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() - used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline - # should just reallocate memory for weights (1B model, ~2GiB weights) - if use_v1: - assert used_bytes < 10 * GiB_bytes - else: - assert used_bytes < 6 * GiB_bytes + # should just reallocate memory for weights (1B model, ~2GiB weights) + assert used_bytes < 10 * GiB_bytes - # now allocate kv cache memory - llm.wake_up(tags=["kv_cache"]) - output3 = llm.generate(prompt, sampling_params) + # now allocate kv cache memory + llm.wake_up(tags=["kv_cache"]) + output3 = llm.generate(prompt, sampling_params) - # cmp output - assert output[0].outputs[0].text == output3[0].outputs[0].text + # cmp output + assert output[0].outputs[0].text == output3[0].outputs[0].text + + +@create_new_process_for_each_test() +def test_deep_sleep(): + model = "Qwen/Qwen3-0.6B" + free, total = torch.cuda.mem_get_info() + used_bytes_baseline = total - free # in case other process is running + llm = LLM(model, enable_sleep_mode=True) + prompt = "How are you?" + sampling_params = SamplingParams(temperature=0, max_tokens=10) + output = llm.generate(prompt, sampling_params) + + # Put the engine to deep sleep + llm.sleep(level=2) + + free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline + assert used_bytes < 3 * GiB_bytes + + llm.wake_up(tags=["weights"]) + llm.collective_rpc("reload_weights") + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + assert used_bytes < 4 * GiB_bytes + + # now allocate kv cache and cuda graph memory + llm.wake_up(tags=["kv_cache"]) + output2 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output2[0].outputs[0].text diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py deleted file mode 100644 index db2fa2f6bef6f..0000000000000 --- a/tests/basic_correctness/test_preemption.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the short outputs of HF and vLLM when using greedy sampling. - -VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. - -Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 -pytest tests/basic_correctness/test_preemption.py`. -""" -import pytest -from prometheus_client import REGISTRY - -import vllm.envs as envs -from vllm import SamplingParams -from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, - ENABLE_ARTIFICIAL_PREEMPT) - -from ..models.utils import check_outputs_equal - -MODELS = [ - "distilbert/distilgpt2", -] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - We should enable this for V1, but VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT, - so use VLLM_USE_V1=0 for all tests in the file. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -@pytest.fixture(scope="module", autouse=True) -def check_settings(): - assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1." - "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " - "pytest tests/basic_correctness/test_preemption.py`") - - -@pytest.fixture -def distributed_executor_backend() -> str: - # When SPMD worker is used, use distributed_executor_backend="ray" - # to test delta input optimization works with preemption. - return "ray" if envs.VLLM_USE_RAY_SPMD_WORKER else "mp" - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [96]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) -def test_chunked_prefill_recompute( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - distributed_executor_backend: str, -) -> None: - """Ensure that chunked prefill works with preemption.""" - max_num_seqs = min(chunked_prefill_token_size, 256) - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill, - max_num_seqs=max_num_seqs, - distributed_executor_backend=distributed_executor_backend, - disable_log_stats=False, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_preemption( - caplog_vllm, - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - distributed_executor_backend: str, -) -> None: - """By default, recompute preemption is enabled""" - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner( - model, - dtype=dtype, - disable_log_stats=False, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - total_preemption = ( - vllm_model.llm.llm_engine.scheduler[0].num_cumulative_preemption) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " - "is not enough KV cache space." in caplog_vllm.text) - # Ensure the count bucket of request-level histogram metrics matches - # the number of requests as a simple sanity check to ensure metrics are - # generated - preemption_metrics = None - for m in REGISTRY.collect(): - if m.name == "vllm:num_preemptions": - preemption_metrics = m - assert preemption_metrics is not None - total_recorded_preemption = 0 - for sample in preemption_metrics.samples: - total_recorded_preemption += sample.value - assert total_preemption == total_recorded_preemption - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_preemption_infeasible( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - distributed_executor_backend: str, -) -> None: - """Verify infeasible preemption request will be ignored.""" - BLOCK_SIZE = 16 - prefill_blocks = 2 - decode_blocks = max_tokens // BLOCK_SIZE - with vllm_runner( - model, - dtype=dtype, - block_size=BLOCK_SIZE, - # Not enough gpu blocks to complete a single sequence. - # preemption should happen, and the sequence should be - # ignored instead of hanging forever. - num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, - max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True) - req_outputs = vllm_model.llm.generate( - example_prompts, - sampling_params=sampling_params, - ) - - assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - - # Verify the request is ignored and not hang. - for req_output in req_outputs: - outputs = req_output.outputs - assert len(outputs) == 1 - assert outputs[0].finish_reason == "length" diff --git a/tests/benchmarks/test_latency_cli.py b/tests/benchmarks/test_latency_cli.py index 2279c846e01cd..54075a3a15e63 100644 --- a/tests/benchmarks/test_latency_cli.py +++ b/tests/benchmarks/test_latency_cli.py @@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @pytest.mark.benchmark def test_bench_latency(): command = [ - "vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", - "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + "vllm", + "bench", + "latency", + "--model", + MODEL_NAME, + "--input-len", + "32", + "--output-len", + "1", + "--enforce-eager", + "--load-format", + "dummy", ] result = subprocess.run(command, capture_output=True, text=True) print(result.stdout) diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py new file mode 100644 index 0000000000000..90527dbeae28c --- /dev/null +++ b/tests/benchmarks/test_random_dataset.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +from typing import Any, NamedTuple, Optional, cast + +import numpy as np +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import ( + RandomDataset, + RandomMultiModalDataset, + SampleRequest, +) + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +class Params(NamedTuple): + num_requests: int + prefix_len: int + range_ratio: float + input_len: int + output_len: int + + +@pytest.fixture(scope="session") +def random_dataset_params() -> Params: + return Params( + num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20 + ) + + +def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: + """Project a SampleRequest into a comparable tuple.""" + return (req.prompt, req.prompt_len, req.expected_output_len) + + +def _collect_samples( + dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20, +) -> list[tuple[str, int, int]]: + samples = dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + ) + return [_fingerprint_sample(s) for s in samples] + + +@pytest.mark.benchmark +def test_random_dataset_same_seed( + hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params +) -> None: + """Same seed should yield identical outputs, even if global RNGs change. + + This guards against accidental reliance on Python's random or np.random + in RandomDataset after moving to numpy.default_rng. + """ + p = random_dataset_params + common_seed = 123 + dataset_a = RandomDataset(random_seed=common_seed) + dataset_b = RandomDataset(random_seed=common_seed) + a = _collect_samples( + dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) + + # Perturb global RNG state to ensure isolation + random.seed(999) + _ = [random.random() for _ in range(100)] + np.random.seed(888) + _ = [np.random.random() for _ in range(100)] + + b = _collect_samples( + dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) + assert a == b + + +@pytest.mark.benchmark +def test_random_dataset_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params +) -> None: + """Different seeds should change outputs with overwhelming likelihood.""" + p = random_dataset_params + seed_a = 0 + dataset_a = RandomDataset(random_seed=seed_a) + a = _collect_samples( + dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) + + seed_b = 999 + dataset_b = RandomDataset(random_seed=seed_b) + # Perturb global RNG with same seed as dataset_a to ensure isolation + random.seed(seed_a) + np.random.seed(seed_a) + b = _collect_samples( + dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) + assert a != b + + +# ----------------------------- +# RandomMultiModalDataset tests +# ----------------------------- + + +def _mm_fingerprint_sample( + req: SampleRequest, +) -> tuple[str, int, int, int, list[str]]: + """Create a compact fingerprint for multimodal samples. + + Includes: + - prompt string + - prompt_len + - expected_output_len + - count of multimodal items + - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,') + """ + items = req.multi_modal_data or [] + item_prefixes: list[str] = [] + for it in items: + if isinstance(it, dict) and it.get("type") == "image_url": + url = it.get("image_url", {}).get("url", "") + # Only keep a short identifying prefix to avoid huge strings + item_prefixes.append(f"image:{url[:22]}") + elif isinstance(it, dict) and it.get("type") == "video_url": + url = it.get("video_url", {}).get("url", "") + item_prefixes.append(f"video:{url[:22]}") + else: + item_prefixes.append("unknown:") + return ( + req.prompt, + req.prompt_len, + req.expected_output_len, + len(items), + item_prefixes, + ) + + +def _collect_mm_samples( + dataset: RandomMultiModalDataset, + tokenizer: PreTrainedTokenizerBase, + *, + num_requests: int = 8, + prefix_len: int = 3, + range_ratio: float = 0.0, + input_len: int = 20, + output_len: int = 5, + base_items_per_request: int = 2, + num_mm_items_range_ratio: float = 0.0, + limit_mm_per_prompt: Optional[dict[str, int]] = None, + bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + enable_multimodal_chat: bool = False, +) -> list[SampleRequest]: + if limit_mm_per_prompt is None: + limit_mm_per_prompt = {"image": 5, "video": 0} + if bucket_config is None: + bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5} + return dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + base_items_per_request=base_items_per_request, + num_mm_items_range_ratio=num_mm_items_range_ratio, + limit_mm_per_prompt=limit_mm_per_prompt, + bucket_config=bucket_config, + enable_multimodal_chat=enable_multimodal_chat, + ) + + +@pytest.mark.benchmark +def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None: + seed = 42 + ds_a = RandomMultiModalDataset(random_seed=seed) + ds_b = RandomMultiModalDataset(random_seed=seed) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa == fb + + +@pytest.mark.benchmark +def test_random_mm_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds_a = RandomMultiModalDataset(random_seed=0) + ds_b = RandomMultiModalDataset(random_seed=999) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa != fb + + +@pytest.mark.benchmark +def test_random_mm_respects_limits( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Requesting 3 items with a per-prompt limit of 1 should error per current + # design (dataset refuses to silently clamp below the requested baseline). + with pytest.raises(ValueError): + _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=12, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + + +@pytest.mark.benchmark +def test_random_mm_zero_prob_entries_are_removed( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Second bucket has zero probability and should be ignored after + # normalization + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=6, + base_items_per_request=2, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 10, "video": 0}, + bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0}, + ) + for s in samples: + assert isinstance(s.multi_modal_data, list) + typed_mm = cast(list[dict[str, Any]], s.multi_modal_data) + for it in typed_mm: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=0, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 5, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + for s in samples: + assert s.multi_modal_data == [] + + +@pytest.mark.benchmark +def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Fixed number of images per prompt + # set num_mm_items_range_ratio to 0.0 + # TODO: modify video values when video sampling is implemented + samples_fixed_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 3, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with 3 mm items per prompt + assert len(samples_fixed_items) == 5 + for s in samples_fixed_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) == 3 + for it in mm_data: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_bucket_config_not_mutated( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # This bucket config is not normalized to sum to 1 + # and has more buckets than requested images + original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3} + # Keep a snapshot to compare after sampling + snapshot = dict(original) + + _ = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=4, + base_items_per_request=1, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config=original, + ) + + # Ensure the original dict content is unchanged + assert original == snapshot + + # Vary number of mm items per prompt + # set num_mm_items_range_ratio to 0.5 + samples_varying_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=2, + num_mm_items_range_ratio=0.5, + limit_mm_per_prompt={"image": 4, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with less than 4 mm items per prompt + # but at least 1 mm item per prompt + assert len(samples_varying_items) == 5 + for s in samples_varying_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) <= 4 + assert len(mm_data) >= 1 + for it in mm_data: + assert it.get("type") == "image_url" diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index bfcf274727e27..90d685c966d3e 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @pytest.fixture(scope="module") def server(): - args = [ - "--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy" - ] + args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -45,3 +43,35 @@ def test_bench_serve(server): print(result.stderr) assert result.returncode == 0, f"Benchmark failed: {result.stderr}" + + +@pytest.mark.benchmark +def test_bench_serve_chat(server): + command = [ + "vllm", + "bench", + "serve", + "--model", + MODEL_NAME, + "--host", + server.host, + "--port", + str(server.port), + "--dataset-name", + "random", + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + "--endpoint", + "/v1/chat/completions", + "--backend", + "openai-chat", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_throughput_cli.py b/tests/benchmarks/test_throughput_cli.py index b61e51db4fbe4..a579b59e8af46 100644 --- a/tests/benchmarks/test_throughput_cli.py +++ b/tests/benchmarks/test_throughput_cli.py @@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @pytest.mark.benchmark def test_bench_throughput(): command = [ - "vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", - "32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + "vllm", + "bench", + "throughput", + "--model", + MODEL_NAME, + "--input-len", + "32", + "--output-len", + "1", + "--enforce-eager", + "--load-format", + "dummy", ] result = subprocess.run(command, capture_output=True, text=True) print(result.stdout) diff --git a/tests/build_cython.py b/tests/build_cython.py deleted file mode 100644 index 444434e8f0a79..0000000000000 --- a/tests/build_cython.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import Cython.Compiler.Options -from Cython.Build import cythonize -from setuptools import setup - -Cython.Compiler.Options.annotate = True - -infiles = [] - -infiles += [ - "vllm/engine/llm_engine.py", - "vllm/transformers_utils/detokenizer.py", - "vllm/engine/output_processor/single_step.py", - "vllm/outputs.py", - "vllm/engine/output_processor/stop_checker.py", -] - -infiles += [ - "vllm/core/scheduler.py", - "vllm/sequence.py", - "vllm/core/block_manager.py", -] - -infiles += [ - "vllm/model_executor/layers/sampler.py", - "vllm/sampling_params.py", - "vllm/utils/__init__.py", -] - -setup(ext_modules=cythonize(infiles, - annotate=False, - force=True, - compiler_directives={ - 'language_level': "3", - 'infer_types': True - })) - -# example usage: python3 build_cython.py build_ext --inplace diff --git a/tests/ci_envs.py b/tests/ci_envs.py new file mode 100644 index 0000000000000..d16ecce1ef8dd --- /dev/null +++ b/tests/ci_envs.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +These envs only work for a small part of the tests, fix what you need! +""" + +import os +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + VLLM_CI_NO_SKIP: bool = False + VLLM_CI_DTYPE: Optional[str] = None + VLLM_CI_HEAD_DTYPE: Optional[str] = None + VLLM_CI_HF_DTYPE: Optional[str] = None + +environment_variables: dict[str, Callable[[], Any]] = { + # A model family has many models with the same architecture. + # By default, a model family tests only one model. + # Through this flag, all models can be tested. + "VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))), + # Allow changing the dtype used by vllm in tests + "VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None), + # Allow changing the head dtype used by vllm in tests + "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), + # Allow changing the head dtype used by transformers in tests + "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), +} + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def is_set(name: str): + """Check if an environment variable is explicitly set.""" + if name in environment_variables: + return name in os.environ + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ace4d25534cdd..36bc832a1329e 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import weakref from collections.abc import Sequence from copy import deepcopy from typing import Callable, Union @@ -10,7 +11,25 @@ from torch._ops import OpOverload from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass -from vllm.config import get_current_vllm_config +from vllm.compilation.pass_manager import with_pattern_match_debug +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig, get_current_vllm_config + + +class LazyInitPass(InductorPass): + """ + If there's a pass that we want to initialize lazily in a test, + we can wrap it in LazyInitPass, which will initialize the pass when invoked + and then immediately invoke it. + """ + + def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig): + self.pass_cls = pass_cls + self.vllm_config = weakref.proxy(vllm_config) # avoid cycle + + def __call__(self, graph: fx.Graph) -> None: + self.pass_ = self.pass_cls(self.vllm_config) + self.pass_(graph) class TestBackend: @@ -25,25 +44,29 @@ class TestBackend: Inductor config is default-initialized from VllmConfig.CompilationConfig. """ - def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], - None]]): + def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) compile_config = get_current_vllm_config().compilation_config self.inductor_config = compile_config.inductor_compile_config - self.inductor_config['force_disable_caches'] = True - self.inductor_config['post_grad_custom_post_pass'] = self.post_pass + self.inductor_config["force_disable_caches"] = True + self.inductor_config["post_grad_custom_post_pass"] = self.post_pass def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, - example_inputs, - config_patches=self.inductor_config) + return compile_fx(graph, example_inputs, config_patches=self.inductor_config) + + @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + + VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: pass_(graph) + VllmInductorPass.dump_prefix += 1 + + VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) # assign by reference, will reflect the final state of the graph @@ -56,12 +79,15 @@ class TestBackend: assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" assert num_pre > num_post, f"All nodes remain for op {op.name()}" if fully_replaced: - assert num_post == 0, \ - f"Unexpected op {op.name()} in post-pass graph" + assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" def check_after_ops(self, ops: Sequence[OpOverload]): for op in ops: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass))) assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" - assert num_post > 0, f"Op {op.name()} not found in post-pass graph" \ No newline at end of file + assert num_post > 0, f"Op {op.name()} not found in post-pass graph" + + def op_count(self, op: OpOverload, before=False) -> int: + graph = self.graph_pre_pass if before else self.graph_post_pass + return len(list(find_op_nodes(op, graph))) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 97140a9db7af6..84194f3ed01e8 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -3,12 +3,11 @@ import contextlib import os import weakref -from dataclasses import dataclass -from typing import Optional import pytest from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform @@ -33,80 +32,14 @@ def temporary_environ(env_vars): os.environ[k] = v -@dataclass -class BackendConfig: - name: str - env_vars: dict - comp_config: dict - specific_gpu_arch: Optional[tuple] = None - - -# Define all backend configurations of full cudagraph to be tested -backend_configs = { - # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), - # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), - # Cutlass MLA on Blackwell - "CutlassMLA": - BackendConfig( - name="CutlassMLA", - env_vars={ - "VLLM_USE_V1": "1", - "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", - "FORCE_NUM_KV_SPLITS": - "1", # TODO: remove this when hang issue is fixed - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512], - }, - specific_gpu_arch=(10, 0)), - # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, - comp_config={ - "cudagraph_mode": "FULL", - }), - # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, - comp_config={ - "cudagraph_mode": "FULL", - }), - # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), -} - test_params_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA -MLA_backends = ["FlashMLA", "CutlassMLA"] +MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: test_params_full_cudagraph.append( - pytest.param( - ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))) + pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])) + ) # Qwen/Qwen2-1.5B-Instruct with other backends other_backend_configs = [ @@ -114,7 +47,8 @@ other_backend_configs = [ ] for backend_config in other_backend_configs: test_params_full_cudagraph.append( - pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))) + pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)) + ) @pytest.fixture(scope="class") @@ -122,15 +56,16 @@ def llm_pair(request): model, backend_config = request.param # Dynamically skip test if GPU capability is not met - if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ - != current_platform.get_device_capability(): + if ( + backend_config.specific_gpu_arch + and backend_config.specific_gpu_arch != current_platform.get_device_capability() + ): if backend_config.specific_gpu_arch == (9, 0): pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") elif backend_config.specific_gpu_arch == (10, 0): pytest.skip("Only Blackwell GPUs support Cutlass MLA") env_vars = { - "VLLM_USE_V1": "1", # Force native sampler to avoid potential nondeterminism in FlashInfer # when per-request generators are not used in V1. "VLLM_USE_FLASHINFER_SAMPLER": "0", @@ -143,8 +78,7 @@ def llm_pair(request): trust_remote_code=True, max_model_len=1024, max_num_seqs=128, - compilation_config=\ - CompilationConfig(**backend_config.comp_config), + compilation_config=CompilationConfig(**backend_config.comp_config), generation_config="vllm", seed=42, ) @@ -180,20 +114,22 @@ class TestFullCUDAGraph: meaning there would be multiple LLM instances hogging memory simultaneously. """ - @pytest.mark.parametrize(("batch_size", "max_tokens"), [ - (1, 10), - (7, 10), - (16, 10), - (25, 10), - (32, 10), - (45, 10), - (64, 10), - (123, 10), - (8, 5), - (8, 30), - ]) - def test_full_cudagraph(self, batch_size, max_tokens, - llm_pair: tuple[LLM, LLM]): + @pytest.mark.parametrize( + ("batch_size", "max_tokens"), + [ + (1, 10), + (7, 10), + (16, 10), + (25, 10), + (32, 10), + (45, 10), + (64, 10), + (123, 10), + (8, 5), + (8, 30), + ], + ) + def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]): """ Test various batch sizes and max_tokens to ensure that the full cudagraph compilation works for padded cases too. @@ -204,26 +140,33 @@ class TestFullCUDAGraph: prompts = ["the quick brown fox"] * batch_size # Use purely greedy decoding to avoid top-p truncation sensitivity # that can amplify tiny numeric differences across runtimes. - sampling_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - top_p=1.0) + sampling_params = SamplingParams( + temperature=0.0, max_tokens=max_tokens, top_p=1.0 + ) piecewise_responses = piecewise_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params) # Check that all responses are the same - for piecewise_res, full_res in zip(piecewise_responses, - full_responses): - assert piecewise_res.outputs[0].text.lower() == \ - full_res.outputs[0].text.lower() + for piecewise_res, full_res in zip(piecewise_responses, full_responses): + assert ( + piecewise_res.outputs[0].text.lower() + == full_res.outputs[0].text.lower() + ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION" - # Flex_Attention is not supported with full cuda graph - }), pytest.raises(RuntimeError): - LLM(model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=CompilationConfig(cudagraph_mode="FULL")) + with ( + temporary_environ( + { + "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION", + # Flex_Attention is not supported with full cuda graph + } + ), + pytest.raises(RuntimeError), + ): + LLM( + model="Qwen/Qwen2-1.5B-Instruct", + compilation_config=CompilationConfig(cudagraph_mode="FULL"), + ) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index f5e2d9ddb7528..d88645e3bfd62 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -4,21 +4,24 @@ Test (piecewise) compilation with a simple model where multiple submodules are compiled and graph captured separately. """ + import torch from torch import nn -from torch.library import Library from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter -from vllm.compilation.decorators import (ignore_torch_compile, - support_torch_compile) -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - VllmConfig, set_current_vllm_config) +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 @@ -26,35 +29,9 @@ HIDDEN_SIZE = 1024 RANDOM_SEED = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class ParentModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -62,7 +39,6 @@ class ParentModel(nn.Module): class Attention(nn.Module): - def __init__(self, mlp_size: int, hidden_size: int) -> None: super().__init__() self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) @@ -73,17 +49,21 @@ class Attention(nn.Module): nn.init.xavier_normal_( self.pre_attn.weight.data, generator=torch.Generator().manual_seed(RANDOM_SEED), - gain=0.001) + gain=0.001, + ) nn.init.xavier_normal_( self.post_attn.weight.data, generator=torch.Generator().manual_seed(RANDOM_SEED), - gain=0.001) + gain=0.001, + ) def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor: x_f32 = x.float() - return (x_f32 * torch.rsqrt( - torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) * - self.rms_norm_weight).to(x.dtype) + return ( + x_f32 + * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) + * self.rms_norm_weight + ).to(x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre_attn(x) @@ -98,14 +78,15 @@ class Attention(nn.Module): @support_torch_compile class CompiledAttention(nn.Module): - - def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.attn = Attention(mlp_size, hidden_size) @@ -115,26 +96,26 @@ class CompiledAttention(nn.Module): @support_torch_compile class CompiledAttentionTwo(CompiledAttention): - def forward(self, x: torch.Tensor) -> torch.Tensor: return self.attn(x) + x @ignore_torch_compile class SimpleModelWithTwoGraphs(ParentModel): - - def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) # Test will fail without set_model_tag here with error: # "ValueError: too many values to unpack (expected 3)" # This is because CompiledAttention and CompiledAttentionTwo - # have different implmentations but the same torch.compile + # have different implementations but the same torch.compile # cache dir will be used as default prefix is 'model_tag' with set_model_tag("attn_one"): self.attn_one = CompiledAttention( @@ -164,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel): @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, - cudagraph_runtime_mode: CUDAGraphMode): +def run_model( + vllm_config: VllmConfig, + model: nn.Module, + inputs: torch.Tensor, + cudagraph_runtime_mode: CUDAGraphMode, +): with set_forward_context({}, vllm_config=vllm_config): # warmup for the model with cudagraph_mode NONE model(inputs) # simulate cudagraphs capturing - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(inputs[:2]) - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=1, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(inputs[:1]) # simulate cudagraphs replay - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(inputs[:2]) output = output.cpu() @@ -200,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal(): outputs = [] # piecewise compile - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly::attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + model = ( + SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix="", + ) + .eval() + .cuda() + ) # Pre-allocate memory for CUDAGraph which expects # static tensor addresses inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() with compilation_counter.expect( - num_graphs_seen=2, # two graphs for the model - num_piecewise_graphs_seen=6, - # attn_one, attn_two each has 3 piecewise graphs - # (pre attn, post attn, silly_attention) each - num_piecewise_capturable_graphs_seen=4, - # attn_one, attn_two has pre attn and post attn each, total=4 - num_backend_compilations=4, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=6, + # attn_one, attn_two each has 3 piecewise graphs + # (pre attn, post attn, silly_attention) each + num_piecewise_capturable_graphs_seen=4, + # attn_one, attn_two has pre attn and post attn each, total=4 + num_backend_compilations=4, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append( - run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) + outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.NO_COMPILATION, )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.NO_COMPILATION, + ) + ) cudagraph_runtime_mode = CUDAGraphMode.NONE with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + model = ( + SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix="", + ) + .eval() + .cuda() + ) with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): - outputs.append( - run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) + outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # piecewise compile without CUDA graph - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=False, - splitting_ops=["silly.attention"], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=False, + splitting_ops=["silly::attention"], + ) + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + model = ( + SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix="", + ) + .eval() + .cuda() + ) with compilation_counter.expect( - num_graphs_seen=2, - num_piecewise_graphs_seen=6, - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, - num_cudagraph_captured=0, # no cudagraph captured + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=0, # no cudagraph captured ): - outputs.append( - run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) + outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # Generally don't expect outputs with and without inductor # to be bitwise equivalent diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 2d1a72d44ec70..bc65e3da0ae74 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -4,63 +4,36 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ + import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - VllmConfig, set_current_vllm_config) -from vllm.envs import VLLM_USE_V1 -from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op - -global_counter = 0 - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - global global_counter - global_counter += 1 - print(f"{global_counter=}") - out.copy_(q) - out[0] += 1 - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, ) +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer + +# This import automatically registers `torch.ops.silly.attention` +from ..silly_attention import get_global_counter, reset_global_counter @support_torch_compile class SillyModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overall effect: - x += 1 - x[0] += 2 + x = 3 * x + 19 global_counter += 2 """ x = x + 1 @@ -77,57 +50,118 @@ class SillyModel(nn.Module): return x -@pytest.mark.parametrize("use_inductor", [True, False]) -def test_simple_piecewise_compile(use_inductor): - assert VLLM_USE_V1 - - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - use_inductor=use_inductor, - splitting_ops=["silly.attention"], - cudagraph_copy_inputs=True, - cudagraph_capture_sizes=[1, 2], - )) +def _run_simple_model( + splitting_ops, + use_inductor_graph_partition, + use_inductor, + expected_num_piecewise_graphs_seen, + expected_num_piecewise_capturable_graphs_seen, + expected_num_backend_compilations, + expected_num_cudagraph_captured, +): + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + use_inductor=use_inductor, + splitting_ops=splitting_ops, + use_inductor_graph_partition=use_inductor_graph_partition, + cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], + ) + ) with set_current_vllm_config(vllm_config): - model = SillyModel(vllm_config=vllm_config, prefix='') + model = SillyModel(vllm_config=vllm_config, prefix="") inputs = torch.randn(100).cuda() - with compilation_counter.expect( + with ( + compilation_counter.expect( num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=5, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=3, # 1 + num_layers - num_backend_compilations=3, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured= - 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ), set_forward_context(None, - vllm_config=vllm_config): # background context + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, + ), + set_forward_context(None, vllm_config=vllm_config), + ): # background context # warm up with background context model(inputs) # capturing/replaying should under context of cudagraph dispatching with set_forward_context( - None, - vllm_config=vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=BatchDescriptor(num_tokens=2, )): + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(torch.randn(2).cuda()) with set_forward_context( - None, - vllm_config=vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=BatchDescriptor(num_tokens=1, )): + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() - global global_counter - global_counter = 0 + reset_global_counter() with set_forward_context( - None, - vllm_config=vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=BatchDescriptor(num_tokens=2, )): + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(input) - assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert get_global_counter() == 2 + assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) + + +@pytest.mark.parametrize("use_inductor", [True, False]) +@torch.inference_mode() +def test_simple_piecewise_compile(use_inductor): + _run_simple_model( + splitting_ops=["silly::attention"], + use_inductor_graph_partition=False, + use_inductor=use_inductor, + # 2 * num_layers + 1 + expected_num_piecewise_graphs_seen=5, + # 1 + num_layers + expected_num_piecewise_capturable_graphs_seen=3, + # num_piecewise_capturable_graphs_seen + expected_num_backend_compilations=3, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + expected_num_cudagraph_captured=6, + ) + + +@torch.inference_mode() +@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []]) +def test_simple_inductor_graph_partition(splitting_ops, monkeypatch): + if not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + + # disable compile cache so that we run separately for different splitting_ops + # and get the expected number of cudagraphs captured. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + _run_simple_model( + # Inductor graph partition automatically resets splitting_ops to an empty list + splitting_ops=splitting_ops, + use_inductor_graph_partition=True, + use_inductor=True, + # Since not splitting at fx graph level + expected_num_piecewise_graphs_seen=1, + # Since not splitting at fx graph level + expected_num_piecewise_capturable_graphs_seen=1, + # Since not splitting at fx graph level + expected_num_backend_compilations=1, + # Inductor graph partition still captures 6 graph, same as fx graph partition + expected_num_cudagraph_captured=6, + ) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index bcfd0d834c5db..08f59283a6db5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,44 +8,27 @@ This is a tractable model, the weights and computation are specially designed if the config `tractable_init` is set to True. Otherwise, the weights are initialized randomly with a fixed seed. """ + from dataclasses import dataclass from typing import Any, Optional import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - VllmConfig, set_current_vllm_config) -from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, ) +from vllm.forward_context import BatchDescriptor, set_forward_context + +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 @dataclass @@ -66,15 +49,14 @@ class LlamaConfig: factors.append((k, v)) factors.sort() import hashlib - return hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + + return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() def __post_init__(self): assert self.mlp_size >= self.hidden_size class LlamaMLP(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.gate_up_projection = nn.Linear( @@ -89,31 +71,31 @@ class LlamaMLP(nn.Module): ) if config.tractable_init: - nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size]) - nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:]) + nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size]) + nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :]) nn.init.eye_(self.down_projection.weight.data) else: - nn.init.xavier_normal_(self.gate_up_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) - nn.init.xavier_normal_(self.down_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) + nn.init.xavier_normal_( + self.gate_up_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) + nn.init.xavier_normal_( + self.down_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) def forward(self, x): # for tractable_init and positive input, this is # essentially an elementwise-square x = self.gate_up_projection(x) - x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( - x[:, x.size(1) // 2:]) + x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :]) x = self.down_projection(x) return x class LlamaAttention(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.qkv_projection = nn.Linear( @@ -129,21 +111,25 @@ class LlamaAttention(nn.Module): ) if config.tractable_init: - nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size]) - nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 * - config.hidden_size]) - nn.init.eye_(self.qkv_projection.weight.data[2 * - config.hidden_size:]) + nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size]) + nn.init.eye_( + self.qkv_projection.weight.data[ + config.hidden_size : 2 * config.hidden_size + ] + ) + nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :]) nn.init.eye_(self.output_projection.weight.data) else: - nn.init.xavier_normal_(self.qkv_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) - nn.init.xavier_normal_(self.output_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) + nn.init.xavier_normal_( + self.qkv_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) + nn.init.xavier_normal_( + self.output_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) def forward( self, @@ -167,7 +153,6 @@ class LlamaAttention(nn.Module): class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.self_attention = LlamaAttention(config) @@ -187,7 +172,7 @@ class LlamaDecoderLayer(nn.Module): - if residual is not None, the outputs are: - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3 - hidden_states = (residual + 1) ** 2 - """ # noqa + """ # noqa if residual is None: residual = hidden_states hidden_states = hidden_states + 1 @@ -196,8 +181,9 @@ class LlamaDecoderLayer(nn.Module): residual = hidden_states hidden_states = hidden_states + 1 - hidden_states = self.self_attention(positions=positions, - hidden_states=hidden_states) + hidden_states = self.self_attention( + positions=positions, hidden_states=hidden_states + ) hidden_states = hidden_states + residual residual = hidden_states @@ -209,20 +195,22 @@ class LlamaDecoderLayer(nn.Module): @support_torch_compile class LlamaModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - config: LlamaConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + vllm_config: VllmConfig, + config: LlamaConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.embedding_tokens = nn.Embedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, ) self.layers = nn.ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_layers)]) + [LlamaDecoderLayer(config) for _ in range(config.num_layers)] + ) # this is the initial value of the hidden states self.embedding_tokens.weight.data.fill_(config.init_value) @@ -239,34 +227,39 @@ class LlamaModel(nn.Module): return hidden_states -def tractable_computation(input_ids: torch.Tensor, - positions: torch.Tensor, - config: LlamaConfig, - init_value: float = 1.0) -> torch.Tensor: - hidden_states = torch.ones(input_ids.size(0), - config.hidden_size, - device=input_ids.device, - dtype=input_ids.dtype) * init_value +def tractable_computation( + input_ids: torch.Tensor, + positions: torch.Tensor, + config: LlamaConfig, + init_value: float = 1.0, +) -> torch.Tensor: + hidden_states = ( + torch.ones( + input_ids.size(0), + config.hidden_size, + device=input_ids.device, + dtype=input_ids.dtype, + ) + * init_value + ) # first layer residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 - hidden_states = (residual + 1)**2 + hidden_states = (residual + 1) ** 2 # following layers for _ in range(config.num_layers - 1): hidden_states = hidden_states + residual residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 - hidden_states = (residual + 1)**2 + hidden_states = (residual + 1) ** 2 return hidden_states @torch.inference_mode -def run_model(llama_config, - use_compile: bool, - use_inductor: bool, - split_attn: bool = False) -> torch.Tensor: - +def run_model( + llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False +) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, @@ -275,58 +268,70 @@ def run_model(llama_config, cudagraph_capture_sizes=[1, 2], ) if split_attn: - compilation_config.splitting_ops = ["silly.attention"] + compilation_config.splitting_ops = ["silly::attention"] cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: compilation_config = CompilationConfig( - level=CompilationLevel.NO_COMPILATION, ) + level=CompilationLevel.NO_COMPILATION, + ) cudagraph_runtime_mode = CUDAGraphMode.NONE - vllm_config = VllmConfig(compilation_config=compilation_config, - additional_config=llama_config) + vllm_config = VllmConfig( + compilation_config=compilation_config, additional_config=llama_config + ) with set_current_vllm_config(vllm_config): - model = LlamaModel(config=llama_config, - vllm_config=vllm_config, - prefix="").eval().cuda() + model = ( + LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="") + .eval() + .cuda() + ) - with set_forward_context({}, - vllm_config=vllm_config): # background context + with set_forward_context({}, vllm_config=vllm_config): # background context B = 16 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda() positions = torch.arange(B).cuda() # warmup for the model with cudagraph_mode NONE model(input_ids, positions) # simulate cudagraphs capturing - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(input_ids[:2], positions[:2]) - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=1, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(input_ids[:1], positions[:1]) input_ids[:2].zero_() # simulate cudagraphs replay - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(input_ids[:2], positions[:2]) output = output.cpu() if llama_config.tractable_init: - expected_output = tractable_computation(input_ids[:2], - positions[:2], - llama_config).cpu() + expected_output = tractable_computation( + input_ids[:2], positions[:2], llama_config + ).cpu() assert torch.allclose(output, expected_output) else: @@ -337,27 +342,23 @@ def run_model(llama_config, def test_toy_llama(use_inductor: bool): # compare output with and without piecewise compilation - llama_config = LlamaConfig(hidden_size=128, - mlp_size=256, - vocab_size=128, - num_layers=12) + llama_config = LlamaConfig( + hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12 + ) - tractable_config = LlamaConfig(hidden_size=128, - mlp_size=256, - vocab_size=128, - num_layers=2, - tractable_init=True) + tractable_config = LlamaConfig( + hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True + ) outputs = [] with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): - outputs.append( - run_model(llama_config, use_inductor=False, use_compile=False)) + outputs.append(run_model(llama_config, use_inductor=False, use_compile=False)) run_model(tractable_config, use_inductor=False, use_compile=False) if use_inductor: @@ -366,41 +367,44 @@ def test_toy_llama(use_inductor: bool): kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=1, - num_piecewise_capturable_graphs_seen=1, - num_backend_compilations=1, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured= - 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - **kwargs, + # One graph for the model + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_piecewise_capturable_graphs_seen=1, + # num_piecewise_capturable_graphs_seen + num_backend_compilations=1, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_cudagraph_captured=2, + **kwargs, ): outputs.append( - run_model(llama_config, - use_inductor=use_inductor, - use_compile=True)) + run_model(llama_config, use_inductor=use_inductor, use_compile=True) + ) run_model(tractable_config, use_inductor=use_inductor, use_compile=True) with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=2 * llama_config.num_layers + - 1, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=1 + - llama_config.num_layers, # 1 + num_layers - num_backend_compilations=1 + - llama_config.num_layers, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=2 * - (1 + llama_config.num_layers - ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1 + num_piecewise_capturable_graphs_seen=1 + + llama_config.num_layers, # 1 + num_layers + num_backend_compilations=1 + + llama_config.num_layers, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=2 + * ( + 1 + llama_config.num_layers + ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model(llama_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True)) - run_model(tractable_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True) + run_model( + llama_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True, + ) + ) + run_model( + tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True + ) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) @@ -411,17 +415,15 @@ def benchmark(): from triton.testing import do_bench # similar to llama 3.1-8B - llama_config = LlamaConfig(hidden_size=4096, - mlp_size=14336, - vocab_size=128 * 1024, - num_layers=32) + llama_config = LlamaConfig( + hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32 + ) # a tiny model to measure the overhead # of piecewise cudagraph - llama_config = LlamaConfig(hidden_size=40, - mlp_size=80, - vocab_size=128, - num_layers=2) + llama_config = LlamaConfig( + hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2 + ) cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] @@ -436,7 +438,7 @@ def benchmark(): compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=cudagraph_sizes, ) else: @@ -447,12 +449,15 @@ def benchmark(): vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): - model = LlamaModel(config=llama_config, - vllm_config=vllm_config, - prefix="").eval().cuda().to(torch.bfloat16) + model = ( + LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="") + .eval() + .cuda() + .to(torch.bfloat16) + ) B = 256 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda() positions = torch.arange(B).cuda().to(torch.bfloat16) graphs = {} @@ -474,21 +479,26 @@ def benchmark(): # and use it later, because it will look up the name `b` in the # enclosing scope, and the value of `b` will always be 256. # it is fine here, because we only use the lambda function once. - runtime = do_bench(lambda: graphs[b][0] # noqa - (input_ids[:b], positions[:b])) # noqa + runtime = do_bench( + lambda: graphs[b][0]( # noqa + input_ids[:b], # noqa + positions[:b], # noqa + ) + ) piecewise_cudagraph_time[b] = runtime else: runtime = do_bench(lambda: graphs[b][0].replay()) # noqa - eager_runtime = do_bench( - lambda: model(input_ids[:b], positions[:b])) # noqa + eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa full_cudagraph_time[b] = runtime eager_time[b] = eager_runtime # print in tabular format print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") for b in cudagraph_sizes: - print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" - f"\t{piecewise_cudagraph_time[b]:.3f}") + print( + f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" + f"\t{piecewise_cudagraph_time[b]:.3f}" + ) if __name__ == "__main__": diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py new file mode 100644 index 0000000000000..c0d3f908149f6 --- /dev/null +++ b/tests/compile/silly_attention.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared PyTorch custom silly attention for compilation tests. +Centralizes custom operation definitions to avoid duplicate registrations. +""" + +import torch +from torch.library import Library + +from vllm.utils import direct_register_custom_op + +# Shared library for all compilation test operations +# Using "silly" namespace to match existing test expectations +# import this file will automatically register +# torch ops for testing (like silly.attention) +silly_lib = Library("silly", "FRAGMENT") + +# Global counter that counts the number of times attention is invoked +_global_counter = 0 + + +def get_global_counter(): + """Get the current global counter value""" + return _global_counter + + +def reset_global_counter(): + """Reset the global counter to 0""" + global _global_counter + _global_counter = 0 + + +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: + """ + Unified attention implementation that depends on + all inputs and affects the output. + Always increments a global counter that tests can use or ignore. + """ + global _global_counter + + # Always increment the global counter + _global_counter += 1 + + # Unified implementation that depends on all inputs + out.copy_(q + k + v) + + +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: + """Fake implementation for testing""" + return + + +# Register the unified attention operation +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, + tags=(torch._C.Tag.cudagraph_unsafe,), +) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py new file mode 100644 index 0000000000000..08f79d90cd367 --- /dev/null +++ b/tests/compile/test_aot_compile.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile +from contextlib import contextmanager + +import pytest +import torch + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) +from vllm.forward_context import set_forward_context +from vllm.utils import is_torch_equal_or_newer + + +def reference_fn(x: torch.Tensor): + assert x.shape[0] <= 42 + assert x.shape[0] % 2 == 0 + for _ in range(3000): + x = x + x.shape[0] + return x + + +@support_torch_compile +class CompiledMod(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + return reference_fn(x) + + +def make_vllm_config() -> VllmConfig: + return VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + ) + ) + + +@contextmanager +def use_vllm_config(vllm_config: VllmConfig): + with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config): + yield + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + vllm_config = make_vllm_config() + args = (torch.randn(10, 10),) + expected = reference_fn(*args) + with use_vllm_config(vllm_config): + m.setenv("VLLM_USE_AOT_COMPILE", "0") + with ( + pytest.raises(RuntimeError, match="Detected recompile"), + torch.compiler.set_stance("fail_on_recompile"), + ): + CompiledMod(vllm_config=vllm_config)(*args) + + m.setenv("VLLM_USE_AOT_COMPILE", "1") + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + actual = CompiledMod(vllm_config=vllm_config)(*args) + assert torch.allclose(actual, expected) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): + with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m: + args = (torch.randn(10, 10),) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError): + CompiledMod(vllm_config=vllm_config)(*args) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_save_and_load(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + args = (torch.randn(10, 10),) + + with tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + expected = CompiledMod(vllm_config=vllm_config)(*args) + + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + ret = CompiledMod(vllm_config=vllm_config)(*args) + assert torch.allclose(ret, expected) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_shape_env(monkeypatch: pytest.MonkeyPatch): + """ + Test that the shape environment is correctly serialized and preserved + when loading from cache. + """ + with monkeypatch.context() as m: + args = (torch.randn(10, 10),) + + with tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + compiled_mod = CompiledMod(vllm_config=vllm_config) + compiled_mod(*args) + artifacts = compiled_mod.aot_compiled_fn._artifacts + guards_string = artifacts.compiled_fn.shape_env.format_guards() + assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + compiled_mod = CompiledMod(vllm_config=vllm_config) + compiled_mod(*args) + artifacts = compiled_mod.aot_compiled_fn._artifacts + guards_string = artifacts.compiled_fn.shape_env.format_guards() + assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 9a51e6b3514f4..d396d3940f67f 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -8,18 +8,30 @@ import torch import vllm.envs as envs from vllm.compilation.collective_fusion import AsyncTPPass -from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - PassConfig, VllmConfig) -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_reduce_scatter) -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.config import ( + CompilationConfig, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) +from vllm.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter, +) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.platforms import current_platform from vllm.utils import update_environment_variables from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import (compare_two_settings, create_new_process_for_each_test, - multi_gpu_test) +from ..utils import ( + compare_two_settings, + create_new_process_for_each_test, + multi_gpu_test, +) from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -33,21 +45,20 @@ prompts = [ class TestMMRSModel(torch.nn.Module): - def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size self.dtype = dtype - self.gate_proj = torch.nn.Parameter(torch.empty( - (self.hidden_size * 2, hidden_size)), - requires_grad=False) + self.gate_proj = torch.nn.Parameter( + torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False + ) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) def forward(self, hidden_states): """ Forward pass implementing the mm + reduce scatter in the FX graph - + """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) @@ -66,14 +77,13 @@ class TestMMRSModel(torch.nn.Module): class TestAGMMModel(torch.nn.Module): - def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size self.dtype = dtype - self.weight = torch.nn.Parameter(torch.empty( - (hidden_size, hidden_size)), - requires_grad=False) + self.weight = torch.nn.Parameter( + torch.empty((hidden_size, hidden_size)), requires_grad=False + ) # Initialize weights torch.nn.init.normal_(self.weight, std=0.02) @@ -96,32 +106,35 @@ class TestAGMMModel(torch.nn.Module): class _BaseScaledMMModel(torch.nn.Module): - def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size self.dtype = dtype - self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\ - .contiguous().transpose(0, 1) + self.weight = ( + torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) # Initialize scale_b for _scaled_mm. self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32) class TestScaledMMRSModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ Forward pass implementing the scaled_mm + reduce scatter in the FX graph - + """ fp8_input = input.to(FP8_DTYPE) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) - scaled_mm = torch._scaled_mm(fp8_input, - self.weight, - scale_a=scale_a, - scale_b=self.scale_b, - out_dtype=self.dtype) + scaled_mm = torch._scaled_mm( + fp8_input, + self.weight, + scale_a=scale_a, + scale_b=self.scale_b, + out_dtype=self.dtype, + ) reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0) return reduce_scatter @@ -129,11 +142,10 @@ class TestScaledMMRSModel(_BaseScaledMMModel): return [torch.ops.vllm.reduce_scatter.default] def ops_in_model_after(self): - return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default] class TestAGScaledMMModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ Forward pass implementing the all gather + scaled_mm in the FX graph @@ -143,11 +155,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel): all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0) scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) - scaled_mm = torch._scaled_mm(all_gather, - self.weight, - scale_a=scale_a, - scale_b=self.scale_b, - out_dtype=self.dtype) + scaled_mm = torch._scaled_mm( + all_gather, + self.weight, + scale_a=scale_a, + scale_b=self.scale_b, + out_dtype=self.dtype, + ) return scaled_mm def ops_in_model_before(self): @@ -158,20 +172,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel): class TestCutlassScaledMMRSModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ Forward pass implementing the cutlass_scaled_mm + reduce scatter in the FX graph - + """ fp8_input = input.to(FP8_DTYPE) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) - mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]), - dtype=self.dtype, - device=input.device) - torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a, - self.scale_b, None) + mm_out = torch.empty( + (fp8_input.shape[0], self.weight.shape[1]), + dtype=self.dtype, + device=input.device, + ) + torch.ops._C.cutlass_scaled_mm( + mm_out, fp8_input, self.weight, scale_a, self.scale_b, None + ) reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0) return reduce_scatter @@ -179,14 +195,13 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel): return [torch.ops.vllm.reduce_scatter.default] def ops_in_model_after(self): - return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default] class TestAGCutlassScaledMMModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ - Forward pass implementing the all gather + cutlass_scaled_mm + Forward pass implementing the all gather + cutlass_scaled_mm in the FX graph """ # Reshape input @@ -195,11 +210,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel): scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) - mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]), - dtype=self.dtype, - device=all_gather.device) - torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight, - scale_a, self.scale_b, None) + mm_out = torch.empty( + (all_gather.shape[0], self.weight.shape[1]), + dtype=self.dtype, + device=all_gather.device, + ) + torch.ops._C.cutlass_scaled_mm( + mm_out, all_gather, self.weight, scale_a, self.scale_b, None + ) return mm_out def ops_in_model_before(self): @@ -210,23 +228,43 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel): @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("test_model", [ - TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel, - TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel -]) +@pytest.mark.parametrize( + "test_model", + [ + TestMMRSModel, + TestAGMMModel, + TestScaledMMRSModel, + TestAGScaledMMModel, + TestCutlassScaledMMRSModel, + TestAGCutlassScaledMMModel, + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): - if test_model in (TestScaledMMRSModel, TestAGScaledMMModel, - TestCutlassScaledMMRSModel, - TestAGCutlassScaledMMModel) and dtype == torch.float16: +@pytest.mark.parametrize("dynamic", [True, False]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_async_tp_pass_replace( + test_model: str, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + dynamic: bool, +): + if ( + test_model + in ( + TestScaledMMRSModel, + TestAGScaledMMModel, + TestCutlassScaledMMRSModel, + TestAGCutlassScaledMMModel, + ) + and dtype == torch.float16 + ): pytest.skip( - "Only bf16 high precision output types are supported for " \ + "Only bf16 high precision output types are supported for " "per-token (row-wise) scaling" ) @@ -235,19 +273,33 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model, - batch_size, seq_len, hidden_size, - dtype), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + dynamic, + ), + nprocs=nprocs, + ) run_torch_spawn(async_tp_pass_on_test_model, num_processes) -def async_tp_pass_on_test_model(local_rank: int, world_size: int, - test_model_cls: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def async_tp_pass_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + dynamic: bool, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -255,13 +307,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() @@ -269,31 +323,37 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_async_tp=True, ), ) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig( + enable_async_tp=True, + ), + ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) - model = test_model_cls(hidden_size, - dtype) # Pass dtype to model constructor + model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype, - requires_grad=False) + hidden_states = torch.randn( + (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False + ) + + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) + assert async_tp_pass.matched_count == 1 + # In pre-nodes, all gather or reduce scatter should exist, # fused_matmul_reduce_scatter or fused_all_gather_matmul should not backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) @@ -304,10 +364,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, @create_new_process_for_each_test() -@pytest.mark.parametrize("model_id", [ - "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" -]) +@pytest.mark.parametrize( + "model_id", + ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"], +) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("async_tp_enabled", [True]) @pytest.mark.parametrize("distributed_backend", ["mp"]) @@ -340,16 +400,10 @@ def test_async_tp_pass_correctness( common_args.append("--enforce-eager") compilation_config = { - 'level': 3, - 'compile_sizes': [2, 4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_async_tp': async_tp_enabled - }, - } - - async_tp_env = tp_env = { - "VLLM_USE_V1": "1", + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": {"enable_async_tp": async_tp_enabled}, } async_tp_args = [ @@ -370,9 +424,4 @@ def test_async_tp_pass_correctness( "mp", ] - compare_two_settings(model_id, - async_tp_args, - tp_args, - async_tp_env, - tp_env, - method="generate") + compare_two_settings(model_id, async_tp_args, tp_args, method="generate") diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 422cb94b036ca..4bcefb30b2e6e 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -20,10 +20,9 @@ class TestSetting: tp_size: int attn_backend: str method: str - fullgraph: bool -# we cannot afford testing the full Catesian product +# we cannot afford testing the full Cartesian product # of all models and all levels @pytest.mark.parametrize( "test_setting", @@ -36,7 +35,6 @@ class TestSetting: tp_size=2, attn_backend="FLASH_ATTN", method="generate", - fullgraph=True, ), # llama model with quantization TestSetting( @@ -46,7 +44,6 @@ class TestSetting: tp_size=1, attn_backend="FLASH_ATTN", method="generate", - fullgraph=True, ), # MoE model TestSetting( @@ -56,32 +53,31 @@ class TestSetting: tp_size=2, attn_backend="FLASH_ATTN", method="generate", - fullgraph=True, ), # embedding model TestSetting( model="BAAI/bge-multilingual-gemma2", model_args=[ - "--runner", "pooling", "--dtype", "bfloat16", - "--max-model-len", "2048" + "--runner", + "pooling", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", ], pp_size=1, tp_size=1, attn_backend="FLASH_ATTN", method="encode", - fullgraph=True, ), - # TODO: bert models are not supported in V1 yet - # # encoder-based embedding model (BERT) - # TestSetting( - # model="BAAI/bge-base-en-v1.5", - # model_args=["--runner", "pooling"], - # pp_size=1, - # tp_size=1, - # attn_backend="XFORMERS", - # method="encode", - # fullgraph=True, - # ), + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--runner", "pooling"], + pp_size=1, + tp_size=1, + attn_backend="FLASH_ATTN", + method="encode", + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -90,9 +86,9 @@ class TestSetting: tp_size=1, attn_backend="FLASH_ATTN", method="generate_with_image", - fullgraph=False, ), - ]) + ], +) def test_compile_correctness( monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting, @@ -106,25 +102,29 @@ def test_compile_correctness( tp_size = test_setting.tp_size attn_backend = test_setting.attn_backend method = test_setting.method - fullgraph = test_setting.fullgraph - if cuda_device_count_stateless() != pp_size * tp_size: - pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got " - f"{cuda_device_count_stateless()}") + if cuda_device_count_stateless() < pp_size * tp_size: + pytest.skip( + f"Need at least {pp_size}*{tp_size} CUDA gpus but got " + f"{cuda_device_count_stateless()}" + ) with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) final_args = [ - "--enforce-eager", *model_args, "-pp", - str(pp_size), "-tp", - str(tp_size) + "--enforce-eager", + *model_args, + "-pp", + str(pp_size), + "-tp", + str(tp_size), ] all_args: list[list[str]] = [] all_envs: list[dict[str, str] | None] = [] for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.PIECEWISE, + CompilationLevel.NO_COMPILATION, + CompilationLevel.PIECEWISE, ]: all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) @@ -135,20 +135,17 @@ def test_compile_correctness( model, all_args, all_envs, - method=method if method != "generate" else "generate_close") + method=method if method != "generate" else "generate_close", + ) all_envs.clear() all_args.clear() for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, + CompilationLevel.NO_COMPILATION, + CompilationLevel.DYNAMO_AS_IS, + CompilationLevel.DYNAMO_ONCE, ]: all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) - if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: - # "DYNAMO_ONCE" will always use fullgraph - all_envs[-1][ - "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 90e8e0ff95858..ae8b0b226c313 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -2,28 +2,35 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -import vllm from vllm.compilation.counter import compilation_counter -from vllm.config import VllmConfig -from vllm.utils import _is_torch_equal_or_newer +from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.compilation import CompilationLevel +from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): - assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev') - assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev') + # Test the version comparison logic using the private function + assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev") + assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") -def test_use_cudagraphs_dynamic(monkeypatch): - assert vllm.envs.VLLM_USE_V1 +def test_use_cudagraphs_dynamic(): vllm_config = VllmConfig() + # Default V1 configuration now starts without cudagraphs enabled; the + # engine decides when to capture based on runtime settings instead of a + # blanket default. assert vllm_config.compilation_config.use_cudagraph - monkeypatch.setenv('VLLM_USE_V1', '0') - vllm_config = VllmConfig() - assert not vllm_config.compilation_config.use_cudagraph + +def test_custom_op(): + # proper syntax + _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) + + with pytest.raises(ValueError, match="Invalid syntax '"): + _ = CompilationConfig(custom_ops=["quant_fp8"]) # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @@ -33,22 +40,24 @@ def test_use_cudagraphs_dynamic(monkeypatch): # may be influenced by other tests. @pytest.mark.parametrize("val", ["1"]) def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): - assert vllm.envs.VLLM_USE_V1 - # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val) compilation_config = { "use_cudagraph": False, # speed things up a bit } with ( - compilation_counter.expect(num_cache_entries_updated=0, - num_compiled_artifacts_saved=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config=compilation_config, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect( + num_cache_entries_updated=0, num_compiled_artifacts_saved=0 + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -56,25 +65,26 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): @pytest.mark.forked @pytest.mark.parametrize("enabled", [True, False]) def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): - assert vllm.envs.VLLM_USE_V1 - # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = { "cudagraph_capture_sizes": [100], "use_cudagraph": enabled, } with ( - compilation_counter.expect( - num_graphs_seen=1, - num_gpu_runner_capture_triggers=1 if enabled else 0, - num_cudagraph_captured=13 if enabled else 0, - ), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config=compilation_config, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect( + num_graphs_seen=1, + num_gpu_runner_capture_triggers=1 if enabled else 0, + num_cudagraph_captured=13 if enabled else 0, + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -82,14 +92,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): @pytest.mark.forked def test_dynamo_as_is(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(dynamo_as_is_count=1), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config={"level": 1}, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect(dynamo_as_is_count=1), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config={"level": 1}, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -97,15 +110,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch): @pytest.mark.forked def test_no_compilation(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, - dynamo_as_is_count=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config={"level": 0}, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config={"level": 0}, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -113,13 +127,92 @@ def test_no_compilation(vllm_runner, monkeypatch): @pytest.mark.forked def test_enforce_eager(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, - dynamo_as_is_count=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - enforce_eager=True, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4 + ) as _, + ): pass + + +def test_splitting_ops_dynamic(): + # Default config + config = VllmConfig() + # Default V1 config leaves cudagraph mode unset; splitting ops are only + # populated when the engine decides to use piecewise compilation. + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert not config.compilation_config.splitting_ops_contain_attention() + + # When use_inductor_graph_partition=True + if is_torch_equal_or_newer("2.9.0.dev"): + config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_inductor_graph_partition=True, + splitting_ops=["vllm::unified_attention"], + ) + ) + # with inductor partition we use splitting_ops directly for + # partition rules + assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] + + # When attn_fusion pass enabled, splitting_ops now default to attention ops. + config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ) + ) + # With the new simplified logic, attention fusion works with splitting_ops + assert config.compilation_config.splitting_ops_contain_attention() + # cudagraph mode remains PIECEWISE + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + + # When both use_inductor_graph_partition and attn_fusion pass enabled. + if is_torch_equal_or_newer("2.9.0.dev"): + config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_inductor_graph_partition=True, + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ) + ) + # With inductor graph partition, attn_fusion and splitting_ops + # work together. Default splitting_ops include attention ops. + assert config.compilation_config.splitting_ops_contain_attention() + # enable_attn_fusion is directly supported under + # use_inductor_graph_partition=True, and cudagraph_mode + # is unchanged. + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + + +def test_resolve_operator_overload(): + import torch + + from vllm.compilation.partition_rules import resolve_defined_ops + + # Test valid operator names + resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"]) + assert len(resolved) == 2 + assert resolved[0] is torch.ops.aten.mm.default + assert resolved[1] is torch.ops.aten.addmm.default + + # Test that invalid operators are skipped (not raising exceptions) + resolved = resolve_defined_ops( + [ + "aten::mm.default", + "aten::nonexistent_op.default", # This should be skipped + "aten::addmm.default", + ] + ) + assert len(resolved) == 2 # Only 2 valid ops + assert resolved[0] is torch.ops.aten.mm.default + assert resolved[1] is torch.ops.aten.addmm.default diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 51f8ddd566d56..6b050207ec41b 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -2,71 +2,63 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter -from vllm.compilation.decorators import (ignore_torch_compile, - support_torch_compile) -from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, - CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import ( + CacheConfig, + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, - cudagraph_runtime_mode: CUDAGraphMode): +def run_model( + vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode +): with set_forward_context({}, vllm_config=vllm_config): # warmup for the model with cudagraph_mode NONE model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) # simulate cudagraphs capturing - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(torch.randn(2, MLP_SIZE).cuda()) - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=1, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(torch.randn(1, MLP_SIZE).cuda()) # simulate cudagraphs replay - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(torch.randn(2, MLP_SIZE).cuda()) output = output.cpu() @@ -75,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, def test_ignore_torch_compile_decorator(): # piecewise - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly::attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @support_torch_compile class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs + ) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -102,66 +93,60 @@ def test_ignore_torch_compile_decorator(): return x @ignore_torch_compile - class B(A): - ... + class B(A): ... @support_torch_compile - class C(B): - ... + class C(B): ... with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) with set_current_vllm_config(vllm_config): - mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda() # B's ignore_torch_compile should override A's support_torch_compile with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): run_model(vllm_config, mod_B, cudagraph_runtime_mode) with set_current_vllm_config(vllm_config): - mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda() # C's support_torch_compile should override B's ignore_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_C, cudagraph_runtime_mode) -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=True -@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. - kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) class B(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -173,17 +158,13 @@ class B(nn.Module): return x -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=False -@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. - cache_config.kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill +) class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) @@ -198,54 +179,60 @@ class A(nn.Module): def test_conditional_compile_enable_if(): - vllm_config = VllmConfig(cache_config=CacheConfig( - kv_sharing_fast_prefill=True, ), - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + cache_config=CacheConfig( + kv_sharing_fast_prefill=True, + ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly::attention"], + cudagraph_capture_sizes=[1, 2], + ), + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() # A has support_torch_compile but enable_if fn returns False # enalbe_if will be True for B, so we expect mod1 and mod2 # to be compiled with compilation_counter.expect( - num_graphs_seen=2, - num_piecewise_graphs_seen=6, - # 3 piecewise graphs per instance of B() - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + # 3 piecewise graphs per instance of B() + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) # Set kv_sharing_fast_prefill=False # which will cause A to be compiled and B to not be compiled - vllm_config = VllmConfig(cache_config=CacheConfig( - kv_sharing_fast_prefill=False, ), - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + cache_config=CacheConfig( + kv_sharing_fast_prefill=False, + ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly::attention"], + cudagraph_capture_sizes=[1, 2], + ), + ) with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=7, - # 3 attn ops and 4 non-attn ops - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=7, + # 3 attn ops and 4 non-attn ops + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 84178344a5f36..8ccae4cfb9df2 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,60 +3,75 @@ from __future__ import annotations +import logging import tempfile -from typing import Any, Optional, Union +from typing import Any, Union import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel, PassConfig +from vllm.attention.backends.registry import _Backend +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test -def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): +def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { - "dtype": torch.float16, - }), - ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { - "dtype": torch.float16, - }), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + { + "dtype": torch.float16, + }, + ), + ( + "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", + { + "dtype": torch.float16, + }, + ), ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: - # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { - "quantization": "gguf" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"}) + ) if is_quant_method_supported("gptq"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { - "quantization": "gptq" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"}) + ) if is_quant_method_supported("gptq_marlin"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { - "quantization": "gptq_marlin" - })) + TEST_MODELS.append( + ( + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + {"quantization": "gptq_marlin"}, + ) + ) if is_quant_method_supported("gptq_marlin_24"): - TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { - "quantization": "gptq_marlin_24" - })) + TEST_MODELS.append( + ( + "alexm-nm/tinyllama-24-marlin24-4bit-g128", + {"quantization": "gptq_marlin_24"}, + ) + ) if not current_platform.is_rocm() and is_quant_method_supported("awq"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { - "quantization": "AWQ" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"}) + ) if keywords is None: return TEST_MODELS @@ -79,9 +94,7 @@ def test_full_graph( ): model, model_kwargs = model_info - with monkeypatch.context() as m: - # make sure these models can be captured in full graph mode - m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") + with monkeypatch.context(): print(f"MODEL={model}") run_model(optimization_level, model, model_kwargs) @@ -92,35 +105,122 @@ def test_full_graph( "compilation_config, model_info", [ # additional compile sizes, only some of the models - (CompilationConfig(level=CompilationLevel.PIECEWISE, - compile_sizes=[1, 2]), model) + ( + CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]), + model, + ) for model in models_list(all=False) - ] + [ + ] + + [ # RMSNorm + quant fusion, only 8-bit quant models - (CompilationConfig(level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm"], - pass_config=PassConfig(enable_fusion=True, - enable_noop=True)), model) + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), + model, + ) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) - ] + [ + ] + + [ # Test depyf integration works - (CompilationConfig(level=CompilationLevel.PIECEWISE, - debug_dump_path=tempfile.gettempdir()), - ("facebook/opt-125m", {})), - ]) + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir() + ), + ("facebook/opt-125m", {}), + ), + ] + + [ + # graph inductor partition + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, + # inductor graph partition uses + # torch._C.Tag.cudagraph_unsafe to specify splitting ops + use_inductor_graph_partition=True, + cudagraph_mode=CUDAGraphMode.PIECEWISE, + compile_sizes=[1, 2], + ), + model, + ) + for model in models_list(all=False) + if is_torch_equal_or_newer("2.9.0.dev") + ], +) # only test some of the models @create_new_process_for_each_test() def test_custom_compile_config( compilation_config: CompilationConfig, model_info: tuple[str, dict[str, Any]], ): + if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( + "2.9.0.dev" + ): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + model, model_kwargs = model_info print(f"MODEL={model}") run_model(compilation_config, model, model_kwargs) -def run_model(compile_config: Union[int, CompilationConfig], model: str, - model_kwargs: dict[str, Any]): +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], +) +def test_fp8_kv_scale_compile(optimization_level: int): + model = "Qwen/Qwen2-0.5B" + model_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": "fp8_e4m3", + "calculate_kv_scales": True, + "max_model_len": 512, + } + run_model(optimization_level, model, model_kwargs) + + +def test_inductor_graph_partition_attn_fusion(caplog_vllm): + if not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + + model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_inductor_graph_partition=True, + cudagraph_mode=CUDAGraphMode.PIECEWISE, + custom_ops=["+quant_fp8"], + pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + ) + model_kwargs = { + "kv_cache_dtype": "fp8", + "max_model_len": 1024, + } + with ( + caplog_vllm.at_level(logging.DEBUG), + global_force_attn_backend_context_manager(_Backend.FLASHINFER), + ): + run_model(compilation_config, model, model_kwargs) + + try: + assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, ( + caplog_vllm.text + ) + except AssertionError: + # Note: this message is only triggered when the compilation goes + # through the custom pass. Due to multiple layers of cache on + # PyTorch side, the compilation of a graph may be cached such + # that custom pass directly goes through cache. In this case, + # we go through this branch and assert that the pass is not + # triggered. + assert "Fused quantization" not in caplog_vllm.text + + +def run_model( + compile_config: Union[int, CompilationConfig], + model: str, + model_kwargs: dict[str, Any], +): prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 0c7e6fbccf20c..ae17bc67b1fb6 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -5,112 +5,254 @@ import pytest import torch import vllm.envs as envs -from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FUSED_OPS, FusionPass +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform from .backend import TestBackend -OPS_IN_MODEL = [ - torch.ops._C.rotary_embedding.default, - torch.ops._C.fused_add_rms_norm.default, -] +TEST_FP8 = current_platform.supports_fp8() +FP8_DTYPE = current_platform.fp8_dtype() -RMS_OP = torch.ops._C.rms_norm.default -RMS_QUANT_OPS = { - "static_fp8": [ - torch.ops._C.rms_norm_static_fp8_quant.default, - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - ], -} +class TestSiluMul(torch.nn.Module): + def __init__(self, hidden_size: int = 128): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.wscale = torch.rand(1, dtype=torch.float32) + self.scale = torch.rand(1, dtype=torch.float32) -SILU_MUL_OP = torch.ops._C.silu_and_mul.default + if TEST_FP8: + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) -SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + def forward(self, x): + y = self.silu_and_mul(x) + if TEST_FP8: + x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) + return x2 + else: + return y + + def example_inputs(self, num_tokens=32, hidden_size=128): + dtype = torch.float16 if TEST_FP8 else torch.float32 + return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),) + + def ops_in_model(self, do_fusion): + if TEST_FP8 and do_fusion: + return [torch.ops._C.silu_and_mul_quant.default] + else: + return [torch.ops._C.silu_and_mul.default] + + def ops_not_in_model(self): + return [] + + +class TestFusedAddRMSNorm(torch.nn.Module): + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + dtype = torch.float16 if TEST_FP8 else torch.float32 + + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size), dtype=dtype) + ) + self.norm = RMSNorm(intermediate_size, 1e-05) + self.norm.weight = torch.nn.Parameter( + torch.ones(intermediate_size, dtype=dtype) + ) + + torch.nn.init.normal_(self.gate_proj, std=0.02) + + if TEST_FP8: + self.fp8_linear = Fp8LinearOp(act_quant_static=True) + + self.scale = torch.rand(1, dtype=torch.float32) + self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() + self.wscale = torch.rand(1, dtype=torch.float32) + + def forward(self, hidden_states, residual): + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + # matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # layer normalization + norm_output, residual_output = self.norm(mm, residual) + + if TEST_FP8: + # scaled_mm with static input quantization + fp8_linear_result = self.fp8_linear.apply( + norm_output, + self.w, + self.wscale, + input_scale=self.scale.to(norm_output.device), + ) + + return fp8_linear_result, residual_output + + else: + return norm_output, residual_output + + def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): + dtype = torch.float16 if TEST_FP8 else torch.float32 + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + return (hidden_states, residual) + + def ops_in_model(self, do_fusion): + if TEST_FP8 and do_fusion: + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] + else: + return [torch.ops._C.fused_add_rms_norm.default] + + def ops_not_in_model(self): + return [] + + +class TestRotaryEmbedding(torch.nn.Module): + def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000): + super().__init__() + self.head_dim = head_dim + self.rotary_dim = rotary_dim or head_dim + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position, + base=base, + ) + + def forward(self, positions, q, k): + q_rotated, k_rotated = self.rotary_emb(positions, q, k) + return q_rotated, k_rotated + + def example_inputs(self, num_tokens=32, head_dim=64): + dtype = torch.float16 + positions = torch.arange(num_tokens, dtype=torch.long) + q = torch.randn(num_tokens, head_dim, dtype=dtype) + k = torch.randn(num_tokens, head_dim, dtype=dtype) + return (positions, q, k) + + def ops_in_model(self, do_fusion): + return [torch.ops._C.rotary_embedding.default] + + def ops_not_in_model(self): + return [] + + +class TestRotaryEmbeddingSliceScatter(torch.nn.Module): + def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.hidden_size = head_dim * num_heads + + self.qkv_proj = torch.nn.Linear( + self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=base, + ) + + def forward(self, positions, hidden_states): + # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding + # -> slice_scatter -> split_with_sizes + + qkv = self.qkv_proj(hidden_states) + split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size] + q, k, v = torch.split(qkv, split_sizes, dim=-1) + + q_rotated, k_rotated = self.rotary_emb(positions, q, k) + + qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1) + return qkv_updated + + def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): + dtype = torch.float16 + hidden_size = head_dim * num_heads + positions = torch.arange(num_tokens, dtype=torch.long) + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + return (positions, hidden_states) + + def ops_in_model(self, do_fusion): + return [torch.ops._C.rotary_embedding.default] + + def ops_not_in_model(self): + return [torch.ops.aten.slice_scatter.default] + + +MODELS = [ + TestSiluMul, + TestFusedAddRMSNorm, + TestRotaryEmbedding, + TestRotaryEmbeddingSliceScatter, ] -@pytest.mark.parametrize( - "model, quant_key", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e", - kFp8DynamicTokenSym)]) +@pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", - reason="Only test on CUDA") -def test_fix_functionalization(model: str, quant_key: QuantKey, - do_fusion: bool): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") +def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): torch.set_default_device("cuda") vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + ) noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = [noop_pass, fusion_pass, act_quant_fusion_pass - ] if do_fusion else [noop_pass] + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion + else [noop_pass, cleanup_pass] + ) func_pass = FixFunctionalizationPass(vllm_config) + backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) - # instantiate a full engine and manually compile the model 2x - # (with and without FixFunctionalizationPass) - llm = LLM(model=model, enforce_eager=True) - model_runner = llm.llm_engine.model_executor.driver_worker.model_runner - orig_model = model_runner.model - # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) - # Can only do that by using the decorator but then we'd have to instantiate - # 2 LLM instances. + model = model_class() + torch.compile(model, backend=backend_func)(*model.example_inputs()) + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - model_runner.model = torch.compile(orig_model, - fullgraph=True, - backend=backend_func) - gen_func = llm.generate(prompts, sampling_params) - - model_runner.model = torch.compile(orig_model, - fullgraph=True, - backend=backend_no_func) - - gen_no_func = llm.generate(prompts, sampling_params) - - for output_func, output_no_func in zip(gen_func, gen_no_func): - assert output_func.outputs[0].text == output_no_func.outputs[0].text - - # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, - # and replaced by fused quantized ops in RMS_QUANT_OPS. - rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] - ] if do_fusion else [RMS_OP] - silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \ - quant_key == kFp8StaticTensorSym else [ - SILU_MUL_OP - ] - - ops = OPS_IN_MODEL + rms_ops + silu_mul_ops - - for op in ops: + # check if the functionalization pass is applied + for op in model.ops_in_model(do_fusion): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, - op) is None # noqa: E501 + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # make sure the ops were all de-functionalized found = dict() for node in backend_func.graph_post_pass.nodes: - for op in ops: + for op in model.ops_in_model(do_fusion): if is_func(node, op): found[op] = True - assert all(found[op] for op in ops) + for op in model.ops_not_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model(do_fusion)) + assert all(not found.get(op) for op in model.ops_not_in_model()) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 5cfad935a0fb1..7c22336432299 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -4,31 +4,47 @@ import pytest import torch -import vllm.envs as envs import vllm.plugins -from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass) +from vllm.compilation.fusion import ( + FUSED_OPS, + QUANT_OPS, + FusedRMSQuantKey, + RMSNormQuantFusionPass, +) from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, - VllmConfig) +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc) + GroupShape, + QuantKey, + ScaleDesc, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, + cutlass_fp8_supported, + maybe_create_device_identity, +) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): - - def __init__(self, hidden_size: int, eps: float, static: bool, - cutlass_fp8_enabled: bool, *args, **kwargs): + def __init__( + self, + hidden_size: int, + eps: float, + static: bool, + cuda_force_torch: bool, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) - self.cutlass_fp8_enabled = cutlass_fp8_enabled + self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN @@ -42,27 +58,26 @@ class TestModel(torch.nn.Module): torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) ] - self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, - act_quant_static=static, - act_quant_group_shape=group_shape, - ) + + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply(y, - self.w[0], - self.wscale[0], - input_scale=self.scale[0]) + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply(y2, - self.w[1], - self.wscale[1], - input_scale=self.scale[1]) + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -72,38 +87,46 @@ class TestModel(torch.nn.Module): def ops_in_model_after(self): return [ FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)] + FUSED_OPS[FusedRMSQuantKey(self.key, True)], ] @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) -@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], - reason="Only test on CUDA and ROCm") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cutlass_fp8_enabled): +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize( + "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" +) +def test_fusion_rmsnorm_quant( + dtype, hidden_size, num_tokens, eps, static, cuda_force_torch +): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"], - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ) + ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) @@ -124,6 +147,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + assert fusion_pass.matched_count == 2 + # In pre-nodes, fp8 quant should be there and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index dd31e0db1f59f..7e5c460db1744 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -9,14 +9,25 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, - ModelConfig, PassConfig, VllmConfig) +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.config import ( + CompilationConfig, + CompilationLevel, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - GroupShape, QuantFP8) + GroupShape, + QuantFP8, +) from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -25,7 +36,6 @@ from .backend import TestBackend class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -46,7 +56,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -67,25 +76,22 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, - group_shape=GroupShape.PER_TENSOR) + self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), - dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant(self.output, - norm_output.contiguous(), - self.scale) + torch.ops._C.static_scaled_fp8_quant( + self.output, norm_output.contiguous(), self.scale + ) return self.output, residual_output def ops_in_model_after(self): @@ -94,35 +100,33 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default + torch.ops._C.static_scaled_fp8_quant.default, ] class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), - dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) round_up = lambda x, y: (x + y - 1) // y * y rounded_m = round_up(token_num, 128) scale_n = hidden_size // 16 rounded_n = round_up(scale_n, 4) - self.output_scale = torch.empty((rounded_m, rounded_n // 4), - dtype=torch.int32) + self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) norm_output = norm_output.reshape(-1, norm_output.shape[-1]) - torch.ops._C.scaled_fp4_quant(self.output, norm_output, - self.output_scale, self.scale) + torch.ops._C.scaled_fp4_quant( + self.output, norm_output, self.output_scale, self.scale + ) return self.output, residual_output, self.output_scale def ops_in_model_after(self): @@ -131,7 +135,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.scaled_fp4_quant.default + torch.ops._C.scaled_fp4_quant.default, ] @@ -144,41 +148,55 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): TestAllReduceFusedAddRMSNormStaticQuantFP8Model, # TODO: Enable with torch==2.8.0 # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, - ]) + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), reason="flashinfer is not found or flashinfer " - "is not compiled with trtllm_allreduce_fusion") -def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): + "is not compiled with trtllm_allreduce_fusion", +) +def test_all_reduce_fusion_pass_replace( + test_model: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): num_processes = 2 - if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model - and not current_platform.has_device_capability(100)): - pytest.skip("Skip as nvfp4 is only supported on " - "devices with compute capability 10.0 (Blackwell)") + if ( + test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model + and not current_platform.has_device_capability(100) + ): + pytest.skip( + "Skip as nvfp4 is only supported on " + "devices with compute capability 10.0 (Blackwell)" + ) def run_torch_spawn(fn, nprocs): - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model, - batch_size, seq_len, hidden_size, - dtype), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + nprocs=nprocs, + ) run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) -def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, - test_model_cls: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def all_reduce_fusion_pass_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -186,37 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"])) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"] + ) + ) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True, enable_noop=True) + enable_fi_allreduce_fusion=True, enable_noop=True + ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass) token_num = batch_size * seq_len model = test_model_cls(hidden_size, token_num) @@ -227,6 +250,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states, residual) + assert all_reduce_fusion_pass.matched_count == 1 backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index dba668cfa16a6..0f2e3bffbd311 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -6,27 +6,35 @@ from typing import Optional import pytest import torch._dynamo -from tests.compile.backend import TestBackend -from tests.models.utils import check_outputs_equal -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata) -from vllm import LLM, SamplingParams +from tests.compile.backend import LazyInitPass, TestBackend +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.attention import Attention +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, - ModelConfig, PassConfig, SchedulerConfig, VllmConfig, - set_current_vllm_config) +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.config import ( + CacheConfig, + CompilationConfig, + CompilationLevel, + ModelConfig, + PassConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) + kFp8StaticTensorSym, + kNvfp4Quant, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -37,124 +45,19 @@ backend: Optional[TestBackend] = None backend_unfused: Optional[TestBackend] = None -@pytest.mark.parametrize( - "model, quant_key", - [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) -@pytest.mark.parametrize( - "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) -@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test CUDA and ROCm") -def test_attention_fusion(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): - # Clean Dynamo cache to avoid reusing other test cases - # (for some reason the reset at the end is not enough) - torch._dynamo.reset() - - # Use global backends - global backend, backend_unfused - - use_v1 = False # can be made a param once V1 support added - monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1))) - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa))) - - # Prompt 4 seems too open-ended, differs between fused and unfused - # (both outputs look reasonable though) - prompts = example_prompts[:4] + example_prompts[5:] - - compile_config = CompilationConfig( - # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation - # DYNAMO_ONCE does not properly propagate shapes. - level=CompilationLevel.DYNAMO_AS_IS, - backend="tests.compile.test_fusion_attn.backend_unfused", - custom_ops=["+quant_fp8"], - ) - vllm_config = VllmConfig(compilation_config=compile_config) - backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) - - llm = LLM(model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.9, - max_model_len=2048) - - sampling_params = SamplingParams(temperature=0.0, - max_tokens=10, - top_p=0.95) - - unfused_output = llm.generate(prompts, sampling_params) - backend_unfused = None # Reset backend to make sure llm gets released - del llm - - compile_config = CompilationConfig( - # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation - # DYNAMO_ONCE does not properly propagate shapes. - level=CompilationLevel.DYNAMO_AS_IS, - backend="tests.compile.test_fusion_attn.backend", - custom_ops=["+quant_fp8"], - ) - vllm_config = VllmConfig(compilation_config=compile_config) - - # AttnFusionPass needs attention layers to be registered in config upon init - # so we initialize it during compilation. - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) - backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) - llm2 = LLM(model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.9, - max_model_len=2048) - - # check support - attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key) - for key, layer in compile_config.static_forward_context.items() - ] - - print(f"{attn_fusion_supported=}") - if any(attn_fusion_supported): - # Check quant ops - backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) - - # attention ops present in both, just output_scale param changes - attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass)) - attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass)) - assert len(attn_nodes_pre) == len(attn_nodes_post) - - for i in range(len(attn_nodes_pre)): - assert attn_nodes_pre[i].kwargs["output_scale"] is None - fused = attn_nodes_post[i].kwargs["output_scale"] is not None - assert fused == attn_fusion_supported[i], \ - f"Node {i} {'' if fused else 'not '} expected " \ - f"to have fused output quant" - - # check outputs - fused_output = llm2.generate(prompts, sampling_params) - - # transform outputs to format expected by check_outputs_equal - sample_outs = lambda s: (list(s.token_ids), s.text) - outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros] - - check_outputs_equal( - outputs_0_lst=outs_lst(unfused_output), - outputs_1_lst=outs_lst(fused_output), - name_0="unfused", - name_1="fused", - ) - - # Clean Dynamo cache to avoid polluting other case(s) - torch._dynamo.reset() - - # Reset backend to make sure llm2 gets released - backend = None - - class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" - def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, - kv_cache_dtype: torch.dtype, device: torch.device, - vllm_config: VllmConfig, **kwargs): + def __init__( + self, + num_qo_heads: int, + num_kv_heads: int, + head_size: int, + kv_cache_dtype: torch.dtype, + device: torch.device, + vllm_config: VllmConfig, + **kwargs, + ): super().__init__() self.num_qo_heads = num_qo_heads self.num_kv_heads = num_kv_heads @@ -171,6 +74,8 @@ class AttentionQuantPatternModel(torch.nn.Module): cache_config=vllm_config.cache_config, prefix="model.layers.0.self_attn.attn", ) + self.attn._k_scale = self.attn._k_scale.to(device) + self.attn._v_scale = self.attn._v_scale.to(device) self.block_size = 16 @@ -181,47 +86,80 @@ class AttentionQuantPatternModel(torch.nn.Module): num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_dtype, - use_mla=False, ), layer_names=[self.attn.layer_name], vllm_config=self.vllm_config, device=self.device, ) - def build_attn_metadata(self, batch_size: int): + def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: """Initialize attention metadata.""" # Create common attn metadata - batch_spec = BatchSpec(seq_lens=[1] * batch_size, - query_lens=[1] * batch_size) + batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) common_attn_metadata = create_common_attn_metadata( - batch_spec, - self.block_size, - self.device, - arange_block_indices=True) + batch_spec, self.block_size, self.device, arange_block_indices=True + ) - max_blocks = (max(batch_spec.seq_lens) + self.block_size - - 1) // self.block_size + max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks + backend = self.attn.backend - # Create dummy KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) - kv_cache = torch.zeros(num_blocks, - 2, - self.num_kv_heads, - self.block_size, - self.head_size, - dtype=self.kv_cache_dtype, - device=self.device) - kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + # Create dummy KV cache for the selected backend + if backend == _Backend.ROCM_ATTN: + # k/v as 1st dimention + # HND: [num_blocks, num_kv_heads, block_size, head_size] + kv_cache = torch.zeros( + 2, + num_blocks, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + # k/v as 1st dimention + # NHD: [num_blocks, block_size, num_kv_heads, head_size] + kv_cache = torch.zeros( + 2, + num_blocks, + self.block_size, + self.num_kv_heads, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + elif backend == _Backend.TRITON_ATTN: + # k/v as 2nd dimention + # NHD: [num_blocks, block_size, num_kv_heads, head_size] + kv_cache = torch.zeros( + num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + elif backend == _Backend.FLASHINFER: + kv_cache = torch.zeros( + num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ).permute(0, 1, 3, 2, 4) + else: + raise ValueError(f"Unsupported backend: {backend}") self.attn.kv_cache = [kv_cache] # Build attn metadata self.attn_metadata = self.builder.build( - common_prefix_len=0, common_attn_metadata=common_attn_metadata) + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) return self.attn_metadata @@ -236,27 +174,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): self.fp8_linear = Fp8LinearOp( act_quant_static=self.quant_key.scale.static, - act_quant_group_shape=self.quant_key.scale.group_shape) + act_quant_group_shape=self.quant_key.scale.group_shape, + ) hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( - "w", { - "weight": - torch.randn(hidden_size, hidden_size).to( - dtype=FP8_DTYPE, device=self.device).t(), - "wscale": - torch.tensor([1.0], dtype=torch.float32, device=self.device), - "scale": - torch.tensor([1.0], dtype=torch.float32, device=self.device), - }) + "w", + { + "weight": torch.randn(hidden_size, hidden_size) + .to(dtype=FP8_DTYPE, device=self.device) + .t(), + "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), + }, + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - return self.fp8_linear.apply(input=attn_output, - weight=self.w["weight"], - weight_scale=self.w["wscale"], - input_scale=self.w["scale"]) + return self.fp8_linear.apply( + input=attn_output, + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"], + ) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): @@ -269,55 +210,106 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( - "w", { - "weight": - torch.randint(256, (hidden_size, hidden_size // 2), - dtype=FP4_DTYPE, - device=self.device), - "wscale_swizzled": - torch.randn(hidden_size, hidden_size // 16).to( - dtype=FP8_DTYPE, device=self.device), - "wscale": - torch.tensor([500], dtype=torch.float32, device=self.device), - "scale": - torch.tensor([0.002], dtype=torch.float32, device=self.device), - }) + "w", + { + "weight": torch.randint( + 256, + (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device, + ), + "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device + ), + "wscale": torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device), + }, + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) quant_output, output_block_scale = scaled_fp4_quant( - attn_output, 1 / self.w["scale"]) - return cutlass_scaled_fp4_mm(a=quant_output, - b=self.w["weight"], - block_scale_a=output_block_scale, - block_scale_b=self.w["wscale_swizzled"], - alpha=self.w["scale"] * self.w["wscale"], - out_dtype=attn_output.dtype) + attn_output, 1 / self.w["scale"] + ) + return cutlass_scaled_fp4_mm( + a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype, + ) -@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +if current_platform.is_cuda(): + MODELS = [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel, + ), + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel, + ), + ] + HEADS = [(64, 8), (40, 8)] +elif current_platform.is_rocm(): + MODELS = [ + ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) + ] + HEADS = [(32, 8), (40, 8)] +else: + MODELS = [] + HEADS = [] + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) -@pytest.mark.parametrize("batch_size", [7, 256, 533]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("model_name, model_class", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - TestAttentionFp8StaticQuantPatternModel), - ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel)]) -@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) -@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.parametrize( + "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("model_name, model_class", MODELS) +@pytest.mark.parametrize( + "backend", + [_Backend.FLASHINFER] + if current_platform.is_cuda() + else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], +) +# TODO(boyuan): test inductor graph partition on rocm +@pytest.mark.parametrize( + "use_inductor_graph_partition", + [False] if current_platform.is_rocm() else [False, True], +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" +) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), - reason="Only test on SM100(Blackwell)") -def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, - head_size: int, batch_size: int, - dtype: torch.dtype, model_name: str, - model_class: type[AttentionQuantPatternModel], - backend: _Backend, monkeypatch, dist_init): +@pytest.mark.skipif( + current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)), + reason="On CUDA only test on SM100(Blackwell)", +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" +) +def test_attention_quant_pattern( + num_qo_heads: int, + num_kv_heads: int, + head_size: int, + batch_size: int, + dtype: torch.dtype, + model_name: str, + model_class: type[AttentionQuantPatternModel], + backend: _Backend, + use_inductor_graph_partition: bool, + dist_init, + caplog_vllm, +): """Test AttentionStaticQuantPattern fusion pass""" - monkeypatch.setenv("VLLM_USE_V1", "1") + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") device = torch.device("cuda:0") torch.manual_seed(42) @@ -326,27 +318,21 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_config=ModelConfig( model=model_name, max_model_len=2048, + dtype=dtype, ), scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+quant_fp8"], + use_inductor_graph_partition=use_inductor_graph_partition, ), - cache_config=CacheConfig(cache_dtype="fp8")) + cache_config=CacheConfig(cache_dtype="fp8"), + ) # Create test inputs - q = torch.randn(batch_size, - num_qo_heads * head_size, - dtype=dtype, - device=device) - k = torch.randn(batch_size, - num_kv_heads * head_size, - dtype=dtype, - device=device) - v = torch.randn(batch_size, - num_kv_heads * head_size, - dtype=dtype, - device=device) + q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device) + k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device) + v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device) # Mark first dimension as dynamic for realistic testing torch._dynamo.mark_dynamic(q, 0) @@ -355,37 +341,45 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, # Run model directly without compilation and fusion vllm_config_unfused = copy.deepcopy(vllm_config) - with set_current_vllm_config(vllm_config_unfused), set_forward_context( - attn_metadata=None, vllm_config=vllm_config_unfused - ), global_force_attn_backend_context_manager(backend): - model_unfused = model_class(num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - kv_cache_dtype=FP8_DTYPE, - device=device, - vllm_config=vllm_config_unfused) + with ( + set_current_vllm_config(vllm_config_unfused), + set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), + global_force_attn_backend_context_manager(backend), + ): + model_unfused = model_class( + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused, + ) model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() - forward_ctx.attn_metadata = model_unfused.build_attn_metadata( - batch_size) + forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) # Run model directly without compilation and fusion result_unfused = model_unfused(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( - enable_attn_fusion=True, enable_noop=True) - with set_current_vllm_config(vllm_config), set_forward_context( - attn_metadata=None, vllm_config=vllm_config - ), global_force_attn_backend_context_manager(backend): - model_fused = model_class(num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - kv_cache_dtype=FP8_DTYPE, - device=device, - vllm_config=vllm_config, - w=model_unfused.w) + enable_attn_fusion=True, enable_noop=True + ) + with ( + set_current_vllm_config(vllm_config), + set_forward_context(attn_metadata=None, vllm_config=vllm_config), + global_force_attn_backend_context_manager(backend), + ): + model_fused = model_class( + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w, + ) model_fused = model_fused.to(device) forward_ctx = get_forward_context() @@ -393,63 +387,72 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw - ) - test_backend = TestBackend(noop_pass, attn_pass) + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) # Compile model with fusion enabled - model_compiled = torch.compile(model_fused, - backend=test_backend, - fullgraph=True) + model_compiled = torch.compile( + model_fused, backend=test_backend, fullgraph=True + ) assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v) - # After the 1st round of the forward pass, output quant scale should be - # loaded into the attn layer's _o_scale_float, the 2nd round should - # reuse the loaded _o_scale_float - assert model_compiled.attn._o_scale_float is not None - result_fused_2 = model_compiled(q, k, v) - assert model_compiled.attn._o_scale_float is not None + if backend == _Backend.FLASHINFER: + # With the Flashinfer backend after the 1st round of the forward + # pass, output quant scale should be loaded into the attn layer's + # _o_scale_float, the 2nd round should reuse the loaded + # _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + + assert model_compiled.attn._o_scale_float is not None + + torch.testing.assert_close( + result_unfused, result_fused_2, atol=1e-2, rtol=1e-2 + ) # Check attn fusion support quant_key = model_class.quant_key attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key) for key, layer in - vllm_config.compilation_config.static_forward_context.items() + layer.impl.fused_output_quant_supported(quant_key) + for key, layer in vllm_config.compilation_config.static_forward_context.items() ] if any(attn_fusion_supported): # Check quantization ops in the graph before and after fusion - test_backend.check_before_ops([QUANT_OPS[quant_key]], - fully_replaced=True) + test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + + # access the underlying `AttnFusionPass` on the `LazyInitPass` + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) - attn_nodes_post = list(find_op_nodes(ATTN_OP, - test_backend.graph_post_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass)) assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" - assert len(attn_nodes_pre) == len(attn_nodes_post), \ + assert len(attn_nodes_pre) == len(attn_nodes_post), ( "Should have same number of attention nodes before and after fusion" - assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + ) + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, ( "Attention should not have output_scale before fusion" - assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + ) + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, ( "Attention should have output_scale after fusion" + ) - assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, ( "Attention should not have output_block_scale before fusion" + ) if quant_key.dtype == FP8_DTYPE: - assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, ( "Attention should not have output_block_scale after FP8 fusion" + ) elif quant_key.dtype == FP4_DTYPE: - assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ - "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, ( + "Attention should have output_block_scale after FP4 fusion" + ) - # Check that results are closed - torch.testing.assert_close(result_unfused, - result_fused_1, - atol=1e-2, - rtol=1e-2) - torch.testing.assert_close(result_unfused, - result_fused_2, - atol=1e-2, - rtol=1e-2) + # Check that results are close + torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py new file mode 100644 index 0000000000000..188f4514dda5f --- /dev/null +++ b/tests/compile/test_noop_elimination.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig + +from .backend import TestBackend + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +# Important edge case is when `num_tokens == buffer_size` +@pytest.mark.parametrize( + ("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)] +) +@pytest.mark.parametrize("hidden_size", [64, 4096]) +def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype) + + def forward(self, x): + x += self.pos_embed[: x.shape[0]] + # Chain of reshapes + y = x.reshape(-1, 128, 32) + z = y.reshape(-1, 4096) + # No-op reshape + a = z.reshape(-1, 4096) + # Final reshape that should remain + b = a.reshape(-1, 128, 32) + # No-op slice + c = b[0 : b.shape[0]] + # The pass should replace the result of this op with `c`. + d = torch.slice_scatter( + torch.ones_like(c), # Dummy tensor to be scattered into + c, # Source tensor + 0, # dim + 0, # start + c.shape[0], # end + ) + return d + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + ) + ) + with vllm.config.set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + + backend = TestBackend(noop_pass) + + model = Model() + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + ATOL, RTOL = (2e-3, 2e-3) + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + # The no-op reshape and slice should be eliminated. + # The initial slice on the positional embedding should remain. + # The chain of reshapes should be fused into a single reshape. + assert backend.op_count(torch.ops.aten.reshape.default) == 1 + assert backend.op_count(torch.ops.aten.slice.Tensor) == 1 + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 + + +def test_non_noop_slice_preserved(): + """Ensure that a slice with end=-1 (dropping last row) is NOT eliminated. + + Regression test for a bug where end=-1 was treated like an inferred + dimension (reshape semantics) leading to incorrect elimination. + """ + torch.set_default_device("cuda") + x = torch.randn(16, 16) + + class SliceModel(torch.nn.Module): + def forward(self, x): + base = x.clone() + src = torch.ones(15, 16) + y = torch.slice_scatter(base, src, dim=0, start=0, end=-1) + return x[0:-1, :], y + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + ) + ) + with vllm.config.set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + backend = TestBackend(noop_pass) + model = SliceModel() + ref = model(x) + compiled = torch.compile(model, backend=backend) + out = compiled(x) + torch.testing.assert_close(ref, out) + # The slice should remain (not a no-op). + assert backend.op_count(torch.ops.aten.slice.Tensor) == 1 + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1 diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 251cc46e9e989..ac561d2e8f84a 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -28,7 +28,6 @@ def test_bad_callable(): # Pass that inherits from InductorPass class ProperPass(InductorPass): - def __call__(self, graph: torch.fx.graph.Graph) -> None: pass @@ -39,8 +38,7 @@ class ProperPass(InductorPass): ProperPass(), # Can also wrap callables in CallableInductorPass for compliance CallableInductorPass(simple_callable), - CallableInductorPass(simple_callable, - InductorPass.hash_source(__file__)) + CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)), ], ) def test_pass_manager_uuid(callable): @@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.compilation_config.pass_config.enable_fusion = not \ - config2.compilation_config.pass_config.enable_fusion + config2.compilation_config.pass_config.enable_fusion = ( + not config2.compilation_config.pass_config.enable_fusion + ) pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) pass_manager3.add(callable) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index a6baa97fe6990..afb31cb95be09 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -6,18 +6,26 @@ import torch import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import FusionPass +from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass -from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - PassConfig, VllmConfig) +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import ( + CompilationConfig, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -34,16 +42,15 @@ prompts = [ class TestModel(torch.nn.Module): - - def __init__(self, - hidden_size=16, - intermediate_size=32, - vllm_config: VllmConfig = None): + def __init__( + self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None + ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size))) + torch.empty((intermediate_size, hidden_size)) + ) self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -51,18 +58,18 @@ class TestModel(torch.nn.Module): def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph - + Args: hidden_states: Input tensor residual: Residual tensor from previous layer - + Returns: Tuple containing the output tensor """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) - #matrix multiplication + # matrix multiplication permute = self.gate_proj.permute(1, 0) mm = torch.mm(view, permute) @@ -80,7 +87,7 @@ class TestModel(torch.nn.Module): def ops_in_model_after(self): return [ torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default + torch.ops.vllm.all_gather.default, ] def ops_in_model(self): @@ -88,47 +95,43 @@ class TestModel(torch.nn.Module): class TestQuantModel(torch.nn.Module): - - def __init__(self, - hidden_size=16, - intermediate_size=32, - vllm_config: VllmConfig = None): + def __init__( + self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None + ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.vllm_config = vllm_config - self.gate_proj = torch.nn.Parameter(torch.empty( - (intermediate_size, hidden_size)), - requires_grad=False) + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size)), requires_grad=False + ) self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, - use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. - self.w = torch.rand(hidden_size, - intermediate_size).to(dtype=FP8_DTYPE).t() + self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph - + Args: hidden_states: Input tensor residual: Residual tensor from previous layer - + Returns: Tuple containing the output tensor """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) - #matrix multiplication + # matrix multiplication permute = self.gate_proj.permute(1, 0) mm = torch.mm(view, permute) @@ -138,47 +141,52 @@ class TestQuantModel(torch.nn.Module): # layer normalization norm_output, residual_output = self.norm(all_reduce, residual) - # for static input quantization - # self.fp8_linear is initialized with use_per_token_if_dynamic=False - fp8_linear_result = self.fp8_linear.apply(norm_output, - self.w, - self.wscale, - input_scale=self.scale.to( - norm_output.device)) + # scaled_mm with static input quantization + fp8_linear_result = self.fp8_linear.apply( + norm_output, + self.w, + self.wscale, + input_scale=self.scale.to(norm_output.device), + ) return fp8_linear_result, residual_output def ops_in_model_before(self): - ops_to_remove = [torch.ops.vllm.all_reduce.default - ] # Always removed by SP + ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP # The following are only removed if fusion happens - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: - ops_to_remove.extend([ - torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.static_scaled_fp8_quant.default, - ]) + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): + ops_to_remove.extend( + [ + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.static_scaled_fp8_quant.default, + ] + ) return ops_to_remove def ops_in_model_after(self): ops_to_add = [ torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default + torch.ops.vllm.all_gather.default, ] # The following is only added if fusion happens - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: - ops_to_add.append( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): + ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) return ops_to_add def ops_in_model(self): - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): # If fusion happens, the fused op is the one # we check for (de)functionalization - return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - ] # noqa: E501 + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] else: # If no fusion, the original ops are checked return [ @@ -195,30 +203,47 @@ class TestQuantModel(torch.nn.Module): @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("enable_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module], - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype, - enable_fusion: bool): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_sequence_parallelism_pass( + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + enable_fusion: bool, +): num_processes = 2 def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model_cls, - batch_size, seq_len, hidden_size, - dtype, enable_fusion), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + test_model_cls, + batch_size, + seq_len, + hidden_size, + dtype, + enable_fusion, + ), + nprocs=nprocs, + ) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) def sequence_parallelism_pass_on_test_model( - local_rank: int, world_size: int, - test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype, enable_fusion: bool): + local_rank: int, + world_size: int, + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + enable_fusion: bool, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model( torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() @@ -240,39 +267,41 @@ def sequence_parallelism_pass_on_test_model( # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_sequence_parallelism=True, - enable_fusion=enable_fusion, - enable_noop=True)) # NoOp needed for fusion + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig( + enable_sequence_parallelism=True, + enable_fusion=enable_fusion, + enable_noop=True, + ) + ) # NoOp needed for fusion vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) - passes_for_backend = [noop_pass, sequence_parallelism_pass] + passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass] if enable_fusion: - fusion_pass = FusionPass.instance(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) passes_for_backend.append(fusion_pass) + passes_for_backend.append(cleanup_pass) + backend_no_func = TestBackend(*passes_for_backend) backend_func = TestBackend(*passes_for_backend, func_pass) - model = test_model_cls(hidden_size, - hidden_size * 2, - vllm_config=vllm_config) + model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) compiled_model_no_func = torch.compile(model, backend=backend_no_func) @@ -280,6 +309,8 @@ def sequence_parallelism_pass_on_test_model( compiled_model_func = torch.compile(model, backend=backend_func) compiled_model_func(hidden_states, residual) + assert sequence_parallelism_pass.matched_count == 1 + # In pre-nodes, all reduce should be there, # reduce scatter and all gather should not backend_no_func.check_before_ops(model.ops_in_model_before()) @@ -291,8 +322,7 @@ def sequence_parallelism_pass_on_test_model( # check if the functionalization pass is applied for op in model.ops_in_model(): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, - op) is None # noqa: E501 + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # make sure the ops were all de-functionalized found = dict() diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 5351a3cf35ba5..16a4271655efa 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,73 +1,155 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + import pytest import torch import vllm.envs as envs -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.compilation.activation_quant_fusion import ( + FUSED_OPS, + SILU_MUL_OP, + ActivationQuantFusionPass, +) +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp) + Fp8LinearOp, + cutlass_fp8_supported, +) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, - **kwargs): - super().__init__(*args, **kwargs) +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +class TestSiluMulFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): + super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = (torch.rand( - hidden_size, - hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, - self.w, - self.wscale, - input_scale=self.wscale) + x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) return x2 + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] + + def ops_in_model_after(self): + return [FUSED_OPS[kFp8StaticTensorSym]] + + +class TestSiluMulNvfp4QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): + super().__init__() + from vllm.compilation.activation_quant_fusion import ( + silu_and_mul_nvfp4_quant_supported, + ) + + assert silu_and_mul_nvfp4_quant_supported + + self.silu_and_mul = SiluAndMul() + + # create nvfp4 weight + w = torch.rand((hidden_size, hidden_size)) + self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w) + + # get global scale offline + _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x)) + + self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale) + + def forward(self, x): + y = self.silu_and_mul(x) + y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) + out = cutlass_scaled_fp4_mm( + a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y.dtype, + ) + return out + + def ops_in_model_before(self): + return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + + def ops_in_model_after(self): + return [FUSED_OPS[kNvfp4Quant]] + + +@pytest.mark.parametrize("num_tokens", [32, 64]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "model_class", + cast( + list[type], + [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() + else [TestSiluMulFp8QuantModel], + ), +) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize( + "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] +) +@pytest.mark.skipif( + envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" +) +def test_fusion_silu_and_mul_quant( + num_tokens, hidden_size, dtype, model_class, cuda_force_torch +): + if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: + pytest.skip("Duplicate tests for NVFP4") -@pytest.mark.parametrize("num_tokens", [256]) -@pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], - reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, - cutlass_fp8_enabled): torch.set_default_device("cuda") - torch.set_default_dtype(torch.float16) + torch.set_default_dtype(dtype) + + x = torch.rand(num_tokens, hidden_size * 2) # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_noop=True)) + pass_config=PassConfig(enable_fusion=True, enable_noop=True) + ) fusion_pass = ActivationQuantFusionPass(config) - backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, cutlass_fp8_enabled) + passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] + backend = TestBackend(*passes) + model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) # First dimension dynamic - x = torch.rand(num_tokens, hidden_size * 2) torch._dynamo.mark_dynamic(x, 0) result = model(x) @@ -76,22 +158,19 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, result2 = model2(x) # Check that it gives the same answer - torch.testing.assert_close(result[0].to(dtype=torch.float16), - result2[0].to(dtype=torch.float16), - atol=1e-3, - rtol=1e-3) + if model_class == TestSiluMulFp8QuantModel: + atol, rtol = 1e-3, 1e-3 + elif model_class == TestSiluMulNvfp4QuantModel: + atol, rtol = 1e-1, 1e-1 - # Check substitution worked - pre_nodes = backend.graph_pre_pass.nodes - post_nodes = backend.graph_post_pass.nodes + torch.testing.assert_close( + result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol + ) - silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + assert fusion_pass.matched_count == 1 - # In pre-nodes, fp8 quant should be present and fused kernels should not - assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None - find_auto_fn(pre_nodes, fp8_quant) + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - # In post-nodes, fused kernels should be present and fp8 quant should not - find_auto_fn(post_nodes, silu_and_mul_quant) - assert find_auto_fn_maybe(post_nodes, fp8_quant) is None + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 5e39f6821d16c..34db5a999cbd8 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -10,7 +10,6 @@ from vllm.config import CompilationLevel class MyMod(torch.nn.Module): - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): if cache is not None: return x + cache @@ -18,12 +17,12 @@ class MyMod(torch.nn.Module): class MyWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable, - compilation_level=CompilationLevel.DYNAMO_ONCE) + super().__init__( + compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE + ) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled @@ -54,10 +53,8 @@ def test_torch_compile_wrapper(): # for new input, dispatch to the compiled code directly new_x = torch.tensor([3]) - assert wrapper(new_x, - None).item() == 6 # dispatch to the first compiled code - assert wrapper( - new_x, cache).item() == 5 # dispatch to the second compiled code + assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code + assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code for wrapper in wrappers: # make sure they have independent compiled codes diff --git a/tests/config/test_config_generation.py b/tests/config/test_config_generation.py index e37b6b95941e9..61c3df0a23483 100644 --- a/tests/config/test_config_generation.py +++ b/tests/config/test_config_generation.py @@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch): """ def create_config(): - engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True) + engine_args = EngineArgs( + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ) return engine_args.create_engine_config() # Create config with CUDA_VISIBLE_DEVICES set normally @@ -34,16 +35,18 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch): empty_config_dict.pop("instance_id", None) assert deep_compare(normal_config_dict, empty_config_dict), ( - "Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\"" - " should be equivalent") + 'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""' + " should be equivalent" + ) def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch): # In testing, this method needs to be nested inside as ray does not # see the test module. def create_config(): - engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True) + engine_args = EngineArgs( + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ) return engine_args.create_engine_config() config = create_config() @@ -51,6 +54,7 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch): assert parallel_config.ray_runtime_env is None import ray + ray.init() runtime_env = { @@ -59,13 +63,13 @@ def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch): }, } - config_ref = ray.remote(create_config).options( - runtime_env=runtime_env).remote() + config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote() config = ray.get(config_ref) parallel_config = config.parallel_config assert parallel_config.ray_runtime_env is not None - assert parallel_config.ray_runtime_env.env_vars().get( - "TEST_ENV_VAR") == "test_value" + assert ( + parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value" + ) ray.shutdown() diff --git a/tests/config/test_mp_reducer.py b/tests/config/test_mp_reducer.py index d4d4be293280b..56dc542f1c76d 100644 --- a/tests/config/test_mp_reducer.py +++ b/tests/config/test_mp_reducer.py @@ -8,21 +8,18 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.v1.engine.async_llm import AsyncLLM -def test_mp_reducer(monkeypatch): +def test_mp_reducer(): """ Test that _reduce_config reducer is registered when AsyncLLM is instantiated without transformers_modules. This is a regression test for https://github.com/vllm-project/vllm/pull/18640. """ - # Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value - monkeypatch.setenv('VLLM_USE_V1', '1') - # Ensure transformers_modules is not in sys.modules - if 'transformers_modules' in sys.modules: - del sys.modules['transformers_modules'] + if "transformers_modules" in sys.modules: + del sys.modules["transformers_modules"] - with patch('multiprocessing.reducer.register') as mock_register: + with patch("multiprocessing.reducer.register") as mock_register: engine_args = AsyncEngineArgs( model="facebook/opt-125m", max_model_len=32, @@ -36,7 +33,8 @@ def test_mp_reducer(monkeypatch): ) assert mock_register.called, ( - "multiprocessing.reducer.register should have been called") + "multiprocessing.reducer.register should have been called" + ) vllm_config_registered = False for call_args in mock_register.call_args_list: @@ -45,8 +43,7 @@ def test_mp_reducer(monkeypatch): vllm_config_registered = True reducer_func = call_args[0][1] - assert callable( - reducer_func), "Reducer function should be callable" + assert callable(reducer_func), "Reducer function should be callable" break assert vllm_config_registered, ( diff --git a/tests/conftest.py b/tests/conftest.py index 2bf88abb0f6c2..4713e12385965 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa + +from tblib import pickling_support + +# Install support for pickling exceptions so that we can nicely propagate +# failures from tests running in a subprocess. +# This should be run before any custom exception subclasses are defined. +pickling_support.install() + +import http.server import json +import math +import mimetypes import os +import socket import tempfile +import threading +from collections.abc import Generator +from contextlib import nullcontext from enum import Enum -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast import numpy as np import pytest @@ -13,27 +30,34 @@ import torch.nn as nn import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - BatchEncoding, BatchFeature) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BatchEncoding, + BatchFeature, +) from transformers.models.auto.auto_factory import _BaseAutoModelClass -from tests.models.utils import (TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs) +from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype +from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype from vllm.connections import global_http_connection -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) from vllm.logger import init_logger +from vllm.logprobs import Logprob +from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.utils import is_list_of, set_default_torch_num_threads logger = init_logger(__name__) @@ -63,12 +87,13 @@ class ImageAssetPrompts(TypedDict): class ImageTestAssets(list[ImageAsset]): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: ImageAssetPrompts) -> list[str]: """ @@ -85,11 +110,12 @@ class VideoAssetPrompts(TypedDict): class VideoTestAssets(list[VideoAsset]): - def __init__(self) -> None: - super().__init__([ - VideoAsset("baby_reading"), - ]) + super().__init__( + [ + VideoAsset("baby_reading"), + ] + ) def prompts(self, prompts: VideoAssetPrompts) -> list[str]: return [prompts["baby_reading"]] @@ -101,12 +127,13 @@ class AudioAssetPrompts(TypedDict): class AudioTestAssets(list[AudioAsset]): - def __init__(self) -> None: - super().__init__([ - AudioAsset("mary_had_lamb"), - AudioAsset("winning_call"), - ]) + super().__init__( + [ + AudioAsset("mary_had_lamb"), + AudioAsset("winning_call"), + ] + ) def prompts(self, prompts: AudioAssetPrompts) -> list[str]: return [prompts["mary_had_lamb"], prompts["winning_call"]] @@ -140,26 +167,6 @@ def cleanup_VLLM_USE_V1(monkeypatch): monkeypatch.delenv("VLLM_USE_V1") -@pytest.fixture(params=[True, False]) -def run_with_both_engines(request, monkeypatch): - # Automatically runs tests twice, once with V1 and once without - use_v1 = request.param - # Tests decorated with `@skip_v1` are only run without v1 - skip_v0 = request.node.get_closest_marker("skip_v0") - skip_v1 = request.node.get_closest_marker("skip_v1") - - if use_v1: - if skip_v1: - pytest.skip("Skipping test on vllm V1") - monkeypatch.setenv('VLLM_USE_V1', '1') - else: - if skip_v0: - pytest.skip("Skipping test on vllm V0") - monkeypatch.setenv('VLLM_USE_V1', '0') - - yield - - @pytest.fixture(autouse=True) def init_test_http_connection(): # pytest_asyncio may use a different event loop per test @@ -221,44 +228,12 @@ def example_system_message() -> str: class DecoderPromptType(Enum): """For encoder/decoder models only.""" + CUSTOM = 1 NONE = 2 EMPTY_STR = 3 -@pytest.fixture -def example_encoder_decoder_prompts( -) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]: - ''' - Returns an encoder prompt list and a decoder prompt list, wherein each pair - of same-index entries in both lists corresponds to an (encoder prompt, - decoder prompt) tuple. - - Returns: - - * Encoder prompt list - * Decoder prompt list (reverse of encoder prompt list) - ''' - - encoder_prompts = [] - for filename in _TEST_PROMPTS: - encoder_prompts += _read_prompts(filename) - - custom_decoder_prompts = encoder_prompts[::-1] - empty_str_decoder_prompts = [""] * len(encoder_prompts) - none_decoder_prompts = [None] * len(encoder_prompts) - - # NONE decoder prompt type - return { - DecoderPromptType.NONE: - zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts), - DecoderPromptType.EMPTY_STR: - zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts), - DecoderPromptType.CUSTOM: - zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts), - } - - @pytest.fixture def example_long_prompts() -> list[str]: prompts = [] @@ -287,15 +262,13 @@ _R = TypeVar("_R") class HfRunner: - def get_default_device(self): from vllm.platforms import current_platform - return ("cpu" - if current_platform.is_cpu() else current_platform.device_type) + return "cpu" if current_platform.is_cpu() else current_platform.device_type def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: - if x is None or isinstance(x, (bool, )): + if x is None or isinstance(x, (bool,)): return x if device is None: @@ -320,6 +293,38 @@ class HfRunner: is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, + # Set this to avoid hanging issue + default_torch_num_threads: Optional[int] = None, + ) -> None: + init_ctx = ( + nullcontext() + if default_torch_num_threads is None + else set_default_torch_num_threads(default_torch_num_threads) + ) + + with init_ctx: + self._init( + model_name=model_name, + dtype=dtype, + model_kwargs=model_kwargs, + trust_remote_code=trust_remote_code, + is_sentence_transformer=is_sentence_transformer, + is_cross_encoder=is_cross_encoder, + skip_tokenizer_init=skip_tokenizer_init, + auto_cls=auto_cls, + ) + + def _init( + self, + model_name: str, + dtype: str = "auto", + *, + model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, + is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, + skip_tokenizer_init: bool = False, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, ) -> None: model_name = maybe_model_redirect(model_name) self.model_name = model_name @@ -367,14 +372,15 @@ class HfRunner: ) # in case some unquantized custom models are not in same dtype - if (getattr(model, "quantization_method", None) is None - and any(p.dtype != self.dtype - for p in model.parameters())): + if getattr(model, "quantization_method", None) is None and any( + p.dtype != self.dtype for p in model.parameters() + ): model = model.to(dtype=self.dtype) - if (getattr(model, "quantization_method", None) != "bitsandbytes" - and len({p.device - for p in model.parameters()}) < 2): + if ( + getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device for p in model.parameters()}) < 2 + ): model = model.to(device=self.device) self.model = model @@ -389,6 +395,7 @@ class HfRunner: # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=torch_dtype, @@ -399,11 +406,11 @@ class HfRunner: def get_inputs( self, - prompts: list[str], + prompts: Union[list[str], list[list[int]]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> list[Union[BatchFeature, BatchEncoding]]: + ) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]: if images is not None: assert len(prompts) == len(images) @@ -413,31 +420,48 @@ class HfRunner: if audios is not None: assert len(prompts) == len(audios) - all_inputs: list[Union[BatchFeature, BatchEncoding]] = [] + all_inputs: list[ + Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]] + ] = [] for i, prompt in enumerate(prompts): - processor_kwargs: dict[str, Any] = { - "text": prompt, - "return_tensors": "pt", - } - if images is not None and (image := images[i]) is not None: - processor_kwargs["images"] = image - if videos is not None and (video := videos[i]) is not None: - processor_kwargs["videos"] = video - if audios is not None and (audio_inputs := audios[i]) is not None: - # HACK - not all processors take sampling_rate; we should - # clean this up in the future. - if len(audio_inputs) == 2: - audio, sr = audio_inputs - processor_kwargs["audio"] = audio - processor_kwargs["sampling_rate"] = sr - else: - processor_kwargs["audio"] = audio_inputs + if isinstance(prompt, str): + processor_kwargs: dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and (image := images[i]) is not None: + processor_kwargs["images"] = image + if videos is not None and (video := videos[i]) is not None: + processor_kwargs["videos"] = video + if audios is not None and (audio_inputs := audios[i]) is not None: + # HACK - not all processors take sampling_rate; we should + # clean this up in the future. + if len(audio_inputs) == 2: + audio, sr = audio_inputs + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + else: + processor_kwargs["audio"] = audio_inputs - inputs = self.processor(**processor_kwargs) - if isinstance(inputs, BatchFeature): - inputs = inputs.to(dtype=self.dtype) - - all_inputs.append(inputs) + inputs = self.processor(**processor_kwargs) + if isinstance(inputs, BatchFeature): + inputs = inputs.to(dtype=self.dtype) + all_inputs.append(inputs) + else: + # check that prompt is (batched) list of integers (token ids) + if not is_list_of(prompt, typ=int, check="all"): + raise ValueError( + "Prompt must be a list of ints corresponding to the prompt token ids." + ) + # check that no multimodal input is provided + if images or videos or audios: + raise ValueError( + "When providing prompt token ids multimodal inputs are not supported." + ) + input_dict = { + "input_ids": torch.tensor(prompt, dtype=torch.long).unsqueeze(0), + } + all_inputs.append(input_dict) return all_inputs @@ -454,11 +478,10 @@ class HfRunner: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] + problem_type = getattr(self.config, "problem_type", "") + for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) - - problem_type = getattr(self.config, "problem_type", "") - if problem_type == "regression": logits = output.logits[0].tolist() elif problem_type == "multi_label_classification": @@ -471,16 +494,15 @@ class HfRunner: def generate( self, - prompts: list[str], + prompts: Union[list[str], list[list[int]]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: @@ -500,23 +522,24 @@ class HfRunner: def generate_greedy( self, - prompts: list[str], + prompts: Union[list[str], list[list[int]]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - images=images, - videos=videos, - audios=audios, - **kwargs) + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) - return [(output_ids[0], output_str[0]) - for output_ids, output_str in outputs] + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_beam_search( self, @@ -527,21 +550,22 @@ class HfRunner: videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - num_beams=beam_width, - num_return_sequences=beam_width, - images=images, - videos=videos, - audios=audios) + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios, + ) for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): output_ids[j] = [ - x for x in output_ids[j] - if x != self.tokenizer.pad_token_id + x for x in output_ids[j] if x != self.tokenizer.pad_token_id ] outputs[i] = (output_ids, output_str) return outputs @@ -555,10 +579,9 @@ class HfRunner: audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[list[torch.Tensor]]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: @@ -571,8 +594,7 @@ class HfRunner: return_dict_in_generate=True, **kwargs, ) - seq_logprobs = self._hidden_states_to_seq_logprobs( - output.hidden_states) + seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states) all_logprobs.append(seq_logprobs) return all_logprobs @@ -602,7 +624,7 @@ class HfRunner: def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: int, + num_logprobs: Optional[int], ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -630,16 +652,15 @@ class HfRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) all_logprobs: list[list[dict[int, float]]] = [] all_output_ids: list[list[int]] = [] @@ -659,8 +680,7 @@ class HfRunner: ( seq_logprobs_lst, output_len, - ) = self._hidden_states_to_logprobs(output.hidden_states, - num_logprobs) + ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] @@ -670,81 +690,16 @@ class HfRunner: all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] - def generate_encoder_decoder_greedy_logprobs_limit( - self, - encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], - max_tokens: int, - num_logprobs: int, - images: Optional[PromptImageInput] = None, - **kwargs: Any, - ) -> list[TokensTextLogprobs]: - ''' - Greedy logprobs generation for vLLM encoder/decoder models - ''' - - all_logprobs: list[list[dict[int, float]]] = [] - all_output_ids: list[list[int]] = [] - all_output_strs: list[str] = [] - - for i, (encoder_prompt, decoder_prompt) in enumerate( - to_enc_dec_tuple_list(encoder_decoder_prompts)): - processor_kwargs: dict[str, Any] = { - "text": encoder_prompt, - "return_tensors": "pt", - } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - - encoder_inputs = self.processor(**processor_kwargs) - encoder_inputs = self.wrap_device(encoder_inputs) - - if decoder_prompt is None: - decoder_input_ids = None - else: - decoder_inputs = self.tokenizer(decoder_prompt, - return_tensors="pt") - decoder_input_ids = self.wrap_device(decoder_inputs.input_ids) - - output = self.model.generate( - decoder_input_ids=decoder_input_ids, - use_cache=True, - do_sample=False, - max_new_tokens=max_tokens, - output_hidden_states=True, - return_dict_in_generate=True, - **encoder_inputs, - **kwargs, - ) - - ( - seq_logprobs_lst, - output_len, - ) = self._hidden_states_to_logprobs(output.decoder_hidden_states, - num_logprobs) - - all_logprobs.append(seq_logprobs_lst) - seq_ids = output.sequences[0] - output_ids = seq_ids[-output_len:] - all_output_ids.append(output_ids.tolist()) - all_output_strs.append(self.tokenizer.decode(output_ids)) - - outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - - def encode(self, prompts: list[str], *args, - **kwargs) -> list[list[torch.Tensor]]: + def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]: return self.model.encode(prompts, *args, **kwargs) - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - return self.model.predict(prompts, - *args, - convert_to_tensor=True, - **kwargs) + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: + return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs) def __enter__(self): return self @@ -791,44 +746,68 @@ class VllmRunner: enable_chunked_prefill: Optional[bool] = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, + # Set this to avoid hanging issue + default_torch_num_threads: Optional[int] = None, **kwargs, ) -> None: - self.llm = LLM( - model=model_name, - runner=runner, - convert=convert, - tokenizer=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - dtype=dtype, - seed=seed, - swap_space=swap_space, - enforce_eager=enforce_eager, - disable_log_stats=disable_log_stats, - tensor_parallel_size=tensor_parallel_size, - max_model_len=max_model_len, - block_size=block_size, - enable_chunked_prefill=enable_chunked_prefill, - **kwargs, + init_ctx = ( + nullcontext() + if default_torch_num_threads is None + else set_default_torch_num_threads(default_torch_num_threads) ) + if not kwargs.get("compilation_config", None): + # Note(@tdoublep): This is set to 4 because some tests (e.g., hybrid + # model tests) may set max_num_seqs=4. If min cudagraph_capture_size is + # set to larger than max_num_seqs, then it will lead to *no* graphs + # being captured which can trigger edge cases that we don't handle yet. + kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]} + + with init_ctx: + self.llm = LLM( + model=model_name, + runner=runner, + convert=convert, + tokenizer=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + dtype=dtype, + seed=seed, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) + def get_inputs( self, - prompts: Union[list[str], list[torch.Tensor], list[int]], + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> list[TextPrompt]: - - if any(x is not None and len(x) != len(prompts) - for x in [images, videos, audios]): + ) -> list[dict[str, Any]]: + if any( + x is not None and len(x) != len(prompts) for x in [images, videos, audios] + ): raise ValueError( - "All non-None multimodal inputs must have the same length as " - "prompts") + "All non-None multimodal inputs must have the same length as prompts" + ) - inputs = [] + inputs = list[dict[str, Any]]() for i, prompt in enumerate(prompts): - multi_modal_data = {} + prompt_dict = dict[str, Any]() + if isinstance(prompt, str): + prompt_dict["prompt"] = prompt + elif isinstance(prompt, list): + prompt_dict["prompt_token_ids"] = prompt + else: + prompt_dict["prompt_embeds"] = prompt + + multi_modal_data = dict[str, Any]() if images is not None and (image := images[i]) is not None: multi_modal_data["image"] = image if videos is not None and (video := videos[i]) is not None: @@ -836,37 +815,27 @@ class VllmRunner: if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio - text_prompt_kwargs: dict[str, Any] = { - "multi_modal_data": multi_modal_data or None - } - if isinstance(prompt, str): - text_prompt_kwargs["prompt"] = prompt - elif isinstance(prompt, list): - text_prompt_kwargs["prompt_token_ids"] = prompt - else: - text_prompt_kwargs["prompt_embeds"] = prompt + if multi_modal_data: + prompt_dict["multi_modal_data"] = multi_modal_data - inputs.append(TextPrompt(**text_prompt_kwargs)) + inputs.append(prompt_dict) return inputs def generate( self, - prompts: Union[list[str], list[torch.Tensor]], + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.llm.generate(inputs, - sampling_params=sampling_params, - **kwargs) + req_outputs = self.llm.generate( + inputs, sampling_params=sampling_params, **kwargs + ) outputs: list[tuple[list[list[int]], list[str]]] = [] for req_output in req_outputs: @@ -893,8 +862,9 @@ class VllmRunner: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs, - req_output.prompt_logprobs)) + outputs.append( + (output_ids, output_str, output_logprobs, req_output.prompt_logprobs) + ) return outputs def generate_w_logprobs( @@ -905,47 +875,26 @@ class VllmRunner: audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.llm.generate(inputs, - sampling_params=sampling_params, - **kwargs) + req_outputs = self.llm.generate( + inputs, sampling_params=sampling_params, **kwargs + ) - toks_str_logsprobs_prompt_logprobs = ( - self._final_steps_generate_w_logprobs(req_outputs)) + toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs( + req_outputs + ) # Omit prompt logprobs if not required by sampling params - return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.prompt_logprobs is None else - toks_str_logsprobs_prompt_logprobs) - - def generate_encoder_decoder_w_logprobs( - self, - encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], - sampling_params: SamplingParams, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: - ''' - Logprobs generation for vLLM encoder/decoder models - ''' - - assert sampling_params.logprobs is not None - req_outputs = self.llm.generate(encoder_decoder_prompts, - sampling_params=sampling_params) - toks_str_logsprobs_prompt_logprobs = ( - self._final_steps_generate_w_logprobs(req_outputs)) - # Omit prompt logprobs if not required by sampling params - return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.prompt_logprobs is None else - toks_str_logsprobs_prompt_logprobs) + return ( + [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None + else toks_str_logsprobs_prompt_logprobs + ) def generate_greedy( self, - prompts: Union[list[str], list[torch.Tensor]], + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, @@ -953,20 +902,21 @@ class VllmRunner: **kwargs: Any, ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, - greedy_params, - images=images, - videos=videos, - audios=audios, - **kwargs) - return [(output_ids[0], output_str[0]) - for output_ids, output_str in outputs] + outputs = self.generate( + prompts, + greedy_params, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_greedy_logprobs( self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, @@ -974,45 +924,52 @@ class VllmRunner: stop_token_ids: Optional[list[int]] = None, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids, - stop=stop) - - return self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos, - **kwargs) - - def generate_encoder_decoder_greedy_logprobs( - self, - encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], - max_tokens: int, - num_logprobs: int, - num_prompt_logprobs: Optional[int] = None, - skip_special_tokens: bool = True, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: - greedy_logprobs_params = SamplingParams( - temperature=0.0, - max_tokens=max_tokens, - logprobs=num_logprobs, - prompt_logprobs=(num_prompt_logprobs), - skip_special_tokens=skip_special_tokens, + stop=stop, ) - ''' - Greedy logprobs generation for vLLM encoder/decoder models - ''' - return self.generate_encoder_decoder_w_logprobs( - encoder_decoder_prompts, greedy_logprobs_params) + return self.generate_w_logprobs( + prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos, + **kwargs, + ) + + def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + """ + Return the perplexity score associated with generating the prompts + + :param prompts: list of prompts to score + :return: perplexity score of each prompt + """ + outputs = self.generate_greedy_logprobs( + prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0 + ) + + perplexities = [] + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs)) + perplexities.append(perplexity) + + return perplexities def generate_beam_search( self, @@ -1022,15 +979,15 @@ class VllmRunner: images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) outputs = self.llm.beam_search( inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens), + concurrency_limit=concurrency_limit, + ) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] @@ -1042,17 +999,16 @@ class VllmRunner: req_outputs = self.llm.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] - def embed(self, - prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - *args, - **kwargs) -> list[list[float]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + def embed( + self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs, + ) -> list[list[float]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] @@ -1076,17 +1032,10 @@ class VllmRunner: return [req_output.outputs.score for req_output in req_outputs] def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - if hasattr(self.llm.llm_engine, "model_executor"): - # This works either in V0 or in V1 with - # VLLM_ENABLE_V1_MULTIPROCESSING=0 - executor = self.llm.llm_engine.model_executor - return executor.apply_model(func) + return self.llm.apply_model(func) - # This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1 - def _apply_model(self): - return func(self.get_model()) - - return self.llm.llm_engine.collective_rpc(_apply_model) + def get_llm(self) -> LLM: + return self.llm def __enter__(self): return self @@ -1104,6 +1053,7 @@ def vllm_runner(): @pytest.fixture() def temporary_enable_log_propagate(): import logging + logger = logging.getLogger("vllm") logger.propagate = True yield @@ -1123,6 +1073,7 @@ def num_gpus_available(): in current process.""" from vllm.platforms import current_platform + return current_platform.device_count() @@ -1136,12 +1087,11 @@ _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding") def dummy_opt_path(): json_path = os.path.join(_dummy_opt_path, "config.json") if not os.path.exists(_dummy_opt_path): - snapshot_download(repo_id="facebook/opt-125m", - local_dir=_dummy_opt_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="facebook/opt-125m", + local_dir=_dummy_opt_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1155,12 +1105,18 @@ def dummy_opt_path(): def dummy_llava_path(): json_path = os.path.join(_dummy_llava_path, "config.json") if not os.path.exists(_dummy_llava_path): - snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf", - local_dir=_dummy_llava_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="llava-hf/llava-1.5-7b-hf", + local_dir=_dummy_llava_path, + ignore_patterns=[ + "*.bin", + "*.bin.index.json", + "*.pt", + "*.h5", + "*.msgpack", + "*.safetensors", + ], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1174,12 +1130,18 @@ def dummy_llava_path(): def dummy_gemma2_embedding_path(): json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") if not os.path.exists(_dummy_gemma2_embedding_path): - snapshot_download(repo_id="BAAI/bge-multilingual-gemma2", - local_dir=_dummy_gemma2_embedding_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="BAAI/bge-multilingual-gemma2", + local_dir=_dummy_gemma2_embedding_path, + ignore_patterns=[ + "*.bin", + "*.bin.index.json", + "*.pt", + "*.h5", + "*.msgpack", + "*.safetensors", + ], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1192,10 +1154,9 @@ def dummy_gemma2_embedding_path(): # Add the flag `--optional` to allow run tests # that are marked with @pytest.mark.optional def pytest_addoption(parser): - parser.addoption("--optional", - action="store_true", - default=False, - help="run optional test") + parser.addoption( + "--optional", action="store_true", default=False, help="run optional test" + ) def pytest_collection_modifyitems(config, items): @@ -1218,3 +1179,118 @@ def cli_config_file(): def cli_config_file_with_model(): """Return the path to the CLI config file with model.""" return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml") + + +class AssetHandler(http.server.BaseHTTPRequestHandler): + # _IMAGE_CACHE : Dict[str, bytes] = {} + + def log_message(self, *args, **kwargs): + pass + + def do_GET(self): + # Accepts paths like: /1280px-Venn_diagram_rgb.jpg + filename = self.path.lstrip("/") + if not filename or "." not in filename: + self.send_error(404, "Missing filename (expected /<name>.<ext>)") + return + + base, ext = filename.rsplit(".", 1) + ext = ext.lower() + + if ext not in ["jpg", "png"]: + self.send_error(404, f"Unsupported extension: .{ext}") + return + + try: + data = ImageAsset(base).read_bytes(ext=ext) + except Exception as e: + self.send_error(500, f"Failed to load asset: {ext} {base} {e} ") + return + + ctype, _ = mimetypes.guess_type(filename) + if ctype is None: + ctype = {"jpg": "image/jpg", "png": "image/png"}[ext] + self.send_response(200) + self.send_header("Content-Type", ctype) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + +def _find_free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +class LocalAssetServer: + address: str + port: int + server: Optional[http.server.ThreadingHTTPServer] + thread: Optional[threading.Thread] + + def __init__(self, address: str = "127.0.0.1") -> None: + self.address = address + self.port = -1 + self.server = None + self.thread = None + + def __enter__(self): + self.port = _find_free_port() + self.server = http.server.ThreadingHTTPServer( + (self.address, self.port), AssetHandler + ) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.server: + self.server.shutdown() + del self.server + + if self.thread: + self.thread.join() + del self.thread + + if exc_type is None: + return None + + return False + + @property + def base_url(self) -> str: + assert self.port is not None + return f"http://{self.address}:{self.port}" + + def url_for(self, name: str) -> str: + """e.g., name='RGBA_comp.png' -> 'http://127.0.0.1:PORT/RGBA_comp.png'""" + return f"{self.base_url}/{name}" + + def get_image_asset(self, name: str) -> Image.Image: + return fetch_image(self.url_for(name)) + + +@pytest.fixture(scope="session") +def local_asset_server() -> Generator[LocalAssetServer, None, None]: + """ + Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. + The server currently servers images at: + http://127.0.0.1:<port>/<name>.<ext> + """ + with LocalAssetServer() as srv: + yield srv + + +@pytest.fixture +def image_url(request, local_asset_server) -> str: + # request.param is one of the IMAGE_ASSETS filenames + name = request.param + return local_asset_server.url_for(name) + + +@pytest.fixture +def image_urls(request, local_asset_server) -> list[str]: + """Indirect fixture: takes a list of names, returns list of full URLs.""" + names: list[str] = request.param + return [local_asset_server.url_for(name) for name in names] diff --git a/tests/core/block/conftest.py b/tests/core/block/conftest.py deleted file mode 100644 index 6afe98d78ce81..0000000000000 --- a/tests/core/block/conftest.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - - -@pytest.fixture() -def should_do_global_cleanup_after_test() -> bool: - """Disable the global cleanup fixture for tests in this directory. This - provides a ~10x speedup for unit tests that don't load a model to GPU. - - This requires that tests in this directory clean up after themselves if they - use the GPU. - """ - return False diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py deleted file mode 100644 index e2c6c66b259c8..0000000000000 --- a/tests/core/block/e2e/conftest.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Iterable -from typing import Callable, Optional - -import pytest - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.model_executor.utils import set_random_seed - - -@pytest.fixture -def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed) - - -@pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed) - - -def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - distinct_llm_kwargs, seed): - kwargs = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **distinct_llm_kwargs, - } - - def generator_inner(): - llm = LLM(**kwargs) - - set_random_seed(seed) - - yield llm - del llm - cleanup_dist_env_and_memory() - - for llm in generator_inner(): - yield llm - del llm - - -def get_text_from_llm_generator(llm_generator: Iterable[LLM], - prompts, - sampling_params, - llm_cb: Optional[Callable[[LLM], - None]] = None): - for llm in llm_generator: - if llm_cb: - llm_cb(llm) - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - text = [output.outputs[0].text for output in outputs] - del llm - - return text - - -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm - - return token_ids diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py deleted file mode 100644 index 93222b564ebe7..0000000000000 --- a/tests/core/block/e2e/test_correctness.py +++ /dev/null @@ -1,479 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from itertools import cycle - -import pytest - -from vllm import SamplingParams - -from .conftest import get_token_ids_from_llm_generator - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "preemption_mode": "swap" -}, { - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_block_manager_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify block manager produces same outputs even when there is preemption. - - This constructs two LLM, each with limited number of GPU blocks. The limit - is decided such that as the sequences in the batch grow, sequences must be - preempted and removed from cache. - - If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted. - - NOTE: We want a significant number of generated tokens so that any incorrect - KV mapping has time to build up error. - - NOTE(Kuntai): Though we have removed block manager v1, this test is still - useful as it asserts the behavior of block manager v2 (now it is called - SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we - keep this test. - """ - output_len = 1024 - temperature = 0.0 - - # We want to ensure equality even with preemption. - # We force the total block size to be 1 + cdiv(output_len, block_size) - # so that only one sequence can fit at a time (once the sequences grow). - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # Our prompts will generate 128 tokens; since the prompts themselves are - # small, we don't need much KV space beyond 128. - "max_model_len": 160, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - "block_size": 16, - - # Allow only 2 sequences of ~128 tokens in worst case. - # Note 8 = 128/block_size - "num_gpu_blocks_override": 2 * (8 + 1), - }, - { - "block_size": 8, - - # Allow only 2 sequences of ~128 tokens in worst case. - # Note 16 = 128/block_size - "num_gpu_blocks_override": 2 * (16 + 2), - } - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "num_lookahead_slots": 0, -}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - # We run one test with block_size < lookahead_slots, one test with - # block_size > lookahead_slots - "num_lookahead_slots": 10, - "preemption_mode": "swap", - }, - { - "num_lookahead_slots": 10, - "preemption_mode": "recompute", - } - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size): - """Verify vLLM produces the same output with greedy sampling, when lookahead - scheduling is used vs. not. - - Lookahead scheduling is not expected to modify the output, as it simply - allocates empty slots ahead of the known token ids in a sliding fashion. - - This test constrains the total number of blocks to force preemption. It also - varies the block size so that the lookahead size is less than and greater - than the block size. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids without lookahead scheduling') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with lookahead scheduling') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [ - { - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "enable_chunked_prefill": True, - }, - ]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", - [{ - "block_size": 16, - "max_num_batched_tokens": 2, - "max_num_seqs": 2, - }, { - "block_size": 16, - "max_num_batched_tokens": 3, - "max_num_seqs": 2, - }, { - "block_size": 16, - "max_num_batched_tokens": 256, - "max_num_seqs": 10, - }]) -@pytest.mark.parametrize("baseline_llm_kwargs", [ - {}, -]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "num_lookahead_slots": 0, - }, - { - "num_lookahead_slots": 5, - }, -]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_chunked_prefill_block_manager(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify that chunked prefill works with SelfAttnBlockSpaceManager, - with and without lookahead scheduling. - """ - output_len = 32 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - ("1 + " * 50) + " 1 = ", # Longer prompt. - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids with BlockManager') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with BlockManager, with lookahead slots.') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - - # Enable prefill cache - "enable_prefix_caching": True, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "preemption_mode": "swap" -}, { - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_block_manager_prefix_caching_enabled_with_preemption( - baseline_llm_generator, test_llm_generator, batch_size): - """Verify block manager produces same outputs even when there is preemption. - - This constructs two LLM, each with limited number of GPU blocks. The limit - is decided such that as the sequences in the batch grow, sequences must be - preempted and removed from cache. - - If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted. - - NOTE: We want a significant number of generated tokens so that any incorrect - KV mapping has time to build up error. - - NOTE(Kuntai): Though we have removed block manager v1, this test is still - useful as it asserts the behavior of block manager v2 (now it is called - SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we - keep this test. - """ - output_len = 1024 - temperature = 0.0 - - # We want to ensure equality even with preemption. - # We force the total block size to be 1 + cdiv(output_len, block_size) - # so that only one sequence can fit at a time (once the sequences grow). - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids from block manager') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids from block manager, with preemption') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "enable_prefix_caching": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "enable_prefix_caching": True, - "preemption_mode": "swap" -}, { - "enable_prefix_caching": True, - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_auto_prefix_caching_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify block manager v2 with auto prefix caching enabled produces same - outputs as auto prefix caching disabled, even when there is preemption. - - This constructs two LLM, each with limited number of GPU blocks. The limit - is decided such that as the sequences in the batch grow, sequences must be - preempted and removed from cache. - - If the output token ids are equivalent, then we have confidence that auto - prefix caching itself at least don't cause result error. - """ - output_len = 1024 - temperature = 0.0 - - # We want to ensure equality even with preemption. - # We force the total block size to be 1 + cdiv(output_len, block_size) - # so that only one sequence can fit at a time (once the sequences grow). - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids with APC disabled') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with APC enabled') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # we keep the blocks small, so that hit eviction quickly - "max_model_len": 48, - "block_size": 16, - "num_gpu_blocks_override": 3, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "enable_prefix_caching": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "enable_prefix_caching": True, -}]) -@pytest.mark.parametrize("seed", [1]) -def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator, - test_llm_generator): - """Verify block manager v2 with auto prefix caching could works normal - even when eviction started. - With APC enabled, all blocks are held by native block at the beginning. - Then blocks are managed by evictor instead. If cache hit at the evitor's - block, then it could be reused, or we need to recompute its kv cache. - """ - output_len = 10 - temperature = 0.0 - - prompts = [ - "You are a helpful assistant. Please answer truthfully and write " - "out your thinking step by step to be sure you get the right answer. " - "If you make a mistake, attempt to correct it. who are you?", - "You are a helpful assistant. Please answer truthfully and write out " - "your thinking step by step to be sure you get the right answer. You " - "are helpful and harmless and you follow ethical guidelines. " - "who are you?" - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids with APC disabled') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with APC enabled') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py deleted file mode 100644 index 27fe27a880e3d..0000000000000 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ /dev/null @@ -1,185 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random - -import pytest - -from tests.kernels.utils import override_backend_env_variable -from vllm import LLM, SamplingParams -from vllm.platforms import current_platform - -from .conftest import get_text_from_llm_generator - -# relatively small model with 4k sliding window -MODEL = "bigcode/starcoder2-3b" -BLOCK_SIZE = 16 - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": MODEL, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "block_size": BLOCK_SIZE, - # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 - "num_gpu_blocks_override": 100000 // BLOCK_SIZE, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [5]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) -def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, - batch_size, seed, backend, monkeypatch): - """ - The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then - asks for value of one of them (which is outside the sliding window). - If we tell it upfront which we are going to be looking for, then - it answers correctly (mostly). - - Additionally, we compare the results of the v1 and v2 managers. - """ - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - - override_backend_env_variable(monkeypatch, backend) - - sampling_params = SamplingParams( - max_tokens=1024, - ignore_eos=True, - temperature=0.0, - ) - - prompts, answer, indices = prep_prompts(batch_size) - - baseline_texts = get_text_from_llm_generator(baseline_llm_generator, - prompts, - sampling_params, - llm_cb=check_window(prompts)) - - check_answers(indices, answer, baseline_texts) - - print('Getting token ids from block manager v2') - test_texts = get_text_from_llm_generator(test_llm_generator, prompts, - sampling_params) - check_answers(indices, answer, test_texts) - - cmp = [ - expected_text == actual_text - for expected_text, actual_text in zip(baseline_texts, test_texts) - ] - print(cmp) - # make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768 - # however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290 - # states that xformers and flash_attn have different ideas about the window - # size anyways - assert sum(cmp) > 0.7 * len(cmp) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": MODEL, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "block_size": BLOCK_SIZE, - "num_gpu_blocks_override": 100000 // BLOCK_SIZE, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) -@pytest.mark.parametrize("batch_size", [5]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) -def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, - backend, monkeypatch): - """ - This is similar to test_sliding_window_retrieval, however, it doesn't - compare against the v1 block manager since v1 doesn't support - chunked prefill with sliding window. - - The results with and without chunked prefill are not the same due to - numerical instabilities. - """ - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - override_backend_env_variable(monkeypatch, backend) - - sampling_params = SamplingParams( - max_tokens=10, - ignore_eos=True, - temperature=0.0, - ) - - prompts, answer, indices = prep_prompts(batch_size) - - # We don't compare with the baseline model here, since the results - # slightly different due to different tailing in attention. - test_texts = get_text_from_llm_generator(test_llm_generator, - prompts, - sampling_params, - llm_cb=check_window(prompts)) - check_answers(indices, answer, test_texts) - - -def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): - """ - Generate prompts which a bunch of assignments, - then asking for the value of one of them. - The prompt is just under 10k tokens; sliding window is 4k - so the answer is outside sliding window, but should still be correct. - - Args: - batch_size: number of prompts to generate - ln_range: an argument to control the length of the prompt - """ - prompts: list[str] = [] - answer: list[int] = [] - indices: list[int] = [] - random.seed(1) - for _ in range(batch_size): - idx = random.randint(30, 90) - indices.append(idx) - prompt = "```python\n# We set a number of variables, " + \ - f"x{idx} will be important later\n" - ln = random.randint(*ln_range) - for k in range(30, ln): - v = random.randint(10, 99) - if k == idx: - answer.append(v) - prompt += f"x{k} = {v}\n" - prompt += f"# Now, we check the value of x{idx}:\n" - prompt += f"assert x{idx} == " - prompts.append(prompt) - return prompts, answer, indices - - -def check_answers(indices: list[int], - answer: list[int], - outputs: list[str], - accept_rate: float = 0.7): - answer2 = [int(text[0:2].strip()) for text in outputs] - print(list(zip(indices, zip(answer, answer2)))) - numok = 0 - for a1, a2 in zip(answer, answer2): - if a1 == a2: - numok += 1 - frac_ok = numok / len(answer) - print(f"Num OK: {numok}/{len(answer)} {frac_ok}") - assert frac_ok >= accept_rate - - -def check_window(prompts: list[str]): - - def inner(llm: LLM): - sliding_window = llm.llm_engine.model_config.get_sliding_window() - assert sliding_window and sliding_window > 0 - assert any( - len(llm.get_tokenizer().tokenize(prompt)) > sliding_window - for prompt in prompts) - - return inner diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py deleted file mode 100644 index 9eed264fd7d43..0000000000000 --- a/tests/core/block/test_block_manager.py +++ /dev/null @@ -1,494 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.core.block_manager import SelfAttnBlockSpaceManager -from vllm.core.interfaces import AllocStatus -from vllm.sequence import Logprob, SequenceStatus -from vllm.utils import chunk_list - -from ..utils import (create_dummy_prompt, create_seq_group, - create_seq_group_encoder_decoder) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) -@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - ) - num_watermark_blocks = int(watermark * num_gpu_blocks) - - num_output_blocks_per_seq = 1 - - # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but - # the current implementation assumes all seqs are new prompts / don't have - # different output lens. - num_output_blocks = num_output_blocks_per_seq - - for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): - seq_group = create_seq_group( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - ) - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - can_allocate_result = block_manager.can_allocate(seq_group) - - num_required_blocks = num_prompt_blocks + num_output_blocks - - if num_gpu_blocks - num_required_blocks < num_watermark_blocks: - assert can_allocate_result == AllocStatus.NEVER - elif num_gpu_blocks >= num_required_blocks: - assert can_allocate_result == AllocStatus.OK - else: - assert can_allocate_result == AllocStatus.LATER - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) -@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group_encoder_decoder(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - ) - num_watermark_blocks = int(watermark * num_gpu_blocks) - - num_output_blocks_per_seq = 1 - - # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but - # the current implementation assumes all seqs are new prompts / don't have - # different output lens. - num_output_blocks = num_output_blocks_per_seq - - for bdx, num_prompt_blocks in enumerate( - range(1, num_gpu_blocks - num_output_blocks)): - num_cross_blocks_per_seq = num_prompt_blocks - - seq_group = create_seq_group_encoder_decoder( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - request_id=str(bdx)) - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - can_allocate_result = block_manager.can_allocate(seq_group) - - num_required_blocks = num_prompt_blocks + \ - num_output_blocks + \ - num_cross_blocks_per_seq - - if num_gpu_blocks - num_required_blocks < num_watermark_blocks: - assert can_allocate_result == AllocStatus.NEVER - elif num_gpu_blocks >= num_required_blocks: - assert can_allocate_result == AllocStatus.OK - else: - assert can_allocate_result == AllocStatus.LATER - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [16]) -@pytest.mark.parametrize("num_seqs_per_group", [1]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): - ''' - SWA short for Sliding Window Attention. - - At time of writing block manager does not support SWA. - - However even when SWA is implemented for block manager, - there will still most likely be a separate workstream required - to enable SWA for encoder/decoder models. - - Therefore this test enforces that one of the following cases - hold true: - 1. Block manager does not support SWA at all (true at time of writing) - 2. Block manager fails with NotImplementError when SWA is enabled - AND a SequenceGroup with an encoder sequence (i.e. in support of an - encoder/decoder model) is passed into can_allocate() as an argument - - The setup for this test is stripped down version of - test_can_allocate_seq_group_encoder_decoder() - ''' - - with pytest.raises((NotImplementedError, AssertionError)) as exc_info: - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - sliding_window=5 # SWA - ) - - num_output_blocks_per_seq = 1 - num_prompt_blocks = 1 - num_output_blocks = num_output_blocks_per_seq - seq_group = create_seq_group_encoder_decoder( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - request_id="0") - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - block_manager.can_allocate(seq_group) - - # Assert that either - # 1. Block manager constructor fails with assertion that sliding window - # is not yet supported (most likely near-term outcome at time of - # writing), or - # 2. can_allocate() fails with NotImplementedError due to combination of - # encoder/decoder and sliding window attention - if isinstance(exc_info.value, NotImplementedError): - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA - elif isinstance(exc_info.value, AssertionError): - assert str(exc_info.value) == "Sliding window not yet supported" - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [16]) -@pytest.mark.parametrize("num_seqs_per_group", [1]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_encoder_decoder_fails_with_prefix_cache( - block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, - watermark: float): - - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - enable_caching=True # Prefix cache - ) - - num_output_blocks_per_seq = 1 - num_prompt_blocks = 1 - num_output_blocks = num_output_blocks_per_seq - seq_group = create_seq_group_encoder_decoder( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - request_id="0") - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - # Assert that either can_allocate() fails with NotImplementedError - # due to combination of encoder/decoder and prefix cache - with pytest.raises(NotImplementedError) as exc_info: - block_manager.can_allocate(seq_group) - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("prompt_len", [1, 7, 8]) -@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) -@pytest.mark.parametrize("num_lookahead_slots", [0, 10]) -def test_append_slots(block_size, prompt_len, num_slots_to_append, - num_lookahead_slots): - """Verify append_slots consumes the correct number of blocks from the block - table. - """ - - num_gpu_blocks = 1024 - watermark = 0.1 - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - watermark=watermark, - ) - - seq_group = create_seq_group( - seq_prompt_len=prompt_len, - seq_output_lens=[0], - ) - - # Allocate seq - assert block_manager.can_allocate(seq_group) - block_manager.allocate(seq_group) - - # Seq seq to RUNNING - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - # Append tokens to the sequeqnce - for token_id in range(num_slots_to_append): - seq.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Append slots for new tokens and lookahead slots. - free_blocks_before_append = block_manager.get_num_free_gpu_blocks() - block_manager.append_slots(seq, num_lookahead_slots) - num_consumed_blocks = (free_blocks_before_append - - block_manager.get_num_free_gpu_blocks()) - - # Expect consumed blocks to be new blocks required to support the new slots. - expected_consumed_blocks = len( - list( - chunk_list( - list( - range(prompt_len + num_slots_to_append + - num_lookahead_slots)), - block_size))) - len( - list(chunk_list(list(range(prompt_len)), block_size))) - assert num_consumed_blocks == expected_consumed_blocks - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("num_cpu_blocks", [4]) -@pytest.mark.parametrize("num_gpu_blocks", [4]) -@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) -@pytest.mark.parametrize("enable_caching", [False, True]) -def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, - enable_caching): - """Verify blocks number on src/desc device is correct after swapping in/out - sequence group (not missing or extra blocks). - """ - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) - prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - prompt.status = SequenceStatus.RUNNING - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap seq group from GPU -> CPU. - gpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - mapping_keys = [key for key, _ in mapping] - assert mapping_keys == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - prompt.status = SequenceStatus.SWAPPED - - # Swap seq group from CPU -> GPU. - assert block_manager.can_swap_in(seq_group, num_lookahead_slots) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_in(seq_group) - cpu_blocks = block_manager.get_block_table(prompt) - mapping_keys = [key for key, _ in mapping] - assert mapping_keys == [cpu_blocks[0]] - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("num_gpu_blocks", [4]) -@pytest.mark.parametrize("num_lookahead_slots", [3, 8, 10]) -@pytest.mark.parametrize("enable_caching", [True, False]) -def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, - enable_caching): - """ Verify the block manager can correctly determine if a sequence group - can be swapped in/out. - """ - num_cpu_blocks = num_gpu_blocks - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) - prompt, seq_group = create_dummy_prompt( - "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - prompt.status = SequenceStatus.RUNNING - - # Swap seq group from GPU -> CPU. - gpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - mapping_keys = [key for key, _ in mapping] - assert mapping_keys == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - prompt.status = SequenceStatus.SWAPPED - - # At this moment, we still have enough free blocks to swap in the seq group. - if num_lookahead_slots <= block_size: - assert block_manager.can_swap_in(seq_group, - num_lookahead_slots) == AllocStatus.OK - else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER - - # During Swapped out, 2 cached blocks were evicted from the GPU, - # so the prompt1 can't be swapped in - prompt2_len = 2 * block_size - 1 - prompt2, seq_group2 = create_dummy_prompt( - "2", - prompt_length=prompt2_len, - prompt_tokens=[10000 + i for i in range(prompt2_len)]) - prompt2.status = SequenceStatus.WAITING - block_manager.allocate(seq_group2) - - # Swap seq group from CPU -> GPU. - if num_lookahead_slots <= block_size: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.LATER - else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER - - -@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) -@pytest.mark.parametrize("enable_caching", [False, True]) -def test_swap_in_infeasible(num_lookahead_slots, enable_caching): - """Verifies that swapping fails if there is not enough free blocks - to account for unseen tokens and lookahead_slots. - """ - block_size = 8 - num_cpu_blocks = 1 - num_gpu_blocks = 1 - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) - prompt_length = block_size - 3 - assert prompt_length > 0 - prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - prompt.status = SequenceStatus.RUNNING - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap seq group from GPU -> CPU. - assert block_manager.can_swap_out(seq_group) - block_manager.swap_out(seq_group) - prompt.status = SequenceStatus.SWAPPED - - # Swap seq group from CPU -> GPU. - # The number of unseen tokens is 1. If the number of existing - # tokens plus the unseen ones and number of lookahead slots exceeds - # the total number of available GPU blocks then the swap - # should fail. - num_unseen_tokens = 1 - if (num_lookahead_slots + num_unseen_tokens + - prompt_length) <= (block_size * num_gpu_blocks): - assert block_manager.can_swap_in(seq_group, - num_lookahead_slots) == AllocStatus.OK - else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER - - -# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. - - -@pytest.mark.parametrize("block_size", [8, 16]) -@pytest.mark.parametrize("prompt_len", [10, 300, 1000]) -@pytest.mark.parametrize("num_slots_to_append", [50]) -@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) -def test_sliding_window(block_size, prompt_len, num_slots_to_append, - sliding_window): - """Verify append_slots consumes the correct number of blocks from the block - table. - """ - - num_gpu_blocks = 1024 - watermark = 0.1 - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - watermark=watermark, - sliding_window=sliding_window, - ) - - def check_used(min_n, max_n=None): - if max_n is None: - max_n = min_n - used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() - assert min_n <= used - assert used <= max_n - - def num_blocks(num_tokens): - return (num_tokens + block_size - 1) // block_size - - check_used(0) - - seq_group = create_seq_group( - seq_prompt_len=prompt_len, - seq_output_lens=[0], - ) - - check_used(0) - - # Allocate seq - assert block_manager.can_allocate(seq_group) - block_manager.allocate(seq_group) - - check_used(num_blocks(prompt_len)) - - # Seq seq to RUNNING - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - seq.data.update_num_computed_tokens(prompt_len) - check_used(num_blocks(prompt_len)) - - # this is how we compute it in SelfAttnBlockSpaceManager.__init__ - sliding_blocks = (sliding_window // block_size) + 2 - # plus one block for null block - sliding_blocks += 1 - - # Append tokens to the sequeqnce - for token_id in range(num_slots_to_append): - seq.append_token_id(token_id, {token_id: Logprob(0.0)}) - seq.data.update_num_computed_tokens(1) - block_manager.append_slots(seq, num_lookahead_slots=0) - if prompt_len < sliding_window + 10: - check_used(0, sliding_blocks + 1) - else: - check_used(sliding_blocks, sliding_blocks + 1) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py deleted file mode 100644 index ba085001136be..0000000000000 --- a/tests/core/block/test_block_table.py +++ /dev/null @@ -1,577 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.core.block.block_table import BlockTable -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.utils import Device, cdiv, chunk_list - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -def test_allocate_naive(block_size: int, sequence_len: int): - """Test the allocation of blocks using the naive allocator. - - This test creates a CpuGpuBlockAllocator with the specified block size and - number of blocks. It then allocates multiple BlockTables with varying - sequence lengths and verifies that the number of free blocks decreases as - expected after each allocation. - """ - assert block_size > 1 - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type="naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) - - block_tables: list[BlockTable] = [] - for i in range(5): - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc - - block_tables.append( - BlockTable( - block_size=block_size, - block_allocator=allocator, - )) - block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -def test_allocate_prefix_caching(block_size: int, sequence_len: int): - """Test the allocation of blocks using the prefix caching allocator. - - This test creates a CpuGpuBlockAllocator with the specified block size and - number of blocks, using the prefix caching allocator. It then allocates - multiple BlockTables with varying sequence lengths and verifies that the - number of free blocks decreases as expected after each allocation. - - The test expects all sequences to share allocations, except for their last - block, which may be mutable. It calculates the expected number of immutable - and mutable blocks per allocation based on the sequence length and block - size. - """ - assert block_size > 1 - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - chunked_tokens = list(chunk_list(token_ids, block_size)) - num_mutable_blocks_per_alloc = 0 if len( - chunked_tokens[-1]) == block_size else 1 - num_immutable_blocks_per_alloc = len( - chunked_tokens) - num_mutable_blocks_per_alloc - - block_tables: list[BlockTable] = [] - for alloc_i in range(1, 6): - - block_tables.append( - BlockTable( - block_size=block_size, - block_allocator=allocator, - )) - block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) - - # Expect all sequences to share allocations, except for their last block - # (which may be mutable). - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_gpu_blocks - ( - num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc * - (alloc_i)) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -@pytest.mark.parametrize("device", ["cpu", "gpu"]) -def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, - device: str): - """Test the allocation and freeing of blocks using different allocators and - devices. - - This test creates a CpuGpuBlockAllocator with the specified block size, - number of blocks, allocator type, and device. It then allocates a BlockTable - multiple times with the same sequence and verifies that the number of free - blocks remains consistent after each allocation and freeing. - """ - device = Device[device.upper()] - - num_device_blocks = 1024 - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_device_blocks, - num_cpu_blocks=num_device_blocks, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - for i in range(5): - block_table.allocate(token_ids=token_ids, device=device) - assert allocator.get_num_free_blocks( - device) == num_device_blocks - num_blocks_per_alloc - assert all(block_id is not None - for block_id in block_table.physical_block_ids) - - block_table.free() - assert allocator.get_num_free_blocks(device) == num_device_blocks - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_append_token_ids_allocation(block_size: int, sequence_len: int, - append_len: int, allocator_type: str): - """Test the allocation behavior when appending token IDs to a BlockTable. - - This test creates a CpuGpuBlockAllocator with the specified block size, - number of blocks, and allocator type. It then allocates a BlockTable with an - initial sequence and appends additional token IDs to it. The test verifies - that the number of allocated blocks before and after appending matches the - expected values. - """ - - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) - num_expected_appended_blocks = len( - list(chunk_list(token_ids + token_ids_to_append, - block_size))) - num_expected_blocks_before_append - - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - assert len( - block_table.physical_block_ids) == num_expected_blocks_before_append - block_table.append_token_ids(token_ids_to_append) - assert len( - block_table.physical_block_ids - ) == num_expected_blocks_before_append + num_expected_appended_blocks - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("num_empty_slots", [1, 16, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, - num_empty_slots: int, - allocator_type: str): - """Test the allocation behavior when ensuring a certain number of empty - slots in a BlockTable. - - This test creates a CpuGpuBlockAllocator with the specified block size, - number of blocks, and allocator type. It then allocates a BlockTable with an - initial sequence and ensures a certain number of empty slots. The test - verifies that the number of allocated blocks before and after ensuring empty - slots matches the expected values. It also checks that filling up the empty - slots does not consume additional blocks. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) - num_expected_appended_blocks = len( - list(chunk_list(token_ids + [-1] * num_empty_slots, - block_size))) - num_expected_blocks_before_append - - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - # Assert that the empty slots consume the expected number of additional - # blocks. - assert len( - block_table.physical_block_ids) == num_expected_blocks_before_append - block_table.ensure_num_empty_slots(num_empty_slots) - assert len( - block_table.physical_block_ids - ) == num_expected_blocks_before_append + num_expected_appended_blocks - - # Now, ensure no additional blocks consumed as we fill up the empty slots. - num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU) - block_table.append_token_ids(token_ids=list(range(num_empty_slots))) - assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU) - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 9]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("append_size", [1, 4, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_append_token_ids_correct_content(block_size: int, sequence_len: int, - append_len: int, allocator_type: str, - append_size: int): - """Verify token ids are correctly appended. Appends various amounts of - token ids in various append sizes, and verifies the final sequence is - correct. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - appended_so_far: list[int] = [] - for append in chunk_list(token_ids_to_append, append_size): - block_table.append_token_ids(append) - appended_so_far.extend(append) - - assert block_table._get_all_token_ids() == token_ids + appended_so_far - - assert block_table._get_all_token_ids() == token_ids + token_ids_to_append - - -@pytest.mark.parametrize("seq_len", [1, 9, 129]) -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_fork(seq_len: int, block_size: int, allocator_type: str): - """Create a sequence using the specified allocator. - 1. Assert that after forking the sequence, the free block count is the - same. - 2. Assert that the forked sequence has the same physical mappings. - 3. Then free the original sequence; verify that the free block count is - the same. - 4. Finally, free the forked sequence and verify that the free block - count drops to zero. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(seq_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - block_table.allocate(token_ids) - - num_free_blocks_before_fork = allocator.get_num_free_blocks( - device=Device.GPU) - - forked_block_table = block_table.fork() - - # Expect physical_block_ids and token_ids to match. - assert (block_table.physical_block_ids == - forked_block_table.physical_block_ids) - assert block_table._get_all_token_ids( - ) == forked_block_table._get_all_token_ids() - - # Do not expect any additional allocations. - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_free_blocks_before_fork - - # Free the original blocks. Assert num free blocks does not change, since - # refcount is nonzero. - block_table.free() - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_free_blocks_before_fork - - # Expect the forked block table to be unaffected by the free. - assert all(block_id is not None - for block_id in forked_block_table.physical_block_ids) - - # Free the forked blocks. Assert num free blocks does change, since - # refcount is now zero. - forked_block_table.free() - assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("appender", ["forked", "original"]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_cow(block_size: int, sequence_len: int, append_len: int, - allocator_type: str, appender: str): - """Fork a sequence; append to the forked sequence; verify there's a CoW. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - original_block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - num_expected_non_cow_blocks = cdiv(sequence_len, block_size) - num_expected_cow_blocks = cdiv(sequence_len + append_len, - block_size) - (sequence_len // block_size) - - original_block_table.allocate(token_ids=token_ids, device=Device.GPU) - original_block_ids = original_block_table.physical_block_ids[:] - - print("original_block_ids = {}".format(original_block_ids)) - forked_block_table = original_block_table.fork() - - # Expect no additional allocation (copy on _write_). - assert allocator.get_num_free_blocks( - Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks) - - if appender == "forked": - appender_block_table = forked_block_table - static_block_table = original_block_table - elif appender == "original": - appender_block_table = original_block_table - static_block_table = forked_block_table - else: - raise ValueError(f"unknown test config {appender=}") - - # Write tokens. - appender_block_table.append_token_ids(token_ids_to_append) - - # Expect the non-appending block table to have no change. - assert static_block_table.physical_block_ids == original_block_ids - assert appender_block_table.physical_block_ids != original_block_ids - - # Expect the blocks changed during append to have a CoW. - assert allocator.get_num_free_blocks( - Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks + - num_expected_cow_blocks) - - cows = allocator.clear_copy_on_writes() - if sequence_len % block_size > 0: - # If the last block in the sequence is not full, then when appending we - # expect a CoW. - assert cows - - cow_block_id = sequence_len // block_size - expected_src = static_block_table.physical_block_ids[cow_block_id] - expected_dst = appender_block_table.physical_block_ids[cow_block_id] - - assert (expected_src, expected_dst) in cows - else: - # Otherwise, there should be no copy-on-write. - assert not cows - - static_block_table.free() - appender_block_table.free() - - # After free, expect all blocks to be freed. - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("lookahead_slots", [1, 16, 129]) -@pytest.mark.parametrize("appender", ["forked", "original"]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_cow_lookahead_simple(block_size: int, sequence_len: int, - append_len: int, lookahead_slots: int, - allocator_type: str, appender: str): - """Similar to test_cow, except with lookahead allocation. The assertions are - less rigorous due to the complexity of the property under test. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - original_block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - original_block_table.allocate(token_ids=token_ids, device=Device.GPU) - - # Allocate lookahead slots. - original_block_table.ensure_num_empty_slots(lookahead_slots) - original_block_ids = original_block_table.physical_block_ids[:] - - forked_block_table = original_block_table.fork() - - if appender == "forked": - appender_block_table = forked_block_table - static_block_table = original_block_table - elif appender == "original": - appender_block_table = original_block_table - static_block_table = forked_block_table - else: - raise ValueError(f"unknown test config {appender=}") - - # Write tokens. - appender_block_table.append_token_ids(token_ids_to_append) - - # Expect the non-appending block table to have no change. - assert static_block_table.physical_block_ids == original_block_ids - assert appender_block_table.physical_block_ids != original_block_ids - - cows = allocator.clear_copy_on_writes() - - # Always expect copy-on-write - assert cows - - if sequence_len % block_size > 0: - # If the last block in the sequence is not full, then when appending we - # expect a CoW. - assert cows - - cow_block_id = sequence_len // block_size - expected_src = static_block_table.physical_block_ids[cow_block_id] - expected_dst = appender_block_table.physical_block_ids[cow_block_id] - - assert (expected_src, expected_dst) in cows - - static_block_table.free() - appender_block_table.free() - - # After free, expect all blocks to be freed. - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("num_new_tokens", [1, 16, 129]) -@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, - num_new_tokens: int, - num_lookahead_slots: int, - allocator_type: str): - """Verify correct calculation of get_num_blocks_touched_by_append_slots. - - This is done by using copy-on-write, which requires any modified block to - be copied before write if the refcount > 1. We set the refcount>1 by forking - a sequence, then measure the free blocks before and after an append. If the - number of consumed blocks equals what `get_num_blocks_touched_by_append_ - slots` returns, then the calculation is correct. - """ - - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(num_new_tokens)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - # Add lookahead before fork so both sequences have the same lookahead - # blocks. - block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots) - - # Fork sequence so that every block has refcount > 1. - _ = block_table.fork() - - # Determine how many blocks should be touched. - expected_num_touched_blocks = ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=token_ids_to_append, - num_lookahead_slots=num_lookahead_slots)) - - # Measure how many blocks are touched by measuring num_free_blocks before - # and after the append. - # - # We expect append_token_ids to CoW all mutated blocks that have refcount>1. - num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU) - block_table.append_token_ids(token_ids_to_append, num_lookahead_slots) - num_consumed_blocks = (num_free_blocks_before_append - - allocator.get_num_free_blocks(Device.GPU)) - - # TODO(cade) ensure equality when num_lookahead_slots > 0. - # The reason we have < is because lookahead blocks are not copied eagerly; - # they are copied on first write. This will cause issues for beam search + - # speculative decoding. This is acceptable for now as it is a large effort - # to combine the two. To fix this, we can ensure single sequence ownership - # of lookahead blocks by appending empty slots to each block, which will - # trigger the CoW. - # - # Until then, we can accept that the consumed tokens are <= the expected - # tokens when appending with lookahead. - if num_lookahead_slots > 0: - assert num_consumed_blocks <= expected_num_touched_blocks - else: - assert num_consumed_blocks == expected_num_touched_blocks diff --git a/tests/core/block/test_common.py b/tests/core/block/test_common.py deleted file mode 100644 index 65400899b811c..0000000000000 --- a/tests/core/block/test_common.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random - -import pytest - -from vllm.core.block.common import RefCounter - - -@pytest.mark.parametrize("seed", list(range(20))) -@pytest.mark.parametrize("num_incrs", [1, 100]) -@pytest.mark.parametrize("num_blocks", [1024]) -def test_incr(seed: int, num_incrs: int, num_blocks: int): - random.seed(seed) - - all_block_indices = list(range(num_blocks)) - counter = RefCounter(all_block_indices=all_block_indices) - - block_id = random.randint(0, num_blocks - 1) - for i in range(num_incrs): - value = counter.incr(block_id) - assert value == i + 1 - - -@pytest.mark.parametrize("seed", list(range(20))) -@pytest.mark.parametrize("num_incrs", [1, 100]) -@pytest.mark.parametrize("num_blocks", [1024]) -def test_incr_decr(seed: int, num_incrs: int, num_blocks: int): - random.seed(seed) - - all_block_indices = list(range(num_blocks)) - counter = RefCounter(all_block_indices=all_block_indices) - - block_id = random.randint(0, num_blocks - 1) - for i in range(num_incrs): - value = counter.incr(block_id) - assert value == i + 1 - - for i in range(num_incrs): - value = counter.decr(block_id) - assert value == num_incrs - (i + 1) - - with pytest.raises(AssertionError): - counter.decr(block_id) diff --git a/tests/core/block/test_cpu_gpu_block_allocator.py b/tests/core/block/test_cpu_gpu_block_allocator.py deleted file mode 100644 index 795eef6743fd1..0000000000000 --- a/tests/core/block/test_cpu_gpu_block_allocator.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.utils import Device, chunk_list - - -@pytest.mark.parametrize("num_cpu_blocks", [0, 512]) -@pytest.mark.parametrize("num_gpu_blocks", [1024]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - cpu_blocks = [ - allocator.allocate_mutable_block(prev_block=None, device=Device.CPU) - for _ in range(num_cpu_blocks) - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - gpu_blocks = [ - allocator.allocate_mutable_block(prev_block=None, device=Device.GPU) - for _ in range(num_gpu_blocks) - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in cpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in gpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("num_cpu_blocks", [0, 512]) -@pytest.mark.parametrize("num_gpu_blocks", [1024]) -@pytest.mark.parametrize("block_size", [2]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - unique_token_ids = list( - range((num_cpu_blocks + num_gpu_blocks) * block_size)) - gpu_token_ids = list( - chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size)) - cpu_token_ids = list( - chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size)) - - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - cpu_blocks = [ - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids, - device=Device.CPU) - for token_ids in cpu_token_ids - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - gpu_blocks = [ - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids, - device=Device.GPU) - for token_ids in gpu_token_ids - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in cpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in gpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py deleted file mode 100644 index a31d1c46b37f0..0000000000000 --- a/tests/core/block/test_naive_block.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest - -from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator - - -class TestNaiveBlockAllocator: - - @staticmethod - def create_allocate_lambda(allocate_type: str, - allocator: NaiveBlockAllocator, - prev_block: Optional[Block], - token_ids: list[int]): - if allocate_type == "immutable": - allocate_block = lambda: allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=token_ids) - elif allocate_type == "mutable": - allocate_block = lambda: allocator.allocate_mutable_block( - prev_block=prev_block) - else: - raise ValueError() - - return allocate_block - - @staticmethod - @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_ooms(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) - - [allocate_block() for _ in range(num_blocks)] - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - @staticmethod - @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_free_prevents_oom(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) - - blocks = [allocate_block() for _ in range(num_blocks)] - - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - block_to_free = blocks.pop() - - for _ in range(100): - block_id = block_to_free.block_id - allocator.free(block_to_free) - assert block_to_free.block_id is None - - new_block = allocate_block() - assert new_block.block_id == block_id - - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - block_to_free = new_block - - @staticmethod - @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - def test_get_num_free_blocks(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) - - assert allocator.get_num_free_blocks() == num_blocks - - blocks = [allocate_block() for _ in range(num_blocks)] - - for i, block in enumerate(blocks): - assert allocator.get_num_free_blocks() == i - allocator.free(block) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [4]) - @pytest.mark.parametrize("block_size", [8]) - def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): - """ Verify the allocator can correctly return the number of - full blocks touched. - """ - allocator_src = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - - # Create a chain of cacheable blocks in the dst - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - "immutable", - allocator_src, - prev_block=None, - token_ids=list(range(block_size))) - src_blocks = [allocate_block() for _ in range(num_blocks - 1)] - - # All blocks are cached - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks - 1 - - # Insert one non-full block in the src - allocate_non_full_block = \ - TestNaiveBlockAllocator.create_allocate_lambda( - "mutable", allocator_src, - prev_block=src_blocks[-1],token_ids=[] - ) - src_blocks.append(allocate_non_full_block()) - src_blocks[-1].append_token_ids([0]) - - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks - 1 - # Fill up the last source block and then invoke - # get_num_blocks_touched - src_blocks[-1].append_token_ids([0] * (block_size - 1)) - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py deleted file mode 100644 index 46e224c6f53b2..0000000000000 --- a/tests/core/block/test_prefix_caching_block.py +++ /dev/null @@ -1,1035 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -import random -from typing import Optional -from unittest.mock import MagicMock - -import pytest - -from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - PrefixCachingBlock, - PrefixCachingBlockAllocator) -from vllm.sequence import Logprob -from vllm.utils import Device - - -class TestPrefixCachingBlock: - - @staticmethod - @pytest.mark.parametrize("seed", list(range(10))) - @pytest.mark.parametrize("block_size", [1, 16]) - @pytest.mark.parametrize("is_curr_block_full", [True, False]) - def test_first_block_has_correct_content_hash(seed: int, block_size: int, - is_curr_block_full: bool): - """Verify a block which is first in the sequence has the correct hash. - """ - random.seed(seed) - num_to_fill = block_size if is_curr_block_full else random.randint( - 0, block_size - 1) - token_ids = list(range(num_to_fill)) - mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) - - block_with_prev = PrefixCachingBlock(prev_block=None, - token_ids=token_ids, - block_size=block_size, - allocator=mock_allocator) - - if is_curr_block_full: - # Expect hash since block is full. - assert block_with_prev.content_hash == ( - PrefixCachingBlock.hash_block_tokens( - is_first_block=True, - prev_block_hash=None, - cur_block_token_ids=token_ids)) - else: - # Do not expect hash since block is not full. - assert block_with_prev.content_hash is None - - @staticmethod - @pytest.mark.parametrize("seed", list(range(10))) - @pytest.mark.parametrize("block_size", [1, 16]) - @pytest.mark.parametrize("is_curr_block_full", [True, False]) - @pytest.mark.parametrize("prev_block_has_hash", [True, False]) - def test_nth_block_has_correct_content_hash(seed: int, block_size: int, - is_curr_block_full: bool, - prev_block_has_hash: bool): - """Verify a block which is not first in the sequence has the correct - hash. - """ - - random.seed(seed) - - previous_block = MagicMock(spec=PrefixCachingBlock) - prev_block_hash = random.randint(0, 1000) - previous_block.content_hash = (prev_block_hash if prev_block_has_hash - else hash('None')) - - num_to_fill = block_size if is_curr_block_full else random.randint( - 0, block_size - 1) - token_ids = list(range(num_to_fill)) - mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) - - block_with_prev = PrefixCachingBlock( - prev_block=previous_block, - token_ids=token_ids, - block_size=block_size, - allocator=mock_allocator, - ) - - if is_curr_block_full and prev_block_has_hash: - # Expect hash since block is full and previous block has hash. - assert (block_with_prev.content_hash == - PrefixCachingBlock.hash_block_tokens( - is_first_block=False, - prev_block_hash=prev_block_hash, - cur_block_token_ids=token_ids)) - else: - # Do not expect hash since block is not full or the previous block - # does not have a hash. - assert block_with_prev.content_hash is None - - @staticmethod - @pytest.mark.parametrize("block_size", [1, 2, 16]) - @pytest.mark.parametrize("num_tokens", list(range(3))) - @pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10]) - def test_blocks_have_correct_hash_in_chain(block_size: int, - num_tokens: int, - num_empty_trailing_blocks: int): - """Create two chains of logical blocks with the same contents. - Assert the hashes are equal. - """ - random.seed(0) - - token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)] - - first_chain, second_chain = (TestPrefixCachingBlock.create_chain( - block_size=block_size, - token_ids=token_ids, - num_empty_trailing_blocks=num_empty_trailing_blocks) - for _ in range(2)) - - for first_chain_block, second_chain_block in zip( - first_chain, second_chain): - assert (first_chain_block.content_hash == - second_chain_block.content_hash) - - if not first_chain or not second_chain: - assert first_chain == second_chain - assert num_tokens == 0 - - @staticmethod - def create_chain(block_size: int, - token_ids: list[int], - num_empty_trailing_blocks=0) -> list[PrefixCachingBlock]: - """Helper method which creates a chain of blocks. - """ - blocks: list[PrefixCachingBlock] = [] - num_blocks = math.ceil( - len(token_ids) / block_size) + num_empty_trailing_blocks - - if num_blocks == 0: - return [] - - allocator = MagicMock(spec=PrefixCachingBlockAllocator) - - prev_block = None - for block_number in range(0, num_blocks): - prev_block = PrefixCachingBlock( - prev_block=prev_block, - token_ids=[], - block_size=block_size, - allocator=allocator, - ) - - tokens_to_append = token_ids[block_number * - block_size:(block_number + 1) * - block_size] - if tokens_to_append: - prev_block.append_token_ids(tokens_to_append) - - blocks.append(prev_block) - - return blocks - - -class TestPrefixCachingBlockAllocator: - - @staticmethod - def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, - prev_block: Optional[Block], - token_ids: list[int]): - if allocate_type == "immutable": - allocate_block = lambda: allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=token_ids) - elif allocate_type == "mutable": - allocate_block = lambda: allocator.allocate_mutable_block( - prev_block=prev_block) - else: - raise ValueError() - - return allocate_block - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_mutable_ooms(num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( - allocate_type="mutable", - allocator=allocator, - prev_block=None, - token_ids=list(range(block_size)), - ) - - [allocate_block() for _ in range(num_blocks)] - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_immutable_does_not_oom_single_hash( - num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( - allocate_type="immutable", - allocator=allocator, - prev_block=None, - token_ids=list(range(block_size)), - ) - - blocks = [allocate_block() for _ in range(num_blocks)] - - # Expect no OOM. If these were mutable blocks, this would OOM. - non_oom_block = allocate_block() - - # Expect all blocks to have same physical block index. - for block in blocks: - assert (block.block_id == non_oom_block.block_id) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_immutable_ooms_many_hash(num_blocks: int, - block_size: int): - """Consume all blocks using many different hashes/block content. - - Do this by creating a sequence that is very long. - Expect next block to OOM. - """ - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks * block_size)) - - chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Expect allocation with unseen hash to fail. - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_immutable_block(prev_block=chain[-1], - token_ids=list( - range(block_size))) - - # Expect mutable allocation to fail. - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable_block(prev_block=chain[-1]) - - # Expect allocation of exact same chain to pass. - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Expect physical block indices to be the same in both chains. - assert chain and second_chain - for first_chain_block, second_chain_block in zip(chain, second_chain): - assert (first_chain_block.block_id == second_chain_block.block_id) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_free_prevents_oom(num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks * block_size)) - - chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Expect mutable allocation to fail. - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable_block(prev_block=None) - - block_to_free = chain[-1] - - # Expect free/allocate loop to succeed many times. - for i in range(100): - block_id = block_to_free.block_id - allocator.free(block_to_free) - assert block_to_free.block_id is None, i - - new_block = allocator.allocate_mutable_block(prev_block=None) - assert new_block.block_id == block_id, i - - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable_block(prev_block=None) - - block_to_free = new_block - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - num_blocks_to_consume = random.randint(1, num_blocks - 1) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) - - chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Free each block in chain, assert num free blocks includes new free - # block. - for i, block in enumerate(chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume + - i) - allocator.free(block) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [4]) - @pytest.mark.parametrize("block_size", [8]) - def test_prefix_caching_block_get_num_full_blocks_touched( - num_blocks, block_size): - """ Verify the allocator can correctly return the number of - blocks touched, when there are cached prefixes. - """ - allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocator_dst = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - # Create token ids that will exhaust all blocks except the last - token_ids = list(range((num_blocks - 1) * block_size)) - - # Create a chain of cacheable blocks in the dst - cached_blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator_dst, - ) - - # Create a chain of the same blocks in the src - blocks_to_swap_in = \ - TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator_src, - ) - # All blocks are cached - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 0 - - # Free the first block in the dst - allocator_dst.free(cached_blocks[0]) - - # Now the first block becomes dangling, the swapped blocks need - # to reclaim the first block in the dst - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 1 - - # Insert one non-full block in the src - non_full_block = allocator_src.allocate_mutable_block( - blocks_to_swap_in[-1]) - non_full_block.append_token_ids([0]) - blocks_to_swap_in.append(non_full_block) - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 1 - # Fill up the last mutable block and invoke get_num_blocks_touched. - # Note: The last block is not cached so it will be touched. - non_full_block.append_token_ids([0] * (block_size - 1)) - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 2 - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, - seed: int): - """Verify sharing occurs by allocating two sequences that share prefixes - and incrementally freeing blocks. - """ - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - num_blocks_to_consume = random.randint(1, num_blocks - 1) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) - - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Free each block in the first chain. Since all blocks are shared, the - # free count should stay constant. - for i, block in enumerate(first_chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume) - allocator.free(block) - - # Free each block in the second chain. Since the refcount is now zero, - # the free count should increment with each free. - for i, block in enumerate(second_chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume + - i) - allocator.free(block) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_get_common_computed_block_ids(num_blocks: int, block_size: int, - seed: int): - """Verify get_common_computed_block_ids could get correct result - by create two immutable chain sharing prefix at specified pos, - and compare whether we also could get right result - from get_common_computed_block_ids. - """ - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, - block_size=block_size) - num_blocks_to_consume = random.randint(1, num_blocks - 1) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) - - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # After zero_point, second_chain's token_ids would be set -1, which - # make it different from here comparing with first_chain - zero_point = random.randint(1, len(token_ids) - 1) - zero_point_blocks = zero_point // block_size - token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) - - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - first_computed_ids = [ - first_chain[i].block_id for i in range(num_blocks_to_consume) - ] - second_computed_ids = [ - second_chain[i].block_id for i in range(num_blocks_to_consume) - ] - res = allocator.get_common_computed_block_ids( - [first_computed_ids, second_computed_ids]) - - assert (len(res) == zero_point_blocks) - - # Test case that assume those prompted block after first immutable would - # be freed into hashless allocator, while first immutable block get ref - # increased. - @staticmethod - @pytest.mark.parametrize("num_blocks", [3]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(10))) - def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): - random.seed(seed) - - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - token_ids = list(range(block_size)) - - block = allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - - assert allocator._refcounter.get(block.block_id) == 1 - m = allocator.allocate_mutable_block(prev_block=None) - - block_id = m.block_id - for i in range(block_size): - m.append_token_ids([i]) - - # After block get promoted to immutable from mutable, if there is - # already same content hash block, then it shall be released into - # hashless_allocator - # And first immutable block's ref get increased by 1 - assert m.block_id == block.block_id - assert block_id in allocator._hashless_allocator._free_block_indices - assert allocator._refcounter.get(block.block_id) == 2 - - # Test case when eviction and allocation are mixed, - # make sure they work as expected - @staticmethod - @pytest.mark.parametrize("num_blocks", [3]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(10))) - def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): - random.seed(seed) - - all_blocks_list = [i for i in range(num_blocks)] - zero_ref = {i: 0 for i in range(num_blocks)} - one_ref = {i: 1 for i in range(num_blocks)} - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - token_ids = list(range(num_blocks * block_size)) - - # Verify initial/pre-alloc state - - # Ensure all blocks are free inside hashless allocator - assert list(allocator._hashless_allocator._free_block_indices - ) == all_blocks_list - # Ensure no tracked blocks - assert len(allocator._block_tracker.keys()) == num_blocks - for block_id in range(num_blocks): - assert not allocator._block_tracker[block_id].active - # Ensure no cached blocks - assert len(allocator._cached_blocks.values()) == 0 - # Ensure no evicted blocks - assert len(allocator.evictor.free_table.keys()) == 0 - # Ensure 0s ref counts for all blocks - assert allocator._refcounter._refcounts == zero_ref - - # Allocate immutable chains with only one block residuled in - new_block = [] - for i in range(num_blocks): - block = allocator.allocate_immutable_block( - prev_block=None, - token_ids=token_ids[block_size * i:block_size * (i + 1)]) - new_block.append(block) - - # Verify post-alloc state - - # Ensure no blocks are free inside hashless allocator - assert (len(allocator._hashless_allocator._free_block_indices) == 0) - # Ensure all blocks are tracked - assert len(allocator._block_tracker.keys()) == num_blocks - for block_id in range(num_blocks): - assert allocator._block_tracker[block_id].active - # Ensure all blocks are cached (all promoted) - assert len(allocator._cached_blocks.values()) == num_blocks - # Ensure no evicted blocks - assert len(allocator.evictor.free_table.keys()) == 0 - # Ensure 1s ref counts for all blocks - assert allocator._refcounter._refcounts == one_ref - - # Free all blocks, and now all blocks shall be in the evictor - # there shall be no tracking data left in _block_tracker - # all blocks shall be tracked in _cached_blocks - # all blocks' ref shall be zero - for block in new_block: - allocator.free(block) - - # Verify post-free state - - # Ensure no tracked blocks - assert len(allocator._block_tracker.keys()) == num_blocks - for block_id in range(num_blocks): - assert not allocator._block_tracker[block_id].active - # Ensure no blocks in hashless allocator (all promoted) - assert len(allocator._hashless_allocator._free_block_indices) == 0 - # Ensure all blocks are cached - assert list(allocator._cached_blocks.values()) == all_blocks_list - # Ensure all blocks are inside the evictor - assert list(allocator.evictor.free_table.keys()) == all_blocks_list - # Ensure 0s refcounts - assert allocator._refcounter._refcounts == zero_ref - - # Allocate a mutable block, and the first block shall be evicted - # and set its content hash into None, ref to 1 - mutable = allocator.allocate_mutable_block(prev_block=None) - - assert mutable.block_id == 0 - assert mutable.content_hash is None - assert allocator._block_tracker[0].active - assert allocator._refcounter.get(0) == 1 - assert 0 not in allocator._cached_blocks - assert 0 not in allocator.evictor - - # Since this mutable block has no hash yet, it shall be released into - # hashless allocator - allocator.free(mutable) - - assert not allocator._block_tracker[0].active - assert allocator._refcounter._refcounts == zero_ref - assert 0 not in allocator._cached_blocks - assert 0 not in allocator.evictor - assert 0 in allocator._hashless_allocator._free_block_indices - - # When allocate immutable with first block_size tokens, we - # shall get free block from hashless allocator, thus no block left - # in hashless - block = allocator.allocate_immutable_block( - prev_block=None, token_ids=token_ids[:block_size]) - - assert block.block_id == 0 - assert len(allocator._hashless_allocator._free_block_indices) == 0 - assert allocator._block_tracker[0].active - assert 0 in allocator._cached_blocks.values() - assert allocator._refcounter.get(0) == 1 - assert 0 not in allocator.evictor - - # allocate mutable block again, it shall be popped from evictor - mutable = allocator.allocate_mutable_block(prev_block=None) - assert len(allocator._hashless_allocator._free_block_indices) == 0 - assert mutable.block_id not in allocator.evictor.free_table - assert allocator._refcounter.get(mutable.block_id) == 1 - - # Test case where two last accessed times are equal - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_eviction_order(num_blocks: int, block_size: int, seed: int): - """This test case simulate the two chain created and free in order, - and together they would exhaust the initial freed blocks. - - So the next block created after those two chain shall use the block - from the first chain as that block has long access time. - While first chain has two blocks, it shall pick up the last one, as - it has larger token number. - """ - - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - num_blocks_to_consume = num_blocks + 1 - - token_ids = list(range(num_blocks_to_consume * block_size)) - - num_blocks_in_first_chain = 2 - num_tokens_in_first_chain = block_size * num_blocks_in_first_chain - # First chain takes the first block - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids[:num_tokens_in_first_chain], - allocator=allocator, - ) - # There should only be one block allocated at this point - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_in_first_chain) - - # Set the last accessed time of the first block to 1 - blocks_ids = [block.block_id for block in first_chain] - allocator.mark_blocks_as_accessed(blocks_ids, 1) - - # Second chain takes the rest of the blocks - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids[num_tokens_in_first_chain:-block_size], - allocator=allocator, - ) - - # There shouldn't be any blocks left at this point - assert allocator.get_num_free_blocks() == (0) - - assert len(first_chain) == num_blocks_in_first_chain - last_block_id = first_chain[-1].block_id - # Free each block in the first chain. - for i, block in enumerate(first_chain): - allocator.free(block) - - # Set the last accessed time on all of the blocks in the second chain - # to 2 - blocks_ids = [block.block_id for block in second_chain] - allocator.mark_blocks_as_accessed(blocks_ids, 2) - - # Free each block in the second chain. - for i, block in enumerate(second_chain): - allocator.free(block) - - # Allocate a new block and check that it's the least recently used block - # from the first chain. - new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids[-block_size:], - allocator=allocator, - ) - - assert new_block[0].block_id == last_block_id - - # Test case for cache mertics - @staticmethod - def test_metric(): - block_size = 16 - allocator = PrefixCachingBlockAllocator(num_blocks=4, - block_size=block_size) - # Test when no query (0/0) - assert allocator.get_prefix_cache_hit_rate() == 0.0 - - token_ids = list(range(block_size)) - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - # Test 0/1 hit rate - assert allocator.get_prefix_cache_hit_rate() == 0.0 - - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - # Test 1/2 hit rate - assert allocator.get_prefix_cache_hit_rate() == 0.5 - - # Test more than one block - for _ in range(2, 1005): - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - assert allocator.get_prefix_cache_hit_rate() > 0.99 - - # Test case for marking cache hit blocks as computed right after - # a batch of prefill sequences are scheduled. - @staticmethod - def test_touch_block(): - block_size = 16 - common_blocks = 4 - allocator = PrefixCachingBlockAllocator(num_blocks=8, - block_size=block_size) - - common_token_ids = list(range(block_size * common_blocks)) - - # Mimic the behavior of allocating the same block chain - # (i.e., common prefix) for a batch of 3 different prefill sequences. - for _ in range(3): - blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=common_token_ids, - allocator=allocator, - ) - block_hashes = [block.content_hash for block in blocks] - # The allocated blocks should be marked as touched - # but not computed. - computed_block_ids = allocator.find_cached_blocks_prefix( - block_hashes) - assert len(computed_block_ids) == 0 - - allocator.mark_blocks_as_computed([]) - computed_block_ids = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes) - assert len(computed_block_ids) == common_blocks - - @staticmethod - def test_find_cached_blocks_prefix(): - """ - This test verifies the behavior of find_cached_blocks_prefix. - """ - block_size = 4 - num_blocks = 8 - total_test_blocks = 12 - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - token_ids = list(range(total_test_blocks * block_size)) - block_tokens_seq1 = token_ids[:num_blocks * block_size] - blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=block_tokens_seq1, - allocator=allocator, - ) - block_hashes_seq1 = [block.content_hash for block in blocks_seq1] - allocator.mark_blocks_as_computed([]) - - # All blocks should be cached. - cached_blocks_seq1 = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) - assert len(cached_blocks_seq1) == num_blocks - - # Free the first sequence. - for block in blocks_seq1: - allocator.free(block) - - # All blocks should be still be cached if not required to be allocated. - cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) - assert len(cached_blocks) == num_blocks - - block_tokens_seq2 = token_ids[num_blocks * block_size:] - blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=block_tokens_seq2, - allocator=allocator, - ) - block_hashes_seq2 = [block.content_hash for block in blocks_seq2] - allocator.mark_blocks_as_computed([]) - cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq2) - assert len(cached_blocks) == len(blocks_seq2) - - # Half of the blocks from seq1 should still be cached. - num_evicted_blocks = len(blocks_seq2) - cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) - assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks - - # Test reset prefix cache - @staticmethod - @pytest.mark.parametrize("num_blocks", [10]) - @pytest.mark.parametrize("block_size", [16]) - def test_reset_prefix_cache(num_blocks: int, block_size: int): - """This test case simulates the case of resetting the prefix cache.""" - - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - token_ids = list(range(3 * block_size)) - - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Free each block in the first chain. - for block in first_chain: - allocator.free(block) - - # Failed to reset prefix cache because some blocks are not freed yet. - assert not allocator.reset_prefix_cache() - assert allocator.get_prefix_cache_hit_rate() > 0.0 - - # Free each block in the second chain. - for block in second_chain: - allocator.free(block) - - # Reset prefix cache. - assert allocator.reset_prefix_cache() - assert allocator.get_prefix_cache_hit_rate() == 0.0 - - @staticmethod - def create_immutable_chain( - block_size: int, - token_ids: list[int], - allocator: PrefixCachingBlockAllocator, - extra_hash: Optional[int] = None, - ) -> list[PrefixCachingBlock]: - """Helper method which creates a chain of blocks. - """ - blocks: list[Block] = [] - num_blocks = math.ceil(len(token_ids) / block_size) - - if num_blocks == 0: - return [] - - prev_block = None - for block_number in range(0, num_blocks): - block_token_ids = token_ids[block_number * - block_size:(block_number + 1) * - block_size] - prev_block = allocator.allocate_immutable_block( - prev_block=prev_block, - token_ids=block_token_ids, - extra_hash=extra_hash) - blocks.append(prev_block) - - return blocks - - -class TestComputedBlocksTracker: - - @staticmethod - def _get_mock_allocator(): - return MagicMock(spec=PrefixCachingBlockAllocator) - - @staticmethod - def test_get_num_cached_tokens(): - """ - Test it correctly computes the number of cached tokens for a given - sequence: - - - The cache token count is derived from the number of cached blocks. - - The cache token count is updated when the allocator is updated. - - When a sequence is removed, the cache token count should be updated - accordingly. - - # TODO(rickyx): This behaviour for prefill sequence is a hack until - we fix the computed blocks tracking. - - The cache token count for prefill sequence doesn't change while - the sequence is in continuous prefill (chunked prefill). - """ - block_size = 4 - mock_allocator = TestComputedBlocksTracker._get_mock_allocator() - tracker = ComputedBlocksTracker( - allocator=mock_allocator, - block_size=block_size, - enable_caching=True, - ) - - # Not yet allocated. - tokens = [0, 1, 2, 3, 4, 5] - seq1 = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - mock_allocator.find_cached_blocks_prefix.return_value = [] - assert tracker.get_num_cached_tokens(seq1) == 0 - - mock_allocator.find_cached_blocks_prefix.return_value = [ - None - ] # 1 block cached. - # Result is cached for prefill sequence. - assert tracker.get_num_cached_tokens(seq1) == 0 - - # Mark the sequence as non-prefill. - seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed. - assert not seq1.is_prefill() - - # Recomputes for decoding sequence. - assert tracker.get_num_cached_tokens(seq1) == 4 - - # Append new tokens to the sequence. - num_new_tokens = 3 - for i in range(num_new_tokens): - seq1.append_token_id(i, {i: Logprob(logprob=0.0)}) - - assert tracker.get_num_cached_tokens(seq1) == 4 - - # Update the allocator. - mock_allocator.find_cached_blocks_prefix.return_value = [ - None - ] * 2 # 2 blocks cached. - assert tracker.get_num_cached_tokens(seq1) == 8 - - # Remove the sequence. - tracker.remove_seq(seq1.seq_id) - - # Re-create the sequence with the same request id to simulate recompute. - seq1 = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - mock_allocator.find_cached_blocks_prefix.return_value = [ - ] # no cached block - assert tracker.get_num_cached_tokens(seq1) == 0 - - @staticmethod - def test_correct_block_hash(): - """ - Test that the block hash is correctly computed for a sequence (should - match the underlying block allocator's block hash). So the number of - cached tokens is correctly retrieved. - """ - block_size = 4 - allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching", - num_gpu_blocks=16, - num_cpu_blocks=16, - block_size=block_size, - ) - gpu_allocator = allocator._allocators[Device.GPU] - - tracker = ComputedBlocksTracker( - allocator=allocator, - block_size=block_size, - enable_caching=True, - ) - - tokens = list(range(block_size * 4)) # 4 blocks. - seq = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - _ = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=tokens, - allocator=gpu_allocator, - ) - allocator.mark_blocks_as_computed([]) - - assert tracker.get_num_cached_tokens(seq) == len(tokens) - - @staticmethod - def test_correct_extra_hash(): - """ - Test that the block hash is correctly computed based on the extra hash, - ensuring it matches the allocator's block hash, specifically for the - LoRA case, and that the correct number of cached tokens is retrieved. - """ - block_size = 4 - allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching", - num_gpu_blocks=16, - num_cpu_blocks=16, - block_size=block_size, - ) - gpu_allocator = allocator._allocators[Device.GPU] - - tracker = ComputedBlocksTracker( - allocator=allocator, - block_size=block_size, - enable_caching=True, - ) - - tokens = list(range(block_size * 4)) - - # Create a dummy LoRA sequence with a specific LoRA ID. - lora_seq = create_dummy_lora_sequence(request_id=0, - token_ids=tokens, - block_size=block_size, - lora_int_id=1) - - _ = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=tokens, - allocator=gpu_allocator, - extra_hash=lora_seq.extra_hash(), - ) - - allocator.mark_blocks_as_computed([]) - - # Create different dummy sequences that have the same token IDs - # but different LoRA IDs. - seq = create_dummy_sequence(request_id=1, - token_ids=tokens, - block_size=block_size) - - different_lora_seq = create_dummy_lora_sequence(request_id=2, - token_ids=tokens, - block_size=block_size, - lora_int_id=2) - - # Due to the different LoRA IDs, corresponding blocks are not cached. - assert tracker.get_num_cached_tokens(seq) == 0 - assert tracker.get_num_cached_tokens(different_lora_seq) == 0 - - # The number of cached tokens matches the length of the tokens - # for the cached LoRA sequence. - assert tracker.get_num_cached_tokens(lora_seq) == len(tokens) diff --git a/tests/core/conftest.py b/tests/core/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/core/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py deleted file mode 100644 index ce1fe189b3ca1..0000000000000 --- a/tests/core/test_chunked_prefill_scheduler.py +++ /dev/null @@ -1,858 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock - -import pytest # noqa - -from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, SequenceGroup - -from .utils import create_dummy_prompt - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(seq_group: SequenceGroup, token_id: int): - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out, _ = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) - return metas, out - - -def test_simple(): - """Verify basic scheduling works.""" - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - num_tokens = block_size * num_seq_group - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - for s in running: - append_new_token(s, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - - -def test_chunk(): - """Verify prefills are chunked properly.""" - block_size = 4 - max_seqs = 60 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Verify the second request is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - print() - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 60 - # Verify it is chunked. - assert seq_group_meta[1].token_chunk_size == 4 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # One chunked prefill, and one decoding. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # The first one is prefill. Scheduler guarantees ordering. - assert seq_group_meta[0].token_chunk_size == 56 - # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 1 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 57 - - -def test_concurrent_chunking(): - """Verify prefills are chunked properly when - --max-num-partial-prefills is > 1""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Verify both requests are chunked with half of max_num_batched_tokens each - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 32 - assert seq_group_meta[1].token_chunk_size == 32 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - - # After one iteration, both should have 60 - 32 = 28 tokens left to prefill - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 28 - assert seq_group_meta[1].token_chunk_size == 28 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 56 - - -def test_concurrent_chunking_large_requests(): - """Verify large prefill requests are run one at a time""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests - cache_config.num_gpu_blocks = 3200 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i), - prompt_length=1200, # Very large prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # Verify only a single request is chunked, and it gets all 64 tokens - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 64 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 64 - - -def test_short_prompts_jump_long_prompts_in_queue(): - """Verify large prefill requests are punted behind smaller ones if - another large prefill request is already running""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests - cache_config.num_gpu_blocks = 3200 - scheduler = Scheduler(scheduler_config, cache_config, None) - long_seqs: list[SequenceGroup] = [] - short_seqs: list[SequenceGroup] = [] - - # Add 2 large seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i), - prompt_length=1200, # Very large prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - long_seqs.append(seq_group) - assert seq_group.is_prefill() - - # Add 2 small seq groups behind them - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i + 2), - prompt_length=40, # Very small prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - short_seqs.append(seq_group) - assert seq_group.is_prefill() - - # Verify one large req and 1 small req chunked - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens - assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens - - # all 4 are prefilling - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert short_seqs[0].is_prefill() - assert short_seqs[1].is_prefill() - # First short and first long sequences have been scheduled - assert long_seqs[0].first_seq.get_num_computed_tokens() == 32 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 32 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 0 - - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - - # in the second iteration, - # the first small request had only 8 tokens left - # so it went to decode - # The other small req is scheduled - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # the new small req got 64 - (32+8) tokens - assert seq_group_meta[0].token_chunk_size == 24 - assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 - # the other small request had only 8 tokens left - assert seq_group_meta[2].token_chunk_size == 8 # 40-32 - - # The first small request got to decode now - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert not short_seqs[0].is_prefill() - assert short_seqs[1].is_prefill() - # Both small requests have started in front of the second long request - assert long_seqs[0].first_seq.get_num_computed_tokens() == 64 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 40 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 24 - - assert out.num_prefill_groups == 3 - assert out.num_batched_tokens == 64 - # the first small seq group has a new token appended. - append_new_token(short_seqs[0], 1) - - # in the third iteration, - # the first small request is already decoding - # the second small request only has 16 tokens left and will enter decoding - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 - # small req finished prefilling 40-24=16 tokens - assert seq_group_meta[1].token_chunk_size == 16 - assert seq_group_meta[2].token_chunk_size == 1 # decode - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 49 # (32+16+1 decode) - - # both small requests have now reached decode - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert not short_seqs[0].is_prefill() - assert not short_seqs[1].is_prefill() - assert long_seqs[0].first_seq.get_num_computed_tokens() == 96 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 41 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 40 - - # both the small seq groups have a new token appended - append_new_token(short_seqs[0], 1) - append_new_token(short_seqs[1], 1) - - # in the fourth iteration, both small requests are decoding - # so large request gets all the budget - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - - # large req gets 62 tokens (minus 2 for decode) - assert seq_group_meta[0].token_chunk_size == 62 - assert seq_group_meta[1].token_chunk_size == 1 # decode - assert seq_group_meta[2].token_chunk_size == 1 # decode - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 64 - - assert long_seqs[0].first_seq.get_num_computed_tokens() == 158 - - # assert long_seqs[0].is_prefill() - # assert long_seqs[1].is_prefill() - # assert not short_seqs[0].is_prefill() - # assert not short_seqs[1].is_prefill() - - # # both the small seq groups have a new token appended - # append_new_token(short_seqs[0], 1) - # append_new_token(short_seqs[1], 1) - - # # in the fifth iteration, large request gets all the budget - # # while both small requests are decoding - # seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # assert seq_group_meta[0].token_chunk_size == 62 - # assert seq_group_meta[1].token_chunk_size == 1 # decode - # assert seq_group_meta[2].token_chunk_size == 1 # decode - # assert out.num_prefill_groups == 1 - # assert out.num_batched_tokens == 64 - - -def test_complex(): - block_size = 4 - max_seqs = 60 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 64 - cache_config.num_gpu_blocks = 64 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # Verify the second request is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 60 - # Verify it is chunked. - assert seq_group_meta[1].token_chunk_size == 4 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # Add 2 more requests. - for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Decoding & chunked prefill & first chunk of 3rd request is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 3 - # The first one is the first chunked prefill. - assert seq_group_meta[0].token_chunk_size == 7 - # The second one is the second new chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 - # The last one is decode. - assert seq_group_meta[2].token_chunk_size == 1 - # Two of them are in chunked prefill. - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # The first 2 requests are now in decodine phase. - append_new_token(running[0], 1) - assert not running[0].is_prefill() - append_new_token(running[1], 1) - assert not running[1].is_prefill() - # The third request is still in prefill stage. - assert running[2].is_prefill() - - -def test_maximal_decoding(): - """Verify decoding requests are prioritized.""" - block_size = 4 - max_seqs = 2 - max_model_len = 8 - max_num_batched_tokens = 2 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # The first prefill is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 2 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", - prompt_length=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - # The first decoding + second chunk is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert running[2].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - - # Decoding + running prefill is prioritized. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert not running[1].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - append_new_token(running[1], 1) - - # Only decoding is prioritized. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert not running[1].is_prefill() - assert out.num_prefill_groups == 0 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - append_new_token(running[1], 1) - - # After aborting the decoding request, the fcfs new prefill is prioritized. - scheduler.abort_seq_group(running[0].request_id) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[1].is_prefill() - assert running[2].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - - -def test_prompt_limit(): - """Verify max_num_batched_tokens < max_model_len is possible.""" - block_size = 4 - max_seqs = 32 - max_model_len = 64 - max_num_batched_tokens = 32 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("1", - prompt_length=48, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # The prompt length > max_num_batched_tokens should be still scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 32 - assert running[0].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 32 - - -def test_prompt_limit_exceed(): - block_size = 4 - max_seqs = 64 - max_model_len = 32 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("2", - prompt_length=48, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.ignored_seq_groups) == 1 - assert out.ignored_seq_groups[0] == seq_group - - -def test_chunked_prefill_preempt(): - """Verify preempt works with chunked prefill requests""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The request should be preempted. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group1(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group1) - - # The running prefill is now preempted. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == [] - assert out.blocks_to_swap_in == [] - - # Make sure we can reschedule preempted request. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - assert seq_group.get_num_uncomputed_tokens() == 30 - - # We should be able to run prefill twice as it is chunked. - def cannot_append_second_group2(seq_group, num_lookahead_slots): - return True - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group2) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert not seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - -def test_chunked_prefill_spec_prefill(): - """Verify that the num_lookahead_slots is set appropriately for an all""" - """prefill batch.""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - num_lookahead_slots = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - num_lookahead_slots=num_lookahead_slots, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=30, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == max_num_batched_tokens - print(out.num_lookahead_slots) - assert out.num_lookahead_slots == 0 - - -def test_chunked_prefill_max_seqs(): - block_size = 4 - max_seqs = 2 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 128 - cache_config.num_gpu_blocks = 128 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("1", - prompt_length=65, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - # The first prefill is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens - assert len(get_sequence_groups(out)) == 1 - - # Add new requests. - for i in range(4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=65, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Make sure only 2 requests are scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_batched_tokens == max_num_batched_tokens - assert len(get_sequence_groups(out)) == 2 - assert not running[0].is_prefill() - assert running[1].is_prefill() - append_new_token(running[0], 1) - - # Although we have enough token budget, we can only schedule max_seqs. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 2 - assert seq_group_meta[1].token_chunk_size == 1 - assert out.num_batched_tokens == 3 - assert len(get_sequence_groups(out)) == max_seqs - assert not running[0].is_prefill() - assert not running[1].is_prefill() - - -def test_prefix_caching(): - """Verify allocating full blocks when prefix caching is enabled.""" - block_size = 4 - max_seqs = 10 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 50 - # Verify it is chunked. Note that although the budget is 64-50=14, - # we only allocate full blocks for prefix caching, so only 4*(14//4)=12 - # tokens are allocated. - assert seq_group_meta[1].token_chunk_size == 12 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 62 - - -def test_prefix_caching_with_concurrent_partial_prefills(): - """Verify allocating full blocks when prefix caching is enabled with - --max-num-partial-prefills > 1.""" - block_size = 4 - max_seqs = 10 - max_model_len = 8000 - max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # To partially prefill both sequences, both can chunk up to 30 tokens - # But the next lowest multiple of the block size (4) is 28 - assert seq_group_meta[0].token_chunk_size == 28 - assert seq_group_meta[1].token_chunk_size == 28 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 56 - - # On the next iteration, both sequences should finish prefill - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # Both sequences have 50 - 28 = 22 tokens left to prefill. - # This is not a multiple of the block size, but we don't care since we don't - # cache the final partial block of prefix sequences - assert seq_group_meta[0].token_chunk_size == 22 - assert seq_group_meta[1].token_chunk_size == 22 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 44 - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) -def test_chunked_prefill_with_actual_engine(model: str, - max_num_partial_prefills: int): - """Make sure the model can actually sample with concurrent - partial prefills - """ - - prompt = "hello" * 40 - - engine_args = EngineArgs( - model=model, - max_num_partial_prefills=max_num_partial_prefills, - max_num_batched_tokens=40, - max_num_seqs=8, - enable_chunked_prefill=True, - gpu_memory_utilization=0.8, - ) - - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(temperature=0) - - for req_num in range(max_num_partial_prefills): - engine.add_request(f"{req_num}", prompt, sampling_params) - # first step - request_outputs = engine.step() - # means all are prefilling - assert len(request_outputs) == 0 - assert len(engine.scheduler[0].running) == max_num_partial_prefills diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py deleted file mode 100644 index 131a7b3a6299b..0000000000000 --- a/tests/core/test_num_computed_tokens_update.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from tests.conftest import VllmRunner -from tests.core.utils import create_dummy_prompt -from vllm.engine.llm_engine import LLMEngine -from vllm.sequence import SequenceGroup - -MODEL = "JackFram/llama-160m" - - -def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): - scheduler = engine.scheduler[0] - scheduler.add_seq_group(seq_group) - - -@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -def test_num_computed_tokens_update(enable_chunked_prefill: bool, - enforce_eager: bool): - - # Make a vllm engine - runner = VllmRunner(model_name=MODEL, - gpu_memory_utilization=0.7, - enable_chunked_prefill=enable_chunked_prefill, - enforce_eager=enforce_eager) - engine: LLMEngine = runner.llm.llm_engine - - num_prompt_steps = 1 - - num_output_tokens_list = [4, 8, 12, 15, 16, 17] - - # Create sequence and add to engine - prompt_len = 10 - - for req_idx, num_output_tokens in enumerate(num_output_tokens_list): - seq, seq_group = create_dummy_prompt(request_id=str(req_idx), - prompt_length=prompt_len, - min_tokens=num_output_tokens, - max_tokens=num_output_tokens) - add_seq_group_to_engine(engine, seq_group) - - assert seq.data.get_num_computed_tokens() == 0 - - for _ in range(num_prompt_steps): - # prompt steps - engine.step() - - if not seq.is_finished(): - prompt_num_computed_tokens = seq.data.get_num_computed_tokens() - # Test correctness of num_computed_tokens after the prompt steps - assert prompt_num_computed_tokens == \ - prompt_len + num_prompt_steps - 1 - - decode_step_counter = 0 - while not seq.is_finished(): - # Test correctness of num_computed_tokens after the decode steps - assert seq.data.get_num_computed_tokens( - ) == prompt_num_computed_tokens + decode_step_counter - engine.step() - decode_step_counter += 1 - - # Test correctness of num_computed_tokens after the sequence finish. - assert seq.data.get_num_computed_tokens( - ) == prompt_len + num_output_tokens - 1 diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py deleted file mode 100644 index 591e1780c11c6..0000000000000 --- a/tests/core/test_scheduler.py +++ /dev/null @@ -1,1337 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from collections import deque -from typing import Optional -from unittest.mock import MagicMock - -import pytest # noqa -import torch -from torch import Use # noqa - -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus -from vllm.core.scheduler import Scheduler, SchedulingBudget -from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup, SequenceStatus - -from .utils import (append_new_token, append_new_token_seq, - append_new_token_seq_group, create_dummy_prompt, - get_sequence_groups, schedule_and_update_computed_tokens) - - -def test_scheduler_add_seq_group(): - block_size = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=100, - max_num_seqs=64, - max_model_len=1, - ) - cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") - cache_config.num_cpu_blocks = 4 - cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq group to scheduler. - num_seq_group = 4 - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - assert scheduler.get_num_unfinished_seq_groups() == i + 1 - - -def test_scheduler_abort_seq_group(): - block_size = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=100, - max_num_seqs=64, - max_model_len=1, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 4 - cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add multiple seq groups to scheduler. - num_seq_group = 4 - request_ids: set[str] = set() - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) - scheduler.add_seq_group(seq_group) - request_ids.add(str(i)) - - # Abort all added seq groups. - assert scheduler.get_num_unfinished_seq_groups() == num_seq_group - scheduler.abort_seq_group(request_ids) - assert scheduler.get_num_unfinished_seq_groups() == 0 - - -def test_scheduler_schedule_simple(): - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=num_seq_group, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - num_tokens = block_size * num_seq_group - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - append_new_token(out, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - append_new_token(out, 1) - - -def test_scheduler_prefill_prioritized(): - """Verify running batched tokens are not applied to prefill requests.""" - block_size = 4 - max_model_len = 30 - max_batched_num_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=max_batched_num_tokens, - max_num_seqs=2, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size) - scheduler.add_seq_group(seq_group_a) - - # Schedule seq groups prompts. - _, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_a] - - # Add a new prefill request B. - _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size) - scheduler.add_seq_group(seq_group_b) - - # Verify prefill requests are prioritized. Since max_batched_num_tokens - # is 1, new prefill request has to be scheduled first. - _, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_b] - - -def test_scheduler_schedule_preempt_abort(): - block_size = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=2, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 2 - cache_config.num_gpu_blocks = 2 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", - block_size, - block_size=block_size) - seq_b, seq_group_b = create_dummy_prompt("2", - block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group_a) - scheduler.add_seq_group(seq_group_b) - - # Schedule seq groups prompts. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 2 - assert scheduler.get_num_unfinished_seq_groups() == 2 - - # Append "generated" tokens, allowing the sequence to mark prompt tokens as - # processed. - append_new_token(out, 1) - - # Schedule seq groups generation and preempt seq group b. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_a] - assert out.num_batched_tokens == 1 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert scheduler.get_num_unfinished_seq_groups() == 2 - assert out.preempted == 1 - - # Abort seq group a. Re-schedule seq group b prompt with recomputation. - scheduler.abort_seq_group("1") - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_b] - assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert scheduler.get_num_unfinished_seq_groups() == 1 - - -def test_scheduler_max_seqs(): - block_size = 4 - num_seq_group = 4 - max_seq_group = 2 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=max_seq_group, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - all_seq_groups: list[SequenceGroup] = [] - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - all_seq_groups.append(seq_group) - - # Append 1 seq group - scheduler.add_seq_group(all_seq_groups[0]) - - # Schedule seq groups prompts. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) - append_new_token(out, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) - append_new_token(out, 1) - - # Append 2 more seq group - scheduler.add_seq_group(all_seq_groups[1]) - scheduler.add_seq_group(all_seq_groups[2]) - - # Schedule seq groups prompts. - # Only 1 seq group should be scheduled since max_seq_group is 2 - # and one is prompting. - _, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) - - -def test_scheduler_delay_factor(): - block_size = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=100, - max_num_seqs=64, - max_model_len=16, - delay_factor=0.5, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # schedule first prompt - seq_group_meta, seq_group = create_dummy_prompt("0", - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_prefill_groups > 0 - assert seq_group_meta[0].request_id == '0' - append_new_token(out, 1) - - # wait for a second before scheduling next prompt - time.sleep(1) - seq_group_meta, seq_group = create_dummy_prompt("1", - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # second prompt should *not* be scheduled - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_prefill_groups == 0 - assert seq_group_meta[0].request_id == '0' - append_new_token(out, 1) - - # wait for more than 0.5 second and try again - time.sleep(0.6) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_prefill_groups > 0 - assert seq_group_meta[0].request_id == '1' - append_new_token(out, 1) - - -def initialize_scheduler( - *, - max_num_seqs=1000, - max_token_budget=1000, - max_model_len=1000, - lora_config=None, - block_size=4, - num_cpu_blocks=8, - num_gpu_blocks=8, - enable_prefix_caching=False, - enable_chunked_prefill=False, -): - block_size = block_size - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=max_token_budget, - max_num_seqs=max_num_seqs, - max_model_len=max_model_len, - enable_chunked_prefill=enable_chunked_prefill, - ) - cache_config = CacheConfig( - block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=enable_prefix_caching, - ) - cache_config.num_cpu_blocks = num_cpu_blocks - cache_config.num_gpu_blocks = num_gpu_blocks - scheduler = Scheduler(scheduler_config, cache_config, lora_config) - return scheduler - - -def create_token_budget(token_budget: int = 10000, - max_num_seqs: int = 10000) -> SchedulingBudget: - return SchedulingBudget( - token_budget=token_budget, - max_num_seqs=max_num_seqs, - ) - - -def add_token_budget(budget: SchedulingBudget, - num_batched_tokens: int = 0, - num_curr_seqs: int = 0): - mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1] - budget.add_num_batched_tokens(mock_seq_group.request_id, - num_batched_tokens) - budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) - - -def test_prefill_schedule_max_prompt_len(): - """ - Test prompt longer than max_prompt_len is aborted. - """ - block_size = 4 - scheduler = initialize_scheduler(max_model_len=30, block_size=block_size) - _, seq_group = create_dummy_prompt("0", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - budget = create_token_budget() - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 1 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 0 - - -def test_prefill_schedule_token_budget(): - """ - Test token budget respected. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - budget = create_token_budget(token_budget=0) - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # 0 token budget == nothing is scheduled. - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 2 - - # 60 token budget == 1 request scheduled. - budget = create_token_budget(token_budget=60) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 1 - assert budget.num_batched_tokens == 60 - assert budget.num_curr_seqs == 1 - assert len(remaining_waiting) == 1 - - # Test when current_batched_tokens respected. - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16) - budget = create_token_budget(token_budget=60) - add_token_budget(budget, 30, 0) - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - # Cannot schedule a prompt that doesn't fit the budget. - scheduler.add_seq_group(seq_group) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 30 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 1 - budget = create_token_budget(token_budget=90) - add_token_budget(budget, 30, 0) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.seq_groups) == 1 - assert budget.num_batched_tokens == 90 - assert budget.num_curr_seqs == 1 - assert len(remaining_waiting) == 0 - - -def test_prefill_schedule_max_seqs(): - """ - Test max seq respected. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - budget = create_token_budget(max_num_seqs=2) - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 2 - assert budget.num_batched_tokens == 120 - assert budget.num_curr_seqs == 2 - assert len(remaining_waiting) == 1 - - # Verify curr_num_seqs respected. - scheduler.waiting = deque() - budget = create_token_budget(max_num_seqs=2) - add_token_budget(budget, 0, 2) - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 2 - assert len(remaining_waiting) == 1 - - -def test_prefill_schedule_max_lora(): - """ - Test max lora is respected and prioritized. - """ - block_size = 4 - lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - budget = create_token_budget(token_budget=120) - curr_loras: set[int] = set() - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) - scheduler.add_seq_group(seq_group) - # Add two more requests to verify lora is prioritized. - # 0: LoRA, 1: LoRA, 2: regular, 3: regular - # In the first iteration, index 0, 2 is scheduled. - # If a request is not scheduled because it hits max lora, it is - # prioritized. Verify that. - for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - # Schedule 2 requests (0 and 2) - output = scheduler._schedule_prefills(budget, curr_loras) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 2 - assert budget.num_batched_tokens == 120 - assert budget.num_curr_seqs == 2 - assert len(remaining_waiting) == 2 - assert len(curr_loras) == 1 - # The second lora request is scheduled next as FCFS policy. - # Reset curr_loras so that it can be scheduled. - curr_loras = set() - budget = create_token_budget(token_budget=60) - output = scheduler._schedule_prefills(budget, curr_loras) - remaining_waiting = scheduler.waiting - assert len(output.seq_groups) == 1 - assert output.seq_groups[0].seq_group.request_id == "1" - assert len(remaining_waiting) == 1 - assert len(curr_loras) == 1 - assert budget.num_batched_tokens == 60 - - -def test_prefill_schedule_no_block_manager_capacity(): - """ - Test sequence cannot be scheduled due to block manager has no capacity. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_gpu_blocks=128, - num_cpu_blocks=128) - budget = create_token_budget() - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - scheduler.block_manager.can_allocate = MagicMock() - scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 3 - - scheduler = initialize_scheduler() - budget = create_token_budget() - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - scheduler.block_manager.can_allocate = MagicMock() - scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 3 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 0 - - -def test_decode_schedule_preempted(): - """ - Test decodes cannot be scheduled and preempted. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - curr_loras = None - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._add_seq_group_to_running(seq_group) - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - # 1 cannot be scheduled, and the lowest priority (request 2) - # should be preempted. 1 will also be preempted. - budget = create_token_budget() - output = scheduler._schedule_running(budget, curr_loras) - remaining_running = scheduler.running - assert len(remaining_running) == 0 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert output.decode_seq_groups[0].seq_group.request_id == "0" - assert len(output.preempted) == 2 - # Verify budgets are updated. - assert budget.num_batched_tokens == 1 - # NOTE: When enable_chunk is False, num_seqs budget is not updated. - # assert budget.num_curr_seqs == 1 - # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == [] - # Nothing is copied. - assert output.blocks_to_copy == [] - - -def test_schedule_decode_blocks_to_copy_update(): - """ - Verify blocks_to_copy is updated. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=4, - num_cpu_blocks=16, - num_gpu_blocks=16) - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - curr_loras = None - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._add_seq_group_to_running(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = [(2, 3)] - - budget = create_token_budget() - output = scheduler._schedule_running(budget, curr_loras) - remaining_running = scheduler.running - assert len(remaining_running) == 0 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert len(output.preempted) == 0 - assert len(output.swapped_out) == 0 - # Nothing is preempted. - assert output.blocks_to_swap_out == [] - # Since append_slot returns the source -> dist mapping, it should - # applied. - assert output.blocks_to_copy == [(2, 3)] - - -def test_schedule_swapped_max_loras(): - block_size = 4 - lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras: set[int] = set() - blocks_to_swap_out: list[tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 1 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 1 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert len(curr_loras) == 1 - - -def test_schedule_swapped_cannot_swap_in(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - blocks_to_swap_out: list[tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - # Since we cannot swap in, none of the requests are swapped in. - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 2 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -def test_infeasible_swap(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - blocks_to_swap_out: list[tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER - # Since we cannot swap in, none of the requests are swapped in. - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 0 - assert len(output.infeasible_seq_groups) == 2 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -def test_schedule_swapped_blocks_to_copy(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out: list[tuple[int, int]] = [] - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = [(2, 3)] - - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 0 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == [(2, 3)] - - -def test_scheduling_budget(): - TOKEN_BUDGET = 4 - MAX_SEQS = 4 - budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS) - assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1) - assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4) - assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5) - assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1) - assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5) - assert budget.remaining_token_budget() == TOKEN_BUDGET - - # Verify add/subtract num batched tokens. - _, seq_group = create_dummy_prompt("1", 3) - budget.add_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 2 - assert budget.num_batched_tokens == 2 - assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1) - assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1) - # Verify adding another seq group is no-op. - budget.add_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 2 - assert budget.num_batched_tokens == 2 - budget.subtract_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 4 - assert budget.num_batched_tokens == 0 - budget.subtract_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 4 - assert budget.num_batched_tokens == 0 - - # Verify add/subtract max seqs. - _, seq_group = create_dummy_prompt("1", 3) - budget.add_num_seqs(seq_group.request_id, 2) - assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2) - assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3) - assert budget.num_curr_seqs == 2 - # Verify adding another seq group is no-op. - budget.add_num_seqs(seq_group.request_id, 2) - assert budget.num_curr_seqs == 2 - budget.subtract_num_seqs(seq_group.request_id, 2) - assert budget.num_curr_seqs == 0 - budget.subtract_num_seqs(seq_group.request_id, 2) - assert budget.num_curr_seqs == 0 - - -@pytest.mark.parametrize("enable_prefix_caching", [True, False]) -def test_prefix_caching_aware_prefills(enable_prefix_caching): - """ - Test the below scenario: - - For 3 sequences, seqA, seqB, seqC, share the first block as prefix. - - The test verifies the below scenarios: - 1. SeqA is first scheduled. - 2. SeqB and SeqC can be prefilled together in a single schedule round - even though there are not enough token budgets to prefill both without - considering prefix caching. - """ - - block_size = 4 - max_num_batched_tokens = 12 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_token_budget=max_num_batched_tokens, - max_num_seqs=max_seq_group, - max_model_len=max_num_batched_tokens, - enable_prefix_caching=enable_prefix_caching, - ) - - seqA_tokens = list(range(8)) - num_shared_tokens = 4 - seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range( - 12, 16)) # Shared prefix first 4. - seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range( - 16, 20)) # Shared prefix first 4. - - seqA, seqA_group = create_dummy_prompt("0", - prompt_tokens=seqA_tokens, - block_size=block_size) - seqB, seqB_group = create_dummy_prompt("1", - prompt_tokens=seqB_tokens, - block_size=block_size) - seqC, seqC_group = create_dummy_prompt("2", - prompt_tokens=seqC_tokens, - block_size=block_size) - - # Schedule seqA prefill. - scheduler.add_seq_group(seqA_group) - metas, out, _ = scheduler.schedule() - assert (len(out.scheduled_seq_groups) == 1 - and out.scheduled_seq_groups[0].seq_group == seqA_group) - assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens) - - # Schedule seqA decode. - append_new_token_seq_group(len(seqA_tokens), seqA_group, 999) - metas, out, _ = scheduler.schedule() - - assert len(out.scheduled_seq_groups) == 1 - assert out.scheduled_seq_groups[0].seq_group == seqA_group - assert out.scheduled_seq_groups[0].token_chunk_size == 1 - - # Schedule seqB and seqC prefills should work with prefix caching. - scheduler.add_seq_group(seqB_group) - scheduler.add_seq_group(seqC_group) - metas, out, _ = scheduler.schedule() - - if enable_prefix_caching: - assert len(out.scheduled_seq_groups) == 2 - assert set([ - out.scheduled_seq_groups[0].seq_group, - out.scheduled_seq_groups[1].seq_group, - ]) == set([seqB_group, seqC_group]) - assert len(metas) == 2 - for meta in metas: - assert meta.token_chunk_size == 8 - assert (len(meta.computed_block_nums) == num_shared_tokens // - block_size) # 1 Block for the 8 tokens. - else: - assert len(out.scheduled_seq_groups) == 1 - assert len(metas) == 1 - assert metas[0].token_chunk_size == 8 - assert len(metas[0].computed_block_nums) == 0 # No blocks computed. - - -def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( -): - """ - This test verifies that we don't schedule new prefills if there's already - a continuous prefill in progress even though the new prefills with shared - prefix can fit in the token budget: - - - SeqA is being chunked prefill. - - SeqB with the same prompt shouldn't be scheduled for prefill even though - there's enough token budget to prefill the cached tokens. - - Neither should seqC be scheduled. - - - When seqA is in decoding phase, seqB and seqC can be scheduled. - - Entire seqB should be prefilled since it's a full prefix cache hit. - - SeqC would be partially prefilled with the prefix shared, and the - remaining unique tokens would be prefilled (rounded down to be - block-size aligned). - """ - - block_size = 2 - max_num_batched_tokens = 4 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_token_budget=max_num_batched_tokens, - max_num_seqs=max_seq_group, - max_model_len=100, - enable_prefix_caching=True, - enable_chunked_prefill=True, - ) - - seqA_tokens = list(range(8)) - seqB_tokens = seqA_tokens - seqC_shared_prefix_len = 4 - seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20)) - - seqA, seqA_group = create_dummy_prompt("0", - prompt_tokens=seqA_tokens, - block_size=block_size) - seqB, seqB_group = create_dummy_prompt("1", - prompt_tokens=seqB_tokens, - block_size=block_size) - - # Chunked prefill seqA. - scheduler.add_seq_group(seqA_group) - metas, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.scheduled_seq_groups[0].seq_group == seqA_group - assert out.scheduled_seq_groups[0].token_chunk_size == 4 - - # seqB should not be scheduled with ongoing prefills. - scheduler.add_seq_group(seqB_group) - metas, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.scheduled_seq_groups[0].seq_group == seqA_group - assert out.scheduled_seq_groups[0].token_chunk_size == 4 - - # both seqB and seqC can now be scheduled with seqA is over. - # seqA is in decoding phase. - append_new_token_seq(seqA, 999) - seqC, seqC_group = create_dummy_prompt("2", - prompt_tokens=seqC_tokens, - block_size=block_size) - scheduler.add_seq_group(seqC_group) - metas, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 3 - - metas = {meta.request_id: meta for meta in metas} - assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode - assert (metas[seqB_group.request_id].token_chunk_size == 8 - ) # Fully cached prefill - assert ( - metas[seqC_group.request_id].token_chunk_size == 6 - ), "A partial prefix of C (4 tokens) should be prefilled, with the " - "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " - "then be rounded down to 2 tokens on block size, thus 6 tokens in total." - - -def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): - """ - Test that the scheduler does not schedule batches with prompt tokens and - prompt embeddings co-mingled. - """ - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=100, - enable_prefix_caching=True, - ) - - # the odd indexed inputs should be passed in via embeddings, - # evens via token_ids - seq_length = 7 - embedding_size = 5 - num_seqs = 11 - seq_tokens: list[list[int]] = [] - seq_embeds: list[Optional[torch.Tensor]] = [] - for i in range(num_seqs): - if i % 2: - seq_tokens.append(list(range(seq_length))) - seq_embeds.append(None) - else: - seq_tokens.append([0] * seq_length) - seq_embeds.append(torch.rand(embedding_size)) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens[i], - prompt_embeds=seq_embeds[i], - block_size=block_size) - for i in range(len(seq_tokens)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): - unfinished_seq_groups = [ - seq_group for _, seq_group in seq_and_seq_groups - if not seq_group.is_finished() - ] - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) > 0 - batch_is_prompt_embeds = out.scheduled_seq_groups[ - 0].seq_group.uses_prompt_embeds() - expected_scheduled_seq_groups = [ - seq_group for seq_group in unfinished_seq_groups - if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds - ] - - # We should have as many scheduled groups as possible, without mixing - assert len(out.scheduled_seq_groups) == min( - max_seq_group, len(expected_scheduled_seq_groups)) - assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == - batch_is_prompt_embeds - for scheduled_seq_group in out.scheduled_seq_groups) - - # Finish the scheduled groups - for scheduled_seq_group in out.scheduled_seq_groups: - for seq in scheduled_seq_group.seq_group.seqs: - seq.status = SequenceStatus.FINISHED_STOPPED - scheduler.free_finished_seq_groups() - - -def test_remove_seq_from_computed_blocks_tracker(): - """ - Test that computed_blocks_tracker correctly removes stale sequences - during scheduling. - - The test covers 9 scheduling branches where stale seqs are removed: - - 1 in _schedule_swapped - - 1 in _schedule_priority_preemption - - 7 in _schedule_prefill - - Each branch is tested to ensure proper cleanup of - _seq_id_to_num_tokens_computed. - """ - # Budget can not schedule in swapped - block_size = 2 - max_seq_group = 3 - seq_tokens_with_swapped: list[list[int]] = [] - blocks_to_swap_out: list[tuple[int, int]] = [] - curr_loras: set[int] = set() - - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - enable_prefix_caching=True, - ) - budget = create_token_budget(token_budget=15) - - seq_length = 16 - num_seqs = 3 - for i in range(num_seqs): - seq_tokens_with_swapped.append([i] * seq_length) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_with_swapped[i], - block_size=block_size) - for i in range(len(seq_tokens_with_swapped)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler._allocate_and_set_running(seq_group) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - scheduler._schedule_swapped(budget, curr_loras) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Prefill schedule don't have a space for another LoRA, so - # we ignore this request for now. - block_size = 4 - lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64, - enable_prefix_caching=True) - budget = create_token_budget(token_budget=120) - num_seqs = 2 - for i in range(num_seqs): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=seq_length, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) - scheduler.add_seq_group(seq_group) - - scheduler._schedule_prefills(budget, curr_loras) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Priority preemption schedule - scheduler._schedule_priority_preemption(budget) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Prefill scheduler does not schedule batches with prompt tokens and - # prompt embeddings co-mingled. - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=100, - enable_prefix_caching=True, - ) - seq_length = 7 - embedding_size = 5 - seq_tokens_with_embedding: list[list[int]] = [] - seq_embeds: list[Optional[torch.Tensor]] = [] - - seq_tokens_with_embedding.append(list(range(seq_length))) - seq_embeds.append(None) - seq_tokens_with_embedding.append([0] * seq_length) - seq_embeds.append(torch.rand(embedding_size)) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_with_embedding[i], - prompt_embeds=seq_embeds[i], - block_size=block_size) - for i in range(len(seq_tokens_with_embedding)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Prefill scheduler budget num_batched_tokens - # >= scheduler_config max_num_batched_tokens - block_size = 2 - max_seq_group = 3 - seq_tokens_prefill_budget: list[list[int]] = [] - - scheduler = initialize_scheduler( - block_size=block_size, - max_token_budget=8, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=5, - enable_prefix_caching=True, - ) - seq_length = 4 - num_seqs = 3 - for i in range(num_seqs): - seq_tokens_prefill_budget.append([i] * seq_length) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_prefill_budget[i], - block_size=block_size) - for i in range(len(seq_tokens_prefill_budget)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(2)) - assert seq_id_to_num_tokens_computed is None - - # Budget can not schedule in waiting - block_size = 2 - max_seq_group = 3 - - scheduler = initialize_scheduler( - block_size=block_size, - max_token_budget=30, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=30, - enable_prefix_caching=True, - ) - seq_length = 16 - num_seqs = 3 - seq_tokens_prefill_budget_waiting: list[list[int]] = [] - - for i in range(num_seqs): - seq_tokens_prefill_budget_waiting.append(list(range(seq_length))) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_prefill_budget_waiting[i], - block_size=block_size) - for i in range(len(seq_tokens_prefill_budget_waiting)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Sequence num_new_tokens > prompt_limit marked FINISHED_IGNORED - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=30, - enable_prefix_caching=True, - ) - - seq_length = 31 - seq_tokens_prompt_limit: list[list[int]] = [] - seq_tokens_prompt_limit.append(list(range(seq_length))) - seq_and_seq_groups = [ - create_dummy_prompt("0", - prompt_tokens=seq_tokens_prompt_limit[0], - block_size=block_size) - ] - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(0)) - assert seq_id_to_num_tokens_computed is None - - # Budget can not allocate, AllocStatus is NEVER marked FINISHED_IGNORED - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=160, - num_gpu_blocks=160, - max_num_seqs=max_seq_group, - max_model_len=320, - enable_prefix_caching=True, - ) - - seq_length = 320 - num_seqs = 1 - seq_tokens_never: list[list[int]] = [] - for i in range(num_seqs): - seq_tokens_never.append(list(range(seq_length))) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_never[i], - block_size=block_size) - for i in range(len(seq_tokens_never)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(0)) - assert seq_id_to_num_tokens_computed is None - - # Budget can not allocate, AllocStatus is LATER - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=160, - num_gpu_blocks=160, - max_num_seqs=max_seq_group, - max_model_len=320, - enable_prefix_caching=True, - ) - - seq_length = 160 - num_seqs = 2 - seq_tokens_later: list[list[int]] = [] - for i in range(num_seqs): - seq_tokens_later.append(list(range(seq_length))) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_later[i], - block_size=block_size) - for i in range(len(seq_tokens_later)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py deleted file mode 100644 index 20cc083ec8db4..0000000000000 --- a/tests/core/test_scheduler_encoder_decoder.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest # noqa - -from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.sequence import SequenceGroup - -from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, - get_sequence_groups, schedule_and_update_computed_tokens) - - -def test_scheduler_schedule_simple_encoder_decoder(): - ''' - Test basic scheduler functionality in the context - of an encoder/decoder model. Focus on testing - enc/dec-specific functionality sense tests already - exist for decoder-only functionality - - Test behavior: - * Construct Scheduler - * Construct dummy encoder/decoder sequence groups - * Add dummy seq groups to scheduler backlog - * Schedule the next seq group & validate: - * Cross-attn block tables - * Updated states of seq groups - * Number of batched tokens - * Number of blocks to copy/swap-in/swap-out - * Number of scheduled seq groups - * Repeat for both prefill- and decode-phase - * Abort scheduled seq groups - * Assert that aborted seq groups no longer appear in - cross-attention block table - ''' - - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=num_seq_group, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group - cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - req_id_list = [] - for i in range(num_seq_group): - req_id = str(i) - req_id_list.append(req_id) - _, _, seq_group = create_dummy_prompt_encoder_decoder( - req_id, block_size, block_size, block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prefill. - num_tokens = block_size * num_seq_group - seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) - # - Verify that sequence group cross-attention block tables are - # registered with the block manager - assert all([(req_id in scheduler.block_manager.cross_block_tables) - for req_id in req_id_list]) - # - Validate sequence-group status - assert set(get_sequence_groups(out)) == set(running) - # - Validate number of batched tokens - assert out.num_batched_tokens == num_tokens - # - Validate there are no remaining blocks to swap - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - # - Validate all seq groups were scheduled - assert len(seq_group_meta_list) == num_seq_group - append_new_token(out, 1) - - # Schedule seq groups decode. - seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) - # - Verify that sequence group metadata includes encoder attention - # and cross-attention metadata - assert all([ - not ((seq_group_meta.encoder_seq_data is None) or - (seq_group_meta.cross_block_table is None)) - for seq_group_meta in seq_group_meta_list - ]) - # - Validate sequence-group status - assert set(get_sequence_groups(out)) == set(running) - # - Validate there is one batched token per seq group - assert out.num_batched_tokens == num_seq_group - # - Validate there are no remaining blocks to swap - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - # - Validate that all seq groups were scheduled - assert len(seq_group_meta_list) == num_seq_group - append_new_token(out, 1) - - # Abort sequences - for req_id in req_id_list: - scheduler.abort_seq_group(req_id) - # - Verify that sequence group cross-attention block tables are - # NO LONGER registered with the block manager - assert req_id not in scheduler.block_manager.cross_block_tables diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py deleted file mode 100644 index ee9ac2129f2db..0000000000000 --- a/tests/core/test_serialization.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import msgspec - -from vllm.executor.msgspec_utils import decode_hook, encode_hook -from vllm.sequence import ExecuteModelRequest - -from .utils import create_batch - - -def test_msgspec_serialization(): - num_lookahead_slots = 4 - seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=num_lookahead_slots, - running_queue_size=4) - - encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, - dec_hook=decode_hook) - req = decoder.decode(encoder.encode(execute_model_req)) - expected = execute_model_req.seq_group_metadata_list - actual = req.seq_group_metadata_list - assert (len(expected) == len(actual)) - expected = expected[0] - actual = actual[0] - - assert expected.block_tables == actual.block_tables - assert expected.is_prompt == actual.is_prompt - assert expected.request_id == actual.request_id - assert (expected.seq_data[0].prompt_token_ids == - actual.seq_data[0].prompt_token_ids) - assert (expected.seq_data[0].output_token_ids == - actual.seq_data[0].output_token_ids) diff --git a/tests/core/utils.py b/tests/core/utils.py deleted file mode 100644 index 033fffd2c4e24..0000000000000 --- a/tests/core/utils.py +++ /dev/null @@ -1,392 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from collections import defaultdict -from collections.abc import Sequence as GenericSequence -from itertools import count -from typing import Any, Optional, Union - -import torch - -from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata) - - -def create_dummy_prompt( - request_id: str, - prompt_length: int = -1, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_tokens: Optional[list[int]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - min_tokens: int = 0, - max_tokens: int = 16, -) -> tuple[Sequence, SequenceGroup]: - if not block_size: - block_size = prompt_length - - if prompt_tokens is None: - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". - prompt_tokens = list(range(prompt_length)) - - prompt_str = " ".join([str(t) for t in prompt_tokens]) - inputs = token_inputs( - prompt_token_ids=prompt_tokens, - prompt=prompt_str) if prompt_embeds is None else embeds_inputs( - prompt_embeds=prompt_embeds) - prompt = Sequence( - int(request_id), - inputs=inputs, - block_size=block_size, - ) - seq_group = SequenceGroup( - request_id=request_id, - seqs=[prompt], - arrival_time=time.time(), - sampling_params=SamplingParams(max_tokens=max_tokens, - min_tokens=min_tokens), - lora_request=lora_request, - ) - - return prompt, seq_group - - -def create_dummy_lora_sequence(request_id: int, token_ids: list[int], - block_size: int, lora_int_id: int) -> Sequence: - return Sequence(seq_id=request_id, - inputs=token_inputs(token_ids), - block_size=block_size, - lora_request=LoRARequest(lora_name="dummy", - lora_path="/dummy", - lora_int_id=lora_int_id)) - - -def create_dummy_sequence(request_id: int, token_ids: list[int], - block_size: int) -> Sequence: - return Sequence( - seq_id=request_id, - inputs=token_inputs(token_ids), - block_size=block_size, - ) - - -def create_dummy_prompt_encoder_decoder( - request_id: str, - decoder_prompt_length: int, - encoder_prompt_length: int, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, -) -> tuple[Sequence, Sequence, SequenceGroup]: - if not block_size: - block_size = decoder_prompt_length - - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". Note that the prompt string - # doesn't actually match the tokens - decoder_prompt_tokens = list(range(decoder_prompt_length)) - decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) - encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) - encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - - inputs: EncoderDecoderInputs = { - "decoder": token_inputs(decoder_prompt_tokens, - prompt=decoder_prompt_str), - "encoder": token_inputs(encoder_prompt_tokens, - prompt=encoder_prompt_str), - } - - decoder_prompt = Sequence(int(request_id), - inputs=inputs["decoder"], - block_size=block_size) - - encoder_prompt = Sequence(int(request_id), - inputs=inputs["encoder"], - block_size=block_size) - - seq_group = SequenceGroup(request_id=request_id, - seqs=[decoder_prompt], - arrival_time=time.time(), - lora_request=lora_request, - encoder_seq=encoder_prompt) - - return decoder_prompt, encoder_prompt, seq_group - - -def create_seq_group( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - - assert len(seq_output_lens) > 0 - - if sampling_params is None: - sampling_params = SamplingParams() - - prompt_token_ids = [0] * seq_prompt_len - - seqs: list[Sequence] = [] - for seq_id_offset, output_len in enumerate(seq_output_lens): - seq = Sequence( - seq_id=seq_id_start + seq_id_offset, - inputs=token_inputs(prompt_token_ids), - block_size=16, - ) - - for i in range(output_len): - seq.append_token_id( - token_id=i, - logprobs={i: Logprob(0.0)}, - ) - seqs.append(seq) - - seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - ) - - return seq_group - - -def create_seq_group_encoder_decoder( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - - assert len(seq_output_lens) > 0 - - if sampling_params is None: - sampling_params = SamplingParams() - - prompt_token_ids = [0] * seq_prompt_len - - inputs: EncoderDecoderInputs = { - "decoder": token_inputs(prompt_token_ids), - "encoder": token_inputs(prompt_token_ids), - } - - seqs = [] - for seq_id_offset, output_len in enumerate(seq_output_lens): - # Construct decoder input sequences - seq = Sequence( - seq_id=seq_id_start + seq_id_offset, - inputs=inputs["decoder"], - block_size=16, - ) - - for i in range(output_len): - seq.append_token_id( - token_id=i, - logprobs={i: Logprob(0.0)}, - ) - seqs.append(seq) - - # Encoder input sequence - encoder_seq = Sequence( - seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs["encoder"], - block_size=16, - ) - - return SequenceGroup(request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq) - - -def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size - - -# Helper functions for scheduler tests - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(out, token_id: int): - seq_groups = get_sequence_groups(out) - for seq_group in seq_groups: - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out, _ = scheduler.schedule() - for s in out.scheduled_seq_groups: - s.seq_group.update_num_computed_tokens(s.token_chunk_size) - return metas, out - - -def append_new_token_seq(seq: Sequence, token_id: int): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): - seq_group.update_num_computed_tokens(token_chunk_size) - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -class SchedulerProxy: - """ - A proxy class to forward calls to the scheduler. - """ - - def __init__(self, scheduler: Scheduler): - self.scheduler_ = scheduler - self.call_history: dict[str, list[Any]] = defaultdict(list) - - def __getattr__(self, name: str) -> Any: - - def wrapper(*args, **kwargs): - result = getattr(self.scheduler_, name)(*args, **kwargs) - self.call_history[name].append((args, kwargs, result)) - return result - - return wrapper - - def last_schedule_ret( - self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]: - _, _, ret = self.call_history["schedule"][-1] - return ret - - -def create_seq_group_metadata_from_prompts( - prompts: list[list[int]], - num_gpu_blocks: int, - block_size: int, - final_prompt_lens: list[int], - continuations: Optional[list[list[int]]] = None, - seq_ids: Optional[list[int]] = None, -) -> list[SequenceGroupMetadata]: - - if continuations is None: - continuations = [[] for _ in prompts] - - if seq_ids is None: - seq_ids = list(i for i, _ in enumerate(prompts)) - - free_gpu_blocks = list(range(num_gpu_blocks)) - - block_allocations = { - i: [ - free_gpu_blocks.pop() - for _ in range(round_up_to_next_block(final_len, block_size)) - ] - for i, final_len in enumerate(final_prompt_lens) - } - - seq_grou_metadata_list = [] - for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)): - data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) - data.update_num_computed_tokens( - len(prompt_token_ids) + len(cont_token_ids) - 1) - seq_data = {i: data} - seq_grou_metadata_list.append( - SequenceGroupMetadata( - request_id=str(i), - is_prompt=len(cont_token_ids) == 0, - seq_data=seq_data, - sampling_params=SamplingParams(temperature=0.0), - block_tables={i: block_allocations[i][:]}, - )) - return seq_grou_metadata_list - - -def create_chunked_seq_group_metadata_from_prompt( - prompt: list[int], - num_gpu_blocks: int, - chunk_size: int, - block_size: int, - seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]: - - if seq_id is None: - seq_id = 0 - - free_gpu_blocks = list(range(num_gpu_blocks)) - - block_allocations = [ - free_gpu_blocks.pop() - for _ in range(round_up_to_next_block(len(prompt), block_size)) - ] - - seq_group_metadata_list = [] - for i, idx in enumerate(range(0, len(prompt), chunk_size)): - chunk_ids = prompt[idx:idx + chunk_size] - data = SequenceData.from_seqs(prompt) - data.update_num_computed_tokens(idx) - seq_data = {i: data} - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=str(seq_id), - is_prompt=True, - do_sample=idx + chunk_size >= len(prompt), # terminal chunk - seq_data=seq_data, - sampling_params=SamplingParams(temperature=0.0), - block_tables={i: block_allocations}, - token_chunk_size=len(chunk_ids))) - return seq_group_metadata_list - - -def create_batch(batch_size, - k, - prompt_len: Union[int, list[int]] = 10, - prev_output_token_len: int = 10, - seq_ids: Optional[list[int]] = None, - num_gpu_blocks: Optional[int] = None, - block_size: Optional[int] = None, - prefill_chunk_size: Optional[int] = None): - if block_size is None: - block_size = 8 - - if num_gpu_blocks is None: - num_gpu_blocks = 2048 // block_size - - iterator = count() - - if isinstance(prompt_len, int): - prompt_lens = [prompt_len for _ in range(batch_size)] - else: - prompt_lens = prompt_len - - prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] - - if prefill_chunk_size: - # Create a batch of chunked prompts. - if not seq_ids: - seq_ids = list(range(len(prompts))) - seq_group_metadata_list = [] - for p, sid in zip(prompts, seq_ids): - seq_group_metadata_list += \ - create_chunked_seq_group_metadata_from_prompt( - p, num_gpu_blocks, prefill_chunk_size, block_size, sid) - seq_group_metadata_list = seq_group_metadata_list[:batch_size] - prev_output_tokens = [] - else: - prev_output_tokens = [[ - next(iterator) for _ in range(prev_output_token_len) - ] for _ in range(batch_size)] - final_prompt_lens = [ - len(prompt) + len(prev_output_token) + k + 1 - for prompt, prev_output_token in zip(prompts, prev_output_tokens) - ] - - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, final_prompt_lens, - prev_output_tokens, seq_ids) - return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/cuda/test_cuda_context.py b/tests/cuda/test_cuda_context.py index f973b284b87e1..6336f2112c66e 100644 --- a/tests/cuda/test_cuda_context.py +++ b/tests/cuda/test_cuda_context.py @@ -13,7 +13,7 @@ from vllm.platforms import current_platform def check_cuda_context(): """Check CUDA driver context status""" try: - cuda = ctypes.CDLL('libcuda.so') + cuda = ctypes.CDLL("libcuda.so") device = ctypes.c_int() result = cuda.cuCtxGetDevice(ctypes.byref(device)) return (True, device.value) if result == 0 else (False, None) @@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id): # New thread should have no CUDA context initially valid_before, device_before = check_cuda_context() if valid_before: - return False, \ - "CUDA context should not exist in new thread, " \ - f"got device {device_before}" + return ( + False, + "CUDA context should not exist in new thread, " + f"got device {device_before}", + ) # Test setting CUDA context current_platform.set_device(device_input) @@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id): if not valid_after: return False, "CUDA context should be valid after set_cuda_context" if device_id != expected_device_id: - return False, \ - f"Expected device {expected_device_id}, got {device_id}" + return False, f"Expected device {expected_device_id}, got {device_id}" return True, "Success" except Exception as e: @@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id): class TestSetCudaContext: """Test suite for the set_cuda_context function.""" - @pytest.mark.skipif(not current_platform.is_cuda(), - reason="CUDA not available") - @pytest.mark.parametrize(argnames="device_input,expected_device_id", - argvalues=[ - (0, 0), - (torch.device('cuda:0'), 0), - ('cuda:0', 0), - ], - ids=["int", "torch_device", "string"]) - def test_set_cuda_context_parametrized(self, device_input, - expected_device_id): + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") + @pytest.mark.parametrize( + argnames="device_input,expected_device_id", + argvalues=[ + (0, 0), + (torch.device("cuda:0"), 0), + ("cuda:0", 0), + ], + ids=["int", "torch_device", "string"], + ) + def test_set_cuda_context_parametrized(self, device_input, expected_device_id): """Test setting CUDA context in isolated threads.""" with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_cuda_test_in_thread, device_input, - expected_device_id) + future = executor.submit( + run_cuda_test_in_thread, device_input, expected_device_id + ) success, message = future.result(timeout=30) assert success, message - @pytest.mark.skipif(not current_platform.is_cuda(), - reason="CUDA not available") + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") def test_set_cuda_context_invalid_device_type(self): """Test error handling for invalid device type.""" with pytest.raises(ValueError, match="Expected a cuda device"): - current_platform.set_device(torch.device('cpu')) + current_platform.set_device(torch.device("cpu")) if __name__ == "__main__": diff --git a/tests/detokenizer/conftest.py b/tests/detokenizer/conftest.py deleted file mode 100644 index f2c125355c83c..0000000000000 --- a/tests/detokenizer/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass diff --git a/tests/detokenizer/test_disable_detokenization.py b/tests/detokenizer/test_disable_detokenization.py index ae06a985c7ecd..a77626df5dc78 100644 --- a/tests/detokenizer/test_disable_detokenization.py +++ b/tests/detokenizer/test_disable_detokenization.py @@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str): prompt = ( "You are a helpful assistant. How do I build a car from cardboard and " "paper clips? Is there an easy to follow video tutorial available " - "online for free?") + "online for free?" + ) llm = LLM(model=model) - sampling_params = SamplingParams(max_tokens=10, - temperature=0.0, - detokenize=False) + sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False) - outputs_no_detokenization = llm.generate(prompt, - sampling_params)[0].outputs[0] + outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0] sampling_params.detokenize = True - outputs_with_detokenization = llm.generate(prompt, - sampling_params)[0].outputs[0] + outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0] - assert outputs_no_detokenization.text == '' - assert outputs_with_detokenization.text != '' - assert outputs_no_detokenization.token_ids == \ - outputs_with_detokenization.token_ids + assert outputs_no_detokenization.text == "" + assert outputs_with_detokenization.text != "" + assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids diff --git a/tests/detokenizer/test_min_tokens.py b/tests/detokenizer/test_min_tokens.py index 887e83342536e..1f8e944695bdc 100644 --- a/tests/detokenizer/test_min_tokens.py +++ b/tests/detokenizer/test_min_tokens.py @@ -8,15 +8,17 @@ from vllm import SamplingParams from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer -PROMPT = "Hello, my name is Lee, and I'm a student in the " + \ - "college of engineering" +PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering" -@pytest.mark.parametrize("min_tokens,stop,truth", [ - (0, None, " is Lee, and I'm a student in the college of engineering"), - (0, "e", " is L"), - (5, "e", " is Lee, and I'm a stud"), -]) +@pytest.mark.parametrize( + "min_tokens,stop,truth", + [ + (0, None, " is Lee, and I'm a student in the college of engineering"), + (0, "e", " is L"), + (5, "e", " is Lee, and I'm a stud"), + ], +) def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): """Test for a specific min_tokens and stop. @@ -31,18 +33,18 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): stop=stop, min_tokens=min_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, - cache_salt=None, - data_parallel_rank=None) + request = EngineCoreRequest( + request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) detokenizer = FastIncrementalDetokenizer(tokenizer, request) diff --git a/tests/detokenizer/test_stop_checker.py b/tests/detokenizer/test_stop_checker.py deleted file mode 100644 index bd221977224f9..0000000000000 --- a/tests/detokenizer/test_stop_checker.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock - -import pytest -from transformers import PreTrainedTokenizer - -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.inputs import token_inputs -from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, Sequence, SequenceStatus - - -def sequence_with_eos(text: str, eos_token: str, - eos_token_id: int) -> Sequence: - """ - Create a Sequence that ends with an EOS token. - """ - seq = Sequence( - seq_id=0, - inputs=token_inputs([]), - block_size=16, - eos_token_id=eos_token_id, - ) - seq.output_text = text + eos_token - - offset = eos_token_id + 1 - for i in range(offset, len(text) + offset): - seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)}) - seq.append_token_id(token_id=eos_token_id, - logprobs={eos_token_id: Logprob(0.0)}) - - seq.status = SequenceStatus.RUNNING - - return seq - - -@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ - ("This text ends with EOS token", "</s>", 2), -]) -@pytest.mark.parametrize("ignore_eos", [True, False]) -@pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -@pytest.mark.skip_global_cleanup -def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, - ignore_eos: bool, include_stop_str_in_output: bool): - """ - Test the behavior of the StopChecker's maybe_stop_sequence method - when an EOS token is encountered. - - This test covers: - - When the EOS token should stop the sequence and be removed from the output - - When the EOS token should stop the sequence and be included in the output - - When the EOS token should be ignored, and the sequence continues - """ - - tokenizer = MagicMock(spec=PreTrainedTokenizer) - get_tokenizer_for_seq = MagicMock(return_value=tokenizer) - stop_checker = StopChecker(max_model_len=1024, - get_tokenizer_for_seq=get_tokenizer_for_seq) - - seq = sequence_with_eos( - text=text_wo_eos, - eos_token=eos_token, - eos_token_id=eos_token_id, - ) - new_char_count = len(eos_token) - - # Note that `stop` and `stop_token_ids` are not specified - sampling_params = SamplingParams( - min_tokens=1, - ignore_eos=ignore_eos, - include_stop_str_in_output=include_stop_str_in_output) - - stop_checker.maybe_stop_sequence( - seq=seq, - new_char_count=new_char_count, - sampling_params=sampling_params, - ) - - if ignore_eos: - assert seq.status == SequenceStatus.RUNNING - assert seq.output_text == text_wo_eos + eos_token - elif include_stop_str_in_output: - assert seq.status == SequenceStatus.FINISHED_STOPPED - assert seq.output_text == text_wo_eos + eos_token - else: - assert seq.status == SequenceStatus.FINISHED_STOPPED - assert seq.output_text == text_wo_eos diff --git a/tests/detokenizer/test_stop_reason.py b/tests/detokenizer/test_stop_reason.py index 1ff679789c959..6565949cc50fc 100644 --- a/tests/detokenizer/test_stop_reason.py +++ b/tests/detokenizer/test_stop_reason.py @@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts): llm = vllm_model.llm # test stop token - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - ignore_eos=True, - seed=SEED, - max_tokens=MAX_TOKENS, - stop_token_ids=[stop_token_id])) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams( + ignore_eos=True, + seed=SEED, + max_tokens=MAX_TOKENS, + stop_token_ids=[stop_token_id], + ), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" assert output.stop_reason == stop_token_id # test stop string - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - ignore_eos=True, - seed=SEED, - max_tokens=MAX_TOKENS, - stop=".")) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams( + ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="." + ), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" assert output.stop_reason == STOP_STR # test EOS token - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - seed=SEED, max_tokens=MAX_TOKENS)) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "length" or ( - output.finish_reason == "stop" and output.stop_reason is None) + output.finish_reason == "stop" and output.stop_reason is None + ) diff --git a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py new file mode 100644 index 0000000000000..5624332ef71d6 --- /dev/null +++ b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import BaseIncrementalDetokenizer + + +@pytest.fixture(params=[True, False]) +def include_stop_str_in_output(request): + return request.param + + +class _DummyDetokenizer(BaseIncrementalDetokenizer): + def __init__(self, request: EngineCoreRequest): + super().__init__(request) + + def decode_next(self, next_token_id: int) -> str: + # Map token id to single ASCII character for deterministic testing. + return chr(next_token_id) + + +def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): + params = SamplingParams( + stop=stop, + include_stop_str_in_output=include_stop_str_in_output, + min_tokens=min_tokens, + ) + # Keep other fields minimal for unit test purposes. + req = EngineCoreRequest( + request_id="test", + prompt_token_ids=[], + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) + return req + + +def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool): + """ + This test verifies that the detokenizer correctly handles the case where + the generated token sequence contains both: + - a stop token + - an <eos> token + + The detokenizer should respect the stop string and truncate the output + accordingly. + + Imagine the following sequence: + - "abcdeZ" is generated, where "Z" is the <eos> token. + - "cd" is the stop string. + + If include_stop_str_in_output=False, the detokenizer should truncate the + output to "ab" because the stop string "cd" is excluded. + If include_stop_str_in_output=True, the detokenizer should include the stop + string "cd" in the output, resulting in "abcd". + + + This verifies the behavioral change introduced in BaseIncrementalDetokenizer + where stop-string evaluation occurs before the early-return on + stop_terminated. + """ + + # Generate text "abcdeZ" and tokenize it. + generated_text = "abcde" + eos_token = "Z" + stop_string = "cd" + generated_text = generated_text + eos_token + token_ids = [ord(c) for c in generated_text] + + # Create a request with the stop string and initialize the detokenizer. + req = _make_request( + stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output + ) + detok = _DummyDetokenizer(req) + + # Simulate that the last token ('Z') is a stop token (stop_terminated=True). + result = detok.update(new_token_ids=token_ids, stop_terminated=True) + + # The update should not report a stop string + assert result == stop_string + + # Output text should reflect stop-string handling: + # - include_stop_str_in_output=False => exclude "cd" => "ab" + # - include_stop_str_in_output=True => include "cd" => "abcd" + expected_text = "abcd" if include_stop_str_in_output else "ab" + assert detok.output_text == expected_text + + # The skipped final token should still be recorded in token_ids. + assert detok.output_token_ids == token_ids + + # get_next_output_text should return the full text when finished=True. + # (Buffering only applies during streaming when finished=False.) + assert detok.get_next_output_text(finished=True, delta=False) == expected_text diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index cb87c44cc3999..d59b394393e34 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -5,18 +5,20 @@ from typing import Any, Optional import pytest -from vllm import LLM, SamplingParams, envs +from vllm import LLM, SamplingParams MODEL = "meta-llama/llama-2-7b-hf" MAX_TOKENS = 200 -def _test_stopping(llm: LLM, - expected_output: str, - expected_reason: Any, - stop: Optional[list[str]] = None, - stop_token_ids: Optional[list[int]] = None, - include_in_output: bool = False) -> None: +def _test_stopping( + llm: LLM, + expected_output: str, + expected_reason: Any, + stop: Optional[list[str]] = None, + stop_token_ids: Optional[list[int]] = None, + include_in_output: bool = False, +) -> None: output = llm.generate( "A story about vLLM:\n", SamplingParams( @@ -25,29 +27,30 @@ def _test_stopping(llm: LLM, stop=stop, stop_token_ids=stop_token_ids, include_stop_str_in_output=include_in_output, - ))[0].outputs[0] + ), + )[0].outputs[0] assert output is not None assert output.text == expected_output assert output.stop_reason == expected_reason -def _set_async_mode(llm, is_async): - llm.llm_engine.scheduler[0].use_async_output_proc = is_async - - def _stop_basic(llm): - _test_stopping(llm, - stop=["."], - include_in_output=False, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=".") + _test_stopping( + llm, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".", + ) - _test_stopping(llm, - stop=["."], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization.", - expected_reason=".") + _test_stopping( + llm, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".", + ) def _stop_multi_tokens(llm): @@ -56,87 +59,62 @@ def _stop_multi_tokens(llm): stop=["group of peo", "short"], include_in_output=False, expected_output="VLLM is a 100% volunteer organization. We are a ", - expected_reason="group of peo") + expected_reason="group of peo", + ) _test_stopping( llm, stop=["group of peo", "short"], include_in_output=True, - expected_output= - "VLLM is a 100% volunteer organization. We are a group of peo", - expected_reason="group of peo") + expected_output="VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo", + ) def _stop_partial_token(llm): - _test_stopping(llm, - stop=["gani"], - include_in_output=False, - expected_output="VLLM is a 100% volunteer or", - expected_reason="gani") + _test_stopping( + llm, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani", + ) - _test_stopping(llm, - stop=["gani"], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organi", - expected_reason="gani") + _test_stopping( + llm, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani", + ) def _stop_token_id(llm): # token id 13013 => " organization" - _test_stopping(llm, - stop_token_ids=[13013], - include_in_output=False, - expected_output="VLLM is a 100% volunteer", - expected_reason=13013) + _test_stopping( + llm, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013, + ) - _test_stopping(llm, - stop_token_ids=[13013], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=13013) + _test_stopping( + llm, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013, + ) @pytest.mark.skip_global_cleanup def test_stop_strings(): - # If V0, must set enforce_eager=False since we use - # async output processing below. - llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) + llm = LLM(MODEL, enforce_eager=True) - if envs.VLLM_USE_V1: - _stop_basic(llm) - else: - _set_async_mode(llm, True) - _stop_basic(llm) - - _set_async_mode(llm, False) - _stop_basic(llm) - - if envs.VLLM_USE_V1: - _stop_multi_tokens(llm) - else: - _set_async_mode(llm, True) - _stop_multi_tokens(llm) - - _set_async_mode(llm, False) - _stop_multi_tokens(llm) - - if envs.VLLM_USE_V1: - _stop_partial_token(llm) - else: - _set_async_mode(llm, True) - _stop_partial_token(llm) - - _set_async_mode(llm, False) - _stop_partial_token(llm) - - if envs.VLLM_USE_V1: - # FIXME: this does not respect include_in_output=False - # _stop_token_id(llm) - pass - else: - _set_async_mode(llm, True) - _stop_token_id(llm) - - _set_async_mode(llm, False) - _stop_token_id(llm) + _stop_basic(llm) + _stop_multi_tokens(llm) + _stop_partial_token(llm) + # FIXME: this does not respect include_in_output=False + # _stop_token_id(llm) diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 666a715cc0da1..47ceb45057c97 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -8,7 +8,7 @@ import msgspec.msgpack import pytest import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.distributed.kv_events import EventPublisherFactory from .test_events import SampleBatch @@ -111,8 +111,7 @@ class MockSubscriber: self.last_seq = -1 self.decoder = msgspec.msgpack.Decoder(type=decode_type) - def receive_one(self, - timeout=1000) -> Union[tuple[int, SampleBatch], None]: + def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]: """Receive a single message with timeout""" if not self.sub.poll(timeout): return None @@ -135,8 +134,7 @@ class MockSubscriber: self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big")) - def receive_replay(self, - socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: + def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: """Receive replayed messages from a specific replay socket""" if not self.replay_sockets: raise ValueError("Replay sockets not initialized") diff --git a/tests/distributed/test_ca_buffer_sharing.py b/tests/distributed/test_ca_buffer_sharing.py index e2de462612b47..1ddce64f8e614 100644 --- a/tests/distributed/test_ca_buffer_sharing.py +++ b/tests/distributed/test_ca_buffer_sharing.py @@ -12,7 +12,8 @@ import torch.distributed as dist from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa - CustomAllreduce) + CustomAllreduce, +) # create a cpu process group for communicating metadata (ipc handle) dist.init_process_group(backend="gloo") @@ -52,7 +53,8 @@ for p in pointers: assert ord(host_data[i]) == byte_value, ( f"Rank {rank} failed" f" to verify buffer {p}. Expected {byte_value}, " - f"got {ord(host_data[i])}") + f"got {ord(host_data[i])}" + ) print(f"Rank {rank} verified all buffers") diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index e2cb579e22dc4..c61c4584d8376 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -13,12 +13,19 @@ import pytest import ray import torch -from vllm.distributed import (broadcast_tensor_dict, get_pp_group, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) +from vllm.distributed import ( + broadcast_tensor_dict, + get_pp_group, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter, +) -from ..utils import init_test_distributed_environment, multi_process_parallel +from ..utils import ( + init_test_distributed_environment, + multi_gpu_test, + multi_process_parallel, +) @ray.remote(num_gpus=1, max_calls=1) @@ -36,12 +43,11 @@ def all_reduce_test_worker( device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tp_size) + torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1) + for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank % tp_size] @@ -50,28 +56,31 @@ def all_reduce_test_worker( @ray.remote(num_gpus=1, max_calls=1) -def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, - pp_size: int, rank: int, - distributed_init_port: str): +def reduce_scatter_test_worker( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, +): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tp_size) + torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1) + for r in range(tp_size) ] index = rank % tp_size partition_size = num_elements // tp_size all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - expected = all_reduce[index * partition_size:(index + 1) * partition_size] + expected = all_reduce[index * partition_size : (index + 1) * partition_size] t = all_tensors[index] t = tensor_model_parallel_reduce_scatter(t, 0) torch.testing.assert_close(t, expected) @@ -91,8 +100,7 @@ def all_gather_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) total_size = 1 @@ -100,8 +108,10 @@ def all_gather_test_worker( total_size *= s for all_gather_dimension in range(num_dimensions): all_tensors = [ - torch.arange(total_size, dtype=torch.float32, - device="cuda").reshape(tensor_size) * (r + 1) + torch.arange(total_size, dtype=torch.float32, device="cuda").reshape( + tensor_size + ) + * (r + 1) for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) @@ -124,8 +134,7 @@ def broadcast_tensor_dict_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), @@ -133,10 +142,7 @@ def broadcast_tensor_dict_test_worker( "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], - "e": { - "a": 1, - "b": 2 - }, + "e": {"a": 1, "b": 2}, # empty tensor "f": torch.tensor([], dtype=torch.float32, device="cuda"), } @@ -165,8 +171,7 @@ def send_recv_tensor_dict_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor @@ -175,10 +180,7 @@ def send_recv_tensor_dict_test_worker( "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], - "e": { - "a": 1, - "b": 2 - }, + "e": {"a": 1, "b": 2}, # empty tensor "f": torch.tensor([], dtype=torch.float32, device="cuda"), } @@ -210,8 +212,7 @@ def send_recv_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) size = 64 test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") @@ -226,13 +227,12 @@ def send_recv_test_worker( torch.testing.assert_close(test_tensor, recv_tensor) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("test_target", [ - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker -]) +@pytest.mark.parametrize( + "test_target", + [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker], +) def test_multi_process_tensor_parallel( monkeypatch: pytest.MonkeyPatch, tp_size: int, @@ -241,11 +241,11 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize( - "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) + "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker] +) def test_multi_process_pipeline_parallel( monkeypatch: pytest.MonkeyPatch, pp_size: int, @@ -254,15 +254,19 @@ def test_multi_process_pipeline_parallel( multi_process_parallel(monkeypatch, 1, pp_size, test_target) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@multi_gpu_test(num_gpus=4) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pp_size", [2]) -@pytest.mark.parametrize("test_target", [ - send_recv_test_worker, send_recv_tensor_dict_test_worker, - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker -]) +@pytest.mark.parametrize( + "test_target", + [ + send_recv_test_worker, + send_recv_tensor_dict_test_worker, + all_reduce_test_worker, + all_gather_test_worker, + broadcast_tensor_dict_test_worker, + ], +) def test_multi_process_tensor_parallel_pipeline_parallel( tp_size: int, pp_size: int, diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py new file mode 100644 index 0000000000000..89c2c9f8badeb --- /dev/null +++ b/tests/distributed/test_context_parallel.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" + +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config.model import RunnerOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_context_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + pp_size: int + dcp_size: int + eager_mode: bool + chunked_prefill: bool + + +class CPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class CPTestSettings: + parallel_setups: list[ParallelSetup] + distributed_backends: list[str] + runner: RunnerOption + test_options: CPTestOptions + + @staticmethod + def detailed( + *, + tp_base: int = 4, + pp_base: int = 1, + dcp_base: int = 1, + multi_node_only: bool = False, + runner: RunnerOption = "auto", + load_format: Optional[str] = None, + ): + parallel_setups = [] + for eager_mode_val in [False]: + for pp_multiplier in [1]: + for dcp_multiplier in [0.5, 1]: + for chunked_prefill_val in [True]: + parallel_setups.append( + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + dcp_size=int(dcp_multiplier * tp_base), + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) + return CPTestSettings( + parallel_setups=parallel_setups, + distributed_backends=["mp"], + runner=runner, + test_options=CPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend in self.distributed_backends: + yield ( + model_id, + parallel_setup, + backend, + self.runner, + opts, + ) + + +def _compare_cp_with_tp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate"], + is_multimodal: bool, +): + ( + tp_size, + pp_size, + dcp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if runner != "auto": + common_args.extend(["--runner", runner]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + cp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--decode-context-parallel-size", + str(dcp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + compare_two_settings( + model_id, + cp_args, + tp_args, + method=method, + max_wait_seconds=720, + ) + + +CP_TEXT_GENERATION_MODELS = { + # [MLA attention only] + "deepseek-ai/DeepSeek-V2-Lite-Chat": [ + CPTestSettings.detailed(), + CPTestSettings.detailed(tp_base=2), + ], +} + +CP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "deepseek-ai/DeepSeek-V2-Lite-Chat", +] + + +@pytest.mark.parametrize( + ( + "model_id", + "parallel_setup", + "distributed_backend", + "runner", + "test_options", + ), + [ + params + for model_id, settings in CP_TEXT_GENERATION_MODELS.items() + for setting in settings + for params in setting.iter_params(model_id) + if model_id in CP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_cp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available, +): + _compare_cp_with_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 9212c04deec90..f6e274be93847 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -8,12 +8,14 @@ import ray import torch import torch.distributed as dist -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.parallel_state import get_tp_group, graph_capture -from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, multi_process_parallel) +from ..utils import ( + ensure_model_parallel_initialized, + init_test_distributed_environment, + multi_process_parallel, +) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] @@ -33,8 +35,7 @@ def graph_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) group = get_tp_group().device_group @@ -60,18 +61,15 @@ def graph_allreduce( for dtype in [torch.float32, torch.float16, torch.bfloat16]: with graph_capture(device=device) as graph_capture_context: # use integers so result matches NCCL exactly - inp1 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + inp2 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): + with torch.cuda.graph(graph, stream=graph_capture_context.stream): for i in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) # the input buffer is immediately modified to test @@ -96,8 +94,7 @@ def eager_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) # we use the first group to communicate once # and the second group to communicate twice @@ -132,5 +129,4 @@ def test_custom_allreduce( world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_distributed_oot.py b/tests/distributed/test_distributed_oot.py index b93696e4be0e1..ea7a88abda245 100644 --- a/tests/distributed/test_distributed_oot.py +++ b/tests/distributed/test_distributed_oot.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from ..entrypoints.openai.test_oot_registration import ( - run_and_test_dummy_opt_api_server) +from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server def test_distributed_oot(dummy_opt_path: str): diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index e47ccba99c81d..79805a7cce53b 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -10,10 +10,12 @@ from vllm.distributed.eplb.rebalance_algo import rebalance_experts def test_basic_rebalance(): """Test basic rebalancing functionality""" # Example from https://github.com/deepseek-ai/eplb - weight = torch.tensor([ - [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], - [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], - ]) + weight = torch.tensor( + [ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ] + ) num_layers = weight.shape[0] num_replicas = 16 @@ -21,45 +23,49 @@ def test_basic_rebalance(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify output shapes assert phy2log.shape == ( 2, 16, ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" - assert (log2phy.shape[0] == 2 - ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" - assert ( - log2phy.shape[1] == 12 - ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + assert log2phy.shape[0] == 2, ( + f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" + ) + assert log2phy.shape[1] == 12, ( + f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + ) assert logcnt.shape == ( 2, 12, ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" # Verify physical to logical expert mapping range is correct - assert torch.all(phy2log >= 0) and torch.all( - phy2log < 12), "Physical to logical mapping should be in range [0, 12)" + assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), ( + "Physical to logical mapping should be in range [0, 12)" + ) # Verify expert count reasonableness - assert torch.all( - logcnt >= 1), "Each logical expert should have at least 1 replica" - assert ( - torch.sum(logcnt, dim=1).sum() == num_replicas * - num_layers), f"Total replicas should be {num_replicas * num_layers}" + assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica" + assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, ( + f"Total replicas should be {num_replicas * num_layers}" + ) # Verify expected output - expected_phy2log = torch.tensor([ - [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], - [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], - ]) + expected_phy2log = torch.tensor( + [ + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], + ] + ) assert torch.all(phy2log == expected_phy2log) - expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], - [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) + expected_logcnt = torch.tensor( + [[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]] + ) assert torch.all(logcnt == expected_logcnt) @@ -71,9 +77,9 @@ def test_single_gpu_case(): num_nodes = 1 num_gpus = 1 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 4) @@ -93,19 +99,19 @@ def test_equal_weights(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 8) assert logcnt.shape == (1, 8) # With equal weights, each expert should have exactly one replica - assert torch.all( - logcnt == 1 - ), "With equal weights and no replication, " \ - "each expert should have exactly 1 replica" + assert torch.all(logcnt == 1), ( + "With equal weights and no replication, " + "each expert should have exactly 1 replica" + ) def test_extreme_weight_imbalance(): @@ -116,35 +122,37 @@ def test_extreme_weight_imbalance(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 12) assert logcnt.shape == (1, 8) # Expert with highest weight (index 0) should have more replicas - assert ( - logcnt[0, 0] - > logcnt[0, 1]), "Expert with highest weight should have more replicas" + assert logcnt[0, 0] > logcnt[0, 1], ( + "Expert with highest weight should have more replicas" + ) def test_multiple_layers(): """Test multiple layers case""" - weight = torch.tensor([ - [10, 20, 30, 40, 50, 60], # First layer - [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) - [25, 25, 25, 25, 25, 25], # Third layer (equal weights) - ]) + weight = torch.tensor( + [ + [10, 20, 30, 40, 50, 60], # First layer + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) + ] + ) num_replicas = 8 num_groups = 2 num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (3, 8) @@ -152,12 +160,12 @@ def test_multiple_layers(): # Verify expert allocation is reasonable for each layer for layer in range(3): - assert torch.all(phy2log[layer] >= 0) and torch.all( - phy2log[layer] < 6 - ), f"Layer {layer} physical to logical mapping" \ - "should be in range [0, 6)" - assert (torch.sum(logcnt[layer]) == num_replicas - ), f"Layer {layer} total replicas should be {num_replicas}" + assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), ( + f"Layer {layer} physical to logical mappingshould be in range [0, 6)" + ) + assert torch.sum(logcnt[layer]) == num_replicas, ( + f"Layer {layer} total replicas should be {num_replicas}" + ) def test_parameter_validation(): @@ -179,17 +187,19 @@ def test_parameter_validation(): def test_small_scale_hierarchical(): """Test small-scale hierarchical load balancing""" - weight = torch.tensor([ - [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts - ]) + weight = torch.tensor( + [ + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts + ] + ) num_replicas = 12 num_groups = 4 # 4 groups, 2 experts each num_nodes = 2 # 2 nodes num_gpus = 4 # 4 GPUs - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify basic constraints assert phy2log.shape == (1, 12) @@ -199,8 +209,9 @@ def test_small_scale_hierarchical(): # Expert with highest weight should have more replicas max_weight_expert = torch.argmax(weight[0]) - assert (logcnt[0, max_weight_expert] - >= 2), "Highest weight expert should have multiple replicas" + assert logcnt[0, max_weight_expert] >= 2, ( + "Highest weight expert should have multiple replicas" + ) def test_global_load_balance_fallback(): @@ -213,9 +224,9 @@ def test_global_load_balance_fallback(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Should work normally, just using global load balancing strategy assert phy2log.shape == (1, 8) @@ -235,9 +246,9 @@ def test_device_compatibility(device): num_nodes = 1 num_gpus = 2 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Function will convert to CPU internally, but should handle different # device inputs normally @@ -250,7 +261,8 @@ def test_additional_cases(): # Test case 1: Large-scale distributed setup weight1 = torch.tensor( - [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] + ) phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) assert phy2log1.shape == (1, 24) @@ -258,10 +270,12 @@ def test_additional_cases(): assert torch.sum(logcnt1) == 24 # Test case 2: Different weight distributions - weight2 = torch.tensor([ - [200, 150, 100, 50, 25, 12], # Decreasing weights - [12, 25, 50, 100, 150, 200], # Increasing weights - ]) + weight2 = torch.tensor( + [ + [200, 150, 100, 50, 25, 12], # Decreasing weights + [12, 25, 50, 100, 150, 200], # Increasing weights + ] + ) phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) assert phy2log2.shape == (2, 10) @@ -274,19 +288,21 @@ def test_additional_cases(): if __name__ == "__main__": - weight = torch.tensor([ - [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], - [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], - ]) + weight = torch.tensor( + [ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ] + ) num_replicas = 16 num_groups = 4 num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) print(phy2log) test_basic_rebalance() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index de9ed1eabbac6..7ca3d3d27b562 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -9,11 +9,12 @@ import pytest import torch import torch.distributed -from vllm.distributed.eplb.rebalance_execute import ( - rearrange_expert_weights_inplace) -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - get_tp_group, - init_distributed_environment) +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_tp_group, + init_distributed_environment, +) from vllm.utils import update_environment_variables @@ -22,13 +23,13 @@ def distributed_run(fn, world_size): processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -45,7 +46,7 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - local_rank = os.environ['LOCAL_RANK'] + local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() @@ -60,20 +61,20 @@ def worker_fn_wrapper(fn): def create_expert_indices_with_redundancy( - num_layers: int, - num_logical_experts: int, - total_physical_experts: int, - redundancy_config: list[int], # redundancy for each logical expert + num_layers: int, + num_logical_experts: int, + total_physical_experts: int, + redundancy_config: list[int], # redundancy for each logical expert ) -> torch.Tensor: """ Create expert indices with redundancy. - + Args: num_layers: number of layers num_logical_experts: number of logical experts total_physical_experts: total number of physical experts redundancy_config: redundancy for each logical expert - + Returns: indices: Shape (num_layers, total_physical_experts) """ @@ -106,11 +107,11 @@ def create_expert_weights( ) -> list[list[torch.Tensor]]: """ Create fake expert weights tensor for testing. - + Use `arange` to generate predictable weights values, based on logical expert ID. All replicas of the same logical expert should have the same weights. - + Args: physical_to_logical_mapping: Shape (num_layers, num_local_experts) mapping[layer, physical_pos] = logical_expert_id @@ -120,27 +121,27 @@ def create_expert_weights( for layer in range(num_layers): layer_weights = [] for weight_idx, hidden_size in enumerate(hidden_sizes): - weight_tensor = torch.zeros(num_local_experts, - hidden_size, - device=device, - dtype=torch.float32) + weight_tensor = torch.zeros( + num_local_experts, hidden_size, device=device, dtype=torch.float32 + ) for local_expert in range(num_local_experts): # Get the logical expert ID for this physical expert global_pos = rank * num_local_experts + local_expert logical_expert_id = physical_to_logical_mapping[ - layer, global_pos].item() + layer, global_pos + ].item() # Generate weights based on logical expert ID # (so that all replicas of the same logical expert have the # same weights) - base_value = (logical_expert_id * 1000 + layer * 100 + - weight_idx * 10) - weight_tensor[local_expert] = torch.arange(base_value, - base_value + - hidden_size, - device=device, - dtype=torch.float32) + base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10 + weight_tensor[local_expert] = torch.arange( + base_value, + base_value + hidden_size, + device=device, + dtype=torch.float32, + ) layer_weights.append(weight_tensor) expert_weights.append(layer_weights) @@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle( # Check if the weights are correct actual_weights = weight_tensor[local_expert] - expected_base = (expected_logical_expert * 1000 + layer * 100 + - weight_idx * 10) - expected_weights = torch.arange(expected_base, - expected_base + hidden_size, - device=actual_weights.device, - dtype=actual_weights.dtype) + expected_base = ( + expected_logical_expert * 1000 + layer * 100 + weight_idx * 10 + ) + expected_weights = torch.arange( + expected_base, + expected_base + hidden_size, + device=actual_weights.device, + dtype=actual_weights.dtype, + ) torch.testing.assert_close( actual_weights, @@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle( msg=f"Layer {layer}, weight {weight_idx}," f"local expert {local_expert}: " f"weights do not match. " - f"Expected logical expert {expected_logical_expert}") + f"Expected logical expert {expected_logical_expert}", + ) def verify_redundant_experts_have_same_weights( @@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights( total_physical_experts, hidden_size, device=expert_weights[layer][weight_idx].device, - dtype=expert_weights[layer][weight_idx].dtype) + dtype=expert_weights[layer][weight_idx].dtype, + ) # Use all_gather to collect expert weights from current node # expert_weights[layer][weight_idx] shape: # [num_local_experts, hidden_size] local_weights = expert_weights[layer][ - weight_idx] # [num_local_experts, hidden_size] + weight_idx + ] # [num_local_experts, hidden_size] # Split tensor along dim 0 into a list for all_gather - gathered_weights_list = torch.chunk(gathered_weights, - world_size, - dim=0) + gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0) torch.distributed.all_gather( # Output list: each element corresponds to one rank's weights list(gathered_weights_list), - local_weights # Input: current rank's local weights + local_weights, # Input: current rank's local weights ) all_weights.append(gathered_weights) @@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights( msg=f"Layer {layer}, weight {weight_idx}," f"logical expert {logical_expert_id}: " f"Physical expert {physical_pos} has different weights" - f"than expected") + f"than expected", + ) @pytest.mark.parametrize( @@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights( # 4 GPU, 8 experts per GPU # 16 logical experts, 32 physical experts, 16 redundant experts (4, 8, 8, 16), - ]) -def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, - num_local_experts, - num_logical_experts): + ], +) +def test_rearrange_expert_weights_with_redundancy( + world_size, num_layers, num_local_experts, num_logical_experts +): """Test the functionality of rearranging expert weights with redundancy.""" if torch.cuda.device_count() < world_size: @@ -304,8 +311,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, # Initialize model parallel (using tensor parallel as an entrypoint # to expert parallel) ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -316,8 +323,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, hidden_sizes = [32, 64] # Two different weight matrices # Create old expert indices (with redundancy) - redundancy_config = create_redundancy_config(num_logical_experts, - total_physical_experts) + redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts + ) old_indices = create_expert_indices_with_redundancy( num_layers, @@ -328,7 +336,8 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, # Create new expert indices (with redundancy) new_redundancy_config = create_redundancy_config( - num_logical_experts, total_physical_experts) + num_logical_experts, total_physical_experts + ) new_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, @@ -337,9 +346,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, ) # Create expert weights - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - old_indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) # Execute weight rearrangement rearrange_expert_weights_inplace( @@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size): @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -401,12 +410,12 @@ def test_rearrange_expert_weights_no_change(world_size): # Same indices - no change indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - redundancy_config) + num_layers, num_logical_experts, total_physical_experts, redundancy_config + ) - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices + ) # Save original weights original_weights = [] @@ -422,7 +431,8 @@ def test_rearrange_expert_weights_no_change(world_size): indices, # Same indices expert_weights, ep_group, - is_profile=False) + is_profile=False, + ) # Verify that the weights have not changed for layer in range(num_layers): @@ -430,8 +440,8 @@ def test_rearrange_expert_weights_no_change(world_size): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], - msg=f"Layer {layer}, weight {weight_idx} should remain " - f"unchanged") + msg=f"Layer {layer}, weight {weight_idx} should remain unchanged", + ) distributed_run(worker_fn, world_size) @@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size): @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -460,21 +470,23 @@ def test_rearrange_expert_weights_profile_mode(world_size): hidden_sizes = [32] # Create different index distributions - old_redundancy = create_redundancy_config(num_logical_experts, - total_physical_experts) - new_redundancy = create_redundancy_config(num_logical_experts, - total_physical_experts) + old_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) + new_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) old_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - old_redundancy) + num_layers, num_logical_experts, total_physical_experts, old_redundancy + ) new_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - new_redundancy) + num_layers, num_logical_experts, total_physical_experts, new_redundancy + ) - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - old_indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) # Save original weights original_weights = [] @@ -490,7 +502,7 @@ def test_rearrange_expert_weights_profile_mode(world_size): new_indices, expert_weights, ep_group, - is_profile=True # Profile mode + is_profile=True, # Profile mode ) # In profile mode, the weights should remain unchanged @@ -499,6 +511,7 @@ def test_rearrange_expert_weights_profile_mode(world_size): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], - msg="In profile mode, the weights should remain unchanged") + msg="In profile mode, the weights should remain unchanged", + ) distributed_run(worker_fn, world_size) diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index 8be9ee0a1889d..f06f6771a4a0b 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -6,24 +6,29 @@ import time import msgspec import pytest -from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, - NullEventPublisher) +from vllm.distributed.kv_events import ( + EventBatch, + EventPublisherFactory, + NullEventPublisher, +) DP_RANK = 0 class EventSample( - msgspec.Struct, - tag=True, # type: ignore - array_like=True # type: ignore + msgspec.Struct, + tag=True, # type: ignore + array_like=True, # type: ignore ): """Test event for publisher testing""" + id: int value: str class SampleBatch(EventBatch): """Test event batch for publisher testing""" + events: list[EventSample] @@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber): seq, received = result assert seq == 0, "Sequence number mismatch" - assert received.ts == pytest.approx(test_batch.ts, - abs=0.1), ("Timestamp mismatch") - assert len(received.events) == len( - test_batch.events), ("Number of events mismatch") + assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch" + assert len(received.events) == len(test_batch.events), "Number of events mismatch" for i, event in enumerate(received.events): assert event.id == i, "Event id mismatch" @@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber): assert len(replayed) > 0, "No replayed messages received" seqs = [seq for seq, _ in replayed] assert all(seq >= 10 for seq in seqs), "Replayed messages not in order" - assert seqs == list(range(min(seqs), - max(seqs) + - 1)), ("Replayed messages not consecutive") + assert seqs == list(range(min(seqs), max(seqs) + 1)), ( + "Replayed messages not consecutive" + ) def test_buffer_limit(publisher, subscriber, publisher_config): @@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config): pub = EventPublisherFactory.create(publisher_config, DP_RANK) from .conftest import MockSubscriber + sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar") @@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config): foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)] assert all(msg is not None for msg in foo_received), ( - "Subscriber with matching topic should receive messages") + "Subscriber with matching topic should receive messages" + ) bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)] assert all(msg is None for msg in bar_received), ( - "Subscriber with non-matching topic should receive no messages") + "Subscriber with non-matching topic should receive no messages" + ) finally: pub.shutdown() sub_foo.close() @@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber): publisher_thread.join() - assert len(received) >= num_batches * 0.9, ( - "We should have received most messages") + assert len(received) >= num_batches * 0.9, "We should have received most messages" seqs = [seq for seq, _ in received] assert sorted(seqs) == seqs, "Sequence numbers should be in order" @@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config): # For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558 expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port expected_endpoint_1 = base_endpoint.replace( - ":5557", ":5558") # rank 1 gets port + 1 + ":5557", ":5558" + ) # rank 1 gets port + 1 else: # For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1 expected_endpoint_0 = base_endpoint # rank 0 gets base expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1 from .conftest import MockSubscriber + sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic) sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic) @@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config): # Verify DP rank tagging assert received_0.data_parallel_rank == 0, ( - f"Expected DP rank 0, got {received_0.data_parallel_rank}") + f"Expected DP rank 0, got {received_0.data_parallel_rank}" + ) assert received_1.data_parallel_rank == 1, ( - f"Expected DP rank 1, got {received_1.data_parallel_rank}") + f"Expected DP rank 1, got {received_1.data_parallel_rank}" + ) # Verify event content is correct - assert len( - received_0.events) == 2, "Wrong number of events from rank 0" - assert len( - received_1.events) == 3, "Wrong number of events from rank 1" + assert len(received_0.events) == 2, "Wrong number of events from rank 0" + assert len(received_1.events) == 3, "Wrong number of events from rank 1" finally: pub_0.shutdown() diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py index f273f302e72e8..8a9ddcd58cfce 100644 --- a/tests/distributed/test_expert_parallel.py +++ b/tests/distributed/test_expert_parallel.py @@ -6,7 +6,7 @@ from typing import Literal, NamedTuple, Optional import pytest -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.logger import init_logger from ..utils import compare_two_settings, create_new_process_for_each_test @@ -46,28 +46,24 @@ class EPTestSettings: ): return EPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=2 * tp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=2 * tp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True), + ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False), + ParallelSetup( + tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True + ), + ParallelSetup( + tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False + ), ], distributed_backends=["mp", "ray"], runner=runner, - test_options=EPTestOptions(trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + test_options=EPTestOptions( + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides, + ), ) @staticmethod @@ -82,16 +78,16 @@ class EPTestSettings: ): return EPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False), ], distributed_backends=["mp"], runner=runner, - test_options=EPTestOptions(trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + test_options=EPTestOptions( + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides, + ), ) def iter_params(self, model_name: str): @@ -99,17 +95,20 @@ class EPTestSettings: for parallel_setup in self.parallel_setups: for distributed_backend in self.distributed_backends: - yield (model_name, parallel_setup, distributed_backend, - self.runner, opts) + yield ( + model_name, + parallel_setup, + distributed_backend, + self.runner, + opts, + ) # NOTE: You can adjust tp_base locally to fit the model in GPU # The values displayed here are only a rough indicator of the size of the model -# yapf: disable TEST_MODELS = { - "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast( - trust_remote_code=True), + "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(trust_remote_code=True), "mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4), } @@ -191,22 +190,24 @@ def _compare_tp( ] try: - compare_two_settings(model_name, - ep_args, - tp_args, - ep_env, - tp_env, - method=method, - max_wait_seconds=360) + compare_two_settings( + model_name, + ep_args, + tp_args, + ep_env, + tp_env, + method=method, + max_wait_seconds=360, + ) except Exception: raise @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "runner", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_name, settings in TEST_MODELS.items() + params + for model_name, settings in TEST_MODELS.items() for params in settings.iter_params(model_name) ], ) @@ -219,10 +220,12 @@ def test_ep( test_options: EPTestOptions, num_gpus_available, ): - _compare_tp(model_name, - parallel_setup, - distributed_backend, - runner, - test_options, - num_gpus_available, - method="generate") + _compare_tp( + model_name, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + ) diff --git a/tests/distributed/test_expert_placement.py b/tests/distributed/test_expert_placement.py new file mode 100644 index 0000000000000..cb9c8f5074049 --- /dev/null +++ b/tests/distributed/test_expert_placement.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.model_executor.layers.fused_moe.layer import determine_expert_map + + +def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts): + """Verify that the expert map follows the round_robin pattern.""" + # Calculate expected local experts (supporting non-divisible cases) + base_experts = global_num_experts // ep_size + remainder = global_num_experts % ep_size + + local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts + + # Expected expert IDs for this rank in round_robin pattern + # For non-divisible cases, ranks with extra experts start earlier + expected_expert_ids = [] + for expert_idx in range(local_num_experts): + global_expert_id = ep_rank + expert_idx * ep_size + expected_expert_ids.append(global_expert_id) + + # Check that only expected experts are mapped to this rank + for global_expert_id in range(global_num_experts): + if global_expert_id in expected_expert_ids: + local_expert_id = expert_map[global_expert_id] + expected_local_id = expected_expert_ids.index(global_expert_id) + assert local_expert_id == expected_local_id, ( + f"Global expert {global_expert_id} should map to local expert " + f"{expected_local_id}, got {local_expert_id}" + ) + else: + assert expert_map[global_expert_id] == -1, ( + f"Global expert {global_expert_id} should not be mapped to this rank" + ) + + # Verify that all local expert IDs are consecutive starting from 0 + local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids] + expected_local_ids = list(range(local_num_experts)) + assert local_expert_ids == expected_local_ids, ( + f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}" + ) + + +@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) +@pytest.mark.parametrize("world_size", [2, 4]) +def test_expert_placement_various_sizes(expert_placement_strategy, world_size): + """Test round_robin expert placement with various expert counts.""" + + # Test with different global_num_experts values + # Include both divisible and non-divisible cases + if world_size == 2: + test_cases = [ + (4, 2), # 4 experts (divisible) + (8, 2), # 8 experts (divisible) + (9, 2), # 9 experts (non-divisible) + (16, 2), # 16 experts (divisible) + (17, 2), # 17 experts (non-divisible) + ] + elif world_size == 4: + test_cases = [ + (8, 4), # 8 experts (divisible) + (16, 4), # 16 experts (divisible) + (18, 4), # 18 experts (non-divisible) + (32, 4), # 32 experts (divisible) + (33, 4), # 33 experts (non-divisible) + ] + else: + test_cases = [] + + for test_global_experts, test_ep_size in test_cases: + # Ensure ep_size matches world_size + assert test_ep_size == world_size, ( + f"ep_size {test_ep_size} must equal world_size {world_size}" + ) + + # Test each rank + for ep_rank in range(world_size): + # Calculate expected local experts + base_experts = test_global_experts // test_ep_size + remainder = test_global_experts % test_ep_size + if ep_rank < remainder: + expected_test_local = base_experts + 1 + else: + expected_test_local = base_experts + + test_local_experts, test_expert_map = determine_expert_map( + ep_size=test_ep_size, + ep_rank=ep_rank, + global_num_experts=test_global_experts, + expert_placement_strategy=expert_placement_strategy, + ) + + assert test_local_experts == expected_test_local, ( + f"For {test_global_experts} experts on {test_ep_size} ranks, " + f"rank {ep_rank}: expected {expected_test_local} local" + f"experts, got {test_local_experts}" + ) + + if test_expert_map is not None: + assert test_expert_map.shape == (test_global_experts,), ( + f"Expected expert map shape ({test_global_experts},), " + f"got {test_expert_map.shape}" + ) + + # Verify round_robin pattern for this test case + verify_round_robin_pattern( + test_expert_map, ep_rank, test_ep_size, test_global_experts + ) + + +@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) +@pytest.mark.parametrize("world_size", [2, 4]) +def test_expert_placement_edge_cases(expert_placement_strategy, world_size): + """Test edge cases for round_robin expert placement.""" + + # Test case 1: ep_size = 1 (should return None for expert_map) + local_num_experts, expert_map = determine_expert_map( + ep_size=1, + ep_rank=0, + global_num_experts=8, + expert_placement_strategy=expert_placement_strategy, + ) + assert local_num_experts == 8, "For ep_size=1, should get all experts" + assert expert_map is None, "For ep_size=1, expert_map should be None" + + # Test case 2: ep_size = 0 (should raise assertion) + with pytest.raises(AssertionError): + determine_expert_map( + ep_size=0, + ep_rank=0, + global_num_experts=8, + expert_placement_strategy=expert_placement_strategy, + ) + + +def test_determine_expert_map_comprehensive(): + """Test of determine_expert_map function with various configurations.""" + + # Test cases: (ep_size, ep_rank, global_num_experts, + # expert_placement_strategy, expected_local, expected_map_pattern) + test_cases = [ + # Round robin placement tests + ( + 2, + 0, + 8, + "round_robin", + 4, + [0, -1, 1, -1, 2, -1, 3, -1], + ), # rank 0 gets even experts + ( + 2, + 1, + 8, + "round_robin", + 4, + [-1, 0, -1, 1, -1, 2, -1, 3], + ), # rank 1 gets odd experts + ( + 2, + 0, + 9, + "round_robin", + 5, + [0, -1, 1, -1, 2, -1, 3, -1, 4], + ), # rank 0 gets 5 experts (even + last) + ( + 2, + 1, + 9, + "round_robin", + 4, + [-1, 0, -1, 1, -1, 2, -1, 3, -1], + ), # rank 1 gets 4 experts (odd) + # 4-rank tests + ( + 4, + 0, + 8, + "round_robin", + 2, + [0, -1, -1, -1, 1, -1, -1, -1], + ), # rank 0 gets experts 0, 4 + ( + 4, + 1, + 8, + "round_robin", + 2, + [-1, 0, -1, -1, -1, 1, -1, -1], + ), # rank 1 gets experts 1, 5 + ( + 4, + 2, + 8, + "round_robin", + 2, + [-1, -1, 0, -1, -1, -1, 1, -1], + ), # rank 2 gets experts 2, 6 + ( + 4, + 3, + 8, + "round_robin", + 2, + [-1, -1, -1, 0, -1, -1, -1, 1], + ), # rank 3 gets experts 3, 7 + ] + + for ( + ep_size, + ep_rank, + global_num_experts, + expert_placement_strategy, + expected_local, + expected_map_pattern, + ) in test_cases: + local_num_experts, expert_map = determine_expert_map( + ep_size=ep_size, + ep_rank=ep_rank, + global_num_experts=global_num_experts, + expert_placement_strategy=expert_placement_strategy, + ) + + assert local_num_experts == expected_local, ( + f"ep_size={ep_size}, ep_rank={ep_rank}, " + f"global_num_experts={global_num_experts}, " + f"expert_placement_strategy={expert_placement_strategy}: " + f"expected {expected_local} local experts, got {local_num_experts}" + ) + + if expected_map_pattern is None: + assert expert_map is None, "Expected expert_map to be None" + else: + assert expert_map is not None, "Expected expert_map to not be None" + actual_map = expert_map.tolist() + assert actual_map == expected_map_pattern, ( + f"ep_size={ep_size}, ep_rank={ep_rank}, " + f"global_num_experts={global_num_experts}, " + f"expert_placement_strategy={expert_placement_strategy}: " + f"expected map {expected_map_pattern}, got {actual_map}" + ) diff --git a/tests/distributed/test_kvlayout.py b/tests/distributed/test_kvlayout.py index d447876f6cc7c..b190b2820451b 100644 --- a/tests/distributed/test_kvlayout.py +++ b/tests/distributed/test_kvlayout.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig, - VllmConfig, set_current_vllm_config) +from vllm.config import ( + DeviceConfig, + KVTransferConfig, + ModelConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) + get_kv_connector_cache_layout, +) from vllm.logger import init_logger logger = init_logger("test_expert_parallel") @@ -23,8 +29,9 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector(): kv_connector="LMCacheConnectorV1", kv_role="kv_both", ) - vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), - kv_transfer_config=kv_transfer_config) + vllm_config = VllmConfig( + device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config + ) with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() @@ -37,9 +44,11 @@ def test_get_kv_connector_cache_layout_with_nixl_connector(): kv_role="kv_both", ) model_config = ModelConfig() - vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), - model_config=model_config, - kv_transfer_config=kv_transfer_config) + vllm_config = VllmConfig( + device_config=DeviceConfig("cpu"), + model_config=model_config, + kv_transfer_config=kv_transfer_config, + ) with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() @@ -47,25 +56,22 @@ def test_get_kv_connector_cache_layout_with_nixl_connector(): def test_get_kv_connector_cache_layout_with_multi_connector(): - kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [{ - "kv_connector": - "SharedStorageConnector", - "kv_role": - "kv_both" - }, { - "kv_connector": - "NixlConnector", - "kv_role": - "kv_both" - }] - }) + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + {"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"}, + {"kv_connector": "NixlConnector", "kv_role": "kv_both"}, + ] + }, + ) model_config = ModelConfig() - vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), - model_config=model_config, - kv_transfer_config=kv_transfer_config) + vllm_config = VllmConfig( + device_config=DeviceConfig("cpu"), + model_config=model_config, + kv_transfer_config=kv_transfer_config, + ) with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() diff --git a/tests/distributed/test_multi_node_assignment.py b/tests/distributed/test_multi_node_assignment.py index ef17a51fff0e1..8d818edbb3bd7 100644 --- a/tests/distributed/test_multi_node_assignment.py +++ b/tests/distributed/test_multi_node_assignment.py @@ -24,14 +24,13 @@ from vllm.utils import get_ip VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.skipif(not VLLM_MULTI_NODE, - reason="Need at least 2 nodes to run the test.") +@pytest.mark.skipif( + not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test." +) def test_multi_node_assignment() -> None: - # NOTE: important to keep this class definition here # to let ray use cloudpickle to serialize it. class Actor: - def get_ip(self): return get_ip() @@ -41,8 +40,7 @@ def test_multi_node_assignment() -> None: current_ip = get_ip() workers = [] - for bundle_id, bundle in enumerate( - config.placement_group.bundle_specs): + for bundle_id, bundle in enumerate(config.placement_group.bundle_specs): if not bundle.get("GPU", 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( diff --git a/tests/distributed/test_nccl_symm_mem_allreduce.py b/tests/distributed/test_nccl_symm_mem_allreduce.py new file mode 100644 index 0000000000000..40dcf7567c92f --- /dev/null +++ b/tests/distributed/test_nccl_symm_mem_allreduce.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops +from vllm.distributed.device_communicators.pynccl_allocator import ( + get_nccl_mem_pool, + is_symmetric_memory_enabled, +) +from vllm.distributed.parallel_state import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast( + CudaCommunicator, get_tp_group().device_communicator + ) + pynccl_comm = cuda_communicator.pynccl_comm + if get_nccl_mem_pool() is None: + pytest.skip( + "NCCL allocator compilation failed (probably missing NCCL headers)." + ) + if not is_symmetric_memory_enabled(): + pytest.skip("NCCL symmetric memory allreduce is disabled.") + + register_nccl_symmetric_ops(pynccl_comm) + input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device) + input_clone = input.clone() + output = torch.ops.vllm.all_reduce_symmetric_with_copy(input) + assert output is not None + + group = get_tp_group().device_group + dist.all_reduce(input_clone, group=group) + torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="NCCLSymmMemAllreduce is only available for CUDA platforms.", +) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1") + monkeypatch.setenv("NCCL_NVLS_ENABLE", "1") + monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1") + + mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/tests/distributed/test_node_count.py b/tests/distributed/test_node_count.py index e3c36ef5ef379..b48c025aa1a23 100644 --- a/tests/distributed/test_node_count.py +++ b/tests/distributed/test_node_count.py @@ -32,12 +32,15 @@ if __name__ == "__main__": # Expected node count based on environment variable) expected = int(os.environ.get("NUM_NODES", "1")) - assert test_result == expected, \ - f"Expected {expected} nodes, got {test_result}" + assert test_result == expected, f"Expected {expected} nodes, got {test_result}" if pg == dist.group.WORLD: - print(f"Node count test passed! Got {test_result} nodes " - f"when using torch distributed!") + print( + f"Node count test passed! Got {test_result} nodes " + f"when using torch distributed!" + ) else: - print(f"Node count test passed! Got {test_result} nodes " - f"when using StatelessProcessGroup!") + print( + f"Node count test passed! Got {test_result} nodes " + f"when using StatelessProcessGroup!" + ) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 28150d7682378..43f0c9dd1a85a 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -14,7 +15,7 @@ from typing import Literal, NamedTuple, Optional import pytest -from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption +from vllm.config.model import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption from vllm.logger import init_logger from vllm.transformers_utils.config import get_config @@ -26,23 +27,10 @@ logger = init_logger("test_pipeline_parallel") VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - For PP, we fall back to V0 by default. This means - that the TP baseline runs with V1 while the PP engine - runs with V0. This gives divergent results with dummy - weights. Once we enable V1 by default for PP, we can - remove this. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - class ParallelSetup(NamedTuple): tp_size: int pp_size: int eager_mode: bool - chunked_prefill: bool class PPTestOptions(NamedTuple): @@ -53,23 +41,10 @@ class PPTestOptions(NamedTuple): @dataclass class PPTestSettings: parallel_setups: list[ParallelSetup] - # NOTE: the length of distributed_backends and - # vllm_major_versions should be the same, and they - # are first zipped together to iterate over all - # test settings. distributed_backends: list[str] - # vllm major version: "0" for V0, "1" for V1 - vllm_major_versions: list[str] runner: RunnerOption test_options: PPTestOptions - def __post_init__(self): - if len(self.distributed_backends) != len(self.vllm_major_versions): - raise ValueError( - f"Length mismatch: distributed_backends " - f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") - @staticmethod def detailed( *, @@ -81,32 +56,17 @@ class PPTestSettings: ): return PPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=False), + ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=False), + ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=True), + ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=False), + ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=True), ], - distributed_backends=["mp", "mp", "ray", "ray"], - vllm_major_versions=["0", "1", "0", "1"], + distributed_backends=["mp", "ray"], runner=runner, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=PPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -120,37 +80,31 @@ class PPTestSettings: ): return PPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=True), ], distributed_backends=["mp"], - vllm_major_versions=["0"], runner=runner, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=PPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.runner, opts) + for backend in self.distributed_backends: + yield (model_id, parallel_setup, backend, self.runner, opts) # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU # The values displayed here are only a rough indicator of the size of the model -# yapf: disable TEXT_GENERATION_MODELS = { # [Decoder-only] # Uses Llama # "BAAI/AquilaChat-7B": PPTestSettings.fast(), - "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"), "baichuan-inc/Baichuan-7B": PPTestSettings.fast(), "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(), "bigscience/bloomz-1b1": PPTestSettings.fast(), @@ -184,7 +138,7 @@ TEXT_GENERATION_MODELS = { # Uses Llama # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), "state-spaces/mamba-130m-hf": PPTestSettings.fast(), - "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"), "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), @@ -195,13 +149,15 @@ TEXT_GENERATION_MODELS = { "adept/persimmon-8b-chat": PPTestSettings.fast(), "microsoft/phi-2": PPTestSettings.fast(), "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(), - "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"), # noqa: E501 + "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed( + multi_node_only=True, load_format="dummy" + ), "Qwen/Qwen-7B-Chat": PPTestSettings.fast(), "Qwen/Qwen2.5-0.5B-Instruct": PPTestSettings.fast(), "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), - "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"), # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(), @@ -240,11 +196,7 @@ MULTIMODAL_MODELS = { "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(), - # [Encoder-decoder] - # TODO: Implement PP - # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), } -# yapf: enable # NOTE: You can update this on your local machine to run specific tests TEST_MODELS = [ @@ -270,7 +222,6 @@ def _compare_tp( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available: int, @@ -282,7 +233,6 @@ def _compare_tp( tp_size, pp_size, eager_mode, - chunked_prefill, ) = parallel_setup multi_node_only, load_format = test_options @@ -294,6 +244,8 @@ def _compare_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides hf_config = get_config(model_id, trust_remote_code) + skip_tokenizer_init = model_info.skip_tokenizer_init + max_num_seqs = model_info.max_num_seqs dtype = "float16" if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: @@ -319,8 +271,10 @@ def _compare_tp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -333,8 +287,6 @@ def _compare_tp( "--max-num-seqs", "8", ] - if chunked_prefill: - common_args.append("--enable-chunked-prefill") if eager_mode: common_args.append("--enforce-eager") if runner != "auto": @@ -347,15 +299,14 @@ def _compare_tp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") + if max_num_seqs: + common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) - specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill - testing_ray_compiled_graph = False - if distributed_backend == "ray" and (vllm_major_version == "1" - or specific_case): + if distributed_backend == "ray": # For V1, test Ray Compiled Graph for all the tests - # For V0, test Ray Compiled Graph for a subset of the tests pp_env = { - "VLLM_USE_V1": vllm_major_version, "VLLM_USE_RAY_COMPILED_DAG": "1", "VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", @@ -363,18 +314,12 @@ def _compare_tp( # Temporary. Currently when zeromq + SPMD is used, it does not properly # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") - testing_ray_compiled_graph = True elif distributed_backend == "mp": - # Both V0/V1 of multiprocessing executor support PP - pp_env = { - "VLLM_USE_V1": vllm_major_version, - } + pp_env = None else: pp_env = None - tp_env = { - "VLLM_USE_V1": vllm_major_version, - } + tp_env = None pp_args = [ *common_args, @@ -399,28 +344,16 @@ def _compare_tp( "mp", ] - try: - compare_two_settings(model_id, - pp_args, - tp_args, - pp_env, - tp_env, - method=method) - except Exception: - if testing_ray_compiled_graph and vllm_major_version == "0": - # Ray Compiled Graph tests are flaky for V0, - # so we don't want to fail the test - logger.exception("Ray Compiled Graph tests failed") - else: - raise + compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_id, settings in TEXT_GENERATION_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -428,28 +361,29 @@ def test_tp_language_generation( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_id, settings in EMBEDDING_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in EMBEDDING_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -457,28 +391,29 @@ def test_tp_language_embedding( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - runner, - test_options, - num_gpus_available, - method="encode", - is_multimodal=False) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="encode", + is_multimodal=False, + ) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_id, settings in MULTIMODAL_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in MULTIMODAL_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -486,17 +421,17 @@ def test_tp_multimodal_generation( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=True) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=True, + ) diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py index 69ceedd345a89..4df6f43970d70 100644 --- a/tests/distributed/test_pipeline_partition.py +++ b/tests/distributed/test_pipeline_partition.py @@ -9,7 +9,6 @@ from vllm.distributed.utils import get_pp_indices def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: def _verify(partition_str, num_layers, pp_size, goldens): @@ -57,7 +56,8 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch): (5, 3, 0, (0, 2)), (5, 3, 1, (2, 4)), (5, 3, 2, (4, 5)), - ]) + ], +) def test_uneven_auto_partition( num_hidden_layers: int, pp_size: int, diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 5ca65a0e8d2c9..2c9f474640088 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -12,12 +12,18 @@ if TYPE_CHECKING: from typing_extensions import LiteralString -@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ - (2, "JackFram/llama-160m"), -]) -@pytest.mark.parametrize("ATTN_BACKEND", [ - "FLASH_ATTN", -]) +@pytest.mark.parametrize( + "PP_SIZE, MODEL_NAME", + [ + (2, "JackFram/llama-160m"), + ], +) +@pytest.mark.parametrize( + "ATTN_BACKEND", + [ + "FLASH_ATTN", + ], +) @create_new_process_for_each_test() def test_pp_cudagraph( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index abfad9ebfe7d0..4bab709fb5892 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -9,13 +9,15 @@ import pytest import torch import torch.distributed -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - get_world_group, graph_capture, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_world_group, + graph_capture, + init_distributed_environment, +) from vllm.utils import update_environment_variables @@ -24,13 +26,13 @@ def distributed_run(fn, world_size): processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -47,7 +49,7 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - local_rank = os.environ['LOCAL_RANK'] + local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() @@ -58,17 +60,18 @@ def worker_fn_wrapper(fn): @worker_fn_wrapper def worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) - tensor = torch.ones(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl(): distributed_run(worker_fn, 2) @@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 1], backend="gloo"), - torch.distributed.new_group(ranks=[2, 3], backend="gloo") + torch.distributed.new_group(ranks=[2, 3], backend="gloo"), ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) @@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_allreduce(): # this tests pynccl for multiple tp groups, in a standalone way # i.e. call `pynccl_comm.all_reduce` directly @@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_allreduce_with_vllm(): # this tests pynccl for multiple tp groups, together with vllm # i.e. call `tensor_model_parallel_all_reduce` @@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') + a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}") torch.cuda.synchronize() with torch.cuda.graph(graph): a_out = pynccl_comm.all_reduce(a) @@ -148,84 +154,90 @@ def worker_fn_with_cudagraph(): @worker_fn_wrapper def all_gather_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" num_elems = 1000 - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * num_elems - result = torch.zeros(num_elems * world_size, - dtype=torch.float32, - device=device) + tensor = ( + torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems + ) + result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device) - expected = torch.cat([ - torch.arange(num_elems, dtype=torch.float32) + r * num_elems - for r in range(world_size) - ]).to(device) + expected = torch.cat( + [ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ] + ).to(device) pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_all_gather(): distributed_run(all_gather_worker_fn, 2) @worker_fn_wrapper def all_gatherv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" assert world_size <= 8 sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] num_elems = sizes[rank] - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * 100 + tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100 result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) - expected = torch.cat([ - torch.arange(sizes[r], dtype=torch.float32) + r * 100 - for r in range(world_size) - ]).to(device) + expected = torch.cat( + [ + torch.arange(sizes[r], dtype=torch.float32) + r * 100 + for r in range(world_size) + ] + ).to(device) pynccl_comm.all_gatherv(result, tensor, sizes=sizes) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_all_gatherv(): distributed_run(all_gatherv_worker_fn, 2) @worker_fn_wrapper def reduce_scatter_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" num_elems = 1000 - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * num_elems - assert (num_elems % world_size == 0) - result = torch.zeros(num_elems // world_size, - dtype=torch.float32, - device=device) + tensor = ( + torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems + ) + assert num_elems % world_size == 0 + result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device) # Calculate expected result for this rank's chunk scattered_size = num_elems // world_size @@ -233,34 +245,37 @@ def reduce_scatter_worker_fn(): torch.arange(num_elems, dtype=torch.float32) + r * num_elems for r in range(world_size) ] - expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] - for tensor in all_tensors).to(device) + expected = sum( + tensor[rank * scattered_size : (rank + 1) * scattered_size] + for tensor in all_tensors + ).to(device) pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_reduce_scatter(): distributed_run(reduce_scatter_worker_fn, 2) @worker_fn_wrapper def reduce_scatterv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" assert world_size <= 8 sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] num_elems = sum(sizes) - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * 100 + tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100 result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) # Calculate expected result for this rank's chunk @@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn(): torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_reduce_scatterv(): distributed_run(reduce_scatterv_worker_fn, 2) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) @worker_fn_wrapper def send_recv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) if pynccl_comm.rank == 0: - tensor = torch.ones(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) else: - tensor = torch.empty(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_send_recv(): distributed_run(send_recv_worker_fn, 2) @@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 2], backend="gloo"), - torch.distributed.new_group(ranks=[1, 3], backend="gloo") + torch.distributed.new_group(ranks=[1, 3], backend="gloo"), ] group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) if torch.distributed.get_rank() == 0: tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) elif torch.distributed.get_rank() == 1: - tensor = 2 * torch.ones( - 16, 1024, 1024, dtype=torch.float32, device=device) + tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) else: - tensor = torch.empty(16, - 1024, - 1024, - dtype=torch.float32, - device=device) + tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device) if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() @@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_send_recv(): distributed_run(multiple_send_recv_worker_fn, 4) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_broadcast(): distributed_run(broadcast_worker_fn, 4) @@ -366,19 +376,17 @@ def test_pynccl_broadcast(): def broadcast_worker_fn(): # Test broadcast for every root rank. # Essentially this is an all-gather operation. - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) recv_tensors = [ - torch.empty(16, - 1024, - 1024, - dtype=torch.float32, - device=pynccl_comm.device) + torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device) for i in range(pynccl_comm.world_size) ] - recv_tensors[pynccl_comm.rank] = torch.ones( - 16, 1024, 1024, dtype=torch.float32, - device=pynccl_comm.device) * pynccl_comm.rank + recv_tensors[pynccl_comm.rank] = ( + torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device) + * pynccl_comm.rank + ) for i in range(pynccl_comm.world_size): pynccl_comm.broadcast(recv_tensors[i], src=i) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index 6245ccbeca877..2df88377345dd 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -8,20 +8,20 @@ import ray import torch import torch.distributed as dist -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.platforms import current_platform -from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, multi_process_parallel) +from ..utils import ( + ensure_model_parallel_initialized, + init_test_distributed_environment, + multi_process_parallel, +) torch.manual_seed(42) random.seed(44) # Size over 8MB is sufficient for custom quick allreduce. -test_sizes = [ - random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8) -] +test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)] for i, v in enumerate(test_sizes): test_sizes[i] -= v % 8 @@ -38,8 +38,7 @@ def graph_quickreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) group = get_tp_group().device_group @@ -64,18 +63,15 @@ def graph_quickreduce( for sz in test_sizes: for dtype in [torch.float16, torch.bfloat16]: with graph_capture(device=device) as graph_capture_context: - inp1 = torch.randint(1, - 23, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(-23, - 1, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) + inp1 = torch.randint( + 1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + inp2 = torch.randint( + -23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): + with torch.cuda.graph(graph, stream=graph_capture_context.stream): for _ in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) dist.all_reduce(inp1, group=group) @@ -99,39 +95,42 @@ def eager_quickreduce( device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) # Size over 8MB is sufficient for custom quick allreduce. sz = 16 * 1024 * 1024 fa = get_tp_group().device_communicator.qr_comm - inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], - dtype=torch.float16, - device=device) + inp = torch.tensor( + [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device + ) out = fa.quick_all_reduce(inp) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) - inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], - dtype=torch.bfloat16, - device=device) + inp = torch.tensor( + [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device + ) out = fa.quick_all_reduce(inp) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test quick allreduce for rocm") +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="only test quick allreduce for rocm" +) @pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"]) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce]) -def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size, test_target, - quant_mode): +def test_custom_quick_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pipeline_parallel_size, + test_target, + quant_mode, +): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 94ad8f4f1213a..baf75fd48c636 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -22,15 +22,13 @@ if __name__ == "__main__": dist.broadcast_object_list(recv, src=0) ip, port = recv - stateless_pg = StatelessProcessGroup.create(ip, port, rank, - dist.get_world_size()) + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: test_result = all(in_the_same_node_as(pg, source_rank=0)) expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" - assert test_result == expected, \ - f"Expected {expected}, got {test_result}" + assert test_result == expected, f"Expected {expected}, got {test_result}" if pg == dist.group.WORLD: print("Same node test passed! when using torch distributed!") else: diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 49b8eddecb4a9..0847687cf2f9a 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -7,6 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -14,7 +15,7 @@ from typing import Literal, NamedTuple, Optional import pytest -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.logger import init_logger from ..models.registry import HF_EXAMPLE_MODELS @@ -41,23 +42,10 @@ class SPTestOptions(NamedTuple): @dataclass class SPTestSettings: parallel_setups: list[ParallelSetup] - # NOTE: the length of distributed_backends and - # vllm_major_versions should be the same, and they - # are first zipped together to iterate over all - # test settings. distributed_backends: list[str] - # vllm major version: "0" for V0, "1" for V1 - vllm_major_versions: list[str] runner: RunnerOption test_options: SPTestOptions - def __post_init__(self): - if len(self.distributed_backends) != len(self.vllm_major_versions): - raise ValueError( - f"Length mismatch: distributed_backends " - f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") - @staticmethod def detailed( *, @@ -72,18 +60,21 @@ class SPTestSettings: for pp_multiplier in [1, 2]: for chunked_prefill_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - enable_fusion=False, - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], - vllm_major_versions=["1", "1"], runner=runner, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -100,18 +91,21 @@ class SPTestSettings: for pp_multiplier in [1, 2]: for chunked_prefill_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - enable_fusion=False, - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], - vllm_major_versions=["1", "1"], runner=runner, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -126,35 +120,41 @@ class SPTestSettings: parallel_setups = [] for fusion_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - enable_fusion=fusion_val, - eager_mode=True, - chunked_prefill=False)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_base, + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], - vllm_major_versions=["1", "1"], runner=runner, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.runner, opts) + for backend in self.distributed_backends: + yield ( + model_id, + parallel_setup, + backend, + self.runner, + opts, + ) def _compare_sp( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: SPTestOptions, num_gpus_available: int, @@ -178,6 +178,7 @@ def _compare_sp( trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides + skip_tokenizer_init = model_info.skip_tokenizer_init if load_format == "dummy": # Avoid OOM @@ -199,8 +200,10 @@ def _compare_sp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -227,36 +230,32 @@ def _compare_sp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") compilation_config = { - 'level': 3, - 'custom_ops': ["+rms_norm"], - 'compile_sizes': [4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_sequence_parallelism': True, - 'enable_fusion': enable_fusion, - 'enable_noop': True, + "level": 3, + "custom_ops": ["+rms_norm"], + "compile_sizes": [4, 8], + "pass_config": { + "enable_sequence_parallelism": True, + "enable_fusion": enable_fusion, + "enable_noop": True, }, } - tp_sp_env = tp_env = { - "VLLM_USE_V1": vllm_major_version, - } - tp_sp_args = [ *common_args, "--tensor-parallel-size", str(tp_size), + "--pipeline-parallel-size", + str(pp_size), "--distributed-executor-backend", distributed_backend, "--compilation_config", json.dumps(compilation_config), ] - tp_env = { - "VLLM_USE_V1": vllm_major_version, - } tp_args = [ *common_args, "--tensor-parallel-size", @@ -265,21 +264,7 @@ def _compare_sp( "mp", ] - try: - compare_two_settings(model_id, - tp_sp_args, - tp_args, - tp_sp_env, - tp_env, - method=method) - except Exception: - testing_ray_compiled_graph = tp_sp_env is not None - if testing_ray_compiled_graph and vllm_major_version == "0": - # Ray Compiled Graph tests are flaky for V0, - # so we don't want to fail the test - logger.exception("Ray Compiled Graph tests failed") - else: - raise + compare_two_settings(model_id, tp_sp_args, tp_args, method=method) SP_TEXT_GENERATION_MODELS = { @@ -292,15 +277,21 @@ SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", ] @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "runner", + "test_options", + ), [ - params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + params + for model_id, settings in SP_TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_id) if model_id in SP_TEST_MODELS ], @@ -310,17 +301,17 @@ def test_tp_sp_generation( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: SPTestOptions, num_gpus_available, ): - _compare_sp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_sp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index e1357b4a34e99..cdea1bfe8f281 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -26,13 +26,13 @@ def distributed_run(fn, world_size): processes = [] for i in range(number_of_processes): env = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -57,25 +57,23 @@ def worker_fn_wrapper(fn): @worker_fn_wrapper def worker_fn(): - rank = dist.get_rank() if rank == 0: port = get_open_port() - ip = '127.0.0.1' + ip = "127.0.0.1" dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) ip, port = recv # type: ignore - stateless_pg = StatelessProcessGroup.create(ip, port, rank, - dist.get_world_size()) + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: - writer_rank = 2 broadcaster = MessageQueue.create_from_process_group( - pg, 40 * 1024, 2, writer_rank) + pg, 40 * 1024, 2, writer_rank + ) if rank == writer_rank: seed = random.randint(0, 1000) dist.broadcast_object_list([seed], writer_rank) diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py new file mode 100644 index 0000000000000..c6ceab181ff55 --- /dev/null +++ b/tests/distributed/test_shm_buffer.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import traceback +import unittest + +from vllm.distributed.device_communicators.shm_object_storage import ( + SingleWriterShmRingBuffer, +) + + +class TestSingleWriterShmRingBuffer(unittest.TestCase): + """Test suite for the ring buffer implementation""" + + def setUp(self): + """Set up test fixtures""" + self.buffer_size = 4096 + self.ring_buffer = None + + def tearDown(self): + """Clean up after tests""" + if self.ring_buffer: + del self.ring_buffer + + def test_buffer_opening(self): + """Test opening an existing buffer""" + # First create a buffer + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True + ) + + # Then open it with another instance + reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle()) + self.assertFalse(reader_buffer.is_writer) + self.assertEqual( + reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name + ) + + def test_buffer_access(self): + """Test accessing allocated buffers""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True + ) + + size = 100 + address, monotonic_id = self.ring_buffer.allocate_buf(size) + + # Write some test data + test_data = b"Hello, World!" * 7 # 91 bytes + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0 : len(test_data)] = test_data + + # Read it back + with self.ring_buffer.access_buf(address) as (data_buf2, metadata2): + read_data = bytes(data_buf2[0 : len(test_data)]) + read_id = metadata2[0] + + self.assertEqual(read_data, test_data) + self.assertEqual(read_id, monotonic_id) + + def test_memory_error_on_full_buffer(self): + """Test that MemoryError is raised when buffer is full""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True + ) + + # Fill up the buffer + self.ring_buffer.allocate_buf(100) + self.ring_buffer.allocate_buf(80) # Total: 196 bytes used + + # This should fail + with self.assertRaises(MemoryError): + self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity + + def test_allocation_and_free(self): + """Test allocation and freeing of buffers""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True + ) + + size = 80 + # Write some data + test_data = b"Repeated test data" + for i in range(5): + address, monotonic_id = self.ring_buffer.allocate_buf(size) + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use + data_buf[4 : len(test_data) + 4] = test_data + print(self.ring_buffer.metadata) + freed_ids = self.ring_buffer.free_buf(lambda *args: True) + print(f" Freed IDs: {freed_ids}") + self.assertEqual(freed_ids[0], i) + + def test_clear_buffer(self): + """Test clearing the buffer""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True + ) + + # Allocate some buffers + for _ in range(3): + self.ring_buffer.allocate_buf(100) + + # Clear the buffer + self.ring_buffer.clear() + + # Check that metadata is empty and IDs reset + self.assertEqual(len(self.ring_buffer.metadata), 0) + self.assertEqual(self.ring_buffer.monotonic_id_start, 0) + self.assertEqual(self.ring_buffer.monotonic_id_end, 0) + self.assertEqual(self.ring_buffer.data_buffer_start, 0) + self.assertEqual(self.ring_buffer.data_buffer_end, 0) + + +def main(): + """Main function demonstrating usage and running tests""" + print("=== SingleWriterShmRingBuffer Test Suite ===\n") + + # Run unit tests + print("Running unit tests...") + unittest.main(argv=[""], exit=False, verbosity=2) + + print("\n" + "=" * 50) + print("=== Manual Demo ===\n") + + # Manual demonstration + try: + print("Creating ring buffer...") + writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True) + reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle()) + + print(f"Buffer created with name: {writer_buffer.shared_memory.name}") + + # Allocate some buffers + print("\nAllocating buffers...") + address_array = [] + for i in range(3): + size = 100 + i * 50 + try: + writer_buffer.free_buf(lambda *args: True) + address, monotonic_id = writer_buffer.allocate_buf(size) + address_array.append((address, size, monotonic_id)) + + # Write some test data + with writer_buffer.access_buf(address) as (data_buf, metadata): + test_message = f"Test message {i}".encode() + data_buf[0 : len(test_message)] = test_message + + except MemoryError as e: + print(f" Failed to allocate {size} bytes: {e}") + + print("\nBuffer state:") + print(f" Data buffer start: {writer_buffer.data_buffer_start}") + print(f" Data buffer end: {writer_buffer.data_buffer_end}") + print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}") + print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}") + print(f" Metadata entries: {len(writer_buffer.metadata)}") + + # Try to read back the data + print("\nReading back data...") + for address, size, monotonic_id in address_array: + with reader_buffer.access_buf(address) as (data_buf, metadata): + # Find null terminator or read first 50 chars + data_bytes = bytes(data_buf[0:size]) + message = data_bytes.decode() + print(f" ID {monotonic_id}: '{message}'") + + except Exception as e: + print(f"Demo error: {e}") + traceback.print_exc() + + print("\n=== Demo Complete ===") + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py new file mode 100644 index 0000000000000..b9a5c22447fd8 --- /dev/null +++ b/tests/distributed/test_shm_storage.py @@ -0,0 +1,327 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import random +import time +import traceback +import unittest +from multiprocessing import Lock + +import torch + +# Assuming these are imported from your module +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, + SingleWriterShmObjectStorage, + SingleWriterShmRingBuffer, +) +from vllm.multimodal.inputs import ( + MultiModalFieldElem, + MultiModalKwargsItem, + MultiModalSharedField, +) + + +def _dummy_elem(modality: str, key: str, size: int): + return MultiModalFieldElem( + modality=modality, + key=key, + data=torch.empty((size,), dtype=torch.int8), + field=MultiModalSharedField(1), + ) + + +def _dummy_item(modality: str, size_by_key: dict[str, int]): + return MultiModalKwargsItem.from_elems( + [_dummy_elem(modality, key, size) for key, size in size_by_key.items()] + ) + + +class TestSingleWriterShmObjectStorage(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + self.storage = SingleWriterShmObjectStorage( + max_object_size=1024 * 10, # 10KB max object + n_readers=2, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + def tearDown(self): + """Clean up after each test.""" + if self.storage: + del self.storage + + def test_minimal_put_get_cycle(self): + """Test basic put and get operations.""" + key = "test_key" + value = _dummy_item("text", {"field1": 10, "field2": 20}) + + # Put operation + address, monotonic_id = self.storage.put(key, value) + + # Verify key is in index + self.assertIn(key, self.storage.key_index) + self.assertEqual(self.storage.key_index[key], (address, monotonic_id)) + self.assertEqual(self.storage.id_index[monotonic_id], key) + + # Get operation + result = self.storage.get(address, monotonic_id) + + # Verify result + self.assertEqual(result, value) + + def test_put_same_key_twice(self): + """Test behavior when putting the same key multiple times.""" + key = "duplicate_key" + value1 = "first value" + value2 = "second value" + + # First put + address1, id1 = self.storage.put(key, value1) + retrieved1 = self.storage.get(address1, id1) + self.assertEqual(retrieved1, value1) + + # should raise an error on second put + with self.assertRaises(ValueError) as context: + self.storage.put(key, value2) + + self.assertIn("already exists in the storage", str(context.exception)) + + def test_large_object_rejection(self): + """Test that objects exceeding max_object_size are rejected.""" + # Create an object larger than max_object_size + large_data = "x" * (self.storage.max_object_size + 100) + + with self.assertRaises(ValueError) as context: + self.storage.put("large_key", large_data) + + self.assertIn("exceeds max object size", str(context.exception)) + + def test_buffer_overflow_and_cleanup(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessible + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items)) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessibles + for key, original_value, address, monotonic_id in stored_items: + try: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + except ValueError as e: + print(f"Error retrieving {key}: {e}") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(accessible_count, len(stored_items)) + + def test_blocking_unread_object(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read all items except the first one + # to simulate a blocking situation + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items[1:]: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items) - 1) + + try: + key = f"item_{len(stored_items)}" + value = f"data_{len(stored_items)}" * 100 + address, monotonic_id = self.storage.put(key, value) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read the first item + for i in range(self.storage.n_readers): + key, original_value, address, monotonic_id = stored_items[0] + retrieved = self.storage.get(address, monotonic_id) + self.assertEqual(retrieved, original_value) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(len(stored_items), accessible_count + 10) + + def test_invalid_get_operations(self): + """Test various invalid get operations.""" + # Test with non-existent address + with self.assertRaises(ValueError): # Could be various exceptions + self.storage.get(99999, 1) + + # Store something first + address, monotonic_id = self.storage.put("test", "value") + + # Test with wrong monotonic_id + with self.assertRaises(ValueError) as context: + self.storage.get(address, monotonic_id + 100) + + self.assertIn("has been modified or is invalid", str(context.exception)) + + def test_clear_storage(self): + """Test clearing the storage.""" + # Store some items + for i in range(5): + self.storage.put(f"item_{i}", f"value_{i}") + + # Clear the storage + self.storage.clear() + + # Verify that all indices are empty + self.assertEqual(len(self.storage.key_index), 0) + self.assertEqual(len(self.storage.id_index), 0) + self.assertEqual(len(self.storage.ring_buffer.metadata), 0) + + # Verify that new items can be added after clearing + address, monotonic_id = self.storage.put("new_item", "new_value") + self.assertIn("new_item", self.storage.key_index) + self.assertEqual((address, monotonic_id), (0, 0)) + + +# Reader process function +def reader_process(process_id, storage_handle, items_to_read): + """Reader process that connects to existing shared memory and reads data.""" + reader_storage = SingleWriterShmObjectStorage.create_from_handle(storage_handle) + + print(f"Reader {process_id} started") + + errors = [] + + for key, original_value, address, monotonic_id in items_to_read: + time.sleep(random.random() / 100) + try: + # Read data from shared memory + retrieved_value = reader_storage.get(address, monotonic_id) + + # Verify data integrity + assert retrieved_value == original_value + print(f"Reader {process_id} retrieved {key}: {retrieved_value}") + except Exception as e: + errors.append((key, str(e), type(e).__name__)) + + +def run_multiprocess_example(): + """Run a minimal working example with real shared memory.""" + print("=== Minimal Object Storage Example ===") + + try: + # Create storage instance + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + storage = SingleWriterShmObjectStorage( + max_object_size=1024, + n_readers=3, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + print(f"Created storage (writer: {storage.is_writer})") + + # Test basic data types + test_data = [ + ("user_data", {"name": "Alice", "age": 30, "scores": [95, 87, 92]}), + ("simple_string", "Hello, World!"), + ("number", 42), + ("list_data", [1, 2, 3, "four", 5.0]), + ] + + stored_items = [] + + # Store all data + for key, value in test_data: + print(f"Storing {key}: {value}") + address, monotonic_id = storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + print(f" -> Stored at address {address}, ID {monotonic_id}") + + print("\n--- Retrieving Data ---") + processes = [] + handle = storage.handle() + # initialize lock for reader processes + handle.reader_lock = Lock() + for i in range(storage.n_readers): + p = multiprocessing.Process( + target=reader_process, args=(i, handle, stored_items) + ) + processes.append(p) + p.start() + + for p in processes: + p.join(timeout=10) + if p.is_alive(): + p.terminate() + p.join() + + except Exception as e: + print(f"Error in minimal example: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + # Run the minimal example first + run_multiprocess_example() + print("\n" + "=" * 50 + "\n") + + # Run the test suite + print("Running comprehensive test suite...") + unittest.main(verbosity=2, exit=False) diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py index 5a804a389123b..e669b81b04f08 100644 --- a/tests/distributed/test_symm_mem_allreduce.py +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import queue import random import typing @@ -10,99 +11,130 @@ import torch.distributed as dist import torch.multiprocessing as mp import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.cuda_communicator import ( - CudaCommunicator) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, - init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine from vllm.platforms import current_platform from vllm.utils import update_environment_variables torch.manual_seed(42) random.seed(44) -test_size_elements = 4 * 1024 * 1024 +test_size_elements = 1024 * 1024 -def symm_mem_allreduce_worker(local_rank: int, world_size: int): +def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): monkeypatch = pytest.MonkeyPatch() - with monkeypatch.context() as m: + config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size)) + + with monkeypatch.context() as m, set_current_vllm_config(config): m.delenv("CUDA_VISIBLE_DEVICES", raising=False) dtype = torch.bfloat16 device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - cuda_communicator = typing.cast(CudaCommunicator, - get_tp_group().device_communicator) + cuda_communicator = typing.cast( + CudaCommunicator, get_tp_group().device_communicator + ) symm_mem_comm = cuda_communicator.symm_mem_comm if symm_mem_comm is None or symm_mem_comm.disabled: - pytest.skip("SymmMemCommunicator is not available or disabled.") + # can't use skip under multiprocessing + q.put("SymmMemCommunicator is not available or disabled.") + return - inp_direct_symm_mem = torch.randint(1, - 23, (test_size_elements, ), - dtype=dtype, - device=device) + inp_direct_symm_mem = torch.randint( + 1, 23, (test_size_elements,), dtype=dtype, device=device + ) if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): - pytest.skip( - "SymmMemCommunicator isn't used for this world and input size." - ) + # can't use skip under multiprocessing + q.put("SymmMemCommunicator isn't used for this world and input size.") + return original_inp_direct_symm_mem = inp_direct_symm_mem.clone() out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) assert out_direct_symm_mem is not None - group = get_tensor_model_parallel_group().device_group + group = get_tp_group().device_group dist.all_reduce(original_inp_direct_symm_mem, group=group) - torch.testing.assert_close(out_direct_symm_mem, - original_inp_direct_symm_mem, - atol=2.5, - rtol=0.1) + torch.testing.assert_close( + out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1 + ) # Test tensor_model_parallel_all_reduce which should use symm_mem - inp_tensor_parallel = torch.randint(-23, - 1, (test_size_elements, ), - dtype=dtype, - device=device) + inp_tensor_parallel = torch.randint( + -23, 1, (test_size_elements,), dtype=dtype, device=device + ) original_inp_tensor_parallel = inp_tensor_parallel.clone() - out_tensor_parallel = tensor_model_parallel_all_reduce( - inp_tensor_parallel) + out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel) dist.all_reduce(original_inp_tensor_parallel, group=group) - torch.testing.assert_close(out_tensor_parallel, - original_inp_tensor_parallel, - atol=2.5, - rtol=0.1) + torch.testing.assert_close( + out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1 + ) @pytest.mark.skipif( not current_platform.is_cuda(), - reason="SymmMemAllreduce is only available for CUDA platforms.") + reason="SymmMemAllreduce is only available for CUDA platforms.", +) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_symm_mem_allreduce( + monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size +): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") + q = mp.get_context("spawn").Queue() + mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size) + try: + val = q.get(timeout=1) + except queue.Empty: + val = None + finally: + cleanup_dist_env_and_memory() + if val is not None: + pytest.skip(val) - # Enable SymmMemCommunicator - monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") - mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) - cleanup_dist_env_and_memory() +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.", +) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch): + world_size = 4 + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + # Verify that the DataParallel runs without error + engine_args = EngineArgs( + model="distilbert/distilgpt2", + enforce_eager=True, + enable_prefix_caching=True, + data_parallel_size=2, + tensor_parallel_size=2, + data_parallel_backend="mp", + ) + LLMEngine.from_engine_args(engine_args) diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 9f2c3eaec3597..f415409d7b377 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -24,13 +24,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # set different `gpu_memory_utilization` and `swap_space` for different ranks, # to test if all ranks agree on the same kv cache configuration. -llm = LLM(model="facebook/opt-125m", - tensor_parallel_size=2, - pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), - distributed_executor_backend="external_launcher", - gpu_memory_utilization=random.uniform(0.7, 0.9), - swap_space=random.randint(1, 4), - seed=0) +llm = LLM( + model="facebook/opt-125m", + tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4), + seed=0, +) outputs = llm.generate(prompts, sampling_params) @@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj): assert container[0] == obj -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) # make sure we can access the model parameters from the calling process # of the `LLM` instance. -params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. - model.parameters()) +params = list( + llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters() +) test_consistent_across_ranks(len(params)) # all ranks should have the same outputs @@ -65,5 +66,4 @@ for output in outputs: generated_text = output.outputs[0].text test_consistent_across_ranks(prompt) test_consistent_across_ranks(generated_text) - print(f"Rank {torch_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py new file mode 100644 index 0000000000000..1aa7f17935704 --- /dev/null +++ b/tests/distributed/test_torchrun_example_moe.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# unit test for `examples/offline_inference/torchrun_example.py` +import os +import random + +import torch.distributed as dist + +from vllm import LLM, SamplingParams +from vllm.distributed.parallel_state import get_tp_group, get_world_group + +dist.init_process_group(backend="gloo") + +# Create prompts +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] * 10 +dp_size = int(os.getenv("DP_SIZE", "1")) +dp_rank = int(os.getenv("DP_RANK", "0")) + +if dp_size > 1: + # distribute the prompts across the data parallel ranks + prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# set different `gpu_memory_utilization` and `swap_space` for different ranks, +# to test if all ranks agree on the same kv cache configuration. +llm = LLM( + model="microsoft/Phi-mini-MoE-instruct", + tensor_parallel_size=int(os.getenv("TP_SIZE", "1")), + pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")), + enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4), + seed=0, +) + +outputs = llm.generate(prompts, sampling_params) + +group = get_world_group() if dp_size == 1 else get_tp_group() +cpu_group = group.cpu_group +group_rank = dist.get_rank(group=cpu_group) + + +def test_consistent_across_ranks(obj): + if group_rank == 0: + dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group) + else: + container = [None] + dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group) + assert container[0] == obj + + +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) + +# make sure we can access the model parameters from the calling process +# of the `LLM` instance. +params = list( + llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters() +) +test_consistent_across_ranks(len(params)) + +# all ranks should have the same outputs +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + test_consistent_across_ranks(prompt) + test_consistent_across_ranks(generated_text) + print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 0287ad94e3886..2a6936fcd4c2e 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -10,21 +10,22 @@ import torch import vllm.envs as envs from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import (cuda_device_count_stateless, get_open_port, - update_environment_variables) +from vllm.utils import ( + cuda_device_count_stateless, + get_open_port, + update_environment_variables, +) from ..utils import multi_gpu_test @ray.remote class _CUDADeviceCountStatelessTestActor: - def get_count(self): return cuda_device_count_stateless() def set_cuda_visible_devices(self, cuda_visible_devices: str): - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) def get_cuda_visible_devices(self): return envs.CUDA_VISIBLE_DEVICES @@ -34,10 +35,9 @@ def test_cuda_device_count_stateless(): """Test that cuda_device_count_stateless changes return value if CUDA_VISIBLE_DEVICES is changed.""" actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore - num_gpus=2).remote() - assert len( - sorted(ray.get( - actor.get_cuda_visible_devices.remote()).split(","))) == 2 + num_gpus=2 + ).remote() + assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2 assert ray.get(actor.get_count.remote()) == 2 ray.get(actor.set_cuda_visible_devices.remote("0")) assert ray.get(actor.get_count.remote()) == 1 @@ -46,15 +46,13 @@ def test_cuda_device_count_stateless(): def cpu_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) if rank <= 2: - pg2 = StatelessProcessGroup.create(host="127.0.0.1", - port=port2, - rank=rank, - world_size=3) + pg2 = StatelessProcessGroup.create( + host="127.0.0.1", port=port2, rank=rank, world_size=3 + ) data = torch.tensor([rank]) data = pg1.broadcast_obj(data, src=2) assert data.item() == 2 @@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2): torch.cuda.set_device(rank) - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) pynccl1 = PyNcclCommunicator(pg1, device=rank) if rank <= 2: - pg2 = StatelessProcessGroup.create(host="127.0.0.1", - port=port2, - rank=rank, - world_size=3) + pg2 = StatelessProcessGroup.create( + host="127.0.0.1", port=port2, rank=rank, world_size=3 + ) pynccl2 = PyNcclCommunicator(pg2, device=rank) data = torch.tensor([rank]).cuda() pynccl1.all_reduce(data) @@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) if rank == 2: pg1.broadcast_obj("secret", src=2) else: @@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) data = pg1.all_gather_obj(rank) assert data == list(range(WORLD_SIZE)) pg1.barrier() @@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2): @pytest.mark.skip(reason="This test is flaky and prone to hang.") @multi_gpu_test(num_gpus=4) @pytest.mark.parametrize( - "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) + "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker] +) def test_stateless_process_group(worker): port1 = get_open_port() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -129,12 +124,14 @@ def test_stateless_process_group(worker): port2 = get_open_port() WORLD_SIZE = 4 from multiprocessing import get_context + ctx = get_context("fork") processes = [] for i in range(WORLD_SIZE): rank = i processes.append( - ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))) + ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)) + ) for p in processes: p.start() for p in processes: diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py deleted file mode 100644 index 8b99d9d6e21fb..0000000000000 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""E2E tests to verify the correctness of the encoder-decoder framework - -Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. -""" -from typing import Optional - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) -from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs - -from ..conftest import DecoderPromptType -from ..models.utils import check_logprobs_close - -LIST_ENC_DEC_SUPPORTED_BACKENDS = [ - _Backend.XFORMERS, _Backend.FLASH_ATTN, None -] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "</s>" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "<s>" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Fixture to clear backend cache before each test.""" - _cached_get_attn_backend.cache_clear() # Clear the cache - yield # This allows the test to run - - -@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -@pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.skipif( - current_platform.is_cpu(), - reason="CPU backend is not currently supported with encoder/decoder models" -) -def test_encoder_decoder_e2e( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, - decoder_prompt_type: DecoderPromptType, - enforce_eager: bool, - attn_backend: _Backend, -) -> None: - ''' - End-to-End (E2E) test for the encoder-decoder framework. - This test evaluates the encoder-decoder functionality using the BART - model. We compare the outputs of the Hugging Face and vLLM - implementations to ensure that both implementations produce consistent - and correct results. - ''' - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - test_case_prompts = example_encoder_decoder_prompts[ - decoder_prompt_type] - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = ( - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) - - hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE - else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 93ac18dfcc7b4..9d367349fc2e5 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -10,22 +10,30 @@ from typing import Annotated, Literal, Optional, Union import pytest from vllm.config import CompilationConfig, config -from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, - get_type, get_type_hints, is_not_builtin, - is_type, literal_to_kwargs, optional_type, - parse_type) +from vllm.engine.arg_utils import ( + EngineArgs, + contains_type, + get_kwargs, + get_type, + get_type_hints, + is_not_builtin, + is_type, + literal_to_kwargs, + optional_type, + parse_type, +) from vllm.utils import FlexibleArgumentParser -@pytest.mark.parametrize(("type", "value", "expected"), [ - (int, "42", 42), - (float, "3.14", 3.14), - (str, "Hello World!", "Hello World!"), - (json.loads, '{"foo":1,"bar":2}', { - "foo": 1, - "bar": 2 - }), -]) +@pytest.mark.parametrize( + ("type", "value", "expected"), + [ + (int, "42", 42), + (float, "3.14", 3.14), + (str, "Hello World!", "Hello World!"), + (json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}), + ], +) def test_parse_type(type, value, expected): parse_type_func = parse_type(type) assert parse_type_func(value) == expected @@ -37,47 +45,56 @@ def test_optional_type(): assert optional_type_func("42") == 42 -@pytest.mark.parametrize(("type_hint", "type", "expected"), [ - (int, int, True), - (int, float, False), - (list[int], list, True), - (list[int], tuple, False), - (Literal[0, 1], Literal, True), -]) +@pytest.mark.parametrize( + ("type_hint", "type", "expected"), + [ + (int, int, True), + (int, float, False), + (list[int], list, True), + (list[int], tuple, False), + (Literal[0, 1], Literal, True), + ], +) def test_is_type(type_hint, type, expected): assert is_type(type_hint, type) == expected -@pytest.mark.parametrize(("type_hints", "type", "expected"), [ - ({float, int}, int, True), - ({int, tuple[int]}, int, True), - ({int, tuple[int]}, float, False), - ({str, Literal["x", "y"]}, Literal, True), -]) +@pytest.mark.parametrize( + ("type_hints", "type", "expected"), + [ + ({float, int}, int, True), + ({int, tuple}, int, True), + ({int, tuple[int]}, int, True), + ({int, tuple[int, ...]}, int, True), + ({int, tuple[int]}, float, False), + ({int, tuple[int, ...]}, float, False), + ({str, Literal["x", "y"]}, Literal, True), + ], +) def test_contains_type(type_hints, type, expected): assert contains_type(type_hints, type) == expected -@pytest.mark.parametrize(("type_hints", "type", "expected"), [ - ({int, float}, int, int), - ({int, float}, str, None), - ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), -]) +@pytest.mark.parametrize( + ("type_hints", "type", "expected"), + [ + ({int, float}, int, int), + ({int, float}, str, None), + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), + ], +) def test_get_type(type_hints, type, expected): assert get_type(type_hints, type) == expected -@pytest.mark.parametrize(("type_hints", "expected"), [ - ({Literal[1, 2]}, { - "type": int, - "choices": [1, 2] - }), - ({str, Literal["x", "y"]}, { - "type": str, - "metavar": ["x", "y"] - }), - ({Literal[1, "a"]}, Exception), -]) +@pytest.mark.parametrize( + ("type_hints", "expected"), + [ + ({Literal[1, 2]}, {"type": int, "choices": [1, 2]}), + ({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}), + ({Literal[1, "a"]}, Exception), + ], +) def test_literal_to_kwargs(type_hints, expected): context = nullcontext() if expected is Exception: @@ -120,22 +137,27 @@ class DummyConfig: """Nested config""" -@pytest.mark.parametrize(("type_hint", "expected"), [ - (int, False), - (DummyConfig, True), -]) +@pytest.mark.parametrize( + ("type_hint", "expected"), + [ + (int, False), + (DummyConfig, True), + ], +) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected @pytest.mark.parametrize( - ("type_hint", "expected"), [ + ("type_hint", "expected"), + [ (Annotated[int, "annotation"], {int}), (Optional[int], {int, type(None)}), (Annotated[Optional[int], "annotation"], {int, type(None)}), (Optional[Annotated[int, "annotation"]], {int, type(None)}), ], - ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"]) + ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"], +) def test_get_type_hints(type_hint, expected): assert get_type_hints(type_hint) == expected @@ -167,7 +189,7 @@ def test_get_kwargs(): # dict should have json tip in help json_tip = "Should either be a valid JSON string or JSON keys" assert json_tip in kwargs["json_tip"]["help"] - # nested config should should construct the nested config + # nested config should construct the nested config assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) @@ -175,24 +197,16 @@ def test_get_kwargs(): ("arg", "expected"), [ (None, dict()), - ('{"video": {"num_frames": 123} }', { - "video": { - "num_frames": 123 - } - }), + ('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}), ( '{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa { - "video": { - "num_frames": 123, - "fps": 1.0, - "foo": "bar" - }, - "image": { - "foo": "bar" - } - }), - ]) + "video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, + "image": {"foo": "bar"}, + }, + ), + ], +) def test_media_io_kwargs_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: @@ -227,24 +241,32 @@ def test_compilation_config(): assert args.compilation_config.level == 3 # set to string form of a dict - args = parser.parse_args([ - "-O", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' - '"use_inductor": false}', - ]) - assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] - and not args.compilation_config.use_inductor) + args = parser.parse_args( + [ + "-O", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": false}', + ] + ) + assert ( + args.compilation_config.level == 3 + and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and not args.compilation_config.use_inductor + ) # set to string form of a dict - args = parser.parse_args([ - "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' - '"use_inductor": true}', - ]) - assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] - and args.compilation_config.use_inductor) + args = parser.parse_args( + [ + "--compilation-config=" + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": true}', + ] + ) + assert ( + args.compilation_config.level == 3 + and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and args.compilation_config.use_inductor + ) def test_prefix_cache_default(): @@ -252,8 +274,7 @@ def test_prefix_cache_default(): args = parser.parse_args([]) engine_args = EngineArgs.from_cli_args(args=args) - assert (not engine_args.enable_prefix_caching - ), "prefix caching defaults to off." + assert not engine_args.enable_prefix_caching, "prefix caching defaults to off." # with flag to turn it on. args = parser.parse_args(["--enable-prefix-caching"]) @@ -266,38 +287,15 @@ def test_prefix_cache_default(): assert not engine_args.enable_prefix_caching -# yapf: disable -@pytest.mark.parametrize(("arg", "expected", "option"), [ - (None, None, "mm-processor-kwargs"), - ("{}", {}, "mm-processor-kwargs"), - ( - '{"num_crops": 4}', - { - "num_crops": 4 - }, - "mm-processor-kwargs" - ), - ( - '{"foo": {"bar": "baz"}}', - { - "foo": - { - "bar": "baz" - } - }, - "mm-processor-kwargs" - ), - ( - '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}', - { - "cast_logits_dtype": "bfloat16", - "sequence_parallel_norm": True, - "sequence_parallel_norm_threshold": 2048, - }, - "override-neuron-config" - ), -]) -# yapf: enable +@pytest.mark.parametrize( + ("arg", "expected", "option"), + [ + (None, None, "mm-processor-kwargs"), + ("{}", {}, "mm-processor-kwargs"), + ('{"num_crops": 4}', {"num_crops": 4}, "mm-processor-kwargs"), + ('{"foo": {"bar": "baz"}}', {"foo": {"bar": "baz"}}, "mm-processor-kwargs"), + ], +) def test_composite_arg_parser(arg, expected, option): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: @@ -309,8 +307,7 @@ def test_composite_arg_parser(arg, expected, option): def test_human_readable_model_len(): # `exit_on_error` disabled to test invalid values below - parser = EngineArgs.add_cli_args( - FlexibleArgumentParser(exit_on_error=False)) + parser = EngineArgs.add_cli_args(FlexibleArgumentParser(exit_on_error=False)) args = parser.parse_args([]) assert args.max_model_len is None diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py deleted file mode 100644 index ac5a1f957dfe4..0000000000000 --- a/tests/engine/test_computed_prefix_blocks.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -@pytest.mark.parametrize("block_size", [16]) -def test_computed_prefix_blocks(model: str, block_size: int): - # This test checks if we are able to run the engine to completion - # without triggering asserts. - # We are in a scenario where all blocks from the second request's prompt - # are full and already computed when the second request arrives. - prompt = ( - "You are a helpful assistant. How do I build a car from cardboard and " - "paper clips? Is there an easy to follow video tutorial available " - "online for free?") - prompt2 = ( - " Please recommend to me some resources where I can learn not only to " - "handle technical difficulties of building a car, but also " - "decoration.") - - engine_args = EngineArgs(model=model, - block_size=block_size, - enable_prefix_caching=True) - - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams() - - engine.add_request("0", prompt + prompt2, sampling_params) - engine.step() - engine.add_request("1", prompt, sampling_params) - engine.step() diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py deleted file mode 100644 index 15c7a97b50e1f..0000000000000 --- a/tests/engine/test_executor.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -from typing import Any, Callable, Optional, Union - -import pytest - -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.llm_engine import LLMEngine -from vllm.executor.uniproc_executor import UniProcExecutor -from vllm.sampling_params import SamplingParams - - -class Mock: - ... - - -class CustomUniExecutor(UniProcExecutor): - - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None) -> list[Any]: - # Drop marker to show that this was ran - with open(".marker", "w"): - ... - return super().collective_rpc(method, timeout, args, kwargs) - - -CustomUniExecutorAsync = CustomUniExecutor - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_custom_executor_type_checking(model): - with pytest.raises(ValueError): - engine_args = EngineArgs(model=model, - distributed_executor_backend=Mock) - LLMEngine.from_engine_args(engine_args) - with pytest.raises(ValueError): - engine_args = AsyncEngineArgs(model=model, - distributed_executor_backend=Mock) - AsyncLLMEngine.from_engine_args(engine_args) - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_custom_executor(model, tmp_path): - cwd = os.path.abspath(".") - os.chdir(tmp_path) - try: - assert not os.path.exists(".marker") - - engine_args = EngineArgs( - model=model, - distributed_executor_backend=CustomUniExecutor, - enforce_eager=True, # reduce test time - ) - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(max_tokens=1) - - engine.add_request("0", "foo", sampling_params) - engine.step() - - assert os.path.exists(".marker") - finally: - os.chdir(cwd) - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_custom_executor_async(model, tmp_path): - cwd = os.path.abspath(".") - os.chdir(tmp_path) - try: - assert not os.path.exists(".marker") - - engine_args = AsyncEngineArgs( - model=model, - distributed_executor_backend=CustomUniExecutorAsync, - enforce_eager=True, # reduce test time - ) - engine = AsyncLLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(max_tokens=1) - - async def t(): - stream = await engine.add_request("0", "foo", sampling_params) - async for x in stream: - ... - - asyncio.run(t()) - - assert os.path.exists(".marker") - finally: - os.chdir(cwd) - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_respect_ray(model): - # even for TP=1 and PP=1, - # if users specify ray, we should use ray. - # users might do this if they want to manage the - # resources using ray. - engine_args = EngineArgs( - model=model, - distributed_executor_backend="ray", - enforce_eager=True, # reduce test time - ) - engine = LLMEngine.from_engine_args(engine_args) - assert engine.model_executor.uses_ray diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py deleted file mode 100644 index b5381b61a020a..0000000000000 --- a/tests/engine/test_multiproc_workers.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from time import sleep -from typing import Any - -import pytest - -from vllm.config import VllmConfig -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) -from vllm.worker.worker_base import WorkerWrapperBase - - -class DummyWorkerWrapper(WorkerWrapperBase): - """Dummy version of vllm.worker.worker.Worker""" - - def worker_method(self, worker_input: Any) -> tuple[int, Any]: - sleep(0.05) - - if isinstance(worker_input, Exception): - # simulate error case - raise worker_input - - return self.rpc_rank, input - - -def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]: - result_handler = ResultHandler() - vllm_config = VllmConfig() - workers = [ - ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config, - rank) for rank in range(8) - ] - - worker_monitor = WorkerMonitor(workers, result_handler) - assert not worker_monitor.is_alive() - - result_handler.start() - worker_monitor.start() - assert worker_monitor.is_alive() - - return workers, worker_monitor - - -def test_local_workers() -> None: - """Test workers with sync task submission""" - - workers, worker_monitor = _start_workers() - - def execute_workers(worker_input: str) -> None: - worker_outputs = [ - worker.execute_method("worker_method", worker_input) - for worker in workers - ] - - for rank, output in enumerate(worker_outputs): - assert output.get() == (rank, input) - - executor = ThreadPoolExecutor(max_workers=4) - - # Test concurrent submission from different threads - futures = [ - executor.submit(partial(execute_workers, f"thread {thread_num}")) - for thread_num in range(4) - ] - - for future in futures: - future.result() - - # Test error case - exception = ValueError("fake error") - result = workers[0].execute_method("worker_method", exception) - try: - result.get() - pytest.fail("task should have failed") - except Exception as e: - assert isinstance(e, ValueError) - assert str(e) == "fake error" - - # Test cleanup when a worker fails - assert worker_monitor.is_alive() - workers[3].process.kill() - - # Other workers should get shut down here - worker_monitor.join(20) - - # Ensure everything is stopped - assert not worker_monitor.is_alive() - assert all(not worker.process.is_alive() for worker in workers) - - # Further attempts to submit tasks should fail - try: - _result = workers[0].execute_method("worker_method", "test") - pytest.fail("task should fail once workers have been shut down") - except Exception as e: - assert isinstance(e, ChildProcessError) - - -def test_local_workers_clean_shutdown() -> None: - """Test clean shutdown""" - - workers, worker_monitor = _start_workers() - - assert worker_monitor.is_alive() - assert all(worker.process.is_alive() for worker in workers) - - # Clean shutdown - worker_monitor.close() - - worker_monitor.join(20) - - # Ensure everything is stopped - assert not worker_monitor.is_alive() - assert all(not worker.process.is_alive() for worker in workers) - - # Further attempts to submit tasks should fail - try: - _result = workers[0].execute_method("worker_method", "test") - pytest.fail("task should fail once workers have been shut down") - except Exception as e: - assert isinstance(e, ChildProcessError) - - -@pytest.mark.asyncio -async def test_local_workers_async() -> None: - """Test local workers with async task submission""" - - workers, worker_monitor = _start_workers() - - async def execute_workers(worker_input: str) -> None: - worker_coros = [ - worker.execute_method_async("worker_method", worker_input) - for worker in workers - ] - - results = await asyncio.gather(*worker_coros) - for rank, result in enumerate(results): - assert result == (rank, input) - - tasks = [ - asyncio.create_task(execute_workers(f"task {task_num}")) - for task_num in range(4) - ] - - for task in tasks: - await task - - # Test error case - exception = ValueError("fake error") - try: - _result = await workers[0].execute_method_async( - "worker_method", exception) - pytest.fail("task should have failed") - except Exception as e: - assert isinstance(e, ValueError) - assert str(e) == "fake error" - - # Test cleanup when a worker fails - assert worker_monitor.is_alive() - workers[3].process.kill() - - # Other workers should get shut down here - worker_monitor.join(20) - - # Ensure everything is stopped - assert not worker_monitor.is_alive() - assert all(not worker.process.is_alive() for worker in workers) - - # Further attempts to submit tasks should fail - try: - _result = await workers[0].execute_method_async( - "worker_method", "test") - pytest.fail("task should fail once workers have been shut down") - except Exception as e: - assert isinstance(e, ChildProcessError) diff --git a/tests/engine/test_options.py b/tests/engine/test_options.py deleted file mode 100644 index 42e88e84770ab..0000000000000 --- a/tests/engine/test_options.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from contextlib import nullcontext - -import pytest - -from vllm.entrypoints.llm import LLM -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_skip_tokenizer_initialization(model: str): - # This test checks if the flag skip_tokenizer_init skips the initialization - # of tokenizer and detokenizer. The generated output is expected to contain - # token ids. - llm = LLM( - model=model, - skip_tokenizer_init=True, - enforce_eager=True, - ) - sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - - with pytest.raises(ValueError, match="cannot pass text prompts when"): - llm.generate("abc", sampling_params) - - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) - assert len(outputs) > 0 - completions = outputs[0].outputs - assert len(completions) > 0 - assert completions[0].text == "" - assert completions[0].token_ids - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) -def test_enable_prompt_embeds(hf_runner, model: str, - enable_prompt_embeds: bool): - prompt = "abc" - - with hf_runner(model) as hf_model: - token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids - token_ids = token_ids.to(hf_model.model.device) - - embed_layer = hf_model.model.get_input_embeddings() - prompt_embeds = embed_layer(token_ids).squeeze(0) - - ctx = (nullcontext() if enable_prompt_embeds else pytest.raises( - ValueError, match="set `--enable-prompt-embeds`")) - - llm = LLM( - model=model, - enable_prompt_embeds=enable_prompt_embeds, - enforce_eager=True, - ) - - with ctx: - llm.generate({"prompt_embeds": prompt_embeds}) diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py index 9c62761d78afb..54a88586d8edd 100644 --- a/tests/engine/test_short_mm_context.py +++ b/tests/engine/test_short_mm_context.py @@ -5,12 +5,12 @@ import pytest from ..conftest import IMAGE_ASSETS -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "USER: <image>\nWhat's the content of the image?\nASSISTANT:", - "cherry_blossom": - "USER: <image>\nWhat is the season?\nASSISTANT:", -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:", + "cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:", + } +) models = ["llava-hf/llava-1.5-7b-hf"] @@ -19,15 +19,15 @@ models = ["llava-hf/llava-1.5-7b-hf"] def test_context_length_too_short(vllm_runner, image_assets, model): images = [asset.pil_image for asset in image_assets] - with pytest.raises(ValueError, - match="longer than the maximum model length"): + with pytest.raises(ValueError, match="longer than the maximum model length"): vllm_model = vllm_runner( model, max_model_len=128, # LLaVA has a feature size of 576 enforce_eager=True, + load_format="dummy", ) with vllm_model: - vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], - max_tokens=1, - images=[images[0]]) + vllm_model.generate_greedy( + [HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]] + ) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index a7c533ec24198..a52e1cb7df33d 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -26,8 +26,10 @@ def sample_token_ids(): @pytest.fixture def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + return ( + r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + ) @pytest.fixture @@ -35,40 +37,27 @@ def sample_json_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } @@ -80,65 +69,54 @@ def sample_complex_json_schema(): "score": { "type": "integer", "minimum": 0, - "maximum": 100 # Numeric range + "maximum": 100, # Numeric range }, "grade": { "type": "string", - "pattern": "^[A-D]$" # Regex pattern + "pattern": "^[A-D]$", # Regex pattern }, "email": { "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", }, "tags": { "type": "array", "items": { "type": "string", - "pattern": - "^[a-z]{1,10}$" # Combining length and pattern restrictions - } - } + # Combining length and pattern restrictions + "pattern": "^[a-z]{1,10}$", + }, + }, }, - "required": ["score", "grade", "email", "tags"] + "required": ["score", "grade", "email", "tags"], } @pytest.fixture def sample_definition_json_schema(): return { - '$defs': { - 'Step': { - 'properties': { - 'explanation': { - 'title': 'Explanation', - 'type': 'string' - }, - 'output': { - 'title': 'Output', - 'type': 'string' - } + "$defs": { + "Step": { + "properties": { + "explanation": {"title": "Explanation", "type": "string"}, + "output": {"title": "Output", "type": "string"}, }, - 'required': ['explanation', 'output'], - 'title': 'Step', - 'type': 'object' + "required": ["explanation", "output"], + "title": "Step", + "type": "object", } }, - 'properties': { - 'steps': { - 'items': { - '$ref': '#/$defs/Step' - }, - 'title': 'Steps', - 'type': 'array' + "properties": { + "steps": { + "items": {"$ref": "#/$defs/Step"}, + "title": "Steps", + "type": "array", }, - 'final_answer': { - 'title': 'Final Answer', - 'type': 'string' - } + "final_answer": {"title": "Final Answer", "type": "string"}, }, - 'required': ['steps', 'final_answer'], - 'title': 'MathReasoning', - 'type': 'object' + "required": ["steps", "final_answer"], + "title": "MathReasoning", + "type": "object", } @@ -149,55 +127,77 @@ def sample_enum_json_schema(): "properties": { "status": { "type": "string", - "enum": ["active", "inactive", - "pending"] # Literal values using enum + "enum": ["active", "inactive", "pending"], # Literal values using enum }, "priority": { "type": "string", - "enum": ["low", "medium", "high", "critical"] + "enum": ["low", "medium", "high", "critical"], }, "category": { "type": "object", "properties": { "type": { "type": "string", - "enum": ["bug", "feature", "improvement"] + "enum": ["bug", "feature", "improvement"], }, "severity": { "type": "integer", - "enum": [1, 2, 3, 4, - 5] # Enum can also contain numbers - } + "enum": [1, 2, 3, 4, 5], # Enum can also contain numbers + }, }, - "required": ["type", "severity"] + "required": ["type", "severity"], }, "flags": { "type": "array", "items": { "type": "string", - "enum": ["urgent", "blocked", "needs_review", "approved"] - } - } + "enum": ["urgent", "blocked", "needs_review", "approved"], + }, + }, }, - "required": ["status", "priority", "category", "flags"] + "required": ["status", "priority", "category", "flags"], } @pytest.fixture -def sample_guided_choice(): +def sample_structured_outputs_choices(): return [ - "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", - "Ruby", "Swift", "Kotlin" + "Python", + "Java", + "JavaScript", + "C++", + "C#", + "PHP", + "TypeScript", + "Ruby", + "Swift", + "Kotlin", ] @pytest.fixture def sample_sql_statements(): - return (""" + return """ start: select_statement select_statement: "SELECT" column "from" table "where" condition column: "col_1" | "col_2" table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") +""" + + +@pytest.fixture(scope="session") +def zephyr_lora_files(): + """Download zephyr LoRA files once per test session.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") + + +@pytest.fixture(scope="session") +def opt125_lora_files() -> str: + """Download opt-125m LoRA files once per test session.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id="peft-internal-testing/opt-125m-dummy-lora") diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 5d605e906e81b..af607720c8b0e 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -48,58 +48,47 @@ def run_test(model_name, more_args=None): measured_value = results["results"][TASK][FILTER] assert model_name in EXPECTED_VALUES, ( - f"Cannot find the expected value for the model {model_name=}") + f"Cannot find the expected value for the model {model_name=}" + ) expected_value = EXPECTED_VALUES[model_name] - assert (measured_value - RTOL < expected_value - and measured_value + RTOL > expected_value - ), f"Expected: {expected_value} | Measured: {measured_value}" + assert ( + measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" # TODO: [AlexM] Fix it with new CI/CD tests -TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" +TPU_TP_TEST_STR = "" # "tensor_parallel_size=4" -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 is currently only supported on CUDA and TPU") @pytest.mark.parametrize("model", MODEL_NAMES) -def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): +def test_lm_eval_accuracy_v1_engine(model): """Run with the V1 Engine.""" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + more_args = None + if current_platform.is_tpu(): + # Limit compilation time for TPU V1 - more_args = None - if current_platform.is_tpu(): - # Limit compilation time for TPU V1 + more_args = "max_model_len=2048,max_num_seqs=64" - more_args = "max_model_len=2048,max_num_seqs=64" + # Add TP test (if provided) + if TPU_TP_TEST_STR: + more_args += ",{}".format(TPU_TP_TEST_STR) - # Add TP test (if provided) - if TPU_TP_TEST_STR: - more_args += ",{}".format(TPU_TP_TEST_STR) - - run_test(model, more_args) + run_test(model, more_args) -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 is currently only supported on CUDA and TPU") @pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES) -def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( - model, monkeypatch: pytest.MonkeyPatch): +def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(model): """Run with the V1 Engine.""" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + more_args = None + if current_platform.is_tpu(): + # Limit compilation time for TPU V1 + more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8" - more_args = None - if current_platform.is_tpu(): - # Limit compilation time for TPU V1 - more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8" + # Add TP test (if provided) + if TPU_TP_TEST_STR: + more_args += ",{}".format(TPU_TP_TEST_STR) - # Add TP test (if provided) - if TPU_TP_TEST_STR: - more_args += ",{}".format(TPU_TP_TEST_STR) - - run_test(model, more_args) + run_test(model, more_args) diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 2cbfed98a577a..b2a958a992a62 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -7,16 +7,14 @@ import pytest from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory -from ..openai.test_vision import TEST_IMAGE_URLS +from ..openai.test_vision import TEST_IMAGE_ASSETS @pytest.fixture(scope="function") def text_llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - seed=0) + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0) yield weakref.proxy(llm) @@ -28,14 +26,8 @@ def text_llm(): def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt1}, ] outputs = text_llm.chat(messages) assert len(outputs) == 1 @@ -46,25 +38,13 @@ def test_multi_chat(text_llm): prompt2 = "Explain what among us is." conversation1 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt1}, ] conversation2 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt2}, ] messages = [conversation1, conversation2] @@ -94,25 +74,22 @@ def vision_llm(): cleanup_dist_env_and_memory() -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +@pytest.mark.parametrize( + "image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True +) def test_chat_multi_image(vision_llm, image_urls: list[str]): - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": "What's in this image?"}, + ], + } + ] outputs = vision_llm.chat(messages) assert len(outputs) >= 0 @@ -123,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm): Check we get a single BOS token for llama chat. """ messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello!" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello!"}, ] outputs = text_llm.chat(messages) assert len(outputs) == 1 @@ -166,14 +137,8 @@ def thinking_llm(): @pytest.mark.parametrize("enable_thinking", [True, False]) def test_chat_extra_kwargs(thinking_llm, enable_thinking): messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "What is 1+1?" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "What is 1+1?"}, ] outputs = thinking_llm.chat( diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 3a13f8c979f23..937aa5c132461 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -23,9 +23,11 @@ def test_collective_rpc(tp_size, backend, monkeypatch): return self.rank monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=tp_size, - distributed_executor_backend=backend) + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=tp_size, + distributed_executor_backend=backend, + ) assert llm.collective_rpc(echo_rank) == list(range(tp_size)) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 3bbbcc755d134..e9993fd840619 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -25,21 +25,17 @@ TOKEN_IDS = [ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True, + ) yield weakref.proxy(llm) @@ -87,8 +83,22 @@ def test_max_model_len(): outputs = llm.generate(PROMPTS, sampling_params) for output in outputs: num_total_tokens = len(output.prompt_token_ids) + len( - output.outputs[0].token_ids) + output.outputs[0].token_ids + ) # Total tokens must not exceed max_model_len. # It can be less if generation finishes due to other reasons (e.g., EOS) # before reaching the absolute model length limit. assert num_total_tokens <= max_model_len + + +def test_log_stats(): + llm = LLM( + model=MODEL_NAME, + disable_log_stats=False, + gpu_memory_utilization=0.10, + enforce_eager=True, # reduce test time + ) + outputs = llm.generate(PROMPTS, sampling_params=None) + + # disable_log_stats is False, every output should have metrics + assert all(output.metrics is not None for output in outputs) diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py deleted file mode 100644 index a04f195692e9b..0000000000000 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import weakref - -import pytest -# downloading lora to test lora requests -from huggingface_hub import snapshot_download - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" - -PROMPTS = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -LORA_NAME = "typeof/zephyr-7b-beta-lora" - - -@pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module", params=[False, True]) -def llm(request, monkeypatch_module): - - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') - - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - tensor_parallel_size=1, - max_model_len=8192, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - max_num_seqs=128, - enforce_eager=True) - - yield weakref.proxy(llm) - - del llm - - cleanup_dist_env_and_memory() - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.mark.skip_global_cleanup -def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): - lora_request = [ - LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files) - for idx in range(len(PROMPTS)) - ] - # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(PROMPTS, lora_request=lora_request) - assert len(PROMPTS) == len(outputs) - - # Exception raised, if the size of params does not match the size of prompts - with pytest.raises(ValueError): - outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) - - # Single LoRARequest should be applied to every prompt - single_lora_request = lora_request[0] - outputs = llm.generate(PROMPTS, lora_request=single_lora_request) - assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/llm/test_gpu_utilization.py b/tests/entrypoints/llm/test_gpu_utilization.py index 533da9e6d6eac..896091533ad29 100644 --- a/tests/entrypoints/llm/test_gpu_utilization.py +++ b/tests/entrypoints/llm/test_gpu_utilization.py @@ -16,9 +16,8 @@ def test_gpu_memory_utilization(): # makes sure gpu_memory_utilization is per-instance limit, # not a global limit llms = [ - LLM(model="facebook/opt-125m", - gpu_memory_utilization=0.3, - enforce_eager=True) for i in range(3) + LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True) + for i in range(3) ] for llm in llms: outputs = llm.generate(prompts, sampling_params) diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py deleted file mode 100644 index ac0b7e134c55a..0000000000000 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ /dev/null @@ -1,82 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import sys -from contextlib import nullcontext - -from vllm_test_utils import BlameResult, blame - -from vllm import LLM, SamplingParams -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.sampling_params import GuidedDecodingParams - - -def run_normal(): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - # Create an LLM without guided decoding as a baseline. - llm = LLM(model="distilbert/distilgpt2", - enforce_eager=True, - gpu_memory_utilization=0.3) - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - # Destroy the LLM object and free up the GPU memory. - del llm - cleanup_dist_env_and_memory() - - -def run_xgrammar(sample_regex): - # Create an LLM with guided decoding enabled. - llm = LLM(model="distilbert/distilgpt2", - enforce_eager=True, - guided_decoding_backend="xgrammar", - gpu_memory_utilization=0.3) - prompt = f"Give an example IPv4 address with this regex: {sample_regex}" - guided_decoding = GuidedDecodingParams(regex=sample_regex) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=guided_decoding) - outputs = llm.generate( - prompts=[prompt] * 2, - sampling_params=sampling_params, - use_tqdm=True, - ) - - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def test_lazy_outlines(sample_regex): - """If users don't use guided decoding, outlines should not be imported. - """ - # make sure outlines is not imported - module_name = "outlines" - # In CI, we only check finally if the module is imported. - # If it is indeed imported, we can rerun the test with `use_blame=True`, - # which will trace every function call to find the first import location, - # and help find the root cause. - # We don't run it in CI by default because it is slow. - use_blame = False - context = blame( - lambda: module_name in sys.modules) if use_blame else nullcontext() - with context as result: - run_normal() - run_xgrammar(sample_regex) - if use_blame: - assert isinstance(result, BlameResult) - print(f"the first import location is:\n{result.trace_stack}") - assert module_name not in sys.modules, ( - f"Module {module_name} is imported. To see the first" - f" import location, run the test with `use_blame=True`.") diff --git a/tests/entrypoints/llm/test_mm_cache_stats.py b/tests/entrypoints/llm/test_mm_cache_stats.py new file mode 100644 index 0000000000000..e5ee99124409d --- /dev/null +++ b/tests/entrypoints/llm/test_mm_cache_stats.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging + +import pytest +import regex as re + +from vllm import LLM +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.v1.metrics import loggers as stat_loggers +from vllm.v1.metrics.reader import Counter, Metric + +from ..openai.test_vision import TEST_IMAGE_ASSETS + + +def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]: + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + ], + } + ] + + +def _get_counter_value(metrics: list[Metric], name: str): + metric = next(m for m in metrics if m.name == name) + assert isinstance(metric, Counter) + return metric.value + + +def _get_mm_cache_stats(metrics: list[Metric]): + mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries") + mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits") + + return mm_cache_queries, mm_cache_hits + + +def _get_mm_cache_log(llm: LLM, caplog_vllm: pytest.LogCaptureFixture) -> float: + caplog_vllm.clear() + with caplog_vllm.at_level(logging.INFO, logger=stat_loggers.__name__): + llm.llm_engine.do_log_stats() + + assert len(caplog_vllm.records) == 1 + msg = caplog_vllm.records[0].getMessage() + + assert "MM cache hit rate" in msg + match = re.search(r"MM cache hit rate: ([0-9.]+)%", msg) + assert match is not None + return float(match.group(1)) + + +@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True) +@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"]) +def test_mm_cache_stats( + num_gpus_available, + image_urls, + mm_processor_cache_type, + caplog_vllm, +): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + mm_processor_cache_type=mm_processor_cache_type, + disable_log_stats=False, + limit_mm_per_prompt={"image": 2}, + ) + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) + + llm.chat(_make_messages(image_urls[1])) + assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(33.3) + + # NOTE: This only resets hit rate stats in CachingMetrics + # The raw queries and hits counts remain unaffected + llm.reset_mm_cache() + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) + + llm.chat(_make_messages(image_urls[1])) + assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 1b7be15d5d691..81126a4f16f98 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -6,22 +6,14 @@ import pytest from vllm import LLM -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - def test_empty_prompt(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='decoder prompt cannot be empty'): + with pytest.raises(ValueError, match="decoder prompt cannot be empty"): llm.generate([""]) @pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='out of vocabulary'): + with pytest.raises(ValueError, match="out of vocabulary"): llm.generate({"prompt_token_ids": [999999]}) diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py deleted file mode 100644 index de82cf8d40380..0000000000000 --- a/tests/entrypoints/llm/test_reward.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import weakref - -import pytest -import torch - -from vllm import LLM, PoolingParams -from vllm.distributed import cleanup_dist_env_and_memory - -from ...models.utils import softmax - -MODEL_NAME = "internlm/internlm2-1_8b-reward" - -prompts = ["The chef prepared a delicious meal."] - - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - -@pytest.fixture(scope="module") -def llm(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - trust_remote_code=True, - seed=0) - - yield weakref.proxy(llm) - - del llm - - cleanup_dist_env_and_memory() - - -@pytest.mark.skip_global_cleanup -def test_pooling_params(llm: LLM): - - def get_outputs(softmax): - outputs = llm.reward(prompts, - pooling_params=PoolingParams(softmax=softmax), - use_tqdm=False) - return torch.cat([x.outputs.data for x in outputs]) - - default = get_outputs(softmax=None) - w_softmax = get_outputs(softmax=True) - wo_softmax = get_outputs(softmax=False) - - assert torch.allclose(default, w_softmax, - atol=1e-2), "Default should use softmax." - assert not torch.allclose(w_softmax, wo_softmax, - atol=1e-2), "wo_softmax should not use softmax." - assert torch.allclose( - softmax(wo_softmax), w_softmax, - atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index dd8d63ad319ac..25e663f3af0eb 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" + import dataclasses import importlib import sys @@ -32,15 +33,16 @@ MODEL_CONFIGS = [ "tensor_parallel_size": 1, "tokenizer_mode": "mistral", }, - { - "model": "sentence-transformers/all-MiniLM-L12-v2", - "enforce_eager": True, - "gpu_memory_utilization": 0.20, - "max_model_len": 64, - "max_num_batched_tokens": 64, - "max_num_seqs": 64, - "tensor_parallel_size": 1, - }, + # TODO: re-enable once these tests are run with V1 + # { + # "model": "sentence-transformers/all-MiniLM-L12-v2", + # "enforce_eager": True, + # "gpu_memory_utilization": 0.20, + # "max_model_len": 64, + # "max_num_batched_tokens": 64, + # "max_num_seqs": 64, + # "tensor_parallel_size": 1, + # }, ] @@ -78,7 +80,7 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch): ) # Need to re-import huggingface_hub - # and friends to setup offline mode + # and friends to set up offline mode _re_import_modules() # Cached model files should be used in offline mode for model_config in MODEL_CONFIGS: @@ -90,12 +92,11 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch): def _re_import_modules(): - hf_hub_module_names = [ - k for k in sys.modules if k.startswith("huggingface_hub") - ] + hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")] transformers_module_names = [ - k for k in sys.modules if k.startswith("transformers") - and not k.startswith("transformers_modules") + k + for k in sys.modules + if k.startswith("transformers") and not k.startswith("transformers_modules") ] reload_exception = None @@ -135,7 +136,7 @@ def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): disable_connect, ) # Need to re-import huggingface_hub - # and friends to setup offline mode + # and friends to set up offline mode _re_import_modules() engine_args = EngineArgs(model="facebook/opt-125m") LLM(**dataclasses.asdict(engine_args)) diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/openai/conftest.py new file mode 100644 index 0000000000000..b40079d8dc3d5 --- /dev/null +++ b/tests/entrypoints/openai/conftest.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.assets.audio import AudioAsset + + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset("mary_had_lamb").get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def winning_call(): + path = AudioAsset("winning_call").get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def foscolo(): + # Test translation it->en + path = AudioAsset("azacinto_foscolo").get_local_path() + with open(str(path), "rb") as f: + yield f diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index 684407cd6ee97..5b23b42390279 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -10,7 +10,6 @@ AsyncLLMEngine are working correctly. """ import lm_eval -import pytest from vllm.platforms import current_platform @@ -44,14 +43,15 @@ def run_test(more_args): print(f"Running with: {args}") with RemoteOpenAIServer( - MODEL_NAME, args, - max_wait_seconds=MAX_WAIT_SECONDS) as remote_server: + MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS + ) as remote_server: url = f"{remote_server.url_for('v1')}/completions" model_args = ( f"model={MODEL_NAME}," f"base_url={url}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -60,34 +60,19 @@ def run_test(more_args): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu() - and not current_platform.is_xpu(), - reason="V1 currently only supported on CUDA, XPU and TPU") -def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): +def test_lm_eval_accuracy_v1_engine(): """Run with the V1 Engine.""" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - more_args = [] + more_args = [] - # Limit compilation time for V1 - if current_platform.is_tpu(): - more_args = ["--max-num-seqs", "64"] + # Limit compilation time for V1 + if current_platform.is_tpu(): + more_args = ["--max-num-seqs", "64"] - run_test(more_args) - - -@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) -def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch, - more_args): - """Run with the V0 Engine.""" - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - run_test(more_args) + run_test(more_args) diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 58195f98bd351..7821ade63ac38 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -7,6 +7,7 @@ a baseline. This simulates real work usage of the API and makes sure that the frontend and AsyncLLMEngine are working correctly. """ + import asyncio import io import time @@ -32,7 +33,7 @@ def to_bytes(y, sr): async def transcribe_audio(client, tokenizer, y, sr): # Send loaded audio directly instead of loading from disk, - # dont account for that time though + # don't account for that time though with to_bytes(y, sr) as f: start_time = time.perf_counter() transcription = await client.audio.transcriptions.create( @@ -45,12 +46,12 @@ async def transcribe_audio(client, tokenizer, y, sr): # NOTE there's no streaming in transcriptions, can't measure ttft latency = end_time - start_time num_output_tokens = len( - tokenizer(transcription.text, add_special_tokens=False).input_ids) + tokenizer(transcription.text, add_special_tokens=False).input_ids + ) return latency, num_output_tokens, transcription.text -async def bound_transcribe(model_name, sem, client, audio, reference): - tokenizer = AutoTokenizer.from_pretrained(model_name) +async def bound_transcribe(sem, client, tokenizer, audio, reference): # Use semaphore to limit concurrent requests. async with sem: result = await transcribe_audio(client, tokenizer, *audio) @@ -63,15 +64,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference): async def process_dataset(model, client, data, concurrent_request): sem = asyncio.Semaphore(concurrent_request) + # Load tokenizer once outside the loop + tokenizer = AutoTokenizer.from_pretrained(model) + # Warmup call as the first `librosa.load` server-side is quite slow. audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] - _ = await bound_transcribe(model, sem, client, (audio, sr), "") + _ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "") tasks: list[asyncio.Task] = [] for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"]) + ) tasks.append(task) return await asyncio.gather(*tasks) @@ -95,34 +100,35 @@ def print_performance_metrics(results, total_time): def add_duration(sample): - y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] - sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000 return sample -def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): +def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs): ## Load and filter the dataset dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) - if 'duration_ms' not in dataset[0]: + if "duration_ms" not in dataset[0]: # compute duration to filter dataset = dataset.map(add_duration) # Whisper max supported duration - dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + dataset = dataset.filter(lambda example: example["duration_ms"] < 30000) return dataset -def run_evaluation(model: str, - client, - dataset, - max_concurrent_reqs: int, - n_examples: int = -1, - print_metrics: bool = True): +def run_evaluation( + model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True, +): if n_examples > 0: dataset = dataset.select(range(n_examples)) start = time.perf_counter() - results = asyncio.run( - process_dataset(model, client, dataset, max_concurrent_reqs)) + results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs)) end = time.perf_counter() total_time = end - start print(f"Total Test Time: {total_time:.4f} seconds") @@ -132,8 +138,7 @@ def run_evaluation(model: str, predictions = [res[2] for res in results] references = [res[3] for res in results] wer = load("wer") - wer_score = 100 * wer.compute(references=references, - predictions=predictions) + wer_score = 100 * wer.compute(references=references, predictions=predictions) print("WER:", wer_score) return wer_score @@ -142,26 +147,25 @@ def run_evaluation(model: str, @pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) # Original dataset is 20GB+ in size, hence we use a pre-filtered slice. @pytest.mark.parametrize( - "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"] +) # NOTE: Expected WER measured with equivalent hf.transformers args: # whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. @pytest.mark.parametrize("expected_wer", [12.744980]) -def test_wer_correctness(model_name, - dataset_repo, - expected_wer, - n_examples=-1, - max_concurrent_request=None): +def test_wer_correctness( + model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None +): # TODO refactor to use `ASRDataset` - with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server: dataset = load_hf_dataset(dataset_repo) if not max_concurrent_request: # No max concurrency - max_concurrent_request = n_examples if n_examples > 0\ - else len(dataset) + max_concurrent_request = n_examples if n_examples > 0 else len(dataset) client = remote_server.get_async_client() - wer = run_evaluation(model_name, client, dataset, - max_concurrent_request, n_examples) + wer = run_evaluation( + model_name, client, dataset, max_concurrent_request, n_examples + ) if expected_wer: torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py index 80261597b11a8..5df859df42da7 100644 --- a/tests/entrypoints/openai/test_async_tokenization.py +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -44,15 +44,11 @@ async def client(server): ids=["completion", "chat"], argnames=["create_func_gen", "content_body"], argvalues=[ - (lambda x: x.completions.create, { - "prompt": " ".join(['A'] * 10_000) - }), - (lambda x: x.chat.completions.create, { - "messages": [{ - "role": "user", - "content": " ".join(['A'] * 10_000) - }] - }), + (lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}), + ( + lambda x: x.chat.completions.create, + {"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]}, + ), ], ) async def test_with_and_without_truncate( @@ -65,15 +61,15 @@ async def test_with_and_without_truncate( body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} num_requests = 10 - truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * - (num_requests - num_requests // 2)) + truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * ( + num_requests - num_requests // 2 + ) random.shuffle(truncate_prompt_tokens) - bodies = [{ - **body, "extra_body": { - 'truncate_prompt_tokens': t - } - } for t in truncate_prompt_tokens] + bodies = [ + {**body, "extra_body": {"truncate_prompt_tokens": t}} + for t in truncate_prompt_tokens + ] async def get_status_code(**kwargs): try: diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 2d33d3c3a6b54..a96f0134c2ffb 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_single_chat_session_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_single_chat_session_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI, - model_name: str, - audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": audio_url - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_error_on_invalid_audio_url_type( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": audio_url}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # audio_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_audio_base64encoded( - client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": - f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded( max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_input_audio( - client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: dict[str, str]): - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav", + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio( messages=messages, max_completion_tokens=10, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) -async def test_chat_streaming_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_chat_streaming_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) -async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str, - base64_encoded_audio: dict[str, - str]): - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_chat_streaming_input_audio( + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav", + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]) -async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, - audio_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "audio_url", - "audio_url": { - "url": audio_url - } - } for audio_url in audio_urls), - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]] +) +async def test_multi_audio_input( + client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "audio_url", "audio_url": {"url": audio_url}} + for audio_url in audio_urls + ), + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] if len(audio_urls) > MAXIMUM_AUDIOS: with pytest.raises(openai.BadRequestError): # test multi-audio input diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index a55941976cd82..50ec87b4464f6 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -16,9 +16,9 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def server_args(request: pytest.FixtureRequest) -> list[str]: - """ Provide extra arguments to the server via indirect parametrization + """Provide extra arguments to the server via indirect parametrization Usage: @@ -80,8 +80,10 @@ async def client(server): "server_args", [ pytest.param([], id="default-frontend-multiprocessing"), - pytest.param(["--disable-frontend-multiprocessing"], - id="disable-frontend-multiprocessing") + pytest.param( + ["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer): "server_args", [ pytest.param([], id="default-frontend-multiprocessing"), - pytest.param(["--disable-frontend-multiprocessing"], - id="disable-frontend-multiprocessing") + pytest.param( + ["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer): @pytest.mark.parametrize( "server_args", [ - pytest.param(["--max-model-len", "10100"], - id="default-frontend-multiprocessing"), + pytest.param( + ["--max-model-len", "10100"], id="default-frontend-multiprocessing" + ), pytest.param( ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], - id="disable-frontend-multiprocessing") + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # Request about 2 million tokens for _ in range(200): task = asyncio.create_task( - client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10000, - extra_body={"min_tokens": 10000})) + client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_body={"min_tokens": 10000}, + ) + ) tasks.append(task) - done, pending = await asyncio.wait(tasks, - return_when=asyncio.ALL_COMPLETED) + done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) # Make sure all requests were sent to the server and timed out # (We don't want to hide other errors like 400s that would invalidate this @@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # If the server had not cancelled all the other requests, then it would not # be able to respond to this one within the timeout client = server.get_async_client(timeout=5) - response = await client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10) + response = await client.chat.completions.create( + messages=chat_input, model=MODEL_NAME, max_tokens=10 + ) assert len(response.choices) == 1 @pytest.mark.asyncio async def test_request_wrong_content_type(server: RemoteOpenAIServer): - chat_input = [{"role": "user", "content": "Write a long story"}] client = server.get_async_client() @@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer): messages=chat_input, model=MODEL_NAME, max_tokens=10000, - extra_headers={ - "Content-Type": "application/x-www-form-urlencoded" - }) + extra_headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) @pytest.mark.parametrize( "server_args", - [ - pytest.param(["--enable-server-load-tracking"], - id="enable-server-load-tracking") - ], + [pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")], indirect=True, ) @pytest.mark.asyncio @@ -202,7 +205,8 @@ async def test_server_load(server: RemoteOpenAIServer): # Start the completion request in a background thread. completion_future = asyncio.create_task( - asyncio.to_thread(make_long_completion_request)) + asyncio.to_thread(make_long_completion_request) + ) # Give a short delay to ensure the request has started. await asyncio.sleep(0.1) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 5ad29d70f10df..14181c6b8b16b 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# imports for guided decoding tests +# imports for structured outputs tests import json from typing import Optional @@ -12,34 +12,16 @@ import pytest_asyncio import regex as re import requests import torch -from openai import BadRequestError, OpenAI +from openai import BadRequestError from ...utils import RemoteOpenAIServer -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module", params=[False, True]) -def server( - request, - monkeypatch_module, - zephyr_lora_files, #noqa: F811 - zephyr_lora_added_tokens_files): # noqa: F811 - - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') - +def server(zephyr_lora_files): # noqa: F811 args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -51,7 +33,6 @@ def server( "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", "--max-cpu-loras", @@ -64,13 +45,6 @@ def server( yield remote_server -@pytest.fixture -def is_v1_server(server): - import os - assert os.environ['VLLM_USE_V1'] in ['0', '1'] - return os.environ['VLLM_USE_V1'] == '1' - - @pytest_asyncio.fixture async def client(server): async with server.get_async_client() as async_client: @@ -81,23 +55,21 @@ async def client(server): @pytest.mark.parametrize( # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_completion_tokens=5, temperature=0.0, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.logprobs is None @@ -110,13 +82,10 @@ async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME, "zephyr-lora"], ) async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -124,7 +93,8 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): max_completion_tokens=5, temperature=0.0, logprobs=True, - top_logprobs=0) + top_logprobs=0, + ) choice = chat_completion.choices[0] assert choice.logprobs is not None @@ -138,13 +108,10 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME, "zephyr-lora"], ) async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -152,7 +119,8 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): max_completion_tokens=5, temperature=0.0, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) choice = chat_completion.choices[0] assert choice.logprobs is not None @@ -165,41 +133,39 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME, "zephyr-lora"], ) -async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, model_name: str): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - logprobs=True, - top_logprobs=21, - stream=True) + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=21, + stream=True, + ) async for chunk in stream: ... with pytest.raises(openai.BadRequestError): - await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - logprobs=True, - top_logprobs=30, - stream=False) + await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=30, + stream=False, + ) # the server should still work afterwards chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=10, - stream=False) + model=model_name, messages=messages, max_completion_tokens=10, stream=False + ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 @@ -209,27 +175,20 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, "model_name, prompt_logprobs", [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], ) -async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): +async def test_prompt_logprobs_chat( + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] +): params: dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": model_name, } if prompt_logprobs is not None: @@ -252,29 +211,21 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str): +async def test_more_than_one_prompt_logprobs_chat( + client: openai.AsyncOpenAI, model_name: str +): params: dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name, - "extra_body": { - "prompt_logprobs": 1 - } + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": model_name, + "extra_body": {"prompt_logprobs": 1}, } completion_1 = await client.chat.completions.create(**params) @@ -291,15 +242,11 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME, "zephyr-lora"], ) -async def test_single_chat_session(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # test single completion chat_completion = await client.chat.completions.create( @@ -307,14 +254,16 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, messages=messages, max_completion_tokens=10, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=37, total_tokens=47) + completion_tokens=10, prompt_tokens=37, total_tokens=47 + ) message = choice.message assert message.content is not None and len(message.content) >= 10 @@ -339,13 +288,10 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, [MODEL_NAME, "zephyr-lora"], ) async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # test single completion chat_completion = await client.chat.completions.create( @@ -387,15 +333,13 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): "model_name", ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], ) -async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" - }] +async def test_chat_completion_stream_options( + client: openai.AsyncOpenAI, model_name: str +): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] # Test stream=True, stream_options={"include_usage": False} stream = await client.chat.completions.create( @@ -404,36 +348,34 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=True, - stream_options={"include_usage": False}) + stream_options={"include_usage": False}, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options={"include_usage": True, # "continuous_usage_stats": False}} - stream = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": - True, - "continuous_usage_stats": - False - }) + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + stream=True, + stream_options={"include_usage": True, "continuous_usage_stats": False}, + ) async for chunk in stream: if chunk.choices[0].finish_reason is None: assert chunk.usage is None else: assert chunk.usage is None - final_chunk = await stream.__anext__() + final_chunk = await anext(stream) assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=False, stream_options={"include_usage": None} @@ -444,7 +386,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=False, - stream_options={"include_usage": None}) + stream_options={"include_usage": None}, + ) # Test stream=False, stream_options={"include_usage": True} with pytest.raises(BadRequestError): @@ -454,7 +397,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=False, - stream_options={"include_usage": True}) + stream_options={"include_usage": True}, + ) # Test stream=True, stream_options={"include_usage": True, # "continuous_usage_stats": True} @@ -473,96 +417,96 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, last_completion_tokens = 0 async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 - assert last_completion_tokens == 0 or \ - chunk.usage.completion_tokens > last_completion_tokens or \ - ( - not chunk.choices and - chunk.usage.completion_tokens == last_completion_tokens - ) - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert ( + last_completion_tokens == 0 + or chunk.usage.completion_tokens > last_completion_tokens + or ( + not chunk.choices + and chunk.usage.completion_tokens == last_completion_tokens + ) + ) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) last_completion_tokens = chunk.usage.completion_tokens assert last_completion_tokens == 10 @pytest.mark.asyncio -async def test_guided_choice_chat(client: openai.AsyncOpenAI, - sample_guided_choice, is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] +async def test_structured_outputs_choice_chat( + client: openai.AsyncOpenAI, + sample_structured_outputs_choices, +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, temperature=0.7, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict( + structured_outputs={"choice": sample_structured_outputs_choices} + ), + ) choice1 = chat_completion.choices[0].message.content - assert choice1 in sample_guided_choice + assert choice1 in sample_structured_outputs_choices messages.append({"role": "assistant", "content": choice1}) - messages.append({ - "role": "user", - "content": "I disagree, pick another one" - }) + messages.append({"role": "user", "content": "I disagree, pick another one"}) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, temperature=0.7, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict( + structured_outputs={"choice": sample_structured_outputs_choices} + ), + ) choice2 = chat_completion.choices[0].message.content - assert choice2 in sample_guided_choice + assert choice2 in sample_structured_outputs_choices assert choice1 != choice2 @pytest.mark.asyncio -async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema, - is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] +async def test_structured_outputs_json_chat( + client: openai.AsyncOpenAI, + sample_json_schema, +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(guided_json=sample_json_schema)) + extra_body=dict(structured_outputs={"json": sample_json_schema}), + ) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) jsonschema.validate(instance=json1, schema=sample_json_schema) messages.append({"role": "assistant", "content": message.content}) - messages.append({ - "role": - "user", - "content": - "Give me another one with a different name and age" - }) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(guided_json=sample_json_schema)) + extra_body=dict(structured_outputs={"json": sample_json_schema}), + ) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -572,25 +516,23 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema, @pytest.mark.asyncio -async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, - is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example IP address with this regex: {sample_regex}" - }] +async def test_structured_outputs_regex_chat( + client: openai.AsyncOpenAI, + sample_regex, +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example IP address with this regex: {sample_regex}", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(guided_regex=sample_regex)) + extra_body=dict(structured_outputs={"regex": sample_regex}), + ) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(sample_regex, ip1) is not None @@ -601,7 +543,8 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(guided_regex=sample_regex)) + extra_body=dict(structured_outputs={"regex": sample_regex}), + ) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(sample_regex, ip2) is not None @@ -609,46 +552,44 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, @pytest.mark.asyncio -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] +async def test_structured_outputs_type_error(client: openai.AsyncOpenAI): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - extra_body=dict(guided_regex={ - 1: "Python", - 2: "C++" - })) + _ = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body=dict(structured_outputs={"regex": {1: "Python", 2: "C++"}}), + ) @pytest.mark.asyncio -async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, - sample_guided_choice): - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] +async def test_structured_outputs_choice_chat_logprobs( + client: openai.AsyncOpenAI, sample_structured_outputs_choices +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict( + structured_outputs={"choice": sample_structured_outputs_choices} + ), + ) assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs.content is not None @@ -660,20 +601,30 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, @pytest.mark.asyncio -async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, - is_v1_server: bool): - if not is_v1_server: - pytest.skip("Tool use is only supported in v1 engine") - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] +async def test_named_tool_use( + client: openai.AsyncOpenAI, + sample_json_schema, +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": ( + "Give an example JSON for an employee profile using the specified tool." + ), + }, + ] + tools = [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "dummy_function_name"}} # non-streaming @@ -681,20 +632,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], - tool_choice={ - "type": "function", - "function": { - "name": "dummy_function_name" - } - }, + tools=tools, + tool_choice=tool_choice, ) message = chat_completion.choices[0].message assert len(message.content) == 0 @@ -703,12 +642,9 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, jsonschema.validate(instance=json1, schema=sample_json_schema) messages.append({"role": "assistant", "content": json_string}) - messages.append({ - "role": - "user", - "content": - "Give me another one with a different name and age" - }) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) # streaming @@ -716,21 +652,10 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], - tool_choice={ - "type": "function", - "function": { - "name": "dummy_function_name" - } - }, - stream=True) + tools=tools, + tool_choice=tool_choice, + stream=True, + ) output = [] finish_reason_count = 0 @@ -752,64 +677,66 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, @pytest.mark.asyncio -async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, - sample_json_schema): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] - - with pytest.raises(openai.BadRequestError): - await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tool_choice={ - "type": "function", - "function": { - "name": - "dummy_function_name" - } - }) +async def test_inconsistent_tool_choice_and_tools( + client: openai.AsyncOpenAI, sample_json_schema +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] with pytest.raises(openai.BadRequestError): await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], tool_choice={ "type": "function", - "function": { - "name": "nondefined_function_name" - } - }) + "function": {"name": "dummy_function_name"}, + }, + ) + with pytest.raises(openai.BadRequestError): await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }], - tool_choice={}) + ], + tool_choice={ + "type": "function", + "function": {"name": "nondefined_function_name"}, + }, + ) + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ], + tool_choice={}, + ) @pytest.mark.asyncio @@ -817,13 +744,17 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') - }], - response_format={"type": "json_object"}) + messages=[ + { + "role": "user", + "content": ( + "what is 1+1? please respond with a JSON object, " + 'the format is {"result": 2}' + ), + } + ], + response_format={"type": "json_object"}, + ) content = resp.choices[0].message.content assert content is not None @@ -833,20 +764,13 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): @pytest.mark.asyncio -async def test_response_format_json_schema(client: openai.AsyncOpenAI, - is_v1_server: bool): - if not is_v1_server: - pytest.skip( - "JSON schema response format is only supported in v1 engine") +async def test_response_format_json_schema(client: openai.AsyncOpenAI): prompt = 'what is 1+1? The format is "result": 2' # Check that this prompt cannot lead to a valid JSON without json_schema for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], ) content = resp.choices[0].message.content assert content is not None @@ -857,10 +781,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI, for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], response_format={ "type": "json_schema", "json_schema": { @@ -868,13 +789,12 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI, "schema": { "type": "object", "properties": { - "result": { - "type": "integer" - }, + "result": {"type": "integer"}, }, }, - } - }) + }, + }, + ) content = resp.choices[0].message.content assert content is not None @@ -887,13 +807,16 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI, async def test_extra_fields_allowed(client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?", - "extra_field": "0", - }], # type: ignore + messages=[ + { + "role": "user", + "content": "what is 1+1?", + "extra_field": "0", + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) content = resp.choices[0].message.content assert content is not None @@ -901,20 +824,23 @@ async def test_extra_fields_allowed(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_complex_message_content(client: openai.AsyncOpenAI): + content = [ + { + "type": "text", + "text": "what is 1+1? please provide the result without any other text.", + } + ] resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": [{ - "type": - "text", - "text": - "what is 1+1? please provide the result without any other text." - }] - }], + messages=[ + { + "role": "user", + "content": content, + } + ], temperature=0, - seed=0) + seed=0, + ) content = resp.choices[0].message.content assert content == "2" @@ -926,24 +852,27 @@ async def test_custom_role(client: openai.AsyncOpenAI): resp1 = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "my-custom-role", - "content": "what is 1+1?", - }], # type: ignore + messages=[ + { + "role": "my-custom-role", + "content": "what is 1+1?", + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) resp2 = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "my-custom-role", - "content": [{ - "type": "text", - "text": "what is 1+1?" - }] - }], # type: ignore + messages=[ + { + "role": "my-custom-role", + "content": [{"type": "text", "text": "what is 1+1?"}], + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) content1 = resp1.choices[0].message.content content2 = resp2.choices[0].message.content @@ -952,87 +881,32 @@ async def test_custom_role(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_long_seed(client: openai.AsyncOpenAI): - for seed in [ - torch.iinfo(torch.long).min - 1, - torch.iinfo(torch.long).max + 1 - ]: + for seed in [torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).max + 1]: with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant.", - }], + messages=[ + { + "role": "system", + "content": "You are a helpful assistant.", + } + ], temperature=0, - seed=seed) + seed=seed, + ) - assert ("greater_than_equal" in exc_info.value.message - or "less_than_equal" in exc_info.value.message) + assert ( + "greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message + ) @pytest.mark.asyncio -async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): - url = f"http://localhost:{server.port}/v1/chat/completions" - headers = { - "Content-Type": "application/json", - } - data = { - # model_name is avoided here. - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "what is 1+1?" - }], - "max_tokens": - 5 - } - - response = requests.post(url, headers=headers, json=data) - response_data = response.json() - print(response_data) - assert response_data.get("model") == MODEL_NAME - choice = response_data.get("choices")[0] - message = choice.get("message") - assert message is not None - content = message.get("content") - assert content is not None - assert len(content) > 0 - - -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): - openai_api_key = "EMPTY" - openai_api_base = f"http://localhost:{server.port}/v1" - - client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, - ) +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): messages = [ - { - "role": "user", - "content": "Hello, vLLM!" - }, + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, ] - response = client.chat.completions.create( - model="", # empty string - messages=messages, - ) - assert response.model == MODEL_NAME - - -@pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] request_args = { "model": MODEL_NAME, @@ -1044,8 +918,9 @@ async def test_invocations(server: RemoteOpenAIServer, chat_completion = await client.chat.completions.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_completion.model_dump() diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py index de63f4ed218b6..b3b8b700336db 100644 --- a/tests/entrypoints/openai/test_chat_echo.py +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -7,12 +7,23 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio +from vllm.config import ModelConfig + from ...utils import RemoteOpenAIServer # # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +def get_vocab_size(model_name): + config = ModelConfig( + model=model_name, + seed=0, + dtype="float16", + ) + return config.get_vocab_size() + + @pytest.fixture(scope="module") def server(): args = [ @@ -22,6 +33,8 @@ def server(): "--enforce-eager", "--max-model-len", "4080", + "--max-logprobs", # test prompt_logprobs equal to -1 + "151936", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -44,27 +57,26 @@ class TestCase(NamedTuple): "test_case", [ TestCase(model_name=MODEL_NAME, echo=True), - TestCase(model_name=MODEL_NAME, echo=False) + TestCase(model_name=MODEL_NAME, echo=False), ], ) async def test_chat_session_with_echo_and_continue_final_message( - client: openai.AsyncOpenAI, test_case: TestCase): + client: openai.AsyncOpenAI, test_case: TestCase +): saying: str = "Here is a common saying about apple. An apple a day, keeps" # test echo with continue_final_message parameter chat_completion = await client.chat.completions.create( model=test_case.model_name, - messages=[{ - "role": "user", - "content": "tell me a common saying" - }, { - "role": "assistant", - "content": saying - }], + messages=[ + {"role": "user", "content": "tell me a common saying"}, + {"role": "assistant", "content": saying}, + ], extra_body={ "echo": test_case.echo, "continue_final_message": True, - "add_generation_prompt": False - }) + "add_generation_prompt": False, + }, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 @@ -77,3 +89,44 @@ async def test_chat_session_with_echo_and_continue_final_message( else: assert message.content is not None and saying not in message.content assert message.role == "assistant" + + +@pytest.mark.asyncio +async def test_prompt_logprobs(client: openai.AsyncOpenAI): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Beijing is the capital of which country?"}, + ] + + completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body={"prompt_logprobs": -1}, + ) + + assert completion.prompt_logprobs is not None + assert len(completion.prompt_logprobs) > 0 + + +@pytest.mark.asyncio +async def test_top_logprobs(client: openai.AsyncOpenAI): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Beijing is the capital of which country?"}, + ] + + completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1, + extra_body={ + "top_logprobs": -1, + "logprobs": "true", + }, + ) + assert completion.choices[0].logprobs is not None + assert completion.choices[0].logprobs.content is not None + assert len(completion.choices[0].logprobs.content) > 0 + assert len( + completion.choices[0].logprobs.content[0].top_logprobs + ) == get_vocab_size(MODEL_NAME) diff --git a/tests/entrypoints/openai/test_chat_logit_bias_validation.py b/tests/entrypoints/openai/test_chat_logit_bias_validation.py index 9fa7ab83555af..6539613ed17b9 100644 --- a/tests/entrypoints/openai/test_chat_logit_bias_validation.py +++ b/tests/entrypoints/openai/test_chat_logit_bias_validation.py @@ -49,10 +49,7 @@ async def test_chat_logit_bias_valid(client): completion = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "Testing valid logit bias" - }], + messages=[{"role": "user", "content": "Testing valid logit bias"}], max_tokens=5, logit_bias={str(valid_token_id): 1.0}, ) @@ -69,10 +66,7 @@ async def test_chat_logit_bias_invalid(client): with pytest.raises(openai.BadRequestError) as excinfo: await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "Testing invalid logit bias" - }], + messages=[{"role": "user", "content": "Testing invalid logit bias"}], max_tokens=5, logit_bias={str(invalid_token_id): 1.0}, ) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 5b6e2a4146b1f..d1202a59752bf 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -4,8 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (apply_hf_chat_template, - load_chat_template) +from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer @@ -17,48 +16,54 @@ assert chatml_jinja_path.exists() # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATION_OUTPUT = [ - ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user + ( + "facebook/opt-125m", + chatml_jinja_path, + True, + False, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user What is the capital of<|im_end|> <|im_start|>assistant -"""), - ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user +""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + False, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user -What is the capital of"""), - ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user +What is the capital of""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + True, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user What is the capital of<|im_end|> <|im_start|>assistant -The capital of"""), +The capital of""", + ), ] TEST_MESSAGES = [ - { - 'role': 'user', - 'content': 'Hello' - }, - { - 'role': 'assistant', - 'content': 'Hi there!' - }, - { - 'role': 'user', - 'content': 'What is the capital of' - }, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "What is the capital of"}, ] -ASSISTANT_MESSAGE_TO_CONTINUE = { - 'role': 'assistant', - 'content': 'The capital of' -} +ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"} def test_load_chat_template(): @@ -68,8 +73,11 @@ def test_load_chat_template(): # Test assertions assert template_content is not None # Hard coded value for template_chatml.jinja - assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} + assert ( + template_content + == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 + ) def test_no_load_chat_template_filelike(): @@ -91,9 +99,11 @@ def test_no_load_chat_template_literallike(): @pytest.mark.parametrize( "model,template,add_generation_prompt,continue_final_message,expected_output", - MODEL_TEMPLATE_GENERATION_OUTPUT) -def test_get_gen_prompt(model, template, add_generation_prompt, - continue_final_message, expected_output): + MODEL_TEMPLATE_GENERATION_OUTPUT, +) +def test_get_gen_prompt( + model, template, add_generation_prompt, continue_final_message, expected_output +): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -104,6 +114,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt, trust_remote_code=model_info.trust_remote_code, revision=model_info.revision, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) # Initialize the tokenizer @@ -117,7 +130,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, mock_request = ChatCompletionRequest( model=model, messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] - if continue_final_message else TEST_MESSAGES, + if continue_final_message + else TEST_MESSAGES, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, ) @@ -136,4 +150,5 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Test assertion assert result == expected_output, ( f"The generated prompt does not match the expected output for " - f"model {model} and template {template}") + f"model {model} and template {template}" + ) diff --git a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py index 03730b67283c4..e452b578ba22b 100644 --- a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py +++ b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py @@ -14,9 +14,14 @@ MODEL_NAME = "Qwen/QwQ-32B" @pytest.fixture(scope="module") def server(): # noqa: F811 args = [ - "--max-model-len", "8192", "--enforce-eager", "--reasoning-parser", - "deepseek_r1", "--enable-auto-tool-choice", "--tool-call-parser", - "hermes" + "--max-model-len", + "8192", + "--enforce-eager", + "--reasoning-parser", + "deepseek_r1", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -29,50 +34,46 @@ async def client(server): yield async_client -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. " + "'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that " + "the city is in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, } -}] +] -MESSAGES = [{ - "role": "user", - "content": "Hi! How are you doing today?" -}, { - "role": "assistant", - "content": "I'm doing well! How can I help you?" -}, { - "role": - "user", - "content": - "Can you tell me what the temperate will be in Dallas, in fahrenheit?" -}] +MESSAGES = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + { + "role": "user", + "content": "Can you tell me what the temperate will be in Dallas, " + "in fahrenheit?", + }, +] FUNC_NAME = "get_current_weather" FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}""" @@ -105,9 +106,7 @@ def extract_reasoning_and_calls(chunks: list): # test streaming @pytest.mark.asyncio -async def test_chat_streaming_of_tool_and_reasoning( - client: openai.AsyncOpenAI): - +async def test_chat_streaming_of_tool_and_reasoning(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( model=MODEL_NAME, messages=MESSAGES, @@ -120,8 +119,7 @@ async def test_chat_streaming_of_tool_and_reasoning( async for chunk in stream: chunks.append(chunk) - reasoning_content, arguments, function_names = extract_reasoning_and_calls( - chunks) + reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks) assert len(reasoning_content) > 0 assert len(function_names) > 0 and function_names[0] == FUNC_NAME assert len(arguments) > 0 and arguments[0] == FUNC_ARGS @@ -130,7 +128,6 @@ async def test_chat_streaming_of_tool_and_reasoning( # test full generate @pytest.mark.asyncio async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): - tool_calls = await client.chat.completions.create( model=MODEL_NAME, messages=MESSAGES, @@ -140,7 +137,5 @@ async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): ) assert len(tool_calls.choices[0].message.reasoning_content) > 0 - assert tool_calls.choices[0].message.tool_calls[0].function.name \ - == FUNC_NAME - assert tool_calls.choices[0].message.tool_calls[0].function.arguments \ - == FUNC_ARGS + assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME + assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS diff --git a/tests/entrypoints/openai/test_chunked_prompt.py b/tests/entrypoints/openai/test_chunked_prompt.py index c8160c5f2d0e3..608e509e59e8a 100644 --- a/tests/entrypoints/openai/test_chunked_prompt.py +++ b/tests/entrypoints/openai/test_chunked_prompt.py @@ -40,7 +40,8 @@ async def client(server): @pytest.mark.asyncio async def test_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI): + client: openai.AsyncOpenAI, +): # Test stream with long prompt prompt = "What is the capital of France?" * 400 @@ -62,8 +63,9 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if not finished: tokens_received += 1 assert chunk.choices[0].text @@ -77,15 +79,13 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( @pytest.mark.asyncio async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI): + client: openai.AsyncOpenAI, +): # Test stream with long prompt - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" * 400 - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?" * 400}, + ] stream = await client.chat.completions.create( model=MODEL_NAME, messages=messages, @@ -106,8 +106,9 @@ async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if not finished: if chunk.choices[0].delta.content == "": diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index b20838956d721..0b9d171aa4818 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -5,8 +5,7 @@ import json import pytest -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.utils import FlexibleArgumentParser @@ -15,7 +14,7 @@ from ...utils import VLLM_PATH LORA_MODULE = { "name": "module2", "path": "/path/to/module2", - "base_model_name": "llama" + "base_model_name": "llama", } CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja" assert CHATML_JINJA_PATH.exists() @@ -27,27 +26,55 @@ def serve_parser(): return make_arg_parser(parser) +### Test config parsing +def test_config_arg_parsing(serve_parser, cli_config_file): + args = serve_parser.parse_args([]) + assert args.port == 8000 + args = serve_parser.parse_args(["--config", cli_config_file]) + assert args.port == 12312 + args = serve_parser.parse_args( + [ + "--config", + cli_config_file, + "--port", + "9000", + ] + ) + assert args.port == 9000 + args = serve_parser.parse_args( + [ + "--port", + "9000", + "--config", + cli_config_file, + ] + ) + assert args.port == 9000 + + ### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path - args = serve_parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', - ]) - expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + args = serve_parser.parse_args( + [ + "--lora-modules", + "module1=/path/to/module1", + ] + ) + expected = [LoRAModulePath(name="module1", path="/path/to/module1")] assert args.lora_modules == expected def test_valid_json_format(serve_parser): # Test valid JSON format input - args = serve_parser.parse_args([ - '--lora-modules', - json.dumps(LORA_MODULE), - ]) + args = serve_parser.parse_args( + [ + "--lora-modules", + json.dumps(LORA_MODULE), + ] + ) expected = [ - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') + LoRAModulePath(name="module2", path="/path/to/module2", base_model_name="llama") ] assert args.lora_modules == expected @@ -55,47 +82,53 @@ def test_valid_json_format(serve_parser): def test_invalid_json_format(serve_parser): # Test invalid JSON format input, missing closing brace with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', '{"name": "module3", "path": "/path/to/module3"' - ]) + serve_parser.parse_args( + ["--lora-modules", '{"name": "module3", "path": "/path/to/module3"'] + ) def test_invalid_type_error(serve_parser): # Test type error when values are not JSON or key=value with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', - 'invalid_format' # This is not JSON or key=value format - ]) + serve_parser.parse_args( + [ + "--lora-modules", + "invalid_format", # This is not JSON or key=value format + ] + ) def test_invalid_json_field(serve_parser): # Test valid JSON format but missing required fields with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', - '{"name": "module4"}' # Missing required 'path' field - ]) + serve_parser.parse_args( + [ + "--lora-modules", + '{"name": "module4"}', # Missing required 'path' field + ] + ) def test_empty_values(serve_parser): # Test when no LoRA modules are provided - args = serve_parser.parse_args(['--lora-modules', '']) + args = serve_parser.parse_args(["--lora-modules", ""]) assert args.lora_modules == [] def test_multiple_valid_inputs(serve_parser): # Test multiple valid inputs (both old and JSON format) - args = serve_parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', - json.dumps(LORA_MODULE), - ]) + args = serve_parser.parse_args( + [ + "--lora-modules", + "module1=/path/to/module1", + json.dumps(LORA_MODULE), + ] + ) expected = [ - LoRAModulePath(name='module1', path='/path/to/module1'), - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') + LoRAModulePath(name="module1", path="/path/to/module1"), + LoRAModulePath( + name="module2", path="/path/to/module2", base_model_name="llama" + ), ] assert args.lora_modules == expected @@ -111,40 +144,46 @@ def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser): def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser): """Ensure validation passes with tool choice enabled with a call parser""" - args = serve_parser.parse_args(args=[ - "--enable-auto-tool-choice", - "--tool-call-parser", - "mistral", - ]) + args = serve_parser.parse_args( + args=[ + "--enable-auto-tool-choice", + "--tool-call-parser", + "mistral", + ] + ) validate_parsed_serve_args(args) def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser): """Ensure validation fails if reasoning is enabled with auto tool choice""" - args = serve_parser.parse_args(args=[ - "--enable-auto-tool-choice", - "--reasoning-parser", - "deepseek_r1", - ]) + args = serve_parser.parse_args( + args=[ + "--enable-auto-tool-choice", + "--reasoning-parser", + "deepseek_r1", + ] + ) with pytest.raises(TypeError): validate_parsed_serve_args(args) def test_passes_with_reasoning_parser(serve_parser): - """Ensure validation passes if reasoning is enabled + """Ensure validation passes if reasoning is enabled with a reasoning parser""" - args = serve_parser.parse_args(args=[ - "--reasoning-parser", - "deepseek_r1", - ]) + args = serve_parser.parse_args( + args=[ + "--reasoning-parser", + "deepseek_r1", + ] + ) validate_parsed_serve_args(args) def test_chat_template_validation_for_happy_paths(serve_parser): """Ensure validation passes if the chat template exists""" args = serve_parser.parse_args( - args=["--chat-template", - CHATML_JINJA_PATH.absolute().as_posix()]) + args=["--chat-template", CHATML_JINJA_PATH.absolute().as_posix()] + ) validate_parsed_serve_args(args) @@ -157,8 +196,14 @@ def test_chat_template_validation_for_sad_paths(serve_parser): @pytest.mark.parametrize( "cli_args, expected_middleware", - [(["--middleware", "middleware1", "--middleware", "middleware2" - ], ["middleware1", "middleware2"]), ([], [])]) + [ + ( + ["--middleware", "middleware1", "--middleware", "middleware2"], + ["middleware1", "middleware2"], + ), + ([], []), + ], +) def test_middleware(serve_parser, cli_args, expected_middleware): """Ensure multiple middleware args are parsed properly""" args = serve_parser.parse_args(args=cli_args) diff --git a/tests/entrypoints/openai/test_collective_rpc.py b/tests/entrypoints/openai/test_collective_rpc.py index 37c0b7a900ac4..cbd6b02f05dce 100644 --- a/tests/entrypoints/openai/test_collective_rpc.py +++ b/tests/entrypoints/openai/test_collective_rpc.py @@ -12,7 +12,6 @@ MODEL_NAME = "Qwen/Qwen3-0.6B" class TestWorkerExtension: - def get_model_name(self) -> str: """Test non-pydantic return type.""" return MODEL_NAME @@ -41,20 +40,18 @@ def server(): "tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension", ] with RemoteOpenAIServer( - MODEL_NAME, - args, - env_dict={ - "VLLM_SERVER_DEV_MODE": "1", - "CUDA_VISIBLE_DEVICES": "0" - }, + MODEL_NAME, + args, + env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"}, ) as remote_server: yield remote_server def test_get_model_name(server): """Test basic response""" - response = requests.post(server.url_for("collective_rpc"), - json={"method": "get_model_name"}) + response = requests.post( + server.url_for("collective_rpc"), json={"method": "get_model_name"} + ) assert response.status_code == 200 results = response.json() assert "results" in results @@ -63,8 +60,9 @@ def test_get_model_name(server): def test_return_none(server): """Test return none""" - response = requests.post(server.url_for("collective_rpc"), - json={"method": "return_none"}) + response = requests.post( + server.url_for("collective_rpc"), json={"method": "return_none"} + ) assert response.status_code == 200 results = response.json() assert results["results"] == [None] @@ -74,12 +72,10 @@ def test_echo_args_kwargs(server): """Test args, kwargs, and dict response""" args = ["arg1", "arg2"] kwargs = {"key1": "value1", "key2": "value2"} - response = requests.post(server.url_for("collective_rpc"), - json={ - "method": "echo_args_kwargs", - "args": args, - "kwargs": kwargs - }) + response = requests.post( + server.url_for("collective_rpc"), + json={"method": "echo_args_kwargs", "args": args, "kwargs": kwargs}, + ) assert response.status_code == 200 results = response.json() result = results["results"][0] diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py deleted file mode 100644 index 74ef6deeea16b..0000000000000 --- a/tests/entrypoints/openai/test_completion.py +++ /dev/null @@ -1,872 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# imports for guided decoding tests -import json -import os -import shutil -from tempfile import TemporaryDirectory -from typing import Optional - -import jsonschema -import openai # use the official client for correctness check -import pytest -import pytest_asyncio -import regex as re -import requests -# downloading lora to test lora requests -from huggingface_hub import snapshot_download -from openai import BadRequestError -from transformers import AutoTokenizer - -from vllm.transformers_utils.tokenizer import get_tokenizer - -from ...utils import RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically these adapters use a different base model, -# but we're not testing generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" - -GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"] - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - -@pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): - return [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--max-num-seqs", - "128", - "--enforce-eager", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - ] - - -@pytest.fixture(scope="module", - params=["", "--disable-frontend-multiprocessing"]) -def server(default_server_args, request): - if request.param: - default_server_args.append(request.param) - - original_value = os.environ.get('VLLM_USE_V1') - os.environ['VLLM_USE_V1'] = '0' - try: - with RemoteOpenAIServer(MODEL_NAME, - default_server_args) as remote_server: - yield remote_server - finally: - # Restore original env value - if original_value is None: - os.environ.pop('VLLM_USE_V1', None) - else: - os.environ['VLLM_USE_V1'] = original_value - - -@pytest.fixture -def is_v1_server(server): - import os - - # For completion tests, we assume v0 since there's no explicit v1 setup - return os.environ.get('VLLM_USE_V1', '0') == '1' - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], -) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - assert len(choice.text) >= 5 - assert choice.finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 1 - assert completion.choices[0].prompt_logprobs is None - - -@pytest.mark.asyncio -async def test_added_lora_tokens(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model="zephyr-lora2", - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should appear in tokenized prompt - assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3") - - -@pytest.mark.asyncio -async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): - # test using token IDs - with pytest.raises(openai.BadRequestError, match="out of vocabulary"): - # Added tokens should be rejected by the base model - await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], -) -async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=None, - ) - choice = completion.choices[0] - assert choice.logprobs is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # just test 1 lora - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=0, - ) - choice = completion.choices[0] - assert choice.logprobs is not None - assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is not None - assert len(choice.logprobs.top_logprobs[0]) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=5, - ) - choice = completion.choices[0] - assert choice.logprobs is not None - assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is not None - assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str): - - with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs - await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - # vLLM has higher default max_logprobs (20 instead of 5) to support - # both Completion API and Chat Completion API - logprobs=21, - ) - ... - with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs - stream = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - # vLLM has higher default max_logprobs (20 instead of 5) to support - # both Completion API and Chat Completion API - logprobs=30, - stream=True, - ) - async for chunk in stream: - ... - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), - (MODEL_NAME, 0), - (MODEL_NAME, 1), - (MODEL_NAME, None)]) -async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): - params: dict = { - "prompt": ["A robot may not injure another robot", "My name is"], - "model": model_name, - } - if prompt_logprobs is not None: - params["extra_body"] = {"prompt_logprobs": prompt_logprobs} - - if prompt_logprobs is not None and prompt_logprobs < 0: - with pytest.raises(BadRequestError): - await client.completions.create(**params) - else: - completion = await client.completions.create(**params) - if prompt_logprobs is not None: - assert completion.choices[0].prompt_logprobs is not None - assert len(completion.choices[0].prompt_logprobs) > 0 - - assert completion.choices[1].prompt_logprobs is not None - assert len(completion.choices[1].prompt_logprobs) > 0 - - else: - assert completion.choices[0].prompt_logprobs is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is an LLM?" - - single_completion = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - ) - single_output = single_completion.choices[0].text - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) - chunks: list[str] = [] - finish_reason_count = 0 - async for chunk in stream: - chunks.append(chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - # finish reason should only return in last block - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert "".join(chunks) == single_output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): - """Streaming for parallel sampling. - The tokens from multiple samples, are flattened into a single stream, - with an index to indicate which sample the token belongs to. - """ - - prompt = "What is an LLM?" - n = 3 - max_tokens = 5 - - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=max_tokens, - n=n, - stream=True) - chunks: list[list[str]] = [[] for i in range(n)] - finish_reason_count = 0 - async for chunk in stream: - index = chunk.choices[0].index - text = chunk.choices[0].text - chunks[index].append(text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert finish_reason_count == n - for chunk in chunks: - assert len(chunk) == max_tokens - print("".join(chunk)) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is the capital of France?" - - # Test stream=True, stream_options= - # {"include_usage": False, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - False, - }) - - async for chunk in stream: - assert chunk.usage is None - - # Test stream=True, stream_options= - # {"include_usage": False, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - True, - }) - async for chunk in stream: - assert chunk.usage is None - - # Test stream=True, stream_options= - # {"include_usage": True, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - False, - }) - async for chunk in stream: - if chunk.choices[0].finish_reason is None: - assert chunk.usage is None - else: - assert chunk.usage is None - final_chunk = await stream.__anext__() - assert final_chunk.usage is not None - assert final_chunk.usage.prompt_tokens > 0 - assert final_chunk.usage.completion_tokens > 0 - assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) - assert final_chunk.choices == [] - - # Test stream=True, stream_options= - # {"include_usage": True, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - True, - }) - async for chunk in stream: - assert chunk.usage is not None - assert chunk.usage.prompt_tokens > 0 - assert chunk.usage.completion_tokens > 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) - if chunk.choices[0].finish_reason is not None: - final_chunk = await stream.__anext__() - assert final_chunk.usage is not None - assert final_chunk.usage.prompt_tokens > 0 - assert final_chunk.usage.completion_tokens > 0 - assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) - assert final_chunk.choices == [] - - # Test stream=False, stream_options= - # {"include_usage": None} - with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}) - - # Test stream=False, stream_options= - # {"include_usage": True} - with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}) - - # Test stream=False, stream_options= - # {"continuous_usage_stats": None} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"continuous_usage_stats": None}) - - # Test stream=False, stream_options= - # {"continuous_usage_stats": True} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"continuous_usage_stats": True}) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora"], -) -async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): - # test both text and token IDs - for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): - # test simple list - batch = await client.completions.create( - model=model_name, - prompt=prompts, - max_tokens=5, - temperature=0.0, - ) - assert len(batch.choices) == 2 - assert batch.choices[0].text == batch.choices[1].text - - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=prompts, - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but - # not necessary for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - - # test streaming - batch = await client.completions.create( - model=model_name, - prompt=prompts, - max_tokens=5, - temperature=0.0, - stream=True, - ) - texts = [""] * 2 - async for chunk in batch: - assert len(chunk.choices) == 1 - choice = chunk.choices[0] - texts[choice.index] += choice.text - assert texts[0] == texts[1] - - -@pytest.mark.asyncio -async def test_logits_bias(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 5 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - token_id = 1000 - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token_id): 100}, - seed=42, - ) - assert len(completion.choices[0].text) >= 5 - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), - add_special_tokens=False)["input_ids"] - assert all([ - response == expected - for response, expected in zip(response_tokens, expected_tokens) - ]) - - # Test ban - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - first_response = completion.choices[0].text - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token): -100 - for token in response_tokens}, - ) - assert first_response != completion.choices[0].text - - -@pytest.mark.asyncio -async def test_allowed_token_ids(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 1 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - allowed_ids = [21555, 21557, 21558] - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - seed=42, - extra_body=dict(allowed_token_ids=allowed_ids), - logprobs=1, - ) - response_tokens = completion.choices[0].logprobs.tokens - assert len(response_tokens) == 1 - assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_json_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}", - n=3, - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_json=sample_json_schema, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - output_json = json.loads(completion.choices[i].text) - jsonschema.validate(instance=output_json, schema=sample_json_schema) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_regex_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_regex, is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example IPv4 address with this regex: {sample_regex}", - n=3, - temperature=1.0, - max_tokens=20, - extra_body=dict(guided_regex=sample_regex, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - assert re.fullmatch(sample_regex, - completion.choices[i].text) is not None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_choice_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_guided_choice, - is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - - completion = await client.completions.create( - model=MODEL_NAME, - prompt="The best language for type-safe systems programming is ", - n=2, - temperature=1.0, - max_tokens=10, - extra_body=dict(guided_choice=sample_guided_choice, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 2 - for i in range(2): - assert completion.choices[i].text in sample_guided_choice - - -@pytest.mark.asyncio -async def test_guided_grammar(client: openai.AsyncOpenAI, - sample_sql_statements, is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided grammar is only supported in v1 engine") - - completion = await client.completions.create( - model=MODEL_NAME, - prompt=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) - - content = completion.choices[0].text - - # use Lark to parse the output, and make sure it's a valid parse tree - from lark import Lark - parser = Lark(sample_sql_statements) - parser.parse(content) - - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") - - assert content.strip() == ground_truth - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], -) -@pytest.mark.parametrize("logprobs_arg", [1, 0]) -async def test_echo_logprob_completion(client: openai.AsyncOpenAI, - model_name: str, logprobs_arg: int): - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - # test using text and token IDs - for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - echo=True, - logprobs=logprobs_arg) - - prompt_text = tokenizer.decode(prompt) if isinstance(prompt, - list) else prompt - assert re.search(r"^" + prompt_text, completion.choices[0].text) - logprobs = completion.choices[0].logprobs - assert logprobs is not None - assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) - for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) > 5 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex, - is_v1_server: bool): - if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") - - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) - - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema)) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name,stream,echo", - [ - (MODEL_NAME, False, False), - (MODEL_NAME, False, True), - (MODEL_NAME, True, False), - (MODEL_NAME, True, True) # should not raise BadRequestError error - ], -) -async def test_echo_stream_completion(client: openai.AsyncOpenAI, - model_name: str, stream: bool, - echo: bool): - saying: str = "Hello, my name is" - result = await client.completions.create(model=model_name, - prompt=saying, - max_tokens=10, - temperature=0.0, - echo=echo, - stream=stream) - - stop_reason = "length" - - if not stream: - completion = result - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - assert len(choice.text) >= 5 - assert choice.finish_reason == stop_reason - - if echo: - assert choice.text is not None and saying in choice.text - else: - assert choice.text is not None and saying not in choice.text - - else: - chunks: list[str] = [] - final_finish_reason = None - async for chunk in result: - if chunk.choices and chunk.choices[0].text: - chunks.append(chunk.choices[0].text) - if chunk.choices and chunk.choices[0].finish_reason: - final_finish_reason = chunk.choices[0].finish_reason - - assert final_finish_reason == stop_reason - content = "".join(chunks) - if echo: - assert content is not None and saying in content - else: - assert content is not None and saying not in content - - -@pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): - request_args = { - "model": MODEL_NAME, - "prompt": "Hello, my name is", - "max_tokens": 5, - "temperature": 0.0, - "logprobs": None, - } - - completion = await client.completions.create(**request_args) - - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) - invocation_response.raise_for_status() - - completion_output = completion.model_dump() - invocation_output = invocation_response.json() - - assert completion_output.keys() == invocation_output.keys() - assert completion_output["choices"] == invocation_output["choices"] diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 4ef5d4e8a699a..e64f68cad7c83 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import datetime from typing import Union import openai # use the official client for correctness check @@ -24,15 +25,14 @@ tools = [ "properties": { "city": { "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", + "description": "The city to find the weather for, e.g. " + "'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. " + "'Austria'", }, "unit": { "type": "string", @@ -61,8 +61,7 @@ tools = [ "include_forecast": { "type": "boolean", "default": False, - "description": - "Whether to include a 24-hour forecast", + "description": "Whether to include a 24-hour forecast", "title": "Include Forecast", }, "language": { @@ -88,21 +87,18 @@ tools = [ "properties": { "city": { "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", + "description": "The city to get the forecast for, e.g. " + "'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. " + "'Austria'", }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, "unit": { "type": "string", @@ -117,19 +113,11 @@ tools = [ ] messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, { "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ + "content": "Can you tell me what the current weather is in Berlin and the " "forecast for the next 5 days, in fahrenheit?", }, ] @@ -142,14 +130,14 @@ def server(): # noqa: F811 "--dtype", "half", "--enable-auto-tool-choice", - "--guided-decoding-backend", + "--structured-outputs-config.backend", "xgrammar", "--tool-call-parser", "hermes", "--reasoning-parser", "qwen3", "--gpu-memory-utilization", - "0.4" + "0.4", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -165,18 +153,22 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("stream", [True, False]) -@pytest.mark.parametrize("tool_choice", [ - "auto", "required", { - "type": "function", - "function": { - "name": "get_current_weather" - } - } -]) +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + "required", + {"type": "function", "function": {"name": "get_current_weather"}}, + ], +) @pytest.mark.parametrize("enable_thinking", [True, False]) -async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, - stream: bool, tool_choice: Union[str, dict], - enable_thinking: bool): +async def test_function_tool_use( + client: openai.AsyncOpenAI, + model_name: str, + stream: bool, + tool_choice: Union[str, dict], + enable_thinking: bool, +): if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -184,16 +176,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, model=model_name, tools=tools, tool_choice=tool_choice, - extra_body={ - "chat_template_kwargs": { - "enable_thinking": enable_thinking - } - }) + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, + ) if enable_thinking: - assert chat_completion.choices[0].message.\ - reasoning_content is not None - assert chat_completion.choices[0].message.\ - reasoning_content != "" + assert chat_completion.choices[0].message.reasoning_content is not None + assert chat_completion.choices[0].message.reasoning_content != "" assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 else: @@ -204,11 +191,8 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, tools=tools, tool_choice=tool_choice, stream=True, - extra_body={ - "chat_template_kwargs": { - "enable_thinking": enable_thinking - } - }) + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, + ) output = [] async for chunk in output_stream: @@ -225,7 +209,7 @@ def k2_server(): # noqa: F811 "--dtype", "half", "--enable-auto-tool-choice", - "--guided-decoding-backend", + "--structured-outputs-config.backend", "xgrammar", "--tool-call-parser", "hermes", @@ -236,12 +220,11 @@ def k2_server(): # noqa: F811 ] # hack to test kimi_k2 tool use tool_id format. # avoid error in is_deepseek_mla check by setting kv_lora_rank=null - with RemoteOpenAIServer(MODEL_NAME, - args, - override_hf_configs={ - "model_type": 'kimi_k2', - 'kv_lora_rank': None - }) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, + args, + override_hf_configs={"model_type": "kimi_k2", "kv_lora_rank": None}, + ) as remote_server: yield remote_server @@ -255,20 +238,20 @@ async def k2_client(k2_server): @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("tool_choice", ["required"]) -async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, - stream: bool, tool_choice: str): - +async def test_tool_id_kimi_k2( + k2_client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: str +): if not stream: # Non-streaming test chat_completion = await k2_client.chat.completions.create( - messages=messages, - model=model_name, - tools=tools, - tool_choice=tool_choice) + messages=messages, model=model_name, tools=tools, tool_choice=tool_choice + ) assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 - assert chat_completion.choices[0].message.tool_calls[ - 0].id == 'functions.get_current_weather:0' + assert ( + chat_completion.choices[0].message.tool_calls[0].id + == "functions.get_current_weather:0" + ) else: # Streaming test output_stream = await k2_client.chat.completions.create( @@ -276,11 +259,75 @@ async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, model=model_name, tools=tools, tool_choice=tool_choice, - stream=True) + stream=True, + ) output = [] async for chunk in output_stream: if chunk.choices and chunk.choices[0].delta.tool_calls: output.extend(chunk.choices[0].delta.tool_calls) for o in output: - assert o.id is None or o.id == 'functions.get_current_weather:0' + assert o.id is None or o.id == "functions.get_current_weather:0" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("arguments", ["{}", ""]) +async def test_no_args_tool_call( + client: openai.AsyncOpenAI, model_name: str, arguments: str +): + # Step 1: Define a tool that requires no parameters + tools = [ + { + "type": "function", + "function": { + "name": "get_current_time", + "description": "Get the current date and time. No parameters needed.", + "parameters": { + "type": "object", + "properties": {}, # No parameters + "required": [], # No required fields + }, + }, + } + ] + messages = [{"role": "user", "content": "What time is it now?"}] + # Step 2: Send user message and let model decide whether to call the tool + response = await client.chat.completions.create( + model=model_name, + messages=messages, + tools=tools, + tool_choice="auto", # Let model choose automatically + ) + + # Step 3: Check if model wants to call a tool + message = response.choices[0].message + if message.tool_calls: + # Get the first tool call + tool_call = message.tool_calls[0] + tool_name = tool_call.function.name + # Step 4: Execute the tool locally (no parameters) + if tool_name == "get_current_time": + # Test both empty string and "{}" for no-arg tool calls + tool_call.function.arguments = arguments + messages.append(message) + current_time = datetime.datetime.now() + result = current_time.isoformat() + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + # Step 5: Send tool result back to model to continue conversation + final_response = await client.chat.completions.create( + model=model_name, + messages=messages, + ) + # Output final natural language response + assert final_response.choices[0].message.content is not None + + else: + # No tool called — just print model's direct reply + assert message.content is not None diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 00d3ffb61ee9f..3ed98ffe0e399 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -3,70 +3,87 @@ import base64 import io -import shutil -from tempfile import TemporaryDirectory +import json import openai # use the official client for correctness check import pytest import pytest_asyncio import torch + # downloading lora to test lora requests -from huggingface_hub import snapshot_download from openai import BadRequestError -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig from ...utils import RemoteOpenAIServer # any model with a chat template should work here -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -LORA_NAME = "typeof/zephyr-7b-beta-lora" +MODEL_NAME = "facebook/opt-125m" +LORA_SERVING_MODEL_NAME = "opt125m-lora" CONFIG = AutoConfig.from_pretrained(MODEL_NAME) -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", params=["use-lora"]) def default_server_args( - zephyr_lora_files, - zephyr_lora_added_tokens_files, + request: pytest.FixtureRequest, opt125_lora_files: str ) -> list[str]: - return [ + args = [ # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", "--max-model-len", - "8192", + "2048", "--max-num-seqs", "128", "--enforce-eager", # Prompt Embeds server args "--enable-prompt-embeds", - "--no-enable-chunked-prefill", ] + if request.param == "use-lora": + lora_module_1 = { + "name": LORA_SERVING_MODEL_NAME, + "path": opt125_lora_files, + "base_model_name": MODEL_NAME, + } -@pytest.fixture(scope="module", - params=["", "--disable-frontend-multiprocessing"]) + args.extend( + [ + "--enable-lora", + "--lora-module", + json.dumps(lora_module_1), + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + ] + ) + + return args + + +EXAMPLE_PROMPTS = [ + "Hello, my name is", + "What is an LLM?", +] + + +def _encode_embeds(embeds: torch.Tensor): + buffer = io.BytesIO() + torch.save(embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +@pytest.fixture(scope="module") +def example_prompt_embeds(hf_runner): + """Create example embeddings and return them as base64 encoded string.""" + with hf_runner(MODEL_NAME) as hf_model: + example_embeddings = hf_model.get_prompt_embeddings(EXAMPLE_PROMPTS) + + return [_encode_embeds(item) for item in example_embeddings] + + +@pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"]) def server_with_prompt_embeds(default_server_args, request): if request.param: default_server_args.append(request.param) @@ -81,49 +98,46 @@ async def client_with_prompt_embeds(server_with_prompt_embeds): yield async_client -def create_dummy_embeds(num_tokens: int = 5) -> str: - """Create dummy embeddings and return them as base64 encoded string.""" - dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) - buffer = io.BytesIO() - torch.save(dummy_embeds, buffer) - return base64.b64encode(buffer.getvalue()).decode('utf-8') - - @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME]) async def test_completions_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + example_prompt_embeds, + client_with_prompt_embeds: openai.AsyncOpenAI, + model_name: str, +): + encoded_embeds, encoded_embeds2 = example_prompt_embeds + # Test case: Single prompt embeds input - encoded_embeds = create_dummy_embeds() completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) assert len(completion.choices[0].text) >= 1 assert completion.choices[0].prompt_logprobs is None # Test case: batch completion with prompt_embeds - encoded_embeds2 = create_dummy_embeds() completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) assert len(completion.choices) == 2 assert len(completion.choices[0].text) >= 1 assert len(completion.choices[1].text) >= 1 # Test case: streaming with prompt_embeds - encoded_embeds = create_dummy_embeds() single_completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) single_output = single_completion.choices[0].text stream = await client_with_prompt_embeds.completions.create( @@ -132,7 +146,8 @@ async def test_completions_with_prompt_embeds( max_tokens=5, temperature=0.0, stream=True, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) chunks = [] finish_reason_count = 0 async for chunk in stream: @@ -145,19 +160,18 @@ async def test_completions_with_prompt_embeds( assert "".join(chunks) == single_output # Test case: batch streaming with prompt_embeds - encoded_embeds2 = create_dummy_embeds() stream = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, stream=True, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) chunks_stream_embeds: list[list[str]] = [[], []] finish_reason_count = 0 async for chunk in stream: - chunks_stream_embeds[chunk.choices[0].index].append( - chunk.choices[0].text) + chunks_stream_embeds[chunk.choices[0].index].append(chunk.choices[0].text) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 assert finish_reason_count == 2 @@ -167,13 +181,13 @@ async def test_completions_with_prompt_embeds( assert len(chunks_stream_embeds[1]) > 0 # Test case: mixed text and prompt_embeds - encoded_embeds = create_dummy_embeds() completion_mixed = await client_with_prompt_embeds.completions.create( model=model_name, prompt="This is a prompt", max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) assert len(completion.choices) == 2 completion_text_only = await client_with_prompt_embeds.completions.create( model=model_name, @@ -186,18 +200,18 @@ async def test_completions_with_prompt_embeds( prompt="", max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) # Embeddings responses should be handled first - assert completion_mixed.choices[0].text == completion_embeds_only.choices[ - 0].text - assert completion_mixed.choices[1].text == completion_text_only.choices[ - 0].text + assert completion_mixed.choices[0].text == completion_embeds_only.choices[0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[0].text @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME]) async def test_completions_errors_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str +): # Test error case: invalid prompt_embeds with pytest.raises(BadRequestError): await client_with_prompt_embeds.completions.create( @@ -205,17 +219,22 @@ async def test_completions_errors_with_prompt_embeds( model=model_name, max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": "invalid_base64"}) + extra_body={"prompt_embeds": "invalid_base64"}, + ) @pytest.mark.asyncio @pytest.mark.parametrize("logprobs_arg", [1, 0]) -@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME]) async def test_completions_with_logprobs_and_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, - model_name: str): + example_prompt_embeds, + client_with_prompt_embeds: openai.AsyncOpenAI, + logprobs_arg: int, + model_name: str, +): + encoded_embeds, encoded_embeds2 = example_prompt_embeds + # Test case: Logprobs using prompt_embeds - encoded_embeds = create_dummy_embeds() completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter @@ -223,7 +242,8 @@ async def test_completions_with_logprobs_and_prompt_embeds( temperature=0.0, echo=False, logprobs=logprobs_arg, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) logprobs = completion.choices[0].logprobs assert logprobs is not None @@ -235,7 +255,6 @@ async def test_completions_with_logprobs_and_prompt_embeds( assert len(logprobs.tokens) == 5 # Test case: Log probs with batch completion and prompt_embeds - encoded_embeds2 = create_dummy_embeds() completion = await client_with_prompt_embeds.completions.create( model=model_name, prompt="", # Add empty prompt as required parameter @@ -243,7 +262,8 @@ async def test_completions_with_logprobs_and_prompt_embeds( temperature=0.0, echo=False, logprobs=logprobs_arg, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) assert len(completion.choices) == 2 for choice in completion.choices: @@ -253,6 +273,22 @@ async def test_completions_with_logprobs_and_prompt_embeds( assert len(logprobs.token_logprobs) == 5 assert len(logprobs.top_logprobs) == 5 for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) == 5 + + +@pytest.mark.asyncio +async def test_prompt_logprobs_raises_error( + example_prompt_embeds, + client_with_prompt_embeds: openai.AsyncOpenAI, +): + encoded_embeds, _ = example_prompt_embeds + + with pytest.raises(BadRequestError, match="not compatible"): + await client_with_prompt_embeds.completions.create( + model=MODEL_NAME, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, + ) diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py index b9c466a6fbeb6..336bda81a9ef2 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -16,8 +16,7 @@ from ...utils import RemoteOpenAIServer # need a multimodal model for these tests. # Contains a modality specific lora alongside the base model -MULTIMODAL_MODEL_NAME = snapshot_download( - "microsoft/Phi-4-multimodal-instruct") +MULTIMODAL_MODEL_NAME = snapshot_download("microsoft/Phi-4-multimodal-instruct") AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora") ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501 @@ -25,7 +24,6 @@ ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original @pytest.fixture(scope="module") def multimodal_server(): # noqa: F811 - args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -45,11 +43,12 @@ def multimodal_server(): # noqa: F811 "--gpu-memory-utilization", "0.8", "--default-mm-loras", - f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}", + f'{{"audio": "{AUDIO_LORA_PATH}"}}', ] - with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args, - max_wait_seconds=480) as remote_server: + with RemoteOpenAIServer( + MULTIMODAL_MODEL_NAME, args, max_wait_seconds=480 + ) as remote_server: yield remote_server @@ -70,25 +69,25 @@ async def test_default_mm_lora_chat_completions( multi_modal_client: openai.AsyncOpenAI, audio_assets: AudioTestAssets, ): - messages = [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "Can you transcribe this audio?", - }, { - "type": "audio_url", - "audio_url": { - "url": audio_assets[0].url - }, - }] - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you transcribe this audio?", + }, + { + "type": "audio_url", + "audio_url": {"url": audio_assets[0].url}, + }, + ], + } + ] chat_completion = await multi_modal_client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=128, - temperature=0.0) + model=model_name, messages=messages, max_completion_tokens=128, temperature=0.0 + ) assert len(chat_completion.choices) > 0 diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py deleted file mode 100644 index 9c2aef23e8772..0000000000000 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import openai -import pytest -import pytest_asyncio - -from ...utils import RemoteOpenAIServer - -MODEL_NAME = "facebook/bart-base" - - -@pytest.fixture(scope="module") -def server(): - args = [ - "--dtype", - "bfloat16", - "--enforce-eager", - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - assert len(choice.text) >= 5 - assert choice.finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=2, total_tokens=7) - - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 1 diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index bcdeaaacedea0..c74f805961bc8 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -9,8 +9,6 @@ from contextlib import suppress import openai # use the official client for correctness check import pytest import pytest_asyncio -# downloading lora to test lora requests -from huggingface_hub import snapshot_download from ...utils import RemoteOpenAIServer @@ -18,67 +16,29 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" BADREQUEST_CASES = [ ( "test_rank", - { - "r": 1024 - }, + {"r": 1024}, "is greater than max_lora_rank", ), - ( - "test_bias", - { - "bias": "all" - }, - "Adapter bias cannot be used without bias_enabled", - ), - ("test_dora", { - "use_dora": True - }, "does not yet support DoRA"), + ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", - { - "modules_to_save": ["lm_head"] - }, + {"modules_to_save": ["lm_head"]}, "only supports modules_to_save being None", ), ] -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module", params=[False, True]) -def server_with_lora_modules_json(request, monkeypatch_module, - zephyr_lora_files): - - use_v1 = request.param - monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') - +@pytest.fixture(scope="module", params=[True]) +def server_with_lora_modules_json(request, zephyr_lora_files): # Define the json format LoRA module configurations lora_module_1 = { "name": "zephyr-lora", "path": zephyr_lora_files, - "base_model_name": MODEL_NAME - } - - lora_module_2 = { - "name": "zephyr-lora2", - "path": zephyr_lora_files, - "base_model_name": MODEL_NAME + "base_model_name": MODEL_NAME, } args = [ @@ -92,7 +52,6 @@ def server_with_lora_modules_json(request, monkeypatch_module, "--enable-lora", "--lora-modules", json.dumps(lora_module_1), - json.dumps(lora_module_2), "--max-lora-rank", "64", "--max-cpu-loras", @@ -110,14 +69,12 @@ def server_with_lora_modules_json(request, monkeypatch_module, @pytest_asyncio.fixture async def client(server_with_lora_modules_json): - async with server_with_lora_modules_json.get_async_client( - ) as async_client: + async with server_with_lora_modules_json.get_async_client() as async_client: yield async_client @pytest.mark.asyncio -async def test_static_lora_lineage(client: openai.AsyncOpenAI, - zephyr_lora_files): +async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): models = await client.models.list() models = models.data served_model = models[0] @@ -125,23 +82,18 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, assert served_model.id == MODEL_NAME assert served_model.root == MODEL_NAME assert served_model.parent is None - assert all(lora_model.root == zephyr_lora_files - for lora_model in lora_models) + assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" - assert lora_models[1].id == "zephyr-lora2" @pytest.mark.asyncio -async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, - zephyr_lora_files): - - response = await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "zephyr-lora-3", - "lora_path": zephyr_lora_files - }) +async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): + response = await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files}, + ) # Ensure adapter loads before querying /models assert "success" in response @@ -156,37 +108,37 @@ async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, @pytest.mark.asyncio async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI): with pytest.raises(openai.NotFoundError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "notfound", - "lora_path": "/not/an/adapter" - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "notfound", "lora_path": "/not/an/adapter"}, + ) @pytest.mark.asyncio -async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, - tmp_path): +async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path): invalid_files = tmp_path / "invalid_files" invalid_files.mkdir() (invalid_files / "adapter_config.json").write_text("this is not json") with pytest.raises(openai.BadRequestError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "invalid-json", - "lora_path": str(invalid_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "invalid-json", "lora_path": str(invalid_files)}, + ) @pytest.mark.asyncio -@pytest.mark.parametrize("test_name,config_change,expected_error", - BADREQUEST_CASES) -async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, - zephyr_lora_files, test_name: str, - config_change: dict, - expected_error: str): +@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES) +async def test_dynamic_lora_badrequests( + client: openai.AsyncOpenAI, + tmp_path, + zephyr_lora_files, + test_name: str, + config_change: dict, + expected_error: str, +): # Create test directory test_dir = tmp_path / test_name @@ -206,29 +158,28 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, # Test loading the adapter with pytest.raises(openai.BadRequestError, match=expected_error): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": test_name, - "lora_path": str(test_dir) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": test_name, "lora_path": str(test_dir)}, + ) @pytest.mark.asyncio -async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, - zephyr_lora_files): - """Validate that many loras can be dynamically registered and inferenced +async def test_multiple_lora_adapters( + client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files +): + """Validate that many loras can be dynamically registered and inferenced with concurrently""" # This test file configures the server with --max-cpu-loras=2 and this test # will concurrently load 10 adapters, so it should flex the LRU cache async def load_and_run_adapter(adapter_name: str): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": adapter_name, - "lora_path": str(zephyr_lora_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, + ) for _ in range(3): await client.completions.create( model=adapter_name, @@ -238,8 +189,7 @@ async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, lora_tasks = [] for i in range(10): - lora_tasks.append( - asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) results, _ = await asyncio.wait(lora_tasks) @@ -249,8 +199,8 @@ async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, @pytest.mark.asyncio async def test_loading_invalid_adapters_does_not_break_others( - client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files): - + client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files +): invalid_files = tmp_path / "invalid_files" invalid_files.mkdir() (invalid_files / "adapter_config.json").write_text("this is not json") @@ -281,20 +231,18 @@ async def test_loading_invalid_adapters_does_not_break_others( # Run a bunch of bad adapter loads for _ in range(25): with suppress(openai.NotFoundError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "notfound", - "lora_path": "/not/an/adapter" - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "notfound", "lora_path": "/not/an/adapter"}, + ) for _ in range(25): with suppress(openai.BadRequestError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "invalid", - "lora_path": str(invalid_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "invalid", "lora_path": str(invalid_files)}, + ) # Ensure all the running requests with lora adapters succeeded stop_good_requests_event.set() @@ -303,12 +251,11 @@ async def test_loading_invalid_adapters_does_not_break_others( assert not isinstance(r, Exception), f"Got exception {r}" # Ensure we can load another adapter and run it - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "valid", - "lora_path": zephyr_lora_files - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "valid", "lora_path": zephyr_lora_files}, + ) await client.completions.create( model="valid", prompt=["Hello there", "Foo bar bazz buzz"], @@ -325,12 +272,11 @@ async def test_beam_search_with_lora_adapters( """Validate that async beam search can be used with lora.""" async def load_and_run_adapter(adapter_name: str): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": adapter_name, - "lora_path": str(zephyr_lora_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, + ) for _ in range(3): await client.completions.create( model=adapter_name, @@ -341,8 +287,7 @@ async def test_beam_search_with_lora_adapters( lora_tasks = [] for i in range(3): - lora_tasks.append( - asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) results, _ = await asyncio.wait(lora_tasks) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index f4801172580c6..aa4ee603647e4 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -5,19 +5,18 @@ from contextlib import suppress from dataclasses import dataclass, field from http import HTTPStatus from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest -from vllm.config import MultiModalConfig -from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @@ -33,37 +32,43 @@ class MockHFConfig: @dataclass class MockModelConfig: """Minimal mock ModelConfig for testing.""" + model: str = MODEL_NAME tokenizer: str = MODEL_NAME trust_remote_code: bool = False tokenizer_mode: str = "auto" max_model_len: int = 100 tokenizer_revision: Optional[str] = None - multimodal_config: MultiModalConfig = field( - default_factory=MultiModalConfig) + multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig) logits_processor_pattern: Optional[str] = None diff_sampling_param: Optional[dict] = None allowed_local_media_path: str = "" + allowed_media_domains: Optional[list[str]] = None encoder_config = None generation_config: str = "auto" + skip_tokenizer_init: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} class MockLoRAResolver(LoRAResolver): - - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: if lora_name == "test-lora": - return LoRARequest(lora_name="test-lora", - lora_int_id=1, - lora_local_path="/fake/path/test-lora") + return LoRARequest( + lora_name="test-lora", + lora_int_id=1, + lora_local_path="/fake/path/test-lora", + ) elif lora_name == "invalid-lora": - return LoRARequest(lora_name="invalid-lora", - lora_int_id=2, - lora_local_path="/fake/path/invalid-lora") + return LoRARequest( + lora_name="invalid-lora", + lora_int_id=2, + lora_local_path="/fake/path/invalid-lora", + ) return None @@ -81,40 +86,55 @@ def register_mock_resolver(): @pytest.fixture def mock_serving_setup(): """Provides a mocked engine and serving completion instance.""" - mock_engine = MagicMock(spec=MQLLMEngineClient) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.errored = False - def mock_add_lora_side_effect(lora_request: LoRARequest): + tokenizer = get_tokenizer(MODEL_NAME) + mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer) + + async def mock_add_lora_side_effect(lora_request: LoRARequest): """Simulate engine behavior when adding LoRAs.""" if lora_request.lora_name == "test-lora": # Simulate successful addition - return - elif lora_request.lora_name == "invalid-lora": + return True + if lora_request.lora_name == "invalid-lora": # Simulate failure during addition (e.g. invalid format) - raise ValueError(f"Simulated failure adding LoRA: " - f"{lora_request.lora_name}") + raise ValueError(f"Simulated failure adding LoRA: {lora_request.lora_name}") + return True + + mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect) + + async def mock_generate(*args, **kwargs): + for _ in []: + yield _ + + mock_engine.generate = MagicMock(spec=AsyncLLM.generate, side_effect=mock_generate) - mock_engine.add_lora.side_effect = mock_add_lora_side_effect mock_engine.generate.reset_mock() mock_engine.add_lora.reset_mock() - mock_model_config = MockModelConfig() - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() - serving_completion = OpenAIServingCompletion(mock_engine, - mock_model_config, - models, - request_logger=None) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + ) + + serving_completion = OpenAIServingCompletion( + mock_engine, models, request_logger=None + ) + + serving_completion._process_inputs = AsyncMock( + return_value=(MagicMock(name="engine_request"), {}) + ) return mock_engine, serving_completion @pytest.mark.asyncio -async def test_serving_completion_with_lora_resolver(mock_serving_setup, - monkeypatch): +async def test_serving_completion_with_lora_resolver(mock_serving_setup, monkeypatch): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -130,20 +150,19 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup, with suppress(Exception): await serving_completion.create_completion(req_found) - mock_engine.add_lora.assert_called_once() + mock_engine.add_lora.assert_awaited_once() called_lora_request = mock_engine.add_lora.call_args[0][0] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == lora_model_name mock_engine.generate.assert_called_once() - called_lora_request = mock_engine.generate.call_args[1]['lora_request'] + called_lora_request = mock_engine.generate.call_args[1]["lora_request"] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == lora_model_name @pytest.mark.asyncio -async def test_serving_completion_resolver_not_found(mock_serving_setup, - monkeypatch): +async def test_serving_completion_resolver_not_found(mock_serving_setup, monkeypatch): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -156,7 +175,7 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, response = await serving_completion.create_completion(req) - mock_engine.add_lora.assert_not_called() + mock_engine.add_lora.assert_not_awaited() mock_engine.generate.assert_not_called() assert isinstance(response, ErrorResponse) @@ -166,7 +185,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, @pytest.mark.asyncio async def test_serving_completion_resolver_add_lora_fails( - mock_serving_setup, monkeypatch): + mock_serving_setup, monkeypatch +): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -180,7 +200,7 @@ async def test_serving_completion_resolver_add_lora_fails( response = await serving_completion.create_completion(req) # Assert add_lora was called before the failure - mock_engine.add_lora.assert_called_once() + mock_engine.add_lora.assert_awaited_once() called_lora_request = mock_engine.add_lora.call_args[0][0] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == invalid_model diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index ff2e7004ff9f8..dbcec9d31fc9b 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -18,25 +18,15 @@ from vllm import version from ...utils import RemoteOpenAIServer -MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODELS = { + "text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct", +} PREV_MINOR_VERSION = version._prev_minor_version() -@pytest.fixture(scope="module", params=[True, False]) -def use_v1(request): - # Module-scoped variant of run_with_both_engines - # - # Use this fixture to run a test with both v0 and v1, and - # also to conditionalize the test logic e.g. - # - # def test_metrics_exist(use_v1, server, client): - # ... - # expected = EXPECTED_V1_METRICS if use_v1 else EXPECTED_METRICS - # for metric in expected: - # assert metric in response.text - # - # @skip_v1 wouldn't work here because this is a module-level - # fixture - per-function decorators would have no effect +@pytest.fixture(scope="module", params=list(MODELS.keys())) +def model_key(request): yield request.param @@ -54,19 +44,21 @@ def default_server_args(): ] -@pytest.fixture(scope="module", - params=[ - "", - "--enable-chunked-prefill", - "--disable-frontend-multiprocessing", - f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", - ]) -def server(use_v1, default_server_args, request): +@pytest.fixture( + scope="module", + params=[ + "", + "--enable-chunked-prefill", + "--disable-frontend-multiprocessing", + f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", + ], +) +def server(model_key, default_server_args, request): if request.param: default_server_args.append(request.param) - env_dict = dict(VLLM_USE_V1='1' if use_v1 else '0') - with RemoteOpenAIServer(MODEL_NAME, default_server_args, - env_dict=env_dict) as remote_server: + + model_name = MODELS[model_key] + with RemoteOpenAIServer(model_name, default_server_args) as remote_server: yield remote_server @@ -77,66 +69,83 @@ async def client(server): _PROMPT = "Hello my name is Robert and I love magic" -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"] +_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" -_NUM_REQUESTS = 10 -_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT) -_NUM_GENERATION_TOKENS_PER_REQUEST = 10 -# {metric_family: [(suffix, expected_value)]} -EXPECTED_VALUES = { - "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": - [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], - "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": - [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], - "vllm:request_generation_tokens": - [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], - "vllm:request_params_n": [("_count", _NUM_REQUESTS)], - "vllm:request_params_max_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS) - ], - "vllm:iteration_tokens_total": - [("_sum", _NUM_REQUESTS * - (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)], - "vllm:prompt_tokens": [("_total", - _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], - "vllm:generation_tokens": [ - ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) - ], - "vllm:request_success": [("_total", _NUM_REQUESTS)], -} +def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int): + num_prompt_tokens = len(prompt_ids) + + # {metric_family: [(suffix, expected_value)]} + return { + "vllm:time_to_first_token_seconds": [("_count", num_requests)], + "vllm:time_per_output_token_seconds": [ + ("_count", num_requests * (max_tokens - 1)) + ], + "vllm:e2e_request_latency_seconds": [("_count", num_requests)], + "vllm:request_queue_time_seconds": [("_count", num_requests)], + "vllm:request_inference_time_seconds": [("_count", num_requests)], + "vllm:request_prefill_time_seconds": [("_count", num_requests)], + "vllm:request_decode_time_seconds": [("_count", num_requests)], + "vllm:request_prompt_tokens": [ + ("_sum", num_requests * num_prompt_tokens), + ("_count", num_requests), + ], + "vllm:request_generation_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:request_params_n": [("_count", num_requests)], + "vllm:request_params_max_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:iteration_tokens_total": [ + ( + "_sum", + num_requests * (num_prompt_tokens + max_tokens), + ), + ("_count", num_requests * max_tokens), + ], + "vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)], + "vllm:generation_tokens": [("_total", num_requests * max_tokens)], + "vllm:request_success": [("_total", num_requests)], + } @pytest.mark.asyncio -async def test_metrics_counts(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): - for _ in range(_NUM_REQUESTS): +async def test_metrics_counts( + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, +): + if model_key == "multimodal": + pytest.skip("Unnecessary test") + + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + num_requests = 10 + max_tokens = 10 + + for _ in range(num_requests): # sending a request triggers the metrics to be logged. await client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, - max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) + model=model_name, + prompt=prompt_ids, + max_tokens=max_tokens, + ) response = requests.get(server.url_for("metrics")) print(response.text) assert response.status_code == HTTPStatus.OK # Loop over all expected metric_families - for metric_family, suffix_values_list in EXPECTED_VALUES.items(): - if ((use_v1 and metric_family not in EXPECTED_METRICS_V1) - or (not server.show_hidden_metrics - and metric_family in HIDDEN_DEPRECATED_METRICS)): + expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens) + for metric_family, suffix_values_list in expected_values.items(): + if metric_family not in EXPECTED_METRICS_V1 or ( + not server.show_hidden_metrics + and metric_family in HIDDEN_DEPRECATED_METRICS + ): continue found_metric = False @@ -160,78 +169,26 @@ async def test_metrics_counts(server: RemoteOpenAIServer, assert sample.value == expected_value, ( f"{metric_name_w_suffix} expected value of " f"{expected_value} did not match found value " - f"{sample.value}") + f"{sample.value}" + ) break assert found_suffix, ( f"Did not find {metric_name_w_suffix} in prom endpoint" ) break - assert found_metric, (f"Did not find {metric_family} in prom endpoint") + assert found_metric, f"Did not find {metric_family} in prom endpoint" -EXPECTED_METRICS = [ - "vllm:num_requests_running", - "vllm:num_requests_waiting", - "vllm:gpu_cache_usage_perc", - "vllm:time_to_first_token_seconds_sum", - "vllm:time_to_first_token_seconds_bucket", - "vllm:time_to_first_token_seconds_count", - "vllm:time_per_output_token_seconds_sum", - "vllm:time_per_output_token_seconds_bucket", - "vllm:time_per_output_token_seconds_count", - "vllm:e2e_request_latency_seconds_sum", - "vllm:e2e_request_latency_seconds_bucket", - "vllm:e2e_request_latency_seconds_count", - "vllm:request_queue_time_seconds_sum", - "vllm:request_queue_time_seconds_bucket", - "vllm:request_queue_time_seconds_count", - "vllm:request_inference_time_seconds_sum", - "vllm:request_inference_time_seconds_bucket", - "vllm:request_inference_time_seconds_count", - "vllm:request_prefill_time_seconds_sum", - "vllm:request_prefill_time_seconds_bucket", - "vllm:request_prefill_time_seconds_count", - "vllm:request_decode_time_seconds_sum", - "vllm:request_decode_time_seconds_bucket", - "vllm:request_decode_time_seconds_count", - "vllm:request_prompt_tokens_sum", - "vllm:request_prompt_tokens_bucket", - "vllm:request_prompt_tokens_count", - "vllm:request_generation_tokens_sum", - "vllm:request_generation_tokens_bucket", - "vllm:request_generation_tokens_count", - "vllm:request_params_n_sum", - "vllm:request_params_n_bucket", - "vllm:request_params_n_count", - "vllm:request_params_max_tokens_sum", - "vllm:request_params_max_tokens_bucket", - "vllm:request_params_max_tokens_count", - "vllm:iteration_tokens_total", - "vllm:num_preemptions_total", - "vllm:prompt_tokens_total", - "vllm:generation_tokens_total", - "vllm:request_success_total", - "vllm:cache_config_info", - # labels in cache_config_info - "block_size", - "cache_dtype", - "cpu_offload_gb", - "enable_prefix_caching", - "gpu_memory_utilization", - "num_cpu_blocks", - "num_gpu_blocks", - "num_gpu_blocks_override", - "sliding_window", - "swap_space_bytes", -] - EXPECTED_METRICS_V1 = [ "vllm:num_requests_running", "vllm:num_requests_waiting", "vllm:gpu_cache_usage_perc", "vllm:gpu_prefix_cache_queries", "vllm:gpu_prefix_cache_hits", + "vllm:kv_cache_usage_perc", + "vllm:prefix_cache_queries", + "vllm:prefix_cache_hits", "vllm:num_preemptions_total", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", @@ -250,12 +207,15 @@ EXPECTED_METRICS_V1 = [ "vllm:request_params_max_tokens_sum", "vllm:request_params_max_tokens_bucket", "vllm:request_params_max_tokens_count", - "vllm:time_to_first_token_seconds_sum", - "vllm:time_to_first_token_seconds_bucket", - "vllm:time_to_first_token_seconds_count", "vllm:time_per_output_token_seconds_sum", "vllm:time_per_output_token_seconds_bucket", "vllm:time_per_output_token_seconds_count", + "vllm:time_to_first_token_seconds_sum", + "vllm:time_to_first_token_seconds_bucket", + "vllm:time_to_first_token_seconds_count", + "vllm:inter_token_latency_seconds_sum", + "vllm:inter_token_latency_seconds_bucket", + "vllm:inter_token_latency_seconds_count", "vllm:e2e_request_latency_seconds_sum", "vllm:e2e_request_latency_seconds_bucket", "vllm:e2e_request_latency_seconds_count", @@ -273,33 +233,80 @@ EXPECTED_METRICS_V1 = [ "vllm:request_decode_time_seconds_count", ] -HIDDEN_DEPRECATED_METRICS: list[str] = [] +EXPECTED_METRICS_MM = [ + "vllm:mm_cache_queries", + "vllm:mm_cache_hits", +] + +HIDDEN_DEPRECATED_METRICS: list[str] = [ + "vllm:gpu_cache_usage_perc", + "vllm:gpu_prefix_cache_queries", + "vllm:gpu_prefix_cache_hits", + "vllm:time_per_output_token_seconds_sum", + "vllm:time_per_output_token_seconds_bucket", + "vllm:time_per_output_token_seconds_count", +] @pytest.mark.asyncio -async def test_metrics_exist(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_exist( + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, +): + model_name = MODELS[model_key] + # sending a request triggers the metrics to be logged. - await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + if model_key == "text": + await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0, + ) + else: + await client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": _IMAGE_URL}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + max_tokens=5, + temperature=0.0, + ) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): - if (not server.show_hidden_metrics - and metric not in HIDDEN_DEPRECATED_METRICS): - assert metric in response.text + expected_metrics = EXPECTED_METRICS_V1 + if model_key == "multimodal": + # NOTE: Don't use in-place assignment + expected_metrics = expected_metrics + EXPECTED_METRICS_MM + + for metric in expected_metrics: + if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics: + continue + assert metric in response.text @pytest.mark.asyncio -async def test_abort_metrics_reset(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_abort_metrics_reset( + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, +): + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) - running_requests, waiting_requests, kv_cache_usage = ( - _get_running_metrics_from_api(server)) + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( + server, + ) # Expect no running requests or kvcache usage assert running_requests == 0 @@ -311,18 +318,21 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, for _ in range(3): task = asyncio.create_task( client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, + model=model_name, + prompt=prompt_ids, max_tokens=100, # Long generation to give time to abort - temperature=0.0)) + temperature=0.0, + ) + ) tasks.append(task) # Wait a bit for requests to start processing await asyncio.sleep(0.5) # Check that we have running requests - running_requests, waiting_requests, kv_cache_usage = ( - _get_running_metrics_from_api(server)) + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( + server, + ) # Expect running requests and kvcache usage assert running_requests > 0 @@ -341,17 +351,18 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, # Verify running and waiting requests counts and KV cache usage are zero running_requests_after, waiting_requests_after, kv_cache_usage_after = ( - _get_running_metrics_from_api(server)) + _get_running_metrics_from_api(server) + ) - assert running_requests_after == 0,\ - (f"Expected 0 running requests after abort, got " - f"{running_requests_after}") - assert waiting_requests_after == 0,\ - (f"Expected 0 waiting requests after abort, got " - f"{waiting_requests_after}") - assert kv_cache_usage_after == 0,\ - (f"Expected 0% KV cache usage after abort, got " - f"{kv_cache_usage_after}") + assert running_requests_after == 0, ( + f"Expected 0 running requests after abort, got {running_requests_after}" + ) + assert waiting_requests_after == 0, ( + f"Expected 0 waiting requests after abort, got {waiting_requests_after}" + ) + assert kv_cache_usage_after == 0, ( + f"Expected 0% KV cache usage after abort, got {kv_cache_usage_after}" + ) def _get_running_metrics_from_api(server: RemoteOpenAIServer): @@ -363,6 +374,8 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer): # Verify running and waiting requests counts and KV cache usage are zero running_requests, waiting_requests, kv_cache_usage = None, None, None + kv_cache_usage_metric = "vllm:kv_cache_usage_perc" + for family in text_string_to_metric_families(response.text): if family.name == "vllm:num_requests_running": for sample in family.samples: @@ -374,9 +387,9 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer): if sample.name == "vllm:num_requests_waiting": waiting_requests = sample.value break - elif family.name == "vllm:gpu_cache_usage_perc": + elif family.name == kv_cache_usage_metric: for sample in family.samples: - if sample.name == "vllm:gpu_cache_usage_perc": + if sample.name == kv_cache_usage_metric: kv_cache_usage = sample.value break @@ -387,35 +400,37 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer): return running_requests, waiting_requests, kv_cache_usage -def test_metrics_exist_run_batch(use_v1: bool): +def test_metrics_exist_run_batch(): input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 base_url = "0.0.0.0" port = "8001" server_url = f"http://{base_url}:{port}" - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(input_batch) input_file.flush() - proc = subprocess.Popen([ - sys.executable, - "-m", - "vllm.entrypoints.openai.run_batch", - "-i", - input_file.name, - "-o", - output_file.name, - "--model", - "intfloat/multilingual-e5-small", - "--enable-metrics", - "--url", - base_url, - "--port", - port, - ], - env={"VLLM_USE_V1": "1" if use_v1 else "0"}) + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "vllm.entrypoints.openai.run_batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + "--enable-metrics", + "--url", + base_url, + "--port", + port, + ], + ) def is_server_up(url): try: diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 1980daa80db9e..7d2968d965066 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -4,8 +4,6 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio -# downloading lora to test lora requests -from huggingface_hub import snapshot_download from ...utils import RemoteOpenAIServer @@ -13,12 +11,6 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) @pytest.fixture(scope="module") @@ -34,7 +26,6 @@ def server(zephyr_lora_files): "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_files}", "--max-lora-rank", "64", "--max-cpu-loras", @@ -61,7 +52,5 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): lora_models = models[1:] assert served_model.id == MODEL_NAME assert served_model.root == MODEL_NAME - assert all(lora_model.root == zephyr_lora_files - for lora_model in lora_models) + assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" - assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index f0ce50debe494..ba463be1d5cd7 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -25,13 +25,10 @@ def run_and_test_dummy_opt_api_server(model, tp=1): client = server.get_client() completion = client.chat.completions.create( model=model, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], temperature=0, ) generated_text = completion.choices[0].message.content diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 11ed1c4a9ee4b..64fdaf08893ad 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -75,10 +75,11 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): http://localhost:8000/v1/chat/completions """ # noqa: E501 if hasattr(case, "body") and isinstance(case.body, dict): - if ("messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): - + if ( + "messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0 + ): for message in case.body["messages"]: if not isinstance(message, dict): continue @@ -86,10 +87,11 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): # Check for invalid file type in tokenize endpoint if op.method.lower() == "post" and op.path == "/tokenize": content = message.get("content", []) - if (isinstance(content, list) and len(content) > 0 - and any( - item.get("type") == "file" - for item in content)): + if ( + isinstance(content, list) + and len(content) > 0 + and any(item.get("type") == "file" for item in content) + ): return False # Check for invalid tool_calls with non-function types @@ -102,12 +104,17 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): if "custom" in tool_call: return False - # Sometimes guided_grammar is generated to be empty + # Sometimes structured_outputs.grammar is generated to be empty # Causing a server error in EBNF grammar parsing # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 - guided_grammar = case.body.get("guided_grammar") + structured_outputs = case.body.get("structured_outputs", {}) + grammar = ( + structured_outputs.get("grammar") + if isinstance(structured_outputs, dict) + else None + ) - if guided_grammar == '': + if grammar == "": # Allow None (will be handled as no grammar) # But skip empty strings return False @@ -131,9 +138,8 @@ def test_openapi_stateless(case: schemathesis.Case): timeout = { # requires a longer timeout - ("POST", "/v1/chat/completions"): - LONG_TIMEOUT_SECONDS, + ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, }.get(key, DEFAULT_TIMEOUT_SECONDS) - #No need to verify SSL certificate for localhost + # No need to verify SSL certificate for localhost case.call_and_validate(verify=False, timeout=timeout) diff --git a/tests/entrypoints/openai/test_optional_middleware.py b/tests/entrypoints/openai/test_optional_middleware.py index eb387998c2cc4..b67d6147937d1 100644 --- a/tests/entrypoints/openai/test_optional_middleware.py +++ b/tests/entrypoints/openai/test_optional_middleware.py @@ -37,7 +37,7 @@ def server(request: pytest.FixtureRequest): "--enforce-eager", "--max-num-seqs", "2", - *passed_params + *passed_params, ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -73,8 +73,9 @@ async def test_missing_api_token(server: RemoteOpenAIServer): ) @pytest.mark.asyncio async def test_passed_api_token(server: RemoteOpenAIServer): - response = requests.get(server.url_for("v1/models"), - headers={"Authorization": "Bearer test"}) + response = requests.get( + server.url_for("v1/models"), headers={"Authorization": "Bearer test"} + ) assert response.status_code == HTTPStatus.OK @@ -110,7 +111,8 @@ async def test_enable_request_id_header(server: RemoteOpenAIServer): ) @pytest.mark.asyncio async def test_custom_request_id_header(server: RemoteOpenAIServer): - response = requests.get(server.url_for("health"), - headers={"X-Request-Id": "Custom"}) + response = requests.get( + server.url_for("health"), headers={"X-Request-Id": "Custom"} + ) assert "X-Request-Id" in response.headers assert response.headers.get("X-Request-Id") == "Custom" diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 4197583074dfe..3d0885414b24b 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -3,23 +3,18 @@ import io -# imports for guided decoding tests +# imports for structured outputs tests import openai import pybase64 import pytest import regex as re import torch -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.renderer import BaseRenderer from ...utils import RemoteOpenAIServer -@pytest.fixture(scope="function", autouse=True) -def use_v1_only(monkeypatch): - monkeypatch.setenv('VLLM_USE_V1', '1') - - @pytest.mark.asyncio async def test_empty_prompt(): model_name = "gpt2" @@ -27,12 +22,17 @@ async def test_empty_prompt(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match="decoder prompt cannot be empty"): - await client.completions.create(model=model_name, - prompt="", - max_tokens=5, - temperature=0.0) + with pytest.raises( + openai.BadRequestError, + match="Either prompt or prompt_embeds must be provided and non-empty.", + ): + await client.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": []}, + ) @pytest.mark.asyncio @@ -42,23 +42,23 @@ async def test_out_of_vocab_token_ids(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match=re.compile('.*out of vocabulary.*').pattern): - await client.completions.create(model=model_name, - prompt=[999999], - max_tokens=5, - temperature=0.0) + with pytest.raises( + openai.BadRequestError, match=re.compile(".*out of vocabulary.*").pattern + ): + await client.completions.create( + model=model_name, prompt=[999999], max_tokens=5, temperature=0.0 + ) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "layout", - [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) + "layout", [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr] +) @pytest.mark.parametrize("seq_len", [2, 10]) @pytest.mark.parametrize("hidden_size", [2, 10]) -def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, - seq_len: int, hidden_size: int): +def test_load_prompt_embeds( + dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int +): # construct arbitrary tensors of various dtypes, layouts, and sizes. # We need to check against different layouts to make sure that if a user # uses sparse tensors to reduce the transmission size of prompt embeddings, @@ -83,11 +83,11 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor) assert len(loaded_prompt_embeds) == 1 loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] assert loaded_tensor.device.type == "cpu" assert loaded_tensor.layout == torch.strided - torch.testing.assert_close(loaded_tensor, - tensor.to("cpu").to_dense(), - equal_nan=True) + torch.testing.assert_close( + loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True + ) diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools.py b/tests/entrypoints/openai/test_response_api_mcp_tools.py new file mode 100644 index 0000000000000..653d44f20b440 --- /dev/null +++ b/tests/entrypoints/openai/test_response_api_mcp_tools.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import pytest_asyncio +from openai import OpenAI + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch): + args = ["--enforce-eager", "--tool-server", "demo"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="function") +def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch): + args = ["--enforce-eager", "--tool-server", "demo"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") + m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def mcp_disabled_client(mcp_disabled_server): + async with mcp_disabled_server.get_async_client() as async_client: + yield async_client + + +@pytest_asyncio.fixture +async def mcp_enabled_client(mcp_enabled_server): + async with mcp_enabled_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") +async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str): + response = await mcp_enabled_client.responses.create( + model=model_name, + # TODO: Ideally should be able to set max tool calls + # to prevent multi-turn, but it is not currently supported + # would speed up the test + input=( + "What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output." + ), + tools=[ + { + "type": "mcp", + "server_label": "code_interpreter", + # URL unused for DemoToolServer + "server_url": "http://localhost:8888", + } + ], + ) + assert response is not None + assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") +async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str): + response = await mcp_disabled_client.responses.create( + model=model_name, + # TODO: Ideally should be able to set max tool calls + # to prevent multi-turn, but it is not currently supported + # would speed up the test + input=( + "What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output." + ), + tools=[ + { + "type": "mcp", + "server_label": "code_interpreter", + # URL unused for DemoToolServer + "server_url": "http://localhost:8888", + } + ], + ) + assert response is not None + assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens == 0 diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 1ca52599c519d..57d88f84d2519 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -8,20 +8,24 @@ import pytest import pytest_asyncio import requests from openai import BadRequestError, NotFoundError, OpenAI +from openai_harmony import ( + Message, +) from ...utils import RemoteOpenAIServer -pytest.skip(allow_module_level=True, reason="gpt-oss can't run on CI yet.") - MODEL_NAME = "openai/gpt-oss-20b" -DTYPE = "bfloat16" @pytest.fixture(scope="module") def server(): args = ["--enforce-eager", "--tool-server", "demo"] + env_dict = dict( + VLLM_ENABLE_RESPONSES_API_STORE="1", + PYTHON_EXECUTION_BACKEND="dangerously_use_uv", + ) - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: yield remote_server @@ -67,28 +71,30 @@ async def test_basic_with_reasoning_effort(client: OpenAI, model_name: str): assert response.status == "completed" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_max_tokens(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is the first paragraph of Moby Dick?", + reasoning={"effort": "low"}, + max_output_tokens=30, + ) + assert response is not None + assert response.status == "incomplete" + assert response.incomplete_details.reason == "max_output_tokens" + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_chat(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, input=[ - { - "role": "system", - "content": "Respond in Korean." - }, - { - "role": "user", - "content": "Hello!" - }, - { - "role": "assistant", - "content": "Hello! How can I help you today?" - }, - { - "role": "user", - "content": "What is 13 * 24? Explain your answer." - }, + {"role": "system", "content": "Respond in Korean."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hello! How can I help you today?"}, + {"role": "user", "content": "What is 13 * 24? Explain your answer."}, ], ) assert response is not None @@ -103,10 +109,7 @@ async def test_chat_with_input_type(client: OpenAI, model_name: str): input=[ { "role": "user", - "content": [{ - "type": "input_text", - "text": "What is 13*24?" - }], + "content": [{"type": "input_text", "text": "What is 13*24?"}], }, ], ) @@ -120,14 +123,10 @@ async def test_structured_output(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, input=[ - { - "role": "system", - "content": "Extract the event information." - }, + {"role": "system", "content": "Extract the event information."}, { "role": "user", - "content": - "Alice and Bob are going to a science fair on Friday.", + "content": "Alice and Bob are going to a science fair on Friday.", }, ], text={ @@ -137,18 +136,9 @@ async def test_structured_output(client: OpenAI, model_name: str): "schema": { "type": "object", "properties": { - "name": { - "type": "string" - }, - "date": { - "type": "string" - }, - "participants": { - "type": "array", - "items": { - "type": "string" - } - }, + "name": {"type": "string"}, + "date": {"type": "string"}, + "participants": {"type": "array", "items": {"type": "string"}}, }, "required": ["name", "date", "participants"], "additionalProperties": False, @@ -268,11 +258,62 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_streaming(client: OpenAI, model_name: str): +async def test_streaming_types(client: OpenAI, model_name: str): + prompts = [ + "tell me a story about a cat in 20 words", + ] + + # this links the "done" type with the "start" type + # so every "done" type should have a corresponding "start" type + # and every open block should be closed by the end of the stream + pairs_of_event_types = { + "response.completed": "response.created", + "response.output_item.done": "response.output_item.added", + "response.content_part.done": "response.content_part.added", + "response.output_text.done": "response.output_text.delta", + "response.web_search_call.done": "response.web_search_call.added", + "response.reasoning_text.done": "response.reasoning_text.delta", + "response.reasoning_part.done": "response.reasoning_part.added", + } + + for prompt in prompts: + response = await client.responses.create( + model=model_name, + input=prompt, + reasoning={"effort": "low"}, + tools=[], + stream=True, + background=False, + ) + + stack_of_event_types = [] + async for event in response: + if event.type == "response.created": + stack_of_event_types.append(event.type) + elif event.type == "response.completed": + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + if event.type.endswith("added"): + stack_of_event_types.append(event.type) + elif event.type.endswith("delta"): + if stack_of_event_types[-1] == event.type: + continue + stack_of_event_types.append(event.type) + elif event.type.endswith("done"): + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + assert len(stack_of_event_types) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("background", [True, False]) +async def test_streaming(client: OpenAI, model_name: str, background: bool): + # TODO: Add back when web search and code interpreter are available in CI prompts = [ "tell me a story about a cat in 20 words", "What is 13 * 24? Use python to calculate the result.", - "When did Jensen found NVIDIA? Search it and answer the year only.", + # "When did Jensen found NVIDIA? Search it and answer the year only.", ] for prompt in prompts: @@ -281,49 +322,111 @@ async def test_streaming(client: OpenAI, model_name: str): input=prompt, reasoning={"effort": "low"}, tools=[ - { - "type": "web_search_preview" - }, - { - "type": "code_interpreter", - "container": { - "type": "auto" - } - }, + # { + # "type": "web_search_preview" + # }, + {"type": "code_interpreter", "container": {"type": "auto"}}, ], stream=True, + background=background, + extra_body={"enable_response_messages": True}, ) + current_item_id = "" + current_content_index = -1 + events = [] current_event_mode = None + resp_id = None + checked_response_completed = False async for event in response: + if event.type == "response.created": + resp_id = event.response.id + + # test vllm custom types are in the response + if event.type in [ + "response.completed", + "response.in_progress", + "response.created", + ]: + assert "input_messages" in event.response.model_extra + assert "output_messages" in event.response.model_extra + if event.type == "response.completed": + # make sure the serialization of content works + for msg in event.response.model_extra["output_messages"]: + # make sure we can convert the messages back into harmony + Message.from_dict(msg) + + for msg in event.response.model_extra["input_messages"]: + # make sure we can convert the messages back into harmony + Message.from_dict(msg) + checked_response_completed = True + if current_event_mode != event.type: current_event_mode = event.type print(f"\n[{event.type}] ", end="", flush=True) + # verify current_item_id is correct + if event.type == "response.output_item.added": + assert event.item.id != current_item_id + current_item_id = event.item.id + elif event.type in [ + "response.output_text.delta", + "response.reasoning_text.delta", + ]: + assert event.item_id == current_item_id + + # verify content_index_id is correct + if event.type in [ + "response.content_part.added", + "response.reasoning_part.added", + ]: + assert event.content_index != current_content_index + current_content_index = event.content_index + elif event.type in [ + "response.output_text.delta", + "response.reasoning_text.delta", + ]: + assert event.content_index == current_content_index + if "text.delta" in event.type: print(event.delta, end="", flush=True) elif "reasoning_text.delta" in event.type: print(f"{event.delta}", end="", flush=True) elif "response.code_interpreter_call_code.done" in event.type: print(f"Code: {event.code}", end="", flush=True) - elif ("response.output_item.added" in event.type - and event.item.type == "web_search_call"): + elif ( + "response.output_item.added" in event.type + and event.item.type == "web_search_call" + ): print(f"Web search: {event.item.action}", end="", flush=True) events.append(event) assert len(events) > 0 + response_completed_event = events[-1] + assert len(response_completed_event.response.output) > 0 + assert checked_response_completed + + if background: + starting_after = 5 + async with await client.responses.retrieve( + response_id=resp_id, stream=True, starting_after=starting_after + ) as stream: + counter = starting_after + async for event in stream: + counter += 1 + assert event == events[counter] + assert counter == len(events) - 1 @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Web search tool is not available in CI yet.") async def test_web_search(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, input="Who is the president of South Korea as of now?", - tools=[{ - "type": "web_search_preview" - }], + tools=[{"type": "web_search_preview"}], ) assert response is not None assert response.status == "completed" @@ -334,16 +437,26 @@ async def test_web_search(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, - input="Multiply 64548*15151 using builtin python interpreter.", - tools=[{ - "type": "code_interpreter", - "container": { - "type": "auto" - } - }], + # TODO: Ideally should be able to set max tool calls + # to prevent multi-turn, but it is not currently supported + # would speed up the test + input=( + "What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output." + ), + tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], + temperature=0.0, # More deterministic output in response ) assert response is not None assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens > 0 + for item in response.output: + if item.type == "message": + output_string = item.content[0].text + print("output_string: ", output_string, flush=True) + assert "5846" in output_string def get_weather(latitude, longitude): @@ -370,31 +483,30 @@ def call_function(name, args): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling(client: OpenAI, model_name: str): - tools = [{ - "type": "function", - "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, + "required": ["latitude", "longitude"], + "additionalProperties": False, }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }] + "strict": True, + } + ] response = await client.responses.create( model=model_name, input="What's the weather like in Paris today?", tools=tools, + temperature=0.0, + extra_body={"request_id": "test_function_calling_non_resp"}, ) assert response is not None assert response.status == "completed" @@ -410,11 +522,13 @@ async def test_function_calling(client: OpenAI, model_name: str): response_2 = await client.responses.create( model=model_name, - input=[{ - "type": "function_call_output", - "call_id": tool_call.call_id, - "output": str(result), - }], + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], tools=tools, previous_response_id=response.id, ) @@ -436,6 +550,7 @@ async def test_function_calling(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.flaky(reruns=5) async def test_function_calling_multi_turn(client: OpenAI, model_name: str): tools = [ { @@ -453,17 +568,12 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): { "type": "function", "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa + "description": "Get current temperature for provided coordinates in celsius.", # noqa "parameters": { "type": "object", "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" - }, + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, "required": ["latitude", "longitude"], "additionalProperties": False, @@ -474,8 +584,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, - input= - "Help me plan a trip to a random place. And tell me the weather there.", + input="Help me plan a trip to a random place. And tell me the weather there.", tools=tools, ) assert response is not None @@ -492,11 +601,13 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): response_2 = await client.responses.create( model=model_name, - input=[{ - "type": "function_call_output", - "call_id": tool_call.call_id, - "output": str(result), - }], + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], tools=tools, previous_response_id=response.id, ) @@ -514,11 +625,13 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): response_3 = await client.responses.create( model=model_name, - input=[{ - "type": "function_call_output", - "call_id": tool_call.call_id, - "output": str(result), - }], + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], tools=tools, previous_response_id=response_2.id, ) @@ -530,26 +643,23 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling_required(client: OpenAI, model_name: str): - tools = [{ - "type": "function", - "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, + "required": ["latitude", "longitude"], + "additionalProperties": False, }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }] + "strict": True, + } + ] with pytest.raises(BadRequestError): await client.responses.create( @@ -562,32 +672,44 @@ async def test_function_calling_required(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_function_calling_full_history(client: OpenAI, model_name: str): - tools = [{ - "type": "function", - "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" - }, - }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }] +async def test_system_message_with_tools(client: OpenAI, model_name: str): + from vllm.entrypoints.harmony_utils import get_system_message - input_messages = [{ - "role": "user", - "content": "What's the weather like in Paris today?" - }] + # Test with custom tools enabled - commentary channel should be available + sys_msg = get_system_message(with_custom_tools=True) + valid_channels = sys_msg.content[0].channel_config.valid_channels + assert "commentary" in valid_channels + + # Test with custom tools disabled - commentary channel should be removed + sys_msg = get_system_message(with_custom_tools=False) + valid_channels = sys_msg.content[0].channel_config.valid_channels + assert "commentary" not in valid_channels + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_full_history(client: OpenAI, model_name: str): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + } + ] + + input_messages = [ + {"role": "user", "content": "What's the weather like in Paris today?"} + ] response = await client.responses.create( model=model_name, @@ -604,8 +726,7 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): result = call_function(name, args) - input_messages.extend( - response.output) # append model's function call message + input_messages.extend(response.output) # append model's function call message input_messages.append( { # append result message "type": "function_call_output", @@ -622,3 +743,18 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): assert response_2 is not None assert response_2.status == "completed" assert response_2.output_text is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_output_messages_enabled(client: OpenAI, model_name: str, server): + response = await client.responses.create( + model=model_name, + input="What is the capital of South Korea?", + extra_body={"enable_response_messages": True}, + ) + + assert response is not None + assert response.status == "completed" + assert len(response.input_messages) > 0 + assert len(response.output_messages) > 0 diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py index 6addcb41c4098..60a80210fb768 100644 --- a/tests/entrypoints/openai/test_return_token_ids.py +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -50,13 +50,16 @@ async def test_basic_completion_with_emoji(server): # Check against the expected prompt token IDs tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) encoded_tokens = tokenizer.encode( - "Complete this sentence with emojis: I love coding 🚀") + "Complete this sentence with emojis: I love coding 🚀" + ) # Check that encoded_tokens is a subsequence of prompt_token_ids - assert any(completion.choices[0].prompt_token_ids[i:i + - len(encoded_tokens)] - == encoded_tokens for i in range( - len(completion.choices[0].prompt_token_ids) - - len(encoded_tokens) + 1)) + assert any( + completion.choices[0].prompt_token_ids[i : i + len(encoded_tokens)] + == encoded_tokens + for i in range( + len(completion.choices[0].prompt_token_ids) - len(encoded_tokens) + 1 + ) + ) # Verify token_ids field is present in the choice assert completion.choices[0].token_ids is not None @@ -86,44 +89,38 @@ async def test_basic_completion_with_emoji(server): @pytest.mark.asyncio async def test_chat_completion_with_tool_use(server): """Test chat completion with tool use (get_weather function).""" - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": - "string", - "description": - "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The unit of temperature", + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature", + }, }, + "required": ["location"], }, - "required": ["location"], }, - }, - }] + } + ] async with server.get_async_client() as client: # Test with return_token_ids enabled response = await client.chat.completions.create( model=MODEL_NAME, messages=[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "What's the weather like in Paris?" - }, + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in Paris?"}, ], tools=tools, tool_choice="auto", @@ -145,10 +142,11 @@ async def test_chat_completion_with_tool_use(server): tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) prompt_text = tokenizer.decode(response.prompt_token_ids) assert prompt_text.startswith( - "<|im_start|>system\nYou are a helpful assistant.") + "<|im_start|>system\nYou are a helpful assistant." + ) assert prompt_text.endswith( - "What's the weather like in Paris?<|im_end|>\n" - "<|im_start|>assistant\n") + "What's the weather like in Paris?<|im_end|>\n<|im_start|>assistant\n" + ) response_text = tokenizer.decode(response.choices[0].token_ids) assert response_text.startswith('<tool_call>\n{"name": "get_weather"') @@ -164,14 +162,8 @@ async def test_chat_completion_with_tool_use(server): response_without = await client.chat.completions.create( model=MODEL_NAME, messages=[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "What's the weather like in Paris?" - }, + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in Paris?"}, ], tools=tools, tool_choice="auto", @@ -203,7 +195,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): extra_body={ "return_token_ids": True, "return_tokens_as_token_ids": True, - "prompt_logprobs": 1 + "prompt_logprobs": 1, }, ) @@ -224,20 +216,21 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): logprobs_token_ids.append(token_id) # When echo=True, the logprobs include both prompt and response tokens - # The token_ids field should match the the suffix of response portion + # The token_ids field should match the suffix of response portion # The prompt_token_ids should match the prompt portion assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) response_token_ids_length = len(completion.choices[0].token_ids) - assert logprobs_token_ids[-response_token_ids_length:] == \ - completion.choices[0].token_ids + assert ( + logprobs_token_ids[-response_token_ids_length:] + == completion.choices[0].token_ids + ) # Verify tokenizer consistency tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) # Decode prompt tokens if completion.choices[0].prompt_token_ids: - prompt_text = tokenizer.decode( - completion.choices[0].prompt_token_ids) + prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids) # The decoded prompt should match or close to original prompt assert "Hello, world" in prompt_text @@ -255,10 +248,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): stream=True, echo=False, logprobs=1, - extra_body={ - "return_token_ids": True, - "return_tokens_as_token_ids": True - }, + extra_body={"return_token_ids": True, "return_tokens_as_token_ids": True}, ) # Collect streamed tokens @@ -287,14 +277,8 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): async def test_chat_completion_with_emoji_and_token_ids(server): """Test chat completion with emojis to verify token_ids handling.""" chat_messages = [ - { - "role": "system", - "content": "You like to use emojis in your responses." - }, - { - "role": "user", - "content": "Repeat after me: I love cats 🐱" - }, + {"role": "system", "content": "You like to use emojis in your responses."}, + {"role": "user", "content": "Repeat after me: I love cats 🐱"}, ] async with server.get_async_client() as client: response = await client.chat.completions.create( @@ -319,15 +303,16 @@ async def test_chat_completion_with_emoji_and_token_ids(server): decoded_prompt = tokenizer.decode(response.prompt_token_ids) assert decoded_prompt.startswith( - "<|im_start|>system\nYou like to use emojis in your responses.") + "<|im_start|>system\nYou like to use emojis in your responses." + ) assert decoded_prompt.endswith( - "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n") + "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n" + ) decoded_response = tokenizer.decode(response.choices[0].token_ids) # The content should match the response text # except the ending <|im_end|> - assert decoded_response == response.choices[ - 0].message.content + "<|im_end|>" + assert decoded_response == response.choices[0].message.content + "<|im_end|>" # Test with streaming stream = await client.chat.completions.create( @@ -348,14 +333,14 @@ async def test_chat_completion_with_emoji_and_token_ids(server): assert chunk.prompt_token_ids is not None assert isinstance(chunk.prompt_token_ids, list) # Check the prompt_token_ids match the initial prompt - decoded_prompt_stream = tokenizer.decode( - chunk.prompt_token_ids) + decoded_prompt_stream = tokenizer.decode(chunk.prompt_token_ids) assert decoded_prompt_stream == decoded_prompt first_chunk = False else: chunk_dump = chunk.model_dump() - assert "prompt_token_ids" not in chunk_dump, \ + assert "prompt_token_ids" not in chunk_dump, ( "Subsequent chunks should not have prompt_token_ids" + ) if chunk.choices: if chunk.choices[0].delta.content: diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index af58fbd4b3640..adbcc1f2430c4 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -10,10 +10,30 @@ import pytest from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -from .test_completion import default_server_args # noqa: F401 -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 -from .test_completion import MODEL_NAME + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + ] @pytest.fixture(scope="module") @@ -24,22 +44,19 @@ def server_fixture(request, default_server_args): # noqa: F811 with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server: yield (remote_server, True) else: - with RemoteOpenAIServer(MODEL_NAME, - default_server_args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: yield (remote_server, False) @pytest.mark.asyncio @pytest.mark.parametrize("server_fixture", [True, False], indirect=True) -async def test_completion_return_tokens_as_token_ids_completion( - server_fixture): +async def test_completion_return_tokens_as_token_ids_completion(server_fixture): server, use_server_flag = server_fixture request_args = {} if not use_server_flag: request_args["return_tokens_as_token_ids"] = True async with server.get_async_client() as client: - completion = await client.completions.create( model=MODEL_NAME, # Include Unicode characters to test for dividing a single @@ -50,7 +67,8 @@ async def test_completion_return_tokens_as_token_ids_completion( temperature=0, max_tokens=10, logprobs=1, - extra_body=request_args) + extra_body=request_args, + ) text = completion.choices[0].text token_strs = completion.choices[0].logprobs.tokens @@ -84,22 +102,22 @@ async def test_chat_return_tokens_as_token_ids_completion(server_fixture): # Include Unicode characters to test for dividing a single # character across multiple tokens: 🎉 is [28705, 31862] for the # Zephyr tokenizer - messages=[{ - "role": "system", - "content": "You like to respond in only emojis, like 🎉" - }, { - "role": "user", - "content": "Please write some emojis: 🐱🐶🎉" - }], + messages=[ + { + "role": "system", + "content": "You like to respond in only emojis, like 🎉", + }, + {"role": "user", "content": "Please write some emojis: 🐱🐶🎉"}, + ], temperature=0, max_tokens=8, logprobs=True, - extra_body=request_args) + extra_body=request_args, + ) text = response.choices[0].message.content tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) token_ids = [] for logprob_content in response.choices[0].logprobs.content: - token_ids.append( - int(logprob_content.token.removeprefix("token_id:"))) + token_ids.append(int(logprob_content.token.removeprefix("token_id:"))) assert tokenizer.decode(token_ids, skip_special_tokens=True) == text diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py index 7b4966848b9de..6bcb80878f07a 100644 --- a/tests/entrypoints/openai/test_root_path.py +++ b/tests/entrypoints/openai/test_root_path.py @@ -51,26 +51,31 @@ class TestCase(NamedTuple): model_name=MODEL_NAME, base_url=["v1"], # http://localhost:8000/v1 api_key=ERROR_API_KEY, - expected_error=openai.AuthenticationError), + expected_error=openai.AuthenticationError, + ), TestCase( model_name=MODEL_NAME, base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 api_key=ERROR_API_KEY, - expected_error=openai.AuthenticationError), + expected_error=openai.AuthenticationError, + ), TestCase( model_name=MODEL_NAME, base_url=["v1"], # http://localhost:8000/v1 api_key=API_KEY, - expected_error=None), + expected_error=None, + ), TestCase( model_name=MODEL_NAME, base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 api_key=API_KEY, - expected_error=None), + expected_error=None, + ), ], ) -async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, - test_case: TestCase): +async def test_chat_session_root_path_with_api_key( + server: RemoteOpenAIServer, test_case: TestCase +): saying: str = "Here is a common saying about apple. An apple a day, keeps" ctx = contextlib.nullcontext() if test_case.expected_error is not None: @@ -79,20 +84,16 @@ async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, client = openai.AsyncOpenAI( api_key=test_case.api_key, base_url=server.url_for(*test_case.base_url), - max_retries=0) + max_retries=0, + ) chat_completion = await client.chat.completions.create( model=test_case.model_name, - messages=[{ - "role": "user", - "content": "tell me a common saying" - }, { - "role": "assistant", - "content": saying - }], - extra_body={ - "continue_final_message": True, - "add_generation_prompt": False - }) + messages=[ + {"role": "user", "content": "tell me a common saying"}, + {"role": "assistant", "content": saying}, + ], + extra_body={"continue_final_message": True, "add_generation_prompt": False}, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index e23f41e983b0d..e17f25afe4c91 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -9,22 +9,28 @@ import pytest from vllm.entrypoints.openai.protocol import BatchRequestOutput +MODEL_NAME = "Qwen/Qwen3-0.6B" + # ruff: noqa: E501 -INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +INPUT_BATCH = ( + '{{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are a helpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "NonExistModel", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {{"stream": "True", "model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}' +).format(MODEL_NAME) -{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {"stream": "True", "model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" +INVALID_INPUT_BATCH = ( + '{{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are a helpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}' +).format(MODEL_NAME) -INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" - -INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are an unhelpful assistant."}} - -{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}} -{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}""" +INPUT_EMBEDDING_BATCH = ( + '{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}\n' + '{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are an unhelpful assistant."}}\n' + '{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}\n' + '{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}' +) INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" @@ -35,15 +41,24 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re def test_empty_file(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write("") input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "intfloat/multilingual-e5-small" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -53,15 +68,24 @@ def test_empty_file(): def test_completions(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INPUT_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "NousResearch/Meta-Llama-3-8B-Instruct" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + MODEL_NAME, + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -77,30 +101,48 @@ def test_completions_invalid_input(): """ Ensure that we fail when the input doesn't conform to the openai api. """ - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INVALID_INPUT_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "NousResearch/Meta-Llama-3-8B-Instruct" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + MODEL_NAME, + ], + ) proc.communicate() proc.wait() assert proc.returncode != 0, f"{proc=}" def test_embeddings(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INPUT_EMBEDDING_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "intfloat/multilingual-e5-small" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -112,24 +154,26 @@ def test_embeddings(): BatchRequestOutput.model_validate_json(line) -@pytest.mark.parametrize("input_batch", - [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) +@pytest.mark.parametrize("input_batch", [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) def test_score(input_batch): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(input_batch) input_file.flush() - proc = subprocess.Popen([ - "vllm", - "run-batch", - "-i", - input_file.name, - "-o", - output_file.name, - "--model", - "BAAI/bge-reranker-v2-m3", - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "BAAI/bge-reranker-v2-m3", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 10879f0be83c8..10224dee0efe8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,25 +1,345 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Optional -from unittest.mock import MagicMock +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock import pytest +import pytest_asyncio -from vllm.config import MultiModalConfig -from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM + +from ...utils import RemoteOpenAIServer + +if TYPE_CHECKING: + from openai import OpenAI + +GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture( + scope="module", + params=[True, False], + ids=["with_tool_parser", "without_tool_parser"], +) +def with_tool_parser(request) -> bool: + return request.param + + +@pytest.fixture(scope="module") +def default_server_args(with_tool_parser: bool): + args = [ + # use half precision for speed and memory savings in CI environment + "--enforce-eager", + "--max-model-len", + "4096", + "--reasoning-parser", + "openai_gptoss", + "--gpu-memory-utilization", + "0.8", + ] + if with_tool_parser: + args.extend( + [ + "--tool-call-parser", + "openai", + "--enable-auto-tool-choice", + ] + ) + return args + + +@pytest.fixture(scope="module") +def gptoss_server( + monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] +): + with monkeypatch_module.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + with RemoteOpenAIServer( + GPT_OSS_MODEL_NAME, default_server_args + ) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def gptoss_client(gptoss_server): + async with gptoss_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_gpt_oss_chat_tool_call_streaming( + gptoss_client: OpenAI, with_tool_parser: bool +): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + } + ] + + messages = [ + {"role": "user", "content": "What is the weather in Dallas, TX?"}, + ] + + stream = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools if with_tool_parser else None, + stream=True, + ) + + name = None + args_buf = "" + content_buf = "" + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.tool_calls: + tc = delta.tool_calls[0] + if tc.function and tc.function.name: + name = tc.function.name + if tc.function and tc.function.arguments: + args_buf += tc.function.arguments + if getattr(delta, "content", None): + content_buf += delta.content + if with_tool_parser: + assert name is not None + assert len(args_buf) > 0 + else: + assert name is None + assert len(args_buf) == 0 + assert len(content_buf) > 0 + + +@pytest.mark.asyncio +async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool): + if not with_tool_parser: + pytest.skip("skip non-tool for multi-turn tests") + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + } + ] + + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"}, + ] + + first = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + first_msg = first.choices[0].message + assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0 + tc = first_msg.tool_calls[0] + assert tc.function is not None and tc.function.name == "get_current_weather" + args1 = tc.function.arguments + assert args1 is not None and len(args1) > 0 + assert not first_msg.content + + messages.append({"role": "assistant", "content": args1}) + messages.append( + {"role": "user", "content": "Now convert to celsius and return JSON only"} + ) + + second = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + second_msg = second.choices[0].message + assert (second_msg.content is not None and len(second_msg.content) > 0) or ( + second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0 + ) + + +@pytest.mark.asyncio +async def test_gpt_oss_tool_message_array_content( + gptoss_client: OpenAI, with_tool_parser: bool +): + """Test that tool messages support both string and array content formats.""" + if not with_tool_parser: + pytest.skip("skip non-tool for array content tests") + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + }, + "required": ["city", "state"], + }, + }, + } + ] + + # Test 1: Tool message with string content + messages_string = [ + {"role": "user", "content": "What's the weather in Paris?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris", "state": "TX"}', + }, + } + ], + }, + {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"}, + ] + + response_string = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages_string, + tools=tools, + temperature=0.0, + ) + + assert response_string is not None + assert response_string.choices[0].message is not None + + # Test 2: Tool message with array content + messages_array = [ + {"role": "user", "content": "What's the weather in Dallas?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_456", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Dallas", "state": "TX"}', + }, + } + ], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"} + ], + }, + ] + + response_array = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages_array, + tools=tools, + temperature=0.0, + ) + + assert response_array is not None + assert response_array.choices[0].message is not None + + # Test 3: Tool message with multiple array content items + messages_multi_array = [ + {"role": "user", "content": "Search for information"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_789", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Austin", "state": "TX"}', + }, + } + ], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "Weather data: "}, + {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"}, + {"type": "text", "text": " with 60% humidity"}, + ], + }, + ] + + response_multi_array = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages_multi_array, + tools=tools, + temperature=0.0, + ) + + assert response_multi_array is not None + assert response_multi_array.choices[0].message is not None + MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" -BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), +] @dataclass @@ -30,6 +350,7 @@ class MockHFConfig: @dataclass class MockModelConfig: task = "generate" + runner_type = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -38,35 +359,66 @@ class MockModelConfig: multimodal_config = MultiModalConfig() hf_config = MockHFConfig() logits_processor_pattern = None - diff_sampling_param: Optional[dict] = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None encoder_config = None generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} +def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + ) + serving_chat = OpenAIServingChat( + engine, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) + + async def _fake_process_inputs( + request_id, + engine_prompt, + sampling_params, + *, + lora_request, + trace_headers, + priority, + ): + return dict(engine_prompt), {} + + serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) + return serving_chat + + @dataclass class MockEngine: - - async def get_model_config(self): - return MockModelConfig() + model_config: MockModelConfig = field(default_factory=MockModelConfig) + processor: MagicMock = field(default_factory=MagicMock) + io_processor: MagicMock = field(default_factory=MagicMock) async def _async_serving_chat_init(): engine = MockEngine() - model_config = await engine.get_model_config() - models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS) - serving_completion = OpenAIServingChat(engine, - model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + models = OpenAIServingModels(engine, BASE_MODEL_PATHS) + serving_completion = OpenAIServingChat( + engine, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) return serving_completion @@ -76,29 +428,49 @@ def test_async_serving_chat_init(): @pytest.mark.asyncio -async def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=MQLLMEngineClient) +async def test_serving_chat_returns_correct_model_name(): + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=MockModelConfig()) - serving_chat = OpenAIServingChat(mock_engine, - MockModelConfig(), - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine) + messages = [{"role": "user", "content": "what is 1+1?"}] + + async def return_model_name(*args): + return args[3] + + serving_chat.chat_completion_full_generator = return_model_name + + # Test that full name is returned when short name is requested + req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when empty string is specified + req = ChatCompletionRequest(model="", messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when no model is specified + req = ChatCompletionRequest(messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + +@pytest.mark.asyncio +async def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_chat = _build_serving_chat(mock_engine) req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], - guided_decoding_backend="outlines", + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -120,30 +492,20 @@ async def test_serving_chat_should_set_correct_max_tokens(): } # Reinitialize the engine with new settings - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine) # Test Case 1: No max_tokens specified in request req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], - guided_decoding_backend="outlines", + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -175,30 +537,20 @@ async def test_serving_chat_should_set_correct_max_tokens(): } # Reinitialize the engine with new settings - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine) # Test case 1: No max_tokens specified, defaults to context_window req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], - guided_decoding_backend="outlines", + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -225,36 +577,25 @@ async def test_serving_chat_should_set_correct_max_tokens(): @pytest.mark.asyncio async def test_serving_chat_could_load_correct_generation_config(): - mock_model_config = MockModelConfig() mock_model_config.diff_sampling_param = { "temperature": 0.5, - "repetition_penalty": 1.05 + "repetition_penalty": 1.05, } - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine) req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], - guided_decoding_backend="outlines", + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -288,38 +629,30 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() mock_model_config.hf_config.model_type = model_type - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() - # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine) # Test cache_salt req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) - # By default cache_salt in the engine prompt is not set + # By default, cache_salt in the engine prompt is not set with suppress(Exception): await serving_chat.create_chat_completion(req) - assert "cache_salt" not in mock_engine.generate.call_args.args[0] + engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1] + assert "cache_salt" not in engine_prompt # Test with certain cache_salt req.cache_salt = "test_salt" with suppress(Exception): await serving_chat.create_chat_completion(req) - assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" + engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1] + assert engine_prompt.get("cache_salt") == "test_salt" diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py new file mode 100644 index 0000000000000..46d8871441a75 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from unittest.mock import Mock + +import pytest + +from vllm.config import ModelConfig +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + + +@pytest.fixture() +def serving() -> OpenAIServing: + """Create a minimal OpenAIServing instance for testing.""" + + # Create minimal mocks + engine_client = Mock() + model_config = Mock(spec=ModelConfig) + model_config.max_model_len = 32768 + models = Mock(spec=OpenAIServingModels) + models.model_config = model_config + models.processor = Mock() + models.io_processor = Mock() + + serving = OpenAIServing( + engine_client=engine_client, + models=models, + request_logger=None, + ) + return serving + + +@pytest.mark.asyncio +async def test_async_mistral_tokenizer_does_not_block_event_loop( + serving: OpenAIServing, +): + expected_tokens = [1, 2, 3] + + # Mock the blocking version to sleep + def mocked_apply_chat_template(*_args, **_kwargs): + time.sleep(2) + return expected_tokens + + mock_tokenizer = Mock(spec=MistralTokenizer) + mock_tokenizer.apply_chat_template.side_effect = mocked_apply_chat_template + + task = serving._apply_mistral_chat_template_async( + tokenizer=mock_tokenizer, messages=[], chat_template=None, tools=[] + ) + + # Ensure the event loop is not blocked + blocked_count = 0 + for _i in range(20): # Check over ~2 seconds + start = time.perf_counter() + await asyncio.sleep(0) + elapsed = time.perf_counter() - start + + # an overly generous elapsed time for slow machines + if elapsed >= 0.5: + blocked_count += 1 + + await asyncio.sleep(0.1) + + # Ensure task completes + tokens = await task + assert tokens == expected_tokens, "Mocked blocking tokenizer was not called" + assert blocked_count == 0, "Event loop blocked during tokenization" diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index bc6a0341f59f6..df5bf07a8bd41 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -8,31 +8,36 @@ import pytest from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.protocol import (ErrorResponse, - LoadLoRAAdapterRequest, - UnloadLoRAAdapterRequest) -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] -LORA_LOADING_SUCCESS_MESSAGE = ( - "Success: LoRA adapter '{lora_name}' added successfully.") +LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully." LORA_UNLOADING_SUCCESS_MESSAGE = ( - "Success: LoRA adapter '{lora_name}' removed successfully.") + "Success: LoRA adapter '{lora_name}' removed successfully." +) async def _async_serving_models_init() -> OpenAIServingModels: - mock_model_config = MagicMock(spec=ModelConfig) mock_engine_client = MagicMock(spec=EngineClient) # Set the max_model_len attribute to avoid missing attribute + mock_model_config = MagicMock(spec=ModelConfig) mock_model_config.max_model_len = 2048 + mock_engine_client.model_config = mock_model_config + mock_engine_client.processor = MagicMock() + mock_engine_client.io_processor = MagicMock() - serving_models = OpenAIServingModels(engine_client=mock_engine_client, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config, - lora_modules=None) + serving_models = OpenAIServingModels( + engine_client=mock_engine_client, + base_model_paths=BASE_MODEL_PATHS, + lora_modules=None, + ) await serving_models.init_static_loras() return serving_models @@ -42,19 +47,18 @@ async def _async_serving_models_init() -> OpenAIServingModels: async def test_serving_model_name(): serving_models = await _async_serving_models_init() assert serving_models.model_name(None) == MODEL_NAME - request = LoRARequest(lora_name="adapter", - lora_path="/path/to/adapter2", - lora_int_id=1) + request = LoRARequest( + lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1 + ) assert serving_models.model_name(request) == request.lora_name @pytest.mark.asyncio async def test_load_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter", - lora_path="/path/to/adapter2") + request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2") response = await serving_models.load_lora_adapter(request) - assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter") assert len(serving_models.lora_requests) == 1 assert "adapter" in serving_models.lora_requests assert serving_models.lora_requests["adapter"].lora_name == "adapter" @@ -73,15 +77,16 @@ async def test_load_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_load_lora_adapter_duplicate(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) - assert response == LORA_LOADING_SUCCESS_MESSAGE.format( - lora_name='adapter1') + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1") assert len(serving_models.lora_requests) == 1 - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.error.type == "InvalidUserInput" @@ -92,15 +97,15 @@ async def test_load_lora_adapter_duplicate(): @pytest.mark.asyncio async def test_unload_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) assert len(serving_models.lora_requests) == 1 request = UnloadLoRAAdapterRequest(lora_name="adapter1") response = await serving_models.unload_lora_adapter(request) - assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( - lora_name='adapter1') + assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1") assert len(serving_models.lora_requests) == 0 diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py new file mode 100644 index 0000000000000..263b076db1835 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import AsyncExitStack +from unittest.mock import MagicMock + +import pytest +import pytest_asyncio + +from vllm.entrypoints.context import ConversationContext +from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest +from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses +from vllm.entrypoints.tool_server import ToolServer +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt + + +class MockConversationContext(ConversationContext): + """Mock conversation context for testing""" + + def __init__(self): + self.init_tool_sessions_called = False + self.init_tool_sessions_args = None + self.init_tool_sessions_kwargs = None + + def append_output(self, output) -> None: + pass + + async def call_tool(self): + return [] + + def need_builtin_tool_call(self) -> bool: + return False + + def render_for_completion(self): + return [] + + async def init_tool_sessions(self, tool_server, exit_stack, request_id, mcp_tools): + self.init_tool_sessions_called = True + self.init_tool_sessions_args = (tool_server, exit_stack, request_id, mcp_tools) + + async def cleanup_session(self) -> None: + pass + + +@pytest.fixture +def mock_serving_responses(): + """Create a mock OpenAIServingResponses instance""" + serving_responses = MagicMock(spec=OpenAIServingResponses) + serving_responses.tool_server = MagicMock(spec=ToolServer) + return serving_responses + + +@pytest.fixture +def mock_context(): + """Create a mock conversation context""" + return MockConversationContext() + + +@pytest.fixture +def mock_exit_stack(): + """Create a mock async exit stack""" + return MagicMock(spec=AsyncExitStack) + + +class TestInitializeToolSessions: + """Test class for _initialize_tool_sessions method""" + + @pytest_asyncio.fixture + async def serving_responses_instance(self): + """Create a real OpenAIServingResponses instance for testing""" + # Create minimal mocks for required dependencies + engine_client = MagicMock() + + model_config = MagicMock() + model_config.hf_config.model_type = "test" + model_config.get_diff_sampling_param.return_value = {} + engine_client.model_config = model_config + + engine_client.processor = MagicMock() + engine_client.io_processor = MagicMock() + + models = MagicMock() + + tool_server = MagicMock(spec=ToolServer) + + # Create the actual instance + instance = OpenAIServingResponses( + engine_client=engine_client, + models=models, + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + tool_server=tool_server, + ) + + return instance + + @pytest.mark.asyncio + async def test_initialize_tool_sessions( + self, serving_responses_instance, mock_context, mock_exit_stack + ): + """Test that method works correctly with only MCP tools""" + + request = ResponsesRequest(input="test input", tools=[]) + + # Call the method + await serving_responses_instance._initialize_tool_sessions( + request, mock_context, mock_exit_stack + ) + assert mock_context.init_tool_sessions_called is False + + # Create only MCP tools + tools = [ + {"type": "web_search_preview"}, + {"type": "code_interpreter", "container": {"type": "auto"}}, + ] + + request = ResponsesRequest(input="test input", tools=tools) + + # Call the method + await serving_responses_instance._initialize_tool_sessions( + request, mock_context, mock_exit_stack + ) + + # Verify that init_tool_sessions was called + assert mock_context.init_tool_sessions_called + + +class TestValidateGeneratorInput: + """Test class for _validate_generator_input method""" + + @pytest_asyncio.fixture + async def serving_responses_instance(self): + """Create a real OpenAIServingResponses instance for testing""" + # Create minimal mocks for required dependencies + engine_client = MagicMock() + + model_config = MagicMock() + model_config.hf_config.model_type = "test" + model_config.get_diff_sampling_param.return_value = {} + engine_client.model_config = model_config + + engine_client.processor = MagicMock() + engine_client.io_processor = MagicMock() + + models = MagicMock() + + # Create the actual instance + instance = OpenAIServingResponses( + engine_client=engine_client, + models=models, + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + ) + + # Set max_model_len for testing + instance.max_model_len = 100 + + return instance + + def test_validate_generator_input(self, serving_responses_instance): + """Test _validate_generator_input with valid prompt length""" + # Create an engine prompt with valid length (less than max_model_len) + valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len + engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids) + + # Call the method + result = serving_responses_instance._validate_generator_input(engine_prompt) + + # Should return None for valid input + assert result is None + + # create an invalid engine prompt + invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len + engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids) + + # Call the method + result = serving_responses_instance._validate_generator_input(engine_prompt) + + # Should return an ErrorResponse + assert result is not None + assert isinstance(result, ErrorResponse) diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 29a94c852bba6..ff46df81d0fff 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -24,16 +24,13 @@ async def test_shutdown_on_engine_failure(): with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: async with remote_server.get_async_client() as client: - - with pytest.raises( - (openai.APIConnectionError, openai.InternalServerError)): + with pytest.raises((openai.APIConnectionError, openai.InternalServerError)): # Asking for lots of prompt logprobs will currently crash the # engine. This may change in the future when that bug is fixed prompt = "Hello " * 4000 await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - extra_body={"prompt_logprobs": 10}) + model=MODEL_NAME, prompt=prompt, extra_body={"prompt_logprobs": 10} + ) # Now the server should shut down return_code = remote_server.proc.wait(timeout=8) diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index 0bb42ed8aa7fb..6998566c03d02 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -11,18 +11,10 @@ import torch from ...utils import RemoteOpenAIServer -MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" DTYPE = "float16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = [ @@ -35,7 +27,9 @@ def server(): "--trust-remote-code", "--skip-tokenizer-init", "--max-num-seqs", - "32" + "32", + "--model-impl", + "terratorch", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -45,7 +39,6 @@ def server(): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_request(server: RemoteOpenAIServer, model_name: str): - pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) @@ -53,40 +46,39 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): torch.save(pixel_values, buffer_tiff) buffer_tiff.seek(0) binary_data = buffer_tiff.read() - base64_tensor_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8") buffer_coord = io.BytesIO() torch.save(location_coords, buffer_coord) buffer_coord.seek(0) binary_data = buffer_coord.read() - base64_coord_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") prompt = { - "model": - model_name, - "additional_data": { - "prompt_token_ids": [1] - }, - "encoding_format": - "base64", - "messages": [{ - "role": - "user", - "content": [{ - "type": "image_embeds", - "image_embeds": { - "pixel_values": base64_tensor_embedding, - "location_coords": base64_coord_embedding, - }, - }], - }] + "model": model_name, + "additional_data": {"prompt_token_ids": [1]}, + "encoding_format": "base64", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "pixel_values": base64_tensor_embedding, + "location_coords": base64_coord_embedding, + }, + } + ], + } + ], } # test single pooling response = requests.post(server.url_for("pooling"), json=prompt) response.raise_for_status() - output = response.json()["data"][0]['data'] + output = response.json()["data"][0]["data"] np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) diff --git a/tests/entrypoints/openai/test_sleep.py b/tests/entrypoints/openai/test_sleep.py index 0dd6af17ef227..e07436f89d2d2 100644 --- a/tests/entrypoints/openai/test_sleep.py +++ b/tests/entrypoints/openai/test_sleep.py @@ -20,14 +20,12 @@ def test_sleep_mode(): "--enable-sleep-mode", ] - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={ - "VLLM_SERVER_DEV_MODE": "1", - "CUDA_VISIBLE_DEVICES": "0" - }) as remote_server: - response = requests.post(remote_server.url_for("sleep"), - params={"level": "1"}) + with RemoteOpenAIServer( + MODEL_NAME, + args, + env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"}, + ) as remote_server: + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 response = requests.get(remote_server.url_for("is_sleeping")) assert response.status_code == 200 @@ -40,12 +38,12 @@ def test_sleep_mode(): assert response.json().get("is_sleeping") is False # test wake up with tags - response = requests.post(remote_server.url_for("sleep"), - params={"level": "1"}) + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 - response = requests.post(remote_server.url_for("wake_up"), - params={"tags": ["weights"]}) + response = requests.post( + remote_server.url_for("wake_up"), params={"tags": ["weights"]} + ) assert response.status_code == 200 # is sleeping should be false after waking up any part of the engine @@ -53,8 +51,9 @@ def test_sleep_mode(): assert response.status_code == 200 assert response.json().get("is_sleeping") is True - response = requests.post(remote_server.url_for("wake_up"), - params={"tags": ["kv_cache"]}) + response = requests.post( + remote_server.url_for("wake_up"), params={"tags": ["kv_cache"]} + ) assert response.status_code == 200 response = requests.get(remote_server.url_for("is_sleeping")) diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index 058e96f203c38..80b7cd9f4cbc9 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -11,7 +11,10 @@ import torch.cuda from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model) + TensorizerConfig, + tensorize_lora_adapter, + tensorize_vllm_model, +) from ...utils import RemoteOpenAIServer @@ -29,21 +32,20 @@ def cleanup(): _cleanup() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tmp_dir(): with tempfile.TemporaryDirectory() as path: yield path -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def model_uri(tmp_dir): yield f"{tmp_dir}/model.tensors" @pytest.fixture(scope="module") def tensorize_model_and_lora(tmp_dir, model_uri): - tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, - lora_dir=tmp_dir) + tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, lora_dir=tmp_dir) args = EngineArgs(model=MODEL_NAME) tensorize_lora_adapter(LORA_PATH, tensorizer_config) @@ -66,8 +68,11 @@ def server(model_uri, tensorize_model_and_lora): ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--served-model-name", MODEL_NAME, - "--enable-lora" + "--load-format", + "tensorizer", + "--served-model-name", + MODEL_NAME, + "--enable-lora", ] model_dir = os.path.dirname(model_uri) @@ -85,10 +90,9 @@ async def client(server): @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): _cleanup() - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -97,4 +101,5 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + completion_tokens=5, prompt_tokens=6, total_tokens=11 + ) diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py new file mode 100644 index 0000000000000..25eb5882be89c --- /dev/null +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import tempfile + +import pytest + +from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b") + + +@pytest.fixture(scope="module") +def server(): + global MODEL_PATH + MODEL_PATH = download_weights_from_hf( + MODEL_NAME, + allow_patterns=["*"], + cache_dir=MODEL_PATH, + ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"], + ) + args = [ + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + "--skip-tokenizer-init", + "--load-format", + "dummy", + ] + with RemoteOpenAIServer(MODEL_PATH, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +async def test_token_in_token_out_and_logprobs(server): + """ + Test token-in-token-out and token_ids align with prompt_logprobs + & logprobs when return_tokens_as_token_ids is enabled. + """ + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + text = "Hello, world! How are you today?" + token_ids = tokenizer.encode(text) + async with server.get_async_client() as client: + # Test with both return_token_ids and return_tokens_as_token_ids enabled + completion = await client.completions.create( + model=MODEL_PATH, + prompt=token_ids, + max_tokens=20, + temperature=0, + echo=True, + extra_body={ + "return_token_ids": True, + }, + ) + + # Verify all fields are present + assert ( + completion.choices[0].token_ids is not None + and 0 < len(completion.choices[0].token_ids) <= 20 + ) + assert completion.choices[0].prompt_token_ids is not None + + # Decode prompt tokens + if completion.choices[0].prompt_token_ids: + prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids) + # The decoded prompt should match or close to original prompt + assert prompt_text == text diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 0dbbdfbfd24ad..7fd32e1c7be1d 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -8,15 +8,13 @@ import requests from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 -from .test_completion import zephyr_lora_files # noqa: F401 # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @pytest.fixture(scope="module") -def server(zephyr_lora_added_tokens_files: str): # noqa: F811 +def server(): args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -26,12 +24,6 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 "--enforce-eager", "--max-num-seqs", "128", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", "--enable-tokenizer-info-endpoint", ] @@ -40,10 +32,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 @pytest.fixture(scope="module") -def tokenizer_name(model_name: str, - zephyr_lora_added_tokens_files: str): # noqa: F811 - return zephyr_lora_added_tokens_files if ( - model_name == "zephyr-lora2") else model_name +def tokenizer_name(model_name: str): + return model_name @pytest_asyncio.fixture @@ -55,7 +45,7 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_completions( @@ -63,19 +53,20 @@ async def test_tokenize_completions( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_special in [False, True]: prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(server.url_for("tokenize"), - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt - }) + response = requests.post( + server.url_for("tokenize"), + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt, + }, + ) response.raise_for_status() result = response.json() @@ -88,7 +79,7 @@ async def test_tokenize_completions( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_chat( @@ -96,48 +87,39 @@ async def test_tokenize_chat( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": "user", - "content": "Hi there!" - }, { - "role": "assistant", - "content": "Nice to meet you!" - }, { - "role": "user", - "content": "Can I ask a question? vllm1" - }] + conversation = [ + {"role": "user", "content": "Hi there!"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "Can I ask a question? vllm1"}, + ] for continue_final in [False, True]: if add_generation and continue_final: continue if continue_final: - conversation.append({ - "role": "assistant", - "content": "Sure," - }) + conversation.append({"role": "assistant", "content": "Sure,"}) prompt = tokenizer.apply_chat_template( add_generation_prompt=add_generation, continue_final_message=continue_final, conversation=conversation, - tokenize=False) - tokens = tokenizer.encode(prompt, - add_special_tokens=add_special) + tokenize=False, + ) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(server.url_for("tokenize"), - json={ - "add_generation_prompt": - add_generation, - "continue_final_message": - continue_final, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + }, + ) response.raise_for_status() result = response.json() @@ -150,7 +132,7 @@ async def test_tokenize_chat( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_chat_with_tools( @@ -158,41 +140,35 @@ async def test_tokenize_chat_with_tools( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": - "user", - "content": - "What's the weather like in Paris today?", - }] + conversation = [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ] - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string" - } + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, }, }, - }, - }] + } + ] for continue_final in [False, True]: if add_generation and continue_final: continue if continue_final: - conversation.append({ - "role": "assistant", - "content": "Sure," - }) + conversation.append({"role": "assistant", "content": "Sure,"}) prompt = tokenizer.apply_chat_template( add_generation_prompt=add_generation, @@ -201,8 +177,7 @@ async def test_tokenize_chat_with_tools( tools=tools, tokenize=False, ) - tokens = tokenizer.encode(prompt, - add_special_tokens=add_special) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) response = requests.post( server.url_for("tokenize"), @@ -227,7 +202,7 @@ async def test_tokenize_chat_with_tools( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name, tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_with_return_token_strs( @@ -235,17 +210,12 @@ async def test_tokenize_with_return_token_strs( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a token_strs test prompt! vllm1" response = requests.post( server.url_for("tokenize"), - json={ - "prompt": prompt, - "model": model_name, - "return_token_strs": True - }, + json={"prompt": prompt, "model": model_name, "return_token_strs": True}, ) response.raise_for_status() @@ -262,7 +232,7 @@ async def test_tokenize_with_return_token_strs( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_detokenize( @@ -270,17 +240,14 @@ async def test_detokenize( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) - response = requests.post(server.url_for("detokenize"), - json={ - "model": model_name, - "tokens": tokens - }) + response = requests.post( + server.url_for("detokenize"), json={"model": model_name, "tokens": tokens} + ) response.raise_for_status() assert response.json() == {"prompt": prompt} @@ -289,7 +256,7 @@ async def test_detokenize( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenizer_info_basic( @@ -329,14 +296,15 @@ async def test_tokenizer_info_schema(server: RemoteOpenAIServer): } for field, expected_type in field_types.items(): if field in result and result[field] is not None: - assert isinstance( - result[field], - expected_type), (f"{field} should be {expected_type.__name__}") + assert isinstance(result[field], expected_type), ( + f"{field} should be {expected_type.__name__}" + ) @pytest.mark.asyncio async def test_tokenizer_info_added_tokens_structure( - server: RemoteOpenAIServer, ): + server: RemoteOpenAIServer, +): """Test added_tokens_decoder structure if present.""" response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() @@ -347,25 +315,23 @@ async def test_tokenizer_info_added_tokens_structure( assert isinstance(token_id, str), "Token IDs should be strings" assert isinstance(token_info, dict), "Token info should be a dict" assert "content" in token_info, "Token info should have content" - assert "special" in token_info, ( - "Token info should have special flag") - assert isinstance(token_info["special"], - bool), ("Special flag should be boolean") + assert "special" in token_info, "Token info should have special flag" + assert isinstance(token_info["special"], bool), ( + "Special flag should be boolean" + ) @pytest.mark.asyncio async def test_tokenizer_info_consistency_with_tokenize( - server: RemoteOpenAIServer, ): + server: RemoteOpenAIServer, +): """Test that tokenizer info is consistent with tokenization endpoint.""" info_response = requests.get(server.url_for("tokenizer_info")) info_response.raise_for_status() info = info_response.json() tokenize_response = requests.post( server.url_for("tokenize"), - json={ - "model": MODEL_NAME, - "prompt": "Hello world!" - }, + json={"model": MODEL_NAME, "prompt": "Hello world!"}, ) tokenize_response.raise_for_status() tokenize_result = tokenize_response.json() @@ -373,7 +339,8 @@ async def test_tokenizer_info_consistency_with_tokenize( tokenize_max_len = tokenize_result.get("max_model_len") if info_max_len and tokenize_max_len: assert info_max_len >= tokenize_max_len, ( - "Info max length should be >= tokenize max length") + "Info max length should be >= tokenize max length" + ) @pytest.mark.asyncio @@ -384,6 +351,5 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): result = response.json() chat_template = result.get("chat_template") if chat_template: - assert isinstance(chat_template, - str), ("Chat template should be a string") - assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file + assert isinstance(chat_template, str), "Chat template should be a string" + assert chat_template.strip(), "Chat template should not be empty" diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 93239f41a4aeb..6ef932392d095 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# imports for guided decoding tests +# imports for structured outputs tests import io import json @@ -12,32 +12,20 @@ import pytest import pytest_asyncio import soundfile as sf -from vllm.assets.audio import AudioAsset - from ...utils import RemoteOpenAIServer MODEL_NAME = "openai/whisper-large-v3-turbo" SERVER_ARGS = ["--enforce-eager"] MISTRAL_FORMAT_ARGS = [ - "--tokenizer_mode", "mistral", "--config_format", "mistral", - "--load_format", "mistral" + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", ] -@pytest.fixture -def mary_had_lamb(): - path = AudioAsset('mary_had_lamb').get_local_path() - with open(str(path), "rb") as f: - yield f - - -@pytest.fixture -def winning_call(): - path = AudioAsset('winning_call').get_local_path() - with open(str(path), "rb") as f: - yield f - - @pytest.fixture(scope="module") def server(): with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: @@ -52,8 +40,8 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( - "model_name", - ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]) + "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"] +) async def test_basic_audio(mary_had_lamb, model_name): server_args = ["--enforce-eager"] @@ -68,9 +56,33 @@ async def test_basic_audio(mary_had_lamb, model_name): file=mary_had_lamb, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - assert "Mary had a little lamb," in out + temperature=0.0, + ) + out = json.loads(transcription) + out_text = out["text"] + out_usage = out["usage"] + assert "Mary had a little lamb," in out_text + assert out_usage["seconds"] == 16, out_usage["seconds"] + + +@pytest.mark.asyncio +async def test_basic_audio_gemma(foscolo): + # Gemma accuracy on some of the audio samples we use is particularly bad, + # hence we use a different one here. WER is evaluated separately. + model_name = "google/gemma-3n-E2B-it" + server_args = ["--enforce-eager"] + + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=foscolo, + language="it", + response_format="text", + temperature=0.0, + ) + out = json.loads(transcription)["text"] + assert "da cui vergine nacque Venere" in out @pytest.mark.asyncio @@ -79,24 +91,21 @@ async def test_non_asr_model(winning_call): model_name = "JackFram/llama-68m" with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() - res = await client.audio.transcriptions.create(model=model_name, - file=winning_call, - language="en", - temperature=0.0) + res = await client.audio.transcriptions.create( + model=model_name, file=winning_call, language="en", temperature=0.0 + ) err = res.error assert err["code"] == 400 and not res.text - assert err[ - "message"] == "The model does not support Transcriptions API" + assert err["message"] == "The model does not support Transcriptions API" @pytest.mark.asyncio async def test_bad_requests(mary_had_lamb, client): # invalid language with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=MODEL_NAME, - file=mary_had_lamb, - language="hh", - temperature=0.0) + await client.audio.transcriptions.create( + model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0 + ) @pytest.mark.asyncio @@ -108,17 +117,21 @@ async def test_long_audio_request(mary_had_lamb, client): repeated_audio = np.tile(audio, 10) # Repeated audio to buffer buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') + sf.write(buffer, repeated_audio, sr, format="WAV") buffer.seek(0) transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=buffer, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] - counts = out.count("Mary had a little lamb") + temperature=0.0, + ) + out = json.loads(transcription) + out_text = out["text"] + out_usage = out["usage"] + counts = out_text.count("Mary had a little lamb") assert counts == 10, counts + assert out_usage["seconds"] == 161, out_usage["seconds"] @pytest.mark.asyncio @@ -126,10 +139,8 @@ async def test_completion_endpoints(client): # text to text model res = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }]) + messages=[{"role": "system", "content": "You are a helpful assistant."}], + ) err = res.error assert err["code"] == 400 assert err["message"] == "The model does not support Chat Completions API" @@ -148,16 +159,19 @@ async def test_streaming_response(winning_call, client): file=winning_call, response_format="json", language="en", - temperature=0.0) - res = await client.audio.transcriptions.create(model=MODEL_NAME, - file=winning_call, - language="en", - temperature=0.0, - stream=True, - timeout=30) + temperature=0.0, + ) + res = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + language="en", + temperature=0.0, + stream=True, + timeout=30, + ) # Reconstruct from chunks and validate async for chunk in res: - text = chunk.choices[0]['delta']['content'] + text = chunk.choices[0]["delta"]["content"] transcription += text assert transcription == res_no_stream.text @@ -171,9 +185,9 @@ async def test_stream_options(winning_call, client): language="en", temperature=0.0, stream=True, - extra_body=dict(stream_include_usage=True, - stream_continuous_usage_stats=True), - timeout=30) + extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True), + timeout=30, + ) final = False continuous = True async for chunk in res: @@ -181,7 +195,7 @@ async def test_stream_options(winning_call, client): # final usage sent final = True else: - continuous = continuous and hasattr(chunk, 'usage') + continuous = continuous and hasattr(chunk, "usage") assert final and continuous @@ -189,27 +203,31 @@ async def test_stream_options(winning_call, client): async def test_sampling_params(mary_had_lamb, client): """ Compare sampling with params and greedy sampling to assert results - are different when extreme sampling parameters values are picked. + are different when extreme sampling parameters values are picked. """ transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=mary_had_lamb, language="en", temperature=0.8, - extra_body=dict(seed=42, - repetition_penalty=1.9, - top_k=12, - top_p=0.4, - min_p=0.5, - frequency_penalty=1.8, - presence_penalty=2.0)) + extra_body=dict( + seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0, + ), + ) greedy_transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=mary_had_lamb, language="en", temperature=0.0, - extra_body=dict(seed=42)) + extra_body=dict(seed=42), + ) assert greedy_transcription.text != transcription.text @@ -217,15 +235,16 @@ async def test_sampling_params(mary_had_lamb, client): @pytest.mark.asyncio async def test_audio_prompt(mary_had_lamb, client): prompt = "This is a speech, recorded in a phonograph." - #Prompts should not omit the part of original prompt while transcribing. + # Prompts should not omit the part of original prompt while transcribing. prefix = "The first words I spoke in the original phonograph" transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=mary_had_lamb, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert prefix in out transcription_wprompt = await client.audio.transcriptions.create( model=MODEL_NAME, @@ -233,6 +252,7 @@ async def test_audio_prompt(mary_had_lamb, client): language="en", response_format="text", prompt=prompt, - temperature=0.0) - out_prompt = json.loads(transcription_wprompt)['text'] + temperature=0.0, + ) + out_prompt = json.loads(transcription_wprompt)["text"] assert prefix in out_prompt diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index f4f5c66f2deeb..f35742e166fe0 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io -# imports for guided decoding tests + +# imports for structured outputs tests import json import httpx @@ -12,32 +13,25 @@ import pytest import pytest_asyncio import soundfile as sf -from vllm.assets.audio import AudioAsset - from ...utils import RemoteOpenAIServer -MODEL_NAME = "openai/whisper-small" SERVER_ARGS = ["--enforce-eager"] -@pytest.fixture -def foscolo(): - # Test translation it->en - path = AudioAsset('azacinto_foscolo').get_local_path() - with open(str(path), "rb") as f: - yield f - - -@pytest.fixture(scope="module") -def server(): - with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: - yield remote_server +@pytest.fixture( + scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"] +) +def server(request): + # Parametrize over model name + with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server: + yield remote_server, request.param @pytest_asyncio.fixture -async def client(server): +async def client_and_model(server): + server, model_name = server async with server.get_async_client() as async_client: - yield async_client + yield async_client, model_name @pytest.mark.asyncio @@ -46,9 +40,9 @@ async def test_non_asr_model(foscolo): model_name = "JackFram/llama-68m" with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() - res = await client.audio.translations.create(model=model_name, - file=foscolo, - temperature=0.0) + res = await client.audio.translations.create( + model=model_name, file=foscolo, temperature=0.0 + ) err = res.error assert err["code"] == 400 and not res.text assert err["message"] == "The model does not support Translations API" @@ -56,81 +50,98 @@ async def test_non_asr_model(foscolo): # NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! @pytest.mark.asyncio -async def test_basic_audio(foscolo, client): +async def test_basic_audio(foscolo, client_and_model): + client, model_name = client_and_model translation = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=foscolo, response_format="text", - # TODO remove once language detection is implemented - extra_body=dict(language="it"), - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() + # TODO remove `language="it"` once language detection is implemented + extra_body=dict(language="it", to_language="en"), + temperature=0.0, + ) + out = json.loads(translation)["text"].strip().lower() assert "greek sea" in out @pytest.mark.asyncio -async def test_audio_prompt(foscolo, client): +async def test_audio_prompt(foscolo, client_and_model): + client, model_name = client_and_model # Condition whisper on starting text prompt = "Nor have I ever" transcription = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=foscolo, prompt=prompt, - extra_body=dict(language="it"), + extra_body=dict(language="it", to_language="en"), response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert "Nor will I ever touch the sacred" not in out assert prompt not in out @pytest.mark.asyncio -async def test_streaming_response(foscolo, client, server): +async def test_streaming_response(foscolo, client_and_model, server): + client, model_name = client_and_model translation = "" res_no_stream = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=foscolo, response_format="json", - extra_body=dict(language="it"), - temperature=0.0) + extra_body=dict(language="it", to_language="en", seed=42), + temperature=0.0, + ) + # Stream via HTTPX since OpenAI translation client doesn't expose streaming + server, model_name = server url = server.url_for("v1/audio/translations") headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} data = { - "model": MODEL_NAME, + "model": model_name, "language": "it", + "to_language": "en", "stream": True, "temperature": 0.0, + "seed": 42, } foscolo.seek(0) async with httpx.AsyncClient() as http_client: files = {"file": foscolo} - async with http_client.stream("POST", - url, - headers=headers, - data=data, - files=files) as response: + async with http_client.stream( + "POST", url, headers=headers, data=data, files=files + ) as response: async for line in response.aiter_lines(): if not line: continue if line.startswith("data: "): - line = line[len("data: "):] + line = line[len("data: ") :] if line.strip() == "[DONE]": break chunk = json.loads(line) text = chunk["choices"][0].get("delta", {}).get("content") translation += text or "" - assert translation == res_no_stream.text + res_stream = translation.split() + # NOTE There's a small non-deterministic issue here, likely in the attn + # computation, which will cause a few tokens to be different, while still + # being very close semantically. + assert ( + sum([x == y for x, y in zip(res_stream, res_no_stream.text.split())]) + >= len(res_stream) * 0.9 + ) @pytest.mark.asyncio -async def test_stream_options(foscolo, client, server): +async def test_stream_options(foscolo, server): + server, model_name = server url = server.url_for("v1/audio/translations") headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} data = { - "model": MODEL_NAME, + "model": model_name, "language": "it", + "to_language": "en", "stream": True, "stream_include_usage": True, "stream_continuous_usage_stats": True, @@ -141,16 +152,14 @@ async def test_stream_options(foscolo, client, server): continuous = True async with httpx.AsyncClient() as http_client: files = {"file": foscolo} - async with http_client.stream("POST", - url, - headers=headers, - data=data, - files=files) as response: + async with http_client.stream( + "POST", url, headers=headers, data=data, files=files + ) as response: async for line in response.aiter_lines(): if not line: continue if line.startswith("data: "): - line = line[len("data: "):] + line = line[len("data: ") :] if line.strip() == "[DONE]": break chunk = json.loads(line) @@ -164,19 +173,23 @@ async def test_stream_options(foscolo, client, server): @pytest.mark.asyncio -async def test_long_audio_request(foscolo, client): +async def test_long_audio_request(foscolo, client_and_model): + client, model_name = client_and_model + if model_name == "google/gemma-3n-E2B-it": + pytest.skip("Gemma3n does not support long audio requests") foscolo.seek(0) audio, sr = librosa.load(foscolo) repeated_audio = np.tile(audio, 2) # Repeated audio to buffer buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') + sf.write(buffer, repeated_audio, sr, format="WAV") buffer.seek(0) translation = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=buffer, - extra_body=dict(language="it"), + extra_body=dict(language="it", to_language="en"), response_format="text", - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() + temperature=0.0, + ) + out = json.loads(translation)["text"].strip().lower() assert out.count("greek sea") == 2 diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index ad4dff00daaa4..4c7d1c14ca17b 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -58,24 +58,18 @@ def base64_encoded_video() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video(client: openai.AsyncOpenAI, - model_name: str, video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_single_chat_session_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -84,13 +78,15 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6287, total_tokens=6297) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297 + ) message = choice.message message = chat_completion.choices[0].message @@ -112,54 +108,44 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI, - model_name: str, - video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": video_url - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_error_on_invalid_video_url_type( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": video_url}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # video_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, - model_name: str, - video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_single_chat_session_video_beamsearch( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -168,36 +154,38 @@ async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded( - client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": - f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + video_url: str, + base64_encoded_video: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -206,13 +194,15 @@ async def test_single_chat_session_video_base64encoded( max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6287, total_tokens=6297) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297 + ) message = choice.message message = chat_completion.choices[0].message @@ -236,58 +226,54 @@ async def test_single_chat_session_video_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": - f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + video_url: str, + base64_encoded_video: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, n=2, max_completion_tokens=10, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_chat_streaming_video(client: openai.AsyncOpenAI, - model_name: str, video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_chat_streaming_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -327,27 +313,23 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "video_urls", - [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))]) -async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str, - video_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "video_url", - "video_url": { - "url": video_url - } - } for video_url in video_urls), - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + "video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))] +) +async def test_multi_video_input( + client: openai.AsyncOpenAI, model_name: str, video_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "video_url", "video_url": {"url": video_url}} + for video_url in video_urls + ), + {"type": "text", "text": "What's in this video?"}, + ], + } + ] if len(video_urls) > MAXIMUM_VIDEOS: with pytest.raises(openai.BadRequestError): # test multi-video input diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 8259a81d7b6a1..5a15a352f45cc 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -6,8 +6,6 @@ import json import openai import pytest import pytest_asyncio -import requests -from PIL import Image from transformers import AutoProcessor from vllm.multimodal.utils import encode_image_base64, fetch_image @@ -18,11 +16,11 @@ MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] EXPECTED_MM_BEAM_SEARCH_RES = [ @@ -36,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [ ], [ "The image shows a Venn diagram with three over", - "The image shows a Venn diagram with three intersect", + "This image shows a Venn diagram with three over", ], [ "This image displays a gradient of colors ranging from", - "The image displays a gradient of colors ranging from", + "This image displays a gradient of colors forming a spectrum", ], ] @@ -71,27 +69,32 @@ async def client(server): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_asset: encode_image_base64( + local_asset_server.get_image_asset(image_asset) + ) + for image_asset in TEST_IMAGE_ASSETS } def get_hf_prompt_tokens(model_name, content, image_url): - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True, - num_crops=4) + processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True, num_crops=4 + ) placeholder = "<|image_1|>\n" - messages = [{ - "role": "user", - "content": f"{placeholder}{content}", - }] - images = [Image.open(requests.get(image_url, stream=True).raw)] + messages = [ + { + "role": "user", + "content": f"{placeholder}{content}", + } + ] + images = [fetch_image(image_url)] prompt = processor.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True) + messages, tokenize=False, add_generation_prompt=True + ) inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] @@ -99,26 +102,20 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_single_chat_session_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) +async def test_single_chat_session_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": content_text}, + ], + } + ] max_completion_tokens = 10 # test single completion @@ -128,17 +125,18 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, max_completion_tokens=max_completion_tokens, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, - total_tokens=hf_prompt_tokens + max_completion_tokens) + total_tokens=hf_prompt_tokens + max_completion_tokens, + ) message = choice.message message = chat_completion.choices[0].message @@ -159,56 +157,46 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, - model_name: str, - image_url: str): +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) +async def test_error_on_invalid_image_url_type( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": image_url - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": image_url}, + {"type": "text", "text": content_text}, + ], + } + ] # image_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, - model_name: str, - image_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) +async def test_single_chat_session_image_beamsearch( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -217,37 +205,41 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image_base64encoded( - client: openai.AsyncOpenAI, model_name: str, image_url: str, - base64_encoded_image: dict[str, str]): - + client: openai.AsyncOpenAI, + model_name: str, + raw_image_url: str, + image_url: str, + base64_encoded_image: dict[str, str], +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": content_text}, + ], + } + ] max_completion_tokens = 10 # test single completion @@ -257,17 +249,18 @@ async def test_single_chat_session_image_base64encoded( max_completion_tokens=max_completion_tokens, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, - total_tokens=hf_prompt_tokens + max_completion_tokens) + total_tokens=hf_prompt_tokens + max_completion_tokens, + ) message = choice.message message = chat_completion.choices[0].message @@ -289,38 +282,39 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS)))) +@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_ASSETS)))) async def test_single_chat_session_image_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, image_idx: int, - base64_encoded_image: dict[str, str]): + client: openai.AsyncOpenAI, + model_name: str, + image_idx: int, + base64_encoded_image: dict[str, str], +): # NOTE: This test also validates that we pass MM data through beam search - image_url = TEST_IMAGE_URLS[image_idx] + raw_image_url = TEST_IMAGE_ASSETS[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, n=2, max_completion_tokens=10, temperature=0.0, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 for actual, expected_str in zip(chat_completion.choices, expected_res): assert actual.message.content == expected_str @@ -328,25 +322,19 @@ async def test_single_chat_session_image_base64encoded_beamsearch( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_chat_streaming_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) +async def test_chat_streaming_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -387,26 +375,24 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) -async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True, +) +async def test_multi_image_input( + client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": "What's in this image?"}, + ], + } + ] if len(image_urls) > MAXIMUM_IMAGES: with pytest.raises(openai.BadRequestError): # test multi-image input @@ -435,3 +421,175 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True, +) +async def test_completions_with_image( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True, +) +async def test_completions_with_image_with_uuid( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_url, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + # Second request, with empty image but the same uuid. + chat_completion_with_empty_image = await client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + {"type": "image_url", "image_url": {}, "uuid": image_url}, + ], + }, + ], + model=model_name, + ) + assert chat_completion_with_empty_image.choices[0].message.content is not None + assert isinstance( + chat_completion_with_empty_image.choices[0].message.content, str + ) + assert len(chat_completion_with_empty_image.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_empty_image_with_uuid_without_cache_hit( + client: openai.AsyncOpenAI, + model_name: str, +): + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": {}, + "uuid": "uuid_not_previously_seen", + }, + ], + }, + ], + model=model_name, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True, +) +async def test_completions_with_image_with_incorrect_uuid_format( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + "incorrect_uuid_key": image_url, + }, + "also_incorrect_uuid_key": image_url, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 28b1f8358d80b..38008dafe32b2 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -5,6 +5,10 @@ import json import pytest +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser +from vllm.transformers_utils.tokenizer import AnyTokenizer + from ....utils import RemoteOpenAIServer MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -18,33 +22,69 @@ SERVER_ARGS = [ "--enable-lora", "--lora-modules", f"{LORA_MODEL}={LORA_MODEL}", + "--tokenizer", + f"{LORA_MODEL}", ] -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": - "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, + "required": ["location"], }, - "required": ["location"], }, - }, -}] + } +] + +PRODUCT_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_product_info", + "description": "Get detailed information of a product based on its " + "product ID.", + "parameters": { + "type": "object", + "properties": { + "inserted": { + "type": "boolean", + "description": "inserted.", + }, + "product_id": { + "type": "integer", + "description": "The product ID of the product.", + }, + }, + "required": ["product_id", "inserted"], + }, + }, + } +] MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] +PRODUCT_MESSAGES = [ + { + "role": "user", + "content": "Hi! Do you have any detailed information about the product id " + "7355608 and inserted true?", + } +] + @pytest.mark.asyncio async def test_non_streaming_tool_call(): @@ -111,8 +151,9 @@ async def test_streaming_tool_call(): if tool_chunk.function.name: tool_call_chunks[index]["name"] += tool_chunk.function.name if tool_chunk.function.arguments: - tool_call_chunks[index][ - "arguments"] += tool_chunk.function.arguments + tool_call_chunks[index]["arguments"] += ( + tool_chunk.function.arguments + ) assert len(tool_call_chunks) == 1 reconstructed_tool_call = tool_call_chunks[0] @@ -125,3 +166,295 @@ async def test_streaming_tool_call(): print("\n[Streaming Test Passed]") print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") print(f"Reconstructed Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_non_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_product_info" + + arguments = json.loads(tool_call.function.arguments) + assert "product_id" in arguments + assert "inserted" in arguments + + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Non-Streaming Product Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index]["arguments"] += ( + tool_chunk.function.arguments + ) + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_product_info" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "product_id" in arguments + assert "inserted" in arguments + + # Handle type coercion for streaming test as well + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Streaming Product Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") + + +@pytest.fixture +def qwen_tokenizer() -> AnyTokenizer: + from vllm.transformers_utils.tokenizer import get_tokenizer + + return get_tokenizer("Qwen/Qwen3-32B") + + +@pytest.fixture +def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser: + return Hermes2ProToolParser(qwen_tokenizer) + + +@pytest.fixture +def any_chat_request() -> ChatCompletionRequest: + return ChatCompletionRequest( + seed=42, + model="Qwen/Qwen3-32B", + messages=[], + ) + + +def test_hermes_parser_streaming_just_forward_text( + qwen_tokenizer: AnyTokenizer, + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """This is some prior text that has nothing to do with tool calling.""" + tokens = qwen_tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for token in tokens: + delta_text = qwen_tokenizer.decode([token]) + current_text = previous_text + delta_text + delta = hermes_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + previous_text = current_text + delta_messages.append(delta) + + for delta in delta_messages: + assert delta is not None + assert not delta.tool_calls + + print(delta_messages) + assert "".join([delta.content for delta in delta_messages]) == text + + +def test_hermes_parser_streaming_failure_case_bug_19056( + qwen_tokenizer: AnyTokenizer, + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """<tool_call> +{"name": "final_answer", "arguments": {"trigger": true}} +</tool_call>""" + tokens = qwen_tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for token in tokens: + text = qwen_tokenizer.decode([token]) + current_text = previous_text + text + delta = hermes_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + previous_text = current_text + if delta is not None: + delta_messages.append(delta) + + assert delta_messages[0].tool_calls[0].function.name == "final_answer" + tool_call_args = "".join( + delta.tool_calls[0].function.arguments or "" for delta in delta_messages + ) + assert tool_call_args == '{"trigger": true}' + + +def test_hermes_parser_streaming( + qwen_tokenizer: AnyTokenizer, + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = '<tool_call>\ +{"name": "get_current_temperature",\ +"arguments": {"location":\ +"San Francisco, California, United States", "unit": "celsius"}}\ +</tool_call>' + + tokens = qwen_tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for token in tokens: + text = qwen_tokenizer.decode([token]) + current_text = previous_text + text + delta = hermes_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + previous_text = current_text + if delta is not None: + delta_messages.append(delta) + print(delta_messages) + assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature" + tool_call_args = "".join( + delta.tool_calls[0].function.arguments or "" for delta in delta_messages + ) + assert tool_call_args == ( + '{"location":"San Francisco, California, United States", "unit": "celsius"}' + ) + + +def test_hermes_parser_non_streaming_no_tool_call( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """This is not a tool call.""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert not tool_call.tools_called + + +def test_hermes_parser_non_streaming_tool_call_between_tags( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """<tool_call> +{"name": "final_answer", "arguments": {"trigger": true}} +</tool_call>""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert tool_call.tools_called + assert tool_call.tool_calls[0].function.name == "final_answer" + assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}' + + +def test_hermes_parser_non_streaming_tool_call_until_eos( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + text = """<tool_call> +{"name": "final_answer", "arguments": {"trigger": true}}""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert tool_call.tools_called + assert tool_call.tool_calls[0].function.name == "final_answer" + assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}' + + +def test_hermes_parser_non_streaming_tool_call_invalid_json( + hermes_parser: Hermes2ProToolParser, + any_chat_request: ChatCompletionRequest, +) -> None: + # Missing closing brace to trigger exception + text = """<tool_call> +{"name": "final_answer", "arguments": {"trigger": true}""" + tool_call = hermes_parser.extract_tool_calls( + model_output=text, + request=any_chat_request, + ) + + assert tool_call is not None + assert not tool_call.tools_called diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index bd8e06513e13e..bdd5344652c4b 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -8,15 +8,18 @@ from unittest.mock import MagicMock import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager def make_tool_call(name, arguments): - return ToolCall(type="function", - function=FunctionCall(name=name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=name, arguments=json.dumps(arguments)), + ) # TODO: add reason prefix and suffix. @@ -29,70 +32,68 @@ def make_tool_call(name, arguments): ("How can I help you today?", [], "How can I help you today?"), # Single tool call, no content ( - "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501 - [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }) - ], - None), - # Multiple tool calls - ( - "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501 - [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }), - make_tool_call( - "register_user", { - "name": "John Doe", - "age": 37, - "address": { - "city": "San Francisco", - "state": "CA" - }, - "role": None, - "passed_test": True, - "aliases": ["John", "Johnny"] - }) - ], - None), - # Content before tool call - ( - "I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501 - [make_tool_call("get_weather", {"city": "Boston"})], - "I will call the tool now. "), - # Content after tool call (should be stripped) - ( - "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501 - [make_tool_call("get_weather", {"city": "Seattle"})], - None), - ( - "<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>", + '<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501 [ make_tool_call( - "complex_tool", - {"level1": { - "level2": { - "level3": { - "value": 123 - } - } - }}) + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ) ], None, ), - ]) -def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, - expected_content): + # Multiple tool calls + ( + '<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501 + [ + make_tool_call( + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ), + make_tool_call( + "register_user", + { + "name": "John Doe", + "age": 37, + "address": {"city": "San Francisco", "state": "CA"}, + "role": None, + "passed_test": True, + "aliases": ["John", "Johnny"], + }, + ), + ], + None, + ), + # Content before tool call + ( + 'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501 + [make_tool_call("get_weather", {"city": "Boston"})], + "I will call the tool now. ", + ), + # Content after tool call (should be stripped) + ( + '<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501 + [make_tool_call("get_weather", {"city": "Seattle"})], + None, + ), + ( + '<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>', + [ + make_tool_call( + "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} + ) + ], + None, + ), + ], +) +def test_hunyuan_a13b_tool_parser_extract( + model_output, expected_tool_calls, expected_content +): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "hunyuan_a13b")(mock_tokenizer) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=False) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( + mock_tokenizer + ) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=False + ) # align the random id. for idx in range(len(tool_calls)): @@ -102,49 +103,74 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, # Streaming test: simulate incremental output -@pytest.mark.parametrize("model_deltas,expected_tool_calls", [ - ([ - "<tool_calls>[{\"name\": \"get_weather\", ", - "\"arguments\": {\"city\": \"San Francisco\", ", - "\"metric\": \"celsius\"}}]", "</tool_calls>" - ], [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }) - ]), - ([ - "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":", - " {\"city\": \"Boston\"}", "}]", "</tool_calls>" - ], [make_tool_call("get_weather", {"city": "Boston"})]), - ([ - "", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":", - " {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>" - ], [make_tool_call("get_weather", {"city": "Boston"})]), - pytest.param([ - "<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ", - " {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}", - "]</tool_calls>" - ], [ - make_tool_call("complex_tool", - {"level1": { - "level2": { - "level3": { - "value": 123 - } - } - }}) +@pytest.mark.parametrize( + "model_deltas,expected_tool_calls", + [ + ( + [ + '<tool_calls>[{"name": "get_weather", ', + '"arguments": {"city": "San Francisco", ', + '"metric": "celsius"}}]', + "</tool_calls>", + ], + [ + make_tool_call( + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ) + ], + ), + ( + [ + '<tool_calls>[{"name":', + ' "get_weather",', + ' "arguments":', + ' {"city": "Boston"}', + "}]", + "</tool_calls>", + ], + [make_tool_call("get_weather", {"city": "Boston"})], + ), + ( + [ + "", + '<tool_calls>[{"name":', + ' "get_weather",', + ' "arguments":', + ' {"city": "Boston"}', + "}]", + "</tool_calls>", + "\n</answer>", + ], + [make_tool_call("get_weather", {"city": "Boston"})], + ), + pytest.param( + [ + '<tool_calls>[{"name": "complex_tool",', + ' "arguments": ', + ' {"level1": {"level2": ', + '{"level3": {"value": 123}}}}}', + "]</tool_calls>", + ], + [ + make_tool_call( + "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} + ) + ], + marks=pytest.mark.xfail( + reason="stream parsing not support nested json yet." + ), + ), ], - marks=pytest.mark.xfail( - reason="stream parsing not support nested json yet.")), -]) +) def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "hunyuan_a13b")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( + mock_tokenizer + ) reconstructor = run_tool_extraction_streaming( - tool_parser, model_deltas, assert_one_tool_per_delta=False) + tool_parser, model_deltas, assert_one_tool_per_delta=False + ) # align the random id. for idx in range(len(reconstructor.tool_calls)): diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py index 09726c7e3e5b5..c7a8ef83cf71d 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py @@ -5,8 +5,7 @@ import pytest from transformers import AutoTokenizer from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation -from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import ( - Llama3JsonToolParser) +from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser @pytest.fixture @@ -18,8 +17,10 @@ def parser(): def test_extract_tool_calls_simple(parser): # Test with a simple tool call - model_output = ('Here is the result: {"name": "getOpenIncidentsTool", ' - '"parameters": {}} Would you like to know more?') + model_output = ( + 'Here is the result: {"name": "getOpenIncidentsTool", ' + '"parameters": {}} Would you like to know more?' + ) result = parser.extract_tool_calls(model_output, None) assert isinstance(result, ExtractedToolCallInformation) @@ -34,8 +35,8 @@ def test_extract_tool_calls_simple(parser): def test_extract_tool_calls_with_arguments(parser): # Test with a tool call that has arguments model_output = ( - '{"name": "searchTool", "parameters": {"query": "test query", ' - '"limit": 10}}') + '{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}' + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True @@ -81,7 +82,8 @@ def test_extract_tool_calls_multiple_json(parser): model_output = ( '{"name": "searchTool", "parameters": {"query": "test1"}}; ' '{"name": "getOpenIncidentsTool", "parameters": {}}; ' - '{"name": "searchTool", "parameters": {"query": "test2"}}') + '{"name": "searchTool", "parameters": {"query": "test2"}}' + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True @@ -105,7 +107,8 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser): model_output = ( '{"name": "searchTool", "parameters": {"query": "test1"}} ; ' '{"name": "getOpenIncidentsTool", "parameters": {}} ; ' - '{"name": "searchTool", "parameters": {"query": "test2"}}') + '{"name": "searchTool", "parameters": {"query": "test2"}}' + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True @@ -118,11 +121,12 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser): def test_extract_tool_calls_multiple_json_with_surrounding_text(parser): # Test with multiple JSONs and surrounding text model_output = ( - 'Here are the results: ' + "Here are the results: " '{"name": "searchTool", "parameters": {"query": "test1"}}; ' '{"name": "getOpenIncidentsTool", "parameters": {}}; ' '{"name": "searchTool", "parameters": {"query": "test2"}} ' - 'Would you like to know more?') + "Would you like to know more?" + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 8c86b4889e15b..94277980f229f 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -16,12 +18,14 @@ SIMPLE_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "LA", "metric": "C"}', ) -MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', " - "age=9, " - "address={'city': 'LA', 'state': 'CA'}, " - "role=None, " - "passed_test=True, " - "aliases=['John', 'Johnny'])]") +MORE_TYPES_FUNCTION_OUTPUT = ( + "[register_user(name='Doe', " + "age=9, " + "address={'city': 'LA', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])]" +) MORE_TYPES_FUNCTION_CALL = FunctionCall( name="register_user", arguments='{"name": "Doe", ' @@ -34,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall( PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]" PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="get_weather", - arguments='{}', + arguments="{}", ) EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]" EMPTY_DICT_FUNCTION_CALL = FunctionCall( @@ -47,25 +51,28 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall( arguments='{"steps": []}', ) ESCAPED_STRING_FUNCTION_OUTPUT = ( - r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]") + r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]" +) ESCAPED_STRING_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', ) PYTHON_TAG_FUNCTION_OUTPUT = ( - "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>") + "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>" +) @pytest.mark.parametrize("streaming", [True, False]) def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) model_output = "How can I help you today?" - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content == model_output assert len(tool_calls) == 0 @@ -75,98 +82,139 @@ test_str = "<|python_start|>" test_str += "[get_weather(city='LA', metric='C')," test_str += "register_user(name='Doe', age=9)]" TEST_CASES = [ - pytest.param(True, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="simple_streaming"), - pytest.param(False, - SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="simple_nonstreaming"), - pytest.param(True, - MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], - id="more_types_streaming"), - pytest.param(False, - MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], - id="more_types_nonstreaming"), - pytest.param(True, - PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_streaming"), - pytest.param(False, - PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_nonstreaming"), - pytest.param(True, - EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_streaming"), - pytest.param(False, - EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_nonstreaming"), - pytest.param(True, - EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_streaming"), - pytest.param(False, - EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_nonstreaming"), - pytest.param(True, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_streaming"), - pytest.param(False, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_nonstreaming"), + pytest.param( + True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming" + ), + pytest.param( + True, + MORE_TYPES_FUNCTION_OUTPUT, + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + MORE_TYPES_FUNCTION_OUTPUT, + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + EMPTY_DICT_FUNCTION_OUTPUT, + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + EMPTY_DICT_FUNCTION_OUTPUT, + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + EMPTY_LIST_FUNCTION_OUTPUT, + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + EMPTY_LIST_FUNCTION_OUTPUT, + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), pytest.param( True, "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", [ SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), ], - id="parallel_calls_streaming"), + id="parallel_calls_streaming", + ), pytest.param( False, "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", [ SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), ], - id="parallel_calls_nonstreaming"), - pytest.param(True, - PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="python_tag_streaming"), - pytest.param(False, - PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="python_tag_nonstreaming"), - pytest.param(True, - test_str, [ - SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') - ], - id="parallel_calls_streaming"), - pytest.param(False, - "<|python_start|>[get_weather(city='LA', metric='C'), " + - "register_user(name='Doe', age=9)]", [ - SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') - ], - id="parallel_calls_nonstreaming"), + id="parallel_calls_nonstreaming", + ), + pytest.param( + True, + PYTHON_TAG_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + id="python_tag_streaming", + ), + pytest.param( + False, + PYTHON_TAG_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + id="python_tag_nonstreaming", + ), + pytest.param( + True, + test_str, + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), + ], + id="parallel_calls_streaming", + ), + pytest.param( + False, + "<|python_start|>[get_weather(city='LA', metric='C'), " + + "register_user(name='Doe', age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), + ], + id="parallel_calls_nonstreaming", + ), ] -@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", - TEST_CASES) -def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: list[FunctionCall]): +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert len(tool_calls) == len(expected_tool_calls) for actual, expected in zip(tool_calls, expected_tool_calls): @@ -176,8 +224,9 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) model_output_deltas = [ "<|python_start|>[get_weather(city='LA', metric='C'), " "get_weather(), " @@ -185,7 +234,8 @@ def test_streaming_tool_call_with_large_steps(): ] reconstructor = run_tool_extraction_streaming( - tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 @@ -198,8 +248,9 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 @@ -207,10 +258,10 @@ def test_regex_timeout_handling(streaming: bool): mock_regex = MagicMock() mock_regex.match.side_effect = TimeoutError("Regex timeout") - with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): - content, tool_calls = run_tool_extraction(tool_parser, - fake_problematic_input, - streaming=streaming) + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) # should treat as regular text when regex times out assert content == fake_problematic_input diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index d83137472598e..d7b4051ea572a 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -22,7 +24,8 @@ MORE_TYPES_FUNCTION_OUTPUT = ( "address={'city': 'San Francisco', 'state': 'CA'}, " "role=None, " "passed_test=True, " - "aliases=['John', 'Johnny'])") + "aliases=['John', 'Johnny'])" +) MORE_TYPES_FUNCTION_CALL = FunctionCall( name="register_user", arguments='{"name": "John Doe", ' @@ -35,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall( PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="get_weather", - arguments='{}', + arguments="{}", ) EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" EMPTY_DICT_FUNCTION_CALL = FunctionCall( @@ -48,7 +51,8 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall( arguments='{"steps": []}', ) ESCAPED_STRING_FUNCTION_OUTPUT = ( - r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')") + r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')" +) ESCAPED_STRING_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', @@ -59,80 +63,118 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall( def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) model_output = "How can I help you today?" - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content == model_output assert len(tool_calls) == 0 TEST_CASES = [ - pytest.param(True, - f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], - id="simple_streaming"), - pytest.param(False, - f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], - id="simple_nonstreaming"), - pytest.param(True, - f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], - id="more_types_streaming"), - pytest.param(False, - f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], - id="more_types_nonstreaming"), - pytest.param(True, - f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", - [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_streaming"), - pytest.param(False, - f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", - [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_nonstreaming"), - pytest.param(True, - f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_streaming"), - pytest.param(False, - f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_nonstreaming"), - pytest.param(True, - f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_streaming"), - pytest.param(False, - f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_nonstreaming"), - pytest.param(True, - f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_streaming"), - pytest.param(False, - f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_nonstreaming"), - pytest.param(True, - f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", - [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], - id="parallel_calls_streaming"), - pytest.param(False, - f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", - [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], - id="parallel_calls_nonstreaming"), + pytest.param( + True, + f"[{SIMPLE_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, + f"[{SIMPLE_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming", + ), + pytest.param( + True, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), + pytest.param( + True, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_streaming", + ), + pytest.param( + False, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_nonstreaming", + ), ] -@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", - TEST_CASES) -def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: list[FunctionCall]): +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content is None assert len(tool_calls) == len(expected_tool_calls) @@ -144,7 +186,8 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) model_output_deltas = [ "[get_weather(city='San", " Francisco', metric='celsius'), " @@ -153,7 +196,8 @@ def test_streaming_tool_call_with_large_steps(): ] reconstructor = run_tool_extraction_streaming( - tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 @@ -166,8 +210,9 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( + mock_tokenizer + ) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 @@ -175,10 +220,10 @@ def test_regex_timeout_handling(streaming: bool): mock_regex = MagicMock() mock_regex.match.side_effect = TimeoutError("Regex timeout") - with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): - content, tool_calls = run_tool_extraction(tool_parser, - fake_problematic_input, - streaming=streaming) + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) # should treat as regular text when regex times out assert content == fake_problematic_input diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index e1b41f45f5548..cfa4d3584e709 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -4,15 +4,17 @@ from collections.abc import Iterable from typing import Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import ToolParser class StreamingToolReconstructor: - def __init__(self, assert_one_tool_per_delta: bool = True): self.tool_calls: list[ToolCall] = [] self.other_content: str = "" @@ -23,49 +25,60 @@ class StreamingToolReconstructor: self.other_content += delta.content else: assert delta.tool_calls, ( - "Streaming results should have either content or tool calls " - "(or both)") + "Streaming results should have either content or tool calls (or both)" + ) if self._assert_one_tool_per_delta: # Note: This isn't strictly required by the API and may not be # possible to adhere to depending on the token space and number of # tokens per streamed response from the model, but it is required # by tool_use tests, so we enforce it here by default also. assert len(delta.tool_calls) < 2, ( - "Streaming should include only one tool call per update.") + "Streaming should include only one tool call per update." + ) for call_delta in delta.tool_calls: assert call_delta.type is None or call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " - f"{call_delta.type}") - current_tool_call = self.tool_calls[ - call_delta.index] if call_delta.index < len( - self.tool_calls) else None + f"{call_delta.type}" + ) + current_tool_call = ( + self.tool_calls[call_delta.index] + if call_delta.index < len(self.tool_calls) + else None + ) if current_tool_call: - assert (not call_delta.function.name), ( + assert not call_delta.function.name, ( "Streaming tool calls should emit the full function name " - f"exactly once. Got {call_delta.function.name}") - assert (not call_delta.id), ( + f"exactly once. Got {call_delta.function.name}" + ) + assert not call_delta.id, ( "Streaming tool calls must emit function id only once. Got " - f"{call_delta.id}") - assert (call_delta.index == len(self.tool_calls) - 1), ( + f"{call_delta.id}" + ) + assert call_delta.index == len(self.tool_calls) - 1, ( f"Incorrect index for tool delta. Got {call_delta.index}, " - f"expected {len(self.tool_calls) - 1}") - current_tool_call.function.arguments += ( - call_delta.function.arguments) + f"expected {len(self.tool_calls) - 1}" + ) + current_tool_call.function.arguments += call_delta.function.arguments else: assert call_delta.id is not None, ( - "Streaming tool calls must have an id on first appearance") + "Streaming tool calls must have an id on first appearance" + ) assert call_delta.function.name is not None, ( - "Streaming tool calls must have a function name on first " - "appearance") + "Streaming tool calls must have a function name on first appearance" + ) assert call_delta.index == len(self.tool_calls), ( f"Incorrect index for tool delta. Got {call_delta.index}, " - f"expected {len(self.tool_calls)}") + f"expected {len(self.tool_calls)}" + ) self.tool_calls.append( - ToolCall(id=call_delta.id, - function=FunctionCall( - name=call_delta.function.name, - arguments=call_delta.function.arguments - or ""))) + ToolCall( + id=call_delta.id, + function=FunctionCall( + name=call_delta.function.name, + arguments=call_delta.function.arguments or "", + ), + ) + ) def run_tool_extraction( @@ -80,11 +93,11 @@ def run_tool_extraction( tool_parser, model_output, request, - assert_one_tool_per_delta=assert_one_tool_per_delta) + assert_one_tool_per_delta=assert_one_tool_per_delta, + ) return reconstructor.other_content or None, reconstructor.tool_calls else: - extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, - request) + extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request) assert extracted.tools_called == bool(extracted.tool_calls) return extracted.content, extracted.tool_calls @@ -92,7 +105,7 @@ def run_tool_extraction( def run_tool_extraction_nonstreaming( tool_parser: ToolParser, model_output: str, - request: Union[ChatCompletionRequest, None] = None + request: Union[ChatCompletionRequest, None] = None, ) -> ExtractedToolCallInformation: request = request or ChatCompletionRequest(messages=[], model="test-model") return tool_parser.extract_tool_calls(model_output, request) @@ -106,7 +119,8 @@ def run_tool_extraction_streaming( ) -> StreamingToolReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingToolReconstructor( - assert_one_tool_per_delta=assert_one_tool_per_delta) + assert_one_tool_per_delta=assert_one_tool_per_delta + ) previous_text = "" previous_tokens: list[int] = [] for delta in model_deltas: @@ -118,8 +132,14 @@ def run_tool_extraction_streaming( current_text = previous_text + delta current_tokens = previous_tokens + token_delta delta_message = tool_parser.extract_tool_calls_streaming( - previous_text, current_text, delta, previous_tokens, - current_tokens, token_delta, request) + previous_text, + current_text, + delta, + previous_tokens, + current_tokens, + token_delta, + request, + ) if delta_message is not None: reconstructor.append_delta(delta_message) previous_text = current_text diff --git a/tests/async_engine/__init__.py b/tests/entrypoints/pooling/__init__.py similarity index 100% rename from tests/async_engine/__init__.py rename to tests/entrypoints/pooling/__init__.py diff --git a/tests/core/__init__.py b/tests/entrypoints/pooling/correctness/__init__.py similarity index 100% rename from tests/core/__init__.py rename to tests/entrypoints/pooling/correctness/__init__.py diff --git a/tests/entrypoints/openai/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/correctness/test_mteb_embed.py similarity index 62% rename from tests/entrypoints/openai/correctness/test_mteb_embed.py rename to tests/entrypoints/pooling/correctness/test_mteb_embed.py index 783f7d3e0d5aa..7f16638e51e2c 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_embed.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_embed.py @@ -4,10 +4,12 @@ import os import pytest -from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, - MTEB_EMBED_TOL, - OpenAIClientMtebEncoder, - run_mteb_embed_task) +from tests.models.language.pooling_mteb_test.mteb_utils import ( + MTEB_EMBED_TASKS, + MTEB_EMBED_TOL, + OpenAIClientMtebEncoder, + run_mteb_embed_task, +) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" @@ -18,10 +20,7 @@ MAIN_SCORE = 0.7422994752439667 @pytest.fixture(scope="module") def server(): - args = [ - "--runner", "pooling", "--enforce-eager", - "--disable-uvicorn-access-log" - ] + args = ["--runner", "pooling", "--enforce-eager", "--disable-uvicorn-access-log"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -37,4 +36,6 @@ def test_mteb_embed(server): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_EMBED_TOL diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/pooling/correctness/test_mteb_score.py similarity index 50% rename from tests/entrypoints/openai/correctness/test_mteb_score.py rename to tests/entrypoints/pooling/correctness/test_mteb_score.py index cfb865815c9b2..1afe68b189db8 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_score.py @@ -4,60 +4,53 @@ import os import pytest -# yapf conflicts with isort for this block -# yapf: disable -from tests.models.language.pooling.mteb_utils import ( - MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, - RerankClientMtebEncoder, ScoreClientMtebEncoder, - mteb_test_rerank_models_hf, run_mteb_rerank) -# yapf: enable +from tests.models.language.pooling_mteb_test.mteb_utils import ( + MTEB_RERANK_LANGS, + MTEB_RERANK_TASKS, + MTEB_RERANK_TOL, + RerankClientMtebEncoder, + ScoreClientMtebEncoder, + run_mteb_rerank, +) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" +st_main_score = 0.33457 @pytest.fixture(scope="module") def server(): - args = [ - "--runner", "pooling", "--enforce-eager", - "--disable-uvicorn-access-log" - ] + args = ["--runner", "pooling", "--enforce-eager", "--disable-uvicorn-access-log"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server -@pytest.fixture(scope="module") -def st_main_score(hf_runner): - # The main score related to the version of the dependency. - # So we need to recalculate every time. - main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME) - return main_score - - -def test_mteb_score(server, st_main_score): +def test_mteb_score(server): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) - vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, - MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_RERANK_TOL -def test_mteb_rerank(server, st_main_score): +def test_mteb_rerank(server): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) - vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, - MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_RERANK_TOL diff --git a/tests/core/block/__init__.py b/tests/entrypoints/pooling/llm/__init__.py similarity index 100% rename from tests/core/block/__init__.py rename to tests/entrypoints/pooling/llm/__init__.py diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py similarity index 52% rename from tests/entrypoints/llm/test_classify.py rename to tests/entrypoints/pooling/llm/test_classify.py index 57705ff669075..ae216c464a5b4 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -6,34 +6,27 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) @@ -44,29 +37,34 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(activation): outputs = llm.classify( - prompts, - pooling_params=PoolingParams(activation=activation), - use_tqdm=False) + prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False + ) return torch.tensor([x.outputs.probs for x in outputs]) default = get_outputs(activation=None) w_activation = get_outputs(activation=True) wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - softmax(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) def test_encode_api(llm: LLM): err_msg = "pooling_task must be one of.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompts, use_tqdm=False) + + +def test_score_api(llm: LLM): + err_msg = "Score API is only enabled for num_labels == 1." + with pytest.raises(ValueError, match=err_msg): + llm.score("ping", "pong", use_tqdm=False) diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py similarity index 53% rename from tests/entrypoints/llm/test_embedding.py rename to tests/entrypoints/pooling/llm/test_embedding.py index 485f04ed6d849..aa24a70fd18b8 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/pooling/llm/test_embedding.py @@ -19,12 +19,14 @@ prompts = ["The chef prepared a delicious meal."] def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) @@ -35,21 +37,20 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(normalize): - outputs = llm.embed(prompts, - pooling_params=PoolingParams(normalize=normalize), - use_tqdm=False) + outputs = llm.embed( + prompts, pooling_params=PoolingParams(normalize=normalize), use_tqdm=False + ) return torch.tensor([x.outputs.embedding for x in outputs]) default = get_outputs(normalize=None) w_normal = get_outputs(normalize=True) wo_normal = get_outputs(normalize=False) - assert torch.allclose(default, w_normal, - atol=1e-2), "Default should use normal." - assert not torch.allclose(w_normal, wo_normal, - atol=1e-2), "wo_normal should not use normal." - assert torch.allclose( - w_normal, F.normalize(wo_normal, p=2, dim=-1), - atol=1e-2), "w_normal should be close to normal(wo_normal)." + assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, atol=1e-2), ( + "wo_normal should not use normal." + ) + assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( + "w_normal should be close to normal(wo_normal)." + ) diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py similarity index 82% rename from tests/entrypoints/llm/test_encode.py rename to tests/entrypoints/pooling/llm/test_encode.py index cb54b16b0b044..d6aae99944f8f 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/pooling/llm/test_encode.py @@ -27,24 +27,18 @@ TOKEN_IDS = [ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py new file mode 100644 index 0000000000000..8312ff180b36f --- /dev/null +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from tests.models.utils import softmax +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "internlm/internlm2-1_8b-reward" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + trust_remote_code=True, + seed=0, + ) + + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + def get_outputs(softmax): + outputs = llm.reward( + prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False + ) + return torch.cat([x.outputs.data for x in outputs]) + + default = get_outputs(softmax=None) + w_softmax = get_outputs(softmax=True) + wo_softmax = get_outputs(softmax=False) + + assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." + assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), ( + "wo_softmax should not use softmax." + ) + assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), ( + "w_softmax should be close to softmax(wo_softmax)." + ) diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py similarity index 54% rename from tests/entrypoints/llm/test_score.py rename to tests/entrypoints/pooling/llm/test_score.py index 5a1339b2addf4..9bf74fce906b0 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -6,32 +6,25 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) @@ -42,7 +35,6 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(activation): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." @@ -51,18 +43,20 @@ def test_pooling_params(llm: LLM): text_1, text_2, pooling_params=PoolingParams(activation=activation), - use_tqdm=False) + use_tqdm=False, + ) return torch.tensor([x.outputs.score for x in outputs]) default = get_outputs(activation=None) w_activation = get_outputs(activation=True) wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - softmax(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) diff --git a/tests/core/block/e2e/__init__.py b/tests/entrypoints/pooling/openai/__init__.py similarity index 100% rename from tests/core/block/e2e/__init__.py rename to tests/entrypoints/pooling/openai/__init__.py diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py similarity index 65% rename from tests/entrypoints/openai/test_classification.py rename to tests/entrypoints/pooling/openai/test_classification.py index 30078fe90257a..92d40efad21cb 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -6,10 +6,9 @@ import requests import torch import torch.nn.functional as F +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ClassificationResponse -from ...utils import RemoteOpenAIServer - MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" DTYPE = "float32" # Use float32 to avoid NaN issue @@ -29,21 +28,16 @@ def server(): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_single_input_classification(server: RemoteOpenAIServer, - model_name: str): +def test_single_input_classification(server: RemoteOpenAIServer, model_name: str): input_text = "This product was excellent and exceeded my expectations" classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": input_text - }, + json={"model": model_name, "input": input_text}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert output.object == "list" assert output.model == MODEL_NAME @@ -53,8 +47,7 @@ def test_single_input_classification(server: RemoteOpenAIServer, @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_multiple_inputs_classification(server: RemoteOpenAIServer, - model_name: str): +def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): input_texts = [ "The product arrived on time and works perfectly", "I'm very satisfied with my purchase, would buy again", @@ -66,13 +59,9 @@ def test_multiple_inputs_classification(server: RemoteOpenAIServer, classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": input_texts - }, + json={"model": model_name, "input": input_texts}, ) - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert len(output.data) == len(input_texts) for i, item in enumerate(output.data): @@ -89,16 +78,11 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": long_text, - "truncate_prompt_tokens": 5 - }, + json={"model": model_name, "input": long_text, "truncate_prompt_tokens": 5}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert len(output.data) == 1 assert output.data[0].index == 0 @@ -108,15 +92,12 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, - model_name: str): +def test_invalid_truncate_prompt_tokens_error( + server: RemoteOpenAIServer, model_name: str +): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": "test", - "truncate_prompt_tokens": 513 - }, + json={"model": model_name, "input": "test", "truncate_prompt_tokens": 513}, ) error = classification_response.json() @@ -128,10 +109,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": "" - }, + json={"model": model_name, "input": ""}, ) error = classification_response.json() @@ -140,18 +118,13 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_batch_classification_empty_list(server: RemoteOpenAIServer, - model_name: str): +def test_batch_classification_empty_list(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": [] - }, + json={"model": model_name, "input": []}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert output.object == "list" assert isinstance(output.data, list) @@ -162,15 +135,17 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer, async def test_invocations(server: RemoteOpenAIServer): request_args = { "model": MODEL_NAME, - "input": "This product was excellent and exceeded my expectations" + "input": "This product was excellent and exceeded my expectations", } - classification_response = requests.post(server.url_for("classify"), - json=request_args) + classification_response = requests.post( + server.url_for("classify"), json=request_args + ) classification_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() classification_output = classification_response.json() @@ -178,10 +153,12 @@ async def test_invocations(server: RemoteOpenAIServer): assert classification_output.keys() == invocation_output.keys() for classification_data, invocation_data in zip( - classification_output["data"], invocation_output["data"]): + classification_output["data"], invocation_output["data"] + ): assert classification_data.keys() == invocation_data.keys() assert classification_data["probs"] == pytest.approx( - invocation_data["probs"], rel=0.01) + invocation_data["probs"], rel=0.01 + ) @pytest.mark.asyncio @@ -190,27 +167,26 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str): input_text = ["This product was excellent and exceeded my expectations"] async def get_outputs(activation): - response = requests.post(server.url_for("classify"), - json={ - "model": model_name, - "input": input_text, - "activation": activation - }) + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "input": input_text, "activation": activation}, + ) outputs = response.json() - return torch.tensor([x['probs'] for x in outputs["data"]]) + return torch.tensor([x["probs"] for x in outputs["data"]]) default = await get_outputs(activation=None) w_activation = await get_outputs(activation=True) wo_activation = await get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) @pytest.mark.asyncio @@ -219,10 +195,36 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): # pooling api uses ALL pooling, which does not support chunked prefill. response = requests.post( server.url_for("pooling"), + json={"model": model_name, "input": "test", "encoding_format": "float"}, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_score(server: RemoteOpenAIServer, model_name: str): + # score api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("score"), json={ "model": model_name, - "input": "test", - "encoding_format": "float" + "text_1": "ping", + "text_2": "pong", + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_rerank(server: RemoteOpenAIServer, model_name: str): + # rerank api is only enabled for num_labels == 1. + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": "ping", + "documents": ["pong"], }, ) assert response.json()["error"]["type"] == "BadRequestError" diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py similarity index 61% rename from tests/entrypoints/openai/test_embedding.py rename to tests/entrypoints/pooling/openai/test_embedding.py index cf2442a569388..6f6559a961a18 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -11,27 +11,17 @@ import requests import torch import torch.nn.functional as F +from tests.models.language.pooling.embed_utils import run_embedding_correctness_test +from tests.models.utils import check_embeddings_close +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.language.pooling.embed_utils import ( - run_embedding_correctness_test) -from ...models.utils import check_embeddings_close -from ...utils import RemoteOpenAIServer - MODEL_NAME = "intfloat/multilingual-e5-small" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = [ @@ -59,15 +49,13 @@ async def client(server): @pytest.fixture(scope="module") def hf_model(hf_runner): - with hf_runner(MODEL_NAME, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: + with hf_runner(MODEL_NAME, dtype=DTYPE, is_sentence_transformer=True) as hf_model: yield hf_model @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): input_texts = [ "The chef prepared a delicious meal.", ] @@ -79,7 +67,8 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -99,7 +88,8 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -111,12 +101,12 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): # test list[str] input_texts = [ - "The cat sat on the mat.", "A feline was resting on a rug.", - "Stars twinkle brightly in the night sky." + "The cat sat on the mat.", + "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky.", ] embedding_response = await client.embeddings.create( model=model_name, @@ -124,7 +114,8 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 @@ -137,15 +128,20 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) # test list[list[int]] - input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [ + [4, 5, 7, 9, 20], + [15, 29, 499], + [24, 24, 24, 24, 24], + [25, 32, 64, 77], + ] embedding_response = await client.embeddings.create( model=model_name, input=input_tokens, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 4 @@ -157,19 +153,23 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_embedding(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] +async def test_conversation_embedding( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] chat_response = requests.post( server.url_for("v1/embeddings"), @@ -198,64 +198,66 @@ async def test_conversation_embedding(server: RemoteOpenAIServer, extra_body={"add_special_tokens": False}, ) completion_embeddings = EmbeddingResponse.model_validate( - completion_response.model_dump(mode="json")) + completion_response.model_dump(mode="json") + ) assert chat_embeddings.id is not None assert completion_embeddings.id is not None assert chat_embeddings.created <= completion_embeddings.created - assert chat_embeddings.model_dump( - exclude={"id", "created"}) == (completion_embeddings.model_dump( - exclude={"id", "created"})) + assert chat_embeddings.model_dump(exclude={"id", "created"}) == ( + completion_embeddings.model_dump(exclude={"id", "created"}) + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_batch_base64_embedding( + hf_model, client: openai.AsyncOpenAI, model_name: str +): input_texts = [ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ] - responses_float = await client.embeddings.create(input=input_texts, - model=model_name, - encoding_format="float") + responses_float = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float" + ) float_data = [d.embedding for d in responses_float.data] run_embedding_correctness_test(hf_model, input_texts, float_data) - responses_base64 = await client.embeddings.create(input=input_texts, - model=model_name, - encoding_format="base64") + responses_base64 = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="base64" + ) base64_data = [] for data in responses_base64.data: base64_data.append( - np.frombuffer(base64.b64decode(data.embedding), - dtype="float32").tolist()) + np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist() + ) run_embedding_correctness_test(hf_model, input_texts, base64_data) # Default response is float32 decoded from base64 by OpenAI Client - responses_default = await client.embeddings.create(input=input_texts, - model=model_name) + responses_default = await client.embeddings.create( + input=input_texts, model=model_name + ) default_data = [d.embedding for d in responses_default.data] run_embedding_correctness_test(hf_model, input_texts, default_data) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding_truncation(client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] # test single embedding embedding_response = await client.embeddings.create( - model=model_name, - input=input_texts, - extra_body={"truncate_prompt_tokens": 10}) + model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10} + ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -265,15 +267,34 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 10 input_tokens = [ - 1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728, - 9901, 340, 2229, 385, 340, 315, 28741, 28804, 2 + 1, + 24428, + 289, + 18341, + 26165, + 285, + 19323, + 283, + 289, + 26789, + 3871, + 28728, + 9901, + 340, + 2229, + 385, + 340, + 315, + 28741, + 28804, + 2, ] embedding_response = await client.embeddings.create( - model=model_name, - input=input_tokens, - extra_body={"truncate_prompt_tokens": 10}) + model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10} + ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -285,8 +306,9 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding_truncation_invalid( + client: openai.AsyncOpenAI, model_name: str +): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] @@ -295,15 +317,17 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, response = await client.embeddings.create( model=model_name, input=input_texts, - extra_body={"truncate_prompt_tokens": 8193}) + extra_body={"truncate_prompt_tokens": 8193}, + ) assert "error" in response.object - assert "truncate_prompt_tokens value is greater than max_model_len. "\ - "Please, select a smaller truncation size." in response.message + assert ( + "truncate_prompt_tokens value is greater than max_model_len. " + "Please, select a smaller truncation size." in response.message + ) @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): input_texts = [ "The chef prepared a delicious meal.", ] @@ -316,35 +340,43 @@ async def test_invocations(server: RemoteOpenAIServer, completion_response = await client.embeddings.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion_response.model_dump() invocation_output = invocation_response.json() assert completion_output.keys() == invocation_output.keys() - for completion_data, invocation_data in zip(completion_output["data"], - invocation_output["data"]): + for completion_data, invocation_data in zip( + completion_output["data"], invocation_output["data"] + ): assert completion_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]], - embeddings_1_lst=[invocation_data["embedding"]], - name_0="completion", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=[completion_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="completion", + name_1="invocation", + ) @pytest.mark.asyncio async def test_invocations_conversation(server: RemoteOpenAIServer): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] request_args = { "model": MODEL_NAME, @@ -352,25 +384,28 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): "encoding_format": "float", } - chat_response = requests.post(server.url_for("v1/embeddings"), - json=request_args) + chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args) chat_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() - for chat_data, invocation_data in zip(chat_output["data"], - invocation_output["data"]): + for chat_data, invocation_data in zip( + chat_output["data"], invocation_output["data"] + ): assert chat_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]], - embeddings_1_lst=[invocation_data["embedding"]], - name_0="chat", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=[chat_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="chat", + name_1="invocation", + ) @pytest.mark.asyncio @@ -383,23 +418,22 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str): "model": MODEL_NAME, "input": input_text, "encoding_format": "float", - "normalize": normalize + "normalize": normalize, } - response = requests.post(server.url_for("v1/embeddings"), - json=request_args) + response = requests.post(server.url_for("v1/embeddings"), json=request_args) outputs = response.json() - return torch.tensor([x['embedding'] for x in outputs["data"]]) + return torch.tensor([x["embedding"] for x in outputs["data"]]) default = await get_outputs(normalize=None) w_normal = await get_outputs(normalize=True) wo_normal = await get_outputs(normalize=False) - assert torch.allclose(default, w_normal, - atol=1e-2), "Default should use normal." - assert not torch.allclose(w_normal, wo_normal, - atol=1e-2), "wo_normal should not use normal." - assert torch.allclose( - w_normal, F.normalize(wo_normal, p=2, dim=-1), - atol=1e-2), "w_normal should be close to normal(wo_normal)." + assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, atol=1e-2), ( + "wo_normal should not use normal." + ) + assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( + "w_normal should be close to normal(wo_normal)." + ) diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py similarity index 78% rename from tests/entrypoints/openai/test_embedding_dimensions.py rename to tests/entrypoints/pooling/openai/test_embedding_dimensions.py index 91e91699b92ca..92df43d7dbdcf 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py @@ -9,19 +9,19 @@ from typing import Optional import openai import pytest +from tests.conftest import HfRunner +from tests.models.language.pooling.embed_utils import run_embedding_correctness_test +from tests.models.utils import EmbedModelInfo +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse -from ...conftest import HfRunner -from ...models.language.pooling.embed_utils import ( - run_embedding_correctness_test) -from ...models.utils import EmbedModelInfo -from ...utils import RemoteOpenAIServer - MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - matryoshka_dimensions=[256]), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256], + ), ] input_texts = [ @@ -49,15 +49,14 @@ def server(model_info, dtype: str): dtype, "--enforce-eager", "--max-model-len", - "512" + "512", ] if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5": # Manually enable Matryoshka Embeddings - args.extend([ - "--trust_remote_code", "--hf_overrides", - '{"matryoshka_dimensions":[256]}' - ]) + args.extend( + ["--trust_remote_code", "--hf_overrides", '{"matryoshka_dimensions":[256]}'] + ) with RemoteOpenAIServer(model_info.name, args) as remote_server: yield remote_server @@ -65,14 +64,16 @@ def server(model_info, dtype: str): @pytest.fixture(scope="module") def hf_model(hf_runner, model_info, dtype: str): - with hf_runner(model_info.name, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner( + model_info.name, dtype=dtype, is_sentence_transformer=True + ) as hf_model: yield hf_model @pytest.mark.asyncio -async def test_matryoshka(model_info: EmbedModelInfo, - server: RemoteOpenAIServer, hf_model: HfRunner): +async def test_matryoshka( + model_info: EmbedModelInfo, server: RemoteOpenAIServer, hf_model: HfRunner +): client = server.get_async_client() async def make_request_and_correctness_test(dimensions): @@ -85,7 +86,8 @@ async def test_matryoshka(model_info: EmbedModelInfo, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 @@ -98,8 +100,7 @@ async def test_matryoshka(model_info: EmbedModelInfo, assert len(embeddings.data[0].embedding) == dimensions vllm_outputs = [d.embedding for d in embeddings.data] - run_embedding_correctness_test(hf_model, prompts, vllm_outputs, - dimensions) + run_embedding_correctness_test(hf_model, prompts, vllm_outputs, dimensions) if model_info.is_matryoshka: valid_dimensions: list[Optional[int]] = [None] diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/openai/test_embedding_long_text.py similarity index 83% rename from tests/entrypoints/openai/test_embedding_long_text.py rename to tests/entrypoints/pooling/openai/test_embedding_long_text.py index 86bd34abb97e0..f977c81a9084e 100644 --- a/tests/entrypoints/openai/test_embedding_long_text.py +++ b/tests/entrypoints/pooling/openai/test_embedding_long_text.py @@ -14,10 +14,9 @@ import openai import pytest import pytest_asyncio +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse -from ...utils import RemoteOpenAIServer - def _generate_random_text(word_count: int) -> str: """Generate random text with approximately the specified word count.""" @@ -32,7 +31,6 @@ def _generate_random_text(word_count: int) -> str: "that", "these", "those", - # Action verbs "create", "build", @@ -81,7 +79,6 @@ def _generate_random_text(word_count: int) -> str: "finish", "deliver", "provide", - # Technology and science nouns "system", "application", @@ -133,7 +130,6 @@ def _generate_random_text(word_count: int) -> str: "optimization", "performance", "efficiency", - # General nouns "project", "team", @@ -176,7 +172,7 @@ def _generate_random_text(word_count: int) -> str: "session", "meeting", "discussion", - "decision" + "decision", ] words = [] @@ -190,7 +186,7 @@ def _generate_random_text(word_count: int) -> str: result = [] for i, word in enumerate(words_list): result.append(word) - if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1): + if (i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1: result[-1] += "." return " ".join(result) @@ -217,9 +213,11 @@ def server_with_chunked_processing(): "--enforce-eager", "--max-model-len", "512", # Set smaller max_model_len to trigger chunking mechanism - '--override-pooler-config', - ('{"pooling_type": "MEAN", "normalize": true, ' - '"enable_chunked_processing": true, "max_embed_len": 10000}'), + "--pooler-config", + ( + '{"pooling_type": "MEAN", "normalize": true, ' + '"enable_chunked_processing": true, "max_embed_len": 10000}' + ), "--gpu-memory-utilization", "0.8", ] @@ -231,23 +229,22 @@ def server_with_chunked_processing(): @pytest_asyncio.fixture async def client_with_chunked_processing(server_with_chunked_processing): """Create async client with chunking processing support.""" - async with server_with_chunked_processing.get_async_client( - ) as async_client: + async with server_with_chunked_processing.get_async_client() as async_client: yield async_client @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_long_text_embedding_1500_chars( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): - """Test embedding processing for ~1500 character long text + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): + """Test embedding processing for ~1500 character long text (~1028 tokens, exceeding 512 token limit).""" # Verify text length # Verify text has sufficient word count (approximately 1500 words) word_count = len(LONG_TEXT_1500_WORDS.split()) - assert word_count >= 1400, ( - f"Test text word count insufficient: {word_count} words") + assert word_count >= 1400, f"Test text word count insufficient: {word_count} words" # Send embedding request embedding_response = await client_with_chunked_processing.embeddings.create( @@ -258,12 +255,14 @@ async def test_long_text_embedding_1500_chars( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding - ) == 384 # multilingual-e5-small embedding dimension + assert ( + len(embeddings.data[0].embedding) == 384 + ) # multilingual-e5-small embedding dimension assert embeddings.usage.completion_tokens == 0 # Due to chunked processing, token count should # reflect actual processed tokens @@ -275,26 +274,26 @@ async def test_long_text_embedding_1500_chars( # Verify embedding vector validity embedding_vector = embeddings.data[0].embedding - assert all( - isinstance(x, float) - for x in embedding_vector), "Embedding vector should contain floats" - assert not all( - x == 0 - for x in embedding_vector), "Embedding vector should not be all zeros" + assert all(isinstance(x, float) for x in embedding_vector), ( + "Embedding vector should contain floats" + ) + assert not all(x == 0 for x in embedding_vector), ( + "Embedding vector should not be all zeros" + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_long_text_embedding_2500_chars( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test embedding processing for ~2500 character long text (~2048 tokens, requiring multiple chunks).""" # Verify text length # Verify text has sufficient word count (approximately 2500 words) word_count = len(LONG_TEXT_2500_WORDS.split()) - assert word_count >= 2300, ( - f"Test text word count insufficient: {word_count} words") + assert word_count >= 2300, f"Test text word count insufficient: {word_count} words" # Send embedding request embedding_response = await client_with_chunked_processing.embeddings.create( @@ -305,12 +304,14 @@ async def test_long_text_embedding_2500_chars( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding - ) == 384 # multilingual-e5-small embedding dimension + assert ( + len(embeddings.data[0].embedding) == 384 + ) # multilingual-e5-small embedding dimension assert embeddings.usage.completion_tokens == 0 # Due to chunked processing, token count should # reflect actual processed tokens @@ -322,18 +323,19 @@ async def test_long_text_embedding_2500_chars( # Verify embedding vector validity embedding_vector = embeddings.data[0].embedding - assert all( - isinstance(x, float) - for x in embedding_vector), "Embedding vector should contain floats" - assert not all( - x == 0 - for x in embedding_vector), "Embedding vector should not be all zeros" + assert all(isinstance(x, float) for x in embedding_vector), ( + "Embedding vector should contain floats" + ) + assert not all(x == 0 for x in embedding_vector), ( + "Embedding vector should not be all zeros" + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_batch_long_text_embedding( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test batch long text embedding processing.""" input_texts = [ @@ -351,7 +353,8 @@ async def test_batch_long_text_embedding( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 # Three input texts @@ -376,13 +379,16 @@ async def test_batch_long_text_embedding( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_chunked_vs_normal_consistency( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test consistency between chunked and normal processing (using short text).""" # Use a short text within the 512 token limit - short_text = ("Artificial intelligence technology is changing our world, " - "bringing unprecedented opportunities and challenges.") + short_text = ( + "Artificial intelligence technology is changing our world, " + "bringing unprecedented opportunities and challenges." + ) # Send embedding request embedding_response = await client_with_chunked_processing.embeddings.create( @@ -393,7 +399,8 @@ async def test_chunked_vs_normal_consistency( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -412,7 +419,8 @@ async def test_chunked_vs_normal_consistency( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_chunked_processing_response_format( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test response format and structure during chunked processing.""" # Test with long text to trigger chunking @@ -424,7 +432,8 @@ async def test_chunked_processing_response_format( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -434,8 +443,10 @@ async def test_chunked_processing_response_format( # Verify embedding vector properties embedding_vector = embeddings.data[0].embedding import math + vector_norm = math.sqrt(sum(x * x for x in embedding_vector)) # Check that the vector is normalized # (default behavior for most embedding models) assert 0.8 < vector_norm < 1.2, ( - f"Vector norm should be reasonable, actual: {vector_norm}") + f"Vector norm should be reasonable, actual: {vector_norm}" + ) diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/pooling/openai/test_pooling.py similarity index 67% rename from tests/entrypoints/openai/test_pooling.py rename to tests/entrypoints/pooling/openai/test_pooling.py index 63f4205e0a42b..3439c556ccc40 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/pooling/openai/test_pooling.py @@ -8,11 +8,10 @@ import pytest import requests from tests.models.utils import check_embeddings_close +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...utils import RemoteOpenAIServer - MODEL_NAME = "internlm/internlm2-1_8b-reward" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 @@ -47,11 +46,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): # test single pooling response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_texts, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_texts, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -67,11 +62,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): input_tokens = [1, 1, 1, 1, 1] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_tokens, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -89,16 +80,13 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): # test list[str] input_texts = [ - "The cat sat on the mat.", "A feline was resting on a rug.", - "Stars twinkle brightly in the night sky." + "The cat sat on the mat.", + "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky.", ] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_texts, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_texts, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -111,15 +99,15 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.usage.total_tokens == 29 # test list[list[int]] - input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [ + [4, 5, 7, 9, 20], + [15, 29, 499], + [24, 24, 24, 24, 24], + [25, 32, 64, 77], + ] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_tokens, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -134,18 +122,21 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_pooling(server: RemoteOpenAIServer, - model_name: str): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] +async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] chat_response = requests.post( server.url_for("pooling"), @@ -181,24 +172,22 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, }, ) completions_response.raise_for_status() - completion_poolings = PoolingResponse.model_validate( - completions_response.json()) + completion_poolings = PoolingResponse.model_validate(completions_response.json()) assert chat_poolings.id is not None assert completion_poolings.id is not None assert chat_poolings.created <= completion_poolings.created - assert chat_poolings.model_dump( - exclude={"id", "created"}) == (completion_poolings.model_dump( - exclude={"id", "created"})) + assert chat_poolings.model_dump(exclude={"id", "created"}) == ( + completion_poolings.model_dump(exclude={"id", "created"}) + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_pooling(server: RemoteOpenAIServer, - model_name: str): +async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str): input_texts = [ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ] float_response = requests.post( @@ -211,9 +200,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) float_response.raise_for_status() responses_float = PoolingResponse.model_validate(float_response.json()) - float_data = [ - np.array(d.data).squeeze(-1).tolist() for d in responses_float.data - ] + float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] base64_response = requests.post( server.url_for("pooling"), @@ -229,13 +216,15 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, decoded_responses_base64_data = [] for data in responses_base64.data: decoded_responses_base64_data.append( - np.frombuffer(base64.b64decode(data.data), - dtype="float32").tolist()) + np.frombuffer(base64.b64decode(data.data), dtype="float32").tolist() + ) - check_embeddings_close(embeddings_0_lst=float_data, - embeddings_1_lst=decoded_responses_base64_data, - name_0="float32", - name_1="base64") + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=decoded_responses_base64_data, + name_0="float32", + name_1="base64", + ) # Default response is float32 decoded from base64 by OpenAI Client default_response = requests.post( @@ -251,10 +240,12 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, np.array(d.data).squeeze(-1).tolist() for d in responses_default.data ] - check_embeddings_close(embeddings_0_lst=float_data, - embeddings_1_lst=default_data, - name_0="float32", - name_1="default") + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=default_data, + name_0="float32", + name_1="default", + ) @pytest.mark.asyncio @@ -269,39 +260,46 @@ async def test_invocations(server: RemoteOpenAIServer): "encoding_format": "float", } - completion_response = requests.post(server.url_for("pooling"), - json=request_args) + completion_response = requests.post(server.url_for("pooling"), json=request_args) completion_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion_response.json() invocation_output = invocation_response.json() assert completion_output.keys() == invocation_output.keys() - for completion_data, invocation_data in zip(completion_output["data"], - invocation_output["data"]): + for completion_data, invocation_data in zip( + completion_output["data"], invocation_output["data"] + ): assert completion_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=completion_data["data"], - embeddings_1_lst=invocation_data["data"], - name_0="completion", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=completion_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="completion", + name_1="invocation", + ) @pytest.mark.asyncio async def test_invocations_conversation(server: RemoteOpenAIServer): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] request_args = { "model": MODEL_NAME, @@ -312,18 +310,22 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): chat_response = requests.post(server.url_for("pooling"), json=request_args) chat_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() - for chat_data, invocation_data in zip(chat_output["data"], - invocation_output["data"]): + for chat_data, invocation_data in zip( + chat_output["data"], invocation_output["data"] + ): assert chat_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=chat_data["data"], - embeddings_1_lst=invocation_data["data"], - name_0="chat", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=chat_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="chat", + name_1="invocation", + ) diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py similarity index 50% rename from tests/entrypoints/openai/test_rerank.py rename to tests/entrypoints/pooling/openai/test_rerank.py index 73364294cbcdc..9980fcff16c15 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -6,22 +6,13 @@ import requests import torch import torch.nn.functional as F +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import RerankResponse -from ...utils import RemoteOpenAIServer - MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def server(): args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] @@ -34,15 +25,18 @@ def server(): def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + }, + ) rerank_response.raise_for_status() rerank = RerankResponse.model_validate(rerank_response.json()) @@ -58,16 +52,14 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" documents = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", "Cross-encoder models are neat" + "The capital of France is Paris.", + "Cross-encoder models are neat", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - "top_n": 2 - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={"model": model_name, "query": query, "documents": documents, "top_n": 2}, + ) rerank_response.raise_for_status() rerank = RerankResponse.model_validate(rerank_response.json()) @@ -80,28 +72,26 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): - query = "What is the capital of France?" * 100 documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={"model": model_name, "query": query, "documents": documents}, + ) assert rerank_response.status_code == 400 # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - rerank_response.text + assert "Please reduce the length of the input." in rerank_response.text def test_invocations(server: RemoteOpenAIServer): query = "What is the capital of France?" documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] request_args = { @@ -110,23 +100,25 @@ def test_invocations(server: RemoteOpenAIServer): "documents": documents, } - rerank_response = requests.post(server.url_for("rerank"), - json=request_args) + rerank_response = requests.post(server.url_for("rerank"), json=request_args) rerank_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() rerank_output = rerank_response.json() invocation_output = invocation_response.json() assert rerank_output.keys() == invocation_output.keys() - for rerank_result, invocations_result in zip(rerank_output["results"], - invocation_output["results"]): + for rerank_result, invocations_result in zip( + rerank_output["results"], invocation_output["results"] + ): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.05) + invocations_result["relevance_score"], rel=0.05 + ) # TODO: reset this tolerance to 0.01 once we find # an alternative to flash_attn with bfloat16 @@ -134,34 +126,36 @@ def test_invocations(server: RemoteOpenAIServer): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_activation(server: RemoteOpenAIServer, model_name: str): - async def get_outputs(activation): query = "What is the capital of France?" documents = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - "activation": activation - }) + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "activation": activation, + }, + ) outputs = response.json() - return torch.tensor([x['relevance_score'] for x in outputs["results"]]) + return torch.tensor([x["relevance_score"] for x in outputs["results"]]) default = await get_outputs(activation=None) w_activation = await get_outputs(activation=True) wo_activation = await get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - F.sigmoid(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/pooling/openai/test_score.py similarity index 50% rename from tests/entrypoints/openai/test_score.py rename to tests/entrypoints/pooling/openai/test_score.py index cb6ec795ae969..ef213ab0ea18b 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/pooling/openai/test_score.py @@ -8,28 +8,12 @@ import torch import torch.nn.functional as F from torch import tensor +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ScoreResponse -from ...utils import RemoteOpenAIServer - - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - MODELS = [ - { - "name": "BAAI/bge-reranker-v2-m3", - "is_cross_encoder": True - }, - { - "name": "BAAI/bge-base-en-v1.5", - "is_cross_encoder": False - }, + {"name": "BAAI/bge-reranker-v2-m3", "is_cross_encoder": True}, + {"name": "BAAI/bge-base-en-v1.5", "is_cross_encoder": False}, ] DTYPE = "half" @@ -38,9 +22,7 @@ def run_transformers(hf_model, model, text_pairs): if model["is_cross_encoder"]: return hf_model.predict(text_pairs).tolist() else: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] return [ F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0) for pair in hf_embeddings @@ -64,8 +46,9 @@ def server(model: dict[str, Any]): def runner(model: dict[str, Any], hf_runner): kwargs = { "dtype": DTYPE, - "is_cross_encoder" if model["is_cross_encoder"]\ - else "is_sentence_transformer": True + "is_cross_encoder" + if model["is_cross_encoder"] + else "is_sentence_transformer": True, } with hf_runner(model["name"], **kwargs) as hf_model: @@ -73,21 +56,23 @@ def runner(model: dict[str, Any], hf_runner): class TestModel: - - def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_str_text_2_list( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = "What is the capital of France?" text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -103,23 +88,26 @@ class TestModel: for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_list_text_2_list( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = [ "What is the capital of the United States?", - "What is the capital of France?" + "What is the capital of France?", ] text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -135,17 +123,20 @@ class TestModel: for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_str_text_2_str( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -161,40 +152,41 @@ class TestModel: for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_score_max_model_len(self, server: RemoteOpenAIServer, - model: dict[str, Any]): - + def test_score_max_model_len( + self, server: RemoteOpenAIServer, model: dict[str, Any] + ): text_1 = "What is the capital of France?" * 20 text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) assert score_response.status_code == 400 # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - score_response.text + assert "Please reduce the length of the input." in score_response.text # Test truncation - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - "truncate_prompt_tokens": 101 - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "truncate_prompt_tokens": 101, + }, + ) assert score_response.status_code == 400 - assert "Please, select a smaller truncation size." in \ - score_response.text + assert "Please, select a smaller truncation size." in score_response.text - def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, - Any]): + def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." @@ -204,59 +196,61 @@ class TestModel: "text_2": text_2, } - score_response = requests.post(server.url_for("score"), - json=request_args) + score_response = requests.post(server.url_for("score"), json=request_args) score_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() score_output = score_response.json() invocation_output = invocation_response.json() assert score_output.keys() == invocation_output.keys() - for score_data, invocation_data in zip(score_output["data"], - invocation_output["data"]): + for score_data, invocation_data in zip( + score_output["data"], invocation_output["data"] + ): assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( - invocation_data["score"], rel=0.05) + invocation_data["score"], rel=0.05 + ) # TODO: reset this tolerance to 0.01 once we find # an alternative to flash_attn with bfloat16 - def test_activation(self, server: RemoteOpenAIServer, model: dict[str, - Any]): - + def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): def get_outputs(activation): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." - response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - "activation": activation - }) + response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "activation": activation, + }, + ) if response.status_code != 200: return response outputs = response.json() - return torch.tensor([x['score'] for x in outputs["data"]]) + return torch.tensor([x["score"] for x in outputs["data"]]) if model["is_cross_encoder"]: - default = get_outputs(activation=None) w_activation = get_outputs(activation=True) wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - F.sigmoid(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) else: get_outputs(activation=None) diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/pooling/openai/test_truncation.py similarity index 75% rename from tests/entrypoints/openai/test_truncation.py rename to tests/entrypoints/pooling/openai/test_truncation.py index 18ddc493c9283..6889628dc9145 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/pooling/openai/test_truncation.py @@ -54,12 +54,24 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert response["usage"]["prompt_tokens"] == truncation_size + + +@pytest.mark.asyncio +async def test_zero_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 0 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size, + } + + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == truncation_size @@ -70,7 +82,7 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } with pytest.raises(openai.BadRequestError) as err: @@ -79,9 +91,11 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): assert err.value.status_code == 400 error_details = err.value.response.json()["error"] assert error_details["type"] == "BadRequestError" - expected_message = ("truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + expected_message = ( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." + ) assert error_details["message"] == expected_message @@ -91,11 +105,9 @@ async def test_max_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == max_model_len diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/pooling/openai/test_vision_embedding.py similarity index 52% rename from tests/entrypoints/openai/test_vision_embedding.py rename to tests/entrypoints/pooling/openai/test_vision_embedding.py index 4e6a21058658b..944392d66fa5f 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/pooling/openai/test_vision_embedding.py @@ -5,26 +5,24 @@ import json import pytest import requests -from PIL import Image from transformers import AutoProcessor +from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.multimodal.utils import encode_image_base64, fetch_image -from ...utils import VLLM_PATH, RemoteOpenAIServer - MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MAXIMUM_IMAGES = 2 -vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja" +vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec_phi3v.jinja" assert vlm2vec_jinja_path.exists() # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] @@ -50,61 +48,50 @@ def server(): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_url: encode_image_base64(local_asset_server.get_image_asset(image_url)) + for image_url in TEST_IMAGE_ASSETS } def get_hf_prompt_tokens(model_name, content, image_url): - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True, - num_crops=4) + processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True, num_crops=4 + ) placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" - images = [Image.open(requests.get(image_url, stream=True).raw)] + images = [fetch_image(image_url)] inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, - image_url: str): +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) +async def test_image_embedding( + server: RemoteOpenAIServer, model_name: str, image_url: str +): content_text = "Represent the given image." - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": content_text}, + ], + } + ] response = requests.post( server.url_for("v1/embeddings"), - json={ - "model": model_name, - "messages": messages, - "encoding_format": "float" - }, + json={"model": model_name, "messages": messages, "encoding_format": "float"}, ) response.raise_for_status() embeddings = EmbeddingResponse.model_validate(response.json()) - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert embeddings.id is not None assert len(embeddings.data) == 1 diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index e4af60a782651..e548f52e1e94d 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -10,8 +10,7 @@ from unittest.mock import patch import pytest -from vllm.v1.utils import (APIServerProcessManager, - wait_for_completion_or_failure) +from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure # Global variables to control worker behavior WORKER_RUNTIME_SECONDS = 0.5 @@ -30,26 +29,22 @@ def api_server_args(): """Fixture to provide arguments for APIServerProcessManager.""" sock = socket.socket() return { - "target_server_fn": - mock_run_api_server_worker, - "listen_address": - "localhost:8000", - "sock": - sock, - "args": - "test_args", # Simple string to avoid pickling issues - "num_servers": - 3, + "target_server_fn": mock_run_api_server_worker, + "listen_address": "localhost:8000", + "sock": sock, + "args": "test_args", # Simple string to avoid pickling issues + "num_servers": 3, "input_addresses": [ - "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", - "tcp://127.0.0.1:5003" + "tcp://127.0.0.1:5001", + "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003", ], "output_addresses": [ - "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", - "tcp://127.0.0.1:6003" + "tcp://127.0.0.1:6001", + "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003", ], - "stats_update_address": - "tcp://127.0.0.1:7000", + "stats_update_address": "tcp://127.0.0.1:7000", } @@ -60,7 +55,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): global WORKER_RUNTIME_SECONDS WORKER_RUNTIME_SECONDS = 0.5 - # Copy the args to avoid mutating the + # Copy the args to avoid mutating them args = api_server_args.copy() if not with_stats_update: @@ -95,8 +90,9 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): assert not proc.is_alive() -@patch("vllm.entrypoints.cli.serve.run_api_server_worker", - mock_run_api_server_worker) +@patch( + "vllm.entrypoints.cli.serve.run_api_server_worker_proc", mock_run_api_server_worker +) def test_wait_for_completion_or_failure(api_server_args): """Test that wait_for_completion_or_failure works with failures.""" global WORKER_RUNTIME_SECONDS @@ -118,8 +114,7 @@ def test_wait_for_completion_or_failure(api_server_args): result["exception"] = e # Start a thread to run wait_for_completion_or_failure - wait_thread = threading.Thread(target=run_with_exception_capture, - daemon=True) + wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Let all processes run for a short time @@ -174,8 +169,7 @@ def test_normal_completion(api_server_args): # Verify all processes have terminated for i, proc in enumerate(manager.processes): - assert not proc.is_alive( - ), f"Process {i} still alive after terminate()" + assert not proc.is_alive(), f"Process {i} still alive after terminate()" # Now call wait_for_completion_or_failure # since all processes have already @@ -198,13 +192,13 @@ def test_external_process_monitoring(api_server_args): # Create and start the external process # (simulates local_engine_manager or coordinator) spawn_context = multiprocessing.get_context("spawn") - external_proc = spawn_context.Process(target=mock_run_api_server_worker, - name="MockExternalProcess") + external_proc = spawn_context.Process( + target=mock_run_api_server_worker, name="MockExternalProcess" + ) external_proc.start() # Create the class to simulate a coordinator class MockCoordinator: - def __init__(self, proc): self.proc = proc @@ -228,14 +222,14 @@ def test_external_process_monitoring(api_server_args): def run_with_exception_capture(): try: - wait_for_completion_or_failure(api_server_manager=manager, - coordinator=mock_coordinator) + wait_for_completion_or_failure( + api_server_manager=manager, coordinator=mock_coordinator + ) except Exception as e: result["exception"] = e # Start a thread to run wait_for_completion_or_failure - wait_thread = threading.Thread(target=run_with_exception_capture, - daemon=True) + wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Terminate the external process to trigger a failure @@ -246,21 +240,23 @@ def test_external_process_monitoring(api_server_args): wait_thread.join(timeout=1.0) # The wait thread should have completed - assert not wait_thread.is_alive( - ), "wait_for_completion_or_failure thread still running" + assert not wait_thread.is_alive(), ( + "wait_for_completion_or_failure thread still running" + ) # Verify that an exception was raised with appropriate error message assert result["exception"] is not None, "No exception was raised" error_message = str(result["exception"]) - assert "died with exit code" in error_message, \ + assert "died with exit code" in error_message, ( f"Unexpected error message: {error_message}" - assert "MockExternalProcess" in error_message, \ + ) + assert "MockExternalProcess" in error_message, ( f"Error doesn't mention external process: {error_message}" + ) # Verify that all API server processes were terminated as a result for i, proc in enumerate(manager.processes): - assert not proc.is_alive( - ), f"API server process {i} was not terminated" + assert not proc.is_alive(), f"API server process {i} was not terminated" finally: # Clean up diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 647f1c7b7f34f..dcd196ebdd772 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,25 +6,29 @@ from collections.abc import Mapping from typing import Literal, Optional import pytest -from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy, - SpecialTokens) -from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, - Tekkenizer) +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, - parse_chat_messages, - parse_chat_messages_futures, - resolve_chat_template_content_format, - resolve_hf_chat_template) -from vllm.entrypoints.llm import apply_hf_chat_template -from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, - encode_video_base64) -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.entrypoints.chat_utils import ( + _try_extract_ast, + apply_mistral_chat_template, + load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, + resolve_chat_template_content_format, + resolve_chat_template_kwargs, + resolve_hf_chat_template, +) +from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict +from vllm.multimodal.utils import ( + encode_audio_base64, + encode_image_base64, + encode_video_base64, +) +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from ..models.registry import HF_EXAMPLE_MODELS @@ -38,7 +42,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" -MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" +QWEN3_MODEL_ID = "Qwen/Qwen3-8B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -46,112 +50,103 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @pytest.fixture(scope="function") def phi3v_model_config(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="function") def phi3v_model_config_mm_interleaved(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") def phi3v_tokenizer(): - return TokenizerGroup( - tokenizer_id=PHI3V_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, + return get_tokenizer(PHI3V_MODEL_ID) + + +@pytest.fixture(scope="function") +def qwen2_audio_model_config(): + return ModelConfig( + QWEN2AUDIO_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "audio": 1, + }, ) +@pytest.fixture(scope="module") +def qwen2_audio_tokenizer(): + return get_tokenizer(QWEN2AUDIO_MODEL_ID) + + @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): - return ModelConfig(QWEN25OMNI_MODEL_ID, - runner="generate", - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - "audio": 1, - "video": 1, - }) + return ModelConfig( + QWEN25OMNI_MODEL_ID, + runner="generate", + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + "audio": 1, + "video": 1, + }, + ) @pytest.fixture(scope="module") def qwen25omni_tokenizer(): - return TokenizerGroup( - tokenizer_id=QWEN25OMNI_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, - ) - - -@pytest.fixture(scope="module") -def mllama_model_config(): - return ModelConfig(MLLAMA_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) - - -@pytest.fixture(scope="module") -def mllama_tokenizer(): - return TokenizerGroup( - MLLAMA_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, - ) + return get_tokenizer(QWEN25OMNI_MODEL_ID) @pytest.fixture(scope="function") def mistral_model_config(): - return ModelConfig(MISTRAL_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) - - -@pytest.fixture(scope="module") -def mistral_tokenizer(): - return TokenizerGroup( - tokenizer_id=MISTRAL_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, + return ModelConfig( + MISTRAL_MODEL_ID, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, ) +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer(MISTRAL_MODEL_ID) + + @pytest.fixture(scope="module") def image_url(): - image = ImageAsset('cherry_blossom') + image = ImageAsset("cherry_blossom") base64 = encode_image_base64(image.pil_image) return f"data:image/jpeg;base64,{base64}" @pytest.fixture(scope="module") def video_url(): - video = VideoAsset('baby_reading', 1) + video = VideoAsset("baby_reading", 1) base64 = encode_video_base64(video.np_ndarrays) return f"data:video/jpeg;base64,{base64}" @pytest.fixture(scope="module") def audio_url(): - audio = AudioAsset('mary_had_lamb') + audio = AudioAsset("mary_had_lamb") base64 = encode_audio_base64(*audio.audio_and_sample_rate) return f"data:audio/ogg;base64,{base64}" @@ -159,6 +154,7 @@ def audio_url(): def _assert_mm_data_is_image_input( mm_data: Optional[MultiModalDataDict], image_count: int, + skipped_image_indices: Optional[list] = None, ) -> None: assert mm_data is not None assert set(mm_data.keys()) == {"image"} @@ -167,6 +163,29 @@ def _assert_mm_data_is_image_input( assert image_data is not None assert isinstance(image_data, list) and len(image_data) == image_count + if skipped_image_indices is not None: + for i in skipped_image_indices: + assert image_data[i] is None + + +def _assert_mm_uuids( + mm_uuids: Optional[MultiModalUUIDDict], + media_count: int, + expected_uuids: list[Optional[str]], + modality: str = "image", +) -> None: + if len(expected_uuids) > 0: + assert mm_uuids is not None + assert modality in mm_uuids + + image_uuids = mm_uuids.get(modality) + assert image_uuids is not None + + assert isinstance(image_uuids, list) and len(image_uuids) == media_count + + assert image_uuids == expected_uuids + else: + assert mm_uuids is None ModalityType = Literal["image", "video", "audio"] @@ -176,6 +195,7 @@ MultiModalDataCounts = Mapping[ModalityType, int] def _assert_mm_data_inputs( mm_data: Optional[MultiModalDataDict], data_count: MultiModalDataCounts, + skipped_media_indices: Optional[dict[str, list]] = None, # modality -> list[int] ) -> None: assert mm_data is not None assert set(data_count.keys()) == (set(mm_data.keys())) @@ -185,36 +205,464 @@ def _assert_mm_data_inputs( assert modality_data is not None assert isinstance(modality_data, list) and len(modality_data) == n + if skipped_media_indices is not None: + skipped_media_indices_for_modality = skipped_media_indices.get(modality) + assert skipped_media_indices_for_modality is not None + for i in skipped_media_indices_for_modality: + assert modality_data[i] is None + def test_parse_chat_messages_single_image( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_single_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_empty_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] + _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_image_with_bad_uuid_format( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + "uuid": image_uuid, + }, + "bad_uuid_key": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in the image?", + } + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +def test_parse_chat_messages_multiple_empty_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in the image?", + } + ] + _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[0, 1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +def test_parse_chat_messages_mixed_empty_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in the image?", + } + ] + _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] + _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_empty_image_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] + _assert_mm_data_is_image_input(await mm_future, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": None, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] + _assert_mm_data_is_image_input(await mm_future, 2, skipped_image_indices=[0, 1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, image_uuid2]) def test_parse_chat_messages_empty_system( @@ -222,59 +670,40 @@ def test_parse_chat_messages_empty_system( mistral_tokenizer, ): # Test string format - conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], + conversation, _, _ = parse_chat_messages( + [ + {"role": "system", "content": ""}, + { + "role": "user", + "content": [{"type": "text", "text": "Who are you?"}], + }, + ], mistral_model_config, mistral_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": "Who are you?" - }] + assert conversation == [ + {"role": "system", "content": ""}, + {"role": "user", "content": "Who are you?"}, + ] # Test openai format - conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], + conversation, _, _ = parse_chat_messages( + [ + {"role": "system", "content": ""}, + { + "role": "user", + "content": [{"type": "text", "text": "Who are you?"}], + }, + ], mistral_model_config, mistral_tokenizer, content_format="openai", ) - assert conversation == [{ - "role": "system", - "content": [{ - "type": "text", - "text": "" - }] - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }] + assert conversation == [ + {"role": "system", "content": [{"type": "text", "text": ""}]}, + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ] @pytest.mark.asyncio @@ -283,30 +712,26 @@ async def test_parse_chat_messages_single_image_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) def test_parse_chat_messages_multiple_images( @@ -314,35 +739,129 @@ def test_parse_chat_messages_multiple_images( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_empty_pil_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_pil", "image_pil": None, "uuid": uuid}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] + _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) + + +def test_parse_chat_messages_empty_image_embeds_with_uuid( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": None, "uuid": uuid}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] + assert mm_data is not None + assert "image" in mm_data + assert mm_data["image"] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": None, "uuid": uuid}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] + mm_data = await mm_future + assert mm_data is not None + assert "image" in mm_data + assert mm_data["image"] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) @pytest.mark.asyncio @@ -351,155 +870,20 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] - }], - phi3v_model_config, - phi3v_tokenizer, - content_format="string", - ) - - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" - }] - _assert_mm_data_is_image_input(await mm_future, 2) - - -def test_parse_chat_messages_placeholder_already_in_prompt( - phi3v_model_config, - phi3v_tokenizer, - image_url, -): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], - phi3v_model_config, - phi3v_tokenizer, - content_format="string", - ) - assert conversation == [{ - "role": - "user", - "content": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - _assert_mm_data_is_image_input(mm_data, 2) - - -def test_parse_chat_messages_placeholder_one_already_in_prompt( - phi3v_model_config, - phi3v_tokenizer, - image_url, -): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 - } - ] - }], - phi3v_model_config, - phi3v_tokenizer, - content_format="string", - ) - - assert conversation == [{ - "role": - "user", - "content": - "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " - "other one?" - }] - _assert_mm_data_is_image_input(mm_data, 2) - - -def test_parse_chat_messages_multiple_images_across_messages( - phi3v_model_config, - phi3v_tokenizer, - image_url, -): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about this one?" - }] - }], + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -508,38 +892,178 @@ def test_parse_chat_messages_multiple_images_across_messages( assert conversation == [ { "role": "user", - "content": "<|image_1|>\nWhat's in this image?" - }, - { - "role": "assistant", - "content": "Some stuff." - }, + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_placeholder_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501 + }, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + assert conversation == [ { "role": "user", - "content": "<|image_2|>\nWhat about this one?" - }, + "content": "What's in <|image_1|> and how does it compare to <|image_2|>?", + } ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "What's in <|image_1|> and how does it compare to " + "the other one?", + }, + ], + } + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_2|>\nWhat's in <|image_1|> and how does it compare to " + "the other one?", + } + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What about this one?"}, + ], + }, + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in this image?"}, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "<|image_2|>\nWhat about this one?"}, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in this image?"}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "What about this one?"}, + ], + }, + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in this image?"}, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "<|image_2|>\nWhat about this one?"}, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What about this one?" - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [{"type": "text", "text": "What's in this text?"}], + }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What about this one?"}, + ], phi3v_model_config, phi3v_tokenizer, content_format="openai", @@ -548,26 +1072,19 @@ def test_parse_chat_messages_context_text_format( assert conversation == [ { "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] + "content": [{"type": "text", "text": "What's in this text?"}], }, { "role": "assistant", - "content": [{ - "type": "text", - "text": "Some stuff." - }] + "content": [{"type": "text", "text": "Some stuff."}], }, { "role": "user", - "content": [{ - "type": "text", - "text": "What about this one?" - }] + "content": [{"type": "text", "text": "What about this one?"}], }, ] + assert mm_data is None + assert mm_uuids is None def test_parse_chat_messages_rejects_too_many_images_in_one_message( @@ -578,32 +1095,30 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -618,42 +1133,37 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + {"type": "text", "text": "What's in this image?"}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + {"type": "text", "text": "What about these two?"}, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -665,30 +1175,30 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - "What's in these images?", { - "image_url": image_url - }, { - "image_url": image_url - } - ] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + "What's in these images?", + {"image_url": image_url}, + {"image_url": image_url}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_interleave( @@ -696,44 +1206,36 @@ def test_parse_chat_messages_multiple_images_interleave( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "and this one"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @pytest.mark.asyncio @@ -742,44 +1244,83 @@ async def test_parse_chat_messages_multiple_images_interleave_async( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "and this one"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "and this one"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] + _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_multiple_images_multiple_messages_interleave( @@ -787,138 +1328,354 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Be accurate." - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Be accurate."}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|image_1|>\nBe accurate." - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What's on this image?\n<|image_2|>" - }] + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What's on this image?\n<|image_2|>"}, + ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "Be accurate."}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + ], + }, + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What's on this image?\n<|image_2|>"}, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( - qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, - image_url, video_url, audio_url): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "And what's in the video?" - }, { - "type": "video_url", - "video_url": { - "url": video_url - } - }] - }], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Now listen to this audio"}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "And what's in the video?"}, + {"type": "video_url", "video_url": {"url": video_url}}, + ], + }, + ], qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>" - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>" - }] + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, 2, modality="image", expected_uuids=[None, None]) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=[None]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": "image_123", + }, + {"type": "text", "text": "Now listen to this audio"}, + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + "uuid": "audio_123", + }, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": "image_123", + }, + {"type": "text", "text": "And what's in the video?"}, + { + "type": "video_url", + "video_url": {"url": video_url}, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids( + mm_uuids, 2, modality="image", expected_uuids=["image_123", "image_123"] + ) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": None, + "uuid": "image_123", + }, + {"type": "text", "text": "Now listen to this audio"}, + { + "type": "audio_url", + "audio_url": None, + "uuid": "audio_123", + }, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": None, + "uuid": "image_123", + }, + {"type": "text", "text": "And what's in the video?"}, + { + "type": "video_url", + "video_url": None, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs( + mm_data, + {"image": 2, "video": 1, "audio": 1}, + skipped_media_indices={"image": [0, 1], "video": [0], "audio": [0]}, + ) + _assert_mm_uuids( + mm_uuids, 2, modality="image", expected_uuids=["image_123", "image_123"] + ) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": "image_123", + }, + {"type": "text", "text": "Now listen to this audio"}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, + ], + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "And what's in the video?"}, + { + "type": "video_url", + "video_url": {"url": video_url}, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", + }, + {"role": "assistant", "content": "Some stuff."}, + { + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, 2, modality="image", expected_uuids=["image_123", None]) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) def test_parse_chat_messages_multiple_images_interleave_with_placeholders( @@ -927,206 +1684,38 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( image_url, ): with pytest.raises( - ValueError, - match=r"Found more '<|image_1|>' placeholders in input prompt " - "than actual multimodal data items."): + ValueError, + match=r"Found more '<|image_1|>' placeholders in input prompt " + "than actual multimodal data items.", + ): parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" - }, - ] - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + }, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) -### Mllama currently wraps images / texts as interleaved dictionaries -def test_mllama_single_image( - mllama_model_config, - mllama_tokenizer, - image_url, -): - """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] - }], - mllama_model_config, - mllama_tokenizer, - content_format="openai", - ) - _assert_mm_data_is_image_input(mm_data, 1) - assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - 'type': 'image' - }] - }] - - -def test_mllama_interleaved_images( - mllama_model_config, - mllama_tokenizer, - image_url, -): - """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - { - "image_url": image_url - }, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - { - "image_url": image_url - }, - ] - }], - mllama_model_config, - mllama_tokenizer, - content_format="openai", - ) - _assert_mm_data_is_image_input(mm_data, 2) - assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of the first image is:' - }, { - 'type': 'image' - }, { - 'type': 'text', - 'text': 'The content of the second image is:' - }, { - 'type': 'image' - }] - }] - - -@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID]) -def test_multimodal_image_parsing_matches_hf(model, image_url): - """Checks end to end hf alignment for multimodal [image] parsing.""" - - def get_conversation(is_hf: bool): - img_part = {"type": "image_url", "image_url": {"url": image_url}} - if is_hf: - img_part = {'type': 'image'} - return [{ - 'role': - 'user', - 'content': [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - img_part, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - img_part, - { - 'type': 'text', - 'text': 'What animal is in the first image?' - }, - ] - }] - - # Build a config for the model - model_config = ModelConfig(model, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) - - # Build the tokenizer group and grab the underlying tokenizer - tokenizer_group = TokenizerGroup( - model, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, - trust_remote_code=model_config.trust_remote_code, - ) - tokenizer = tokenizer_group.tokenizer - - # Build and parse a conversation with {"type": "image"} using the tokenizer - hf_conversation = get_conversation(is_hf=True) - hf_result = tokenizer.apply_chat_template( - hf_conversation, - tokenize=False, - add_generation_prompt=True, - ) - - # Now parse with vLLMs chat utils & apply the template - vllm_conversation = get_conversation(is_hf=False) - conversation, _ = parse_chat_messages( - vllm_conversation, - model_config, - tokenizer_group, - content_format="openai", - ) - - vllm_result = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - chat_template=None, - model_config=model_config, - tools=None, - add_generation_prompt=True, - ) - - assert hf_result == vllm_result - - @pytest.mark.parametrize( "model", [ QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str HERMES_MODEL_ID, # tokenizer.chat_template is of type dict - ]) + ], +) @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" @@ -1140,26 +1729,31 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) - # Build the tokenizer group and grab the underlying tokenizer - tokenizer_group = TokenizerGroup( + # Build the tokenizer + tokenizer = get_tokenizer( model, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - tokenizer = tokenizer_group.tokenizer - tools = [{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }] if use_tools else None + tools = ( + [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + if use_tools + else None + ) # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1171,20 +1765,108 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): assert isinstance(chat_template, str) +@pytest.mark.parametrize( + "model, expected_kwargs", + [ + ( + QWEN2VL_MODEL_ID, + { + "add_vision_id", + "add_generation_prompt", + "continue_final_message", + "tools", + }, + ), + ( + QWEN3_MODEL_ID, + { + "enable_thinking", + "add_generation_prompt", + "continue_final_message", + "tools", + }, + ), + ], +) +def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwargs): + """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + tools = [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + + chat_template_kwargs = { + # both unused + "unsed_kwargs_1": 123, + "unsed_kwargs_2": "abc", + # should not appear + "chat_template": "{% Hello world! %}", + # used by tokenizer + "continue_final_message": True, + "tools": tools, + # both used by Qwen2-VL and Qwen3 + "add_generation_prompt": True, + # only used by Qwen2-VL + "add_vision_id": True, + # only used by Qwen3 + "enable_thinking": True, + } + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, + ) + + # Build the tokenizer + tokenizer = get_tokenizer( + model, + trust_remote_code=model_config.trust_remote_code, + ) + + # Test detecting the tokenizer's chat_template + chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=tools, + model_config=model_config, + ) + resolved_chat_template_kwargs = resolve_chat_template_kwargs( + tokenizer, + chat_template=chat_template, + chat_template_kwargs=chat_template_kwargs, + ) + assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs + + # NOTE: Qwen2-Audio default chat template is specially defined inside # processor class instead of using `tokenizer_config.json` -# yapf: disable @pytest.mark.parametrize( ("model", "expected_format"), - [(PHI3V_MODEL_ID, "string"), - (QWEN2VL_MODEL_ID, "openai"), - (QWEN25VL_MODEL_ID, "openai"), - (ULTRAVOX_MODEL_ID, "string"), - (QWEN2AUDIO_MODEL_ID, "openai"), - (MLLAMA_MODEL_ID, "openai"), - (LLAMA_GUARD_MODEL_ID, "openai")], + [ + (PHI3V_MODEL_ID, "string"), + (QWEN2VL_MODEL_ID, "openai"), + (QWEN25VL_MODEL_ID, "openai"), + (ULTRAVOX_MODEL_ID, "string"), + (QWEN2AUDIO_MODEL_ID, "openai"), + (LLAMA_GUARD_MODEL_ID, "openai"), + ], ) -# yapf: enable def test_resolve_content_format_hf_defined(model, expected_format): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -1196,16 +1878,15 @@ def test_resolve_content_format_hf_defined(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) - tokenizer_group = TokenizerGroup( + tokenizer = get_tokenizer( model, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1232,19 +1913,18 @@ def test_resolve_content_format_hf_defined(model, expected_format): assert resolved_format == expected_format -# yapf: disable @pytest.mark.parametrize( ("model", "expected_format"), - [("Salesforce/blip2-opt-2.7b", "string"), - ("facebook/chameleon-7b", "string"), - ("deepseek-ai/deepseek-vl2-tiny", "string"), - ("microsoft/Florence-2-base", "string"), - ("adept/fuyu-8b", "string"), - ("google/paligemma-3b-mix-224", "string"), - ("Qwen/Qwen-VL", "string"), - ("Qwen/Qwen-VL-Chat", "string")], + [ + ("Salesforce/blip2-opt-2.7b", "string"), + ("facebook/chameleon-7b", "string"), + ("deepseek-ai/deepseek-vl2-tiny", "string"), + ("adept/fuyu-8b", "string"), + ("google/paligemma-3b-mix-224", "string"), + ("Qwen/Qwen-VL", "string"), + ("Qwen/Qwen-VL-Chat", "string"), + ], ) -# yapf: enable def test_resolve_content_format_fallbacks(model, expected_format): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -1256,16 +1936,15 @@ def test_resolve_content_format_fallbacks(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) - tokenizer_group = TokenizerGroup( + tokenizer = get_tokenizer( model_config.tokenizer, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1292,29 +1971,30 @@ def test_resolve_content_format_fallbacks(model, expected_format): assert resolved_format == expected_format -# yapf: disable @pytest.mark.parametrize( ("template_path", "expected_format"), - [("template_alpaca.jinja", "string"), - ("template_baichuan.jinja", "string"), - ("template_chatglm.jinja", "string"), - ("template_chatglm2.jinja", "string"), - ("template_chatml.jinja", "string"), - ("template_dse_qwen2_vl.jinja", "openai"), - ("template_falcon_180b.jinja", "string"), - ("template_falcon.jinja", "string"), - ("template_inkbot.jinja", "string"), - ("template_teleflm.jinja", "string"), - ("template_vlm2vec.jinja", "openai"), - ("tool_chat_template_granite_20b_fc.jinja", "string"), - ("tool_chat_template_hermes.jinja", "string"), - ("tool_chat_template_internlm2_tool.jinja", "string"), - ("tool_chat_template_llama3.1_json.jinja", "openai"), - ("tool_chat_template_llama3.2_json.jinja", "openai"), - ("tool_chat_template_mistral_parallel.jinja", "string"), - ("tool_chat_template_mistral.jinja", "string")], + [ + ("template_alpaca.jinja", "string"), + ("template_baichuan.jinja", "string"), + ("template_chatglm.jinja", "string"), + ("template_chatglm2.jinja", "string"), + ("template_chatml.jinja", "string"), + ("template_dse_qwen2_vl.jinja", "openai"), + ("template_falcon_180b.jinja", "string"), + ("template_falcon.jinja", "string"), + ("template_inkbot.jinja", "string"), + ("template_teleflm.jinja", "string"), + ("template_vlm2vec_phi3v.jinja", "openai"), + ("template_vlm2vec_qwen2vl.jinja", "openai"), + ("tool_chat_template_granite_20b_fc.jinja", "string"), + ("tool_chat_template_hermes.jinja", "string"), + ("tool_chat_template_internlm2_tool.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "openai"), + ("tool_chat_template_llama3.2_json.jinja", "openai"), + ("tool_chat_template_mistral_parallel.jinja", "string"), + ("tool_chat_template_mistral.jinja", "string"), + ], ) -# yapf: enable def test_resolve_content_format_examples(template_path, expected_format): model_config = ModelConfig( PHI3V_MODEL_ID, # Dummy @@ -1322,14 +2002,10 @@ def test_resolve_content_format_examples(template_path, expected_format): trust_remote_code=True, ) - tokenizer_group = TokenizerGroup( + dummy_tokenizer = get_tokenizer( PHI3V_MODEL_ID, # Dummy - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer.chat_template = None chat_template = load_chat_template(EXAMPLES_DIR / template_path) @@ -1351,163 +2027,186 @@ def test_resolve_content_format_examples(template_path, expected_format): assert resolved_format == expected_format -def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, - mistral_tokenizer): - messages = [{ - "role": - "system", - "content": [{ - "type": "text", - "text": "You are a helpful assistant." - }, { - "type": - "thinking", - "closed": - True, - "thinking": - "Only return the answer when you are confident." - }] - }, { - "role": "user", - "content": "What is 2+2?" - }, { - "role": - "assistant", - "content": [{ - "type": "text", - "text": "Let me think about it." - }, { - "type": "thinking", - "closed": True, - "thinking": "2+2 = 4" - }, { - "type": "text", - "text": "The answer is 4.", - }], - }] +def test_parse_chat_messages_include_thinking_chunk( + mistral_model_config, mistral_tokenizer +): + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "thinking", + "closed": True, + "thinking": "Only return the answer when you are confident.", + }, + ], + }, + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "thinking", "closed": True, "thinking": "2+2 = 4"}, + { + "type": "text", + "text": "The answer is 4.", + }, + ], + }, + ] - conversation_with_thinking, _ = parse_chat_messages( + conversation_with_thinking, _, _ = parse_chat_messages( messages, mistral_model_config, mistral_tokenizer, content_format="openai", ) - expected_conversation = [{ - "role": - "system", - "content": [{ - "type": "text", - "text": "You are a helpful assistant." - }, { - "type": "text", - "text": "Only return the answer when you are confident." - }], - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What is 2+2?" - }], - }, { - "role": - "assistant", - "content": [ - { - "type": "text", - "text": "Let me think about it." - }, - { - "type": "text", - "text": "2+2 = 4" - }, - { - "type": "text", - "text": "The answer is 4." - }, - ] - }] + expected_conversation = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "text", + "text": "Only return the answer when you are confident.", + }, + ], + }, + { + "role": "user", + "content": [{"type": "text", "text": "What is 2+2?"}], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "text", "text": "2+2 = 4"}, + {"type": "text", "text": "The answer is 4."}, + ], + }, + ] assert conversation_with_thinking == expected_conversation def test_apply_mistral_chat_template_thinking_chunk(): - # Moved import here to avoid yapf and isort conflicts - from vllm.entrypoints.chat_utils import apply_mistral_chat_template - messages = [{ - "role": - "system", - "content": [{ - "type": "text", - "text": "You are a helpful assistant." - }, { - "type": - "thinking", - "closed": - True, - "thinking": - "Only return the answer when you are confident." - }] - }, { - "role": "user", - "content": "What is 2+2?" - }, { - "role": - "assistant", - "content": [{ - "type": "text", - "text": "Let me think about it." - }, { - "type": "thinking", - "closed": True, - "thinking": "2+2 = 4" - }, { - "type": "text", - "text": "The answer is 4.", - }], - }, { - "role": "user", - "content": "Thanks, what is 3+3?" - }] - - # TODO(Julien): upon model release change to a tokenizer already configured. - # ================================================================= + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "thinking", + "closed": True, + "thinking": "Only return the answer when you are confident.", + }, + ], + }, + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "thinking", "closed": True, "thinking": "2+2 = 4"}, + { + "type": "text", + "text": "The answer is 4.", + }, + ], + }, + {"role": "user", "content": "Thanks, what is 3+3?"}, + ] mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507") - assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) - # Add think special tokens to the tokenizer - mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) - mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value) - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { - k: v - for k, v in - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() - if v not in {35, 36} - } - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value] = 35 - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value] = 36 - mistral_tokenizer.instruct.BEGIN_THINK = 35 - mistral_tokenizer.instruct.END_THINK = 36 - # ================================================================= + "mistralai/Magistral-Small-2509" + ) - tokens_ids = apply_mistral_chat_template(mistral_tokenizer, - messages, - chat_template=None, - tools=None) + tokens_ids = apply_mistral_chat_template( + mistral_tokenizer, messages, chat_template=None, tools=None + ) string_tokens = mistral_tokenizer.mistral.decode( - tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP) + tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP + ) expected_tokens = ( r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the" r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]" r"[INST]What is 2+2?[/INST]" r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>" - r"[INST]Thanks, what is 3+3?[/INST]") + r"[INST]Thanks, what is 3+3?[/INST]" + ) assert string_tokens == expected_tokens + + +def test_parse_chat_messages_single_empty_audio_with_uuid( + qwen2_audio_model_config, + qwen2_audio_tokenizer, +): + audio_uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + {"type": "text", "text": "What does the audio say?"}, + ], + } + ], + qwen2_audio_model_config, + qwen2_audio_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the " + "audio say?", + } + ] + _assert_mm_data_inputs(mm_data, {"audio": 1}) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[audio_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_empty_audio_with_uuid_async( + qwen2_audio_model_config, + qwen2_audio_tokenizer, +): + audio_uuid = "abcd" + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + {"type": "text", "text": "What does the audio say?"}, + ], + } + ], + qwen2_audio_model_config, + qwen2_audio_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the " + "audio say?", + } + ] + _assert_mm_data_inputs(await mm_future, {"audio": 1}) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[audio_uuid]) diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py new file mode 100644 index 0000000000000..b0faa870a9272 --- /dev/null +++ b/tests/entrypoints/test_context.py @@ -0,0 +1,524 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest +from openai_harmony import Author, Message, Role, StreamState, TextContent + +from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext +from vllm.outputs import CompletionOutput, RequestOutput + + +def create_mock_request_output( + prompt_token_ids=None, + output_token_ids=None, + num_cached_tokens=0, + finished=True, +): + """Helper function to create a mock RequestOutput object for testing.""" + outputs = [] + token_ids = output_token_ids if output_token_ids is not None else [] + outputs = [ + CompletionOutput( + index=0, + text="Test output", + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=None, + finish_reason=None, + stop_reason=None, + ) + ] + + return RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=outputs, + finished=finished, + num_cached_tokens=num_cached_tokens, + ) + + +async def generate_mock_outputs( + num_turns, prompt_token_counts, output_token_counts, cached_token_counts=None +): + """Generate a sequence of mock RequestOutput objects to simulate multiple + turns.""" + if cached_token_counts is None: + cached_token_counts = [0] * num_turns + + for i in range(num_turns): + # Create mock prompt token IDs and output token IDs + prompt_token_ids = list(range(1, prompt_token_counts[i] + 1)) + output_token_ids = list(range(1, output_token_counts[i] + 1)) + + # Create and yield the RequestOutput + yield create_mock_request_output( + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + num_cached_tokens=cached_token_counts[i], + ) + + +@pytest.fixture +def mock_parser(): + """Set up a mock parser for tests.""" + with patch( + "vllm.entrypoints.context.get_streamable_parser_for_assistant" + ) as mock_parser_factory: + # Create a mock parser object + parser = MagicMock() + parser.messages = [] + parser.current_channel = None + parser.state = StreamState.EXPECT_START + mock_parser_factory.return_value = parser + yield parser + + +def test_single_turn_token_counting(): + """Test token counting behavior for a single turn.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a mock RequestOutput with specific token counts + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3, 4, 5], # 5 prompt tokens + output_token_ids=[6, 7, 8], # 3 output tokens + num_cached_tokens=2, # 2 cached tokens + ) + + # Append the output to the context + context.append_output(mock_output) + + # Verify the token counts + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_cached_tokens == 2 + assert context.num_tool_output_tokens == 0 # No tool tokens in first turn + + # Verify internal state tracking + assert not context.is_first_turn + assert context.previous_turn.input_tokens == 5 + assert context.previous_turn.output_tokens == 3 + + +@pytest.mark.asyncio +async def test_multi_turn_token_counting(): + """Test token counting behavior across multiple turns with tool output.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate a conversation with 3 turns + # Turn 1: prefill 5, decode 3, tool 7 + # Turn 2: prefill 15, cached 5, decode 4, tool 1 + # Turn 3: prefill 20, cached 15, decode 5 + prompt_token_counts = [5, 15, 20] + output_token_counts = [3, 4, 5] + cached_token_counts = [0, 5, 15] + mock_generator = generate_mock_outputs( + 3, prompt_token_counts, output_token_counts, cached_token_counts + ) + + # First turn - initial prompt and response + mock_output1 = await anext(mock_generator) + context.append_output(mock_output1) + + # At this point, we should have 5 prompt tokens and 3 output tokens + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_tool_output_tokens == 0 + + # Second turn - after tool output + mock_output2 = await anext(mock_generator) + context.append_output(mock_output2) + # Current prompt tokens (15) - last_turn_input_tokens (5) - + # last_turn_output_tokens (3) = 7 + expected_tool_output = 7 + + assert context.num_prompt_tokens == 5 + 15 + assert context.num_output_tokens == 3 + 4 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + + # Third turn - final response + mock_output3 = await anext(mock_generator) + context.append_output(mock_output3) + # Additional tool output tokens from third turn: + # Current prompt (20) - last_turn_input_tokens (15) - + # last_turn_output_tokens (4) = 1 + expected_tool_output = 7 + 1 + + assert context.num_prompt_tokens == 5 + 15 + 20 + assert context.num_output_tokens == 3 + 4 + 5 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + 15 + + +def test_empty_output_tokens(): + """Test behavior when RequestOutput has empty output tokens.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a RequestOutput with empty output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[], # Empty output tokens list + num_cached_tokens=1, + ) + + context.append_output(mock_output) + + # Should handle empty outputs gracefully + assert context.num_prompt_tokens == 3 + assert context.num_output_tokens == 0 # No output tokens + assert context.num_cached_tokens == 1 + assert context.num_tool_output_tokens == 0 + + +def test_missing_prompt_token_ids(): + """Test behavior when RequestOutput has None prompt_token_ids.""" + context = HarmonyContext(messages=[], available_tools=[]) + + mock_output = create_mock_request_output( + prompt_token_ids=None, # No prompt token IDs + output_token_ids=[1, 2], # 2 output tokens + num_cached_tokens=0, + ) + + # Logger.error will be called, but we don't need to check for warnings + # here Just ensure it doesn't raise an exception + context.append_output(mock_output) + + # Should handle missing prompt tokens gracefully + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 2 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + + +def test_reasoning_tokens_counting(mock_parser): + """Test that reasoning tokens are counted correctly.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Mock parser to simulate reasoning channel + mock_parser.current_channel = "analysis" # Reasoning channel + + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], + output_token_ids=[4, 5, 6, 7], # 4 tokens, all in reasoning + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All output tokens should be counted as reasoning + assert context.num_reasoning_tokens == 4 + assert context.num_output_tokens == 4 + + +def test_zero_tokens_edge_case(): + """Test behavior with all zero token counts.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a request with empty lists (not None) for both prompt and + # output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[], # Empty prompt tokens + output_token_ids=[], # Empty output tokens + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All counts should be zero + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 0 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + assert context.num_reasoning_tokens == 0 + + +@pytest.mark.asyncio +async def test_single_turn_no_tool_output(): + """Test that first turn never generates tool output tokens.""" + context = HarmonyContext( + messages=[], + available_tools=["browser"], # Tools available + ) + + # Even with large prompt in first turn, no tool tokens should be counted + mock_output = create_mock_request_output( + prompt_token_ids=list(range(100)), # 100 tokens + output_token_ids=[1, 2, 3], + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # First turn should never have tool output tokens + assert context.num_tool_output_tokens == 0 + assert context.is_first_turn is False # Should be updated after first turn + + +@pytest.mark.asyncio +async def test_negative_tool_tokens_edge_case(): + """Test edge case where calculation could result in negative tool + tokens. We should log an error and clamp the value to 0.""" + # Use patch to check if logger.error was called + with patch("vllm.entrypoints.context.logger.error") as mock_log: + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # First turn + mock_output1 = create_mock_request_output( + prompt_token_ids=list(range(10)), # 10 tokens + output_token_ids=[1, 2, 3, 4, 5], # 5 tokens + ) + context.append_output(mock_output1) + + # Second turn with fewer new tokens than previous output + # This could happen in edge cases with aggressive caching + mock_output2 = create_mock_request_output( + prompt_token_ids=list(range(12)), # 12 tokens (only 2 new) + output_token_ids=[6, 7], # 2 tokens + ) + context.append_output(mock_output2) + + # Calculated negative tool tokens (12 - 10 - 5 = -3) should be clamped + # to 0 and an error should be logged + assert context.num_tool_output_tokens == 0 + assert context.num_prompt_tokens == 10 + 12 + assert context.num_output_tokens == 5 + 2 + + # Verify the error was logged properly + mock_log.assert_called_once() + + # Extract the actual log message and arguments from the call + args, _ = mock_log.call_args + log_message = args[0] + + # Check for key parts of the message + assert "Negative tool output tokens calculated" in log_message + assert "-3" in str(args) # Check that -3 is in the arguments + + +@pytest.mark.asyncio +async def test_streaming_multi_turn_token_counting(mock_parser): + """Test token counting for streaming multi-turn conversations. + + This test focuses on how StreamingHarmonyContext counts tokens in a + multi-turn conversation with streaming (token-by-token) outputs and + message boundaries. + """ + # Create a streaming context + context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate three turns of conversation: + # Turn 1: stream tokens one by one, then finish the message + # Turn 2: new prompt, stream more tokens with a reasoning segment + # Turn 3: new prompt with tool output and cached tokens + + # First turn: 3 tokens streamed one by one + # First token of first turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[101], # Single token + num_cached_tokens=0, + finished=False, # Not end of message yet + ) + ) + + # Second token of first turn + context.append_output( + create_mock_request_output( + output_token_ids=[102], + finished=False, + ) + ) + + # Last token of first turn (finished=True signals end of message) + context.append_output( + create_mock_request_output( + output_token_ids=[103], + finished=True, # End of message + ) + ) + + # Check token counts after first turn + assert context.num_prompt_tokens == 3 # Initial prompt tokens + assert context.num_output_tokens == 3 # Three output tokens + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 # No tool output in first turn + assert context.first_tok_of_message is True # Ready for next message + + # Second turn: reasoning tokens in analysis channel + mock_parser.current_channel = "analysis" # Set to reasoning channel + + # First token of second turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[ + 1, + 2, + 3, + 101, + 102, + 103, + 4, + 5, + ], # 8 tokens (includes previous) + output_token_ids=[201], + num_cached_tokens=3, # Some tokens cached + finished=False, + ) + ) + + # More tokens in reasoning channel + context.append_output( + create_mock_request_output( + output_token_ids=[202], + finished=False, + ) + ) + + context.append_output( + create_mock_request_output( + output_token_ids=[203], + finished=True, # End of reasoning message + ) + ) + + # Check counts after second turn (reasoning message) + assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt + assert context.num_output_tokens == 3 + 3 # First turn + second turn + assert context.num_reasoning_tokens == 3 # All tokens in analysis channel + assert context.num_cached_tokens == 3 # Cached tokens from second turn + + # Formula: this turn prompt tokens - last turn prompt - last turn output + expected_tool_tokens = 8 - 3 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens + + # Third turn: regular output channel + mock_parser.current_channel = "final" # Switch back to regular channel + + # Third turn (with more cached tokens) + context.append_output( + create_mock_request_output( + prompt_token_ids=[ + 1, + 2, + 3, + 101, + 102, + 103, + 4, + 5, + 201, + 202, + 203, + 6, + 7, + ], # 13 tokens + output_token_ids=[301], + num_cached_tokens=8, # More cached tokens + finished=False, + ) + ) + + context.append_output( + create_mock_request_output( + output_token_ids=[302], + finished=True, + ) + ) + + # Final token counts check + assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts + assert context.num_output_tokens == 3 + 3 + 2 # All outputs + assert context.num_reasoning_tokens == 3 # Unchanged from second turn + assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens + + # Additional tool tokens from third turn + # Formula: this turn prompt - last turn prompt - last turn output + additional_tool_tokens = 13 - 8 - 3 # = 2 + assert ( + context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens + ) + + +@pytest.mark.asyncio +async def test_streaming_message_synchronization(mock_parser): + """Test message synchronization logic from lines 413-417 in context.py. + + This test verifies that when parser.messages contains more messages than + the context's _messages (minus initial messages), the context properly + extends its message list with the new parser messages. + """ + + # Create a streaming context with some initial messages + initial_messages = [ + Message( + author=Author(role=Role.USER, name="user"), + content=[TextContent(text="Hello")], + recipient=Role.ASSISTANT, + ) + ] + context = StreamingHarmonyContext(messages=initial_messages, available_tools=[]) + + # Verify initial state + assert len(context._messages) == 1 + assert context.num_init_messages == 1 + + # Mock parser to have more messages than context + # Simulate parser having processed 3 new messages + mock_parser.messages = [ + Message( + author=Author(role=Role.ASSISTANT, name="assistant"), + content=[TextContent(text="Response 1")], + recipient=Role.USER, + ), + ] + + # This should trigger the message synchronization logic + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3], output_token_ids=[101], finished=False + ) + ) + + # Verify that messages were synchronized + assert len(context._messages) == 2 + + # Verify the new messages were added correctly + assert context._messages[1].content[0].text == "Response 1" + + # Test the specific condition from line 413-414: + # len(self._messages) - self.num_init_messages < len(self.parser.messages) + messages_minus_init = len(context._messages) - context.num_init_messages + parser_messages_count = len(mock_parser.messages) + + # After synchronization, they should be equal (no longer less than) + assert messages_minus_init == parser_messages_count + + # Test edge case: add one more parser message + mock_parser.messages.append( + Message( + author=Author(role=Role.ASSISTANT, name="assistant"), + content=[TextContent(text="Response 4")], + recipient=Role.USER, + ) + ) + + # Create another output to trigger synchronization again + mock_output2 = create_mock_request_output( + prompt_token_ids=[1, 2, 3], output_token_ids=[102], finished=True + ) + + context.append_output(mock_output2) + + # Verify the fourth message was added, num_init_messages is still 1 + assert len(context._messages) == 3 + assert context.num_init_messages == 1 + assert context._messages[2].content[0].text == "Response 4" diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py new file mode 100644 index 0000000000000..f93978c3e6e72 --- /dev/null +++ b/tests/entrypoints/test_renderer.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import io +from dataclasses import dataclass +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +import pybase64 +import pytest +import torch + +from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig +from vllm.inputs.data import is_embeds_prompt + + +@dataclass +class MockModelConfig: + max_model_len: int = 100 + encoder_config: Optional[dict] = None + + +class MockTokenizerResult: + def __init__(self, input_ids): + self.input_ids = input_ids + + +@pytest.fixture +def mock_model_config(): + return MockModelConfig() + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + return tokenizer + + +@pytest.fixture +def mock_async_tokenizer(): + async_tokenizer = AsyncMock() + return async_tokenizer + + +@pytest.fixture +def renderer(mock_model_config, mock_tokenizer): + return CompletionRenderer( + model_config=mock_model_config, + tokenizer=mock_tokenizer, + async_tokenizer_pool={}, + ) + + +class TestRenderPrompt: + """Test Category A: Basic Functionality Tests""" + + @pytest.mark.asyncio + async def test_token_input(self, renderer): + tokens = [101, 7592, 2088] + results = await renderer.render_prompt( + prompt_or_prompts=tokens, config=RenderConfig(max_length=100) + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + + @pytest.mark.asyncio + async def test_token_list_input(self, renderer): + token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] + results = await renderer.render_prompt( + prompt_or_prompts=token_lists, config=RenderConfig(max_length=100) + ) + + assert len(results) == 3 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] + assert results[2]["prompt_token_ids"] == [103, 4567] + + @pytest.mark.asyncio + async def test_text_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + mock_async_tokenizer.assert_called_once() + + @pytest.mark.asyncio + async def test_text_list_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + text_list_input = ["Hello world", "How are you?", "Good morning"] + results = await renderer.render_prompt( + prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100) + ) + + assert len(results) == 3 + for result in results: + assert result["prompt_token_ids"] == [101, 7592, 2088] + assert mock_async_tokenizer.call_count == 3 + + @pytest.mark.asyncio + async def test_no_truncation(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) + ) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert ( + "truncation" not in call_args.kwargs + or call_args.kwargs["truncation"] is False + ) + + @pytest.mark.asyncio + async def test_truncation_positive(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088] + ) # Truncated + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100, truncate_prompt_tokens=50), + ) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 50 + + @pytest.mark.asyncio + async def test_truncation_negative(self, renderer, mock_async_tokenizer): + # Test that negative truncation uses model's max_model_len + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088] + ) # Truncated to max_model_len + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=200, truncate_prompt_tokens=-1), + ) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 100 # model's max_model_len + + @pytest.mark.asyncio + async def test_token_truncation_last_elements(self, renderer): + # Test that token truncation keeps the last N elements + long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens + results = await renderer.render_prompt( + prompt_or_prompts=long_tokens, + config=RenderConfig(max_length=100, truncate_prompt_tokens=5), + ) + + assert len(results) == 1 + # Should keep the last 5 tokens: [105, 106, 107, 108, 109] + assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] + + @pytest.mark.asyncio + async def test_max_length_exceeded(self, renderer): + long_tokens = list(range(150)) # Exceeds max_model_len=100 + + with pytest.raises(ValueError, match="maximum context length"): + await renderer.render_prompt( + prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100) + ) + + @pytest.mark.asyncio + async def test_no_tokenizer_for_text(self, mock_model_config): + renderer_no_tokenizer = CompletionRenderer( + model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={} + ) + + with pytest.raises(ValueError, match="No tokenizer available"): + await renderer_no_tokenizer.render_prompt( + prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) + ) + + @pytest.mark.asyncio + async def test_token_input_with_needs_detokenization( + self, renderer, mock_async_tokenizer + ): + # When needs_detokenization=True for token inputs, renderer should + # use the async tokenizer to decode and include the original text + # in the returned prompt object. + mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + tokens = [1, 2, 3, 4] + results = await renderer.render_prompt( + prompt_or_prompts=tokens, + config=RenderConfig(needs_detokenization=True), + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + assert results[0]["prompt"] == "decoded text" + mock_async_tokenizer.decode.assert_awaited_once() + + +class TestRenderEmbedPrompt: + def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: + """Helper to create base64-encoded tensor bytes""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return pybase64.b64encode(buffer.read()) + + @pytest.mark.asyncio + async def test_single_prompt_embed(self, renderer): + # Create a test tensor + test_tensor = torch.randn(10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(cache_salt="test_salt"), + ) + + assert len(results) == 1 + assert is_embeds_prompt(results[0]) + assert torch.allclose(results[0]["prompt_embeds"], test_tensor) + assert results[0]["cache_salt"] == "test_salt" + + @pytest.mark.asyncio + async def test_multiple_prompt_embeds(self, renderer): + # Create multiple test tensors + test_tensors = [ + torch.randn(8, 512, dtype=torch.float32), + torch.randn(12, 512, dtype=torch.float32), + ] + embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors] + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes_list, + config=RenderConfig(), + ) + + assert len(results) == 2 + for i, result in enumerate(results): + assert is_embeds_prompt(result) + assert torch.allclose(result["prompt_embeds"], test_tensors[i]) + + @pytest.mark.asyncio + async def test_prompt_embed_truncation(self, renderer): + # Create tensor with more tokens than truncation limit + test_tensor = torch.randn(20, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(truncate_prompt_tokens=10), + ) + + assert len(results) == 1 + # Should keep last 10 tokens + expected = test_tensor[-10:] + assert torch.allclose(results[0]["prompt_embeds"], expected) + + @pytest.mark.asyncio + async def test_prompt_embed_different_dtypes(self, renderer): + # Test different supported dtypes + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + for dtype in dtypes: + test_tensor = torch.randn(5, 256, dtype=dtype) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 1 + assert results[0]["prompt_embeds"].dtype == dtype + + @pytest.mark.asyncio + async def test_prompt_embed_squeeze_batch_dim(self, renderer): + # Test tensor with batch dimension gets squeezed + test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 1 + # Should be squeezed to 2D + assert results[0]["prompt_embeds"].shape == (10, 768) + + @pytest.mark.asyncio + async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer): + # Set up text tokenization + mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer + + # Create embed + test_tensor = torch.randn(5, 256, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_or_prompts="Hello world", + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 2 + # First should be embed prompt + assert is_embeds_prompt(results[0]) + # Second should be tokens prompt + assert "prompt_token_ids" in results[1] + assert results[1]["prompt_token_ids"] == [101, 102, 103] diff --git a/tests/entrypoints/test_ssl_cert_refresher.py b/tests/entrypoints/test_ssl_cert_refresher.py index 33ad2cfd3a33a..b56fbd9fee7e0 100644 --- a/tests/entrypoints/test_ssl_cert_refresher.py +++ b/tests/entrypoints/test_ssl_cert_refresher.py @@ -11,7 +11,6 @@ from vllm.entrypoints.ssl import SSLCertRefresher class MockSSLContext(SSLContext): - def __init__(self): self.load_cert_chain_count = 0 self.load_ca_count = 0 @@ -34,7 +33,7 @@ class MockSSLContext(SSLContext): def create_file() -> str: - with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f: + with tempfile.NamedTemporaryFile(dir="/tmp", delete=False) as f: return f.name diff --git a/tests/evals/gpt_oss/__init__.py b/tests/evals/gpt_oss/__init__.py new file mode 100644 index 0000000000000..208f01a7cb5ee --- /dev/null +++ b/tests/evals/gpt_oss/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/evals/gpt_oss/conftest.py b/tests/evals/gpt_oss/conftest.py new file mode 100644 index 0000000000000..2f140ae2c8e9b --- /dev/null +++ b/tests/evals/gpt_oss/conftest.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Pytest configuration for GPT-OSS evaluation tests. +""" + + +def pytest_addoption(parser): + """Add command line options for pytest.""" + parser.addoption("--model", action="store", help="Model name to evaluate") + parser.addoption( + "--metric", action="store", type=float, help="Expected metric threshold" + ) + parser.addoption( + "--server-args", action="store", default="", help="Additional server arguments" + ) diff --git a/tests/evals/gpt_oss/test_gpqa_correctness.py b/tests/evals/gpt_oss/test_gpqa_correctness.py new file mode 100644 index 0000000000000..151deaa059f0d --- /dev/null +++ b/tests/evals/gpt_oss/test_gpqa_correctness.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPQA evaluation using vLLM server and GPT-OSS evaluation package. + +Usage: +pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \ + --model openai/gpt-oss-20b \ + --metric 0.58 \ + --server-args "--tensor-parallel-size 2" +""" + +import subprocess +import sys + +import regex as re + +from tests.utils import RemoteOpenAIServer + +TOL = 0.05 # Absolute tolerance for accuracy comparison + + +def run_gpqa_eval(model_name: str, base_url: str) -> float: + """Run GPQA evaluation using the gpt-oss evaluation package.""" + + # Build the command to run the evaluation + cmd = [ + sys.executable, + "-m", + "gpt_oss.evals", + "--eval", + "gpqa", + "--model", + model_name, + "--reasoning-effort", + "low", + "--base-url", + base_url, + "--n-threads", + "200", + ] + + try: + # Run the evaluation + result = subprocess.run( + cmd, + text=True, + capture_output=True, + timeout=1800, # 30 minute timeout + env={"OPENAI_API_KEY": "dummy"}, + ) + + print("Evaluation process output:\n", result.stdout) + + # Parse the output to extract the score + match = re.search(r"'metric':\s*([\d.]+)", result.stdout) + if match: + return float(match.group(1)) + + # If we still can't find it, raise an error + raise ValueError( + f"Could not parse score from evaluation output:\n{result.stdout}" + ) + + except subprocess.TimeoutExpired as e: + raise RuntimeError("Evaluation timed out") from e + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Evaluation failed with exit code {e.returncode}:\n" + f"stdout: {e.stdout}\nstderr: {e.stderr}" + ) from e + + +def test_gpqa_correctness(request): + """Test GPQA correctness for GPT-OSS model.""" + + # Get command line arguments + model_name = request.config.getoption("--model") + expected_metric = request.config.getoption("--metric") + server_args_str = request.config.getoption("--server-args") + + # Parse server arguments + server_args = [] + if server_args_str: + server_args = server_args_str.split() + + # Add standard server arguments + server_args.extend( + [ + "--trust-remote-code", + ] + ) + + print(f"Starting GPQA evaluation for model: {model_name}") + print(f"Expected metric threshold: {expected_metric}") + print(f"Server args: {' '.join(server_args)}") + + # Launch server and run evaluation + with RemoteOpenAIServer( + model_name, server_args, max_wait_seconds=1800 + ) as remote_server: + base_url = remote_server.url_for("v1") + print(f"Server started at: {base_url}") + + measured_metric = run_gpqa_eval(model_name, base_url) + + print(f"GPQA Results for {model_name}:") + print(f" Measured metric: {measured_metric:.4f}") + print(f" Expected metric: {expected_metric:.4f}") + print(f" Tolerance: {TOL:.4f}") + + # Verify metric is within tolerance + assert measured_metric >= expected_metric - TOL, ( + f"GPQA metric too low: {measured_metric:.4f} < " + f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" + ) + + print(f"✅ GPQA test passed for {model_name}") diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md index 58572c3a6fbc1..29c5199e1e87a 100644 --- a/tests/evals/gsm8k/README.md +++ b/tests/evals/gsm8k/README.md @@ -19,7 +19,7 @@ pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 # Run evaluation -python tests/gsm8k/gsm8k_eval.py --port 8000 +python tests/evals/gsm8k/gsm8k_eval.py --port 8000 ``` ## Configuration Format diff --git a/tests/evals/gsm8k/__init__.py b/tests/evals/gsm8k/__init__.py index 0fec1fe5bcdfd..208f01a7cb5ee 100644 --- a/tests/evals/gsm8k/__init__.py +++ b/tests/evals/gsm8k/__init__.py @@ -1,2 +1,2 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml new file mode 100644 index 0000000000000..7ec6a1e0be27f --- /dev/null +++ b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml @@ -0,0 +1,6 @@ +model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" +accuracy_threshold: 0.72 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 + diff --git a/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml b/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml new file mode 100644 index 0000000000000..6b7bdd1e65bb3 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml @@ -0,0 +1,6 @@ +model_name: "nvidia/Qwen3-30B-A3B-FP4" +accuracy_threshold: 0.89 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 + diff --git a/tests/evals/gsm8k/configs/models-blackwell.txt b/tests/evals/gsm8k/configs/models-blackwell.txt new file mode 100644 index 0000000000000..3c9b1084de7bc --- /dev/null +++ b/tests/evals/gsm8k/configs/models-blackwell.txt @@ -0,0 +1,5 @@ +Qwen3-0.6B-FP8.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-CT.yaml +DeepSeek-V2-Lite-Instruct-FP8.yaml +Qwen3-30B-A3B-NVFP4.yaml diff --git a/tests/evals/gsm8k/configs/models-small.txt b/tests/evals/gsm8k/configs/models-small.txt index afd1065b9191b..7bce3f0004f7d 100644 --- a/tests/evals/gsm8k/configs/models-small.txt +++ b/tests/evals/gsm8k/configs/models-small.txt @@ -3,3 +3,4 @@ Llama-3.2-1B-Instruct-INT8-CT.yaml Llama-3-8B-Instruct-nonuniform-CT.yaml Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-CT.yaml +DeepSeek-V2-Lite-Instruct-FP8.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py index d96b0a66ede2b..1932a13cdfc63 100644 --- a/tests/evals/gsm8k/conftest.py +++ b/tests/evals/gsm8k/conftest.py @@ -6,13 +6,12 @@ from pathlib import Path def pytest_addoption(parser): """Add custom command line options.""" - parser.addoption("--config-list-file", - default="configs/models-small.txt", - help="File containing list of config files to test") - parser.addoption("--tp-size", - default=1, - type=int, - help="Tensor parallel size") + parser.addoption( + "--config-list-file", + default="configs/models-small.txt", + help="File containing list of config files to test", + ) + parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size") def pytest_generate_tests(metafunc): @@ -55,12 +54,10 @@ def pytest_generate_tests(metafunc): # Generate test parameters if config_files: - metafunc.parametrize(["config_filename", "tp_size"], - [(config_file, int(tp_size)) - for config_file in config_files], - ids=[ - f"{config_file.stem}-tp{tp_size}" - for config_file in config_files - ]) + metafunc.parametrize( + ["config_filename", "tp_size"], + [(config_file, int(tp_size)) for config_file in config_files], + ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files], + ) else: print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index 7d0ce25f75dd4..9edec7a78ca23 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -76,13 +76,15 @@ def get_answer_value(answer_str: str) -> int: return INVALID -async def call_vllm_api(session: aiohttp.ClientSession, - prompt: str, - temperature: float, - max_tokens: int, - stop: Optional[list[str]] = None, - url: Optional[str] = None, - seed: Optional[int] = None) -> str: +async def call_vllm_api( + session: aiohttp.ClientSession, + prompt: str, + temperature: float, + max_tokens: int, + stop: Optional[list[str]] = None, + url: Optional[str] = None, + seed: Optional[int] = None, +) -> str: """Call vLLM's OpenAI-compatible completions endpoint.""" data = { "prompt": prompt, @@ -94,8 +96,7 @@ async def call_vllm_api(session: aiohttp.ClientSession, data["seed"] = seed try: - async with session.post(f"{url}/v1/completions", - json=data) as response: + async with session.post(f"{url}/v1/completions", json=data) as response: response.raise_for_status() result = await response.json() return result["choices"][0]["text"] @@ -104,16 +105,18 @@ async def call_vllm_api(session: aiohttp.ClientSession, return "" -def evaluate_gsm8k(num_questions: int = 1319, - num_shots: int = 5, - max_tokens: int = 256, - host: str = "http://127.0.0.1", - port: int = 8000, - temperature: float = 0.0, - seed: Optional[int] = 42) -> dict[str, Union[float, int]]: +def evaluate_gsm8k( + num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + host: str = "http://127.0.0.1", + port: int = 8000, + temperature: float = 0.0, + seed: Optional[int] = 42, +) -> dict[str, Union[float, int]]: """ Evaluate GSM8K accuracy using vLLM serve endpoint. - + Returns dict with accuracy, invalid_rate, latency, etc. """ base_url = f"{host}:{port}" @@ -127,8 +130,10 @@ def evaluate_gsm8k(num_questions: int = 1319, # Build few-shot examples from train split (like lm-eval does) few_shot_examples = "" for i in range(num_shots): - few_shot_examples += (f"Question: {train_data[i]['question']}\n" - f"Answer: {train_data[i]['answer']}\n\n") + few_shot_examples += ( + f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n" + ) # Prepare test questions and labels from test split questions = [] @@ -157,15 +162,15 @@ def evaluate_gsm8k(num_questions: int = 1319, states[i] = answer return answer - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( - total=600)) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=600) + ) as session: tasks = [get_answer(session, i) for i in range(num_questions)] await tqdm.gather(*tasks, desc="Evaluating") return states - print(f"Running GSM8K evaluation: {num_questions} questions, " - f"{num_shots}-shot") + print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot") tic = time.perf_counter() states = asyncio.run(run_async_evaluation()) @@ -191,36 +196,28 @@ def evaluate_gsm8k(num_questions: int = 1319, def main() -> None: - parser = argparse.ArgumentParser( - description="GSM8K evaluation for vLLM serve") - parser.add_argument("--num-shots", - type=int, - default=5, - help="Number of few-shot examples") - parser.add_argument("--num-questions", - type=int, - default=1319, - help="Number of questions to evaluate") - parser.add_argument("--max-tokens", - type=int, - default=256, - help="Max tokens for generation") - parser.add_argument("--host", - type=str, - default="http://127.0.0.1", - help="Host URL") + parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve") + parser.add_argument( + "--num-shots", type=int, default=5, help="Number of few-shot examples" + ) + parser.add_argument( + "--num-questions", + type=int, + default=1319, + help="Number of questions to evaluate", + ) + parser.add_argument( + "--max-tokens", type=int, default=256, help="Max tokens for generation" + ) + parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL") parser.add_argument("--port", type=int, default=8000, help="Port number") - parser.add_argument("--temperature", - type=float, - default=0.0, - help="Temperature for generation") - parser.add_argument("--seed", - type=int, - default=42, - help="Random seed for reproducibility") - parser.add_argument("--save-results", - type=str, - help="Save results to JSON file") + parser.add_argument( + "--temperature", type=float, default=0.0, help="Temperature for generation" + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + parser.add_argument("--save-results", type=str, help="Save results to JSON file") args = parser.parse_args() diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index a12dd49dbea6d..ce3ab8096b45c 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -63,9 +63,9 @@ def test_gsm8k_correctness_param(config_filename, tp_size): ] # Launch server and run evaluation - with RemoteOpenAIServer(eval_config["model_name"], - server_args, - max_wait_seconds=480) as remote_server: + with RemoteOpenAIServer( + eval_config["model_name"], server_args, max_wait_seconds=480 + ) as remote_server: server_url = remote_server.url_for("v1") results = launch_gsm8k_eval(eval_config, server_url, tp_size) @@ -85,6 +85,7 @@ def test_gsm8k_correctness_param(config_filename, tp_size): # Verify accuracy is within tolerance assert measured_accuracy >= expected_accuracy - RTOL, ( f"Accuracy too low: {measured_accuracy:.3f} < " - f"{expected_accuracy:.3f} - {RTOL:.3f}") + f"{expected_accuracy:.3f} - {RTOL:.3f}" + ) print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/kernels/allclose_default.py b/tests/kernels/allclose_default.py index 9d65159bf64fe..6561e9556fa7a 100644 --- a/tests/kernels/allclose_default.py +++ b/tests/kernels/allclose_default.py @@ -6,11 +6,7 @@ import torch # Reference default values of atol and rtol are from # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67 default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5} -default_rtol = { - torch.float16: 1e-3, - torch.bfloat16: 1.6e-2, - torch.float: 1.3e-6 -} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6} def get_default_atol(output) -> float: diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index 88a2fb62b2540..b080a71bd54e6 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,8 +3,7 @@ import pytest -from vllm.utils import (create_kv_caches_with_random, - create_kv_caches_with_random_flash) +from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash @pytest.fixture() diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index 2d882bdf4066f..88b21a9b84d64 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -39,7 +39,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -57,10 +57,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -74,11 +77,10 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Only ROCm is supported") -@pytest.mark.parametrize("seq_lens", - [[(10, 1328), (5, 18), - (129, 463)], [(8, 523), (24, 37), (3, 2011)]]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only ROCm is supported") +@pytest.mark.parametrize( + "seq_lens", [[(10, 1328), (5, 18), (129, 463)], [(8, 523), (24, 37), (3, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -109,34 +111,27 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) - cu_seq_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) output = torch.empty_like(query) @@ -187,5 +182,7 @@ def test_varlen_with_paged_kv( atol, rtol = 2e-2, 2e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 7083661575ef2..16e544eb3cf9f 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -18,7 +18,7 @@ if not current_platform.is_rocm(): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - from vllm.attention.backends.xformers import _make_alibi_bias + from tests.kernels.utils import make_alibi_bias FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. @@ -42,9 +42,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( @@ -110,8 +108,7 @@ def ref_single_query_cached_kv_attention( # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) @@ -119,8 +116,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) + "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"] +) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -143,13 +140,18 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if ((kv_cache_dtype == "fp8" and head_size % 16) - or (version == "rocm" and head_size not in (64, 128))): + if (kv_cache_dtype == "fp8" and head_size % 16) or ( + version == "rocm" and head_size not in (64, 128) + ): pytest.skip() - if (version == "rocm" and current_platform.is_navi() - and (kv_cache_dtype == "fp8" or head_size != 128 - or block_size != 16 or use_alibi)): + if ( + version == "rocm" + and current_platform.is_navi() + and ( + kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi + ) + ): pytest.skip() global PARTITION_SIZE @@ -177,18 +179,24 @@ def test_paged_attention( block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) block_tables = torch.tensor(block_tables_lst, dtype=torch.int) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -214,18 +222,37 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v1, - (output, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v1, + ( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) elif version in ("v2", "rocm"): if current_platform.is_rocm() and version == "rocm": PARTITION_SIZE = PARTITION_SIZE_ROCM - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -258,13 +285,34 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) else: ops.paged_attention_rocm( @@ -288,13 +336,30 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._rocm_C.paged_attention, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, None, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._rocm_C.paged_attention, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + None, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) else: raise AssertionError(f"Unknown version: {version}") @@ -303,18 +368,17 @@ def test_paged_attention( if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + dequantized_key_cache = torch.empty( + size=key_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) + dequantized_value_cache = torch.empty( + size=value_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache @@ -367,8 +431,9 @@ def ref_multi_query_kv_attention( if alibi_bias: attn_mask = alibi_bias[i] else: - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1 + ) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype) @@ -390,8 +455,9 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -413,13 +479,11 @@ def test_multi_query_kv_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) + qkv = torch.empty( + num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype + ) qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1) num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: @@ -429,8 +493,7 @@ def test_multi_query_kv_attention( alibi_bias = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, - seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output = torch.empty_like(query) start = 0 # Dynamic sequence length not supported with custom attn_bias. @@ -442,7 +505,8 @@ def test_multi_query_kv_attention( value[None, start:end], attn_bias=attn_bias[i], p=0.0, - scale=scale) + scale=scale, + ) output[start:end].copy_(out.view_as(query[start:end])) start += seq_len # xformers.AttentionBias to Tensor for use in reference impl. @@ -485,8 +549,9 @@ def test_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) @torch.inference_mode() def test_multi_query_kv_attention_with_alibi( num_seqs: int, diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index aea166da3af2f..48a42ce6ffab5 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -15,21 +15,26 @@ from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA"], + "cuda": [ + "TRITON_MLA", + "FLASHMLA", + "FLASHINFER_MLA", + "FLASH_ATTN_MLA", + "CUTLASS_MLA", + ], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } DEVICE_REGULAR_ATTN_BACKENDS = { - "cuda": ["XFORMERS", "FLASHINFER"], - "hip": ["ROCM_FLASH"], + "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"], + "hip": ["ROCM_ATTN"], "cpu": ["TORCH_SDPA"], } @@ -37,7 +42,7 @@ DEVICE_MLA_BLOCK_SIZES = { "cuda": [16, 64], # CUDA supports both standard and extended block sizes "hip": [16, 1], # HIP requires special handling for block_size=1 # "cpu": [16] # CPU uses fixed block size from test cases - "cpu": [] # FIXME(woosuk): Temporarily disable CPU tests + "cpu": [], # FIXME(woosuk): Temporarily disable CPU tests } @@ -45,12 +50,13 @@ def generate_params(): params = [] for use_mla in [True, False]: for device in ["cuda", "hip", "cpu"]: - backends = DEVICE_MLA_BACKENDS[ - device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device] + backends = ( + DEVICE_MLA_BACKENDS[device] + if use_mla + else DEVICE_REGULAR_ATTN_BACKENDS[device] + ) for name in backends: - block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [ - 16 - ] + block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16] for block_size in block_sizes: params.append( pytest.param( @@ -58,236 +64,224 @@ def generate_params(): name, use_mla, block_size, - id= - f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}" - )) + id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}", + ) + ) return params -@pytest.mark.parametrize("device, name, use_mla, block_size", - generate_params()) -@pytest.mark.parametrize("use_v1", [True, False]) +@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) def test_env( device: str, name: str, use_mla: bool, block_size: int, - use_v1: bool, monkeypatch: pytest.MonkeyPatch, ): """Test attention backend selection with valid device-backend pairs.""" with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") - if name == "FLASHINFER" and not use_v1: - pytest.skip("FlashInfer backend is only available on V1 engine") - if device == "cpu": - if not use_v1: - pytest.skip("CPU backend only supports V1") - - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - block_size, False) - assert backend.get_name() == "TORCH_SDPA_VLLM_V1" + with patch("vllm.platforms.current_platform", CpuPlatform()): + backend = get_attn_backend(16, torch.float16, None, block_size) + assert backend.get_name() == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): + with patch("vllm.platforms.current_platform", RocmPlatform()): if use_mla: - # Validate HIP MLA backend-block_size combinations - valid_combination = ( - (name == "TRITON_MLA" and block_size != 1) - or (name == "ROCM_AITER_MLA" and block_size == 1)) + # ROCm MLA backend logic: + # - TRITON_MLA: supported when block_size != 1 + # - ROCM_AITER_MLA: supported when block_size == 1 + # If backend is forced but doesn't match block_size, + # should raise ValueError - if valid_combination: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name - assert backend.get_name() == expected - else: + if name == "TRITON_MLA" and block_size == 1: + # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError) as exc_info: - get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - assert f"The selected backend, {name}" in str( - exc_info.value) + get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + assert f"The selected backend, {name}" in str(exc_info.value) + elif name == "ROCM_AITER_MLA" and block_size != 1: + # ROCM_AITER_MLA only supports block_size == 1 + with pytest.raises(ValueError) as exc_info: + get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + assert f"The selected backend, {name}" in str(exc_info.value) + else: + # Valid backend-block_size combination + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = name + assert backend.get_name() == expected else: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "ROCM_ATTN" assert backend.get_name() == expected elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): + with patch("vllm.platforms.current_platform", CudaPlatform()): if use_mla: - if name == "FLASHMLA" and block_size == 64: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) + # CUDA MLA backend logic: + # - CUTLASS_MLA: only supported with block_size == 128 + # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHINFER_MLA: only supported on Blackwell GPUs + # (SM 10.0+), V1 only + # - FLASHMLA: only supported with block_size == 64 + # - FLASH_ATTN_MLA: V1 only + # - TRITON_MLA: fallback for other cases - # only on cuda platforms with specific capability. - is_supported, _ = is_flashmla_supported() - - if not is_supported: - # if platform is not supported then skip this case. - pytest.skip() + if name == "CUTLASS_MLA": + if block_size != 128: + # CUTLASS_MLA only supports block_size == 128 + pytest.skip("CUTLASS_MLA only supports block_size 128") else: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "CUTLASS_MLA" assert backend.get_name() == expected + elif name == "FLASHINFER_MLA": + if block_size not in [32, 64]: + # FlashInfer MLA only supports block_size 32 or 64 + pytest.skip( + "FlashInfer MLA only supports block_size 32 or 64" + ) + else: + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected + elif name == "FLASHMLA": + if block_size != 64: + # FlashMLA only supports block_size == 64 + pytest.skip("FlashMLA only supports block_size 64") + else: + from vllm.v1.attention.backends.mla.flashmla import ( + is_flashmla_dense_supported, + ) + + is_supported, _ = is_flashmla_dense_supported() + if not is_supported: + pytest.skip("FlashMLA not supported on this platform") + else: + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = name + assert backend.get_name() == expected + elif name == "FLASH_ATTN_MLA": + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASH_ATTN_MLA" + assert backend.get_name() == expected else: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = ("TRITON_MLA_VLLM_V1" - if use_v1 else "TRITON_MLA") + # TRITON_MLA or other fallback + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = "FLASHINFER_VLLM_V1" if use_v1 else name + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASHINFER" assert backend.get_name() == expected - else: - backend = get_attn_backend(32, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name + elif name == "XFORMERS": + backend = get_attn_backend( + 32, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "XFORMERS" + assert backend.get_name() == expected + elif name == "FLASH_ATTN": + backend = get_attn_backend( + 32, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASH_ATTN" assert backend.get_name() == expected - - if use_v1: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - assert backend.get_name() == "FLEX_ATTENTION", ( - "Should fallback to FlexAttention if head size is " - "not supported by FlashAttention") @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("use_v1", [True, False]) -def test_fp32_fallback( - device: str, - use_v1: bool, - monkeypatch: pytest.MonkeyPatch, -): +def test_fp32_fallback(device: str): """Test attention backend selection with fp32.""" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + if device == "cpu": + with patch("vllm.platforms.current_platform", CpuPlatform()): + backend = get_attn_backend(16, torch.float32, None, 16) + assert backend.get_name() == "TORCH_SDPA" - if device == "cpu": - if not use_v1: - pytest.skip("CPU backend only supports V1") - - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) - assert backend.get_name() == "TORCH_SDPA_VLLM_V1" - - elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) - assert (backend.get_name() == "FLEX_ATTENTION" - if use_v1 else "XFORMERS") + elif device == "cuda": + with patch("vllm.platforms.current_platform", CudaPlatform()): + backend = get_attn_backend(16, torch.float32, None, 16) + assert backend.get_name() == "FLEX_ATTENTION" def test_flash_attn(monkeypatch: pytest.MonkeyPatch): """Test FlashAttn validation.""" - # TODO: When testing for v1, pipe in `use_v1` as an argument to - # get_attn_backend + pytest.skip( + "Skipping as current backend selector does not " + "handle fallbacks when a backend is set via env var." + ) with monkeypatch.context() as m: m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, - "get_device_capability", - lambda _=None: (7, 5)) - backend = get_attn_backend(16, torch.float16, None, 16, False) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) + backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Reset the monkeypatch for subsequent tests monkeypatch.undo() # Unsupported data type - backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + backend = get_attn_backend(16, torch.float16, "fp8", 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported block size - backend = get_attn_backend(16, torch.float16, None, 8, False) + backend = get_attn_backend(16, torch.float16, None, 8) assert backend.get_name() != STR_FLASH_ATTN_VAL # flash-attn is not installed import sys - original_module = sys.modules.get('vllm_flash_attn') - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) - backend = get_attn_backend(16, torch.float16, None, 16, False) + + original_module = sys.modules.get("vllm_flash_attn") + monkeypatch.setitem(sys.modules, "vllm_flash_attn", None) + backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Restore the original module if it existed if original_module is not None: - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', - original_module) + monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module) else: - monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) + monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False) # Unsupported head size - backend = get_attn_backend(17, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + backend = get_attn_backend(17, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL -@pytest.mark.parametrize("use_v1", [True, False]) -def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): +def test_invalid_env(monkeypatch: pytest.MonkeyPatch): """Test that invalid attention backend names raise ValueError.""" - with monkeypatch.context() as m, patch( - "vllm.attention.selector.current_platform", CudaPlatform()): - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + with ( + monkeypatch.context() as m, + patch("vllm.platforms.current_platform", CudaPlatform()), + ): m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) # Should raise ValueError for invalid backend with pytest.raises(ValueError) as exc_info: - get_attn_backend(32, torch.float16, None, 16, False) - assert "Invalid attention backend: 'INVALID'" in str(exc_info.value) + get_attn_backend(32, torch.float16, None, 16) + assert "Invalid value 'INVALID'" in str(exc_info.value) diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index cbf11da63cab9..f33a27d1fd85a 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] +COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] DTYPES = [torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing @@ -32,13 +32,13 @@ NUM_BLOCKS = [1024, 10000] NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] +RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"] + @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_layers", NUM_LAYERS) @@ -83,24 +83,33 @@ def test_copy_blocks( block_mapping.append((src, dst2)) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, - num_layers, num_heads, - head_size, kv_cache_dtype, - dtype, seed, device) + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device=device).view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) - opcheck(torch.ops._C_cache_ops.copy_blocks, - (key_caches, value_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(head_size == HEAD_SIZES[0])) + opcheck( + torch.ops._C_cache_ops.copy_blocks, + (key_caches, value_caches, block_mapping_tensor), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + cond=(head_size == HEAD_SIZES[0]), + ) ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. @@ -113,8 +122,7 @@ def test_copy_blocks( # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): torch.testing.assert_close(key_cache, cloned_key_cache) - for value_cache, cloned_value_cache in zip(value_caches, - cloned_value_caches): + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): torch.testing.assert_close(value_cache, cloned_value_cache) @@ -153,10 +161,17 @@ def test_reshape_and_cache( _, key, value = qkv.unbind(dim=1) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -174,12 +189,30 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + opcheck( + torch.ops._C_cache_ops.reshape_and_cache, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) @@ -200,14 +233,12 @@ def test_reshape_and_cache( cloned_value_cache[block_idx, :, :, block_offset] = value[i] if kv_cache_dtype == "fp8": - torch.testing.assert_close(result_key_cache, - cloned_key_cache, - atol=0.001, - rtol=0.1) - torch.testing.assert_close(result_value_cache, - cloned_value_cache, - atol=0.001, - rtol=0.1) + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 + ) else: torch.testing.assert_close(key_cache, cloned_key_cache) torch.testing.assert_close(value_cache, cloned_value_cache) @@ -223,6 +254,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) +@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -236,9 +268,13 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, kv_cache_layout: str, + implementation: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + assert implementation in ["cuda", "triton"] + if implementation == "triton" and kv_cache_layout == "HND": + pytest.skip("Triton implementation only supports NHD layout.") # fp8 conversion requires continugous memory buffer. Reduce the number of # blocks and tokens to consume less memory. @@ -247,15 +283,8 @@ def test_reshape_and_cache_flash( # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) _, key, value = qkv.unbind(dim=1) # Create the KV caches. @@ -286,40 +315,73 @@ def test_reshape_and_cache_flash( # Clone the KV caches. if kv_cache_dtype == "fp8": - cloned_key_cache = torch.empty_like(key_cache_compact, - dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(), - kv_cache_dtype) - cloned_value_cache = torch.empty_like(value_cache_compact, - dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache_compact, - v_scale.item(), kv_cache_dtype) + cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) + ops.convert_fp8( + cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype + ) + cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) + ops.convert_fp8( + cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype + ) else: cloned_key_cache = key_cache_compact.clone() cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, v_scale) + if implementation == "cuda": + opcheck( + torch.ops._C_cache_ops.reshape_and_cache_flash, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + elif implementation == "triton": + from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, + ) + + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) key_cache_compact = permute_and_compact(key_cache) value_cache_compact = permute_and_compact(value_cache) if kv_cache_dtype == "fp8": - result_key_cache = torch.empty_like(key_cache_compact, - dtype=torch.float16) - ops.convert_fp8(result_key_cache, - key_cache_compact, - k_scale.item(), - kv_dtype=kv_cache_dtype) - result_value_cache = torch.empty_like(value_cache_compact, - dtype=torch.float16) - ops.convert_fp8(result_value_cache, - value_cache_compact, - v_scale.item(), - kv_dtype=kv_cache_dtype) + result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) + ops.convert_fp8( + result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype + ) + result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) + ops.convert_fp8( + result_value_cache, + value_cache_compact, + v_scale.item(), + kv_dtype=kv_cache_dtype, + ) # Run the reference implementation. block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") @@ -337,14 +399,12 @@ def test_reshape_and_cache_flash( cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": - torch.testing.assert_close(result_key_cache, - cloned_key_cache, - atol=0.001, - rtol=0.1) - torch.testing.assert_close(result_value_cache, - cloned_value_cache, - atol=0.001, - rtol=0.1) + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 + ) else: torch.testing.assert_close(key_cache_compact, cloned_key_cache) torch.testing.assert_close(value_cache_compact, cloned_value_cache) @@ -381,8 +441,8 @@ def test_swap_blocks( current_platform.seed_everything(seed) - src_device = device if direction[0] == "cuda" else 'cpu' - dst_device = device if direction[1] == "cuda" else 'cpu' + src_device = device if direction[0] == "cuda" else "cpu" + dst_device = device if direction[1] == "cuda" else "cpu" src_blocks = random.sample(range(num_blocks), num_mappings) # For the same device, mapping must not overlap @@ -393,42 +453,62 @@ def test_swap_blocks( dst_blocks = random.sample(range(num_blocks), num_mappings) block_mapping = list(zip(src_blocks, dst_blocks)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device="cpu").view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, - seed, src_device) + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + src_device, + ) # Create the KV caches on the second device. dist_key_caches, dist_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, - seed, dst_device) + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + dst_device, + ) src_key_caches_clone = src_key_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - do_opcheck = (head_size == HEAD_SIZES[0]) - opcheck(torch.ops._C_cache_ops.swap_blocks, - (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), - cond=do_opcheck) - opcheck(torch.ops._C_cache_ops.swap_blocks, - (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), - cond=do_opcheck) + do_opcheck = head_size == HEAD_SIZES[0] + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) - ops.swap_blocks(src_key_caches[0], dist_key_caches[0], - block_mapping_tensor) - ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping_tensor) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) for src, dst in block_mapping: - torch.testing.assert_close(src_key_caches_clone[src].cpu(), - dist_key_caches[0][dst].cpu()) - torch.testing.assert_close(src_value_caches_clone[src].cpu(), - dist_value_caches[0][dst].cpu()) + torch.testing.assert_close( + src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() + ) + torch.testing.assert_close( + src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() + ) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -474,11 +554,9 @@ def _create_mla_cache( device: str, ) -> torch.Tensor: cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype - return torch.zeros(num_blocks, - block_size, - entry_size, - dtype=cache_dtype, - device=device) + return torch.zeros( + num_blocks, block_size, entry_size, dtype=cache_dtype, device=device + ) def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): @@ -518,20 +596,16 @@ def test_concat_and_cache_mla( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -543,10 +617,7 @@ def test_concat_and_cache_mla( if kv_cache_dtype == "fp8": ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) - ops.convert_fp8(ref_kv_cache, - ref_temp, - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype) else: ref_kv_cache = ref_temp @@ -556,28 +627,135 @@ def test_concat_and_cache_mla( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) if kv_cache_dtype == "fp8": result_temp = torch.empty_like(kv_cache, dtype=torch.float16) - ops.convert_fp8(result_temp, - kv_cache.contiguous(), - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8( + result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype + ) expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16) - ops.convert_fp8(expected_temp, - ref_kv_cache, - scale.item(), - kv_dtype=kv_cache_dtype) - torch.testing.assert_close(result_temp, - expected_temp, - atol=0.001, - rtol=0.1) + ops.convert_fp8( + expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype + ) + torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1) else: torch.testing.assert_close(kv_cache, ref_kv_cache) +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_concat_and_cache_ds_mla( + kv_lora_rank: int, + qk_rope_head_dim: int, + num_tokens: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + if dtype.itemsize != 2: + pytest.skip("ds_mla only supports 16-bit input") + kv_cache_dtype = "fp8_ds_mla" + current_platform.seed_everything(seed) + torch.set_default_device(device) + + total_slots = num_blocks * block_size + slot_mapping_lst = random.sample(range(total_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) + entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim) + + scale = torch.tensor(1.0, dtype=torch.float32, device=device) + kv_cache = _create_mla_cache( + num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + kv_cache_dtype=kv_cache_dtype, + device=device, + ) + + ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype) + tile_data = torch.zeros(128, dtype=dtype, device=device) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + + ref_cache_slice = ref_cache[block_idx, block_offset] + ref_cache_16bit = ref_cache_slice.view(dtype) + ref_cache_32bit = ref_cache_slice.view(torch.float32) + + kv_c_data = kv_c[i] + for tile_idx in range(4): + tile_start = tile_idx * 128 + tile_end = (tile_idx + 1) * 128 + tile_data[:] = kv_c_data[tile_start:tile_end] + + # tile_scale = tile_data.amax().to(torch.float32) / 448. + # NOTE: Using torch's amax() gives different results, + # so this must be manually computed. + tile_data_float = tile_data.to(torch.float32) + manual_max = abs(tile_data_float[0]) + for j in range(1, 128): + manual_max = max(manual_max, abs(tile_data_float[j])) + tile_scale = manual_max / 448.0 + + ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale + + ops.convert_fp8( + ref_cache_slice[tile_start:tile_end], + tile_data, + tile_scale.item(), + kv_dtype="fp8", + ) + + for j in range(qk_rope_head_dim): + ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j] + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla, + (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + kv_cache_slice = kv_cache[block_idx, block_offset] + ref_cache_slice = ref_cache[block_idx, block_offset] + + kv_nope = kv_cache_slice[:kv_lora_rank] + ref_nope = ref_cache_slice[:kv_lora_rank] + kv_scales = kv_cache_slice.view(torch.float32)[ + kv_lora_rank // 4 : kv_lora_rank // 4 + 4 + ] + ref_scales = ref_cache_slice.view(torch.float32)[ + kv_lora_rank // 4 : kv_lora_rank // 4 + 4 + ] + kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :] + ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :] + + torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1) + torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1) + torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) + + @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) @@ -606,8 +784,9 @@ def test_copy_blocks_mla( kv_caches = [] for _ in range(num_layers): - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) kv_caches.append(kv_cache) @@ -624,9 +803,9 @@ def test_copy_blocks_mla( dst2 = dst_blocks[2 * i + 1] block_mapping.append((src, dst1)) block_mapping.append((src, dst2)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device=device).view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) for src, dst in block_mapping: for ref_cache in ref_caches: @@ -667,10 +846,12 @@ def test_swap_blocks_mla( entry_size = kv_lora_rank + qk_rope_head_dim - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) - dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) + dst_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype) _fill_mla_cache(dst_cache, kv_cache_dtype) @@ -682,9 +863,9 @@ def test_swap_blocks_mla( remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remaining_blocks, num_mappings) block_mapping = list(zip(src_blocks, dst_blocks)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device="cpu").view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) opcheck( torch.ops._C_cache_ops.swap_blocks, @@ -699,7 +880,8 @@ def test_swap_blocks_mla( src_cache_clone[src].cpu(), dst_cache[dst].cpu(), msg=f"Block {src} from src should have been swapped to block " - f"{dst} in dst_cache.") + f"{dst} in dst_cache.", + ) @pytest.mark.parametrize("kv_lora_rank", [512]) @@ -712,32 +894,36 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, - block_size, num_blocks, - max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_gather_and_maybe_dequant_cache_mla( + kv_lora_rank, + qk_rope_head_dim, + block_size, + num_blocks, + max_seq_len, + batch_size, + dtype, + kv_cache_dtype, + device, +): entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) - seq_len_tensor = torch.randint(0, - max_seq_len + 1, (batch_size, ), - device=device) + seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device) total_tokens = seq_len_tensor.sum() - cu_seq_lens = torch.empty((batch_size + 1), - dtype=torch.int32, - device=device) + cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) print("seq_len_tensor", seq_len_tensor) tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size - block_table = torch.empty((batch_size, num_blocks), - dtype=torch.int32, - device=device) + block_table = torch.empty( + (batch_size, num_blocks), dtype=torch.int32, device=device + ) for b in range(batch_size): perm = torch.randperm(num_blocks, device=device) @@ -765,10 +951,8 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, remaining = s - (tot - 1) * block_size last_block_data = src_cache[blocks[-1], :remaining, :] if kv_cache_dtype == "fp8": - dequantized_last_block = torch.empty_like(last_block_data, - dtype=dtype) - ops.convert_fp8(dequantized_last_block, last_block_data, - scale.item()) + dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, scale.item()) gathered_rows.append(dequantized_last_block) else: gathered_rows.append(last_block_data) @@ -779,14 +963,105 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, opcheck( torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, - (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, - scale, None), + ( + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + kv_cache_dtype, + scale, + None, + ), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, kv_cache_dtype, - scale, None) + ops.gather_and_maybe_dequant_cache( + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + kv_cache_dtype, + scale, + None, + ) + torch.testing.assert_close(dst, expected) + + +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize( + "kv_cache_dtype", ["auto"] +) # You can also test "fp8" if needed. +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_cp_gather_cache_mla( + kv_lora_rank, + qk_rope_head_dim, + block_size, + num_blocks, + max_seq_len, + batch_size, + dtype, + kv_cache_dtype, + device, +): + entry_size = kv_lora_rank + qk_rope_head_dim + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty( + (batch_size, num_blocks), dtype=torch.int32, device=device + ) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + if s == 0: + continue + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + gathered_rows.append(src_cache[blocks[i]]) + remaining = s - (tot - 1) * block_size + gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.cp_gather_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) torch.testing.assert_close(dst, expected) @@ -816,20 +1091,16 @@ def test_concat_and_cache_mla_cpu( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -841,10 +1112,7 @@ def test_concat_and_cache_mla_cpu( if kv_cache_dtype == "fp8": ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) - ops.convert_fp8(ref_kv_cache, - ref_temp, - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype) else: ref_kv_cache = ref_temp @@ -854,6 +1122,5 @@ def test_concat_and_cache_mla_cpu( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) torch.testing.assert_close(kv_cache, ref_kv_cache) diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 1e7e7e0a7f84b..58e8bd592ba43 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -7,11 +7,12 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import (cascade_attention, - merge_attn_states) -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) +from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states +from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported, +) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 192, 256] @@ -37,21 +38,14 @@ def test_merge_kernel( assert num_query_heads % num_kv_heads == 0 # Prepare inputs. - prefix_output = torch.randn(num_tokens, - num_query_heads, - head_size, - dtype=dtype) - suffix_output = torch.randn(num_tokens, - num_query_heads, - head_size, - dtype=dtype) + prefix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype) + suffix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype) prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) # Run the kernel. output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype) - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) # Reference implementation. max_lse = torch.maximum(prefix_lse, suffix_lse) @@ -97,8 +91,10 @@ def test_cascade( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) current_platform.seed_everything(0) @@ -107,11 +103,9 @@ def test_cascade( num_query_heads = num_heads[0] num_kv_heads = num_heads[1] assert num_query_heads % num_kv_heads == 0 - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) seq_lens, common_prefix_len = seq_lens_and_common_prefix @@ -122,26 +116,21 @@ def test_cascade( max_kv_len = max(kv_lens) total_num_query_tokens = sum(query_lens) - query = torch.randn(total_num_query_tokens, - num_query_heads, - head_size, - dtype=dtype) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + query = torch.randn(total_num_query_tokens, num_query_heads, head_size, dtype=dtype) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) assert common_prefix_len > 0 assert common_prefix_len % block_size == 0 num_common_kv_blocks = common_prefix_len // block_size # Make sure the first `num_common_kv_blocks` blocks are the same. - block_tables[:, :num_common_kv_blocks] = \ - block_tables[0, :num_common_kv_blocks] + block_tables[:, :num_common_kv_blocks] = block_tables[0, :num_common_kv_blocks] # Run the regular attention. ref_output = flash_attn_varlen_func( @@ -161,8 +150,7 @@ def test_cascade( # Run cascade attention. assert all(common_prefix_len < kv_len for kv_len in kv_lens) - cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], - dtype=torch.int32) + cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32) suffix_kv_lens = kv_lens_tensor - common_prefix_len output = torch.empty_like(query) diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py new file mode 100644 index 0000000000000..dad1510ce532b --- /dev/null +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import random +from typing import Optional + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.triton_utils import triton + + +def cal_diff( + x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False, + diff_threshold: Optional[float] = None, +) -> None: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + if diff_threshold is not None: + # directly compare the cos_diff with the threshold + assert cos_diff < diff_threshold + else: + # use the default threshold + if use_fp8: + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 + + +CUTLASS_MLA_UNSUPPORTED_REASON = ( + "Cutlass MLA Requires compute capability of 10 or above." + if not current_platform.is_device_capability(100) + else "Cutlass MLA is supported" +) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(100), + reason=CUTLASS_MLA_UNSUPPORTED_REASON, +) +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize( + "torch_dtype", + [ + torch.bfloat16, + # fp8 can have occasional precision-related failures. + pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)), + ], +) +@torch.inference_mode() +def test_cutlass_mla_decode( + b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype +): + device = torch.device("cuda:0") + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype + torch.set_default_dtype(init_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(42) + random.seed(42) + + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" + ) + + use_fp8 = torch_dtype == torch.float8_e4m3fn + scale = math.sqrt(d) ** (-1) + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + blocked_v = blocked_k[..., :dv] + + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + + def cutlass_mla(): + MAX_HEADS = 128 + + q_reshaped = q.squeeze(1) + q_nope = q_reshaped[:, :, :dv].clone() + q_pe = q_reshaped[:, :, dv:].clone() + + if h_q < MAX_HEADS: + q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv)) + q_nope_padded[:, :h_q] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv)) + q_pe_padded[:, :h_q] = q_pe + q_pe = q_pe_padded + + kv_cache_flat = blocked_k.squeeze(2) + device_properties = torch.cuda.get_device_properties(torch.device("cuda:0")) + sm_count = device_properties.multi_processor_count + workspace_size = ops.sm100_cutlass_mla_get_workspace_size( + max_seqlen * block_size, b, sm_count, num_kv_splits=1 + ) + workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) + + out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) + output_lse = torch.empty( + (b, MAX_HEADS), dtype=torch.float32, device=q_nope.device + ) + ops.sm100_cutlass_mla_decode( + out_ans, + output_lse, + q_nope, + q_pe, + kv_cache_flat, + cache_seqlens, + block_table, + workspace, + scale, + 1, + ) + return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous() + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = ( + (blocked_k.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_k + ) + blocked_v_ = ( + (blocked_v.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_v + ) + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i + return out, lse + + out_cutlass, lse_cutlass = cutlass_mla() + out_torch, lse_torch = ref_mla() + # Extract the single token (s_q=1) slice to match cutlass output shape + out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] + lse_torch_slice = lse_torch[:, 0, :] # [b, h_q] + cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) + # lse has larger numerical error, so use a larger threshold + cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3) + + t = triton.testing.do_bench(cutlass_mla) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * ( + torch.finfo(torch_dtype).bits // 8 + ) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print( + f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s" + ) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py new file mode 100644 index 0000000000000..f4b4fac840151 --- /dev/null +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.utils import cdiv, has_deep_gemm +from vllm.utils.deep_gemm import ( + _ceil_to_ue8m0, + calc_diff, + fp8_mqa_logits, + fp8_paged_mqa_logits, + get_num_sms, + get_paged_mqa_logits_metadata, +) + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + # x: (num_blocks, block_size, 1, head_dim) + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(dtype=torch.uint8) + x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( + dtype=torch.uint8 + ) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def per_custom_dims_cast_to_fp8( + x: torch.Tensor, dims: tuple, use_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _generate_cp_test_data(seq_len: int, seq_len_kv: int): + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.zeros(seq_len, dtype=torch.int, device="cuda") + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + +def _ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + seq_len_kv = kv.shape[0] + + k = kv + q = q.float() + k = k.float() + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +def test_deepgemm_fp8_mqa_logits(): + torch.manual_seed(0) + random.seed(0) + num_heads, head_dim = 32, 128 + for seq_len in (512,): + for seq_len_kv in (1024,): + for disable_cp in (False, True): + q = torch.randn( + seq_len, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + kv = torch.randn( + seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 + ) + weights = torch.randn( + seq_len, num_heads, device="cuda", dtype=torch.float32 + ) + + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + ( + seq_len_kv - seq_len + ) + else: + ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) + logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + + ref_logits = _ref_fp8_mqa_logits( + q=q, + kv=kv, + weights=weights, + cu_seqlen_ks=ks, + cu_seqlen_ke=ke, + ) + + ref_neginf_mask = ref_logits == float("-inf") + neginf_mask = logits == float("-inf") + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + + +def _ref_fp8_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +): + batch_size, next_n, _, _ = q.size() + _, block_size, _, _ = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens_list = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens_list[i] + q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() + ) + for block_rk in range(cdiv(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange( + block_rk * block_size, + (block_rk + 1) * block_size, + device="cuda", + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) + return logits + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +def test_deepgemm_fp8_paged_mqa_logits(): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + for batch_size, next_n in [(4, 1), (2, 2)]: + for heads, index_dim in [(32, 128)]: + for avg_kv in (2048,): + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), + device="cuda", + dtype=torch.float32, + ) + + context_lens = ( + torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,)) + .cuda() + .to(torch.int32) + ) + max_block_len = ( + (context_lens.max().item() + blocksize - 1) // blocksize * blocksize + ) + block_tables = torch.zeros( + (batch_size, max_block_len), + device="cuda", + dtype=torch.int32, + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + logits = fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + ) + + ref_logits = _ref_fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + ) + + positions = ( + torch.arange(max_model_len, device="cuda") + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n + next_n_offset = ( + torch.arange(batch_size * next_n, device="cuda") % next_n + ) + mask = positions <= ( + context_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) + + logits = logits.masked_fill(~mask, 0) + ref_logits = ref_logits.masked_fill(~mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" diff --git a/tests/kernels/attention/test_encoder_decoder_attn.py b/tests/kernels/attention/test_encoder_decoder_attn.py deleted file mode 100644 index a2e6986460904..0000000000000 --- a/tests/kernels/attention/test_encoder_decoder_attn.py +++ /dev/null @@ -1,1105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Tests: - -* E2E test of Encoder attention + Decoder self-attention + - Encoder/decoder cross-attention (collectively - "encoder/decoder attention") - -""" - -from typing import NamedTuple, Optional - -import pytest -import torch - -from tests.kernels.utils import * -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.forward_context import set_forward_context -from vllm.platforms import current_platform - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Encoder-decoder is only supported on V0, so set - VLLM_USE_V1=0 for all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -# List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] -HEAD_SIZES = [64, 256] - -NUM_HEADS = [1, 16] - -BATCH_SIZES = [1, 16] -BLOCK_SIZES = [16] -CUDA_DEVICE = "cuda:0" - -MAX_DEC_SEQ_LENS = [128] -MAX_ENC_SEQ_LENS = [128] - -# Narrow test-cases for unsupported-scenario -# tests -HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] - - -class TestPoint(NamedTuple): - """ - Encapsulates the attributes which define a single invocation - of the test_e2e_enc_dec_attn() test - - Attributes: - num_heads: The number of heads in the model. - head_size: Head dimension - backend_name: Name of the backend framework used. - batch_size: Number of samples per batch. - block_size: Size of each block of data processed. - max_dec_seq_len: Maximum sequence length for the decoder. - max_enc_seq_len: Maximum sequence length for the encoder. - num_blocks: Number of blocks in the model. - """ - - num_heads: int - head_size: int - backend_name: str - batch_size: int - block_size: int - max_dec_seq_len: int - max_enc_seq_len: int - num_blocks: int - attn_type: AttentionType - - -class TestResources(NamedTuple): - ''' - Encapsulates key components for performing an - encoder/decoder attention test - - Note that - (1) attn automatically selects an attention backend - based on platform info & a set of canned - heuristics - (2) attn_backend is thus *not the same backend - instance* used by attn, but rather it is - intended to be a - *different instance* of the *same backend class*; - it is assumed that the user of TestResources - will leverage attn_backend for the purpose of - constructing backend-compatible attention - metadata instances - - Attributes: - - * scale: 1/sqrt(d) scale factor for attn - * attn_backend: implementations of abstraction - attention interface using - a particular kernel library - i.e. XFormers - * attn: Attention layer instance - * kv_cache: shared key/value cache for all attention - ''' - - scale: float - attn: Attention - kv_cache: torch.Tensor - - -def _make_test_resources(test_pt: TestPoint, ) -> TestResources: - ''' - Build key components for performing encoder/decoder attention test. - - Note that - (1) The Attention instance constructed here, automatically selects - an attention backend class based on platform info & a set of canned - heuristics, so - (2) The attention backend instance constructed here is thus *not - the same backend instance* used by attn, but rather it is - intended to be a *different instance* of the *same backend class*; - therefore, - (3) This function requires that test_pt.backend_name matches the backend - class that Attention will automatically select when it is constructed. - - - Arguments: - - * test_pt: TestPoint data structure; this function relies on the - following fields: num_heads, head_size, num_blocks, - block_size, backend_name - - Returns: - - * TestResources data structure. - ''' - - scale = float(1.0 / (test_pt.head_size**0.5)) - attn = Attention( - test_pt.num_heads, - test_pt.head_size, - scale=scale, - prefix=f"{test_pt.attn_type}", - attn_type=test_pt.attn_type, - ) - if test_pt.num_blocks is None or test_pt.num_heads is None: - # Caller does not require a KV cache - return TestResources( - scale, attn, - torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) - - # Construct KV cache - if test_pt.attn_type in (AttentionType.DECODER, - AttentionType.ENCODER_DECODER): - kv_cache = make_kv_cache(test_pt.num_blocks, - test_pt.num_heads, - test_pt.head_size, - test_pt.block_size, - device=CUDA_DEVICE, - backend=test_pt.backend_name) - else: - kv_cache = torch.tensor([]) - - attn.kv_cache = [kv_cache] - return TestResources(scale, attn, kv_cache) - - -def _encoder_attn_setup( - test_pt: TestPoint, - test_rsrcs: TestResources, -) -> PhaseTestParameters: - ''' - Set up test vectors & data structures for encoder attention test. - - A triplet of synthetic query/key/value tensors are constructed. - Given this is an encoder attention test, the key & value - sequences will have the same length as the corresponding queries. - - The query/key/value tensors are passed to an ideal reference - self-attention implementation to generate an ideal output tensor. - - Encoder inference does not populate the KV cache, therefore - no KV cache memory mapping is constructed - - Arguments: - - * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, - block_size, max_q_seq_len - * test_rsrcs: TestResources data structure; this function relies on the - scale field - - - Returns: - - * PhaseTestParameters data structure comprising (1) packed query/key/value - tensors, (2) the ideal output of attention computed using a naive - implementation, and (3) KVCache field set to None - ''' - - ( - num_heads, - head_size, - _, - batch_size, - _, - _, - max_q_seq_len, - _, - _, - ) = test_pt - - scale = test_rsrcs.scale - - max_kv_seq_len = max_q_seq_len - - # Make test tensors - - qkv_in, _, _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) - - # Compute correct answer using naive non-causal attention - # implementation - - ideal_output = ref_masked_attention(qkv_in.query, - qkv_in.key, - qkv_in.value, - scale=scale, - q_seq_lens=qkv_in.q_seq_lens, - kv_seq_lens=qkv_in.kv_seq_lens) - - packed_ideal_output, _ = pack_tensor(ideal_output, - qkv_in.q_seq_lens, - device=CUDA_DEVICE) - - packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) - - return PhaseTestParameters( - PackedQKVO(packed_qkv, packed_ideal_output), - None # No KV cache - ) - - -def _decoder_attn_setup( - test_pt: TestPoint, - test_rsrcs: TestResources, - block_base_addr: int = 0, -) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: - ''' - Set up test vectors & data structures for self-attention test. - - A triplet of synthetic query/key/value tensors are constructed ("baseline" - query/key/value). Given this is a self-attention test, the key & value - sequences will have the same length as the corresponding queries. - - "Prefill" query/key/value tensors are derived by masking out the last value - in each baseline query/key/value. These tensors are used to test prefill & - populate KV cache for a subsequent decode test. - - "Decode" query/key/value tensors are derived by extracting *only* the last - value from each baseline query/key/value (i.e. complement of the prefill - tensors.) These tensors are used to test decode, conditional on the kv cache - being populated during the prefill test. - - The baseline query/key/value tensors are passed to an ideal reference - self-attention implementation to generate a "Baseline" ideal output tensor. - This tensor is split into the "Prefill" ideal output tensor (all but the - last element of each output sequence) and the "Decode" ideal output tensor - (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode - test results, respectively. - - This function also constructs the self-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts at - block_base_addr - - Arguments: - - * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, - block_size, max_q_seq_len - * test_rsrcs: TestResources data structure; this function relies on the - scale field - * block_base_addr: decoder self-attention block-table base address - - Returns: - * qkv: Unpacked (batch_size x padded_seq_len x num_heads x - head_size) query/key/value tensors - * Prefill-phase decoder self-attention PhaseTestParameters data structure, - including (1) packed (number_of_tokens x num_heads x head_size) - query/key/value tensors along with (2) ideal attention output - computed using a naive implementation, and (3) memory-mapping data - structures appropriate for prefill phase. - * Decode-phase decoder self-attention PhaseTestParameters data structure, - including (1) packed (number_of_tokens x num_heads x head_size) - query/key/value tensors along with (2) ideal attention output - computed using a naive implementation, and (3) memory-mapping data - structures appropriate for decode phase. - * max_block_idx: max physical address in decoder self-attention block-table - (intended to be used as the base address for the encoder/ - decoder cross-attention block-table, which is not - constructed in this function) - ''' - - ( - num_heads, - head_size, - _, - batch_size, - block_size, - max_q_seq_len, - _, - _, - _, - ) = test_pt - - scale = test_rsrcs.scale - - max_kv_seq_len = max_q_seq_len - - # Build test tensors - - ( - qkv, - prefill_qkv, - decode_qkv, - ) = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) - - # Compute correct answer using naive attention implementation - # with causal attention mask - - causal_mask = make_causal_mask(max_q_seq_len, - max_kv_seq_len).to(CUDA_DEVICE) - - ideal_output = ref_masked_attention(qkv.query, - qkv.key, - qkv.value, - scale=scale, - custom_mask=causal_mask, - q_seq_lens=qkv.q_seq_lens, - kv_seq_lens=qkv.kv_seq_lens) - - # Split out the prefill- & decode-phase ideal answers & pack them - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): - prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( - prefill_q_seq_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_qkv.q_seq_lens, - device=CUDA_DEVICE) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)], - device=CUDA_DEVICE) - - # Build prefill- & decode-phase data structures - # for decoder self-attention. Block tables and - # slot mapping must be in a format compatible - # with KV caching & attention kernels - # - # Prefill-phase: - # - # * Empty block-tables tensor - # * Slot-mapping with entries for prompt tokens - # - # Decode-phase: - # * Block-tables tensor with minimum number of blocks - # required by total num. tokens in the entirety of all sequences - # (including both prefill & decode) - # * Slot-mapping with entries for tokens that will be decoded in the - # current decode iteration - # - # Note: the format described above is simply mirroring what ModelRunner - # produces - - prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) - - ( - decode_block_tables, - slot_mapping_list, - max_block_idx, - ) = make_block_tables_slot_mapping(block_size, - qkv.q_seq_lens, - device=CUDA_DEVICE, - block_base_addr=block_base_addr) - - ( - prefill_slot_mapping, - decode_slot_mapping, - ) = split_slot_mapping(slot_mapping_list, - qkv.q_seq_lens, - device=CUDA_DEVICE) - - prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) - - decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) - - return ( - qkv, - PhaseTestParameters( # Prefill test params - PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), - KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), - PhaseTestParameters( # Decode test params - PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), - KVMemoryMap(decode_block_tables, decode_slot_mapping)), - max_block_idx) - - -def _enc_dec_cross_attn_setup_reuses_query( - decoder_qkv: QKVInputs, - encoder_test_params: PhaseTestParameters, - prefill_decoder_phase_test_params: PhaseTestParameters, - test_pt: TestPoint, - test_rsrcs: TestResources, - block_base_addr: int = 0, -) -> tuple[PhaseTestParameters, PhaseTestParameters]: - ''' - Set up test vectors & data structures for cross-attention test. - - A triplet of synthetic cross-attention key/value tensors are constructed - ("baseline" key/value). Given this is a cross-attention test, we assume - query tensors were already synthesized for a prior self-attention test and - will be reused for cross-attention. The key & value sequences generated here - may have a different length than the corresponding queries (as is often - the case for cross-attention between decoder and encoder sequences.) - - Cross attention key & value tensors do not grow during autoregressive - inference; thus this function obtains a single key/value pair suitable for - both prefill and decode. - - The "baseline" query tensor is received as an argument. The "baseline" - query/key/value tensors are passed to an ideal reference cross-attention - implementation to generate a "baseline" ideal output tensor. This tensor is - split into the "Prefill" ideal output tensor (all but the last element of - each output sequence) and the "Decode" ideal output tensor (*only* the last - element of each output sequence); the "Prefill" and "Decode" ideal output - tensors can be used to validate the prefill and decode test results, - respectively. - - This function also constructs the cross-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts at - block_base_addr. - - Arguments: - - * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x - num_heads x head_size) decoder self-attention inputs; - this function relies on the query and q_seq_lens - fields - * encoder_test_params: PhaseTestParameters data structure which was - used for encoder inference; KV cache field - is not used by this function - * prefill_decoder_phase_test_params: PhaseTestParameters data structure - used for prefill-phase decoder - self-attention; all fields - including KV cache required - * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, - block_size, max_q_seq_len - * test_rsrcs: TestResources data structure; this function relies on the - scale field - * block_base_addr: decoder self-attention block-table base address - - Returns: - - * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data - structure, including (1) packed - (number_of_tokens x num_heads x head_size) query/key/value tensors - along with (2) ideal attention output computed using a - naive implementation, and (3) memory-mapping data structures appropriate - for prefill phase. - * Decode-phase encoder/decoder cross-attention PhaseTestParameters data - structure, including (1) packed - (number_of_tokens x num_heads x head_size) query/key/value tensors - along with (2) ideal attention output computed using a - naive implementation, and (3) memory-mapping data structures appropriate - for decode phase. - ''' - - assert encoder_test_params.packed_qkvo.packed_qkv is not None - assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None - - ( - num_heads, - head_size, - _, - batch_size, - block_size, - max_decoder_seq_len, - max_encoder_seq_len, - _, - _, - ) = test_pt - - scale = test_rsrcs.scale - - decoder_query = decoder_qkv.query - decoder_seq_lens = decoder_qkv.q_seq_lens - encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - prefill_q_seq_lens = ( - prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) - - assert prefill_q_seq_lens is not None - - ( - cross_kv, - _, - _, - ) = make_qkv(batch_size, - max_decoder_seq_len, - max_encoder_seq_len, - num_heads, - head_size, - force_kv_seq_lens=encoder_seq_lens, - attn_type=AttentionType.ENCODER_DECODER, - device=CUDA_DEVICE) - - ideal_output = ref_masked_attention(decoder_query, - cross_kv.key, - cross_kv.value, - scale=scale, - q_seq_lens=decoder_seq_lens, - kv_seq_lens=cross_kv.kv_seq_lens) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): - prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( - prefill_q_seq_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens, - device=CUDA_DEVICE) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)], - device=CUDA_DEVICE) - - # Build prefill- & decode-phase data structures - # for encoder/decoder cross-attention. Block tables and - # slot mapping must be in a format compatible - # with KV caching & attention kernels - # - # Whereas decoder self-attention extracts relationships between - # equal-length Q/K/V sequences, which mutually grow in length - # with each decoded token, cross-attention relates the Q sequence - # - which grows with each new decoded token - to fixed-length - # K and V sequences derived from the encoder hidden states. - # - # Prefill-phase: - # - # * Empty block-tables tensor - # * Slot-mapping with as many entries as there are tokens in the encoder - # prompt. - # - # Decode-phase: - # * Block-tables tensor with minimum number of blocks to - # accommodate K & V tensors which are equal in lnegth - # to the encoder prompt length - # * Empty slot-mapping tensor (since K & V are fixed in size, - # new decoded tokens are not KV-cached and require no slot- - # mapping) - # - # Note: the format above is simply an extension of what ModelRunner - # produces for decoder-only models - - prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) - decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) - - ( - decode_block_tables, - prefill_slot_mapping_list, - _, - ) = make_block_tables_slot_mapping(block_size, - cross_kv.kv_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) - - prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, - device=CUDA_DEVICE) - - # Packed key/value (query is already provided) - packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) - - return ( - PhaseTestParameters( # Prefill-phase test params - PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), - KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), - PhaseTestParameters( # Decode-phase test params - PackedQKVO(None, decode_packed_ideal_output), - KVMemoryMap(decode_block_tables, decode_slot_mapping))) - - -def _run_encoder_attention_test( - attn: Attention, - encoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - test_pt: TestPoint, - vllm_config: VllmConfig, -) -> torch.Tensor: - ''' - Run encoder attention. - - attn.forward() is passed attn_type=AttentionType.ENCODER in order - to configure the kernel invocation for encoder attention - - Requires attn_metadata.num_decode_tokens == 0 - (There is no encoder execution in the decode-phase) - - Arguments: - - * attn: Attention wrapper instance - * encoder_test_params: encoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - query/key/value fields - * attn_metadata: attention metadata for encoder/decoder-self attention - * test_pt: The TestPoint object containing test details like number of - model heads, head size, name of the backend being used etc. - - Returns: - * Attention.forward() applied to packed {query,key,value} and - & attn_metadata - ''' - assert attn_metadata.num_decode_tokens == 0 - packed_qkv = encoder_test_params.packed_qkvo.packed_qkv - assert packed_qkv is not None - with set_forward_context(attn_metadata, vllm_config): - # In the test setup the shape of the query is - # [batch_size, seq_len, num_heads, head_size]. However - # the attention backend expect the shape to be - # [num_tokens, hidden_size]. Hence reshape the query before - # invoking the forward method. - # TODO - Update the way we construct the query so that it - # is shaped as [num_tokens, hidden_size] and we can skip the reshape. - reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) - - -def _run_decoder_self_attention_test( - test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - test_pt: TestPoint, - vllm_config: VllmConfig, -) -> torch.Tensor: - ''' - Run decoder self-attention test. - - attn.forward() is passed attn_type=AttentionType.DECODER - in order to configure the kernel invocation for decoder self-attention. - - Arguments: - - * test_rsrcs: TestResources instance; this function relies on the kv_cache - and attn (Attention wrapper instance) fields - * decoder_test_params: decoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - query/key/value fields - * attn_metadata: attention metadata for decoder-self attention - (contains KV cache memory-mapping) - * test_pt: The TestPoint object containing test details like number of - model heads, head size, name of the backend being used etc. - - Returns: - * Attention.forward() applied to packed_{query,key,value}, kv_cache - & attn_metadata - ''' - attn = test_rsrcs.attn - packed_qkv = decoder_test_params.packed_qkvo.packed_qkv - assert packed_qkv is not None - with set_forward_context(attn_metadata, vllm_config): - # In the test setup the shape of the query is - # [batch_size, seq_len, num_heads, head_size]. However - # the attention backend expect the shape to be - # [num_tokens, hidden_size]. Hence reshape the query before - # invoking the forward method. - # TODO - Update the way we construct the query so that it - # is shaped as [num_tokens, hidden_size] and we can skip the reshape. - reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) - - -def _run_encoder_decoder_cross_attention_test( - test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, - cross_test_params: Optional[PhaseTestParameters], - attn_metadata: AttentionMetadata, - test_pt: TestPoint, - vllm_config: VllmConfig, -) -> torch.Tensor: - ''' - Run encoder/decoder cross-attention test. - - Via PhaseTestParameters data structures, consumes the same query utilized - for decoder self-attention, plus a key/value specific to cross-attention. - - if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv - is None, this reflects that in decode-phase cross attention there - is no growth in the key and value tensors. - - attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER - in order to configure the kernel invocation for encoder/decoder cross- - attention. - - Arguments: - - * test_rsrcs: TestResources instance; this function relies on the kv_cache - and attn (Attention wrapper instance) fields - * decoder_test_params: decoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - query field - * cross_test_params: encoder/decoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - key/value fields - * attn_metadata: attention metadata for encoder/decoder-self attention - * test_pt: The TestPoint object containing test details like number of - model heads, head size, name of the backend being used etc. - - Returns: - * Attention.forward() applied to packed_{query,key,value}, kv_cache - & attn_metadata - ''' - assert decoder_test_params.packed_qkvo.packed_qkv is not None - - attn = test_rsrcs.attn - if cross_test_params is None: - key = None - value = None - else: - cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) - value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - with set_forward_context(attn_metadata, vllm_config): - # In the test setup the shape of the query is - # [batch_size, seq_len, num_heads, head_size]. However - # the attention backend expect the shape to be - # [num_tokens, hidden_size]. Hence reshape the query before - # invoking the forward method. - # TODO - Update the way we construct the query so that it - # is shaped as [num_tokens, hidden_size] and we can skip the reshape. - reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, key, value) - - -@pytest.fixture(autouse=True) -def set_reset_environment(attn_backend): - # Set the default torch datatype to bfloat16 to enable - # testing of the Flash Attention backend. Also clear the - # cached value of the backend. - default_dtype = torch.get_default_dtype() - if attn_backend.name == 'FLASH_ATTN': - torch.set_default_dtype(torch.bfloat16) - _cached_get_attn_backend.cache_clear() - yield - # Reset the torch datatype to what it was before the test - # so as not to impact the remaining tests. - torch.set_default_dtype(default_dtype) - - -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_encoder_only( - num_heads: int, - head_size: int, - attn_backend: _Backend, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, -): - ''' - End-to-end encoder-only attention test: - - * Construct fake test vectors for (1) encoder attention - * Construct (1) attention metadata structure with prefill-phase - encoder attention, and (2) an analogous attention metadata - structure but for decode-phase - * Test & validate encoder attention against ideal output - - No KV cache is required for encoder-only attention. - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if - current_platform.is_rocm(). - - This test globally forces an override of the usual backend - auto-selection process, forcing the specific backend-under-test - to be utilized. - - Arguments: - - * num_heads - * head_size, - * attn_backend: The attention backend to employ for testing - * batch_size - * block_size: KV cache block size - * max_dec_seq_len: max length of decoder input sequences - * max_enc_seq_len: max length of encoder input sequences - ''' - # Force Attention wrapper backend - with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally - # to be more than necessary, since exceeding the kv cache size - # is not part of this test - test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.ENCODER) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - test_rsrcs = _make_test_resources(test_pt) - - # Construct encoder attention test params (only used - # during prefill) - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Shared prefill metadata structure - - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - None, - decoder_test_params=None, - encoder_test_params=enc_test_params, - cross_test_params=None, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata, - test_pt=test_pt, - vllm_config=vllm_config)) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, - attn_backend.name) - - -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn( - num_heads: int, - head_size: int, - attn_backend: _Backend, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, -) -> None: - ''' - End-to-end encoder/decoder test: - - * Construct fake test vectors for (1) encoder attention, - (2) decoder self-attention, and (3) encoder/decoder cross-attention - * Construct (1) attention metadata structure with self- and cross-attention - attributes for prefill-phase, and (2) an analogous attention metadata - structure but for decode-phase - * Test attention steps in the following order - - * Encoder attention - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * Besides being reflective of realistic use-cases, this order would - exacerbate any accidental overlap in the self-/cross-attention - block tables, which one hopes to avoid - - - * Validate output correctness against ideal reference attention - implementation - - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - - This test globally forces an override of the usual backend - auto-selection process, forcing the specific backend-under-test - to be utilized. - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if - current_platform.is_rocm(). - - Note on metadata: there is a single attention metadata structure shared by - all prefill-phase attention operations (encoder, decoder, enc/dec cross), - and a single one shared by all decode-phase attention operations - (decoder & enc/dec cross.) This is intended to reflect the behavior - of EncoderDecoderModelRunner, which constructs a single attention metadata - structure for each prefill or decode run. A realistic scenario would rely - on the attention backend to utilize the appropriate attention metadata - fields according to the value of attn_metadata.attention_type. Thus, - this test is organized so as to confirm that the backend-under-test can - handle a shared prefill attention metadata structure & a shared decode\ - attention metadata structure. - - Arguments: - - * num_heads - * head_size, - * attn_backend: The attention backend to employ for testing - * batch_size - * block_size: KV cache block size - * max_dec_seq_len: max length of decoder input sequences - * max_enc_seq_len: max length of encoder input sequences - ''' - # Force Attention wrapper backend - with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally - # to be more than necessary, since exceeding the kv cache size - # is not part of this test - enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.ENCODER) - enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, - AttentionType.ENCODER_DECODER) - dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.DECODER) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - enc_test_rsrcs = _make_test_resources(enc_test_pt) - enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt) - dec_test_rsrcs = _make_test_resources(dec_test_pt) - - # Construct encoder attention test params (only used - # during prefill) - - enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs) - - # Construct Decoder self-attention prefill-phase & decode-phase - # test params, including query/key/value tensors, decoder self-attention - # memory-mapping. cross_block_base_addr is the uppermost address in the - # decoder self-attention block-table, i.e. a base address which the - # encoder/decoder cross-attention block-table may build downward toward. - - ( - dec_qkv, - prephase_dec_test_params, - decphase_dec_test_params, - cross_block_base_addr, - ) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs) - - # Construct encoder/decoder cross-attention prefill-phase - # & decode-phase test params, including key/value tensors, - # cross-attention memory-mapping - - ( - prephase_cross_test_params, - decphase_cross_test_params, - ) = _enc_dec_cross_attn_setup_reuses_query( - dec_qkv, - enc_test_params, - prephase_dec_test_params, - enc_dec_test_pt, - enc_dec_test_rsrcs, - block_base_addr=cross_block_base_addr) - - # Shared prefill metadata structure - assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata, - test_pt=enc_test_pt, - vllm_config=vllm_config) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, - attn_backend.name) - - # PREFILL: decoder self-attention test - - prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - dec_test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - test_pt=dec_test_pt, - vllm_config=vllm_config) - - # - Is prefill decoder self-attention correct? - assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out, - attn_backend.name) - - # PREFILL: encoder/decoder cross-attention test - - prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - enc_dec_test_rsrcs, - prephase_dec_test_params, - prephase_cross_test_params, - prephase_attn_metadata, - test_pt=enc_dec_test_pt, - vllm_config=vllm_config) - - # - Is prefill encoder/decoder cross-attention correct? - assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out, - attn_backend.name) - - # DECODE: build decode-phase attention metadata - - decphase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - False, - dec_qkv.q_seq_lens, - decoder_test_params=decphase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=decphase_cross_test_params, - device=CUDA_DEVICE) - - # DECODE: decoder self-attention test - - decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - dec_test_rsrcs, - decphase_dec_test_params, - decphase_attn_metadata, - test_pt=dec_test_pt, - vllm_config=vllm_config) - - # - Is decode-phase decoder self-attention correct? - assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out, - attn_backend.name) - - # DECODE: encoder/decoder cross-attention test - - decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - enc_dec_test_rsrcs, - decphase_dec_test_params, - None, - decphase_attn_metadata, - test_pt=enc_dec_test_pt, - vllm_config=vllm_config) - - # - Is decode-phase encoder/decoder cross-attention correct? - assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out, - attn_backend.name) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 2544703f8bf91..d39f0a593ed41 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -7,10 +7,12 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported, +) NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] @@ -44,7 +46,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -62,10 +64,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -106,11 +111,15 @@ def test_flash_attn_with_paged_kv( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip("Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type") + pytest.skip( + "Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type" + ) current_platform.seed_everything(0) num_seqs = len(kv_lens) @@ -119,23 +128,19 @@ def test_flash_attn_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) q = query.unsqueeze(1) out = torch.empty_like(q) if use_out else None @@ -180,23 +185,27 @@ def test_flash_attn_with_paged_kv( if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -222,11 +231,15 @@ def test_varlen_with_paged_kv( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip("Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type") + pytest.skip( + "Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type" + ) current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -236,30 +249,23 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) out = torch.empty_like(query) if use_out else None @@ -315,5 +321,7 @@ def test_varlen_with_paged_kv( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index a821a74aba93d..52cd10fdc5be0 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -38,7 +38,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -56,10 +56,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -101,20 +104,16 @@ def test_flashinfer_decode_with_paged_kv( query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] @@ -135,9 +134,9 @@ def test_flashinfer_decode_with_paged_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=True) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=True + ) wrapper.plan( kv_indptr, kv_indices, @@ -155,17 +154,21 @@ def test_flashinfer_decode_with_paged_kv( output = wrapper.run(query, key_value_cache) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) @@ -196,16 +199,10 @@ def test_flashinfer_prefill_with_paged_kv( max_kv_len = max(kv_lens) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) @@ -215,10 +212,9 @@ def test_flashinfer_prefill_with_paged_kv( value_cache /= head_size**0.5 max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) qo_indptr = [0] kv_indptr = [0] @@ -242,8 +238,7 @@ def test_flashinfer_prefill_with_paged_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( qo_indptr, kv_indptr, @@ -264,17 +259,21 @@ def test_flashinfer_prefill_with_paged_kv( key_value_cache, ) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) @@ -284,9 +283,13 @@ def test_flashinfer_prefill_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", SOFT_CAPS) def test_flashinfer_prefill_with_paged_fp8_kv( - seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], - head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: pytest.skip("TODO: fix the accuracy issue") torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -301,17 +304,11 @@ def test_flashinfer_prefill_with_paged_fp8_kv( kv_cache_dtype = torch.float8_e4m3fn - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) key_cache /= head_size**0.5 value_cache /= head_size**0.5 @@ -319,15 +316,15 @@ def test_flashinfer_prefill_with_paged_fp8_kv( k_scale = key_cache.amax().item() / 448.0 v_scale = value_cache.amax().item() / 448.0 - kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], - dim=1).to(kv_cache_dtype) + kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], dim=1).to( + kv_cache_dtype + ) - assert (kv_cache_fp8.shape == key_value_cache.shape) + assert kv_cache_fp8.shape == key_value_cache.shape max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) qo_indptr = [0] kv_indptr = [0] @@ -351,8 +348,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( qo_indptr, kv_indptr, @@ -369,19 +365,23 @@ def test_flashinfer_prefill_with_paged_fp8_kv( output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache.squeeze(1), - value_cache=value_cache.squeeze(1), - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache.squeeze(1), + value_cache=value_cache.squeeze(1), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) del query del block_tables # verify prefill fp8 - torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @@ -414,12 +414,9 @@ def test_flashinfer_decode_with_paged_fp8_kv( query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) key_cache /= head_size**0.5 value_cache /= head_size**0.5 @@ -429,14 +426,13 @@ def test_flashinfer_decode_with_paged_fp8_kv( key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) - assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) + assert key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1 kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] @@ -457,32 +453,38 @@ def test_flashinfer_decode_with_paged_fp8_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=use_tensor_cores) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=kv_cache_dtype, - logits_soft_cap=soft_cap) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue - torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py new file mode 100644 index 0000000000000..0350136677c6b --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla +from torch import Tensor + +from vllm.platforms import current_platform + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="FlashInfer MLA Requires compute capability of 10 or above.", + allow_module_level=True, + ) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("bs", [1, 2, 4, 16]) +@pytest.mark.parametrize("block_size", [32, 64]) +def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): + torch.set_default_device("cuda") + torch.manual_seed(42) + + # Deepseek R1 config + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + qk_head_dim = kv_lora_rank + qk_rope_head_dim + scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5 + + MAX_SEQ_LEN = 1024 + + seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1,)).item() for _ in range(bs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + # Generate block tables with random but unique block IDs + # From https://github.com/flashinfer-ai/flashinfer/pull/1222 + blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size + max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4) + total_blocks_needed = sum(blocks_per_seq) + # Get random unique IDs for all blocks + all_block_ids = torch.randperm(total_blocks_needed) + + block_id = 0 + block_tables = torch.zeros( + (bs, max_num_blocks_per_seq), + dtype=torch.int32, + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(bs): + num_blocks_needed = blocks_per_seq[i] + block_tables[i, :num_blocks_needed] = all_block_ids[ + block_id : block_id + num_blocks_needed + ] + block_id += num_blocks_needed + + kv_cache = torch.randn(block_tables.numel(), block_size, qk_head_dim).to(dtype) + q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) + + out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) + ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor) + + workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=q.device, + ) + # Flashinfer MLA expects the query to be of shape + # (bs, q_len_per_request, num_heads, qk_head_dim), + # where q_len_per_request is the MTP query length (=1 without MTP) + q = q.unsqueeze(1) + + out_ans = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale, + ) + out_ans = out_ans.squeeze(1) + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 8d0a11d8eb8ab..61157429ec9cc 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,15 +6,17 @@ import flashinfer import pytest import torch -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + dequantize_nvfp4_to_dtype, + get_nvfp4_global_scale, +) from vllm.platforms import current_platform from vllm.utils import round_up if not current_platform.is_device_capability(100): - pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", - allow_module_level=True) + pytest.skip( + "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True + ) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = current_platform.fp8_dtype() @@ -35,6 +37,7 @@ QUANT_DTYPES = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] @@ -44,7 +47,9 @@ NUM_HEADS = [(64, 8), (40, 8)] HEAD_SIZE = [128] KV_LAYOUT = ["HND"] # currently only HND is supported BLOCK_SIZE = [16] +WINDOW_LEFT = [-1, 127] SOFT_CAP = [None, 50.0] +HAS_SINKS = [True, False] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @@ -57,22 +62,27 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", SOFT_CAP) +@pytest.mark.parametrize("has_sinks", HAS_SINKS) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( dtype: torch.dtype, - quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], - Optional[torch.dtype]], + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], batch_size: int, max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], + has_sinks: bool, ) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(0) + current_platform.seed_everything(42) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype = q_quant_dtype or dtype @@ -94,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + # max_q_len = 1 + q_lens = torch.ones((batch_size,), dtype=torch.int32) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ] + ) + + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: query, q_scale = to_float8(query) ref_query = query.to(dtype) * q_scale @@ -102,10 +121,10 @@ def test_flashinfer_trtllm_decode_with_baseline( q_scale = 1.0 ref_query = query - kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len - seq_lens = kv_lens + seq_lens = kv_lens + q_lens max_seq_len = torch.max(seq_lens).item() kv_cache = torch.randn(kv_cache_shape, dtype=dtype) @@ -118,10 +137,9 @@ def test_flashinfer_trtllm_decode_with_baseline( k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (batch_size, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] @@ -142,39 +160,55 @@ def test_flashinfer_trtllm_decode_with_baseline( workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) # Baseline Decode - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) + if has_sinks: + sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5 + wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + else: + sinks = None + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + wrapper.plan( + qo_indptr=q_indptr, + paged_kv_indptr=kv_indptr, + paged_kv_indices=kv_indices, + paged_kv_last_page_len=kv_last_page_lens, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_size, + page_size=block_size, + causal=True, + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=soft_cap, + q_data_type=dtype, + kv_data_type=dtype, + ) output = torch.empty(ref_query.shape, dtype=dtype) - wrapper.run(ref_query, ref_kv_cache, out=output) + wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output) + o_scale = 1.0 - o_sf_scale = None + o_sf_scale_float = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) elif o_quant_dtype == FP4_DTYPE: - o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(output.flatten(), dim=-1)).to(torch.float32) + o_sf_scale = get_nvfp4_global_scale(output) + o_sf_scale_float = o_sf_scale.item() # TRTLLM Decode if o_quant_dtype == FP4_DTYPE: output_trtllm = flashinfer.utils.FP4Tensor( - torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), - dtype=torch.uint8), - torch.empty((round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // 16, 4)), - dtype=torch.float8_e4m3fn), + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), ) else: output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) @@ -188,30 +222,35 @@ def test_flashinfer_trtllm_decode_with_baseline( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, - o_sf_scale=o_sf_scale, + window_left=window_left, + sinks=sinks, + o_sf_scale=o_sf_scale_float, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: output_trtllm = output_trtllm.to(dtype) * o_scale elif o_quant_dtype == FP4_DTYPE: output_trtllm.data = output_trtllm.data.reshape( - -1, query.shape[1] * query.shape[2] // 2) - output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, - output_trtllm.scale, - o_sf_scale, dtype, - query.device) - output_trtllm = output_trtllm.reshape(-1, query.shape[1], - query.shape[2]) + -1, query.shape[1] * query.shape[2] // 2 + ) + output_trtllm = dequantize_nvfp4_to_dtype( + output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device + ) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 3e-1, 1e0 + rtol, atol = 7e-2, 9e-2 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - rtol, atol = 5e-2, 7e-2 - else: + rtol, atol = 2e-2, 4e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: rtol, atol = 1e-2, 2e-2 + else: + rtol, atol = 1e-2, 1e-2 - torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - output_trtllm))}" + ( + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - output_trtllm))}", + ) @pytest.mark.parametrize("dtype", DTYPE) @@ -222,22 +261,27 @@ def test_flashinfer_trtllm_decode_with_baseline( @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("has_sinks", HAS_SINKS) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( dtype: torch.dtype, - quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], - Optional[torch.dtype]], + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], batch_size: int, max_seq_lens: tuple[int, int], num_heads: tuple[int, int], head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], + has_sinks: bool, ) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(0) + current_platform.seed_everything(42) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype = q_quant_dtype or dtype @@ -262,17 +306,16 @@ def test_flashinfer_trtllm_prefill_with_baseline( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) q_lens[-1] = max_q_len - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ] + ) - query = torch.randn(torch.sum(q_lens).item(), - num_qo_heads, - head_size, - dtype=dtype) + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: query, q_scale = to_float8(query) ref_query = query.to(dtype) * q_scale @@ -280,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( q_scale = 1.0 ref_query = query - kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len seq_lens = kv_lens + q_lens @@ -296,10 +339,9 @@ def test_flashinfer_trtllm_prefill_with_baseline( k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (batch_size, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] @@ -320,40 +362,55 @@ def test_flashinfer_trtllm_prefill_with_baseline( workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) # Baseline Prefill - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout) - wrapper.plan(q_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, - causal=True, - sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) + if has_sinks: + sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5 + wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + else: + sinks = None + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + wrapper.plan( + qo_indptr=q_indptr, + paged_kv_indptr=kv_indptr, + paged_kv_indices=kv_indices, + paged_kv_last_page_len=kv_last_page_lens, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_size, + page_size=block_size, + causal=True, + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=soft_cap, + q_data_type=dtype, + kv_data_type=dtype, + ) output = torch.empty(ref_query.shape, dtype=dtype) - wrapper.run(ref_query, ref_kv_cache, out=output) + wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output) + o_scale = 1.0 - o_sf_scale = None + o_sf_scale_float = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) elif o_quant_dtype == FP4_DTYPE: - o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(output.flatten(), dim=-1)).to(torch.float32) + o_sf_scale = get_nvfp4_global_scale(output) + o_sf_scale_float = o_sf_scale.item() # TRTLLM Prefill if o_quant_dtype == FP4_DTYPE: output_trtllm = flashinfer.utils.FP4Tensor( - torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), - dtype=torch.uint8), - torch.empty((round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // 16, 4)), - dtype=torch.float8_e4m3fn), + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), ) else: output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) @@ -371,27 +428,32 @@ def test_flashinfer_trtllm_prefill_with_baseline( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, - o_sf_scale=o_sf_scale, + window_left=window_left, + sinks=sinks, + o_sf_scale=o_sf_scale_float, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: output_trtllm = output_trtllm.to(dtype) * o_scale elif o_quant_dtype == FP4_DTYPE: output_trtllm.data = output_trtllm.data.reshape( - -1, query.shape[1] * query.shape[2] // 2) - output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, - output_trtllm.scale, - o_sf_scale, dtype, - query.device) - output_trtllm = output_trtllm.reshape(-1, query.shape[1], - query.shape[2]) + -1, query.shape[1] * query.shape[2] // 2 + ) + output_trtllm = dequantize_nvfp4_to_dtype( + output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device + ) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 4e-1, 1e0 + rtol, atol = 1e-1, 2e-1 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - rtol, atol = 5e-2, 7e-2 + rtol, atol = 4e-2, 6e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: + rtol, atol = 2e-2, 3e-2 else: rtol, atol = 1e-2, 1e-2 - torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - output_trtllm))}" + ( + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - output_trtllm))}", + ) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index abcfe828d5aca..2151933a610d8 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -7,30 +7,35 @@ import random import pytest import torch -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_dense_supported, +) from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, - y: torch.Tensor, - name: str, - use_fp8: bool = False) -> None: +def cal_diff( + x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False +) -> None: x, y = x.double(), y.double() - cos_diff = 1 - 2 * (x * y).sum().item() / max( - (x * x + y * y).sum().item(), 1e-12) - if (use_fp8): + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + if use_fp8: assert cos_diff < 1e-4 else: assert cos_diff < 1e-5 -FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ - if not is_flashmla_supported()[0] else "FlashMLA is supported" + +FLASH_MLA_UNSUPPORTED_REASON = ( + is_flashmla_dense_supported()[1] + if not is_flashmla_dense_supported()[0] + else "FlashMLA is supported" +) -@pytest.mark.skipif(not is_flashmla_supported()[0], - reason=FLASH_MLA_UNSUPPORTED_REASON) +@pytest.mark.skipif( + not is_flashmla_dense_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON +) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @@ -41,47 +46,49 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("torch_dtype", - [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn] +) @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen, torch_dtype): +def test_flash_mla( + b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype +): device = torch.device("cuda:0") - if torch_dtype == torch.float8_e4m3fn: - init_dtype = torch.bfloat16 - else: - init_dtype = torch_dtype + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" + ) use_fp8 = torch_dtype == torch.float8_e4m3fn - cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), - s_q) + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) total_seqlens = cache_seqlens.sum().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 q = torch.randn(b, s_q, h_q, d) - block_table = torch.arange(b * max_seqlen_pad // block_size, - dtype=torch.int32).view( - b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, - d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv) + cache_seqlens, s_q * h_q // h_kv, h_kv + ) init_dtype = q.dtype if use_fp8: @@ -121,8 +128,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, - dtype=torch.bool).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -132,10 +138,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, def ref_mla(): q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q - blocked_k_ = (blocked_k.to(torch.float) * - descale_k).to(init_dtype) if use_fp8 else blocked_k - blocked_v_ = (blocked_v.to(torch.float) * - descale_k).to(init_dtype) if use_fp8 else blocked_v + blocked_k_ = ( + (blocked_k.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_k + ) + blocked_v_ = ( + (blocked_v.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_v + ) out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): @@ -158,8 +170,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + - b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( - b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", - f"{bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * ( + torch.finfo(torch_dtype).bits // 8 + ) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print( + f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s" + ) diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py new file mode 100644 index 0000000000000..7ee6f4b07b4a9 --- /dev/null +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + + +def test_sparse_flashmla_metadata_smoke(): + import vllm.attention.ops.flashmla as fm + + ok, reason = fm.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) + + device = torch.device("cuda") + batch_size = 1 + seqlen_q = 1 + num_heads_q = 128 + num_heads_k = 1 + q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k + topk = 128 + + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + + tile_md, num_splits = fm.get_mla_metadata( + cache_seqlens, + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True, + ) + assert tile_md.dtype == torch.int32 + assert num_splits.dtype == torch.int32 + + +def test_sparse_flashmla_decode_smoke(): + import vllm.attention.ops.flashmla as fm + + ok, reason = fm.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) + + device = torch.device("cuda") + batch_size = 1 + seqlen_q = 1 + num_heads_q = 1 + head_dim_k = 576 + head_dim_v = 512 + num_heads_k = 1 + page_block_size = 64 + bytes_per_token = 656 + topk = 128 + + # Metadata + q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k + # q_heads_per_hk = num_heads_q // num_heads_k + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + tile_md, num_splits = fm.get_mla_metadata( + cache_seqlens, + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True, + ) + + # Inputs + q = torch.zeros( + (batch_size, seqlen_q, num_heads_q, head_dim_k), + dtype=torch.bfloat16, + device=device, + ) + k_cache = torch.zeros( + (1, page_block_size, num_heads_k, bytes_per_token), + dtype=torch.uint8, + device=device, + ) + indices = torch.zeros( + (batch_size, seqlen_q, topk), dtype=torch.int32, device=device + ) + + block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device) + out, lse = fm.flash_mla_with_kvcache( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_md, + num_splits, + indices=indices, + is_fp8_kvcache=True, + ) + assert out.shape[0] == batch_size + assert out.shape[-1] == head_dim_v + assert lse.shape[0] == batch_size + + +def test_sparse_flashmla_prefill_smoke(): + import vllm.attention.ops.flashmla as fm + + ok, reason = fm.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) + + device = torch.device("cuda") + s_q = 1 + s_kv = 1 + h_q = 64 # kernel expects multiple of 64 + h_kv = 1 + d_qk = 576 + d_v = 512 + topk = 128 + + q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device) + kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device) + indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device) + + out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v) + assert out.shape == (s_q, h_q, d_v) + assert max_logits.shape == (s_q, h_q) + assert lse.shape == (s_q, h_q) diff --git a/tests/kernels/attention/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py index de45ee1ed5cca..ec938caff2c6d 100644 --- a/tests/kernels/attention/test_lightning_attn.py +++ b/tests/kernels/attention/test_lightning_attn.py @@ -4,8 +4,7 @@ import pytest import torch -from vllm.model_executor.layers.lightning_attn import ( - linear_decode_forward_triton) +from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton from vllm.platforms import current_platform NUM_HEADS = [4, 8] @@ -17,8 +16,8 @@ DTYPES = [torch.float32] def reference_lightning_attention(q, k, v, ed, block_size, kv_history): """Reference implementation of lightning attention core algorithm - - The difference from the main implementation is that this processes + + The difference from the main implementation is that this processes each step sequentially, instead of using parallelized triton kernels """ B, H, S, D = q.shape @@ -34,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): # More efficient implementation # Convert decay factors to matrix form - if ed.dim() == 1: - decay = torch.exp(-ed).view(1, -1, 1, 1) - else: - decay = torch.exp(-ed) + decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed) for b in range(B): for step in range(S): @@ -62,8 +58,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): # The actual implementation returns a tensor of shape [B, H, 2, D, E] # where dimension 2 contains both KV and KV history kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E] - final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], - dim=2) # [B, H, 2, D, E] + final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E] return output, final_kv_cache @@ -109,7 +104,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx): out_h = torch.matmul(q_bh, kv_new) # Update output and cache - output[b, h * D:(h + 1) * D] = out_h + output[b, h * D : (h + 1) * D] = out_h kv_caches[b, h] = kv_new return output @@ -135,12 +130,9 @@ def test_linear_decode_forward_triton( k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) - kv_caches = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_caches = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_caches_copy = kv_caches.clone() @@ -150,15 +142,14 @@ def test_linear_decode_forward_triton( slot_idx = torch.arange(batch_size, device="cuda") - triton_output = linear_decode_forward_triton(q, k, v, kv_caches, - slope_rate, slot_idx) + triton_output = linear_decode_forward_triton( + q, k, v, kv_caches, slope_rate, slot_idx + ) - reference_output = reference_linear_decode(q, k, v, kv_caches_copy, - slope_rate, slot_idx) - torch.testing.assert_close(triton_output, - reference_output, - rtol=1e-1, - atol=1e-1) + reference_output = reference_linear_decode( + q, k, v, kv_caches_copy, slope_rate, slot_idx + ) + torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1) torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1) assert triton_output.shape == (batch_size, num_heads * head_size) @@ -184,12 +175,9 @@ def test_linear_decode_forward_triton_with_padding( k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) - kv_caches = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_caches = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_caches_copy = kv_caches.clone() @@ -199,14 +187,15 @@ def test_linear_decode_forward_triton_with_padding( slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") - triton_output = linear_decode_forward_triton(q, k, v, kv_caches, - slope_rate, slot_idx) + triton_output = linear_decode_forward_triton( + q, k, v, kv_caches, slope_rate, slot_idx + ) - reference_output = reference_linear_decode(q, k, v, kv_caches_copy, - slope_rate, slot_idx) + reference_output = reference_linear_decode( + q, k, v, kv_caches_copy, slope_rate, slot_idx + ) - padding_mask = (slot_idx - != -1).unsqueeze(1).expand(-1, num_heads * head_size) + padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size) triton_masked = triton_output[padding_mask] reference_masked = reference_output[padding_mask] @@ -217,15 +206,11 @@ def test_linear_decode_forward_triton_with_padding( for i in range(batch_size): if valid_indices[i] > 0: - torch.testing.assert_close(kv_caches[i], - kv_caches_copy[i], - rtol=rtol, - atol=atol) + torch.testing.assert_close( + kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol + ) - torch.testing.assert_close(triton_masked, - reference_masked, - rtol=rtol, - atol=atol) + torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol) assert triton_output.shape == (batch_size, num_heads * head_size) @@ -249,39 +234,33 @@ def test_lightning_attention_reference( current_platform.seed_everything(42) base = 0.01 - q = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) - k = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) - v = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) + q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) + k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) + v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) ed = torch.zeros(num_heads, device="cuda") for h in range(num_heads): ed[h] = 0.1 * (h + 1) - kv_history = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_history = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_history_clone = kv_history.clone() ref_output, ref_kv_cache = reference_lightning_attention( - q, k, v, ed, 256, kv_history) + q, k, v, ed, 256, kv_history + ) from vllm.model_executor.layers.lightning_attn import lightning_attention + actual_output, actual_kv_cache = lightning_attention( - q, k, v, ed, 256, kv_history_clone) + q, k, v, ed, 256, kv_history_clone + ) atol, rtol = 1.5e-1, 1.5e-1 torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol) - torch.testing.assert_close(ref_kv_cache, - actual_kv_cache, - rtol=rtol, - atol=atol) + torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol) assert ref_output.shape == (batch_size, num_heads, seq_len, head_size) assert ref_kv_cache.shape == actual_kv_cache.shape diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py index 9d1a301ebe304..eb9204dfaf158 100644 --- a/tests/kernels/attention/test_merge_attn_states.py +++ b/tests/kernels/attention/test_merge_attn_states.py @@ -7,19 +7,20 @@ import torch from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda from vllm.attention.ops.triton_merge_attn_states import ( - merge_attn_states as merge_attn_states_triton) + merge_attn_states as merge_attn_states_triton, +) from vllm.platforms import current_platform # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # can be used to combine partial attention results (in the split-KV case) def merge_attn_states_torch( - output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] + output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] ): p_lse = prefix_lse s_lse = suffix_lse @@ -32,15 +33,13 @@ def merge_attn_states_torch( s_lse = s_lse - max_lse p_lse_exp = torch.exp(p_lse) s_lse_exp = torch.exp(s_lse) - out_se = (p_lse_exp + s_lse_exp) + out_se = p_lse_exp + s_lse_exp if output_lse is not None: output_lse = torch.log(out_se) + max_lse p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] - p_scale = torch.transpose(p_scale, 0, - 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] - s_scale = torch.transpose(s_scale, 0, - 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] output = prefix_output * p_scale + suffix_output * s_scale return output, output_lse @@ -55,8 +54,10 @@ all_case_info: list[tuple] = [] def generate_markdown_table(): global all_case_info - table_header = ("| tokens | heads | headsize | dtype " - "| device | torch | triton | cuda | speedup |") + table_header = ( + "| tokens | heads | headsize | dtype " + "| device | torch | triton | cuda | speedup |" + ) table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |" def shortly_dtype(dtype: torch.dtype) -> str: @@ -68,16 +69,26 @@ def generate_markdown_table(): print(table_header) print(table_separator) for info in all_case_info: - (num_tokens, num_heads, head_size, dtype, device, - avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, - performance_improved) = info + ( + num_tokens, + num_heads, + head_size, + dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_kernel, + performance_improved, + ) = info dtype = shortly_dtype(dtype) device = shortly_device(device) - print(f"| {num_tokens} | {num_heads} | {head_size} " - f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " - f"| {avg_time_triton_kernel:.5f}ms " - f"| {avg_time_cuda_kernel:.5f}ms " - f"| {performance_improved:.4f}x |") + print( + f"| {num_tokens} | {num_heads} | {head_size} " + f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " + f"| {avg_time_triton_kernel:.5f}ms " + f"| {avg_time_cuda_kernel:.5f}ms " + f"| {performance_improved:.4f}x |" + ) @pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) @@ -85,29 +96,28 @@ def generate_markdown_table(): @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("output_dtype", DTYPES) @torch.inference_mode() -def test_merge_attn_states(num_tokens: int, num_query_heads: int, - head_size: int, output_dtype: torch.dtype): +def test_merge_attn_states( + num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype +): if not current_platform.is_cuda(): - pytest.skip('Currently only support compare triton merge_attn_states ' - 'with custom cuda merge_attn_states kernel') + pytest.skip( + "Currently only support compare triton merge_attn_states " + "with custom cuda merge_attn_states kernel" + ) NUM_TOKENS = num_tokens NUM_HEADS = num_query_heads HEAD_SIZE = head_size - print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " - f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " - f"Device: {current_platform.get_device_name()}") + print( + f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " + f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"Device: {current_platform.get_device_name()}" + ) # prefix_lse and suffix_lse contain inf and normal values - prefix_lse = torch.randn(NUM_HEADS, - NUM_TOKENS, - dtype=torch.float32, - device="cuda") - suffix_lse = torch.randn(NUM_HEADS, - NUM_TOKENS, - dtype=torch.float32, - device="cuda") + prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") # Generate boolean masks mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 @@ -117,23 +127,23 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) mask_suffix = torch.logical_and(mask_suffix, ~combined_mask) - prefix_lse[mask_prefix] = float('inf') - suffix_lse[mask_suffix] = float('inf') + prefix_lse[mask_prefix] = float("inf") + suffix_lse[mask_suffix] = float("inf") # Other input tensors (need to be initialized but # no actual calculation needed) - output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") - output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS), - dtype=torch.float32, - device="cuda") - prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") - suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") + output = torch.zeros( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + output_lse = torch.zeros( + (NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda" + ) + prefix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + suffix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) warmup_times = 2 repeat_times = 20 @@ -149,15 +159,25 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, suffix_lse_torch = suffix_lse.clone() for _ in range(warmup_times): output_torch, output_lse_torch = merge_attn_states_torch( - output_torch, prefix_output, prefix_lse_torch, suffix_output, - suffix_lse_torch, output_lse_torch) + output_torch, + prefix_output, + prefix_lse_torch, + suffix_output, + suffix_lse_torch, + output_lse_torch, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() output_torch, output_lse_torch = merge_attn_states_torch( - output_torch, prefix_output, prefix_lse_torch, suffix_output, - suffix_lse_torch, output_lse_torch) + output_torch, + prefix_output, + prefix_lse_torch, + suffix_output, + suffix_lse_torch, + output_lse_torch, + ) end.record() torch.cuda.synchronize() total_time_torch_kernel += start.elapsed_time(end) @@ -173,16 +193,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, end = torch.cuda.Event(enable_timing=True) for _ in range(warmup_times): - merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, - suffix_output, suffix_lse, - output_lse_ref_triton) + merge_attn_states_triton( + output_ref_triton, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_ref_triton, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() - merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, - suffix_output, suffix_lse, - output_lse_ref_triton) + merge_attn_states_triton( + output_ref_triton, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_ref_triton, + ) end.record() torch.cuda.synchronize() total_time_triton_kernel += start.elapsed_time(end) @@ -195,14 +225,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, output_lse_cuda = output_lse.clone() for _ in range(warmup_times): - merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse_cuda) + merge_attn_states_cuda( + output_cuda, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_cuda, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() - merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse_cuda) + merge_attn_states_cuda( + output_cuda, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_cuda, + ) end.record() torch.cuda.synchronize() total_time_cuda_kernel += start.elapsed_time(end) @@ -213,8 +255,10 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel print(f" Torch time: {avg_time_torch_kernel:.6f}ms") print(f"Triton time: {avg_time_triton_kernel:.6f}ms") - print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " - f"Performance: {performance_improved:.5f}x") + print( + f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " + f"Performance: {performance_improved:.5f}x" + ) print("-" * 100) # 4. Correctness compare @@ -232,35 +276,45 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, # states operation. output_ref = output_ref_triton output_lse_ref = output_lse_ref_triton - torch.testing.assert_close(output_cuda.float(), - output_ref.float(), - atol=1e-3, - rtol=rtol) + torch.testing.assert_close( + output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol + ) print("Output all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}") print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}") print("-" * 100) - torch.testing.assert_close(output_lse_cuda.float(), - output_lse_ref.float(), - atol=1e-3, - rtol=rtol) + torch.testing.assert_close( + output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + ) print("Output LSE all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}") print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}") print("-" * 100) - print("All output values test passed! All inf values " - "are correctly replaced with -inf.") + print( + "All output values test passed! All inf values " + "are correctly replaced with -inf." + ) print("-" * 100) device = current_platform.get_device_name() all_case_info.append( - (NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device, - avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, - performance_improved)) - if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * - len(NUM_QUERY_HEADS) * len(DTYPES)): + ( + NUM_TOKENS, + NUM_HEADS, + HEAD_SIZE, + output_dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_kernel, + performance_improved, + ) + ) + if len(all_case_info) == ( + len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES) + ): generate_markdown_table() diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 53c37554b15a3..14d1618bca3c5 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -5,13 +5,15 @@ Test: * Tests for MultiHeadAttention layer """ + from unittest.mock import patch import pytest import torch +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import MultiHeadAttention -from vllm.attention.selector import _Backend, _cached_get_attn_backend +from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform @@ -20,9 +22,12 @@ from vllm.platforms.rocm import RocmPlatform @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() + # Clear xformers availability cache + import vllm.attention.layer as layer_module + + layer_module.USE_XFORMERS_OPS = None @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @@ -33,22 +38,66 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with ( + patch("vllm.attention.layer.current_platform", CpuPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), + ): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA elif device == "hip": - with patch("vllm.attention.selector.current_platform", RocmPlatform()): + with ( + patch("vllm.attention.layer.current_platform", RocmPlatform()), + patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), + ): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + # Test CUDA with head_size=64 (divisible by 32) + # - should use vLLM's FlashAttention + with ( + patch("vllm.attention.layer.current_platform", CudaPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == _Backend.FLASH_ATTN - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA not available + # - should use xformers + with ( + patch("vllm.attention.layer.current_platform", CudaPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + patch( + "vllm.attention.layer.check_upstream_fa_availability", + return_value=False, + ), + ): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA available + # - should use upstream FA + with ( + patch("vllm.attention.layer.current_platform", CudaPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + patch( + "vllm.attention.layer.check_upstream_fa_availability", return_value=True + ), + patch.dict( + "sys.modules", + { + "flash_attn": type( + "MockFlashAttn", + (), + {"flash_attn_varlen_func": lambda *args, **kwargs: None}, + )() + }, + ), + ): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + def ref_attention( query: torch.Tensor, @@ -74,9 +123,11 @@ NUM_HEADS = [1, 16] NUM_KV_HEADS = [1] HEAD_SIZES = [64, 80] # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [ - torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +DTYPES = ( + [torch.half, torch.bfloat16, torch.float] + if not current_platform.is_rocm() + else [torch.half, torch.bfloat16] +) CUDA_DEVICES = ["cuda"] @@ -104,10 +155,9 @@ def test_mha_attn_forward( k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention(num_heads, - head_size, - scale=scale, - num_kv_heads=num_kv_heads) + attn = MultiHeadAttention( + num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads + ) output = attn(q, k, v) assert num_heads % num_kv_heads == 0 diff --git a/tests/kernels/attention/test_mla_decode_cpu.py b/tests/kernels/attention/test_mla_decode_cpu.py index f8b307c595dea..44f3e42e8714a 100644 --- a/tests/kernels/attention/test_mla_decode_cpu.py +++ b/tests/kernels/attention/test_mla_decode_cpu.py @@ -11,30 +11,24 @@ from vllm.utils import cdiv def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) ): bs, num_heads, v_head_dim = out.shape head_dim = query.shape[2] for i in range(bs): # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) v = kv[:, :, :v_head_dim] q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) out[i] = o.view(num_heads, v_head_dim) return out @@ -63,18 +57,17 @@ def test_mla_decode_cpu( torch.set_default_dtype(dtype) torch.manual_seed(0) - scale = d**(-0.5) + scale = d ** (-0.5) if varlen: seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) seq_lens = seq_lens.clip(2).to(torch.int32) else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32) max_seq_len = seq_lens.max().item() seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary? q = torch.randn(bs, h_q, d) - block_table = torch.arange(bs * seqlen_pad // block_size, - dtype=torch.int32) + block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32) block_table = block_table.view(bs, seqlen_pad // block_size) kv_cache = torch.randn(block_table.numel(), block_size, d) @@ -82,8 +75,7 @@ def test_mla_decode_cpu( kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan") out_mla = q.new_zeros(bs, h_q, dv) - ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, - seq_lens) + ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens) out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) diff --git a/tests/kernels/attention/test_pack_unpack_triton.py b/tests/kernels/attention/test_pack_unpack_triton.py new file mode 100644 index 0000000000000..d2aa14738d9d9 --- /dev/null +++ b/tests/kernels/attention/test_pack_unpack_triton.py @@ -0,0 +1,234 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch.testing import assert_close + +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton + + +def test_pack_seq_basic_fp8(): + """Test basic functionality of pack_seq_triton with fp8 and 3D tensors.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test cases with 3D tensors (N, H, D) + test_cases = [ + (6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4) + (10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8) + (20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32) + ] + + for N, H, D, B, lengths_list in test_cases: + # Create input tensor with small values for fp8 + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor(lengths_list, device=device) + + # Pack the data + packed = pack_seq_triton(x, lengths) + + # Check output shape and properties + expected_shape = (B, max(lengths_list), H, D) + assert packed.shape == expected_shape + assert packed.dtype == dtype + assert packed.device == x.device + + # Check that valid data is preserved (within fp8 precision) + for b in range(B): + start_idx = sum(lengths_list[:b]) + seq_len = lengths_list[b] + + expected_data = x[start_idx : start_idx + seq_len].to(torch.float32) + actual_data = packed[b, :seq_len].to(torch.float32) + + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + +def test_pack_seq_custom_padding_fp8(): + """Test pack_seq_triton with custom padding values for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + # Test with different padding values + for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]: + result = pack_seq_triton(x, lengths, pad_value=pad_value) + + # Check valid data + for b in range(B): + start_idx = b * 10 + expected_data = x[start_idx : start_idx + 10].to(torch.float32) + actual_data = result[b, :10].to(torch.float32) + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + # Check padding (fp8 has limited range, so check for large values) + padded_data = result[:, 10:].to(torch.float32) + if pad_value < 0: + assert torch.all(padded_data < -50) # Large negative values + elif pad_value > 0: + assert torch.all(padded_data > 50) # Large positive values + else: + assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2) + + +def test_pack_seq_default_negative_inf_padding_fp8(): + """Test that pack_seq_triton uses -inf padding by default for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + # B = 2 + N, H, D = 20, 8, 16 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + result = pack_seq_triton(x, lengths) + + # Check that padding is large negative values (fp8 representation of -inf) + padded_data = result[:, 10:].to(torch.float32) + assert torch.all( + padded_data < -100 + ) # fp8 -inf is represented as large negative number + + +def test_pack_seq_edge_cases_fp8(): + """Test pack_seq_triton with edge cases for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test with single batch element + x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([10], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (1, 10, 8, 16) + + # Test with very short sequences + x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([1, 1, 1], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (3, 1, 4, 8) + + # Test with different sequence lengths + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([5, 7, 3], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (3, 7, 8, 16) + + +def test_pack_seq_different_block_sizes_fp8(): + """Test pack_seq_triton with different block sizes for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 100, 16, 32, 4 + lengths = torch.tensor([25, 25, 25, 25], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + # Test different block sizes + for block_t, block_d in [(32, 32), (64, 64), (128, 128)]: + result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d) + + assert result.shape == (B, 25, H, D) + + # Check that valid data is preserved (within fp8 precision) + for b in range(B): + start_idx = b * 25 + expected_data = x[start_idx : start_idx + 25].to(torch.float32) + actual_data = result[b, :25].to(torch.float32) + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + +def test_pack_seq_shape_consistency(): + """Test that pack_seq_triton maintains shape consistency.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + result = pack_seq_triton(x, lengths) + + # Check shape consistency + assert result.shape[0] == B # Batch dimension + assert result.shape[1] == lengths.max().item() # Max sequence length + assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved + + +def test_pack_unpack_roundtrip_fp8(): + """Test that pack -> unpack gives us back the original data for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test cases with 3D tensors + test_cases = [ + (6, 8, 4, 2, [3, 3]), + (10, 4, 8, 3, [2, 4, 4]), + (20, 16, 32, 4, [5, 5, 5, 5]), + (15, 8, 16, 3, [7, 5, 3]), + ] + + for N, H, D, B, lengths_list in test_cases: + # Create input tensor with small values for fp8 + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor(lengths_list, device=device) + + # Pack the data + packed = pack_seq_triton(x, lengths) + + # Unpack the data + unpacked = unpack_seq_triton(packed, lengths) + + # Check that we get back the original data (within fp8 precision) + assert unpacked.shape == x.shape + x_f32 = x.to(torch.float32) + unpacked_f32 = unpacked.to(torch.float32) + assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3) + + # Unpack without explicit start locations (computed in kernel) + unpacked_with_loc = unpack_seq_triton(packed, lengths) + assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2) + + +def test_unpack_seq_triton_edge_cases_fp8(): + """Test unpack function with edge cases for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test with single batch element + x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([10], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + assert unpacked.shape == x.shape + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) + + # Test with very short sequences + x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([1, 1, 1], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + # Only compare the first 3 elements that were actually packed + assert_close( + x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2 + ) + + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([5, 7, 3], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + assert unpacked.shape == x.shape + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 8544eab3acccd..5ff2624cd7a49 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -11,9 +11,8 @@ import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -from vllm.attention.backends.xformers import _make_alibi_bias -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) +from tests.kernels.utils import make_alibi_bias +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -22,9 +21,7 @@ NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] HEAD_SIZES = [24, 128] DTYPES = [torch.float16] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] SLIDING_WINDOW = [0, 16, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] @@ -50,12 +47,10 @@ def test_contexted_kv_attention( device: str, op: Callable, ) -> None: - - if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( - 89): + if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( - 'Triton limitation: fp8e4nv data type is not supported on CUDA' - ' arch < 89') + "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" + ) current_platform.seed_everything(0) torch.set_default_device(device) @@ -93,38 +88,29 @@ def test_contexted_kv_attention( cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) + k_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) + v_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.long), - dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + ) for i in range(BS): for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -135,61 +121,71 @@ def test_contexted_kv_attention( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = ( + k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() + v_cache = ( + v_cache.view(-1, block_size, num_kv_heads, head_size) + .permute(0, 2, 3, 1) + .contiguous() + ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window, + ) torch.cuda.synchronize() start_time = time.time() - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window, + ) torch.cuda.synchronize() end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) @@ -201,22 +197,24 @@ def test_contexted_kv_attention( # heads. # # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) + query = query.view( + query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1] + ) + key = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens) + query_lens, seq_lens + ) if sliding_window > 0: - attn_bias = attn_bias.make_local_attention_from_bottomright( - sliding_window) + attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, @@ -239,7 +237,7 @@ def test_contexted_kv_attention( ) torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") output_ref = output_ref.reshape(output.shape) atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -262,12 +260,10 @@ def test_contexted_kv_attention_alibi( device: str, op: Callable, ) -> None: - - if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( - 89): + if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( - 'Triton limitation: fp8e4nv data type is not supported on CUDA' - ' arch < 89') + "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" + ) current_platform.seed_everything(0) torch.set_default_device(device) @@ -280,9 +276,9 @@ def test_contexted_kv_attention_alibi( def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -290,17 +286,16 @@ def test_contexted_kv_attention_alibi( if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes alibi_slopes = _get_alibi_slopes(num_heads).to(device) @@ -328,38 +323,29 @@ def test_contexted_kv_attention_alibi( cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) + k_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) + v_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.long), - dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + ) for i in range(BS): for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -370,82 +356,90 @@ def test_contexted_kv_attention_alibi( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = ( + k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() + v_cache = ( + v_cache.view(-1, block_size, num_kv_heads, head_size) + .permute(0, 2, 3, 1) + .contiguous() + ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes, + ) torch.cuda.synchronize() start_time = time.time() - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes, + ) torch.cuda.synchronize() end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) # NOTE(DefTruth): In order to reuse _make_alibi_bias function, # we have to pad query tensor before MQA/GQA expanding. if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), - num_heads, - head_size, - dtype=dtype) + query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) query_pad.uniform_(-1e-3, 1e-3) seq_start = 0 query_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) + query_pad[seq_start:seq_end, ...] = torch.cat( + [ + torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...], + ], + dim=0, + ) seq_start += seq_len query_start += query_len query = query_pad @@ -456,11 +450,12 @@ def test_contexted_kv_attention_alibi( # heads. # # see also: vllm/model_executor/layers/attention.py - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) + key = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) # [seq, num_kv_heads, num_queries_per_kv, dk]=> # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the # codebase. We save some time reshaping alibi matrix at runtime. @@ -470,7 +465,7 @@ def test_contexted_kv_attention_alibi( key = key.unsqueeze(0) value = value.unsqueeze(0) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output_ref = torch.empty_like(output) seq_start = 0 query_start = 0 @@ -479,28 +474,27 @@ def test_contexted_kv_attention_alibi( # FIXME(DefTruth): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/attention/backends/xformers.py#L343 + # modified from: vllm/v1/attention/backends/xformers.py#L343 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) + out = xops.memory_efficient_attention_forward( + query[:, seq_start:seq_end], + key[:, seq_start:seq_end], + value[:, seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale, + ) out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, - ...]) + seq_len, num_heads, head_size + ) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...]) seq_start += seq_len query_start += query_len torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -532,9 +526,16 @@ def test_contexted_kv_attention_f32( device: str, op: Callable, ) -> None: - test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, - sliding_window, dtype, kv_cache_dtype, device, - op) + test_contexted_kv_attention( + num_heads, + num_queries_per_kv, + head_size, + sliding_window, + dtype, + kv_cache_dtype, + device, + op, + ) @pytest.mark.optional @@ -555,5 +556,6 @@ def test_contexted_kv_attention_alibi_f32( device: str, op: Callable, ) -> None: - test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, - dtype, kv_cache_dtype, device, op) + test_contexted_kv_attention_alibi( + num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op + ) diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index d56d3f4638f1c..9b7fb664956c6 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -11,60 +11,40 @@ from vllm.utils import STR_BACKEND_ENV_VAR @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() +@pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN") # Set the current platform to ROCm using monkeypatch - monkeypatch.setattr("vllm.attention.selector.current_platform", - RocmPlatform()) + monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) # Test standard ROCm attention backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) - assert (backend.get_name() == "ROCM_FLASH" - or backend.get_name() == "TRITON_ATTN_VLLM_V1") + assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" # MLA test for deepseek related # change the attention backend to triton MLA m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 16, - False, - use_mla=True) - assert (backend.get_name() == "TRITON_MLA" - or backend.get_name() == "TRITON_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) + assert backend.get_name() == "TRITON_MLA" # If attention backend is None # If use_mla is true # The selected backend is triton MLA m.setenv(STR_BACKEND_ENV_VAR, None) - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 16, - False, - use_mla=True) - assert (backend.get_name() == "TRITON_MLA" - or backend.get_name() == "TRITON_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) + assert backend.get_name() == "TRITON_MLA" # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 1, - False, - use_mla=True) - assert (backend.get_name() == "ROCM_AITER_MLA" - or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) + assert backend.get_name() == "ROCM_AITER_MLA" # If attention backend is None # If use_mla is true @@ -72,11 +52,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # The selected backend is ROCM_AITER_MLA m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv("VLLM_ROCM_USE_AITER", "1") - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 1, - False, - use_mla=True) - assert (backend.get_name() == "ROCM_AITER_MLA" - or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) + assert backend.get_name() == "ROCM_AITER_MLA" diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index 2dca720fe3301..01ba0951b8254 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -24,14 +24,12 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) - req_to_page = torch.randint(0, - CACHE_SIZE // PAGE_SIZE, - (B, num_pages_per_batch, 1), - device="cuda") + req_to_page = torch.randint( + 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" + ) req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) - req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( - 1, 1, -1) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() @@ -46,7 +44,9 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): # o will have the same shape as q o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - b_seq_len = torch.full((B, ), seq_len, device="cuda") + lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda") + + b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), @@ -60,6 +60,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -72,12 +73,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) o1 = torch.zeros_like(o) + lse1 = torch.zeros_like(lse) decode_attention_fwd( q, k_buffer, v_buffer, o1, + lse1, req_to_page, b_seq_len, attn_logits, diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 4b97d51e6ed21..fba82cfdadbdf 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -14,9 +14,11 @@ HEAD_SIZES = [128, 256] BLOCK_SIZES = [16] DTYPES = [torch.bfloat16] -QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ - None, torch.float8_e4m3fnuz -] +QDTYPES = ( + [None, torch.float8_e4m3fn] + if not current_platform.is_rocm() + else [None, torch.float8_e4m3fnuz] +) # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] @@ -42,7 +44,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -60,10 +62,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None and soft_cap > 0: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -77,13 +82,13 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @@ -102,9 +107,6 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") - if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: - pytest.skip("block size must be at least 32 for fp8") - current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -114,30 +116,23 @@ def test_triton_unified_attn( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) output = torch.empty_like(query) @@ -191,5 +186,7 @@ def test_triton_unified_attn( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index ec5c60fd7b0e2..e8777ec4f59e8 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -8,19 +8,23 @@ import torch from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, - GeluAndMul, MulAndSilu, - NewGELU, QuickGELU, - SiluAndMul, SwigluOAIAndMul) +from vllm.model_executor.layers.activation import ( + FastGELU, + FatreluAndMul, + GeluAndMul, + MulAndSilu, + NewGELU, + QuickGELU, + SiluAndMul, + SwigluOAIAndMul, +) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize( @@ -73,24 +77,19 @@ def test_act_and_mul( out = layer(x) ref_out = layer.forward_native(x) if activation == "swigluoai_and_mul": - rtol = { - #For fp16, change the relative tolerance from 1e-3 to 2e-3 - torch.float16: - 2e-3, - torch.bfloat16: - 2e-2, - torch.float: - 1.3e-6 + # For fp16, change the relative tolerance from 1e-3 to 2e-3 + torch.float16: 2e-3, + torch.bfloat16: 2e-2, + torch.float: 1.3e-6, } def _get_rtol(output) -> float: return rtol[output.dtype] - torch.testing.assert_close(out, - ref_out, - atol=get_default_atol(out), - rtol=_get_rtol(out)) + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=_get_rtol(out) + ) else: # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are # equivalent to the native PyTorch implementations, so we can do exact @@ -98,7 +97,7 @@ def test_act_and_mul( torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) @@ -108,9 +107,14 @@ def test_act_and_mul( opcheck(fn, (out, x)) -@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast), - (NewGELU, torch.ops._C.gelu_new), - (QuickGELU, torch.ops._C.gelu_quick)]) +@pytest.mark.parametrize( + "activation", + [ + (FastGELU, torch.ops._C.gelu_fast), + (NewGELU, torch.ops._C.gelu_new), + (QuickGELU, torch.ops._C.gelu_quick), + ], +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -132,10 +136,9 @@ def test_activation( fn = activation[1] out = layer(x) ref_out = layer.forward_native(x) - torch.testing.assert_close(out, - ref_out, - atol=get_default_atol(out), - rtol=get_default_rtol(out)) + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) + ) out = torch.empty_like(x) opcheck(fn, (out, x)) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 19703b8a2f978..52133ec53d1d7 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -24,9 +24,7 @@ NUM_TOKENS_HIDDEN_SIZES = [ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] EPS = 1e-6 @@ -34,13 +32,12 @@ EPS = 1e-6 def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: - return torch.as_tensor(x, dtype=torch.float32, device='cuda') + return torch.as_tensor(x, dtype=torch.float32, device="cuda") -def ref_rms_norm(rms_norm_layer: RMSNorm, - x: torch.Tensor, - residual: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, Optional[torch.Tensor]]: +def ref_rms_norm( + rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor] +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -50,12 +47,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm, return out, residual -def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ref_dynamic_per_token_quant( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -64,9 +62,9 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, # Quant if quant_dtype == torch.float8_e4m3fn: - torch_out, scales = ops.scaled_fp8_quant(torch_out, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) + torch_out, scales = ops.scaled_fp8_quant( + torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) else: assert quant_dtype == torch.int8 torch_out, scales = ops.scaled_int8_quant(torch_out) @@ -74,38 +72,41 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, return torch_out, scales, residual -def ref_impl(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, - residual, scale_ub) +def ref_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ref_dynamic_per_token_quant( + rms_norm_layer, x, quant_dtype, residual, scale_ub + ) -def ops_dynamic_per_token_quant(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ops_dynamic_per_token_quant( + weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() - out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, - quant_dtype, scale_ub, - residual) + out, scales = ops.rms_norm_dynamic_per_token_quant( + x, weight, EPS, quant_dtype, scale_ub, residual + ) return out, scales, residual -def ops_impl(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, - scale_ub) +def ops_impl( + weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @@ -146,12 +147,14 @@ def test_rms_norm( residual = torch.randn_like(x) * scale if add_residual else None if scale_ub is not None: rms_x, _ = ref_rms_norm(layer, x, residual) - scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') + scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") - ref_out, ref_scales, ref_residual = \ - ref_impl(layer, x, quant_dtype, residual, scale_ub) - ops_out, ops_scales, ops_residual = \ - ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + ref_out, ref_scales, ref_residual = ref_impl( + layer, x, quant_dtype, residual, scale_ub + ) + ops_out, ops_scales, ops_residual = ops_impl( + layer.weight, x, quant_dtype, residual, scale_ub + ) assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype @@ -160,15 +163,18 @@ def test_rms_norm( # big atol to account for round-off errors. assert torch.allclose(ref_out, ops_out, atol=1) else: - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + assert torch.allclose( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) if add_residual: assert torch.allclose(ref_residual, ops_residual) output = torch.empty_like(x, dtype=quant_dtype) - scales = torch.empty((x.numel() // x.shape[-1], 1), - device=x.device, - dtype=torch.float32) + scales = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) - opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant, - (output, x, layer.weight, scales, 1e-5, scale_ub, residual)) + opcheck( + torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual), + ) diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 02316ceaac735..7553d45e00576 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,18 +6,27 @@ import torch from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, - 8199] # Arbitrary values for testing +HIDDEN_SIZES = [ + 8, + 768, + 769, + 770, + 771, + 5120, + 5124, + 5125, + 5126, + 8192, + 8199, +] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -63,11 +72,46 @@ def test_rms_norm( torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) if residual is not None: - opcheck(torch.ops._C.fused_add_rms_norm, - (x, residual, layer.weight.data, layer.variance_epsilon)) + opcheck( + torch.ops._C.fused_add_rms_norm, + (x, residual, layer.weight.data, layer.variance_epsilon), + ) else: - opcheck(torch.ops._C.rms_norm, - (out, x, layer.weight.data, layer.variance_epsilon)) + opcheck( + torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon) + ) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_poly_norm( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + layer = PolyNorm().to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + layer.bias.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + + ref_out = layer.forward_native(x) + out = layer(x) + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + opcheck( + torch.ops._C.poly_norm, + (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon), + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -113,7 +157,8 @@ def test_fused_rms_norm_quant( if add_residual: torch.ops._C.fused_add_rms_norm_static_fp8_quant( - out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6) + out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6 + ) # Unfused kernel is in-place so it goes second # Also use a separate clone of x to avoid modifying the input @@ -121,29 +166,32 @@ def test_fused_rms_norm_quant( x_unfused = x_unfused_base[..., :hidden_size] assert x_unfused.is_contiguous() != strided_input torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(), - quant_scale_t) + torch.ops._C.static_scaled_fp8_quant( + out_quant, x_unfused.contiguous(), quant_scale_t + ) torch.cuda.synchronize() - torch.testing.assert_close(residual_fused, - residual, - atol=1e-2, - rtol=1e-2) + torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2) opcheck( torch.ops._C.fused_add_rms_norm_static_fp8_quant, - (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) + (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6), + ) else: - torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, - quant_scale_t, 1e-6) + torch.ops._C.rms_norm_static_fp8_quant( + out_quant_fused, x, weight, quant_scale_t, 1e-6 + ) torch.ops._C.rms_norm(out_norm, x, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, - quant_scale_t) + torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t) - opcheck(torch.ops._C.rms_norm_static_fp8_quant, - (out_quant_fused, x, weight, quant_scale_t, 1e-6)) + opcheck( + torch.ops._C.rms_norm_static_fp8_quant, + (out_quant_fused, x, weight, quant_scale_t, 1e-6), + ) - torch.testing.assert_close(out_quant.to(dtype=torch.float32), - out_quant_fused.to(dtype=torch.float32), - atol=1e-3, - rtol=1e-3) + torch.testing.assert_close( + out_quant.to(dtype=torch.float32), + out_quant_fused.to(dtype=torch.float32), + atol=1e-3, + rtol=1e-3, + ) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 3f2f330f6dc3b..02b795721f46e 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple import pytest import torch +from packaging.version import Version from transformers import AutoConfig +from transformers import __version__ as TRANSFORMERS_VERSION from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -11,65 +14,103 @@ from vllm.platforms import current_platform device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, - head_size: int, max_position_embeddings: int, - dtype: torch.dtype, device: torch.device): +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): """Generate test data for given configuration.""" + current_platform.seed_everything(42) # Create 2D positions (3, num_tokens) for multimodal case - positions = torch.randint(0, - max_position_embeddings // 4, (3, num_tokens), - device=device) + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) # Create query and key tensors - query = torch.randn(num_tokens, - num_q_heads * head_size, - dtype=dtype, - device=device) - key = torch.randn(num_tokens, - num_kv_heads * head_size, - dtype=dtype, - device=device) + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) return positions, query, key -def unroll_model_tp_dict(model_tp_dict): - return [(model_name, tp_size) - for model_name, tp_sizes in model_tp_dict.items() - for tp_size in tp_sizes] +class MRoPETestInfo(NamedTuple): + model_name: str + # https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 + atol: float = 1e-2 + rtol: float = 1.6e-2 + marks: list[pytest.MarkDecorator] = [] -model_tp_dict = { - "Qwen/Qwen2-VL-7B-Instruct": [1, 2], - "Qwen/Qwen2-VL-72B-Instruct": [1, 2], - "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], - "zai-org/GLM-4.1V-9B-Thinking": [1, 2], -} +TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version -# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 -dtype_atol_rtol_list = [ - [torch.bfloat16, 1e-2, 1.6e-2], +MODELS_TO_TEST = [ + MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), + MRoPETestInfo( + model_name="Qwen/Qwen3-VL-4B-Instruct", + marks=[ + pytest.mark.skipif( + Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), + reason="Qwen3-VL only available after Transformers v4.57", + ) + ], + ), + MRoPETestInfo( + model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", + marks=[ + pytest.mark.skipif( + Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), + reason="Qwen3-VL only available after Transformers v4.57", + ) + ], + ), ] num_tokens_list = [11, 8192] -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize("model_name, tp_size", - unroll_model_tp_dict(model_tp_dict)) -@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize( + "model_info, model_name", + [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST + ], +) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("num_tokens", num_tokens_list) -def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): +def test_mrope( + model_name: str, + model_info: MRoPETestInfo, + tp_size: int, + dtype: torch.dtype, + num_tokens: int, +): + atol = model_info.atol + rtol = model_info.rtol config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = config.hidden_size // total_num_heads + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // total_num_heads + ) is_neox_style = True rope_theta = config.rope_theta @@ -89,9 +130,9 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): # create q k v input tensors # create rotary pos emb input tensors - positions, query, key = generate_test_data(num_tokens, num_heads, - num_kv_heads, head_dim, - max_position, dtype, device) + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) query_native, key_native = mrope_helper_class.forward_native( positions, @@ -109,26 +150,42 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Skipping CUDA/ROCm only tests.") +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) @pytest.mark.parametrize( - "model_name, tp_size", - unroll_model_tp_dict({ - "Qwen/Qwen2-VL-7B-Instruct": [1, 2], - "zai-org/GLM-4.1V-9B-Thinking": [1, 2] - })) -@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) -@pytest.mark.parametrize("num_tokens", [4]) -def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, - num_tokens): + "model_info, model_name", + [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST + ], +) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +def test_mrope_torch_compile_tracing( + model_name: str, + model_info: MRoPETestInfo, + tp_size: int, + dtype: torch.dtype, + num_tokens: int, +): + atol = model_info.atol + rtol = model_info.rtol + config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = config.hidden_size // total_num_heads + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // total_num_heads + ) is_neox_style = True rope_theta = config.rope_theta max_position = config.max_position_embeddings @@ -146,16 +203,16 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, ).to(device=device) # Generate test data - positions, query, key = generate_test_data(num_tokens, num_heads, - num_kv_heads, head_dim, - max_position, dtype, device) + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) # Create a wrapper that makes the in-place function appear functional def functional_forward_cuda(pos, q, k): """Wrapper that converts in-place operation to functional style CUDA Graph does not support in-place operations. - This wrapper creates working copies of the + This wrapper creates working copies of the input tensors and modifies them. """ q_work = q.clone() # Create working copies @@ -172,11 +229,13 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, ) try: - compiled_forward_cuda = torch.compile(functional_forward_cuda, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) + compiled_forward_cuda = torch.compile( + functional_forward_cuda, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) # Run compiled version query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda( @@ -191,25 +250,16 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda) # Verify results - torch.testing.assert_close(query_compiled_cuda, - query_cuda, - atol=atol, - rtol=rtol) - torch.testing.assert_close(key_compiled_cuda, - key_cuda, - atol=atol, - rtol=rtol) - torch.testing.assert_close(query_compiled_cuda, - query_native, - atol=atol, - rtol=rtol) - torch.testing.assert_close(key_compiled_cuda, - key_native, - atol=atol, - rtol=rtol) + torch.testing.assert_close( + query_compiled_cuda, query_cuda, atol=atol, rtol=rtol + ) + torch.testing.assert_close(key_compiled_cuda, key_cuda, atol=atol, rtol=rtol) + torch.testing.assert_close( + query_compiled_cuda, query_native, atol=atol, rtol=rtol + ) + torch.testing.assert_close(key_compiled_cuda, key_native, atol=atol, rtol=rtol) print("✓ forward_cuda successfully traced with torch.compile inductor") except Exception as e: - pytest.fail( - f"forward_cuda failed to trace with torch.compile inductor: {e}") + pytest.fail(f"forward_cuda failed to trace with torch.compile inductor: {e}") diff --git a/tests/kernels/core/test_permute_cols.py b/tests/kernels/core/test_permute_cols.py index e18f6230dbcea..1e264735cb3c2 100644 --- a/tests/kernels/core/test_permute_cols.py +++ b/tests/kernels/core/test_permute_cols.py @@ -8,11 +8,11 @@ from tests.kernels.utils import opcheck from vllm._custom_ops import permute_cols -@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) -@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_permute_cols(shape, dtype): x = torch.randn(shape, dtype=dtype).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda() opcheck(torch.ops._C.permute_cols, (x, perm)) y = permute_cols(x, perm) - torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file + torch.testing.assert_close(y, x[:, perm]) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index ab6f1ccf881fd..799e0a3f2a2bd 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate, product +from itertools import product from typing import Callable, Optional import pytest @@ -19,30 +19,33 @@ NUM_HEADS = [17] # Arbitrary values for testing BATCH_SIZES = [5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] USE_KEY = [True, False] -def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_flat_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads * head_size) # For testing sliced tensors -def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_padded_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size + 64) -def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_batch_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) TENSORS_SHAPES_FN = [ - _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape + _get_batch_tensor_shape, + _get_flat_tensor_shape, + _get_padded_tensor_shape, ] @@ -60,7 +63,7 @@ TENSORS_SHAPES_FN = [ @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, - tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], + tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]], batch_size: int, seq_len: int, num_heads: int, @@ -97,186 +100,63 @@ def test_rotary_embedding( ref_query, ref_key = rope.forward_native(positions, query, key) out_query, out_key = rope.forward(positions, query, key) # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) + torch.testing.assert_close( + out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query), + ) if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + torch.testing.assert_close( + out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key), + ) else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding( - is_neox_style: bool, - tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": (1, ) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) - query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) if use_key else None - - # slice tensor if required, noop otherwise - query = query[..., :head_size] - key = key[..., :head_size] if use_key else None - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key) - out_query, out_key = rope.forward(positions, - query, - key, - offsets=torch.zeros(batch_size * seq_len, - dtype=torch.long, - device=device)) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - scaling_factors: list[int] = [1, 2, 4] - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) - key = torch.randn_like(query) if use_key else None - - offset_map = torch.tensor( - list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) - query_offsets = offset_map[query_types] - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key, - query_offsets) - out_query, out_key = rope.forward(positions, query, key, - query_offsets.flatten()) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" + assert ref_key is None and out_key is None, "expected returned key to be None" @torch.inference_mode() def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] BASES = [10000, 1000000] - ROPE_SCALINGS = (None, { - "rope_type": "linear", - "factor": (1, ) - }, { - "rope_type": "dynamic", - "factor": 1 - }) - settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, - ROPE_SCALINGS, DTYPES) + ROPE_SCALINGS = ( + None, + {"rope_type": "linear", "factor": (1,)}, + {"rope_type": "dynamic", "factor": 1}, + ) + settings = ( + HEAD_SIZES, + ROTARY_DIMS, + MAX_POSITIONS, + BASES, + IS_NEOX_STYLE, + ROPE_SCALINGS, + DTYPES, + ) rope_setting_id_map: dict[str, int] = {} for setting in product(*settings): - head_size, rotary_dim, max_position, base, \ - is_neox_stype, rope_scaling, dtype = setting + ( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) = setting if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_stype, rope_scaling, dtype) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) # different settings cannot share the same rope module assert id(rope) not in rope_setting_id_map.values() assert all(x.dtype == dtype for x in rope.buffers()) @@ -284,11 +164,25 @@ def test_rope_module_cache(): rope_setting_id_map[str(setting)] = id(rope) for setting in product(*settings): - head_size, rotary_dim, max_position, base, \ - is_neox_stype, rope_scaling, dtype = setting + ( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) = setting if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_stype, rope_scaling, dtype) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) # check if cache take effect assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index d1fd960bf115c..0a292a3e2ae70 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -13,23 +13,20 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -def rotary_embedding_opcheck(rot, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None): +def rotary_embedding_opcheck( + rot, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, +): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - opcheck(torch.ops._C.batched_rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style, rot.rotary_dim, offsets)) - else: - opcheck(torch.ops._C.rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style)) + # ops.rotary_embedding() is a in-place operation + # that updates the query and key tensors. + opcheck( + torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, rot.is_neox_style), + ) @pytest.mark.parametrize("device", ["cuda"]) @@ -40,39 +37,42 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) @pytest.mark.parametrize("head_stride_is_contiguous", [True, False]) -def test_rotary_embedding_opcheck(dist_init, device, max_position, - is_neox_style, rotary_dim, head_size, - seq_len, use_key, head_stride_is_contiguous): +def test_rotary_embedding_opcheck( + dist_init, + device, + max_position, + is_neox_style, + rotary_dim, + head_size, + seq_len, + use_key, + head_stride_is_contiguous, +): batch_size = 1 base = 10000 num_heads = 7 - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) + rot = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, torch.float32 + ) - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device=device) + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) head_stride = head_size + (64 if head_stride_is_contiguous else 0) - query = torch.randn(batch_size, - seq_len, - num_heads, - head_stride, - dtype=torch.float32, - device=device) + query = torch.randn( + batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device + ) key = torch.randn_like(query) if use_key else None query = query[..., :head_size] key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) - offsets = torch.zeros(batch_size * seq_len, - device=device, - dtype=torch.long) - rotary_embedding_opcheck(rot, positions, query, key, offsets) # if we have a contiguous head stride, test the alternate # [..., num_heads * head_dim] shape/layout if head_stride_is_contiguous: rotary_embedding_opcheck( - rot, positions, query.flatten(start_dim=-2), - key.flatten(start_dim=-2) if use_key else None) + rot, + positions, + query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None, + ) diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index c71215e4c646b..73738175e5c76 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -5,20 +5,14 @@ import torch from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") @pytest.mark.parametrize("device", CUDA_DEVICES) def test_cpu_write(device): torch.set_default_device(device) - cpu_tensor = torch.zeros(10, - 10, - device="cpu", - pin_memory=True, - dtype=torch.int32) + cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32) cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) assert cuda_view.device.type == "cuda" @@ -40,11 +34,7 @@ def test_cpu_write(device): @pytest.mark.parametrize("device", CUDA_DEVICES) def test_gpu_write(device): torch.set_default_device(device) - cpu_tensor = torch.zeros(10, - 10, - device="cpu", - pin_memory=True, - dtype=torch.int32) + cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32) cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) assert cuda_view.device.type == "cuda" @@ -59,4 +49,4 @@ def test_gpu_write(device): assert cpu_tensor[0, 0] == 2 assert cpu_tensor[2, 3] == 4 - assert cpu_tensor[4, 5] == -2 \ No newline at end of file + assert cpu_tensor[4, 5] == -2 diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index 411bd9e904b04..fea6b94481b60 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -10,7 +10,9 @@ from einops import rearrange from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.platforms import current_platform @@ -39,18 +41,15 @@ def causal_conv1d_ref( seqlen = x.shape[-1] dim, width = weight.shape if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) + dtype_in + ) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: @@ -59,12 +58,9 @@ def causal_conv1d_ref( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update_ref(x, - conv_state, - weight, - bias=None, - activation=None, - cache_seqlens=None): +def causal_conv1d_update_ref( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -91,24 +87,25 @@ def causal_conv1d_update_ref(x, assert weight.shape == (dim, width) if cache_seqlens is None: x_new = torch.cat([conv_state, x], dim=-1).to( - weight.dtype) # (batch, dim, state_len + seqlen) + weight.dtype + ) # (batch, dim, state_len + seqlen) conv_state.copy_(x_new[:, :, -state_len:]) else: width_idx = torch.arange( - -(width - 1), 0, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand( - -1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], - dim=-1).to(weight.dtype) - copy_idx = torch.arange( - seqlen, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, - state_len).unsqueeze(1).expand(-1, dim, -1) + -(width - 1), 0, dtype=torch.long, device=x.device + ).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = ( + torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + ) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze( + 0 + ) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, - groups=dim)[:, :, -seqlen:] + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -seqlen: + ] if unsqueeze: out = out.squeeze(-1) return (out if activation is None else F.silu(out)).to(dtype=dtype_in) @@ -117,15 +114,17 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID): +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -150,8 +149,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("seqlen", [1]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, - itype): +def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -167,23 +165,26 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation) - out_ref = causal_conv1d_update_ref(x_ref, - conv_state_ref, - weight, - bias, - activation=activation) + + conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device) + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices, + ) + out_ref = causal_conv1d_update_ref( + x_ref, conv_state_ref, weight, bias, activation=activation + ) assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("seqlen", [1, 3]) @@ -192,9 +193,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("batch_size", [3]) -def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, - width, seqlen, has_bias, - silu_activation, itype): +def test_causal_conv1d_update_with_batch_gather( + batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -209,31 +210,30 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, total_entries = 10 * batch_size # x will be (batch, dim, seqlen) with contiguous along dim-axis - x = torch.randn(padded_batch_size, seqlen, dim, device=device, - dtype=itype).transpose(1, 2) + x = torch.randn( + padded_batch_size, seqlen, dim, device=device, dtype=itype + ).transpose(1, 2) x_ref = x.clone() conv_state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[conv_state_indices] = False - padded_state_indices = torch.concat([ - conv_state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) - ], - dim=0) + padded_state_indices = torch.concat( + [ + conv_state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) # conv_state will be (cache_lines, dim, state_len) # with contiguous along dim-axis - conv_state = torch.randn(total_entries, - width - 1, - dim, - device=device, - dtype=itype).transpose(1, 2) + conv_state = torch.randn( + total_entries, width - 1, dim, device=device, dtype=itype + ).transpose(1, 2) conv_state_for_padding_test = conv_state.clone() @@ -242,22 +242,23 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation, - conv_state_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID) - out_ref = causal_conv1d_update_ref(x_ref[:batch_size], - conv_state_ref, - weight, - bias, - activation=activation) + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = causal_conv1d_update_ref( + x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation + ) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.equal(conv_state[unused_states_bool], - conv_state_for_padding_test[unused_states_bool]) + assert torch.equal( + conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool] + ) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @@ -265,12 +266,13 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096]) -@pytest.mark.parametrize('dim', [64, 4096]) -@pytest.mark.parametrize('with_padding', [True, False]) -@pytest.mark.parametrize('batch', [4, 10]) -def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, - has_bias, silu_activation, itype): +@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096]) +@pytest.mark.parametrize("dim", [64, 4096]) +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch", [4, 10]) +def test_causal_conv1d_varlen( + batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype +): device = "cuda" torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -288,19 +290,19 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, seqlens.append( torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) x = rearrange( torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype), - "b s d -> b d s")[:, 4096:4096 + dim, :] + "b s d -> b d s", + )[:, 4096 : 4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) @@ -309,34 +311,34 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(total_entries, - width - 1, - dim, - device=x.device, - dtype=x.dtype).transpose(1, 2) + final_states = torch.randn( + total_entries, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) final_states_ref = final_states.clone() - has_initial_states = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=x.device) - state_indices = torch.randperm(total_entries, - dtype=torch.int32, - device=x.device)[:batch_size] - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) - out = causal_conv1d_fn(x.squeeze(0), - weight, - bias=bias, - conv_states=final_states, - query_start_loc=cumsum.cuda(), - cache_indices=padded_state_indices, - has_initial_state=has_initial_states, - activation=activation, - pad_slot_id=PAD_SLOT_ID) + has_initial_states = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[ + :batch_size + ] + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) + out = causal_conv1d_fn( + x.squeeze(0), + weight, + bias=bias, + conv_states=final_states, + query_start_loc=cumsum.cuda(), + cache_indices=padded_state_indices, + has_initial_state=has_initial_states, + activation=activation, + pad_slot_id=PAD_SLOT_ID, + ) out_ref = [] out_ref_b = [] @@ -353,16 +355,20 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[ - padded_state_indices[i]].unsqueeze(0), - initial_states=final_states_ref[padded_state_indices[i]]. - unsqueeze(0) if has_initial_states[i] else None)) + final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0) + if has_initial_states[i] + else None, + ) + ) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref_tensor = torch.cat(out_ref, dim=0) - assert torch.allclose(final_states[state_indices], - final_states_ref[state_indices], - rtol=rtol, - atol=atol) - unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose( + final_states[state_indices], + final_states_ref[state_indices], + rtol=rtol, + atol=atol, + ) + unpadded_out = out[:, : out_ref_tensor.shape[-1]] assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index 16c310726ad16..d23daefa7b436 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -7,8 +7,10 @@ import pytest import torch from tests.utils import multi_gpu_test -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -24,14 +26,15 @@ from vllm.utils import update_environment_variables (64, 2), (64, 4), # hidden_size be divisible by num_gpus (100, 5), # and n_groups must divide hidden_size - ]) + ], +) @pytest.mark.parametrize("dtype", [torch.float16]) def test_mixer2_gated_norm_multi_gpu( batch_size: int, seq_len: int, hidden_size_n_groups: tuple[int, int], dtype: torch.dtype, - device: str = 'cuda', + device: str = "cuda", ): hidden_size, n_groups = hidden_size_n_groups num_processes = 2 @@ -39,17 +42,19 @@ def test_mixer2_gated_norm_multi_gpu( def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=( - num_processes, - batch_size, - seq_len, - hidden_size, - n_groups, - dtype, - device, - ), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs, + ) run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) @@ -71,20 +76,22 @@ def mixer2_gated_norm_tensor_parallel( torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) # create random weights an inputs - weight = torch.rand((hidden_size, ), dtype=dtype, device=device) + weight = torch.rand((hidden_size,), dtype=dtype, device=device) hidden_states = torch.randn(batch_size, seq_len, hidden_size) gate_states = torch.randn(batch_size, seq_len, hidden_size) @@ -97,14 +104,18 @@ def mixer2_gated_norm_tensor_parallel( # create gated-norm without TP to compute reference # - utilize mock patching to disable TP when - with (unittest.mock.patch( + with ( + unittest.mock.patch( "vllm.model_executor.layers.mamba.mamba_mixer2." "get_tensor_model_parallel_world_size", - return_value=1), - unittest.mock.patch( - "vllm.model_executor.layers.mamba.mamba_mixer2." - "get_tensor_model_parallel_rank", - return_value=0)): + return_value=1, + ), + unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_rank", + return_value=0, + ), + ): mixer_single_gpu = Mixer2RMSNormGated( full_hidden_size=hidden_size, full_n_groups=n_groups, @@ -115,12 +126,13 @@ def mixer2_gated_norm_tensor_parallel( # generate and compare N = hidden_size // world_size output = mixer( - hidden_states[..., local_rank * N:(local_rank + 1) * N], - gate_states[..., local_rank * N:(local_rank + 1) * N], + hidden_states[..., local_rank * N : (local_rank + 1) * N], + gate_states[..., local_rank * N : (local_rank + 1) * N], ) ref_output = mixer_single_gpu(hidden_states, gate_states) - torch.testing.assert_close(output, - ref_output[..., - local_rank * N:(local_rank + 1) * N], - atol=5e-3, - rtol=1e-3) + torch.testing.assert_close( + output, + ref_output[..., local_rank * N : (local_rank + 1) * N], + atol=5e-3, + rtol=1e-3, + ) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 4c32ae81b34c5..9a6137239ebfc 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -10,20 +10,15 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) + selective_scan_fn, + selective_state_update, +) from vllm.platforms import current_platform -def selective_state_update_ref(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False): +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -73,16 +68,17 @@ def selective_state_update_ref(state, assert dt_bias.shape == (nheads, dim) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * - A) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) + dA = torch.exp( + rearrange(dt, "b h d -> b h d 1") * A + ) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) dB = rearrange(dt, "b h d -> b h d 1") * rearrange( - B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) - state.copy_(state * dA + - dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + B, "b h n -> b h 1 n" + ) # (batch, nheads, dim, dstate) + state.copy_( + state * dA + dB * rearrange(x, "b h d -> b h d 1") + ) # (batch, dim, dstate out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) if D is not None: out += (x * D).to(out.dtype) @@ -92,18 +88,20 @@ def selective_state_update_ref(state, return out -def selective_scan_ref(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - return_last_state=False, - prev_state=None, - final_state_out=None): +def selective_scan_ref( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + prev_state=None, + final_state_out=None, +): """ u: r(B D L) delta: r(B D L) @@ -132,26 +130,26 @@ def selective_scan_ref(u, C = C.float() x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) else: if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) + y = torch.einsum("bdn,dn->bd", x, C) else: if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) if i == u.shape[2] - 1: if final_state_out is None: final_state_out = x @@ -166,20 +164,22 @@ def selective_scan_ref(u, return out if not return_last_state else (out, final_state_out) -def selective_scan_opcheck_fn(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - cu_seq_len=None, - cache_indices=None, - has_initial_state=None, - ssm_states=None, - pad_slot_id=PAD_SLOT_ID): +def selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + cu_seq_len=None, + cache_indices=None, + has_initial_state=None, + ssm_states=None, + pad_slot_id=PAD_SLOT_ID, +): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -206,30 +206,55 @@ def selective_scan_opcheck_fn(u, # Disable test_autograd_registration for now as it seems to trigger # a bogus error. - opcheck(torch.ops._C.selective_scan_fwd, - (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states, pad_slot_id), - test_utils=["test_schema", "test_faketensor"]) + opcheck( + torch.ops._C.selective_scan_fwd, + ( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cu_seq_len, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ), + test_utils=["test_schema", "test_faketensor"], + ) -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("wtype", [torch.float32]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("has_delta_bias", [True]) +@pytest.mark.parametrize("delta_softplus", [True]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("has_D", [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) @pytest.mark.parametrize("scan_chunks", [1, 2, 3]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, - has_z, has_delta_bias, delta_softplus, seqlen, itype, - wtype, scan_chunks): +def test_selective_scan( + is_variable_B, + is_variable_C, + varBC_groups, + has_D, + has_z, + has_delta_bias, + delta_softplus, + seqlen, + itype, + wtype, + scan_chunks, +): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable - device = 'cuda' + device = "cuda" rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -242,7 +267,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, batch_size = 1 dim = 4 dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) A_ref = A.clone() if not is_variable_B: B_shape = [dim, dstate] @@ -250,9 +275,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, B_shape = [batch_size, dstate, seqlen] else: B_shape = [batch_size, varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) B_ref = B.clone() if not is_variable_C: C_shape = [dim, dstate] @@ -260,27 +283,27 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, C_shape = [batch_size, dstate, seqlen] else: C_shape = [batch_size, varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D_ref = D.clone() - z = torch.randn(batch_size, dim, seqlen, device=device, - dtype=itype) if has_z else None + z = ( + torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + if has_z + else None + ) z_ref = z.clone() if has_z else None - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None + delta_bias = ( + (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + if has_delta_bias + else None + ) u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) u_ref = u.clone() - delta = (0.5 * - torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) delta_ref = delta.clone() state_shape = (batch_size, u.shape[1], int(A.shape[1])) - state = torch.randn(state_shape, - device=u.device, - dtype=itype, - requires_grad=False) + state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False) state_ref = state.clone() out = None out_ref = None @@ -312,9 +335,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, z=_z, delta_bias=delta_bias, delta_softplus=delta_softplus, - has_initial_state=torch.ones(batch_size, - device=u.device, - dtype=torch.bool) if c > 0 else None) + has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) + if c > 0 + else None, + ) outs.append(out) if len(outs) > 1: out = torch.cat(outs, dim=-1) @@ -329,27 +353,29 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, z=z_ref, delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=True) + return_last_state=True, + ) assert out is not None and out_ref is not None assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert state is not None and state_ref is not None assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol) - selective_scan_opcheck_fn(u, - delta, - A, - B, - C, - D, - z, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - ssm_states=state) + selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D, + z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + ssm_states=state, + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @@ -374,52 +400,47 @@ def test_selective_state_update(dim, dstate, has_z, itype): D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None state_ref = state.detach().clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - out=out) - out_ref = selective_state_update_ref(state_ref, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update( + state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', [torch.float32]) -@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("wtype", [torch.float32]) +@pytest.mark.parametrize("itype", [torch.float32]) +@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize("return_last_state", [True]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("has_delta_bias", [True]) +@pytest.mark.parametrize("delta_softplus", [True]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("has_D", [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [False, True]) -def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, - varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, - itype, wtype): +def test_selective_scan_varlen( + with_padding, + is_variable_B, + is_variable_C, + varBC_groups, + has_D, + has_z, + has_delta_bias, + delta_softplus, + return_last_state, + seqlen, + itype, + wtype, +): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable - device = 'cuda' + device = "cuda" rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -443,72 +464,79 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0).cuda() + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() dim = 4 dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) A_ref = A.clone() B_shape = [varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) B_ref = B.clone() C_shape = [varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D_ref = D.clone() z = torch.randn(dim, seqlen, device=device, dtype=itype) z_ref = z.clone() - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None + delta_bias = ( + (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + if has_delta_bias + else None + ) u = torch.randn(dim, seqlen, device=device, dtype=itype) u_ref = u.clone() - delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)) + delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype) delta_ref = delta.clone() out = None out_ref = None prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) - prev_state = torch.randn(prev_state_shape, - device=u.device, - dtype=itype, - requires_grad=False) + prev_state = torch.randn( + prev_state_shape, device=u.device, dtype=itype, requires_grad=False + ) prev_state_ref = prev_state.clone() - state_indices = torch.randperm(total_entries, - dtype=torch.int32, - device=u.device)[:batch_size] - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[ + :batch_size + ] + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) - has_initial_state = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=u.device) - out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, padded_state_indices, - has_initial_state) + has_initial_state = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device + ) + out = selective_scan_fn( + u, + prev_state, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cumsum, + padded_state_indices, + has_initial_state, + ) outs_ref = [] splits = [ torch.split(var, seqlens[0], dim=-1) @@ -530,33 +558,46 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, delta_softplus=delta_softplus, return_last_state=return_last_state, prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) - if has_initial_state[i] else None, - final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( - 0)) + if has_initial_state[i] + else None, + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0), + ) outs_ref.append(out_ref_s) out_ref = torch.cat(outs_ref, dim=-1)[0] - unpadded_out = out[:, :out_ref[0].shape[-1]] + unpadded_out = out[:, : out_ref[0].shape[-1]] print("Output diff max", (unpadded_out - out_ref).max()) print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) - selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, padded_state_indices, - has_initial_state, prev_state) + selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cumsum, + padded_state_indices, + has_initial_state, + prev_state, + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) -def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, - has_z, itype): +def test_selective_state_update_with_batch_indices( + with_padding, dim, dstate, has_z, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -571,17 +612,17 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) - ], - dim=0) + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) out = torch.empty_like(x) dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) @@ -593,61 +634,60 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].clone() state_before = state.clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID, - out=out) - out_ref = selective_state_update_ref(state_ref, - x[:batch_size], - dt[:batch_size], - A, - B[:batch_size], - C[:batch_size], - D=D, - z=z[:batch_size], - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out, + ) + out_ref = selective_state_update_ref( + state_ref, + x[:batch_size], + dt[:batch_size], + A, + B[:batch_size], + C[:batch_size], + D=D, + z=z[:batch_size], + dt_bias=dt_bias, + dt_softplus=True, + ) print("Output diff max", (out[:batch_size] - out_ref).max()) print("Output diff mean", (out[:batch_size] - out_ref).mean()) print("Output state diff max", (state[state_indices, :] - state_ref).max()) - print("Output state diff mean", - (state[state_indices, :] - state_ref).mean()) + print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) # test padded entries stay the same if with_padding: - assert torch.equal(state_before[unused_states_bool], - state[unused_states_bool]) - assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) - assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) - assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) - assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + assert torch.equal(state_before[unused_states_bool], state[unused_states_bool]) + assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :]) + assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :]) + assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :]) + assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :]) # test "real" entries - assert torch.allclose(state[state_indices, :], - state_ref, - rtol=rtol, - atol=atol) + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("tie_hdim", [False, True]) @pytest.mark.parametrize("ngroups", [1, 2, 4]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 4096]) def test_selective_state_update_with_heads_with_batch_indices( - dim, dstate, ngroups, has_z, tie_hdim, itype): + dim, dstate, ngroups, has_z, tie_hdim, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) if itype == torch.bfloat16: @@ -659,71 +699,55 @@ def test_selective_state_update_with_heads_with_batch_indices( nheads = dim // headdim total_entries = 10 * batch_size - state = torch.randn(total_entries, - nheads, - headdim, - dstate, - dtype=itype, - device=device) + state = torch.randn( + total_entries, nheads, headdim, dstate, dtype=itype, device=device + ) state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) + dtype=torch.int32, device=device + ) x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) out = torch.empty_like(x) if not tie_hdim: - dt = torch.randn(batch_size, - nheads, - headdim, - device=device, - dtype=itype) + dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 D = torch.randn(nheads, headdim, device=device) else: - dt = repeat(torch.randn(batch_size, nheads, device=device, - dtype=itype), - "b h -> b h p", - p=headdim) - dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, - "h -> h p", - p=headdim) - A = repeat(-torch.rand(nheads, device=device) - 1.0, - "h -> h p n", - p=headdim, - n=dstate) + dt = repeat( + torch.randn(batch_size, nheads, device=device, dtype=itype), + "b h -> b h p", + p=headdim, + ) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) + A = repeat( + -torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate + ) D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) B = torch.randn(batch_size, ngroups, dstate, device=device) C = torch.randn(batch_size, ngroups, dstate, device=device) z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].detach().clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=state_indices, - pad_slot_id=PAD_SLOT_ID, - out=out) - out_ref = selective_state_update_ref(state_ref, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out, + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert torch.allclose(state[state_indices, :], - state_ref, - rtol=rtol, - atol=atol) + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 2c554baaff76c..57dcb789e97ba 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -7,10 +7,10 @@ import torch.nn.functional as F from einops import rearrange, repeat from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) + mamba_chunk_scan_combined_varlen, +) from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba2_attn import ( - _query_start_loc_to_chunk_indices_offsets) +from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata # Added by the IBM Team, 2024 @@ -22,12 +22,10 @@ def segsum(x): """Calculates segment sum.""" T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), - diagonal=-1) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), - diagonal=0) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum @@ -46,8 +44,9 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): assert X.shape[1] % block_len == 0 # Rearrange into blocks/chunks - X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) - for x in (X, A, B, C)) + X, A, B, C = ( + rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C) + ) A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) @@ -74,7 +73,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) - Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) # Add output of intra-chunk and inter-chunk terms # (diagonal and off-diagonal blocks) @@ -82,61 +81,53 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): return Y, final_state -def generate_random_inputs(batch_size, - seqlen, - n_heads, - d_head, - itype, - device='cuda'): - +def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"): current_platform.seed_everything(0) - A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device)) dt = F.softplus( - torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - - 4) - X = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) - B = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) - C = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4 + ) + X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) return A, dt, X, B, C -def generate_continuous_batched_examples(example_lens_by_batch, - num_examples, - full_length, - last_taken, - exhausted, - n_heads, - d_head, - itype, - device='cuda'): - +def generate_continuous_batched_examples( + example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device="cuda", + return_naive_ref=True, +): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed - # them in continuous batches to the kernels + # them in continuous batches to the kernels. + # If if return_naive_ref=True, the naive torch implementation + # ssd_minimal_discrete will be used to compute and return + # reference output. # generate the full-length example - A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, - d_head, itype) + A, dt, X, B, C = generate_random_inputs( + num_examples, full_length, n_heads, d_head, itype + ) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // 4) + if return_naive_ref: + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4 + ) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch # e.g., example_lens=(8, 4) means take 8 samples from first eg, # 4 examples from second eg, etc def get_continuous_batch(example_lens: tuple[int, ...]): - indices = [] for i, x in enumerate(example_lens): c = last_taken.get(i, 0) @@ -144,8 +135,10 @@ def generate_continuous_batched_examples(example_lens_by_batch, last_taken[i] = (c + x) % full_length exhausted[i] = last_taken[i] == 0 - return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) - ]).unsqueeze(0) for x in (dt, X, B, C)) + return ( + torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0) + for x in (dt, X, B, C) + ) # internal function that maps "n" to the appropriate right boundary # value when forming continuous batches from examples of length given @@ -157,19 +150,20 @@ def generate_continuous_batched_examples(example_lens_by_batch, IND_E = None for spec in example_lens_by_batch: - # get the (maybe partial) example seen in this cont batch dt2, X2, B2, C2 = get_continuous_batch(spec) # get the metadata - cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) - seq_idx = torch.zeros(cu_seqlens[-1], - dtype=torch.int32, - device=cu_seqlens.device) - for i, (srt, end) in enumerate(zip( + cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0) + seq_idx = torch.zeros( + cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device + ) + for i, (srt, end) in enumerate( + zip( cu_seqlens, cu_seqlens[1:], - )): + ) + ): seq_idx[srt:end] = i # for cont batch @@ -179,19 +173,27 @@ def generate_continuous_batched_examples(example_lens_by_batch, IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] - yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], - cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + # varlen has implicit batch=1 + dt2 = dt2.squeeze(0) + X2 = X2.squeeze(0) + B2 = B2.squeeze(0) + C2 = C2.squeeze(0) + yield ( + [Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)] + if return_naive_ref + else None, + cu_seqlens, + seq_idx, + (A, dt2, X2, B2, C2), + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) @pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) @pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) -def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, - itype): - - # this tests the kernels on a single example (no batching) +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): + # this tests the kernels on a single example (bs=1) # TODO: the bfloat16 case requires higher thresholds. To be investigated @@ -207,31 +209,49 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # it is not an operational limitation. seqlen, chunk_size = seq_len_chunk_size - A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, - d_head, itype) + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, - B, C, chunk_size) + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, chunk_size + ) + + cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0) + cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( + compute_varlen_chunk_metadata(cu_seqlens, chunk_size) + ) + # varlen has implicit batch=1 + X = X.squeeze(0) + dt = dt.squeeze(0) + A = A.squeeze(0) + B = B.squeeze(0) + C = C.squeeze(0) Y = torch.empty_like(X) - final_state = mamba_chunk_scan_combined(X, - dt, - A, - B, - C, - chunk_size, - D=None, - return_final_states=True, - out=Y) + final_state = mamba_chunk_scan_combined_varlen( + X, + dt, + A, + B, + C, + chunk_size, + cu_seqlens=cu_seqlens.to(torch.int32), + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunks, + out=Y, + D=None, + ) # just test the last in sequence - torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) + torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol) # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.testing.assert_close(final_state[:, -1], - final_state_min[:, -1].to(torch.float32), - atol=atol, - rtol=rtol) + torch.testing.assert_close( + final_state[:, -1].to(torch.float32), + final_state_min[:, -1].to(torch.float32), + atol=atol, + rtol=rtol, + ) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -240,32 +260,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, @pytest.mark.parametrize( "seq_len_chunk_size_cases", [ - # small-ish chunk_size (8) (64, 8, 2, [(64, 32), (64, 32)]), (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary - (64, 8, 2, [(4, 4), (4, 4), (4, 4), - (4, 4)]), # chunk_size larger than cont batches - (64, 8, 5, [ - (64, 32, 16, 8, 8), - (8, 16, 32, 16, 8), - (8, 8, 16, 32, 16), - ]), # mode examples with varied lengths - + ( + 64, + 8, + 2, + [(4, 4), (4, 4), (4, 4), (4, 4)], + ), # chunk_size larger than cont batches + ( + 64, + 8, + 5, + [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ], + ), # mode examples with varied lengths # large-ish chunk_size (256) - (64, 256, 1, [(5, ), (1, ), (1, ), - (1, )]), # irregular sizes with small sequences - (64, 256, 2, [(5, 30), (1, 2), (1, 2), - (1, 2)]), # irregular sizes with small sequences - + (64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences + ( + 64, + 256, + 2, + [(5, 30), (1, 2), (1, 2), (1, 2)], + ), # irregular sizes with small sequences # we also need to test some large seqlen # to catch errors with init states decay (768, 128, 2, [(138, 225), (138, 225)]), - ]) -def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, - itype): - + ], +) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype): # this test with multiple examples in a continuous batch # (i.e. chunked prefill) @@ -283,38 +311,40 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, exhausted: dict = {} # map: eg -> boolean indicating example is exhausted states = None - for Y_min, cu_seqlens, seq_idx, ( - A, dt, X, B, C) in generate_continuous_batched_examples( - cases, num_examples, seqlen, last_taken, exhausted, n_heads, - d_head, itype): - - chunk_indices, chunk_offsets = \ - _query_start_loc_to_chunk_indices_offsets( - cu_seqlens, chunk_size, cu_seqlens[-1]) + for Y_min, cu_seqlens, _token_seq_idx, ( + A, + dt, + X, + B, + C, + ) in generate_continuous_batched_examples( + cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype + ): + cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( + compute_varlen_chunk_metadata(cu_seqlens, chunk_size) + ) Y = torch.empty_like(X) - new_states = mamba_chunk_scan_combined( + new_states = mamba_chunk_scan_combined_varlen( X, dt, A, B, C, chunk_size, - D=None, - cu_seqlens=cu_seqlens, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - return_varlen_states=True, - initial_states=states, + cu_seqlens=cu_seqlens.to(torch.int32), + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunks, out=Y, + D=None, + initial_states=states, ) # just test the last in sequence for i in range(num_examples): - # just test one dim and dstate - Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_eg = Y[cu_seqlens[i] : cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) @@ -322,5 +352,232 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, states = new_states for i, clear in exhausted.items(): if clear: - states[i].fill_(0.) + states[i].fill_(0.0) exhausted[i] = False + + +@pytest.mark.parametrize("chunk_size", [8, 256]) +@pytest.mark.parametrize( + "seqlens", + [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), + ], +) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + # This test verifies the correctness of the chunked prefill implementation + # in the mamba2 ssd kernels, by comparing concatenation (in the sequence + # dimension) of chunked results with the full sequence result. + # It is different from test_mamba_chunk_scan_cont_batch by: + # 1. Not using the naive torch implementation (ssd_minimal_discrete) to get + # reference outputs. Instead, it compares chunked kernel outputs to full + # sequence kernel outputs. This is the most straightforward way to + # assert chunked prefill correctness. + # 2. It focuses on cases where sequences change in the middle of mamba + # chunks, and not necessarily on chunk boundaries. + + max_seqlen = max(seqlens) + # This test can have larger error for longer sequences + if max_seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + num_sequences = len(seqlens) + n_heads = 16 + d_head = 64 + itype = torch.float32 + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted + _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( + generate_continuous_batched_examples( + [seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_naive_ref=False, + ) + ) + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) + device = X.device + + ## full seqlen computation + cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( + compute_varlen_chunk_metadata(cu_seqlens, chunk_size) + ) + Y_ref = torch.empty_like(X) + state_ref = mamba_chunk_scan_combined_varlen( + X, + dt, + A, + B, + C, + chunk_size, + cu_seqlens=cu_seqlens.to(torch.int32), + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunks, + out=Y_ref, + D=None, + initial_states=None, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat( + [torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0 + ) + chunked_input_seq_len = chunked_cu_seqlens[-1] + X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...] + for i in range(num_sequences): + chunk_f = lambda x, i: x[ + cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ... + ] + + X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + X, i + ) + dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + dt, i + ) + B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + B, i + ) + C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + C, i + ) + + cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( + compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size) + ) + Y_partial = torch.empty_like(X_chunked) + partial_state = mamba_chunk_scan_combined_varlen( + X_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size, + cu_seqlens=chunked_cu_seqlens.to(torch.int32), + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunks, + out=Y_partial, + D=None, + initial_states=None, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0), + ], + dim=0, + ) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] + remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] + remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] + remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[ + cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ... + ] + + remaining_X_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(X, i) + remaining_dt_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(dt, i) + remaining_B_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(B, i) + remaining_C_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(C, i) + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat( + [ + pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...], + pt2[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], + ..., + ], + ], + dim=0, + ) + concat_batch_f = lambda pt1, pt2: torch.cat( + [concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0 + ) + + assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + + cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( + compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, chunk_size) + ) + + Y_chunked = torch.empty_like(remaining_X_chunked) + state_chunked = mamba_chunk_scan_combined_varlen( + remaining_X_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size, + cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32), + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunks, + out=Y_chunked, + D=None, + initial_states=partial_state, + ) + Y = concat_batch_f(Y_partial, Y_chunked) + + # kernel chunked is same as kernel overall + for i in range(num_sequences): + Y_seq = Y[cu_seqlens[i] : cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[cu_seqlens[i] : cu_seqlens[i + 1], ...] + torch.testing.assert_close( + Y_seq[: chunked_seqlens[i], ...], + Y_ref_seq[: chunked_seqlens[i], ...], + atol=atol, + rtol=rtol, + msg=lambda x, i=i: f"seq{i} output part1 " + x, + ) + torch.testing.assert_close( + Y_seq[chunked_seqlens[i] :, ...], + Y_ref_seq[chunked_seqlens[i] :, ...], + atol=atol, + rtol=rtol, + msg=lambda x, i=i: f"seq{i} output part2 " + x, + ) + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close( + state_seq, + state_seq_ref, + atol=atol, + rtol=rtol, + msg=lambda x, i=i: f"seq{i} state " + x, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py index b95d87cd04f57..d46847fbf6a3c 100644 --- a/tests/kernels/moe/modular_kernel_tools/cli_args.py +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -9,18 +9,19 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from .common import Config -from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) +from .mk_objects import ( + MK_ALL_PREPARE_FINALIZE_TYPES, + MK_FUSED_EXPERT_TYPES, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, +) def make_config_arg_parser(description: str): - def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: for pf in MK_ALL_PREPARE_FINALIZE_TYPES: if pf.__name__ == s: return pf - raise ValueError( - f"Cannot find a PrepareFinalize type that matches {s}") + raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}") def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: for fe in MK_FUSED_EXPERT_TYPES: @@ -45,15 +46,18 @@ def make_config_arg_parser(description: str): "--pf-type", type=to_pf_class_type, required=True, - help=("Choose a PrepareFinalize Type : " - f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"), + help=( + "Choose a PrepareFinalize Type : " + f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}" + ), ) parser.add_argument( "--experts-type", type=to_experts_class_type, required=True, - help=(f"Choose a FusedExpert type : " - f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"), + help=( + f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}" + ), ) parser.add_argument( "-m", @@ -74,66 +78,65 @@ def make_config_arg_parser(description: str): default=1024, help="N dimension of the first fused-moe matmul", ) - parser.add_argument("--num-experts", - type=int, - default=32, - help="Global num experts") - parser.add_argument("--topk", - nargs="+", - type=int, - default=[4, 1], - help="num topk") + parser.add_argument( + "--num-experts", type=int, default=32, help="Global num experts" + ) + parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk") parser.add_argument( "--fused-moe-chunk-size", type=int, - help="Fused moe chunk size used for the non-batched fused experts impl." + help="Fused moe chunk size used for the non-batched fused experts impl.", ) # Quant args - parser.add_argument("--quant-dtype", - type=to_quant_torch_dtype, - help="Quant datatype") - parser.add_argument("--per-token-quantized-activations", - action='store_true', - help=("The input activations must be per-token " - "quantized")) - parser.add_argument("--per-channel-quantized-weights", - action="store_true", - help="The weights must be per-channel quantized.") - parser.add_argument("--block-shape", - nargs="+", - type=int, - help="Quantization block shape") + parser.add_argument( + "--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype" + ) + parser.add_argument( + "--per-token-quantized-activations", + action="store_true", + help=("The input activations must be per-token quantized"), + ) + parser.add_argument( + "--per-channel-quantized-weights", + action="store_true", + help="The weights must be per-channel quantized.", + ) + parser.add_argument( + "--block-shape", nargs="+", type=int, help="Quantization block shape" + ) # Torch trace profile generation args - parser.add_argument("--torch-trace-dir-path", - type=str, - default=None, - help="Get torch trace for single execution") + parser.add_argument( + "--torch-trace-dir-path", + type=str, + default=None, + help="Get torch trace for single execution", + ) return parser def _validate_args(args: argparse.Namespace): - if args.quant_dtype is not None: assert args.quant_dtype == torch.float8_e4m3fn if args.block_shape is not None: assert len(args.block_shape) == 2, ( - f"block shape must have 2 elements. got {args.block_shape}") + f"block shape must have 2 elements. got {args.block_shape}" + ) if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: - assert args.world_size == 1, ( - "Single GPU objects need world size set to 1") + assert args.world_size == 1, "Single GPU objects need world size set to 1" if args.torch_trace_dir_path is not None: from pathlib import Path + assert Path(args.torch_trace_dir_path).is_dir(), ( - f"Please create {args.torch_trace_dir_path}") + f"Please create {args.torch_trace_dir_path}" + ) def make_config(args: argparse.Namespace) -> Config: - _validate_args(args) quant_config = None @@ -142,7 +145,8 @@ def make_config(args: argparse.Namespace) -> Config: quant_dtype=args.quant_dtype, per_act_token_quant=args.per_token_quantized_activations, per_out_ch_quant=args.per_channel_quantized_weights, - block_shape=args.block_shape) + block_shape=args.block_shape, + ) return Config( Ms=args.m, @@ -156,4 +160,5 @@ def make_config(args: argparse.Namespace) -> Config: fused_experts_type=args.experts_type, fused_moe_chunk_size=args.fused_moe_chunk_size, world_size=args.world_size, - torch_trace_dir_path=args.torch_trace_dir_path) + torch_trace_dir_path=args.torch_trace_dir_path, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index a10666b6ec9a7..ff12d1fb9a805 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -8,20 +8,30 @@ import torch import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8 -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_experts from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx -from .mk_objects import (expert_info, make_fused_experts, - make_prepare_finalize, prepare_finalize_info) +from .mk_objects import ( + TestMoEQuantConfig, + expert_info, + make_fused_experts, + make_prepare_finalize, + prepare_finalize_info, +) from .parallel_utils import ProcessGroupInfo @@ -40,7 +50,7 @@ class Config: E: int topks: Union[list[int], int] dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] + quant_config: Optional[TestMoEQuantConfig] prepare_finalize_type: mk.FusedMoEPrepareAndFinalize fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute @@ -52,7 +62,7 @@ class Config: def __post_init__(self): if self.quant_config is None: - self.quant_config = FusedMoEQuantConfig() + self.quant_config = TestMoEQuantConfig(None, False, False, None) def describe(self) -> str: s = "" @@ -94,8 +104,7 @@ class Config: @property def is_per_tensor_act_quant(self) -> bool: - return (not self.is_per_act_token_quant - and self.quant_block_shape is None) + return not self.is_per_act_token_quant and self.quant_block_shape is None @property def is_per_out_ch_quant(self) -> bool: @@ -134,23 +143,24 @@ class Config: if self.fused_moe_chunk_size is not None: env_dict.update( - {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)} + ) return vllm_config, env_dict def is_fp8_block_quantized(self): - return (self.quant_dtype == torch.float8_e4m3fn - and self.quant_block_shape is not None) + return ( + self.quant_dtype == torch.float8_e4m3fn + and self.quant_block_shape is not None + ) def is_batched_prepare_finalize(self): info = prepare_finalize_info(self.prepare_finalize_type) - return (mk.FusedMoEActivationFormat.BatchedExperts == - info.activation_format) + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format def is_batched_fused_experts(self): info = expert_info(self.fused_experts_type) - return (mk.FusedMoEActivationFormat.BatchedExperts == - info.activation_format) + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format def is_standard_fused_experts(self): info = expert_info(self.fused_experts_type) @@ -190,63 +200,80 @@ class Config: def needs_deep_ep(self): info = prepare_finalize_info(self.prepare_finalize_type) - return (info.backend == "deepep_high_throughput" - or info.backend == "deepep_low_latency") + return ( + info.backend == "deepep_high_throughput" + or info.backend == "deepep_low_latency" + ) def all2all_backend(self): info = prepare_finalize_info(self.prepare_finalize_type) return info.backend - def is_valid(self): + def is_valid(self) -> tuple[bool, Optional[str]]: # Check prepare-finalize and fused-experts compatibility if self.is_batched_prepare_finalize(): if not self.is_batched_fused_experts(): - return False + return False, "Mismatched format." else: if not self.is_standard_fused_experts(): - return False + return False, "Mismatched format." use_chunking = self.fused_moe_chunk_size is not None if use_chunking and not self.is_fe_supports_chunking(): - return False + return False, "Chunking not supported." # Check quantization sanity - if (int(self.is_per_act_token_quant) + - int(self.is_per_tensor_act_quant) + - int(self.quant_block_shape is not None)) > 1: + if ( + int(self.is_per_act_token_quant) + + int(self.is_per_tensor_act_quant) + + int(self.quant_block_shape is not None) + ) > 1: # invalid quant config - return False + return False, f"Bad quant_config {self.quant_config}." # check type support if self.quant_dtype is None: - if (self.dtype not in self.pf_supported_types() - or self.dtype not in self.fe_supported_types()): - return False + if ( + self.dtype not in self.pf_supported_types() + or self.dtype not in self.fe_supported_types() + ): + return False, ( + f"Unsupported type {self.dtype} not in " + f"{self.pf_supported_types()} and " + f"{self.fe_supported_types()}." + ) else: - if (self.quant_dtype not in self.pf_supported_types() - or self.quant_dtype not in self.fe_supported_types()): - return False + if ( + self.quant_dtype not in self.pf_supported_types() + or self.quant_dtype not in self.fe_supported_types() + ): + return False, ( + f"Unsupported quant type {self.quant_dtype} " + f"not in {self.pf_supported_types()} and " + f"{self.fe_supported_types()}." + ) # Check block quanization support is_block_quatized = self.quant_block_shape is not None if is_block_quatized and self.quant_dtype is None: - return False + return False, "No block quantization support." + if is_block_quatized and not self.is_block_quant_supported(): - return False + return False, "Mismatched block quantization support." # deep_gemm only works with block-quantized if self.needs_deep_gemm() and not is_block_quatized: - return False + return False, "Needs DeepGEMM but not block quantized." # Check dependencies (turn into asserts?) if self.needs_deep_ep() and not has_deep_ep(): - return False + return False, "Needs DeepEP, but DeepEP not available." if self.needs_deep_gemm() and not has_deep_gemm(): - return False + return False, "Needs DeepGEMM, but DeepGEMM not available." if self.needs_pplx() and not has_pplx(): # noqa: SIM103 - return False + return False, "Needs PPLX, but PPLX not available." - return True + return True, None @dataclass @@ -261,56 +288,46 @@ class WeightTensors: def describe(self): s = "" s += "== Weight Tensors: \n" - s += f' - {_describe_tensor(self.w1, "w1")} \n' - s += f' - {_describe_tensor(self.w2, "w2")} \n' - s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' - s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' - s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n' - s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n' + s += f" - {_describe_tensor(self.w1, 'w1')} \n" + s += f" - {_describe_tensor(self.w2, 'w2')} \n" + s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n" + s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n" + s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n" + s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n" return s def is_quantized(self) -> bool: # or w1_scale is not None? - return (self.w1.dtype == torch.float8_e4m3fn - or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) + return ( + self.w1.dtype == torch.float8_e4m3fn + or self.w1.dtype == torch.uint8 + or self.w1.dtype == torch.int8 + ) def to_current_device(self): - self.w1 = self.w1.to(device=torch.cuda.current_device()) - self.w2 = self.w2.to(device=torch.cuda.current_device()) + device = torch.cuda.current_device() + self.w1 = self.w1.to(device=device) + self.w2 = self.w2.to(device=device) - if self.is_quantized(): - assert self.w1_scale is not None - assert self.w2_scale is not None - self.w1_scale = self.w1_scale.to( - device=torch.cuda.current_device()) - self.w2_scale = self.w2_scale.to( - device=torch.cuda.current_device()) + if self.w1_scale is not None: + self.w1_scale = self.w1_scale.to(device=device) + if self.w2_scale is not None: + self.w2_scale = self.w2_scale.to(device=device) if self.w1_gs is not None: - assert self.w2_gs is not None - self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) - self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) + self.w1_gs = self.w1_gs.to(device=device) + if self.w2_gs is not None: + self.w2_gs = self.w2_gs.to(device=device) - def slice_weights(self, rank: int, - num_local_experts: int) -> "WeightTensors": + def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": s = rank * num_local_experts e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - - w1_scale, w2_scale = (None, None) - if self.is_quantized(): - assert self.w1_scale is not None - assert self.w2_scale is not None - w1_scale = self.w1_scale[s:e, :, :] - w2_scale = self.w2_scale[s:e, :, :] - - w1_gs = self.w1_gs - w2_gs = self.w2_gs - if w1_gs is not None: - assert w2_gs is not None - w1_gs = w1_gs[s:e] - w2_gs = w2_gs[s:e] + w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None + w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None + w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None + w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) @@ -323,14 +340,12 @@ class WeightTensors: in_dtype=config.dtype, quant_dtype=config.quant_dtype, block_shape=config.quant_block_shape, - per_act_token_quant=config.is_per_out_ch_quant, + # or config.is_per_out_ch_quant + per_out_ch_quant=config.is_per_act_token_quant, + ) + return WeightTensors( + w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_gs=w1_gs, - w2_gs=w2_gs) @dataclass @@ -342,27 +357,25 @@ class RankTensors: topk_ids: torch.Tensor expert_map: Optional[torch.Tensor] - quant_config: Optional[FusedMoEQuantConfig] - def describe(self): s = "" s += "== Rank Tensors: \n" - s += f' - {_describe_tensor(self.hidden_states, "HS")} \n' - s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n' - s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n' - s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n' - s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n' + s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n" + s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n" + s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n" + s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n" + s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n" return s @staticmethod def make_hidden_states( - config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + config: Config, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Return hidden_states """ m, k, dtype = (config.M, config.K, config.dtype) - a = (torch.randn( - (m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0) + a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0 if config.quant_dtype is None: return a, None @@ -373,36 +386,29 @@ class RankTensors: # first - so further quantize and dequantize will yield the same # values. if config.is_per_tensor_act_quant: - a_q, a_scales = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=False) + a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False) return a_q.float().mul(a_scales).to(dtype), a_scales if config.is_per_act_token_quant: - a_q, a_scales = ops.scaled_fp8_quant(a, - use_per_token_if_dynamic=True) + a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True) return a_q.float().mul(a_scales).to(dtype), None assert config.quant_block_shape is not None block_k = config.quant_block_shape[1] a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k) - return a_q.float().view( - (-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None + return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to( + dtype + ), None @staticmethod def make(config: Config, pgi: ProcessGroupInfo): - dtype = config.dtype topk, m, _ = (config.topk, config.M, config.K) - hidden_states, hidden_states_scale = RankTensors.make_hidden_states( - config) + hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config) - num_local_experts, global_num_experts = (config.num_local_experts, - config.E) - score = torch.randn((m, global_num_experts), - device="cuda", - dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, - False) + num_local_experts, global_num_experts = (config.num_local_experts, config.E) + score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) # distribute topk_ids evenly for mi in range(m): @@ -411,14 +417,15 @@ class RankTensors: expert_map = None if config.world_size > 1 and config.supports_expert_map(): - expert_map = torch.full((global_num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full( + (global_num_experts,), fill_value=-1, dtype=torch.int32 + ) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - expert_map = expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + expert_map = expert_map.to( + device=torch.cuda.current_device(), dtype=torch.int32 + ) return RankTensors( hidden_states=hidden_states, @@ -426,13 +433,12 @@ class RankTensors: topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, - quant_config=config.quant_config, ) -def reference_moe_impl(config: Config, weights: WeightTensors, - rank_tensors: RankTensors) -> torch.Tensor: - +def reference_moe_impl( + config: Config, weights: WeightTensors, rank_tensors: RankTensors +) -> torch.Tensor: if config.quant_dtype == "nvfp4": quant_blocksize = 16 dtype = config.dtype @@ -445,8 +451,10 @@ def reference_moe_impl(config: Config, weights: WeightTensors, w2_blockscale = weights.w2_scale w2_gs = weights.w2_gs - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax( - rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) + / torch.amax(rank_tensors.hidden_states.flatten(), dim=-1) + ).to(torch.float32) assert w1_gs is not None assert w2_gs is not None @@ -459,14 +467,17 @@ def reference_moe_impl(config: Config, weights: WeightTensors, assert w2_blockscale.shape[2] % 4 == 0 a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( - rank_tensors.hidden_states, a_global_scale) + rank_tensors.hidden_states, a_global_scale + ) - a = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=dtype, - device=a_fp4.device, - block_size=quant_blocksize) + a = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=dtype, + device=a_fp4.device, + block_size=quant_blocksize, + ) e = w1_q.shape[0] n = w1_q.shape[1] // 2 @@ -476,18 +487,22 @@ def reference_moe_impl(config: Config, weights: WeightTensors, w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize, + ) + w2[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize, + ) a_scale = None w1_scale = None w2_scale = None @@ -505,34 +520,42 @@ def reference_moe_impl(config: Config, weights: WeightTensors, per_act_token_quant = config.is_per_act_token_quant block_shape = config.quant_block_shape - return torch_experts(a=a, - w1=w1, - w2=w2, - topk_weight=rank_tensors.topk_weights, - topk_ids=rank_tensors.topk_ids, - global_num_experts=config.E, - expert_map=None, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - apply_router_weights_on_input=config.topk == 1 - and config.supports_apply_weight_on_input()) + return torch_experts( + a=a, + w1=w1, + w2=w2, + topk_weight=rank_tensors.topk_weights, + topk_ids=rank_tensors.topk_ids, + global_num_experts=config.E, + expert_map=None, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + apply_router_weights_on_input=config.topk == 1 + and config.supports_apply_weight_on_input(), + ) + + +def _make_gscale(num_experts: int) -> torch.Tensor: + return torch.ones( + (num_experts,), device=torch.cuda.current_device(), dtype=torch.float32 + ) def make_modular_kernel( config: Config, vllm_config: VllmConfig, - weights: WeightTensors, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEModularKernel: - def next_power_of_2(x): import math + if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return 2 ** math.ceil(math.log2(x)) # make moe config moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( @@ -548,24 +571,25 @@ def make_modular_kernel( num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, - quant_config=config.quant_config, max_num_tokens=next_power_of_2(config.M), ) # make modular kernel - prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, - config.all2all_backend(), moe) + prepare_finalize = make_prepare_finalize( + config.prepare_finalize_type, config.all2all_backend(), moe, quant_config + ) fused_experts = make_fused_experts( config.fused_experts_type, moe, + quant_config, prepare_finalize.num_dispatchers(), - weights.w1_gs, - weights.w2_gs, + config.N, ) modular_kernel = mk.FusedMoEModularKernel( - prepare_finalize=prepare_finalize, fused_experts=fused_experts) + prepare_finalize=prepare_finalize, fused_experts=fused_experts + ) return modular_kernel @@ -583,44 +607,54 @@ def run_modular_kernel( # weights for rank rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) - mk = make_modular_kernel(config, vllm_config, weights) + if config.quant_dtype == "nvfp4": + gscale = _make_gscale(config.num_local_experts) + else: + gscale = None + + quant_config = FusedMoEQuantConfig.make( + config.quant_dtype, + w1_scale=rank_weights.w1_scale, + w2_scale=rank_weights.w2_scale, + a1_scale=rank_tensors.hidden_states_scale, + g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None, + g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None, + a1_gscale=gscale, + a2_gscale=gscale, + block_shape=config.quant_block_shape, + per_act_token_quant=config.is_per_act_token_quant, + per_out_ch_quant=config.is_per_out_ch_quant, + ) + + mk = make_modular_kernel(config, vllm_config, quant_config) + + # impls might update the tensor in place + hidden_states = rank_tensors.hidden_states.clone() + + topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()) mk_kwargs = { - "hidden_states": - rank_tensors.hidden_states.clone( - ), # impls might update the tensor in place - "w1": - rank_weights.w1, - "w2": - rank_weights.w2, - "topk_weights": - rank_tensors.topk_weights, - "topk_ids": - rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), - "expert_map": - rank_tensors.expert_map, - "w1_scale": - rank_weights.w1_scale, - "w2_scale": - rank_weights.w2_scale, - "a1_scale": - rank_tensors.hidden_states_scale, - "global_num_experts": - config.E, - "apply_router_weight_on_input": - config.topk == 1 and config.supports_apply_weight_on_input(), + "hidden_states": hidden_states, + "w1": rank_weights.w1, + "w2": rank_weights.w2, + "topk_weights": rank_tensors.topk_weights, + "topk_ids": topk_ids, + "expert_map": rank_tensors.expert_map, + "global_num_experts": config.E, + "apply_router_weight_on_input": config.topk == 1 + and config.supports_apply_weight_on_input(), } num_tokens = rank_tensors.hidden_states.shape[0] - num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size, - device="cuda", - dtype=torch.int) + num_tokens_across_dp = torch.tensor( + [num_tokens] * config.world_size, device="cuda", dtype=torch.int + ) with set_forward_context( - None, - vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, + None, + vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, ): out = mk.forward(**mk_kwargs) diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index 5dbfdfc153f9f..7d555202afe6a 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -10,13 +10,21 @@ import torch from tqdm import tqdm from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG from vllm.platforms import current_platform -from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, - run_modular_kernel) -from .mk_objects import (MK_FUSED_EXPERT_TYPES, - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS) +from .common import ( + Config, + RankTensors, + WeightTensors, + reference_moe_impl, + run_modular_kernel, +) +from .mk_objects import ( + MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, +) from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config @@ -37,8 +45,9 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() @@ -59,8 +68,7 @@ def rank_worker( rank_tensors = RankTensors.make(cfgx, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors) with set_current_vllm_config(vllm_config): ref_out = reference_moe_impl(cfgx, weights, rank_tensors) @@ -69,28 +77,27 @@ def rank_worker( def make_feature_matrix(csv_file_path: str): - from dataclasses import asdict import pandas as pd - def add_to_results(config: Config, - success: Result, - results_df: Optional[pd.DataFrame] = None): + def add_to_results( + config: Config, success: Result, results_df: Optional[pd.DataFrame] = None + ): config_dict = asdict(config) - config_dict['prepare_finalize_type'] = config_dict[ - 'prepare_finalize_type'].__name__ - config_dict['fused_experts_type'] = config_dict[ - 'fused_experts_type'].__name__ - config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant - quant_config_dict = config_dict['quant_config'] - del config_dict['quant_config'] + config_dict["prepare_finalize_type"] = config_dict[ + "prepare_finalize_type" + ].__name__ + config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__ + config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant + quant_config_dict = config_dict["quant_config"] + del config_dict["quant_config"] if quant_config_dict is None: - quant_config = FusedMoEQuantConfig(None) + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config_dict = asdict(quant_config) config_dict |= quant_config_dict - result_dict = config_dict | {'success': success.name} + result_dict = config_dict | {"success": success.name} result_df = pd.DataFrame([result_dict]) if results_df is None: @@ -111,32 +118,41 @@ def make_feature_matrix(csv_file_path: str): Q_TYPES = MK_QUANT_CONFIGS combinations = list( - product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)) + product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES) + ) results_df: Optional[pd.DataFrame] = None for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm( - combinations): #noqa: E501 - config = Config(Ms=[m], - K=k, - N=n, - E=e, - topks=topks, - dtype=dtype, - prepare_finalize_type=pf_type, - fused_experts_type=experts_type, - quant_config=quant_config, - world_size=2, - fused_moe_chunk_size=None) + combinations + ): + config = Config( + Ms=[m], + K=k, + N=n, + E=e, + topks=topks, + dtype=dtype, + prepare_finalize_type=pf_type, + fused_experts_type=experts_type, + quant_config=quant_config, + world_size=2, + fused_moe_chunk_size=None, + ) success = None - if config.is_valid(): + if config.is_valid()[0]: print(f"Running config : {config.describe()} ...") try: weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, - vllm_config, env_dict, config, - weights) + parallel_launch_with_config( + config.world_size, + rank_worker, + vllm_config, + env_dict, + config, + weights, + ) success = Result.PASS except Exception as _: success = Result.FAIL @@ -149,25 +165,33 @@ def make_feature_matrix(csv_file_path: str): results_df.to_csv(f"{csv_file_path}") -if __name__ == '__main__': +if __name__ == "__main__": import argparse from pathlib import Path - parser = argparse.ArgumentParser(description=( - "Make ModularKernel feature matrix \n" - "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501 - "-f ./feature_matrices/feature_matrix.csv")) - parser.add_argument("-f", - "--feature-matrix-csv-file-path", - type=str, - required=True, - help="File name to Generate a .csv file") + parser = argparse.ArgumentParser( + description=( + "Make ModularKernel feature matrix \n" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501 + "-f ./feature_matrices/feature_matrix.csv" + ) + ) + + parser.add_argument( + "-f", + "--feature-matrix-csv-file-path", + type=str, + required=True, + help="File name to Generate a .csv file", + ) args = parser.parse_args() csv_path = args.feature_matrix_csv_file_path - assert csv_path.endswith( - 'csv'), f"Need a file path ending with .csv, got {csv_path}" - assert Path(csv_path).parent.is_dir( - ), f"Cannot find parent directory for {Path(csv_path).parent}" + assert csv_path.endswith("csv"), ( + f"Need a file path ending with .csv, got {csv_path}" + ) + assert Path(csv_path).parent.is_dir(), ( + f"Cannot find parent directory for {Path(csv_path).parent}" + ) make_feature_matrix(args.feature_matrix_csv_file_path) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index aecffae36ae5e..174b2d1781ae0 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -8,30 +8,47 @@ import torch # Fused experts and PrepareFinalize imports import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) + BatchedDeepGemmExperts, +) +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( + BatchedTritonOrDeepGemmExperts, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, - TritonExperts) + BatchedTritonExperts, + NaiveBatchedExperts, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported) + cutlass_fp8_supported, +) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +@dataclass +class TestMoEQuantConfig: + quant_dtype: Union[torch.dtype, str, None] + per_out_ch_quant: bool + per_act_token_quant: bool + block_shape: Optional[list[int]] + + @dataclass class PrepareFinalizeInfo: activation_format: mk.FusedMoEActivationFormat @@ -52,8 +69,7 @@ class ExpertInfo: needs_deep_gemm: bool = False -PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, - PrepareFinalizeInfo] = {} +PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {} EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] @@ -63,10 +79,13 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = [] standard_format = mk.FusedMoEActivationFormat.Standard batched_format = mk.FusedMoEActivationFormat.BatchedExperts common_float_types: list[Union[torch.dtype, str]] = [ - torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 + torch.float8_e4m3fn, + torch.bfloat16, + torch.float16, + torch.float32, ] common_float_and_int_types = common_float_types + [torch.int8] -nv_fp4_types = ["nvfp4"] +nvfp4_types = ["nvfp4"] fp8_types = [torch.float8_e4m3fn] @@ -177,10 +196,12 @@ register_experts( # Disable on blackwell for now if has_deep_ep() and not current_platform.has_device_capability(100): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) register_prepare_and_finalize( DeepEPHTPrepareAndFinalize, @@ -200,7 +221,9 @@ if has_deep_ep() and not current_platform.has_device_capability(100): if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) + register_prepare_and_finalize( PplxPrepareAndFinalize, batched_format, @@ -209,17 +232,19 @@ if has_pplx(): backend="pplx", ) -if (has_flashinfer_cutlass_fused_moe() - and current_platform.has_device_capability(100)): - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - FlashInferExperts) +if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + FlashInferCutlassMoEPrepareAndFinalize, + create_flashinfer_prepare_finalize, + ) register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, standard_format, - nv_fp4_types, + nvfp4_types + fp8_types, blocked_quantization_support=True, backend=None, force_multigpu=True, @@ -229,7 +254,7 @@ if (has_flashinfer_cutlass_fused_moe() register_experts( FlashInferExperts, standard_format, - nv_fp4_types, + nvfp4_types + fp8_types, blocked_quantization_support=True, supports_chunking=True, # Note: this is a hack to get it to run for now @@ -258,7 +283,7 @@ if has_deep_gemm() and is_deep_gemm_supported(): supports_expert_map=True, needs_matching_quant=False, needs_deep_gemm=True, - ), + ) register_experts( BatchedTritonOrDeepGemmExperts, batched_format, @@ -281,8 +306,11 @@ if has_deep_gemm() and is_deep_gemm_supported(): ) if cutlass_fp8_supported(): - from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8, - CutlassExpertsFp8) + from vllm.model_executor.layers.fused_moe import ( + CutlassBatchedExpertsFp8, + CutlassExpertsFp8, + ) + register_experts( CutlassExpertsFp8, standard_format, @@ -301,44 +329,54 @@ if cutlass_fp8_supported(): ) if cutlass_fp4_supported(): - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp4) + from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4 + register_experts( CutlassExpertsFp4, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, supports_chunking=True, supports_expert_map=False, ) -MK_QUANT_CONFIGS = [ +MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [ None, # per-channel / per-column weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None, + ), # per-channel / per-column weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None, + ), # per-tensor weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None, + ), # per-tensor weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None, + ), # block-quantized weights and 128 block per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=[128, 128]), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128], + ), # TODO (varun) : Should we test the following combinations ? # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations @@ -346,32 +384,30 @@ MK_QUANT_CONFIGS = [ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): MK_QUANT_CONFIGS += [ - FusedMoEQuantConfig(quant_dtype="nvfp4", - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig( + quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None, + ), ] -def _make_gscale(num_experts: int) -> torch.Tensor: - return torch.ones((num_experts, ), - device=torch.cuda.current_device(), - dtype=torch.float32) - - def make_prepare_finalize( prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, backend: Optional[str], moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize( + moe, quant_config + ) assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: - return FlashInferCutlassMoEPrepareAndFinalize( - use_dp=moe.moe_parallel_config.dp_size > 1, - a1_gscale=_make_gscale(moe.num_local_experts), + return create_flashinfer_prepare_finalize( + use_dp=moe.moe_parallel_config.dp_size > 1 ) else: return MoEPrepareAndFinalizeNoEP() @@ -383,34 +419,38 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: return t[s:e] +def make_cutlass_strides( + e: int, + n: int, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) + return ab_strides1, ab_strides2, c_strides1, c_strides2 + + def make_fused_experts( fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, num_dispatchers: int, - w1_gs: Optional[torch.Tensor], - w2_gs: Optional[torch.Tensor], + N: int, ) -> mk.FusedMoEPermuteExpertsUnpermute: - - use_fp8 = moe.quant_dtype == torch.float8_e4m3fn batch_kwargs = { "max_num_tokens": moe.max_num_tokens, "num_dispatchers": num_dispatchers, } quant_kwargs = { - "use_fp8_w8a8": use_fp8, - "use_int8_w8a8": False, - "use_int8_w8a16": False, - "use_int4_w4a16": False, - "block_shape": moe.block_shape, - "per_act_token_quant": moe.per_act_token_quant, + "quant_config": quant_config, } deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000) + if fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | { - "block_shape": moe.block_shape, - "per_act_token_quant": moe.per_act_token_quant, - } + kwargs = batch_kwargs | quant_kwargs print(f"Making BatchedDeepGemmExperts {kwargs} ...") experts = BatchedDeepGemmExperts(**kwargs) elif fused_experts_type == BatchedTritonExperts: @@ -422,8 +462,8 @@ def make_fused_experts( print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts () ...") - experts = DeepGemmExperts() + print(f"Making DeepGemmExperts {quant_config} ...") + experts = DeepGemmExperts(quant_config) elif fused_experts_type == TritonExperts: kwargs = quant_kwargs print(f"Making TritonExperts {kwargs} ...") @@ -437,62 +477,50 @@ def make_fused_experts( print(f"Making NaiveBatchedExperts {kwargs} ...") experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == CutlassExpertsFp8: + strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, - } + "ab_strides1": strides[0], + "ab_strides2": strides[1], + "c_strides1": strides[2], + "c_strides2": strides[3], + } | quant_kwargs print(f"Making CutlassExpertsFp8 {kwargs} ...") experts = CutlassExpertsFp8(**kwargs) elif fused_experts_type == CutlassBatchedExpertsFp8: + strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { "max_experts_per_worker": moe.num_local_experts, "num_dispatchers": num_dispatchers, "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, - } + "ab_strides1": strides[0], + "ab_strides2": strides[1], + "c_strides1": strides[2], + "c_strides2": strides[3], + } | quant_kwargs print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") experts = CutlassBatchedExpertsFp8(**kwargs) elif fused_experts_type == CutlassExpertsFp4: - assert w1_gs is not None and w2_gs is not None - num_experts = moe.num_local_experts - rank = moe.moe_parallel_config.dp_rank kwargs = { - "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), - "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), - "a1_gscale": _make_gscale(num_experts), - "a2_gscale": _make_gscale(num_experts), - "max_experts_per_worker": num_experts, - "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, + "max_experts_per_worker": moe.num_local_experts, "num_dispatchers": num_dispatchers, - } + "out_dtype": moe.in_dtype, + } | quant_kwargs print(f"Making CutlassExpertsFp4 {kwargs} ...") experts = CutlassExpertsFp4(**kwargs) elif fused_experts_type == FlashInferExperts: - assert w1_gs is not None and w2_gs is not None - num_experts = moe.num_local_experts - rank = moe.moe_parallel_config.dp_rank kwargs = { - "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), - "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), - "a1_gscale": _make_gscale(num_experts), - "a2_gscale": _make_gscale(num_experts), "out_dtype": moe.in_dtype, - "quant_dtype": "nvfp4", "ep_rank": moe.ep_rank, "ep_size": moe.ep_size, "tp_rank": moe.tp_rank, "tp_size": moe.tp_size, - } + } | quant_kwargs print(f"Making FlashInferExperts {kwargs} ...") experts = FlashInferExperts(**kwargs) else: raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80) + return experts diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 459b785e6504e..7802129d3d48f 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -6,13 +6,11 @@ import traceback from typing import Any, Callable, Optional import torch -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec from vllm.config import VllmConfig, set_current_vllm_config -from vllm.distributed import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed import init_distributed_environment, initialize_model_parallel from vllm.utils import get_open_port ## Parallel Processes Utils @@ -30,10 +28,11 @@ class ProcessGroupInfo: device: torch.device -def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, - local_rank: int): - +def _set_vllm_config( + vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int +): import tempfile + temp_file = tempfile.mkstemp()[1] with set_current_vllm_config(vllm_config): @@ -46,13 +45,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, ) initialize_model_parallel( - tensor_model_parallel_size=vllm_config.parallel_config. - tensor_parallel_size, - pipeline_model_parallel_size=vllm_config.parallel_config. - pipeline_parallel_size, + tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size, ) - cpu_group = torch.distributed.new_group(list(range(world_size)), - backend="gloo") + cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") return cpu_group @@ -62,8 +58,7 @@ def _worker_parallel_launch( world_local_size: int, node_rank: int, init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, - P], None], + worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None], vllm_config: Optional[VllmConfig], env_dict: Optional[dict], *args: P.args, @@ -131,7 +126,8 @@ def parallel_launch_with_config( worker, vllm_config, env_dict, - ) + args, + ) + + args, nprocs=world_size, join=True, ) diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index 0da6ee3543521..48e5c4659b49a 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -14,28 +14,31 @@ from .common import Config, RankTensors, WeightTensors, make_modular_kernel from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config -def do_profile(fn: Callable, - fn_kwargs: dict[Any, Any], - pgi: ProcessGroupInfo, - config: Config, - num_warmups: int = 5): +def do_profile( + fn: Callable, + fn_kwargs: dict[Any, Any], + pgi: ProcessGroupInfo, + config: Config, + num_warmups: int = 5, +): for _ in range(num_warmups): fn(**fn_kwargs) with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - record_shapes=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=True, ) as tprof: fn(**fn_kwargs) torch.cuda.synchronize(torch.cuda.current_device()) # TODO (varun): Add a descriptive trace file name tprof.export_chrome_trace( - f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json") + f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json" + ) def profile_modular_kernel( @@ -82,6 +85,7 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -108,20 +112,25 @@ def rank_worker( def run(config: Config): weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + parallel_launch_with_config( + config.world_size, rank_worker, vllm_config, env_dict, config, weights + ) -if __name__ == '__main__': +if __name__ == "__main__": from .cli_args import make_config, make_config_arg_parser - parser = make_config_arg_parser(description=( - "Run single prepare-finalize & fused-experts combination test" - "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501 - "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" - )) + + parser = make_config_arg_parser( + description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + ) + ) args = parser.parse_args() assert args.torch_trace_dir_path is not None, ( - "Please pass in a directory to store torch traces") + "Please pass in a directory to store torch traces" + ) config = make_config(args) run(config) diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index 1ad361ae07333..fb9e5df281f1d 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -3,6 +3,7 @@ """ DeepEP test utilities """ + import dataclasses import os import traceback @@ -10,17 +11,18 @@ from typing import Callable, Optional import torch from torch.distributed import ProcessGroup -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec from vllm.utils import get_open_port, has_deep_ep if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) ## Parallel Processes Utils @@ -96,7 +98,8 @@ def parallel_launch( 0, f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", worker, - ) + args, + ) + + args, nprocs=world_size, join=True, ) @@ -118,48 +121,57 @@ class DeepEPLLArgs: use_fp8_dispatch: bool -def make_deepep_ht_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - ht_args: DeepEPHTArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - +def make_deepep_ht_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + ht_args: DeepEPHTArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): import deep_ep # high throughput a2a num_nvl_bytes = 1024 * 1024 * 1024 # 1GB num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 - buffer = deep_ep.Buffer(group=pg, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=num_qps_per_rank) - return DeepEPHTPrepareAndFinalize(buffer=buffer, - num_dispatchers=pgi.world_size, - dp_size=dp_size, - rank_expert_offset=pgi.rank * - ht_args.num_local_experts) + buffer = deep_ep.Buffer( + group=pg, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank, + ) + return DeepEPHTPrepareAndFinalize( + buffer=buffer, + num_dispatchers=pgi.world_size, + dp_size=dp_size, + rank_expert_offset=pgi.rank * ht_args.num_local_experts, + ) -def make_deepep_ll_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - deepep_ll_args: DeepEPLLArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - +def make_deepep_ll_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + deepep_ll_args: DeepEPLLArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): import deep_ep # low-latency a2a num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, - pgi.world_size, deepep_ll_args.num_experts) + deepep_ll_args.max_tokens_per_rank, + deepep_ll_args.hidden_size, + pgi.world_size, + deepep_ll_args.num_experts, + ) - buffer = deep_ep.Buffer(group=pg, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=deepep_ll_args.num_experts // - pgi.world_size) + buffer = deep_ep.Buffer( + group=pg, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size, + ) return DeepEPLLPrepareAndFinalize( buffer=buffer, @@ -169,17 +181,20 @@ def make_deepep_ll_a2a(pg: ProcessGroup, ) -def make_deepep_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - deepep_ht_args: Optional[DeepEPHTArgs], - deepep_ll_args: Optional[DeepEPLLArgs], - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): +def make_deepep_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ht_args: Optional[DeepEPHTArgs], + deepep_ll_args: Optional[DeepEPLLArgs], + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): if deepep_ht_args is not None: assert deepep_ll_args is None - return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, - block_shape) + return make_deepep_ht_a2a( + pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape + ) assert deepep_ll_args is not None return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape) diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 018d4c224f75e..59cecd60d3d61 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -5,11 +5,14 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, +) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) + BatchedPrepareAndFinalize, + BatchedTritonExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported from .test_deepgemm import make_block_quant_fp8_weights @@ -17,15 +20,15 @@ from .test_deepgemm import make_block_quant_fp8_weights BLOCK_SIZE = [128, 128] -@pytest.mark.skipif(not is_deep_gemm_supported(), - reason="Requires deep_gemm kernels") +@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") @pytest.mark.parametrize("E", [16, 32]) # number of experts @pytest.mark.parametrize("T", [256, 512]) # tokens per expert @pytest.mark.parametrize("K", [128, 256]) # hidden dim @pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert @pytest.mark.parametrize("topk", [2, 4]) -def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, - monkeypatch): +def test_batched_deepgemm_vs_triton( + E: int, T: int, K: int, N: int, topk: int, monkeypatch +): """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") @@ -56,13 +59,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, rank=0, ) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + per_act_token_quant=False, + block_shape=BLOCK_SIZE, + ) + # triton (reference) triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=True, - per_act_token_quant=False, - block_shape=BLOCK_SIZE, + quant_config=quant_config, ) mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) @@ -73,8 +81,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - w1_scale=w1_s, - w2_scale=w2_s, global_num_experts=E, ) @@ -82,8 +88,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, deepgemm_experts = BatchedDeepGemmExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - block_shape=BLOCK_SIZE, - per_act_token_quant=False, + quant_config=quant_config, ) mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) @@ -94,8 +99,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - w1_scale=w1_s, - w2_scale=w2_s, global_num_experts=E, ) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 00b2d780e66f5..09cede3fbcc77 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -7,14 +7,18 @@ from typing import Optional import pytest import torch -from tests.kernels.moe.utils import (batched_moe, - make_quantized_test_activations, - make_test_weights, naive_batched_moe) +from tests.kernels.moe.utils import ( + batched_moe, + make_quantized_test_activations, + make_test_weights, + naive_batched_moe, +) from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_moe_batched_triton_kernel) + invoke_moe_batched_triton_kernel, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform from vllm.triton_utils import tl @@ -68,23 +72,32 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - A = torch.randn( - (config.num_experts, config.max_tokens_per_expert, config.K), + A = ( + torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.in_dtype, + ) + / 10 + ) + B = torch.randn( + (config.num_experts, config.N, config.K), device="cuda", - dtype=config.in_dtype) / 10 - B = torch.randn((config.num_experts, config.N, config.K), - device="cuda", - dtype=config.in_dtype) + dtype=config.in_dtype, + ) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config.out_dtype) + dtype=config.out_dtype, + ) - num_expert_tokens = torch.randint(low=0, - high=config.max_tokens_per_expert, - size=(config.num_experts, ), - device="cuda", - dtype=torch.int32) + num_expert_tokens = torch.randint( + low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts,), + device="cuda", + dtype=torch.int32, + ) return BatchedMMTensors(A, B, C, num_expert_tokens) @@ -96,10 +109,15 @@ class BatchedMMTensors: @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) -def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, - N: int, dtype: torch.dtype, - block_shape: Optional[list[int]], - per_act_token_quant: bool): +def test_batched_mm( + num_experts: int, + max_tokens_per_expert: int, + K: int, + N: int, + dtype: torch.dtype, + block_shape: Optional[list[int]], + per_act_token_quant: bool, +): current_platform.seed_everything(7) use_fp8_w8a8 = dtype == torch.float8_e4m3fn @@ -117,11 +135,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, act_dtype = dtype quant_dtype = None - num_expert_tokens = torch.randint(low=0, - high=max_tokens_per_expert, - size=(num_experts, ), - device="cuda", - dtype=torch.int32) + num_expert_tokens = torch.randint( + low=0, + high=max_tokens_per_expert, + size=(num_experts,), + device="cuda", + dtype=torch.int32, + ) A, A_q, A_scale = make_quantized_test_activations( num_experts, @@ -140,7 +160,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, in_dtype=act_dtype, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) out_shape = (num_experts, max_tokens_per_expert, N) @@ -151,7 +171,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, compute_tl_dtype = { torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, - torch.float32: tl.float32 + torch.float32: tl.float32, }[test_output.dtype] assert A_q.dtype == B_q.dtype @@ -173,7 +193,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, config={ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 + "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32, }, per_act_token_quant=per_act_token_quant, block_shape=block_shape, @@ -186,11 +206,16 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, num_expert_tokens, ) - q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output, - num_expert_tokens, - A_scale, B_scale, - block_shape, - per_act_token_quant) + q_ref_output = native_batched_masked_quant_matmul( + A_q, + B_q, + q_ref_output, + num_expert_tokens, + A_scale, + B_scale, + block_shape, + per_act_token_quant, + ) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -250,7 +275,7 @@ def test_fused_moe_batched_experts( block_shape=block_shape, in_dtype=act_dtype, quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) if input_scales and quant_dtype is not None: @@ -308,12 +333,6 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - torch.testing.assert_close(batched_output, - baseline_output, - atol=3e-2, - rtol=2e-2) + torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2) - torch.testing.assert_close(triton_output, - batched_output, - atol=2e-2, - rtol=2e-2) + torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 9e4eaf221f245..b8cd3cb9200c9 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -4,19 +4,25 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul) +from tests.kernels.moe.utils import make_test_quant_config, make_test_weights +from tests.kernels.quant_utils import ( + native_per_token_group_quant_fp8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) + _valid_deep_gemm_shape, + deep_gemm_moe_fp8, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) + fused_topk, + modular_triton_fused_moe, +) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -24,8 +30,7 @@ if dg_available: from deep_gemm import get_m_alignment_for_contiguous_layout if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -97,8 +102,7 @@ TOP_KS = [1, 2, 6] SEEDS = [0] -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, - block_shape): +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape topk = topk_ids.size(1) @@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k) + out[mask] = native_w8a8_block_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) # Skip all tests if CUDA is not available @@ -149,8 +147,9 @@ def setup_cuda(): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, - monkeypatch): +def test_w8a8_block_fp8_fused_moe( + M, N, K, E, topk, block_size, dtype, seed, monkeypatch +): if topk > E: pytest.skip(f"Skipping test; topk={topk} > E={E}") @@ -161,22 +160,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + w1, w2, quant_config = make_test_quant_config( + E, + N, + K, + dtype, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size, + ) - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=block_size) + m_fused_moe = modular_triton_fused_moe(quant_config) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) @@ -186,37 +180,21 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a, w1, w2, - w1_s, - w2_s, + quant_config.w1_scale, + quant_config.w2_scale, topk_weights, topk_ids, block_size, ) out = fused_experts( - a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, + a, w1, w2, topk_weights, topk_ids, quant_config=quant_config ) - m_out = m_fused_moe( - a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - ) + m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids) - # 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0] - tol = 0.035 if M < 40000 else 0.039 + # 0.039 only needed for M >= 8192 + tol = 0.035 if M < 8192 else 0.039 torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol) torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol) @@ -226,11 +204,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, - monkeypatch): +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") @@ -249,50 +225,53 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_out_ch_quant=False, + block_shape=block_size, + ) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. use_compile = False - use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 - and current_platform.is_cuda_alike()) + use_cudagraph = ( + chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike() + ) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids, block_size) + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size + ) if use_compile: - deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, - backend="inductor", - fullgraph=True) + deep_gemm_moe_fp8_fn = torch.compile( + deep_gemm_moe_fp8, backend="inductor", fullgraph=True + ) torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(topk_weights, 0) torch._dynamo.mark_dynamic(topk_ids, 0) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) if use_cudagraph: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8_fn( + a, w1, w2, w1_s, w2_s, topk_weights, topk_ids + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 5e4a93963f8e8..74cc943714dd9 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -4,17 +4,18 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quant_utils import (native_per_token_group_quant_int8, - native_w8a8_block_matmul) +from tests.kernels.moe.utils import make_test_quant_config +from tests.kernels.quant_utils import ( + native_per_token_group_quant_int8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -50,7 +51,7 @@ MNK_FACTORS = [ (2048, 128, 128), (2048, 1024, 7168), (2048, 4096, 512), - (2048, 4096, 7168), + (2048, 4096, 4096), ] E = [8, 24] @@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_int8( - act_out, block_k) + act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k) act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + out[mask] = native_w8a8_block_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -117,32 +112,33 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + w1, w2, quant_config = make_test_quant_config( + E, + N, + K, + dtype, + quant_dtype=torch.int8, + per_act_token_quant=False, + block_shape=block_size, + ) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - out = fused_moe( + out = fused_experts( + a, w1, w2, topk_weights, topk_ids, quant_config=quant_config + ) + ref_out = torch_w8a8_block_int8_moe( a, w1, w2, + quant_config.w1_scale, + quant_config.w2_scale, score, topk, - renormalize=False, - use_int8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, + block_size, ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) # Check results torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065) diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py index 1768baaf1ca71..996a4538d1054 100644 --- a/tests/kernels/moe/test_count_expert_num_tokens.py +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens @dataclasses.dataclass class TestTensors: - topk_ids: torch.Tensor expert_map: Optional[torch.Tensor] = None @@ -25,32 +24,31 @@ class TestTensors: self.expert_map = self.expert_map.to(device=device) @staticmethod - def make(num_tokens: int, num_topk: int, num_experts: int, device: str, - topk_ids_dtype: torch.dtype) -> "TestTensors": - + def make( + num_tokens: int, + num_topk: int, + num_experts: int, + device: str, + topk_ids_dtype: torch.dtype, + ) -> "TestTensors": # make topk ids - topk_ids = torch.empty((num_tokens, num_topk), - device=device, - dtype=torch.int64) + topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64) for x in range(num_tokens): topk_ids[x] = torch.randperm(num_experts)[:num_topk] topk_ids = topk_ids.to(dtype=torch.int64) return TestTensors(topk_ids=topk_ids) - def with_ep_rank(self, ep_rank: int, num_global_experts: int, - num_local_experts: int, device: str): + def with_ep_rank( + self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str + ): # make an expert map - expert_map = torch.empty((num_global_experts), - device=device, - dtype=torch.int32) + expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32) expert_map.fill_(-1) s = ep_rank * num_local_experts e = s + num_local_experts - expert_map[s:e] = torch.tensor(list(range(num_local_experts)), - device=device) + expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device) - return TestTensors(topk_ids=self.topk_ids.clone(), - expert_map=expert_map) + return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map) def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): @@ -68,49 +66,49 @@ def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): expert_num_tokens[eid] += count -def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, - num_experts: int, ep_size: int, - topk_ids_dtype: torch.dtype): - +def do_test_compute_expert_num_tokens( + num_tokens: int, + num_topk: int, + num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype, +): assert num_topk <= num_experts - tt = TestTensors.make(num_tokens, - num_topk, - num_experts, - topk_ids_dtype=topk_ids_dtype, - device="cpu") + tt = TestTensors.make( + num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu" + ) num_global_experts = num_experts assert num_global_experts % ep_size == 0 num_local_experts = num_global_experts // ep_size for ep_rank in range(ep_size): - tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, - num_local_experts, "cpu") + tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu") - ref_expert_num_tokens = torch.zeros((num_local_experts), - device="cpu", - dtype=torch.int32) + ref_expert_num_tokens = torch.zeros( + (num_local_experts), device="cpu", dtype=torch.int32 + ) ref_impl(tt_rank, ref_expert_num_tokens) ref_expert_num_tokens = ref_expert_num_tokens.to("cuda") tt_rank.to_device("cuda") # Test with expert_map triton_expert_num_tokens_w_emap = count_expert_num_tokens( - tt_rank.topk_ids, num_local_experts, tt_rank.expert_map) + tt_rank.topk_ids, num_local_experts, tt_rank.expert_map + ) # Test without expert map topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype) triton_expert_num_tokens_wo_emap = count_expert_num_tokens( - topk_ids, num_local_experts, expert_map=None) + topk_ids, num_local_experts, expert_map=None + ) - torch.testing.assert_close(ref_expert_num_tokens, - triton_expert_num_tokens_w_emap, - atol=0, - rtol=0) - torch.testing.assert_close(ref_expert_num_tokens, - triton_expert_num_tokens_wo_emap, - atol=0, - rtol=0) + torch.testing.assert_close( + ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0 + ) + torch.testing.assert_close( + ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0 + ) @pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317]) @@ -118,22 +116,29 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, @pytest.mark.parametrize("num_experts", [64]) @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) -def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, - num_experts: int, ep_size: int, - topk_ids_dtype: torch.dtype): - do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts, - ep_size, topk_ids_dtype) +def test_compute_expert_num_tokens( + num_tokens: int, + num_topk: int, + num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype, +): + do_test_compute_expert_num_tokens( + num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype + ) @pytest.mark.parametrize("numel", list(range(1, 8192, 111))) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("ep_size", [2]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) -def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int, - ep_size: int, - topk_ids_dtype: torch.dtype): - do_test_compute_expert_num_tokens(num_tokens=numel, - num_topk=1, - num_experts=num_experts, - ep_size=ep_size, - topk_ids_dtype=topk_ids_dtype) +def test_compute_expert_num_tokens_from_numel( + numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype +): + do_test_compute_expert_num_tokens( + num_tokens=numel, + num_topk=1, + num_experts=num_experts, + ep_size=ep_size, + topk_ids_dtype=topk_ids_dtype, + ) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 3b1618dacac7b..4c60241bdb01c 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -17,19 +17,24 @@ from vllm.utils import cdiv from vllm.utils.deep_gemm import per_block_cast_to_fp8 -@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ - (4, 8192, 7168, 4096), - (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), - (8, 4096, 2048, 7168), - (32, 1024, 7168, 4096), - (32, 1024, 2048, 7168), -]) +@pytest.mark.parametrize( + "num_groups, expected_m_per_group, k, n", + [ + (4, 8192, 7168, 4096), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + (32, 1024, 7168, 4096), + (32, 1024, 2048, 7168), + ], +) @pytest.mark.parametrize("out_dtype", [torch.float16]) @pytest.mark.skipif( (lambda x: x is None or x.to_int() != 100)( - current_platform.get_device_capability()), - reason="Block Scaled Grouped GEMM is only supported on SM100.") + current_platform.get_device_capability() + ), + reason="Block Scaled Grouped GEMM is only supported on SM100.", +) def test_cutlass_grouped_gemm( num_groups: int, expected_m_per_group: int, @@ -40,8 +45,7 @@ def test_cutlass_grouped_gemm( device = "cuda" alignment = 128 group_ms = [ - int(expected_m_per_group * random.uniform(0.7, 1.3)) - for _ in range(num_groups) + int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups) ] m = sum([cdiv(m, alignment) * alignment for m in group_ms]) @@ -58,20 +62,22 @@ def test_cutlass_grouped_gemm( expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, cdiv(n, 128), k // 128), - device=device, - dtype=torch.float)) + y_fp8 = ( + torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty( + (num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float + ), + ) for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128]) for i in range(num_groups): - a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] - a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]] + a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]] + a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]] b = y_fp8[0][i].t() b_scale = y_fp8[1][i].t() baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) - ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline + ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline ops.cutlass_blockwise_scaled_grouped_mm( out, diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index c84f66383b902..b82cea61bd4ea 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import dataclasses from math import prod from typing import Optional @@ -9,12 +10,16 @@ import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, run_cutlass_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + cutlass_moe_fp8, + run_cutlass_moe_fp8, +) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform NUM_EXPERTS = [40, 64] @@ -36,12 +41,11 @@ MNK_FACTORS = [ (224, 3072, 1536), (32768, 1024, 1024), # These sizes trigger wrong answers. - #(7232, 2048, 5120), - #(40000, 2048, 5120), + # (7232, 2048, 5120), + # (40000, 2048, 5120), ] -vllm_config = VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1)) +vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -57,22 +61,25 @@ class MOETensors: c_strides2: torch.Tensor @staticmethod - def make_moe_tensors(m: int, k: int, n: int, e: int, - dtype: torch.dtype) -> "MOETensors": + def make_moe_tensors( + m: int, k: int, n: int, e: int, dtype: torch.dtype + ) -> "MOETensors": a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - return MOETensors(a=a, - w1=w1, - w2=w2, - ab_strides1=ab_strides1, - c_strides1=c_strides1, - ab_strides2=ab_strides2, - c_strides2=c_strides2) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) + return MOETensors( + a=a, + w1=w1, + w2=w2, + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2, + ) @dataclasses.dataclass @@ -90,9 +97,9 @@ class MOETensors8Bit(MOETensors): w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d @staticmethod - def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, - per_act_token: bool, - per_out_channel: bool) -> "MOETensors8Bit": + def make_moe_tensors_8bit( + m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool + ) -> "MOETensors8Bit": dtype = torch.half q_dtype = torch.float8_e4m3fn @@ -103,24 +110,21 @@ class MOETensors8Bit(MOETensors): k_b_scales = k if per_out_channel else 1 # Get the right scale for tests. a_q, a_scale = ops.scaled_fp8_quant( - moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token) + moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token + ) w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - moe_tensors_fp16.w1[expert], - use_per_token_if_dynamic=per_out_channel) + moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - moe_tensors_fp16.w2[expert], - use_per_token_if_dynamic=per_out_channel) + moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel + ) # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d a_d = a_q.float().mul(a_scale).to(dtype) @@ -130,31 +134,37 @@ class MOETensors8Bit(MOETensors): w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() - return MOETensors8Bit(a=moe_tensors_fp16.a, - w1=moe_tensors_fp16.w1, - w2=moe_tensors_fp16.w2, - ab_strides1=moe_tensors_fp16.ab_strides1, - c_strides1=moe_tensors_fp16.c_strides1, - ab_strides2=moe_tensors_fp16.ab_strides2, - c_strides2=moe_tensors_fp16.c_strides2, - a_q=a_q, - w1_q=w1_q, - w2_q=w2_q, - a_scale=a_scale, - w1_scale=w1_scale, - w2_scale=w2_scale, - a_d=a_d, - w1_d=w1_d, - w2_d=w2_d) + return MOETensors8Bit( + a=moe_tensors_fp16.a, + w1=moe_tensors_fp16.w1, + w2=moe_tensors_fp16.w2, + ab_strides1=moe_tensors_fp16.ab_strides1, + c_strides1=moe_tensors_fp16.c_strides1, + ab_strides2=moe_tensors_fp16.ab_strides2, + c_strides2=moe_tensors_fp16.c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a_d=a_d, + w1_d=w1_d, + w2_d=w2_d, + ) -def run_with_expert_maps(num_experts: int, num_local_experts: int, - **cutlass_moe_kwargs): - +def run_with_expert_maps( + num_experts: int, num_local_experts: int, **cutlass_moe_kwargs +): def slice_experts(): slice_params = [ - "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", - "c_strides2", "w1_scale", "w2_scale" + "w1_q", + "w2_q", + "ab_strides1", + "ab_strides2", + "c_strides1", + "c_strides2", ] full_tensors = { k: v @@ -162,15 +172,15 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, if k in slice_params and k in cutlass_moe_kwargs } + quant_config = cutlass_moe_kwargs["quant_config"] + for i in range(0, num_experts, num_local_experts): s, e = i, i + num_local_experts # make expert map expert_map = [-1] * num_experts expert_map[s:e] = list(range(num_local_experts)) - expert_map = torch.tensor(expert_map, - dtype=torch.int32, - device="cuda") + expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") # update cutlass moe arg with expert_map cutlass_moe_kwargs["expert_map"] = expert_map @@ -178,6 +188,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, for k, t in full_tensors.items(): cutlass_moe_kwargs[k] = t[s:e] + new_quant_config = copy.deepcopy(quant_config) + new_quant_config._w1.scale = quant_config.w1_scale[s:e] + new_quant_config._w2.scale = quant_config.w2_scale[s:e] + + cutlass_moe_kwargs["quant_config"] = new_quant_config + yield cutlass_moe_kwargs out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) @@ -187,32 +203,48 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, return out_tensor -def run_8_bit(moe_tensors: MOETensors8Bit, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - per_act_token: bool, - num_local_experts: Optional[int] = None) -> torch.Tensor: - assert not any([ - t is None for t in [ - moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, - moe_tensors.w2_scale, moe_tensors.a_scale +def run_8_bit( + moe_tensors: MOETensors8Bit, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + per_act_token: bool, + per_out_ch: bool, + num_local_experts: Optional[int] = None, +) -> torch.Tensor: + assert not any( + [ + t is None + for t in [ + moe_tensors.w1_q, + moe_tensors.w2_q, + moe_tensors.w1_scale, + moe_tensors.w2_scale, + moe_tensors.a_scale, + ] ] - ]) + ) + + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=moe_tensors.w1_scale, + w2_scale=moe_tensors.w2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + # Set to moe_tensors.a_scale iff static scales + per tensor. + # This is not currently being tested. + a1_scale=None, + ) kwargs = { - 'a': moe_tensors.a, - 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] - 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] - 'topk_weights': topk_weights, - 'topk_ids': topk_ids, - 'w1_scale': moe_tensors.w1_scale, - 'w2_scale': moe_tensors.w2_scale, - 'ab_strides1': moe_tensors.ab_strides1, - 'ab_strides2': moe_tensors.ab_strides2, - 'c_strides1': moe_tensors.c_strides1, - 'c_strides2': moe_tensors.c_strides2, - 'per_act_token': per_act_token, - 'a1_scale': None #moe_tensors.a_scale + "a": moe_tensors.a, + "w1_q": moe_tensors.w1_q, # type: ignore[union-attr] + "w2_q": moe_tensors.w2_q, # type: ignore[union-attr] + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "ab_strides1": moe_tensors.ab_strides1, + "ab_strides2": moe_tensors.ab_strides2, + "c_strides1": moe_tensors.c_strides1, + "c_strides2": moe_tensors.c_strides2, + "quant_config": quant_config, } num_experts = moe_tensors.w1.size(0) @@ -224,7 +256,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, return run_with_expert_maps( num_experts, num_local_experts, # type: ignore[arg-type] - **kwargs) + **kwargs, + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -234,8 +267,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_no_graph( m: int, n: int, @@ -250,34 +285,34 @@ def test_cutlass_moe_8_bit_no_graph( current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_ch) + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + triton_output = fused_experts( + mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config + ) if ep_size is not None: assert e % ep_size == 0, "Cannot distribute experts evenly" number_local_experts = e // ep_size else: number_local_experts = None - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, - number_local_experts) + + cutlass_output = run_8_bit( + mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts + ) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close( + triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2 + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -287,8 +322,10 @@ def test_cutlass_moe_8_bit_no_graph( @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_cuda_graph( m: int, n: int, @@ -304,34 +341,30 @@ def test_cutlass_moe_8_bit_cuda_graph( with set_current_vllm_config(vllm_config): dtype = torch.half - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_ch) + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + triton_output = fused_experts( + mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config + ) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, - per_act_token) + cutlass_output = run_8_bit( + mt, topk_weights, topk_ids, per_act_token, per_out_ch + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) + torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2) @pytest.mark.parametrize("m", [64]) @@ -344,8 +377,10 @@ def test_cutlass_moe_8_bit_cuda_graph( @pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_EP( m: int, n: int, @@ -357,8 +392,9 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, monkeypatch, ): - test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, - per_out_channel, monkeypatch, ep_size) + test_cutlass_moe_8_bit_no_graph( + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + ) LARGE_MNK_FACTORS = [ @@ -375,8 +411,10 @@ LARGE_MNK_FACTORS = [ @pytest.mark.parametrize("ep_size", [8]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_EP_large( m: int, n: int, @@ -388,8 +426,9 @@ def test_cutlass_moe_8_bit_EP_large( ep_size: int, monkeypatch, ): - test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, - per_out_channel, monkeypatch, ep_size) + test_cutlass_moe_8_bit_no_graph( + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + ) @pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) @@ -399,8 +438,10 @@ def test_cutlass_moe_8_bit_EP_large( @pytest.mark.parametrize("ep_size", [8]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_run_cutlass_moe_fp8( m: int, n: int, @@ -413,14 +454,12 @@ def test_run_cutlass_moe_fp8( ): current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_channel) + mt = MOETensors8Bit.make_moe_tensors_8bit( + m, k, n, e, per_act_token, per_out_channel + ) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # we want to make sure there is at least one token that's generated in # this expert shard and at least one token that's NOT generated in this # expert shard @@ -431,12 +470,12 @@ def test_run_cutlass_moe_fp8( workspace2_shape = (m * topk, max(n, k)) output_shape = (m, k) - workspace13 = torch.empty(prod(workspace13_shape), - device="cuda", - dtype=mt.a.dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device="cuda", - dtype=mt.a.dtype) + workspace13 = torch.empty( + prod(workspace13_shape), device="cuda", dtype=mt.a.dtype + ) + workspace2 = torch.empty( + prod(workspace2_shape), device="cuda", dtype=mt.a.dtype + ) num_local_experts = e // ep_size start, end = 0, num_local_experts @@ -444,36 +483,55 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) - a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, - torch.float8_e4m3fn, - per_act_token) + a1q, a1q_scale = moe_kernel_quantize_input( + mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token + ) global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0) func = lambda output: run_cutlass_moe_fp8( - output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, - global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, - workspace13, workspace2, None, mt.a.dtype, per_act_token, - per_out_channel, False, topk_weights) + output, + a1q, + mt.w1_q, + mt.w2_q, + topk_ids, + activation, + global_num_experts, + expert_map, + mt.w1_scale, + mt.w2_scale, + a1q_scale, + None, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + workspace13, + workspace2, + None, + mt.a.dtype, + per_act_token, + per_out_channel, + False, + topk_weights, + ) workspace13.random_() - output_random_workspace = torch.empty(output_shape, - device="cuda", - dtype=mt.a.dtype) + output_random_workspace = torch.empty( + output_shape, device="cuda", dtype=mt.a.dtype + ) func(output_random_workspace) workspace13.fill_(0) - output_zero_workspace = torch.zeros(output_shape, - device="cuda", - dtype=mt.a.dtype) + output_zero_workspace = torch.zeros( + output_shape, device="cuda", dtype=mt.a.dtype + ) func(output_zero_workspace) - torch.testing.assert_close(output_random_workspace, - output_zero_workspace, - atol=5e-3, - rtol=1e-3) + torch.testing.assert_close( + output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3 + ) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 6f95581a5e60d..e68c5bfa5946f 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -15,31 +15,35 @@ from torch.distributed import ProcessGroup from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch from .utils import make_test_weights if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm(): - from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) - from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts) + BatchedDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts requires_deep_ep = pytest.mark.skipif( not has_deep_ep(), @@ -56,9 +60,10 @@ P = ParamSpec("P") def next_power_of_2(x): import math + if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return 2 ** math.ceil(math.log2(x)) def make_block_quant_fp8_weights( @@ -70,10 +75,9 @@ def make_block_quant_fp8_weights( """ Return weights w1q, w2q, w1_scale, w2_scale """ - (_, w1q, w1_scale, _), (_, w2q, w2_scale, - _) = make_test_weights(e, n, k, torch.bfloat16, - torch.float8_e4m3fn, - block_size) + (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights( + e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size + ) return w1q, w2q, w1_scale, w2_scale @@ -101,15 +105,15 @@ class TestTensors: @staticmethod def make(config: TestConfig, rank) -> "TestTensors": - dtype = torch.bfloat16 topk, m, k = (config.topk, config.m, config.k) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - rank_tokens = torch.randn( - (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + rank_tokens = ( + torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + ) rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) rank_token_scales = None @@ -117,24 +121,32 @@ class TestTensors: low=0, high=config.num_experts, size=(m, topk), - device=torch.cuda.current_device()).to(dtype=torch.int64) + device=torch.cuda.current_device(), + ).to(dtype=torch.int64) - topk_weights = torch.randn(topk_ids.shape, - dtype=torch.float32, - device=torch.cuda.current_device()) + topk_weights = torch.randn( + topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device() + ) - return TestTensors(rank_tokens=rank_tokens, - rank_token_scales=rank_token_scales, - topk=topk_ids, - topk_weights=topk_weights, - config=config) + return TestTensors( + rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk_ids, + topk_weights=topk_weights, + config=config, + ) -def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - max_tokens_per_rank: int, dp_size: int, - hidden_size: int, q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: - +def make_ll_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + max_tokens_per_rank: int, + dp_size: int, + hidden_size: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig, +) -> FusedMoEModularKernel: assert test_config.low_latency assert test_config.use_fp8_dispatch is not None @@ -147,25 +159,30 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank=max_tokens_per_rank, hidden_size=hidden_size, num_experts=test_config.num_experts, - use_fp8_dispatch=test_config.use_fp8_dispatch), + use_fp8_dispatch=test_config.use_fp8_dispatch, + ), q_dtype=q_dtype, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + ) fused_experts = BatchedDeepGemmExperts( max_num_tokens=max_tokens_per_rank, num_dispatchers=pgi.world_size // dp_size, - block_shape=test_config.block_size, - per_act_token_quant=test_config.per_act_token_quant) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + quant_config=quant_config, + ) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: - +def make_ht_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + num_local_experts: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig, +) -> FusedMoEModularKernel: assert not test_config.low_latency assert test_config.use_fp8_dispatch is None @@ -176,62 +193,84 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts), deepep_ll_args=None, q_dtype=q_dtype, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + ) - fused_experts = DeepGemmExperts() - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + fused_experts = DeepGemmExperts(quant_config) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, - test_tensors: TestTensors) -> FusedMoEModularKernel: - +def make_modular_kernel( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + num_local_experts: int, + test_tensors: TestTensors, + quant_config: FusedMoEQuantConfig, +) -> FusedMoEModularKernel: q_dtype = torch.float8_e4m3fn test_config = test_tensors.config mk: FusedMoEModularKernel # Make modular kernel if test_config.low_latency: - max_tokens_per_rank = max( - 64, next_power_of_2(test_tensors.rank_tokens.size(0))) + max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0))) hidden_size = test_tensors.rank_tokens.size(-1) - mk = make_ll_modular_kernel(pg=pg, - pgi=pgi, - max_tokens_per_rank=max_tokens_per_rank, - dp_size=dp_size, - hidden_size=hidden_size, - q_dtype=q_dtype, - test_config=test_config) + mk = make_ll_modular_kernel( + pg=pg, + pgi=pgi, + max_tokens_per_rank=max_tokens_per_rank, + dp_size=dp_size, + hidden_size=hidden_size, + q_dtype=q_dtype, + test_config=test_config, + quant_config=quant_config, + ) else: - mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts, - q_dtype, test_config) + mk = make_ht_modular_kernel( + pg, + pgi, + dp_size, + num_local_experts, + q_dtype, + test_config, + quant_config=quant_config, + ) return mk -def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, test_tensors: TestTensors, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor]) -> torch.Tensor: - +def deepep_deepgemm_moe_impl( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + test_tensors: TestTensors, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], +) -> torch.Tensor: test_config = test_tensors.config num_experts = test_config.num_experts num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) - expert_map = torch.full((num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - return expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) + + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + # Low-Latency kernels can't dispatch scales. + a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales), + block_shape=test_config.block_size, + ) # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( @@ -239,35 +278,42 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, pgi=pgi, dp_size=dp_size, num_local_experts=num_local_experts, - test_tensors=test_tensors) + test_tensors=test_tensors, + quant_config=quant_config, + ) - # Low-Latency kernels can't dispatch scales. - a1_scale = (None - if test_config.low_latency else test_tensors.rank_token_scales) - - out = mk.forward(hidden_states=test_tensors.rank_tokens, - w1=w1, - w2=w2, - topk_weights=test_tensors.topk_weights, - topk_ids=test_tensors.topk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=a1_scale, - a2_scale=None, - apply_router_weight_on_input=False) + out = mk.forward( + hidden_states=test_tensors.rank_tokens, + w1=w1, + w2=w2, + topk_weights=test_tensors.topk_weights, + topk_ids=test_tensors.topk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + apply_router_weight_on_input=False, + ) return out -def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, - topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a1_scale: torch.Tensor, block_shape: list[int]): +def triton_impl( + a: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: list[int], +): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + block_shape=block_shape, + ) return fused_experts( hidden_states=a, @@ -276,14 +322,11 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - block_shape=block_shape, + quant_config=quant_config, # Make sure this is set to False so we - # dont end up comparing the same implementation. - allow_deep_gemm=False) + # don't end up comparing the same implementation. + allow_deep_gemm=False, + ) def _test_deepep_deepgemm_moe( @@ -304,22 +347,21 @@ def _test_deepep_deepgemm_moe( pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, pgi.rank) - block_shape = [ - w1.size(1) // w1_scale.size(1), - w1.size(2) // w1_scale.size(2) - ] + block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)] with set_current_vllm_config(VllmConfig()): # Reference - triton_moe = triton_impl(a=test_tensors.rank_tokens, - topk_ids=test_tensors.topk, - topk_weights=test_tensors.topk_weights, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=test_tensors.rank_token_scales, - block_shape=block_shape) + triton_moe = triton_impl( + a=test_tensors.rank_tokens, + topk_ids=test_tensors.topk, + topk_weights=test_tensors.topk_weights, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=test_tensors.rank_token_scales, + block_shape=block_shape, + ) # Slice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -370,12 +412,18 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Skipping test for Blackwell DeepGEMM") -def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, - topk: int, world_dp_size: tuple[int, int]): +@pytest.mark.skipif( + is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" +) +def test_ht_deepep_deepgemm_moe( + mnk: tuple[int, int, int], + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], +): """ Tests for High-Throughput DeepEP + DeepGemm integration. """ @@ -391,21 +439,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, block_size = [block_m, block_m] world_size, dp_size = world_dp_size - config = TestConfig(topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts, - per_act_token_quant=False, - block_size=block_size, - low_latency=False, - use_fp8_dispatch=None) + config = TestConfig( + topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts, + per_act_token_quant=False, + block_size=block_size, + low_latency=False, + use_fp8_dispatch=None, + ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( - num_experts, n, k, block_size) + num_experts, n, k, block_size + ) - parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, - w2, w1_scale, w2_scale) + parallel_launch( + world_size, + _test_deepep_deepgemm_moe, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + ) MNKs = [ @@ -427,10 +486,12 @@ USE_FP8_DISPATCH = [False] @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @pytest.mark.parametrize("block_size", [[128, 128]]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Skipping test for Blackwell DeepGEMM") +@pytest.mark.skipif( + is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" +) def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, @@ -463,7 +524,16 @@ def test_ll_deepep_deepgemm_moe( ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( - num_experts, n, k, block_size) + num_experts, n, k, block_size + ) - parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, - w2, w1_scale, w2_scale) + parallel_launch( + world_size, + _test_deepep_deepgemm_moe, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + ) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 43804c410b6c2..a1dabea1f0c7d 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -15,22 +15,25 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.platforms import current_platform from vllm.utils import has_deep_ep +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a @@ -43,7 +46,7 @@ MAX_TOKENS_PER_RANK = 64 def make_weights( - e, n, k, dtype + e, n, k, dtype ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Return weights w1, w2, w1_scale, w2_scale @@ -62,17 +65,15 @@ def make_weights( k_b_scales = k w1_q = torch.empty_like(w1, dtype=dtype) w2_q = torch.empty_like(w2, dtype=dtype) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=True) + w1[expert], use_per_token_if_dynamic=True + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=True) + w2[expert], use_per_token_if_dynamic=True + ) return w1_q, w2_q, w1_scale, w2_scale @@ -98,24 +99,25 @@ class TestTensors: def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": # TODO (varun) - check that float16 works ? assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn] - token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn - else config.dtype) - rank_tokens = torch.randn( - (config.m, config.k), device="cuda", dtype=token_dtype) / 10 + token_dtype = ( + torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype + ) + rank_tokens = ( + torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10 + ) rank_token_scales = None - topk = torch.randint(low=0, - high=config.num_experts, - size=(config.m, config.topk), - device="cuda").to(dtype=torch.int64) - topk_weights = torch.randn(topk.shape, - dtype=torch.float32, - device="cuda") - return TestTensors(rank_tokens=rank_tokens, - rank_token_scales=rank_token_scales, - topk=topk, - topk_weights=topk_weights, - config=config) + topk = torch.randint( + low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda" + ).to(dtype=torch.int64) + topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda") + return TestTensors( + rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk, + topk_weights=topk_weights, + config=config, + ) def make_modular_kernel( @@ -128,57 +130,49 @@ def make_modular_kernel( num_local_experts: int, q_dtype: Optional[torch.dtype], use_fp8_dispatch: bool, - per_act_token_quant: bool, + quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: - - is_quantized = q_dtype is not None - ht_args: Optional[DeepEPHTArgs] = None ll_args: Optional[DeepEPLLArgs] = None if low_latency_mode: - ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK, - hidden_size=hidden_size, - num_experts=num_experts, - use_fp8_dispatch=use_fp8_dispatch) + ll_args = DeepEPLLArgs( + max_tokens_per_rank=MAX_TOKENS_PER_RANK, + hidden_size=hidden_size, + num_experts=num_experts, + use_fp8_dispatch=use_fp8_dispatch, + ) else: assert not use_fp8_dispatch, ( - "FP8 Dispatch is valid only for low-latency kernels") + "FP8 Dispatch is valid only for low-latency kernels" + ) ht_args = DeepEPHTArgs(num_local_experts=num_local_experts) - a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \ - make_deepep_a2a(pg = pg, - pgi = pgi, - dp_size = dp_size, - q_dtype = q_dtype, - block_shape = None, - deepep_ht_args = ht_args, - deepep_ll_args = ll_args) + a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = ( + make_deepep_a2a( + pg=pg, + pgi=pgi, + dp_size=dp_size, + q_dtype=q_dtype, + block_shape=None, + deepep_ht_args=ht_args, + deepep_ll_args=ll_args, + ) + ) num_dispatchers = pgi.world_size // dp_size if low_latency_mode: - assert not per_act_token_quant, "not supported in ll mode" + assert not quant_config.per_act_token_quant, "not supported in ll mode" fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, num_dispatchers=num_dispatchers, - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=False, + quant_config=quant_config, ) else: - fused_experts = TritonExperts( - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=per_act_token_quant, - ) + fused_experts = TritonExperts(quant_config=quant_config) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -196,19 +190,15 @@ def deep_ep_moe_impl( use_fp8_dispatch: bool, per_act_token_quant: bool, ) -> torch.Tensor: - num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) - expert_map = torch.full((num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - return expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) hidden_size = test_tensors.rank_tokens.size(1) is_quantized = w1.dtype == torch.float8_e4m3fn @@ -216,11 +206,6 @@ def deep_ep_moe_impl( if is_quantized: q_dtype = torch.float8_e4m3fn - # Make modular kernel - mk: FusedMoEModularKernel = make_modular_kernel( - pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, - num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant) - out_hidden_states = torch.empty_like(test_tensors.rank_tokens) total_num_tokens = test_tensors.rank_tokens.size(0) @@ -229,35 +214,54 @@ def deep_ep_moe_impl( topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end] topk_chunk = test_tensors.topk[chunk_start:chunk_end] rank_token_scales_chunk = test_tensors.rank_token_scales - if rank_token_scales_chunk is not None and rank_token_scales_chunk.size( - 0) == total_num_tokens: + if ( + rank_token_scales_chunk is not None + and rank_token_scales_chunk.size(0) == total_num_tokens + ): # per act token - rank_token_scales_chunk = rank_token_scales_chunk[ - chunk_start:chunk_end] + rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end] - out = mk.forward(hidden_states=rank_tokens_chunk, - w1=w1, - w2=w2, - topk_weights=topk_weights_chunk, - topk_ids=topk_chunk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=rank_token_scales_chunk, - a2_scale=None, - apply_router_weight_on_input=False) + quant_config = FusedMoEQuantConfig.make( + q_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token_quant, + a1_scale=rank_token_scales_chunk, + ) + + # Make modular kernel + mk: FusedMoEModularKernel = make_modular_kernel( + pg, + pgi, + low_latency_mode, + hidden_size, + dp_size, + num_experts, + num_local_experts, + q_dtype, + use_fp8_dispatch, + quant_config, + ) + + out = mk.forward( + hidden_states=rank_tokens_chunk, + w1=w1, + w2=w2, + topk_weights=topk_weights_chunk, + topk_ids=topk_chunk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + apply_router_weight_on_input=False, + ) if not skip_result_store: - out_hidden_states[chunk_start:chunk_end, :].copy_( - out, non_blocking=True) + out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True) - max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK - if low_latency_mode else total_num_tokens) + max_num_tokens_per_dp = ( + MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens + ) for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp): chunk_start = chunk_start_ @@ -266,9 +270,9 @@ def deep_ep_moe_impl( chunk_start = min(chunk_start, total_num_tokens - 1) chunk_end = min(chunk_end, total_num_tokens) - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= total_num_tokens) + process_chunk( + chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens + ) return out_hidden_states @@ -282,9 +286,11 @@ def torch_moe_impl( using_fp8_dispatch: bool, per_act_token_quant: bool, ): - - a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, - test_tensors.topk_weights) + a, topk_ids, topk_weights = ( + test_tensors.rank_tokens, + test_tensors.topk, + test_tensors.topk_weights, + ) if using_fp8_dispatch: # The DeepEP implementation is requested to dispatch using FP8. # For numerical stability for testing, emulate the fp8 dispatch by @@ -292,8 +298,11 @@ def torch_moe_impl( assert not per_act_token_quant a = test_tensors.rank_tokens aq, aq_scale = per_token_group_quant_fp8(a, 128) - a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( - a.shape).to(a.dtype) + a = ( + (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)) + .view(a.shape) + .to(a.dtype) + ) is_quantized = w1.dtype == torch.float8_e4m3fn a_dtype = a.dtype @@ -314,8 +323,9 @@ def torch_moe_impl( e_w = topk_weights[i][j] w1_e = w1[e] w2_e = w2[e] - o_i += (SiluAndMul() - (a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w + o_i += ( + SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1) + ) * e_w if is_quantized: out = out.to(dtype=a_dtype) @@ -335,28 +345,36 @@ def _deep_ep_moe( use_fp8_dispatch: bool, per_act_token_quant: bool, ): - if not low_latency_mode: assert not use_fp8_dispatch, ( - "FP8 dispatch interface is available only in low-latency mode") + "FP8 dispatch interface is available only in low-latency mode" + ) is_quantized = w1.dtype == torch.float8_e4m3fn w1 = w1.to(device=torch.cuda.current_device()) w2 = w2.to(device=torch.cuda.current_device()) if is_quantized: w1_scale = w1_scale.to( # type: ignore - device=torch.cuda.current_device()) + device=torch.cuda.current_device() + ) w2_scale = w2_scale.to( # type: ignore - device=torch.cuda.current_device()) + device=torch.cuda.current_device() + ) pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, low_latency_mode) with set_current_vllm_config(VllmConfig()): # Reference - torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, - w2_scale, use_fp8_dispatch, - per_act_token_quant) + torch_combined = torch_moe_impl( + test_tensors, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + per_act_token_quant, + ) # Splice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -406,15 +424,18 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("m,n,k", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) +@multi_gpu_test(num_gpus=2) @requires_deep_ep def test_deep_ep_moe( dtype: torch.dtype, - mnk: tuple[int, int, int], + m: int, + n: int, + k: int, num_experts: int, topk: int, world_dp_size: tuple[int, int], @@ -422,22 +443,26 @@ def test_deep_ep_moe( ): low_latency_mode = False use_fp8_dispatch = False - m, n, k = mnk current_platform.seed_everything(7) world_size, dp_size = world_dp_size - config = TestConfig(dtype=dtype, - topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts) + config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) - parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, - per_act_token_quant) + parallel_launch( + world_size, + _deep_ep_moe, + low_latency_mode, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + per_act_token_quant, + ) MNKs = [ @@ -454,22 +479,26 @@ USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("m,n,k", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) +@multi_gpu_test(num_gpus=2) @requires_deep_ep -def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], - num_experts: int, topk: int, - world_dp_size: tuple[int, int], - use_fp8_dispatch: bool): - +def test_low_latency_deep_ep_moe( + dtype: torch.dtype, + m: int, + n: int, + k: int, + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], + use_fp8_dispatch: bool, +): low_latency_mode = True - m, n, k = mnk - if (low_latency_mode - and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): + if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES: pytest.skip( f"Skipping test as hidden size {k} is not in list of supported " f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}" @@ -477,15 +506,20 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], current_platform.seed_everything(7) world_size, dp_size = world_dp_size - config = TestConfig(dtype=dtype, - topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts) + config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) - parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, - False) + parallel_launch( + world_size, + _deep_ep_moe, + low_latency_mode, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + False, + ) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 4472f34a6291a..cad0085d5ba6e 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -11,12 +11,18 @@ import math import pytest import torch +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config + # vLLM fused-expert reference (Triton fallback + DeepGEMM option) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported, - per_block_cast_to_fp8) + per_token_group_quant_fp8, +) +from vllm.utils.deep_gemm import ( + calc_diff, + is_deep_gemm_supported, + per_block_cast_to_fp8, +) BLOCK_SIZE = [128, 128] @@ -35,8 +41,10 @@ def make_block_quant_fp8_weights( w2 shape: (E, K, N) """ dtype = torch.bfloat16 - fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( - torch.float8_e4m3fn).min + fp8_max, fp8_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) # bf16 reference weights w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 @@ -52,24 +60,16 @@ def make_block_quant_fp8_weights( w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.empty(e, - n_tiles_w1, - k_tiles_w1, - device="cuda", - dtype=torch.float32) - w2_s = torch.empty(e, - n_tiles_w2, - k_tiles_w2, - device="cuda", - dtype=torch.float32) + w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32) + w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32) for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size=block_size, - use_ue8m0=True) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size=block_size, - use_ue8m0=True) + w1[i], w1_s[i] = per_block_cast_to_fp8( + w1_bf16[i], block_size=block_size, use_ue8m0=True + ) + w2[i], w2_s[i] = per_block_cast_to_fp8( + w2_bf16[i], block_size=block_size, use_ue8m0=True + ) return w1, w2, w1_s, w2_s @@ -79,21 +79,27 @@ def run_single_case(m, n, k, topk, num_experts, block_size): Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == Triton baseline within tolerance. """ - tokens_bf16 = torch.randn( - m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) + tokens_bf16 = ( + torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + .clamp_min_(-1) + .clamp_max_(1) + ) _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) # expert weight tensors - w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, - block_size) + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size) - router_logits = torch.randn(m, - num_experts, - device="cuda", - dtype=torch.float32) + router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32) topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + ) + # triton reference out_triton = fused_experts( hidden_states=tokens_bf16, @@ -102,11 +108,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - block_shape=block_size, + quant_config=quant_config, allow_deep_gemm=False, ) @@ -118,19 +120,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - block_shape=block_size, + quant_config=quant_config, allow_deep_gemm=True, ) diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f"Diff exceeded 1%: {diff}" -# Note: W1 has shape (E, 2N, K), so N = 512 -# can trigger the deepgemm path. +# Note: N <= 512 will disable the deepgemm path due to performance issues. MNKs = [ (1024, 768, 128), (1024, 768, 512), @@ -144,18 +141,17 @@ TOPKS = [2, 6] NUM_EXPERTS = [32] -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize(("m", "n", "k"), MNKs) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.skipif(not is_deep_gemm_supported(), - reason="Requires deep_gemm kernels") -def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_DEEP_GEMM", "1") +@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") +def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): + with monkeypatch.context() as mp: + mp.setenv("VLLM_USE_DEEP_GEMM", "1") _fused_moe_mod = importlib.import_module( - "vllm.model_executor.layers.fused_moe.fused_moe") + "vllm.model_executor.layers.fused_moe.fused_moe" + ) call_counter = {"cnt": 0} @@ -165,10 +161,7 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): call_counter["cnt"] += 1 return orig_fn(*args, **kwargs) - monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", - _spy_deep_gemm_moe_fp8) - - m, n, k = mnk + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8) if topk > num_experts: pytest.skip(f"topk={topk} > num_experts={num_experts}") @@ -183,6 +176,7 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): ) # ensure that the DeepGEMM path was indeed taken. - assert call_counter["cnt"] == 1, \ - f"DeepGEMM path was not executed during the test. " \ + assert call_counter["cnt"] == 1, ( + f"DeepGEMM path was not executed during the test. " f"Call counter: {call_counter['cnt']}" + ) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 52a3d2ca3b422..0780232a82640 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -6,22 +6,28 @@ import pytest import torch from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - input_to_float8) + apply_flashinfer_per_tensor_scale_fp8, + flashinfer_cutlass_moe_fp8, + register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.models.llama4 import Llama4MoE from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if not has_flashinfer_cutlass_fused_moe( -) or not current_platform.has_device_capability(100): - pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", - allow_module_level=True) +if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( + 100 +): + pytest.skip( + "Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True, + ) NUM_EXPERTS = [16] TOP_KS = [1] @@ -37,8 +43,7 @@ MNK_FACTORS = [ (1, 4096, 5120), ] -vllm_config = VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1)) +vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -72,18 +77,17 @@ class TestData: layer: torch.nn.Module @staticmethod - def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, - reorder: bool) -> "TestData": - hidden_states = torch.randn( - (m, k), device="cuda", dtype=torch.bfloat16) / 10 + def make_moe_tensors_8bit( + m: int, k: int, n: int, e: int, reorder: bool + ) -> "TestData": + hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) # Scale to fp8 _, a1_scale = input_to_float8(hidden_states) a1_scale = 1.0 / a1_scale - a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to( - dtype=torch.float32) + a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32) w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) @@ -100,8 +104,7 @@ class TestData: # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if reorder: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) layer.custom_routing_function = Llama4MoE.custom_routing_function layer.intermediate_size_per_partition = n layer.ep_rank = 0 @@ -136,14 +139,23 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=td.hidden_states, router_logits=score, use_grouped_topk=False, top_k=topk, renormalize=False, custom_routing_function=Llama4MoE.custom_routing_function, - scoring_func="softmax") + scoring_func="softmax", + ) + + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + per_act_token_quant=False, + ) output = fused_experts( td.hidden_states, @@ -153,15 +165,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( topk_ids=topk_ids, inplace=False, activation="silu", - use_fp8_w8a8=True, - per_channel_quant=False, global_num_experts=e, expert_map=None, - w1_scale=td.w13_weight_scale, - w2_scale=td.w2_weight_scale, - a1_scale=td.a1_scale, - a2_scale=td.a2_scale, apply_router_weight_on_input=True, + quant_config=quant_config, ) flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( @@ -173,12 +180,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( top_k=topk, num_expert_group=None, topk_group=None, - apply_router_weight_on_input=True) + apply_router_weight_on_input=True, + ) - torch.testing.assert_close(output, - flashinfer_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2) @pytest.mark.skip( @@ -201,14 +206,23 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=td.hidden_states, router_logits=score, use_grouped_topk=False, top_k=topk, renormalize=False, custom_routing_function=Llama4MoE.custom_routing_function, - scoring_func="softmax") + scoring_func="softmax", + ) + + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + per_act_token_quant=False, + ) output = fused_experts( td.hidden_states, @@ -218,15 +232,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk_ids=topk_ids, inplace=False, activation="silu", - use_fp8_w8a8=True, - per_channel_quant=False, global_num_experts=e, expert_map=None, - w1_scale=td.w13_weight_scale, - w2_scale=td.w2_weight_scale, - a1_scale=td.a1_scale, - a2_scale=td.a2_scale, apply_router_weight_on_input=True, + quant_config=quant_config, ) td.layer.dp_size = 1 @@ -242,7 +251,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( apply_router_weight_on_input=True, ) - torch.testing.assert_close(output, - flashinfer_cutlass_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close( + output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 + ) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1c14df2b914aa..18cfd4f79092d 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -3,27 +3,34 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.moe.utils import make_test_quant_config +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) + FlashInferExperts, + is_valid_flashinfer_cutlass_fused_moe, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if not has_flashinfer_cutlass_fused_moe( -) or not current_platform.has_device_capability(100): - pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", - allow_module_level=True) +if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( + 100 +): + pytest.skip( + "Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True, + ) MNK_FACTORS = [ (2, 1024, 1024), @@ -41,106 +48,89 @@ MNK_FACTORS = [ @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", [40, 64, 256]) -#@pytest.mark.parametrize("e", [128, 256]) @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() -def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype): +def test_flashinfer_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): current_platform.seed_everything(7) with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 - (_, w1_q, w1_blockscale, - w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( - e, - n, - k, - in_dtype=dtype, - quant_dtype="nvfp4", - block_shape=None, # use quant_blocksize? - per_act_token_quant=False, - ) + w1_q, w2_q, quant_config = make_test_quant_config( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, + per_act_token_quant=False, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) - - a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) - assert w1_gs is not None - assert w2_gs is not None - assert w1_blockscale is not None - assert w2_blockscale is not None - flashinfer_experts = FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - FlashInferExperts( - a1_gscale=a1_gs, - g1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, - g2_alphas=(1 / w2_gs), - out_dtype=dtype, - quant_dtype="nvfp4", - )) + FlashInferExperts(out_dtype=dtype, quant_config=quant_config), + ) flashinfer_output = flashinfer_experts( hidden_states=a, w1=w1_q, - w1_scale=w1_blockscale, w2=w2_q, - w2_scale=w2_blockscale, - a1_scale=a1_gs, - a2_scale=a2_gs, topk_weights=topk_weights, topk_ids=topk_ids, ) # Reference check: - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) _, m_k = a_fp4.shape - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=a.dtype, - device=a.device, - block_size=quant_blocksize) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + quant_config.w1_scale[idx], + (1 / quant_config.g1_alphas[idx]), + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + quant_config.w2_scale[idx], + (1 / quant_config.g2_alphas[idx]), + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize, + ) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) - torch.testing.assert_close(torch_output, - flashinfer_output, - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close( + torch_output, flashinfer_output, atol=1e-1, rtol=1e-1 + ) if __name__ == "__main__": diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 54f2351bf6d9b..f78596d220bfa 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -17,19 +17,21 @@ if not has_triton_kernels(): import triton_kernels.swiglu from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.numerics import InFlexData -from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp, - upcast_from_mxfp) +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize) + BatchedPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - BatchedOAITritonExperts, triton_kernel_moe_forward) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) + BatchedOAITritonExperts, + triton_kernel_moe_forward, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.utils import shuffle_weight from vllm.utils import round_up @@ -45,13 +47,11 @@ def deshuffle(w: torch.Tensor): def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): randbits = [torch.randperm(E) for _ in range(M)] x_list = [ - (-1)**i * - ((16384 + - ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) + (-1) ** i + * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) for i, bits in enumerate(randbits) ] - exp_data = torch.stack(x_list).to( - device="cuda") # simulating gate_output (M, E) + exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E) # create input tensor x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") @@ -119,20 +119,21 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): value=0, ) - w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0), - mode="constant", - value=0) - w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0), - mode="constant", - value=0) + w1_bias_tri = F.pad( + w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0 + ) + w2_bias_tri = F.pad( + w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0 + ) x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0) - w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout( - mx_axis=1) + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w_scale_layout, w_scale_layout_opts = ( layout.make_default_matmul_mxfp4_w_scale_layout( - mx_axis=1, num_warps=num_warps)) + mx_axis=1, num_warps=num_warps + ) + ) w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1) @@ -140,29 +141,33 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1) - w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, - **w_layout_opts) + w1_tri = convert_layout( + wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts + ) w1_scale_tri = convert_layout( wrap_torch_tensor(w1_scale_tri), w_scale_layout, **w_scale_layout_opts, ) - w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, - **w_layout_opts) + w2_tri = convert_layout( + wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts + ) w2_scale_tri = convert_layout( wrap_torch_tensor(w2_scale_tri), w_scale_layout, **w_scale_layout_opts, ) - pc1 = PrecisionConfig(weight_scale=w1_scale_tri, - flex_ctx=FlexCtx(rhs_data=InFlexData())) - pc2 = PrecisionConfig(weight_scale=w2_scale_tri, - flex_ctx=FlexCtx(rhs_data=InFlexData())) + pc1 = PrecisionConfig( + weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) + pc2 = PrecisionConfig( + weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) # tucuate so the rest can run properly - w1 = w1[..., :K, :2 * N] + w1 = w1[..., :K, : 2 * N] w2 = w2[..., :N, :K] w1 = deshuffle(w1) @@ -260,7 +265,8 @@ class Case: @pytest.mark.parametrize( ", ".join(f.name for f in fields(Case)), [ - tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + tuple(getattr(case, f.name) for f in fields(Case)) + for case in [ # Case(a_dtype="bf16", w_dtype="bf16"), # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), Case(a_dtype="bf16", w_dtype="mx4") @@ -293,6 +299,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): pc2, ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) + quant_config = FusedMoEQuantConfig.make( + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2, + ) + out_triton_monolithic = triton_kernel_moe_forward( hidden_states=x_tri, w1=w1_tri, @@ -300,10 +313,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): gating_output=exp_data_tri, topk=topk, renormalize=True, - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, + quant_config=quant_config, ) out_triton_monolithic = out_triton_monolithic[..., :K] @@ -316,10 +326,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): gating_output=exp_data, topk=topk, ) - assert_close(ref=out_ref, - tri=out_triton_monolithic, - maxtol=0.025, - rmstol=0.005) + assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005) def batched_moe( @@ -336,6 +343,13 @@ def batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + w1_precision=w1_precision, + w2_precision=w2_precision, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize( max_num_tokens, @@ -344,19 +358,12 @@ def batched_moe( rank=0, ), BatchedOAITritonExperts( - None, max_num_tokens=max_num_tokens, num_dispatchers=1, - w1_precision=w1_precision, - w2_precision=w2_precision, + quant_config=quant_config, ), ) - extra_expert_args = { - "w1_bias": w1_bias, - "w2_bias": w2_bias, - } - topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) return fused_experts( @@ -365,14 +372,14 @@ def batched_moe( w2, topk_weight, topk_ids, - extra_expert_args=extra_expert_args, ) @pytest.mark.parametrize( ", ".join(f.name for f in fields(Case)), [ - tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + tuple(getattr(case, f.name) for f in fields(Case)) + for case in [ # Case(a_dtype="bf16", w_dtype="bf16"), # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), Case(a_dtype="bf16", w_dtype="mx4") diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py new file mode 100644 index 0000000000000..3f4f142be7674 --- /dev/null +++ b/tests/kernels/moe/test_grouped_topk.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MoE grouped topk kernel + +Run `pytest tests/kernels/moe/test_grouped_topk.py`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_grouped_topk, + grouped_topk, +) +from vllm.platforms import current_platform + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("n_token", [1, 33, 64]) +@pytest.mark.parametrize("n_hidden", [1024, 2048]) +@pytest.mark.parametrize("n_expert", [16]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("num_expert_group", [8]) +@pytest.mark.parametrize("topk_group", [2]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_grouped_topk( + monkeypatch: pytest.MonkeyPatch, + n_token: int, + n_hidden: int, + n_expert: int, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str, + routed_scaling_factor: float, + dtype: torch.dtype, +): + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") + gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda") + e_score_correction_bias = torch.randn( + (n_expert,), dtype=torch.float32, device="cuda" + ) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + baseline_topk_weights, baseline_topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + test_topk_weights, test_topk_ids = fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + if renormalize: + torch.testing.assert_close( + baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0 + ) + torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index d45982384eb3b..b028e676f086f 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -5,28 +5,40 @@ import copy import textwrap import traceback from itertools import product -from typing import Optional +from typing import Any, Optional import pytest import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.config import VllmConfig, current_platform, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, - reference_moe_impl, - run_modular_kernel) +from .modular_kernel_tools.common import ( + Config, + RankTensors, + WeightTensors, + reference_moe_impl, + run_modular_kernel, +) from .modular_kernel_tools.mk_objects import ( - MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) -from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, - parallel_launch_with_config) + MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, + TestMoEQuantConfig, + expert_info, +) +from .modular_kernel_tools.parallel_utils import ( + ProcessGroupInfo, + parallel_launch_with_config, +) -has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx() - or has_flashinfer_cutlass_fused_moe()) +has_any_multi_gpu_package = ( + has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe() +) meets_multi_gpu_requirements = pytest.mark.skipif( not has_any_multi_gpu_package, @@ -54,7 +66,7 @@ def rank_worker( pgi: ProcessGroupInfo, vllm_config: VllmConfig, cpu_group, - config: Config, + base_config: Config, weights: WeightTensors, verbose: bool, ): @@ -62,42 +74,43 @@ def rank_worker( # sanity check from vllm import envs - if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + if base_config.fused_moe_chunk_size is not None: + assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() - Ms = config.Ms + Ms = base_config.Ms assert isinstance(Ms, list) - TOPKs = config.topks + TOPKs = base_config.topks assert isinstance(TOPKs, list) exceptions = [] count = 0 for m, topk in product(Ms, TOPKs): + # override m and topk + config = copy.deepcopy(base_config) + config.Ms = m + config.topks = topk + try: print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") count = count + 1 - # override m and topk - cfgx = copy.deepcopy(config) - cfgx.Ms = m - cfgx.topks = topk # inputs for rank - rank_tensors = RankTensors.make(cfgx, pgi) + rank_tensors = RankTensors.make(config, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors) with set_current_vllm_config(vllm_config): - ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + ref_out = reference_moe_impl(config, weights, rank_tensors) if config.quant_dtype == "nvfp4": - atol = 1e-1 - rtol = 1e-1 + atol = 1e-1 if config.K < 4096 else 2e-1 + rtol = 1e-1 if config.K < 4096 else 2e-1 else: atol = 3e-2 rtol = 3e-2 @@ -111,27 +124,29 @@ def rank_worker( if len(exceptions) > 0: raise RuntimeError( f"{len(exceptions)} of {count} tests failed in child process, " - f"rank={pgi.rank}.") + f"rank={pgi.rank}." + ) else: - print(f"{count} of {count} tests passed in child process, " - f"rank={pgi.rank}.") + print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.") def run(config: Config, verbose: bool): - assert config.is_valid() + assert config.is_valid()[0] + assert not is_nyi_config(config) weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights, verbose) + parallel_launch_with_config( + config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose + ) Ms = [32, 64] # hidden sizes, making this too large will cause fp4 tests to fail. # Also needs to be a multiple of 1024 for deep_gemm. Ks = [2048] -Ns = [2048] +Ns = [1024] TOPKs = [4, 1] Es = [32] DTYPEs = [torch.bfloat16] @@ -145,30 +160,104 @@ def is_nyi_config(config: Config) -> bool: if info.needs_matching_quant: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. - unsupported_quant_config = ((config.is_per_act_token_quant + - config.is_per_out_ch_quant) == 1) + unsupported_quant_config = ( + config.is_per_act_token_quant + config.is_per_out_ch_quant + ) == 1 return unsupported_quant_config return not info.supports_expert_map -@pytest.mark.parametrize("k", Ks) -@pytest.mark.parametrize("n", Ns) -@pytest.mark.parametrize("e", Es) -@pytest.mark.parametrize("dtype", DTYPEs) -@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +def generate_valid_test_cases( + world_size: int, prepare_finalize_types +) -> list[tuple[Any, ...]]: + cases = [] + total = 0 + + for k, n, e, dtype, quant_config, combination, chunk_size in product( + Ks, + Ns, + Es, + DTYPEs, + MK_QUANT_CONFIGS, + product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES), + FUSED_MOE_CHUNK_SIZEs, + ): + total = total + 1 + + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=chunk_size, + world_size=world_size, + ) + + # TODO(bnell): figure out how to get verbose flag here. + verbose = False # pytestconfig.getoption('verbose') > 0 + + valid, reason = config.is_valid() + + if not valid: + if verbose: + print(f"Test config {config} is not valid: {reason}") + continue + + if is_nyi_config(config): + if verbose: + print(f"Test config {config} is nyi.") + continue + + cases.append( + ( + k, + n, + e, + dtype, + quant_config, + combination[0], + combination[1], + chunk_size, + world_size, + ) + ) + + print(f"{len(cases)} of {total} valid configs generated.") + + return cases + + @pytest.mark.parametrize( - "combination", - product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) -@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) -@pytest.mark.parametrize("world_size", [2]) + "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", + generate_valid_test_cases( + world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + ), +) @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( - k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[FusedMoEQuantConfig], - combination: tuple[mk.FusedMoEPrepareAndFinalize, - mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): + k: int, + n: int, + e: int, + dtype: torch.dtype, + quant_config: Optional[TestMoEQuantConfig], + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + chunk_size: Optional[int], + world_size: int, + pytestconfig, +): + if cuda_device_count_stateless() < world_size: + pytest.skip( + f"Not enough GPUs available to run, got " + f"{cuda_device_count_stateless()} exepected " + f"{world_size}." + ) config = Config( Ms=Ms, @@ -178,38 +267,33 @@ def test_modular_kernel_combinations_multigpu( topks=TOPKs, dtype=dtype, quant_config=quant_config, - prepare_finalize_type=combination[0], - fused_experts_type=combination[1], - fused_moe_chunk_size=fused_moe_chunk_size, + prepare_finalize_type=prepare_finalize_type, + fused_experts_type=fused_experts_type, + fused_moe_chunk_size=chunk_size, world_size=world_size, ) - - if not config.is_valid(): - pytest.skip(f"Tests config {config} is not valid. Skipping ...") - - if is_nyi_config(config): - pytest.skip(f"Tests config {config} is nyi. Skipping ...") - - verbosity = pytestconfig.getoption('verbose') + verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) -@pytest.mark.parametrize("k", Ks) -@pytest.mark.parametrize("n", Ns) -@pytest.mark.parametrize("e", Es) -@pytest.mark.parametrize("dtype", DTYPEs) -@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", - product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) -@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) -@pytest.mark.parametrize("world_size", [1]) + "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", + generate_valid_test_cases( + world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES + ), +) def test_modular_kernel_combinations_singlegpu( - k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[FusedMoEQuantConfig], - combination: tuple[mk.FusedMoEPrepareAndFinalize, - mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): + k: int, + n: int, + e: int, + dtype: torch.dtype, + quant_config: Optional[TestMoEQuantConfig], + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + chunk_size: Optional[int], + world_size: int, + pytestconfig, +): config = Config( Ms=Ms, K=k, @@ -218,31 +302,27 @@ def test_modular_kernel_combinations_singlegpu( topks=TOPKs, dtype=dtype, quant_config=quant_config, - prepare_finalize_type=combination[0], - fused_experts_type=combination[1], - fused_moe_chunk_size=fused_moe_chunk_size, + prepare_finalize_type=prepare_finalize_type, + fused_experts_type=fused_experts_type, + fused_moe_chunk_size=chunk_size, world_size=world_size, ) - if not config.is_valid(): - pytest.skip(f"Tests config {config} is not valid. Skipping ...") - - if is_nyi_config(config): - pytest.skip(f"Tests config {config} is nyi. Skipping ...") - - verbosity = pytestconfig.getoption('verbose') + verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) -if __name__ == '__main__': +if __name__ == "__main__": # Ability to test individual PrepareAndFinalize and FusedExperts combination - from .modular_kernel_tools.cli_args import (make_config, - make_config_arg_parser) - parser = make_config_arg_parser(description=( - "Run single prepare-finalize & fused-experts combination test" - "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501 - "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" - )) + from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser + + parser = make_config_arg_parser( + description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + ) + ) args = parser.parse_args() config = make_config(args) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0ea9667914fd5..f357d149bd071 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/test_moe.py`. """ + import functools from typing import Callable, Optional, Union @@ -15,25 +16,38 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) + fused_topk, + modular_triton_fused_moe, +) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) + fused_moe as iterative_moe, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_permute_bias) + marlin_permute_bias, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like) + rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - marlin_quant_fp8_torch) + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize, marlin_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) + awq_marlin_quantize, + marlin_quantize, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -84,13 +98,15 @@ def run_moe_test( if isinstance(baseline, torch.Tensor): baseline_output = baseline else: - baseline_output = baseline(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + baseline_output = baseline( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) # Pad the weight if moe padding is enabled if padding: @@ -102,34 +118,35 @@ def run_moe_test( torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(score, 0) - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + test_output = moe_fn( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) if use_cudagraph: test_output.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + test_output = moe_fn( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(test_output, - baseline_output, - atol=atol, - rtol=rtol) + torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) return baseline_output @@ -173,11 +190,8 @@ def test_fused_moe( if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] @@ -187,14 +201,9 @@ def test_fused_moe( # # Setup test functions # + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=None) + m_fused_moe_fn = modular_triton_fused_moe(quant_config) def m_fused_moe( a: torch.Tensor, @@ -206,13 +215,15 @@ def test_fused_moe( expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - return m_fused_moe_fn(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map) + return m_fused_moe_fn( + a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) fused_moe_fn = functools.partial(fused_moe, renormalize=False) @@ -236,19 +247,22 @@ def test_fused_moe( # setup code in case we are able to revisit this later. use_compile = False - use_cudagraph = (n >= 1024 and k >= 1024 - and current_platform.is_cuda_alike()) + use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() with set_current_vllm_config(vllm_config): baseline_output = runner(torch_moe, iterative_moe) - runner(baseline_output, - fused_moe_fn, - use_compile=use_compile, - use_cudagraph=use_cudagraph) - runner(baseline_output, - m_fused_moe, - use_compile=use_compile, - use_cudagraph=use_cudagraph) + runner( + baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) + runner( + baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) @pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) @@ -259,9 +273,18 @@ def test_fused_moe( @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) -def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, - ep_size: int, dtype: torch.dtype, group_size: int, - has_zp: bool, weight_bits: int): +def test_fused_moe_wn16( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + group_size: int, + has_zp: bool, + weight_bits: int, +): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -276,35 +299,40 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_ref = w1.clone() w2_ref = w2.clone() - w1_qweight = torch.empty((e, 2 * n, k // pack_factor), - device="cuda", - dtype=torch.uint8) - w2_qweight = torch.empty((e, k, n // pack_factor), - device="cuda", - dtype=torch.uint8) - w1_scales = torch.empty((e, 2 * n, k // group_size), - device="cuda", - dtype=dtype) - w2_scales = torch.empty((e, k, n // group_size), - device="cuda", - dtype=dtype) - w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), - device="cuda", - dtype=torch.uint8) - w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), - device="cuda", - dtype=torch.uint8) + w1_qweight = torch.empty( + (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 + ) + w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) + w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) + w1_qzeros = torch.empty( + (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 + ) + w2_qzeros = torch.empty( + (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 + ) for i in range(e * 2): expert_id = i % e if i // e == 0: - w, w_ref, w_qweight, w_scales, w_qzeros = \ - w1, w1_ref, w1_qweight, w1_scales, w1_qzeros + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w1, + w1_ref, + w1_qweight, + w1_scales, + w1_qzeros, + ) else: - w, w_ref, w_qweight, w_scales, w_qzeros = \ - w2, w2_ref, w2_qweight, w2_scales, w2_qzeros + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w2, + w2_ref, + w2_qweight, + w2_scales, + w2_qzeros, + ) weight, qweight, scales, qzeros = quantize_weights( - w[expert_id].T, quant_type, group_size, has_zp, False) + w[expert_id].T, quant_type, group_size, has_zp, False + ) weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T @@ -323,11 +351,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1_ref = w1_ref[e_ids] w2_ref = w2_ref[e_ids] @@ -340,28 +365,33 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None + if weight_bits == 4: + quant_config_builder = int4_w4a16_moe_quant_config + else: + assert weight_bits == 8 + quant_config_builder = int8_w8a16_moe_quant_config + + quant_config = quant_config_builder( + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size], + ) + with set_current_vllm_config(vllm_config): - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, - w1_ref, - w2_ref, - score, - topk, - expert_map=e_map) + triton_output = fused_moe( + a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + global_num_experts=e, + expert_map=e_map, + quant_config=quant_config, + ) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -369,16 +399,20 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, - monkeypatch): +def test_mixtral_moe( + dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch +): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" # clear the cache before every test from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, + ) + is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -386,17 +420,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") - monkeypatch.setenv('RANK', "0") - monkeypatch.setenv('LOCAL_RANK', "0") - monkeypatch.setenv('WORLD_SIZE', "1") - monkeypatch.setenv('MASTER_ADDR', 'localhost') - monkeypatch.setenv('MASTER_PORT', '12345') + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "12345") init_distributed_environment() # Instantiate our and huggingface's MoE blocks vllm_config.compilation_config.static_forward_context = dict() - with (set_current_vllm_config(vllm_config), - set_forward_context(None, vllm_config)): + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") vllm_moe = MixtralMoE( @@ -412,27 +445,30 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) + weights = ( + hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data, + ) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn( - (1, 64, config.hidden_size)).to(dtype).to("cuda") + hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) # Pad the weight if moe padding is enabled if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., - 0:-128], - requires_grad=False) - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., - 0:-128], - requires_grad=False) + vllm_moe.experts.w13_weight = Parameter( + F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[ + ..., 0:-128 + ], + requires_grad=False, + ) + vllm_moe.experts.w2_weight = Parameter( + F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], + requires_grad=False, + ) torch.cuda.synchronize() torch.cuda.empty_cache() @@ -447,21 +483,23 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, } if use_rocm_aiter: - # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501 - # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501 - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=0.01, - atol=100) + # The values of rtol and atol are set based on the tests in ROCM AITER package. + # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 + torch.testing.assert_close( + hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100 + ) else: - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=mixtral_moe_tol[dtype], - atol=mixtral_moe_tol[dtype]) + torch.testing.assert_close( + hf_states.flatten(0, 1), + vllm_states, + rtol=mixtral_moe_tol[dtype], + atol=mixtral_moe_tol[dtype], + ) def marlin_moe_generate_valid_test_cases(): import itertools + m_list = [1, 123, 666] n_list = [128, 1024] k_list = [256, 2048] @@ -480,16 +518,24 @@ def marlin_moe_generate_valid_test_cases(): ] is_k_full_list = [True, False] - all_combinations = itertools.product(m_list, n_list, k_list, e_list, - topk_list, ep_size_list, dtype_list, - group_size_list, act_order_list, - quant_type_list, is_k_full_list) + all_combinations = itertools.product( + m_list, + n_list, + k_list, + e_list, + topk_list, + ep_size_list, + dtype_list, + group_size_list, + act_order_list, + quant_type_list, + is_k_full_list, + ) - def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, - quant_type, is_k_full): - - if quant_type == scalar_types.float8_e4m3fn and \ - group_size not in [-1, 128]: + def is_invalid( + m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full + ): + if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: return False if quant_type == scalar_types.float4_e2m1f: if group_size not in [16, 32]: @@ -518,9 +564,10 @@ def marlin_moe_generate_valid_test_cases(): @pytest.mark.flaky(reruns=2) -@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," - "act_order, quant_type, is_k_full"), - marlin_moe_generate_valid_test_cases()) +@pytest.mark.parametrize( + ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases(), +) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, @@ -545,7 +592,7 @@ def test_fused_marlin_moe( if ep_size > 1: local_e = e // ep_size e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] @@ -563,11 +610,13 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: if group_size == 16: - w_ref1, qweight1, scales1, global_scale1 = \ + w_ref1, qweight1, scales1, global_scale1 = ( rand_marlin_weight_nvfp4_like(w1[i], group_size) + ) else: - w_ref1, qweight1, scales1 = \ - rand_marlin_weight_mxfp4_like(w1[i], group_size) + w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like( + w1[i], group_size + ) global_scale1 = None w_ref1_l.append(w_ref1.T) @@ -576,14 +625,14 @@ def test_fused_marlin_moe( if global_scale1 is not None: global_scale1_l.append(global_scale1) elif quant_type == scalar_types.float8_e4m3fn: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size) + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size) + w1[i].transpose(1, 0), quant_type, group_size + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -591,9 +640,9 @@ def test_fused_marlin_moe( zeros1_l.append(zeros1) else: test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ - marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -620,11 +669,13 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: if group_size == 16: - w_ref2, qweight2, scales2, global_scale2 = \ + w_ref2, qweight2, scales2, global_scale2 = ( rand_marlin_weight_nvfp4_like(w2[i], group_size) + ) else: - w_ref2, qweight2, scales2 = \ - rand_marlin_weight_mxfp4_like(w2[i], group_size) + w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like( + w2[i], group_size + ) global_scale2 = None w_ref2_l.append(w_ref2.T) @@ -633,14 +684,14 @@ def test_fused_marlin_moe( if global_scale2 is not None: global_scale2_l.append(global_scale2) elif quant_type == scalar_types.float8_e4m3fn: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size) + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size) + w2[i].transpose(1, 0), quant_type, group_size + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -648,9 +699,9 @@ def test_fused_marlin_moe( zeros2_l.append(zeros2) else: test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ - marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -671,12 +722,7 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, - w_ref1, - w_ref2, - score, - topk, - expert_map=e_map) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -700,7 +746,8 @@ def test_fused_marlin_moe( w1_zeros=zeros1, w2_zeros=zeros2, quant_type_id=quant_type.id, - is_k_full=is_k_full) + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @@ -734,9 +781,9 @@ def test_fused_marlin_moe_with_bias(m): for i in range(w1.shape[0]): test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ - marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -763,9 +810,9 @@ def test_fused_marlin_moe_with_bias(m): for i in range(w2.shape[0]): test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ - marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -788,8 +835,7 @@ def test_fused_marlin_moe_with_bias(m): topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, - b_bias2) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -813,7 +859,8 @@ def test_fused_marlin_moe_with_bias(m): w1_zeros=zeros1, w2_zeros=zeros2, quant_type_id=quant_type.id, - is_k_full=is_k_full) + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @@ -821,34 +868,36 @@ def test_fused_marlin_moe_with_bias(m): def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 - topk_ids = torch.randint(0, - num_experts, (3, 4), - dtype=torch.int32, - device='cuda') + topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - opcheck(torch.ops._moe_C.moe_align_block_size, - (topk_ids, num_experts, block_size, sorted_ids, expert_ids, - num_tokens_post_pad)) + opcheck( + torch.ops._moe_C.moe_align_block_size, + ( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): input = torch.randn((m, topk, k), device="cuda", dtype=dtype) @@ -860,3 +909,72 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) opcheck(torch.ops._moe_C.moe_sum, (input, actual)) + + +@pytest.mark.parametrize("m", [1, 33]) +@pytest.mark.parametrize("n,k", [(128, 128)]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("with_bias", [False, True]) +@pytest.mark.parametrize("activation", ["silu"]) +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test") +def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation): + from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE + + device = "cpu" + torch.manual_seed(7) + + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + router_logits = torch.randn((m, e), device=device, dtype=dtype) + + b1 = b2 = None + if with_bias: + b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10 + b2 = torch.randn((e, k), device=device, dtype=dtype) / 10 + + ref = ( + torch_moe(a, w13, w2, router_logits, topk, b1, b2) + if with_bias + else torch_moe(a, w13, w2, router_logits, topk) + ) + + class _Dummy(torch.nn.Module): + def __init__(self, w13, w2, b1=None, b2=None): + super().__init__() + self.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + self.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + if b1 is not None: + self.w13_bias = torch.nn.Parameter(b1, requires_grad=False) + if b2 is not None: + self.w2_bias = torch.nn.Parameter(b2, requires_grad=False) + + layer = _Dummy(w13, w2, b1, b2).to(dtype) + fused = CPUFusedMOE(layer) + out = fused( + layer=layer, + x=a, + use_grouped_topk=False, + top_k=topk, + router_logits=router_logits, + renormalize=False, + global_num_experts=e, + expert_map=None, + custom_routing_function=None, + scoring_func="softmax", + routed_scaling_factor=1.0, + e_score_correction_bias=None, + apply_router_weight_on_input=False, + activation=activation, + ) + + # Tolerances: fp32 tight; bf16 looser (esp. with bias) + if dtype == torch.float32: + atol = 1e-3 + elif with_bias: + atol = 8e-2 + else: + atol = 5e-2 + torch.testing.assert_close(out, ref, atol=atol, rtol=0) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 5dfc8d9fab32b..f92526e749557 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -11,7 +11,8 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.platforms import current_platform from vllm.utils import round_up @@ -60,30 +61,33 @@ def _verify_expert_level_sorting( in topk_ids in the final sorted_ids however this does not impact quality. """ # Group tokens by expert from the golden implementation - golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids, - expert_ids, block_size, - valid_length, total_tokens) + golden_expert_tokens = _group_tokens_by_expert( + golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens + ) - actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids, - expert_ids, block_size, - valid_length, total_tokens) + actual_expert_tokens = _group_tokens_by_expert( + actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens + ) - assert set(golden_expert_tokens.keys()) == set( - actual_expert_tokens.keys()), ( - f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " - f"actual={set(actual_expert_tokens.keys())}") + assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), ( + f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " + f"actual={set(actual_expert_tokens.keys())}" + ) for expert_id in golden_expert_tokens: - golden_tokens = torch.tensor(golden_expert_tokens[expert_id], - device=actual_sorted_ids.device) - actual_tokens = torch.tensor(actual_expert_tokens[expert_id], - device=actual_sorted_ids.device) + golden_tokens = torch.tensor( + golden_expert_tokens[expert_id], device=actual_sorted_ids.device + ) + actual_tokens = torch.tensor( + actual_expert_tokens[expert_id], device=actual_sorted_ids.device + ) assert torch.equal( - torch.sort(golden_tokens)[0], - torch.sort(actual_tokens)[0]), ( - f"Expert {expert_id} token mismatch: " - f"golden={golden_expert_tokens[expert_id]}, " - f"actual={actual_expert_tokens[expert_id]}") + torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0] + ), ( + f"Expert {expert_id} token mismatch: " + f"golden={golden_expert_tokens[expert_id]}, " + f"actual={actual_expert_tokens[expert_id]}" + ) def torch_moe_align_block_size( @@ -104,40 +108,38 @@ def torch_moe_align_block_size( if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - flattened_token_indices = torch.arange(topk_ids.numel(), - device=topk_ids.device, - dtype=torch.int32) + flattened_token_indices = torch.arange( + topk_ids.numel(), device=topk_ids.device, dtype=torch.int32 + ) flattened_expert_ids = topk_ids.flatten() - sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, - stable=True) + sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True) sorted_token_indices = flattened_token_indices[sort_indices] - expert_token_counts = torch.zeros(num_experts, - dtype=torch.int64, - device=topk_ids.device) + expert_token_counts = torch.zeros( + num_experts, dtype=torch.int64, device=topk_ids.device + ) for expert_id in range(num_experts): mask = sorted_expert_ids == expert_id expert_token_counts[expert_id] = mask.sum() - expert_padded_counts = torch.zeros(num_experts, - dtype=torch.int64, - device=topk_ids.device) + expert_padded_counts = torch.zeros( + num_experts, dtype=torch.int64, device=topk_ids.device + ) for expert_id in range(num_experts): original_count = expert_token_counts[expert_id] if original_count > 0: expert_padded_counts[expert_id] = ( - (original_count + block_size - 1) // block_size) * block_size + (original_count + block_size - 1) // block_size + ) * block_size sorted_token_ids = torch.full( - (max_num_tokens_padded, ), + (max_num_tokens_padded,), topk_ids.numel(), dtype=torch.int32, device=topk_ids.device, ) max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size - expert_ids = torch.zeros(max_num_blocks, - dtype=torch.int32, - device=topk_ids.device) + expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device) current_pos = 0 current_block = 0 @@ -147,20 +149,20 @@ def torch_moe_align_block_size( num_expert_tokens = expert_tokens.shape[0] if num_expert_tokens > 0: - sorted_token_ids[current_pos:current_pos + - num_expert_tokens] = (expert_tokens) + sorted_token_ids[current_pos : current_pos + num_expert_tokens] = ( + expert_tokens + ) expert_blocks_needed = expert_padded_counts[expert_id] // block_size - expert_ids[current_block:current_block + - expert_blocks_needed] = (expert_id) + expert_ids[current_block : current_block + expert_blocks_needed] = expert_id current_pos += expert_padded_counts[expert_id] current_block += expert_blocks_needed total_padded_tokens = expert_padded_counts.sum() - num_tokens_post_pad = torch.tensor([total_padded_tokens], - dtype=torch.int32, - device=topk_ids.device) + num_tokens_post_pad = torch.tensor( + [total_padded_tokens], dtype=torch.int32, device=topk_ids.device + ) if expert_map is not None: expert_ids = expert_map[expert_ids] @@ -173,37 +175,32 @@ def torch_moe_align_block_size( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("pad_sorted_ids", [False, True]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_moe_align_block_size(m: int, topk: int, num_experts: int, - block_size: int, pad_sorted_ids: bool): +def test_moe_align_block_size( + m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool +): """Test moe_align_block_size without expert mapping""" topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) for i in range(m): experts = torch.randperm(num_experts, device="cuda")[:topk] topk_ids[i] = experts - actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( - moe_align_block_size( - topk_ids=topk_ids, - block_size=block_size, - num_experts=num_experts, - pad_sorted_ids=pad_sorted_ids, - )) + actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + pad_sorted_ids=pad_sorted_ids, + ) golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( torch_moe_align_block_size( topk_ids=topk_ids, block_size=block_size, num_experts=num_experts, pad_sorted_ids=pad_sorted_ids, - )) + ) + ) - torch.testing.assert_close(actual_num_tokens, - golden_num_tokens, - atol=0, - rtol=0) - torch.testing.assert_close(actual_expert_ids, - golden_expert_ids, - atol=0, - rtol=0) + torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0) + torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0) # For sorted_token_ids, verify block-level correctness rather than exact # order Tokens within each expert's blocks can be in any order, but expert @@ -219,16 +216,18 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int, total_tokens = m * topk assert actual_num_tokens.item() % block_size == 0, ( - "num_tokens_post_pad should be divisible by block_size") + "num_tokens_post_pad should be divisible by block_size" + ) assert actual_num_tokens.item() >= total_tokens, ( - "num_tokens_post_pad should be at least total_tokens") + "num_tokens_post_pad should be at least total_tokens" + ) valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens] assert len(valid_tokens) == total_tokens, ( - f"Should have exactly {total_tokens} valid tokens, " - f"got {len(valid_tokens)}") - assert (actual_expert_ids >= 0).all() and ( - actual_expert_ids - < num_experts).all(), "expert_ids should contain valid expert indices" + f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}" + ) + assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), ( + "expert_ids should contain valid expert indices" + ) @pytest.mark.parametrize("m", [16, 32]) @@ -236,46 +235,37 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int, @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("block_size", [64]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_moe_align_block_size_with_expert_map(m: int, topk: int, - num_experts: int, - block_size: int): +def test_moe_align_block_size_with_expert_map( + m: int, topk: int, num_experts: int, block_size: int +): """Test moe_align_block_size with expert mapping (EP scenario)""" topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) for i in range(m): experts = torch.randperm(num_experts, device="cuda")[:topk] topk_ids[i] = experts - expert_map = torch.full((num_experts, ), - -1, - device="cuda", - dtype=torch.int32) + expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32) local_experts = list(range(0, num_experts, 2)) for i, expert_id in enumerate(local_experts): expert_map[expert_id] = i - actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( - moe_align_block_size( - topk_ids=topk_ids, - block_size=block_size, - num_experts=num_experts, - expert_map=expert_map, - )) + actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + expert_map=expert_map, + ) golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( torch_moe_align_block_size( topk_ids=topk_ids, block_size=block_size, num_experts=num_experts, expert_map=expert_map, - )) + ) + ) - torch.testing.assert_close(actual_num_tokens, - golden_num_tokens, - atol=0, - rtol=0) - torch.testing.assert_close(actual_expert_ids, - golden_expert_ids, - atol=0, - rtol=0) + torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0) + torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0) _verify_expert_level_sorting( actual_sorted_ids, golden_sorted_ids, @@ -290,26 +280,25 @@ def test_moe_align_block_size_deterministic(): m, topk, num_experts, block_size = 128, 2, 32, 64 torch.manual_seed(42) - topk_ids = torch.randint(0, - num_experts, (m, topk), - device="cuda", - dtype=torch.int32) + topk_ids = torch.randint( + 0, num_experts, (m, topk), device="cuda", dtype=torch.int32 + ) # expect the results to be reproducible results = [] for _ in range(5): sorted_ids, expert_ids, num_tokens = moe_align_block_size( - topk_ids=topk_ids, block_size=block_size, num_experts=num_experts) - results.append( - (sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) + topk_ids=topk_ids, block_size=block_size, num_experts=num_experts + ) + results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) for i in range(1, len(results)): - assert torch.equal( - results[0][0], - results[i][0]), ("sorted_ids should be deterministic") - assert torch.equal( - results[0][1], - results[i][1]), ("expert_ids should be deterministic") - assert torch.equal( - results[0][2], - results[i][2]), ("num_tokens should be deterministic") + assert torch.equal(results[0][0], results[i][0]), ( + "sorted_ids should be deterministic" + ) + assert torch.equal(results[0][1], results[i][1]), ( + "expert_ids should be deterministic" + ) + assert torch.equal(results[0][2], results[i][2]), ( + "num_tokens should be deterministic" + ) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index d71664d94b9c8..a6214437d404a 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -14,7 +14,10 @@ import torch from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_permute_unpermute_supported, moe_unpermute) + moe_permute, + moe_permute_unpermute_supported, + moe_unpermute, +) from vllm.platforms import current_platform NUM_EXPERTS = [16, 64, 256] @@ -24,35 +27,34 @@ current_platform.seed_everything(0) def torch_permute( - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - # token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + # token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1, +) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] if expert_map is not None: - is_local_expert = (expert_map[topk_ids] != -1) - not_local_expert = (expert_map[topk_ids] == -1) - topk_ids = is_local_expert * ( - topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) - token_expert_indices = torch.arange(0, - n_token * topk, - dtype=torch.int32, - device=hidden_states.device).reshape( - (n_token, topk)) + is_local_expert = expert_map[topk_ids] != -1 + not_local_expert = expert_map[topk_ids] == -1 + topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * ( + topk_ids + n_expert + ) + token_expert_indices = torch.arange( + 0, n_token * topk, dtype=torch.int32, device=hidden_states.device + ).reshape((n_token, topk)) - sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), - stable=True) + sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] - expert_first_token_offset = torch.zeros(n_local_expert + 1, - dtype=torch.int64, - device="cuda") + expert_first_token_offset = torch.zeros( + n_local_expert + 1, dtype=torch.int64, device="cuda" + ) idx = 0 for i in range(0, n_local_expert): cnt = 0 @@ -64,116 +66,133 @@ def torch_permute( _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) valid_row_idx = [] if align_block_size is None: - - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // - topk, ...] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...] permuted_row_size = permuted_hidden_states.shape[0] - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) + m_indices = torch.empty( + permuted_row_size, device="cuda", dtype=torch.int32 + ).fill_(fill_invalid_expert) for i in range(1, n_local_expert + 1): first_token_offset = expert_first_token_offset[i - 1] last_token_offset = expert_first_token_offset[i] m_indices[first_token_offset:last_token_offset] = i - 1 src_row_id2dst_row_id_map = torch.arange( - 0, n_token * topk, device="cuda", - dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) + 0, n_token * topk, device="cuda", dtype=torch.int32 + )[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] - dst_row_id2src_row_id_map[ - expert_first_token_offset[-1]:] = n_token * topk + dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk return [ - permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices, - valid_row_idx + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, + dst_row_id2src_row_id_map, + m_indices, + valid_row_idx, ] else: - permuted_row_size = (topk * n_token + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size - permuted_idx = torch.full((permuted_row_size, ), - n_token * topk, - dtype=torch.int32, - device=hidden_states.device) - permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), - device="cuda", - dtype=hidden_states.dtype) - align_src_row_id2dst_row_id = torch.empty(n_token * topk, - device="cuda", - dtype=torch.int32) - align_expert_first_token_offset = torch.zeros_like( - expert_first_token_offset) - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) + permuted_row_size = ( + (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) + // align_block_size + * align_block_size + ) + permuted_idx = torch.full( + (permuted_row_size,), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device, + ) + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype + ) + align_src_row_id2dst_row_id = torch.empty( + n_token * topk, device="cuda", dtype=torch.int32 + ) + align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset) + m_indices = torch.empty( + permuted_row_size, device="cuda", dtype=torch.int32 + ).fill_(fill_invalid_expert) # get align_permuted_hidden_states, # valid row_idx and align_expert_first_token_offset for i in range(1, n_local_expert + 1): first_token_offset = expert_first_token_offset[i - 1] last_token_offset = expert_first_token_offset[i] n_token_in_expert = last_token_offset - first_token_offset - align_expert_first_token_offset[ - i] = align_expert_first_token_offset[ - i - 1] + (n_token_in_expert + align_block_size - - 1) // align_block_size * align_block_size + align_expert_first_token_offset[i] = ( + align_expert_first_token_offset[i - 1] + + (n_token_in_expert + align_block_size - 1) + // align_block_size + * align_block_size + ) align_first_token_offset = align_expert_first_token_offset[i - 1] align_last_token_offset = align_expert_first_token_offset[i] dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + n_token_in_expert] + first_token_offset : first_token_offset + n_token_in_expert + ] # store token in current expert with align_first_token_offset - permuted_hidden_states[align_first_token_offset:\ - align_first_token_offset+n_token_in_expert,\ - ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert // topk,\ - ...] - permuted_idx[align_first_token_offset:\ - align_first_token_offset+\ - n_token_in_expert] = dst_row_id2src_row_id_in_expert + permuted_hidden_states[ + align_first_token_offset : align_first_token_offset + n_token_in_expert, + ..., + ] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...] + permuted_idx[ + align_first_token_offset : align_first_token_offset + n_token_in_expert + ] = dst_row_id2src_row_id_in_expert # set current expert m_indices m_indices[align_first_token_offset:align_last_token_offset] = i - 1 valid_row_idx += [ - i for i in range(align_first_token_offset, - align_first_token_offset + n_token_in_expert) + i + for i in range( + align_first_token_offset, + align_first_token_offset + n_token_in_expert, + ) ] # get align_src_row_id2dst_row_id for i in range(n_token * topk): eid = sorted_topk_ids[i] - if (eid >= n_local_expert): + if eid >= n_local_expert: # check token not in local expert - align_src_row_id2dst_row_id[ - i] = align_expert_first_token_offset[-1] + align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1] continue first_token_offset = expert_first_token_offset[eid] align_first_token_offset = align_expert_first_token_offset[eid] token_offset = i - first_token_offset - align_src_row_id2dst_row_id[ - i] = align_first_token_offset + token_offset - align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ - src2dst_idx].reshape((n_token, topk)) + align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset + align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape( + (n_token, topk) + ) return [ - permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx + permuted_hidden_states, + align_expert_first_token_offset, + align_src_row_id2dst_row_id, + permuted_idx, + m_indices, + valid_row_idx, ] -def torch_unpermute(permuted_hidden_states: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - valid_row_idx: torch.Tensor, topk: int, - n_expert: int) -> torch.Tensor: +def torch_unpermute( + permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + valid_row_idx: torch.Tensor, + topk: int, + n_expert: int, +) -> torch.Tensor: # ignore invalid row n_hidden = permuted_hidden_states.shape[1] - mask = torch.zeros(permuted_hidden_states.shape[0], - dtype=bool, - device="cuda") + mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 permuted_hidden_states = permuted_hidden_states[ - src_row_id2dst_row_id_map.flatten(), ...] + src_row_id2dst_row_id_map.flatten(), ... + ] permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden) - output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to( - permuted_hidden_states.dtype) + output = ( + (permuted_hidden_states * topk_weights.unsqueeze(2)) + .sum(1) + .to(permuted_hidden_states.dtype) + ) return output @@ -184,59 +203,76 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("align_block_size", [None, 128]) -def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, - n_expert: int, ep_size: int, dtype: torch.dtype, - align_block_size: Optional[int]): +def test_moe_permute_unpermute( + n_token: int, + n_hidden: int, + topk: int, + n_expert: int, + ep_size: int, + dtype: torch.dtype, + align_block_size: Optional[int], +): if not moe_permute_unpermute_supported(): pytest.skip("moe_permute_unpermute is not supported on this platform.") fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None n_local_expert = n_expert - if (ep_size != 1): - n_local_expert, expert_map = determine_expert_map( - ep_size, ep_rank, n_expert) + if ep_size != 1: + n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert) expert_map = expert_map.cuda() start_expert = n_local_expert * ep_rank current_platform.seed_everything(0) hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, False) - (gold_permuted_hidden_states, gold_expert_first_token_offset, - gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices, - valid_row_idx) = torch_permute( - hidden_states, - topk_ids, - # token_expert_indices, - topk, - n_expert, - n_local_expert, - start_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) + hidden_states, gating_output, topk, False + ) + ( + gold_permuted_hidden_states, + gold_expert_first_token_offset, + gold_inv_permuted_idx, + gold_permuted_idx, + gold_m_indices, + valid_row_idx, + ) = torch_permute( + hidden_states, + topk_ids, + # token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert, + ) - (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, - m_indices) = moe_permute(hidden_states=hidden_states, - a1q_scale=None, - topk_ids=topk_ids, - n_expert=n_expert, - n_local_expert=n_local_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) + ( + permuted_hidden_states, + _, + expert_first_token_offset, + inv_permuted_idx, + m_indices, + ) = moe_permute( + hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=n_expert, + n_local_expert=n_local_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert, + ) # check expert_first_token_offset - torch.testing.assert_close(gold_expert_first_token_offset, - expert_first_token_offset, - atol=0, - rtol=0) + torch.testing.assert_close( + gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0 + ) # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold_inv_permuted_idx.flatten(), - inv_permuted_idx, - atol=0, - rtol=0) + torch.testing.assert_close( + gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0 + ) # check mindice # current kernel usage assumes deepgemm requires align_block_size # when it's not provided then we don't compute m_indices (for cutlass) @@ -244,19 +280,28 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], - permuted_hidden_states[valid_row_idx], - atol=0, - rtol=0) + torch.testing.assert_close( + gold_permuted_hidden_states[valid_row_idx], + permuted_hidden_states[valid_row_idx], + atol=0, + rtol=0, + ) # add a random tensor to simulate group gemm - result0 = 0.5 * permuted_hidden_states + torch.randn_like( - permuted_hidden_states) + result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states) result4 = torch.empty_like(hidden_states) - moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, - expert_first_token_offset) + moe_unpermute( + result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset + ) - gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, inv_permuted_idx, - valid_row_idx, topk, n_local_expert) + gold4 = torch_unpermute( + result0, + topk_weights, + topk_ids, + token_expert_indices, + inv_permuted_idx, + valid_row_idx, + topk, + n_local_expert, + ) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py deleted file mode 100644 index 7bd1ffce58e96..0000000000000 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ /dev/null @@ -1,475 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import importlib -import importlib.metadata -from dataclasses import dataclass -from typing import Optional - -import pytest -import torch -from packaging import version - -from vllm.platforms import current_platform - -QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( - "quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') - -TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( -) and current_platform.is_device_capability(100) - -if TRTLLM_GEN_MXFP4_AVAILABLE: - from flashinfer import (fp4_quantize, mxfp8_quantize, - next_positive_power_of_2, - reorder_rows_for_gated_act_gemm, shuffle_matrix_a, - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) - - -@dataclass -class ModelCase: - model_id: str - tp: int - - -@pytest.mark.parametrize('model_case', [ - ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), - ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), - ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1) -]) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") -def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): - if torch.cuda.device_count() < model_case.tp: - pytest.skip(f"This test requires >={model_case.tp} gpus, got only " - f"{torch.cuda.device_count()}") - - with vllm_runner(model_case.model_id, - tensor_parallel_size=model_case.tp, - load_format="dummy") as llm: - - # TODO: llm.apply_model(check_model) currently relies on V0 internals. - # Re-enable this later. - # def check_model(model): - # layer = model.model.layers[0] - - # qkv_proj = layer.self_attn.qkv_proj - - # assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) - # assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) - - # assert isinstance(layer.mlp.experts.quant_method, - # QuarkW4A4MXFp4MoEMethod) - - # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": - # llm.apply_model(check_model) - - output = llm.generate_greedy("Today I am in the French Alps and", - max_tokens=20) - assert output - - -def swiglu(x, - alpha: float = 1.702, - beta: float = 1.0, - limit: Optional[float] = None): - # Note we add an extra bias of 1 to the linear layer - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - if limit is not None: - x_glu = x_glu.clamp(max=limit) - x_linear = x_linear.clamp(min=-limit, max=limit) - out_glu = x_glu * torch.sigmoid(alpha * x_glu) - return out_glu * (x_linear + beta) - - -fp4_lookup_table = [ - 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6 -] - - -def mxfp4_dequantize(x, scale): - assert x.dtype == torch.uint8 - x = x.view(torch.uint8).to(torch.int32) - x_unpacked = torch.zeros(*x.shape[:-1], - x.shape[-1] * 2, - dtype=torch.int32, - device=x.device) - x_unpacked[..., 0::2].copy_(x & 0xF) - x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) - - x_float = torch.zeros(x_unpacked.shape, - dtype=torch.float32, - device=x.device) - for i, val in enumerate(fp4_lookup_table): - x_float[x_unpacked == i] = val - - scale = scale.view(torch.uint8).to(torch.int32) - scale = (scale << 23).view(torch.float32) - scale = scale.reshape(*x.shape[:-1], -1) - scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) - - return x_float * scale - - -def mxfp8_dequantize(x, scale): - assert x.dtype == torch.float8_e4m3fn - x_float = x.to(torch.float32) - - scale = scale.view(torch.uint8).to(torch.int32) - scale = (scale << 23).view(torch.float32) - scale = scale.reshape(*x.shape[:-1], -1) - scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) - - return x_float * scale - - -def reference_moe( - roouting_logits, - topk, - num_experts, - hidden_states, - w13, - bias13, - w2, - bias2, - alpha, - beta, - limit, - act_type, -): - # renormalize routing - experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) - expert_weights = torch.nn.functional.softmax(experts.values, dim=1) - expert_indices = experts.indices - t = hidden_states.clone() - # MLP #1 - mlp1_weight = w13[expert_indices, ...] - mlp1_bias = bias13[expert_indices, ...] - t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias - t = swiglu(t, alpha=alpha, beta=beta, limit=limit) - - if act_type == 'mxfp8': - t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16), - is_sf_swizzled_layout=False) - t = mxfp8_dequantize(t_quantized, t_scale) - # MLP #2 - mlp2_weight = w2[expert_indices, ...] - mlp2_bias = bias2[expert_indices, ...] - t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias - # Weighted sum of experts - t = torch.einsum("bec,be->bc", t, expert_weights) - assert t.shape == hidden_states.shape - return t.to(torch.bfloat16) - - -def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # - 1.0 means perfect expert distribution. - # - > 1.0 means some experts have more - # tokens than the perfect distribution. - # - < 1.0 does not make sense. - imbalance_factor = 1.3 - # Calculate the number of tokens per expert - # assuming perfect distribution. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # Apply the imbalance factor. - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - -def tg_mxfp4_moe( - router_logits, - topk, - num_experts, - intermediate_size, - hidden_size, - hidden_states, - hidden_states_scale, - w13_weight, - w13_weight_scale, - w13_bias, - w2_weight, - w2_weight_scale, - w2_bias, - act_type, - alpha, - beta, - limit, -) -> torch.Tensor: - sf_block_size = 32 - assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts - and w13_weight.shape[1] == intermediate_size * 2 - and w13_weight.shape[2] == hidden_size // 2) - assert (w13_weight_scale.dim() == 3 - and w13_weight_scale.shape[0] == num_experts - and w13_weight_scale.shape[1] == intermediate_size * 2 - and w13_weight_scale.shape[2] == hidden_size // sf_block_size) - assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts - and w2_weight.shape[1] == hidden_size - and w2_weight.shape[2] == intermediate_size // 2) - assert (w2_weight_scale.dim() == 3 - and w2_weight_scale.shape[1] == hidden_size - and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) - assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts - and w13_bias.shape[1] == intermediate_size * 2) - assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts - and w2_bias.shape[1] == hidden_size) - - # Swap w1 and w3 as the defenition of - # swiglu is different in the trtllm-gen - w13_weight_scale_ = w13_weight_scale.clone() - w13_weight_ = w13_weight.clone() - w13_bias_ = w13_bias.clone() - w13_weight[:, :intermediate_size, :].copy_( - w13_weight_[:, intermediate_size:, :]) - w13_weight[:, intermediate_size:, :].copy_( - w13_weight_[:, :intermediate_size, :]) - w13_weight_scale[:, :intermediate_size, :].copy_( - w13_weight_scale_[:, intermediate_size:, :]) - w13_weight_scale[:, intermediate_size:, :].copy_( - w13_weight_scale_[:, :intermediate_size, :]) - w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) - w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) - - # Interleave the weights and scaling factors for activation - w13_weight_interleaved = [] - w13_weight_scale_interleaved = [] - w13_bias_interleaved = [] - for i in range(num_experts): - w13_weight_interleaved.append( - reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) - w13_weight_scale_interleaved.append( - reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) - w13_bias_interleaved.append( - reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, - 1))) - w13_weight = torch.stack(w13_weight_interleaved).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2) - w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( - num_experts, 2 * intermediate_size, hidden_size // 32) - w13_bias = torch.stack(w13_bias_interleaved).reshape( - num_experts, 2 * intermediate_size) - - # Shuffle weights and scaling factors for transposed mma output - gemm1_weights_shuffled = [] - gemm1_scales_shuffled = [] - gemm2_weights_shuffled = [] - gemm2_scales_shuffled = [] - gemm1_bias_shuffled = [] - gemm2_bias_shuffled = [] - epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - for i in range(num_experts): - gemm1_weights_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm1_scales_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - - gemm2_weights_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm2_scales_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) - - w13_weight = torch.stack(gemm1_weights_shuffled) - w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( - num_experts, 2 * intermediate_size, - hidden_size // sf_block_size).view(torch.float8_e4m3fn) - w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) - - w2_weight = torch.stack(gemm2_weights_shuffled) - w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape( - num_experts, hidden_size, - intermediate_size // sf_block_size).view(torch.float8_e4m3fn) - w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) - - tg_result = trtllm_fp4_block_scale_moe( - routing_logits=router_logits.to(torch.bfloat16), - routing_bias=None, - hidden_states=hidden_states, - hidden_states_scale=hidden_states_scale, - gemm1_weights=w13_weight, - gemm1_weights_scale=w13_weight_scale, - gemm1_bias=w13_bias, - gemm1_alpha=alpha, - gemm1_beta=beta, - gemm1_clamp_limit=limit, - gemm2_weights=w2_weight, - gemm2_weights_scale=w2_weight_scale, - gemm2_bias=w2_bias, - output1_scale_scalar=None, - output1_scale_gate_scalar=None, - output2_scale_scalar=None, - num_experts=num_experts, - top_k=topk, - n_group=None, - topk_group=None, - intermediate_size=intermediate_size, - local_expert_offset=0, - local_num_experts=num_experts, - routed_scaling_factor=None, - tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), - routing_method_type=1, # renormalize - do_finalize=True)[0] - return tg_result - - -def check_accuracy(a, b, atol, rtol, percent): - """Allow a mismatch percentage of 1 - percent.""" - if torch.any(torch.isnan(a)): - raise Exception("NaN in reference output") - if torch.any(torch.isnan(b)): - raise Exception("NaN in actual output") - if torch.any(torch.isinf(a)): - raise Exception("Inf in reference output") - if torch.any(torch.isinf(b)): - raise Exception("Inf in actual output") - assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" - - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() - if mismatch_percent > 1 - percent: - raise Exception( - f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " - f"(threshold: {1-percent:.4f})") - - -@pytest.mark.parametrize("topk", [1, 4]) -@pytest.mark.parametrize("num_experts", [32, 128]) -@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) -@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) -@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), - (1.702, 1.0, 7.0)]) -@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) -@pytest.mark.skipif( - not TRTLLM_GEN_MXFP4_AVAILABLE, - reason="nvidia gpu and compute capability sm100 is required for this test") -def test_trtllm_gen_mxfp4_fused_moe( - topk: int, - num_experts: int, - num_tokens: int, - intermediate_size: int, - hidden_size: int, - alpha: float, - beta: float, - limit: Optional[float], - act_type: str, -): - seed = 42 - torch.manual_seed(seed) - hidden_states = torch.randn(num_tokens, - hidden_size, - device="cuda:0", - dtype=torch.bfloat16) - w13 = (torch.randn(num_experts, - intermediate_size * 2, - hidden_size, - device="cuda:0", - dtype=torch.bfloat16)) - w2 = (torch.randn(num_experts, - hidden_size, - intermediate_size, - device="cuda:0", - dtype=torch.bfloat16)) - bias13 = torch.randn(num_experts, intermediate_size * 2, - device="cuda:0") * 10 - bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 - router_logits = torch.rand(num_tokens, num_experts, - dtype=torch.float32).cuda() - - w13, w13_scale = fp4_quantize(w13, - torch.tensor(1.0, device="cuda:0"), - 32, - sf_use_ue8m0=True, - is_sf_swizzled_layout=False) - w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( - num_experts, intermediate_size * 2, hidden_size // 32) - w2, w2_scale = fp4_quantize(w2, - torch.tensor(1.0, device="cuda:0"), - 32, - sf_use_ue8m0=True, - is_sf_swizzled_layout=False) - w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( - num_experts, hidden_size, intermediate_size // 32) - if act_type == 'mxfp8': - hidden_states, hidden_states_scale = mxfp8_quantize( - hidden_states, is_sf_swizzled_layout=False) - hidden_states_scale = hidden_states_scale.view( - torch.float8_e4m3fn).reshape(-1) - else: - hidden_states_scale = None - - # reference result - ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) - w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) - w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) - bias13_ref = bias13 - bias2_ref = bias2 - if act_type == 'mxfp8': - hidden_states_ref = mxfp8_dequantize( - hidden_states, hidden_states_scale).to(torch.float32) - else: - hidden_states_ref = hidden_states.to(torch.float32) - # Process tokens in chunks of 32 to reduce memory usage - chunk_size = 32 - num_chunks = (num_tokens + chunk_size - 1) // chunk_size - for i in range(num_chunks): - start_idx = i * chunk_size - end_idx = min(start_idx + chunk_size, num_tokens) - chunk_result = reference_moe( - router_logits[start_idx:end_idx].to(torch.float32), - topk, - num_experts, - hidden_states_ref[start_idx:end_idx], - w13_ref, - bias13_ref, - w2_ref, - bias2_ref, - alpha, - beta, - limit, - act_type, - ) - ref_result[start_idx:end_idx].copy_(chunk_result) - - # trtllm-gen result - if alpha is not None: - alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) - if limit is not None: - limit = torch.full((num_experts, ), limit, device=hidden_states.device) - if beta is not None: - beta = torch.full((num_experts, ), beta, device=hidden_states.device) - tg_result = tg_mxfp4_moe(router_logits, - topk, - num_experts, - intermediate_size, - hidden_size, - hidden_states, - hidden_states_scale, - w13, - w13_scale, - bias13, - w2, - w2_scale, - bias2, - act_type, - alpha=alpha, - beta=beta, - limit=limit) - # relatively loose check since the mxfp4 quantization is less accurate - check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 30388ef9375d4..dae19c0b2b31b 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -4,19 +4,23 @@ import pytest import torch from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip("Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + "Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True + ) MNK_FACTORS = [ (2, 1024, 1024), @@ -37,54 +41,56 @@ MNK_FACTORS = [ @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() -def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype): +def test_cutlass_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): current_platform.seed_everything(7) with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): quant_blocksize = 16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - (_, w1_q, w1_blockscale, - w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( - e, - n, - k, - in_dtype=dtype, - quant_dtype="nvfp4", - block_shape=None, # use quant_blocksize? - per_act_token_quant=False, - ) + (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = ( + make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_out_ch_quant=False, + ) + ) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) - a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) assert w1_gs is not None assert w2_gs is not None assert w1_blockscale is not None assert w2_blockscale is not None + quant_config = nvfp4_moe_quant_config( + g1_alphas=(1 / w1_gs), + g2_alphas=(1 / w2_gs), + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + ) + cutlass_output = cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, w1_fp4=w1_q, - w1_blockscale=w1_blockscale, - g1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, w2_fp4=w2_q, - w2_blockscale=w2_blockscale, - g2_alphas=(1 / w2_gs), topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=quant_config, m=m, n=n, k=k, @@ -92,40 +98,44 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ) # Reference check: - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=a.dtype, - device=a.device, - block_size=quant_blocksize) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize, + ) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) - torch.testing.assert_close(torch_output, - cutlass_output, - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) if __name__ == "__main__": diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py new file mode 100644 index 0000000000000..dceed34f35125 --- /dev/null +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -0,0 +1,994 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib.metadata +from dataclasses import dataclass +from importlib.util import find_spec +from typing import Optional + +import pytest +import torch +from packaging import version + +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer + +QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") + +TRTLLM_GEN_MXFP4_AVAILABLE = ( + current_platform.is_cuda() and current_platform.is_device_capability(100) +) + +HOPPER_MXFP4_BF16_AVAILABLE = ( + current_platform.is_cuda() + and current_platform.is_device_capability(90) + and has_flashinfer() +) + +if TRTLLM_GEN_MXFP4_AVAILABLE: + from flashinfer import ( + fp4_quantize, + mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + trtllm_fp4_block_scale_moe, + ) + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@pytest.fixture(scope="function", autouse=True) +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + +@pytest.mark.parametrize( + "model_case", + [ + ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2), + ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), + ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1), + ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1), + ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4), + ], +) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): + if torch.cuda.device_count() < model_case.tp: + pytest.skip( + f"This test requires >={model_case.tp} gpus, got only " + f"{torch.cuda.device_count()}" + ) + + # `cuda_graph_sizes=[16]` to reduce load time. + with vllm_runner( + model_case.model_id, + tensor_parallel_size=model_case.tp, + load_format="dummy", + cuda_graph_sizes=[16], + ) as llm: + # Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562 + # def check_model(model): + # from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 + # QuarkLinearMethod) + # from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501 + # from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 + # QuarkOCP_MX_MoEMethod) + + # layer = model.model.layers[0] + + # qkv_proj = layer.self_attn.qkv_proj + + # assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + # assert isinstance(qkv_proj.scheme, QuarkOCP_MX) + + # assert isinstance(layer.mlp.experts.quant_method, + # QuarkOCP_MX_MoEMethod) + + # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": + # llm.apply_model(check_model) + + output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) + assert output + + +def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu * (x_linear + beta) + + +fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6] + + +def mxfp4_dequantize(x, scale): + assert x.dtype == torch.uint8 + x = x.view(torch.uint8).to(torch.int32) + x_unpacked = torch.zeros( + *x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device + ) + x_unpacked[..., 0::2].copy_(x & 0xF) + x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) + + x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device) + for i, val in enumerate(fp4_lookup_table): + x_float[x_unpacked == i] = val + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def mxfp8_dequantize(x, scale): + assert x.dtype == torch.float8_e4m3fn + x_float = x.to(torch.float32) + + scale = scale.view(torch.uint8).to(torch.int32) + scale = (scale << 23).view(torch.float32) + scale = scale.reshape(*x.shape[:-1], -1) + scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) + + return x_float * scale + + +def reference_moe( + roouting_logits, + topk, + num_experts, + hidden_states, + w13, + bias13, + w2, + bias2, + alpha, + beta, + limit, + act_type, +): + # renormalize routing + experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + t = hidden_states.clone() + # MLP #1 + mlp1_weight = w13[expert_indices, ...] + mlp1_bias = bias13[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + + if act_type == "mxfp8": + t_quantized, t_scale = mxfp8_quantize( + t.to(torch.bfloat16), is_sf_swizzled_layout=False + ) + t = mxfp8_dequantize(t_quantized, t_scale) + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = bias2[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + assert t.shape == hidden_states.shape + return t.to(torch.bfloat16) + + +def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13_weight, + w13_weight_scale, + w13_bias, + w2_weight, + w2_weight_scale, + w2_bias, + act_type, + alpha, + beta, + limit, + transpose_optimized: bool = False, +) -> torch.Tensor: + sf_block_size = 32 + assert ( + w13_weight.dim() == 3 + and w13_weight.shape[0] == num_experts + and w13_weight.shape[1] == intermediate_size * 2 + and w13_weight.shape[2] == hidden_size // 2 + ) + assert ( + w13_weight_scale.dim() == 3 + and w13_weight_scale.shape[0] == num_experts + and w13_weight_scale.shape[1] == intermediate_size * 2 + and w13_weight_scale.shape[2] == hidden_size // sf_block_size + ) + assert ( + w2_weight.dim() == 3 + and w2_weight.shape[0] == num_experts + and w2_weight.shape[1] == hidden_size + and w2_weight.shape[2] == intermediate_size // 2 + ) + assert ( + w2_weight_scale.dim() == 3 + and w2_weight_scale.shape[1] == hidden_size + and w2_weight_scale.shape[2] == intermediate_size // sf_block_size + ) + assert ( + w13_bias.dim() == 2 + and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2 + ) + assert ( + w2_bias.dim() == 2 + and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size + ) + + # Swap w1 and w3 as the definition of + # swiglu is different in the trtllm-gen + w13_weight_scale_ = w13_weight_scale.clone() + w13_weight_ = w13_weight.clone() + w13_bias_ = w13_bias.clone() + w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :]) + w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :]) + w13_weight_scale[:, :intermediate_size, :].copy_( + w13_weight_scale_[:, intermediate_size:, :] + ) + w13_weight_scale[:, intermediate_size:, :].copy_( + w13_weight_scale_[:, :intermediate_size, :] + ) + w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) + w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) + + # Interleave the weights and scaling factors for activation + w13_weight_interleaved = [] + w13_weight_scale_interleaved = [] + w13_bias_interleaved = [] + for i in range(num_experts): + w13_weight_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight[i].clone()) + ) + w13_weight_scale_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()) + ) + w13_bias_interleaved.append( + reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1)) + ) + w13_weight = torch.stack(w13_weight_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2 + ) + w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size // 32 + ) + w13_bias = torch.stack(w13_bias_interleaved).reshape( + num_experts, 2 * intermediate_size + ) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_shuffled = [] + gemm2_scales_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + if transpose_optimized: + for i in range(num_experts): + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_shuffled.append( + w13_weight[i] + .view(torch.uint8)[permute_indices.to(w13_weight.device)] + .contiguous() + ) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_shuffled.append( + nvfp4_block_scale_interleave( + w13_weight_scale[i] + .view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)] + .contiguous() + ) + ) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append( + w13_bias[i] + .clone() + .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] + .contiguous() + ) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_shuffled.append( + w2_weight[i] + .view(torch.uint8)[permute_indices.to(w2_weight.device)] + .contiguous() + ) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + nvfp4_block_scale_interleave( + w2_weight_scale[i] + .view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)] + .contiguous() + ) + ) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append( + w2_bias[i] + .clone() + .reshape(-1, 1)[permute_indices.to(w2_bias.device)] + .contiguous() + ) + + else: + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m) + ) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a( + w13_weight_scale[i].view(torch.uint8), epilogue_tile_m + ) + ) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) + ) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a( + w2_weight_scale[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m) + ) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m) + ) + + w13_weight = torch.stack(gemm1_weights_shuffled) + w13_weight_scale = ( + torch.stack(gemm1_scales_shuffled) + .reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size) + .view(torch.float8_e4m3fn) + ) + w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) + + w2_weight = torch.stack(gemm2_weights_shuffled) + w2_weight_scale = ( + torch.stack(gemm2_scales_shuffled) + .reshape(num_experts, hidden_size, intermediate_size // sf_block_size) + .view(torch.float8_e4m3fn) + ) + w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) + + tg_result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale, + gemm1_bias=w13_bias, + gemm1_alpha=alpha, + gemm1_beta=beta, + gemm1_clamp_limit=limit, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale, + gemm2_bias=w2_bias, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=topk, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), + routing_method_type=1, # renormalize + do_finalize=True, + )[0] + return tg_result + + +def check_accuracy(a, b, atol, rtol, percent): + """Allow a mismatch percentage of 1 - percent.""" + if torch.any(torch.isnan(a)): + raise Exception("NaN in reference output") + if torch.any(torch.isnan(b)): + raise Exception("NaN in actual output") + if torch.any(torch.isinf(a)): + raise Exception("Inf in reference output") + if torch.any(torch.isinf(b)): + raise Exception("Inf in actual output") + assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" + + left = torch.abs(a - b) + right = atol + rtol * torch.abs(b) + count = torch.sum(left > right) + mismatch_percent = count / a.numel() + if mismatch_percent > 1 - percent: + raise Exception( + f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " + f"(threshold: {1 - percent:.4f})" + ) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"]) +@pytest.mark.parametrize("transpose_optimized", [False, True]) +@pytest.mark.skipif( + not TRTLLM_GEN_MXFP4_AVAILABLE, + reason="nvidia gpu and compute capability sm100 is required for this test", +) +def test_trtllm_gen_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], + act_type: str, + transpose_optimized: bool, +): + seed = 42 + torch.manual_seed(seed) + hidden_states = torch.randn( + num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16 + ) + w13 = torch.randn( + num_experts, + intermediate_size * 2, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16, + ) + w2 = torch.randn( + num_experts, + hidden_size, + intermediate_size, + device="cuda:0", + dtype=torch.bfloat16, + ) + bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10 + bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 + router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda() + + w13, w13_scale = fp4_quantize( + w13, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False, + ) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, hidden_size // 32 + ) + w2, w2_scale = fp4_quantize( + w2, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False, + ) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 32 + ) + if act_type == "mxfp8": + hidden_states, hidden_states_scale = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False + ) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1) + else: + hidden_states_scale = None + + # reference result + ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) + w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) + w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) + bias13_ref = bias13 + bias2_ref = bias2 + if act_type == "mxfp8": + hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to( + torch.float32 + ) + else: + hidden_states_ref = hidden_states.to(torch.float32) + # Process tokens in chunks of 32 to reduce memory usage + chunk_size = 32 + num_chunks = (num_tokens + chunk_size - 1) // chunk_size + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_result = reference_moe( + router_logits[start_idx:end_idx].to(torch.float32), + topk, + num_experts, + hidden_states_ref[start_idx:end_idx], + w13_ref, + bias13_ref, + w2_ref, + bias2_ref, + alpha, + beta, + limit, + act_type, + ) + ref_result[start_idx:end_idx].copy_(chunk_result) + + # trtllm-gen result + if alpha is not None: + alpha = torch.full((num_experts,), alpha, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts,), limit, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts,), beta, device=hidden_states.device) + tg_result = tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + bias13, + w2, + w2_scale, + bias2, + act_type, + alpha=alpha, + beta=beta, + limit=limit, + transpose_optimized=transpose_optimized, + ) + # relatively loose check since the mxfp4 quantization is less accurate + check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) + + +def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: + """Interleave scales on the last dimension by groups of 4, matching + the transformation in mxfp4.py's BF16 (Hopper) path.""" + s = scales.to(torch.uint8) + s_shape = s.shape + assert s_shape[-1] % 4 == 0 + s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) + # Move the 4-group dimension before the row dimension + permuted = s.permute(0, 2, 1, 3) + # Merge the row dim with the 4-group dim + return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not HOPPER_MXFP4_BF16_AVAILABLE, + reason="nvidia gpu sm90 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn( + num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] + w13_q = torch.randint( + 0, + 256, + (num_experts, 2 * intermediate_size, hidden_size // 2), + device=device, + dtype=torch.uint8, + ) + w13_scale = torch.randint( + 118, + 123, + (num_experts, 2 * intermediate_size, hidden_size // 32), + device=device, + dtype=torch.uint8, + ) + + w2_q = torch.randint( + 0, + 256, + (num_experts, hidden_size, intermediate_size // 2), + device=device, + dtype=torch.uint8, + ) + w2_scale = torch.randint( + 118, + 123, + (num_experts, hidden_size, intermediate_size // 32), + device=device, + dtype=torch.uint8, + ) + # Bias contiguous [b1; b3] + bias13 = ( + torch.randn( + num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16 + ) + * 10 + ) + bias2 = ( + torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10 + ) + router_logits = torch.rand( + num_tokens, num_experts, dtype=torch.float32, device=device + ) + + w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( + num_experts, 2 * intermediate_size, hidden_size + ) + w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( + num_experts, hidden_size, intermediate_size + ) + ref = reference_moe( + router_logits.to(torch.float32), + topk, + num_experts, + hidden_states.to(torch.float32), + w13_ref, + bias13.to(torch.float32), + w2_ref, + bias2.to(torch.float32), + alpha, + beta, + limit, + "bf16", + ) + + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) + w13_s = torch.cat([w3_s, w1_s], dim=1) + w13_s_inter = _interleave_scales_lastdim_by4(w13_s) + w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) + + routing_weights = torch.nn.functional.softmax( + router_logits, dim=1, dtype=torch.float32 + ) + token_final_scales, token_selected_experts = torch.topk( + routing_weights, topk, dim=-1 + ) + token_final_scales = token_final_scales / token_final_scales.sum( + dim=-1, keepdim=True + ) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha = torch.full((num_experts,), alpha, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts,), beta, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts,), limit, device=hidden_states.device) + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped, + fc2_expert_weights=w2_q, + output_dtype=torch.bfloat16, + output=out, + quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)], + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha, + swiglu_beta=beta, + swiglu_limit=limit, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_w4_group_scaling=True, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and has_flashinfer() + ), + reason="NVIDIA GPU sm100 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: Optional[float], + beta: Optional[float], + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn( + num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + # Float weights in w13 format [w1; w3] + w13 = ( + torch.randn( + num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ) + / 10 + ) + w2 = ( + torch.randn( + num_experts, + hidden_size, + intermediate_size, + device=device, + dtype=torch.bfloat16, + ) + / 10 + ) + # Bias contiguous [b1; b3] + bias13 = ( + torch.randn( + num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16 + ) + * 10 + ) + bias2 = ( + torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10 + ) + router_logits = torch.rand( + num_tokens, num_experts, dtype=torch.float32, device=device + ) + + # Quantize weights to MXFP4 per expert (SM100 path) + from flashinfer import mxfp4_quantize + + def quant_mxfp4_batches(a: torch.Tensor, e: int): + qs, sfs = [], [] + for i in range(e): + q, sf = mxfp4_quantize(a[i].cuda()) + qs.append(q) + sfs.append(sf) + return torch.stack(qs), torch.stack(sfs) + + def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor): + num_batches = mat_fp4.size(0) + scale_tensor = scale_tensor.view(num_batches, -1) + from flashinfer import mxfp4_dequantize + + return torch.stack( + [ + mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) + for b in range(num_batches) + ] + ) + + w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) + w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) + + # Reference result using dequantized tensors and reference_moe + w13_ref = ( + dequant_mxfp4_batches( + w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1) + ) + .to(torch.float32) + .reshape(num_experts, 2 * intermediate_size, hidden_size) + .to(device) + ) + w2_ref = ( + dequant_mxfp4_batches( + w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1) + ) + .to(torch.float32) + .reshape(num_experts, hidden_size, intermediate_size) + .to(device) + ) + + # Quantize activations for SM100 path and dequantize for reference + hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) + # Reference uses BF16 input but quantizes intermediate activation to MXFP8 + ref = reference_moe( + router_logits.to(torch.float32), + topk, + num_experts, + hidden_states.to(torch.float32), + w13_ref, + bias13.to(torch.float32), + w2_ref, + bias2.to(torch.float32), + alpha, + beta, + limit, + "mxfp8", + ) + + # Prepare inputs for FlashInfer CUTLASS fused MoE + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + # Swap scales halves to match swapped weights + s1, s3 = torch.chunk(w13_scale, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + # Build routing for kernel + routing_weights = torch.nn.functional.softmax( + router_logits, dim=1, dtype=torch.float32 + ) + token_final_scales, token_selected_experts = torch.topk( + routing_weights, topk, dim=-1 + ) + token_final_scales = token_final_scales / token_final_scales.sum( + dim=-1, keepdim=True + ) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device) + else: + alpha_t = None + if beta is not None: + beta_t = torch.full((num_experts,), beta, device=hidden_states.device) + else: + beta_t = None + if limit is not None: + limit_t = torch.full((num_experts,), limit, device=hidden_states.device) + else: + limit_t = None + + # Quant scales for SM100 MXFP8+MXFP4 path + fake_input_scale = torch.ones(num_experts, device=device) + quant_scales = [ + w13_scale_swapped.view(torch.int32), + fake_input_scale, + w2_scale.view(torch.int32), + fake_input_scale, + ] + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states_q, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long), + fc2_expert_weights=w2_q.contiguous().view(torch.long), + output_dtype=torch.bfloat16, + output=out, + quant_scales=quant_scales, + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha_t, + swiglu_beta=beta_t, + swiglu_limit=limit_t, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_mxfp8_act_scaling=True, + input_sf=hidden_states_sf, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 98908f2714707..4c7c6c6a4f529 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,21 +9,25 @@ import torch from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassBatchedExpertsFp8) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils import cdiv +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch try: from pplx_kernels import AllToAll - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, + ) + has_pplx = True except ImportError: has_pplx = False @@ -47,12 +51,12 @@ def chunk_by_rank(t, r, w): chunk = rank_chunk(num, r, w) rem = num % w if rem == 0 or r < rem: - return t[(r * chunk):(r + 1) * chunk].contiguous() + return t[(r * chunk) : (r + 1) * chunk].contiguous() else: long_chunks = (num // w + 1) * rem short_chunks = (r - rem) * chunk start = long_chunks + short_chunks - return t[start:start + chunk].contiguous() + return t[start : start + chunk].contiguous() def pplx_cutlass_moe( @@ -72,7 +76,9 @@ def pplx_cutlass_moe( group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) + assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape @@ -123,29 +129,40 @@ def pplx_cutlass_moe( ata, max_num_tokens=max_num_tokens, num_local_experts=num_local_experts, - num_dispatchers=num_dispatchers) + num_dispatchers=num_dispatchers, + ) - ab_strides1 = torch.full((num_local_experts, ), - hidden_dim, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_local_experts, ), - intermediate_dim, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_local_experts, ), - 2 * intermediate_dim, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_local_experts, ), - hidden_dim, - device="cuda", - dtype=torch.int64) + ab_strides1 = torch.full( + (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 + ) + ab_strides2 = torch.full( + (num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64 + ) + c_strides1 = torch.full( + (num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64 + ) + c_strides2 = torch.full( + (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 + ) - experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, - out_dtype, per_act_token, per_out_ch, - ab_strides1, ab_strides2, c_strides1, - c_strides2) + experts = CutlassBatchedExpertsFp8( + num_local_experts, + num_dispatchers, + out_dtype, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + fp8_w8a8_moe_quant_config( + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + w1_scale=chunk_by_rank(w1_scale, rank, world_size), + w2_scale=chunk_by_rank(w2_scale, rank, world_size), + a1_scale=chunk_by_rank(a1_scale, rank, world_size) + if per_act_token + else a1_scale[rank], + ), + ) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -153,10 +170,10 @@ def pplx_cutlass_moe( ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - chunk_topk_weight = chunk_by_rank(topk_weights, rank, - world_size).to(device) - chunk_topk_ids = chunk_by_rank(topk_ids, rank, - world_size).to(torch.uint32).to(device) + chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device) + chunk_topk_ids = ( + chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device) + ) out = fused_cutlass_experts( a_chunk, @@ -165,11 +182,8 @@ def pplx_cutlass_moe( chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts, - expert_map=None, #TODO - w1_scale=chunk_by_rank(w1_scale, rank, world_size), - w2_scale=chunk_by_rank(w2_scale, rank, world_size), - a1_scale=chunk_by_rank(a1_scale, rank, world_size) - if per_act_token else a1_scale[rank]) + expert_map=None, # TODO + ) torch.cuda.synchronize() @@ -204,35 +218,48 @@ def _pplx_moe( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_experts(a_full, w1_full, w2_full, - topk_weights, topk_ids) - pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, - w2_scale, topk_weights, topk_ids, - a1_scale, out_dtype, per_act_token, - per_out_ch, group_name) + torch_output = torch_experts( + a_full, w1_full, w2_full, topk_weights, topk_ids + ) + pplx_output = pplx_cutlass_moe( + pgi, + dp_size, + a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a1_scale, + out_dtype, + per_act_token, + per_out_ch, + group_name, + ) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pplx_output.device + ) # Uncomment if more debugging is needed # print("PPLX OUT:", pplx_output) # print("TORCH OUT:", torch_output) - torch.testing.assert_close(pplx_output, - torch_output, - atol=0.05, - rtol=0) + torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) finally: if use_internode: nvshmem_finalize() @@ -245,12 +272,15 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]]) @pytest.mark.parametrize("use_internode", [False]) +@multi_gpu_test(num_gpus=2) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) @requires_pplx def test_cutlass_moe_pplx( m: int, @@ -266,7 +296,6 @@ def test_cutlass_moe_pplx( current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): - dtype = torch.half a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0 @@ -276,22 +305,18 @@ def test_cutlass_moe_pplx( n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn) w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) + w1[expert], use_per_token_if_dynamic=per_out_ch + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) + w2[expert], use_per_token_if_dynamic=per_out_ch + ) w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) @@ -300,19 +325,35 @@ def test_cutlass_moe_pplx( w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) world_size, dp_size = world_dp_size - a_scale1 = torch.randn( - (m if per_act_token else 1, 1), device="cuda", - dtype=torch.float32) / 10.0 + a_scale1 = ( + torch.randn( + (m if per_act_token else 1, 1), device="cuda", dtype=torch.float32 + ) + / 10.0 + ) if not per_act_token: a_scale1 = a_scale1.repeat(world_size, 1) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, - w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, - dtype, a, w1_d, w2_d, per_act_token, per_out_ch, - use_internode) + parallel_launch( + world_size, + _pplx_moe, + dp_size, + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a_scale1, + dtype, + a, + w1_d, + w2_d, + per_act_token, + per_out_ch, + use_internode, + ) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index c2064de97358f..223f095c0b553 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,39 +4,50 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ + +import copy import itertools import textwrap import traceback -from typing import Callable, Optional +from typing import Callable, Optional, Union import pytest import torch try: from pplx_kernels import AllToAll - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, + ) + has_pplx = True except ImportError: has_pplx = False -from tests.kernels.moe.utils import make_test_weights, naive_batched_moe +from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config +from tests.kernels.moe.utils import ( + make_shared_experts, + make_test_weights, + naive_batched_moe, +) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.platforms import current_platform from vllm.utils import round_up +from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch requires_pplx = pytest.mark.skipif( @@ -53,8 +64,8 @@ BATCHED_MOE_MNK_FACTORS = [ ] PPLX_COMBOS = [ - # TODO: figure out why this fails, seems to be test problem - #(1, 128, 128), + # TODO(bnell): figure out why this fails, seems to be test problem + # (1, 128, 128), (2, 128, 512), (3, 1024, 2048), (4, 128, 128), @@ -86,17 +97,16 @@ def torch_prepare( num_tokens, hidden_dim = a.shape topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), - minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) assert tokens_per_expert.numel() == num_experts if max_num_tokens is None: max_num_tokens = int(tokens_per_expert.max().item()) - b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), - dtype=a.dtype, - device=a.device) + b_a = torch.zeros( + (num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device + ) token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) @@ -104,28 +114,29 @@ def torch_prepare( for j in range(topk): expert_id = topk_ids[token, j] idx = token_counts[expert_id] - b_a[expert_id, idx:idx + 1, :] = a[token, :] + b_a[expert_id, idx : idx + 1, :] = a[token, :] token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert -def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: +def torch_finalize( + b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor +) -> torch.Tensor: num_tokens = topk_ids.shape[0] num_experts = b_out.shape[0] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, - dtype=torch.int, - device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx + - 1, :] * topk_weight[token, i] + out[token, :] = ( + out[token, :] + + b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i] + ) expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -144,17 +155,18 @@ def torch_batched_moe( num_tokens, topk = topk_ids.shape _, max_num_tokens, K = b_a.shape assert num_experts == b_a.shape[0] and w2.shape[1] == K - out = torch.zeros((num_experts, max_num_tokens, K), - dtype=b_a.dtype, - device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), - dtype=b_a.dtype, - device=b_a.device) + out = torch.zeros( + (num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device + ) + tmp = torch.empty( + (max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device + ) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: torch.ops._C.silu_and_mul( - tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1) + ) out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) return torch_finalize(out, topk_weight, topk_ids) @@ -181,20 +193,16 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_experts(a, w1, w2, topk_weight, - topk_ids) # only for baseline + baseline_output = torch_experts( + a, w1, w2, topk_weight, topk_ids + ) # only for baseline torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = naive_batched_moe( - a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this + a, w1, w2, topk_weight, topk_ids + ) # pick torch_experts or this - torch.testing.assert_close(baseline_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(baseline_output, - batched_output, - atol=2e-2, - rtol=0) + torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) def create_pplx_prepare_finalize( @@ -212,7 +220,9 @@ def create_pplx_prepare_finalize( group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1) num_local_experts = rank_chunk(num_experts, 0, world_size) @@ -261,28 +271,31 @@ def rank_chunk(num: int, r: int, w: int) -> int: def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return t[(r * chunk) : (r + 1) * chunk] -def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def maybe_chunk_by_rank( + t: Optional[torch.Tensor], r: int, w: int +) -> Optional[torch.Tensor]: if t is not None: return chunk_by_rank(t, r, w) else: return t -def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def chunk_scales_by_rank( + t: Optional[torch.Tensor], r: int, w: int +) -> Optional[torch.Tensor]: if t is not None and t.numel() > 1: chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return t[(r * chunk) : (r + 1) * chunk] else: return t -def chunk_scales(t: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def chunk_scales( + t: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if t is not None and t.numel() > 1: return t[start:end] else: @@ -345,8 +358,7 @@ def pplx_prepare_finalize( device=device, ) - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): + if quant_dtype is not None and not per_act_token_quant and block_shape is None: a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -355,23 +367,22 @@ def pplx_prepare_finalize( b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( a_chunk, - a1_scale, - a2_scale, chunk_topk_weight, chunk_topk_ids, num_experts, None, False, - FusedMoEQuantConfig( + FusedMoEQuantConfig.make( quant_dtype, - per_act_token_quant, - False, - block_shape, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=block_shape, + a1_scale=a1_scale, + a2_scale=a2_scale, ), ) - b_a = dummy_work( - dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) + b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) prepare_finalize.finalize( out, @@ -405,15 +416,17 @@ def _pplx_prepare_finalize( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) @@ -421,22 +434,28 @@ def _pplx_prepare_finalize( a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) - torch_output = (a_rep.view(m, topk, k) * - topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum( - dim=1) + torch_output = ( + a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype) + ).sum(dim=1) - pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, - topk_ids, num_experts, quant_dtype, - block_shape, per_act_token_quant, - group_name) + pplx_output = pplx_prepare_finalize( + pgi, + dp_size, + a, + topk_weight, + topk_ids, + num_experts, + quant_dtype, + block_shape, + per_act_token_quant, + group_name, + ) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pgi.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pgi.device + ) - torch.testing.assert_close(pplx_output, - torch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) finally: if use_internode: nvshmem_finalize() @@ -452,6 +471,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize_slow( mnk: tuple[int, int, int], e: int, @@ -485,9 +505,19 @@ def test_pplx_prepare_finalize_slow( a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 score = torch.randn((m, e), device=device, dtype=act_dtype) - parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, - topk, e, quant_dtype, block_shape, per_act_token_quant, - use_internode) + parallel_launch( + world_size, + _pplx_prepare_finalize, + dp_size, + a, + score, + topk, + e, + quant_dtype, + block_shape, + per_act_token_quant, + use_internode, + ) def pplx_moe( @@ -509,8 +539,8 @@ def pplx_moe( block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, -) -> torch.Tensor: - + shared_experts: Optional[torch.nn.Module] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] topk = topk_ids.shape[1] @@ -533,19 +563,6 @@ def pplx_moe( topk_ids = topk_ids.to(dtype=torch.uint32) - experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - ) - - fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - # Note: workers with the same dp_rank must use the exact same inputs. a_chunk = chunk_by_rank(a, rank, world_size) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size) @@ -559,45 +576,67 @@ def pplx_moe( a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size) a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, + ) + + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=quant_config, + ) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts, + ) + # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. if use_compile: - _fused_experts = torch.compile(fused_experts, - backend='inductor', - fullgraph=True) + _fused_experts = torch.compile( + fused_experts, backend="inductor", fullgraph=True + ) torch._dynamo.mark_dynamic(a_chunk, 0) torch._dynamo.mark_dynamic(chunk_topk_weight, 0) torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts - out = _fused_experts(a_chunk, - w1_chunk, - w2_chunk, - chunk_topk_weight, - chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, - global_num_experts=num_experts) + out = _fused_experts( + a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts, + ) if use_cudagraphs: - out.fill_(0) + if isinstance(out, tuple): + out[0].fill_(0) + out[1].fill_(0) + else: + out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = _fused_experts(a_chunk, - w1_chunk, - w2_chunk, - chunk_topk_weight, - chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, - global_num_experts=num_experts) + out = _fused_experts( + a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts, + ) torch.cuda.synchronize() graph.replay() @@ -624,18 +663,21 @@ def _pplx_moe( per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, use_internode: bool = False, + shared_experts: Optional[torch.nn.Module] = None, ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name m, k = a.shape @@ -653,8 +695,7 @@ def _pplx_moe( w1_s = w1_s.to(device) if w1_s is not None else None w2_s = w2_s.to(device) if w2_s is not None else None - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): + if quant_dtype is not None and not per_act_token_quant and block_shape is None: a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -664,6 +705,8 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + shared_output = shared_experts(a) if shared_experts is not None else None + torch_output = torch_experts( a, w1, @@ -694,7 +737,7 @@ def _pplx_moe( block_shape=block_shape, ) - pplx_output = pplx_moe( + pplx_outputs = pplx_moe( group_name, rank, world_size, @@ -711,20 +754,41 @@ def _pplx_moe( quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, + shared_experts=shared_experts, ) + if shared_experts is None: + pplx_shared_output = None + pplx_output = pplx_outputs + assert isinstance(pplx_output, torch.Tensor) + else: + pplx_shared_output, pplx_output = pplx_outputs + + if shared_output is not None: + assert pplx_shared_output is not None + chunked_shared_output = chunk_by_rank( + shared_output, pgi.rank, pgi.world_size + ).to(pplx_shared_output.device) + else: + chunked_shared_output = None + chunked_batch_output = chunk_by_rank( - batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) + batched_output, pgi.rank, pgi.world_size + ).to(pplx_output.device) - torch.testing.assert_close(batched_output, - torch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) + + torch.testing.assert_close( + pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2 + ) + + if shared_experts is not None: + assert chunked_shared_output is not None + assert pplx_shared_output is not None + torch.testing.assert_close( + pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2 + ) - torch.testing.assert_close(pplx_output, - chunked_batch_output, - atol=3e-2, - rtol=3e-2) finally: if use_internode: nvshmem_finalize() @@ -740,6 +804,7 @@ def _pplx_moe( @pytest.mark.parametrize("use_internode", [False]) @pytest.mark.optional @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe_slow( mnk: tuple[int, int, int], e: int, @@ -776,17 +841,36 @@ def test_pplx_moe_slow( k, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, - w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, - use_internode) + parallel_launch( + world_size, + _pplx_moe, + dp_size, + a, + w1, + w2, + score, + topk, + e, + w1_s, + w2_s, + quant_dtype, + per_act_token_quant, + block_shape, + use_internode, + ) -def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, - make_weights: bool, test_fn: Callable): - +def _pplx_test_loop( + pgi: ProcessGroupInfo, + dp_size: int, + use_internode: bool, + use_shared_experts: bool, + make_weights: bool, + test_fn: Callable, +): def format_result(msg, ex=None): if ex is not None: x = str(ex) @@ -800,9 +884,17 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, else: print(f"PASSED {msg}") + if use_shared_experts: + # Note: this config is only needed for the non-naive shared experts. + new_vllm_config = copy.deepcopy(vllm_config) + new_vllm_config.parallel_config.data_parallel_size = pgi.world_size + new_vllm_config.parallel_config.enable_expert_parallel = True + _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank) + current_platform.seed_everything(7) - combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, - [False, True], [None, [128, 128]]) + combos = itertools.product( + PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]] + ) exceptions = [] count = 0 for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos: @@ -816,15 +908,15 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, use_fp8_w8a8 = False quant_dtype = None - test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " - f"dtype={dtype}, per_act_token={per_act_token_quant}, " - f"block_shape={block_shape}") + test_desc = ( + f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " + f"dtype={dtype}, per_act_token={per_act_token_quant}, " + f"block_shape={block_shape}, use_internode={use_internode}, " + f"use_shared_experts={use_shared_experts}" + ) - if not use_fp8_w8a8 and (per_act_token_quant - or block_shape is not None): - print( - f"{test_desc} - Skip quantization test for non-quantized type." - ) + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): + print(f"{test_desc} - Skip quantization test for non-quantized type.") continue if per_act_token_quant and block_shape is not None: @@ -842,13 +934,21 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, k, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) args["w1"] = w1 args["w2"] = w2 args["w1_s"] = w1_s args["w2_s"] = w2_s + if use_shared_experts: + args["shared_experts"] = make_shared_experts( + n, + k, + in_dtype=a.dtype, + quant_dtype=quant_dtype, + ) + try: test_fn( pgi=pgi, @@ -871,33 +971,51 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, if len(exceptions) > 0: raise RuntimeError( f"{len(exceptions)} of {count} tests failed in child process, " - f"rank={pgi.rank}.") + f"rank={pgi.rank}." + ) else: - print(f"{count} of {count} tests passed in child process, " - f"rank={pgi.rank}.") + print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.") @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_prepare_finalize( world_dp_size: tuple[int, int], use_internode: bool, ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, - use_internode, False, _pplx_prepare_finalize) + parallel_launch( + world_size * dp_size, + _pplx_test_loop, + dp_size, + use_internode, + False, + False, + _pplx_prepare_finalize, + ) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.parametrize("use_shared_experts", [False, True]) @requires_pplx +@multi_gpu_test(num_gpus=2) def test_pplx_moe( world_dp_size: tuple[int, int], use_internode: bool, + use_shared_experts: bool, ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, - _pplx_moe) + parallel_launch( + world_size, + _pplx_test_loop, + dp_size, + use_internode, + use_shared_experts, + True, + _pplx_moe, + ) diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py index 1c51c530c193c..d4724d749fc98 100644 --- a/tests/kernels/moe/test_rocm_aiter_topk.py +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -24,13 +24,14 @@ aiter_available = importlib.util.find_spec("aiter") is not None pytestmark = pytest.mark.skipif( not (current_platform.is_rocm() and aiter_available), - reason="AITER ops are only available on ROCm with aiter package installed") + reason="AITER ops are only available on ROCm with aiter package installed", +) def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm - assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') + assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk") # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) @@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): def test_rocm_aiter_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm - assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk') + assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk") # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_grouped_topk) @@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): renormalize = True scale_factor = 1.0 - gating_output = torch.randn((token, expert), - dtype=torch.bfloat16, - device="cuda") - e_score_correction_bias = torch.randn((expert, ), - dtype=torch.bfloat16, - device="cuda") + gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") + e_score_correction_bias = torch.randn( + (expert,), dtype=torch.bfloat16, device="cuda" + ) device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op - def biased_grouped_topk_fn(gating_output, e_score_correction_bias, - topk_weights, topk_ids): + def biased_grouped_topk_fn( + gating_output, e_score_correction_bias, topk_weights, topk_ids + ): return torch.ops.vllm.rocm_aiter_biased_grouped_topk( - gating_output, e_score_correction_bias, topk_weights, topk_ids, - num_expert_group, topk_group, renormalize, scale_factor) + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scale_factor, + ) # Verify the op's fake implementation torch.library.opcheck( @@ -84,51 +89,49 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): "num_expert_group": num_expert_group, "topk_group": topk_group, "need_renorm": renormalize, - "routed_scaling_factor": scale_factor + "routed_scaling_factor": scale_factor, }, - test_utils=("test_faketensor")) + test_utils=("test_faketensor"), + ) # Compile the function with appropriate settings - compiled_fn = torch.compile(biased_grouped_topk_fn, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) + compiled_fn = torch.compile( + biased_grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) - topk_weights_original = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_original = torch.empty((token, topk), - dtype=torch.int32, - device=device) + topk_weights_original = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights_compiled = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_compiled = torch.empty((token, topk), - dtype=torch.int32, - device=device) + topk_weights_compiled = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) - biased_grouped_topk_fn(gating_output, e_score_correction_bias, - topk_weights_original, topk_ids_original) - compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, - topk_ids_compiled) + biased_grouped_topk_fn( + gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original + ) + compiled_fn( + gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled + ) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) - topk_weights_original = torch.gather(topk_weights_original, 1, - indices_original) + topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) - topk_weights_compiled = torch.gather(topk_weights_compiled, 1, - indices_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match - assert torch.allclose(topk_weights_original, - topk_weights_compiled, - rtol=1e-2, - atol=1e-2) + assert torch.allclose( + topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2 + ) assert torch.allclose(topk_ids_original, topk_ids_compiled) @@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility(): scoring_func = "softmax" scale_factor = 1.0 - gating_output = torch.randn((token, expert), - dtype=torch.bfloat16, - device="cuda") + gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func): return torch.ops.vllm.rocm_aiter_grouped_topk( - gating_output, topk_weights, topk_ids, num_expert_group, - topk_group, renormalize, scoring_func, scale_factor) + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scoring_func, + scale_factor, + ) # Verify the op's fake implementation - torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk, - (gating_output, topk_weights, topk_ids), - kwargs={ - "num_expert_group": num_expert_group, - "topk_group": topk_group, - "need_renorm": renormalize, - "scoring_func": scoring_func, - "routed_scaling_factor": scale_factor - }, - test_utils=("test_faketensor")) + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_grouped_topk, + (gating_output, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "scoring_func": scoring_func, + "routed_scaling_factor": scale_factor, + }, + test_utils=("test_faketensor"), + ) # Compile the function with appropriate settings - compiled_fn = torch.compile(grouped_topk_fn, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) + compiled_fn = torch.compile( + grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) - topk_weights_original = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_original = torch.empty((token, topk), - dtype=torch.int32, - device=device) + topk_weights_original = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights_compiled = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_compiled = torch.empty((token, topk), - dtype=torch.int32, - device=device) + topk_weights_compiled = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) - grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original, - scoring_func) - compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, - scoring_func) + grouped_topk_fn( + gating_output, topk_weights_original, topk_ids_original, scoring_func + ) + compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) - topk_weights_original = torch.gather(topk_weights_original, 1, - indices_original) + topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) - topk_weights_compiled = torch.gather(topk_weights_compiled, 1, - indices_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match - assert torch.allclose(topk_weights_original, - topk_weights_compiled, - rtol=1e-2, - atol=1e-2) + assert torch.allclose( + topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2 + ) assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb4475..8b3bebb391f2f 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,79 +5,121 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm) + persistent_masked_m_silu_mul_quant, +) from vllm.platforms import current_platform +from vllm.utils import cdiv + +fp8_dtype = torch.float8_e4m3fn -# (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (1, 1, 128, fp8_dtype), + (1, 4, 128, fp8_dtype), + (2, 4, 256, fp8_dtype), + (32, 64, 256, fp8_dtype), + (17, 31, 768, fp8_dtype), + (1, 1, 128 * 1, fp8_dtype), + (1, 1, 128 * 2, fp8_dtype), + (1, 1, 128 * 3, fp8_dtype), + (1, 1, 128 * 4, fp8_dtype), + (8, 16, 128 * 1, fp8_dtype), + (8, 16, 128 * 2, fp8_dtype), + (8, 16, 128 * 3, fp8_dtype), + (8, 16, 128 * 4, fp8_dtype), + (8, 64, 7168, fp8_dtype), + (8, 128, 7168, fp8_dtype), + (8, 256, 7168, fp8_dtype), + (8, 512, 7168, fp8_dtype), + (8, 1024, 7168, fp8_dtype), + (256, 8, 7168, fp8_dtype), + (256, 16, 7168, fp8_dtype), + (256, 32, 7168, fp8_dtype), + (256, 64, 7168, fp8_dtype), + # Only add a few fnuz tests to help with long CI times. + (8, 512, 7168, torch.float8_e4m3fnuz), + (8, 1024, 7168, torch.float8_e4m3fnuz), ] -@pytest.mark.parametrize("E,T,H,group_size,seed", CASES) +@pytest.mark.parametrize("E,T,H,fp8_type", CASES) @torch.inference_mode() -def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): - current_platform.seed_everything(seed) +def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): + group_size = 128 + current_platform.seed_everything(42) # Input tensor of shape (E, T, 2*H) y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, - size=(E, ), + size=(E,), dtype=torch.int32, device="cuda", ) - # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, - tokens_per_expert, - group_size=group_size, - eps=1e-10) + # Run the SiLU V2 kernel + y_q, y_s = persistent_masked_m_silu_mul_quant( + y, tokens_per_expert, group_size=group_size + ) - # Reference implementation - fp8_info = torch.finfo(torch.float8_e4m3fn) + torch.cuda.synchronize() + fp8_info = torch.finfo(fp8_dtype) fp8_max = fp8_info.max fp8_min = fp8_info.min eps = 1e-10 - # Compute silu activation and elementwise multiplication - y1 = y[..., :H] + y1 = y[..., :H].float() y2 = y[..., H:] silu_x = y1 * torch.sigmoid(y1) merged = silu_x * y2 - # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, H // group_size), - dtype=torch.float32, - device="cuda") - ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") + ref_s = torch.empty( + (T, cdiv(H, group_size)), dtype=torch.float32, device="cuda" + ) + ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") + for t in range(nt): - data = merged[e, t] - data_grp = data.view(H // group_size, group_size) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max + data = merged[e, t].float() + ref_q_row = torch.empty_like(data) - scaled = data / scale.repeat_interleave(group_size) - clamped = scaled.clamp(fp8_min, fp8_max) - q = clamped.to(torch.float8_e4m3fn) + # process full groups + n_full_groups = H // group_size + if n_full_groups > 0: + data_grp = data[: n_full_groups * group_size].view( + n_full_groups, group_size + ) + amax = data_grp.abs().amax(dim=1).clamp(min=eps) + scale = amax / fp8_max + scaled = data[: n_full_groups * group_size] / scale.repeat_interleave( + group_size + ) + ref_q_row[: n_full_groups * group_size] = scaled.clamp( + fp8_min, fp8_max + ).to(fp8_dtype) + ref_s[t, :n_full_groups] = scale - ref_s[t] = scale - ref_q[t] = q + # process remainder group + rem = H % group_size + if rem > 0: + data_rem = data[-rem:] + amax = data_rem.abs().amax().clamp(min=eps) + scale = amax / fp8_max + scaled = data_rem / scale + ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, -1] = scale - y_se = y_s[e] - y_qe = y_q[e] + ref_q[t] = ref_q_row + + y_se = y_s[e].float() + y_qe = y_q[e].float() - torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), atol=2, rtol=2e-1, ) + + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index dfd0f35c8da3d..933cd9dbdeaa0 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,15 +7,15 @@ import itertools import pytest import torch +from tests.kernels.moe.utils import fused_moe from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.platforms import current_platform if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -29,14 +29,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous( - ), "B must be a 2D contiguous tensor" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" # Reshape input M = A.numel() // A.shape[-1] B = B.t() # Transpose weight matrix N, K = B.shape - origin_C_shape = A.shape[:-1] + (K, ) + origin_C_shape = A.shape[:-1] + (K,) A = A.reshape(M, N) # As is per-token [M, 1], Bs is per-column [1, K] @@ -86,17 +85,17 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token act_out_q, act_out_s = ops.scaled_fp8_quant( - act_out, use_per_token_if_dynamic=True) + act_out, use_per_token_if_dynamic=True + ) # Second MLP layer - out[mask] = native_w8a8_per_token_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - output_dtype=a.dtype) + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) # Apply routing weights and sum - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -114,8 +113,10 @@ TOP_KS = [2, 6] SEEDS = [0] -@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): torch.manual_seed(seed) @@ -131,12 +132,10 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): # Generate int8 weights w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 - w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, - max=fp8_max).to(torch.float8_e4m3fn) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 - w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, - max=fp8_max).to(torch.float8_e4m3fn) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) # Generate scale for each column (per-column quantization) w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale @@ -152,15 +151,16 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): score, topk, renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization + quant_config=fp8_w8a8_moe_quant_config( + per_act_token_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ), ) # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.05 diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 82960bd57345d..9466dacb0c111 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -6,15 +6,17 @@ import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX) -from vllm.model_executor.layers.fused_moe import fused_experts +from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + BatchedPrepareAndFinalize, + BatchedTritonExperts, + NaiveBatchedExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import round_up from vllm.utils.deep_gemm import per_block_cast_to_fp8 @@ -33,18 +35,17 @@ def triton_moe( per_act_token_quant=False, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_channel_quant=per_act_token_quant, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config) def batched_moe( @@ -63,29 +64,28 @@ def batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + quant_config=quant_config, ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts(a, w1, w2, topk_weight, topk_ids) def naive_batched_moe( @@ -104,33 +104,33 @@ def naive_batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), NaiveBatchedExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + quant_config=quant_config, ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts(a, w1, w2, topk_weight, topk_ids) -def chunk_scales(scales: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def chunk_scales( + scales: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: return scales @@ -153,13 +153,15 @@ def make_quantized_test_activations( a_scale = None if quant_dtype is not None: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" + assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, ( + "only fp8/int8 supported" + ) a_q = torch.zeros_like(a, dtype=quant_dtype) a_scale_l = [None] * E for e in range(E): a_q[e], a_scale_l[e] = moe_kernel_quantize_input( - a[e], None, quant_dtype, per_act_token_quant, block_shape) + a[e], None, quant_dtype, per_act_token_quant, block_shape + ) a_scale = torch.stack(a_scale_l) if not per_act_token_quant and block_shape is None: @@ -175,8 +177,11 @@ def moe_quantize_weights( per_token_quant: bool, block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8 - or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported" + assert ( + quant_dtype == torch.float8_e4m3fn + or quant_dtype == torch.int8 + or quant_dtype == "nvfp4" + ), "only fp8/int8/nvfp4 supported" w_gs = None @@ -193,10 +198,12 @@ def moe_quantize_weights( else: if quant_dtype == torch.int8: w, w_s = ops.scaled_int8_quant( - w, w_s, use_per_token_if_dynamic=per_token_quant) + w, w_s, use_per_token_if_dynamic=per_token_quant + ) elif quant_dtype == torch.float8_e4m3fn: w, w_s = ops.scaled_fp8_quant( - w, w_s, use_per_token_if_dynamic=per_token_quant) + w, w_s, use_per_token_if_dynamic=per_token_quant + ) elif quant_dtype == "nvfp4": assert not per_token_quant w_amax = torch.abs(w).max().to(torch.float32) @@ -215,9 +222,8 @@ def make_test_weight( in_dtype: torch.dtype = torch.bfloat16, quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + per_out_ch_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_gs = None @@ -227,7 +233,8 @@ def make_test_weight( w_gs_l = [None] * e for idx in range(e): w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( - w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) + w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape + ) w = torch.stack(w_l) w_s = torch.stack(w_s_l) @@ -257,28 +264,250 @@ def make_test_weights( in_dtype: torch.dtype = torch.bfloat16, quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, -) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]], - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]]: + per_out_ch_quant: bool = False, +) -> tuple[ + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], +]: return ( - make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), - make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + make_test_weight( + e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant + ), + make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant), ) def per_token_cast_to_fp8( - x: torch.Tensor, - block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size: int = 128 +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape pad_size = (block_size - (n % block_size)) % block_size - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, block_size) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def make_test_quant_config( + e: int, + n: int, + k: int, + in_dtype: torch.dtype, + quant_dtype: Union[torch.dtype, str, None] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]: + (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype, + quant_dtype, + per_out_ch_quant=per_act_token_quant, + block_shape=block_shape, + ) + + # Hacky/trivial scales for nvfp4. + a1_gscale: Optional[torch.Tensor] = None + a2_gscale: Optional[torch.Tensor] = None + if quant_dtype == "nvfp4": + a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32) + a1_scale = a1_gscale + a2_scale = a2_gscale + else: + a1_scale = None + a2_scale = None + + return ( + w1, + w2, + FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_s, + w2_scale=w2_s, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + a1_scale=a1_scale, + a2_scale=a2_scale, + # TODO: make sure this is handled properly + g1_alphas=(1 / w1_gs) if w1_gs is not None else None, + g2_alphas=(1 / w2_gs) if w2_gs is not None else None, + ), + ) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + renormalize: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk( + hidden_states, score.float(), topk, renormalize + ) + return fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=quant_config, + ) + + +# CustomOp? +class BaselineMM(torch.nn.Module): + def __init__( + self, + b: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.b = b.to(dtype=torch.float32) + self.out_dtype = out_dtype + + def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None + + +class TestMLP(torch.nn.Module): + def __init__( + self, + w1: torch.Tensor, + w2: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.gate_up_proj = BaselineMM(w1, out_dtype) + self.down_proj = BaselineMM(w2, out_dtype) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +def make_naive_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, +) -> torch.nn.Module: + w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15 + w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15 + return TestMLP(w1, w2, out_dtype=in_dtype) + + +class RealMLP(torch.nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + w1: torch.Tensor, + w2: torch.Tensor, + hidden_act: str = "silu", + quant_config=None, + reduce_results: bool = True, + prefix: str = "", + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + ) -> None: + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, + ) + + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.gate_up_proj.register_parameter( + "weight", torch.nn.Parameter(w1, requires_grad=False) + ) + self.gate_up_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False) + ) + self.gate_up_proj.register_parameter( + "input_scale", None + ) # torch.nn.Parameter(None, requires_grad=False)) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.down_proj.register_parameter( + "weight", torch.nn.Parameter(w2, requires_grad=False) + ) + self.down_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False) + ) + self.down_proj.register_parameter( + "input_scale", None + ) # torch.nn.Parameter(None, requires_grad=False)) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def make_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Union[torch.dtype, str, None] = None, +) -> torch.nn.Module: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + 1, + N, + K, + in_dtype=in_dtype, + quant_dtype=quant_dtype, + ) + old_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(in_dtype) + if quant_dtype == torch.float8_e4m3fn: + w1 = w1[0].transpose(0, 1) + w2 = w2[0].transpose(0, 1) + w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None + w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None + quant_config = Fp8Config(True) + else: + w1 = w1[0] + w2 = w2[0] + w1_s = None + w2_s = None + quant_config = None + + return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s) + finally: + torch.set_default_dtype(old_dtype) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 01a1ad2e7a0a5..d892f2a5acc09 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -5,8 +5,7 @@ from typing import Optional, Union import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.platforms import current_platform from vllm.utils import round_up @@ -17,25 +16,31 @@ FP8_DTYPE = current_platform.fp8_dtype() def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: - return torch.as_tensor(x, dtype=torch.float32, device='cuda') + return torch.as_tensor(x, dtype=torch.float32, device="cuda") -def ref_dynamic_per_token_quant(x: torch.tensor, - quant_dtype: torch.dtype, - scale_ub: Optional[torch.tensor] = None) \ - -> tuple[torch.tensor, torch.tensor]: +def ref_dynamic_per_token_quant( + x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None +) -> tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, FP8_DTYPE] if scale_ub is not None: assert quant_dtype == FP8_DTYPE - qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ - else torch.finfo(quant_dtype) - qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else qtype_traits.max - qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else qtype_traits.min + qtype_traits = ( + torch.iinfo(quant_dtype) + if quant_dtype == torch.int8 + else torch.finfo(quant_dtype) + ) + qtype_traits_max = ( + ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else qtype_traits.max + ) + qtype_traits_min = ( + -ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else qtype_traits.min + ) qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -56,15 +61,13 @@ def ref_dynamic_per_token_quant(x: torch.tensor, iscales = as_float32_tensor(s_1 / scales) torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() - torch_out = torch_out.clamp(qtype_traits_min, - qtype_traits_max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype) else: assert quant_dtype == FP8_DTYPE min_scaling_factor = s_1 / (qtype_max * s_512) scales = scales.clamp(min=min_scaling_factor) torch_out = as_float32_tensor(x) / scales - torch_out = torch_out.clamp(qtype_traits_min, - qtype_traits_max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype) return torch_out, scales @@ -72,16 +75,20 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # The int8 version is very similar. Incorporate the int8 version, like in # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel -def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ - -> tuple[torch.tensor, torch.tensor]: - +def ref_dynamic_per_tensor_fp8_quant( + x: torch.tensor, +) -> tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else fp8_traits.max - fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else fp8_traits.min + fp8_traits_max = ( + ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else fp8_traits.max + ) + fp8_traits_min = ( + -ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else fp8_traits.min + ) fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) @@ -92,9 +99,12 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ x_max = as_float32_tensor(x.abs().max()) ref_scale = x_max / fp8_max ref_iscale = one / ref_scale - ref_out = (as_float32_tensor(x) * ref_iscale).clamp( - fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) - return ref_out, ref_scale.view((1, )) + ref_out = ( + (as_float32_tensor(x) * ref_iscale) + .clamp(fp8_traits_min, fp8_traits_max) + .to(FP8_DTYPE) + ) + return ref_out, ref_scale.view((1,)) def native_w8a8_block_matmul( @@ -126,7 +136,7 @@ def native_w8a8_block_matmul( M = A.numel() // A.shape[-1] N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) + origin_C_shape = A.shape[:-1] + (N,) A = A.reshape(M, A.shape[-1]) As = As.reshape(M, As.shape[-1]) n_tiles = (N + block_n - 1) // block_n @@ -137,19 +147,19 @@ def native_w8a8_block_matmul( C_shape = (M, N) C = torch.zeros(C_shape, dtype=compute_type, device=A.device) - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] for i in range(k_tiles): for j in range(n_tiles): @@ -163,14 +173,14 @@ def native_w8a8_block_matmul( return C -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): """Function to perform per-token-group quantization on an input tensor `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must " - "be divisible by `group_size`") + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` must be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" finfo = torch.finfo(dtype) @@ -178,28 +188,25 @@ def native_per_token_group_quant_fp8(x, fp8_max = finfo.max x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / fp8_max x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s -def native_per_token_group_quant_int8(x, - group_size, - eps=1e-10, - dtype=torch.int8): +def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8): """Function to perform per-token-group quantization on an input tensor `x` using native torch. It converts the tensor values into int8 values and returns the quantized tensor along with the scaling factor used for quantization. """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` must be divisible by `group_size`" + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` must be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" iinfo = torch.iinfo(dtype) @@ -208,13 +215,13 @@ def native_per_token_group_quant_int8(x, x_ = x.reshape(x.numel() // group_size, group_size) # Use float32 for scale calculation for stability - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / int8_max - x_q = (x_.to(torch.float32) / x_s).round().clamp( - min=int8_min, max=int8_max).to(dtype) # Round before clamping + x_q = ( + (x_.to(torch.float32) / x_s).round().clamp(min=int8_min, max=int8_max).to(dtype) + ) # Round before clamping x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s @@ -229,9 +236,9 @@ def per_block_cast_to_int8( block_m, block_n = block_shape assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -269,8 +276,9 @@ def batched_dequant( assert t.shape[0] == scale.shape[0] out = torch.empty_like(t, dtype=out_dtype) for e in range(t.shape[0]): - out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, - out_dtype) + out[e] = dequant( + t[e], scale[e], block_shape, per_act_token_quant, out_dtype + ) return out return t.to(out_dtype) @@ -294,15 +302,17 @@ def native_batched_masked_quant_matmul( num_tokens = num_expert_tokens_cpu[e] if A.dtype.itemsize == 1 and block_shape is not None: assert A_scale is not None and B_scale is not None - tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e], - block_shape, C.dtype) + tmp = native_w8a8_block_matmul( + A[e], B[e], A_scale[e], B_scale[e], block_shape, C.dtype + ) C[e, :num_tokens, :] = tmp[:num_tokens, :] elif A.dtype.itemsize == 1 and block_shape is None: assert A_scale is not None and B_scale is not None A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant) B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant) - C[e, :num_tokens, :] = ( - A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype) + C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to( + C.dtype + ) else: assert A_scale is None assert B_scale is None diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 1095975ab2b41..5e6d54c42e89b 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -2,13 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm._custom_ops import scaled_fp4_quant from vllm.scalar_type import scalar_types FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): @@ -21,12 +23,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] -def dequantize_nvfp4_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): +def dequantize_nvfp4_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 @@ -65,3 +64,13 @@ def break_fp4_bytes(a, dtype): # Reshape to final form return values.reshape(m, n * 2).to(dtype=dtype) + + +def get_nvfp4_global_scale(a: torch.Tensor): + return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32) + + +def quant_nvfp4_tensor(a: torch.Tensor): + a_global_scale = get_nvfp4_global_scale(a) + a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) + return a_quant, a_block_scale, a_global_scale diff --git a/tests/kernels/quantization/test_allspark_gemm.py b/tests/kernels/quantization/test_allspark_gemm.py index 3de9cb3644684..e5f056f04f8c0 100644 --- a/tests/kernels/quantization/test_allspark_gemm.py +++ b/tests/kernels/quantization/test_allspark_gemm.py @@ -6,24 +6,25 @@ import torch from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - ALLSPARK_AMPERE_N_ALIGN) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) + ALLSPARK_AMPERE_K_ALIGN, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_AMPERE_N_ALIGN, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -def is_gptq_allspark_supported(min_capability: int, - max_capability: int) -> bool: +def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool: if not current_platform.is_cuda(): return False capability = current_platform.get_device_capability() assert capability is not None - return capability.to_int() >= min_capability \ - and capability.to_int() <= max_capability + return ( + capability.to_int() >= min_capability and capability.to_int() <= max_capability + ) MNK_FACTORS = [ @@ -43,7 +44,8 @@ HAS_ZP_OPTS = [False, True] def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def rand_data(shape, dtype=torch.float16): @@ -52,7 +54,8 @@ def rand_data(shape, dtype=torch.float16): @pytest.mark.skipif( not is_gptq_allspark_supported(80, 89), - reason="AllSpark Ampere kernel is not supported on this GPU type.") + reason="AllSpark Ampere kernel is not supported on this GPU type.", +) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("group_size", [-1]) @pytest.mark.parametrize("has_zp", HAS_ZP_OPTS) @@ -67,8 +70,9 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): weight = rand_data((k, n), dtype=dtype) # Quantize (and apply act_order if provided) - w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128, - group_size, has_zp) + w_ref, qw, s, zp = quantize_weights( + weight, scalar_types.uint8b128, group_size, has_zp + ) qw = qw.to(torch.uint8) if has_zp: @@ -79,20 +83,42 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): n_32align = (n + 32 - 1) // 32 * 32 - qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( - qw, s, zp, has_zp) - opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order, - (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, - n_32align)) + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp) + opcheck( + torch.ops._C.rearrange_kn_weight_as_n32k16_order, + (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align), + ) - opcheck(torch.ops._C.allspark_w8a16_gemm, - (input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, - sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) - output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder, - n, group_size, sm_count, sm_version, - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - has_zp, True) + opcheck( + torch.ops._C.allspark_w8a16_gemm, + ( + input, + qw_reorder, + s_reorder, + zp_reorder, + n, + group_size, + sm_count, + sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, + True, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + output = ops.allspark_w8a16_gemm( + input, + qw_reorder, + s_reorder, + zp_reorder, + n, + group_size, + sm_count, + sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, + True, + ) output_ref = torch.matmul(input, w_ref) torch.cuda.synchronize() diff --git a/tests/kernels/quantization/test_awq.py b/tests/kernels/quantization/test_awq.py index bc0868123d82a..efb62ca3799a9 100644 --- a/tests/kernels/quantization/test_awq.py +++ b/tests/kernels/quantization/test_awq.py @@ -8,40 +8,42 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"), - reason="AWQ is not supported on this GPU type.") +@pytest.mark.skipif( + not hasattr(torch.ops._C, "awq_dequantize"), + reason="AWQ is not supported on this GPU type.", +) def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_TRITON_AWQ", "0") - qweight = torch.randint(-2000000000, - 2000000000, (8192, 256), - device='cuda', - dtype=torch.int32) - scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) - zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) + qweight = torch.randint( + -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 + ) + scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16) + zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32) split_k_iters = 0 thx = 0 thy = 0 - opcheck(torch.ops._C.awq_dequantize, - (qweight, scales, zeros, split_k_iters, thx, thy)) + opcheck( + torch.ops._C.awq_dequantize, + (qweight, scales, zeros, split_k_iters, thx, thy), + ) @pytest.mark.skip(reason="Not working; needs investigation.") -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"), - reason="AWQ is not supported on this GPU type.") +@pytest.mark.skipif( + not hasattr(torch.ops._C, "awq_gemm"), + reason="AWQ is not supported on this GPU type.", +) def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_TRITON_AWQ", "0") - input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) - qweight = torch.randint(-2000000000, - 2000000000, (8192, 256), - device='cuda', - dtype=torch.int32) - scales = torch.randint(-2000000000, - 2000000000, (64, 256), - device='cuda', - dtype=torch.int32) - qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) + input = torch.rand((2, 8192), device="cuda", dtype=torch.float16) + qweight = torch.randint( + -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 + ) + scales = torch.randint( + -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32 + ) + qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16) split_k_iters = 8 - opcheck(torch.ops._C.awq_gemm, - (input, qweight, qzeros, scales, split_k_iters)) + opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters)) diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index 96797e85bd125..069bd74355348 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -2,13 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the AWQ Triton kernel. -Run `pytest tests/kernels/test_awq_triton.py`. +Run `pytest tests/kernels/quantization/test_awq_triton.py`. """ + import pytest import torch from vllm.model_executor.layers.quantization.awq_triton import ( - AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) + AWQ_TRITON_SUPPORTED_GROUP_SIZES, + awq_dequantize_triton, + awq_gemm_triton, +) from vllm.platforms import current_platform device = "cuda" @@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor): # qweights - [R , C // 8], int32 # scales - [R // G, C ], float16 # zeros - [R // G, C // 8], int32 -def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, - qzeros: torch.Tensor, - group_size: int) -> torch.Tensor: - +def awq_dequantize_torch( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int +) -> torch.Tensor: if group_size == -1: group_size = qweight.shape[0] bits = 4 shifts = torch.arange(0, 32, bits, device=qzeros.device) - iweights = torch.bitwise_right_shift(qweight[:, :, None], - shifts[None, None, :]).to(torch.int8) + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) iweights = iweights.view(iweights.shape[0], -1) - zeros = torch.bitwise_right_shift(qzeros[:, :, None], - shifts[None, None, :]).to(torch.int8) + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) zeros = zeros.view(qzeros.shape[0], -1) zeros = reverse_awq_order(zeros) @@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, @pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) def test_dequantize(qweight_rows, qweight_cols, group_size): - if group_size == -1: group_size = qweight_rows @@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): current_platform.seed_everything(0) - qweight = torch.randint(0, - torch.iinfo(torch.int32).max, - (qweight_rows, qweight_cols), - dtype=qweight_dtype, - device=device) - scales = torch.rand(scales_rows, - scales_cols, - dtype=scales_dtype, - device=device) - zeros = torch.randint(0, - torch.iinfo(torch.int32).max, - (zeros_rows, zeros_cols), - dtype=zeros_dtype, - device=device) + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device, + ) + scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device) + zeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device, + ) iweights_triton = awq_dequantize_triton(qweight, scales, zeros) - assert (not torch.any(torch.isinf(iweights_triton)) - and not torch.any(torch.isnan(iweights_triton))) + assert not torch.any(torch.isinf(iweights_triton)) and not torch.any( + torch.isnan(iweights_triton) + ) iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) @@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("splitK", [1, 8]) def test_gemm(N, K, M, splitK, group_size): - if group_size == -1: group_size = K @@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size): current_platform.seed_everything(0) - input = torch.rand((input_rows, input_cols), - dtype=input_dtype, - device=device) - qweight = torch.randint(0, - torch.iinfo(torch.int32).max, - (qweight_rows, qweight_cols), - device=device) - qzeros = torch.randint(0, - torch.iinfo(torch.int32).max, - (qzeros_rows, qzeros_cols), - device=device) - scales = torch.rand((scales_rows, scales_cols), - dtype=scales_dtype, - device=device) + input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device) + qweight = torch.randint( + 0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device + ) + qzeros = torch.randint( + 0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device + ) + scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device) - output_triton = awq_gemm_triton(input, qweight, scales, qzeros, - split_k_iters) + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) - assert (not torch.any(torch.isinf(output_triton)) - and not torch.any(torch.isnan(output_triton))) + assert not torch.any(torch.isinf(output_triton)) and not torch.any( + torch.isnan(output_triton) + ) dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) output_torch = torch.matmul(input, dequantized_weights) - assert (not torch.any(torch.isinf(output_torch)) - and not torch.any(torch.isnan(output_torch))) + assert not torch.any(torch.isinf(output_torch)) and not torch.any( + torch.isnan(output_torch) + ) - torch.testing.assert_close(output_triton.cpu(), - output_torch.cpu(), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close( + output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1 + ) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index d9154d3fd7f33..a6dfb5428c52e 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -7,19 +7,26 @@ import itertools import pytest import torch -from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul) +from tests.kernels.quant_utils import ( + native_per_token_group_quant_fp8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul) + cutlass_scaled_mm, + per_token_group_quant_fp8, + w8a8_triton_block_scaled_mm, +) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + get_col_major_tma_aligned_tensor, + per_block_cast_to_fp8, +) if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -50,7 +57,8 @@ def setup_cuda(): @pytest.mark.parametrize( "num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS), +) @torch.inference_mode() def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): torch.manual_seed(seed) @@ -59,15 +67,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) out, scale = per_token_group_quant_fp8(x, group_size) - assert torch.allclose(out.to(torch.float32), - ref_out.to(torch.float32), - rtol=0.15) + assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) assert torch.allclose(scale, ref_scale) @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -88,21 +95,68 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) + assert rel_diff < 0.001 + + +@torch.inference_mode() +def test_w8a8_block_fp8_cutlass_matmul(): + # Test simple case where weight.shape % 128 != 0, + # like in DSV3 kv_a_proj_with_mqa + M = 32 + N = 576 + K = 7168 + block_size = [128, 128] + out_dtype = torch.bfloat16 + seed = 0 + + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + # Hopper requires row-major format for scales + Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs + + A_fp8, As = per_token_group_quant_fp8( + A_fp32, block_size[1], column_major_scales=False + ) + # CUTLASS uses column-major format for scales + A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8( + A_fp32, block_size[1], column_major_scales=True + ) + + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = cutlass_scaled_mm( + A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype + ) + + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@pytest.mark.skipif(not has_deep_gemm(), - reason="DeepGemm kernels not available.") + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -122,20 +176,20 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): As = As_fp8.to(torch.float32) Bs = Bs_fp8.to(torch.float32) - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) # Transpose earlier so that the testing will not trigger transposing kernels As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) - out = torch.zeros((M, N), device='cuda', dtype=out_dtype) + out = torch.zeros((M, N), device="cuda", dtype=out_dtype) - assert As_fp8.shape == (M, (K + 127) // - 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + assert As_fp8.shape == (M, (K + 127) // 128), ( + f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + ) fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index fac82cf9c8b5e..dabc10a122f7a 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -10,12 +10,12 @@ import torch from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( - w8a8_block_int8_matmul) + w8a8_block_int8_matmul, +) from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -36,8 +36,10 @@ def setup_cuda(): torch.set_default_device("cuda") -@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -58,11 +60,10 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 diff --git a/tests/kernels/quantization/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py index 878f66647e19e..cfdb3658028a6 100644 --- a/tests/kernels/quantization/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for sparse cutlass kernels -Run `pytest tests/kernels/test_semi_structured.py`. +Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`. """ import pytest @@ -11,12 +11,11 @@ import torch from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8 from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - sparse_cutlass_supported) + sparse_cutlass_supported, +) from vllm.platforms import current_platform -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -40,9 +39,7 @@ def prune_to_2_4(tensor): # Create binary mask mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) # Apply mask and reshape back pruned = reshaped * mask @@ -55,32 +52,31 @@ def prune_to_2_4(tensor): # This function checks that applying an identity matrix multiplication # to the compressed weights yields the original uncompressed weights. -def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor, - b_compressed: torch.Tensor, - b_metadata: torch.Tensor): - +def check_compress_decompress_invariance( + dtype: torch.dtype, + b: torch.Tensor, + b_compressed: torch.Tensor, + b_metadata: torch.Tensor, +): # For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the # same dtype as its inputs. This line addresses that constraint while # arbitrarily using bfloat16 for the int8/fp8 cases. out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16 - eye = torch.eye(b.shape[0], device='cuda', dtype=dtype) - eye_scale = torch.ones(1, device='cuda', dtype=torch.float32) - b_decomp = ops.cutlass_scaled_sparse_mm(eye, - b_compressed, - b_metadata, - eye_scale, - eye_scale, - out_dtype=out_dtype) + eye = torch.eye(b.shape[0], device="cuda", dtype=dtype) + eye_scale = torch.ones(1, device="cuda", dtype=torch.float32) + b_decomp = ops.cutlass_scaled_sparse_mm( + eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype + ) torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp) def make_rand_sparse_tensors( - dtype: torch.dtype, m: int, n: int, k: int + dtype: torch.dtype, m: int, n: int, k: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') - b = torch.randn((n, k), device='cuda').t() + a = torch.randn((m, k), device="cuda") + b = torch.randn((n, k), device="cuda").t() if dtype == torch.int8: # ensure A and B aren't all zeros after rounding @@ -107,32 +103,25 @@ def make_rand_sparse_tensors( return b_compressed, e, a, b -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) # Test working with a subset of A and B for sparse matmul def test_cutlass_sparse_subset(): - big_m = 1024 m, n, k = 512, 512, 512 # Create tensors - b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, - big_m, n, k) + b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k) a = whole_a[0:m, 0:k] scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16 + ) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) @@ -161,105 +150,87 @@ MNK_FACTORS = [ # Test working with a subset of A and B for sparse matmul -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m, n, k", MNK_FACTORS) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype], - use_bias: bool): - +def test_cutlass_sparse_gemm( + m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool +): # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32) scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32) - bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=dtype, - bias=bias) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1) -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m, k, n", MNK_FACTORS) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool): - # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) out_dtype = torch.bfloat16 - bias = torch.rand( - (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + baseline = baseline_scaled_mm( + a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1) -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m,k,n", MNK_FACTORS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): - +def test_cutlass_sparse_int8_gemm( + m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool +): # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) - scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) out_dtype = torch.bfloat16 - bias = torch.rand( - (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + baseline = baseline_scaled_mm( + a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index a15decdf6f827..835c067e2f72f 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for cutlass kernels -Run `pytest tests/kernels/test_cutlass.py`. +Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`. """ + import random import pytest @@ -36,9 +37,7 @@ MNK_FACTORS = [ (512, 24576, 128), ] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # -1 means full extent in that dimension TENSORWISE_GROUP_SHAPE = (-1, -1) @@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape): def scale_shape(shape, group_shape): assert len(shape) == len(group_shape) group_shape = group_scale_helper(shape, group_shape) - return tuple( - cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) -def cutlass_fp8_gemm_helper(m: int, - n: int, - k: int, - a_scale_group_shape: tuple, - b_scale_group_shape: tuple, - use_bias: bool, - out_dtype: type[torch.dtype] = torch.bfloat16, - device: str = "cuda"): +def cutlass_fp8_gemm_helper( + m: int, + n: int, + k: int, + a_scale_group_shape: tuple, + b_scale_group_shape: tuple, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda", +): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_fp8(torch.randn((m, k), device=device)) @@ -80,36 +80,34 @@ def cutlass_fp8_gemm_helper(m: int, a_scales_shape = scale_shape(a.shape, a_scale_group_shape) b_scales_shape = scale_shape(b.shape, b_scale_group_shape) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) # make scales M-major for blockwise quant, doesn't affect 1D scales scale_a = scale_a.t().contiguous().t() # make scales K-major for blockwise quant, doesn't affect 1D scales scale_b = scale_b.t().contiguous().t() - if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 - else: - bias = None + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1) - opcheck(torch.ops._C.cutlass_scaled_mm, - (out, a, b, scale_a, scale_b, bias)) + opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) -def cutlass_int8_gemm_helper(m: int, - n: int, - k: int, - a_scale_group_shape: tuple, - b_scale_group_shape: tuple, - use_bias: bool, - out_dtype: type[torch.dtype] = torch.bfloat16, - device: str = "cuda"): +def cutlass_int8_gemm_helper( + m: int, + n: int, + k: int, + a_scale_group_shape: tuple, + b_scale_group_shape: tuple, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda", +): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_int8(torch.randn((m, k), device=device) * 5) @@ -118,158 +116,202 @@ def cutlass_int8_gemm_helper(m: int, a_scales_shape = scale_shape(a.shape, a_scale_group_shape) b_scales_shape = scale_shape(b.shape, b_scale_group_shape) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) - if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 - else: - bias = None + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) - opcheck(torch.ops._C.cutlass_scaled_mm, - (out, a, b, scale_a, scale_b, bias)) + opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape, - b_scale_group_shape, use_bias: bool): - cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): + cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", - [((1, 128), (128, 128))]) +@pytest.mark.parametrize( + "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))] +) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="FP8 blockwise is not supported on this GPU type.") -def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, - a_scale_group_shape, - b_scale_group_shape, use_bias: bool): +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="FP8 blockwise is not supported on this GPU type.", +) +def test_cutlass_fp8_blockwise_scale_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0: return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return if m % 4 != 0 and current_platform.has_device_capability(100): return - cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) + cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape, - b_scale_group_shape, use_bias: bool): - cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) +def test_cutlass_int8_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): + cutlass_int8_gemm_helper( + m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_int8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) +def test_cutlass_int8_gemm_output_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_int8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_fp8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_output_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) -@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", - [((1, 128), (128, 128))]) +@pytest.mark.parametrize( + "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="FP8 blockwise is not supported on this GPU type.") -def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_fp8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="FP8 blockwise is not supported on this GPU type.", +) +def test_cutlass_fp8_blockwise_scale_gemm_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape, - use_bias: bool, device: str): - cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape, - b_scale_group_shape, use_bias, torch.bfloat16, - device) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_devices( + a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + torch.bfloat16, + device, + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, - use_bias: bool, device: str): - cutlass_int8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=torch.bfloat16, - device=device) +def test_cutlass_int8_gemm_devices( + a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str +): + cutlass_int8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=torch.bfloat16, + device=device, + ) # For the following two tests: @@ -277,32 +319,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, # of a large power of two. In any case, the kernel will have a naive fallback # when N and K are not divisible by 16. But M is the number of tokens and the # kernel must handle any M thrown at it. -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, - use_bias: bool): +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_m_sweep( + a_scale_group_shape, b_scale_group_shape, use_bias: bool +): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape, - b_scale_group_shape, use_bias) + cutlass_fp8_gemm_helper( + m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, - use_bias: bool): +def test_cutlass_int8_gemm_m_sweep( + a_scale_group_shape, b_scale_group_shape, use_bias: bool +): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape, - b_scale_group_shape, use_bias) + cutlass_int8_gemm_helper( + m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias + ) @pytest.mark.parametrize("m", [32, 64, 128]) @@ -310,8 +362,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, @pytest.mark.parametrize("k", [64, 128, 256]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.skip -def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, - out_dtype: torch.dtype): +def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype): # Currently, the test is failing because folding azp into # 16-bit bias loses too much precision scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 @@ -328,7 +379,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, b_dq = scale_b * bq_f32 - azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5 + azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5 azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding @@ -340,18 +391,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, J = torch.ones((1, k), device="cuda", dtype=torch.float32) azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype) assert azp_bias.shape == (1, n) - assert azp_bias[0, :].shape == (n, ) + assert azp_bias[0, :].shape == (n,) - baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * ( - (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to( - dtype=out_dtype, device='cuda') + baseline_q = ( + scale_a.to(device="cpu") + * scale_b.to(device="cpu") + * ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu")) + ).to(dtype=out_dtype, device="cuda") - out = ops.cutlass_scaled_mm(aq_i8, - bq_i8, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=azp_bias[0, :]) + out = ops.cutlass_scaled_mm( + aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :] + ) torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0) torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0) @@ -362,8 +412,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("azp_per_token", [True, False]) -def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, - use_bias: bool, azp_per_token: bool): +def test_cutlass_int8_azp( + m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool +): m_azp = m if azp_per_token else 1 scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10 @@ -377,16 +428,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, bq_f32 = bq_i8.to(dtype=torch.float32) b_dq = scale_b * bq_f32 - azp_a = torch.rand( - (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5 + azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5 azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32) - torch.testing.assert_close(a_dq, - scale_a * aq_f32 - azp_a, - rtol=1e-4, - atol=1e-3) + torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3) if use_bias: bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5 @@ -396,8 +443,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype) # int32 mm not supported on CUDA - a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu') - cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda') + a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu") + cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda") baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype) # Hadamard is just the sum of the cols @@ -406,14 +453,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, func_bias = bias if use_bias else None if azp_per_token: - out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b, - out_dtype, azp_adj_i32, azp_i32, - func_bias) + out = ops.cutlass_scaled_mm_azp( + aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias + ) else: azp_with_adj_i32 = azp_i32 * azp_adj_i32 - out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b, - out_dtype, azp_with_adj_i32, None, - func_bias) + out = ops.cutlass_scaled_mm_azp( + aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias + ) # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4% # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05% @@ -423,13 +470,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol) if azp_per_token: - opcheck(torch.ops._C.cutlass_scaled_mm_azp, - (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, - func_bias)) + opcheck( + torch.ops._C.cutlass_scaled_mm_azp, + (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias), + ) else: - opcheck(torch.ops._C.cutlass_scaled_mm_azp, - (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, - func_bias)) + opcheck( + torch.ops._C.cutlass_scaled_mm_azp, + (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias), + ) # Test working with a subset of A and B @@ -445,23 +494,14 @@ def test_cutlass_subset(): scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) # Test to make sure cuda graphs work class CutlassLayer(torch.nn.Module): - def __init__(self, b, scale_a, scale_b, out_dtype): super().__init__() self.b = b @@ -470,8 +510,9 @@ class CutlassLayer(torch.nn.Module): self.out_dtype = out_dtype def forward(self, a): - return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b, - self.out_dtype) + return ops.cutlass_scaled_mm( + a, self.b, self.scale_a, self.scale_b, self.out_dtype + ) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -485,10 +526,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): m_a_scales = m if per_act_token else 1 n_b_scales = n if per_out_ch else 1 - scale_a = (torch.randn( - (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) - scale_b = (torch.randn( - (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) + scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10 # Construct a trivial model with a single layer that calls a CUTLASS kernel model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) @@ -502,13 +541,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): out.zero_() g.replay() - baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + baseline = torch.mm( + scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32) + ).to(torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) def test_cutlass_support_opcheck(): - opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,)) @pytest.mark.parametrize("num_experts", [8, 64]) @@ -517,11 +557,13 @@ def test_cutlass_support_opcheck(): @pytest.mark.parametrize("use_bias", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): - + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) +def test_cutlass_fp8_group_gemm( + num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool +): # Device and dtype setup device = "cuda" out_dtype = torch.half @@ -533,13 +575,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, b_scales_tensors = [] baseline_tensors = [] - expert_offsets = torch.zeros((num_experts + 1), - device=device, - dtype=torch.int64) + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64) - problem_sizes = torch.zeros((num_experts, 3), - device=device, - dtype=torch.int32) + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) if not per_act_token: one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32) @@ -566,75 +604,76 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, b_tensors.append(b_g) # Set up A/B scales - scale_b = torch.randn((1, n_b_scales), - device=device, - dtype=torch.float32) + scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32) b_scales_tensors.append(scale_b) if per_act_token: - scale_a = torch.randn((m_a_scales, 1), - device=device, - dtype=torch.float32) + scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32) a_scales_tensors.append(scale_a) else: scale_a = one_scale_a # Compute baseline result for this group - baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, - None) + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) baseline_tensors.append(baseline_g) - a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), - device=device, - dtype=torch.float8_e4m3fn) - b_tensors_stacked = torch.empty((num_experts, n_g, k_g), - device=device, - dtype=torch.float8_e4m3fn) + a_tensors_stacked = torch.empty( + (expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_tensors_stacked = torch.empty( + (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) for g in range(num_experts): - a_tensors_stacked[expert_offsets[g]:expert_offsets[g + - 1]] = a_tensors[g] + a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] b_tensors_stacked[g] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.transpose(1, 2) if per_act_token: a_scales_tensors_stacked = torch.empty( - (expert_offsets[num_experts], 1), - device=device, - dtype=torch.float32) + (expert_offsets[num_experts], 1), device=device, dtype=torch.float32 + ) for g in range(num_experts): - a_scales_tensors_stacked[ - expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] + a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = ( + a_scales_tensors[g] + ) else: a_scales_tensors_stacked = one_scale_a - b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales), - device=device, - dtype=torch.float32) + b_scales_tensors_stacked = torch.empty( + (num_experts, n_b_scales), device=device, dtype=torch.float32 + ) for g in range(num_experts): b_scales_tensors_stacked[g] = b_scales_tensors[g] - out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), - device=device, - dtype=out_dtype) + out_tensors_stacked = torch.zeros( + (expert_offsets[num_experts], n_g), device=device, dtype=out_dtype + ) - ab_strides = torch.full((num_experts, ), - a_tensors_stacked.stride(0), - device="cuda", - dtype=torch.int64) - c_strides = torch.full((num_experts, ), - out_tensors_stacked.stride(0), - device="cuda", - dtype=torch.int64) + ab_strides = torch.full( + (num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64 + ) + c_strides = torch.full( + (num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64 + ) - ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, - b_tensors_stacked, a_scales_tensors_stacked, - b_scales_tensors_stacked, expert_offsets[:-1], - problem_sizes, ab_strides, ab_strides, c_strides, - per_act_token, per_out_ch) + ops.cutlass_moe_mm( + out_tensors_stacked, + a_tensors_stacked, + b_tensors_stacked, + a_scales_tensors_stacked, + b_scales_tensors_stacked, + expert_offsets[:-1], + problem_sizes, + ab_strides, + ab_strides, + c_strides, + per_act_token, + per_out_ch, + ) # Validate each group's result against the baseline for g in range(num_experts): baseline = baseline_tensors[g] - c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4) diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py new file mode 100644 index 0000000000000..a3d524fe90ed0 --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CUTLASS W4A8 kernel. + +Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`. +""" + +from dataclasses import dataclass +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, + quantize_weights, +) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (1, 8192, 28672), + (13, 8192, 4096), + (26, 4096, 8192), + (64, 4096, 4096), + (64, 8192, 28672), + (257, 128, 4096), + (257, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] + +# TODO(czhu): get supported schedules from fn +SCHEDULES = [ + "128x16_1x1x1", + "256x16_1x1x1", + "128x32_1x1x1", + "256x32_1x1x1", + "128x64_1x1x1", + "256x64_1x1x1", + "128x128_1x1x1", + "256x128_1x1x1", + "128x256_1x1x1", + "128x256_2x1x1", +] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: torch.Tensor + w_ch_s: torch.Tensor + w_tok_s: torch.Tensor + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +TestTypeTuple = tuple[ + list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool +] +TEST_TYPES = [ + *( + TypeConfig( + act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32, + ) + for w_type in [scalar_types.int4] + # TODO(czhu): fp16 out type + for o_type in [torch.bfloat16] + ), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, wtype, group_size=group_size, zero_points=zero_points + ) + + # since scales are cast to fp8, we need to compute w_ref this way + w_ref = ( + (w_q).to(torch.float32) + * w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0) + ).to(atype) + + # bit mask prevents sign extending int4 when packing + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q) + w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype)) + + return w_ref, w_q_packed, w_s_packed, w_zp + + +def create_test_tensors( + shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] +) -> Tensors: + m, n, k = shape + + print( + "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size + ) + + a = to_fp8(torch.randn((m, k), device="cuda")) + w = to_fp8(torch.randn((k, n), device="cuda")) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, False + ) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + # for the practical use case we need per-tok scales for fp8 activations + w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type) + # weights are already per-group quantized, use placeholder here + w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type) + + return Tensors( + w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) + + +def mm_test_helper( + types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None, +): + # CUTLASS upstream uses fp8 with fastaccum as reference + # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 + output_ref = torch._scaled_mm( + tensors.a_ref.to(types.act_type), + tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major + tensors.w_tok_s.unsqueeze(1), + tensors.w_ch_s.unsqueeze(0), + out_dtype=types.output_type, + use_fast_accum=True, + ) + + output = ops.cutlass_w4a8_mm( + a=tensors.a, + b_q=tensors.w_q, + b_group_scales=tensors.w_g_s, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + ) + + print(output) + print(output_ref) + + torch.testing.assert_close( + output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +@pytest.mark.parametrize("schedule", SCHEDULES) +def test_cutlass_w4a8(shape, types: TypeConfig, schedule): + group_sizes = [128] + for group_size in group_sizes: + tensors = create_test_tensors(shape, types, group_size) + mm_test_helper(types, tensors, group_size, schedule) + + +# Test to make sure cuda graphs work +class W4A8Layer(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.cutlass_w4a8_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type." +) +def test_w4a8_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((k, n), device="cuda")) + + wtype = scalar_types.int4 + stype = torch.float8_e4m3fn + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points + ) + + w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32) + w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32) + + # Construct a trivial model with a single layer that calls the kernel + model = W4A8Layer( + b_q=w_q_packed, + b_group_scales=w_s, + b_group_size=group_size, + b_channel_scales=w_ch_s, + a_token_scales=w_tok_s, + ) + + output_ref = torch._scaled_mm( + a, + w_ref.to(a.dtype).t().contiguous().t(), # col major + w_tok_s.unsqueeze(1), + w_ch_s.unsqueeze(0), + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + + output.zero_() + g.replay() + + torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index 131086a5f7034..1e5c7dafb0f5a 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -2,8 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch -from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, - convert_swizzled_to_linear, dequantize_nvfp4_to_dtype) +from nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + convert_swizzled_to_linear, + dequantize_nvfp4_to_dtype, +) from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -41,18 +45,12 @@ def get_ref_results( _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert m_k == n_k - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + ) + b_in_dtype = dequantize_nvfp4_to_dtype( + b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size + ) return torch.matmul(a_in_dtype, b_in_dtype.t()) @@ -72,8 +70,7 @@ def test_flashinfer_nvfp4_gemm( autotune: bool, ) -> None: if backend == "trtllm" and dtype == torch.float16: - pytest.skip( - "Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") + pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") current_platform.seed_everything(seed) m, n, packed_k = shape @@ -82,10 +79,12 @@ def test_flashinfer_nvfp4_gemm( a_dtype = torch.randn((m, k), dtype=dtype, device=device) b_dtype = torch.randn((n, k), dtype=dtype, device=device) - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) - b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) alpha = 1.0 / (a_global_scale * b_global_scale) # ops.scaled_fp4_quant returns swizzled scales, while weights # from checkpoints are in linear scales. @@ -113,14 +112,18 @@ def test_flashinfer_nvfp4_gemm( if backend == "trtllm": epilogue_tile_m = 128 - b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), - epilogue_tile_m) + b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m) b_scale_interleaved = convert_swizzled_to_linear( - b_scale_interleaved, n, k, block_size) - b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a( - b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape( - b_scale_interleaved.shape).view(torch.float8_e4m3fn)) + b_scale_interleaved, n, k, block_size + ) + b_scale_interleaved = ( + flashinfer.shuffle_matrix_sf_a( + b_scale_interleaved.view(torch.uint8), epilogue_tile_m + ) + .reshape(b_scale_interleaved.shape) + .view(torch.float8_e4m3fn) + ) with flashinfer.autotune(autotune): out = flashinfer_scaled_fp4_mm( @@ -133,7 +136,4 @@ def test_flashinfer_nvfp4_gemm( backend=backend, ) - torch.testing.assert_close(out, - expected_out.to(dtype=dtype), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_flashinfer_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_scaled_mm.py new file mode 100644 index 0000000000000..b30821b6895bc --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_scaled_mm.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_fp8_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + use_bias: bool, + seed: int, + device: str, + autotune: bool, +) -> None: + current_platform.seed_everything(seed) + m, n, k = shape + a = torch.randn((m, k), dtype=dtype, device=device) + b = torch.randn((n, k), dtype=dtype, device=device) / k + + a_fp8, a_scale = ops.scaled_fp8_quant(a) + b_fp8, b_scale = ops.scaled_fp8_quant(b) + + expected_out = torch.mm( + a_scale * a_fp8.to(dtype=torch.float32), + b_scale * b_fp8.to(dtype=torch.float32).t(), + ).to(dtype=dtype) + + if use_bias: + bias = torch.randn((n,), dtype=dtype, device=device) + expected_out = expected_out + bias + else: + bias = None + + import flashinfer + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp8_mm( + a_fp8, + b_fp8.t(), + a_scale, + b_scale, + dtype, + bias=bias, + ) + + torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index c2e70ffb8d343..19aa21b96a573 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -5,9 +5,11 @@ import pytest import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import (FP8_DTYPE, - ref_dynamic_per_tensor_fp8_quant, - ref_dynamic_per_token_quant) +from tests.kernels.quant_utils import ( + FP8_DTYPE, + ref_dynamic_per_tensor_fp8_quant, + ref_dynamic_per_token_quant, +) from tests.kernels.utils import opcheck from vllm.platforms import current_platform @@ -18,23 +20,25 @@ SCALE_UBS = [True, False] SEEDS = [0] -def opcheck_fp8_quant(output, - input, - scale=None, - scale_ub=None, - use_per_token_if_dynamic=False): +def opcheck_fp8_quant( + output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False +): if scale is not None: opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) elif use_per_token_if_dynamic: - scale = torch.empty((input.shape[0], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant, - (output, input, scale, scale_ub)) + scale = torch.empty( + (input.shape[0], 1), device=input.device, dtype=torch.float32 + ) + opcheck( + torch.ops._C.dynamic_per_token_scaled_fp8_quant, + (output, input, scale, scale_ub), + ) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32, + ) opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale)) @@ -44,30 +48,29 @@ def opcheck_fp8_quant(output, @pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, scale_ub: bool, - seed: int) -> None: +def test_dynamic_per_token_fp8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int +) -> None: current_platform.seed_everything(seed) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") + 1e-6 # avoid nans + x = ( + torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 + ) # avoid nans - scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ - if scale_ub else None + scale_ub = ( + torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None + ) ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub) - ops_out, ops_scales = ops.scaled_fp8_quant(x, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) + ops_out, ops_scales = ops.scaled_fp8_quant( + x, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) torch.testing.assert_close(ref_scales, ops_scales) - torch.testing.assert_close(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + torch.testing.assert_close( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) - opcheck_fp8_quant(ops_out, - x, - None, - scale_ub, - use_per_token_if_dynamic=True) + opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -75,8 +78,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_per_tensor_fp8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -85,8 +89,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, ops_out, ops_scale = ops.scaled_fp8_quant(x) torch.testing.assert_close(ref_scale, ops_scale) - torch.testing.assert_close(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + torch.testing.assert_close( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) opcheck_fp8_quant(ops_out, x) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py new file mode 100644 index 0000000000000..6628ac650fd5f --- /dev/null +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for QuantFP8 Group Quantization implementation.""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform + + +@pytest.mark.parametrize( + "batch_size,hidden_dim,group_size", + [ + (16, 256, 32), # Small + (64, 1024, 64), # Medium + (128, 2048, 128), # Large + (8, 513, 64), # Non-divisible (native only) + ], +) +@pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) +@torch.inference_mode() +def test_quantfp8_group_functionality( + batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool +) -> None: + """Test QuantFP8 group quantization with various configurations. + + Tests both CUDA and native implementations, column-major scales, + and verifies consistency between implementations. + """ + current_platform.seed_everything(seed) + + x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + expected_num_groups = (hidden_dim + group_size - 1) // group_size + is_divisible = hidden_dim % group_size == 0 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=False, + use_ue8m0=use_ue8m0, + ) + + # 1. Test native implementation (always available) + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + assert x_quant_native.shape == x.shape + assert scales_native.shape == (batch_size, expected_num_groups) + + # 2. Test column-major scales configuration + quant_op_col = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=True, + use_ue8m0=use_ue8m0, + ) + _, scales_col = quant_op_col.forward_native(x.clone()) + assert scales_col.shape == (batch_size, expected_num_groups) + assert scales_col.stride(0) == 1 + assert scales_col.stride(1) == batch_size + + # Test column-major scales consistency + assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) + + # 3. Test CUDA implementation (only for divisible dimensions) + if is_divisible: + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + assert x_quant_cuda.shape == x.shape + assert scales_cuda.shape == (batch_size, expected_num_groups) + + # Verify CUDA/native consistency + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + + # Quantized values should mostly match + diff_count = (x_quant_cuda != x_quant_native).sum().item() + diff_ratio = diff_count / x_quant_cuda.numel() + assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" + + +@pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) +@torch.inference_mode() +def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: + current_platform.seed_everything(seed) + + group_size = 64 + + # Test with 3D input + batch1, batch2, hidden_dim = 4, 8, 1024 + x_3d = ( + torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") + * 8 + ) + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=False, + use_ue8m0=use_ue8m0, + ) + + x_quant, scales = quant_op.forward_native(x_3d.clone()) + assert x_quant.shape == x_3d.shape + assert scales.shape == (batch1, batch2, hidden_dim // group_size) + + # Test column_major_scales with multi-dim + quant_op_col = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=True, + use_ue8m0=use_ue8m0, + ) + _, scales_col = quant_op_col.forward_native(x_3d.clone()) + assert scales_col.shape == (batch1, batch2, hidden_dim // group_size) + + # Test with 4D input + batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 + x_4d = ( + torch.randn( + (batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda" + ) + * 8 + ) + + x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) + assert x_quant_4d.shape == x_4d.shape + assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size) + + _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) + assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3) + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_edge_cases(seed: int) -> None: + current_platform.seed_everything(seed) + + batch_size = 16 + group_size = 64 + + # Test with single group (group_size >= hidden_dim) + x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8( + static=False, group_shape=group_shape, column_major_scales=False + ) + + x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) + assert x_quant_small.shape == x_small.shape + assert scales_small.shape == (batch_size, 1) + + # Test with zero inputs + x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda") + x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) + assert x_quant_zero.shape == x_zero.shape + assert (scales_zero > 0).all(), "Scales should be clamped to minimum" + + # Test very large values + x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda") + x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) + assert x_quant_large.shape == x_large.shape + # FP8 max is typically 448 or 224, so scales should be > 1 + assert (scales_large > 1.0).all(), "Large values should have scales > 1" diff --git a/tests/kernels/quantization/test_ggml.py b/tests/kernels/quantization/test_ggml.py index 07651fef39bf4..0dc24187f2b34 100644 --- a/tests/kernels/quantization/test_ggml.py +++ b/tests/kernels/quantization/test_ggml.py @@ -13,33 +13,42 @@ from vllm import _custom_ops as ops # noqa: F401 def test_ggml_opcheck(quant_type): block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] shape = [256, 1152] - qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8) m = qweight.shape[0] n = qweight.shape[1] // type_size * block_size - opcheck(torch.ops._C.ggml_dequantize, - (qweight, quant_type, m, n, torch.float16)) + opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16)) - x = torch.rand((m, 512), device='cuda', dtype=torch.float16) - opcheck(torch.ops._C.ggml_mul_mat_a8, - (qweight, x, quant_type, qweight.shape[0])) - opcheck(torch.ops._C.ggml_mul_mat_vec_a8, - (qweight, x, quant_type, qweight.shape[0])) + x = torch.rand((m, 512), device="cuda", dtype=torch.float16) + opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0])) + opcheck( + torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0]) + ) shape = [256, 1024, 336] - qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) - x = torch.rand((1, 1024), device='cuda', dtype=torch.float16) - sorted_token_ids = torch.arange(776, device='cuda') - expert_ids = torch.randint(0, 256, (194, ), device='cuda') - num_tokens_post_padded = torch.tensor([1], - dtype=torch.int64, - device='cuda') + qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8) + x = torch.rand((1, 1024), device="cuda", dtype=torch.float16) + sorted_token_ids = torch.arange(776, device="cuda") + expert_ids = torch.randint(0, 256, (194,), device="cuda") + num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda") - opcheck(torch.ops._C.ggml_moe_a8, - (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, - quant_type, qweight.shape[0], 1, x.shape[0])) + opcheck( + torch.ops._C.ggml_moe_a8, + ( + x, + qweight, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + quant_type, + qweight.shape[0], + 1, + x.shape[0], + ), + ) - topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32) + topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32) opcheck( torch.ops._C.ggml_moe_a8_vec, - (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0])) + (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]), + ) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 436d5cb640219..0988ba01759f2 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -18,8 +18,8 @@ GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample") def get_gguf_sample_tensors( - hidden_size: int, - quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + hidden_size: int, quant_type: GGMLQuantizationType +) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename @@ -27,8 +27,8 @@ def get_gguf_sample_tensors( def get_gguf_MoE_tensors( - hidden_size: int, - quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + hidden_size: int, quant_type: GGMLQuantizationType +) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE_MOE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename @@ -68,17 +68,20 @@ QUANT_TYPES = [ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_dequantize(hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_dequantize( + hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType +): tensors = get_gguf_sample_tensors(hidden_size, quant_type) for tensor in tensors: shape_str = tensor.name.split("_")[-1] shape = map(int, shape_str.split("x")) - ref_output = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) - output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), - quant_type, *list(shape), dtype) + ref_output = torch.tensor( + dequantize(tensor.data, quant_type), device="cuda" + ).to(dtype) + output = ops.ggml_dequantize( + torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype + ) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) @@ -87,20 +90,21 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_mmvq(hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") for tensor in tensors: - weight = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) + weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to( + dtype + ) ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") - output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, - qweight.shape[0]).to(dtype) + output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to( + dtype + ) torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) @@ -121,17 +125,23 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, GGMLQuantizationType.Q4_0, GGMLQuantizationType.Q5_0, GGMLQuantizationType.Q8_0, - ]) + ], +) @torch.inference_mode() -def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_mmq( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_type: GGMLQuantizationType, +): current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") for tensor in tensors: - weight = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) + weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to( + dtype + ) ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") @@ -141,10 +151,9 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, # bfloat16 tends to accumulate and can greatly inflate rtol # since outputs are also very close to 0 rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1} - torch.testing.assert_close(output, - ref_output, - atol=atols[dtype], - rtol=rtols[dtype]) + torch.testing.assert_close( + output, ref_output, atol=atols[dtype], rtol=rtols[dtype] + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -153,35 +162,46 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType, top_k: int): +def test_moe( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_type: GGMLQuantizationType, + top_k: int, +): current_platform.seed_everything(0) H, E = 1024, 256 x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) - topk_ids = torch.randint(0, - E, (num_tokens, top_k), - device="cuda", - dtype=torch.int32) + topk_ids = torch.randint( + 0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32 + ) tensors = get_gguf_MoE_tensors(hidden_size, quant_type) w13 = tensors[0] w2 = tensors[1] - w13_dequant = torch.tensor(dequantize(w13.data, quant_type), - device="cuda").to(dtype) + w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to( + dtype + ) - w2_dequant = torch.tensor(dequantize(w2.data, quant_type), - device="cuda").to(dtype) + w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype) - output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"), - torch.tensor(w2.data, - device="cuda"), topk_weights, - topk_ids, quant_type, quant_type, "silu") + output = _fused_moe_gguf( + x, + torch.tensor(w13.data, device="cuda"), + torch.tensor(w2.data, device="cuda"), + topk_weights, + topk_ids, + quant_type, + quant_type, + "silu", + ) - ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights, - topk_ids).reshape(output.shape) + ref_output = fused_experts( + x, w13_dequant, w2_dequant, topk_weights, topk_ids + ).reshape(output.shape) torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_gptq.py b/tests/kernels/quantization/test_gptq.py index 7fb57a1576bd8..72e4194c13276 100644 --- a/tests/kernels/quantization/test_gptq.py +++ b/tests/kernels/quantization/test_gptq.py @@ -8,25 +8,22 @@ from vllm import _custom_ops as ops # noqa: F401 def test_gptq_shuffle_opcheck(): - weight = torch.randint(-2000000, - 2000000, (1792, 4096), - device='cuda', - dtype=torch.int32) - perm = torch.empty((0, ), device='cuda', dtype=torch.int32) + weight = torch.randint( + -2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32 + ) + perm = torch.empty((0,), device="cuda", dtype=torch.int32) bit = 4 opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit)) def test_gptq_gemm_opcheck(): - a = torch.rand((240, 4096), device='cuda', dtype=torch.float16) - weight = torch.randint(-2000000, - 2000000, (512, 6144), - device='cuda', - dtype=torch.int32) - zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32) - scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16) - idx = torch.empty((0, ), device='cuda', dtype=torch.int32) + a = torch.rand((240, 4096), device="cuda", dtype=torch.float16) + weight = torch.randint( + -2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32 + ) + zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32) + scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16) + idx = torch.empty((0,), device="cuda", dtype=torch.int32) use_exllama = True bit = 4 - opcheck(torch.ops._C.gptq_gemm, - (a, weight, zeros, scales, idx, use_exllama, bit)) + opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit)) diff --git a/tests/kernels/quantization/test_hadacore.py b/tests/kernels/quantization/test_hadacore.py new file mode 100644 index 0000000000000..3ccee9db048cf --- /dev/null +++ b/tests/kernels/quantization/test_hadacore.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import pytest +import torch +from compressed_tensors.transform import deterministic_hadamard_matrix + +from vllm import _custom_ops as ops + + +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)]) +def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"): + x = torch.eye(hidden_dim, dtype=dtype, device=device) + hadamard = deterministic_hadamard_matrix( + hidden_dim, dtype=torch.float64, device="cuda" + ) / math.sqrt(hidden_dim) + + y = ops.hadacore_transform(x.clone()) + y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype) + assert torch.allclose(y, y_true) + + y = ops.hadacore_transform(y) + assert torch.allclose(y, x) diff --git a/tests/kernels/quantization/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py index dc5fecbf4ccc8..0e31e9aabea85 100644 --- a/tests/kernels/quantization/test_int8_kernel.py +++ b/tests/kernels/quantization/test_int8_kernel.py @@ -8,14 +8,15 @@ import pytest import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_quant_int8) + per_token_quant_int8, +) from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): @@ -25,14 +26,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous( - ), "B must be a 2D contiguous tensor" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" # Reshape input M = A.numel() // A.shape[-1] B = B.t() # Transpose weight matrix N, K = B.shape - origin_C_shape = A.shape[:-1] + (K, ) + origin_C_shape = A.shape[:-1] + (K,) A = A.reshape(M, N) # As is per-token [M, 1], Bs is per-column [1, K] @@ -42,7 +42,7 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): return C.reshape(origin_C_shape).to(output_dtype) -def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids): """This function performs fused moe with per-column int8 quantization using native torch.""" @@ -57,8 +57,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) # Calculate routing - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) # Process each expert @@ -66,25 +64,22 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): mask = topk_ids == i if mask.sum(): # First MLP layer: note that a_s is now per-token - inter_out = native_w8a8_per_token_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - output_dtype=a.dtype) + inter_out = native_w8a8_per_token_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype + ) # Activation function act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token act_out_q, act_out_s = per_token_quant_int8(act_out) # Second MLP layer - out[mask] = native_w8a8_per_token_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - output_dtype=a.dtype) + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) # Apply routing weights and sum - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -102,8 +97,10 @@ TOP_KS = [2, 6] SEEDS = [0] -@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): torch.manual_seed(seed) @@ -127,24 +124,32 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(score, topk) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( + ref_out = torch_w8a8_per_column_moe( + a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids + ) + + quant_config = FusedMoEQuantConfig.make( + torch.int8, + per_act_token_quant=True, + block_shape=None, + w1_scale=w1_s, + w2_scale=w2_s, + ) + + out = fused_experts( a, w1, w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, # Using int8-w8a8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization + topk_weights, + topk_ids, + quant_config=quant_config, ) # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.05 diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index c1c9bf191d5b5..48e947db5fa78 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -18,26 +18,24 @@ SCALE = [0.1, 2.1] def opcheck_int8_quant_static(output, input, scale, azp=None): if azp is None: - opcheck(torch.ops._C.static_scaled_int8_quant, - (output, input, scale, None)) + opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None)) else: - opcheck(torch.ops._C.static_scaled_int8_quant, - (output, input, scale, azp)) + opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp)) def opcheck_int8_quant_dynamic(output, input, symmetric=True): - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) if symmetric: - opcheck(torch.ops._C.dynamic_scaled_int8_quant, - (output, input, scale, None)) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None)) else: - azp = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.int32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, - (output, input, scale, azp)) + azp = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32, + ) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -45,8 +43,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_scaled_int8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -68,30 +67,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_scaled_int8_azp_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) - azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( - torch.int32) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32) - torch_out = ((x / scales).round() + azps).clamp( - int8_traits.min, int8_traits.max).to(torch.int8) - assert torch_out.min() >= int8_traits.min and torch_out.max( - ) <= int8_traits.max + torch_out = ( + ((x / scales).round() + azps) + .clamp(int8_traits.min, int8_traits.max) + .to(torch.int8) + ) + assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False) - if (not torch.allclose(scales_out, scales)): + if not torch.allclose(scales_out, scales): print(torch.argmax(torch.abs(scales_out - scales))) torch.testing.assert_close(scales_out, scales) # big atol to account for rounding errors @@ -108,17 +108,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("scale", SCALE) @torch.inference_mode() -def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int, - scale: float) -> None: +def test_static_scaled_int8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") - out1 = (x / scale_arg).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + out1 = ( + (x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) + ) out2, scale2, _ = scaled_int8_quant(x, scale_arg) assert scale2 is scale_arg @@ -135,24 +136,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("scale", SCALE) @pytest.mark.parametrize("azp", [-255, 54]) @torch.inference_mode() -def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int, - scale: float, azp: int) -> None: +def test_static_scaled_int8_azp_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + scale: float, + azp: int, +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 - out1 = ((x / scale).round() + azp).clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + out1 = ( + ((x / scale).round() + azp) + .clamp(int8_traits.min, int8_traits.max) + .to(torch.int8) + ) scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") - out2, scale2, azp2 = scaled_int8_quant(x, - scale_arg, - azp_arg, - symmetric=False) + out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False) assert scale2 is scale_arg assert azp2 is azp_arg @@ -172,10 +177,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: int32_traits = torch.iinfo(torch.int32) val = float(int32_traits.max if is_max else int32_traits.min) - x_vals = [[ - nextafter(val, inf), val + 1, val, val - 1, - nextafter(val, -inf) - ]] + x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]] x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") # The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp) diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index 0e09661c955e4..b32523bb85d9a 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the machete kernel. -Run `pytest tests/kernels/test_machete_mm.py`. +Run `pytest tests/kernels/quantization/test_machete_mm.py`. """ import math @@ -15,15 +15,16 @@ import torch from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.machete_utils import ( - query_machete_supported_group_sizes) + query_machete_supported_group_sizes, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel # unit tests to a common utility function. Currently the use of @@ -72,29 +73,38 @@ class Tensors: # Ch Scales Type, Tok Scales Type) # NOTE: None "Scale Type" means the act type is floating point # None "Output Type" means the output type is the same as the act type -TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], - Optional[torch.dtype], bool] +TestTypeTuple = tuple[ + list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool +] TEST_TYPES = [ # GPTQ style - *(TypeConfig(act_type=a_type, - weight_type=w_type, - output_type=None, - group_scale_type=a_type, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) - for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] - for a_type in [torch.float16, torch.bfloat16]), + *( + TypeConfig( + act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16] + ), # AWQ style - *(TypeConfig(act_type=a_type, - weight_type=w_type, - output_type=None, - group_scale_type=a_type, - group_zero_type=a_type, - channel_scale_type=None, - token_scale_type=None) - for w_type in [scalar_types.uint4, scalar_types.uint8] - for a_type in [torch.float16, torch.bfloat16]), + *( + TypeConfig( + act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None, + ) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16] + ), # # QQQ style # *(TypeConfig(act_type=torch.int8, # weight_type=scalar_types.uint4b8, @@ -133,17 +143,18 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): return zps if zps is None else -1 * s * (zps.to(s.dtype)) -def group_size_valid(shape: tuple[int, int, int], - group_size: Optional[int]) -> bool: +def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: return group_size is None or group_size == -1 or shape[2] % group_size == 0 -def machete_quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def machete_quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( @@ -152,7 +163,8 @@ def machete_quantize_and_pack(atype: torch.dtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps - ref_zero_points_after_scales=True) + ref_zero_points_after_scales=True, + ) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) w_q = w_q.t().contiguous().t() # convert to col major @@ -163,15 +175,18 @@ def machete_quantize_and_pack(atype: torch.dtype, return w_ref, w_q_machete, w_s, w_zp -def create_test_tensors(shape: tuple[int, int, int], - types: TypeConfig, - group_size: Optional[int], - subset_stride_factor: Optional[int] = None) -> Tensors: +def create_test_tensors( + shape: tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None, +) -> Tensors: m, n, k = shape factor = subset_stride_factor or 1 - print("create_test_tensors, shape:", shape, "types:", types, "group_size:", - group_size) + print( + "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size + ) a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) @@ -186,8 +201,13 @@ def create_test_tensors(shape: tuple[int, int, int], w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - types.group_zero_type is not None) + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) @@ -196,35 +216,47 @@ def create_test_tensors(shape: tuple[int, int, int], a_ref = a.to(torch.float32) w_ref = w_ref.to(torch.float32) - w_ch_s = None if types.channel_scale_type is None else\ - rand_data((n,), types.channel_scale_type) - w_tok_s = None if types.token_scale_type is None else\ - rand_data((m,), types.token_scale_type) + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) - return Tensors(w_ref=w_ref, - a_ref=a_ref, - a=a, - w_q=w_q_packed, - w_g_s=w_s, - w_g_zp=maybe_convert_zeropoints(w_zp, w_s), - w_ch_s=w_ch_s, - w_tok_s=w_tok_s) + return Tensors( + w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) # None stype means scales use the same dtype as a -def machete_mm_test_helper(types: TypeConfig, - tensors: Tensors, - group_size: Optional[int] = None, - schedule: Optional[str] = None): +def machete_mm_test_helper( + types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None, +): output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) output_ref_type = output_ref.dtype if tensors.w_ch_s is not None: - output_ref = (output_ref.to(tensors.w_ch_s.dtype) * - tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + output_ref = ( + output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0) + ).to(output_ref_type) if tensors.w_tok_s is not None: - output_ref = (output_ref.to(tensors.w_tok_s.dtype) * - tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + output_ref = ( + output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1) + ).to(output_ref_type) output = ops.machete_mm( a=tensors.a, @@ -245,23 +277,23 @@ def machete_mm_test_helper(types: TypeConfig, # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies # zeropoints (after scales) causes noise around 0 - atol = 1 if tensors.w_g_zp is not None\ + atol = ( + 1 + if tensors.w_g_zp is not None else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + ) rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 - torch.testing.assert_close(output, - output_ref.to(output.dtype), - rtol=rtol, - atol=atol) + torch.testing.assert_close( + output, output_ref.to(output.dtype), rtol=rtol, atol=atol + ) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_all_schedules(shape, types: TypeConfig): - group_sizes: list[Optional[int]] = [] if types.group_scale_type is None: group_sizes = [None] @@ -275,20 +307,20 @@ def test_machete_all_schedules(shape, types: TypeConfig): tensors = create_test_tensors(shape, types, group_size) print(f"MNK = {shape}") for schedule in ops.machete_supported_schedules( - types.act_type, - types.weight_type, - group_scales_type=types.group_scale_type, - group_zeros_type=types.group_scale_type, - out_type=types.output_type): + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type, + ): print(f"Testing schedule {schedule}") machete_mm_test_helper(types, tensors, group_size, schedule) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_heuristic(shape, types: TypeConfig): group_sizes: list[Optional[int]] = [] @@ -306,19 +338,22 @@ def test_machete_heuristic(shape, types: TypeConfig): # Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_machete_devices(device: str): group_size = 128 - type_config = TypeConfig(act_type=torch.float16, - weight_type=scalar_types.uint4b8, - output_type=None, - group_scale_type=torch.float16, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) + type_config = TypeConfig( + act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) @@ -331,29 +366,30 @@ def test_machete_devices(device: str): # Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) def test_machete_subset(): group_size = 128 - type_config = TypeConfig(act_type=torch.float16, - weight_type=scalar_types.uint4b8, - output_type=None, - group_scale_type=torch.float16, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) + type_config = TypeConfig( + act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) - tensors = create_test_tensors((512, 4096, 4096), - type_config, - group_size, - subset_stride_factor=2) + tensors = create_test_tensors( + (512, 4096, 4096), type_config, group_size, subset_stride_factor=2 + ) machete_mm_test_helper(type_config, tensors, group_size) # Test to make sure cuda graphs work class MacheteLayer(torch.nn.Module): - def __init__(self, **kwargs): super().__init__() self.kwargs = kwargs @@ -362,8 +398,9 @@ class MacheteLayer(torch.nn.Module): return ops.machete_mm(a=a, **self.kwargs) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) def test_machete_cuda_graph(): m, n, k = 512, 4096, 4096 @@ -375,7 +412,8 @@ def test_machete_cuda_graph(): zero_points = False w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - a.dtype, b, wtype, stype, group_size, zero_points) + a.dtype, b, wtype, stype, group_size, zero_points + ) # Construct a trivial model with a single layer that calls a machete kernel model = MacheteLayer( diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index ad077e0b94732..0833115fcf301 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the marlin kernel. -Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. """ + import pytest import torch @@ -11,24 +12,44 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, - query_marlin_supported_quant_types) + MARLIN_SUPPORTED_GROUP_SIZES, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + query_marlin_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like, - rand_marlin_weight_nvfp4_like) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, + rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - marlin_quant_fp8_torch) + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, - marlin_weights) + MarlinWorkspace, + awq_marlin_quantize, + get_weight_perm, + marlin_quantize, + marlin_weights, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize) + marlin_24_quantize, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) + awq_pack, + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] @@ -56,24 +77,27 @@ DTYPES = [torch.float16, torch.bfloat16] def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def rand_data(shape, dtype=torch.float16): return torch.randn(shape, dtype=dtype, device="cuda") -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False, False)) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, - act_order, mnk_factors): +def test_gptq_marlin_repack( + k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors +): m_factor, n_factor, k_factor = mnk_factors size_k = k_chunk * k_factor @@ -96,7 +120,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - b_weight, quant_type, group_size, act_order) + b_weight, quant_type, group_size, act_order + ) # Pack to GPTQ format q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) @@ -109,11 +134,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, # Pack to Marlin format weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm + ) - opcheck(torch.ops._C.gptq_marlin_repack, - (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits)) + opcheck( + torch.ops._C.gptq_marlin_repack, + (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits), + ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -128,16 +156,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, - mnk_factors): +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_k = k_chunk * k_factor @@ -152,21 +180,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, b_weight = rand_data((size_k, size_n)) # Quantize - w_ref, q_w, s, zp = quantize_weights(b_weight, - quant_type, - group_size, - zero_points=True) + w_ref, q_w, s, zp = quantize_weights( + b_weight, quant_type, group_size, zero_points=True + ) # Pack to AWQ format q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm + ) - opcheck(torch.ops._C.awq_marlin_repack, - (q_w_awq, size_k, size_n, quant_type.size_bits)) + opcheck( + torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits) + ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( @@ -180,23 +209,34 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) @pytest.mark.parametrize( - "group_size", - set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) + "group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES) +) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) @pytest.mark.parametrize("dtype", DTYPES) -def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, - mnk_factors, act_order, is_k_full, use_atomic_add, - use_fp32_reduce, dtype): +def test_gptq_marlin_gemm( + k_chunk, + n_chunk, + quant_type, + group_size, + mnk_factors, + act_order, + is_k_full, + use_atomic_add, + use_fp32_reduce, + dtype, +): m_factor, n_factor, k_factor = mnk_factors has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] @@ -225,11 +265,13 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, return if group_size == 16: - w_ref, marlin_q_w, marlin_s, marlin_s2 = \ - rand_marlin_weight_nvfp4_like(b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like( + b_weight.T, group_size + ) else: - w_ref, marlin_q_w, marlin_s = \ - rand_marlin_weight_mxfp4_like(b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like( + b_weight.T, group_size + ) marlin_s2 = None g_idx = None @@ -240,8 +282,7 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, return if act_order: return - w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch( - b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size) g_idx = None sort_indices = None marlin_zp = None @@ -250,7 +291,8 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, if group_size == 16: return w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) + b_weight, quant_type, group_size + ) g_idx = None sort_indices = None marlin_s2 = None @@ -258,18 +300,37 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, if group_size == 16: return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, act_order) + b_weight, quant_type, group_size, act_order + ) marlin_zp = None marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp, - g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, - use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_gemm, + ( + a_input, + None, + marlin_q_w, + None, + marlin_s, + marlin_s2, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type.id, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + use_atomic_add, + use_fp32_reduce, + False, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) output = ops.gptq_marlin_gemm( a_input, @@ -302,23 +363,40 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, # TODO: find better way to test this? @torch.compile(fullgraph=True) -def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s, scratch, quant_type, size_m, size_n, - size_k): - return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s, scratch, quant_type, size_m, - size_n, size_k) +def marlin_24_gemm_tester( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + scratch, + quant_type, + size_m, + size_n, + size_k, +): + return ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + scratch, + quant_type, + size_m, + size_n, + size_k, + ) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) @pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, - mnk_factors): +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -328,19 +406,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size) + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize( + b_weight, quant_type, group_size + ) - workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) + workspace_24 = MarlinWorkspace( + size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL + ) output_ref = torch.matmul(a_input, w_24_ref) - opcheck(torch.ops._C.gptq_marlin_24_gemm, - (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, - workspace_24.scratch, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1]), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_24_gemm, + ( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + quant_type.id, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) output = marlin_24_gemm_tester( a_input, @@ -361,8 +451,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES) @@ -386,22 +478,22 @@ def test_hqq_marlin_gemm( a_input = rand_data((size_m, size_k)) dev = a_input.device - b_weight = torch.randint(0, - 10, (size_n, size_k), - dtype=torch.uint8, - device=dev) + b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev) scale = rand_data((size_n, size_k // group_size)) zero = rand_data((size_n, size_k // group_size)) gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n) sort_indices = torch.empty(0, dtype=torch.int, device=dev) - marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, - 4).to(dev) - marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n, - group_size).to(dev) - marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n, - group_size).to(dev) + marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to( + dev + ) + marlin_s = marlin_permute_scales( + scale.transpose(1, 0), size_k, size_n, group_size + ).to(dev) + marlin_zp = marlin_permute_scales( + zero.transpose(1, 0), size_k, size_n, group_size + ).to(dev) g_idx = marlin_make_empty_g_idx(dev) g_idx_sort_indices = marlin_make_empty_g_idx(dev) @@ -433,8 +525,7 @@ def test_hqq_marlin_gemm( s_flat = scale.reshape(-1, 1) dequant = (b_flat - zp_flat) * s_flat - output_ref = torch.matmul(a_input, - dequant.reshape(b_weight.shape).transpose(1, 0)) + output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0)) torch.cuda.synchronize() @@ -451,11 +542,12 @@ def test_marlin_gemm_subset_input(): big_m = size_m * 2 big_k = size_k * 2 - a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8] + a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8] b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, False) + b_weight, quant_type, group_size, False + ) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) @@ -497,12 +589,13 @@ def test_marlin_gemm_with_bias(size_m): size_k, size_n = 1024, 2048 a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - b_bias = rand_data((size_n, )) * 10 + b_bias = rand_data((size_n,)) * 10 marlin_bias = marlin_permute_bias(b_bias) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, False) + b_weight, quant_type, group_size, False + ) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) diff --git a/tests/kernels/quantization/test_mxfp4_qutlass.py b/tests/kernels/quantization/test_mxfp4_qutlass.py new file mode 100644 index 0000000000000..0bacbef2046b4 --- /dev/null +++ b/tests/kernels/quantization/test_mxfp4_qutlass.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.platforms import current_platform + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + +if not ( + current_platform.has_device_capability(100) + or current_platform.has_device_capability(120) +): + pytest.skip( + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", + allow_module_level=True, + ) + + +# ----- Helpers ----- +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _rtne_fp4(x: torch.Tensor): + device = x.device + grid = torch.tensor( + [ + -6.0, + -4.0, + -3.0, + -2.0, + -1.5, + -1.0, + -0.5, + -0.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + ], + dtype=x.dtype, + device=x.device, + ) + grid_int = torch.tensor( + [-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.uint8, + device=device, + ) + inds = torch.bucketize(x, grid) + lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15) + g_lo, g_hi = grid[lo], grid[hi] + pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0) + y = torch.where(pick_hi, g_hi, g_lo) + y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo]) + y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF + return y, y_int_packed + + +def _dq_fp4(x_e2m1: torch.Tensor, x_e8m0: torch.Tensor, alpha: float): + device = x_e2m1.device + + x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32) + x_e2m1_unpacked = torch.stack( + [x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1 + ).flatten(start_dim=-2) + + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float64, + device=device, + ) + x_fp4_dq = grid_dq[x_e2m1_unpacked] + scales_dq = x_e8m0.to(torch.float64) + + x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 32)) * scales_dq[..., None]).flatten( + start_dim=-2 + ) / alpha + return x_dq, x_fp4_dq, scales_dq + + +def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor: + clip_mask_unpacked_dq = torch.zeros( + *clip_mask.shape[:-1], + clip_mask.size(-1) * 8, + dtype=torch.bool, + device=clip_mask.device, + ) + for i in range(8): + clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1 + return clip_mask_unpacked_dq + + +def _forward_quantize_ref( + x: torch.Tensor, h: torch.Tensor, rot_size: int, quest: bool = True +): + device = x.device + xh_ref64 = ( + x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64) + @ h.reshape(rot_size, rot_size).to(dtype=torch.float64) + ).flatten(start_dim=-2) + + if quest: + scales_ref64_ = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).std(dim=-1, correction=0) + * (2.92247856 / 6.0) + + 1e-8 + ) + else: + abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).abs().amax(dim=-1) + scales_ref64_ = abs_max + 1e-8 + + xh_e8m0_ref = scales_ref64_.log2().floor().exp2().to(dtype=torch.float8_e8m0fnu) + scales_ref64 = xh_e8m0_ref.to(dtype=torch.float64) + + xh_scaled_ref64 = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 32)) / scales_ref64[..., None] + ).flatten(start_dim=-2) + if not quest: + xh_scaled_ref64 *= 3 + + clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0 + clip_mask_ref = torch.zeros( + *x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device + ) + for i in range(8): + clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i + + xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64) + xh_dq, xh_fp4_dq, scales_dq = _dq_fp4( + xh_e2m1_ref, xh_e8m0_ref, alpha=1.0 if quest else 3.0 + ) + clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref) + + assert xh_fp4_dq.equal(xh_fp4_ref) + assert scales_dq.equal(scales_ref64) + assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref) + + return ( + xh_dq, + clip_mask_unpacked_ref, + (xh_e2m1_ref, xh_e8m0_ref, clip_mask_ref), + ) + + +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda:0") + +ROT_SIZES = [32, 64, 128] +SEEDS = [0] +BATCHES = [1, 16] + +LLAMA_MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + + +@pytest.fixture(autouse=True) +def _seed_each_test(): + current_platform.seed_everything(0) + np.random.seed(0) + torch.random.manual_seed(0) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_fused_quantization_absmax(rot_size: int): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False) + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max") + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 + + m, n, k = 1, 504, 4096 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max") + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_fused_quantization_quest(rot_size: int): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True) + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest") + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 + + m, n, k = 504, 504, 2048 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) +@pytest.mark.parametrize("batch", [1, 16]) +@pytest.mark.parametrize("had_size", ROT_SIZES) +@torch.inference_mode() +def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int): + dtype, device = DTYPE, DEVICE + m = batch + k, n = LLAMA_MODELS[model][layer_idx] + + h = get_hadamard_matrix(had_size, dtype, device) + + a = torch.rand(m, k, dtype=dtype, device=device) * 25.0 + b = torch.rand(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") + + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index 3a8f4c17598c2..e9b091d06697e 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -8,15 +8,27 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) DTYPES = [torch.float16, torch.bfloat16] SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] -PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), - (90, 128), (150, 128), (150, 48), (90, 80)] +PAD_SHAPES = [ + (90, 64), + (150, 64), + (128, 48), + (128, 80), + (150, 80), + (90, 48), + (90, 128), + (150, 128), + (150, 48), + (90, 80), +] SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max @@ -31,7 +43,22 @@ FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 0001 -> 0.5 # 0000 -> 0 E2M1_TO_FLOAT32 = [ - 0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6. + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] BLOCK_SIZE = 16 @@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale): assert x.ndim == 2 m, n = x.shape x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) scale = scale.to(torch.float8_e4m3fn).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) @@ -131,7 +157,7 @@ def test_quantize_to_fp4( def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: dtype = torch.float16 current_platform.seed_everything(42) - torch.set_default_device('cuda:0') + torch.set_default_device("cuda:0") m, n = pad_shape diff --git a/tests/kernels/quantization/test_nvfp4_qutlass.py b/tests/kernels/quantization/test_nvfp4_qutlass.py new file mode 100644 index 0000000000000..3824a080f5047 --- /dev/null +++ b/tests/kernels/quantization/test_nvfp4_qutlass.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.platforms import current_platform + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + +if not ( + current_platform.has_device_capability(100) + or current_platform.has_device_capability(120) +): + pytest.skip( + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", + allow_module_level=True, + ) + + +# ----- Helpers ----- +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _rtne_fp4(x: torch.Tensor): + device = x.device + grid = torch.tensor( + [ + -6.0, + -4.0, + -3.0, + -2.0, + -1.5, + -1.0, + -0.5, + -0.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + ], + dtype=x.dtype, + device=x.device, + ) + grid_int = torch.tensor( + [-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.uint8, + device=device, + ) + inds = torch.bucketize(x, grid) + lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15) + g_lo, g_hi = grid[lo], grid[hi] + pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0) + y = torch.where(pick_hi, g_hi, g_lo) + y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo]) + y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF + return y, y_int_packed + + +def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float): + device = x_e2m1.device + + x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32) + x_e2m1_unpacked = torch.stack( + [x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1 + ).flatten(start_dim=-2) + + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float64, + device=device, + ) + x_fp4_dq = grid_dq[x_e2m1_unpacked] + + scales_dq = x_e4m3.to(torch.float64) + x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 16)) * scales_dq[..., None]).flatten( + start_dim=-2 + ) / alpha # * (4. / 3.) + return x_dq, x_fp4_dq, scales_dq + + +def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor: + clip_mask_unpacked_dq = torch.zeros( + *clip_mask.shape[:-1], + clip_mask.size(-1) * 8, + dtype=torch.bool, + device=clip_mask.device, + ) + for i in range(8): + clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1 + return clip_mask_unpacked_dq + + +def _forward_quantize_ref(x: torch.Tensor, h: torch.Tensor, rot_size: int): + device = x.device + + xh_ref64 = ( + x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64) + @ h.reshape(rot_size, rot_size).to(dtype=torch.float64) + ).flatten(start_dim=-2) + + abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 16)).abs().amax(dim=-1) + scales_ref64_ = abs_max + 1e-8 + + xh_e4m3_ref = scales_ref64_.to(dtype=torch.float8_e4m3fn) + scales_ref64 = xh_e4m3_ref.to(dtype=torch.float64) + xh_scaled_ref64 = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 16)) / scales_ref64[..., None] + ).flatten(start_dim=-2) + + xh_scaled_ref64 *= 6.0 + + clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0 + clip_mask_ref = torch.zeros( + *x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device + ) + for i in range(8): + clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i + + xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64) + xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(xh_e2m1_ref, xh_e4m3_ref, 6.0) + clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref) + + assert xh_fp4_dq.equal(xh_fp4_ref) + assert scales_dq.equal(scales_ref64) + assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref) + + return ( + xh_dq, + clip_mask_unpacked_ref, + (xh_e2m1_ref, xh_e4m3_ref, clip_mask_ref), + ) + + +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda:0") +ROT_SIZES = [16, 32, 64, 128] +GLOBAL_SCALES = [6.0] + +LLAMA_MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + + +@pytest.fixture(autouse=True) +def _seed_each_test(): + current_platform.seed_everything(0) + np.random.seed(0) + torch.random.manual_seed(0) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@pytest.mark.parametrize("global_scale_value", GLOBAL_SCALES) +@torch.inference_mode() +def test_fused_quantization(rot_size: int, global_scale_value: float): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + global_scale = torch.tensor([global_scale_value], device=device) + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size) + xh_e2m1, xh_e4m3 = fusedQuantizeNv(x, h, global_scale) + xh_e4m3 = xh_e4m3.reshape(2, 4096, 4096 // 16) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e4m3, alpha=global_scale_value) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-1 + + m, n, k = 504, 4096 * 2, 4096 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale) + b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale) + + a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16) + b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16) + alpha = torch.tensor([1.0], device=device) + out = ops.cutlass_scaled_fp4_mm( + a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16 + ) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) +@pytest.mark.parametrize("batch", [1, 16]) +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_llama_shapes(model: str, layer_idx: int, batch: int, rot_size: int): + dtype, device = DTYPE, DEVICE + m = batch + k, n = LLAMA_MODELS[model][layer_idx] + + h = get_hadamard_matrix(rot_size, dtype, device) + + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + global_scale = torch.tensor([1.0], device=device) + + a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale) + b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale) + + a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16) + b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16) + alpha = torch.tensor([1.0], device=device) + out = ops.cutlass_scaled_fp4_mm( + a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16 + ) + assert out.equal(out_ref.to(dtype=out.dtype)) diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index 67e041f2b71c4..434564737c889 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch -from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype from vllm import _custom_ops as ops from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) DTYPES = [torch.float16, torch.bfloat16] # m, n, k @@ -19,26 +20,31 @@ PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] SHAPES.extend(PAD_SHAPES) SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] -def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, - m, n, dtype, block_size, device): +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): _, m_k = a_fp4.shape _, n_k = b_fp4.shape - assert (m_k == n_k) - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + assert m_k == n_k + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + ) + b_in_dtype = dequantize_nvfp4_to_dtype( + b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size + ) return torch.matmul(a_in_dtype, b_in_dtype.t()) @@ -60,25 +66,34 @@ def test_nvfp4_gemm( a_dtype = torch.randn((m, k), dtype=dtype, device=device) b_dtype = torch.randn((n, k), dtype=dtype, device=device) - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) - b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) - alpha = 1. / (a_global_scale * b_global_scale) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) # ops.scaled_fp4_quant returns swizzled scales, while weights # from checkpoints are in linear scales. a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) # get_ref_results unswizzles the scales internally. - expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, - b_scale_interleaved, a_global_scale, - b_global_scale, m, n, dtype, block_size, - device) - out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, - b_scale_interleaved, alpha, dtype) + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, + ) + out = ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ) - torch.testing.assert_close(out, - expected_out.to(dtype=dtype), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_per_token_group_quant.py b/tests/kernels/quantization/test_per_token_group_quant.py index 07f17d1efe641..7a65004545306 100644 --- a/tests/kernels/quantization/test_per_token_group_quant.py +++ b/tests/kernels/quantization/test_per_token_group_quant.py @@ -13,15 +13,15 @@ from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils @pytest.mark.parametrize("scale_ue8m0", [False, True]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_per_token_group_quant_fp8(shape, column_major: bool, - scale_ue8m0: bool, group_size: int): +def test_per_token_group_quant_fp8( + shape, column_major: bool, scale_ue8m0: bool, group_size: int +): device = "cuda" torch.manual_seed(42) num_tokens, hidden_dim = shape - x = (torch.randn( - (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8 # cuda path out_q, scale = fp8_utils.per_token_group_quant_fp8( @@ -53,8 +53,7 @@ def test_per_token_group_quant_int8(shape, group_size: int): torch.manual_seed(42) num_tokens, hidden_dim = shape - x = (torch.randn( - (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8 # cuda path out_q, scale = int8_utils.per_token_group_quant_int8( diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 03d5d98739c50..dc6557b93f050 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + import pytest import torch @@ -47,6 +49,7 @@ NKM_FACTORS_WVSPLITK_FP8 = [ (2, 512, 512), (3, 2048, 2048), (4, 4096, 4096), + (4, 16400, 2048), # Extended FP8 dimensions not covered by WVSPLITK (1, 14336, 1024), (2, 24576, 2048), @@ -60,11 +63,13 @@ SEEDS = [0] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @torch.inference_mode() def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): torch.manual_seed(seed) + # TODO: Zero-centering the inputs causes errors for LLMM1! + # Without that the numbers quickly saturate, and may + # be giving false matches. A = torch.rand(n, k, dtype=dtype, device="cuda") B = torch.rand(m, k, dtype=dtype, device="cuda") @@ -77,17 +82,54 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() - A = torch.rand(n, k, dtype=dtype, device="cuda") - B = torch.rand(m, k, dtype=dtype, device="cuda") + A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 + B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 - ref_out = torch.matmul(A, B.t()) - out = ops.wvSplitK(B, A, cu_count) + ref_out = torch.nn.functional.linear(A, B) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") +def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") +def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier + BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) assert torch.allclose(out, ref_out, rtol=0.01) @@ -97,22 +139,48 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), - reason="only test for rocm fp8") + reason="only test for rocm fp8", +) def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - A = torch.rand(n, k, device="cuda") - B = torch.rand(m, k, device="cuda") + A = torch.rand(n, k, device="cuda") - 0.5 + B = torch.rand(m, k, device="cuda") - 0.5 A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - ref_out = torch._scaled_mm(A, - B.t(), - out_dtype=dtype, - scale_a=scale_a, - scale_b=scale_b) - out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, - current_platform.get_cu_count()) + ref_out = torch._scaled_mm( + A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b + ) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count()) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="only test for rocm fp8", +) +def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, device="cuda") - 0.5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 + + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + + ref_out = torch._scaled_mm( + A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS + ) + out = ops.wvSplitKQ( + B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS + ) assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py b/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py new file mode 100644 index 0000000000000..4617464a39788 --- /dev/null +++ b/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) +from vllm._custom_ops import scaled_fp4_quant +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +FP4_DTYPE = torch.uint8 +FP8_DTYPE = current_platform.fp8_dtype() + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)] +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@torch.inference_mode() +def test_silu_mul_nvfp4_quant( + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + current_platform.seed_everything(42) + device = "cuda:0" + torch.set_default_device(device) + + x = torch.randn(shape, dtype=dtype) + + # ref op + ref_output = SiluAndMul().forward_native(x) + ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + ref_output + ).max().to(torch.float32) + ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale) + + # fused op + fused_output_quant = torch.empty_like(ref_output_quant) + fused_block_scale = torch.empty_like(ref_block_scale) + torch.ops._C.silu_and_mul_nvfp4_quant( + fused_output_quant, fused_block_scale, x, ref_global_scale + ) + + # check dtype + assert ref_output_quant.dtype == FP4_DTYPE + assert fused_output_quant.dtype == FP4_DTYPE + assert ref_output_quant.shape == fused_output_quant.shape + + assert ref_block_scale.dtype == FP8_DTYPE + assert fused_block_scale.dtype == FP8_DTYPE + assert ref_block_scale.shape == fused_block_scale.shape + + # check dequantized output + ref_output_dequant = dequantize_nvfp4_to_dtype( + ref_output_quant, ref_block_scale, ref_global_scale, dtype, device + ) + fused_output_dequant = dequantize_nvfp4_to_dtype( + fused_output_quant, fused_block_scale, ref_global_scale, dtype, device + ) + + atol, rtol = 3e-1, 3e-1 + torch.testing.assert_close( + ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol + ) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 24245663fb1d6..1026332d99f89 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the triton_scaled_mm kernel -Run `pytest tests/kernels/test_triton_scaled_mm.py`. +Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`. """ + import importlib from typing import Optional @@ -15,17 +16,19 @@ from vllm.platforms import current_platform device = "cuda" triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") + "vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm" +) triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm -def torch_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def torch_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: out = torch.mm(a.to(torch.float32), b.to(torch.float32)) out = scale_a * out out = scale_b.T * out @@ -44,20 +47,22 @@ def get_8bit_types(): # This test is to check regressions for int8 support on ROCm. -@pytest.mark.parametrize("model_path", [ - "neuralmagic/Llama-3.2-1B-quantized.w8a8", -]) +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Should only run on ROCm") -def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, - max_tokens, num_logprobs): +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm") +def test_rocm_compressed_tensors_w8a8( + vllm_runner, example_prompts, model_path, max_tokens, num_logprobs +): dtype = "bfloat16" with vllm_runner(model_path, dtype=dtype) as vllm_model: - vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, - num_logprobs) + vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) MNK_FACTORS = [ @@ -76,10 +81,10 @@ MNK_FACTORS = [ @pytest.mark.parametrize("use_scalar_scale_a", [True, False]) @pytest.mark.parametrize("use_scalar_scale_b", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, - use_scalar_scale_b, use_bias): - is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t - ).is_floating_point() +def test_scaled_mm( + M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias +): + is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point() current_platform.seed_everything(0) @@ -93,10 +98,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, # # So, the values here are kept small enough to avoid this situation. if is_floating_point_type(in_dtype): - a = (0.25 * torch.rand( - (M, K), dtype=torch.float32, device=device)).to(in_dtype) - b = (0.25 * torch.rand( - (K, N), dtype=torch.float32, device=device)).to(in_dtype) + a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype) + b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype) else: a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device) b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device) @@ -113,7 +116,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, bias = None if use_bias: - bias = torch.rand((N, ), device=device, dtype=out_dtype) + bias = torch.rand((N,), device=device, dtype=out_dtype) c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) diff --git a/tests/kernels/test_apply_repetition_penalties.py b/tests/kernels/test_apply_repetition_penalties.py index 90380b872d6c2..a4619f5846b16 100644 --- a/tests/kernels/test_apply_repetition_penalties.py +++ b/tests/kernels/test_apply_repetition_penalties.py @@ -4,8 +4,10 @@ import pytest import torch from tests.kernels.utils import opcheck -from vllm._custom_ops import (apply_repetition_penalties_cuda, - apply_repetition_penalties_torch) +from vllm._custom_ops import ( + apply_repetition_penalties_cuda, + apply_repetition_penalties_torch, +) from vllm.platforms import current_platform NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025] @@ -21,8 +23,9 @@ DTYPES = [torch.float32, torch.float16] @pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test for checking CUDA kernel") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test for checking CUDA kernel" +) @torch.inference_mode() def test_apply_repetition_penalties( num_seqs: int, @@ -32,7 +35,7 @@ def test_apply_repetition_penalties( seed: int, ) -> None: """ - Test the apply_repetition_penalties custom op + Test the apply_repetition_penalties custom op against a reference implementation. """ current_platform.seed_everything(seed) @@ -46,39 +49,40 @@ def test_apply_repetition_penalties( output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) # Mark some tokens as repeated in prompt and output - prompt_indices = torch.randint(0, vocab_size, - (num_seqs, max(1, vocab_size // 200))) - output_indices = torch.randint(0, vocab_size, - (num_seqs, max(1, vocab_size // 200))) + prompt_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200))) + output_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200))) for i in range(num_seqs): prompt_mask[i, prompt_indices[i]] = True output_mask[i, output_indices[i]] = True # Create repetition penalties tensor - repetition_penalties = torch.full((num_seqs, ), - repetition_penalty, - dtype=dtype) + repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype) # Run all three implementations logits_torch = logits.clone() logits_cuda = logits.clone() - apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, - repetition_penalties) - apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits_torch, prompt_mask, output_mask, repetition_penalties + ) + apply_repetition_penalties_cuda( + logits_cuda, prompt_mask, output_mask, repetition_penalties + ) # Compare all outputs to reference torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) # Test the operator by applying the opcheck utility - opcheck(torch.ops._C.apply_repetition_penalties_, - (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + opcheck( + torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties), + ) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test for checking CUDA kernel") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test for checking CUDA kernel" +) @torch.inference_mode() def test_apply_repetition_penalties_zero_seqs() -> None: """ @@ -104,22 +108,24 @@ def test_apply_repetition_penalties_zero_seqs() -> None: # No tokens to mark as repeated since num_seqs=0 # Create repetition penalties tensor - repetition_penalties = torch.full((num_seqs, ), - repetition_penalty, - dtype=dtype) + repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype) # Run all three implementations logits_torch = logits.clone() logits_cuda = logits.clone() - apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, - repetition_penalties) - apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits_torch, prompt_mask, output_mask, repetition_penalties + ) + apply_repetition_penalties_cuda( + logits_cuda, prompt_mask, output_mask, repetition_penalties + ) # Compare all outputs to reference torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) # Test the operator by applying the opcheck utility - opcheck(torch.ops._C.apply_repetition_penalties_, - (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + opcheck( + torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties), + ) diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py deleted file mode 100644 index 2b745b84dae6c..0000000000000 --- a/tests/kernels/test_cutlass_mla_decode.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -import torch -import torch.nn.functional as F -from torch import Tensor - -import vllm._custom_ops as ops -from vllm.platforms import current_platform - -if not current_platform.has_device_capability(100): - pytest.skip( - reason="Cutlass MLA Requires compute capability of 10 or above.", - allow_module_level=True) - - -def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) -): - bs, num_heads, v_head_dim = out.shape - head_dim = query.shape[2] - - for i in range(bs): - # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) - v = kv[:, :, :v_head_dim] - - q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) - out[i] = o.view(num_heads, v_head_dim) - - return out - - -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) -@pytest.mark.parametrize("bs", [1, 2, 4]) -@pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("block_size", [16, 64, 128]) -def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, - varlen: bool, block_size: int): - torch.set_default_dtype(dtype) - torch.set_default_device('cuda') - torch.manual_seed(42) - - d = 576 - h_q = 128 - dv = 512 - - q_nope_dim = 128 - q_pe_dim = 64 - scale = (q_nope_dim + q_pe_dim)**(-0.5) - if varlen: - seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) - seq_lens = seq_lens.clip(2).to(torch.int32) - else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) - max_seq_len = seq_lens.max().item() - block_num = (max_seq_len + block_size - 1) // block_size - - # Pad block_num so that small blocks can be packed into full 128-sized - # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small - # blocks. - pack_factor = 128 // block_size - block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - - # Amplify input values to ensure test coverage of edge cases where CUTLASS - # kernel errors occur with split_k settings. - q = torch.randn(bs, h_q, d) * 100 - block_table = torch.randint(0, - bs * block_num, (bs, block_num), - dtype=torch.int32) - - kv_cache = torch.randn(block_table.numel(), block_size, d) - - out_ref = q.new_zeros(bs, h_q, dv) - ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) - out_ans = torch.zeros_like(out_ref) - q_nope = q[:, :, :dv].clone() - q_pe = q[:, :, dv:].clone() - ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, - block_table, scale) - - torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index f76bd192460c9..ae33f422d3732 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -9,12 +9,19 @@ import pytest import torch from packaging import version -from vllm import SamplingParams +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, +) +from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder -from ..models.utils import check_embeddings_close +from ..models.utils import check_embeddings_close, check_logprobs_close TORCH_VERSION = version.parse(torch.__version__) MINIMUM_TORCH_VERSION = version.parse("2.7.0") +DIRECT_BUILD_VERSION = version.parse("2.9.dev0") def set_seed(seed): @@ -34,57 +41,55 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with - the default backend, ensuring they are identical when using the same seed. + the default backend, ensuring they are similar when using the same seed. """ model_name = "Qwen/Qwen2.5-1.5B-Instruct" seed = 42 max_tokens = 24 + num_logprobs = 5 prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", ] - sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - seed=seed, - max_tokens=max_tokens) - # Run with flex attention with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") set_seed(seed) - with vllm_runner(model_name, - runner="generate", - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True) as llm_flex: - output_flex = llm_flex.generate(prompts, sampling_params) + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + ) as llm_flex: + output_flex = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) # Run with default backend with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") set_seed(seed) - with vllm_runner(model_name, - runner="generate", - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True) as llm_default: - output_default = llm_default.generate(prompts, sampling_params) + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + gpu_memory_utilization=0.85, + ) as llm_default: + output_default = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) - # Compare outputs from both backends - for i, (flex_result, - default_result) in enumerate(zip(output_flex, output_default)): - prompt = prompts[i] - flex_text = flex_result[1][0] - default_text = default_result[1][0] - - assert flex_text == default_text, ( - f"FlexAttention output doesn't match default for: {prompt!r}\n" - f"FlexAttention: {flex_text!r}\n" - f"Default: {default_text!r}") + check_logprobs_close( + outputs_0_lst=output_flex, + outputs_1_lst=output_default, + name_0="flex", + name_1="default", + ) @pytest.mark.skipif( @@ -106,26 +111,30 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): # Run with flex attention with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") - with vllm_runner(model_name, - runner="pooling", - dtype=torch.bfloat16, - tensor_parallel_size=1, - max_model_len=100, - enforce_eager=True) as llm_flex: + with vllm_runner( + model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True, + ) as llm_flex: flex_outputs = llm_flex.embed(prompts) # Run with default backend - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - with vllm_runner(model_name, - runner="pooling", - dtype=torch.bfloat16, - tensor_parallel_size=1, - max_model_len=100, - enforce_eager=True) as llm_default: - default_outputs = llm_default.embed(prompts) + with ( + monkeypatch.context() as m, + vllm_runner( + model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True, + ) as llm_default, + ): + default_outputs = llm_default.embed(prompts) check_embeddings_close( embeddings_0_lst=flex_outputs, @@ -136,5 +145,72 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ) +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_block_mask_direct_vs_slow_path(): + """Test that direct path block mask is a superset of slow path. + + The direct path may include extra blocks for performance (over-estimation), + but must include all blocks that the slow path determines are necessary. + """ + device = torch.device("cuda") + + vllm_config = create_vllm_config( + model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024 + ) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # Use a mixed batch that will create groups spanning multiple sequences + batch_spec = BatchSpec( + seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch" + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device + ) + + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device) + + metadata_direct = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + builder.direct_build = False + metadata_slow = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + assert metadata_direct.block_mask is not None + assert metadata_slow.block_mask is not None + + # Extract block indices for comparison, B, H are the same + direct_indices = metadata_direct.block_mask.kv_indices[0, 0] + slow_indices = metadata_slow.block_mask.kv_indices[0, 0] + direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0] + slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0] + + # main test: every block needed by slow path must be in direct path + num_groups = direct_num.shape[0] + all_contained = True + missing_details = [] + + for group_idx in range(num_groups): + direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist()) + slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist()) + + missing_blocks = slow_blocks - direct_blocks + if missing_blocks: + all_contained = False + missing_details.append( + f"Group {group_idx}: missing {sorted(missing_blocks)}" + ) + + assert all_contained, ( + "Direct path is missing blocks required by slow path:\n" + + "\n".join(missing_details) + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index 803453a20d81d..c79e6105e69fa 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -13,13 +13,12 @@ QUANT_DTYPES = [current_platform.fp8_dtype()] NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] -def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, - scale: torch.Tensor) -> torch.Tensor: +def ref_impl( + silu_and_mul: SiluAndMul, x: torch.Tensor, scale: torch.Tensor +) -> torch.Tensor: silu_and_mul_out = silu_and_mul.forward_native(x) out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) return out @@ -27,9 +26,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: out_shape = (x.shape[0], x.shape[1] // 2) - out = torch.empty(out_shape, - dtype=current_platform.fp8_dtype(), - device=x.device) + out = torch.empty(out_shape, dtype=current_platform.fp8_dtype(), device=x.device) torch.ops._C.silu_and_mul_quant(out, x, scale) return out @@ -57,7 +54,7 @@ def test_silu_and_mul( layer = SiluAndMul() # Make inputs - scale = (torch.randn((1), device=device, dtype=torch.float32)) + scale = torch.randn((1), device=device, dtype=torch.float32) x = torch.randn(num_tokens, hidden_size, dtype=dtype) ref_out = ref_impl(layer, x, scale) @@ -66,6 +63,7 @@ def test_silu_and_mul( assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype assert ref_out.shape == ops_out.shape - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + assert torch.allclose( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py index 17692384ac9a9..9f78c177a81f0 100644 --- a/tests/kernels/test_onednn.py +++ b/tests/kernels/test_onednn.py @@ -44,24 +44,27 @@ def ref_int8_scaled_mm( ): if azp is not None: a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32) - output = torch.mm((scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32))) + output = torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ) if bias is not None: output += bias.float() return output.to(dtype=output_type) -def onednn_int8_gemm_test_helper(primitive_cache_size: int, - m: int, - n: int, - k: int, - per_tensor_a_quant: bool, - per_tensor_b_quant: bool, - use_azp: bool, - use_bias: bool, - out_dtype: torch.dtype = torch.bfloat16, - device: str = "cpu"): +def onednn_int8_gemm_test_helper( + primitive_cache_size: int, + m: int, + n: int, + k: int, + per_tensor_a_quant: bool, + per_tensor_b_quant: bool, + use_azp: bool, + use_bias: bool, + out_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu", +): # Test for a oneDNN kernel with per-tensor / per-token activation # quantization and per-tensor / per-output channel weight quantization. a = to_int8(torch.randn((m, k), device=device) * 5) @@ -70,8 +73,8 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1) b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) if use_azp: azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5 @@ -81,10 +84,7 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, azp = None azp_adj = None - if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 - else: - bias = None + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None handler = ops.create_onednn_scaled_mm( b, @@ -105,12 +105,58 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, # To test runtime bias setting out = torch.zeros((m, n), dtype=out_dtype) ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None) - baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, - out_dtype) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, out_dtype) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) +def onednn_gemm_test_helper( + primitive_cache_size: int, + m: int, + n: int, + k: int, + use_bias: bool, + use_stride: bool, + dtype: torch.dtype = torch.bfloat16, + device: str = "cpu", +): + if use_stride: + a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5 + a = a[:, :k] + else: + a = torch.rand((m, k), dtype=dtype, device=device) * 1.5 + + b = torch.rand((n, k), dtype=dtype, device=device) * 1.5 + + if use_bias: + bias = torch.rand((n,), device=device, dtype=dtype) * 5 + bias_f32 = bias.float() + else: + bias = None + bias_f32 = None + + handler = ops.create_onednn_mm( + b.t(), + primitive_cache_size, + ) + + out = ops.onednn_mm(handler, a, bias) + baseline = torch.nn.functional.linear(a.float(), b.float(), bias_f32).to( + dtype=a.dtype + ) + + torch.testing.assert_close(out, baseline) + + if use_bias: + # To test runtime bias setting + out = ops.onednn_mm(handler, a, None) + baseline = torch.nn.functional.linear(a.float(), b.float(), None).to( + dtype=a.dtype + ) + + torch.testing.assert_close(out, baseline) + + @pytest.mark.parametrize("n,k", NK_FACTORS) @pytest.mark.parametrize("m_list", M_FACTORS) @pytest.mark.parametrize("per_tensor_a_scale", [True, False]) @@ -122,7 +168,7 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, def test_onednn_int8_scaled_gemm( n: int, k: int, - m_list: tuple[int], + m_list: tuple[int, ...], per_tensor_a_scale: bool, per_tensor_b_scale: bool, use_bias: bool, @@ -142,3 +188,30 @@ def test_onednn_int8_scaled_gemm( use_azp=use_azp, out_dtype=output_type, ) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_stride", [True, False]) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_gemm( + n: int, + k: int, + m_list: tuple[int, ...], + use_bias: bool, + use_stride: bool, + dtype: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + use_bias=use_bias, + use_stride=use_stride, + dtype=dtype, + ) diff --git a/tests/kernels/test_shuffle_rows.py b/tests/kernels/test_shuffle_rows.py index 7d02e1764e7d4..c7de64066e87b 100644 --- a/tests/kernels/test_shuffle_rows.py +++ b/tests/kernels/test_shuffle_rows.py @@ -14,20 +14,15 @@ from vllm.platforms import current_platform @pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024]) @pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32]) -def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, - dtype: torch.dtype): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, dtype: torch.dtype): """Test basic functionality of shuffle_rows with various tensor sizes and dtypes.""" if not current_platform.is_cuda(): pytest.skip("shuffle_rows requires CUDA") # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a simple permutation map (identity mapping) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) @@ -47,24 +42,18 @@ def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("num_tokens", [16, 64, 128]) @pytest.mark.parametrize("hidden_size", [128, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_shuffle_rows_permutation(num_tokens: int, hidden_size: int, - dtype: torch.dtype): +def test_shuffle_rows_permutation( + num_tokens: int, hidden_size: int, dtype: torch.dtype +): """Test shuffle_rows with actual permutation.""" if not current_platform.is_cuda(): pytest.skip("shuffle_rows requires CUDA") # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a reverse permutation map - dst2src_map = torch.arange(num_tokens - 1, - -1, - -1, - device="cuda", - dtype=torch.int32) + dst2src_map = torch.arange(num_tokens - 1, -1, -1, device="cuda", dtype=torch.int32) # Test shuffle_rows output = shuffle_rows(input_tensor, dst2src_map) @@ -90,17 +79,13 @@ def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int): dtype = torch.float16 # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a mapping that duplicates some tokens (expansion) expanded_size = num_tokens * 2 - dst2src_map = torch.randint(0, - num_tokens, (expanded_size, ), - device="cuda", - dtype=torch.int32) + dst2src_map = torch.randint( + 0, num_tokens, (expanded_size,), device="cuda", dtype=torch.int32 + ) # Test shuffle_rows output = shuffle_rows(input_tensor, dst2src_map) @@ -113,10 +98,9 @@ def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int): # Verify that each output row matches the corresponding input row for i in range(expanded_size): src_idx = dst2src_map[i].item() - torch.testing.assert_close(output[i], - input_tensor[src_idx], - atol=1e-6, - rtol=1e-5) + torch.testing.assert_close( + output[i], input_tensor[src_idx], atol=1e-6, rtol=1e-5 + ) @pytest.mark.parametrize("num_tokens", [16, 64]) @@ -132,10 +116,7 @@ def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int): torch.manual_seed(42) # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a random permutation map dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32) @@ -151,10 +132,9 @@ def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int): # Verify that each output row matches the corresponding input row for i in range(num_tokens): src_idx = dst2src_map[i].item() - torch.testing.assert_close(output[i], - input_tensor[src_idx], - atol=1e-6, - rtol=1e-5) + torch.testing.assert_close( + output[i], input_tensor[src_idx], atol=1e-6, rtol=1e-5 + ) def test_shuffle_rows_edge_cases(): @@ -188,10 +168,7 @@ def test_shuffle_rows_moe_like_scenario(): topk = 2 # Simulate input tokens - input_tensor = torch.randn(batch_size, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) # Simulate expert assignment (each token goes to topk experts) # This creates a mapping where tokens are duplicated for multiple experts @@ -215,14 +192,12 @@ def test_shuffle_rows_moe_like_scenario(): for i in range(batch_size): for k in range(topk): output_idx = i * topk + k - torch.testing.assert_close(output[output_idx], - input_tensor[i], - atol=1e-6, - rtol=1e-5) + torch.testing.assert_close( + output[output_idx], input_tensor[i], atol=1e-6, rtol=1e-5 + ) -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_shuffle_rows_dtype_consistency(dtype: torch.dtype): """Test that shuffle_rows preserves dtype correctly.""" if not current_platform.is_cuda(): @@ -232,10 +207,7 @@ def test_shuffle_rows_dtype_consistency(dtype: torch.dtype): hidden_size = 512 # Create input tensor with specific dtype - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) # Test shuffle_rows @@ -257,10 +229,7 @@ def test_shuffle_rows_device_consistency(): dtype = torch.float16 # Create input tensor on CUDA - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) # Test shuffle_rows @@ -281,10 +250,7 @@ def test_shuffle_rows_contiguous_output(): dtype = torch.float16 # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) # Test shuffle_rows diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py new file mode 100644 index 0000000000000..ccef9d7123640 --- /dev/null +++ b/tests/kernels/test_top_k_per_row.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch + +from vllm.platforms import current_platform + +# Test parameters +NUM_ROWS = [1, 32, 2050] +TOP_K_VALUES = [2048] + + +def create_random_logits( + row_starts: torch.Tensor, + row_ends: torch.Tensor, + vocab_size: int, + dtype: torch.dtype, + seed: int, +) -> torch.Tensor: + """Create random logits tensor for testing.""" + torch.manual_seed(seed) + np.random.seed(seed) + # Generate logits with some structure to make testing more meaningful + logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda") + for i, end in enumerate(row_ends): + logits[i, end:] = float("-inf") + return logits + + +def create_row_boundaries( + seq_len: int, vocab_size: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Create row start and end indices for testing.""" + row_starts = torch.zeros(seq_len, dtype=torch.int32, device="cuda") + row_ends = torch.arange(1, seq_len + 1, device="cuda", dtype=torch.int32) + return row_starts, row_ends + + +def compare_top_k_results( + cuda_indices: torch.Tensor, + cuda_values: torch.Tensor, + torch_indices: torch.Tensor, + torch_values: torch.Tensor, + row_starts: torch.Tensor, + row_ends: torch.Tensor, + top_k: int, + tolerance: float = 1e-5, +) -> bool: + """ + Compare results from CUDA top_k_per_row with torch.topk. + Both results should be sorted and contain the same top-k elements. + """ + num_rows = cuda_indices.shape[0] + + for row_idx in range(num_rows): + # Get valid elements using row boundaries + row_start = row_starts[row_idx].item() + row_end = row_ends[row_idx].item() + row_length = row_end - row_start + num_valid = min(top_k, row_length) + cuda_row_indices = cuda_indices[row_idx][:num_valid].cpu() + torch_row_indices = torch_indices[row_idx][:num_valid].cpu() + + # Compare the sets of indices first + cuda_set = set(cuda_row_indices.tolist()) + torch_set = set(torch_row_indices.tolist()) + if cuda_set == torch_set: + continue + + # Any difference in elements, compare the values + cuda_row_values = cuda_values[row_idx][:num_valid].cpu() + torch_row_values = torch_values[row_idx][:num_valid].cpu() + + cuda_only_values, torch_only_values = [], [] + for idx in cuda_set - torch_set: + cuda_pos = (cuda_row_indices == idx).nonzero(as_tuple=True)[0] + cuda_only_values.append(cuda_row_values[cuda_pos[0]]) + + for idx in torch_set - cuda_set: + torch_pos = (torch_row_indices == idx).nonzero(as_tuple=True)[0] + torch_only_values.append(torch_row_values[torch_pos[0]]) + + if len(cuda_only_values) != len(torch_only_values): + return False + if not torch.allclose( + torch.tensor(cuda_only_values), + torch.tensor(torch_only_values), + rtol=tolerance, + atol=tolerance, + ): + return False + + return True + + +@pytest.mark.parametrize("num_rows", NUM_ROWS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row( + num_rows: int, + top_k: int, +) -> None: + """ + Test top_k_per_row. + """ + torch.set_default_device("cuda:0") + + # Create test data + vocab_size = 20000 + row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) + logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + + # Create output tensors + indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda") + values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda") + + # Run CUDA implementation + torch.ops._C.top_k_per_row( + logits, + row_starts, + row_ends, + indices, + values, + num_rows, + logits.stride(0), + logits.stride(1), + ) + + # Run reference implementation + torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1) + mask_lo = torch_indices >= 0 + mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 + mask = mask_lo & mask_hi + torch_indices = torch_indices.masked_fill(~mask, -1) + + # Compare results + assert compare_top_k_results( + indices, values, torch_indices, torch_values, row_starts, row_ends, top_k + ), "CUDA top_k_per_row results don't match torch.topk" diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py index 1c31cfb25e5ac..4b0bbb992d2ee 100644 --- a/tests/kernels/test_triton_flash_attention.py +++ b/tests/kernels/test_triton_flash_attention.py @@ -4,21 +4,24 @@ Run `pytest tests/kernels/test_triton_flash_attention.py`. """ + import pytest import torch -from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS, - MetaData, - compute_alibi_tensor, - scale_fp8, - triton_attention_rocm) +from vllm.attention.ops.triton_flash_attention import ( + SUPPORTED_LAYOUTS, + MetaData, + compute_alibi_tensor, + scale_fp8, + triton_attention_rocm, +) from vllm.platforms import current_platform class ReferenceAttention: - - def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, - input_metadata): + def __init__( + self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata + ): self.Z = Z self.HQ = HQ self.HK = HK @@ -30,21 +33,23 @@ class ReferenceAttention: self.input_metadata = input_metadata def fwd(self, q, k, v): - scores = torch.einsum('bhqd,bhkd->bhqk', q, - k).float() * self.input_metadata.sm_scale + scores = ( + torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale + ) if self.input_metadata.causal: - mask = torch.tril(torch.ones(self.N_CTX_Q, - self.N_CTX_K, - device="cuda"), - diagonal=self.N_CTX_K - self.N_CTX_Q) + mask = torch.tril( + torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"), + diagonal=self.N_CTX_K - self.N_CTX_Q, + ) scores[:, :, mask == 0] = float("-inf") if self.input_metadata.bias is not None: scores += self.input_metadata.bias if self.use_alibi: - scores += compute_alibi_tensor(self.input_metadata.alibi_slopes, - self.N_CTX_Q, self.N_CTX_K) + scores += compute_alibi_tensor( + self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K + ) p = torch.softmax(scores, dim=-1) if self.input_metadata.causal: @@ -54,31 +59,38 @@ class ReferenceAttention: # should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v) + ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v) # compare - if self.input_metadata.layout == 'bshd': + if self.input_metadata.layout == "bshd": ref_out = ref_out.transpose(1, 2).clone() return ref_out def fwd_fp8(self, q_quantized, k_quantized, v_quantized): q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( - self.dtype) + self.dtype + ) k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( - self.dtype) + self.dtype + ) v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( - self.dtype) + self.dtype + ) result = self.fwd(q, k, v) if self.input_metadata.o_scale is not None: result, _ = scale_fp8(result, self.input_metadata.o_scale) return result def fwd_fp8_kv(self, q, k_quantized, v_quantized): - k_descale, v_descale = (self.input_metadata.k_descale, - self.input_metadata.v_descale) - k_dequantized = (k_quantized.to(torch.float32) * - k_descale.to(torch.float32)).to(self.dtype) - v_dequantized = (v_quantized.to(torch.float32) * - v_descale.to(torch.float32)).to(self.dtype) + k_descale, v_descale = ( + self.input_metadata.k_descale, + self.input_metadata.v_descale, + ) + k_dequantized = ( + k_quantized.to(torch.float32) * k_descale.to(torch.float32) + ).to(self.dtype) + v_dequantized = ( + v_quantized.to(torch.float32) * v_descale.to(torch.float32) + ).to(self.dtype) return self.fwd(q, k_dequantized, v_dequantized) def varlen_fwd(self, q, k, v, is_mqa=False): @@ -86,29 +98,33 @@ class ReferenceAttention: if is_mqa: # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so # the size aligns with Q. - k_ref = k.view(k.shape[0], k.shape[1], 1, - k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) - v_ref = v.view(v.shape[0], v.shape[1], 1, - v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand( + -1, -1, self.HQ // self.HK, -1 + ) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand( + -1, -1, self.HQ // self.HK, -1 + ) else: k_ref = k v_ref = v for i in range(0, self.input_metadata.num_contexts): - start_q, start_k = self.input_metadata.cu_seqlens_q[ - i], self.input_metadata.cu_seqlens_k[i] - end_q, end_k = self.input_metadata.cu_seqlens_q[ - i + 1], self.input_metadata.cu_seqlens_k[i + 1] + start_q, start_k = ( + self.input_metadata.cu_seqlens_q[i], + self.input_metadata.cu_seqlens_k[i], + ) + end_q, end_k = ( + self.input_metadata.cu_seqlens_q[i + 1], + self.input_metadata.cu_seqlens_k[i + 1], + ) k_curr = k_ref[start_k:end_k] v_curr = v_ref[start_k:end_k] if is_mqa: k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], - k_curr).float() - p = torch.softmax(scores * self.input_metadata.sm_scale, - dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr) return ref_out @@ -123,8 +139,7 @@ def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): # model. p_scale = None - o_scale = torch.rand(1, device="cuda", - requires_grad=False) if use_o_scale else None + o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale @@ -150,10 +165,10 @@ def input_helper( current_platform.seed_everything(0) # Initialize q, k, v - if layout == 'bhsd': + if layout == "bhsd": q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) - elif layout == 'bshd': + elif layout == "bshd": q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) @@ -161,69 +176,54 @@ def input_helper( # for n heads the set of slopes is the geometric sequence that starts # 2^(-8/n) alibi_slopes = torch.tensor( - [2**(-8 / HQ * i) for i in range(1, HQ + 1)], + [2 ** (-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) + device="cuda", + ).repeat(Z, 1) else: alibi_slopes = None if use_bias: - bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K), - dtype=dtype, - device="cuda", - requires_grad=False) + bias = torch.randn( + (1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False + ) else: bias = None - q = torch.randn(q_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) - k = torch.randn(k_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) - v = torch.randn(k_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) if is_fp8: - (q, k, v, q_descale, k_descale, v_descale, p_scale, - o_scale) = quantize_input(q, - k, - v, - use_o_scale=use_o_scale, - fp8_kv=fp8_kv) + (q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input( + q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv + ) else: q_descale = k_descale = v_descale = p_scale = o_scale = None - input_metadata = MetaData(sm_scale=D_HEAD**-0.5, - max_seqlens_q=N_CTX_Q, - max_seqlens_k=N_CTX_K, - layout=layout, - alibi_slopes=alibi_slopes, - alibi_batch=Z, - alibi_nheads=HQ, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - p_scale=p_scale, - o_scale=o_scale, - bias=bias, - seqlen_q=N_CTX_Q, - seqlen_k=N_CTX_K) + input_metadata = MetaData( + sm_scale=D_HEAD**-0.5, + max_seqlens_q=N_CTX_Q, + max_seqlens_k=N_CTX_K, + layout=layout, + alibi_slopes=alibi_slopes, + alibi_batch=Z, + alibi_nheads=HQ, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=o_scale, + bias=bias, + seqlen_q=N_CTX_Q, + seqlen_k=N_CTX_K, + ) return q, k, v, input_metadata -def varlen_input_helper(Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - equal_seqlens=False): +def varlen_input_helper( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False +): current_platform.seed_everything(0) # Random sequence lengths. Using N_CTX as kind of max of sum of individual @@ -231,66 +231,72 @@ def varlen_input_helper(Z, if not equal_seqlens: max_seqlens_q = N_CTX_Q // Z max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, - max_seqlens_q + 1, (Z, ), - dtype=torch.int32) - seqlens_k = torch.randint(1, - max_seqlens_k + 1, (Z, ), - dtype=torch.int32) + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) else: - seqlens_q = torch.full((Z, ), N_CTX_Q // Z) - seqlens_k = torch.full((Z, ), N_CTX_K // Z) + seqlens_q = torch.full((Z,), N_CTX_Q // Z) + seqlens_k = torch.full((Z,), N_CTX_K // Z) # Calculate cumulative sequence lengths - cu_seqlens_q = torch.cat([ - torch.tensor([0], dtype=torch.int32), - seqlens_q.cumsum(dim=0, dtype=torch.int32) - ]) - cu_seqlens_k = torch.cat([ - torch.tensor([0], dtype=torch.int32), - seqlens_k.cumsum(dim=0, dtype=torch.int32) - ]) + cu_seqlens_q = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + seqlens_q.cumsum(dim=0, dtype=torch.int32), + ] + ) + cu_seqlens_k = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + seqlens_k.cumsum(dim=0, dtype=torch.int32), + ] + ) cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() total_k = cu_seqlens_k[-1].item() - q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() + q = ( + torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata -@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 48, 12, 1, 1, 64), - (4, 4, 4, 128, 128, 65), - (16, 48, 48, 1, 1, 128), - (64, 48, 24, 3, 3, 128), - (4, 4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_alibi', [True, False]) -@pytest.mark.parametrize('layout', ['bshd']) -def test_op_fwd(Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - use_alibi, - layout, - dtype=torch.float16): +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 48, 12, 1, 1, 64), + (4, 4, 4, 128, 128, 65), + (16, 48, 48, 1, 1, 128), + (64, 48, 24, 3, 3, 128), + (4, 4, 4, 113, 123, 1), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_alibi", [True, False]) +@pytest.mark.parametrize("layout", ["bshd"]) +def test_op_fwd( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16 +): current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, - dtype, layout, use_alibi, causal) + q, k, v, input_metadata = input_helper( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal + ) o = torch.empty_like(q) @@ -299,48 +305,50 @@ def test_op_fwd(Z, # Transpose here if layout is bshd so we have same reference code for all # layouts - if layout == 'bshd': + if layout == "bshd": q = q.transpose(1, 2).clone() k = k.transpose(1, 2).clone() v = v.transpose(1, 2).clone() # Replicate K and V if using MQA/GQA if HQ != HK: - k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], - k.shape[3]).expand(-1, -1, HQ // HK, -1, - -1).reshape(k.shape[0], -1, k.shape[2], - k.shape[3]) - v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], - v.shape[3]).expand(-1, -1, HQ // HK, -1, - -1).reshape(v.shape[0], -1, v.shape[2], - v.shape[3]) + k = ( + k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3]) + .expand(-1, -1, HQ // HK, -1, -1) + .reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + ) + v = ( + v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3]) + .expand(-1, -1, HQ // HK, -1, -1) + .reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + ) - ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, - use_alibi, dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata + ) ref_out = ref_impl.fwd(q, k, v) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('layout', ['bhsd']) -@pytest.mark.parametrize('use_o_scale', [True, False]) -@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0), - reason="Triton FP8 requires CUDA 9.0 or higher") -def test_op_fwd_fp8(Z, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - layout, - use_o_scale, - dtype=torch.float32): +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("layout", ["bhsd"]) +@pytest.mark.parametrize("use_o_scale", [True, False]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="Triton FP8 requires CUDA 9.0 or higher", +) +def test_op_fwd_fp8( + Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32 +): current_platform.seed_everything(0) # Disable grad to save memory it won't run into OOM on CI machine. @@ -358,95 +366,103 @@ def test_op_fwd_fp8(Z, causal=causal, layout=layout, is_fp8=True, - use_o_scale=use_o_scale) + use_o_scale=use_o_scale, + ) o = torch.empty_like(q_quantized) if use_o_scale else None - tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized, - o, input_metadata) + tri_out, _ = triton_attention_rocm( + q_quantized, k_quantized, v_quantized, o, input_metadata + ) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) # compare - torch.testing.assert_close(ref_out.to(torch.float32), - tri_out.to(torch.float32), - atol=7e-2, - rtol=2e-1) + torch.testing.assert_close( + ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1 + ) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('layout', ['bhsd']) -def test_op_fwd_fp8_kv(Z, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - layout, - dtype=torch.float32): +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("layout", ["bhsd"]) +def test_op_fwd_fp8_kv( + Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32 +): current_platform.seed_everything(0) - q, k_quantized, v_quantized, input_metadata = input_helper(Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - causal=causal, - layout=layout, - is_fp8=True, - fp8_kv=True) + q, k_quantized, v_quantized, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + fp8_kv=True, + ) o = torch.empty_like(q) - tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, - input_metadata) + tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_bias', [True]) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_bias", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper(Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - layout='bhsd', - causal=causal, - use_bias=use_bias) + q, k, v, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout="bhsd", + causal=causal, + use_bias=use_bias, + ) o = torch.empty_like(q) # triton implementation tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd(q, k, v) # compare @@ -454,47 +470,47 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64), - (4, 48, 512, 64), - (16, 48, 512, 64), - (64, 48, 128, 128)]) -@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize( + "Z, H, N_CTX, D_HEAD", + [(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)], +) +@pytest.mark.parametrize("causal", [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, - D_HEAD, dtype) + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) triton_attention_rocm(q, k, v, tri_out, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, - input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) # NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), - (4, 48, 12, 256, 64), - (4, 48, 4, 512, 64), - (4, 64, 16, 128, 128)]) -@pytest.mark.parametrize('causal', [False]) -def test_op_varlen_mqa_fwd(Z, - HQ, - HK, - N_CTX, - D_HEAD, - causal, - dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, - D_HEAD, dtype) +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX, D_HEAD", + [ + (2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 64, 16, 128, 128), + ], +) +@pytest.mark.parametrize("causal", [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper( + Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype + ) tri_out = torch.empty_like(q) triton_attention_rocm(q, k, v, tri_out, input_metadata) - ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fa4125840a010..015424d9ee0f7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,12 +15,15 @@ from torch._prims_common import TensorLikeType from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +from vllm.attention.backends.registry import _Backend from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) -from vllm.platforms.interface import _Backend -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils import ( + STR_BACKEND_ENV_VAR, + STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad, +) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -39,7 +42,7 @@ ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = ( class QKVInputs(NamedTuple): - ''' + """ Data structure for representing unpacked attention inputs, query/key/values and their sequence lengths. @@ -49,7 +52,7 @@ class QKVInputs(NamedTuple): num_heads x head_size) attention inputs * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list - ''' + """ query: torch.Tensor key: torch.Tensor @@ -59,7 +62,7 @@ class QKVInputs(NamedTuple): class QKVO(NamedTuple): - ''' + """ Data structure for representing unpacked attention inputs, alongside unpacked known-correct attention output @@ -69,14 +72,14 @@ class QKVO(NamedTuple): num_heads x head_size) attention inputs * ideal_output: unpacked (batch_size x padded_seq_len x num_heads x head_size) known-correct attention output - ''' + """ qkv: QKVInputs ideal_output: torch.Tensor class PackedQKVInputs(NamedTuple): - ''' + """ Data structure for representing packed attention inputs Attributes: @@ -88,7 +91,7 @@ class PackedQKVInputs(NamedTuple): packed tensor * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list - ''' + """ query: torch.Tensor key: torch.Tensor @@ -100,7 +103,7 @@ class PackedQKVInputs(NamedTuple): class PackedQKVO(NamedTuple): - ''' + """ Data structure for representing packed attention inputs, alongside packed known-correct attention output @@ -110,28 +113,28 @@ class PackedQKVO(NamedTuple): x head_size) attention inputs * ideal_output: packed (number_of_tokens x num_heads x head_size) known-correct attention output - ''' + """ packed_qkv: Optional[PackedQKVInputs] ideal_output: torch.Tensor class KVMemoryMap(NamedTuple): - ''' + """ Data structure for encapsulating KV cache memory mapping. Attributes: * block_tables: KV cache block tables * slot_mapping: mapping of sequence offset to physical address - ''' + """ block_tables: torch.Tensor slot_mapping: torch.Tensor class PhaseTestParameters(NamedTuple): - ''' + """ Data structure for encapsulating the test parameters for a given test "phase" (prefill or decode phase) and attention scenario (encoder, decoder-self, encoder/decoder-cross) @@ -143,7 +146,7 @@ class PhaseTestParameters(NamedTuple): output * kv_mmap: KV cache memory mapping, specific to this test phase & attention scenario - ''' + """ packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] @@ -153,41 +156,43 @@ def maybe_make_int_tensor( _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: - ''' + """ Convert Python int list to a 1D int torch.Tensor on `device` Returns: * If _list is not None: 1D int torch.Tensor on `device` * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) + """ + return ( + None if _list is None else torch.tensor(_list, dtype=torch.int, device=device) + ) def maybe_make_long_tensor( _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: - ''' + """ Convert Python int list to a 1D long torch.Tensor on `device` Returns: * If _list is not None: 1D long torch.Tensor on `device` * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) + """ + return ( + None if _list is None else torch.tensor(_list, dtype=torch.long, device=device) + ) def maybe_max(_list: Optional[list]) -> Optional[Number]: - ''' + """ Returns: * If _list is not None: max(_list) * None otherwise - ''' + """ return None if _list is None else max(_list) @@ -195,7 +200,7 @@ def make_causal_mask( q_max_seq_len: int, kv_max_seq_len: int, ) -> torch.Tensor: - ''' + """ Create a q_max_seq_len x kv_max_seq_len causal mask Arguments: @@ -206,19 +211,19 @@ def make_causal_mask( Returns: * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' + """ # Create a matrix where entry (i, j) is True if i >= j mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) + mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) return mask -def override_backend_env_variable(mpatch: pytest.MonkeyPatch, - backend_name: str) -> None: - ''' +def override_backend_env_variable( + mpatch: pytest.MonkeyPatch, backend_name: str +) -> None: + """ Override the environment variable indicating the vLLM backend temporarily, using pytest monkeypatch to ensure that the env vars get reset once the test context exits. @@ -227,18 +232,20 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch, * mpatch: pytest monkeypatch instance * backend_name: attention backend name to force - ''' + """ mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) -def ref_masked_attention(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[list] = None, - kv_seq_lens: Optional[list] = None) -> torch.Tensor: - ''' +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[list] = None, + kv_seq_lens: Optional[list] = None, +) -> torch.Tensor: + """ "Golden" masked attention reference. Supports two types of masking: * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out @@ -260,14 +267,14 @@ def ref_masked_attention(query: torch.Tensor, Returns: * Attention result, batch_size x q_padded_seq_len x num_heads x head_size - ''' + """ assert q_seq_lens is not None assert kv_seq_lens is not None batch_size = query.shape[0] - assert (len(q_seq_lens) == batch_size) - assert (len(kv_seq_lens) == batch_size) + assert len(q_seq_lens) == batch_size + assert len(kv_seq_lens) == batch_size attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() @@ -303,7 +310,7 @@ def make_qkv( attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, ) -> tuple[QKVInputs, QKVInputs, QKVInputs]: - ''' + """ Construct QKV test tensors for self- and cross-attention. Generates three query/key/value triplets: @@ -340,14 +347,12 @@ def make_qkv( * Overall QKVInputs structure (containing full unpacked Q/K/V tensors) * Prefill QKVInputs structure (containing all but the last sequence offset) * Decode QKVInputs structure (containing all only the last sequence offset) - ''' + """ if force_max_len: q_seq_lens = [max_q_seq_len for _ in range(batch_size)] else: - q_seq_lens = [ - random.randint(2, max_q_seq_len) for _ in range(batch_size) - ] + q_seq_lens = [random.randint(2, max_q_seq_len) for _ in range(batch_size)] kv_seq_lens = None if force_kv_seq_lens is not None: kv_seq_lens = force_kv_seq_lens @@ -360,50 +365,44 @@ def make_qkv( if force_max_len: kv_seq_lens = [max_kv_seq_len] * batch_size else: - kv_seq_lens = [ - random.randint(2, max_kv_seq_len) for _ in range(batch_size) - ] + kv_seq_lens = [random.randint(2, max_kv_seq_len) for _ in range(batch_size)] - query = torch.rand( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - key = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - value = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + query = torch.rand((batch_size, max_q_seq_len, num_heads, head_size)).to(device) + key = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + value = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + prefill_query = torch.zeros((batch_size, max_q_seq_len, num_heads, head_size)).to( + device + ) + prefill_key = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to( + device + ) + prefill_value = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to( + device + ) - decode_query = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) + decode_query = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - decode_value = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) + decode_value = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, - kv_seq_lens)): + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): query[bdx, q_seq_len:, :, :] = 0 key[bdx, kv_seq_len:, :, :] = 0 value[bdx, kv_seq_len:, :, :] = 0 - prefill_query[bdx, - 0:(q_seq_len - 1), :, :] = query[bdx, - 0:(q_seq_len - 1), :, :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :, :] = key[bdx, - 0:(kv_seq_len - 1), :, :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] + prefill_query[bdx, 0 : (q_seq_len - 1), :, :] = query[ + bdx, 0 : (q_seq_len - 1), :, : + ] + prefill_key[bdx, 0 : (kv_seq_len - 1), :, :] = key[ + bdx, 0 : (kv_seq_len - 1), :, : + ] + prefill_value[bdx, 0 : (kv_seq_len - 1), :, :] = value[ + bdx, 0 : (kv_seq_len - 1), :, : + ] - decode_query[bdx, :, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :, :] - decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] - decode_value[bdx, :, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :, :] + decode_query[bdx, :, :, :] = query[bdx, (q_seq_len - 1) : q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1) : kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, (kv_seq_len - 1) : kv_seq_len, :, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -417,25 +416,29 @@ def make_qkv( key, value, q_seq_lens, - kv_seq_lens), + kv_seq_lens, + ), QKVInputs( prefill_query, # Prefill subset of QKV sequences prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens), + prefill_kv_seq_lens, + ), QKVInputs( decode_query, # Decode subset of KV sequences decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens)) + decode_kv_seq_lens, + ), + ) def pack_tensor( - unpacked_tensor: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]: - ''' + unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str] +) -> tuple[torch.Tensor, list[int]]: + """ Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where number_of_tokens = sum(seq_lens) @@ -451,7 +454,7 @@ def pack_tensor( * packed_tensor: number_of_tokens x num_heads x head_size * start_loc_list: start idx of each batch elt in packed_tensor; [0] + list(itertools.accumulate(seq_lens)) - ''' + """ num_tok = sum(seq_lens) num_heads = unpacked_tensor.shape[-2] @@ -460,16 +463,15 @@ def pack_tensor( packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): - - packed_tensor[start_loc:( - start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + packed_tensor[start_loc : (start_loc + seq_len), :, :] = unpacked_tensor[ + bdx, :seq_len, :, : + ] return packed_tensor, start_loc_list -def pack_qkv(qkv: QKVInputs, device: Union[torch.device, - str]) -> PackedQKVInputs: - ''' +def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs: + """ Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x num_heads x head_size tensors. @@ -488,35 +490,33 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, * Packed (number_of_tokens x num_heads x head_size) QKV inputs derived from unpacked inputs - ''' + """ if qkv.query is None: packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(qkv.query, - qkv.q_seq_lens, - device=device) - packed_key, kv_start_loc_list = pack_tensor(qkv.key, - qkv.kv_seq_lens, - device=device) + packed_query, q_start_loc_list = pack_tensor( + qkv.query, qkv.q_seq_lens, device=device + ) + packed_key, kv_start_loc_list = pack_tensor(qkv.key, qkv.kv_seq_lens, device=device) packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) return PackedQKVInputs( - packed_query, packed_key, packed_value, q_start_loc_list, + packed_query, + packed_key, + packed_value, + q_start_loc_list, kv_start_loc_list, (None if q_start_loc_list is None else qkv.q_seq_lens), - qkv.kv_seq_lens) + qkv.kv_seq_lens, + ) def make_backend(backend_name: str) -> AttentionBackend: - ''' + """ Construct the backend instance determined by the backend_name string argument. - "XFORMERS" -> construct xformers backend - - TODO: other backends - Note: at time of writing the Attention wrapper automatically selects its own backend for Attention.forward(); so the backend instance which you generate with this function is not meant to be used for *running* @@ -527,17 +527,70 @@ def make_backend(backend_name: str) -> AttentionBackend: Returns: * Backend instance - ''' + """ if backend_name == STR_XFORMERS_ATTN_VAL: - # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. - from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() - elif backend_name == STR_FLASH_ATTN_VAL: - from vllm.attention.backends.flash_attn import FlashAttentionBackend - return FlashAttentionBackend() + from vllm.v1.attention.backends.xformers import XFormersAttentionBackend - raise AssertionError( - f"Unrecognized backend_name {backend_name} for unit test") + return XFormersAttentionBackend() + if backend_name == STR_FLASH_ATTN_VAL: + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + return FlashAttentionBackend() + if backend_name == "TRITON_ATTN": + from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend + + return TritonAttentionBackend() + if backend_name == "FLEX_ATTENTION": + from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend + + return FlexAttentionBackend() + if backend_name == "TORCH_SDPA": + from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend + + return TorchSDPABackend() + if backend_name == "FLASHINFER": + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + + return FlashInferBackend() + + raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test") + + +def make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: list[int], +) -> list[Any]: + """Create ALiBi biases compatible with xFormers attention tests.""" + from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias + + if alibi_slopes is None: + return [None for _ in seq_lens] + + attn_biases: list[Any] = [] + num_heads = alibi_slopes.shape[0] + assert num_heads >= num_kv_heads, ( + "ALiBi slopes expect at least as many heads as KV heads" + ) + + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + bias_tensor = torch.empty( + 1, + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias_tensor.mul_(alibi_slopes[:, None, None]) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor)) + + return attn_biases def _make_metadata_tensors( @@ -545,9 +598,17 @@ def _make_metadata_tensors( context_lens: Optional[list[int]], encoder_seq_lens: Optional[list[int]], device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], - torch.Tensor, torch.Tensor, Optional[int]]: - ''' +) -> tuple[ + torch.Tensor, + torch.Tensor, + Any, + Any, + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + Optional[int], +]: + """ Build scalar & tensor values required to build attention metadata structure. Arguments: @@ -567,48 +628,61 @@ def _make_metadata_tensors( * encoder_seq_lens_tensor: encoder seq_lens list, as tensor * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor - ''' + """ seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) context_lens_tensor = maybe_make_int_tensor(context_lens, device) max_context_len = maybe_max(context_lens) max_seq_len = maybe_max(seq_lens) encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) - max_encoder_seq_len = (None if encoder_seq_lens is None else - max(encoder_seq_lens)) + max_encoder_seq_len = None if encoder_seq_lens is None else max(encoder_seq_lens) seq_start_loc = None if seq_lens_tensor is not None: - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=seq_lens_tensor.device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) + seq_start_loc = torch.zeros( + seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=seq_lens_tensor.device, + ) + torch.cumsum( + seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:] + ) - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=encoder_seq_lens_tensor.device) - torch.cumsum(encoder_seq_lens_tensor, - dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) + encoder_seq_start_loc = torch.zeros( + encoder_seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device, + ) + torch.cumsum( + encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:], + ) - return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, - max_encoder_seq_len) + return ( + seq_lens_tensor, + context_lens_tensor, + max_context_len, + max_seq_len, + seq_start_loc, + encoder_seq_lens_tensor, + encoder_seq_start_loc, + max_encoder_seq_len, + ) -def make_kv_cache(num_blocks: int, - num_heads: int, - head_size: int, - block_size: int, - device: Union[torch.device, str], - backend: str, - default_val: float = 0.0) -> torch.Tensor: - ''' +def make_kv_cache( + num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str], + backend: str, + default_val: float = 0.0, +) -> torch.Tensor: + """ Create a fake KV cache. Arguments: @@ -626,27 +700,29 @@ def make_kv_cache(num_blocks: int, * for backend 'XFORMERS' * kv_cache: 2 x num_blocks x block_size x num_heads x head_size * for backend 'FLASH_ATTN' - ''' - if backend == 'XFORMERS': - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) - elif backend == 'FLASH_ATTN': - kv_cache = torch.rand( - (2, num_blocks, block_size, num_heads, head_size)).to(device) + """ + if backend == "XFORMERS": + kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to( + device + ) + elif backend == "FLASH_ATTN": + kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to( + device + ) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " - f"'FLASH_ATTN'.") + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: - ''' + """ Compute the minimum number of blocks required to hold num_tokens tokens, given block_size - ''' + """ return (num_tokens + block_size) // block_size @@ -658,9 +734,12 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]): return torch.tensor([], device=device) -def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str]): - ''' +def split_slot_mapping( + slot_mapping_list: torch.Tensor, + seq_lens: list[int], + device: Union[torch.device, str], +): + """ Split a slot mapping into valid prefill- and decode-phase slot mappings. Context: @@ -698,28 +777,32 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], reflecting all N prefill prompts * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting all N decoded tokens - ''' + """ prefill_slot_mapping = [] decode_slot_mapping = [] base_idx = 0 for seq_len in seq_lens: - prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + - seq_len - 1)]) + prefill_slot_mapping.extend( + slot_mapping_list[base_idx : (base_idx + seq_len - 1)] + ) decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len - return (maybe_make_long_tensor(prefill_slot_mapping, device), - maybe_make_long_tensor(decode_slot_mapping, device)) + return ( + maybe_make_long_tensor(prefill_slot_mapping, device), + maybe_make_long_tensor(decode_slot_mapping, device), + ) def make_block_tables_slot_mapping( - block_size: int, - seq_lens: list[int], - device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]: - ''' + block_size: int, + seq_lens: list[int], + device: Union[torch.device, str], + block_base_addr: int = 0, +) -> tuple[torch.Tensor, list[int], int]: + """ Construct fake block tables & slot mappings. For a sequence with num_tokens tokens the minimum number @@ -756,12 +839,11 @@ def make_block_tables_slot_mapping( * block_tables_tensor: block table for sequence * slot_mapping_list: slot mapping for sequence * max_block_idx: the highest block address within this block table - ''' + """ # Provision minimum number of KV cache blocks num_blocks_list = [ - _num_tokens_to_min_blocks(num_tokens, block_size) - for num_tokens in seq_lens + _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -774,11 +856,11 @@ def make_block_tables_slot_mapping( max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] - block_table = list( - range(block_base_idx, block_base_idx - num_blocks, -1)) + block_table = list(range(block_base_idx, block_base_idx - num_blocks, -1)) for idx in range(num_tokens): - mapping_value = ( - idx % block_size) + block_table[idx // block_size] * block_size + mapping_value = (idx % block_size) + block_table[ + idx // block_size + ] * block_size slot_mapping_list.append(mapping_value) block_base_idx -= num_blocks @@ -802,9 +884,9 @@ def make_test_metadata( decoder_test_params: Optional[PhaseTestParameters], device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, - cross_test_params: Optional[PhaseTestParameters] = None + cross_test_params: Optional[PhaseTestParameters] = None, ) -> AttentionMetadata: - ''' + """ Construct fake attention metadata for a given test phase (prefill-phase or decode-phase). @@ -841,13 +923,12 @@ def make_test_metadata( Return: * AttentionMetadata structure - ''' + """ # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None - kv_mmap = (None - if decoder_test_params is None else decoder_test_params.kv_mmap) + kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap # This function constructs metadata assuming no chunked prefill, # i.e. 100% prefill tokens or 100% decode tokens @@ -860,10 +941,11 @@ def make_test_metadata( # seq_lens is None signals encoder-only # scenario, in which case num_prefills_or_decodes and # num_prefill_or_decode_tokens are unused - num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens)) + num_prefills_or_decodes = None if seq_lens is None else len(seq_lens) - num_prefill_or_decode_tokens = (None if seq_lens is None else ( - sum(seq_lens) if is_prompt else len(seq_lens))) + num_prefill_or_decode_tokens = ( + None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens)) + ) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -877,16 +959,13 @@ def make_test_metadata( # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - num_encoder_tokens = (None if encoder_seq_lens is None else - (sum(encoder_seq_lens))) + num_encoder_tokens = ( + None if encoder_seq_lens is None else (sum(encoder_seq_lens)) + ) - if cross_test_params is None: - cross_kv_mmap = None - else: - # Encoder/decoder or encoder-only models only: - # * Extract *cross-attention* slot_mapping and block table - # (kv_mmap) - cross_kv_mmap = cross_test_params.kv_mmap + # For encoder/decoder or encoder-only models only, extract *cross-attention* + # slot_mapping and block table (kv_mmap) + cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap attn_backend_obj = make_backend(attn_backend.name) @@ -906,14 +985,12 @@ def make_test_metadata( encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, - ) = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ) = _make_metadata_tensors( + seq_lens, context_lens, encoder_seq_lens, device=device + ) return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -930,10 +1007,13 @@ def make_test_metadata( encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=(None if cross_kv_mmap is None else - cross_kv_mmap.slot_mapping), - cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) + cross_slot_mapping=( + None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping + ), + cross_block_tables=( + None if cross_kv_mmap is None else cross_kv_mmap.block_tables + ), + ) else: # not is_prompt # Decode-phase scenario @@ -955,15 +1035,13 @@ def make_test_metadata( encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, - ) = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ) = _make_metadata_tensors( + seq_lens, context_lens, encoder_seq_lens, device=device + ) return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -981,16 +1059,19 @@ def make_test_metadata( encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=(None if cross_kv_mmap is None else - cross_kv_mmap.slot_mapping), - cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) + cross_slot_mapping=( + None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping + ), + cross_block_tables=( + None if cross_kv_mmap is None else cross_kv_mmap.block_tables + ), + ) -def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor, - backend: str) -> None: - ''' +def assert_actual_matches_ideal( + test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str +) -> None: + """ Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -998,24 +1079,24 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * test_params: Test parameters including packed ideal output * output_under_test: actually observed output value - ''' + """ ideal_output = test_params.packed_qkvo.ideal_output - if backend == 'XFORMERS': - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + if backend == "XFORMERS": + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output) + ) - elif backend == 'FLASH_ATTN': + elif backend == "FLASH_ATTN": # For FlashAttention override the accuracy thresholds to non default # values since we notice a higher difference between the ideal and # actual output. - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output), - atol=0.01, - rtol=0.016) + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016 + ) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " - f"'FLASH_ATTN'.") + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) # Copied/modified from torch._refs.__init__.py @@ -1029,19 +1110,15 @@ def fp8_allclose( """ Reference implementation of torch.allclose """ - torch._refs._check_close_args(name="torch.allclose", - a=a, - b=b, - rtol=rtol, - atol=atol) + torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) return bool( torch.all( - torch.isclose(a.double(), - b.double(), - rtol=rtol, - atol=atol, - equal_nan=equal_nan)).item()) + torch.isclose( + a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan + ) + ).item() + ) # Marlin MoE test utils @@ -1054,7 +1131,8 @@ def stack_and_dev(tensors: list[torch.Tensor]): def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def torch_experts( @@ -1076,10 +1154,11 @@ def torch_experts( block_shape: Optional[list[int]] = None, apply_router_weights_on_input: bool = False, ) -> torch.Tensor: - assert (global_num_experts == -1 - or (global_num_experts == w1.shape[0] and expert_map is None) - or (expert_map is not None - and global_num_experts == expert_map.shape[0])) + assert ( + global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None and global_num_experts == expert_map.shape[0]) + ) M, K = a.shape topk = topk_ids.shape[1] @@ -1094,8 +1173,9 @@ def torch_experts( if a1_scale: assert not per_act_token_quant and block_shape is None - a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype, - per_act_token_quant, block_shape) + a, a_scale = moe_kernel_quantize_input( + a, a1_scale, quant_dtype, per_act_token_quant, block_shape + ) num_experts = w1.shape[0] @@ -1115,31 +1195,35 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) if b_bias2 is not None: - out[mask] = out[mask] + b_bias2[i].view(1, -1).to( - tmp1.dtype) + out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) elif block_shape is not None: # block quantized - assert (a_scale is not None and w1_scale is not None - and w2_scale is not None) - tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], - w1_scale[i], block_shape, - out.dtype) + assert ( + a_scale is not None + and w1_scale is not None + and w2_scale is not None + ) + tmp1 = native_w8a8_block_matmul( + a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype + ) if b_bias1 is not None: tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, a2_scale, quant_dtype, per_act_token_quant, - block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape + ) - out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, - w2_scale[i], block_shape, - out.dtype) + out[mask] = native_w8a8_block_matmul( + tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype + ) if b_bias2 is not None: - out[mask] = out[mask] + b_bias2[i].view(1, -1).to( - tmp1.dtype) + out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) else: - assert (a_scale is not None and w1_scale is not None - and w2_scale is not None) + assert ( + a_scale is not None + and w1_scale is not None + and w2_scale is not None + ) scales = a_scale if a_scale.numel() == 1 else a_scale[mask] tmp1 = a[mask].to(f32) * scales @@ -1151,37 +1235,50 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1).to(out.dtype) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, a2_scale, quant_dtype, per_act_token_quant, - block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape + ) assert b_scale is not None tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) if b_bias2 is not None: - out[mask] = out[mask] + b_bias2[i].view(1, -1).to( - out.dtype) + out[mask] = out[mask] + b_bias2[i].view(1, -1).to(out.dtype) if apply_router_weights_on_input: return out else: - return (out.view(M, -1, w2.shape[1]).to(f32) * - topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) + return ( + (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1)) + .sum(dim=1) + .to(out.dtype) + ) -def torch_moe(a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - b_bias1: Optional[torch.Tensor] = None, - b_bias2: Optional[torch.Tensor] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: +def torch_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) - return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, - b_bias1, b_bias2, expert_map) + return torch_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts, + b_bias1, + b_bias2, + expert_map, + ) def torch_moe_single(a, w, score, topk): @@ -1200,43 +1297,51 @@ def torch_moe_single(a, w, score, topk): # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. -def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, - torch._library.custom_ops.CustomOpDef], - args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, - *, - test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, - raise_exception: bool = True, - cond: bool = True) -> dict[str, str]: - with unittest.mock.patch('torch.allclose', new=fp8_allclose): - return torch.library.opcheck( - op, - args, - kwargs, - test_utils=test_utils, - raise_exception=raise_exception) if cond else {} +def opcheck( + op: Union[ + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + torch._library.custom_ops.CustomOpDef, + ], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, + raise_exception: bool = True, + cond: bool = True, +) -> dict[str, str]: + with unittest.mock.patch("torch.allclose", new=fp8_allclose): + return ( + torch.library.opcheck( + op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception + ) + if cond + else {} + ) # For testing quantized linear kernels def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) def to_int8(tensor: torch.Tensor): return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) -def baseline_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - +def baseline_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting - # in numpy simply stretches dimensions with an extent of 1 to match the + # in numpy simply stretches dimensions with an extent of 1 to match # the target shape by repeating the data along that dimension (broadcasting) # , we extend these semantics to say if the extent of a dimension in the # source shape is not 1 and does not match the target shape we repeat each @@ -1247,22 +1352,25 @@ def baseline_scaled_mm(a: torch.Tensor, # then we would expand a to: # a = [[1, 1, 2, 2], # [3, 3, 4, 4]] - # NOTE this function this function does not explicitly broadcast dimensions + # NOTE this function does not explicitly broadcast dimensions # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: assert s % t.shape[i] == 0 - t = t.unsqueeze(i + 1)\ - .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ - .flatten(i, i + 1) + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) + .flatten(i, i + 1) + ) return t scale_a = group_broadcast(scale_a, a.shape) scale_b = group_broadcast(scale_b, b.shape) - output = torch.mm((scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32))).to(out_dtype) + output = torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ).to(out_dtype) if bias is not None: output = output + bias diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 352ab63552de7..a61ccef700624 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -8,8 +8,7 @@ import torch from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe # TODO: the test depends on a lot of fields in the current implementation. @@ -17,7 +16,6 @@ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe def test_run(my_rank, buffer, device): - # buffer should be empty in the beginning if my_rank == 0: assert buffer.buffer_size == 0 @@ -27,7 +25,7 @@ def test_run(my_rank, buffer, device): # insert tokens = torch.tensor([1, 2, 3]).to(device) - roi = (tokens > 0) + roi = tokens > 0 if my_rank == 0: key = 2.0 * torch.ones([5, 6]).to(device) value = 3.0 * torch.ones([5, 6]).to(device) @@ -55,7 +53,6 @@ def test_run(my_rank, buffer, device): def stress_test(my_rank, buf, device): - torch.distributed.barrier() torch.manual_seed(100) @@ -66,7 +63,8 @@ def stress_test(my_rank, buf, device): torch.rand(100).to(device), # key torch.rand(100).to(device), # value torch.rand(100).to(device), # hidden - ) for i in tqdm(range(200)) + ) + for i in tqdm(range(200)) ] random.seed(my_rank) @@ -115,12 +113,11 @@ def stress_test(my_rank, buf, device): if __name__ == "__main__": - - my_rank = int(os.environ['RANK']) + my_rank = int(os.environ["RANK"]) torch.distributed.init_process_group( - backend='gloo', - init_method='tcp://localhost:12398', + backend="gloo", + init_method="tcp://localhost:12398", world_size=2, rank=my_rank, ) @@ -128,8 +125,8 @@ if __name__ == "__main__": print(f"initialized! My rank is {my_rank}") config = KVTransferConfig( - kv_connector='PyNcclConnector', - kv_buffer_device='cuda', + kv_connector="P2pNcclConnector", + kv_buffer_device="cuda", kv_buffer_size=1e9, kv_rank=my_rank, kv_role="kv_both", # this arg doesn't matter in this test @@ -160,4 +157,4 @@ if __name__ == "__main__": buffer.close() data_pipe.close() cpu_pipe.close() - print('Done') + print("Done") diff --git a/tests/kv_transfer/test_module.py b/tests/kv_transfer/test_module.py index 7a04174870daf..b9a28e4bceb7c 100644 --- a/tests/kv_transfer/test_module.py +++ b/tests/kv_transfer/test_module.py @@ -9,21 +9,19 @@ import torch def run_python_script(script_name, timeout): - script_name = f'kv_transfer/{script_name}' + script_name = f"kv_transfer/{script_name}" try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": - "0"}, # Set the RANK environment variable for process 0 + env={"RANK": "0"}, # Set the RANK environment variable for process 0 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) process1 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": - "1"}, # Set the RANK environment variable for process 1 + env={"RANK": "1"}, # Set the RANK environment variable for process 1 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) @@ -34,11 +32,9 @@ def run_python_script(script_name, timeout): # Check the return status of both processes if process0.returncode != 0: - pytest.fail( - f"Test {script_name} failed for RANK=0, {process0.returncode}") + pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}") if process1.returncode != 0: - pytest.fail( - f"Test {script_name} failed for RANK=1, {process1.returncode}") + pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}") except subprocess.TimeoutExpired: # If either process times out, terminate both and fail the test @@ -53,15 +49,14 @@ def run_python_script(script_name, timeout): @pytest.mark.parametrize( "script_name,timeout", [ - ("test_lookup_buffer.py", - 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120) # First test case with a 120-second timeout - ]) + ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120), # First test case with a 120-second timeout + ], +) def test_run_python_script(script_name, timeout): # Check the number of GPUs if torch.cuda.device_count() < 2: - pytest.skip( - f"Skipping test {script_name} because <2 GPUs are available") + pytest.skip(f"Skipping test {script_name} because <2 GPUs are available") # Run the test if there are at least 2 GPUs run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 32116608a2177..5762224eff76d 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -15,7 +15,7 @@ def test_run(my_rank, pipe): print(f"rank {my_rank} test_run starts....") # test run x = torch.tensor([1]).to(pipe.device) - y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device) if my_rank == 0: pipe.send_tensor(x) print(f"rank {my_rank} sent tensor x") @@ -53,9 +53,8 @@ def stress_test(my_rank, pipe): for i in tqdm(range(500)): mean = torch.rand(1).item() * 100 std = torch.rand(1).item() * 100 - size = torch.randint(900, 1000, (2, )) - x = torch.normal(mean * 1.0, std * 1.0, - size=size.tolist()).to(pipe.device) + size = torch.randint(900, 1000, (2,)) + x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) # 5% probability of sending a None if torch.rand(1).item() < 0.05: @@ -96,20 +95,16 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() for i in tqdm(range(500)): - tensors = [] if my_rank == 0: # create tensor - tensors = [ - torch.rand(nelement).to(pipe.device) for _ in range(ntensor) - ] + tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] torch.distributed.barrier() if my_rank == 0: - t = torch.tensor([time.time()], - dtype=torch.float64).to(pipe.device) + t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) for tensor in tensors: pipe.send_tensor(tensor) pipe.send_tensor(t) @@ -121,24 +116,23 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() - print('Latency test passed.') - print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + print("Latency test passed.") + print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms") if __name__ == "__main__": - - my_rank = int(os.environ['RANK']) + my_rank = int(os.environ["RANK"]) torch.distributed.init_process_group( - backend='gloo', - init_method='tcp://localhost:12398', + backend="gloo", + init_method="tcp://localhost:12398", world_size=2, rank=my_rank, ) config = KVTransferConfig( - kv_connector='PyNcclConnector', - kv_buffer_device='cuda', + kv_connector="P2pNcclConnector", + kv_buffer_device="cuda", kv_buffer_size=1e9, kv_rank=my_rank, kv_role="kv_both", # this arg doesn't matter in this test diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index cba573b63c045..f805a74a4dba8 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -10,14 +10,17 @@ import torch import torch.nn as nn from huggingface_hub import snapshot_download -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -48,11 +51,13 @@ def dist_init(): if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" - init_distributed_environment(world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend=backend) + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend, + ) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -67,10 +72,9 @@ def dist_init_torch_only(): backend = "gloo" temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group(world_size=1, - rank=0, - init_method=f"file://{temp_file}", - backend=backend) + torch.distributed.init_process_group( + world_size=1, rank=0, init_method=f"file://{temp_file}", backend=backend + ) class DummyLoRAModel(nn.Sequential, SupportsLoRA): @@ -80,25 +84,30 @@ class DummyLoRAModel(nn.Sequential, SupportsLoRA): @pytest.fixture def dummy_model() -> nn.Module: model = DummyLoRAModel( - OrderedDict([ - ("dense1", ColumnParallelLinear(764, 100)), - ("dense2", RowParallelLinear(100, 50)), - ( - "layer1", - nn.Sequential( - OrderedDict([ - ("dense1", ColumnParallelLinear(100, 10)), - ("dense2", RowParallelLinear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("output", ColumnParallelLinear(50, 10)), - ("outact", nn.Sigmoid()), - # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), - ("sampler", Sampler()) - ])) + OrderedDict( + [ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict( + [ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), + ] + ) + ) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} model.unpadded_vocab_size = 32000 @@ -108,25 +117,30 @@ def dummy_model() -> nn.Module: @pytest.fixture def dummy_model_gate_up() -> nn.Module: model = DummyLoRAModel( - OrderedDict([ - ("dense1", ColumnParallelLinear(764, 100)), - ("dense2", RowParallelLinear(100, 50)), - ( - "layer1", - nn.Sequential( - OrderedDict([ - ("dense1", ColumnParallelLinear(100, 10)), - ("dense2", RowParallelLinear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), - ("outact", nn.Sigmoid()), - # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), - ("sampler", Sampler()) - ])) + OrderedDict( + [ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict( + [ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), + ] + ) + ) model.config = MagicMock() model.packed_modules_mapping = { "gate_up_proj": [ @@ -216,11 +230,6 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") -@pytest.fixture(scope="session") -def phi2_lora_files(): - return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 44755c603f281..2f28253bce536 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -7,7 +7,8 @@ import pytest from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -26,14 +27,10 @@ def get_lora_requests(lora_path) -> list[LoRARequest]: return lora_requests -async def requests_processing_time(llm, - lora_requests: list[LoRARequest]) -> float: - - sampling_params = SamplingParams(n=1, - temperature=0.0, - top_p=1.0, - ignore_eos=True, - max_tokens=1) +async def requests_processing_time(llm, lora_requests: list[LoRARequest]) -> float: + sampling_params = SamplingParams( + n=1, temperature=0.0, top_p=1.0, ignore_eos=True, max_tokens=1 + ) generators = [] start = time.perf_counter() @@ -41,11 +38,11 @@ async def requests_processing_time(llm, for lora_request in lora_requests: lora_int_id = lora_request.lora_int_id generator = llm.generate( - prompt=TextPrompt(prompt=f"hello {lora_int_id}", - multi_modal_data=None), # type: ignore + prompt=TextPrompt(prompt=f"hello {lora_int_id}", multi_modal_data=None), # type: ignore sampling_params=sampling_params, lora_request=lora_request, - request_id=f"test{lora_int_id}") + request_id=f"test{lora_int_id}", + ) generators.append(generator) all_gens = merge_async_iterators(*generators) @@ -58,13 +55,13 @@ async def requests_processing_time(llm, @pytest.mark.asyncio async def test_add_lora(chatglm3_lora_files): - """ - The add_lora function is used to pre-load some LoRA adapters into the + """ + The add_lora function is used to preload some LoRA adapters into the engine in anticipation of future requests using these adapters. To test this functionality, we use the async engine to process some requests - We - do it twice, once with add_lora() pre-loading and once without. + do it twice, once with add_lora() preloading and once without. - We measure the request processing time in both cases and expect the time + We measure the request processing time in both cases and expect the time to be lesser in the case with add_lora() calls. """ lora_requests: list[LoRARequest] = get_lora_requests(chatglm3_lora_files) @@ -78,18 +75,18 @@ async def test_add_lora(chatglm3_lora_files): max_loras=max_loras, max_lora_rank=LORA_RANK, max_model_len=128, - gpu_memory_utilization=0.8, #avoid OOM + gpu_memory_utilization=0.8, # avoid OOM trust_remote_code=True, - enforce_eager=True) + enforce_eager=True, + ) # split lora_requests into 3 parts part_size = len(lora_requests) // 3 dummy_run_requests = lora_requests[:part_size] - warmup_run_requests = lora_requests[part_size:part_size * 2] - cold_run_requests = lora_requests[part_size * 2:] + warmup_run_requests = lora_requests[part_size : part_size * 2] + cold_run_requests = lora_requests[part_size * 2 :] async with build_async_engine_client_from_engine_args(engine_args) as llm: - # Dummy run - So any 1-time functionality like triton kernel compilation # is complete here. await requests_processing_time(llm, dummy_run_requests) @@ -101,18 +98,16 @@ async def test_add_lora(chatglm3_lora_files): # Test that all all_lora calls are successful. assert all(add_lora_results) - time_with_add_lora = await requests_processing_time( - llm, warmup_run_requests) + time_with_add_lora = await requests_processing_time(llm, warmup_run_requests) # Run without any warmup - time_cold_start = await requests_processing_time( - llm, cold_run_requests) + time_cold_start = await requests_processing_time(llm, cold_run_requests) - print(f"time hot-start {time_with_add_lora} vs " - f"time cold-start {time_cold_start} ") + print(f"time hot-start {time_with_add_lora} vs time cold-start {time_cold_start} ") assert time_with_add_lora < time_cold_start, ( f"time_with_add_lora={time_with_add_lora}, " f"time_cold_start={time_cold_start}" "The engine request processing time with LoRA pre-loading " - "must be less than the version that does on-demand LoRA loading.") + "must be less than the version that does on-demand LoRA loading." + ) diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py deleted file mode 100644 index 774ebb9db2106..0000000000000 --- a/tests/lora/test_baichuan.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "baichuan-inc/Baichuan-7B" - -PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format(query="How many singers do we have?"), - PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 - ), - ] - print(prompts) - sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_baichuan_lora(baichuan_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True) - - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age ASC", - ] - - output1 = do_sample(llm, baichuan_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] - output2 = do_sample(llm, baichuan_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] - - -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, - num_gpus_available, fully_sharded): - if num_gpus_available < 4: - pytest.skip(f"Not enough GPUs for tensor parallelism {4}") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index fb00e7b65b04a..d8058c5f87a81 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -12,7 +12,7 @@ PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example EXPECTED_LORA_OUTPUT = [ "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "SELECT name , country , age FROM singer ORDER BY age", ] @@ -21,20 +21,24 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query=( + "What is the average, minimum, and maximum " + "age of all singers from France?" + ) ), PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 + query=( + "Show name, country, age for all singers ordered " + "by age from the oldest to the youngest." + ) ), ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -47,13 +51,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @create_new_process_for_each_test() def test_chatglm3_lora(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -66,15 +72,17 @@ def test_chatglm3_lora(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=False, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -87,15 +95,21 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=True, - enable_chunked_prefill=True) + # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for + # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use + # more GPU memory causing vLLM to OOM + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + gpu_memory_utilization=0.85, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_default_mm_loras.py b/tests/lora/test_default_mm_loras.py index f615ceda76b56..1a5b9ba3641d3 100644 --- a/tests/lora/test_default_mm_loras.py +++ b/tests/lora/test_default_mm_loras.py @@ -32,15 +32,12 @@ VLLM_RUNNER_BASE_KWARGS = { "max_lora_rank": 320, "max_model_len": 12800, "gpu_memory_utilization": 0.8, - "limit_mm_per_prompt": { - "audio": 1 - }, + "limit_mm_per_prompt": {"audio": 1}, "enforce_eager": True, } -def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, - **kwargs): +def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, **kwargs): inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])] # Apply any additional kwargs as overrides to the base kwargs @@ -53,11 +50,11 @@ def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, max_tokens=128, audios=audios, lora_request=lora_request, - ) for prompts, audios in inputs + ) + for prompts, audios in inputs ] - assert vllm_outputs_with_default_lora[-1][-1][-1].endswith( - expected_suffix) + assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(expected_suffix) def test_active_default_mm_lora( diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023babc28..695e06e7c1d63 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -11,33 +11,39 @@ import pytest import torch import torch.nn.functional as F -from vllm.config import LoRAConfig -from vllm.lora.fully_sharded_layers import ( +from vllm.config.lora import LoRAConfig +from vllm.lora.layers import ( + BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + LogitsProcessorWithLoRA, + LoRAMapping, + MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LogitsProcessorWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) -# yapf: enable + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, + VocabParallelEmbeddingWithLoRA, +) from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) + ParallelLMHead, + VocabParallelEmbedding, + get_masked_input_and_mask, +) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform @@ -51,25 +57,28 @@ TOLERANCES = { pytestmark = pytest.mark.skipif( not (current_platform.is_cuda_alike() or current_platform.is_cpu()), - reason="Backend not supported") + reason="Backend not supported", +) -DEVICES = ([ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] if current_platform.is_cuda_alike() else ["cpu"]) +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) # prefill stage(True) or decode stage(False) STAGES = [True, False] -NUM_RANDOM_SEEDS = 6 +NUM_RANDOM_SEEDS = 2 -VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 +VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2 @pytest.fixture(autouse=True) def clean_cache_reset_device(reset_default_device): # Release any memory we might be holding on to. CI runs OOMs otherwise. - from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, - _LORA_B_PTR_DICT) + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT + _LORA_B_PTR_DICT.clear() _LORA_A_PTR_DICT.clear() @@ -79,13 +88,14 @@ def clean_cache_reset_device(reset_default_device): @pytest.fixture(autouse=True) def skip_cuda_with_stage_false(request): """ - On cuda-like platforms, we use the same kernels for prefill and decode + On cuda-like platforms, we use the same kernels for prefill and decode stage, and 'stage' is generally ignored, so we only need to test once. """ if current_platform.is_cuda_alike(): try: if hasattr(request.node, "callspec") and hasattr( - request.node.callspec, "params"): + request.node.callspec, "params" + ): params = request.node.callspec.params if "stage" in params and params["stage"] is False: pytest.skip("Skip test when stage=False") @@ -94,9 +104,9 @@ def skip_cuda_with_stage_false(request): yield -def get_random_id_to_index(num_loras: int, - num_slots: int, - log: bool = True) -> list[Optional[int]]: +def get_random_id_to_index( + num_loras: int, num_slots: int, log: bool = True +) -> list[Optional[int]]: """Creates a random lora_id_to_index mapping. Args: @@ -109,7 +119,8 @@ def get_random_id_to_index(num_loras: int, if num_loras > num_slots: raise ValueError( f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " - "num_loras must be less than or equal to num_slots.") + "num_loras must be less than or equal to num_slots." + ) slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() @@ -158,19 +169,18 @@ def populate_loras( subloras: list[LoRALayerWeights] = [] sublora_len = layer_weights.shape[0] // repeats for i in range(repeats): - sublora = DummyLoRAManager( - layer_weights.device).init_random_lora( - module_name=f"fake_{i}", - weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, - ) - sublora.lora_b = sublora.lora_b[:, (sublora_len * - i):(sublora_len * (i + 1))] + sublora = DummyLoRAManager(layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[ + (sublora_len * i) : (sublora_len * (i + 1)), : + ] sublora.optimize() subloras.append(sublora) - lora = PackedLoRALayerWeights.pack( - subloras) if repeats > 1 else subloras[0] + lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0] layer.set_lora( slot_idx, @@ -191,7 +201,7 @@ def create_random_inputs( input_size: tuple[int, ...], input_range: tuple[float, float], input_type: torch.dtype = torch.int, - device: torch.device = "cuda" + device: torch.device = "cuda", ) -> tuple[list[torch.Tensor], list[int], list[int]]: """Creates random inputs. @@ -213,14 +223,15 @@ def create_random_inputs( for _ in range(num_inputs): if input_type == torch.int: inputs.append( - torch.randint(low=int(low), - high=int(high), - size=input_size, - device=device)) + torch.randint( + low=int(low), high=int(high), size=input_size, device=device + ) + ) else: inputs.append( - torch.rand(size=input_size, dtype=input_type, device=device) * - high + low) + torch.rand(size=input_size, dtype=input_type, device=device) * high + + low + ) lora_id = random.choice(active_lora_ids) index_mapping += [lora_id] * input_size[0] @@ -243,7 +254,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -258,9 +269,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -286,15 +297,18 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(inputs)) @@ -304,17 +318,14 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: result = embedding(input_) after_a = F.embedding( input_, - lora.lora_a, + lora.lora_a.T, ) - result += (after_a @ lora.lora_b) + result += after_a @ lora.lora_b.T expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -324,36 +335,36 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size, stage) -> None: - +def test_embeddings_with_new_embeddings( + dist_init, num_loras, device, vocab_size, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -361,9 +372,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -373,12 +384,12 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, expanded_embedding = VocabParallelEmbedding( vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=vocab_size) + org_num_embeddings=vocab_size, + ) expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place - lora_embedding = VocabParallelEmbeddingWithLoRA( - deepcopy(expanded_embedding)) + lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) lora_embedding.create_lora_weights(max_loras, lora_config) return expanded_embedding, lora_embedding @@ -392,7 +403,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size) + ), generate_embeddings_tensor=256, ) @@ -410,52 +422,53 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): embedding_id = lora_id - 1 input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) original_input_[-1] = vocab_size - input_[-2] = vocab_size + ( - (embedding_id + 1) * embeddings_tensor_len - 1) + input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - expanded_embedding.weight[vocab_size:vocab_size + - (embeddings_tensor_len * - max_loras)] = torch.cat(embeddings_tensors) + expanded_embedding.weight[ + vocab_size : vocab_size + (embeddings_tensor_len * max_loras) + ] = torch.cat(embeddings_tensors) lora_result = lora_embedding(torch.cat(original_inputs)) expected_results: list[torch.Tensor] = [] - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): lora = lora_dict[lora_id] result = expanded_embedding(input_) after_a = F.embedding( original_input_, - lora.lora_a, + lora.lora_a.T, ) - result += (after_a @ lora.lora_b) + result += after_a @ lora.lora_b.T expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -465,34 +478,34 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) + device=device, + ) original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) -def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, - stage) -> None: - +def test_lm_head_logits_processor( + dist_init, num_loras, device, vocab_size, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -500,22 +513,25 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def _pretest(): - linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, - params_dtype=torch.float16) + linear = ParallelLMHead( + vocab_size + lora_config.lora_extra_vocab_size, + 1024, + vocab_size, + params_dtype=torch.float16, + ) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size + ) lora_logits_processor = LogitsProcessorWithLoRA( - logits_processor, 1024, linear.weight.dtype, linear.weight.device, - None) + logits_processor, 1024, linear.weight.dtype, linear.weight.device, None + ) lora_logits_processor.create_lora_weights(max_loras, lora_config) return linear, logits_processor, lora_logits_processor @@ -542,10 +558,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -556,26 +571,25 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=linear, - embedding_bias=None) + hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None + ) original_lm_head = deepcopy(linear) - linear.weight[logits_processor. - org_vocab_size:logits_processor.org_vocab_size + - embeddings_tensor_len] = embeddings_tensor + linear.weight[ + logits_processor.org_vocab_size : logits_processor.org_vocab_size + + embeddings_tensor_len + ] = embeddings_tensor - logits_processor.org_vocab_size = (vocab_size + - lora_config.lora_extra_vocab_size) + logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] - result = logits_processor._get_logits(hidden_states=input_, - lm_head=linear, - embedding_bias=None) - result[:, vocab_size + embeddings_tensor_len:] = float("-inf") - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result = logits_processor._get_logits( + hidden_states=input_, lm_head=linear, embedding_bias=None + ) + result[:, vocab_size + embeddings_tensor_len :] = float("-inf") + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) logits_processor.org_vocab_size = vocab_size @@ -591,10 +605,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -606,27 +619,28 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, - embedding_bias=None)[:, :vocab_size] + embedding_bias=None, + )[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, - embedding_bias=None) + embedding_bias=None, + ) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: - +def test_linear_replicated( + dist_init, + num_loras, + device, + stage, +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,27 +648,24 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + ) def create_random_linear_replicated_layer(): - - linear = ReplicatedLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == 1 + ) return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -676,10 +687,9 @@ def test_linear_replicated(dist_init, num_loras, device, stage, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -694,15 +704,12 @@ def test_linear_replicated(dist_init, num_loras, device, stage, for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -715,34 +722,30 @@ def test_linear_replicated(dist_init, num_loras, device, stage, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + punica_wrapper.update_metadata( + lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: - +def test_linear_parallel( + dist_init, num_loras, orientation, fully_shard, device, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,37 +753,42 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_random_linear_parallel_layer(): if orientation == "row": - linear = RowParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = RowParallelLinear( + 4096, 4096, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard - else RowParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + RowParallelLinearWithLoRA(linear) + if not fully_shard + else RowParallelLinearWithShardedLoRA(linear) + ) else: - linear = ColumnParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = ColumnParallelLinear( + 4096, 4096, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (ColumnParallelLinearWithLoRA(linear) - if not fully_shard else - ColumnParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + ColumnParallelLinearWithLoRA(linear) + if not fully_shard + else ColumnParallelLinearWithShardedLoRA(linear) + ) lora_linear.create_lora_weights(max_loras, lora_config) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == 1 + ) + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -802,10 +810,9 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -820,15 +827,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -841,34 +845,30 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + punica_wrapper.update_metadata( + lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: - +def test_column_parallel_packed( + dist_init, num_loras, repeats, fully_shard, device, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,41 +876,44 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_column_parallel_packed_layer(): if repeats == 2: - linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False, - params_dtype=torch.float16) + linear = MergedColumnParallelLinear( + 4096, [4096] * repeats, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedColumnParallelLinearWithLoRA(linear) - if not fully_shard else - MergedColumnParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard + else MergedColumnParallelLinearWithShardedLoRA(linear) + ) elif repeats == 3: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) + linear = QKVParallelLinear( + 4096, 64, 32, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedQKVParallelLinearWithLoRA(linear) - if not fully_shard else - MergedQKVParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + MergedQKVParallelLinearWithLoRA(linear) + if not fully_shard + else MergedQKVParallelLinearWithShardedLoRA(linear) + ) else: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) + linear = QKVParallelLinear( + 4096, 64, 32, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLoRA( - linear - ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear) + lora_linear = ( + QKVParallelLinearWithLoRA(linear) + if not fully_shard + else QKVParallelLinearWithShardedLoRA(linear) + ) @dataclass class FakeConfig: @@ -919,15 +922,16 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, num_attention_heads = 32 n_slices = repeats - lora_linear.create_lora_weights(max_loras, - lora_config, - model_config=FakeConfig()) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == n_slices) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + lora_linear.create_lora_weights( + max_loras, lora_config, model_config=FakeConfig() + ) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == n_slices + ) + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -951,10 +955,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, @@ -971,17 +974,14 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * - sublora.scaling) + result[ + :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1) + ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) for slot_idx in range(max_loras): lora_linear.reset_lora(slot_idx) @@ -992,10 +992,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, @@ -1009,15 +1008,13 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize( - "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))) + "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)) +) def test_vocab_parallel_embedding_indices(tp_size, seed): random.seed(seed) vocab_size = random.randint(4000, 64000) @@ -1035,20 +1032,24 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): token_ids: list[int] = [] for tp_rank in range(tp_size): - with patch( + with ( + patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", - return_value=tp_rank - ), patch( + return_value=tp_rank, + ), + patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", - return_value=tp_size): + return_value=tp_size, + ), + ): vocab_embedding = VocabParallelEmbedding( - vocab_size, 1, org_num_embeddings=org_vocab_size) + vocab_size, 1, org_num_embeddings=org_vocab_size + ) vocab_size_padded = vocab_embedding.num_embeddings_padded shard_indices = vocab_embedding.shard_indices # Assert that the ranges are contiguous assert shard_indices.org_vocab_start_index == last_org_vocab_end_index - assert (shard_indices.added_vocab_start_index == - last_added_vocab_end_index) + assert shard_indices.added_vocab_start_index == last_added_vocab_end_index # Ensure that we are not exceeding the vocab size computed_vocab_size += shard_indices.num_elements_padded @@ -1057,22 +1058,39 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): # Ensure that the ranges are not overlapping all_org_tokens.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) + range( + shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index + ) + ) all_added_tokens.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) + range( + shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index, + ) + ) token_ids.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_org_elements_padded - - shard_indices.num_org_elements)) + range( + shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index + ) + ) token_ids.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_added_elements_padded - - shard_indices.num_added_elements)) + [-1] + * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements) + ) + token_ids.extend( + range( + shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index, + ) + ) + token_ids.extend( + [-1] + * ( + shard_indices.num_added_elements_padded + - shard_indices.num_added_elements + ) + ) last_org_vocab_end_index = shard_indices.org_vocab_end_index last_added_vocab_end_index = shard_indices.added_vocab_end_index @@ -1100,130 +1118,165 @@ def test_get_masked_input_and_mask(): x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) # base tp 1 case, no padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=0) + modified_x, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0, + ) assert torch.equal(x, modified_x) # tp 2 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=0) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0, + ) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + num_org_vocab_padding=0, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]) + ) # tp 4 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=0) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0, + ) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0, + ) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, - num_org_vocab_padding=0) + num_org_vocab_padding=0, + ) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + num_org_vocab_padding=0, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]) + ) + assert torch.equal( + modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]) + ) + assert torch.equal( + modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]) + ) # base tp 1 case, with padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x, - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + modified_x, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]) + ) # tp 2 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=2) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2, + ) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]) + ) # tp 4 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=2) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2, + ) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2, + ) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, - num_org_vocab_padding=2) + num_org_vocab_padding=2, + ) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]) + ) + assert torch.equal( + modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]) + ) + assert torch.equal( + modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]) + ) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 06196cc697cec..0d9431bd7aaea 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -13,41 +13,34 @@ from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test MODEL_PATH = "meta-llama/Llama-2-7b-hf" -EXPECTED_NO_LORA_OUTPUT = [ - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 - "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 -] EXPECTED_LORA_OUTPUT = [ " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 - " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 - " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 - " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501 ] -def do_sample(llm: vllm.LLM, - lora_path: str, - lora_id: int, - tensorizer_config_dict: Union[dict, None] = None) -> list[str]: +def do_sample( + llm: vllm.LLM, + lora_path: str, + lora_id: int, + tensorizer_config_dict: Union[dict, None] = None, +) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501 ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=256, - skip_special_tokens=False, - stop=["[/assistant]"]) + sampling_params = vllm.SamplingParams( + temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"] + ) if tensorizer_config_dict is not None: outputs = llm.generate( @@ -57,14 +50,19 @@ def do_sample(llm: vllm.LLM, str(lora_id), lora_id, lora_path, - tensorizer_config_dict=tensorizer_config_dict) - if lora_id else None) + tensorizer_config_dict=tensorizer_config_dict, + ) + if lora_id + else None, + ) else: outputs = llm.generate( prompts, sampling_params, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + if lora_id + else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -75,54 +73,54 @@ def do_sample(llm: vllm.LLM, return generated_texts -def generate_and_test(llm, - sql_lora_files, - tensorizer_config_dict: Union[dict, None] = None): +def generate_and_test( + llm, sql_lora_files, tensorizer_config_dict: Union[dict, None] = None +): print("lora adapter created") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 1") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=1) == EXPECTED_LORA_OUTPUT - - print("no lora") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1, + ) + == EXPECTED_LORA_OUTPUT + ) print("lora 2") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=2) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=2, + ) + == EXPECTED_LORA_OUTPUT + ) print("removing lora") @create_new_process_for_each_test() def test_llama_lora(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, + tokenizer=sql_lora_files, enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, - max_loras=4) + max_loras=4, + ) generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_llama_lora_tp4(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, + tokenizer=sql_lora_files, enable_lora=True, max_num_seqs=16, max_loras=4, @@ -134,9 +132,9 @@ def test_llama_lora_tp4(sql_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, + tokenizer=sql_lora_files, enable_lora=True, max_num_seqs=16, max_loras=4, @@ -148,9 +146,9 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): @multi_gpu_test(num_gpus=2) @create_new_process_for_each_test() -def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, - sql_lora_huggingface_id): - +def test_tp2_serialize_and_deserialize_lora( + tmp_path, sql_lora_files, sql_lora_huggingface_id +): # Run the tensorizing of the LoRA adapter and the model in a subprocess # to guarantee cleanup @@ -161,17 +159,28 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, lora_path = sql_lora_huggingface_id suffix = "test" try: - result = subprocess.run([ - sys.executable, - f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", - MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", - str(tp_size), "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix, "--serialization-kwargs", - '{"limit_cpu_concurrency": 4}' - ], - check=True, - capture_output=True, - text=True) + result = subprocess.run( + [ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", + "--model", + MODEL_PATH, + "--lora-path", + lora_path, + "--tensor-parallel-size", + str(tp_size), + "serialize", + "--serialized-directory", + str(tmp_path), + "--suffix", + suffix, + "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}', + ], + check=True, + capture_output=True, + text=True, + ) except subprocess.CalledProcessError as e: print("Tensorizing failed.") print("STDOUT:\n", e.stdout) @@ -183,25 +192,25 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, model_uri = tmp_path / "vllm" / model_ref / suffix / model_name tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) - loaded_llm = LLM(model=model_ref, - load_format="tensorizer", - enable_lora=True, - enforce_eager=True, - model_loader_extra_config=tensorizer_config, - max_num_seqs=13, - tensor_parallel_size=2, - max_loras=2) + loaded_llm = LLM( + model=model_ref, + tokenizer=sql_lora_files, + load_format="tensorizer", + enable_lora=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + max_num_seqs=13, + tensor_parallel_size=2, + max_loras=2, + ) tc_as_dict = tensorizer_config.to_serializable() print("lora adapter created") - assert do_sample(loaded_llm, - sql_lora_files, - tensorizer_config_dict=tc_as_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 1") - assert do_sample(loaded_llm, - sql_lora_files, - tensorizer_config_dict=tc_as_dict, - lora_id=1) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1 + ) + == EXPECTED_LORA_OUTPUT + ) diff --git a/tests/lora/test_multi_loras_with_tp.py b/tests/lora/test_llm_with_multi_loras.py similarity index 74% rename from tests/lora/test_multi_loras_with_tp.py rename to tests/lora/test_llm_with_multi_loras.py index fe9bd3f269515..269a1ade7734f 100644 --- a/tests/lora/test_multi_loras_with_tp.py +++ b/tests/lora/test_llm_with_multi_loras.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Script to test multi loras service with tp >= 2 +This script contains: +1. test multi loras service with tp >= 2 +2. test multi loras request """ + +import pytest + from tests.utils import multi_gpu_test from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest @@ -21,20 +26,14 @@ LORA_RANK = 8 LORA_TEST_PROMPTS = ["What is GitHub?", "Hi, tell me about you"] LORA_TEST_EXPECTED = [ "GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.", # noqa: E501 - "I am Alice, an AI assistant developed by GitHub/Charent.", # noqa: E501 + "I am Alice, an AI assistant developed by GitHub/Charent.", ] def format_chatml_messages(prompt: str): return [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": prompt - }, + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, ] @@ -53,7 +52,6 @@ def make_add_lora_request(name: str, path: str): @multi_gpu_test(num_gpus=2) def test_multi_loras_with_tp_sync(): - llm = LLM( model=MODEL_PATH, enable_lora=True, @@ -112,15 +110,17 @@ def test_multi_loras_with_tp_sync(): def reload_lora(name: str): """ - reload a lora to simulate the case: - setting `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true` + reload a lora to simulate the case: + setting `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true` for dynamic lora loading and unloading """ remove_lora_response = llm.llm_engine.remove_lora( - lora_id=LORA_NAME_ID_MAP[name]) + lora_id=LORA_NAME_ID_MAP[name] + ) add_lora_response = llm.llm_engine.add_lora( - make_add_lora_request(name, LORA_NAME_PATH_MAP[name])) + make_add_lora_request(name, LORA_NAME_PATH_MAP[name]) + ) print(f"{remove_lora_response=}, {add_lora_response=}") @@ -130,7 +130,6 @@ def test_multi_loras_with_tp_sync(): assert outputs == expected for prompt, expected_output in zip(LORA_TEST_PROMPTS, LORA_TEST_EXPECTED): - output_text = call_llm_get_outputs(prompt, "Alice") check_outputs(output_text, expected_output) @@ -156,3 +155,33 @@ def test_multi_loras_with_tp_sync(): output_text = call_llm_get_outputs(prompt, "Alice") check_outputs(output_text, expected_output) + + +def test_multiple_lora_requests(): + llm = LLM( + model=MODEL_PATH, + enable_lora=True, + max_loras=4, + max_lora_rank=LORA_RANK, + max_model_len=512, + gpu_memory_utilization=0.5, + enforce_eager=True, + ) + PROMPTS = ["Hello, my name is"] * 2 + LORA_NAME = "Alice" + lora_request = [ + LoRARequest(LORA_NAME + str(idx), idx + 1, LORA_NAME_PATH_MAP[LORA_NAME]) + for idx in range(len(PROMPTS)) + ] + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(PROMPTS, lora_request=lora_request) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, lora_request=lora_request[:1]) + + # Single LoRARequest should be applied to every prompt + single_lora_request = lora_request[0] + outputs = llm.generate(PROMPTS, lora_request=single_lora_request) + assert len(PROMPTS) == len(outputs) diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py deleted file mode 100644 index 01bc102bd112b..0000000000000 --- a/tests/lora/test_lora_allowed_token_ids.py +++ /dev/null @@ -1,135 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - VllmConfig) -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.engine.processor import Processor - - -def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, - sql_lora_files): - """ - Test that we properly resolve the range of allowed token ids for lora - adapters that define additional tokens. - """ - - # Setup a base model compatible with the sql_lora_files adapter and - # a known number of tokens in the base model. - model_config = ModelConfig( - model=llama_2_7b_base_huggingface_id, - tokenizer=llama_2_7b_base_huggingface_id, - tokenizer_mode="auto", - ) - - vllm_config = VllmConfig( - model_config=model_config, - cache_config=CacheConfig(), - device_config=DeviceConfig(), - lora_config=LoRAConfig(), - ) - - tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) - processor = Processor(vllm_config, tokenizer) - - lora_request = LoRARequest("1", 1, str(sql_lora_files)) - request_id = "1" - prompt = "a prompt" - - # tokens added in the lora adapter should not raise an error - lora_token_ids = [32000, 32001, 32002, 32003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=lora_token_ids), - lora_request=lora_request) - - # tokens in the base model should not raise an error - base_token_ids = [1000, 1001, 1002, 1003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=base_token_ids), - lora_request=lora_request) - - # tokens not in the lora adapter should raise an error - invalid_token_ids = [35000, 35001, 35002, 35003] - with pytest.raises(ValueError): - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=invalid_token_ids), - lora_request=lora_request) - - # tokens in the lora adapter with no lora request should raise an error - with pytest.raises(ValueError): - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=lora_token_ids), - ) - - -def test_allowed_token_ids_with_lora_adapter_no_vocab( - qwen25vl_base_huggingface_id, qwen25vl_lora_files): - """ - Test that we properly resolve the range of allowed token ids for lora - adapters that do not define additional tokens. - """ - - # Setup a base model compatible with the qwen25vl_lora_files adapter and - # a known number of tokens in the base model. - model_config = ModelConfig( - model=qwen25vl_base_huggingface_id, - tokenizer=qwen25vl_base_huggingface_id, - tokenizer_mode="auto", - ) - - vllm_config = VllmConfig( - model_config=model_config, - cache_config=CacheConfig(), - device_config=DeviceConfig(), - lora_config=LoRAConfig(), - ) - - tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) - processor = Processor(vllm_config, tokenizer) - - lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files)) - request_id = "1" - prompt = "a prompt" - - # tokens in the base model should not raise an error - base_token_ids = [1000, 1001, 1002, 1003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=base_token_ids), - lora_request=lora_request) - - # tokens in the base model with no lora request should not raise an error - base_token_ids = [1000, 1001, 1002, 1003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=base_token_ids), - ) - - # tokens not in the base model should raise an error - invalid_token_ids = [200000, 200001, 200002, 200003] - with pytest.raises(ValueError): - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=invalid_token_ids), - lora_request=lora_request) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index ebc0f26378d27..2219d470e91a1 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -8,9 +8,7 @@ from vllm.lora.peft_helper import PEFTHelper from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.utils import WeightsMapper -lora_lst = [ - "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" -] +lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"] BAICHUAN_LORA_MODULES = [ "W_pack", "o_proj", @@ -37,8 +35,9 @@ def test_load_checkpoints( else: expected_lora_modules.append(module) if lora_name == "baichuan7B": - peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_lora_files, max_position_embeddings=4096 + ) # For the baichuan7B model, load it's LoRA, # and the test should pass. LoRAModel.from_local_checkpoint( @@ -48,13 +47,15 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) elif lora_name == "baichuan7B-zero": # Test that the target_modules contain prefix # such as "model.layers.0.self_atten.W_pack", and # the test should pass. - peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_zero_lora_files, max_position_embeddings=4096 + ) LoRAModel.from_local_checkpoint( baichuan_zero_lora_files, expected_lora_modules, @@ -62,12 +63,14 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) elif lora_name == "baichuan7B-zero-regex": # Test that the `target_modules` in the form of regular expressions, # such as `model\\..*(W_pack|o_proj)`, and the test should pass. - peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_regex_lora_files, max_position_embeddings=4096 + ) LoRAModel.from_local_checkpoint( baichuan_regex_lora_files, expected_lora_modules, @@ -75,13 +78,15 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 - peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + chatglm3_lora_files, max_position_embeddings=4096 + ) with pytest.raises(ValueError, match=expected_error): LoRAModel.from_local_checkpoint( chatglm3_lora_files, @@ -90,11 +95,11 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) def test_lora_weights_mapping(baichuan_lora_files): - packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules @@ -113,8 +118,9 @@ def test_lora_weights_mapping(baichuan_lora_files): ".layers.": ".baichuan_layers.", }, ) - peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_lora_files, max_position_embeddings=4096 + ) lora_model = LoRAModel.from_local_checkpoint( baichuan_lora_files, expected_lora_modules, diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 50c60341f0d88..e914393fee8aa 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -3,13 +3,15 @@ """ Script to test add_lora, remove_lora, pin_lora, list_loras functions. """ + import pytest from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.lora.request import LoRARequest +from vllm.v1.engine.llm_engine import LLMEngine MODEL_PATH = "meta-llama/Llama-2-7b-hf" LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test" @@ -17,23 +19,24 @@ LORA_RANK = 8 def make_lora_request(lora_id: int): - return LoRARequest(lora_name=f"{lora_id}", - lora_int_id=lora_id, - lora_path=LORA_MODULE_PATH) + return LoRARequest( + lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=LORA_MODULE_PATH + ) def test_lora_functions_sync(): - max_loras = 4 # Create engine in eager-mode. Due to high max_loras, the CI can # OOM during cuda-graph capture. - engine_args = EngineArgs(model=MODEL_PATH, - enable_lora=True, - max_loras=max_loras, - max_lora_rank=LORA_RANK, - max_model_len=128, - gpu_memory_utilization=0.8, - enforce_eager=True) + engine_args = EngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) llm = LLMEngine.from_engine_args(engine_args) @@ -70,15 +73,16 @@ def test_lora_functions_sync(): @pytest.mark.asyncio async def test_lora_functions_async(): - max_loras = 4 - engine_args = AsyncEngineArgs(model=MODEL_PATH, - enable_lora=True, - max_loras=max_loras, - max_lora_rank=LORA_RANK, - max_model_len=128, - gpu_memory_utilization=0.8, - enforce_eager=True) + engine_args = AsyncEngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) async def run_check(fn, args, expected: list): await fn(args) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index b46d81f1651a6..7d20faef541aa 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -11,8 +11,12 @@ from vllm.model_executor.models.llama import LlamaForCausalLM # Provide absolute path and huggingface lora ids lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] LLAMA_LORA_MODULES = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", ] @@ -40,7 +44,8 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) # Assertions to ensure the model is loaded correctly assert lora_model is not None, "LoRAModel is not loaded correctly" diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index c9ab32edc7f32..e7816031142e3 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -8,17 +8,23 @@ import torch from safetensors.torch import load_file from torch import nn -from vllm.config import LoRAConfig -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - RowParallelLinearWithLoRA) -from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager) +from vllm.config import ModelConfig, VllmConfig +from vllm.config.lora import LoRAConfig +from vllm.lora.layers import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, +) +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.models import ( + LoRAMapping, + LoRAModel, + LoRAModelManager, + LRUCacheLoRAModelManager, +) from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, - WorkerLoRAManager) +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager from vllm.platforms import current_platform from .utils import create_peft_lora @@ -30,22 +36,25 @@ EMBEDDING_MODULES = { EMBEDDING_PADDING_MODULES = ["lm_head"] -DEVICES = ([ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] if current_platform.is_cuda_alike() else ["cpu"]) +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) DEFAULT_DTYPE = torch.get_default_dtype() @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): - tensors = load_file( - os.path.join(sql_lora_files, "adapter_model.safetensors")) + tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors")) new_embeddings = load_file( - os.path.join(sql_lora_files, "new_embeddings.safetensors")) + os.path.join(sql_lora_files, "new_embeddings.safetensors") + ) - peft_helper = PEFTHelper.from_local_dir(sql_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + sql_lora_files, max_position_embeddings=4096 + ) lora_model = LoRAModel.from_lora_tensors( 1, tensors, @@ -53,7 +62,8 @@ def test_from_lora_tensors(sql_lora_files, device): device=device, embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, - embedding_padding_modules=EMBEDDING_PADDING_MODULES) + embedding_padding_modules=EMBEDDING_PADDING_MODULES, + ) for module_name, lora in lora_model.loras.items(): assert lora.module_name == module_name assert lora.rank == 8 @@ -62,22 +72,27 @@ def test_from_lora_tensors(sql_lora_files, device): assert lora.lora_b is not None assert lora.lora_a.device == torch.device(device) assert lora.lora_b.device == torch.device(device) - assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] - ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" - assert lora.lora_a.shape[1] == 8 + assert lora.lora_a.shape[0] == lora.lora_b.shape[1], ( + f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + ) + assert lora.lora_a.shape[0] == 8 embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), None) + (k for k in EMBEDDING_MODULES if k in module_name), None + ) if embeddings_module: assert torch.equal( lora.embeddings_tensor, new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( - device=lora.embeddings_tensor.device)) + device=lora.embeddings_tensor.device + ), + ) else: assert lora.embeddings_tensor is None -def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], - device: torch.device) -> LoRAModel: +def create_lora( + lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device +) -> LoRAModel: loras: dict[str, LoRALayerWeights] = {} for name in sub_modules: w = model.get_submodule(name).weight @@ -85,8 +100,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], name, 8, 16, - torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0]], device=device), + torch.rand([8, w.shape[1]], device=device), + torch.rand([w.shape[0], 8], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -108,9 +123,8 @@ def create_packed_lora( replaced_module_name, 8, 16, - torch.rand([w.shape[1], 8], device=device), - torch.rand([8, w.shape[0] // len(replaced_module_names)], - device=device), + torch.rand([8, w.shape[1]], device=device), + torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -118,42 +132,42 @@ def create_packed_lora( def test_replace_submodules(dist_init, dummy_model): model = dummy_model manager = LoRAModelManager( - model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=8, - max_loras=8, - lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0])) + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE + ), + torch.device(DEVICES[0]), + ) model = manager.model - assert isinstance(model.get_submodule("dense1"), - ColumnParallelLinearWithLoRA) - assert isinstance(model.get_submodule("layer1.dense1"), - ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) + assert isinstance( + model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA + ) assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA) - assert isinstance(model.get_submodule("layer1.dense2"), - RowParallelLinearWithLoRA) + assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA) @pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=3, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -203,24 +217,21 @@ def test_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LRUCacheLoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=3, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -296,27 +307,22 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora4 = create_lora(4, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LRUCacheLoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=2, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity @@ -420,12 +426,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, - tmp_path): - lora_config = LoRAConfig(max_lora_rank=8, - max_cpu_loras=4, - max_loras=4, - lora_dtype=DEFAULT_DTYPE) +def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path): + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) dummy_lora_files = f"{tmp_path}/lora_adapter" os.makedirs(dummy_lora_files, exist_ok=True) @@ -435,59 +439,80 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, target_modules=["layer1.dense1", "dense2"], lora_dtype=DEFAULT_DTYPE, ) + + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) + + vllm_config.scheduler_config.max_num_seqs = 4 + vllm_config.scheduler_config.max_num_batched_tokens = 2 worker_adapter_manager = LRUCacheWorkerLoRAManager( - 4, 2, - dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size, - lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES + ) + + worker_adapter_manager.max_num_seqs = 4 + worker_adapter_manager.max_num_batched_tokens = 2 + worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("3", 3, dummy_lora_files), - LoRARequest("4", 4, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files), - LoRARequest("5", 5, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, dummy_lora_files), - LoRARequest("7", 7, dummy_lora_files), - LoRARequest("8", 8, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 @@ -496,31 +521,40 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, # Over capacity with pytest.raises(RuntimeError): - worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, dummy_lora_files), - LoRARequest("11", 11, dummy_lora_files), - LoRARequest("12", 12, dummy_lora_files), - LoRARequest("13", 13, dummy_lora_files), - LoRARequest("14", 14, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.device == device - assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == - device) + assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, - tmp_path): +def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path): # Should remove every LoRA not specified in the request. - lora_config = LoRAConfig(max_lora_rank=8, - max_cpu_loras=4, - max_loras=4, - lora_dtype=DEFAULT_DTYPE) + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) + + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) + + vllm_config.scheduler_config.max_num_seqs = 4 + vllm_config.scheduler_config.max_num_batched_tokens = 2 + worker_adapter_manager = WorkerLoRAManager( - 4, 2, dummy_model_gate_up.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, device, - EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES + ) + worker_adapter_manager.vocab_size = ( + dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size + ) worker_adapter_manager.create_lora_manager(dummy_model_gate_up) dummy_lora_files = f"{tmp_path}/lora_adapter" @@ -533,49 +567,61 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, ) mapping = LoRAMapping([], []) - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("3", 3, dummy_lora_files), - LoRARequest("4", 4, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files), - LoRARequest("5", 5, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, dummy_lora_files), - LoRARequest("7", 7, dummy_lora_files), - LoRARequest("8", 8, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 @@ -583,17 +629,19 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, # Over capacity with pytest.raises(RuntimeError): - worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, dummy_lora_files), - LoRARequest("11", 11, dummy_lora_files), - LoRARequest("12", 12, dummy_lora_files), - LoRARequest("13", 13, dummy_lora_files), - LoRARequest("14", 14, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.device == device - assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == - device) + assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) @@ -604,7 +652,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): model, module_name="gate_up_proj", replaced_module_names=["gate_proj", "up_proj"], - device=device) + device=device, + ) model_lora1 = create_packed_lora( 2, model, @@ -614,19 +663,21 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): empty_replaced_module_name="gate_proj", ) - manager = LoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=2, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) model = manager.model - assert isinstance(model.get_submodule("gate_up_proj"), - MergedColumnParallelLinearWithLoRA) + assert isinstance( + model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA + ) # Verify packed lora is correct model_lora_clone = model_lora.clone(1) model_lora_clone1 = model_lora1.clone(1) @@ -639,21 +690,27 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) - torch.testing.assert_close(packed_lora.lora_a[0], - model_lora_clone.get_lora("gate_proj").lora_a) - torch.testing.assert_close(packed_lora.lora_b[0], - model_lora_clone.get_lora("gate_proj").lora_b) - torch.testing.assert_close(packed_lora.lora_a[1], - model_lora_clone.get_lora("up_proj").lora_a) - torch.testing.assert_close(packed_lora.lora_b[1], - model_lora_clone.get_lora("up_proj").lora_b) + torch.testing.assert_close( + packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a + ) + torch.testing.assert_close( + packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b + ) + torch.testing.assert_close( + packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a + ) + torch.testing.assert_close( + packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b + ) packed_lora1 = model_lora1.get_lora("gate_up_proj") assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_b[0] is None - torch.testing.assert_close(packed_lora1.lora_a[1], - model_lora_clone1.get_lora("up_proj").lora_a) - torch.testing.assert_close(packed_lora1.lora_b[1], - model_lora_clone1.get_lora("up_proj").lora_b) + torch.testing.assert_close( + packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a + ) + torch.testing.assert_close( + packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b + ) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 99fe951bbf070..ce98fe2f86137 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -15,7 +15,8 @@ MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" PROMPT_TEMPLATE = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" "(<image>./</image>)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n") + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) IMAGE_ASSETS = [ ImageAsset("stop_sign"), @@ -34,18 +35,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: stop_token_ids=[128001, 128009], # eos_id, eot_id ) - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] + inputs = [ + { + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in IMAGE_ASSETS + ] outputs = llm.generate( inputs, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, ) # Print the outputs. generated_texts: list[str] = [] @@ -58,7 +59,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -68,10 +70,7 @@ def test_minicpmv_lora(minicpmv_lora_files): max_lora_rank=8, enforce_eager=True, max_model_len=2048, - limit_mm_per_prompt={ - "image": 2, - "video": 0 - }, + limit_mm_per_prompt={"image": 2, "video": 0}, trust_remote_code=True, ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -82,11 +81,13 @@ def test_minicpmv_lora(minicpmv_lora_files): assert EXPECTED_OUTPUT[i].startswith(output2[i]) -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) @create_new_process_for_each_test() def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -96,10 +97,7 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=4, - limit_mm_per_prompt={ - "image": 2, - "video": 0 - }, + limit_mm_per_prompt={"image": 2, "video": 0}, trust_remote_code=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -107,11 +105,13 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) @create_new_process_for_each_test() def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -122,10 +122,7 @@ def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): max_lora_rank=8, tensor_parallel_size=4, trust_remote_code=True, - limit_mm_per_prompt={ - "image": 1, - "video": 0 - }, + limit_mm_per_prompt={"image": 1, "video": 0}, fully_sharded_loras=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 03e5d8d5d6728..868ca51b33314 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -11,15 +11,15 @@ from vllm.platforms import current_platform MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - prompts: list[str]) -> list[str]: - +def do_sample( + llm: vllm.LLM, lora_path: str, lora_id: int, prompts: list[str] +) -> list[str]: sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -33,8 +33,11 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count( - ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): + if ( + torch.cuda.device_count() < tp_size + and tp_size > 1 + and current_platform.is_cuda_alike() + ): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ @@ -57,7 +60,11 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501 "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501 ] - assert do_sample(llm, mixtral_lora_files, lora_id=1, - prompts=prompts) == expected_lora_output - assert do_sample(llm, mixtral_lora_files, lora_id=2, - prompts=prompts) == expected_lora_output + assert ( + do_sample(llm, mixtral_lora_files, lora_id=1, prompts=prompts) + == expected_lora_output + ) + assert ( + do_sample(llm, mixtral_lora_files, lora_id=2, prompts=prompts) + == expected_lora_output + ) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index df8696cf58e0f..9c55c623d444b 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -7,40 +7,28 @@ import shutil import pytest -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.lora.peft_helper import PEFTHelper ERROR_CASES = [ ( "test_rank", - { - "r": 1024 - }, + {"r": 1024}, "is greater than max_lora_rank", ), - ( - "test_bias", - { - "bias": "all" - }, - "Adapter bias cannot be used without bias_enabled", - ), - ("test_dora", { - "use_dora": True - }, "does not yet support DoRA"), + ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", - { - "modules_to_save": ["lm_head"] - }, + {"modules_to_save": ["lm_head"]}, "only supports modules_to_save being None", ), ] def test_peft_helper_pass(sql_lora_files, tmp_path): - peft_helper = PEFTHelper.from_local_dir(sql_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + sql_lora_files, max_position_embeddings=4096 + ) lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) peft_helper.validate_legal(lora_config) assert peft_helper.r == 8 @@ -74,8 +62,7 @@ def test_peft_helper_pass(sql_lora_files, tmp_path): with open(config_path, "w") as f: json.dump(adapter_config, f) - peft_helper = PEFTHelper.from_local_dir(test_dir, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir(test_dir, max_position_embeddings=4096) peft_helper.validate_legal(lora_config) scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 @@ -106,4 +93,5 @@ def test_peft_helper_error( # Test loading the adapter with pytest.raises(ValueError, match=expected_error): PEFTHelper.from_local_dir( - test_dir, max_position_embeddings=4096).validate_legal(lora_config) + test_dir, max_position_embeddings=4096 + ).validate_legal(lora_config) diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py deleted file mode 100644 index 3090941e63679..0000000000000 --- a/tests/lora/test_phi.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "microsoft/phi-2" - -PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), - PROMPT_TEMPLATE.format( - sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 - context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 - ), - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="### End") - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_phi2_lora(phi2_lora_files): - # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, - # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True, - enable_chunked_prefill=True) - - expected_lora_output = [ - "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 - "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 - ] - - output1 = do_sample(llm, phi2_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i].startswith(expected_lora_output[i]) - output2 = do_sample(llm, phi2_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index 14fa79ae5b446..e4df9751077d2 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -21,11 +21,18 @@ def reset_device(reset_default_device): # Utility shrink and expand operations used as reference implementations. def sgmv_shrink_for_nslices( - nslices: int, inputs_tensor: torch.Tensor, - lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, - prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, - num_tokens: int, scaling: float): + nslices: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, + batches: int, + max_seq_length: int, + num_tokens: int, + scaling: float, +): """ Wrapper around torch_ops.sgmv_shrink that handles any nslices. """ @@ -44,15 +51,20 @@ def sgmv_shrink_for_nslices( ) -def sgmv_expand_for_nslices(nslices: int, hidden_size: int, - inputs_tensor: torch.Tensor, - lora_weights_lst: list[torch.Tensor], - out_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - prompt_lora_mapping: torch.Tensor, batches: int, - max_seq_length: int, num_tokens: int, - add_inputs: bool) -> None: +def sgmv_expand_for_nslices( + nslices: int, + hidden_size: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, + batches: int, + max_seq_length: int, + num_tokens: int, + add_inputs: bool, +) -> None: """ Wrapper around torch_ops.sgmv_expand that handles any nslices. """ @@ -94,10 +106,17 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, _dict_lock = Lock() -def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, seq_length: int, - scaling: float): +def check_lora_shrink_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + scaling: float, +): """ Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink kernels. @@ -116,14 +135,19 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() # Setup metadata information for SGMV and reference kernels - sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, - data.prompt_lora_mapping, batches, max_seq_length, - token_nums) + sgmv_meta_args = ( + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + ) # Setup metadata information for the LoRA kernel. - lora_meta = LoRAKernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=token_nums, device="cuda" + ) lora_meta.prepare_tensors(data.token_lora_mapping) ref_out_tensor = data.ref_out_tensor @@ -154,10 +178,17 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, assert_close(out_tensor, ref_out_tensor) -def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, seq_length: int, - add_inputs: bool): +def check_lora_expand_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + add_inputs: bool, +): """ Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand kernels. @@ -177,14 +208,19 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() # Setup metadata information for SGMV and reference kernels - sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, - data.prompt_lora_mapping, batches, max_seq_length, - token_nums) + sgmv_meta_args = ( + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + ) # Setup metadata information for the LoRA kernel. - lora_meta = LoRAKernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=token_nums, device="cuda" + ) lora_meta.prepare_tensors(data.token_lora_mapping) # Setup output tensors @@ -194,21 +230,25 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, with _dict_lock: # lora_expand kernel _LORA_B_PTR_DICT.clear() - triton_ops.lora_expand(data.inputs_tensor, - data.lora_weights, - out_tensor, - *lora_meta.meta_args(token_nums=token_nums), - offset_start=0, - add_inputs=add_inputs) + triton_ops.lora_expand( + data.inputs_tensor, + data.lora_weights, + out_tensor, + *lora_meta.meta_args(token_nums=token_nums), + offset_start=0, + add_inputs=add_inputs, + ) # Reference - sgmv_expand_for_nslices(nslices, - hidden_size, - data.inputs_tensor, - data.lora_weights, - ref_out_tensor, - *sgmv_meta_args, - add_inputs=add_inputs) + sgmv_expand_for_nslices( + nslices, + hidden_size, + data.inputs_tensor, + data.lora_weights, + ref_out_tensor, + *sgmv_meta_args, + add_inputs=add_inputs, + ) assert_close(out_tensor, ref_out_tensor) @@ -299,7 +339,7 @@ HIDDEN_SIZES = [ 128000, 128256, ] -#The size of TP +# The size of TP divisibility = [1, 2, 8, 16, 64] all_hidden_size = [] @@ -331,10 +371,10 @@ DEVICES = [f"cuda:{0}"] SEED = [0] -@pytest.mark.parametrize("batches", test_params['batches']) -@pytest.mark.parametrize("num_loras", test_params['num_loras']) -@pytest.mark.parametrize("rank", test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("batches", test_params["batches"]) +@pytest.mark.parametrize("num_loras", test_params["num_loras"]) +@pytest.mark.parametrize("rank", test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @@ -358,31 +398,35 @@ def test_kernels( current_platform.seed_everything(seed) if op_type == "shrink": - check_lora_shrink_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + ) else: - check_lora_expand_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_lora_expand_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + ) -@pytest.mark.parametrize("batches", hs_test_params['batches']) -@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) -@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("batches", hs_test_params["batches"]) +@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"]) +@pytest.mark.parametrize("rank", hs_test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @@ -406,22 +450,26 @@ def test_kernels_hidden_size( current_platform.seed_everything(seed) if op_type == "shrink": - check_lora_shrink_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + ) else: - check_lora_expand_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_lora_expand_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + ) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index caa31fdb0e73e..06e1b22ab56e5 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -20,28 +20,27 @@ class ModelWithQuantization: MODELS: list[ModelWithQuantization] -#AWQ quantization is currently not supported in ROCm. +# AWQ quantization is currently not supported in ROCm. if current_platform.is_rocm(): MODELS = [ ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="gptq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", quantization="gptq" + ), ] else: MODELS = [ ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", - quantization="awq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", quantization="awq" + ), ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="gptq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", quantization="gptq" + ), ] -def do_sample(llm: vllm.LLM, - lora_path: str, - lora_id: int, - max_tokens: int = 256) -> list[str]: +def do_sample( + llm: vllm.LLM, lora_path: str, lora_id: int, max_tokens: int = 256 +) -> list[str]: raw_prompts = [ "Give me an orange-ish brown color", "Give me a neon pink color", @@ -52,14 +51,14 @@ def do_sample(llm: vllm.LLM, prompts = [format_prompt_tuples(p) for p in raw_prompts] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=max_tokens, - stop=["<|im_end|>"]) + sampling_params = vllm.SamplingParams( + temperature=0, max_tokens=max_tokens, stop=["<|im_end|>"] + ) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -72,41 +71,30 @@ def do_sample(llm: vllm.LLM, @pytest.mark.parametrize("model", MODELS) def test_quant_model_lora(tinyllama_lora_files, model): - llm = vllm.LLM( model=model.model_path, enable_lora=True, max_num_seqs=16, max_loras=4, max_model_len=400, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + tokenizer=tinyllama_lora_files, + ) if model.quantization is None: - expected_no_lora_output = [ - "Here are some examples of orange-brown colors", - "I'm sorry, I don't have" - ] expected_lora_output = [ "#ff8050", "#ff8080", ] elif model.quantization == "awq": - expected_no_lora_output = [ - "I'm sorry, I don't understand", - "I'm sorry, I don't understand", - ] expected_lora_output = [ "#f07700: A v", "#f00000: A v", ] elif model.quantization == "gptq": - expected_no_lora_output = [ - "I'm sorry, I don't have", - "I'm sorry, I don't have", - ] expected_lora_output = [ "#f08800: This is", "#f07788 \n#", @@ -115,43 +103,23 @@ def test_quant_model_lora(tinyllama_lora_files, model): def expect_match(output, expected_output): # HACK: GPTQ lora outputs are just incredibly unstable. # Assert that the outputs changed. - if (model.quantization == "gptq" - and expected_output is expected_lora_output): - assert output != expected_no_lora_output + if model.quantization == "gptq" and expected_output is expected_lora_output: for i, o in enumerate(output): - assert o.startswith( - '#'), f"Expected example {i} to start with # but got {o}" + assert o.startswith("#"), ( + f"Expected example {i} to start with # but got {o}" + ) return assert output == expected_output max_tokens = 10 print("lora adapter created") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=0, - max_tokens=max_tokens) - expect_match(output, expected_no_lora_output) - print("lora 1") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=1, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=1, max_tokens=max_tokens) expect_match(output, expected_lora_output) - print("no lora") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=0, - max_tokens=max_tokens) - expect_match(output, expected_no_lora_output) - print("lora 2") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=2, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=2, max_tokens=max_tokens) expect_match(output, expected_lora_output) print("removing lora") @@ -161,8 +129,7 @@ def test_quant_model_lora(tinyllama_lora_files, model): @pytest.mark.parametrize("model", MODELS) -def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, - model): +def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, model): if num_gpus_available < 2: pytest.skip(f"Not enough GPUs for tensor parallelism {2}") if model.quantization == "gptq": @@ -172,10 +139,11 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, enable_lora=True, max_num_seqs=16, max_loras=4, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 @@ -187,9 +155,10 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, max_num_seqs=16, max_loras=4, tensor_parallel_size=2, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 76f3bc0ebf89f..894263bd0ba38 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -37,7 +37,8 @@ class Qwen2VLTester: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" "What is in the image?<|im_end|>\n" - "<|im_start|>assistant\n") + "<|im_start|>assistant\n" + ) def __init__(self, config: TestConfig): self.config = config @@ -56,68 +57,68 @@ class Qwen2VLTester: max_model_len=self.config.max_model_len, ) - def run_test(self, - images: list[ImageAsset], - expected_outputs: list[str], - lora_id: Optional[int] = None, - temperature: float = 0, - max_tokens: int = 5): - + def run_test( + self, + images: list[ImageAsset], + expected_outputs: list[str], + lora_id: Optional[int] = None, + temperature: float = 0, + max_tokens: int = 5, + ): sampling_params = vllm.SamplingParams( temperature=temperature, max_tokens=max_tokens, ) - inputs = [{ - "prompt": self.PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in images] - - lora_request = LoRARequest(str(lora_id), lora_id, - self.config.lora_path) - outputs = self.llm.generate(inputs, - sampling_params, - lora_request=lora_request) - generated_texts = [ - output.outputs[0].text.strip() for output in outputs + inputs = [ + { + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in images ] + lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) + outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request) + generated_texts = [output.outputs[0].text.strip() for output in outputs] + # Validate outputs for generated, expected in zip(generated_texts, expected_outputs): - assert expected.startswith( - generated), f"Generated text {generated} doesn't " + assert expected.startswith(generated), ( + f"Generated text {generated} doesn't " + ) f"match expected pattern {expected}" - def run_beam_search_test(self, - images: list[ImageAsset], - expected_outputs: list[list[str]], - lora_id: Optional[int] = None, - temperature: float = 0, - beam_width: int = 2, - max_tokens: int = 5): + def run_beam_search_test( + self, + images: list[ImageAsset], + expected_outputs: list[list[str]], + lora_id: Optional[int] = None, + temperature: float = 0, + beam_width: int = 2, + max_tokens: int = 5, + ): + beam_search_params = BeamSearchParams( + beam_width=beam_width, max_tokens=max_tokens, temperature=temperature + ) - beam_search_params = BeamSearchParams(beam_width=beam_width, - max_tokens=max_tokens, - temperature=temperature) + inputs = [ + { + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in images + ] - inputs = [{ - "prompt": self.PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in images] - - lora_request = LoRARequest(str(lora_id), lora_id, - self.config.lora_path) - outputs = self.llm.beam_search(inputs, - beam_search_params, - lora_request=lora_request) + lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) + outputs = self.llm.beam_search( + inputs, beam_search_params, lora_request=lora_request + ) for output_obj, expected_outs in zip(outputs, expected_outputs): output_texts = [seq.text for seq in output_obj.sequences] - assert output_texts == expected_outs, \ - f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501 + assert output_texts == expected_outs, ( + f"Generated texts {output_texts} do not match expected {expected_outs}" + ) # noqa: E501 TEST_IMAGES = [ @@ -144,27 +145,25 @@ QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct" @pytest.mark.xfail( current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm") + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) def test_qwen2vl_lora(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA""" - config = TestConfig(model_path=QWEN2VL_MODEL_PATH, - lora_path=qwen2vl_lora_files) + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs for lora_id in [1, 2]: - tester.run_test(TEST_IMAGES, - expected_outputs=EXPECTED_OUTPUTS, - lora_id=lora_id) + tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) @pytest.mark.xfail( current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm") + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA through beam search.""" - config = TestConfig(model_path=QWEN2VL_MODEL_PATH, - lora_path=qwen2vl_lora_files) + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs @@ -176,7 +175,8 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): tester.run_beam_search_test( [ImageAsset("cherry_blossom")], expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS, - lora_id=lora_id) + lora_id=lora_id, + ) @pytest.mark.xfail( @@ -185,12 +185,9 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): ) def test_qwen25vl_lora(qwen25vl_lora_files): """Test Qwen 2.5 VL model with LoRA""" - config = TestConfig(model_path=QWEN25VL_MODEL_PATH, - lora_path=qwen25vl_lora_files) + config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs for lora_id in [1, 2]: - tester.run_test(TEST_IMAGES, - expected_outputs=EXPECTED_OUTPUTS, - lora_id=lora_id) + tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) diff --git a/tests/lora/test_resolver.py b/tests/lora/test_resolver.py index 6c93e577611f8..c70e58a375c78 100644 --- a/tests/lora/test_resolver.py +++ b/tests/lora/test_resolver.py @@ -12,13 +12,15 @@ from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry class DummyLoRAResolver(LoRAResolver): """A dummy LoRA resolver for testing.""" - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: if lora_name == "test_lora": return LoRARequest( lora_name=lora_name, lora_path=f"/dummy/path/{base_model_name}/{lora_name}", - lora_int_id=abs(hash(lora_name))) + lora_int_id=abs(hash(lora_name)), + ) return None @@ -70,6 +72,5 @@ async def test_dummy_resolver_resolve(): assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}" # Test failed resolution - result = await dummy_resolver.resolve_lora(base_model_name, - "nonexistent_lora") + result = await dummy_resolver.resolve_lora(base_model_name, "nonexistent_lora") assert result is None diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py deleted file mode 100644 index 6cfdaf50d33c4..0000000000000 --- a/tests/lora/test_tokenizer_group.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import get_lora_tokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) -async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): - reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer_group = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=True, - max_num_seqs=1, - max_loras=1, - max_input_length=None, - ) - lora_request = LoRARequest("1", 1, sql_lora_files) - assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( - prompt="prompt", lora_request=lora_request) - assert reference_tokenizer.encode( - "prompt") == await tokenizer_group.encode_async( - prompt="prompt", lora_request=lora_request) - assert isinstance(tokenizer_group.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer_group.get_lora_tokenizer( - None) == await tokenizer_group.get_lora_tokenizer_async(None) - - assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), - PreTrainedTokenizerBase) - assert tokenizer_group.get_lora_tokenizer( - lora_request) != tokenizer_group.get_lora_tokenizer(None) - assert tokenizer_group.get_lora_tokenizer( - lora_request) == await tokenizer_group.get_lora_tokenizer_async( - lora_request) - - -def test_get_lora_tokenizer(sql_lora_files, tmp_path): - lora_request = None - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - lora_request = LoRARequest("1", 1, sql_lora_files) - tokenizer = get_lora_tokenizer(lora_request) - assert tokenizer.get_added_vocab() - - lora_request = LoRARequest("1", 1, str(tmp_path)) - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - -@pytest.mark.parametrize("enable_lora", [True, False]) -@pytest.mark.parametrize("max_num_seqs", [1, 2]) -@pytest.mark.parametrize("max_loras", [1, 2]) -def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): - tokenizer_group = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=enable_lora, - max_num_seqs=max_num_seqs, - max_loras=max_loras, - max_input_length=None, - ) - if enable_lora: - assert tokenizer_group.lora_tokenizers.capacity == max( - max_num_seqs, max_loras) - else: - assert tokenizer_group.lora_tokenizers.capacity == 0 diff --git a/tests/lora/test_transformers_model.py b/tests/lora/test_transformers_model.py index 723f7a54778fe..ea1f5f9c32c3f 100644 --- a/tests/lora/test_transformers_model.py +++ b/tests/lora/test_transformers_model.py @@ -24,20 +24,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query="What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 ), PROMPT_TEMPLATE.format( - query= - "What are all distinct countries where singers above age 20 are from?" # noqa: E501 + query="What are all distinct countries where singers above age 20 are from?" # noqa: E501 ), ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -49,13 +47,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_ilama_lora(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - trust_remote_code=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + trust_remote_code=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -65,20 +65,23 @@ def test_ilama_lora(ilama_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_ilama_lora_tp4(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=False, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -88,20 +91,23 @@ def test_ilama_lora_tp4(ilama_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index b343bef0a920b..c861a52d68721 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -9,8 +9,11 @@ import pytest from huggingface_hub.utils import HfHubHTTPError from torch import nn -from vllm.lora.utils import (get_adapter_absolute_path, - parse_fine_tuned_lora_name, replace_submodule) +from vllm.lora.utils import ( + get_adapter_absolute_path, + parse_fine_tuned_lora_name, + replace_submodule, +) from vllm.model_executor.models.utils import WeightsMapper @@ -18,89 +21,85 @@ class LoRANameParserTestConfig(NamedTuple): name: str module_name: str is_lora_a: bool - is_bias: bool weights_mapper: Optional[WeightsMapper] = None def test_parse_fine_tuned_lora_name_valid(): fixture = [ - LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", - "lm_head", True, False), - LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", - "lm_head", False, False), + LoRANameParserTestConfig( + "base_model.model.lm_head.lora_A.weight", "lm_head", True, False + ), + LoRANameParserTestConfig( + "base_model.model.lm_head.lora_B.weight", "lm_head", False, False + ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, - False, ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, - False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, - False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, - False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, - False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, - False, ), # Test with WeightsMapper LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, - False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, - False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, - False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, - False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), ] - for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: - assert (module_name, is_lora_a, - is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) + for name, module_name, is_lora_a, weights_mapper in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name( + name, weights_mapper + ) def test_parse_fine_tuned_lora_name_invalid(): @@ -115,22 +114,28 @@ def test_parse_fine_tuned_lora_name_invalid(): def test_replace_submodule(): model = nn.Sequential( - OrderedDict([ - ("dense1", nn.Linear(764, 100)), - ("act1", nn.ReLU()), - ("dense2", nn.Linear(100, 50)), - ( - "seq1", - nn.Sequential( - OrderedDict([ - ("dense1", nn.Linear(100, 10)), - ("dense2", nn.Linear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("output", nn.Linear(50, 10)), - ("outact", nn.Sigmoid()), - ])) + OrderedDict( + [ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict( + [ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ] + ) + ) sigmoid = nn.Sigmoid() @@ -143,52 +148,51 @@ def test_replace_submodule(): # Unit tests for get_adapter_absolute_path -@patch('os.path.isabs') +@patch("os.path.isabs") def test_get_adapter_absolute_path_absolute(mock_isabs): - path = '/absolute/path/to/lora' + path = "/absolute/path/to/lora" mock_isabs.return_value = True assert get_adapter_absolute_path(path) == path -@patch('os.path.expanduser') +@patch("os.path.expanduser") def test_get_adapter_absolute_path_expanduser(mock_expanduser): # Path with ~ that needs to be expanded - path = '~/relative/path/to/lora' - absolute_path = '/home/user/relative/path/to/lora' + path = "~/relative/path/to/lora" + absolute_path = "/home/user/relative/path/to/lora" mock_expanduser.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('os.path.exists') -@patch('os.path.abspath') +@patch("os.path.exists") +@patch("os.path.abspath") def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist): # Relative path that exists locally - path = 'relative/path/to/lora' - absolute_path = '/absolute/path/to/lora' + path = "relative/path/to/lora" + absolute_path = "/absolute/path/to/lora" mock_exist.return_value = True mock_abspath.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('huggingface_hub.snapshot_download') -@patch('os.path.exists') -def test_get_adapter_absolute_path_huggingface(mock_exist, - mock_snapshot_download): +@patch("huggingface_hub.snapshot_download") +@patch("os.path.exists") +def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download): # Hugging Face model identifier - path = 'org/repo' - absolute_path = '/mock/snapshot/path' + path = "org/repo" + absolute_path = "/mock/snapshot/path" mock_exist.return_value = False mock_snapshot_download.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('huggingface_hub.snapshot_download') -@patch('os.path.exists') -def test_get_adapter_absolute_path_huggingface_error(mock_exist, - mock_snapshot_download): +@patch("huggingface_hub.snapshot_download") +@patch("os.path.exists") +def test_get_adapter_absolute_path_huggingface_error( + mock_exist, mock_snapshot_download +): # Hugging Face model identifier with download error - path = 'org/repo' + path = "org/repo" mock_exist.return_value = False - mock_snapshot_download.side_effect = HfHubHTTPError( - "failed to query model info") + mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info") assert get_adapter_absolute_path(path) == path diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index a836ff94ba3ed..c97f8debd1b9a 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -6,9 +6,16 @@ import random import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.config.load import LoadConfig +from vllm.config.lora import LoRAConfig from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker @@ -18,12 +25,12 @@ NUM_LORAS = 16 @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) + lora_requests, lora_mapping + ) vllm_config = VllmConfig( model_config=ModelConfig( @@ -48,9 +55,9 @@ def test_worker_apply_lora(sql_lora_files): swap_space=0, cache_dtype="auto", ), - lora_config=LoRAConfig(max_lora_rank=8, - max_cpu_loras=NUM_LORAS, - max_loras=NUM_LORAS), + lora_config=LoRAConfig( + max_lora_rank=8, max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS + ), ) worker = Worker( vllm_config=vllm_config, @@ -66,23 +73,22 @@ def test_worker_apply_lora(sql_lora_files): assert worker.list_loras() == set() lora_requests = [ - LoRARequest(str(i + 1), i + 1, sql_lora_files) - for i in range(NUM_LORAS) + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS) ] set_active_loras(worker, lora_requests) assert worker.list_loras() == { - lora_request.lora_int_id - for lora_request in lora_requests + lora_request.lora_int_id for lora_request in lora_requests } for i in range(NUM_LORAS): random.seed(i) - iter_lora_requests = random.choices(lora_requests, - k=random.randint(1, NUM_LORAS)) + iter_lora_requests = random.choices( + lora_requests, k=random.randint(1, NUM_LORAS) + ) random.shuffle(iter_lora_requests) - iter_lora_requests = iter_lora_requests[:-random.randint(0, NUM_LORAS)] + iter_lora_requests = iter_lora_requests[: -random.randint(0, NUM_LORAS)] set_active_loras(worker, lora_requests) assert worker.list_loras().issuperset( - {lora_request.lora_int_id - for lora_request in iter_lora_requests}) + {lora_request.lora_int_id for lora_request in iter_lora_requests} + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7cda90787b6f1..b522aa6b08743 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -9,11 +9,10 @@ from typing import Optional, Union import torch from safetensors.torch import save_file -from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights class DummyLoRAManager: - def __init__(self, device: torch.device = "cuda:0"): super().__init__() self._loras: dict[str, LoRALayerWeights] = {} @@ -36,12 +35,12 @@ class DummyLoRAManager: module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([weight.shape[1], rank], - dtype=weight.dtype, - device=self._device), - lora_b=torch.rand([rank, weight.shape[0]], - dtype=weight.dtype, - device=self._device), + lora_a=torch.rand( + [rank, weight.shape[1]], dtype=weight.dtype, device=self._device + ), + lora_b=torch.rand( + [weight.shape[0], rank], dtype=weight.dtype, device=self._device + ), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand( @@ -67,8 +66,8 @@ class DummyLoRAManager: module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([input_dim, rank], device="cuda"), - lora_b=torch.rand([rank, output_dim], device="cuda"), + lora_a=torch.rand([rank, input_dim], device="cuda"), + lora_b=torch.rand([output_dim, input_dim], device="cuda"), embeddings_tensor=embeddings_tensor, ) self.set_module_lora(module_name, lora) @@ -146,27 +145,26 @@ def generate_data( op_type, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, ).to(device) total_tokens = seq_len_tensor.sum() if op_type == "shrink": - inputs_tensor = torch.rand((total_tokens, hidden_size), - dtype=dtype).to(device) + inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device) lora_weights = torch.rand( (lora_nums, max_rank, hidden_size), # col-major dtype=dtype, ).to(device) # shrink op need atomic_add, so output is initinized by 0 - ref_out_tensor = torch.zeros((total_tokens, max_rank), - dtype=dtype, - device=inputs_tensor.device) + ref_out_tensor = torch.zeros( + (total_tokens, max_rank), dtype=dtype, device=inputs_tensor.device + ) # NOTE shrink kernel using torch.float32 as output type - our_out_tensor = torch.zeros((total_tokens, max_rank), - dtype=torch.float32).to(device) + our_out_tensor = torch.zeros((total_tokens, max_rank), dtype=torch.float32).to( + device + ) else: inputs_tensor = torch.rand( (total_tokens, max_rank), @@ -184,15 +182,16 @@ def generate_data( ).to(device) # Ensure the same input. our_out_tensor = ref_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )).to(device) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ).to(device) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]].copy_(lora_index) + indices[current_offset : current_offset + seq_len_tensor[b_id]].copy_( + lora_index + ) current_offset += seq_len_tensor[b_id].item() return PunicaTensors( @@ -217,8 +216,7 @@ def generate_data_for_expand_nslices( nslices, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, @@ -234,22 +232,25 @@ def generate_data_for_expand_nslices( torch.rand( (lora_nums, hidden_size, max_rank), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # expand op needs to complete y+=a@lora_b, so output is # initinized randomly - ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), - dtype=dtype).to(device) + ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), dtype=dtype).to( + device + ) # Ensure the same input. our_out_tensor = ref_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]] = (lora_index.item()) + indices[current_offset : current_offset + seq_len_tensor[b_id]] = ( + lora_index.item() + ) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) @@ -276,8 +277,7 @@ def generate_data_for_nslices( op_type, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, @@ -286,9 +286,7 @@ def generate_data_for_nslices( lora_weights_lst = [] if op_type == "shrink": - - inputs_tensor = torch.rand((total_tokens, hidden_size), - dtype=dtype).to(device) + inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device) for _ in range(nslices): if op_type == "shrink": @@ -296,7 +294,8 @@ def generate_data_for_nslices( torch.rand( (lora_nums, max_rank, hidden_size), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # NOTE shrink kernel using torch.float32 as output type # shrink op need atomic_add, so output is initinized by 0 our_out_tensor = torch.zeros( @@ -313,23 +312,26 @@ def generate_data_for_nslices( torch.rand( (lora_nums, hidden_size, max_rank), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # expand op needs to complete y+=a@lora_b, so output is # initinized randomly - our_out_tensor = torch.rand((total_tokens, hidden_size * nslices), - dtype=dtype).to(device) + our_out_tensor = torch.rand( + (total_tokens, hidden_size * nslices), dtype=dtype + ).to(device) # Ensure the same input. ref_out_tensor = our_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]] = (lora_index.item()) + indices[current_offset : current_offset + seq_len_tensor[b_id]] = ( + lora_index.item() + ) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) @@ -379,24 +381,20 @@ def create_peft_lora( } for module_name in target_modules: - module = model for attr in module_name.split("."): module = getattr(module, attr) if hasattr(module, "input_size") and hasattr(module, "output_size"): - in_features = module.input_size out_features = module.output_size - elif hasattr(module, "embedding_dim") and hasattr( - module, "num_embeddings"): + elif hasattr(module, "embedding_dim") and hasattr(module, "num_embeddings"): # ParallelLMHead in_features = module.embedding_dim out_features = module.num_embeddings else: - raise ValueError( - f"Unable to determine dimensions for module {module_name}") + raise ValueError(f"Unable to determine dimensions for module {module_name}") lora_A = torch.randn(rank, in_features, dtype=lora_dtype) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py deleted file mode 100644 index dbd9c518e0200..0000000000000 --- a/tests/metrics/test_metrics.py +++ /dev/null @@ -1,268 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import ray -from prometheus_client import REGISTRY - -import vllm.envs as envs -from vllm import EngineArgs, LLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.metrics import RayPrometheusStatLogger -from vllm.sampling_params import SamplingParams -from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -MODELS = [ - "distilbert/distilgpt2", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [128]) -def test_metric_counter_prompt_tokens( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) as vllm_model: - tokenizer = vllm_model.llm.get_tokenizer() - prompt_token_counts = [ - len(tokenizer.encode(p)) for p in example_prompts - ] - # This test needs at least 2 prompts in a batch of different lengths to - # verify their token count is correct despite padding. - assert len(example_prompts) > 1, "at least 2 prompts are required" - assert prompt_token_counts[0] != prompt_token_counts[1], ( - "prompts of different lengths are required") - vllm_prompt_token_count = sum(prompt_token_counts) - - _ = vllm_model.generate_greedy(example_prompts, max_tokens) - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metric_count = stat_logger.metrics.counter_prompt_tokens.labels( - **stat_logger.labels)._value.get() - - assert vllm_prompt_token_count == metric_count, ( - f"prompt token count: {vllm_prompt_token_count!r}\n" - f"metric: {metric_count!r}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [128]) -def test_metric_counter_generation_tokens( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - tokenizer = vllm_model.llm.get_tokenizer() - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() - vllm_generation_count = 0 - for i in range(len(example_prompts)): - vllm_output_ids, vllm_output_str = vllm_outputs[i] - prompt_ids = tokenizer.encode(example_prompts[i]) - # vllm_output_ids contains both prompt tokens and generation tokens. - # We're interested only in the count of the generation tokens. - vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) - - assert vllm_generation_count == metric_count, ( - f"generation token count: {vllm_generation_count!r}\n" - f"metric: {metric_count!r}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize( - "served_model_name", - [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) -def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, - served_model_name: list[str]) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.3, - served_model_name=served_model_name) as vllm_model: - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metrics_tag_content = stat_logger.labels["model_name"] - - if envs.VLLM_CI_USE_S3: - model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" - if served_model_name is None or served_model_name == []: - assert metrics_tag_content == model, ( - f"Metrics tag model_name is wrong! expect: {model!r}\n" - f"actual: {metrics_tag_content!r}") - else: - assert metrics_tag_content == served_model_name[0], ( - f"Metrics tag model_name is wrong! expect: " - f"{served_model_name[0]!r}\n" - f"actual: {metrics_tag_content!r}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("disable_log_stats", [True, False]) -@pytest.mark.asyncio -async def test_async_engine_log_metrics_regression( - example_prompts, - model: str, - dtype: str, - max_tokens: int, - disable_log_stats: bool, -) -> None: - """ - Regression test ensuring async engine generates metrics - when disable_log_stats=False - (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) - """ - engine_args = AsyncEngineArgs( - model=model, - dtype=dtype, - disable_log_stats=disable_log_stats, - ) - async_engine = AsyncLLMEngine.from_engine_args(engine_args) - for i, prompt in enumerate(example_prompts): - results = async_engine.generate( - prompt, - SamplingParams(max_tokens=max_tokens), - f"request-id-{i}", - ) - # Exhaust the async iterator to make the async engine work - async for _ in results: - pass - - assert_metrics(model, async_engine.engine, disable_log_stats, - len(example_prompts)) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("disable_log_stats", [True, False]) -def test_engine_log_metrics_regression( - example_prompts, - model: str, - dtype: str, - max_tokens: int, - disable_log_stats: bool, -) -> None: - engine_args = EngineArgs( - model=model, - dtype=dtype, - disable_log_stats=disable_log_stats, - ) - engine = LLMEngine.from_engine_args(engine_args) - for i, prompt in enumerate(example_prompts): - engine.add_request( - f"request-id-{i}", - prompt, - SamplingParams(max_tokens=max_tokens), - ) - while engine.has_unfinished_requests(): - engine.step() - - if envs.VLLM_CI_USE_S3: - model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" - assert_metrics(model, engine, disable_log_stats, len(example_prompts)) - - -def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, - num_requests: int) -> None: - if disable_log_stats: - with pytest.raises(AttributeError): - _ = engine.stat_loggers - else: - assert (engine.stat_loggers - is not None), "engine.stat_loggers should be set" - # Ensure the count bucket of request-level histogram metrics matches - # the number of requests as a simple sanity check to ensure metrics are - # generated - labels = {'model_name': model} - request_histogram_metrics = [ - "vllm:e2e_request_latency_seconds", - "vllm:request_prompt_tokens", - "vllm:request_generation_tokens", - "vllm:request_params_n", - "vllm:request_params_max_tokens", - ] - for metric_name in request_histogram_metrics: - metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", - labels) - assert ( - metric_value == num_requests), "Metrics should be collected" - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [16]) -def test_engine_log_metrics_ray( - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # This test is quite weak - it only checks that we can use - # RayPrometheusStatLogger without exceptions. - # Checking whether the metrics are actually emitted is unfortunately - # non-trivial. - - # We have to run in a Ray task for Ray metrics to be emitted correctly - @ray.remote(num_gpus=1) - def _inner(): - - class _RayPrometheusStatLogger(RayPrometheusStatLogger): - - def __init__(self, *args, **kwargs): - self._i = 0 - super().__init__(*args, **kwargs) - - def log(self, *args, **kwargs): - self._i += 1 - return super().log(*args, **kwargs) - - engine_args = EngineArgs( - model=model, - dtype=dtype, - disable_log_stats=False, - ) - engine = LLMEngine.from_engine_args(engine_args) - logger = _RayPrometheusStatLogger( - local_interval=0.5, - labels=dict(model_name=engine.model_config.served_model_name), - vllm_config=engine.vllm_config) - engine.add_logger("ray", logger) - for i, prompt in enumerate(example_prompts): - engine.add_request( - f"request-id-{i}", - prompt, - SamplingParams(max_tokens=max_tokens), - ) - while engine.has_unfinished_requests(): - engine.step() - assert logger._i > 0, ".log must be called at least once" - - ray.get(_inner.remote()) diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py deleted file mode 100644 index c6d89d849e9f9..0000000000000 --- a/tests/model_executor/conftest.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - - -@pytest.fixture -def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") - - -@pytest.fixture -def sample_json_schema(): - return { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "skills": { - "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 - }, - "work_history": { - "type": "array", - "items": { - "type": "object", - "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } - }, - "required": ["company", "position"] - } - } - }, - "required": ["name", "age", "skills", "work_history"] - } diff --git a/tests/encoder_decoder/__init__.py b/tests/model_executor/model_loader/fastsafetensors_loader/__init__.py similarity index 100% rename from tests/encoder_decoder/__init__.py rename to tests/model_executor/model_loader/fastsafetensors_loader/__init__.py diff --git a/tests/fastsafetensors_loader/test_fastsafetensors_loader.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py similarity index 100% rename from tests/fastsafetensors_loader/test_fastsafetensors_loader.py rename to tests/model_executor/model_loader/fastsafetensors_loader/test_fastsafetensors_loader.py diff --git a/tests/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py similarity index 64% rename from tests/fastsafetensors_loader/test_weight_utils.py rename to tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py index 78d23acfec7c5..cc899b77b5e9a 100644 --- a/tests/fastsafetensors_loader/test_weight_utils.py +++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py @@ -8,24 +8,25 @@ import huggingface_hub.constants import torch from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, fastsafetensors_weights_iterator, - safetensors_weights_iterator) + download_weights_from_hf, + fastsafetensors_weights_iterator, + safetensors_weights_iterator, +) def test_fastsafetensors_model_loader(): with tempfile.TemporaryDirectory() as tmpdir: huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("openai-community/gpt2", - allow_patterns=["*.safetensors"], - cache_dir=tmpdir) + download_weights_from_hf( + "openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir + ) safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) assert len(safetensors) > 0 fastsafetensors_tensors = {} hf_safetensors_tensors = {} - for name, tensor in fastsafetensors_weights_iterator( - safetensors, True): + for name, tensor in fastsafetensors_weights_iterator(safetensors, True): fastsafetensors_tensors[name] = tensor for name, tensor in safetensors_weights_iterator(safetensors, True): @@ -34,13 +35,10 @@ def test_fastsafetensors_model_loader(): assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors) for name, fastsafetensors_tensor in fastsafetensors_tensors.items(): - fastsafetensors_tensor = fastsafetensors_tensor.to('cpu') - assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[ - name].dtype - assert fastsafetensors_tensor.shape == hf_safetensors_tensors[ - name].shape - assert torch.all( - fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) + fastsafetensors_tensor = fastsafetensors_tensor.to("cpu") + assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[name].dtype + assert fastsafetensors_tensor.shape == hf_safetensors_tensors[name].shape + assert torch.all(fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) if __name__ == "__main__": diff --git a/tests/fastsafetensors_loader/__init__.py b/tests/model_executor/model_loader/runai_model_streamer/__init__.py similarity index 100% rename from tests/fastsafetensors_loader/__init__.py rename to tests/model_executor/model_loader/runai_model_streamer/__init__.py diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py similarity index 96% rename from tests/runai_model_streamer_test/test_runai_model_streamer_loader.py rename to tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py index 84c615b6b8dbc..22bdb3b44eb03 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import get_model_loader load_format = "runai_streamer" diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py b/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py new file mode 100644 index 0000000000000..3ad7308eeba24 --- /dev/null +++ b/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import glob +import hashlib +import os +import tempfile + +import huggingface_hub.constants + +from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf +from vllm.transformers_utils.runai_utils import ( + ObjectStorageModel, + is_runai_obj_uri, + list_safetensors, +) + + +def test_is_runai_obj_uri(): + assert is_runai_obj_uri("gs://some-gcs-bucket/path") + assert is_runai_obj_uri("s3://some-s3-bucket/path") + assert not is_runai_obj_uri("nfs://some-nfs-path") + + +def test_runai_list_safetensors_local(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf( + "openai-community/gpt2", + allow_patterns=["*.safetensors", "*.json"], + cache_dir=tmpdir, + ) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + parentdir = [os.path.dirname(safetensor) for safetensor in safetensors][0] + files = list_safetensors(parentdir) + assert len(safetensors) == len(files) + + +def test_runai_pull_files_gcs(monkeypatch): + monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true") + # Bypass default project lookup by setting GOOGLE_CLOUD_PROJECT + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project") + filename = "LT08_L1GT_074061_20130309_20170505_01_T2_MTL.txt" + gcs_bucket = "gs://gcp-public-data-landsat/LT08/01/074/061/LT08_L1GT_074061_20130309_20170505_01_T2/" + gcs_url = f"{gcs_bucket}/{filename}" + model = ObjectStorageModel(gcs_url) + model.pull_files(gcs_bucket, allow_pattern=[f"*{filename}"]) + # To re-generate / change URLs: + # gsutil ls -L gs://<gcs-url> | grep "Hash (md5)" | tr -d ' ' \ + # | cut -d":" -f2 | base64 -d | xxd -p + expected_checksum = "f60dea775da1392434275b311b31a431" + hasher = hashlib.new("md5") + with open(os.path.join(model.dir, filename), "rb") as f: + # Read the file in chunks to handle large files efficiently + for chunk in iter(lambda: f.read(4096), b""): + hasher.update(chunk) + actual_checksum = hasher.hexdigest() + assert actual_checksum == expected_checksum diff --git a/tests/runai_model_streamer_test/test_weight_utils.py b/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py similarity index 76% rename from tests/runai_model_streamer_test/test_weight_utils.py rename to tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py index ee448c2ccb213..03691b4a472f1 100644 --- a/tests/runai_model_streamer_test/test_weight_utils.py +++ b/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py @@ -8,24 +8,25 @@ import huggingface_hub.constants import torch from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, runai_safetensors_weights_iterator, - safetensors_weights_iterator) + download_weights_from_hf, + runai_safetensors_weights_iterator, + safetensors_weights_iterator, +) def test_runai_model_loader(): with tempfile.TemporaryDirectory() as tmpdir: huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("openai-community/gpt2", - allow_patterns=["*.safetensors"], - cache_dir=tmpdir) + download_weights_from_hf( + "openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir + ) safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) assert len(safetensors) > 0 runai_model_streamer_tensors = {} hf_safetensors_tensors = {} - for name, tensor in runai_safetensors_weights_iterator( - safetensors, True): + for name, tensor in runai_safetensors_weights_iterator(safetensors, True): runai_model_streamer_tensors[name] = tensor for name, tensor in safetensors_weights_iterator(safetensors, True): diff --git a/tests/metrics/__init__.py b/tests/model_executor/model_loader/tensorizer_loader/__init__.py similarity index 100% rename from tests/metrics/__init__.py rename to tests/model_executor/model_loader/tensorizer_loader/__init__.py diff --git a/tests/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py similarity index 82% rename from tests/tensorizer_loader/conftest.py rename to tests/model_executor/model_loader/tensorizer_loader/conftest.py index 18aa4c88c0338..add6d3742ff53 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/model_executor/model_loader/tensorizer_loader/conftest.py @@ -10,7 +10,7 @@ from vllm.model_executor.model_loader import tensorizer as tensorizer_mod from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import UniProcExecutor -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.worker.worker_base import WorkerWrapperBase MODEL_REF = "facebook/opt-125m" @@ -32,7 +32,6 @@ def cleanup(): @pytest.fixture() def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path): - def noop(*args, **kwargs): return None @@ -56,8 +55,7 @@ def model_path(model_ref, tmp_path): yield tmp_path / model_ref / "model.tensors" -def assert_from_collective_rpc(engine: LLM, closure: Callable, - closure_kwargs: dict): +def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: dict): res = engine.collective_rpc(method=closure, kwargs=closure_kwargs) return all(res) @@ -67,18 +65,13 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, # method. It's purely used as a dummy utility to run methods that test # Tensorizer functionality class DummyExecutor(UniProcExecutor): - def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + """Initialize the worker and load the model.""" + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) local_rank = 0 # set local rank as the device index if specified - device_info = self.vllm_config.device_config.device.__str__().split( - ":") + device_info = self.vllm_config.device_config.device.__str__().split(":") if len(device_info) > 1: local_rank = int(device_info[1]) rank = 0 @@ -90,7 +83,8 @@ class DummyExecutor(UniProcExecutor): distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) - self.collective_rpc("init_worker", args=([kwargs], )) + self.mm_receiver_cache = None + self.collective_rpc("init_worker", args=([kwargs],)) self.collective_rpc("init_device") @property @@ -98,5 +92,5 @@ class DummyExecutor(UniProcExecutor): return 2 def shutdown(self): - if hasattr(self, 'thread_pool'): + if hasattr(self, "thread_pool"): self.thread_pool.shutdown(wait=False) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py similarity index 69% rename from tests/tensorizer_loader/test_tensorizer.py rename to tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py index 0fb142a1b6e56..57db1f98baed0 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py @@ -14,20 +14,21 @@ import pytest import torch import vllm.model_executor.model_loader.tensorizer +from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -# yapf: disable -from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, - TensorSerializer, - is_vllm_tensorized, - open_stream, - tensorize_vllm_model) +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + open_stream, + tensorize_vllm_model, +) from vllm.model_executor.model_loader.tensorizer_loader import ( - BLACKLISTED_TENSORIZER_ARGS) -# yapf: enable + BLACKLISTED_TENSORIZER_ARGS, +) from vllm.utils import PlaceholderModule -from ..utils import VLLM_PATH, RemoteOpenAIServer from .conftest import DummyExecutor, assert_from_collective_rpc try: @@ -44,7 +45,7 @@ class TensorizerCaughtError(Exception): EXAMPLES_PATH = VLLM_PATH / "examples" -pytest_plugins = "pytest_asyncio", +pytest_plugins = ("pytest_asyncio",) prompts = [ "Hello, my name is", @@ -56,8 +57,7 @@ prompts = [ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) -def patch_init_and_catch_error(self, obj, method_name, - expected_error: type[Exception]): +def patch_init_and_catch_error(self, obj, method_name, expected_error: type[Exception]): original = getattr(obj, method_name, None) if original is None: raise ValueError("Method '{}' not found.".format(method_name)) @@ -80,17 +80,19 @@ def assert_specific_tensorizer_error_is_raised( expected_error: type[Exception], ): with pytest.raises(TensorizerCaughtError): - executor.collective_rpc(patch_init_and_catch_error, - args=( - obj, - method_name, - expected_error, - )) + executor.collective_rpc( + patch_init_and_catch_error, + args=( + obj, + method_name, + expected_error, + ), + ) def is_curl_installed(): try: - subprocess.check_call(['curl', '--version']) + subprocess.check_call(["curl", "--version"]) return True except (subprocess.CalledProcessError, FileNotFoundError): return False @@ -99,13 +101,14 @@ def is_curl_installed(): def write_keyfile(keyfile_path: str): encryption_params = EncryptionParams.random() pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True) - with open(keyfile_path, 'wb') as f: + with open(keyfile_path, "wb") as f: f.write(encryption_params.key) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( - model_ref, vllm_runner, tmp_path, model_path): + model_ref, vllm_runner, tmp_path, model_path +): args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: key_path = tmp_path / model_ref / "model.key" @@ -113,29 +116,30 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( outputs = vllm_model.generate(prompts, sampling_params) - config_for_serializing = TensorizerConfig(tensorizer_uri=str(model_path), - encryption_keyfile=str(key_path)) + config_for_serializing = TensorizerConfig( + tensorizer_uri=str(model_path), encryption_keyfile=str(key_path) + ) tensorize_vllm_model(args, config_for_serializing) config_for_deserializing = TensorizerConfig( - tensorizer_uri=str(model_path), encryption_keyfile=str(key_path)) + tensorizer_uri=str(model_path), encryption_keyfile=str(key_path) + ) - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=config_for_deserializing - ) as loaded_vllm_model: # noqa: E501 - - deserialized_outputs = loaded_vllm_model.generate( - prompts, sampling_params) + with vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=config_for_deserializing, + ) as loaded_vllm_model: # noqa: E501 + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 assert outputs == deserialized_outputs -def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, - tmp_path, model_ref, - model_path): +def test_deserialized_hf_model_has_same_outputs( + hf_runner, vllm_runner, tmp_path, model_ref, model_path +): with hf_runner(model_ref) as hf_model: max_tokens = 50 outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) @@ -143,14 +147,17 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, serializer = TensorSerializer(stream) serializer.write_module(hf_model.model) - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=str(model_path), - num_readers=1, - )) as loaded_hf_model: + with vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=str(model_path), + num_readers=1, + ), + ) as loaded_hf_model: deserialized_outputs = loaded_hf_model.generate_greedy( - prompts, max_tokens=max_tokens) + prompts, max_tokens=max_tokens + ) assert outputs == deserialized_outputs @@ -159,34 +166,37 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( - model_ref, - model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + model_ref, model_loader_extra_config=TensorizerConfig(tensorizer_uri="test") + ) + pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Model loader extra config " - "is not supported for load " - "format auto") in combined_output + assert ( + "ValueError: Unexpected extra config keys for load format auto" + ) in combined_output finally: del model gc.collect() torch.cuda.empty_cache() -def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, - model_ref): +def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( model_ref, load_format="safetensors", - model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"), + ) + pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Model loader extra config is not supported " - "for load format safetensors") in combined_output + assert ( + "ValueError: Unexpected extra config keys for load format safetensors" + ) in combined_output finally: del model gc.collect() @@ -213,21 +223,24 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd): except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: For a sharded model, tensorizer_uri " - "should include a string format template like '%04d' " - "to be formatted with the rank " - "of the shard") in combined_output + assert ( + "ValueError: For a sharded model, tensorizer_uri " + "should include a string format template like '%04d' " + "to be formatted with the rank " + "of the shard" + ) in combined_output @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( - vllm_runner, tmp_path): + vllm_runner, tmp_path +): model_ref = "EleutherAI/pythia-1.4b" # record outputs from un-sharded un-tensorized model with vllm_runner( - model_ref, - disable_custom_all_reduce=True, - enforce_eager=True, + model_ref, + disable_custom_all_reduce=True, + enforce_eager=True, ) as base_model: outputs = base_model.generate(prompts, sampling_params) @@ -253,21 +266,22 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( assert os.path.isfile(model_path % 1), "Serialization subprocess failed" with vllm_runner( - model_ref, - tensor_parallel_size=2, - load_format="tensorizer", - disable_custom_all_reduce=True, - enforce_eager=True, - model_loader_extra_config=tensorizer_config) as loaded_vllm_model: - deserialized_outputs = loaded_vllm_model.generate( - prompts, sampling_params) + model_ref, + tensor_parallel_size=2, + load_format="tensorizer", + disable_custom_all_reduce=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + ) as loaded_vllm_model: + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) assert outputs == deserialized_outputs @pytest.mark.flaky(reruns=3) -def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, - tmp_path, model_path): +def test_vllm_tensorized_model_has_same_outputs( + model_ref, vllm_runner, tmp_path, model_path +): gc.collect() torch.cuda.empty_cache() config = TensorizerConfig(tensorizer_uri=str(model_path)) @@ -279,11 +293,10 @@ def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, tensorize_vllm_model(args, config) assert is_vllm_tensorized(config) - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=config) as loaded_vllm_model: - deserialized_outputs = loaded_vllm_model.generate( - prompts, sampling_params) + with vllm_runner( + model_ref, load_format="tensorizer", model_loader_extra_config=config + ) as loaded_vllm_model: + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 assert outputs == deserialized_outputs @@ -313,15 +326,17 @@ def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref): def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): - serialization_params = { "limit_cpu_concurrency": 2, } model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) - llm = LLM(model=model_ref, ) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) + llm = LLM( + model=model_ref, + ) def serialization_test(self, *args, **kwargs): # This is performed in the ephemeral worker process, so monkey-patching @@ -339,10 +354,13 @@ def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): return original(self, *args, **kwargs) tensorizer.serialization.TensorSerializer.__init__ = ( - tensorizer_serializer_wrapper) + tensorizer_serializer_wrapper + ) tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"]) - self.save_tensorized_model(tensorizer_config=tensorizer_config, ) + self.save_tensorized_model( + tensorizer_config=tensorizer_config, + ) return to_compare | original_dict == to_compare kwargs = {"tensorizer_config": config.to_serializable()} @@ -350,9 +368,7 @@ def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): assert assert_from_collective_rpc(llm, serialization_test, kwargs) -def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( - tmp_path, capfd): - +def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): deserialization_kwargs = { "num_readers": "bar", # illegal value } @@ -363,8 +379,9 @@ def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) @@ -392,7 +409,6 @@ def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): - deserialization_kwargs = { "num_readers": 1, } @@ -403,8 +419,9 @@ def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) @@ -440,16 +457,24 @@ async def test_serialize_and_serve_entrypoints(tmp_path): suffix = "test" try: - result = subprocess.run([ - sys.executable, - f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", - model_ref, "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix, "--serialization-kwargs", - '{"limit_cpu_concurrency": 4}' - ], - check=True, - capture_output=True, - text=True) + result = subprocess.run( + [ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", + "--model", + model_ref, + "serialize", + "--serialized-directory", + str(tmp_path), + "--suffix", + suffix, + "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}', + ], + check=True, + capture_output=True, + text=True, + ) except subprocess.CalledProcessError as e: print("Tensorizing failed.") print("STDOUT:\n", e.stdout) @@ -469,14 +494,20 @@ async def test_serialize_and_serve_entrypoints(tmp_path): "deserialization_kwargs": { "verify_hash": True, "num_readers": 8, - } + }, } cmd = [ - "-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost", - "--load-format", "tensorizer", model_ref, + "-m", + "vllm.entrypoints.cli.main", + "serve", + "--host", + "localhost", + "--load-format", + "tensorizer", + model_ref, "--model-loader-extra-config", - json.dumps(model_loader_extra_config, indent=2) + json.dumps(model_loader_extra_config, indent=2), ] proc = await asyncio.create_subprocess_exec( @@ -499,17 +530,16 @@ async def test_serialize_and_serve_entrypoints(tmp_path): @pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS) -def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, - illegal_value): - +def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, illegal_value): serialization_params = { "limit_cpu_concurrency": 2, } model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) @@ -525,5 +555,6 @@ def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert (f"ValueError: {illegal_value} is not an allowed " - f"Tensorizer argument.") in combined_output + assert ( + f"ValueError: {illegal_value} is not an allowed Tensorizer argument." + ) in combined_output diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py index 93a3e34835b5a..020988ccac13c 100644 --- a/tests/model_executor/model_loader/test_registry.py +++ b/tests/model_executor/model_loader/test_registry.py @@ -4,23 +4,21 @@ import pytest from torch import nn -from vllm.config import LoadConfig, ModelConfig -from vllm.model_executor.model_loader import (get_model_loader, - register_model_loader) +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader import get_model_loader, register_model_loader from vllm.model_executor.model_loader.base_loader import BaseModelLoader @register_model_loader("custom_load_format") class CustomModelLoader(BaseModelLoader): - def __init__(self, load_config: LoadConfig) -> None: super().__init__(load_config) def download_model(self, model_config: ModelConfig) -> None: pass - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: pass diff --git a/tests/test_sharded_state_loader.py b/tests/model_executor/model_loader/test_sharded_state_loader.py similarity index 61% rename from tests/test_sharded_state_loader.py rename to tests/model_executor/model_loader/test_sharded_state_loader.py index 42afdfa3c7468..5bb841bf2fa0e 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/model_executor/model_loader/test_sharded_state_loader.py @@ -35,11 +35,13 @@ def test_filter_subtensors(): "b": torch.empty((2, 4)), "c": torch.empty((2, 4, 8)), } - state_dict.update({ - "x": state_dict["b"], - "y": state_dict["c"][1, 2, :], - "z": state_dict["c"][1, :, 4], - }) + state_dict.update( + { + "x": state_dict["b"], + "y": state_dict["c"][1, 2, :], + "z": state_dict["c"][1, :, 4], + } + ) filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") for key, tensor in filtered_state_dict.items(): @@ -49,24 +51,34 @@ def test_filter_subtensors(): @pytest.fixture(scope="module") def llama_3p2_1b_files(): - input_dir = snapshot_download("meta-llama/Llama-3.2-1B-Instruct", - ignore_patterns=["*.bin*", "original/*"]) + input_dir = snapshot_download( + "meta-llama/Llama-3.2-1B-Instruct", ignore_patterns=["*.bin*", "original/*"] + ) yield input_dir def _run_writer(input_dir, output_dir, weights_patterns, **kwargs): llm_sharded_writer = LLM(model=input_dir, **kwargs) - + # Check which engine version is being used + is_v1_engine = hasattr(llm_sharded_writer.llm_engine, "engine_core") # Dump worker states to output directory - llm_sharded_writer.llm_engine.model_executor.save_sharded_state( - path=output_dir) + if is_v1_engine: + # For V1 engine, we need to use engine_core.save_sharded_state + print("Using V1 engine save path") + llm_sharded_writer.llm_engine.engine_core.save_sharded_state(path=output_dir) + else: + # For V0 engine + print("Using V0 engine save path") + model_executor = llm_sharded_writer.llm_engine.model_executor + model_executor.save_sharded_state(path=output_dir) # Copy metadata files to output directory for file in os.listdir(input_dir): if os.path.isdir(os.path.join(input_dir, file)): - shutil.copytree(os.path.join(input_dir, file), - os.path.join(output_dir, file)) + shutil.copytree( + os.path.join(input_dir, file), os.path.join(output_dir, file) + ) elif not any(fnmatch.fnmatch(file, ext) for ext in weights_patterns): shutil.copy(os.path.join(input_dir, file), output_dir) @@ -81,42 +93,42 @@ def _run_generate(input_dir, queue: mp.Queue, **kwargs): @pytest.mark.parametrize("enable_lora", [False, True]) @pytest.mark.parametrize("tp_size", [1, 2]) -def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, - llama_3p2_1b_files, - monkeypatch: pytest.MonkeyPatch): +def test_sharded_state_loader( + enable_lora, tp_size, num_gpus_available, llama_3p2_1b_files +): if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - weights_patterns = ("*.safetensors", ) + weights_patterns = ("*.safetensors",) gpu_memory_utilization = 0.8 input_dir = llama_3p2_1b_files ctx = mp.get_context("spawn") - # The interface in v1 engine has changed, run in v1 engine will hang. - monkeypatch.setenv("VLLM_USE_V1", "0") # Run in separate processes for memory & CUDA isolation with TemporaryDirectory() as output_dir: - p = ctx.Process(target=_run_writer, - args=(input_dir, output_dir, weights_patterns), - kwargs=dict( - tensor_parallel_size=tp_size, - distributed_executor_backend="mp", - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=True, - )) + p = ctx.Process( + target=_run_writer, + args=(input_dir, output_dir, weights_patterns), + kwargs=dict( + tensor_parallel_size=tp_size, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=True, + ), + ) p.start() p.join() queue = ctx.Queue() - p = ctx.Process(target=_run_generate, - args=(input_dir, queue), - kwargs=dict( - distributed_executor_backend="mp", - enable_lora=enable_lora, - gpu_memory_utilization=gpu_memory_utilization, - tensor_parallel_size=tp_size, - )) + p = ctx.Process( + target=_run_generate, + args=(input_dir, queue), + kwargs=dict( + enable_lora=enable_lora, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, + ), + ) p.start() # Call queue.get() before p.join() to prevent deadlock: # If p.join() is called before queue.get() and the queue is full, @@ -130,15 +142,16 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, queue = ctx.Queue() - p = ctx.Process(target=_run_generate, - args=(output_dir, queue), - kwargs=dict( - distributed_executor_backend="mp", - enable_lora=enable_lora, - gpu_memory_utilization=gpu_memory_utilization, - tensor_parallel_size=tp_size, - load_format="sharded_state", - )) + p = ctx.Process( + target=_run_generate, + args=(output_dir, queue), + kwargs=dict( + enable_lora=enable_lora, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, + load_format="sharded_state", + ), + ) p.start() # Call queue.get() before p.join() to prevent deadlock: # If p.join() is called before queue.get() and the queue is full, diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 140f00294765d..12aad4cb8da0f 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,25 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional import pytest import torch from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.activation import (GeluAndMul, - ReLUSquaredActivation, - SiluAndMul) -from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, - vllm_topk_softmax) +from vllm.model_executor.layers.activation import ( + GeluAndMul, + ReLUSquaredActivation, + SiluAndMul, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + dispatch_topk_func, + vllm_topk_softmax, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, +) from vllm.model_executor.layers.layernorm import ( - RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, - rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) + RMSNorm, + dispatch_rocm_rmsnorm_func, + fused_add_rms_norm, + rms_norm, +) from vllm.platforms import current_platform +RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # Registered subclass for test @CustomOp.register("relu3") @@ -32,15 +41,15 @@ class Relu3(ReLUSquaredActivation): [ # Default values based on compile level # - All by default (no Inductor compilation) - ("", 0, False, [True] * 4, True), - ("", 1, True, [True] * 4, True), - ("", 2, False, [True] * 4, True), + (None, 0, False, [True] * 4, True), + (None, 1, True, [True] * 4, True), + (None, 2, False, [True] * 4, True), # - None by default (with Inductor) - ("", 3, True, [False] * 4, False), - ("", 4, True, [False] * 4, False), + (None, 3, True, [False] * 4, False), + (None, 4, True, [False] * 4, False), # - All by default (without Inductor) - ("", 3, False, [True] * 4, True), - ("", 4, False, [True] * 4, True), + (None, 3, False, [True] * 4, True), + (None, 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all @@ -52,7 +61,7 @@ class Relu3(ReLUSquaredActivation): # All but SiluAndMul ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,relu2", 3, False, [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), # RMSNorm and SiluAndMul ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm @@ -64,13 +73,21 @@ class Relu3(ReLUSquaredActivation): ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), - ]) -def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, - ops_enabled: list[int], default_on: bool): + ], +) +def test_enabled_ops( + env: Optional[str], + torch_level: int, + use_inductor: bool, + ops_enabled: list[int], + default_on: bool, +): + custom_ops = env.split(",") if env else [] vllm_config = VllmConfig( - compilation_config=CompilationConfig(use_inductor=bool(use_inductor), - level=torch_level, - custom_ops=env.split(","))) + compilation_config=CompilationConfig( + use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops + ) + ) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on @@ -98,43 +115,17 @@ def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, @pytest.mark.parametrize( - "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) + "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"] +) def test_enabled_ops_invalid(env: str): with pytest.raises(Exception): # noqa - vllm_config = VllmConfig(compilation_config=CompilationConfig( - custom_ops=env.split(","))) + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=env.split(",")) + ) with set_current_vllm_config(vllm_config): RMSNorm(1024).enabled() -@pytest.mark.skipif( - not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), - reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") -@pytest.mark.parametrize("use_cutlass", [True, False]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) -def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, - use_rocm_aiter_gemm_w8a8_blockscale: str, - monkeypatch): - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", - use_rocm_aiter_gemm_w8a8_blockscale) - - use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( - int(use_rocm_aiter_gemm_w8a8_blockscale))) - block_scale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) - if use_cutlass: - assert block_scale_func == cutlass_scaled_mm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == ( - torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) - else: - assert block_scale_func == w8a8_block_fp8_matmul - - @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) @@ -142,31 +133,44 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax) + rocm_aiter_topk_softmax, + ) + assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="AITER is a feature exclusive for ROCm") -def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, - use_rocm_aiter_norm: str, monkeypatch): +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" +) +def test_rms_norm_dispatch( + add_residual: bool, + dtype: torch.dtype, + use_rocm_aiter: str, + use_rocm_aiter_norm: str, + monkeypatch, +): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) - if not add_residual: - if current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_rms_norm - else: - assert rms_norm_func == rms_norm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_fused_add_rms_norm - else: + should_use_rocm_aiter = ( + current_platform.is_rocm() + and int(use_rocm_aiter) + and int(use_rocm_aiter_norm) + and dtype in RMS_NORM_SUPPORTED_DTYPES + ) + + if add_residual and should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + elif should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + elif add_residual: assert rms_norm_func == fused_add_rms_norm + else: + assert rms_norm_func == rms_norm diff --git a/tests/model_executor/test_logits_processor.py b/tests/model_executor/test_logits_processor.py deleted file mode 100644 index 532ebba038d38..0000000000000 --- a/tests/model_executor/test_logits_processor.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import patch - -import pytest -import torch - -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available - - -class MockLogitsProcessor(LogitsProcessor): - - def __init__(self, vocab_size: int, scale: float, - fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size, scale=scale) - self.fake_logits = fake_logits.clone() - - def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.logits_processor._prune_hidden_states", - lambda x, y: x - ), patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=input_tensor.dtype) - logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - return input_tensor, fake_logits, logits_processor - - -RANDOM_SEEDS = list(range(128)) -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_logits_processors(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - seq_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - logits_processor_output = logits_processor( - lm_head=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - - assert torch.isinf(logits_processor_output[:, 0]).all() - - fake_logits *= logits_processor.scale - torch.testing.assert_close(logits_processor_output[:, 1], - fake_logits[:, 1], - rtol=1e-4, - atol=0.0) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 0ade75b7e6228..489ac1e6475b9 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -5,8 +5,12 @@ import os import pytest -from vllm.model_executor.layers.pooler import (CLSPool, DispatchPooler, - MeanPool, PoolingType) +from vllm.model_executor.layers.pooler import ( + CLSPool, + DispatchPooler, + MeanPool, + PoolingType, +) from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.platforms import current_platform @@ -15,25 +19,28 @@ MAX_MODEL_LEN = 128 MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") REVISION = os.environ.get("REVISION", "main") -MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", - "intfloat/multilingual-e5-base") +MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", "intfloat/multilingual-e5-base") REVISION_ROBERTA = os.environ.get("REVISION", "main") -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_model_loading_with_params(vllm_runner, monkeypatch): """ Test parameter weight loading with tp>1. """ # to use apply_model monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_name=MODEL_NAME, - revision=REVISION, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=MODEL_NAME, + revision=REVISION, + dtype="float16", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_config = vllm_model.llm.llm_engine.model_config model_tokenizer = vllm_model.llm.llm_engine.tokenizer @@ -47,8 +54,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): assert model_config.pooler_config.normalize # asserts on the tokenizer loaded - assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" - assert model_tokenizer.tokenizer.model_max_length == 512 + assert model_config.tokenizer == "BAAI/bge-base-en-v1.5" + assert model_tokenizer.model_max_length == 512 def check_model(model): assert isinstance(model, BertEmbeddingModel) @@ -60,20 +67,24 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): assert output -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): """ Test parameter weight loading with tp>1. """ # to use apply_model monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_name=MODEL_NAME_ROBERTA, - revision=REVISION_ROBERTA, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=MODEL_NAME_ROBERTA, + revision=REVISION_ROBERTA, + dtype="float16", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_config = vllm_model.llm.llm_engine.model_config model_tokenizer = vllm_model.llm.llm_engine.tokenizer @@ -87,22 +98,22 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): assert model_config.pooler_config.normalize # asserts on the tokenizer loaded - assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-base" - assert model_tokenizer.tokenizer.model_max_length == 512 + assert model_config.tokenizer == "intfloat/multilingual-e5-base" + assert model_tokenizer.model_max_length == 512 def check_model(model): assert isinstance(model, RobertaEmbeddingModel) assert isinstance(pooler := model.pooler, DispatchPooler) - assert isinstance(pooler.poolers_by_task["embed"].pooling, - MeanPool) + assert isinstance(pooler.poolers_by_task["embed"].pooling, MeanPool) vllm_model.apply_model(check_model) assert output -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch): """ Test loading roberta-base model with no lm_head. @@ -110,14 +121,14 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch): # to use apply_model monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") model_name = "FacebookAI/roberta-base" - with vllm_runner(model_name=model_name, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=model_name, dtype="float16", max_model_len=MAX_MODEL_LEN + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) - model_tokenizer = vllm_model.llm.llm_engine.tokenizer - assert model_tokenizer.tokenizer_id == model_name + assert vllm_model.llm.llm_engine.model_config.tokenizer == model_name def check_model(model): assert isinstance(model, RobertaEmbeddingModel) diff --git a/tests/model_executor/test_weight_utils.py b/tests/model_executor/test_weight_utils.py index df625b8d60049..6dc120ddbac9a 100644 --- a/tests/model_executor/test_weight_utils.py +++ b/tests/model_executor/test_weight_utils.py @@ -9,23 +9,24 @@ import pytest from huggingface_hub.utils import LocalEntryNotFoundError from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, enable_hf_transfer) + download_weights_from_hf, + enable_hf_transfer, +) def test_hf_transfer_auto_activation(): if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: # in case it is already set, we can't test the auto activation - pytest.skip( - "HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") + pytest.skip("HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") enable_hf_transfer() try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa + HF_TRANSFER_ACTIVE = True except ImportError: HF_TRANSFER_ACTIVE = False - assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == - HF_TRANSFER_ACTIVE) + assert huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == HF_TRANSFER_ACTIVE def test_download_weights_from_hf(): @@ -34,22 +35,30 @@ def test_download_weights_from_hf(): # if offline is set and model is not cached huggingface_hub.constants.HF_HUB_OFFLINE = True with pytest.raises(LocalEntryNotFoundError): - download_weights_from_hf("facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) # download the model huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) # now it should work offline huggingface_hub.constants.HF_HUB_OFFLINE = True - assert download_weights_from_hf( - "facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) is not None + assert ( + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) + is not None + ) if __name__ == "__main__": diff --git a/tests/models/language/generation/test_bart.py b/tests/models/language/generation/test_bart.py deleted file mode 100644 index b4c771840196c..0000000000000 --- a/tests/models/language/generation/test_bart.py +++ /dev/null @@ -1,220 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.sequence import SampleLogprobs - -from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) -from ....utils import multi_gpu_test -from ...utils import check_logprobs_close - - -def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "</s>" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "<s>" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts: list[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - ''' - Test the vLLM BART model for a variety of encoder/decoder input prompts, - by validating it against HuggingFace (HF) BART. - - Arguments: - - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - * model: the HF ID of the specific BART variant under test - * dtype: the tensor datatype to employ - * max_tokens - * num_logprobs - * decoder_prompt_type: key into the example_encoder_decoder_prompts - dictionary; selects specific encoder/decoder - prompt scenarios to test - - A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - - The HF GenerationMixin's default behavior is to force the first - decoded token to be <BOS> if the prompt does not already contain - <BOS> (this is accomplished using a logit - processor setting.) - - So when we use HF BART as our baseline for comparison, note that - when the user provides a request with a None decoder prompt - (i.e. a singleton encoder prompt, or else an explicit encoder/ - decoder prompt with the decoder sub-prompt set to None), HF and - vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, - (2) append <decoder-start-token> to the beginning, yielding - [<decoder-start-token>], (3) pass this token list to the model, and - then (4) after computing logits during prefill, override the model - logits & force <BOS> to be the first generated token. - - * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder- - start-token to the beginning, yielding [<decoder-start-token><BOS>], - (3) pass these tokens to the model & proceed with generation. - - The net effect is that compared to vLLM, the list of HF *decoded* tokens - will contain one more initial <BOS> than the vLLM generated tokens, - because vLLM's <BOS> token is injected into the prompt rather than into - the generated output. This is in spite of the fact that overall, the - complete sequences (prompt + decoded tokens) produced by vLLM will match - HF. - - So when we use HF decoded token output to validate vLLM's decoded token - output, the testing process must account for the difference in decoded - token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - - One option is to disable the logit processor feature that forces the - <BOS> token to be decoded (forced_bos_token_id = None), eliminating - the problem entirely. However this is not "normal" BART usage. - - The other option is - only in the decoder-prompt-is-None case - to - discard the first decoded token from the HF output before comparing it - to vLLM. - - To that end, when testing the scenario where the decoder prompt is None - (and only in that one scenario), this test skips the first HF decoded - token during the process of validating the vLLM decoded output. - ''' - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default). - - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-existing - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) - - -@pytest.mark.parametrize( - "model", - [ - pytest.param("facebook/bart-base", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("facebook/bart-large-cnn"), - ], -) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) -def test_models_distributed(hf_runner, vllm_runner, - example_encoder_decoder_prompts, - distributed_executor_backend, model, dtype, - max_tokens, num_logprobs, - decoder_prompt_type) -> None: - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 57382914bfea8..b161cc7153b8f 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from typing import Optional import pytest @@ -13,10 +12,11 @@ from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not -# not have a clean way to fall back, so we fail with +# have a clean way to fall back, so we fail with # a clear msg when it happens. # https://github.com/vllm-project/vllm/issues/14524 -REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] +# NOTE(woosuk): Skipping these tests until V1 supports them. +# REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] # This list contains the model that are using AITER kernel. # Skip model that are not using AITER tests. @@ -39,7 +39,7 @@ AITER_MODEL_LIST = [ [ pytest.param( "bigscience/bloom-560m", # bloom - testing alibi slopes - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.slow_test], ), pytest.param( "openai-community/gpt2", # gpt2 @@ -50,7 +50,11 @@ AITER_MODEL_LIST = [ pytest.param("EleutherAI/pythia-70m"), # gpt_neox pytest.param( "google/gemma-1.1-2b-it", # gemma - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[ + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, + ], ), pytest.param( "zai-org/chatglm3-6b", # chatglm (text-only) @@ -62,8 +66,7 @@ AITER_MODEL_LIST = [ pytest.param( "openbmb/MiniCPM3-4B", # fused_moe not supported on CPU - marks=[pytest.mark.core_model, - large_gpu_mark(min_gb=32)], + marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)], ), pytest.param( "facebook/opt-125m", # opt @@ -71,14 +74,18 @@ AITER_MODEL_LIST = [ ), pytest.param( "microsoft/phi-2", # phi - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.slow_test], ), pytest.param( "Qwen/Qwen-7B-Chat", # qwen (text-only) ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[ + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, + ], ), pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) @@ -92,23 +99,31 @@ AITER_MODEL_LIST = [ pytest.param( "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], - ) - ]) + ), + pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, - monkeypatch) -> None: - + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + num_logprobs: int, + use_rocm_aiter: bool, + use_prompt_embeds: bool, + monkeypatch, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - if model in REQUIRES_V0: - monkeypatch.setenv("VLLM_USE_V1", "0") - if use_rocm_aiter and (model in AITER_MODEL_LIST): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") elif use_rocm_aiter and model not in AITER_MODEL_LIST: @@ -118,38 +133,39 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0" - with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds - else None) + prompt_embeds: Optional[list[torch.Tensor]] = [] if use_prompt_embeds else None prompt_token_ids = [] for prompt in example_prompts: - token_ids = hf_model.tokenizer(prompt, - return_tensors="pt").input_ids.to( - hf_model.model.device) + token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids.to( + hf_model.model.device + ) prompt_token_ids.append(token_ids) if prompt_embeds is not None: - prompt_embeds.append(hf_model.model.get_input_embeddings()( - token_ids).squeeze(0)) + prompt_embeds.append( + hf_model.model.get_input_embeddings()(token_ids).squeeze(0) + ) with vllm_runner( - model, - tokenizer_name=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - trust_remote_code=model_info.trust_remote_code, - max_num_seqs=2, - enable_prompt_embeds=use_prompt_embeds, + model, + tokenizer_name=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + max_num_seqs=2, + enable_prompt_embeds=use_prompt_embeds, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) if prompt_embeds is not None: vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( - prompt_embeds, max_tokens, num_logprobs) + prompt_embeds, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 60a4bc14be882..246b893be315d 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -11,17 +11,17 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: with monkeypatch.context() as m: m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner( - model, - load_format="dummy", + model, + load_format="dummy", ) as llm: if model == "google/gemma-3-4b-it": normalizers = llm.llm.collective_rpc( - lambda self: self.model_runner.model.language_model.model. - normalizer.cpu().item()) + lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() # noqa: E501 + ) config = llm.llm.llm_engine.model_config.hf_config.text_config else: normalizers = llm.llm.collective_rpc( - lambda self: self.model_runner.model.model.normalizer.cpu( - ).item()) + lambda self: self.model_runner.model.model.normalizer.cpu().item() + ) config = llm.llm.llm_engine.model_config.hf_config assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) diff --git a/tests/models/language/generation/test_granite.py b/tests/models/language/generation/test_granite.py index 2a39f78a708ee..e569e75ff3a82 100644 --- a/tests/models/language/generation/test_granite.py +++ b/tests/models/language/generation/test_granite.py @@ -26,11 +26,13 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2055c44c83cda..abedd15b0d7eb 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable + import pytest from tests.models.registry import HF_EXAMPLE_MODELS @@ -20,49 +22,31 @@ pytestmark = pytest.mark.hybrid_model SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - "yujiepan/mamba2-codestral-v0.1-tiny-random", + # mamba2-codestral in transformers is broken pending: + # https://github.com/huggingface/transformers/pull/40861 + # "yujiepan/mamba2-codestral-v0.1-tiny-random", ] HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # skipping until vLLM implementation issues are resolved - # "pfnet/plamo-2-1b", - "Zyphra/Zamba2-1.2B-instruct", - "hmellor/tiny-random-BambaForCausalLM", - "ibm-granite/granite-4.0-tiny-preview", - "tiiuae/Falcon-H1-0.5B-Base", - "LiquidAI/LFM2-1.2B", -] - -HF_UNSUPPORTED_MODELS = [ - # The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test - # doesn't compare vLLM output with HF output. - # See https://github.com/huggingface/transformers/pull/35943 - "yujiepan/mamba2-codestral-v0.1-tiny-random", - # transformers 4.55 is still producing garbage for this model - # TODO(tdoublep): follow-up on transformers side - "ibm-granite/granite-4.0-tiny-preview" -] - -V1_SUPPORTED_MODELS = [ - "state-spaces/mamba-130m-hf", - "ai21labs/Jamba-tiny-dev", - "yujiepan/mamba2-codestral-v0.1-tiny-random", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", "tiiuae/Falcon-H1-0.5B-Base", "LiquidAI/LFM2-1.2B", + "tiny-random/qwen3-next-moe", ] FULL_CUDA_GRAPH_MODELS = [ "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", ] -V0_UNSUPPORTED_MODELS = [ - "LiquidAI/LFM2-1.2B", +FP32_STATE_MODELS = [ + "state-spaces/mamba-130m-hf", + "Zyphra/Zamba2-1.2B-instruct", ] # Avoid OOM @@ -81,66 +65,32 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") - hf_version_check = model_info.check_transformers_version( - on_fail="return") + model_info.check_transformers_version(on_fail="skip") except ValueError: - hf_version_check = None - - if hf_version_check is not None: - print(f"Skipping transformers comparison because: {hf_version_check}") + pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS and hf_version_check is None: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None - - if model not in V0_UNSUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v0_outputs = None - - if model in V1_SUPPORTED_MODELS: - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v1_outputs = None - - if hf_outputs is not None and vllm_v0_outputs is not None: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs ) - if model in V1_SUPPORTED_MODELS: - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs - assert ref_outputs is not None - check_logprobs_close( - outputs_0_lst=ref_outputs, - outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", - name_1="vllm-v1", + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) -@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) + +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_batching( @@ -150,10 +100,6 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: - if model in V0_UNSUPPORTED_MODELS: - pytest.skip( - f"Unsupported V0 Engine. Skipping `test_batching` on {model}.") - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -164,13 +110,14 @@ def test_batching( for_loop_outputs = [] with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: - single_output, = vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs) + (single_output,) = vllm_model.generate_greedy_logprobs( + [prompt], max_tokens, num_logprobs + ) for_loop_outputs.append(single_output) batched_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=for_loop_outputs, @@ -180,42 +127,6 @@ def test_batching( ) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -def test_chunked_prefill( - vllm_runner, - example_prompts, - model: str, - max_tokens: int, - num_logprobs: int, - chunked_prefill_token_size: int, -) -> None: - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - with vllm_runner(model, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, num_logprobs) - - with vllm_runner(model, - enable_chunked_prefill=False, - max_num_seqs=max_num_seqs) as vllm_model: - non_chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - check_logprobs_close( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) - - @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [10]) def test_chunked_prefill_with_parallel_sampling( @@ -225,8 +136,8 @@ def test_chunked_prefill_with_parallel_sampling( max_tokens: int, ) -> None: """ - Tests chunked prefill in conjunction with n > 1. - + Tests chunked prefill in conjunction with n > 1. + In this case, prefill is populated with decoding tokens and we test that it doesn't fail. @@ -234,16 +145,13 @@ def test_chunked_prefill_with_parallel_sampling( decoding steps inside a chunked prefill forward pass (where we have both prefill and decode together) """ - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) + sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens) with vllm_runner( - model, - enable_chunked_prefill=True, - # forces prefill chunks with decoding - max_num_batched_tokens=MAX_NUM_SEQS * 3, - max_num_seqs=MAX_NUM_SEQS, + model, + enable_chunked_prefill=True, + # forces prefill chunks with decoding + max_num_batched_tokens=MAX_NUM_SEQS * 3, + max_num_seqs=MAX_NUM_SEQS, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -261,10 +169,8 @@ def test_mamba_cache_cg_padding( batch size. If it's not, a torch RuntimeError will be raised because tensor dimensions aren't compatible. """ - vllm_config = EngineArgs(model=model, - trust_remote_code=True).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph( - len(example_prompts)): + vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() + while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): example_prompts.append(example_prompts[0]) try: @@ -274,35 +180,8 @@ def test_mamba_cache_cg_padding( pytest.fail( "Couldn't run batch size which is not equal to a Cuda Graph " "captured batch size. " - "Could be related to mamba cache not padded correctly") - - -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_models_preemption_recompute( - vllm_runner, - example_prompts, - model: str, - max_tokens: int, -) -> None: - """ - Tests that outputs are identical with and w/o preemptions (recompute). - """ - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - scheduler = vllm_model.llm.llm_engine.scheduler[0] - scheduler.ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - scheduler.ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) + "Could be related to mamba cache not padded correctly" + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -317,15 +196,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( finished_requests_ids is larger than the maximum mamba block capacity. This could generally happen due to the fact that hybrid does support - statelessness mechanism where it can cleanup new incoming requests in + statelessness mechanism where it can clean up new incoming requests in a single step. """ try: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") + pytest.fail( + "Hybrid inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily " + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -334,19 +215,21 @@ def test_state_cleanup( example_prompts, model: str, ) -> None: - """ + """ This test is for verifying that the Hybrid state is cleaned up between steps. - - If its not cleaned, an error would be expected. + + If it's not cleaned, an error would be expected. """ try: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") + pytest.fail( + "Hybrid inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids" + ) @multi_gpu_test(num_gpus=2) @@ -360,15 +243,19 @@ def test_distributed_correctness( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model, tensor_parallel_size=1, - max_num_seqs=2) as vllm_model: + with vllm_runner( + model, tensor_parallel_size=1, max_num_seqs=MAX_NUM_SEQS + ) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, tensor_parallel_size=2, - max_num_seqs=2) as vllm_model: + with vllm_runner( + model, tensor_parallel_size=2, max_num_seqs=MAX_NUM_SEQS + ) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=vllm_outputs_tp_1, @@ -390,7 +277,6 @@ def test_full_cuda_graph( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -399,53 +285,30 @@ def test_full_cuda_graph( pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None - - if model not in V0_UNSUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v0_outputs = None - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - compilation_config={'full_cuda_graph': True}, - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - if hf_outputs is not None and vllm_v0_outputs is not None: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs + ) + + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs ) - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs - assert ref_outputs is not None check_logprobs_close( - outputs_0_lst=ref_outputs, - outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", - name_1="vllm-v1", + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) -@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("model", FP32_STATE_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_fp32_state( +@pytest.mark.parametrize( + "cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"] +) +def test_fp32_cache_state( hf_runner, vllm_runner, example_prompts, @@ -453,8 +316,8 @@ def test_fp32_state( model: str, max_tokens: int, num_logprobs: int, + cache_dtype_param: str, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -463,42 +326,434 @@ def test_fp32_state( pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None - - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32", - enable_prefix_caching=False) as vllm_model: - vllm_v1_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - if hf_outputs is not None: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs + ) + + with vllm_runner( + model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"} + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs ) - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs check_logprobs_close( - outputs_0_lst=ref_outputs, - outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", - name_1="vllm-v1", + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) + + +# Helper functions for the APC tests +def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): + return { + "model_name": model, + "enable_prefix_caching": False, + "max_model_len": max_model_len, + "tensor_parallel_size": tensor_parallel_size, + "gpu_memory_utilization": 0.4, + } + + +def _get_vLLM_output( + vllm_runner, + kwargs, + prompts, + max_tokens, + num_logprobs, + num_repetitions=1, + vllm_model=None, +): + outs = [] + if vllm_model is None: + vllm_model = vllm_runner(**kwargs) + for _ in range(num_repetitions): + if num_logprobs < 0: + vllm_output = vllm_model.generate_greedy(prompts, max_tokens) + else: + vllm_output = vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) + outs.append(vllm_output) + + return outs, vllm_model + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_single_prompt( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * example_prompts[0]] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_single_prompt_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + generated_prompts = ["The president of the United States is " * MULTIPLE] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size + + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + + mamba_block_size_multiplier = 10 + for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]: + vllm_runner_kwargs["max_num_batched_tokens"] = ( + mamba_block_size_multiplier * mamba_block_size - offsets + ) + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_all_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + prompt_text = "The president of the United States is " + prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] + generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size + + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + + mamba_block_size_multiplier = 10 + for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]: + vllm_runner_kwargs["max_num_batched_tokens"] = ( + mamba_block_size_multiplier * mamba_block_size - offsets + ) + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_partial_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + # Cache only part of all the prompts + vllm_runner_kwargs["enable_prefix_caching"] = True + vllm_outputs_partial_cache, vllm_model = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs + ) + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0][:3], + outputs_1_lst=vllm_outputs_partial_cache[0], + name_0="vllm_no_cache", + name_1="vllm_partial_cache", + ) + + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + vllm_model=vllm_model, + ) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) diff --git a/tests/models/language/generation/test_mbart.py b/tests/models/language/generation/test_mbart.py deleted file mode 100644 index 854a72713943b..0000000000000 --- a/tests/models/language/generation/test_mbart.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.sequence import SampleLogprobs - -from ....conftest import DecoderPromptType, HfRunner, VllmRunner -from ...utils import check_logprobs_close - - -def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - hf_output_str = output_str + "</s>" - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts: list[dict[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - ''' - Test the vLLM mBART model by validating it against HuggingFace (HF). - (Docstring content is omitted for brevity) - ''' - - vllm_prompts = prompts - if decoder_prompt_type == DecoderPromptType.NONE: - vllm_prompts = [{ - "encoder_prompt": p['encoder_prompt'], - "decoder_prompt": "" - } for p in prompts] - - vllm_kwargs = { - "hf_overrides": { - "architectures": ["MBartForConditionalGeneration"] - } - } - - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - **vllm_kwargs) as vllm_model: # type: ignore - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - vllm_prompts, max_tokens, num_logprobs) - - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_kwargs["decoder_start_token_id"] = ( - hf_model.tokenizer.lang_code_to_id["ro_RO"]) - - hf_outputs = ( - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, # HF runner still uses the original prompts - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = 0 - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) - - -@pytest.mark.parametrize( - "model", - [pytest.param("facebook/mbart-large-en-ro")], -) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index af51a60edfd62..0ae83ec16020a 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -6,7 +6,9 @@ import json import pytest from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( - MistralToolCall, MistralToolParser) + MistralToolCall, + MistralToolParser, +) from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -20,7 +22,7 @@ MISTRAL_FORMAT_MODELS = [ "mistralai/Mistral-7B-Instruct-v0.3", # uses the v3-Tekken tokenizer "mistralai/Ministral-8B-Instruct-2410", - # Mistral-Nemo is to big for CI, but passes locally + # Mistral-Nemo is too big for CI, but passes locally # "mistralai/Mistral-Nemo-Instruct-2407" ] @@ -33,136 +35,118 @@ SYMBOLIC_LANG_PROMPTS = [ ] # for function calling -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. " + "'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that " + "the city is in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, }, -}, { - "type": "function", - "function": { - "name": "rewrite", - "description": "Rewrites text", - "parameters": { - "type": "object", - "required": [], - "properties": { - "text": { - "type": "string", - "description": "The input text to rewrite." - } - } - } - } -}] + { + "type": "function", + "function": { + "name": "rewrite", + "description": "Rewrites text", + "parameters": { + "type": "object", + "required": [], + "properties": { + "text": { + "type": "string", + "description": "The input text to rewrite.", + } + }, + }, + }, + }, +] MSGS = [ + {"role": "system", "content": "You are an assistant."}, { - "role": "system", - "content": "You are an assistant." - }, - { - "role": - "user", - "content": - "Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa - }, - { - "role": - "assistant", - "content": - "", - "tool_calls": [{ - "id": "bbc5b7ede", - "type": "function", - "function": { - "name": - "rewrite", - "arguments": - '{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa - } - }] - }, - { - "role": "tool", - "content": - "{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa - "tool_call_id": "bbc5b7ede", - "name": "rewrite" + "role": "user", + "content": "Could you please rewrite the below article? \n\n My English needs " + "improvving, maybe I make errors.", }, { "role": "assistant", - "content": "---\n\nMy English needs improving, maybe I make errors" + "content": "", + "tool_calls": [ + { + "id": "bbc5b7ede", + "type": "function", + "function": { + "name": "rewrite", + "arguments": '{"text":"My English needs improvving, maybe ' + 'I make errors."}', + }, + } + ], }, { - "role": - "user", - "content": ("Can you tell me what the temperate" - " will be in Dallas, in fahrenheit?") - } + "role": "tool", + "content": '{"action":"rewrite","outcome":"My English needs improving, maybe ' + 'I make errors."}', + "tool_call_id": "bbc5b7ede", + "name": "rewrite", + }, + { + "role": "assistant", + "content": "---\n\nMy English needs improving, maybe I make errors", + }, + { + "role": "user", + "content": ( + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + ), + }, ] SAMPLE_JSON_SCHEMA = { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } @@ -170,17 +154,25 @@ SAMPLE_JSON_SCHEMA = { @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: + with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -194,27 +186,35 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="mistral", - load_format="mistral", - config_format="mistral", + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", ) as mistral_format_model: mistral_format_outputs = mistral_format_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="auto", - load_format="safetensors", - config_format="hf", + model, + dtype=dtype, + tokenizer_mode="auto", + load_format="safetensors", + config_format="hf", ) as hf_format_model: hf_format_outputs = hf_format_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_format_outputs, @@ -226,34 +226,35 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages(vllm_runner, model: str, - dtype: str) -> None: - with vllm_runner(model, - dtype=dtype, - max_model_len=8192, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") as vllm_model: +def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str) -> None: + with vllm_runner( + model, + dtype=dtype, + max_model_len=8192, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", + ) as vllm_model: for prompt in SYMBOLIC_LANG_PROMPTS: msg = {"role": "user", "content": prompt} - outputs = vllm_model.llm.chat([msg], - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.llm.chat([msg], sampling_params=SAMPLING_PARAMS) assert "�" not in outputs[0].outputs[0].text.strip() @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: - with vllm_runner(model, - dtype=dtype, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") as vllm_model: - + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", + ) as vllm_model: msgs = copy.deepcopy(MSGS) - outputs = vllm_model.llm.chat(msgs, - tools=TOOLS, - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.llm.chat( + msgs, tools=TOOLS, sampling_params=SAMPLING_PARAMS + ) tokenizer = vllm_model.llm.get_tokenizer() tool_parser = MistralToolParser(tokenizer) @@ -265,15 +266,16 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: assert parsed_message.tools_called assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id) - assert parsed_message.tool_calls[ - 0].function.name == "get_current_weather" - assert parsed_message.tool_calls[ - 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa + assert parsed_message.tool_calls[0].function.name == "get_current_weather" + assert ( + parsed_message.tool_calls[0].function.arguments + == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' + ) # noqa assert parsed_message.content is None def test_mistral_function_call_nested_json(): - """Ensure that the function-name regex captures the entire outer-most + """Ensure that the function-name regex captures the entire outermost JSON block, including nested braces.""" # Create a minimal stub tokenizer that provides the few attributes the @@ -297,17 +299,10 @@ def test_mistral_function_call_nested_json(): "city": "Dallas", "state": "TX", "unit": "fahrenheit", - "sub_dict": { - "foo": "bar", - "inner": { - "x": 1, - "y": 2 - } - }, + "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}}, } - model_output = ( - f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}") + model_output = f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}" parsed = parser.extract_tool_calls(model_output, None) diff --git a/tests/models/language/generation/test_phimoe.py b/tests/models/language/generation/test_phimoe.py index 6c9cc2821c30f..e640655784ccb 100644 --- a/tests/models/language/generation/test_phimoe.py +++ b/tests/models/language/generation/test_phimoe.py @@ -15,62 +15,56 @@ MODELS = [ def test_phimoe_routing_function(): from vllm.model_executor.models.phimoe import phimoe_routing_function + test_case = { 0: { - "hidden_states": - torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float32, - requires_grad=False).view(4, 2), - "gating_output": - torch.tensor([0.1, 0.2, 0.3, 0.4], - dtype=torch.float32, - requires_grad=False), - "topk": - 2, - "renormalize": - False, + "hidden_states": torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32, requires_grad=False + ).view(4, 2), + "gating_output": torch.tensor( + [0.1, 0.2, 0.3, 0.4], dtype=torch.float32, requires_grad=False + ), + "topk": 2, + "renormalize": False, }, 1: { - "hidden_states": - torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float32, - requires_grad=False).view(4, 2), - "gating_output": - torch.tensor([0.4, 0.2, 0.3, 0.4], - dtype=torch.float32, - requires_grad=False), - "topk": - 2, - "renormalize": - False, - } + "hidden_states": torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32, requires_grad=False + ).view(4, 2), + "gating_output": torch.tensor( + [0.4, 0.2, 0.3, 0.4], dtype=torch.float32, requires_grad=False + ), + "topk": 2, + "renormalize": False, + }, } ground_truth = { 0: { - "topk_weights": - torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False), - "topk_ids": - torch.tensor([3, 2], dtype=torch.long, requires_grad=False), + "topk_weights": torch.tensor( + [1.0, 1.0], dtype=torch.float32, requires_grad=False + ), + "topk_ids": torch.tensor([3, 2], dtype=torch.long, requires_grad=False), }, 1: { - "topk_weights": - torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False), - "topk_ids": - torch.tensor([0, 3], dtype=torch.long, requires_grad=False), - } + "topk_weights": torch.tensor( + [0.5, 1.0], dtype=torch.float32, requires_grad=False + ), + "topk_ids": torch.tensor([0, 3], dtype=torch.long, requires_grad=False), + }, } for test_id in test_case: topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id]) - assert torch.allclose(topk_weights, - ground_truth[test_id]["topk_weights"]) + assert torch.allclose(topk_weights, ground_truth[test_id]["topk_weights"]) assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="This test takes a lot time to run on CPU, " - "and vllm CI's disk space is not enough for this model.") +@pytest.mark.skipif( + condition=current_platform.is_cpu(), + reason="This test takes a lot time to run on CPU, " + "and vllm CI's disk space is not enough for this model.", +) @large_gpu_test(min_gb=80) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -87,11 +81,13 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/mistral_tool_use/__init__.py b/tests/models/language/generation_ppl_test/__init__.py similarity index 100% rename from tests/mistral_tool_use/__init__.py rename to tests/models/language/generation_ppl_test/__init__.py diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py new file mode 100644 index 0000000000000..43f6066b1c85e --- /dev/null +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/docs/transformers/perplexity +from typing import Optional, cast + +import pytest +import torch +from datasets import load_dataset + +import tests.ci_envs as ci_envs +from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs +from vllm.logprobs import Logprob + +# See #24485 +PPL_TOL = 0.01 +MAX_LENGTH = 1024 + + +@torch.inference_mode +def wikitext_ppl_test( + hf_runner, + vllm_runner, + model_info: GenerateModelInfo, + max_length=MAX_LENGTH, + vllm_extra_kwargs=None, + atol=PPL_TOL, +): + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + pytest.skip("Skipping test.") + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner( + model_info.name, + gpu_memory_utilization=0.7, + max_model_len=max_length, + max_num_seqs=1, + **vllm_extra_kwargs, + ) as vllm_model: + # Use max_num_seqs=1 to avoid OOM, + # and avoid batch different requests together. + + model_config = vllm_model.llm.llm_engine.model_config + + # Confirm whether vllm is using the correct architecture + if model_info.architecture: + assert model_info.architecture in model_config.architectures + + max_length = min(model_config.max_model_len - 1, max_length) + stride = max_length + + tokenizer = vllm_model.llm.get_tokenizer() + tokens = tokenizer.encode("\n\n".join(dataset["text"])) + n_tokens = len(tokens) + + chunks = [] + for begin_loc in range(0, n_tokens, stride): + end_loc = min(begin_loc + max_length, n_tokens) + chunks.append(tokens[begin_loc:end_loc]) + + outputs = vllm_model.generate_greedy_logprobs( + prompts=chunks, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0, + use_tqdm=False, + ) + nll_sum = torch.tensor(0.0, dtype=torch.float32, device="cpu") + n_tokens = 0 + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + neg_log_likelihood = -torch.tensor( + token_log_probs, dtype=torch.float32, device="cpu" + ).sum() + nll_sum += neg_log_likelihood + n_tokens += len(token_log_probs) + vllm_ppl = float(torch.exp(nll_sum / n_tokens)) + vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype + + # Accelerate ppl test by setting Transformers ppl score to a constant + if model_info.hf_ppl is None: + with hf_runner( + model_info.name, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: + nll_sum = torch.tensor(0.0, dtype=torch.float32, device="cpu") + n_tokens = 0 + for chunk in chunks: + inputs = hf_model.wrap_device({"input_ids": torch.tensor([chunk])}) + input_ids = inputs["input_ids"] + outputs = hf_model.model(input_ids, labels=input_ids) + neg_log_likelihood = outputs.loss + + neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu() + + num_loss_tokens = len(chunk) - 1 + nll_sum += neg_log_likelihood * num_loss_tokens + n_tokens += num_loss_tokens + + hf_ppl = float(torch.exp(nll_sum / n_tokens)) + hf_dtype = next(hf_model.model.parameters()).dtype + else: + hf_ppl = model_info.hf_ppl + hf_dtype = "Constant" + + differ = (vllm_ppl - hf_ppl) / hf_ppl + print("Model:", model_info.name) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl) + print("Transformers:", hf_dtype, hf_ppl) + print("Difference (%):", differ * 100) + + # PPL the smaller, the better + # We are not concerned that the vllm PPL is less than Transformers, + # so we only perform one-sided testing. + assert differ < atol diff --git a/tests/models/language/generation_ppl_test/test_gemma.py b/tests/models/language/generation_ppl_test/test_gemma.py new file mode 100644 index 0000000000000..5324de143d674 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_gemma.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [ + GenerateModelInfo("google/gemma-2b"), + GenerateModelInfo("google/gemma-2-2b"), + GenerateModelInfo("google/gemma-3-4b-it"), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/generation_ppl_test/test_gpt.py b/tests/models/language/generation_ppl_test/test_gpt.py new file mode 100644 index 0000000000000..f3f9e55a24234 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_gpt.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [GenerateModelInfo("openai-community/gpt2-large")] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/generation_ppl_test/test_qwen.py b/tests/models/language/generation_ppl_test/test_qwen.py new file mode 100644 index 0000000000000..0d3127cbaac47 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_qwen.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [ + GenerateModelInfo("Qwen/Qwen3-0.6B"), + GenerateModelInfo("Qwen/Qwen3-0.6B-FP8"), + # transformers: + # Loading a GPTQ quantized model requires optimum, gptqmodel + # GenerateModelInfo("Qwen/Qwen3-0.6B-GPTQ-Int8"), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 61c5fcab4f8a4..261ab80ae86bc 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -6,8 +6,7 @@ from typing import Optional import pytest from tests.conftest import HfRunner -from tests.models.utils import (EmbedModelInfo, check_embeddings_close, - matryoshka_fy) +from tests.models.utils import EmbedModelInfo, check_embeddings_close, matryoshka_fy def run_embedding_correctness_test( @@ -29,16 +28,15 @@ def run_embedding_correctness_test( ) -def correctness_test_embed_models(hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - example_prompts, - vllm_extra_kwargs=None, - hf_model_callback=None): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") +def correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + example_prompts, + vllm_extra_kwargs=None, + hf_model_callback=None, +): + pytest.skip("Debug only, ci prefers to use mteb test.") # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" @@ -51,18 +49,19 @@ def correctness_test_embed_models(hf_runner, vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs["dtype"] = model_info.dtype - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + with vllm_runner( + model_info.name, runner="pooling", max_model_len=None, **vllm_extra_kwargs + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( - model_info.name, - dtype="float32", - is_sentence_transformer=True, + model_info.name, + dtype=model_info.hf_dtype, + is_sentence_transformer=True, ) as hf_model: - if hf_model_callback is not None: hf_model_callback(hf_model) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py deleted file mode 100644 index 4a1f8a53d024c..0000000000000 --- a/tests/models/language/pooling/mteb_utils.py +++ /dev/null @@ -1,315 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import tempfile -from collections.abc import Sequence -from typing import Optional - -import mteb -import numpy as np -import pytest -import requests - -from tests.models.utils import EmbedModelInfo, RerankModelInfo - -# Most embedding models on the STS12 task (See #17175): -# - Model implementation and minor changes in tensor dtype -# results in differences less than 1e-4 -# - Different model results in differences more than 1e-3 -# 1e-4 is a good tolerance threshold -MTEB_EMBED_TASKS = ["STS12"] -MTEB_EMBED_TOL = 0.02 - -# See #19344 -MTEB_RERANK_TASKS = ["NFCorpus"] -MTEB_RERANK_LANGS = ["en"] -MTEB_RERANK_TOL = 2e-3 - - -class VllmMtebEncoder(mteb.Encoder): - - def __init__(self, vllm_model): - super().__init__() - self.llm = vllm_model - self.rng = np.random.default_rng(seed=42) - - def encode( - self, - sentences: Sequence[str], - *args, - **kwargs, - ) -> np.ndarray: - # Hoping to discover potential scheduling - # issues by randomizing the order. - r = self.rng.permutation(len(sentences)) - sentences = [sentences[i] for i in r] - outputs = self.llm.embed(sentences, use_tqdm=False) - embeds = np.array(outputs) - embeds = embeds[np.argsort(r)] - return embeds - - def predict( - self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt - *args, - **kwargs, - ) -> np.ndarray: - r = self.rng.permutation(len(sentences)) - sentences = [sentences[i] for i in r] - - queries = [s[0] for s in sentences] - corpus = [s[1] for s in sentences] - - outputs = self.llm.score(queries, - corpus, - truncate_prompt_tokens=-1, - use_tqdm=False) - scores = np.array(outputs) - scores = scores[np.argsort(r)] - return scores - - -class OpenAIClientMtebEncoder(mteb.Encoder): - - def __init__(self, model_name: str, client): - super().__init__() - self.model_name = model_name - self.client = client - self.rng = np.random.default_rng(seed=42) - - def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: - # Hoping to discover potential scheduling - # issues by randomizing the order. - r = self.rng.permutation(len(sentences)) - sentences = [sentences[i] for i in r] - - embeddings = self.client.embeddings.create(model=self.model_name, - input=sentences) - outputs = [d.embedding for d in embeddings.data] - embeds = np.array(outputs) - embeds = embeds[np.argsort(r)] - return embeds - - -class ScoreClientMtebEncoder(mteb.Encoder): - - def __init__(self, model_name: str, url): - super().__init__() - self.model_name = model_name - self.url = url - self.rng = np.random.default_rng(seed=42) - - def predict( - self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt - *args, - **kwargs, - ) -> np.ndarray: - r = self.rng.permutation(len(sentences)) - sentences = [sentences[i] for i in r] - - outputs = [] - for query, corpus, prompt in sentences: - outputs.append(self.get_score(query, corpus)) - - scores = np.array(outputs) - scores = scores[np.argsort(r)] - return scores - - def get_score(self, query, corpus): - response = requests.post(self.url, - json={ - "model": self.model_name, - "text_1": query, - "text_2": corpus, - "truncate_prompt_tokens": -1, - }).json() - return response['data'][0]["score"] - - -class RerankClientMtebEncoder(ScoreClientMtebEncoder): - - def get_score(self, query, corpus): - response = requests.post(self.url, - json={ - "model": self.model_name, - "query": query, - "documents": [corpus], - "truncate_prompt_tokens": -1, - }).json() - return response['results'][0]["relevance_score"] - - -def run_mteb_embed_task(encoder, tasks): - tasks = mteb.get_tasks(tasks=tasks) - evaluation = mteb.MTEB(tasks=tasks) - results = evaluation.run( - encoder, - verbosity=0, - output_folder=None, - encode_kwargs={ - "show_progress_bar": False, - }, - ) - - main_score = results[0].scores["test"][0]["main_score"] - return main_score - - -def mteb_test_embed_models(hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - vllm_extra_kwargs=None, - hf_model_callback=None, - atol=MTEB_RERANK_TOL): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") - - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype - - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - enforce_eager=True, - **vllm_extra_kwargs) as vllm_model: - - model_config = vllm_model.llm.llm_engine.model_config - - if model_info.architecture: - assert model_info.architecture in model_config.architectures - assert (model_config._model_info.default_pooling_type == - model_info.default_pooling_type) - - vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), - MTEB_EMBED_TASKS) - vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype - - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype="float32") as hf_model: - - if hf_model_callback is not None: - hf_model_callback(hf_model) - - st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) - st_dtype = next(hf_model.model.parameters()).dtype - - print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) - print("SentenceTransformers:", st_dtype, st_main_score) - print("Difference:", st_main_score - vllm_main_score) - - assert st_main_score == pytest.approx(vllm_main_score, abs=atol) - - -def run_mteb_rerank(cross_encoder, tasks, languages): - with tempfile.TemporaryDirectory() as results_folder: - bm25s = mteb.get_model("bm25s") - tasks = mteb.get_tasks(tasks=tasks, languages=languages) - - subset = "default" - eval_splits = ["test"] - - evaluation = mteb.MTEB(tasks=tasks) - evaluation.run( - bm25s, - verbosity=0, - eval_splits=eval_splits, - save_predictions=True, - output_folder=f"{results_folder}/stage1", - encode_kwargs={"show_progress_bar": False}, - ) - - results = evaluation.run( - cross_encoder, - verbosity=0, - eval_splits=eval_splits, - top_k=10, - save_predictions=True, - output_folder=f"{results_folder}/stage2", - previous_results= - f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", - encode_kwargs={"show_progress_bar": False}, - ) - main_score = results[0].scores["test"][0]["main_score"] - return main_score - - -def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None): - with hf_runner(model_name, is_cross_encoder=True, - dtype="float32") as hf_model: - - original_predict = hf_model.predict - - def _predict( - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt - *args, - **kwargs, - ): - # vllm and st both remove the prompt, fair comparison. - prompts = [(s[0], s[1]) for s in sentences] - return original_predict(prompts, *args, **kwargs, batch_size=8) - - hf_model.predict = _predict - hf_model.original_predict = original_predict - - if hf_model_callback is not None: - hf_model_callback(hf_model) - - st_main_score = run_mteb_rerank(hf_model, - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) - st_dtype = next(hf_model.model.model.parameters()).dtype - return st_main_score, st_dtype - - -def mteb_test_rerank_models(hf_runner, - vllm_runner, - model_info: RerankModelInfo, - vllm_extra_kwargs=None, - hf_model_callback=None, - vllm_mteb_encoder=VllmMtebEncoder, - atol=MTEB_RERANK_TOL): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") - - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype - - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - max_num_seqs=8, - enforce_eager=True, - **vllm_extra_kwargs) as vllm_model: - - model_config = vllm_model.llm.llm_engine.model_config - - if model_info.architecture: - assert (model_info.architecture in model_config.architectures) - assert model_config.hf_config.num_labels == 1 - assert (model_config._model_info.default_pooling_type == - model_info.default_pooling_type) - - vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) - vllm_dtype = model_config.dtype - - st_main_score, st_dtype = mteb_test_rerank_models_hf( - hf_runner, model_info.name, hf_model_callback) - - print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) - print("SentenceTransformers:", st_dtype, st_main_score) - print("Difference:", st_main_score - vllm_main_score) - - assert st_main_score == pytest.approx(vllm_main_score, abs=atol) diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py index 15e24c59d1dd9..e95119df95c71 100644 --- a/tests/models/language/pooling/test_auto_prefix_cache_support.py +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -4,8 +4,7 @@ import pytest import torch from transformers import AutoModelForSequenceClassification -from tests.models.language.pooling.embed_utils import ( - run_embedding_correctness_test) +from tests.models.language.pooling.embed_utils import run_embedding_correctness_test @pytest.mark.parametrize( @@ -20,28 +19,27 @@ def test_classify_models( model: str, dtype: str, ) -> None: - example_prompts = example_prompts * 2 - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - enable_prefix_caching=True) as vllm_model: + with vllm_runner( + model, max_model_len=512, dtype=dtype, enable_prefix_caching=True + ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) @pytest.mark.parametrize( @@ -59,18 +57,18 @@ def test_embed_models( example_prompts = [str(s).strip() for s in example_prompts] * 2 with vllm_runner( - model, - runner="pooling", - max_model_len=None, - enable_prefix_caching=True, + model, + runner="pooling", + max_model_len=None, + enable_prefix_caching=True, ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( - model, - is_sentence_transformer=True, + model, + is_sentence_transformer=True, ) as hf_model: run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) @@ -81,13 +79,14 @@ def test_embed_models( "intfloat/e5-small", "Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False "papluca/xlm-roberta-base-language-detection", - ]) + ], +) @pytest.mark.parametrize("dtype", ["half"]) -def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str) -> None: - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - enable_prefix_caching=True) as vllm_model: +def test_non_causal_models( + hf_runner, vllm_runner, example_prompts, model: str, dtype: str +) -> None: + with vllm_runner( + model, max_model_len=512, dtype=dtype, enable_prefix_caching=True + ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert not cache_config.enable_prefix_caching diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py deleted file mode 100644 index 6fbe0e82d7f8a..0000000000000 --- a/tests/models/language/pooling/test_baai.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, LASTPoolingEmbedModelInfo, - RerankModelInfo) -from .embed_utils import correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models - -MODELS = [ - ########## BertModel - CLSPoolingEmbedModelInfo("BAAI/bge-base-en", - architecture="BertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("BAAI/bge-base-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-en", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-en", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5", - architecture="BertModel", - enable_test=False), - ########## XLMRobertaModel - CLSPoolingEmbedModelInfo("BAAI/bge-m3", - architecture="XLMRobertaModel", - enable_test=True), - ########## Qwen2Model - LASTPoolingEmbedModelInfo("BAAI/bge-code-v1", - architecture="Qwen2Model", - dtype="float32", - enable_test=True), -] - -RERANK_MODELS = [ - ########## XLMRobertaForSequenceClassification - CLSPoolingRerankModelInfo( - "BAAI/bge-reranker-base", - architecture="XLMRobertaForSequenceClassification", - enable_test=True), - CLSPoolingRerankModelInfo( - "BAAI/bge-reranker-large", - architecture="XLMRobertaForSequenceClassification", - enable_test=False), - CLSPoolingRerankModelInfo( - "BAAI/bge-reranker-v2-m3", - architecture="XLMRobertaForSequenceClassification", - enable_test=False) -] - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - mteb_test_embed_models(hf_runner, vllm_runner, model_info) - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) - - -@pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: - mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index c71fa96275335..471826f214d0c 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -10,12 +10,17 @@ from vllm.platforms import current_platform @pytest.mark.parametrize( "model", [ - pytest.param("jason9693/Qwen2.5-1.5B-apeach", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "jason9693/Qwen2.5-1.5B-apeach", + marks=[ + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, + ], + ), ], ) -@pytest.mark.parametrize("dtype", - ["half"] if current_platform.is_rocm() else ["float"]) +@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"]) def test_models( hf_runner, vllm_runner, @@ -32,9 +37,9 @@ def test_models( with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) # check logits difference @@ -45,5 +50,6 @@ def test_models( # the tolerance value of 1e-2 is selected based on the # half datatype tests in # tests/models/language/pooling/test_embedding.py - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling/test_cross_encoder.py deleted file mode 100644 index 8c1bc5779b8a1..0000000000000 --- a/tests/models/language/pooling/test_cross_encoder.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - -from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo, - RerankModelInfo) -from .mteb_utils import mteb_test_rerank_models - -RERANK_MODELS = [ - CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", - architecture="BertForSequenceClassification"), - LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", - architecture="Qwen3ForSequenceClassification") -] - - -@pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: - mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 2dd35c4151580..c9574dca498ee 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -7,15 +7,7 @@ import pytest from vllm.config import PoolerConfig from vllm.platforms import current_platform -from ...utils import check_embeddings_close, check_transformers_version - - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass +from ...utils import check_embeddings_close @pytest.mark.parametrize( @@ -26,27 +18,34 @@ def v1(run_with_both_engines): # case won't pass because gte-Qwen2-1.5B-instruct will cache custom # model code with bidirectional attention. # [Decoder-only] - pytest.param("BAAI/bge-multilingual-gemma2", - marks=[pytest.mark.core_model]), + pytest.param( + "BAAI/bge-multilingual-gemma2", + marks=[pytest.mark.core_model, pytest.mark.slow_test], + ), pytest.param( "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window - marks=[pytest.mark.core_model]), - # the qwen models interfere with each other (see PR - # https://github.com/vllm-project/vllm/pull/18720). - # To avoid this problem, for now we skip v0 since it will be - # deprecated anyway. - pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", - marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), + marks=[pytest.mark.core_model], + ), + pytest.param( + "ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model] + ), # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param( + "BAAI/bge-base-en-v1.5", + marks=[ + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, + ], + ), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - marks=[pytest.mark.skip_v1]), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2", - marks=[pytest.mark.skip_v1]), + pytest.param( + "sentence-transformers/stsb-roberta-base-v2", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) def test_models( @@ -56,9 +55,6 @@ def test_models( model, monkeypatch, ) -> None: - if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - check_transformers_version(model, max_transformers_version="4.53.2") - if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend @@ -66,13 +62,14 @@ def test_models( vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": - vllm_extra_kwargs["override_pooler_config"] = \ - PoolerConfig(pooling_type="MEAN", normalize=False) + vllm_extra_kwargs["pooler_config"] = PoolerConfig( + pooling_type="MEAN", normalize=False + ) max_model_len: Optional[int] = 512 if model in [ - "sentence-transformers/all-MiniLM-L12-v2", - "sentence-transformers/stsb-roberta-base-v2" + "sentence-transformers/all-MiniLM-L12-v2", + "sentence-transformers/stsb-roberta-base-v2", ]: max_model_len = None @@ -87,10 +84,9 @@ def test_models( with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, - runner="pooling", - max_model_len=max_model_len, - **vllm_extra_kwargs) as vllm_model: + with vllm_runner( + model, runner="pooling", max_model_len=max_model_len, **vllm_extra_kwargs + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 17a55d916b1ff..14308ac06c03e 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -70,8 +70,9 @@ async def run_client_embeddings( def gritlm_instruction(instruction): - return ("<|user|>\n" + instruction + - "\n<|embed|>\n" if instruction else "<|embed|>\n") + return ( + "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" + ) def get_test_data(): @@ -80,7 +81,8 @@ def get_test_data(): README.md in https://github.com/ContextualAI/gritlm """ q_instruction = gritlm_instruction( - "Given a scientific paper title, retrieve the paper's abstract", ) + "Given a scientific paper title, retrieve the paper's abstract", + ) queries = [ "Bitcoin: A Peer-to-Peer Electronic Cash System", "Generative Representational Instruction Tuning", @@ -114,9 +116,9 @@ def test_gritlm_offline_embedding(vllm_runner): queries, q_instruction, documents, d_instruction = get_test_data() with vllm_runner( - MODEL_NAME, - runner="pooling", - max_model_len=MAX_MODEL_LEN, + MODEL_NAME, + runner="pooling", + max_model_len=MAX_MODEL_LEN, ) as vllm_model: llm = vllm_model.llm @@ -161,9 +163,9 @@ def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" with vllm_runner( - MODEL_NAME, - runner="generate", - max_model_len=MAX_MODEL_LEN, + MODEL_NAME, + runner="generate", + max_model_len=MAX_MODEL_LEN, ) as vllm_model: llm = vllm_model.llm diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py deleted file mode 100644 index f805a64103c06..0000000000000 --- a/tests/models/language/pooling/test_gte.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any - -import pytest - -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, LASTPoolingEmbedModelInfo, - RerankModelInfo, check_transformers_version) -from .embed_utils import correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models - -MODELS = [ - ########## BertModel - CLSPoolingEmbedModelInfo("thenlper/gte-large", - architecture="BertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("thenlper/gte-base", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-small", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-large-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-base-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-small-zh", - architecture="BertModel", - enable_test=False), - ########### NewModel - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", - architecture="GteNewModel", - enable_test=True), - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", - architecture="GteNewModel", - enable_test=True), - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", - architecture="GteNewModel", - enable_test=True), - ########### Qwen2ForCausalLM - LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - architecture="Qwen2ForCausalLM", - enable_test=True), - ########## ModernBertModel - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base", - architecture="ModernBertModel", - enable_test=True), - ########## Qwen3ForCausalLM - LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=True), - LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=False), -] - -RERANK_MODELS = [ - # classifier_pooling: mean - CLSPoolingRerankModelInfo( - "Alibaba-NLP/gte-reranker-modernbert-base", - architecture="ModernBertForSequenceClassification", - enable_test=True), -] - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - check_transformers_version(model_info.name, - max_transformers_version="4.53.2") - - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - mteb_test_embed_models(hf_runner, vllm_runner, model_info, - vllm_extra_kwargs) - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - check_transformers_version(model_info.name, - max_transformers_version="4.53.2") - - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "GteNewModel": - vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts, vllm_extra_kwargs) - - -@pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: - mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling/test_intfloat.py deleted file mode 100644 index 6cae53a660ad8..0000000000000 --- a/tests/models/language/pooling/test_intfloat.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models - -MODELS = [ - ########## BertModel - CLSPoolingEmbedModelInfo("intfloat/e5-small", - architecture="BertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("intfloat/e5-base", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("intfloat/e5-large", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small", - architecture="BertModel", - enable_test=False), - ########## XLMRobertaModel - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base", - architecture="XLMRobertaModel", - enable_test=True), - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large", - architecture="XLMRobertaModel", - enable_test=False), - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct", - architecture="XLMRobertaModel", - enable_test=False), -] - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - mteb_test_embed_models(hf_runner, vllm_runner, model_info, atol=0.02) - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py new file mode 100644 index 0000000000000..91be6cd09d33e --- /dev/null +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.config.pooler import PoolerConfig +from vllm.platforms import current_platform + + +def test_idefics_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + with vllm_runner( + model_name="HuggingFaceM4/Idefics3-8B-Llama3", + runner="pooling", + task="classify", + convert="classify", + load_format="dummy", + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16", + ) as vllm_model: + llm = vllm_model.get_llm() + outputs = llm.classify(prompts) + for output in outputs: + assert len(output.outputs.probs) == 2 + + +def update_config(config): + config.text_config.update( + { + "architectures": ["Gemma3ForSequenceClassification"], + "classifier_from_token": ["A", "B", "C", "D", "E"], + "method": "no_post_processing", + "id2label": { + "A": "Chair", + "B": "Couch", + "C": "Table", + "D": "Bed", + "E": "Cupboard", + }, + } + ) + return config + + +def test_gemma_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + messages = [ + { + "role": "system", + "content": """ + You are a helpful assistant. You will be given a product description + which may also include an image. Classify the following product into + one of the categories: + + A = chair + B = couch + C = table + D = bed + E = cupboard + + You'll answer with exactly one letter (A, B, C, D, or E).""", + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" + }, + }, + {"type": "text", "text": "A fine 19th century piece of furniture."}, + ], + }, + ] + + with vllm_runner( + model_name="google/gemma-3-4b-it", + runner="pooling", + task="classify", + convert="classify", + load_format="auto", + hf_overrides=update_config, + pooler_config=PoolerConfig(pooling_type="LAST"), + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16", + ) as vllm_model: + llm = vllm_model.get_llm() + prompts = llm.preprocess_chat(messages) + + result = llm.classify(prompts) + assert result[0].outputs.probs[0] > 0.95 + assert all(c < 0.05 for c in result[0].outputs.probs[1:]) diff --git a/tests/models/language/pooling/test_multilabel_classification_support.py b/tests/models/language/pooling/test_multilabel_classification_support.py index 45366f2094144..472fee71711a6 100644 --- a/tests/models/language/pooling/test_multilabel_classification_support.py +++ b/tests/models/language/pooling/test_multilabel_classification_support.py @@ -20,14 +20,15 @@ def test_classify_models( with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py deleted file mode 100644 index 480bd5e4567cb..0000000000000 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any - -import pytest -import torch - -from tests.conftest import HfRunner - -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo -from .mteb_utils import mteb_test_rerank_models - -RERANK_MODELS = [ - LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=True), - LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", - architecture="Qwen2ForSequenceClassification", - enable_test=False) -] - - -class MxbaiRerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: - from transformers import AutoModelForCausalLM, AutoTokenizer - super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') - self.yes_loc = self.tokenizer.convert_tokens_to_ids("1") - self.no_loc = self.tokenizer.convert_tokens_to_ids("0") - - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - - def process_inputs(pairs): - inputs = self.tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = self.tokenizer.pad(inputs, - padding=True, - return_tensors="pt") - for key in inputs: - inputs[key] = inputs[key].to(self.model.device) - return inputs - - @torch.no_grad() - def compute_logits(inputs): - logits = self.model(**inputs).logits[:, -1, :] - yes_logits = logits[:, self.yes_loc] - no_logits = logits[:, self.no_loc] - logits = yes_logits - no_logits - scores = logits.float().sigmoid() - return scores - - scores = [] - for prompt in prompts: - inputs = process_inputs([prompt]) - score = compute_logits(inputs) - scores.append(score[0].item()) - return torch.Tensor(scores) - - -@pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - vllm_extra_kwargs: dict[str, Any] = {} - if model_info.architecture == "Qwen2ForSequenceClassification": - vllm_extra_kwargs["hf_overrides"] = { - "architectures": ["Qwen2ForSequenceClassification"], - "classifier_from_token": ["0", "1"], - "method": "from_2_way_softmax", - } - - mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py deleted file mode 100644 index 2d05958e9bcda..0000000000000 --- a/tests/models/language/pooling/test_nomic.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models - -MODELS = [ - CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1", - architecture="NomicBertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", - architecture="NomicBertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed", - architecture="NomicBertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", - architecture="NomicBertModel", - enable_test=True) -] - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - mteb_test_embed_models(hf_runner, vllm_runner, model_info) - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) diff --git a/tests/models/language/pooling/test_nomic_max_model_len.py b/tests/models/language/pooling/test_nomic_max_model_len.py index c34c36fd98150..88f088c603276 100644 --- a/tests/models/language/pooling/test_nomic_max_model_len.py +++ b/tests/models/language/pooling/test_nomic_max_model_len.py @@ -7,10 +7,10 @@ from ...utils import EmbedModelInfo MODELS = [ EmbedModelInfo("nomic-ai/nomic-embed-text-v1"), - #EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"), - #EmbedModelInfo("nomic-ai/CodeRankEmbed"), + # EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"), + # EmbedModelInfo("nomic-ai/CodeRankEmbed"), EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe"), - #EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"), + # EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"), ] rope_theta = 1000 @@ -21,23 +21,24 @@ max_model_len = int(original_max_position_embeddings * factor) @pytest.mark.parametrize("model_info", MODELS) def test_default(model_info, vllm_runner): - with vllm_runner(model_info.name, runner="pooling", - max_model_len=None) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=None + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config if model_info.name == "nomic-ai/nomic-embed-text-v2-moe": # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. assert model_config.max_model_len == 512 else: - assert ( - model_config.max_model_len == original_max_position_embeddings) + assert model_config.max_model_len == original_max_position_embeddings @pytest.mark.parametrize("model_info", MODELS) def test_set_max_model_len_legal(model_info, vllm_runner): # set max_model_len <= 512 - with vllm_runner(model_info.name, runner="pooling", - max_model_len=256) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=256 + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config assert model_config.max_model_len == 256 @@ -46,13 +47,12 @@ def test_set_max_model_len_legal(model_info, vllm_runner): # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=1024): + with vllm_runner(model_info.name, runner="pooling", max_model_len=1024): pass else: - with vllm_runner(model_info.name, runner="pooling", - max_model_len=1024) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=1024 + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config assert model_config.max_model_len == 1024 @@ -61,17 +61,18 @@ def test_set_max_model_len_legal(model_info, vllm_runner): def test_set_max_model_len_illegal(model_info, vllm_runner): # set max_model_len > 2048 with pytest.raises(ValueError): - with vllm_runner(model_info.name, runner="pooling", - max_model_len=4096): + with vllm_runner(model_info.name, runner="pooling", max_model_len=4096): pass # set max_model_len > 2048 by hf_overrides hf_overrides = {"max_model_len": 4096} with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + hf_overrides=hf_overrides, + ): pass @@ -82,16 +83,14 @@ def test_use_rope_scaling_legal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings + "original_max_position_embeddings": original_max_position_embeddings, }, - "max_model_len": max_model_len + "max_model_len": max_model_len, } - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, runner="pooling", max_model_len=None, hf_overrides=hf_overrides + ): pass @@ -102,16 +101,17 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings - } + "original_max_position_embeddings": original_max_position_embeddings, + }, } # illegal max_model_len with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=max_model_len + 1, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=max_model_len + 1, + hf_overrides=hf_overrides, + ): pass hf_overrides = { @@ -119,15 +119,16 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings + "original_max_position_embeddings": original_max_position_embeddings, }, - "max_model_len": max_model_len + 1 + "max_model_len": max_model_len + 1, } # illegal max_model_len by hf_overrides with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + hf_overrides=hf_overrides, + ): pass diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py deleted file mode 100644 index 2b1c74652e76f..0000000000000 --- a/tests/models/language/pooling/test_override_pooler_config.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -import torch -import torch.nn.functional as F - -from tests.models.utils import softmax -from vllm.config import PoolerConfig - - -@pytest.mark.parametrize( - "model", - [ - "jason9693/Qwen2.5-1.5B-apeach", - "papluca/xlm-roberta-base-language-detection" - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -def test_classify_models_using_activation( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, -) -> None: - - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig( - activation=False)) as vllm_model: - wo_activation_out = vllm_model.classify(example_prompts) - - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig( - activation=True)) as vllm_model: - w_activation_out = vllm_model.classify(example_prompts) - - for wo_activation, w_activation in zip(wo_activation_out, - w_activation_out): - wo_activation = torch.tensor(wo_activation) - w_activation = torch.tensor(w_activation) - - assert not torch.allclose( - wo_activation, w_activation, - atol=1e-2), "override_pooler_config is not working" - assert torch.allclose(softmax(wo_activation), w_activation, - 1e-3 if dtype == "float" else 1e-2) - - -@pytest.mark.parametrize( - "model", - [ - "intfloat/multilingual-e5-small", - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -def test_embed_models_using_normalize( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, -) -> None: - - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig( - normalize=False)) as vllm_model: - wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) - - with vllm_runner( - model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig(normalize=True)) as vllm_model: - w_normalize = torch.tensor(vllm_model.embed(example_prompts)) - - assert not torch.allclose( - wo_normalize, w_normalize, - atol=1e-2), "override_pooler_config normalize is not working" - assert torch.allclose( - F.normalize(wo_normalize, p=2, dim=-1), w_normalize, - atol=1e-2), "w_normal should be close to normal(wo_normal)." - - -@pytest.mark.parametrize( - "model", - [ - "internlm/internlm2-1_8b-reward", - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -def test_reward_models_using_softmax( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, -) -> None: - - with vllm_runner( - model, - max_model_len=1024, - dtype=dtype, - override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: - wo_softmax = vllm_model.encode(example_prompts) - - with vllm_runner( - model, - max_model_len=1024, - dtype=dtype, - override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: - w_softmax = vllm_model.encode(example_prompts) - - for wo, w in zip(wo_softmax, w_softmax): - wo = torch.tensor(wo) - w = torch.tensor(w) - - assert not torch.allclose( - wo, w, atol=1e-2), "override_pooler_config softmax is not working" - assert torch.allclose( - softmax(wo), w, - atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/models/language/pooling/test_pooler_config_init_behaviour.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py new file mode 100644 index 0000000000000..674bf02b7b98b --- /dev/null +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F + +from tests.models.utils import softmax +from vllm.config import PoolerConfig + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach", "papluca/xlm-roberta-base-language-detection"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models_using_activation( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(activation=False), + ) as vllm_model: + wo_activation_out = vllm_model.classify(example_prompts) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(activation=True), + ) as vllm_model: + w_activation_out = vllm_model.classify(example_prompts) + + for wo_activation, w_activation in zip(wo_activation_out, w_activation_out): + wo_activation = torch.tensor(wo_activation) + w_activation = torch.tensor(w_activation) + + assert not torch.allclose(wo_activation, w_activation, atol=1e-2), ( + "pooler_config is not working" + ) + assert torch.allclose( + softmax(wo_activation), w_activation, 1e-3 if dtype == "float" else 1e-2 + ) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/multilingual-e5-small", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embed_models_using_normalize( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=False), + ) as vllm_model: + wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=True), + ) as vllm_model: + w_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + assert not torch.allclose(wo_normalize, w_normalize, atol=1e-2), ( + "pooler_config normalize is not working" + ) + assert torch.allclose( + F.normalize(wo_normalize, p=2, dim=-1), w_normalize, atol=1e-2 + ), "w_normal should be close to normal(wo_normal)." + + +@pytest.mark.parametrize( + "model", + [ + "internlm/internlm2-1_8b-reward", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_reward_models_using_softmax( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + pooler_config=PoolerConfig(softmax=False), + ) as vllm_model: + wo_softmax = vllm_model.encode(example_prompts) + + with vllm_runner( + model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True) + ) as vllm_model: + w_softmax = vllm_model.encode(example_prompts) + + for wo, w in zip(wo_softmax, w_softmax): + wo = torch.tensor(wo) + w = torch.tensor(w) + + assert not torch.allclose(wo, w, atol=1e-2), ( + "pooler_config softmax is not working" + ) + assert torch.allclose(softmax(wo), w, atol=1e-2), ( + "w_softmax should be close to softmax(wo_softmax)." + ) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py deleted file mode 100644 index 37f5566a330d0..0000000000000 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any - -import pytest -import torch - -from tests.conftest import HfRunner -from tests.utils import multi_gpu_test - -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo -from .mteb_utils import mteb_test_rerank_models - -RERANK_MODELS = [ - LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", - architecture="Qwen3ForSequenceClassification", - enable_test=True), - LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", - architecture="Qwen3ForSequenceClassification", - enable_test=False) -] - - -class Qwen3RerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: - from transformers import AutoModelForCausalLM, AutoTokenizer - super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') - self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") - self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") - - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - - def process_inputs(pairs): - inputs = self.tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = self.tokenizer.pad(inputs, - padding=True, - return_tensors="pt") - for key in inputs: - inputs[key] = inputs[key].to(self.model.device) - return inputs - - @torch.no_grad() - def compute_logits(inputs): - batch_scores = self.model(**inputs).logits[:, -1, :] - true_vector = batch_scores[:, self.token_true_id] - false_vector = batch_scores[:, self.token_false_id] - batch_scores = torch.stack([false_vector, true_vector], dim=1) - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) - scores = batch_scores[:, 1].exp() - return scores - - scores = [] - for prompt in prompts: - inputs = process_inputs([prompt]) - score = compute_logits(inputs) - scores.append(score[0].item()) - return torch.Tensor(scores) - - -@pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - } - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) - - -@pytest.mark.parametrize("model_info", RERANK_MODELS) -@multi_gpu_test(num_gpus=2) -def test_rerank_models_mteb_tp(vllm_runner, - model_info: RerankModelInfo) -> None: - - assert model_info.architecture == "Qwen3ForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, - "tensor_parallel_size": 2, - } - - mteb_test_rerank_models(Qwen3RerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - atol=1.2e-2) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index beafa0aed9862..46504d025c265 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import pytest import torch @@ -13,22 +12,12 @@ from ....conftest import HfRunner from ...utils import check_transformers_version -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture def math_step_prompts(): # ruff: noqa: E501 data = { - "system": - "Please reason step by step, and put your final answer within \\boxed{}. ", - "query": - "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", + "system": "Please reason step by step, and put your final answer within \\boxed{}. ", + "query": "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", "response": [ "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.", "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.", @@ -36,16 +25,16 @@ def math_step_prompts(): "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).", ], } - answer = "<extra_0>".join(data['response']) + "<extra_0>" + answer = "<extra_0>".join(data["response"]) + "<extra_0>" prompt = f"<im_start>system\n{data['system']}<im_end>\n<im_start>user\n{data['query']}<im_end>\n<im_start>assistant\n{answer}<im_end><|endoftext|>" return [prompt] def step_reward_patch_hf_model(hf_model: HfRunner): - # Patch the hf_runner to use the step reward function - def make_step_rewards(logits: torch.Tensor, - token_masks: torch.Tensor) -> list[list[float]]: + def make_step_rewards( + logits: torch.Tensor, token_masks: torch.Tensor + ) -> list[list[float]]: probabilities = F.softmax(logits, dim=-1) probabilities = probabilities * token_masks.unsqueeze(-1) @@ -63,7 +52,7 @@ def step_reward_patch_hf_model(hf_model: HfRunner): outputs = hf_model.model(input_ids=input_ids) step_sep_id = hf_model.tokenizer.encode("<extra_0>")[0] - token_masks = (input_ids == step_sep_id) + token_masks = input_ids == step_sep_id return make_step_rewards(outputs[0], token_masks) hf_model.reward = reward # type: ignore[attr-defined] @@ -74,8 +63,10 @@ def step_reward_patch_hf_model(hf_model: HfRunner): @pytest.mark.parametrize( "model", [ - pytest.param("Qwen/Qwen2.5-Math-PRM-7B", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "Qwen/Qwen2.5-Math-PRM-7B", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) @pytest.mark.parametrize("dtype", ["half"]) @@ -87,10 +78,11 @@ def test_prm_models( dtype: str, monkeypatch, ) -> None: - check_transformers_version("Qwen/Qwen2.5-Math-PRM-7B", - max_transformers_version="4.53.2") + check_transformers_version( + "Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2" + ) - if current_platform.is_cpu() and os.environ.get("VLLM_USE_V1", "0") == "0": + if current_platform.is_cpu(): pytest.skip("CPU only supports V1") if current_platform.is_rocm(): diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index 6b5ff70681459..416a43070f0e0 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,15 +23,6 @@ TEXTS_2 = [ "The capital of Germany is Berlin.", ] - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - DTYPE = "half" @@ -46,10 +37,9 @@ def test_cross_encoder_1_to_1(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict([text_pair]).tolist() - with vllm_runner(model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) assert len(vllm_outputs) == 1 @@ -67,10 +57,9 @@ def test_cross_encoder_1_to_N(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) assert len(vllm_outputs) == 2 @@ -89,10 +78,9 @@ def test_cross_encoder_N_to_N(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) assert len(vllm_outputs) == 2 @@ -110,17 +98,15 @@ def emb_model_name(request): def test_embedding_1_to_1(vllm_runner, hf_runner, emb_model_name): text_pair = [TEXTS_1[0], TEXTS_2[0]] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: hf_embeddings = hf_model.encode(text_pair) - hf_outputs = [ - F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0) - ] + hf_outputs = [F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0)] - with vllm_runner(emb_model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) assert len(vllm_outputs) == 1 @@ -135,20 +121,18 @@ def test_embedding_1_to_N(vllm_runner, hf_runner, emb_model_name): [TEXTS_1[0], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] hf_outputs = [ F.cosine_similarity(*map(torch.tensor, pair), dim=0) for pair in hf_embeddings ] - with vllm_runner(emb_model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) assert len(vllm_outputs) == 2 @@ -164,20 +148,18 @@ def test_embedding_N_to_N(vllm_runner, hf_runner, emb_model_name): [TEXTS_1[1], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] hf_outputs = [ F.cosine_similarity(*map(torch.tensor, pair), dim=0) for pair in hf_embeddings ] - with vllm_runner(emb_model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) assert len(vllm_outputs) == 2 diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py deleted file mode 100644 index c22c78592e535..0000000000000 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models -from .mteb_utils import mteb_test_embed_models - -MODELS = [ - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", - is_matryoshka=False, - architecture="BertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", - is_matryoshka=False, - architecture="NomicBertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - architecture="BertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", - is_matryoshka=True, - architecture="XLMRobertaModel", - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - is_matryoshka=True, - architecture="GteModel", - enable_test=True), -] - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - mteb_test_embed_models(hf_runner, vllm_runner, model_info, atol=0.02) - - -@pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py new file mode 100644 index 0000000000000..784d9fc312679 --- /dev/null +++ b/tests/models/language/pooling/test_token_classification.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForTokenClassification + +from tests.models.utils import softmax + + +@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"]) +# The float32 is required for this tiny model to pass the test. +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_bert_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, 1e-2) + + +@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"]) +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_modernbert_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index c6ef899958a07..f1870ddbee510 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -20,51 +20,57 @@ calculus, each contributing unique perspectives that would shape this new field.""" -def test_smaller_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): - +def test_smaller_truncation_size( + vllm_runner, model_name=MODEL_NAME, input_str=input_str +): truncate_prompt_tokens = 10 - with vllm_runner(model_name, runner="pooling", - max_model_len=max_model_len) as vllm_model: + with vllm_runner( + model_name, runner="pooling", max_model_len=max_model_len + ) as vllm_model: vllm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) prompt_tokens = vllm_output[0].prompt_token_ids assert len(prompt_tokens) == truncate_prompt_tokens -def test_max_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): +def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input_str): truncate_prompt_tokens = -1 - with vllm_runner(model_name, runner="pooling", - max_model_len=max_model_len) as vllm_model: + with vllm_runner( + model_name, runner="pooling", max_model_len=max_model_len + ) as vllm_model: vllm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) prompt_tokens = vllm_output[0].prompt_token_ids assert len(prompt_tokens) == max_model_len -def test_bigger_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): - +def test_bigger_truncation_size( + vllm_runner, model_name=MODEL_NAME, input_str=input_str +): truncate_prompt_tokens = max_model_len + 1 - with pytest.raises(ValueError), vllm_runner( - model_name, runner="pooling", - max_model_len=max_model_len) as vllm_model: - + with ( + pytest.raises(ValueError), + vllm_runner( + model_name, runner="pooling", max_model_len=max_model_len + ) as vllm_model, + ): llm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) - assert llm_output == f"""truncate_prompt_tokens value + assert ( + llm_output + == f"""truncate_prompt_tokens value ({truncate_prompt_tokens}) is greater than max_model_len ({max_model_len}). Please, select a smaller truncation size.""" + ) diff --git a/tests/mq_llm_engine/__init__.py b/tests/models/language/pooling_mteb_test/__init__.py similarity index 100% rename from tests/mq_llm_engine/__init__.py rename to tests/models/language/pooling_mteb_test/__init__.py diff --git a/tests/models/language/pooling_mteb_test/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py new file mode 100644 index 0000000000000..d96dc90416855 --- /dev/null +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile +from collections.abc import Sequence +from typing import Optional + +import mteb +import numpy as np +import pytest +import requests +import torch + +import tests.ci_envs as ci_envs +from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close + +# Most embedding models on the STS12 task (See #17175): +# - Model implementation and minor changes in tensor dtype +# results in differences less than 1e-4 +# - Different model results in differences more than 1e-3 +# 1e-4 is a good tolerance threshold +MTEB_EMBED_TASKS = ["STS12"] +MTEB_EMBED_TOL = 1e-4 + +# See #19344 +MTEB_RERANK_TASKS = ["NFCorpus"] +MTEB_RERANK_LANGS = ["en"] +MTEB_RERANK_TOL = 2e-3 + + +class VllmMtebEncoder(mteb.Encoder): + def __init__(self, vllm_model): + super().__init__() + self.llm = vllm_model + self.rng = np.random.default_rng(seed=42) + + def encode( + self, + sentences: Sequence[str], + *args, + **kwargs, + ) -> np.ndarray: + # Hoping to discover potential scheduling + # issues by randomizing the order. + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + outputs = self.llm.embed(sentences, use_tqdm=False) + embeds = np.array(outputs) + embeds = embeds[np.argsort(r)] + return embeds + + def predict( + self, + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + queries = [s[0] for s in sentences] + corpus = [s[1] for s in sentences] + + outputs = self.llm.score( + queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False + ) + scores = np.array(outputs) + scores = scores[np.argsort(r)] + return scores + + +class OpenAIClientMtebEncoder(mteb.Encoder): + def __init__(self, model_name: str, client): + super().__init__() + self.model_name = model_name + self.client = client + self.rng = np.random.default_rng(seed=42) + + def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: + # Hoping to discover potential scheduling + # issues by randomizing the order. + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + embeddings = self.client.embeddings.create( + model=self.model_name, input=sentences + ) + outputs = [d.embedding for d in embeddings.data] + embeds = np.array(outputs) + embeds = embeds[np.argsort(r)] + return embeds + + +class ScoreClientMtebEncoder(mteb.Encoder): + def __init__(self, model_name: str, url): + super().__init__() + self.model_name = model_name + self.url = url + self.rng = np.random.default_rng(seed=42) + + def predict( + self, + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + outputs = [] + for query, corpus, prompt in sentences: + outputs.append(self.get_score(query, corpus)) + + scores = np.array(outputs) + scores = scores[np.argsort(r)] + return scores + + def get_score(self, query, corpus): + response = requests.post( + self.url, + json={ + "model": self.model_name, + "text_1": query, + "text_2": corpus, + "truncate_prompt_tokens": -1, + }, + ).json() + return response["data"][0]["score"] + + +class RerankClientMtebEncoder(ScoreClientMtebEncoder): + def get_score(self, query, corpus): + response = requests.post( + self.url, + json={ + "model": self.model_name, + "query": query, + "documents": [corpus], + "truncate_prompt_tokens": -1, + }, + ).json() + return response["results"][0]["relevance_score"] + + +def run_mteb_embed_task(encoder, tasks): + tasks = mteb.get_tasks(tasks=tasks) + evaluation = mteb.MTEB(tasks=tasks) + results = evaluation.run( + encoder, + verbosity=0, + output_folder=None, + encode_kwargs={ + "show_progress_bar": False, + }, + ) + + main_score = results[0].scores["test"][0]["main_score"] + return main_score + + +def mteb_test_embed_models( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None, + atol=MTEB_EMBED_TOL, +): + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + pytest.skip("Skipping test.") + + # Test embed_dims, isnan and whether to use normalize + example_prompts = ["The chef prepared a delicious meal." * 1000] + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + **vllm_extra_kwargs, + ) as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + + # Confirm whether vllm is using the correct architecture + if model_info.architecture: + assert model_info.architecture in model_config.architectures + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled + assert ( + model_config._model_info.default_pooling_type + == model_info.default_pooling_type + ) + + vllm_main_score = run_mteb_embed_task( + VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS + ) + vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype + head_dtype = model_config.head_dtype + + # Test embed_dims, isnan and whether to use normalize + vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) + assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) + + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant + if model_info.mteb_score is None: + with hf_runner( + model_info.name, + is_sentence_transformer=True, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: + # e.g. setting default parameters for the encode method of hf_runner + if hf_model_callback is not None: + hf_model_callback(hf_model) + + st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) + st_dtype = next(hf_model.model.parameters()).dtype + + # Test embed_dims and whether to use normalize + hf_outputs = hf_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + else: + st_main_score = model_info.mteb_score + st_dtype = "Constant" + + print("Model:", model_info.name) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_main_score) + print("SentenceTransformers:", st_dtype, st_main_score) + print("Difference:", st_main_score - vllm_main_score) + + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < atol + + +def run_mteb_rerank(cross_encoder, tasks, languages): + with tempfile.TemporaryDirectory() as results_folder: + bm25s = mteb.get_model("bm25s") + tasks = mteb.get_tasks(tasks=tasks, languages=languages) + + subset = "default" + eval_splits = ["test"] + + evaluation = mteb.MTEB(tasks=tasks) + evaluation.run( + bm25s, + verbosity=0, + eval_splits=eval_splits, + save_predictions=True, + output_folder=f"{results_folder}/stage1", + encode_kwargs={"show_progress_bar": False}, + ) + + results = evaluation.run( + cross_encoder, + verbosity=0, + eval_splits=eval_splits, + top_k=10, + save_predictions=True, + output_folder=f"{results_folder}/stage2", + previous_results=f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", + encode_kwargs={"show_progress_bar": False}, + ) + main_score = results[0].scores["test"][0]["main_score"] + return main_score + + +def mteb_test_rerank_models_hf( + hf_runner, model_name, hf_dtype="float32", hf_model_callback=None +): + with hf_runner(model_name, is_cross_encoder=True, dtype=hf_dtype) as hf_model: + original_predict = hf_model.predict + + def _predict( + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ): + # vllm and st both remove the prompt, fair comparison. + prompts = [(s[0], s[1]) for s in sentences] + return original_predict(prompts, *args, **kwargs, batch_size=8) + + hf_model.predict = _predict + hf_model.original_predict = original_predict + + if hf_model_callback is not None: + hf_model_callback(hf_model) + + st_main_score = run_mteb_rerank( + hf_model, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS + ) + st_dtype = next(hf_model.model.model.parameters()).dtype + return st_main_score, st_dtype + + +def mteb_test_rerank_models( + hf_runner, + vllm_runner, + model_info: RerankModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None, + vllm_mteb_encoder=VllmMtebEncoder, + atol=MTEB_RERANK_TOL, +): + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + pytest.skip("Skipping test.") + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + max_num_seqs=8, + **vllm_extra_kwargs, + ) as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + + # Confirm whether vllm is using the correct architecture + if model_info.architecture: + assert model_info.architecture in model_config.architectures + + # Score API is only enabled for num_labels == 1 + assert model_config.hf_config.num_labels == 1 + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled + assert ( + model_config._model_info.default_pooling_type + == model_info.default_pooling_type + ) + + vllm_main_score = run_mteb_rerank( + vllm_mteb_encoder(vllm_model), + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS, + ) + vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype + + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant + if model_info.mteb_score is None: + st_main_score, st_dtype = mteb_test_rerank_models_hf( + hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback + ) + else: + st_main_score = model_info.mteb_score + st_dtype = "Constant" + + print("Model:", model_info.name) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_main_score) + print("SentenceTransformers:", st_dtype, st_main_score) + print("Difference:", st_main_score - vllm_main_score) + + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < atol diff --git a/tests/models/language/pooling_mteb_test/test_baai.py b/tests/models/language/pooling_mteb_test/test_baai.py new file mode 100644 index 0000000000000..bad13e2457146 --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_baai.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, + EmbedModelInfo, + LASTPoolingEmbedModelInfo, + RerankModelInfo, +) + +from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models + +MODELS = [ + ########## BertModel + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-en", + architecture="BertModel", + mteb_score=0.779336792, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-en", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-en", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-zh-noinstruct", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-en-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-zh-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-en-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-zh-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-en-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-zh-v1.5", architecture="BertModel", enable_test=False + ), + ########## XLMRobertaModel + CLSPoolingEmbedModelInfo( + "BAAI/bge-m3", + architecture="XLMRobertaModel", + mteb_score=0.787343078, + enable_test=True, + ), + ########## Qwen2Model + LASTPoolingEmbedModelInfo( + "BAAI/bge-code-v1", + architecture="Qwen2Model", + mteb_score=0.75724465, + dtype="float32", + enable_test=True, + ), +] + +RERANK_MODELS = [ + ########## XLMRobertaForSequenceClassification + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-base", + architecture="XLMRobertaForSequenceClassification", + mteb_score=0.32398, + enable_test=True, + ), + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-large", + architecture="XLMRobertaForSequenceClassification", + enable_test=False, + ), + CLSPoolingRerankModelInfo( + "BAAI/bge-reranker-v2-m3", + architecture="XLMRobertaForSequenceClassification", + enable_test=False, + ), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py similarity index 64% rename from tests/models/language/pooling/test_bge_reranker_v2_gemma.py rename to tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py index 206524d7caad3..9e95dd74c3978 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py @@ -7,46 +7,51 @@ import pytest import torch from tests.conftest import HfRunner - -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo -from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models +from tests.models.language.pooling_mteb_test.mteb_utils import ( + VllmMtebEncoder, + mteb_test_rerank_models, +) +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo RERANK_MODELS = [ - LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification"), + LASTPoolingRerankModelInfo( + "BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification", + mteb_score=0.33757, + hf_overrides={ + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + }, + ), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 class GemmaRerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") @torch.no_grad() - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def get_inputs(pairs, tokenizer, prompt=None): if prompt is None: prompt = PROMPT sep = "\n" - prompt_inputs = tokenizer(prompt, - return_tensors=None, - add_special_tokens=False)["input_ids"] - sep_inputs = tokenizer(sep, - return_tensors=None, - add_special_tokens=False)["input_ids"] + prompt_inputs = tokenizer( + prompt, return_tensors=None, add_special_tokens=False + )["input_ids"] + sep_inputs = tokenizer(sep, return_tensors=None, add_special_tokens=False)[ + "input_ids" + ] inputs = [] for query, passage in pairs: query_inputs = tokenizer( @@ -70,8 +75,7 @@ class GemmaRerankerHfRunner(HfRunner): return_token_type_ids=False, add_special_tokens=False, ) - item["input_ids"] = item[ - "input_ids"] + sep_inputs + prompt_inputs + item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs item["attention_mask"] = [1] * len(item["input_ids"]) inputs.append(item) return tokenizer.pad( @@ -87,54 +91,44 @@ class GemmaRerankerHfRunner(HfRunner): inputs = inputs.to(self.model.device) _n_tokens = inputs["input_ids"].shape[1] logits = self.model(**inputs, return_dict=True).logits - _scores = (logits[:, -1, - self.yes_loc].view(-1, ).float().sigmoid()) + _scores = ( + logits[:, -1, self.yes_loc] + .view( + -1, + ) + .float() + .sigmoid() + ) scores.append(_scores[0].item()) return torch.Tensor(scores) class GemmaMtebEncoder(VllmMtebEncoder): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.prompt = PROMPT self.query_template = "A: {query}\n" self.document_template = "B: {doc}\n{prompt}" def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: - _sentences = [] for query, corpus, prompt in sentences: query = self.query_template.format(query=query) - corpus = self.document_template.format(doc=corpus, prompt=prompt) + corpus = self.document_template.format(doc=corpus, prompt=PROMPT) _sentences.append((query, corpus, prompt)) return super().predict(_sentences, *args, **kwargs) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, - monkeypatch) -> None: - monkeypatch.setenv("VLLM_USE_V1", "0") - - assert model_info.architecture == "GemmaForSequenceClassification" - - vllm_extra_kwargs: dict[str, Any] = { - "hf_overrides": { - "architectures": ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": "no_post_processing", - } - } - - mteb_test_rerank_models(GemmaRerankerHfRunner, - vllm_runner, - model_info, - vllm_extra_kwargs, - vllm_mteb_encoder=GemmaMtebEncoder) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: + mteb_test_rerank_models( + GemmaRerankerHfRunner, + vllm_runner, + model_info, + vllm_mteb_encoder=GemmaMtebEncoder, + ) diff --git a/tests/models/language/pooling_mteb_test/test_cross_encoder.py b/tests/models/language/pooling_mteb_test/test_cross_encoder.py new file mode 100644 index 0000000000000..638ffc7a62b0e --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_cross_encoder.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import ( + CLSPoolingRerankModelInfo, + LASTPoolingRerankModelInfo, + RerankModelInfo, +) + +from .mteb_utils import mteb_test_rerank_models + +RERANK_MODELS = [ + CLSPoolingRerankModelInfo( + "cross-encoder/ms-marco-TinyBERT-L-2-v2", + mteb_score=0.32898, + architecture="BertForSequenceClassification", + ), + LASTPoolingRerankModelInfo( + "tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + mteb_score=0.25736, + architecture="Qwen3ForSequenceClassification", + ), +] + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling_mteb_test/test_gte.py b/tests/models/language/pooling_mteb_test/test_gte.py new file mode 100644 index 0000000000000..a22821fd65b5a --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_gte.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, + EmbedModelInfo, + LASTPoolingEmbedModelInfo, + RerankModelInfo, +) + +from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models + +MODELS = [ + ########## BertModel + CLSPoolingEmbedModelInfo( + "thenlper/gte-large", + mteb_score=0.76807651, + architecture="BertModel", + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-base", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-small", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-large-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-base-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-small-zh", architecture="BertModel", enable_test=False + ), + ########### NewModel + # These three architectures are almost the same, but not exactly the same. + # For example, + # - whether to use token_type_embeddings + # - whether to use context expansion + # So only test one (the most widely used) model + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + mteb_score=0.775074696, + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=False, + ), + ########### Qwen2ForCausalLM + LASTPoolingEmbedModelInfo( + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", + mteb_score=0.758473459018872, + architecture="Qwen2ForCausalLM", + enable_test=True, + ), + ########## ModernBertModel + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-modernbert-base", + mteb_score=0.748193353, + architecture="ModernBertModel", + enable_test=True, + ), + ########## Qwen3ForCausalLM + LASTPoolingEmbedModelInfo( + "Qwen/Qwen3-Embedding-0.6B", + mteb_score=0.771163695, + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=True, + ), + LASTPoolingEmbedModelInfo( + "Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=False, + ), +] + +RERANK_MODELS = [ + CLSPoolingRerankModelInfo( + # classifier_pooling: mean + "Alibaba-NLP/gte-reranker-modernbert-base", + mteb_score=0.33386, + architecture="ModernBertForSequenceClassification", + enable_test=True, + ), + CLSPoolingRerankModelInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + mteb_score=0.33062, + architecture="GteNewForSequenceClassification", + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: + mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling_mteb_test/test_intfloat.py b/tests/models/language/pooling_mteb_test/test_intfloat.py new file mode 100644 index 0000000000000..1d078db69236a --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_intfloat.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + +from .mteb_utils import mteb_test_embed_models + +MODELS = [ + ########## BertModel + CLSPoolingEmbedModelInfo( + "intfloat/e5-small", + architecture="BertModel", + mteb_score=0.742285423, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "intfloat/e5-base", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "intfloat/e5-large", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-small", architecture="BertModel", enable_test=False + ), + ########## XLMRobertaModel + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-base", + architecture="XLMRobertaModel", + mteb_score=0.779325955, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-large", + architecture="XLMRobertaModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-large-instruct", + architecture="XLMRobertaModel", + enable_test=False, + ), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py similarity index 54% rename from tests/models/language/pooling/test_jina.py rename to tests/models/language/pooling_mteb_test/test_jina.py index 37c5bdc97dd98..0a712b2542f3c 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -4,58 +4,69 @@ from functools import partial import pytest +from tests.models.language.pooling.embed_utils import ( + check_embeddings_close, + correctness_test_embed_models, + matryoshka_fy, +) +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, + EmbedModelInfo, + RerankModelInfo, +) from vllm import PoolingParams -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, RerankModelInfo) -from .embed_utils import (check_embeddings_close, - correctness_test_embed_models, matryoshka_fy) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ - CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3", - architecture="XLMRobertaModel", - is_matryoshka=True) + CLSPoolingEmbedModelInfo( + "jinaai/jina-embeddings-v3", + mteb_score=0.824413164, + architecture="XLMRobertaModel", + is_matryoshka=True, + ) ] RERANK_MODELS = [ CLSPoolingRerankModelInfo( "jinaai/jina-reranker-v2-base-multilingual", - architecture="XLMRobertaForSequenceClassification") + mteb_score=0.33643, + architecture="XLMRobertaForSequenceClassification", + ) ] @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") - mteb_test_embed_models(hf_runner, - vllm_runner, - model_info, - hf_model_callback=hf_model_callback) + mteb_test_embed_models( + hf_runner, vllm_runner, model_info, hf_model_callback=hf_model_callback + ) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") - correctness_test_embed_models(hf_runner, - vllm_runner, - model_info, - example_prompts, - hf_model_callback=hf_model_callback) + correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info, + example_prompts, + hf_model_callback=hf_model_callback, + ) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) @@ -78,32 +89,32 @@ def test_matryoshka( example_prompts = [str(s).strip() for s in example_prompts] with hf_runner( - model_info.name, - dtype=dtype, - is_sentence_transformer=True, + model_info.name, + dtype=dtype, + is_sentence_transformer=True, ) as hf_model: hf_outputs = hf_model.encode(example_prompts, task="text-matching") hf_outputs = matryoshka_fy(hf_outputs, dimensions) - with vllm_runner(model_info.name, - runner="pooling", - dtype=dtype, - max_model_len=None) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", dtype=dtype, max_model_len=None + ) as vllm_model: assert vllm_model.llm.llm_engine.model_config.is_matryoshka matryoshka_dimensions = ( - vllm_model.llm.llm_engine.model_config.matryoshka_dimensions) + vllm_model.llm.llm_engine.model_config.matryoshka_dimensions + ) assert matryoshka_dimensions is not None if dimensions not in matryoshka_dimensions: with pytest.raises(ValueError): vllm_model.embed( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) + example_prompts, pooling_params=PoolingParams(dimensions=dimensions) + ) else: vllm_outputs = vllm_model.embed( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) + example_prompts, pooling_params=PoolingParams(dimensions=dimensions) + ) check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py new file mode 100644 index 0000000000000..fd04dc1990238 --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import pytest +import torch + +from tests.conftest import HfRunner +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo + +from .mteb_utils import mteb_test_rerank_models + +mxbai_rerank_hf_overrides = { + "architectures": ["Qwen2ForSequenceClassification"], + "classifier_from_token": ["0", "1"], + "method": "from_2_way_softmax", +} + +RERANK_MODELS = [ + LASTPoolingRerankModelInfo( + "mixedbread-ai/mxbai-rerank-base-v2", + architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, + mteb_score=0.273, + enable_test=True, + ), + LASTPoolingRerankModelInfo( + "mixedbread-ai/mxbai-rerank-large-v2", + architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, + enable_test=False, + ), +] + + +class MxbaiRerankerHfRunner(HfRunner): + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.yes_loc = self.tokenizer.convert_tokens_to_ids("1") + self.no_loc = self.tokenizer.convert_tokens_to_ids("0") + + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: + def process_inputs(pairs): + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ele + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt") + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs): + logits = self.model(**inputs).logits[:, -1, :] + yes_logits = logits[:, self.yes_loc] + no_logits = logits[:, self.no_loc] + logits = yes_logits - no_logits + scores = logits.float().sigmoid() + return scores + + scores = [] + for prompt in prompts: + inputs = process_inputs([prompt]) + score = compute_logits(inputs) + scores.append(score[0].item()) + return torch.Tensor(scores) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) diff --git a/tests/models/language/pooling_mteb_test/test_nomic.py b/tests/models/language/pooling_mteb_test/test_nomic.py new file mode 100644 index 0000000000000..c54a43052483a --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_nomic.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + +from .mteb_utils import mteb_test_embed_models + +MODELS = [ + CLSPoolingEmbedModelInfo( + "nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + mteb_score=0.737568559, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "nomic-ai/CodeRankEmbed", architecture="NomicBertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + mteb_score=0.715488912, + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py new file mode 100644 index 0000000000000..00e99f44cfdb1 --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import pytest +import torch + +from tests.conftest import HfRunner +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo +from tests.utils import multi_gpu_test + +from .mteb_utils import mteb_test_rerank_models + +qwen3_reranker_hf_overrides = { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, +} + +RERANK_MODELS = [ + LASTPoolingRerankModelInfo( + "Qwen/Qwen3-Reranker-0.6B", + architecture="Qwen3ForSequenceClassification", + mteb_score=0.25736, + hf_overrides=qwen3_reranker_hf_overrides, + enable_test=True, + ), + LASTPoolingRerankModelInfo( + "Qwen/Qwen3-Reranker-4B", + architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, + enable_test=False, + ), +] + + +class Qwen3RerankerHfRunner(HfRunner): + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: + def process_inputs(pairs): + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ele + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt") + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs): + batch_scores = self.model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp() + return scores + + scores = [] + for prompt in prompts: + inputs = process_inputs([prompt]) + score = compute_logits(inputs) + scores.append(score[0].item()) + return torch.Tensor(scores) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: + mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +@multi_gpu_test(num_gpus=2) +def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None: + assert model_info.architecture == "Qwen3ForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "tensor_parallel_size": 2, + } + + mteb_test_rerank_models( + Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs + ) diff --git a/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py new file mode 100644 index 0000000000000..3c30628aeaa49 --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + +from .mteb_utils import mteb_test_embed_models + +MODELS = [ + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + mteb_score=0.714927797, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + mteb_score=0.681146831, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + mteb_score=0.649088363, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + mteb_score=0.712258299, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + mteb_score=0.706622444, + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling_mteb_test/test_st_projector.py b/tests/models/language/pooling_mteb_test/test_st_projector.py new file mode 100644 index 0000000000000..91b1ef828d0df --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_st_projector.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + EmbedModelInfo, + LASTPoolingEmbedModelInfo, +) + +from .mteb_utils import mteb_test_embed_models + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + CLSPoolingEmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + mteb_score=0.688611955, + enable_test=True, + ), + LASTPoolingEmbedModelInfo( + "google/embeddinggemma-300m", + architecture="Gemma3TextModel", + mteb_score=0.7473819294684156, + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 96208f8eda628..0572898368d6d 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -3,27 +3,40 @@ """Common tests for testing .generate() functionality for single / multiple image, embedding, and video support for different VLMs in vLLM. """ + import math import os from collections import defaultdict from pathlib import PosixPath import pytest -from transformers import (AutoModel, AutoModelForImageTextToText, - AutoModelForTextToWaveform, AutoModelForVision2Seq) +from transformers import ( + AutoModel, + AutoModelForImageTextToText, + AutoModelForTextToWaveform, +) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, - ImageTestAssets, VideoTestAssets, VllmRunner) -from ....utils import (create_new_process_for_each_test, large_gpu_mark, - multi_gpu_marks) +from ....conftest import ( + IMAGE_ASSETS, + AudioTestAssets, + HfRunner, + ImageTestAssets, + VideoTestAssets, + VllmRunner, +) +from ....utils import create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks from ...utils import check_outputs_equal from .vlm_utils import custom_inputs, model_utils, runners from .vlm_utils.case_filtering import get_parametrized_options -from .vlm_utils.types import (CustomTestOptions, ExpandableVLMTestArgs, - VLMTestInfo, VLMTestType) +from .vlm_utils.types import ( + CustomTestOptions, + ExpandableVLMTestArgs, + VLMTestInfo, + VLMTestType, +) # This hack is needed for phi3v & paligemma models # ROCm Triton FA can run into shared memory issues with these models, @@ -32,25 +45,17 @@ from .vlm_utils.types import (CustomTestOptions, ExpandableVLMTestArgs, if current_platform.is_rocm(): os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" -REQUIRES_V0_MODELS = [ - # V1 Test: not enough KV cache space in C1. - "fuyu", - # V1 Test: Deadlock issue when processing mm_inputs - "llava-onevision-transformers", -] - -# yapf: disable COMMON_BROADCAST_SETTINGS = { "test_type": VLMTestType.IMAGE, "dtype": "half", "max_tokens": 5, "tensor_parallel_size": 2, "hf_model_kwargs": {"device_map": "auto"}, - "image_size_factors": [(.25, 0.5, 1.0)], + "image_size_factors": [(0.25, 0.5, 1.0)], "distributed_executor_backend": ( "ray", "mp", - ) + ), } ### Test configuration for specific models @@ -90,22 +95,20 @@ VLM_TEST_SETTINGS = { #### Core tests to always run in the CI "llava": VLMTestInfo( models=["llava-hf/llava-1.5-7b-hf"], - test_type=( - VLMTestType.EMBEDDING, - VLMTestType.IMAGE, - VLMTestType.CUSTOM_INPUTS - ), + test_type=(VLMTestType.EMBEDDING, VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS), prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", convert_assets_to_embeddings=model_utils.get_llava_embeddings, max_model_len=4096, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:" - ), - limit_mm_per_prompt={"image": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:" + ), + limit_mm_per_prompt={"image": 4}, + ) + ], # TODO: Revert to "auto" when CPU backend can use torch > 2.6 dtype="bfloat16" if current_platform.is_cpu() else "auto", marks=[pytest.mark.core_model, pytest.mark.cpu_model], @@ -114,47 +117,43 @@ VLM_TEST_SETTINGS = { models=["google/paligemma-3b-mix-224"], test_type=VLMTestType.IMAGE, prompt_formatter=identity, - img_idx_to_prompt = lambda idx: "", + img_idx_to_prompt=lambda idx: "", # Paligemma uses its own sample prompts because the default one fails - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "caption es", - "cherry_blossom": "What is in the picture?", - }), + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "caption es", + "cherry_blossom": "What is in the picture?", + } + ), auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, dtype="bfloat16", - marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501 + marks=[ + pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask") + ], ), "qwen2_5_vl": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen2_5_omni": VLMTestInfo( models=["Qwen/Qwen2.5-Omni-3B"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", + video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", max_model_len=4096, max_num_seqs=2, - num_logprobs= 6 if current_platform.is_cpu() else 5, + num_logprobs=6 if current_platform.is_cpu() else 5, auto_cls=AutoModelForTextToWaveform, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner, @@ -162,9 +161,9 @@ VLM_TEST_SETTINGS = { marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "ultravox": VLMTestInfo( - models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"], + models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"], test_type=VLMTestType.AUDIO, - prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 audio_idx_to_prompt=lambda idx: "<|audio|>", max_model_len=4096, max_num_seqs=2, @@ -178,40 +177,57 @@ VLM_TEST_SETTINGS = { "llava-onevision-transformers": VLMTestInfo( models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], test_type=VLMTestType.IMAGE, - prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 max_model_len=16384, - hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + "default_torch_num_threads": 1, + }, + # FIXME: Investigate why the test hangs + # when processing the 3rd prompt in vLLM + marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")], + ), + # Gemma3 has bidirectional mask on images + "gemma3-transformers": VLMTestInfo( + models=["google/gemma-3-4b-it"], + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda vid_prompt: f"<'<bos><start_of_turn>user\n{vid_prompt}<start_of_image><end_of_turn>\n<start_of_turn>model\n", # noqa: E501 + max_model_len=4096, + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[pytest.mark.core_model], + ), + "idefics3-transformers": VLMTestInfo( + models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], vllm_runner_kwargs={ "model_impl": "transformers", }, marks=[pytest.mark.core_model], ), - # FIXME(Isotr0py): Enable this test after - # https://github.com/huggingface/transformers/pull/39470 released - # "idefics3-transformers": VLMTestInfo( - # models=["HuggingFaceTB/SmolVLM-256M-Instruct"], - # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - # prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 - # img_idx_to_prompt=lambda idx: "<image>", - # max_model_len=8192, - # max_num_seqs=2, - # auto_cls=AutoModelForImageTextToText, - # hf_output_post_proc=model_utils.idefics3_trunc_hf_output, - # image_size_factors=[(0.25, 0.5, 1.0)], - # vllm_runner_kwargs={ - # "model_impl": "transformers", - # }, - # marks=[pytest.mark.core_model], - # ), # Pixel values from processor are not 4D or 5D arrays "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], test_type=VLMTestType.IMAGE, - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -222,35 +238,22 @@ VLM_TEST_SETTINGS = { }, marks=[large_gpu_mark(min_gb=32)], ), - # Check "auto" with fallback to transformers - "internvl-transformers": VLMTestInfo( - models=["OpenGVLab/InternVL3-1B-hf"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", - max_model_len=4096, - use_tokenizer_eos=True, - image_size_factors=[(0.25, 0.5, 1.0)], - vllm_runner_kwargs={ - "model_impl": "auto", - }, - auto_cls=AutoModelForImageTextToText, - marks=[pytest.mark.core_model], - ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<vlm_image>Please describe the image shortly.", - "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501 - }), - multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<vlm_image>Please describe the image shortly.", + "cherry_blossom": "<vlm_image>Please infer the season with reason.", + } + ), + multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", stop_str=["<|im_end|>"], image_size_factors=[(0.10, 0.15)], max_tokens=64, @@ -259,12 +262,14 @@ VLM_TEST_SETTINGS = { "aya_vision": VLMTestInfo( models=["CohereForAI/aya-vision-8b"], test_type=(VLMTestType.IMAGE), - prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>What is the season?", # noqa: E501 - }), - multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>What's the content in the center of the image?", + "cherry_blossom": "<image>What is the season?", + } + ), + multi_image_prompt="<image><image>Describe the two images in detail.", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -273,12 +278,14 @@ VLM_TEST_SETTINGS = { "aya_vision-multi_image": VLMTestInfo( models=["CohereForAI/aya-vision-8b"], test_type=(VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>What is the season?", # noqa: E501 - }), - multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>What's the content in the center of the image?", + "cherry_blossom": "<image>What is the season?", + } + ), + multi_image_prompt="<image><image>Describe the two images in detail.", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -303,27 +310,29 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, # For chameleon, we only compare the sequences - vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], - hf_output_post_proc = lambda hf_output, model: hf_output[:2], + vllm_output_post_proc=lambda vllm_output, model: vllm_output[:2], + hf_output_post_proc=lambda hf_output, model: hf_output[:2], comparator=check_outputs_equal, max_tokens=8, dtype="bfloat16", ), "deepseek_vl_v2": VLMTestInfo( - models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module + models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 max_model_len=4096, max_num_seqs=2, - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nPlease infer the season with reason in details.", # noqa: E501 - }), - multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nPlease infer the season with reason in details.", # noqa: E501 + } + ), + multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501 patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, - stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501 - image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], + stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], + image_size_factors=[(), (1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], ), "fuyu": VLMTestInfo( models=["adept/fuyu-8b"], @@ -337,19 +346,18 @@ VLM_TEST_SETTINGS = { vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], - # FIXME(Isotr0py): This model is broken in Transformers v4.54.1, we - # should enable this again after the fix is released: - # https://github.com/huggingface/transformers/pull/39915 - marks=[pytest.mark.skip("HF model is broken")], + marks=[large_gpu_mark(min_gb=32)], ), "gemma3": VLMTestInfo( models=["google/gemma-3-4b-it"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<start_of_image>What is the season?", # noqa: E501 - }), + prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "<start_of_image>What is the season?", + } + ), multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501 max_model_len=4096, max_num_seqs=2, @@ -361,11 +369,13 @@ VLM_TEST_SETTINGS = { "glm4v": VLMTestInfo( models=["zai-org/glm-4v-9b"], test_type=VLMTestType.IMAGE, - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<|begin_of_image|><|endoftext|><|end_of_image|>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<|begin_of_image|><|endoftext|><|end_of_image|>What is the season?", # noqa: E501 - }), + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<|begin_of_image|><|endoftext|><|end_of_image|>What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "<|begin_of_image|><|endoftext|><|end_of_image|>What is the season?", # noqa: E501 + } + ), max_model_len=2048, max_num_seqs=2, get_stop_token_ids=lambda tok: [151329, 151336, 151338], @@ -380,9 +390,9 @@ VLM_TEST_SETTINGS = { "glm4_1v": VLMTestInfo( models=["zai-org/GLM-4.1V-9B-Thinking"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", + video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", max_model_len=2048, max_num_seqs=2, get_stop_token_ids=lambda tok: [151329, 151336, 151338], @@ -399,23 +409,27 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, patch_hf_runner=model_utils.glm4_1v_patch_hf_runner, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.video_with_metadata_glm4_1v(), - limit_mm_per_prompt={"video": 1}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.video_with_metadata_glm4_1v(), + limit_mm_per_prompt={"video": 1}, + ) + ], marks=[large_gpu_mark(min_gb=32)], ), "h2ovl": VLMTestInfo( - models = [ + models=[ "h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-2b", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nWhat is the season?", - }), + prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nWhat is the season?", + } + ), multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501 max_model_len=8192, use_tokenizer_eos=True, @@ -425,7 +439,7 @@ VLM_TEST_SETTINGS = { "idefics3": VLMTestInfo( models=["HuggingFaceTB/SmolVLM-256M-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 img_idx_to_prompt=lambda idx: "<image>", max_model_len=8192, max_num_seqs=2, @@ -440,11 +454,13 @@ VLM_TEST_SETTINGS = { # "OpenGVLab/Mono-InternVL-2B", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nWhat is the season?", - }), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nWhat is the season?", + } + ), multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501 max_model_len=4096, use_tokenizer_eos=True, @@ -455,16 +471,30 @@ VLM_TEST_SETTINGS = { "OpenGVLab/InternVL3-1B", ], test_type=VLMTestType.VIDEO, - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 video_idx_to_prompt=lambda idx: "<video>", max_model_len=8192, use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "intern_vl-hf": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO, + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", + video_idx_to_prompt=lambda idx: "<video>", + max_model_len=8192, + use_tokenizer_eos=True, + auto_cls=AutoModelForImageTextToText, + ), "kimi_vl": VLMTestInfo( models=["moonshotai/Kimi-VL-A3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501 max_model_len=8192, max_num_seqs=2, @@ -475,11 +505,11 @@ VLM_TEST_SETTINGS = { ), "llama4": VLMTestInfo( models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], - prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 img_idx_to_prompt=lambda _: "<|image|>", test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), distributed_executor_backend="mp", - image_size_factors=[(.25, 0.5, 1.0)], + image_size_factors=[(0.25, 0.5, 1.0)], hf_model_kwargs={"device_map": "auto"}, max_model_len=8192, max_num_seqs=4, @@ -495,28 +525,34 @@ VLM_TEST_SETTINGS = { max_model_len=10240, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]" - ), - limit_mm_per_prompt={"image": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]" + ), + limit_mm_per_prompt={"image": 4}, + ) + ], ), "llava_onevision": VLMTestInfo( models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], test_type=VLMTestType.CUSTOM_INPUTS, - prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 - auto_cls=AutoModelForVision2Seq, + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( - formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - ), - limit_mm_per_prompt={"video": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( + formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + ), + limit_mm_per_prompt={"video": 4}, + ) + ], ), "llava_next_video": VLMTestInfo( models=["llava-hf/LLaVA-NeXT-Video-7B-hf"], @@ -525,7 +561,7 @@ VLM_TEST_SETTINGS = { num_video_frames=16, max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, ), "mantis": VLMTestInfo( @@ -558,7 +594,9 @@ VLM_TEST_SETTINGS = { img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", max_model_len=4096, max_num_seqs=2, - get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids( + ["<|im_end|>", "<|endoftext|>"] + ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, # FIXME: https://huggingface.co/openbmb/MiniCPM-o-2_6/discussions/49 @@ -571,13 +609,15 @@ VLM_TEST_SETTINGS = { img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", max_model_len=4096, max_num_seqs=2, - get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids( + ["<|im_end|>", "<|endoftext|>"] + ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, ), "minimax_vl_01": VLMTestInfo( models=["MiniMaxAI/MiniMax-VL-01"], - prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501 img_idx_to_prompt=lambda _: "<image>", test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), max_model_len=8192, @@ -599,8 +639,8 @@ VLM_TEST_SETTINGS = { "ovis1_6-gemma2": VLMTestInfo( models=["AIDC-AI/Ovis1.6-Gemma2-9B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", max_model_len=4096, max_num_seqs=2, dtype="half", @@ -612,8 +652,8 @@ VLM_TEST_SETTINGS = { "ovis2": VLMTestInfo( models=["AIDC-AI/Ovis2-1B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", max_model_len=4096, max_num_seqs=2, dtype="half", @@ -623,13 +663,9 @@ VLM_TEST_SETTINGS = { ), "ovis2_5": VLMTestInfo( models=["AIDC-AI/Ovis2.5-2B"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", video_idx_to_prompt=lambda idx: "<video>\n", max_model_len=4096, max_num_seqs=2, @@ -641,7 +677,7 @@ VLM_TEST_SETTINGS = { "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501 img_idx_to_prompt=lambda idx: f"<|image_{idx}|>\n", max_model_len=4096, max_num_seqs=2, @@ -676,18 +712,14 @@ VLM_TEST_SETTINGS = { ), "qwen2_vl": VLMTestInfo( models=["Qwen/Qwen2-VL-2B-Instruct"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 - multi_image_prompt="Picture 1: <vlm_image>\nPicture 2: <vlm_image>\nDescribe these two images with one paragraph respectively.", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", + multi_image_prompt="Picture 1: <vlm_image>\nPicture 2: <vlm_image>\nDescribe these two images with one paragraph respectively.", # noqa: E501 max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.cpu_model], @@ -695,12 +727,14 @@ VLM_TEST_SETTINGS = { "skywork_r1v": VLMTestInfo( models=["Skywork/Skywork-R1V-38B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nWhat is the season?", - }), - multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nWhat is the season?", + } + ), + multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", max_model_len=4096, use_tokenizer_eos=True, patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner, @@ -715,6 +749,7 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, hf_output_post_proc=model_utils.smolvlm_trunc_hf_output, + num_logprobs=10, ), "tarsier": VLMTestInfo( models=["omni-research/Tarsier-7b"], @@ -732,9 +767,9 @@ VLM_TEST_SETTINGS = { VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO, ), - prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -747,11 +782,11 @@ VLM_TEST_SETTINGS = { prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForImageTextToText, - vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], - hf_output_post_proc = lambda hf_output, model: hf_output[:2], + vllm_output_post_proc=lambda vllm_output, model: vllm_output[:2], + hf_output_post_proc=lambda hf_output, model: hf_output[:2], comparator=check_outputs_equal, marks=multi_gpu_marks(num_gpus=2), - **COMMON_BROADCAST_SETTINGS # type: ignore + **COMMON_BROADCAST_SETTINGS, # type: ignore ), "llava-broadcast": VLMTestInfo( models=["llava-hf/llava-1.5-7b-hf"], @@ -760,7 +795,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, marks=multi_gpu_marks(num_gpus=2), - **COMMON_BROADCAST_SETTINGS # type: ignore + **COMMON_BROADCAST_SETTINGS, # type: ignore ), "llava_next-broadcast": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], @@ -769,12 +804,12 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, marks=multi_gpu_marks(num_gpus=2), - **COMMON_BROADCAST_SETTINGS # type: ignore + **COMMON_BROADCAST_SETTINGS, # type: ignore ), ### Custom input edge-cases for specific models "intern_vl-diff-patches": VLMTestInfo( models=["OpenGVLab/InternVL2-2B"], - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=4096, use_tokenizer_eos=True, @@ -783,7 +818,8 @@ VLM_TEST_SETTINGS = { CustomTestOptions( inputs=inp, limit_mm_per_prompt={"image": 2}, - ) for inp in custom_inputs.different_patch_input_cases_internvl() + ) + for inp in custom_inputs.different_patch_input_cases_internvl() ], ), "llava_onevision-multiple-images": VLMTestInfo( @@ -791,15 +827,19 @@ VLM_TEST_SETTINGS = { test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, - hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + auto_cls=AutoModelForImageTextToText, + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - ), - limit_mm_per_prompt={"image": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + ), + limit_mm_per_prompt={"image": 4}, + ) + ], ), # regression test for https://github.com/vllm-project/vllm/issues/15122 "qwen2_5_vl-windows-attention": VLMTestInfo( @@ -807,15 +847,16 @@ VLM_TEST_SETTINGS = { test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.windows_attention_image_qwen2_5_vl(), - limit_mm_per_prompt={"image": 1}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.windows_attention_image_qwen2_5_vl(), + limit_mm_per_prompt={"image": 1}, + ) + ], ), } -# yapf: enable def _mark_splits( @@ -836,7 +877,7 @@ def _mark_splits( new_test_settings = dict[str, VLMTestInfo]() for i in range(num_groups): - models_in_group = models[i * split_size:(i + 1) * split_size] + models_in_group = models[i * split_size : (i + 1) * split_size] for model in models_in_group: for info in test_infos_by_model[model]: @@ -867,14 +908,16 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2) VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, create_new_process_for_each_test=False, - )) -def test_single_image_models(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_single_image_models( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_single_image_test( tmp_path=tmp_path, @@ -892,14 +935,16 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=False, - )) -def test_multi_image_models(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_multi_image_models( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_multi_image_test( tmp_path=tmp_path, @@ -917,14 +962,15 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=False, - )) -def test_image_embedding_models(model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_image_embedding_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_embedding_test( model_test_info=model_test_info, @@ -941,12 +987,15 @@ def test_image_embedding_models(model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=False, - )) -def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: VideoTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_video_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + video_assets: VideoTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_video_test( model_test_info=model_test_info, @@ -963,12 +1012,15 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=False, - )) -def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - audio_assets: AudioTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_audio_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_audio_test( model_test_info=model_test_info, @@ -985,16 +1037,14 @@ def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, create_new_process_for_each_test=False, - )) + ), +) def test_custom_inputs_models( model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - monkeypatch, ): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_custom_inputs_test( model_test_info=model_test_info, @@ -1011,15 +1061,17 @@ def test_custom_inputs_models( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() -def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_single_image_models_heavy( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_single_image_test( tmp_path=tmp_path, @@ -1037,15 +1089,17 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() -def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_multi_image_models_heavy( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_multi_image_test( tmp_path=tmp_path, @@ -1063,16 +1117,16 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() -def test_image_embedding_models_heavy(model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_image_embedding_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_embedding_test( model_test_info=model_test_info, @@ -1089,13 +1143,15 @@ def test_image_embedding_models_heavy(model_type: str, VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=True, - )) -def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - video_assets: VideoTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_video_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + video_assets: VideoTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_video_test( model_test_info=model_test_info, @@ -1112,13 +1168,15 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=True, - )) -def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - audio_assets: AudioTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") + ), +) +def test_audio_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_audio_test( model_test_info=model_test_info, @@ -1135,17 +1193,15 @@ def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() def test_custom_inputs_models_heavy( model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - monkeypatch, ): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_custom_inputs_test( model_test_info=model_test_info, diff --git a/tests/models/multimodal/generation/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py deleted file mode 100644 index a622957f96f69..0000000000000 --- a/tests/models/multimodal/generation/test_florence2.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest -from PIL import Image - -from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt -from vllm.multimodal.image import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner -from ...utils import check_logprobs_close - -MODELS = ["microsoft/Florence-2-base"] -# Florence-2 model repo's tokenizer config is missing some special tokens. -# Therefore, we use a converted tokenizer from a forked repo -TOKENIZER = "Isotr0py/Florence-2-tokenizer" -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<OD>", # special task token which will output special tokens - "cherry_blossom": - "Describe in detail what is shown in the image.", -}) - - -def get_hf_images_prompts( - prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]], -) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]: - prompts, images = [], [] - for prompt in prompts_: - encoder_prompt = prompt["encoder_prompt"] - prompts.append( - ExplicitEncoderDecoderPrompt( - encoder_prompt=encoder_prompt["prompt"], - decoder_prompt=None, - )) - images.append(encoder_prompt["multi_modal_data"]["image"]) - return prompts, images - - -def hf_to_vllm_output(hf_output: tuple[list[int], str, - Optional[SampleLogprobs]]): - """Sanitize hf output to be comparable with vllm output.""" - output_ids, output_str, out_logprobs = hf_output - - output_str = output_str.replace("</s>", "").replace("<s>", "") - - return output_ids, output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: list[list[ExplicitEncoderDecoderPrompt]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - with vllm_runner(model, - max_num_seqs=8, - tokenizer_name=TOKENIZER, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_case = [ - vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, - max_tokens, - num_logprobs=num_logprobs, - skip_special_tokens=False, - ) for prompts in inputs - ] - - hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] - - with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.lm_head - hf_outputs_per_case = [ - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in hf_inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs], - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=1, - ) - - -# FIXME: https://github.com/huggingface/transformers/issues/38358 -@pytest.mark.skip("Model initialization fails") -@pytest.mark.core_model -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, model: str, - size_factors: list[int], dtype: str, max_tokens: int, - num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [[ - ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=prompt, - multi_modal_data={"image": rescale_image_size(image, factor)}), - decoder_prompt=None, - ) for factor in size_factors - ] for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index f2e6fbfad6e80..ef08b1916aa5f 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -7,11 +7,10 @@ from typing import Optional import pytest from transformers import AutoModelForSpeechSeq2Seq +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest -from vllm.sequence import SampleLogprobs -from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, - VllmRunner) +from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -64,50 +63,49 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - runner="generate", - max_model_len=max_model_len, - max_num_seqs=1, - dtype=dtype, - limit_mm_per_prompt={"audio": 1}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=64, - enforce_eager=True, + model, + runner="generate", + max_model_len=max_model_len, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"audio": 1}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=64, + enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("audio", 1, audio_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + lora_request=lora_request, + ) for prompts, audios in inputs ] - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: - + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - audios=[audios], - eos_token_id=eos_token_id) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=[audios], + eos_token_id=eos_token_id, + ) for prompts, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(output) for output in vllm_outputs - ], + outputs_1_lst=[vllm_to_hf_output(output) for output in vllm_outputs], name_0="hf", name_1="vllm", ) @@ -118,9 +116,16 @@ def run_test( @pytest.mark.parametrize("max_model_len", [2048]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, model: str, - audio_assets: AudioTestAssets, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + model: str, + audio_assets: AudioTestAssets, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index 1ef56af33a094..a773db19825e1 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -28,8 +28,7 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: give the same result. """ - image_cherry = convert_image_mode( - ImageAsset("cherry_blossom").pil_image, "RGB") + image_cherry = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB") images = [image_cherry, image_stop] video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays @@ -47,29 +46,30 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: ), ] - with vllm_runner(model, - runner="generate", - dtype=dtype, - limit_mm_per_prompt={"image": 2}, - max_model_len=32768, - max_num_seqs=2, - tensor_parallel_size=1, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, + runner="generate", + dtype=dtype, + limit_mm_per_prompt={"image": 2}, + max_model_len=32768, + max_num_seqs=2, + tensor_parallel_size=1, + enforce_eager=True, + ) as vllm_model: vllm_outputs_per_case = [ - vllm_model.generate_greedy(prompts, - max_tokens, - images=images, - videos=videos) + vllm_model.generate_greedy( + prompts, max_tokens, images=images, videos=videos + ) for prompts, images, videos in inputs ] all_results = [output[0][1] for output in vllm_outputs_per_case] - outputs = [(total_str, total_str.find("assistant\n") + len("assistant\n")) - for total_str in all_results] - prompt_lengths = [prompt_len for _, prompt_len in outputs] - generated_strs = [ - total_str[prompt_len:] for total_str, prompt_len in outputs + outputs = [ + (total_str, total_str.find("assistant\n") + len("assistant\n")) + for total_str in all_results ] + prompt_lengths = [prompt_len for _, prompt_len in outputs] + generated_strs = [total_str[prompt_len:] for total_str, prompt_len in outputs] interleaved_prompt_len, noninterleaved_prompt_len = prompt_lengths interleaved_output_str, noninterleaved_output_str = generated_strs diff --git a/tests/models/multimodal/generation/test_maverick.py b/tests/models/multimodal/generation/test_maverick.py index bacc9ef94f49d..fd3386ff67df2 100644 --- a/tests/models/multimodal/generation/test_maverick.py +++ b/tests/models/multimodal/generation/test_maverick.py @@ -18,13 +18,11 @@ from typing import Any import pytest import torch from safetensors.torch import save_file -from transformers import (AutoConfig, AutoProcessor, AutoTokenizer, - GenerationConfig) +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, GenerationConfig from vllm import LLM, SamplingParams from vllm.v1.executor.abstract import Executor -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, FullAttentionSpec from ....utils import multi_gpu_test @@ -93,8 +91,7 @@ def get_rope_layers_config(model_path: str) -> list[int]: def create_reduced_maverick_model( - original_model_name: - str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + original_model_name: str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", output_dir: str = "/tmp/reduced_maverick", text_layers: int = 4, num_experts: int = 4, @@ -118,7 +115,8 @@ def create_reduced_maverick_model( print( f"Creating reduced Maverick model with {text_layers} text layers and " - f"{vision_layers} vision layers...") + f"{vision_layers} vision layers..." + ) # Create output directory output_path = Path(output_dir) @@ -126,19 +124,23 @@ def create_reduced_maverick_model( if force_recreate: shutil.rmtree(output_path) else: - print(f"Output directory {output_dir} already exists. " - "Use --force-recreate to overwrite.") + print( + f"Output directory {output_dir} already exists. " + "Use --force-recreate to overwrite." + ) return str(output_path) output_path.mkdir(parents=True, exist_ok=True) try: print("Loading original model configuration...") - original_config = AutoConfig.from_pretrained(original_model_name, - trust_remote_code=True) + original_config = AutoConfig.from_pretrained( + original_model_name, trust_remote_code=True + ) print("Creating reduced configuration...") - reduced_config = create_reduced_config(original_config, text_layers, - num_experts, vision_layers) + reduced_config = create_reduced_config( + original_config, text_layers, num_experts, vision_layers + ) config_path = output_path / "config.json" with open(config_path, "w") as f: @@ -149,8 +151,7 @@ def create_reduced_maverick_model( copy_tokenizer_files(original_model_name, output_path) print("Creating reduced safetensors files...") - create_reduced_safetensors(original_config, reduced_config, - output_path) + create_reduced_safetensors(original_config, reduced_config, output_path) print("Creating preprocessor config...") create_preprocessor_config(original_config, output_path) @@ -173,9 +174,9 @@ def create_reduced_maverick_model( raise -def create_reduced_config(original_config: Any, text_layers: int, - num_experts: int, - vision_layers: int) -> dict[str, Any]: +def create_reduced_config( + original_config: Any, text_layers: int, num_experts: int, vision_layers: int +) -> dict[str, Any]: """Create a reduced configuration based on the original.""" # Convert config to dictionary @@ -185,23 +186,18 @@ def create_reduced_config(original_config: Any, text_layers: int, if "text_config" in config_dict: original_text_layers = config_dict["text_config"]["num_hidden_layers"] config_dict["text_config"]["num_hidden_layers"] = text_layers - print( - f"Reduced text layers from {original_text_layers} to {text_layers}" - ) + print(f"Reduced text layers from {original_text_layers} to {text_layers}") original_num_experts = config_dict["text_config"]["num_local_experts"] config_dict["text_config"]["num_local_experts"] = num_experts - print( - f"Reduced num experts from {original_num_experts} to {num_experts}" - ) + print(f"Reduced num experts from {original_num_experts} to {num_experts}") hidden_dim_divisor = 4 original_hidden_size = config_dict["text_config"]["hidden_size"] new_hidden_size = original_hidden_size // hidden_dim_divisor config_dict["text_config"]["hidden_size"] = new_hidden_size - print(f"Reduced hidden size from {original_hidden_size} to " - f"{new_hidden_size}") + print(f"Reduced hidden size from {original_hidden_size} to {new_hidden_size}") original_head_dim = config_dict["text_config"]["head_dim"] new_head_dim = original_head_dim // hidden_dim_divisor @@ -210,15 +206,12 @@ def create_reduced_config(original_config: Any, text_layers: int, # Reduce vision layers if "vision_config" in config_dict: - original_vision_layers = config_dict["vision_config"][ - "num_hidden_layers"] + original_vision_layers = config_dict["vision_config"]["num_hidden_layers"] config_dict["vision_config"]["num_hidden_layers"] = vision_layers - print(f"Reduced vision layers from {original_vision_layers} " - f"to {vision_layers}") + print(f"Reduced vision layers from {original_vision_layers} to {vision_layers}") # Update model name to indicate it's a reduced version - config_dict["_name_or_path"] = ( - f"reduced_maverick_{text_layers}t_{vision_layers}v") + config_dict["_name_or_path"] = f"reduced_maverick_{text_layers}t_{vision_layers}v" return config_dict @@ -227,16 +220,16 @@ def copy_tokenizer_files(original_model_name: str, output_path: Path) -> None: """Copy tokenizer files from the original model.""" try: - tokenizer = AutoTokenizer.from_pretrained(original_model_name, - trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + original_model_name, trust_remote_code=True + ) tokenizer.save_pretrained(output_path) print("Tokenizer files copied successfully") except Exception as e: print(f"Warning: Could not copy tokenizer files: {e}") -def create_preprocessor_config(original_config: Any, - output_path: Path) -> None: +def create_preprocessor_config(original_config: Any, output_path: Path) -> None: """Create preprocessor_config.json for multimodal model.""" # Try to load the original preprocessor config @@ -254,9 +247,9 @@ def create_preprocessor_config(original_config: Any, raise -def create_reduced_safetensors(original_config: Any, reduced_config: dict[str, - Any], - output_path: Path) -> None: +def create_reduced_safetensors( + original_config: Any, reduced_config: dict[str, Any], output_path: Path +) -> None: """Create safetensors files with weights for the reduced model.""" print("Generating synthetic weights for reduced model...") @@ -279,8 +272,7 @@ def create_reduced_safetensors(original_config: Any, reduced_config: dict[str, save_weights_to_safetensors(weights, output_path) -def create_text_model_weights( - text_config: dict[str, Any]) -> dict[str, torch.Tensor]: +def create_text_model_weights(text_config: dict[str, Any]) -> dict[str, torch.Tensor]: """Create synthetic weights for the text model with MoE structure.""" weights = {} @@ -291,19 +283,18 @@ def create_text_model_weights( intermediate_size_mlp = text_config["intermediate_size_mlp"] num_layers = text_config["num_hidden_layers"] num_attention_heads = text_config["num_attention_heads"] - num_key_value_heads = text_config.get("num_key_value_heads", - num_attention_heads) + num_key_value_heads = text_config.get("num_key_value_heads", num_attention_heads) # MoE specific parameters num_experts = text_config.get("num_local_experts") - assert (num_experts - is not None), "num_local_experts must be specified for MoE" + assert num_experts is not None, "num_local_experts must be specified for MoE" head_dim = hidden_size // num_attention_heads # Embedding layers weights["language_model.model.embed_tokens.weight"] = torch.randn( - vocab_size, hidden_size, dtype=torch.float16) + vocab_size, hidden_size, dtype=torch.float16 + ) # Transformer layers for layer_idx in range(num_layers): @@ -312,95 +303,105 @@ def create_text_model_weights( # Self-attention weights (separate q, k, v projections) weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( - hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16) + hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( - hidden_size, num_key_value_heads * head_dim, dtype=torch.bfloat16) + hidden_size, num_key_value_heads * head_dim, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( - num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16) + num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.o_proj.weight"] = torch.randn( - hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16) + hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16 + ) print("Self-attention weights created.") # Feed-forward weights - MoE pattern based on interleave_moe_layer_step # For interleave_moe_layer_step=2: layers 1,3,5,... are MoE, layers # 0,2,4,... are dense interleave_step = text_config.get("interleave_moe_layer_step", 1) - is_moe_layer = (interleave_step > 0 - and (layer_idx + 1) % interleave_step == 0) + is_moe_layer = interleave_step > 0 and (layer_idx + 1) % interleave_step == 0 if is_moe_layer: # MoE layer structure # 1. Router weights - weights[ - f"{layer_prefix}.feed_forward.router.weight"] = torch.randn( - num_experts, hidden_size, dtype=torch.float16) + weights[f"{layer_prefix}.feed_forward.router.weight"] = torch.randn( + num_experts, hidden_size, dtype=torch.float16 + ) # 2. Individual expert weights (not fused) for expert_idx in range(num_experts): - expert_prefix = ( - f"{layer_prefix}.feed_forward.experts.{expert_idx}") + expert_prefix = f"{layer_prefix}.feed_forward.experts.{expert_idx}" weights[f"{expert_prefix}.gate_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{expert_prefix}.up_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{expert_prefix}.down_proj.weight"] = torch.randn( - hidden_size, intermediate_size, dtype=torch.bfloat16) + hidden_size, intermediate_size, dtype=torch.bfloat16 + ) # Expert weight scales (FP8 quantization) - weights[ - f"{expert_prefix}.gate_proj.weight_scale"] = torch.ones( - intermediate_size, 1, dtype=torch.bfloat16) + weights[f"{expert_prefix}.gate_proj.weight_scale"] = torch.ones( + intermediate_size, 1, dtype=torch.bfloat16 + ) weights[f"{expert_prefix}.up_proj.weight_scale"] = torch.ones( - intermediate_size, 1, dtype=torch.bfloat16) - weights[ - f"{expert_prefix}.down_proj.weight_scale"] = torch.ones( - hidden_size, 1, dtype=torch.bfloat16) + intermediate_size, 1, dtype=torch.bfloat16 + ) + weights[f"{expert_prefix}.down_proj.weight_scale"] = torch.ones( + hidden_size, 1, dtype=torch.bfloat16 + ) # 3. Shared expert weights shared_expert_prefix = f"{layer_prefix}.feed_forward.shared_expert" weights[f"{shared_expert_prefix}.gate_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{shared_expert_prefix}.up_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{shared_expert_prefix}.down_proj.weight"] = torch.randn( - hidden_size, intermediate_size, dtype=torch.bfloat16) + hidden_size, intermediate_size, dtype=torch.bfloat16 + ) print(f"MoE feed-forward weights created for layer {layer_idx}.") else: # Dense layer structure - weights[f"{layer_prefix}.feed_forward.gate_proj.weight"] = ( - torch.randn(intermediate_size_mlp, - hidden_size, - dtype=torch.bfloat16)) - weights[f"{layer_prefix}.feed_forward.up_proj.weight"] = ( - torch.randn(intermediate_size_mlp, - hidden_size, - dtype=torch.bfloat16)) - weights[f"{layer_prefix}.feed_forward.down_proj.weight"] = ( - torch.randn(hidden_size, - intermediate_size_mlp, - dtype=torch.bfloat16)) + weights[f"{layer_prefix}.feed_forward.gate_proj.weight"] = torch.randn( + intermediate_size_mlp, hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.feed_forward.up_proj.weight"] = torch.randn( + intermediate_size_mlp, hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.feed_forward.down_proj.weight"] = torch.randn( + hidden_size, intermediate_size_mlp, dtype=torch.bfloat16 + ) print(f"Dense feed-forward weights created for layer {layer_idx}.") # Layer norms weights[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) - weights[ - f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16 + ) print("Layer norms created.") # Final layer norm and output projection weights["language_model.model.norm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights["language_model.lm_head.weight"] = torch.randn( - vocab_size, hidden_size, dtype=torch.bfloat16) + vocab_size, hidden_size, dtype=torch.bfloat16 + ) return weights def create_vision_model_weights( - vision_config: dict[str, Any]) -> dict[str, torch.Tensor]: + vision_config: dict[str, Any], +) -> dict[str, torch.Tensor]: """Create synthetic weights for the vision model.""" weights = {} @@ -414,47 +415,62 @@ def create_vision_model_weights( layer_prefix = f"vision_model.model.layers.{layer_idx}" weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.q_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.k_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.v_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.o_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.o_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc1.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc1.bias"] = torch.zeros( - intermediate_size, dtype=torch.bfloat16) + intermediate_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc2.weight"] = torch.randn( - hidden_size, intermediate_size, dtype=torch.bfloat16) + hidden_size, intermediate_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc2.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.input_layernorm.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) - weights[ - f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.post_attention_layernorm.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) return weights def create_shared_weights( - text_config: dict[str, Any], - vision_config: dict[str, Any]) -> dict[str, torch.Tensor]: + text_config: dict[str, Any], vision_config: dict[str, Any] +) -> dict[str, torch.Tensor]: """Create weights for shared components (vision-language connector)""" weights = {} @@ -464,13 +480,15 @@ def create_shared_weights( # Vision-language connector (projects vision features to text space) weights["multi_modal_projector.linear_1.weight"] = torch.randn( - text_hidden_size, projector_input_dim, dtype=torch.bfloat16) + text_hidden_size, projector_input_dim, dtype=torch.bfloat16 + ) return weights -def save_weights_to_safetensors(weights: dict[str, torch.Tensor], - output_path: Path) -> None: +def save_weights_to_safetensors( + weights: dict[str, torch.Tensor], output_path: Path +) -> None: """Save weights to safetensors files and create index.""" # Determine how to shard the weights @@ -507,18 +525,18 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor], else: # Multiple shards for i, shard in enumerate(shards): - filename = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors" + filename = f"model-{i + 1:05d}-of-{len(shards):05d}.safetensors" save_file(shard, output_path / filename) for name in shard: weight_map[name] = filename - print(f"Saved shard {i+1}/{len(shards)}: {filename}") + print(f"Saved shard {i + 1}/{len(shards)}: {filename}") # Create index file index_data = { "metadata": { - "total_size": - sum(tensor.numel() * tensor.element_size() - for tensor in weights.values()) + "total_size": sum( + tensor.numel() * tensor.element_size() for tensor in weights.values() + ) }, "weight_map": weight_map, } @@ -528,8 +546,9 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor], json.dump(index_data, f, indent=2) print(f"Created index file: {index_path}") - print(f"Total model size: " - f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB") + print( + f"Total model size: {index_data['metadata']['total_size'] / (1024**3):.2f} GB" + ) def check_attention_spec_interleaved_rope( @@ -540,8 +559,7 @@ def check_attention_spec_interleaved_rope( ): """Check that the attention spec is correct.""" assert isinstance(llm.llm_engine.model_executor, Executor) - kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs( - ) + kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs() for rank in range(num_ranks): kv_cache_specs = kv_cache_specs_per_rank[rank] assert len(kv_cache_specs.keys()) == num_attention_layers @@ -551,16 +569,14 @@ def check_attention_spec_interleaved_rope( else: expected_spec = ChunkedLocalAttentionSpec assert isinstance( - kv_cache_specs[ - f"language_model.model.layers.{i}.self_attn.attn"], - expected_spec) + kv_cache_specs[f"language_model.model.layers.{i}.self_attn.attn"], + expected_spec, + ) def run_reduced_model(llm: LLM, should_profile: bool = False) -> None: """Test the created reduced model with vLLM.""" - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=50) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50) if should_profile: llm.start_profile() @@ -571,15 +587,15 @@ def run_reduced_model(llm: LLM, should_profile: bool = False) -> None: print("Test generation successful!") for output in outputs: print(f"Prompt: {output.prompt}") - print(f"Output: " - f"{output.outputs[0].text}") + print(f"Output: {output.outputs[0].text}") print("-" * 40) @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "original_model_name,text_layers,num_experts,vision_layers,", - [("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)]) + [("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)], +) @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("tp,ep", [(2, True)]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -597,7 +613,6 @@ def test_dummy_maverick( profile: bool = False, ) -> None: # Disable multiprocessing allows us to access model executor from LLM engine - monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") model_path = create_reduced_maverick_model( @@ -640,7 +655,8 @@ def main(): import argparse parser = argparse.ArgumentParser( - description="Create a reduced-layer Maverick model") + description="Create a reduced-layer Maverick model" + ) parser.add_argument( "--output-dir", default="/tmp/reduced_maverick", @@ -652,10 +668,7 @@ def main(): default=4, help="Number of text transformer layers", ) - parser.add_argument("--num-experts", - type=int, - default=4, - help="Number of experts") + parser.add_argument("--num-experts", type=int, default=4, help="Number of experts") parser.add_argument( "--vision-layers", type=int, @@ -667,12 +680,12 @@ def main(): action="store_true", help="Force recreation if output directory exists", ) - parser.add_argument("--test", - action="store_true", - help="Test the created model with vLLM") - parser.add_argument("--profile", - action="store_true", - help="Profile the created model with vLLM") + parser.add_argument( + "--test", action="store_true", help="Test the created model with vLLM" + ) + parser.add_argument( + "--profile", action="store_true", help="Profile the created model with vLLM" + ) parser.add_argument( "--test-original", action="store_true", @@ -687,16 +700,18 @@ def main(): args = parser.parse_args() if args.test: - test_dummy_maverick(original_model_name=args.original_model, - output_dir=args.output_dir, - text_layers=args.text_layers, - num_experts=args.num_experts, - vision_layers=args.vision_layers, - force_recreate=args.force_recreate, - tp=2, - ep=True, - enforce_eager=True, - profile=args.profile) + test_dummy_maverick( + original_model_name=args.original_model, + output_dir=args.output_dir, + text_layers=args.text_layers, + num_experts=args.num_experts, + vision_layers=args.vision_layers, + force_recreate=args.force_recreate, + tp=2, + ep=True, + enforce_eager=True, + profile=args.profile, + ) if args.test_original: run_maverick_serving(args.original_model) diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py deleted file mode 100644 index 1c32cc6d71c04..0000000000000 --- a/tests/models/multimodal/generation/test_mllama.py +++ /dev/null @@ -1,768 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional, overload - -import pytest -import torch -from packaging.version import Version -from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer -from transformers import __version__ as TRANSFORMERS_VERSION - -from vllm import LLM, SamplingParams -from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) -from vllm.model_executor.models.mllama import MllamaForConditionalGeneration -from vllm.multimodal.image import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, - PromptImageInput, VllmRunner) -from ....quantization.utils import is_quant_method_supported -from ....utils import (create_new_process_for_each_test, large_gpu_test, - multi_gpu_test) -from ...utils import check_logprobs_close - -_LIMIT_IMAGE_PER_PROMPT = 3 -MLLAMA_IMAGE_TOKEN_ID = 128256 - -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|image|><|begin_of_text|>The meaning of the image is", - "cherry_blossom": - "<|image|><|begin_of_text|>The city is", -}) - -text_only_prompts = [ - "The color of the sky is blue but sometimes it can also be", -] - -models = [ - "meta-llama/Llama-3.2-11B-Vision-Instruct", -] - -# Indices for inputs -TEXT_ONLY = '0' -IMAGE_AT_BEG = '1' -IMAGE_AT_MIDDLE = '2' -TWO_IMAGES = '3' - -# Input tokenized -prompt_data = { - # Tell me a story - TEXT_ONLY: [41551, 757, 264, 3446], - # <|image|> What's the content of this image - IMAGE_AT_BEG: - [MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220], - # Hello <|image|>What' the content of this image - IMAGE_AT_MIDDLE: - [9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217], - #<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501 - TWO_IMAGES: [ - MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30, - MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30 - ] -} - - -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) - if token_id != image_token_id or output_ids[idx - 1] != image_token_id - ] - - hf_output_str = output_str - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -def _get_inputs( - image_assets: ImageTestAssets, - *, - size_factors: Optional[list[float]] = None, - sizes: Optional[list[tuple[int, int]]] = None, -) -> list[tuple[list[str], PromptImageInput]]: - images = [asset.pil_image for asset in image_assets] - - if size_factors is not None: - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - elif sizes is not None: - inputs_per_image = [( - [ - prompt if size is not None else text_only_prompts[0] - for size in sizes - ], - [ - image.resize(size) if size is not None else None - for size in sizes - ], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - if len(sizes) == 0: - inputs_per_image.append( - (text_only_prompts, [None] * len(text_only_prompts))) - else: - raise ValueError("You must provide either `size_factors` or `sizes`") - - return inputs_per_image - - -@overload -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - model: str, - *, - size_factors: list[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -@overload -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - model: str, - *, - sizes: list[tuple[int, int]], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - model: str, - *, - size_factors: Optional[list[float]] = None, - sizes: Optional[list[tuple[int, int]]] = None, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - _run_test( - hf_runner, - vllm_runner, - _get_inputs(image_assets, size_factors=size_factors, sizes=sizes), - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - ) - - -def _run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner( - model, - dtype=dtype, - max_model_len=19212, # 3 max size images - max_num_seqs=3, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - limit_mm_per_prompt={"image": - _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - with hf_runner(model, - dtype=dtype, - model_kwargs={"device_map": "auto"}, - auto_cls=AutoModelForImageTextToText) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Fixture to clear backend cache before each test.""" - _cached_get_attn_backend.cache_clear() # Clear the cache - yield # This allows the test to run - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "sizes", - [ - # Text only - [], - # Single-size - [(512, 512)], - # Single-size, batched - [(512, 512), (512, 512), (512, 512)], - # Multi-size, batched - [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), - (1024, 1024), (512, 1536), (512, 2028)], - # Multi-size, batched, including text only - [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), - (1024, 1024), (512, 1536), (512, 2028), None], - # mllama has 8 possible aspect ratios, carefully set the sizes - # to cover all of them - ]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, - model, sizes, dtype, max_tokens, - num_logprobs, - attn_backend: _Backend) -> None: - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - sizes=sizes, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, - model, dtype, max_tokens, num_logprobs, - attn_backend: _Backend) -> None: - - stop_sign = image_assets[0].pil_image - cherry_blossom = image_assets[1].pil_image - - inputs = [( - [ - "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 - "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 - "<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501 - ], - [ - [stop_sign, cherry_blossom], - # Images with different sizes. - [ - stop_sign.resize((512, 512)), - stop_sign, - ], - [ - stop_sign, - stop_sign.resize((512, 1536)), - cherry_blossom.resize((512, 1024)), - ], - ])] - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, - dtype, max_tokens, num_logprobs, - attn_backend: _Backend) -> None: - - stop_sign = image_assets[0].pil_image - cherry_blossom = image_assets[1].pil_image - - inputs = [( - [ - "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 - "<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501 - "which is a stop sign and which is a cherry blossom?", # noqa: E501 - ], - [ - [stop_sign], - [stop_sign, cherry_blossom], - ])] - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@create_new_process_for_each_test() -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_distributed( - hf_runner, - vllm_runner, - image_assets, - distributed_executor_backend, - model, - dtype, - max_tokens, - num_logprobs, -) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model=model, - size_factors=[0.25, 0.5, 1.0], - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["float16"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -def test_bnb_regression( - image_assets: ImageTestAssets, - model: str, - dtype: str, - max_tokens: int, -): - stop_sign = image_assets[0].pil_image - prompts = [ - { - "prompt": "<|begin_of_text|>The content of the image <|image|> is", - "multi_modal_data": { - "image": stop_sign - }, - }, - { - "prompt": - "The color of the sky is blue but sometimes it can also be", - }, - ] - # Test regression about QKVCrossParallelLinear - llm = LLM( - model=model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=2, - quantization="bitsandbytes", - ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=max_tokens, - ) - outputs = llm.generate(prompts, sampling_params) - assert outputs - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [32]) -def test_explicit_implicit_prompt( - image_assets: ImageTestAssets, - model: str, - dtype: str, - max_tokens: int, -): - stop_sign = image_assets[0].pil_image - # yapf: disable - prompts = [ - # explicit prompt - { - "encoder_prompt": { - "prompt": "<|image|>", - "multi_modal_data": {"image": stop_sign}, - }, - "decoder_prompt": { - "prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501 - } - }, - { - "encoder_prompt": "Not <|image|>", - "decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 - }, - # implicit prompt - { - "prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 - "multi_modal_data": {"image": stop_sign}, - }, - { - "prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 - }, - ] - # yapf: enable - llm = LLM( - model=model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=2, - tensor_parallel_size=1, - ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=max_tokens, - ) - outputs = llm.generate(prompts, sampling_params) - n_prompts = len(prompts) - explicit_outputs = outputs[:n_prompts // 2] - implicit_outputs = outputs[n_prompts // 2:] - for exp_output, imp_output in zip(explicit_outputs, implicit_outputs): - assert exp_output.outputs[0].text == imp_output.outputs[0].text - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, - num_logprobs, attn_backend: _Backend) -> None: - - stop_sign = image_assets[0].pil_image - - with global_force_attn_backend_context_manager(attn_backend), vllm_runner( - model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=4, - tensor_parallel_size=1, - limit_mm_per_prompt={"image": - _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: - - # Regression tests for https://github.com/vllm-project/vllm/issues/10648 - - # Number of groups of image tokens is greater than the number of images - # provided (the whitespace between the tags is necessary) - prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501 - image = stop_sign - with pytest.raises(ValueError): - vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs, - images=[image]) - - # Batch of a text-only and image request that requires cross-attention - prompts = [ - "What is the capital of spain?", - "Text before the image...<|image|>What is in the image?", # noqa: E501 - ] - images = [ - None, - [stop_sign], - ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) - - # Test the reverse order too for good measure - prompts = [ - "<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501 - "<|begin_of_text|>Hello!", - ] - images = [ - [stop_sign], - None, - ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) - - # Mixed batch with text and images with different numbers of tiles - prompts = [ - "<|begin_of_text|>Hello!", - "<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501 - "<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501 - ] - images = [ - None, - [stop_sign], - # smaller image must be 2nd for the repro - [stop_sign.resize((448, 448))], - ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) - - -class DummyModel: - image_token_id = MLLAMA_IMAGE_TOKEN_ID - - -@pytest.mark.core_model -@pytest.mark.parametrize( - "input_indices_and_output", - # inputs, (cross_attention_mask, kv_range_for_decode) - [([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)), - ([TEXT_ONLY, IMAGE_AT_BEG], (None, None)), - ([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])), - ([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])), - ([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], - ((23, 24), [[0, 6], [6, 12]])), - ([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])), - ([TWO_IMAGES], ((18, 12), [[6, 12]])), - ([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))]) -def test_get_cross_attention_mask(input_indices_and_output) -> None: - - input_indices, expected_output = input_indices_and_output - - sequences = [torch.tensor(prompt_data[i]) for i in input_indices] - num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices - if i != TEXT_ONLY] - input = torch.cat(sequences) - - seq_lens = [len(s) for s in sequences] - - attn_data = FlashAttentionMetadata( - seq_lens=seq_lens, - # Dummy values - enable_kv_scales_calculation=False, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=0, - slot_mapping=0, - multi_modal_placeholder_index_maps=None, - seq_lens_tensor=0, - max_prefill_seq_len=0, - max_decode_seq_len=0, - context_lens_tensor=None, - block_tables=None, - use_cuda_graph=False, - ) - - dummy = DummyModel() - - cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\ - .get_cross_attention_mask(dummy, - input, - attn_data, - num_tiles=num_tiles, - num_tokens_per_tile=3, - dtype=torch.bfloat16) - - expected_cross_attention_mask, expected_kv_range_for_decode = \ - expected_output - - assert kv_range_for_decode == expected_kv_range_for_decode - if expected_cross_attention_mask is not None: - assert cross_attention_mask is not None - assert cross_attention_mask.shape == expected_cross_attention_mask - else: - assert cross_attention_mask is None - - -@pytest.mark.core_model -@pytest.mark.parametrize( - "input_indices", - [[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE], - [TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], - [IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]]) -def test_get_full_text_row_masked_out_mask(input_indices) -> None: - - sequences = [torch.tensor(prompt_data[i]) for i in input_indices] - - seq_lens = [len(s) for s in sequences] - - num_prefill_tokens = sum(seq_lens) - - # TEXT_ONLY is zero, so it will be masked out, - # other instances should not be. - encoder_seq_lens = [int(i) for i in input_indices] - - attn_data = FlashAttentionMetadata( - seq_lens=seq_lens, - encoder_seq_lens=encoder_seq_lens, - num_prefill_tokens=num_prefill_tokens, - # Dummy values - enable_kv_scales_calculation=False, - num_prefills=0, - num_decode_tokens=0, - slot_mapping=0, - multi_modal_placeholder_index_maps=None, - seq_lens_tensor=0, - max_prefill_seq_len=0, - max_decode_seq_len=0, - context_lens_tensor=None, - block_tables=None, - use_cuda_graph=False, - ) - - dummy = DummyModel() - - full_text_row_masked_out_mask = MllamaForConditionalGeneration\ - .get_full_text_row_masked_out_mask(dummy, - attn_data, - torch.get_default_device()) - - full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze() - full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist() - - idx = 0 - assert len(full_text_row_masked_out_mask) == num_prefill_tokens - for i, seq_len in enumerate(seq_lens): - must_be_masked = input_indices[i] != TEXT_ONLY - for _ in range(seq_len): - assert full_text_row_masked_out_mask[idx] == must_be_masked, \ - f"full_text_row_masked_out_mask[{idx}] must be " \ - f"'{must_be_masked}' " - idx += 1 - - -@pytest.mark.core_model -@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [ - ([6404], [[4]], [6404]), - ([0, 6404], [[4]], [6404]), - ([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]), - ([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]), -]) -def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles, - expected) -> None: - - dummy = DummyModel() - num_tokens_per_tile = 1601 - actual_encoder_seq_lens = MllamaForConditionalGeneration \ - ._get_and_validate_encoder_lens( - dummy, - encoder_seq_lens, - num_tiles, - num_tokens_per_tile, - ) - assert actual_encoder_seq_lens == expected, \ - f"Expected {expected} but got {actual_encoder_seq_lens}" diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py index db8984d8656fc..132c69285c5c7 100644 --- a/tests/models/multimodal/generation/test_phi4_multimodal.py +++ b/tests/models/multimodal/generation/test_phi4_multimodal.py @@ -14,26 +14,35 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.image import rescale_image_size from vllm.platforms import current_platform -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, - PromptImageInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + PromptAudioInput, + PromptImageInput, + VllmRunner, +) from ....utils import large_gpu_test from ...utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": - "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 + } +) +HF_MULTIIMAGE_IMAGE_PROMPT = ( + "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +) -model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct", - revision="refs/pr/70") +model_path = snapshot_download( + "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" +) # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join(model_path, "examples", - "what_is_shown_in_this_image.wav") +speech_question = os.path.join( + model_path, "examples", "what_is_shown_in_this_image.wav" +) models = [model_path] target_dtype = "half" @@ -48,8 +57,7 @@ if current_platform.is_rocm(): def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, - Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], model: str, *, max_model_len: int, @@ -75,28 +83,30 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - task="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, - trust_remote_code=False, + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=320, + gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI + enforce_eager=True, + trust_remote_code=False, ) as vllm_model: lora_request = LoRARequest("vision", 1, vision_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + lora_request=lora_request, + ) for prompts, images, audios in inputs ] @@ -108,17 +118,18 @@ def run_test( hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + eos_token_id=eos_token_id, + ) for prompts, images, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, @@ -145,16 +156,27 @@ def run_test( @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + None, + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] run_test( hf_runner, @@ -189,16 +211,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_multi_images_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case = [ ( [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], None, ), ] @@ -222,10 +254,15 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, - max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: - +def test_vision_speech_models( + hf_runner, + vllm_runner, + model, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=16000) image = ImageAsset("cherry_blossom").pil_image.convert("RGB") diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 67d35213d6422..e69d44c6a1319 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -12,36 +12,44 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer from vllm.assets.image import ImageAsset +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, - PromptImageInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + PromptAudioInput, + PromptImageInput, + VllmRunner, +) from ....utils import large_gpu_test from ...utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": - "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 + } +) +HF_MULTIIMAGE_IMAGE_PROMPT = ( + "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +) model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join(model_path, "examples", - "what_is_shown_in_this_image.wav") +speech_question = os.path.join( + model_path, "examples", "what_is_shown_in_this_image.wav" +) models = [model_path] -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str +): """Sanitize vllm output to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -71,8 +79,7 @@ if current_platform.is_rocm(): def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, - Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], model: str, *, max_model_len: int, @@ -98,27 +105,29 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - runner="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, + model, + runner="generate", + max_model_len=max_model_len, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=320, + gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI + enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("vision", 1, vision_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + lora_request=lora_request, + ) for prompts, images, audios in inputs ] @@ -127,42 +136,36 @@ def run_test( pytest.skip("HF impl is not compatible with current transformers") hf_model_kwargs = {"_attn_implementation": "sdpa"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: - + with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id - def patch_hf_processor(*args, - text="", - images=None, - audio=None, - sampling_rate=None, - **kwargs): + def patch_hf_processor( + *args, text="", images=None, audio=None, sampling_rate=None, **kwargs + ): audios = None if audio is not None and sampling_rate is not None: audios = [(audio, sampling_rate)] - return hf_processor(*args, - text=text, - images=images, - audios=audios, - **kwargs) + return hf_processor( + *args, text=text, images=images, audios=audios, **kwargs + ) hf_model.processor = patch_hf_processor hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id, - num_logits_to_keep=0) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + eos_token_id=eos_token_id, + num_logits_to_keep=0, + ) for prompts, images, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, @@ -189,16 +192,27 @@ def run_test( @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + None, + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] run_test( hf_runner, @@ -233,16 +247,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_multi_images_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case = [ ( [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], None, ), ] @@ -266,10 +290,15 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, - max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: - +def test_vision_speech_models( + hf_runner, + vllm_runner, + model, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=None) image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index d39cf706786e2..bde07da9101ac 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -6,19 +6,18 @@ from typing import TYPE_CHECKING, Any, Optional import pytest from mistral_common.multimodal import download_image -from mistral_common.protocol.instruct.messages import ImageURLChunk +from mistral_common.protocol.instruct.chunk import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from transformers import AutoProcessor -from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt +from vllm import SamplingParams, TextPrompt, TokensPrompt +from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins -from vllm.multimodal.inputs import PlaceholderRange -from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test -from ...utils import check_logprobs_close, dummy_hf_overrides +from ...utils import check_logprobs_close if TYPE_CHECKING: from _typeshed import StrPath @@ -29,42 +28,42 @@ MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID] IMG_URLS = [ - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/231-200x300.jpg", - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/27-500x500.jpg", - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/17-150x600.jpg", + "237-400x300.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "231-200x300.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "27-500x500.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "17-150x600.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", ] PROMPT = "Describe each image in one short sentence." def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]: - return [{ - "role": - "user", - "content": [{ - "type": "text", - "text": PROMPT, - }] + [{ - "type": "image_url", - "image_url": { - "url": url - } - } for url in urls], - }] + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": PROMPT, + } + ] + + [{"type": "image_url", "image_url": {"url": url}} for url in urls], + } + ] def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]: - return [{ - "role": - "user", - "content": [{ - "type": "text", - "content": PROMPT, - }, *({ - "type": "image", - "image": download_image(url) - } for url in urls)], - }] + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "content": PROMPT, + }, + *({"type": "image", "image": download_image(url)} for url in urls), + ], + } + ] def _create_engine_inputs(urls: list[str]) -> TokensPrompt: @@ -105,12 +104,6 @@ def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt: return engine_inputs -MSGS = [ - _create_msg_format(IMG_URLS[:1]), - _create_msg_format(IMG_URLS[:2]), - _create_msg_format(IMG_URLS), -] - SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) LIMIT_MM_PER_PROMPT = dict(image=4) @@ -132,11 +125,17 @@ def _dump_outputs_w_logprobs( outputs: OutputsLogprobs, filename: "StrPath", ) -> None: - json_data = [(tokens, text, [{ - k: asdict(v) - for k, v in token_logprobs.items() - } for token_logprobs in (logprobs or [])]) - for tokens, text, logprobs in outputs] + json_data = [ + ( + tokens, + text, + [ + {k: asdict(v) for k, v in token_logprobs.items()} + for token_logprobs in (logprobs or []) + ], + ) + for tokens, text, logprobs in outputs + ] with open(filename, "w") as f: json.dump(json_data, f) @@ -146,10 +145,17 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: with open(filename, "rb") as f: json_data = json.load(f) - return [(tokens, text, [{ - int(k): Logprob(**v) - for k, v in token_logprobs.items() - } for token_logprobs in logprobs]) for tokens, text, logprobs in json_data] + return [ + ( + tokens, + text, + [ + {int(k): Logprob(**v) for k, v in token_logprobs.items()} + for token_logprobs in logprobs + ], + ) + for tokens, text, logprobs in json_data + ] @large_gpu_test(min_gb=80) @@ -157,24 +163,27 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_chat( - vllm_runner, - max_model_len: int, - model: str, - dtype: str, + vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server ) -> None: - EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs( - FIXTURE_LOGPROBS_CHAT[model]) + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="mistral", - load_format="mistral", - config_format="mistral", - max_model_len=max_model_len, - limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = [] - for msg in MSGS: + + urls_all = [local_asset_server.url_for(u) for u in IMG_URLS] + msgs = [ + _create_msg_format(urls_all[:1]), + _create_msg_format(urls_all[:2]), + _create_msg_format(urls_all), + ] + for msg in msgs: output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS) outputs.extend(output) @@ -184,49 +193,9 @@ def test_chat( for i in range(len(logprobs)): assert logprobs[i][-1] is None logprobs[i] = logprobs[i][:-1] - check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, - outputs_1_lst=logprobs, - name_0="h100_ref", - name_1="output") - - -@pytest.mark.parametrize("prompt,expected_ranges", - [(_create_engine_inputs_hf(IMG_URLS[:1]), - [PlaceholderRange(offset=11, length=494)]), - (_create_engine_inputs_hf(IMG_URLS[1:4]), [ - PlaceholderRange(offset=11, length=266), - PlaceholderRange(offset=277, length=1056), - PlaceholderRange(offset=1333, length=418) - ])]) -def test_multi_modal_placeholders(vllm_runner, prompt: TextPrompt, - expected_ranges: list[PlaceholderRange], - monkeypatch) -> None: - - # This placeholder checking test only works with V0 engine - # where `multi_modal_placeholders` is returned with `RequestOutput` - monkeypatch.setenv("VLLM_USE_V1", "0") - with vllm_runner( - "mistral-community/pixtral-12b", - max_model_len=8192, - limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, - load_format="dummy", - hf_overrides=dummy_hf_overrides, - ) as vllm_model: - outputs = vllm_model.llm.generate(prompt) - - assert len(outputs) == 1, f"{len(outputs)=}" - output: RequestOutput = outputs[0] - assert hasattr(output, - "multi_modal_placeholders"), f"{output.__dict__=}" - assert "image" in output.multi_modal_placeholders, \ - f"{output.multi_modal_placeholders.keys()=}" - image_placeholder_ranges: list[ - PlaceholderRange] = output.multi_modal_placeholders["image"] - assert len(image_placeholder_ranges) == len( - expected_ranges), f"{image_placeholder_ranges=}" - for real_range, expected_range in zip(image_placeholder_ranges, - expected_ranges): - assert real_range.offset == expected_range.offset, \ - f"{real_range=} {expected_range=}" - assert real_range.length == expected_range.length, \ - f"{real_range=} {expected_range=}" + check_logprobs_close( + outputs_0_lst=EXPECTED_CHAT_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output", + ) diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py new file mode 100644 index 0000000000000..1a7d854352ae6 --- /dev/null +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.multimodal.video import sample_frames_from_video + +from ....conftest import VIDEO_ASSETS + +models = ["Qwen/Qwen2.5-VL-3B-Instruct"] +target_dtype = "bfloat16" + +VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" + + +def qwen2_5_vl_chat_template(*query): + return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 + + +VIDEO_PROMPTS = VIDEO_ASSETS.prompts( + { + "baby_reading": qwen2_5_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), + } +) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) +@pytest.mark.parametrize("num_frames", [16]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_qwen2_5_vl_evs_functionality( + vllm_runner, + video_assets, + model, + video_pruning_rate: float, + num_frames: int, + dtype: str, + max_tokens: int, +) -> None: + """Test EVS (Efficient Video Sampling) functionality with different + pruning rates. + """ + + # Sample frames from video assets + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + prompts = [VIDEO_PROMPTS[0]] + videos = [sampled_vids[0]] + + # Initialize model with EVS configuration + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"video": 1}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate, + ) as vllm_model: + # Generate output - this should not crash + outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) + + # Basic validation that we got a response + assert len(outputs) == 1 + output_ids, output_text = outputs[0] + + # Ensure we got some output + assert len(output_ids) > 0 + assert len(output_text) > 0 + + # Ensure the output is a string + assert isinstance(output_text, str) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) +@pytest.mark.parametrize("num_frames", [16]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_qwen2_5_vl_evs_batched_videos( + vllm_runner, + video_assets, + model, + video_pruning_rate: float, + num_frames: int, + dtype: str, + max_tokens: int, +) -> None: + """Test EVS functionality with batched videos. + + This test validates that: + 1. The model handles batched video inputs correctly with EVS + 2. Both pruning configurations work with multiple videos + 3. The model doesn't crash when processing multiple videos simultaneously + """ + # Sample frames from video assets + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + # Test batched videos + prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]] + videos = [sampled_vids[0], sampled_vids[0]] # Use same video twice for testing + + # Initialize model with EVS configuration + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"video": 2}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate, + ) as vllm_model: + # Generate output - this should not crash + outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) + + # Basic validation that we got responses for both videos + assert len(outputs) == 2 + + for output_ids, output_text in outputs: + # Ensure we got some output for each video + assert len(output_ids) > 0 + assert len(output_text) > 0 + + # Ensure the output is a string + assert isinstance(output_text, str) diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index c61c27ae204a3..a8f0ba8701850 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -11,17 +11,20 @@ from PIL import Image from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import rescale_video_size, sample_frames_from_video -from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, - PromptVideoInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + VIDEO_ASSETS, + PromptImageInput, + PromptVideoInput, + VllmRunner, +) from ...utils import check_logprobs_close @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - V1 Test: batch_make_xxxxx_embeddings calls a V0 internal - """ - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") models = ["Qwen/Qwen2-VL-2B-Instruct"] @@ -36,28 +39,29 @@ def qwen2_vl_chat_template(*query): return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 -IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - qwen2_vl_chat_template( - IMAGE_PLACEHOLDER, - "What is the biggest text's content in this image?", - ), - "cherry_blossom": - qwen2_vl_chat_template( - IMAGE_PLACEHOLDER, - "What is the season shown in this image? ", - "Reply with a short sentence (no more than 20 words)", - ), -}) +IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the biggest text's content in this image?", + ), + "cherry_blossom": qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the season shown in this image? ", + "Reply with a short sentence (no more than 20 words)", + ), + } +) -VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "baby_reading": - qwen2_vl_chat_template( - VIDEO_PLACEHOLDER, - "Describe this video with a short sentence ", - "(no more than 20 words)", - ), -}) +VIDEO_PROMPTS = VIDEO_ASSETS.prompts( + { + "baby_reading": qwen2_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), + } +) MULTIIMAGE_PROMPT = qwen2_vl_chat_template( IMAGE_PLACEHOLDER, @@ -79,17 +83,19 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict): def batch_make_image_embeddings( - image_batches: list[Union[Image.Image, list[Image.Image]]], processor, - llm: VllmRunner) -> list[Qwen2VLPromptImageEmbeddingInput]: + image_batches: list[Union[Image.Image, list[Image.Image]]], + processor, + llm: VllmRunner, +) -> list[Qwen2VLPromptImageEmbeddingInput]: """batched image embeddings for Qwen2-VL - This will infer all images' embeddings in a single batch, + This will infer all images' embeddings in a single batch, and split the result according to input batches. image_batches: - Single-image batches: `list[Image.Image]` - Multiple-image batches: `list[list[Image.Image]]]` - + returns: `list[Qwen2VLPromptImageEmbeddingInput]` """ @@ -110,9 +116,9 @@ def batch_make_image_embeddings( # image to pixel values image_processor = processor.image_processor - preprocess_result = image_processor \ - .preprocess(images=images, return_tensors="pt") \ - .data + preprocess_result = image_processor.preprocess( + images=images, return_tensors="pt" + ).data pixel_values = preprocess_result["pixel_values"] image_grid_thw = preprocess_result["image_grid_thw"] @@ -121,14 +127,14 @@ def batch_make_image_embeddings( with torch.no_grad(): visual = model.visual - pixel_values_on_device = pixel_values.to(visual.device, - dtype=visual.dtype) - image_grid_thw_on_device = image_grid_thw.to(visual.device, - dtype=torch.int64) - return visual(pixel_values_on_device, - grid_thw=image_grid_thw_on_device) + pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) + image_grid_thw_on_device = image_grid_thw.to( + visual.device, dtype=torch.int64 + ) + return visual( + pixel_values_on_device, grid_thw=image_grid_thw_on_device + ).cpu() - # V1 Test: this calls a V0 internal. image_embeds = torch.concat(llm.apply_model(get_image_embeds)) # split into original batches @@ -140,21 +146,26 @@ def batch_make_image_embeddings( merge_size = image_processor.merge_size cur_batch_embed_len = sum( grid_thw.prod(-1) // merge_size // merge_size - for grid_thw in image_grid_thw[image_counter:image_counter + - cur_batch_image_count]) + for grid_thw in image_grid_thw[ + image_counter : image_counter + cur_batch_image_count + ] + ) - result.append({ - "image_embeds": - image_embeds[embed_counter:embed_counter + cur_batch_embed_len], - "image_grid_thw": - image_grid_thw[image_counter:image_counter + - cur_batch_image_count], - }) + result.append( + { + "image_embeds": image_embeds[ + embed_counter : embed_counter + cur_batch_embed_len + ], + "image_grid_thw": image_grid_thw[ + image_counter : image_counter + cur_batch_image_count + ], + } + ) embed_counter += cur_batch_embed_len image_counter += cur_batch_image_count - # ensure we don't lost any images or embeddings + # ensure we don't lose any images or embeddings assert embed_counter == image_embeds.size(0) assert image_counter == image_grid_thw.size(0) assert len(image_batches) == len(result) @@ -163,13 +174,13 @@ def batch_make_image_embeddings( def batch_make_video_embeddings( - video_batches: PromptVideoInput, processor, - llm: VllmRunner) -> list[Qwen2VLPromptVideoEmbeddingInput]: + video_batches: PromptVideoInput, processor, llm: VllmRunner +) -> list[Qwen2VLPromptVideoEmbeddingInput]: """batched video embeddings for Qwen2-VL A NDArray represents a single video's all frames. - This will infer all videos' embeddings in a single batch, + This will infer all videos' embeddings in a single batch, and split the result according to input batches. video_batches: @@ -194,9 +205,9 @@ def batch_make_video_embeddings( # video to pixel values image_processor = processor.image_processor - preprocess_result = image_processor \ - .preprocess(images=None, videos=videos, return_tensors="pt") \ - .data + preprocess_result = image_processor.preprocess( + images=None, videos=videos, return_tensors="pt" + ).data pixel_values = preprocess_result["pixel_values_videos"] video_grid_thw = preprocess_result["video_grid_thw"] @@ -205,14 +216,14 @@ def batch_make_video_embeddings( with torch.no_grad(): visual = model.visual - pixel_values_on_device = pixel_values.to(visual.device, - dtype=visual.dtype) - video_grid_thw_on_device = video_grid_thw.to(visual.device, - dtype=torch.int64) - return visual(pixel_values_on_device, - grid_thw=video_grid_thw_on_device) + pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) + video_grid_thw_on_device = video_grid_thw.to( + visual.device, dtype=torch.int64 + ) + return visual( + pixel_values_on_device, grid_thw=video_grid_thw_on_device + ).cpu() - # V1 Test: this calls a V0 internal. video_embeds = torch.concat(llm.apply_model(get_image_embeds)) # split into original batches @@ -224,21 +235,26 @@ def batch_make_video_embeddings( merge_size = image_processor.merge_size cur_batch_embed_len = sum( grid_thw.prod(-1) // merge_size // merge_size - for grid_thw in video_grid_thw[video_counter:video_counter + - cur_batch_video_count]) + for grid_thw in video_grid_thw[ + video_counter : video_counter + cur_batch_video_count + ] + ) - result.append({ - "video_embeds": - video_embeds[embed_counter:embed_counter + cur_batch_embed_len], - "video_grid_thw": - video_grid_thw[video_counter:video_counter + - cur_batch_video_count], - }) + result.append( + { + "video_embeds": video_embeds[ + embed_counter : embed_counter + cur_batch_embed_len + ], + "video_grid_thw": video_grid_thw[ + video_counter : video_counter + cur_batch_video_count + ], + } + ) embed_counter += cur_batch_embed_len video_counter += cur_batch_video_count - # ensure we don't lost any videos or embeddings + # ensure we don't lose any videos or embeddings assert embed_counter == video_embeds.size(0) assert video_counter == video_grid_thw.size(0) assert len(video_batches) == len(result) @@ -266,25 +282,25 @@ def run_embedding_input_test( processor = AutoProcessor.from_pretrained(model) # max_model_len should be greater than image_feature_size - with vllm_runner(model, - runner="generate", - max_model_len=4000, - max_num_seqs=3, - dtype=dtype, - limit_mm_per_prompt={ - "image": mm_limit, - "video": mm_limit - }, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: - + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=3, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit, "video": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + default_torch_num_threads=1, + ) as vllm_model: outputs_per_case_for_original_input = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images or None, - videos=videos or None) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None, + ) for prompts, images, videos in inputs ] @@ -293,17 +309,19 @@ def run_embedding_input_test( prompts, max_tokens, num_logprobs=num_logprobs, - images=batch_make_image_embeddings( - images, processor, vllm_model) if images else None, - videos=batch_make_video_embeddings( - videos, processor, vllm_model) if videos else None) + images=batch_make_image_embeddings(images, processor, vllm_model) + if images + else None, + videos=batch_make_video_embeddings(videos, processor, vllm_model) + if videos + else None, + ) for prompts, images, videos in inputs ] - for outputs_for_original_input, \ - outputs_for_embeddings_input \ - in zip(outputs_per_case_for_original_input, - outputs_per_case_for_embeddings_input): + for outputs_for_original_input, outputs_for_embeddings_input in zip( + outputs_per_case_for_original_input, outputs_per_case_for_embeddings_input + ): check_logprobs_close( outputs_0_lst=outputs_for_original_input, outputs_1_lst=outputs_for_embeddings_input, @@ -328,18 +346,26 @@ def run_embedding_input_test( @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, - size_factors, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_image_embeddings_input( + vllm_runner, + image_assets, + model, + size_factors, + dtype, + max_tokens, + num_logprobs, + monkeypatch, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: list[tuple[ - list[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], [], - ) for image, prompt in zip(images, IMAGE_PROMPTS)] + ) + for image, prompt in zip(images, IMAGE_PROMPTS) + ] run_embedding_input_test( vllm_runner, @@ -370,21 +396,27 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, - model, size_factors, - dtype: str, max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_multiple_image_embeddings_input( + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: list[tuple[list[str], PromptImageInput, - PromptVideoInput]] = [( - [MULTIIMAGE_PROMPT for _ in size_factors], - [[ - rescale_image_size(image, factor) - for image in images - ] for factor in size_factors], - [], - )] + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( + [MULTIIMAGE_PROMPT for _ in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], + [], + ) + ] run_embedding_input_test( vllm_runner, @@ -414,22 +446,29 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, - size_factors, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_video_embeddings_input( + vllm_runner, + video_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: num_frames = 4 sampled_vids = [ sample_frames_from_video(asset.np_ndarrays, num_frames) for asset in video_assets ] - inputs_per_case: list[tuple[ - list[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( [prompt for _ in size_factors], [], [rescale_video_size(video, factor) for factor in size_factors], - ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] + ) + for video, prompt in zip(sampled_vids, VIDEO_PROMPTS) + ] run_embedding_input_test( vllm_runner, diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index e7e7bd3154a11..6bfec6c2c8d30 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -15,12 +15,12 @@ from ...registry import HF_EXAMPLE_MODELS MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" -AUDIO_PROMPTS = AUDIO_ASSETS.prompts({ - "mary_had_lamb": - "Transcribe this into English.", - "winning_call": - "What is happening in this audio clip?", -}) +AUDIO_PROMPTS = AUDIO_ASSETS.prompts( + { + "mary_had_lamb": "Transcribe this into English.", + "winning_call": "What is happening in this audio clip?", + } +) MULTI_AUDIO_PROMPT = "Describe each of the audios above." @@ -33,7 +33,7 @@ CHUNKED_PREFILL_KWARGS = { "enable_chunked_prefill": True, "max_num_seqs": 2, # Use a very small limit to exercise chunked prefill. - "max_num_batched_tokens": 16 + "max_num_batched_tokens": 16, } @@ -43,27 +43,33 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: for key, value in params_kwargs.items(): if isinstance(value, bool): if value: - args.append(f"--{key.replace('_','-')}") + args.append(f"--{key.replace('_', '-')}") else: - args.append(f"--{key.replace('_','-')}={value}") + args.append(f"--{key.replace('_', '-')}={value}") return args -@pytest.fixture(params=[ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) +@pytest.fixture( + params=[ + pytest.param({}, marks=pytest.mark.cpu_model), + pytest.param(CHUNKED_PREFILL_KWARGS), + ] +) def server(request, audio_assets: AudioTestAssets): args = [ - "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", "--limit-mm-per-prompt", - json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" + json.dumps({"audio": len(audio_assets)}), + "--trust-remote-code", ] + params_kwargs_to_cli_args(request.param) - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"} + ) as remote_server: yield remote_server @@ -77,12 +83,11 @@ def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) placeholder = f"{placeholder}\n" * audio_count - return tokenizer.apply_chat_template([{ - 'role': 'user', - 'content': f"{placeholder}{question}" - }], - tokenize=False, - add_generation_prompt=True) + return tokenizer.apply_chat_template( + [{"role": "user", "content": f"{placeholder}{question}"}], + tokenize=False, + add_generation_prompt=True, + ) def run_multi_audio_test( @@ -99,19 +104,21 @@ def run_multi_audio_test( model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - with vllm_runner(model, - dtype=dtype, - enforce_eager=True, - limit_mm_per_prompt={ - "audio": - max((len(audio) for _, audio in prompts_and_audios)) - }, - **kwargs) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + enforce_eager=True, + limit_mm_per_prompt={ + "audio": max((len(audio) for _, audio in prompts_and_audios)) + }, + **kwargs, + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( [prompt for prompt, _ in prompts_and_audios], max_tokens, num_logprobs=num_logprobs, - audios=[audios for _, audios in prompts_and_audios]) + audios=[audios for _, audios in prompts_and_audios], + ) # The HuggingFace model doesn't support multiple audios yet, so # just assert that some tokens were generated. @@ -122,21 +129,25 @@ def run_multi_audio_test( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("vllm_kwargs", [ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) -def test_models_with_multiple_audios(vllm_runner, - audio_assets: AudioTestAssets, dtype: str, - max_tokens: int, num_logprobs: int, - vllm_kwargs: dict) -> None: - - vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, - VLLM_PLACEHOLDER) +@pytest.mark.parametrize( + "vllm_kwargs", + [ + pytest.param({}, marks=pytest.mark.cpu_model), + pytest.param(CHUNKED_PREFILL_KWARGS), + ], +) +def test_models_with_multiple_audios( + vllm_runner, + audio_assets: AudioTestAssets, + dtype: str, + max_tokens: int, + num_logprobs: int, + vllm_kwargs: dict, +) -> None: + vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, VLLM_PLACEHOLDER) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate - for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -149,28 +160,25 @@ def test_models_with_multiple_audios(vllm_runner, async def test_online_serving(client, audio_assets: AudioTestAssets): """Exercises online serving with/without chunked prefill enabled.""" - messages = [{ - "role": - "user", - "content": [ - *[{ - "type": "audio_url", - "audio_url": { - "url": audio.url - } - } for audio in audio_assets], - { - "type": - "text", - "text": - f"What's happening in these {len(audio_assets)} audio clips?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *[ + {"type": "audio_url", "audio_url": {"url": audio.url}} + for audio in audio_assets + ], + { + "type": "text", + "text": f"What's happening in these {len(audio_assets)} audio clips?", # noqa: E501 + }, + ], + } + ] - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, messages=messages, max_tokens=10 + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index b4439dfe020c2..18a50c3a555da 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -6,8 +6,8 @@ import json import pytest import pytest_asyncio from mistral_common.audio import Audio -from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, - TextChunk, UserMessage) +from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -17,8 +17,12 @@ from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507" MISTRAL_FORMAT_ARGS = [ - "--tokenizer_mode", "mistral", "--config_format", "mistral", - "--load_format", "mistral" + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", ] @@ -30,10 +34,9 @@ def server(request, audio_assets: AudioTestAssets): json.dumps({"audio": len(audio_assets)}), ] + MISTRAL_FORMAT_ARGS - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"} + ) as remote_server: yield remote_server @@ -64,15 +67,17 @@ def _get_prompt(audio_assets, question): @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models_with_multiple_audios(vllm_runner, - audio_assets: AudioTestAssets, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_models_with_multiple_audios( + vllm_runner, + audio_assets: AudioTestAssets, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate - for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -92,23 +97,17 @@ async def test_online_serving(client, audio_assets: AudioTestAssets): return audio_dict audio_chunks = [asset_to_chunk(asset) for asset in audio_assets] - messages = [{ - "role": - "user", - "content": [ - *audio_chunks, - { - "type": - "text", - "text": - f"What's happening in these {len(audio_assets)} audio clips?" - }, - ], - }] + text = f"What's happening in these {len(audio_assets)} audio clips?" + messages = [ + { + "role": "user", + "content": [*audio_chunks, {"type": "text", "text": text}], + } + ] - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, messages=messages, max_tokens=10 + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4a65e8c95204e..766f09b0d3207 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -12,8 +12,7 @@ from ....utils import create_new_process_for_each_test, multi_gpu_test PROMPTS = [ { - "prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + "prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", "multi_modal_data": { "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, }, @@ -25,9 +24,8 @@ PROMPTS = [ "audio": AudioAsset("winning_call").audio_and_sample_rate, }, }, - "decoder_prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - } + "decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + }, ] EXPECTED = { @@ -41,7 +39,7 @@ EXPECTED = { " is June and the third base. They're going to wave him in. The throw" " to the plate will be late. The Mariners are going to play for the" " American League Championship. I don't believe it. It just continues" - " by all five." + " by all five.", ], "openai/whisper-small": [ " The first words I spoke in the original pornograph. A little piece" @@ -51,7 +49,7 @@ EXPECTED = { " comes joy. Here is Junior to third base. They're gonna wave him" " in. The throw to the plate will be late. The Mariners are going to" " play for the American League Championship. I don't believe it. It" - " just continues. My, oh my." + " just continues. My, oh my.", ], "openai/whisper-medium": [ " The first words I spoke in the original phonograph, a little piece" @@ -62,7 +60,7 @@ EXPECTED = { " Jorgen at third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh" - " my." + " my.", ], "openai/whisper-large-v3": [ " The first words I spoke in the original phonograph, a little piece" @@ -73,7 +71,7 @@ EXPECTED = { " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," - " my." + " my.", ], "openai/whisper-large-v3-turbo": [ " The first words I spoke in the original phonograph, a little piece" @@ -84,8 +82,8 @@ EXPECTED = { " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," - " my." - ] + " my.", + ], } @@ -100,11 +98,11 @@ def run_test( expected_list = EXPECTED[model] * 10 with vllm_runner( - model, - dtype="half", - max_model_len=448, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, + model, + dtype="half", + max_model_len=448, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, ) as vllm_model: llm = vllm_model.llm @@ -122,8 +120,7 @@ def run_test( @pytest.mark.core_model -@pytest.mark.parametrize( - "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @create_new_process_for_each_test() def test_models(vllm_runner, model) -> None: run_test( diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 03c08240d6a81..096931cca09f7 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Helpers for building inputs that can be leveraged for different test types. -""" +"""Helpers for building inputs that can be leveraged for different test types.""" + from collections.abc import Iterable from pathlib import PosixPath from typing import Callable, Optional, Union @@ -10,20 +10,30 @@ import torch from vllm.multimodal.audio import AudioResampler from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.video import (rescale_video_size, resize_video, - sample_frames_from_video) +from vllm.multimodal.video import ( + rescale_video_size, + resize_video, + sample_frames_from_video, +) from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets -from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS, - TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER, - TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, - ImageSizeWrapper, PromptWithMultiModalInput, SizeType, - VLMTestInfo) +from .types import ( + SINGLE_AUDIO_BASE_PROMPT, + SINGLE_IMAGE_BASE_PROMPTS, + TEST_AUDIO_PLACEHOLDER, + TEST_IMG_PLACEHOLDER, + TEST_VIDEO_PLACEHOLDER, + VIDEO_BASE_PROMPT, + ImageSizeWrapper, + PromptWithMultiModalInput, + SizeType, + VLMTestInfo, +) -def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], - str], - test_placeholder: str) -> str: +def replace_test_placeholder( + prompt: str, mm_idx_to_prompt: Callable[[int], str], test_placeholder: str +) -> str: """Given a prompt, replaces each test placeholder with the model-specific tag. """ @@ -35,11 +45,13 @@ def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], return img_prompt -def get_model_prompts(base_prompts: Iterable[str], - img_idx_to_prompt: Optional[Callable[[int], str]], - video_idx_to_prompt: Optional[Callable[[int], str]], - audio_idx_to_prompt: Optional[Callable[[int], str]], - prompt_formatter: Callable[[str], str]) -> list[str]: +def get_model_prompts( + base_prompts: Iterable[str], + img_idx_to_prompt: Optional[Callable[[int], str]], + video_idx_to_prompt: Optional[Callable[[int], str]], + audio_idx_to_prompt: Optional[Callable[[int], str]], + prompt_formatter: Callable[[str], str], +) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting to get the test prompt string for this model. @@ -56,19 +68,19 @@ def get_model_prompts(base_prompts: Iterable[str], # Replace the multimodal placeholders in the base prompt with # the correct ones for the model that we are testing if img_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - img_idx_to_prompt, - TEST_IMG_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, img_idx_to_prompt, TEST_IMG_PLACEHOLDER + ) if video_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - video_idx_to_prompt, - TEST_VIDEO_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, video_idx_to_prompt, TEST_VIDEO_PLACEHOLDER + ) if audio_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - audio_idx_to_prompt, - TEST_AUDIO_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, audio_idx_to_prompt, TEST_AUDIO_PLACEHOLDER + ) # Apply the prompt formatter to wrap the base prompt with # the correct media placeholders to get the model test prompt @@ -84,14 +96,15 @@ def build_single_image_inputs_from_test_info( tmp_path: Optional[PosixPath] = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build single image inputs") + raise ValueError("Prompt formatter must be set to build single image inputs") - model_prompts = get_model_prompts(test_info.single_image_prompts, - test_info.img_idx_to_prompt, - test_info.video_idx_to_prompt, - test_info.audio_idx_to_prompt, - test_info.prompt_formatter) + model_prompts = get_model_prompts( + test_info.single_image_prompts, + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) # For models that require a local path / URL encoded in the image; export # assets and encode into tmp_path for this test. This should be avoided @@ -110,8 +123,8 @@ def build_single_image_inputs_from_test_info( def build_single_image_inputs( - images, model_prompts, - size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + images, model_prompts, size_wrapper: ImageSizeWrapper +) -> list[PromptWithMultiModalInput]: # For every image / prompt pair, get a pair containing two lists of # length size_factors, where the first contains duplicates of the model # prompt [str], and the second contains copies of the image after being @@ -125,7 +138,8 @@ def build_single_image_inputs( apply_image_size_scaling(image, size, size_wrapper.type) for size in size_wrapper.data ], - ) for image, prompt in zip(images, model_prompts) + ) + for image, prompt in zip(images, model_prompts) ] @@ -136,14 +150,15 @@ def build_multi_image_inputs_from_test_info( tmp_path: Optional[PosixPath] = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build multi image inputs") + raise ValueError("Prompt formatter must be set to build multi image inputs") - model_prompts = get_model_prompts([test_info.multi_image_prompt], - test_info.img_idx_to_prompt, - test_info.video_idx_to_prompt, - test_info.audio_idx_to_prompt, - test_info.prompt_formatter) + model_prompts = get_model_prompts( + [test_info.multi_image_prompt], + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) if test_info.prompt_path_encoder is not None: if tmp_path is None: @@ -164,16 +179,20 @@ def build_multi_image_inputs_from_test_info( def build_multi_image_inputs( - image_lists, model_prompts, - size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + image_lists, model_prompts, size_wrapper: ImageSizeWrapper +) -> list[PromptWithMultiModalInput]: return [ PromptWithMultiModalInput( prompts=[prompt for _ in size_wrapper.data], - image_data=[[ - apply_image_size_scaling(image, size, size_wrapper.type) - for image in images - ] for size in size_wrapper.data], - ) for images, prompt in zip(image_lists, model_prompts) + image_data=[ + [ + apply_image_size_scaling(image, size, size_wrapper.type) + for image in images + ] + for size in size_wrapper.data + ], + ) + for images, prompt in zip(image_lists, model_prompts) ] @@ -185,10 +204,10 @@ def build_embedding_inputs_from_test_info( # These conditions will always be true if invoked through filtering, # but we still check them in case this is ever called directly if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build image embedding inputs") - if size_wrapper.type != SizeType.SIZE_FACTOR or not \ - all(factor == 1.0 for factor in size_wrapper.data): + raise ValueError("Prompt formatter must be set to build image embedding inputs") + if size_wrapper.type != SizeType.SIZE_FACTOR or not all( + factor == 1.0 for factor in size_wrapper.data + ): raise ValueError("Embedding tests require constant (1.0) size factors") if test_info.convert_assets_to_embeddings is None: raise ValueError("No conversion func for getting embeddings found") @@ -209,8 +228,7 @@ def build_embedding_inputs_from_test_info( assert len(images) == len(model_prompts) inputs = build_single_image_inputs(images, model_prompts, size_wrapper) - vllm_embeddings = build_single_image_inputs(embeds, model_prompts, - size_wrapper) + vllm_embeddings = build_single_image_inputs(embeds, model_prompts, size_wrapper) return inputs, vllm_embeddings @@ -235,22 +253,23 @@ def build_video_inputs_from_test_info( for asset in video_assets ] - video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE - else rescale_video_size) + video_scaler = ( + resize_video if size_wrapper.type == SizeType.FIXED_SIZE else rescale_video_size + ) return [ PromptWithMultiModalInput( prompts=[prompt for _ in size_wrapper.data], - video_data=[ - video_scaler(video, size) for size in size_wrapper.data - ], - ) for video, prompt in zip(sampled_vids, model_prompts) + video_data=[video_scaler(video, size) for size in size_wrapper.data], + ) + for video, prompt in zip(sampled_vids, model_prompts) ] -def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], - size_type: SizeType): - """Applies a size scaler to one image; this can be a an image size factor, +def apply_image_size_scaling( + image, size: Union[float, tuple[int, int]], size_type: SizeType +): + """Applies a size scaler to one image; this can be an image size factor, which scales the image while maintaining the aspect ratio""" # Special case for embeddings; if it's a tensor, it's only valid if we # are considering size factors at constant scale, i.e., we just clone @@ -285,13 +304,16 @@ def build_audio_inputs_from_test_info( method="librosa", ) audios = [asset.audio_and_sample_rate for asset in audio_assets] - resampled_audios = [( - resampler.resample( - audio, - orig_sr=sr, - ), - int(resampler.target_sr), - ) for audio, sr in audios] + resampled_audios = [ + ( + resampler.resample( + audio, + orig_sr=sr, + ), + int(resampler.target_sr), + ) + for audio, sr in audios + ] return [ PromptWithMultiModalInput( diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 336e2dd2b1201..77e478e53c1fd 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -4,19 +4,28 @@ modality, getting all combinations (similar to pytest's parametrization), handling multimodal placeholder substitution, and so on. """ + import itertools from collections import OrderedDict from collections.abc import Iterable import pytest -from .types import (EMBEDDING_SIZE_FACTORS, ExpandableVLMTestArgs, - ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType) +from .types import ( + EMBEDDING_SIZE_FACTORS, + ExpandableVLMTestArgs, + ImageSizeWrapper, + SizeType, + VLMTestInfo, + VLMTestType, +) def get_filtered_test_settings( - test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, - new_proc_per_test: bool) -> dict[str, VLMTestInfo]: + test_settings: dict[str, VLMTestInfo], + test_type: VLMTestType, + new_proc_per_test: bool, +) -> dict[str, VLMTestInfo]: """Given the dict of potential test settings to run, return a subdict of tests who have the current test type enabled with the matching val for fork_per_test. @@ -25,7 +34,8 @@ def get_filtered_test_settings( def matches_test_type(test_info: VLMTestInfo, test_type: VLMTestType): return test_info.test_type == test_type or ( isinstance(test_info.test_type, Iterable) - and test_type in test_info.test_type) + and test_type in test_info.test_type + ) matching_tests = {} for test_name, test_info in test_settings.items(): @@ -36,68 +46,74 @@ def get_filtered_test_settings( assert test_info.convert_assets_to_embeddings is not None # Custom test inputs need to explicitly define the mm limit/inputs if matches_test_type(test_info, VLMTestType.CUSTOM_INPUTS): - assert (test_info.custom_test_opts is not None - and isinstance(test_info.custom_test_opts, Iterable)) + assert test_info.custom_test_opts is not None and isinstance( + test_info.custom_test_opts, Iterable + ) # For all types besides custom inputs, we need a prompt formatter else: assert test_info.prompt_formatter is not None - # Everything looks okay; keep if this is has correct proc handling - if (test_info.distributed_executor_backend - is not None) == new_proc_per_test: + # Everything looks okay; keep if this is correct proc handling + if ( + test_info.distributed_executor_backend is not None + ) == new_proc_per_test: matching_tests[test_name] = test_info return matching_tests -def get_parametrized_options(test_settings: dict[str, VLMTestInfo], - test_type: VLMTestType, - create_new_process_for_each_test: bool): +def get_parametrized_options( + test_settings: dict[str, VLMTestInfo], + test_type: VLMTestType, + create_new_process_for_each_test: bool, +): """Converts all of our VLMTestInfo into an expanded list of parameters. This is similar to nesting pytest parametrize calls, but done directly through an itertools product so that each test can set things like size factors etc, while still running in isolated test cases. """ matching_tests = get_filtered_test_settings( - test_settings, test_type, create_new_process_for_each_test) + test_settings, test_type, create_new_process_for_each_test + ) # Ensure that something is wrapped as an iterable it's not already - ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e, ) + ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) def get_model_type_cases(model_type: str, test_info: VLMTestInfo): # This is essentially the same as nesting a bunch of mark.parametrize # decorators, but we do it programmatically to allow overrides for on # a per-model basis, while still being able to execute each of these # as individual test cases in pytest. - iter_kwargs = OrderedDict([ - ("model", ensure_wrapped(test_info.models)), - ("max_tokens", ensure_wrapped(test_info.max_tokens)), - ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), - ("dtype", ensure_wrapped(test_info.dtype)), - ("distributed_executor_backend", - ensure_wrapped(test_info.distributed_executor_backend)), - ]) + iter_kwargs = OrderedDict( + [ + ("model", ensure_wrapped(test_info.models)), + ("max_tokens", ensure_wrapped(test_info.max_tokens)), + ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), + ("dtype", ensure_wrapped(test_info.dtype)), + ( + "distributed_executor_backend", + ensure_wrapped(test_info.distributed_executor_backend), + ), + ] + ) # num_frames is video only if test_type == VLMTestType.VIDEO: - iter_kwargs["num_video_frames"] = ensure_wrapped( - test_info.num_video_frames) + iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) # No sizes passed for custom inputs, since inputs are directly provided if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) if wrapped_sizes is None: - raise ValueError( - f"Sizes must be set for test type {test_type}") + raise ValueError(f"Sizes must be set for test type {test_type}") iter_kwargs["size_wrapper"] = wrapped_sizes - #Otherwise expand the custom test options instead + # Otherwise expand the custom test options instead elif test_type == VLMTestType.CUSTOM_INPUTS: if test_info.custom_test_opts is None: raise ValueError("Test has type CUSTOM_INPUTS, but none given") iter_kwargs["custom_test_opts"] = test_info.custom_test_opts - # yapf: disable # Wrap all model cases in a pytest parameter & pass marks through return [ pytest.param( @@ -105,10 +121,10 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo], ExpandableVLMTestArgs( **{k: v for k, v in zip(iter_kwargs.keys(), case)} ), - marks=test_info.marks if test_info.marks is not None else [] - ) for case in list(itertools.product(*iter_kwargs.values())) + marks=test_info.marks if test_info.marks is not None else [], + ) + for case in list(itertools.product(*iter_kwargs.values())) ] - # yapf: enable # Get a list per model type, where each entry contains a tuple of all of # that model type's cases, then flatten them into the top level so that @@ -121,8 +137,8 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo], def get_wrapped_test_sizes( - test_info: VLMTestInfo, - test_type: VLMTestType) -> tuple[ImageSizeWrapper, ...]: + test_info: VLMTestInfo, test_type: VLMTestType +) -> tuple[ImageSizeWrapper, ...]: """Given a test info which may have size factors or fixed sizes, wrap them and combine them into an iterable, each of which will be used in parameter expansion. @@ -133,18 +149,18 @@ def get_wrapped_test_sizes( """ # If it is an embedding test, we always use the EMBEDDING_SIZE_FACTORS if test_type == VLMTestType.EMBEDDING: - return tuple([ - ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) - for factor in EMBEDDING_SIZE_FACTORS - ]) + return tuple( + [ + ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) + for factor in EMBEDDING_SIZE_FACTORS + ] + ) # Audio and Custom inputs have preprocessed inputs elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS): return tuple() - size_factors = test_info.image_size_factors \ - if test_info.image_size_factors else [] - fixed_sizes = test_info.image_sizes \ - if test_info.image_sizes else [] + size_factors = test_info.image_size_factors if test_info.image_size_factors else [] + fixed_sizes = test_info.image_sizes if test_info.image_sizes else [] wrapped_factors = [ ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) @@ -152,8 +168,7 @@ def get_wrapped_test_sizes( ] wrapped_sizes = [ - ImageSizeWrapper(type=SizeType.FIXED_SIZE, data=size) - for size in fixed_sizes + ImageSizeWrapper(type=SizeType.FIXED_SIZE, data=size) for size in fixed_sizes ] return tuple(wrapped_factors + wrapped_sizes) diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index a5d6948f06efd..5748ccc14c294 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Core test implementation to be shared across modalities.""" + from typing import Any, Callable, Optional import torch from transformers.models.auto.auto_factory import _BaseAutoModelClass -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import HfRunner, VllmRunner @@ -42,7 +43,7 @@ def run_test( tensor_parallel_size: int = 1, vllm_embeddings: Optional[torch.Tensor] = None, ): - """Modality agnostic test test executor for comparing HF/vLLM outputs.""" + """Modality agnostic test executor for comparing HF/vLLM outputs.""" # In the case of embeddings, vLLM takes separate input tensors vllm_inputs = vllm_embeddings if vllm_embeddings is not None else inputs @@ -69,20 +70,24 @@ def run_test( vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode if model_info.hf_overrides: vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides + if model_info.skip_tokenizer_init: + vllm_runner_kwargs_["skip_tokenizer_init"] = model_info.skip_tokenizer_init if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) - with vllm_runner(model, - max_model_len=max_model_len, - max_num_seqs=max_num_seqs, - dtype=dtype, - limit_mm_per_prompt=limit_mm_per_prompt, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=enforce_eager, - runner=runner, - **vllm_runner_kwargs_) as vllm_model: + with vllm_runner( + model, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + dtype=dtype, + limit_mm_per_prompt=limit_mm_per_prompt, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=enforce_eager, + runner=runner, + **vllm_runner_kwargs_, + ) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() vllm_kwargs: dict[str, Any] = {} @@ -92,21 +97,19 @@ def run_test( vllm_kwargs["stop"] = stop_str for prompts, image_data, video_data, audio_data in vllm_inputs: - mm_data = dict(images=image_data, - videos=video_data, - audios=audio_data) + mm_data = dict(images=image_data, videos=video_data, audios=audio_data) vllm_kwargs_with_mm_data = vllm_kwargs | mm_data vllm_output = vllm_model.generate_greedy_logprobs( prompts, max_tokens, num_logprobs=num_logprobs, - **vllm_kwargs_with_mm_data) + **vllm_kwargs_with_mm_data, + ) vllm_outputs_per_mm.append(vllm_output) - hf_model = hf_runner(model, - dtype=dtype, - auto_cls=auto_cls, - model_kwargs=hf_model_kwargs) + hf_model = hf_runner( + model, dtype=dtype, auto_cls=auto_cls, model_kwargs=hf_model_kwargs + ) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: @@ -126,16 +129,15 @@ def run_test( hf_kwargs["stop_strings"] = stop_str for prompts, image_data, video_data, audio_data in inputs: - mm_data = dict(images=image_data, - videos=video_data, - audios=audio_data) + mm_data = dict(images=image_data, videos=video_data, audios=audio_data) hf_kwargs_with_mm_data = hf_kwargs | mm_data hf_output = hf_model.generate_greedy_logprobs_limit( prompts, max_tokens, num_logprobs=num_logprobs, tokenizer=tokenizer, - **hf_kwargs_with_mm_data) + **hf_kwargs_with_mm_data, + ) hf_outputs_per_mm.append(hf_output) # Apply output processing / sanitation to the vLLM and HF runner results @@ -147,8 +149,7 @@ def run_test( second_runner_processor=vllm_output_post_proc, ) - for hf_outputs, vllm_outputs in zip(hf_outputs_per_mm, - vllm_outputs_per_mm): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_mm, vllm_outputs_per_mm): # This is usually check_logprobs_close, but it's passed through to # allow things like check_outputs_equal where needed comparator( @@ -168,15 +169,19 @@ def process_runner_outputs( ): """Applies the runner processor(s) to the runner outputs, if any.""" if first_runner_processor is not None: - first_runner_outputs = process_outputs(first_runner_processor, model, - first_runner_outputs) + first_runner_outputs = process_outputs( + first_runner_processor, model, first_runner_outputs + ) if second_runner_processor is not None: - second_runner_outputs = process_outputs(second_runner_processor, model, - second_runner_outputs) + second_runner_outputs = process_outputs( + second_runner_processor, model, second_runner_outputs + ) return first_runner_outputs, second_runner_outputs def process_outputs(output_processor, model, outputs_per_image): """Applies a model specific post-processor function to a runner's output""" - return [[output_processor(res, model) for res in outputs] - for outputs in outputs_per_image] + return [ + [output_processor(res, model) for res in outputs] + for outputs in outputs_per_image + ] diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index c53243b42e384..8f2f8bba39ca2 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -1,15 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom input builders for edge-cases in different models.""" -from io import BytesIO + from typing import Callable -import requests -from PIL import Image - +from vllm.assets.image import ImageAsset from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.video import (rescale_video_size, resize_video, - sample_frames_from_video) +from vllm.multimodal.video import ( + rescale_video_size, + resize_video, + sample_frames_from_video, +) from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .builders import build_multi_image_inputs, build_single_image_inputs @@ -18,7 +19,7 @@ from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): """Builds inputs for multi-image (varied sizes/aspect ratio) testing. - + Args: formatter: model-specific prompt formatter. """ @@ -44,7 +45,7 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): stop_sign, rescale_image_size(stop_sign, 0.25), cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) + cherry_blossom.resize((488, 183)), ], cherry_blossom, ] @@ -57,10 +58,11 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): ] -def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], - num_frames: int = 16): +def multi_video_multi_aspect_ratio_inputs( + formatter: Callable[[str], str], num_frames: int = 16 +): """Builds inputs for multi-video (varied sizes/aspect ratio) testing. - + Args: formatter: model-specific prompt formatter. """ @@ -84,7 +86,7 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], video, rescale_video_size(video, 0.25), resize_video(video, (183, 488)), - resize_video(video, (488, 183)) + resize_video(video, (488, 183)), ], video, ] @@ -99,7 +101,9 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], def different_patch_input_cases_internvl(): images = [asset.pil_image.resize((896, 896)) for asset in IMAGE_ASSETS] - formatter = lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 + formatter = ( + lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 + ) single_img_prompts = [ "<image>\nWhat's the content in the center of the image?", "<image>\nWhat is the season?", @@ -118,14 +122,14 @@ def different_patch_input_cases_internvl(): def windows_attention_image_qwen2_5_vl(): - # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 - image_url = "https://aomediacodec.github.io/av1-avif/testFiles/Link-U/hato.jpg" - image = Image.open(BytesIO(requests.get(image_url).content)) + # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 # noqa: E501 + image = ImageAsset("hato").pil_image question = "Describe the image." img_prompt = "<|vision_start|><|image_pad|><|vision_end|>" - prompt = (f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = ( + f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n<|im_start|>assistant\n" + ) wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5]) return build_single_image_inputs([image], [prompt], wrapped_sf) @@ -139,8 +143,9 @@ def video_with_metadata_glm4_1v(): formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" scales = [0.1, 0.2, 0.25] - video_input = [[(rescale_video_size(video_array, scale), metadata)] - for scale in scales] + video_input = [ + [(rescale_video_size(video_array, scale), metadata)] for scale in scales + ] prompts = [formatted_prompt] * len(video_input) return [ diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 8b7d051218f14..e51d895772c05 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -4,6 +4,7 @@ for manipulating the input / output of HF & vLLM test runners, which are typically specific to a small subset of models. """ + import types from pathlib import PosixPath from typing import Optional, Union @@ -15,11 +16,16 @@ import pytest import regex as re import torch from PIL.Image import Image -from transformers import (AutoConfig, AutoTokenizer, BatchFeature, - GenerationConfig, GenerationMixin) +from transformers import ( + AutoConfig, + AutoTokenizer, + BatchFeature, + GenerationConfig, + GenerationMixin, +) from transformers.video_utils import VideoMetadata -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -27,8 +33,7 @@ from .types import RunnerOutput ####### vLLM output processors functions -def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [blip2 models] to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -42,8 +47,7 @@ def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [fuyu models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -53,8 +57,8 @@ def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, def qwen_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -64,8 +68,8 @@ def qwen_vllm_to_hf_output( def qwen2_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen2 models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -75,8 +79,8 @@ def qwen2_vllm_to_hf_output( def kimiv_vl_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -85,23 +89,25 @@ def kimiv_vl_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs -def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def llava_image_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str +) -> RunnerOutput: config = AutoConfig.from_pretrained(model) mm_token_id = config.image_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) def llava_video_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: config = AutoConfig.from_pretrained(model) mm_token_id = config.video_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) -def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str, - mm_token_id: int) -> RunnerOutput: +def _llava_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str, mm_token_id: int +) -> RunnerOutput: """Sanitize vllm output [Llava models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -109,7 +115,8 @@ def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != mm_token_id or output_ids[idx - 1] != mm_token_id ] @@ -128,8 +135,9 @@ def llava_onevision_hf_model_kwargs(model: str) -> dict: return config.to_dict() -def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def llava_onevision_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str +) -> RunnerOutput: """Sanitize vllm output [llava-onevision] to compare with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -140,7 +148,8 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != video_token_id or output_ids[idx - 1] != video_token_id ] @@ -151,8 +160,7 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [mantis] to compare with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -161,8 +169,7 @@ def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, return output_ids, hf_output_str, out_logprobs -def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -180,8 +187,7 @@ def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -192,7 +198,8 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] @@ -205,46 +212,40 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs -def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|end▁of▁sentence|>"): output_str = output_str.split("<|end▁of▁sentence|>")[0] return output_ids, output_str, out_logprobs -def idefics3_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def idefics3_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<end_of_utterance>"): output_str = output_str.split("<end_of_utterance>")[0] return output_ids, output_str, out_logprobs -def smolvlm_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def smolvlm_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: # Based on Idefics3 return idefics3_trunc_hf_output(hf_output, model) -def minicpmv_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|eot_id|>"): output_str = output_str.split("<|eot_id|>")[0] return output_ids, output_str, out_logprobs -def minimax_vl_01_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def minimax_vl_01_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<end_of_sentence>"): output_str = output_str.split("<end_of_sentence>")[0] return output_ids, output_str, out_logprobs -def ultravox_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def ultravox_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output tokenizer = AutoTokenizer.from_pretrained(model) @@ -262,8 +263,8 @@ def get_llava_embeddings(image_assets: ImageTestAssets): ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, - assets: Union[list[ImageAsset], ImageTestAssets]) -> str: + tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], ImageTestAssets] +) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that the HF version of Qwen-VL can resolve the path and load the image in its @@ -313,8 +314,9 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return BatchFeature(data=inputs, tensor_type="pt") hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language.model.embed_tokens + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language.model.embed_tokens + ) return hf_model @@ -340,6 +342,29 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: + """Sanitize vllm output [gemma-3] to compare with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = [ + token_id + for idx, token_id in enumerate(output_ids) + if token_id != image_token_id + ] + + hf_output_str = output_str + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" hf_processor = hf_model.processor @@ -357,11 +382,10 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: assert len(contents) == len(images) return hf_processor.apply_chat_template( - [{ - "role": "user", - "image": image, - "content": content - } for image, content in zip(images, contents)], + [ + {"role": "user", "image": image, "content": content} + for image, content in zip(images, contents) + ], add_generation_prompt=True, tokenize=True, return_dict=True, @@ -369,8 +393,9 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: ) hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.transformer.output_layer + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.transformer.output_layer + ) return hf_model @@ -387,10 +412,9 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: else: video_metadata = None - return hf_processor(*args, - videos=videos, - video_metadata=video_metadata, - **kwargs) + return hf_processor( + *args, videos=videos, video_metadata=video_metadata, **kwargs + ) hf_model.processor = processor return hf_model @@ -406,8 +430,9 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.use_msac = self.config.use_msac @@ -415,13 +440,14 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): - # yapf: disable + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): from vllm.model_executor.models.h2ovl import ( - IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_h2ovl, + ) - # yapf: enable images = [images] if isinstance(images, Image) else images pixel_values = [ image_to_pixel_values_h2ovl( @@ -431,29 +457,26 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: max_num=self.max_num, use_thumbnail=self.use_thumbnail, use_msac=self.use_msac, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values + ) + for image in images ] + num_patches_list = [pixel_value.shape[0] for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + text = text.replace("<image>", image_tokens, 1) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "<IMG_CONTEXT>") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = H2OVLProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -467,19 +490,23 @@ def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): from vllm.model_executor.models.skyworkr1v import ( - IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_skyworkr1v) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_skyworkr1v, + ) + images = [images] if isinstance(images, Image) else images pixel_values = [ image_to_pixel_values_skyworkr1v( @@ -488,29 +515,26 @@ def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: min_num=self.min_num, max_num=self.max_num, use_thumbnail=self.use_thumbnail, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values + ) + for image in images ] + num_patches_list = [pixel_value.shape[0] for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + text = text.replace("<image>", image_tokens, 1) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "<IMG_CONTEXT>") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = SkyworkR1VProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -524,8 +548,9 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch @@ -540,8 +565,13 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: **kwargs, ): from vllm.model_executor.models.internvl import ( - IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_internvl, video_to_pixel_values_internvl) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_internvl, + video_to_pixel_values_internvl, + ) + images = [images] if isinstance(images, Image) else images videos = [videos] if isinstance(videos, np.ndarray) else videos if images is not None: @@ -552,7 +582,8 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: min_num=self.min_num, max_num=self.max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] num_patches_images = [ pixel_value.shape[0] for pixel_value in pixel_values_images @@ -568,7 +599,8 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: min_num=1, max_num=1, use_thumbnail=False, - ) for video in videos + ) + for video in videos ] num_patches_videos = [ pixel_value.shape[0] for pixel_value in pixel_values_videos @@ -580,38 +612,37 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: while ("<image>" in text) or ("<video>" in text): image_index = text.find("<image>") video_index = text.find("<video>") - if image_index == -1 or (video_index > -1 - and video_index < image_index): + if image_index == -1 or ( + video_index > -1 and video_index < image_index + ): num_patches = num_patches_videos.pop(0) pixel_values.append(pixel_values_videos.pop(0)) - context_tokens = IMG_START + \ - IMG_CONTEXT * self.num_image_token + IMG_END - video_tokens = ''.join([ - f'Frame{i+1}: {context_tokens}' - for i in range(num_patches) - ]) - text = text.replace('<video>', video_tokens, 1) + context_tokens = ( + IMG_START + IMG_CONTEXT * self.num_image_token + IMG_END + ) + video_tokens = "".join( + [f"Frame{i + 1}: {context_tokens}" for i in range(num_patches)] + ) + text = text.replace("<video>", video_tokens, 1) else: num_patches = num_patches_images.pop(0) pixel_values.append(pixel_values_images.pop(0)) - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + text = text.replace("<image>", image_tokens, 1) pixel_values = torch.cat(pixel_values, dim=0) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "<IMG_CONTEXT>") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = InternVLProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -631,7 +662,7 @@ def _internvl_generate( input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) - selected = (input_ids == self.img_context_token_id) + selected = input_ids == self.img_context_token_id assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) @@ -778,8 +809,9 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.llm.get_output_embeddings() + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.llm.get_output_embeddings() + ) def processor(*args, text="", images=None, **kwargs): text_tokenizer = hf_model.model.get_text_tokenizer() @@ -787,8 +819,7 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: prompt_start_and_end = { "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), - "llama": - ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "llama": ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), } for start, end in prompt_start_and_end.values(): @@ -797,7 +828,8 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: break prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs( - text_or_conversations=text, images=images) + text_or_conversations=text, images=images + ) attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) inputs = { @@ -813,8 +845,9 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.llm.get_output_embeddings() + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.llm.get_output_embeddings() + ) def processor(*args, text="", images=None, videos=None, **kwargs): if images is None: @@ -825,13 +858,11 @@ def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: videos = [] else: videos = [videos] if isinstance(videos, np.ndarray) else videos - videos = [[PIL.Image.fromarray(frame) for frame in vid] - for vid in videos] + videos = [[PIL.Image.fromarray(frame) for frame in vid] for vid in videos] prompt_start_and_end = { "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), - "llama": - ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "llama": ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), } for start, end in prompt_start_and_end.values(): @@ -842,21 +873,20 @@ def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: images_message = [{"type": "image", "image": img} for img in images] videos_message = [{"type": "video", "video": vid} for vid in videos] - messages = [{ - "role": - "user", - "content": [ - *images_message, - *videos_message, - { - "type": "text", - "text": text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *images_message, + *videos_message, + {"type": "text", "text": text}, + ], + } + ] input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs( - messages=messages, enable_thinking=True) + messages=messages, enable_thinking=True + ) inputs = { "inputs": input_ids, "pixel_values": pixel_values, diff --git a/tests/models/multimodal/generation/vlm_utils/runners.py b/tests/models/multimodal/generation/vlm_utils/runners.py index 562f89df13470..c91ae117b5589 100644 --- a/tests/models/multimodal/generation/vlm_utils/runners.py +++ b/tests/models/multimodal/generation/vlm_utils/runners.py @@ -3,23 +3,34 @@ """Entrypoints for wrapping the core run_test implementation for specific test types / modalities. """ + from pathlib import PosixPath -from .....conftest import (AudioTestAssets, HfRunner, ImageTestAssets, - VideoTestAssets, VllmRunner) +from .....conftest import ( + AudioTestAssets, + HfRunner, + ImageTestAssets, + VideoTestAssets, + VllmRunner, +) from . import builders, core from .types import ExpandableVLMTestArgs, VLMTestInfo ####### Entrypoints for running different test types -def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets): +def run_single_image_test( + *, + tmp_path: PosixPath, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): assert test_case.size_wrapper is not None inputs = builders.build_single_image_inputs_from_test_info( - model_test_info, image_assets, test_case.size_wrapper, tmp_path) + model_test_info, image_assets, test_case.size_wrapper, tmp_path + ) core.run_test( hf_runner=hf_runner, @@ -31,17 +42,23 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": 1}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) -def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets): +def run_multi_image_test( + *, + tmp_path: PosixPath, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): assert test_case.size_wrapper is not None inputs = builders.build_multi_image_inputs_from_test_info( - model_test_info, image_assets, test_case.size_wrapper, tmp_path) + model_test_info, image_assets, test_case.size_wrapper, tmp_path + ) core.run_test( hf_runner=hf_runner, @@ -53,17 +70,22 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": len(image_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) -def run_embedding_test(*, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets): +def run_embedding_test( + *, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): assert test_case.size_wrapper is not None inputs, vllm_embeddings = builders.build_embedding_inputs_from_test_info( - model_test_info, image_assets, test_case.size_wrapper) + model_test_info, image_assets, test_case.size_wrapper + ) core.run_test( hf_runner=hf_runner, @@ -76,7 +98,8 @@ def run_embedding_test(*, model_test_info: VLMTestInfo, limit_mm_per_prompt={"image": 1}, vllm_embeddings=vllm_embeddings, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) def run_video_test( @@ -90,8 +113,11 @@ def run_video_test( assert test_case.size_wrapper is not None assert test_case.num_video_frames is not None inputs = builders.build_video_inputs_from_test_info( - model_test_info, video_assets, test_case.size_wrapper, - test_case.num_video_frames) + model_test_info, + video_assets, + test_case.size_wrapper, + test_case.num_video_frames, + ) core.run_test( hf_runner=hf_runner, @@ -103,7 +129,8 @@ def run_video_test( num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"video": len(video_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) def run_audio_test( @@ -114,8 +141,7 @@ def run_audio_test( vllm_runner: type[VllmRunner], audio_assets: AudioTestAssets, ): - inputs = builders.build_audio_inputs_from_test_info( - model_test_info, audio_assets) + inputs = builders.build_audio_inputs_from_test_info(model_test_info, audio_assets) core.run_test( hf_runner=hf_runner, @@ -127,13 +153,17 @@ def run_audio_test( num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"audio": 1}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) -def run_custom_inputs_test(*, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner]): +def run_custom_inputs_test( + *, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], +): # Custom test cases can provide inputs directly, but they need to # explicitly provided a CustomTestConfig, which wraps the inputs and # the limit_mm_per_prompt @@ -155,4 +185,5 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt=limit_mm_per_prompt, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 9451131960885..6e82f7e3306ab 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Types for writing multimodal model tests.""" + from collections.abc import Iterable from enum import Enum from pathlib import PosixPath @@ -11,13 +12,20 @@ from pytest import MarkDecorator from transformers import AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass -from vllm.config import RunnerOption -from vllm.sequence import SampleLogprobs +from vllm.config.model import RunnerOption +from vllm.logprobs import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer -from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, - ImageTestAssets, PromptAudioInput, PromptImageInput, - PromptVideoInput) +from .....conftest import ( + AUDIO_ASSETS, + IMAGE_ASSETS, + HfRunner, + ImageAsset, + ImageTestAssets, + PromptAudioInput, + PromptImageInput, + PromptVideoInput, +) from ....utils import check_logprobs_close # meta image tag; will be replaced by the appropriate tag for the model @@ -25,28 +33,31 @@ TEST_IMG_PLACEHOLDER = "<vlm_image>" TEST_VIDEO_PLACEHOLDER = "<vlm_video>" TEST_AUDIO_PLACEHOLDER = "<lmm_audio>" -# yapf: disable -SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?", - "cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?", -}) -SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts({ - "mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501 - "winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501 -}) +SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?", + "cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?", + } +) +SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts( + { + "mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501 + "winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501 + } +) MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501 VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?" -IMAGE_SIZE_FACTORS = [(), (1.0, ), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] -EMBEDDING_SIZE_FACTORS = [(), (1.0, ), (1.0, 1.0, 1.0)] +IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] +EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)] RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]] -# yapf: enable class PromptWithMultiModalInput(NamedTuple): """Holds the multimodal input for a single test case.""" + prompts: list[str] image_data: Optional[PromptImageInput] = None video_data: Optional[PromptVideoInput] = None @@ -100,8 +111,9 @@ class VLMTestInfo(NamedTuple): # Function for converting ImageAssets to image embeddings; # We need to define this explicitly for embedding tests - convert_assets_to_embeddings: Optional[Callable[[ImageTestAssets], - torch.Tensor]] = None + convert_assets_to_embeddings: Optional[ + Callable[[ImageTestAssets], list[torch.Tensor]] + ] = None # Exposed options for vLLM runner; we change these in a several tests, # but the defaults are derived from VllmRunner & the engine defaults @@ -137,12 +149,12 @@ class VLMTestInfo(NamedTuple): # Default expandable params per test; these defaults can be overridden in # instances of this object; the complete set of test cases for the model # is all combinations of .models + all fields below - max_tokens: Union[int, tuple[int]] = 128 - num_logprobs: Union[int, tuple[int]] = 5 - dtype: Union[str, Union[list[str], tuple[str, ...]]] = "auto" - distributed_executor_backend: Optional[Union[str, Iterable[str]]] = None + max_tokens: int = 128 + num_logprobs: int = 5 + dtype: str = "auto" + distributed_executor_backend: Optional[str] = None # Only expanded in video tests - num_video_frames: Union[int, tuple[int]] = 16 + num_video_frames: int = 16 # Fixed image sizes / image size factors; most tests use image_size_factors # The values provided for these two fields will be stacked and expanded @@ -156,8 +168,8 @@ class VLMTestInfo(NamedTuple): # for Qwen-VL, which requires encoding the image path / url into the prompt # for HF runner prompt_path_encoder: Optional[ - Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], - str]] = None # noqa: E501 + Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], str] + ] = None # noqa: E501 # Allows configuring a test to run with custom inputs custom_test_opts: Optional[list[CustomTestOptions]] = None @@ -190,6 +202,7 @@ class VLMTestInfo(NamedTuple): class ExpandableVLMTestArgs(NamedTuple): """The expanded kwargs which correspond to a single test case.""" + model: str max_tokens: int num_logprobs: int diff --git a/tests/models/multimodal/pooling/test_clip.py b/tests/models/multimodal/pooling/test_clip.py new file mode 100644 index 0000000000000..95c678558f4fa --- /dev/null +++ b/tests/models/multimodal/pooling/test_clip.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import CLIPModel + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + "a photo of a stop sign", + "a photo of a cherry blossom", +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "", + "cherry_blossom": "", + } +) + +MODELS = ["openai/clip-vit-base-patch32"] + + +def _run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=77 + ) as vllm_model: + vllm_outputs = vllm_model.embed(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, auto_cls=CLIPModel) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + inputs = hf_model.wrap_device(inputs) + + if "pixel_values" in inputs: + pooled_output = hf_model.model.get_image_features( + pixel_values=inputs.pixel_values, + ).squeeze(0) + else: + pooled_output = hf_model.model.get_text_features( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + ).squeeze(0) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text_image_no_crash( + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + texts = [HF_TEXT_PROMPTS[0]] + images = [image_assets[0].pil_image] + + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=77 + ) as vllm_model: + with pytest.raises(ValueError, match="not both"): + vllm_model.embed(texts, images=images) + + # Should still be able to run subsequent requests + vllm_model.embed(texts) + vllm_model.embed([""], images=images) diff --git a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py index f152ded3fb23a..7f30b1f299ba1 100644 --- a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py +++ b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py @@ -17,18 +17,21 @@ HF_TEXT_PROMPTS = [ # T -> X ( "Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501, - Image.new("RGB", (56, 56))), + Image.new("RGB", (56, 56)), + ), # T -> X - ("Query: Retrieve an image of this caption: cherry blossom", - Image.new("RGB", (56, 56))), + ( + "Query: Retrieve an image of this caption: cherry blossom", + Image.new("RGB", (56, 56)), + ), ] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "What is shown in this image?", - "cherry_blossom": - "What is shown in this image?" -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "What is shown in this image?", + "cherry_blossom": "What is shown in this image?", + } +) MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"] @@ -36,34 +39,30 @@ MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"] def get_messages(image: Image.Image, text: str, embed_text: bool): # assert False, 'remember to use outer [] as required' if embed_text: - messages = [{ - "role": - "user", - "content": [ - { - "type": "image", - "image": Image.new("RGB", (56, 56)), - "resized_height": 1, - "resized_width": 1 - }, # need a dummy image here for an easier process. - { - "type": "text", - "text": text - }, - ] - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": Image.new("RGB", (56, 56)), + "resized_height": 1, + "resized_width": 1, + }, # need a dummy image here for an easier process. + {"type": "text", "text": text}, + ], + } + ] else: - messages = [{ - "role": - "user", - "content": [{ - "type": "image", - "image": image - }, { - "type": "text", - "text": text - }] - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": text}, + ], + } + ] return messages @@ -71,8 +70,10 @@ def apply_chat_template_and_add_eos( messages: list[dict], apply_chat_template_fn: Callable, ): - prompt = apply_chat_template_fn( - messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>" + prompt = ( + apply_chat_template_fn(messages, tokenize=False, add_generation_prompt=True) + + "<|endoftext|>" + ) return prompt @@ -86,16 +87,14 @@ def _run_test( *, dtype: str, ) -> None: - '''SET PYTHONPATH''' + """SET PYTHONPATH""" # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - runner="pooling", - dtype=dtype, - enforce_eager=True, - max_model_len=8192) as vllm_model: + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=8192 + ) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() texts = [ # this is necessary because vllm_model.embed will not apply any @@ -105,25 +104,25 @@ def _run_test( apply_chat_template_and_add_eos( get_messages(image, text, False), apply_chat_template_fn=tokenizer.apply_chat_template, - ) for text, image in zip(input_texts, input_images) + ) + for text, image in zip(input_texts, input_images) # vllm will replace the pad token with the actual image, # which may be a placeholder image, later. ] vllm_outputs = vllm_model.embed(texts, images=input_images) hf_outputs = [] - with hf_runner(model, - dtype=dtype, - auto_cls=Qwen2VLForConditionalGeneration) as hf_model: - + with hf_runner( + model, dtype=dtype, auto_cls=Qwen2VLForConditionalGeneration + ) as hf_model: prompts = [] - for text, image, embed_text in zip(input_texts, input_images, - embed_texts): + for text, image, embed_text in zip(input_texts, input_images, embed_texts): # dse requires non-standard input processing # because it needs an image_pad token messages = get_messages(image, text, embed_text) prompt = apply_chat_template_and_add_eos( - messages, hf_model.processor.apply_chat_template) + messages, hf_model.processor.apply_chat_template + ) prompts.append(prompt) @@ -145,9 +144,9 @@ def _run_test( return_dict=True, output_hidden_states=True, ) - pooled_output = F.normalize(outputs.hidden_states[-1][0, -1], - p=2, - dim=-1) + pooled_output = F.normalize( + outputs.hidden_states[-1][0, -1], p=2, dim=-1 + ) all_outputs.append(pooled_output.tolist()) @@ -170,8 +169,9 @@ def test_models_text( model: str, dtype: str, ) -> None: - input_texts_images = [(text, image_placeholder) - for text, image_placeholder in HF_TEXT_PROMPTS] + input_texts_images = [ + (text, image_placeholder) for text, image_placeholder in HF_TEXT_PROMPTS + ] input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] embed_texts = [True] * len(input_texts) @@ -198,8 +198,7 @@ def test_models_image( dtype: str, ) -> None: input_texts_images = [ - (text, asset.pil_image) - for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 3e2be34a50ad5..b474e851319ae 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -29,7 +29,7 @@ def run_intern_vit_test( img_processor = CLIPImageProcessor.from_pretrained(model) images = [asset.pil_image for asset in image_assets] pixel_values = [ - img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype) + img_processor(images, return_tensors="pt").pixel_values.to(torch_dtype) for images in images ] @@ -37,15 +37,16 @@ def run_intern_vit_test( if not getattr(config, "norm_type", None): config.norm_type = "rms_norm" - hf_model = AutoModel.from_pretrained(model, - torch_dtype=torch_dtype, - trust_remote_code=True).to("cuda") + hf_model = AutoModel.from_pretrained( + model, torch_dtype=torch_dtype, trust_remote_code=True + ).to("cuda") hf_outputs_per_image = [ hf_model(pixel_value.to("cuda")).last_hidden_state for pixel_value in pixel_values ] from vllm.model_executor.models.intern_vit import InternVisionModel + vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) @@ -54,22 +55,23 @@ def run_intern_vit_test( vllm_model = vllm_model.to("cuda", torch_dtype) vllm_outputs_per_image = [ - vllm_model(pixel_values=pixel_value.to("cuda")) - for pixel_value in pixel_values + vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values ] del vllm_model cleanup_dist_env_and_memory() cos_similar = nn.CosineSimilarity(dim=-1) - for vllm_output, hf_output in zip(vllm_outputs_per_image, - hf_outputs_per_image): + for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image): assert cos_similar(vllm_output, hf_output).mean() > 0.99 -@pytest.mark.parametrize("model_id", [ - "OpenGVLab/InternViT-300M-448px", - "OpenGVLab/InternViT-6B-448px-V1-5", -]) +@pytest.mark.parametrize( + "model_id", + [ + "OpenGVLab/InternViT-300M-448px", + "OpenGVLab/InternViT-6B-448px-V1-5", + ], +) @pytest.mark.parametrize("dtype", ["half"]) def test_models(dist_init, image_assets, model_id, dtype: str) -> None: run_intern_vit_test( diff --git a/tests/models/multimodal/pooling/test_jinavl_reranker.py b/tests/models/multimodal/pooling/test_jinavl_reranker.py index 7ad7a8d284cba..853f56618290e 100644 --- a/tests/models/multimodal/pooling/test_jinavl_reranker.py +++ b/tests/models/multimodal/pooling/test_jinavl_reranker.py @@ -29,7 +29,6 @@ def vllm_reranker( query_type: str = "text", doc_type: str = "text", ): - def create_image_param(url: str) -> ChatCompletionContentPartImageParam: return {"type": "image_url", "image_url": {"url": f"{url}"}} @@ -38,23 +37,25 @@ def vllm_reranker( query = query_strs elif query_type == "image": query = ScoreMultiModalParam( - content=[create_image_param(url) for url in query_strs]) + content=[create_image_param(url) for url in query_strs] + ) documents: Union[list[str], ScoreMultiModalParam] if doc_type == "text": documents = document_strs elif doc_type == "image": documents = ScoreMultiModalParam( - content=[create_image_param(url) for url in document_strs]) + content=[create_image_param(url) for url in document_strs] + ) with vllm_runner( - model_name, - runner="pooling", - dtype=dtype, - max_num_seqs=2, - max_model_len=2048, - mm_processor_kwargs=mm_processor_kwargs, - limit_mm_per_prompt=limit_mm_per_prompt, + model_name, + runner="pooling", + dtype=dtype, + max_num_seqs=2, + max_model_len=2048, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, ) as vllm_model: outputs = vllm_model.llm.score(query, documents) @@ -78,16 +79,15 @@ def hf_reranker( data_pairs = [[query_strs[0], d] for d in document_strs] with hf_runner( - model_name, - dtype=dtype, - trust_remote_code=True, - auto_cls=AutoModel, - model_kwargs={"key_mapping": checkpoint_to_hf_mapper}, + model_name, + dtype=dtype, + trust_remote_code=True, + auto_cls=AutoModel, + model_kwargs={"key_mapping": checkpoint_to_hf_mapper}, ) as hf_model: - return hf_model.model.compute_score(data_pairs, - max_length=2048, - query_type=query_type, - doc_type=doc_type) + return hf_model.model.compute_score( + data_pairs, max_length=2048, query_type=query_type, doc_type=doc_type + ) # Visual Documents Reranking @@ -100,10 +100,12 @@ def test_model_text_image(hf_runner, vllm_runner, model_name, dtype): "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "text", "image") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "text", "image") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "text", "image" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "text", "image" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) @@ -127,10 +129,12 @@ def test_model_text_text(hf_runner, vllm_runner, model_name, dtype): lower computational requirements.""", # noqa: E501 "数据提取么?为什么不用正则啊,你用正则不就全解决了么?", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "text", "text") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "text", "text") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "text", "text" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "text", "text" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) @@ -157,10 +161,12 @@ def test_model_image_text(hf_runner, vllm_runner, model_name, dtype): "数据提取么?为什么不用正则啊,你用正则不就全解决了么?", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "image", "text") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "image", "text") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "image", "text" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "image", "text" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) @@ -178,10 +184,12 @@ def test_model_image_image(hf_runner, vllm_runner, model_name, dtype): "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "image", "image") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "image", "image") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "image", "image" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "image", "image" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) diff --git a/tests/models/multimodal/pooling/test_llava_next.py b/tests/models/multimodal/pooling/test_llava_next.py index 50826677581d0..2053ce3994831 100644 --- a/tests/models/multimodal/pooling/test_llava_next.py +++ b/tests/models/multimodal/pooling/test_llava_next.py @@ -24,9 +24,10 @@ from ...utils import check_embeddings_close # built with LAPACK support. pytestmark = pytest.mark.skipif( not current_platform.is_cuda(), - reason="Llava Next model uses op that is only supported in CUDA") + reason="Llava Next model uses op that is only supported in CUDA", +) -llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 +llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 HF_TEXT_PROMPTS = [ # T -> X @@ -34,18 +35,21 @@ HF_TEXT_PROMPTS = [ "The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501 ), # T -> X - llama3_template.format( - "cherry blossom\nSummary above sentence in one word: "), + llama3_template.format("cherry blossom\nSummary above sentence in one word: "), ] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - # I -> X - "stop_sign": - llama3_template.format("<image>\nSummary above image in one word: "), - # I -> X - "cherry_blossom": - llama3_template.format("<image>\nSummary above image in one word: "), -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + # I -> X + "stop_sign": llama3_template.format( + "<image>\nSummary above image in one word: " + ), + # I -> X + "cherry_blossom": llama3_template.format( + "<image>\nSummary above image in one word: " + ), + } +) MODELS = ["royokong/e5-v"] @@ -63,23 +67,22 @@ def _run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - runner="pooling", - dtype=dtype, - max_model_len=4096, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, runner="pooling", dtype=dtype, max_model_len=4096, enforce_eager=True + ) as vllm_model: vllm_outputs = vllm_model.embed(input_texts, images=input_images) - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForImageTextToText) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForImageTextToText + ) as hf_model: # Patch the issue where generation_config.json is missing - hf_model.processor.patch_size = \ - hf_model.model.config.vision_config.patch_size + hf_model.processor.patch_size = hf_model.model.config.vision_config.patch_size # Patch the issue where image_token_id # exceeds the maximum allowed vocab size hf_model.model.resize_token_embeddings( - hf_model.model.language_model.vocab_size + 1) + hf_model.model.language_model.vocab_size + 1 + ) all_inputs = hf_model.get_inputs(input_texts, images=input_images) @@ -91,8 +94,7 @@ def _run_test( return_dict=True, output_hidden_states=True, ) - pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :], - dim=-1) + pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :], dim=-1) all_outputs.append(pooled_output.tolist()) @@ -142,8 +144,7 @@ def test_models_image( dtype: str, ) -> None: input_texts_images = [ - (text, asset.pil_image) - for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] diff --git a/tests/models/multimodal/pooling/test_phi3v.py b/tests/models/multimodal/pooling/test_phi3v.py index f918a0bd781ea..c799a5bd3e1ef 100644 --- a/tests/models/multimodal/pooling/test_phi3v.py +++ b/tests/models/multimodal/pooling/test_phi3v.py @@ -19,14 +19,14 @@ HF_TEXT_PROMPTS = [ "Retrieve an image of this caption: cherry blossom", ] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - # T + I -> X - "stop_sign": - "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 - # I -> X - "cherry_blossom": - "<|image_1|> Represent the given image for classification", # noqa: E501 -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + # T + I -> X + "stop_sign": "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 + # I -> X + "cherry_blossom": "<|image_1|> Represent the given image for classification", # noqa: E501 + } +) MODELS = ["TIGER-Lab/VLM2Vec-Full"] @@ -44,14 +44,14 @@ def _run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, runner="pooling", dtype=dtype, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True + ) as vllm_model: vllm_outputs = vllm_model.embed(input_texts, images=input_images) # use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: + with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: all_inputs = hf_model.get_inputs(input_texts, images=input_images) all_outputs = [] @@ -114,18 +114,21 @@ def test_models_image( dtype: str, ) -> None: input_texts_images = [ - (text, asset.pil_image) - for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] # add cases for special_tokens - input_texts_images.append(( - "\n<s><|user|>\n <|image_1|>\n\t <s>" - "Represent the given image for classification<|end|>" - "\n<|assistant|>\n", - Image.open( - get_vllm_public_assets(filename="cherry_blossom.jpg", - s3_prefix=VLM_IMAGES_DIR)), - )) + input_texts_images.append( + ( + "\n<s><|user|>\n <|image_1|>\n\t <s>" + "Represent the given image for classification<|end|>" + "\n<|assistant|>\n", + Image.open( + get_vllm_public_assets( + filename="cherry_blossom.jpg", s3_prefix=VLM_IMAGES_DIR + ) + ), + ) + ) input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index e9be79fba911f..abf4150a91329 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -4,8 +4,6 @@ import pytest import torch -from vllm.utils import set_default_torch_num_threads - from ....conftest import VllmRunner @@ -21,32 +19,30 @@ def _run_test( vllm_runner: type[VllmRunner], model: str, ) -> None: - prompt = [ { # This model deals with no text input "prompt_token_ids": [1], "multi_modal_data": generate_test_mm_data(), - } for _ in range(10) + } + for _ in range(10) ] - with ( - set_default_torch_num_threads(1), - vllm_runner( - model, - runner="pooling", - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True, - # Limit the maximum number of sequences to avoid the - # test going OOM during the warmup run - max_num_seqs=32, - ) as vllm_model, - ): + with vllm_runner( + model, + runner="pooling", + dtype="half", + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + default_torch_num_threads=1, + ) as vllm_model: vllm_model.encode(prompt) -MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] +MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] @pytest.mark.core_model diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py new file mode 100644 index 0000000000000..80f594021ca8a --- /dev/null +++ b/tests/models/multimodal/pooling/test_radio.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from transformers import AutoConfig, AutoModel, CLIPImageProcessor + +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.model_executor.models.radio import RadioModel +from vllm.transformers_utils.configs.radio import RadioConfig +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +from ....conftest import ImageTestAssets + +# we use snapshot_download to prevent conflicts between +# dynamic_module and trust_remote_code for hf_runner +DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] + + +@torch.inference_mode() +def run_radio_test( + image_assets: ImageTestAssets, + model_id: str, + *, + dtype: str, +): + model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + img_processor = CLIPImageProcessor.from_pretrained(model) + images = [asset.pil_image for asset in image_assets] + # Input resolution must be a multiple of `self.min_resolution_step`. + # Using `self.get_nearest_supported_resolution`, for assets 432x642 the + # nearest supported resolution is 432x640. + pixel_values = [ + img_processor(image, return_tensors="pt").pixel_values.to(torch_dtype)[ + :, :, :, :640 + ] + for image in images + ] + + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + + hf_model = AutoModel.from_pretrained( + model_id, + config=config, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).to("cuda") + hf_model.eval() + + hf_outputs_per_image = [ + hf_model(pixel_value.to("cuda")).features for pixel_value in pixel_values + ] + + radio_config = RadioConfig( + model_name=config.args["model"], reg_tokens=config.args["register_multiple"] + ) + vllm_model = RadioModel(radio_config) + vllm_model.load_weights(hf_model.state_dict()) + vllm_model = vllm_model.to("cuda", torch_dtype) + + vllm_outputs_per_image = [ + vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values + ] + del vllm_model, hf_model + cleanup_dist_env_and_memory() + + cos_similar = nn.CosineSimilarity(dim=-1) + for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image): + assert cos_similar(vllm_output, hf_output).mean() > 0.99 + + +@pytest.mark.parametrize( + "model_id", + [ + "nvidia/C-RADIOv2-H", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_radio(dist_init, image_assets, model_id, dtype: str) -> None: + run_radio_test( + image_assets, + model_id, + dtype=dtype, + ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index adc8b2510d677..d1361f336a071 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -6,19 +6,28 @@ from typing import Optional, Union import numpy as np import pytest -from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, - UserMessage) +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext +from vllm.config.multimodal import ( + AudioDummyOptions, + BaseDummyOptions, + ImageDummyOptions, + VideoDummyOptions, +) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache -from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, - cached_tokenizer_from_config, - encode_tokens) +from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + MistralTokenizer, + cached_tokenizer_from_config, + encode_tokens, +) from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS @@ -30,13 +39,48 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: """ # Ensure video metadata is included if "video" in mm_data: + # GLM4.1V doesn't support multiple videos video = mm_data["video"] - mm_data["video"] = (video, { - "total_num_frames": len(video), - "fps": len(video), - "duration": 1, - "video_backend": "opencv" - }) + num_frames = len(video) + mm_data["video"] = ( + video, + { + "total_num_frames": num_frames, + "fps": num_frames, + "duration": 1, + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": True, + }, + ) + return mm_data + + +def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + Patch the multimodal data for Qwen3-VL model. + """ + + def create_metadata(frames: np.ndarray): + num_frames = len(frames) + return { + "total_num_frames": num_frames, + "fps": 2.0, + "duration": num_frames / 2.0, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + "do_sample_frames": True, + } + + # Ensure video metadata is included + if "video" in mm_data: + video = mm_data["video"] + if isinstance(video, list): + # multiple videos + mm_data["video"] = [(vid, create_metadata(vid)) for vid in video] + else: + # single video + mm_data["video"] = (video, create_metadata(video)) return mm_data @@ -63,6 +107,11 @@ def _test_processing_correctness( revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + # Ensure that the cache can fit all of the data + mm_processor_cache_gb=2048, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -71,17 +120,30 @@ def _test_processing_correctness( model_config, tokenizer=cached_tokenizer_from_config(model_config), ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity_gb=2048) + cache = MultiModalProcessorOnlyCache(model_config) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() - limit_mm_per_prompt = { + # Keep integer limits for local data generation + limit_mm_per_prompt_ints = { modality: 3 if limit is None else limit for modality, limit in supported_mm_limits.items() } - model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt + def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: + if modality == "video": + return VideoDummyOptions(count=count) + if modality == "image": + return ImageDummyOptions(count=count) + if modality == "audio": + return AudioDummyOptions(count=count) + return BaseDummyOptions(count=count) + + # Assign normalized DummyOptions to the model config + model_config.get_multimodal_config().limit_per_prompt = { + modality: _to_dummy_options(modality, count) + for modality, count in limit_mm_per_prompt_ints.items() + } baseline_processor = factories.build_processor(ctx, cache=None) cached_processor = factories.build_processor(ctx, cache=cache) @@ -93,28 +155,23 @@ def _test_processing_correctness( input_to_hit = { "image": Image.new("RGB", size=(128, 128)), "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), - "audio": (np.zeros((512, )), 16000), + "audio": (np.zeros((512,)), 16000), } input_factory = { - "image": - partial(random_image, rng, min_wh=128, max_wh=256), - "video": - partial(random_video, - rng, - min_frames=2, - max_frames=16, - min_wh=128, - max_wh=256), - "audio": - partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), + "image": partial(random_image, rng, min_wh=128, max_wh=256), + "video": partial( + random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256 + ), + "audio": partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), } for batch_idx in range(num_batches): mm_data = { - k: - [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(limit + 1))] - for k, limit in limit_mm_per_prompt.items() + k: [ + (input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) + for _ in range(rng.randint(limit + 1)) + ] + for k, limit in limit_mm_per_prompt_ints.items() } mm_counts = {k: len(vs) for k, vs in mm_data.items()} @@ -122,12 +179,16 @@ def _test_processing_correctness( # Mistral chat outputs tokens directly, rather than text prompts if isinstance(tokenizer, MistralTokenizer): images = mm_data.get("image", []) - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=""), - *(ImageChunk(image=image) for image in images), - ]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ] + ), + ] + ) res = tokenizer.mistral.encode_chat_completion(request) prompt = res.tokens else: @@ -160,7 +221,6 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { - "mllama": False, "ovis": False, "ovis2_5": False, "paligemma": False, @@ -176,8 +236,11 @@ _IGNORE_MM_KEYS = { } MM_DATA_PATCHES = { - # GLM4.1V requires video metadata to be included in the input + # GLM4.1V and Qwen3-VL requires video metadata to be included in the input "glm4v": glm4_1v_patch_mm_data, + "glm4v_moe": glm4_1v_patch_mm_data, + "qwen3_vl": qwen3_vl_patch_mm_data, + "qwen3_vl_moe": qwen3_vl_patch_mm_data, } @@ -249,87 +312,93 @@ def _test_processing_correctness_one( baseline_text_result, baseline_tokenized_result, ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {text_prompt=}, " - f"{token_prompt=}, {mm_data=})", + msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})", ) _assert_inputs_equal( cached_text_result, cached_tokenized_result, ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {text_prompt=}, " - f"{token_prompt=}, {mm_data=})", + msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})", ) -# yapf: disable -@pytest.mark.parametrize("model_id", [ - "rhymes-ai/Aria", - "CohereForAI/aya-vision-8b", - "Salesforce/blip2-opt-2.7b", - "facebook/chameleon-7b", - "CohereLabs/command-a-vision-07-2025", - "deepseek-ai/deepseek-vl2-tiny", - "microsoft/Florence-2-base", - "adept/fuyu-8b", - "google/gemma-3-4b-it", - "google/gemma-3n-E2B-it", - "zai-org/glm-4v-9b", - "zai-org/GLM-4.1V-9B-Thinking", - "zai-org/GLM-4.5V", - "ibm-granite/granite-speech-3.3-2b", - "h2oai/h2ovl-mississippi-800m", - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", - "HuggingFaceM4/Idefics3-8B-Llama3", - "internlm/Intern-S1", - "OpenGVLab/InternVL2-1B", - "OpenGVLab/InternVL3-1B", - "Kwai-Keye/Keye-VL-8B-Preview", - "moonshotai/Kimi-VL-A3B-Instruct", - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "llava-hf/llava-1.5-7b-hf", - "llava-hf/llava-v1.6-mistral-7b-hf", - "llava-hf/LLaVA-NeXT-Video-7B-hf", - "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", - "meta-llama/Llama-3.2-11B-Vision-Instruct", - "TIGER-Lab/Mantis-8B-siglip-llama3", - "openbmb/MiniCPM-Llama3-V-2_5", - "openbmb/MiniCPM-o-2_6", - "openbmb/MiniCPM-V-2_6", - "MiniMaxAI/MiniMax-VL-01", - "allenai/Molmo-7B-D-0924", - "allenai/Molmo-7B-O-0924", - "nvidia/NVLM-D-72B", - "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", - "AIDC-AI/Ovis1.6-Gemma2-9B", - "AIDC-AI/Ovis1.6-Llama3.2-3B", - "AIDC-AI/Ovis2-1B", - "AIDC-AI/Ovis2.5-2B", - "google/paligemma-3b-mix-224", - "google/paligemma2-3b-ft-docci-448", - "microsoft/Phi-3.5-vision-instruct", - "microsoft/Phi-4-multimodal-instruct", - "mistralai/Pixtral-12B-2409", - "mistral-community/pixtral-12b", - "Qwen/Qwen-VL-Chat", - "Qwen/Qwen2-VL-2B-Instruct", - "Qwen/Qwen2.5-VL-3B-Instruct", - "Qwen/Qwen2-Audio-7B-Instruct", - "Qwen/Qwen2.5-Omni-3B", - "YannQi/R-4B", - "Skywork/Skywork-R1V-38B", - "HuggingFaceTB/SmolVLM2-2.2B-Instruct", - "stepfun-ai/step3", - "fixie-ai/ultravox-v0_5-llama-3_2-1b", - "openai/whisper-large-v3", - "omni-research/Tarsier-7b", - "omni-research/Tarsier2-Recap-7b", - "mistralai/Voxtral-Mini-3B-2507", -]) +@pytest.mark.parametrize( + "model_id", + [ + "rhymes-ai/Aria", + "CohereForAI/aya-vision-8b", + "Salesforce/blip2-opt-2.7b", + "facebook/chameleon-7b", + "CohereLabs/command-a-vision-07-2025", + "deepseek-ai/deepseek-vl2-tiny", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", + "adept/fuyu-8b", + "google/gemma-3-4b-it", + "google/gemma-3n-E2B-it", + "zai-org/glm-4v-9b", + "zai-org/GLM-4.1V-9B-Thinking", + "zai-org/GLM-4.5V", + "ibm-granite/granite-speech-3.3-2b", + "h2oai/h2ovl-mississippi-800m", + "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", + "HuggingFaceM4/Idefics3-8B-Llama3", + "internlm/Intern-S1", + "OpenGVLab/InternVL2-1B", + "OpenGVLab/InternVL3-1B", + "OpenGVLab/InternVL3_5-1B", + "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", + "OpenGVLab/InternVL3_5-30B-A3B", + "Kwai-Keye/Keye-VL-8B-Preview", + "Kwai-Keye/Keye-VL-1_5-8B", + "moonshotai/Kimi-VL-A3B-Instruct", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "llava-hf/LLaVA-NeXT-Video-7B-hf", + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "TIGER-Lab/Mantis-8B-siglip-llama3", + "mispeech/midashenglm-7b", + "openbmb/MiniCPM-Llama3-V-2_5", + "openbmb/MiniCPM-o-2_6", + "openbmb/MiniCPM-V-2_6", + "MiniMaxAI/MiniMax-VL-01", + "allenai/Molmo-7B-D-0924", + "allenai/Molmo-7B-O-0924", + "nvidia/NVLM-D-72B", + "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", + "AIDC-AI/Ovis1.6-Gemma2-9B", + "AIDC-AI/Ovis1.6-Llama3.2-3B", + "AIDC-AI/Ovis2-1B", + "AIDC-AI/Ovis2.5-2B", + "google/paligemma-3b-mix-224", + "google/paligemma2-3b-ft-docci-448", + "microsoft/Phi-3.5-vision-instruct", + "microsoft/Phi-4-multimodal-instruct", + "mistralai/Pixtral-12B-2409", + "mistral-community/pixtral-12b", + "Qwen/Qwen-VL-Chat", + "Qwen/Qwen2-VL-2B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2.5-Omni-3B", + "Qwen/Qwen3-VL-4B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", + "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "YannQi/R-4B", + "Skywork/Skywork-R1V-38B", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "stepfun-ai/step3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "openai/whisper-large-v3", + "omni-research/Tarsier-7b", + "omni-research/Tarsier2-Recap-7b", + "mistralai/Voxtral-Mini-3B-2507", + ], +) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("simplify_rate", [1.0]) -# yapf: enable def test_processing_correctness( model_id: str, hit_rate: float, diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index a49842e1099c2..553a5f719bd35 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -5,14 +5,27 @@ import pytest from vllm.assets.video import VideoAsset from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"]) @pytest.mark.parametrize("expected_toks_per_frame", [299]) -@pytest.mark.parametrize("num_frames", [32, 128]) -@pytest.mark.parametrize("fps, expected_grid_t", [(1, 5), (2, 10)]) +@pytest.mark.parametrize( + "num_frames, fps, expected_grid_t", + [ + # pre-sampled fixed frames (unexpected behavior, + # but we still expect it to work without errors) + (32, 1, 16), + (32, 2, 16), + (128, 1, 64), + (128, 2, 64), + # post-sampled frames (expected behavior) + (-1, 1, 5), + (-1, 2, 10), + ], +) def test_processor_override( model_id: str, expected_toks_per_frame: int, @@ -43,10 +56,54 @@ def test_processor_override( # Ensure we have the right number of placeholders per num_crops size hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token) - video_tok_count = processed_inputs["prompt_token_ids"].count( - video_token_id) - grid_t, _, _ = processed_inputs["mm_kwargs"].get_data( - )["video_grid_thw"][0] + video_tok_count = processed_inputs["prompt_token_ids"].count(video_token_id) + grid_t, _, _ = processed_inputs["mm_kwargs"].get_data()["video_grid_thw"][0] assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t + + +@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"]) +@pytest.mark.parametrize("fps", [2]) +def test_video_loader_consistency( + model_id: str, + fps: int, +): + """ + Ensure dynamic video loader (pre-sampled by loader) and normal video + loader (post-sampled by processor) produce same video processing outputs. + """ + ctx = build_model_context( + model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"video": 1}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {"fps": fps} + + # Build the image str / prompt based on the number of images we pass + prompt = "<|begin_of_video|><|video|><|end_of_video|>" + + video_path = VideoAsset(name="baby_reading", num_frames=-1).video_path + with open(video_path, "rb") as f: + video_bytes = f.read() + + static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes) + dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes( + video_bytes, fps=fps + ) + + # pre-sampled loader shouldn't read all frames + assert len(dynamic_video) < len(static_video) + + static_mm_data = {"video": [(static_video, static_metadata)]} + dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]} + + static_outputs = processor.apply(prompt, static_mm_data, hf_processor_mm_kwargs) + dynamic_outputs = processor.apply(prompt, dynamic_mm_data, hf_processor_mm_kwargs) + + assert static_outputs["prompt_token_ids"] == dynamic_outputs["prompt_token_ids"] + assert ( + static_outputs["mm_kwargs"].get_data() + == dynamic_outputs["mm_kwargs"].get_data() + ) diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 1adfe21352c41..bd21d4008fa7b 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for H2OVL's multimodal preprocessing kwargs.""" + from collections.abc import Mapping from typing import Optional @@ -23,8 +24,10 @@ def _get_expected_num_patches( min_num: int, max_num: int, ): - from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets, - get_h2ovl_target_ratios) + from vllm.model_executor.models.h2ovl import ( + calculate_h2ovl_targets, + get_h2ovl_target_ratios, + ) width, height = image.size @@ -101,24 +104,27 @@ def _run_check( total_expected_num_patches = sum( _get_expected_num_patches(config, image, len(images), min_num, max_num) - for image in images) + for image in images + ) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches -@pytest.mark.parametrize("model_id", [ - "h2oai/h2ovl-mississippi-800m", - "h2oai/h2ovl-mississippi-2b", -]) +@pytest.mark.parametrize( + "model_id", + [ + "h2oai/h2ovl-mississippi-800m", + "h2oai/h2ovl-mississippi-2b", + ], +) @pytest.mark.parametrize( "size_factors", [ @@ -165,10 +171,7 @@ def test_processor_override( _run_check( processor, - [ - rescale_image_size(image_assets[0].pil_image, f) - for f in size_factors - ], + [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors], min_num, max_num, hf_processor_mm_kwargs, diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index d3a55993e5588..351b9d018eec2 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for Idefics3's multimodal preprocessing kwargs.""" + import pytest from transformers import Idefics3Config @@ -11,14 +12,13 @@ from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["HuggingFaceM4/Idefics3-8B-Llama3"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ ({"size": {"longest_edge": 364}}, 169), ({"size": {"longest_edge": 728}}, 169 * (2**2 + 1)), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -42,8 +42,11 @@ def test_processor_override( hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass - placeholders = "<image>" if num_imgs == 1 else "\n".join( - f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + placeholders = ( + "<image>" + if num_imgs == 1 + else "\n".join(f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + ) prompt = f"<|begin_of_text|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501 # Build mm_data @@ -57,8 +60,7 @@ def test_processor_override( # Ensure the placeholders format are correct hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) - assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ - "input_ids"][0] + assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0] # Ensure we have the right number of placeholders per num_crops size image_token_id = ctx.get_hf_config().image_token_id diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index e4f25f5ac7123..6f6529cb9401a 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for InternVL's multimodal preprocessing kwargs.""" + from collections.abc import Mapping from typing import Optional @@ -24,7 +25,9 @@ def _get_expected_num_patches( max_num: int, ): from vllm.model_executor.models.internvl import ( - calculate_internvl_targets, get_internvl_target_ratios) + calculate_internvl_targets, + get_internvl_target_ratios, + ) width, height = image.size @@ -61,15 +64,15 @@ def _run_check( total_expected_num_patches = sum( _get_expected_num_patches(config, image, len(images), min_num, max_num) - for image in images) + for image in images + ) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches @@ -122,10 +125,7 @@ def test_processor_override( _run_check( processor, - [ - rescale_image_size(image_assets[0].pil_image, f) - for f in size_factors - ], + [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors], min_num, max_num, hf_processor_mm_kwargs, diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index bea4f43567eee..4c0791ea3cece 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -11,8 +11,7 @@ from ....conftest import ImageTestAssets from ...utils import build_model_context -@pytest.mark.parametrize("model_id", - ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) +@pytest.mark.parametrize("model_id", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) @pytest.mark.parametrize("mm_processor_kwargs", [{}]) @pytest.mark.parametrize("num_imgs", [1, 5]) @pytest.mark.parametrize("mm_processor_cache_gb", [0, 4]) @@ -38,13 +37,14 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor() vocab = tokenizer.get_vocab() - prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \ - + "<|image|>" * num_imgs \ + prompt = ( + "<|begin_of_text|><|header_start|>user<|header_end|>" + + "<|image|>" * num_imgs + "<|eot|><|header_start|>assistant<|header_end|>" + ) mm_data = { "image": [ - image_assets[(i % len(image_assets))].pil_image - for i in range(num_imgs) + image_assets[(i % len(image_assets))].pil_image for i in range(num_imgs) ] } if tokenized_prompt: @@ -64,22 +64,23 @@ def test_processor_override( if tiles_x * tiles_y > 1: num_x_separators += (tiles_x - 1) * tiles_y num_y_separators += tiles_y - assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \ - == num_x_separators - assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \ - == num_y_separators + assert prompt_token_ids.count(vocab[hf_processor.tile_token]) == num_x_separators + assert ( + prompt_token_ids.count(vocab[hf_processor.tile_global_token]) + == num_y_separators + ) # image token offsets img_locs = processed_inputs["mm_placeholders"].get("image", []) assert len(img_locs) == num_imgs - assert [img_loc.offset for img_loc in img_locs] == \ - [i for i, v in enumerate(prompt_token_ids) \ - if v == config.boi_token_index] + assert [img_loc.offset for img_loc in img_locs] == [ + i for i, v in enumerate(prompt_token_ids) if v == config.boi_token_index + ] # patch sizes and masks - num_patches_per_chunk = processor.info.get_patch_per_chunk( - config.vision_config) - assert prompt_token_ids.count(config.image_token_index) \ + num_patches_per_chunk = processor.info.get_patch_per_chunk(config.vision_config) + assert ( + prompt_token_ids.count(config.image_token_index) == sum(mm_data["patches_per_image"]) * num_patches_per_chunk - assert len(mm_data["pixel_values"]) \ - == sum(mm_data["patches_per_image"]) + ) + assert len(mm_data["pixel_values"]) == sum(mm_data["patches_per_image"]) diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py index ca34d1d758a46..ffe7ca17b5d61 100644 --- a/tests/models/multimodal/processing/test_llava_next.py +++ b/tests/models/multimodal/processing/test_llava_next.py @@ -22,8 +22,9 @@ def _validate_image_max_tokens_one( image_size: ImageSize, ) -> None: info = processor.info - feature_size = info.get_num_image_tokens(image_width=image_size.width, - image_height=image_size.height) + feature_size = info.get_num_image_tokens( + image_width=image_size.width, image_height=image_size.height + ) try: assert feature_size <= max_tokens, f"{feature_size} <= {max_tokens}" @@ -31,8 +32,9 @@ def _validate_image_max_tokens_one( failed_size_excs.append((image_size, exc)) -@pytest.mark.skip("This test takes around 5 minutes to run. " - "Comment this out to run it manually.") +@pytest.mark.skip( + "This test takes around 5 minutes to run. Comment this out to run it manually." +) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) def test_processor_max_tokens(model_id): ctx = build_model_context( @@ -66,9 +68,9 @@ def test_processor_max_tokens(model_id): pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -94,8 +96,10 @@ def _validate_image_prompt_replacements_one( # NOTE: There is a BOS token assert first_placeholder.offset == 1 - assert first_placeholder.length == ( - len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs + assert ( + first_placeholder.length + == (len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs + ) except Exception as exc: failed_size_excs.append((image_size, exc)) @@ -122,9 +126,9 @@ def _test_image_prompt_replacements( pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -138,11 +142,17 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), - (488, 183), (2560, 1669)] + image_ratios = [ + (171, 152), + (184, 161), + (198, 176), + (333, 296), + (369, 328), + (488, 183), + (2560, 1669), + ] image_sizes = [ - size for w, h in image_ratios - for size in [ImageSize(w, h), ImageSize(h, w)] + size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)] ] _test_image_prompt_replacements( @@ -152,8 +162,9 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) -@pytest.mark.skip("This test takes around 2 hours to run. " - "Comment this out to run it manually.") +@pytest.mark.skip( + "This test takes around 2 hours to run. Comment this out to run it manually." +) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("num_imgs", [1]) def test_processor_prompt_replacements_all(model_id, num_imgs): diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py index e6344c4e7e6fd..f5c552fe6476a 100644 --- a/tests/models/multimodal/processing/test_llava_onevision.py +++ b/tests/models/multimodal/processing/test_llava_onevision.py @@ -22,8 +22,9 @@ def _validate_image_max_tokens_one( image_size: ImageSize, ) -> None: info = processor.info - feature_size = info.get_num_image_tokens(image_width=image_size.width, - image_height=image_size.height) + feature_size = info.get_num_image_tokens( + image_width=image_size.width, image_height=image_size.height + ) try: assert feature_size <= max_tokens, f"{feature_size} <= {max_tokens}" @@ -31,10 +32,10 @@ def _validate_image_max_tokens_one( failed_size_excs.append((image_size, exc)) -@pytest.mark.skip("This test takes around 5 minutes to run. " - "Comment this out to run it manually.") -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.skip( + "This test takes around 5 minutes to run. Comment this out to run it manually." +) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) def test_processor_max_tokens(model_id): ctx = build_model_context( model_id, @@ -67,9 +68,9 @@ def test_processor_max_tokens(model_id): pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -94,8 +95,10 @@ def _validate_image_prompt_replacements_one( first_placeholder = image_placeholders[0] assert first_placeholder.offset == 0 - assert first_placeholder.length == len( - processed_inputs["prompt_token_ids"]) // num_imgs + assert ( + first_placeholder.length + == len(processed_inputs["prompt_token_ids"]) // num_imgs + ) except Exception as exc: failed_size_excs.append((image_size, exc)) @@ -121,14 +124,13 @@ def _test_image_prompt_replacements( pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_prompt_replacements_regression(model_id, num_imgs): ctx = build_model_context( @@ -138,11 +140,17 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), - (488, 183), (2560, 1669)] + image_ratios = [ + (171, 152), + (184, 161), + (198, 176), + (333, 296), + (369, 328), + (488, 183), + (2560, 1669), + ] image_sizes = [ - size for w, h in image_ratios - for size in [ImageSize(w, h), ImageSize(h, w)] + size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)] ] _test_image_prompt_replacements( @@ -152,10 +160,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) -@pytest.mark.skip("This test takes around 2 hours to run. " - "Comment this out to run it manually.") -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.skip( + "This test takes around 2 hours to run. Comment this out to run it manually." +) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1]) def test_processor_prompt_replacements_all(model_id, num_imgs): ctx = build_model_context( diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py index 9387212e3f101..11e0001235110 100644 --- a/tests/models/multimodal/processing/test_minimax_vl_01.py +++ b/tests/models/multimodal/processing/test_minimax_vl_01.py @@ -61,17 +61,17 @@ def _test_image_prompt_replacements( num_imgs: int, image_sizes: list[ImageSize], ) -> None: - failed_size_excs = list[tuple[ImageSize, Exception]]() for size in image_sizes: - _validate_image_prompt_replacements_one(processor, num_imgs, - failed_size_excs, size) + _validate_image_prompt_replacements_one( + processor, num_imgs, failed_size_excs, size + ) if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -85,11 +85,17 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), - (488, 183), (2560, 1669)] + image_ratios = [ + (171, 152), + (184, 161), + (198, 176), + (333, 296), + (369, 328), + (488, 183), + (2560, 1669), + ] image_sizes = [ - size for w, h in image_ratios - for size in [ImageSize(w, h), ImageSize(h, w)] + size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)] ] _test_image_prompt_replacements( diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py deleted file mode 100644 index b42d3f89f3cbf..0000000000000 --- a/tests/models/multimodal/processing/test_mllama.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mllama's multimodal preprocessing and profiling.""" -import pytest -from transformers import MllamaConfig - -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.profiling import MultiModalProfiler - -from ...utils import build_model_context - - -@pytest.mark.parametrize("model_id", - ["meta-llama/Llama-3.2-11B-Vision-Instruct"]) -@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072]) -@pytest.mark.parametrize("max_num_seqs", [1, 2, 8]) -def test_profiling( - model_id: str, - max_model_len: int, - max_num_seqs: int, -): - # regression test for https://github.com/vllm-project/vllm/issues/13929 - from vllm.model_executor.models.mllama import calc_token_per_chunk - - model_config_kwargs = { - "max_model_len": max_model_len, - } - ctx = build_model_context( - model_id, - model_config_kwargs=model_config_kwargs, - limit_mm_per_prompt={"image": 1}, - ) - - mm_config = ctx.get_mm_config() - processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - profiler = MultiModalProfiler(processor) - - dummy_encoder_data = profiler.get_encoder_dummy_data( - max_model_len, - mm_counts=mm_config.limit_per_prompt, - ) - dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( - max_model_len, - mm_counts=mm_config.limit_per_prompt, - ) - - hf_config = ctx.get_hf_config(MllamaConfig) - image_size = hf_config.vision_config.image_size - encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) - ] * max_num_seqs - - mm_data = processor.apply( - prompt=dummy_mm_data.prompt, - mm_data=dummy_mm_data.mm_data, - hf_processor_mm_kwargs=dict(), - )["mm_kwargs"].get_data() - - # Get the actual number of encoder tokens for each sample. - # Because attn_metadata.encoder_seq_lens only counts the last - # group of images for each sample, which is used to cheat the - # block manager to allocate blocks for those images only. - # See MllamaMultiModalProcessor for more details. - num_tiles = [[t] for t in mm_data.pop("num_tiles")] - num_tokens_per_tile = calc_token_per_chunk(image_size) - actual_encoder_seq_lens = [ - sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles - ] - - # simulate mllama image-present prefill. - for actual_len, last_group_len in zip(actual_encoder_seq_lens, - encoder_seq_lens): - assert actual_len >= last_group_len diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index 3be77b5da63f2..e5ff2d1391b62 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for mllama's multimodal preprocessing and profiling.""" + import pytest from torch import prod from transformers import Llama4Config @@ -17,23 +18,23 @@ def test_profiling(model_id: str, max_model_len: int): model_config_kwargs = { "max_model_len": max_model_len, } + mm_counts = {"image": 1} ctx = build_model_context( model_id, model_config_kwargs=model_config_kwargs, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt=mm_counts, ) - mm_config = ctx.get_mm_config() processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) profiler = MultiModalProfiler(processor) decoder_dummy_data = profiler.get_decoder_dummy_data( max_model_len, - mm_counts=mm_config.limit_per_prompt, + mm_counts=mm_counts, ) dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( max_model_len, - mm_counts=mm_config.limit_per_prompt, + mm_counts=mm_counts, ) hf_config = ctx.get_hf_config(Llama4Config) @@ -47,21 +48,25 @@ def test_profiling(model_id: str, max_model_len: int): image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size downsample_ratio = int( - round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) - tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio + round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)) + ) + tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch - num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][ - 1] # x-y seperator tokens - total_tokens = total_num_patches.item() + num_tiles.item( - ) + 3 # image start, image, image end + num_tiles = ( + mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][1] + ) # x-y separator tokens + total_tokens = ( + total_num_patches.item() + num_tiles.item() + 3 + ) # image start, image, image end profiled_tokens = profiler.get_mm_max_contiguous_tokens( max_model_len, - mm_counts=mm_config.limit_per_prompt, + mm_counts=mm_counts, ) assert total_tokens == profiled_tokens["image"] assert total_tokens == sum( - placeholder.length for placeholder in - decoder_dummy_data.multi_modal_placeholders["image"]) + placeholder.length + for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] + ) diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index d9f1965a053df..6ff6f396fa338 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for Nemotron-Nano-VL's multimodal preprocessing kwargs.""" + from collections.abc import Mapping from typing import Optional @@ -24,7 +25,9 @@ def _get_expected_num_patches( max_num: int, ): from vllm.model_executor.models.nemotron_vl import ( - calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios) + calculate_nemotron_vl_targets, + get_nemotron_vl_target_ratios, + ) width, height = image.size @@ -63,22 +66,21 @@ def _run_check( total_expected_num_patches = sum( _get_expected_num_patches(config, image, len(images), min_num, max_num) - for image in images) + for image in images + ) print(total_expected_num_patches) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<image>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches -@pytest.mark.parametrize("model_id", - ["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"]) +@pytest.mark.parametrize("model_id", ["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"]) @pytest.mark.parametrize( "size_factors", [ @@ -125,10 +127,7 @@ def test_processor_override( _run_check( processor, - [ - rescale_image_size(image_assets[0].pil_image, f) - for f in size_factors - ], + [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors], min_num, max_num, hf_processor_mm_kwargs, diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index 1f3646f794868..8faff2611e6fe 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for phi3v's multimodal preprocessing kwargs.""" + import pytest from vllm.multimodal import MULTIMODAL_REGISTRY @@ -10,7 +11,6 @@ from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ @@ -18,8 +18,8 @@ from ...utils import build_model_context ({"num_crops": 16}, 1921), # the default num_crops of phi-3.5-vision is 4 ({}, 757), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py index f16d261c2c6a4..5391555c26675 100644 --- a/tests/models/multimodal/processing/test_phi4mm.py +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for phi4mm's multimodal preprocessing kwargs.""" + import pytest from vllm.multimodal import MULTIMODAL_REGISTRY @@ -10,7 +11,6 @@ from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ @@ -18,8 +18,8 @@ from ...utils import build_model_context ({"dynamic_hd": 16}, 4433), # the default num_crops of phi-4-multimodal is 36 ({}, 9585), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -46,8 +46,7 @@ def test_processor_override( img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" - image_size = ctx.get_hf_config( - ).embd_layer["image_embd_layer"]["crop_size"] + image_size = ctx.get_hf_config().embd_layer["image_embd_layer"]["crop_size"] dummy_image_size = (image_size * 7, image_size * 7) dummy_image = image_assets[0].pil_image.resize(dummy_image_size) mm_data = {"image": [dummy_image] * num_imgs} @@ -56,5 +55,6 @@ def test_processor_override( # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count( - _IMAGE_PLACEHOLDER_TOKEN_ID) + _IMAGE_PLACEHOLDER_TOKEN_ID + ) assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index 985f4188fdb66..9f4cdb6789b2c 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -10,13 +10,13 @@ from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) -# yapf: disable @pytest.mark.parametrize( - ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [ + ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), + [ ({}, 1426, (5704, 1176)), ({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -48,8 +48,7 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values"].shape assert img_tok_count == expected_toks_per_img * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs diff --git a/tests/models/multimodal/processing/test_smolvlm.py b/tests/models/multimodal/processing/test_smolvlm.py index af8f983388c6c..6f77d5516d147 100644 --- a/tests/models/multimodal/processing/test_smolvlm.py +++ b/tests/models/multimodal/processing/test_smolvlm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for smolvlm's multimodal preprocessing kwargs.""" + import pytest from transformers import SmolVLMConfig @@ -11,14 +12,13 @@ from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ ({"max_image_size": {"longest_edge": 384}}, 1377), ({"max_image_size": {"longest_edge": 768}}, 405), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -42,8 +42,11 @@ def test_processor_override( hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass - placeholders = "<image>" if num_imgs == 1 else "\n".join( - f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + placeholders = ( + "<image>" + if num_imgs == 1 + else "\n".join(f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + ) prompt = f"<|im_start|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501 # Build mm_data @@ -57,8 +60,7 @@ def test_processor_override( # Ensure the placeholders format are correct hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) - assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ - "input_ids"][0] + assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0] # Ensure we have the right number of placeholders per num_crops size image_token_id = ctx.get_hf_config().image_token_id diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 79164f02c3398..6b6c53a50397b 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -1,30 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile from collections.abc import Iterable +from contextlib import contextmanager from functools import partial from typing import Any, Union -from unittest.mock import patch import numpy as np import pytest -from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, - UserMessage) +import torch.nn as nn +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image -from vllm.config import ModelConfig -from vllm.engine.llm_engine import LLMEngine as V0LLMEngine -from vllm.inputs import InputProcessingContext -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) -from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config +from vllm.config.multimodal import ( + AudioDummyOptions, + BaseDummyOptions, + ImageDummyOptions, + VideoDummyOptions, +) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + supports_multimodal, +) +from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs +from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads -from vllm.v1.core.kv_cache_utils import get_kv_cache_config -from vllm.v1.engine.core import EngineCore as V1EngineCore +from vllm.utils import is_list_of -from ....conftest import VllmRunner from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides @@ -38,16 +50,20 @@ ARCH_NEEDS_EXTRAS = [ "MiniCPMV", "PaliGemmaForConditionalGeneration", ] -REPO_ID_TO_SKIP = {"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test"} +REPO_ID_TO_SKIP = { + "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", +} ImageInput = list[Image.Image] -VideoInput = Union[list[Image.Image], list[np.ndarray], - list[tuple[np.ndarray, dict[str, Any]]]] +VideoInput = Union[ + list[Image.Image], list[np.ndarray], list[tuple[np.ndarray, dict[str, Any]]] +] AudioInput = list[tuple[np.ndarray, int]] -def _resize_data(_data: Union[Image.Image, np.ndarray], - size_factor: float) -> Union[Image.Image, np.ndarray]: +def _resize_data( + _data: Union[Image.Image, np.ndarray], size_factor: float +) -> Union[Image.Image, np.ndarray]: assert size_factor <= 1, "Size factor must be less than 1" # Image input if isinstance(_data, Image.Image): @@ -67,24 +83,23 @@ def _resize_data(_data: Union[Image.Image, np.ndarray], return _data[..., :T, :H, :W, :C] # Audio input elif isinstance(_data, np.ndarray) and _data.ndim == 1: - return _data[:int(len(_data) * size_factor)] + return _data[: int(len(_data) * size_factor)] raise AssertionError("This line should be unreachable.") def resize_mm_data( - data: Union[ImageInput, VideoInput, AudioInput], - size_factors: tuple[float, - ...]) -> Union[ImageInput, VideoInput, AudioInput]: - size_factors = size_factors[:len(data)] + data: Union[ImageInput, VideoInput, AudioInput], size_factors: tuple[float, ...] +) -> Union[ImageInput, VideoInput, AudioInput]: + size_factors = size_factors[: len(data)] if is_list_of(data, (Image.Image, np.ndarray, list)): return [_resize_data(d, s) for d, s in zip(data, size_factors)] elif is_list_of(data, tuple): - return [(_resize_data(d, s), meta) - for (d, meta), s in zip(data, size_factors)] + return [(_resize_data(d, s), meta) for (d, meta), s in zip(data, size_factors)] raise ValueError("Unsupported multimodal data type.") def create_batched_mm_kwargs( + model_cls: type[SupportsMultiModal], model_config: ModelConfig, processor: BaseMultiModalProcessor, size_factors: tuple[float, ...] = (1.0, 0.5, 0.25), @@ -108,12 +123,16 @@ def create_batched_mm_kwargs( # Mistral chat outputs tokens directly, rather than text prompts if model_config.tokenizer_mode == "mistral": images = resized_mm_data.get("image", []) - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=""), - *(ImageChunk(image=image) for image in images), - ]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ] + ), + ] + ) tokenizer = processing_info.get_tokenizer() res = tokenizer.mistral.encode_chat_completion(request) prompt = res.tokens @@ -124,23 +143,49 @@ def create_batched_mm_kwargs( mm_data=resized_mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, - )["mm_kwargs"] - items = [ - item for modality in supported_mm_limits - for item in mm_kwargs[modality] - ] - return group_mm_kwargs_by_modality(items) + )["mm_kwargs"].require_data() + items = [item for modality in supported_mm_limits for item in mm_kwargs[modality]] + return group_mm_kwargs_by_modality( + items, + merge_by_field_config=model_cls.merge_by_field_config, + ) -def get_model_id_to_test( - model_arch_list: Iterable[str]) -> list[tuple[str, str]]: +@contextmanager +def initialize_dummy_model( + model_cls: type[nn.Module], + model_config: ModelConfig, +): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(tensor_model_parallel_size=1) + vllm_config = VllmConfig(model_config=model_config) + with set_current_vllm_config(vllm_config=vllm_config): + with set_default_torch_dtype(model_config.dtype): + model = model_cls(vllm_config=vllm_config) + yield model + + del model + cleanup_dist_env_and_memory() + + +def get_model_id_to_test(model_arch_list: Iterable[str]) -> list[tuple[str, str]]: filtered_results = [] for model_arch in model_arch_list: model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS: available_repos = list( - map(lambda model_id: (model_arch, model_id), - [model_info.default, *model_info.extras.values()])) + map( + lambda model_id: (model_arch, model_id), + [model_info.default, *model_info.extras.values()], + ) + ) filtered_results.extend(available_repos) else: filtered_results.append((model_arch, model_info.default)) @@ -148,10 +193,9 @@ def get_model_id_to_test( @pytest.mark.parametrize( - "model_arch, model_id", - get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys())) -def test_model_tensor_schema(model_arch: str, model_id: str, - vllm_runner: type[VllmRunner], monkeypatch): + "model_arch, model_id", get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()) +) +def test_model_tensor_schema(model_arch: str, model_id: str): if model_arch in ARCH_TO_SKIP: pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") if model_id in REPO_ID_TO_SKIP: @@ -159,12 +203,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str, model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip", - check_max_version=False) + model_info.check_transformers_version(on_fail="skip", check_max_version=False) - hf_overrides_fn = partial(dummy_hf_overrides, - model_arch=model_arch, - exist_overrides=model_info.hf_overrides) + hf_overrides_fn = partial( + dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides, + ) model_config = ModelConfig( model_id, @@ -172,14 +217,26 @@ def test_model_tensor_schema(model_arch: str, model_id: str, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, - hf_overrides=model_info.hf_overrides, + hf_overrides=hf_overrides_fn, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + assert supports_multimodal(model_cls) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] - if not any( - hasattr(model_cls, f"_parse_and_validate_{m}_input") - for m in ["image", "video", "audio"]): + inputs_parse_methods = [] + for attr_name in dir(model_cls): + attr = getattr(model_cls, attr_name) + if hasattr(attr, "__annotations__"): + return_type = attr.__annotations__.get("return", None) + if return_type is not None and "Input" in str(return_type): + inputs_parse_methods.append(attr_name) + + if not any(inputs_parse_methods): pytest.skip(f"{model_arch} does not support tensor schema validation.") ctx = InputProcessingContext( @@ -193,67 +250,28 @@ def test_model_tensor_schema(model_arch: str, model_id: str, for modality, limit in supported_mm_limits.items() } - # Avoid calling model.forward() - def _initialize_kv_caches_v0(self) -> None: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 + def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: + if modality == "video": + return VideoDummyOptions(count=count) + if modality == "image": + return ImageDummyOptions(count=count) + if modality == "audio": + return AudioDummyOptions(count=count) + return BaseDummyOptions(count=count) - def _initialize_kv_caches_v1(self, vllm_config): - kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( - vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, - ) + model_config.get_multimodal_config().limit_per_prompt = { + modality: _to_dummy_options(modality, count) + for modality, count in limit_mm_per_prompt.items() + } + processor = factories.build_processor(ctx, cache=None) - # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config - return 1, 0, scheduler_kv_cache_config - - with (patch.object(V0LLMEngine, "_initialize_kv_caches", - _initialize_kv_caches_v0), - patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1), monkeypatch.context() as m): - m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - if model_info.v0_only: - m.setenv("VLLM_USE_V1", "0") - - # TODO(Isotr0py): Can we avoid initializing engine? - with ( - set_default_torch_num_threads(1), - vllm_runner( - model_id, - tokenizer_name=model_info.tokenizer, - tokenizer_mode=model_info.tokenizer_mode, - revision=model_info.revision, - trust_remote_code=model_info.trust_remote_code, - max_model_len=model_info.max_model_len, - load_format="dummy", - hf_overrides=hf_overrides_fn, - limit_mm_per_prompt=limit_mm_per_prompt, - enforce_eager=True, - ) as vllm_model, + with initialize_dummy_model(model_cls, model_config) as model: + for modality, _, mm_kwargs in create_batched_mm_kwargs( + model_cls, model_config, processor ): - model_config = vllm_model.llm.llm_engine.model_config - llm_engine = vllm_model.llm.llm_engine - - if hasattr(llm_engine, "processor"): - # v1 processor - mm_registry = llm_engine.processor.mm_registry - else: - # v0 input_preprocessor - mm_registry = llm_engine.input_preprocessor.mm_registry - - processor = mm_registry.create_processor(model_config) - - def validate_model_input(model, modality: str, - mm_kwargs: MultiModalKwargs): - method_name = f"_parse_and_validate_{modality}_input" - if hasattr(model, method_name): - getattr(model, method_name)(**mm_kwargs) - - for modality, _, mm_kwargs in create_batched_mm_kwargs( - model_config, processor): - valid_func = partial(validate_model_input, - modality=modality, - mm_kwargs=mm_kwargs) - vllm_model.apply_model(valid_func) + for method_name in inputs_parse_methods: + print( + f"Testing `{method_name}` with modality={modality} " + f"and mm_kwargs{list(mm_kwargs.keys())}" + ) + getattr(model, method_name)(modality=modality, **mm_kwargs) diff --git a/tests/models/multimodal/processing/test_transformers.py b/tests/models/multimodal/processing/test_transformers.py index 54a0be99384a8..e2a2186f470b4 100644 --- a/tests/models/multimodal/processing/test_transformers.py +++ b/tests/models/multimodal/processing/test_transformers.py @@ -7,9 +7,7 @@ from vllm.config import ModelConfig from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf: disable -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) def test_multimodal_processor(model_id): model_config = ModelConfig( model=model_id, @@ -18,9 +16,9 @@ def test_multimodal_processor(model_id): mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config) - image_pil = ImageAsset('cherry_blossom').pil_image + image_pil = ImageAsset("cherry_blossom").pil_image mm_data = {"image": image_pil} - str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 + str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 str_processed_inputs = mm_processor.apply( prompt=str_prompt, mm_data=mm_data, @@ -28,8 +26,23 @@ def test_multimodal_processor(model_id): ) ids_prompt = [ - 151644, 872, 220, 151646, 198, 3838, 374, 279, 2213, 315, 419, 2168, - 30, 151645, 151644, 77091, 198 + 151644, + 872, + 220, + 151646, + 198, + 3838, + 374, + 279, + 2213, + 315, + 419, + 2168, + 30, + 151645, + 151644, + 77091, + 198, ] ids_processed_inputs = mm_processor.apply( prompt=ids_prompt, @@ -37,4 +50,7 @@ def test_multimodal_processor(model_id): hf_processor_mm_kwargs={}, ) - assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"] + assert ( + str_processed_inputs["prompt_token_ids"] + == ids_processed_inputs["prompt_token_ids"] + ) diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index 7096810d8e15c..2179cf33a5735 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -19,7 +19,7 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]: """Create weights from safetensors checkpoint metadata""" metadata = try_get_safetensors_metadata(repo) weight_names = list(metadata.weight_map.keys()) - with torch.device('meta'): + with torch.device("meta"): return ((name, torch.empty(0)) for name in weight_names) @@ -59,6 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -81,6 +84,7 @@ def test_hf_model_weights_mapper(model_arch: str): weights_missing = ref_weight_names - weight_names weights_unmapped = weight_names - ref_weight_names - assert (not weights_missing and not weights_unmapped), ( + assert not weights_missing and not weights_unmapped, ( f"Following weights are not mapped correctly: {weights_unmapped}, " - f"Missing expected weights: {weights_missing}.") + f"Missing expected weights: {weights_missing}." + ) diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index bd696198931ff..c4c10832ede3a 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -11,12 +11,12 @@ from vllm.multimodal.image import rescale_image_size from ...conftest import IMAGE_ASSETS, ImageTestAssets, VllmRunner from ..utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - "cherry_blossom": - "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + "cherry_blossom": "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + } +) def run_awq_test( @@ -34,10 +34,13 @@ def run_awq_test( ): images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -45,37 +48,42 @@ def run_awq_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size - with vllm_runner(source_model, - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner( + source_model, + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + default_torch_num_threads=1, + ) as vllm_model: source_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in inputs_per_image ] - with vllm_runner(quant_model, - quantization="awq", - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner( + quant_model, + quantization="awq", + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + default_torch_num_threads=1, + ) as vllm_model: quant_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in inputs_per_image ] - for source_outputs, quant_outputs in zip(source_outputs_per_image, - quant_outputs_per_image): + for source_outputs, quant_outputs in zip( + source_outputs_per_image, quant_outputs_per_image + ): # TODO: Check whether using original CLIPVisionModel can improve # consistency against HF check_logprobs_close( @@ -107,13 +115,16 @@ def run_awq_test( @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @torch.inference_mode() -def test_awq_models(vllm_runner, image_assets, source_model, quant_model, - size_factors, dtype, max_tokens, num_logprobs, - monkeypatch) -> None: - - # Test V1: this test hangs during setup on single-scale input. - # TODO: fixure out why and re-enable this on V1. - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_awq_models( + vllm_runner, + image_assets, + source_model, + quant_model, + size_factors, + dtype, + max_tokens, + num_logprobs, +) -> None: run_awq_test( vllm_runner, image_assets, diff --git a/tests/models/quantization/test_bitblas.py b/tests/models/quantization/test_bitblas.py index 754ac9a29a132..f516cc2724a6b 100644 --- a/tests/models/quantization/test_bitblas.py +++ b/tests/models/quantization/test_bitblas.py @@ -7,9 +7,10 @@ As a result, in this test, we just confirm that the top selected tokens of the bitblas/GPTQ models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the +result in very slight nondeterminism for bitblas. As a result, we re-run the test up to 3 times to see if we pass. """ + from dataclasses import dataclass import pytest @@ -24,8 +25,10 @@ class ModelPair: model_pairs = [ - ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", - model_gptq="hxbgsyxh/opt-125m-4bit-128g"), + ModelPair( + model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", + model_gptq="hxbgsyxh/opt-125m-4bit-128g", + ), ] @@ -43,16 +46,19 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model_pair.model_bitblas, - dtype=dtype, - quantization="bitblas") as bitblas_model: + with vllm_runner( + model_pair.model_bitblas, dtype=dtype, quantization="bitblas" + ) as bitblas_model: bitblas_outputs = bitblas_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model_pair.model_gptq, dtype=dtype, - quantization="gptq") as gptq_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="gptq" + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py index e0e919b62b217..5e0421af1c17b 100644 --- a/tests/models/quantization/test_bitsandbytes.py +++ b/tests/models/quantization/test_bitsandbytes.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -'''Tests whether bitsandbytes computation is enabled correctly. +"""Tests whether bitsandbytes computation is enabled correctly. Run `pytest tests/quantization/test_bitsandbytes.py`. -''' - -import gc +""" import pytest -import torch from transformers import BitsAndBytesConfig from tests.quantization.utils import is_quant_method_supported @@ -18,8 +15,10 @@ from ..utils import check_embeddings_close, check_logprobs_close models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), - ("mistralai/Mistral-7B-Instruct-v0.3", - "quantize inflight model with both HF and Mistral format weights") + ( + "mistralai/Mistral-7B-Instruct-v0.3", + "quantize inflight model with both HF and Mistral format weights", + ), ] models_4bit_to_embedding_test = [ @@ -31,72 +30,84 @@ models_4bit_to_moe_test = [ ] models_pre_qaunt_4bit_to_test = [ - ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', - 'read pre-quantized 4-bit FP4 model'), - ('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'), + ( + "PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed", + "read pre-quantized 4-bit FP4 model", + ), + ("poedator/opt-125m-bnb-4bit", "read pre-quantized 4-bit NF4 opt model"), ] models_pre_quant_8bit_to_test = [ - ('meta-llama/Llama-Guard-3-8B-INT8', - 'read pre-quantized llama 8-bit model'), + ("meta-llama/Llama-Guard-3-8B-INT8", "read pre-quantized llama 8-bit model"), ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"), ] -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True)) - validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], - model_name, False, hf_model_kwargs) +def test_load_4bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) + validate_generated_texts( + hf_runner, vllm_runner, example_prompts[:1], model_name, False, hf_model_kwargs + ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", - models_pre_qaunt_4bit_to_test) -def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], - model_name, True) +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) +@pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test) +def test_load_pre_quant_4bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + validate_generated_texts( + hf_runner, vllm_runner, example_prompts[:1], model_name, True + ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", - models_pre_quant_8bit_to_test) -def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], - model_name, True) +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) +@pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test) +def test_load_8bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + validate_generated_texts( + hf_runner, vllm_runner, example_prompts[:1], model_name, True + ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_test) @multi_gpu_test(num_gpus=2) -def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True)) - validate_generated_texts(hf_runner, - vllm_runner, - example_prompts[:1], - model_name, - False, - hf_model_kwargs, - vllm_tp_size=2) +def test_load_tp_4bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) + validate_generated_texts( + hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + False, + hf_model_kwargs, + vllm_tp_size=2, + ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_test) @multi_gpu_test(num_gpus=2) def test_load_pp_4bit_bnb_model(model_name, description) -> None: @@ -118,27 +129,37 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test) -def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: +def test_4bit_bnb_moe_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + hf_model_kwargs = dict( + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + ) + with vllm_runner( + model_name, + quantization="bitsandbytes", + enforce_eager=False, + default_torch_num_threads=1, + ) as llm: + vllm_outputs = llm.generate_greedy_logprobs( + example_prompts, max_tokens=32, num_logprobs=5 + ) - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - )) - with vllm_runner(model_name, - quantization='bitsandbytes', - enforce_eager=False) as llm: - vllm_outputs = llm.generate_greedy_logprobs(example_prompts, - max_tokens=32, - num_logprobs=5) - - with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + with hf_runner( + model_name, model_kwargs=hf_model_kwargs, default_torch_num_threads=1 + ) as llm: transformers_outputs = llm.generate_greedy_logprobs_limit( - example_prompts, max_tokens=32, num_logprobs=5) + example_prompts, max_tokens=32, num_logprobs=5 + ) check_logprobs_close( outputs_0_lst=transformers_outputs, outputs_1_lst=vllm_outputs, @@ -147,10 +168,11 @@ def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", - models_4bit_to_embedding_test) +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) +@pytest.mark.parametrize("model_name, description", models_4bit_to_embedding_test) @pytest.mark.parametrize("dtype", ["half"]) def test_4bit_bnb_embedding_model( model_name, @@ -160,7 +182,6 @@ def test_4bit_bnb_embedding_model( example_prompts, dtype: str, ) -> None: - # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -170,20 +191,23 @@ def test_4bit_bnb_embedding_model( example_prompts = [str(s).strip() for s in example_prompts] # Inflight 4bit quantization - with vllm_runner(model_name, - runner="pooling", - dtype=dtype, - gpu_memory_utilization=0.5, - quantization="bitsandbytes") as vllm_model: + with vllm_runner( + model_name, + runner="pooling", + dtype=dtype, + gpu_memory_utilization=0.5, + quantization="bitsandbytes", + default_torch_num_threads=1, + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True)) + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) with hf_runner( - model_name, - dtype=dtype, - model_kwargs=hf_model_kwargs, - is_sentence_transformer=True, + model_name, + dtype=dtype, + model_kwargs=hf_model_kwargs, + is_sentence_transformer=True, + default_torch_num_threads=1, ) as hf_model: hf_outputs = hf_model.encode(example_prompts) @@ -208,47 +232,47 @@ def log_generated_texts(prompts, outputs, runner_name): return logged_texts -def validate_generated_texts(hf_runner, - vllm_runner, - prompts, - model_name, - pre_quant=False, - hf_model_kwargs=None, - vllm_tp_size=1, - max_tokens=8): - +def validate_generated_texts( + hf_runner, + vllm_runner, + prompts, + model_name, + pre_quant=False, + hf_model_kwargs=None, + vllm_tp_size=1, + max_tokens=8, +): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference - with vllm_runner(model_name, - quantization=None if pre_quant else 'bitsandbytes', - tensor_parallel_size=vllm_tp_size, - enforce_eager=False) as llm: - + with vllm_runner( + model_name, + quantization=None if pre_quant else "bitsandbytes", + tensor_parallel_size=vllm_tp_size, + enforce_eager=False, + default_torch_num_threads=1, + ) as llm: vllm_outputs = llm.generate_greedy(prompts, max_tokens) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") - # Clean up the GPU memory for the next test - gc.collect() - torch.cuda.empty_cache() - if hf_model_kwargs is None: hf_model_kwargs = {} # Run with HF runner - with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + with hf_runner( + model_name, model_kwargs=hf_model_kwargs, default_torch_num_threads=1 + ) as llm: hf_outputs = llm.generate_greedy(prompts, max_tokens) hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") - # Clean up the GPU memory for the next test - gc.collect() - torch.cuda.empty_cache() # Compare the generated strings for hf_log, vllm_log in zip(hf_logs, vllm_logs): hf_str = hf_log["generated_text"] vllm_str = vllm_log["generated_text"] prompt = hf_log["prompt"] - assert hf_str == vllm_str, (f"Model: {model_name}" - f"Mismatch between HF and vLLM outputs:\n" - f"Prompt: {prompt}\n" - f"HF Output: '{hf_str}'\n" - f"vLLM Output: '{vllm_str}'") + assert hf_str == vllm_str, ( + f"Model: {model_name}" + f"Mismatch between HF and vLLM outputs:\n" + f"Prompt: {prompt}\n" + f"HF Output: '{hf_str}'\n" + f"vLLM Output: '{vllm_str}'" + ) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index afc27b6e0566e..55b149ae5da71 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -5,6 +5,7 @@ """Tests fp8 models against ground truth generation Note: these tests will only pass on L4 GPU. """ + import pytest from tests.quantization.utils import is_quant_method_supported @@ -14,31 +15,40 @@ from vllm.utils import STR_BACKEND_ENV_VAR from ..utils import check_logprobs_close -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.", +) @pytest.mark.parametrize( "kv_cache_dtype,base_model,test_model", [ # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. - ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"), + ( + "fp8_e4m3", + "meta-llama/Llama-3.2-1B-Instruct", + "nm-testing/Llama-3.2-1B-Instruct-FP8-KV", + ), # Test BF16 checkpoint w. fp8_e5m2 kv-cache. - ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct"), + ( + "fp8_e5m2", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), # Test BF16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. - ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct") - ]) + ( + "fp8_e4m3", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + ], +) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) -# Due to low-precision numerical divergence, this test is too sensitive for -# the async postprocessor -@pytest.mark.parametrize("disable_async_output_proc", [True]) def test_models( vllm_runner, example_prompts, @@ -49,7 +59,6 @@ def test_models( enforce_eager: bool, backend: str, tensor_parallel_size: int, - disable_async_output_proc: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: """ @@ -58,37 +67,39 @@ def test_models( """ if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): - pytest.skip( - f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") + pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") + + if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None): + pytest.skip(f"{kv_cache_dtype} is not supported on this platform.") with monkeypatch.context() as m: - m.setenv("TOKENIZERS_PARALLELISM", 'true') + m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv(STR_BACKEND_ENV_VAR, backend) MAX_MODEL_LEN = 1024 NUM_LOG_PROBS = 8 with vllm_runner( - base_model, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - kv_cache_dtype="auto", - disable_async_output_proc=disable_async_output_proc, + base_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype="auto", ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) with vllm_runner( - test_model, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) check_logprobs_close( outputs_0_lst=baseline_outputs, @@ -99,20 +110,20 @@ def test_models( @pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), - reason="test for the CPU backend.") +@pytest.mark.skipif(not current_platform.is_cpu(), reason="test for the CPU backend.") @pytest.mark.parametrize( "kv_cache_dtype,base_model,test_model", [ # Test BF16 checkpoint w. fp8_e5m2 kv-cache. - ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct"), - ]) + ( + "fp8_e5m2", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + ], +) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) -# Due to low-precision numerical divergence, this test is too sensitive for -# the async postprocessor -@pytest.mark.parametrize("disable_async_output_proc", [True]) def test_cpu_models( vllm_runner, example_prompts, @@ -120,7 +131,6 @@ def test_cpu_models( base_model: str, test_model: str, max_tokens: int, - disable_async_output_proc: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: """ @@ -128,30 +138,30 @@ def test_cpu_models( numerical sensitive kernels. """ with monkeypatch.context() as m: - m.setenv("TOKENIZERS_PARALLELISM", 'true') + m.setenv("TOKENIZERS_PARALLELISM", "true") MAX_MODEL_LEN = 1024 NUM_LOG_PROBS = 8 with vllm_runner( - base_model, - max_model_len=MAX_MODEL_LEN, - dtype="bfloat16", - kv_cache_dtype="auto", - disable_async_output_proc=disable_async_output_proc, + base_model, + max_model_len=MAX_MODEL_LEN, + dtype="bfloat16", + kv_cache_dtype="auto", ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) with vllm_runner( - test_model, - max_model_len=MAX_MODEL_LEN, - dtype="bfloat16", - kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, + test_model, + max_model_len=MAX_MODEL_LEN, + dtype="bfloat16", + kv_cache_dtype=kv_cache_dtype, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) check_logprobs_close( outputs_0_lst=baseline_outputs, diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 3e77d3e710393..5e2438857aeef 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -100,35 +100,37 @@ def check_model_outputs( ): tokenizer = AutoTokenizer.from_pretrained(model.original_model) if tokenizer.chat_template is not None: - messages = [[{ - 'role': 'user', - 'content': prompt - }] for prompt in prompts] - prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + messages = [[{"role": "user", "content": prompt}] for prompt in prompts] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Run gguf model. - with vllm_runner(model_name=model.gguf_model, - enforce_eager=True, - tokenizer_name=model.original_model, - dtype=dtype, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=tp_size) as gguf_model: + with vllm_runner( + model_name=model.gguf_model, + enforce_eager=True, + tokenizer_name=model.original_model, + dtype=dtype, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tp_size, + ) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( - prompts[:-1], max_tokens, num_logprobs) + prompts[:-1], max_tokens, num_logprobs + ) # Run unquantized model. # Should run with tp=1, otherwise the test will stuck at # nccl initialization. with vllm_runner( - model_name=model.original_model, - enforce_eager=True, # faster tests - dtype=dtype, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) as original_model: + model_name=model.original_model, + enforce_eager=True, # faster tests + dtype=dtype, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + ) as original_model: original_outputs = original_model.generate_greedy_logprobs( - prompts[:-1], max_tokens, num_logprobs) + prompts[:-1], max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=original_outputs, @@ -138,12 +140,14 @@ def check_model_outputs( ) -@pytest.mark.skipif(not is_quant_method_supported("gguf"), - reason="gguf is not supported on this GPU type.") -@pytest.mark.parametrize("model", [ - pytest.param(test_config, marks=test_config.marks) - for test_config in MODELS -]) +@pytest.mark.skipif( + not is_quant_method_supported("gguf"), + reason="gguf is not supported on this GPU type.", +) +@pytest.mark.parametrize( + "model", + [pytest.param(test_config, marks=test_config.marks) for test_config in MODELS], +) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -157,12 +161,15 @@ def test_models( num_logprobs: int, tp_size: int, ) -> None: - check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens, - num_logprobs, tp_size) + check_model_outputs( + vllm_runner, example_prompts, model, dtype, max_tokens, num_logprobs, tp_size + ) -@pytest.mark.skipif(not is_quant_method_supported("gguf"), - reason="gguf is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gguf"), + reason="gguf is not supported on this GPU type.", +) @pytest.mark.parametrize("model", [LLAMA_CONFIG]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [8]) @@ -178,5 +185,6 @@ def test_distributed( num_logprobs: int, tp_size: int, ) -> None: - check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens, - num_logprobs, tp_size) + check_model_outputs( + vllm_runner, example_prompts, model, dtype, max_tokens, num_logprobs, tp_size + ) diff --git a/tests/models/quantization/test_gptq_bitblas.py b/tests/models/quantization/test_gptq_bitblas.py index c3aed77525de9..b29c5e769ce8f 100644 --- a/tests/models/quantization/test_gptq_bitblas.py +++ b/tests/models/quantization/test_gptq_bitblas.py @@ -7,9 +7,10 @@ As a result, in this test, we just confirm that the top selected tokens of the bitblas/GPTQ models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the +result in very slight nondeterminism for bitblas. As a result, we re-run the test up to 3 times to see if we pass. """ + from dataclasses import dataclass import pytest @@ -41,16 +42,19 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model_pair.model_gptq, - dtype=dtype, - quantization="bitblas") as bitblas_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="bitblas" + ) as bitblas_model: bitblas_outputs = bitblas_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model_pair.model_gptq, dtype=dtype, - quantization="gptq") as gptq_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="gptq" + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_gptq_marlin.py b/tests/models/quantization/test_gptq_marlin.py index db70a3bd2c046..cf52ae39214d2 100644 --- a/tests/models/quantization/test_gptq_marlin.py +++ b/tests/models/quantization/test_gptq_marlin.py @@ -9,6 +9,7 @@ Note: Marlin internally uses locks to synchronize the threads. This can result in very slight nondeterminism for Marlin. As a result, we re-run the test up to 3 times to see if we pass. """ + import os import pytest @@ -26,20 +27,20 @@ MAX_MODEL_LEN = 1024 MODELS = [ # act_order==True, group_size=128 ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"), - # 8-bit, act_order==True, group_size=channelwise ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"), - # 4-bit, act_order==True, group_size=128 - ("TechxGenus/gemma-1.1-2b-it-GPTQ", "main") + ("TechxGenus/gemma-1.1-2b-it-GPTQ", "main"), ] @pytest.mark.flaky(reruns=3) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="gptq_marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin") + or current_platform.is_rocm() + or not current_platform.is_cuda(), + reason="gptq_marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -55,29 +56,34 @@ def test_models( model_name, revision = model # Run marlin. - with vllm_runner(model_name=model_name, - revision=revision, - dtype=dtype, - quantization="marlin", - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) as gptq_marlin_model: - + with vllm_runner( + model_name=model_name, + revision=revision, + dtype=dtype, + quantization="marlin", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + ) as gptq_marlin_model: gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( - example_prompts[:-1], max_tokens, num_logprobs) + example_prompts[:-1], max_tokens, num_logprobs + ) _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error # Run gptq. # The naive gptq kernel doesn't support bf16 yet. # Here we always compare fp16/bf16 gpt marlin kernel # to fp16 gptq kernel. - with vllm_runner(model_name=model_name, - revision=revision, - dtype="half", - quantization="gptq", - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) as gptq_model: + with vllm_runner( + model_name=model_name, + revision=revision, + dtype="half", + quantization="gptq", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts[:-1], max_tokens, num_logprobs) + example_prompts[:-1], max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_gptq_marlin_24.py b/tests/models/quantization/test_gptq_marlin_24.py index 9b86ae95ba5c7..85426ee5b0898 100644 --- a/tests/models/quantization/test_gptq_marlin_24.py +++ b/tests/models/quantization/test_gptq_marlin_24.py @@ -6,6 +6,7 @@ Note: GPTQ and Marlin_24 do not have bitwise correctness. As a result, in this test, we just confirm that the top selected tokens of the Marlin/GPTQ models are in the top 3 selections of each other. """ + from dataclasses import dataclass import pytest @@ -24,15 +25,18 @@ class ModelPair: model_pairs = [ # 4-bit, group_size == 128 - ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", - model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"), + ModelPair( + model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128", + ), # # 4-bit, group_size == channelwise # ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise", # model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"), - # 8-bit, group_size == 128 - ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", - model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"), + ModelPair( + model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128", + ), # # 8-bit, group_size == channelwise # ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise", # model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"), @@ -40,10 +44,12 @@ model_pairs = [ @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="Marlin24 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin_24") + or current_platform.is_rocm() + or not current_platform.is_cuda(), + reason="Marlin24 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [8]) @@ -56,16 +62,19 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model_pair.model_marlin, - dtype=dtype, - quantization="gptq_marlin_24") as marlin_24_model: + with vllm_runner( + model_pair.model_marlin, dtype=dtype, quantization="gptq_marlin_24" + ) as marlin_24_model: marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model_pair.model_gptq, dtype=dtype, - quantization="gptq") as gptq_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="gptq" + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_modelopt.py b/tests/models/quantization/test_modelopt.py index e23d4d9d211d8..db3af972bb778 100644 --- a/tests/models/quantization/test_modelopt.py +++ b/tests/models/quantization/test_modelopt.py @@ -5,6 +5,7 @@ """Tests Model Optimizer fp8 models against ground truth generation Note: these tests will only pass on H100 """ + import os import pytest @@ -22,13 +23,13 @@ MODELS = ["nvidia/Llama-3.1-8B-Instruct-FP8"] EXPECTED_STRS_MAP = { "nvidia/Llama-3.1-8B-Instruct-FP8": [ "You're referring to VLLM, a high-performance Large Language Model (LLM) inference and", - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and', + "Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ", + "The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and", 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - '**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir', - 'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる' + "**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir", + "The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to", + "The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of", + "Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる", ] } @@ -39,10 +40,12 @@ EXPECTED_STRS_MAP = { # the hardware being run on. # Disabled to prevent it from breaking the build @pytest.mark.skip( - reason= - "Prevent unstable test based on golden strings from breaking the build.") -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") + reason="Prevent unstable test based on golden strings from breaking the build." +) +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: llm = LLM( @@ -55,12 +58,11 @@ def test_models(example_prompts, model_name) -> None: tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) for prompt in example_prompts ] params = SamplingParams(max_tokens=20, temperature=0) @@ -78,4 +80,5 @@ def test_models(example_prompts, model_name) -> None: generated_str = generations[i] expected_str = expected_strs[i] assert expected_str == generated_str, ( - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" + ) diff --git a/tests/models/quantization/test_mxfp4.py b/tests/models/quantization/test_mxfp4.py index 7b8a334bbc369..d598e405be817 100644 --- a/tests/models/quantization/test_mxfp4.py +++ b/tests/models/quantization/test_mxfp4.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # flake8: noqa -"""Tests Quark mxfp4 models against ground truth generation -""" +"""Tests Quark mxfp4 models against ground truth generation""" + import pytest from vllm import LLM, SamplingParams @@ -11,13 +11,13 @@ MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"] EXPECTED_STRS_MAP = { "amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [ - '\n### Key Features\n\n* **High-throughput Inference**: vLL', - '\nArtificial intelligence (AI) has evolved significantly since its inception in the 1', - 'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been', - 'A neural network is a machine learning model inspired by the structure of the human brain. It consists of', - '\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol', - '\nThe COVID-19 pandemic has had a profound impact on global economic structures and business', - 'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th', + "\n### Key Features\n\n* **High-throughput Inference**: vLL", + "\nArtificial intelligence (AI) has evolved significantly since its inception in the 1", + "Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been", + "A neural network is a machine learning model inspired by the structure of the human brain. It consists of", + "\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol", + "\nThe COVID-19 pandemic has had a profound impact on global economic structures and business", + "The Mona Lisa painting, created by Leonardo da Vinci in the early 16th", " everybody knows this proverbial saying, but did you know that it's not entirely accurate?", ] } @@ -38,4 +38,5 @@ def test_models(example_prompts, model_name) -> None: output_str = output.outputs[0].text expected_str = EXPECTED_STRS_MAP[model_name][i] assert expected_str == output_str, ( - f"Expected: {expected_str!r}\nvLLM: {output_str!r}") + f"Expected: {expected_str!r}\nvLLM: {output_str!r}" + ) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index b3c217e729e4a..9f45f142d68b1 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -4,6 +4,7 @@ """Tests Model Optimizer nvfp4 models against ground truth generation Note: these tests will only pass on B200 """ + import os from typing import List @@ -21,14 +22,14 @@ MODELS = ["nvidia/Llama-3.3-70B-Instruct-FP4"] EXPECTED_STRS_MAP = { "nvidia/Llama-3.3-70B-Instruct-FP4": [ - 'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process', - 'A neural network is a type of machine learning model inspired by the structure and function of the human brain', - 'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts' + "vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference", + "Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ", + "Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process", + "A neural network is a type of machine learning model inspired by the structure and function of the human brain", + "In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push", + "The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading", + "The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of", + "Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts", ] } @@ -39,11 +40,13 @@ EXPECTED_STRS_MAP = { # the hardware being run on. # Disabled to prevent it from breaking the build @pytest.mark.skip( - reason= - "Prevent unstable test based on golden strings from breaking the build " - " and test input model being too large and hanging the system.") -@pytest.mark.skipif(not is_quant_method_supported("modelopt_fp4"), - reason="modelopt_fp4 is not supported on this GPU type.") + reason="Prevent unstable test based on golden strings from breaking the build " + " and test input model being too large and hanging the system." +) +@pytest.mark.skipif( + not is_quant_method_supported("modelopt_fp4"), + reason="modelopt_fp4 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: llm = LLM( @@ -56,12 +59,11 @@ def test_models(example_prompts, model_name) -> None: tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) for prompt in example_prompts ] params = SamplingParams(max_tokens=20, temperature=0) @@ -79,4 +81,5 @@ def test_models(example_prompts, model_name) -> None: generated_str = generations[i] expected_str = expected_strs[i] assert expected_str == generated_str, ( - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 25dbbd7fa9832..615b03998323a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -6,10 +6,11 @@ from dataclasses import dataclass, field from typing import Any, Literal, Optional import pytest +import torch from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION -from vllm.config import TokenizerMode +from vllm.config.model import ModelDType, TokenizerMode @dataclass(frozen=True) @@ -47,6 +48,23 @@ class _HfExamplesInfo: The reason for the minimum/maximum version requirement. """ + skip_tokenizer_init: bool = False + """ + If true, skip initialization of tokenizer and detokenizer. + """ + + dtype: ModelDType = "auto" + """ + The data type for the model weights and activations. + """ + + enforce_eager: bool = False + """ + Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + """ + is_available_online: bool = True """ Set this to ``False`` if the name of this architecture no longer exists on @@ -76,6 +94,15 @@ class _HfExamplesInfo: If not specified, the default revision will be used. """ + max_num_seqs: Optional[int] = None + """Maximum number of sequences to be processed in a single iteration.""" + + use_original_num_layers: bool = False + """ + If True, use the original number of layers from the model config + instead of minimal layers for testing. + """ + def check_transformers_version( self, *, @@ -87,8 +114,10 @@ class _HfExamplesInfo: If the installed transformers version does not meet the requirements, perform the given action. """ - if (self.min_transformers_version is None - and self.max_transformers_version is None): + if ( + self.min_transformers_version is None + and self.max_transformers_version is None + ): return None current_version = TRANSFORMERS_VERSION @@ -98,11 +127,17 @@ class _HfExamplesInfo: msg = f"`transformers=={current_version}` installed, but `transformers" # Only check the base version for the min/max version, otherwise preview # models cannot be run because `x.yy.0.dev0`<`x.yy.0` - if (check_min_version and min_version - and Version(cur_base_version) < Version(min_version)): + if ( + check_min_version + and min_version + and Version(cur_base_version) < Version(min_version) + ): msg += f">={min_version}` is required to run this model." - elif (check_max_version and max_version - and Version(cur_base_version) > Version(max_version)): + elif ( + check_max_version + and max_version + and Version(cur_base_version) > Version(max_version) + ): msg += f"<={max_version}` is required to run this model." else: return None @@ -134,378 +169,636 @@ class _HfExamplesInfo: pytest.skip(msg) -# yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", - trust_remote_code=True), - "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", - trust_remote_code=True), + "ApertusForCausalLM": _HfExamplesInfo( + "swiss-ai/Apertus-8B-Instruct-2509", + min_transformers_version="4.56.0", + ), + "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), + "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"), - "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", - trust_remote_code=True), - "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", - trust_remote_code=True), - "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", - trust_remote_code=True), - "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", - trust_remote_code=True), - "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", - min_transformers_version="4.56.0", - extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 - "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", - {"1b": "bigscience/bloomz-1b1"}), - "ChatGLMModel": _HfExamplesInfo("zai-org/chatglm3-6b", - trust_remote_code=True, - max_transformers_version="4.48"), - "ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501 - trust_remote_code=True), - "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", - trust_remote_code=True), - "Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501 - trust_remote_code=True), + "ArcticForCausalLM": _HfExamplesInfo( + "Snowflake/snowflake-arctic-instruct", trust_remote_code=True + ), + "BaiChuanForCausalLM": _HfExamplesInfo( + "baichuan-inc/Baichuan-7B", trust_remote_code=True + ), + "BaichuanForCausalLM": _HfExamplesInfo( + "baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True + ), + "BailingMoeForCausalLM": _HfExamplesInfo( + "inclusionAI/Ling-lite-1.5", trust_remote_code=True + ), + "BailingMoeV2ForCausalLM": _HfExamplesInfo( + "inclusionAI/Ling-mini-2.0", trust_remote_code=True + ), + "BambaForCausalLM": _HfExamplesInfo( + "ibm-ai-platform/Bamba-9B-v1", + min_transformers_version="4.55.3", + extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}, + ), + "BloomForCausalLM": _HfExamplesInfo( + "bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"} + ), + "ChatGLMModel": _HfExamplesInfo( + "zai-org/chatglm3-6b", trust_remote_code=True, max_transformers_version="4.48" + ), + "ChatGLMForConditionalGeneration": _HfExamplesInfo( + "thu-coai/ShieldLM-6B-chatglm3", + trust_remote_code=True, + ), + "CohereForCausalLM": _HfExamplesInfo( + "CohereForAI/c4ai-command-r-v01", trust_remote_code=True + ), + "Cohere2ForCausalLM": _HfExamplesInfo( + "CohereForAI/c4ai-command-r7b-12-2024", + trust_remote_code=True, + ), + "CwmForCausalLM": _HfExamplesInfo( + "facebook/cwm", + trust_remote_code=True, + is_available_online=False, + ), "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), - "DeciLMForCausalLM": _HfExamplesInfo("nvidia/Llama-3_3-Nemotron-Super-49B-v1", # noqa: E501 - trust_remote_code=True), + "DeciLMForCausalLM": _HfExamplesInfo( + "nvidia/Llama-3_3-Nemotron-Super-49B-v1", + trust_remote_code=True, + ), "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), - "DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501 - trust_remote_code=True), - "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 - trust_remote_code=True), - "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT", - min_transformers_version="4.54"), - "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", - min_transformers_version="4.54"), - "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", - trust_remote_code=True), - "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B", - min_transformers_version="4.54"), - "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 + "DeepseekV2ForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-V2-Lite-Chat", + trust_remote_code=True, + ), + "DeepseekV3ForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-V3", + trust_remote_code=True, + ), + "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), + "Ernie4_5ForCausalLM": _HfExamplesInfo( + "baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54" + ), + "Ernie4_5_MoeForCausalLM": _HfExamplesInfo( + "baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54" + ), + "ExaoneForCausalLM": _HfExamplesInfo( + "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True + ), + "Exaone4ForCausalLM": _HfExamplesInfo( + "LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54" + ), + "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), + "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), + "FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), - "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it", - min_transformers_version="4.53"), + "Gemma3nForCausalLM": _HfExamplesInfo( + "google/gemma-3n-E2B-it", min_transformers_version="4.53" + ), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), - "Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5", - min_transformers_version="4.54"), # noqa: E501 - "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", - {"alias": "gpt2"}), - "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", - extras={"tiny": "bigcode/tiny_starcoder_py"}, # noqa: E501 - min_transformers_version="4.55.1", - transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501 - "GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M", - {"6b": "EleutherAI/gpt-j-6b"}), - "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", - {"1b": "EleutherAI/pythia-1.4b"}), + "Glm4MoeForCausalLM": _HfExamplesInfo( + "zai-org/GLM-4.5", min_transformers_version="4.54" + ), + "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), + "GPTBigCodeForCausalLM": _HfExamplesInfo( + "bigcode/starcoder", + extras={"tiny": "bigcode/tiny_starcoder_py"}, + min_transformers_version="4.55.1", + transformers_version_reason="HF model broken in 4.55.0", + ), + "GPTJForCausalLM": _HfExamplesInfo( + "Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"} + ), + "GPTNeoXForCausalLM": _HfExamplesInfo( + "EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"} + ), "GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 - "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 - "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", - trust_remote_code=True), - "HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct", - trust_remote_code=True), + "GraniteMoeHybridForCausalLM": _HfExamplesInfo( + "ibm-granite/granite-4.0-tiny-preview", + min_transformers_version="4.55.3", + ), + "GraniteMoeSharedForCausalLM": _HfExamplesInfo( + "ibm-research/moe-7b-1b-active-shared-experts" + ), + "Grok1ModelForCausalLM": _HfExamplesInfo( + "hpcai-tech/grok-1", trust_remote_code=True + ), + "HunYuanMoEV1ForCausalLM": _HfExamplesInfo( + "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True + ), # TODO: Remove is_available_online once their config.json is fixed - "HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124", - trust_remote_code=True, - is_available_online=False), - "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", - trust_remote_code=True), - "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", - trust_remote_code=True), - "InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B", - trust_remote_code=True), - "InternLM3ForCausalLM": _HfExamplesInfo("internlm/internlm3-8b-instruct", - trust_remote_code=True), + "HunYuanDenseV1ForCausalLM": _HfExamplesInfo( + "tencent/Hunyuan-7B-Instruct-0124", + trust_remote_code=True, + is_available_online=False, + ), + "InternLMForCausalLM": _HfExamplesInfo( + "internlm/internlm-chat-7b", trust_remote_code=True + ), + "InternLM2ForCausalLM": _HfExamplesInfo( + "internlm/internlm2-chat-7b", trust_remote_code=True + ), + "InternLM2VEForCausalLM": _HfExamplesInfo( + "OpenGVLab/Mono-InternVL-2B", trust_remote_code=True + ), + "InternLM3ForCausalLM": _HfExamplesInfo( + "internlm/internlm3-8b-instruct", trust_remote_code=True + ), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), - "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - min_transformers_version="4.56.0", - extras={ - "tiny": "ai21labs/Jamba-tiny-dev", - "random": "ai21labs/Jamba-tiny-random", # noqa: E501 - }), - "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B", - min_transformers_version="4.54"), - "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", - extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 - "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 - "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501 - "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", - is_available_online=False), - "Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - is_available_online=False), + "JambaForCausalLM": _HfExamplesInfo( + "ai21labs/AI21-Jamba-1.5-Mini", + min_transformers_version="4.55.3", + extras={ + "tiny": "ai21labs/Jamba-tiny-dev", + "random": "ai21labs/Jamba-tiny-random", + }, + ), + "Lfm2ForCausalLM": _HfExamplesInfo( + "LiquidAI/LFM2-1.2B", min_transformers_version="4.54" + ), + "Lfm2MoeForCausalLM": _HfExamplesInfo( + "LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58" + ), + "LlamaForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-3.2-1B-Instruct", + extras={ + "guard": "meta-llama/Llama-Guard-3-1B", + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", + "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + }, + ), + "LLaMAForCausalLM": _HfExamplesInfo( + "decapoda-research/llama-7b-hf", is_available_online=False + ), + "Llama4ForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + is_available_online=False, + ), + "LongcatFlashForCausalLM": _HfExamplesInfo( + "meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True + ), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), - "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), - "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 - "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", - trust_remote_code=True), - "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", - trust_remote_code=True), + "Mamba2ForCausalLM": _HfExamplesInfo( + "mistralai/Mamba-Codestral-7B-v0.1", + min_transformers_version="4.55.3", + extras={ + "random": "yujiepan/mamba2-codestral-v0.1-tiny-random", + }, + ), + "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), + "MiniCPMForCausalLM": _HfExamplesInfo( + "openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True + ), + "MiniCPM3ForCausalLM": _HfExamplesInfo( + "openbmb/MiniCPM3-4B", trust_remote_code=True + ), "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"), - "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", - trust_remote_code=True, - revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 - "MiniMaxM1ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-M1-40k", - trust_remote_code=True), + "MiniMaxText01ForCausalLM": _HfExamplesInfo( + "MiniMaxAI/MiniMax-Text-01", + trust_remote_code=True, + revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3", + ), + "MiniMaxM1ForCausalLM": _HfExamplesInfo( + "MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True + ), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), - "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 - {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 - "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 + "MixtralForCausalLM": _HfExamplesInfo( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + {"tiny": "TitanML/tiny-mixtral"}, + ), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), - "NemotronHForCausalLM": _HfExamplesInfo("nvidia/Nemotron-H-8B-Base-8K", - trust_remote_code=True), + "NemotronHForCausalLM": _HfExamplesInfo( + "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True + ), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), + "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), - "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", - {"1b": "facebook/opt-iml-max-1.3b"}), - "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", - trust_remote_code=True), + "OPTForCausalLM": _HfExamplesInfo( + "facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"} + ), + "OrionForCausalLM": _HfExamplesInfo( + "OrionStarAI/Orion-14B-Chat", trust_remote_code=True + ), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), - "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 - trust_remote_code=True, - v0_only=True, - max_model_len=10240), - "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", - trust_remote_code=True), - "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - max_transformers_version="4.53", - transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501 - trust_remote_code=True), - "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", - max_transformers_version="4.53", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 - trust_remote_code=True), - "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct", - extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501 + "PhiMoEForCausalLM": _HfExamplesInfo( + "microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True + ), + "Plamo2ForCausalLM": _HfExamplesInfo( + "pfnet/plamo-2-1b", + max_transformers_version="4.55.4", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + trust_remote_code=True, + ), + "QWenLMHeadModel": _HfExamplesInfo( + "Qwen/Qwen-7B-Chat", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + trust_remote_code=True, + ), + "Qwen2ForCausalLM": _HfExamplesInfo( + "Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"} + ), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "Qwen3NextForCausalLM": _HfExamplesInfo( + "Qwen/Qwen3-Next-80B-A3B-Instruct", + extras={"tiny-random": "tiny-random/qwen3-next-moe"}, + min_transformers_version="4.56.3", + ), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), - "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 - trust_remote_code=True, - is_available_online=False), + "SeedOssForCausalLM": _HfExamplesInfo( + "ByteDance-Seed/Seed-OSS-36B-Instruct", + trust_remote_code=True, + is_available_online=False, + ), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), - "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 + "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), - "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True), - "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", - trust_remote_code=True), - "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", - trust_remote_code=True), - "TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407", - trust_remote_code=True), - "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", - tokenizer="meta-llama/Llama-2-7b", - trust_remote_code=True), + "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", trust_remote_code=True), + "SolarForCausalLM": _HfExamplesInfo( + "upstage/solar-pro-preview-instruct", trust_remote_code=True + ), + "TeleChat2ForCausalLM": _HfExamplesInfo( + "Tele-AI/TeleChat2-3B", trust_remote_code=True + ), + "TeleFLMForCausalLM": _HfExamplesInfo( + "CofeAI/FLM-2-52B-Instruct-2407", trust_remote_code=True + ), + "XverseForCausalLM": _HfExamplesInfo( + "xverse/XVERSE-7B-Chat", + tokenizer="meta-llama/Llama-2-7b", + trust_remote_code=True, + ), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), - "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", - trust_remote_code=True), + "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), - # [Encoder-decoder] - "BartModel": _HfExamplesInfo("facebook/bart-base"), - "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), - "MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501 - hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] - "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), - "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), + "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), - "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - trust_remote_code=True), - "GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5", - trust_remote_code=True, - hf_overrides={"architectures": ["GteNewModel"]}), # noqa: E501 - "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", - trust_remote_code=True), - "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 + "GteModel": _HfExamplesInfo( + "Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True + ), + "GteNewModel": _HfExamplesInfo( + "Alibaba-NLP/gte-base-en-v1.5", + trust_remote_code=True, + hf_overrides={"architectures": ["GteNewModel"]}, + ), + "InternLM2ForRewardModel": _HfExamplesInfo( + "internlm/internlm2-1_8b-reward", trust_remote_code=True + ), + "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), - "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", - trust_remote_code=True, v0_only=True), - "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", - trust_remote_code=True, v0_only=True), # noqa: E501 + "ModernBertModel": _HfExamplesInfo( + "Alibaba-NLP/gte-modernbert-base", trust_remote_code=True + ), + "NomicBertModel": _HfExamplesInfo( + "nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True + ), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), - "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", - max_transformers_version="4.53", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", - max_transformers_version="4.53", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 + "Qwen2ForRewardModel": _HfExamplesInfo( + "Qwen/Qwen2.5-Math-RM-72B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + ), + "Qwen2ForProcessRewardModel": _HfExamplesInfo( + "Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + ), + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # [Multimodal] + "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), - "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", - trust_remote_code=True), - "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 - "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - is_available_online=False), # noqa: E501 + "Phi3VForCausalLM": _HfExamplesInfo( + "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True + ), + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), + "PrithviGeoSpatialMAE": _HfExamplesInfo( + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model + # going OOM in CI + max_num_seqs=32, + ), + "Terratorch": _HfExamplesInfo( + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model going OOM in CI + max_num_seqs=32, + ), } _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { # [Decoder-only] - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 - + "GPT2ForSequenceClassification": _HfExamplesInfo( + "nie3e/sentiment-polish-gpt2-small" + ), # [Cross-encoder] - "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo( + "cross-encoder/ms-marco-MiniLM-L-6-v2" + ), + "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), + "GteNewForSequenceClassification": _HfExamplesInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + trust_remote_code=True, + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + ), + "ModernBertForSequenceClassification": _HfExamplesInfo( + "Alibaba-NLP/gte-reranker-modernbert-base" + ), + "ModernBertForTokenClassification": _HfExamplesInfo( + "disham993/electrical-ner-ModernBERT-base" + ), + "RobertaForSequenceClassification": _HfExamplesInfo( + "cross-encoder/quora-roberta-base" + ), + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), } _AUTOMATIC_CONVERTED_MODELS = { # Use as_seq_cls_model for automatic conversion - "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - v0_only=True, - hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 - "classifier_from_token": ["Yes"], # noqa: E501 - "method": "no_post_processing"}), # noqa: E501 - "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 - "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 - "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 + "GemmaForSequenceClassification": _HfExamplesInfo( + "BAAI/bge-reranker-v2-gemma", + hf_overrides={ + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + }, + ), + "LlamaForSequenceClassification": _HfExamplesInfo( + "Skywork/Skywork-Reward-V2-Llama-3.2-1B" + ), + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), + "Qwen3ForSequenceClassification": _HfExamplesInfo( + "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + ), } _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), - "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 - "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 - extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 - "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 - "Cohere2VisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/command-a-vision-07-2025"), # noqa: E501 - "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 - extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 - max_transformers_version="4.48", # noqa: E501 - transformers_version_reason="HF model is not compatible.", # noqa: E501 - hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 + "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), + "Blip2ForConditionalGeneration": _HfExamplesInfo( + "Salesforce/blip2-opt-2.7b", + extras={"6b": "Salesforce/blip2-opt-6.7b"}, + ), + "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), + "Cohere2VisionForConditionalGeneration": _HfExamplesInfo( + "CohereLabs/command-a-vision-07-2025" + ), + "DeepseekVLV2ForCausalLM": _HfExamplesInfo( + "deepseek-ai/deepseek-vl2-tiny", + extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, + max_transformers_version="4.48", + transformers_version_reason="HF model is not compatible.", + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, + ), + "DotsOCRForCausalLM": _HfExamplesInfo( + "rednote-hilab/dots.ocr", trust_remote_code=True + ), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo( + "baidu/ERNIE-4.5-VL-28B-A3B-PT", + trust_remote_code=True, + ), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), - "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 - min_transformers_version="4.53"), - "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 - "GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b", - trust_remote_code=True, - hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 - "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501 - "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V", - min_transformers_version="4.56"), # noqa: E501 - "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", - trust_remote_code=True, - extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 - max_transformers_version="4.48", # noqa: E501 - transformers_version_reason="HF model is not compatible."), # noqa: E501 - "HCXVisionForCausalLM": _HfExamplesInfo("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", # noqa: E501 - trust_remote_code=True), - "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 - {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501 - min_transformers_version="4.56", - transformers_version_reason="HF model broken in 4.55"), # noqa: E501 - "InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1", - trust_remote_code=True), # noqa: E501 - "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", - extras={"2B": "OpenGVLab/InternVL2-2B", - "3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501 - trust_remote_code=True), - "KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501 - trust_remote_code=True), - "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 - extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 - trust_remote_code=True), - "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - max_model_len=10240, - extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501 - ), - "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", - extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 - "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 - "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 - "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 - "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 - "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501 - max_transformers_version="4.48", # noqa: E501 - transformers_version_reason="HF model is not compatible.", # noqa: E501 - hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 - "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", - trust_remote_code=True), - "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501 - trust_remote_code=True), - "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 - trust_remote_code=True, - v0_only=True), - "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 - extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 - "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", - max_transformers_version="4.48", - transformers_version_reason="Incorrectly-detected `tensorflow` import.", # noqa: E501 - extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 - trust_remote_code=True), - "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", - trust_remote_code=True), - "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 - trust_remote_code=True), - "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, - max_transformers_version="4.53", - transformers_version_reason="HF model is not compatible", # noqa: E501 - extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", - "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 - "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", - trust_remote_code=True), - "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 - extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 - "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", - trust_remote_code=True, - max_transformers_version="4.48", - transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 - extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 - "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", - trust_remote_code=True), - "Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501 - revision="refs/pr/70"), - "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 - tokenizer_mode="mistral"), - "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", - extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501 - trust_remote_code=True, - hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 - "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 - "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 - max_model_len=4096), + "Gemma3nForConditionalGeneration": _HfExamplesInfo( + "google/gemma-3n-E2B-it", + min_transformers_version="4.53", + ), + "GraniteSpeechForConditionalGeneration": _HfExamplesInfo( + "ibm-granite/granite-speech-3.3-2b" + ), + "GLM4VForCausalLM": _HfExamplesInfo( + "zai-org/glm-4v-9b", + trust_remote_code=True, + hf_overrides={"architectures": ["GLM4VForCausalLM"]}, + ), + "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), + "Glm4vMoeForConditionalGeneration": _HfExamplesInfo( + "zai-org/GLM-4.5V", min_transformers_version="4.56" + ), + "H2OVLChatModel": _HfExamplesInfo( + "h2oai/h2ovl-mississippi-800m", + trust_remote_code=True, + extras={"2b": "h2oai/h2ovl-mississippi-2b"}, + max_transformers_version="4.48", + transformers_version_reason="HF model is not compatible.", + ), + "HCXVisionForCausalLM": _HfExamplesInfo( + "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", + trust_remote_code=True, + ), + "Idefics3ForConditionalGeneration": _HfExamplesInfo( + "HuggingFaceM4/Idefics3-8B-Llama3", + {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55", + ), + "InternS1ForConditionalGeneration": _HfExamplesInfo( + "internlm/Intern-S1", trust_remote_code=True + ), + "InternVLChatModel": _HfExamplesInfo( + "OpenGVLab/InternVL2-1B", + extras={ + "2B": "OpenGVLab/InternVL2-2B", + "3.0": "OpenGVLab/InternVL3-1B", + "3.5-qwen3": "OpenGVLab/InternVL3_5-1B", + "3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B", + "3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", + }, + trust_remote_code=True, + ), + "InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), + "KeyeForConditionalGeneration": _HfExamplesInfo( + "Kwai-Keye/Keye-VL-8B-Preview", + trust_remote_code=True, + ), + "KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo( + "Kwai-Keye/Keye-VL-1_5-8B", + trust_remote_code=True, + ), + "KimiVLForConditionalGeneration": _HfExamplesInfo( + "moonshotai/Kimi-VL-A3B-Instruct", + extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, + trust_remote_code=True, + ), + "Llama4ForConditionalGeneration": _HfExamplesInfo( + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + max_model_len=10240, + extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, + ), + "LlavaForConditionalGeneration": _HfExamplesInfo( + "llava-hf/llava-1.5-7b-hf", + extras={ + "mistral": "mistral-community/pixtral-12b", + "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic", + }, + ), + "LlavaNextForConditionalGeneration": _HfExamplesInfo( + "llava-hf/llava-v1.6-mistral-7b-hf" + ), + "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo( + "llava-hf/LLaVA-NeXT-Video-7B-hf" + ), + "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), + "MantisForConditionalGeneration": _HfExamplesInfo( + "TIGER-Lab/Mantis-8B-siglip-llama3", + max_transformers_version="4.48", + transformers_version_reason="HF model is not compatible.", + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, + ), + "MiDashengLMModel": _HfExamplesInfo( + "mispeech/midashenglm-7b", trust_remote_code=True + ), + "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), + "MiniCPMV": _HfExamplesInfo( + "openbmb/MiniCPM-Llama3-V-2_5", + extras={ + "2.6": "openbmb/MiniCPM-V-2_6", + "4.0": "openbmb/MiniCPM-V-4", + "4.5": "openbmb/MiniCPM-V-4_5", + }, + trust_remote_code=True, + ), + "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo( + "MiniMaxAI/MiniMax-VL-01", + trust_remote_code=True, + v0_only=True, + ), + "Mistral3ForConditionalGeneration": _HfExamplesInfo( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}, + ), + "MolmoForCausalLM": _HfExamplesInfo( + "allenai/Molmo-7B-D-0924", + max_transformers_version="4.48", + transformers_version_reason="Incorrectly-detected `tensorflow` import.", + extras={"olmo": "allenai/Molmo-7B-O-0924"}, + trust_remote_code=True, + ), + "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), + "Llama_Nemotron_Nano_VL": _HfExamplesInfo( + "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", + trust_remote_code=True, + ), + "NemotronH_Nano_VL_V2": _HfExamplesInfo( + "nano_vl_dummy", is_available_online=False, trust_remote_code=True + ), + "Ovis": _HfExamplesInfo( + "AIDC-AI/Ovis2-1B", + trust_remote_code=True, + max_transformers_version="4.53", + transformers_version_reason="HF model is not compatible", + extras={ + "1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", + "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B", + }, + ), + "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True), + "PaliGemmaForConditionalGeneration": _HfExamplesInfo( + "google/paligemma-3b-mix-224", + extras={"v2": "google/paligemma2-3b-ft-docci-448"}, + ), + "Phi3VForCausalLM": _HfExamplesInfo( + "microsoft/Phi-3-vision-128k-instruct", + trust_remote_code=True, + max_transformers_version="4.48", + transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 + extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}, + ), + "Phi4MMForCausalLM": _HfExamplesInfo( + "microsoft/Phi-4-multimodal-instruct", trust_remote_code=True + ), + "Phi4MultimodalForCausalLM": _HfExamplesInfo( + "microsoft/Phi-4-multimodal-instruct", + revision="refs/pr/70", + ), + "PixtralForConditionalGeneration": _HfExamplesInfo( + "mistralai/Pixtral-12B-2409", + tokenizer_mode="mistral", + ), + "QwenVLForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen-VL", + extras={"chat": "Qwen/Qwen-VL-Chat"}, + trust_remote_code=True, + hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, + ), + "Qwen2AudioForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen2-Audio-7B-Instruct" + ), + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen2.5-VL-3B-Instruct", + max_model_len=4096, + ), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), - "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 - "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", - trust_remote_code=True), - "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", - trust_remote_code=True), - "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 - min_transformers_version="4.56", - transformers_version_reason="HF model broken in 4.55"), # noqa: E501 - "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True), - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 - trust_remote_code=True), - "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501 - "Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501 - hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501 + "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), + "Qwen3VLForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3-VL-4B-Instruct", + max_model_len=4096, + min_transformers_version="4.57", + is_available_online=False, + ), + "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3-VL-30B-A3B-Instruct", + max_model_len=4096, + min_transformers_version="4.57", + is_available_online=False, + ), + "Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3-Omni-30B-A3B-Instruct", + max_model_len=4096, + min_transformers_version="4.57", + ), + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), + "SkyworkR1VChatModel": _HfExamplesInfo( + "Skywork/Skywork-R1V-38B", trust_remote_code=True + ), + "SmolVLMForConditionalGeneration": _HfExamplesInfo( + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55", + ), + "Step3VLForConditionalGeneration": _HfExamplesInfo( + "stepfun-ai/step3", trust_remote_code=True + ), + "UltravoxModel": _HfExamplesInfo( + "fixie-ai/ultravox-v0_5-llama-3_2-1b", + trust_remote_code=True, + ), + "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), + "Tarsier2ForConditionalGeneration": _HfExamplesInfo( + "omni-research/Tarsier2-Recap-7b", + hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}, + ), "VoxtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Voxtral-Mini-3B-2507", min_transformers_version="4.54", @@ -513,70 +806,120 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] - # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer - # Therefore, we borrow the BartTokenizer from the original Bart model - "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501 - trust_remote_code=True), # noqa: E501 - "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 - "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 + "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # [Cross-encoder] - "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 + "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), } _SPECULATIVE_DECODING_EXAMPLE_MODELS = { - "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", - speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 + "MedusaModel": _HfExamplesInfo( + "JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random" + ), # Temporarily disabled. # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. - # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", - # speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 - "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", - speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 - trust_remote_code=True), - "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", - speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 - trust_remote_code=True), - "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", - trust_remote_code=True, - speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 - "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 - trust_remote_code=True, - speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", - tokenizer="meta-llama/Llama-3.1-8B-Instruct"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # trust_remote_code=True, - # speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # tokenizer="Qwen/Qwen3-8B"), + # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo( + # "JackFram/llama-160m", + # speculative_model="ibm-ai-platform/llama-160m-accelerator" + # ), + "DeepSeekMTPModel": _HfExamplesInfo( + "luccafong/deepseek_mtp_main_random", + speculative_model="luccafong/deepseek_mtp_draft_random", + trust_remote_code=True, + ), + "EagleDeepSeekMTPModel": _HfExamplesInfo( + "eagle618/deepseek-v3-random", + speculative_model="eagle618/eagle-deepseek-v3-random", + trust_remote_code=True, + ), + "EagleLlamaForCausalLM": _HfExamplesInfo( + "meta-llama/Meta-Llama-3-8B-Instruct", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", + ), + "Eagle3LlamaForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-3.1-8B-Instruct", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="meta-llama/Llama-3.1-8B-Instruct", + use_original_num_layers=True, + max_model_len=10240, + ), + "LlamaForCausalLMEagle3": _HfExamplesInfo( + "Qwen/Qwen3-8B", + trust_remote_code=True, + speculative_model="AngelSlim/Qwen3-8B_eagle3", + tokenizer="Qwen/Qwen3-8B", + use_original_num_layers=True, + ), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", - tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 - "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", - trust_remote_code=True, - is_available_online=False, - speculative_model="openbmb/MiniCPM-2B-sft-bf16", - tokenizer="openbmb/MiniCPM-2B-sft-bf16"), - "ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", - trust_remote_code=True, - speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"), - "Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5", - speculative_model="zai-org/GLM-4.5", - min_transformers_version="4.54", - is_available_online=False), - "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", - trust_remote_code=True, - speculative_model="XiaomiMiMo/MiMo-7B-RL") + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct", + ), + "EagleMiniCPMForCausalLM": _HfExamplesInfo( + "openbmb/MiniCPM-1B-sft-bf16", + trust_remote_code=True, + is_available_online=False, + speculative_model="openbmb/MiniCPM-2B-sft-bf16", + tokenizer="openbmb/MiniCPM-2B-sft-bf16", + ), + "ErnieMTPModel": _HfExamplesInfo( + "baidu/ERNIE-4.5-21B-A3B-PT", + trust_remote_code=True, + speculative_model="baidu/ERNIE-4.5-21B-A3B-PT", + ), + "Glm4MoeMTPModel": _HfExamplesInfo( + "zai-org/GLM-4.5", + speculative_model="zai-org/GLM-4.5", + min_transformers_version="4.56", + is_available_online=False, + ), + "LongCatFlashMTPModel": _HfExamplesInfo( + "meituan-longcat/LongCat-Flash-Chat", + trust_remote_code=True, + speculative_model="meituan-longcat/LongCat-Flash-Chat", + ), + "MiMoMTPModel": _HfExamplesInfo( + "XiaomiMiMo/MiMo-7B-RL", + trust_remote_code=True, + speculative_model="XiaomiMiMo/MiMo-7B-RL", + ), + "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo( + "Qwen/Qwen2.5-VL-7B-Instruct", + speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl", + ), + "Qwen3NextMTP": _HfExamplesInfo( + "Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3" + ), } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"), - "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 - "TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), + "TransformersEmbeddingModel": _HfExamplesInfo( + "BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0" + ), + "TransformersForSequenceClassification": _HfExamplesInfo( + "papluca/xlm-roberta-base-language-detection", + min_transformers_version="4.57.0.dev0", + ), + "TransformersForCausalLM": _HfExamplesInfo( + "hmellor/Ilama-3.2-1B", trust_remote_code=True + ), + "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "TransformersMoEForCausalLM": _HfExamplesInfo( + "allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0" + ), + "TransformersMoEForMultimodalLM": _HfExamplesInfo( + "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0" + ), + "TransformersMoEEmbeddingModel": _HfExamplesInfo( + "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" + ), + "TransformersMoEForSequenceClassification": _HfExamplesInfo( + "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" + ), } _EXAMPLE_MODELS = { @@ -599,7 +942,12 @@ class HfExampleModels: return self.hf_models.keys() def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: - return self.hf_models[model_arch] + try: + return self.hf_models[model_arch] + except KeyError: + raise ValueError( + f"No example model defined for {model_arch}; please update this file." + ) from None def find_hf_info(self, model_id: str) -> _HfExamplesInfo: for info in self.hf_models.values(): @@ -611,7 +959,9 @@ class HfExampleModels: if any(extra == model_id for extra in info.extras.values()): return info - raise ValueError(f"No example model defined for {model_id}") + raise ValueError( + f"No example model defined for {model_id}; please update this file." + ) HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index bbd3da982af84..f501798ffa36b 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -7,21 +7,55 @@ from unittest.mock import patch import pytest from vllm import LLM -from vllm.config import ModelImpl -from vllm.engine.llm_engine import LLMEngine as V0LLMEngine from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.core.kv_cache_utils import ( + generate_scheduler_kv_cache_config, + get_kv_cache_configs, +) from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test -from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS, - HF_EXAMPLE_MODELS, HfExampleModels) +from .registry import ( + _TRANSFORMERS_BACKEND_MODELS, + AUTO_EXAMPLE_MODELS, + HF_EXAMPLE_MODELS, + HfExampleModels, +) from .utils import dummy_hf_overrides +# This minimal list of model architectures is smaller than the total list of +# supported models. The intention is that in the "typical" regression testing +# scenario, we only test initializing these models. This subset was chosen +# to include representative examples of model varieties/workloads (conditional +# generation, sequence classification, causal LM, ranking, chat, reward model, +# multimodal, geospatial, voice, embedding, MTP) +MINIMAL_MODEL_ARCH_LIST = [ + "LlavaForConditionalGeneration", + "Llama4ForConditionalGeneration", + "BertForSequenceClassification", + "Gemma3nForCausalLM", + "JinaVLForRanking", + "InternVLChatModel", + "InternLM2ForRewardModel", + "TransformersForMultimodalLM", + "PrithviGeoSpatialMAE", + "UltravoxModel", + "DeepSeekMTPModel", + "XLMRobertaModel", +] + +# This list is the complement of the minimal list above. The intention is that +# this list of models is only tested in a "special case" i.e. most PRs should +# not test these models +OTHER_MODEL_ARCH_LIST = set(HF_EXAMPLE_MODELS.get_supported_archs()) - set( + MINIMAL_MODEL_ARCH_LIST +) + @create_new_process_for_each_test() -def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, - EXAMPLE_MODELS: HfExampleModels): +def can_initialize( + model_arch: str, monkeypatch: pytest.MonkeyPatch, EXAMPLE_MODELS: HfExampleModels +): """The reason for using create_new_process_for_each_test is to avoid the WARNING: "We must use the 'spawn' multiprocessing start method. Overriding @@ -34,74 +68,87 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - hf_overrides_fn = partial(dummy_hf_overrides, - model_arch=model_arch, - exist_overrides=model_info.hf_overrides) - - if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): - from vllm.model_executor.models.llama4 import Llama4ForCausalLM - from vllm.model_executor.models.registry import ModelRegistry - ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) + hf_overrides_fn = partial( + dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides, + use_original_num_layers=getattr(model_info, "use_original_num_layers", False), + ) # Avoid calling model.forward() - def _initialize_kv_caches_v0(self) -> None: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - def _initialize_kv_caches_v1(self, vllm_config): kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( + kv_cache_configs = get_kv_cache_configs( vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, + kv_cache_specs, + [10 * GiB_bytes], ) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config - with (patch.object(V0LLMEngine, "_initialize_kv_caches", - _initialize_kv_caches_v0), - patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1), monkeypatch.context() as m): + with ( + patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), + monkeypatch.context() as m, + ): if model_info.v0_only: - m.setenv("VLLM_USE_V1", "0") - if model_arch == "Phi4FlashForCausalLM": - # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend - m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") + # NOTE(woosuk): skip the test for V0-only models + return if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3. - m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + if model_arch == "WhisperForConditionalGeneration": + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") LLM( model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, + enforce_eager=model_info.enforce_eager, + skip_tokenizer_init=model_info.skip_tokenizer_init, + dtype=model_info.dtype, speculative_config={ "model": model_info.speculative_model, "num_speculative_tokens": 1, - } if model_info.speculative_model else None, + } + if model_info.speculative_model + else None, trust_remote_code=model_info.trust_remote_code, max_model_len=model_info.max_model_len, # these tests seem to produce leftover memory gpu_memory_utilization=0.80, load_format="dummy", - model_impl=ModelImpl.TRANSFORMERS - if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, + model_impl="transformers" + if model_arch in _TRANSFORMERS_BACKEND_MODELS + else "vllm", hf_overrides=hf_overrides_fn, + max_num_seqs=model_info.max_num_seqs, ) -@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) -def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("model_arch", MINIMAL_MODEL_ARCH_LIST) +def test_can_initialize_small_subset(model_arch: str, monkeypatch: pytest.MonkeyPatch): + """Test initializing small subset of supported models""" if model_arch == "Lfm2ForCausalLM": pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) -@pytest.mark.parametrize("model_arch", - AUTO_EXAMPLE_MODELS.get_supported_archs()) -def test_implicit_converted_models(model_arch: str, - monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("model_arch", OTHER_MODEL_ARCH_LIST) +def test_can_initialize_large_subset(model_arch: str, monkeypatch: pytest.MonkeyPatch): + """Test initializing large subset of supported models + + This test covers the complement of the tests covered in the "small subset" + test. + """ + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") + can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) + + +@pytest.mark.parametrize("model_arch", AUTO_EXAMPLE_MODELS.get_supported_archs()) +def test_implicit_converted_models(model_arch: str, monkeypatch: pytest.MonkeyPatch): can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 4aa7bb7297893..15e94eef4aa00 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -50,9 +50,9 @@ def test_oot_registration_embedding( with monkeypatch.context() as m: m.setenv("VLLM_PLUGINS", "register_dummy_model") prompts = ["Hello, my name is", "The text does not matter"] - llm = LLM(model=dummy_gemma2_embedding_path, - load_format="dummy", - max_model_len=2048) + llm = LLM( + model=dummy_gemma2_embedding_path, load_format="dummy", max_model_len=2048 + ) outputs = llm.embed(prompts) for output in outputs: @@ -69,27 +69,28 @@ def test_oot_registration_multimodal( ): with monkeypatch.context() as m: m.setenv("VLLM_PLUGINS", "register_dummy_model") - prompts = [{ - "prompt": "What's in the image?<image>", - "multi_modal_data": { - "image": image + prompts = [ + { + "prompt": "What's in the image?<image>", + "multi_modal_data": {"image": image}, }, - }, { - "prompt": "Describe the image<image>", - "multi_modal_data": { - "image": image + { + "prompt": "Describe the image<image>", + "multi_modal_data": {"image": image}, }, - }] + ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=dummy_llava_path, - load_format="dummy", - max_num_seqs=1, - trust_remote_code=True, - gpu_memory_utilization=0.98, - max_model_len=4096, - enforce_eager=True, - limit_mm_per_prompt={"image": 1}) + llm = LLM( + model=dummy_llava_path, + load_format="dummy", + max_num_seqs=1, + trust_remote_code=True, + gpu_memory_utilization=0.98, + max_model_len=4096, + enforce_eager=True, + limit_mm_per_prompt={"image": 1}, + ) first_token = llm.get_tokenizer().decode(0) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 8769ad45eb93e..9017a0fd91407 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,16 +6,22 @@ import warnings import pytest import torch.cuda -from vllm.model_executor.models import (is_pooling_model, - is_text_generation_model, - supports_multimodal) -from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model, - as_seq_cls_model) -from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, - _SPECULATIVE_DECODING_MODELS, - _TEXT_GENERATION_MODELS, - ModelRegistry) +from vllm.model_executor.models import ( + is_pooling_model, + is_text_generation_model, + supports_multimodal, +) +from vllm.model_executor.models.adapters import ( + as_embedding_model, + as_reward_model, + as_seq_cls_model, +) +from vllm.model_executor.models.registry import ( + _MULTIMODAL_MODELS, + _SPECULATIVE_DECODING_MODELS, + _TEXT_GENERATION_MODELS, + ModelRegistry, +) from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test @@ -24,6 +30,9 @@ from .registry import HF_EXAMPLE_MODELS @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): + # Skip if transformers version is incompatible + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_transformers_version(on_fail="skip") # Ensure all model classes can be imported successfully model_cls = ModelRegistry._try_load_model_cls(model_arch) assert model_cls is not None @@ -31,8 +40,7 @@ def test_registry_imports(model_arch): if model_arch in _SPECULATIVE_DECODING_MODELS: return # Ignore these models which do not have a unified format - if (model_arch in _TEXT_GENERATION_MODELS - or model_arch in _MULTIMODAL_MODELS): + if model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS: assert is_text_generation_model(model_cls) # All vLLM models should be convertible to a pooling model @@ -45,14 +53,16 @@ def test_registry_imports(model_arch): @create_new_process_for_each_test() -@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ - ("LlamaForCausalLM", False, False, False), - ("MllamaForConditionalGeneration", True, False, False), - ("LlavaForConditionalGeneration", True, True, False), - ("BertForSequenceClassification", False, False, True), - ("RobertaForSequenceClassification", False, False, True), - ("XLMRobertaForSequenceClassification", False, False, True), -]) +@pytest.mark.parametrize( + "model_arch,is_mm,init_cuda,is_ce", + [ + ("LlamaForCausalLM", False, False, False), + ("LlavaForConditionalGeneration", True, True, False), + ("BertForSequenceClassification", False, False, True), + ("RobertaForSequenceClassification", False, False, True), + ("XLMRobertaForSequenceClassification", False, False, True), + ], +) def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): model_info = ModelRegistry._try_inspect_model_cls(model_arch) assert model_info is not None @@ -68,7 +78,8 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): warnings.warn( "This model no longer initializes CUDA on import. " "Please test using a different one.", - stacklevel=2) + stacklevel=2, + ) @create_new_process_for_each_test() @@ -80,7 +91,8 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): # ("MLPSpeculatorPreTrainedModel", False, False), ("DeepseekV2ForCausalLM", True, False), ("Qwen2VLForConditionalGeneration", True, True), - ]) + ], +) def test_registry_is_pp(model_arch, is_pp, init_cuda): model_info = ModelRegistry._try_inspect_model_cls(model_arch) assert model_info is not None @@ -95,13 +107,16 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda): warnings.warn( "This model no longer initializes CUDA on import. " "Please test using a different one.", - stacklevel=2) + stacklevel=2, + ) def test_hf_registry_coverage(): - untested_archs = (ModelRegistry.get_supported_archs() - - HF_EXAMPLE_MODELS.get_supported_archs()) + untested_archs = ( + ModelRegistry.get_supported_archs() - HF_EXAMPLE_MODELS.get_supported_archs() + ) assert not untested_archs, ( "Please add the following architectures to " - f"`tests/models/registry.py`: {untested_archs}") + f"`tests/models/registry.py`: {untested_archs}" + ) diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py new file mode 100644 index 0000000000000..cadce5d2b2bb7 --- /dev/null +++ b/tests/models/test_terratorch.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.conftest import VllmRunner + + +@pytest.mark.parametrize( + "model", + [ + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "mgazz/Prithvi_v2_eo_300_tl_unet_agb", + ], +) +def test_inference( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) + location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) + prompt = dict( + prompt_token_ids=[1], + multi_modal_data=dict( + pixel_values=pixel_values, location_coords=location_coords + ), + ) + with vllm_runner( + model, + runner="pooling", + dtype="half", + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + default_torch_num_threads=1, + ) as vllm_model: + vllm_output = vllm_model.llm.encode(prompt) + assert torch.equal( + torch.isnan(vllm_output[0].outputs.data).any(), torch.tensor(False) + ) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 66ff8f7a54d31..b434c0955be7e 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test the functionality of the Transformers backend.""" + from typing import Any, Optional, Union import pytest @@ -8,9 +9,15 @@ import pytest from vllm.platforms import current_platform from ..conftest import HfRunner, VllmRunner -from ..core.block.e2e.test_correctness_sliding_window import prep_prompts -from ..utils import multi_gpu_test -from .utils import check_logprobs_close +from ..utils import multi_gpu_test, prep_prompts +from .registry import HF_EXAMPLE_MODELS +from .utils import check_embeddings_close, check_logprobs_close + + +def get_model(arch: str) -> str: + model_info = HF_EXAMPLE_MODELS.get_hf_info(arch) + model_info.check_transformers_version(on_fail="skip") + return model_info.default def check_implementation( @@ -54,13 +61,16 @@ def check_implementation( @pytest.mark.skipif( current_platform.is_rocm(), - reason="Llama-3.2-1B-Instruct, Ilama-3.2-1B produce memory access fault.") + reason="Llama-3.2-1B-Instruct, Ilama-3.2-1B produce memory access fault.", +) @pytest.mark.parametrize( "model,model_impl", [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE - ]) # trust_remote_code=True by default + ("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE + ], +) # trust_remote_code=True by default def test_models( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -68,23 +78,34 @@ def test_models( model: str, model_impl: str, ) -> None: - check_implementation(hf_runner, - vllm_runner, - example_prompts, - model, - model_impl=model_impl) + import transformers + from packaging.version import Version + + installed = Version(transformers.__version__) + required = Version("4.57.0.dev0") + if model == "allenai/OLMoE-1B-7B-0924" and installed < required: + pytest.skip( + "MoE models with the Transformers backend require " + f"transformers>={required}, but got {installed}" + ) + + check_implementation( + hf_runner, vllm_runner, example_prompts, model, model_impl=model_impl + ) def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None: prompts, _, _ = prep_prompts(4, (800, 801)) kwargs_ref = {"max_model_len": 8192, "enforce_eager": True} kwargs_test = {"model_impl": "transformers", **kwargs_ref} - check_implementation(vllm_runner, - vllm_runner, - prompts, - model="hmellor/tiny-random-Gemma2ForCausalLM", - kwargs_ref=kwargs_ref, - kwargs_test=kwargs_test) + check_implementation( + vllm_runner, + vllm_runner, + prompts, + model="hmellor/tiny-random-Gemma2ForCausalLM", + kwargs_ref=kwargs_ref, + kwargs_test=kwargs_test, + ) @multi_gpu_test(num_gpus=2) @@ -94,24 +115,28 @@ def test_distributed( example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} - check_implementation(hf_runner, - vllm_runner, - example_prompts, - "meta-llama/Llama-3.2-1B-Instruct", - kwargs_test=kwargs) - - -@pytest.mark.skipif( - current_platform.is_rocm(), - reason="bitsandbytes quantization is currently not supported in rocm.") -@pytest.mark.parametrize("model, quantization_kwargs", [ - ( + check_implementation( + hf_runner, + vllm_runner, + example_prompts, "meta-llama/Llama-3.2-1B-Instruct", - { - "quantization": "bitsandbytes", - }, - ), -]) + kwargs_test=kwargs, + ) + + +@pytest.mark.parametrize( + "model, quantization_kwargs", + [ + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}), + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}), + ( + "meta-llama/Llama-3.2-1B-Instruct", + { + "quantization": "bitsandbytes", + }, + ), + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_quantization( @@ -122,22 +147,34 @@ def test_quantization( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner( - model, model_impl="auto", enforce_eager=True, - **quantization_kwargs) as vllm_model: # type: ignore[arg-type] - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + if ( + current_platform.is_rocm() + and quantization_kwargs.get("quantization", "") == "bitsandbytes" + ): + pytest.skip("bitsandbytes quantization is currently not supported in rocm.") with vllm_runner( - model, - model_impl="transformers", - enforce_eager=True, - **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + model, + model_impl="auto", + enforce_eager=True, + **quantization_kwargs, # type: ignore[arg-type] + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs + ) + + with vllm_runner( + model, + model_impl="transformers", + enforce_eager=True, + **quantization_kwargs, # type: ignore[arg-type] + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config assert model_config.using_transformers_backend() transformers_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs + ) check_logprobs_close( outputs_0_lst=transformers_outputs, @@ -153,51 +190,66 @@ def test_quantization( # Layers live in `layers` "Qwen/Qwen3-Embedding-0.6B", # Layers live in `model.layers` - "meta-llama/Llama-3.2-1B-Instruct" + "meta-llama/Llama-3.2-1B-Instruct", ], ) def test_embed_loading(vllm_runner, model): - with vllm_runner(model, - max_model_len=1024, - enforce_eager=True, - runner="pooling", - model_impl="transformers") as model_test: + with vllm_runner( + model, + max_model_len=1024, + enforce_eager=True, + runner="pooling", + model_impl="transformers", + ) as model_test: model_config = model_test.llm.llm_engine.model_config assert model_config.using_transformers_backend() @pytest.mark.parametrize( - "model", - ["jason9693/Qwen2.5-1.5B-apeach"], + "arch", ["TransformersEmbeddingModel", "TransformersForSequenceClassification"] ) -@pytest.mark.parametrize("dtype", ["float"]) -def test_classify( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, -) -> None: - import torch - from transformers import AutoModelForSequenceClassification +def test_pooling(hf_runner, vllm_runner, example_prompts, arch): + model = get_model(arch) - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - model_impl="transformers") as vllm_model: + vllm_kwargs = dict( + max_model_len=None, + model_impl="transformers", + compilation_config=dict(cudagraph_capture_sizes=[8]), + ) + + hf_kwargs = dict() + if arch == "TransformersEmbeddingModel": + hf_kwargs["is_sentence_transformer"] = True + elif arch == "TransformersForSequenceClassification": + from transformers import AutoModelForSequenceClassification + + hf_kwargs["auto_cls"] = AutoModelForSequenceClassification + + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + + with ( + vllm_runner(model, **vllm_kwargs) as vllm_model, + hf_runner(model, **hf_kwargs) as hf_model, + ): model_config = vllm_model.llm.llm_engine.model_config assert model_config.using_transformers_backend() - vllm_outputs = vllm_model.classify(example_prompts) + if arch == "TransformersEmbeddingModel": + vllm_outputs = vllm_model.embed(example_prompts) + hf_outputs = hf_model.encode(example_prompts) + elif arch == "TransformersForSequenceClassification": + vllm_outputs = vllm_model.classify(example_prompts) + hf_outputs = hf_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: - hf_outputs = hf_model.classify(example_prompts) - - for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): - hf_output = torch.tensor(hf_output) - vllm_output = torch.tensor(vllm_output) - - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index b52327a1844f6..7cc4ee3c1856f 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest import torch from vllm.model_executor.models.utils import AutoWeightsLoader +pytestmark = pytest.mark.cpu_test + class ModuleWithBatchNorm(torch.nn.Module): - def __init__(self): super().__init__() self.bn = torch.nn.BatchNorm1d(2) @@ -17,7 +19,6 @@ class ModuleWithBatchNorm(torch.nn.Module): class ModuleWithNestedBatchNorm(torch.nn.Module): - def __init__(self): super().__init__() self.nested_mod = ModuleWithBatchNorm() @@ -64,9 +65,11 @@ def test_module_with_child_containing_batchnorm_can_autoload(): new_mod = ModuleWithNestedBatchNorm() assert not torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) assert not torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var + ) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod) @@ -74,9 +77,9 @@ def test_module_with_child_containing_batchnorm_can_autoload(): # Ensure the stats are updated assert torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) - assert torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) + assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 @@ -98,9 +101,11 @@ def test_module_skip_prefix(): new_mod = ModuleWithNestedBatchNorm() assert not torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) assert not torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var + ) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."]) @@ -108,9 +113,9 @@ def test_module_skip_prefix(): # Ensure the stats are updated assert torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) - assert torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) + assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 @@ -134,9 +139,11 @@ def test_module_skip_substr(): new_mod = ModuleWithNestedBatchNorm() assert not torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) assert not torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var + ) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."]) @@ -144,7 +151,7 @@ def test_module_skip_substr(): # Ensure the stats are updated assert torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) - assert torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) + assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 310d3a3719b65..b323bca79f4e7 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -1,15 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math import pytest import torch +import torch.multiprocessing as mp -from vllm.model_executor.models.vision import resolve_visual_encoder_outputs +from tests.utils import multi_gpu_test +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) +from vllm.model_executor.models.vision import ( + get_load_balance_assignment, + resolve_visual_encoder_outputs, + run_dp_sharded_mrope_vision_model, + run_dp_sharded_vision_model, +) +from vllm.platforms import current_platform +from vllm.utils import get_open_port, update_environment_variables + +pytestmark = pytest.mark.cpu_test @pytest.mark.parametrize( - ("feature_sample_layers", "num_layers_loaded", "max_possible_layers", - "expected_features"), + ("select_layers", "num_layers_loaded", "max_possible_layers", "expected_features"), [ # All layers loaded ([1, 10], 10, 10, [1, 10]), @@ -17,19 +33,456 @@ from vllm.model_executor.models.vision import resolve_visual_encoder_outputs # Some layers not loaded ([1, 10], 10, 20, [1, 10]), ([-20, -11], 10, 20, [1, 10]), - ]) -def test_resolve_visual_encoder_outputs(feature_sample_layers, - num_layers_loaded, max_possible_layers, - expected_features): + ], +) +def test_resolve_visual_encoder_outputs( + select_layers, num_layers_loaded, max_possible_layers, expected_features +): """ Test that offsets are correctly handled for vision feature layers. """ - encoder_outputs = [ - torch.tensor([idx]) for idx in range(num_layers_loaded + 1) - ] + encoder_outputs = [torch.tensor([idx]) for idx in range(num_layers_loaded + 1)] output_tensor = resolve_visual_encoder_outputs( encoder_outputs=encoder_outputs, - feature_sample_layers=feature_sample_layers, post_layer_norm=None, - max_possible_layers=max_possible_layers) + select_layers=select_layers, + max_possible_layers=max_possible_layers, + ) assert torch.equal(torch.tensor(expected_features), output_tensor) + + +class SimpleLinearModel(torch.nn.Module): + """A simple linear vision model for testing.""" + + def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): + super().__init__() + self.flatten = torch.nn.Flatten() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor): + # Flatten the input and apply linear transformation + x = self.flatten(x) + return self.linear(x) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 4, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_vision_model_vs_direct( + local_rank: int, world_size: int, batch_size: int, master_port: int +): + """ + Test that run_dp_sharded_vision_model produces the same results as + calling the model directly. + """ + + # Set random seed for reproducibility + current_platform.seed_everything(0) + + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create a test input tensor + image_input = torch.randn(batch_size, 3, 224, 224) + + # Create a simple linear model + vision_model = SimpleLinearModel() + + # Run the model directly on the full input + with torch.inference_mode(): + direct_output = vision_model(image_input) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_vision_model(image_input, vision_model) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," + "expected_grouped_sizes_per_gpu,test_description", + [ + # Empty input + ([], 2, [], [0, 0], [0, 0], "empty input"), + # Fewer samples than GPUs + ( + [100, 200], + 4, + [1, 0], + [1, 1, 0, 0], + [200, 100, 0, 0], + "fewer samples than GPUs", + ), + # Single GPU + ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), + # Balanced assignment + ( + [100, 100, 100, 100], + 2, + [0, 2, 1, 3], + [2, 2], + [200, 200], + "balanced assignment", + ), + # Unbalanced sizes - this one is trickier since the algorithm is greedy + ( + [1000, 100, 200, 50], + 2, + [0, 2, 1, 3], + [1, 3], + [1000, 350], + "unbalanced sizes", + ), + ], +) +def test_get_load_balance_assignment_cases( + sizes, + num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description, +): + """Test get_load_balance_assignment with various input cases.""" + result = get_load_balance_assignment(sizes, num_gpus=num_gpus) + (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result + + # Common assertions for all cases + assert len(shuffle_indices) == len(sizes) + assert len(gpu_sample_counts) == num_gpus + assert len(grouped_sizes_per_gpu) == num_gpus + assert sum(gpu_sample_counts) == len(sizes) + + assert shuffle_indices == expected_shuffle_indices + + assert gpu_sample_counts == expected_gpu_sample_counts + assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu + + +class SimpleMRopeVisionModel(torch.nn.Module): + """A simple vision model for testing mrope functionality.""" + + def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.out_hidden_size = out_hidden_size + self.linear = torch.nn.Linear(768, out_hidden_size) + + def forward(self, pixel_values: torch.Tensor, grid_thw_list: list[list[int]]): + """Simple forward pass that simulates spatial merging.""" + # Apply linear transformation + embeddings = self.linear(pixel_values) + + # Simulate spatial merging by reducing the number of patches + merge_factor = self.spatial_merge_size * self.spatial_merge_size + + # Group patches and merge spatially + merged_embeddings = [] + start_idx = 0 + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + end_idx = start_idx + num_patches + + # Get patches for this image + image_patches = embeddings[start_idx:end_idx] + + # Simulate spatial merging by averaging groups of patches + merged_patches = num_patches // merge_factor + if merged_patches > 0: + # Reshape and average to simulate merging + reshaped = image_patches[: merged_patches * merge_factor].view( + merged_patches, merge_factor, -1 + ) + merged = reshaped.mean(dim=1) + merged_embeddings.append(merged) + + start_idx = end_idx + + if merged_embeddings: + return torch.cat(merged_embeddings, dim=0) + else: + return torch.empty( + (0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "batch_size", + [ + 1, # Single image + 3, # Small batch + 5, # Odd batch size (for testing padding) + ], +) +def test_run_dp_sharded_mrope_vision_model(batch_size: int): + world_size = 2 + # Launch processes + mp.spawn( + run_dp_sharded_mrope_vision_model_vs_direct, + args=( + world_size, + batch_size, + get_open_port(), + ), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_vs_direct( + local_rank: int, world_size: int, batch_size: int, master_port: int +): + """ + Test that run_dp_sharded_mrope_vision_model produces the same results as + calling the model directly. + """ + # Set random seed for reproducibility + current_platform.seed_everything(0) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create test data + grid_thw_list = [] + pixel_values_list = [] + + for i in range(batch_size): + # Varying image sizes for better testing + t, h, w = 1, 4 + i, 4 + i + grid_thw_list.append([t, h, w]) + + num_patches = t * h * w + # Create random pixel values for this image + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + # Concatenate all pixel values + pixel_values = torch.cat(pixel_values_list, dim=0) + + # Create a simple mrope vision model + vision_model = SimpleMRopeVisionModel() + + # Run the model directly on the full input (only on rank 0) + if local_rank == 0: + with torch.inference_mode(): + direct_output = vision_model(pixel_values, grid_thw_list) + + # Run the model through the sharded function + with torch.inference_mode(): + sharded_output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + sharded_output = torch.cat(sharded_output, dim=0) + + # Check that the world size is set up correctly + assert get_tensor_model_parallel_world_size() == world_size + + # Compare outputs (only on rank 0) + if local_rank == 0: + # Check that the outputs have the same shape + assert direct_output.shape == sharded_output.shape + # Check that the outputs are close (they should be identical) + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) + + +@multi_gpu_test(num_gpus=2) +def test_run_dp_sharded_mrope_vision_model_empty_input(): + world_size = 2 + mp.spawn( + run_dp_sharded_mrope_vision_model_empty_input_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_empty_input_worker( + local_rank: int, world_size: int, master_port: int +): + """Test run_dp_sharded_mrope_vision_model with empty input.""" + # Set up distributed environment + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create empty inputs + pixel_values = torch.empty((0, 768)) + grid_thw_list: list[list[int]] = [] + + vision_model = SimpleMRopeVisionModel() + + # Should handle empty input gracefully + with torch.inference_mode(): + output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + + assert len(output) == 0 + + +@multi_gpu_test(num_gpus=4) +def test_run_dp_sharded_mrope_vision_model_uneven_load(): + world_size = 4 + mp.spawn( + run_dp_sharded_mrope_vision_model_uneven_load_worker, + args=(world_size, get_open_port()), + nprocs=world_size, + ) + + +def run_dp_sharded_mrope_vision_model_uneven_load_worker( + local_rank: int, world_size: int, master_port: int +): + """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" + # Set up distributed environment + current_platform.seed_everything(123) + device = f"{current_platform.device_name}:{local_rank}" + current_platform.set_device(device) + torch.set_default_device(device) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Create images with very different sizes + grid_thw_list = [ + [1, 2, 2], # Small: 4 patches + [1, 8, 8], # Large: 64 patches + [1, 3, 3], # Medium: 9 patches + ] + + pixel_values_list = [] + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel() + + # Should handle uneven distribution without errors + with torch.inference_mode(): + output_tuple = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + + # Verify output shape is reasonable + merge_factor = vision_model.spatial_merge_size**2 + expected_output_patches = list( + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list + ) + + for i, output in enumerate(output_tuple): + assert output.shape[0] == expected_output_patches[i] + assert output.shape[1] == vision_model.out_hidden_size + + +@pytest.mark.parametrize("spatial_merge_size", [2, 4]) +def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): + """Test SimpleMRopeVisionModel with different spatial merge sizes.""" + device = current_platform.device_type + + grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images + pixel_values_list = [] + + for grid_thw in grid_thw_list: + num_patches = math.prod(grid_thw) + image_pixels = torch.randn(num_patches, 768, device=device) + pixel_values_list.append(image_pixels) + + pixel_values = torch.cat(pixel_values_list, dim=0) + vision_model = SimpleMRopeVisionModel(spatial_merge_size=spatial_merge_size).to( + device + ) + + with torch.inference_mode(): + output = vision_model(pixel_values, grid_thw_list) + + # Verify output dimensions based on spatial merging + total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) + merge_factor = spatial_merge_size**2 + expected_output_patches = total_patches // merge_factor + + assert output.shape[0] == expected_output_patches + assert output.shape[1] == vision_model.out_hidden_size diff --git a/tests/models/utils.py b/tests/models/utils.py index 84aeb927c5fa9..84697ad68d441 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,15 +3,17 @@ import warnings from collections.abc import Sequence -from typing import Any, NamedTuple, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union import torch import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.config import ModelConfig, ModelDType, RunnerOption -from vllm.inputs import InputContext -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.config.model import ModelConfig, ModelDType, RunnerOption +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs +from vllm.multimodal.processing import InputProcessingContext +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .registry import HF_EXAMPLE_MODELS @@ -31,16 +33,18 @@ def check_outputs_equal( """ assert len(outputs_0_lst) == len(outputs_1_lst) - for prompt_idx, (outputs_0, - outputs_1) in enumerate(zip(outputs_0_lst, - outputs_1_lst)): + for prompt_idx, (outputs_0, outputs_1) in enumerate( + zip(outputs_0_lst, outputs_1_lst) + ): output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 # The text and token outputs should exactly match - fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + fail_msg = ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}" + ) assert output_str_0 == output_str_1, fail_msg assert output_ids_0 == output_ids_1, fail_msg @@ -52,9 +56,9 @@ def check_outputs_equal( # * List of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int, - float]], - SampleLogprobs]]] +TokensTextLogprobs = tuple[ + list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]] +] # Allow for tokens to be represented as str's rather than IDs; # tuple of @@ -63,9 +67,9 @@ TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int, # * Optional list of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]], - list[dict[str, - Logprob]]]]] +TextTextLogprobs = tuple[ + list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]] +] # Representation of generated sequence as a tuple of # * Token ID list @@ -75,18 +79,21 @@ TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]], # # Allows prompt logprobs to be requested. TokensTextLogprobsPromptLogprobs = tuple[ - list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]], - Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]]] + list[int], + str, + Optional[Union[list[dict[int, float]], SampleLogprobs]], + Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]], +] def check_logprobs_close( *, - outputs_0_lst: Sequence[Union[TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs, - TextTextLogprobs]], - outputs_1_lst: Sequence[Union[TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs, - TextTextLogprobs]], + outputs_0_lst: Sequence[ + Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] + ], + outputs_1_lst: Sequence[ + Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] + ], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, @@ -126,9 +133,9 @@ def check_logprobs_close( assert len(outputs_0_lst) == len(outputs_1_lst) # Loop through responses to each prompt. - for prompt_idx, (outputs_0, - outputs_1) in enumerate(zip(outputs_0_lst, - outputs_1_lst)): + for prompt_idx, (outputs_0, outputs_1) in enumerate( + zip(outputs_0_lst, outputs_1_lst) + ): assert len(outputs_0) == len(outputs_1) if len(outputs_0) == 3: assert len(outputs_1) == 3 @@ -153,17 +160,18 @@ def check_logprobs_close( ) = outputs_1 # Test prompt logprobs closeness - if (prompt_logprobs_0 is not None - and prompt_logprobs_1 is not None): + if prompt_logprobs_0 is not None and prompt_logprobs_1 is not None: # Both sequences' prompt logprobs lists are not `None`` # (although individual list elements may be `None`); # for each token's logprobs: for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( - zip(prompt_logprobs_0, prompt_logprobs_1)): + zip(prompt_logprobs_0, prompt_logprobs_1) + ): fail_msg = ( f"Prompt logprobs test:" f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}" - f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}") + f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}" + ) if logprobs_elem_0 is None: # If the seq 0 token's logprobs are `None`, @@ -174,20 +182,24 @@ def check_logprobs_close( # the seq 1 token's logprobs must not be `None` assert logprobs_elem_1 is not None, fail_msg # Logprobs check: top-k token choices must be the same - assert (set(logprobs_elem_0.keys()) == set( - logprobs_elem_1.keys())), fail_msg + assert set(logprobs_elem_0.keys()) == set( + logprobs_elem_1.keys() + ), fail_msg else: # Both sequence logprobs lists must be `None` - fail_msg = (f"Prompt logprobs test:" - f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" - f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}") + fail_msg = ( + f"Prompt logprobs test:" + f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" + f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}" + ) - assert (prompt_logprobs_0 is None - and prompt_logprobs_1 is None), fail_msg + assert prompt_logprobs_0 is None and prompt_logprobs_1 is None, fail_msg else: - raise ValueError(f"Outputs tuple must have 3 or 4 elements but " - f"{len(outputs_0)} elements were provided: " - f"{outputs_0}") + raise ValueError( + f"Outputs tuple must have 3 or 4 elements but " + f"{len(outputs_0)} elements were provided: " + f"{outputs_0}" + ) if logprobs_0 is None: logprobs_0 = [None] * len(output_ids_0) @@ -204,9 +216,9 @@ def check_logprobs_close( logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:] # Loop through generated tokens. - for idx, (output_id_0, - output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - + for idx, (output_id_0, output_id_1) in enumerate( + zip(output_ids_0, output_ids_1) + ): is_tok_mismatch = output_id_0 != output_id_1 # If generated tokens don't match @@ -221,7 +233,8 @@ def check_logprobs_close( f"Test{prompt_idx}:" f"\nMatched tokens:\t{output_ids_0[:idx]}" f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}" - f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}") + f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}" + ) assert logprobs_elem_0 is not None, fail_msg assert logprobs_elem_1 is not None, fail_msg @@ -242,9 +255,11 @@ def check_logprobs_close( if output_str_0 != output_str_1 and warn_on_mismatch: # The token outputs exactly match, # so the text outputs should exactly match as well - fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + fail_msg = ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}" + ) with warnings.catch_warnings(): # This ensures that repeated warnings are shown @@ -263,7 +278,7 @@ def build_model_context( limit_mm_per_prompt: Optional[dict[str, int]] = None, mm_processor_cache_gb: int = 0, ): - """Creates an InputContext for a given model. + """Creates an InputProcessingContext for a given model. Args: model_id: ID of the model being considered. @@ -272,7 +287,7 @@ def build_model_context( limit_mm_per_prompt: Multimodal limits. Returns: - InputContext for the model being considered. + InputProcessingContext for the model being considered. """ model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -293,9 +308,15 @@ def build_model_context( limit_mm_per_prompt=limit_mm_per_prompt, mm_processor_cache_gb=mm_processor_cache_gb, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) - return InputContext(model_config) + + return InputProcessingContext( + model_config, + tokenizer=cached_tokenizer_from_config(model_config), + ) def check_embeddings_close( @@ -309,18 +330,22 @@ def check_embeddings_close( assert len(embeddings_0_lst) == len(embeddings_1_lst) for prompt_idx, (embeddings_0, embeddings_1) in enumerate( - zip(embeddings_0_lst, embeddings_1_lst)): + zip(embeddings_0_lst, embeddings_1_lst) + ): assert len(embeddings_0) == len(embeddings_1), ( - f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") + f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}" + ) - sim = F.cosine_similarity(torch.tensor(embeddings_0), - torch.tensor(embeddings_1), - dim=0) + sim = F.cosine_similarity( + torch.tensor(embeddings_0), torch.tensor(embeddings_1), dim=0 + ) - fail_msg = (f"Test{prompt_idx}:" - f"\nCosine similarity: \t{sim:.4f}" - f"\n{name_0}:\t{embeddings_0[:16]!r}" - f"\n{name_1}:\t{embeddings_1[:16]!r}") + fail_msg = ( + f"Test{prompt_idx}:" + f"\nCosine similarity: \t{sim:.4f}" + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}" + ) assert sim >= 1 - tol, fail_msg @@ -339,45 +364,61 @@ def softmax(data): return F.softmax(data, dim=-1) -class EmbedModelInfo(NamedTuple): +@dataclass +class ModelInfo: name: str - is_matryoshka: bool = False - matryoshka_dimensions: Optional[list[int]] = None architecture: str = "" dtype: str = "auto" + hf_dtype: str = "float32" + hf_overrides: Optional[dict[str, Any]] = None default_pooling_type: str = "" enable_test: bool = True +@dataclass +class EmbedModelInfo(ModelInfo): + mteb_score: Optional[float] = None + is_matryoshka: bool = False + matryoshka_dimensions: Optional[list[int]] = None + + +@dataclass class CLSPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingEmbedModelInfo(EmbedModelInfo): default_pooling_type: str = "LAST" -class RerankModelInfo(NamedTuple): - name: str - architecture: str = "" - dtype: str = "auto" - default_pooling_type: str = "" - enable_test: bool = True +@dataclass +class RerankModelInfo(ModelInfo): + mteb_score: Optional[float] = None +@dataclass class CLSPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "CLS" +@dataclass class LASTPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "LAST" +@dataclass +class GenerateModelInfo(ModelInfo): + hf_dtype: str = "auto" + hf_ppl: Optional[float] = None + + def dummy_hf_overrides( hf_config: PretrainedConfig, *, model_arch: str = "", exist_overrides: Optional[dict[str, Any]] = None, + use_original_num_layers: bool = False, ) -> PretrainedConfig: """ Dummy HF overrides function used to create dummy model @@ -389,57 +430,89 @@ def dummy_hf_overrides( # Ensure at least 2 expert per group # Since `grouped_topk` assumes top-2 - n_group = getattr(text_config, 'n_group', None) + n_group = getattr(text_config, "n_group", None) num_experts = n_group * 2 if n_group is not None else 2 # we use three layers for Gemma-3n to check # both normal layer and kv_shared_layer - num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration" - else 1) - text_config.update({ - "num_layers": 1, - "num_hidden_layers": num_hidden_layers, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, + if use_original_num_layers: + # Use the original number of layers from the config + num_layers = getattr(text_config, "num_layers", 1) + num_hidden_layers = getattr(text_config, "num_hidden_layers", 1) + else: + # Use minimal layers for testing + num_layers = 1 + num_hidden_layers = 3 if model_arch == "Gemma3nForConditionalGeneration" else 1 + + update_dict = { + "num_layers": num_layers, # For Gemma-3n "num_kv_shared_layers": 1, - }) + } + + class DummyConfig: + hf_text_config = text_config + + # Only set MoE related config when the model has MoE layers. + # Otherwise all models detected as MoE by _get_transformers_backend_cls. + if ModelConfig.get_num_experts(DummyConfig) > 0: + update_dict.update( + { + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + } + ) + + # Update num_hidden_layers for non-Longcat architectures + if model_arch != "LongcatFlashForCausalLM" and model_arch != "LongCatFlashMTPModel": + update_dict["num_hidden_layers"] = num_hidden_layers + + text_config.update(update_dict) if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) + hf_config.vision_config.update( + { + "num_layers": 1, + "num_hidden_layers": 1, + } + ) # e.g.: ibm-granite/granite-speech-3.3-2b if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) + hf_config.encoder_config.update( + { + "num_layers": 1, + "num_hidden_layers": 1, + } + ) # e.g.: Qwen/Qwen2-Audio-7B-Instruct if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) + hf_config.audio_config.update( + { + "num_layers": 1, + "num_hidden_layers": 1, + "encoder_layers": 1, + } + ) return hf_config -def check_transformers_version(model: str, - min_transformers_version: Optional[str] = None, - max_transformers_version: Optional[str] = None): +def check_transformers_version( + model: str, + min_transformers_version: Optional[str] = None, + max_transformers_version: Optional[str] = None, +): from .registry import _HfExamplesInfo - return _HfExamplesInfo(model, - min_transformers_version=min_transformers_version, - max_transformers_version=max_transformers_version - ).check_transformers_version(on_fail="skip") + return _HfExamplesInfo( + model, + min_transformers_version=min_transformers_version, + max_transformers_version=max_transformers_version, + ).check_transformers_version(on_fail="skip") diff --git a/tests/mq_llm_engine/conftest.py b/tests/mq_llm_engine/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/mq_llm_engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py deleted file mode 100644 index 5ff08cbb32487..0000000000000 --- a/tests/mq_llm_engine/test_abort.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that aborting is handled properly.""" - -import asyncio -import tempfile -import uuid - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate -from vllm.engine.arg_utils import AsyncEngineArgs - -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) -RAISED_ERROR = KeyError -RAISED_VALUE = "foo" -EXPECTED_TOKENS = 250 - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest.mark.asyncio -async def test_abort(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - request_id_to_be_aborted = "request-aborted" - request_ids_a = [f"request-a-{idx}" for idx in range(10)] - request_ids_b = [f"request-b-{idx}" for idx in range(10)] - - # Requests started before one to be aborted. - tasks = [] - for request_id in request_ids_a: - tasks.append( - asyncio.create_task( - generate(client, request_id, EXPECTED_TOKENS))) - - # Aborted. - task_aborted = asyncio.create_task( - generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) - - # Requests started after one to be aborted. - for request_id in request_ids_b: - tasks.append( - asyncio.create_task( - generate(client, request_id, EXPECTED_TOKENS))) - - # Actually abort. - await asyncio.sleep(0.5) - await client.abort(request_id_to_be_aborted) - - # Confirm that we got all the EXPECTED tokens from the requests. - for task in tasks: - count, request_id = await task - assert count == EXPECTED_TOKENS, ( - f"{request_id} generated only {count} tokens") - - # Cancel task (this will hang indefinitely if not). - task_aborted.cancel() - - # Shutdown. - client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py deleted file mode 100644 index 77e3732cd06c6..0000000000000 --- a/tests/mq_llm_engine/test_error_handling.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that various errors are handled properly.""" - -import asyncio -import tempfile -import time -import uuid -from unittest.mock import Mock - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.engine.multiprocessing import MQEngineDeadError -from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroupMetadata -from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser - -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True) -RAISED_ERROR = KeyError -RAISED_VALUE = "foo" - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during first forward pass. - engine.engine.model_executor.execute_model = Mock( - side_effect=RAISED_ERROR(RAISED_VALUE)) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_evil_forward(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_forward) as engine: - - client = await engine.make_client() - - # Server should be healthy after initial probe. - await asyncio.sleep(2.0) - await client.check_health() - - # Throws an error that should get ENGINE_DEAD_ERROR. - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - assert client.errored - - await asyncio.sleep(1.0) - with pytest.raises(RAISED_ERROR): - await client.check_health() - assert client.errored - - # Shutdown. - client.close() - - -def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, - ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during first forward pass. - engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_health_check(tmp_socket): - with RemoteMQLLMEngine( - engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_model_executor_health) as engine: - - client = await engine.make_client() - assert client.is_running - - # Health probe should throw RAISED_ERROR. - await asyncio.sleep(15.) - - with pytest.raises(RAISED_ERROR): - await client.check_health() - assert client.errored - - # Generate call should throw ENGINE_DEAD_ERROR - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - - client.close() - - -def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during abort call. - engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_abort(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: - - client = await engine.make_client() - assert client.is_running - - # First check health should work. - await client.check_health() - - # Trigger an abort on the client side. - # This request ID does not exist, and will cause the engine to error - await client.abort(request_id="foo") - - # Future generation requests will now fail - # with reference to the original KeyError("foo") - with pytest.raises(MQEngineDeadError) as execinfo: - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id=str(uuid.uuid4())): - pass - assert "KeyError" in repr(execinfo.value) - assert client.errored - - # This should raise the original error. - with pytest.raises(RAISED_ERROR): - await client.check_health() - - client.close() - - -@pytest.mark.asyncio -async def test_batch_error(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: - - client = await engine.make_client() - assert client.is_running - - # First check health should work. - await client.check_health() - - # Batch of requests - async def do_generate(client): - # min_tokens=2048 to keep busy the engine busy - # to get enough time to get process a request - # that will crash the engine - params = SamplingParams(min_tokens=2048, max_tokens=2048) - async for _ in client.generate(prompt="Hello my name is", - sampling_params=params, - request_id=str(uuid.uuid4())): - pass - - tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] - - # This request will force a processing batch to raise - # an exception and next the engine get errored - await client.abort(request_id="foo") - - # The batch of those request failed, then they - # should get the same exception as a MQEngineDeadError. - errors = await asyncio.gather(*tasks, return_exceptions=True) - for e in errors: - assert isinstance(e, MQEngineDeadError) - assert "KeyError" in repr(e) - - client.close() - - -@pytest.mark.asyncio -async def test_bad_request(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - # Invalid request should fail, but not crash the server. - with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-1", - lora_request=LoRARequest( - "invalid-lora", 1, - "invalid-path")): - pass - - # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-2"): - pass - - # Shutdown. - client.close() - - -@pytest.mark.asyncio -async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - # When LLMEngine is loaded, it will crash. - def mock_init(): - raise ValueError - - m.setattr(LLMEngine, "__init__", mock_init) - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 100, ( - "Expected vLLM to gracefully shutdown in <100s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass - - -@pytest.mark.asyncio -async def test_engine_process_death(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - assert client.is_running - - # kill the engine process - engine.proc.kill() - - # Generate call should fail - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - - # And the health check should show the engine is dead - with pytest.raises(RuntimeError, match="Engine process .* died"): - await client.check_health() - - client.close() - - -def run_with_evil_input_processing(engine_args: AsyncEngineArgs, - ipc_path: str): - """Simulate an exception while preparing inputs for the model. - In the wild, this could be something like a multimodal input processor - failing on invalid image data.""" - - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - runner = engine.engine.model_executor.driver_worker.worker.model_runner - - # Raise error in the model runner when adding a sequence group. - # See class ModelInputForGPUBuilder - def raiser(_, seq_group_metadata: SequenceGroupMetadata): - if seq_group_metadata.request_id.startswith("evil"): - raise RAISED_ERROR(RAISED_VALUE) - - runner.builder.per_seq_group_compute_fns.append(raiser) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_inputs(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_input_processing) as engine: - - client = await engine.make_client() - assert client.is_running - - # Engine should be healthy - await client.check_health() - - async def run_failing_request(): - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id="evil" + str(uuid.uuid4())): - pass - - async def run_passing_request(): - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id=str(uuid.uuid4())): - pass - - passing_tasks = [ - asyncio.create_task(run_passing_request()) for _ in range(10) - ] - failing_tasks = [ - asyncio.create_task(run_failing_request()) for _ in range(10) - ] - await asyncio.gather(*failing_tasks, return_exceptions=True) - await asyncio.gather(*passing_tasks) - - # All the bad inputs should have raised - for task in failing_tasks: - with pytest.raises(RAISED_ERROR): - task.result() - - # But all good inputs should have still succeeded - for task in passing_tasks: - task.result() - - # And the engine should remain healthy - assert not client.errored - await client.check_health() - - client.close() diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py deleted file mode 100644 index c934706611ae3..0000000000000 --- a/tests/mq_llm_engine/test_load.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" - -import asyncio -import tempfile -import uuid - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate -from vllm.engine.arg_utils import AsyncEngineArgs - -MODEL = "google/gemma-1.1-2b-it" -NUM_EXPECTED_TOKENS = 10 -NUM_REQUESTS = 10000 - -# Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest.mark.asyncio -async def test_load(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] - - # Create concurrent requests. - tasks = [] - for request_id in request_ids: - tasks.append( - asyncio.create_task( - generate(client, request_id, NUM_EXPECTED_TOKENS))) - - # Confirm that we got all the EXPECTED tokens from the requests. - failed_request_id = None - tokens = None - for task in tasks: - num_generated_tokens, request_id = await task - if (num_generated_tokens != NUM_EXPECTED_TOKENS - and failed_request_id is None): - failed_request_id = request_id - tokens = num_generated_tokens - - assert failed_request_id is None, ( - f"{failed_request_id} generated {tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") - - # Shutdown. - client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py deleted file mode 100644 index 7976d5031aea1..0000000000000 --- a/tests/mq_llm_engine/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import multiprocessing -from typing import Callable, Union - -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.multiprocessing.client import MQLLMEngineClient -from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.outputs import RequestOutput -from vllm.usage.usage_lib import UsageContext - - -async def generate( - client: MQLLMEngineClient, - request_id: str, - num_tokens: int, - return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]: - - final_output = None - count = 0 - async for out in client.generate( - request_id=request_id, - prompt="Hello my name is Robert and", - sampling_params=SamplingParams(max_tokens=num_tokens, - temperature=0)): - - count += 1 - final_output = out - await asyncio.sleep(0.) - - if return_output: - return final_output - - # Confirm we generated all the tokens we expected. - return count, request_id - - -def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Run engine. - engine.start() - - -class RemoteMQLLMEngine: - - def __init__(self, - engine_args: AsyncEngineArgs, - ipc_path: str, - run_fn: Callable = run_normal) -> None: - - self.engine_args = engine_args - self.ipc_path = ipc_path - context = multiprocessing.get_context("spawn") - self.proc = context.Process(target=run_fn, - args=(engine_args, ipc_path)) - self.proc.start() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.proc.kill() - - async def make_client(self) -> MQLLMEngineClient: - engine_config = self.engine_args.create_engine_config() - client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid) - while True: - try: - await client.setup() - break - except TimeoutError: - assert self.proc.is_alive() - return client diff --git a/tests/multimodal/test_audio.py b/tests/multimodal/test_audio.py new file mode 100644 index 0000000000000..189b319e5fcde --- /dev/null +++ b/tests/multimodal/test_audio.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# test_audio.py +import base64 +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest + +from vllm.multimodal.audio import ( + AudioMediaIO, + AudioResampler, + resample_audio_librosa, + resample_audio_scipy, +) + + +@pytest.fixture +def dummy_audio(): + return np.array([0.0, 0.1, 0.2, 0.3, 0.4], dtype=float) + + +def test_resample_audio_librosa(dummy_audio): + with patch("vllm.multimodal.audio.librosa.resample") as mock_resample: + mock_resample.return_value = dummy_audio * 2 + out = resample_audio_librosa(dummy_audio, orig_sr=44100, target_sr=22050) + mock_resample.assert_called_once_with( + dummy_audio, orig_sr=44100, target_sr=22050 + ) + assert np.all(out == dummy_audio * 2) + + +def test_resample_audio_scipy(dummy_audio): + out_down = resample_audio_scipy(dummy_audio, orig_sr=4, target_sr=2) + out_up = resample_audio_scipy(dummy_audio, orig_sr=2, target_sr=4) + out_same = resample_audio_scipy(dummy_audio, orig_sr=4, target_sr=4) + + assert len(out_down) == 3 + assert len(out_up) == 10 + assert np.all(out_same == dummy_audio) + + +@pytest.mark.xfail(reason="resample_audio_scipy is buggy for non-integer ratios") +def test_resample_audio_scipy_non_integer_ratio(dummy_audio): + out = resample_audio_scipy(dummy_audio, orig_sr=5, target_sr=3) + + expected_len = int(round(len(dummy_audio) * 3 / 5)) + assert len(out) == expected_len + + assert isinstance(out, np.ndarray) + assert np.isfinite(out).all() + + +def test_audio_resampler_librosa_calls_resample(dummy_audio): + resampler = AudioResampler(target_sr=22050, method="librosa") + with patch("vllm.multimodal.audio.resample_audio_librosa") as mock_resample: + mock_resample.return_value = dummy_audio + out = resampler.resample(dummy_audio, orig_sr=44100) + mock_resample.assert_called_once_with( + dummy_audio, orig_sr=44100, target_sr=22050 + ) + assert np.all(out == dummy_audio) + + +def test_audio_resampler_scipy_calls_resample(dummy_audio): + resampler = AudioResampler(target_sr=22050, method="scipy") + with patch("vllm.multimodal.audio.resample_audio_scipy") as mock_resample: + mock_resample.return_value = dummy_audio + out = resampler.resample(dummy_audio, orig_sr=44100) + mock_resample.assert_called_once_with( + dummy_audio, orig_sr=44100, target_sr=22050 + ) + assert np.all(out == dummy_audio) + + +def test_audio_resampler_invalid_method(dummy_audio): + resampler = AudioResampler(target_sr=22050, method="invalid") + with pytest.raises(ValueError): + resampler.resample(dummy_audio, orig_sr=44100) + + +def test_audio_resampler_no_target_sr(dummy_audio): + resampler = AudioResampler(target_sr=None) + with pytest.raises(RuntimeError): + resampler.resample(dummy_audio, orig_sr=44100) + + +@pytest.fixture +def dummy_audio_bytes(): + return b"FAKEAUDIOBYTES" + + +def test_audio_media_io_load_bytes(dummy_audio_bytes): + audio_io = AudioMediaIO() + with patch("vllm.multimodal.audio.librosa.load") as mock_load: + mock_load.return_value = (np.array([0.1, 0.2]), 16000) + out = audio_io.load_bytes(dummy_audio_bytes) + mock_load.assert_called_once() + assert isinstance(out[0], np.ndarray) + assert out[1] == 16000 + + +def test_audio_media_io_load_base64(dummy_audio_bytes): + audio_io = AudioMediaIO() + encoded = base64.b64encode(dummy_audio_bytes).decode("utf-8") + with patch.object(AudioMediaIO, "load_bytes") as mock_load_bytes: + mock_load_bytes.return_value = (np.array([0.1, 0.2]), 16000) + out = audio_io.load_base64("audio/wav", encoded) + mock_load_bytes.assert_called_once() + assert isinstance(out[0], np.ndarray) + assert out[1] == 16000 + + +def test_audio_media_io_load_file(): + audio_io = AudioMediaIO() + path = Path("/fake/path.wav") + with patch("vllm.multimodal.audio.librosa.load") as mock_load: + mock_load.return_value = (np.array([0.1, 0.2]), 16000) + out = audio_io.load_file(path) + mock_load.assert_called_once_with(path, sr=None) + assert isinstance(out[0], np.ndarray) + assert out[1] == 16000 + + +def test_audio_media_io_encode_base64(dummy_audio): + audio_io = AudioMediaIO() + media = (dummy_audio, 16000) + with patch("vllm.multimodal.audio.soundfile.write") as mock_write: + + def write_to_buffer(buffer, *_args, **_kwargs): + buffer.write(b"dummy_wav_data") + + mock_write.side_effect = write_to_buffer + + out = audio_io.encode_base64(media) + decoded = base64.b64decode(out) + assert decoded == b"dummy_wav_data" + mock_write.assert_called_once() diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 088cd00db2e04..fe983990b90c8 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,52 +1,228 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np import pytest import torch -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField) +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import ( + MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + engine_receiver_cache_from_config, + processor_cache_from_config, +) +from vllm.multimodal.hasher import MultiModalHasher +from vllm.multimodal.inputs import ( + MultiModalFieldElem, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, +) +from vllm.multimodal.processing import PromptInsertion + +pytestmark = pytest.mark.cpu_test -def _dummy_elem(modality: str, key: str, size: int): +def _dummy_elem( + modality: str, + key: str, + size: int, + *, + rng: Optional[np.random.RandomState] = None, +): + if rng is None: + data = torch.empty((size,), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8)) + return MultiModalFieldElem( modality=modality, key=key, - data=torch.empty((size, ), dtype=torch.int8), + data=data, field=MultiModalSharedField(1), ) -def _dummy_item(modality: str, size_by_key: dict[str, int]): - return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() - ]) +def _dummy_item( + modality: str, + size_by_key: dict[str, int], + *, + rng: Optional[np.random.RandomState] = None, +): + return MultiModalKwargsItem.from_elems( + [_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()] + ) -def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargsItems.from_seq([ - _dummy_item(modality, size_by_key) - for modality, size_by_key in size_by_key_modality.items() - ]) +def _dummy_items( + size_by_key_modality: dict[str, dict[str, int]], + *, + rng: Optional[np.random.RandomState] = None, +): + return MultiModalKwargsItems.from_seq( + [ + _dummy_item(modality, size_by_key, rng=rng) + for modality, size_by_key in size_by_key_modality.items() + ] + ) -# yapf: disable @pytest.mark.parametrize( ("item", "expected_size"), [ (_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100, "a2": 110}), 210), (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501 + ( + _dummy_items( + {"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}} + ).get_data(), + 460, + ), # noqa: E501 ], ) -# yapf: enable def test_cache_item_size(item, expected_size): cache = MultiModalCache.get_lru_cache(2048, type(item)) cache[""] = item assert cache.currsize == expected_size - cache[""] = MultiModalCacheItemMetadata.wraps(item) + prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0) + + cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) assert cache.currsize == expected_size + + cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) + assert cache.currsize == expected_size + + +def _create_vllm_config( + *, + mm_processor_cache_gb: float, + enable_ipc: bool, +): + return VllmConfig( + model_config=ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_gb=mm_processor_cache_gb, + ), + parallel_config=ParallelConfig(data_parallel_size=1 if enable_ipc else 2), + ) + + +def _compare_caches( + config_0: VllmConfig, + config_1: VllmConfig, + *, + item_capacity: int = 8, + hit_rate: float = 0.5, + max_items_per_iter: int = 3, + is_cached_calls_per_iter: int, + n_iter: int = 100, + seed: int = 0, +): + cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY) + cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY) + cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY) + cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY) + + cache_size_gb = max( + config_0.model_config.multimodal_config.mm_processor_cache_gb, + config_1.model_config.multimodal_config.mm_processor_cache_gb, + ) + item_size_gb = int(cache_size_gb / item_capacity) + + rng = np.random.RandomState(seed) + all_items = [ + _dummy_item("item", {"key": item_size_gb}, rng=rng) + for _ in range(int(item_capacity / hit_rate)) + ] + all_hashes = [ + MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items + ] + + # Should not be used since there is nothing to convert to text + prompt_update = PromptInsertion("dummy", "target", "insertion") + + for it in range(n_iter): + num_items_to_select = rng.randint(0, max_items_per_iter) + item_idxs_to_select = rng.choice(len(all_items), num_items_to_select) + + selected_items = [all_items[idx] for idx in item_idxs_to_select] + selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select] + + if cache_0_p0 is None: + cache_0_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ + item + for item, _ in cache_0_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_1_p0 is None: + cache_1_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ + item + for item, _ in cache_1_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_0_p1 is None: + cache_0_p1_out = cache_0_p0_out + else: + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, selected_hashes) + + if cache_1_p1 is None: + cache_1_p1_out = cache_1_p0_out + else: + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, selected_hashes) + + assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" + + +@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3]) +def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): + cache_size_gb = 1 / (1 << 20) + + vllm_config_ipc_enabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + vllm_config_ipc_disabled = _create_vllm_config( + mm_processor_cache_gb=0, + enable_ipc=False, + ) + vllm_config_cache_disabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + + _compare_caches( + vllm_config_ipc_enabled, + vllm_config_ipc_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_ipc_disabled, + vllm_config_cache_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_cache_disabled, + vllm_config_ipc_enabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py index 75a233c2567cb..29064f2737834 100644 --- a/tests/multimodal/test_hasher.py +++ b/tests/multimodal/test_hasher.py @@ -10,6 +10,8 @@ from PIL import Image, ImageDraw from vllm.multimodal.hasher import MultiModalHasher +pytestmark = pytest.mark.cpu_test + ASSETS_DIR = Path(__file__).parent / "assets" assert ASSETS_DIR.exists() @@ -45,10 +47,11 @@ def test_hash_collision_image_transpose(): assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2) -def test_hash_collision_tensor_shape(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_hash_collision_tensor_shape(dtype): # The hash should be different though the data is the same when flattened - arr1 = torch.zeros((5, 10, 20, 3)) - arr2 = torch.zeros((10, 20, 5, 3)) + arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype) + arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype) hasher = MultiModalHasher assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) @@ -87,8 +90,6 @@ def test_hash_image_exif_id(): hasher = MultiModalHasher # first image has UUID in ImageID, so it should hash to that UUID - assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs( - image=id.bytes) + assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(image=id.bytes) # second image has non-UUID in ImageID, so it should hash to the image data - assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs( - image=image2a) + assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(image=image2a) diff --git a/tests/multimodal/test_image.py b/tests/multimodal/test_image.py index 271a85f1195ec..329a5b0494cb6 100644 --- a/tests/multimodal/test_image.py +++ b/tests/multimodal/test_image.py @@ -8,6 +8,8 @@ from PIL import Image, ImageChops from vllm.multimodal.image import ImageMediaIO, convert_image_mode +pytestmark = pytest.mark.cpu_test + ASSETS_DIR = Path(__file__).parent / "assets" assert ASSETS_DIR.exists() @@ -41,8 +43,7 @@ def test_rgba_to_rgb(): def test_rgba_to_rgb_custom_background(tmp_path): """Test RGBA to RGB conversion with custom background colors.""" # Create a simple RGBA image with transparent and opaque pixels - rgba_image = Image.new("RGBA", (10, 10), - (255, 0, 0, 255)) # Red with full opacity + rgba_image = Image.new("RGBA", (10, 10), (255, 0, 0, 255)) # Red with full opacity # Make top-left quadrant transparent for i in range(5): @@ -92,7 +93,7 @@ def test_rgba_to_rgb_custom_background(tmp_path): assert blue_numpy[0][0][2] == 255 # B # Test 4: Test with load_bytes method - with open(test_image_path, 'rb') as f: + with open(test_image_path, "rb") as f: image_data = f.read() image_io_green = ImageMediaIO(rgba_background_color=(0, 255, 0)) @@ -109,39 +110,47 @@ def test_rgba_background_color_validation(): """Test that invalid rgba_background_color values are properly rejected.""" # Test invalid types - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color="255,255,255") - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=255) # Test wrong number of elements - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, 255)) - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, 255, 255, 255)) # Test non-integer values - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255.0, 255.0, 255.0)) - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, "255", 255)) # Test out of range values - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(256, 255, 255)) - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, -1, 255)) # Test that valid values work diff --git a/tests/multimodal/test_inputs.py b/tests/multimodal/test_inputs.py index ffb3a6fe86b46..88e92bee3a292 100644 --- a/tests/multimodal/test_inputs.py +++ b/tests/multimodal/test_inputs.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest import torch from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +pytestmark = pytest.mark.cpu_test -def assert_nested_tensors_equal(expected: NestedTensors, - actual: NestedTensors): + +def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): assert type(expected) == type(actual) # noqa: E721 if isinstance(expected, torch.Tensor): assert torch.equal(expected, actual) @@ -16,8 +18,9 @@ def assert_nested_tensors_equal(expected: NestedTensors, assert_nested_tensors_equal(expected_item, actual_item) -def assert_multimodal_inputs_equal(expected: MultiModalKwargs, - actual: MultiModalKwargs): +def assert_multimodal_inputs_equal( + expected: MultiModalKwargs, actual: MultiModalKwargs +): assert set(expected.keys()) == set(actual.keys()) for key in expected: assert_nested_tensors_equal(expected[key], actual[key]) @@ -49,19 +52,10 @@ def test_multimodal_input_batch_nested_tensors(): a = torch.rand([2, 3]) b = torch.rand([2, 3]) c = torch.rand([2, 3]) - result = MultiModalKwargs.batch([{ - "image": [a] - }, { - "image": [b] - }, { - "image": [c] - }]) - assert_multimodal_inputs_equal(result, { - "image": - torch.stack([a.unsqueeze(0), - b.unsqueeze(0), - c.unsqueeze(0)]) - }) + result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b]}, {"image": [c]}]) + assert_multimodal_inputs_equal( + result, {"image": torch.stack([a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)])} + ) def test_multimodal_input_batch_heterogeneous_lists(): @@ -70,8 +64,8 @@ def test_multimodal_input_batch_heterogeneous_lists(): c = torch.rand([1, 2, 3]) result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) assert_multimodal_inputs_equal( - result, - {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) + result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]} + ) def test_multimodal_input_batch_multiple_batchable_lists(): @@ -81,9 +75,8 @@ def test_multimodal_input_batch_multiple_batchable_lists(): d = torch.rand([1, 2, 3]) result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}]) assert_multimodal_inputs_equal( - result, - {"image": torch.stack([torch.stack([a, b]), - torch.stack([c, d])])}) + result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])} + ) def test_multimodal_input_batch_mixed_stacking_depths(): diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index cb489c47fd8fd..a542b068a42b6 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -8,27 +8,27 @@ import numpy as np import pytest from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - PromptIndexTargets, PromptInsertion, - PromptReplacement, apply_text_matches, - apply_token_matches, - find_mm_placeholders, - find_text_matches, find_token_matches, - iter_token_matches, - replace_token_matches) -# yapf: enable +from vllm.multimodal.processing import ( + InputProcessingContext, + PlaceholderFeaturesInfo, + PromptIndexTargets, + PromptInsertion, + PromptReplacement, + apply_text_matches, + apply_token_matches, + find_mm_placeholders, + iter_token_matches, + replace_token_matches, +) from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import full_groupby from .utils import random_image +pytestmark = pytest.mark.cpu_test + -# yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), [ @@ -38,34 +38,34 @@ from .utils import random_image [32000, 32000, 32000], [32000], [ - { "start_idx": 0, "end_idx": 1 }, - { "start_idx": 1, "end_idx": 2 }, - { "start_idx": 2, "end_idx": 3 }, + {"start_idx": 0, "end_idx": 1}, + {"start_idx": 1, "end_idx": 2}, + {"start_idx": 2, "end_idx": 3}, ], ), ( [32000, 32000, 32000], [32000, 32000], - [{ "start_idx": 0, "end_idx": 2 }], + [{"start_idx": 0, "end_idx": 2}], ), ( [32000, 32000, 32000], [32000, 32000, 32000], - [{ "start_idx": 0, "end_idx": 3 }], + [{"start_idx": 0, "end_idx": 3}], ), ( [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], [28747, 32000], [ - { "start_idx": 1, "end_idx": 3 }, - { "start_idx": 6, "end_idx": 8 }, + {"start_idx": 1, "end_idx": 3}, + {"start_idx": 6, "end_idx": 8}, ], ), ( [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], [28747, 32000, 32000, 32000], [ - { "start_idx": 1, "end_idx": 5 }, + {"start_idx": 1, "end_idx": 5}, ], ), ( @@ -75,12 +75,14 @@ from .utils import random_image ), ], ) -# yapf: enable -def test_iter_token_matches(token_ids, match_ids, expected): - result = list(iter_token_matches(token_ids, match_ids)) +@pytest.mark.parametrize("start_idx", [0, 4, 8]) +def test_iter_token_matches(token_ids, match_ids, expected, start_idx): + result = list(iter_token_matches(token_ids, match_ids, start_idx=start_idx)) # Manually constructed results - assert [item._asdict() for item in result] == expected + assert [item._asdict() for item in result] == [ + item for item in expected if item["start_idx"] >= start_idx + ] # Invariants match_lens = [end - start for start, end in result] @@ -88,7 +90,6 @@ def test_iter_token_matches(token_ids, match_ids, expected): assert all(match_len == len(match_ids) for match_len in match_lens) -# yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "new_ids", "expected"), [ @@ -132,7 +133,6 @@ def test_iter_token_matches(token_ids, match_ids, expected): ), ], ) -# yapf: enable def test_replace_token_matches(token_ids, match_ids, new_ids, expected): result = replace_token_matches(token_ids, match_ids, new_ids) @@ -140,7 +140,6 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): assert result == expected -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "expected_by_key"), [ @@ -157,11 +156,11 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): "pattern_1": [], "pattern_2": [], "pattern_3": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_4": [], "pattern_5": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], }, ), @@ -177,26 +176,26 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 1 }, - { "start_idx": 1, "end_idx": 2 }, - { "start_idx": 2, "end_idx": 3 }, - { "start_idx": 3, "end_idx": 4 }, + {"start_idx": 0, "end_idx": 1}, + {"start_idx": 1, "end_idx": 2}, + {"start_idx": 2, "end_idx": 3}, + {"start_idx": 3, "end_idx": 4}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 2 }, - { "start_idx": 2, "end_idx": 4 }, + {"start_idx": 0, "end_idx": 2}, + {"start_idx": 2, "end_idx": 4}, ], "pattern_3": [ - { "start_idx": 0, "end_idx": 3 }, + {"start_idx": 0, "end_idx": 3}, ], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [ - { "start_idx": 1, "end_idx": 1 }, + {"start_idx": 1, "end_idx": 1}, ], "pattern_6": [ - { "start_idx": 4, "end_idx": 4 }, + {"start_idx": 4, "end_idx": 4}, ], }, ), @@ -212,26 +211,25 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): }, { "pattern_1": [ - { "start_idx": 1, "end_idx": 3 }, - { "start_idx": 6, "end_idx": 8 }, + {"start_idx": 1, "end_idx": 3}, + {"start_idx": 6, "end_idx": 8}, ], "pattern_2": [ - { "start_idx": 1, "end_idx": 5 }, + {"start_idx": 1, "end_idx": 5}, ], "pattern_3": [], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [], "pattern_6": [ - { "start_idx": 10, "end_idx": 10 }, + {"start_idx": 10, "end_idx": 10}, ], }, ), ], ) @pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) -# yapf: enable def test_find_token_matches( prompt, target_by_key, @@ -241,27 +239,28 @@ def test_find_token_matches( # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_token_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_token_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "expected_by_key"), [ @@ -277,16 +276,16 @@ def test_find_token_matches( "pattern_5": PromptIndexTargets.end(), }, { - "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_1": [{"start_idx": 0, "end_idx": 0}], "pattern_2": [], "pattern_3": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_4": [], "pattern_5": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], - } + }, ), ( "<image><image><image><image>", @@ -300,26 +299,26 @@ def test_find_token_matches( }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 7 }, - { "start_idx": 7, "end_idx": 14 }, - { "start_idx": 14, "end_idx": 21 }, - { "start_idx": 21, "end_idx": 28 }, + {"start_idx": 0, "end_idx": 7}, + {"start_idx": 7, "end_idx": 14}, + {"start_idx": 14, "end_idx": 21}, + {"start_idx": 21, "end_idx": 28}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 14 }, - { "start_idx": 14, "end_idx": 28 }, + {"start_idx": 0, "end_idx": 14}, + {"start_idx": 14, "end_idx": 28}, ], "pattern_3": [ - { "start_idx": 0, "end_idx": 21 }, + {"start_idx": 0, "end_idx": 21}, ], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [ - { "start_idx": 7, "end_idx": 7 }, + {"start_idx": 7, "end_idx": 7}, ], "pattern_6": [ - { "start_idx": 28, "end_idx": 28 }, + {"start_idx": 28, "end_idx": 28}, ], }, ), @@ -335,21 +334,21 @@ def test_find_token_matches( }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 13 }, - { "start_idx": 27, "end_idx": 40 }, + {"start_idx": 0, "end_idx": 13}, + {"start_idx": 27, "end_idx": 40}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 27 }, + {"start_idx": 0, "end_idx": 27}, ], "pattern_3": [], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [ - { "start_idx": 13, "end_idx": 13 }, + {"start_idx": 13, "end_idx": 13}, ], "pattern_6": [ - { "start_idx": 48, "end_idx": 48 }, + {"start_idx": 48, "end_idx": 48}, ], }, ), @@ -363,22 +362,21 @@ def test_find_token_matches( }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 9 }, - { "start_idx": 16, "end_idx": 25 }, + {"start_idx": 0, "end_idx": 9}, + {"start_idx": 16, "end_idx": 25}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 16 }, - { "start_idx": 16, "end_idx": 32 }, + {"start_idx": 0, "end_idx": 16}, + {"start_idx": 16, "end_idx": 32}, ], "pattern_3": [ - { "start_idx": 0, "end_idx": 25 }, + {"start_idx": 0, "end_idx": 25}, ], }, ), ], ) @pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) -# yapf: enable def test_find_text_matches( prompt, target_by_key, @@ -388,27 +386,28 @@ def test_find_text_matches( # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_text_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_text_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ @@ -536,9 +535,8 @@ def test_find_text_matches( }, }, ), - ] + ], ) -# yapf: enable def test_find_update_text( prompt, target_by_key, @@ -549,42 +547,39 @@ def test_find_update_text( mock_tokenizer = cast(AnyTokenizer, object()) for ( - update_type, - expected_by_mm_count, + update_type, + expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_text_matches( + mm_prompt_updates = { + key: [ + [update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count) + ] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_text_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ - # Tokenized test cases of `test_find_replace_text` + # Tokenized test cases of `test_find_update_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf ( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -606,8 +601,43 @@ def test_find_update_text( { PromptInsertion: { 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], - 1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501 - 2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501 + 1: [ + 1, + 9833, + 28747, + 32000, + 32000, + 32000, + 9833, + 28747, + 32000, + 32000, + 918, + 1550, + 918, + 1550, + ], # noqa: E501 + 2: [ + 1, + 9833, + 28747, + 32000, + 32000, + 32000, + 32000, + 32000, + 9833, + 28747, + 32000, + 32000, + 918, + 1550, + 918, + 1550, + 1550, + 918, + 1550, + ], # noqa: E501 }, PromptReplacement: { 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -710,9 +740,8 @@ def test_find_update_text( }, }, ), - ] + ], ) -# yapf: enable def test_find_update_tokens( prompt, target_by_key, @@ -723,38 +752,35 @@ def test_find_update_tokens( mock_tokenizer = cast(AnyTokenizer, object()) for ( - update_type, - expected_by_mm_count, + update_type, + expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_token_matches( + mm_prompt_updates = { + key: [ + [update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count) + ] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_token_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected -# yapf: disable @pytest.mark.parametrize( "repl_by_key", [ @@ -791,8 +817,7 @@ def test_find_update_tokens( is_embed=None, ), ], - } - + }, ), ( [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], @@ -823,7 +848,7 @@ def test_find_update_tokens( ), ], # No match for pattern_4 as it has lower priority than pattern_1 - } + }, ), ( [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], @@ -862,12 +887,11 @@ def test_find_update_tokens( is_embed=None, ), ], - } + }, ), - ] + ], ) @pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) -# yapf: enable def test_find_mm_placeholders( repl_by_key, prompt, @@ -878,17 +902,11 @@ def test_find_mm_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) mm_prompt_updates = { - key: [update_type(key, [], repl).bind(mock_tokenizer)] + key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders( - mm_prompt_updates, - prompt, - # Effectively match all occurrences in the prompt - {key: 3 - for key in repl_by_key}, - ) + result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) # Only displayed on error print("result:", result) @@ -900,8 +918,15 @@ def test_find_mm_placeholders( @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), - [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), - (2, 1, False), (2, 2, True)], + [ + (0, 0, True), + (0, 1, True), + (1, 0, False), + (1, 1, True), + (1, 2, True), + (2, 1, False), + (2, 2, True), + ], ) def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): limit_mm_per_prompt = {"image": limit} @@ -916,10 +941,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): profiler = MultiModalProfiler(processor) - if is_valid: - exc_ctx = nullcontext() - else: - exc_ctx = pytest.raises(ValueError, match="At most") + exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") with exc_ctx: profiler.get_decoder_dummy_data( @@ -931,8 +953,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("num_images", "limit", "is_valid"), - [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), - (2, 1, False), (2, 2, True)], + [ + (0, 0, True), + (0, 1, True), + (1, 0, False), + (1, 1, True), + (1, 2, True), + (2, 1, False), + (2, 2, True), + ], ) def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): limit_mm_per_prompt = {"image": limit} @@ -953,10 +982,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): else: mm_data = {"image": [image] * num_images} - if is_valid: - exc_ctx = nullcontext() - else: - exc_ctx = pytest.raises(ValueError, match="At most") + exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") with exc_ctx: processor.apply( @@ -967,7 +993,6 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): class DummyProcessor: - def __init__(self, a: int = 0, b: int = 0) -> None: super().__init__() @@ -983,7 +1008,6 @@ class DummyProcessor: return dict(a=a, c=c) -# yapf: disable @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy @pytest.mark.parametrize( ("config_kwargs", "inference_kwargs", "expected_kwargs"), @@ -997,7 +1021,6 @@ class DummyProcessor: ({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}), ], ) -# yapf: enable def test_hf_processor_init_kwargs( model_id, config_kwargs, @@ -1021,7 +1044,6 @@ def test_hf_processor_init_kwargs( assert getattr(processor, k) == v -# yapf: disable @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy @pytest.mark.parametrize( ("config_kwargs", "inference_kwargs", "expected_kwargs"), @@ -1035,7 +1057,6 @@ def test_hf_processor_init_kwargs( ({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}), ], ) -# yapf: enable def test_hf_processor_call_kwargs( model_id, config_kwargs, diff --git a/tests/multimodal/test_registry.py b/tests/multimodal/test_registry.py index d31e75bc279f6..3b01bda7f54c8 100644 --- a/tests/multimodal/test_registry.py +++ b/tests/multimodal/test_registry.py @@ -11,28 +11,24 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from ..models.utils import build_model_context +pytestmark = pytest.mark.cpu_test + @pytest.mark.parametrize( "model_id,limit_mm_per_prompt,expected", [ ("Qwen/Qwen2-0.5B-Instruct", {}, False), ("Qwen/Qwen2.5-VL-3B-Instruct", {}, True), - ("Qwen/Qwen2.5-VL-3B-Instruct", { - "image": 0, - "video": 0 - }, False), - ("Qwen/Qwen2.5-VL-3B-Instruct", { - "image": 0 - }, True), + ("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0, "video": 0}, False), + ("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0}, True), ], ) @pytest.mark.core_model def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected): - """Test supports_multimodal_inputs returns correct boolean for various + """Test supports_multimodal_inputs returns correct boolean for various configs.""" ctx = build_model_context( model_id, limit_mm_per_prompt=limit_mm_per_prompt, ) - assert MULTIMODAL_REGISTRY.supports_multimodal_inputs( - ctx.model_config) is expected \ No newline at end of file + assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(ctx.model_config) is expected diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index a028c668c8ab7..ea795fcbbde55 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,40 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -import math import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, NamedTuple import numpy as np import pytest -import torch -import torch.multiprocessing as mp from PIL import Image, ImageChops -from tests.utils import multi_gpu_test -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange -from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, - get_load_balance_assignment, - run_dp_sharded_mrope_vision_model, - run_dp_sharded_vision_model) -from vllm.platforms import current_platform -from vllm.utils import get_open_port, update_environment_variables - -if TYPE_CHECKING: - from vllm.multimodal.inputs import MultiModalPlaceholderDict +from vllm.multimodal.utils import MediaConnector, argsort_mm_positions # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] TEST_VIDEO_URLS = [ @@ -45,21 +29,19 @@ TEST_VIDEO_URLS = [ @pytest.fixture(scope="module") -def url_images() -> dict[str, Image.Image]: - connector = MediaConnector() - +def url_images(local_asset_server) -> dict[str, Image.Image]: return { - image_url: connector.fetch_image(image_url) - for image_url in TEST_IMAGE_URLS + image_url: local_asset_server.get_image_asset(image_url) + for image_url in TEST_IMAGE_ASSETS } def get_supported_suffixes() -> tuple[str, ...]: # We should at least test the file types mentioned in GPT-4 with Vision - OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif') + OPENAI_SUPPORTED_SUFFIXES = (".png", ".jpeg", ".jpg", ".webp", ".gif") # Additional file types that are supported by us - EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff') + EXTRA_SUPPORTED_SUFFIXES = (".bmp", ".tiff") return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES @@ -69,7 +51,7 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_fetch_image_http(image_url: str): connector = MediaConnector() @@ -79,12 +61,19 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) -async def test_fetch_image_base64(url_images: dict[str, Image.Image], - image_url: str, suffix: str): - connector = MediaConnector() - url_image = url_images[image_url] +async def test_fetch_image_base64( + url_images: dict[str, Image.Image], raw_image_url: str, suffix: str +): + connector = MediaConnector( + # Domain restriction should not apply to data URLs. + allowed_media_domains=[ + "www.bogotobogo.com", + "github.com", + ] + ) + url_image = url_images[raw_image_url] try: mime_type = Image.MIME[Image.registered_extensions()[suffix]] @@ -92,14 +81,14 @@ async def test_fetch_image_base64(url_images: dict[str, Image.Image], try: mime_type = mimetypes.types_map[suffix] except KeyError: - pytest.skip('No MIME type') + pytest.skip("No MIME type") with NamedTemporaryFile(suffix=suffix) as f: try: url_image.save(f.name) except Exception as e: - if e.args[0] == 'cannot write mode RGBA as JPEG': - pytest.skip('Conversion not supported') + if e.args[0] == "cannot write mode RGBA as JPEG": + pytest.skip("Conversion not supported") raise @@ -117,7 +106,7 @@ async def test_fetch_image_base64(url_images: dict[str, Image.Image], @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_fetch_image_local_files(image_url: str): connector = MediaConnector() @@ -125,35 +114,41 @@ async def test_fetch_image_local_files(image_url: str): local_connector = MediaConnector(allowed_local_media_path=temp_dir) origin_image = connector.fetch_image(image_url) - origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), - quality=100, - icc_profile=origin_image.info.get('icc_profile')) + origin_image.save( + os.path.join(temp_dir, os.path.basename(image_url)), + quality=100, + icc_profile=origin_image.info.get("icc_profile"), + ) image_async = await local_connector.fetch_image_async( - f"file://{temp_dir}/{os.path.basename(image_url)}") + f"file://{temp_dir}/{os.path.basename(image_url)}" + ) image_sync = local_connector.fetch_image( - f"file://{temp_dir}/{os.path.basename(image_url)}") + f"file://{temp_dir}/{os.path.basename(image_url)}" + ) # Check that the images are equal assert not ImageChops.difference(image_sync, image_async).getbbox() with pytest.raises(ValueError, match="must be a subpath"): await local_connector.fetch_image_async( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + f"file://{temp_dir}/../{os.path.basename(image_url)}" + ) with pytest.raises(RuntimeError, match="Cannot load local files"): await connector.fetch_image_async( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + f"file://{temp_dir}/../{os.path.basename(image_url)}" + ) with pytest.raises(ValueError, match="must be a subpath"): local_connector.fetch_image( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + f"file://{temp_dir}/../{os.path.basename(image_url)}" + ) with pytest.raises(RuntimeError, match="Cannot load local files"): - connector.fetch_image( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + connector.fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") @pytest.mark.asyncio -async def test_fetch_image_local_files_with_space_in_name(): - image_url = TEST_IMAGE_URLS[0] +@pytest.mark.parametrize("image_url", [TEST_IMAGE_ASSETS[0]], indirect=True) +async def test_fetch_image_local_files_with_space_in_name(image_url: str): connector = MediaConnector() with TemporaryDirectory() as temp_dir: @@ -161,18 +156,19 @@ async def test_fetch_image_local_files_with_space_in_name(): origin_image = connector.fetch_image(image_url) filename = "file name with space.jpg" - origin_image.save(os.path.join(temp_dir, filename), - quality=100, - icc_profile=origin_image.info.get('icc_profile')) + origin_image.save( + os.path.join(temp_dir, filename), + quality=100, + icc_profile=origin_image.info.get("icc_profile"), + ) try: image_async = await local_connector.fetch_image_async( - f"file://{temp_dir}/{filename}") - image_sync = local_connector.fetch_image( - f"file://{temp_dir}/{filename}") + f"file://{temp_dir}/{filename}" + ) + image_sync = local_connector.fetch_image(f"file://{temp_dir}/{filename}") except FileNotFoundError as e: - pytest.fail( - "Failed to fetch image with space in name: {}".format(e)) + pytest.fail("Failed to fetch image with space in name: {}".format(e)) # Check that the images are equal assert not ImageChops.difference(image_sync, image_async).getbbox() @@ -195,9 +191,12 @@ async def test_fetch_image_error_conversion(): @pytest.mark.parametrize("num_frames", [-1, 32, 1800]) async def test_fetch_video_http(video_url: str, num_frames: int): connector = MediaConnector( - media_io_kwargs={"video": { - "num_frames": num_frames, - }}) + media_io_kwargs={ + "video": { + "num_frames": num_frames, + } + } + ) video_sync, metadata_sync = connector.fetch_video(video_url) video_async, metadata_async = await connector.fetch_video_async(video_url) @@ -205,18 +204,41 @@ async def test_fetch_video_http(video_url: str, num_frames: int): assert metadata_sync == metadata_async -# Used for `test_argsort_mm_positions`. -class TestCase(NamedTuple): - mm_positions: "MultiModalPlaceholderDict" - expected_modality_idxs: list[tuple[str, int]] +@pytest.mark.asyncio +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +@pytest.mark.parametrize("max_duration", [1, 60, 1800]) +@pytest.mark.parametrize("requested_fps", [2, 24]) +async def test_fetch_video_http_with_dynamic_loader( + video_url: str, + max_duration: int, + requested_fps: int, + monkeypatch: pytest.MonkeyPatch, +): + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic") + connector = MediaConnector( + media_io_kwargs={ + "video": { + "max_duration": max_duration, + "requested_fps": requested_fps, + } + } + ) + + video_sync, metadata_sync = connector.fetch_video(video_url) + video_async, metadata_async = await connector.fetch_video_async(video_url) + + assert np.array_equal(video_sync, video_async) + assert metadata_sync == metadata_async + assert metadata_sync["video_backend"] == "opencv_dynamic" -def test_argsort_mm_positions(): - - test_cases = [ +@pytest.mark.parametrize( + "case", + [ # Single modality ## Internally sorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=0, length=2), @@ -229,7 +251,7 @@ def test_argsort_mm_positions(): ], ), ## Internally unsorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=3, length=2), @@ -241,10 +263,9 @@ def test_argsort_mm_positions(): ("image", 0), ], ), - # Two modalities ## Internally sorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=7, length=4), @@ -253,7 +274,7 @@ def test_argsort_mm_positions(): "audio": [ PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=2, length=3), - ] + ], }, expected_modality_idxs=[ ("audio", 0), @@ -263,7 +284,7 @@ def test_argsort_mm_positions(): ], ), ## Interleaved, internally sorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=0, length=4), @@ -272,7 +293,7 @@ def test_argsort_mm_positions(): "audio": [ PlaceholderRange(offset=5, length=2), PlaceholderRange(offset=11, length=4), - ] + ], }, expected_modality_idxs=[ ("image", 0), @@ -282,7 +303,7 @@ def test_argsort_mm_positions(): ], ), ## Interleaved, internally unsorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=8, length=2), @@ -291,7 +312,7 @@ def test_argsort_mm_positions(): "audio": [ PlaceholderRange(offset=11, length=4), PlaceholderRange(offset=5, length=2), - ] + ], }, expected_modality_idxs=[ ("image", 1), @@ -300,10 +321,9 @@ def test_argsort_mm_positions(): ("audio", 0), ], ), - # Three modalities ## Internally sorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=15, length=7), @@ -316,7 +336,7 @@ def test_argsort_mm_positions(): PlaceholderRange(offset=3, length=4), PlaceholderRange(offset=7, length=5), PlaceholderRange(offset=12, length=6), - ] + ], }, expected_modality_idxs=[ ("audio", 0), @@ -328,7 +348,7 @@ def test_argsort_mm_positions(): ], ), ## Interleaved, internally sorted - TestCase( + dict( mm_positions={ "image": [ PlaceholderRange(offset=0, length=2), @@ -340,7 +360,7 @@ def test_argsort_mm_positions(): ], "video": [ PlaceholderRange(offset=8, length=5), - ] + ], }, expected_modality_idxs=[ ("image", 0), @@ -350,8 +370,8 @@ def test_argsort_mm_positions(): ("image", 2), ], ), - ## Interleaved, internally sunorted - TestCase( + ## Interleaved, internally unsorted + dict( mm_positions={ "image": [ PlaceholderRange(offset=0, length=2), @@ -363,7 +383,7 @@ def test_argsort_mm_positions(): ], "video": [ PlaceholderRange(offset=8, length=5), - ] + ], }, expected_modality_idxs=[ ("image", 0), @@ -373,415 +393,41 @@ def test_argsort_mm_positions(): ("image", 1), ], ), - ] - - for mm_positions, expected_modality_idxs in test_cases: - modality_idxs = argsort_mm_positions(mm_positions) - - assert modality_idxs == expected_modality_idxs - - -class SimpleLinearModel(torch.nn.Module): - """A simple linear vision model for testing.""" - - def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): - super().__init__() - self.flatten = torch.nn.Flatten() - self.linear = torch.nn.Linear(input_dim, output_dim) - - def forward(self, x: torch.Tensor): - # Flatten the input and apply linear transformation - x = self.flatten(x) - return self.linear(x) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "batch_size", - [ - 1, # Single image - 4, # Small batch - 5, # Odd batch size (for testing padding) ], ) -def test_run_dp_sharded_vision_model(batch_size: int): - world_size = 2 - # Launch processes - mp.spawn( - run_dp_sharded_vision_model_vs_direct, - args=( - world_size, - batch_size, - get_open_port(), - ), - nprocs=world_size, +def test_argsort_mm_positions(case): + mm_positions = case["mm_positions"] + expected_modality_idxs = case["expected_modality_idxs"] + + modality_idxs = argsort_mm_positions(mm_positions) + + assert modality_idxs == expected_modality_idxs + + +@pytest.mark.asyncio +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) +async def test_allowed_media_domains(video_url: str, num_frames: int): + connector = MediaConnector( + media_io_kwargs={ + "video": { + "num_frames": num_frames, + } + }, + allowed_media_domains=[ + "www.bogotobogo.com", + "github.com", + ], ) + video_sync, metadata_sync = connector.fetch_video(video_url) + video_async, metadata_async = await connector.fetch_video_async(video_url) + assert np.array_equal(video_sync, video_async) + assert metadata_sync == metadata_async -def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, - batch_size: int, master_port: int): - """ - Test that run_dp_sharded_vision_model produces the same results as - calling the model directly. - """ + disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png" + with pytest.raises(ValueError): + _, _ = connector.fetch_video(disallowed_url) - # Set random seed for reproducibility - current_platform.seed_everything(0) - - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - # initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create a test input tensor - image_input = torch.randn(batch_size, 3, 224, 224) - - # Create a simple linear model - vision_model = SimpleLinearModel() - - # Run the model directly on the full input - with torch.inference_mode(): - direct_output = vision_model(image_input) - - # Run the model through the sharded function - with torch.inference_mode(): - sharded_output = run_dp_sharded_vision_model(image_input, vision_model) - - # Check that the world size is setup correctly - assert get_tensor_model_parallel_world_size() == world_size - - # Check that the outputs have the same shape - assert direct_output.shape == sharded_output.shape - - # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) - - -@pytest.mark.parametrize( - "sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts," - "expected_grouped_sizes_per_gpu,test_description", - [ - # Empty input - ([], 2, [], [0, 0], [0, 0], "empty input"), - - # Fewer samples than GPUs - ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 - ], "fewer samples than GPUs"), - - # Single GPU - ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), - - # Balanced assignment - ([100, 100, 100, 100 - ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), - - # Unbalanced sizes - this one is trickier since the algorithm is greedy - ([1000, 100, 200, 50], 2, [0, 2, 1, 3 - ], [1, 3], [1000, 350], "unbalanced sizes"), - ], -) -def test_get_load_balance_assignment_cases(sizes, num_gpus, - expected_shuffle_indices, - expected_gpu_sample_counts, - expected_grouped_sizes_per_gpu, - test_description): - """Test get_load_balance_assignment with various input cases.""" - result = get_load_balance_assignment(sizes, num_gpus=num_gpus) - (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result - - # Common assertions for all cases - assert len(shuffle_indices) == len(sizes) - assert len(gpu_sample_counts) == num_gpus - assert len(grouped_sizes_per_gpu) == num_gpus - assert sum(gpu_sample_counts) == len(sizes) - - assert shuffle_indices == expected_shuffle_indices - - assert gpu_sample_counts == expected_gpu_sample_counts - assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu - - -class SimpleMRopeVisionModel(torch.nn.Module): - """A simple vision model for testing mrope functionality.""" - - def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): - super().__init__() - self.spatial_merge_size = spatial_merge_size - self.out_hidden_size = out_hidden_size - self.linear = torch.nn.Linear(768, out_hidden_size) - - def forward(self, pixel_values: torch.Tensor, - grid_thw_list: list[list[int]]): - """Simple forward pass that simulates spatial merging.""" - # Apply linear transformation - embeddings = self.linear(pixel_values) - - # Simulate spatial merging by reducing the number of patches - merge_factor = self.spatial_merge_size * self.spatial_merge_size - - # Group patches and merge spatially - merged_embeddings = [] - start_idx = 0 - - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - end_idx = start_idx + num_patches - - # Get patches for this image - image_patches = embeddings[start_idx:end_idx] - - # Simulate spatial merging by averaging groups of patches - merged_patches = num_patches // merge_factor - if merged_patches > 0: - # Reshape and average to simulate merging - reshaped = image_patches[:merged_patches * merge_factor].view( - merged_patches, merge_factor, -1) - merged = reshaped.mean(dim=1) - merged_embeddings.append(merged) - - start_idx = end_idx - - if merged_embeddings: - return torch.cat(merged_embeddings, dim=0) - else: - return torch.empty((0, self.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize( - "batch_size", - [ - 1, # Single image - 3, # Small batch - 5, # Odd batch size (for testing padding) - ], -) -def test_run_dp_sharded_mrope_vision_model(batch_size: int): - world_size = 2 - # Launch processes - mp.spawn( - run_dp_sharded_mrope_vision_model_vs_direct, - args=( - world_size, - batch_size, - get_open_port(), - ), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, - world_size: int, - batch_size: int, - master_port: int): - """ - Test that run_dp_sharded_mrope_vision_model produces the same results as - calling the model directly. - """ - # Set random seed for reproducibility - current_platform.seed_everything(0) - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - # initialize distributed - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create test data - grid_thw_list = [] - pixel_values_list = [] - - for i in range(batch_size): - # Varying image sizes for better testing - t, h, w = 1, 4 + i, 4 + i - grid_thw_list.append([t, h, w]) - - num_patches = t * h * w - # Create random pixel values for this image - image_pixels = torch.randn(num_patches, 768) - pixel_values_list.append(image_pixels) - - # Concatenate all pixel values - pixel_values = torch.cat(pixel_values_list, dim=0) - - # Create a simple mrope vision model - vision_model = SimpleMRopeVisionModel() - - # Run the model directly on the full input (only on rank 0) - if local_rank == 0: - with torch.inference_mode(): - direct_output = vision_model(pixel_values, grid_thw_list) - - # Run the model through the sharded function - with torch.inference_mode(): - sharded_output = run_dp_sharded_mrope_vision_model( - vision_model, pixel_values, grid_thw_list) - sharded_output = torch.cat(sharded_output, dim=0) - - # Check that the world size is setup correctly - assert get_tensor_model_parallel_world_size() == world_size - - # Compare outputs (only on rank 0) - if local_rank == 0: - # Check that the outputs have the same shape - assert direct_output.shape == sharded_output.shape - # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, - sharded_output, - rtol=1e-5, - atol=1e-5) - - -@multi_gpu_test(num_gpus=2) -def test_run_dp_sharded_mrope_vision_model_empty_input(): - world_size = 2 - mp.spawn( - run_dp_sharded_mrope_vision_model_empty_input_worker, - args=(world_size, get_open_port()), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_empty_input_worker( - local_rank: int, world_size: int, master_port: int): - """Test run_dp_sharded_mrope_vision_model with empty input.""" - # Set up distributed environment - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create empty inputs - pixel_values = torch.empty((0, 768)) - grid_thw_list: list[list[int]] = [] - - vision_model = SimpleMRopeVisionModel() - - # Should handle empty input gracefully - with torch.inference_mode(): - output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values, - grid_thw_list) - - assert len(output) == 0 - - -@multi_gpu_test(num_gpus=4) -def test_run_dp_sharded_mrope_vision_model_uneven_load(): - world_size = 4 - mp.spawn( - run_dp_sharded_mrope_vision_model_uneven_load_worker, - args=(world_size, get_open_port()), - nprocs=world_size, - ) - - -def run_dp_sharded_mrope_vision_model_uneven_load_worker( - local_rank: int, world_size: int, master_port: int): - """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" - # Set up distributed environment - current_platform.seed_everything(123) - device = f"{current_platform.device_name}:{local_rank}" - current_platform.set_device(device) - torch.set_default_device(device) - - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) - - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) - - # Create images with very different sizes - grid_thw_list = [ - [1, 2, 2], # Small: 4 patches - [1, 8, 8], # Large: 64 patches - [1, 3, 3], # Medium: 9 patches - ] - - pixel_values_list = [] - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - image_pixels = torch.randn(num_patches, 768) - pixel_values_list.append(image_pixels) - - pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel() - - # Should handle uneven distribution without errors - with torch.inference_mode(): - output_tuple = run_dp_sharded_mrope_vision_model( - vision_model, pixel_values, grid_thw_list) - - # Verify output shape is reasonable - merge_factor = vision_model.spatial_merge_size**2 - expected_output_patches = list( - math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) - - for i, output in enumerate(output_tuple): - assert output.shape[0] == expected_output_patches[i] - assert output.shape[1] == vision_model.out_hidden_size - - -@pytest.mark.parametrize("spatial_merge_size", [2, 4]) -def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): - """Test SimpleMRopeVisionModel with different spatial merge sizes.""" - device = current_platform.device_type - - grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images - pixel_values_list = [] - - for grid_thw in grid_thw_list: - num_patches = math.prod(grid_thw) - image_pixels = torch.randn(num_patches, 768, device=device) - pixel_values_list.append(image_pixels) - - pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel( - spatial_merge_size=spatial_merge_size).to(device) - - with torch.inference_mode(): - output = vision_model(pixel_values, grid_thw_list) - - # Verify output dimensions based on spatial merging - total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list) - merge_factor = spatial_merge_size**2 - expected_output_patches = total_patches // merge_factor - - assert output.shape[0] == expected_output_patches - assert output.shape[1] == vision_model.out_hidden_size + with pytest.raises(ValueError): + _, _ = await connector.fetch_video_async(disallowed_url) diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py index 05b7b84be7f34..6572616769a91 100644 --- a/tests/multimodal/test_video.py +++ b/tests/multimodal/test_video.py @@ -12,11 +12,12 @@ from PIL import Image from vllm.assets.base import get_vllm_public_assets from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list from vllm.multimodal.image import ImageMediaIO -from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader, - VideoMediaIO) +from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader, VideoMediaIO from .utils import cosine_similarity, create_video_from_image, normalize_image +pytestmark = pytest.mark.cpu_test + NUM_FRAMES = 10 FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3) FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3) @@ -24,7 +25,6 @@ FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3) @VIDEO_LOADER_REGISTRY.register("test_video_loader_1") class TestVideoLoader1(VideoLoader): - @classmethod def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: return FAKE_OUTPUT_1 @@ -32,7 +32,6 @@ class TestVideoLoader1(VideoLoader): @VIDEO_LOADER_REGISTRY.register("test_video_loader_2") class TestVideoLoader2(VideoLoader): - @classmethod def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: return FAKE_OUTPUT_2 @@ -55,13 +54,10 @@ def test_video_loader_type_doesnt_exist(): @VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps") class Assert10Frames1FPSVideoLoader(VideoLoader): - @classmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - fps: float = -1.0, - **kwargs) -> npt.NDArray: + def load_bytes( + cls, data: bytes, num_frames: int = -1, fps: float = -1.0, **kwargs + ) -> npt.NDArray: assert num_frames == 10, "bad num_frames" assert fps == 1.0, "bad fps" return FAKE_OUTPUT_2 @@ -77,11 +73,8 @@ def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch): _ = videoio.load_bytes(b"test") videoio = VideoMediaIO( - imageio, **{ - "num_frames": 10, - "fps": 1.0, - "not_used": "not_used" - }) + imageio, **{"num_frames": 10, "fps": 1.0, "not_used": "not_used"} + ) _ = videoio.load_bytes(b"test") with pytest.raises(AssertionError, match="bad num_frames"): @@ -104,8 +97,9 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str): Test all functions that use OpenCV for video I/O return RGB format. Both RGB and grayscale videos are tested. """ - image_path = get_vllm_public_assets(filename="stop_sign.jpg", - s3_prefix="vision_model_images") + image_path = get_vllm_public_assets( + filename="stop_sign.jpg", s3_prefix="vision_model_images" + ) image = Image.open(image_path) with tempfile.TemporaryDirectory() as tmpdir: if not is_color: @@ -125,21 +119,24 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str): frames = video_to_ndarrays(video_path) for frame in frames: - sim = cosine_similarity(normalize_image(np.array(frame)), - normalize_image(np.array(image))) + sim = cosine_similarity( + normalize_image(np.array(frame)), normalize_image(np.array(image)) + ) assert np.sum(np.isnan(sim)) / sim.size < 0.001 assert np.nanmean(sim) > 0.99 pil_frames = video_to_pil_images_list(video_path) for frame in pil_frames: - sim = cosine_similarity(normalize_image(np.array(frame)), - normalize_image(np.array(image))) + sim = cosine_similarity( + normalize_image(np.array(frame)), normalize_image(np.array(image)) + ) assert np.sum(np.isnan(sim)) / sim.size < 0.001 assert np.nanmean(sim) > 0.99 io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path)) for frame in io_frames: - sim = cosine_similarity(normalize_image(np.array(frame)), - normalize_image(np.array(image))) + sim = cosine_similarity( + normalize_image(np.array(frame)), normalize_image(np.array(image)) + ) assert np.sum(np.isnan(sim)) / sim.size < 0.001 assert np.nanmean(sim) > 0.99 diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 9a58292f9f4a5..485bde939f690 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -8,7 +8,7 @@ from PIL import Image def random_image(rng: np.random.RandomState, min_wh: int, max_wh: int): - w, h = rng.randint(min_wh, max_wh, size=(2, )) + w, h = rng.randint(min_wh, max_wh, size=(2,)) arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8) return Image.fromarray(arr) @@ -21,7 +21,7 @@ def random_video( max_wh: int, ): num_frames = rng.randint(min_frames, max_frames) - w, h = rng.randint(min_wh, max_wh, size=(2, )) + w, h = rng.randint(min_wh, max_wh, size=(2,)) return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) @@ -66,14 +66,13 @@ def create_video_from_image( return video_path -def cosine_similarity(A: npt.NDArray, - B: npt.NDArray, - axis: int = -1) -> npt.NDArray: +def cosine_similarity(A: npt.NDArray, B: npt.NDArray, axis: int = -1) -> npt.NDArray: """Compute cosine similarity between two vectors.""" - return (np.sum(A * B, axis=axis) / - (np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis))) + return np.sum(A * B, axis=axis) / ( + np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis) + ) def normalize_image(image: npt.NDArray) -> npt.NDArray: """Normalize image to [0, 1] range.""" - return image.astype(np.float32) / 255.0 \ No newline at end of file + return image.astype(np.float32) / 255.0 diff --git a/tests/neuron/1_core/test_activation.py b/tests/neuron/1_core/test_activation.py deleted file mode 100644 index 2d6e5f523cb85..0000000000000 --- a/tests/neuron/1_core/test_activation.py +++ /dev/null @@ -1,43 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.activation import FastGELU, SiluAndMul -from vllm.platforms import current_platform - - -@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"]) -@pytest.mark.parametrize("num_tokens,d,dtype", [ - (7, 512, torch.half), - (7, 512, torch.float), - (83, 512, torch.half), -]) -@torch.inference_mode() -def test_act_and_mul( - activation: str, - num_tokens: int, - d: int, - dtype: torch.dtype, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device) - if activation == "silu_and_mul": - layer = SiluAndMul() - fn = layer.forward_native - elif activation == "gelu_fast": - layer = FastGELU() - fn = F.gelu - else: - raise NotImplementedError( - f"activation {activation} is not implemented.") - assert x.is_xla, "input tensor under testing is expected to be XLA tensor." - out = layer.to(device=device).forward_neuron(x) - ref_out = fn(x.cpu()) - torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0) diff --git a/tests/neuron/1_core/test_block_table.py b/tests/neuron/1_core/test_block_table.py deleted file mode 100644 index efec56360c142..0000000000000 --- a/tests/neuron/1_core/test_block_table.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import neuronxcc.nki.language as nl -import pytest -import torch -import torch.nn.functional as F -from neuronxcc import nki - -from vllm.attention.ops.nki_flash_attn import ( - load_block_tables, transform_block_tables_for_indirect_load) - - -def is_power_of_2(n): - return n > 0 and (n & (n - 1) == 0) - - -def nki_load_and_transform_block_tables( - block_tables, - num_tiles, - num_blocks_per_tile, - num_head, - head_id, - block_size_tiling_factor, -): - assert is_power_of_2( - num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" - block_tables_sbuf = load_block_tables(block_tables, num_tiles, - num_blocks_per_tile) - - # we need to pass an Index as head_id - head_id = nl.arange(1)[None, :] + head_id - - block_tables_transposed = transform_block_tables_for_indirect_load( - block_tables_sbuf, block_size_tiling_factor, num_head, head_id) - B_P_SIZE = 128 - assert block_tables_transposed.shape[1] == B_P_SIZE - - out = nl.ndarray( - block_tables_transposed.shape, - dtype=nl.int32, - buffer=nl.shared_hbm, - ) - for i in nl.affine_range(block_tables_transposed.shape[0]): - nl.store(dst=out[i], value=block_tables_transposed[i]) - return out - - -def ref_block_tables_transform( - block_tables, - num_tiles, - num_blocks_per_tile, - num_head, - head_id, - block_size_tiling_factor, -): - assert block_tables.numel() == num_tiles * num_blocks_per_tile - block_tables = block_tables.view(num_tiles, num_blocks_per_tile) - B_F_SIZE = 128 - num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE - block_tables = F.pad( - block_tables, - (0, 0, 0, num_tiles_padded - num_tiles), - "constant", - 0, - ) - - block_tables = block_tables * num_head + head_id - block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) - offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) - block_tables = block_tables * block_size_tiling_factor + offset - block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() - - num_blocks_per_tile = block_tables_transposed.shape[0] - assert num_blocks_per_tile % B_F_SIZE == 0 - return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, - B_F_SIZE, num_tiles_padded) - - -@pytest.mark.parametrize( - "q_head_per_kv_head,head_id", - [ - (1, 0), - (3, 1), - ], -) -@pytest.mark.parametrize( - "num_tiles,num_blocks_per_tile", - [ - (1, 1), - (13, 16), - (17, 128), - (35, 512), - (128, 128), - (130, 64), - (280, 256), - (315, 1), - ], -) -@torch.inference_mode() -def test_load_and_transform_block_tables( - monkeypatch: pytest.MonkeyPatch, - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - - compiler_flags_str = " ".join([ - "-O1", - "--retry_failed_compilation", - ]) - with monkeypatch.context() as m: - m.setenv("NEURON_CC_FLAGS", compiler_flags_str) - - torch.manual_seed(10000) - torch.set_printoptions(sci_mode=False) - - # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient - B_P_SIZE = 128 - if num_blocks_per_tile < B_P_SIZE: - assert B_P_SIZE % num_blocks_per_tile == 0 - block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile - else: - block_size_tiling_factor = 1 - max_num_blocks = 100000 - block_tables = torch.randint( - 0, - max_num_blocks, - (num_tiles * num_blocks_per_tile, ), - dtype=torch.int32, - ) - nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( - block_tables.to(device=device), - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, - block_size_tiling_factor, - ).cpu() - ref_out = ref_block_tables_transform( - block_tables, - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, - block_size_tiling_factor, - ) - assert (nki_out.shape == ref_out.shape - ), f"{nki_out.shape=} != {ref_out.shape=}" - assert torch.all(nki_out == ref_out) diff --git a/tests/neuron/1_core/test_cache.py b/tests/neuron/1_core/test_cache.py deleted file mode 100644 index 670889ad6b58d..0000000000000 --- a/tests/neuron/1_core/test_cache.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.attention.ops.nki_flash_attn import reshape_and_cache - - -@pytest.mark.parametrize( - "num_tokens, n_kv_head, d_head, num_blocks, block_size", - [ - # Small model configuration (e.g., GPT-2 small) - (32, 12, 64, 4, 128), # Typical sequence processing - (1, 12, 64, 4, 128), # Single token update - (128, 12, 64, 4, 128), # Longer sequence - - # Medium model configuration (e.g., GPT-2 medium) - (64, 16, 96, 8, 256), # Standard batch - (256, 16, 96, 8, 256), # Large batch - - # Large model configuration (e.g., GPT-3 style) - (48, 32, 128, 16, 512), # Typical processing window - (512, 32, 128, 16, 512), # Full context window - - # Edge cases and stress tests - (1024, 8, 32, 32, 32), # Many tokens, small heads - (16, 64, 256, 4, 64), # Few tokens, many heads - (2048, 24, 128, 64, 128), # Large scale test - - # Minimal configurations for debugging - (4, 2, 16, 2, 16), # Tiny test case - (1, 1, 8, 1, 8), # Minimal possible - ]) -def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks, - block_size): - # Set random seed for reproducibility - torch.manual_seed(42) - - # Create CPU tensors for reference implementation - key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( - torch.tensor(d_head)) - value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( - torch.tensor(d_head)) - key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) - value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) - slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens] - - # Run reference implementation on CPU - block_indices = torch.div(slot_mapping_cpu, - block_size, - rounding_mode="floor") - block_offsets = slot_mapping_cpu % block_size - - for i in range(num_tokens): - block_idx = block_indices[i] - block_offset = block_offsets[i] - key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i] - value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i] - - # Create XLA device tensors - device = torch.device('xla') - key = key_cpu.to(device) - value = value_cpu.to(device) - key_cache = torch.zeros_like(key_cache_cpu, device=device) - value_cache = torch.zeros_like(value_cache_cpu, device=device) - slot_mapping = slot_mapping_cpu.to(device) - kv_cache = torch.stack([key_cache, value_cache]) - - # Run vectorized implementation on XLA device - reshape_and_cache(key, value, kv_cache, slot_mapping) - key_cache, value_cache = torch.unbind(kv_cache, dim=0) - - # Move results back to CPU for comparison - key_cache_result = key_cache.cpu() - value_cache_result = value_cache.cpu() - - # Assert results match - torch.testing.assert_close(key_cache_result, - key_cache_cpu, - rtol=1e-5, - atol=1e-5) - torch.testing.assert_close(value_cache_result, - value_cache_cpu, - rtol=1e-5, - atol=1e-5) diff --git a/tests/neuron/1_core/test_layernorm.py b/tests/neuron/1_core/test_layernorm.py deleted file mode 100644 index c6fce1d1a0630..0000000000000 --- a/tests/neuron/1_core/test_layernorm.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.platforms import current_platform - - -@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ - (7, 8, False, torch.half), - (83, 768, False, torch.half), - (83, 768, True, torch.half), - (83, 768, True, torch.bfloat16), - (83, 768, True, torch.float32), -]) -@torch.inference_mode() -def test_rms_norm( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - layer = RMSNorm(hidden_size).to(dtype=dtype) - layer.weight.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) - x *= scale - residual = torch.randn_like(x) * scale if add_residual else None - - residual_cpu = residual.cpu() if add_residual else None - ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) - assert x.is_xla, "input tensor under testing is expected to be XLA tensor." - out = layer.to(device=device)(x, residual) - - # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger - # numerical errors than other operators because they involve reductions. - # Therefore, we use a larger tolerance. - if add_residual: - assert out[0].is_xla, "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out[0].cpu(), - ref_out[0], - atol=1e-2, - rtol=1e-2) - torch.testing.assert_close(out[1].cpu(), - ref_out[1], - atol=1e-2, - rtol=1e-2) - else: - assert out.is_xla, "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2) diff --git a/tests/neuron/1_core/test_logits_processor.py b/tests/neuron/1_core/test_logits_processor.py deleted file mode 100644 index ce9eadf5a883e..0000000000000 --- a/tests/neuron/1_core/test_logits_processor.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import patch - -import pytest -import torch - -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available - - -class MockLogitsProcessor(LogitsProcessor): - - def __init__(self, vocab_size: int, scale: float, - fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size, scale=scale) - self.fake_logits = fake_logits.clone() - - def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.logits_processor._prune_hidden_states", - lambda x, y: x - ), patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=input_tensor.dtype) - logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - return input_tensor, fake_logits, logits_processor - - -RANDOM_SEEDS = list(range(8)) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_logits_processors(seed: int): - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - set_random_seed(seed) - torch.set_default_device("cpu") - batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - seq_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - logits_processor_output = logits_processor( - lm_head=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - - fake_logits *= logits_processor.scale - torch.testing.assert_close(logits_processor_output[:, 1], - fake_logits[:, 1], - rtol=1e-4, - atol=0.0) diff --git a/tests/neuron/1_core/test_neuron_model_runner.py b/tests/neuron/1_core/test_neuron_model_runner.py deleted file mode 100644 index 5f3268810f9fe..0000000000000 --- a/tests/neuron/1_core/test_neuron_model_runner.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from unittest.mock import MagicMock - -from vllm.config import VllmConfig -from vllm.engine.arg_utils import EngineArgs -from vllm.platforms import current_platform -from vllm.platforms.neuron import NeuronFramework -from vllm.sampling_params import SamplingParams -from vllm.sequence import SequenceData, SequenceGroupMetadata -from vllm.worker.neuron_model_runner import NeuronModelRunner - -os.environ[ - 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value - - -def _create_neuron_model_runner(model: str, *args, - **kwargs) -> NeuronModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - vllm_config = VllmConfig( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - ) - neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config) - return neuron_model_runner - - -def test_update_neuron_sampling_params_not_full_batch(): - os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" - model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) - assert not model_runner._on_device_sampling_disabled - # Test sampling param updating only when TNx is framework - # NxDI handles sampling parameter updating inside model - if current_platform.use_transformers_neuronx(): - model_mock = MagicMock() - model_runner.model = model_mock - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, - top_k=1, - top_p=0.5), - block_tables={0: [1]}, - ) - ] - - model_runner.prepare_model_input(seq_group_metadata_list) - - # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are - # placed at index 1. So the sampling parameters will be: - # Index 0: default sampling parameters - # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = ( - model_runner.model_config.neuron_sampling_params) - assert neuron_sampling_params.temperature == [1.0, 0.5] - assert neuron_sampling_params.top_k == [ - model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 - ] - assert neuron_sampling_params.top_p == [1.0, 0.5] - model_mock.model.update_generation_config.assert_called_once_with( - neuron_sampling_params) - - -def test_update_neuron_sampling_params_full_batch(): - os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" - model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) - assert not model_runner._on_device_sampling_disabled - - # Test sampling param updating only when TNx is framework - # NxDI handles sampling parameter updating inside model - if current_platform.use_transformers_neuronx(): - model_mock = MagicMock() - model_runner.model = model_mock - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, - top_k=1, - top_p=0.5), - block_tables={0: [1]}, - ), - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={1: SequenceData.from_seqs([4, 5, 6])}, - sampling_params=SamplingParams(temperature=0.2, - top_k=2, - top_p=0.2), - block_tables={1: [0]}, - ) - ] - - model_runner.prepare_model_input(seq_group_metadata_list) - - # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are - # placed at index 1. So the sampling parameters will be: - # Index 0: sequence 1's sampling parameters - # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = ( - model_runner.model_config.neuron_sampling_params) - assert neuron_sampling_params.temperature == [0.2, 0.5] - assert neuron_sampling_params.top_k == [2, 1] - assert neuron_sampling_params.top_p == [0.2, 0.5] - model_mock.model.update_generation_config.assert_called_once_with( - neuron_sampling_params) diff --git a/tests/neuron/1_core/test_neuron_quant.py b/tests/neuron/1_core/test_neuron_quant.py deleted file mode 100644 index 0863002695928..0000000000000 --- a/tests/neuron/1_core/test_neuron_quant.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.layers.quantization.neuron_quant import ( - NeuronQuantConfig) - - -def test_get_supported_act_dtypes(): - neuron_quant_config = NeuronQuantConfig() - supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes() - target_list = ["any_dtype1", "any_dtype2"] - for dtype in target_list: - assert dtype in supported_act_dtypes diff --git a/tests/neuron/1_core/test_prefix_prefill.py b/tests/neuron/1_core/test_prefix_prefill.py deleted file mode 100644 index abf7febc2955c..0000000000000 --- a/tests/neuron/1_core/test_prefix_prefill.py +++ /dev/null @@ -1,514 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest -import torch -import torch.nn.functional as F - -from vllm.utils import cdiv - - -class BlockDiagonalCausalFromBottomRightMask: - - @staticmethod - def _from_seqlens(query_lens, seq_lens, block_size=None): - from torch import logical_and, logical_or - - contexted = block_size is None - context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - n_queries = sum(query_lens) - num_seqs = len(query_lens) - if contexted: - key_lens_blockaligned = seq_lens - else: - n_blocks_per_seq = (context_lens + block_size - 1) // block_size - offset_per_seq = n_blocks_per_seq * block_size - key_lens_blockaligned = offset_per_seq[:num_seqs].tolist() - n_keys = sum(key_lens_blockaligned) - - a = (torch.arange(n_queries).reshape(n_queries, - 1).expand(n_queries, n_keys)) - b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys) - q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0) - k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0) - - prior_mask = torch.zeros(n_queries, n_keys) - new_masks: list[torch.Tensor] = [] - for seq_id in range(num_seqs): - ri = q_cumsum[seq_id] - ci = k_cumsum[seq_id] - nr = query_lens[seq_id] - - if contexted: - nc = seq_lens[seq_id] - a_offset = ci + nc - ri - nr - new_mask = (a + a_offset) >= b - else: - nc = context_lens[seq_id] - a_offset = ci + nc - 1 - new_mask = a_offset >= b - - left_mask = b >= ci - top_mask = a >= ri - bottom_mask = a < (ri + nr) - - new_mask = logical_and( - logical_and(logical_and(new_mask, left_mask), top_mask), - bottom_mask, - ) - prior_mask = logical_or(prior_mask, new_mask) - new_masks = new_masks + [new_mask] - return prior_mask - - @staticmethod - def from_seqlens(query_lens, seq_lens, block_size=None): - contexted = block_size is None - if contexted: - prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, seq_lens) - active_mask = None - else: - prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, seq_lens, block_size) - active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, query_lens) - return prior_mask, active_mask - - -def ref_softmax(x: torch.Tensor, - dim: int, - mixed_precision=False, - return_max_reduce=False): - max_value = torch.amax(x, dim=dim, keepdims=True) - exp = torch.exp(x - max_value) - if mixed_precision: - sum_value = torch.sum(exp.astype(torch.float32), - dim=dim, - keepdims=True).astype(x.dtype) - else: - sum_value = torch.sum(exp, dim=dim, keepdims=True) - if return_max_reduce: - return exp / sum_value, max_value, torch.reciprocal(sum_value) - return exp / sum_value - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, - return_max_reduce: Optional[bool] = False, -) -> torch.Tensor: - scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - masked_score = scaled_qk + attn_mask.float() - if return_max_reduce: - norm_score, cached_max, cached_sum_reciprocal = ref_softmax( - masked_score, dim=-1, return_max_reduce=True) - else: - norm_score = ref_softmax(masked_score, dim=-1) - out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value) - if return_max_reduce: - return ( - out, - cached_max, - cached_sum_reciprocal, - norm_score, - masked_score, - scaled_qk, - ) - else: - return (out, ) - - -def ref_context_attention( - query, - key, - value, - query_lens, - seq_lens, - head_size, - num_queries_per_kv, - return_max_reduce=False, -): - scale = float(1.0 / (head_size**0.5)) - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - - attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens) - - # convert binary mask to -inf values - attn_mask = torch.logical_not(attn_mask) - attn_mask = attn_mask.float() * -30000 - - output, *debug_tensors = ref_masked_attention( - query, - key, - value, - scale, - attn_mask, - return_max_reduce=return_max_reduce, - ) - - output = output.unsqueeze(1) - if return_max_reduce: - cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( - debug_tensors) - return ( - output, - cached_max, - cached_sum_reciprocal, - lse, - masked_score, - scaled_qk, - ) - else: - return output - - -def sample_inputs( - prefill_batch_size, - decode_batch_size, - min_query_len, - max_query_len, - min_ctx_len, - max_ctx_len, - block_size, - num_heads, - num_kv_heads, - head_size, - dtype, -): - batch_size = prefill_batch_size + decode_batch_size - max_model_len = (max_query_len + max_ctx_len) * 4 - max_block_per_request = max_model_len // block_size - cache_size = (batch_size * max_block_per_request) + 2 - prefill_ctx_lens = torch.randint(min_ctx_len, - max_ctx_len + 1, (prefill_batch_size, ), - dtype=torch.long).tolist() - decode_ctx_lens = torch.randint(min_ctx_len, - max_ctx_len + 1, (decode_batch_size, ), - dtype=torch.long).tolist() - ctx_lens = prefill_ctx_lens + decode_ctx_lens - query_lens = torch.randint( - min_query_len, - max_query_len + 1, - (prefill_batch_size, ), - dtype=torch.long, - ).tolist() + [1 for _ in range(decode_batch_size)] - seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - - num_tokens = sum(query_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1, 1) - torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - - kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - kv.uniform_(-1, 1) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] - block_table = values[:batch_size * max_block_per_request].view( - batch_size, max_block_per_request) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], - dtype=torch.long), - dim=0) - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) - for i in range(batch_size): - for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - kv_cache = torch.stack([k_cache, v_cache]) - - return ( - query, - k, - v, - kv_cache, - block_table, - key, - value, - query_lens, - seq_lens, - ) - - -def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, - num_blocks): - context_lens = seq_lens - query_lens - blocks_per_seq = (context_lens + block_size - 1) // block_size - num_seqs = len(seq_lens) - active_blocks: list[int] = [] - for seq_id in range(num_seqs): - active_blocks = ( - active_blocks + - block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) - return F.pad( - torch.tensor(active_blocks, dtype=torch.int32), - (0, num_blocks - len(active_blocks)), - "constant", - 0, - ) - - -@pytest.mark.parametrize( - "prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision", - [ - # Test minimal configurations (small block size) - (1, 199, 1, 512, 4, 2, 8, False - ), # minimal block size, small dimensions - (1, 199, 1, 512, 4, 2, 8, True), # same with mixed precision - - # Test common/medium configurations - (4, 12, 32, 2048, 32, 8, 64, False), # common case, larger heads - (4, 12, 32, 2048, 16, 4, 32, - True), # medium size, mixed precision, grouped-query attention (GQA) - - # Test large configurations - (4, 12, 256, 8192, 8, 1, 128, False), # large blocks, large head size - (4, 12, 256, 8192, 64, 8, 64, True), # large blocks, many heads - - # Test asymmetric configurations - (2, 24, 64, 4096, 12, 4, 96, False), # varied batch sizes - (8, 8, 128, 2048, 24, 2, 48, True), # balanced batches - - # Test edge cases - (1, 128, 16, 1024, 4, 2, 16, False), # large decode batch - (16, 4, 8, 1024, 4, 2, 128, True), # large prefill batch - (4, 12, 32, 2048, 16, 1, 32, True), # multi-head attention (MHA) - (4, 12, 32, 2048, 16, 16, 32, True), # multi-query attention (MQA) - ]) -@torch.inference_mode() -def test_contexted_kv_attention( - monkeypatch: pytest.MonkeyPatch, - prefill_batch_size: int, - decode_batch_size: int, - num_heads: int, - num_queries_per_kv: int, - head_size: int, - block_size: int, - large_tile_size, - mixed_precision: bool, -) -> None: - - import torch_xla.core.xla_model as xm - - from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc, - reorder_context_mask) - - assert large_tile_size % block_size == 0 - - device = xm.xla_device() - - compiler_flags_str = " ".join([ - "-O1", - "--retry_failed_compilation", - ]) - with monkeypatch.context() as m: - m.setenv("NEURON_CC_FLAGS", compiler_flags_str) - - torch.manual_seed(0) - torch.set_printoptions(sci_mode=False) - torch.set_default_device("cpu") - dtype = torch.float32 - - min_ctx_len = 32 - max_ctx_len = 1024 - min_query_len = 16 - max_query_len = 512 - num_kv_heads = num_heads // num_queries_per_kv - ( - query, - k_active, - v_active, - kv_cache, - block_table, - key, - value, - query_lens, - seq_lens, - ) = sample_inputs( - prefill_batch_size=prefill_batch_size, - decode_batch_size=decode_batch_size, - min_query_len=min_query_len, - max_query_len=max_query_len, - min_ctx_len=min_ctx_len, - max_ctx_len=max_ctx_len, - block_size=block_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - ) - - output_ref = ref_context_attention( - query, - key, - value, - query_lens, - seq_lens, - head_size, - num_queries_per_kv, - return_max_reduce=False, - ) - - # build neuron program - B_P_SIZE = 128 - assert (large_tile_size >= B_P_SIZE - ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" - - def pad_to_multiple(a, b): - return cdiv(a, b) * b - - def pad_to_next_power_of_2(a): - assert a > 0 - return 2**int(a - 1).bit_length() - - # calculate input shapes - max_num_queries = pad_to_next_power_of_2(sum(query_lens)) - context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - num_active_blocks = cdiv(context_lens, block_size).sum().item() - num_active_blocks = pad_to_multiple(num_active_blocks, - large_tile_size // block_size) - context_kv_len = num_active_blocks * block_size - assert ( - context_kv_len % - large_tile_size == 0), f"invalid context_kv_len={context_kv_len}" - - # pad QKV tensors - pad_dims = ( - 0, - 0, - 0, - 0, - 0, - max_num_queries - query.shape[0], - ) - query = F.pad(query, pad_dims, "constant", 0) - k = F.pad(k_active, pad_dims, "constant", 0) - v = F.pad(v_active, pad_dims, "constant", 0) - - # permute QKV tensors - # query: (1, n_heads, d, seq_q) - # key: (1, n_kv_heads, d, seq_k) - # value: (1, n_kv_heads, seq_v, d) - query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() - k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() - v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() - kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous() - - # transform block table - active_block_table = get_active_block_tables( - block_table.cpu(), - torch.tensor(query_lens).cpu(), - torch.tensor(seq_lens).cpu(), - block_size, - num_active_blocks, - ) - - # Build attention masks - prior_mask, active_mask = ( - BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens, block_size=block_size)) - prior_mask_padded = F.pad( - prior_mask, - ( - 0, - context_kv_len - prior_mask.shape[1], - 0, - max_num_queries - prior_mask.shape[0], - ), - "constant", - 0, - ).bool() - active_mask_padded = F.pad( - active_mask, - ( - 0, - max_num_queries - active_mask.shape[1], - 0, - max_num_queries - active_mask.shape[0], - ), - "constant", - 0, - ).bool() - attn_mask = torch.concat([prior_mask_padded, active_mask_padded], - dim=1) - - attn_mask = reorder_context_mask(attn_mask, large_tile_size, - block_size) - - input_args = ( - query.to(device=device), - k.to(device=device), - v.to(device=device), - kv_cache.to(device=device), - active_block_table.to(device=device), - attn_mask.to(device=device), - ) - input_kwargs = dict( - n_kv_head=num_kv_heads, - head_size=head_size, - mixed_precision=mixed_precision, - LARGE_TILE_SZ=large_tile_size, - ) - - output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) - - num_actual_tokens = sum(query_lens) - # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) - output_nki = output_nki.cpu().permute(0, 2, 1, 3) - output_nki = output_nki[0, :num_actual_tokens, :, :] - output_ref_padded = F.pad( - output_ref, - (0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]), - "constant", - 0, - ) - output_ref = output_ref_padded.transpose( - 0, 1)[0, :num_actual_tokens, :, :] - - torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0) diff --git a/tests/neuron/1_core/test_rotary_embedding.py b/tests/neuron/1_core/test_rotary_embedding.py deleted file mode 100644 index a7ac79729986d..0000000000000 --- a/tests/neuron/1_core/test_rotary_embedding.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Tests for miscellaneous utilities -""" - -import pytest -import torch - -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.platforms import current_platform - - -@pytest.mark.parametrize( - "max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [ - (16, False, 32, 32, 1024, True), - (16, False, 32, 128, 1024, True), - (16, True, 32, 32, 1024, True), - (16, True, 32, 128, 1024, True), - (16, False, 32, 128, 1024, False), - (16, True, 32, 128, 1024, False), - ]) -def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, - head_size, seq_len, use_key): - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - - batch_size = 1 - base = 10000 - num_heads = 8 - - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) - - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device="cpu") - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=torch.float32, - device="cpu") - key = torch.randn_like(query) if use_key else None - assert positions.is_cpu, \ - "reference input tensor is expected to be CPU tensor." - ref_query, ref_key = rot.to(device="cpu").forward_native( - positions, query, key) - out_query, out_key = rot.to(device=device).forward_neuron( - positions.to(device=device), query.to(device=device), - key.to(device=device) if key is not None else None) - if use_key: - assert out_query.is_xla and out_key.is_xla, \ - "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out_key.cpu(), - ref_key, - atol=1e-2, - rtol=1e-2) - else: - assert out_key is None, "expected returned key to be None" - assert out_query.is_xla, \ - "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out_query.cpu(), - ref_query, - atol=1e-2, - rtol=1e-2) diff --git a/tests/neuron/2_core/test_comm_ops.py b/tests/neuron/2_core/test_comm_ops.py deleted file mode 100644 index 85a48dae58aaf..0000000000000 --- a/tests/neuron/2_core/test_comm_ops.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools -from typing import Callable -from unittest.mock import patch - -import pytest -import torch -import torch_xla.distributed.xla_multiprocessing as xmp -from typing_extensions import ParamSpec - -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.utils import get_distributed_init_method, get_open_port - -_P = ParamSpec("_P") - - -def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to reinitialize the Neuron Runtime before executing a test. - This is necessary for distributed tests which need to reallocate Neuron - Cores to separate subprocesses. - """ - - @functools.wraps(f) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - runtime = torch.classes.neuron.Runtime() - runtime.initialize() - runtime.unsafe_close() - - f(*args, **kwargs) - runtime.initialize() - - return wrapper - - -def all_gather_test_worker(index, tp_degree, distributed_init_method): - init_distributed_environment(tp_degree, - index, - distributed_init_method, - index, - backend="xla") - ensure_model_parallel_initialized(tp_degree, 1) - - num_dimensions = 3 - tensor_size = list(range(2, num_dimensions + 2)) - total_size = 1 - for s in tensor_size: - total_size *= s - - all_gather_dimension = -1 - all_tensors = [ - torch.arange(total_size, dtype=torch.float32, - device="xla").reshape(tensor_size) * (r + 1) - for r in range(tp_degree) - ] - expected = torch.cat(all_tensors, dim=all_gather_dimension) - t = all_tensors[index % tp_degree] - t = tensor_model_parallel_all_gather(t, all_gather_dimension) - torch.testing.assert_close(t, expected) - - -def all_reduce_test_worker(index, tp_degree, distributed_init_method): - init_distributed_environment(tp_degree, - index, - distributed_init_method, - index, - backend="xla") - ensure_model_parallel_initialized(tp_degree, 1) - - num_elements = 8 - all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1) - for r in range(tp_degree) - ] - expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - t = all_tensors[index % tp_degree] - t = tensor_model_parallel_all_reduce(t) - torch.testing.assert_close(t, expected) - - -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("test_target", - [all_reduce_test_worker, all_gather_test_worker]) -@reinitialize_neuron_runtime -def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size, - test_target): - - with patch('torch_xla._XLAC._xla_runtime_is_initialized', - return_value=False): - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size)) - monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES", - ','.join(['1' for _ in range(tp_size)])) - - xmp.spawn(test_target, args=(tp_size, distributed_init_method)) diff --git a/tests/neuron/2_core/test_eagle.py b/tests/neuron/2_core/test_eagle.py deleted file mode 100644 index cac642af03101..0000000000000 --- a/tests/neuron/2_core/test_eagle.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import json -import os -import shutil -import tempfile - -import torch -from huggingface_hub import snapshot_download -from safetensors import safe_open - -from vllm import LLM, SamplingParams - - -def patch_eagle_draft_with_lm_head(target_model_id: str, - draft_model_id: str) -> str: - # In NxDI, draft model checkpoint must include lm_head weights from target - # model. For more details see https://awsdocs-neuron.readthedocs-hosted.com - # /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html - # #eagle-checkpoint-compatibility - final_draft_dir = "/tmp/patched_eagle_draft" - - with tempfile.TemporaryDirectory() as tmp_dir: - target_dir = snapshot_download(repo_id=target_model_id, - local_dir=os.path.join( - tmp_dir, "target")) - draft_dir = snapshot_download(repo_id=draft_model_id, - local_dir=os.path.join(tmp_dir, "draft")) - - lm_head_key = "lm_head.weight" - index_path = os.path.join(target_dir, "model.safetensors.index.json") - with open(index_path) as f: - index = json.load(f) - shard_name = index["weight_map"][lm_head_key] - target_safetensor_path = os.path.join(target_dir, shard_name) - - with safe_open(target_safetensor_path, framework="pt") as f: - target_lm_head = f.get_tensor(lm_head_key) - - draft_path = os.path.join(draft_dir, "pytorch_model.bin") - draft_state_dict = torch.load(draft_path, map_location="cpu") - draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16) - torch.save(draft_state_dict, draft_path) - - shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True) - - return final_draft_dir - - -def test_eagle(): - patched_draft_path = patch_eagle_draft_with_lm_head( - target_model_id="meta-llama/Llama-2-7b-hf", - draft_model_id="yuhuili/EAGLE-llama2-chat-7B") - llm = LLM( - model="meta-llama/Llama-2-7b-hf", - speculative_config={ - "model": patched_draft_path, - "num_speculative_tokens": 5, - "max_model_len": 128 - }, - max_num_seqs=1, - max_model_len=128, - tensor_parallel_size=2, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - "fused_qkv": True - }, - ) - prompts = [ - "The president of the United States is", - ] - outputs = llm.generate(prompts, SamplingParams(top_k=1)) - expected_output = " the head of state and head of government of " \ - "the United States. The president direct" - - for output in outputs: - generated_text = output.outputs[0].text - print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") - assert (expected_output == generated_text) - - print("Neuron Eagle speculation test passed.") diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py deleted file mode 100644 index ff59be1725b6c..0000000000000 --- a/tests/neuron/2_core/test_mistral.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - - -def test_mistral(): - llm = LLM(model="mistralai/Mistral-7B-v0.1", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=128, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True - }) - - # Send more prompts than the compiled batch size (4) and request - # varying generation lengths to test accuracy related to Neuron - # specific sequence id sorting. - prompts = [ - "The president of the United States is", - "The capital of France is", - "What is Annapurna labs?", - "I believe the meaning of life is", - "Tell me a story about a brave knight", - "Hello, my name is Llama", - ] - - sampling_params = [ - SamplingParams(top_k=1, max_tokens=10), - SamplingParams(top_k=1, max_tokens=20), - SamplingParams(top_k=1, max_tokens=30), - SamplingParams(top_k=1, max_tokens=40), - SamplingParams(top_k=1, max_tokens=50), - SamplingParams(top_k=1, max_tokens=60) - ] - - outputs = llm.generate(prompts, sampling_params) - - expected_outputs = [ - " the most powerful person in the world. He is", - " a city of many faces. It is a city of history, culture, art, " - "fashion, and", - "\n\nAnnapurna Labs is a semiconductor company that was founded " - "in 2013 by Amazon. The company is", - " to be happy.\n\nI believe that happiness is a choice.\n\nI " - "believe that happiness is a state of mind.\n\nI believe that " - "happiness is a journey.\n\nI believe", - " who rescued a princess from a dragon.\n\nTell me a story about" - " a princess who rescued herself from a dragon.\n\nTell me a " - "story about a princess who rescued herself from a dragon and " - "then rescued a knight from", - " and I am a 10 year old male. I am a very friendly and " - "affectionate boy who loves to be around people. I am a very " - "active boy who loves to play and run around. I am a very smart " - "boy who loves to learn new things. I am a very loyal boy" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") - assert (expected_output == generated_text) - - print("Neuron Mistral test passed.") diff --git a/tests/neuron/2_core/test_multi_lora.py b/tests/neuron/2_core/test_multi_lora.py deleted file mode 100644 index 52ca9fe7b6667..0000000000000 --- a/tests/neuron/2_core/test_multi_lora.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from huggingface_hub import snapshot_download - -from vllm import LLM, SamplingParams -from vllm.lora.request import LoRARequest - - -def test_llama_single_lora(): - sql_lora_files = snapshot_download( - repo_id="yard1/llama-2-7b-sql-lora-test") - llm = LLM(model="meta-llama/Llama-2-7b-hf", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=512, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "lora_modules": [{ - "name": "lora_id_1", - "path": sql_lora_files - }] - }, - enable_lora=True, - max_loras=1, - max_lora_rank=256, - device="neuron") - """For multi-lora requests using NxDI as the backend, only the lora_name - needs to be specified. The lora_id and lora_path are supplied at the LLM - class/server initialization, after which the paths are handled by NxDI""" - lora_req_1 = LoRARequest("lora_id_1", 0, " ") - prompts = [ - "The president of the United States is", - "The capital of France is", - ] - outputs = llm.generate(prompts, - SamplingParams(top_k=1), - lora_request=[lora_req_1, lora_req_1]) - - expected_outputs = [ - " the head of state and head of government of the United States. " - "The president direct", - " a city of contrasts. The city is home to the Eiffel Tower" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - assert (expected_output == generated_text) - - -def test_llama_multiple_lora(): - sql_lora_files = snapshot_download( - repo_id="yard1/llama-2-7b-sql-lora-test") - llm = LLM(model="meta-llama/Llama-2-7b-hf", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=512, - override_neuron_config={ - "sequence_parallel_enabled": - False, - "skip_warmup": - True, - "lora_modules": [{ - "name": "lora_id_1", - "path": sql_lora_files - }, { - "name": "lora_id_2", - "path": sql_lora_files - }] - }, - enable_lora=True, - max_loras=2, - max_lora_rank=256, - device="neuron") - """For multi-lora requests using NxDI as the backend, only the lora_name - needs to be specified. The lora_id and lora_path are supplied at the LLM - class/server initialization, after which the paths are handled by NxDI""" - lora_req_1 = LoRARequest("lora_id_1", 0, " ") - lora_req_2 = LoRARequest("lora_id_2", 1, " ") - prompts = [ - "The president of the United States is", - "The capital of France is", - ] - outputs = llm.generate(prompts, - SamplingParams(top_k=1), - lora_request=[lora_req_1, lora_req_2]) - - expected_outputs = [ - " the head of state and head of government of the United States. " - "The president direct", - " a city of contrasts. The city is home to the Eiffel Tower" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - assert (expected_output == generated_text) diff --git a/tests/plugins/lora_resolvers/test_filesystem_resolver.py b/tests/plugins/lora_resolvers/test_filesystem_resolver.py index 3e2c2577da66c..cd98efdd13909 100644 --- a/tests/plugins/lora_resolvers/test_filesystem_resolver.py +++ b/tests/plugins/lora_resolvers/test_filesystem_resolver.py @@ -13,11 +13,10 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora" PA_NAME = "swapnilbp/llama_tweet_ptune" -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def adapter_cache(request, tmpdir_factory): # Create dir that mimics the structure of the adapter cache - adapter_cache = tmpdir_factory.mktemp( - request.module.__name__) / "adapter_cache" + adapter_cache = tmpdir_factory.mktemp(request.module.__name__) / "adapter_cache" return adapter_cache diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py new file mode 100644 index 0000000000000..4bbb79c98a82a --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +def register_prithvi(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501 diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py new file mode 100644 index 0000000000000..a2a8d0ec9aba4 --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import base64 +import datetime +import os +import tempfile +import urllib.request +from collections.abc import Sequence +from typing import Any, Union + +import albumentations +import numpy as np +import rasterio +import regex as re +import torch +from einops import rearrange +from terratorch.datamodules import Sen1Floods11NonGeoDataModule + +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput +from vllm.plugins.io_processors.interface import ( + IOProcessor, + IOProcessorInput, + IOProcessorOutput, +) + +from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput + +logger = init_logger(__name__) + +NO_DATA = -9999 +NO_DATA_FLOAT = 0.0001 +OFFSET = 0 +PERCENTILE = 99 + +DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5] + +datamodule_config: DataModuleConfig = { + "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], + "batch_size": 16, + "constant_scale": 0.0001, + "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11", + "drop_last": True, + "no_data_replace": 0.0, + "no_label_replace": -1, + "num_workers": 8, + "test_transform": [ + albumentations.Resize( + always_apply=False, height=448, interpolation=1, p=1, width=448 + ), + albumentations.pytorch.ToTensorV2( + transpose_mask=False, always_apply=True, p=1.0 + ), + ], +} + + +def save_geotiff(image: torch.Tensor, meta: dict, out_format: str) -> str | bytes: + """Save multi-band image in Geotiff file. + + Args: + image: np.ndarray with shape (bands, height, width) + output_path: path where to save the image + meta: dict with meta info. + """ + if out_format == "path": + # create temp file + file_path = os.path.join(os.getcwd(), "prediction.tiff") + with rasterio.open(file_path, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + return file_path + elif out_format == "b64_json": + with tempfile.NamedTemporaryFile() as tmpfile: + with rasterio.open(tmpfile.name, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + file_data = tmpfile.read() + return base64.b64encode(file_data) + + else: + raise ValueError("Unknown output format") + + +def _convert_np_uint8(float_image: torch.Tensor): + image = float_image.numpy() * 255.0 + image = image.astype(dtype=np.uint8) + + return image + + +def read_geotiff( + file_path: str | None = None, + path_type: str | None = None, + file_data: bytes | None = None, +) -> tuple[torch.Tensor, dict, tuple[float, float] | None]: + """Read all bands from *file_path* and return image + meta info. + + Args: + file_path: path to image file. + + Returns: + np.ndarray with shape (bands, height, width) + meta info dict + """ + + if all([x is None for x in [file_path, path_type, file_data]]): + raise Exception("All input fields to read_geotiff are None") + write_to_file: bytes | None = None + path: str | None = None + if file_data is not None: + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(file_data) + # path = tmpfile.name + + write_to_file = file_data + elif file_path is not None and path_type == "url": + resp = urllib.request.urlopen(file_path) + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(resp.read()) + # path = tmpfile.name + write_to_file = resp.read() + elif file_path is not None and path_type == "path": + path = file_path + elif file_path is not None and path_type == "b64_json": + image_data = base64.b64decode(file_path) + # with tempfile.NamedTemporaryFile() as tmpfile: + # tmpfile.write(image_data) + # path = tmpfile.name + write_to_file = image_data + else: + raise Exception("Wrong combination of parameters to read_geotiff") + + with tempfile.NamedTemporaryFile() as tmpfile: + path_to_use = None + if write_to_file: + tmpfile.write(write_to_file) + path_to_use = tmpfile.name + elif path: + path_to_use = path + + with rasterio.open(path_to_use) as src: + img = src.read() + meta = src.meta + try: + coords = src.lnglat() + except Exception: + # Cannot read coords + coords = None + + return img, meta, coords + + +def load_image( + data: Union[list[str]], + path_type: str, + mean: list[float] | None = None, + std: list[float] | None = None, + indices: Union[list[int], None] | None = None, +): + """Build an input example by loading images in *file_paths*. + + Args: + file_paths: list of file paths . + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. + + Returns: + np.array containing created example + list of meta info for each image in *file_paths* + """ + + imgs = [] + metas = [] + temporal_coords = [] + location_coords = [] + + for file in data: + # if isinstance(file, bytes): + # img, meta, coords = read_geotiff(file_data=file) + # else: + img, meta, coords = read_geotiff(file_path=file, path_type=path_type) + # Rescaling (don't normalize on nodata) + img = np.moveaxis(img, 0, -1) # channels last for rescaling + if indices is not None: + img = img[..., indices] + if mean is not None and std is not None: + img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std) + + imgs.append(img) + metas.append(meta) + if coords is not None: + location_coords.append(coords) + + try: + match = re.search(r"(\d{7,8}T\d{6})", file) + if match: + year = int(match.group(1)[:4]) + julian_day = match.group(1).split("T")[0][4:] + if len(julian_day) == 3: + julian_day = int(julian_day) + else: + julian_day = ( + datetime.datetime.strptime(julian_day, "%m%d") + .timetuple() + .tm_yday + ) + temporal_coords.append([year, julian_day]) + except Exception: + logger.exception("Could not extract timestamp for %s", file) + + imgs = np.stack(imgs, axis=0) # num_frames, H, W, C + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W + imgs = np.expand_dims(imgs, axis=0) # add batch di + + return imgs, temporal_coords, location_coords, metas + + +class PrithviMultimodalDataProcessor(IOProcessor): + indices = [0, 1, 2, 3, 4, 5] + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + self.datamodule = Sen1Floods11NonGeoDataModule( + data_root=datamodule_config["data_root"], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform"], + ) + self.img_size = 512 + self.h1 = 1 + self.w1 = 1 + self.original_h = 512 + self.original_w = 512 + self.batch_size = 1 + self.meta_data = None + self.requests_cache: dict[str, dict[str, Any]] = {} + self.indices = DEFAULT_INPUT_INDICES + + def parse_request(self, request: Any) -> IOProcessorInput: + if type(request) is dict: + image_prompt = ImagePrompt(**request) + return image_prompt + if isinstance(request, IOProcessorRequest): + if not hasattr(request, "data"): + raise ValueError("missing 'data' field in OpenAIBaseModel Request") + + request_data = request.data + + if type(request_data) is dict: + return ImagePrompt(**request_data) + else: + raise ValueError("Unable to parse the request data") + + raise ValueError("Unable to parse request") + + def output_to_response( + self, plugin_output: IOProcessorOutput + ) -> IOProcessorResponse: + return IOProcessorResponse( + request_id=plugin_output.request_id, + data=plugin_output, + ) + + def pre_process( + self, + prompt: IOProcessorInput, + request_id: str | None = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + image_data = dict(prompt) + + if request_id: + self.requests_cache[request_id] = { + "out_format": image_data["out_data_format"], + } + + input_data, temporal_coords, location_coords, meta_data = load_image( + data=[image_data["data"]], + indices=self.indices, + path_type=image_data["data_format"], + ) + + self.meta_data = meta_data[0] + + if input_data.mean() > 1: + input_data = input_data / 10000 # Convert to range 0-1 + + self.original_h, self.original_w = input_data.shape[-2:] + pad_h = (self.img_size - (self.original_h % self.img_size)) % self.img_size + pad_w = (self.img_size - (self.original_w % self.img_size)) % self.img_size + input_data = np.pad( + input_data, + ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), + mode="reflect", + ) + + batch = torch.tensor(input_data) + windows = batch.unfold(3, self.img_size, self.img_size).unfold( + 4, self.img_size, self.img_size + ) + self.h1, self.w1 = windows.shape[3:5] + windows = rearrange( + windows, + "b c t h1 w1 h w -> (b h1 w1) c t h w", + h=self.img_size, + w=self.img_size, + ) + + # Split into batches if number of windows > batch_size + num_batches = ( + windows.shape[0] // self.batch_size + if windows.shape[0] > self.batch_size + else 1 + ) + windows = torch.tensor_split(windows, num_batches, dim=0) + + if temporal_coords: + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) + else: + temporal_coords = None + if location_coords: + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) + else: + location_coords = None + + prompts = [] + for window in windows: + # Apply standardization + window = self.datamodule.test_transform( + image=window.squeeze().numpy().transpose(1, 2, 0) + ) + window = self.datamodule.aug(window)["image"] + prompts.append( + { + "prompt_token_ids": [1], + "multi_modal_data": { + "pixel_values": window.to(torch.float16)[0], + "location_coords": location_coords.to(torch.float16), + }, + } + ) + + return prompts + + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: str | None = None, + **kwargs, + ) -> IOProcessorOutput: + pred_imgs_list = [] + + if request_id and (request_id in self.requests_cache): + out_format = self.requests_cache[request_id]["out_format"] + else: + out_format = "b64_json" + + for output in model_output: + y_hat = output.outputs.data.argmax(dim=1) + pred = torch.nn.functional.interpolate( + y_hat.unsqueeze(1).float(), + size=self.img_size, + mode="nearest", + ) + pred_imgs_list.append(pred) + + pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0) + + # Build images from patches + pred_imgs = rearrange( + pred_imgs, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h=self.img_size, + w=self.img_size, + b=1, + c=1, + h1=self.h1, + w1=self.w1, + ) + + # Cut padded area back to original size + pred_imgs = pred_imgs[..., : self.original_h, : self.original_w] + + # Squeeze (batch size 1) + pred_imgs = pred_imgs[0] + + if not self.meta_data: + raise ValueError("No metadata available for the current task") + self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) + out_data = save_geotiff( + _convert_np_uint8(pred_imgs), self.meta_data, out_format + ) + + return ImageRequestOutput( + type=out_format, format="tiff", data=out_data, request_id=request_id + ) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py new file mode 100644 index 0000000000000..21a5c3754c36f --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Literal, Optional, TypedDict, Union + +import albumentations +from pydantic import BaseModel + + +class DataModuleConfig(TypedDict): + bands: list[str] + batch_size: int + constant_scale: float + data_root: str + drop_last: bool + no_data_replace: float + no_label_replace: int + num_workers: int + test_transform: list[albumentations.core.transforms_interface.BasicTransform] + + +class ImagePrompt(BaseModel): + data_format: Literal["b64_json", "bytes", "url", "path"] + """ + This is the data type for the input image + """ + + image_format: str + """ + This is the image format (e.g., jpeg, png, etc.) + """ + + out_data_format: Literal["b64_json", "url"] + + data: Any + """ + Input image data + """ + + +MultiModalPromptType = Union[ImagePrompt] + + +class ImageRequestOutput(BaseModel): + """ + The output data of an image request to vLLM. + + Args: + type (str): The data content type [path, object] + format (str): The image format (e.g., jpeg, png, etc.) + data (Any): The resulting data. + """ + + type: Literal["path", "b64_json"] + format: str + data: str + request_id: Optional[str] = None diff --git a/tests/plugins/prithvi_io_processor_plugin/setup.py b/tests/plugins/prithvi_io_processor_plugin/setup.py new file mode 100644 index 0000000000000..3ddda1a47bbe4 --- /dev/null +++ b/tests/plugins/prithvi_io_processor_plugin/setup.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from setuptools import setup + +setup( + name="prithvi_io_processor_plugin", + version="0.1", + packages=["prithvi_io_processor"], + entry_points={ + "vllm.io_processor_plugins": [ + "prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501 + ] + }, +) diff --git a/tests/plugins/vllm_add_dummy_model/setup.py b/tests/plugins/vllm_add_dummy_model/setup.py index 6307bb63897ac..eeffac5d3eddd 100644 --- a/tests/plugins/vllm_add_dummy_model/setup.py +++ b/tests/plugins/vllm_add_dummy_model/setup.py @@ -3,10 +3,11 @@ from setuptools import setup -setup(name='vllm_add_dummy_model', - version='0.1', - packages=['vllm_add_dummy_model'], - entry_points={ - 'vllm.general_plugins': - ["register_dummy_model = vllm_add_dummy_model:register"] - }) +setup( + name="vllm_add_dummy_model", + version="0.1", + packages=["vllm_add_dummy_model"], + entry_points={ + "vllm.general_plugins": ["register_dummy_model = vllm_add_dummy_model:register"] + }, +) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py index b2085b01c45c1..457187e4b492e 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py @@ -19,5 +19,4 @@ def register(): ) if "MyLlava" not in ModelRegistry.get_supported_archs(): - ModelRegistry.register_model("MyLlava", - "vllm_add_dummy_model.my_llava:MyLlava") + ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava") diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index fc654f20fff22..a22a10eab47dc 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -15,7 +15,6 @@ from vllm.sequence import IntermediateTensors class MyGemma2Embedding(nn.Module): - is_pooling_model = True hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) @@ -23,19 +22,23 @@ class MyGemma2Embedding(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def forward( self, @@ -58,8 +61,8 @@ class MyGemma2Embedding(nn.Module): return torch.zeros_like(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) + weights = ( + (name, data) for name, data in weights if not name.startswith("lm_head.") + ) return self.model.load_weights(weights) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index da97cf7e2b40b..9e6f5c3a77e3c 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -5,24 +5,24 @@ from typing import Optional import torch -from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder, - LlavaForConditionalGeneration, - LlavaMultiModalProcessor, - LlavaProcessingInfo) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.llava import ( + LlavaDummyInputsBuilder, + LlavaForConditionalGeneration, + LlavaMultiModalProcessor, + LlavaProcessingInfo, +) from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor, - info=LlavaProcessingInfo, - dummy_inputs=LlavaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + LlavaMultiModalProcessor, + info=LlavaProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder, +) class MyLlava(LlavaForConditionalGeneration): - - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) + logits = super().compute_logits(hidden_states) if logits is not None: logits.zero_() logits[:, 0] += 1.0 diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py index 8c34407e3e071..c02299f5d44f2 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py @@ -6,16 +6,12 @@ from typing import Optional import torch from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata class MyOPTForCausalLM(OPTForCausalLM): - - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) + logits = super().compute_logits(hidden_states) if logits is not None: logits.zero_() logits[:, 0] += 1.0 diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py index a531826628cda..b976dddb7fb5d 100644 --- a/tests/plugins/vllm_add_dummy_platform/setup.py +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -4,13 +4,15 @@ from setuptools import setup setup( - name='vllm_add_dummy_platform', - version='0.1', - packages=['vllm_add_dummy_platform'], + name="vllm_add_dummy_platform", + version="0.1", + packages=["vllm_add_dummy_platform"], entry_points={ - 'vllm.platform_plugins': [ + "vllm.platform_plugins": [ "dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa ], - "vllm.general_plugins": - ["dummy_custom_ops = vllm_add_dummy_platform:register_ops"], - }) + "vllm.general_plugins": [ + "dummy_custom_ops = vllm_add_dummy_platform:register_ops" + ], + }, +) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py index e38fb2fbf934e..f2d516f52b8b3 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionBackend) +from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend class DummyAttentionBackend(PlaceholderAttentionBackend): - @staticmethod def get_name() -> str: return "Dummy_Backend" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py index 1fcc3fc666173..b730285745269 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py @@ -15,6 +15,5 @@ class DummyRotaryEmbedding(RotaryEmbedding): super().__init__(*args, **kwargs) self.addition_config = True - def forward_oot(self, *args, - **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + def forward_oot(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: return super().forward_oot(*args, **kwargs) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index 8d0687b49bb47..0389e28746cbb 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None -from vllm import envs class DummyPlatform(Platform): @@ -19,12 +18,18 @@ class DummyPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - if envs.VLLM_USE_V1: - compilation_config = vllm_config.compilation_config - # Activate custom ops for v1. - compilation_config.custom_ops = ["all"] + vllm_config.compilation_config.custom_ops = ["all"] - def get_attn_backend_cls(self, backend_name, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink): + def get_attn_backend_cls( + self, + backend_name, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/plugins_tests/conftest.py b/tests/plugins_tests/conftest.py deleted file mode 100644 index c8c1b81ca2183..0000000000000 --- a/tests/plugins_tests/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') \ No newline at end of file diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py new file mode 100644 index 0000000000000..912b32755e80f --- /dev/null +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import IOProcessorResponse +from vllm.plugins.io_processors import get_io_processor +from vllm.pooling_params import PoolingParams + +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" + +image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + + +def test_loading_missing_plugin(): + vllm_config = VllmConfig() + with pytest.raises(ValueError): + get_io_processor(vllm_config, "wrong_plugin") + + +@pytest.fixture(scope="function") +def server(): + args = [ + "--runner", + "pooling", + "--enforce-eager", + "--trust-remote-code", + "--skip-tokenizer-init", + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + "--max-num-seqs", + "32", + "--io-processor-plugin", + "prithvi_to_tiff", + "--model-impl", + "terratorch", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_prithvi_mae_plugin_online( + server: RemoteOpenAIServer, + model_name: str, +): + request_payload_url = { + "data": { + "data": image_url, + "data_format": "url", + "image_format": "tiff", + "out_data_format": "b64_json", + }, + "priority": 0, + "model": model_name, + "softmax": False, + } + + ret = requests.post( + server.url_for("pooling"), + json=request_payload_url, + ) + + response = ret.json() + + # verify the request response is in the correct format + assert (parsed_response := IOProcessorResponse(**response)) + + # verify the output is formatted as expected for this plugin + plugin_data = parsed_response.data + + assert all( + plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"] + ) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(plugin_data["data"]) + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + + with vllm_runner( + model_name, + runner="pooling", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=1, + model_impl="terratorch", + io_processor_plugin="prithvi_to_tiff", + ) as llm_runner: + pooler_output = llm_runner.get_llm().encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + # verify the output is formatted as expected for this plugin + assert all( + hasattr(output, attr) for attr in ["type", "format", "data", "request_id"] + ) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(output.data) diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index 1d7e4475011d0..4dace171a8d3b 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -10,29 +10,38 @@ from vllm.plugins import load_general_plugins def test_platform_plugins(): # simulate workload by running an example import runpy + current_file = __file__ import os + example_file = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(current_file))), - "examples", "offline_inference/basic/basic.py") + "examples", + "offline_inference/basic/basic.py", + ) runpy.run_path(example_file) # check if the plugin is loaded correctly from vllm.platforms import _init_trace, current_platform + assert current_platform.device_name == "DummyDevice", ( f"Expected DummyDevice, got {current_platform.device_name}, " "possibly because current_platform is imported before the plugin" - f" is loaded. The first import:\n{_init_trace}") + f" is loaded. The first import:\n{_init_trace}" + ) def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch): # simulate workload by running an example load_general_plugins() from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16) assert layer.__class__.__name__ == "DummyRotaryEmbedding", ( f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, " - "possibly because the custom op is not registered correctly.") + "possibly because the custom op is not registered correctly." + ) assert hasattr(layer, "addition_config"), ( "Expected DummyRotaryEmbedding to have an 'addition_config' attribute, " - "which is set by the custom op.") + "which is set by the custom op." + ) diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py index 8c21216108685..45902cc874c30 100644 --- a/tests/plugins_tests/test_scheduler_plugins.py +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -3,35 +3,28 @@ import pytest -from vllm.core.scheduler import Scheduler from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler -from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine.llm_engine import LLMEngine -class DummyV0Scheduler(Scheduler): - - def schedule(self): - raise Exception("Exception raised by DummyV0Scheduler") - - -class DummyV1Scheduler(V1Scheduler): - +class DummyV1Scheduler(Scheduler): def schedule(self): raise Exception("Exception raised by DummyV1Scheduler") -def test_scheduler_plugins_v0(monkeypatch: pytest.MonkeyPatch): +def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - with pytest.raises(Exception) as exception_info: + # Explicitly turn off engine multiprocessing so + # that the scheduler runs in this process + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + with pytest.raises(Exception) as exception_info: engine_args = EngineArgs( model="facebook/opt-125m", enforce_eager=True, # reduce test time - scheduler_cls=DummyV0Scheduler, + scheduler_cls=DummyV1Scheduler, ) engine = LLMEngine.from_engine_args(engine_args=engine_args) @@ -40,30 +33,4 @@ def test_scheduler_plugins_v0(monkeypatch: pytest.MonkeyPatch): engine.add_request("0", "foo", sampling_params) engine.step() - assert str( - exception_info.value) == "Exception raised by DummyV0Scheduler" - - -def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - # Explicitly turn off engine multiprocessing so - # that the scheduler runs in this process - m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - - with pytest.raises(Exception) as exception_info: - - engine_args = EngineArgs( - model="facebook/opt-125m", - enforce_eager=True, # reduce test time - scheduler_cls=DummyV1Scheduler, - ) - - engine = V1LLMEngine.from_engine_args(engine_args=engine_args) - - sampling_params = SamplingParams(max_tokens=1) - engine.add_request("0", "foo", sampling_params) - engine.step() - - assert str( - exception_info.value) == "Exception raised by DummyV1Scheduler" + assert str(exception_info.value) == "Exception raised by DummyV1Scheduler" diff --git a/tests/quantization/fp_quant.py b/tests/quantization/fp_quant.py new file mode 100644 index 0000000000000..664ce9d111e4e --- /dev/null +++ b/tests/quantization/fp_quant.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test model set-up and inference for quantized HF models supported +on the GPU backend using FPQuant. + +Validating the configuration and printing results for manual checking. + +Run `pytest tests/quantization/test_fp_quant.py`. +""" + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +MODELS = [ + "ISTA-DASLab/Qwen3-0.6B-RTN-NVFP4", + "ISTA-DASLab/Qwen3-0.6B-RTN-MXFP4", +] +DTYPE = ["bfloat16"] +EAGER = [True, False] + + +@pytest.mark.skipif( + not is_quant_method_supported("fp_quant"), + reason="FPQuant is not supported on this GPU type.", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("eager", EAGER) +def test_fpquant(vllm_runner, model, eager): + with vllm_runner(model, enforce_eager=eager) as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" diff --git a/tests/quantization/reference_mxfp4.py b/tests/quantization/reference_mxfp4.py index 2ef251933f681..d84659ed035eb 100644 --- a/tests/quantization/reference_mxfp4.py +++ b/tests/quantization/reference_mxfp4.py @@ -14,14 +14,15 @@ FLOAT8_E8M0_MAX_EXP = 127 FLOAT4_EXP_BIAS = 1 FLOAT4_MANTISSA_BITS = 1 -FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) -FLOAT16_SIGN_EXPONENT_MASK = (( - (1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS) +FLOAT16_VAL_TO_ADD = 1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1) +FLOAT16_SIGN_EXPONENT_MASK = ( + (1 << (FLOAT16_EXP_BITS + 1)) - 1 +) << FLOAT16_MANTISSA_BITS -BFLOAT16_VAL_TO_ADD = (1 << - (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) -BFLOAT16_SIGN_EXPONENT_MASK = (( - (1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS) +BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1) +BFLOAT16_SIGN_EXPONENT_MASK = ( + (1 << (BFLOAT16_EXP_BITS + 1)) - 1 +) << BFLOAT16_MANTISSA_BITS def e8m0_to_half(scale, half_dtype: torch.dtype): @@ -30,19 +31,19 @@ def e8m0_to_half(scale, half_dtype: torch.dtype): scale_exp = scale.to(torch.int16) - 127 # This can be implemented with bitwise operations in a proper kernel. - scale_half = 2.0**(scale_exp.to(torch.float)) + scale_half = 2.0 ** (scale_exp.to(torch.float)) return scale_half.to(half_dtype) -def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, - half_exp_bias: int, half_mantissa_bits: int): +def upcast_fp4_to_fp16_or_bf16( + val, float_dtype: torch.dtype, half_exp_bias: int, half_mantissa_bits: int +): assert val.dtype == torch.uint8 - unpacked = torch.zeros(*val.shape[:-1], - val.shape[-1] * 2, - dtype=torch.uint8, - device=val.device) + unpacked = torch.zeros( + *val.shape[:-1], val.shape[-1] * 2, dtype=torch.uint8, device=val.device + ) unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits. unpacked[..., ::2] = val & 0x0F # Extract low 4 bits. @@ -72,8 +73,11 @@ def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, new_exp = new_exp.to(torch.int32) sign = sign.to(torch.int32) - qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( - new_mantissa << (half_mantissa_bits - 1)) + qdq_val = ( + (sign << 15) + + (new_exp << half_mantissa_bits) + + (new_mantissa << (half_mantissa_bits - 1)) + ) assert qdq_val.max() <= 65535 assert qdq_val.min() >= 0 @@ -84,8 +88,9 @@ def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, return result -def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, - float_dtype: torch.dtype) -> torch.Tensor: +def dq_mxfp4_torch( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: assert x.dtype == torch.uint8 assert scale.dtype == torch.uint8 @@ -98,10 +103,12 @@ def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, scale_half = e8m0_to_half(scale, half_dtype=float_dtype) - x_half = upcast_fp4_to_fp16_or_bf16(x, - float_dtype=float_dtype, - half_exp_bias=half_exp_bias, - half_mantissa_bits=half_mantissa_bits) + x_half = upcast_fp4_to_fp16_or_bf16( + x, + float_dtype=float_dtype, + half_exp_bias=half_exp_bias, + half_mantissa_bits=half_mantissa_bits, + ) x_half = x_half.reshape(*x_half.shape[:-1], -1, 32) x_half = x_half * scale_half[..., None] @@ -110,8 +117,9 @@ def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, return x_half -def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, - half_exp_bias: int): +def fp16_to_fp4_simulate( + val, half_mantissa_bits: int, half_exp_bits: int, half_exp_bias: int +): # Casts an fp16/bf16 input to the restricted values of float4_e2m1, # that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, # -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]. @@ -119,7 +127,7 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, float_type = val.dtype # "rshift_cuda" not implemented for 'UInt16' - val_view = val.view(torch.int16) #.to(torch.int32) + val_view = val.view(torch.int16) # .to(torch.int32) exp = val_view >> half_mantissa_bits exp = exp & ((1 << half_exp_bits) - 1) @@ -147,23 +155,15 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, tail = mantissa_plus_one & ((1 << tail_bits) - 1) - round_close = (tail < half) # round towards 0 - round_away = (tail > half) # round away from 0 + round_close = tail < half # round towards 0 + round_away = tail > half # round away from 0 tie = tail == half - new_mantissa_close = torch.zeros(val.shape, - device=val.device, - dtype=torch.bool) - new_exp_close = torch.zeros(val.shape, - device=val.device, - dtype=torch.uint16) + new_mantissa_close = torch.zeros(val.shape, device=val.device, dtype=torch.bool) + new_exp_close = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) - new_mantissa_away = torch.zeros(val.shape, - device=val.device, - dtype=torch.bool) - new_exp_away = torch.zeros(val.shape, - device=val.device, - dtype=torch.uint16) + new_mantissa_away = torch.zeros(val.shape, device=val.device, dtype=torch.bool) + new_exp_away = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) @@ -202,27 +202,29 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1)) # Gather round up, round down and tie. - new_exp = round_away * new_exp_away \ - + round_close * new_exp_close \ - + tie * new_exp_tie + new_exp = ( + round_away * new_exp_away + round_close * new_exp_close + tie * new_exp_tie + ) - new_mantissa = round_away * new_mantissa_away \ - + round_close * new_mantissa_close + new_mantissa = round_away * new_mantissa_away + round_close * new_mantissa_close # if new_exp > 3: # new_mantissa = 1 - new_mantissa = new_mantissa + (new_exp > - (2 + half_exp_bias)) * (new_mantissa == 0) + new_mantissa = new_mantissa + (new_exp > (2 + half_exp_bias)) * (new_mantissa == 0) # Clamp the exponent to acceptable values. new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp( - new_exp, half_exp_bias - 2, half_exp_bias + 2) + new_exp, half_exp_bias - 2, half_exp_bias + 2 + ) sign = sign.to(torch.int32) new_mantissa = new_mantissa.to(torch.int32) - qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( - new_mantissa << (half_mantissa_bits - 1)) + qdq_val = ( + (sign << 15) + + (new_exp << half_mantissa_bits) + + (new_mantissa << (half_mantissa_bits - 1)) + ) assert qdq_val.max() <= 65535 assert qdq_val.min() >= 0 @@ -233,8 +235,9 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, return result -def qdq_mxfp4_torch(x: torch.Tensor, - scale_calculation_mode: str = "even") -> torch.Tensor: +def qdq_mxfp4_torch( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: half_dtype = x.dtype if half_dtype == torch.float16: @@ -258,8 +261,7 @@ def qdq_mxfp4_torch(x: torch.Tensor, block_max = block_max.view(torch.uint16).to(torch.int32) - block_max_uint = torch.bitwise_and(block_max + val_to_add, - sign_exponent_mask) + block_max_uint = torch.bitwise_and(block_max + val_to_add, sign_exponent_mask) assert block_max_uint.max() <= 65535 assert block_max_uint.min() >= 0 @@ -268,20 +270,23 @@ def qdq_mxfp4_torch(x: torch.Tensor, block_max = block_max_uint.view(half_dtype) - scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to( - torch.int32) - 2 + scale_exp = ( + FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(torch.int32) - 2 + ) scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP) - scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP) + scale = 2.0 ** (scale_exp - FLOAT8_E8M0_MAX_EXP) scale = scale.to(half_dtype) x = x / scale[..., None] - x_fp4 = fp16_to_fp4_simulate(x, - half_exp_bits=half_exp_bits, - half_mantissa_bits=half_mantissa_bits, - half_exp_bias=half_exp_bias) + x_fp4 = fp16_to_fp4_simulate( + x, + half_exp_bits=half_exp_bits, + half_mantissa_bits=half_mantissa_bits, + half_exp_bias=half_exp_bias, + ) x_fp4 = x_fp4 * scale[..., None] return x_fp4.reshape(*x_fp4.shape[:-2], -1) diff --git a/tests/quantization/test_auto_round.py b/tests/quantization/test_auto_round.py index 1c41d904b8168..69632ae6cac70 100644 --- a/tests/quantization/test_auto_round.py +++ b/tests/quantization/test_auto_round.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test model set-up and inference for quantized HF models supported - on the AutoRound. +on the AutoRound. - Validating the configuration and printing results for manual checking. +Validating the configuration and printing results for manual checking. - Run `pytest tests/quantization/test_auto_round.py`. +Run `pytest tests/quantization/test_auto_round.py`. """ import pytest @@ -14,18 +14,19 @@ from vllm.platforms import current_platform MODELS = [ "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq - "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq + "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", ##auto_round:auto_awq ] -@pytest.mark.skipif(not current_platform.is_cpu() - and not current_platform.is_xpu() - and not current_platform.is_cuda(), - reason="only supports CPU/XPU/CUDA backend.") +@pytest.mark.skipif( + not current_platform.is_cpu() + and not current_platform.is_xpu() + and not current_platform.is_cuda(), + reason="only supports CPU/XPU/CUDA backend.", +) @pytest.mark.parametrize("model", MODELS) def test_auto_round(vllm_runner, model): with vllm_runner(model) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=8) + output = llm.generate_greedy(["The capital of France is"], max_tokens=8) assert output print(f"{output[0][1]}") diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py new file mode 100644 index 0000000000000..3ad68172d771e --- /dev/null +++ b/tests/quantization/test_blackwell_moe.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import os +from typing import Optional + +import pytest + +from tests.utils import RemoteOpenAIServer +from vllm.platforms import current_platform + +if not current_platform.is_device_capability(100): + pytest.skip( + "This test only runs on Blackwell GPUs (SM100).", allow_module_level=True + ) + + +@pytest.fixture(scope="module", autouse=True) +def set_test_environment(): + """Sets environment variables required for this test module.""" + # Make sure TRTLLM attention is available + os.environ["VLLM_HAS_FLASHINFER_CUBIN"] = "1" + # Set compilation threads to 16 to speed up startup + os.environ["FLASHINFER_NVCC_THREADS"] = "16" + + +# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4, +# "text_config": {"num_layers": 4, "num_hidden_layers": 4}} +dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4} + + +def can_initialize(model: str, extra_args: Optional[list[str]] = None): + # Server arguments + extra_args = extra_args if extra_args is not None else [] + server_args = [ + "--max-model-len", + "2048", + "--max-num-batched-tokens", + "256", + "--load-format", + "dummy", + "--trust-remote-code", + "--limit-mm-per-prompt", + json.dumps({"image": 0}), + *extra_args, + ] + + # Launch server and make a simple request + with RemoteOpenAIServer( + model, + server_args, + max_wait_seconds=1500, # Due to FlashInfer compile + override_hf_configs=dummy_hf_overrides, + ) as server: + client = server.get_client() + # Make a simple request to verify the server works + completion = client.completions.create( + model=model, + prompt=["Hello, World!"], + temperature=0, + max_tokens=2, + ) + print(completion) + assert completion.choices[0].text is not None + + +## Llama4 ## + + +@pytest.mark.skip( + reason=( + "RuntimeError: run_moe() Expected a value of type " + "'Optional[List[Tensor]]' for argument '_9' but instead found type " + "'list'." + ) +) +def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8") + + +@pytest.mark.skip(reason="Works, but takes too long to run") +def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8") + + +@pytest.mark.skip(reason="Works, but takes too long to run") +def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4") + + +@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") +def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4") + + +## DeepSeekV3 ## + + +def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") + can_initialize("deepseek-ai/DeepSeek-V3.1") + + +@pytest.mark.skip( + reason=( + "Known issue: lack of kernel support. " + "Expected failure: assert self.block_quant is None" + ) +) +def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("deepseek-ai/DeepSeek-V3.1") + + +def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("deepseek-ai/DeepSeek-V3.1") + + +def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2") + + +@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") +def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2") + + +## GPT-OSS ## + + +def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") + can_initialize("openai/gpt-oss-20b") + + +def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1") + can_initialize("openai/gpt-oss-20b") + + +def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + can_initialize("openai/gpt-oss-20b") diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 296743dbfa041..824d927724e02 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -13,15 +13,25 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensors24, + CompressedTensorsLinearMethod, + CompressedTensorsW4A4Fp4, + CompressedTensorsW4A8Fp8, + CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16, +) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - sparse_cutlass_supported) + sparse_cutlass_supported, +) from vllm.platforms import current_platform # AITER only supports per-channel-per-channel INT8 gemm @@ -29,7 +39,7 @@ from vllm.platforms import current_platform # It does not support mix precision MM and mix quantization scheme. ROCM_AITER_SUPPORTED_INT8_MODEL = [ "neuralmagic/Llama-3.2-1B-quantized.w8a8", - "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", ] # TritonScaledMMLinearKernel only supports symmetric quantization. @@ -43,12 +53,9 @@ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - if not current_platform.is_cpu(): - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.mark.parametrize( @@ -80,8 +87,10 @@ def use_v0_only(monkeypatch): def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): model_path, strategy, quant_type, shape_0, is_symmetric = model_args - if current_platform.is_rocm( - ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + if ( + current_platform.is_rocm() + and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL + ): pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") with vllm_runner(model_path, enforce_eager=True) as llm: @@ -106,14 +115,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): assert zp_valid(gate_up_proj.input_zero_point) assert zp_valid(down_proj.input_zero_point) - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(o_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(gate_up_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(down_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) assert qkv_proj.scheme.strategy == strategy @@ -151,7 +156,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize( - "use_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_aiter", [True, False] if current_platform.is_rocm() else [False] +) def test_compressed_tensors_w8a8_logprobs( hf_runner, vllm_runner, @@ -162,33 +168,36 @@ def test_compressed_tensors_w8a8_logprobs( use_aiter, monkeypatch, ): - - if current_platform.is_rocm( - ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + if ( + current_platform.is_rocm() + and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL + ): pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") if use_aiter: if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: - pytest.skip( - f"Skip model {model_path} as it is not support by aiter.") + pytest.skip(f"Skip model {model_path} as it is not support by aiter.") # this will enable VLLM_ROCM_USE_AITER_LINEAR monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") dtype = "bfloat16" - # skip language translation prompt for the static per tensor asym model - if (model_path == - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" - ): # noqa: E501 + # skip language translation prompt for the static per tensor models + if model_path in ( + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + ): example_prompts = example_prompts[0:-1] with hf_runner(model_path, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model_path, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -224,7 +233,8 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): ], ) @pytest.mark.parametrize( - "use_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_aiter", [True, False] if current_platform.is_rocm() else [False] +) def test_compressed_tensors_w8a8_dynamic_per_token( vllm_runner, model_args, @@ -233,14 +243,15 @@ def test_compressed_tensors_w8a8_dynamic_per_token( ): model_path, strategy = model_args - if current_platform.is_rocm( - ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + if ( + current_platform.is_rocm() + and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL + ): pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") if use_aiter: if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: - pytest.skip( - f"Skip model {model_path} as it is not support by aiter.") + pytest.skip(f"Skip model {model_path} as it is not support by aiter.") # this will enable VLLM_ROCM_USE_AITER_LINEAR monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -251,8 +262,7 @@ def test_compressed_tensors_w8a8_dynamic_per_token( qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) assert not qkv_proj.scheme.is_static_input_scheme assert qkv_proj.scheme.strategy == strategy @@ -266,21 +276,60 @@ def test_compressed_tensors_w8a8_dynamic_per_token( @pytest.mark.parametrize( "wNa16_args", - [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8, - True, False), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True, - False), - ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4, - True, False), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128, - 8, False, False), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", - "channel", None, 8, False, False), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", - "group", 128, 8, False, True)], + [ + ( + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + "channel", + None, + 8, + True, + False, + ), + ( + "nm-testing/tinyllama-oneshot-w4a16-group128-v2", + "group", + 128, + 8, + True, + False, + ), + ( + "nm-testing/tinyllama-oneshot-w8a16-per-channel", + "channel", + None, + 4, + True, + False, + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", + "group", + 128, + 8, + False, + False, + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", + "channel", + None, + 8, + False, + False, + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", + "group", + 128, + 8, + False, + True, + ), + ], +) +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform." ) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args with vllm_runner(model) as llm: @@ -289,13 +338,11 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16) assert qkv_proj.scheme.strategy == strategy - assert qkv_proj.scheme.group_size == (-1 - if group is None else group) + assert qkv_proj.scheme.group_size == (-1 if group is None else group) assert qkv_proj.scheme.pack_factor == pack_factor assert qkv_proj.scheme.symmetric == symmetric @@ -307,8 +354,9 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): assert output -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test is skipped on non-CUDA platform.") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" with vllm_runner(model_path) as llm: @@ -318,8 +366,7 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner): qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24) assert qkv_proj.weight_packed.dtype is torch.int32 @@ -338,8 +385,7 @@ def test_compressed_tensors_fp8(vllm_runner): qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance( qkv_proj.scheme, (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8), @@ -359,8 +405,13 @@ def test_compressed_tensors_fp8(vllm_runner): assert output -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test is skipped on non-CUDA platform.") +@pytest.mark.skipif( + not current_platform.is_kv_cache_dtype_supported("fp8", None), + reason="FP8 KV cache is not supported on this device.", +) +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: @@ -372,10 +423,7 @@ def test_compressed_tensors_kv_cache(vllm_runner): not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.", ) -def _test_2of4_quant_models(qkv_proj, - weight_strategy, - input_strategy, - format="dense"): +def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="dense"): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -389,8 +437,7 @@ def _test_2of4_quant_models(qkv_proj, @pytest.mark.skipif( - not current_platform.is_cuda() - or not current_platform.has_device_capability(90), + not current_platform.is_cuda() or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -437,8 +484,7 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): @pytest.mark.skipif( - not current_platform.is_cuda() - or not current_platform.has_device_capability(90), + not current_platform.is_cuda() or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -599,17 +645,14 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) assert qkv_proj.scheme.weight_quant is None assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = ( - qkv_proj.quant_method.quantization_config.sparsity_scheme_map - ) # noqa: E501 + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -625,7 +668,8 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): reason="Cutlass is not yet supported on this GPU type.", ) @pytest.mark.parametrize( - "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) + "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")] +) def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): model = args_2of4 with vllm_runner(model) as llm: @@ -634,17 +678,14 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) assert qkv_proj.scheme.weight_quant is None assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = ( - qkv_proj.quant_method.quantization_config.sparsity_scheme_map - ) # noqa: E501 + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 assert sparsity_map.get("Linear").format == "sparse-24-bitmask" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -657,9 +698,11 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): @pytest.mark.parametrize( "args", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", - CompressedTensorsW4A16Fp4), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4)]) + [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4), + ], +) def test_compressed_tensors_nvfp4(vllm_runner, args): model, scheme = args with vllm_runner(model, enforce_eager=True) as llm: @@ -668,11 +711,12 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) - if isinstance(qkv_proj.scheme, scheme) or isinstance( - qkv_proj.scheme, - CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported(): + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + if ( + isinstance(qkv_proj.scheme, scheme) + or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4) + and not cutlass_fp4_supported() + ): assert True else: raise AssertionError("FP4 Scheme Mismatch") @@ -683,3 +727,95 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() or not current_platform.has_device_capability(90), + reason="W4A8 FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args", + [("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)], +) +def test_compressed_tensors_w4a8_fp8(vllm_runner, args): + model, scheme = args + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, scheme) + + assert proj.weight_packed.dtype is torch.int32 + assert proj.weight_scale.dtype is torch.float8_e4m3fn + assert proj.weight_chan_scale.dtype is torch.float32 + assert proj.scheme.group_size == 128 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize( + "model,prompt,exp_perplexity", + [ + ( + "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ( + "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ], +) +def test_compressed_tensors_transforms_perplexity( + vllm_runner, model, prompt, exp_perplexity +): + with vllm_runner(model, enforce_eager=True) as llm: + perplexity = llm.generate_prompt_perplexity([prompt])[0] + print(perplexity) + assert perplexity <= exp_perplexity + + +def test_compressed_tensors_fp8_block_enabled(vllm_runner): + model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" + with vllm_runner(model_path) as llm: + fp8_dtype = current_platform.fp8_dtype() + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) + assert isinstance( + qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp + ) + + assert qkv_proj.weight.dtype is fp8_dtype + assert qkv_proj.weight_scale.dtype is torch.float32 + assert len(qkv_proj.weight.shape) == 2 + assert len(qkv_proj.weight_scale.shape) == 2 + + input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + assert isinstance(input_quant_op, QuantFP8) + assert input_quant_op._forward_method == input_quant_op.forward_cuda + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 1843bffd21159..797b565b91af6 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -33,7 +33,6 @@ MODEL_ARG_EXPTYPES = [ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), - # AUTOAWQ ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq_marlin"), ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), @@ -55,4 +54,5 @@ def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None: assert found_quantization_type == expected_type, ( f"Expected quant_type == {expected_type} for {model_path}, " f"but found {found_quantization_type} " - f"for no --quantization {quantization_arg} case") + f"for no --quantization {quantization_arg} case" + ) diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py index 08d9573ecf0b8..25d1dc59f6174 100644 --- a/tests/quantization/test_cpu_offload.py +++ b/tests/quantization/test_cpu_offload.py @@ -1,77 +1,108 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Expanded quantized model tests for CPU offloading -# Base tests: tests/basic_correctness/test_cpu_offload.py - -import pytest - -from tests.quantization.utils import is_quant_method_supported - -from ..utils import compare_two_settings - - -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") -def test_cpu_offload_fp8(): - # Test quantization of an unquantized checkpoint - compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", - ["--quantization", "fp8"], - ["--quantization", "fp8", "--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test loading a quantized checkpoint - compare_two_settings("neuralmagic/Qwen2-1.5B-Instruct-FP8", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - - -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="gptq_marlin is not supported on this GPU type.") -def test_cpu_offload_gptq(monkeypatch): - # This quant method is sensitive to dummy weights, so we force real weights - monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') - # Test GPTQ Marlin - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test GPTQ - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", - ["--quantization", "gptq"], - ["--quantization", "gptq", "--cpu-offload-gb", "1"], - max_wait_seconds=480) - - -@pytest.mark.skipif(not is_quant_method_supported("awq_marlin"), - reason="awq_marlin is not supported on this GPU type.") -def test_cpu_offload_awq(monkeypatch): - # This quant method is sensitive to dummy weights, so we force real weights - monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') - # Test AWQ Marlin - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-AWQ", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test AWQ - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-AWQ", - ["--quantization", "awq"], - ["--quantization", "awq", "--cpu-offload-gb", "1"], - max_wait_seconds=480) - - -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="gptq_marlin is not supported on this GPU type.") -def test_cpu_offload_compressed_tensors(monkeypatch): - # This quant method is sensitive to dummy weights, so we force real weights - monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') - # Test wNa16 - compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test w4a16_marlin24 - compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", - [], ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test w8a8 - compare_two_settings( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Expanded quantized model tests for CPU offloading +# Base tests: tests/basic_correctness/test_cpu_offload.py + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from ..utils import compare_two_settings + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.", +) +def test_cpu_offload_fp8(): + # Test quantization of an unquantized checkpoint + compare_two_settings( + "meta-llama/Llama-3.2-1B-Instruct", + ["--quantization", "fp8"], + ["--quantization", "fp8", "--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test loading a quantized checkpoint + compare_two_settings( + "neuralmagic/Qwen2-1.5B-Instruct-FP8", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="gptq_marlin is not supported on this GPU type.", +) +def test_cpu_offload_gptq(monkeypatch): + # This quant method is sensitive to dummy weights, so we force real weights + monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") + # Test GPTQ Marlin + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test GPTQ + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", + ["--quantization", "gptq"], + ["--quantization", "gptq", "--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("awq_marlin"), + reason="awq_marlin is not supported on this GPU type.", +) +def test_cpu_offload_awq(monkeypatch): + # This quant method is sensitive to dummy weights, so we force real weights + monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") + # Test AWQ Marlin + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-AWQ", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test AWQ + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-AWQ", + ["--quantization", "awq"], + ["--quantization", "awq", "--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="gptq_marlin is not supported on this GPU type.", +) +def test_cpu_offload_compressed_tensors(monkeypatch): + # This quant method is sensitive to dummy weights, so we force real weights + monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") + # Test wNa16 + compare_two_settings( + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test w4a16_marlin24 + compare_two_settings( + "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test w8a8 + compare_two_settings( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 1e3e69e008bd4..2a72f734e431b 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -2,9 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # flake8: noqa -"""Tests experts_int8 quantization startup and generation, +"""Tests experts_int8 quantization startup and generation, doesn't test correctness """ + import pytest from tests.quantization.utils import is_quant_method_supported @@ -14,8 +15,10 @@ from ..models.registry import HF_EXAMPLE_MODELS MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] -@pytest.mark.skipif(not is_quant_method_supported("experts_int8"), - reason="ExpertsInt8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("experts_int8"), + reason="ExpertsInt8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [10]) @@ -30,6 +33,5 @@ def test_model_experts_int8_startup( model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_transformers_version(on_fail="skip") - with vllm_runner(model, dtype=dtype, - quantization="experts_int8") as vllm_model: + with vllm_runner(model, dtype=dtype, quantization="experts_int8") as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index d781f462b4ad7..6b9a33059815f 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -4,13 +4,16 @@ Run `pytest tests/quantization/test_fp8.py --forked`. """ + import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, - Fp8LinearMethod) +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8KVCacheMethod, + Fp8LinearMethod, +) from vllm.platforms import current_platform MODELS = [ @@ -20,15 +23,18 @@ MODELS = [ ] -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: - + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_model_load_and_run( + vllm_runner, model_id: str, force_marlin: bool, use_rocm_aiter: bool, monkeypatch +) -> None: if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -50,18 +56,22 @@ KV_CACHE_MODELS = [ ] -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, - use_rocm_aiter: bool, monkeypatch): + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_kv_cache_model_load_and_run( + vllm_runner, model_id: str, use_rocm_aiter: bool, monkeypatch +): if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: def check_model(model): @@ -93,26 +103,34 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, print(outputs[0][1]) -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_load_fp16_model( + vllm_runner, + kv_cache_dtype: str, + force_marlin: bool, + use_rocm_aiter: bool, + monkeypatch, +) -> None: if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") - with vllm_runner("facebook/opt-125m", - quantization="fp8", - kv_cache_dtype=kv_cache_dtype) as llm: + with vllm_runner( + "facebook/opt-125m", quantization="fp8", kv_cache_dtype=kv_cache_dtype + ) as llm: def check_model(model): fc1 = model.model.decoder.layers[0].fc1 @@ -139,26 +157,29 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, pytest.skip( "Skip `test_load_fp16_model`. " "It only runs on ROCm platform with FP8 compute." - " e.g. MI300X and above.") + " e.g. MI300X and above." + ) else: # unsupported platform - pytest.skip("Skip `test_load_fp16_model`. " - "It only runs on CUDA and ROCm platform.") + pytest.skip( + "Skip `test_load_fp16_model`. " + "It only runs on CUDA and ROCm platform." + ) llm.apply_model(check_model) -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_scaled_fp8_quant(dtype) -> None: - def quantize_ref(tensor, inv_scale): # The reference implementation that fully aligns to # the kernel being tested. finfo = torch.finfo(torch.float8_e4m3fn) scale = inv_scale.reciprocal() - qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, - max=finfo.max) + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) qweight = qweight.to(torch.float8_e4m3fn) return qweight @@ -177,26 +198,23 @@ def test_scaled_fp8_quant(dtype) -> None: # Reference dynamic quantizaton y = quantize_ref(x, inv_scale) - torch.testing.assert_close(ref_y, - per_tensor_dequantize(y, inv_scale, dtype)) + torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Static quantization y, _ = ops.scaled_fp8_quant(x, inv_scale) - torch.testing.assert_close(ref_y, - per_tensor_dequantize(y, inv_scale, dtype)) + torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Padding y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17) assert y.shape[0] == 17 torch.testing.assert_close( ref_y, - per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, - dtype)) + per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype), + ) # non-contiguous input with padding m, n, padded_stride = 975, 512, 576 - padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * - 13).to(dtype) + padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * 13).to(dtype) x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1) assert not x_nc.is_contiguous() @@ -209,19 +227,21 @@ def test_scaled_fp8_quant(dtype) -> None: # reference dynamic quantization y_nc = quantize_ref(x_nc, inv_scale_nc) torch.testing.assert_close( - ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype) + ) # static quantization y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc) torch.testing.assert_close( - ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype) + ) # padding after non-contiguous input quantization - y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, - inv_scale_nc, - num_token_padding=m + 10) + y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc, num_token_padding=m + 10) assert y_nc_pad.shape[0] == m + 10 torch.testing.assert_close( ref_y_nc, - per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), - inv_scale_nc, dtype)) + per_tensor_dequantize( + torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype + ), + ) diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py index aea50e99c1dd5..c71f4b8156113 100644 --- a/tests/quantization/test_gptq_dynamic.py +++ b/tests/quantization/test_gptq_dynamic.py @@ -10,10 +10,10 @@ import torch from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinearMethod from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_dynamic_override) + get_dynamic_override, +) PROMPT = "On the surface of Mars, we found" @@ -21,51 +21,59 @@ PROMPT = "On the surface of Mars, we found" # The second layer is quantized using bits=8, group_size=32 # All other layers (layer index >= 2) are not quantized MODEL_QUANT = [ - ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", - True), - ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", - False), + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", True), + ( + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + False, + ), ] @pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT) -def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool, - monkeypatch): - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_gptq_with_dynamic( + vllm_runner, model_id: str, use_marlin_kernel: bool, monkeypatch +): + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) + linear_method_cls = ( + GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) + ) - linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( - GPTQLinearMethod) + with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm: - for name, submodule in (vllm_model.llm.llm_engine.model_executor. - driver_worker.model_runner.model.named_modules()): - if name == "lm_head": - assert isinstance(submodule.quant_method, linear_method_cls) - elif name == 'model.layers.0.self_attn.qkv_proj': - # The first layer is quantized using bits=4, group_size=128 - # desc_act=True - assert isinstance(submodule.quant_method, linear_method_cls) - config = submodule.quant_method.quant_config - assert config.weight_bits == 4 - assert config.group_size == 128 - assert config.desc_act - elif name == 'model.layers.1.self_attn.qkv_proj': - # The second layer is quantized using bits=8, group_size=32 - # desc_act=False - assert isinstance(submodule.quant_method, linear_method_cls) - config = submodule.quant_method.quant_config - assert get_dynamic_override(config, layer_name=name, - key="bits") == 8 - assert get_dynamic_override(config, - layer_name=name, - key="group_size") == 32 - assert not get_dynamic_override( - config, layer_name=name, key="desc_act") - elif (name == 'model.layers.2.self_attn.qkv_proj' - or name == 'model.layers.2.mlp.gate_up_proj'): - # All other layers (layer index >= 2) are not quantized - assert isinstance(submodule.quant_method, UnquantizedLinearMethod) + def check_model(model): + for name, submodule in model.named_modules(): + if name == "lm_head": + assert isinstance(submodule.quant_method, linear_method_cls) + elif name == "model.layers.0.self_attn.qkv_proj": + # The first layer is quantized using bits=4, group_size=128 + # desc_act=True + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert config.weight_bits == 4 + assert config.group_size == 128 + assert config.desc_act + elif name == "model.layers.1.self_attn.qkv_proj": + # The second layer is quantized using bits=8, group_size=32 + # desc_act=False + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert ( + get_dynamic_override(config, layer_name=name, key="bits") == 8 + ) + assert ( + get_dynamic_override(config, layer_name=name, key="group_size") + == 32 + ) + assert not get_dynamic_override( + config, layer_name=name, key="desc_act" + ) + elif ( + name == "model.layers.2.self_attn.qkv_proj" + or name == "model.layers.2.mlp.gate_up_proj" + ): + # All other layers (layer index >= 2) are not quantized + assert isinstance(submodule.quant_method, UnquantizedLinearMethod) - del vllm_model + llm.apply_model(check_model) diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py index 34b1b6c2e5b6d..ae9b1df3377dc 100644 --- a/tests/quantization/test_ipex_quant.py +++ b/tests/quantization/test_ipex_quant.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test model set-up and inference for quantized HF models supported - on the CPU/GPU backend using IPEX (including AWQ/GPTQ). - - Validating the configuration and printing results for manual checking. +on the CPU/GPU backend using IPEX (including AWQ/GPTQ). - Run `pytest tests/quantization/test_ipex_quant.py`. +Validating the configuration and printing results for manual checking. + +Run `pytest tests/quantization/test_ipex_quant.py`. """ import pytest @@ -19,14 +19,14 @@ MODELS = [ DTYPE = ["bfloat16"] -@pytest.mark.skipif(not current_platform.is_cpu() - and not current_platform.is_xpu(), - reason="only supports Intel CPU/XPU backend.") +@pytest.mark.skipif( + not current_platform.is_cpu() and not current_platform.is_xpu(), + reason="only supports Intel CPU/XPU backend.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", DTYPE) def test_ipex_quant(vllm_runner, model, dtype): with vllm_runner(model, dtype=dtype) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output print(output) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index b24964a9d0a9f..bae8b7f7d535b 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -9,10 +9,10 @@ import pytest import torch from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( - UnquantizedEmbeddingMethod) + UnquantizedEmbeddingMethod, +) PROMPT = "On the surface of Mars, we found" @@ -29,22 +29,22 @@ def test_lm_head( lm_head_quantized: bool, monkeypatch, ) -> None: - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") - with vllm_runner(model_id, dtype=torch.float16, - max_model_len=2048) as vllm_model: + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model: def check_model(model): lm_head_layer = model.lm_head if lm_head_quantized: - assert isinstance(lm_head_layer.quant_method, - (GPTQLinearMethod, GPTQMarlinLinearMethod)) + assert isinstance( + lm_head_layer.quant_method, + (GPTQLinearMethod, GPTQMarlinLinearMethod), + ) else: - assert isinstance(lm_head_layer.quant_method, - UnquantizedEmbeddingMethod) + assert isinstance( + lm_head_layer.quant_method, UnquantizedEmbeddingMethod + ) vllm_model.apply_model(check_model) - print( - vllm_model.generate_greedy(["Hello my name is"], - max_tokens=10)[0][1]) + print(vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1]) diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index fcbfa681d75c9..8abf65d29784d 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -11,33 +11,34 @@ import pytest import torch from tests.quantization.utils import is_quant_method_supported -from vllm.platforms import current_platform @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - if not current_platform.is_cpu(): - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") -@pytest.mark.skipif(not is_quant_method_supported("modelopt"), - reason="ModelOpt FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("modelopt"), + reason="ModelOpt FP8 is not supported on this GPU type.", +) def test_modelopt_fp8_checkpoint_setup(vllm_runner): """Test ModelOpt FP8 checkpoint loading and structure validation.""" - # TODO: provide a small publically available test checkpoint - model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" - "TinyLlama-1.1B-Chat-v1.0-fp8-0710") + # TODO: provide a small publicly available test checkpoint + model_path = ( + "/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" + "TinyLlama-1.1B-Chat-v1.0-fp8-0710" + ) # Skip test if checkpoint doesn't exist if not os.path.exists(model_path): - pytest.skip(f"Test checkpoint not found at {model_path}. " - "This test requires a local ModelOpt FP8 checkpoint.") + pytest.skip( + f"Test checkpoint not found at {model_path}. " + "This test requires a local ModelOpt FP8 checkpoint." + ) - with vllm_runner(model_path, quantization="modelopt", - enforce_eager=True) as llm: + with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -49,11 +50,12 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner): # Check that ModelOpt quantization method is properly applied from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptFp8LinearMethod) + ModelOptFp8LinearMethod, + ) + assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod) assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod) - assert isinstance(gate_up_proj.quant_method, - ModelOptFp8LinearMethod) + assert isinstance(gate_up_proj.quant_method, ModelOptFp8LinearMethod) assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod) # Check weight dtype is FP8 @@ -63,23 +65,23 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner): assert down_proj.weight.dtype == torch.float8_e4m3fn # Check scales are present and have correct dtype - assert hasattr(qkv_proj, 'weight_scale') - assert hasattr(qkv_proj, 'input_scale') + assert hasattr(qkv_proj, "weight_scale") + assert hasattr(qkv_proj, "input_scale") assert qkv_proj.weight_scale.dtype == torch.float32 assert qkv_proj.input_scale.dtype == torch.float32 - assert hasattr(o_proj, 'weight_scale') - assert hasattr(o_proj, 'input_scale') + assert hasattr(o_proj, "weight_scale") + assert hasattr(o_proj, "input_scale") assert o_proj.weight_scale.dtype == torch.float32 assert o_proj.input_scale.dtype == torch.float32 - assert hasattr(gate_up_proj, 'weight_scale') - assert hasattr(gate_up_proj, 'input_scale') + assert hasattr(gate_up_proj, "weight_scale") + assert hasattr(gate_up_proj, "input_scale") assert gate_up_proj.weight_scale.dtype == torch.float32 assert gate_up_proj.input_scale.dtype == torch.float32 - assert hasattr(down_proj, 'weight_scale') - assert hasattr(down_proj, 'input_scale') + assert hasattr(down_proj, "weight_scale") + assert hasattr(down_proj, "input_scale") assert down_proj.weight_scale.dtype == torch.float32 assert down_proj.input_scale.dtype == torch.float32 diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py index 5f78bc30504c0..e8ea4148585bf 100644 --- a/tests/quantization/test_ptpc_fp8.py +++ b/tests/quantization/test_ptpc_fp8.py @@ -4,31 +4,53 @@ Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. """ + import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod -from vllm.model_executor.layers.quantization.ptpc_fp8 import ( - PTPCFp8LinearMethod) +from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod from vllm.platforms import current_platform +UNSUPPORTED_STR = ( + "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only " + "support output dtype of bfloat16. torch.float16 is specified." +) -@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), - reason="PTPC FP8 is not supported on this GPU type.") -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="This test is for ROCm GPU.") + +@pytest.fixture(scope="function", autouse=True) +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + +@pytest.mark.skipif( + not is_quant_method_supported("ptpc_fp8"), + reason="PTPC FP8 is not supported on this GPU type.", +) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="This test is for ROCm GPU.") @pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: - try: - with vllm_runner("facebook/opt-125m", - dtype=dtype, - quantization="ptpc_fp8", - kv_cache_dtype=kv_cache_dtype) as llm: + llm = vllm_runner( + "facebook/opt-125m", + dtype=dtype, + quantization="ptpc_fp8", + kv_cache_dtype=kv_cache_dtype, + ) + except AssertionError as e: + if str(e) == UNSUPPORTED_STR: + # If the error message matches, the test passes + return + else: + # If the error message does not match, re-raise the exception + raise - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + with llm: + + def check_model(model): fc1 = model.model.decoder.layers[0].fc1 assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) if kv_cache_dtype == "ptpc_fp8": @@ -40,17 +62,8 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: if current_platform.has_device_capability(94): # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fnuz - else: - pytest.skip() - output = llm.generate_greedy("Hello my name is", max_tokens=20) - assert output - except AssertionError as e: - if str( - e - ) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501 - # If the error message matches, the test passes - pass - else: - # If the error message does not match, re-raise the exception - raise + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 4a0c8ba4d8a95..1e65d9a995ce2 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -4,13 +4,14 @@ Run `pytest tests/quantization/test_quark.py`. -See also `tests/kernels/moe/test_mxfp4_moe.py`. +See also `tests/kernels/moe/test_ocp_mx_moe.py`. """ -import importlib import importlib.metadata import os from dataclasses import dataclass +from importlib.util import find_spec +from typing import Optional import huggingface_hub import lm_eval @@ -19,44 +20,45 @@ import torch from packaging import version from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 - QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkLinearMethod, + QuarkW8A8Fp8, + QuarkW8A8Int8, +) from vllm.platforms import current_platform from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch -QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( - "quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') +QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") if QUARK_MXFP4_AVAILABLE: - from quark.torch.export.nn.modules.realquantizer import ( - StaticScaledRealQuantizer) + from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer from quark.torch.kernel import mx as mx_kernel from quark.torch.quantization.config.config import FP4PerGroupSpec try: huggingface_hub.list_repo_refs( - "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ") + "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ" + ) HF_HUB_AMD_ORG_ACCESS = True except huggingface_hub.errors.RepositoryNotFoundError: HF_HUB_AMD_ORG_ACCESS = False @pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module relies on V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') +def enable_pickle(monkeypatch): + """`LLM.apply_model` requires pickling a function.""" + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") -@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8']) -@pytest.mark.parametrize('tp', [1]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("tp", [1]) def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test" - with vllm_runner(model_path, - kv_cache_dtype=kv_cache_dtype, - tensor_parallel_size=tp) as llm: + with vllm_runner( + model_path, kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tp + ) as llm: def check_model(model): layer = model.model.layers[0] @@ -77,7 +79,31 @@ def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): assert output -@pytest.mark.parametrize('tp', [1]) +@pytest.mark.parametrize("tp", [1]) +def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp): + model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts" + with vllm_runner(model_path, tensor_parallel_size=tp) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8) + + if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): + assert qkv_proj.weight.dtype is current_platform.fp8_dtype() + assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1] + assert qkv_proj.weight_scale.shape[1] == 1 + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + + +@pytest.mark.parametrize("tp", [1]) def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" with vllm_runner(model_path, tensor_parallel_size=tp) as llm: @@ -103,17 +129,18 @@ def test_quark_fp8_parity(vllm_runner): llm_kwargs = { "tensor_parallel_size": 1, "enforce_eager": True, - "gpu_memory_utilization": 0.1 + "gpu_memory_utilization": 0.1, } - with (vllm_runner(quark_model_id, **llm_kwargs) as - quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): - quark_model = (quark_handle.llm.llm_engine.model_executor. - driver_worker.model_runner.model) - quark_state_dict = quark_model.state_dict() + with ( + vllm_runner(quark_model_id, **llm_kwargs) as quark_handle, + vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle, + ): - fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker. - model_runner.model) - fp8_state_dict = fp8_model.state_dict() + def get_state_dict(model): + return {k: v.cpu() for k, v in model.state_dict().items()} + + (quark_state_dict,) = quark_handle.apply_model(get_state_dict) + (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict) assert fp8_state_dict.keys() == quark_state_dict.keys() @@ -122,38 +149,93 @@ def test_quark_fp8_parity(vllm_runner): @dataclass -class ModelCase: - model_id: str - tp: int - - -@dataclass -class GSM8KAccuracyTestConfig: +class AccuracyTestConfig: model_name: str excepted_value: float - def get_model_args(self) -> str: - return ( - f"pretrained={self.model_name}," - "dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768" - ) + def get_model_args( + self, + tp_size: int, + model_max_len: Optional[int] = None, + kwargs: Optional[dict] = None, + ) -> dict: + if kwargs is None: + kwargs = {} + + model_args = { + "pretrained": self.model_name, + "dtype": "auto", + "add_bos_token": True, + "tensor_parallel_size": tp_size, + "gpu_memory_utilization": 0.7, + **kwargs, + } + if model_max_len is not None: + model_args["max_model_len"] = model_max_len + + return model_args -ACCURACY_CONFIGS = [ +GSM8K_ACCURACY_CONFIGS = [ # Private model. - GSM8KAccuracyTestConfig( + AccuracyTestConfig( model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant", - excepted_value=0.96), + excepted_value=0.96, + ), +] + +WIKITEXT_ACCURACY_CONFIGS = [ + AccuracyTestConfig( + model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3", + excepted_value=11.3, + ), + AccuracyTestConfig( + model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2", + excepted_value=10.6, + ), + AccuracyTestConfig( + model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4 + ), ] -@pytest.mark.parametrize("config", ACCURACY_CONFIGS) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS) +@pytest.mark.parametrize("tp_size", [1, 2]) +def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): + if torch.cuda.device_count() < tp_size: + pytest.skip( + f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}" + ) + + task = "wikitext" + rtol = 0.1 + + # Smaller cuda_graph_sizes to speed up the test. + results = lm_eval.simple_evaluate( + model="vllm", + model_args=config.get_model_args( + tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]} + ), + tasks=task, + batch_size=64, + ) + + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][task]["word_perplexity,none"] + assert ( + measured_value < EXPECTED_VALUE + rtol + and measured_value > EXPECTED_VALUE - rtol + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif( not HF_HUB_AMD_ORG_ACCESS, - reason="Read access to huggingface.co/amd is required for this test.") -def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): + reason="Read access to huggingface.co/amd is required for this test.", +) +def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): if torch.cuda.device_count() < 8: pytest.skip( f"This test requires >=8 gpus, got only {torch.cuda.device_count()}" @@ -166,7 +248,7 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): results = lm_eval.simple_evaluate( model="vllm", - model_args=config.get_model_args(), + model_args=config.get_model_args(tp_size=8, model_max_len=38768), tasks=task, batch_size=64, num_fewshot=8, @@ -174,28 +256,26 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): EXPECTED_VALUE = config.excepted_value measured_value = results["results"][task]["exact_match,strict-match"] - assert (measured_value - rtol < EXPECTED_VALUE - and measured_value + rtol > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - rtol < EXPECTED_VALUE + and measured_value + rtol > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" del os.environ["VLLM_USE_TRITON_FLASH_ATTN"] -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("scalings", - [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) -def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, - scalings: list[int]): +@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]): torch.manual_seed(0) hidden_size = 64 * 32 - inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - - 0.5) * 2 + inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2 for i in range(hidden_size // 32): - inp[:, i * 32:(i + 1) * - 32] = inp[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + inp[:, i * 32 : (i + 1) * 32] = ( + inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)] + ) inp_kernel = inp.clone() inp_kernel_clone = inp_kernel.clone() @@ -204,20 +284,20 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, res_torch = qdq_mxfp4_torch(inp_kernel, "even") for i in range(hidden_size // 32): - assert torch.all(torch.isfinite(res_hip[:, i * 32:(i + 1) * 32])) - assert torch.all(torch.isfinite(res_torch[:, i * 32:(i + 1) * 32])) + assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32])) + assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32])) - torch.testing.assert_close(res_hip[:, i * 32:(i + 1) * 32], - res_torch[:, i * 32:(i + 1) * 32]) + torch.testing.assert_close( + res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32] + ) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("scalings", - [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) -def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype, - scalings: list[int]): +@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_dequant_kernel_match_quark( + float_dtype: torch.dtype, scalings: list[int] +): qspec = FP4PerGroupSpec( ch_axis=-1, group_size=32, @@ -244,8 +324,9 @@ def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype, # Make it so that different groups have different scales. for i in range(hidden_size // 32): - w[:, i * 32:(i + 1) * - 32] = w[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + w[:, i * 32 : (i + 1) * 32] = ( + w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)] + ) observer(w) scale, _ = observer._calculate_qparams() diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 84705e92c85bb..b70c2ee7fe2e6 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -6,18 +6,25 @@ See https://github.com/vllm-project/vllm/issues/11926 for more details. Run `pytest tests/quantization/test_register_quantization_config.py`. """ + from typing import Any, Optional import pytest import torch import torch.nn.functional as F -from vllm.model_executor.layers.linear import LinearBase # noqa: E501 -from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.linear import ( + LinearBase, # noqa: E501 + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import ( - QuantizationMethods, get_quantization_config, register_quantization_config) + QuantizationMethods, + get_quantization_config, + register_quantization_config, +) from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig) + QuantizationConfig, +) class FakeQuantLinearMethod(UnquantizedLinearMethod): @@ -28,10 +35,12 @@ class FakeQuantLinearMethod(UnquantizedLinearMethod): super().__init__() self.num_bits = num_bits - def apply(self, - layer: "torch.nn.Module", - x: "torch.Tensor", - bias: Optional["torch.Tensor"] = None) -> "torch.Tensor": + def apply( + self, + layer: "torch.nn.Module", + x: "torch.Tensor", + bias: Optional["torch.Tensor"] = None, + ) -> "torch.Tensor": """Perform fake quantization before the linear layer.""" # Calculate the scales dynamically @@ -40,8 +49,11 @@ class FakeQuantLinearMethod(UnquantizedLinearMethod): scales = (max_val - min_val) / (2**self.num_bits - 1) # Fake quantize the input - quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1), - 2**(self.num_bits - 1) - 1) + quant_x = torch.clamp( + torch.round(x / scales), + -(2 ** (self.num_bits - 1)), + 2 ** (self.num_bits - 1) - 1, + ) dequant_x = quant_x * scales return F.linear(dequant_x, layer.weight, bias) @@ -79,8 +91,9 @@ class CustomQuantConfig(QuantizationConfig): """Create a config class from the model's quantization config.""" return CustomQuantConfig(num_bits=config.get("num_bits", 8)) - def get_quant_method(self, layer: "torch.nn.Module", - prefix: str) -> Optional["FakeQuantLinearMethod"]: + def get_quant_method( + self, layer: "torch.nn.Module", prefix: str + ) -> Optional["FakeQuantLinearMethod"]: """Get the quantize method to use for the quantized layer.""" if isinstance(layer, LinearBase): return FakeQuantLinearMethod(num_bits=self.num_bits) @@ -99,24 +112,29 @@ def test_register_quantization_config(): register_quantization_config("custom_quant")(CustomQuantConfig) -@pytest.mark.parametrize(argnames="model", - argvalues=[ - "meta-llama/Llama-3.2-1B-Instruct", - ]) +@pytest.mark.parametrize( + argnames="model", + argvalues=[ + "meta-llama/Llama-3.2-1B-Instruct", + ], +) def test_custom_quant(vllm_runner, model, monkeypatch): """Test infer with the custom quantization method.""" - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") - with vllm_runner(model_name=model, - quantization="custom_quant", - enforce_eager=True) as llm: + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 - layer = model.model.layers[0] - qkv_proj = layer.self_attn.qkv_proj + with vllm_runner( + model_name=model, quantization="custom_quant", enforce_eager=True + ) as llm: - # Check the quantization method is FakeQuantLinearMethod - assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod) + def check_model(model): + layer = model.model.layers[0] + qkv_proj = layer.self_attn.qkv_proj + + # Check the quantization method is FakeQuantLinearMethod + assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod) + + llm.apply_model(check_model) output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output diff --git a/tests/quantization/test_rtn.py b/tests/quantization/test_rtn.py index bc2b468f97d8c..370625ed34792 100644 --- a/tests/quantization/test_rtn.py +++ b/tests/quantization/test_rtn.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright © 2025, Oracle and/or its affiliates. -"""Tests RTN quantization startup and generation, +"""Tests RTN quantization startup and generation, doesn't test correctness """ + import pytest from tests.quantization.utils import is_quant_method_supported @@ -14,8 +15,10 @@ MODELS = [ ] -@pytest.mark.skipif(not is_quant_method_supported("rtn"), - reason="RTN is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("rtn"), + reason="RTN is not supported on this GPU type.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [10]) @@ -27,6 +30,5 @@ def test_model_rtn_startup( dtype: str, max_tokens: int, ) -> None: - with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index eef3568efea12..bc24c51b57b28 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -13,14 +13,14 @@ TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") def test_pre_quantized_model(vllm_runner): - with vllm_runner("drisspg/fp8-opt-125m", - quantization="torchao", - dtype="bfloat16", - enforce_eager=True) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + "drisspg/fp8-opt-125m", + quantization="torchao", + dtype="bfloat16", + enforce_eager=True, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") @@ -29,50 +29,225 @@ def test_pre_quantized_model(vllm_runner): [ "cuda:0", # {"": "cuda"}, - ]) -def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, - pt_load_map_location): + ], +) +def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, pt_load_map_location): torch._dynamo.reset() model_name = "jerryzh168/opt-125m-int8wo-partial-quant" - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location=pt_load_map_location) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location=pt_load_map_location, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): torch._dynamo.reset() model_name = "jerryzh168/opt-125m-int4wo-per-module" - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location="cuda:0") as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output - print(output) @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): torch._dynamo.reset() model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao" - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location="cuda:0") as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = "torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2-0.14.0.dev" + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_on_the_fly_quant_config_dict_json(vllm_runner): + """Testing on the fly quantization, load_weights integration point, + with config dict serialized to json string + """ + torch._dynamo.reset() + model_name = "facebook/opt-125m" + + import json + + from torchao.core.config import config_to_dict + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow + + torchao_quant_config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ) + hf_overrides = { + "quantization_config_dict_json": json.dumps( + config_to_dict(torchao_quant_config) + ) + } + with vllm_runner( + model_name=model_name, + dtype="bfloat16", + pt_load_map_location="cuda:0", + quantization="torchao", + hf_overrides=hf_overrides, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_on_the_fly_quant_config_file(vllm_runner): + """Testing on the fly quantization, load_weights integration point, + with config file + """ + torch._dynamo.reset() + model_name = "facebook/opt-125m" + import json + from tempfile import NamedTemporaryFile + + from torchao.core.config import config_to_dict + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow + + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + + with NamedTemporaryFile(mode="w", delete=False) as f: + f.write(json.dumps(config_to_dict(config))) + # close the file to save it + f.close() + config_file_name = str(f.name) + + hf_overrides = {"quantization_config_file": config_file_name} + with vllm_runner( + model_name=model_name, + dtype="bfloat16", + pt_load_map_location="cuda:0", + quantization="torchao", + hf_overrides=hf_overrides, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_reload_weights(): + import json + + from torchao.core.config import config_to_dict + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow + + from vllm import LLM, SamplingParams + + torchao_quant_config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ) + + hf_overrides = { + "quantization_config_dict_json": json.dumps( + config_to_dict(torchao_quant_config) + ) + } + + llm = LLM( + model="Qwen/Qwen3-0.6B", + dtype="bfloat16", + load_format="dummy", + enforce_eager=True, + quantization="torchao", + hf_overrides=hf_overrides, + ) + # Update load format from `dummy` to `auto` + llm.collective_rpc( + "update_config", args=({"load_config": {"load_format": "auto"}},) + ) + # Now reload real weights inplace + llm.collective_rpc("reload_weights") + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0, top_p=0.95) + outputs = llm.generate(prompts, sampling_params) + # make sure it runs + for output in outputs: + generated_text = output.outputs[0].text + assert generated_text + # can also uncomment locally to make sure the generated + # output makes sense + # prompt = output.prompt + # print(f"Prompt: {prompt!r}") + # print(f"Output: {generated_text!r}") + # print("-" * 60) + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = ( + "torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors" + ) + with vllm_runner(model_name=model_name, dtype="bfloat16") as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner): + torch._dynamo.reset() + model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev" + with vllm_runner( + model_name=model_name, dtype="bfloat16", pt_load_map_location="cuda:0" + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output - print(output) if __name__ == "__main__": diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py new file mode 100644 index 0000000000000..ddda50fe770a6 --- /dev/null +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +# Create a concrete test implementation of BaseThinkingReasoningParser +class TestThinkingReasoningParser(BaseThinkingReasoningParser): + """Test implementation of BaseThinkingReasoningParser.""" + + @property + def start_token(self) -> str: + return "<test:think>" + + @property + def end_token(self) -> str: + return "</test:think>" + + +class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser): + """Alternative test implementation with different tokens.""" + + @property + def start_token(self) -> str: + return "<alt:start>" + + @property + def end_token(self) -> str: + return "<alt:end>" + + +# Use a test model +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def test_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + # Add custom test tokens + test_tokens = ["<test:think>", "</test:think>", "<alt:start>", "<alt:end>"] + existing_tokens = set(tokenizer.get_vocab().keys()) + new_tokens = [token for token in test_tokens if token not in existing_tokens] + if new_tokens: + tokenizer.add_tokens(new_tokens) + return tokenizer + + +class TestBaseThinkingReasoningParserInit: + """ + Test initialization and basic properties of + BaseThinkingReasoningParser. + """ + + def test_successful_initialization(self, test_tokenizer): + """Test successful initialization with valid tokens.""" + parser = TestThinkingReasoningParser(test_tokenizer) + assert parser.start_token == "<test:think>" + assert parser.end_token == "</test:think>" + assert parser.start_token_id is not None + assert parser.end_token_id is not None + + def test_initialization_with_missing_tokenizer(self): + """Test that initialization fails without tokenizer.""" + with pytest.raises(ValueError, match="model tokenizer must be passed"): + TestThinkingReasoningParser(None) + + def test_initialization_with_missing_tokens(self, test_tokenizer): + """Test that initialization fails when tokens are not in vocabulary.""" + + # Create a parser with tokens not in vocabulary + class MissingTokenParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "<missing:start>" + + @property + def end_token(self) -> str: + return "<missing:end>" + + with pytest.raises( + RuntimeError, match="could not locate think start/end tokens" + ): + MissingTokenParser(test_tokenizer) + + def test_initialization_with_empty_tokens(self, test_tokenizer): + """Test that initialization fails with empty token strings.""" + + class EmptyTokenParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + with pytest.raises( + ValueError, match="start_token and end_token must be defined" + ): + EmptyTokenParser(test_tokenizer) + + +class TestBaseThinkingReasoningParserMethods: + """Test the methods of BaseThinkingReasoningParser.""" + + def test_is_reasoning_end(self, test_tokenizer): + """Test the is_reasoning_end method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + + # Test with end token present + assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True + + # Test without end token + assert parser.is_reasoning_end([1, 2, 3, 4]) is False + + # Test with empty list + assert parser.is_reasoning_end([]) is False + + def test_extract_content_ids(self, test_tokenizer): + """Test the extract_content_ids method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + + # Test with end token in the middle + input_ids = [1, 2, end_token_id, 4, 5] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [4, 5] + + # Test with end token at the end + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test without end token + input_ids = [1, 2, 3, 4] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test with end token as last element (should not extract) + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + +class TestBaseThinkingReasoningParserExtraction: + """Test reasoning content extraction methods.""" + + def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer): + """Test extraction when both start and end tokens are present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "<test:think>This is reasoning</test:think>This is content" + reasoning, content = parser.extract_reasoning_content(model_output, request) + + assert reasoning == "This is reasoning" + assert content == "This is content" + + def test_extract_reasoning_content_only_end_token(self, test_tokenizer): + """Test extraction when only end token is present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "This is reasoning</test:think>This is content" + reasoning, content = parser.extract_reasoning_content(model_output, request) + + assert reasoning == "This is reasoning" + assert content == "This is content" + + def test_extract_reasoning_content_no_end_token(self, test_tokenizer): + """Test extraction when no end token is present.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "This is just content" + reasoning, content = parser.extract_reasoning_content(model_output, request) + + assert reasoning == "This is just content" + assert content is None + + def test_extract_reasoning_content_empty_output(self, test_tokenizer): + """Test extraction with empty output.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "" + reasoning, content = parser.extract_reasoning_content(model_output, request) + + assert reasoning == "" + assert content is None + + def test_extract_reasoning_content_only_tokens(self, test_tokenizer): + """Test extraction with only tokens and no content.""" + parser = TestThinkingReasoningParser(test_tokenizer) + request = ChatCompletionRequest(messages=[], model="test-model") + + model_output = "<test:think></test:think>" + reasoning, content = parser.extract_reasoning_content(model_output, request) + + assert reasoning == "" + assert content is None + + +class TestBaseThinkingReasoningParserStreaming: + """Test streaming functionality of BaseThinkingReasoningParser.""" + + @pytest.mark.parametrize("streaming", [True, False]) + def test_simple_reasoning_extraction(self, test_tokenizer, streaming): + """ + Test basic reasoning extraction in both + streaming and non-streaming modes. + """ + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = [ + "<test:think>", + "Some ", + "reasoning ", + "content", + "</test:think>", + "Final ", + "answer", + ] + + reasoning, content = run_reasoning_extraction( + parser, model_output, streaming=streaming + ) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" + + def test_streaming_with_incremental_deltas(self, test_tokenizer): + """Test streaming processing with small incremental deltas.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "<test:think>", + "Some ", + "reasoning ", + "content", + "</test:think>", + "Final ", + "answer", + ] + + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" + + def test_streaming_with_start_token(self, test_tokenizer): + """Test streaming with start token included.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "<test:think>", + "Some ", + "reasoning", + "</test:think>", + "Answer", + ] + + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) + + assert reasoning == "Some reasoning" + assert content == "Answer" + + def test_streaming_no_end_token(self, test_tokenizer): + """Test streaming when no end token is encountered.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "<test:think>", + "Some ", + "reasoning ", + "without ", + "end", + ] + + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) + + assert reasoning == "Some reasoning without end" + assert content is None + + def test_streaming_only_end_token(self, test_tokenizer): + """Test streaming when only end token appears.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + deltas = [ + "<test:think>", + "Reasoning ", + "content", + "</test:think>", + "Final", + ] + + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) + + assert reasoning == "Reasoning content" + assert content == "Final" + + +class TestBaseThinkingReasoningParserMultipleImplementations: + """ + Test that multiple implementations of + BaseThinkingReasoningParser work correctly. + """ + + def test_different_token_implementations(self, test_tokenizer): + """ + Test that different implementations + with different tokens work independently. + """ + parser1 = TestThinkingReasoningParser(test_tokenizer) + parser2 = TestThinkingReasoningParserAlt(test_tokenizer) + + # Test parser1 + model_output1 = "Reasoning1</test:think>Content1" + reasoning1, content1 = run_reasoning_extraction(parser1, [model_output1]) + assert reasoning1 == "Reasoning1" + assert content1 == "Content1" + + # Test parser2 + model_output2 = "Reasoning2<alt:end>Content2" + reasoning2, content2 = run_reasoning_extraction(parser2, [model_output2]) + assert reasoning2 == "Reasoning2" + assert content2 == "Content2" + + # Verify tokens are different + assert parser1.start_token != parser2.start_token + assert parser1.end_token != parser2.end_token + assert parser1.start_token_id != parser2.start_token_id + assert parser1.end_token_id != parser2.end_token_id + + +class TestBaseThinkingReasoningParserEdgeCases: + """Test edge cases and error conditions.""" + + def test_multiple_end_tokens(self, test_tokenizer): + """Test behavior with multiple end tokens.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = "First</test:think>Middle</test:think>Last" + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should stop at first end token + assert reasoning == "First" + assert content == "Middle</test:think>Last" + + def test_nested_tokens(self, test_tokenizer): + """Test behavior with nested-like token patterns.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = "<test:think>Outer<test:think>Inner</test:think>Content" + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should process normally, start from first start token + assert reasoning == "Outer<test:think>Inner" + assert content == "Content" + + def test_malformed_tokens(self, test_tokenizer): + """Test behavior with malformed token-like strings.""" + parser = TestThinkingReasoningParser(test_tokenizer) + + model_output = "<test:thinking>Not a real token</test:thinking>Content" + reasoning, content = run_reasoning_extraction(parser, [model_output]) + + # Should treat as regular content since tokens don't match exactly + assert reasoning == ("<test:thinking>Not a real token</test:thinking>Content") + assert content is None diff --git a/tests/reasoning/test_deepseekr1_reasoning_parser.py b/tests/reasoning/test_deepseekr1_reasoning_parser.py index 987f3c48de0c0..946d01c123c5d 100644 --- a/tests/reasoning/test_deepseekr1_reasoning_parser.py +++ b/tests/reasoning/test_deepseekr1_reasoning_parser.py @@ -259,15 +259,15 @@ def test_reasoning( output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: list[str] = [ - deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) - for token in output + deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(deepseek_r1_qwen_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + deepseek_r1_qwen_tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -281,7 +281,8 @@ def test_reasoning( if param_dict["content"] is not None: content = parser.extract_content_ids(output_ids) assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids( - deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"])) + deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]) + ) else: content = parser.extract_content_ids(output) assert content == [] diff --git a/tests/reasoning/test_glm4_moe_reasoning_parser.py b/tests/reasoning/test_glm4_moe_reasoning_parser.py new file mode 100644 index 0000000000000..0a8595a00fcb5 --- /dev/null +++ b/tests/reasoning/test_glm4_moe_reasoning_parser.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "glm45" +start_token = "<think>" +end_token = "</think>" + +REASONING_MODEL_NAME = "zai-org/GLM-4.5" + + +@pytest.fixture(scope="module") +def glm45_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +WITH_THINK = { + "output": "<think>This is a reasoning section</think>This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +WITH_THINK_STREAM = { + "output": "<think>This is a reasoning section</think>This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +WITHOUT_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": False, +} + +WITHOUT_THINK_STREAM = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": False, +} + +COMPLETE_REASONING = { + "output": "<think>This is a reasoning section</think>", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +MULTILINE_REASONING = { + "output": "<think>This is a reasoning\nsection</think>This is the rest\nThat", + "reasoning_content": "This is a reasoning\nsection", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +ONLY_OPEN_TAG = { + "output": "<think>This is a reasoning section", + "reasoning_content": None, + "content": "<think>This is a reasoning section", + "is_reasoning_end": False, +} + +ONLY_OPEN_TAG_STREAM = { + "output": "<think>This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} + +TEST_CASES = [ + pytest.param( + False, + WITH_THINK, + id="with_think", + ), + pytest.param( + True, + WITH_THINK_STREAM, + id="with_think_stream", + ), + pytest.param( + False, + WITHOUT_THINK, + id="without_think", + ), + pytest.param( + True, + WITHOUT_THINK_STREAM, + id="without_think_stream", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_stream", + ), + pytest.param( + False, + MULTILINE_REASONING, + id="multiline_reasoning", + ), + pytest.param( + True, + MULTILINE_REASONING, + id="multiline_reasoning_stream", + ), + pytest.param( + False, + ONLY_OPEN_TAG, + id="only_open_tag", + ), + pytest.param( + True, + ONLY_OPEN_TAG_STREAM, + id="only_open_tag_stream", + ), +] + +STILL_REASONING_PROMPT = """[gMASK]<sop><|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> +<think>The user is asking for the capital of""" + +DONE_REASONING_PROMPT = """[gMASK]<sop><|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> +<think>The user is asking for the capital of France.</think> +The capital of France is Paris.""" + +MULTI_TURN_STILL_REASONING_PROMPT = """[gMASK]<sop><|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> +<think></think> +The capital of France is Paris.<|user|> +What about Chile?<|assistant|> +<think>The user is asking for the capital of""" + +MULTI_TURN_DONE_REASONING_PROMPT = """[gMASK]<sop><|system|> +You are a helpful assistant.<|user|> +What is the capital of France?<|assistant|> +<think></think> +The capital of France is Paris.<|user|> +What about Chile?<|assistant|> +<think>The user is asking for the capital of Chile.</think> +The capital of Chile is Santiago.""" + +REASONING_END_TEST_CASES = [ + pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"), + pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"), + pytest.param( + MULTI_TURN_STILL_REASONING_PROMPT, False, id="multi_turn_still_reasoning" + ), + pytest.param( + MULTI_TURN_DONE_REASONING_PROMPT, True, id="multi_turn_done_reasoning" + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + glm45_tokenizer, +): + output = glm45_tokenizer.tokenize(param_dict["output"]) + output_tokens: list[str] = [ + glm45_tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + glm45_tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + output_ids = glm45_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + +@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES) +def test_is_reasoning_end_full_prompt( + prompt: str, is_reasoning_end: bool, glm45_tokenizer +): + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + glm45_tokenizer + ) + tokens = glm45_tokenizer.tokenize(prompt) + token_ids = glm45_tokenizer.convert_tokens_to_ids(tokens) + check_is_reasoning_end = parser.is_reasoning_end(token_ids) + assert check_is_reasoning_end == is_reasoning_end diff --git a/tests/reasoning/test_granite_reasoning_parser.py b/tests/reasoning/test_granite_reasoning_parser.py index 38cab73a45f22..de1663408d72d 100644 --- a/tests/reasoning/test_granite_reasoning_parser.py +++ b/tests/reasoning/test_granite_reasoning_parser.py @@ -11,8 +11,7 @@ START_REASONING = "Here is my thought process:" START_RESPONSE = "Here is my response:" SIMPLE_REASONING = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -27,14 +26,12 @@ NO_REASONING = { "content": "This is content", } MULTIPLE_LINES = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } REASONING_WITH_THINK = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -44,8 +41,7 @@ COMPLETE_REASONING_WITH_THINK = { "content": None, } MULTIPLE_LINES_WITH_THINK = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } @@ -137,12 +133,13 @@ def test_reasoning( output_tokens: list[str] = [ tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -229,18 +226,15 @@ STREAMING_9 = { ## The Response is ongoing, and the delta mixes reasoning content / content STREAMING_10 = { "previous_text": "Here is my thought process: foo", - "current_text": - "Here is my thought process: foo bar Here is my response: baz", + "current_text": "Here is my thought process: foo bar Here is my response: baz", "delta_text": " bar Here is my response: baz", "reasoning_content": " bar ", "content": " baz", } # The delta text starts a new substring that might be a response special seq STREAMING_11 = { - "previous_text": - "Here is my thought process: This is a reasoning section ", - "current_text": - "Here is my thought process: This is a reasoning section Here", + "previous_text": "Here is my thought process: This is a reasoning section ", + "current_text": "Here is my thought process: This is a reasoning section Here", "delta_text": "Here", "reasoning_content": None, "content": None, @@ -320,14 +314,17 @@ STREAMING_SUBCASES = [ @pytest.mark.parametrize("param_dict", STREAMING_SUBCASES) def test_streaming_subcases(param_dict): # Get all of the token IDs - previous_token_ids = tokenizer.encode( - param_dict["previous_text"] - ) if param_dict["previous_text"] is not None else [] + previous_token_ids = ( + tokenizer.encode(param_dict["previous_text"]) + if param_dict["previous_text"] is not None + else [] + ) current_token_ids = tokenizer.encode(param_dict["current_text"]) delta_token_ids = tokenizer.encode(param_dict["delta_text"]) - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) response = parser.extract_reasoning_content_streaming( previous_text=param_dict["previous_text"], @@ -339,8 +336,7 @@ def test_streaming_subcases(param_dict): ) # Streaming currently expects at least one of reasoning content / content, # so the response should return None in that case. - if param_dict["reasoning_content"] is None and param_dict[ - "content"] is None: + if param_dict["reasoning_content"] is None and param_dict["content"] is None: assert response is None else: assert isinstance(response, DeltaMessage) diff --git a/tests/reasoning/test_hunyuan_reasoning_parser.py b/tests/reasoning/test_hunyuan_reasoning_parser.py index f9238267f02ed..b7e3ea73ccdef 100644 --- a/tests/reasoning/test_hunyuan_reasoning_parser.py +++ b/tests/reasoning/test_hunyuan_reasoning_parser.py @@ -13,15 +13,13 @@ START_RESPONSE = "\n</think>\n<answer>\n" END_RESPONSE = "\n</answer>" NO_REASONING_QUICK_THROUGHT = { - "output": - f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "output": f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", # noqa: E501 "reasoning_content": None, "content": "This is the rest", } SIMPLE_REASONING = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -42,14 +40,12 @@ NO_REASONING = { "content": "This is content", } MULTIPLE_LINES = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } REASONING_WITH_THINK = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -59,8 +55,7 @@ COMPLETE_REASONING_WITH_THINK = { "content": None, } MULTIPLE_LINES_WITH_THINK = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } @@ -122,9 +117,7 @@ TEST_CASES = [ NO_REASONING, id="no_reasoning_streaming", ), - pytest.param(True, - NO_REASONING_QUICK_THROUGHT, - id="no_reasoning_quick_stream"), + pytest.param(True, NO_REASONING_QUICK_THROUGHT, id="no_reasoning_quick_stream"), pytest.param( True, MULTIPLE_LINES, @@ -148,8 +141,9 @@ TEST_CASES = [ ] # Global tokenizer initialization to avoid repeated loading -tokenizer = AutoTokenizer.from_pretrained("tencent/Hunyuan-A13B-Instruct", - trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained( + "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True +) @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) @@ -162,12 +156,13 @@ def test_reasoning( output_tokens: list[str] = [ tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py index 91a22f6f5d720..ff7f94b40ee11 100644 --- a/tests/reasoning/test_mistral_reasoning_parser.py +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -2,9 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from mistral_common.tokens.tokenizers.base import SpecialTokens -from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, - Tekkenizer) from tests.reasoning.utils import run_reasoning_extraction_mistral from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -15,29 +12,9 @@ parser_name = "mistral" @pytest.fixture(scope="module") def mistral_tokenizer(): - # TODO(Julien): upon model release change to a tokenizer already configured. - # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507") - assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) - # Add think special tokens to the tokenizer - mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) - mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value) - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { - k: v - for k, v in - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() - if v not in {35, 36} - } - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value] = 35 - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value] = 36 - mistral_tokenizer.instruct.BEGIN_THINK = 35 - mistral_tokenizer.instruct.END_THINK = 36 - # ================================================================= + "mistralai/Magistral-Small-2509" + ) return mistral_tokenizer @@ -290,39 +267,45 @@ def test_mistral_reasoning( if index_think != -1: output_before_think = output[:index_think] output_tokens += mistral_tokenizer.tokenizer.encode( - output_before_think, False, False) + output_before_think, False, False + ) output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK] if index_end_think != -1: - output_middle = output[index_think + len_think:index_end_think] - output_after_think = output[index_end_think + len_end_think:] + output_middle = output[index_think + len_think : index_end_think] + output_after_think = output[index_end_think + len_end_think :] output_tokens += mistral_tokenizer.tokenizer.encode( - output_middle, False, False) + output_middle, False, False + ) output_tokens += [mistral_tokenizer.instruct.END_THINK] output_tokens += mistral_tokenizer.tokenizer.encode( - output_after_think, False, False) + output_after_think, False, False + ) else: - output_middle = output[index_think + len_think:] + output_middle = output[index_think + len_think :] output_tokens += mistral_tokenizer.tokenizer.encode( - output_middle, False, False) + output_middle, False, False + ) elif index_end_think != -1: output_before_think = output[:index_end_think] - output_after_think = output[index_end_think + len_end_think:] + output_after_think = output[index_end_think + len_end_think :] output_tokens += mistral_tokenizer.tokenizer.encode( - output_before_think, False, False) + output_before_think, False, False + ) output_tokens += [mistral_tokenizer.instruct.END_THINK] output_tokens += mistral_tokenizer.tokenizer.encode( - output_after_think, False, False) + output_after_think, False, False + ) else: - output_tokens += mistral_tokenizer.tokenizer.encode( - output, False, False) + output_tokens += mistral_tokenizer.tokenizer.encode(output, False, False) - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(mistral_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + mistral_tokenizer + ) - reasoning, content = run_reasoning_extraction_mistral(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction_mistral( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -335,7 +318,8 @@ def test_mistral_reasoning( if param_dict["content"] is not None: content = parser.extract_content_ids(output_tokens) assert content == mistral_tokenizer.tokenizer.encode( - param_dict["content"], bos=False, eos=False) + param_dict["content"], bos=False, eos=False + ) else: content = parser.extract_content_ids(output_tokens) assert content == [] diff --git a/tests/reasoning/test_olmo3_reasoning_parser.py b/tests/reasoning/test_olmo3_reasoning_parser.py new file mode 100644 index 0000000000000..4a2eca994610e --- /dev/null +++ b/tests/reasoning/test_olmo3_reasoning_parser.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "olmo3" +START_REASONING = "<think>" +END_REASONING = "</think>" + +NO_REASONING = { + "output": f"{START_REASONING}{END_REASONING}No thoughts, head empty!", + "reasoning_content": None, + "content": "No thoughts, head empty!", +} + +NO_REASONING_WITH_NEWLINE = { + "output": f"{START_REASONING}\n{END_REASONING}\n\nNo thoughts, head empty!", + "reasoning_content": "\n", + "content": "\n\nNo thoughts, head empty!", +} + +SIMPLE_REASONING = { + "output": f"{START_REASONING}This is a reasoning section{END_REASONING}This is the rest", # noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} + +SIMPLE_REASONING_WITH_NEWLINE = { + "output": f"{START_REASONING} Look!\n\nI'm thinking...{END_REASONING}\nThis is the rest", # noqa: E501 + "reasoning_content": " Look!\n\nI'm thinking...", + "content": "\nThis is the rest", +} + +SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES = { + "output": f"{START_REASONING}\nLook!\nI'm thinking...\n\n{END_REASONING}\n\n\nThis is the rest", # noqa: E501 + "reasoning_content": "\nLook!\nI'm thinking...\n\n", + "content": "\n\n\nThis is the rest", +} + +NO_REASONING_ONLY_END_THINK = { + "output": f"{END_REASONING}\n\nNo thoughts, head empty!", + "reasoning_content": None, + "content": "\n\nNo thoughts, head empty!", +} + +REASONING_ONLY_END_THINK = { + "output": f"The user is asking me not to think.{END_REASONING}No thoughts!", + "reasoning_content": "The user is asking me not to think.", + "content": "No thoughts!", +} + +TEST_CASES = [ + pytest.param( + False, # not streaming + NO_REASONING, + id="no_reasoning", + ), + pytest.param( + False, # not streaming + NO_REASONING_WITH_NEWLINE, + id="no_reasoning_with_newline", + ), + pytest.param( + False, # not streaming + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, # not streaming + SIMPLE_REASONING_WITH_NEWLINE, + id="simple_reasoning_with_newline", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES, + id="simple_reasoning_with_multiple_newlines", + ), + pytest.param( + False, # not streaming + NO_REASONING_ONLY_END_THINK, + id="no_reasoning_only_end_think", + ), + pytest.param( + False, # not streaming + REASONING_ONLY_END_THINK, + id="yes_reasoning_only_end_think", + ), + pytest.param( + True, # enable streaming + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param( + True, # enable streaming + NO_REASONING_WITH_NEWLINE, + id="no_reasoning_with_newline_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_NEWLINE, + id="simple_reasoning_with_newline_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES, + id="simple_reasoning_with_multiple_newlines_streaming", + ), + pytest.param( + True, # enable streaming + NO_REASONING_ONLY_END_THINK, + id="no_reasoning_only_end_think_streaming", + ), + pytest.param( + True, # enable streaming + REASONING_ONLY_END_THINK, + id="yes_reasoning_only_end_think_streaming", + ), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("allenai/dolma2-tokenizer") + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict[str, str], +): + output = tokenizer.tokenize(param_dict["output"]) + + # decode everything to tokens + model_output: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser: ReasoningParser = parser_cls(tokenizer) + + reasoning, content = run_reasoning_extraction( + reasoning_parser=parser, model_output=model_output, streaming=streaming + ) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/tests/reasoning/test_qwen3_reasoning_parser.py b/tests/reasoning/test_qwen3_reasoning_parser.py index 2d5557d5cdc13..c06e40d72de2c 100644 --- a/tests/reasoning/test_qwen3_reasoning_parser.py +++ b/tests/reasoning/test_qwen3_reasoning_parser.py @@ -50,8 +50,7 @@ COMPLETE_REASONING = { "content": None, } MULTILINE_REASONING = { - "output": - "<think>This is a reasoning\nsection</think>This is the rest\nThat", + "output": "<think>This is a reasoning\nsection</think>This is the rest\nThat", "reasoning_content": "This is a reasoning\nsection", "content": "This is the rest\nThat", } @@ -131,12 +130,13 @@ def test_reasoning( output_tokens: list[str] = [ qwen3_tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(qwen3_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + qwen3_tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] diff --git a/tests/reasoning/test_seedoss_reasoning_parser.py b/tests/reasoning/test_seedoss_reasoning_parser.py new file mode 100644 index 0000000000000..b356b8545f412 --- /dev/null +++ b/tests/reasoning/test_seedoss_reasoning_parser.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, cast + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "seed_oss" +start_token = "<seed:think>" +end_token = "</seed:think>" + +# Use a test model that contains our custom tokens +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + +@pytest.fixture(scope="module") +def seedoss_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + # Add custom SeedOSS tokens if they don't exist + if start_token not in tokenizer.get_vocab(): + tokenizer.add_tokens([start_token, end_token]) + return tokenizer + + +SIMPLE_REASONING: dict[str, Any] = { + "output": "This is a reasoning section</seed:think>This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING: dict[str, Any] = { + "output": "This is a reasoning section</seed:think>", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +NO_CONTENT: dict[str, Any] = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING_STREAMING: dict[str, Any] = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES: dict[str, Any] = { + "output": "This\nThat</seed:think>This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +WITH_START_TOKEN: dict[str, Any] = { + "output": ("<seed:think>This is a reasoning section</seed:think>This is the rest"), + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +ONLY_END_TOKEN: dict[str, Any] = { + "output": "Some reasoning</seed:think>This is the rest", + "reasoning_content": "Some reasoning", + "content": "This is the rest", + "is_reasoning_end": True, +} +NO_TOKENS: dict[str, Any] = { + "output": "This is just content without any reasoning tokens", + "reasoning_content": "This is just content without any reasoning tokens", + "content": None, + "is_reasoning_end": False, +} + + +def test_seedoss_reasoning_parser_creation(seedoss_tokenizer): + """Test that the SeedOSS reasoning parser can be created and registered.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + assert isinstance(parser, ReasoningParser) + assert parser.start_token == start_token + assert parser.end_token == end_token + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_simple_reasoning(seedoss_tokenizer, streaming): + """Test basic reasoning extraction with both tokens.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming + ) + + assert reasoning == SIMPLE_REASONING["reasoning_content"] + assert content == SIMPLE_REASONING["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_complete_reasoning(seedoss_tokenizer, streaming): + """Test reasoning extraction when there's no content after reasoning.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming + ) + + assert reasoning == COMPLETE_REASONING["reasoning_content"] + assert content == COMPLETE_REASONING["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_content(seedoss_tokenizer, streaming): + """Test when there's no end token - everything is reasoning content.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, NO_CONTENT["output"])], streaming=streaming + ) + + assert reasoning == NO_CONTENT["reasoning_content"] + assert content == NO_CONTENT["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_multiple_lines(seedoss_tokenizer, streaming): + """Test reasoning extraction with multiline content.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming + ) + + assert reasoning == MULTIPLE_LINES["reasoning_content"] + assert content == MULTIPLE_LINES["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_with_start_token(seedoss_tokenizer, streaming): + """Test reasoning extraction with both start and end tokens.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming + ) + + assert reasoning == WITH_START_TOKEN["reasoning_content"] + assert content == WITH_START_TOKEN["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_only_end_token(seedoss_tokenizer, streaming): + """ + Test reasoning extraction with only end token + (SeedOSS typical behavior). + """ + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming + ) + + assert reasoning == ONLY_END_TOKEN["reasoning_content"] + assert content == ONLY_END_TOKEN["content"] + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tokens(seedoss_tokenizer, streaming): + """Test when there are no reasoning tokens at all.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + reasoning, content = run_reasoning_extraction( + parser, [cast(str, NO_TOKENS["output"])], streaming=streaming + ) + + assert reasoning == NO_TOKENS["reasoning_content"] + assert content == NO_TOKENS["content"] + + +def test_is_reasoning_end(seedoss_tokenizer): + """Test the is_reasoning_end method.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + # Test with end token present + end_token_id = parser.end_token_id + assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True + + # Test without end token + assert parser.is_reasoning_end([1, 2, 3, 4]) is False + + +def test_extract_content_ids(seedoss_tokenizer): + """Test the extract_content_ids method.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + end_token_id = parser.end_token_id + + # Test with end token in the middle + input_ids = [1, 2, end_token_id, 4, 5] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [4, 5] + + # Test with end token at the end + input_ids = [1, 2, 3, end_token_id] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + # Test without end token + input_ids = [1, 2, 3, 4] + content_ids = parser.extract_content_ids(input_ids) + assert content_ids == [] + + +def test_streaming_delta_processing(seedoss_tokenizer): + """Test streaming processing with small deltas.""" + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(seedoss_tokenizer) + + # Test streaming with incremental tokens + deltas = ["Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer"] + + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) + + assert reasoning == "Some reasoning content" + assert content == "Final answer" diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index 9af5fa5addbc2..788136e996815 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -3,14 +3,12 @@ from typing import Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.reasoning import ReasoningParser from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer class StreamingReasoningReconstructor: - def __init__(self): self.reasoning_content = None self.other_content = None @@ -19,8 +17,8 @@ class StreamingReasoningReconstructor: # content and the reasoning content should not be present # at the same time assert delta.content is None or delta.reasoning_content is None, ( - "Both content and reasoning content are present in the " - "delta message") + "Both content and reasoning content are present in the delta message" + ) if delta.content is not None: if self.other_content is None: self.other_content = delta.content @@ -51,7 +49,8 @@ def run_reasoning_extraction( ) else: reasoning, content = run_reasoning_extraction_nonstreaming( - reasoning_parser, model_output, request) + reasoning_parser, model_output, request + ) return reasoning, content @@ -61,8 +60,9 @@ def run_reasoning_extraction_mistral( request: Union[ChatCompletionRequest, None] = None, streaming: bool = False, ) -> tuple[Optional[str], Optional[str]]: - assert isinstance(reasoning_parser.model_tokenizer, - MistralTokenizer), type(reasoning_parser.model_tokenizer) + assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( + reasoning_parser.model_tokenizer + ) if streaming: reconstructor = run_reasoning_extraction_streaming_mistral( reasoning_parser, @@ -75,9 +75,11 @@ def run_reasoning_extraction_mistral( ) else: str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens( - model_output) + model_output + ) reasoning, content = run_reasoning_extraction_nonstreaming( - reasoning_parser, str_output, request) + reasoning_parser, str_output, request + ) return reasoning, content @@ -88,7 +90,8 @@ def run_reasoning_extraction_nonstreaming( ) -> tuple[Optional[str], Optional[str]]: request = request or ChatCompletionRequest(messages=[], model="test-model") return reasoning_parser.extract_reasoning_content( - model_output=''.join(model_output), request=request) + model_output="".join(model_output), request=request + ) def run_reasoning_extraction_streaming( @@ -128,16 +131,16 @@ def run_reasoning_extraction_streaming_mistral( model_deltas: list[int], request: Union[ChatCompletionRequest, None] = None, ) -> StreamingReasoningReconstructor: - assert isinstance(reasoning_parser.model_tokenizer, - MistralTokenizer), type(reasoning_parser.model_tokenizer) + assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( + reasoning_parser.model_tokenizer + ) request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingReasoningReconstructor() previous_text = "" previous_tokens: list[int] = [] for model_delta in model_deltas: token_delta = [model_delta] - delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens( - [model_delta])[0] + delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens([model_delta])[0] current_text = previous_text + delta current_tokens = previous_tokens + token_delta delta_message = reasoning_parser.extract_reasoning_content_streaming( diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index bdf48c7687b25..78f5ab3e2d19c 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -10,13 +10,6 @@ from transformers import AutoModelForSeq2SeqLM from vllm.assets.audio import AudioAsset - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - # FIXME(zhuohan): The test can not pass if we: # 1. Increase max_tokens to 256. # 2. Increase beam_width to 8. @@ -43,19 +36,21 @@ def test_beam_search_single_input( ) -> None: example_prompts = example_prompts[:1] with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, - max_tokens) + hf_outputs = hf_model.generate_beam_search( + example_prompts, beam_width, max_tokens + ) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search( + example_prompts, beam_width, max_tokens + ) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for j, (hf_text, - vllm_text) in enumerate(zip(hf_output_texts, - vllm_output_texts)): + for j, (hf_text, vllm_text) in enumerate( + zip(hf_output_texts, vllm_output_texts) + ): print(f">>>{j}-th hf output:") print(hf_text) print(f">>>{j}-th vllm output:") @@ -63,8 +58,62 @@ def test_beam_search_single_input( assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], ( - f"Test{i} output{j}:\nHF: {hf_output_ids}\n" - f"vLLM: {vllm_output_ids}") + f"Test{i} output{j}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}" + ) + + +@pytest.mark.skip_v1 # FIXME: This fails on V1 right now. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) +def test_beam_search_with_concurrency_limit( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + # example_prompts[1]&[3]&[7] fails due to unknown reason even without + # concurrency limit. skip them for now. + example_prompts = example_prompts[:8] + concurrency_limit = 2 + assert len(example_prompts) > concurrency_limit + with vllm_runner(model, dtype=dtype) as vllm_model: + outputs_with_limit = vllm_model.generate_beam_search( + example_prompts, beam_width, max_tokens, concurrency_limit=concurrency_limit + ) + outputs_without_limit = [] + + for i in range(0, len(example_prompts), concurrency_limit): + outputs_without_limit.extend( + vllm_model.generate_beam_search( + example_prompts[i : i + concurrency_limit], beam_width, max_tokens + ) + ) + + correct = True + for i in range(len(example_prompts)): + output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i] + output_ids_without_limit, output_texts_without_limit = outputs_without_limit[i] + for j, (text_with_limit, text_without_limit) in enumerate( + zip(output_texts_with_limit, output_texts_without_limit) + ): + print(f">>>{j}-th with limit output:") + print(text_with_limit) + print(f">>>{j}-th without limit output:") + print(text_without_limit) + assert len(output_ids_with_limit) == len(output_ids_without_limit) + for j in range(len(output_ids_with_limit)): + if output_ids_with_limit[j] != output_ids_without_limit[j]: + print( + f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" + f"-limit: {output_ids_without_limit}" + ) + correct = False + assert correct @pytest.mark.parametrize("dtype", ["half"]) @@ -85,11 +134,10 @@ def test_beam_search_passes_multimodal_data( model = "Qwen/Qwen2-Audio-7B-Instruct" audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>" prompts = [ - f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501 + f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 ] - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM) as hf_model: audio_token_id = hf_model.config.audio_token_index eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|> hf_outputs = hf_model.generate_beam_search( @@ -107,17 +155,15 @@ def test_beam_search_passes_multimodal_data( audios=audios, ) - seq_with_no_audio_toks = lambda seq: [ - tok for tok in seq if tok != audio_token_id - ] + seq_with_no_audio_toks = lambda seq: [tok for tok in seq if tok != audio_token_id] for i in range(len(prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for j, (hf_text, - vllm_text) in enumerate(zip(hf_output_texts, - vllm_output_texts)): + for j, (hf_text, vllm_text) in enumerate( + zip(hf_output_texts, vllm_output_texts) + ): print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:") print(hf_text) print(f">>>{j}-th vllm output:") @@ -130,12 +176,10 @@ def test_beam_search_passes_multimodal_data( # token to match features, while the vLLM helper maintains the # single audio token in the input text filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j]) - filtered_vllm_output_ids = seq_with_no_audio_toks( - vllm_output_ids[j]) + filtered_vllm_output_ids = seq_with_no_audio_toks(vllm_output_ids[j]) # HF output IDs may contain the end of sequence - if len(filtered_hf_output_ids - ) == len(filtered_vllm_output_ids) + 1: + if len(filtered_hf_output_ids) == len(filtered_vllm_output_ids) + 1: assert filtered_hf_output_ids[-1] == eos_token_id filtered_hf_output_ids = filtered_hf_output_ids[:-1] diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index ea4a17dd2306f..d1609b24cc5a8 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -9,13 +9,6 @@ import pytest from vllm import SamplingParams - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - # We also test with llama because it has generation_config to specify EOS # (past regression). MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"] @@ -32,11 +25,11 @@ def test_ignore_eos( max_tokens: int, ) -> None: with vllm_runner(model, dtype=dtype) as vllm_model: - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) for prompt in example_prompts: ignore_eos_output = vllm_model.llm.generate( - prompt, sampling_params=sampling_params) + prompt, sampling_params=sampling_params + ) output_length = len(ignore_eos_output[0].outputs[0].token_ids) assert output_length == max_tokens diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py deleted file mode 100644 index 87f40b1005312..0000000000000 --- a/tests/samplers/test_logprobs.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm import SamplingParams - -from ..conftest import VllmRunner - -MODELS = ["distilbert/distilgpt2"] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module is V0 only since it uses dtype=float, so - set VLLM_USE_V1=0 for all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", - ["float"]) # needed for comparing logprobs with HF -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size -@pytest.mark.parametrize("detokenize", [True, False]) -def test_get_prompt_logprobs( - hf_runner, - vllm_runner, - model, - dtype, - chunked_prefill_token_size: int, - num_top_logprobs: int, - detokenize: bool, - example_prompts, -): - max_num_seqs = 256 - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) - max_num_batched_tokens = chunked_prefill_token_size - - max_tokens = 5 - with hf_runner(model, dtype=dtype) as hf_model: - hf_logprobs = hf_model.generate_greedy_logprobs( - example_prompts, - max_tokens=max_tokens, - ) - - with vllm_runner( - model, - dtype=dtype, - max_logprobs=num_top_logprobs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - ) as vllm_model: - vllm_sampling_params = SamplingParams(max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_top_logprobs, - temperature=0.0, - detokenize=detokenize) - vllm_results = vllm_model.llm.generate( - example_prompts, sampling_params=vllm_sampling_params) - - # Test whether logprobs are included in the results. - for result in vllm_results: - assert result.prompt_logprobs is not None - assert result.outputs[0].logprobs is not None - assert len(result.outputs[0].logprobs) == max_tokens - for logprobs in result.outputs[0].logprobs: - # If the output token is not included in the top X - # logprob, it can return 1 more data - assert (len(logprobs) == num_top_logprobs - or len(logprobs) == num_top_logprobs + 1) - output_text = result.outputs[0].text - output_string_from_most_likely_tokens_lst: list[str] = [] - for top_logprobs in result.outputs[0].logprobs: - top_logprob = next(iter(top_logprobs.values())) - output_string_from_most_likely_tokens_lst.append( - top_logprob.decoded_token) - - if detokenize: - output_string_from_most_likely_tokens = "".join( - output_string_from_most_likely_tokens_lst) - assert output_text == output_string_from_most_likely_tokens, ( - "The output text from the top logprob for each token position " - "should be the same as the output text in the result.") - else: - assert output_text == '' - assert output_string_from_most_likely_tokens_lst == ([None] * - max_tokens) - - # The first prompt logprob is always None - assert result.prompt_logprobs[0] is None - for prompt_logprobs in result.prompt_logprobs[1:]: - # If the prompt token is not included in the top X - # logprob, it can return 1 more data - assert (len(prompt_logprobs) == num_top_logprobs - or len(prompt_logprobs) == num_top_logprobs + 1) - - # Test whether prompt logprobs are consistent with HF - for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): - # Check prompt logprobs - # The first prompt logprob is always None, so we compare it from 1:. - vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] - for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): - for token_id, logprob in vllm_prompt_logprob_dict.items(): - torch.testing.assert_close(logprob.logprob, - hf_logprob[0][i][token_id].item(), - atol=1e-2, - rtol=1e-2) - vllm_sample_logprobs = vllm_result.outputs[0].logprobs - for i, top_logprobs in enumerate(vllm_sample_logprobs): - for token_id, sample_logprob in top_logprobs.items(): - logprob = sample_logprob.logprob - torch.testing.assert_close(logprob, - hf_logprob[i][-1][token_id].item(), - atol=1e-2, - rtol=1e-2) - if detokenize: - assert isinstance(sample_logprob.decoded_token, str), ( - "The token should be decoded by the time it is returned" - " to the user.") - - # Test if prompt logprobs are correctly set. - for vllm_result in vllm_results: - token_ids = vllm_result.prompt_token_ids - prompt_logprobs = vllm_result.prompt_logprobs - - # The first token doesn't have logprob. - assert prompt_logprobs[0] is None - - for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]): - assert token_id in logprob_dict - - -def test_max_logprobs(): - runner = VllmRunner("facebook/opt-125m", max_logprobs=1) - vllm_sampling_params = SamplingParams(logprobs=1) - # should pass - runner.generate(["Hello world"], sampling_params=vllm_sampling_params) - - bad_sampling_params = SamplingParams(logprobs=2) - with pytest.raises(ValueError): - runner.generate(["Hello world"], sampling_params=bad_sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("detokenize", [True, False]) -def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int, - detokenize: bool, example_prompts): - max_num_seqs = 256 - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) - max_num_batched_tokens = chunked_prefill_token_size - max_tokens = 5 - - with vllm_runner( - model, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - ) as vllm_model: - sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, - logprobs=None, - temperature=0.0, - detokenize=detokenize) - results_logprobs_none = vllm_model.llm.generate( - example_prompts, sampling_params=sampling_params_logprobs_none) - - for i in range(len(results_logprobs_none)): - assert results_logprobs_none[i].outputs[0].logprobs is None - assert results_logprobs_none[i].outputs[0].cumulative_logprob is None diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 128e8f552a161..fa0ca48f9bd9c 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -5,20 +5,14 @@ Run `pytest tests/samplers/test_no_bad_words.py`. """ + from typing import Optional -import pytest from transformers import AutoTokenizer from vllm import LLM, SamplingParams -@pytest.fixture(autouse=True) -def v1(monkeypatch): - """Only run on vLLM v1.""" - monkeypatch.setenv('VLLM_USE_V1', '1') - - def _generate( llm: LLM, prompt: str, @@ -49,25 +43,24 @@ class TestOneTokenBadWord: TARGET_TOKEN = "you" def setup_method(self, method): - self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, - add_prefix_space=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.MODEL, add_prefix_space=True + ) self.num_prompt_tokens = len(self._encode(self.PROMPT)) - self.target_token_id = self._encode(self.TARGET_TOKEN, - add_special_tokens=False)[0] + self.target_token_id = self._encode( + self.TARGET_TOKEN, add_special_tokens=False + )[0] def test_one_token_bad_word(self, vllm_runner): with vllm_runner(self.MODEL) as llm: output_token_ids = self._generate(llm) assert output_token_ids[0] == self.target_token_id - output_token_ids = self._generate(llm, - bad_words=[self.TARGET_TOKEN]) + output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN]) assert self.target_token_id not in output_token_ids - def _generate(self, - llm: LLM, - bad_words: Optional[list[str]] = None) -> list[int]: + def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: return _generate( llm=llm, prompt=self.PROMPT, @@ -75,11 +68,8 @@ class TestOneTokenBadWord: bad_words=bad_words, ) - def _encode(self, - prompt: str, - add_special_tokens: bool = True) -> list[int]: - return self.tokenizer(prompt, - add_special_tokens=add_special_tokens).input_ids + def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]: + return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids class TestTwoTokenBadWord: @@ -92,72 +82,80 @@ class TestTwoTokenBadWord: NEIGHBOUR_TOKEN2 = "older" def setup_method(self, method): - self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, - add_prefix_space=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.MODEL, add_prefix_space=True + ) self.num_prompt_tokens = len(self._encode(self.PROMPT)) - self.target_token_id1 = self._encode(self.TARGET_TOKEN1, - add_special_tokens=False)[0] - self.target_token_id2 = self._encode(self.TARGET_TOKEN2, - add_special_tokens=False)[0] - self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2, - add_special_tokens=False)[0] + self.target_token_id1 = self._encode( + self.TARGET_TOKEN1, add_special_tokens=False + )[0] + self.target_token_id2 = self._encode( + self.TARGET_TOKEN2, add_special_tokens=False + )[0] + self.neighbour_token_id2 = self._encode( + self.NEIGHBOUR_TOKEN2, add_special_tokens=False + )[0] def test_two_token_bad_word(self, vllm_runner): with vllm_runner(self.MODEL, dtype="half") as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ - self.target_token_id1, self.target_token_id2 + self.target_token_id1, + self.target_token_id2, ] - output_token_ids = self._generate(llm, - bad_words=[self.TARGET_TOKEN1]) + output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN1]) assert self.target_token_id1 not in output_token_ids - output_token_ids = self._generate(llm, - bad_words=[self.TARGET_TOKEN2]) + output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN2]) assert output_token_ids[0] == self.target_token_id1 assert self.target_token_id2 not in output_token_ids output_token_ids = self._generate( - llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}']) + llm, bad_words=[f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}"] + ) assert output_token_ids[0] == self.target_token_id1 assert output_token_ids[:2] != [ - self.target_token_id1, self.target_token_id2 + self.target_token_id1, + self.target_token_id2, ] assert not self._contains( - output_token_ids, - [self.target_token_id1, self.target_token_id2]) + output_token_ids, [self.target_token_id1, self.target_token_id2] + ) # Model dependent behaviour assert output_token_ids[:2] == [ - self.target_token_id1, self.neighbour_token_id2 + self.target_token_id1, + self.neighbour_token_id2, ] output_token_ids = self._generate( llm, bad_words=[ - f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}', - f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}' - ]) + f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}", + f"{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}", + ], + ) assert output_token_ids[0] == self.target_token_id1 assert output_token_ids[:2] != [ - self.target_token_id1, self.target_token_id2 + self.target_token_id1, + self.target_token_id2, ] assert not self._contains( - output_token_ids, - [self.target_token_id1, self.target_token_id2]) + output_token_ids, [self.target_token_id1, self.target_token_id2] + ) assert output_token_ids[:2] != [ - self.target_token_id1, self.neighbour_token_id2 + self.target_token_id1, + self.neighbour_token_id2, ] assert not self._contains( - output_token_ids, - [self.target_token_id1, self.neighbour_token_id2]) - assert ((self.target_token_id2 in output_token_ids) - or (self.neighbour_token_id2 in output_token_ids)) + output_token_ids, [self.target_token_id1, self.neighbour_token_id2] + ) + assert (self.target_token_id2 in output_token_ids) or ( + self.neighbour_token_id2 in output_token_ids + ) - def _generate(self, - llm: LLM, - bad_words: Optional[list[str]] = None) -> list[int]: + def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: return _generate( llm=llm, prompt=self.PROMPT, @@ -187,8 +185,5 @@ class TestTwoTokenBadWord: return False - def _encode(self, - prompt: str, - add_special_tokens: bool = True) -> list[int]: - return self.tokenizer(prompt, - add_special_tokens=add_special_tokens).input_ids + def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]: + return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py index 86fc14dc85f80..1359e6403e4c3 100644 --- a/tests/samplers/test_ranks.py +++ b/tests/samplers/test_ranks.py @@ -8,12 +8,6 @@ from vllm import SamplingParams MODELS = ["distilbert/distilgpt2"] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_ranks( @@ -26,25 +20,27 @@ def test_ranks( num_top_logprobs = 5 num_prompt_logprobs = 5 - with vllm_runner(model, dtype=dtype, - max_logprobs=num_top_logprobs) as vllm_model: - + with vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) as vllm_model: ## Test greedy logprobs ranks vllm_sampling_params = SamplingParams( temperature=0.0, top_p=1.0, max_tokens=max_tokens, logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs) - vllm_results = vllm_model.generate_w_logprobs(example_prompts, - vllm_sampling_params) + prompt_logprobs=num_prompt_logprobs, + ) + vllm_results = vllm_model.generate_w_logprobs( + example_prompts, vllm_sampling_params + ) ## Test non-greedy logprobs ranks - sampling_params = SamplingParams(temperature=1.0, - top_p=1.0, - max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs) + sampling_params = SamplingParams( + temperature=1.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs, + ) res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) for result in vllm_results: diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py deleted file mode 100644 index 520b88d03ac8e..0000000000000 --- a/tests/samplers/test_sampler.py +++ /dev/null @@ -1,769 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools -import random -from dataclasses import dataclass -from typing import Optional -from unittest.mock import Mock, patch - -import pytest -import torch -from transformers import GenerationConfig, GenerationMixin - -import vllm.envs as envs -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter, is_pin_memory_available - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This file tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -class MockLogitsSampler(Sampler): - - def __init__(self, fake_logits: torch.Tensor): - super().__init__() - self.fake_logits = fake_logits - - def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, VOCAB_SIZE), - 1e-2, - dtype=input_tensor.dtype) - sampler = MockLogitsSampler(fake_logits) - return input_tensor, fake_logits, sampler - - -VOCAB_SIZE = 32000 -RANDOM_SEEDS = list(range(128)) -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -def _do_sample( - batch_size: int, - input_tensor: torch.Tensor, - sampler: MockLogitsSampler, - sampling_params: SamplingParams, - device: str, -): - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - seq_lens: list[int] = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=sampling_params, - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_greedy(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler = _prepare_test(batch_size) - - sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - expected = torch.argmax(fake_logits, dim=-1) - for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output.samples: - assert nth_output.output_token == expected[i].item() - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_random(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - for i in range(batch_size): - fake_logits[i, i] = 1e2 - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output.samples: - assert nth_output.output_token == i - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_random_seed(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - for i in range(batch_size): - fake_logits[i, i] = 1e2 - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - seed=random.randint(0, 10000), - ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output.samples: - assert nth_output.output_token == i - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_random_seed_deterministic(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - seed=random.randint(0, 10000), - ) - first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - assert first_sampler_output == second_sampler_output - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_min_tokens_penalty(seed: int, device: str): - seq_id_counter = Counter(start=random.randint(0, 100)) - set_random_seed(seed) - torch.set_default_device(device) - - def create_sampling_params(min_tokens, - eos_token_id=0, - *, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams( - min_tokens=min_tokens, - max_tokens=9999, # keep higher than max of min_tokens - stop_token_ids=stop_token_ids, - # requesting prompt_logprobs changes the structure of `logits` - prompt_logprobs=prompt_logprobs, - ) - sampling_params.all_stop_token_ids.add(eos_token_id) - return sampling_params - - def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceData.from_seqs( - random.choices(range(0, VOCAB_SIZE), k=num_input)) - if num_generated > 0: - seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), - k=num_generated) - return seq_data - - def generate_test_case(): - # generate multiple seq groups but limit total batch size - batch_size = random.randint(1, 128) - - expected_penalization = [] - sequence_metadata_list: list[SequenceGroupMetadata] = [] - # 20% chance to generate seq group metadata list with all prompts - is_prompt = random.random() < 0.2 - while batch_size > 0: - num_seqs = 1 if is_prompt else random.randint(1, batch_size) - - eos_token_id = random.randint(0, VOCAB_SIZE - 1) - min_tokens = random.randint(0, 50) - num_stop_tokens = random.randint(0, 8) - if num_stop_tokens > 0: - stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1), - k=num_stop_tokens) - else: - stop_token_ids = None - - sampling_params = create_sampling_params( - min_tokens=min_tokens, - eos_token_id=eos_token_id, - stop_token_ids=stop_token_ids) - - seq_data: dict[int, SequenceData] = {} - seq_group_penalization: list[bool] = [] - for _ in range(num_seqs): - num_input = random.randint(1, 100) - num_generated = 0 if is_prompt else random.randint(1, 100) - seq_data[next(seq_id_counter)] = create_sequence_data( - num_input=num_input, num_generated=num_generated) - seq_group_penalization.append(num_generated < min_tokens) - - expected_penalization.extend(seq_group_penalization) - sequence_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{batch_size}", - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=sampling_params, - block_tables={}, - )) - batch_size -= num_seqs - - return { - "expected_penalization": expected_penalization, - "seq_group_metadata_list": sequence_metadata_list, - } - - # define some explicit test cases for edge case behavior - prompt_without_penalization = { - "expected_penalization": [False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, - sampling_params=create_sampling_params(0), - block_tables={}, - ), - ] - } - - prompt_with_penalization = { - "expected_penalization": [True], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, - sampling_params=create_sampling_params(1), - block_tables={}, - ), - ] - } - - prompt_with_penalization_and_prompt_logprobs = { - "expected_penalization": [False, False, True], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(num_input=3), - }, - sampling_params=create_sampling_params(1, prompt_logprobs=3), - block_tables={}, - ), - ] - } - - stop_penalizing_after_min_tokens = { - "expected_penalization": [False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=1), - }, - sampling_params=create_sampling_params(1), - block_tables={}, - ) - ] - } - - stop_token_ids = [42, 99, 42, 0] # intentional duplication - prompt_combination = { - "expected_penalization": [False, True, False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_2", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(num_input=2), - }, - sampling_params=create_sampling_params(1, prompt_logprobs=3), - block_tables={}, - ), - SequenceGroupMetadata( - request_id="test_3", - is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, - sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), - block_tables={}, - ) - ] - } - - stop_token_ids = [1, 999, 37, 37] # intentional duplication - decode_combination = { - "expected_penalization": [True, False, False, True, False], - "seq_group_metadata_list": [ - SequenceGroupMetadata( - request_id="test_1", - is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=1), - next(seq_id_counter): - create_sequence_data(num_generated=100), - }, - sampling_params=create_sampling_params( - 2, stop_token_ids=stop_token_ids), - block_tables={}, - ), - SequenceGroupMetadata( - request_id="test_2", - is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=20), - next(seq_id_counter): - create_sequence_data(num_generated=1), - next(seq_id_counter): - create_sequence_data(num_generated=10), - }, - sampling_params=create_sampling_params( - 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), - block_tables={}, - ), - ] - } - - if seed == 0: - test_cases = [ - prompt_without_penalization, - prompt_with_penalization, - prompt_with_penalization_and_prompt_logprobs, - stop_penalizing_after_min_tokens, - prompt_combination, - decode_combination, - ] - else: - test_cases = [generate_test_case()] - - def run_test_case(*, expected_penalization: list[bool], - seq_group_metadata_list: list[SequenceGroupMetadata]): - assert expected_penalization, \ - "Invalid test case, need expected_penalization" - assert seq_group_metadata_list, \ - "Invalid test case, need seq_group_metadata_list" - - batch_size = 0 - seq_lens: list[int] = [] - sampling_params_per_row: list[SamplingParams] = [] - for sgm in seq_group_metadata_list: - sampling_params = sgm.sampling_params - - num_rows = len(sgm.seq_data) - if sgm.is_prompt: - # a prompt seq_group has only one sequence - seq_data = next(iter(sgm.seq_data.values())) - prompt_len = seq_data.get_prompt_len() - seq_lens.append(prompt_len) - - assert sgm.sampling_params is not None - if sgm.sampling_params.prompt_logprobs: - # with prompt_logprobs each token in the prompt has a row in - # logits - num_rows = prompt_len - - batch_size += num_rows - sampling_params_per_row.extend( - itertools.repeat(sampling_params, num_rows)) - - assert len( - expected_penalization - ) == batch_size, \ - ("Invalid test case, expected_penalization does not match computed" - "batch size") - - _, fake_logits, sampler = _prepare_test(batch_size) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens=seq_lens if seq_lens else None, - query_lens=seq_lens if seq_lens else [1] * batch_size, - device=device, - pin_memory=is_pin_memory_available()) - # the logits tensor is modified in-place by the sampler - _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) - - for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_row)): - - tokens_to_check = sampling_params.all_stop_token_ids - - if should_penalize: - for token_id in tokens_to_check: - assert fake_logits[logits_idx, token_id] == -float( - 'inf' - ), f"Expected token {token_id} for logits row {logits_idx}" - " to be penalized" - # no other tokens should be set to -inf - assert torch.count_nonzero( - fake_logits[logits_idx, :] == -float('inf')) == len( - tokens_to_check - ), f"Expected only {len(tokens_to_check)} to be penalized" - else: - # no tokens should be set to -inf - assert torch.count_nonzero( - fake_logits[logits_idx, :] == - -float('inf')) == 0, "No tokens should have been penalized" - - for test_case in test_cases: - run_test_case(**test_case) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_mixed(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler = _prepare_test(batch_size) - - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - expected_tokens: list[Optional[list[int]]] = [] - seq_lens: list[int] = [] - for i in range(batch_size): - expected: Optional[list[int]] = None - sampling_type = random.randint(0, 2) - if sampling_type == 0: - sampling_params = SamplingParams(temperature=0) - expected = [int(torch.argmax(fake_logits[i], dim=-1).item())] - elif sampling_type in (1, 2): - n = random.randint(1, 10) - sampling_params = SamplingParams( - temperature=random.random() + 0.1, - top_p=min(random.random() + 0.1, 1), - top_k=random.randint(0, 10), - n=n, - presence_penalty=random.randint(0, 1), - ) - if sampling_type == 2: - sampling_params.seed = random.randint(0, 10000) - else: - for idx in range(n): - fake_logits[i, i + idx] = 1e2 - expected = list(range(i, i + n)) - - expected_tokens.append(expected) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=sampling_params, - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - generators: dict[str, torch.Generator] = {} - - def test_sampling(): - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available(), - generators=generators) - sampler_output = sampler(logits=fake_logits, - sampling_metadata=sampling_metadata) - - for i, (sequence_output, metadata) in enumerate( - zip(sampler_output, seq_group_metadata_list)): - assert metadata.sampling_params is not None - - if (metadata.sampling_params.seed is not None - and expected_tokens[i] is None): - # Record seeded random result to compare with results of - # second invocation - expected_tokens[i] = [ - nth_output.output_token - for nth_output in sequence_output.samples - ] - continue - - expected_tokens_item = expected_tokens[i] - assert expected_tokens_item is not None - - for n, nth_output in enumerate(sequence_output.samples): - assert metadata.sampling_params is not None - - if (metadata.sampling_params.temperature == 0 - or metadata.sampling_params.seed is not None): - # Ensure exact matches for greedy or random with seed - assert nth_output.output_token == expected_tokens_item[n] - else: - # For non-seeded random check that one of the high-logit - # tokens were chosen - assert nth_output.output_token in expected_tokens_item - - # Test batch - test_sampling() - - # Shuffle the batch and resample - target_index = list(range(batch_size)) - for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, seq_lens): - random.Random(seed).shuffle(list_to_shuffle) - target_index = torch.tensor(target_index) - input_tensor.data = input_tensor.index_select(0, target_index) - fake_logits.data = fake_logits.index_select(0, target_index) - - # This time, results of seeded random samples will be compared with - # the corresponding sample in the pre-shuffled batch - test_sampling() - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_top_k_top_p(seed: int, device: str): - set_random_seed(seed) - batch_size = random.randint(1, 256) - top_k = random.randint(100, 500) - top_p = random.random() * 0.1 - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), - device=device, - dtype=torch.float16) - fake_logits = torch.normal(0, - 5, - size=(batch_size, vocab_size), - device=input_tensor.device, - dtype=input_tensor.dtype) - sampler = MockLogitsSampler(fake_logits) - - generation_model = GenerationMixin() - generation_config = GenerationConfig(top_k=top_k, - top_p=top_p, - do_sample=True) - - @dataclass - class MockConfig: - is_encoder_decoder: bool = False - - generation_model.config = MockConfig() # needed by the following method - generation_model._prepare_special_tokens(generation_config, device=device) - processors = generation_model._get_logits_processor(generation_config, - None, - None, - None, [], - device=device) - assert len(processors) == 2 # top_p and top_k - - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - seq_lens: list[int] = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams( - temperature=1, - top_k=top_k, - top_p=top_p, - ), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - - sample_probs = None - - def mock_sample(probs, *args, **kwargs): - nonlocal sample_probs - sample_probs = probs - return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] - for prob in probs], None) - - # top-k and top-p is only calculated when flashinfer kernel is not available - with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \ - patch("vllm.model_executor.layers.sampler." - "flashinfer_top_k_top_p_sampling", None): - sampler(logits=fake_logits, sampling_metadata=sampling_metadata) - - assert sample_probs is not None - - hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone()) - hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) - torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5) - assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_flashinfer_fallback(seed: int, device: str): - if not envs.VLLM_USE_FLASHINFER_SAMPLER: - pytest.skip("Flashinfer sampler is disabled") - - pytest.skip("After FlashInfer 0.2.3, sampling will never fail") - - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - def failing_flashinfer_sampling(*_args, **_kwargs): - return None, torch.zeros(batch_size, device=device, dtype=torch.int32) - - sampling_params = SamplingParams( - temperature=1.0, - n=random.randint(1, 10), - seed=random.randint(0, 10000), - ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - with patch( - "vllm.model_executor.layers.sampler." - "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling): - fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - - assert sampler_output == fallback_sampler_output - - -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_repetition_penalty_mixed(device: str): - - vocab_size = 8 - - def test_sampling_params(sampling_params: list[SamplingParams]): - - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - seq_lens: list[int] = [] - for i in range(2): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=sampling_params[i], - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - - fake_logits = torch.full((2, vocab_size), - 1e-2, - device=device, - dtype=torch.float16) - - fake_logits[:, 5] = 1.1e-2 - fake_logits[:, 1] = 1.2e-2 - - sampler = MockLogitsSampler(fake_logits) - - sampler_output = sampler(logits=fake_logits, - sampling_metadata=sampling_metadata) - - generated_tokens = [] - for output in sampler_output: - generated_tokens.append(output.samples[0].output_token) - - return generated_tokens - - # one configuration is greedy with repetition_penalty - sampling_params_rep = SamplingParams( - temperature=0.0, - repetition_penalty=2.0, - ) - - # other configuration is sampling w/o repetition_penalty - sampling_params_sample = SamplingParams( - temperature=1.0, - top_k=1, - seed=42, - ) - - tokens1 = test_sampling_params( - [sampling_params_rep, sampling_params_sample]) - - tokens2 = test_sampling_params( - [sampling_params_sample, sampling_params_rep]) - - assert tokens1[0] == tokens2[1] - assert tokens1[1] == tokens2[0] - - -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_include_gpu_probs_tensor(device: str): - set_random_seed(42) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - sampler.include_gpu_probs_tensor = True - sampler.should_modify_greedy_probs_inplace = False - - sampling_params = SamplingParams(temperature=0) - - mock_inplace = Mock() - with patch( - "vllm.model_executor.layers.sampler._modify_greedy_probs_inplace", - mock_inplace): - - sampler_output = _do_sample(batch_size, fake_logits, sampler, - sampling_params, device) - mock_inplace.assert_not_called() - - assert sampler_output.sampled_token_probs is not None - assert sampler_output.logprobs is not None - assert sampler_output.sampled_token_ids is not None diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py deleted file mode 100644 index 5a0efd98acc16..0000000000000 --- a/tests/samplers/test_seeded_generate.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Verify that seeded random sampling is deterministic. - -Run `pytest tests/samplers/test_seeded_generate.py`. -""" -import copy -import random -from itertools import combinations - -import pytest - -from vllm import SamplingParams -from vllm.model_executor.utils import set_random_seed - -MODEL = "facebook/opt-125m" -RANDOM_SEEDS = list(range(5)) - - -@pytest.fixture -def vllm_model(vllm_runner, monkeypatch): - # This file relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") - with vllm_runner(MODEL, dtype="half") as vllm_model: - yield vllm_model - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_random_sample_with_seed( - vllm_model, - example_prompts, - seed: int, -) -> None: - set_random_seed(seed) - - sampling_params = SamplingParams( - # Parameters to ensure sufficient randomness - temperature=3.0, - top_p=min(random.random() + 0.3, 1), - top_k=random.randint(5, 20), - n=random.randint(1, 10), - presence_penalty=random.randint(0, 1), - max_tokens=8, - ignore_eos=True, - ) - - sampling_params_seed_1 = copy.deepcopy(sampling_params) - sampling_params_seed_1.seed = 100 - sampling_params_seed_2 = copy.deepcopy(sampling_params) - sampling_params_seed_2.seed = 200 - - llm = vllm_model.llm - - for prompt in example_prompts: - for params in ( - sampling_params, - sampling_params_seed_1, - sampling_params_seed_2, - sampling_params, - sampling_params_seed_1, - sampling_params_seed_2, - ): - llm._add_request(prompt, params=params) - - results = llm._run_engine(use_tqdm=False) - all_outputs = [[out.token_ids for out in output.outputs] - for output in results] - - for i in range(0, len(example_prompts), 6): - outputs = all_outputs[i:i + 6] - - # verify all non-seeded requests differ - for output_a, output_b in combinations( - (outputs[0], outputs[1], outputs[2], outputs[3]), - 2, - ): - assert output_a != output_b - - # verify requests with the same seed match - assert outputs[1] == outputs[4] - assert outputs[2] == outputs[5] - - # verify generations within the same parallel sampling group differ - for output in outputs: - for sub_output_a, sub_output_b in combinations(output, 2): - assert sub_output_a != sub_output_b diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 45ddb2178722a..19ba32d8dee4c 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -3,38 +3,67 @@ import pytest import torch +from vllm.config import SpeculativeConfig from vllm.model_executor.models.interfaces import supports_eagle3 @pytest.mark.parametrize( "model_path", - [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) -def test_llama(vllm_runner, example_prompts, model_path, monkeypatch): + [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator", + ), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator", + ), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", + id="qwen3-eagle3-speculator-w4a16-verifier", + ), + pytest.param( + "nm-testing/random-weights-llama3.1.8b-2layer-eagle3", + id="llama3-eagl3-multiple-layers", + ), + ], +) +def test_eagle3_speculators_model( + vllm_runner, example_prompts, model_path, monkeypatch +): + """ + Test Eagle3 speculators models properly initialize speculative decoding. + + This test verifies: + 1. Eagle3 support is detected for the model + 2. Speculative config is automatically initialized from embedded config + 3. The draft model path is correctly set to the speculators model + 4. Speculative tokens count is valid + 5. Text generation works with speculative decoding enabled + """ # Set environment variable for V1 engine serialization monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + # Verify Eagle3 support is detected eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert eagle3_supported, f"Eagle3 should be supported for {model_path}" - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + vllm_config = vllm_model.llm.llm_engine.vllm_config + assert isinstance(vllm_config.speculative_config, SpeculativeConfig), ( + "Speculative config should be initialized for speculators model" + ) -@pytest.mark.parametrize( - "model_path", - [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) -def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch): - # Set environment variable for V1 engine serialization - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + spec_config = vllm_config.speculative_config + assert spec_config.num_speculative_tokens > 0, ( + f"Expected positive speculative tokens, " + f"got {spec_config.num_speculative_tokens}" + ) - with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: - eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert spec_config.model == model_path, ( + f"Draft model should be {model_path}, got {spec_config.model}" + ) - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) + assert vllm_outputs, f"No outputs generated for speculators model {model_path}" diff --git a/tests/standalone_tests/lazy_imports.py b/tests/standalone_tests/lazy_imports.py index 21bcb6b822d1f..ddcdd2a51ab9f 100644 --- a/tests/standalone_tests/lazy_imports.py +++ b/tests/standalone_tests/lazy_imports.py @@ -37,4 +37,5 @@ if use_blame: assert not any_module_imported(), ( f"Some the modules in {module_names} are imported. To see the first" - f" import location, run the test with `use_blame=True`.") + f" import location, run the test with `use_blame=True`." +) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py deleted file mode 100644 index edc0849dff33f..0000000000000 --- a/tests/test_cache_block_hashing.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test hashing of cache blocks. - -Run `pytest tests/test_cache_block_hashing.py`. -""" -from typing import Optional - -import pytest - -from vllm.inputs import token_inputs -from vllm.lora.request import LoRARequest -from vllm.sequence import Sequence -from vllm.transformers_utils.tokenizer_group import TokenizerGroup - -# Make two prefixes with different first blocks. -prefix_start = [("You are an expert"), ("You are a")] -prefix_common = ( - " school principal, skilled in effectively managing " - "faculty and staff. Draft 10-15 questions for a potential first grade " - "Head Teacher for my K-12, all-girls', independent school that emphasizes " - "community, joyful discovery, and life-long learning. The candidate is " - "coming in for a first-round panel interview for a 8th grade Math " - "teaching role. They have 5 years of previous teaching experience " - "as an assistant teacher at a co-ed, public school with experience " - "in middle school math teaching. Based on this, fulfill " - "the following: ") -prefixes = [start + prefix_common for start in prefix_start] - -# Sample prompts. -sample_prompts = [ - "Hello, my name is", "The president of the United States is", - "The capital of France is", "The future of AI is" -] - - -# Helper function. -def flatten_2d(li): - return [lss for ls in li for lss in ls] - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("max_num_seqs", [256]) -@pytest.mark.parametrize("concurrent_lora_int_ids", - [[None], [1], [None, 1], [None, 1, 2], [1, 2]]) -def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, - concurrent_lora_int_ids: list[Optional[int]]): - - tokenizer = TokenizerGroup( - tokenizer_id="facebook/opt-125m", - enable_lora=False, - max_num_seqs=max_num_seqs, - max_input_length=None, - ) - - hashes: list[list[list[int]]] = [] - - for prefix in prefixes: - for lora_int_id in concurrent_lora_int_ids: - lora_request = None - - if lora_int_id is not None: - lora_request = LoRARequest( - f"example_lora_{lora_int_id}", - lora_int_id, - f"example/path/to/lora_{lora_int_id}", - ) - - hashes.append([]) - prompts = [prefix + prompt for prompt in sample_prompts] - for seq_id, prompt in enumerate(prompts): - hashes[-1].append([]) - prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, - inputs=token_inputs(prompt_token_ids, - prompt=prompt), - block_size=block_size, - eos_token_id=tokenizer.tokenizer.eos_token_id, - lora_request=lora_request) - - num_blocks = len(prompt_token_ids) // block_size - for idx in range(num_blocks): - hashes[-1][-1].append(seq.hash_of_block(idx)) - - # Check that hashes made with two prefixes with different first blocks are - # different everywhere. - for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): - assert (hash0 != hash1) - - # Check that hashes of different prompts made with the same prefix are the - # same until the hashes that contain the prompt. - for hash_pref in hashes: - same_hashes = [tuple(h[:-1]) for h in hash_pref] - different_hashes = [h[-1] for h in hash_pref] - assert (len(set(same_hashes)) == 1) - assert (len(set(different_hashes)) == len(different_hashes)) diff --git a/tests/test_config.py b/tests/test_config.py index 957771a4226bc..bba2fbec3db29 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import MISSING, Field, asdict, dataclass, field +from unittest.mock import patch import pytest from vllm.compilation.backends import VllmBackend -from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - get_field, update_config) +from vllm.config import ModelConfig, PoolerConfig, VllmConfig, update_config +from vllm.config.load import LoadConfig +from vllm.config.utils import get_field from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -20,8 +23,8 @@ def test_compile_config_repr_succeeds(): # test that repr(config) succeeds val = repr(config) - assert 'VllmConfig' in val - assert 'inductor_passes' in val + assert "VllmConfig" in val + assert "inductor_passes" in val @dataclass @@ -48,8 +51,7 @@ def test_get_field(): @dataclass class _TestNestedConfig: - a: _TestConfigFields = field( - default_factory=lambda: _TestConfigFields(a=0)) + a: _TestConfigFields = field(default_factory=lambda: _TestConfigFields(a=0)) def test_update_config(): @@ -76,65 +78,60 @@ def test_update_config(): # Can remove once --task option is fully deprecated @pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", - "expected_task"), + ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), [ ("distilbert/distilgpt2", "generate", "none", "generate"), ("intfloat/multilingual-e5-small", "pooling", "none", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", - "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"), ("openai/whisper-small", "generate", "none", "transcription"), ], ) -def test_auto_task(model_id, expected_runner_type, expected_convert_type, - expected_task): +def test_auto_task( + model_id, expected_runner_type, expected_convert_type, expected_task +): config = ModelConfig(model_id, task="auto") assert config.runner_type == expected_runner_type assert config.convert_type == expected_convert_type - assert expected_task in config.supported_tasks # Can remove once --task option is fully deprecated @pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", - "expected_task"), + ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), [ ("distilbert/distilgpt2", "pooling", "embed", "embed"), ("intfloat/multilingual-e5-small", "pooling", "embed", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", - "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"), ("openai/whisper-small", "pooling", "embed", "embed"), ], ) -def test_score_task(model_id, expected_runner_type, expected_convert_type, - expected_task): +def test_score_task( + model_id, expected_runner_type, expected_convert_type, expected_task +): config = ModelConfig(model_id, task="score") assert config.runner_type == expected_runner_type assert config.convert_type == expected_convert_type - assert expected_task in config.supported_tasks # Can remove once --task option is fully deprecated @pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", - "expected_task"), + ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), [ ("openai/whisper-small", "generate", "none", "transcription"), ], ) -def test_transcription_task(model_id, expected_runner_type, - expected_convert_type, expected_task): +def test_transcription_task( + model_id, expected_runner_type, expected_convert_type, expected_task +): config = ModelConfig(model_id, task="transcription") assert config.runner_type == expected_runner_type assert config.convert_type == expected_convert_type - assert expected_task in config.supported_tasks @pytest.mark.parametrize( @@ -200,31 +197,27 @@ def test_disable_sliding_window(model_id_expected): assert model_config.max_model_len == expected -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_get_pooling_config(): model_id = "sentence-transformers/all-MiniLM-L12-v2" model_config = ModelConfig(model_id) - pooling_config = model_config._init_pooler_config() - assert pooling_config is not None - - assert pooling_config.normalize - assert pooling_config.pooling_type == PoolingType.MEAN.name + assert model_config.pooler_config is not None + assert model_config.pooler_config.normalize + assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_get_pooling_config_from_args(): model_id = "sentence-transformers/all-MiniLM-L12-v2" - model_config = ModelConfig(model_id) + pooler_config = PoolerConfig(pooling_type="CLS", normalize=True) + model_config = ModelConfig(model_id, pooler_config=pooler_config) - override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True) - model_config.override_pooler_config = override_pooler_config - - pooling_config = model_config._init_pooler_config() - assert pooling_config is not None - assert asdict(pooling_config) == asdict(override_pooler_config) + assert asdict(model_config.pooler_config) == asdict(pooler_config) @pytest.mark.parametrize( @@ -233,16 +226,18 @@ def test_get_pooling_config_from_args(): ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM ("intfloat/e5-small", "CLS", "MEAN"), # BertModel ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward - ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward - ]) + ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward + ], +) def test_default_pooling_type(model_id, default_pooling_type, pooling_type): model_config = ModelConfig(model_id) assert model_config._model_info.default_pooling_type == default_pooling_type assert model_config.pooler_config.pooling_type == pooling_type -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_get_bert_tokenization_sentence_transformer_config(): model_id = "BAAI/bge-base-en-v1.5" bge_model_config = ModelConfig(model_id) @@ -270,17 +265,18 @@ def test_rope_customization(): "rope_theta": TEST_ROPE_THETA, }, ) - assert getattr(llama_model_config.hf_config, "rope_scaling", - None) == TEST_ROPE_SCALING - assert getattr(llama_model_config.hf_config, "rope_theta", - None) == TEST_ROPE_THETA + assert ( + getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING + ) + assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA assert llama_model_config.max_model_len == 16384 longchat_model_config = ModelConfig("lmsys/longchat-13b-16k") # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config assert all( longchat_model_config.hf_config.rope_scaling.get(key) == value - for key, value in LONGCHAT_ROPE_SCALING.items()) + for key, value in LONGCHAT_ROPE_SCALING.items() + ) assert longchat_model_config.max_model_len == 16384 longchat_model_config = ModelConfig( @@ -289,29 +285,68 @@ def test_rope_customization(): "rope_scaling": TEST_ROPE_SCALING, }, ) - assert getattr(longchat_model_config.hf_config, "rope_scaling", - None) == TEST_ROPE_SCALING + assert ( + getattr(longchat_model_config.hf_config, "rope_scaling", None) + == TEST_ROPE_SCALING + ) assert longchat_model_config.max_model_len == 4096 -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Encoder Decoder models not supported on ROCm.") -@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ - ("facebook/opt-125m", False), - ("facebook/bart-base", True), - ("meta-llama/Llama-3.2-1B-Instruct", False), - ("meta-llama/Llama-3.2-11B-Vision", True), -]) +def test_nested_hf_overrides(): + """Test that nested hf_overrides work correctly.""" + # Test with a model that has text_config + model_config = ModelConfig( + "Qwen/Qwen2-VL-2B-Instruct", + hf_overrides={ + "text_config": { + "hidden_size": 1024, + }, + }, + ) + assert model_config.hf_config.text_config.hidden_size == 1024 + + # Test with deeply nested overrides + model_config = ModelConfig( + "Qwen/Qwen2-VL-2B-Instruct", + hf_overrides={ + "text_config": { + "hidden_size": 2048, + "num_attention_heads": 16, + }, + "vision_config": { + "hidden_size": 512, + }, + }, + ) + assert model_config.hf_config.text_config.hidden_size == 2048 + assert model_config.hf_config.text_config.num_attention_heads == 16 + assert model_config.hf_config.vision_config.hidden_size == 512 + + +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm." +) +@pytest.mark.parametrize( + ("model_id", "is_encoder_decoder"), + [ + ("facebook/opt-125m", False), + ("openai/whisper-tiny", True), + ("meta-llama/Llama-3.2-1B-Instruct", False), + ], +) def test_is_encoder_decoder(model_id, is_encoder_decoder): config = ModelConfig(model_id) assert config.is_encoder_decoder == is_encoder_decoder -@pytest.mark.parametrize(("model_id", "uses_mrope"), [ - ("facebook/opt-125m", False), - ("Qwen/Qwen2-VL-2B-Instruct", True), -]) +@pytest.mark.parametrize( + ("model_id", "uses_mrope"), + [ + ("facebook/opt-125m", False), + ("Qwen/Qwen2-VL-2B-Instruct", True), + ], +) def test_uses_mrope(model_id, uses_mrope): config = ModelConfig(model_id) @@ -345,7 +380,8 @@ def test_generation_config_loading(): model_config = ModelConfig( model_id, generation_config="auto", - override_generation_config=override_generation_config) + override_generation_config=override_generation_config, + ) override_result = correct_generation_config.copy() override_result.update(override_generation_config) @@ -357,17 +393,19 @@ def test_generation_config_loading(): model_config = ModelConfig( model_id, generation_config="vllm", - override_generation_config=override_generation_config) + override_generation_config=override_generation_config, + ) assert model_config.get_diff_sampling_param() == override_generation_config -@pytest.mark.parametrize("pt_load_map_location", [ - "cuda", - { - "": "cuda" - }, -]) +@pytest.mark.parametrize( + "pt_load_map_location", + [ + "cuda", + {"": "cuda"}, + ], +) def test_load_config_pt_load_map_location(pt_load_map_location): load_config = LoadConfig(pt_load_map_location=pt_load_map_location) config = VllmConfig(load_config=load_config) @@ -376,15 +414,18 @@ def test_load_config_pt_load_map_location(pt_load_map_location): @pytest.mark.parametrize( - ("model_id", "max_model_len", "expected_max_len", "should_raise"), [ + ("model_id", "max_model_len", "expected_max_len", "should_raise"), + [ ("BAAI/bge-reranker-base", None, 512, False), ("BAAI/bge-reranker-base", 256, 256, False), ("BAAI/bge-reranker-base", 513, 512, True), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True), - ]) -def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len, - should_raise): + ], +) +def test_get_and_verify_max_len( + model_id, max_model_len, expected_max_len, should_raise +): """Test get_and_verify_max_len with different configurations.""" model_config = ModelConfig(model_id) @@ -394,3 +435,117 @@ def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len, else: actual_max_len = model_config.get_and_verify_max_len(max_model_len) assert actual_max_len == expected_max_len + + +class MockConfig: + """Simple mock object for testing maybe_pull_model_tokenizer_for_runai""" + + def __init__(self, model: str, tokenizer: str): + self.model = model + self.tokenizer = tokenizer + self.model_weights = None + + +@pytest.mark.parametrize( + "s3_url", + [ + "s3://example-bucket-1/model/", + "s3://example-bucket-2/model/", + ], +) +@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files") +def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url): + """Test that S3 URLs create deterministic local directories for model and + tokenizer.""" + # Mock pull_files to avoid actually downloading files during tests + mock_pull_files.return_value = None + + # Create first mock and run the method + config1 = MockConfig(model=s3_url, tokenizer=s3_url) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url) + + # Check that model and tokenizer point to existing directories + assert os.path.exists(config1.model), ( + f"Model directory does not exist: {config1.model}" + ) + assert os.path.isdir(config1.model), ( + f"Model path is not a directory: {config1.model}" + ) + assert os.path.exists(config1.tokenizer), ( + f"Tokenizer directory does not exist: {config1.tokenizer}" + ) + assert os.path.isdir(config1.tokenizer), ( + f"Tokenizer path is not a directory: {config1.tokenizer}" + ) + + # Verify that the paths are different from the original S3 URL + assert config1.model != s3_url, "Model path should be converted to local directory" + assert config1.tokenizer != s3_url, ( + "Tokenizer path should be converted to local directory" + ) + + # Store the original paths + created_model_dir = config1.model + create_tokenizer_dir = config1.tokenizer + + # Create a new mock and run the method with the same S3 URL + config2 = MockConfig(model=s3_url, tokenizer=s3_url) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url) + + # Check that the new directories exist + assert os.path.exists(config2.model), ( + f"Model directory does not exist: {config2.model}" + ) + assert os.path.isdir(config2.model), ( + f"Model path is not a directory: {config2.model}" + ) + assert os.path.exists(config2.tokenizer), ( + f"Tokenizer directory does not exist: {config2.tokenizer}" + ) + assert os.path.isdir(config2.tokenizer), ( + f"Tokenizer path is not a directory: {config2.tokenizer}" + ) + + # Verify that the paths are deterministic (same as before) + assert config2.model == created_model_dir, ( + f"Model paths are not deterministic. " + f"Original: {created_model_dir}, New: {config2.model}" + ) + assert config2.tokenizer == create_tokenizer_dir, ( + f"Tokenizer paths are not deterministic. " + f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}" + ) + + +@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files") +def test_s3_url_different_models_create_different_directories(mock_pull_files): + """Test that different S3 URLs create different local directories.""" + # Mock pull_files to avoid actually downloading files during tests + mock_pull_files.return_value = None + + s3_url1 = "s3://example-bucket-1/model/" + s3_url2 = "s3://example-bucket-2/model/" + + # Create mocks with different S3 URLs and run the method + config1 = MockConfig(model=s3_url1, tokenizer=s3_url1) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1) + + config2 = MockConfig(model=s3_url2, tokenizer=s3_url2) + ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2) + + # Verify that different URLs produce different directories + assert config1.model != config2.model, ( + f"Different S3 URLs should create different model directories. " + f"URL1 model: {config1.model}, URL2 model: {config2.model}" + ) + assert config1.tokenizer != config2.tokenizer, ( + f"Different S3 URLs should create different tokenizer directories. " + f"URL1 tokenizer: {config1.tokenizer}, " + f"URL2 tokenizer: {config2.tokenizer}" + ) + + # Verify that both sets of directories exist + assert os.path.exists(config1.model) and os.path.isdir(config1.model) + assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer) + assert os.path.exists(config2.model) and os.path.isdir(config2.model) + assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer) diff --git a/tests/test_embedded_commit.py b/tests/test_embedded_commit.py index b9593e2a3b7c0..687a15446fc2a 100644 --- a/tests/test_embedded_commit.py +++ b/tests/test_embedded_commit.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import vllm - - -def test_embedded_commit_defined(): - assert hasattr(vllm, "__version__") - assert hasattr(vllm, "__version_tuple__") - assert vllm.__version__ != "dev" - assert vllm.__version_tuple__ != (0, 0, "dev") +import vllm + + +def test_embedded_commit_defined(): + assert hasattr(vllm, "__version__") + assert hasattr(vllm, "__version_tuple__") + assert vllm.__version__ != "dev" + assert vllm.__version_tuple__ != (0, 0, "dev") diff --git a/tests/test_envs.py b/tests/test_envs.py new file mode 100644 index 0000000000000..62d529c363608 --- /dev/null +++ b/tests/test_envs.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from unittest.mock import patch + +import pytest + +from vllm.envs import env_list_with_choices, env_with_choices + + +class TestEnvWithChoices: + """Test cases for env_with_choices function.""" + + def test_default_value_returned_when_env_not_set(self): + """Test default is returned when env var is not set.""" + env_func = env_with_choices( + "NONEXISTENT_ENV", "default", ["option1", "option2"] + ) + assert env_func() == "default" + + def test_none_default_returned_when_env_not_set(self): + """Test that None is returned when env not set and default is None.""" + env_func = env_with_choices("NONEXISTENT_ENV", None, ["option1", "option2"]) + assert env_func() is None + + def test_valid_value_returned_case_sensitive(self): + """Test that valid value is returned in case sensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=True + ) + assert env_func() == "option1" + + def test_valid_lowercase_value_returned_case_insensitive(self): + """Test that lowercase value is accepted in case insensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_with_choices( + "TEST_ENV", "default", ["OPTION1", "OPTION2"], case_sensitive=False + ) + assert env_func() == "option1" + + def test_valid_uppercase_value_returned_case_insensitive(self): + """Test that uppercase value is accepted in case insensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=False + ) + assert env_func() == "OPTION1" + + def test_invalid_value_raises_error_case_sensitive(self): + """Test that invalid value raises ValueError in case sensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "invalid"}): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=True + ) + with pytest.raises( + ValueError, match="Invalid value 'invalid' for TEST_ENV" + ): + env_func() + + def test_case_mismatch_raises_error_case_sensitive(self): + """Test that case mismatch raises ValueError in case sensitive mode.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=True + ) + with pytest.raises( + ValueError, match="Invalid value 'OPTION1' for TEST_ENV" + ): + env_func() + + def test_invalid_value_raises_error_case_insensitive(self): + """Test that invalid value raises ValueError when case insensitive.""" + with patch.dict(os.environ, {"TEST_ENV": "invalid"}): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=False + ) + with pytest.raises( + ValueError, match="Invalid value 'invalid' for TEST_ENV" + ): + env_func() + + def test_callable_choices_resolved_correctly(self): + """Test that callable choices are resolved correctly.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}): + env_func = env_with_choices("TEST_ENV", "default", get_choices) + assert env_func() == "dynamic1" + + def test_callable_choices_with_invalid_value(self): + """Test that callable choices raise error for invalid values.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "invalid"}): + env_func = env_with_choices("TEST_ENV", "default", get_choices) + with pytest.raises( + ValueError, match="Invalid value 'invalid' for TEST_ENV" + ): + env_func() + + +class TestEnvListWithChoices: + """Test cases for env_list_with_choices function.""" + + def test_default_list_returned_when_env_not_set(self): + """Test that default list is returned when env var is not set.""" + env_func = env_list_with_choices( + "NONEXISTENT_ENV", ["default1", "default2"], ["option1", "option2"] + ) + assert env_func() == ["default1", "default2"] + + def test_empty_default_list_returned_when_env_not_set(self): + """Test that empty default list is returned when env not set.""" + env_func = env_list_with_choices("NONEXISTENT_ENV", [], ["option1", "option2"]) + assert env_func() == [] + + def test_single_valid_value_parsed_correctly(self): + """Test that single valid value is parsed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == ["option1"] + + def test_multiple_valid_values_parsed_correctly(self): + """Test that multiple valid values are parsed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == ["option1", "option2"] + + def test_values_with_whitespace_trimmed(self): + """Test that values with whitespace are trimmed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == ["option1", "option2"] + + def test_empty_values_filtered_out(self): + """Test that empty values are filtered out.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == ["option1", "option2"] + + def test_empty_string_returns_default(self): + """Test that empty string returns default.""" + with patch.dict(os.environ, {"TEST_ENV": ""}): + env_func = env_list_with_choices( + "TEST_ENV", ["default"], ["option1", "option2"] + ) + assert env_func() == ["default"] + + def test_only_commas_returns_default(self): + """Test that string with only commas returns default.""" + with patch.dict(os.environ, {"TEST_ENV": ",,,"}): + env_func = env_list_with_choices( + "TEST_ENV", ["default"], ["option1", "option2"] + ) + assert env_func() == ["default"] + + def test_case_sensitive_validation(self): + """Test case sensitive validation.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}): + env_func = env_list_with_choices( + "TEST_ENV", [], ["option1", "option2"], case_sensitive=True + ) + with pytest.raises(ValueError, match="Invalid value 'OPTION2' in TEST_ENV"): + env_func() + + def test_case_insensitive_validation(self): + """Test case insensitive validation.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}): + env_func = env_list_with_choices( + "TEST_ENV", [], ["option1", "option2"], case_sensitive=False + ) + assert env_func() == ["OPTION1", "option2"] + + def test_invalid_value_in_list_raises_error(self): + """Test that invalid value in list raises ValueError.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"): + env_func() + + def test_callable_choices_resolved_correctly(self): + """Test that callable choices are resolved correctly.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}): + env_func = env_list_with_choices("TEST_ENV", [], get_choices) + assert env_func() == ["dynamic1", "dynamic2"] + + def test_callable_choices_with_invalid_value(self): + """Test that callable choices raise error for invalid values.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}): + env_func = env_list_with_choices("TEST_ENV", [], get_choices) + with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"): + env_func() + + def test_duplicate_values_preserved(self): + """Test that duplicate values in the list are preserved.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == ["option1", "option1", "option2"] diff --git a/tests/test_inputs.py b/tests/test_inputs.py index e549834faf6f7..50a273016ab80 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -3,15 +3,20 @@ import pytest +from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import parse_raw_prompts +from vllm.inputs.preprocess import InputPreprocessor +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs + +pytestmark = pytest.mark.cpu_test STRING_INPUTS = [ - '', - 'foo', - 'foo bar', - 'foo baz bar', - 'foo bar qux baz', + "", + "foo", + "foo bar", + "foo baz bar", + "foo bar qux baz", ] TOKEN_INPUTS = [ @@ -29,52 +34,106 @@ INPUTS_SLICES = [ ] -def test_parse_single_batch_empty(): +def test_parse_raw_single_batch_empty(): with pytest.raises(ValueError, match="at least one prompt"): - parse_and_batch_prompt([]) + parse_raw_prompts([]) with pytest.raises(ValueError, match="at least one prompt"): - parse_and_batch_prompt([[]]) + parse_raw_prompts([[]]) -@pytest.mark.parametrize('string_input', STRING_INPUTS) -def test_parse_single_batch_string_consistent(string_input: str): - assert parse_and_batch_prompt(string_input) \ - == parse_and_batch_prompt([string_input]) +@pytest.mark.parametrize("string_input", STRING_INPUTS) +def test_parse_raw_single_batch_string_consistent(string_input: str): + assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input]) -@pytest.mark.parametrize('token_input', TOKEN_INPUTS) -def test_parse_single_batch_token_consistent(token_input: list[int]): - assert parse_and_batch_prompt(token_input) \ - == parse_and_batch_prompt([token_input]) +@pytest.mark.parametrize("token_input", TOKEN_INPUTS) +def test_parse_raw_single_batch_token_consistent(token_input: list[int]): + assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input]) -@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) -def test_parse_single_batch_string_slice(inputs_slice: slice): - assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ - == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) +@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES) +def test_parse_raw_single_batch_string_slice(inputs_slice: slice): + assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts( + STRING_INPUTS[inputs_slice] + ) -# yapf: disable -@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [ - (None, [{}, {}]), - ({}, [{}, {}]), - ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), - ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), -]) -# yapf: enable +@pytest.mark.parametrize( + "mm_processor_kwargs,expected_mm_kwargs", + [ + (None, [{}, {}]), + ({}, [{}, {}]), + ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), + ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), + ], +) def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): """Test mm_processor_kwargs init for zipping enc/dec prompts.""" - encoder_prompts = ['An encoder prompt', 'Another encoder prompt'] - decoder_prompts = ['A decoder prompt', 'Another decoder prompt'] - zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts, - mm_processor_kwargs) + encoder_prompts = ["An encoder prompt", "Another encoder prompt"] + decoder_prompts = ["A decoder prompt", "Another decoder prompt"] + zipped_prompts = zip_enc_dec_prompts( + encoder_prompts, decoder_prompts, mm_processor_kwargs + ) assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts) - for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts, - expected_mm_kwargs, - zipped_prompts): + for enc, dec, exp_kwargs, zipped in zip( + encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts + ): assert isinstance(zipped, dict) assert len(zipped.keys()) == 3 - assert zipped['encoder_prompt'] == enc - assert zipped['decoder_prompt'] == dec - assert zipped['mm_processor_kwargs'] == exp_kwargs + assert zipped["encoder_prompt"] == enc + assert zipped["decoder_prompt"] == dec + assert zipped["mm_processor_kwargs"] == exp_kwargs + + +@pytest.mark.parametrize( + "model_id", + [ + "facebook/opt-125m", + ], +) +@pytest.mark.parametrize( + "prompt", + [ + { + "prompt": "", + "multi_modal_data": {"dummy": []}, + }, + { + "prompt_token_ids": [], + "multi_modal_data": {"dummy": []}, + }, + ], +) +def test_preprocessor_text_no_mm_inputs(model_id, prompt): + model_config = ModelConfig(model=model_id) + tokenizer = init_tokenizer_from_configs(model_config) + input_preprocessor = InputPreprocessor(model_config, tokenizer) + + with pytest.raises(ValueError, match="does not support multimodal inputs"): + input_preprocessor.preprocess(prompt) + + +@pytest.mark.parametrize( + "model_id", + [ + "facebook/chameleon-7b", + ], +) +@pytest.mark.parametrize( + "prompt", + [ + "", + {"prompt_token_ids": []}, + ], +) +def test_preprocessor_always_mm_code_path(model_id, prompt): + model_config = ModelConfig(model=model_id) + tokenizer = init_tokenizer_from_configs(model_config) + input_preprocessor = InputPreprocessor(model_config, tokenizer) + + # HF processor adds sep token + sep_token_id = tokenizer.vocab[tokenizer.sep_token] + + processed_inputs = input_preprocessor.preprocess(prompt) + assert sep_token_id in processed_inputs["prompt_token_ids"] diff --git a/tests/test_logger.py b/tests/test_logger.py index 0bfb449cdf213..ec368d4897b5a 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -16,8 +16,13 @@ from uuid import uuid4 import pytest from vllm.entrypoints.logger import RequestLogger -from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, - enable_trace_function_call, init_logger) +from vllm.logger import ( + _DATE_FORMAT, + _FORMAT, + _configure_vllm_root_logger, + enable_trace_function_call, + init_logger, +) from vllm.logging_utils import NewLineFormatter from vllm.logging_utils.dump_input import prepare_object_to_dump @@ -129,8 +134,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write("---\nloggers: []\nversion: 1") logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name): with pytest.raises(JSONDecodeError) as ex_info: _configure_vllm_root_logger() assert ex_info.type == JSONDecodeError @@ -138,24 +142,24 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) -@pytest.mark.parametrize("unexpected_config", ( - "Invalid string", - [{ - "version": 1, - "loggers": [] - }], - 0, -)) +@pytest.mark.parametrize( + "unexpected_config", + ( + "Invalid string", + [{"version": 1, "loggers": []}], + 0, + ), +) def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( - unexpected_config: Any): + unexpected_config: Any, +): """This test calls _configure_vllm_root_logger again to test custom logging config behavior, however it fails before any change in behavior or configuration occurs.""" with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(unexpected_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name): with pytest.raises(ValueError) as ex_info: _configure_vllm_root_logger() assert ex_info.type == ValueError # noqa: E721 @@ -174,14 +178,15 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): "propagate": False, } }, - "version": 1 + "version": 1, } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name), patch( - "vllm.logger.dictConfig") as dict_config_mock: + with ( + patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name), + patch("vllm.logger.dictConfig") as dict_config_mock, + ): _configure_vllm_root_logger() dict_config_mock.assert_called_with(valid_logging_config) @@ -197,19 +202,19 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): "handlers": [], } }, - "version": 1 + "version": 1, } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name): with pytest.raises(RuntimeError) as ex_info: _configure_vllm_root_logger() assert ex_info.type is RuntimeError expected_message_snippet = ( "VLLM_CONFIGURE_LOGGING evaluated to false, but " - "VLLM_LOGGING_CONFIG_PATH was given.") + "VLLM_LOGGING_CONFIG_PATH was given." + ) assert expected_message_snippet in str(ex_info) # Remember! The root logger is assumed to have been configured as @@ -223,11 +228,11 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): def test_prepare_object_to_dump(): - str_obj = 'str' + str_obj = "str" assert prepare_object_to_dump(str_obj) == "'str'" list_obj = [1, 2, 3] - assert prepare_object_to_dump(list_obj) == '[1, 2, 3]' + assert prepare_object_to_dump(list_obj) == "[1, 2, 3]" dict_obj = {"a": 1, "b": "b"} assert prepare_object_to_dump(dict_obj) in [ @@ -236,9 +241,9 @@ def test_prepare_object_to_dump(): ] set_obj = {1, 2, 3} - assert prepare_object_to_dump(set_obj) == '[1, 2, 3]' + assert prepare_object_to_dump(set_obj) == "[1, 2, 3]" - tuple_obj = ('a', 'b', 'c') + tuple_obj = ("a", "b", "c") assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']" class CustomEnum(enum.Enum): @@ -253,8 +258,7 @@ def test_prepare_object_to_dump(): a: int b: str - assert (prepare_object_to_dump(CustomClass( - 1, "b")) == "CustomClass(a=1, b='b')") + assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')" def test_request_logger_log_outputs(): @@ -467,7 +471,7 @@ def test_request_logger_log_outputs_integration(): def test_streaming_complete_logs_full_text_content(): """Test that streaming complete logging includes - full accumulated text, not just token count.""" + full accumulated text, not just token count.""" mock_logger = MagicMock() with patch("vllm.entrypoints.logger.logger", mock_logger): diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 4bb1c20f77f1d..7b234884c569e 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -1,15 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + from vllm.outputs import RequestOutput +pytestmark = pytest.mark.cpu_test + def test_request_output_forward_compatible(): - output = RequestOutput(request_id="test_request_id", - prompt="test prompt", - prompt_token_ids=[1, 2, 3], - prompt_logprobs=None, - outputs=[], - finished=False, - example_arg_added_in_new_version="some_value") + output = RequestOutput( + request_id="test_request_id", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[], + finished=False, + example_arg_added_in_new_version="some_value", + ) assert output is not None diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index 52c03015483c9..e3561ac3a577e 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -8,9 +8,11 @@ from vllm.config import ModelConfig EMBEDDING_MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - matryoshka_dimensions=[256]), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256], + ), ] @@ -65,8 +67,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): if model_info.is_matryoshka: assert model_info.matryoshka_dimensions is not None - pooling_params = PoolingParams( - dimensions=model_info.matryoshka_dimensions[0]) + pooling_params = PoolingParams(dimensions=model_info.matryoshka_dimensions[0]) pooling_params.verify(task=task, model_config=model_config) diff --git a/tests/test_regression.py b/tests/test_regression.py index f5f1ed8e805e0..8a9829e4dba5f 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -6,6 +6,7 @@ It should include tests that are reported by users and making sure they will never happen again. """ + import gc import pytest @@ -18,12 +19,12 @@ from vllm import LLM, SamplingParams def test_duplicated_ignored_sequence_group(): """https://github.com/vllm-project/vllm/issues/1655""" - sampling_params = SamplingParams(temperature=0.01, - top_p=0.1, - max_tokens=256) - llm = LLM(model="distilbert/distilgpt2", - max_num_batched_tokens=4096, - tensor_parallel_size=1) + sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=256) + llm = LLM( + model="distilbert/distilgpt2", + max_num_batched_tokens=4096, + tensor_parallel_size=1, + ) prompts = ["This is a short prompt", "This is a very long prompt " * 1000] outputs = llm.generate(prompts, sampling_params=sampling_params) @@ -31,12 +32,12 @@ def test_duplicated_ignored_sequence_group(): def test_max_tokens_none(): - sampling_params = SamplingParams(temperature=0.01, - top_p=0.1, - max_tokens=None) - llm = LLM(model="distilbert/distilgpt2", - max_num_batched_tokens=4096, - tensor_parallel_size=1) + sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) + llm = LLM( + model="distilbert/distilgpt2", + max_num_batched_tokens=4096, + tensor_parallel_size=1, + ) prompts = ["Just say hello!"] outputs = llm.generate(prompts, sampling_params=sampling_params) diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index 8324b225a8ce5..5a162fa8f791b 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -13,7 +13,9 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.routing_simulator import ( - DistributionBasedRouting, RoutingSimulator) + DistributionBasedRouting, + RoutingSimulator, +) @pytest.fixture @@ -60,10 +62,10 @@ def test_basic_functionality( ), f"Wrong ids shape for {strategy}" # Check that expert IDs are valid - assert (topk_ids.min() - >= 0), f"Invalid expert ID (negative) for {strategy}" - assert (topk_ids.max() - < num_experts), f"Invalid expert ID (too large) for {strategy}" + assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max() < num_experts, ( + f"Invalid expert ID (too large) for {strategy}" + ) def test_routing_strategy_integration(monkeypatch, device): @@ -96,25 +98,26 @@ def test_routing_strategy_integration(monkeypatch, device): envs.environment_variables[env_name] = lambda s=strategy: s # Test the select_experts method - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, use_grouped_topk=False, renormalize=True, - indices_type=torch.long) + indices_type=torch.long, + ) # Verify output shapes - assert topk_weights.shape == ( - num_tokens, top_k), f"Wrong weights shape for {strategy}" - assert topk_ids.shape == (num_tokens, - top_k), f"Wrong ids shape for {strategy}" + assert topk_weights.shape == (num_tokens, top_k), ( + f"Wrong weights shape for {strategy}" + ) + assert topk_ids.shape == (num_tokens, top_k), f"Wrong ids shape for {strategy}" # Verify expert IDs are valid - assert topk_ids.min( - ) >= 0, f"Invalid expert ID (negative) for {strategy}" - assert topk_ids.max( - ) < num_experts, f"Invalid expert ID (too large) for {strategy}" + assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max() < num_experts, ( + f"Invalid expert ID (too large) for {strategy}" + ) def test_distribution_based_routing_with_custom_strategy(): @@ -123,9 +126,7 @@ def test_distribution_based_routing_with_custom_strategy(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Register custom distribution-based strategy - custom_strategy = DistributionBasedRouting(distribution="normal", - mean=2.0, - std=0.5) + custom_strategy = DistributionBasedRouting(distribution="normal", mean=2.0, std=0.5) RoutingSimulator.register_strategy("custom_normal", custom_strategy) # Test data @@ -142,7 +143,8 @@ def test_distribution_based_routing_with_custom_strategy(): hidden_states=hidden_states, router_logits=router_logits, strategy_name="custom_normal", - top_k=top_k) + top_k=top_k, + ) # Check output shapes assert topk_weights.shape == (num_tokens, top_k) @@ -165,7 +167,8 @@ def test_instance_compatibility(): hidden_states=hidden_states, router_logits=router_logits, strategy_name="uniform_random", - top_k=2) + top_k=2, + ) assert topk_weights.shape == (10, 2) assert topk_ids.shape == (10, 2) diff --git a/tests/test_sampling_params.py b/tests/test_sampling_params.py deleted file mode 100644 index 7330f61e67689..0000000000000 --- a/tests/test_sampling_params.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for the SamplingParams class. -""" - -import pytest - -from vllm import SamplingParams -from vllm.config import ModelConfig -from vllm.entrypoints.openai.protocol import ChatCompletionRequest - -MODEL_NAME = "Qwen/Qwen1.5-7B" - - -def test_max_tokens_none(): - """max_tokens=None should be allowed""" - SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) - - -@pytest.fixture(scope="module") -def model_config(): - return ModelConfig( - MODEL_NAME, - seed=0, - dtype="float16", - ) - - -@pytest.fixture(scope="module") -def default_max_tokens(): - return 4096 - - -def test_sampling_params_from_request_with_no_guided_decoding_backend( - model_config, default_max_tokens): - # guided_decoding_backend is not present at request level - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - MODEL_NAME, - 'response_format': { - 'type': 'json_object', - }, - }) - - sampling_params = request.to_sampling_params( - default_max_tokens, - model_config.logits_processor_pattern, - ) - # we do not expect any backend to be present and the default - # guided_decoding_backend at engine level will be used. - assert sampling_params.guided_decoding.backend is None - - -@pytest.mark.parametrize("request_level_guided_decoding_backend,expected", - [("xgrammar", "xgrammar"), ("guidance", "guidance"), - ("outlines", "outlines")]) -def test_sampling_params_from_request_with_guided_decoding_backend( - request_level_guided_decoding_backend: str, expected: str, - model_config, default_max_tokens): - - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - MODEL_NAME, - 'response_format': { - 'type': 'json_object', - }, - 'guided_decoding_backend': - request_level_guided_decoding_backend, - }) - - sampling_params = request.to_sampling_params( - default_max_tokens, - model_config.logits_processor_pattern, - ) - # backend correctly identified in resulting sampling_params - assert sampling_params.guided_decoding.backend == expected diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py index ef4aef3afc2e2..5361efbbdf6fb 100644 --- a/tests/test_scalartype.py +++ b/tests/test_scalartype.py @@ -7,21 +7,24 @@ import torch from vllm.scalar_type import scalar_types -@pytest.mark.parametrize("type_tuple", ( - (-8, 7, scalar_types.int4), - (0, 15, scalar_types.uint4), - (-8, 7, scalar_types.uint4b8), - (-128, 127, scalar_types.uint8b128), - (-6., 6., scalar_types.float4_e2m1f), - (-28., 28., scalar_types.float6_e3m2f), - (torch.int8, scalar_types.int8), - (torch.uint8, scalar_types.uint8), - (torch.float8_e5m2, scalar_types.float8_e5m2), - (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), - (torch.bfloat16, scalar_types.float16_e8m7), - (torch.float16, scalar_types.float16_e5m10), -), - ids=lambda x: str(x)) +@pytest.mark.parametrize( + "type_tuple", + ( + (-8, 7, scalar_types.int4), + (0, 15, scalar_types.uint4), + (-8, 7, scalar_types.uint4b8), + (-128, 127, scalar_types.uint8b128), + (-6.0, 6.0, scalar_types.float4_e2m1f), + (-28.0, 28.0, scalar_types.float6_e3m2f), + (torch.int8, scalar_types.int8), + (torch.uint8, scalar_types.uint8), + (torch.float8_e5m2, scalar_types.float8_e5m2), + (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), + (torch.bfloat16, scalar_types.float16_e8m7), + (torch.float16, scalar_types.float16_e5m10), + ), + ids=lambda x: str(x), +) def test_scalar_type_min_max(type_tuple): print(type_tuple) if len(type_tuple) == 3: diff --git a/tests/test_seed_behavior.py b/tests/test_seed_behavior.py index e9138b9e8eb61..adc8a1a4bf08e 100644 --- a/tests/test_seed_behavior.py +++ b/tests/test_seed_behavior.py @@ -1,25 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - -import numpy as np -import torch - -from vllm.platforms.interface import Platform - - -def test_seed_behavior(): - # Test with a specific seed - Platform.seed_everything(42) - random_value_1 = random.randint(0, 100) - np_random_value_1 = np.random.randint(0, 100) - torch_random_value_1 = torch.randint(0, 100, (1, )).item() - - Platform.seed_everything(42) - random_value_2 = random.randint(0, 100) - np_random_value_2 = np.random.randint(0, 100) - torch_random_value_2 = torch.randint(0, 100, (1, )).item() - - assert random_value_1 == random_value_2 - assert np_random_value_1 == np_random_value_2 - assert torch_random_value_1 == torch_random_value_2 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import numpy as np +import torch + +from vllm.platforms.interface import Platform + + +def test_seed_behavior(): + # Test with a specific seed + Platform.seed_everything(42) + random_value_1 = random.randint(0, 100) + np_random_value_1 = np.random.randint(0, 100) + torch_random_value_1 = torch.randint(0, 100, (1,)).item() + + Platform.seed_everything(42) + random_value_2 = random.randint(0, 100) + np_random_value_2 = np.random.randint(0, 100) + torch_random_value_2 = torch.randint(0, 100, (1,)).item() + + assert random_value_1 == random_value_2 + assert np_random_value_1 == np_random_value_2 + assert torch_random_value_1 == torch_random_value_2 diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 1b019be9e56dc..27af05bec22dc 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,108 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest import torch -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - SequenceData, SequenceOutput) - -from .core.utils import create_dummy_prompt - - -@pytest.fixture -def sample_outputs(): - return [ - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) - ], - prompt_logprobs=None) for i in range(5) - ] - - -@pytest.fixture -def sampler_output(sample_outputs): - return SamplerOutput(outputs=sample_outputs) - - -def test_sampler_output_initialization(sampler_output, sample_outputs): - assert len(sampler_output) == len(sample_outputs) - assert sampler_output.sampled_token_probs is None - assert sampler_output.sampled_token_ids is None - - -def test_sampler_output_getitem(sampler_output, sample_outputs): - assert sampler_output[2] == sample_outputs[2] - - -def test_sampler_output_setitem(sampler_output): - new_output = CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) - ], - prompt_logprobs=None) - sampler_output[2] = new_output - assert sampler_output[2] == new_output - - -def test_sampler_output_len(sampler_output, sample_outputs): - assert len(sampler_output) == len(sample_outputs) - - -def test_sampler_output_eq(sample_outputs): - sampler_output1 = SamplerOutput(outputs=sample_outputs) - sampler_output2 = SamplerOutput(outputs=sample_outputs.copy()) - sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) - assert sampler_output1 == sampler_output2 - assert sampler_output1 != sampler_output3 - - -def test_sequence_data_prefill(): - seq_data = SequenceData.from_seqs([1, 2, 3, 4]) - assert seq_data.get_num_uncomputed_tokens() == 4 - assert seq_data.get_num_computed_tokens() == 0 - # advance by 2 - seq_data.update_num_computed_tokens(2) - assert seq_data.get_num_uncomputed_tokens() == 2 - assert seq_data.get_num_computed_tokens() == 2 - - # advance by 1 - seq_data.update_num_computed_tokens(1) - assert seq_data.get_num_uncomputed_tokens() == 1 - assert seq_data.get_num_computed_tokens() == 3 - - # append tokens and reset, simulating recompute - seq_data.append_token_id(1, logprob=0.0) - seq_data.reset_state_for_recompute() - assert seq_data.get_num_uncomputed_tokens() == 5 - assert seq_data.get_num_computed_tokens() == 0 - - -def test_sequence_group_stage(): - _, seq_group = create_dummy_prompt("1", 12) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(6) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(5) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(1) - assert seq_group.is_prefill() is False - seqs = seq_group.get_seqs() - assert len(seqs) == 1 - seqs[0].data.append_token_id(1, logprob=0.0) - for seq in seq_group.get_seqs(): - seq.reset_state_for_recompute() - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(5) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(7) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(1) - assert seq_group.is_prefill() is False +from vllm.sequence import IntermediateTensors def test_sequence_intermediate_tensors_equal(): - class AnotherIntermediateTensors(IntermediateTensors): pass @@ -115,22 +19,31 @@ def test_sequence_intermediate_tensors_equal(): assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 different_key_intermediate_tensors_1 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) difference_key_intermediate_tensors_2 = IntermediateTensors( - {"2": torch.zeros([2, 4], dtype=torch.int32)}) - assert (different_key_intermediate_tensors_1 - != difference_key_intermediate_tensors_2) + {"2": torch.zeros([2, 4], dtype=torch.int32)} + ) + assert different_key_intermediate_tensors_1 != difference_key_intermediate_tensors_2 same_key_different_value_intermediate_tensors_1 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) same_key_different_value_intermediate_tensors_2 = IntermediateTensors( - {"1": torch.zeros([2, 5], dtype=torch.int32)}) - assert (same_key_different_value_intermediate_tensors_1 - != same_key_different_value_intermediate_tensors_2) + {"1": torch.zeros([2, 5], dtype=torch.int32)} + ) + assert ( + same_key_different_value_intermediate_tensors_1 + != same_key_different_value_intermediate_tensors_2 + ) same_key_same_value_intermediate_tensors_1 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) same_key_same_value_intermediate_tensors_2 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) - assert (same_key_same_value_intermediate_tensors_1 == - same_key_same_value_intermediate_tensors_2) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) + assert ( + same_key_same_value_intermediate_tensors_1 + == same_key_same_value_intermediate_tensors_2 + ) diff --git a/tests/test_test.py b/tests/test_test.py deleted file mode 100644 index dc8c9814ede39..0000000000000 --- a/tests/test_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm import LLM, envs -from vllm.sampling_params import SamplingParams - -if not envs.VLLM_USE_V1: - pytest.skip( - "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", - allow_module_level=True, - ) - - -@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) -# TODO TPU will appear busy if we fan-out test params here -@pytest.mark.parametrize("n_prompts", [1]) -def test_logprobs(model_name: str, n_prompts: int): - """ - Request top logprobs with different sampling settings and check - that results contains the requested number, ordered ascendingly. - """ - - def check_num_logprobs(logprobs, expected_num: int): - for step in logprobs: - prev_logp = 1.0 - # order by rank - sorted_step = dict( - sorted(step.items(), key=lambda item: item[1].rank)) - - if len(step) != expected_num: - print("watch out", sorted_step) - - # check results are ordered by prob value - # assert len(step) == expected_num - for rankno, (tid, logp) in enumerate(sorted_step.items()): - assert logp.logprob <= prev_logp - prev_logp = logp.logprob - assert logp.rank == rankno + 1 - - llm = LLM(model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=128, - max_num_batched_tokens=128) - prompts = [ - "Write a short story about a robot that dreams for the first time." - ] * n_prompts - greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\ - logprobs=4) - regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ - logprobs=4) - topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ - logprobs=4, top_k=12, top_p=0.5) - - for sp in [greedy_sampling_params, regular_sampling_params, \ - topkp_sampling_params]: - output = llm.generate(prompts, sp) - for o in output: - check_num_logprobs(o.outputs[0].logprobs, 4) diff --git a/tests/test_triton_utils.py b/tests/test_triton_utils.py index 64f72668f29ce..7fe0a5d9c5176 100644 --- a/tests/test_triton_utils.py +++ b/tests/test_triton_utils.py @@ -5,8 +5,7 @@ import sys import types from unittest import mock -from vllm.triton_utils.importing import (TritonLanguagePlaceholder, - TritonPlaceholder) +from vllm.triton_utils.importing import TritonLanguagePlaceholder, TritonPlaceholder def test_triton_placeholder_is_module(): @@ -52,8 +51,7 @@ def test_triton_placeholder_decorators_with_args(): def bar(x): return x - @triton.heuristics( - {"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64}) + @triton.heuristics({"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64}) def baz(x): return x @@ -69,6 +67,8 @@ def test_triton_placeholder_language(): assert lang.constexpr is None assert lang.dtype is None assert lang.int64 is None + assert lang.int32 is None + assert lang.tensor is None def test_triton_placeholder_language_from_parent(): @@ -87,6 +87,7 @@ def test_no_triton_fallback(): # mock triton not being installed with mock.patch.dict(sys.modules, {"triton": None}): from vllm.triton_utils import HAS_TRITON, tl, triton + assert HAS_TRITON is False assert triton.__class__.__name__ == "TritonPlaceholder" assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder" diff --git a/tests/test_version.py b/tests/test_version.py index fd07abb59b1f8..928f742f1de8f 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -31,7 +31,8 @@ def test_version_tuple(): ((1, 0, 0), "1.-1", True), ((1, 0, 0), "0.9", False), ((1, 0, 0), "0.17", False), - ]) + ], +) def test_prev_minor_version_was(version_tuple, version_str, expected): with patch("vllm.version.__version_tuple__", version_tuple): assert version._prev_minor_version_was(version_str) == expected diff --git a/tests/test_vllm_port.py b/tests/test_vllm_port.py index 88e1efd8fdbb6..68bd511635dc1 100644 --- a/tests/test_vllm_port.py +++ b/tests/test_vllm_port.py @@ -23,14 +23,17 @@ def test_get_vllm_port_valid(): def test_get_vllm_port_invalid(): """Test when VLLM_PORT is set to a non-integer value.""" - with (patch.dict(os.environ, {"VLLM_PORT": "abc"}, clear=True), - pytest.raises(ValueError, match="must be a valid integer")): + with ( + patch.dict(os.environ, {"VLLM_PORT": "abc"}, clear=True), + pytest.raises(ValueError, match="must be a valid integer"), + ): get_vllm_port() def test_get_vllm_port_uri(): """Test when VLLM_PORT is set to a URI.""" - with (patch.dict(os.environ, {"VLLM_PORT": "tcp://localhost:5678"}, - clear=True), - pytest.raises(ValueError, match="appears to be a URI")): + with ( + patch.dict(os.environ, {"VLLM_PORT": "tcp://localhost:5678"}, clear=True), + pytest.raises(ValueError, match="appears to be a URI"), + ): get_vllm_port() diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index 07217611ea4d2..074039f9e5134 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -6,17 +6,16 @@ from copy import deepcopy import pytest from transformers import AutoTokenizer -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - get_cached_tokenizer) +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer @pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"]) def test_cached_tokenizer(model_id: str): - reference_tokenizer = AutoTokenizer.from_pretrained(model_id, - trust_remote_code=True) + reference_tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=True + ) reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"}) - reference_tokenizer.add_special_tokens( - {"additional_special_tokens": ["<SEP>"]}) + reference_tokenizer.add_special_tokens({"additional_special_tokens": ["<SEP>"]}) cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) _check_consistency(cached_tokenizer, reference_tokenizer) @@ -32,13 +31,13 @@ def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer): # Cached attributes assert target.all_special_ids == expected.all_special_ids assert target.all_special_tokens == expected.all_special_tokens - assert (target.all_special_tokens_extended == - expected.all_special_tokens_extended) + assert target.all_special_tokens_extended == expected.all_special_tokens_extended assert target.get_vocab() == expected.get_vocab() assert len(target) == len(expected) # Other attributes - assert getattr(target, "padding_side", - None) == getattr(expected, "padding_side", None) + assert getattr(target, "padding_side", None) == getattr( + expected, "padding_side", None + ) assert target.encode("prompt") == expected.encode("prompt") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index ccafc88461275..14dcab7707d4e 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -5,18 +5,16 @@ from collections.abc import Generator from typing import Any, Optional import pytest -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from vllm.inputs import token_inputs -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, - IncrementalDetokenizer, - SlowIncrementalDetokenizer) +from vllm.v1.engine.detokenizer import ( + FastIncrementalDetokenizer, + IncrementalDetokenizer, + SlowIncrementalDetokenizer, +) SPECIAL_TOKS_TRUTH = [ "Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa @@ -48,35 +46,35 @@ TOKENIZERS = [ ] -def _run_incremental_decode(tokenizer, - all_input_ids, - skip_special_tokens: bool, - starting_index: int, - spaces_between_special_tokens: bool = True, - fast: Optional[bool] = None): - +def _run_incremental_decode( + tokenizer, + all_input_ids, + skip_special_tokens: bool, + starting_index: int, + spaces_between_special_tokens: bool = True, + fast: Optional[bool] = None, +): prompt_token_ids = all_input_ids[:starting_index] params = SamplingParams( skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, - cache_salt=None, - data_parallel_rank=None) + request = EngineCoreRequest( + request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) if fast is None: - detokenizer = IncrementalDetokenizer.from_new_request( - tokenizer, request) + detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) elif fast: detokenizer = FastIncrementalDetokenizer(tokenizer, request) else: @@ -93,9 +91,11 @@ def _run_incremental_decode(tokenizer, @pytest.fixture def tokenizer(tokenizer_name): - return (MistralTokenizer.from_pretrained(tokenizer_name) - if "mistral" in tokenizer_name else - AutoTokenizer.from_pretrained(tokenizer_name)) + return ( + MistralTokenizer.from_pretrained(tokenizer_name) + if "mistral" in tokenizer_name + else AutoTokenizer.from_pretrained(tokenizer_name) + ) @pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"]) @@ -107,7 +107,8 @@ def tokenizer(tokenizer_name): "ပုံပြင်လေးပြောပြပါ", # Using "URGENCY" since "CY" has token id 130282 "URGENCY🌶️", - ]) + ], +) def test_mistral_edge_case(tokenizer, truth): """Test for a specific edge cases with V3-Tekken MistralTokenizer. @@ -120,7 +121,8 @@ def test_mistral_edge_case(tokenizer, truth): tokenizer, all_input_ids, skip_special_tokens=True, - starting_index=starting_index) + starting_index=starting_index, + ) assert decoded_text == truth assert out_ids == all_input_ids[starting_index:] @@ -129,8 +131,10 @@ def test_mistral_edge_case(tokenizer, truth): def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: if "mistral" in tokenizer_name: yield ( - True if request.param else - pytest.skip("mistral doesn't support skip_special_tokens=False")) + True + if request.param + else pytest.skip("mistral doesn't support skip_special_tokens=False") + ) else: yield bool(request.param) @@ -141,8 +145,14 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) @pytest.mark.parametrize("spaces_between_special_tokens", (True, False)) @pytest.mark.parametrize("fast", (True, False)) -def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, - spaces_between_special_tokens, fast): +def test_decode_streaming( + tokenizer, + truth, + with_prompt, + skip_special_tokens, + spaces_between_special_tokens, + fast, +): if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): pytest.skip() @@ -151,30 +161,35 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, if not fast and isinstance(tokenizer, PreTrainedTokenizerFast): # Fix up inconsistency in fast/slow tokenizer behaviour. - tokenizer.add_special_tokens({ - "additional_special_tokens": [ - at for at in - tokenizer._tokenizer.get_added_tokens_decoder().values() - if at.special - ] - }) + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + at + for at in tokenizer._tokenizer.get_added_tokens_decoder().values() + if at.special + ] + } + ) - extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \ + extra_decode_args = ( + {} + if not isinstance(tokenizer, PreTrainedTokenizer) else {"spaces_between_special_tokens": spaces_between_special_tokens} + ) truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids if tokenizer.bos_token_id is not None: truth_tokens.insert(0, tokenizer.bos_token_id) truth_tokens.append(tokenizer.eos_token_id) - new_truth = tokenizer.decode(truth_tokens, - skip_special_tokens=skip_special_tokens, - **extra_decode_args) + new_truth = tokenizer.decode( + truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args + ) if with_prompt: num_prompt_tokens = len( - tokenizer(truth[:len(truth) // 2], - add_special_tokens=False).input_ids) + tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids + ) if tokenizer.bos_token_id is not None: num_prompt_tokens += 1 @@ -182,11 +197,13 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, generated_input_ids = truth_tokens[num_prompt_tokens:] all_input_ids = prompt_input_ids + generated_input_ids starting_index = len(prompt_input_ids) - prompt = tokenizer.decode(prompt_input_ids, - skip_special_tokens=skip_special_tokens, - **extra_decode_args) + prompt = tokenizer.decode( + prompt_input_ids, + skip_special_tokens=skip_special_tokens, + **extra_decode_args, + ) - generated = new_truth[len(prompt):] + generated = new_truth[len(prompt) :] else: generated = new_truth starting_index = 0 @@ -198,7 +215,8 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, skip_special_tokens=skip_special_tokens, starting_index=starting_index, spaces_between_special_tokens=spaces_between_special_tokens, - fast=fast) + fast=fast, + ) assert decoded_text == generated assert out_ids == all_input_ids[starting_index:] @@ -211,205 +229,13 @@ def test_oov_decode(tokenizer, fast): pytest.skip() decoded_text, out_ids = _run_incremental_decode( - tokenizer, [len(tokenizer)], + tokenizer, + [len(tokenizer)], skip_special_tokens=True, starting_index=0, spaces_between_special_tokens=True, - fast=fast) + fast=fast, + ) - assert decoded_text == '' + assert decoded_text == "" assert out_ids == [len(tokenizer)] - - -@pytest.fixture -def detokenizer(tokenizer_name: str) -> Detokenizer: - tokenizer_group = TokenizerGroup( - tokenizer_id=tokenizer_name, - enable_lora=False, - max_num_seqs=100, - max_input_length=None, - tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", - trust_remote_code=False, - revision=None, - ) - - return Detokenizer(tokenizer_group) - - -@pytest.fixture(name="complete_sequence_token_ids") -def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer) -> list[int]: - return tokenizer(complete_sequence, add_special_tokens=False).input_ids - - -def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [] - return Sequence( - seq_id=0, - inputs=token_inputs(prompt_token_ids), - block_size=16, - ) - - -def create_dummy_logprobs( - complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: - return [{ - token_id: Logprob(logprob=0.0), - token_id + 1: Logprob(logprob=0.1) - } for token_id in complete_sequence_token_ids] - - -def create_dummy_prompt_logprobs( - complete_sequence_token_ids: list[int] -) -> list[Optional[dict[int, Any]]]: - # logprob for the first prompt token is None. - logprobs: list[Optional[dict[int, Any]]] = [None] - logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) - return logprobs - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) -def test_decode_sequence_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer, - skip_special_tokens: bool): - """Verify Detokenizer decodes logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - logprobs=2) - - # Run sequentially. - seq = create_sequence() - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) - sequential_logprobs_text_chosen_token: list[str] = [] - sequential_logprobs_text_other_token: list[str] = [] - for new_token, logprobs in zip(complete_sequence_token_ids, - dummy_logprobs): - seq.append_token_id(new_token, logprobs) - detokenizer.decode_sequence_inplace(seq, sampling_params) - sequential_logprobs_text_chosen_token.append( - seq.output_logprobs[-1][new_token].decoded_token) - sequential_logprobs_text_other_token.append( - seq.output_logprobs[-1][new_token + 1].decoded_token) - sequential_result = seq.output_text - - assert sequential_result == "".join(sequential_logprobs_text_chosen_token) - assert sequential_result != "".join(sequential_logprobs_text_other_token) - - if not skip_special_tokens: - # Text for logprobs for the chosen token should be the same as the - # generated text. Note that this will only be true if we skip - # special tokens. - assert sequential_result == complete_sequence - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer): - - # We want to use skip_special_tokens=False here but Mistral tokenizers - # don't support that. - if complete_sequence not in SPECIAL_TOKS_TRUTH: - skip_special_tokens = True - elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None), - MistralTokenizer): - skip_special_tokens = False - else: - pytest.skip("MistralTokenizers don't support " - "skip_special_tokens=False") - return - """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - prompt_logprobs=1) - - # Run sequentially. - seq = create_sequence(complete_sequence_token_ids) - seq_group = SequenceGroup(request_id="1", - seqs=[seq], - sampling_params=sampling_params, - arrival_time=0.0) - dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, - dummy_logprobs, - position_offset=0) - # First logprob is None. - decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ - 1:] # type: ignore - - # decoded_prompt_logprobs doesn't contain the first token. - token_ids = complete_sequence_token_ids - tokenizer = detokenizer.get_tokenizer_for_seq(seq) - text_full = tokenizer.decode(token_ids, - skip_special_tokens=skip_special_tokens) - text_first = tokenizer.decode(token_ids[0], - skip_special_tokens=skip_special_tokens) - text = text_full[len(text_first):] - - # Text for logprobs for the chosen token should be the same as the - # prompt text. Note that the first logprob is None. - assert text == "".join([ - logprobs[token_id].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) - assert text != "".join([ - logprobs[token_id + 1].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1]) -def test_decode_prompt_logprobs_chunked_prefill( - vllm_runner, - model, - chunked_prefill_token_size: int, - example_prompts, - monkeypatch, -): - # VLLM V1 does not use incremental detokenization for - # prompt logprobs, so this test strategy is irrelevant. - monkeypatch.setenv("VLLM_USE_V1", "0") - - max_num_seqs = 256 - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) - max_num_batched_tokens = chunked_prefill_token_size - - with vllm_runner(model, - dtype="half", - max_logprobs=5, - gpu_memory_utilization=0.5, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - - vllm_sampling_params = SamplingParams(max_tokens=10, - logprobs=5, - prompt_logprobs=5, - temperature=0.0) - vllm_results = vllm_model.llm.generate( - example_prompts, sampling_params=vllm_sampling_params) - - for idx, result in enumerate(vllm_results): - assert result.prompt_logprobs is not None - assert result.prompt_logprobs[0] is None - - # Compared detokenized prompts ids to original prompt. - generated_string = "" - for (prompt_token, - prompt_logprobs) in zip(result.prompt_token_ids[1:], - result.prompt_logprobs[1:]): - # prompt_logprobs is a dict of the token_id: logprob - # We select the token_id corresponding to the actual prompt - # Decoded token in the detokenized string corresponding to this - # prompt token. - generated_string += prompt_logprobs[prompt_token].decoded_token - - assert generated_string == example_prompts[idx], ( - "Detokenized prompt logprobs do not match original prompt") diff --git a/tests/tokenization/test_do_lower_case.py b/tests/tokenization/test_do_lower_case.py index 7aa655e1c3b45..8aff50b351e31 100644 --- a/tests/tokenization/test_do_lower_case.py +++ b/tests/tokenization/test_do_lower_case.py @@ -13,6 +13,6 @@ TOKENIZER_NAMES = ["BAAI/bge-base-en"] def test_special_tokens(tokenizer_name: str, n_tokens: int): tokenizer = get_tokenizer(tokenizer_name, revision="main") - prompts = '[UNK]' * n_tokens + prompts = "[UNK]" * n_tokens prompt_token_ids = tokenizer.encode(prompts) assert len(prompt_token_ids) == n_tokens + 2 diff --git a/tests/tokenization/test_get_eos.py b/tests/tokenization/test_get_eos.py index d8288429351c4..921d77b1b335e 100644 --- a/tests/tokenization/test_get_eos.py +++ b/tests/tokenization/test_get_eos.py @@ -5,6 +5,7 @@ This test file includes some cases where it is inappropriate to only get the `eos_token_id` from the tokenizer as defined by {meth}`vllm.LLMEngine._get_eos_token_id`. """ + from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer import get_tokenizer @@ -15,8 +16,7 @@ def test_get_llama3_eos_token(): tokenizer = get_tokenizer(model_name) assert tokenizer.eos_token_id == 128009 - generation_config = try_get_generation_config(model_name, - trust_remote_code=False) + generation_config = try_get_generation_config(model_name, trust_remote_code=False) assert generation_config is not None assert generation_config.eos_token_id == [128001, 128008, 128009] @@ -27,7 +27,6 @@ def test_get_blip2_eos_token(): tokenizer = get_tokenizer(model_name) assert tokenizer.eos_token_id == 2 - generation_config = try_get_generation_config(model_name, - trust_remote_code=False) + generation_config = try_get_generation_config(model_name, trust_remote_code=False) assert generation_config is not None assert generation_config.eos_token_id == 50118 diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index 69b3c6294284b..ebf107217c3cb 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -1,188 +1,2209 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import pytest -from mistral_common.protocol.instruct.messages import (AssistantMessage, - ToolMessage, - UserMessage) -from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.protocol.instruct.tool_calls import (Function, - FunctionCall, Tool, - ToolCall) +from mistral_common.exceptions import InvalidMessageStructureException +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.transformers_utils.tokenizers.mistral import ( - make_mistral_chat_completion_request) + MistralTokenizer, + _prepare_apply_chat_template_tools_and_messages, +) @pytest.mark.parametrize( - "openai_request,expected_mistral_request", - [( - { - "messages": [{ - "role": "user", - "content": "What is the current local date and time?", - }], - "tools": [{ - "type": "function", - "function": { - "description": "Fetch the current local date and time.", - "name": "get_current_time", - }, - }], - }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What is the current local date and time?") - ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) - ], + "openai_request,expected_mistral_output", + [ + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + }, + } + ], + }, + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } + ], + ), ), - ), - ( - { - "messages": - [{ - "role": "user", - "content": "What is the current local date and time?", - }], - "tools": [{ - "type": "function", - "function": { - "description": "Fetch the current local date and time.", - "name": "get_current_time", - "parameters": None, - }, - }], - }, - ChatCompletionRequest( - messages=[ - UserMessage( - content="What is the current local date and time?") - ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) - ], - ), - )], + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } + ], + }, + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } + ], + ), + ), + ], ) -def test_make_mistral_chat_completion_request(openai_request, - expected_mistral_request): - actual_request = make_mistral_chat_completion_request( - openai_request["messages"], openai_request["tools"]) - assert actual_request == expected_mistral_request +def test_prepare_apply_chat_template_tools_and_messages( + openai_request, expected_mistral_output +): + actual_request = _prepare_apply_chat_template_tools_and_messages( + openai_request["messages"], openai_request["tools"] + ) + assert actual_request == expected_mistral_output # Tool use with list content and reasoning_content -@pytest.mark.parametrize("openai_request,expected_mistral_request", [( - { - "messages": [ +@pytest.mark.parametrize( + "openai_request,expected_mistral_output", + [ + ( { - "role": "user", - "content": "What's the weather in Paris?", - }, - { - "role": - "assistant", - "reasoning_content": - None, - "content": - None, - "tool_calls": [{ - "id": "call123", - "type": "function", - "function": { + "messages": [ + { + "role": "user", + "content": "What's the weather in Paris?", + }, + { + "role": "assistant", + "reasoning_content": None, + "content": None, + "tool_calls": [ + { + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Rainy"}], "name": "get_weather", - "arguments": '{"city": "Paris"}', + "tool_call_id": "call123", }, - }], + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], }, - { - "role": "tool", - "content": [{ - "type": "text", - "text": "Rainy" - }], - "name": "get_weather", - "tool_call_id": "call123", - }, - ], - "tools": [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Gets the current weather in a city.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } + ( + [ + { + "role": "user", + "content": "What's the weather in Paris?", }, - "required": ["city"], - }, - }, - }], - }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What's the weather in Paris?"), - AssistantMessage( - content=None, - tool_calls=[ - ToolCall( - id="call123", - function=FunctionCall( - name="get_weather", - arguments='{"city": "Paris"}', - ), - ) + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Rainy"}], + "name": "get_weather", + "tool_call_id": "call123", + }, + ], + [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } ], ), - ToolMessage( - content="Rainy", - tool_call_id="call123", - name="get_weather", + ) + ], +) +def test_prepare_apply_chat_template_tools_and_messages_list_content( + openai_request, expected_mistral_output +): + actual_request = _prepare_apply_chat_template_tools_and_messages( + openai_request["messages"], openai_request["tools"] + ) + assert actual_request == expected_mistral_output + + +def test_prepare_apply_chat_template_generation_prompt_and_continue(): + messages = [{"role": "assistant", "content": "Hello"}] + tools: list[dict[str, Any]] = [] + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True + ) + + messages = [{"role": "user", "content": "Hello"}] + out_messages, _ = _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True + ) + assert out_messages == [{"role": "user", "content": "Hello"}] + + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True, continue_final_message=True + ) + + messages = [{"role": "assistant", "content": "Hello"}] + out_messages, _ = _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=False, continue_final_message=True + ) + assert out_messages == [{"role": "assistant", "content": "Hello"}] + + messages = [{"role": "user", "content": "Hello"}] + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=False, continue_final_message=True + ) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(request) -> MistralTokenizer: + return MistralTokenizer.from_pretrained(request.param) + + +@pytest.mark.parametrize( + "mistral_tokenizer", + ["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"], + indirect=True, +) +class TestMistralTokenizer: + def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer): + attributes = [ + mistral_tokenizer.all_special_tokens, + mistral_tokenizer.all_special_tokens_extended, + ] + + for attribute in attributes: + if mistral_tokenizer.is_tekken: + assert attribute == [ + "<unk>", + "<s>", + "</s>", + "[INST]", + "[/INST]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + "[TOOL_CALLS]", + "[IMG]", + "<pad>", + "[IMG_BREAK]", + "[IMG_END]", + "[PREFIX]", + "[MIDDLE]", + "[SUFFIX]", + "[SYSTEM_PROMPT]", + "[/SYSTEM_PROMPT]", + "[TOOL_CONTENT]", + ] + [f"<SPECIAL_{i}>" for i in range(20, 32)] + [ + "[ARGS]", + "[CALL_ID]", + "[THINK]", + "[/THINK]", + ] + [f"<SPECIAL_{i}>" for i in range(36, 1000)] + else: + assert attribute == [ + "<s>", + "</s>", + "[INST]", + "[/INST]", + "[TOOL_CALLS]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + ] + [f"[control_{i}]" for i in range(8, 769)] + + def get_vocab(self, mistral_tokenizer: MistralTokenizer): + assert ( + mistral_tokenizer.get_vocab() + == mistral_tokenizer.transformers_tokenizer.get_vocab() + ) + + def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer): + assert mistral_tokenizer.get_added_vocab() == {} + + def test_encode_one(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686] + ) + + assert mistral_tokenizer.encode_one("Hello world !") == token_ids + assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids + assert ( + mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode_one( + "Hello world !", truncation=False, max_length=1 + ) + == token_ids + ) + + def test_encode(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [1, 22177, 4304, 2662, 2] + if mistral_tokenizer.is_tekken + else [1, 23325, 2294, 1686, 2] + ) + + assert mistral_tokenizer.encode("Hello world !") == token_ids[:-1] + assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-2] + assert ( + mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3) + == token_ids[:-1] + ) + + assert ( + mistral_tokenizer.encode("Hello world !", add_special_tokens=True) + == token_ids + ) + assert ( + mistral_tokenizer.encode( + "Hello world !", add_special_tokens=True, max_length=3 + ) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode( + "Hello world !", add_special_tokens=True, truncation=False, max_length=3 + ) + == token_ids + ) + assert ( + mistral_tokenizer.encode("Hello world !", add_special_tokens=False) + == token_ids[1:-1] + ) + + @pytest.mark.parametrize( + "openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output", + [ + ( + { + "messages": [ + { + "role": "user", + "content": "Hello world !", + } + ], + }, + True, + False, + ([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]), + ("<s>[INST]▁Hello▁world▁![/INST]", ("<s>[INST]Hello world ![/INST]")), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + ], + }, + True, + False, + ( + [1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4], + [1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4], + ), + ( + "<s>[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]", + ( + "<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]" # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], + }, + True, + False, + ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + ], + ), + ( + '<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]', + ( + '<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]' # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "123456789", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "123456789", + "content": '{"temperature": 20, "unit": "celsius"}', + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], + }, + True, + False, + ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + 5, + 1501, + 7567, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 17452, + 2032, + 10598, + 19141, + 2032, + 1113, + 4684, + 1046, + 8474, + 1113, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 29507, + 10925, + 2, + 8, + 10598, + 4557, + 2032, + 10598, + 29475, + 17329, + 2032, + 29473, + 29518, + 29502, + 29493, + 1113, + 6074, + 2032, + 1113, + 29485, + 1958, + 3938, + 8474, + 1113, + 3613, + 29498, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 18163, + 9, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + 9, + 1689, + 1095, + 45629, + 32, + 19227, + 29363, + 2811, + 1429, + 42572, + 46005, + 2, + 7, + 19227, + 113824, + 2811, + 1032, + 1050, + 1048, + 1044, + 1429, + 8979, + 2811, + 1429, + 1099, + 79092, + 46005, + 8, + ], + ), + ( + '<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}]</s>[TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]', + ( + '<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}</s>[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]' # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "user", + "content": "Hello world !", + }, + { + "role": "assistant", + "content": "Hello ", + }, + ], + }, + False, + True, + ( + [1, 3, 23325, 2294, 1686, 4, 23325], + [1, 3, 22177, 4304, 2662, 4, 22177, 2], + ), + ( + "<s>[INST]▁Hello▁world▁![/INST]▁Hello", + ("<s>[INST]Hello world ![/INST]Hello</s>"), + ), ), ], - tools=[ - Tool( - type="function", - function=Function( - name="get_weather", - description="Gets the current weather in a city.", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } - }, - "required": ["city"], - }, - ), + ) + def test_apply_chat_template( + self, + mistral_tokenizer: MistralTokenizer, + openai_request: dict[str, Any], + add_generation_prompt: bool, + continue_final_message: bool, + expected_output: tuple[list[int], list[int]], + decoded_expected_output: tuple[str, str], + ): + actual_output = mistral_tokenizer.apply_chat_template( + openai_request["messages"], + tools=openai_request.get("tools", []), + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) + decoded_actual_output = mistral_tokenizer.tokenizer.decode( + actual_output, SpecialTokenPolicy.KEEP + ) + + assert actual_output == expected_output[mistral_tokenizer.is_tekken] + assert ( + decoded_actual_output + == decoded_expected_output[mistral_tokenizer.is_tekken] + ) + + def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer): + messages = [{"role": "user", "content": "Hello world !"}] + + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=True, + continue_final_message=True, ) - ], - ), -)]) -def test_make_mistral_chat_completion_request_list_content( - openai_request, expected_mistral_request): - actual_request = make_mistral_chat_completion_request( - openai_request["messages"], openai_request["tools"]) - assert actual_request == expected_mistral_request + + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=False, + continue_final_message=True, + ) + + messages = [ + {"role": "user", "content": "Hello world !"}, + {"role": "assistant", "content": "Hello "}, + ] + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=True, + continue_final_message=False, + ) + + messages = [ + {"role": "user", "content": "Hello world !"}, + {"role": "assistant", "content": "Hello "}, + ] + with pytest.raises(InvalidMessageStructureException): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=False, + continue_final_message=False, + ) + + @pytest.mark.parametrize( + "skip_special_tokens,expected_tokens", + ( + ( + False, + ( + "<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>", + "<s>[INST]Hello world ![/INST]Hello</s>", + ), + ), + (True, ("Hello world ! Hello", "Hello world !Hello")), + ), + ) + def test_decode( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + expected_tokens: tuple[str, str], + ): + ids = ( + [1, 3, 23325, 2294, 1686, 4, 23325, 2], + [1, 3, 22177, 4304, 2662, 4, 22177, 2], + ) + assert ( + mistral_tokenizer.decode( + ids[mistral_tokenizer.is_tekken], + skip_special_tokens=skip_special_tokens, + ) + == expected_tokens[mistral_tokenizer.is_tekken] + ) + + def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer): + tokens = ( + [ + "<s>", + "[AVAILABLE_TOOLS]", + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[/INST]", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + "</s>", + "[TOOL_RESULTS]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + "[/TOOL_RESULTS]", + ], + [ + "<s>", + "[SYSTEM_PROMPT]", + "I", + " am", + " an", + " AI", + "[/SYSTEM_PROMPT]", + "[AVAILABLE_TOOLS]", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "Hello", + " world", + " !", + "[/INST]", + "[TOOL_CALLS]", + "get", + "_", + "weather", + "[ARGS]", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + "</s>", + "[TOOL_RESULTS]", + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + "[/TOOL_RESULTS]", + ], + ) + + expected_strings = ( + '[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}', # noqa: E501 + 'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}', # noqa: E501 + ) + + assert ( + mistral_tokenizer.convert_tokens_to_string( + tokens[mistral_tokenizer.is_tekken] + ) + == expected_strings[mistral_tokenizer.is_tekken] + ) + + @pytest.mark.parametrize( + "skip_special_tokens,tuple_expected_tokens", + ( + ( + True, + ( + [ + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + ], + [ + "I", + " am", + " an", + " AI", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "Hello", + " world", + " !", + "[TOOL_CALLS]", + "get", + "_", + "weather", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + ], + ), + ), + ( + False, + ( + [ + "<s>", + "[AVAILABLE_TOOLS]", + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[/INST]", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + "</s>", + "[TOOL_RESULTS]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + "[/TOOL_RESULTS]", + ], + [ + "<s>", + "[SYSTEM_PROMPT]", + "I", + " am", + " an", + " AI", + "[/SYSTEM_PROMPT]", + "[AVAILABLE_TOOLS]", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "Hello", + " world", + " !", + "[/INST]", + "[TOOL_CALLS]", + "get", + "_", + "weather", + "[ARGS]", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + "</s>", + "[TOOL_RESULTS]", + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + "[/TOOL_RESULTS]", + ], + ), + ), + ), + ) + def test_convert_ids_to_tokens( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + tuple_expected_tokens: tuple[list[str], list[str]], + ): + tuple_ids = ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + 5, + 1501, + 7567, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 17452, + 2032, + 10598, + 19141, + 2032, + 1113, + 4684, + 1046, + 8474, + 1113, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 29507, + 10925, + 2, + 8, + 10598, + 4557, + 2032, + 10598, + 29475, + 17329, + 2032, + 29473, + 29518, + 29502, + 29493, + 1113, + 6074, + 2032, + 1113, + 29485, + 1958, + 3938, + 8474, + 1113, + 3613, + 29498, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 18163, + 9, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + 9, + 1689, + 1095, + 45629, + 32, + 19227, + 29363, + 2811, + 1429, + 42572, + 46005, + 2, + 7, + 19227, + 113824, + 2811, + 1032, + 1050, + 1048, + 1044, + 1429, + 8979, + 2811, + 1429, + 1099, + 79092, + 46005, + 8, + ], + ) + + ids = tuple_ids[mistral_tokenizer.is_tekken] + expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken] + actual_tokens = mistral_tokenizer.convert_ids_to_tokens( + ids, skip_special_tokens=skip_special_tokens + ) + assert actual_tokens == expected_tokens diff --git a/tests/tokenization/test_tokenizer.py b/tests/tokenization/test_tokenizer.py index 09a3638fd2ed1..e86bb03883b5e 100644 --- a/tests/tokenization/test_tokenizer.py +++ b/tests/tokenization/test_tokenizer.py @@ -19,5 +19,5 @@ def test_tokenizer_revision(tokenizer_name: str): assert isinstance(tokenizer, PreTrainedTokenizerBase) # Assume that "never" branch always does not exist - with pytest.raises(OSError, match='not a valid git identifier'): + with pytest.raises(OSError, match="not a valid git identifier"): get_tokenizer(tokenizer_name, revision="never") diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py deleted file mode 100644 index 0570c1525e111..0000000000000 --- a/tests/tokenization/test_tokenizer_group.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.transformers_utils.tokenizer_group import TokenizerGroup - - -@pytest.mark.asyncio -async def test_tokenizer_group(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer_group = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( - prompt="prompt", lora_request=None) - assert reference_tokenizer.encode( - "prompt") == await tokenizer_group.encode_async(prompt="prompt", - lora_request=None) - assert isinstance(tokenizer_group.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer_group.get_lora_tokenizer( - None) == await tokenizer_group.get_lora_tokenizer_async(None) diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py index 5abb101644086..de67c3e798c4e 100644 --- a/tests/tokenization/test_tokenizer_registry.py +++ b/tests/tokenization/test_tokenizer_registry.py @@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, Optional, Union from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.transformers_utils.tokenizer_base import (TokenizerBase, - TokenizerRegistry) +from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam class TestTokenizer(TokenizerBase): - @classmethod def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer": return TestTokenizer() @@ -57,6 +55,10 @@ class TestTokenizer(TokenizerBase): def max_token_id(self) -> int: raise NotImplementedError() + @property + def truncation_side(self) -> str: + raise NotImplementedError() + def __call__( self, text: Union[str, list[str], list[int]], @@ -81,23 +83,23 @@ class TestTokenizer(TokenizerBase): ) -> list[int]: raise NotImplementedError() - def encode(self, - text: str, - add_special_tokens: Optional[bool] = None) -> list[int]: + def encode(self, text: str, add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() - def apply_chat_template(self, - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, - **kwargs) -> list[int]: + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs, + ) -> list[int]: raise NotImplementedError() def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() - def decode(self, - ids: Union[list[int], int], - skip_special_tokens: bool = True) -> str: + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: raise NotImplementedError() def convert_ids_to_tokens( @@ -109,9 +111,9 @@ class TestTokenizer(TokenizerBase): def test_customized_tokenizer(): - TokenizerRegistry.register("test_tokenizer", - "tests.tokenization.test_tokenizer_registry", - "TestTokenizer") + TokenizerRegistry.register( + "test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer" + ) tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer") assert isinstance(tokenizer, TestTokenizer) diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index 510b54790cd90..ff9cdeeb73752 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -13,13 +13,13 @@ from .utils import ARGS, CONFIGS, ServerConfig # select models to test based on command line arguments def pytest_addoption(parser): - parser.addoption("--models", - nargs="+", - help="Specify one or more models to test") - parser.addoption("--extended", - action="store_true", - default=False, - help="invoke extended tests requiring large GPUs") + parser.addoption("--models", nargs="+", help="Specify one or more models to test") + parser.addoption( + "--extended", + action="store_true", + default=False, + help="invoke extended tests requiring large GPUs", + ) # for each server config, download the model and return the config @@ -29,8 +29,10 @@ def server_config(request): models = request.config.getoption("--models") config_keys_to_test = [ - key for key in CONFIGS if (models is None or key in models) and ( - extended or not CONFIGS[key].get("extended", False)) + key + for key in CONFIGS + if (models is None or key in models) + and (extended or not CONFIGS[key].get("extended", False)) ] config_key = request.param @@ -40,8 +42,9 @@ def server_config(request): config = CONFIGS[config_key] if current_platform.is_rocm() and not config.get("supports_rocm", True): - pytest.skip("The {} model can't be tested on the ROCm platform".format( - config["model"])) + pytest.skip( + "The {} model can't be tested on the ROCm platform".format(config["model"]) + ) # download model and tokenizer using transformers snapshot_download(config["model"]) @@ -53,8 +56,9 @@ def server_config(request): def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_wait_seconds=480) as server: + with RemoteOpenAIServer( + model, ARGS + args_for_model, max_wait_seconds=480 + ) as server: yield server diff --git a/tests/runai_model_streamer_test/__init__.py b/tests/tool_use/mistral/__init__.py similarity index 100% rename from tests/runai_model_streamer_test/__init__.py rename to tests/tool_use/mistral/__init__.py diff --git a/tests/mistral_tool_use/conftest.py b/tests/tool_use/mistral/conftest.py similarity index 76% rename from tests/mistral_tool_use/conftest.py rename to tests/tool_use/mistral/conftest.py index e89e60c5a02ec..9b0a6eb27fca7 100644 --- a/tests/mistral_tool_use/conftest.py +++ b/tests/tool_use/mistral/conftest.py @@ -12,13 +12,14 @@ from .utils import ARGS, CONFIGS, ServerConfig # for each server config, download the model and return the config -@pytest.fixture(scope="session", params=CONFIGS.keys()) +@pytest.fixture(scope="package", params=CONFIGS.keys()) def server_config(request): config = CONFIGS[request.param] if current_platform.is_rocm() and not config.get("supports_rocm", True): - pytest.skip("The {} model can't be tested on the ROCm platform".format( - config["model"])) + pytest.skip( + "The {} model can't be tested on the ROCm platform".format(config["model"]) + ) # download model and tokenizer using transformers snapshot_download(config["model"]) @@ -26,12 +27,13 @@ def server_config(request): # run this for each server config -@pytest.fixture(scope="session") +@pytest.fixture(scope="package") def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_wait_seconds=480) as server: + with RemoteOpenAIServer( + model, ARGS + args_for_model, max_wait_seconds=480 + ) as server: yield server diff --git a/tests/mistral_tool_use/test_mistral_tool_calls.py b/tests/tool_use/mistral/test_mistral_tool_calls.py similarity index 88% rename from tests/mistral_tool_use/test_mistral_tool_calls.py rename to tests/tool_use/mistral/test_mistral_tool_calls.py index 9bf6863f3f2b7..3c4a543abe412 100644 --- a/tests/mistral_tool_use/test_mistral_tool_calls.py +++ b/tests/tool_use/mistral/test_mistral_tool_calls.py @@ -19,12 +19,12 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL], tool_choice=WEATHER_TOOL, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 1 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral diff --git a/tests/mistral_tool_use/utils.py b/tests/tool_use/mistral/utils.py similarity index 68% rename from tests/mistral_tool_use/utils.py rename to tests/tool_use/mistral/utils.py index 7a026cd9bb619..13a234f8e26be 100644 --- a/tests/mistral_tool_use/utils.py +++ b/tests/tool_use/mistral/utils.py @@ -18,17 +18,16 @@ ARGS: list[str] = ["--max-model-len", "1024"] CONFIGS: dict[str, ServerConfig] = { "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ - "--tokenizer-mode", "mistral", - "--ignore-patterns=\"consolidated.safetensors\"" + "--tokenizer-mode", + "mistral", + '--ignore-patterns="consolidated.safetensors"', ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, } diff --git a/tests/tool_use/test_chat_completion_request_validations.py b/tests/tool_use/test_chat_completion_request_validations.py index a30c58b09fe8f..50cd9e4279b2a 100644 --- a/tests/tool_use/test_chat_completion_request_validations.py +++ b/tests/tool_use/test_chat_completion_request_validations.py @@ -8,68 +8,56 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest def test_chat_completion_request_with_no_tools(): # tools key is not present - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - }) - assert request.tool_choice == 'none' + request = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + } + ) + assert request.tool_choice == "none" # tools key is None - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tools': - None - }) - assert request.tool_choice == 'none' + request = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tools": None, + } + ) + assert request.tool_choice == "none" # tools key present but empty - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tools': [] - }) - assert request.tool_choice == 'none' + request = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tools": [], + } + ) + assert request.tool_choice == "none" -@pytest.mark.parametrize('tool_choice', ['auto', 'required']) +@pytest.mark.parametrize("tool_choice", ["auto", "required"]) def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice): - with pytest.raises(ValueError, - match="When using `tool_choice`, `tools` must be set."): - ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tool_choice': - tool_choice - }) + with pytest.raises( + ValueError, match="When using `tool_choice`, `tools` must be set." + ): + ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tool_choice": tool_choice, + } + ) - with pytest.raises(ValueError, - match="When using `tool_choice`, `tools` must be set."): - ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tool_choice': - tool_choice, - 'tools': - None - }) + with pytest.raises( + ValueError, match="When using `tool_choice`, `tools` must be set." + ): + ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tool_choice": tool_choice, + "tools": None, + } + ) diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index 8c01c86e29f2f..425d3879985e7 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -4,16 +4,21 @@ import openai import pytest -from .utils import (MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL, ServerConfig, - ensure_system_prompt) +from .utils import ( + MESSAGES_WITHOUT_TOOLS, + WEATHER_TOOL, + ServerConfig, + ensure_system_prompt, +) # test: make sure chat completions without tools provided work even when tools # are enabled. This makes sure tool call chat templates work, AND that the tool # parser stream processing doesn't change the output of the model. @pytest.mark.asyncio -async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, - server_config: ServerConfig): +async def test_chat_completion_without_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +): models = await client.models.list() model_name: str = models.data[0].id chat_completion = await client.chat.completions.create( @@ -21,7 +26,8 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, temperature=0, max_completion_tokens=150, model=model_name, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content @@ -32,8 +38,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, assert stop_reason != "tool_calls" # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 # make the same request, streaming stream = await client.chat.completions.create( @@ -55,7 +60,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, # make sure the role is assistant if delta.role: assert not role_sent - assert delta.role == 'assistant' + assert delta.role == "assistant" role_sent = True if delta.content: @@ -80,8 +85,9 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, # tools, to make sure we can still get normal chat completion responses # and that they won't be parsed as tools @pytest.mark.asyncio -async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, - server_config: ServerConfig): +async def test_chat_completion_with_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +): models = await client.models.list() model_name: str = models.data[0].id chat_completion = await client.chat.completions.create( @@ -90,19 +96,19 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, max_completion_tokens=150, model=model_name, tools=[WEATHER_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content # check to make sure we got text assert output_text is not None - assert stop_reason != 'tool_calls' + assert stop_reason != "tool_calls" assert len(output_text) > 0 # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 # make the same request, streaming stream = await client.chat.completions.create( @@ -125,7 +131,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, # make sure the role is assistant if delta.role: - assert delta.role == 'assistant' + assert delta.role == "assistant" role_sent = True if delta.content: @@ -142,6 +148,6 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, assert role_sent assert finish_reason_count == 1 assert chunk.choices[0].finish_reason == stop_reason - assert chunk.choices[0].finish_reason != 'tool_calls' + assert chunk.choices[0].finish_reason != "tool_calls" assert len(chunks) assert "".join(chunks) == output_text diff --git a/tests/tool_use/test_deepseekv31_tool_parser.py b/tests/tool_use/test_deepseekv31_tool_parser.py new file mode 100644 index 0000000000000..9b7e71b49c05b --- /dev/null +++ b/tests/tool_use/test_deepseekv31_tool_parser.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.entrypoints.openai.tool_parsers import DeepSeekV31ToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL = "deepseek-ai/DeepSeek-V3.1" + + +@pytest.fixture(scope="module") +def deepseekv31_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def parser(deepseekv31_tokenizer): + return DeepSeekV31ToolParser(deepseekv31_tokenizer) + + +def test_extract_tool_calls_with_tool(parser): + model_output = ( + "normal text" + + "<|tool▁calls▁begin|>" + + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + + "<|tool▁calls▁end|>" + ) + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "foo" + assert result.tool_calls[0].function.arguments == '{"x":1}' + assert result.content == "normal text" + + +def test_extract_tool_calls_with_multiple_tools(parser): + model_output = ( + "some prefix text" + + "<|tool▁calls▁begin|>" + + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + + '<|tool▁call▁begin|>bar<|tool▁sep|>{"y":2}<|tool▁call▁end|>' + + "<|tool▁calls▁end|>" + + " some suffix text" + ) + + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called + assert len(result.tool_calls) == 2 + + assert result.tool_calls[0].function.name == "foo" + assert result.tool_calls[0].function.arguments == '{"x":1}' + + assert result.tool_calls[1].function.name == "bar" + assert result.tool_calls[1].function.arguments == '{"y":2}' + + # prefix is content + assert result.content == "some prefix text" diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_use/test_glm4_moe_tool_parser.py index 91913c933184e..6f1f6671d9b3c 100644 --- a/tests/tool_use/test_glm4_moe_tool_parser.py +++ b/tests/tool_use/test_glm4_moe_tool_parser.py @@ -10,6 +10,8 @@ from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser from vllm.transformers_utils.tokenizer import get_tokenizer +pytestmark = pytest.mark.cpu_test + pytest.skip("skip glm4_moe parser test", allow_module_level=True) # Use a common model that is likely to be available MODEL = "zai-org/GLM-4.5" @@ -25,12 +27,14 @@ def glm4_moe_tool_parser(glm4_moe_tokenizer): return Glm4MoeModelToolParser(glm4_moe_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 0 @@ -45,7 +49,8 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): model_output = "This is a test" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -71,14 +76,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>fahrenheit</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], None, ), @@ -100,22 +109,30 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>fahrenheit</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], None, ), @@ -129,14 +146,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>celsius</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Seattle", - "state": "WA", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Seattle", + "state": "WA", + "unit": "celsius", + } + ), + ) + ) ], "I'll help you check the weather.", ), @@ -150,37 +171,51 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>celsius</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "New York", - "state": "NY", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "New York", + "state": "NY", + "unit": "celsius", + } + ), + ) + ) ], None, ), - ("""I will help you get the weather.<tool_call>get_weather + ( + """I will help you get the weather.<tool_call>get_weather <arg_key>city</arg_key> <arg_value>Beijing</arg_value> <arg_key>date</arg_key> <arg_value>2025-08-01</arg_value> - </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Beijing", - "date": "2025-08-01", - }), - )) - ], "I will help you get the weather."), + </tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Beijing", + "date": "2025-08-01", + } + ), + ) + ) + ], + "I will help you get the weather.", + ), ], ) -def test_extract_tool_calls(glm4_moe_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + glm4_moe_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -200,7 +235,8 @@ I will help you get the weather. </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 @@ -222,7 +258,8 @@ def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] # Should handle malformed XML gracefully # The parser should either extract what it can or return no tool calls @@ -237,12 +274,12 @@ def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "get_current_time" + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_time" # Empty arguments should result in empty JSON object assert extracted_tool_calls.tool_calls[0].function.arguments == "{}" @@ -268,7 +305,8 @@ meaningwhile, I will also check the weather in Shanghai. </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 2 @@ -319,8 +357,7 @@ def test_streaming_basic_functionality(glm4_moe_tool_parser): # The result behavior depends on the streaming state # This test mainly ensures no exceptions are thrown - assert result is None or hasattr(result, 'tool_calls') or hasattr( - result, 'content') + assert result is None or hasattr(result, "tool_calls") or hasattr(result, "content") def test_streaming_no_tool_calls(glm4_moe_tool_parser): @@ -339,7 +376,7 @@ def test_streaming_no_tool_calls(glm4_moe_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." @@ -365,7 +402,7 @@ def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser): # Should return content when no tool call tokens are detected assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == "get the weather.<tool_call>" @@ -381,7 +418,8 @@ def test_extract_tool_calls_special_characters(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 @@ -402,7 +440,8 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser): <arg_value>2025-08-01</arg_value>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] # Incomplete tool calls should not be extracted assert not extracted_tool_calls.tools_called diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 35153139350bf..44d42bbd72b04 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -9,12 +9,13 @@ import partial_json_parser import pytest from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import JambaToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer +pytestmark = pytest.mark.cpu_test + MODEL = "ai21labs/Jamba-tiny-dev" @@ -28,12 +29,14 @@ def jamba_tool_parser(jamba_tokenizer): return JambaToolParser(jamba_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 @@ -42,10 +45,9 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def stream_delta_message_generator( - jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, - model_output: str) -> Generator[DeltaMessage, None, None]: - all_token_ids = jamba_tokenizer.encode(model_output, - add_special_tokens=False) + jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str +) -> Generator[DeltaMessage, None, None]: + all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -54,18 +56,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] + current_token_ids = all_token_ids[: i + 1] - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=jamba_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - ) + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=jamba_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -82,8 +85,9 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = previous_tokens + new_tokens if previous_tokens\ - else new_tokens + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -91,7 +95,8 @@ def stream_delta_message_generator( def test_extract_tool_calls_no_tools(jamba_tool_parser): model_output = "This is a test" extracted_tool_calls = jamba_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -106,54 +111,63 @@ def test_extract_tool_calls_no_tools(jamba_tool_parser): argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - None), + None, + ), ( - ''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - " Sure! let me call the tool for you."), + " Sure! let me call the tool for you.", + ), ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "fahrenheit"} + ), + ) + ), ], - None) + None, + ), ], ) -def test_extract_tool_calls(jamba_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + jamba_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = jamba_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -170,63 +184,75 @@ def test_extract_tool_calls(jamba_tool_parser, model_output, ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ('''This is a test''', [], '''This is a test'''), + ("""This is a test""", [], """This is a test"""), ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - " "), + " ", + ), ( - ''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - " Sure! let me call the tool for you."), + " Sure! let me call the tool for you.", + ), ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "fahrenheit"} + ), + ) + ), ], - " ") + " ", + ), ], ) -def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, - model_output, expected_tool_calls, - expected_content): - other_content: str = '' +def test_extract_tool_calls_streaming( + jamba_tool_parser, + jamba_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + other_content: str = "" function_names: list[str] = [] function_args_strs: list[str] = [] tool_call_idx: int = -1 tool_call_ids: list[Optional[str]] = [] for delta_message in stream_delta_message_generator( - jamba_tool_parser, jamba_tokenizer, model_output): + jamba_tool_parser, jamba_tokenizer, model_output + ): # role should never be streamed from tool parser assert not delta_message.role @@ -262,18 +288,22 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, # make sure they're a string and then add them to the list assert isinstance(tool_call.function.arguments, str) - function_args_strs[ - tool_call.index] += tool_call.function.arguments + function_args_strs[tool_call.index] += tool_call.function.arguments assert other_content == expected_content actual_tool_calls = [ - ToolCall(id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=partial_json_parser.ensure_json( - function_args_str, Allow.OBJ | Allow.STR))) + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR + ), + ), + ) for tool_call_id, function_name, function_args_str in zip( - tool_call_ids, function_names, function_args_strs) + tool_call_ids, function_names, function_args_strs + ) ] assert_tool_calls(actual_tool_calls, expected_tool_calls) diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py index bd030632f167b..43feae4d865ed 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -10,6 +10,8 @@ from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser from vllm.transformers_utils.tokenizer import get_tokenizer +pytestmark = pytest.mark.cpu_test + # Use a common model that is likely to be available MODEL = "moonshotai/Kimi-K2-Instruct" @@ -24,27 +26,31 @@ def kimi_k2_tool_parser(kimi_k2_tokenizer): return KimiK2ToolParser(kimi_k2_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): - + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert actual_tool_call.type == "function" assert actual_tool_call.function == expected_tool_call.function # assert tool call id format assert actual_tool_call.id.startswith("functions.") - assert actual_tool_call.id.split(':')[-1].isdigit() - assert actual_tool_call.id.split('.')[1].split( - ':')[0] == expected_tool_call.function.name + assert actual_tool_call.id.split(":")[-1].isdigit() + assert ( + actual_tool_call.id.split(".")[1].split(":")[0] + == expected_tool_call.function.name + ) def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): model_output = "This is a test" extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -61,14 +67,18 @@ def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""", [ - ToolCall(id='functions.get_weather:0', - function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Beijing", - }, ), - ), - type='function') + ToolCall( + id="functions.get_weather:0", + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Beijing", + }, + ), + ), + type="function", + ) ], "I'll help you check the weather. ", ), @@ -77,31 +87,41 @@ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|> functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""", [ - ToolCall(id='functions.get_weather:0', - function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Beijing", - }, ), - ), - type='function'), - ToolCall(id='functions.get_weather:1', - function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Shanghai", - }, ), - ), - type='function') + ToolCall( + id="functions.get_weather:0", + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Beijing", + }, + ), + ), + type="function", + ), + ToolCall( + id="functions.get_weather:1", + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Shanghai", + }, + ), + ), + type="function", + ), ], "I'll help you check the weather. ", ), ], ) -def test_extract_tool_calls(kimi_k2_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -116,15 +136,14 @@ functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid JSON tool calls assert len(extracted_tool_calls.tool_calls) == 2 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "invalid_get_weather" - assert extracted_tool_calls.tool_calls[ - 1].function.name == "valid_get_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather" + assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather" def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser): @@ -134,13 +153,13 @@ functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid JSON tool calls assert len(extracted_tool_calls.tool_calls) == 1 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "valid_get_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather" def test_streaming_basic_functionality(kimi_k2_tool_parser): @@ -168,8 +187,7 @@ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_ # The result might be None or contain tool call information # This depends on the internal state management - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: assert len(result.tool_calls) >= 0 @@ -189,5 +207,5 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_use/test_minimax_tool_parser.py index ddf26007121e5..8610656fa288d 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_use/test_minimax_tool_parser.py @@ -7,11 +7,16 @@ from typing import Any import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionToolsParam, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser from vllm.transformers_utils.tokenizer import get_tokenizer +pytestmark = pytest.mark.cpu_test + # Use a common model that is likely to be available MODEL = "MiniMaxAi/MiniMax-M1-40k" @@ -29,60 +34,48 @@ def minimax_tool_parser(minimax_tokenizer): @pytest.fixture def sample_tools(): return [ - ChatCompletionToolsParam(type="function", - function={ - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - }, - "state": { - "type": "string", - "description": - "The state code" - }, - "unit": { - "type": "string", - "enum": - ["fahrenheit", "celsius"] - } - }, - "required": ["city", "state"] - } - }), - ChatCompletionToolsParam(type="function", - function={ - "name": "calculate_area", - "description": - "Calculate area of a shape", - "parameters": { - "type": "object", - "properties": { - "shape": { - "type": "string" - }, - "dimensions": { - "type": "object" - }, - "precision": { - "type": "integer" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": {"type": "string", "description": "The state code"}, + "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, + }, + "required": ["city", "state"], + }, + }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculate_area", + "description": "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, + }, + }, + ), ] -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 @@ -93,7 +86,8 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def test_extract_tool_calls_no_tools(minimax_tool_parser): model_output = "This is a test" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -114,14 +108,18 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], None, ), @@ -131,22 +129,30 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], None, ), @@ -155,14 +161,18 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Seattle", - "state": "WA", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Seattle", + "state": "WA", + "unit": "celsius", + } + ), + ) + ) ], "I'll help you check the weather.", ), @@ -171,14 +181,18 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "New York", - "state": "NY", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "New York", + "state": "NY", + "unit": "celsius", + } + ), + ) + ) ], None, ), @@ -186,22 +200,28 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): """<tool_calls> {"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Boston", - "state": "MA", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Boston", + "state": "MA", + } + ), + ) + ) ], None, ), ], ) -def test_extract_tool_calls(minimax_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + minimax_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -219,8 +239,7 @@ I'll help you with that. <tool_calls> {"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}} </tool_calls>""" - processed_output = minimax_tool_parser.preprocess_model_output( - model_output) + processed_output = minimax_tool_parser.preprocess_model_output(model_output) # The tool call within thinking tags should be removed assert "fake_tool" not in processed_output @@ -242,12 +261,12 @@ Let me help you with the weather. <tool_calls> </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "get_current_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" # Content extraction is based on the position of the first <tool_calls> in the original model_output # Since preprocessing removes tool calls within thinking tags, the actual first <tool_calls> is the external one @@ -268,14 +287,14 @@ def test_extract_tool_calls_invalid_json(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid JSON tool calls assert len(extracted_tool_calls.tool_calls) == 2 assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool" - assert extracted_tool_calls.tool_calls[ - 1].function.name == "another_valid_tool" + assert extracted_tool_calls.tool_calls[1].function.name == "another_valid_tool" def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser): @@ -288,14 +307,14 @@ def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid tool calls with both name and arguments assert len(extracted_tool_calls.tool_calls) == 2 assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool" - assert extracted_tool_calls.tool_calls[ - 1].function.name == "another_valid_tool" + assert extracted_tool_calls.tool_calls[1].function.name == "another_valid_tool" def test_streaming_basic_functionality(minimax_tool_parser): @@ -324,8 +343,7 @@ def test_streaming_basic_functionality(minimax_tool_parser): # The result might be None or contain tool call information # This depends on the internal state management - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: assert len(result.tool_calls) >= 0 @@ -350,7 +368,7 @@ def test_streaming_with_content_before_tool_calls(minimax_tool_parser): request=None, ) - if result is not None and hasattr(result, 'content'): + if result is not None and hasattr(result, "content"): # Should contain some content assert result.content is not None @@ -371,7 +389,7 @@ def test_streaming_no_tool_calls(minimax_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." @@ -397,8 +415,7 @@ def test_streaming_with_thinking_tags(minimax_tool_parser): # The preprocessing should remove tool calls from thinking tags # and only process the real tool call - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: for tool_call in result.tool_calls: assert tool_call.function.name != "ignored" @@ -417,7 +434,8 @@ def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] # Multiline JSON is currently not supported, should return no tools called assert not extracted_tool_calls.tools_called @@ -447,7 +465,7 @@ def test_streaming_arguments_incremental_output(minimax_tool_parser): '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', # Stage 6: Tool calls closed '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool', - '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>' + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>', ] function_name_sent = False @@ -455,8 +473,7 @@ def test_streaming_arguments_incremental_output(minimax_tool_parser): for i, current_text in enumerate(stages): previous_text = stages[i - 1] if i > 0 else "" - delta_text = current_text[len(previous_text - ):] if i > 0 else current_text + delta_text = current_text[len(previous_text) :] if i > 0 else current_text result = minimax_tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -471,30 +488,27 @@ def test_streaming_arguments_incremental_output(minimax_tool_parser): print(f"Stage {i}: Current text: {repr(current_text)}") print(f"Stage {i}: Delta text: {repr(delta_text)}") - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: tool_call = result.tool_calls[0] # Check if function name is sent (should happen only once) if tool_call.function and tool_call.function.name: assert tool_call.function.name == "get_current_weather" function_name_sent = True - print( - f"Stage {i}: Function name sent: {tool_call.function.name}" - ) + print(f"Stage {i}: Function name sent: {tool_call.function.name}") # Check if arguments are sent incrementally if tool_call.function and tool_call.function.arguments: args_fragment = tool_call.function.arguments - print( - f"Stage {i}: Got arguments fragment: {repr(args_fragment)}" - ) + print(f"Stage {i}: Got arguments fragment: {repr(args_fragment)}") # For incremental output, each fragment should be new content only # The fragment should not contain all previous content if i >= 2 and previous_args_content: # After we start getting arguments # The new fragment should not be identical to or contain all previous content - assert args_fragment != previous_args_content, f"Fragment should be incremental, not cumulative: {args_fragment}" + assert args_fragment != previous_args_content, ( + f"Fragment should be incremental, not cumulative: {args_fragment}" + ) # If this is truly incremental, the fragment should be relatively small # compared to the complete arguments so far @@ -518,7 +532,9 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): minimax_tool_parser.streamed_args_for_tool = [] # Simulate two consecutive calls with growing arguments - call1_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}' + call1_text = ( + '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}' + ) call2_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1", "param2": "value2"}}' print(f"Call 1 text: {repr(call1_text)}") @@ -536,7 +552,7 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): ) print(f"Result 1: {result1}") - if result1 and hasattr(result1, 'tool_calls') and result1.tool_calls: + if result1 and hasattr(result1, "tool_calls") and result1.tool_calls: for i, tc in enumerate(result1.tool_calls): print(f" Tool call {i}: {tc}") @@ -552,13 +568,12 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): ) print(f"Result 2: {result2}") - if result2 and hasattr(result2, 'tool_calls') and result2.tool_calls: + if result2 and hasattr(result2, "tool_calls") and result2.tool_calls: for i, tc in enumerate(result2.tool_calls): print(f" Tool call {i}: {tc}") # Verify the second call only returns the delta - if result2 is not None and hasattr(result2, - 'tool_calls') and result2.tool_calls: + if result2 is not None and hasattr(result2, "tool_calls") and result2.tool_calls: tool_call = result2.tool_calls[0] if tool_call.function and tool_call.function.arguments: args_delta = tool_call.function.arguments @@ -566,17 +581,21 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): # Should only contain the new part, not the full arguments # The delta should be something like ', "param2": "value2"}' or just '"param2": "value2"' - assert ', "param2": "value2"}' in args_delta or '"param2": "value2"' in args_delta, f"Expected delta containing param2, got: {args_delta}" + assert ( + ', "param2": "value2"}' in args_delta + or '"param2": "value2"' in args_delta + ), f"Expected delta containing param2, got: {args_delta}" # Should NOT contain the previous parameter data - assert '"param1": "value1"' not in args_delta, f"Arguments delta should not contain previous data: {args_delta}" + assert '"param1": "value1"' not in args_delta, ( + f"Arguments delta should not contain previous data: {args_delta}" + ) # The delta should be relatively short (incremental, not cumulative) - expected_max_length = len( - ', "param2": "value2"}') + 10 # Some tolerance - assert len( - args_delta - ) <= expected_max_length, f"Delta seems too long (possibly cumulative): {args_delta}" + expected_max_length = len(', "param2": "value2"}') + 10 # Some tolerance + assert len(args_delta) <= expected_max_length, ( + f"Delta seems too long (possibly cumulative): {args_delta}" + ) print("✓ Delta validation passed") else: @@ -603,40 +622,39 @@ def test_streaming_openai_compatibility(minimax_tool_parser): # Test scenario: simple buffering without complex tool call context test_cases: list[dict[str, Any]] = [ { - 'stage': 'Token: <', - 'previous': '', - 'current': '<', - 'delta': '<', - 'expected_content': None, # Should be buffered + "stage": "Token: <", + "previous": "", + "current": "<", + "delta": "<", + "expected_content": None, # Should be buffered }, { - 'stage': 'Token: tool_calls>', - 'previous': '<', - 'current': '<tool_calls>', - 'delta': 'tool_calls>', - 'expected_content': None, # Complete tag, should not output + "stage": "Token: tool_calls>", + "previous": "<", + "current": "<tool_calls>", + "delta": "tool_calls>", + "expected_content": None, # Complete tag, should not output }, { - 'stage': 'Regular content', - 'previous': 'Hello', - 'current': 'Hello world', - 'delta': ' world', - 'expected_content': ' world', # Normal content should pass through + "stage": "Regular content", + "previous": "Hello", + "current": "Hello world", + "delta": " world", + "expected_content": " world", # Normal content should pass through }, { - 'stage': 'Content with end tag start', - 'previous': 'Text', - 'current': 'Text content</tool_', - 'delta': ' content</tool_', - 'expected_content': - ' content', # Content part output, </tool_ buffered + "stage": "Content with end tag start", + "previous": "Text", + "current": "Text content</tool_", + "delta": " content</tool_", + "expected_content": " content", # Content part output, </tool_ buffered }, { - 'stage': 'Complete end tag', - 'previous': 'Text content</tool_', - 'current': 'Text content</tool_calls>', - 'delta': 'calls>', - 'expected_content': None, # Complete close tag, should not output + "stage": "Complete end tag", + "previous": "Text content</tool_", + "current": "Text content</tool_calls>", + "delta": "calls>", + "expected_content": None, # Complete close tag, should not output }, ] @@ -647,9 +665,9 @@ def test_streaming_openai_compatibility(minimax_tool_parser): print(f"Delta: {repr(test_case['delta'])}") result = minimax_tool_parser.extract_tool_calls_streaming( - previous_text=test_case['previous'], - current_text=test_case['current'], - delta_text=test_case['delta'], + previous_text=test_case["previous"], + current_text=test_case["current"], + delta_text=test_case["delta"], previous_token_ids=[], current_token_ids=[], delta_token_ids=[], @@ -659,15 +677,18 @@ def test_streaming_openai_compatibility(minimax_tool_parser): print(f"Result: {result}") # Check expected content - if test_case['expected_content'] is None: - assert result is None or not getattr(result, 'content', None), \ + if test_case["expected_content"] is None: + assert result is None or not getattr(result, "content", None), ( f"Stage {i}: Expected no content, got {result}" + ) print("✓ No content output as expected") else: - assert result is not None and hasattr(result, 'content'), \ + assert result is not None and hasattr(result, "content"), ( f"Stage {i}: Expected content, got {result}" - assert result.content == test_case['expected_content'], \ + ) + assert result.content == test_case["expected_content"], ( f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + ) print(f"✓ Content matches: {repr(result.content)}") print("✓ Streaming test with buffering completed successfully") @@ -688,35 +709,26 @@ def test_streaming_thinking_tag_buffering(minimax_tool_parser): # Test scenario: tool calls within thinking tags should be ignored test_cases: list[dict[str, Any]] = [ { - 'stage': 'Start thinking', - 'previous': '', - 'current': '<think>I need to use a tool. <tool_calls>', - 'delta': '<think>I need to use a tool. <tool_calls>', - 'expected_content': - '<think>I need to use a tool. <tool_calls>', # Should pass through as content + "stage": "Start thinking", + "previous": "", + "current": "<think>I need to use a tool. <tool_calls>", + "delta": "<think>I need to use a tool. <tool_calls>", + "expected_content": "<think>I need to use a tool. <tool_calls>", # Should pass through as content }, { - 'stage': - 'Tool call in thinking', - 'previous': - '<think>I need to use a tool. <tool_calls>', - 'current': - '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', - 'delta': - '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', - 'expected_content': - '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags + "stage": "Tool call in thinking", + "previous": "<think>I need to use a tool. <tool_calls>", + "current": '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', + "delta": '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', + "expected_content": '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags }, { - 'stage': 'Real tool call after thinking', - 'previous': - '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>', - 'current': - '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>', - 'delta': '\n<tool_calls>', - 'expected_content': - '\n', # Should output '\n' and suppress <tool_calls> - } + "stage": "Real tool call after thinking", + "previous": '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>', + "current": '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>', + "delta": "\n<tool_calls>", + "expected_content": "\n", # Should output '\n' and suppress <tool_calls> + }, ] for i, test_case in enumerate(test_cases): @@ -726,9 +738,9 @@ def test_streaming_thinking_tag_buffering(minimax_tool_parser): print(f"Delta: {repr(test_case['delta'])}") result = minimax_tool_parser.extract_tool_calls_streaming( - previous_text=test_case['previous'], - current_text=test_case['current'], - delta_text=test_case['delta'], + previous_text=test_case["previous"], + current_text=test_case["current"], + delta_text=test_case["delta"], previous_token_ids=[], current_token_ids=[], delta_token_ids=[], @@ -738,25 +750,32 @@ def test_streaming_thinking_tag_buffering(minimax_tool_parser): print(f"Result: {result}") # Check expected content - if 'expected_content' in test_case: - if test_case['expected_content'] is None: - assert result is None or not getattr(result, 'content', None), \ + if "expected_content" in test_case: + if test_case["expected_content"] is None: + assert result is None or not getattr(result, "content", None), ( f"Stage {i}: Expected no content, got {result}" + ) else: - assert result is not None and hasattr(result, 'content'), \ + assert result is not None and hasattr(result, "content"), ( f"Stage {i}: Expected content, got {result}" - assert result.content == test_case['expected_content'], \ + ) + assert result.content == test_case["expected_content"], ( f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + ) print(f"✓ Content matches: {repr(result.content)}") # Check tool calls - if test_case.get('expected_tool_call'): - assert result is not None and hasattr(result, 'tool_calls') and result.tool_calls, \ - f"Stage {i}: Expected tool call, got {result}" + if test_case.get("expected_tool_call"): + assert ( + result is not None + and hasattr(result, "tool_calls") + and result.tool_calls + ), f"Stage {i}: Expected tool call, got {result}" tool_call = result.tool_calls[0] - assert tool_call.function.name == "real_tool", \ + assert tool_call.function.name == "real_tool", ( f"Expected real_tool, got {tool_call.function.name}" + ) print(f"✓ Real tool call detected: {tool_call.function.name}") print("✓ Thinking tag buffering test completed successfully") @@ -782,104 +801,79 @@ def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): # Complex scenario: tools inside thinking tags and multiple tools in one group test_stages: list[dict[str, Any]] = [ { - 'stage': 'Initial content', - 'previous': '', - 'current': 'Let me help you with this task.', - 'delta': 'Let me help you with this task.', - 'expected_content': 'Let me help you with this task.', - 'expected_tool_calls': 0, + "stage": "Initial content", + "previous": "", + "current": "Let me help you with this task.", + "delta": "Let me help you with this task.", + "expected_content": "Let me help you with this task.", + "expected_tool_calls": 0, }, { - 'stage': 'Start thinking tag', - 'previous': 'Let me help you with this task.', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.', - 'delta': '<think>I need to analyze this situation first.', - 'expected_content': - '<think>I need to analyze this situation first.', - 'expected_tool_calls': 0, + "stage": "Start thinking tag", + "previous": "Let me help you with this task.", + "current": "Let me help you with this task.<think>I need to analyze this situation first.", + "delta": "<think>I need to analyze this situation first.", + "expected_content": "<think>I need to analyze this situation first.", + "expected_tool_calls": 0, }, { - 'stage': 'Tool call inside thinking tag starts', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>', - 'delta': '<tool_calls>', - 'expected_content': - '<tool_calls>', # Inside thinking tags, tool tags should be preserved as content - 'expected_tool_calls': 0, + "stage": "Tool call inside thinking tag starts", + "previous": "Let me help you with this task.<think>I need to analyze this situation first.", + "current": "Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>", + "delta": "<tool_calls>", + "expected_content": "<tool_calls>", # Inside thinking tags, tool tags should be preserved as content + "expected_tool_calls": 0, }, { - 'stage': 'Complete tool call inside thinking tag', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'delta': - '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'expected_content': - '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'expected_tool_calls': - 0, # Tools inside thinking tags should be ignored + "stage": "Complete tool call inside thinking tag", + "previous": "Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>", + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "delta": '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "expected_content": '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "expected_tool_calls": 0, # Tools inside thinking tags should be ignored }, { - 'stage': 'End thinking tag', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', - 'delta': '</think>', - 'expected_content': '</think>', - 'expected_tool_calls': 0, + "stage": "End thinking tag", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', + "delta": "</think>", + "expected_content": "</think>", + "expected_tool_calls": 0, }, { - 'stage': 'Multiple tools group starts', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', - 'delta': - '\nNow I need to get weather information and calculate area.<tool_calls>', - 'expected_content': - '\nNow I need to get weather information and calculate area.', # <tool_calls> should be filtered - 'expected_tool_calls': 0, + "stage": "Multiple tools group starts", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', + "delta": "\nNow I need to get weather information and calculate area.<tool_calls>", + "expected_content": "\nNow I need to get weather information and calculate area.", # <tool_calls> should be filtered + "expected_tool_calls": 0, }, { - 'stage': 'First tool in group', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', - 'delta': - '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', - 'expected_content': - None, # No content should be output when tool call is in progress - 'expected_tool_calls': 1, - 'expected_tool_name': 'get_current_weather', + "stage": "First tool in group", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + "delta": '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + "expected_content": None, # No content should be output when tool call is in progress + "expected_tool_calls": 1, + "expected_tool_name": "get_current_weather", }, { - 'stage': 'Second tool in group', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', - 'delta': - '\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', - 'expected_content': None, - 'expected_tool_calls': 1, - 'expected_tool_name': 'calculate_area', + "stage": "Second tool in group", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + "delta": '\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + "expected_content": None, + "expected_tool_calls": 1, + "expected_tool_name": "calculate_area", }, { - 'stage': 'Complete tool calls group', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>', - 'delta': '</tool_calls>', - 'expected_content': None, - 'expected_tool_calls': 0, - } + "stage": "Complete tool calls group", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>', + "delta": "</tool_calls>", + "expected_content": None, + "expected_tool_calls": 0, + }, ] tool_calls_count = 0 @@ -893,9 +887,9 @@ def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): print(f"Delta: {repr(test_case['delta'])}") result = minimax_tool_parser.extract_tool_calls_streaming( - previous_text=test_case['previous'], - current_text=test_case['current'], - delta_text=test_case['delta'], + previous_text=test_case["previous"], + current_text=test_case["current"], + delta_text=test_case["delta"], previous_token_ids=[], current_token_ids=[], delta_token_ids=[], @@ -905,53 +899,64 @@ def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): print(f"Result: {result}") # Check expected content - if test_case['expected_content'] is None: - assert result is None or not getattr(result, 'content', None), \ + if test_case["expected_content"] is None: + assert result is None or not getattr(result, "content", None), ( f"Stage {i}: Expected no content output, got {result}" + ) print("✓ No content output as expected") else: - assert result is not None and hasattr(result, 'content'), \ + assert result is not None and hasattr(result, "content"), ( f"Stage {i}: Expected content output, got {result}" - assert result.content == test_case['expected_content'], \ + ) + assert result.content == test_case["expected_content"], ( f"Stage {i}: Expected content {repr(test_case['expected_content'])}, got {repr(result.content)}" + ) print(f"✓ Content matches: {repr(result.content)}") # Check tool calls - expected_tool_calls = test_case['expected_tool_calls'] - actual_tool_calls = len(result.tool_calls) if result and hasattr( - result, 'tool_calls') and result.tool_calls else 0 + expected_tool_calls = test_case["expected_tool_calls"] + actual_tool_calls = ( + len(result.tool_calls) + if result and hasattr(result, "tool_calls") and result.tool_calls + else 0 + ) if expected_tool_calls > 0: - assert actual_tool_calls >= expected_tool_calls, \ + assert actual_tool_calls >= expected_tool_calls, ( f"Stage {i}: Expected at least {expected_tool_calls} tool calls, got {actual_tool_calls}" + ) - if 'expected_tool_name' in test_case: + if "expected_tool_name" in test_case: # Find the tool call with the expected name found_tool_call = None for tool_call in result.tool_calls: - if tool_call.function.name == test_case[ - 'expected_tool_name']: + if tool_call.function.name == test_case["expected_tool_name"]: found_tool_call = tool_call break - assert found_tool_call is not None, \ + assert found_tool_call is not None, ( f"Stage {i}: Expected tool name {test_case['expected_tool_name']} not found in tool calls: {[tc.function.name for tc in result.tool_calls]}" + ) print(f"✓ Tool call correct: {found_tool_call.function.name}") # Ensure tools inside thinking tags are not called - assert found_tool_call.function.name != "internal_analysis", \ + assert found_tool_call.function.name != "internal_analysis", ( f"Stage {i}: Tool 'internal_analysis' inside thinking tags should not be called" + ) tool_calls_count += actual_tool_calls print(f"✓ Detected {actual_tool_calls} tool calls") else: - assert actual_tool_calls == 0, \ + assert actual_tool_calls == 0, ( f"Stage {i}: Expected no tool calls, got {actual_tool_calls}" + ) # Verify overall results print("\n=== Test Summary ===") print(f"Total tool calls count: {tool_calls_count}") - assert tool_calls_count >= 2, f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}" + assert tool_calls_count >= 2, ( + f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}" + ) print("✓ Complex streaming test completed:") print(" - ✓ Tools inside thinking tags correctly ignored") @@ -985,8 +990,8 @@ Now I'll get the weather information for you. <tool_calls> # Stream character by character for i in range(1, len(complete_text) + 1): current_text = complete_text[:i] - previous_text = complete_text[:i - 1] if i > 1 else "" - delta_text = complete_text[i - 1:i] + previous_text = complete_text[: i - 1] if i > 1 else "" + delta_text = complete_text[i - 1 : i] # Show progress every 50 characters if i % 50 == 0 or i == len(complete_text): @@ -1005,36 +1010,35 @@ Now I'll get the weather information for you. <tool_calls> # Collect results if result is not None: - if hasattr(result, 'content') and result.content: + if hasattr(result, "content") and result.content: content_fragments.append(result.content) # Log important content fragments if any( - keyword in result.content for keyword in - ['<think>', '</think>', '<tool_calls>', '</tool_calls>']): - print( - f" Char {i}: Content fragment: {repr(result.content)}" - ) + keyword in result.content + for keyword in [ + "<think>", + "</think>", + "<tool_calls>", + "</tool_calls>", + ] + ): + print(f" Char {i}: Content fragment: {repr(result.content)}") - if hasattr(result, 'tool_calls') and result.tool_calls: + if hasattr(result, "tool_calls") and result.tool_calls: for tool_call in result.tool_calls: tool_info = { - 'character_position': - i, - 'function_name': - tool_call.function.name - if tool_call.function else None, - 'arguments': - tool_call.function.arguments - if tool_call.function else None, + "character_position": i, + "function_name": tool_call.function.name + if tool_call.function + else None, + "arguments": tool_call.function.arguments + if tool_call.function + else None, } tool_calls_detected.append(tool_info) - print( - f" Char {i}: Tool call detected: {tool_call.function.name}" - ) + print(f" Char {i}: Tool call detected: {tool_call.function.name}") if tool_call.function.arguments: - print( - f" Arguments: {repr(tool_call.function.arguments)}" - ) + print(f" Arguments: {repr(tool_call.function.arguments)}") # Verify results print("\n=== Streaming Test Results ===") @@ -1042,68 +1046,74 @@ Now I'll get the weather information for you. <tool_calls> print(f"Total tool calls detected: {len(tool_calls_detected)}") # Reconstruct content from fragments - reconstructed_content = ''.join(content_fragments) + reconstructed_content = "".join(content_fragments) print(f"Reconstructed content length: {len(reconstructed_content)}") # Verify thinking tags content is preserved - assert '<think>' in reconstructed_content, "Opening thinking tag should be preserved in content" - assert '</think>' in reconstructed_content, "Closing thinking tag should be preserved in content" + assert "<think>" in reconstructed_content, ( + "Opening thinking tag should be preserved in content" + ) + assert "</think>" in reconstructed_content, ( + "Closing thinking tag should be preserved in content" + ) # Verify that tool calls inside thinking tags are NOT extracted as actual tool calls thinking_tool_calls = [ - tc for tc in tool_calls_detected - if tc['function_name'] == 'internal_analysis' + tc for tc in tool_calls_detected if tc["function_name"] == "internal_analysis" ] - assert len( - thinking_tool_calls - ) == 0, f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}" + assert len(thinking_tool_calls) == 0, ( + f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}" + ) # Verify that real tool calls outside thinking tags ARE extracted weather_tool_calls = [ - tc for tc in tool_calls_detected - if tc['function_name'] == 'get_current_weather' + tc for tc in tool_calls_detected if tc["function_name"] == "get_current_weather" ] area_tool_calls = [ - tc for tc in tool_calls_detected - if tc['function_name'] == 'calculate_area' + tc for tc in tool_calls_detected if tc["function_name"] == "calculate_area" ] print(tool_calls_detected) - assert len(weather_tool_calls - ) > 0, "get_current_weather tool call should be detected" - assert len( - area_tool_calls) > 0, "calculate_area tool call should be detected" + assert len(weather_tool_calls) > 0, ( + "get_current_weather tool call should be detected" + ) + assert len(area_tool_calls) > 0, "calculate_area tool call should be detected" # Verify tool call arguments are properly streamed - weather_args_found = any(tc['arguments'] for tc in weather_tool_calls - if tc['arguments']) - area_args_found = any(tc['arguments'] for tc in area_tool_calls - if tc['arguments']) + weather_args_found = any( + tc["arguments"] for tc in weather_tool_calls if tc["arguments"] + ) + area_args_found = any(tc["arguments"] for tc in area_tool_calls if tc["arguments"]) print(f"Weather tool call with arguments: {weather_args_found}") print(f"Area tool call with arguments: {area_args_found}") # Verify content before and after tool calls - assert 'I\'ll help you with the weather analysis.' in reconstructed_content, "Initial content should be preserved" - assert 'Here are the results.' in reconstructed_content, "Final content should be preserved" + assert "I'll help you with the weather analysis." in reconstructed_content, ( + "Initial content should be preserved" + ) + assert "Here are the results." in reconstructed_content, ( + "Final content should be preserved" + ) # Verify that <tool_calls> and </tool_calls> tags are not included in the final content # (they should be filtered out when not inside thinking tags) content_outside_thinking = reconstructed_content # Remove thinking tag content to check content outside - if '<think>' in content_outside_thinking and '</think>' in content_outside_thinking: - start_think = content_outside_thinking.find('<think>') - end_think = content_outside_thinking.find('</think>') + len('</think>') - content_outside_thinking = content_outside_thinking[: - start_think] + content_outside_thinking[ - end_think:] + if "<think>" in content_outside_thinking and "</think>" in content_outside_thinking: + start_think = content_outside_thinking.find("<think>") + end_think = content_outside_thinking.find("</think>") + len("</think>") + content_outside_thinking = ( + content_outside_thinking[:start_think] + + content_outside_thinking[end_think:] + ) # Outside thinking tags, tool_calls tags should be filtered - tool_calls_in_content = content_outside_thinking.count('<tool_calls>') - assert tool_calls_in_content == 0, f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}" - - print( - "\n=== Character-by-character streaming test completed successfully ===" + tool_calls_in_content = content_outside_thinking.count("<tool_calls>") + assert tool_calls_in_content == 0, ( + f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}" ) + + print("\n=== Character-by-character streaming test completed successfully ===") print("✓ Tool calls inside thinking tags correctly ignored") print("✓ Tool calls outside thinking tags correctly detected") print("✓ Content properly streamed and reconstructed") @@ -1111,8 +1121,7 @@ Now I'll get the weather information for you. <tool_calls> print("✓ Character-level streaming works correctly") -def test_streaming_character_by_character_simple_tool_call( - minimax_tool_parser): +def test_streaming_character_by_character_simple_tool_call(minimax_tool_parser): """Test character-by-character streaming for a simple tool call scenario.""" # Reset streaming state reset_streaming_state(minimax_tool_parser) @@ -1129,8 +1138,8 @@ def test_streaming_character_by_character_simple_tool_call( for i in range(1, len(simple_text) + 1): current_text = simple_text[:i] - previous_text = simple_text[:i - 1] if i > 1 else "" - delta_text = simple_text[i - 1:i] + previous_text = simple_text[: i - 1] if i > 1 else "" + delta_text = simple_text[i - 1 : i] result = minimax_tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -1143,19 +1152,17 @@ def test_streaming_character_by_character_simple_tool_call( ) if result: - if hasattr(result, 'content') and result.content: + if hasattr(result, "content") and result.content: content_parts.append(result.content) print( f" Char {i} ({repr(delta_text)}): Content: {repr(result.content)}" ) - if hasattr(result, 'tool_calls') and result.tool_calls: + if hasattr(result, "tool_calls") and result.tool_calls: for tool_call in result.tool_calls: if tool_call.function and tool_call.function.name: tool_name_sent = True - print( - f" Char {i}: Tool name: {tool_call.function.name}" - ) + print(f" Char {i}: Tool name: {tool_call.function.name}") if tool_call.function and tool_call.function.arguments: tool_args_sent = True print( @@ -1163,12 +1170,14 @@ def test_streaming_character_by_character_simple_tool_call( ) # Verify basic expectations - reconstructed_content = ''.join(content_parts) + reconstructed_content = "".join(content_parts) print(f"Final reconstructed content: {repr(reconstructed_content)}") assert tool_name_sent, "Tool name should be sent during streaming" assert tool_args_sent, "Tool arguments should be sent during streaming" - assert "Let me check the weather." in reconstructed_content, "Initial content should be preserved" + assert "Let me check the weather." in reconstructed_content, ( + "Initial content should be preserved" + ) print("✓ Simple character-by-character test passed") @@ -1188,8 +1197,8 @@ def test_streaming_character_by_character_with_buffering(minimax_tool_parser): for i in range(1, len(buffering_text) + 1): current_text = buffering_text[:i] - previous_text = buffering_text[:i - 1] if i > 1 else "" - delta_text = buffering_text[i - 1:i] + previous_text = buffering_text[: i - 1] if i > 1 else "" + delta_text = buffering_text[i - 1 : i] result = minimax_tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -1201,16 +1210,18 @@ def test_streaming_character_by_character_with_buffering(minimax_tool_parser): request=None, ) - if result and hasattr(result, 'content') and result.content: + if result and hasattr(result, "content") and result.content: all_content.append(result.content) print(f" Char {i} ({repr(delta_text)}): {repr(result.content)}") - final_content = ''.join(all_content) + final_content = "".join(all_content) print(f"Final content: {repr(final_content)}") # The parser should handle the edge case where </tool_calls> appears before <tool_calls> assert "Hello" in final_content, "Initial 'Hello' should be preserved" - assert "world" in final_content, "Content after false closing tag should be preserved" + assert "world" in final_content, ( + "Content after false closing tag should be preserved" + ) assert "done" in final_content, "Final content should be preserved" print("✓ Buffering character-by-character test passed") diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py new file mode 100644 index 0000000000000..f6223f3fdce4f --- /dev/null +++ b/tests/tool_use/test_openai_tool_parser.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +from openai_harmony import ( + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + Role, + SystemContent, + load_harmony_encoding, +) + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL = "gpt2" + + +@pytest.fixture(scope="module") +def openai_tokenizer(): + # The parser does not use the tokenizer, but the constructor requires it. + return get_tokenizer(MODEL) + + +@pytest.fixture +def openai_tool_parser(openai_tokenizer): + return OpenAIToolParser(openai_tokenizer) + + +@pytest.fixture(scope="module") +def harmony_encoding(): + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 16 # Default from protocol.py + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + +def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.SYSTEM, + SystemContent.new(), + ), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Talk like a pirate!"), + ), + Message.from_role_and_content(Role.USER, "Arrr, how be you?"), + Message.from_role_and_content( + Role.ASSISTANT, "This is a test" + ).with_channel("final"), + ] + ) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT + ) + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert not extracted_info.tools_called + assert extracted_info.tool_calls == [] + assert extracted_info.content == "This is a test" + + +@pytest.mark.parametrize( + "tool_args", + [ + '{"location": "Tokyo"}', + '{\n"location": "Tokyo"\n}', + ], +) +def test_extract_tool_calls_single_tool( + openai_tool_parser, harmony_encoding, tool_args +): + convo = Conversation.from_messages( + [ + Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, tool_args) + .with_channel("commentary") + .with_recipient("functions.get_current_weather") + .with_content_type("json"), + ] + ) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + ) + ) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None + + +def test_extract_tool_calls_multiple_tools( + openai_tool_parser, + harmony_encoding, +): + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?" + ), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_current_weather") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_user_location") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.no_content_type"), + Message.from_role_and_content(Role.ASSISTANT, "foo") + .with_channel("commentary") + .with_recipient("functions.not_json_no_content_type"), + Message.from_role_and_content(Role.ASSISTANT, "{}") + .with_channel("commentary") + .with_recipient("functions.empty_args") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, "") + .with_channel("commentary") + .with_recipient("functions.no_args") + .with_content_type("json"), + ] + ) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, + Role.ASSISTANT, + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ToolCall( + function=FunctionCall( + name="get_user_location", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ToolCall( + function=FunctionCall( + name="no_content_type", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ToolCall( + function=FunctionCall( + name="not_json_no_content_type", + arguments="foo", + ) + ), + ToolCall( + function=FunctionCall( + name="empty_args", + arguments=json.dumps({}), + ) + ), + ToolCall( + function=FunctionCall( + name="no_args", + arguments="", + ) + ), + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None + + +def test_extract_tool_calls_with_content( + openai_tool_parser, + harmony_encoding, +): + final_content = "This tool call will get the weather." + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?" + ), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_current_weather") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel( + "final" + ), + ] + ) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, + Role.ASSISTANT, + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content == final_content diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index fff20c68d6212..159966365ec45 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -7,9 +7,13 @@ from typing import Optional import openai import pytest -from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, - WEATHER_TOOL, ServerConfig) +from .utils import ( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + SEARCH_TOOL, + WEATHER_TOOL, + ServerConfig, +) # test: getting the model to generate parallel tool calls (streaming/not) @@ -17,12 +21,15 @@ from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, # may be added in the future. e.g. llama 3.1 models are not designed to support # parallel tool calls. @pytest.mark.asyncio -async def test_parallel_tool_calls(client: openai.AsyncOpenAI, - server_config: ServerConfig): - +async def test_parallel_tool_calls( + client: openai.AsyncOpenAI, server_config: ServerConfig +): if not server_config.get("supports_parallel", True): - pytest.skip("The {} model doesn't support parallel tool calls".format( - server_config["model"])) + pytest.skip( + "The {} model doesn't support parallel tool calls".format( + server_config["model"] + ) + ) models = await client.models.list() model_name: str = models.data[0].id @@ -32,7 +39,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, max_completion_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -69,7 +77,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) role_name: Optional[str] = None finish_reason_count: int = 0 @@ -80,24 +89,22 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, tool_call_id_count: int = 0 async for chunk in stream: - # if there's a finish reason make sure it's tools if chunk.choices[0].finish_reason: finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' + assert chunk.choices[0].finish_reason == "tool_calls" # if a role is being streamed make sure it wasn't already set to # something else if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' + assert not role_name or role_name == "assistant" + role_name = "assistant" # if a tool call is streamed make sure there's exactly one # (based on the request parameters streamed_tool_calls = chunk.choices[0].delta.tool_calls if streamed_tool_calls and len(streamed_tool_calls) > 0: - # make sure only one diff is present - correct even for parallel assert len(streamed_tool_calls) == 1 tool_call = streamed_tool_calls[0] @@ -110,8 +117,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, # if a tool call ID is streamed, make sure one hasn't been already if tool_call.id: tool_call_id_count += 1 - assert (isinstance(tool_call.id, str) - and (len(tool_call.id) >= 9)) + assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9) # if parts of the function start being streamed if tool_call.function: @@ -125,32 +131,32 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, # make sure they're a string and then add them to the list assert isinstance(tool_call.function.arguments, str) - tool_call_args[ - tool_call.index] += tool_call.function.arguments + tool_call_args[tool_call.index] += tool_call.function.arguments assert finish_reason_count == 1 - assert role_name == 'assistant' + assert role_name == "assistant" - assert (len(non_streamed_tool_calls) == len(tool_call_names) == - len(tool_call_args)) + assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args) for i in range(2): assert non_streamed_tool_calls[i].function.name == tool_call_names[i] streamed_args = json.loads(tool_call_args[i]) - non_streamed_args = json.loads( - non_streamed_tool_calls[i].function.arguments) + non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments) assert streamed_args == non_streamed_args # test: providing parallel tool calls back to the model to get a response # (streaming/not) @pytest.mark.asyncio -async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, - server_config: ServerConfig): - +async def test_parallel_tool_calls_with_results( + client: openai.AsyncOpenAI, server_config: ServerConfig +): if not server_config.get("supports_parallel", True): - pytest.skip("The {} model doesn't support parallel tool calls".format( - server_config["model"])) + pytest.skip( + "The {} model doesn't support parallel tool calls".format( + server_config["model"] + ) + ) models = await client.models.list() model_name: str = models.data[0].id @@ -160,14 +166,14 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, max_completion_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 assert choice.message.content is not None assert "98" in choice.message.content # Dallas temp in tool response assert "78" in choice.message.content # Orlando temp in tool response @@ -179,7 +185,8 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) chunks: list[str] = [] finish_reason_count = 0 diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index 40c3158e9e683..20fa3b08c7b98 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -7,16 +7,23 @@ from typing import Optional import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( - Qwen3CoderToolParser) -from vllm.transformers_utils.detokenizer import detokenize_incrementally + Qwen3CoderToolParser, +) +from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" +pytestmark = pytest.mark.cpu_test + +MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" @pytest.fixture(scope="module") @@ -29,79 +36,80 @@ def qwen3_tool_parser(qwen3_tokenizer): return Qwen3CoderToolParser(qwen3_tokenizer) +@pytest.fixture +def qwen3_xml_tool_parser(qwen3_tokenizer): + return Qwen3XMLToolParser(qwen3_tokenizer) + + +@pytest.fixture(params=["original", "xml"]) +def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request): + """Parameterized fixture that provides both parser types for testing""" + if request.param == "original": + return qwen3_tool_parser + else: + return qwen3_xml_tool_parser + + @pytest.fixture def sample_tools(): return [ - ChatCompletionToolsParam(type="function", - function={ - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - }, - "state": { - "type": "string", - "description": - "The state code" - }, - "unit": { - "type": "string", - "enum": - ["fahrenheit", "celsius"] - } - }, - "required": ["city", "state"] - } - }), - ChatCompletionToolsParam(type="function", - function={ - "name": "calculate_area", - "description": - "Calculate area of a shape", - "parameters": { - "type": "object", - "properties": { - "shape": { - "type": "string" - }, - "dimensions": { - "type": "object" - }, - "precision": { - "type": "integer" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": {"type": "string", "description": "The state code"}, + "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, + }, + "required": ["city", "state"], + }, + }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculate_area", + "description": "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, + }, + }, + ), ] -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): # Qwen3 parser doesn't generate IDs during extraction assert actual_tool_call.type == "function" - assert ( - actual_tool_call.function.name == expected_tool_call.function.name) - assert (json.loads(actual_tool_call.function.arguments) == json.loads( - expected_tool_call.function.arguments)) + assert actual_tool_call.function.name == expected_tool_call.function.name + assert json.loads(actual_tool_call.function.arguments) == json.loads( + expected_tool_call.function.arguments + ) def stream_delta_message_generator( - qwen3_tool_parser: Qwen3CoderToolParser, + qwen3_tool_parser, qwen3_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None + request: Optional[ChatCompletionRequest] = None, ) -> Generator[DeltaMessage, None, None]: - all_token_ids = qwen3_tokenizer.encode(model_output, - add_special_tokens=False) + all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -110,18 +118,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] + current_token_ids = all_token_ids[: i + 1] - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=qwen3_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - ) + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=qwen3_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -138,16 +147,18 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = (previous_tokens + - new_tokens if previous_tokens else new_tokens) + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset -def test_extract_tool_calls_no_tools(qwen3_tool_parser): +def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): model_output = "This is a test response without any tool calls" - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -163,7 +174,8 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser): ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ('''<tool_call> + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -175,16 +187,21 @@ TX fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], None), - ('''Sure! Let me check the weather for you.<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + None, + ), + ( + """Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -196,16 +213,21 @@ TX fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], "Sure! Let me check the weather for you."), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + "Sure! Let me check the weather for you.", + ), + ( + """<tool_call> <function=calculate_area> <parameter=shape> rectangle @@ -218,18 +240,25 @@ rectangle 2 </parameter> </function> -</tool_call>''', [ - ToolCall(function=FunctionCall(name="calculate_area", - arguments=json.dumps({ - "shape": "rectangle", - "dimensions": { - "width": 10, - "height": 20 - }, - "precision": 2 - }))) - ], None), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "rectangle", + "dimensions": {"width": 10, "height": 20}, + "precision": 2, + } + ), + ) + ) + ], + None, + ), + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -254,23 +283,29 @@ FL fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit" - }))) - ], None), - ('''Let me calculate that area for you.<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "fahrenheit"} + ), + ) + ), + ], + None, + ), + ( + """Let me calculate that area for you.<tool_call> <function=calculate_area> <parameter=shape> circle @@ -282,25 +317,36 @@ circle 3 </parameter> </function> -</tool_call>''', [ - ToolCall(function=FunctionCall(name="calculate_area", - arguments=json.dumps({ - "shape": "circle", - "dimensions": { - "radius": 15.5 - }, - "precision": 3 - }))) - ], "Let me calculate that area for you."), +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "circle", + "dimensions": {"radius": 15.5}, + "precision": 3, + } + ), + ) + ) + ], + "Let me calculate that area for you.", + ), ], ) -def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, - expected_tool_calls, expected_content): - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( - model_output, request=request) +def test_extract_tool_calls( + qwen3_tool_parser_parametrized, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=request + ) assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -308,59 +354,51 @@ def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, assert extracted_tool_calls.content == expected_content -def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools): +def test_extract_tool_calls_fallback_no_tags( + qwen3_tool_parser_parametrized, sample_tools +): """Test fallback parsing when XML tags are missing""" - model_output = '''<function=get_current_weather> + model_output = """<function=get_current_weather> <parameter=city> Dallas </parameter> <parameter=state> TX </parameter> -</function>''' +</function>""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( - model_output, request=request) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=request + ) assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 - assert (extracted_tool_calls.tool_calls[0].function.name == - "get_current_weather") + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" -def test_extract_tool_calls_type_conversion(qwen3_tool_parser): +def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): """Test parameter type conversion based on tool schema""" tools = [ - ChatCompletionToolsParam(type="function", - function={ - "name": "test_types", - "parameters": { - "type": "object", - "properties": { - "int_param": { - "type": "integer" - }, - "float_param": { - "type": "float" - }, - "bool_param": { - "type": "boolean" - }, - "str_param": { - "type": "string" - }, - "obj_param": { - "type": "object" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": {"type": "integer"}, + "float_param": {"type": "float"}, + "bool_param": {"type": "boolean"}, + "str_param": {"type": "string"}, + "obj_param": {"type": "object"}, + }, + }, + }, + ) ] - model_output = '''<tool_call> + model_output = """<tool_call> <function=test_types> <parameter=int_param> 42 @@ -378,11 +416,12 @@ hello world {"key": "value"} </parameter> </function> -</tool_call>''' +</tool_call>""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) - extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( - model_output, request=request) + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=request + ) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) assert args["int_param"] == 42 @@ -397,12 +436,15 @@ hello world "no_tools", "single_tool", "single_tool_with_content", + "single_tool_multiline_param", "parallel_tools", + "tool_with_typed_params", # Added this test case ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ("This is a test without tools", [], "This is a test without tools"), - ('''<tool_call> + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -414,16 +456,21 @@ TX fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], ""), - ('''Sure! Let me check the weather for you.<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + None, + ), + ( + """Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -435,16 +482,52 @@ TX fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], "Sure! Let me check the weather for you."), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + "Sure! Let me check the weather for you.", + ), + ( + """<tool_call> +<function=calculate_area> +<parameter=shape> +rectangle +</parameter> +<parameter=dimensions> +{"width": 10, + "height": 20} +</parameter> +<parameter=precision> +2 +</parameter> +</function> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "rectangle", + "dimensions": {"width": 10, "height": 20}, + "precision": 2, + } + ), + ) + ) + ], + None, + ), + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -469,37 +552,77 @@ FL celsius </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "celsius" - }))) - ], ""), +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "celsius"} + ), + ) + ), + ], + None, + ), + # Added tool_with_typed_params test case + ( + """Let me calculate that area for you.<tool_call> +<function=calculate_area> +<parameter=shape> +circle +</parameter> +<parameter=dimensions> +{"radius": 15.5} +</parameter> +<parameter=precision> +3 +</parameter> +</function> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "circle", + "dimensions": {"radius": 15.5}, + "precision": 3, + } + ), + ) + ) + ], + "Let me calculate that area for you.", + ), ], ) -def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, - sample_tools, model_output, - expected_tool_calls, expected_content): - """Test incremental streaming behavior""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) +def test_extract_tool_calls_streaming( + qwen3_tool_parser_parametrized, + qwen3_tokenizer, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): + """Test incremental streaming behavior including typed parameters""" + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - other_content = '' + other_content = "" tool_states = {} # Track state per tool index for delta_message in stream_delta_message_generator( - qwen3_tool_parser, qwen3_tokenizer, model_output, request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): # role should never be streamed from tool parser assert not delta_message.role @@ -516,7 +639,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, "id": None, "name": None, "arguments": "", - "type": None + "type": None, } # First chunk should have id, name, and type @@ -535,11 +658,10 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, if tool_call.function.arguments is not None: # Accumulate arguments incrementally - tool_states[idx][ - "arguments"] += tool_call.function.arguments + tool_states[idx]["arguments"] += tool_call.function.arguments # Verify final content - assert other_content == expected_content + assert other_content == (expected_content or "") # Handle None case # Verify we got all expected tool calls assert len(tool_states) == len(expected_tool_calls) @@ -559,11 +681,127 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, assert actual_args == expected_args -def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, - qwen3_tokenizer, - sample_tools): +def test_extract_tool_calls_missing_closing_parameter_tag( + qwen3_tool_parser_parametrized, sample_tools +): + """Test handling of missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = """Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>""" + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=request + ) + + # The parser should handle the malformed XML gracefully + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + + # Verify the function name is correct + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" + + # Verify the arguments are parsed despite the missing closing tag + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert "city" in args + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + # Check that content before the tool call is preserved + assert "Let me check the weather for you:" in extracted_tool_calls.content + + +def test_extract_tool_calls_streaming_missing_closing_tag( + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools +): + """Test streaming with missing closing </parameter> tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = """Let me check the weather for you: +<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>""" + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + + other_content = "" + tool_states = {} + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + # Verify content was streamed + assert "Let me check the weather for you:" in other_content + + # Verify we got the tool call + assert len(tool_states) == 1 + state = tool_states[0] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == "get_current_weather" + + # Verify arguments were parsed correctly despite missing closing tag + assert state["arguments"] is not None + args = json.loads(state["arguments"]) + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + +def test_extract_tool_calls_streaming_incremental( + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools +): """Test that streaming is truly incremental""" - model_output = '''I'll check the weather.<tool_call> + model_output = """I'll check the weather.<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -572,15 +810,14 @@ Dallas TX </parameter> </function> -</tool_call>''' +</tool_call>""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) chunks = [] for delta_message in stream_delta_message_generator( - qwen3_tool_parser, qwen3_tokenizer, model_output, request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): chunks.append(delta_message) # Should have multiple chunks @@ -595,7 +832,7 @@ TX for chunk in chunks: if chunk.tool_calls and chunk.tool_calls[0].id: header_found = True - assert (chunk.tool_calls[0].function.name == "get_current_weather") + assert chunk.tool_calls[0].function.name == "get_current_weather" assert chunk.tool_calls[0].type == "function" # Empty initially assert chunk.tool_calls[0].function.arguments == "" @@ -616,3 +853,43 @@ TX parsed_args = json.loads(full_args) assert parsed_args["city"] == "Dallas" assert parsed_args["state"] == "TX" + + +def test_extract_tool_calls_complex_type_with_single_quote( + qwen3_tool_parser_parametrized, +): + """Test parameter type conversion based on tool schema""" + tools = [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": {"type": "integer"}, + "float_param": {"type": "float"}, + "bool_param": {"type": "boolean"}, + "str_param": {"type": "string"}, + "obj_param": {"type": "object"}, + }, + }, + }, + ) + ] + + model_output = """<tool_call> +<function=test_types> +<parameter=obj_param> +{'key': 'value'} +</parameter> +</function> +</tool_call>""" + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) + extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( + model_output, request=request + ) + + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args["obj_param"] == {"key": "value"} diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index d85bc9bbf1b30..eddb5a9b9f5ec 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -8,14 +8,19 @@ from typing import Optional import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer +pytestmark = pytest.mark.cpu_test + # Use a common model that is likely to be available MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" @@ -43,51 +48,56 @@ def sample_tools(): "properties": { "location": { "type": "string", - "description": - "City and country e.g. Bogotá, Colombia" + "description": "City and country e.g. Bogotá, Colombia", }, "unit": { "type": "string", - "description": "this is the unit of temperature" - } + "description": "this is the unit of temperature", + }, }, "required": ["location"], - "additionalProperties": False + "additionalProperties": False, }, "returns": { "type": "object", "properties": { "temperature": { "type": "number", - "description": "temperature in celsius" + "description": "temperature in celsius", } }, "required": ["temperature"], - "additionalProperties": False + "additionalProperties": False, }, - "strict": True - }), + "strict": True, + }, + ), ] -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): # Seed-OSS tool call will not generate id assert actual_tool_call.type == "function" assert actual_tool_call.function == expected_tool_call.function assert actual_tool_call.function.name == expected_tool_call.function.name - assert actual_tool_call.function.arguments == expected_tool_call.function.arguments + assert ( + actual_tool_call.function.arguments == expected_tool_call.function.arguments + ) def test_extract_tool_calls_no_tools(seed_oss_tool_parser): model_output = "This is a test response without any tool calls" extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] @@ -102,22 +112,24 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" - """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" - """<seed:tool_call>\n<function=get_weather>\n""" - """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", - [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') - ], - """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" - """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" - ), + ( + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) + ], + None, + ), ( """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -134,13 +146,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" """\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) ], """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -172,15 +188,18 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps( - { - "location": "Barcelona, Spain", - "unit": "celsius", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, + ), + ), + type="function", + ) ], """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ """First, I need to remember the function I can use: get_weather. The function requires a """ @@ -199,13 +218,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): ), ], ) -def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, - expected_tool_calls, expected_content): - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) +def test_extract_tool_calls( + seed_oss_tool_parser, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( - model_output, request=request) # type: ignore[arg-type] + model_output, request=request + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -228,7 +251,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." @@ -236,10 +259,9 @@ def stream_delta_message_generator( seed_oss_tool_parser: SeedOssToolParser, seed_oss_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None + request: Optional[ChatCompletionRequest] = None, ) -> Generator[DeltaMessage, None, None]: - all_token_ids = seed_oss_tokenizer.encode(model_output, - add_special_tokens=False) + all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -248,18 +270,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] + current_token_ids = all_token_ids[: i + 1] - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=seed_oss_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - ) + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=seed_oss_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -276,8 +299,9 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = (previous_tokens + - new_tokens if previous_tokens else new_tokens) + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -290,22 +314,27 @@ def stream_delta_message_generator( ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" - """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" - """<seed:tool_call>\n<function=get_weather>\n""" - """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", - [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') - ], - """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" - """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" - ), + ( + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) + ], + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""", + ), ( """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -322,13 +351,17 @@ def stream_delta_message_generator( """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" """\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) ], """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -360,15 +393,18 @@ def stream_delta_message_generator( """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps( - { - "location": "Barcelona, Spain", - "unit": "celsius", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, + ), + ), + type="function", + ) ], """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ """First, I need to remember the function I can use: get_weather. The function requires a """ @@ -387,19 +423,23 @@ def stream_delta_message_generator( ), ], ) -def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, - sample_tools, model_output, expected_tool_calls, - expected_content): +def test_streaming_tool_calls( + seed_oss_tool_parser, + seed_oss_tokenizer, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): """Test incremental streaming behavior""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - other_content = '' + other_content = "" tool_states = {} # Track state per tool index for delta_message in stream_delta_message_generator( - seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): + seed_oss_tool_parser, seed_oss_tokenizer, model_output, request + ): # role should never be streamed from tool parser assert not delta_message.role @@ -416,7 +456,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, "id": None, "name": None, "arguments": "", - "type": None + "type": None, } # First chunk should have id, name, and type @@ -435,8 +475,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, if tool_call.function.arguments is not None: # Accumulate arguments incrementally - tool_states[idx][ - "arguments"] += tool_call.function.arguments + tool_states[idx]["arguments"] += tool_call.function.arguments # Verify final content assert other_content == expected_content diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 53ba03a0ae109..64186aaac6a74 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -7,8 +7,12 @@ from typing import Optional import openai import pytest -from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, - SEARCH_TOOL, WEATHER_TOOL) +from .utils import ( + MESSAGES_ASKING_FOR_TOOLS, + MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, + WEATHER_TOOL, +) # test: request a chat completion that should return tool calls, so we know they @@ -23,17 +27,18 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): max_completion_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason tool_calls = chat_completion.choices[0].message.tool_calls # make sure a tool call is present - assert choice.message.role == 'assistant' + assert choice.message.role == "assistant" assert tool_calls is not None assert len(tool_calls) == 1 - assert tool_calls[0].type == 'function' + assert tool_calls[0].type == "function" assert tool_calls[0].function is not None assert isinstance(tool_calls[0].id, str) assert len(tool_calls[0].id) >= 9 @@ -54,7 +59,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert stop_reason == "tool_calls" function_name: Optional[str] = None - function_args_str: str = '' + function_args_str: str = "" tool_call_id: Optional[str] = None role_name: Optional[str] = None finish_reason_count: int = 0 @@ -67,20 +72,21 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): max_completion_tokens=100, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) async for chunk in stream: assert chunk.choices[0].index == 0 if chunk.choices[0].finish_reason: finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' + assert chunk.choices[0].finish_reason == "tool_calls" # if a role is being streamed make sure it wasn't already set to # something else if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' + assert not role_name or role_name == "assistant" + role_name = "assistant" # if a tool call is streamed make sure there's exactly one # (based on the request parameters @@ -108,7 +114,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): function_args_str += tool_call.function.arguments assert finish_reason_count == 1 - assert role_name == 'assistant' + assert role_name == "assistant" assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9) # validate the name and arguments @@ -148,14 +154,14 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): max_completion_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 assert choice.message.content is not None assert "98" in choice.message.content # the temperature from the response @@ -166,7 +172,8 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) chunks: list[str] = [] finish_reason_count = 0 diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index e0ed221a93e12..d52c141f6210d 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -8,10 +8,14 @@ import pytest import regex as re from pydantic import TypeAdapter -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +pytestmark = pytest.mark.cpu_test + EXAMPLE_TOOLS = [ { "type": "function", @@ -22,18 +26,16 @@ EXAMPLE_TOOLS = [ "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to find the weather for" + "type": "string", + "description": "The city to find the weather for" ", e.g. 'San Francisco'", }, }, "required": ["city"], - "additionalProperties": False + "additionalProperties": False, }, }, - "strict": True + "strict": True, }, { "type": "function", @@ -44,35 +46,34 @@ EXAMPLE_TOOLS = [ "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to get the forecast for, e.g. 'New York'", + "type": "string", + "description": "The city to get the forecast for, e.g. " + "'New York'", }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, }, "required": ["city", "days"], - "additionalProperties": False + "additionalProperties": False, }, }, - "strict": True + "strict": True, }, ] -def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, - should_match: bool): +def _compile_and_check( + tools: list[ChatCompletionToolsParam], sample_output, should_match: bool +): self = MagicMock(tool_choice="required", tools=tools) - schema = ChatCompletionRequest._get_guided_json_from_tool(self) + schema = ChatCompletionRequest._get_json_schema_from_tool(self) assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide from outlines_core.json_schema import build_regex_from_schema + regex = build_regex_from_schema(json.dumps(schema)) compiled = re.compile(regex) matches = compiled.fullmatch(json.dumps(sample_output)) is not None @@ -81,65 +82,31 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, VALID_TOOL_OUTPUTS = [ - ([{ - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }], True), - ([{ - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Berlin" - } - }], True), - ([{ - "name": "get_forecast", - "parameters": { - "city": "Vienna", - "days": 7 - } - }], True), - ([{ - "name": "get_forecast", - "parameters": { - "city": "Vienna", - "days": 7 - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }], True), - ([{ - "name": "get_forecast", - "parameters": { - "city": "Vienna", - "days": 7 - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }, { - "name": "get_forecast", - "parameters": { - "city": "Berlin", - "days": 7 - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Berlin" - } - }], True), + ([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True), + ( + [ + {"name": "get_current_weather", "parameters": {"city": "Vienna"}}, + {"name": "get_current_weather", "parameters": {"city": "Berlin"}}, + ], + True, + ), + ([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True), + ( + [ + {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}, + {"name": "get_current_weather", "parameters": {"city": "Vienna"}}, + ], + True, + ), + ( + [ + {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}, + {"name": "get_current_weather", "parameters": {"city": "Vienna"}}, + {"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}}, + {"name": "get_current_weather", "parameters": {"city": "Berlin"}}, + ], + True, + ), ] VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS] @@ -147,92 +114,100 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS] @pytest.mark.parametrize( "sample_output, should_match", - VALID_TOOL_OUTPUTS + [ + VALID_TOOL_OUTPUTS + + [ (None, False), ([], False), # empty list cannot be generated ({}, False), # empty object cannot be generated ([{}], False), # list with empty object cannot be generated ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather" - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather" + } + ], + False, + ), ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather", - "parameters": {} - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather", + "parameters": {}, + } + ], + False, + ), ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather", - "parameters": None - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather", + "parameters": None, + } + ], + False, + ), ( { # tool call without lists cannot be generated "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } + "parameters": {"city": "Vienna"}, }, - False), + False, + ), ( - [{ # tool call with extra parameters cannot be generated - "name": "get_current_weather", - "parameters": { - "city": "Vienna", - "extra": "value" + [ + { # tool call with extra parameters cannot be generated + "name": "get_current_weather", + "parameters": {"city": "Vienna", "extra": "value"}, } - }], - False), + ], + False, + ), ( - [{ # tool call where parameters are first cannot be generated - "parameters": { - "city": "Vienna" - }, - "name": "get_current_weather" - }], - False), - ( - [{ # tool call without all required parameters cannot be generated - "name": "get_forecast", - "parameters": { - "city": "Vienna" + [ + { # tool call where parameters are first cannot be generated + "parameters": {"city": "Vienna"}, + "name": "get_current_weather", } - }], - False), + ], + False, + ), + ( + [ + { # tool call without all required parameters cannot be generated + "name": "get_forecast", + "parameters": {"city": "Vienna"}, + } + ], + False, + ), ( # tool call with incorrect name/parameters cannot be generated - [{ - "name": "get_weather", - "parameters": { - "city": "Vienna", - "days": 7 - } - }], False), + [{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}], + False, + ), ( # tool call with both valid and empty function cannot be generated - [{ - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }, {}], False), - ]) -def test_guided_json(sample_output, should_match): - _compile_and_check(tools=TypeAdapter( - list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS), - sample_output=sample_output, - should_match=should_match) + [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}], + False, + ), + ], +) +def test_structured_outputs_json(sample_output, should_match): + _compile_and_check( + tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python( + EXAMPLE_TOOLS + ), + sample_output=sample_output, + should_match=should_match, + ) -def update_parameters_none( - tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: +def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: tool.function.parameters = None return tool def update_parameters_empty_dict( - tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: + tool: ChatCompletionToolsParam, +) -> ChatCompletionToolsParam: tool.function.parameters = {} return tool @@ -245,47 +220,60 @@ def update_parameters_empty_dict( ({}, False), # empty object cannot be generated ([{}], False), # list with empty object cannot be generated ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather" - }], - False), - ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather", - "parameters": None - }], - False), - ( - [{ # function with extra parameters cannot be generated - "name": "get_current_weather", - "parameters": { - "extra": "value" + [ + { # function without required parameters cannot be generated + "name": "get_current_weather" } - }], - False), + ], + False, + ), ( - [{ # only function with empty parameters object is valid - "name": "get_current_weather", - "parameters": {} - }], - True), - ]) + [ + { # function without required parameters cannot be generated + "name": "get_current_weather", + "parameters": None, + } + ], + False, + ), + ( + [ + { # function with extra parameters cannot be generated + "name": "get_current_weather", + "parameters": {"extra": "value"}, + } + ], + False, + ), + ( + [ + { # only function with empty parameters object is valid + "name": "get_current_weather", + "parameters": {}, + } + ], + True, + ), + ], +) @pytest.mark.parametrize( - "update_parameters", - [update_parameters_none, update_parameters_empty_dict]) -def test_guided_json_without_parameters(sample_output, should_match, - update_parameters): + "update_parameters", [update_parameters_none, update_parameters_empty_dict] +) +def test_structured_outputs_json_without_parameters( + sample_output, should_match, update_parameters +): updated_tools = [deepcopy(EXAMPLE_TOOLS[0])] - tools = TypeAdapter( - list[ChatCompletionToolsParam]).validate_python(updated_tools) + tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools) tools = list(map(update_parameters, tools)) - assert all([ - tool.function.parameters is None or tool.function.parameters == {} - for tool in tools - ]) - _compile_and_check(tools=tools, - sample_output=sample_output, - should_match=should_match) + assert all( + [ + tool.function.parameters is None or tool.function.parameters == {} + for tool in tools + ] + ) + _compile_and_check( + tools=tools, sample_output=sample_output, should_match=should_match + ) @pytest.mark.parametrize("output", VALID_TOOLS) @@ -303,7 +291,7 @@ def test_streaming_output_valid(output, empty_params, delta_len): function_name_returned = False messages = [] for i in range(0, len(output_json), delta_len): - delta_text = output_json[i:i + delta_len] + delta_text = output_json[i : i + delta_len] current_text = previous_text + delta_text delta_message, function_name_returned = ( @@ -312,7 +300,9 @@ def test_streaming_output_valid(output, empty_params, delta_len): previous_text=previous_text, current_text=current_text, delta_text=delta_text, - function_name_returned=function_name_returned)) + function_name_returned=function_name_returned, + ) + ) if delta_message: messages.append(delta_message) @@ -326,12 +316,14 @@ def test_streaming_output_valid(output, empty_params, delta_len): if len(combined_messages) > 1: combined_messages += "}," - combined_messages += '{"name": "' + \ - message.tool_calls[0].function.name + \ - '", "parameters": ' + \ - message.tool_calls[0].function.arguments + combined_messages += ( + '{"name": "' + + message.tool_calls[0].function.name + + '", "parameters": ' + + message.tool_calls[0].function.arguments + ) else: combined_messages += message.tool_calls[0].function.arguments combined_messages += "}]" assert json.loads(combined_messages) == output - assert json.dumps(json.loads(combined_messages)) == output_json \ No newline at end of file + assert json.dumps(json.loads(combined_messages)) == output_json diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 8d26b90515901..bdac878db4e76 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -2,12 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +from collections.abc import Generator +from typing import Optional import pytest -from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +pytestmark = pytest.mark.cpu_test # Use a common model that is likely to be available MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r" @@ -23,12 +33,14 @@ def xlam_tool_parser(xlam_tokenizer): return xLAMToolParser(xlam_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 @@ -36,10 +48,62 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], assert actual_tool_call.function == expected_tool_call.function +def stream_delta_message_generator( + xlam_tool_parser: xLAMToolParser, + xlam_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None, +) -> Generator[DeltaMessage, None, None]: + all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=xlam_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) + + current_text = previous_text + delta_text + + delta_message = xlam_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + def test_extract_tool_calls_no_tools(xlam_tool_parser): model_output = "This is a test" extracted_tool_calls = xlam_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -51,79 +115,120 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): "single_tool_with_think_tag", "single_tool_with_json_code_block", "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], None, ), ( """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "<think>I'll help you with that.</think>", ), ( """I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "I'll help you with that.", ), ( """I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "I'll check the weather for you.", ), + ( + """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) + ], + "I'll help you check the weather.", + ), ], ) -def test_extract_tool_calls(xlam_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + xlam_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = xlam_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -138,25 +243,30 @@ def test_extract_tool_calls(xlam_tool_parser, model_output, ( """[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Seattle", - "state": "WA", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Seattle", + "state": "WA", + "unit": "celsius", + } + ), + ) + ) ], None, ), ], ) -def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output, - expected_tool_calls, - expected_content): +def test_extract_tool_calls_list_structure( + xlam_tool_parser, model_output, expected_tool_calls, expected_content +): """Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501 extracted_tool_calls = xlam_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -167,20 +277,25 @@ def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output, # Test for preprocess_model_output method def test_preprocess_model_output(xlam_tool_parser): # Test with list structure - model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + model_output = ( + """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + ) content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content is None assert potential_tool_calls == model_output # Test with thinking tag model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content == "<think>I'll help you with that.</think>" assert ( - potential_tool_calls == - '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]') + potential_tool_calls + == '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]' + ) # Test with JSON code block model_output = """I'll help you with that. @@ -188,14 +303,16 @@ def test_preprocess_model_output(xlam_tool_parser): [{"name": "get_current_weather", "arguments": {"city": "Seattle"}}] ```""" content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content == "I'll help you with that." assert "get_current_weather" in potential_tool_calls # Test with no tool calls model_output = """I'll help you with that.""" content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content == model_output assert potential_tool_calls is None @@ -209,7 +326,9 @@ def test_streaming_with_list_structure(xlam_tool_parser): xlam_tool_parser.current_tool_id = -1 # Simulate receiving a message with list structure - current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + current_text = ( + """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + ) # First call to set up the tool xlam_tool_parser.extract_tool_calls_streaming( @@ -223,8 +342,7 @@ def test_streaming_with_list_structure(xlam_tool_parser): ) # Make sure the tool is set up correctly - assert (xlam_tool_parser.current_tool_id - >= 0), "Tool index should be initialized" + assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized" # Manually set up the state for sending the tool name xlam_tool_parser.current_tools_sent = [False] @@ -245,3 +363,175 @@ def test_streaming_with_list_structure(xlam_tool_parser): assert hasattr(result, "tool_calls") assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "get_current_weather" + + +@pytest.mark.parametrize( + ids=[ + "parallel_tool_calls", + "single_tool_with_think_tag", + "single_tool_with_json_code_block", + "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), + ], + "", + ), + ( + """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) + ], + "<think>I'll help you with that.</think>", + ), + ( + """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) + ], + "", + ), + ( + """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) + ], + "I can help with that.", + ), + ], +) +def test_extract_tool_calls_streaming_incremental( + xlam_tool_parser, + xlam_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 + request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) + + chunks = [] + for delta_message in stream_delta_message_generator( + xlam_tool_parser, xlam_tokenizer, model_output, request + ): + chunks.append(delta_message) + + # Should have multiple chunks + assert len(chunks) >= 3 + + # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501 + header_found = False + expected_first_tool = expected_tool_calls[0] + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].id: + header_found = True + assert ( + chunk.tool_calls[0].function.name == expected_first_tool.function.name + ) + assert chunk.tool_calls[0].type == "function" + # Arguments may be empty initially or None + if chunk.tool_calls[0].function.arguments is not None: + # If present, should be empty string initially + assert chunk.tool_calls[0].function.arguments == "" + break + assert header_found + + # Should have chunks with incremental arguments + arg_chunks = [] + for chunk in chunks: + if ( + chunk.tool_calls + and chunk.tool_calls[0].function.arguments + and chunk.tool_calls[0].function.arguments != "" + and chunk.tool_calls[0].index + == 0 # Only collect arguments from the first tool call + ): + arg_chunks.append(chunk.tool_calls[0].function.arguments) + + # Arguments should be streamed incrementally + assert len(arg_chunks) > 1 + + # Concatenated arguments should form valid JSON for the first tool call + full_args = "".join(arg_chunks) + parsed_args = json.loads(full_args) + expected_args = json.loads(expected_first_tool.function.arguments) + assert parsed_args == expected_args diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index a17fab9aecbca..835d07608e408 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -4,8 +4,7 @@ from copy import deepcopy from typing import Any, Optional -from openai.types.chat import (ChatCompletionMessageParam, - ChatCompletionToolParam) +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam from typing_extensions import TypedDict from tests.utils import VLLM_PATH @@ -20,8 +19,9 @@ class ServerConfig(TypedDict, total=False): extended: Optional[bool] # tests do not run in CI automatically -def patch_system_prompt(messages: list[dict[str, Any]], - system_prompt: str) -> list[dict[str, Any]]: +def patch_system_prompt( + messages: list[dict[str, Any]], system_prompt: str +) -> list[dict[str, Any]]: new_messages = deepcopy(messages) if new_messages[0]["role"] == "system": new_messages[0]["content"] = system_prompt @@ -30,8 +30,9 @@ def patch_system_prompt(messages: list[dict[str, Any]], return new_messages -def ensure_system_prompt(messages: list[dict[str, Any]], - config: ServerConfig) -> list[dict[str, Any]]: +def ensure_system_prompt( + messages: list[dict[str, Any]], config: ServerConfig +) -> list[dict[str, Any]]: prompt = config.get("system_prompt") if prompt: return patch_system_prompt(messages, prompt) @@ -42,92 +43,102 @@ def ensure_system_prompt(messages: list[dict[str, Any]], # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. ARGS: list[str] = [ - "--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs", - "256" + "--enable-auto-tool-choice", + "--max-model-len", + "1024", + "--max-num-seqs", + "256", ] CONFIGS: dict[str, ServerConfig] = { "hermes": { - "model": - "NousResearch/Hermes-3-Llama-3.1-8B", + "model": "NousResearch/Hermes-3-Llama-3.1-8B", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "hermes", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "hermes", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja"), ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, "llama": { - "model": - "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "llama3_json", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "llama3_json", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja"), ], - "supports_parallel": - False, + "supports_parallel": False, }, "llama3.2": { - "model": - "meta-llama/Llama-3.2-3B-Instruct", + "model": "meta-llama/Llama-3.2-3B-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "llama3_json", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "llama3_json", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja"), ], - "supports_parallel": - False, + "supports_parallel": False, }, "llama4": { - "model": - "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "model": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "llama4_pythonic", "--chat-template", - str(VLLM_PATH / - "examples/tool_chat_template_llama4_pythonic.jinja"), "-tp", - "4" + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "llama4_pythonic", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"), + "-tp", + "4", ], - "supports_parallel": - False, - "extended": - True + "supports_parallel": False, + "extended": True, }, "llama4_json": { - "model": - "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "model": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", "-tp", "4", - "--distributed-executor-backend", "mp", "--tool-call-parser", - "llama4_json", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "-tp", + "4", + "--distributed-executor-backend", + "mp", + "--tool-call-parser", + "llama4_json", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja"), ], - "supports_parallel": - True, - "extended": - True + "supports_parallel": True, + "extended": True, }, "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "mistral", "--chat-template", + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "mistral", + "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), - "--ignore-patterns=\"consolidated.safetensors\"" + '--ignore-patterns="consolidated.safetensors"', ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, # V1 Test: Passing locally but failing in CI. This runs the # V0 Engine because of CPU offloading. Need to debug why. @@ -146,49 +157,50 @@ CONFIGS: dict[str, ServerConfig] = { # False, # }, "granite-3.0-8b": { - "model": - "ibm-granite/granite-3.0-8b-instruct", + "model": "ibm-granite/granite-3.0-8b-instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "granite", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "granite", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"), ], }, "granite-3.1-8b": { - "model": - "ibm-granite/granite-3.1-8b-instruct", + "model": "ibm-granite/granite-3.1-8b-instruct", "arguments": [ "--enforce-eager", "--no-enable-prefix-caching", "--tool-call-parser", "granite", ], - "supports_parallel": - True, + "supports_parallel": True, }, "internlm": { - "model": - "internlm/internlm2_5-7b-chat", + "model": "internlm/internlm2_5-7b-chat", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "internlm", "--chat-template", - str(VLLM_PATH / - "examples/tool_chat_template_internlm2_tool.jinja"), - "--trust_remote_code" + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "internlm", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_internlm2_tool.jinja"), + "--trust_remote_code", ], - "supports_parallel": - False, + "supports_parallel": False, }, "toolACE": { - "model": - "Team-ACE/ToolACE-8B", + "model": "Team-ACE/ToolACE-8B", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "pythonic", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "pythonic", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja"), ], - "supports_parallel": - True, + "supports_parallel": True, }, } @@ -201,37 +213,31 @@ WEATHER_TOOL: ChatCompletionToolParam = { "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to find the weather for, " - "e.g. 'San Francisco'" + "type": "string", + "description": "The city to find the weather for, " + "e.g. 'San Francisco'", }, "state": { - "type": - "string", - "description": - "must the two-letter abbreviation for the state " + "type": "string", + "description": "must the two-letter abbreviation for the state " "that the city is in, e.g. 'CA' which would " - "mean 'California'" + "mean 'California'", }, "unit": { "type": "string", "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } - } - } - } + "enum": ["celsius", "fahrenheit"], + }, + }, + }, + }, } SEARCH_TOOL: ChatCompletionToolParam = { "type": "function", "function": { - "name": - "web_search", - "description": - "Search the internet and get a summary of the top " + "name": "web_search", + "description": "Search the internet and get a summary of the top " "10 webpages. Should only be used if you don't know " "the answer to a user query, and the results are likely" "to be able to be found with a web search", @@ -239,124 +245,98 @@ SEARCH_TOOL: ChatCompletionToolParam = { "type": "object", "properties": { "search_term": { - "type": - "string", - "description": - "The term to use in the search. This should" + "type": "string", + "description": "The term to use in the search. This should" "ideally be keywords to search for, not a" - "natural-language question" + "natural-language question", } }, - "required": ["search_term"] - } - } + "required": ["search_term"], + }, + }, } -MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "Hi! How are you?" -}, { - "role": - "assistant", - "content": - "I'm doing great! How can I assist you?" -}, { - "role": - "user", - "content": - "Can you tell me a joke please?" -}] +MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [ + {"role": "user", "content": "Hi! How are you?"}, + {"role": "assistant", "content": "I'm doing great! How can I assist you?"}, + {"role": "user", "content": "Can you tell me a joke please?"}, +] -MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}] +MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [ + {"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"} +] -MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas is 98 degrees fahrenheit, with partly" - "cloudy skies and a low chance of rain." -}] +MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [ + {"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain.", + }, +] -MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}] +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [ + { + "role": "user", + "content": "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?", + } +] -MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }, { - "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Orlando", "state": "Fl", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas TX is 98 degrees fahrenheit with mostly " - "cloudy skies and a chance of rain in the evening." -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "content": - "The weather in Orlando FL is 78 degrees fahrenheit with clear" - "skies." -}] +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [ + { + "role": "user", + "content": "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}', + }, + }, + { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}', + }, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening.", + }, + { + "role": "tool", + "tool_call_id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies.", + }, +] diff --git a/tests/tools/test_config_validator.py b/tests/tools/test_config_validator.py index b0475894a114e..22d838d272643 100644 --- a/tests/tools/test_config_validator.py +++ b/tests/tools/test_config_validator.py @@ -7,11 +7,11 @@ import pytest from tools.validate_config import validate_ast -_TestConfig1 = ''' +_TestConfig1 = """ @config class _TestConfig1: pass -''' +""" _TestConfig2 = ''' @config @@ -21,12 +21,12 @@ class _TestConfig2: """docstring""" ''' -_TestConfig3 = ''' +_TestConfig3 = """ @config @dataclass class _TestConfig3: a: int = 1 -''' +""" _TestConfig4 = ''' @config @@ -37,12 +37,15 @@ class _TestConfig4: ''' -@pytest.mark.parametrize(("test_config", "expected_error"), [ - (_TestConfig1, "must be a dataclass"), - (_TestConfig2, "must have a default"), - (_TestConfig3, "must have a docstring"), - (_TestConfig4, "must use a single Literal"), -]) +@pytest.mark.parametrize( + ("test_config", "expected_error"), + [ + (_TestConfig1, "must be a dataclass"), + (_TestConfig2, "must have a default"), + (_TestConfig3, "must have a docstring"), + (_TestConfig4, "must use a single Literal"), + ], +) def test_config(test_config, expected_error): tree = ast.parse(test_config) with pytest.raises(Exception, match=expected_error): diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py index 636108e985816..9780092b25e66 100644 --- a/tests/tpu/lora/test_lora.py +++ b/tests/tpu/lora/test_lora.py @@ -17,30 +17,21 @@ from vllm.lora.request import LoRARequest # 100 training iterations with a training batch size of 100. -@pytest.fixture(scope="function", autouse=True) -def use_v1_only(monkeypatch: pytest.MonkeyPatch): - """ - Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 - for all tests in this file - """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - yield - - def setup_vllm(num_loras: int, tp: int) -> vllm.LLM: - return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", - max_model_len=256, - max_seq_len_to_capture=256, - max_num_seqs=8, - tensor_parallel_size=tp, - enable_lora=True, - max_loras=num_loras, - max_lora_rank=8) + return vllm.LLM( + model="Qwen/Qwen2.5-3B-Instruct", + max_model_len=256, + max_num_seqs=8, + tensor_parallel_size=tp, + enable_lora=True, + max_loras=num_loras, + max_lora_rank=8, + ) -TPU_TENSOR_PARALLEL_SIZES = [1, tpu.num_available_chips() - ] if tpu.num_available_chips() > 1 else [1] +TPU_TENSOR_PARALLEL_SIZES = ( + [1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1] +) @pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) @@ -56,12 +47,19 @@ def test_single_lora(tp: int): prompt = "What is 1+1? \n" lora_request = LoRARequest( - "lora_adapter_1", 1, - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter") - output = llm.generate(prompt, - sampling_params=vllm.SamplingParams(max_tokens=256, - temperature=0), - lora_request=lora_request)[0].outputs[0].text + "lora_adapter_1", + 1, + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter", + ) + output = ( + llm.generate( + prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), + lora_request=lora_request, + )[0] + .outputs[0] + .text + ) answer = output.strip()[0] @@ -74,13 +72,12 @@ def test_lora_hotswapping(tp: int): """ This test ensures we can run multiple LoRA adapters on the TPU backend, even if we only have space to store 1. - + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. """ - lora_name_template = \ - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) for i in range(1, 5) @@ -91,10 +88,15 @@ def test_lora_hotswapping(tp: int): prompt = "What is 1+1? \n" for i, req in enumerate(lora_requests): - output = llm.generate(prompt, - sampling_params=vllm.SamplingParams( - max_tokens=256, temperature=0), - lora_request=req)[0].outputs[0].text + output = ( + llm.generate( + prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), + lora_request=req, + )[0] + .outputs[0] + .text + ) answer = output.strip()[0] assert answer.isdigit() @@ -106,12 +108,11 @@ def test_multi_lora(tp: int): """ This test ensures we can run multiple LoRA adapters on the TPU backend, when we have enough space to store all of them. - + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. """ - lora_name_template = \ - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) for i in range(1, 5) @@ -122,10 +123,15 @@ def test_multi_lora(tp: int): prompt = "What is 1+1? \n" for i, req in enumerate(lora_requests): - output = llm.generate(prompt, - sampling_params=vllm.SamplingParams( - max_tokens=256, temperature=0), - lora_request=req)[0].outputs[0].text + output = ( + llm.generate( + prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), + lora_request=req, + )[0] + .outputs[0] + .text + ) answer = output.strip()[0] diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 448b8b2bc094f..5acfa484f0c13 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -26,16 +26,15 @@ def test_tpu_compilation(): # Currently, top-p sampling is disabled. `top_p` should be 1.0. N = 1 - sampling_params = SamplingParams(temperature=0.7, - top_p=1.0, - n=N, - max_tokens=16) + sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=256, - max_model_len=256, - max_num_seqs=32, - enforce_eager=False) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_batched_tokens=256, + max_model_len=256, + max_num_seqs=32, + enforce_eager=False, + ) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): @@ -45,7 +44,8 @@ def test_tpu_compilation(): assert generated_text.startswith(answer) compiled_codes = sorted( - glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))) + glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")) + ) for i, compiled_code in enumerate(compiled_codes): print("{} file: {}".format(i + 1, compiled_code)) @@ -66,9 +66,10 @@ def test_tpu_compilation(): # Check all the compilations are as expected. The dump files include the # captured graph for the forward function of the nn.Module. - compiled_fns = sorted(glob.glob( - os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")), - key=lambda s: extract_compiled_index(s)) + compiled_fns = sorted( + glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")), + key=lambda s: extract_compiled_index(s), + ) for i, compiled_fn in enumerate(compiled_fns): print("{} file: {}".format(i + 1, compiled_fn)) @@ -82,4 +83,4 @@ def test_tpu_compilation(): # ragged_paged_attention with open(compiled_fns[1]) as f: content = f.read() - assert (kv_cache_prefix in content and attn_prefix in content) + assert kv_cache_prefix in content and attn_prefix in content diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 9c90df1b77010..102e5ddf16d6d 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -15,17 +15,20 @@ from ..utils import compare_two_settings def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_RPC_TIMEOUT", "30000") - compare_two_settings("Qwen/Qwen2.5-1.5B-Instruct", - arg1=[ - "--max-model-len=256", - "--max-num-seqs=32", - "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_ONCE}", - ], - arg2=[ - "--max-model-len=256", "--max-num-seqs=32", - "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_AS_IS}" - ], - env1={}, - env2={}) + compare_two_settings( + "Qwen/Qwen2.5-1.5B-Instruct", + arg1=[ + "--max-model-len=256", + "--max-num-seqs=32", + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_ONCE}", + ], + arg2=[ + "--max-model-len=256", + "--max-num-seqs=32", + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_AS_IS}", + ], + env1={}, + env2={}, + ) diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index 407a824d81748..e3236d20bf673 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -4,16 +4,15 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`. """ + import pytest import torch +import torch_xla -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.layers.fused_moe.moe_pallas import ( - fused_moe as pallas_moe) +from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as torch_moe) -# yapf: enable + fused_moe as torch_moe, +) from vllm.platforms import current_platform if not current_platform.is_tpu(): @@ -42,6 +41,7 @@ def test_pallas_moe( dtype: torch.dtype, ): import torch_xla.core.xla_model as xm + with torch.device(xm.xla_device()): a = torch.randn((m, k), dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 @@ -77,7 +77,7 @@ def test_pallas_moe( expert_map=e_map, renormalize=False, ) - xm.mark_step() + torch_xla.sync(wait=False) # Compare outputs torch.testing.assert_close( diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 6cefbae4bdd18..151be5f17fe89 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -17,18 +17,18 @@ class GSM8KAccuracyTestConfig: expected_value: float def get_model_args(self) -> str: - return (f"pretrained={self.model_name}," - "max_model_len=4096,max_num_seqs=32") + return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32" # NOTE: Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ GSM8KAccuracyTestConfig( model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - expected_value=0.76), # no bias + expected_value=0.76, + ), # no bias # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As - # a follow up, move this into the LM-EVAL section of the CI. + # a follow-up, move this into the LM-EVAL section of the CI. # GSM8KAccuracyTestConfig( # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", # expected_value=0.66), # bias in QKV layers @@ -37,7 +37,6 @@ ACCURACY_CONFIGS = [ @pytest.mark.parametrize("config", ACCURACY_CONFIGS) def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): - results = lm_eval.simple_evaluate( model="vllm", model_args=config.get_model_args(), @@ -47,6 +46,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): EXPECTED_VALUE = config.expected_value measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py deleted file mode 100644 index 4dbae7c15de3a..0000000000000 --- a/tests/tracing/test_tracing.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# ruff: noqa -# type: ignore -from __future__ import annotations - -import threading -from collections.abc import Iterable -from concurrent import futures -from typing import Callable, Generator, Literal - -import grpc -import pytest -from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( - ExportTraceServiceResponse) -from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( - TraceServiceServicer, add_TraceServiceServicer_to_server) -from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue -from opentelemetry.sdk.environment_variables import ( - OTEL_EXPORTER_OTLP_TRACES_INSECURE) - -from vllm import LLM, SamplingParams -from vllm.tracing import SpanAttributes - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - -FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" - -FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', - 'array_value'] - - -def decode_value(value: AnyValue): - field_decoders: dict[FieldName, Callable] = { - "bool_value": (lambda v: v.bool_value), - "string_value": (lambda v: v.string_value), - "int_value": (lambda v: v.int_value), - "double_value": (lambda v: v.double_value), - "array_value": - (lambda v: [decode_value(item) for item in v.array_value.values]), - } - for field, decoder in field_decoders.items(): - if value.HasField(field): - return decoder(value) - raise ValueError(f"Couldn't decode value: {value}") - - -def decode_attributes(attributes: Iterable[KeyValue]): - return {kv.key: decode_value(kv.value) for kv in attributes} - - -class FakeTraceService(TraceServiceServicer): - - def __init__(self): - self.request = None - self.evt = threading.Event() - - def Export(self, request, context): - self.request = request - self.evt.set() - return ExportTraceServiceResponse() - - -@pytest.fixture -def trace_service() -> Generator[FakeTraceService, None, None]: - """Fixture to set up a fake gRPC trace service""" - server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) - service = FakeTraceService() - add_TraceServiceServicer_to_server(service, server) - server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) - server.start() - - yield service - - server.stop(None) - - -def test_traces( - monkeypatch: pytest.MonkeyPatch, - trace_service: FakeTraceService, -): - with monkeypatch.context() as m: - m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") - - sampling_params = SamplingParams( - temperature=0.01, - top_p=0.1, - max_tokens=256, - ) - model = "facebook/opt-125m" - llm = LLM( - model=model, - otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, - ) - prompts = ["This is a short prompt"] - outputs = llm.generate(prompts, sampling_params=sampling_params) - - timeout = 5 - if not trace_service.evt.wait(timeout): - raise TimeoutError( - f"The fake trace service didn't receive a trace within " - f"the {timeout} seconds timeout") - - request = trace_service.request - assert len(request.resource_spans) == 1, ( - f"Expected 1 resource span, " - f"but got {len(request.resource_spans)}") - assert len(request.resource_spans[0].scope_spans) == 1, ( - f"Expected 1 scope span, " - f"but got {len(request.resource_spans[0].scope_spans)}") - assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( - f"Expected 1 span, " - f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") - - attributes = decode_attributes( - request.resource_spans[0].scope_spans[0].spans[0].attributes) - assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE - ) == sampling_params.temperature - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS - ) == sampling_params.max_tokens - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( - outputs[0].prompt_token_ids) - completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens - metrics = outputs[0].metrics - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE - ) == metrics.time_in_queue - ttft = metrics.first_token_time - metrics.arrival_time - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft - e2e_time = metrics.finished_time - metrics.arrival_time - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time - assert metrics.scheduler_time > 0 - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER - ) == metrics.scheduler_time - # Model forward and model execute should be none, since detailed traces is - # not enabled. - assert metrics.model_forward_time is None - assert metrics.model_execute_time is None - - -def test_traces_with_detailed_steps( - monkeypatch: pytest.MonkeyPatch, - trace_service: FakeTraceService, -): - with monkeypatch.context() as m: - m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") - - sampling_params = SamplingParams( - temperature=0.01, - top_p=0.1, - max_tokens=256, - ) - model = "facebook/opt-125m" - llm = LLM( - model=model, - otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, - collect_detailed_traces=["all"], - ) - prompts = ["This is a short prompt"] - outputs = llm.generate(prompts, sampling_params=sampling_params) - - timeout = 5 - if not trace_service.evt.wait(timeout): - raise TimeoutError( - f"The fake trace service didn't receive a trace within " - f"the {timeout} seconds timeout") - - request = trace_service.request - assert len(request.resource_spans) == 1, ( - f"Expected 1 resource span, " - f"but got {len(request.resource_spans)}") - assert len(request.resource_spans[0].scope_spans) == 1, ( - f"Expected 1 scope span, " - f"but got {len(request.resource_spans[0].scope_spans)}") - assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( - f"Expected 1 span, " - f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") - - attributes = decode_attributes( - request.resource_spans[0].scope_spans[0].spans[0].attributes) - assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE - ) == sampling_params.temperature - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS - ) == sampling_params.max_tokens - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( - outputs[0].prompt_token_ids) - completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens - metrics = outputs[0].metrics - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE - ) == metrics.time_in_queue - ttft = metrics.first_token_time - metrics.arrival_time - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft - e2e_time = metrics.finished_time - metrics.arrival_time - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time - assert metrics.scheduler_time > 0 - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER - ) == metrics.scheduler_time - assert metrics.model_forward_time > 0 - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD - ) == pytest.approx(metrics.model_forward_time / 1000) - assert metrics.model_execute_time > 0 - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE - ) == metrics.model_execute_time - assert metrics.model_forward_time < 1000 * metrics.model_execute_time diff --git a/tests/tensorizer_loader/__init__.py b/tests/transformers_utils/__init__.py similarity index 100% rename from tests/tensorizer_loader/__init__.py rename to tests/transformers_utils/__init__.py diff --git a/tests/transformers_utils/test_config_parser_registry.py b/tests/transformers_utils/test_config_parser_registry.py new file mode 100644 index 0000000000000..9372cb9d46d30 --- /dev/null +++ b/tests/transformers_utils/test_config_parser_registry.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from typing import Optional, Union + +import pytest +from transformers import PretrainedConfig + +from vllm.transformers_utils.config import get_config_parser, register_config_parser +from vllm.transformers_utils.config_parser_base import ConfigParserBase + + +@register_config_parser("custom_config_parser") +class CustomConfigParser(ConfigParserBase): + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError + + +def test_register_config_parser(): + assert isinstance(get_config_parser("custom_config_parser"), CustomConfigParser) + + +def test_invalid_config_parser(): + with pytest.raises(ValueError): + + @register_config_parser("invalid_config_parser") + class InvalidConfigParser: + pass diff --git a/tests/utils.py b/tests/utils.py index 4dba5494665a3..b853542c241fc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,21 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import contextlib import copy import functools import importlib import json import os +import random import signal import subprocess import sys import tempfile import time import warnings -from contextlib import contextmanager, suppress +from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union +from unittest.mock import patch import cloudpickle import httpx @@ -30,20 +33,29 @@ from typing_extensions import ParamSpec import vllm.envs as envs from tests.models.utils import TextTextLogprobs -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import (FlexibleArgumentParser, GB_bytes, - cuda_device_count_stateless, get_open_port) +from vllm.utils import ( + FlexibleArgumentParser, + GB_bytes, + cuda_device_count_stateless, + get_open_port, +) if current_platform.is_rocm(): - from amdsmi import (amdsmi_get_gpu_vram_usage, - amdsmi_get_processor_handles, amdsmi_init, - amdsmi_shut_down) + from amdsmi import ( + amdsmi_get_gpu_vram_usage, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + ) @contextmanager def _nvml(): @@ -53,9 +65,12 @@ if current_platform.is_rocm(): finally: amdsmi_shut_down() elif current_platform.is_cuda(): - from vllm.third_party.pynvml import (nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, nvmlInit, - nvmlShutdown) + from vllm.third_party.pynvml import ( + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlInit, + nvmlShutdown, + ) @contextmanager def _nvml(): @@ -78,58 +93,61 @@ VLLM_PATH = Path(__file__).parent.parent class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - def _start_server(self, model: str, vllm_serve_args: list[str], - env_dict: Optional[dict[str, str]]) -> None: - """Subclasses override this method to customize server process launch - """ + def _start_server( + self, model: str, vllm_serve_args: list[str], env_dict: Optional[dict[str, str]] + ) -> None: + """Subclasses override this method to customize server process launch""" env = os.environ.copy() # the current process might initialize cuda, # to be safe, we should use spawn method - env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" if env_dict is not None: env.update(env_dict) + serve_cmd = ["vllm", "serve", model, *vllm_serve_args] + print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}") self.proc: subprocess.Popen = subprocess.Popen( - ["vllm", "serve", model, *vllm_serve_args], + serve_cmd, env=env, stdout=sys.stdout, stderr=sys.stderr, ) - def __init__(self, - model: str, - vllm_serve_args: list[str], - *, - env_dict: Optional[dict[str, str]] = None, - seed: Optional[int] = 0, - auto_port: bool = True, - max_wait_seconds: Optional[float] = None, - override_hf_configs: Optional[dict[str, Any]] = None) -> None: + def __init__( + self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None, + ) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: - raise ValueError("You have manually specified the port " - "when `auto_port=True`.") + raise ValueError( + "You have manually specified the port when `auto_port=True`." + ) # No need for a port if using unix sockets if "--uds" not in vllm_serve_args: # Don't mutate the input args - vllm_serve_args = vllm_serve_args + [ - "--port", str(get_open_port()) - ] + vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] if seed is not None: if "--seed" in vllm_serve_args: - raise ValueError("You have manually specified the seed " - f"when `seed={seed}`.") + raise ValueError( + f"You have manually specified the seed when `seed={seed}`." + ) vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] if override_hf_configs is not None: vllm_serve_args = vllm_serve_args + [ "--hf-overrides", - json.dumps(override_hf_configs) + json.dumps(override_hf_configs), ] - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") parser = ServeSubcommand().subparser_init(subparsers) args = parser.parse_args(["--model", model, *vllm_serve_args]) @@ -138,11 +156,10 @@ class RemoteOpenAIServer: self.host = None self.port = None else: - self.host = str(args.host or 'localhost') + self.host = str(args.host or "localhost") self.port = int(args.port) - self.show_hidden_metrics = \ - args.show_hidden_metrics_for_version is not None + self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None # download the model before starting the server to avoid timeout is_local = os.path.isdir(model) @@ -156,8 +173,7 @@ class RemoteOpenAIServer: self._start_server(model, vllm_serve_args, env_dict) max_wait_seconds = max_wait_seconds or 240 - self._wait_for_server(url=self.url_for("health"), - timeout=max_wait_seconds) + self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) def __enter__(self): return self @@ -177,8 +193,11 @@ class RemoteOpenAIServer: def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() - client = (httpx.Client(transport=httpx.HTTPTransport( - uds=self.uds)) if self.uds else requests) + client = ( + httpx.Client(transport=httpx.HTTPTransport(uds=self.uds)) + if self.uds + else requests + ) while True: try: if client.get(url).status_code == 200: @@ -194,13 +213,15 @@ class RemoteOpenAIServer: time.sleep(0.5) if time.time() - start > timeout: - raise RuntimeError( - "Server failed to start in time.") from None + raise RuntimeError("Server failed to start in time.") from None @property def url_root(self) -> str: - return (f"http://{self.uds.split('/')[-1]}" - if self.uds else f"http://{self.host}:{self.port}") + return ( + f"http://{self.uds.split('/')[-1]}" + if self.uds + else f"http://{self.host}:{self.port}" + ) def url_for(self, *parts: str) -> str: return self.url_root + "/" + "/".join(parts) @@ -218,42 +239,47 @@ class RemoteOpenAIServer: def get_async_client(self, **kwargs): if "timeout" not in kwargs: kwargs["timeout"] = 600 - return openai.AsyncOpenAI(base_url=self.url_for("v1"), - api_key=self.DUMMY_API_KEY, - max_retries=0, - **kwargs) + return openai.AsyncOpenAI( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) class RemoteOpenAIServerCustom(RemoteOpenAIServer): """Launch test server with custom child process""" - def _start_server(self, model: str, vllm_serve_args: list[str], - env_dict: Optional[dict[str, str]]) -> None: + def _start_server( + self, model: str, vllm_serve_args: list[str], env_dict: Optional[dict[str, str]] + ) -> None: self.proc: Process = Process( - target=self.child_process_fxn, - args=(env_dict, model, - vllm_serve_args)) # type: ignore[assignment] + target=self.child_process_fxn, args=(env_dict, model, vllm_serve_args) + ) # type: ignore[assignment] self.proc.start() - def __init__(self, - model: str, - vllm_serve_args: list[str], - child_process_fxn: Callable[ - [Optional[dict[str, str]], str, list[str]], None], - *, - env_dict: Optional[dict[str, str]] = None, - seed: Optional[int] = 0, - auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + def __init__( + self, + model: str, + vllm_serve_args: list[str], + child_process_fxn: Callable[[Optional[dict[str, str]], str, list[str]], None], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None, + ) -> None: """Store custom child process function then invoke superclass constructor which will indirectly launch it.""" self.child_process_fxn = child_process_fxn - super().__init__(model=model, - vllm_serve_args=vllm_serve_args, - env_dict=env_dict, - seed=seed, - auto_port=auto_port, - max_wait_seconds=max_wait_seconds) + super().__init__( + model=model, + vllm_serve_args=vllm_serve_args, + env_dict=env_dict, + seed=seed, + auto_port=auto_port, + max_wait_seconds=max_wait_seconds, + ) def _poll(self) -> Optional[int]: return self.proc.exitcode @@ -275,17 +301,18 @@ def _test_completion( results = [] # test with text prompt - completion = client.completions.create(model=model, - prompt=prompt, - max_tokens=5, - temperature=0.0) + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=5, temperature=0.0 + ) - results.append({ - "test": "single_completion", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - }) + results.append( + { + "test": "single_completion", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) # test using token IDs completion = client.completions.create( @@ -295,43 +322,42 @@ def _test_completion( temperature=0.0, ) - results.append({ - "test": "token_ids", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - }) + results.append( + { + "test": "token_ids", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) # test seeded random sampling - completion = client.completions.create(model=model, - prompt=prompt, - max_tokens=5, - seed=33, - temperature=1.0) + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0 + ) - results.append({ - "test": "seeded_sampling", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - }) + results.append( + { + "test": "seeded_sampling", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) # test seeded random sampling with multiple prompts - completion = client.completions.create(model=model, - prompt=[prompt, prompt], - max_tokens=5, - seed=33, - temperature=1.0) + completion = client.completions.create( + model=model, prompt=[prompt, prompt], max_tokens=5, seed=33, temperature=1.0 + ) - results.append({ - "test": - "seeded_sampling", - "text": [choice.text for choice in completion.choices], - "finish_reason": - [choice.finish_reason for choice in completion.choices], - "usage": - completion.usage, - }) + results.append( + { + "test": "seeded_sampling", + "text": [choice.text for choice in completion.choices], + "finish_reason": [choice.finish_reason for choice in completion.choices], + "usage": completion.usage, + } + ) # test simple list batch = client.completions.create( @@ -341,11 +367,13 @@ def _test_completion( temperature=0.0, ) - results.append({ - "test": "simple_list", - "text0": batch.choices[0].text, - "text1": batch.choices[1].text, - }) + results.append( + { + "test": "simple_list", + "text0": batch.choices[0].text, + "text1": batch.choices[1].text, + } + ) # test streaming batch = client.completions.create( @@ -362,10 +390,12 @@ def _test_completion( choice = chunk.choices[0] texts[choice.index] += choice.text - results.append({ - "test": "streaming", - "texts": texts, - }) + results.append( + { + "test": "streaming", + "texts": texts, + } + ) return results @@ -378,19 +408,19 @@ def _test_completion_close( results = [] # test with text prompt - completion = client.completions.create(model=model, - prompt=prompt, - max_tokens=1, - logprobs=5, - temperature=0.0) + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=1, logprobs=5, temperature=0.0 + ) logprobs = completion.choices[0].logprobs.top_logprobs[0] logprobs = {k: round(v, 2) for k, v in logprobs.items()} - results.append({ - "test": "completion_close", - "logprobs": logprobs, - }) + results.append( + { + "test": "completion_close", + "logprobs": logprobs, + } + ) return results @@ -402,26 +432,21 @@ def _test_chat( ): results = [] - messages = [{ - "role": "user", - "content": [{ - "type": "text", - "text": prompt - }] - }] + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] # test with text prompt - chat_response = client.chat.completions.create(model=model, - messages=messages, - max_tokens=5, - temperature=0.0) + chat_response = client.chat.completions.create( + model=model, messages=messages, max_tokens=5, temperature=0.0 + ) - results.append({ - "test": "completion_close", - "text": chat_response.choices[0].message.content, - "finish_reason": chat_response.choices[0].finish_reason, - "usage": chat_response.usage, - }) + results.append( + { + "test": "completion_close", + "text": chat_response.choices[0].message.content, + "finish_reason": chat_response.choices[0].finish_reason, + "usage": chat_response.usage, + } + ) return results @@ -440,11 +465,13 @@ def _test_embeddings( encoding_format="float", ) - results.append({ - "test": "single_embedding", - "embedding": embeddings.data[0].embedding, - "usage": embeddings.usage, - }) + results.append( + { + "test": "single_embedding", + "embedding": embeddings.data[0].embedding, + "usage": embeddings.usage, + } + ) return results @@ -457,74 +484,75 @@ def _test_image_text( results = [] # test pure text input - messages = [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "How do you feel today?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "How do you feel today?"}, + ], + } + ] - chat_completion = client.chat.completions.create(model=model_name, - messages=messages, - temperature=0.0, - max_tokens=1, - logprobs=True, - top_logprobs=5) + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.0, + max_tokens=1, + logprobs=True, + top_logprobs=5, + ) top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs for x in top_logprobs: x.logprob = round(x.logprob, 2) - results.append({ - "test": "pure_text", - "logprobs": top_logprobs, - }) + results.append( + { + "test": "pure_text", + "logprobs": top_logprobs, + } + ) - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] - chat_completion = client.chat.completions.create(model=model_name, - messages=messages, - temperature=0.0, - max_tokens=1, - logprobs=True, - top_logprobs=5) + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.0, + max_tokens=1, + logprobs=True, + top_logprobs=5, + ) top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs - results.append({ - "test": "text_image", - "logprobs": top_logprobs, - }) + results.append( + { + "test": "text_image", + "logprobs": top_logprobs, + } + ) return results -def compare_two_settings(model: str, - arg1: list[str], - arg2: list[str], - env1: Optional[dict[str, str]] = None, - env2: Optional[dict[str, str]] = None, - *, - method: str = "generate", - max_wait_seconds: Optional[float] = None) -> None: +def compare_two_settings( + model: str, + arg1: list[str], + arg2: list[str], + env1: Optional[dict[str, str]] = None, + env2: Optional[dict[str, str]] = None, + *, + method: str = "generate", + max_wait_seconds: Optional[float] = None, +) -> None: """ Launch API server with two different sets of arguments/environments and compare the results of the API calls. @@ -546,12 +574,14 @@ def compare_two_settings(model: str, ) -def compare_all_settings(model: str, - all_args: list[list[str]], - all_envs: list[Optional[dict[str, str]]], - *, - method: str = "generate", - max_wait_seconds: Optional[float] = None) -> None: +def compare_all_settings( + model: str, + all_args: list[list[str]], + all_envs: list[Optional[dict[str, str]]], + *, + method: str = "generate", + max_wait_seconds: Optional[float] = None, +) -> None: """ Launch API server with several different sets of arguments/environments and compare the results of the API calls with the first set of arguments. @@ -601,21 +631,22 @@ def compare_all_settings(model: str, args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT] compare_results: list = [] results = ref_results if i == 0 else compare_results - with RemoteOpenAIServer(model, - args, - env_dict=env, - max_wait_seconds=max_wait_seconds) as server: + with RemoteOpenAIServer( + model, args, env_dict=env, max_wait_seconds=max_wait_seconds + ) as server: client = server.get_client() # test models list models = client.models.list() models = models.data served_model = models[0] - results.append({ - "test": "models_list", - "id": served_model.id, - "root": served_model.root, - }) + results.append( + { + "test": "models_list", + "id": served_model.id, + "root": served_model.root, + } + ) if method == "generate": results += _test_completion(client, model, prompt, token_ids) @@ -625,8 +656,9 @@ def compare_all_settings(model: str, results += _test_chat(client, model, prompt) elif method == "generate_with_image": results += _test_image_text( - client, model, - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png" + client, + model, + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ) elif method == "encode": results += _test_embeddings(client, model, prompt) @@ -639,8 +671,7 @@ def compare_all_settings(model: str, ref_envs = all_envs[0] compare_args = all_args[i] compare_envs = all_envs[i] - for ref_result, compare_result in zip(ref_results, - compare_results): + for ref_result, compare_result in zip(ref_results, compare_results): ref_result = copy.deepcopy(ref_result) compare_result = copy.deepcopy(compare_result) if "embedding" in ref_result and method == "encode": @@ -651,7 +682,8 @@ def compare_all_settings(model: str, ) assert sim >= 0.999, ( f"Embedding for {model=} are not the same.\n" - f"cosine_similarity={sim}\n") + f"cosine_similarity={sim}\n" + ) del ref_result["embedding"] del compare_result["embedding"] assert ref_result == compare_result, ( @@ -659,7 +691,8 @@ def compare_all_settings(model: str, f"{ref_args=} {ref_envs=}\n" f"{compare_args=} {compare_envs=}\n" f"{ref_result=}\n" - f"{compare_result=}\n") + f"{compare_result=}\n" + ) def init_test_distributed_environment( @@ -674,7 +707,8 @@ def init_test_distributed_environment( world_size=pp_size * tp_size, rank=rank, distributed_init_method=distributed_init_method, - local_rank=local_rank) + local_rank=local_rank, + ) ensure_model_parallel_initialized(tp_size, pp_size) @@ -697,9 +731,16 @@ def multi_process_parallel( ray.init( runtime_env={ "working_dir": VLLM_PATH, - "excludes": - ["build", ".git", "cmake-build-*", "shellcheck", "dist"] - }) + "excludes": [ + "build", + ".git", + "cmake-build-*", + "shellcheck", + "dist", + "ep_kernels_workspace", + ], + } + ) distributed_init_port = get_open_port() refs = [] @@ -711,7 +752,8 @@ def multi_process_parallel( pp_size, rank, distributed_init_port, - ), ) + ), + ) ray.get(refs) ray.shutdown() @@ -740,11 +782,13 @@ def get_physical_device_indices(devices): @_nvml() -def wait_for_gpu_memory_to_clear(*, - devices: list[int], - threshold_bytes: Optional[int] = None, - threshold_ratio: Optional[float] = None, - timeout_s: float = 120) -> None: +def wait_for_gpu_memory_to_clear( + *, + devices: list[int], + threshold_bytes: Optional[int] = None, + threshold_ratio: Optional[float] = None, + timeout_s: float = 120, +) -> None: assert threshold_bytes is not None or threshold_ratio is not None # Use nvml instead of pytorch to reduce measurement error from torch cuda # context. @@ -765,29 +809,33 @@ def wait_for_gpu_memory_to_clear(*, gb_used = mem_info.used / 2**30 gb_total = mem_info.total / 2**30 output_raw[device] = (gb_used, gb_total) - output[device] = f'{gb_used:.02f}/{gb_total:.02f}' + output[device] = f"{gb_used:.02f}/{gb_total:.02f}" - print('gpu memory used/total (GiB): ', end='') + print("gpu memory used/total (GiB): ", end="") for k, v in output.items(): - print(f'{k}={v}; ', end='') - print('') + print(f"{k}={v}; ", end="") + print("") if threshold_bytes is not None: is_free = lambda used, total: used <= threshold_bytes / 2**30 - threshold = f"{threshold_bytes/2**30} GiB" + threshold = f"{threshold_bytes / 2**30} GiB" else: is_free = lambda used, total: used / total <= threshold_ratio threshold = f"{threshold_ratio:.2f}" dur_s = time.time() - start_time if all(is_free(used, total) for used, total in output_raw.values()): - print(f'Done waiting for free GPU memory on devices {devices=} ' - f'({threshold=}) {dur_s=:.02f}') + print( + f"Done waiting for free GPU memory on devices {devices=} " + f"({threshold=}) {dur_s=:.02f}" + ) break if dur_s >= timeout_s: - raise ValueError(f'Memory of devices {devices=} not free after ' - f'{dur_s=:.02f} ({threshold=})') + raise ValueError( + f"Memory of devices {devices=} not free after " + f"{dur_s=:.02f} ({threshold=})" + ) time.sleep(5) @@ -795,70 +843,139 @@ def wait_for_gpu_memory_to_clear(*, _P = ParamSpec("_P") -def fork_new_process_for_each_test( - f: Callable[_P, None]) -> Callable[_P, None]: +def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. See https://github.com/vllm-project/vllm/issues/7053 for more details. """ - @functools.wraps(f) + @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() from _pytest.outcomes import Skipped - pid = os.fork() - print(f"Fork a new process to run a test {pid}") - if pid == 0: - try: - f(*args, **kwargs) - except Skipped as e: - # convert Skipped to exit code 0 - print(str(e)) - os._exit(0) - except Exception: - import traceback - traceback.print_exc() - os._exit(1) + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with ( + tempfile.NamedTemporaryFile( + delete=False, + mode="w+b", + prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", + suffix=".exc", + ) as exc_file, + ExitStack() as delete_after, + ): + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {"pickled_exception": e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": tb_string, + } + try: + with open(exc_file_path, "wb") as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) else: - os._exit(0) - else: - pgid = os.getpgid(pid) - _pid, _exitcode = os.waitpid(pid, 0) - # ignore SIGTERM signal itself - old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) - # kill all child processes - os.killpg(pgid, signal.SIGTERM) - # restore the signal handler - signal.signal(signal.SIGTERM, old_signal_handler) - assert _exitcode == 0, (f"function {f} failed when called with" - f" args {args} and kwargs {kwargs}") + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with ( + contextlib.suppress(Exception), + open(exc_file_path, "rb") as f, + ): + exc_info = cloudpickle.load(f) + + if ( + original_exception := exc_info.get("pickled_exception") + ) is not None: + # Re-raise the actual exception object if it was + # successfully pickled. + assert isinstance(original_exception, Exception) + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})" + ) from None return wrapper -def spawn_new_process_for_each_test( - f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to spawn a new process for each test function. - """ +def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function.""" @functools.wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Check if we're already in a subprocess - if os.environ.get('RUNNING_IN_SUBPROCESS') == '1': + if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": # If we are, just run the function directly return f(*args, **kwargs) import torch.multiprocessing as mp + with suppress(RuntimeError): - mp.set_start_method('spawn') + mp.set_start_method("spawn") # Get the module module_name = f.__module__ # Create a process with environment variable set env = os.environ.copy() - env['RUNNING_IN_SUBPROCESS'] = '1' + env["RUNNING_IN_SUBPROCESS"] = "1" with tempfile.TemporaryDirectory() as tempdir: output_filepath = os.path.join(tempdir, "new_process.tmp") @@ -868,29 +985,29 @@ def spawn_new_process_for_each_test( cmd = [sys.executable, "-m", f"{module_name}"] - returned = subprocess.run(cmd, - input=input_bytes, - capture_output=True, - env=env) + returned = subprocess.run( + cmd, input=input_bytes, capture_output=True, env=env + ) # check if the subprocess is successful try: returned.check_returncode() except Exception as e: # wrap raised exception to provide more information - raise RuntimeError(f"Error raised in subprocess:\n" - f"{returned.stderr.decode()}") from e + raise RuntimeError( + f"Error raised in subprocess:\n{returned.stderr.decode()}" + ) from e return wrapper def create_new_process_for_each_test( - method: Optional[Literal["spawn", "fork"]] = None + method: Optional[Literal["spawn", "fork"]] = None, ) -> Callable[[Callable[_P, None]], Callable[_P, None]]: """Creates a decorator that runs each test function in a new process. Args: - method: The process creation method. Can be either "spawn" or "fork". + method: The process creation method. Can be either "spawn" or "fork". If not specified, it defaults to "spawn" on ROCm and XPU platforms and "fork" otherwise. @@ -901,8 +1018,7 @@ def create_new_process_for_each_test( use_spawn = current_platform.is_rocm() or current_platform.is_xpu() method = "spawn" if use_spawn else "fork" - assert method in ["spawn", - "fork"], "Method must be either 'spawn' or 'fork'" + assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" if method == "fork": return fork_new_process_for_each_test @@ -986,7 +1102,7 @@ async def completions_with_server_args( max_wait_seconds: int = 240, max_tokens: Union[int, list] = 5, ) -> list[Completion]: - '''Construct a remote OpenAI server, obtain an async client to the + """Construct a remote OpenAI server, obtain an async client to the server & invoke the completions API to obtain completions. Args: @@ -1002,7 +1118,7 @@ async def completions_with_server_args( Returns: OpenAI Completion instance - ''' + """ if isinstance(max_tokens, int): max_tokens = [max_tokens] * len(prompts) @@ -1010,17 +1126,21 @@ async def completions_with_server_args( assert len(max_tokens) == len(prompts) outputs = None - with RemoteOpenAIServer(model_name, - server_cli_args, - max_wait_seconds=max_wait_seconds) as server: + with RemoteOpenAIServer( + model_name, server_cli_args, max_wait_seconds=max_wait_seconds + ) as server: client = server.get_async_client() - outputs = [ client.completions.create(model=model_name, - prompt=[p], - temperature=0, - stream=False, - max_tokens=max_tok, - logprobs=num_logprobs) \ - for p, max_tok in zip(prompts, max_tokens) ] + outputs = [ + client.completions.create( + model=model_name, + prompt=[p], + temperature=0, + stream=False, + max_tokens=max_tok, + logprobs=num_logprobs, + ) + for p, max_tok in zip(prompts, max_tokens) + ] outputs = await asyncio.gather(*outputs) assert outputs is not None, "Completion API call failed." @@ -1029,24 +1149,31 @@ async def completions_with_server_args( def get_client_text_generations(completions: list[Completion]) -> list[str]: - '''Extract generated tokens from the output of a + """Extract generated tokens from the output of a request made to an Open-AI-protocol completions endpoint. - ''' + """ assert all([len(x.choices) == 1 for x in completions]) return [x.choices[0].text for x in completions] def get_client_text_logprob_generations( - completions: list[Completion]) -> list[TextTextLogprobs]: - '''Operates on the output of a request made to an Open-AI-protocol + completions: list[Completion], +) -> list[TextTextLogprobs]: + """Operates on the output of a request made to an Open-AI-protocol completions endpoint; obtains top-rank logprobs for each token in each {class}`SequenceGroup` - ''' + """ text_generations = get_client_text_generations(completions) - text = ''.join(text_generations) - return [(text_generations, text, - (None if x.logprobs is None else x.logprobs.top_logprobs)) - for completion in completions for x in completion.choices] + text = "".join(text_generations) + return [ + ( + text_generations, + text, + (None if x.logprobs is None else x.logprobs.top_logprobs), + ) + for completion in completions + for x in completion.choices + ] def has_module_attribute(module_name, attribute_name): @@ -1062,15 +1189,74 @@ def has_module_attribute(module_name, attribute_name): def get_attn_backend_list_based_on_platform() -> list[str]: if current_platform.is_cuda(): - return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"] + return ["FLASH_ATTN", "TRITON_ATTN", "TREE_ATTN"] elif current_platform.is_rocm(): - attn_backend_list = ["TRITON_ATTN_VLLM_V1"] + attn_backend_list = ["TRITON_ATTN"] try: import aiter # noqa: F401 - attn_backend_list.append("FLASH_ATTN_VLLM_V1") + + attn_backend_list.append("FLASH_ATTN") except Exception: - print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed") + print("Skip FLASH_ATTN on ROCm as aiter is not installed") return attn_backend_list + elif current_platform.is_xpu(): + return ["FLASH_ATTN", "TRITON_ATTN"] else: raise ValueError("Unsupported platform") + + +@contextmanager +def override_cutlass_fp8_supported(value: bool): + with patch( + "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", + return_value=value, + ): + yield + + +def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + Args: + batch_size: number of prompts to generate + ln_range: an argument to control the length of the prompt + """ + prompts: list[str] = [] + answer: list[int] = [] + indices: list[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = ( + "```python\n# We set a number of variables, " + + f"x{idx} will be important later\n" + ) + ln = random.randint(*ln_range) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers( + indices: list[int], answer: list[int], outputs: list[str], accept_rate: float = 0.7 +): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok >= accept_rate diff --git a/tests/utils_/test_cache.py b/tests/utils_/test_cache.py new file mode 100644 index 0000000000000..e361006fd8e66 --- /dev/null +++ b/tests/utils_/test_cache.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.utils.cache import CacheInfo, LRUCache + + +class TestLRUCache(LRUCache): + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + + assert cache.get(2) == 2 + assert cache.stat() == CacheInfo(hits=1, total=1) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + assert cache[2] == 2 + assert cache.stat() == CacheInfo(hits=2, total=2) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + assert cache.get(-1) is None + assert cache.stat() == CacheInfo(hits=2, total=3) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/utils_/test_gc_utils.py b/tests/utils_/test_gc_utils.py new file mode 100644 index 0000000000000..f1d0de87c81ba --- /dev/null +++ b/tests/utils_/test_gc_utils.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any + +from vllm.utils.gc_utils import ( + GCDebugConfig, + _compute_detailed_type, + _compute_top_gc_collected_objects, +) + + +@dataclass +class Normal: + v: int + + +@dataclass +class ListWrapper: + vs: list[int] + + def __len__(self) -> int: + return len(self.vs) + + +def test_compute_detailed_type(): + assert ( + _compute_detailed_type(Normal(v=8)) + == "<class 'tests.utils_.test_gc_utils.Normal'>" + ) + + assert _compute_detailed_type([1, 2, 3]) == "<class 'list'>(size:3)" + assert _compute_detailed_type({4, 5}) == "<class 'set'>(size:2)" + assert _compute_detailed_type({6: 7}) == "<class 'dict'>(size:1)" + assert ( + _compute_detailed_type(ListWrapper(vs=[])) + == "<class 'tests.utils_.test_gc_utils.ListWrapper'>(size:0)" + ) + + +def test_compute_top_gc_collected_objects(): + objects: list[Any] = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + {13, 14}, + {15: 16, 17: 18}, + Normal(v=19), + Normal(v=20), + Normal(v=21), + ] + assert _compute_top_gc_collected_objects(objects, top=-1) == "" + assert _compute_top_gc_collected_objects(objects, top=0) == "" + assert ( + _compute_top_gc_collected_objects(objects, top=1) + == " 4:<class 'list'>(size:3)" + ) + assert _compute_top_gc_collected_objects(objects, top=2) == "\n".join( + [ + " 4:<class 'list'>(size:3)", + " 3:<class 'tests.utils_.test_gc_utils.Normal'>", + ] + ) + assert _compute_top_gc_collected_objects(objects, top=3) == "\n".join( + [ + " 4:<class 'list'>(size:3)", + " 3:<class 'tests.utils_.test_gc_utils.Normal'>", + " 1:<class 'set'>(size:2)", + ] + ) + + +def test_gc_debug_config(): + assert not GCDebugConfig(None).enabled + assert not GCDebugConfig("").enabled + assert not GCDebugConfig("0").enabled + + config = GCDebugConfig("1") + assert config.enabled + assert config.top_objects == -1 + + config = GCDebugConfig('{"top_objects":5}') + assert config.enabled + assert config.top_objects == 5 diff --git a/tests/utils_/test_tensor_schema.py b/tests/utils_/test_tensor_schema.py index 6aa781c1564de..c86bed75472c9 100644 --- a/tests/utils_/test_tensor_schema.py +++ b/tests/utils_/test_tensor_schema.py @@ -6,37 +6,38 @@ import torch from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs +from vllm.model_executor.models.hyperclovax_vision import HCXVisionVideoPixelInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs def test_tensor_schema_valid_tensor(): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3, 32, 32), + pixel_values=torch.randn(16, 64, 3, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_optional_fields(): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3, 32, 32), + pixel_values=torch.randn(16, 64, 3, 32, 32), image_sizes=None, ) - Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), ) + Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32)) def test_tensor_schema_constant_dim_failure(): with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 + pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_invalid_types_in_list(): - with pytest.raises(ValueError, match="is not a torch.Tensor"): + with pytest.raises(TypeError, match="is not one of the expected types"): Phi3VImagePixelInputs( - data=[ + pixel_values=[ torch.randn(64, 3, 32, 32), "not_a_tensor", torch.randn(64, 3, 32, 32), @@ -48,27 +49,29 @@ def test_tensor_schema_invalid_types_in_list(): def test_tensor_schema_rank_mismatch(): with pytest.raises(ValueError, match="has rank 3 but expected 5"): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3), + pixel_values=torch.randn(16, 64, 3), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_missing_required_field(): - with pytest.raises(ValueError, match="Required field 'data' is missing"): - Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), ) + with pytest.raises(ValueError, match="Required field 'pixel_values' is missing"): + Phi3VImagePixelInputs( + image_sizes=torch.randint(0, 256, (16, 2)), + ) def test_tensor_schema_symbolic_dim_mismatch(): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): Phi3VImagePixelInputs( - data=torch.randn(12, 64, 3, 32, 32), + pixel_values=torch.randn(12, 64, 3, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_list_tensor_valid(): Phi3VImagePixelInputs( - data=[torch.randn(64, 3, 32, 32) for _ in range(16)], + pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)], image_sizes=torch.randint(0, 256, (16, 2)), ) @@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid(): def test_tensor_schema_variable_patch_counts_valid(): # Each image has a different number of patches (p) # Each tensor has shape (p, 3, 32, 32) - data = [ - torch.randn(16, 3, 32, 32), # p = 16 - torch.randn(32, 3, 32, 32), # p = 32 - torch.randn(64, 3, 32, 32), # p = 64 - ] - image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3 Phi3VImagePixelInputs( - data=data, - image_sizes=image_sizes, + pixel_values=[ + torch.randn(16, 3, 32, 32), # p = 16 + torch.randn(32, 3, 32, 32), # p = 32 + torch.randn(64, 3, 32, 32), # p = 64 + ], + image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3 ) def test_tensor_schema_tuple_tensor_valid(): Phi3VImagePixelInputs( - data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), + pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), image_sizes=torch.randint(0, 256, (16, 2)), ) +def test_tensor_schema_double_nested_tensors(): + x = torch.rand(4, 3, 32, 32) + y = torch.rand(2, 3, 32, 32) + + HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y])) + + def test_tensor_schema_inconsistent_shapes_in_list(): with pytest.raises(ValueError, match="contains inconsistent shapes"): Phi3VImagePixelInputs( - data=[torch.randn(64, 3, 32, 32), - torch.randn(64, 3, 16, 16)] + - [torch.randn(64, 3, 32, 32) for _ in range(14)], + pixel_values=[ + torch.randn(64, 3, 32, 32), + torch.randn(64, 3, 16, 16), + *(torch.randn(64, 3, 32, 32) for _ in range(14)), + ], image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_empty_list(): - with pytest.raises(ValueError, match="is an empty list"): + with pytest.raises(ValueError, match="is an empty sequence"): Phi3VImagePixelInputs( - data=[], + pixel_values=[], image_sizes=torch.randint(0, 256, (0, 2)), ) @@ -117,39 +127,33 @@ def test_tensor_schema_validation_disabled_skips_shape_check(): # This should NOT raise, because validation is turned off # This would normally fail (dim[2] should be 3, not 4) Phi3VImagePixelInputs( - data=torch.randn(16, 64, 4, 32, 32), + pixel_values=torch.randn(16, 64, 4, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), validate=False, ) def test_tensor_schema_with_valid_resolve_binding_dims(): - data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 + pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 image_sizes = torch.randint(0, 256, (16, 2)) Phi3VImagePixelInputs( - data=data, + pixel_values=pixel_values, image_sizes=image_sizes, - resolve_bindings={ - "h": 336, - "w": 336 - }, + resolve_bindings={"h": 336, "w": 336}, ) def test_tensor_schema_with_invalid_resolve_binding_dims(): - data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 + pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 image_sizes = torch.randint(0, 256, (16, 2)) # Should raise because 'h' and 'w' don't match resolve bindings with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"): Phi3VImagePixelInputs( - data=data, + pixel_values=pixel_values, image_sizes=image_sizes, - resolve_bindings={ - "h": 336, - "w": 336 - }, + resolve_bindings={"h": 336, "w": 336}, ) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 084d82dee11b3..308629ab05834 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -5,35 +5,53 @@ import asyncio import hashlib import json +import os import pickle import socket +import tempfile from collections.abc import AsyncIterator +from pathlib import Path from unittest.mock import patch import pytest import torch +import yaml import zmq from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.transformers_utils.detokenizer_utils import ( - convert_ids_list_to_tokens) -from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, - MemorySnapshot, PlaceholderModule, StoreBoolean, - bind_kv_cache, common_broadcastable_dtype, - current_stream, deprecate_kwargs, get_open_port, - get_tcp_uri, is_lossless_cast, join_host_port, - make_zmq_path, make_zmq_socket, memory_profiling, - merge_async_iterators, sha256, split_host_port, - split_zmq_path, supports_kw, swap_dict_values) +from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens + +from vllm.utils import ( + FlexibleArgumentParser, + MemorySnapshot, + PlaceholderModule, + bind_kv_cache, + common_broadcastable_dtype, + current_stream, + deprecate_kwargs, + get_open_port, + get_tcp_uri, + is_lossless_cast, + join_host_port, + make_zmq_path, + make_zmq_socket, + memory_profiling, + merge_async_iterators, + sha256, + split_host_port, + split_zmq_path, + supports_kw, + swap_dict_values, + unique_filepath, +) from ..utils import create_new_process_for_each_test, error_on_warning @pytest.mark.asyncio async def test_merge_async_iterators(): - async def mock_async_iterator(idx: int): try: while True: @@ -57,8 +75,7 @@ async def test_merge_async_iterators(): for iterator in iterators: try: - # Can use anext() in python >= 3.10 - await asyncio.wait_for(iterator.__anext__(), 1) + await asyncio.wait_for(anext(iterator), 1) except StopAsyncIteration: # All iterators should be cancelled and print this message. print("Iterator was cancelled normally") @@ -67,7 +84,6 @@ async def test_merge_async_iterators(): def test_deprecate_kwargs_always(): - @deprecate_kwargs("old_arg", is_deprecated=True) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -80,7 +96,6 @@ def test_deprecate_kwargs_always(): def test_deprecate_kwargs_never(): - @deprecate_kwargs("old_arg", is_deprecated=False) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -115,7 +130,6 @@ def test_deprecate_kwargs_dynamic(): def test_deprecate_kwargs_additional_message(): - @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -140,99 +154,107 @@ def test_get_open_port(monkeypatch: pytest.MonkeyPatch): @pytest.fixture def parser(): parser = FlexibleArgumentParser() - parser.add_argument('--image-input-type', - choices=['pixel_values', 'image_features']) - parser.add_argument('--model-name') - parser.add_argument('--batch-size', type=int) - parser.add_argument('--enable-feature', action='store_true') - parser.add_argument('--hf-overrides', type=json.loads) - parser.add_argument('-O', '--compilation-config', type=json.loads) + parser.add_argument( + "--image-input-type", choices=["pixel_values", "image_features"] + ) + parser.add_argument("--model-name") + parser.add_argument("--batch-size", type=int) + parser.add_argument("--enable-feature", action="store_true") + parser.add_argument("--hf-overrides", type=json.loads) + parser.add_argument("-O", "--compilation-config", type=json.loads) return parser @pytest.fixture def parser_with_config(): parser = FlexibleArgumentParser() - parser.add_argument('serve') - parser.add_argument('model_tag', nargs='?') - parser.add_argument('--model', type=str) - parser.add_argument('--served-model-name', type=str) - parser.add_argument('--config', type=str) - parser.add_argument('--port', type=int) - parser.add_argument('--tensor-parallel-size', type=int) - parser.add_argument('--trust-remote-code', action='store_true') + parser.add_argument("serve") + parser.add_argument("model_tag", nargs="?") + parser.add_argument("--model", type=str) + parser.add_argument("--served-model-name", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("--port", type=int) + parser.add_argument("--tensor-parallel-size", type=int) + parser.add_argument("--trust-remote-code", action="store_true") return parser def test_underscore_to_dash(parser): - args = parser.parse_args(['--image_input_type', 'pixel_values']) - assert args.image_input_type == 'pixel_values' + args = parser.parse_args(["--image_input_type", "pixel_values"]) + assert args.image_input_type == "pixel_values" def test_mixed_usage(parser): - args = parser.parse_args([ - '--image_input_type', 'image_features', '--model-name', - 'facebook/opt-125m' - ]) - assert args.image_input_type == 'image_features' - assert args.model_name == 'facebook/opt-125m' + args = parser.parse_args( + ["--image_input_type", "image_features", "--model-name", "facebook/opt-125m"] + ) + assert args.image_input_type == "image_features" + assert args.model_name == "facebook/opt-125m" def test_with_equals_sign(parser): args = parser.parse_args( - ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m']) - assert args.image_input_type == 'pixel_values' - assert args.model_name == 'facebook/opt-125m' + ["--image_input_type=pixel_values", "--model-name=facebook/opt-125m"] + ) + assert args.image_input_type == "pixel_values" + assert args.model_name == "facebook/opt-125m" def test_with_int_value(parser): - args = parser.parse_args(['--batch_size', '32']) + args = parser.parse_args(["--batch_size", "32"]) assert args.batch_size == 32 - args = parser.parse_args(['--batch-size', '32']) + args = parser.parse_args(["--batch-size", "32"]) assert args.batch_size == 32 def test_with_bool_flag(parser): - args = parser.parse_args(['--enable_feature']) + args = parser.parse_args(["--enable_feature"]) assert args.enable_feature is True - args = parser.parse_args(['--enable-feature']) + args = parser.parse_args(["--enable-feature"]) assert args.enable_feature is True def test_invalid_choice(parser): with pytest.raises(SystemExit): - parser.parse_args(['--image_input_type', 'invalid_choice']) + parser.parse_args(["--image_input_type", "invalid_choice"]) def test_missing_required_argument(parser): - parser.add_argument('--required-arg', required=True) + parser.add_argument("--required-arg", required=True) with pytest.raises(SystemExit): parser.parse_args([]) def test_cli_override_to_config(parser_with_config, cli_config_file): - args = parser_with_config.parse_args([ - 'serve', 'mymodel', '--config', cli_config_file, - '--tensor-parallel-size', '3' - ]) + args = parser_with_config.parse_args( + ["serve", "mymodel", "--config", cli_config_file, "--tensor-parallel-size", "3"] + ) assert args.tensor_parallel_size == 3 - args = parser_with_config.parse_args([ - 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', - cli_config_file - ]) + args = parser_with_config.parse_args( + ["serve", "mymodel", "--tensor-parallel-size", "3", "--config", cli_config_file] + ) assert args.tensor_parallel_size == 3 assert args.port == 12312 - args = parser_with_config.parse_args([ - 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', - cli_config_file, '--port', '666' - ]) + args = parser_with_config.parse_args( + [ + "serve", + "mymodel", + "--tensor-parallel-size", + "3", + "--config", + cli_config_file, + "--port", + "666", + ] + ) assert args.tensor_parallel_size == 3 assert args.port == 666 def test_config_args(parser_with_config, cli_config_file): args = parser_with_config.parse_args( - ['serve', 'mymodel', '--config', cli_config_file]) + ["serve", "mymodel", "--config", cli_config_file] + ) assert args.tensor_parallel_size == 2 assert args.trust_remote_code @@ -240,22 +262,31 @@ def test_config_args(parser_with_config, cli_config_file): def test_config_file(parser_with_config): with pytest.raises(FileNotFoundError): parser_with_config.parse_args( - ['serve', 'mymodel', '--config', 'test_config.yml']) + ["serve", "mymodel", "--config", "test_config.yml"] + ) with pytest.raises(ValueError): parser_with_config.parse_args( - ['serve', 'mymodel', '--config', './data/test_config.json']) + ["serve", "mymodel", "--config", "./data/test_config.json"] + ) with pytest.raises(ValueError): - parser_with_config.parse_args([ - 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', - '--batch-size', '32' - ]) + parser_with_config.parse_args( + [ + "serve", + "mymodel", + "--tensor-parallel-size", + "3", + "--config", + "--batch-size", + "32", + ] + ) def test_no_model_tag(parser_with_config, cli_config_file): with pytest.raises(ValueError): - parser_with_config.parse_args(['serve', '--config', cli_config_file]) + parser_with_config.parse_args(["serve", "--config", cli_config_file]) def test_dict_args(parser): @@ -318,7 +349,7 @@ def test_dict_args(parser): }, "key14": { "key15": "-minus.and.dot", - } + }, } assert parsed_args.compilation_config == { "level": 1, @@ -352,7 +383,6 @@ def test_duplicate_dict_args(caplog_vllm, parser): assert "-O.level" in caplog_vllm.text -# yapf: enable @pytest.mark.parametrize( "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", [ @@ -370,24 +400,28 @@ def test_duplicate_dict_args(caplog_vllm, parser): (lambda foo, **kwargs: None, "something_else", False, True, True), (lambda foo, **kwargs: None, "kwargs", True, True, False), (lambda foo, **kwargs: None, "foo", True, True, False), - ]) -# yapf: disable -def test_supports_kw(callable,kw_name,requires_kw_only, - allow_var_kwargs,is_supported): - assert supports_kw( - callable=callable, - kw_name=kw_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs - ) == is_supported + ], +) +def test_supports_kw( + callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported +): + assert ( + supports_kw( + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + == is_supported + ) @create_new_process_for_each_test() def test_memory_profiling(): # Fake out some model loading + inference memory usage to test profiling # Memory used by other processes will show up as cuda usage outside of torch - from vllm.distributed.device_communicators.cuda_wrapper import ( - CudaRTLibrary) + from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + lib = CudaRTLibrary() # 512 MiB allocation outside of this instance handle1 = lib.cudaMalloc(512 * 1024 * 1024) @@ -396,9 +430,9 @@ def test_memory_profiling(): # load weights - weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) + weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32) - weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB + weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB def measure_current_non_torch(): free, total = torch.cuda.mem_get_info() @@ -407,11 +441,14 @@ def test_memory_profiling(): current_non_torch = current_used - current_torch return current_non_torch - with memory_profiling(baseline_snapshot=baseline_snapshot, - weights_memory=weights_memory) as result, \ - monitor(measure_current_non_torch) as monitored_values: + with ( + memory_profiling( + baseline_snapshot=baseline_snapshot, weights_memory=weights_memory + ) as result, + monitor(measure_current_non_torch) as monitored_values, + ): # make a memory spike, 1 GiB - spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) + spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32) del spike # Add some extra non-torch memory 256 MiB (simulate NCCL) @@ -426,7 +463,7 @@ def test_memory_profiling(): # 5% tolerance is caused by cuda runtime. # we cannot control cuda runtime in the granularity of bytes, # which causes a small error (<10 MiB in practice) - non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa + non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa assert abs(non_torch_ratio - 1) <= 0.05 assert result.torch_peak_increase == 1024 * 1024 * 1024 del weights @@ -438,237 +475,83 @@ def test_bind_kv_cache(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), } kv_cache = [ - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), ] bind_kv_cache(ctx, [kv_cache]) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3] + def test_bind_kv_cache_kv_sharing(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), } kv_cache = [ - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), ] shared_kv_cache_layers = { - 'layers.2.self_attn': 'layers.1.self_attn', - 'layers.3.self_attn': 'layers.0.self_attn' + "layers.2.self_attn": "layers.1.self_attn", + "layers.3.self_attn": "layers.0.self_attn", } bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0] + def test_bind_kv_cache_non_attention(): from vllm.attention import Attention # example from Jamba PP=2 ctx = { - 'model.layers.20.attn': Attention(32, 128, 0.1), - 'model.layers.28.attn': Attention(32, 128, 0.1), + "model.layers.20.attn": Attention(32, 128, 0.1), + "model.layers.28.attn": Attention(32, 128, 0.1), } kv_cache = [ - torch.zeros((1, )), - torch.zeros((1, )), + torch.zeros((1,)), + torch.zeros((1,)), ] bind_kv_cache(ctx, [kv_cache]) - assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0] - assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] - - -def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch): - # V1 TESTS: ENCODER_DECODER is not supported on V1 yet. - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - - from vllm.attention import Attention, AttentionType - - # example from bart - ctx = { - 'encoder.layers.0.self_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), - 'decoder.layers.0.encoder_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), - 'decoder.layers.0.self_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), - } - - kv_cache = [ - torch.zeros((1, )), - ] - encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache - - bind_kv_cache(ctx, [kv_cache]) - assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache - assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] - assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] + assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0] + assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1] def test_bind_kv_cache_pp(): with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs - cfg = VllmConfig( - parallel_config=ParallelConfig(pipeline_parallel_size=2)) + cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), } - kv_cache = [ - [torch.zeros((1, ))], - [torch.zeros((1, ))] - ] + kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]] bind_kv_cache(ctx, kv_cache) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0] - assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0] + assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0] -class TestLRUCache(LRUCache): - - def _on_remove(self, key, value): - if not hasattr(self, "_remove_counter"): - self._remove_counter = 0 - self._remove_counter += 1 - - -def test_lru_cache(): - cache = TestLRUCache(3) - assert cache.stat() == CacheInfo(hits=0, total=0) - assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(2, 2) - assert len(cache) == 2 - - cache.put(3, 3) - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache.put(4, 4) - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - - assert cache.get(2) == 2 - assert cache.stat() == CacheInfo(hits=1, total=1) - assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) - - assert cache[2] == 2 - assert cache.stat() == CacheInfo(hits=2, total=2) - assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) - - cache.put(5, 5) - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - assert cache.pop(5) == 5 - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - assert cache.get(-1) is None - assert cache.stat() == CacheInfo(hits=2, total=3) - assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.get(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.put(6, 6) - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - cache.remove_oldest() - assert len(cache) == 2 - assert set(cache.cache) == {2, 6} - assert cache._remove_counter == 4 - - cache.clear() - assert len(cache) == 0 - assert cache._remove_counter == 6 - assert cache.stat() == CacheInfo(hits=0, total=0) - assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) - - cache._remove_counter = 0 - - cache[1] = 1 - assert len(cache) == 1 - - cache[1] = 1 - assert len(cache) == 1 - - cache[2] = 2 - assert len(cache) == 2 - - cache[3] = 3 - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache[4] = 4 - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache[2] == 2 - - cache[5] = 5 - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - del cache[5] - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache[6] = 6 - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - -# yapf: disable @pytest.mark.parametrize( ("src_dtype", "tgt_dtype", "expected_result"), [ @@ -702,12 +585,10 @@ def test_lru_cache(): (torch.complex64, torch.complex32, False), ], ) -# yapf: enable def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result -# yapf: disable @pytest.mark.parametrize( ("dtypes", "expected_result"), [ @@ -717,7 +598,6 @@ def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 ], ) -# yapf: enable def test_common_broadcastable_dtype(dtypes, expected_result): assert common_broadcastable_dtype(dtypes) == expected_result @@ -762,7 +642,6 @@ def test_placeholder_module_error_handling(): _ = placeholder_attr.module -# yapf: disable @pytest.mark.parametrize( "obj,key1,key2", [ @@ -772,8 +651,8 @@ def test_placeholder_module_error_handling(): ({1: "a", 2: "b"}, 1, 3), # Tests for both keys do not exist ({1: "a", 2: "b"}, 3, 4), - ]) -# yapf: enable + ], +) def test_swap_dict_values(obj, key1, key2): original_obj = obj.copy() swap_dict_values(obj, key1, key2) @@ -787,66 +666,103 @@ def test_swap_dict_values(obj, key1, key2): assert key1 not in obj -def test_model_specification(parser_with_config, cli_config_file, - cli_config_file_with_model): +def test_model_specification( + parser_with_config, cli_config_file, cli_config_file_with_model +): # Test model in CLI takes precedence over config args = parser_with_config.parse_args( - ['serve', 'cli-model', '--config', cli_config_file_with_model]) - assert args.model_tag == 'cli-model' - assert args.served_model_name == 'mymodel' + ["serve", "cli-model", "--config", cli_config_file_with_model] + ) + assert args.model_tag == "cli-model" + assert args.served_model_name == "mymodel" # Test model from config file works - args = parser_with_config.parse_args([ - 'serve', - '--config', - cli_config_file_with_model, - ]) - assert args.model == 'config-model' - assert args.served_model_name == 'mymodel' + args = parser_with_config.parse_args( + [ + "serve", + "--config", + cli_config_file_with_model, + ] + ) + assert args.model == "config-model" + assert args.served_model_name == "mymodel" # Test no model specified anywhere raises error with pytest.raises(ValueError, match="No model specified!"): - parser_with_config.parse_args(['serve', '--config', cli_config_file]) + parser_with_config.parse_args(["serve", "--config", cli_config_file]) # Test using --model option raises error - with pytest.raises( - ValueError, - match= - ("With `vllm serve`, you should provide the model as a positional " - "argument or in a config file instead of via the `--model` option."), - ): - parser_with_config.parse_args(['serve', '--model', 'my-model']) + # with pytest.raises( + # ValueError, + # match= + # ("With `vllm serve`, you should provide the model as a positional " + # "argument or in a config file instead of via the `--model` option."), + # ): + # parser_with_config.parse_args(['serve', '--model', 'my-model']) + + # Test using --model option back-compatibility + # (when back-compatibility ends, the above test should be uncommented + # and the below test should be removed) + args = parser_with_config.parse_args( + [ + "serve", + "--tensor-parallel-size", + "2", + "--model", + "my-model", + "--trust-remote-code", + "--port", + "8001", + ] + ) + assert args.model is None + assert args.tensor_parallel_size == 2 + assert args.trust_remote_code is True + assert args.port == 8001 + + args = parser_with_config.parse_args( + [ + "serve", + "--tensor-parallel-size=2", + "--model=my-model", + "--trust-remote-code", + "--port=8001", + ] + ) + assert args.model is None + assert args.tensor_parallel_size == 2 + assert args.trust_remote_code is True + assert args.port == 8001 # Test other config values are preserved - args = parser_with_config.parse_args([ - 'serve', - 'cli-model', - '--config', - cli_config_file_with_model, - ]) + args = parser_with_config.parse_args( + [ + "serve", + "cli-model", + "--config", + cli_config_file_with_model, + ] + ) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True assert args.port == 12312 -@pytest.mark.parametrize("input", [(), ("abc", ), (None, ), - (None, bool, [1, 2, 3])]) -@pytest.mark.parametrize("output", [0, 1, 2]) -def test_sha256(input: tuple, output: int): - hash = sha256(input) - assert hash is not None - assert isinstance(hash, int) - assert hash != 0 +@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])]) +def test_sha256(input: tuple): + digest = sha256(input) + assert digest is not None + assert isinstance(digest, bytes) + assert digest != b"" - bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), - byteorder="big") + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + assert digest == hashlib.sha256(input_bytes).digest() # hashing again, returns the same value - assert hash == sha256(input) + assert digest == sha256(input) # hashing different input, returns different value - assert hash != sha256(input + (1, )) + assert digest != sha256(input + (1,)) @pytest.mark.parametrize( @@ -856,7 +772,8 @@ def test_sha256(input: tuple, output: int): ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address ("inproc://some_identifier", ("inproc", "some_identifier", "")), - ]) + ], +) def test_split_zmq_path(path, expected): assert split_zmq_path(path) == expected @@ -868,7 +785,8 @@ def test_split_zmq_path(path, expected): "tcp://127.0.0.1", # Missing port "tcp://[::1]", # Missing port for IPv6 "tcp://:5555", # Missing host - ]) + ], +) def test_split_zmq_path_invalid(invalid_path): with pytest.raises(ValueError): split_zmq_path(invalid_path) @@ -890,8 +808,9 @@ def test_make_zmq_socket_ipv6(): zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) # Verify that the IPV6 option is set - assert zsock.getsockopt( - zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" + assert zsock.getsockopt(zmq.IPV6) == 1, ( + "IPV6 option should be enabled for IPv6 addresses" + ) # Clean up zsock.close() @@ -944,19 +863,48 @@ def test_join_host_port(): assert join_host_port("::1", 5555) == "[::1]:5555" +def test_json_count_leaves(): + """Test json_count_leaves function from jsontree utility.""" + from vllm.utils.jsontree import json_count_leaves + + # Single leaf values + assert json_count_leaves(42) == 1 + assert json_count_leaves("hello") == 1 + assert json_count_leaves(None) == 1 + + # Empty containers + assert json_count_leaves([]) == 0 + assert json_count_leaves({}) == 0 + assert json_count_leaves(()) == 0 + + # Flat structures + assert json_count_leaves([1, 2, 3]) == 3 + assert json_count_leaves({"a": 1, "b": 2}) == 2 + assert json_count_leaves((1, 2, 3)) == 3 + + # Nested structures + nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} + assert json_count_leaves(nested_dict) == 3 + + nested_list = [1, [2, 3], 4] + assert json_count_leaves(nested_list) == 4 + + mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} + assert json_count_leaves(mixed_nested) == 4 + + def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") # token_ids = [9707, 11, 1879, 0] - assert tokenizer.convert_ids_to_tokens(token_ids) == [ - 'Hello', ',', 'Ġworld', '!' - ] + assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"] tokens = convert_ids_list_to_tokens(tokenizer, token_ids) - assert tokens == ['Hello', ',', ' world', '!'] + assert tokens == ["Hello", ",", " world", "!"] def test_current_stream_multithread(): import threading + if not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -975,13 +923,18 @@ def test_current_stream_multithread(): child_thread.start() try: - assert thread_stream_ready.wait( - timeout=5), "Child thread failed to enter stream context in time" + assert thread_stream_ready.wait(timeout=5), ( + "Child thread failed to enter stream context in time" + ) main_current_stream = current_stream() - assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread" - assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream" + assert main_current_stream != child_stream, ( + "Main thread's current_stream was contaminated by child thread" + ) + assert main_current_stream == main_default_stream, ( + "Main thread's current_stream is not the default stream" + ) # Notify child thread it can exit thread_can_exit.set() @@ -991,3 +944,52 @@ def test_current_stream_multithread(): child_thread.join(timeout=5) if child_thread.is_alive(): pytest.fail("Child thread failed to exit properly") + + +def test_load_config_file(tmp_path): + # Define the configuration data + config_data = { + "enable-logging": True, + "list-arg": ["item1", "item2"], + "port": 12323, + "tensor-parallel-size": 4, + } + + # Write the configuration data to a temporary YAML file + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as config_file: + yaml.dump(config_data, config_file) + + # Initialize the parser + parser = FlexibleArgumentParser() + + # Call the function with the temporary file path + processed_args = parser.load_config_file(str(config_file_path)) + + # Expected output + expected_args = [ + "--enable-logging", + "--list-arg", + "item1", + "item2", + "--port", + "12323", + "--tensor-parallel-size", + "4", + ] + + # Assert that the processed arguments match the expected output + assert processed_args == expected_args + os.remove(str(config_file_path)) + + +def test_unique_filepath(): + temp_dir = tempfile.mkdtemp() + path_fn = lambda i: Path(temp_dir) / f"file_{i}.txt" + paths = set() + for i in range(10): + path = unique_filepath(path_fn) + path.write_text("test") + paths.add(path) + assert len(paths) == 10 + assert len(list(Path(temp_dir).glob("*.txt"))) == 10 diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 60e04ad9069e7..7fee73da15a2a 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -2,29 +2,44 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 attention backends without GPUModelRunner dependency.""" +from functools import partial +from typing import Optional, Union + import pytest import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - set_kv_cache_layout) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + try_get_attention_backend, +) +from vllm.attention.backends.registry import _Backend +from vllm.config import ModelConfig +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + set_kv_cache_layout, +) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, - _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN + _Backend.FLASH_ATTN, + _Backend.FLASHINFER, + _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN, + _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW", ] # Remove flashinfer from the list if it's not available try: import flashinfer # noqa: F401 except ImportError: - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1) + BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) def _convert_dtype_to_torch(dtype): @@ -44,60 +59,40 @@ def _convert_dtype_to_torch(dtype): # Define common batch configurations BATCH_SPECS = { - "small_decode": - BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), - "small_prefill": - BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), - "mixed_small": - BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), - "medium_decode": - BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], - query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), - "medium_prefill": - BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), - "mixed_medium": - BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], - query_lens=[1, 1, 1, 7, 7, 7]), - "large_decode": - BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), - "large_prefill": - BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), - "single_decode": - BatchSpec(seq_lens=[1024], query_lens=[1]), - "single_prefill": - BatchSpec(seq_lens=[1024], query_lens=[64]), + "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": BatchSpec( + seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1], + ), + "medium_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16] + ), + "mixed_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7] + ), + "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - 2, # K and V - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( - k_contexts: list[torch.Tensor], - v_contexts: list[torch.Tensor], - block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int, - common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True) -> torch.Tensor: + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, +) -> torch.Tensor: """Create and prepopulate a KV cache with context data. - + Args: k_contexts: List of key context tensors for each sequence v_contexts: List of value context tensors for each sequence @@ -109,28 +104,26 @@ def create_and_prepopulate_kv_cache( device: Device to create the cache on num_blocks: Total number of blocks in the cache block_table: Block table tensor to populate - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order - + Returns: Tuple of (kv_cache, updated_block_table) """ batch_size = len(k_contexts) seq_lens = common_attn_metadata.seq_lens_cpu - query_lens = common_attn_metadata.query_start_loc_cpu[ - 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) context_lens = common_attn_metadata.num_computed_tokens_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping # Create KV cache - kv_cache = torch.empty(2, - num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype, - device=device) + kv_cache = torch.empty( + 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device + ) kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) # Populate the cache with the context tokens @@ -179,8 +172,8 @@ def create_and_prepopulate_kv_cache( start = common_attn_metadata.query_start_loc_cpu[i] end = common_attn_metadata.query_start_loc_cpu[i + 1] slot_mapping[start:end] = block_table[ - i, - block_indices] * block_size + token_inter_block_offsets.to(device) + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) return kv_cache @@ -193,23 +186,38 @@ class MockAttentionLayer: self._k_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device) # Add float versions for flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 -def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - layer_names: list[str], vllm_config, - device: torch.device, - common_attn_metadata: CommonAttentionMetadata, - query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor) -> torch.Tensor: +def run_attention_backend( + backend: _Backend, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + sliding_window: Optional[int] = None, +) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = _Backend.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if backend == _Backend.FLASHINFER_VLLM_V1: + if actual_backend == _Backend.FLASHINFER: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -218,20 +226,19 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { - layer_name: - PerLayerParameters( + layer_name: PerLayerParameters( window_left=-1, # No sliding window logits_soft_cap=0.0, # No soft cap - sm_scale=1.0 / (head_size**0.5) # Standard scale + sm_scale=1.0 / (head_size**0.5), # Standard scale ) for layer_name in layer_names } with unittest.mock.patch( - 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', - mock_get_per_layer_parameters): - builder = builder_cls(kv_cache_spec, layer_names, vllm_config, - device) + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -239,6 +246,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == _Backend.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -246,9 +255,11 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Instantiate implementation num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() scale = 1.0 / (head_size**0.5) impl = impl_cls( @@ -257,7 +268,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, scale=scale, num_kv_heads=num_kv_heads, alibi_slopes=None, - sliding_window=None, + sliding_window=sliding_window, kv_cache_dtype="auto", ) @@ -268,24 +279,23 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Run forward pass # NOTE: The query, key, and value are already shaped correctly # in the calling test function. - output = impl.forward(mock_layer, - query, - key, - value, - kv_cache, - attn_metadata, - output=output) + output = impl.forward( + mock_layer, query, key, value, kv_cache, attn_metadata, output=output + ) return output -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_backend_correctness(batch_spec_name: str, model: str): +def _test_backend_correctness( + batch_spec: BatchSpec, + model: str, + backend_to_test: list[Union[_Backend, str]], + mask_mod, + *, + block_size: int = 16, + atol: float = 1e-2, + rtol: float = 1e-2, +): """ Test that all backends produce similar outputs to a reference implementation using torch.nn.functional.scaled_dot_product_attention. @@ -301,10 +311,13 @@ def test_backend_correctness(batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ - batch_spec = BATCH_SPECS[batch_spec_name] - vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens), - num_gpu_blocks=8192) + current_platform.seed_everything(42) + vllm_config = create_vllm_config( + model_name=model, + max_model_len=max(batch_spec.seq_lens), + block_size=block_size, + num_gpu_blocks=8192, + ) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -314,10 +327,13 @@ def test_backend_correctness(batch_spec_name: str, model: str): seq_lens = batch_spec.seq_lens query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() + sliding_window = vllm_config.model_config.get_sliding_window() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size scale = 1.0 / (head_size**0.5) @@ -333,21 +349,9 @@ def test_backend_correctness(batch_spec_name: str, model: str): context_len = s_len - q_len # Generate Q, K, V for the whole sequence to be used in SDPA - q = torch.randn(q_len, - num_q_heads, - head_size, - dtype=dtype, - device=device) - k_full = torch.randn(s_len, - num_kv_heads, - head_size, - dtype=dtype, - device=device) - v_full = torch.randn(s_len, - num_kv_heads, - head_size, - dtype=dtype, - device=device) + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) # SDPA expects (N, H, L, D), so unsqueeze batch and permute q_sdpa_in = q.unsqueeze(0).transpose(1, 2) @@ -357,7 +361,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): if num_q_heads != num_kv_heads: assert num_q_heads % num_kv_heads == 0, ( f"num_q_heads ({num_q_heads}) must be divisible by " - f"num_kv_heads ({num_kv_heads})") + f"num_kv_heads ({num_kv_heads})" + ) repeats = num_q_heads // num_kv_heads k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) @@ -365,22 +370,20 @@ def test_backend_correctness(batch_spec_name: str, model: str): # Create causal mask: query token i attends to positions 0 to # (context_len + i) kv_len = s_len - offset = context_len - attn_mask = torch.full((q_len, kv_len), - float('-inf'), - device=device, - dtype=dtype) - for i in range(q_len): - attn_mask[i, :offset + i + 1] = 0.0 - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + final_mask_mod = partial(mask_mod, context_len=context_len) + block_mask = create_block_mask( + final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device + ) + sdpa_out_i = flex_attention( q_sdpa_in, k_sdpa_in, v_sdpa_in, - attn_mask=attn_mask, + block_mask=block_mask, scale=scale, - enable_gqa=True) - # Convert back to (L, H, D) + enable_gqa=True, + ) + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) # Inputs for vLLM backends are just the new tokens @@ -398,7 +401,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): sdpa_output = torch.cat(all_sdpa_outputs, dim=0) common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) + batch_spec, vllm_config.cache_config.block_size, device + ) # 3. Simulate Paged KV Cache and a realistic slot_mapping kv_cache = create_and_prepopulate_kv_cache( @@ -411,62 +415,166 @@ def test_backend_correctness(batch_spec_name: str, model: str): device=device, num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, common_attn_metadata=common_attn_metadata, - randomize_blocks=True) + randomize_blocks=True, + ) # 4. Run vLLM backends and compare # Note: flex_attention has known Triton kernel compatibility issues # with test infrastructures - for backend_name in BACKENDS_TO_TEST: + for backend_name in backend_to_test: # FlashAttentionm + FlexAttention: # [2, num_blocks, block_size, num_kv_heads, head_size] # FlashInfer: # [num_blocks, 2, block_size, num_kv_heads, head_size] # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache - if backend_name == _Backend.FLASHINFER_VLLM_V1: + if backend_name == _Backend.FLASHINFER: kv_cache_for_backend = kv_cache.transpose(0, 1) # For FlashInfer default to HND layout and - kv_cache_for_backend = kv_cache_for_backend.transpose( - 2, 3).contiguous().transpose(2, 3) + kv_cache_for_backend = ( + kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) + ) set_kv_cache_layout("HND") - backend_output = run_attention_backend(backend_name, kv_cache_spec, - ["placeholder"], vllm_config, - device, common_attn_metadata, - query_vllm, key_vllm, - value_vllm, - kv_cache_for_backend) + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + ) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") + f"SDPA shape {sdpa_output.shape}" + ) assert backend_output.dtype == sdpa_output.dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_output.dtype}" + ) assert torch.isfinite(backend_output).all(), ( - f"[{backend_name}] produced non-finite values") + f"[{backend_name}] produced non-finite values" + ) # Check numerical similarity - rtol = 1e-2 - atol = 5e-3 + def error_msg(msg: str, backend_name: str): + return f"[{backend_name}] output differs from SDPA baseline. {msg}" - if backend_name == _Backend.FLEX_ATTENTION: - atol = 5e-1 # TODO: figure out why flex_attention has such large - # numerical differences for medium_decode, medium_prefill, - # mixed_medium + torch.testing.assert_close( + backend_output, + sdpa_output, + rtol=rtol, + atol=atol, + msg=partial(error_msg, backend_name=backend_name), + ) - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() - max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() - all_close = torch.allclose(backend_output, - sdpa_output, - rtol=rtol, - atol=atol) - assert all_close, ( - f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_causal_backend_correctness(batch_spec_name: str, model: str): + """Test backend's correctness with causal attention.""" + + def causal_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + ): + return (q_idx + context_len) >= kv_idx + + batch_spec = BATCH_SPECS[batch_spec_name] + LARGE_BLOCK_BACKENDS = ( + [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + ) + SMALL_BLOCK_BACKENDS = [ + x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod) + + # Fast FlexAttention needs to run with block_size=128 + if LARGE_BLOCK_BACKENDS: + _test_backend_correctness( + batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128 + ) + + +SLIDING_WINDOW_BACKENDS_TO_TEST = [ + _Backend.FLASH_ATTN, + _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN, + "FLEX_ATTENTION_SLOW", +] + + +@pytest.mark.parametrize( + "batch_spec_name", + ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], +) +@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) +def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): + """Test backend's correctness with sliding window attention.""" + + def sliding_window_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + sliding_window: int, + ): + causal_mask = q_idx + context_len >= kv_idx + window_mask = q_idx + context_len - kv_idx < sliding_window + return causal_mask & window_mask + + batch_spec = BATCH_SPECS[batch_spec_name] + model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens)) + sliding_window = model_config.get_sliding_window() + sliding_window_mask_mod_fn = partial( + sliding_window_mask_mod, sliding_window=sliding_window + ) + + LARGE_BLOCK_BACKENDS = ( + [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + ) + SMALL_BLOCK_BACKENDS = [ + x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness( + batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn + ) + + # Fast FlexAttention needs to run with block_size=128 + if LARGE_BLOCK_BACKENDS: + _test_backend_correctness( + batch_spec, + model, + LARGE_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + block_size=128, + ) diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py new file mode 100644 index 0000000000000..6464bb52a4eaa --- /dev/null +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +from types import SimpleNamespace + +import pytest + +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.models.minimax_text_01 import MiniMaxText01LinearAttention +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend + + +@pytest.mark.parametrize( + "layer_class, init_kwargs, expected_backend, expected_mamba_type", + [ + ( + MambaMixer, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + time_step_rank=8, + use_conv_bias=True, + use_bias=False, + use_rms_norm=True, + ), + Mamba1AttentionBackend, + "mamba1", + ), + ( + MambaMixer2, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + use_conv_bias=True, + use_bias=False, + n_groups=1, + num_heads=8, + head_dim=32, + ), + Mamba2AttentionBackend, + "mamba2", + ), + ( + MiniMaxText01LinearAttention, + dict( + hidden_size=128, + hidden_inner_size=256, + num_heads=8, + head_dim=32, + max_position=2048, + block_size=64, + num_hidden_layer=12, + layer_idx=0, + linear_layer_idx=0, + ), + LinearAttentionBackend, + "linear_attention", + ), + ( + ShortConv, + dict( + config=SimpleNamespace(conv_L_cache=32, conv_bias=True), + dim=128, + layer_idx=0, + ), + ShortConvAttentionBackend, + "short_conv", + ), + ], +) +def test_mamba_layers_get_attn_backend( + dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type +): + """Test that Mamba-like layers return the correct attention backend.""" + layer = layer_class(**init_kwargs) + + backend_class = layer.get_attn_backend() + assert backend_class is expected_backend + assert layer.mamba_type == expected_mamba_type + + +@pytest.mark.parametrize( + "layer_class,expected_backend,expected_mamba_type", + [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), + ], +) +def test_mamba_layers_have_unified_interface( + layer_class, expected_backend, expected_mamba_type +): + """Test that all Mamba layers have the unified get_attn_backend + interface.""" + assert hasattr(layer_class, "get_attn_backend"), ( + f"{layer_class.__name__} should have get_attn_backend method" + ) + assert hasattr(layer_class, "mamba_type"), ( + f"{layer_class.__name__} should have mamba_type property" + ) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 3fc1011d5042e..1cbd0fe56be6d 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -5,11 +5,15 @@ import pytest import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS -from tests.v1.attention.utils import create_common_attn_metadata -from vllm.v1.attention.backends.utils import (UbatchSlice, - _make_metadata_with_slice, - slice_query_start_locs, - split_attn_metadata) +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from vllm.v1.attention.backends.utils import ( + UBatchSlice, + _make_metadata_with_slice, + slice_query_start_locs, + split_attn_metadata, + split_decodes_and_prefills, +) +from vllm.v1.worker.ubatch_utils import create_ubatch_slices @pytest.fixture @@ -77,9 +81,7 @@ def small_decode_metadata(): """Create metadata for small decode batch""" batch_spec = BATCH_SPECS["small_decode"] device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + return create_common_attn_metadata(batch_spec, block_size=16, device=device) @pytest.fixture @@ -87,9 +89,7 @@ def large_decode_metadata(): """Create metadata for small decode batch""" batch_spec = BATCH_SPECS["large_decode"] device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + return create_common_attn_metadata(batch_spec, block_size=16, device=device) @pytest.fixture @@ -97,16 +97,14 @@ def mixed_small_metadata(): """Create metadata for mixed small batch""" batch_spec = BATCH_SPECS["mixed_small"] device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + return create_common_attn_metadata(batch_spec, block_size=16, device=device) # Tests for _make_metadata_with_slice def test_make_metadata_with_slice_decode_batch(small_decode_metadata): """Test slicing decode batch metadata""" # Split first request only - ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) + ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1)) result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) @@ -120,8 +118,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): """Test slicing mixed batch metadata""" - ubatch_slice = UbatchSlice(slice(1, 3), - slice(1, 7)) # Requests 1-3, tokens 1-7 + ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7 result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) @@ -137,9 +134,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): num_tokens = large_decode_metadata.num_reqs mid_point = num_tokens // 2 ubatch_slices = [ - UbatchSlice(slice(0, mid_point), slice(0, mid_point)), - UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, - num_tokens)), + UBatchSlice(slice(0, mid_point), slice(0, mid_point)), + UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)), ] results = split_attn_metadata(ubatch_slices, large_decode_metadata) @@ -155,3 +151,199 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert results[1].num_reqs == mid_point assert results[1].num_actual_tokens == mid_point assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) + + +def apply_split_decodes_and_prefills( + query_lens: list[int], decode_threshold: int, require_uniform: bool +): + """Helper function to apply split_decodes_and_prefills and return + the results.""" + device = torch.device("cpu") + seq_lens = [10 * (i + 1) for i in range(len(query_lens))] + common_metadata = create_common_attn_metadata( + BatchSpec(seq_lens=seq_lens, query_lens=query_lens), + block_size=16, + device=device, + ) + return split_decodes_and_prefills( + common_metadata, + decode_threshold=decode_threshold, + require_uniform=require_uniform, + ) + + +def test_split_decodes_and_prefills_nonuniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, False) + ) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_short_decodes(): + query_lens = [1, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False) + ) + assert num_decodes == 7 + assert num_prefills == 0 + assert num_decode_tokens == sum(query_lens) + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_nonuniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, False) + ) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_nonuniform_mixed_batch(): + query_lens = [2, 1, 3, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, False) + ) + assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4 + assert num_prefills == 4 # 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 10 # 2 + 1 + 3 + 4 + assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_all_ones(): + query_lens = [1, 1, 1] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 1, True) + ) + assert num_decodes == 3 + assert num_prefills == 0 + assert num_decode_tokens == 3 + assert num_prefill_tokens == 0 + + +def test_split_decodes_and_prefills_uniform_all_short_decodes(): + query_lens = [2, 2, 1, 3, 2, 1, 2] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True) + ) + assert num_decodes == 2 + assert num_prefills == 5 + assert num_decode_tokens == 4 + assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2) + + +def test_split_decodes_and_prefills_uniform_all_prefills(): + query_lens = [4, 5, 6, 7] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True) + ) + assert num_decodes == 0 + assert num_prefills == 4 + assert num_decode_tokens == 0 + assert num_prefill_tokens == sum(query_lens) + + +def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes(): + query_lens = [2, 2, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True) + ) + assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform + assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4 + assert num_decode_tokens == 6 # 2 + 2 + 2 + assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8 + + +def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): + query_lens = [2, 1, 2, 4, 5, 6, 7, 8] + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 4, True) + ) + assert num_decodes == 1 # only the first 2 is taken as decode + assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform + assert num_decode_tokens == 2 # only the first 2 + assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens + + +@pytest.mark.parametrize( + "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", + [ + # Split in the middle of request 1 + ([32, 40], [8, 8], 12, 2, 1), + # Split inside the first request + ([32, 40], [8, 8], 4, 1, 2), + ], +) +def test_prefill_split_across_ubatches( + seq_lens, query_lens, split_point, expected_first_reqs, expected_second_reqs +): + """Test splitting a prefill across ubatches""" + import numpy as np + + device = torch.device("cpu") + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) + common = create_common_attn_metadata(batch_spec, block_size=16, device=device) + + num_scheduled_tokens = np.array(query_lens, dtype=np.int32) + qsl_np = common.query_start_loc_cpu.numpy() + num_tokens = common.num_actual_tokens + + ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) + assert len(ubatch_slices) == 2 + + first_meta = _make_metadata_with_slice(ubatch_slices[0], common) + second_meta = _make_metadata_with_slice(ubatch_slices[1], common) + + # Token counts match the split + assert first_meta.num_actual_tokens == split_point + assert second_meta.num_actual_tokens == num_tokens - split_point + + # Number of requests per ubatch + assert first_meta.num_reqs == expected_first_reqs + assert second_meta.num_reqs == expected_second_reqs + + # Identify which request is split and how many tokens are in the first chunk + split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1) + tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx]) + orig_q_lens = common.query_start_loc_cpu[1:] - common.query_start_loc_cpu[:-1] + + # Check query length continuity: first-chunk + second-chunk == original qlen + # First ubatch last request query length + qlen_first_last = int( + first_meta.query_start_loc_cpu[-1] - first_meta.query_start_loc_cpu[-2] + ) + # Second ubatch first request query length + qlen_second_first = int( + second_meta.query_start_loc_cpu[1] - second_meta.query_start_loc_cpu[0] + ) + assert qlen_first_last == tokens_in_first_chunk + assert qlen_first_last + qlen_second_first == int(orig_q_lens[split_req_idx]) + + # Check seq_lens adjustments + # Context lengths per original request + context_lens = [s - q for s, q in zip(seq_lens, query_lens)] + + # First ubatch: last request's seq_len should be + # context + tokens_in_first_chunk + expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk + assert int(first_meta.seq_lens[-1]) == expected_seqlen + + # For full preceding requests in first ubatch, seq_lens should match + # originals + for i in range(first_meta.num_reqs - 1): + assert int(first_meta.seq_lens[i]) == seq_lens[i] + + # Second ubatch: first request (continuation) seq_len should be full + # original + assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx] + # Any following full requests in second ubatch should match originals + for j in range(1, second_meta.num_reqs): + # Map to original request index + orig_idx = split_req_idx + j + assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx] diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index 8c5a63653db9f..faace3473a281 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -7,8 +7,7 @@ import pytest import torch from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata -from vllm.v1.attention.backends.utils import ( - make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches @dataclass @@ -46,21 +45,24 @@ test_data_list = [ [17, 17], # local-batch 5, (batch 1, starting from k[16]) [20, 21], # local-batch 6, (batch 2, starting from k[4]) [22, 23], # local-batch 7, (batch 2, starting from k[8]) - ]), + ], + ), # Case where block indices are not clipped to block table ncols-1 # because tokens_in_last_block == attn_chunk_size - LocalAttentionTestData(batch_spec=BatchSpec( - query_lens=[8], - seq_lens=[12], + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[8], + seq_lens=[12], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[4, 4], + expected_k_seqlens=[4, 4], + expected_local_block_table=[ + [2, 3], + [4, 5], + ], ), - attn_chunk_size=4, - block_size=2, - expected_q_seqlens=[4, 4], - expected_k_seqlens=[4, 4], - expected_local_block_table=[ - [2, 3], - [4, 5], - ]), # Case where all kv_seq positions are involved in attn LocalAttentionTestData( batch_spec=BatchSpec( @@ -76,7 +78,8 @@ test_data_list = [ [0, 1], [2, 3], [4, 4], - ]), + ], + ), # Case where attn_chunk_size > kv_seq_len # so no extra mini virtual batches are created LocalAttentionTestData( @@ -97,7 +100,8 @@ test_data_list = [ # is calculated as (attn_chunk_size // block_size) expected_local_block_table=[ [0, 1, 2, 2, 2], - ]), + ], + ), # Block size equal to chunk size # Expect single page per batch in local batch table LocalAttentionTestData( @@ -118,7 +122,8 @@ test_data_list = [ [1], # local-batch 1, (batch 0, starting from k[4]) [2], # local-batch 1, (batch 0, starting from k[0]) [3], # local-batch 1, (batch 0, starting from k[4]) - ]), + ], + ), # Case where query falls in the second attention chunk # k_toks > 0 1 2 3 4 # q_toks v _____________ @@ -128,17 +133,19 @@ test_data_list = [ # 3 | 1 1 1 1 # 4 | 1 # where tokens 0,1,2,3 have been pre-computed - LocalAttentionTestData(batch_spec=BatchSpec( - query_lens=[1], - seq_lens=[5], + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[1], + seq_lens=[5], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[1], + expected_k_seqlens=[1], + expected_local_block_table=[ + [2, 2], + ], ), - attn_chunk_size=4, - block_size=2, - expected_q_seqlens=[1], - expected_k_seqlens=[1], - expected_local_block_table=[ - [2, 2], - ]), ] @@ -160,14 +167,14 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): # Use torch.arange instead of torch.randint so we can assert on # block table tensor values. The block table will have shape # (num_batches, cdiv(max_seq_len, block_size)) and the values will be - # aranged from 0 to cdiv(max_seq_len, block_size)-1 + # arranged from 0 to cdiv(max_seq_len, block_size)-1 arange_block_indices=True, ) # Call the function - result = make_local_attention_virtual_batches(attn_chunk_size, - common_attn_metadata, - block_size) + result = make_local_attention_virtual_batches( + attn_chunk_size, common_attn_metadata, block_size + ) # Convert to numpy for easier comparison actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy()) @@ -184,13 +191,11 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens) np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens) - expected_block_table_tensor =\ - torch.tensor(expected_local_block_table, - dtype=torch.int32, - device=device) + expected_block_table_tensor = torch.tensor( + expected_local_block_table, dtype=torch.int32, device=device + ) print(f"Expected block table:\n{expected_block_table_tensor}") print(f"Actual block table:\n{result.block_table_tensor}") - torch.testing.assert_close(result.block_table_tensor, - expected_block_table_tensor) + torch.testing.assert_close(result.block_table_tensor, expected_block_table_tensor) diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py deleted file mode 100644 index 4245b50c71310..0000000000000 --- a/tests/v1/attention/test_mamba_selectors.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mamba attention backend selectors.""" - -import pytest - -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend - - -@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], - argvalues=[("mamba2", Mamba2AttentionBackend)]) -def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): - backend_class = get_mamba_attn_backend(mamba_type) - - assert backend_class is expected_backend - - -def test_get_mamba_attn_backend_unsupported(): - unsupported_types = ["mamba", ""] - - for mamba_type in unsupported_types: - err_message = f"Mamba Attention type {mamba_type} is not supported yet." - with pytest.raises(NotImplementedError, match=err_message): - get_mamba_attn_backend(mamba_type) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 24070358799ef..3b6a9115435c4 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -2,26 +2,33 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 MLA backends without GPUModelRunner dependency.""" +from typing import Optional, Union + import pytest import torch -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + try_get_attention_backend, +) +from vllm import _custom_ops as ops +from vllm.attention.backends.registry import _Backend from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, - _Backend.TRITON_MLA_VLLM_V1 + _Backend.CUTLASS_MLA, + _Backend.FLASHMLA, + _Backend.FLASH_ATTN_MLA, + _Backend.TRITON_MLA, ] # Remove CUTLASS_MLA from the list if not using sm100 -if not torch.cuda.is_available() or torch.cuda.get_device_properties( - 0).major < 10: +if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) torch.manual_seed(42) @@ -44,100 +51,124 @@ def _convert_dtype_to_torch(dtype): # Define common batch configurations BATCH_SPECS = { - "small_decode": - BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), - "small_prefill": - BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), - "mixed_small": - BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), - "medium_decode": - BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], - query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), - "medium_prefill": - BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), - "mixed_medium": - BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], - query_lens=[1, 1, 1, 7, 7, 7]), - "large_decode": - BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), - "large_prefill": - BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), - "single_decode": - BatchSpec(seq_lens=[1024], query_lens=[1]), - "single_prefill": - BatchSpec(seq_lens=[1024], query_lens=[64]), + "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": BatchSpec( + seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1], + ), + "medium_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16] + ), + "mixed_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7] + ), + "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.head_size, # latent dimension - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( - kv_c_contexts: list[torch.Tensor], - k_pe_contexts: list[torch.Tensor], - block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int, - common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True) -> torch.Tensor: + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, + kv_cache_dtype: Optional[str] = None, + scale: Union[float, torch.Tensor] = 1.0, +) -> torch.Tensor: """Create and prepopulate an MLA KV cache with context data. - + Args: kv_c_contexts: List of latent KV context tensors for each sequence k_pe_contexts: List of key positional embedding context tensors for each sequence block_size: Size of each block - num_kv_heads: Number of KV heads (should be 1 for MLA) head_size: Size of each head (latent dimension) dtype: Data type for the cache device: Device to create the cache on num_blocks: Total number of blocks in the cache common_attn_metadata: Common attention metadata - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order - + kv_cache_dtype: Optional kv cache dtype string. When set to + "fp8_ds_mla" the cache is populated using the + fp8 DeepSeek MLA layout via concat_and_cache_mla. + scale: Scaling factor forwarded to concat_and_cache_mla when the + fp8 cache layout is requested. + Returns: MLA KV cache tensor """ batch_size = len(kv_c_contexts) seq_lens = common_attn_metadata.seq_lens_cpu - query_lens = common_attn_metadata.query_start_loc_cpu[ - 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) context_lens = common_attn_metadata.num_computed_tokens_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - # Create MLA KV cache: (num_blocks, block_size, head_size) - kv_cache = torch.empty(num_blocks, - block_size, - head_size, - dtype=dtype, - device=device) - kv_cache_flat = kv_cache.view(-1, head_size) + use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla" + + if use_fp8_ds_mla: + if not kv_c_contexts: + raise ValueError( + "kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype" + ) + kv_lora_rank = kv_c_contexts[0].shape[-1] + rope_dim = k_pe_contexts[0].shape[-1] + entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + kv_cache = torch.zeros( + num_blocks, block_size, entry_size, dtype=torch.uint8, device=device + ) + scale_tensor = ( + scale + if isinstance(scale, torch.Tensor) + else torch.tensor(scale, dtype=torch.float32, device=device) + ) + scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) + else: + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty( + num_blocks, block_size, head_size, dtype=dtype, device=device + ) + kv_cache_flat = kv_cache.view(-1, head_size) # Populate the cache with the context tokens # Start from block_id=1 since block_id=0 is considered the null block start_block_idx = 1 for i in range(batch_size): kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] - kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + context_len = kv_c_context.shape[0] + if context_len == 0: + start_block_idx += cdiv(int(seq_lens[i]), block_size) + continue + start = start_block_idx * block_size - end = start + kv_context.shape[0] - kv_cache_flat[start:end, ...] = kv_context + + if use_fp8_ds_mla: + slots = torch.arange(context_len, device=device, dtype=torch.long) + start + ops.concat_and_cache_mla( + kv_c_context, + k_pe_context.squeeze(1), + kv_cache, + slots, + kv_cache_dtype="fp8_ds_mla", + scale=scale_tensor, + ) + else: + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context # Stay block aligned and allocate enough blocks for the new tokens start_block_idx += cdiv(int(seq_lens[i]), block_size) @@ -146,15 +177,14 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + perm = ( + torch.randperm(blocks_end - 1) + 1 + ) # Random permutation starting from block 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) # Sequential order starting from block 1 inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 # Add 1 to account for starting from block 1 kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] # Construct the right block table @@ -175,8 +205,8 @@ def create_and_prepopulate_kv_cache( start = common_attn_metadata.query_start_loc_cpu[i] end = common_attn_metadata.query_start_loc_cpu[i + 1] slot_mapping[start:end] = block_table[ - i, - block_indices] * block_size + token_inter_block_offsets.to(device) + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) return kv_cache @@ -190,18 +220,26 @@ class MockAttentionLayer: self._v_scale = torch.tensor(1.0, device=device) -def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - layer_names: list[str], vllm_config, - device: torch.device, - common_attn_metadata: CommonAttentionMetadata, - query: torch.Tensor, kv_c: torch.Tensor, - k_pe: torch.Tensor, kv_cache: torch.Tensor, - kv_lora_rank: int, qk_nope_head_dim: int, - qk_rope_head_dim: int, v_head_dim: int, - mock_kv_b_proj) -> torch.Tensor: +def run_attention_backend( + backend: _Backend, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + mock_kv_b_proj, +) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + builder_cls, impl_cls = try_get_attention_backend(backend) # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) @@ -212,9 +250,11 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Instantiate MLA implementation num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() scale = 1.0 / (head_size**0.5) impl = impl_cls( @@ -244,30 +284,35 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Create mock layer and output buffer mock_layer = MockAttentionLayer(device) num_tokens = query.shape[0] - output = torch.empty(num_tokens, - num_heads * v_head_dim, - dtype=query.dtype, - device=query.device) + output = torch.empty( + num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device + ) # Run forward pass # NOTE: The query, key, and value are already shaped correctly # in the calling test function. - output = impl.forward(mock_layer, - query, - kv_c, - k_pe, - kv_cache, - attn_metadata, - output=output) + output = impl.forward( + mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output + ) return output -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) def test_backend_correctness(dist_init, batch_spec_name: str, model: str): """ @@ -286,9 +331,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ batch_spec = BATCH_SPECS[batch_spec_name] - vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens), - num_gpu_blocks=2048) + vllm_config = create_vllm_config( + model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048 + ) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -298,9 +343,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): seq_lens = batch_spec.seq_lens query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size @@ -309,28 +353,28 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): qk_nope_head_dim = 128 v_head_dim = 128 total_head_size = kv_lora_rank + qk_rope_head_dim - assert kv_lora_rank + qk_rope_head_dim == head_size, \ + assert kv_lora_rank + qk_rope_head_dim == head_size, ( f"MLA dimensions don't match: {total_head_size} != {head_size}" + ) scale = 1.0 / (total_head_size**0.5) # 2. Generate data and compute SDPA reference output for MLA all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] - all_sdpa_outputs = [] + all_sdpa_outputs: list[list[torch.Tensor]] = [] kv_c_contexts, k_pe_contexts = [], [] # Create shared MLA weight matrices for consistency across all sequences - W_UK = torch.randn(kv_lora_rank, - num_q_heads, - qk_nope_head_dim, - dtype=dtype, - device=device) - W_UV = torch.randn(kv_lora_rank, - num_q_heads, - v_head_dim, - dtype=dtype, - device=device) + W_UK = torch.randn( + kv_lora_rank, num_q_heads, qk_nope_head_dim, dtype=dtype, device=device + ) + W_UV = torch.randn( + kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device + ) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + for i, backend in enumerate(BACKENDS_TO_TEST): + all_sdpa_outputs.append([]) + for i in range(batch_size): s_len = seq_lens[i] q_len = query_lens[i] @@ -339,104 +383,108 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Generate MLA tensors # Q has both nope and rope components: # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] - q_c = torch.randn(q_len, - num_q_heads, - qk_nope_head_dim + qk_rope_head_dim, - dtype=dtype, - device=device) + q_c = torch.randn( + q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device, + ) # KV_C (latent K/V): [s_len, kv_lora_rank] - kv_c_full = torch.randn(s_len, - kv_lora_rank, - dtype=dtype, - device=device) + kv_c_full = torch.randn(s_len, kv_lora_rank, dtype=dtype, device=device) # K_PE (rope component): [s_len, 1, qk_rope_head_dim] - k_pe_full = torch.randn(s_len, - 1, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) - # Determine if this is decode (single token) - # or prefill (multiple tokens) - is_decode = q_len == 1 + # Determine if this is decode or prefill + is_decode = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + builder_cls, _ = try_get_attention_backend(backend) + is_decode.append(q_len <= builder_cls.reorder_batch_threshold) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - if is_decode: - # Decode path: MQA-style attention in latent space - # Transform q_nope to latent space: q_nope @ W_UK - # q_nope: [1, num_heads, qk_nope_head_dim] - # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] - ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, - W_UK) # [1, num_heads, kv_lora_rank] + ####################################################### + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum( + "qnh,lnh->qnl", q_nope, W_UK + ) # [1, num_heads, kv_lora_rank] - # Build MQA attention inputs - # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] - q_mqa = torch.cat([ql_nope, q_pe], dim=-1) - # K: [s_len, kv_lora_rank + qk_rope_head_dim] - # (broadcasted to all heads) - k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) - k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) - # V: [s_len, kv_lora_rank] (broadcasted to all heads) - v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + # Create custom attention mask for decode path: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their position + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( - 0) # [1, num_heads, kv_lora_rank] + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) - # Project back to output space: sdpa_out @ W_UV - sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) - else: - # Prefill path: MHA-style attention with full sequence - # Apply kv_b_proj to the full kv_c tensor - kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, - kv_b_proj_weight) - k_nope_full, v_full = kv_nope_full.split( - [qk_nope_head_dim, v_head_dim], dim=-1) + sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale + ) + sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze( + 0 + ) # [1, num_heads, kv_lora_rank] - # Build attention inputs for full sequence - q_mha = torch.cat([q_nope, q_pe], - dim=-1) # [q_len, num_heads, total_dim] - k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) - k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, W_UV) + sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2) - # Create custom attention mask: - # - Query tokens can attend to all context tokens - # - Query tokens can only attend to query tokens up to their pos - attn_mask = torch.ones(q_len, - s_len, - dtype=torch.bool, - device=device) - # Apply causal mask only to the query portion (context_len onwards) - causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) - attn_mask[:, context_len:] = causal_mask + ####################################################### + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split([qk_nope_head_dim, v_head_dim], dim=-1) - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) - # Single attention call with custom mask - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - attn_mask=attn_mask, - scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask - all_sdpa_outputs.append(sdpa_out_i) + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale + ) + sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) + sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) + + for i, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[i]: + all_sdpa_outputs[i].append(sdpa_out_i_decode) + else: + all_sdpa_outputs[i].append(sdpa_out_i_prefill) # Inputs for vLLM MLA backends are just the new tokens all_q_vllm.append(q_c) @@ -451,72 +499,92 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) - sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + sdpa_outputs = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear - mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, - output_size=num_q_heads * - (qk_nope_head_dim + v_head_dim), - bias=False).to(device=device, - dtype=dtype) + + mock_kv_b_proj = ColumnParallelLinear( + input_size=kv_lora_rank, + output_size=num_q_heads * (qk_nope_head_dim + v_head_dim), + bias=False, + ).to(device=device, dtype=dtype) # Set the mock weights to match our reference implementation # Reshape W_UK and W_UV to match the expected kv_b_proj format # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] kv_b_proj_weight = kv_b_proj_weight.view( - kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim) + ) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) # Create metadata using original batch spec common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) + batch_spec, vllm_config.cache_config.block_size, device + ) # 3. Simulate Paged KV Cache and a realistic slot_mapping kv_cache = create_and_prepopulate_kv_cache( kv_c_contexts=kv_c_contexts, k_pe_contexts=k_pe_contexts, block_size=block_size, - num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, device=device, num_blocks=vllm_config.cache_config.num_gpu_blocks, common_attn_metadata=common_attn_metadata, - randomize_blocks=True) + randomize_blocks=True, + ) # 4. Run vLLM backends and compare - for backend_name in BACKENDS_TO_TEST: + for i, backend_name in enumerate(BACKENDS_TO_TEST): backend_output = run_attention_backend( - backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, - common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, - kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, - mock_kv_b_proj) + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + kv_c_vllm, + k_pe_vllm, + kv_cache, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + mock_kv_b_proj, + ) # Check shape and dtype consistency - assert backend_output.shape == sdpa_output.shape, ( + assert backend_output.shape == sdpa_outputs[i].shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") - assert backend_output.dtype == sdpa_output.dtype, ( + f"SDPA shape {sdpa_outputs[i].shape}" + ) + assert backend_output.dtype == sdpa_outputs[i].dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_outputs[i].dtype}" + ) assert torch.isfinite(backend_output).all(), ( - f"[{backend_name}] produced non-finite values") + f"[{backend_name}] produced non-finite values" + ) # Check numerical similarity rtol = 1e-2 atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() - all_close = torch.allclose(backend_output, - sdpa_output, - rtol=rtol, - atol=atol) + torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i]) + ).item() + all_close = torch.allclose( + backend_output, sdpa_outputs[i], rtol=rtol, atol=atol + ) assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})" + ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py new file mode 100644 index 0000000000000..25de65a56b379 --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -0,0 +1,380 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for the FlashMLA sparse backend utilities.""" + +import math +from types import MethodType, SimpleNamespace + +import numpy as np +import pytest +import torch + +from tests.v1.attention.test_mla_backends import ( + BATCH_SPECS, + BatchSpec, + MockAttentionLayer, + create_and_prepopulate_kv_cache, +) +from tests.v1.attention.utils import ( + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, +) +from vllm import _custom_ops as ops +from vllm.attention.ops import flashmla +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend +from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks + +SPARSE_BACKEND_BATCH_SPECS = { + name: BATCH_SPECS[name] + for name in [ + "mixed_small", + "mixed_medium", + "small_prefill", + "medium_prefill", + "single_prefill", + ] +} + +SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec( + seq_lens=[1024] * 2, query_lens=[256] * 2 +) +SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec( + seq_lens=[256] * 2, query_lens=[256] * 2 +) + + +def _dequantize_fp8_ds_mla_entry( + cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Dequantize a single fp8_ds_mla cache entry back to latent + rope.""" + + # The first kv_lora_rank bytes store FP8 latent values with one scale per + # 128 element tile written as float32 right after the latent payload. + scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4] + latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device) + for tile_idx in range(4): + tile_start = tile_idx * 128 + tile_end = tile_start + 128 + ops.convert_fp8( + latent[tile_start:tile_end], + cache_slice[tile_start:tile_end], + float(scales[tile_idx].item()), + kv_dtype="fp8", + ) + latent = latent.to(dtype) + + rope_offset = kv_lora_rank // 2 + 8 + rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim] + return latent, rope_vals.clone() + + +def _quantize_dequantize_fp8_ds_mla( + kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.""" + + if kv_c.numel() == 0: + return kv_c.clone(), k_pe.clone() + + kv_lora_rank = kv_c.shape[-1] + rope_dim = k_pe.shape[-1] + num_tokens = kv_c.shape[0] + num_blocks = max(1, math.ceil(num_tokens / block_size)) + entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + + tmp_cache = torch.zeros( + num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device + ) + slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device) + + ops.concat_and_cache_mla( + kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale + ) + + dequant_kv_c = torch.empty_like(kv_c) + dequant_k_pe = torch.empty_like(k_pe) + + for token_idx in range(num_tokens): + slot = slot_mapping[token_idx].item() + block_idx = slot // block_size + block_offset = slot % block_size + cache_slice = tmp_cache[block_idx, block_offset] + latent, rope_vals = _dequantize_fp8_ds_mla_entry( + cache_slice, kv_lora_rank, rope_dim, kv_c.dtype + ) + dequant_kv_c[token_idx] = latent + dequant_k_pe[token_idx] = rope_vals + + return dequant_kv_c, dequant_k_pe + + +@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) +@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) +def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for sparse MLA decode test") + + device = torch.device("cuda") + dtype = torch.bfloat16 + + batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] + + # Model hyper-parameters (kept intentionally small for the unit test) + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + v_head_dim = 128 + head_size = kv_lora_rank + qk_rope_head_dim + topk_tokens = 2048 + + max_seqlen = max(batch_spec.seq_lens) + total_cache_tokens = sum(batch_spec.seq_lens) + block_size = 64 + + vllm_config = create_vllm_config( + model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", + max_model_len=max_seqlen, + num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1), + block_size=block_size, + hf_config_override={ + "index_topk": topk_tokens, + "attn_module_list_cfg": [{"topk_tokens": topk_tokens}], + }, + ) + model_config = vllm_config.model_config + model_config.hf_text_config = SimpleNamespace( + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + model_type="deepseek_v2", + ) + model_config.dtype = dtype + model_config.get_num_attention_heads = MethodType( + lambda self, parallel_config: num_heads, model_config + ) + model_config.get_num_kv_heads = MethodType( + lambda self, parallel_config: 1, model_config + ) + model_config.get_head_size = MethodType(lambda self: head_size, model_config) + model_config.get_sliding_window = MethodType(lambda self: None, model_config) + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + torch.manual_seed(0) + + scale = 1.0 / math.sqrt(head_size) + + # Shared MLA projection weights to keep reference and backend in sync + W_UK = torch.randn( + kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device + ) + W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device) + + # Build synthetic decode-only workload + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + kv_c_contexts, k_pe_contexts = [], [] + reference_outputs = [] + + kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + for i in range(batch_spec.batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + ctx_len = s_len - q_len + + q_c = torch.rand( + q_len, + num_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device, + ) + kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device) + k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) + + kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla( + kv_c_full, + k_pe_full.squeeze(1), + block_size=vllm_config.cache_config.block_size, + scale=kv_cache_scale, + ) + + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK) + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + + k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1) + + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, ctx_len:] = causal_mask + + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale + ) + sdpa_out = sdpa_out.transpose(1, 2).squeeze(0) + + sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV) + reference_outputs.append(sdpa_out.flatten(start_dim=-2)) + + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[ctx_len:]) + all_k_pe_vllm.append(k_pe_full[ctx_len:]) + kv_c_contexts.append(kv_c_full[: ctx_len + 1]) + k_pe_contexts.append(k_pe_full[: ctx_len + 1]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_reference = torch.cat(reference_outputs, dim=0) + + vllm_config.cache_config.cache_dtype = kv_cache_dtype + vllm_config.model_config.hf_config.index_topk = topk_tokens + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + vllm_config.cache_config.block_size, + device, + arange_block_indices=True, + ) + + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=vllm_config.cache_config.block_size, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=False, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + scale=kv_cache_scale, + ) + + builder_cls = FlashMLASparseBackend.get_builder_cls() + builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device) + metadata = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( + starts[:-1], seg_lengths + ) + seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32) + prefix_lengths = seq_lengths - seg_lengths + positions += np.repeat(prefix_lengths, seg_lengths) + + pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32) + topk = metadata.topk_tokens + debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0) + token_positions = pos_gpu.unsqueeze(1) + causal_mask = debug_indices <= token_positions + debug_indices = torch.where( + causal_mask, debug_indices, torch.full_like(debug_indices, -1) + ) + + # FlashMLASparseImpl now reads top-k indices from the indexer-provided + # buffer, so emulate that contract with a simple namespace mock. + debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone() + mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices) + + ok, reason = flashmla.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) + + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim) + ) + + mock_kv_b_proj = ColumnParallelLinear( + input_size=kv_lora_rank, + output_size=num_heads * (qk_nope_head_dim + v_head_dim), + bias=False, + ).to(device=device, dtype=dtype) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) + + impl_cls = FlashMLASparseBackend.get_impl_cls() + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer, + ) + + impl.process_weights_after_loading(dtype) + + layer = MockAttentionLayer(device) + out_buffer = torch.empty( + metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device + ) + + with torch.inference_mode(): + backend_output = impl.forward( + layer, + query_vllm, + kv_c_vllm, + k_pe_vllm, + kv_cache, + metadata, + output=out_buffer, + ) + + assert backend_output.shape == sdpa_reference.shape + assert backend_output.dtype == sdpa_reference.dtype + assert torch.isfinite(backend_output).all() + + torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5) + + +@pytest.mark.parametrize( + "seq_lens,max_buf,start,expected", + [ + # Basic split: totals per chunk ≤ max_buf + (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]), + # Non-zero start index + (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]), + # Exact fits should split between items when adding the next would + # overflow + (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]), + # All requests fit in a single chunk + (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]), + # Large buffer with non-zero start + (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]), + ], +) +def test_split_prefill_chunks(seq_lens, max_buf, start, expected): + out = split_prefill_chunks(seq_lens, max_buf, start) + assert out == expected diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 6a08cdc56f736..819cd81be358d 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -3,23 +3,36 @@ """Utility functions for attention-related v1 tests.""" from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import pytest import torch -from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, - LoadConfig, ModelConfig, ModelDType, ParallelConfig, - SchedulerConfig, VllmConfig) -from vllm.platforms import _Backend, current_platform +from vllm.attention.backends.abstract import AttentionImpl +from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.config import ( + CacheConfig, + CompilationConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.config.model import ModelDType from vllm.utils import resolve_obj_by_qualname -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import FullAttentionSpec @dataclass class BatchSpec: """Specification for a batch configuration (workload shape only).""" + seq_lens: list[int] query_lens: list[int] @@ -37,26 +50,25 @@ class BatchSpec: def create_common_attn_metadata( - batch_spec: BatchSpec, - block_size: int, - device: torch.device, - max_block_idx: int = 1000, - arange_block_indices: bool = False) -> CommonAttentionMetadata: + batch_spec: BatchSpec, + block_size: int, + device: torch.device, + max_block_idx: int = 1000, + arange_block_indices: bool = False, +) -> CommonAttentionMetadata: """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" # Create query start locations - query_start_loc = torch.zeros(batch_spec.batch_size + 1, - dtype=torch.int32, - device=device) - query_start_loc[1:] = torch.tensor(batch_spec.query_lens, - dtype=torch.int32, - device=device).cumsum(0) + query_start_loc = torch.zeros( + batch_spec.batch_size + 1, dtype=torch.int32, device=device + ) + query_start_loc[1:] = torch.tensor( + batch_spec.query_lens, dtype=torch.int32, device=device + ).cumsum(0) query_start_loc_cpu = query_start_loc.cpu() num_tokens = batch_spec.compute_num_tokens() # Create sequence lengths - seq_lens = torch.tensor(batch_spec.seq_lens, - dtype=torch.int32, - device=device) + seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() max_seq_len = int(seq_lens_cpu.max()) @@ -71,24 +83,23 @@ def create_common_attn_metadata( max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size if arange_block_indices: num_blocks = batch_spec.batch_size * max_blocks - block_table_tensor = torch.arange(num_blocks, - dtype=torch.int32, - device=device).view( - batch_spec.batch_size, - max_blocks) - slot_mapping = torch.arange(num_tokens, - dtype=torch.int64, - device=device).view(num_tokens) + block_table_tensor = torch.arange( + num_blocks, dtype=torch.int32, device=device + ).view(batch_spec.batch_size, max_blocks) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view( + num_tokens + ) else: - block_table_tensor = torch.randint(0, - max_block_idx, - (batch_spec.batch_size, max_blocks), - dtype=torch.int32, - device=device) - slot_mapping = torch.randint(0, - max_block_idx, (num_tokens, ), - dtype=torch.int64, - device=device) + block_table_tensor = torch.randint( + 0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device, + ) + slot_mapping = torch.randint( + 0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device + ) # Calculate max query length max_query_len = max(batch_spec.query_lens) @@ -109,76 +120,45 @@ def create_common_attn_metadata( ) -def get_attention_backend(backend_name: _Backend): - """Set up attention backend classes for testing. - - Args: - backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) - vllm_config: VllmConfig instance - - Returns: - Tuple of (backend_builder_class, backend_impl_class) - """ - backend_map = { - _Backend.FLASH_ATTN_VLLM_V1: - ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - if current_platform.is_cuda() else - "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" - ), - _Backend.FLASHINFER_VLLM_V1: - "vllm.v1.attention.backends.flashinfer.FlashInferBackend", - _Backend.FLEX_ATTENTION: - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", - _Backend.TRITON_ATTN_VLLM_V1: - "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", - _Backend.TREE_ATTN: - "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", - _Backend.XFORMERS_VLLM_V1: - "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", - _Backend.CUTLASS_MLA: - "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", - _Backend.FLASHMLA_VLLM_V1: - "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", - _Backend.TRITON_MLA_VLLM_V1: - "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", - } - - if backend_name not in backend_map: - raise ValueError(f"Unknown backend: {backend_name}") - - backend_class_name = backend_map[backend_name] - +def try_get_attention_backend( + backend: _Backend, +) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: + """Try to get the attention backend class, skipping test if not found.""" + backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_name) + backend_class = resolve_obj_by_qualname(backend_class_str) return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_name} not available: {e}") + pytest.skip(f"{backend_class_str} not available: {e}") + raise AssertionError("unreachable") from None -def create_standard_kv_cache_spec( - vllm_config: VllmConfig) -> FullAttentionSpec: +def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: """Create a FullAttentionSpec from ModelParams only.""" return FullAttentionSpec( block_size=vllm_config.cache_config.block_size, num_kv_heads=vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config), + vllm_config.parallel_config + ), head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, - use_mla=vllm_config.model_config.use_mla, sliding_window=vllm_config.model_config.get_sliding_window(), ) -def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", - tensor_parallel_size: int = 1, - max_model_len: int = 1024, - dtype: Union[ModelDType, torch.dtype] = "auto", - num_gpu_blocks: int = 1000, - block_size: int = 16, - max_num_seqs: int = 256, - max_num_batched_tokens: int = 8192, - enable_chunked_prefill: bool = True, - add_mock_model_methods: bool = True) -> VllmConfig: +def create_vllm_config( + model_name: str = "meta-llama/Meta-Llama-3-8B", + tensor_parallel_size: int = 1, + max_model_len: int = 1024, + dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, + block_size: int = 16, + max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, + add_mock_model_methods: bool = True, + hf_config_override: Optional[dict] = None, +) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" model_config = ModelConfig( @@ -201,7 +181,8 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( - tensor_parallel_size=tensor_parallel_size, ) + tensor_parallel_size=tensor_parallel_size, + ) scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, @@ -219,15 +200,20 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", # but some backends expect to query the model for layer-specific # parameters import types - model_config.get_num_layers = types.MethodType(lambda self: 1, - model_config) + + model_config.get_num_layers = types.MethodType(lambda self: 1, model_config) model_config.get_sliding_window_for_layer = types.MethodType( - lambda self, i: None, model_config) + lambda self, i: None, model_config + ) model_config.get_logits_soft_cap_for_layer = types.MethodType( - lambda self, i: 0.0, model_config) + lambda self, i: 0.0, model_config + ) model_config.get_sm_scale_for_layer = types.MethodType( - lambda self, i: 1.0 / model_config.get_head_size()**0.5, - model_config) + lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config + ) + + if hf_config_override: + model_config.hf_config.update(hf_config_override) return VllmConfig( model_config=model_config, @@ -240,12 +226,14 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) -def create_dummy_kv_cache(block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: +def create_dummy_kv_cache( + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int = 100, +) -> torch.Tensor: """Create a dummy KV cache tensor for testing.""" kv_cache = torch.randn( num_blocks, @@ -254,5 +242,95 @@ def create_dummy_kv_cache(block_size: int, num_kv_heads, head_size, dtype=dtype, - device=device) + device=device, + ) return kv_cache + + +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict # compilation config + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +full_cg_backend_configs = { + # FA3 on Hopper + "FA3": BackendConfig( + name="FA3", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0), + ), + # FlashMLA on Hopper + "FlashMLA": BackendConfig( + name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0), + ), + # Cutlass MLA on Blackwell + "CutlassMLA": BackendConfig( + name="CutlassMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", + "FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(10, 0), + ), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": BackendConfig( + name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0), + ), + # FA2 + "FA2": BackendConfig( + name="FA2", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), + # Triton Attention + "TritonAttn": BackendConfig( + name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), + # FlashInfer + "FlashInfer": BackendConfig( + name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), +} diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 123faaebb2833..6d870b5640dfb 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -11,12 +11,16 @@ from vllm.v1.utils import ConstantList from .utils import create_requests, create_scheduler +pytestmark = pytest.mark.cpu_test + def _make_model_runner_output( - scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: + scheduler_output: SchedulerOutput, +) -> ModelRunnerOutput: req_ids = list(scheduler_output.num_scheduled_tokens.keys()) return ModelRunnerOutput( req_ids=req_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, sampled_token_ids=[[i] for i in range(len(req_ids))], logprobs=None, prompt_logprobs_dict={}, @@ -69,8 +73,7 @@ def test_abort(): if not abort_order: return req = requests[abort_order.pop(0)] - scheduler.finish_requests(req.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) while sched_outputs: # Abort a scheduled request. @@ -106,8 +109,7 @@ def test_preempt(): if not abort_order: return req = requests[abort_order.pop(0)] - scheduler.finish_requests(req.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) while sched_outputs: # Abort a scheduled request. @@ -129,15 +131,19 @@ def test_prefix_caching_for_prefill_dedup(): CHUNK_SIZE = 1000 BLOCK_SIZE = 16 num_prompt_tokens = 100 - scheduler = create_scheduler(async_scheduling=True, - max_num_batched_tokens=CHUNK_SIZE, - enable_prefix_caching=True, - block_size=BLOCK_SIZE) - requests = create_requests(num_requests=5, - num_tokens=num_prompt_tokens, - max_tokens=3, - same_prompt=True, - block_size=BLOCK_SIZE) + scheduler = create_scheduler( + async_scheduling=True, + max_num_batched_tokens=CHUNK_SIZE, + enable_prefix_caching=True, + block_size=BLOCK_SIZE, + ) + requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens, + max_tokens=3, + same_prompt=True, + block_size=BLOCK_SIZE, + ) requests_copy = requests.copy() # Two requests with the same prompt. @@ -179,14 +185,18 @@ def test_prefix_caching_for_multi_turn(): BLOCK_SIZE = 16 num_prompt_tokens = 100 num_output_tokens = 200 - scheduler = create_scheduler(async_scheduling=True, - max_num_batched_tokens=CHUNK_SIZE, - enable_prefix_caching=True, - block_size=BLOCK_SIZE) - requests = create_requests(num_requests=5, - num_tokens=num_prompt_tokens, - max_tokens=num_output_tokens, - block_size=BLOCK_SIZE) + scheduler = create_scheduler( + async_scheduling=True, + max_num_batched_tokens=CHUNK_SIZE, + enable_prefix_caching=True, + block_size=BLOCK_SIZE, + ) + requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE, + ) for req in requests: scheduler.add_request(req) @@ -206,14 +216,16 @@ def test_prefix_caching_for_multi_turn(): # Create next-turn requests whose prompts are the full output of the # previous turn. - next_turn_requests = create_requests(num_requests=5, - num_tokens=num_prompt_tokens + - num_output_tokens, - max_tokens=num_output_tokens, - block_size=BLOCK_SIZE) + next_turn_requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens + num_output_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE, + ) for i, req in enumerate(next_turn_requests): - req.prompt_token_ids = (requests[i].prompt_token_ids + - list(requests[i].output_token_ids)) + req.prompt_token_ids = requests[i].prompt_token_ids + list( + requests[i].output_token_ids + ) req._all_token_ids = req.prompt_token_ids.copy() req.all_token_ids = ConstantList(req._all_token_ids) req.block_hashes = [] @@ -227,5 +239,4 @@ def test_prefix_caching_for_multi_turn(): # Make sure the next-turn requests get prefix cache hit by the previous # requests. for req in next_turn_requests: - assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * - BLOCK_SIZE) + assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py new file mode 100644 index 0000000000000..8a52b5bd78977 --- /dev/null +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager + +pytestmark = pytest.mark.cpu_test + + +# ------------------ Mock Classes ------------------ # +class MockRequest: + def __init__(self, request_id, mm_hashes, token_counts): + self.request_id = request_id + self._token_counts = token_counts + self.mm_features = [] + for i, mm_hash in enumerate(mm_hashes): + feature = MultiModalFeatureSpec( + data=None, + modality="image", + identifier=mm_hash, + mm_position=PlaceholderRange(offset=0, length=self._token_counts[i]), + ) + self.mm_features.append(feature) + + def get_num_encoder_tokens(self, input_id: int) -> int: + return self._token_counts[input_id] + + +# ------------------ Unit Tests ------------------ # +def test_basic_allocate_and_reuse(): + cache = EncoderCacheManager(cache_size=10) + req = MockRequest("r1", ["imgA"], [4]) + + assert not cache.check_and_update_cache(req, 0) + assert cache.can_allocate(req, 0, int(1e9), 0) + + cache.allocate(req, 0) + + assert cache.check_and_update_cache(req, 0) + assert "r1" in cache.cached["imgA"] + assert cache.num_free_slots == 6 + + # Free twice to bring refcount to 0. + cache.free_encoder_input(req, 0) + cache.free_encoder_input(req, 0) + + assert not cache.cached["imgA"] + assert "imgA" in cache.freeable + assert cache.num_freeable_slots == 10 + assert cache.num_free_slots == 6 + + +def test_freeing_decreases_refcount_and_moves_to_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req2", ["img3"], [5]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert len(manager.cached["img3"]) == 1 + + manager.free_encoder_input(req, 0) + + assert not manager.cached["img3"] + assert "img3" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_free_request_frees_all_inputs(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req3", ["a", "b"], [2, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 1, int(1e9), 0) + manager.allocate(req, 1) + + assert len(manager.cached["a"]) == 1 + assert len(manager.cached["b"]) == 1 + + manager.free(req) + + assert not manager.cached["a"] + assert not manager.cached["b"] + assert "a" in manager.freeable + assert "b" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_eviction_when_cache_is_full(): + manager = EncoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["x"], [6]) + req2 = MockRequest("req2", ["y"], [5]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + # 'x' should have been evicted. + assert "x" not in manager.cached + assert "x" in manager.get_freed_mm_hashes() + + +def test_get_cached_input_ids(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + assert manager.can_allocate(req, 2, int(1e9), 0) + manager.allocate(req, 2) + + cached_ids = manager.get_cached_input_ids(req) + assert cached_ids == {0, 2} + + +def test_has_cache_restores_from_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqY", ["imgZ"], [4]) + + assert manager.can_allocate(req, 0, int(1e9), 0) + manager.allocate(req, 0) + + manager.free_encoder_input(req, 0) + + # Should restore from freeable. + assert manager.check_and_update_cache(req, 0) + assert len(manager.cached["imgZ"]) == 1 + assert "imgZ" not in manager.freeable + assert manager.num_freeable_slots == 6 + + +def test_get_freed_mm_hashes_clears_freed_list(): + manager = EncoderCacheManager(cache_size=10) + req1 = MockRequest("reqA", ["a"], [5]) + req2 = MockRequest("reqB", ["b"], [6]) + + assert manager.can_allocate(req1, 0, int(1e9), 0) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + # Should trigger eviction of 'a'. + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + freed = manager.get_freed_mm_hashes() + assert "a" in freed + assert manager.get_freed_mm_hashes() == [] + + +def test_schedule_request_multi_images_respect_space_limit(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 100 + + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) + + +def test_schedule_request_multi_images_respect_compute_limit(): + manager = EncoderCacheManager(cache_size=100) + req = MockRequest("reqA", ["a", "b"], [5, 6]) + compute_budget = 10 + num_tokens_to_schedule = 0 + assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) + num_tokens_to_schedule += req.get_num_encoder_tokens(0) + compute_budget -= req.get_num_encoder_tokens(0) + + assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 47c74aff1e753..714a540e86b5e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -6,26 +6,55 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit +from vllm.utils import GiB_bytes, sha256, sha256_cbor from vllm.v1.core.kv_cache_manager import KVCacheManager -# disable yapf here as it formats differently than isort such that both fail -# yapf: disable from vllm.v1.core.kv_cache_utils import ( - FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, - estimate_max_model_len, generate_block_hash_extra_keys, - get_kv_cache_config, get_max_concurrency_for_kv_cache_config, - get_request_block_hasher, hash_block_tokens, init_none_hash, - is_kv_cache_type_uniform, unify_kv_cache_configs) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor, - SlidingWindowSpec) -from vllm.v1.metrics.stats import PrefixCacheStats + BlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + estimate_max_model_len, + generate_block_hash_extra_keys, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_max_concurrency_for_kv_cache_config, + get_request_block_hasher, + hash_block_tokens, + init_none_hash, + is_kv_cache_spec_uniform, + make_block_hash_with_group_id, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request -# yapf: enable +pytestmark = pytest.mark.cpu_test + + +@pytest.fixture(autouse=True) +def _auto_init_hash_fn(request): + hash_fn: Callable + if "hash_fn" in request.fixturenames: + hash_fn = init_none_hash(request.getfixturevalue("hash_fn")) + else: + hash_fn = sha256 + init_none_hash(hash_fn) def make_request( @@ -37,79 +66,83 @@ def make_request( mm_hashes: Optional[list[str]] = None, cache_salt: Optional[str] = None, ): - if mm_positions is None: - mm_kwargs = None - else: - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_positions) + mm_features = [] + if mm_positions is not None: + for j, position in enumerate(mm_positions): + identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) - return Request(request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - block_hasher=get_request_block_hasher(block_size, hash_fn)) + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) -def new_kv_cache_spec(block_size=16, - num_kv_heads=2, - head_size=64, - dtype=torch.float32, - use_mla=False, - sliding_window=None): - return FullAttentionSpec(block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - use_mla=use_mla, - sliding_window=sliding_window) +def new_kv_cache_spec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + sliding_window=None, +): + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + sliding_window=sliding_window, + ) -def new_sliding_window_spec(block_size=16, - num_kv_heads=2, - head_size=64, - dtype=torch.float32, - use_mla=False, - sliding_window=1): - return SlidingWindowSpec(block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - use_mla=use_mla, - sliding_window=sliding_window) +def new_sliding_window_spec( + block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, sliding_window=1 +): + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + sliding_window=sliding_window, + ) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils # case 1: PYTHONHASHSEED is not set, use random with monkeypatch.context() as m: - m.delenv('PYTHONHASHSEED', raising=False) + m.delenv("PYTHONHASHSEED", raising=False) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) - assert reloaded_kv_cache_utils.NONE_HASH != 0 + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) + assert reloaded_kv_cache_utils.NONE_HASH != b"" # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: - m.setenv('PYTHONHASHSEED', 'python hash seed') + m.setenv("PYTHONHASHSEED", "python hash seed") reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) - assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) + assert hash_fn("python hash seed") == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): - import vllm.v1.core.kv_cache_utils - # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) assert block.block_id == 0 @@ -123,8 +156,7 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, - token_ids=(1, 2, 3)) + block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0) block.block_hash = block_hash assert block.block_hash == block_hash @@ -178,10 +210,8 @@ def test_free_kv_cache_block_queue_operations(): for _ in range(4): queue.popleft() assert queue.num_free_blocks == 0 - assert (queue.fake_free_list_head.next_free_block - is queue.fake_free_list_tail) - assert (queue.fake_free_list_tail.prev_free_block - is queue.fake_free_list_head) + assert queue.fake_free_list_head.next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is queue.fake_free_list_head # Attempt to pop from an empty queue with pytest.raises(ValueError) as e: @@ -197,10 +227,8 @@ def test_free_kv_cache_block_queue_append_n(): # fake_head->fake_tail queue.append_n([]) assert queue.num_free_blocks == 0 - assert (queue.fake_free_list_head.next_free_block - is queue.fake_free_list_tail) - assert (queue.fake_free_list_tail.prev_free_block - is queue.fake_free_list_head) + assert queue.fake_free_list_head.next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is queue.fake_free_list_head # Append 1 block # fake_head->b0->fake_tail queue.append_n(blocks[0:1]) @@ -240,12 +268,27 @@ def test_free_kv_cache_block_queue_append_n(): assert blocks[3].next_free_block is queue.fake_free_list_tail assert queue.fake_free_list_tail.prev_free_block is blocks[3] + # Create an empty FreeKVCacheBlockQueue + invalid_queue = FreeKVCacheBlockQueue([]) + # set prev_free_block to None and this will cause assertation in append_n + invalid_queue.fake_free_list_tail.prev_free_block = None + with pytest.raises(AssertionError): + # Append 1 block + # fake_head->fake_tail + invalid_queue.append_n(blocks[0:1]) + assert invalid_queue.num_free_blocks == 0 + assert ( + invalid_queue.fake_free_list_head.next_free_block + == invalid_queue.fake_free_list_tail + ) + def test_free_kv_cache_block_queue_popleft_n(): blocks = [KVCacheBlock(block_id=i) for i in range(6)] - # Create a empty FreeKVCacheBlockQueue with these blocks + # Create an empty FreeKVCacheBlockQueue with these blocks queue = FreeKVCacheBlockQueue( - [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]]) + [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]] + ) assert queue.num_free_blocks == 6 assert queue.fake_free_list_head.next_free_block is blocks[1] assert blocks[1].prev_free_block is queue.fake_free_list_head @@ -265,9 +308,11 @@ def test_free_kv_cache_block_queue_popleft_n(): # Pop 0 block # fake_head->b1->b3->b5->b4->b0->b2->fake_tail assert len(queue.popleft_n(0)) == 0 + assert queue.num_free_blocks == 6 # Pop 1 block # fake_head->b3->b5->b4->b0->b2->fake_tail result_blocks = queue.popleft_n(1) + assert queue.num_free_blocks == 5 assert len(result_blocks) == 1 assert result_blocks[0] is blocks[1] for block in result_blocks: @@ -277,6 +322,7 @@ def test_free_kv_cache_block_queue_popleft_n(): # fake_head->b4->b0->b2->fake_tail result_blocks = queue.popleft_n(2) assert len(result_blocks) == 2 + assert queue.num_free_blocks == 3 assert result_blocks[0] is blocks[3] assert result_blocks[1] is blocks[5] for block in result_blocks: @@ -286,6 +332,7 @@ def test_free_kv_cache_block_queue_popleft_n(): # fake_head->fake_tail result_blocks = queue.popleft_n(3) assert len(result_blocks) == 3 + assert queue.num_free_blocks == 0 assert result_blocks[0] is blocks[4] assert result_blocks[1] is blocks[0] assert result_blocks[2] is blocks[2] @@ -315,8 +362,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks(): # Append a block back and check again queue.append(block_to_remove) - assert queue.get_all_free_blocks() == \ - blocks[1:2] + blocks[3:] + [block_to_remove] + assert queue.get_all_free_blocks() == blocks[1:2] + blocks[3:] + [block_to_remove] def test_generate_block_hash_extra_keys(): @@ -332,12 +378,12 @@ def test_generate_block_hash_extra_keys(): # Test with no extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0) - assert extra_keys == ("hash1", ) + assert extra_keys == ("hash1",) assert next_mm_idx == 1 # Test with partial overlap extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0) - assert extra_keys == ("hash1", ) + assert extra_keys == ("hash1",) assert next_mm_idx == 1 # Test with no overlap @@ -347,7 +393,7 @@ def test_generate_block_hash_extra_keys(): # Test with multiple extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0) - assert extra_keys == ('hash1', 'hash2') + assert extra_keys == ("hash1", "hash2") assert next_mm_idx == 2 @@ -375,9 +421,9 @@ def test_generate_block_hash_extra_keys_cache_salt(): # salt is added for the first token extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0) - assert extra_keys == ('salt', ) + assert extra_keys == ("salt",) extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0) - assert extra_keys == ('salt', ) + assert extra_keys == ("salt",) # no salt added for other tokens extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0) @@ -397,33 +443,26 @@ def test_generate_block_hash_extra_keys_cache_salt(): ) # Test with no extra keys - extra_keys, next_mm_idx = generate_block_hash_extra_keys( - request_mm, 0, 5, 0) + extra_keys, next_mm_idx = generate_block_hash_extra_keys(request_mm, 0, 5, 0) assert extra_keys == ("hash1", "salt") assert next_mm_idx == 1 -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_block_tokens(hash_fn): - import vllm.v1.core.kv_cache_utils - init_none_hash(hash_fn) - parent_block_hash = 123 + parent_block_hash = BlockHash(b"123") curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - curr_block_token_ids, extra_keys) - assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) - assert block_hash.hash_value == hash_fn( - (parent_block_hash, curr_block_token_ids, extra_keys)) - assert block_hash.token_ids == curr_block_token_ids - assert block_hash.extra_keys == extra_keys + block_hash = hash_block_tokens( + hash_fn, parent_block_hash, curr_block_token_ids, extra_keys + ) + expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) + assert block_hash == expected -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_request_block_hasher(hash_fn): - import vllm.v1.core.kv_cache_utils - init_none_hash(hash_fn) request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -438,22 +477,12 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) - assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) - - # Check the first block - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys == ("hash1", ) - - # Check the second block - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys == ("hash2", ) + assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1",))) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), ("hash2",))) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_tokens_different_mm_input(hash_fn): - init_none_hash(hash_fn) - request1 = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -480,10 +509,8 @@ def test_hash_tokens_different_mm_input(hash_fn): assert block_hashes1[1] != block_hashes2[1] -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_request_tokens_no_mm_inputs(hash_fn): - init_none_hash(hash_fn) - request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -496,33 +523,31 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys is None - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys is None + assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), None)) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) + + +def _stats(requests: int, queries: int, hits: int) -> PrefixCacheStats: + return PrefixCacheStats(requests=requests, queries=queries, hits=hits) def test_metrics(): """ Test the prefix caching metrics. """ - - def stats(requests, queries, hits): - return PrefixCacheStats(requests=requests, queries=queries, hits=hits) - - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 - metrics.observe(stats(1, 20, 9)) + metrics.observe(_stats(1, 20, 9)) # 9 / 20 = 0.45 assert metrics.hit_rate == 0.45 - metrics.observe(stats(4, 80, 16)) + metrics.observe(_stats(4, 80, 16)) # 25 / 100 = 0.25 assert metrics.hit_rate == 0.25 - metrics.observe(stats(1, 10, 2)) + metrics.observe(_stats(1, 10, 2)) # Remove (20, 9) and add (10, 2): 18 / 90 = 0.2 assert metrics.aggregated_requests == 5 @@ -538,96 +563,388 @@ def test_metrics(): assert not metrics.query_queue -def test_unify_kv_cache_configs(): - same_kv_cache_config = [ - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - ], - ), - KVCacheConfig( - num_blocks=20, - kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - ], - ), - ] - unify_kv_cache_configs(same_kv_cache_config) - assert same_kv_cache_config[0].num_blocks == 10 - assert same_kv_cache_config[1].num_blocks == 10 +def test_metrics_empty_stats(): + """ + Test the prefix caching metrics with empty stats. + """ + metrics = CachingMetrics(max_recent_requests=5) + metrics.observe(_stats(0, 0, 0)) + metrics.observe(_stats(1, 20, 9)) + metrics.observe(_stats(0, 0, 0)) + metrics.observe(_stats(4, 80, 16)) + metrics.observe(_stats(0, 0, 0)) + metrics.observe(_stats(1, 10, 2)) + # Remove (20, 9) and add (10, 2): 18 / 90 = 0.2 + assert metrics.aggregated_requests == 5 + assert metrics.aggregated_query_total == 90 + assert metrics.aggregated_query_hit == 18 + assert metrics.hit_rate == 0.2 - need_sort_kv_cache_config = [ + # Only the latest added stats preserved 10 / 20 = 0.5 + metrics.observe(_stats(11, 20, 10)) + assert metrics.aggregated_requests == 11 + assert metrics.aggregated_query_total == 20 + assert metrics.aggregated_query_hit == 10 + assert metrics.hit_rate == 0.5 + + # Only the latest added stats preserved 30 / 40 = 0.75 + metrics.observe(_stats(22, 40, 30)) + assert metrics.aggregated_requests == 22 + assert metrics.aggregated_query_total == 40 + assert metrics.aggregated_query_hit == 30 + assert metrics.hit_rate == 0.75 + + +def test_get_kv_cache_configs_multiple_workers(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + + ref_kv_cache_spec = new_kv_cache_spec() + same_kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + ] + + # Basic case. All things are the same. + kv_cache_configs = get_kv_cache_configs( + vllm_config, + same_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) + assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), KVCacheConfig( - num_blocks=20, + num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), ] - unify_kv_cache_configs(need_sort_kv_cache_config) - assert need_sort_kv_cache_config[0].num_blocks == 10 - assert need_sort_kv_cache_config[1].num_blocks == 10 - - diff_kv_cache_config = [ + # Different available memory. This is the case for TP. + # Use the smallest memory available. + kv_cache_configs = get_kv_cache_configs( + vllm_config, + same_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 20, + ], + ) + assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), KVCacheConfig( - num_blocks=20, + num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + ] + + # Different KV cache specs. This is the case for PP. + different_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + }, + { + "layer2": new_kv_cache_spec(), + "layer3": new_kv_cache_spec(), + }, + ] + + # Different workers have different layers. + kv_cache_configs = get_kv_cache_configs( + vllm_config, + different_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=8)), ], ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()), + ], + ), + ] + + # Some layers are the same, some are different. This is the case for TP+PP + tp_pp_kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer3": new_kv_cache_spec(), + }, + { + "layer3": new_kv_cache_spec(), + }, + ] + + kv_cache_configs = get_kv_cache_configs( + vllm_config, + tp_pp_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), + ], + ), + ] + + # Different workers have different types of layers. This is the case for + # hybrid models + PP. + different_type_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer3": new_sliding_window_spec(), + "layer4": new_sliding_window_spec(), + }, + ] + kv_cache_configs = get_kv_cache_configs( + vllm_config, + different_type_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + KVCacheGroupSpec([], new_sliding_window_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer4"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec([], ref_kv_cache_spec), + KVCacheGroupSpec(["layer3", "layer4"], new_sliding_window_spec()), + ], + ), + ] + + # When divided into multiple KVCacheGroups, need to ensure the number of + # layers per group is similar. + different_type_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_sliding_window_spec(), + "layer3": new_sliding_window_spec(), + }, + { + "layer4": new_kv_cache_spec(), + "layer5": new_sliding_window_spec(), + "layer6": new_sliding_window_spec(), + }, + ] + kv_cache_configs = get_kv_cache_configs( + vllm_config, + different_type_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 10, + ref_kv_cache_spec.page_size_bytes * 10, + ], + ) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1", "layer2", "layer3"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], ref_kv_cache_spec), + KVCacheGroupSpec(["layer2"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer3"], new_sliding_window_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4", "layer5", "layer6"], + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer4"], ref_kv_cache_spec), + KVCacheGroupSpec(["layer5"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer6"], new_sliding_window_spec()), + ], + ), + ] + + # Have conflicting layers. Need to raise an error. + conflicting_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + }, + { + "layer1": new_sliding_window_spec(), + }, ] with pytest.raises(AssertionError): - unify_kv_cache_configs(diff_kv_cache_config) + get_kv_cache_configs( + vllm_config, + conflicting_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) def test_merge_kv_cache_spec(): @@ -657,7 +974,6 @@ def test_merge_kv_cache_spec(): num_kv_heads=full_spec.num_kv_heads, head_size=full_spec.head_size, dtype=full_spec.dtype, - use_mla=full_spec.use_mla, sliding_window=1, ), ] @@ -673,14 +989,16 @@ def test_merge_kv_cache_spec(): ] with pytest.raises(ValueError): different_sliding_window_layer_specs[0].merge( - different_sliding_window_layer_specs) + different_sliding_window_layer_specs + ) same_sliding_window_layer_specs = [ new_kv_cache_spec(num_kv_heads=32, sliding_window=1), new_kv_cache_spec(num_kv_heads=32, sliding_window=1), ] merged_layer_spec = same_sliding_window_layer_specs[0].merge( - same_sliding_window_layer_specs) + same_sliding_window_layer_specs + ) assert merged_layer_spec.sliding_window == 1 same_sliding_window_layer_spec_with_none = [ @@ -688,49 +1006,51 @@ def test_merge_kv_cache_spec(): new_kv_cache_spec(num_kv_heads=32, sliding_window=None), ] merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( - same_sliding_window_layer_spec_with_none) + same_sliding_window_layer_spec_with_none + ) assert merged_layer_spec.sliding_window == 1 -def test_is_kv_cache_type_uniform(): +def test_is_kv_cache_spec_uniform(): kv_cache_spec = { "layer_1": new_kv_cache_spec(num_kv_heads=32), "layer_2": new_kv_cache_spec(num_kv_heads=32), } - assert is_kv_cache_type_uniform(kv_cache_spec) + assert is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_kv_cache_spec(num_kv_heads=32), "layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1), } - assert is_kv_cache_type_uniform(kv_cache_spec) + assert is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_kv_cache_spec(num_kv_heads=32), "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), } - assert not is_kv_cache_type_uniform(kv_cache_spec) + assert not is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), } - assert is_kv_cache_type_uniform(kv_cache_spec) + assert is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2), } - assert not is_kv_cache_type_uniform(kv_cache_spec) + assert not is_kv_cache_spec_uniform(kv_cache_spec) @pytest.mark.parametrize( - ("model_id", "max_model_len", "want_estimated_max_len"), [ + ("model_id", "max_model_len", "want_estimated_max_len"), + [ ("Qwen/Qwen1.5-7B", 16385, 16384), ("Qwen/Qwen1.5-7B", 16383, 16383), - ]) -def test_estimate_max_model_len(model_id, max_model_len, - want_estimated_max_len): + ], +) +def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len): # Create a VllmConfig model_config = ModelConfig( model_id, @@ -754,11 +1074,11 @@ def test_estimate_max_model_len(model_id, max_model_len, num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, ) # Estimate the maximum model length, 16384 model_len need 8GB - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - 8 * GiB_bytes) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, 8 * GiB_bytes + ) assert estimated_max_len == want_estimated_max_len @@ -772,8 +1092,9 @@ def test_get_max_concurrency_for_kv_cache_config(): dtype="float16", max_model_len=max_model_len, ) - scheduler_config = SchedulerConfig(max_num_batched_tokens=1024, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=1024, enable_chunked_prefill=True + ) vllm_config = VllmConfig( model_config=model_config, @@ -785,7 +1106,6 @@ def test_get_max_concurrency_for_kv_cache_config(): num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, ) sliding_window_spec = SlidingWindowSpec( @@ -793,7 +1113,6 @@ def test_get_max_concurrency_for_kv_cache_config(): num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, sliding_window=1024, ) @@ -801,38 +1120,39 @@ def test_get_max_concurrency_for_kv_cache_config(): num_blocks=int(1024 * 1.5), kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - full_attention_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), ], ) max_concurrency_full_attention = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_full_attention) + vllm_config, kv_cache_config_full_attention + ) assert max_concurrency_full_attention == 1.5 kv_cache_config_sliding_window = KVCacheConfig( num_blocks=129 * 3, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - sliding_window_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], sliding_window_spec), ], ) max_concurrency_sliding_window = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_sliding_window) + vllm_config, kv_cache_config_sliding_window + ) assert max_concurrency_sliding_window == 3 kv_cache_config_hybrid_model = KVCacheConfig( num_blocks=(1024 + 129) * 3, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - full_attention_spec), - KVCacheGroupSpec([f"layer_{i}" for i in range(32, 64)], - sliding_window_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), + KVCacheGroupSpec( + [f"layer_{i}" for i in range(32, 64)], sliding_window_spec + ), ], ) max_concurrency_hybrid_model = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_hybrid_model) + vllm_config, kv_cache_config_hybrid_model + ) assert max_concurrency_hybrid_model == 3 @@ -845,8 +1165,7 @@ def test_allocate_with_lookahead(): KVCacheTensor(size=100, shared_by=["layer1"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], - new_kv_cache_spec(block_size=block_size)), + KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)), ], ) @@ -859,8 +1178,7 @@ def test_allocate_with_lookahead(): ) # Test case 1: Requires additional lookahead tokens - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -869,8 +1187,7 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, @@ -881,8 +1198,7 @@ def test_allocate_with_lookahead(): # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -891,7 +1207,7 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 -def test_get_kv_cache_config(): +def test_get_kv_cache_config_one_worker(): # pass max_model_len to pass check_enough_kv_cache_memory model_config = ModelConfig(max_model_len=16) vllm_config = VllmConfig(model_config=model_config) @@ -899,77 +1215,78 @@ def test_get_kv_cache_config(): mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # all layers are full attention -> single group kv_cache_specs_full = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), } - kv_cache_config_full = get_kv_cache_config( - vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + kv_cache_config_full = get_kv_cache_configs( + vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] + )[0] + print(kv_cache_config_full) assert kv_cache_config_full == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) # all layers are sliding window -> single group kv_cache_specs_sliding = { - 'layer_1': new_sliding_window_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_sliding_window_spec(), + "layer_2": new_sliding_window_spec(), } - kv_cache_config_sliding = get_kv_cache_config( - vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32) + kv_cache_config_sliding = get_kv_cache_configs( + vllm_config, [kv_cache_specs_sliding], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_sliding == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec()) - ]) + ], + ) # full + sliding, but disable_hybrid_kv_cache_manager vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], - new_kv_cache_spec(sliding_window=1)), + KVCacheGroupSpec( + ["layer_1", "layer_2"], new_kv_cache_spec(sliding_window=1) + ), ], ) vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False # full + sliding, with hybrid_kv_cache_manager kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=64, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 64, - shared_by=["layer_1", "layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 64, shared_by=["layer_1", "layer_2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()), @@ -979,90 +1296,243 @@ def test_get_kv_cache_config(): # 2 full + 4 sliding, 2 layers per group kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), - 'layer_3': new_sliding_window_spec(), - 'layer_4': new_sliding_window_spec(), - 'layer_5': new_sliding_window_spec(), - 'layer_6': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + "layer_3": new_sliding_window_spec(), + "layer_4": new_sliding_window_spec(), + "layer_5": new_sliding_window_spec(), + "layer_6": new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) - assert kv_cache_config_hybrid == KVCacheConfig( - num_blocks=32, - kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_3", "layer_5"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_4", "layer_6"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_3", "layer_4"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_5", "layer_6"], - new_sliding_window_spec()), - ], - ) - - # 3 full + 7 sliding, pad to 3 full + 9 sliding - kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), - 'layer_3': new_kv_cache_spec(), - 'layer_4': new_sliding_window_spec(), - 'layer_5': new_sliding_window_spec(), - 'layer_6': new_sliding_window_spec(), - 'layer_7': new_sliding_window_spec(), - 'layer_8': new_sliding_window_spec(), - 'layer_9': new_sliding_window_spec(), - 'layer_10': new_sliding_window_spec(), - } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_5", "layer_8"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_3", "layer_6", "layer_9"]), + shared_by=["layer_1", "layer_3", "layer_4"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_5", "layer_6"], + ), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], - new_kv_cache_spec()), - KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer_3", "layer_5"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_4", "layer_6"], new_sliding_window_spec()), ], ) - # different hidden size, unimplemented + # 3 full + 7 sliding, pad to 3 full + 9 sliding kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(head_size=128), - 'layer_2': new_kv_cache_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + "layer_3": new_kv_cache_spec(), + "layer_4": new_sliding_window_spec(), + "layer_5": new_sliding_window_spec(), + "layer_6": new_sliding_window_spec(), + "layer_7": new_sliding_window_spec(), + "layer_8": new_sliding_window_spec(), + "layer_9": new_sliding_window_spec(), + "layer_10": new_sliding_window_spec(), } - with pytest.raises(NotImplementedError): - get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, - mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 3 * 32] + )[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_4", "layer_5", "layer_6"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_7", "layer_8", "layer_9"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, shared_by=["layer_3", "layer_10"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], new_kv_cache_spec()), + KVCacheGroupSpec( + ["layer_4", "layer_7", "layer_10"], new_sliding_window_spec() + ), + KVCacheGroupSpec(["layer_5", "layer_8"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_6", "layer_9"], new_sliding_window_spec()), + ], + ) + + # different hidden size + kv_cache_specs_hybrid = { + "layer_1": new_kv_cache_spec(head_size=128), + "layer_2": new_kv_cache_spec(head_size=64), + } + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 3 * 32] + )[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs_hybrid + ), + ) + ], + ) # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 - kv_cache_config_override_blocks = get_kv_cache_config( - vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + kv_cache_config_override_blocks = get_kv_cache_configs( + vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_override_blocks == KVCacheConfig( num_blocks=16, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 16, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 16, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_2"]), ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) + + +def test_get_kv_cache_configs_attention_free(): + kv_cache_specs: dict[str, KVCacheSpec] = {} + vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16)) + kv_cache_configs = get_kv_cache_configs(vllm_config, [kv_cache_specs], [0]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=[], + ) + ] + + +def test_generate_uniform_type_kv_cache_specs(): + # All layers are full attention, can be merged + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(head_size=128), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec == UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs + ) + + # Full attention + sliding window, cannot be merged + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(sliding_window=1), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec is None + + # different order of full attention + sliding window, cannot be merged + kv_cache_specs = { + "layer_1": new_sliding_window_spec(sliding_window=1), + "layer_2": new_kv_cache_spec(), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec is None + + # Same-size sliding window, can be merged + kv_cache_specs = { + "layer_1": new_sliding_window_spec(sliding_window=1), + "layer_2": new_sliding_window_spec(sliding_window=1, head_size=128), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec == UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs + ) + + # different block sizes, cannot be merged + kv_cache_specs = { + "layer_1": new_kv_cache_spec(block_size=16), + "layer_2": new_kv_cache_spec(block_size=32), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec is None + + +def test_generate_scheduler_kv_cache_config(): + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(head_size=128), + } + kv_cache_configs = [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs + ), + ), + ], + ) + ] + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + assert scheduler_kv_cache_config == KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) + + +def new_mla_spec(cache_dtype_str=None): + return MLAAttentionSpec( + block_size=16, + num_kv_heads=16, + head_size=64, + dtype=torch.float32, + cache_dtype_str=cache_dtype_str, + ) + + +def test_merge_mla_spec(): + kv_cache_specs = [ + new_mla_spec(), + new_mla_spec(), + ] + mla_spec = kv_cache_specs[0].merge(kv_cache_specs) + assert mla_spec == new_mla_spec() + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + ] + mla_spec = kv_cache_specs[0].merge(kv_cache_specs) + assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla") + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_mla_spec(cache_dtype_str=None), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) + + kv_cache_specs = [ + new_kv_cache_spec(), + new_mla_spec(), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_kv_cache_spec(), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) diff --git a/tests/v1/core/test_kv_sharing.py b/tests/v1/core/test_kv_sharing.py new file mode 100644 index 0000000000000..e6d37b1d63c8c --- /dev/null +++ b/tests/v1/core/test_kv_sharing.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec +from vllm.v1.worker.utils import add_kv_sharing_layers_to_kv_cache_groups + +pytestmark = pytest.mark.cpu_test + + +def new_kv_cache_spec(): + return FullAttentionSpec(16, 1, 1, torch.float32, False) + + +def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): + """ + Test initializing KV cache sharing with different attention groups. + Layers in the same KV cache group might be placed in different attn groups + if they have different attention backends. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + # Layers 0 and 1 both belong in KV cache group 0 + # However, if they have different attention backends, they will be + # placed in different attention groups for KV cache group 0 + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], new_kv_cache_spec()), + ] + + add_kv_sharing_layers_to_kv_cache_groups( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + ) + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + ] + + +def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): + """ + Test case assuming that all layers in the same KV cache group have the same + attention backends. This is true for most models. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], new_kv_cache_spec()), + ] + + add_kv_sharing_layers_to_kv_cache_groups( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + ) + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + ] + + +def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): + """ + Test KV sharing set up when no attention groups are provided. + This is the case for the TPU model runner, which doesn't have + support for attention groups yet. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()), + KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()), + ] + + add_kv_sharing_layers_to_kv_cache_groups( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + ) + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 2 + assert kv_cache_groups[0].layer_names == ["model.layers.0", "model.layers.2"] + assert kv_cache_groups[1].layer_names == ["model.layers.1", "model.layers.3"] diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 89824768ed909..d08c1bcc57bd5 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,18 +8,46 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams -from vllm.utils import sha256, sha256_cbor_64bit -from vllm.v1.core.block_pool import BlockPool +from vllm.utils import sha256, sha256_cbor +from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, - get_request_block_hasher, - hash_block_tokens, init_none_hash) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, SlidingWindowSpec) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashWithGroupId, + KVCacheBlock, + get_block_hash, + get_group_id, + get_request_block_hasher, + hash_block_tokens, + init_none_hash, + make_block_hash_with_group_id, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) + +pytestmark = pytest.mark.cpu_test + + +@pytest.fixture(autouse=True) +def _auto_init_hash_fn(request): + hash_fn: Callable + if "hash_fn" in request.fixturenames: + hash_fn = init_none_hash(request.getfixturevalue("hash_fn")) + else: + hash_fn = sha256 + init_none_hash(hash_fn) def make_request( @@ -32,24 +60,29 @@ def make_request( prompt_logprobs: Optional[int] = None, cache_salt: Optional[str] = None, ): - if mm_positions is None: - mm_kwargs = None - else: - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_positions) + mm_features = [] + if mm_positions is not None: + for j, position in enumerate(mm_positions): + identifier = mm_hashes[j] if mm_hashes else f"hash_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) - return Request(request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams( - max_tokens=17, prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - block_hasher=get_request_block_hasher(block_size, hash_fn)) + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -59,46 +92,41 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, 1, 1, torch.float32), ) ], ) -def make_kv_cache_config_hybrid_model(block_size: int, - num_blocks: int) -> KVCacheConfig: +def make_kv_cache_config_hybrid_model( + block_size: int, num_blocks: int +) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, 1, 1, torch.float32), ), KVCacheGroupSpec( ["layer2"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - False, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, 1, 1, torch.float32, sliding_window=2 * block_size + ), ), KVCacheGroupSpec( ["layer3"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - False, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, 1, 1, torch.float32, sliding_window=2 * block_size + ), ), ], ) -@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) -def test_prefill(hash_algo): +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_prefill(hash_fn): block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -106,10 +134,6 @@ def test_prefill(hash_algo): enable_caching=True, ) - # choose the hash function according to the parameter - hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else - sha256 if hash_algo == "sha256" else hash) - # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -122,41 +146,41 @@ def test_prefill(hash_algo): assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): - block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata - for block_id in (4, ): + for block_id in (4,): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -175,30 +199,27 @@ def test_prefill(hash_algo): # [unique_req1 (5)] # [common (3, 2, 1)] assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids, block_size, - hash_fn) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([6], ) + blocks = manager.allocate_slots( + req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([6],) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. assert free_block_queue.num_free_blocks == 6 - assert all( - [b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) + assert all([b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6 manager.free(req2) @@ -208,17 +229,23 @@ def test_prefill(hash_algo): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 10, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req3, 16 * 10, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], ) + assert blocks is not None and blocks.get_block_ids() == ( + [7, 8, 9, 10, 4, 5, 6, 3, 2, 1], + ) assert free_block_queue.num_free_blocks == 0 - assert (free_block_queue.fake_free_list_head.next_free_block - is free_block_queue.fake_free_list_tail) - assert (free_block_queue.fake_free_list_tail.prev_free_block - is free_block_queue.fake_free_list_head) + assert ( + free_block_queue.fake_free_list_head.next_free_block + is free_block_queue.fake_free_list_tail + ) + assert ( + free_block_queue.fake_free_list_tail.prev_free_block + is free_block_queue.fake_free_list_head + ) def test_prefill_hybrid_model(): @@ -229,7 +256,7 @@ def test_prefill_hybrid_model(): enable_caching=True, ) - hash_fn = hash + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(block_size)] @@ -243,24 +270,27 @@ def test_prefill_hybrid_model(): assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, - 8], [9, 10, 11, 12]) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ( + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + ) # Check full block metadata parent_block_hash = None - for length, block_ids in zip((1, 2, 3), - ((1, 5, 9), (2, 6, 10), (3, 7, 11))): - block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) - for block_id in block_ids: - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + for length, block_ids in zip((1, 2, 3), ((1, 5, 9), (2, 6, 10), (3, 7, 11))): + block_tokens = tuple(all_token_ids[(length - 1) * 16 : length * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) + for group_id, block_id in enumerate(block_ids): + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == group_id assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, 8, 12): @@ -270,18 +300,16 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, - 7], [0, 10, 11]) + assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([13], [14], [15]) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: if block != manager.block_pool.null_block: @@ -292,84 +320,103 @@ def test_prefill_hybrid_model(): manager.free(req1) cached_block_hash_to_block_bak = copy.copy( - manager.block_pool.cached_block_hash_to_block) + manager.block_pool.cached_block_hash_to_block._cache + ) - def test_partial_request_hit(request_id: str, - hash_to_evict: list[BlockHashWithGroupId], - expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids, - block_size, hash) + def test_partial_request_hit( + request_id: str, + hash_to_evict: list[BlockHashWithGroupId], + expect_hit_length: int, + ): + req = make_request( + request_id, common_token_ids + unique_token_ids, block_size, sha256 + ) for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block.pop( - hash_with_group_id) + manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block[ - hash_with_group_id] = cached_block_hash_to_block_bak[ - hash_with_group_id] + manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = ( + cached_block_hash_to_block_bak[hash_with_group_id] + ) manager.free(req) # Evict the blocks outside sliding window, does not affect the hit length. - test_partial_request_hit("2", [ - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2) - ], 3) + test_partial_request_hit( + "2", + [ + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 3, + ) # Evict the first block of full attention, makes total cache miss. - test_partial_request_hit("3", [ - BlockHashWithGroupId(block_hashes[0], 0), - ], 0) + test_partial_request_hit( + "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0 + ) # Evict the last block of all layers, reduces the hit length to 2. - test_partial_request_hit("4", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[2], 1), - BlockHashWithGroupId(block_hashes[2], 2), - ], 2) + test_partial_request_hit( + "4", + [ + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[2], 1), + make_block_hash_with_group_id(block_hashes[2], 2), + ], + 2, + ) # Evict the last block of full attention, reduces the hit length to 2. - test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], - 2) + test_partial_request_hit( + "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2 + ) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], - 2) + test_partial_request_hit( + "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2 + ) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], - 2) + test_partial_request_hit( + "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2 + ) # Evict different set of blocks for full attention and sliding window makes # total cache miss. # The cache hit length of full attention is 1 * block_size. # The cache hit length of sliding window is 2 * block_size. - # Then it is cache miss as the two type of layers have different hit length. - test_partial_request_hit("8", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2), - ], 0) + # Then it is cache miss as the two type of layers + # have different hit length. + test_partial_request_hit( + "8", + [ + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 0, + ) def test_prefill_plp(): - '''Test prefill with APC and some prompt logprobs (plp) requests. + """Test prefill with APC and some prompt logprobs (plp) requests. 1. Schedule plp request and validate APC block allocation 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks - ''' + """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) - # the default hash function is hash - hash_fn = hash + # the default hash function is sha256 + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -379,34 +426,31 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", - all_token_ids, - block_size, - hash_fn, - prompt_logprobs=5) + req0 = make_request("0", all_token_ids, block_size, hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): - block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata - for block_id in (4, ): + for block_id in (4,): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -414,17 +458,16 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -442,29 +485,27 @@ def test_prefill_plp(): # [unique_req1 (5)] # [common (3, 2, 1)] assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks unique_token_ids = [3] * 6 - req2 = make_request("2", - common_token_ids + unique_token_ids, - block_size, - hash_fn, - prompt_logprobs=5) + req2 = make_request( + "2", common_token_ids + unique_token_ids, block_size, hash_fn, prompt_logprobs=5 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes - assert block_ids != ([1, 2, 3, 4], ) + assert block_ids != ([1, 2, 3, 4],) # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -488,26 +529,29 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids, block_size, - hash) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 4, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-1].block_hash is None + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-1] + .block_hash + is None + ) # Append slots with allocating a new block. req0.num_computed_tokens = 59 @@ -515,14 +559,22 @@ def test_decode(): # the preallocated block. for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 19, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 19, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 1 - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-2].block_hash is not None - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-1].block_hash is None + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-2] + .block_hash + is not None + ) + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-1] + .block_hash + is None + ) def test_evict(): @@ -534,26 +586,27 @@ def test_evict(): ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id)), block_size, hash) + req0 = make_request("0", list(range(last_token_id)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial + blocks = manager.allocate_slots( + req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + # 5 full + 1 partial + assert blocks is not None and len(blocks.blocks[0]) == 6 # 3 blocks. - req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16)), block_size, - hash) + req1 = make_request( + "1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, sha256 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 3 # 3 full blocks + blocks = manager.allocate_slots( + req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -563,19 +616,18 @@ def test_evict(): manager.free(req1) assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == ([1, 2], ) + assert computed_blocks.get_block_ids() == ([1, 2],) assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([10], ) + blocks = manager.allocate_slots( + req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([10],) assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -593,31 +645,30 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens)), block_size, hash) + req = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 1 + blocks = manager.allocate_slots( + req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 1 # Deallocate the block. manager.free(req) # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1)), block_size, hash) + req = make_request("1", list(range(num_tokens - 1)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 1 + blocks = manager.allocate_slots( + req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 1 - assert manager.block_pool.blocks[blocks.blocks[0] - [0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks.blocks[0][0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -634,26 +685,27 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 1 + blocks = manager.allocate_slots( + req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), - block_size, hash) + req1 = make_request( + "1", list(range(num_tokens, num_tokens * 2)), block_size, sha256 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 1 + blocks = manager.allocate_slots( + req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 # Free the blocks. @@ -662,16 +714,19 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size - blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 1 + blocks = manager.allocate_slots( + req2, + num_tokens * 2 - num_tokens, + len(computed_blocks.blocks[0]) * 16, + computed_blocks, + ) + assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -686,49 +741,48 @@ def test_basic_prefix_caching_disabled(): enable_caching=False, ) - req1 = make_request("1", list(range(10)), block_size, - hash) # 2 blocks and some more + req1 = make_request( + "1", list(range(10)), block_size, sha256 + ) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 3 + blocks = manager.allocate_slots( + req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) # No caching. - req2 = make_request("2", list(range(16)), block_size, - hash) # shared prefix + req2 = make_request("2", list(range(16)), block_size, sha256) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert len(blocks.blocks[0]) == 4 + blocks = manager.allocate_slots( + req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4)), block_size, hash) + req3 = make_request("3", list(range(4)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req3, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert not blocks -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_cache_blocks(hash_fn): """ This is a unit test that tests the correctness of the _cache_full_blocks function of KVCacheManager. """ - init_none_hash(hash_fn) block_size = 4 block_pool = BlockPool( @@ -757,7 +811,8 @@ def test_cache_blocks(hash_fn): assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) - # Test that blocks that don't start from the beginning are cached correctly. + # Test that blocks that don't start from the beginning are cached + # correctly. blocks += [KVCacheBlock(block_id=2)] block_pool.cache_full_blocks( request=req, @@ -783,7 +838,7 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14)), block_size, hash) + req = make_request("0", list(range(14)), block_size, sha256) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] @@ -817,30 +872,48 @@ def test_cache_blocks_multi_group(): # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[0, 1]) is None + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) + is None + ) def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -863,58 +936,73 @@ def test_mm_prefix_caching(): # A unique image plus some text tokens. unique_token_ids = [-1] * 7 + [100] * 4 all_token_ids = common_token_ids + unique_token_ids - mm_positions = common_mm_positions + [ - PlaceholderRange(offset=48, length=7) - ] + mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)] mm_hashes = common_mm_hashes + ["ccc"] - req0 = make_request("0", - all_token_ids, - block_size, - hash, - mm_positions=mm_positions, - mm_hashes=mm_hashes) + req0 = make_request( + "0", + all_token_ids, + block_size, + sha256, + mm_positions=mm_positions, + mm_hashes=mm_hashes, + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("aaa", ) - assert block_hashes[1].extra_keys == ("aaa", "bbb") - assert block_hashes[2].extra_keys == ("bbb", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), ("aaa",)) + ) + assert block_hashes[1] == sha256( + ( + block_hashes[0], + tuple(all_token_ids[block_size : block_size * 2]), + ("aaa", "bbb"), + ) + ) + assert block_hashes[2] == sha256( + ( + block_hashes[1], + tuple(all_token_ids[block_size * 2 : block_size * 3]), + ("bbb",), + ) + ) - blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None + assert blocks.get_block_ids() == ([1, 2, 3, 4],) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys == ("ccc", ) + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(all_token_ids[3 * block_size :] + [8] * 5), ("ccc",)) + ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 all_token_ids = common_token_ids + unique_token_ids - mm_positions = common_mm_positions + [ - PlaceholderRange(offset=48, length=7) - ] + mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)] mm_hashes = common_mm_hashes + ["ccc"] - req1 = make_request("1", - all_token_ids, - block_size, - hash, - mm_positions=mm_positions, - mm_hashes=mm_hashes) + req1 = make_request( + "1", + all_token_ids, + block_size, + sha256, + mm_positions=mm_positions, + mm_hashes=mm_hashes, + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -935,39 +1023,46 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt1", ) - assert block_hashes[1].extra_keys is None - assert block_hashes[2].extra_keys is None + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1",)) + ) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None) + ) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None) + ) - blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None + assert blocks.get_block_ids() == ([1, 2, 3, 4],) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # Now one more block that should not have extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys is None + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(token_ids[3 * block_size :] + [8] * 5), None) + ) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -975,13 +1070,21 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = req2.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt2", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2",)) + ) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None) + ) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None) + ) def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -1000,26 +1103,28 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids, block_size, hash) + req0 = make_request("0", common_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req0, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id] + req0.request_id + ] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2, block_size, hash) + req1 = make_request("1", common_token_ids * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req1, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[ - req1.request_id] + req1.request_id + ] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) @@ -1028,26 +1133,32 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2, block_size, hash) + req2 = make_request("2", [7] * block_size * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * block_size, - computed_blocks) + manager.allocate_slots( + req2, + block_size * 2, + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3, block_size, hash) + req3 = make_request("3", common_token_ids * 3, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) is None + assert ( + manager.allocate_slots( + req3, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + is None + ) # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. @@ -1065,20 +1176,20 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, block_size, hash) + req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids, block_size, hash) + req1 = make_request("1", all_token_ids, block_size, sha256) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 - blocks = manager.allocate_slots(req1, 7, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -1105,13 +1216,13 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16)), block_size, hash) + req = make_request("0", list(range(16)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req, 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -1120,15 +1231,9 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): pool = BlockPool(num_gpu_blocks=4, enable_caching=True) - block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10, - token_ids=(100, )), - group_id=1000) - block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20, - token_ids=(200, )), - group_id=2000) - block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30, - token_ids=(300, )), - group_id=3000) + block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) + block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) + block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) block_hashes = [ block_hash0, block_hash1, @@ -1140,49 +1245,36 @@ def test_maybe_evict_cached_block(): # Manually add all blocks to cached_blocks for block, block_hash in zip(pool.blocks, block_hashes): block.block_hash = block_hash - pool.cached_block_hash_to_block[block_hash][block.block_id] = block + pool.cached_block_hash_to_block.insert(block_hash, block) block0, block1, block2, block3 = pool.blocks - assert pool.cached_block_hash_to_block == { + assert pool.cached_block_hash_to_block._cache == { block_hash0: { block0.block_id: block0, - block3.block_id: block3 + block3.block_id: block3, }, - block_hash1: { - block1.block_id: block1 - }, - block_hash2: { - block2.block_id: block2 - } + block_hash1: block1, + block_hash2: block2, } # Evict block1 pool._maybe_evict_cached_block(block1) - assert pool.cached_block_hash_to_block == { - block_hash0: { - block0.block_id: block0, - block3.block_id: block3 - }, - block_hash2: { - block2.block_id: block2 - } + assert pool.cached_block_hash_to_block._cache == { + block_hash0: {block0.block_id: block0, block3.block_id: block3}, + block_hash2: block2, } # Evict block0: block_hash0 entry should NOT be removed, as block3 # also use the same hash pool._maybe_evict_cached_block(block0) - assert pool.cached_block_hash_to_block == { - block_hash0: { - block3.block_id: block3 - }, - block_hash2: { - block2.block_id: block2 - } + assert pool.cached_block_hash_to_block._cache == { + block_hash0: {block3.block_id: block3}, + block_hash2: block2, } # Evict block2 pool._maybe_evict_cached_block(block2) - assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}} + assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}} # Evict block3 pool._maybe_evict_cached_block(block3) - assert pool.cached_block_hash_to_block == {} + assert pool.cached_block_hash_to_block._cache == {} @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) @@ -1202,13 +1294,16 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() block = events[-1] - assert (len(block.block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + assert ( + len(block.block_hashes) + == blocks_to_cache + == len(manager.block_pool.cached_block_hash_to_block) + ) assert len(block.token_ids) == block.block_size * len(block.block_hashes) assert len(manager.block_pool.kv_event_queue) == 0 @@ -1218,16 +1313,19 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens)), block_size, hash) + req1 = make_request("1", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() for blocks in events[:-1]: assert blocks.block_hashes[0] in stored_block_hash assert len(events) == blocks_to_cache + 1 - assert (isinstance(events[-2], BlockRemoved)) - assert (len(events[-1].block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + assert isinstance(events[-2], BlockRemoved) + assert ( + len(events[-1].block_hashes) + == blocks_to_cache + == len(manager.block_pool.cached_block_hash_to_block) + ) # All Blocks Cleared # Should see a single all blocks cleared event @@ -1240,7 +1338,7 @@ def test_kv_cache_events(blocks_to_cache: int): def test_eagle_enabled_removes_last_block(): - """Verify Eagle does NOT remove blocks when request + """Verify Eagle does NOT remove blocks when request length is divisible by block size.""" block_size = 16 manager = KVCacheManager( @@ -1252,17 +1350,17 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids, block_size, hash) + req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) + req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1283,17 +1381,17 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1309,13 +1407,12 @@ def test_eagle_with_sliding_window(): head_size=1, dtype=torch.float32, sliding_window=block_size, - use_mla=False, ) manager = KVCacheManager( KVCacheConfig( num_blocks=10, kv_cache_tensors=[], - kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], + kv_cache_groups=[KVCacheGroupSpec(["layer"], sliding_window_spec)], ), max_model_len=8192, enable_caching=True, @@ -1324,37 +1421,118 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # record the block hash of the first block in the request for later use block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # Evict the first block in the request - assert manager.block_pool.get_cached_block( - block_hash_first_block, kv_cache_group_ids=[0]) is not None - manager.block_pool.cached_block_hash_to_block.pop( - BlockHashWithGroupId(block_hash_first_block, 0)) + assert ( + manager.block_pool.get_cached_block( + block_hash_first_block, kv_cache_group_ids=[0] + ) + is not None + ) + manager.block_pool.cached_block_hash_to_block._cache.pop( + make_block_hash_with_group_id(block_hash_first_block, 0) + ) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids, - block_size, hash) + req_after_evict = make_request( + "partial_eagle_after_evict", token_ids, block_size, sha256 + ) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0 + + +def test_block_lookup_cache_single_block_per_key(): + cache = BlockHashToBlockMap() + key0 = BlockHashWithGroupId(b"hash0") + key1 = BlockHashWithGroupId(b"hash1") + key2 = BlockHashWithGroupId(b"hash2") + block0 = KVCacheBlock(0) + block1 = KVCacheBlock(1) + + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + # key0 inserted + cache.insert(key0, block0) + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + # key1 inserted + cache.insert(key1, block1) + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # No block poped due to block_id mismatch + assert cache.pop(key0, 100) is None + assert cache.get_one_block(key0) is block0 + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # block poped with (key0, block ID 0) + assert cache.pop(key0, 0) is block0 + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # No block poped due to block_id mismatch + assert cache.pop(key0, 1) is None + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is block1 + assert cache.get_one_block(key2) is None + # block poped with (key1, block ID 1) + assert cache.pop(key1, 1) is block1 + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + assert cache.get_one_block(key2) is None + + +def test_block_lookup_cache_multi_blocks_per_key(): + cache = BlockHashToBlockMap() + key0 = BlockHashWithGroupId(b"hash0") + key1 = BlockHashWithGroupId(b"hash1") + block00 = KVCacheBlock(0) + block01 = KVCacheBlock(1) + block10 = KVCacheBlock(10) + block11 = KVCacheBlock(11) + + assert cache.get_one_block(key0) is None + assert cache.get_one_block(key1) is None + + cache.insert(key0, block00) + cache.insert(key0, block01) + cache.insert(key1, block10) + cache.insert(key1, block11) + + assert cache.get_one_block(key0) is block00 + assert cache.pop(key0, 0) is block00 + assert cache.get_one_block(key0) is block01 + assert cache.pop(key0, 1) is block01 + assert cache.get_one_block(key0) is None + assert cache.pop(key0, 2) is None + + assert cache.get_one_block(key1) is block10 + assert cache.pop(key1, 10) is block10 + assert cache.get_one_block(key1) is block11 + assert cache.pop(key1, 11) is block11 + assert cache.get_one_block(key1) is None + assert cache.pop(key1, 12) is None diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a2efbec0e610a..ff15af70b88bc 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,19 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses from typing import Optional from unittest.mock import Mock import pytest import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.config import ( + CacheConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -21,6 +35,8 @@ from vllm.v1.structured_output.request import StructuredOutputRequest from .utils import EOS_TOKEN_ID, create_requests, create_scheduler +pytestmark = pytest.mark.cpu_test + def test_add_requests(): scheduler = create_scheduler() @@ -39,8 +55,7 @@ def test_finish_request(): scheduler.add_request(request) for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) assert request.request_id not in scheduler.requests assert len(scheduler.waiting) == 9 - i @@ -52,23 +67,25 @@ def test_get_num_unfinished_requests(): scheduler.add_request(request) for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_STOPPED) + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED) assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 -@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ - (None, None), - (True, 5), -]) -def test_schedule(enable_prefix_caching: Optional[bool], - prompt_logprobs: Optional[int]): - '''Test scheduling. +@pytest.mark.parametrize( + "enable_prefix_caching, prompt_logprobs", + [ + (None, None), + (True, 5), + ], +) +def test_schedule( + enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int] +): + """Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs - ''' + """ scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) - requests = create_requests(num_requests=10, - prompt_logprobs=prompt_logprobs) + requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs) for request in requests: scheduler.add_request(request) @@ -90,8 +107,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], def test_schedule_multimodal_requests(): scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf") - mm_positions = [[PlaceholderRange(offset=i, length=100)] - for i in range(10)] + mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)] requests = create_requests( num_requests=10, num_tokens=200, @@ -124,8 +140,7 @@ def test_schedule_partial_requests(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, ) - mm_positions = [[PlaceholderRange(offset=100, length=600)] - for _ in range(3)] + mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)] requests = create_requests( num_requests=3, num_tokens=800, @@ -148,8 +163,10 @@ def test_schedule_partial_requests(): # The third request is also scheduled partially. # The <img> tokens are not scheduled because of the encoder budget. assert output.num_scheduled_tokens[requests[2].request_id] == 100 + req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, # Only the first request has a sampled token id because # the rest requests are still being prefilled. sampled_token_ids=[[0], [], []], @@ -182,9 +199,9 @@ def test_no_mm_input_chunking(): max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] - requests = create_requests(num_requests=1, - num_tokens=1200, - mm_positions=mm_positions) + requests = create_requests( + num_requests=1, num_tokens=1200, mm_positions=mm_positions + ) for request in requests: scheduler.add_request(request) @@ -195,8 +212,10 @@ def test_no_mm_input_chunking(): # We want to only see the 400 text tokens at the start scheduled assert output.num_scheduled_tokens[requests[0].request_id] == 400 + req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, @@ -253,8 +272,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): assert output.num_scheduled_tokens[requests[1].request_id] == 400 # The third request is also scheduled partially - 1024 - 400 - 400 = 224. assert output.num_scheduled_tokens[requests[2].request_id] == 224 + req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, @@ -278,6 +299,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): # All the remaining tokens in the third request are processed. model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], logprobs=None, prompt_logprobs_dict={}, @@ -291,8 +313,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 - assert output2.num_scheduled_tokens[ - requests[2].request_id] == 800 - 224 - 224 + assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224 def test_stop_via_update_from_output(): @@ -310,30 +331,31 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, + num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [], - requests[1].request_id: [10] + requests[1].request_id: [10], }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, - grammar_bitmask=None) + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - sampled_token_ids=[[EOS_TOKEN_ID], - [10, - 11]], # First request hits EOS, second continues + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[ + [EOS_TOKEN_ID], + [10, 11], + ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -347,9 +369,7 @@ def test_stop_via_update_from_output(): # Test case 2: Stop on custom stop token scheduler = create_scheduler(num_speculative_tokens=2) - requests = create_requests(num_requests=2, - max_tokens=10, - stop_token_ids=[42, 43]) + requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req @@ -359,30 +379,28 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, + num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2}, total_num_scheduled_tokens=5, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 42], - requests[1].request_id: [13] + requests[1].request_id: [13], }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - sampled_token_ids=[[10, 42, 12], - [13, 14]], # First request hits stop token + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -407,30 +425,28 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, + num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1}, total_num_scheduled_tokens=4, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 11], - requests[1].request_id: [] + requests[1].request_id: [], }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - sampled_token_ids=[[10, 11, 12], - [13]], # First request exceeds max_tokens + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -439,8 +455,7 @@ def test_stop_via_update_from_output(): assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED assert requests[0].request_id in scheduler.finished_req_ids - assert list(requests[0].output_token_ids) == [10, 11 - ] # Truncated to max_tokens + assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag @@ -457,17 +472,17 @@ def test_stop_via_update_from_output(): num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, + scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, - grammar_bitmask=None) + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], logprobs=None, prompt_logprobs_dict={}, @@ -482,12 +497,106 @@ def test_stop_via_update_from_output(): assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] -@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ - (None, None), - (True, 5), -]) -def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], - prompt_logprobs: Optional[int]): +def test_check_stop_min_tokens(): + """Test that requests don't stop when min_tokens requirement isn't met.""" + from vllm.v1.core.sched.utils import check_stop + + # Test case 1: num_output_tokens < min_tokens + # Should return False (don't stop) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=5, + ) + request = Request( + request_id="0", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + # Simulate having generated 3 output tokens (less than min_tokens=5) + request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present + + result = check_stop(request, max_model_len=100) + assert result is False, "Should not stop when num_output_tokens<min_tokens" + + # Test case 2: num_output_tokens >= min_tokens + # Should follow normal stopping logic (stop on EOS) + request.append_output_token_ids( + [ + 10, + 11, + 12, + 13, + 14, + EOS_TOKEN_ID, + ] + ) # 6 tokens > min_tokens + + result = check_stop(request, max_model_len=100) + assert result is True, "Should stop on EOS when min_tokens met" + assert request.status == RequestStatus.FINISHED_STOPPED + + # Test case 3: min_tokens = 0, should follow normal stopping logic + sampling_params_no_min = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=0, + ) + request_no_min = Request( + request_id="1", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params_no_min, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + request_no_min.append_output_token_ids([10, EOS_TOKEN_ID]) + + result = check_stop(request_no_min, max_model_len=100) + assert result is True, "Should stop on EOS when min_tokens=0" + assert request_no_min.status == RequestStatus.FINISHED_STOPPED + + # Test case 4: min_tokens > 0 with stop token (not EOS) + sampling_params_stop = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=5, + stop_token_ids=[42], + ) + request_stop = Request( + request_id="2", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params_stop, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + # Only 3 output tokens, less than min_tokens=5, but has stop token + request_stop.append_output_token_ids([10, 11, 42]) + result = check_stop(request_stop, max_model_len=100) + assert result is False, "Should not stop when num_output_tokens<min_tokens" + + # Test case 5: min_tokens met, should stop on stop token + request_stop.append_output_token_ids( + [10, 11, 12, 13, 14, 42] + ) # 6 tokens >= min_tokens=5 + + result = check_stop(request_stop, max_model_len=100) + assert result is True, "Should stop on stop token when min_tokens met" + assert request_stop.status == RequestStatus.FINISHED_STOPPED + assert request_stop.stop_reason == 42 + + +@pytest.mark.parametrize( + "enable_prefix_caching, prompt_logprobs", + [ + (None, None), + (True, 5), + ], +) +def test_schedule_concurrent_batches( + enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int] +): scheduler = create_scheduler( max_num_batched_tokens=1024, max_num_seqs=2, @@ -503,19 +612,18 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], scheduler.add_request(requests[0]) scheduler_output0 = scheduler.schedule() assert len(scheduler_output0.scheduled_new_reqs) == 1 - assert scheduler_output0.num_scheduled_tokens[ - requests[0].request_id] == 512 + assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512 # The first request is still running, so only schedule the second request. scheduler.add_request(requests[1]) scheduler_output1 = scheduler.schedule() assert len(scheduler_output1.scheduled_new_reqs) == 1 - assert scheduler_output1.num_scheduled_tokens[ - requests[1].request_id] == 512 + assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512 # Model output of the first request. model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, @@ -532,6 +640,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], # Model output of the second request. model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, @@ -543,10 +652,12 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. - scheduler = create_scheduler(max_num_batched_tokens=100, - block_size=16, - num_blocks=11, - enable_prefix_caching=False) + scheduler = create_scheduler( + max_num_batched_tokens=100, + block_size=16, + num_blocks=11, + enable_prefix_caching=False, + ) requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. @@ -566,6 +677,7 @@ def test_preempt_during_execution(): # Get the output of the first request. model_runner_output0 = ModelRunnerOutput( req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, @@ -582,6 +694,7 @@ def test_preempt_during_execution(): model_runner_output1 = ModelRunnerOutput( req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[42]], logprobs=None, prompt_logprobs_dict={}, @@ -601,13 +714,16 @@ def test_preempt_during_execution(): [ ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], - (2, 3, 3, [2, 1])), # multiple sequences + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence ([[]], [[5]], (0, 0, 0, [0])), # empty sequence - ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], - (2, 6, 3, [2, 1, 0])), # multiple mismatches - ]) + ( + [[1, 2, 3], [4, 5, 6]], + [[1, 2, 7], [4, 8]], + (2, 6, 3, [2, 1, 0]), + ), # multiple mismatches + ], +) def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): """Test scheduling behavior with speculative decoding. @@ -619,9 +735,11 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) req_ids = [] + req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) + req_to_index[request.request_id] = i # Schedule a decode, which will also draft speculative tokens output = scheduler.schedule() @@ -634,13 +752,13 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): model_runner_output = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) - engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(output, model_runner_output) draft_token_ids = DraftTokenIds(req_ids, spec_tokens) scheduler.update_draft_token_ids(draft_token_ids) @@ -655,38 +773,44 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): # No draft or accepted tokens counted yet assert not engine_core_outputs or ( - engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None) + engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None + ) # Schedule the speculated tokens for validation output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 0 # The sampled token and speculated tokens - assert output.total_num_scheduled_tokens == \ - len(requests) + sum(len(ids) for ids in spec_tokens) + assert output.total_num_scheduled_tokens == len(requests) + sum( + len(ids) for ids in spec_tokens + ) for i in range(len(requests)): req_id = requests[i].request_id assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i]) if spec_tokens[i]: - assert len(output.scheduled_spec_decode_tokens[req_id]) == \ - len(spec_tokens[i]) + assert len(output.scheduled_spec_decode_tokens[req_id]) == len( + spec_tokens[i] + ) else: assert req_id not in output.scheduled_spec_decode_tokens model_runner_output = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_to_index, sampled_token_ids=output_tokens, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) - engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(output, model_runner_output) - scheduler_stats = engine_core_outputs[0].scheduler_stats \ - if engine_core_outputs else None + scheduler_stats = ( + engine_core_outputs[0].scheduler_stats if engine_core_outputs else None + ) if expected[0] == 0: + assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is None else: + assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats assert stats.num_drafts == expected[0] @@ -723,18 +847,25 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req in requests: - blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req.request_id]) + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks[req.request_id] hashes = req.block_hashes - assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS) + assert ( + scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].num_cached_block[req.request_id] + == EXPECTED_TOTAL_BLOCKS + ) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. BLOCKS_PER_REQ = num_tokens / block_size - assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == - num_total_blocks - num_requests * BLOCKS_PER_REQ) + assert ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks() + == num_total_blocks - num_requests * BLOCKS_PER_REQ + ) def _step_until_done( @@ -773,32 +904,38 @@ def test_kv_connector_basic(): enable_prefix_caching=True, use_kv_connector=True, ) - NUM_TOTAL_BLOCKS = ( - scheduler.kv_cache_manager.block_pool.get_num_free_blocks()) + NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() BLOCK_SIZE = scheduler.cache_config.block_size # Mock External Cache Hit. NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, False) + NUM_MATCHED_NEW_TOKENS, + False, + ) ###################################################### # FIRST SET OF REQUESTS - External Hit Only NUM_REQUESTS = 2 NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2 MAX_TOKENS = 3 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] + req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) + req_to_index[request.request_id] = i MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, @@ -815,15 +952,17 @@ def test_kv_connector_basic(): ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, - NUM_REQUESTS, NUM_TOTAL_BLOCKS) + _assert_right_kv_cache_manager( + scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS + ) # Continue Generation until done. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # Confirm we clean up the memory properly. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_TOTAL_BLOCKS + assert ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS + ) ###################################################### # SECOND SET OF REQUESTS - Local And External Hit @@ -831,17 +970,22 @@ def test_kv_connector_basic(): # We will get a local prefix cache hit for the first # NUM_TOKENS_PREFIX tokens since they are used above. NUM_TOKENS = NUM_TOKENS_PREFIX * 2 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] + req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) + req_to_index[request.request_id] = i MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, @@ -855,19 +999,23 @@ def test_kv_connector_basic(): output=output, num_requests=NUM_REQUESTS, # Just the incremental tokens after local + remote cache hit. - expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX - - NUM_MATCHED_NEW_TOKENS)) + expected_num_scheduled_tokens=( + NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS + ), + ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, - NUM_REQUESTS, NUM_TOTAL_BLOCKS) + _assert_right_kv_cache_manager( + scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS + ) # Continue Generation until done. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # Confirm we clean up the memory properly. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_TOTAL_BLOCKS + assert ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS + ) def test_kv_connector_unable_to_allocate(): @@ -888,24 +1036,31 @@ def test_kv_connector_unable_to_allocate(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, False) + NUM_MATCHED_NEW_TOKENS, + False, + ) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. NUM_REQUESTS = 2 NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE MAX_TOKENS = 2 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] + req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) + req_to_index[request.request_id] = i MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, @@ -914,33 +1069,33 @@ def test_kv_connector_unable_to_allocate(): # Just one request should be running. output = scheduler.schedule() - _assert_right_scheduler_output(output, - num_requests=1, - expected_num_scheduled_tokens=NUM_TOKENS - - NUM_MATCHED_NEW_TOKENS) + _assert_right_scheduler_output( + output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # All memory should be freed, with one request waiting. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 # Just one request should be running. output = scheduler.schedule() - _assert_right_scheduler_output(output, - num_requests=1, - expected_num_scheduled_tokens=NUM_TOKENS - - NUM_MATCHED_NEW_TOKENS) + _assert_right_scheduler_output( + output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # All memory should be freed, with no requests waiting / running. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 0 @@ -965,7 +1120,9 @@ def test_kv_connector_handles_preemption(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, False) + NUM_MATCHED_NEW_TOKENS, + False, + ) # Create two requests. # Both can be scheduled at first, but the second request @@ -973,17 +1130,22 @@ def test_kv_connector_handles_preemption(): NUM_REQUESTS = 2 NUM_TOKENS = BLOCK_SIZE * 2 + 1 MAX_TOKENS = BLOCK_SIZE * 2 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] + req_to_index = {} for i, request in enumerate(requests): scheduler.add_request(request) req_ids.append(request.request_id) + req_to_index[request.request_id] = i MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), logprobs=None, prompt_logprobs_dict={}, @@ -996,7 +1158,8 @@ def test_kv_connector_handles_preemption(): output, # 2 remote kv cache hits. num_requests=2, - expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) assert len(scheduler.running) == 2 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) @@ -1006,7 +1169,8 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.running) == 2 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) @@ -1016,7 +1180,8 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) @@ -1029,14 +1194,14 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 0 # All memory should be freed since nothing is running. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 # Restarts the preempted request - generate 3rd token. # This will have a local and remote cache hit. @@ -1061,18 +1226,19 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.running) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 0 # All memory should be freed since nothing is running. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], + req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, sampled_token_ids=[[1000]] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, @@ -1093,14 +1259,24 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks + ) + == 0 + ) + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].num_cached_block + ) + == 0 + ) num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + ) + assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -1120,9 +1296,9 @@ def test_memory_leak(): NUM_REQUESTS = 5 NUM_TOKENS = 10 MAX_TOKENS = 10 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + requests = create_requests( + num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS + ) # Add each request. for request in requests: @@ -1156,7 +1332,7 @@ def create_scheduler_with_priority( max_model_len: Optional[int] = None, num_speculative_tokens: Optional[int] = None, ) -> Scheduler: - '''Create scheduler with priority policy enabled. + """Create scheduler with priority policy enabled. Args: model: model under test @@ -1168,7 +1344,7 @@ def create_scheduler_with_priority( Returns: {class}`Scheduler` instance with priority scheduling - ''' + """ if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -1187,9 +1363,11 @@ def create_scheduler_with_priority( seed=42, ) # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) + kwargs_cache = ( + {} + if enable_prefix_caching is None + else {"enable_prefix_caching": enable_prefix_caching} + ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, @@ -1197,16 +1375,21 @@ def create_scheduler_with_priority( cache_dtype="auto", **kwargs_cache, ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None + kv_transfer_config = ( + KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) + if use_kv_connector + else None + ) speculative_config: Optional[SpeculativeConfig] = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=num_speculative_tokens) + model="ngram", num_speculative_tokens=num_speculative_tokens + ) vllm_config = VllmConfig( scheduler_config=scheduler_config, @@ -1219,9 +1402,9 @@ def create_scheduler_with_priority( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) ], ) cache_config.num_gpu_blocks = num_blocks @@ -1230,18 +1413,21 @@ def create_scheduler_with_priority( kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, ) def create_requests_with_priority( - num_requests: int, - priorities: list[int], - arrival_times: Optional[list[float]] = None, - num_tokens: int = 10, - mm_positions: Optional[list[list[PlaceholderRange]]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): + num_requests: int, + priorities: list[int], + arrival_times: Optional[list[float]] = None, + num_tokens: int = 10, + mm_positions: Optional[list[list[PlaceholderRange]]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None, + starting_idx: int = 0, +): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests if arrival_times is not None: @@ -1249,27 +1435,33 @@ def create_requests_with_priority( else: arrival_times = [float(i) for i in range(num_requests)] - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs, + ) requests = [] for i in range(num_requests): + mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_position) - else: - mm_position = None - mm_kwargs = None + for j, position in enumerate(mm_position): + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) + request = Request( - request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, + request_id=f"{i + starting_idx}", + prompt_token_ids=[i + starting_idx] * num_tokens, sampling_params=sampling_params, pooling_params=None, - multi_modal_kwargs=mm_kwargs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, arrival_time=arrival_times[i], priority=priorities[i], @@ -1287,9 +1479,9 @@ def test_priority_scheduling_basic_ordering(): # Priority 0 (highest), 1, 2 (lowest) priorities = [2, 0, 1] # Add in non-priority order arrival_times = [1.0, 2.0, 3.0] # All different arrival times - requests = create_requests_with_priority(num_requests=3, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=3, priorities=priorities, arrival_times=arrival_times + ) # Add requests in non-priority order for request in requests: @@ -1315,9 +1507,9 @@ def test_priority_scheduling_arrival_time_tiebreaker(): # Create requests with same priority but different arrival times priorities = [1, 1, 1] # All same priority arrival_times = [3.0, 1.0, 2.0] # Different arrival times - requests = create_requests_with_priority(num_requests=3, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=3, priorities=priorities, arrival_times=arrival_times + ) # Add requests in non-arrival order for request in requests: @@ -1342,9 +1534,9 @@ def test_priority_scheduling_mixed_priority_and_arrival(): # Create requests with mixed priorities and arrival times priorities = [2, 1, 1, 0] # Mixed priorities arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times - requests = create_requests_with_priority(num_requests=4, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=4, priorities=priorities, arrival_times=arrival_times + ) # Add requests for request in requests: @@ -1381,7 +1573,7 @@ def test_priority_scheduling_preemption(): num_requests=2, priorities=[5, 5], # Low priority arrival_times=[1.0, 2.0], - num_tokens=30 # Large enough to consume significant memory + num_tokens=30, # Large enough to consume significant memory ) # Add and schedule low priority requests @@ -1394,6 +1586,9 @@ def test_priority_scheduling_preemption(): # Simulate model execution to move requests to running state model_output = ModelRunnerOutput( req_ids=[req.request_id for req in low_priority_requests], + req_id_to_index={ + req.request_id: i for i, req in enumerate(low_priority_requests) + }, sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, @@ -1410,7 +1605,7 @@ def test_priority_scheduling_preemption(): num_requests=1, priorities=[0], # High priority arrival_times=[3.0], - num_tokens=30 # Large enough to require significant memory + num_tokens=30, # Large enough to require significant memory )[0] scheduler.add_request(high_priority_request) @@ -1451,10 +1646,8 @@ def test_priority_scheduling_no_preemption_when_space_available(): # Add two low-priority running requests low_priority_requests = create_requests_with_priority( - num_requests=2, - priorities=[5, 5], - arrival_times=[1.0, 2.0], - num_tokens=30) + num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30 + ) for request in low_priority_requests: scheduler.add_request(request) @@ -1462,6 +1655,9 @@ def test_priority_scheduling_no_preemption_when_space_available(): output = scheduler.schedule() model_output = ModelRunnerOutput( req_ids=[req.request_id for req in low_priority_requests], + req_id_to_index={ + req.request_id: i for i, req in enumerate(low_priority_requests) + }, sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, prompt_logprobs_dict={}, @@ -1470,10 +1666,9 @@ def test_priority_scheduling_no_preemption_when_space_available(): scheduler.update_from_output(output, model_output) # Add high-priority request - high_priority_request = create_requests_with_priority(num_requests=1, - priorities=[0], - arrival_times=[3.0], - num_tokens=30)[0] + high_priority_request = create_requests_with_priority( + num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30 + )[0] scheduler.add_request(high_priority_request) @@ -1501,7 +1696,8 @@ def test_priority_scheduling_preemption_victim_selection(): num_requests=3, priorities=[3, 2, 0], # Different priorities: low, medium, high arrival_times=[1.0, 2.0, 3.0], - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1540,7 +1736,8 @@ def test_priority_scheduling_equal_priority_preemption(): num_requests=3, priorities=[2, 2, 2], # Same priority arrival_times=[3.0, 1.0, 2.0], # Different arrival times - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1576,7 +1773,8 @@ def test_priority_scheduling_waiting_queue_order(): num_requests=4, priorities=[3, 1, 2, 0], # Mixed priorities arrival_times=[1.0, 2.0, 3.0, 4.0], - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1611,9 +1809,9 @@ def test_priority_scheduling_fcfs_fallback(): # Create requests with same priority but different arrival times priorities = [1, 1, 1, 1] # All same priority arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times - requests = create_requests_with_priority(num_requests=4, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=4, priorities=priorities, arrival_times=arrival_times + ) # Add requests for request in requests: @@ -1643,7 +1841,8 @@ def test_priority_scheduling_with_limited_slots(): num_requests=4, priorities=[3, 1, 2, 0], # Mixed priorities arrival_times=[1.0, 2.0, 3.0, 4.0], - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1681,10 +1880,12 @@ def test_priority_scheduling_heap_property(): # Add requests in random priority order priorities = [5, 1, 8, 3, 2, 7, 4, 6] arrival_times = [float(i) for i in range(len(priorities))] - requests = create_requests_with_priority(num_requests=len(priorities), - priorities=priorities, - arrival_times=arrival_times, - num_tokens=10) + requests = create_requests_with_priority( + num_requests=len(priorities), + priorities=priorities, + arrival_times=arrival_times, + num_tokens=10, + ) # Add all requests for request in requests: @@ -1702,6 +1903,7 @@ def test_priority_scheduling_heap_property(): # Simulate completion to make room for next request model_output = ModelRunnerOutput( req_ids=[req.req_id], + req_id_to_index={req.req_id: 0}, sampled_token_ids=[[100]], logprobs=None, prompt_logprobs_dict={}, @@ -1710,8 +1912,7 @@ def test_priority_scheduling_heap_property(): scheduler.update_from_output(output, model_output) # Finish the request to make room for the next one - scheduler.finish_requests(req.req_id, - RequestStatus.FINISHED_STOPPED) + scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED) # Verify requests were scheduled in priority order (lowest value first) expected_priorities = sorted(priorities) @@ -1730,18 +1931,16 @@ def test_schedule_skip_tokenizer_init(): def test_schedule_skip_tokenizer_init_structured_output_request(): scheduler = create_scheduler(skip_tokenizer_init=True) - guided_params = GuidedDecodingParams(regex="[0-9]+") + structured_outputs_params = StructuredOutputsParams(regex="[0-9]+") sampling_params = SamplingParams( ignore_eos=False, max_tokens=16, - guided_decoding=guided_params, + structured_outputs=structured_outputs_params, ) request = Request( request_id="0", prompt_token_ids=[0, 1], - multi_modal_kwargs=None, - multi_modal_hashes=None, - multi_modal_placeholders=None, + mm_features=None, sampling_params=sampling_params, pooling_params=None, eos_token_id=EOS_TOKEN_ID, @@ -1752,3 +1951,174 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 + + +def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): + """Test that priority scheduling preempts lower priority requests + when out of KV cache space.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + max_num_seqs=2, # Allow multiple requests + max_num_batched_tokens=200, + num_blocks=5, # Can hold 64 tokens (first block is null) + block_size=16, # Standard block size + use_kv_connector=True, + ) + + # Create a request and schedule it + request_low = create_requests_with_priority( + num_requests=1, + priorities=[1], + arrival_times=[0.0], + num_tokens=30, + starting_idx=0, + )[0] + scheduler.add_request(request_low) + # 1st schedule + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Simulate model execution - 1st decode + model_output = ModelRunnerOutput( + req_ids=[request_low.request_id], + req_id_to_index={request_low.request_id: 0}, + sampled_token_ids=[[100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Create a high priority request and schedule it + request_high = create_requests_with_priority( + num_requests=1, + priorities=[0], + arrival_times=[1.0], + num_tokens=32, + starting_idx=1, + )[0] + scheduler.add_request(request_high) + # 2nd schedule + output = scheduler.schedule() + # KV cache should be full at this point + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 2 + + # Simulate model execution - 2nd decode + requests = [request_low, request_high] + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[100] for _ in requests], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # 3rd schedule - this should trigger preemption + # req_low needs 32 tokens = 2 blocks + # req_high needs 33 tokens = 3 blocks + # so doesn't fit in 4 blocks. + output = scheduler.schedule() + + # Should have preempted req_low + assert len(output.scheduled_new_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 + + # Simulate model execution - 3rd decode + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[], [100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + # Finish the requests to make room for the preempted requests to resume + scheduler.update_from_output(output, model_output) + scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED) + + # 4th Schedule - this should trigger the resumption + output = scheduler.schedule() + scheduled_cached_reqs = output.scheduled_cached_reqs + resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption + + assert len(output.scheduled_new_reqs) == 0 + assert scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Preempted request resumed in scheduled_cached_reqs + assert len(resumed_from_preemption) == 1 + assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1 + assert resumed_from_preemption[0] + assert scheduled_cached_reqs.req_ids[0] == request_low.request_id + assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None + # Resumed tokens include 30 prompt tokens and 2 decoded tokens + assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32 + assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100 + + +@pytest.mark.parametrize( + ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), + [ + (True, False, True), + (False, False, False), + # Encoder-decoder models should always have it disabled + (False, True, False), + (True, True, False), + ], +) +def test_chunked_prefill_disabled_for_encoder_decoder( + enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool +) -> None: + """Validate that chunked prefill is appropriately disabled for + encoder-decoder models.""" + scheduler_config = SchedulerConfig( + enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=is_encoder_decoder, + ) + + # `is_encoder_decoder` should only be used during construction + # of the config, and otherwise stored in the model config. + assert "is_encoder_decoder" not in vars(scheduler_config) + assert "is_encoder_decoder" not in [ + f.name for f in dataclasses.fields(scheduler_config) + ] + _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config, is_encoder_decoder, expect_enabled + ) + + # Ensure it is retained in VllmConfig, even after its post-init. + vllm_config = VllmConfig(scheduler_config=scheduler_config) + _validate_chunked_prefill_settings_for_encoder_decoder( + vllm_config.scheduler_config, is_encoder_decoder, expect_enabled + ) + + +def _validate_chunked_prefill_settings_for_encoder_decoder( + scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool +) -> None: + """Validate chunked prefill settings in the scheduler config for + encoder-decoder models.""" + assert scheduler_config.chunked_prefill_enabled is expect_enabled + assert scheduler_config.enable_chunked_prefill is expect_enabled + if is_encoder_decoder: + # Encoder-decoder models should automatically disable chunked multimodal + # inputs as well + assert scheduler_config.disable_chunked_mm_input is not expect_enabled + if is_encoder_decoder and not expect_enabled: + assert scheduler_config.long_prefill_token_threshold == 0 diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index bd0320baef871..90f8757ae4939 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -1,27 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import pytest from vllm import LLM -if os.getenv("VLLM_USE_V1", "0") != "1": - pytest.skip("Test package requires V1", allow_module_level=True) - MODEL = "meta-llama/Llama-3.2-1B" PROMPT = "Hello my name is Robert and I" @pytest.fixture(scope="module") def llm() -> LLM: - return LLM(MODEL, - enforce_eager=True, - enable_prefix_caching=True, - long_prefill_token_threshold=2, - max_num_batched_tokens=6, - max_num_seqs=3, - block_size=16) + return LLM( + MODEL, + enforce_eager=True, + enable_prefix_caching=True, + long_prefill_token_threshold=2, + max_num_batched_tokens=6, + max_num_seqs=3, + block_size=16, + ) def test_concurrent_partial_prefill(llm): diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 7dcebba491fab..a27f32938c08b 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -3,28 +3,32 @@ import random +import pytest import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + KVCacheBlock, + make_block_hash_with_group_id, +) from vllm.v1.core.single_type_kv_cache_manager import ( - ChunkedLocalAttentionManager, SlidingWindowManager) -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - SlidingWindowSpec) + ChunkedLocalAttentionManager, + SlidingWindowManager, +) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowSpec + +pytestmark = pytest.mark.cpu_test def get_sliding_window_manager(sliding_window_spec, block_pool): - return SlidingWindowManager(sliding_window_spec, - block_pool, - kv_cache_group_id=0) + return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0) -def get_chunked_local_attention_manager(chunked_local_attention_spec, - block_pool): - return ChunkedLocalAttentionManager(chunked_local_attention_spec, - block_pool, - kv_cache_group_id=0) +def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): + return ChunkedLocalAttentionManager( + chunked_local_attention_spec, block_pool, kv_cache_group_id=0 + ) def test_chunked_local_attention_possible_cached_prefix(): @@ -35,28 +39,29 @@ def test_chunked_local_attention_possible_cached_prefix(): head_size=1, dtype=torch.float32, attention_chunk_size=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = get_chunked_local_attention_manager(chunked_local_attention_spec, - block_pool) + manager = get_chunked_local_attention_manager( + chunked_local_attention_spec, block_pool + ) def run_one_case(block_is_cached, tail_token, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks - for i, (block_hash, - is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + for i, (block_hash, is_cached) in enumerate( + zip(block_hash_list, block_is_cached) + ): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { - i: block_pool.blocks[i + 10], - } + block_pool.cached_block_hash_to_block.insert( + make_block_hash_with_group_id(block_hash, 0), + block_pool.blocks[i + 10], + ) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -64,11 +69,14 @@ def test_chunked_local_attention_possible_cached_prefix(): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=chunked_local_attention_spec, - use_eagle=False)[0] + use_eagle=False, + )[0] assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block - for block in computed_blocks[:(expect_length - 1) // 2]) + assert all( + block == block_pool.null_block + for block in computed_blocks[: (expect_length - 1) // 2] + ) run_one_case([True], 0, 1) run_one_case([True], 1, 1) @@ -101,7 +109,6 @@ def test_sliding_window_possible_cached_prefix(): head_size=1, dtype=torch.float32, sliding_window=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -109,19 +116,20 @@ def test_sliding_window_possible_cached_prefix(): def run_one_case(block_is_cached, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks - for i, (block_hash, - is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + for i, (block_hash, is_cached) in enumerate( + zip(block_hash_list, block_is_cached) + ): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { - i: block_pool.blocks[i + 10], - } + block_pool.cached_block_hash_to_block.insert( + make_block_hash_with_group_id(block_hash, 0), + block_pool.blocks[i + 10], + ) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -129,16 +137,18 @@ def test_sliding_window_possible_cached_prefix(): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=sliding_window_spec, - use_eagle=False)[0] + use_eagle=False, + )[0] assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block - for block in computed_blocks[:expect_length - 2]) + assert all( + block == block_pool.null_block + for block in computed_blocks[: expect_length - 2] + ) for i in range(2): if i < expect_length: block_index = expect_length - i - 1 - assert computed_blocks[ - block_index].block_id == block_index + 10 + assert computed_blocks[block_index].block_id == block_index + 10 run_one_case([False] * 10, 0) run_one_case([True], 1) @@ -147,17 +157,16 @@ def test_sliding_window_possible_cached_prefix(): run_one_case([True, True, False], 2) run_one_case([True, True, True], 3) run_one_case([True, True, True, False], 3) - run_one_case([ - True, True, False, True, False, False, True, True, False, True, True, - True - ], 12) - run_one_case([ - True, True, False, True, False, False, True, True, False, False, False - ], 8) - run_one_case([ - True, True, False, True, False, False, True, True, False, False, False, - True - ], 8) + run_one_case( + [True, True, False, True, False, False, True, True, False, True, True, True], 12 + ) + run_one_case( + [True, True, False, True, False, False, True, True, False, False, False], 8 + ) + run_one_case( + [True, True, False, True, False, False, True, True, False, False, False, True], + 8, + ) def test_chunked_local_attention_remove_skipped_blocks(): @@ -167,7 +176,6 @@ def test_chunked_local_attention_remove_skipped_blocks(): head_size=1, dtype=torch.float32, attention_chunk_size=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -178,8 +186,8 @@ def test_chunked_local_attention_remove_skipped_blocks(): def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block + for id_ in ids ] def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): @@ -190,7 +198,17 @@ def test_chunked_local_attention_remove_skipped_blocks(): assert block.block_id == id_ original_block_ids = [ - 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, ] block_table = id_to_block_table(original_block_ids) manager.req_to_blocks["test"] = block_table @@ -219,7 +237,6 @@ def test_sliding_window_remove_skipped_blocks(): head_size=1, dtype=torch.float32, sliding_window=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -230,8 +247,8 @@ def test_sliding_window_remove_skipped_blocks(): def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block + for id_ in ids ] def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): @@ -242,7 +259,17 @@ def test_sliding_window_remove_skipped_blocks(): assert block.block_id == id_ original_block_ids = [ - 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, ] block_table = id_to_block_table(original_block_ids) manager.req_to_blocks["test"] = block_table @@ -287,19 +314,21 @@ def test_get_num_blocks_to_allocate(): head_size=1, dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) manager = get_sliding_window_manager(sliding_window_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] - cached_blocks_2 = [block_pool.null_block for _ in range(5) - ] + [KVCacheBlock(i + 1) for i in range(5)] + cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ + KVCacheBlock(i + 1) for i in range(5) + ] - assert manager.get_num_blocks_to_allocate("1", 20 * block_size, - cached_blocks_1) == 20 - assert manager.get_num_blocks_to_allocate("2", 20 * block_size, - cached_blocks_2) == 15 + assert ( + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + ) + assert ( + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + ) def test_chunked_local_attention_get_num_blocks_to_allocate(): @@ -310,16 +339,18 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): head_size=1, dtype=torch.float32, attention_chunk_size=4, # Placeholder value, not related to test result - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) manager = get_chunked_local_attention_manager(attention_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] - cached_blocks_2 = [block_pool.null_block for _ in range(5) - ] + [KVCacheBlock(i + 1) for i in range(5)] + cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ + KVCacheBlock(i + 1) for i in range(5) + ] - assert manager.get_num_blocks_to_allocate("1", 20 * block_size, - cached_blocks_1) == 20 - assert manager.get_num_blocks_to_allocate("2", 20 * block_size, - cached_blocks_2) == 15 + assert ( + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + ) + assert ( + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + ) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 849c3f59ae527..c11cf3e817d19 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -4,16 +4,29 @@ from typing import Optional, Union import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.config import ( + CacheConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.utils import sha256 +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -35,7 +48,7 @@ def create_scheduler( skip_tokenizer_init: bool = False, async_scheduling: bool = False, ) -> Union[Scheduler, AsyncScheduler]: - '''Create scheduler under test. + """Create scheduler under test. Args: model: model under test @@ -47,7 +60,7 @@ def create_scheduler( Returns: {class}`Scheduler` instance - ''' + """ if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -67,9 +80,11 @@ def create_scheduler( skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) + kwargs_cache = ( + {} + if enable_prefix_caching is None + else {"enable_prefix_caching": enable_prefix_caching} + ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, @@ -77,16 +92,21 @@ def create_scheduler( cache_dtype="auto", **kwargs_cache, ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None + kv_transfer_config = ( + KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) + if use_kv_connector + else None + ) speculative_config: Optional[SpeculativeConfig] = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=num_speculative_tokens) + model="ngram", num_speculative_tokens=num_speculative_tokens + ) vllm_config = VllmConfig( scheduler_config=scheduler_config, @@ -99,9 +119,9 @@ def create_scheduler( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) ], ) cache_config.num_gpu_blocks = num_blocks @@ -109,6 +129,7 @@ def create_scheduler( return scheduler_cls( vllm_config=vllm_config, kv_cache_config=kv_cache_config, + block_size=block_size, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) @@ -129,35 +150,40 @@ def create_requests( ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(sha256) _none_hash_initialized = True - block_hasher = get_request_block_hasher(block_size, hash) - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) + block_hasher = get_request_block_hasher(block_size, sha256) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs, + ) requests = [] for i in range(num_requests): + mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] - mm_item = MultiModalKwargsItem.dummy("dummy_m") - mm_kwargs = [mm_item] * len(mm_position) - mm_hashes = ["hash"] * len(mm_position) - else: - mm_position = None - mm_kwargs = None - mm_hashes = None - prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * - num_tokens) + for j, position in enumerate(mm_position): + # Dummy hash for each mm item should be unique + # since encoder cache tracks entries by hash + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) + + prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, - multi_modal_kwargs=mm_kwargs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=mm_hashes, + mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, block_hasher=block_hasher, ) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 64f2fa462802f..59841a446db3e 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -9,8 +9,14 @@ import torch.nn as nn from tests.utils import create_new_process_for_each_test from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.platforms import current_platform from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -18,7 +24,6 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # Helper MLP for testing class SimpleMLP(nn.Module): - def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 10) @@ -28,8 +33,9 @@ class SimpleMLP(nn.Module): return self.fc2(self.fc1(x)) -def _create_vllm_config(compilation_config: CompilationConfig, - max_num_seqs: int = 8) -> MagicMock: +def _create_vllm_config( + compilation_config: CompilationConfig, max_num_seqs: int = 8 +) -> MagicMock: mock_config = MagicMock(spec=VllmConfig) mock_config.compilation_config = compilation_config mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) @@ -43,54 +49,39 @@ def _create_vllm_config(compilation_config: CompilationConfig, class TestCudagraphDispatcher: - @pytest.mark.parametrize( - "params", + "case_id,cudagraph_mode_str,compilation_level", [ # Test case 0: Full CG for mixed batches, no separate routine - { - "case_id": 0, - "cudagraph_mode": "FULL", - "compilation_level": CompilationLevel.NO_COMPILATION, - }, + (0, "FULL", CompilationLevel.NO_COMPILATION), # Test case 1: Full CG for uniform batches, piecewise for mixed - { - "case_id": 1, - "cudagraph_mode": "FULL_AND_PIECEWISE", - "compilation_level": CompilationLevel.PIECEWISE, - }, + (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION), # Test case 2: Full CG for uniform batches, no CG for mixed - { - "case_id": 2, - "cudagraph_mode": "FULL_DECODE_ONLY", - "compilation_level": CompilationLevel.NO_COMPILATION, - }, + (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION), # Test case 3: Piecewise for all - { - "case_id": 3, - "cudagraph_mode": "PIECEWISE", - "compilation_level": CompilationLevel.PIECEWISE, - }, - ]) - def test_dispatcher(self, params): + (3, "PIECEWISE", CompilationLevel.PIECEWISE), + ], + ) + def test_dispatcher(self, cudagraph_mode_str, compilation_level): # Setup dispatcher comp_config = CompilationConfig( - cudagraph_mode=params["cudagraph_mode"], - level=params["compilation_level"], - cudagraph_capture_sizes=[1, 8]) + cudagraph_mode=cudagraph_mode_str, + level=compilation_level, + cudagraph_capture_sizes=[1, 8], + ) config = _create_vllm_config(comp_config, max_num_seqs=8) dispatcher = CudagraphDispatcher(config) dispatcher.initialize_cudagraph_keys( - cudagraph_mode=comp_config.cudagraph_mode, - uniform_decode_query_len=1) + cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1 + ) # Verify the key is initialized correctly - if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 - if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]: + if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 @@ -99,10 +90,10 @@ class TestCudagraphDispatcher: # 1. non-uniform batch, size in cudagraph size list desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) rt_mode, key = dispatcher.dispatch(desc_full_exact) - if params["cudagraph_mode"] == "FULL": + if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_full_exact - elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact else: @@ -111,15 +102,13 @@ class TestCudagraphDispatcher: # 2. uniform decode batch, size in cudagraph size list desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) rt_mode, key = dispatcher.dispatch(desc_uniform_exact) - if params["cudagraph_mode"] == "FULL": + if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact.non_uniform - elif params["cudagraph_mode"] in [ - "FULL_DECODE_ONLY", "FULL_AND_PIECEWISE" - ]: + elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact - elif params["cudagraph_mode"] == "PIECEWISE": + elif cudagraph_mode_str == "PIECEWISE": assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_uniform_exact.non_uniform else: @@ -131,10 +120,18 @@ class TestCudagraphDispatcher: assert rt_mode == CUDAGraphMode.NONE assert key is None + # 4. Cascade attention should have a fall back mode + desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True) + if "PIECEWISE" in cudagraph_mode_str: # string contains check + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_full_exact.non_uniform + else: + assert rt_mode == CUDAGraphMode.NONE + @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: - def setup_method(self): self.vllm_config = _create_vllm_config(CompilationConfig()) self.model = SimpleMLP().to("cuda") @@ -143,26 +140,30 @@ class TestCUDAGraphWrapper: @create_new_process_for_each_test("spawn") def test_capture_and_replay(self): - wrapper = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) batch_descriptor = BatchDescriptor(num_tokens=10) # 0. global warmup - with set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None): + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ): wrapper(self.input_tensor) # 1. Capture - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.FULL, - batch_descriptor=batch_descriptor),\ - patch("torch.cuda.graph", - wraps=torch.cuda.graph) as mock_cuda_graph: + batch_descriptor=batch_descriptor, + ), + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output assert torch.allclose(output1, torch.zeros_like(output1)) @@ -173,13 +174,17 @@ class TestCUDAGraphWrapper: assert entry.cudagraph is not None # 2. Replay - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.FULL, - batch_descriptor=batch_descriptor),\ - patch.object(entry.cudagraph, 'replay', - wraps=entry.cudagraph.replay) as mock_replay: + batch_descriptor=batch_descriptor, + ), + patch.object( + entry.cudagraph, "replay", wraps=entry.cudagraph.replay + ) as mock_replay, + ): output2 = wrapper(self.input_tensor) mock_replay.assert_called_once() @@ -189,20 +194,23 @@ class TestCUDAGraphWrapper: @create_new_process_for_each_test("spawn") def test_bypass_on_mode_mismatch(self): - wrapper = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) batch_descriptor = BatchDescriptor(num_tokens=10) - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=batch_descriptor), \ - patch('torch.cuda.graph', - wraps=torch.cuda.graph) as mock_cuda_graph, \ - patch.object(self.model, 'forward', - wraps=self.model.forward) as mock_forward: + batch_descriptor=batch_descriptor, + ), + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + patch.object( + self.model, "forward", wraps=self.model.forward + ) as mock_forward, + ): wrapper(self.input_tensor) mock_cuda_graph.assert_not_called() mock_forward.assert_called_once() @@ -210,18 +218,20 @@ class TestCUDAGraphWrapper: @create_new_process_for_each_test("spawn") def test_bypass_on_mode_none(self): - wrapper = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) batch_descriptor = BatchDescriptor(num_tokens=10) - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=batch_descriptor), \ - patch('torch.cuda.graph', - wraps=torch.cuda.graph) as mock_cuda_graph: + batch_descriptor=batch_descriptor, + ), + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + ): wrapper(self.input_tensor) mock_cuda_graph.assert_not_called() assert not wrapper.concrete_cudagraph_entries @@ -229,38 +239,44 @@ class TestCUDAGraphWrapper: @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCudagraphIntegration: - def setup_method(self): # only FULL mode for non-uniform batches - self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE, - cudagraph_mode="FULL", - cudagraph_capture_sizes=[10, 20]) + self.comp_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + cudagraph_mode="FULL", + cudagraph_capture_sizes=[10, 20], + ) self.vllm_config = _create_vllm_config(self.comp_config) self.dispatcher = CudagraphDispatcher(self.vllm_config) self.dispatcher.initialize_cudagraph_keys( - self.comp_config.cudagraph_mode, uniform_decode_query_len=1) + self.comp_config.cudagraph_mode, uniform_decode_query_len=1 + ) - def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode, - batch_descriptor): + def _run_and_monitor_call( + self, wrapper, input_tensor, runtime_mode, batch_descriptor + ): """Helper to run a single call and monitor the action.""" - with patch('torch.cuda.graph', - wraps=torch.cuda.graph) as mock_graph_context, \ - patch.object(wrapper, 'runnable', - wraps=wrapper.runnable) as mock_runnable: + with ( + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context, + patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable, + ): + entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None) - entry = wrapper.concrete_cudagraph_entries.get( - batch_descriptor, None) - - context = set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=runtime_mode, - batch_descriptor=batch_descriptor) + context = set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=runtime_mode, + batch_descriptor=batch_descriptor, + ) mock_replay = MagicMock() if entry and entry.cudagraph: - with context, \ - patch.object(entry.cudagraph, 'replay', - new_callable=MagicMock) as mock_replay: + with ( + context, + patch.object( + entry.cudagraph, "replay", new_callable=MagicMock + ) as mock_replay, + ): wrapper(input_tensor) else: with context: @@ -281,8 +297,7 @@ class TestCudagraphIntegration: @create_new_process_for_each_test("spawn") def test_capture_replay_bypass_logic(self): model = SimpleMLP().to("cuda") - full_wrapper = CUDAGraphWrapper(model, self.vllm_config, - CUDAGraphMode.FULL) + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL) max_bs = 16 persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda") input_1 = persistent_input_buffer[:1] @@ -294,75 +309,79 @@ class TestCudagraphIntegration: desc_3_unseen = BatchDescriptor(num_tokens=3) # 0. global warmup - with set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None): + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ): full_wrapper(input_1) rt_mode, key = self.dispatcher.dispatch(desc_1) # 1. Capture first shape - action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "capture_global" # 2. Replay first shape - action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "replay" rt_mode, key = self.dispatcher.dispatch(desc_2) # 3. Capture second shape - action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key) assert action == "capture_global" # 4. Replay second shape - action = self._run_and_monitor_call(full_wrapper, input_2, - CUDAGraphMode.FULL, desc_2) + action = self._run_and_monitor_call( + full_wrapper, input_2, CUDAGraphMode.FULL, desc_2 + ) assert action == "replay" # 5. Bypass if no key match rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) assert rt_mode == CUDAGraphMode.NONE - action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key) assert action == "bypass" # capture unseen shape is not allowed after disable set_cudagraph_capturing_enabled(False) with pytest.raises(RuntimeError): - self._run_and_monitor_call(full_wrapper, input_3, - CUDAGraphMode.FULL, desc_3_unseen) + self._run_and_monitor_call( + full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen + ) set_cudagraph_capturing_enabled(True) @create_new_process_for_each_test("spawn") def test_nested_wrappers(self): """Tests a scenario with a PIECEWISE wrapper inside a FULL one.""" model = SimpleMLP().to("cuda") - full_wrapper = CUDAGraphWrapper(model, self.vllm_config, - CUDAGraphMode.FULL) + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL) input_1 = torch.randn(1, 10, device="cuda") # Setup: Inner model is wrapped with PIECEWISE, outer with FULL inner_model = SimpleMLP().to("cuda") - piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config, - CUDAGraphMode.PIECEWISE) + piecewise_wrapper = CUDAGraphWrapper( + inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE + ) inner_model.forward = MagicMock(wraps=inner_model.forward) outer_model = SimpleMLP().to("cuda") # When outer model is called, it calls the piecewise_wrapper - outer_model.forward = MagicMock(wraps=outer_model.forward, - side_effect=piecewise_wrapper) - full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config, - CUDAGraphMode.FULL) + outer_model.forward = MagicMock( + wraps=outer_model.forward, side_effect=piecewise_wrapper + ) + full_wrapper = CUDAGraphWrapper( + outer_model, self.vllm_config, CUDAGraphMode.FULL + ) desc_1 = BatchDescriptor(num_tokens=1) # 0. global warmup - with set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None): + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ): full_wrapper(input_1) # --- Test runtime mode FULL--- @@ -370,8 +389,9 @@ class TestCudagraphIntegration: # The inner mock should be called once inside the graph capture. outer_model.forward.reset_mock() inner_model.forward.reset_mock() - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.FULL, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.FULL, desc_1 + ) assert action == "capture_global" assert outer_model.forward.call_count == 1 assert inner_model.forward.call_count == 1 @@ -379,8 +399,9 @@ class TestCudagraphIntegration: # Run again. Expect outer wrapper to replay. # The outer model should NOT be called because the whole graph # is replayed. - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.FULL, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.FULL, desc_1 + ) assert action == "replay" assert outer_model.forward.call_count == 1 # No new call assert inner_model.forward.call_count == 1 @@ -391,16 +412,18 @@ class TestCudagraphIntegration: # Run with PIECEWISE mode context. # Expect outer wrapper to bypass and call inner wrapper. # Inner wrapper should capture. - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.PIECEWISE, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1 + ) assert action == "capture_global" assert outer_model.forward.call_count == 1 assert inner_model.forward.call_count == 1 # Run again with PIECEWISE. # Outer bypasses, inner replays. - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.PIECEWISE, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1 + ) assert action == "bypass" assert outer_model.forward.call_count == 2 assert inner_model.forward.call_count == 1 diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 81655e4175006..8c8148ae20948 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -4,12 +4,11 @@ import contextlib import os import weakref from contextlib import ExitStack -from dataclasses import dataclass -from typing import Optional import pytest from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM from vllm.config import CompilationConfig from vllm.platforms import current_platform @@ -34,57 +33,6 @@ def temporary_environ(env_vars): os.environ[k] = v -@dataclass -class BackendConfig: - name: str - env_vars: dict - comp_config: dict - specific_gpu_arch: Optional[tuple] = None - - -# Define all backend configurations of full cudagraph to be tested -backend_configs = { - # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={"VLLM_FLASH_ATTN_VERSION": "3"}, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), - # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), - # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={"VLLM_FLASH_ATTN_VERSION": "2"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), - # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), - # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), -} - # test attention backend and cudagraph_mode combo # (backend_name, cudagraph_mode, supported) combo_cases_1 = [ @@ -97,9 +45,8 @@ combo_cases_1 = [ ] -@pytest.mark.parametrize("combo_case", combo_cases_1) -def test_backend_and_cudagraph_mode_combo(combo_case): - backend_name, cudagraph_mode, supported = combo_case +@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1) +def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supported): if backend_name == "FlashInfer": try: import flashinfer # noqa: F401 @@ -107,25 +54,30 @@ def test_backend_and_cudagraph_mode_combo(combo_case): pytest.skip("FlashInfer is not installed") backend_config = backend_configs[backend_name] # Dynamically skip test if GPU capability is not met - if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ - != current_platform.get_device_capability(): + if ( + backend_config.specific_gpu_arch + and backend_config.specific_gpu_arch != current_platform.get_device_capability() + ): pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") - env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} + env_vars = backend_configs[backend_name].env_vars with temporary_environ(env_vars), ExitStack() as stack: if not supported: stack.enter_context(pytest.raises(Exception)) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_seqs=256, - trust_remote_code=True, - gpu_memory_utilization=0.45, - max_model_len=1024, - compilation_config=CompilationConfig( - level=3, cudagraph_mode=cudagraph_mode)) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=3, cudagraph_mode=cudagraph_mode + ), + ) llm.generate(["Hello, my name is"] * 10) - + # when above code raises, `llm` may be undefined, so we need to catch that try: llm = weakref.proxy(llm) del llm @@ -144,10 +96,13 @@ combo_cases_2 = [ ("FA2", "FULL", 0, True), # no compilation + full cudagraph ("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph ("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph - ("FA2", "PIECEWISE", 3, - True), # piecewise compilation + piecewise cudagraph - ("FA2", "FULL_AND_PIECEWISE", 0, - False), # piecewise cudagraph not supported without piecewise compilation + ("FA2", "PIECEWISE", 3, True), # piecewise compilation + piecewise cudagraph + ( + "FA2", + "FULL_AND_PIECEWISE", + 0, + False, + ), # piecewise cudagraph not supported without piecewise compilation ("FA2", "FULL_AND_PIECEWISE", 3, True), ("FA2", "FULL_DECODE_ONLY", 0, True), ("FA2", "FULL_DECODE_ONLY", 3, True), @@ -156,25 +111,30 @@ combo_cases_2 = [ ] -@pytest.mark.parametrize("combo_case", combo_cases_2) +@pytest.mark.parametrize( + "backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2 +) def test_cudagraph_compilation_combo(combo_case): - backend_name, cudagraph_mode, compilation_level, supported\ - = combo_case + backend_name, cudagraph_mode, compilation_level, supported = combo_case - env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} + env_vars = backend_configs[backend_name].env_vars with temporary_environ(env_vars), ExitStack() as stack: if not supported: stack.enter_context(pytest.raises(Exception)) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_seqs=256, - trust_remote_code=True, - gpu_memory_utilization=0.45, - max_model_len=1024, - compilation_config=CompilationConfig( - level=compilation_level, cudagraph_mode=cudagraph_mode)) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=compilation_level, cudagraph_mode=cudagraph_mode + ), + ) llm.generate(["Hello, my name is"] * 10) + # when above code raises, `llm` may be undefined, so we need to catch that try: llm = weakref.proxy(llm) del llm diff --git a/tests/tracing/__init__.py b/tests/v1/distributed/__init__.py similarity index 100% rename from tests/tracing/__init__.py rename to tests/v1/distributed/__init__.py diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py similarity index 63% rename from tests/v1/test_async_llm_dp.py rename to tests/v1/distributed/test_async_llm_dp.py index c2610a87ac780..28bb91f34c39b 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/distributed/test_async_llm_dp.py @@ -13,12 +13,11 @@ from vllm import SamplingParams from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType -from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.metrics.loggers import StatLoggerBase -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats DP_SIZE = int(os.getenv("DP_SIZE", 2)) @@ -29,40 +28,40 @@ engine_args = AsyncEngineArgs( data_parallel_size=DP_SIZE, ) -if not current_platform.supports_v1(engine_args.create_model_config()): - pytest.skip(reason="Requires V1-supporting platform.", - allow_module_level=True) - async def generate( - engine: AsyncLLM, - request_id: str, - prompt: PromptType, - output_kind: RequestOutputKind, - max_tokens: int, - prompt_logprobs: Optional[int] = None, - data_parallel_rank: Optional[int] = None) -> tuple[int, str]: + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + output_kind: RequestOutputKind, + max_tokens: int, + prompt_logprobs: Optional[int] = None, + data_parallel_rank: Optional[int] = None, +) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) count = 0 - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True, - output_kind=output_kind, - temperature=0, - prompt_logprobs=prompt_logprobs) - async for out in engine.generate(request_id=request_id, - prompt=prompt, - sampling_params=sampling_params, - data_parallel_rank=data_parallel_rank): - + sampling_params = SamplingParams( + max_tokens=max_tokens, + ignore_eos=True, + output_kind=output_kind, + temperature=0, + prompt_logprobs=prompt_logprobs, + ) + async for out in engine.generate( + request_id=request_id, + prompt=prompt, + sampling_params=sampling_params, + data_parallel_rank=data_parallel_rank, + ): num_tokens = len(out.outputs[0].token_ids) if output_kind == RequestOutputKind.DELTA: count += num_tokens else: count = num_tokens - await asyncio.sleep(0.) + await asyncio.sleep(0.0) return count, request_id @@ -75,10 +74,11 @@ async def generate( ], ) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) +@pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind, - data_parallel_backend: str): - +async def test_load( + output_kind: RequestOutputKind, data_parallel_backend: str, async_scheduling: bool +): stats_loggers = {} @dataclass @@ -89,24 +89,27 @@ async def test_load(output_kind: RequestOutputKind, def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): stats_loggers[engine_index] = self - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, + engine_idx: int = 0, + ): if iteration_stats: - self.finished_req_count += len( - iteration_stats.finished_requests) + self.finished_req_count += len(iteration_stats.finished_requests) def log_engine_initialized(self): self.init_count += 1 with ExitStack() as after: - prompt = "This is a test of data parallel" engine_args.data_parallel_backend = data_parallel_backend - engine = AsyncLLM.from_engine_args(engine_args, - stat_loggers=[SimpleStatsLogger]) + engine_args.async_scheduling = async_scheduling + engine = AsyncLLM.from_engine_args( + engine_args, stat_loggers=[SimpleStatsLogger] + ) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -119,20 +122,23 @@ async def test_load(output_kind: RequestOutputKind, for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + generate( + engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS + ) + ) + ) # Short sleep to ensure that requests are distributed. await asyncio.sleep(0.01) # Confirm that we got all the EXPECTED tokens from the requests. - done, pending = await asyncio.wait(tasks, - return_when=asyncio.FIRST_EXCEPTION) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) for task in pending: task.cancel() for task in done: num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + f"expected {NUM_EXPECTED_TOKENS}" + ) assert not engine.output_processor.has_unfinished_requests() @@ -156,5 +162,6 @@ async def test_load(output_kind: RequestOutputKind, for sl in stats_loggers.values(): slogger: SimpleStatsLogger = sl - assert slogger.finished_req_count > NUM_REQUESTS // ( - DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}" + assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), ( + f"requests are imbalanced: {stats_loggers}" + ) diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/distributed/test_external_lb_dp.py similarity index 64% rename from tests/v1/test_external_lb_dp.py rename to tests/v1/distributed/test_external_lb_dp.py index 4a5c47fead58f..912f8cffe7f6d 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/distributed/test_external_lb_dp.py @@ -9,6 +9,7 @@ from contextlib import AsyncExitStack import openai # use the official client for correctness check import pytest import pytest_asyncio +import requests from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform @@ -25,12 +26,14 @@ class ExternalLBServerManager: """Manages data parallel vLLM server instances for external load balancer testing.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.tp_size = tp_size @@ -46,20 +49,22 @@ class ExternalLBServerManager: server_args = self.base_server_args.copy() # Add external LB specific arguments - server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-rank", - str(rank), - "--data-parallel-size-local", - "1", - "--tensor-parallel-size", - str(self.tp_size), - "--port", - str(8000 + rank), # Different port for each rank - "--api-server-count", - str(self.api_server_count), - ]) + server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-rank", + str(rank), + "--data-parallel-size-local", + "1", + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + rank), # Different port for each rank + "--api-server-count", + str(self.api_server_count), + ] + ) # Use a thread to start each server to allow parallel initialization def start_server(r: int, sargs: list[str]): @@ -70,23 +75,24 @@ class ExternalLBServerManager: sargs, auto_port=False, env_dict={ - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(r * TP_SIZE, (r + 1) * TP_SIZE)) - }) + "VLLM_SERVER_DEV_MODE": "1", + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(r * TP_SIZE, (r + 1) * TP_SIZE) + ), + }, + ) server.__enter__() - print(f"Server rank {r} started successfully with " - f"{self.api_server_count} API servers") + print( + f"Server rank {r} started successfully with " + f"{self.api_server_count} API servers" + ) self.servers.append((server, sargs)) except Exception as e: print(f"Failed to start server rank {r}: {e}") raise - thread = threading.Thread(target=start_server, - args=(rank, server_args)) + thread = threading.Thread(target=start_server, args=(rank, server_args)) thread.start() self.server_threads.append(thread) @@ -127,11 +133,19 @@ def default_server_args(): @pytest.fixture(scope="module", params=[1, 4]) -def servers(request, default_server_args): +def server_manager(request, default_server_args): api_server_count = request.param - with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, - default_server_args) as server_list: - yield server_list + server_manager = ExternalLBServerManager( + MODEL_NAME, DP_SIZE, api_server_count, default_server_args + ) + + with server_manager: + yield server_manager + + +@pytest.fixture +def servers(server_manager): + return server_manager.servers @pytest_asyncio.fixture @@ -144,21 +158,51 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): ] +def _get_parallel_config(server: RemoteOpenAIServer): + response = requests.get(server.url_for("server_info?config_format=json")) + response.raise_for_status() + + vllm_config = response.json()["vllm_config"] + return vllm_config["parallel_config"] + + +def test_external_lb_server_info(server_manager): + servers = server_manager.servers + api_server_count = server_manager.api_server_count + + for i, (server, _) in enumerate(servers): + print(f"Testing {i=}") + + # Each request will hit one of the API servers + # `n_reqs` is set so that there is a good chance each server + # receives at least one request + n_reqs = 2 * api_server_count * api_server_count + parallel_configs = [_get_parallel_config(server) for _ in range(n_reqs)] + api_process_counts = [c["_api_process_count"] for c in parallel_configs] + api_process_ranks = [c["_api_process_rank"] for c in parallel_configs] + + assert all(c == api_server_count for c in api_process_counts), ( + api_process_counts + ) + assert all(0 <= r < api_server_count for r in api_process_ranks), ( + api_process_ranks + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [MODEL_NAME], ) -async def test_external_lb_single_completion(clients: list[ - openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: - +async def test_external_lb_single_completion( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: async def make_request(client: openai.AsyncOpenAI): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=10, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=10, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -212,11 +256,14 @@ async def test_external_lb_single_completion(clients: list[ _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) print( f"Successfully completed external LB test with {len(clients)} servers " - f"(API server count: {api_server_count})") + f"(API server count: {api_server_count})" + ) @pytest.mark.asyncio @@ -224,9 +271,11 @@ async def test_external_lb_single_completion(clients: list[ "model_name", [MODEL_NAME], ) -async def test_external_lb_completion_streaming(clients: list[ - openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: +async def test_external_lb_completion_streaming( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: prompt = "What is an LLM?" async def make_streaming_request(client: openai.AsyncOpenAI): @@ -240,11 +289,9 @@ async def test_external_lb_completion_streaming(clients: list[ single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -255,16 +302,15 @@ async def test_external_lb_completion_streaming(clients: list[ last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single request to each server @@ -280,10 +326,7 @@ async def test_external_lb_completion_streaming(clients: list[ all_tasks = [] for i, client in enumerate(clients): - tasks = [ - make_streaming_request(client) - for _ in range(num_requests_per_server) - ] + tasks = [make_streaming_request(client) for _ in range(num_requests_per_server)] all_tasks.extend(tasks) results = await asyncio.gather(*all_tasks) @@ -295,10 +338,7 @@ async def test_external_lb_completion_streaming(clients: list[ # Second burst of streaming requests all_tasks = [] for i, client in enumerate(clients): - tasks = [ - make_streaming_request(client) - for _ in range(num_requests_per_server) - ] + tasks = [make_streaming_request(client) for _ in range(num_requests_per_server)] all_tasks.extend(tasks) results = await asyncio.gather(*all_tasks) @@ -307,7 +347,11 @@ async def test_external_lb_completion_streaming(clients: list[ _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed external LB streaming test with " - f"{len(clients)} servers (API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed external LB streaming test with " + f"{len(clients)} servers (API server count: {api_server_count})" + ) diff --git a/tests/v1/test_hybrid_lb_dp.py b/tests/v1/distributed/test_hybrid_lb_dp.py similarity index 65% rename from tests/v1/test_hybrid_lb_dp.py rename to tests/v1/distributed/test_hybrid_lb_dp.py index 293b1257be6bb..aa25130752a49 100644 --- a/tests/v1/test_hybrid_lb_dp.py +++ b/tests/v1/distributed/test_hybrid_lb_dp.py @@ -9,9 +9,10 @@ from contextlib import AsyncExitStack import openai # use the official client for correctness check import pytest import pytest_asyncio +import requests from tests.utils import RemoteOpenAIServer -from tests.v1.test_utils import check_request_balancing +from tests.v1.utils import check_request_balancing from vllm.platforms import current_platform MODEL_NAME = "ibm-research/PowerMoE-3b" @@ -27,17 +28,19 @@ DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node class HybridLBServerManager: - """Manages hybrid data parallel vLLM server instances where each node - runs a single logical API server that balances requests only to the + """Manages hybrid data parallel vLLM server instances where each node + runs a single logical API server that balances requests only to the DP engines running on that same node.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - dp_size_local: int = DP_SIZE_LOCAL, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_size_local: int = DP_SIZE_LOCAL, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.dp_size_local = dp_size_local @@ -58,25 +61,27 @@ class HybridLBServerManager: start_rank = node_id * self.dp_size_local # Add hybrid LB specific arguments - server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_size_local), - "--data-parallel-start-rank", - str(start_rank), - "--data-parallel-hybrid-lb", # Enable hybrid LB mode - "--tensor-parallel-size", - str(self.tp_size), - "--port", - str(8000 + node_id), # Different port for each node - "--api-server-count", - str(self.api_server_count), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size_local), + "--data-parallel-start-rank", + str(start_rank), + "--data-parallel-hybrid-lb", # Enable hybrid LB mode + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + node_id), # Different port for each node + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Use a thread to start each server to allow parallel initialization def start_server(node: int, sargs: list[str]): @@ -92,24 +97,25 @@ class HybridLBServerManager: sargs, auto_port=False, env_dict={ - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(gpu_start, gpu_end)) - }) + "VLLM_SERVER_DEV_MODE": "1", + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(gpu_start, gpu_end) + ), + }, + ) server.__enter__() - print(f"Hybrid LB node {node} started successfully with " - f"{self.dp_size_local} local DP ranks and " - f"{self.api_server_count} API servers") + print( + f"Hybrid LB node {node} started successfully with " + f"{self.dp_size_local} local DP ranks and " + f"{self.api_server_count} API servers" + ) self.servers.append((server, sargs)) except Exception as e: print(f"Failed to start hybrid LB node {node}: {e}") raise - thread = threading.Thread(target=start_server, - args=(node_id, server_args)) + thread = threading.Thread(target=start_server, args=(node_id, server_args)) thread.start() self.server_threads.append(thread) @@ -150,12 +156,24 @@ def default_server_args(): @pytest.fixture(scope="module", params=[1, 4]) -def servers(request, default_server_args): +def server_manager(request, default_server_args): api_server_count = request.param - with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, - default_server_args, DP_SIZE_LOCAL, - TP_SIZE) as server_list: - yield server_list + server_manager = HybridLBServerManager( + MODEL_NAME, + DP_SIZE, + api_server_count, + default_server_args, + DP_SIZE_LOCAL, + TP_SIZE, + ) + + with server_manager: + yield server_manager + + +@pytest.fixture +def servers(server_manager): + return server_manager.servers @pytest_asyncio.fixture @@ -168,22 +186,51 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): ] +def _get_parallel_config(server: RemoteOpenAIServer): + response = requests.get(server.url_for("server_info?config_format=json")) + response.raise_for_status() + + vllm_config = response.json()["vllm_config"] + return vllm_config["parallel_config"] + + +def test_hybrid_dp_server_info(server_manager): + servers = server_manager.servers + api_server_count = server_manager.api_server_count + + for i, (server, _) in enumerate(servers): + print(f"Testing {i=}") + + # Each request will hit one of the API servers + # `n_reqs` is set so that there is a good chance each server + # receives at least one request + n_reqs = 2 * api_server_count * api_server_count + parallel_configs = [_get_parallel_config(server) for _ in range(n_reqs)] + api_process_counts = [c["_api_process_count"] for c in parallel_configs] + api_process_ranks = [c["_api_process_rank"] for c in parallel_configs] + + assert all(c == api_server_count for c in api_process_counts), ( + api_process_counts + ) + assert all(0 <= r < api_server_count for r in api_process_ranks), ( + api_process_ranks + ) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [MODEL_NAME], ) -async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], - servers: list[tuple[RemoteOpenAIServer, - list[str]]], - model_name: str) -> None: - +async def test_hybrid_lb_completion( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: async def make_request(client: openai.AsyncOpenAI): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -207,9 +254,7 @@ async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], for i, client in enumerate(clients): result = await make_request(client) assert result is not None - print( - f"Hybrid LB node {i} handled single completion request successfully" - ) + print(f"Hybrid LB node {i} handled single completion request successfully") await asyncio.sleep(0.5) @@ -240,8 +285,10 @@ async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) print( f"Successfully completed hybrid LB test with {len(clients)} nodes " f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})" @@ -258,9 +305,11 @@ async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], "model_name", [MODEL_NAME], ) -async def test_hybrid_lb_completion_streaming(clients: list[ - openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: +async def test_hybrid_lb_completion_streaming( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: prompt = "What is an LLM?" async def make_streaming_request(client: openai.AsyncOpenAI): @@ -274,11 +323,9 @@ async def test_hybrid_lb_completion_streaming(clients: list[ single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -289,25 +336,22 @@ async def test_hybrid_lb_completion_streaming(clients: list[ last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single request to each node for i, client in enumerate(clients): result = await make_streaming_request(client) assert result is not None - print( - f"Hybrid LB node {i} handled single streaming request successfully" - ) + print(f"Hybrid LB node {i} handled single streaming request successfully") await asyncio.sleep(0.5) @@ -338,11 +382,15 @@ async def test_hybrid_lb_completion_streaming(clients: list[ _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed hybrid LB streaming test with " - f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, " - f"API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed hybrid LB streaming test with " + f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, " + f"API server count: {api_server_count})" + ) # Check request balancing within each node for i, (server, _) in enumerate(servers): diff --git a/tests/v1/test_internal_lb_dp.py b/tests/v1/distributed/test_internal_lb_dp.py similarity index 64% rename from tests/v1/test_internal_lb_dp.py rename to tests/v1/distributed/test_internal_lb_dp.py index 2b031865cad76..452d3682e65de 100644 --- a/tests/v1/test_internal_lb_dp.py +++ b/tests/v1/distributed/test_internal_lb_dp.py @@ -10,9 +10,10 @@ from typing import Optional, cast import openai # use the official client for correctness check import pytest import pytest_asyncio +import requests from tests.utils import RemoteOpenAIServer -from tests.v1.test_utils import check_request_balancing +from tests.v1.utils import check_request_balancing from vllm.platforms import current_platform MODEL_NAME = "ibm-research/PowerMoE-3b" @@ -30,66 +31,71 @@ class MultinodeInternalLBServerManager: """Manages multi-node data parallel vLLM server instances for internal load balancer testing using --headless mode.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - dp_per_node: int = 1, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_per_node: int = 1, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.dp_per_node = dp_per_node self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[Optional[tuple[RemoteOpenAIServer, - list[str]]]] = [None] * (dp_size // - dp_per_node) + self.servers: list[Optional[tuple[RemoteOpenAIServer, list[str]]]] = [None] * ( + dp_size // dp_per_node + ) self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: """Start all server instances for multi-node internal LB mode.""" - for server_idx, rank in enumerate( - range(0, self.dp_size, self.dp_per_node)): + for server_idx, rank in enumerate(range(0, self.dp_size, self.dp_per_node)): # Create server args for this specific rank server_args = self.base_server_args.copy() if rank == 0: # Head node - runs API server and first DP rank - server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_per_node), - "--tensor-parallel-size", - str(self.tp_size), - "--port", - "8000", # Single endpoint for all requests - "--api-server-count", - str(self.api_server_count), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_per_node), + "--tensor-parallel-size", + str(self.tp_size), + "--port", + "8000", # Single endpoint for all requests + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) else: # Secondary nodes - run in headless mode - server_args.extend([ - "--headless", - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_per_node), - "--data-parallel-start-rank", - str(rank), - "--tensor-parallel-size", - str(self.tp_size), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + server_args.extend( + [ + "--headless", + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_per_node), + "--data-parallel-start-rank", + str(rank), + "--tensor-parallel-size", + str(self.tp_size), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Use a thread to start each server to allow parallel initialization def start_server(sidx: int, r: int, sargs: list[str]): @@ -101,18 +107,19 @@ class MultinodeInternalLBServerManager: sargs, auto_port=False, env_dict={ - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(r, r + gpus_per_node)) - }) + "VLLM_SERVER_DEV_MODE": "1", + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(r, r + gpus_per_node) + ), + }, + ) server.__enter__() if r == 0: print( f"Head node (rank {r}) started successfully with " - f"{self.api_server_count} API servers") + f"{self.api_server_count} API servers" + ) else: print(f"Headless node (rank {r}) started successfully") self.servers[sidx] = (server, sargs) @@ -121,8 +128,9 @@ class MultinodeInternalLBServerManager: traceback.print_exc() raise - thread = threading.Thread(target=start_server, - args=(server_idx, rank, server_args)) + thread = threading.Thread( + target=start_server, args=(server_idx, rank, server_args) + ) thread.start() self.server_threads.append(thread) @@ -154,19 +162,20 @@ class APIOnlyServerManager: """Manages API-only server (Node 0) and headless engines server (Node 1) for testing separated API server and engine configuration.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[Optional[tuple[RemoteOpenAIServer, - list[str]]]] = [None] * 2 + self.servers: list[Optional[tuple[RemoteOpenAIServer, list[str]]]] = [None] * 2 self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: @@ -174,38 +183,42 @@ class APIOnlyServerManager: # Start API-only server (Node 0) - no engines, only API server api_server_args = self.base_server_args.copy() - api_server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - "0", # No engines on this node - "--tensor-parallel-size", - str(self.tp_size), - "--port", - "8000", - "--api-server-count", - str(self.api_server_count), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + api_server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + "0", # No engines on this node + "--tensor-parallel-size", + str(self.tp_size), + "--port", + "8000", + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Start headless engines server (Node 1) - all engines, no API server engines_server_args = self.base_server_args.copy() - engines_server_args.extend([ - "--headless", - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_size), # All engines on this node - "--tensor-parallel-size", - str(self.tp_size), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + engines_server_args.extend( + [ + "--headless", + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size), # All engines on this node + "--tensor-parallel-size", + str(self.tp_size), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Use threads to start both servers in parallel def start_api_server(): @@ -214,10 +227,16 @@ class APIOnlyServerManager: self.model_name, api_server_args, auto_port=False, - env_dict={}) # No GPUs needed for API-only server + env_dict={ + "VLLM_SERVER_DEV_MODE": "1", + # No GPUs needed for API-only server + }, + ) server.__enter__() - print(f"API-only server started successfully with " - f"{self.api_server_count} API servers") + print( + f"API-only server started successfully with " + f"{self.api_server_count} API servers" + ) self.servers[0] = (server, api_server_args) except Exception as e: print(f"Failed to start API-only server: {e}") @@ -230,16 +249,17 @@ class APIOnlyServerManager: engines_server_args, auto_port=False, env_dict={ - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(self.dp_size * self.tp_size)) - }) + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(self.dp_size * self.tp_size) + ) + }, + ) server.__enter__() - print(f"Headless engines server started successfully with " - f"{self.dp_size} engines") + print( + f"Headless engines server started successfully with " + f"{self.dp_size} engines" + ) self.servers[1] = (server, engines_server_args) except Exception as e: print(f"Failed to start headless engines server: {e}") @@ -293,22 +313,33 @@ def default_server_args(): @pytest.fixture(scope="module", params=[1, 4]) -def servers(request, default_server_args): +def server_manager(request, default_server_args): api_server_count = request.param - with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE, - api_server_count, - default_server_args, - DP_SIZE // NUM_NODES, - TP_SIZE) as server_list: - yield server_list + server_manager = MultinodeInternalLBServerManager( + MODEL_NAME, + DP_SIZE, + api_server_count, + default_server_args, + DP_SIZE // NUM_NODES, + TP_SIZE, + ) + + with server_manager: + yield server_manager + + +@pytest.fixture +def servers(server_manager): + return server_manager.servers @pytest.fixture(scope="module", params=[1, 4]) def api_only_servers(request, default_server_args): """Fixture for API-only server + headless engines configuration.""" api_server_count = request.param - with APIOnlyServerManager(MODEL_NAME, DP_SIZE, api_server_count, - default_server_args, TP_SIZE) as server_list: + with APIOnlyServerManager( + MODEL_NAME, DP_SIZE, api_server_count, default_server_args, TP_SIZE + ) as server_list: yield server_list @@ -322,8 +353,7 @@ async def client(servers: list[tuple[RemoteOpenAIServer, list[str]]]): @pytest_asyncio.fixture -async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer, - list[str]]]): +async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer, list[str]]]): """Client fixture for API-only server configuration.""" # Connect to the API-only server (first server in the list) api_server = api_only_servers[0][0] @@ -331,22 +361,44 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer, yield client +def _get_parallel_config(server: RemoteOpenAIServer): + response = requests.get(server.url_for("server_info?config_format=json")) + response.raise_for_status() + + vllm_config = response.json()["vllm_config"] + return vllm_config["parallel_config"] + + +def test_multinode_dp_server_info(server_manager): + head_server = server_manager.servers[0][0] + api_server_count = server_manager.api_server_count + + # Each request will hit one of the API servers + # `n_reqs` is set so that there is a good chance each server + # receives at least one request + n_reqs = 2 * api_server_count * api_server_count + parallel_configs = [_get_parallel_config(head_server) for _ in range(n_reqs)] + api_process_counts = [c["_api_process_count"] for c in parallel_configs] + api_process_ranks = [c["_api_process_rank"] for c in parallel_configs] + + assert all(c == api_server_count for c in api_process_counts), api_process_counts + assert all(0 <= r < api_server_count for r in api_process_ranks), api_process_ranks + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [MODEL_NAME], ) -async def test_multinode_dp_completion(client: openai.AsyncOpenAI, - servers: list[tuple[RemoteOpenAIServer, - list[str]]], - model_name: str) -> None: - +async def test_multinode_dp_completion( + client: openai.AsyncOpenAI, + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: async def make_request(): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -369,9 +421,7 @@ async def test_multinode_dp_completion(client: openai.AsyncOpenAI, # Test single request result = await make_request() assert result is not None - print( - "Multi-node internal LB handled single completion request successfully" - ) + print("Multi-node internal LB handled single completion request successfully") await asyncio.sleep(0.5) @@ -400,10 +450,14 @@ async def test_multinode_dp_completion(client: openai.AsyncOpenAI, _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed multi-node internal LB test with " - f"{len(servers)} DP ranks (API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed multi-node internal LB test with " + f"{len(servers)} DP ranks (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics head_server = servers[0][0] @@ -415,11 +469,11 @@ async def test_multinode_dp_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, - servers: list[ - tuple[RemoteOpenAIServer, - list[str]]], - model_name: str) -> None: +async def test_multinode_dp_completion_streaming( + client: openai.AsyncOpenAI, + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: prompt = "What is an LLM?" async def make_streaming_request(): @@ -433,11 +487,9 @@ async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -448,23 +500,21 @@ async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single streaming request result = await make_streaming_request() assert result is not None - print( - "Multi-node internal LB handled single streaming request successfully") + print("Multi-node internal LB handled single streaming request successfully") await asyncio.sleep(0.5) @@ -494,10 +544,14 @@ async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed multi-node internal LB streaming test with " - f"{len(servers)} DP ranks (API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed multi-node internal LB streaming test with " + f"{len(servers)} DP ranks (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics head_server = servers[0][0] @@ -510,17 +564,16 @@ async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, [MODEL_NAME], ) async def test_api_only_multinode_dp_completion( - api_only_client: openai.AsyncOpenAI, - api_only_servers: list[tuple[RemoteOpenAIServer, - list[str]]], model_name: str) -> None: + api_only_client: openai.AsyncOpenAI, + api_only_servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: """Test API-only server with all engines on separate headless server.""" async def make_request(): completion = await api_only_client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -573,11 +626,14 @@ async def test_api_only_multinode_dp_completion( api_server, api_server_args = api_only_servers[0] api_server_count = ( - api_server_args.count('--api-server-count') - and api_server_args[api_server_args.index('--api-server-count') + 1] - or 1) - print(f"Successfully completed API-only multi-node test with {DP_SIZE} " - f"engines on headless server (API server count: {api_server_count})") + api_server_args.count("--api-server-count") + and api_server_args[api_server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed API-only multi-node test with {DP_SIZE} " + f"engines on headless server (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics check_request_balancing(api_server, DP_SIZE) @@ -589,9 +645,10 @@ async def test_api_only_multinode_dp_completion( [MODEL_NAME], ) async def test_api_only_multinode_dp_completion_streaming( - api_only_client: openai.AsyncOpenAI, - api_only_servers: list[tuple[RemoteOpenAIServer, - list[str]]], model_name: str) -> None: + api_only_client: openai.AsyncOpenAI, + api_only_servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: """Test API-only server streaming with all engines on separate headless server.""" prompt = "What is an LLM?" @@ -607,11 +664,9 @@ async def test_api_only_multinode_dp_completion_streaming( single_output = single_completion.choices[0].text # Perform the streaming request - stream = await api_only_client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await api_only_client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -622,16 +677,15 @@ async def test_api_only_multinode_dp_completion_streaming( last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single streaming request @@ -666,11 +720,14 @@ async def test_api_only_multinode_dp_completion_streaming( _, api_server_args = api_only_servers[0] api_server_count = ( - api_server_args.count('--api-server-count') - and api_server_args[api_server_args.index('--api-server-count') + 1] - or 1) - print(f"Successfully completed API-only streaming test with {DP_SIZE} " - f"engines on headless server (API server count: {api_server_count})") + api_server_args.count("--api-server-count") + and api_server_args[api_server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed API-only streaming test with {DP_SIZE} " + f"engines on headless server (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics api_server = api_only_servers[0][0] diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_sched_and_preempt.py new file mode 100644 index 0000000000000..0f7ccb35a7576 --- /dev/null +++ b/tests/v1/e2e/test_async_sched_and_preempt.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import pytest + +from vllm import SamplingParams + +from ...conftest import VllmRunner +from ...models.utils import check_outputs_equal + +MODEL = "Qwen/Qwen3-0.6B" + + +def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): + """Test consistency of combos of async scheduling, preemption, + uni/multiproc executor, and various sampling parameters.""" + + first_prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(10)) + + " are:" + ) + example_prompts = [first_prompt, "In one word, the capital of France is "] + [ + f"Tell me about the number {i}: " for i in range(32) + ] + + sampling_param_tests: list[dict[str, Any]] = [ + dict(), + # dict(min_tokens=20), + dict(presence_penalty=-1.0), + dict(bad_words=["the", " the"]), + ] + + default_params = dict( + temperature=0.0, # greedy + max_tokens=20, + ) + + with monkeypatch.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + # m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1") + + outputs: list[tuple[str, list]] = [] + for test_preemption in [False, True]: + for executor in ["mp", "uni"]: + for async_scheduling in [False, True]: + cache_arg: dict[str, Any] = ( + dict(num_gpu_blocks_override=32) + if test_preemption + else dict(gpu_memory_utilization=0.7) + ) + test_config = ( + f"executor={executor}, preemption={test_preemption}," + f" async_sched={async_scheduling}" + ) + print("-" * 80) + print(f"---- TESTING: {test_config}") + print("-" * 80) + with VllmRunner( + MODEL, + max_model_len=512, + enforce_eager=True, + async_scheduling=async_scheduling, + distributed_executor_backend=executor, + dtype="float32", # avoid precision errors + **cache_arg, + ) as vllm_model: + results = [] + for override_params in sampling_param_tests: + print(f"----------- RUNNING PARAMS: {override_params}") + results.append( + vllm_model.generate( + example_prompts, + sampling_params=SamplingParams( + **default_params, **override_params + ), + ) + ) + + if not outputs: + # First check that the different parameter configs + # actually result in different output. + for other_test, params in zip( + results[1:], sampling_param_tests[1:] + ): + with pytest.raises(AssertionError): + check_outputs_equal( + outputs_0_lst=results[0], + outputs_1_lst=other_test, + name_0=f"baseline params={params}", + name_1=f"other params={params}", + ) + + outputs.append((test_config, results)) + + baseline_config, baseline_tests = outputs[0] + + for test_config, test_outputs in outputs[1:]: + for base_outs, test_outs, params in zip( + baseline_tests, test_outputs, sampling_param_tests + ): + check_outputs_equal( + outputs_0_lst=base_outs, + outputs_1_lst=test_outs, + name_0=f"baseline=[{baseline_config}], params={params}", + name_1=f"config=[{test_config}], params={params}", + ) + + print(f"PASSED: config=[{test_config}], params={params}") diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index f2f460513605f..0fcb97fe63055 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -9,13 +9,17 @@ from ...utils import create_new_process_for_each_test @create_new_process_for_each_test() -@pytest.mark.parametrize("attn_backend", - ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) +@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_cascade_attention(example_system_message, monkeypatch, attn_backend): prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" + if attn_backend == "FLASHINFER": + pytest.skip( + "This test is failing with FlashInfer backend and " + "needs investigation. See issue #25679." + ) + with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 4dfe1d3bb33fa..71b0e86c75c18 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -6,8 +6,7 @@ import pytest from vllm import LLM, SamplingParams -from ...core.block.e2e.test_correctness_sliding_window import (check_answers, - prep_prompts) +from ...utils import check_answers, prep_prompts @dataclass @@ -27,51 +26,53 @@ model_config = { [ "bigcode/starcoder2-3b", # sliding window only "google/gemma-3-1b-it", # sliding window + full attention - ]) + ], +) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False]) -def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, - disable_hybrid_kv_cache_manager): +def test_sliding_window_retrieval( + model, batch_size, seed, disable_hybrid_kv_cache_manager +): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). If we tell it upfront which we are going to be looking for, then it answers correctly (mostly). """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + test_config = model_config[model] - test_config = model_config[model] + llm = LLM( + model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager + ) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - llm = LLM( - model=model, - disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager) - sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + prompts, answer, indices = prep_prompts(batch_size, ln_range=test_config.ln_range) - prompts, answer, indices = prep_prompts(batch_size, - ln_range=test_config.ln_range) + check_length(prompts, llm, test_config.sliding_window) - check_length(prompts, llm, test_config.sliding_window) + # Fresh generation + responses = llm.generate(prompts, sampling_params) + check_answers( + indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0, + ) - # Fresh generation - responses = llm.generate(prompts, sampling_params) - check_answers(indices, - answer, - [response.outputs[0].text for response in responses], - accept_rate=1.0) - - # Re-generate with the same prompts to test prefix caching - responses = llm.generate(prompts, sampling_params) - check_answers(indices, - answer, - [response.outputs[0].text for response in responses], - accept_rate=1.0) + # Re-generate with the same prompts to test prefix caching + responses = llm.generate(prompts, sampling_params) + check_answers( + indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0, + ) def check_length(prompts: list[str], llm: LLM, sliding_window: int): """ - Check if the prompt length is valid, i.e., longer than the sliding window + Check if the prompt length is valid, i.e., longer than the sliding window size and shorter than the model's max length. Args: @@ -81,9 +82,9 @@ def check_length(prompts: list[str], llm: LLM, sliding_window: int): """ tokenizer = llm.get_tokenizer() max_model_len = llm.llm_engine.model_config.max_model_len - assert any( - len(tokenizer.encode(prompt)) > sliding_window - for prompt in prompts), "Prompt is too short for test" - assert all( - len(tokenizer.encode(prompt)) <= max_model_len - for prompt in prompts), "Prompt is too long for test" + assert any(len(tokenizer.encode(prompt)) > sliding_window for prompt in prompts), ( + "Prompt is too short for test" + ) + assert all(len(tokenizer.encode(prompt)) <= max_model_len for prompt in prompts), ( + "Prompt is too long for test" + ) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index d72e50e5196b8..89e5f26ac627f 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional, Union import pytest import torch @@ -10,12 +9,6 @@ import torch from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel from vllm.distributed import cleanup_dist_env_and_memory -from vllm.forward_context import get_forward_context -from vllm.model_executor.models.gemma3n_mm import ( - Gemma3nForConditionalGeneration) -from vllm.model_executor.models.registry import ModelRegistry -from vllm.model_executor.models.utils import extract_layer_index -from vllm.sequence import IntermediateTensors from ...utils import fork_new_process_for_each_test @@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test SEED = 42 -class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = super().forward(input_ids, positions, - intermediate_tensors, inputs_embeds, - **kwargs) - attn_metadata = get_forward_context().attn_metadata - # attn_metadata is None during dummy runs - if (attn_metadata is not None - and self.language_model.cache_config.kv_sharing_fast_prefill): - assert isinstance(attn_metadata, dict) # true in V1 - # Gemma3n-E2B has 30 layers, with last 20 layers being - # cross-decoder layers. Check attention metadata is correct - for layer_name, metadata in attn_metadata.items(): - layer_idx = extract_layer_index(layer_name) - if layer_idx >= 20: - assert hasattr(metadata, 'logits_indices_padded') - assert hasattr(metadata, 'num_logits_indices') - else: - assert not hasattr(metadata, 'logits_indices_padded') - assert not hasattr(metadata, 'num_logits_indices') - - # Last layer will be a KV sharing layer - layer_attn_metadata = attn_metadata[ - self.language_model.model.layers[-1].self_attn.attn.layer_name] - logits_indices_padded = (layer_attn_metadata.logits_indices_padded) - assert logits_indices_padded is not None - num_logits_indices = layer_attn_metadata.num_logits_indices - assert num_logits_indices > 0 - # Reset hidden states to random values and - # only set logits at logits_indices to valid values - # Because logits_indices are the only positions that are used - # for output token sampling, this still produces same outputs - logits_hs = hidden_states[logits_indices_padded] - hidden_states = torch.randn_like(hidden_states) - gen_indices = logits_indices_padded[:num_logits_indices] - hidden_states[gen_indices] = logits_hs[:num_logits_indices] - - return hidden_states - - @pytest.fixture def test_prompts(): """ @@ -119,24 +64,23 @@ def cleanup(llm: LLM, compilation_config: CompilationConfig): @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, test_prompts: list[str], ): - ModelRegistry.register_model("Gemma3nForConditionalGeneration", - TestGemma3nForConditionalGeneration) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( # This allows vLLM compilation backend to handle allocating and # managing buffers for cudagraph cudagraph_copy_inputs=True, level=CompilationLevel.PIECEWISE - if not enforce_eager else CompilationLevel.NO_COMPILATION) + if not enforce_eager + else CompilationLevel.NO_COMPILATION, + ) with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - # Make scheduling deterministic for reproducibility m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") @@ -150,21 +94,21 @@ def test_kv_sharing_fast_prefill( cleanup(llm, compilation_config) - llm = LLM(model="google/gemma-3n-E2B-it", - enforce_eager=enforce_eager, - compilation_config=compilation_config, - seed=SEED, - kv_sharing_fast_prefill=True) + llm = LLM( + model="google/gemma-3n-E2B-it", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + seed=SEED, + kv_sharing_fast_prefill=True, + ) optimized_responses = llm.generate(test_prompts, sampling_params) cleanup(llm, compilation_config) misses = 0 - for ref_response, optimized_response in zip(ref_responses, - optimized_responses): - if ref_response.outputs[0].text != optimized_response.outputs[ - 0].text: + for ref_response, optimized_response in zip(ref_responses, optimized_responses): + if ref_response.outputs[0].text != optimized_response.outputs[0].text: misses += 1 assert misses == 0 diff --git a/tests/v1/e2e/test_min_tokens.py b/tests/v1/e2e/test_min_tokens.py index f013425cb59df..e00a3d58debe3 100644 --- a/tests/v1/e2e/test_min_tokens.py +++ b/tests/v1/e2e/test_min_tokens.py @@ -13,7 +13,6 @@ Covers: 5) Multiple stop conditions """ -import os from typing import Optional, Union import pytest @@ -46,29 +45,36 @@ class MinTokensTestCase: self.expected_exact_len = expected_exact_len def __str__(self): - return (f"{self.name}: min={self.min_tokens}, " - f"max={self.max_tokens}, stop={self.stop}") + return ( + f"{self.name}: min={self.min_tokens}, " + f"max={self.max_tokens}, stop={self.stop}" + ) # Test scenarios covering all critical cases MIN_TOKENS_TEST_CASES = [ # === BASIC FUNCTIONALITY (should work) === - MinTokensTestCase(name="basic_min_tokens_no_stop", - min_tokens=8, - max_tokens=20, - stop=None, - expected_min_len=8), - MinTokensTestCase(name="min_tokens_zero", - min_tokens=0, - max_tokens=10, - stop=None, - expected_min_len=0), - MinTokensTestCase(name="min_equals_max_no_stop", - min_tokens=15, - max_tokens=15, - stop=None, - expected_exact_len=15), - + MinTokensTestCase( + name="basic_min_tokens_no_stop", + min_tokens=8, + max_tokens=20, + stop=None, + expected_min_len=8, + ), + MinTokensTestCase( + name="min_tokens_zero", + min_tokens=0, + max_tokens=10, + stop=None, + expected_min_len=0, + ), + MinTokensTestCase( + name="min_equals_max_no_stop", + min_tokens=15, + max_tokens=15, + stop=None, + expected_exact_len=15, + ), # === STOP STRINGS WITH MIN_TOKENS === # These tests expose the detokenizer bug where stop strings # bypass min_tokens @@ -94,9 +100,11 @@ MIN_TOKENS_TEST_CASES = [ expected_min_len=5, ), marks=pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), - strict=False), + reason=( + "Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)" + ), + strict=False, + ), id="min_tokens_with_comprehensive_stops", ), pytest.param( @@ -108,12 +116,13 @@ MIN_TOKENS_TEST_CASES = [ expected_min_len=3, ), marks=pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), - strict=False), + reason=( + "Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)" + ), + strict=False, + ), id="min_tokens_with_simple_char_stop", ), - # === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) === # These test the MinTokensLogitsProcessor handling of EOS tokens pytest.param( @@ -125,35 +134,32 @@ MIN_TOKENS_TEST_CASES = [ expected_exact_len=20, ), marks=pytest.mark.xfail( - reason= - ("Potential logits-processor bug: EOS tokens may bypass min_tokens" - ), + reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"), strict=False, ), id="min_equals_max_eos_only", ), - # === EDGE CASES === - MinTokensTestCase(name="large_min_tokens", - min_tokens=50, - max_tokens=60, - stop=None, - expected_min_len=50), + MinTokensTestCase( + name="large_min_tokens", + min_tokens=50, + max_tokens=60, + stop=None, + expected_min_len=50, + ), MinTokensTestCase( name="min_tokens_with_empty_stop_list", min_tokens=5, max_tokens=15, stop=[], # Empty stop list - expected_min_len=5), + expected_min_len=5, + ), ] @pytest.fixture(scope="module") def llm_v1(): """Create V1 LLM instance for testing""" - # Ensure V1 engine is used - os.environ["VLLM_USE_V1"] = "1" - llm = LLM( model=TEST_MODEL, tensor_parallel_size=1, @@ -170,25 +176,27 @@ def get_token_count(output: RequestOutput) -> int: return len(output.outputs[0].token_ids) -def assert_min_tokens_satisfied(output: RequestOutput, - test_case: MinTokensTestCase) -> None: +def assert_min_tokens_satisfied( + output: RequestOutput, test_case: MinTokensTestCase +) -> None: """Assert that min_tokens requirement is satisfied""" token_count = get_token_count(output) - stop_reason = (output.outputs[0].stop_reason - if output.outputs else "no output") + stop_reason = output.outputs[0].stop_reason if output.outputs else "no output" if test_case.expected_exact_len is not None: # Exact length requirement assert token_count == test_case.expected_exact_len, ( f"Expected exactly {test_case.expected_exact_len} tokens, " f"got {token_count} tokens. " - f"Stop reason: {stop_reason}") + f"Stop reason: {stop_reason}" + ) else: # Minimum length requirement assert token_count >= (test_case.expected_min_len or 0), ( f"Expected at least {test_case.expected_min_len} tokens, " f"got {token_count} tokens. " - f"Stop reason: {stop_reason}") + f"Stop reason: {stop_reason}" + ) @pytest.mark.parametrize( @@ -199,13 +207,13 @@ def assert_min_tokens_satisfied(output: RequestOutput, def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): """ Comprehensive test for min_tokens functionality in V1 engine. - + This test covers all critical scenarios for min_tokens: - Basic functionality (should work) - Stop strings with min_tokens (known bug) - EOS tokens with min_tokens (potential bug) - Edge cases - + Args: llm_v1: V1 LLM instance test_case: Test scenario parameters @@ -218,7 +226,7 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): max_tokens=test_case.max_tokens, stop=test_case.stop, temperature=GREEDY, - include_stop_str_in_output=True # Include stop strings for debugging + include_stop_str_in_output=True, # Include stop strings for debugging ) # Use simple prompt. Comprehensive stop lists should catch any generation @@ -250,13 +258,11 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): def test_min_tokens_basic_functionality(llm_v1: LLM): """ Test basic min_tokens functionality without stop conditions. - + This is a baseline test that should always pass and validates that min_tokens works correctly in the simple case. """ - sampling_params = SamplingParams(min_tokens=10, - max_tokens=20, - temperature=GREEDY) + sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY) prompt = "Once upon a time" outputs = llm_v1.generate([prompt], sampling_params) @@ -269,17 +275,16 @@ def test_min_tokens_basic_functionality(llm_v1: LLM): @pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), + reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"), strict=False, ) def test_min_tokens_stop_strings_bug(llm_v1: LLM): """ Test the specific bug where stop strings bypass min_tokens. - + This test specifically reproduces the bug Calvin is fixing in PR #22014. It should fail until that fix is merged. - + Strategy: Use guaranteed stop characters that will appear in any generated text. """ @@ -291,7 +296,8 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM): # Common letter; likely appears early stop=["e"], temperature=GREEDY, - include_stop_str_in_output=True) + include_stop_str_in_output=True, + ) # Simple prompt that will generate text containing "e" prompt = "The quick brown fox" @@ -308,23 +314,25 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM): # This assertion should fail due to the bug - if stop string is found early, # the model should still continue generating until min_tokens is reached - stop_reason = (outputs[0].outputs[0].stop_reason - if outputs[0].outputs else "no output") - assert token_count >= 15, ("Bug confirmed: " - f"{token_count} tokens < min_tokens=15. " - f"Reason: {stop_reason}. " - f"Text: {repr(generated_text)}") + stop_reason = ( + outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output" + ) + assert token_count >= 15, ( + "Bug confirmed: " + f"{token_count} tokens < min_tokens=15. " + f"Reason: {stop_reason}. " + f"Text: {repr(generated_text)}" + ) @pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), + reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"), strict=False, ) def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): """ Guaranteed test for stop strings bypassing min_tokens bug. - + Strategy: Use very low temperature and multiple common stop strings to virtually guarantee early detection, combined with long min_tokens to ensure the bug is exposed regardless of model behavior. @@ -337,7 +345,8 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): # Use multiple very common patterns - at least one will appear stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"], temperature=GREEDY, - include_stop_str_in_output=True) + include_stop_str_in_output=True, + ) # Simple prompt that will generate some text prompt = "The cat" @@ -346,8 +355,7 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): assert len(outputs) == 1 token_count = get_token_count(outputs[0]) generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" - stop_reason = (outputs[0].outputs[0].stop_reason - if outputs[0].outputs else "unknown") + stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown" print(f"Generated text: {repr(generated_text)}") print(f"Token count: {token_count}") @@ -357,21 +365,23 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): # will trigger early termination before min_tokens=50 is reached # It's virtually impossible to generate 50 tokens without hitting # at least one of: e, a, i, o, u, space, t, n, s, r - finish_reason = (outputs[0].outputs[0].finish_reason - if outputs[0].outputs else "unknown") + finish_reason = ( + outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown" + ) print(f"Finish reason: {finish_reason}") if finish_reason == "stop": - assert token_count >= 50, ("Bug confirmed: " - f"{token_count} tokens < min_tokens=50. " - f"Reason: {finish_reason}. " - f"Text: {repr(generated_text)}") + assert token_count >= 50, ( + "Bug confirmed: " + f"{token_count} tokens < min_tokens=50. " + f"Reason: {finish_reason}. " + f"Text: {repr(generated_text)}" + ) @pytest.mark.xfail( - reason=( - "Potential logits-processor bug: EOS tokens may bypass min_tokens"), + reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"), strict=False, ) def test_min_tokens_eos_behavior(llm_v1: LLM): @@ -404,8 +414,14 @@ def test_min_tokens_eos_behavior(llm_v1: LLM): finish_no_min = choice_no_min.finish_reason stop_no_min = choice_no_min.stop_reason - print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min, - " stop_reason=", stop_no_min) + print( + "[no-min] tokens=", + len(ids_no_min), + " finish=", + finish_no_min, + " stop_reason=", + stop_no_min, + ) assert finish_no_min == "stop", ( f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}" @@ -414,7 +430,8 @@ def test_min_tokens_eos_behavior(llm_v1: LLM): "For EOS-based stop (no user stop strings), stop_reason should be None." ) assert len(ids_no_min) < max_toks, ( - f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}") + f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}" + ) # Case 2: WITH min_tokens sp_with_min = SamplingParams( @@ -430,23 +447,31 @@ def test_min_tokens_eos_behavior(llm_v1: LLM): finish_with_min = choice_with_min.finish_reason stop_with_min = choice_with_min.stop_reason - print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min, - " stop_reason=", stop_with_min) + print( + "[with-min] tokens=", + len(ids_with_min), + " finish=", + finish_with_min, + " stop_reason=", + stop_with_min, + ) # Exact length reached; EOS should have been blocked assert len(ids_with_min) == max_toks, ( - f"Expected exactly {max_toks} tokens with min_tokens; " - f"got {len(ids_with_min)}") + f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}" + ) assert finish_with_min == "length", ( - f"Expected finish_reason 'length'; got {finish_with_min}") + f"Expected finish_reason 'length'; got {finish_with_min}" + ) assert eos_token_id not in ids_with_min, ( - "EOS token id should not appear when min_tokens prevents early EOS.") + "EOS token id should not appear when min_tokens prevents early EOS." + ) def test_min_tokens_validation(): """ Test that SamplingParams correctly validates min_tokens parameters. - + This tests the parameter validation logic in SamplingParams. """ # Valid cases @@ -456,14 +481,14 @@ def test_min_tokens_validation(): # Invalid cases with pytest.raises( - ValueError, - match="min_tokens must be greater than or equal to 0", + ValueError, + match="min_tokens must be greater than or equal to 0", ): SamplingParams(min_tokens=-1, max_tokens=10) with pytest.raises( - ValueError, - match="min_tokens must be less than or equal to max_tokens", + ValueError, + match="min_tokens must be less than or equal to max_tokens", ): SamplingParams(min_tokens=15, max_tokens=10) @@ -474,6 +499,6 @@ if __name__ == "__main__": Usage: cd vllm/ - VLLM_USE_V1=1 python -m pytest tests/v1/e2e/test_min_tokens.py -v + python -m pytest tests/v1/e2e/test_min_tokens.py -v """ pytest.main([__file__, "-v"]) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index bd0fa6b80781a..fbbbd0389c265 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -8,13 +8,15 @@ from typing import Any, Union import pytest import torch -from tests.utils import get_attn_backend_list_based_on_platform +from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory from vllm.platforms import current_platform +MTP_SIMILARITY_RATE = 0.8 + def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] @@ -46,19 +48,17 @@ def get_test_prompts(mm_enabled: bool): give no other output than that simple sentence without quotes. """ elif kind == "mm": - placeholders = [{ - "type": "image_url", - "image_url": { - "url": - f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" - }, - }] + placeholders = [ + { + "type": "image_url", + "image_url": { + "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + } + ] prompt = [ *placeholders, - { - "type": "text", - "text": "The meaning of the image is" - }, + {"type": "text", "text": "The meaning of the image is"}, ] else: raise ValueError(f"Unknown prompt type: {kind}") @@ -82,82 +82,122 @@ def test_ngram_correctness( sampling_config: SamplingParams, model_name: str, ): - ''' - Compare the outputs of a original LLM and a speculative LLM + """ + Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. - ''' - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts(mm_enabled=False) + """ + test_prompts = get_test_prompts(mm_enabled=False) - ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) - del ref_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_llm = LLM( - model=model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, - max_model_len=1024, - ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) - del spec_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches >= int(0.66 * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() @pytest.mark.parametrize( ["model_setup", "mm_enabled"], [ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + ( + "eagle3", + "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", + 1, + ), False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + marks=pytest.mark.skip( + reason="Skipping due to its head_dim not being a a multiple of 32" + ), + ), + ( + ( + "eagle", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + 1, + ), + False, + ), + ( + ( + "eagle3", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + 1, + ), + False, + ), pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + ), + False, + marks=large_gpu_mark(min_gb=80), + ), # works on 4x H100 + pytest.param( + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + ), True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), + marks=large_gpu_mark(min_gb=80), + ), # works on 4x H100 + ( + ( + "eagle", + "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", + 1, + ), + False, + ), ], ids=[ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "qwen3_eagle3", + "qwen3_eagle3", + "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm", - "deepseek_eagle" - ]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) + "deepseek_eagle", + ], +) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, @@ -169,33 +209,40 @@ def test_eagle_correctness( # TODO: Fix this flaky test pytest.skip( "TREE_ATTN is flaky in the test disable for now until it can be " - "reolved (see https://github.com/vllm-project/vllm/issues/22922)") + "resolved (see https://github.com/vllm-project/vllm/issues/22922)" + ) # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. model_setup: (method, model_name, eagle_model_name, tp_size) - ''' + """ with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - m.setenv("VLLM_MLA_DISABLE", "1") - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": + # Scout requires default backend selection + # because vision encoder has head_dim 88 being incompatible + # with FLASH_ATTN and needs to fall back to Flex Attn + pass + else: + m.setenv("VLLM_MLA_DISABLE", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN_VLLM_V1" - and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN_VLLM_V1 does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) - if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, - max_model_len=2048, - tensor_parallel_size=tp_size) + ref_llm = LLM( + model=model_name, max_model_len=2048, tensor_parallel_size=tp_size + ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() @@ -230,3 +277,70 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), + (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), + ], + ids=["mimo", "deepseek"], +) +def test_mtp_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_setup: tuple[str, str, int], + mm_enabled: bool, +): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) + """ + Compare the outputs of a original LLM and a speculative LLM + should be the same when using MTP speculative decoding. + model_setup: (method, model_name, tp_size) + """ + with monkeypatch.context() as m: + m.setenv("VLLM_MLA_DISABLE", "1") + + method, model_name, tp_size = model_setup + + ref_llm = LLM( + model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "num_speculative_tokens": 1, + "max_model_len": 2048, + }, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 80% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index d7722142b207f..c5c5d35b83c3e 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -5,14 +5,16 @@ import pytest import torch from transformers import AutoTokenizer -from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, - NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, - TOKENIZER_NAME, - DummyOutputProcessorTestVectors, - generate_dummy_prompt_logprobs_tensors, - generate_dummy_sample_logprobs) +from tests.v1.engine.utils import ( + NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + PROMPT_LEN, + TOKENIZER_NAME, + DummyOutputProcessorTestVectors, + generate_dummy_prompt_logprobs_tensors, + generate_dummy_sample_logprobs, +) from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from ...distributed.conftest import publisher_config, random_port # noqa: F401 @@ -24,7 +26,7 @@ EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor] def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: """Generate output processor dummy test vectors, without logprobs - + Returns: DummyOutputProcessorTestVectors instance with no logprobs """ @@ -32,9 +34,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config() # Tokenize prompts under test & create dummy generated tokens - prompt_tokens = [ - tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS - ] + prompt_tokens = [tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS] generation_tokens = [ tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS ] @@ -43,14 +43,9 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer.decode(prompt_tokens, skip_special_tokens=True) for prompt_tokens in prompt_tokens ] - prompt_strings_len = [ - len(prompt_string) for prompt_string in prompt_strings - ] + prompt_strings_len = [len(prompt_string) for prompt_string in prompt_strings] return DummyOutputProcessorTestVectors( tokenizer=tokenizer, - tokenizer_group=init_tokenizer_from_configs( - vllm_config.model_config, vllm_config.scheduler_config, - vllm_config.lora_config), vllm_config=vllm_config, full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], prompt_tokens=prompt_tokens, @@ -62,13 +57,14 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len) ], prompt_logprobs=[], - generation_logprobs=[]) + generation_logprobs=[], + ) @pytest.fixture def dummy_test_vectors() -> DummyOutputProcessorTestVectors: """Generate output processor dummy test vectors, with logprobs - + Returns: DummyOutputProcessorTestVectors instance with logprobs """ @@ -80,12 +76,16 @@ def dummy_test_vectors() -> DummyOutputProcessorTestVectors: generate_dummy_sample_logprobs( sampled_tokens_list=tokens_list, num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST, - tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens + tokenizer=dtv.tokenizer, + ) + for tokens_list in dtv.generation_tokens ] dtv.prompt_logprobs = [ generate_dummy_prompt_logprobs_tensors( prompt_tokens_list=tokens_list, num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST, - tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens + tokenizer=dtv.tokenizer, + ) + for tokens_list in dtv.prompt_tokens ] return dtv diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index df04a14af70ce..444d771a18d63 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -21,16 +21,16 @@ from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import LoggingStatLogger if not current_platform.is_cuda(): - pytest.skip(reason="V1 currently only supported on CUDA.", - allow_module_level=True) + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) TEXT_ENGINE_ARGS = AsyncEngineArgs( model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, ) -VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct", - enforce_eager=True) +VISION_ENGINE_ARGS = AsyncEngineArgs( + model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True +) TEXT_PROMPT = "Hello my name is Robert and" @@ -38,12 +38,11 @@ VISION_PROMPT_TEMPLATE = ( "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" "What is in the image?<|im_end|>\n" - "<|im_start|>assistant\n") + "<|im_start|>assistant\n" +) VISION_PROMPT = { "prompt": VISION_PROMPT_TEMPLATE, - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image - }, + "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image}, } @@ -70,10 +69,9 @@ async def generate( n=n, prompt_logprobs=prompt_logprobs, ) - async for out in engine.generate(request_id=request_id, - prompt=prompt, - sampling_params=sampling_params): - + async for out in engine.generate( + request_id=request_id, prompt=prompt, sampling_params=sampling_params + ): num_tokens = sum(len(output.token_ids) for output in out.outputs) if output_kind == RequestOutputKind.DELTA: count += num_tokens @@ -89,24 +87,19 @@ async def generate( @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.parametrize( "engine_args,prompt", [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], ) @pytest.mark.asyncio async def test_load( - monkeypatch: pytest.MonkeyPatch, output_kind: RequestOutputKind, engine_args: AsyncEngineArgs, prompt: PromptType, ): - # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 - # so that in the future when we switch, we don't have to change all the - # tests. - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -121,40 +114,40 @@ async def test_load( for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + generate( + engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS + ) + ) + ) # Confirm that we got all the EXPECTED tokens from the requests. - done, pending = await asyncio.wait(tasks, - return_when=asyncio.FIRST_EXCEPTION) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) for task in pending: task.cancel() for task in done: num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + f"expected {NUM_EXPECTED_TOKENS}" + ) assert not engine.output_processor.has_unfinished_requests() @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.parametrize( "engine_args,prompt", [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], ) @pytest.mark.asyncio async def test_abort( - monkeypatch: pytest.MonkeyPatch, output_kind: RequestOutputKind, engine_args: AsyncEngineArgs, prompt: PromptType, ): - - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -170,14 +163,17 @@ async def test_abort( # Create concurrent requests. tasks: list[asyncio.Task] = [] for idx, request_id in enumerate(request_ids): - max_tokens = (NUM_EXPECTED_TOKENS_LONG if - (idx - in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + max_tokens = ( + NUM_EXPECTED_TOKENS_LONG + if (idx in REQUEST_IDS_TO_ABORT) + else NUM_EXPECTED_TOKENS + ) n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - max_tokens, n))) + generate(engine, request_id, prompt, output_kind, max_tokens, n) + ) + ) # API server cancels requests when they disconnect. for idx in REQUEST_IDS_TO_ABORT: @@ -197,7 +193,8 @@ async def test_abort( expected_tokens = NUM_EXPECTED_TOKENS * n assert num_generated_tokens == expected_tokens, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {expected_tokens}") + f"expected {expected_tokens}" + ) # Make sure all aborted requests were really aborted. assert not engine.output_processor.has_unfinished_requests() @@ -205,24 +202,19 @@ async def test_abort( # Confirm we can do another generation. request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}" task = asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS)) + generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS) + ) num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.asyncio -async def test_multi_abort( - monkeypatch: pytest.MonkeyPatch, - output_kind: RequestOutputKind, -): - - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - +async def test_multi_abort(output_kind: RequestOutputKind): + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) after.callback(engine.shutdown) @@ -238,14 +230,19 @@ async def test_multi_abort( # Create concurrent requests. tasks: list[asyncio.Task] = [] for idx, request_id in enumerate(request_ids): - max_tokens = (NUM_EXPECTED_TOKENS_LONG if - (idx - in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + max_tokens = ( + NUM_EXPECTED_TOKENS_LONG + if (idx in REQUEST_IDS_TO_ABORT) + else NUM_EXPECTED_TOKENS + ) n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( - generate(engine, request_id, TEXT_PROMPT, output_kind, - max_tokens, n))) + generate( + engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n + ) + ) + ) # Let requests start await asyncio.sleep(0.5) @@ -261,25 +258,26 @@ async def test_multi_abort( for idx, result in enumerate(results): if idx in REQUEST_IDS_TO_ABORT: # Aborted requests should return partial results - assert isinstance( - result, tuple - ), f"Request {idx} should have completed with partial results" + assert isinstance(result, tuple), ( + f"Request {idx} should have completed with partial results" + ) num_generated_tokens, request_id = result # Should have generated some tokens before abort assert num_generated_tokens > 0, ( - f"Aborted request " - f"{request_id} should have generated some tokens") + f"Aborted request {request_id} should have generated some tokens" + ) else: # Non-aborted requests should complete normally - assert isinstance( - result, - tuple), f"Request {idx} should have completed successfully" + assert isinstance(result, tuple), ( + f"Request {idx} should have completed successfully" + ) num_generated_tokens, request_id = result n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 expected_tokens = NUM_EXPECTED_TOKENS * n assert num_generated_tokens == expected_tokens, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {expected_tokens}") + f"expected {expected_tokens}" + ) # Make sure all aborted requests were cleaned up assert not engine.output_processor.has_unfinished_requests() @@ -292,15 +290,11 @@ async def test_multi_abort( ) @pytest.mark.asyncio async def test_finished_flag( - monkeypatch: pytest.MonkeyPatch, n: int, engine_args: AsyncEngineArgs, prompt: PromptType, ): - - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -314,9 +308,9 @@ async def test_finished_flag( ) outputs = [ out - async for out in engine.generate(request_id="request-33", - prompt=prompt, - sampling_params=sampling_params) + async for out in engine.generate( + request_id="request-33", prompt=prompt, sampling_params=sampling_params + ) ] # Assert only the last output has the finished flag set @@ -329,13 +323,11 @@ async def test_finished_flag( [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], ) @pytest.mark.asyncio -async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, - engine_args: AsyncEngineArgs, - prompt: PromptType): +async def test_mid_stream_cancellation( + engine_args: AsyncEngineArgs, prompt: PromptType +): """Test that requests can be cancelled mid-stream.""" - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -358,7 +350,9 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, RequestOutputKind.DELTA, NUM_TOKENS, cancel_after=NUM_EXPECTED_TOKENS, - ))) + ) + ) + ) # Wait for all tasks to complete results = await asyncio.gather(*tasks) @@ -367,7 +361,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, for num_generated_tokens, request_id in results: assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} tokens but " - f"expected to cancel after {NUM_EXPECTED_TOKENS}") + f"expected to cancel after {NUM_EXPECTED_TOKENS}" + ) # Make sure no requests are left hanging assert not engine.output_processor.has_unfinished_requests() @@ -375,15 +370,16 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, # Confirm we can reuse the request id after the cancellations. request_id = request_ids[0] task = asyncio.create_task( - generate(engine, request_id, prompt, RequestOutputKind.DELTA, - NUM_EXPECTED_TOKENS)) + generate( + engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS + ) + ) num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() class MockLoggingStatLogger(LoggingStatLogger): - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): super().__init__(vllm_config, engine_index) self.log = MagicMock() @@ -393,12 +389,10 @@ class MockLoggingStatLogger(LoggingStatLogger): async def test_customize_loggers(monkeypatch): """Test that we can customize the loggers. If a customized logger is provided at the init, it should - be used directly. + be added to the default loggers. """ - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args( TEXT_ENGINE_ARGS, @@ -410,42 +404,46 @@ async def test_customize_loggers(monkeypatch): stat_loggers = engine.logger_manager.per_engine_logger_dict assert len(stat_loggers) == 1 - assert len(stat_loggers[0]) == 1 + assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger stat_loggers[0][0].log.assert_called_once() @pytest.mark.asyncio(scope="module") -async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - +async def test_dp_rank_argument(): + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) after.callback(engine.shutdown) - sampling_params = SamplingParams(max_tokens=100, - output_kind=RequestOutputKind.DELTA, - temperature=1.0, - seed=33) + sampling_params = SamplingParams( + max_tokens=100, + output_kind=RequestOutputKind.DELTA, + temperature=1.0, + seed=33, + ) # Test with valid DP rank. - async for _ in engine.generate(request_id="request-34", - prompt=TEXT_PROMPT, - sampling_params=sampling_params, - data_parallel_rank=0): + async for _ in engine.generate( + request_id="request-34", + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + data_parallel_rank=0, + ): pass # Test with out-of-range DP rank. with pytest.raises(ValueError): - async for _ in engine.generate(request_id="request-35", - prompt=TEXT_PROMPT, - sampling_params=sampling_params, - data_parallel_rank=1): + async for _ in engine.generate( + request_id="request-35", + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + data_parallel_rank=1, + ): pass @pytest.mark.asyncio -async def test_check_health(monkeypatch: pytest.MonkeyPatch): +async def test_check_health(): """Test that check_health returns normally for healthy engine and raises EngineDeadError when the engine is dead. """ @@ -453,9 +451,7 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): from vllm.v1.engine.exceptions import EngineDeadError - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) after.callback(engine.shutdown) @@ -464,10 +460,14 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): await engine.check_health() # Test 2: Mock the errored property to simulate a dead engine - with patch.object(type(engine), - 'errored', - new_callable=lambda: property(lambda self: True) - ), pytest.raises(EngineDeadError): + with ( + patch.object( + type(engine), + "errored", + new_callable=lambda: property(lambda self: True), + ), + pytest.raises(EngineDeadError), + ): await engine.check_health() # Test 3: Verify healthy engine still works after mock @@ -475,17 +475,13 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.asyncio -async def test_abort_final_output( - monkeypatch: pytest.MonkeyPatch, - output_kind: RequestOutputKind, -): +async def test_abort_final_output(output_kind: RequestOutputKind): """Test that abort() returns a final output with correct information.""" - with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") - + with ExitStack() as after: with set_default_torch_num_threads(1): engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) after.callback(engine.shutdown) @@ -503,8 +499,8 @@ async def test_abort_final_output( outputs: list[RequestOutput] = [] generated = asyncio.create_task( - collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, - outputs)) + collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs) + ) # Let it generate some tokens await asyncio.sleep(0.5) @@ -524,14 +520,13 @@ async def test_abort_final_output( assert final_output.outputs[0].stop_reason is None # Verify num_cached_tokens is set correctly - assert hasattr(final_output, 'num_cached_tokens') + assert hasattr(final_output, "num_cached_tokens") assert final_output.num_cached_tokens >= 0 # If we got intermediate outputs, verify they are consistent if output_kind == RequestOutputKind.DELTA: # For DELTA, sum all intermediate tokens should <= final tokens - token_count = sum( - len(output.outputs[0].token_ids) for output in outputs) + token_count = sum(len(output.outputs[0].token_ids) for output in outputs) assert token_count > 0 # This would ordinarily be 0, but could end up > 0 if the # final abort is coalesced with another chunk in the output queue. @@ -553,9 +548,9 @@ async def collect_outputs( ) -> Optional[RequestOutput]: """Helper to collect outputs and return the final one.""" final_output: Optional[RequestOutput] = None - async for output in engine.generate(request_id=request_id, - prompt=prompt, - sampling_params=sampling_params): + async for output in engine.generate( + request_id=request_id, prompt=prompt, sampling_params=sampling_params + ): if not output.finished: outputs_list.append(output) final_output = output diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index f70a3ce147ff2..943402e429b6a 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -5,25 +5,19 @@ from argparse import ArgumentError import pytest -from vllm import envs from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser -if not envs.VLLM_USE_V1: - pytest.skip( - "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", - allow_module_level=True, - ) - def test_prefix_caching_from_cli(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert (vllm_config.cache_config.enable_prefix_caching - ), "V1 turns on prefix caching by default." + assert vllm_config.cache_config.enable_prefix_caching, ( + "V1 turns on prefix caching by default." + ) # Turn it off possible with flag. args = parser.parse_args(["--no-enable-prefix-caching"]) @@ -36,18 +30,18 @@ def test_prefix_caching_from_cli(): assert vllm_config.cache_config.enable_prefix_caching # default hash algorithm is "builtin" - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" + assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" + + # set hash algorithm to sha256_cbor + args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256_cbor" # set hash algorithm to sha256 args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" - # set hash algorithm to builtin - args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"]) - vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" - # an invalid hash algorithm raises an error parser.exit_on_error = False with pytest.raises(ArgumentError): @@ -56,10 +50,10 @@ def test_prefix_caching_from_cli(): def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") - vllm_config: VllmConfig = engine_args.create_engine_config( - UsageContext.LLM_CLASS) + vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS) from vllm.platforms import current_platform + device_name = current_platform.get_device_name().lower() if "h100" in device_name or "h200" in device_name: # For H100 and H200, we use larger default values. @@ -75,7 +69,6 @@ def test_defaults_with_usage_context(): assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501 engine_args = EngineArgs(model="facebook/opt-125m") - vllm_config = engine_args.create_engine_config( - UsageContext.OPENAI_API_SERVER) + vllm_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER) assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501 diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 2ea957a3e230f..997b2b74bb6b5 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -22,8 +22,7 @@ from vllm.v1.outputs import ModelRunnerOutput from ...utils import create_new_process_for_each_test, multi_gpu_test if not current_platform.is_cuda(): - pytest.skip(reason="V1 currently only supported on CUDA.", - allow_module_level=True) + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) @@ -35,9 +34,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, sampling_params=SamplingParams(), pooling_params=None, eos_token_id=None, @@ -49,208 +46,196 @@ def make_request() -> EngineCoreRequest: @create_new_process_for_each_test() -def test_engine_core(monkeypatch: pytest.MonkeyPatch): +def test_engine_core(): + """Setup the EngineCore.""" + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - """Setup the EngineCore.""" - engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() - executor_class = Executor.get_class(vllm_config) + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) + """Test basic request lifecycle.""" - with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) - """Test basic request lifecycle.""" + # First request. + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 - # First request. - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) - assert len(engine_core.scheduler.waiting) == 1 - assert len(engine_core.scheduler.running) == 0 + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 1 - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 1 + # Second request. + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 1 - # Second request. - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) - assert len(engine_core.scheduler.waiting) == 1 - assert len(engine_core.scheduler.running) == 1 + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 2 + # Add two requests in a row. + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) + assert len(engine_core.scheduler.waiting) == 2 + assert len(engine_core.scheduler.running) == 2 - # Add two requests in a row. - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) - assert len(engine_core.scheduler.waiting) == 2 - assert len(engine_core.scheduler.running) == 2 + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 4 - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 4 + # Loop through until they are all done. + while (outs := engine_core.step()[0].get(0)) and outs.outputs: + pass - # Loop through until they are all done. - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 + """Test abort cycle.""" - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 0 - """Test abort cycle.""" + # Basic abort. + req = make_request() + request_id = req.request_id - # Basic abort. - req = make_request() - request_id = req.request_id + engine_core.add_request(*engine_core.preprocess_add_request(req)) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 + assert engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() - engine_core.add_request(*engine_core.preprocess_add_request(req)) - assert len(engine_core.scheduler.waiting) == 1 - assert len(engine_core.scheduler.running) == 0 - assert engine_core.scheduler.has_unfinished_requests() - assert not engine_core.scheduler.has_finished_requests() + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 1 + assert engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 1 - assert engine_core.scheduler.has_unfinished_requests() - assert not engine_core.scheduler.has_finished_requests() + engine_core.abort_requests([request_id]) + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 + assert not engine_core.scheduler.has_unfinished_requests() + assert engine_core.scheduler.has_finished_requests() - engine_core.abort_requests([request_id]) - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 0 - assert not engine_core.scheduler.has_unfinished_requests() - assert engine_core.scheduler.has_finished_requests() + _ = engine_core.step() + assert not engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() - _ = engine_core.step() - assert not engine_core.scheduler.has_unfinished_requests() - assert not engine_core.scheduler.has_finished_requests() + # Add, step, abort 1 of the 3. + req0 = make_request() + req1 = make_request() + req2 = make_request() - # Add, step, abort 1 of the 3. - req0 = make_request() - req1 = make_request() - req2 = make_request() + engine_core.add_request(*engine_core.preprocess_add_request(req0)) + engine_core.add_request(*engine_core.preprocess_add_request(req1)) + assert len(engine_core.scheduler.waiting) == 2 + assert len(engine_core.scheduler.running) == 0 - engine_core.add_request(*engine_core.preprocess_add_request(req0)) - engine_core.add_request(*engine_core.preprocess_add_request(req1)) - assert len(engine_core.scheduler.waiting) == 2 - assert len(engine_core.scheduler.running) == 0 + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 2 + engine_core.add_request(*engine_core.preprocess_add_request(req2)) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 2 - engine_core.add_request(*engine_core.preprocess_add_request(req2)) - assert len(engine_core.scheduler.waiting) == 1 - assert len(engine_core.scheduler.running) == 2 + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 3 - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 3 + # Abort just one. + engine_core.abort_requests([req1.request_id]) + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 - # Abort just one. - engine_core.abort_requests([req1.request_id]) - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 2 + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 - _ = engine_core.step() - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 2 + # Abort the other requests at the same time. + engine_core.abort_requests([req2.request_id, req0.request_id]) + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 - # Abort the other requests at the same time. - engine_core.abort_requests([req2.request_id, req0.request_id]) - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 0 + # Sending duplicate requests with same request_id + req0 = make_request() + req1 = make_request() + req0.request_id = req1.request_id = "test" + engine_core.add_request(*engine_core.preprocess_add_request(req0)) - # Sending duplicate requests with same request_id - req0 = make_request() - req1 = make_request() - req0.request_id = req1.request_id = "test" - engine_core.add_request(*engine_core.preprocess_add_request(req0)) + while (outs := engine_core.step()[0].get(0)) and outs.outputs: + pass - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass + engine_core.add_request(*engine_core.preprocess_add_request(req1)) + while (outs := engine_core.step()[0].get(0)) and outs.outputs: + pass - engine_core.add_request(*engine_core.preprocess_add_request(req1)) - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass - - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 0 + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 @create_new_process_for_each_test() -def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): +def test_engine_core_advanced_sampling(): """ A basic end-to-end test to verify that the engine functions correctly when additional sampling parameters, such as top_p, min_tokens, and presence_penalty, are set. """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - """Setup the EngineCore.""" - engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() - executor_class = Executor.get_class(vllm_config) + """Setup the EngineCore.""" + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) - with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) - """Test basic request lifecycle.""" - # First request. - request: EngineCoreRequest = make_request() - request.sampling_params = SamplingParams( - min_tokens=4, - presence_penalty=1.0, - frequency_penalty=1.0, - repetition_penalty=0.1, - stop_token_ids=[1001, 1002], + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True ) - engine_core.add_request(*engine_core.preprocess_add_request(request)) + """Test basic request lifecycle.""" + # First request. + request: EngineCoreRequest = make_request() + request.sampling_params = SamplingParams( + min_tokens=4, + presence_penalty=1.0, + frequency_penalty=1.0, + repetition_penalty=0.1, + stop_token_ids=[1001, 1002], + ) + engine_core.add_request(*engine_core.preprocess_add_request(request)) - def _check_engine_state(): - assert len(engine_core.scheduler.waiting) == 1 - assert len(engine_core.scheduler.running) == 0 - # Loop through until they are all done. - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass - assert len(engine_core.scheduler.waiting) == 0 - assert len(engine_core.scheduler.running) == 0 + def _check_engine_state(): + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 + # Loop through until they are all done. + while (outs := engine_core.step()[0].get(0)) and outs.outputs: + pass + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 - _check_engine_state() + _check_engine_state() - # Second request. - request2 = make_request() - request2.sampling_params = SamplingParams( - top_p=0.99, - top_k=50, - ) - engine_core.add_request(*engine_core.preprocess_add_request(request2)) - _check_engine_state() + # Second request. + request2 = make_request() + request2.sampling_params = SamplingParams( + top_p=0.99, + top_k=50, + ) + engine_core.add_request(*engine_core.preprocess_add_request(request2)) + _check_engine_state() @create_new_process_for_each_test() -def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): +def test_engine_core_concurrent_batches(): """ Test that the engine can handle multiple concurrent batches. """ - def make_request_with_max_tokens(req_id: str, - max_tokens: int) -> EngineCoreRequest: + def make_request_with_max_tokens(req_id: str, max_tokens: int) -> EngineCoreRequest: request = make_request() request.request_id = req_id request.sampling_params.max_tokens = max_tokens return request class DummyExecutor(UniProcExecutor): - - def initialize_from_config( - self, kv_cache_configs: list[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: super().initialize_from_config(kv_cache_configs) # Create a thread pool with a single worker @@ -259,12 +244,15 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): def execute_model( self, scheduler_output, + non_block=False, ) -> Future[ModelRunnerOutput]: """Make execute_model non-blocking.""" + # DummyExecutor used only for testing async case. + assert non_block + def _execute(): - output = self.collective_rpc("execute_model", - args=(scheduler_output, )) + output = self.collective_rpc("execute_model", args=(scheduler_output,)) # Make a copy because output[0] may be reused # by the next batch. return copy.deepcopy(output[0]) @@ -277,193 +265,166 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): return 2 def shutdown(self): - if hasattr(self, 'thread_pool'): + if hasattr(self, "thread_pool"): self.thread_pool.shutdown(wait=False) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - engine_args = EngineArgs( - model=MODEL_NAME, - # To test concurrent batches. - max_num_seqs=2, - # Avoid all requests being scheduled once. - enable_prefix_caching=False, - max_num_batched_tokens=10, - # Reduce startup time. - enforce_eager=True, + engine_args = EngineArgs( + model=MODEL_NAME, + # To test concurrent batches. + max_num_seqs=2, + # Avoid all requests being scheduled once. + enable_prefix_caching=False, + max_num_batched_tokens=10, + # Reduce startup time. + enforce_eager=True, + ) + vllm_config = engine_args.create_engine_config() + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, log_stats=False, executor_class=DummyExecutor ) - vllm_config = engine_args.create_engine_config() - with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - log_stats=False, - executor_class=DummyExecutor) - assert engine_core.batch_queue is not None + assert engine_core.batch_queue is not None - # Add two requests in a row. Each request have 12 prompt tokens. - req0 = make_request_with_max_tokens("0", 5) - engine_core.add_request(*engine_core.preprocess_add_request(req0)) - req1 = make_request_with_max_tokens("1", 5) - engine_core.add_request(*engine_core.preprocess_add_request(req1)) + # Add two requests in a row. Each request have 12 prompt tokens. + req0 = make_request_with_max_tokens("0", 5) + engine_core.add_request(*engine_core.preprocess_add_request(req0)) + req1 = make_request_with_max_tokens("1", 5) + engine_core.add_request(*engine_core.preprocess_add_request(req1)) - # Schedule Batch 1: (10, req0) - assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 1 - scheduler_output = engine_core.batch_queue.queue[-1][1] - assert scheduler_output.num_scheduled_tokens["0"] == 10 - # num_computed_tokens should have been updated immediately. - assert engine_core.scheduler.requests[ - req0.request_id].num_computed_tokens == 10 + # Schedule Batch 1: (10, req0) + assert engine_core.step_with_batch_queue()[0] is None + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] + assert scheduler_output.num_scheduled_tokens["0"] == 10 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[req0.request_id].num_computed_tokens == 10 - # Schedule Batch 2: (2, req0), (8, req1) - assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] - assert scheduler_output.num_scheduled_tokens["0"] == 2 - assert scheduler_output.num_scheduled_tokens["1"] == 8 - # num_computed_tokens should have been updated immediately. - assert engine_core.scheduler.requests["0"].num_computed_tokens == 12 - assert engine_core.scheduler.requests["1"].num_computed_tokens == 8 + # Schedule Batch 2: (2, req0), (8, req1) + assert engine_core.step_with_batch_queue()[0] == {} + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] + assert scheduler_output.num_scheduled_tokens["0"] == 2 + assert scheduler_output.num_scheduled_tokens["1"] == 8 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests["0"].num_computed_tokens == 12 + assert engine_core.scheduler.requests["1"].num_computed_tokens == 8 - assert engine_core.scheduler.get_num_unfinished_requests() == 2 + assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Batch queue is full. Finish Batch 1. - engine_core.step_with_batch_queue() + # Finish Batch 1 and schedule Batch 3: (4, req1). + # Note that req0 cannot be scheduled + # because it is in the decoding stage now. + engine_core.step_with_batch_queue() + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] + assert scheduler_output.num_scheduled_tokens["1"] == 4 - # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled - # because it is in the decoding stage now. - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] - assert scheduler_output.num_scheduled_tokens["1"] == 4 + # Finish Batch 2. Get first token of req0. + # Schedule Batch 4: (1, req0). + output = engine_core.step_with_batch_queue()[0].get(0) + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 + scheduler_output = engine_core.batch_queue[-1][1] + assert scheduler_output.num_scheduled_tokens["0"] == 1 - # Batch queue is full. Finish Batch 2. Get first token of req0. - output = engine_core.step_with_batch_queue()[0].get(0) + # Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1). + output = engine_core.step_with_batch_queue()[0].get(0) + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 + scheduler_output = engine_core.batch_queue[-1][1] + assert scheduler_output.num_scheduled_tokens["1"] == 1 + + # Loop until req0 is finished. + req_id = 0 + expected_num_tokens = [ + engine_core.scheduler.requests["0"].num_tokens + 1, + engine_core.scheduler.requests["1"].num_tokens + 1, + ] + while engine_core.scheduler.get_num_unfinished_requests() == 2: + output = engine_core.step_with_batch_queue()[0] + # Every step consumes an output. assert output is not None - assert len(output.outputs) == 1 - assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 - - # Schedule Batch 4: (1, req0). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] - assert scheduler_output.num_scheduled_tokens["0"] == 1 - - # Batch queue is full. Finish Batch 3. Get first token of req1. - output = engine_core.step_with_batch_queue()[0].get(0) - assert output is not None - assert len(output.outputs) == 1 - assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 - - # Schedule Batch 5: (1, req1). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] - assert scheduler_output.num_scheduled_tokens["1"] == 1 - - # Loop until req0 is finished. - step = 0 - req_id = 0 - expected_num_tokens = [ - engine_core.scheduler.requests["0"].num_tokens + 1, - engine_core.scheduler.requests["1"].num_tokens + 1, - ] - while engine_core.scheduler.get_num_unfinished_requests() == 2: - output = engine_core.step_with_batch_queue()[0] - if step % 2 == 0: - # Even steps consumes an output. - assert output is not None - assert len(output[0].outputs) == 1 - if req_id in engine_core.scheduler.requests: - assert engine_core.scheduler.requests[ - req_id].num_tokens == expected_num_tokens[req_id] - expected_num_tokens[req_id] += 1 - req_id = (req_id + 1) % 2 - else: - # Odd steps schedules a new batch. - assert output is None - step += 1 + assert len(output[0].outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert ( + engine_core.scheduler.requests[req_id].num_tokens + == expected_num_tokens[req_id] + ) + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 @multi_gpu_test(num_gpus=2) -def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch): +def test_engine_core_tp(): """ Test engine can initialize worker in tp properly """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - """Setup the EngineCore.""" - engine_args = EngineArgs( - model=MODEL_NAME, - tensor_parallel_size=2, - # Reduce startup time. - enforce_eager=True, + """Setup the EngineCore.""" + engine_args = EngineArgs( + model=MODEL_NAME, + tensor_parallel_size=2, + # Reduce startup time. + enforce_eager=True, + ) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True ) - vllm_config = engine_args.create_engine_config() - executor_class = Executor.get_class(vllm_config) - with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + def get_worker_cache_config_field(worker, key: str): + return getattr(worker.cache_config, key) - def get_worker_cache_config_field(worker, key: str): - return getattr(worker.cache_config, key) - - num_gpu_blocks = engine_core.collective_rpc( - get_worker_cache_config_field, args=("num_gpu_blocks", )) - num_cpu_blocks = engine_core.collective_rpc( - get_worker_cache_config_field, args=("num_cpu_blocks", )) - assert all(x is not None for x in num_gpu_blocks) - assert all(x is not None for x in num_cpu_blocks) + num_gpu_blocks = engine_core.collective_rpc( + get_worker_cache_config_field, args=("num_gpu_blocks",) + ) + num_cpu_blocks = engine_core.collective_rpc( + get_worker_cache_config_field, args=("num_cpu_blocks",) + ) + assert all(x is not None for x in num_gpu_blocks) + assert all(x is not None for x in num_cpu_blocks) @create_new_process_for_each_test() -def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch): +def test_engine_core_invalid_request_id_type(): """Test that engine raises TypeError for non-string request_id.""" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) - engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() - executor_class = Executor.get_class(vllm_config) + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) - with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + # Test with UUID object (common mistake) + uuid_request = make_request() + uuid_request.request_id = uuid.uuid4() # UUID object instead of string - # Test with UUID object (common mistake) - uuid_request = make_request() - uuid_request.request_id = uuid.uuid4() # UUID object instead of string + with pytest.raises(TypeError, match="request_id must be a string, got.*UUID"): + engine_core.add_request(*engine_core.preprocess_add_request(uuid_request)) - with pytest.raises(TypeError, - match="request_id must be a string, got.*UUID"): - engine_core.add_request( - *engine_core.preprocess_add_request(uuid_request)) + # Test with integer + int_request = make_request() + int_request.request_id = 12345 - # Test with integer - int_request = make_request() - int_request.request_id = 12345 + with pytest.raises(TypeError, match="request_id must be a string, got.*int"): + engine_core.add_request(*engine_core.preprocess_add_request(int_request)) - with pytest.raises(TypeError, - match="request_id must be a string, got.*int"): - engine_core.add_request( - *engine_core.preprocess_add_request(int_request)) + # Test with None + none_request = make_request() + none_request.request_id = None - # Test with None - none_request = make_request() - none_request.request_id = None + with pytest.raises(TypeError, match="request_id must be a string, got.*NoneType"): + engine_core.add_request(*engine_core.preprocess_add_request(none_request)) - with pytest.raises(TypeError, - match="request_id must be a string, got.*NoneType"): - engine_core.add_request( - *engine_core.preprocess_add_request(none_request)) - - # Verify engine is still functional after errors - valid_request = make_request() - engine_core.add_request( - *engine_core.preprocess_add_request(valid_request)) - assert len(engine_core.scheduler.waiting) == 1 - assert len(engine_core.scheduler.running) == 0 + # Verify engine is still functional after errors + valid_request = make_request() + engine_core.add_request(*engine_core.preprocess_add_request(valid_request)) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 37eb869fe69a3..bc04d1f93f951 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -8,7 +8,7 @@ import time import uuid from dataclasses import dataclass from threading import Thread -from typing import Optional, Union +from typing import Any, Optional, Union from unittest.mock import MagicMock import pytest @@ -17,16 +17,14 @@ from transformers import AutoTokenizer from tests.utils import multi_gpu_test from vllm import SamplingParams -from vllm.distributed.kv_events import (BlockStored, KVEventBatch, - ZmqEventPublisher) +from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext from vllm.utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, - SyncMPClient) +from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.abstract import Executor @@ -34,8 +32,7 @@ from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test if not current_platform.is_cuda(): - pytest.skip(reason="V1 currently only supported on CUDA.", - allow_module_level=True) + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) @@ -44,17 +41,15 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids def make_request( - params: SamplingParams, - prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest: + params: SamplingParams, prompt_tokens_ids: Optional[list[int]] = None +) -> EngineCoreRequest: if not prompt_tokens_ids: prompt_tokens_ids = PROMPT_TOKENS return EngineCoreRequest( request_id=str(uuid.uuid4()), prompt_token_ids=prompt_tokens_ids, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, sampling_params=params, pooling_params=None, eos_token_id=None, @@ -66,7 +61,6 @@ def make_request( def loop_until_done(client: EngineCoreClient, outputs: dict): - while True: engine_core_outputs = client.get_output().outputs @@ -84,7 +78,6 @@ def loop_until_done(client: EngineCoreClient, outputs: dict): async def loop_until_done_async(client: EngineCoreClient, outputs: dict): - while True: engine_core_outputs = (await client.get_output_async()).outputs @@ -102,7 +95,6 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict): async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): - while True: engine_core_outputs = (await client.get_output_async()).outputs @@ -121,10 +113,9 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): # Dummy utility function to monkey-patch into engine core. -def echo(self, - msg: str, - err_msg: Optional[str] = None, - sleep: Optional[float] = None) -> str: +def echo( + self, msg: str, err_msg: Optional[str] = None, sleep: Optional[float] = None +) -> str: print(f"echo util function called: {msg}, {err_msg}") if sleep is not None: time.sleep(sleep) @@ -135,18 +126,15 @@ def echo(self, @create_new_process_for_each_test() @pytest.mark.parametrize("multiprocessing_mode", [True, False]) -def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, - multiprocessing_mode: bool): - +def test_engine_core_client( + monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool +): with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - # Monkey-patch core engine utility function to test. m.setattr(EngineCore, "echo", echo, raising=False) engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) - vllm_config = engine_args.create_engine_config( - UsageContext.UNKNOWN_CONTEXT) + vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -174,7 +162,8 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, for req_id in request_ids: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{outputs[req_id]=}, {MAX_TOKENS=}") + f"{outputs[req_id]=}, {MAX_TOKENS=}" + ) """Abort Request Cycle.""" # Note: this code pathway will only work for multiprocessing @@ -193,10 +182,12 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, for idx, req_id in enumerate(request_ids): if idx % 2 == 0: assert len(outputs[req_id]) < MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) else: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) """Abort after request is finished.""" # Note: this code pathway will only work for multiprocessing @@ -204,7 +195,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, request = requests[0] client.add_request(request) - time.sleep(10.) + time.sleep(10.0) client.abort_requests([request.request_id]) @@ -224,16 +215,14 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - # Monkey-patch core engine utility function to test. m.setattr(EngineCore, "echo", echo, raising=False) engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -263,7 +252,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): for req_id in request_ids: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{outputs[req_id]=}, {MAX_TOKENS=}") + f"{outputs[req_id]=}, {MAX_TOKENS=}" + ) """Abort Request Cycle.""" # Add requests to the engine. @@ -279,10 +269,12 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): for idx, req_id in enumerate(request_ids): if idx % 2 == 0: assert len(outputs[req_id]) < MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) else: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) """Utility method invocation""" core_client: AsyncMPClient = client @@ -298,8 +290,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): # Test that cancelling the utility call doesn't destabilize the # engine. util_task = asyncio.create_task( - core_client.call_utility_async("echo", "testarg2", None, - 0.5)) # sleep for 0.5 sec + core_client.call_utility_async("echo", "testarg2", None, 0.5) + ) # sleep for 0.5 sec await asyncio.sleep(0.05) cancelled = util_task.cancel() assert cancelled @@ -307,9 +299,9 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): # Ensure client is still functional. The engine runs utility # methods in a single thread so this request won't be processed # until the cancelled sleeping one is complete. - result = await asyncio.wait_for(core_client.call_utility_async( - "echo", "testarg3"), - timeout=1.0) + result = await asyncio.wait_for( + core_client.call_utility_async("echo", "testarg3"), timeout=1.0 + ) assert result == "testarg3" finally: client.shutdown() @@ -333,13 +325,50 @@ def echo_dc( return [val for _ in range(3)] if return_list else val +# Dummy utility function to test dict serialization with custom types. +def echo_dc_dict( + self, + msg: str, + return_dict: bool = False, +) -> Union[MyDataclass, dict[str, MyDataclass]]: + print(f"echo dc dict util function called: {msg}") + val = None if msg is None else MyDataclass(msg) + # Return dict of dataclasses to verify support for returning dicts + # with custom value types. + if return_dict: + return {"key1": val, "key2": val, "key3": val} + else: + return val + + +# Dummy utility function to test nested structures with custom types. +def echo_dc_nested( + self, + msg: str, + structure_type: str = "list_of_dicts", +) -> Any: + print(f"echo dc nested util function called: {msg}, structure: {structure_type}") + val = None if msg is None else MyDataclass(msg) + + if structure_type == "list_of_dicts": # noqa + # Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}] + return [{"a": val, "b": val}, {"c": val, "d": val}] + elif structure_type == "dict_of_lists": + # Return dict of lists: {"list1": [val, val], "list2": [val, val]} + return {"list1": [val, val], "list2": [val, val]} + elif structure_type == "deep_nested": + # Return deeply nested: {"outer": [{"inner": [val, val]}, + # {"inner": [val]}]} + return {"outer": [{"inner": [val, val]}, {"inner": [val]}]} + else: + return val + + @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_util_method_custom_return( - monkeypatch: pytest.MonkeyPatch): - + monkeypatch: pytest.MonkeyPatch, +): with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - # Must set insecure serialization to allow returning custom types. m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @@ -348,7 +377,8 @@ async def test_engine_core_client_util_method_custom_return( engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -364,103 +394,259 @@ async def test_engine_core_client_util_method_custom_return( # Test utility method returning custom / non-native data type. core_client: AsyncMPClient = client - result = await core_client.call_utility_async( - "echo_dc", "testarg2", False) - assert isinstance(result, - MyDataclass) and result.message == "testarg2" - result = await core_client.call_utility_async( - "echo_dc", "testarg2", True) + result = await core_client.call_utility_async("echo_dc", "testarg2", False) + assert isinstance(result, MyDataclass) and result.message == "testarg2" + result = await core_client.call_utility_async("echo_dc", "testarg2", True) assert isinstance(result, list) and all( - isinstance(r, MyDataclass) and r.message == "testarg2" - for r in result) + isinstance(r, MyDataclass) and r.message == "testarg2" for r in result + ) # Test returning None and list of Nones - result = await core_client.call_utility_async( - "echo_dc", None, False) + result = await core_client.call_utility_async("echo_dc", None, False) assert result is None - result = await core_client.call_utility_async( - "echo_dc", None, True) + result = await core_client.call_utility_async("echo_dc", None, True) assert isinstance(result, list) and all(r is None for r in result) finally: client.shutdown() +@pytest.mark.asyncio(loop_scope="function") +async def test_engine_core_client_util_method_custom_dict_return( + monkeypatch: pytest.MonkeyPatch, +): + with monkeypatch.context() as m: + # Must set insecure serialization to allow returning custom types. + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT + ) + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=True, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True, + ) + + try: + # Test utility method returning custom / non-native data type. + core_client: AsyncMPClient = client + + # Test single object return + result = await core_client.call_utility_async( + "echo_dc_dict", "testarg3", False + ) + assert isinstance(result, MyDataclass) and result.message == "testarg3" + + # Test dict return with custom value types + result = await core_client.call_utility_async( + "echo_dc_dict", "testarg3", True + ) + assert isinstance(result, dict) and len(result) == 3 + for key, val in result.items(): + assert key in ["key1", "key2", "key3"] + assert isinstance(val, MyDataclass) and val.message == "testarg3" + + # Test returning dict with None values + result = await core_client.call_utility_async("echo_dc_dict", None, True) + assert isinstance(result, dict) and len(result) == 3 + for key, val in result.items(): + assert key in ["key1", "key2", "key3"] + assert val is None + + finally: + client.shutdown() + + +@pytest.mark.asyncio(loop_scope="function") +async def test_engine_core_client_util_method_nested_structures( + monkeypatch: pytest.MonkeyPatch, +): + with monkeypatch.context() as m: + # Must set insecure serialization to allow returning custom types. + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Monkey-patch core engine utility function to test. + m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False) + + engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT + ) + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=True, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True, + ) + + try: + core_client: AsyncMPClient = client + + # Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}] + result = await core_client.call_utility_async( + "echo_dc_nested", "nested1", "list_of_dicts" + ) + assert isinstance(result, list) and len(result) == 2 + for i, item in enumerate(result): + assert isinstance(item, dict) + if i == 0: + assert "a" in item and "b" in item + assert ( + isinstance(item["a"], MyDataclass) + and item["a"].message == "nested1" + ) + assert ( + isinstance(item["b"], MyDataclass) + and item["b"].message == "nested1" + ) + else: + assert "c" in item and "d" in item + assert ( + isinstance(item["c"], MyDataclass) + and item["c"].message == "nested1" + ) + assert ( + isinstance(item["d"], MyDataclass) + and item["d"].message == "nested1" + ) + + # Test dict of lists: {"list1": [val, val], "list2": [val, val]} + result = await core_client.call_utility_async( + "echo_dc_nested", "nested2", "dict_of_lists" + ) + assert isinstance(result, dict) and len(result) == 2 + assert "list1" in result and "list2" in result + for key, lst in result.items(): + assert isinstance(lst, list) and len(lst) == 2 + for item in lst: + assert isinstance(item, MyDataclass) and item.message == "nested2" + + # Test deeply nested: {"outer": [{"inner": [val, val]}, + # {"inner": [val]}]} + result = await core_client.call_utility_async( + "echo_dc_nested", "nested3", "deep_nested" + ) + assert isinstance(result, dict) and "outer" in result + outer_list = result["outer"] + assert isinstance(outer_list, list) and len(outer_list) == 2 + + # First dict in outer list should have "inner" with 2 items + inner_dict1 = outer_list[0] + assert isinstance(inner_dict1, dict) and "inner" in inner_dict1 + inner_list1 = inner_dict1["inner"] + assert isinstance(inner_list1, list) and len(inner_list1) == 2 + for item in inner_list1: + assert isinstance(item, MyDataclass) and item.message == "nested3" + + # Second dict in outer list should have "inner" with 1 item + inner_dict2 = outer_list[1] + assert isinstance(inner_dict2, dict) and "inner" in inner_dict2 + inner_list2 = inner_dict2["inner"] + assert isinstance(inner_list2, list) and len(inner_list2) == 1 + assert ( + isinstance(inner_list2[0], MyDataclass) + and inner_list2[0].message == "nested3" + ) + + # Test with None values in nested structures + result = await core_client.call_utility_async( + "echo_dc_nested", None, "list_of_dicts" + ) + assert isinstance(result, list) and len(result) == 2 + for item in result: + assert isinstance(item, dict) + for val in item.values(): + assert val is None + + finally: + client.shutdown() + + @pytest.mark.parametrize( "multiprocessing_mode,publisher_config", [(True, "tcp"), (False, "inproc")], indirect=["publisher_config"], ) def test_kv_cache_events( - monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool, publisher_config, ): + block_size = 16 + num_blocks = 2 - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - block_size = 16 - num_blocks = 2 + engine_args = EngineArgs( + model=MODEL_NAME, + enforce_eager=True, + enable_prefix_caching=True, + block_size=block_size, + ) + engine_args.kv_events_config = publisher_config - engine_args = EngineArgs( - model=MODEL_NAME, - enforce_eager=True, - enable_prefix_caching=True, - block_size=block_size, + vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT) + + executor_class = Executor.get_class(vllm_config) + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, ) - engine_args.kv_events_config = publisher_config + endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") + subscriber = MockSubscriber( + endpoint, topic=publisher_config.topic, decode_type=KVEventBatch + ) - vllm_config = engine_args.create_engine_config( - UsageContext.UNKNOWN_CONTEXT) + try: + custom_tokens = list(range(num_blocks * block_size)) + sampling_params = SamplingParams(max_tokens=1) + request = make_request(sampling_params, custom_tokens) + client.add_request(request) - executor_class = Executor.get_class(vllm_config) - with set_default_torch_num_threads(1): - client = EngineCoreClient.make_client( - multiprocess_mode=multiprocessing_mode, - asyncio_mode=False, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=False, - ) - endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") - subscriber = MockSubscriber(endpoint, - topic=publisher_config.topic, - decode_type=KVEventBatch) + outputs: dict[str, list] = {request.request_id: []} + loop_until_done(client, outputs) - try: - custom_tokens = list(range(num_blocks * block_size)) - sampling_params = SamplingParams(max_tokens=1) - request = make_request(sampling_params, custom_tokens) - client.add_request(request) + result = subscriber.receive_one(timeout=1000) + assert result is not None, "No message received" - outputs: dict[str, list] = {request.request_id: []} - loop_until_done(client, outputs) + seq, received = result - result = subscriber.receive_one(timeout=1000) - assert result is not None, "No message received" - - seq, received = result - - assert seq == 0, "Sequence number mismatch" - assert (len(received.events) == 1 - ), "We should have exactly one BlockStored event" - event = received.events[0] - assert isinstance( - event, BlockStored), "We should have a BlockStored event" - assert (len(event.block_hashes) == num_blocks - ), "We should have a BlockStored event with 2 block_hashes" - assert (event.block_size == block_size - ), "Block size should be the same as the block size" - assert (event.parent_block_hash - is None), "Parent block hash should be None" - assert event.lora_id is None, "Lora id should be None" - assert (len(event.token_ids) == num_blocks * block_size - ), "Token ids should be the same as the custom tokens" - assert (event.token_ids == custom_tokens - ), "Token ids should be the same as the custom tokens" - finally: - client.shutdown() - subscriber.close() + assert seq == 0, "Sequence number mismatch" + assert len(received.events) == 1, "We should have exactly one BlockStored event" + event = received.events[0] + assert isinstance(event, BlockStored), "We should have a BlockStored event" + assert len(event.block_hashes) == num_blocks, ( + "We should have a BlockStored event with 2 block_hashes" + ) + assert event.block_size == block_size, ( + "Block size should be the same as the block size" + ) + assert event.parent_block_hash is None, "Parent block hash should be None" + assert event.lora_id is None, "Lora id should be None" + assert len(event.token_ids) == num_blocks * block_size, ( + "Token ids should be the same as the custom tokens" + ) + assert event.token_ids == custom_tokens, ( + "Token ids should be the same as the custom tokens" + ) + finally: + client.shutdown() + subscriber.close() @pytest.mark.asyncio @@ -471,110 +657,96 @@ def test_kv_cache_events( ) @multi_gpu_test(num_gpus=4) async def test_kv_cache_events_dp( - monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool, publisher_config, ): + block_size = 16 + num_blocks = 2 + dp_size = 2 + tp_size = 2 - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - block_size = 16 - num_blocks = 2 - dp_size = 2 - tp_size = 2 + engine_args = EngineArgs( + model=MODEL_NAME, + enforce_eager=True, + enable_prefix_caching=True, + data_parallel_size=dp_size, + tensor_parallel_size=tp_size, + block_size=block_size, + ) + engine_args.kv_events_config = publisher_config - engine_args = EngineArgs( - model=MODEL_NAME, - enforce_eager=True, - enable_prefix_caching=True, - data_parallel_size=dp_size, - tensor_parallel_size=tp_size, - block_size=block_size, + vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT) + + executor_class = Executor.get_class(vllm_config) + with set_default_torch_num_threads(1): + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, ) - engine_args.kv_events_config = publisher_config + await asyncio.sleep(1) - vllm_config = engine_args.create_engine_config( - UsageContext.UNKNOWN_CONTEXT) + # Build endpoints for all DP ranks + base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") + endpoints = [] + for i in range(dp_size): + offset_endpoint = ZmqEventPublisher.offset_endpoint_port(base_endpoint, i) + endpoints.append(offset_endpoint) - executor_class = Executor.get_class(vllm_config) - with set_default_torch_num_threads(1): - client = EngineCoreClient.make_client( - multiprocess_mode=multiprocessing_mode, - asyncio_mode=True, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=False, - ) - await asyncio.sleep(1) + subscriber = MockSubscriber( + endpoints, topic=publisher_config.topic, decode_type=KVEventBatch + ) - # Build endpoints for all DP ranks - base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") - endpoints = [] - for i in range(dp_size): - offset_endpoint = ZmqEventPublisher.offset_endpoint_port( - base_endpoint, i) - endpoints.append(offset_endpoint) + try: + custom_tokens = list(range(num_blocks * block_size)) + sampling_params = SamplingParams(max_tokens=1) + all_request_ids = [] - subscriber = MockSubscriber(endpoints, - topic=publisher_config.topic, - decode_type=KVEventBatch) + # Create and add 25 requests + # NOTE: attempts to force routing to both dp groups but can be flaky + for i in range(25): + await asyncio.sleep(0.01) + request = make_request(sampling_params, custom_tokens) + await client.add_request_async(request) + all_request_ids.append(request.request_id) - try: - custom_tokens = list(range(num_blocks * block_size)) - sampling_params = SamplingParams(max_tokens=1) - all_request_ids = [] + await asyncio.sleep(0.1) - # Create and add 25 requests - # NOTE: attempts to force routing to both dp groups but can be flaky - for i in range(25): - await asyncio.sleep(0.01) - request = make_request(sampling_params, custom_tokens) - await client.add_request_async(request) - all_request_ids.append(request.request_id) + # Initialize outputs dict for all requests + outputs: dict[str, list] = {req_id: [] for req_id in all_request_ids} - await asyncio.sleep(0.1) + print("processing requests...") + await asyncio.wait_for( + loop_until_fully_done_async(client, outputs), timeout=20.0 + ) - # Initialize outputs dict for all requests - outputs: dict[str, list] = { - req_id: [] - for req_id in all_request_ids - } + # Receive from subscriber until no more messages + print("collecting results...") + results = [] + while True: + result = subscriber.receive_one(timeout=1) + print(result) + if result is None: + break + results.append(result) - print("processing requests...") - await asyncio.wait_for(loop_until_fully_done_async( - client, outputs), - timeout=20.0) + # Collect all events and data_parallel_ranks from all results + all_dp_ranks = [received.data_parallel_rank for (_, received) in results] + unique_dps = set(all_dp_ranks) + assert len(unique_dps) == 2, ( + f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}" + ) - # Receive from subscriber until no more messages - print("collecting results...") - results = [] - while True: - result = subscriber.receive_one(timeout=1) - print(result) - if result is None: - break - results.append(result) - - # Collect all events and data_parallel_ranks from all results - all_dp_ranks = [ - received.data_parallel_rank for (_, received) in results - ] - unique_dps = set(all_dp_ranks) - assert ( - len(unique_dps) == 2 - ), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}" - - finally: - client.shutdown() - subscriber.close() + finally: + client.shutdown() + subscriber.close() @pytest.mark.timeout(20) def test_startup_failure(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m, pytest.raises(Exception) as e_info: - m.setenv("VLLM_USE_V1", "1") - # Monkey-patch to extract core process pid while it's starting. core_proc_pid = [None] cepm_ctor = CoreEngineProcManager.__init__ @@ -588,7 +760,8 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): t = time.time() engine_args = EngineArgs(model=MODEL_NAME) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) print(f"VllmConfig creation took {time.time() - t:.2f} seconds.") @@ -616,8 +789,7 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): @create_new_process_for_each_test() -def test_engine_core_proc_instantiation_cuda_empty( - monkeypatch: pytest.MonkeyPatch): +def test_engine_core_proc_instantiation_cuda_empty(monkeypatch: pytest.MonkeyPatch): """ Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES is empty. This ensures the engine frontend does not need access to GPUs. @@ -634,18 +806,13 @@ def test_engine_core_proc_instantiation_cuda_empty( # Only implement the methods that are actually called during init from vllm.v1.kv_cache_interface import FullAttentionSpec - mock_spec = FullAttentionSpec(block_size=16, - num_kv_heads=1, - head_size=64, - dtype=torch.float16, - use_mla=False) - mock_executor.get_kv_cache_specs.return_value = [{ - "default": mock_spec - }] - mock_executor.determine_available_memory.return_value = [ - 1024 * 1024 * 1024 - ] + mock_spec = FullAttentionSpec( + block_size=16, num_kv_heads=1, head_size=64, dtype=torch.float16 + ) + + mock_executor.get_kv_cache_specs.return_value = [{"default": mock_spec}] + mock_executor.determine_available_memory.return_value = [1024 * 1024 * 1024] mock_executor.initialize_from_config.return_value = None mock_executor.max_concurrent_batches = 1 @@ -654,24 +821,26 @@ def test_engine_core_proc_instantiation_cuda_empty( mock_executor_class.side_effect = create_mock_executor with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices from vllm.v1.engine.utils import EngineZmqAddresses - def mock_startup_handshake(self, handshake_socket, local_client, - headless, parallel_config): - return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"], - outputs=["tcp://127.0.0.1:5556"], - coordinator_input=None, - coordinator_output=None) + def mock_startup_handshake( + self, handshake_socket, local_client, headless, parallel_config + ): + return EngineZmqAddresses( + inputs=["tcp://127.0.0.1:5555"], + outputs=["tcp://127.0.0.1:5556"], + coordinator_input=None, + coordinator_output=None, + ) # Background processes are not important here m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake) vllm_config = EngineArgs( - model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True).create_engine_config() + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ).create_engine_config() engine_core_proc = EngineCoreProc( vllm_config=vllm_config, local_client=True, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index f028b4ab1d73f..77e67d54e587e 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -26,39 +26,153 @@ def test_fast_inc_detok_invalid_utf8_err_case(): prompt_token_ids = [107, 4606, 236787, 107] params = SamplingParams(skip_special_tokens=True) request = EngineCoreRequest( - "test", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request_id="test", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None, ) detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) - assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \ + assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", ( "Should use FastIncrementalDetokenizer by default" + ) # Process tokens incrementally test_tokens = [ - 236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908, - 147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292, - 827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418, - 569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118, - 35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140, - 236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654, - 236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654, - 236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817, - 4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509, - 19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398, - 432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745, - 2555, 513, 236789, 602, 31118, 569 + 236840, + 107, + 138, + 236782, + 107, + 140, + 236775, + 6265, + 1083, + 623, + 121908, + 147418, + 827, + 107, + 140, + 236775, + 6265, + 236779, + 2084, + 1083, + 623, + 203292, + 827, + 107, + 140, + 236775, + 6265, + 236779, + 7777, + 1083, + 623, + 121908, + 147418, + 569, + 537, + 236789, + 65880, + 569, + 537, + 236789, + 62580, + 853, + 115693, + 210118, + 35178, + 16055, + 1270, + 759, + 215817, + 4758, + 1925, + 1117, + 827, + 107, + 140, + 236775, + 5654, + 1083, + 623, + 110733, + 46291, + 827, + 107, + 140, + 236775, + 5654, + 236779, + 2084, + 1083, + 623, + 136955, + 56731, + 827, + 107, + 140, + 236775, + 5654, + 236779, + 7777, + 1083, + 623, + 194776, + 2947, + 496, + 109811, + 1608, + 890, + 215817, + 4758, + 1925, + 1117, + 2789, + 432, + 398, + 602, + 31118, + 569, + 124866, + 134772, + 509, + 19478, + 1640, + 33779, + 236743, + 236770, + 236819, + 236825, + 236771, + 432, + 398, + 432, + 237167, + 827, + 107, + 140, + 236775, + 77984, + 1083, + 623, + 2709, + 236745, + 2555, + 513, + 236789, + 602, + 31118, + 569, ] output = "" @@ -68,9 +182,9 @@ def test_fast_inc_detok_invalid_utf8_err_case(): finished = i == len(test_tokens) - 1 output += detokenizer.get_next_output_text(finished, delta=True) - -# fmt: off - assert output == r'''[ + assert ( + output + == r"""[ { "source": "Résultats", "source_type": "CONCEPT", @@ -78,4 +192,5 @@ def test_fast_inc_detok_invalid_utf8_err_case(): "target": "Israël", "target_type": "ORGANIZATION", "target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »", - "relationship": "Obtention d'un niveau de''' + "relationship": "Obtention d'un niveau de""" + ) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 2848420c22085..3f6f2211556f5 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -3,12 +3,12 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest from vllm import LLM -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector if TYPE_CHECKING: @@ -21,12 +21,10 @@ DTYPE = "half" def _vllm_model( apc: bool, vllm_runner: type[VllmRunner], - monkeypatch: pytest.MonkeyPatch, *, skip_tokenizer_init: bool = False, ): """Set up VllmRunner instance.""" - monkeypatch.setenv("VLLM_USE_V1", "1") return vllm_runner( MODEL, dtype=DTYPE, @@ -43,17 +41,18 @@ def _vllm_model( # env var adjustment via monkeypatch scope="function", # Prefix caching - params=[False, True]) -def vllm_model(vllm_runner, request, monkeypatch): + params=[False, True], +) +def vllm_model(vllm_runner, request): """VllmRunner test fixture parameterized by APC True/False.""" - with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model: + with _vllm_model(request.param, vllm_runner) as vllm_model: yield vllm_model @pytest.fixture(scope="function") -def vllm_model_apc(vllm_runner, monkeypatch): +def vllm_model_apc(vllm_runner): """VllmRunner test fixture with APC.""" - with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model: + with _vllm_model(True, vllm_runner) as vllm_model: yield vllm_model @@ -62,21 +61,21 @@ def vllm_model_apc(vllm_runner, monkeypatch): # env var adjustment via monkeypatch scope="function", # Prefix caching - params=[False, True]) -def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch): + params=[False, True], +) +def vllm_model_skip_tokenizer_init(vllm_runner, request): """VllmRunner test fixture with APC.""" with _vllm_model( - request.param, - vllm_runner, - monkeypatch, - skip_tokenizer_init=True, + request.param, + vllm_runner, + skip_tokenizer_init=True, ) as vllm_model: yield vllm_model def _get_test_sampling_params( prompt_list: list[str], - seed: Optional[int] = 42, + seed: int | None = 42, structured_outputs: bool = False, ) -> tuple[list[SamplingParams], list[int]]: """Generate random sampling params for a batch.""" @@ -97,9 +96,11 @@ def _get_test_sampling_params( top_p=0.95, n=n, seed=seed, - guided_decoding=GuidedDecodingParams( - regex="[0-9]+") if structured_outputs else None, - ) for n in n_list + structured_outputs=StructuredOutputsParams(regex="[0-9]+") + if structured_outputs + else None, + ) + for n in n_list ], n_list @@ -132,26 +133,23 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: for out, n in zip(outputs, n_list): completion_counts: dict[str, int] = {} # Assert correct number of completions - assert len(out.outputs) == n, ( - f"{len(out.outputs)} completions; {n} expected.") + assert len(out.outputs) == n, f"{len(out.outputs)} completions; {n} expected." for idx in range(n): comp = out.outputs[idx] # Assert correct completion indices - assert comp.index == idx, (f"Index {comp.index}; expected {idx}.") + assert comp.index == idx, f"Index {comp.index}; expected {idx}." text = comp.text completion_counts[text] = completion_counts.get(text, 0) + 1 # Assert unique completions if len(completion_counts) != n: - repeats = { - txt: num - for (txt, num) in completion_counts.items() if num > 1 - } + repeats = {txt: num for (txt, num) in completion_counts.items() if num > 1} raise AssertionError( f"{len(completion_counts)} unique completions; expected" - f" {n}. Repeats: {repeats}") + f" {n}. Repeats: {repeats}" + ) -def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): +def test_engine_metrics(vllm_runner, example_prompts): max_tokens = 100 # Use spec decoding to test num_accepted_tokens_per_pos speculative_config = { @@ -160,15 +158,14 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): "prompt_lookup_min": 3, "num_speculative_tokens": 5, } - monkeypatch.setenv("VLLM_USE_V1", "1") + with vllm_runner( - MODEL, - speculative_config=speculative_config, - disable_log_stats=False, + MODEL, + speculative_config=speculative_config, + disable_log_stats=False, ) as vllm_model: llm: LLM = vllm_model.llm - sampling_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens) + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = llm.generate(example_prompts, sampling_params) n_prompts = len(example_prompts) @@ -192,15 +189,14 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): num_requests_running = find_metric("vllm:num_requests_running") assert len(num_requests_running) == 1 assert isinstance(num_requests_running[0], Gauge) - assert num_requests_running[0].value == .0 + assert num_requests_running[0].value == 0.0 generation_tokens = find_metric("vllm:generation_tokens") assert len(generation_tokens) == 1 assert isinstance(generation_tokens[0], Counter) assert generation_tokens[0].value == total_tokens - request_generation_tokens = find_metric( - "vllm:request_generation_tokens") + request_generation_tokens = find_metric("vllm:request_generation_tokens") assert len(request_generation_tokens) == 1 assert isinstance(request_generation_tokens[0], Histogram) assert "+Inf" in request_generation_tokens[0].buckets @@ -209,16 +205,15 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): assert request_generation_tokens[0].sum == total_tokens num_accepted_tokens_per_pos = find_metric( - "vllm:spec_decode_num_accepted_tokens_per_pos") + "vllm:spec_decode_num_accepted_tokens_per_pos" + ) assert len(num_accepted_tokens_per_pos) == 1 assert isinstance(num_accepted_tokens_per_pos[0], Vector) assert len(num_accepted_tokens_per_pos[0].values) == 5 @pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"]) -def test_skip_tokenizer_initialization(model: str, - monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_V1", "1") +def test_skip_tokenizer_initialization(model: str): # This test checks if the flag skip_tokenizer_init skips the initialization # of tokenizer and detokenizer. The generated output is expected to contain # token ids. @@ -232,8 +227,9 @@ def test_skip_tokenizer_initialization(model: str, with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) + outputs = llm.generate( + {"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params + ) assert len(outputs) > 0 completions = outputs[0].outputs assert len(completions) > 0 diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index c113439a70228..9ebf7f09503e5 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -7,18 +7,20 @@ from typing import Optional import pytest -from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, - NUM_SAMPLE_LOGPROBS_UNDER_TEST, - STOP_STRINGS, - DummyOutputProcessorTestVectors, - MockEngineCore) +from tests.v1.engine.utils import ( + NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + STOP_STRINGS, + DummyOutputProcessorTestVectors, + MockEngineCore, +) +from vllm import PoolingParams +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.output_processor import (OutputProcessor, - RequestOutputCollector) +from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.metrics.stats import IterationStats @@ -39,35 +41,34 @@ def _ref_convert_id_to_token( @pytest.mark.parametrize( - "request_output_kind", - [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -def test_incremental_detokenization(request_output_kind: RequestOutputKind, - dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, - log_stats=False) - engine_core = MockEngineCore( - tokens_list=dummy_test_vectors.generation_tokens) + "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) +def test_incremental_detokenization( + request_output_kind: RequestOutputKind, dummy_test_vectors +): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) + engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ - EngineCoreRequest(request_id=f"request-{idx}", - prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, - eos_token_id=None, - lora_request=None, - cache_salt=None, - data_parallel_rank=None, - sampling_params=SamplingParams( - skip_special_tokens=False, - spaces_between_special_tokens=False, - output_kind=request_output_kind, - stop=[], - include_stop_str_in_output=False, - ), - pooling_params=None) + EngineCoreRequest( + request_id=f"request-{idx}", + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + ), + pooling_params=None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -103,8 +104,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, ref_gen_toks) in enumerate( - zip(dummy_test_vectors.generation_strings, - dummy_test_vectors.generation_tokens)): + zip(dummy_test_vectors.generation_strings, dummy_test_vectors.generation_tokens) + ): gen_str = gen_strings[f"request-{idx}"] gen_toks = gen_tokens[f"request-{idx}"] @@ -135,9 +136,11 @@ def _validate_logprobs( ref_prompt_logprobs = dtv.prompt_logprobs[req_idx] if num_sample_logprobs is not None: # Validate sample logprobs - assert logprobs is not None, (f"Request {req_id} requires sample" - " logprobs but sample logprobs are" - " None.") + assert logprobs is not None, ( + f"Request {req_id} requires sample" + " logprobs but sample logprobs are" + " None." + ) # Require num sampled tokens to match num # sampled logprobs - especially important # to check since the detokenizer can cause @@ -148,44 +151,51 @@ def _validate_logprobs( assert num_new_tokens == len_sample_logprobs, ( f"Request {req_id} has {num_new_tokens}" " completion tokens but has" - f" {len_sample_logprobs} sample logprobs.") + f" {len_sample_logprobs} sample logprobs." + ) ref_cumulative_logprob = 0.0 - for idx, (sampled_token, - pos_logprob_dict) in enumerate(zip(new_tokens, - logprobs)): + for idx, (sampled_token, pos_logprob_dict) in enumerate( + zip(new_tokens, logprobs) + ): # Break out the reference log probability value & # logprob token id tensors associated with this # position in the completion. Also break out the # sampled token ranks - (ref_pos_logprob_toks, ref_pos_logprob_vals, - ref_sampled_token_rank) = ref_logprobs[idx] + (ref_pos_logprob_toks, ref_pos_logprob_vals, ref_sampled_token_rank) = ( + ref_logprobs[idx] + ) # For each position in the completion sequence, # ensure the actual sampled token is among the # logprobs assert sampled_token in pos_logprob_dict, ( f"Sampled token {sampled_token} not" - f" present in logprob at index {idx}") + f" present in logprob at index {idx}" + ) # Validate number of sample logprobs num_lp_toks = len(pos_logprob_dict) - assert (num_lp_toks == num_sample_logprobs - or num_lp_toks == num_sample_logprobs + - 1), ("Valid numbers of sample logprobs are" - f" {num_sample_logprobs} or" - f" {num_sample_logprobs+1} but" - f" {num_lp_toks} logprobs found at" - f" position {idx}. Logprobs dict:" - f" {pos_logprob_dict}") + assert ( + num_lp_toks == num_sample_logprobs + or num_lp_toks == num_sample_logprobs + 1 + ), ( + "Valid numbers of sample logprobs are" + f" {num_sample_logprobs} or" + f" {num_sample_logprobs + 1} but" + f" {num_lp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}" + ) # Validate sampled token logprob rank smp_lp = pos_logprob_dict[sampled_token] smp_lp_rank = smp_lp.rank - assert (ref_sampled_token_rank == smp_lp_rank), ( + assert ref_sampled_token_rank == smp_lp_rank, ( "Sampled token logprob rank" f" {smp_lp_rank} does not match" " correct value" f" {ref_sampled_token_rank}" - f" in Logprob {smp_lp}") + f" in Logprob {smp_lp}" + ) # Validate that the logprob processor yields # the correct log probabilities and valid @@ -199,7 +209,8 @@ def _validate_logprobs( ref_tok_id = ref_pos_logprob_toks[jdx] assert ref_tok_id in pos_logprob_dict, ( f"Expected token {ref_tok_id} to be" - f" in logprob dict but it is not.") + f" in logprob dict but it is not." + ) # Extract actually-generated logprob # info @@ -209,40 +220,43 @@ def _validate_logprobs( # A "top" (rank 1) logprob must be # present - rank_one_appears = (True - if lp_rank == 1 else rank_one_appears) + rank_one_appears = True if lp_rank == 1 else rank_one_appears # Rank must be >= 1 - assert lp_rank >= 1, (f"Logprob {lp} has invalid" - f" rank {lp_rank} < 1." - f" Logprob dict: {pos_logprob_dict}") + assert lp_rank >= 1, ( + f"Logprob {lp} has invalid" + f" rank {lp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}" + ) # Validate log probability assert math.isclose(lp_val, ref_lp_val), ( f"Token id {ref_tok_id} appears in logprobs dict" f" at position {idx} in completion with log" f" probability {lp_val} but {ref_lp_val} was" - f" expected. Logprob: {lp}") + f" expected. Logprob: {lp}" + ) - assert rank_one_appears, (f"No Logprob has rank 1" - " in the following Logprob" - f" dict: {pos_logprob_dict}") + assert rank_one_appears, ( + f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}" + ) # Validate logprobs detokenization for lp_tok in pos_logprob_dict: # Confirm that sample logprob decoded token matches # the logprob token id at this sequence position decoded_token = pos_logprob_dict[lp_tok].decoded_token - ref_decoded_token = _ref_convert_id_to_token( - dtv.tokenizer, lp_tok) + ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok) assert decoded_token == ref_decoded_token, ( f"Sampled logprob token id {lp_tok} decodes to" f" {ref_decoded_token} but Logprob decoded" f" token is {decoded_token} instead" - f" (at position {idx})") + f" (at position {idx})" + ) - ref_cumulative_logprob += pos_logprob_dict[ - sampled_token].logprob + ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob # Assert that cumulative logprobs are correct assert math.isclose(cumulative_logprob, ref_cumulative_logprob) else: @@ -255,7 +269,8 @@ def _validate_logprobs( assert prompt_logprobs is not None, ( f"Request {req_id} requires prompt" " logprobs but prompt logprobs are" - " None.") + " None." + ) # Require num prompt tokens to match num # prompt logprobs num_prompt_tokens = len(prompt_token_ids) @@ -263,56 +278,70 @@ def _validate_logprobs( assert num_prompt_tokens == len_prompt_logprobs, ( f"Request {req_id} has {num_prompt_tokens}" " prompt tokens but has" - f" {len_prompt_logprobs} prompt logprobs.") + f" {len_prompt_logprobs} prompt logprobs." + ) # First prompt logprob is None first_plp_dict = prompt_logprobs[0] assert first_plp_dict is None, ( f"Request {req_id} first prompt logprob" f" should be None but has following value" - f" instead: {first_plp_dict}") + f" instead: {first_plp_dict}" + ) # Break out the reference prompt log prob value & # logprob token id matrices for the whole prompt. # Also break out the prompt token rank vector - (ref_prompt_logprob_toks, ref_prompt_logprob_vals, - ref_prompt_token_ranks) = ref_prompt_logprobs + ( + ref_prompt_logprob_toks, + ref_prompt_logprob_vals, + ref_prompt_token_ranks, + ) = ref_prompt_logprobs for idx, (prompt_token, pos_logprob_dict) in enumerate( - zip(prompt_token_ids[1:], prompt_logprobs[1:])): - + zip(prompt_token_ids[1:], prompt_logprobs[1:]) + ): # Break out the reference prompt log prob value # vector, prompt logprob token id vector, and # prompt token rank at the current position. - (ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals, - ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :], - ref_prompt_logprob_vals[idx, :], - ref_prompt_token_ranks[idx]) + ( + ref_pos_prompt_logprob_toks, + ref_pos_prompt_logprob_vals, + ref_pos_prompt_token_rank, + ) = ( + ref_prompt_logprob_toks[idx, :], + ref_prompt_logprob_vals[idx, :], + ref_prompt_token_ranks[idx], + ) # For each position in the prompt sequence, # ensure the actual prompt token is among the # logprobs assert prompt_token in pos_logprob_dict, ( - f"Prompt token {prompt_token} not" - f" present in logprob at index {idx}") + f"Prompt token {prompt_token} not present in logprob at index {idx}" + ) # Validate number of prompt logprobs num_plp_toks = len(pos_logprob_dict) - assert (num_plp_toks == num_prompt_logprobs - or num_plp_toks == num_prompt_logprobs + - 1), ("Valid numbers of prompt logprobs are" - f" {num_prompt_logprobs} or" - f" {num_prompt_logprobs+1} but" - f" {num_plp_toks} logprobs found at" - f" position {idx}. Logprobs dict:" - f" {pos_logprob_dict}") + assert ( + num_plp_toks == num_prompt_logprobs + or num_plp_toks == num_prompt_logprobs + 1 + ), ( + "Valid numbers of prompt logprobs are" + f" {num_prompt_logprobs} or" + f" {num_prompt_logprobs + 1} but" + f" {num_plp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}" + ) # Validate prompt token logprob rank prmpt_tok_lp = pos_logprob_dict[prompt_token] prmpt_tok_lp_rank = prmpt_tok_lp.rank ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank - assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), ( + assert ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank, ( "Prompt token logprob rank" f" {prmpt_tok_lp_rank} does not match" " correct value" f" {ref_prmpt_tok_lp_rank}" - f" in Logprob {prmpt_tok_lp}") + f" in Logprob {prmpt_tok_lp}" + ) # Validate that the logprob processor yields # the correct prompt log probs and valid @@ -326,7 +355,8 @@ def _validate_logprobs( ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx]) assert ref_tok_id in pos_logprob_dict, ( f"Expected token {ref_tok_id} to be" - f" in logprob dict but it is not.") + f" in logprob dict but it is not." + ) # Extract actually-generated logprob # info @@ -336,89 +366,93 @@ def _validate_logprobs( # A "top" (rank 1) logprob must be # present - rank_one_appears = (True - if plp_rank == 1 else rank_one_appears) + rank_one_appears = True if plp_rank == 1 else rank_one_appears # Rank must be >= 1 assert plp_rank >= 1, ( f"Logprob {plp} has invalid" f" rank {plp_rank} < 1." - f" Logprob dict: {pos_logprob_dict}") + f" Logprob dict: {pos_logprob_dict}" + ) # Validate log probability assert math.isclose(plp_val, ref_plp_val), ( f"Token id {ref_tok_id} appears in logprobs dict" f" at position {idx} in completion with log" f" probability {plp_val} but {ref_plp_val} was" - f" expected. Logprob: {plp}") + f" expected. Logprob: {plp}" + ) - assert rank_one_appears, (f"No Logprob has rank 1" - " in the following Logprob" - f" dict: {pos_logprob_dict}") + assert rank_one_appears, ( + f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}" + ) # Validate prompt logprob detokenization for plp_tok in pos_logprob_dict: # Confirm that prompt logprob decoded token matches # the logprob token id at this sequence position decoded_token = pos_logprob_dict[plp_tok].decoded_token - ref_decoded_token = _ref_convert_id_to_token( - dtv.tokenizer, plp_tok) + ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok) assert decoded_token == ref_decoded_token, ( f"Prompt logprob token id {plp_tok} decodes to" f" {ref_decoded_token} but Logprob decoded" f" token is {decoded_token} instead" - f" (at position {idx})") + f" (at position {idx})" + ) else: # Prompt logprobs disabled for this request assert prompt_logprobs is None @pytest.mark.parametrize( - "request_output_kind", - [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -@pytest.mark.parametrize("num_sample_logprobs", - [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -@pytest.mark.parametrize("num_prompt_logprobs", - [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) -def test_logprobs_processor(request_output_kind: RequestOutputKind, - num_sample_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], - dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, - log_stats=False) + "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) +@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_logprobs_processor( + request_output_kind: RequestOutputKind, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], + dummy_test_vectors, +): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, - generated_logprobs_raw=None if num_sample_logprobs is None else - dummy_test_vectors.generation_logprobs, + generated_logprobs_raw=None + if num_sample_logprobs is None + else dummy_test_vectors.generation_logprobs, prompt_logprobs_raw=None - if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs) + if num_prompt_logprobs is None + else dummy_test_vectors.prompt_logprobs, + ) # Make N requests. request_id_list = [ - f"request-{idx}" - for idx in range(len(dummy_test_vectors.prompt_strings)) + f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings)) ] requests = [ - EngineCoreRequest(request_id=request_id_list[idx], - prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, - eos_token_id=None, - lora_request=None, - cache_salt=None, - data_parallel_rank=None, - sampling_params=SamplingParams( - skip_special_tokens=False, - spaces_between_special_tokens=False, - output_kind=request_output_kind, - stop=[], - include_stop_str_in_output=False, - logprobs=num_sample_logprobs, - prompt_logprobs=num_prompt_logprobs, - ), - pooling_params=None) + EngineCoreRequest( + request_id=request_id_list[idx], + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + ), + pooling_params=None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -449,7 +483,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, prompt_logprobs = request_output.prompt_logprobs logprobs = request_output.outputs[0].logprobs gen_cumulative_logprobs[request_id] = request_output.outputs[ - 0].cumulative_logprob + 0 + ].cumulative_logprob if request_id not in gen_logprobs: # Start tracking sample and prompt logprobs for this request gen_tokens[request_id] = new_tokens @@ -466,10 +501,16 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, plp.extend(prompt_logprobs) # Confirmed tracked logprobs match what we expect - _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, - gen_cumulative_logprobs, dummy_test_vectors, - request_id_list, num_sample_logprobs, - num_prompt_logprobs) + _validate_logprobs( + gen_tokens, + gen_logprobs, + gen_prompt_logprobs, + gen_cumulative_logprobs, + dummy_test_vectors, + request_id_list, + num_sample_logprobs, + num_prompt_logprobs, + ) assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() @@ -477,15 +518,23 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, @pytest.mark.parametrize( "include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs", - [(False, "stop_token_ids", False, None), - (True, "stop_token_ids", False, None), - (False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), - (True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), - (False, "eos_token_id", False, None), (True, "eos_token_id", False, None), - (False, "eos_token_id", True, None)]) -def test_stop_token(include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], stop_token_type: str, - ignore_eos: bool, dummy_test_vectors): + [ + (False, "stop_token_ids", False, None), + (True, "stop_token_ids", False, None), + (False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), + (True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), + (False, "eos_token_id", False, None), + (True, "eos_token_id", False, None), + (False, "eos_token_id", True, None), + ], +) +def test_stop_token( + include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + stop_token_type: str, + ignore_eos: bool, + dummy_test_vectors, +): """Test output processor EOS/stop token handling. Send mock engine core request to mock engine core and pass core outputs @@ -526,9 +575,10 @@ def test_stop_token(include_stop_str_in_output: bool, dummy_test_vectors: dummy engine core outputs and other data structures """ model_id = dummy_test_vectors.tokenizer.name_or_path - if model_id != 'meta-llama/Llama-3.2-1B': - raise AssertionError("Test requires meta-llama/Llama-3.2-1B but " - f"{model_id} is in use.") + if model_id != "meta-llama/Llama-3.2-1B": + raise AssertionError( + f"Test requires meta-llama/Llama-3.2-1B but {model_id} is in use." + ) do_logprobs = num_sample_logprobs is not None # EOS under test; if False, stop_token_ids under test is_eos_test = stop_token_type == "eos_token_id" @@ -539,18 +589,16 @@ def test_stop_token(include_stop_str_in_output: bool, ) # '<|end_of_text|>' stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, - log_stats=False) + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) # Dummy engine core outputs, with control tokens suffixed to test stops - suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids) + suffix_token = [eos_token_id] if is_eos_test else stop_token_ids assert suffix_token is not None and isinstance(suffix_token[0], int) generation_string = dummy_test_vectors.generation_strings[0] - generation_tokens = (dummy_test_vectors.generation_tokens[0] + - 2 * suffix_token) + generation_tokens = dummy_test_vectors.generation_tokens[0] + 2 * suffix_token if do_logprobs: - generation_logprobs = ( - dummy_test_vectors.generation_logprobs[0] + - 2 * [dummy_test_vectors.generation_logprobs[0][-1]]) + generation_logprobs = dummy_test_vectors.generation_logprobs[0] + 2 * [ + dummy_test_vectors.generation_logprobs[0][-1] + ] prompt_string = dummy_test_vectors.prompt_strings[0] prompt_tokens = dummy_test_vectors.prompt_tokens[0] engine_core = MockEngineCore( @@ -559,18 +607,17 @@ def test_stop_token(include_stop_str_in_output: bool, prompt_logprobs_raw=None, eos_token_id=eos_token_id, stop_token_ids=stop_token_ids, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + ) # Make request. request_id = "request-0" request = EngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=eos_token_id, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -585,7 +632,8 @@ def test_stop_token(include_stop_str_in_output: bool, prompt_logprobs=None, ignore_eos=ignore_eos, ), - pooling_params=None) + pooling_params=None, + ) # Add request to the detokenizer. output_processor.add_request(request, prompt_string) @@ -610,7 +658,7 @@ def test_stop_token(include_stop_str_in_output: bool, # Update tracking. request_output = request_outputs[0] if request_output.finished: - finish_reason = ("length" if is_eos_ignore_test else "stop") + finish_reason = "length" if is_eos_ignore_test else "stop" assert request_output.outputs[0].finish_reason == finish_reason gen_string += request_output.outputs[0].text @@ -619,7 +667,7 @@ def test_stop_token(include_stop_str_in_output: bool, gen_logprobs.extend(request_output.outputs[0].logprobs) # Validate generated text - control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>' + control_token = "<|end_of_text|>" if is_eos_test else "<|eot_id|>" if is_eos_ignore_test: # Length-based stop; expect full string ref_str = generation_string + 2 * control_token @@ -629,14 +677,15 @@ def test_stop_token(include_stop_str_in_output: bool, else: # Stop token triggered but not in output ref_str = generation_string - assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}") + assert gen_string == ref_str, f"{gen_string=}, {ref_str=}" if do_logprobs: # Validate number of sample logprobs num_tokens = len(gen_tokens) num_logprobs = len(gen_logprobs) assert num_tokens == num_logprobs, ( - f"Token count ({num_tokens}) != logprobs count ({num_logprobs})") + f"Token count ({num_tokens}) != logprobs count ({num_logprobs})" + ) # Check requests are finished assert output_processor.get_num_unfinished_requests() == 0 @@ -644,32 +693,32 @@ def test_stop_token(include_stop_str_in_output: bool, @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -@pytest.mark.parametrize("num_sample_logprobs", - [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -def test_stop_string(include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, - log_stats=False) +@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +def test_stop_string( + include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + dummy_test_vectors, +): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, generated_logprobs_raw=dummy_test_vectors.generation_logprobs - if num_sample_logprobs else None, - prompt_logprobs_raw=None) + if num_sample_logprobs + else None, + prompt_logprobs_raw=None, + ) # Make N requests. request_id_list = [ - f"request-{idx}" - for idx in range(len(dummy_test_vectors.prompt_strings)) + f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings)) ] requests = [ EngineCoreRequest( request_id=request_id_list[idx], prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, @@ -682,7 +731,8 @@ def test_stop_string(include_stop_str_in_output: bool, logprobs=num_sample_logprobs, prompt_logprobs=None, ), - pooling_params=None) + pooling_params=None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -722,7 +772,8 @@ def test_stop_string(include_stop_str_in_output: bool, prompt_logprobs = request_output.prompt_logprobs logprobs = request_output.outputs[0].logprobs gen_cumulative_logprobs[request_id] = request_output.outputs[ - 0].cumulative_logprob + 0 + ].cumulative_logprob if request_id not in gen_strings: gen_strings[request_id] = new_text gen_tokens[request_id] = new_tokens @@ -740,8 +791,8 @@ def test_stop_string(include_stop_str_in_output: bool, # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, stop_str) in enumerate( - zip(dummy_test_vectors.generation_strings, STOP_STRINGS)): - + zip(dummy_test_vectors.generation_strings, STOP_STRINGS) + ): # Request should be aborted. request_id = f"request-{idx}" assert request_id in aborted @@ -755,24 +806,28 @@ def test_stop_string(include_stop_str_in_output: bool, ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str if include_stop_str_in_output: - assert gen_str == ref_str_inc_stop, ( - f"{gen_str=}, {ref_str_inc_stop=}") + assert gen_str == ref_str_inc_stop, f"{gen_str=}, {ref_str_inc_stop=}" else: - assert gen_str == ref_str_exc_stop, ( - f"{gen_str=}, {ref_str_exc_stop=}") + assert gen_str == ref_str_exc_stop, f"{gen_str=}, {ref_str_exc_stop=}" # Confirmed tracked logprobs match what we expect - _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, - gen_cumulative_logprobs, dummy_test_vectors, - request_id_list, num_sample_logprobs, None) + _validate_logprobs( + gen_tokens, + gen_logprobs, + gen_prompt_logprobs, + gen_cumulative_logprobs, + dummy_test_vectors, + request_id_list, + num_sample_logprobs, + None, + ) assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() def test_iteration_stats(dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, - log_stats=True) + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core_timestamp = time.monotonic() @@ -781,17 +836,16 @@ def test_iteration_stats(dummy_test_vectors): EngineCoreRequest( request_id=f"request-{idx}", prompt_token_ids=prompt_tokens, - arrival_time=0, - mm_kwargs=None, - mm_hashes=None, - mm_placeholders=None, + mm_features=None, eos_token_id=None, + arrival_time=0, lora_request=None, cache_salt=None, data_parallel_rank=None, sampling_params=SamplingParams(), pooling_params=None, - ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) + ) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add all requests except one to the OutputProcessor. @@ -803,12 +857,13 @@ def test_iteration_stats(dummy_test_vectors): # First iteration has 2 prefills. outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) - total_prompt_tokens = sum([ - len(prompt_tokens) - for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] - ]) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) + total_prompt_tokens = sum( + [ + len(prompt_tokens) + for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] + ] + ) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active @@ -816,8 +871,7 @@ def test_iteration_stats(dummy_test_vectors): # Just decodes in this step. outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_generation_tokens == num_active @@ -827,8 +881,7 @@ def test_iteration_stats(dummy_test_vectors): num_active += 1 outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens @@ -837,8 +890,7 @@ def test_iteration_stats(dummy_test_vectors): # Just decodes in this step. outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_generation_tokens == num_active @@ -862,16 +914,13 @@ async def test_request_output_collector(): text=TEXT, token_ids=[idx], cumulative_logprob=(idx + 1 * 1.0), - logprobs=[{ - "a": idx, - "b": idx - }], - finish_reason="length" if - (idx == NUM_REQS - 1) else None, + logprobs=[{"a": idx, "b": idx}], + finish_reason="length" if (idx == NUM_REQS - 1) else None, ) ], finished=(idx == NUM_REQS - 1), - ) for idx in range(NUM_REQS) + ) + for idx in range(NUM_REQS) ] collector = RequestOutputCollector(RequestOutputKind.DELTA) @@ -897,8 +946,7 @@ async def test_request_output_collector(): assert not output.finished # Text, token_ids, and logprobs should get merged. assert output.outputs[0].text == TEXT * num_to_put - for tok_0, tok_1 in zip(output.outputs[0].token_ids, - list(range(num_to_put))): + for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))): assert tok_0 == tok_1 assert len(output.outputs[0].logprobs) == num_to_put @@ -919,8 +967,7 @@ async def test_request_output_collector(): assert output.outputs[0].finish_reason == "length" # Text, token_ids, and logprobs should get merged. assert output.outputs[0].text == TEXT * num_to_put - for tok_0, tok_1 in zip(output.outputs[0].token_ids, - list(range(num_to_put))): + for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))): assert tok_0 == tok_1 assert len(output.outputs[0].logprobs) == num_to_put @@ -1008,3 +1055,34 @@ async def test_cumulative_output_collector_n(): third = [k for k in result.outputs if k.index == 2] assert len(third) == 1 assert third[0].text == "c" + + +@pytest.mark.parametrize("runner", ["generate", "pooling"]) +def test_abort_requests(runner: str, dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) + requests = [ + EngineCoreRequest( + request_id=f"request-{idx}", + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams() if runner == "generate" else None, + pooling_params=PoolingParams(task="embed") if runner == "pooling" else None, + ) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) + ] + + for request in requests: + if runner == "generate": + output_kind = request.sampling_params.output_kind + else: + output_kind = request.pooling_params.output_kind + queue = RequestOutputCollector(output_kind=output_kind) + output_processor.add_request(request, None, queue=queue) + + for request in requests: + output_processor.abort_requests([request.request_id]) diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py new file mode 100644 index 0000000000000..cb6865e42ef8b --- /dev/null +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import processor as processor_mod +from vllm.v1.engine.processor import Processor + +cherry_pil_image = ImageAsset("cherry_blossom").pil_image +stop_pil_image = ImageAsset("stop_sign").pil_image +baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays + + +# Mock processor for testing +def _mk_processor( + monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True +) -> Processor: + """ + Create a Processor instance with minimal configuration suitable for unit + tests without accessing external resources. + """ + monkeypatch.setattr( + ModelConfig, "try_get_generation_config", lambda self: {}, raising=True + ) + monkeypatch.setattr( + ModelConfig, "__post_init__", lambda self, *args: None, raising=True + ) + monkeypatch.setattr( + ModelConfig, + "verify_with_parallel_config", + lambda self, parallel_config: None, + raising=True, + ) + monkeypatch.setattr( + processor_mod, + "processor_cache_from_config", + lambda vllm_config, mm_registry: None, + raising=True, + ) + + monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True) + + model_config = ModelConfig( + skip_tokenizer_init=True, + max_model_len=128, + mm_processor_cache_gb=mm_cache_gb, + generation_config="vllm", + tokenizer="dummy", + ) + + # Minimal multimodal_config to satisfy references in + # Processor.process_inputs. + class _MockMMConfig: + def __init__(self, gb: float): + self.mm_processor_cache_gb = gb + + model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined] + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), + device_config=DeviceConfig(device="cpu"), + ) + + return Processor(vllm_config, tokenizer=None) + + +def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): + processor = _mk_processor(monkeypatch) + + prompt = { + "prompt": "USER: <image>\nDescribe\nASSISTANT:", + "multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]}, + # Mismatch: 2 items but only 1 uuid provided + "multi_modal_uuids": {"image": ["hash_cherry"]}, + } + + with pytest.raises(ValueError, match="must have same length as data"): + processor.process_inputs( + request_id="req-1", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + +def test_multi_modal_uuids_missing_modality_raises(monkeypatch): + processor = _mk_processor(monkeypatch) + + prompt = { + "prompt": "USER: <image><video>\nDescribe\nASSISTANT:", + # Two modalities provided in data + "multi_modal_data": { + "image": [cherry_pil_image], + "video": [baby_reading_np_ndarrays], + }, + # Only image uuids provided; video missing should raise + "multi_modal_uuids": {"image": ["hash_cherry"]}, + } + + with pytest.raises(ValueError, match="must be provided if multi_modal_data"): + processor.process_inputs( + request_id="req-2", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + +@pytest.mark.parametrize( + "mm_cache_gb, enable_prefix_caching", + [ + (4.0, True), # default behavior + (4.0, False), # prefix caching disabled + (0.0, True), # processor cache disabled + ], +) +def test_multi_modal_uuids_accepts_none_and_passes_through( + monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool +): + processor = _mk_processor( + monkeypatch, + mm_cache_gb=mm_cache_gb, + enable_prefix_caching=enable_prefix_caching, + ) + + # Capture the overrides passed to InputPreprocessor.preprocess + captured: dict[str, object] = {} + + def fake_preprocess( + prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None + ): + captured["mm_uuids"] = mm_uuids + # Minimal processed inputs for decoder-only flow + return {"type": "token", "prompt_token_ids": [1]} + + # Monkeypatch only the bound preprocess method on this instance + monkeypatch.setattr( + processor.input_preprocessor, "preprocess", fake_preprocess, raising=True + ) + + # Use a consistent two-image scenario across all configurations + mm_uuids = {"image": [None, "hash_stop"], "video": None} + prompt = { + "prompt": "USER: <image><image>\nTwo images\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image], + "video": baby_reading_np_ndarrays, + }, + "multi_modal_uuids": mm_uuids, + } + + processor.process_inputs( + request_id="req-3", + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + assert captured["mm_uuids"] == mm_uuids + + +def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): + # When both processor cache is 0 and prefix caching disabled, the + # processor builds overrides from request id instead of using user UUIDs. + processor = _mk_processor(monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False) + + captured: dict[str, object] = {} + + def fake_preprocess( + prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None + ): + captured["mm_uuids"] = mm_uuids + return {"type": "token", "prompt_token_ids": [1]} + + monkeypatch.setattr( + processor.input_preprocessor, "preprocess", fake_preprocess, raising=True + ) + + request_id = "req-42" + mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"} + prompt = { + "prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:", + "multi_modal_data": { + "image": [cherry_pil_image, stop_pil_image], + "video": baby_reading_np_ndarrays, + }, + "multi_modal_uuids": mm_uuids, + } + + processor.process_inputs( + request_id=request_id, + prompt=prompt, # type: ignore[arg-type] + params=SamplingParams(), + ) + + # Expect request-id-based overrides are passed through + assert captured["mm_uuids"] == { + "image": [f"{request_id}-image-0", f"{request_id}-image-1"], + "video": [f"{request_id}-video-0"], + } diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index b58bc75fc9565..9b720f6eb668e 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -9,7 +9,6 @@ import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector( upper: float, ) -> torch.Tensor: """Create a random vector of top logprob float values. - + Use to create fake sample logprobs for testing. Note that a real production scenario would require @@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix( upper: float, ) -> torch.Tensor: """Create a random matrix of top logprob float values. - + Use to create fake prompt logprobs for testing. Note that a real production scenario would require @@ -83,11 +82,12 @@ def _create_random_top_logprob_test_matrix( def _create_random_top_token_test_vector( - num_logprobs: int, - lower: int, - upper: int, - sampled_token_id: int, - adjust_num_logprobs: bool = True) -> tuple[torch.Tensor, int]: + num_logprobs: int, + lower: int, + upper: int, + sampled_token_id: int, + adjust_num_logprobs: bool = True, +) -> tuple[torch.Tensor, int]: """Create a random vector of top logprob token indices Use to create fake sample logprobs for testing. The sampled token @@ -128,8 +128,9 @@ def _create_random_top_token_test_vector( # Check if the sampled_token_id occurs in choice_tensor[1:] if sampled_token_id in choice_tensor[1:]: - sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero( - as_tuple=True)[0].item() + sampled_token_rank = ( + (choice_tensor[1:] == sampled_token_id).nonzero(as_tuple=True)[0].item() + ) else: # If not found, assign a random int between num_logprobs and 50700 sampled_token_rank = random.randint(num_logprobs, 50700) @@ -165,9 +166,12 @@ def _create_random_top_token_test_matrix( num_elements = shape[0] * shape[1] choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower matrix = torch.cat( - (torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), - choice_tensor.view(shape)), - dim=1) + ( + torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), + choice_tensor.view(shape), + ), + dim=1, + ) # Initialize the tensor for storing the ranks prompt_token_ranks = torch.empty(shape[0], dtype=torch.int) @@ -175,8 +179,7 @@ def _create_random_top_token_test_matrix( # Iterate over each row to check presence of # tokens_list[rdx] and determine its index for rdx in range(shape[0]): - row = matrix[rdx, - 1:] # Skip the first column as it contains the token list + row = matrix[rdx, 1:] # Skip the first column as it contains the token list token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0] if token_index.numel() > 0: prompt_token_ranks[rdx] = token_index.item() @@ -230,19 +233,21 @@ def generate_dummy_sample_logprobs( ( token_vector, sampled_token_rank, - ) = _create_random_top_token_test_vector(num_logprobs, 0, - len(tokenizer.vocab) - 1, - sampled_token_id) + ) = _create_random_top_token_test_vector( + num_logprobs, 0, len(tokenizer.vocab) - 1, sampled_token_id + ) res.append( - (token_vector, - _create_random_top_logprob_test_vector(num_logprobs + 1, -100, - 0), sampled_token_rank)) + ( + token_vector, + _create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0), + sampled_token_rank, + ) + ) # Convert tensors in the list tuples to Python lists res_list_format = [ - (log_probs_tensor.tolist(), token_ids_tensor.tolist(), - sampled_token_rank) + (log_probs_tensor.tolist(), token_ids_tensor.tolist(), sampled_token_rank) for log_probs_tensor, token_ids_tensor, sampled_token_rank in res ] @@ -283,20 +288,25 @@ def generate_dummy_prompt_logprobs_tensors( token_vector, prompt_token_ranks, ) = _create_random_top_token_test_matrix( - (num_prompt_logprobs, num_logprobs), 0, - len(tokenizer.vocab) - 1, prompt_tokens_list[1:]) + (num_prompt_logprobs, num_logprobs), + 0, + len(tokenizer.vocab) - 1, + prompt_tokens_list[1:], + ) return LogprobsTensors( token_vector, _create_random_top_logprob_test_matrix( - (num_prompt_logprobs, num_logprobs + 1), -100, 0), - prompt_token_ranks) + (num_prompt_logprobs, num_logprobs + 1), -100, 0 + ), + prompt_token_ranks, + ) @dataclass class DummyOutputProcessorTestVectors: """Dummy test vectors for output processor tests""" + tokenizer: GeneralTokenizerType - tokenizer_group: TokenizerGroup vllm_config: EngineArgs full_tokens: list[list[int]] # Prompt + generated tokens prompt_tokens: list[list[int]] @@ -322,9 +332,9 @@ class MockEngineCore: # For each request, for each sampled token offset, # a tuple of # (list of topk token ids, list of sample logprob vals, rank) - generated_logprobs_raw: Optional[list[list[tuple[list[int], - list[float], - int]]]] = None, + generated_logprobs_raw: Optional[ + list[list[tuple[list[int], list[float], int]]] + ] = None, # For each request, a tuple of # (prompt logprob val matrix, prompt logprob tok id matrix); # each matrix has dimensions @@ -357,7 +367,8 @@ class MockEngineCore: if do_logprobs: assert self.generated_logprobs_raw is not None (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( - self.generated_logprobs_raw[req_idx][token_idx]) + self.generated_logprobs_raw[req_idx][token_idx] + ) logprobs = LogprobsLists( [logprobs_token_ids_], [logprobs_], diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index ffe0612124660..40b9d1fe850c6 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -26,8 +26,10 @@ def sample_token_ids(): @pytest.fixture def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + return ( + r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + ) # Note: Ensure this only uses attributes compatible with xgrammar @@ -36,53 +38,44 @@ def sample_json_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", "items": { "type": "string", - } + }, }, "grade": { "type": "string", - "pattern": "^[A-D]$" # Regex pattern + "pattern": "^[A-D]$", # Regex pattern }, "email": { "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, + "company": {"type": "string"}, "duration": { "type": "number", "minimum": 0.0, "maximum": 100.0, # Numeric range }, - "position": { - "type": "string" - } + "position": {"type": "string"}, }, "required": ["company", "duration", "position"], - "additionalProperties": False + "additionalProperties": False, }, "minItems": 0, - "maxItems": 3 - } + "maxItems": 3, + }, }, - "required": - ["name", "age", "skills", "grade", "email", "work_history"], - "additionalProperties": False + "required": ["name", "age", "skills", "grade", "email", "work_history"], + "additionalProperties": False, } @@ -94,67 +87,60 @@ def unsupported_json_schema(): "properties": { "score": { "type": "integer", - "multipleOf": 5 # Numeric multiple + "multipleOf": 5, # Numeric multiple }, "tags": { "type": "array", - "items": { - "type": "string", - "minLength": 10, - "maxLength": 20 - } - } + "items": {"type": "string", "minLength": 10, "maxLength": 20}, + }, }, "required": ["score", "tags"], - "additionalProperties": False + "additionalProperties": False, } @pytest.fixture def sample_definition_json_schema(): return { - '$defs': { - 'Step': { - 'properties': { - 'explanation': { - 'title': 'Explanation', - 'type': 'string' - }, - 'output': { - 'title': 'Output', - 'type': 'string' - } + "$defs": { + "Step": { + "properties": { + "explanation": {"title": "Explanation", "type": "string"}, + "output": {"title": "Output", "type": "string"}, }, - 'required': ['explanation', 'output'], - 'title': 'Step', - 'type': 'object' + "required": ["explanation", "output"], + "title": "Step", + "type": "object", } }, - 'properties': { - 'steps': { - 'items': { - '$ref': '#/$defs/Step' - }, - 'title': 'Steps', - 'type': 'array' + "properties": { + "steps": { + "items": {"$ref": "#/$defs/Step"}, + "title": "Steps", + "type": "array", }, - 'final_answer': { - 'title': 'Final Answer', - 'type': 'string' - } + "final_answer": {"title": "Final Answer", "type": "string"}, }, - 'required': ['steps', 'final_answer'], - 'title': 'MathReasoning', - 'type': 'object', - "additionalProperties": False + "required": ["steps", "final_answer"], + "title": "MathReasoning", + "type": "object", + "additionalProperties": False, } @pytest.fixture -def sample_guided_choice(): +def sample_structured_outputs_choices(): return [ - "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", - "Ruby", "Swift", "Kotlin" + "Python", + "Java", + "JavaScript", + "C++", + "C#", + "PHP", + "TypeScript", + "Ruby", + "Swift", + "Kotlin", ] @@ -172,11 +158,11 @@ number ::= "1" | "2" @pytest.fixture def sample_sql_lark(): - return (""" + return """ start: select_statement select_statement: "SELECT" column "from" table "where" condition column: "col_1" | "col_2" table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") +""" diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 572af0175d114..b5d04679317e6 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -5,6 +5,7 @@ from __future__ import annotations import json +from dataclasses import fields from enum import Enum from typing import TYPE_CHECKING, Any @@ -15,15 +16,20 @@ import torch from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction +from vllm.config import StructuredOutputsConfig from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import ( + GuidedDecodingParams, + SamplingParams, + StructuredOutputsParams, +) if TYPE_CHECKING: - from vllm.config import TokenizerMode + from vllm.config.model import TokenizerMode NGRAM_SPEC_CONFIG = { "model": "[ngram]", @@ -41,19 +47,18 @@ EAGLE_SPEC_CONFIG = { PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", - NGRAM_SPEC_CONFIG), - #FIXME: This test is flaky on CI thus disabled - #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", - NGRAM_SPEC_CONFIG), + ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), + # FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402 + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + # ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), - ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", - EAGLE_SPEC_CONFIG) + ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG), ] PARAMS_MODELS_TOKENIZER_MODE = [ @@ -75,52 +80,56 @@ class CarDescription(BaseModel): car_type: CarType -def _load_json(s: str, backend: str) -> str: - if backend != "xgrammar": - return json.loads(s) +def test_guided_decoding_deprecated(): + with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"): + guided_decoding = GuidedDecodingParams(json_object=True) - # xgrammar specific workarounds - # https://github.com/mlc-ai/xgrammar/issues/286 - s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s) - return json.loads(s) + structured_outputs = StructuredOutputsParams(json_object=True) + assert fields(guided_decoding) == fields(structured_outputs) + + with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"): + sp1 = SamplingParams(guided_decoding=guided_decoding) + + with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"): + sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding) + + assert sp1 == sp2 + assert sp1.structured_outputs == guided_decoding @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( - "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) + "model_name, backend, tokenizer_mode, speculative_config", + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE, +) def test_structured_output( - monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], unsupported_json_schema: dict[str, Any], sample_sql_ebnf: str, sample_sql_lark: str, sample_regex: str, - sample_guided_choice: str, - guided_decoding_backend: str, + sample_structured_outputs_choices: str, + backend: str, tokenizer_mode: str, model_name: str, speculative_config: dict[str, Any], ): - monkeypatch.setenv("VLLM_USE_V1", "1") - if current_platform.is_tpu() and speculative_config: pytest.skip("TPU does not support speculative decoding") - # Don't use eager execution on TPUs because we want to test for no - # recompilation at runtime - enforce_eager = bool(not current_platform.is_tpu()) # Use a single LLM instance for several scenarios to # speed up the test suite. llm = LLM( model=model_name, - enforce_eager=enforce_eager, + enforce_eager=True, max_model_len=1024, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=(guided_decoding_backend - in {"xgrammar", "guidance"}), + structured_outputs_config=dict( + backend=backend, disable_any_whitespace=backend in {"xgrammar", "guidance"} + ), + seed=120, tokenizer_mode=tokenizer_mode, - speculative_config=speculative_config) + speculative_config=speculative_config, + ) # # Test 1: Generate JSON output based on a provided schema @@ -128,11 +137,14 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + structured_outputs=StructuredOutputsParams(json=sample_json_schema), + ) - prompt = ("Give an example JSON for an employee profile that fits this " - "schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") + prompt = ( + "Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}" + ) outputs = llm.generate( [prompt] * 2, sampling_params=sampling_params, @@ -148,27 +160,38 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - assert "\n" not in generated_text + if backend != "lm-format-enforcer": + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) + try: + output_json = json.loads(generated_text) + except json.JSONDecodeError as e: + pytest.fail( + f"Invalid JSON from backend={backend}: {generated_text!r}\n" + f"Schema: {sample_json_schema}\nError: {e}" + ) jsonschema.validate(instance=output_json, schema=sample_json_schema) # # Test 2: Generate JSON object without a schema # - if guided_decoding_backend != "outlines": + if backend != "outlines": sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, n=2, - guided_decoding=GuidedDecodingParams(json_object=True)) + structured_outputs=StructuredOutputsParams(json_object=True), + ) - outputs = llm.generate(prompts=( - "Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old. " - "Make the response as short as possible."), - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate( + prompts=( + "Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old. " + "Make the response as short as possible." + ), + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: @@ -190,24 +213,30 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) - if guided_decoding_backend.startswith("xgrammar"): - with pytest.raises(ValueError, - match="The provided JSON schema contains features " - "not supported by xgrammar."): - - prompt = (f"Give an example JSON for an employee profile that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.") + structured_outputs=StructuredOutputsParams(json=unsupported_json_schema), + ) + if backend.startswith("xgrammar"): + with pytest.raises( + ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar.", + ): + prompt = ( + f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible." + ) llm.generate( [prompt] * 2, sampling_params=sampling_params, use_tqdm=True, ) else: - prompt = (f"Give an example JSON object for a grade that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.") + prompt = ( + f"Give an example JSON object for a grade that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible." + ) outputs = llm.generate( prompt, sampling_params=sampling_params, @@ -225,7 +254,7 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - if guided_decoding_backend != "outlines": + if backend not in ["outlines", "lm-format-enforcer"]: # # Test 4: Generate SQL statement using EBNF grammar # @@ -233,11 +262,14 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) + structured_outputs=StructuredOutputsParams(grammar=sample_sql_ebnf), + ) outputs = llm.generate( - ("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -252,8 +284,7 @@ def test_structured_output( assert generated_text is not None # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") assert generated_text.strip() == ground_truth @@ -266,11 +297,14 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) + structured_outputs=StructuredOutputsParams(grammar=sample_sql_lark), + ) outputs = llm.generate( - ("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -286,12 +320,12 @@ def test_structured_output( # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark + parser = Lark(sample_sql_lark) parser.parse(generated_text) # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") assert generated_text.strip() == ground_truth @@ -304,12 +338,15 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar="not a grammar")) + structured_outputs=StructuredOutputsParams(grammar="not a grammar"), + ) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - ("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short " - "as possible."), + ( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short " + "as possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -320,10 +357,13 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - guided_decoding=GuidedDecodingParams(regex=sample_regex)) + structured_outputs=StructuredOutputsParams(regex=sample_regex), + ) - prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " - f"Make the response as short as possible.") + prompt = ( + f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible." + ) outputs = llm.generate( [prompt] * 2, sampling_params=sampling_params, @@ -347,11 +387,16 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) + structured_outputs=StructuredOutputsParams( + choice=sample_structured_outputs_choices + ), + ) outputs = llm.generate( - ("The best language for type-safe systems programming is " - "(Make the response as short as possible.) "), + ( + "The best language for type-safe systems programming is " + "(Make the response as short as possible.) " + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -363,7 +408,7 @@ def test_structured_output( generated_text = output.outputs[0].text print(generated_text) assert generated_text is not None - assert generated_text in sample_guided_choice + assert generated_text in sample_structured_outputs_choices print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # @@ -373,12 +418,15 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=json_schema)) + structured_outputs=StructuredOutputsParams(json=json_schema), + ) outputs = llm.generate( - ("Generate a JSON with the brand, model and car_type of the most " - "iconic car from the 90's. Make the response as short as " - "possible."), + ( + "Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -393,7 +441,13 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) + try: + output_json = json.loads(generated_text) + except json.JSONDecodeError as e: + pytest.fail( + f"Invalid JSON from backend={backend}: {generated_text!r}\n" + f"Schema: {json_schema}\nError: {e}" + ) jsonschema.validate(instance=output_json, schema=json_schema) # @@ -407,21 +461,24 @@ def test_structured_output( "description": { "type": "string", "maxLength": max_length, - "minLength": min_length + "minLength": min_length, } }, "required": ["description"], - "additionalProperties": False + "additionalProperties": False, } sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=json_schema)) + structured_outputs=StructuredOutputsParams(json=json_schema), + ) outputs = llm.generate( - ("Generate a description of a frog using 50 characters. " - "Make the response as short as possible."), + ( + "Generate a description of a frog using 50 characters. " + "Make the response as short as possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -436,37 +493,42 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - output_json = json.loads(generated_text) + try: + output_json = json.loads(generated_text) + except json.JSONDecodeError as e: + pytest.fail( + f"Invalid JSON from backend={backend}: {generated_text!r}\n" + f"Schema: {json_schema}\nError: {e}" + ) jsonschema.validate(instance=output_json, schema=json_schema) - if guided_decoding_backend != "outlines": + if backend not in ["outlines", "lm-format-enforcer"]: # # Test 11: Generate structured output using structural_tag format # structural_tag_config = { - "type": - "structural_tag", - "structures": [{ - "begin": "<function=get_weather>", - "schema": { - "type": "object", - "properties": { - "city": { - "type": "string" - } + "type": "structural_tag", + "structures": [ + { + "begin": "<function=get_weather>", + "schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "additionalProperties": False, }, - "additionalProperties": False - }, - "end": "</function>" - }], - "triggers": ["<function="] + "end": "</function>", + } + ], + "triggers": ["<function="], } sampling_params = SamplingParams( temperature=0.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams( - structural_tag=json.dumps(structural_tag_config))) + structured_outputs=StructuredOutputsParams( + structural_tag=json.dumps(structural_tag_config) + ), + ) prompt = """ You have access to the following function to retrieve the weather in a city: @@ -508,9 +570,7 @@ Make the response as short as possible. """ # Change this once other backends support structural_tag - outputs = llm.generate(prompt, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: @@ -520,12 +580,13 @@ Make the response as short as possible. assert generated_text is not None # Search for function call pattern in the response - function_call_pattern = r'<function=get_weather>(.*?)</function>' + function_call_pattern = r"<function=get_weather>(.*?)</function>" matches = re.findall(function_call_pattern, generated_text) if not matches: - print(f"Warning: No function calls found in response: " - f"{generated_text!r}") + print( + f"Warning: No function calls found in response: {generated_text!r}" + ) continue # Take the first function call if multiple are found @@ -536,29 +597,32 @@ Make the response as short as possible. assert isinstance(json_content["city"], str) print(f"Found valid function call: {generated_text!r}") except (json.JSONDecodeError, AssertionError) as e: - pytest.fail("Invalid function call format: " - f"{generated_text!r}\nError: {str(e)}") + pytest.fail( + f"Invalid function call format: {generated_text!r}\nError: {str(e)}" + ) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( - "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 + "model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 [ - ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", - "deepseek_r1", NGRAM_SPEC_CONFIG), + ( + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "xgrammar", + "auto", + "deepseek_r1", + NGRAM_SPEC_CONFIG, + ), ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None), ], ) def test_structured_output_with_reasoning_matrices( - monkeypatch: pytest.MonkeyPatch, - guided_decoding_backend: str, + backend: str, tokenizer_mode: TokenizerMode, reasoning_parser: str, model_name: str, speculative_config: dict[str, Any] | None, ): - monkeypatch.setenv("VLLM_USE_V1", "1") - if current_platform.is_tpu() and speculative_config: pytest.skip("TPU does not support speculative decoding") @@ -571,26 +635,25 @@ def test_structured_output_with_reasoning_matrices( enforce_eager=bool(not current_platform.is_tpu()), max_model_len=1024, max_num_seqs=16, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=True, + structured_outputs_config=dict( + backend=backend, + disable_any_whitespace=backend in {"xgrammar", "guidance"}, + reasoning_parser=reasoning_parser, + ), tokenizer_mode=tokenizer_mode, - reasoning_parser=reasoning_parser, speculative_config=speculative_config, ) - tokenizer = llm.get_tokenizer(None) + tokenizer = llm.get_tokenizer() reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( - tokenizer=tokenizer) + tokenizer=tokenizer + ) reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501 reasoning_schema = { "type": "object", - "properties": { - "result": { - "type": "integer" - } - }, + "properties": {"result": {"type": "integer"}}, "required": ["result"], - "additionalProperties": False + "additionalProperties": False, } if "Qwen3" in model_name: reasoning_prompt += "<think>\n" @@ -598,7 +661,7 @@ def test_structured_output_with_reasoning_matrices( sampling_params = SamplingParams( temperature=0.1, max_tokens=8192, - guided_decoding=GuidedDecodingParams(json=reasoning_schema), + structured_outputs=StructuredOutputsParams(json=reasoning_schema), ) outputs = llm.generate( [reasoning_prompt], @@ -611,11 +674,8 @@ def test_structured_output_with_reasoning_matrices( assert output is not None and isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text - reasoning_content, content = run_reasoning_extraction( - reasoner, [generated_text]) - print( - f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" - ) + reasoning_content, content = run_reasoning_extraction(reasoner, [generated_text]) + print(f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}") assert content is not None and reasoning_content is not None output_json = json.loads(content) @@ -623,39 +683,38 @@ def test_structured_output_with_reasoning_matrices( @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("model_name, tokenizer_mode", - PARAMS_MODELS_TOKENIZER_MODE) +@pytest.mark.parametrize("model_name, tokenizer_mode", PARAMS_MODELS_TOKENIZER_MODE) def test_structured_output_auto_mode( - monkeypatch: pytest.MonkeyPatch, unsupported_json_schema: dict[str, Any], model_name: str, tokenizer_mode: str, ): - monkeypatch.setenv("VLLM_USE_V1", "1") - - llm = LLM(model=model_name, - max_model_len=1024, - guided_decoding_backend="auto", - tokenizer_mode=tokenizer_mode) + llm = LLM( + model=model_name, + max_model_len=1024, + structured_outputs_config=dict(backend="auto"), + tokenizer_mode=tokenizer_mode, + ) sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) + structured_outputs=StructuredOutputsParams(json=unsupported_json_schema), + ) prompts = ( "Give an example JSON object for a grade " "that fits this schema: " - f"{unsupported_json_schema}. Make the response as short as possible.") + f"{unsupported_json_schema}. Make the response as short as possible." + ) # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) # Make sure `auto` backend handling doesn't mess up sampling_params # and that we can reuse it without error. outputs.extend( - llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)) + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) + ) assert outputs is not None for output in outputs: @@ -671,29 +730,25 @@ def test_structured_output_auto_mode( @pytest.mark.skip_global_cleanup -def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_V1", "1") - - llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", - max_model_len=1024, - guided_decoding_backend="guidance", - guided_decoding_disable_any_whitespace=True, - guided_decoding_disable_additional_properties=True) +def test_guidance_no_additional_properties(): + llm = LLM( + model="Qwen/Qwen2.5-1.5B-Instruct", + max_model_len=1024, + structured_outputs_config=dict( + backend="guidance", + disable_any_whitespace=True, + disable_additional_properties=True, + ), + ) schema = { - 'type': 'object', - 'properties': { - 'a1': { - 'type': 'string' - }, - 'a2': { - 'type': 'string' - }, - 'a3': { - 'type': 'string' - } + "type": "object", + "properties": { + "a1": {"type": "string"}, + "a2": {"type": "string"}, + "a3": {"type": "string"}, }, - 'required': ['a1', 'a2', 'a3'], + "required": ["a1", "a2", "a3"], } prompt = ( @@ -701,17 +756,19 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. " "Make the response as short as possible." - "<|im_end|>\n<|im_start|>assistant\n") + "<|im_end|>\n<|im_start|>assistant\n" + ) def generate_with_backend(backend): - guided_params = GuidedDecodingParams( + structured_outputs_params = StructuredOutputsParams( json=schema, backend=backend, disable_any_whitespace=True, - disable_additional_properties=True) - sampling_params = SamplingParams(temperature=0, - max_tokens=256, - guided_decoding=guided_params) + disable_additional_properties=True, + ) + sampling_params = SamplingParams( + temperature=0, max_tokens=256, structured_outputs=structured_outputs_params + ) outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None @@ -731,15 +788,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): assert "a6" not in generated -@pytest.mark.parametrize("guided_decoding_backend", - ["guidance", "xgrammar", "outlines"]) -def test_structured_output_batched_with_non_guided_requests( - monkeypatch: pytest.MonkeyPatch, +@pytest.mark.parametrize("backend", ["guidance", "xgrammar", "outlines"]) +def test_structured_output_batched_with_non_structured_outputs_requests( sample_json_schema: dict[str, Any], - guided_decoding_backend: str, + backend: str, ): - monkeypatch.setenv("VLLM_USE_V1", "1") - # Don't use eager execution on TPUs because we want to test for no # recompilation at runtime enforce_eager = bool(not current_platform.is_tpu()) @@ -748,24 +801,27 @@ def test_structured_output_batched_with_non_guided_requests( model="meta-llama/Meta-Llama-3.1-8B-Instruct", enforce_eager=enforce_eager, max_model_len=1024, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=(guided_decoding_backend - in {"xgrammar", "guidance"}), + structured_outputs_config=StructuredOutputsConfig( + backend=backend, + disable_any_whitespace=backend in {"xgrammar", "guidance"}, + ), ) - guided_prompt = ( + structured_outputs_prompt = ( "Give an example JSON for an employee profile that fits this " "schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") + f"{sample_json_schema}" + ) - non_guided_prompt = "The diameter of the Earth in kilometers is " + non_structured_outputs_prompt = "The diameter of the Earth in kilometers is " - prompts = [guided_prompt, non_guided_prompt] + prompts = [structured_outputs_prompt, non_structured_outputs_prompt] sampling_params = [ SamplingParams( temperature=1.0, max_tokens=400, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)), + structured_outputs=StructuredOutputsParams(json=sample_json_schema), + ), # No max tokens, temp=0 to assert on contents SamplingParams( seed=42, @@ -774,9 +830,9 @@ def test_structured_output_batched_with_non_guided_requests( ), ] - outputs = llm.generate(prompts=prompts, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate( + prompts=prompts, sampling_params=sampling_params, use_tqdm=True + ) assert outputs is not None @@ -796,16 +852,15 @@ def test_structured_output_batched_with_non_guided_requests( print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") if index == 0: - # First prompt is guided, expect valid JSON + # First prompt is structured outputs, expect valid JSON assert "\n" not in generated_text output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, - schema=sample_json_schema) + jsonschema.validate(instance=output_json, schema=sample_json_schema) else: - # Second prompt is not guided, expect valid output + # Second prompt is not structured outputs, expect valid output # Cannot assert on exact output, but we can expect it to be factual assert "12,742" in generated_text - # non-guided requests should not return a valid JSON here + # non-structured outputs requests should not return a valid JSON here with pytest.raises(ValueError): output_json = json.loads(generated_text) diff --git a/tests/v1/entrypoints/openai/responses/conftest.py b/tests/v1/entrypoints/openai/responses/conftest.py index 2d677a00b646a..ad7594a3dd6dd 100644 --- a/tests/v1/entrypoints/openai/responses/conftest.py +++ b/tests/v1/entrypoints/openai/responses/conftest.py @@ -23,9 +23,9 @@ def default_server_args(): @pytest.fixture(scope="module") def server_with_store(default_server_args): with RemoteOpenAIServer( - MODEL_NAME, - default_server_args, - env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, + MODEL_NAME, + default_server_args, + env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, ) as remote_server: yield remote_server diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 7a0baa5767cba..dd3a563e9570a 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import openai # use the official client for correctness check +import openai.types.responses as openai_responses_types import pytest @@ -35,24 +36,14 @@ async def test_instructions(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_chat(client: openai.AsyncOpenAI): - response = await client.responses.create(input=[ - { - "role": "system", - "content": "Finish the answer with QED." - }, - { - "role": "user", - "content": "What is 5 * 3?" - }, - { - "role": "assistant", - "content": "15. QED." - }, - { - "role": "user", - "content": "Multiply the result by 2." - }, - ], ) + response = await client.responses.create( + input=[ + {"role": "system", "content": "Finish the answer with QED."}, + {"role": "user", "content": "What is 5 * 3?"}, + {"role": "assistant", "content": "15. QED."}, + {"role": "user", "content": "Multiply the result by 2."}, + ], + ) print(response) output_text = response.output[-1].content[0].text @@ -62,15 +53,14 @@ async def test_chat(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_chat_with_input_type(client: openai.AsyncOpenAI): - response = await client.responses.create(input=[ - { - "role": "user", - "content": [{ - "type": "input_text", - "text": "Hello!" - }], - }, - ], ) + response = await client.responses.create( + input=[ + { + "role": "user", + "content": [{"type": "input_text", "text": "Hello!"}], + }, + ], + ) print(response) assert response.status == "completed" @@ -86,3 +76,18 @@ async def test_logprobs(client: openai.AsyncOpenAI): outputs = response.output assert outputs[-1].content[-1].logprobs assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5 + + +@pytest.mark.asyncio +async def test_streaming(client: openai.AsyncOpenAI): + stream = await client.responses.create( + input="What is 13 * 24?", + stream=True, + ) + events = [event async for event in stream] + assert isinstance(events[0], openai_responses_types.ResponseCreatedEvent) + assert any( + isinstance(event, openai_responses_types.ResponseTextDeltaEvent) + for event in events + ) + assert isinstance(events[-1], openai_responses_types.ResponseCompletedEvent) diff --git a/tests/v1/entrypoints/openai/responses/test_image.py b/tests/v1/entrypoints/openai/responses/test_image.py index c8d09fd39fb13..980d83b787e7a 100644 --- a/tests/v1/entrypoints/openai/responses/test_image.py +++ b/tests/v1/entrypoints/openai/responses/test_image.py @@ -8,17 +8,17 @@ import pytest import pytest_asyncio from tests.utils import RemoteOpenAIServer -from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.multimodal.utils import encode_image_base64 # Use a small vision model for testing MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] @@ -38,9 +38,9 @@ def default_image_server_args(): @pytest.fixture(scope="module") def image_server(default_image_server_args): with RemoteOpenAIServer( - MODEL_NAME, - default_image_server_args, - env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, + MODEL_NAME, + default_image_server_args, + env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, ) as remote_server: yield remote_server @@ -52,34 +52,33 @@ async def client(image_server): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_url: encode_image_base64(local_asset_server.get_image_asset(image_url)) + for image_url in TEST_IMAGE_ASSETS } @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) -async def test_single_chat_session_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) +async def test_single_chat_session_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_image", - "image_url": image_url, - "detail": "auto", - }, - { - "type": "input_text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": image_url, + "detail": "auto", + }, + {"type": "input_text", "text": content_text}, + ], + } + ] # test image url response = await client.responses.create( @@ -91,30 +90,27 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) async def test_single_chat_session_image_base64encoded( client: openai.AsyncOpenAI, model_name: str, - image_url: str, + raw_image_url: str, base64_encoded_image: dict[str, str], ): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_image", - "image_url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}", - "detail": "auto", - }, - { - "type": "input_text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", # noqa: E501 + "detail": "auto", + }, + {"type": "input_text", "text": content_text}, + ], + } + ] # test image base64 response = await client.responses.create( model=model_name, @@ -127,24 +123,28 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) -async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_urls: list[str]): - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "input_image", - "image_url": image_url, - "detail": "auto", - } for image_url in image_urls), - { - "type": "input_text", - "text": "What's in this image?" - }, - ], - }] + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True, +) +async def test_multi_image_input( + client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + { + "type": "input_image", + "image_url": image_url, + "detail": "auto", + } + for image_url in image_urls + ), + {"type": "input_text", "text": "What's in this image?"}, + ], + } + ] if len(image_urls) > MAXIMUM_IMAGES: with pytest.raises(openai.BadRequestError): # test multi-image input @@ -155,10 +155,12 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, # the server should still work afterwards response = await client.responses.create( model=model_name, - input=[{ - "role": "user", - "content": "What's the weather like in Paris today?", - }], + input=[ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ], ) assert len(response.output_text) > 0 else: diff --git a/tests/v1/entrypoints/openai/responses/test_stateful.py b/tests/v1/entrypoints/openai/responses/test_stateful.py index a2d581ef7ced8..6f7edb6bd7e78 100644 --- a/tests/v1/entrypoints/openai/responses/test_stateful.py +++ b/tests/v1/entrypoints/openai/responses/test_stateful.py @@ -24,8 +24,7 @@ async def test_store(client: openai.AsyncOpenAI): assert response.status == "completed" # The response should not be found. - with pytest.raises(openai.NotFoundError, - match="Response with id .* not found."): + with pytest.raises(openai.NotFoundError, match="Response with id .* not found."): await client.responses.retrieve(response.id) @@ -53,8 +52,8 @@ async def test_background(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_background_error(client: openai.AsyncOpenAI): with pytest.raises( - openai.BadRequestError, - match="background can only be used when `store` is true"): + openai.BadRequestError, match="background can only be used when `store` is true" + ): _ = await client.responses.create( input="What is 13 * 24?", background=True, @@ -87,8 +86,9 @@ async def test_cancel_completed(client: openai.AsyncOpenAI): response = await client.responses.create(input="Hello") assert response.status == "completed" - with pytest.raises(openai.BadRequestError, - match="Cannot cancel a synchronous response."): + with pytest.raises( + openai.BadRequestError, match="Cannot cancel a synchronous response." + ): await client.responses.cancel(response.id) @@ -97,7 +97,8 @@ async def test_previous_response_id(client: openai.AsyncOpenAI): response1 = await client.responses.create( instructions="You are tested on your ability to retrieve the correct " "information from the previous response.", - input="Hello, my name is John.") + input="Hello, my name is John.", + ) response2 = await client.responses.create( input="Actually, my name is not John. My real name is Mark.", @@ -118,7 +119,8 @@ async def test_two_responses_with_same_prev_id(client: openai.AsyncOpenAI): response1 = await client.responses.create( instructions="You are tested on your ability to retrieve the correct " "information from the previous response.", - input="Hello, my name is John.") + input="Hello, my name is John.", + ) # Both response 2 and 3 use response 1 as the previous response. response2 = client.responses.create( diff --git a/tests/v1/entrypoints/openai/responses/test_structured_output.py b/tests/v1/entrypoints/openai/responses/test_structured_output.py index c4c43a87b601a..db8b87768e44f 100644 --- a/tests/v1/entrypoints/openai/responses/test_structured_output.py +++ b/tests/v1/entrypoints/openai/responses/test_structured_output.py @@ -11,14 +11,10 @@ from pydantic import BaseModel async def test_structured_output(client: openai.AsyncOpenAI): response = await client.responses.create( input=[ - { - "role": "system", - "content": "Extract the event information." - }, + {"role": "system", "content": "Extract the event information."}, { "role": "user", - "content": - "Alice and Bob are going to a science fair on Friday.", + "content": "Alice and Bob are going to a science fair on Friday.", }, ], text={ @@ -28,18 +24,9 @@ async def test_structured_output(client: openai.AsyncOpenAI): "schema": { "type": "object", "properties": { - "event_name": { - "type": "string" - }, - "date": { - "type": "string" - }, - "participants": { - "type": "array", - "items": { - "type": "string" - } - }, + "event_name": {"type": "string"}, + "date": {"type": "string"}, + "participants": {"type": "array", "items": {"type": "string"}}, }, "required": ["event_name", "date", "participants"], "additionalProperties": False, @@ -65,7 +52,6 @@ async def test_structured_output(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_structured_output_with_parse(client: openai.AsyncOpenAI): - class CalendarEvent(BaseModel): event_name: str date: str diff --git a/tests/v1/entrypoints/openai/test_chat_completion.py b/tests/v1/entrypoints/openai/test_chat_completion.py index dffb32846c05e..522c72b559556 100644 --- a/tests/v1/entrypoints/openai/test_chat_completion.py +++ b/tests/v1/entrypoints/openai/test_chat_completion.py @@ -40,8 +40,7 @@ async def client(server): "model_name", [MODEL_NAME], ) -async def test_invalid_json_schema(client: openai.AsyncOpenAI, - model_name: str) -> None: +async def test_invalid_json_schema(client: openai.AsyncOpenAI, model_name: str) -> None: invalid_json_schema = { "$defs": { "CarType": { @@ -51,33 +50,29 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, } }, "properties": { - "brand": { - "title": "Brand", - "type": "string" - }, - "model": { - "title": "Model", - "type": "string" - }, - "car_type": { - "$ref": "#/$defs/CarType" - }, + "brand": {"title": "Brand", "type": "string"}, + "model": {"title": "Model", "type": "string"}, + "car_type": {"$ref": "#/$defs/CarType"}, "foo": "bar", }, "required": ["brand", "model", "car_type"], "title": "CarDescription", "type": "object", } - prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") + prompt = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"guided_json": invalid_json_schema}, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + extra_body={"structured_outputs": {"json": invalid_json_schema}}, ) @@ -87,21 +82,22 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, [MODEL_NAME], ) async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): - prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + prompt = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={ - "guided_regex": r"[.*", - "stop": ["\n"] - }, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + extra_body={"structured_outputs": {"regex": r"[.*"}, "stop": ["\n"]}, ) @@ -125,14 +121,20 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): number ::= "1 " | "2 " """ - prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") + prompt = ( + "Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table." + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + extra_body={ + "structured_outputs": {"grammar": invalid_simplified_sql_grammar} + }, ) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 3a65583fab8d3..66dbed2b9fddf 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -31,12 +31,13 @@ def default_server_args(): ] -@pytest.fixture(scope="module", - params=[["--no-enable-prefix-caching"], - [ - "--no-enable-prefix-caching", - "--disable-frontend-multiprocessing" - ]]) +@pytest.fixture( + scope="module", + params=[ + ["--no-enable-prefix-caching"], + ["--no-enable-prefix-caching", "--disable-frontend-multiprocessing"], + ], +) def server(default_server_args, request): if request.param: default_server_args = default_server_args + request.param @@ -55,12 +56,10 @@ async def client(server): "model_name", [MODEL_NAME], ) -async def test_single_completion(client: openai.AsyncOpenAI, - model_name: str) -> None: - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str) -> None: + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -69,7 +68,8 @@ async def test_single_completion(client: openai.AsyncOpenAI, assert len(choice.text) >= 5 assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + completion_tokens=5, prompt_tokens=6, total_tokens=11 + ) # test using token IDs completion = await client.completions.create( @@ -147,11 +147,12 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME], ) -async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str) -> None: - +async def test_too_many_completion_logprobs( + client: openai.AsyncOpenAI, model_name: str +) -> None: with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs + (openai.BadRequestError, openai.APIError) + ): # test using token IDs await client.completions.create( model=model_name, prompt=[0, 0, 0, 0, 0], @@ -163,7 +164,8 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ) ... with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs + (openai.BadRequestError, openai.APIError) + ): # test using token IDs stream = await client.completions.create( model=model_name, prompt=[0, 0, 0, 0, 0], @@ -188,13 +190,13 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), - (MODEL_NAME, 0), - (MODEL_NAME, 1), - (MODEL_NAME, None)]) -async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): +@pytest.mark.parametrize( + "model_name, prompt_logprobs", + [(MODEL_NAME, -1), (MODEL_NAME, 0), (MODEL_NAME, 1), (MODEL_NAME, None)], +) +async def test_prompt_logprobs_completion( + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] +): params: dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, @@ -223,8 +225,9 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str) -> None: +async def test_completion_streaming( + client: openai.AsyncOpenAI, model_name: str +) -> None: prompt = "What is an LLM?" single_completion = await client.completions.create( @@ -234,11 +237,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, temperature=0.0, ) single_output = single_completion.choices[0].text - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: @@ -257,8 +258,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_parallel_no_streaming(client: openai.AsyncOpenAI, - model_name: str): +async def test_parallel_no_streaming(client: openai.AsyncOpenAI, model_name: str): """Parallel sampling without streaming. A single request output contains a list of completions. """ @@ -268,27 +268,26 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, max_tokens = 50 # we want some to finish earlier than others # High temperature to maximize chance of unique completions. - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=max_tokens, - n=n, - temperature=1.0, - stream=False, - logprobs=0, - seed=42) + completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=1.0, + stream=False, + logprobs=0, + seed=42, + ) # Assert `n` completions num_completions = len(completion.choices) - assert num_completions == n, ( - f"Num completions {num_completions} but expected {n}.") + assert num_completions == n, f"Num completions {num_completions} but expected {n}." completion_repeats: dict[str, int] = {} output_token_lengths = set() for idx, choice in enumerate(completion.choices): # Assert correct completion index & some finish reason. - assert choice.index == idx, ( - f"Index {choice.index} but expected {idx}.") - assert choice.finish_reason is not None, ( - "None finish_reason is invalid.") + assert choice.index == idx, f"Index {choice.index} but expected {idx}." + assert choice.finish_reason is not None, "None finish_reason is invalid." text = choice.text completion_repeats[text] = completion_repeats.get(text, 0) + 1 output_token_lengths.add(len(choice.logprobs.tokens)) @@ -297,13 +296,10 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: - repeats = { - txt: num - for (txt, num) in completion_repeats.items() if num > 1 - } + repeats = {txt: num for (txt, num) in completion_repeats.items() if num > 1} raise AssertionError( - f"Expected {n} unique completions, got {num_unique};" - f" repeats: {repeats}.") + f"Expected {n} unique completions, got {num_unique}; repeats: {repeats}." + ) @pytest.mark.asyncio @@ -321,13 +317,15 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): n = 3 max_tokens = 50 # we want some to finish earlier than others - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=max_tokens, - n=n, - temperature=1.0, - stream=True, - seed=42) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=1.0, + stream=True, + seed=42, + ) chunks: list[list[str]] = [[] for _ in range(n)] finish_reason_count = 0 async for chunk in stream: @@ -338,7 +336,8 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): finish_reason_count += 1 # Assert `n` completions with correct finish reasons assert finish_reason_count == n, ( - f"Expected {n} completions with valid indices and finish_reason.") + f"Expected {n} completions with valid indices and finish_reason." + ) completion_repeats: dict[str, int] = {} chunk_lengths = set() for chunk in chunks: @@ -346,7 +345,8 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): # Assert correct number of completion tokens chunk_lengths.add(chunk_len) assert chunk_len <= max_tokens, ( - f"max_tokens={max_tokens} but chunk len is {chunk_len}.") + f"max_tokens={max_tokens} but chunk len is {chunk_len}." + ) text = "".join(chunk) completion_repeats[text] = completion_repeats.get(text, 0) + 1 print(text) @@ -355,12 +355,10 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: - repeats = { - txt: num - for (txt, num) in completion_repeats.items() if num > 1 - } - raise AssertionError(f"{num_unique} unique completions, expected {n};" - f" repeats: {repeats}") + repeats = {txt: num for (txt, num) in completion_repeats.items() if num > 1} + raise AssertionError( + f"{num_unique} unique completions, expected {n}; repeats: {repeats}" + ) @pytest.mark.asyncio @@ -368,114 +366,122 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME], ) -async def test_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): +async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: str): prompt = "What is the capital of France?" # Test stream=True, stream_options= # {"include_usage": False, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - False, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": False, + }, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options= # {"include_usage": False, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - True, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": True, + }, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options= # {"include_usage": True, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - False, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": False, + }, + ) async for chunk in stream: if chunk.choices[0].finish_reason is None: assert chunk.usage is None else: assert chunk.usage is None - final_chunk = await stream.__anext__() + final_chunk = await anext(stream) assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=True, stream_options= # {"include_usage": True, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - True, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + ) async for chunk in stream: assert chunk.usage is not None assert chunk.usage.prompt_tokens > 0 assert chunk.usage.completion_tokens > 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if chunk.choices[0].finish_reason is not None: - final_chunk = await stream.__anext__() + final_chunk = await anext(stream) assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=False, stream_options= # {"include_usage": None} with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}) + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}, + ) # Test stream=False, stream_options= # {"include_usage": True} with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}) + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}, + ) # Test stream=False, stream_options= # {"continuous_usage_stats": None} @@ -486,7 +492,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=False, - stream_options={"continuous_usage_stats": None}) + stream_options={"continuous_usage_stats": None}, + ) # Test stream=False, stream_options= # {"continuous_usage_stats": True} @@ -497,7 +504,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=False, - stream_options={"continuous_usage_stats": True}) + stream_options={"continuous_usage_stats": True}, + ) @pytest.mark.asyncio @@ -528,15 +536,19 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): extra_body=dict( # NOTE: this has to be true for n > 1 in vLLM, but # not necessary for official client. - use_beam_search=True), + use_beam_search=True + ), ) assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" + assert batch.choices[0].text != batch.choices[1].text, ( + "beam search should be different" + ) + assert batch.choices[0].text == batch.choices[2].text, ( + "two copies of the same prompt should be the same" + ) + assert batch.choices[1].text == batch.choices[3].text, ( + "two copies of the same prompt should be the same" + ) # test streaming batch = await client.completions.create( @@ -560,31 +572,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME], ) @pytest.mark.parametrize("logprobs_arg", [1, 0]) -async def test_echo_logprob_completion(client: openai.AsyncOpenAI, - model_name: str, logprobs_arg: int): +async def test_echo_logprob_completion( + client: openai.AsyncOpenAI, model_name: str, logprobs_arg: int +): tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) # test using text and token IDs for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - echo=True, - logprobs=logprobs_arg) + completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg, + ) - prompt_text = tokenizer.decode(prompt) if isinstance(prompt, - list) else prompt + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt assert re.search(r"^" + prompt_text, completion.choices[0].text) logprobs = completion.choices[0].logprobs assert logprobs is not None assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) + assert len(logprobs.token_logprobs) > 5 and logprobs.token_logprobs[0] is None + assert len(logprobs.top_logprobs) > 5 and logprobs.top_logprobs[0] is None for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 @@ -593,8 +604,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_invalid_json_schema(client: openai.AsyncOpenAI, - model_name: str) -> None: +async def test_invalid_json_schema(client: openai.AsyncOpenAI, model_name: str) -> None: invalid_json_schema = { "$defs": { "CarType": { @@ -604,30 +614,24 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, } }, "properties": { - "brand": { - "title": "Brand", - "type": "string" - }, - "model": { - "title": "Model", - "type": "string" - }, - "car_type": { - "$ref": "#/$defs/CarType" - }, + "brand": {"title": "Brand", "type": "string"}, + "model": {"title": "Model", "type": "string"}, + "car_type": {"$ref": "#/$defs/CarType"}, "foo": "bar", }, "required": ["brand", "model", "car_type"], "title": "CarDescription", "type": "object", } - prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") + prompt = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.completions.create( model=model_name, prompt=prompt, - extra_body={"guided_json": invalid_json_schema}, + extra_body={"structured_outputs": {"json": invalid_json_schema}}, ) @@ -637,18 +641,17 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, [MODEL_NAME], ) async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): - prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + prompt = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.completions.create( model=model_name, prompt=prompt, - extra_body={ - "guided_regex": r"[.*", - "stop": ["\n"] - }, + extra_body={"structured_outputs": {"regex": r"[.*"}, "stop": ["\n"]}, ) @@ -672,25 +675,29 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): number ::= "1 " | "2 " """ - prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") + prompt = ( + "Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table." + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.completions.create( model=model_name, prompt=prompt, - extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + extra_body={ + "structured_outputs": {"grammar": invalid_simplified_sql_grammar} + }, ) @pytest.mark.asyncio -async def test_completion_with_empty_prompt_embeds( - client: openai.AsyncOpenAI) -> None: +async def test_completion_with_empty_prompt_embeds(client: openai.AsyncOpenAI) -> None: """Test completion with empty prompt embeds.""" - payload: dict[str, list] = {"prompt_embeds": []} + payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []} headers: dict[str, str] = {"Content-Type": "application/json"} # base_url = http://localhost:8000/v1/completions - response = requests.post(f"{client.base_url}completions", - headers=headers, - json=payload) + response = requests.post( + f"{client.base_url}completions", headers=headers, json=payload + ) assert response.status_code == 200, ( - f"Expected status code 200, got {response.status_code}. ") + f"Expected status code 200, got {response.status_code}. " + ) diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py index 41f1d02bf7870..3c2b3de339585 100644 --- a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -37,9 +37,9 @@ def default_image_embeds_server_args() -> list[str]: @pytest.fixture(scope="module") def server_with_image_embeds(default_image_embeds_server_args): - with RemoteOpenAIServer(MODEL_NAME, - default_image_embeds_server_args, - max_wait_seconds=600) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, default_image_embeds_server_args, max_wait_seconds=600 + ) as remote_server: yield remote_server @@ -57,7 +57,7 @@ def encode_image_embedding_to_base64(image_embedding) -> str: torch.save(image_embedding, buffer) buffer.seek(0) binary_data = buffer.read() - base64_image_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_image_embedding = base64.b64encode(binary_data).decode("utf-8") return base64_image_embedding @@ -75,19 +75,13 @@ async def test_completions_with_image_embeds( base64_image_embedding = encode_image_embedding_to_base64(image_embeds) chat_completion = await client_with_image_embeds.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { - "type": - "text", - "text": - "Describe these images separately. For each image," + "type": "text", + "text": "Describe these images separately. For each image," "reply with a short sentence (no more than 10 words).", }, { diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index f7c31b0c43778..55328f0cf0f09 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -8,7 +8,7 @@ import pytest import pytest_asyncio from tests.utils import RemoteOpenAIServer -from tests.v1.test_utils import check_request_balancing +from tests.v1.utils import check_request_balancing MODEL_NAME = "ibm-research/PowerMoE-3b" @@ -50,16 +50,13 @@ async def client(server): "model_name", [MODEL_NAME], ) -async def test_single_completion(client: openai.AsyncOpenAI, - server: RemoteOpenAIServer, - model_name: str) -> None: - +async def test_single_completion( + client: openai.AsyncOpenAI, server: RemoteOpenAIServer, model_name: str +) -> None: async def make_request(): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=10, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=10, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -108,9 +105,9 @@ async def test_single_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_completion_streaming(client: openai.AsyncOpenAI, - server: RemoteOpenAIServer, - model_name: str) -> None: +async def test_completion_streaming( + client: openai.AsyncOpenAI, server: RemoteOpenAIServer, model_name: str +) -> None: prompt = "What is an LLM?" async def make_streaming_request(): @@ -124,11 +121,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -139,16 +134,15 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single request @@ -162,9 +156,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, tasks = [make_streaming_request() for _ in range(num_requests)] results = await asyncio.gather(*tasks) - assert len( - results - ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert len(results) == num_requests, ( + f"Expected {num_requests} results, got {len(results)}" + ) assert all(results), "Not all streaming requests completed successfully." await asyncio.sleep(0.5) @@ -172,9 +166,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, tasks = [make_streaming_request() for _ in range(num_requests)] results = await asyncio.gather(*tasks) - assert len( - results - ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert len(results) == num_requests, ( + f"Expected {num_requests} results, got {len(results)}" + ) assert all(results), "Not all streaming requests completed successfully." # Check request balancing via Prometheus metrics if DP_SIZE > 1 diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index bdd5155c1481d..c8bcd62d66802 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -14,20 +14,20 @@ from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.executor.multiproc_executor import MultiprocExecutor -class Mock: - ... +class Mock: ... class CustomMultiprocExecutor(MultiprocExecutor): - - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False, - unique_reply_rank: Optional[int] = None) -> list[Any]: - # Drop marker to show that this was ran + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None, + ) -> list[Any]: + # Drop marker to show that this was run with open(".marker", "w"): ... return super().collective_rpc(method, timeout, args, kwargs) @@ -47,17 +47,22 @@ def test_custom_executor_type_checking(): ) LLMEngine.from_engine_args(engine_args) with pytest.raises(ValueError): - engine_args = AsyncEngineArgs(model=MODEL, - gpu_memory_utilization=0.2, - max_model_len=8192, - distributed_executor_backend=Mock) + engine_args = AsyncEngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock, + ) AsyncLLM.from_engine_args(engine_args) -@pytest.mark.parametrize("distributed_executor_backend", [ - CustomMultiprocExecutor, - "tests.v1.executor.test_executor.CustomMultiprocExecutor" -]) +@pytest.mark.parametrize( + "distributed_executor_backend", + [ + CustomMultiprocExecutor, + "tests.v1.executor.test_executor.CustomMultiprocExecutor", + ], +) def test_custom_executor(distributed_executor_backend, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -82,10 +87,13 @@ def test_custom_executor(distributed_executor_backend, tmp_path): os.chdir(cwd) -@pytest.mark.parametrize("distributed_executor_backend", [ - CustomMultiprocExecutorAsync, - "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync" -]) +@pytest.mark.parametrize( + "distributed_executor_backend", + [ + CustomMultiprocExecutorAsync, + "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync", + ], +) def test_custom_executor_async(distributed_executor_backend, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -103,9 +111,9 @@ def test_custom_executor_async(distributed_executor_backend, tmp_path): sampling_params = SamplingParams(max_tokens=1) async def t(): - stream = engine.generate(request_id="0", - prompt="foo", - sampling_params=sampling_params) + stream = engine.generate( + request_id="0", prompt="foo", sampling_params=sampling_params + ) async for x in stream: ... diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py new file mode 100644 index 0000000000000..31f6f377da624 --- /dev/null +++ b/tests/v1/generation/test_batch_invariance.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import random +import string + +import pytest +import torch + +from vllm import LLM, SamplingParams + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + # Lightweight random prompt generator to vary prompt lengths and content. + vocab = [ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "golf", + "hotel", + "india", + "juliet", + "kilo", + "lima", + "mike", + "november", + "oscar", + "papa", + "quebec", + "romeo", + "sierra", + "tango", + "uniform", + "victor", + "whiskey", + "xray", + "yankee", + "zulu", + ] + n = random.randint(min_words, max_words) + words = random.choices(vocab, k=n) + + # Add some noise and punctuation variability + if random.random() < 0.5: + words[0] = words[0].capitalize() + if random.random() < 0.2: + words.append("".join(random.choices(string.ascii_lowercase, k=5))) + punct = random.choice([".", "?", "!", "...", ""]) + return " ".join(words) + punct + + +@pytest.mark.timeout(1000) +def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): + """ + Ensures that the same request (the 'needle' prompt) yields identical output + whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64), + using the high-level v1 LLM() API only (no manual batching). + + Strategy: + - Create two LLM engines with identical config except max_num_seqs: 1 vs N. + - Compute a baseline output for the needle prompt with the bs=1 engine. + - For many trials, generate a batch (size N) where the needle appears at a + random position among random filler prompts using the bs=N engine. + - Track how many trials match vs mismatch, and report totals at the end. + The test fails if any mismatches occur, but we still dump pass/fail + counts. + + Notes: + - Use seeded stochastic sampling with a fixed seed to test determinism. + - Outputs are intentionally longer and sampled at higher temperature/top_p + to produce a more random-sounding phrase, yet remain deterministic by + seed. + - Keep max_tokens and max_model_len bounded for speed and memory use. + """ + random.seed(12345) + + # Allow overrides from environment (useful for CI tuning) + # "facebook/opt-125m" is too small, doesn't reliably test determinism + model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) + batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64")) + assert batch_size >= 2, "Batch size should be >= 2 to mix needle." + + # Keep GPU memory usage low to avoid startup allocation failures. + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096")) + swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) + + # Sampling parameters: longer outputs with a more random-sounding + # continuation,but still deterministic due to fixed seed. + temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0")) + top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95")) + max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128")) + + sampling = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=20240919, + ) + + needle_prompt = "There once was a " + + llm_bs1 = None + llm_bsN = None + try: + # Engine with bs=1 behavior + llm_bs1 = LLM_with_max_seqs( + model=model, + max_num_seqs=1, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + swap_space=swap_space_gb, + ) + + # Baseline generation for the needle prompt alone. + baseline_out = llm_bs1.generate([needle_prompt], sampling) + assert len(baseline_out) == 1 + assert len(baseline_out[0].outputs) >= 1 + baseline_text = baseline_out[0].outputs[0].text + + # Engine with larger batch limit (e.g., 64) + llm_bsN = LLM_with_max_seqs( + model=model, + max_num_seqs=batch_size, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + swap_space=swap_space_gb, + ) + + mismatches = 0 + + for trial in range(num_trials): + # Create a batch of size `batch_size` and insert the needle at + # a random index + prompts: list[str] = [] + needle_pos = random.randint(0, batch_size - 1) + for i in range(batch_size): + if i == needle_pos: + prompts.append(needle_prompt) + else: + prompts.append(_random_prompt()) + + # Generate with the larger-batch engine + outputs = llm_bsN.generate(prompts, sampling) + # Find the needle output by position + needle_output = outputs[needle_pos] + assert needle_output.prompt == needle_prompt + assert len(needle_output.outputs) >= 1 + text = needle_output.outputs[0].text + + if text != baseline_text: + mismatches += 1 + + passes = num_trials - mismatches + # Dump how many passed vs failed + print( + f"[determinism] total={num_trials}, passed={passes}, " + f"failed={mismatches}, batch_size={batch_size}" + ) + + if mismatches > 0: + pytest.fail( + f"Nondeterministic outputs detected: {mismatches} failed out " + f"of {num_trials} trials (batch_size={batch_size})." + ) + + finally: + # Ensure engines are shutdown to free GPU/VRAM across test sessions + if llm_bs1 is not None: + with contextlib.suppress(Exception): + llm_bs1.shutdown() + if llm_bsN is not None: + with contextlib.suppress(Exception): + llm_bsN.shutdown() + + +def _extract_step_logprobs(request_output): + if getattr(request_output, "outputs", None): + inner = request_output.outputs[0] + if hasattr(inner, "logprobs") and inner.logprobs is not None: + t = torch.tensor( + [ + inner.logprobs[i][tid].logprob + for i, tid in enumerate(inner.token_ids) + ], + dtype=torch.float32, + ) + return t + + return None + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Requires CUDA to match production inference path.", +) +def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2(): + # model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + # Force float32 to avoid precision-induced differences. + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enforce_eager=True, # helps reduce nondeterminism from some backends + ) + + prompts = [ + "The capital of France is", + "The capital of Germany is", + ] + + sp = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=8, + # Seed shouldn't matter at temperature=0, but keeping it stable anyway. + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + bs1_logprobs_per_prompt = [] + for p in prompts: + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bs1_logprobs_per_prompt.append(step_logprobs) + + # BS=2: run prompts in a batch and collect logprobs per step for each + # prompt. + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bs2_logprobs_per_prompt = [] + for o in outs_batched: + step_logprobs = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bs2_logprobs_per_prompt.append(step_logprobs) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs. + for i, (logprobs_bs1, logprobs_bs2) in enumerate( + zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt) + ): + assert len(logprobs_bs1) == len(logprobs_bs2), ( + f"Different number of generation steps for prompt index {i}: " + f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)" + ) + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)): + assert a.shape == b.shape, ( + f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}" + ) + # Bitwise exact equality. + assert torch.equal(a, b), ( + f"Bitwise logprobs mismatch at prompt {i}, step {t} " + f"(dtype={a.dtype}, shape={a.shape})." + ) + + +def LLM_with_max_seqs( + model: str, + max_num_seqs: int, + gpu_memory_utilization: float, + max_model_len: int, + swap_space: int, +) -> LLM: + """ + Helper to construct an LLM with a specific max_num_seqs (batch-size limit) + using the high-level v1 LLM API, while constraining memory usage. + """ + return LLM( + model=model, + max_num_seqs=max_num_seqs, + # Constrain GPU memory pool so test can run even on busy GPUs. + gpu_memory_utilization=gpu_memory_utilization, + # Keep KV cache footprint small while allowing longer outputs. + max_model_len=max_model_len, + # Allow some CPU offload if needed. + swap_space=swap_space, + # Keep things lean and CI-friendly. + dtype="auto", + # Single-GPU by default; override externally if desired. + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1", + enable_prefix_caching=False, + # Enable for MOE models + # enable_expert_parallel=True, + ) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 9322410ec99e9..3b0f2d102c1ff 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -1,6 +1,31 @@ #!/bin/bash set -xe +# Parse command line arguments +KV_BUFFER_DEVICE="cuda" # Default to cuda +while [[ $# -gt 0 ]]; do + case $1 in + --kv_buffer_device) + KV_BUFFER_DEVICE="$2" + shift 2 + ;; + *) + echo "Unknown option $1" + echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]" + exit 1 + ;; + esac +done + +echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" + +# Build the kv-transfer-config once +if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +else + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" +fi + # Models to run MODELS=( "Qwen/Qwen3-0.6B" @@ -79,18 +104,21 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8100 + i)) - # Calculate side channel port. Avoid clash with with TP workers. + # Calculate side channel port. Avoid clash with with TP workers. SIDE_CHANNEL_PORT=$((5559 + i)) echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + UCX_NET_DEVICES=all \ + VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ + vllm serve $model_name \ --port $PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -117,12 +145,15 @@ run_tests_for_model() { echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + UCX_NET_DEVICES=all \ + VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ + vllm serve $model_name \ --port $PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ --tensor-parallel-size $DECODER_TP_SIZE \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh old mode 100644 new mode 100755 index b64461292910d..c48b452e24cd4 --- a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -1,6 +1,33 @@ #!/bin/bash set -xe +# Parse command line arguments +KV_BUFFER_DEVICE="cuda" # Default to cuda +PREFILL_GPU_ID=4 # Default GPU IDs +DECODE_GPU_ID=5 +while [[ $# -gt 0 ]]; do + case $1 in + --kv_buffer_device) + KV_BUFFER_DEVICE="$2" + shift 2 + ;; + *) + echo "Unknown option $1" + echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]" + exit 1 + ;; + esac +done + +echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)" + +# Build the kv-transfer-config once +if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +else + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" +fi + # Models to run MODELS=( "Qwen/Qwen3-0.6B" @@ -50,15 +77,15 @@ run_tests_for_model() { # Get model-specific arguments local model_args=$(get_model_args "$model_name") - + # Start prefill instance PREFILL_PORT=8001 - BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ --port $PREFILL_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -72,11 +99,11 @@ run_tests_for_model() { DECODE_PORT=8002 # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ --port $DECODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" diff --git a/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh index ea125f99fc42c..fa1738bb31940 100644 --- a/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh @@ -53,7 +53,6 @@ cleanup() { launch_baseline() { BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; VLLM_LOGGING_LEVEL=DEBUG \ - VLLM_USE_V1=1 \ PJRT_DEVICE=TPU \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ @@ -73,7 +72,6 @@ launch_pd() { UCX_TLS=tcp \ VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ VLLM_LOGGING_LEVEL=DEBUG \ - VLLM_USE_V1=1 \ VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ PJRT_DEVICE=TPU \ @@ -93,7 +91,6 @@ launch_pd() { UCX_TLS=tcp \ VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ VLLM_LOGGING_LEVEL=DEBUG \ - VLLM_USE_V1=1 \ PJRT_DEVICE=TPU \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ diff --git a/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh index 8ba653770c4f0..3d63822371bed 100644 --- a/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh @@ -55,7 +55,6 @@ launch_pd() { UCX_TLS=tcp \ VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ VLLM_LOGGING_LEVEL=DEBUG \ - VLLM_USE_V1=1 \ VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ PJRT_DEVICE=TPU \ @@ -75,7 +74,6 @@ launch_pd() { UCX_TLS=tcp \ VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ VLLM_LOGGING_LEVEL=DEBUG \ - VLLM_USE_V1=1 \ PJRT_DEVICE=TPU \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index e5d66ffeeeb23..b301968e5bf84 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -12,12 +12,12 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 # Model-specific expected values -EXPECTED_VALUES = { - "Qwen/Qwen3-0.6B": 0.41, - "deepseek-ai/deepseek-vl2-small": 0.59 -} +EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59} -SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 +SIMPLE_PROMPT = ( + "The best part about working on vLLM is that I got to meet so many people across " + "various different organizations like UCB, Google, and Meta which means", +) # Get model name from environment variable MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") @@ -25,8 +25,7 @@ MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") def run_simple_prompt(): client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) - completion = client.completions.create(model=MODEL_NAME, - prompt=SIMPLE_PROMPT) + completion = client.completions.create(model=MODEL_NAME, prompt=SIMPLE_PROMPT) print("-" * 50) print(f"Completion results for {MODEL_NAME}:") @@ -38,9 +37,11 @@ def test_accuracy(): """Run the end to end accuracy test.""" run_simple_prompt() - model_args = (f"model={MODEL_NAME}," - f"base_url={BASE_URL}/completions," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + model_args = ( + f"model={MODEL_NAME}," + f"base_url={BASE_URL}/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -52,11 +53,14 @@ def test_accuracy(): expected_value = EXPECTED_VALUES.get(MODEL_NAME) if expected_value is None: - print(f"Warning: No expected value found for {MODEL_NAME}. " - "Skipping accuracy check.") + print( + f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check." + ) print(f"Measured value: {measured_value}") return - assert (measured_value - RTOL < expected_value - and measured_value + RTOL > expected_value - ), f"Expected: {expected_value} | Measured: {measured_value}" + assert ( + measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py index 697e101c35926..caa4aab870abe 100644 --- a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py @@ -43,37 +43,39 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool: if response.status_code == 200: return True else: - print(f"Attempt {attempt + 1}: Server returned status code " - "{response.status_code}") + print( + f"Attempt {attempt + 1}: Server returned status code " + "{response.status_code}" + ) except requests.exceptions.RequestException as e: print(f"Attempt {attempt + 1}: Error connecting to server: {e}") time.sleep(1) # Wait before retrying return False -def run_simple_prompt(base_url: str, model_name: str, input_prompt: str, - use_chat_endpoint: bool) -> str: +def run_simple_prompt( + base_url: str, model_name: str, input_prompt: str, use_chat_endpoint: bool +) -> str: client = openai.OpenAI(api_key="EMPTY", base_url=base_url) if use_chat_endpoint: completion = client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": [{ - "type": "text", - "text": input_prompt - }] - }], + messages=[ + {"role": "user", "content": [{"type": "text", "text": input_prompt}]} + ], max_completion_tokens=MAX_OUTPUT_LEN, temperature=0.0, - seed=42) + seed=42, + ) return completion.choices[0].message.content else: - completion = client.completions.create(model=model_name, - prompt=input_prompt, - max_tokens=MAX_OUTPUT_LEN, - temperature=0.0, - seed=42) + completion = client.completions.create( + model=model_name, + prompt=input_prompt, + max_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42, + ) return completion.choices[0].text @@ -90,7 +92,8 @@ def main(): "--service_url", # Name of the first argument type=str, required=True, - help="The vLLM service URL.") + help="The vLLM service URL.", + ) parser.add_argument( "--model_name", # Name of the first argument @@ -127,28 +130,30 @@ def main(): if not os.path.exists(args.file_name): raise ValueError( f"In disagg mode, the output file {args.file_name} from " - "non-disagg. baseline does not exist.") + "non-disagg. baseline does not exist." + ) service_url = f"{args.service_url}/v1" if not check_vllm_server(health_check_url): - raise RuntimeError( - f"vllm server: {args.service_url} is not ready yet!") + raise RuntimeError(f"vllm server: {args.service_url} is not ready yet!") output_strs = dict() for i, prompt in enumerate(SAMPLE_PROMPTS): - use_chat_endpoint = (i % 2 == 1) - output_str = run_simple_prompt(base_url=service_url, - model_name=args.model_name, - input_prompt=prompt, - use_chat_endpoint=use_chat_endpoint) + use_chat_endpoint = i % 2 == 1 + output_str = run_simple_prompt( + base_url=service_url, + model_name=args.model_name, + input_prompt=prompt, + use_chat_endpoint=use_chat_endpoint, + ) print(f"Prompt: {prompt}, output: {output_str}") output_strs[prompt] = output_str if args.mode == "baseline": # baseline: save outputs try: - with open(args.file_name, 'w') as json_file: + with open(args.file_name, "w") as json_file: json.dump(output_strs, json_file, indent=4) except OSError as e: print(f"Error writing to file: {e}") diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py index 8439e30be154b..268a1845a2bba 100644 --- a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -12,8 +12,7 @@ PROXY_HOST = os.getenv("PROXY_HOST", "localhost") PROXY_PORT = os.getenv("PROXY_PORT", None) if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: - raise ValueError( - "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + raise ValueError("Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 @@ -41,13 +40,13 @@ def test_edge_cases(): # (1) Check that we can handle a very short prompt, # less than the length of the block size. - completion = proxy_client.completions.create(model=MODEL, - prompt=SHORT_PROMPT, - temperature=0) + completion = proxy_client.completions.create( + model=MODEL, prompt=SHORT_PROMPT, temperature=0 + ) proxy_response = completion.choices[0].text - completion = prefill_client.completions.create(model=MODEL, - prompt=SHORT_PROMPT, - temperature=0) + completion = prefill_client.completions.create( + model=MODEL, prompt=SHORT_PROMPT, temperature=0 + ) prefill_response = completion.choices[0].text print(f"SMALL PROMPT: {proxy_response=}") assert proxy_response == prefill_response @@ -55,27 +54,27 @@ def test_edge_cases(): # (2) Check that we can handle a full prefix cache # hit on the D worker but not on the P worker. # (2a): prime the D worker. - completion = decode_client.completions.create(model=MODEL, - prompt=PROMPT, - temperature=0) + completion = decode_client.completions.create( + model=MODEL, prompt=PROMPT, temperature=0 + ) decode_response = completion.choices[0].text # (2b): send via the P/D setup - completion = proxy_client.completions.create(model=MODEL, - prompt=PROMPT, - temperature=0) + completion = proxy_client.completions.create( + model=MODEL, prompt=PROMPT, temperature=0 + ) proxy_response = completion.choices[0].text print(f"FULL CACHE HIT: {proxy_response=}") assert proxy_response == decode_response # (3) Check that we can handle a partial prefix cache # hit on the D worker. - completion = proxy_client.completions.create(model=MODEL, - prompt=LONG_PROMPT, - temperature=0) + completion = proxy_client.completions.create( + model=MODEL, prompt=LONG_PROMPT, temperature=0 + ) proxy_response = completion.choices[0].text - completion = prefill_client.completions.create(model=MODEL, - prompt=LONG_PROMPT, - temperature=0) + completion = prefill_client.completions.create( + model=MODEL, prompt=LONG_PROMPT, temperature=0 + ) prefill_response = completion.choices[0].text print(f"PARTIAL CACHE HIT: {proxy_response=}") assert proxy_response == prefill_response diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 905ae0ea71722..37d70510fe256 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -27,49 +27,45 @@ async def lifespan(app: FastAPI): # Create prefill clients for i, (host, port) in enumerate(global_args.prefiller_instances): - prefiller_base_url = f'http://{host}:{port}/v1' - app.state.prefill_clients.append({ - 'client': - httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), - 'host': - host, - 'port': - port, - 'id': - i - }) + prefiller_base_url = f"http://{host}:{port}/v1" + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + "host": host, + "port": port, + "id": i, + } + ) # Create decode clients for i, (host, port) in enumerate(global_args.decoder_instances): - decoder_base_url = f'http://{host}:{port}/v1' - app.state.decode_clients.append({ - 'client': - httpx.AsyncClient(timeout=None, base_url=decoder_base_url), - 'host': - host, - 'port': - port, - 'id': - i - }) + decoder_base_url = f"http://{host}:{port}/v1" + app.state.decode_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + "host": host, + "port": port, + "id": i, + } + ) # Initialize round-robin iterators - app.state.prefill_iterator = itertools.cycle( - range(len(app.state.prefill_clients))) - app.state.decode_iterator = itertools.cycle( - range(len(app.state.decode_clients))) + app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) - print(f"Initialized {len(app.state.prefill_clients)} prefill clients " - f"and {len(app.state.decode_clients)} decode clients.") + print( + f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients." + ) yield # Shutdown: Close all clients for client_info in app.state.prefill_clients: - await client_info['client'].aclose() + await client_info["client"].aclose() for client_info in app.state.decode_clients: - await client_info['client'].aclose() + await client_info["client"].aclose() # Update FastAPI app initialization to use lifespan @@ -83,43 +79,38 @@ def parse_args(): parser.add_argument("--host", type=str, default="localhost") # For prefiller instances - parser.add_argument("--prefiller-hosts", - "--prefiller-host", - type=str, - nargs="+", - default=["localhost"]) - parser.add_argument("--prefiller-ports", - "--prefiller-port", - type=int, - nargs="+", - default=[8100]) + parser.add_argument( + "--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"], + ) + parser.add_argument( + "--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100] + ) # For decoder instances - parser.add_argument("--decoder-hosts", - "--decoder-host", - type=str, - nargs="+", - default=["localhost"]) - parser.add_argument("--decoder-ports", - "--decoder-port", - type=int, - nargs="+", - default=[8200]) + parser.add_argument( + "--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"] + ) + parser.add_argument( + "--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200] + ) args = parser.parse_args() # Validate and pair hosts with ports if len(args.prefiller_hosts) != len(args.prefiller_ports): raise ValueError( - "Number of prefiller hosts must match number of prefiller ports") + "Number of prefiller hosts must match number of prefiller ports" + ) if len(args.decoder_hosts) != len(args.decoder_ports): - raise ValueError( - "Number of decoder hosts must match number of decoder ports") + raise ValueError("Number of decoder hosts must match number of decoder ports") # Create tuples of (host, port) for each service type - args.prefiller_instances = list( - zip(args.prefiller_hosts, args.prefiller_ports)) + args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports)) args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) return args @@ -136,29 +127,30 @@ def get_next_client(app, service_type: str): Returns: The next client to use """ - if service_type == 'prefill': + if service_type == "prefill": client_idx = next(app.state.prefill_iterator) return app.state.prefill_clients[client_idx] - elif service_type == 'decode': + elif service_type == "decode": client_idx = next(app.state.decode_iterator) return app.state.decode_clients[client_idx] else: raise ValueError(f"Unknown service type: {service_type}") -async def send_request_to_service(client_info: dict, endpoint: str, - req_data: dict, request_id: str): +async def send_request_to_service( + client_info: dict, endpoint: str, req_data: dict, request_id: str +): """ Send a request to a service using a client from the pool. """ req_data = req_data.copy() - req_data['kv_transfer_params'] = { + req_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "remote_engine_id": None, "remote_block_ids": None, "remote_host": None, - "remote_port": None + "remote_port": None, } req_data["stream"] = False req_data["max_tokens"] = 1 @@ -168,31 +160,31 @@ async def send_request_to_service(client_info: dict, endpoint: str, del req_data["stream_options"] headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } - response = await client_info['client'].post(endpoint, - json=req_data, - headers=headers) + response = await client_info["client"].post( + endpoint, json=req_data, headers=headers + ) response.raise_for_status() return response -async def stream_service_response(client_info: dict, endpoint: str, - req_data: dict, request_id: str): +async def stream_service_response( + client_info: dict, endpoint: str, req_data: dict, request_id: str +): """ Asynchronously stream response from a service using a client from the pool. """ headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } - async with client_info['client'].stream("POST", - endpoint, - json=req_data, - headers=headers) as response: + async with client_info["client"].stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk @@ -204,40 +196,39 @@ async def _handle_completions(api: str, request: Request): request_id = str(uuid.uuid4()) # Get the next prefill client in round-robin fashion - prefill_client_info = get_next_client(request.app, 'prefill') + prefill_client_info = get_next_client(request.app, "prefill") # Send request to prefill service - response = await send_request_to_service(prefill_client_info, api, - req_data, request_id) + response = await send_request_to_service( + prefill_client_info, api, req_data, request_id + ) # Extract the needed fields response_json = response.json() - kv_transfer_params = response_json.get('kv_transfer_params', {}) + kv_transfer_params = response_json.get("kv_transfer_params", {}) if kv_transfer_params: req_data["kv_transfer_params"] = kv_transfer_params # Get the next decode client in round-robin fashion - decode_client_info = get_next_client(request.app, 'decode') + decode_client_info = get_next_client(request.app, "decode") logger.debug("Using %s %s", prefill_client_info, decode_client_info) # Stream response from decode service async def generate_stream(): - async for chunk in stream_service_response(decode_client_info, - api, - req_data, - request_id=request_id): + async for chunk in stream_service_response( + decode_client_info, api, req_data, request_id=request_id + ): yield chunk - return StreamingResponse(generate_stream(), - media_type="application/json") + return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import sys import traceback + exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - f" - {api} endpoint") + print(f"Error occurred in disagg prefill proxy server - {api} endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @@ -259,13 +250,14 @@ async def healthcheck(): return { "status": "ok", "prefill_instances": len(app.state.prefill_clients), - "decode_instances": len(app.state.decode_clients) + "decode_instances": len(app.state.decode_clients), } -if __name__ == '__main__': +if __name__ == "__main__": global global_args global_args = parse_args() import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py new file mode 100644 index 0000000000000..0bb67b574fa14 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501 + SharedStorageConnectorMetadata, +) +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_initialized, + get_kv_transfer_group, +) +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin + +# Importing utils registers TestSharedStorageConnector with the factory +from .utils import create_vllm_config + + +def _make_empty_scheduler_output(): + return SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + kv_connector_metadata=SharedStorageConnectorMetadata(), + ) + + +def test_kv_connector_mixin_clears_metadata(): + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector" + vllm_config.kv_transfer_config.kv_role = "kv_both" + vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit" + + # Initialize the global connector instance + ensure_kv_transfer_initialized(vllm_config) + + try: + # Minimal scheduler output with empty metadata; mixin should still + # bind/clear metadata even if no loads happen + scheduler_output = _make_empty_scheduler_output() + + # Invoke the no-forward path which uses the mixin context manager + KVConnectorModelRunnerMixin.kv_connector_no_forward( + scheduler_output, vllm_config + ) + + # Verify clear_connector_metadata was called on the connector + connector = get_kv_transfer_group() + assert connector._connector_metadata is None + # Test connector wrapper records method calls + assert connector.call_record.get("bind_connector_metadata", 0) == 1 + assert connector.call_record.get("clear_connector_metadata", 0) == 1 + finally: + # Ensure we clean up the global connector between tests + KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown() diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py new file mode 100644 index 0000000000000..0902fbfe85f33 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def scheduler(): + vllm_config = create_vllm_config() + return create_scheduler(vllm_config) + + +@pytest.mark.parametrize( + "num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs", + [ + (100, 99, {0, 98}), + (100, 99, {50, 98}), + (100, 99, {98}), + ], +) +def test_async_load_failure( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size + + request1 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request1) + request2 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request2) + request3 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request3) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True) + ) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + assert len(scheduler.waiting) == 3 + for request in scheduler.waiting: + assert request.num_computed_tokens == 0 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + # Simulate a failure in loading some of request2 blocks. + (req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id) + invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving={request1.request_id, request3.request_id}, + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + min_invalid_block_idx = min(invalid_block_idxs) + + assert len(scheduler.waiting) == 3 + for request in scheduler.waiting: + if request.request_id == request2.request_id: + assert request.num_computed_tokens == ( + min_invalid_block_idx * scheduler.block_size + ) + else: + assert request.num_computed_tokens == 0 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.failed_recving_kv_req_ids == {request2.request_id} + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + +@pytest.mark.parametrize( + "num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs", + [ + (100, 99, {0, 98}), + (100, 99, {50, 98}), + (100, 99, {98}), + ], +) +def test_sync_load_failure( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size + + request1 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request1) + request2 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request2) + request3 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request3) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False) + ) + scheduler.connector.request_finished.return_value = (False, None) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + # req_id -> num_computed_tokens + expected_computed_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + assert len(scheduler.running) == 3 + assert len(scheduler_output.scheduled_new_reqs) == 3 + for request in scheduler_output.scheduled_new_reqs: + assert request.num_computed_tokens == expected_computed_tokens[request.req_id] + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + # Simulate a failure in loading some of request2 blocks. + req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0] + invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs} + model_runner_output = create_model_runner_output( + [request1, request2, request3], + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == request2.request_id + assert scheduler.running[0].num_computed_tokens == ( + min(invalid_block_idxs) * scheduler.block_size + ) + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + assert scheduler.connector.request_finished.call_count == 2 + + +@pytest.mark.parametrize( + "num_prompt_blocks," + "num_external_computed_blocks," + "num_common_prefix_blocks," + "invalid_block_idxs", + [ + (100, 99, 50, {0, 49}), + (100, 99, 50, {25, 49}), + (100, 99, 50, {49}), + ], +) +def test_sync_load_failure_with_shared_blocks( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + num_common_prefix_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size + common_prefix_len = num_common_prefix_blocks * scheduler.block_size + + request1 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) + scheduler.add_request(request=request1) + request2 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) + scheduler.add_request(request=request2) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False) + ) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + # req_id -> num_computed_tokens + expected_computed_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: common_prefix_len, + } + + assert len(scheduler.running) == 2 + assert len(scheduler_output.scheduled_new_reqs) == 2 + for request in scheduler_output.scheduled_new_reqs: + assert request.num_computed_tokens == expected_computed_tokens[request.req_id] + assert scheduler.connector.get_num_new_matched_tokens.call_count == 2 + + # Simulate a failure in loading some of the shared blocks. + req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs} + model_runner_output = create_model_runner_output( + [request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True + ) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + # req_id -> num_computed_tokens + # all the common prefix blocks will be computed by request1 + expected_computed_tokens = { + request1.request_id: min(invalid_block_idxs) * scheduler.block_size, + request2.request_id: common_prefix_len, + } + + assert len(scheduler.running) == 2 + for request in scheduler.running: + assert ( + request.num_computed_tokens == expected_computed_tokens[request.request_id] + ) + assert scheduler.connector.get_num_new_matched_tokens.call_count == 2 + + +@pytest.mark.parametrize( + "num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs", + [ + (100, 99, {0, 50, 98}), + (100, 99, {98, 50, 0}), + ], +) +def test_async_progressive_load_failure( + scheduler: Scheduler, + num_prompt_blocks: int, + num_external_computed_blocks: int, + invalid_block_idxs: set[int], +): + assert num_prompt_blocks >= num_external_computed_blocks + + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size + + request = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True) + ) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + assert len(scheduler.waiting) == 1 + assert scheduler.waiting.peek_request().request_id == request.request_id + assert request.num_computed_tokens == 0 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 + + min_invalid_block_idx = max(invalid_block_idxs) + 1 + # Simulate failures when progressively loading request blocks. + for invalid_block_idx in invalid_block_idxs: + (req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id) + invalid_block_ids = {req_block_ids[invalid_block_idx]} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving=set(), + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx) + + assert len(scheduler.waiting) == 1 + assert scheduler.waiting.peek_request().request_id == request.request_id + assert request.num_computed_tokens == ( + min_invalid_block_idx * scheduler.block_size + ) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert scheduler.failed_recving_kv_req_ids == {request.request_id} + assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index b1780d8a9af80..74ae3ca9a8633 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -52,29 +52,26 @@ def test_multi_shared_storage_connector_consistency(): kv_connector="MultiConnector", kv_role="kv_both", kv_connector_extra_config={ - "connectors": [{ - "kv_connector": - "TestSharedStorageConnector", - "kv_role": - "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_1_path), - "name": "storage1", + "connectors": [ + { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path), + "name": "storage1", + }, + "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", }, - "kv_connector_module_path": - "tests.v1.kv_connector.unit.utils", - }, { - "kv_connector": - "TestSharedStorageConnector", - "kv_role": - "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_2_path), - "name": "storage2", + { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path), + "name": "storage2", + }, + "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", }, - "kv_connector_module_path": - "tests.v1.kv_connector.unit.utils", - }] + ] }, ) @@ -93,14 +90,16 @@ def test_multi_shared_storage_connector_consistency(): local_subdirs = list(storage_1_path.iterdir()) external_subdirs = list(storage_2_path.iterdir()) - assert len( - local_subdirs - ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(local_subdirs) > 0, ( + f"Local storage path {storage_1_path} is empty after generation." + ) assert len(external_subdirs) > 0, ( - f"External storage path {storage_2_path} is empty after generation.") + f"External storage path {storage_2_path} is empty after generation." + ) assert len(local_subdirs) == len(external_subdirs), ( f"Mismatch in number of cache entries: " - f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + f"Local={len(local_subdirs)}, External={len(external_subdirs)}" + ) # The subdirectories should correspond to the prompt hashes # Since prompts are the same, the hash directories should be the same name @@ -113,29 +112,39 @@ def test_multi_shared_storage_connector_consistency(): # Compare the contents of each corresponding cache directory for subdir_name in local_subdir_names: print(f"Comparing contents of cache directory: {subdir_name}") - assert _compare_directories(storage_1_path / subdir_name, - storage_2_path / subdir_name), \ - (f"Contents differ for cache directory '{subdir_name}' between " - f"{storage_1_path} and {storage_2_path}") + assert _compare_directories( + storage_1_path / subdir_name, storage_2_path / subdir_name + ), ( + f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}" + ) events = get_connector_events() # get_num_new_matched_tokens and update_state_after_alloc will be called # on each connector in turn. assert events["storage1-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] assert events["storage1-WORKER"][:5] == [ - 'register_kv_caches', 'bind_connector_metadata', 'start_load_kv', - 'wait_for_layer_load', 'save_kv_layer' + "register_kv_caches", + "bind_connector_metadata", + "start_load_kv", + "wait_for_layer_load", + "save_kv_layer", ] assert events["storage2-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] assert events["storage2-WORKER"][:5] == [ - 'register_kv_caches', 'bind_connector_metadata', 'start_load_kv', - 'wait_for_layer_load', 'save_kv_layer' + "register_kv_caches", + "bind_connector_metadata", + "start_load_kv", + "wait_for_layer_load", + "save_kv_layer", ] # Reset prefix cache or else we'll just get the tokens back from there. @@ -151,12 +160,14 @@ def test_multi_shared_storage_connector_consistency(): # on that one but with zero blocks for others (first nonzero match is # chosen). assert events["storage1-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[7] 96", + "build_connector_meta", ] assert events["storage2-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] # Delete storage1 connector state @@ -175,12 +186,14 @@ def test_multi_shared_storage_connector_consistency(): # a hit, so update_state_after_alloc will only be called with allocated # blocks for the second connector. assert events["storage1-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] assert events["storage2-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[7] 96", + "build_connector_meta", ] # Clean up @@ -191,15 +204,14 @@ def test_multi_shared_storage_connector_consistency(): def get_connector_events() -> dict[str, list[str]]: # Read in connector events and reset the files. import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") connector_events = {} for fname in event_files: name = fname.split("connector_")[1].split("_events.log")[0] try: with open(fname, "r+") as f: - connector_events[name] = [ - line.strip() for line in f if line.strip() - ] + connector_events[name] = [line.strip() for line in f if line.strip()] f.truncate(0) except Exception as e: print(f"[ERROR] Could not read connector events for {name}: {e}") @@ -211,5 +223,5 @@ def test_engine_id_conflict(): configs = [KVTransferConfig() for _ in range(2)] ids = [config.engine_id for config in configs] assert ids[0] != ids[1], ( - "Engine IDs should be different for different configs. " - f"Got {ids}") + f"Engine IDs should be different for different configs. Got {ids}" + ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 040b44dc5d2ca..a1f53cb255630 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -18,22 +18,80 @@ import torch from vllm import LLM from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiKVConnectorStats, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( - KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, - NixlConnectorWorker) + KVConnectorRole, + NixlAgentMetadata, + NixlConnector, + NixlConnectorMetadata, + NixlConnectorWorker, + NixlKVConnectorStats, +) +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_shutdown, + has_kv_transfer_group, +) from vllm.forward_context import ForwardContext +from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import RequestStatus from .utils import create_request, create_scheduler, create_vllm_config +@pytest.fixture(scope="module", autouse=True) +def clear_kv_transfer(): + """ + The test cases in this file use `VLLM_ENABLE_V1_MULTIPROCESSING=0`, + causing the global variable `_KV_CONNECTOR_AGENT` + to be assigned but never deleted. + + Since the current pytest process does not terminate and instead + continues running tests from other files, + this global variable remains in memory and interferes + with test cases in other modules. + + So we use this fixture to ensure that the global variable + `_KV_CONNECTOR_AGENT` is properly cleaned up after each test. + """ + yield + if has_kv_transfer_group(): + ensure_kv_transfer_shutdown() + + +def get_default_xfer_telemetry( + xferDurationS: float = 1, + postDurationS: float = 1, + totalBytes: int = 1, + descCount: int = 1, +) -> dict: + class AttributeDict(dict): + __slots__ = () + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ # type: ignore[assignment] + + # We can't instantiate nixlXferTelemetry because it's read only and + # ray env does not have NIXL, so we must fake it + return AttributeDict( + xferDuration=xferDurationS * 1e6, # in us + postDuration=postDurationS * 1e6, # in us + totalBytes=totalBytes, + descCount=descCount, + ) + + class FakeNixlWrapper: """Mock implementation of NixlWrapper for testing. We don't inherit from nixl._api.nixl_agent because nixl may not be installed. - + Note: The complete source of this class is also used in the `_make_fake_nixl_pkg` function to create a fake nixl package for Ray workers. @@ -44,13 +102,15 @@ class FakeNixlWrapper: def __init__(self, agent_name: str, *args, **kwargs): self._cycles_before_xfer_done = 0 - self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( - lambda: 0) + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(lambda: 0) def get_reg_descs(self, caches_data, memory_type: str) -> list: return [str(uuid.uuid4()) for _ in caches_data] - def register_memory(self, descs) -> None: + def register_memory(self, descs, backends) -> None: + pass + + def deregister_memory(self, descs) -> None: pass def get_xfer_descs(self, blocks_data, memory_type: str) -> list: @@ -70,8 +130,7 @@ class FakeNixlWrapper: return {} def check_xfer_state(self, handle: int) -> str: - if self._check_xfer_state_cycles[ - handle] >= self._cycles_before_xfer_done: + if self._check_xfer_state_cycles[handle] >= self._cycles_before_xfer_done: return "DONE" self._check_xfer_state_cycles[handle] += 1 return "PROC" @@ -79,21 +138,32 @@ class FakeNixlWrapper: def release_xfer_handle(self, handle: int) -> None: pass + def release_dlist_handle(self, handle: int) -> None: + pass + + def remove_remote_agent(self, agent: str) -> None: + pass + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: pass - def make_prepped_xfer(self, - xfer_type: str, - local_xfer_side_handle: int, - local_block_descs_ids: list[int], - remote_xfer_side_handle: int, - remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None) -> int: + def make_prepped_xfer( + self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None, + ) -> int: return uuid.uuid4().int def transfer(self, handle: int) -> str: return "PROC" + def get_xfer_telemetry(self, handle: int) -> dict: + return get_default_xfer_telemetry() + ############################################################ # Follow are for changing the behavior during testing. ############################################################ @@ -106,7 +176,7 @@ class FakeNixlWrapper: def _make_fake_nixl_pkg(): """Context manager that creates a temporary package making `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. - + Automatically cleans up the temporary directory when done. """ with tempfile.TemporaryDirectory() as td: @@ -131,6 +201,11 @@ nixl_agent = FakeNixlWrapper with open(os.path.join(pkg_root, "__init__.py"), "w") as f: f.write(stub) + # Mock nixlXferTelemetry class + pkg_root2 = os.path.join(td, "nixl", "_bindings") + os.makedirs(pkg_root2, exist_ok=True) + with open(os.path.join(pkg_root2, "__init__.py"), "w") as f: + f.write("class nixlXferTelemetry: pass") # touch parent package open(os.path.join(td, "nixl", "__init__.py"), "w").close() yield td @@ -147,10 +222,12 @@ def test_basic_interface(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) request_id = request.request_id scheduler.add_request(request) @@ -166,8 +243,11 @@ def test_basic_interface(): req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( - req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]): + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): assert block_id == block.block_id @@ -187,11 +267,13 @@ def test_prompt_less_than_block_size(): NUM_TOKENS = int(BLOCK_SIZE * 0.5) # Request will have 1 partial remote block. - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True, - num_remote_blocks=1) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + num_remote_blocks=1, + ) scheduler.add_request(request) scheduler_output = scheduler.schedule() @@ -204,21 +286,22 @@ def test_prompt_less_than_block_size(): class FakeNixlConnectorWorker(NixlConnectorWorker): - REMOTE_ENGINE_ID = "remote_engine" def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, - expected_engine_id: str) -> dict[int, str]: + def _nixl_handshake( + self, host: str, port: int, remote_tp_size: int, expected_engine_id: str + ) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by # gpu_model_runner. Here we just hardcode some dummy values. - self.slot_size_bytes = 4096 - self.block_len = self.slot_size_bytes * self.block_size + slot_size_bytes = 4096 + self.slot_size_per_layer = [slot_size_bytes] + self.block_len_per_layer = [slot_size_bytes * self.block_size] self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks @@ -230,27 +313,29 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - block_len=self.block_len, + block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. kv_cache_layout="HND", ), - remote_tp_size=remote_tp_size) + remote_tp_size=remote_tp_size, + ) return {0: remote_agent_name} class TestNixlHandshake: - @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, + ) def test_multi_xfer_one_engine( self, # dist_init is a fixture that initializes the distributed environment. - dist_init): + dist_init, + ): """Test case where multiple xfers are initiated to the same engine. - + This test triggers the connector to load remote KV for the same `request_id`. The transfer is not done immediately due to `set_cycles_before_xfer_done`, so there is a state where there are @@ -264,9 +349,9 @@ class TestNixlHandshake: # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0) - assert isinstance(connector.connector_worker.nixl_wrapper, - FakeNixlWrapper) + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) num_xfers = 4 while True: @@ -277,21 +362,19 @@ class TestNixlHandshake: num_xfers -= 1 metadata.add_new_req( request_id=request_id, - local_block_ids=[ - num_xfers + 1, num_xfers + 2, num_xfers + 3 - ], + local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], kv_transfer_params={ - "remote_block_ids": - [num_xfers + 4, num_xfers + 5, num_xfers + 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": - "localhost", - "remote_port": - 1234, - "remote_tp_size": - 1, - }) + "remote_block_ids": [ + num_xfers + 4, + num_xfers + 5, + num_xfers + 6, + ], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) connector.bind_connector_metadata(metadata) # Mimic maybe_setup_kv_connector in gpu_model_runner. @@ -303,8 +386,9 @@ class TestNixlHandshake: _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" + assert _after_load - _before_load < 0.1, ( + f"start_load_kv took {_after_load - _before_load} seconds" + ) # Mimic get_finished_kv_transfers in gpu_model_runner. _, done_recving = connector.get_finished(finished_req_ids=set()) @@ -316,20 +400,25 @@ class TestNixlHandshake: @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) - @pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [ - (1, 1), - (2, 1), - (4, 2), - (4, 4), - ]) + FakeNixlWrapper, + ) + @pytest.mark.parametrize( + "decode_tp_size, prefill_tp_size", + [ + (1, 1), + (2, 1), + (4, 2), + (4, 4), + ], + ) def test_async_load_kv( - self, - # Fixture that initializes the distributed environment. - dist_init, - # Simulate consumer-producer TP sizes. - decode_tp_size, - prefill_tp_size): + self, + # Fixture that initializes the distributed environment. + dist_init, + # Simulate consumer-producer TP sizes. + decode_tp_size, + prefill_tp_size, + ): """Test that NixlConnector's start_load_kv should be non-blocking.""" vllm_config = create_vllm_config() @@ -338,18 +427,20 @@ class TestNixlHandshake: # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id) + vllm_config, connector.engine_id + ) metadata = NixlConnectorMetadata() - metadata.add_new_req(request_id="id", - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "remote_tp_size": prefill_tp_size, - }) + metadata.add_new_req( + request_id="id", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": prefill_tp_size, + }, + ) connector.bind_connector_metadata(metadata) timeout = 2.5 @@ -363,8 +454,9 @@ class TestNixlHandshake: _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" + assert _after_load - _before_load < 0.1, ( + f"start_load_kv took {_after_load - _before_load} seconds" + ) time.sleep(0.5) # backoff for the async handshake to complete. connector.bind_connector_metadata(NixlConnectorMetadata()) _, done_recving = connector.get_finished(finished_req_ids=set()) @@ -374,11 +466,13 @@ class TestNixlHandshake: @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, + ) def test_concurrent_load_kv( self, # dist_init is a fixture that initializes the distributed environment. - dist_init): + dist_init, + ): """Test that multiple start_load_kv calls should occur concurrently.""" vllm_config = create_vllm_config() @@ -386,20 +480,22 @@ class TestNixlHandshake: # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id) + vllm_config, connector.engine_id + ) metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): - metadata.add_new_req(request_id=f"id_{i}", - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "remote_tp_size": 1, - }) + metadata.add_new_req( + request_id=f"id_{i}", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) connector.bind_connector_metadata(metadata) timeout = 2.5 * total_reqs @@ -414,8 +510,9 @@ class TestNixlHandshake: _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" + assert _after_load - _before_load < 0.1, ( + f"start_load_kv took {_after_load - _before_load} seconds" + ) time.sleep(0.5) # backoff for the async handshake to complete. connector.bind_connector_metadata(NixlConnectorMetadata()) _, done_recving = connector.get_finished(finished_req_ids=set()) @@ -427,7 +524,8 @@ class TestNixlHandshake: @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, + ) def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): """ Verify that adding a remote agent fails if kv_cache_layout differs. @@ -438,29 +536,30 @@ class TestNixlHandshake: # Mock TP world size to 2 to force heterogeneous TP when # remote_tp_size=1 with patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 - return_value=2): + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2, + ): # Initialize connector and worker (with fake NIXL wrapper) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0) + vllm_config, connector.engine_id, hand_shake_latency=0 + ) worker = connector.connector_worker # Minimal local registration params used by add_remote_agent - worker.slot_size_bytes = 4096 - worker.block_len = worker.slot_size_bytes * worker.block_size + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks # Metadata with different kv_cache_layout than local worker - mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \ - else "NHD" + mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD" meta = NixlAgentMetadata( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - block_len=worker.block_len, + block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, ) @@ -475,14 +574,226 @@ class TestNixlHandshake: # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then # the rest of the tests. +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_kv_connector_stats(dist_init): + """Test that KV transfer stats are properly recorded and retrieved.""" + vllm_config = create_vllm_config() + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + + # Verify that xfer_stats starts empty + initial_stats = connector.get_kv_connector_stats() + assert initial_stats is None + + # Create transfer metadata + request_id = "test_req_for_stats" + metadata = NixlConnectorMetadata() + metadata.add_new_req( + request_id=request_id, + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) + connector.bind_connector_metadata(metadata) + + # Start the transfer + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Verify stats are recorded after transfer is complete + max_iterations = 2 + # Clear metadata before start_load_kv to prevent reprocessing same request + connector.bind_connector_metadata(NixlConnectorMetadata()) + for _ in range(max_iterations): + # Need to call start_load_kv to process completed handshakes + connector.start_load_kv(dummy_ctx) + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0 and request_id in done_recving: + break + time.sleep(0.1) # Small delay to allow background handshake to complete + else: + assert "Transfer did not complete within expected iterations" + + # Now check that stats were recorded + stats_after_transfer = connector.get_kv_connector_stats() + assert isinstance(stats_after_transfer, NixlKVConnectorStats) + + # Verify stats values are recorded + assert not stats_after_transfer.is_empty() + assert stats_after_transfer.num_successful_transfers == 1 + + # Verify stats are reset after retrieval + stats_after_reset = connector.get_kv_connector_stats() + assert stats_after_reset is None + + +def test_kv_connector_stats_aggregation(): + """ + Test KV transfer stats aggregation across TP ranks using + KVOutputAggregator (used by MultiprocExecutor). + """ + + # Create KVOutputAggregator for 3 workers (simulating TP=3), same thing + # done in MultiprocExecutor.execute_model + aggregator = KVOutputAggregator(world_size=3) + + # Create stats for multiple workers with different transfer patterns + worker1_stats = NixlKVConnectorStats() + worker2_stats = NixlKVConnectorStats() + worker3_stats = NixlKVConnectorStats() + + # Record different transfers on each worker + # Worker 1: 2 transfers + stats = get_default_xfer_telemetry() + worker1_stats.record_transfer(stats) + worker1_stats.record_transfer(stats) + + # Worker 2: 1 transfer + worker2_stats.record_transfer(stats) + + # Worker 3: 3 transfers + stats = get_default_xfer_telemetry( + xferDurationS=2, postDurationS=2, totalBytes=2, descCount=2 + ) + worker3_stats.record_transfer(stats) + worker3_stats.record_transfer(stats) + worker3_stats.record_transfer(stats) + + # Create ModelRunnerOutput instances for each worker + worker_outputs = [] + for i, worker_stats in enumerate([worker1_stats, worker2_stats, worker3_stats]): + output = ModelRunnerOutput( + req_ids=[f"req_{i}"], + req_id_to_index={f"req_{i}": 0}, + sampled_token_ids=[[123]], # dummy token + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=set([f"req_{i}_send"]) + if i < 2 + else None, # Workers 0,1 finished sending + finished_recving=set([f"req_{i}_recv"]) + if i > 0 + else None, # Workers 1,2 finished receiving + kv_connector_stats=worker_stats, + ), + ) + worker_outputs.append(output) + + # Use the real aggregation mechanism (like MultiprocExecutor.execute_model) + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) + kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats + assert isinstance(kv_connector_stats, NixlKVConnectorStats) + # Number of total transfers across all workers. + assert kv_connector_stats.num_successful_transfers == 6 + # Logging proc, call reduce() to get CLI-friendly stats. + cli_stats = kv_connector_stats.reduce() + assert cli_stats["Avg xfer time (ms)"] == 1500.0 + assert cli_stats["Avg post time (ms)"] == 1500.0 + assert cli_stats["Avg number of descriptors"] == 1.5 + + +def test_multi_kv_connector_stats_aggregation(): + """ + Test MultiKVConnectorStats aggregation across TP ranks using + KVOutputAggregator (used by MultiprocExecutor). + """ + + aggregator = KVOutputAggregator(world_size=3) + + from dataclasses import dataclass + + # Mock a KVConnectorStats class for testing aggregation over connectors. + @dataclass + class FooKVConnectorStats(KVConnectorStats): + def reset(self): + self.data = {"num_foo_transfers": 0} + + def record_transfer(self): + if "num_foo_transfers" not in self.data: + self.data["num_foo_transfers"] = 0 + self.data["num_foo_transfers"] += 1 + + def is_empty(self) -> bool: + return self.data["num_foo_transfers"] == 0 + + def aggregate(self, other: "FooKVConnectorStats") -> "FooKVConnectorStats": + if not other.is_empty(): + self.data["num_foo_transfers"] += other.data["num_foo_transfers"] + return self + + def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats: + data: dict[str, KVConnectorStats] = {} + if nixl_count > 0: + nixl_stats = NixlKVConnectorStats() + for _ in range(nixl_count): + nixl_stats.record_transfer(get_default_xfer_telemetry()) + data["NixlConnector"] = nixl_stats + if foo_count > 0: + foo_stats = FooKVConnectorStats() + for _ in range(foo_count): + foo_stats.record_transfer() + data["FooConnector"] = foo_stats + return MultiKVConnectorStats(data=data) + + # Create heterogeneous stats across 3 workers + worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo) + + worker_outputs: list[ModelRunnerOutput] = [] + for i, (nixl, foo) in enumerate(worker_patterns): + stats = make_multi_stats(nixl, foo) + output = ModelRunnerOutput( + req_ids=[f"req_{i}"], + req_id_to_index={f"req_{i}": 0}, + sampled_token_ids=[[123]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=set([f"req_{i}_send"]) if i < 2 else None, + finished_recving=set([f"req_{i}_recv"]) if i > 0 else None, + kv_connector_stats=stats, + ), + ) + worker_outputs.append(output) + + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) + kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats + assert isinstance(kv_connector_stats, MultiKVConnectorStats) + + # Validate per-connector totals across workers + assert isinstance(kv_connector_stats["NixlConnector"], NixlKVConnectorStats) + assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5 + assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats) + assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 + + @pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, +) def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): """ Test lifecycle of an aborted Remote Prefill request hitting the timeout. - -----> P + -----> P | {process request} <-/--- | {result is NOT delivered, eg proxy is down} | @@ -513,6 +824,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): "working_dir": working_dir, # ship fake nixl package "env_vars": { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + # TODO: for ray to carry over, remove once we set + "NIXL_TELEMETRY_ENABLE": "1", }, } ray.init(runtime_env=runtime_env) @@ -537,39 +850,38 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params = SamplingParams( temperature=0.0, max_tokens=1, - extra_args={"kv_transfer_params": remote_prefill_opts}) + extra_args={"kv_transfer_params": remote_prefill_opts}, + ) scheduler = llm.llm_engine.engine_core.engine_core.scheduler req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks + 0 + ].req_to_blocks padding = "Just making this request a little longer so that we're sure " "we're not hitting the small-request lower bound beneath which we don't " "actually trigger the whole kv transfer, but rather just recompute the " "blocks on D." - _ = llm.generate([f"What is the capital of Japan? {padding}"], - sampling_params) + _ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params) # Request finished but not freed - assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks + assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks # Some other request, 0 still not freed - _ = llm.generate([f"What is the capital of Italy? {padding}"], - sampling_params) - assert '0' in req_to_blocks - assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks + _ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params) + assert "0" in req_to_blocks + assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks # Wait for timeout and trigger another scheduler loop time.sleep(timeout) - _ = llm.generate([f"What is the capital of France? {padding}"], - sampling_params) + _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) # Request-0 times out and is cleared! - assert '0' not in req_to_blocks + assert "0" not in req_to_blocks def test_register_kv_caches(dist_init): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. - + This test verifies: 1. nixl_wrapper.get_reg_descs() is called with caches_data containing tensor metadata @@ -580,10 +892,9 @@ def test_register_kv_caches(dist_init): vllm_config = create_vllm_config() # Create test kv cache tensors using proper backend shape - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, - block_size=16, - num_kv_heads=4, - head_size=64) + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) kv_caches = { @@ -593,21 +904,30 @@ def test_register_kv_caches(dist_init): } # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size( - ) * shared_tensor[0].numel() + expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() expected_base_addrs = [ - shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), ] - with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 - + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" + ) as mock_nixl_wrapper, + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" + ), + ): # noqa: E501 # Create connector connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0) + vllm_config, connector.engine_id, hand_shake_latency=0 + ) # Get the mock instance mock_wrapper_instance = mock_nixl_wrapper.return_value @@ -623,12 +943,13 @@ def test_register_kv_caches(dist_init): for i, cache_entry in enumerate(caches_data): base_addr, size, _tp_rank, _ = cache_entry - assert size == expected_tensor_size, \ - f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ - f"got {size}" - assert base_addr == expected_base_addrs[i], \ - f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + assert size == expected_tensor_size, ( + f"Entry {i}: Expected tensor size {expected_tensor_size}, got {size}" + ) + assert base_addr == expected_base_addrs[i], ( + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " f"got {base_addr}" + ) # Verify get_xfer_descs was called with blocks_data assert mock_wrapper_instance.get_xfer_descs.called @@ -636,13 +957,190 @@ def test_register_kv_caches(dist_init): # Validate blocks_data structure and size expected_blocks_count = 8 - assert len(blocks_data) == expected_blocks_count, \ - f"Expected {expected_blocks_count} blocks, " \ - f"got {len(blocks_data)}" + assert len(blocks_data) == expected_blocks_count, ( + f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" + ) expected_block_len = expected_tensor_size // 2 for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry - assert block_len == expected_block_len, \ - f"Block entry {i}: Expected block len {expected_block_len}, " \ + assert block_len == expected_block_len, ( + f"Block entry {i}: Expected block len {expected_block_len}, " f"got {block_len}" + ) + + +class FakePlatform(Platform): + device_type: str = "oot" + + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + """ + Returns a mapping from device_type to a tuple of supported + kv_buffer_device for nixl. + """ + return {"oot": ("oot",)} + + @classmethod + def get_nixl_memory_type(cls) -> Optional[str]: + """ + Returns the nixl memory type for the current platform. + """ + return "VRAM" + + +@pytest.mark.parametrize( + "kv_buffer_device, nixl_memory_type", + [ + ("oot", "VRAM"), + ], +) +def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type): + """ + Test that register_kv_caches() passes the correct memory types from the + config to the nixl_wrapper. + """ + vllm_config = create_vllm_config() + # Override the default memory types in the config + vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + _NIXL_SUPPORTED_DEVICE, + ) + + _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) + + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", + FakePlatform, + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", + _NIXL_SUPPORTED_DEVICE, + ), + ): # noqa: E501 + # Create connector and replace its worker with a fake one for isolation + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + + # Verify get_reg_descs was called with the correct memory_type + assert connector.connector_worker.kv_buffer_device == kv_buffer_device + assert connector.connector_worker.nixl_memory_type == nixl_memory_type + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_shutdown_cleans_up_resources(dist_init): + """Test that shutdown() properly cleans up all resources.""" + vllm_config = create_vllm_config() + + worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id) + nixl_wrapper = worker.nixl_wrapper + + with ( + patch.object(worker, "_handshake_initiation_executor") as mock_exec, + patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, + patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, + patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, + patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, + patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, + ): + worker._recving_transfers = {"req1": [(123, time.perf_counter())]} + worker.src_xfer_side_handle = 456 + worker.dst_xfer_side_handles = {"engine1": 789} + worker._remote_agents = {"engine1": {0: "agent1"}} + worker._registered_descs = ["desc1", "desc2"] + + worker.shutdown() + + # Test idempotency + worker.shutdown() + worker.shutdown() + + mock_exec.shutdown.assert_called_with(wait=False) + mock_listener.join.assert_called_once_with(timeout=0) + + mock_rel_xfer.assert_called_once_with(123) + assert mock_rel_dlist.call_count == 2 + mock_rel_dlist.assert_any_call(456) # src handle + mock_rel_dlist.assert_any_call(789) # dst handle + mock_rem_agent.assert_called_once_with("agent1") + assert mock_dereg.call_count == 2 + mock_dereg.assert_any_call("desc1") + mock_dereg.assert_any_call("desc2") + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_aborted_request_removed_from_worker_in_batch(dist_init): + """ + Create and schedule a request so that P adds it to in-batch tracking via + the real scheduler, then simulate an abort (request not in next scheduler + iteration) and verify the worker no longer tracks it as in-batch. + """ + vllm_config = create_vllm_config() + + scheduler = create_scheduler(vllm_config) + # KVConnector Worker in P + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + + # Create a request that triggers do_remote_decode so that + # the scheduler adds it to reqs_in_batch + req = create_request(request_id=1, do_remote_decode=True, max_tokens=1) + scheduler.add_request(req) + + # First scheduling pass - examinate build_connector_meta output + sched_out = scheduler.schedule() + kv_meta = sched_out.kv_connector_metadata + assert kv_meta is not None + assert isinstance(kv_meta, NixlConnectorMetadata) + assert req.request_id in kv_meta.reqs_in_batch + + #### Model Runner start #### + # Bind scheduler-produced metadata and start worker processing. + connector.bind_connector_metadata(kv_meta) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Ensure it was tracked by the worker + assert req.request_id in connector.connector_worker._reqs_to_process + + #### Model Runner end #### + + # Abort request - request_finished call in connector scheduler + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) + # Second scheduling pass - build metadata with aborted request + sched_out2 = scheduler.schedule() + kv_meta2 = sched_out2.kv_connector_metadata + assert kv_meta2 is not None + assert isinstance(kv_meta2, NixlConnectorMetadata) + assert req.request_id not in kv_meta2.reqs_in_batch + + # Bind empty/abort metadata and run worker step + #### Model Runner start #### + connector.bind_connector_metadata(kv_meta2) + connector.start_load_kv(dummy_ctx) + + # After abort, the worker should not keep tracking it as "in-batch" + assert req.request_id not in connector.connector_worker._reqs_to_process + #### Model Runner end #### diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py new file mode 100644 index 0000000000000..46a5c097094eb --- /dev/null +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm import SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_events import BlockRemoved, BlockStored +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( + OffloadingConnector, + OffloadingConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.utils import sha256 +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + get_request_block_hasher, + init_none_hash, +) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + PrepareStoreOutput, +) +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + TransferResult, + TransferSpec, +) +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.request import Request + +from .utils import ( + EOS_TOKEN_ID, + create_model_runner_output, + create_scheduler, + create_vllm_config, +) + + +class MockLoadStoreSpec(LoadStoreSpec): + def __init__(self, block_hashes: Iterable[BlockHash]): + self.block_hashes: list[BlockHash] = list(block_hashes) + + @staticmethod + def medium() -> str: + return "Mock" + + def __repr__(self) -> str: + return repr(self.block_hashes) + + +class MockOffloadingHandler(OffloadingHandler): + def __init__(self): + self.completed_transfers: list[TransferResult] = [] + self.completed_specs: list[TransferSpec] = [] + + def get_finished(self) -> list[TransferResult]: + finished = self.completed_transfers + self.completed_transfers = [] + return finished + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + self.completed_specs.append(spec) + self.completed_transfers.append((job_id, True)) + return True + + +class MockOffloadingSpec(OffloadingSpec): + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + self.manager = MagicMock(spec=OffloadingManager) + self.manager.lookup.return_value = 0 + self.manager.prepare_load = lambda block_hashes: ( + MockLoadStoreSpec(block_hashes) + ) + self.handler = MockOffloadingHandler() + + def get_manager(self) -> OffloadingManager: + return self.manager + + def get_handlers( + self, _ + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: + yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler + yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler + + def get_completed_transfers(self) -> list[TransferSpec]: + specs = self.handler.completed_specs + self.handler.completed_specs = [] + return specs + + +@dataclass +class TransferSummary: + gpu_block_indices: list[int] + offload_addresses: list[Any] + + +class RequestRunner: + def __init__( + self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int + ): + self.offloaded_block_size: int = offloaded_block_size + self.gpu_block_size: int = gpu_block_size + self.num_gpu_blocks: int = num_gpu_blocks + + self.req_id: int = -1 + + vllm_config = create_vllm_config( + block_size=gpu_block_size, max_num_batched_tokens=1000 + ) + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "spec_name": "MockOffloadingSpec", + "spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector", # noqa: E501 + "block_size": offloaded_block_size, + }, + ) + + self.scheduler: Scheduler = create_scheduler( + vllm_config, num_blocks=num_gpu_blocks + ) + self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER) + + # register worker kv_caches to enable OffloadingWorker creations + self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)}) + + # extract connector of scheduler + scheduler_connector = self.scheduler.connector + assert scheduler_connector is not None + assert isinstance(scheduler_connector, OffloadingConnector) + self.scheduler_connector: OffloadingConnector = scheduler_connector + + # extract mocked OffloadingManager of scheduler connector + connector_scheduler = scheduler_connector.connector_scheduler + assert connector_scheduler is not None + manager = connector_scheduler.manager + assert isinstance(manager, MagicMock) + self.manager: MagicMock = manager + + assert connector_scheduler.gpu_block_size == gpu_block_size + assert connector_scheduler.offloaded_block_size == offloaded_block_size + + # extract OffloadingSpec of worker_connector + connector_worker = self.worker_connector.connector_worker + assert connector_worker is not None + offloading_spec = connector_worker.spec + assert isinstance(offloading_spec, MockOffloadingSpec) + self.offloading_spec: MockOffloadingSpec = offloading_spec + + # mapping (offloading address) -> gpu_block_index + self.offloaded: dict[Any, int] = {} + + self.pending_loads_count: int = 0 + self.pending_stores_count: int = 0 + + self.completed_loads: list[TransferSummary] = [] + self.completed_stores: list[TransferSummary] = [] + + # maps {block_id: block_offset} + self.gpu_block_index: dict[int, int] = {} + + init_none_hash(sha256) + self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) + + self._dummy_ctx: ForwardContext = ForwardContext( + no_compile_layers={}, attn_metadata={}, virtual_engine=0 + ) + + def new_request(self, token_ids: list[int]): + assert not self.scheduler.requests + self.req_id += 1 + + req = Request( + request_id=str(self.req_id), + prompt_token_ids=token_ids, + sampling_params=SamplingParams(max_tokens=1000), + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + block_hasher=self._block_hasher, + ) + + self.scheduler.add_request(req) + + def _wait_for_transfers(self): + block_size_factor = self.offloaded_block_size // self.gpu_block_size + + while self.pending_loads_count or self.pending_stores_count: + for transfer_spec in self.offloading_spec.get_completed_transfers(): + src_spec, dst_spec = transfer_spec + + if isinstance(src_spec, GPULoadStoreSpec): + store = True + gpu_spec = src_spec + offload_spec = dst_spec + else: + store = False + gpu_spec = dst_spec + offload_spec = src_spec + + assert isinstance(offload_spec, MockLoadStoreSpec) + assert isinstance(gpu_spec, GPULoadStoreSpec) + + gpu_block_indices: list[int] = [] + for block_id in gpu_spec.block_ids: + gpu_block_indices.append(self.gpu_block_index[block_id.item()]) + + # list of (block_hash, sub_block_offset) + offload_addresses: list[Any] = [] + for block_hash in offload_spec.block_hashes: + for sub_block_idx in range(block_size_factor): + offload_addresses.append((block_hash, sub_block_idx)) + + if store: + assert len(gpu_block_indices) == len(offload_addresses) + + self.completed_stores.append( + TransferSummary(gpu_block_indices, offload_addresses) + ) + self.pending_stores_count -= 1 + else: + remainder_sub_block_count = len(offload_addresses) - len( + gpu_block_indices + ) + assert remainder_sub_block_count >= 0 + assert remainder_sub_block_count < block_size_factor + offload_addresses = offload_addresses[remainder_sub_block_count:] + + self.completed_loads.append( + TransferSummary(gpu_block_indices, offload_addresses) + ) + self.pending_loads_count -= 1 + + def _update_gpu_block_idx(self): + for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks.values(): + for block_idx, block in enumerate(blocks): + self.gpu_block_index[block.block_id] = block_idx + + def _run(self, decoded_tokens: list[int]): + """ + Runs multiple engine (scheduler + worker) steps. + Assumes a single request is running. + + Args: + decoded_tokens: the tokens to yield at each step. + """ + + tokens_iter = iter(decoded_tokens) + token_id = next(tokens_iter, None) + while token_id is not None: + assert self.scheduler.requests + + scheduler_output = self.scheduler.schedule() + self._update_gpu_block_idx() + + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) + + self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) + self.pending_stores_count += len(kv_connector_metadata.reqs_to_store) + + self.worker_connector.bind_connector_metadata(kv_connector_metadata) + self.worker_connector.start_load_kv(self._dummy_ctx) + + if scheduler_output.total_num_scheduled_tokens > 0: + self.worker_connector.wait_for_save() + + finished_sending, finished_recving = self.worker_connector.get_finished( + scheduler_output.finished_req_ids + ) + + self.worker_connector.clear_connector_metadata() + + model_runner_output = create_model_runner_output( + reqs=self.scheduler.running, + finished_sending=finished_sending, + finished_recving=finished_recving, + token_id=token_id, + ) + + if self.scheduler.running: + token_id = next(tokens_iter, None) + + self.scheduler.update_from_output(scheduler_output, model_runner_output) + + self._wait_for_transfers() + + # run one more step to update finished stored + if EOS_TOKEN_ID in decoded_tokens: + assert not self.scheduler.running + + while self.scheduler.requests: + scheduler_output = self.scheduler.schedule() + + finished_sending, finished_recving = self.worker_connector.get_finished( + scheduler_output.finished_req_ids + ) + + assert not finished_recving + + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending + ) + + self.scheduler.update_from_output(scheduler_output, model_runner_output) + + def run( + self, + decoded_tokens: list[int], + expected_stored_gpu_block_indexes: tuple[int, ...] = (), + expected_loaded_gpu_block_indexes: tuple[int, ...] = (), + ): + """ + Runs multiple engine (scheduler + worker) steps. + Assumes a single request is running. + + Args: + decoded_tokens: the tokens to yield at each step. + expected_stored_gpu_block_indexes: GPU block indexes + that are expected to be written during the run. + expected_loaded_gpu_block_indexes: GPU block indexes + that are expected to be loaded during the run. + """ + + self.manager.reset_mock() + self._run(decoded_tokens) + + loaded_gpu_block_indexes: set[int] = set() + for transfer in self.completed_loads: + for gpu_block_idx, offloaded_address in zip( + transfer.gpu_block_indices, transfer.offload_addresses + ): + loaded_gpu_block_indexes.add(gpu_block_idx) + assert gpu_block_idx == self.offloaded[offloaded_address] + + assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes + self.completed_loads.clear() + + stored_gpu_block_indexes: set[int] = set() + for transfer in self.completed_stores: + for gpu_block_idx, offloaded_address in zip( + transfer.gpu_block_indices, transfer.offload_addresses + ): + stored_gpu_block_indexes.add(gpu_block_idx) + self.offloaded[offloaded_address] = gpu_block_idx + + assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes + self.completed_stores.clear() + + +@pytest.fixture +def request_runner(): + runners = [] + + def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks): + runner = RequestRunner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks, + ) + runners.append(runner) + return runner + + yield runner_factory # pass factory to the test + + +def generate_store_output(block_hashes: Iterable[BlockHash]): + block_hashes = list(block_hashes) + return PrepareStoreOutput( + block_hashes_to_store=list(block_hashes), + store_spec=MockLoadStoreSpec(block_hashes), + block_hashes_evicted=[], + ) + + +def test_offloading_connector(request_runner): + offloaded_block_size = 12 + gpu_block_size = 4 + num_gpu_blocks = 100 + block_size_factor = offloaded_block_size // gpu_block_size + + runner = request_runner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks, + ) + + # 3 blocks, store just the middle block (skip first and last) + # blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8] + runner.new_request(token_ids=[0] * offloaded_block_size * 3) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output(list(block_hashes)[1:2]) + ) + runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5)) + + # add block missing 1 token -> no offload + runner.run(decoded_tokens=[0] * (offloaded_block_size - 1)) + runner.manager.prepare_store.assert_not_called() + + # +1 token -> single block, fail prepare_store + runner.manager.prepare_store.side_effect = lambda block_hashes: None + runner.run(decoded_tokens=[0]) + runner.manager.prepare_store.assert_called() + + # 1 more block, now set block_hashes_to_store = [] + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output([]) + ) + runner.run(decoded_tokens=[0] * offloaded_block_size) + + # 1 more block, now check touch was called with all 6 blocks + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output(block_hashes) + ) + runner.run( + decoded_tokens=[0] * offloaded_block_size, + expected_stored_gpu_block_indexes=(15, 16, 17), + ) + runner.manager.touch.assert_called() + block_hashes1 = list(runner.manager.touch.call_args.args[0]) + assert len(block_hashes1) == 6 + + # terminate request + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + + # create a new request differing only on the last token + runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1]) + runner.run( + decoded_tokens=[0], + expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)), + ) + runner.manager.touch.assert_called() + block_hashes2 = list(runner.manager.touch.call_args.args[0]) + assert len(block_hashes2) == 6 + + # verify hashes are the same, except for the last block + assert block_hashes1[:5] == block_hashes2[:5] + assert block_hashes1[5] != block_hashes2[5] + + # terminate request + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + + # full_block_tokens - num_computed_tokens < offloaded_block_size + runner.new_request( + token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size) + ) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output([]) + ) + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.manager.lookup.assert_not_called() + + # single block lookup with no hits + runner.new_request(token_ids=[1] * offloaded_block_size) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output([]) + ) + runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.manager.lookup.assert_called() + assert len(list(runner.manager.lookup.call_args.args[0])) == 1 + + # single block lookup with a hit + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output([]) + ) + runner.manager.lookup.return_value = 1 + runner.run( + decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2) + ) + + # single block lookup with a hit in a middle block + runner.new_request( + token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size + ) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output([]) + ) + runner.manager.lookup.return_value = 1 + runner.run( + decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5) + ) + + # test take_events + def to_hashes(int_hashes: list[int]) -> list[BlockHash]: + return [BlockHash(str(i).encode()) for i in int_hashes] + + def take_events() -> Iterable[OffloadingEvent]: + yield OffloadingEvent( + block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False + ) + yield OffloadingEvent( + block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True + ) + + runner.manager.take_events.side_effect = take_events + events = list(runner.scheduler_connector.take_events()) + assert len(events) == 2 + event = events[0] + assert isinstance(event, BlockStored) + assert event.block_hashes == to_hashes([1, 2, 3]) + assert event.block_size == 16 + assert event.medium == "A" + assert event.token_ids == [] + assert event.parent_block_hash is None + assert event.lora_id is None + event = events[1] + assert isinstance(event, BlockRemoved) + assert event.block_hashes == to_hashes([4, 5, 6]) + assert event.medium == "B" diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py index 5d2b27a9eb4da..d05cbe1a2fd46 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -3,34 +3,41 @@ from concurrent.futures import Future from typing import Optional +import pytest + from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +pytestmark = pytest.mark.cpu_test + class DummyModelRunnerOutput(ModelRunnerOutput): - - def __init__(self, - finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None): + def __init__( + self, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None, + invalid_block_ids: Optional[set[int]] = None, + ): self.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, + invalid_block_ids=invalid_block_ids or set(), ) def __repr__(self): return ( f"DummyModelRunnerOutput(" f"finished_sending={self.kv_connector_output.finished_sending}," - f"finished_recving={self.kv_connector_output.finished_recving})") + f"finished_recving={self.kv_connector_output.finished_recving})" + f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})" + ) def test_aggregate_workers_output(): aggregator = KVOutputAggregator(world_size=2) - output1 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) - output2 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) + output1 = DummyModelRunnerOutput() + output2 = DummyModelRunnerOutput() aggregated = aggregator.aggregate([output1, output2]) @@ -38,30 +45,44 @@ def test_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving is None + assert not aggregated.invalid_block_ids - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving=None) - - aggregated = aggregator.aggregate([output1, output2]) - - assert aggregated is output1 - aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending == {'req1'} - assert aggregated.finished_recving is None - - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) + output1 = DummyModelRunnerOutput( + finished_sending={"req1"}, finished_recving={"req2"} + ) + output2 = DummyModelRunnerOutput(invalid_block_ids={1}) aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None - assert aggregated.finished_recving == {'req2'} + assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {1} + + output1 = DummyModelRunnerOutput(invalid_block_ids={2}) + output2 = DummyModelRunnerOutput(finished_sending={"req1"}) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + aggregated = aggregated.kv_connector_output + assert aggregated.finished_sending == {"req1"} + assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {2} + + output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) + output2 = DummyModelRunnerOutput( + finished_recving={"req2"}, invalid_block_ids={4, 5} + ) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + aggregated = aggregated.kv_connector_output + assert aggregated.finished_sending is None + assert aggregated.finished_recving == {"req2"} + assert aggregated.invalid_block_ids == {3, 4, 5} def test_async_aggregate_workers_output(): @@ -71,10 +92,8 @@ def test_async_aggregate_workers_output(): future2: Future[DummyModelRunnerOutput] = Future() result_future = aggregator.async_aggregate([future1, future2]) - output1 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) - output2 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) + output1 = DummyModelRunnerOutput() + output2 = DummyModelRunnerOutput() future1.set_result(output1) future2.set_result(output2) @@ -84,33 +103,16 @@ def test_async_aggregate_workers_output(): aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving is None + assert not aggregated.invalid_block_ids future1 = Future() future2 = Future() result_future = aggregator.async_aggregate([future1, future2]) - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving=None) - future1.set_result(output1) - future2.set_result(output2) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending == {'req1'} - assert aggregated.finished_recving is None - - future1 = Future() - future2 = Future() - result_future = aggregator.async_aggregate([future1, future2]) - - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) + output1 = DummyModelRunnerOutput( + finished_sending={"req1"}, finished_recving={"req2"} + ) + output2 = DummyModelRunnerOutput(invalid_block_ids={1}) future1.set_result(output1) future2.set_result(output2) @@ -119,4 +121,41 @@ def test_async_aggregate_workers_output(): assert aggregated is output1 aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None - assert aggregated.finished_recving == {'req2'} + assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {1} + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(invalid_block_ids={2}) + output2 = DummyModelRunnerOutput(finished_sending={"req1"}) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + aggregated = aggregated.kv_connector_output + assert aggregated.finished_sending == {"req1"} + assert aggregated.finished_recving is None + assert aggregated.invalid_block_ids == {2} + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) + output2 = DummyModelRunnerOutput( + finished_recving={"req2"}, invalid_block_ids={4, 5} + ) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + aggregated = aggregated.kv_connector_output + assert aggregated.finished_sending is None + assert aggregated.finished_recving == {"req2"} + assert aggregated.invalid_block_ids == {3, 4, 5} diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index d8c56ac42f718..b2ec2ddfb64da 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -2,11 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import pytest + from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus -from .utils import (assert_scheduler_empty, create_model_runner_output, - create_request, create_scheduler, create_vllm_config) +from .utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test def test_basic_lifecycle(): @@ -20,11 +29,13 @@ def test_basic_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - max_tokens=1, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) scheduler.add_request(request) request_id = request.request_id @@ -32,6 +43,7 @@ def test_basic_lifecycle(): # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -39,10 +51,11 @@ def test_basic_lifecycle(): model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) - # Ensure the request is finished after 1 tokens. + # Ensure the request is finished after 1 token. assert request.is_finished() assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED output = engine_core_outputs[0].outputs[0] @@ -55,14 +68,17 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. + assert len(scheduler.requests) == 1 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + 0 + ].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 # STEP (2): Send Finished to PB. # (2a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 1 assert request_id in scheduler_output.finished_req_ids @@ -79,6 +95,7 @@ def test_basic_lifecycle(): # STEP (3): Finished sending. # (3a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 @@ -88,7 +105,8 @@ def test_basic_lifecycle(): # (3b): execute_model() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_id]) + finished_sending={request_id} + ) # (3c): update_from_output() scheduler.update_from_output(scheduler_output, model_runner_output) @@ -106,17 +124,20 @@ def test_short_prompt_lifecycle(): # Not enough tokens for full block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_TOKENS = BLOCK_SIZE // 2 - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - max_tokens=1, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) scheduler.add_request(request) # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -128,20 +149,21 @@ def test_short_prompt_lifecycle(): eco = scheduler.update_from_output(scheduler_output, model_runner_output) kv_transfer_params = eco[0].outputs[0].kv_transfer_params - assert (len(kv_transfer_params["remote_block_ids"]) == 1) + assert len(kv_transfer_params["remote_block_ids"]) == 1 # Confirm we do not have any memory leaks after req lifecycle. # We need to mark sending finish to clear data for persistent batch. scheduler_output = scheduler.schedule() # Use create_model_runner_output to pass kv_connector_output along model_runner_output = create_model_runner_output( - reqs=[request], finished_sending=[request.request_id]) + reqs=[request], finished_sending={request.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) def test_prefix_cache_lifecycle(): - """Test that remote decode params still works with a prefix cache hit.""" + """Test that remote decode params still work with a prefix cache hit.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) @@ -151,16 +173,17 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 3 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS) + request_normal = create_request( + request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS + ) scheduler.add_request(request_normal) scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_normal], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) - scheduler.schedule() + scheduler_output = scheduler.schedule() scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) ##################### @@ -170,10 +193,12 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS -= 1 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_remote = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request_remote = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) scheduler.add_request(request_remote) scheduler_output = scheduler.schedule() @@ -183,14 +208,55 @@ def test_prefix_cache_lifecycle(): # Ensure we send all block ids, including the partial blocks, # even if there is a cache hit. - assert (len( - kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + - 1)) + assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_remote.request_id]) + finished_sending={request_remote.request_id} + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert_scheduler_empty(scheduler) + + +def test_abort_during_kv_transfer(): + """Test aborting request does not release blocks for remote decode.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) + + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + # Request removed from PB but blocks should not be freed. + assert len(scheduler.requests) == 1 + + # Abort the request, and check the blocks are still not freed + scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED) + assert len(scheduler.requests) == 1 + + # Simulate a finished sending notification + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request.request_id] + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 21fec5344255c..b9588ebcd2110 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -2,11 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import pytest + from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus -from .utils import (assert_scheduler_empty, create_model_runner_output, - create_request, create_scheduler, create_vllm_config) +from .utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test def test_basic_lifecycle(): @@ -20,12 +29,15 @@ def test_basic_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) START_FREE_BLOCK_QUEUE_SIZE = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + ) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) scheduler.add_request(request) request_id = request.request_id @@ -44,16 +56,16 @@ def test_basic_lifecycle(): # Req waiting for KVs with no computed/scheduled toks ... assert len(scheduler.waiting) == 1 assert request in scheduler.waiting - assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) - assert (request.num_computed_tokens == 0) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 # ... but should have (uncached) blocks allocated to it. block_pool = scheduler.kv_cache_manager.block_pool - assert (block_pool.free_block_queue.num_free_blocks - < START_FREE_BLOCK_QUEUE_SIZE) + assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE assert len(block_pool.cached_block_hash_to_block) == 0 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + 0 + ].req_to_blocks[request_id] for block in blocks: assert block._block_hash is None @@ -61,8 +73,9 @@ def test_basic_lifecycle(): model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) assert not engine_core_outputs or not engine_core_outputs[0].outputs # STEP (2): @@ -74,13 +87,15 @@ def test_basic_lifecycle(): # (2b): forward(): request finishes recv. model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + finished_recving={request_id} + ) # (2c): update_from_output(): - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_kv_req_ids) + assert request_id in scheduler.finished_recving_kv_req_ids # STEP (3): # (3a): schedule(): this should actually schedule. @@ -90,10 +105,11 @@ def test_basic_lifecycle(): # Confirm the block are actually allocated. num_hashed_blocks = 0 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + 0 + ].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 - num_hashed_blocks += (1 if block._block_hash is not None else 0) + num_hashed_blocks += 1 if block._block_hash is not None else 0 assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS # Confirm the rest of the prompt is scheduled in this step. @@ -101,7 +117,7 @@ def test_basic_lifecycle(): num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] num_computed_tokens = scheduled_req.num_computed_tokens total_prompt_tokens = len(scheduled_req.prompt_token_ids) - assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens # (3b): execute_model() model_runner_output = create_model_runner_output([request]) @@ -111,8 +127,9 @@ def test_basic_lifecycle(): # Step (4): Hit EOS. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output([request], use_eos=True) - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) scheduler.schedule() outputs = engine_core_outputs[0].outputs @@ -133,10 +150,12 @@ def test_interleaved_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_remote = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request_remote = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) request_local_a = create_request( request_id=2, block_size=BLOCK_SIZE, @@ -165,8 +184,7 @@ def test_interleaved_lifecycle(): assert len(scheduler_output.scheduled_new_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1 - model_runner_output = create_model_runner_output( - [request_local_a, request_local_b]) + model_runner_output = create_model_runner_output([request_local_a, request_local_b]) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 3: continue running, KVs not arrived yet. @@ -177,7 +195,8 @@ def test_interleaved_lifecycle(): assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( - reqs=[request_local_a, request_local_b]) + reqs=[request_local_a, request_local_b] + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 @@ -192,8 +211,8 @@ def test_interleaved_lifecycle(): assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( - [request_local_a, request_local_b], - finished_recving=[request_remote.request_id]) + [request_local_a, request_local_b], finished_recving={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 5: RECVed KVs are sent to ModelRunner. @@ -204,7 +223,8 @@ def test_interleaved_lifecycle(): assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( - [request_local_a, request_local_b, request_remote]) + [request_local_a, request_local_b, request_remote] + ) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 6: Hit EOS and free. @@ -242,16 +262,16 @@ def test_no_spurious_prefix_caching(): request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, + common_prefix_len=NUM_TOKENS, do_remote_prefill=True, - use_all_1s_for_prompt_tokens=True, ) request_local = create_request( request_id=2, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, + common_prefix_len=NUM_TOKENS, do_remote_prefill=False, - use_all_1s_for_prompt_tokens=True, ) # Schedule the remote prefill request. This should not @@ -269,15 +289,17 @@ def test_no_spurious_prefix_caching(): assert len(scheduler.waiting) == 1 local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_local.request_id] + 0 + ].req_to_blocks[request_local.request_id] remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_remote.request_id] + 0 + ].req_to_blocks[request_remote.request_id] # Local should have cached blocks (but not all due to preallocate). num_hashed_blocks = 0 for block in local_blocks: assert block.ref_cnt == 1 - num_hashed_blocks += (1 if block._block_hash is not None else 0) + num_hashed_blocks += 1 if block._block_hash is not None else 0 assert num_hashed_blocks > 0 # Remote blocks should not be cached. @@ -297,10 +319,12 @@ def test_full_block_prompt(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) scheduler.add_request(request) request_id = request.request_id @@ -308,8 +332,11 @@ def test_full_block_prompt(): # STEP (1): Initialize a recv. scheduler_output = scheduler.schedule() # All blocks should be allocated. - num_blocks = len(scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]) + num_blocks = len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ] + ) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT scheduler.update_from_output(scheduler_output, model_runner_output) @@ -318,22 +345,25 @@ def test_full_block_prompt(): scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + finished_recving={request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_kv_req_ids) + assert request_id in scheduler.finished_recving_kv_req_ids # # STEP (3): Run as usual. scheduler_output = scheduler.schedule() # We need to recompute the final token of the prompt to generate # the first new token, so we should not have a new block. - num_blocks = len(scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]) + num_blocks = len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ] + ) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS - assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == - NUM_TOKENS - 1) - assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1 + assert scheduler_output.num_scheduled_tokens[request_id] == 1 model_runner_output = create_model_runner_output([request]) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -341,8 +371,9 @@ def test_full_block_prompt(): # # Step (4): Hit EOS. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output([request], use_eos=True) - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) scheduler.schedule() outputs = engine_core_outputs[0].outputs @@ -371,13 +402,15 @@ def test_cannot_schedule_after_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - request_normal = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_LOCAL) - request_remote = create_request(request_id=2, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_REMOTE, - do_remote_prefill=True) + request_normal = create_request( + request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL + ) + request_remote = create_request( + request_id=2, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True, + ) # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). scheduler.add_request(request_normal) @@ -398,7 +431,8 @@ def test_cannot_schedule_after_recv(): # Step 3: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[request_normal], finished_recving=[request_remote.request_id]) + reqs=[request_normal], finished_recving={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 @@ -407,7 +441,8 @@ def test_cannot_schedule_after_recv(): # because the transfer is completed. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[request_normal, request_remote]) + reqs=[request_normal, request_remote] + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 0 @@ -422,8 +457,9 @@ def test_cannot_schedule_after_recv(): # Step 6: finish the request, free it. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_normal], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @@ -432,16 +468,19 @@ def test_cannot_schedule_after_recv(): # request is retrieved from preempted list. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) - assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] == - NUM_PROMPT_BLOCKS * BLOCK_SIZE) + assert ( + scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] + == NUM_PROMPT_BLOCKS * BLOCK_SIZE + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # Step 8: free everything. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_remote], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_remote], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) @@ -466,13 +505,15 @@ def test_cannot_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_LOCAL) - request_remote = create_request(request_id=2, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_REMOTE, - do_remote_prefill=True) + request_normal = create_request( + request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL + ) + request_remote = create_request( + request_id=2, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True, + ) # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). scheduler.add_request(request_normal) @@ -491,12 +532,13 @@ def test_cannot_recv(): assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # Should not have KV transfer in progress. - assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS) + assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS # Step 3: finish the request, free it. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_normal], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @@ -507,12 +549,13 @@ def test_cannot_recv(): scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 - assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS # Step 5: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[], finished_recving=[request_remote.request_id]) + reqs=[], finished_recving={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @@ -526,8 +569,9 @@ def test_cannot_recv(): # Step 7: free everything. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_remote], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_remote], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index db203b81f15fc..e7013a794a8c6 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -33,20 +33,26 @@ def _check_path_len(path): def _list_path(path): - """Return the list of foldername (hashes generatd) under the path""" + """Return the list of foldername (hashes generated) under the path""" return list(path.iterdir()) -def run_test(tmp_path, processor, llm: LLM, question: str, - image_urls: list[Image], expected_len: int, info: str): +def run_test( + tmp_path, + processor, + llm: LLM, + question: str, + image_urls: list[Image], + expected_len: int, + info: str, +): """ One individual test to process the prompt and output base on 1 set of input - Then check if the length in the strorage path matches the expected length + Then check if the length in the storage path matches the expected length `info` introduces details or purpose of the individual test """ print(f"***info: {info}***") - print( - f"**Expected storage path length after llm generate: {expected_len}**") + print(f"**Expected storage path length after llm generate: {expected_len}**") process_prompt(processor, llm, question, image_urls) print(f"Path matched expected length: {_check_path_len(tmp_path)}") @@ -54,51 +60,42 @@ def run_test(tmp_path, processor, llm: LLM, question: str, assert _check_path_len(tmp_path) == expected_len, ( f"Expect storage path length {expected_len} ;", - f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}") + f"but end up {_check_path_len(tmp_path)} instead. ", + f"Info: {info}", + ) -def process_prompt(processor, llm: LLM, question: str, - image_urls: list[Image]): +def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]): """ Form the prompt based on the text and image input, then llm generate output """ - placeholders = [{ - "type": "image_url", - "image_url": { - "url": f"data:image;base64,{encode_image_base64(image_pil)}" + placeholders = [ + { + "type": "image_url", + "image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"}, } - } for image_pil in image_urls] + for image_pil in image_urls + ] messages = [ - { - "role": "system", - "content": "You are a helpful assistant." - }, + {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ *placeholders, - { - "type": "text", - "text": question - }, + {"type": "text", "text": question}, ], }, ] - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) outputs = llm.generate( { - "prompt": - prompt, - **({ - "multi_modal_data": { - "image": [*image_urls] - } - } if image_urls else {}) + "prompt": prompt, + **({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}), }, sampling_params=SAMPLING_PARAMS, ) @@ -114,8 +111,8 @@ def process_prompt(processor, llm: LLM, question: str, def test_shared_storage_connector_hashes(tmp_path): """ Tests that SharedStorageConnector saves KV to the storage locations - with proper hashes; that are unique for inputs with identical text but - differnt images (same size), or same multiple images but different orders. + with proper hashes; that are unique for inputs with identical text but + different images (same size), or same multiple images but different orders. """ # Using tmp_path as the storage path to store KV print(f"KV storage path at: {str(tmp_path)}") @@ -124,7 +121,8 @@ def test_shared_storage_connector_hashes(tmp_path): kv_transfer_config = KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": str(tmp_path)}) + kv_connector_extra_config={"shared_storage_path": str(tmp_path)}, + ) engine_args = EngineArgs( model=MODEL_NAME, @@ -157,56 +155,88 @@ def test_shared_storage_connector_hashes(tmp_path): # Prepare the input cases input_cases = [ - InputCase(text=TEXT_PROMPTS[0], - img=[image_1], - expected_len=1, - info="image_1 single input the first time."), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2], - expected_len=2, - info=("image_2 single input the first time. " - "It is in same pixel size with image_1, yet it " - "should be able to form a new unique hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_1], - expected_len=2, - info=("image_1 single input the 2nd time. " - "It should not form aother new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2], - expected_len=2, - info=("image_2 single input the 2nd time. " - "It should not form aother new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_1, image_2], - expected_len=3, - info="image_1 with image_2 input the first time."), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2, image_1], - expected_len=4, - info="The image order is swapped. Should form new hash."), - InputCase(text=TEXT_PROMPTS[0], - img=[image_1, image_2], - expected_len=4, - info=("[image_1, image_2] input the 2nd time. " - "It should not form aother new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2, image_1], - expected_len=4, - info=("[image_2, image_1] input the 2nd time. " - "It should not form aother new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[], - expected_len=5, - info="Pure text input test as a case-control"), - InputCase(text=TEXT_PROMPTS[0], - img=[], - expected_len=5, - info="Identical pure text input as a case-control"), - InputCase(text=TEXT_PROMPTS[1], - img=[], - expected_len=6, - info="Another pure text input as a case-control"), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1], + expected_len=1, + info="image_1 single input the first time.", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2], + expected_len=2, + info=( + "image_2 single input the first time. " + "It is in same pixel size with image_1, yet it " + "should be able to form a new unique hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1], + expected_len=2, + info=( + "image_1 single input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2], + expected_len=2, + info=( + "image_2 single input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1, image_2], + expected_len=3, + info="image_1 with image_2 input the first time.", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2, image_1], + expected_len=4, + info="The image order is swapped. Should form new hash.", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1, image_2], + expected_len=4, + info=( + "[image_1, image_2] input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2, image_1], + expected_len=4, + info=( + "[image_2, image_1] input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[], + expected_len=5, + info="Pure text input test as a case-control", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[], + expected_len=5, + info="Identical pure text input as a case-control", + ), + InputCase( + text=TEXT_PROMPTS[1], + img=[], + expected_len=6, + info="Another pure text input as a case-control", + ), ] # Run tests diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 0dd57dfcc95c5..6f51b9bbcbdaa 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -2,23 +2,33 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile from collections import defaultdict +from itertools import count from typing import Any, Callable, Optional import torch from vllm import SamplingParams -from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, - ModelConfig, SchedulerConfig, VllmConfig) -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.config import ( + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) + SharedStorageConnector, +) +from vllm.utils import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -40,14 +50,24 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks + ) + == 0 + ) + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].num_cached_block + ) + == 0 + ) num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + ) + assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -60,12 +80,15 @@ def create_vllm_config( max_num_seqs: int = 16, max_num_batched_tokens: int = 64, block_size: int = 16, + max_model_len: int = 10000, + enable_chunked_prefill: bool = True, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, + max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, ) model_config = ModelConfig( model=model, @@ -85,11 +108,13 @@ def create_vllm_config( kv_connector="NixlConnector", kv_role="kv_both", ) - return VllmConfig(scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - device_config=DeviceConfig("cpu")) + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) def create_scheduler( @@ -102,9 +127,9 @@ def create_scheduler( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks @@ -113,58 +138,64 @@ def create_scheduler( kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, ) +_request_count = count(1) _none_hash_initialized = False -def create_request(request_id: int, - num_tokens: int = 10, - max_tokens: int = 16, - do_remote_decode: bool = False, - do_remote_prefill: bool = False, - use_all_1s_for_prompt_tokens: bool = False, - num_remote_blocks: int = 3, - block_size: int = 16, - hash_fn: Callable = hash) -> Request: +def create_request( + request_id: Optional[int] = None, + num_tokens: int = 10, + common_prefix_len=0, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + num_remote_blocks: int = 3, + block_size: int = 16, + hash_fn: Callable = sha256, +) -> Request: """Make dummy request for testing.""" + assert num_tokens >= common_prefix_len >= 0 + + if request_id is None: + request_id = next(_request_count) + global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(hash_fn) _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = dict(do_remote_prefill=False, - do_remote_decode=True) + kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True) elif do_remote_prefill: - kv_transfer_params = dict(do_remote_prefill=True, - do_remote_decode=False, - remote_engine_id="my-engine-id", - remote_block_ids=list( - range(num_remote_blocks)), - remote_host="my-host", - remote_port=1234) + kv_transfer_params = dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list(range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234, + ) max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) - if use_all_1s_for_prompt_tokens: - prompt_token_ids = [1] * num_tokens - else: - prompt_token_ids = [i * request_id for i in range(num_tokens)] + common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else [] + suffix = [i * request_id for i in range(num_tokens - common_prefix_len)] + prompt_token_ids = common_prefix + suffix req = Request( request_id=f"id-{request_id}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, - multi_modal_kwargs=None, - multi_modal_placeholders=None, - multi_modal_hashes=None, + mm_features=None, eos_token_id=EOS_TOKEN_ID, block_hasher=get_request_block_hasher(block_size, hash_fn), ) @@ -174,29 +205,40 @@ def create_request(request_id: int, def create_model_runner_output( reqs: list[Request], - finished_sending: Optional[list[str]] = None, - finished_recving: Optional[list[str]] = None, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None, + invalid_block_ids: Optional[set[int]] = None, use_eos: bool = False, + token_id: int = 0, ) -> ModelRunnerOutput: """Make dummy model runner output for testing.""" # Make request data. req_ids = [req.request_id for req in reqs] + req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)} # Make sampled tokens. - sampled_token = EOS_TOKEN_ID if use_eos else 0 + sampled_token = EOS_TOKEN_ID if use_eos else token_id sampled_token_ids = [[sampled_token] for _ in req_ids] - kv_connector_output = None if ( - finished_sending is None - and finished_recving is None) else KVConnectorOutput( + kv_connector_output = ( + None + if ( + finished_sending is None + and finished_recving is None + and invalid_block_ids is None + ) + else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, + invalid_block_ids=invalid_block_ids or set(), ) + ) # Make output data structure. return ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=req_id_to_index, sampled_token_ids=sampled_token_ids, logprobs=None, prompt_logprobs_dict={}, @@ -206,22 +248,30 @@ def create_model_runner_output( class TestSharedStorageConnector(SharedStorageConnector): - def __init__(self, config: VllmConfig, role): self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self._connector = SharedStorageConnector(config, role) self.call_record: dict[str, int] = defaultdict(int) # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}-{self.role.name}_events.log" + self._event_file = ( + tempfile.gettempdir() + + f"/connector_{self.name}-{self.role.name}_events.log" + ) # Start with an empty file with open(self._event_file, "w") as _: pass def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion + if name in ( + "_connector", + "call_record", + "name", + "_event_file", + "__class__", + "__dict__", + "__getattribute__", + "__init__", + ): # avoid recursion return object.__getattribute__(self, name) if not hasattr(self._connector, name): return object.__getattribute__(self, name) @@ -240,21 +290,20 @@ class TestSharedStorageConnector(SharedStorageConnector): if isinstance(arg, int): to_log.append(str(arg)) elif isinstance(arg, KVCacheBlocks): - to_log.append( - f"num_blocks={[len(b) for b in arg.blocks]}") + to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}") # Log the event as a line to the file try: with open(self._event_file, "a") as f: - f.write(' '.join(to_log) + "\n") + f.write(" ".join(to_log) + "\n") except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") + print(f"[ERROR] Could not log event {name} for {self.name}: {e}") return attr(*args, **kwargs) return wrapper return attr -KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__, - TestSharedStorageConnector.__name__) +KVConnectorFactory.register_connector( + "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ +) diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py new file mode 100644 index 0000000000000..81b57f1ca0c8d --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import time + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.flashinfer import FlashInferBackend +from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler + +NUM_GPU_BLOCKS = [64] +NUM_CPU_BLOCKS = [256] +GPU_BLOCK_SIZES = [16] +GPU_BLOCKS_PER_CPU_BLOCK = [1, 3] +HEAD_SIZES = [64] +NUM_HEADS = [8] +NUM_LAYERS = [4] +DTYPES = [torch.bfloat16] +SEEDS = [0] +CUDA_DEVICES = ["cuda:0"] +NUM_MAPPINGS = [3] + + +@pytest.mark.parametrize("gpu_to_cpu", [True, False]) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES) +@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK) +@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS) +@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_transfer( + gpu_to_cpu: bool, + num_mappings: int, + head_size: int, + num_heads: int, + gpu_block_size: int, + gpu_blocks_per_cpu_block: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + num_layers: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + + # create per-layer GPU KV caches + attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend] + + gpu_caches = {} + attn_backends = {} + for i in range(num_layers): + layer_name = f"layer {i}" + + attn_backend = attn_backends_list[i % len(attn_backends_list)] + attn_backends[layer_name] = attn_backend + + gpu_cache_shape = attn_backend.get_kv_cache_shape( + num_gpu_blocks, gpu_block_size, num_heads, head_size + ) + gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device) + + # create handler + cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size + handler = CpuGpuOffloadingHandler( + attn_backends=attn_backends, + gpu_block_size=gpu_block_size, + cpu_block_size=cpu_block_size, + num_cpu_blocks=num_cpu_blocks, + gpu_caches=gpu_caches, + ) + + # select block mappings + gpu_blocks = random.sample( + range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block + ) + cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings) + + # convert cpu blocks to gpu block size + cpu_blocks_in_gpu_block_size = [] + for cpu_block in cpu_blocks: + base_block_id = cpu_block * gpu_blocks_per_cpu_block + for i in range(gpu_blocks_per_cpu_block): + cpu_blocks_in_gpu_block_size.append(i + base_block_id) + + # maybe skip a GPU block to test writing to the middle of a CPU block + if gpu_to_cpu: + gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :] + cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[ + gpu_blocks_per_cpu_block - 1 : + ] + + # set transfer direction + if gpu_to_cpu: + src_kv_caches = handler.gpu_tensors + dst_kv_caches = handler.cpu_tensors + src_spec_class = GPULoadStoreSpec + dst_spec_class = CPULoadStoreSpec + src_blocks = gpu_blocks + dst_blocks = cpu_blocks + src_blocks_in_gpu_block_size = gpu_blocks + dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size + dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block + else: + src_kv_caches = handler.cpu_tensors + dst_kv_caches = handler.gpu_tensors + src_spec_class = CPULoadStoreSpec + dst_spec_class = GPULoadStoreSpec + src_blocks = cpu_blocks + dst_blocks = gpu_blocks + src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size + dst_blocks_in_gpu_block_size = gpu_blocks + dst_size_in_gpu_blocks = num_gpu_blocks + + # build dst -> src mapping + dst_to_src = {} + for src_block, dst_block in zip( + src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size + ): + dst_to_src[dst_block] = src_block + + # build transfer specs + src_spec = src_spec_class(src_blocks) + dst_spec = dst_spec_class(dst_blocks) + + # clone src and dst tensors before transfer + orig_src_caches = [x.clone() for x in src_kv_caches] + orig_dst_caches = [x.clone() for x in dst_kv_caches] + + # call transfer function + assert handler.transfer_async(1, (src_spec, dst_spec)) + assert set(handler.transfer_events.keys()) == {1} + + # wait for transfer to complete + end_time = time.time() + 10 + while time.time() < end_time: + finished = handler.get_finished() + if finished: + assert finished == [(1, True)] + break + time.sleep(0.1) + + # verify src tensors did not change + for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches): + assert torch.equal(orig_tensor, tensor) + + # verify dst tensors + for dst_block in range(dst_size_in_gpu_blocks): + src_block_candidate = dst_to_src.get(dst_block) + for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( + src_kv_caches, + dst_kv_caches, + orig_dst_caches, + handler.kv_dim_before_num_blocks, + ): + if kv_dim: + # iterate over key, value + for i in range(2): + if src_block_candidate is not None: + expected_value = src_cache[i][src_block_candidate] + else: + expected_value = orig_dst_cache[i][dst_block] + torch.testing.assert_close( + dst_cache[i][dst_block].cpu(), expected_value.cpu() + ) + else: + if src_block_candidate is not None: + expected_value = src_cache[src_block_candidate] + else: + expected_value = orig_dst_cache[dst_block] + torch.testing.assert_close( + dst_cache[dst_block].cpu(), expected_value.cpu() + ) diff --git a/tests/v1/kv_offload/test_cpu_manager.py b/tests/v1/kv_offload/test_cpu_manager.py new file mode 100644 index 0000000000000..57884f846b513 --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_manager.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + PrepareStoreOutput, +) +from vllm.v1.kv_offload.backends.cpu import CPUBackend +from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec + + +@dataclass +class ExpectedPrepareStoreOutput: + block_hashes_to_store: list[int] + store_block_ids: list[int] + block_hashes_evicted: list[int] + + +def to_hashes(int_hashes: list[int]) -> list[BlockHash]: + return [BlockHash(str(i).encode()) for i in int_hashes] + + +def verify_store_output( + prepare_store_output: Optional[PrepareStoreOutput], + expected_prepare_store_output: ExpectedPrepareStoreOutput, +): + assert prepare_store_output is not None + assert prepare_store_output.block_hashes_to_store == to_hashes( + expected_prepare_store_output.block_hashes_to_store + ) + assert prepare_store_output.block_hashes_evicted == to_hashes( + expected_prepare_store_output.block_hashes_evicted + ) + store_spec = prepare_store_output.store_spec + assert isinstance(store_spec, CPULoadStoreSpec) + expected_array = np.array( + expected_prepare_store_output.store_block_ids, dtype=np.int64 + ) + assert np.array_equal(expected_array, store_spec.block_ids) + + +def verify_load_output( + prepare_load_output: LoadStoreSpec, expected_prepare_load_output: list[int] +): + assert isinstance(prepare_load_output, CPULoadStoreSpec) + expected_array = np.array(expected_prepare_load_output, dtype=np.int64) + assert np.array_equal(expected_array, prepare_load_output.block_ids) + + +def verify_events( + events: Iterable[OffloadingEvent], + block_size: int, + expected_stores: tuple[set[int], ...] = (), + expected_evictions: tuple[set[int], ...] = (), +): + stores: list[set[BlockHash]] = [] + evictions: list[set[BlockHash]] = [] + for event in events: + assert event.medium == CPULoadStoreSpec.medium() + assert event.block_size == block_size + if event.removed: + evictions.append(set(event.block_hashes)) + else: + stores.append(set(event.block_hashes)) + + def to_hash_sets(int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]: + return tuple([set(to_hashes(list(int_set))) for int_set in int_sets]) + + assert tuple(evictions) == to_hash_sets(expected_evictions) + assert tuple(stores) == to_hash_sets(expected_stores) + + +def test_cpu_manager(): + """ + Tests LRUOffloadingManager with a CPUBackend. + """ + # initialize a CPU backend with a capacity of 4 blocks + block_size = 256 + cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) + cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True) + + # prepare store [1, 2] + prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[1, 2], + store_block_ids=[0, 1], + block_hashes_evicted=[], + ), + ) + + # lookup [1, 2] -> not ready + assert cpu_manager.lookup(to_hashes([1, 2])) == 0 + + # no events so far + assert list(cpu_manager.take_events()) == [] + + # complete store [1, 2] + cpu_manager.complete_store(to_hashes([1, 2])) + verify_events( + cpu_manager.take_events(), block_size=block_size, expected_stores=({1, 2},) + ) + + # lookup [1, 2] + assert cpu_manager.lookup(to_hashes([1])) == 1 + assert cpu_manager.lookup(to_hashes([1, 2])) == 2 + assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2 + + # prepare store [2, 3, 4, 5] -> evicts [1] + prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[3, 4, 5], + store_block_ids=[2, 3, 0], + block_hashes_evicted=[1], + ), + ) + + # verify eviction event + verify_events( + cpu_manager.take_events(), block_size=block_size, expected_evictions=({1},) + ) + + # prepare store with no space + assert cpu_manager.prepare_store(to_hashes([1, 6])) is None + + # complete store [2, 3, 4, 5] + cpu_manager.complete_store(to_hashes([2, 3, 4, 5])) + + # prepare load [2, 3] + prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3])) + verify_load_output(prepare_load_output, [1, 2]) + + # prepare store with no space ([2, 3] is being loaded) + assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None + + # complete load [2, 3] + cpu_manager.complete_load(to_hashes([2, 3])) + + # prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest) + prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[6, 7, 8], + store_block_ids=[3, 2, 1], + block_hashes_evicted=[2, 3, 4], + ), + ) + + # complete store [6, 7, 8] + cpu_manager.complete_store(to_hashes([6, 7, 8])) + + # touch [5, 6, 7] (move to end of LRU order) + cpu_manager.touch(to_hashes([5, 6, 7])) + + # prepare store [7, 9] -> evicts [8] (oldest following previous touch) + prepare_store_output = cpu_manager.prepare_store(to_hashes([9])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[9], + store_block_ids=[1], + block_hashes_evicted=[8], + ), + ) + + # complete store [7, 9] with failure + cpu_manager.complete_store(to_hashes([7, 9]), success=False) + + # assert [7] is still stored, but [9] is not + assert cpu_manager.lookup(to_hashes([7])) == 1 + assert cpu_manager.lookup(to_hashes([9])) == 0 + + verify_events( + cpu_manager.take_events(), + block_size=block_size, + expected_stores=({3, 4, 5}, {6, 7, 8}), + expected_evictions=({2, 3, 4}, {8}), + ) diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py new file mode 100644 index 0000000000000..0d90cc715fd48 --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +CPU_BLOCK_SIZES = [16, 48] + + +@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) +def test_cpu_offloading(cpu_block_size: int) -> None: + """ + Tests OffloadingConnector with CPUOffloadingSpec. + """ + + # configure OffloadingConnector (spec_name=CPUOffloadingSpec by default) + kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config={"num_cpu_blocks": 100, "block_size": cpu_block_size}, + ) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + + prompts = ["Hi " * 100] + sampling_params = SamplingParams(temperature=0, max_tokens=20) + + # run generation - this should trigger saving KV cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cold_time = time.time() - start_time + + # run generation again - should hit the GPU prefix cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + gpu_hit_time = time.time() - start_time + + # reset prefix cache to avoid GPU hit. + llm.reset_prefix_cache() + + # sleep for a sec to make sure CPU finished storing + time.sleep(1) + + # run generation again - this should trigger loading from CPU + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cpu_hit_time = time.time() - start_time + + print("Generation times:") + print(f" Cold: {cold_time * 1000:.2f}ms") + print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms") + print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms") diff --git a/tests/v1/kv_offload/test_worker.py b/tests/v1/kv_offload/test_worker.py new file mode 100644 index 0000000000000..6fcd408f3c593 --- /dev/null +++ b/tests/v1/kv_offload/test_worker.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.v1.kv_offload.abstract import LoadStoreSpec +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + OffloadingWorker, + TransferResult, + TransferSpec, +) + + +class LoadStoreSpec1(LoadStoreSpec): + def __init__( + self, + submit_success: bool = True, + async_success: bool = True, + exception: bool = False, + ): + self.finished = False + self.submit_success = submit_success + self.async_success = async_success + self.exception = exception + + @staticmethod + def medium() -> str: + return "1" + + def __repr__(self): + return f"{self.medium()}: {id(self)}" + + +class LoadStoreSpec2(LoadStoreSpec): + @staticmethod + def medium() -> str: + return "2" + + def __repr__(self): + return f"{self.medium()}: {id(self)}" + + +class OffloadingHandler1To2(OffloadingHandler): + def __init__(self): + self.transfers: dict[int, LoadStoreSpec1] = {} + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + src, dst = spec + assert isinstance(src, LoadStoreSpec1) + assert isinstance(dst, LoadStoreSpec2) + + if src.exception: + raise Exception("An expected exception. Don't worry!") + if not src.submit_success: + return False + + self.transfers[job_id] = src + return True + + def get_finished(self) -> list[TransferResult]: + finished = [] + for job_id, spec in list(self.transfers.items()): + if spec.finished: + finished.append((job_id, spec.async_success)) + del self.transfers[job_id] + return finished + + +class OffloadingHandler2To1(OffloadingHandler): + def __init__(self): + self.transfers: dict[int, LoadStoreSpec1] = {} + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + src, dst = spec + assert isinstance(src, LoadStoreSpec2) + assert isinstance(dst, LoadStoreSpec1) + + self.transfers[job_id] = dst + return True + + def get_finished(self) -> list[TransferResult]: + finished = [] + for job_id, spec in list(self.transfers.items()): + if spec.finished: + finished.append((job_id, spec.async_success)) + del self.transfers[job_id] + return finished + + +def test_offloading_worker(): + """ + Tests OffloadingWorker with 2 handlers. + One handler performs 1->2 transfers, and the other handles 2->1. + """ + worker = OffloadingWorker() + handler1to2 = OffloadingHandler1To2() + handler2to1 = OffloadingHandler2To1() + worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2) + worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1) + + # 1st transfer 1->2 (exception) + src1 = LoadStoreSpec1(exception=True) + dst1 = LoadStoreSpec2() + assert not worker.transfer_async(1, (src1, dst1)) + + # 2ed transfer 1->2 (failure to submit) + src2 = LoadStoreSpec1(submit_success=False) + dst2 = LoadStoreSpec2() + assert not worker.transfer_async(2, (src2, dst2)) + + # 3rd transfer 1->2 (failure) + src3 = LoadStoreSpec1(async_success=False) + dst3 = LoadStoreSpec2() + assert worker.transfer_async(3, (src3, dst3)) + + # 4th transfer 1->2 (success) + src4 = LoadStoreSpec1() + dst4 = LoadStoreSpec2() + worker.transfer_async(4, (src4, dst4)) + assert set(handler1to2.transfers.keys()) == {3, 4} + + # 5th transfer 2->1 + src5 = LoadStoreSpec2() + dst5 = LoadStoreSpec1() + worker.transfer_async(5, (src5, dst5)) + assert set(handler2to1.transfers.keys()) == {5} + + # no transfer completed yet + assert worker.get_finished() == [] + + # complete 3rd, 4th + src3.finished = True + src4.finished = True + + # 6th transfer 1->2 + src6 = LoadStoreSpec1() + dst6 = LoadStoreSpec2() + worker.transfer_async(6, (src6, dst6)) + + # 7th transfer 2->1 + src7 = LoadStoreSpec2() + dst7 = LoadStoreSpec1() + worker.transfer_async(7, (src7, dst7)) + + # 6th and 7th transfers started + assert 6 in handler1to2.transfers + assert 7 in handler2to1.transfers + + # verify result of 3rd and 4th transfers + assert sorted(worker.get_finished()) == [(3, False), (4, True)] + + # complete 6th and 7th transfers + src6.finished = True + dst7.finished = True + assert sorted(worker.get_finished()) == [(6, True), (7, True)] diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 43caef79b02f7..538b6281f5a07 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -10,24 +10,28 @@ import pytest import torch from tests.utils import create_new_process_for_each_test -from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, - create_penalty_tensor, - create_prompt_tokens_tensor, - fake_apply_logitsprocs, - fake_update_logitsprocs_state) +from tests.v1.sample.utils import ( + LogitsprocsTestFakes, + create_fake_logits, + create_penalty_tensor, + create_prompt_tokens_tensor, + fake_apply_logitsprocs, + fake_update_logitsprocs_state, +) from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available -# yapf: disable -from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, - LogitBiasLogitsProcessor, - LogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - MoveDirectionality, - build_logitsprocs) -# yapf: enable +from vllm.v1.sample.logits_processor import ( + BatchUpdate, + BatchUpdateBuilder, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + MoveDirectionality, + build_logitsprocs, +) from vllm.v1.sample.metadata import SamplingMetadata PIN_MEMORY_AVAILABLE = is_pin_memory_available() @@ -49,9 +53,10 @@ LogitprocType = Union[type[LogitsProcessor], str] class LogitsProcsRequestParams: """Encapsulates key params for a single request in a batch. - + Params can be customized based on the enabled logitproc """ + workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test @@ -64,14 +69,13 @@ class LogitsProcsRequestParams: # Number of output tokens is randomly 0 or twice the min-tokens # threshold which will be used in testing. Output token values # don't matter *for these tests* so use 0 as a dummy value - self.out_tokens = ([0] * - (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)) self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) def __str__(self): """For debugging""" - summ = ', '.join(f'{k}={v}' for k, v in vars(self).items()) + summ = ", ".join(f"{k}={v}" for k, v in vars(self).items()) return f"MyClass({summ})" @@ -86,12 +90,13 @@ def _generate_fake_sampling_metadata( prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( - np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + np.random.randint(0, vocab_size, size=num_output_tokens).tolist() + ) prompt_token_ids.append( - np.random.randint(0, - vocab_size, - size=np.random.randint( - 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + np.random.randint( + 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) + ).tolist() + ) logitsprocs = build_logitsprocs( vllm_config=VllmConfig(), device=device, @@ -99,15 +104,16 @@ def _generate_fake_sampling_metadata( is_pooling_model=False, ) fake_sampling_metadata = SamplingMetadata( - temperature=torch.full((batch_size, ), 0.0), + temperature=torch.full((batch_size,), 0.0), all_greedy=True, all_random=False, top_p=None, top_k=None, generators={}, max_num_logprobs=0, - prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids, - vocab_size, device), + prompt_token_ids=create_prompt_tokens_tensor( + prompt_token_ids, vocab_size, device + ), output_token_ids=output_token_ids, frequency_penalties=create_penalty_tensor(batch_size, 0.0, device), presence_penalties=create_penalty_tensor(batch_size, 0.0, device), @@ -115,7 +121,8 @@ def _generate_fake_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=logitsprocs) + logitsprocs=logitsprocs, + ) return fake_sampling_metadata @@ -127,15 +134,15 @@ def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes: fake_logits[i, 0] = 10.0 # High logit for first token fake_logits[i, 1:] = 1e-2 # Others remain low sampling_metadata = _generate_fake_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) return LogitsprocsTestFakes( logits=fake_logits, sampling_metadata=sampling_metadata, ) -def _sampling_params_from_logitproc( - logitproc_type: LogitprocType) -> SamplingParams: +def _sampling_params_from_logitproc(logitproc_type: LogitprocType) -> SamplingParams: """Customize request SamplingParams for a specified logitproc""" # SamplingParams for req with no logitproc kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0} @@ -150,7 +157,7 @@ def _generate_mixed_logitsprocs_batch_params( ) -> list[LogitsProcsRequestParams]: """Define key params for a batch of requests with a different logitproc enabled per request. - + The batch will have `reqs_per_logitproc` repeats for all `logitsprocs_types` under test, including the case where no logitsproc is enabled. The batch is randomly shuffled. The @@ -173,7 +180,8 @@ def _generate_mixed_logitsprocs_batch_params( return [ LogitsProcsRequestParams( workload_index=idx, - logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc]) + logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc], + ) for idx, pdx in enumerate(batch_perm) ] @@ -185,10 +193,12 @@ def _raise_error_invalid( step_idx: int, err_cls: type[Exception] = ValueError, ) -> None: - raise err_cls(f"Validation failed for step={step_idx}, " - f"batch_index={batch_index}, " - f"workload_index={request_params.workload_index}, " - f"req_params={request_params}. Reason: {msg_suffix}") + raise err_cls( + f"Validation failed for step={step_idx}, " + f"batch_index={batch_index}, " + f"workload_index={request_params.workload_index}, " + f"req_params={request_params}. Reason: {msg_suffix}" + ) def _logit_bias_params(kwargs: dict) -> None: @@ -208,8 +218,7 @@ def _logit_bias_validate( ) -> None: """Validate logit bias logitproc applied correctly""" logit_bias = request_params.params.logit_bias - logits_old = ( - test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits_old = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu() logits_new = logits_new[batch_index].cpu() for token_id in range(VOCAB_SIZE): logit_old_value = logits_old[token_id] @@ -218,22 +227,28 @@ def _logit_bias_validate( bias_value = logit_bias[token_id] exp_value = bias_value + logit_old_value if logit_new_value != pytest.approx(exp_value): - _raise_error_invalid(msg_suffix=( - f"Biased token {token_id} logit value {logit_new_value} " - f"does not match expected value {exp_value} " - f"given bias {bias_value}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Biased token {token_id} logit value {logit_new_value} " + f"does not match expected value {exp_value} " + f"given bias {bias_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) else: if logit_new_value != pytest.approx(logit_old_value): - _raise_error_invalid(msg_suffix=( - f"Unbiased token {token_id} logit value {logit_new_value} " - f"does not match expected value {logit_old_value}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Unbiased token {token_id} logit value {logit_new_value} " + f"does not match expected value {logit_old_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) def _min_p_params(kwargs: dict) -> None: @@ -259,26 +274,27 @@ def _min_p_validate( msg_suffix="Invalid: dominant token 0 masked (-inf)", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: if request_params.params.min_p > 0.0: # Non-dominant tokens should be masked when min_p > 0 if logits_for_token != -float("inf"): _raise_error_invalid( - msg_suffix= - f"Invalid: non-dominant token {token_id} not masked", + msg_suffix=f"Invalid: non-dominant token {token_id} not masked", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: # No masking when min_p is 0 if logits_for_token == -float("inf"): _raise_error_invalid( - msg_suffix= - f"Invalid: token {token_id} masked when min_p=0.0", + msg_suffix=f"Invalid: token {token_id} masked when min_p=0.0", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _min_tokens_params(kwargs: dict) -> None: @@ -303,7 +319,8 @@ def _min_tokens_validate( min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD ref_all_stop_token_ids = request_params.params.all_stop_token_ids mt_lp: MinTokensLogitsProcessor = next( - test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)) + test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor) + ) assert isinstance(mt_lp, MinTokensLogitsProcessor) min_tok = mt_lp.min_toks.get(batch_index, None) @@ -312,38 +329,50 @@ def _min_tokens_validate( (_, out_tok, all_stop_token_ids) = min_tok num_out_tokens = len(out_tok) if num_out_tokens != ref_num_out_tokens: - _raise_error_invalid(msg_suffix=( - "Number of output tokens in min-token logit processor " - f"request metadata ({num_out_tokens}) does not match " - f"reference ({ref_num_out_tokens})."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + "Number of output tokens in min-token logit processor " + f"request metadata ({num_out_tokens}) does not match " + f"reference ({ref_num_out_tokens})." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) if ref_all_stop_token_ids != all_stop_token_ids: - _raise_error_invalid(msg_suffix=( - "Stop token ids do not match reference; all_stop_token_ids: " - f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " - f"{sorted(ref_all_stop_token_ids)}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + "Stop token ids do not match reference; all_stop_token_ids: " + f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " + f"{sorted(ref_all_stop_token_ids)}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) if min_reached: - _raise_error_invalid(msg_suffix=( - "Expected min-tokens request with min reached, but batch " - "index is recognized by min-tokens logits processor."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx, - err_cls=RuntimeError) + _raise_error_invalid( + msg_suffix=( + "Expected min-tokens request with min reached, but batch " + "index is recognized by min-tokens logits processor." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError, + ) elif not min_reached: - _raise_error_invalid(msg_suffix=( - "Expected min-tokens request with min not reached, but batch " - "index is not recognized by min-tokens logits processor."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx, - err_cls=RuntimeError) + _raise_error_invalid( + msg_suffix=( + "Expected min-tokens request with min not reached, but batch " + "index is not recognized by min-tokens logits processor." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError, + ) # Validate min-token logits for token_id in range(VOCAB_SIZE): @@ -351,21 +380,27 @@ def _min_tokens_validate( if token_id in ref_all_stop_token_ids and not min_reached: if logits_for_token != -float("inf"): _raise_error_invalid( - msg_suffix=(f"Token {token_id} is a stop token and " - "the sequence has not reached min length, " - "but the token is not masked " - f"(logit={logits_for_token})"), + msg_suffix=( + f"Token {token_id} is a stop token and " + "the sequence has not reached min length, " + "but the token is not masked " + f"(logit={logits_for_token})" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: if logits_for_token == -float("inf"): _raise_error_invalid( - msg_suffix=(f"Token {token_id} should not be masked but " - f"is (output len={ref_num_out_tokens})"), + msg_suffix=( + f"Token {token_id} should not be masked but " + f"is (output len={ref_num_out_tokens})" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _none_validate( @@ -377,52 +412,58 @@ def _none_validate( step_idx: int, ) -> None: """Validate that no logits processors are applied""" - logits = ( - test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu() ref_logits = logits_new[batch_index] if not torch.all(ref_logits == logits): - mismatch_toks = (ref_logits - != logits).nonzero(as_tuple=True)[0].tolist() + mismatch_toks = (ref_logits != logits).nonzero(as_tuple=True)[0].tolist() mismatch_strs = [] for token in mismatch_toks: val = float(logits[token]) ref_val = float(ref_logits[token]) mismatch_strs.append(f"({token=},{val=},{ref_val=})") - _raise_error_invalid(msg_suffix=( - f"Unexpected modification of logits: {','.join(mismatch_strs)}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Unexpected modification of logits: {','.join(mismatch_strs)}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) class LogitsprocTestHelpers(NamedTuple): """Supports setting up and validating logitsprocs unit tests.""" + eval_fxn: Callable gen_request_fxn: Optional[Callable] = None logitsprocs_test_mapping = { - STR_NO_LOGITPROC: - LogitsprocTestHelpers(eval_fxn=_none_validate), - LogitBiasLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params, - eval_fxn=_logit_bias_validate), - MinPLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_min_p_params, - eval_fxn=_min_p_validate), - MinTokensLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params, - eval_fxn=_min_tokens_validate), + STR_NO_LOGITPROC: LogitsprocTestHelpers(eval_fxn=_none_validate), + LogitBiasLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_logit_bias_params, eval_fxn=_logit_bias_validate + ), + MinPLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_min_p_params, eval_fxn=_min_p_validate + ), + MinTokensLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate + ), } def _get_test_cases() -> list[list[str]]: """Each test case is a set of logitsprocs""" logitsprocs_types = list(logitsprocs_test_mapping.keys()) - return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC] - for logitproc_type in logitsprocs_types - if logitproc_type != STR_NO_LOGITPROC - ] + [logitsprocs_types] + return ( + [[STR_NO_LOGITPROC]] + + [ + [logitproc_type, STR_NO_LOGITPROC] + for logitproc_type in logitsprocs_types + if logitproc_type != STR_NO_LOGITPROC + ] + + [logitsprocs_types] + ) def _generate_fake_step_update( @@ -440,11 +481,18 @@ def _generate_fake_step_update( # Other 50%: add a limited number of reqs (less than the number # of workload reqs remaining, less than an arbitrary max) # If no workload reqs remain: 100% of steps have 0 adds - num_step_add = random.choice([ - 0, - random.randint(1, min(max_add_remove_per_step, - workload_reqs_remaining)) - ]) if workload_reqs_remaining else 0 + num_step_add = ( + random.choice( + [ + 0, + random.randint( + 1, min(max_add_remove_per_step, workload_reqs_remaining) + ), + ] + ) + if workload_reqs_remaining + else 0 + ) # 50% of steps: remove no requests # Other 50%: remove a limited number of reqs (less than the number @@ -452,9 +500,11 @@ def _generate_fake_step_update( # If persistent batch is empty: 100% of steps have 0 removals until # more requests are added. Assume that removed requests are always # drawn from the current batch, before new adds - num_step_remove = random.choice([ - 0, random.randint(1, min(max_add_remove_per_step, batch_size)) - ]) if batch_size else 0 + num_step_remove = ( + random.choice([0, random.randint(1, min(max_add_remove_per_step, batch_size))]) + if batch_size + else 0 + ) num_step_add_replace = min(num_step_add, num_step_remove) @@ -463,23 +513,34 @@ def _generate_fake_step_update( batch_update_builder.removed_append(removal) # Get added requests from workload - for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]: + for add_req_params in workload_params[wdx : (wdx + num_step_add_replace)]: # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, - add_req_params.prompt_tokens, add_req_params.out_tokens)) + ( + add_remove_idx, + add_req_params.params, + add_req_params.prompt_tokens, + add_req_params.out_tokens, + ) + ) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch - add_reqs_append = workload_params[(wdx + - num_step_add_replace):(wdx + - num_step_add)] - batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.prompt_tokens, - add_req_params.out_tokens) - for adx, add_req_params in enumerate(add_reqs_append) - ]) + add_reqs_append = workload_params[ + (wdx + num_step_add_replace) : (wdx + num_step_add) + ] + batch_update_builder.added.extend( + [ + ( + adx + batch_size, + add_req_params.params, + add_req_params.prompt_tokens, + add_req_params.out_tokens, + ) + for adx, add_req_params in enumerate(add_reqs_append) + ] + ) persistent_batch.extend(add_reqs_append) pre_condense_batch_size = len(persistent_batch) wdx += num_step_add # Update workload offset @@ -488,8 +549,10 @@ def _generate_fake_step_update( last_nonempty_index = pre_condense_batch_size - 1 condensed_to_idxs = set() while batch_update_builder.removed: - if (last_nonempty_index in batch_update_builder.removed - or last_nonempty_index in condensed_to_idxs): + if ( + last_nonempty_index in batch_update_builder.removed + or last_nonempty_index in condensed_to_idxs + ): last_nonempty_index -= 1 continue # last_nonempty_index is the highest persistent batch index that was @@ -504,11 +567,10 @@ def _generate_fake_step_update( # move last_nonempty_index -> first_empty_index batch_update_builder.pop_removed() condensed_to_idxs.add(first_empty_index) - persistent_batch[first_empty_index] = persistent_batch[ - last_nonempty_index] + persistent_batch[first_empty_index] = persistent_batch[last_nonempty_index] batch_update_builder.moved.append( - (last_nonempty_index, first_empty_index, - MoveDirectionality.UNIDIRECTIONAL)) + (last_nonempty_index, first_empty_index, MoveDirectionality.UNIDIRECTIONAL) + ) last_nonempty_index -= 1 @@ -519,23 +581,26 @@ def _generate_fake_step_update( persistent_batch[:] = persistent_batch[0:condensed_batch_size] if condensed_batch_size > 1: - # Simulate arbitrary reorder_batch() in the kernel backend + # Simulate arbitrary batch ordering in the kernel backend # Generate a random number k of non-overlapping swap tuples k = random.randint(0, condensed_batch_size // 2) idxs = list(range(condensed_batch_size)) random.shuffle(idxs) - swaps = [ - tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k) - ] - batch_update_builder.moved.extend([ - (sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps - ]) + swaps = [tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)] + batch_update_builder.moved.extend( + [(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps] + ) for adx, bdx in swaps: - persistent_batch[adx], persistent_batch[bdx] = persistent_batch[ - bdx], persistent_batch[adx] + persistent_batch[adx], persistent_batch[bdx] = ( + persistent_batch[bdx], + persistent_batch[adx], + ) - return (batch_update_builder.get_and_reset(condensed_batch_size), wdx, - workload_size - wdx) + return ( + batch_update_builder.get_and_reset(condensed_batch_size), + wdx, + workload_size - wdx, + ) def _assert_valid( @@ -550,8 +615,10 @@ def _assert_valid( # Trivial case of empty persistent batch assert len(persistent_batch) == 0 if logits_w_lp.shape[0] != 0: - raise ValueError("Fake persistent batch is empty but logitsprocs " - f"output batch has shape {logits_w_lp.shape}") + raise ValueError( + "Fake persistent batch is empty but logitsprocs " + f"output batch has shape {logits_w_lp.shape}" + ) return # Validate logits for each fake request @@ -560,36 +627,40 @@ def _assert_valid( # Invoke the appropriate validation function for # the logitproc employed by this request fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn - fxn(test_fakes=test_fakes, + fxn( + test_fakes=test_fakes, persistent_batch=persistent_batch, logits_new=logits_w_lp, batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) @create_new_process_for_each_test() @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) -def test_logitsprocs(device: str, reqs_per_logitproc: int, - logitsprocs_under_test: list[str]): +def test_logitsprocs( + device: str, reqs_per_logitproc: int, logitsprocs_under_test: list[str] +): random.seed(40) torch.set_default_device(device) # Define a shuffled batch of requests which individually use a different # logitproc, or no logitproc at all workload_params = _generate_mixed_logitsprocs_batch_params( - reqs_per_logitproc=reqs_per_logitproc, - logitsprocs_types=logitsprocs_under_test) + reqs_per_logitproc=reqs_per_logitproc, logitsprocs_types=logitsprocs_under_test + ) workload_size = len(workload_params) # Create fake test data structures for testing. test_fakes = _generate_test_fakes(workload_size, device) wdx = 0 # Next request index in workload to add - persistent_batch: list[LogitsProcsRequestParams] = [ - ] # Persistent batch state, as list of workload indices + persistent_batch: list[ + LogitsProcsRequestParams + ] = [] # Persistent batch state, as list of workload indices # Generate fake removed request indices from current persistent # batch before adds diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index a7fde1990f7ed..95ddb18491691 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -2,36 +2,46 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import sys -from typing import Union +from typing import Any, Union import pytest from tests.utils import create_new_process_for_each_test -# yapf: disable -from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, - DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, - MAX_TOKENS, MODEL_NAME, - POOLING_MODEL_NAME, TEMP_GREEDY, - CustomLogitprocSource, - DummyLogitsProcessor, - dummy_module) +from tests.v1.logits_processors.utils import ( + DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, + MODEL_NAME, + POOLING_MODEL_NAME, + TEMP_GREEDY, + CustomLogitprocSource, + DummyLogitsProcessor, + WrappedPerReqLogitsProcessor, + dummy_module, + prompts, +) from tests.v1.logits_processors.utils import entry_points as fake_entry_points -from tests.v1.logits_processors.utils import prompts -# yapf: enable from vllm import LLM, SamplingParams -from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS, - LogitsProcessor) +from vllm.v1.sample.logits_processor import ( + STR_POOLING_REJECTS_LOGITSPROCS, + STR_SPEC_DEC_REJECTS_LOGITSPROCS, + LogitsProcessor, +) # Create a mixture of requests which do and don't utilize the dummy logitproc sampling_params_list = [ - SamplingParams(temperature=TEMP_GREEDY, - max_tokens=MAX_TOKENS, - extra_args={DUMMY_LOGITPROC_ARG: 128}), + SamplingParams( + temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 128}, + ), SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), - SamplingParams(temperature=TEMP_GREEDY, - max_tokens=MAX_TOKENS, - extra_args={DUMMY_LOGITPROC_ARG: 67}), + SamplingParams( + temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 67}, + ), SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), ] @@ -48,7 +58,7 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: 2. Server has *not* loaded dummy logitproc; test that all requests behave as if logitproc is *not* operating (output matches reference `LLM` output.) - + Args: kwargs: `LLM` constructor kwargs logitproc_loaded: server has loaded dummy logitproc if True @@ -72,7 +82,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: # Validate outputs for bdx, (out_lp, out_ref, params) in enumerate( - zip(outputs_logitproc, outputs_ref, sampling_params_list)): + zip(outputs_logitproc, outputs_ref, sampling_params_list) + ): lp_toks = out_lp.outputs[0].token_ids if logitproc_loaded and params.extra_args: # This request exercises custom logitproc; validate that logitproc @@ -80,8 +91,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: target_token = params.extra_args[DUMMY_LOGITPROC_ARG] if not all(x == target_token for x in lp_toks): raise AssertionError( - f"Request {bdx} generated {lp_toks}, shoud all be " - f"{target_token}") + f"Request {bdx} generated {lp_toks}, should all be {target_token}" + ) else: # This request does not exercise custom logitproc (or custom # logitproc is not enabled on this server); validate against @@ -89,16 +100,15 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: ref_toks = out_ref.outputs[0].token_ids if lp_toks != ref_toks: raise AssertionError( - f"Request {bdx} generated {lp_toks}, should match " - f"{ref_toks}") + f"Request {bdx} generated {lp_toks}, should match {ref_toks}" + ) @create_new_process_for_each_test() @pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource)) -def test_custom_logitsprocs(monkeypatch, - logitproc_source: CustomLogitprocSource): +def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource): """Test offline Python interface for passing custom logitsprocs - + Construct an `LLM` instance which loads a custom logitproc that has a well-defined behavior (mask out all tokens except one `target_token`) @@ -117,7 +127,7 @@ def test_custom_logitsprocs(monkeypatch, instance output * Logitproc passed in via {entrypoint, class object, fully-qualified class name (FQCN)} - test that dummy logitproc is utilized correctly when - provided via any of these three possible sources + provided via any of these three possible sources Args: monkeypatch: for setting env vars @@ -141,6 +151,7 @@ def test_custom_logitsprocs(monkeypatch, # Scenario: vLLM loads a logitproc from a preconfigured entrypoint # To that end, mock a dummy logitproc entrypoint import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore # fork is required for workers to see entrypoint patch @@ -162,15 +173,54 @@ def test_custom_logitsprocs(monkeypatch, @create_new_process_for_each_test() -@pytest.mark.parametrize("logitproc_source", [ - CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, - CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, - CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, -]) -def test_pooling_rejects_custom_logitsprocs( - monkeypatch, logitproc_source: CustomLogitprocSource): +def test_custom_logitsprocs_req(monkeypatch): + """Test passing request-level logits processor to offline Python interface + + Wrap a request-level logits processor to create a batch level logits + processor that has a well-defined behavior (mask out all tokens except one + `target_token`) + + Construct an `LLM` instance which loads the wrapped logits processor. Pass + the custom logitproc as a class object. + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Args: + monkeypatch: for setting env vars + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + _run_test( + {"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True + ) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"]) +@pytest.mark.parametrize( + "logitproc_source", + [ + CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, + CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, + CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, + ], +) +def test_rejects_custom_logitsprocs( + monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource +): """Validate that vLLM engine initialization properly rejects custom - logitsprocs when the model is a pooling model. + logitsprocs when the model is a pooling model or speculative decoding + enabled. Use `LLM` entrypoint. We expect `LLM` initialization to fail before the logitproc is actually loaded. @@ -194,44 +244,57 @@ def test_pooling_rejects_custom_logitsprocs( monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") random.seed(40) + test_params: dict[str, dict[str, Any]] = { + "pooling": { + "runner": "pooling", + "model": POOLING_MODEL_NAME, + "error_message": STR_POOLING_REJECTS_LOGITSPROCS, + "speculative_config": None, + }, + "spec_dec": { + "runner": "auto", + "model": MODEL_NAME, + "error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS, + "speculative_config": {"model": "ngram", "num_speculative_tokens": 1}, + }, + } + + config = test_params[model_scenario] + + llm_kwargs: dict[str, Any] = { + "runner": config["runner"], + "model": config["model"], + "gpu_memory_utilization": 0.1, + "speculative_config": config["speculative_config"], + } + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: - # Scenario: vLLM loads a pooling model and ignores a logitproc that is + # Scenario: vLLM loads a model and ignores a logitproc that is # available at a preconfigured entrypoint # Patch in dummy logitproc entrypoint import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore # fork is required for entrypoint patch to be visible to workers, # although they should ignore the entrypoint patch anyway monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") - llm = LLM( - runner="pooling", - model=POOLING_MODEL_NAME, - gpu_memory_utilization=0.1, - ) + llm = LLM(**llm_kwargs) # Require that no logitsprocs have been loaded - assert sum([ - 1 for _ in llm.llm_engine.model_executor.driver_worker.worker. - model_runner.input_batch.logitsprocs.all - ]) == 0 + worker = llm.llm_engine.model_executor.driver_worker.worker + assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0 return - kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: # Scenario: load logitproc based on fully-qualified class name (FQCN) - kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: # Scenario: load logitproc from provided class object - kwargs["logits_processors"] = [DummyLogitsProcessor] + llm_kwargs["logits_processors"] = [DummyLogitsProcessor] - with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS): - # Require that loading a pooling model alongside the logitproc raises + with pytest.raises(ValueError, match=config["error_message"]): + # Require that loading a model alongside the logitproc raises # the appropriate exception. - LLM( - runner="pooling", - model=POOLING_MODEL_NAME, - gpu_memory_utilization=0.1, - **kwargs, - ) + LLM(**llm_kwargs) diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index a01a479e5b248..9c5b4ff0ba170 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -10,18 +10,18 @@ import openai import pytest import pytest_asyncio -from tests.utils import (RemoteOpenAIServerCustom, - create_new_process_for_each_test) -# yapf: disable -from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, - DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, - MAX_TOKENS, MODEL_NAME, - TEMP_GREEDY, dummy_module) +from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test +from tests.v1.logits_processors.utils import ( + DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, + MODEL_NAME, + TEMP_GREEDY, + dummy_module, + prompts, +) from tests.v1.logits_processors.utils import entry_points as fake_entry_points -from tests.v1.logits_processors.utils import prompts - -# yapf: enable def _server_with_logitproc_entrypoint( @@ -33,11 +33,12 @@ def _server_with_logitproc_entrypoint( # Patch `entry_points` to inject logitproc entrypoint import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore from vllm.entrypoints.cli import main # fork is required for workers to see entrypoint patch - os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork" if env_dict is not None: os.environ.update(env_dict) @@ -55,10 +56,11 @@ def _server_with_logitproc_module( # Patch `modules` to inject dummy logitproc module from vllm.entrypoints.cli import main + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module # fork is required for workers to see entrypoint patch - os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork" if env_dict is not None: os.environ.update(env_dict) @@ -80,8 +82,9 @@ def default_server_args(): ] -@pytest.fixture(scope="function", - params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]) +@pytest.fixture( + scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]] +) def server(default_server_args, request, monkeypatch): """Consider two server configurations: (1) --logits-processors cli arg specifies dummy logits processor via fully- @@ -102,8 +105,7 @@ def server(default_server_args, request, monkeypatch): args = default_server_args _server_fxn = _server_with_logitproc_entrypoint - with RemoteOpenAIServerCustom(MODEL_NAME, args, - _server_fxn) as remote_server: + with RemoteOpenAIServerCustom(MODEL_NAME, args, _server_fxn) as remote_server: yield remote_server @@ -133,7 +135,7 @@ api_keyword_args = { ) async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): """Test custom logitsprocs when starting OpenAI server from CLI - + Launch vLLM OpenAI-compatible server, configured to load a custom logitproc that has a well-defined behavior (mask out all tokens except one `target_token`). @@ -157,9 +159,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): # For requests which activate the dummy logitproc, choose one of # two `target_token` values which are known not to be EOS tokens request_keyword_args["extra_body"] = { - "vllm_xargs": { - DUMMY_LOGITPROC_ARG: target_token - } + "vllm_xargs": {DUMMY_LOGITPROC_ARG: target_token} } batch = await client.completions.create( model=model_name, @@ -173,8 +173,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): choices: openai.types.CompletionChoice = batch.choices toks = choices[0].logprobs.tokens if not all([x == toks[0] for x in toks]): - raise AssertionError( - f"Generated {toks} should all be {toks[0]}") + raise AssertionError(f"Generated {toks} should all be {toks[0]}") # Alternate whether to activate dummy logitproc for each request use_dummy_logitproc = not use_dummy_logitproc diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index c0bfc1a18feca..9a1d5505a5f99 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -3,15 +3,23 @@ import types from enum import Enum, auto -from typing import Optional +from typing import Any, Optional import torch from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, - LogitsProcessor, - MoveDirectionality) +from vllm.v1.sample.logits_processor import ( + LOGITSPROCS_GROUP, + AdapterLogitsProcessor, + BatchUpdate, + LogitsProcessor, + RequestLogitsProcessor, +) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates + +logger = init_logger(__name__) MODEL_NAME = "facebook/opt-125m" POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" @@ -25,6 +33,7 @@ DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" class CustomLogitprocSource(Enum): """How to source a logitproc for testing purposes""" + LOGITPROC_SOURCE_NONE = auto() # No custom logitproc LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN) @@ -43,54 +52,38 @@ prompts = [ class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): - self.req_info: dict[int, SamplingParams] = {} + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and (target_token := - params.extra_args.get("target_token")): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: return logits # Save target values before modification - rows_list = list(self.req_info.keys()) - cols = torch.tensor([self.req_info[i] for i in rows_list], - dtype=torch.long, - device=logits.device) - rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + cols = torch.tensor( + list(self.req_info.values()), dtype=torch.long, device=logits.device + ) + rows = torch.tensor( + list(self.req_info.keys()), dtype=torch.long, device=logits.device + ) values_to_keep = logits[rows, cols].clone() # Mask all but target tokens - logits[rows] = float('-inf') + logits[rows] = float("-inf") logits[rows, cols] = values_to_keep return logits @@ -123,5 +116,63 @@ class EntryPoints(list): self.names = [ep.name for ep in eps] +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + """Fake version of importlib.metadata.entry_points""" entry_points = lambda group: EntryPoints(group) diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py new file mode 100644 index 0000000000000..bf780b1f36adf --- /dev/null +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + +import pytest + +from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM +from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger + + +class DummyStatLogger: + """ + A dummy stat logger for testing purposes. + Implements the minimal interface expected by StatLoggerManager. + """ + + def __init__(self, vllm_config, engine_idx): + self.vllm_config = vllm_config + self.engine_idx = engine_idx + self.recorded = [] + self.logged = False + self.engine_initialized = False + + def record(self, scheduler_stats, iteration_stats, engine_idx): + self.recorded.append((scheduler_stats, iteration_stats, engine_idx)) + + def log(self): + self.logged = True + + def log_engine_initialized(self): + self.engine_initialized = True + + +@pytest.fixture +def log_stats_enabled_engine_args(): + """ + Shared fixture providing common AsyncEngineArgs configuration + used across multiple tests. + """ + return AsyncEngineArgs( + model="distilbert/distilgpt2", + dtype="half", + disable_log_stats=False, + enforce_eager=True, + ) + + +@pytest.mark.asyncio +async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args): + """ + RayPrometheusStatLogger should replace the default PrometheusStatLogger + """ + + engine = AsyncLLM.from_engine_args( + log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger] + ) + assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger) + engine.shutdown() + + +@pytest.mark.asyncio +async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): + """ + It's still possible to use custom stat loggers exclusively by passing + disable_log_stats=True in addition to a list of custom stat loggers. + """ + # Create engine_args with disable_log_stats=True for this test + disabled_log_engine_args = copy.deepcopy(log_stats_enabled_engine_args) + disabled_log_engine_args.disable_log_stats = True + + # Disable default loggers; pass custom stat logger to the constructor + engine = AsyncLLM.from_engine_args( + disabled_log_engine_args, stat_loggers=[DummyStatLogger] + ) + + assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 + assert isinstance( + engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger + ) + + # log_stats is still True, since custom stat loggers are used + assert engine.log_stats + + engine.shutdown() diff --git a/tests/v1/test_metrics_reader.py b/tests/v1/metrics/test_metrics_reader.py similarity index 78% rename from tests/v1/test_metrics_reader.py rename to tests/v1/metrics/test_metrics_reader.py index c05de5e4cb645..1c90e6d335274 100644 --- a/tests/v1/test_metrics_reader.py +++ b/tests/v1/metrics/test_metrics_reader.py @@ -4,8 +4,15 @@ import prometheus_client import pytest -from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector, - get_metrics_snapshot) +from vllm.v1.metrics.reader import ( + Counter, + Gauge, + Histogram, + Vector, + get_metrics_snapshot, +) + +pytestmark = pytest.mark.cpu_test @pytest.fixture(autouse=True) @@ -18,10 +25,12 @@ def test_registry(monkeypatch): @pytest.mark.parametrize("num_engines", [1, 4]) def test_gauge_metric(test_registry, num_engines): - g = prometheus_client.Gauge("vllm:test_gauge", - "Test gauge metric", - labelnames=["model", "engine_index"], - registry=test_registry) + g = prometheus_client.Gauge( + "vllm:test_gauge", + "Test gauge metric", + labelnames=["model", "engine_index"], + registry=test_registry, + ) for i in range(num_engines): g.labels(model="foo", engine_index=str(i)).set(98.5) @@ -39,10 +48,12 @@ def test_gauge_metric(test_registry, num_engines): @pytest.mark.parametrize("num_engines", [1, 4]) def test_counter_metric(test_registry, num_engines): - c = prometheus_client.Counter("vllm:test_counter", - "Test counter metric", - labelnames=["model", "engine_index"], - registry=test_registry) + c = prometheus_client.Counter( + "vllm:test_counter", + "Test counter metric", + labelnames=["model", "engine_index"], + registry=test_registry, + ) for i in range(num_engines): c.labels(model="bar", engine_index=str(i)).inc(19) @@ -60,11 +71,13 @@ def test_counter_metric(test_registry, num_engines): @pytest.mark.parametrize("num_engines", [1, 4]) def test_histogram_metric(test_registry, num_engines): - h = prometheus_client.Histogram("vllm:test_histogram", - "Test histogram metric", - labelnames=["model", "engine_index"], - buckets=[10, 20, 30, 40, 50], - registry=test_registry) + h = prometheus_client.Histogram( + "vllm:test_histogram", + "Test histogram metric", + labelnames=["model", "engine_index"], + buckets=[10, 20, 30, 40, 50], + registry=test_registry, + ) for i in range(num_engines): hist = h.labels(model="blaa", engine_index=str(i)) hist.observe(42) @@ -95,7 +108,8 @@ def test_vector_metric(test_registry, num_engines): "vllm:spec_decode_num_accepted_tokens_per_pos", "Vector-like counter metric", labelnames=["position", "model", "engine_index"], - registry=test_registry) + registry=test_registry, + ) for i in range(num_engines): c.labels(position="0", model="llama", engine_index=str(i)).inc(10) c.labels(position="1", model="llama", engine_index=str(i)).inc(5) diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py index 92f6c6f0e89cd..f08d9f684921d 100644 --- a/tests/v1/metrics/test_ray_metrics.py +++ b/tests/v1/metrics/test_ray_metrics.py @@ -1,23 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import pytest import ray -from vllm.config import ModelDType +from vllm.config.model import ModelDType from vllm.sampling_params import SamplingParams from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM -from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger - - -@pytest.fixture(scope="function", autouse=True) -def use_v1_only(monkeypatch): - """ - The change relies on V1 APIs, so set VLLM_USE_V1=1. - """ - monkeypatch.setenv('VLLM_USE_V1', '1') - +from vllm.v1.metrics.ray_wrappers import RayPrometheusMetric, RayPrometheusStatLogger MODELS = [ "distilbert/distilgpt2", @@ -33,24 +23,19 @@ def test_engine_log_metrics_ray( dtype: ModelDType, max_tokens: int, ) -> None: - """ Simple smoke test, verifying this can be used without exceptions. + """Simple smoke test, verifying this can be used without exceptions. Need to start a Ray cluster in order to verify outputs.""" @ray.remote(num_gpus=1) class EngineTestActor: - async def run(self): - # Set environment variable inside the Ray actor since environment - # variables from pytest fixtures don't propagate to Ray actors - os.environ['VLLM_USE_V1'] = '1' - - engine_args = AsyncEngineArgs(model=model, - dtype=dtype, - disable_log_stats=False, - enforce_eager=True) + engine_args = AsyncEngineArgs( + model=model, dtype=dtype, disable_log_stats=False, enforce_eager=True + ) engine = AsyncLLM.from_engine_args( - engine_args, stat_loggers=[RayPrometheusStatLogger]) + engine_args, stat_loggers=[RayPrometheusStatLogger] + ) for i, prompt in enumerate(example_prompts): results = engine.generate( @@ -65,3 +50,47 @@ def test_engine_log_metrics_ray( # Create the actor and call the async method actor = EngineTestActor.remote() # type: ignore[attr-defined] ray.get(actor.run.remote()) + + +def test_sanitized_opentelemetry_name(): + """Test the metric name sanitization logic for Ray.""" + + # Only a-z, A-Z, 0-9, _, test valid characters are preserved + valid_name = "valid_metric_123_abcDEF" + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(valid_name) == valid_name + ) + + # Test dash, dot, are replaced + name_with_dash_dot = "metric-name.test" + expected = "metric_name_test" + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(name_with_dash_dot) + == expected + ) + + # Test colon is replaced with underscore + name_with_colon = "metric:name" + expected = "metric_name" + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(name_with_colon) + == expected + ) + + # Test multiple invalid characters are replaced + name_with_invalid = "metric:name@with#special%chars" + expected = "metric_name_with_special_chars" + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(name_with_invalid) + == expected + ) + + # Test mixed valid and invalid characters + complex_name = "vllm:engine_stats/time.latency_ms-99p" + expected = "vllm_engine_stats_time_latency_ms_99p" + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(complex_name) == expected + ) + + # Test empty string + assert RayPrometheusMetric._get_sanitized_opentelemetry_name("") == "" diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py new file mode 100644 index 0000000000000..67a2d1739b6bb --- /dev/null +++ b/tests/v1/metrics/test_stats.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.v1.metrics.stats import IterationStats + + +def test_iteration_stats_repr(): + iteration_stats = IterationStats() + iteration_stats.iteration_timestamp = 0 + expected_repr = ( + "IterationStats(" + "iteration_timestamp=0, " + "num_generation_tokens=0, " + "num_prompt_tokens=0, " + "num_preempted_reqs=0, " + "finished_requests=[], " + "max_num_generation_tokens_iter=[], " + "n_params_iter=[], " + "time_to_first_tokens_iter=[], " + "inter_token_latencies_iter=[], " + "waiting_lora_adapters={}, " + "running_lora_adapters={})" + ) + assert repr(iteration_stats) == expected_repr diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index e835c029634ce..86b75deadda7d 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -3,16 +3,20 @@ import itertools from collections.abc import Generator +from typing import get_args import pytest import torch from tests.v1.sample.utils import ( - BatchLogprobsComposition, BatchLogprobsSpecType, + BatchLogprobsComposition, + BatchLogprobsSpecType, assert_incr_detok_str_matches_non_incr_detok_str, - compute_correct_cumulative_logprob, get_test_batch) + compute_correct_cumulative_logprob, + get_test_batch, +) from vllm import SamplingParams -from vllm.config import LogprobsMode +from vllm.config.model import LogprobsMode from ...conftest import HfRunner, VllmRunner @@ -28,22 +32,23 @@ SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT @pytest.fixture( scope="module", # Parameterize APC - params=[False, True]) + params=[False, True], +) def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: with vllm_runner( - MODEL, - dtype=DTYPE, - max_logprobs=7, - # Very small number of batched tokens to ensure - # that we test chunking. - max_num_batched_tokens=16, - max_num_seqs=16, - max_model_len=128, - enforce_eager=True, - #TODO: enable this once we support it for - # prompt logprobs. - enable_prefix_caching=request.param, - gpu_memory_utilization=0.4, # up to 2 alive concurrently + MODEL, + dtype=DTYPE, + max_logprobs=7, + # Very small number of batched tokens to ensure + # that we test chunking. + max_num_batched_tokens=16, + max_num_seqs=16, + max_model_len=128, + enforce_eager=True, + # TODO: enable this once we support it for + # prompt logprobs. + enable_prefix_caching=request.param, + gpu_memory_utilization=0.4, # up to 2 alive concurrently ) as vllm_model: yield vllm_model @@ -95,8 +100,8 @@ def _repeat_logprob_config( num_test_prompts = len(test_prompts) # Make sure there is a logprobs configuration for each test prompt logprob_prompt_logprob_list = list( - itertools.islice(itertools.cycle(logprob_prompt_logprob_list), - num_test_prompts)) + itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts) + ) # Now the number of prompts should match the number of sample params combos assert num_test_prompts == len(logprob_prompt_logprob_list) return logprob_prompt_logprob_list @@ -114,24 +119,28 @@ def _run_and_validate( do_apc: bool, ) -> None: vllm_results = vllm_model.llm.generate( - test_prompts, sampling_params=vllm_sampling_params) + test_prompts, sampling_params=vllm_sampling_params + ) for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( - vllm_results, hf_logprobs, hf_outputs, - logprob_prompt_logprob_list): - + vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list + ): # Extract request-level (prompt)logprobs config num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob # Test whether sampled token output is consistent between vLLM and HF # vLLM prompt+completion should match HF output if temperature == 0.0: - assert (vllm_result.prompt_token_ids + - vllm_result.outputs[0].token_ids == hf_output[0]) + assert ( + vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids + == hf_output[0] + ) else: # Sampled tokens won't match if not greedy - assert (vllm_result.prompt_token_ids == hf_output[0] - [:len(vllm_result.prompt_token_ids)]) + assert ( + vllm_result.prompt_token_ids + == hf_output[0][: len(vllm_result.prompt_token_ids)] + ) # Validate sample logprobs if num_top_logprobs is not None: @@ -140,8 +149,9 @@ def _run_and_validate( # correct assert vllm_result.outputs[0].logprobs is not None assert len(vllm_result.outputs[0].logprobs) == max_tokens - for logprobs, token_id in zip(vllm_result.outputs[0].logprobs, - vllm_result.outputs[0].token_ids): + for logprobs, token_id in zip( + vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids + ): assert logprobs is not None # Confirm that the output token appears among the logprobs @@ -158,23 +168,26 @@ def _run_and_validate( if num_top_logprobs > 0: # We should have an entry for each of the topk ranks all_ranks = {lp.rank for lp in logprobs.values()} - assert all(r in all_ranks - for r in range(1, num_top_logprobs + 1)) + assert all(r in all_ranks for r in range(1, num_top_logprobs + 1)) output_text = vllm_result.outputs[0].text output_string_from_most_likely_tokens_lst: list[str] = [] for top_logprobs in vllm_result.outputs[0].logprobs: top_logprob = next(iter(top_logprobs.values())) output_string_from_most_likely_tokens_lst.append( - top_logprob.decoded_token) + top_logprob.decoded_token + ) output_string_from_most_likely_tokens = "".join( - output_string_from_most_likely_tokens_lst) + output_string_from_most_likely_tokens_lst + ) assert_incr_detok_str_matches_non_incr_detok_str( - output_text, output_string_from_most_likely_tokens, + output_text, + output_string_from_most_likely_tokens, "The output text from the top logprob for each token " "position should be the same as the output text in the " - "result.") + "result.", + ) # Compare vLLM sample logprobs to HF vllm_sample_logprobs = vllm_result.outputs[0].logprobs @@ -186,11 +199,12 @@ def _run_and_validate( logprob, hf_logprob[i][-1][token_id].item(), atol=1e-2, - rtol=1e-2) - assert isinstance( - sample_logprob.decoded_token, - str), ("The token should be decoded by the time it is" - " returned to the user.") + rtol=1e-2, + ) + assert isinstance(sample_logprob.decoded_token, str), ( + "The token should be decoded by the time it is" + " returned to the user." + ) # At this point we know the sample logprobs are correct for this # request. Validate that cumulative_logprob is actually the sum. @@ -200,7 +214,8 @@ def _run_and_validate( vllm_result.outputs[0].cumulative_logprob, compute_correct_cumulative_logprob(vllm_result.outputs[0]), atol=1e-6, - rtol=1e-6) + rtol=1e-6, + ) else: # Logprobs disabled for this request; should be None assert vllm_result.outputs[0].logprobs is None @@ -213,17 +228,17 @@ def _run_and_validate( assert vllm_result.prompt_logprobs[0] is None # - Prompt logprobs are returned for all indices in # the prompt - assert len(vllm_result.prompt_logprobs) == len( - vllm_result.prompt_token_ids) + assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids) for prompt_logprobs, prompt_token_id in zip( - vllm_result.prompt_logprobs[1:], - vllm_result.prompt_token_ids[1:]): + vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:] + ): assert prompt_logprobs is not None # Confirm that the prompt token appears among the logprobs assert prompt_token_id in prompt_logprobs - token_in_topk = prompt_logprobs[ - prompt_token_id].rank <= num_top_prompt_logprobs + token_in_topk = ( + prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs + ) # If the prompt token is not included in the top K # logprob, it can return 1 more data @@ -235,8 +250,9 @@ def _run_and_validate( if num_top_prompt_logprobs > 0: # We should have an entry for each of the topk ranks all_ranks = {lp.rank for lp in prompt_logprobs.values()} - assert all(r in all_ranks - for r in range(1, num_top_prompt_logprobs + 1)) + assert all( + r in all_ranks for r in range(1, num_top_prompt_logprobs + 1) + ) # Compare prompt logprobs to HF # The first prompt logprob is always None, so we compare it from @@ -248,19 +264,23 @@ def _run_and_validate( logprob.logprob, hf_logprob[0][i][token_id].item(), atol=2e-2, - rtol=2e-2) + rtol=2e-2, + ) else: assert vllm_result.prompt_logprobs is None -@pytest.mark.parametrize("batch_logprobs_composition", - [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]) +@pytest.mark.parametrize( + "batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT] +) @pytest.mark.parametrize("temperature", [0.0, 2.0]) def test_get_logprobs_and_prompt_logprobs( - hf_model, vllm_model, - batch_logprobs_composition: BatchLogprobsComposition, - temperature: float, example_prompts: list[str], - monkeypatch: pytest.MonkeyPatch) -> None: + hf_model, + vllm_model, + batch_logprobs_composition: BatchLogprobsComposition, + temperature: float, + example_prompts: list[str], +) -> None: """Test V1 Engine logprobs & prompt logprobs Exercise a variety of combinations of `logprobs` and `prompt_logprobs` @@ -287,212 +307,204 @@ def test_get_logprobs_and_prompt_logprobs( temperature: "temperature" sampling parameter example_prompts: example prompt fixture """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching - if do_apc and (temperature < 2.0 - or batch_logprobs_composition != SAMPLE_PROMPT): - # Skip some test-cases to save time. - pytest.skip() - test_prompts = example_prompts + do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching + if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT): + # Skip some test-cases to save time. + pytest.skip() + test_prompts = example_prompts - max_tokens = 5 - hf_outputs = hf_model.generate_greedy( - test_prompts, + max_tokens = 5 + hf_outputs = hf_model.generate_greedy( + test_prompts, + max_tokens=max_tokens, + ) + hf_logprobs = hf_model.generate_greedy_logprobs( + test_prompts, + max_tokens=max_tokens, + ) + + # Batch has mixed sample params + # (different logprobs/prompt logprobs combos) + logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) + + # Ensure that each test prompt has a logprob config for testing + logprob_prompt_logprob_list = _repeat_logprob_config( + test_prompts, logprob_prompt_logprob_list + ) + # Generate SamplingParams + vllm_sampling_params = [ + SamplingParams( max_tokens=max_tokens, + logprobs=num_lp, + prompt_logprobs=num_plp, + temperature=temperature, + seed=1984, ) - hf_logprobs = hf_model.generate_greedy_logprobs( - test_prompts, + for num_lp, num_plp in logprob_prompt_logprob_list + ] + for _ in range(2 if do_apc else 1): + _run_and_validate( + vllm_model=vllm_model, + test_prompts=test_prompts, + vllm_sampling_params=vllm_sampling_params, + hf_logprobs=hf_logprobs, + hf_outputs=hf_outputs, + logprob_prompt_logprob_list=logprob_prompt_logprob_list, + temperature=temperature, max_tokens=max_tokens, + do_apc=do_apc, ) - # Batch has mixed sample params - # (different logprobs/prompt logprobs combos) - logprob_prompt_logprob_list = get_test_batch( - batch_logprobs_composition) - # Ensure that each test prompt has a logprob config for testing - logprob_prompt_logprob_list = _repeat_logprob_config( - test_prompts, logprob_prompt_logprob_list) - # Generate SamplingParams - vllm_sampling_params = [ - SamplingParams(max_tokens=max_tokens, - logprobs=num_lp, - prompt_logprobs=num_plp, - temperature=temperature, - seed=1984) - for num_lp, num_plp in logprob_prompt_logprob_list - ] - for _ in range(2 if do_apc else 1): - _run_and_validate( - vllm_model=vllm_model, - test_prompts=test_prompts, - vllm_sampling_params=vllm_sampling_params, - hf_logprobs=hf_logprobs, - hf_outputs=hf_outputs, - logprob_prompt_logprob_list=logprob_prompt_logprob_list, - temperature=temperature, - max_tokens=max_tokens, - do_apc=do_apc) - - -def test_max_logprobs(monkeypatch: pytest.MonkeyPatch): +def test_max_logprobs(): """vLLM v1 engine should fail a request with `logprobs > max_logprobs` Should also fail for `prompt_logprobs > max_logprobs` APC should not matter as this test checks basic request validation. """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=1, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.15, + max_model_len=256, + ) + vllm_sampling_params = SamplingParams(logprobs=1) + # should pass + runner.generate(["Hello world"], sampling_params=vllm_sampling_params) - runner = VllmRunner( - "facebook/opt-125m", - max_logprobs=1, - enable_prefix_caching=False, - # 2 other llms alive during whole session - gpu_memory_utilization=0.15, - max_model_len=256) - vllm_sampling_params = SamplingParams(logprobs=1) - # should pass - runner.generate(["Hello world"], sampling_params=vllm_sampling_params) - - bad_sampling_params = SamplingParams(logprobs=2) - with pytest.raises(ValueError): - runner.generate(["Hello world"], - sampling_params=bad_sampling_params) + bad_sampling_params = SamplingParams(logprobs=2) + with pytest.raises(ValueError): + runner.generate(["Hello world"], sampling_params=bad_sampling_params) -def test_none_logprobs(vllm_model, example_prompts, - monkeypatch: pytest.MonkeyPatch): +def test_none_logprobs(vllm_model, example_prompts): """Engine should return `logprobs` and `prompt_logprobs` as `None` Args: vllm_model: vLLM model fixture example_prompts: list of example prompts (test fixture) """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - max_tokens = 5 + max_tokens = 5 - sampling_params_logprobs_none = SamplingParams( - max_tokens=max_tokens, - logprobs=None, - prompt_logprobs=None, - temperature=0.0, - ) - results_logprobs_none = vllm_model.llm.generate( - example_prompts, - sampling_params=sampling_params_logprobs_none, - ) + sampling_params_logprobs_none = SamplingParams( + max_tokens=max_tokens, + logprobs=None, + prompt_logprobs=None, + temperature=0.0, + ) + results_logprobs_none = vllm_model.llm.generate( + example_prompts, + sampling_params=sampling_params_logprobs_none, + ) - for i in range(len(results_logprobs_none)): - # Check sample logprobs are None - assert results_logprobs_none[i].outputs[0].logprobs is None - assert results_logprobs_none[i].outputs[ - 0].cumulative_logprob is None - # Check prompt logprobs are None - assert results_logprobs_none[i].prompt_logprobs is None + for i in range(len(results_logprobs_none)): + # Check sample logprobs are None + assert results_logprobs_none[i].outputs[0].logprobs is None + assert results_logprobs_none[i].outputs[0].cumulative_logprob is None + # Check prompt logprobs are None + assert results_logprobs_none[i].prompt_logprobs is None -def test_zero_logprobs(vllm_model, example_prompts, - monkeypatch: pytest.MonkeyPatch): +def test_zero_logprobs(vllm_model, example_prompts): """Engine should return sampled token and prompt token logprobs Args: vllm_model: vLLM model fixture example_prompts: list of example prompts (test fixture) """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - max_tokens = 5 + max_tokens = 5 - sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, - logprobs=0, - prompt_logprobs=0, - temperature=0.0) - results_logprobs_zero = vllm_model.llm.generate( - example_prompts, sampling_params=sampling_params_logprobs_zero) + sampling_params_logprobs_zero = SamplingParams( + max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0 + ) + results_logprobs_zero = vllm_model.llm.generate( + example_prompts, sampling_params=sampling_params_logprobs_zero + ) - for i in range(len(results_logprobs_zero)): - # Check that there is one sample logprob dict for each - # sample token - logprobs = results_logprobs_zero[i].outputs[0].logprobs - prompt_logprobs = results_logprobs_zero[i].prompt_logprobs - sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids - prompt_token_ids = results_logprobs_zero[i].prompt_token_ids - assert logprobs is not None - assert len(sampled_token_ids) == len(logprobs) - assert results_logprobs_zero[i].outputs[ - 0].cumulative_logprob is not None - # Check that there is one prompt logprob dict for each - # prompt token - assert prompt_logprobs is not None - assert len(prompt_token_ids) == len(prompt_logprobs) + for i in range(len(results_logprobs_zero)): + # Check that there is one sample logprob dict for each + # sample token + logprobs = results_logprobs_zero[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_zero[i].prompt_logprobs + sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids + prompt_token_ids = results_logprobs_zero[i].prompt_token_ids + assert logprobs is not None + assert len(sampled_token_ids) == len(logprobs) + assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None + # Check that there is one prompt logprob dict for each + # prompt token + assert prompt_logprobs is not None + assert len(prompt_token_ids) == len(prompt_logprobs) -def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): - """Engine should return all vocabulary logprobs +def test_all_logprobs(example_prompts): + """Engine should return all vocabulary logprobs and prompt logprobs Args: example_prompts: list of example prompts (test fixture) """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - runner = VllmRunner( - "facebook/opt-125m", - max_logprobs=-1, - enable_prefix_caching=False, - # 2 other llms alive during whole session - gpu_memory_utilization=0.15, - max_model_len=256) - sampling_params_logprobs_all = SamplingParams(max_tokens=5, - logprobs=-1) - results_logprobs_all = runner.llm.generate( - example_prompts, sampling_params=sampling_params_logprobs_all) - vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() - for i in range(len(results_logprobs_all)): - logprobs = results_logprobs_all[i].outputs[0].logprobs - assert logprobs is not None - for logprob in logprobs: - assert len(logprob) == vocab_size + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=-1, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.15, + max_model_len=256, + ) + + sampling_params_logprobs_all = SamplingParams( + max_tokens=5, logprobs=-1, prompt_logprobs=-1 + ) + results_logprobs_all = runner.llm.generate( + example_prompts, sampling_params=sampling_params_logprobs_all + ) + vocab_size = runner.llm.llm_engine.model_config.get_vocab_size() + + for i in range(len(results_logprobs_all)): + logprobs = results_logprobs_all[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_all[i].prompt_logprobs + assert logprobs is not None + for logprob in logprobs: + assert len(logprob) == vocab_size + assert prompt_logprobs is not None + assert prompt_logprobs[0] is None + for prompt_logprob in prompt_logprobs[1:]: + assert len(prompt_logprob) == vocab_size -@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) -def test_logprobs_mode(logprobs_mode: LogprobsMode, - monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode)) +def test_logprobs_mode(logprobs_mode: LogprobsMode): """Test with LLM engine with different logprobs_mode. For logprobs, we should have non-positive values. For logits, we should expect at least one positive values. """ from vllm import LLM - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - llm = LLM( - "facebook/opt-125m", - max_logprobs=5, - enable_prefix_caching=False, - # 2 other llms alive during whole session - gpu_memory_utilization=0.05, - max_model_len=16, - logprobs_mode=logprobs_mode) - vllm_sampling_params = SamplingParams(logprobs=1) - results = llm.generate(["Hello world"], - sampling_params=vllm_sampling_params) + llm = LLM( + "facebook/opt-125m", + max_logprobs=5, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.05, + max_model_len=16, + logprobs_mode=logprobs_mode, + ) + vllm_sampling_params = SamplingParams(logprobs=1) + results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params) - total_token_with_logprobs = 0 - positive_values = 0 - for output in results[0].outputs: - for logprobs in output.logprobs: - for token_id in logprobs: - logprob = logprobs[token_id] - if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, - LogprobsMode.PROCESSED_LOGPROBS): - assert logprob.logprob <= 0 - if logprob.logprob > 0: - positive_values = positive_values + 1 - total_token_with_logprobs = total_token_with_logprobs + 1 - assert total_token_with_logprobs >= len(results[0].outputs) - if logprobs_mode in (LogprobsMode.RAW_LOGITS, - LogprobsMode.PROCESSED_LOGITS): - assert positive_values > 0 - del llm + total_token_with_logprobs = 0 + positive_values = 0 + for output in results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + logprob = logprobs[token_id] + if logprobs_mode in ("raw_logprobs", "processed_logprobs"): + assert logprob.logprob <= 0 + if logprob.logprob > 0: + positive_values = positive_values + 1 + total_token_with_logprobs = total_token_with_logprobs + 1 + assert total_token_with_logprobs >= len(results[0].outputs) + if logprobs_mode in ("raw_logits", "processed_logits"): + assert positive_values > 0 + del llm diff --git a/tests/v1/sample/test_logprobs_e2e.py b/tests/v1/sample/test_logprobs_e2e.py index 7f41355ff7ce4..b3233e50fbf18 100644 --- a/tests/v1/sample/test_logprobs_e2e.py +++ b/tests/v1/sample/test_logprobs_e2e.py @@ -15,22 +15,23 @@ EXPECTED_VALUE = 0.62 MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501 SERVER_ARGS = [ - "--enforce_eager", "--no_enable_prefix_caching", - "--gpu-memory-utilization=0.8" + "--enforce_eager", + "--no_enable_prefix_caching", + "--gpu-memory-utilization=0.8", ] NUM_CONCURRENT = 100 def test_prompt_logprobs_e2e(): - results = lm_eval.simple_evaluate(model="vllm", - model_args=MODEL_ARGS, - tasks=TASK, - batch_size="auto") + results = lm_eval.simple_evaluate( + model="vllm", model_args=MODEL_ARGS, tasks=TASK, batch_size="auto" + ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" def test_prompt_logprobs_e2e_server(): @@ -40,7 +41,8 @@ def test_prompt_logprobs_e2e_server(): model_args = ( f"model={MODEL}," f"base_url={url}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -49,6 +51,7 @@ def test_prompt_logprobs_e2e_server(): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 4e912f98f376f..8df10f8c3afa5 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -6,11 +6,11 @@ import pytest import torch import torch.nn.functional as F +from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, - RejectionSampler) +from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = current_platform.device_type @@ -21,10 +21,13 @@ def rejection_sampler(): return RejectionSampler() -def create_logits_tensor(output_token_ids: list[list[int]], - vocab_size: int = 100) -> torch.Tensor: +def create_logits_tensor( + output_token_ids: list[list[int]], + vocab_size: int = 100, + token_idx_to_override: Optional[int] = None, +) -> torch.Tensor: """Helper function to create logits tensor that - will produce desired token ids on argmax""" + will produce desired token ids on argmax""" token_ids = [tokens[:-1] for tokens in output_token_ids] num_total_tokens = sum(len(tokens) for tokens in token_ids) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) @@ -33,19 +36,29 @@ def create_logits_tensor(output_token_ids: list[list[int]], for j, token_id in enumerate(tokens): logits[start_loc + j, token_id] = 100.0 start_loc += len(tokens) + if token_idx_to_override: + logits[:, token_idx_to_override] = 99.0 return logits def create_sampling_metadata( all_greedy: bool, + output_token_ids: Optional[list[list[int]]] = None, + prompt_token_ids: Optional[torch.Tensor] = None, + spec_token_ids: Optional[torch.Tensor] = None, temperature: Optional[torch.Tensor] = None, top_k: Optional[torch.Tensor] = None, top_p: Optional[torch.Tensor] = None, generators: Optional[dict[int, Any]] = None, + frequency_penalties: Optional[list[float]] = None, + presence_penalties: Optional[list[float]] = None, + repetition_penalties: Optional[list[float]] = None, + bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None, + allowed_token_ids_mask: Optional[torch.Tensor] = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set - to the given value. Either all greedy or all random sampling - is used. + to the given value. Either all greedy or all random sampling + is used. """ generators = generators or {} if all_greedy: @@ -53,6 +66,21 @@ def create_sampling_metadata( else: assert temperature is not None + if any([frequency_penalties, presence_penalties, repetition_penalties]): + no_penalties = False + + assert output_token_ids + assert len(output_token_ids) > 0 + + frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE) + presence_penalties = torch.tensor(presence_penalties, device=DEVICE) + repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE) + else: + no_penalties = True + frequency_penalties = torch.tensor([]) + presence_penalties = torch.tensor([]) + repetition_penalties = torch.tensor([]) + return SamplingMetadata( temperature=temperature, all_greedy=all_greedy, @@ -61,14 +89,15 @@ def create_sampling_metadata( top_k=top_k, generators=generators, max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - allowed_token_ids_mask=None, - bad_words_token_ids={}, + no_penalties=no_penalties, + prompt_token_ids=prompt_token_ids, + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, + output_token_ids=[] if output_token_ids is None else output_token_ids, + spec_token_ids=[] if spec_token_ids is None else spec_token_ids, + allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids, logitsprocs=LogitsProcessors(), ) @@ -81,10 +110,10 @@ def test_perfect_match(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -93,9 +122,7 @@ def test_perfect_match(rejection_sampler): bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) - expected = torch.tensor([[1, 2, 3, 4]], - dtype=torch.int, - device=logits.device) + expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) @@ -106,10 +133,10 @@ def test_early_mismatch(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -129,15 +156,16 @@ def test_early_mismatch(rejection_sampler): def test_multiple_sequences(rejection_sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3]] - output_tokens = [[1, 2, 5], [3, - 4]] # Two sequences with bonus tokens 5 and 4 + output_tokens = [[1, 2, 5], [3, 4]] # Two sequences with bonus tokens 5 and 4 metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( - [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -146,9 +174,9 @@ def test_multiple_sequences(rejection_sampler): bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) - expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], - dtype=torch.int, - device=logits.device) + expected = torch.tensor( + [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device + ) assert torch.equal(output, expected) @@ -159,10 +187,10 @@ def test_single_token_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -182,10 +210,10 @@ def test_empty_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -201,15 +229,16 @@ def test_empty_sequence(rejection_sampler): def test_multiple_mismatches(rejection_sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] - output_tokens = [[1, 2, 7, 6], [4, 8, 6, - 9]] # Mismatches in both sequences + output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]] # Mismatches in both sequences metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( - [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -219,8 +248,10 @@ def test_multiple_mismatches(rejection_sampler): sampling_metadata=metadata, ) expected = torch.tensor( - [[1, 2, 7, PLACEHOLDER_TOKEN_ID], - [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + [ + [1, 2, 7, PLACEHOLDER_TOKEN_ID], + [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID], + ], dtype=torch.int, device=logits.device, ) @@ -232,18 +263,23 @@ def test_multiple_mismatches(rejection_sampler): [ ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch - ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], - [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches - ]) -def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, - expected): + ( + [[1, 2], [3, 4]], + [[1, 5, 6], [3, 4, 7]], + [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]], + ), # Mixed matches + ], +) +def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected): """Parametrized test for various matching scenarios""" metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor( + [tokens[-1] for tokens in output_tokens], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -252,9 +288,7 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) - expected_tensor = torch.tensor(expected, - dtype=torch.int, - device=logits.device) + expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) assert torch.equal(output, expected_tensor) @@ -273,22 +307,15 @@ def test_deterministic_when_seeded( n_rep: int, ): num_tokens = batch_size * k - draft_probs = torch.rand(num_tokens, - vocab_size, - dtype=torch.float32, - device=DEVICE) + draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE) draft_probs = F.softmax(draft_probs, dim=-1) target_logits = torch.rand_like(draft_probs) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64, - device=DEVICE) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64, - device=DEVICE) + bonus_token_ids = torch.randint( + low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE + ) + draft_token_ids = torch.randint( + low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE + ) seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded @@ -296,17 +323,17 @@ def test_deterministic_when_seeded( for _ in range(n_rep): seeded_seqs = { i: torch.Generator(device=DEVICE).manual_seed(i) - for i in range(batch_size) if seeded_mask[i] + for i in range(batch_size) + if seeded_mask[i] } - temperature = torch.ones(batch_size, - dtype=torch.float32, - device=DEVICE) - sampling_metadata = create_sampling_metadata(all_greedy=False, - temperature=temperature, - generators=seeded_seqs) + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata( + all_greedy=False, temperature=temperature, generators=seeded_seqs + ) spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=DEVICE) + draft_token_ids.tolist(), device=DEVICE + ) rep_result = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, @@ -352,8 +379,7 @@ def test_rejection_sampling_approximates_target_distribution(): num_reference_probs = 100 # Prepare draft, target, and reference probability distributions - draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), - dim=-1) + draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1) target_logits = torch.rand(vocab_size, dtype=torch.float32) target_probs = F.softmax(target_logits, dim=-1) reference_probs = F.softmax( @@ -368,38 +394,48 @@ def test_rejection_sampling_approximates_target_distribution(): for num_samples in sample_sizes: # Sample using rejection sampling. rej_sample_probs = estimate_rejection_sampling_pdf( - draft_probs, target_logits, k, vocab_size, num_samples) + draft_probs, target_logits, k, vocab_size, num_samples + ) rej_sample_probs = rej_sample_probs.to(DEVICE) # Average distance from reference probs. - reference_vs_rejsample_dist = torch.dist( - reference_probs, - rej_sample_probs).item() / reference_probs.shape[0] - target_vs_rejsample_dist = torch.dist(target_probs, - rej_sample_probs).item() + reference_vs_rejsample_dist = ( + torch.dist(reference_probs, rej_sample_probs).item() + / reference_probs.shape[0] + ) + target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item() distance_wrt_reference.append(reference_vs_rejsample_dist) distance_wrt_target.append(target_vs_rejsample_dist) relative_change_in_distance_wrt_target = get_ratio_first_to_last( - distance_wrt_target) + distance_wrt_target + ) relative_change_in_distance_wrt_reference = get_ratio_first_to_last( - distance_wrt_reference) + distance_wrt_reference + ) - print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " - f"{reference_vs_rejsample_dist=:.05f}") - print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " - f"{relative_change_in_distance_wrt_reference=:.02f}") + print( + f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}" + ) + print( + f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}" + ) relative_change_in_distance_wrt_target = get_ratio_first_to_last( - distance_wrt_target) + distance_wrt_target + ) relative_change_in_distance_wrt_reference = get_ratio_first_to_last( - distance_wrt_reference) + distance_wrt_reference + ) expected_improvement_multiplier = 20 - assert (relative_change_in_distance_wrt_target - > relative_change_in_distance_wrt_reference * - expected_improvement_multiplier) + assert ( + relative_change_in_distance_wrt_target + > relative_change_in_distance_wrt_reference * expected_improvement_multiplier + ) def get_ratio_first_to_last(elements: list[float]) -> float: @@ -427,28 +463,29 @@ def estimate_rejection_sampling_pdf( rejection_sampler = RejectionSampler() num_tokens = num_samples * k # Repeat draft probs num_samples * k times. - draft_probs = draft_probs.reshape(1, 1, - vocab_size).repeat(num_samples, k, 1) + draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) # Repeat target probs num_tokens times. target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1) # Randomly sample draft token ids from draft probs. - draft_token_ids = torch.multinomial(draft_probs[:, 0, :], - num_samples=k, - replacement=True).reshape( - num_samples, k) + draft_token_ids = torch.multinomial( + draft_probs[:, 0, :], num_samples=k, replacement=True + ).reshape(num_samples, k) draft_probs = draft_probs.view(num_tokens, vocab_size) # Bonus tokens not used but required. - bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, - device=DEVICE).repeat(num_samples, 1) + bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat( + num_samples, 1 + ) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) - sampling_metadata = create_sampling_metadata(all_greedy=False, - temperature=temperature) + sampling_metadata = create_sampling_metadata( + all_greedy=False, temperature=temperature + ) spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=bonus_token_ids.device) + draft_token_ids.tolist(), device=bonus_token_ids.device + ) output_token_ids = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, @@ -458,11 +495,12 @@ def estimate_rejection_sampling_pdf( ) output_token_ids = output_token_ids[:, :-1].flatten() - hist = torch.histogram(output_token_ids.to(dtype=torch.float, - device="cpu"), - bins=vocab_size, - range=(0, vocab_size), - density=True) + hist = torch.histogram( + output_token_ids.to(dtype=torch.float, device="cpu"), + bins=vocab_size, + range=(0, vocab_size), + density=True, + ) return hist.hist @@ -480,9 +518,9 @@ def _test_masked_logits( num_tokens = batch_size * num_draft_tokens # Create random draft probabilities. - draft_probs = torch.rand((num_tokens, vocab_size), - dtype=torch.float32, - device=DEVICE) + draft_probs = torch.rand( + (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE + ) draft_probs = F.softmax(draft_probs, dim=-1) # Randomly sample draft token ids from draft probs @@ -491,9 +529,7 @@ def _test_masked_logits( draft_token_ids = draft_token_ids.tolist() # Bonus tokens not used but required - bonus_token_ids = torch.zeros((batch_size, 1), - dtype=torch.int64, - device=DEVICE) + bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE) # Create spec decode metadata spec_decode_metadata = SpecDecodeMetadata.make_dummy( @@ -531,8 +567,7 @@ def test_top_k(rejection_sampler, top_k): # Randomly create top-k indices. top_k_indices = [ - torch.randperm(vocab_size, device=DEVICE)[:top_k] - for _ in range(num_tokens) + torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens) ] top_k_indices = torch.stack(top_k_indices) @@ -550,9 +585,7 @@ def test_top_k(rejection_sampler, top_k): sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, - top_k=torch.tensor([top_k] * batch_size, - device=DEVICE, - dtype=torch.int64), + top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64), ) _test_masked_logits( @@ -595,9 +628,7 @@ def test_top_p(rejection_sampler, top_p): sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, - top_p=torch.tensor([top_p] * batch_size, - device=DEVICE, - dtype=torch.float32), + top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32), ) _test_masked_logits( @@ -609,3 +640,136 @@ def test_top_p(rejection_sampler, top_p): unmasked_indices=top_p_indices, sampling_metadata=sampling_metadata, ) + + +########################### Tests for Logit Processors ################### +def test_frequency_penalties(rejection_sampler): + """Test rejection sampling with frequency penalties""" + spec_tokens = [[1, 1, 1], [], [1, 1, 1]] + output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens + + num_requsts = len(spec_tokens) + logits = create_logits_tensor(output_tokens, token_idx_to_override=15) + metadata = create_sampling_metadata( + all_greedy=True, + output_token_ids=[[2], [3], [4]], + spec_token_ids=spec_tokens, + prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE), + frequency_penalties=[1.5, 1.5, 0.7], + presence_penalties=[0.0] * num_requsts, + repetition_penalties=[1.0] * num_requsts, + ) + bonus_token_tensor = torch.tensor( + [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor( + [[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) + + +def test_bad_words(rejection_sampler): + """Test rejection sampling with bad words constraints""" + spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]] + output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] + + logits = create_logits_tensor(output_tokens, token_idx_to_override=15) + metadata = create_sampling_metadata( + all_greedy=True, + output_token_ids=[[2], [3], [4]], + spec_token_ids=spec_tokens, + bad_words_token_ids={ + 0: [ + [ + 2, + ] + ], + 1: [ + [ + 2, + ] + ], + # Do not apply bad words to the last request + }, + ) + bonus_token_tensor = torch.tensor( + [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + + expected = torch.tensor( + [[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) + + +def test_allowed_token_ids(rejection_sampler): + """Test rejection sampling with allowed token ids""" + spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]] + output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]] + # Not allowed tokens: + # 0: 0-4 + # 1: 1-5 + # 2: 2-6 + num_allowed_token_ids = 5 + + # Use the token 15 as the sampler choose if a token rejected + logits = create_logits_tensor(output_tokens, token_idx_to_override=15) + + batch_size = len(output_tokens) + _, vocab_size = logits.size() + mask = create_allowed_token_ids( + batch_size=batch_size, + vocab_size=vocab_size, + num_allowed_token_ids=num_allowed_token_ids, + device=logits.device, + ) + metadata = create_sampling_metadata( + all_greedy=True, + output_token_ids=[[], [], []], + spec_token_ids=spec_tokens, + allowed_token_ids_mask=mask, + ) + bonus_token_tensor = torch.tensor( + [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + + expected = torch.tensor( + [[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 53215f88bb27e..edc6acae848aa 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - import numpy as np import pytest import torch +from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.sample.logits_processor import LogitsProcessors @@ -29,12 +28,12 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: return fake_logits -def _create_penalty_tensor(batch_size: int, penalty_value: float, - device: torch.device) -> torch.Tensor: - return torch.full((batch_size, ), - fill_value=penalty_value, - dtype=torch.float, - device=device) +def _create_penalty_tensor( + batch_size: int, penalty_value: float, device: torch.device +) -> torch.Tensor: + return torch.full( + (batch_size,), fill_value=penalty_value, dtype=torch.float, device=device + ) def _create_prompt_tokens_tensor( @@ -51,36 +50,18 @@ def _create_prompt_tokens_tensor( ) -def _create_allowed_token_ids( +def _create_bad_words_token_ids( batch_size: int, vocab_size: int, - num_allowed_token_ids: int, - device: torch.device, -) -> Optional[torch.Tensor]: - mask: Optional[torch.Tensor] = None - for i in range(batch_size): - if i % 2 == 1: - continue - if mask is None: - mask = torch.zeros((batch_size, vocab_size), - dtype=torch.bool, - device=device) - start = min(i, vocab_size - 1) - end = min(i + num_allowed_token_ids, vocab_size - 1) - mask[i, start:end] = True - return mask - - -def _create_bad_words_token_ids( - batch_size: int, vocab_size: int, - bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]: + bad_words_lengths: tuple[int, ...], +) -> dict[int, list[list[int]]]: bad_words_token_ids = {} for batch_idx in range(batch_size): token_ids_single_batch = [] for bad_words_length in bad_words_lengths: - token_ids = np.random.choice(vocab_size, - size=bad_words_length, - replace=True).tolist() + token_ids = np.random.choice( + vocab_size, size=bad_words_length, replace=True + ).tolist() token_ids_single_batch.append(token_ids) bad_words_token_ids[batch_idx] = token_ids_single_batch if batch_size >= 2: @@ -93,26 +74,27 @@ def _create_bad_words_token_ids( # Returns all last tokens of bad word sequences that share the same prefix # as `given_prefix` (excluding the last token). def _collect_suffixes_with_same_prefix( - given_prefix: list[int], - bad_words_token_ids: list[list[int]]) -> list[int]: + given_prefix: list[int], bad_words_token_ids: list[list[int]] +) -> list[int]: return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix] # generate a valid token id that is not in bad_words_token_ids -def _generate_valid_token_id(bad_words_token_ids: list[list[int]], - vocab_size: int) -> int: +def _generate_valid_token_id( + bad_words_token_ids: list[list[int]], vocab_size: int +) -> int: forbidden_start_tokens = set() for bad_word in bad_words_token_ids: forbidden_start_tokens.add(bad_word[0]) # Get a safe token that's not in forbidden starts - safe_token_candidates = list( - set(range(vocab_size)) - forbidden_start_tokens) + safe_token_candidates = list(set(range(vocab_size)) - forbidden_start_tokens) # Pick a random safe token return np.random.choice(safe_token_candidates) def _update_output_token_ids_for_bad_words( - metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: + metadata: SamplingMetadata, vocab_size: int +) -> dict[int, list[int]]: bad_words_last_tokens = {} for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items(): output_token_ids = metadata.output_token_ids[batch_idx] @@ -130,12 +112,13 @@ def _update_output_token_ids_for_bad_words( # Collect all last tokens from other bad words # that share this prefix bad_words_last_token.extend( - _collect_suffixes_with_same_prefix( - prefix, bad_words_token_ids)) + _collect_suffixes_with_same_prefix(prefix, bad_words_token_ids) + ) break # Maximum one update to output_token_ids else: # Make sure no accidental match to bad words output_token_ids[-1] = _generate_valid_token_id( - bad_words_token_ids, vocab_size) + bad_words_token_ids, vocab_size + ) bad_words_last_tokens[batch_idx] = bad_words_last_token return bad_words_last_tokens @@ -150,23 +133,26 @@ def _create_default_sampling_metadata( prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( - np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + np.random.randint(0, vocab_size, size=num_output_tokens).tolist() + ) prompt_token_ids.append( - np.random.randint(0, - vocab_size, - size=np.random.randint( - 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + np.random.randint( + 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) + ).tolist() + ) fake_sampling_metadata = SamplingMetadata( - temperature=torch.full((batch_size, ), 0.0), + temperature=torch.full((batch_size,), 0.0), all_greedy=True, all_random=False, top_p=None, top_k=None, generators={}, max_num_logprobs=0, - prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, - vocab_size, device), + prompt_token_ids=_create_prompt_tokens_tensor( + prompt_token_ids, vocab_size, device + ), output_token_ids=output_token_ids, + spec_token_ids=[[] for _ in range(batch_size)], frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), @@ -179,8 +165,8 @@ def _create_default_sampling_metadata( def _create_weighted_output_token_list( - batch_size: int, - vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: + batch_size: int, vocab_size: int +) -> tuple[list[list[int]], list[list[int]]]: """ Creates an output token list where each token occurs a distinct number of times. @@ -201,14 +187,13 @@ def _create_weighted_output_token_list( output_token_ids: list[list[int]] = [] sorted_token_ids_in_output: list[list[int]] = [] for _ in range(batch_size): - distinct_token_ids = np.random.choice(vocab_size, - size=np.random.randint(1, 10), - replace=False).tolist() + distinct_token_ids = np.random.choice( + vocab_size, size=np.random.randint(1, 10), replace=False + ).tolist() sorted_token_ids_in_output.append(distinct_token_ids) output_token_ids_for_batch = [] for index, token_id in enumerate(distinct_token_ids): - output_token_ids_for_batch.extend( - [token_id for _ in range(index + 1)]) + output_token_ids_for_batch.extend([token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) return output_token_ids, sorted_token_ids_in_output @@ -216,8 +201,9 @@ def _create_weighted_output_token_list( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) -def test_sampler_presence_penalty(device: str, batch_size: int, - presence_penalty: float): +def test_sampler_presence_penalty( + device: str, batch_size: int, presence_penalty: float +): """ Test to verify that if presence penalty is enabled then tokens are penalized as per their presence in the existing output. @@ -227,13 +213,17 @@ def test_sampler_presence_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) output_token_ids = sampling_metadata.output_token_ids sampling_metadata.presence_penalties = _create_penalty_tensor( - batch_size, presence_penalty, torch.device(device)) + batch_size, presence_penalty, torch.device(device) + ) sampling_metadata.no_penalties = False sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = sampler.apply_penalties( + fake_logits, sampling_metadata, sampling_metadata.output_token_ids + ) logits = logits.cpu() for batch_idx in range(batch_size): # Since all tokens initially have the same logits, the non-penalized @@ -261,8 +251,9 @@ def test_sampler_presence_penalty(device: str, batch_size: int, @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) -def test_sampler_frequency_penalty(device: str, batch_size: int, - frequency_penalty: float): +def test_sampler_frequency_penalty( + device: str, batch_size: int, frequency_penalty: float +): """ Test to verify that if frequency penalty is enabled then tokens are penalized as per their frequency of occurrence. @@ -272,34 +263,36 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) sampling_metadata.frequency_penalties = _create_penalty_tensor( - batch_size, frequency_penalty, torch.device(device)) - output_token_ids, sorted_token_ids_in_output = \ - _create_weighted_output_token_list( - batch_size, - VOCAB_SIZE, - ) + batch_size, frequency_penalty, torch.device(device) + ) + output_token_ids, sorted_token_ids_in_output = _create_weighted_output_token_list( + batch_size, + VOCAB_SIZE, + ) sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = sampler.apply_penalties( + fake_logits, sampling_metadata, sampling_metadata.output_token_ids + ) logits = logits.cpu() for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() - distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[ - batch_idx] + distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ - len(distinct_sorted_token_ids_in_output) - 1] + len(distinct_sorted_token_ids_in_output) - 1 + ] if frequency_penalty > 0: # If `frequency_penalty` is set to > 0, it indicates # a preference for new tokens over existing ones. Verify that the # non-penalized token ID is not present in the output, while the # most penalized token is the one that occurs most frequently in # the output. - assert (non_penalized_token_id - not in distinct_sorted_token_ids_in_output) + assert non_penalized_token_id not in distinct_sorted_token_ids_in_output assert penalized_token_id == most_frequent_token_id elif frequency_penalty < 0: # If `frequency_penalty` is set to < 0, it indicates @@ -314,8 +307,9 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) -def test_sampler_repetition_penalty(device: str, batch_size: int, - repetition_penalty: float): +def test_sampler_repetition_penalty( + device: str, batch_size: int, repetition_penalty: float +): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing @@ -326,42 +320,54 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) sampling_metadata.repetition_penalties = _create_penalty_tensor( - batch_size, repetition_penalty, torch.device(device)) + batch_size, repetition_penalty, torch.device(device) + ) sampling_metadata.no_penalties = False sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = sampler.apply_penalties( + fake_logits, sampling_metadata, sampling_metadata.output_token_ids + ) logits = logits.cpu() for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() - prompt_tokens = sampling_metadata.prompt_token_ids[ - batch_idx][:].tolist() + prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx][:].tolist() output_tokens = sampling_metadata.output_token_ids[batch_idx] if repetition_penalty > 1.0: # If `repetition_penalty` > 1.0, verify that the non-penalized # token ID has not been seen before, while the penalized token ID # exists either in the prompt or the output. - assert (non_penalized_token_id not in prompt_tokens - and non_penalized_token_id not in output_tokens) - assert (penalized_token_id in prompt_tokens - or penalized_token_id in output_tokens) + assert ( + non_penalized_token_id not in prompt_tokens + and non_penalized_token_id not in output_tokens + ) + assert ( + penalized_token_id in prompt_tokens + or penalized_token_id in output_tokens + ) elif repetition_penalty < 1.0: # If `repetition_penalty` < 1.0, verify that the penalized # token ID has not been seen before, while the non-penalized # token ID exists either in the prompt or the output. - assert (penalized_token_id not in prompt_tokens - and penalized_token_id not in output_tokens) - assert (non_penalized_token_id in prompt_tokens - or non_penalized_token_id in output_tokens) + assert ( + penalized_token_id not in prompt_tokens + and penalized_token_id not in output_tokens + ) + assert ( + non_penalized_token_id in prompt_tokens + or non_penalized_token_id in output_tokens + ) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) -def test_sampler_allowed_token_ids(device: str, batch_size: int, - num_allowed_token_ids: int): +def test_sampler_allowed_token_ids( + device: str, batch_size: int, num_allowed_token_ids: int +): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing @@ -372,8 +378,9 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) - mask = _create_allowed_token_ids( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) + mask = create_allowed_token_ids( batch_size=batch_size, vocab_size=VOCAB_SIZE, num_allowed_token_ids=num_allowed_token_ids, @@ -381,7 +388,9 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int, ) sampling_metadata.allowed_token_ids_mask = mask sampler = Sampler() - logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) + logits = sampler.apply_logits_processors( + fake_logits, sampling_metadata, predict_bonus_token=False + ) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] @@ -392,17 +401,19 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int, start = min(batch_idx, VOCAB_SIZE - 1) end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1) if token_id >= start and token_id < end: - assert logits_for_req[token_id] == -float( - "inf"), f"{batch_idx}, {token_id}" + assert logits_for_req[token_id] == -float("inf"), ( + f"{batch_idx}, {token_id}" + ) else: assert logits_for_req[token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) -@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)]) -def test_sampler_bad_words(device: str, batch_size: int, - bad_words_lengths: list[tuple[int]]): +@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)]) +def test_sampler_bad_words( + device: str, batch_size: int, bad_words_lengths: tuple[int, ...] +): """ Test to verify that when the bad words restriction is present, tokens are penalized based on their match with the bad words. @@ -412,19 +423,26 @@ def test_sampler_bad_words(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids( - batch_size, VOCAB_SIZE, bad_words_lengths) + batch_size, VOCAB_SIZE, bad_words_lengths + ) bad_words_last_tokens = _update_output_token_ids_for_bad_words( - sampling_metadata, VOCAB_SIZE) + sampling_metadata, VOCAB_SIZE + ) sampler = Sampler() - logits = sampler.apply_bad_words(fake_logits, sampling_metadata) + logits = sampler.apply_logits_processors( + fake_logits, sampling_metadata, predict_bonus_token=False + ) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] for token_id in range(VOCAB_SIZE): - if (batch_idx in bad_words_last_tokens - and token_id in bad_words_last_tokens[batch_idx]): + if ( + batch_idx in bad_words_last_tokens + and token_id in bad_words_last_tokens[batch_idx] + ): assert logits_for_req[token_id] == -float("inf") else: assert logits_for_req[token_id] != -float("inf") diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index f53e1e1c485d6..bdde28fe0342a 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -1,14 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import pytest from vllm import LLM, SamplingParams -if os.getenv("VLLM_USE_V1", "0") != "1": - pytest.skip("Test package requires V1", allow_module_level=True) - MODEL = "meta-llama/Llama-3.2-1B" PROMPT = "Hello my name is Robert and I" @@ -66,9 +62,9 @@ def test_stop(llm): # Output should not contain the stop word. assert len(new_split_text) == STOP_IDX - params = SamplingParams(temperature=0, - stop=split_text[STOP_IDX], - include_stop_str_in_output=True) + params = SamplingParams( + temperature=0, stop=split_text[STOP_IDX], include_stop_str_in_output=True + ) output = llm.generate(PROMPT, params) new_split_text = output[0].outputs[0].text.split() @@ -103,8 +99,8 @@ def test_detokenize_false(llm): assert len(output[0].outputs[0].text) == 0 output = llm.generate( - PROMPT, SamplingParams(detokenize=False, logprobs=3, - prompt_logprobs=3)) + PROMPT, SamplingParams(detokenize=False, logprobs=3, prompt_logprobs=3) + ) assert len(output[0].outputs[0].token_ids) > 0 assert len(output[0].outputs[0].text) == 0 @@ -131,8 +127,7 @@ def test_bad_words(llm): assert bad_words_1 not in new_text bad_words_2 = new_text.split()[-1] - params = SamplingParams(temperature=0, - bad_words=[bad_words_1, bad_words_2]) + params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2]) output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text assert bad_words_1 not in new_text @@ -158,8 +153,7 @@ def test_allowed_token_ids(llm): TOKEN_ID = 10 allowed_token_ids = [TOKEN_ID] - output = llm.generate(PROMPT, - SamplingParams(allowed_token_ids=allowed_token_ids)) + output = llm.generate(PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids)) assert output[0].outputs[0].token_ids[-1] == TOKEN_ID # Reject empty allowed_token_ids. @@ -175,14 +169,6 @@ def test_allowed_token_ids(llm): _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000])) -def test_priority(llm): - """Check that we reject requests with priority.""" - - # Reject all allowed token ids - with pytest.raises(ValueError): - _ = llm.generate(PROMPT, priority=[1]) - - def test_seed(llm): """Check that seed impacts randomness.""" diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index ccf38c31d39e6..c70cbebe22caa 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,8 +5,10 @@ import torch from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - is_flashinfer_available) +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + is_flashinfer_available, +) DEVICE = current_platform.device_type @@ -30,19 +32,18 @@ def reset_default_device(): def test_topk_impl_equivalence(): - torch.set_default_device(DEVICE) generator = Generator(device=DEVICE).manual_seed(33) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) # Random top-k values between 1 and 9. - k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator) # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). k.masked_fill_( - torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool), - VOCAB_SIZE) + torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE + ) # Top-k only implementation result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) @@ -55,7 +56,7 @@ def test_topk_impl_equivalence(): def test_flashinfer_sampler(): - ''' + """ This test verifies that the FlashInfer top-k and top-p sampling implementation produces the same results as the Python implementation. @@ -63,11 +64,10 @@ def test_flashinfer_sampler(): top-p prob renorm (it did provide fused sampling but we cannot compare sampling results due to randomness), so we will compare the probability renormed consequently by top-k and then top-p of FlashInfer implementation. - ''' + """ if not FLASHINFER_ENABLED: - pytest.skip( - "FlashInfer not installed or not available on this platform.") + pytest.skip("FlashInfer not installed or not available on this platform.") torch.set_default_device(DEVICE) generator = Generator(device=DEVICE).manual_seed(42) @@ -76,23 +76,21 @@ def test_flashinfer_sampler(): logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) # Generate various top-k and top-p values - k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) - p_values = torch.rand( - (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] + k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator) + p_values = ( + torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5 + ) # range in [0.5, 1.0] # Sometimes disable top-k (k=vocab_size) k_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), VOCAB_SIZE) + torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), + VOCAB_SIZE, + ) # Sometimes disable top-p (p=1.0) p_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), 1.0) + torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0 + ) python_logits = apply_top_k_top_p( logits=logits.clone(), @@ -113,5 +111,6 @@ def test_flashinfer_sampler(): ) # Compare the results - assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ + assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), ( "FlashInfer and Python sampling implementations do not match!" + ) diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index e33efb413d026..b1c63327b852b 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -16,6 +16,7 @@ from vllm.v1.sample.metadata import SamplingMetadata class BatchLogprobsComposition(Enum): """Types of logprobs configs to include in test batch""" + NONE = 0 SAMPLE = 1 PROMPT = 2 @@ -26,10 +27,10 @@ BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]] def get_test_batch( - batch_logprobs_composition: BatchLogprobsComposition + batch_logprobs_composition: BatchLogprobsComposition, ) -> BatchLogprobsSpecType: """Generate logprobs configs for a batch of requests - + A given request's logprobs configuration is (1) num_sample_logprobs and (2) num_prompt_logprobs. The batch logprobs configuration is the list of request logprobs configs. @@ -101,7 +102,7 @@ def assert_incr_detok_str_matches_non_incr_detok_str( msg: str, ) -> None: """Compare incrementally detok. text to non-incrementally detok. text - + Fail if the strings mismatch after non-alphanumeric characters are stripped out. @@ -120,15 +121,15 @@ def assert_incr_detok_str_matches_non_incr_detok_str( tokens msg: error message if `assert` fails """ - rgx = r'[^a-zA-Z0-9]+' - assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub( - rgx, '', non_incremental_detokenization_str)), (msg) + rgx = r"[^a-zA-Z0-9]+" + assert re.sub(rgx, "", incremental_detokenization_str) == re.sub( + rgx, "", non_incremental_detokenization_str + ), msg -def compute_correct_cumulative_logprob( - completion_output: CompletionOutput) -> float: +def compute_correct_cumulative_logprob(completion_output: CompletionOutput) -> float: """Compute known-good value for evaluating cumulative logprob - + Args: completion_output: completion output from engine @@ -146,12 +147,12 @@ def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: return fake_logits -def create_penalty_tensor(batch_size: int, penalty_value: float, - device: torch.device) -> torch.Tensor: - return torch.full((batch_size, ), - fill_value=penalty_value, - dtype=torch.float, - device=device) +def create_penalty_tensor( + batch_size: int, penalty_value: float, device: torch.device +) -> torch.Tensor: + return torch.full( + (batch_size,), fill_value=penalty_value, dtype=torch.float, device=device + ) def create_prompt_tokens_tensor( @@ -170,6 +171,7 @@ def create_prompt_tokens_tensor( class LogitsprocsTestFakes(NamedTuple): """Wraps fake data structures to support testing""" + logits: torch.Tensor sampling_metadata: SamplingMetadata @@ -178,15 +180,16 @@ class LogitsprocsTestFakes(NamedTuple): cls: type[LogitsProcessor], ) -> Iterator[LogitsProcessor]: """Yield logits processors of a specific class. - + Args: cls: :class:`LogitsProcessor` subclass Returns: Iterator over logits processors """ - return (lp for lp in self.sampling_metadata.logitsprocs.all - if isinstance(lp, cls)) + return ( + lp for lp in self.sampling_metadata.logitsprocs.all if isinstance(lp, cls) + ) def get_logitsprocs(self) -> Iterator[LogitsProcessor]: """Iterator over all logits processors.""" @@ -208,8 +211,27 @@ def fake_apply_logitsprocs( slice_indices: list[int], ) -> torch.Tensor: """Imitate application of logits processors in engine core""" - logits = test_fakes.logits[torch.tensor(slice_indices, - dtype=torch.long)].clone() + logits = test_fakes.logits[torch.tensor(slice_indices, dtype=torch.long)].clone() for processor in test_fakes.get_logitsprocs(): logits = processor.apply(logits) return logits + + +def create_allowed_token_ids( + batch_size: int, + vocab_size: int, + num_allowed_token_ids: int, + device: torch.device, +) -> Optional[torch.Tensor]: + mask: Optional[torch.Tensor] = None + for i in range(batch_size): + if i % 2 == 1: + continue + if mask is None: + mask = torch.zeros( + (batch_size, vocab_size), dtype=torch.bool, device=device + ) + start = min(i, vocab_size - 1) + end = min(i + num_allowed_token_ids, vocab_size - 1) + mask[i, start:end] = True + return mask diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index 682d84dc23d12..d943578278641 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -5,8 +5,10 @@ import pytest from tests.utils import wait_for_gpu_memory_to_clear -from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, - SHUTDOWN_TEST_TIMEOUT_SEC) +from tests.v1.shutdown.utils import ( + SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC, +) from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind @@ -21,8 +23,9 @@ MODELS = ["meta-llama/Llama-3.2-1B"] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("send_one_request", [False, True]) -async def test_async_llm_delete(model: str, tensor_parallel_size: int, - send_one_request: bool) -> None: +async def test_async_llm_delete( + model: str, tensor_parallel_size: int, send_one_request: bool +) -> None: """Test that AsyncLLM frees GPU memory upon deletion. AsyncLLM always uses an MP client. @@ -34,19 +37,21 @@ async def test_async_llm_delete(model: str, tensor_parallel_size: int, if cuda_device_count_stateless() < tensor_parallel_size: pytest.skip(reason="Not enough CUDA devices") - engine_args = AsyncEngineArgs(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + engine_args = AsyncEngineArgs( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) # Instantiate AsyncLLM; make request to complete any deferred # initialization; then delete instance async_llm = AsyncLLM.from_engine_args(engine_args) if send_one_request: async for _ in async_llm.generate( - "Hello my name is", - request_id="abc", - sampling_params=SamplingParams( - max_tokens=1, output_kind=RequestOutputKind.DELTA)): + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=1, output_kind=RequestOutputKind.DELTA + ), + ): pass del async_llm @@ -62,9 +67,13 @@ async def test_async_llm_delete(model: str, tensor_parallel_size: int, @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("enable_multiprocessing", [True]) @pytest.mark.parametrize("send_one_request", [False, True]) -def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, - enable_multiprocessing: bool, - send_one_request: bool) -> None: +def test_llm_delete( + monkeypatch, + model: str, + tensor_parallel_size: int, + enable_multiprocessing: bool, + send_one_request: bool, +) -> None: """Test that LLM frees GPU memory upon deletion. TODO(andy) - LLM without multiprocessing. @@ -83,12 +92,13 @@ def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, # Instantiate LLM; make request to complete any deferred # initialization; then delete instance - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + llm = LLM( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) if send_one_request: - llm.generate("Hello my name is", - sampling_params=SamplingParams(max_tokens=1)) + llm.generate( + "Hello my name is", sampling_params=SamplingParams(max_tokens=1) + ) del llm # Confirm all the processes are cleaned up. diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index 523b7ee231151..383348e88540a 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -7,8 +7,10 @@ import asyncio import pytest from tests.utils import wait_for_gpu_memory_to_clear -from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, - SHUTDOWN_TEST_TIMEOUT_SEC) +from tests.v1.shutdown.utils import ( + SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC, +) from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM @@ -26,8 +28,10 @@ def evil_forward(self, *args, **kwargs): if not hasattr(self, "num_calls"): self.num_calls = 0 - if (self.num_calls == NUMBER_OF_GOOD_PASSES - and get_tensor_model_parallel_rank() == 0): + if ( + self.num_calls == NUMBER_OF_GOOD_PASSES + and get_tensor_model_parallel_rank() == 0 + ): raise Exception("Simulated illegal memory access on Rank 0!") self.num_calls += 1 @@ -37,10 +41,11 @@ def evil_forward(self, *args, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("model", MODELS) -async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, - model: str) -> None: +async def test_async_llm_model_error( + monkeypatch, tensor_parallel_size: int, model: str +) -> None: """Test that AsyncLLM propagates a forward pass error and frees memory. - + AsyncLLM always uses an MP client. """ if cuda_device_count_stateless() < tensor_parallel_size: @@ -49,15 +54,15 @@ async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, # Monkeypatch an error in the model. monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) - engine_args = AsyncEngineArgs(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + engine_args = AsyncEngineArgs( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) async_llm = AsyncLLM.from_engine_args(engine_args) async def generate(request_id: str): - generator = async_llm.generate("Hello my name is", - request_id=request_id, - sampling_params=SamplingParams()) + generator = async_llm.generate( + "Hello my name is", request_id=request_id, sampling_params=SamplingParams() + ) try: async for _ in generator: pass @@ -77,9 +82,9 @@ async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, # We should not be able to make another request. with pytest.raises(EngineDeadError): - async for _ in async_llm.generate("Hello my name is", - request_id="abc", - sampling_params=SamplingParams()): + async for _ in async_llm.generate( + "Hello my name is", request_id="abc", sampling_params=SamplingParams() + ): raise Exception("We should not get here.") # Confirm all the processes are cleaned up. @@ -98,8 +103,9 @@ async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, @pytest.mark.parametrize("enable_multiprocessing", [True]) @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("model", MODELS) -def test_llm_model_error(monkeypatch, tensor_parallel_size: int, - enable_multiprocessing: bool, model: str) -> None: +def test_llm_model_error( + monkeypatch, tensor_parallel_size: int, enable_multiprocessing: bool, model: str +) -> None: """Test that LLM propagates a forward pass error and frees memory. TODO(andy) - LLM without multiprocessing; LLM with multiprocessing and >1 rank @@ -108,19 +114,17 @@ def test_llm_model_error(monkeypatch, tensor_parallel_size: int, pytest.skip(reason="Not enough CUDA devices") with monkeypatch.context() as m: - MP_VALUE = "1" if enable_multiprocessing else "0" m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) # Monkeypatch an error in the model. m.setattr(LlamaForCausalLM, "forward", evil_forward) - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + llm = LLM( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) - with pytest.raises( - EngineDeadError if enable_multiprocessing else Exception): + with pytest.raises(EngineDeadError if enable_multiprocessing else Exception): llm.generate("Hello my name is Robert and I") # Confirm all the processes are cleaned up. diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py index a077d48fecbba..013b929e3df65 100644 --- a/tests/v1/shutdown/test_processor_error.py +++ b/tests/v1/shutdown/test_processor_error.py @@ -30,9 +30,9 @@ async def test_async_llm_processor_error(model: str) -> None: async def generate(request_id: str): # [] is not allowed and will raise a ValueError in Processor. - generator = async_llm.generate(TokensPrompt([]), - request_id=request_id, - sampling_params=SamplingParams()) + generator = async_llm.generate( + TokensPrompt([]), request_id=request_id, sampling_params=SamplingParams() + ) try: async for _ in generator: pass @@ -55,11 +55,12 @@ async def test_async_llm_processor_error(model: str) -> None: EXPECTED_TOKENS = 5 outputs = [] async for out in async_llm.generate( - "Hello my name is", - request_id="abc", - sampling_params=SamplingParams( - max_tokens=EXPECTED_TOKENS, - output_kind=RequestOutputKind.DELTA)): + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=EXPECTED_TOKENS, output_kind=RequestOutputKind.DELTA + ), + ): outputs.append(out) generated_tokens = [] diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 88fc5297aaf50..019c0c4d7cf07 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -5,8 +5,10 @@ import pytest from tests.utils import wait_for_gpu_memory_to_clear -from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, - SHUTDOWN_TEST_TIMEOUT_SEC) +from tests.v1.shutdown.utils import ( + SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC, +) from vllm import LLM from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs @@ -30,9 +32,9 @@ def evil_method(self, *args, **kwargs): @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) -def test_async_llm_startup_error(monkeypatch, model: str, - tensor_parallel_size: int, - failing_method: str) -> None: +def test_async_llm_startup_error( + monkeypatch, model: str, tensor_parallel_size: int, failing_method: str +) -> None: """Test that AsyncLLM propagates an __init__ error & frees memory. Test profiling (forward()) and load weights failures. AsyncLLM always uses an MP client. @@ -43,9 +45,9 @@ def test_async_llm_startup_error(monkeypatch, model: str, # Monkeypatch an error in the model. monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) - engine_args = AsyncEngineArgs(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + engine_args = AsyncEngineArgs( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) # Confirm we get an exception. with pytest.raises(Exception, match="initialization failed"): @@ -63,9 +65,13 @@ def test_async_llm_startup_error(monkeypatch, model: str, @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("enable_multiprocessing", [True]) @pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) -def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, - enable_multiprocessing: bool, - failing_method: str) -> None: +def test_llm_startup_error( + monkeypatch, + model: str, + tensor_parallel_size: int, + enable_multiprocessing: bool, + failing_method: str, +) -> None: """Test that LLM propagates an __init__ error and frees memory. Test profiling (forward()) and load weights failures. TODO(andy) - LLM without multiprocessing. @@ -76,7 +82,6 @@ def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, pytest.skip(reason="Not enough CUDA devices") with monkeypatch.context() as m: - MP_VALUE = "1" if enable_multiprocessing else "0" m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) @@ -84,12 +89,16 @@ def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) with pytest.raises( - Exception, - match="initialization failed" - if enable_multiprocessing else "Simulated Error in startup!"): - _ = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + Exception, + match="initialization failed" + if enable_multiprocessing + else "Simulated Error in startup!", + ): + _ = LLM( + model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size, + ) # Confirm all the processes are cleaned up. wait_for_gpu_memory_to_clear( diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 7b8445a0b2878..0f0a3722ef2dd 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -8,16 +8,28 @@ import pytest import torch from tests.utils import get_attn_backend_list_based_on_platform -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, - create_standard_kv_cache_spec, - get_attention_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VllmConfig) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + try_get_attention_backend, +) +from vllm.attention.backends.registry import _Backend +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -27,11 +39,9 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" def _create_proposer( method: str, num_speculative_tokens: int, - speculative_token_tree: Optional[list[tuple[int]]] = None, + speculative_token_tree: Optional[list[tuple[int, ...]]] = None, ) -> EagleProposer: - model_config = ModelConfig(model=model_dir, - runner="generate", - max_model_len=100) + model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) # Choose model directory based on method draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir @@ -57,10 +67,96 @@ def _create_proposer( device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig()) + scheduler_config=SchedulerConfig(), + ) - return EagleProposer(vllm_config=vllm_config, - device=current_platform.device_type) + return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) + + +def test_prepare_next_token_ids(): + """ + Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded. + Each will produce a device tensor of next_token_ids, taking as input + either the GPU tensor of sampled_token_ids with -1 for rejected tokens, + or the CPU python list[list[int]] with the rejected tokens removed. + """ + device = torch.device(current_platform.device_type) + + num_requests = 4 + num_speculative_tokens = 4 + batch_spec = BatchSpec( + seq_lens=[num_speculative_tokens + 1] * num_requests, + query_lens=[num_speculative_tokens + 1] * num_requests, + ) + + req_ids = [f"req_{i + 1}" for i in range(num_requests)] + mock_input_batch = mock.MagicMock(spec=InputBatch) + mock_input_batch.req_ids = req_ids + mock_input_batch.num_reqs = num_requests + mock_input_batch.vocab_size = 100 + + mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids} + mock_requests = {} + for req_id in req_ids: + mock_request = mock.MagicMock(spec=CachedRequestState) + # Each request will have a backup next token id of 10, 20, 30, 40 + mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10 + mock_request.num_computed_tokens = 0 + mock_requests[req_id] = mock_request + + sampled_token_ids = [ + [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled + [0, 1, 2, 3, 4], # all accepted, "4" sampled + [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" + [-1, -1, -1, -1, -1], # this request will be discarded + ] + sampled_token_ids_tensor = torch.tensor( + sampled_token_ids, dtype=torch.int32, device=device + ) + sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] + + expected_next_token_ids_cpu = [1, 4, 30, 40] + expected_next_token_ids_tensor = torch.tensor( + expected_next_token_ids_cpu, dtype=torch.int32, device=device + ) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu( + sampled_token_ids_cpu, + mock_requests, + mock_input_batch, + mock_num_scheduled_tokens, + ) + + assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) + num_discarded_reqs = 1 + + expected_valid_sampled_tokens_count = torch.tensor( + [2, 5, 0, 0], dtype=torch.int32, device=device + ) + + next_token_ids_from_padded, valid_sampled_tokens_count = ( + proposer.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids_tensor, + mock_requests, + mock_input_batch, + discarded_req_indices, + num_discarded_reqs, + ) + ) + + assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count) def test_prepare_inputs(): @@ -89,18 +185,38 @@ def test_prepare_inputs(): device=device, ) - # Rejected tokens per request: [1, 3, 2] - num_rejected_tokens = torch.tensor([1, 3, 2], - dtype=torch.int32, - device=device) + # If there are `k` sampled tokens, then `k-1` tokens are draft tokens + # from the previous iteration, and the last token is the bonus token sampled + # from the base model. + num_draft_tokens = [3, 6, 4] # one less than query_lens + # num rejected tokens is [1, 3, 2] + ACCEPT_TOKEN = 0 + BONUS_TOKEN = 1 + REJECT_TOKEN = -1 + sampled_token_ids = [ + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + [ + ACCEPT_TOKEN, + ACCEPT_TOKEN, + ACCEPT_TOKEN, + REJECT_TOKEN, + REJECT_TOKEN, + REJECT_TOKEN, + BONUS_TOKEN, + ], + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + ] + sampled_token_ids = [ + [i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids + ] # Expected calculations: # query_len_per_req = [4, 7, 5] # num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens) # Expected cumulative counts: [0, 3, 7, 10] - expected_cu_num_tokens = torch.tensor([0, 3, 7, 10], - dtype=torch.int32, - device=device) + expected_cu_num_tokens = torch.tensor( + [0, 3, 7, 10], dtype=torch.int32, device=device + ) # Expected token indices (mapped from original positions): # First request: indices 0, 1, 2 (keeping first 3 from positions 0-3) @@ -117,41 +233,117 @@ def test_prepare_inputs(): 7, # Second request: 4 tokens (7-3) 11, 12, - 13 # Third request: 3 tokens (5-2) + 13, # Third request: 3 tokens (5-2) ], dtype=torch.int32, - device=device) + device=device, + ) proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, num_rejected_tokens.cpu()) + common_attn_metadata, sampled_token_ids, num_draft_tokens + ) - assert torch.equal(updated_metadata.query_start_loc, - expected_cu_num_tokens) + assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) +def test_prepare_inputs_padded(): + """ + Input scenario is 3 requests with num_speculative_tokens == 2 and: + - Request 1: query_len = 3, rejected = 1 + - Request 2: query_len = 3, rejected = 0 + - Request 3: query_len = 3, rejected = 2 + + Expected outputs: + token_indices: [0, 1, 2, + 3, 4, 5, + 6, 7, 8] + Reason: Deferred computation should not disturb the original indices. + + token_indices_to_sample: [1, 5, 6] + Reason: After accounting for rejections, these are the valid token positions + from the original indices to sample from. + """ + + device = torch.device(current_platform.device_type) + + expected_token_indices = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device + ) + expected_token_indices_to_sample = torch.tensor( + [1, 5, 6], dtype=torch.int32, device=device + ) + + num_speculative_tokens = 2 + batch_spec = BatchSpec( + seq_lens=[3, 3, 3], + query_lens=[3, 3, 3], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9] + expected_query_start_loc = torch.tensor( + [0, 3, 6, 9], dtype=torch.int32, device=device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids=[[0] * num_speculative_tokens] * 3, + device=device, + ) + + # num_rejected_tokens = [1, 0, 2] + # num_draft_tokens = [2, 2, 2] + # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens + valid_sampled_tokens_count = torch.tensor( + [2, 3, 1], dtype=torch.int32, device=device + ) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + output_metadata, token_indices, token_indices_to_sample = ( + proposer.prepare_inputs_padded( + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) + ) + + assert output_metadata.max_query_len == 3 + assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) + assert torch.equal(token_indices, expected_token_indices) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) + + @pytest.mark.parametrize("method", ["eagle", "eagle3"]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) -@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') -@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') -@mock.patch('vllm.v1.spec_decode.eagle.get_model') -def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, - attn_backend, pp_size, use_distinct_embed_tokens, - monkeypatch): - +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_load_model( + mock_get_model, + mock_get_layers, + mock_get_pp_group, + method, + attn_backend, + pp_size, + use_distinct_embed_tokens, + monkeypatch, +): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN_VLLM_V1" - and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN_VLLM_V1 does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) - if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") # Setup draft model mock @@ -168,22 +360,28 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, # Setup mocks for attention layers target_attn_layers = { "target_attn_1": mock.MagicMock(), - "target_attn_2": mock.MagicMock() + "target_attn_2": mock.MagicMock(), } + target_indx_layers: dict[str, mock.MagicMock] = {} # Draft model has one extra attention layer compared to target model - all_attn_layers = { - **target_attn_layers, "draft_extra_attn": mock.MagicMock() - } + all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()} + + all_indx_layers: dict[str, mock.MagicMock] = {} # Make mock_get_layers return different values for each call - mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] + mock_get_layers.side_effect = [ + target_attn_layers, + target_indx_layers, + all_attn_layers, + all_indx_layers, + ] # Setup mock for pp group to return the appropriate value for world size mock_pp_group = mock.MagicMock() mock_pp_group.world_size = pp_size mock_get_pp_group.return_value = mock_pp_group - # Setup the target model mock with a custom class so that + # Set up the target model mock with a custom class so that # isinstance() checks match the expected type. class _TargetModelStub(LlamaForCausalLM): model: mock.MagicMock @@ -194,6 +392,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, target_model.model.embed_tokens.weight.shape = (131072, 4096) from vllm.model_executor.models import SupportsMultiModal + assert not isinstance(target_model, SupportsMultiModal) if method == "eagle": @@ -215,33 +414,32 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, # Verify that the embed tokens are set correctly # If pp_size is > 1, the embed tokens should be distinct if pp_size > 1 or use_distinct_embed_tokens: - assert proposer.model.model.embed_tokens != \ - target_model.model.embed_tokens + assert proposer.model.model.embed_tokens != target_model.model.embed_tokens else: # When pp_size is 1 and the draft and target models have # embed_tokens of the same shape, they should be shared. - assert proposer.model.model.embed_tokens == \ - target_model.model.embed_tokens + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens @pytest.mark.parametrize("method", ["eagle", "eagle3"]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN_VLLM_V1" - and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN_VLLM_V1 does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) - if (attn_backend == "TREE_ATTN"): - pytest.skip("TREE_ATTN is tested separately in test_propose_tree" - "because it requires special input mocking.") + if attn_backend == "TREE_ATTN": + pytest.skip( + "TREE_ATTN is tested separately in test_propose_tree" + "because it requires special input mocking." + ) - if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") # Use GPU device @@ -326,31 +524,22 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): device=device, ) - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_len_1, device=device), - torch.arange(seq_len_2, device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) sampling_metadata = mock.MagicMock() - if attn_backend == "FLASH_ATTN_VLLM_V1": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.FLASH_ATTN_VLLM_V1) - elif attn_backend == "TRITON_ATTN_VLLM_V1": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.TRITON_ATTN_VLLM_V1) + if attn_backend == "FLASH_ATTN": + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + elif attn_backend == "TRITON_ATTN": + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -364,14 +553,22 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder + ) - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) assert result.shape == (batch_size, num_speculative_tokens) @@ -380,13 +577,14 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): # Example for num_speculative_tokens=1: # [[42], [60]] expected_tokens = torch.tensor( - [[base_token_ids[0]], [base_token_ids[1]]], device=device) + [[base_token_ids[0]], [base_token_ids[1]]], device=device + ) else: # Example for num_speculative_tokens=3: # [[42, 43, 44], [60, 61, 62]] - expected_tokens = torch.zeros((batch_size, num_speculative_tokens), - dtype=torch.int64, - device=device) + expected_tokens = torch.zeros( + (batch_size, num_speculative_tokens), dtype=torch.int64, device=device + ) for i in range(batch_size): for j in range(num_speculative_tokens): expected_tokens[i, j] = base_token_ids[i] + j @@ -398,12 +596,12 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): @pytest.mark.parametrize( "spec_token_tree", [ - [(0, )], # A single token - [(0, ), (0, 0), (0, 0, 0)], # Chain - [(0, ), (1, ), (2, )], # Parallel - [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), - (2, 1)], # Tree - ]) + [(0,)], # A single token + [(0,), (0, 0), (0, 0, 0)], # Chain + [(0,), (1,), (2,)], # Parallel + [(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree + ], +) def test_propose_tree(spec_token_tree): # Get GPU device. device = torch.device(current_platform.device_type) @@ -418,9 +616,9 @@ def test_propose_tree(spec_token_tree): num_speculative_tokens = len(spec_token_tree) # Create proposer first so we can use its actual hidden_size. - proposer = _create_proposer("eagle", - num_speculative_tokens, - speculative_token_tree=spec_token_tree) + proposer = _create_proposer( + "eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree + ) # Get the hidden_size from the proposer to ensure consistency. hidden_size = proposer.hidden_size @@ -441,32 +639,31 @@ def test_propose_tree(spec_token_tree): model_mock = mock.MagicMock() # Mock the model forward calls. - forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), - torch.zeros(total_tokens, hidden_size, device=device))] + forward_returns = [ + ( + torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(total_tokens, hidden_size, device=device), + ) + ] for cu_num_drafts in proposer.cu_drafts_per_level: - h_logits = torch.zeros(batch_size * cu_num_drafts, - hidden_size, - device=device) - h_states = torch.zeros(batch_size * cu_num_drafts, - hidden_size, - device=device) + h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) + h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) forward_returns.append((h_logits, h_states)) model_mock.side_effect = forward_returns # Mock the compute_logits calls. - cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, - dtype=torch.int32, - device=device) + cu_num_drafts_tensor = torch.tensor( + [0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device + ) logits_returns = [] for level, num_children in enumerate(proposer.child_drafts_per_level): token_ids = base_token_ids + cu_num_drafts_tensor[level] - level_num_drafts = cu_num_drafts_tensor[ - level + 1] - cu_num_drafts_tensor[level] + level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level] level_logits = [] for i in range(level_num_drafts // num_children): level_logits.append( - create_deterministic_logits(token_ids + i * num_children, - num_children)) + create_deterministic_logits(token_ids + i * num_children, num_children) + ) logits_returns.append(torch.stack(level_logits, dim=1)) model_mock.compute_logits.side_effect = logits_returns @@ -477,7 +674,7 @@ def test_propose_tree(spec_token_tree): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, @@ -488,23 +685,23 @@ def test_propose_tree(spec_token_tree): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder + ) # Setup inputs for the proposer. - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_len_1, device=device), - torch.arange(seq_len_2, device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) batch_spec = BatchSpec( seq_lens=seq_lens, query_lens=seq_lens, @@ -517,18 +714,22 @@ def test_propose_tree(spec_token_tree): sampling_metadata = mock.MagicMock() # Propose draft tokens. - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) assert result.shape == (batch_size, num_speculative_tokens) # The tokens are expected to be consecutive integers starting # from the base token IDs. expected_tokens = base_token_ids[:, None] + torch.arange( - num_speculative_tokens, dtype=torch.int64, device=device) + num_speculative_tokens, dtype=torch.int64, device=device + ) # Verify that the draft tokens match our expectations. assert torch.equal(result, expected_tokens) diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index a5b10bb518668..bc779f6bd9c4d 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -33,20 +33,20 @@ def test_ngram_max_len(num_speculative_tokens: int): @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) -def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch, - num_speculative_tokens: int, attn_backend: str): +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) +def test_eagle_max_len( + monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str +): with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN_VLLM_V1" - and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN_VLLM_V1 does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) - if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") llm = LLM( diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py new file mode 100644 index 0000000000000..9ca7cf9e3e0e1 --- /dev/null +++ b/tests/v1/spec_decode/test_mtp.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest import mock + +import pytest +import torch + +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + try_get_attention_backend, +) +from vllm.attention.backends.registry import _Backend +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.config.load import LoadConfig +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.platforms import current_platform +from vllm.v1.spec_decode.eagle import EagleProposer + +mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" + + +def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: + """Create an MTP proposer with unified model configuration.""" + model_config = ModelConfig( + model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True + ) + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + model=mimo_7b_dir, + method="mtp", + num_speculative_tokens=num_speculative_tokens, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig(), + ) + + return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) + + +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group): + """Test MTP-specific model loading with unified model approach.""" + + # Setup mocks + mock_model = mock.MagicMock() + mock_model.model.embed_tokens.weight.shape = (131072, 4096) + mock_get_model.return_value = mock_model + + target_attn_layers = {"target_attn_1": mock.MagicMock()} + all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} + target_indexer_layers: dict = {} + all_indexer_layers: dict = {} + + mock_get_layers.side_effect = [ + target_attn_layers, + target_indexer_layers, + all_attn_layers, + all_indexer_layers, + ] + + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + mock_get_pp_group.return_value = mock_pp_group + + # Create target model + class _TargetModelStub(LlamaForCausalLM): + model: mock.MagicMock + lm_head: mock.MagicMock + + target_model = mock.create_autospec(_TargetModelStub, instance=True) + target_model.model = mock.MagicMock() + target_model.model.embed_tokens.weight.shape = (131072, 4096) + target_model.lm_head = mock.MagicMock() + + # Create MTP proposer + proposer = _create_mtp_proposer(num_speculative_tokens=4) + proposer.load_model(target_model) + + # Verify MTP-specific behavior: + # Model is loaded + mock_get_model.assert_called_once() + # MTP shares lm_head with target model + assert proposer.model.lm_head == target_model.lm_head + # MTP shares embed_tokens with target model + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens + + +@pytest.mark.parametrize("num_speculative_tokens", [1]) +def test_mtp_propose(num_speculative_tokens, monkeypatch): + """Test that MTP's forward method returns hidden states directly""" + + device = torch.device(current_platform.device_type) + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + vocab_size = 100 + + proposer = _create_mtp_proposer(num_speculative_tokens) + hidden_size = proposer.hidden_size + + # Mock the MTP model to verify it returns hidden states directly + model_mock = mock.MagicMock() + + # MTP returns hidden states directly + if num_speculative_tokens == 1: + model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device) + else: + # Multiple forward passes for multi-token speculation + forward_returns = [] + for i in range(num_speculative_tokens): + if i == 0: + h_states = torch.zeros(total_tokens, hidden_size, device=device) + else: + h_states = torch.zeros(batch_size, hidden_size, device=device) + forward_returns.append(h_states) + model_mock.side_effect = forward_returns + + # Mock compute_logits + def create_deterministic_logits(batch_size, vocab_size, token_offset): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + logits[:, token_offset] = 100.0 + return logits + + if num_speculative_tokens == 1: + model_mock.compute_logits.return_value = create_deterministic_logits( + batch_size, vocab_size, 42 + ) + else: + logits_returns = [ + create_deterministic_logits(batch_size, vocab_size, 42 + i) + for i in range(num_speculative_tokens) + ] + model_mock.compute_logits.side_effect = logits_returns + + proposer.model = model_mock + proposer.attn_layer_names = ["layer.0"] + + # Prepare inputs + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [ + torch.arange(seq_lens[0], device=device), + torch.arange(seq_lens[1], device=device), + ] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) + sampling_metadata = mock.MagicMock() + + # Setup attention metadata + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + proposer.runner = mock.MagicMock() + proposer.attn_metadata_builder = attn_metadata_builder + + # Run propose + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) + + # Verify the model was called correctly + assert model_mock.called + # Verify output shape + assert result.shape == (batch_size, num_speculative_tokens) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 4193f4041b32b..692c39282c372 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -4,107 +4,189 @@ import numpy as np from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.v1.spec_decode.ngram_proposer import ( - NgramProposer, _find_longest_matched_ngram_and_propose_tokens) + NgramProposer, + _find_longest_matched_ngram_and_propose_tokens, +) def test_find_longest_matched_ngram_and_propose_tokens(): tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=2) is None + result = _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2 + ) + assert len(result) == 0 tokens = np.array([1, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=3), - np.array([4, 1, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3 + ), + np.array([4, 1, 2]), + ) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=2), np.array([4, 1])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2 + ), + np.array([4, 1]), + ) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=1, - max_ngram=1, - max_model_len=1024, - k=3), - np.array([4, 1, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3 + ), + np.array([4, 1, 2]), + ) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=1, - max_ngram=1, - max_model_len=1024, - k=2), np.array([4, 1])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2 + ), + np.array([4, 1]), + ) tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=3), - np.array([4, 1, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3 + ), + np.array([4, 1, 2]), + ) # Return on the first match np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=1, - max_ngram=1, - max_model_len=1024, - k=2), np.array([6, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2 + ), + np.array([6, 2]), + ) def test_ngram_proposer(): - - def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: + def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Dummy model config. Just to set max_model_len. model_config = ModelConfig(model="facebook/opt-125m") return NgramProposer( - vllm_config=VllmConfig(model_config=model_config, - speculative_config=SpeculativeConfig( - prompt_lookup_min=min_n, - prompt_lookup_max=max_n, - num_speculative_tokens=k, - method="ngram", - ))) + vllm_config=VllmConfig( + model_config=model_config, + speculative_config=SpeculativeConfig( + prompt_lookup_min=min_n, + prompt_lookup_max=max_n, + num_speculative_tokens=k, + method="ngram", + ), + ) + ) # No match. - result = ngram_proposer( - min_n=2, max_n=2, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) - assert result is None + token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 # No match for 4-gram. - result = ngram_proposer( - min_n=4, max_n=4, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert result is None + token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) + result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 # No match for 4-gram but match for 3-gram. - result = ngram_proposer( - min_n=3, max_n=4, - k=2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) - assert np.array_equal(result, np.array([4, 1])) + token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]]) + result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[4, 1]])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - result = ngram_proposer(min_n=3, max_n=4, k=2).propose( - context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) - assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] + token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]]) + result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]] # Match for 2-gram and 3-gram, but not 4-gram. - result = ngram_proposer(min_n=2, max_n=4, k=2).propose( - context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) - assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] + token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]]) + result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] # Multiple 3-gram matched, but always pick the first one. - result = ngram_proposer( - min_n=3, max_n=3, k=2).propose(context_token_ids=np.array( - [1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3])) - assert np.array_equal(result, np.array([100, 1])) + token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) + result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert np.array_equal(result, np.array([[100, 1]])) + + # check empty input + token_ids_cpu = np.array([[]]) + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0]], + req_ids=["0"], + num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 0 + + # check multibatch input + # first request has 5 tokens and a match + # second request has 3 tokens and no match. Padded with -1 for max len 5 + token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]]) + result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([5, 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 2 + assert np.array_equal(result[0], np.array([3, 1])) + assert np.array_equal(result[1], np.array([])) + + # test if 0 threads available: can happen if TP size > CPU count + ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) + ngram_proposer.num_numba_thread_available = 0 + # set max_model_len to 2 * threshold to ensure multithread is used + num_tokens_threshold = ngram_proposer.num_tokens_threshold + ngram_proposer.max_model_len = 2 * num_tokens_threshold + # using multibatch test + middle_integer = num_tokens_threshold // 2 + input_1 = [_ for _ in range(num_tokens_threshold)] + input_1 += [middle_integer, middle_integer + 1] + input_2 = [-1] * len(input_1) + input_2[:3] = [4, 5, 6] + token_ids_cpu = np.array([input_1, input_2]) + result = ngram_proposer.propose( + sampled_token_ids=[[0], [1]], + req_ids=["0", "1"], + num_tokens_no_spec=np.array([len(input_1), 3]), + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result[0]) == 2 + assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3])) + assert np.array_equal(result[1], np.array([])) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 6317817408661..b31a2f27f54b0 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -6,9 +6,12 @@ from typing import Optional import torch -from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) +from tests.v1.attention.utils import ( + create_standard_kv_cache_spec, + create_vllm_config, + try_get_attention_backend, +) +from vllm.attention.backends.registry import _Backend from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -41,10 +44,11 @@ def forward_attention( num_kv_heads = k.shape[-2] # Initialize the query and KV sequence lengths. query_start_loc = q_len * torch.arange( - batch_size + 1, device=q.device, dtype=torch.int32) + batch_size + 1, device=q.device, dtype=torch.int32 + ) query_lens = torch.diff(query_start_loc) seq_lens = torch.full( - (batch_size, ), + (batch_size,), seqlen_k, device=q.device, dtype=torch.int32, @@ -54,14 +58,13 @@ def forward_attention( max_query_len = q_len num_actual_tokens = query_start_loc[-1] - softmax_scale = q.shape[-1]**(-0.5) + softmax_scale = q.shape[-1] ** (-0.5) layer = MockAttentionLayer() # Build common metadata. model_name = "meta-llama/Meta-Llama-3-8B" - builder_cls, impl_cls = get_attention_backend(backend) - vllm_config = create_vllm_config(model_name=model_name, - max_model_len=max(seq_lens)) + builder_cls, impl_cls = try_get_attention_backend(backend) + vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens)) if spec_token_tree is not None: # Create speculative config if token tree is specified. vllm_config.speculative_config = SpeculativeConfig( @@ -70,7 +73,8 @@ def forward_attention( model=model_name, method="eagle", num_speculative_tokens=num_spec_tokens, - speculative_token_tree=spec_token_tree) + speculative_token_tree=spec_token_tree, + ) kv_cache_spec = create_standard_kv_cache_spec(vllm_config) builder = builder_cls(kv_cache_spec, [], vllm_config, q.device) common_attn_metadata = CommonAttentionMetadata( @@ -127,8 +131,7 @@ def test_tree_attn_correctness() -> None: device = "cuda" tree_attn_masks = { # Chain. - "[(0,), (0, 0), (0, 0, 0)]": - torch.tensor( + "[(0,), (0, 0), (0, 0, 0)]": torch.tensor( [ [1, 0, 0, 0], [1, 1, 0, 0], @@ -139,8 +142,7 @@ def test_tree_attn_correctness() -> None: dtype=torch.int32, ), # Tree. - "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": - torch.tensor( + "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor( [ [1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], @@ -187,7 +189,7 @@ def test_tree_attn_correctness() -> None: dtype=torch.bfloat16, ) - # Setup the block table and KV cache for paged KV. + # Set up the block table and KV cache for paged KV. assert max_sequence_length % block_size == 0 max_blocks_per_batch = max_sequence_length // block_size kv_cache = torch.randn( @@ -201,8 +203,7 @@ def test_tree_attn_correctness() -> None: device=q.device, dtype=torch.bfloat16, ) - num_alloc_blocks_per_batch = math.ceil(seqlen_k / - block_size) + num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size) block_table = torch.zeros( (batch_size, max_blocks_per_batch), device=q.device, @@ -216,13 +217,12 @@ def test_tree_attn_correctness() -> None: ) if randomize_blocks: # Randomize the block ids. - block_ids = block_ids[torch.randperm( - block_ids.numel())] - block_table[:, : - num_alloc_blocks_per_batch] = block_ids.view( - -1, num_alloc_blocks_per_batch) + block_ids = block_ids[torch.randperm(block_ids.numel())] + block_table[:, :num_alloc_blocks_per_batch] = block_ids.view( + -1, num_alloc_blocks_per_batch + ) - # Setup the slot mapping for the input KVs. + # Set up the slot mapping for the input KVs. tree_positions = sequence_position + torch.arange( 0, tree_size_q, @@ -230,7 +230,8 @@ def test_tree_attn_correctness() -> None: dtype=torch.int64, ).repeat(batch_size, 1) tree_slot_mapping = _gen_slot_mapping( - tree_positions, block_table, block_size) + tree_positions, block_table, block_size + ) # Compute attention for the tree. tree_attn_output = forward_attention( @@ -252,8 +253,7 @@ def test_tree_attn_correctness() -> None: for q_index in range(tree_size_q): # Get the q, k, and v for the branch. branch_mask = tree_attn_mask[q_index, :] - branch_indices = torch.nonzero(branch_mask, - as_tuple=True)[0] + branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0] q_len = branch_indices.shape[0] q_branch = q[:, branch_indices] k_branch = k[:, branch_indices] @@ -267,7 +267,8 @@ def test_tree_attn_correctness() -> None: dtype=torch.int64, ).repeat(batch_size, 1) branch_slot_mapping = _gen_slot_mapping( - branch_positions, block_table, block_size) + branch_positions, block_table, block_size + ) # Compute flash attention for the branch. flash_attn_output = forward_attention( @@ -278,7 +279,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, - backend=_Backend.FLASH_ATTN_VLLM_V1, + backend=_Backend.FLASH_ATTN, ).view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. @@ -286,16 +287,19 @@ def test_tree_attn_correctness() -> None: tree_attn_output[:, branch_indices], flash_attn_output, atol=7.81e-3, - ), (f"outputs are not close for " + ), ( + f"outputs are not close for " f"batch_size: {batch_size}, " f"num_heads: {num_heads}, " f"sequence_position: {sequence_position}, " f"tree_attn_mask: {tree_attn_mask}, " - f"q_index: {q_index}.") + f"q_index: {q_index}." + ) -def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor, - block_size: int): +def _gen_slot_mapping( + positions: torch.Tensor, block_table: torch.Tensor, block_size: int +): block_indices = positions // block_size blocks = block_table.gather(dim=1, index=block_indices) return (blocks * block_size + positions % block_size).view(-1) diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 4e7c4b33e8c47..b285658af3d1a 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -4,88 +4,50 @@ import pytest from vllm.v1.structured_output.backend_xgrammar import ( - has_xgrammar_unsupported_json_features) + has_xgrammar_unsupported_json_features, +) + +pytestmark = pytest.mark.cpu_test @pytest.fixture def unsupported_string_schemas(): return [ - { - "type": "string", - "format": "email" - }, + {"type": "string", "format": "email"}, ] @pytest.fixture def unsupported_integer_schemas(): return [ - { - "type": "integer", - "multipleOf": 120 - }, + {"type": "integer", "multipleOf": 120}, ] @pytest.fixture def unsupported_number_schemas(): return [ - { - "type": "number", - "multipleOf": 120 - }, + {"type": "number", "multipleOf": 120}, ] @pytest.fixture def unsupported_array_schemas(): return [ - { - "type": "array", - "uniqueItems": True - }, - { - "type": "array", - "contains": { - "type": "string" - } - }, - { - "type": "array", - "minContains": 1 - }, - { - "type": "array", - "maxContains": 5 - }, + {"type": "array", "uniqueItems": True}, + {"type": "array", "contains": {"type": "string"}}, + {"type": "array", "minContains": 1}, + {"type": "array", "maxContains": 5}, ] @pytest.fixture def unsupported_object_schemas(): return [ - { - "type": "object", - "minProperties": 1 - }, - { - "type": "object", - "maxProperties": 5 - }, - { - "type": "object", - "propertyNames": { - "pattern": "^[a-z]+$" - } - }, - { - "type": "object", - "patternProperties": { - "^S": { - "type": "string" - } - } - }, + {"type": "object", "minProperties": 1}, + {"type": "object", "maxProperties": 5}, + {"type": "object", "propertyNames": {"pattern": "^[a-z]+$"}}, + {"type": "object", "patternProperties": {"^S": {"type": "string"}}}, ] @@ -94,75 +56,50 @@ def supported_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "status": { - "type": "string" - }, - "scores": { - "type": "array", - "items": { - "type": "number" - } - }, - "car_type": { - "type": "string", - "enum": ["sedan", "suv", "truck"] - }, - "car_brand": { - "type": "string", - "pattern": "^[a-zA-Z]+$" - }, - "short_description": { - "type": "string", - "maxLength": 50 - }, - "mileage": { - "type": "number", - "minimum": 0, - "maximum": 1000000 - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, + "status": {"type": "string"}, + "scores": {"type": "array", "items": {"type": "number"}}, + "car_type": {"type": "string", "enum": ["sedan", "suv", "truck"]}, + "car_brand": {"type": "string", "pattern": "^[a-zA-Z]+$"}, + "short_description": {"type": "string", "maxLength": 50}, + "mileage": {"type": "number", "minimum": 0, "maximum": 1000000}, "model_year": { "type": "integer", "exclusiveMinimum": 1900, - "exclusiveMaximum": 2100 - }, - "long_description": { - "type": "string", - "minLength": 50, - "maxLength": 2000 + "exclusiveMaximum": 2100, }, + "long_description": {"type": "string", "minLength": 50, "maxLength": 2000}, "address": { "type": "object", "properties": { - "street": { - "type": "string" - }, - "city": { - "type": "string" - } - } - } - } + "street": {"type": "string"}, + "city": {"type": "string"}, + }, + }, + }, } -@pytest.mark.parametrize("schema_type", [ - "unsupported_string_schemas", "unsupported_integer_schemas", - "unsupported_number_schemas", "unsupported_array_schemas", - "unsupported_object_schemas" -]) +@pytest.mark.parametrize( + "schema_type", + [ + "unsupported_string_schemas", + "unsupported_integer_schemas", + "unsupported_number_schemas", + "unsupported_array_schemas", + "unsupported_object_schemas", + ], +) def test_unsupported_json_features_by_type(schema_type, request): schemas = request.getfixturevalue(schema_type) for schema in schemas: - assert has_xgrammar_unsupported_json_features( - schema), f"Schema should be unsupported: {schema}" + assert has_xgrammar_unsupported_json_features(schema), ( + f"Schema should be unsupported: {schema}" + ) def test_supported_json_features(supported_schema): - assert not has_xgrammar_unsupported_json_features( - supported_schema), "Schema should be supported" + assert not has_xgrammar_unsupported_json_features(supported_schema), ( + "Schema should be supported" + ) diff --git a/tests/v1/test_kv_sharing.py b/tests/v1/test_kv_sharing.py deleted file mode 100644 index 6b01b7d3e1d6c..0000000000000 --- a/tests/v1/test_kv_sharing.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import Mock - -import torch - -from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionBackend, FlashAttentionMetadataBuilder) -from vllm.v1.attention.backends.flex_attention import ( - FlexAttentionBackend, FlexAttentionMetadataBuilder) -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec -from vllm.v1.worker.utils import (AttentionGroup, - initialize_kv_cache_for_kv_sharing) - - -def new_kv_cache_spec(): - return FullAttentionSpec(16, 1, 1, torch.float32, False) - - -def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): - """ - Test initializing KV cache sharing with different attention groups. - Layers in the same KV cache group might be placed in different attn groups - if they have different attention backends. - """ - shared_kv_cache_layers = { - "model.layers.2": "model.layers.0", - "model.layers.3": "model.layers.1", - } - - # Layers 0 and 1 both belong in KV cache group 0 - # However, if they have have different attention backends, they will be - # placed in different attention groups for KV cache group 0 - kv_cache_groups = [ - KVCacheGroupSpec(["model.layers.0", "model.layers.1"], - new_kv_cache_spec()), - ] - - attn_groups = [ - # KV cache group 0 has two attention groups - [ - AttentionGroup( - backend=FlashAttentionBackend, - metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), - layer_names=["model.layers.0"], - ), - AttentionGroup( - backend=FlexAttentionBackend, - metadata_builder=Mock(spec=FlexAttentionMetadataBuilder), - layer_names=["model.layers.1"], - ), - ], - ] - - # Only layers 0 and 1 will have KV caches allocated - kv_caches = { - "model.layers.0": torch.zeros(1, 2, 3), - "model.layers.1": torch.ones(1, 2, 3), - } - - initialize_kv_cache_for_kv_sharing( - shared_kv_cache_layers=shared_kv_cache_layers, - kv_cache_groups=kv_cache_groups, - kv_caches=kv_caches, - attn_groups=attn_groups, - ) - - # Check that the KV caches were shared correctly - assert kv_caches["model.layers.2"].data_ptr( - ) == kv_caches["model.layers.0"].data_ptr() - assert kv_caches["model.layers.3"].data_ptr( - ) == kv_caches["model.layers.1"].data_ptr() - - # Check that the layers were added to the correct KV cache group - assert len(kv_cache_groups) == 1 - assert kv_cache_groups[0].layer_names == [ - "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" - ] - - # Check that the layers were added to the attention groups - assert len(attn_groups) == 1 and len(attn_groups[0]) == 2 - assert attn_groups[0][0].layer_names == [ - "model.layers.0", "model.layers.2" - ] - assert attn_groups[0][1].layer_names == [ - "model.layers.1", "model.layers.3" - ] - - -def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): - """ - Test case assuming that all layers in the same KV cache group have the same - attention backends. This is true for most models. - """ - shared_kv_cache_layers = { - "model.layers.2": "model.layers.0", - "model.layers.3": "model.layers.1", - } - - kv_cache_groups = [ - KVCacheGroupSpec(["model.layers.0", "model.layers.1"], - new_kv_cache_spec()), - ] - - attn_groups = [ - # KV cache group 0 has a single attention group - # as all layers have the same flash attention backend - [ - AttentionGroup( - backend=FlashAttentionBackend, - metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), - layer_names=["model.layers.0", "model.layers.1"], - ), - ], - ] - - kv_caches = { - "model.layers.0": torch.zeros(1, 2, 3), - "model.layers.1": torch.ones(1, 2, 3), - } - - initialize_kv_cache_for_kv_sharing( - shared_kv_cache_layers=shared_kv_cache_layers, - kv_cache_groups=kv_cache_groups, - kv_caches=kv_caches, - attn_groups=attn_groups, - ) - - # Check that the KV caches were shared correctly - assert kv_caches["model.layers.2"].data_ptr( - ) == kv_caches["model.layers.0"].data_ptr() - assert kv_caches["model.layers.3"].data_ptr( - ) == kv_caches["model.layers.1"].data_ptr() - - # Check that the layers were added to the correct KV cache group - assert len(kv_cache_groups) == 1 - assert kv_cache_groups[0].layer_names == [ - "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" - ] - - # Check that the layers were added to the attention groups - assert len(attn_groups) == 1 and len(attn_groups[0]) == 1 - assert attn_groups[0][0].layer_names == [ - "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" - ] - - -def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): - """ - Test KV sharing set up when no attention groups are provided. - This is the case for the TPU model runner, which doesn't have - support for attention groups yet. - """ - shared_kv_cache_layers = { - "model.layers.2": "model.layers.0", - "model.layers.3": "model.layers.1", - } - - kv_cache_groups = [ - KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()), - KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()), - ] - - kv_caches = { - "model.layers.0": torch.zeros(1, 2, 3), - "model.layers.1": torch.ones(1, 2, 3), - } - - initialize_kv_cache_for_kv_sharing( - shared_kv_cache_layers=shared_kv_cache_layers, - kv_cache_groups=kv_cache_groups, - kv_caches=kv_caches, - ) - - # Check that the KV caches were shared correctly - assert kv_caches["model.layers.2"].data_ptr( - ) == kv_caches["model.layers.0"].data_ptr() - assert kv_caches["model.layers.3"].data_ptr( - ) == kv_caches["model.layers.1"].data_ptr() - - # Check that the layers were added to the correct KV cache group - assert len(kv_cache_groups) == 2 - assert kv_cache_groups[0].layer_names == [ - "model.layers.0", "model.layers.2" - ] - assert kv_cache_groups[1].layer_names == [ - "model.layers.1", "model.layers.3" - ] diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 1f16e92f657e0..5d3bb924590ad 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -7,34 +7,16 @@ import pytest import vllm.envs as envs from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine - -UNSUPPORTED_MODELS_V1 = [ - "openai/whisper-large-v3", # transcription - "facebook/bart-large-cnn", # encoder decoder -] MODEL = "meta-llama/Llama-3.2-1B-Instruct" -@pytest.mark.parametrize("model", UNSUPPORTED_MODELS_V1) -def test_reject_unsupported_models(monkeypatch, model): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - args = AsyncEngineArgs(model=model) - - with pytest.raises(NotImplementedError): - _ = args.create_engine_config() - m.delenv("VLLM_USE_V1") - - def test_reject_bad_config(monkeypatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") def test_unsupported_configs(monkeypatch): - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -46,24 +28,6 @@ def test_unsupported_configs(monkeypatch): }, ).create_engine_config() - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - preemption_mode="swap", - ).create_engine_config() - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - disable_async_output_proc=True, - ).create_engine_config() - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - scheduler_delay_factor=1.2, - ).create_engine_config() - def test_enable_by_default_fallback(monkeypatch): with monkeypatch.context() as m: @@ -78,12 +42,6 @@ def test_enable_by_default_fallback(monkeypatch): assert envs.VLLM_USE_V1 m.delenv("VLLM_USE_V1") - # Should fall back to V0 for supported model. - _ = AsyncEngineArgs( - model=UNSUPPORTED_MODELS_V1[0]).create_engine_config() - assert not envs.VLLM_USE_V1 - m.delenv("VLLM_USE_V1") - def test_v1_llm_by_default(monkeypatch): with monkeypatch.context() as m: @@ -95,43 +53,3 @@ def test_v1_llm_by_default(monkeypatch): print(llm.generate("Hello my name is")) assert hasattr(llm.llm_engine, "engine_core") m.delenv("VLLM_USE_V1") - - -def test_v1_attn_backend(monkeypatch): - with monkeypatch.context() as m: - if os.getenv("VLLM_USE_V1", None): - m.delenv("VLLM_USE_V1") - m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") - - # Fall back to V0. - _ = AsyncEngineArgs(model=MODEL).create_engine_config() - assert not envs.VLLM_USE_V1 - m.delenv("VLLM_USE_V1") - - # Reject if V1. - m.setenv("VLLM_USE_V1", "1") - with pytest.raises(NotImplementedError): - AsyncEngineArgs(model=MODEL).create_engine_config() - m.delenv("VLLM_USE_V1") - - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHMLA") - _ = AsyncEngineArgs(model=MODEL).create_engine_config() - assert envs.VLLM_USE_V1 - m.delenv("VLLM_USE_V1") - - -def test_reject_using_constructor_directly(monkeypatch): - with monkeypatch.context() as m: - if os.getenv("VLLM_USE_V1", None): - m.delenv("VLLM_USE_V1") - - # Sets VLLM_USE_V1=1. - vllm_config = AsyncEngineArgs(model=MODEL).create_engine_config() - - # This uses the V0 constructor directly. - with pytest.raises(ValueError): - AsyncLLMEngine(vllm_config, - AsyncLLMEngine._get_executor_cls(vllm_config), - log_stats=True) - - m.delenv("VLLM_USE_V1") diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 118b40d0ef418..a306a2b040d3a 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,16 +9,21 @@ import numpy as np import pytest import torch -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField, NestedTensors) +from vllm.multimodal.inputs import ( + MultiModalBatchedField, + MultiModalFieldElem, + MultiModalFlatField, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, + NestedTensors, +) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +pytestmark = pytest.mark.cpu_test + class UnrecognizedType(UserDict): - def __init__(self, an_int: int): super().__init__() self.an_int = an_int @@ -45,10 +50,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") obj = MyType( - tensor1=torch.randint(low=0, - high=100, - size=(1024, ), - dtype=torch.int32), + tensor1=torch.randint(low=0, high=100, size=(1024,), dtype=torch.int32), a_string="hello", list_of_tensors=[ torch.rand((1, 10), dtype=torch.float32), @@ -56,8 +58,9 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): torch.tensor(1984), # test scalar too # Make sure to test bf16 which numpy doesn't support. torch.rand((3, 5, 1000), dtype=torch.bfloat16), - torch.tensor([float("-inf"), float("inf")] * 1024, - dtype=torch.bfloat16), + torch.tensor( + [float("-inf"), float("inf")] * 1024, dtype=torch.bfloat16 + ), ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), @@ -101,22 +104,24 @@ class MyRequest(msgspec.Struct): def test_multimodal_kwargs(): - e1 = MultiModalFieldElem("audio", "a0", - torch.zeros(1000, dtype=torch.bfloat16), - MultiModalBatchedField()) + e1 = MultiModalFieldElem( + "audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField() + ) e2 = MultiModalFieldElem( "video", "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], - MultiModalFlatField( - [[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), + MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), + ) + e3 = MultiModalFieldElem( + "image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4) ) - e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, - dtype=torch.int32), - MultiModalSharedField(4)) e4 = MultiModalFieldElem( - "image", "i1", torch.zeros(1000, dtype=torch.int32), - MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2)) + "image", + "i1", + torch.zeros(1000, dtype=torch.int32), + MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2), + ) audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) @@ -162,16 +167,14 @@ def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.tensor1, obj2.tensor1) assert obj1.a_string == obj2.a_string assert all( - torch.equal(a, b) - for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) + torch.equal(a, b) for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors) + ) assert np.array_equal(obj1.numpy_array, obj2.numpy_array) assert obj1.unrecognized.an_int == obj2.unrecognized.an_int assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor) assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor) - assert torch.equal(obj1.small_non_contig_tensor, - obj2.small_non_contig_tensor) - assert torch.equal(obj1.large_non_contig_tensor, - obj2.large_non_contig_tensor) + assert torch.equal(obj1.small_non_contig_tensor, obj2.small_non_contig_tensor) + assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor) assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) @@ -208,8 +211,9 @@ def test_tensor_serialization(): decoded = decoder.decode(encoded) # Verify the decoded tensor matches the original - assert torch.allclose( - tensor, decoded), "Decoded tensor does not match the original tensor." + assert torch.allclose(tensor, decoded), ( + "Decoded tensor does not match the original tensor." + ) def test_numpy_array_serialization(): @@ -227,13 +231,12 @@ def test_numpy_array_serialization(): decoded = decoder.decode(encoded) # Verify the decoded array matches the original - assert np.allclose( - array, - decoded), "Decoded numpy array does not match the original array." + assert np.allclose(array, decoded), ( + "Decoded numpy array does not match the original array." + ) class CustomClass: - def __init__(self, value): self.value = value @@ -242,7 +245,8 @@ class CustomClass: def test_custom_class_serialization_allowed_with_pickle( - monkeypatch: pytest.MonkeyPatch): + monkeypatch: pytest.MonkeyPatch, +): """Test that serializing a custom class succeeds when allow_pickle=True.""" with monkeypatch.context() as m: @@ -259,8 +263,7 @@ def test_custom_class_serialization_allowed_with_pickle( decoded = decoder.decode(encoded) # Verify the decoded object matches the original - assert obj == decoded, ( - "Decoded object does not match the original object.") + assert obj == decoded, "Decoded object does not match the original object." def test_custom_class_serialization_disallowed_without_pickle(): diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 865b58bc7f4b0..f3495b00d3d4c 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -4,6 +4,7 @@ Run `pytest tests/v1/tpu/test_basic.py`. """ + from __future__ import annotations from typing import TYPE_CHECKING @@ -32,51 +33,51 @@ MAX_NUM_REQS = [16, 1024] # TENSOR_PARALLEL_SIZES = [1, 4] -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="This is a basic test for TPU only" +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS) def test_basic( vllm_runner: type[VllmRunner], - monkeypatch: pytest.MonkeyPatch, model: str, max_tokens: int, tensor_parallel_size: int, max_num_seqs: int, ) -> None: - prompt = "The next numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The next numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + with vllm_runner( + model, + # Note: max_num_batched_tokens == 1024 is needed here to + # actually test chunked prompt + max_num_batched_tokens=1024, + max_model_len=8192, + gpu_memory_utilization=0.7, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + output = vllm_outputs[0][1] - with vllm_runner( - model, - # Note: max_num_batched_tokens == 1024 is needed here to - # actually test chunked prompt - max_num_batched_tokens=1024, - max_model_len=8192, - gpu_memory_utilization=0.7, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - output = vllm_outputs[0][1] - - assert "1024" in output or "0, 1" in output + assert "1024" in output or "0, 1" in output @pytest.mark.skip(reason="Temporarily disabled due to timeout") -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="This is a basic test for TPU only" +) @pytest.mark.parametrize("max_tokens", [8]) @pytest.mark.parametrize("max_num_seqs", [16]) def test_phi3( vllm_runner: type[VllmRunner], - monkeypatch: pytest.MonkeyPatch, max_tokens: int, max_num_seqs: int, ) -> None: @@ -93,30 +94,27 @@ def test_phi3( # test head dim = 96 model = "microsoft/Phi-3-mini-128k-instruct" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - with vllm_runner(model, - max_num_batched_tokens=256, - max_num_seqs=max_num_seqs) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) - # vllm_outputs is a list of tuples whose first element is the token id - # and the second element is the output (including the prompt). - for output, answer in zip(vllm_outputs, answers): - generated_text = output[1] - assert answer in generated_text + with vllm_runner( + model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + # vllm_outputs is a list of tuples whose first element is the token id + # and the second element is the output (including the prompt). + for output, answer in zip(vllm_outputs, answers): + generated_text = output[1] + assert answer in generated_text TP_SIZE_8 = 8 -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a test for TPU only") -@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8, - reason=f"This test requires {TP_SIZE_8} TPU chips.") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") +@pytest.mark.skipif( + tpu.num_available_chips() < TP_SIZE_8, + reason=f"This test requires {TP_SIZE_8} TPU chips.", +) def test_gemma3_27b_with_text_input_and_tp( vllm_runner: type[VllmRunner], - monkeypatch: pytest.MonkeyPatch, ) -> None: model = "google/gemma-3-27b-it" max_tokens = 16 @@ -133,49 +131,47 @@ def test_gemma3_27b_with_text_input_and_tp( " but in rising every time we fall.", ] - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - - with vllm_runner( - model, - max_num_batched_tokens=256, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) - # vllm_outputs is a list of tuples whose first element is the token id - # and the second element is the output (including the prompt). - for output, answer in zip(vllm_outputs, answers): - generated_text = output[1] - assert answer in generated_text + with vllm_runner( + model, + max_num_batched_tokens=256, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + # vllm_outputs is a list of tuples whose first element is the token id + # and the second element is the output (including the prompt). + for output, answer in zip(vllm_outputs, answers): + generated_text = output[1] + assert answer in generated_text -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="This is a basic test for TPU only" +) def test_w8a8_quantization( vllm_runner: type[VllmRunner], - monkeypatch: pytest.MonkeyPatch, ) -> None: model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8" max_tokens = 5 tensor_parallel_size = 1 max_num_seqs = 4 - prompt = "The next numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The next numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + with vllm_runner( + model, + max_num_batched_tokens=64, + max_model_len=4096, + gpu_memory_utilization=0.7, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + output = vllm_outputs[0][1] - with vllm_runner( - model, - max_num_batched_tokens=64, - max_model_len=4096, - gpu_memory_utilization=0.7, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - output = vllm_outputs[0][1] - - assert "1024" in output or "0, 1" in output + assert "1024" in output or "0, 1" in output diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index acb607247d754..99d5f98351ad2 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -10,61 +10,69 @@ import vllm.v1.attention.backends.pallas # noqa: F401 from vllm.platforms import current_platform -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a test for TPU only") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") @pytest.mark.parametrize("page_size", [32, 33]) @pytest.mark.parametrize("combined_kv_head_num", [2, 16]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("num_slices_per_block", [4, 8]) -def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, - head_dim: int, num_slices_per_block: int): +def test_kv_cache_update_kernel( + page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int +): page_num = 1000 padded_num_tokens = 128 kv_cache_cpu = torch.zeros( (page_num * page_size, combined_kv_head_num, head_dim), dtype=torch.bfloat16, - device="cpu") + device="cpu", + ) kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) new_kv_cpu = torch.randn( (padded_num_tokens, combined_kv_head_num, head_dim), dtype=torch.bfloat16, - device="cpu") + device="cpu", + ) new_kv_xla = new_kv_cpu.to(torch_xla.device()) - slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], - dtype=np.int32) + slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32) num_kv_update_slices = len(slice_lens) - kv_cache_start_indices = np.array([ - page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, - page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 - ], - dtype=np.int32) + kv_cache_start_indices = np.array( + [ + page_size * 2 - 7, + page_size * 2, + page_size * 3, + page_size * 4 + 6, + page_size * 5 + 7, + page_size * 6 + 8, + page_size * 15 + 3, + ], + dtype=np.int32, + ) new_kv_cache_indices = np.concatenate( - [np.array([0], dtype=np.int32), - np.cumsum(slice_lens[:-1])]) + [np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])] + ) slot_mapping = np.stack( - [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) + [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1 + ) slot_mapping = np.transpose(slot_mapping) - slot_mapping_cpu = torch.tensor(slot_mapping, - device="cpu", - dtype=torch.int32) + slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32) slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) - num_kv_update_slices_xla = torch.tensor([num_kv_update_slices], - device=torch_xla.device(), - dtype=torch.int32) + num_kv_update_slices_xla = torch.tensor( + [num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32 + ) torch_xla.sync() torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( - new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla, - page_size, num_slices_per_block) + new_kv_xla, + slot_mapping_xla, + kv_cache_xla, + num_kv_update_slices_xla, + page_size, + num_slices_per_block, + ) kv_cache_xla.copy_(new_kv_cache_xla) torch_xla.sync() - for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, - slice_lens): - kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] + for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens): + kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :] - assert torch.allclose(kv_cache_xla.cpu(), - kv_cache_cpu, - atol=1e-4, - rtol=1e-4) + assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4) diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py index 9d690851b70eb..5debdf85bea8d 100644 --- a/tests/v1/tpu/test_mha_attn.py +++ b/tests/v1/tpu/test_mha_attn.py @@ -19,8 +19,7 @@ from vllm.platforms import current_platform @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() @@ -49,8 +48,7 @@ NUM_KV_HEADS = [1] HEAD_SIZES = [64, 80] -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -68,19 +66,12 @@ def test_mha_attn_forward( current_platform.seed_everything(0) # These are expected to be f32 q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device) - k = torch.randn(batch_size, - seq_len, - num_kv_heads * head_size, - device=device) - v = torch.randn(batch_size, - seq_len, - num_kv_heads * head_size, - device=device) + k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) + v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention(num_heads, - head_size, - scale=scale, - num_kv_heads=num_kv_heads) + attn = MultiHeadAttention( + num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads + ) output = attn(q, k, v) assert num_heads % num_kv_heads == 0 diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py index bcc2993028dd6..5bf823417d4dc 100644 --- a/tests/v1/tpu/test_multimodal.py +++ b/tests/v1/tpu/test_multimodal.py @@ -4,47 +4,42 @@ import openai import pytest -from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.multimodal.utils import encode_image_base64 from vllm.platforms import current_platform -from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS +from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS from ...utils import RemoteOpenAIServer @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_asset: encode_image_base64( + local_asset_server.get_image_asset(image_asset) + ) + for image_asset in TEST_IMAGE_ASSETS } @pytest.mark.asyncio -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") @pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) -async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, - str]): - +async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str]): pytest.skip("Skip this test until it's fixed.") def whats_in_this_image_msg(b64): - return [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{b64}" + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, }, - }, - ], - }] + ], + } + ] server_args = [ "--max-model-len", @@ -61,19 +56,20 @@ async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, ] # Server will pre-compile on first startup (takes a long time). - with RemoteOpenAIServer(model_name, server_args, - max_wait_seconds=600) as remote_server: + with RemoteOpenAIServer( + model_name, server_args, max_wait_seconds=600 + ) as remote_server: client: openai.AsyncOpenAI = remote_server.get_async_client() # Other requests now should be much faster - for image_url in TEST_IMAGE_URLS: + for image_url in TEST_IMAGE_ASSETS: image_base64 = base64_encoded_image[image_url] - chat_completion_from_base64 = await client.chat.completions\ - .create( + chat_completion_from_base64 = await client.chat.completions.create( model=model_name, messages=whats_in_this_image_msg(image_base64), max_completion_tokens=24, - temperature=0.0) + temperature=0.0, + ) result = chat_completion_from_base64 assert result choice = result.choices[0] diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index bfba3af57f715..0a994e99bade1 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -5,8 +5,7 @@ from unittest.mock import ANY, patch import torch from vllm.attention.backends.abstract import AttentionType -from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl, - PallasMetadata) +from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata def test_ragged_paged_attention(): @@ -33,10 +32,12 @@ def test_ragged_paged_attention(): ) class FakeAttentionLayer: + _q_scale_float: float _k_scale_float: float _v_scale_float: float layer = FakeAttentionLayer() + layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 @@ -51,14 +52,14 @@ def test_ragged_paged_attention(): max_num_reqs = 8 max_num_blocks_per_req = 8 num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32) - block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), - dtype=torch.int32) - context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32) + block_tables = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), dtype=torch.int32 + ) + context_lens = torch.ones((max_num_reqs,), dtype=torch.int32) query_lens = [1] * max_num_reqs - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.int32), - dim=0, - dtype=torch.int32) + query_start_loc = torch.cumsum( + torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ) num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, @@ -70,8 +71,7 @@ def test_ragged_paged_attention(): num_slices_per_kv_cache_update_block=8, ) - with patch("torch.ops.xla.ragged_paged_attention" - ) as mock_ragged_paged_attention: + with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention: attn_impl.forward( layer=layer, query=query, diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py index f4a2d5ac853a8..b7b6835c40ccb 100644 --- a/tests/v1/tpu/test_perf.py +++ b/tests/v1/tpu/test_perf.py @@ -4,6 +4,7 @@ Run `pytest tests/v1/tpu/test_perf.py`. """ + from __future__ import annotations import time @@ -37,7 +38,6 @@ TEST_PARAMS = [ # open(/dev/vfio/0): Device or resource busy: Device or resource busy; # Couldn't open iommu group /dev/vfio/0 # => Investigate - # TestParams( # model="Qwen/Qwen2.5-1.5B-Instruct", # num_prompts=1, @@ -59,16 +59,14 @@ TEST_PARAMS = [ num_prompts=64, prefix_len=500, decode_len=50, - # commit id: ccb246776d93ef105904a8ec015b3587240a1183 # tpu: v5lite (old vllm CI/CD) # expected_avg_time=1.4, # err_tol=0.30, - # (This is the active CI/CD instance) # commit id: ccb246776d93ef105904a8ec015b3587240a1183 # tpu: v6e (current vllm CI/CD) - expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH= + expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH= err_tol=0.20, ), ] @@ -81,66 +79,72 @@ MAX_NUM_SEQS = 32 GPU_UTIL = 0.9 -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic performance test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), + reason="This is a basic performance test for TPU only", +) @pytest.mark.parametrize("params", TEST_PARAMS) def test_perf( vllm_runner: type[VllmRunner], - monkeypatch: pytest.MonkeyPatch, params: TestParams, ) -> None: - tokenizer = get_tokenizer(params.model, - tokenizer_mode="auto", - trust_remote_code=True) + tokenizer = get_tokenizer( + params.model, tokenizer_mode="auto", trust_remote_code=True + ) prompts = [] for i in range(params.num_prompts): - prefix_token_ids = np.random.randint(0, - tokenizer.vocab_size, - size=params.prefix_len).tolist() + prefix_token_ids = np.random.randint( + 0, tokenizer.vocab_size, size=params.prefix_len + ).tolist() prompt = tokenizer.decode(prefix_token_ids) prompts.append(prompt) print( "-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format( - len(prompts), params.prefix_len, params.decode_len)) + len(prompts), params.prefix_len, params.decode_len + ) + ) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") + sampling_params = SamplingParams( + max_tokens=params.decode_len, temperature=1.0, min_p=0.0 + ) - sampling_params = SamplingParams(max_tokens=params.decode_len, - temperature=1.0, - min_p=0.0) + with vllm_runner( + params.model, + max_num_batched_tokens=MAX_MODEL_LEN, + max_model_len=MAX_MODEL_LEN, + max_num_seqs=MAX_NUM_SEQS, + gpu_memory_utilization=GPU_UTIL, + enforce_eager=False, + tensor_parallel_size=1, + ) as vllm_model: + print(" -- Warmup / Compile") + for i in range(NUM_WARMUPS): + _ = vllm_model.generate(prompts, sampling_params) - with vllm_runner(params.model, - max_num_batched_tokens=MAX_MODEL_LEN, - max_model_len=MAX_MODEL_LEN, - max_num_seqs=MAX_NUM_SEQS, - gpu_memory_utilization=GPU_UTIL, - enforce_eager=False, - tensor_parallel_size=1) as vllm_model: - print(" -- Warmup / Compile") - for i in range(NUM_WARMUPS): - _ = vllm_model.generate(prompts, sampling_params) + print(" -- Benchmarking... ") + times = [] + for i in range(NUM_RUNS): + start_time = time.time() + _ = vllm_model.generate(prompts, sampling_params) + times.append(time.time() - start_time) - print(" -- Benchmarking... ") - times = [] - for i in range(NUM_RUNS): - start_time = time.time() - _ = vllm_model.generate(prompts, sampling_params) - times.append(time.time() - start_time) + avg_time = sum(times) / len(times) - avg_time = sum(times) / len(times) + print(" -- avg_time = {}".format(avg_time)) + print( + " -- expected_avg_time = {} with err_tol = {}".format( + params.expected_avg_time, params.err_tol + ) + ) + diff = avg_time - params.expected_avg_time + ok = diff < params.err_tol + if diff < -params.err_tol: + print( + " !! WARNING !! Performance has improved by {}, " + "it may be necessary to fine-tune the " + "expected_avg_time = {}".format(-diff, params.expected_avg_time) + ) - print(" -- avg_time = {}".format(avg_time)) - print(" -- expected_avg_time = {} with err_tol = {}".format( - params.expected_avg_time, params.err_tol)) - diff = avg_time - params.expected_avg_time - ok = diff < params.err_tol - if diff < -params.err_tol: - print(" !! WARNING !! Performance has improved by {}, " - "it may be necessary to fine-tune the " - "expected_avg_time = {}".format( - -diff, params.expected_avg_time)) - - assert ok, " !! ERROR !! Regression detected" + assert ok, " !! ERROR !! Regression detected" diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index fa950e5f7f85b..58f6292b05a72 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -10,21 +10,20 @@ from vllm.sampling_params import SamplingParams @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") def test_sampler_different(model_name: str): """ - Test significantly different sampling params to assert the model produces + Test significantly different sampling params to assert the model produces different results. """ - llm = LLM(model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=512, - max_num_batched_tokens=256) - prompts = [ - "Write a short story about a robot that dreams for the first time." - ] + llm = LLM( + model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=512, + max_num_batched_tokens=256, + ) + prompts = ["Write a short story about a robot that dreams for the first time."] sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) output = llm.generate(prompts, sampling_params) @@ -47,7 +46,9 @@ def test_sampler_different(model_name: str): max_tokens=64, # Vary number of ks top_k=random.randint(4, 12), - top_p=random.random()) for _ in range(B) + top_p=random.random(), + ) + for _ in range(B) ] # Make sure first two reqs have the same K/P sampling_params[0] = sampling_params[1] @@ -61,20 +62,18 @@ def test_sampler_different(model_name: str): @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) # TODO TPU will appear busy if we fan-out test params here @pytest.mark.parametrize("n_prompts", [1]) -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") def test_logprobs(model_name: str, n_prompts: int): """ Request top logprobs with different sampling settings and check - that results contains the requested number, ordered ascendingly. + that results contains the requested number, ordered ascendingly. """ def check_num_logprobs(logprobs, expected_num: int): for step in logprobs: prev_logp = 1.0 # order by rank - sorted_step = dict( - sorted(step.items(), key=lambda item: item[1].rank)) + sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank)) # Can contain the sampled token assert len(step) == expected_num or len(step) == expected_num + 1 @@ -84,23 +83,23 @@ def test_logprobs(model_name: str, n_prompts: int): prev_logp = logp.logprob assert logp.rank == rankno + 1 - llm = LLM(model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=128, - max_num_batched_tokens=128) + llm = LLM( + model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=128, + max_num_batched_tokens=128, + ) prompts = [ "Write a short story about a robot that dreams for the first time." ] * n_prompts - greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\ - logprobs=4) - regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ - logprobs=4) - topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ - logprobs=4, top_k=12, top_p=0.5) + greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4) + regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4) + topkp_sampling_params = SamplingParams( + temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5 + ) - for sp in [greedy_sampling_params, regular_sampling_params, \ - topkp_sampling_params]: + for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]: output = llm.generate(prompts, sp) for o in output: check_num_logprobs(o.outputs[0].logprobs, 4) diff --git a/tests/v1/tpu/test_spmd_model_weight_loading.py b/tests/v1/tpu/test_spmd_model_weight_loading.py index ad234df0c8ed7..be866bf90a792 100644 --- a/tests/v1/tpu/test_spmd_model_weight_loading.py +++ b/tests/v1/tpu/test_spmd_model_weight_loading.py @@ -9,14 +9,18 @@ import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tpu import TPUModelLoader def _setup_environment(model): - engine_args = EngineArgs(model=model, ) + engine_args = EngineArgs( + model=model, + ) vllm_config = engine_args.create_engine_config() with set_current_vllm_config(vllm_config): temp_file = tempfile.mkstemp()[1] @@ -25,7 +29,8 @@ def _setup_environment(model): 0, local_rank=0, distributed_init_method=f"file://{temp_file}", - backend="gloo") + backend="gloo", + ) # Under single worker mode, full model is init first and then # partitioned using GSPMD. ensure_model_parallel_initialized(1, 1) @@ -42,7 +47,7 @@ def _get_spmd_mesh(): num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y")) return MESH @@ -53,15 +58,17 @@ def _get_spmd_mesh(): # Skip large models due to CI runner disk space limitations # "meta-llama/Llama-3.1-8B-Instruct", # "meta-llama/Llama-3.1-70B-Instruct", - ]) + ], +) def test_tpu_model_loader(model): # Skip the 70B test if there are less than 8 chips # TODO: Query using torch xla API, the query API is not working # with SPMD now. However, This test is running under SPMD mode. - if '70B' in model and xr.global_runtime_device_count() < 8: + if "70B" in model and xr.global_runtime_device_count() < 8: pytest.skip( "Skipping 70B model if the TPU VM has less than 8 chips to \ - avoid OOM.") + avoid OOM." + ) vllm_config = _setup_environment(model) loader = TPUModelLoader(load_config=vllm_config.load_config) diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index ca5c067b364e0..c2fc24442c7cd 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -4,10 +4,14 @@ import math import pytest import torch +import torch_xla from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_tpu) +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + +# isort: off +from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu +# isort: on if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) @@ -25,11 +29,10 @@ def test_topk_equivalence_to_native_impl(): logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) # Random top-k values between 1 and 10. - k = torch.randint(1, 10, (BATCH_SIZE, )) + k = torch.randint(1, 10, (BATCH_SIZE,)) # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), - VOCAB_SIZE) + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE) result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) @@ -45,21 +48,19 @@ def test_topp_result_sums_past_p(): probs = logits.softmax(dim=-1) # Random top-p values between 0 and 1. - p = torch.rand((BATCH_SIZE, )) + p = torch.rand((BATCH_SIZE,)) # Set p=1 for ~50% of requests in the batch (top-p disabled). - p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1) + p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1) no_op_k = torch.tensor([VOCAB_SIZE]) - logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), - k=no_op_k, - p=p) + logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p) # Verify that the masked logit's probability sums to at least p. probs.masked_fill_(logits_masked.isinf(), 0) masked_prob_sum = probs.sum(dim=-1) - xm.mark_step() + torch_xla.sync() # Perform assertion on CPU. assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) @@ -67,18 +68,18 @@ def test_topp_result_sums_past_p(): def test_topp_basic(): with torch.device(xm.xla_device()): - logits = torch.tensor([[math.log(0.2), - math.log(0.3), - math.log(0.5)], - [math.log(0.5), - math.log(0.1), - math.log(0.4)]]) + logits = torch.tensor( + [ + [math.log(0.2), math.log(0.3), math.log(0.5)], + [math.log(0.5), math.log(0.1), math.log(0.4)], + ] + ) - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([3, 3]), - p=torch.tensor([0.79, 0.79])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79]) + ) - xm.mark_step() + torch_xla.sync() # Expect the smallest elements to be dropped. expected_result = logits.clone().cpu() @@ -89,18 +90,18 @@ def test_topp_basic(): def test_topp_select_all(): with torch.device(xm.xla_device()): - logits = torch.tensor([[math.log(0.2), - math.log(0.3), - math.log(0.5)], - [math.log(0.5), - math.log(0.1), - math.log(0.4)]]) + logits = torch.tensor( + [ + [math.log(0.2), math.log(0.3), math.log(0.5)], + [math.log(0.5), math.log(0.1), math.log(0.4)], + ] + ) - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([3, 3]), - p=torch.tensor([1.0, 1.0])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0]) + ) - xm.mark_step() + torch_xla.sync() assert torch.allclose(logits.cpu(), result.cpu()) @@ -109,16 +110,14 @@ def test_topp_with_ties(): with torch.device(xm.xla_device()): # Input has multiple math.log(0.3). logits = torch.tensor( - [[math.log(0.3), - math.log(0.3), - math.log(0.3), - math.log(0.1)]]) + [[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]] + ) - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([4]), - p=torch.tensor([0.2])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2]) + ) - xm.mark_step() + torch_xla.sync() # All tie values are included in the top-p set. Tie breaking is left # to be done during final sampling (all tie tokens have equal @@ -130,19 +129,19 @@ def test_topp_with_ties(): def test_both_topk_topp(): with torch.device(xm.xla_device()): - logits = torch.tensor([[math.log(0.2), - math.log(0.3), - math.log(0.5)], - [math.log(0.5), - math.log(0.1), - math.log(0.4)]]) + logits = torch.tensor( + [ + [math.log(0.2), math.log(0.3), math.log(0.5)], + [math.log(0.5), math.log(0.1), math.log(0.4)], + ] + ) # Set k=1 for the first batch. - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([1, 3]), - p=torch.tensor([0.79, 0.79])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79]) + ) - xm.mark_step() + torch_xla.sync() # Since for the first batch k=1, expect only the largest element gets # selected. diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py index 991070dc9239d..50001567a9588 100644 --- a/tests/v1/tpu/test_tpu_int8.py +++ b/tests/v1/tpu/test_tpu_int8.py @@ -4,11 +4,11 @@ Run `pytest tests/quantization/test_tpu_int8.py`. """ + import pytest from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.layers.quantization.tpu_int8 import ( - TPUInt8LinearMethod) +from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod from vllm.platforms import current_platform from ...models.registry import HF_EXAMPLE_MODELS @@ -16,8 +16,9 @@ from ...models.registry import HF_EXAMPLE_MODELS MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="TPU Int8 is only enabled for TPUs.") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs." +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [10]) @@ -26,20 +27,28 @@ MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] [ # w8a8 dynamic activation { - 'quantization_config': { - 'quant_method': 'tpu_int8', - 'activation_scheme': 'dynamic' + "quantization_config": { + "quant_method": "tpu_int8", + "activation_scheme": "dynamic", } } - ]) -def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, - hf_overrides: dict, monkeypatch) -> None: + ], +) +def test_model_tpu_int8( + vllm_runner, + model: str, + dtype: str, + max_tokens: int, + hf_overrides: dict, + monkeypatch, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_transformers_version(on_fail="skip") - activation_scheme = hf_overrides.get('quantization_config', - {}).get('activation_scheme') - quantize_activation = activation_scheme == 'dynamic' + activation_scheme = hf_overrides.get("quantization_config", {}).get( + "activation_scheme" + ) + quantize_activation = activation_scheme == "dynamic" # Allows using apply_model monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") @@ -48,13 +57,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, prompts = [ "A robot may not injure a human being", - "It is only with the heart that one can see rightly;", - "The greatest glory in living lies not in never falling,", ] answers = [ - "or, being injured, not kill, except in", - "without the heart, one can only see wrongly.", - "but in rising every time we fall. - Nelson" + "or kill a human being", ] with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm: diff --git a/tests/v1/tpu/test_tpu_qkv_linear.py b/tests/v1/tpu/test_tpu_qkv_linear.py index 46fa1193881fa..098d925505424 100644 --- a/tests/v1/tpu/test_tpu_qkv_linear.py +++ b/tests/v1/tpu/test_tpu_qkv_linear.py @@ -9,8 +9,10 @@ import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.layers.linear import QKVParallelLinear @@ -36,7 +38,8 @@ def setup_environment(): 0, local_rank=0, distributed_init_method=f"file://{temp_file}", - backend="gloo") + backend="gloo", + ) ensure_model_parallel_initialized(1, 1) yield @@ -51,7 +54,7 @@ def _get_spmd_mesh(): num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y")) return MESH @@ -59,7 +62,7 @@ def _get_spmd_mesh(): # `xr.use_spmd()` will set a global state, and this state is not reversible. # Therefore, non-SPMD tests should be run before SPMD tests. @pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()]) -@pytest.mark.parametrize("device", ['cpu', 'xla']) +@pytest.mark.parametrize("device", ["cpu", "xla"]) @torch.no_grad() def test_xla_qkv_linear(bias, mesh, device): torch.manual_seed(123) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 5a05781a03f2a..df9fcdc37fa37 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -4,18 +4,25 @@ import pytest from vllm.attention.layer import Attention -from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CacheConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_config) -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) +from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.tpu_model_runner import ( - TPUModelRunner, _get_padded_num_reqs_with_upper_limit, - _get_padded_token_len, _get_req_paddings, _get_token_paddings) + TPUModelRunner, + _get_padded_num_reqs_with_upper_limit, + _get_padded_token_len, + _get_req_paddings, + _get_token_paddings, +) def get_vllm_config(): @@ -64,15 +71,14 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=PoolingParams(), - block_ids=([0], ), # block_ids should be tuple[list[int]] + block_ids=([0],), # block_ids should be tuple[list[int]] num_computed_tokens=0, lora_request=None, - )) + ) + ) num_scheduled_tokens[req_id] = 3 total_num_scheduled_tokens += num_scheduled_tokens[req_id] @@ -85,7 +91,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -101,7 +107,7 @@ def _is_req_added(model_runner, req_id: str) -> bool: def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: """Check if the request state block IDs match the block table. - + This function handles both legacy BlockTable and new MultiGroupBlockTable structures for backward compatibility. """ @@ -127,7 +133,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: return False num_blocks = block_table.num_blocks_per_row[req_index] - block_table_values = block_table.block_table_np[req_index, :num_blocks] + block_table_values = block_table.block_table.np[req_index, :num_blocks] return (block_table_values == req_block_ids).all() @@ -164,7 +170,7 @@ def test_update_states_request_finished(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -194,7 +200,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -208,7 +214,7 @@ def test_update_states_request_resumed(model_runner): req_ids=[req_id], resumed_from_preemption=[False], new_token_ids=[[]], - new_block_ids=[([], )], + new_block_ids=[([],)], num_computed_tokens=[0], ) @@ -221,7 +227,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -252,7 +258,7 @@ def test_update_states_no_changes(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -287,7 +293,7 @@ def test_update_states_request_unscheduled(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -305,27 +311,23 @@ def test_get_paddings(): # Bucketed padding min_token_size, max_token_size, padding_gap = 16, 512, 64 expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) # Bucketed padding with max_token_size not a power of two. max_token_size = 317 expected_paddings = [16, 32, 64, 128, 192, 256, 320] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings # Exponential padding. max_token_size, padding_gap = 1024, 0 expected_paddings = [16, 32, 64, 128, 256, 512, 1024] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings # Exponential padding with max_token_size not a power of two. max_token_size = 317 expected_paddings = [16, 32, 64, 128, 256, 512] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings @@ -352,32 +354,31 @@ def test_get_req_paddings(): assert _get_req_paddings(8, 36) == [8, 16, 32, 36] -def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order( - model_runner): +def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner): layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} must come before the current layer" vllm_config = model_runner.vllm_config - with pytest.raises(ValueError, match=error_msg), \ - set_current_vllm_config(vllm_config): + with ( + pytest.raises(ValueError, match=error_msg), + set_current_vllm_config(vllm_config), + ): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, kv_sharing_target_layer_name=layer_1, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -389,25 +390,25 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner): invalid_layer = "model.layers.0.cross_attn.attn" error_msg = f"{invalid_layer} is not a valid Attention layer in the model" vllm_config = model_runner.vllm_config - with pytest.raises(ValueError, match=error_msg), \ - set_current_vllm_config(vllm_config): + with ( + pytest.raises(ValueError, match=error_msg), + set_current_vllm_config(vllm_config), + ): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, # invalid layer: cross_attn.atn doesn't exist! kv_sharing_target_layer_name=invalid_layer, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -418,26 +419,26 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner): layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} cannot be the same as the current layer" vllm_config = model_runner.vllm_config - with pytest.raises(ValueError, match=error_msg), \ - set_current_vllm_config(vllm_config): + with ( + pytest.raises(ValueError, match=error_msg), + set_current_vllm_config(vllm_config), + ): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -449,20 +450,18 @@ def test_init_kv_cache_without_kv_sharing(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -477,17 +476,17 @@ def test_init_kv_cache_without_kv_sharing(): available_memory = 20 * GiB_bytes # page size for each layer KV can be calculated as # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim) - # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB + # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers) - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without # max_context_len = available_memory / (page_size / block_size) / num_caches # max_context_len = 5GB / (512KB / 128) / 2 = 655360 @@ -497,8 +496,9 @@ def test_init_kv_cache_without_kv_sharing(): # this will only allocate 2 block worth of memory (2 * 512kb) kv_cache_config.num_blocks = 1 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - kv_cache_tensor.size = ( - kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) + kv_cache_tensor.size = kv_cache_spec[ + kv_cache_tensor.shared_by[0] + ].page_size_bytes model_runner.initialize_kv_cache(kv_cache_config) @@ -520,21 +520,19 @@ def test_init_kv_cache_with_kv_sharing_valid(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name="model.layers.0.self_attn.attn", - ) + ), } # suppress var not used error assert fwd_context is not None @@ -552,24 +550,23 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 2 * 20480 # 20GB / 512KB - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache # compared to no KV sharing assert kv_cache_config.kv_cache_tensors[0].size == available_memory - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == (2 * 655360) # important: override tensor size to prevent large mem alloc during test # this will only allocate 1 block worth of memory (512kb) kv_cache_config.num_blocks = 1 - kv_cache_config.kv_cache_tensors[0].size =\ - kv_cache_spec[layer_0].page_size_bytes + kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes model_runner.initialize_kv_cache(kv_cache_config) diff --git a/tests/worker/__init__.py b/tests/v1/tracing/__init__.py similarity index 100% rename from tests/worker/__init__.py rename to tests/v1/tracing/__init__.py diff --git a/tests/v1/tracing/test_tracing.py b/tests/v1/tracing/test_tracing.py new file mode 100644 index 0000000000000..505da41631438 --- /dev/null +++ b/tests/v1/tracing/test_tracing.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa +# type: ignore +from __future__ import annotations + +import threading +from collections.abc import Iterable +from concurrent import futures +from typing import Callable, Generator, Literal + +import grpc +import pytest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceResponse, +) +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( + TraceServiceServicer, + add_TraceServiceServicer_to_server, +) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE + +from vllm import LLM, SamplingParams +from vllm.tracing import SpanAttributes + +FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" + +FieldName = Literal[ + "bool_value", "string_value", "int_value", "double_value", "array_value" +] + + +def decode_value(value: AnyValue): + field_decoders: dict[FieldName, Callable] = { + "bool_value": (lambda v: v.bool_value), + "string_value": (lambda v: v.string_value), + "int_value": (lambda v: v.int_value), + "double_value": (lambda v: v.double_value), + "array_value": ( + lambda v: [decode_value(item) for item in v.array_value.values] + ), + } + for field, decoder in field_decoders.items(): + if value.HasField(field): + return decoder(value) + raise ValueError(f"Couldn't decode value: {value}") + + +def decode_attributes(attributes: Iterable[KeyValue]): + return {kv.key: decode_value(kv.value) for kv in attributes} + + +class FakeTraceService(TraceServiceServicer): + def __init__(self): + self.request = None + self.evt = threading.Event() + + def Export(self, request, context): + self.request = request + self.evt.set() + return ExportTraceServiceResponse() + + +@pytest.fixture +def trace_service() -> Generator[FakeTraceService, None, None]: + """Fixture to set up a fake gRPC trace service""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + service = FakeTraceService() + add_TraceServiceServicer_to_server(service, server) + server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) + server.start() + + yield service + + server.stop(None) + + +def test_traces( + monkeypatch: pytest.MonkeyPatch, + trace_service: FakeTraceService, +): + with monkeypatch.context() as m: + m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") + + sampling_params = SamplingParams( + temperature=0.01, + top_p=0.1, + max_tokens=256, + ) + model = "facebook/opt-125m" + llm = LLM( + model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + gpu_memory_utilization=0.3, + disable_log_stats=False, + ) + prompts = ["This is a short prompt"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + print(f"test_traces outputs is : {outputs}") + + timeout = 10 + if not trace_service.evt.wait(timeout): + raise TimeoutError( + f"The fake trace service didn't receive a trace within " + f"the {timeout} seconds timeout" + ) + + request = trace_service.request + assert len(request.resource_spans) == 1, ( + f"Expected 1 resource span, but got {len(request.resource_spans)}" + ) + assert len(request.resource_spans[0].scope_spans) == 1, ( + f"Expected 1 scope span, " + f"but got {len(request.resource_spans[0].scope_spans)}" + ) + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( + f"Expected 1 span, " + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}" + ) + + attributes = decode_attributes( + request.resource_spans[0].scope_spans[0].spans[0].attributes + ) + # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id + assert ( + attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE) + == sampling_params.temperature + ) + assert ( + attributes.get(SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p + ) + assert ( + attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) + == sampling_params.max_tokens + ) + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n + assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids + ) + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) + assert ( + attributes.get(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) + == completion_tokens + ) + + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0 + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0 diff --git a/tests/v1/test_utils.py b/tests/v1/utils.py similarity index 56% rename from tests/v1/test_utils.py rename to tests/v1/utils.py index 00d98a873a310..993ad8a947d03 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/utils.py @@ -1,79 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import pytest import regex as re import requests -import torch from tests.utils import RemoteOpenAIServer -from vllm.v1.worker.utils import bind_kv_cache - - -def test_bind_kv_cache(): - from vllm.attention import Attention - - ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), - } - kv_cache = { - 'layers.0.self_attn': torch.zeros((1, )), - 'layers.1.self_attn': torch.zeros((1, )), - 'layers.2.self_attn': torch.zeros((1, )), - 'layers.3.self_attn': torch.zeros((1, )), - } - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ - 'layers.0.self_attn'] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ - 'layers.1.self_attn'] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ - 'layers.2.self_attn'] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ - 'layers.3.self_attn'] - - assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] - assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] - assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] - assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] - - -def test_bind_kv_cache_non_attention(): - from vllm.attention import Attention - - # example from Jamba PP=2 - ctx = { - 'model.layers.20.attn': Attention(32, 128, 0.1), - 'model.layers.28.attn': Attention(32, 128, 0.1), - } - kv_cache = { - 'model.layers.20.attn': torch.zeros((1, )), - 'model.layers.28.attn': torch.zeros((1, )), - } - - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache(kv_cache, ctx, runner_kv_caches) - - assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ - 'model.layers.20.attn'] - assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ - 'model.layers.28.attn'] - - assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] - assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] - # Prometheus metrics utilities for testing -def get_prometheus_metrics( - server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: +def get_prometheus_metrics(server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: """Fetch and parse Prometheus metrics from the /metrics endpoint. - + Returns: Dict mapping metric names to their values grouped by labels. For example: {"vllm:request_success": { @@ -88,14 +26,14 @@ def get_prometheus_metrics( # Regex patterns for Prometheus metrics metric_with_labels = re.compile( - r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$') - metric_simple = re.compile( - r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$') + r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$" + ) + metric_simple = re.compile(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$") - for line in response.text.split('\n'): + for line in response.text.split("\n"): line = line.strip() # Skip comments and empty lines - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue # Try to match metric with labels first @@ -106,7 +44,7 @@ def get_prometheus_metrics( value = float(value_str) if metric_name not in metrics: metrics[metric_name] = {} - metrics[metric_name][f'{{{labels_part}}}'] = value + metrics[metric_name][f"{{{labels_part}}}"] = value except ValueError: continue else: @@ -118,7 +56,7 @@ def get_prometheus_metrics( value = float(value_str) if metric_name not in metrics: metrics[metric_name] = {} - metrics[metric_name][''] = value + metrics[metric_name][""] = value except ValueError: continue @@ -128,10 +66,9 @@ def get_prometheus_metrics( return {} -def get_engine_request_counts( - metrics: dict[str, dict[str, float]]) -> dict[str, float]: +def get_engine_request_counts(metrics: dict[str, dict[str, float]]) -> dict[str, float]: """Extract request counts per engine from Prometheus metrics. - + Returns: Dict mapping engine indices to request counts. For example: {"0": 15.0, "1": 12.0} @@ -156,7 +93,7 @@ def get_engine_request_counts( def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): """Check request balancing via Prometheus metrics if dp_size > 1. - + Args: server: The RemoteOpenAIServer instance dp_size: Number of data parallel ranks @@ -175,7 +112,8 @@ def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): assert len(engines_with_requests) == dp_size, ( f"Expected requests to be distributed across multiple engines," f" but only engine(s) {engines_with_requests} received " - f"requests. Engine counts: {engine_counts}") + f"requests. Engine counts: {engine_counts}" + ) # Verify that the load is reasonably balanced # (no engine should handle all requests) @@ -183,4 +121,5 @@ def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): for count in engine_counts.values(): assert count > total_requests // (dp_size + 1), ( - f"requests are imbalanced: {engine_counts}") + f"requests are imbalanced: {engine_counts}" + ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index d7b4746562beb..5a598dcab7189 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -15,6 +15,7 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -28,14 +29,11 @@ CUDA_DEVICES = [ MAX_NUM_PROMPT_TOKENS = 64 -def _compare_objs(obj1, - obj2, - skip: Sequence = ("logitsprocs", "batch_update_builder")): +def _compare_objs(obj1, obj2, skip: Sequence = ("logitsprocs", "batch_update_builder")): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) - attr_names = set([ - a[0] for a in attrs - if not (a[0].startswith('__') and a[0].endswith('__')) - ]) + attr_names = set( + [a[0] for a in attrs if not (a[0].startswith("__") and a[0].endswith("__"))] + ) for attr_name in attr_names: if attr_name in skip: continue @@ -45,8 +43,8 @@ def _compare_objs(obj1, is_same = False if isinstance(a, torch.Tensor): - if (a.numel() == 0 or b.numel() == 0): - is_same = (a.numel() == 0 and b.numel() == 0) + if a.numel() == 0 or b.numel() == 0: + is_same = a.numel() == 0 and b.numel() == 0 elif torch.allclose(a, b): is_same = True elif isinstance(a, np.ndarray): @@ -61,12 +59,16 @@ def _compare_objs(obj1, is_same = True # if we make it here must be same elif a == b: is_same = True - assert is_same, f"Attribute {attr_name} is different"\ - f" in {obj1} and {obj2}: {a} != {b}" + elif isinstance(a, CpuGpuBuffer): + is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu) + assert is_same, ( + f"Attribute {attr_name} is different in {obj1} and {obj2}: {a} != {b}" + ) -def _remove_requests(input_batch: InputBatch, batch_size: int, - reqs: list[CachedRequestState]) -> set[str]: +def _remove_requests( + input_batch: InputBatch, batch_size: int, reqs: list[CachedRequestState] +) -> set[str]: """ Remove some requests randomly from the batch and returns set of request removed @@ -106,10 +108,9 @@ def _construct_expected_sampling_metadata( temperature = [0.0 for _ in range(num_reqs)] min_tokens = {} logit_bias = [None] * num_reqs - allowed_token_ids_mask = torch.zeros(num_reqs, - VOCAB_SIZE, - dtype=torch.bool, - device=device) + allowed_token_ids_mask = torch.zeros( + num_reqs, VOCAB_SIZE, dtype=torch.bool, device=device + ) bad_words_token_ids = {} for req in reqs: if req.req_id not in req_ids_retained: @@ -117,35 +118,40 @@ def _construct_expected_sampling_metadata( index_in_input_batch = req_id_index_in_input_batch[req.req_id] output_token_ids[index_in_input_batch] = req.output_token_ids prompt_token_ids[index_in_input_batch] = req.prompt_token_ids - presence_penalties[ - index_in_input_batch] = req.sampling_params.presence_penalty + presence_penalties[index_in_input_batch] = req.sampling_params.presence_penalty frequency_penalties[index_in_input_batch] = ( - req.sampling_params.frequency_penalty) + req.sampling_params.frequency_penalty + ) repetition_penalties[index_in_input_batch] = ( - req.sampling_params.repetition_penalty) + req.sampling_params.repetition_penalty + ) top_k[index_in_input_batch] = req.sampling_params.top_k top_p[index_in_input_batch] = req.sampling_params.top_p temperature[index_in_input_batch] = req.sampling_params.temperature min_tokens[index_in_input_batch] = ( req.sampling_params.min_tokens, - req.sampling_params.all_stop_token_ids) + req.sampling_params.all_stop_token_ids, + ) logit_bias[index_in_input_batch] = req.sampling_params.logit_bias if req.sampling_params.allowed_token_ids: allowed_token_ids_mask[index_in_input_batch][ - req.sampling_params.allowed_token_ids] = True + req.sampling_params.allowed_token_ids + ] = True if req.sampling_params.bad_words_token_ids: - bad_words_token_ids[ - index_in_input_batch] = req.sampling_params.bad_words_token_ids + bad_words_token_ids[index_in_input_batch] = ( + req.sampling_params.bad_words_token_ids + ) return SamplingMetadata( - temperature=torch.tensor(temperature, dtype=torch.float, - device=device), + temperature=torch.tensor(temperature, dtype=torch.float, device=device), all_greedy=False, all_random=True, - top_p=None if all(x == 1.0 for x in top_p) else torch.tensor( - top_p, dtype=torch.float, device=device), - top_k=None if all(x == 0 for x in top_k) else torch.tensor( - top_k, dtype=torch.int, device=device), + top_p=None + if all(x == 1.0 for x in top_p) + else torch.tensor(top_p, dtype=torch.float, device=device), + top_k=None + if all(x == 0 for x in top_k) + else torch.tensor(top_k, dtype=torch.int, device=device), generators={}, max_num_logprobs=0, prompt_token_ids=make_tensor_with_pad( @@ -154,19 +160,22 @@ def _construct_expected_sampling_metadata( device=torch.device(device), dtype=torch.int64, ), - frequency_penalties=torch.tensor(frequency_penalties, - dtype=torch.float, - device=device), - presence_penalties=torch.tensor(presence_penalties, - dtype=torch.float, - device=device), - repetition_penalties=torch.tensor(repetition_penalties, - dtype=torch.float, - device=device), + frequency_penalties=torch.tensor( + frequency_penalties, dtype=torch.float, device=device + ), + presence_penalties=torch.tensor( + presence_penalties, dtype=torch.float, device=device + ), + repetition_penalties=torch.tensor( + repetition_penalties, dtype=torch.float, device=device + ), output_token_ids=output_token_ids, - no_penalties=(all(x == 0 for x in presence_penalties) - and all(x == 0 for x in frequency_penalties) - and all(x == 1 for x in repetition_penalties)), + spec_token_ids=[[] for _ in range(len(output_token_ids))], + no_penalties=( + all(x == 0 for x in presence_penalties) + and all(x == 0 for x in frequency_penalties) + and all(x == 1 for x in repetition_penalties) + ), allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, logitsprocs=LogitsProcessors(), @@ -182,8 +191,7 @@ def _create_sampling_params(): frequency_penalty=np.random.uniform(-2.0, 2.0), min_tokens=np.random.randint(1, 10), stop_token_ids=[ - np.random.randint(0, VOCAB_SIZE) - for _ in range(np.random.randint(10)) + np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(10)) ], logit_bias={0: np.random.uniform(-3.0, 3.0)}, ) @@ -203,9 +211,8 @@ def _construct_cached_request_state(req_id_suffix: int): prompt_token_ids=prompt_token_ids, sampling_params=_create_sampling_params(), pooling_params=None, - mm_kwargs=[], - mm_positions=[], - block_ids=([], ), + mm_features=[], + block_ids=([],), generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, @@ -234,6 +241,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -260,19 +268,18 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( - reqs, - req_ids_retained, - input_batch.req_id_to_index, - device=torch.device(device)) + reqs, req_ids_retained, input_batch.req_id_to_index, device=torch.device(device) + ) def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: - return (t1 is None - and t2 is None) or (t1 is not None and t2 is not None - and torch.allclose(t1, t2)) + return (t1 is None and t2 is None) or ( + t1 is not None and t2 is not None and torch.allclose(t1, t2) + ) # Assert the actual and expected output. - assert torch.allclose(expected_sampling_metadata.temperature, - sampling_metadata.temperature) + assert torch.allclose( + expected_sampling_metadata.temperature, sampling_metadata.temperature + ) assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p) assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k) assert torch.allclose( @@ -287,25 +294,29 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): expected_sampling_metadata.repetition_penalties, sampling_metadata.repetition_penalties, ) - assert torch.allclose(expected_sampling_metadata.prompt_token_ids, - sampling_metadata.prompt_token_ids) - assert (expected_sampling_metadata.output_token_ids == - sampling_metadata.output_token_ids) - assert expected_sampling_metadata.no_penalties == \ - sampling_metadata.no_penalties + assert torch.allclose( + expected_sampling_metadata.prompt_token_ids, sampling_metadata.prompt_token_ids + ) + assert ( + expected_sampling_metadata.output_token_ids + == sampling_metadata.output_token_ids + ) + assert expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties if sampling_metadata.allowed_token_ids_mask: assert torch.allclose( expected_sampling_metadata.allowed_token_ids_mask, - sampling_metadata.allowed_token_ids_mask) - assert expected_sampling_metadata.bad_words_token_ids == \ - sampling_metadata.bad_words_token_ids + sampling_metadata.allowed_token_ids_mask, + ) + assert ( + expected_sampling_metadata.bad_words_token_ids + == sampling_metadata.bad_words_token_ids + ) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("swap_list", [((0, 1), )]) -def test_swap_states_in_input_batch(device: str, batch_size: int, - swap_list: list): +@pytest.mark.parametrize("swap_list", [((0, 1),)]) +def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: list): """ Tests the logic for managing sampling metadata in the InputBatch. @@ -325,6 +336,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -334,6 +346,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] @@ -350,8 +363,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, reordered_reqs = reqs.copy() for swap_pair in swap_list: - reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \ - reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]] + reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = ( + reordered_reqs[swap_pair[1]], + reordered_reqs[swap_pair[0]], + ) input_batch.swap_states(swap_pair[0], swap_pair[1]) for req_index in range(batch_size): diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b9b2314ce573f..817cd7f10c1c6 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,27 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - import numpy as np import pytest import torch from vllm.attention import Attention -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VllmConfig, set_current_vllm_config) -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.config import ( + CacheConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, update_environment_variables -from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_config) -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) +from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheTensor, +) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -37,11 +45,9 @@ def initialize_kv_cache(runner: GPUModelRunner): """ attn_spec = FullAttentionSpec( block_size=BLOCK_SIZE, - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), + num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config), head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, - use_mla=False, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( @@ -61,7 +67,8 @@ def initialize_kv_cache(runner: GPUModelRunner): device=runner.device, pin_memory=runner.pin_memory, vocab_size=runner.model_config.get_vocab_size(), - block_sizes=[ + block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], + kernel_block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], ) @@ -101,8 +108,9 @@ def model_runner(): model_config = vllm_config.model_config num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) head_size = model_config.get_head_size() - vllm_config.compilation_config.static_forward_context[ - "layer.0"] = Attention(num_heads, head_size, 0.1) + vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( + num_heads, head_size, 0.1 + ) runner = GPUModelRunner(vllm_config, DEVICE) initialize_kv_cache(runner) return runner @@ -120,15 +128,14 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=None, - block_ids=([0], ), + block_ids=([0],), num_computed_tokens=0, lora_request=None, - )) + ) + ) num_scheduled_tokens[req_id] = 3 total_num_scheduled_tokens += num_scheduled_tokens[req_id] @@ -141,7 +148,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -155,22 +162,22 @@ def _is_req_added(model_runner, req_id: str) -> bool: return req_id in model_runner.requests -def _is_sampling_metadata_changed(model_runner, - sampling_metadata_before: SamplingMetadata): - return model_runner.input_batch.sampling_metadata is not ( - sampling_metadata_before) +def _is_sampling_metadata_changed( + model_runner, sampling_metadata_before: SamplingMetadata +): + return model_runner.input_batch.sampling_metadata is not (sampling_metadata_before) def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] block_table = model_runner.input_batch.block_table[0] req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len( - req_state.block_ids[0]): + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids[0]): return False num_blocks = block_table.num_blocks_per_row[req_index] - return (block_table.block_table_np[req_index, :num_blocks] == - req_state.block_ids[0]).all() + return ( + block_table.block_table.np[req_index, :num_blocks] == req_state.block_ids[0] + ).all() def test_update_states_new_request(model_runner, dist_init): @@ -207,7 +214,7 @@ def test_update_states_request_finished(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -239,7 +246,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -253,8 +260,10 @@ def test_update_states_request_resumed(model_runner, dist_init): req_ids=[req_id], resumed_from_preemption=[False], new_token_ids=[[]], - new_block_ids=([[0]], ), + resumed_req_token_ids=[None], + new_block_ids=([[0]],), num_computed_tokens=[0], + num_output_tokens=[0], ) scheduler_output = SchedulerOutput( @@ -266,7 +275,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -285,46 +294,58 @@ def test_get_nans_in_logits(model_runner, dist_init): scheduler_output = _schedule_new_request(*req_ids) model_runner._update_states(scheduler_output) - logits = torch.tensor([ - [1.0, 2.0, 3.0], - [3.0, 2.0, 1.0], - ], device=DEVICE) + logits = torch.tensor( + [ + [1.0, 2.0, 3.0], + [3.0, 2.0, 1.0], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) assert result == {"req_0": 0, "req_1": 0} - logits = torch.tensor([ - [1.0, float('nan'), 3.0], - [4.0, float('nan'), float('nan')], - ], - device=DEVICE) + logits = torch.tensor( + [ + [1.0, float("nan"), 3.0], + [4.0, float("nan"), float("nan")], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) assert result == {"req_0": 1, "req_1": 2} - logits = torch.tensor([ - [1.0, 2.0, 3.0], - [4.0, float('nan'), float('nan')], - ], - device=DEVICE) + logits = torch.tensor( + [ + [1.0, 2.0, 3.0], + [4.0, float("nan"), float("nan")], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) assert result == {"req_0": 0, "req_1": 2} result = model_runner._get_nans_in_logits(logits=None) assert result == {"req_0": 0, "req_1": 0} - logits = torch.tensor([ - [1.0, float('nan'), 3.0], - ], device=DEVICE) + logits = torch.tensor( + [ + [1.0, float("nan"), 3.0], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) - assert result == {'req_0': 1, 'req_1': 0} + assert result == {"req_0": 1, "req_1": 0} - logits = torch.tensor([ - [float('nan'), float('nan'), 2.0], - [1.0, 2.0, 3.0], - [float('nan'), 2.0, 3.0], - ], - device=DEVICE) + logits = torch.tensor( + [ + [float("nan"), float("nan"), 2.0], + [1.0, 2.0, 3.0], + [float("nan"), 2.0, 3.0], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) - assert result == {'req_0': 2, 'req_1': 0} + assert result == {"req_0": 2, "req_1": 0} def test_update_states_no_changes(model_runner, dist_init): @@ -347,7 +368,7 @@ def test_update_states_no_changes(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -384,7 +405,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -402,36 +423,40 @@ def test_update_states_request_unscheduled(model_runner, dist_init): def test_kv_cache_stride_order(monkeypatch, model_runner): # This test checks if GPUModelRunner initializes correctly when an attention # backend enforces a non-default KV cache stride order. - n_heads = model_runner.model_config.get_num_kv_heads( - model_runner.parallel_config) + n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) expected_kv_cache_shape = [ - 2, NUM_BLOCKS, BLOCK_SIZE, n_heads, - model_runner.model_config.get_head_size() + 2, + NUM_BLOCKS, + BLOCK_SIZE, + n_heads, + model_runner.model_config.get_head_size(), ] # TODO mla test - default_stride = list(range(5)) + default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape - rnd_stride = tuple(random.sample(default_stride, len(default_stride))) + for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): - def rnd_stride_order(): - return rnd_stride + def rnd_stride_order(test_stride=test_stride): + return test_stride - # Patch the attention backend class and re-trigger the KV cache creation. - for attn_group in model_runner._attn_group_iterator(): - attn_backend = attn_group.backend - monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", - rnd_stride_order) + # Patch the attention backend class and re-trigger the KV cache creation + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + monkeypatch.setattr( + attn_backend, "get_kv_cache_stride_order", rnd_stride_order + ) - model_runner.attn_groups = [] - model_runner.initialize_kv_cache(model_runner.kv_cache_config) + model_runner.attn_groups = [] + model_runner.kv_caches = [] + model_runner.initialize_kv_cache(model_runner.kv_cache_config) - # Shape is unchanged, but layout may differ - kv_cache_shape = model_runner.kv_caches[0].shape - assert list(kv_cache_shape) == expected_kv_cache_shape - if default_stride == rnd_stride: - assert all(kv.is_contiguous() for kv in model_runner.kv_caches) - else: - assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) + # Shape is unchanged, but layout may differ + kv_cache_shape = model_runner.kv_caches[0].shape + assert list(kv_cache_shape) == expected_kv_cache_shape + if default_stride == test_stride: + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) + else: + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) def test_update_config(model_runner): @@ -451,14 +476,13 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): model_runner_2.update_config({"load_config": {"load_format": "dummy"}}) model_runner_2.load_model() # Initial model loading with dummy weights assert str(model_runner.get_model().state_dict()) != str( - model_runner_2.get_model().state_dict()) - model_runner_2.update_config( - {"load_config": { - "load_format": original_load_format - }}) + model_runner_2.get_model().state_dict() + ) + model_runner_2.update_config({"load_config": {"load_format": original_load_format}}) model_runner_2.reload_weights() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( - model_runner_2.get_model().state_dict()) + model_runner_2.get_model().state_dict() + ) def test_reload_weights_before_load_model(model_runner): @@ -475,21 +499,19 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, kv_sharing_target_layer_name=layer_1, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -503,22 +525,20 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): error_msg = f"{invalid_layer} is not a valid Attention layer in the model" with pytest.raises(ValueError, match=error_msg): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, # invalid layer: cross_attn.atn doesn't exist! kv_sharing_target_layer_name=invalid_layer, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -533,21 +553,19 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -560,20 +578,18 @@ def test_init_kv_cache_without_kv_sharing(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -588,15 +604,15 @@ def test_init_kv_cache_without_kv_sharing(): available_memory = 20 * GiB_bytes # page size for layer 0's kv_cache_spec is 32KB num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == 1310720 @@ -604,8 +620,9 @@ def test_init_kv_cache_without_kv_sharing(): # this will only allocate 2 block worth of memory (2 * 32kb) kv_cache_config.num_blocks = 1 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - kv_cache_tensor.size = ( - kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) + kv_cache_tensor.size = kv_cache_spec[ + kv_cache_tensor.shared_by[0] + ].page_size_bytes runner.initialize_kv_cache(kv_cache_config) @@ -628,21 +645,19 @@ def test_init_kv_cache_with_kv_sharing_valid(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name="model.layers.0.self_attn.attn", - ) + ), } # suppress var not used error assert fwd_context is not None @@ -660,24 +675,23 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 655360 # 20GB / 32KB - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache # compared to no KV sharing assert kv_cache_config.kv_cache_tensors[0].size == available_memory - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == 2 * 1310720 # important: override tensor size to prevent large mem alloc during test # this will only allocate 1 block worth of memory (32kb) kv_cache_config.num_blocks = 1 - kv_cache_config.kv_cache_tensors[0].size =\ - kv_cache_spec[layer_0].page_size_bytes + kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) kv_cache_config_after_init = runner.kv_cache_config @@ -690,30 +704,30 @@ def test_init_kv_cache_with_kv_sharing_valid(): # check layer 1 added to kv cache group's layer names assert len(kv_cache_config_after_init.kv_cache_groups) == 1 assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ - 0] == layer_0 - assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ - 1] == layer_1 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): - ''' + """ The GPU model runner creates different views into the KVCacheTensors for the attention and mamba layers (via _reshape_kv_cache_tensors function). This test verifies that the views are compatible: writing a mamba block - will not corrupt an attention block and vice-versa - ''' + will not corrupt an attention block and vice versa + """ current_platform.seed_everything(42) - update_environment_variables({ - 'RANK': "0", - 'LOCAL_RANK': "0", - 'WORLD_SIZE': "1", - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=1) torch.set_default_dtype(torch.float16) @@ -754,8 +768,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): fwd_context = {} for key in [layer_0, layer_1]: fwd_context[key] = Attention( - num_heads=model_config.get_num_attention_heads( - parallel_config), + num_heads=model_config.get_num_attention_heads(parallel_config), num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), scale=1.0, @@ -763,13 +776,12 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ) for key in [layer_2, layer_3, layer_4, layer_5]: fwd_context[key] = MambaMixer2( - hidden_size = hf_config.hidden_size, - ssm_state_size = hf_config.mamba_d_state, - conv_kernel_size = hf_config.mamba_d_conv, - intermediate_size = hf_config.mamba_expand *\ - hf_config.hidden_size, - use_conv_bias = hf_config.mamba_conv_bias, - use_bias = hf_config.mamba_proj_bias, + hidden_size=hf_config.hidden_size, + ssm_state_size=hf_config.mamba_d_state, + conv_kernel_size=hf_config.mamba_d_conv, + intermediate_size=hf_config.mamba_expand * hf_config.hidden_size, + use_conv_bias=hf_config.mamba_conv_bias, + use_bias=hf_config.mamba_proj_bias, n_groups=hf_config.mamba_n_groups, num_heads=hf_config.mamba_n_heads, head_dim=hf_config.mamba_d_head, @@ -784,15 +796,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): vllm_ctx = vllm_config.compilation_config.static_forward_context with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") runner = GPUModelRunner(vllm_config, DEVICE) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] runner.initialize_kv_cache(kv_cache_config) # random partition of blocks @@ -801,43 +813,238 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): num_blocks = kv_cache_config.num_blocks ind = np.arange(num_blocks) np.random.shuffle(ind) - blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):] + blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] attn_shape = vllm_ctx[layer_0].kv_cache[0].shape conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape # assert we are using FlashInfer - assert attn_shape[0] == num_blocks + assert attn_shape[0] % num_blocks == 0 + block_split_ratio = attn_shape[0] // num_blocks - attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]), - device=DEVICE, - fill_value=3.33) - conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]), - device=DEVICE, - fill_value=6.66) - ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]), - device=DEVICE, - fill_value=9.99) + # use small blocks for testing to avoid memory issues + test_block_size = min(2, len(blocks0), len(blocks1)) + + # use non-overlapping blocks to avoid data contamination + # Split kernel blocks: first half for attention, second half for mamba + mid_point = num_blocks // 2 + + # attention uses kernel blocks from first half (mapped to logical blocks) + kv_blocks_for_attention = np.array([0, 1])[:test_block_size] + + # mamba uses kernel blocks from second half + kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] + + # create small constant tensors for testing with corrected shapes + # attention: [block_size, ...] starting from dimension 2 + attn_constant_shape = attn_shape[2:] + conv_constant_shape = conv_shape[1:] + ssm_constant_shape = ssm_shape[1:] + + attn_blocks_constant = torch.full( + (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 + ) + conv_blocks_constant = torch.full( + (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 + ) + ssm_blocks_constant = torch.full( + (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 + ) + + # Fill attention blocks with constants using kv block indices + kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio - # fill all attention blocks with constant for layer in [layer_0, layer_1]: - vllm_ctx[layer].kv_cache[0][ - blocks0, :] = attn_blocks_constant.detach().clone() + # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] + for i, kernel_block in enumerate(kernel_blocks_for_attention): + vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] - # fill all mamba blocks with constant + # fill mamba blocks with constants using kernel block indices for layer in [layer_2, layer_3, layer_4, layer_5]: - vllm_ctx[layer].kv_cache[0][0][ - blocks1, :] = conv_blocks_constant.detach().clone() - vllm_ctx[layer].kv_cache[0][1][ - blocks1, :] = ssm_blocks_constant.detach().clone() + # mamba: kv_cache[0][component][kernel_block_idx, ...] + for i, kv_block in enumerate(kv_blocks_for_mamba): + vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] + vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] # verify attention and mamba contents are correct for layer in [layer_0, layer_1]: - assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :], - attn_blocks_constant) + for i, kernel_block in enumerate(kernel_blocks_for_attention): + actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] + expected = attn_blocks_constant[i] + + # Check K and V separately + assert torch.equal(actual_kv[0], expected) + assert torch.equal(actual_kv[1], expected) + for layer in [layer_2, layer_3, layer_4, layer_5]: - assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :], - conv_blocks_constant) - assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], - ssm_blocks_constant) + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) + + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) + + +def test_hybrid_block_table_initialization(): + """Test hybrid block table with different kernel and kvcache_manager block + sizes.""" + from vllm.v1.worker.block_table import BlockTable + + # Test configuration: kvcache_manager block size = 32, + # kernel block size = 16 + block_size = 32 + kernel_block_sizes = [16] + max_num_reqs = 10 + max_num_blocks_per_req = 20 + max_num_batched_tokens = 512 + + block_table = BlockTable( + block_size=block_size, + max_num_reqs=max_num_reqs, + max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=False, + device=torch.device(DEVICE), + kernel_block_size=kernel_block_sizes[0], + ) + + # Verify hybrid block configuration + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_block_sizes[0] + assert block_table.blocks_per_kv_block == ( + block_size // kernel_block_sizes[0] + ) # Changed to use first element + + # Test block table conversion logic + # One kvcache_manager block should map to multiple kernel blocks + kvcache_manager_blocks = [0, 1, 2] + + # Verify that kvcache_manager blocks can be converted to kernel blocks + # and that block table operations work correctly. + req_index = 0 + block_table.append_row(kvcache_manager_blocks, req_index) + # Get expected kernel blocks from the implementation for verification. + expected_kernel_blocks = block_table._map_to_kernel_blocks( + np.array(kvcache_manager_blocks) + ) + # Verify block table state + assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) + assert np.array_equal( + block_table.block_table.np[req_index, : len(expected_kernel_blocks)], + expected_kernel_blocks, + ) + + +def test_input_batch_with_kernel_block_sizes(): + """Test InputBatch initialization with kernel_block_sizes parameter.""" + max_num_reqs = 10 + max_model_len = 512 + max_num_batched_tokens = 512 + device = torch.device(DEVICE) + pin_memory = False + vocab_size = 50272 + + # Test with different kernel block sizes + block_sizes = [32, 64] + kernel_block_sizes = [16, 32] + + input_batch = InputBatch( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + device=device, + pin_memory=pin_memory, + vocab_size=vocab_size, + block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, + ) + + # Verify that block tables were created with kernel block sizes + assert len(input_batch.block_table.block_tables) == len(block_sizes) + + for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)): + block_table = input_batch.block_table.block_tables[i] + if kv_size != kernel_size: + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_size + else: + assert block_table.use_hybrid_blocks is False + assert block_table.block_size == kernel_size + + +def test_hybrid_cache_integration(model_runner, dist_init): + """Test hybrid cache architecture integration with GPUModelRunner.""" + # Create a new model runner with hybrid cache configuration + vllm_config = get_vllm_config() + + # Configure hybrid cache with different kvcache_manager block size + vllm_config.cache_config.block_size = 32 + + model_config = vllm_config.model_config + num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( + num_heads, head_size, 0.1 + ) + + runner = GPUModelRunner(vllm_config, DEVICE) + + # Initialize KV cache with configuration + attn_spec = FullAttentionSpec( + block_size=16, # Use kernel block size directly + num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS + kv_cache_config = KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[ + KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) + ], + ) + runner.kv_cache_config = kv_cache_config + + # Initialize input batch with kernel block sizes + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], + kernel_block_sizes=[16], + ) # Use kernel block size + + runner.initialize_attn_backend(kv_cache_config) + + # Verify hybrid block table configuration + block_table = runner.input_batch.block_table.block_tables[0] + assert block_table.block_size == ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ) + + # Test request processing with hybrid blocks + req_id = "hybrid_req_0" + scheduler_output = _schedule_new_request(req_id) + + # Update states should work with hybrid blocks + runner._update_states(scheduler_output) + assert _is_req_scheduled(runner, req_id) + assert _is_req_state_block_table_match(runner, req_id) diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py new file mode 100644 index 0000000000000..f987b09e603e7 --- /dev/null +++ b/tests/v1/worker/test_utils.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.worker.utils import bind_kv_cache + + +def test_bind_kv_cache(): + from vllm.attention import Attention + + ctx = { + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), + } + kv_cache = { + "layers.0.self_attn": torch.zeros((1,)), + "layers.1.self_attn": torch.zeros((1,)), + "layers.2.self_attn": torch.zeros((1,)), + "layers.3.self_attn": torch.zeros((1,)), + } + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"] + + assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"] + assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"] + assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"] + assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"] + + +def test_bind_kv_cache_non_attention(): + from vllm.attention import Attention + + # example from Jamba PP=2 + ctx = { + "model.layers.20.attn": Attention(32, 128, 0.1), + "model.layers.28.attn": Attention(32, 128, 0.1), + } + kv_cache = { + "model.layers.20.attn": torch.zeros((1,)), + "model.layers.28.attn": torch.zeros((1,)), + } + + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + + assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"] + assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"] + + assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] + assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] diff --git a/tests/v1/worker/test_worker_memory_snapshot.py b/tests/v1/worker/test_worker_memory_snapshot.py new file mode 100644 index 0000000000000..cbfb9a8dc0b60 --- /dev/null +++ b/tests/v1/worker/test_worker_memory_snapshot.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing as mp +import os +import tempfile +from multiprocessing import Queue +from typing import Optional +from unittest.mock import patch + +import pytest +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import MemorySnapshot +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment + +# Global queue to track operation order across processes +_QUEUE: Optional[Queue] = None + + +def track_operation(operation: str, rank: int): + """Track when an operation happens and its rank.""" + if _QUEUE is not None: + _QUEUE.put((operation, rank)) + + +def make_operation_tracker(operation_name: str, original_func): + """Create a mock function that tracks when an operation is called. + + Args: + operation_name: Name to use when tracking this operation + original_func: The original function to wrap + + Returns: + A wrapper function that tracks the operation and calls the original + """ + + def wrapper(*args, **kwargs): + rank = int(os.environ.get("RANK", "-1")) + track_operation(operation_name, rank) + return original_func(*args, **kwargs) + + return wrapper + + +def worker_process( + rank: int, + world_size: int, + distributed_init_method: str, + queue: Queue, + error_queue: Queue, +): + """Worker process that initializes a GPU worker with proper tracking.""" + global _QUEUE + _QUEUE = queue + + try: + # Set environment variables + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Create vLLM config with small model + vllm_config = EngineArgs( + model="facebook/opt-125m", tensor_parallel_size=2, load_format="dummy" + ).create_engine_config() + + # Create worker + worker = Worker( + vllm_config=vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + # Get original functions before patching + original_init_worker = init_worker_distributed_environment + original_memory_snapshot_init = MemorySnapshot.__init__ + original_all_reduce = torch.distributed.all_reduce + + # Apply minimal patches to track operation order + init_patch = patch( + "vllm.v1.worker.gpu_worker.init_worker_distributed_environment", + side_effect=make_operation_tracker( + "init_distributed", original_init_worker + ), + ) + memory_patch = patch.object( + MemorySnapshot, + "__init__", + make_operation_tracker("memory_snapshot", original_memory_snapshot_init), + ) + all_reduce_patch = patch( + "torch.distributed.all_reduce", + side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce), + ) + + with init_patch, memory_patch, all_reduce_patch: + # Initialize device (this is where we test the order) + worker.init_device() + + # Load model to ensure everything works + worker.load_model() + + # Signal success + queue.put(("success", rank)) + + except Exception as e: + error_queue.put((rank, str(e), type(e).__name__)) + raise + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for tensor parallelism" +) +def test_init_distributed_is_called_before_memory_snapshot(): + """Test that distributed env is setup before memory snapshot. + + This test makes sure during worker initialization, the initial memory + snapshot is taken after distributed env is setup to include all the buffers + allocated by distributed env. + """ + world_size = 2 + + # Create a temporary file for distributed init + with tempfile.NamedTemporaryFile(delete=False) as f: + distributed_init_method = f"file://{f.name}" + + # Create queues for inter-process communication + ctx = mp.get_context("spawn") + operation_queue = ctx.Queue() + error_queue = ctx.Queue() + + # Start worker processes + processes = [] + for rank in range(world_size): + p = ctx.Process( + target=worker_process, + args=( + rank, + world_size, + distributed_init_method, + operation_queue, + error_queue, + ), + ) + p.start() + processes.append(p) + + # Wait for all processes to complete + for p in processes: + p.join(timeout=60) # 60 second timeout + + # Check for errors + errors = [] + while not error_queue.empty(): + rank, error_msg, error_type = error_queue.get() + errors.append(f"Rank {rank}: {error_type}: {error_msg}") + + if errors: + pytest.fail("Worker processes failed:\n" + "\n".join(errors)) + + # Collect all operations from the queue + operations = [] + while not operation_queue.empty(): + operations.append(operation_queue.get()) + + # Verify we got operations from both ranks + print(f"Collected operations: {operations}") + + # Check operations for each rank + for rank in range(world_size): + rank_ops = [op for op, r in operations if r == rank] + print(f"\nRank {rank} operations: {rank_ops}") + + # Raises ValueError if the operation is not found + init_distributed = rank_ops.index("init_distributed") + nccl_all_reduce = rank_ops.index("nccl_all_reduce") + memory_snapshot = rank_ops.index("memory_snapshot") + + # Verify order: init_distributed should happen before memory_snapshot + assert init_distributed < nccl_all_reduce < memory_snapshot, ( + f"Rank {rank}: init_distributed (index {init_distributed}) " + f"must happen before nccl_all_reduce (index {nccl_all_reduce}) " + f"and memory_snapshot (index {memory_snapshot})" + ) + + # Clean up + os.unlink(distributed_init_method.replace("file://", "")) diff --git a/tests/vllm_test_utils/setup.py b/tests/vllm_test_utils/setup.py index 83be8bdce85cf..4cb66b556e5a7 100644 --- a/tests/vllm_test_utils/setup.py +++ b/tests/vllm_test_utils/setup.py @@ -4,7 +4,7 @@ from setuptools import setup setup( - name='vllm_test_utils', - version='0.1', - packages=['vllm_test_utils'], + name="vllm_test_utils", + version="0.1", + packages=["vllm_test_utils"], ) diff --git a/tests/vllm_test_utils/vllm_test_utils/blame.py b/tests/vllm_test_utils/vllm_test_utils/blame.py index 49fd083ef19c8..e2cab92ea22b2 100644 --- a/tests/vllm_test_utils/vllm_test_utils/blame.py +++ b/tests/vllm_test_utils/vllm_test_utils/blame.py @@ -26,7 +26,7 @@ def blame(func: Callable) -> Generator[BlameResult, None, None]: ```python with blame(lambda: some_condition()) as result: # do something - + if result.found: print(result.trace_stack) """ @@ -34,7 +34,7 @@ def blame(func: Callable) -> Generator[BlameResult, None, None]: def _trace_calls(frame, event, arg=None): nonlocal result - if event in ['call', 'return']: + if event in ["call", "return"]: # for every function call or return try: # Temporarily disable the trace function diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index 9454221b273e6..e2f1212ed554b 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -19,8 +19,8 @@ class MonitoredValues(Generic[_T]): @contextlib.contextmanager def monitor( - measure_func: Callable[[], - _T]) -> Generator[MonitoredValues[_T], None, None]: + measure_func: Callable[[], _T], +) -> Generator[MonitoredValues[_T], None, None]: """ Trace the function calls to continuously monitor the change of a value. @@ -28,23 +28,23 @@ def monitor( Usage: ```python - def measure_func(): - ... # measure the current value + ... # measure the current value return current_value + with monitor(measure_func) as monitored_values: # do something - - monitored_values.values # all changes of the values - monitored_values.trace_stacks # trace stacks of every change + + monitored_values.values # all changes of the values + monitored_values.trace_stacks # trace stacks of every change ``` """ monitored_values = MonitoredValues[_T]() def _trace_calls(frame, event, arg=None): nonlocal monitored_values - if event in ['line']: + if event in ["line"]: # triggered by every line of Python code. # only Python functions will trigger it, # c/cpp functions will not trigger it. @@ -53,11 +53,14 @@ def monitor( sys.settrace(None) # do a measurement current_value = measure_func() - if len(monitored_values.values - ) == 0 or current_value != monitored_values.values[-1]: + if ( + len(monitored_values.values) == 0 + or current_value != monitored_values.values[-1] + ): monitored_values.values.append(current_value) - monitored_values.trace_stacks.append("".join( - traceback.format_stack())) + monitored_values.trace_stacks.append( + "".join(traceback.format_stack()) + ) # Re-enable the trace function sys.settrace(_trace_calls) except NameError: diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index 3aabae099073e..6587730682088 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -9,35 +9,39 @@ import torch from vllm.platforms import current_platform MAX_MODEL_LEN = 1024 -MODEL_NAME = os.environ.get("MODEL_NAME", - "robertgshaw2/zephyr-7b-beta-channelwise-gptq") +MODEL_NAME = os.environ.get( + "MODEL_NAME", "robertgshaw2/zephyr-7b-beta-channelwise-gptq" +) REVISION = os.environ.get("REVISION", "main") QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin") MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "80") @pytest.mark.skipif( - MODEL_NAME == "casperhansen/deepseek-coder-v2-instruct-awq", - reason="OOM in the CI") + MODEL_NAME == "casperhansen/deepseek-coder-v2-instruct-awq", reason="OOM in the CI" +) @pytest.mark.skipif( not current_platform.has_device_capability(int(MIN_CAPABILITY)), - reason="Current system does not have minimum capability.") + reason="Current system does not have minimum capability.", +) def test_weight_loading(vllm_runner): """ Test parameter weight loading with tp>1. """ # MoE models need fp16. - NEEDS_FP16 = (QUANTIZATION == "gptq" or MODEL_NAME - == "nm-testing/test-w4a16-mixtral-actorder-group") + NEEDS_FP16 = ( + QUANTIZATION == "gptq" + or MODEL_NAME == "nm-testing/test-w4a16-mixtral-actorder-group" + ) with vllm_runner( - model_name=MODEL_NAME, - revision=REVISION, - dtype=torch.half if NEEDS_FP16 else "auto", - quantization=None if QUANTIZATION == "None" else QUANTIZATION, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=2) as model: - + model_name=MODEL_NAME, + revision=REVISION, + dtype=torch.half if NEEDS_FP16 else "auto", + quantization=None if QUANTIZATION == "None" else QUANTIZATION, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=2, + ) as model: output = model.generate_greedy("Hello world!", max_tokens=20) print(output) assert output diff --git a/tests/worker/conftest.py b/tests/worker/conftest.py deleted file mode 100644 index 3f202d4dbe948..0000000000000 --- a/tests/worker/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') \ No newline at end of file diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py deleted file mode 100644 index 35ac90b38e840..0000000000000 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ /dev/null @@ -1,648 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools - -import pytest -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.platforms import current_platform -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner - -BATCH_SIZES = [1, 4, 16, 64, 256] - - -def _create_model_runner(model: str, *args, - **kwargs) -> EncoderDecoderModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - model_runner = EncoderDecoderModelRunner( - vllm_config=engine_config, - is_driver_worker=True, - ) - return model_runner - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -def test_empty_seq_group(): - """Verify prepare prompt and decode returns empty output - for empty seq group list""" - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - ( - input_tokens, - input_positions, - encoder_input_tokens, - encoder_input_positions, - attn_metadata, - return_seq_lens, - ) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.encoder_input_tokens, - model_input.encoder_input_positions, - model_input.attn_metadata, - model_input.seq_lens, - ) - assert input_tokens is None - assert input_positions is None - assert encoder_input_tokens is None - assert encoder_input_positions is None - assert attn_metadata is None - assert return_seq_lens is None - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_prompt(batch_size): - ''' - Test the ability of the encoder/decoder model runner subclass to - produce prefill-phase model inputs & attention metadata. - - Test behavior: - - * Instantiate BART base model & enc/dec model runner - * Construct sequence-group metadata for dummy prompts - * Test that encoder attention, decoder self-attention, - and encoder/decoder cross-attention inputs are correct - - Arguments: - - * batch_size - * backend_name: The attention backend under test - * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) - ''' - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - - # Build - # * Decoder model inputs - # * Decoder self-attention KV caching data structures - # * Encoder model inputs - # * Encoder/decoder cross-attention KV caching data structures - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for prompts. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills > 0 - assert attn_metadata.num_decode_tokens == 0 - assert torch.equal(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == max(seq_lens) - assert attn_metadata.max_decode_seq_len == 0 - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - start_loc.append(start_idx) - assert torch.equal( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - - # Test decoder seq start locs & context lengths - - assert torch.equal( - attn_metadata.seq_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - assert torch.equal( - attn_metadata.context_lens_tensor, - torch.zeros(attn_metadata.context_lens_tensor.shape[0], - dtype=torch.int, - device=device), - ) - - # Verify block tables are correct for prompts - # - Decoder self-attention - expected = torch.tensor( - [[] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Cuda graph should not be used for prefill. - assert attn_metadata.use_cuda_graph is False - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == sum(encoder_seq_lens) - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) - - # Test that vLLM sampling infrastructure chooses the correct - # sequence positions at which to sample (i.e. the end of - # each sequence) in the prefill phase - - expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: - # Compute the index offset of the final token in each - # prompt (recall that the prompts are concatenated) - expected_selected_token_indices.append(selected_token_start_idx + - seq_len - 1) - selected_token_start_idx += seq_len - - sampling_metadata = model_input.sampling_metadata - actual = sampling_metadata.selected_token_indices - expected = torch.tensor( - expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) - assert torch.equal(actual, expected) - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) -def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): - ''' - Test the ability of the encoder/decoder model runner subclass to - produce decode-phase model inputs & attention metadata. - - Test behavior: - - * Instantiate BART base model & enc/dec model runner - * Construct sequence-group metadata for dummy prompts - * Test that encoder attention, decoder self-attention, - and encoder/decoder cross-attention inputs are correct - - Arguments: - - * batch_size - * multiple_seqs_per_seq_group - * backend_name: The attention backend under test - * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) - ''' - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = { - 0: [1], - 1: [3] - } if multiple_seqs_per_seq_group else { - 0: [1] - } - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={ - 0: seq_data, - 1: seq_data - } if multiple_seqs_per_seq_group else {0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - seq_lens.extend( - [seq_len for _ in range(len(seq_group_metadata.seq_data))]) - encoder_seq_lens.extend( - [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - - # Build - # * Decoder model inputs - # * Decoder self-attention KV caching data structures - # * Encoder model inputs - # * Encoder/decoder cross-attention KV caching data structures - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for decode phase. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.equal(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += 1 - start_loc.append(start_idx) - assert torch.equal( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - - # Test decoder seq start locs. Note that for normal prefill it is - # equivalent to query_start_loc. - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - - # Test seq_start_loc and context lengths - - assert torch.equal( - attn_metadata.seq_start_loc, - torch.tensor(seq_start_loc, dtype=torch.int32, device=device), - ) - assert torch.equal( - attn_metadata.context_lens_tensor, - torch.tensor([seq_len - 1 for seq_len in seq_lens], - dtype=torch.int, - device=device)) - - # Verify block tables are correct for prompts - # - Decoder self-attention - flattened_block_tables = [ - block_table for block_table in block_tables.values() - ] - expected = torch.tensor(flattened_block_tables * - len(seq_group_metadata_list), - dtype=torch.int32, - device=model_runner.device) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention - expected = torch.tensor([ - cross_block_table for seq_group_metadata in seq_group_metadata_list - for _ in range(len(seq_group_metadata.seq_data)) - ], - dtype=torch.int32, - device=model_runner.device) - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Model runner's CUDAGraph setting should be propagated to attention - # metadata. - assert attn_metadata.use_cuda_graph is False - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(seq_lens) - assert len(input_positions) == len(seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_tokens) == 0 - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) - - # Test that vLLM sampling infrastructure chooses the correct - # sequence positions at which to sample (i.e. the end of - # each sequence) in the decode phase - - expected_selected_token_indices = [] - for selected_token_start_idx, seq_len in enumerate(seq_lens): - # Compute the index offset of the final token in each - # sequence's decoded outputs; since a single token is - # decoded per iteration per sequence, then the length - # of the decoded tokens for a given sequence is 1 and - # the final index offset into a given sequence's - # generated tokens is 0 (i.e. the expected sampling index - # for a given sequence is just `selected_token_start_idx`) - expected_selected_token_indices.append(selected_token_start_idx) - - sampling_metadata = model_input.sampling_metadata - actual = sampling_metadata.selected_token_indices - expected = torch.tensor( - expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) - assert torch.equal(actual, expected) - - -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) -def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): - """ - Tests that for encoder-decoder models with CUDA Graph capture and replay - enabled, the tensors used during the decode phase are correctly padded - for varying input batch sizes. - """ - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=False, - ) - block_tables = { - 0: [1], - 1: [3] - } if multiple_seqs_per_seq_group else { - 0: [1] - } - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - - cross_block_table = [2] - expanded_batch_size = 0 - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={ - 0: seq_data, - 1: seq_data - } if multiple_seqs_per_seq_group else {0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_lens.extend( - [seq_len for _ in range(len(seq_group_metadata.seq_data))]) - encoder_seq_lens.extend( - [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - expanded_batch_size = expanded_batch_size + len( - seq_group_metadata.seq_data) - seq_group_metadata_list.append(seq_group_metadata) - - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - - # With CUDA Graph capture and replay enabled, the decoder and encoder - # input sequences will be padded. Create the expected padded tensors - # accordingly. - graph_batch_size = model_runner.vllm_config.pad_for_cudagraph( - expanded_batch_size) - cuda_graph_pad_size = graph_batch_size - expanded_batch_size - padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) - padded_encoder_seq_lens = encoder_seq_lens + list( - itertools.repeat(1, cuda_graph_pad_size)) - - assert return_seq_lens == padded_seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.equal( - attn_metadata.seq_lens_tensor, - torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == padded_seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) - - # Verify block tables are correct for prompts - # - Decoder self-attention. Pad the block tables as expected. - flattened_block_tables = [ - block_table for _ in range(len(seq_group_metadata_list)) - for block_table in block_tables.values() - ] - flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) - expected = make_tensor_with_pad( - flattened_block_tables, - max_len=64, - pad=0, - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention. Pad the cross-attention block tables - # as expected. - expected = [ - cross_block_table for seq_group_metadata in seq_group_metadata_list - for _ in range(len(seq_group_metadata.seq_data)) - ] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) - expected = make_tensor_with_pad( - expected, - max_len=64, - pad=0, - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Model runner's CUDAGraph setting should be propagated to attention - # metadata. - assert attn_metadata.use_cuda_graph is True - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(padded_seq_lens) - assert len(input_positions) == len(padded_seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_tokens) == 0 - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py deleted file mode 100644 index 2031f41fab87d..0000000000000 --- a/tests/worker/test_model_input.py +++ /dev/null @@ -1,167 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses - -import torch - -from vllm.attention import AttentionMetadata, AttentionMetadataBuilder -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import CommonAttentionState -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -from vllm.worker.pooling_model_runner import ( - ModelInputForGPUWithPoolingMetadata) - - -class MockAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - raise NotImplementedError - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AttentionMetadata - - @staticmethod - def get_builder_cls() -> type["AttentionMetadataBuilder"]: - return AttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> tuple[int, ...]: - raise NotImplementedError - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - pass - - @staticmethod - def copy_blocks( - kv_caches: list[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - pass - - -def test_model_runner_input(): - sampling_metadata = SamplingMetadata( - ["seq_group"], - "selected_token_indices", - "categorized_sample_indices", - "num_prompts", - ) - attn_metadata = AttentionMetadata( - num_prefills=1, - num_prefill_tokens=2, - num_decode_tokens=3, - slot_mapping=torch.zeros(1), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - ) - model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.ones(10), - input_positions=torch.ones(10), - sampling_metadata=sampling_metadata, - attn_metadata=attn_metadata) - - assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata) - - # Test round trip serialization. - tensor_dict = model_input.as_broadcastable_tensor_dict() - attn_backend = MockAttentionBackend() - received_model_input = ( - ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend)) - # Check that received copy has correct values. - assert isinstance(received_model_input, - ModelInputForGPUWithSamplingMetadata) - assert received_model_input.input_tokens is not None - assert ( - received_model_input.input_tokens == model_input.input_tokens).all() - assert received_model_input.input_positions is not None - assert (received_model_input.input_positions == model_input.input_positions - ).all() - assert received_model_input.multi_modal_kwargs is None - assert (received_model_input.multi_modal_kwargs == - model_input.multi_modal_kwargs) - assert received_model_input.lora_requests is None - assert received_model_input.lora_requests == model_input.lora_requests - assert received_model_input.lora_mapping is None - assert received_model_input.lora_mapping == model_input.lora_mapping - for field in dataclasses.fields(AttentionMetadata): - assert getattr(received_model_input.attn_metadata, field.name, - None) == getattr(attn_metadata, field.name, None) - # For sampling metadata, only selected_token_indices is copied. - assert (received_model_input.sampling_metadata.selected_token_indices == - sampling_metadata.selected_token_indices) - assert received_model_input.sampling_metadata.seq_groups is None - - -def test_embedding_model_runner_input(): - pooling_metadata = PoolingMetadata( - seq_groups=[[0]], - seq_data={}, - prompt_lens=[1], - ) - attn_metadata = AttentionMetadata( - num_prefills=1, - num_prefill_tokens=2, - num_decode_tokens=3, - slot_mapping=torch.zeros(1), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - ) - model_input = ModelInputForGPUWithPoolingMetadata( - input_tokens=torch.ones(10), - input_positions=torch.ones(10), - pooling_metadata=pooling_metadata, - attn_metadata=attn_metadata) - - assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata) - - # Test round trip serialization. - tensor_dict = model_input.as_broadcastable_tensor_dict() - attn_backend = MockAttentionBackend() - received_model_input = ( - ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend)) - # Check that received copy has correct values. - assert isinstance(received_model_input, - ModelInputForGPUWithPoolingMetadata) - assert received_model_input.input_tokens is not None - assert ( - received_model_input.input_tokens == model_input.input_tokens).all() - assert received_model_input.input_positions is not None - assert (received_model_input.input_positions == model_input.input_positions - ).all() - assert received_model_input.multi_modal_kwargs is None - assert (received_model_input.multi_modal_kwargs == - model_input.multi_modal_kwargs) - assert received_model_input.lora_requests is None - assert received_model_input.lora_requests == model_input.lora_requests - assert received_model_input.lora_mapping is None - assert received_model_input.lora_mapping == model_input.lora_mapping - for field in dataclasses.fields(AttentionMetadata): - assert getattr(received_model_input.attn_metadata, field.name, - None) == getattr(attn_metadata, field.name, None) - # Pooling metadata is not broadcast. - assert received_model_input.pooling_metadata is None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py deleted file mode 100644 index 0be25aa2fc35d..0000000000000 --- a/tests/worker/test_model_runner.py +++ /dev/null @@ -1,462 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import get_open_port -from vllm.worker.model_runner import ModelRunner - - -def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - model_runner = ModelRunner( - vllm_config=engine_config, - is_driver_worker=True, - ) - return model_runner - - -def test_deepseek_mla_attn_backend_module(): - model_runner = _create_model_runner( - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", - trust_remote_code=True, - enable_chunked_prefill=False, - ) - assert model_runner.attn_backend.__name__ == "TritonMLABackend" - - -@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) -@pytest.mark.parametrize("use_prompt_embeds", [True, False]) -def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): - if use_prompt_embeds: - # Prompt Embeddings is only currently supported on V0 - monkeypatch.setenv("VLLM_USE_V1", "0") - - model_runner = _create_model_runner( - "facebook/opt-125m", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enable_prompt_embeds=True, - ) - - seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - expected_input_embeds_len = 0 - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * seq_len, - prompt_embeds=torch.rand(seq_len, 10), - ) - expected_input_embeds_len += seq_len - else: - seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len)) - - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - - expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: - expected_selected_token_indices.append(selected_token_start_idx + - seq_len - 1) - selected_token_start_idx += seq_len - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - - # Verify input metadata is correct for prompts. - device = model_runner.device - assert attn_metadata.num_prefills > 0 - assert attn_metadata.num_decode_tokens == 0 - torch.testing.assert_close( - attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == max(seq_lens) - assert attn_metadata.max_decode_seq_len == 0 - - # Test subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - start_loc.append(start_idx) - torch.testing.assert_close( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - - # Test seq start locs. Note that for normal prefill it is - # equivalent to query_start_loc. - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - - torch.testing.assert_close( - attn_metadata.seq_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - torch.testing.assert_close( - attn_metadata.context_lens_tensor, - torch.zeros(attn_metadata.context_lens_tensor.shape[0], - dtype=torch.int, - device=device)) - - expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) - torch.testing.assert_close(attn_metadata.block_tables, expected) - # Cuda graph should not be used for prerill. - assert attn_metadata.use_cuda_graph is False - - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - if expected_input_embeds_len == 0: - torch.testing.assert_close(input_tokens, input_positions) - assert input_embeds is None - else: - assert len(input_embeds) == expected_input_embeds_len - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=model_runner.device, - pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - torch.allclose(input_tokens, input_positions) - - actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - - -@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) -@pytest.mark.parametrize("use_prompt_embeds", [True, False]) -def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): - if use_prompt_embeds: - # Prompt Embeddings is only currently supported on V0 - monkeypatch.setenv("VLLM_USE_V1", "0") - - model_runner = _create_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - enforce_eager=False, - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enable_prompt_embeds=True, - ) - - context_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - # Assume each seq group finishes prefill. - for i in range(batch_size): - # make sure all tokens fit into one block - context_len = i % (model_runner.block_size - 1) + 1 - context_lens.append(context_len) - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * context_len, - prompt_embeds=torch.rand(context_len, 10), - ) - output_embed = torch.rand(10) - else: - seq_data = SequenceData.from_seqs( - prompt_token_ids=range(context_len)) - output_embed = None - seq_data.update_num_computed_tokens(context_len) - # Append one token ID since prefill is finished. - seq_data.append_token_id(1, 0, output_embed) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - slot_mapping = attn_metadata.slot_mapping - - assert len(slot_mapping) == len(input_tokens) - - expected_bs = model_runner.vllm_config.pad_for_cudagraph( - len(seq_group_metadata_list)) - # Verify input metadata is correct for prompts. - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_prefill_tokens == 0 - seq_lens = [context_len + 1 for context_len in context_lens] - # seq_lens are padded to expected_bs - for _ in range(expected_bs - len(seq_lens)): - seq_lens.append(1) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.num_decode_tokens == len(seq_lens) - start_idx = 0 - start_loc = [start_idx] - for _ in context_lens: - # decode has only 1 token for query. - start_idx += 1 - start_loc.append(start_idx) - torch.testing.assert_close( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - torch.testing.assert_close( - attn_metadata.seq_start_loc, - torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) - - torch.testing.assert_close( - attn_metadata.context_lens_tensor, - torch.tensor(context_lens, dtype=torch.int, device=device)) - assert attn_metadata.max_decode_seq_len == max(seq_lens) - torch.testing.assert_close( - attn_metadata.seq_lens_tensor[:len(seq_lens)], - torch.tensor(seq_lens, dtype=torch.int, device=device)) - - # block table's first index corresponds to each batch, meaning in - # decoding it is each token. - assert attn_metadata.block_tables.shape[0] == len(input_tokens) - # Block table's second dim corresponds to each token's block number. - # It is padded up to - assert attn_metadata.block_tables.shape[1] == ( - model_runner.get_max_block_per_batch()) - assert attn_metadata.use_cuda_graph is True - - assert len(input_tokens) == expected_bs - assert len(input_positions) == expected_bs - if use_prompt_embeds: - expected_input_embeds_length = start_loc[-1] - assert len(input_embeds) == expected_input_embeds_length - assert expected_input_embeds_length <= expected_bs - else: - assert input_embeds is None - - # Verify Sampling - expected_selected_token_indices = [] - for selected_token_start_idx, _ in enumerate(context_lens): - expected_selected_token_indices.append(selected_token_start_idx) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query lens is all 1 for decode. - query_lens=[1 for _ in range(len(context_lens))], - device=model_runner.device, - pin_memory=model_runner.pin_memory) - actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - - -def test_empty_seq_group(): - """Verify prepare prompt and decode returns empty output.""" - model_runner = _create_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - enforce_eager=False, - ) - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - - assert input_tokens is None - assert input_positions is None - assert attn_metadata is None - - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - - assert input_tokens is None - assert input_positions is None - assert input_embeds is None - assert attn_metadata is None - assert return_seq_lens is None - - -@pytest.fixture -def distributed_init(): - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", - local_rank=0) - ensure_model_parallel_initialized(1, 1) - - -@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) -@pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.parametrize('use_prompt_embeds', [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, - distributed_init, monkeypatch): - if use_prompt_embeds: - # Prompt Embeddings is only currently supported on V0 - monkeypatch.setenv("VLLM_USE_V1", "0") - - model_runner = _create_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - enforce_eager=enforce_eager, - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=True, - enable_prompt_embeds=True, - ) - - # Add prefill requests. - seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - prefill_metadata_list: list[SequenceGroupMetadata] = [] - decode_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - prefill_batch_size = batch_size // 2 - decode_batch_size = batch_size - prefill_batch_size - expected_input_embeds_len = 0 - for i in range(prefill_batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * seq_len, - prompt_embeds=torch.rand(seq_len, 10), - ) - expected_input_embeds_len += seq_len - else: - seq_data = SequenceData.from_seqs( - prompt_token_ids=range(seq_len), ) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - prefill_metadata_list.append(seq_group_metadata) - - # Add decode requests - for i in range(prefill_batch_size, batch_size): - # make sure all tokens fit into one block - context_len = i % (model_runner.block_size - 1) + 1 - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * context_len, - prompt_embeds=torch.rand(context_len, 10), - ) - output_embed = torch.rand(10) - # This also iterates the expected input_embeds, because the model - # needs both the input and output embeddings passed into together - expected_input_embeds_len += 1 - else: - seq_data = SequenceData.from_seqs( - prompt_token_ids=range(context_len), ) - output_embed = None - assert len(seq_data.prompt_token_ids) == context_len - seq_data.append_token_id(1, 0, output_embed) - seq_data.update_num_computed_tokens(context_len) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - decode_metadata_list.append(seq_group_metadata) - - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - - prefill_meta_actual = attn_metadata.prefill_metadata - decode_meta_actual = attn_metadata.decode_metadata - - assert len(attn_metadata.slot_mapping) == len(input_tokens) - assert len(input_positions) == len(input_tokens) - assert attn_metadata.num_prefills == prefill_batch_size - assert attn_metadata.num_decode_tokens == decode_batch_size - assert attn_metadata.num_prefill_tokens == sum(seq_lens) - if expected_input_embeds_len == 0: - assert input_embeds is None - else: - assert len(input_embeds) == expected_input_embeds_len - - # Verify attn metadata is consistent. We don't need to test individual - # values here because they are tested above. - attn_metadata = model_runner._prepare_model_input_tensors( - seq_group_metadata_list).attn_metadata - - for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), - vars(prefill_meta_actual)): - assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), - vars(decode_meta_actual)): - assert attr_expected[1] == attr_actual[1] diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py deleted file mode 100644 index d8767f700b576..0000000000000 --- a/tests/worker/test_profile.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.worker import Worker - - -def test_gpu_memory_profiling(): - # Tests the gpu profiling that happens in order to determine the number of - # KV cache blocks that we can allocate on the GPU. - # This test mocks the maximum available gpu memory so that it can run on - # any gpu setup. - - # Set up engine args to build a worker. - engine_args = EngineArgs(model="facebook/opt-125m", - dtype="half", - load_format="dummy") - engine_config = engine_args.create_engine_config() - engine_config.cache_config.num_gpu_blocks = 1000 - engine_config.cache_config.num_cpu_blocks = 1000 - - # Create the worker. - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - worker = Worker( - vllm_config=engine_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - is_driver_worker=True, - ) - - # Set 10GiB as the total gpu ram to be device-agnostic - def mock_mem_info(): - current_usage = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - mock_total_bytes = 10 * 1024**3 - free = mock_total_bytes - current_usage - - return (free, mock_total_bytes) - - from unittest.mock import patch - with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info): - # Load the model so we can profile it - worker.init_device() - worker.load_model() - gpu_blocks, _ = worker.determine_num_available_blocks() - - # Peak vram usage by torch should be 0.47 GiB - # Model weights take 0.25 GiB - # No memory should be allocated outside of torch - # 9.0 GiB should be the utilization target - # 8.28 GiB should be available for the KV cache - block_size = CacheEngine.get_cache_block_size( - engine_config.cache_config, engine_config.model_config, - engine_config.parallel_config) - - expected_blocks = (8.28 * 1024**3) // block_size - - # Check within a small tolerance for portability - # Hardware, kernel, or dependency changes could all affect memory - # utilization. - # A 100 block tolerance here should be about 60MB of wiggle room. - assert abs(gpu_blocks - expected_blocks) < 100 diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py deleted file mode 100644 index 6d9f404ac207b..0000000000000 --- a/tests/worker/test_swap.py +++ /dev/null @@ -1,87 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.worker.worker import Worker - - -def test_swap() -> None: - # Configure the engine. - engine_args = EngineArgs(model="distilbert/distilgpt2", - dtype="half", - load_format="dummy") - engine_config = engine_args.create_engine_config() - engine_config.cache_config.num_gpu_blocks = 1000 - engine_config.cache_config.num_cpu_blocks = 1000 - - # Create the worker. - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - worker = Worker( - vllm_config=engine_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - is_driver_worker=True, - ) - - # Initialize the worker. - worker.init_device() - worker.load_model() - worker.initialize_cache( - num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, - num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) - - # Randomly initialize the cache. - gpu_cache = worker.cache_engine[0].gpu_cache - cpu_cache = worker.cache_engine[0].cpu_cache - num_layers = len(gpu_cache) - for i in range(num_layers): - gpu_key_cache, gpu_value_cache = gpu_cache[i] - gpu_key_cache.random_() - gpu_value_cache.random_() - cpu_key_cache, cpu_value_cache = cpu_cache[i] - cpu_key_cache.random_() - cpu_value_cache.random_() - - allclose = lambda a, b: torch.allclose( - a.cuda(), b.cuda(), rtol=0.0, atol=0.0) - - # Test swap out. - blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=[], - blocks_to_swap_in=[], - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=[], - ) - worker.execute_model(execute_model_req=execute_model_req) - - for i in range(num_layers): - gpu_key_cache, gpu_value_cache = gpu_cache[i] - cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_out: - assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) - assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) - - # Test swap in. - execute_model_req.blocks_to_swap_out = [] - execute_model_req.blocks_to_swap_in = [ - (19, 45), - (67, 23), - (12, 78), - (40, 99), - (1, 71), - ] - worker.execute_model(execute_model_req=execute_model_req) - - for i in range(num_layers): - gpu_key_cache, gpu_value_cache = gpu_cache[i] - cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in execute_model_req.blocks_to_swap_in: - assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) - assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/tools/check_init_lazy_imports.py b/tools/check_init_lazy_imports.py index e8e6f07cc33fc..9255aa17db6a6 100644 --- a/tools/check_init_lazy_imports.py +++ b/tools/check_init_lazy_imports.py @@ -17,12 +17,16 @@ REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py" # If you need to add items to whitelist, do it here. -ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({ - "vllm.env_override", -}) -ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({ - ".version", -}) +ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset( + { + "vllm.env_override", + } +) +ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset( + { + ".version", + } +) def _is_internal(name: str | None, *, level: int = 0) -> bool: @@ -34,8 +38,7 @@ def _is_internal(name: str | None, *, level: int = 0) -> bool: def _fail(violations: Iterable[tuple[int, str]]) -> None: - print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", - file=sys.stderr) + print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr) for lineno, msg in violations: print(f" Line {lineno}: {msg}", file=sys.stderr) sys.exit(1) @@ -48,7 +51,6 @@ def main() -> None: violations: list[tuple[int, str]] = [] class Visitor(ast.NodeVisitor): - def __init__(self) -> None: super().__init__() self._in_type_checking = False @@ -56,10 +58,10 @@ def main() -> None: def visit_If(self, node: ast.If) -> None: guard_is_type_checking = False test = node.test - if isinstance(test, ast.Attribute) and isinstance( - test.value, ast.Name): - guard_is_type_checking = (test.value.id == "typing" - and test.attr == "TYPE_CHECKING") + if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name): + guard_is_type_checking = ( + test.value.id == "typing" and test.attr == "TYPE_CHECKING" + ) elif isinstance(test, ast.Name): guard_is_type_checking = test.id == "TYPE_CHECKING" @@ -79,24 +81,28 @@ def main() -> None: return for alias in node.names: module_name = alias.name - if _is_internal( - module_name) and module_name not in ALLOWED_IMPORTS: - violations.append(( - node.lineno, - f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501 - )) + if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS: + violations.append( + ( + node.lineno, + f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501 + ) + ) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if self._in_type_checking: return module_as_written = ("." * node.level) + (node.module or "") - if _is_internal( - node.module, level=node.level - ) and module_as_written not in ALLOWED_FROM_MODULES: - violations.append(( - node.lineno, - f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501 - )) + if ( + _is_internal(node.module, level=node.level) + and module_as_written not in ALLOWED_FROM_MODULES + ): + violations.append( + ( + node.lineno, + f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501 + ) + ) Visitor().visit(tree) diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py deleted file mode 100644 index ad0ae45d1d465..0000000000000 --- a/tools/check_pickle_imports.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -import sys - -import regex as re - -try: - import pathspec -except ImportError: - print( - "ERROR: The 'pathspec' library is required. " - "Install it with 'pip install pathspec'.", - file=sys.stderr) - sys.exit(2) - -# List of files (relative to repo root) that are allowed to import pickle or -# cloudpickle -# -# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST: -# The pickle and cloudpickle modules are known to be unsafe when deserializing -# data from potentially untrusted parties. They have resulted in multiple CVEs -# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly. -# Before adding new uses of pickle/cloudpickle, please consider safer -# alternatives like msgpack or pydantic that are already in use in vLLM. Only -# add to this list if absolutely necessary and after careful security review. -ALLOWED_FILES = set([ - # pickle - 'vllm/v1/serial_utils.py', - 'vllm/v1/executor/multiproc_executor.py', - 'vllm/multimodal/hasher.py', - 'vllm/transformers_utils/config.py', - 'vllm/model_executor/models/registry.py', - 'tests/utils_/test_utils.py', - 'tests/tokenization/test_cached_tokenizer.py', - 'vllm/distributed/utils.py', - 'vllm/distributed/parallel_state.py', - 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/all_reduce_utils.py', - 'vllm/distributed/device_communicators/shm_broadcast.py', - 'vllm/engine/multiprocessing/engine.py', - 'benchmarks/kernels/graph_machete_bench.py', - 'benchmarks/kernels/benchmark_lora.py', - 'benchmarks/kernels/benchmark_machete.py', - 'benchmarks/fused_kernels/layernorm_rms_benchmarks.py', - 'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py', - 'benchmarks/cutlass_benchmarks/sparse_benchmarks.py', - # cloudpickle - 'vllm/worker/worker_base.py', - 'vllm/executor/mp_distributed_executor.py', - 'vllm/executor/ray_distributed_executor.py', - 'vllm/entrypoints/llm.py', - 'tests/utils.py', - # pickle and cloudpickle - 'vllm/utils/__init__.py', - 'vllm/v1/serial_utils.py', - 'vllm/v1/executor/multiproc_executor.py', - 'vllm/transformers_utils/config.py', - 'vllm/model_executor/models/registry.py', - 'vllm/engine/multiprocessing/client.py', - 'vllm/engine/multiprocessing/engine.py', -]) - -PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" - r"|from\s+(pickle|cloudpickle)\s+import\b)") - - -def is_python_file(path): - return path.endswith('.py') - - -def scan_file(path): - with open(path, encoding='utf-8') as f: - for line in f: - if PICKLE_RE.match(line): - return True - return False - - -def load_gitignore(repo_root): - gitignore_path = os.path.join(repo_root, '.gitignore') - patterns = [] - if os.path.exists(gitignore_path): - with open(gitignore_path, encoding='utf-8') as f: - patterns = f.read().splitlines() - # Always ignore .git directory - patterns.append('.git/') - return pathspec.PathSpec.from_lines('gitwildmatch', patterns) - - -def main(): - repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - spec = load_gitignore(repo_root) - bad_files = [] - for dirpath, _, filenames in os.walk(repo_root): - for filename in filenames: - if not is_python_file(filename): - continue - abs_path = os.path.join(dirpath, filename) - rel_path = os.path.relpath(abs_path, repo_root) - # Skip ignored files - if spec.match_file(rel_path): - continue - if scan_file(abs_path) and rel_path not in ALLOWED_FILES: - bad_files.append(rel_path) - if bad_files: - print("\nERROR: The following files import 'pickle' or 'cloudpickle' " - "but are not in the allowed list:") - for f in bad_files: - print(f" {f}") - print("\nIf this is intentional, update the allowed list in " - "tools/check_pickle_imports.py.") - sys.exit(1) - sys.exit(0) - - -def test_regex(): - test_cases = [ - # Should match - ("import pickle", True), - ("import cloudpickle", True), - ("import pickle as pkl", True), - ("import cloudpickle as cpkl", True), - ("from pickle import *", True), - ("from cloudpickle import dumps", True), - ("from pickle import dumps, loads", True), - ("from cloudpickle import (dumps, loads)", True), - (" import pickle", True), - ("\timport cloudpickle", True), - ("from pickle import loads", True), - # Should not match - ("import somethingelse", False), - ("from somethingelse import pickle", False), - ("# import pickle", False), - ("print('import pickle')", False), - ("import pickleas as asdf", False), - ] - for i, (line, should_match) in enumerate(test_cases): - result = bool(PICKLE_RE.match(line)) - assert result == should_match, ( - f"Test case {i} failed: '{line}' " - f"(expected {should_match}, got {result})") - print("All regex tests passed.") - - -if __name__ == '__main__': - if '--test-regex' in sys.argv: - test_regex() - else: - main() diff --git a/tools/check_spdx_header.py b/tools/check_spdx_header.py index ced10ba9097bc..1fcca12519ffa 100644 --- a/tools/check_spdx_header.py +++ b/tools/check_spdx_header.py @@ -7,6 +7,7 @@ from enum import Enum class SPDXStatus(Enum): """SPDX header status enumeration""" + EMPTY = "empty" # empty __init__.py COMPLETE = "complete" MISSING_LICENSE = "missing_license" # Only has copyright line @@ -16,7 +17,8 @@ class SPDXStatus(Enum): FULL_SPDX_HEADER = ( "# SPDX-License-Identifier: Apache-2.0\n" - "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project") + "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" +) LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0" COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501 @@ -123,8 +125,9 @@ def main(): continue # Collect all files that need fixing - all_files_to_fix = (files_missing_both + files_missing_copyright + - files_missing_license) + all_files_to_fix = ( + files_missing_both + files_missing_copyright + files_missing_license + ) if all_files_to_fix: print("The following files are missing the SPDX header:") if files_missing_both: diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py index c01d9d4ab079a..1b83074fe0d20 100644 --- a/tools/check_triton_import.py +++ b/tools/check_triton_import.py @@ -23,8 +23,7 @@ def is_allowed_file(current_file: str) -> bool: def is_forbidden_import(line: str) -> bool: stripped = line.strip() - return bool( - FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES + return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES def parse_diff(diff: str) -> list[str]: @@ -42,24 +41,24 @@ def parse_diff(diff: str) -> list[str]: elif line.startswith("@@"): match = re.search(r"\+(\d+)", line) if match: - current_lineno = int( - match.group(1)) - 1 # next "+ line" is here + current_lineno = int(match.group(1)) - 1 # next "+ line" is here elif line.startswith("+") and not line.startswith("++"): current_lineno += 1 code_line = line[1:] if is_forbidden_import(code_line): violations.append( - f"{current_file}:{current_lineno}: {code_line.strip()}") + f"{current_file}:{current_lineno}: {code_line.strip()}" + ) return violations def get_diff(diff_type: str) -> str: if diff_type == "staged": return subprocess.check_output( - ["git", "diff", "--cached", "--unified=0"], text=True) + ["git", "diff", "--cached", "--unified=0"], text=True + ) elif diff_type == "unstaged": - return subprocess.check_output(["git", "diff", "--unified=0"], - text=True) + return subprocess.check_output(["git", "diff", "--unified=0"], text=True) else: raise ValueError(f"Unknown diff_type: {diff_type}") @@ -75,8 +74,10 @@ def main(): print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr) if all_violations: - print("❌ Forbidden direct `import triton` detected." - " ➤ Use `from vllm.triton_utils import triton` instead.\n") + print( + "❌ Forbidden direct `import triton` detected." + " ➤ Use `from vllm.triton_utils import triton` instead.\n" + ) for v in all_violations: print(f"❌ {v}") return 1 diff --git a/tools/enforce_regex_import.py b/tools/enforce_regex_import.py index 63ceee5829aba..69f43cadc7677 100644 --- a/tools/enforce_regex_import.py +++ b/tools/enforce_regex_import.py @@ -7,24 +7,23 @@ from pathlib import Path import regex as re -FORBIDDEN_PATTERNS = re.compile( - r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)') +FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)") ALLOWED_PATTERNS = [ - re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'), - re.compile(r'^\s*import\s+regex\s*$'), + re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"), + re.compile(r"^\s*import\s+regex\s*$"), ] def get_staged_python_files() -> list[str]: try: result = subprocess.run( - ['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'], + ["git", "diff", "--cached", "--name-only", "--diff-filter=AM"], capture_output=True, text=True, - check=True) - files = result.stdout.strip().split( - '\n') if result.stdout.strip() else [] - return [f for f in files if f.endswith('.py')] + check=True, + ) + files = result.stdout.strip().split("\n") if result.stdout.strip() else [] + return [f for f in files if f.endswith(".py")] except subprocess.CalledProcessError: return [] @@ -33,13 +32,14 @@ def is_forbidden_import(line: str) -> bool: line = line.strip() return bool( FORBIDDEN_PATTERNS.match(line) - and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)) + and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS) + ) def check_file(filepath: str) -> list[tuple[int, str]]: violations = [] try: - with open(filepath, encoding='utf-8') as f: + with open(filepath, encoding="utf-8") as f: for line_num, line in enumerate(f, 1): if is_forbidden_import(line): violations.append((line_num, line.strip())) @@ -72,9 +72,7 @@ def main() -> int: if total_violations > 0: print(f"\n💡 Found {total_violations} violation(s).") print("❌ Please replace 'import re' with 'import regex as re'") - print( - " Also replace 'from re import ...' with 'from regex import ...'" - ) # noqa: E501 + print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501 print("✅ Allowed imports:") print(" - import regex as re") print(" - import regex") # noqa: E501 diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index e163c83e8b513..5a3d734190c1a 100644 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -10,8 +10,12 @@ if [ ! -d "$WORKSPACE" ]; then mkdir -p $WORKSPACE fi +# configurable pip command (default: pip3) +PIP_CMD=${PIP_CMD:-pip3} +CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} + # install dependencies if not installed -pip3 install cmake torch ninja +$PIP_CMD install cmake torch ninja # build nvshmem pushd $WORKSPACE @@ -77,6 +81,7 @@ clone_repo() { local repo_url=$1 local dir_name=$2 local key_file=$3 + local commit_hash=$4 if [ -d "$dir_name" ]; then # Check if directory has uncommitted changes (dirty) @@ -87,27 +92,35 @@ clone_repo() { echo "$dir_name directory exists but clone appears incomplete, cleaning up and re-cloning" rm -rf "$dir_name" git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi else echo "$dir_name directory exists and appears complete; manually update if needed" fi else git clone "$repo_url" + if [ -n "$commit_hash" ]; then + cd "$dir_name" + git checkout "$commit_hash" + cd .. + fi fi } # build and install pplx, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" +clone_repo "https://github.com/ppl-ai/pplx-kernels" "pplx-kernels" "setup.py" "c336faf" cd pplx-kernels -# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 -# PIP_NO_BUILD_ISOLATION=0 disables build isolation -PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . +$PIP_CMD install --no-build-isolation -vvv -e . popd # build and install deepep, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" +clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "e3908bf" cd DeepEP export NVSHMEM_DIR=$WORKSPACE/nvshmem_install -PIP_NO_BUILD_ISOLATION=0 pip install -vvv -e . +$PIP_CMD install --no-build-isolation -vvv -e . popd diff --git a/tools/flashinfer-build.sh b/tools/flashinfer-build.sh new file mode 100644 index 0000000000000..6c14d87348c3a --- /dev/null +++ b/tools/flashinfer-build.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# This script is used to build FlashInfer wheels with AOT kernels + +set -ex + +# FlashInfer configuration +FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" +FLASHINFER_GIT_REF="${FLASHINFER_GIT_REF}" +CUDA_VERSION="${CUDA_VERSION}" +BUILD_WHEEL="${BUILD_WHEEL:-true}" + +if [[ -z "${FLASHINFER_GIT_REF}" ]]; then + echo "❌ FLASHINFER_GIT_REF must be specified" >&2 + exit 1 +fi + +if [[ -z "${CUDA_VERSION}" ]]; then + echo "❌ CUDA_VERSION must be specified" >&2 + exit 1 +fi + +echo "🏗️ Building FlashInfer ${FLASHINFER_GIT_REF} for CUDA ${CUDA_VERSION}" + +# Clone FlashInfer +git clone --depth 1 --recursive --shallow-submodules \ + --branch ${FLASHINFER_GIT_REF} \ + ${FLASHINFER_GIT_REPO} flashinfer + +# Set CUDA arch list based on CUDA version +# Exclude CUDA arches for older versions (11.x and 12.0-12.7) +if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" +elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" +else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" +fi + +echo "🏗️ Building FlashInfer AOT for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + +pushd flashinfer + # Make sure the wheel is built for the correct CUDA version + export UV_TORCH_BACKEND=cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + + # Build AOT kernels + export TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" + export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" + python3 -m flashinfer.aot + + if [[ "${BUILD_WHEEL}" == "true" ]]; then + # Build wheel for distribution + uv build --no-build-isolation --wheel --out-dir ../flashinfer-dist . + echo "✅ FlashInfer wheel built successfully in flashinfer-dist/" + else + # Install directly (for Dockerfile) + uv pip install --system --no-build-isolation --force-reinstall . + echo "✅ FlashInfer installed successfully" + fi +popd + +# Cleanup +rm -rf flashinfer \ No newline at end of file diff --git a/tools/generate_cmake_presets.py b/tools/generate_cmake_presets.py index 5f92f2f5848fa..85847c2c0fe80 100644 --- a/tools/generate_cmake_presets.py +++ b/tools/generate_cmake_presets.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse import json import multiprocessing import os @@ -11,8 +12,7 @@ try: # most reliable source of truth for vLLM's build. from torch.utils.cpp_extension import CUDA_HOME except ImportError: - print("Warning: PyTorch not found. " - "Falling back to CUDA_HOME environment variable.") + print("Warning: PyTorch not found. Falling back to CUDA_HOME environment variable.") CUDA_HOME = os.environ.get("CUDA_HOME") @@ -26,7 +26,7 @@ def get_cpu_cores(): return multiprocessing.cpu_count() -def generate_presets(output_path="CMakeUserPresets.json"): +def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False): """Generates the CMakeUserPresets.json file.""" print("Attempting to detect your system configuration...") @@ -37,8 +37,7 @@ def generate_presets(output_path="CMakeUserPresets.json"): prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc") if os.path.exists(prospective_path): nvcc_path = prospective_path - print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: " - f"{nvcc_path}") + print(f"Found nvcc via torch.utils.cpp_extension.CUDA_HOME: {nvcc_path}") if not nvcc_path: nvcc_path = which("nvcc") @@ -48,7 +47,8 @@ def generate_presets(output_path="CMakeUserPresets.json"): if not nvcc_path: nvcc_path_input = input( "Could not automatically find 'nvcc'. Please provide the full " - "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ") + "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): " + ) nvcc_path = nvcc_path_input.strip() print(f"Using NVCC path: {nvcc_path}") @@ -61,12 +61,13 @@ def generate_presets(output_path="CMakeUserPresets.json"): "Could not automatically find Python executable. Please provide " "the full path to your Python executable for vLLM development " "(typically from your virtual environment, e.g., " - "/home/user/venvs/vllm/bin/python): ") + "/home/user/venvs/vllm/bin/python): " + ) python_executable = input(python_executable_prompt).strip() if not python_executable: raise ValueError( - "Could not determine Python executable. Please provide it " - "manually.") + "Could not determine Python executable. Please provide it manually." + ) print(f"Using Python executable: {python_executable}") @@ -74,20 +75,23 @@ def generate_presets(output_path="CMakeUserPresets.json"): cpu_cores = get_cpu_cores() nvcc_threads = min(4, cpu_cores) cmake_jobs = max(1, cpu_cores // nvcc_threads) - print(f"Detected {cpu_cores} CPU cores. " - f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.") + print( + f"Detected {cpu_cores} CPU cores. " + f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}." + ) # Get vLLM project root (assuming this script is in vllm/tools/) - project_root = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..")) + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) print(f"VLLM project root detected as: {project_root}") # Ensure python_executable path is absolute or resolvable if not os.path.isabs(python_executable) and which(python_executable): python_executable = os.path.abspath(which(python_executable)) elif not os.path.isabs(python_executable): - print(f"Warning: Python executable '{python_executable}' is not an " - "absolute path and not found in PATH. CMake might not find it.") + print( + f"Warning: Python executable '{python_executable}' is not an " + "absolute path and not found in PATH. CMake might not find it." + ) cache_variables = { "CMAKE_CUDA_COMPILER": nvcc_path, @@ -120,50 +124,57 @@ def generate_presets(output_path="CMakeUserPresets.json"): configure_preset["generator"] = "Ninja" cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}" else: - print("Ninja not found, using default generator. " - "Build may be slower.") + print("Ninja not found, using default generator. Build may be slower.") presets = { - "version": - 6, + "version": 6, # Keep in sync with CMakeLists.txt and requirements/build.txt - "cmakeMinimumRequired": { - "major": 3, - "minor": 26, - "patch": 1 - }, + "cmakeMinimumRequired": {"major": 3, "minor": 26, "patch": 1}, "configurePresets": [configure_preset], - "buildPresets": [{ - "name": "release", - "configurePreset": "release", - "jobs": cmake_jobs, - }], + "buildPresets": [ + { + "name": "release", + "configurePreset": "release", + "jobs": cmake_jobs, + } + ], } output_file_path = os.path.join(project_root, output_path) if os.path.exists(output_file_path): - overwrite = input( - f"'{output_file_path}' already exists. Overwrite? (y/N): ").strip( - ).lower() - if overwrite != 'y': - print("Generation cancelled.") - return + if force_overwrite: + print(f"Overwriting existing file '{output_file_path}'") + else: + overwrite = ( + input(f"'{output_file_path}' already exists. Overwrite? (y/N): ") + .strip() + .lower() + ) + if overwrite != "y": + print("Generation cancelled.") + return try: with open(output_file_path, "w") as f: json.dump(presets, f, indent=4) print(f"Successfully generated '{output_file_path}'") print("\nTo use this preset:") - print( - f"1. Ensure you are in the vLLM root directory: cd {project_root}") + print(f"1. Ensure you are in the vLLM root directory: cd {project_root}") print("2. Initialize CMake: cmake --preset release") - print("3. Build+install: cmake --build --preset release " - "--target install") + print("3. Build+install: cmake --build --preset release --target install") except OSError as e: print(f"Error writing file: {e}") if __name__ == "__main__": - generate_presets() + parser = argparse.ArgumentParser() + parser.add_argument( + "--force-overwrite", + action="store_true", + help="Force overwrite existing CMakeUserPresets.json without prompting", + ) + + args = parser.parse_args() + generate_presets(force_overwrite=args.force_overwrite) diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh index 33849581d2c0e..4f2cd302c3eff 100755 --- a/tools/install_deepgemm.sh +++ b/tools/install_deepgemm.sh @@ -6,7 +6,7 @@ set -e # Default values DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" -DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" +DEEPGEMM_GIT_REF="594953acce41793ae00a1233eb516044d604bcb6" # Parse command line arguments while [[ $# -gt 0 ]]; do @@ -105,4 +105,4 @@ fi popd -echo "✅ DeepGEMM installation completed successfully" \ No newline at end of file +echo "✅ DeepGEMM installation completed successfully" diff --git a/tools/install_gdrcopy.sh b/tools/install_gdrcopy.sh new file mode 100755 index 0000000000000..481723320c63b --- /dev/null +++ b/tools/install_gdrcopy.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Usage: install_gdrcopy.sh <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch> +# uuarch must be "x64" or "aarch64" +# Optional: set GDRCOPY_VERSION to override the libgdrapi package version (default: 2.5.1-1) +# Requires: curl, apt-get, root privileges +if [[ $(id -u) -ne 0 ]]; then + echo "Must be run as root" >&2 + + exit 1 +fi +if [[ $# -ne 3 ]]; then + echo "Usage: $0 <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch(x64|aarch64)>" >&2 + exit 1 +fi + +OS_VER="$1" +CUDA_VER="$2" +UUARCH_RAW="$3" + +# Normalize/validate arch +case "${UUARCH_RAW,,}" in + aarch64|arm64) + URL_ARCH="aarch64" + DEB_ARCH="arm64" + ;; + x64|x86_64|amd64) + URL_ARCH="x64" + DEB_ARCH="amd64" + ;; + *) + echo "Unsupported uuarch: ${UUARCH_RAW}. Use 'x64' or 'aarch64'." >&2 + exit 1 + ;; +esac + +OS_VER_LOWER="$(tr '[:upper:]' '[:lower:]' <<<"$OS_VER")" +GDRCOPY_PKG_VER="${GDRCOPY_VERSION:-2.5.1-1}" + +DEB_NAME="libgdrapi_${GDRCOPY_PKG_VER}_${DEB_ARCH}.${OS_VER}.deb" +BASE_URL="https://developer.download.nvidia.com/compute/redist/gdrcopy" +URL="${BASE_URL}/CUDA%20${CUDA_VER}/${OS_VER_LOWER}/${URL_ARCH}/${DEB_NAME}" + +echo "Downloading: ${URL}" +TMPDIR="$(mktemp -d)" +trap 'rm -rf "${TMPDIR}"' EXIT + +curl -fSL "${URL}" -o "${TMPDIR}/${DEB_NAME}" + +export DEBIAN_FRONTEND=noninteractive +apt-get update +apt-get install -y "${TMPDIR}/${DEB_NAME}" +apt-get clean +rm -rf /var/lib/apt/lists/* + +echo "Installed ${DEB_NAME}" diff --git a/tools/install_nixl.sh b/tools/install_nixl.sh deleted file mode 100644 index 56717cfb77f7b..0000000000000 --- a/tools/install_nixl.sh +++ /dev/null @@ -1,109 +0,0 @@ -#!/bin/bash -# Usage: ./install_nixl.sh [--force] - -FORCE=false -if [ "$1" == "--force" ]; then - FORCE=true -fi - -SUDO=false -if command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null; then - SUDO=true -fi - -ARCH=$(uname -m) - -ROOT_DIR="/usr/local" -mkdir -p "$ROOT_DIR" -GDR_HOME="$ROOT_DIR/gdrcopy" -UCX_HOME="$ROOT_DIR/ucx" -NIXL_HOME="$ROOT_DIR/nixl" -CUDA_HOME=/usr/local/cuda - -export PATH="$GDR_HOME/bin:$UCX_HOME/bin:$NIXL_HOME/bin:$PATH" -export LD_LIBRARY_PATH="$GDR_HOME/lib:$UCX_HOME/lib:$NIXL_HOME/lib/$ARCH-linux-gnu:$LD_LIBRARY_PATH" - -TEMP_DIR="nixl_installer" -mkdir -p "$TEMP_DIR" -cd "$TEMP_DIR" - -pip install meson ninja pybind11 - -if [ ! -e "/dev/gdrdrv" ] || [ "$FORCE" = true ]; then - echo "Installing gdrcopy\n" - wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.tar.gz - tar xzf v2.5.tar.gz; rm v2.5.tar.gz - cd gdrcopy-2.5 - make prefix=$GDR_HOME CUDA=$CUDA_HOME all install - - if $SUDO; then - echo "Running insmod.sh with sudo" - sudo ./insmod.sh - else - echo "Skipping insmod.sh - sudo not available" - echo "Please run 'sudo ./gdrcopy-2.5/insmod.sh' manually if needed" - fi - - cd .. -else - echo "Found /dev/gdrdrv. Skipping gdrcopy installation" -fi - -if ! command -v ucx_info &> /dev/null || [ "$FORCE" = true ]; then - echo "Installing UCX" - wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz - tar xzf ucx-1.18.0.tar.gz; rm ucx-1.18.0.tar.gz - cd ucx-1.18.0 - - # Checking Mellanox NICs - MLX_OPTS="" - if lspci | grep -i mellanox > /dev/null || command -v ibstat > /dev/null; then - echo "Mellanox NIC detected, adding Mellanox-specific options" - MLX_OPTS="--with-rdmacm \ - --with-mlx5-dv \ - --with-ib-hw-tm" - fi - - ./configure --prefix=$UCX_HOME \ - --enable-shared \ - --disable-static \ - --disable-doxygen-doc \ - --enable-optimizations \ - --enable-cma \ - --enable-devel-headers \ - --with-cuda=$CUDA_HOME \ - --with-dm \ - --with-gdrcopy=$GDR_HOME \ - --with-verbs \ - --enable-mt \ - $MLX_OPTS - make -j - make -j install-strip - - if $SUDO; then - echo "Running ldconfig with sudo" - sudo ldconfig - else - echo "Skipping ldconfig - sudo not available" - echo "Please run 'sudo ldconfig' manually if needed" - fi - - cd .. -else - echo "Found existing UCX. Skipping UCX installation" -fi - -if ! command -v nixl_test &> /dev/null || [ "$FORCE" = true ]; then - echo "Installing NIXL" - wget https://github.com/ai-dynamo/nixl/archive/refs/tags/0.2.0.tar.gz - tar xzf 0.2.0.tar.gz; rm 0.2.0.tar.gz - cd nixl-0.2.0 - meson setup build --prefix=$NIXL_HOME -Ducx_path=$UCX_HOME - cd build - ninja - ninja install - - cd ../.. -else - echo "Found existing NIXL. Skipping NIXL installation" -fi diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py new file mode 100644 index 0000000000000..c808b01d2e94b --- /dev/null +++ b/tools/install_nixl_from_source_ubuntu.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# install_prerequisites.py +import argparse +import glob +import os +import subprocess +import sys + +# --- Configuration --- +WHEELS_CACHE_HOME = os.environ.get("WHEELS_CACHE_HOME", "/tmp/wheels_cache") +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +UCX_DIR = os.path.join("/tmp", "ucx_source") +NIXL_DIR = os.path.join("/tmp", "nixl_source") +UCX_INSTALL_DIR = os.path.join("/tmp", "ucx_install") +UCX_REPO_URL = "https://github.com/openucx/ucx.git" +NIXL_REPO_URL = "https://github.com/ai-dynamo/nixl.git" + + +# --- Helper Functions --- +def run_command(command, cwd=".", env=None): + """Helper function to run a shell command and check for errors.""" + print(f"--> Running command: {' '.join(command)} in '{cwd}'", flush=True) + subprocess.check_call(command, cwd=cwd, env=env) + + +def is_pip_package_installed(package_name): + """Checks if a package is installed via pip without raising an exception.""" + result = subprocess.run( + [sys.executable, "-m", "pip", "show", package_name], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return result.returncode == 0 + + +def find_nixl_wheel_in_cache(cache_dir): + """Finds a nixl wheel file in the specified cache directory.""" + # The repaired wheel will have a 'manylinux' tag, but this glob still works. + search_pattern = os.path.join(cache_dir, "nixl-*.whl") + wheels = glob.glob(search_pattern) + if wheels: + # Sort to get the most recent/highest version if multiple exist + wheels.sort() + return wheels[-1] + return None + + +def install_system_dependencies(): + """Installs required system packages using apt-get if run as root.""" + if os.geteuid() != 0: + print("\n---", flush=True) + print( + "WARNING: Not running as root. \ + Skipping system dependency installation.", + flush=True, + ) + print( + "Please ensure the listed packages are installed on your system:", + flush=True, + ) + print( + " patchelf build-essential git cmake ninja-build \ + autotools-dev automake meson libtool libtool-bin", + flush=True, + ) + print("---\n", flush=True) + return + + print("--- Running as root. Installing system dependencies... ---", flush=True) + apt_packages = [ + "patchelf", # <-- Add patchelf here + "build-essential", + "git", + "cmake", + "ninja-build", + "autotools-dev", + "automake", + "meson", + "libtool", + "libtool-bin", + ] + run_command(["apt-get", "update"]) + run_command(["apt-get", "install", "-y"] + apt_packages) + print("--- System dependencies installed successfully. ---\n", flush=True) + + +def build_and_install_prerequisites(args): + """Builds UCX and NIXL from source, creating a self-contained wheel.""" + + if not args.force_reinstall and is_pip_package_installed("nixl"): + print("--> NIXL is already installed. Nothing to do.", flush=True) + return + + cached_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME) + if not args.force_reinstall and cached_wheel: + print( + f"\n--> Found self-contained wheel: \ + {os.path.basename(cached_wheel)}.", + flush=True, + ) + print("--> Installing from cache, skipping all source builds.", flush=True) + install_command = [sys.executable, "-m", "pip", "install", cached_wheel] + run_command(install_command) + print("\n--- Installation from cache complete. ---", flush=True) + return + + print( + "\n--> No installed package or cached wheel found. \ + Starting full build process...", + flush=True, + ) + print("\n--> Installing auditwheel...", flush=True) + run_command([sys.executable, "-m", "pip", "install", "auditwheel"]) + install_system_dependencies() + ucx_install_path = os.path.abspath(UCX_INSTALL_DIR) + print(f"--> Using wheel cache directory: {WHEELS_CACHE_HOME}", flush=True) + os.makedirs(WHEELS_CACHE_HOME, exist_ok=True) + + # -- Step 1: Build UCX from source -- + print("\n[1/3] Configuring and building UCX from source...", flush=True) + if not os.path.exists(UCX_DIR): + run_command(["git", "clone", UCX_REPO_URL, UCX_DIR]) + ucx_source_path = os.path.abspath(UCX_DIR) + run_command(["git", "checkout", "v1.19.x"], cwd=ucx_source_path) + run_command(["./autogen.sh"], cwd=ucx_source_path) + configure_command = [ + "./configure", + f"--prefix={ucx_install_path}", + "--enable-shared", + "--disable-static", + "--disable-doxygen-doc", + "--enable-optimizations", + "--enable-cma", + "--enable-devel-headers", + "--with-verbs", + "--enable-mt", + "--with-ze=no", + ] + run_command(configure_command, cwd=ucx_source_path) + run_command(["make", "-j", str(os.cpu_count() or 1)], cwd=ucx_source_path) + run_command(["make", "install"], cwd=ucx_source_path) + print("--- UCX build and install complete ---", flush=True) + + # -- Step 2: Build NIXL wheel from source -- + print("\n[2/3] Building NIXL wheel from source...", flush=True) + if not os.path.exists(NIXL_DIR): + run_command(["git", "clone", NIXL_REPO_URL, NIXL_DIR]) + + build_env = os.environ.copy() + build_env["PKG_CONFIG_PATH"] = os.path.join(ucx_install_path, "lib", "pkgconfig") + ucx_lib_path = os.path.join(ucx_install_path, "lib") + ucx_plugin_path = os.path.join(ucx_lib_path, "ucx") + existing_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + build_env["LD_LIBRARY_PATH"] = ( + f"{ucx_lib_path}:{ucx_plugin_path}:{existing_ld_path}".strip(":") + ) + print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True) + + temp_wheel_dir = os.path.join(ROOT_DIR, "temp_wheelhouse") + run_command( + [ + sys.executable, + "-m", + "pip", + "wheel", + ".", + "--no-deps", + f"--wheel-dir={temp_wheel_dir}", + ], + cwd=os.path.abspath(NIXL_DIR), + env=build_env, + ) + + # -- Step 3: Repair the wheel by copying UCX libraries -- + print("\n[3/3] Repairing NIXL wheel to include UCX libraries...", flush=True) + unrepaired_wheel = find_nixl_wheel_in_cache(temp_wheel_dir) + if not unrepaired_wheel: + raise RuntimeError("Failed to find the NIXL wheel after building it.") + + # We tell auditwheel to ignore the plugin that mesonpy already handled. + auditwheel_command = [ + "auditwheel", + "repair", + "--exclude", + "libplugin_UCX.so", # <-- Exclude because mesonpy already includes it + unrepaired_wheel, + f"--wheel-dir={WHEELS_CACHE_HOME}", + ] + run_command(auditwheel_command, env=build_env) + + # --- CLEANUP --- + # No more temporary files to remove, just the temp wheelhouse + run_command(["rm", "-rf", temp_wheel_dir]) + # --- END CLEANUP --- + + newly_built_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME) + if not newly_built_wheel: + raise RuntimeError("Failed to find the repaired NIXL wheel.") + + print( + f"--> Successfully built self-contained wheel: \ + {os.path.basename(newly_built_wheel)}. Now installing...", + flush=True, + ) + install_command = [sys.executable, "-m", "pip", "install", newly_built_wheel] + if args.force_reinstall: + install_command.insert(-1, "--force-reinstall") + + run_command(install_command) + print("--- NIXL installation complete ---", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Build and install UCX and NIXL dependencies." + ) + parser.add_argument( + "--force-reinstall", + action="store_true", + help="Force rebuild and reinstall of UCX and NIXL \ + even if they are already installed.", + ) + args = parser.parse_args() + build_and_install_prerequisites(args) diff --git a/tools/mypy.sh b/tools/mypy.sh deleted file mode 100755 index 781d8fc02884b..0000000000000 --- a/tools/mypy.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -CI=${1:-0} -PYTHON_VERSION=${2:-local} - -if [ "$CI" -eq 1 ]; then - set -e -fi - -if [ $PYTHON_VERSION == "local" ]; then - PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') -fi - -run_mypy() { - echo "Running mypy on $1" - if [ "$CI" -eq 1 ] && [ -z "$1" ]; then - mypy --python-version "${PYTHON_VERSION}" "$@" - return - fi - mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@" -} - -run_mypy # Note that this is less strict than CI -run_mypy tests -run_mypy vllm/attention -run_mypy vllm/compilation -run_mypy vllm/distributed -run_mypy vllm/engine -run_mypy vllm/executor -run_mypy vllm/inputs -run_mypy vllm/lora -run_mypy vllm/model_executor -run_mypy vllm/plugins -run_mypy vllm/worker -run_mypy vllm/v1 diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py new file mode 100644 index 0000000000000..7944b7c9b275c --- /dev/null +++ b/tools/pre_commit/check_pickle_imports.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys + +import regex as re + +# List of files (relative to repo root) that are allowed to import pickle or +# cloudpickle +# +# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST: +# The pickle and cloudpickle modules are known to be unsafe when deserializing +# data from potentially untrusted parties. They have resulted in multiple CVEs +# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly. +# Before adding new uses of pickle/cloudpickle, please consider safer +# alternatives like msgpack or pydantic that are already in use in vLLM. Only +# add to this list if absolutely necessary and after careful security review. +ALLOWED_FILES = { + # pickle + "vllm/v1/serial_utils.py", + "vllm/v1/executor/multiproc_executor.py", + "vllm/multimodal/hasher.py", + "vllm/transformers_utils/config.py", + "vllm/model_executor/models/registry.py", + "vllm/compilation/caching.py", + "tests/utils_/test_utils.py", + "tests/tokenization/test_cached_tokenizer.py", + "vllm/distributed/utils.py", + "vllm/distributed/parallel_state.py", + "vllm/distributed/device_communicators/all_reduce_utils.py", + "vllm/distributed/device_communicators/shm_broadcast.py", + "vllm/distributed/device_communicators/shm_object_storage.py", + "benchmarks/kernels/graph_machete_bench.py", + "benchmarks/kernels/benchmark_lora.py", + "benchmarks/kernels/benchmark_machete.py", + "benchmarks/fused_kernels/layernorm_rms_benchmarks.py", + "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py", + "benchmarks/cutlass_benchmarks/sparse_benchmarks.py", + # cloudpickle + "vllm/executor/mp_distributed_executor.py", + "vllm/executor/ray_distributed_executor.py", + "vllm/entrypoints/llm.py", + "tests/utils.py", + # pickle and cloudpickle + "vllm/utils/__init__.py", +} + +PICKLE_RE = re.compile( + r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" + r"|from\s+(pickle|cloudpickle)\s+import\b)" +) + + +def scan_file(path: str) -> int: + with open(path, encoding="utf-8") as f: + for i, line in enumerate(f, 1): + if PICKLE_RE.match(line): + print( + f"{path}:{i}: " + "\033[91merror:\033[0m " # red color + "Found pickle/cloudpickle import" + ) + return 1 + return 0 + + +def main(): + returncode = 0 + for filename in sys.argv[1:]: + if filename in ALLOWED_FILES: + continue + returncode |= scan_file(filename) + return returncode + + +def test_regex(): + test_cases = [ + # Should match + ("import pickle", True), + ("import cloudpickle", True), + ("import pickle as pkl", True), + ("import cloudpickle as cpkl", True), + ("from pickle import *", True), + ("from cloudpickle import dumps", True), + ("from pickle import dumps, loads", True), + ("from cloudpickle import (dumps, loads)", True), + (" import pickle", True), + ("\timport cloudpickle", True), + ("from pickle import loads", True), + # Should not match + ("import somethingelse", False), + ("from somethingelse import pickle", False), + ("# import pickle", False), + ("print('import pickle')", False), + ("import pickleas as asdf", False), + ] + for i, (line, should_match) in enumerate(test_cases): + result = bool(PICKLE_RE.match(line)) + assert result == should_match, ( + f"Test case {i} failed: '{line}' (expected {should_match}, got {result})" + ) + print("All regex tests passed.") + + +if __name__ == "__main__": + if "--test-regex" in sys.argv: + test_regex() + else: + sys.exit(main()) diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py new file mode 100755 index 0000000000000..22ee08535bddb --- /dev/null +++ b/tools/pre_commit/mypy.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Run mypy on changed files. + +This script is designed to be used as a pre-commit hook. It runs mypy +on files that have been changed. It groups files into different mypy calls +based on their directory to avoid import following issues. + +Usage: + python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...> + +Args: + ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to + "silent" for the main group of files. + python_version: Python version to use (e.g., "3.10") or "local" to use + the local Python version. + changed_files: List of changed files to check. +""" + +import subprocess +import sys +from typing import Optional + +import regex as re + +FILES = [ + "vllm/*.py", + "vllm/assets", + "vllm/entrypoints", + "vllm/inputs", + "vllm/logging_utils", + "vllm/multimodal", + "vllm/platforms", + "vllm/transformers_utils", + "vllm/triton_utils", + "vllm/usage", +] + +# After fixing errors resulting from changing follow_imports +# from "skip" to "silent", move the following directories to FILES +SEPARATE_GROUPS = [ + "tests", + "vllm/attention", + "vllm/compilation", + "vllm/distributed", + "vllm/engine", + "vllm/executor", + "vllm/inputs", + "vllm/lora", + "vllm/model_executor", + "vllm/plugins", + "vllm/worker", + "vllm/v1", +] + +# TODO(woosuk): Include the code from Megatron and HuggingFace. +EXCLUDE = [ + "vllm/model_executor/parallel_utils", + "vllm/model_executor/models", + "vllm/model_executor/layers/fla/ops", + # Ignore triton kernels in ops. + "vllm/attention/ops", +] + + +def group_files(changed_files: list[str]) -> dict[str, list[str]]: + """ + Group changed files into different mypy calls. + + Args: + changed_files: List of changed files. + + Returns: + A dictionary mapping file group names to lists of changed files. + """ + exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*") + files_pattern = re.compile(f"^({'|'.join(FILES)}).*") + file_groups = {"": []} + file_groups.update({k: [] for k in SEPARATE_GROUPS}) + for changed_file in changed_files: + # Skip files which should be ignored completely + if exclude_pattern.match(changed_file): + continue + # Group files by mypy call + if files_pattern.match(changed_file): + file_groups[""].append(changed_file) + continue + else: + for directory in SEPARATE_GROUPS: + if re.match(f"^{directory}.*", changed_file): + file_groups[directory].append(changed_file) + break + return file_groups + + +def mypy( + targets: list[str], + python_version: Optional[str], + follow_imports: Optional[str], + file_group: str, +) -> int: + """ + Run mypy on the given targets. + + Args: + targets: List of files or directories to check. + python_version: Python version to use (e.g., "3.10") or None to use + the default mypy version. + follow_imports: Value for the --follow-imports option or None to use + the default mypy behavior. + file_group: The file group name for logging purposes. + + Returns: + The return code from mypy. + """ + args = ["mypy"] + if python_version is not None: + args += ["--python-version", python_version] + if follow_imports is not None: + args += ["--follow-imports", follow_imports] + print(f"$ {' '.join(args)} {file_group}") + return subprocess.run(args + targets, check=False).returncode + + +def main(): + ci = sys.argv[1] == "1" + python_version = sys.argv[2] + file_groups = group_files(sys.argv[3:]) + + if python_version == "local": + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + returncode = 0 + for file_group, changed_files in file_groups.items(): + follow_imports = None if ci and file_group == "" else "skip" + if changed_files: + returncode |= mypy( + changed_files, python_version, follow_imports, file_group + ) + return returncode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tools/profiler/nsys_profile_tools/gputrc2graph.py index 42dfede9e9870..fd237c0b214a4 100755 --- a/tools/profiler/nsys_profile_tools/gputrc2graph.py +++ b/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ - This generates gpu kernel analysis output from nsys rep. Will call nsys - stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate - csv and html output for analysis +This generates gpu kernel analysis output from nsys rep. Will call nsys +stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate +csv and html output for analysis """ + import argparse import logging import os @@ -16,13 +17,13 @@ logger = logging.getLogger(__name__) # helper data class for annotating kernels def load_engine_model(): - """ returns engine_model built from all json files in the current dir """ + """returns engine_model built from all json files in the current dir""" import glob import json + engine_model = {} - json_files = glob.glob( - os.path.join(os.path.dirname(__file__) or ".", "*.json")) + json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json")) for fname in json_files: with open(fname, encoding="utf-8") as f: engine_model.update(json.load(f)) @@ -30,54 +31,54 @@ def load_engine_model(): class GPUTrace2Graph: - """ - Parses output of nsys report, generates csv and bar chart output + """ + Parses output of nsys report, generates csv and bar chart output """ def __init__(self): import pandas as pd # avoid importing till needed + self.pd = pd self.pd.options.mode.copy_on_write = True # helper functions for generating trace->summary csvs def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file): - logger.info('loading %s', in_file) + logger.info("loading %s", in_file) df = self.pd.read_csv( - in_file, - usecols=['Start (ns)', 'Duration (ns)', 'Device', 'Strm', 'Name']) - df['End (ns)'] = df['Start (ns)'] + df['Duration (ns)'] + in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"] + ) + df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"] df = self.sum_non_overlapping_intervals(df) # get ready to print table with elapsed times per kernel - df['Instances'] = 1 - df_sum = df.groupby('Name', as_index=False).agg({ - 'Elapsed Time (ns)': 'sum', - 'Duration (ns)': 'sum', - 'Instances': 'size' - }) + df["Instances"] = 1 + df_sum = df.groupby("Name", as_index=False).agg( + {"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"} + ) # generate csv - df_sum['Total Time (sec)'] = df_sum['Duration (ns)'] / 1e9 - df_sum['Elapsed Time (sec)'] = df_sum['Elapsed Time (ns)'] / 1e9 - df_sum = df_sum.sort_values(by='Elapsed Time (sec)', ascending=False) - df_sum[['Elapsed Time (sec)', 'Total Time (sec)', 'Instances', - 'Name']].to_csv(out_file, index=False) + df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9 + df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9 + df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False) + df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv( + out_file, index=False + ) def sum_non_overlapping_intervals(self, df): - """ - returns new sorted df with Elapsed Time (ns) column using - vectorized operations + """ + returns new sorted df with Elapsed Time (ns) column using + vectorized operations """ logger.info("sorting %s trace records by start time", str(df.shape)) # Sort by start time and reset index - df = df.sort_values(by='Start (ns)').reset_index(drop=True) + df = df.sort_values(by="Start (ns)").reset_index(drop=True) # Initialize elapsed time as duration - df['Elapsed Time (ns)'] = df['Duration (ns)'] + df["Elapsed Time (ns)"] = df["Duration (ns)"] # Get numpy arrays for faster operations - starts = df['Start (ns)'].values - ends = df['End (ns)'].values + starts = df["Start (ns)"].values + ends = df["End (ns)"].values # Keep track of current interval end current_end = ends[0] @@ -85,16 +86,17 @@ class GPUTrace2Graph: # Update current_end for overlapping intervals for i in range(1, len(df)): if i % display_units == 0: - print(f'processing trace: {int(i/len(df) * 100)} %', end="\r") + print(f"processing trace: {int(i / len(df) * 100)} %", end="\r") if starts[i] <= current_end: if ends[i] > current_end: # Partial overlap - df.iloc[i, df.columns.get_loc('Elapsed Time (ns)' - )] = ends[i] - current_end + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = ( + ends[i] - current_end + ) current_end = ends[i] else: # Complete overlap - df.iloc[i, df.columns.get_loc('Elapsed Time (ns)')] = 0 + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0 else: # No overlap current_end = ends[i] @@ -103,147 +105,167 @@ class GPUTrace2Graph: # functions for generating html files def make_html(self, df, output_dir, title): - """ make html graph from df """ + """make html graph from df""" import plotly.express as px + if df.empty: return - output_name = output_dir + '/result' + output_name = output_dir + "/result" if not title: - title = 'Model_Engine' - x = 'Model_Engine' - y = 'Elapsed Time (sec)' - color = 'Category' + title = "Model_Engine" + x = "Model_Engine" + y = "Elapsed Time (sec)" + color = "Category" """ generate kernel mapping table """ # Sort Model_Engine categories by last field after underscore - df['Model_Engine'] = self.pd.Categorical( - df['Model_Engine'], - sorted(df['Model_Engine'].unique(), - key=lambda x: x.split('_')[-1])) - df[['Model_Engine', color, 'Instances', 'Name', - y]].sort_values(by=color).to_csv(f'{output_name}.csv', index=False) - graph = px.histogram(df.round(2), - x=x, - y=y, - title=(f'{y} for {title}'), - color=color, - text_auto=True) + df["Model_Engine"] = self.pd.Categorical( + df["Model_Engine"], + sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]), + ) + df[["Model_Engine", color, "Instances", "Name", y]].sort_values( + by=color + ).to_csv(f"{output_name}.csv", index=False) + graph = px.histogram( + df.round(2), + x=x, + y=y, + title=(f"{y} for {title}"), + color=color, + text_auto=True, + ) # wrap x axis labels graph.update_xaxes(automargin=True) - graph.write_html(f'{output_name}.html') + graph.write_html(f"{output_name}.html") """ Generate data table with columns per Model_Engine into result.html """ - pivot_df = df.pivot_table(values='Elapsed Time (sec)', - index='Category', - columns='Model_Engine', - aggfunc='sum', - observed=False).round(2) + pivot_df = df.pivot_table( + values="Elapsed Time (sec)", + index="Category", + columns="Model_Engine", + aggfunc="sum", + observed=False, + ).round(2) # Add sum row at bottom - pivot_df.loc['total_elapsed_sec'] = pivot_df.sum() - pivot_df.fillna('').to_html('temp.html') - with (open(f'{output_name}.html', 'a', encoding='utf-8') as - outfile, open('temp.html', encoding='utf-8') as infile): + pivot_df.loc["total_elapsed_sec"] = pivot_df.sum() + pivot_df.fillna("").to_html("temp.html") + with ( + open(f"{output_name}.html", "a", encoding="utf-8") as outfile, + open("temp.html", encoding="utf-8") as infile, + ): outfile.write(infile.read()) - os.remove('temp.html') + os.remove("temp.html") - print(f'Finished generating: \n' - f' {output_name}.html for stack bar chart \n' - f' {output_name}.csv for Kernel-Category mapping') + print( + f"Finished generating: \n" + f" {output_name}.html for stack bar chart \n" + f" {output_name}.csv for Kernel-Category mapping" + ) def anno_gpu_kernname(self, df, mapping): - """ add "Category" column """ + """add "Category" column""" def anno_gpu_kernname_helper(name): for kern_name, val in mapping.items(): if re.search(kern_name, name): return val - df['Category'] = df['Name'].apply(anno_gpu_kernname_helper) + df["Category"] = df["Name"].apply(anno_gpu_kernname_helper) def make_nongpu_row(self, df, nongpu_sec): - """ this will append non-gpu time entry at end of df """ + """this will append non-gpu time entry at end of df""" nongpu_row = self.pd.DataFrame([df.iloc[-1]]) - nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)' - nongpu_row['Instances'] = 1 - nongpu_row['Elapsed Time (sec)'] = nongpu_sec - return (nongpu_row) + nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)" + nongpu_row["Instances"] = 1 + nongpu_row["Elapsed Time (sec)"] = nongpu_sec + return nongpu_row def is_valid_file(self, base_file): - """ asserts if base_file is non-existent or is empty """ - assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, \ - f"{base_file} doesn't exist or is empty" + """asserts if base_file is non-existent or is empty""" + assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, ( + f"{base_file} doesn't exist or is empty" + ) def should_gen_file(self, new_file, base_file): - """ figure out if new file should be generated from base_file """ + """figure out if new file should be generated from base_file""" self.is_valid_file(base_file) - if (os.path.exists(new_file) - and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) - and (os.path.getsize(base_file) > 0)): - logger.info('reusing %s', new_file) + if ( + os.path.exists(new_file) + and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) + and (os.path.getsize(base_file) > 0) + ): + logger.info("reusing %s", new_file) return False else: - logger.info('generating %s', new_file) + logger.info("generating %s", new_file) return True def gen_sum_file(self, file, nsys_cmd): - """ - generates sum file from nsys trace with times per kernel and - returns the name of the sum file + """ + generates sum file from nsys trace with times per kernel and + returns the name of the sum file """ import subprocess + file_dir = os.path.dirname(file) file_name = os.path.basename(file) if not file_dir: - file_dir = '.' + file_dir = "." # Walk through trace and get the total non-overlapped time - nsys_stats_file = f'{file_dir}/{file_name}_cuda_gpu_trace.csv' - sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv' + nsys_stats_file = f"{file_dir}/{file_name}_cuda_gpu_trace.csv" + sum_file = f"{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv" if self.should_gen_file(nsys_stats_file, file): cmd = [ - nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', - f'{file_dir}/{file_name}' + nsys_cmd, + "stats", + "-r", + "cuda_gpu_trace", + file, + "-o", + f"{file_dir}/{file_name}", ] - cmd_str = ' '.join(cmd) - logger.info('+ %s', cmd_str) + cmd_str = " ".join(cmd) + logger.info("+ %s", cmd_str) # estimate time based on calibrated 240M/min file_size_mb = os.path.getsize(file) / 1e6 logger.info( - 'nsys stats for %.2f MB file expected to take %.2f min', - file_size_mb, file_size_mb / 240) + "nsys stats for %.2f MB file expected to take %.2f min", + file_size_mb, + file_size_mb / 240, + ) try: subprocess.run(cmd, check=True) except Exception: - logger.error("%s failed; Use --nsys_cmd to specify nsys path", - cmd_str) + logger.error("%s failed; Use --nsys_cmd to specify nsys path", cmd_str) exit(1) - logger.info('generating non-overalapped sum %s', sum_file) + logger.info("generating non-overalapped sum %s", sum_file) self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) self.is_valid_file(sum_file) - logger.info('Finished generating %s', sum_file) + logger.info("Finished generating %s", sum_file) return sum_file def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): - """ generates graph and csv file from in_file into out_dir """ + """generates graph and csv file from in_file into out_dir""" # Initialize an empty DataFrame to store combined data combined_df = self.pd.DataFrame() for idx, (file, engine, model, total_sec) in enumerate(in_file): file_dir = os.path.dirname(file) file_name = os.path.basename(file) if not file_dir: - file_dir = '.' + file_dir = "." sum_file = self.gen_sum_file(file, nsys_cmd) # read kernel summary file df = self.pd.read_csv(sum_file) # annotate kernel to their categories - assert engine_model.get(engine), f'engine {engine} unknown' - assert engine_model[engine].get(model), f'model {model} unknown' + assert engine_model.get(engine), f"engine {engine} unknown" + assert engine_model[engine].get(model), f"model {model} unknown" # remove nsys-rep from file_name for shorter x-label - file_name = file_name.replace('.nsys-rep', '') - df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}' + file_name = file_name.replace(".nsys-rep", "") + df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}" self.anno_gpu_kernname(df, engine_model[engine][model]) # patch in non-gpu time - gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1) + gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1) total_sec = round(float(total_sec), 1) if total_sec < gpu_sec: logger.warning( @@ -256,7 +278,7 @@ class GPUTrace2Graph: df = self.pd.concat([df, nongpu_row], ignore_index=True) combined_df = self.pd.concat([combined_df, df], ignore_index=True) if out_dir is None: - out_dir = '.' + out_dir = "." else: os.makedirs(out_dir, exist_ok=True) # generate html file @@ -264,50 +286,59 @@ class GPUTrace2Graph: def parse_tuple(s): - return tuple(s.split(',')) + return tuple(s.split(",")) def main(): - logging.basicConfig(format=('%(asctime)s - %(levelname)s - %(message)s'), - level=logging.INFO) + logging.basicConfig( + format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO + ) parser = argparse.ArgumentParser( description=( - 'Process nsys rep and generate kernel non-overlapped cycles. \n' - 'Example:\n' + "Process nsys rep and generate kernel non-overlapped cycles. \n" + "Example:\n" "gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n" "d2.nsys-rep,vllm,gpt-oss,102 " - "--out_dir results/ --title \"Model=gpt-oss vLLM chart\""), - formatter_class=argparse.RawDescriptionHelpFormatter) + '--out_dir results/ --title "Model=gpt-oss vLLM chart"' + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) # load supported engine_model engine_model_supported = load_engine_model() # Get a string representation of supported engine/model combinations - engine_model_supported_str = ', '.join( + engine_model_supported_str = ", ".join( f"{engine}:[{', '.join(models.keys())}]" - for engine, models in engine_model_supported.items()) + for engine, models in engine_model_supported.items() + ) parser.add_argument( - '--in_file', + "--in_file", type=parse_tuple, - nargs='+', + nargs="+", help=( - 'list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) ' - 'separated by space. Elapsed_nonprofiled_sec is runtime without ' - 'profiling used to calculate non-gpu time. Specify 0 to use ' - 'elapsed time from nsys-rep but that might inflate non-gpu time. ' - f'Available engine:[model] are: {engine_model_supported_str} ' - f'Example: --infile d1.nsys-rep,vllm,llama,100 ' - 'd2.nsys-rep,vllm,gpt-oss,102'), - required=True) - parser.add_argument('--out_dir', help=('output dir for result.csv/html')) - parser.add_argument('--title', help=('title for html chart')) - parser.add_argument('--nsys_cmd', - help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'), - default="nsys") + "list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) " + "separated by space. Elapsed_nonprofiled_sec is runtime without " + "profiling used to calculate non-gpu time. Specify 0 to use " + "elapsed time from nsys-rep but that might inflate non-gpu time. " + f"Available engine:[model] are: {engine_model_supported_str} " + f"Example: --infile d1.nsys-rep,vllm,llama,100 " + "d2.nsys-rep,vllm,gpt-oss,102" + ), + required=True, + ) + parser.add_argument("--out_dir", help=("output dir for result.csv/html")) + parser.add_argument("--title", help=("title for html chart")) + parser.add_argument( + "--nsys_cmd", + help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"), + default="nsys", + ) args = parser.parse_args() gputrace = GPUTrace2Graph() - gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, - engine_model_supported) + gputrace.gen_graph( + args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py index 209c3a576aeed..d7a24a598593d 100644 --- a/tools/profiler/print_layerwise_table.py +++ b/tools/profiler/print_layerwise_table.py @@ -29,48 +29,50 @@ def flatten_entries(entry_cls, profile_dict: dict): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--json-trace", - type=str, - required=True, - help="json trace file output by " - "examples/offline_inference/profiling.py") - parser.add_argument("--phase", - type=str, - required=True, - help="The phase to print the table for. This is either" - "prefill or decode_n, where n is the decode step " - "number") - parser.add_argument("--table", - type=str, - choices=["summary", "model"], - default="summary", - help="Which table to print, the summary table or the " - "layerwise model table") + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_inference/profiling.py", + ) + parser.add_argument( + "--phase", + type=str, + required=True, + help="The phase to print the table for. This is either" + "prefill or decode_n, where n is the decode step " + "number", + ) + parser.add_argument( + "--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the layerwise model table", + ) args = parser.parse_args() with open(args.json_trace) as f: profile_data = json.load(f) - assert args.phase in profile_data, \ - (f"Cannot find phase {args.phase} in profile data. Choose one among" - f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa + assert args.phase in profile_data, ( + f"Cannot find phase {args.phase} in profile data. Choose one among" + f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}" + ) # noqa if args.table == "summary": entries_and_depths = flatten_entries( - SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) - column_widths = dict(name=80, - cuda_time_us=12, - pct_cuda_time=12, - invocations=15) + SummaryStatsEntry, profile_data[args.phase]["summary_stats"] + ) + column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15) elif args.table == "model": entries_and_depths = flatten_entries( - ModelStatsEntry, profile_data[args.phase]["model_stats"]) - column_widths = dict(name=60, - cpu_time_us=12, - cuda_time_us=12, - pct_cuda_time=12, - trace=60) + ModelStatsEntry, profile_data[args.phase]["model_stats"] + ) + column_widths = dict( + name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60 + ) # indent entry names based on the depth entries = [] @@ -78,7 +80,8 @@ if __name__ == "__main__": entry.name = indent_string( entry.name, indent=depth, - indent_style=lambda indent: "|" + "-" * indent + " ") + indent_style=lambda indent: "|" + "-" * indent + " ", + ) entries.append(entry) TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index 038d3c44f043a..cdab004366f9d 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -18,17 +18,18 @@ import pandas as pd def largest_dist_from_leaf(node: dict, depth: int = 0): if len(node["children"]) == 0: return depth - return max([ - largest_dist_from_leaf(child, depth=depth + 1) - for child in node["children"] - ]) + return max( + [largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]] + ) -def get_entries_at_depth(depth: int, - entries_and_traces: list[tuple[Any, Any]], - node: dict, - curr_depth: int = 0, - trace=()): +def get_entries_at_depth( + depth: int, + entries_and_traces: list[tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=(), +): # assert that the query is at kernel or module level assert depth == -1 or depth == -2 @@ -40,21 +41,18 @@ def get_entries_at_depth(depth: int, if largest_dist_from_leaf(node) == (abs(depth) - 1): entries_and_traces.append((node["entry"], trace)) - trace = (node["entry"]["name"], ) + trace + trace = (node["entry"]["name"],) + trace for child in node["children"]: - get_entries_at_depth(depth, - entries_and_traces, - child, - curr_depth=curr_depth + 1, - trace=trace) + get_entries_at_depth( + depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace + ) def fold_nodes(root: dict, nodes_to_fold: list[str]): - stack: list[dict] = [root] while len(stack) != 0: node = stack.pop() - if node['entry']['name'] in nodes_to_fold: + if node["entry"]["name"] in nodes_to_fold: node["children"] = [] continue for child in node["children"]: @@ -76,9 +74,7 @@ def trim_string_back(string: str, width: int) -> str: def shorten_plot_legend_strings(legend, max_char_len: int): for t in legend.get_texts(): - t.set_text( - trim_string_back(abbreviate_known_names(t.get_text()), - max_char_len)) + t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len)) def abbreviate_known_names(name: str) -> str: @@ -108,50 +104,54 @@ def attempt_to_make_names_unique(entries_and_traces): names.add(entry["name"]) for name in non_unique_names: - entries_and_traces_with_name = [(entry, trace) - for entry, trace in entries_and_traces - if entry["name"] == name] + entries_and_traces_with_name = [ + (entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name + ] - zipped_traces = list( - zip(*[trace for _, trace in entries_and_traces_with_name])) + zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name])) first_trace_difference = next( - (i for i, trace_eles in enumerate(zipped_traces) - if not all_the_same(trace_eles)), None) + ( + i + for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles) + ), + None, + ) if first_trace_difference is None: - # can't create a unique name, leave them names as the + # can't create a unique name, leave the names as they # are they will get aggregated by the pivot_table call continue for entry, trace in entries_and_traces_with_name: - entry["name"] = " <- ".join((entry["name"], ) + - trace[:first_trace_difference + 1]) + entry["name"] = " <- ".join( + (entry["name"],) + trace[: first_trace_difference + 1] + ) ## Operation grouping utils #### -''' +""" Group operations in the given dataframe by some high-level ops like, - gemms - attention - rms_norm etc. -''' +""" def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: - def is_rms_norm(op_name: str): if "rms_norm_kernel" in op_name: return True def is_attention_block(op_name: str): - if "flash_fwd" in op_name or \ - "reshape_and_cache_flash_kernel" in op_name: + if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name: return True def is_quant(op_name: str): - if "scaled_fp8_quant" in op_name or \ - "scaled_int8_quant" in op_name: + if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name: return True # LoRA ops @@ -168,24 +168,27 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: return "bgmv_expand" in op_name def is_cutlass_gemm_op(op_name: str): - return "void cutlass::Kernel" in op_name or \ - "void cutlass::device_kernel" in op_name + return ( + "void cutlass::Kernel" in op_name + or "void cutlass::device_kernel" in op_name + ) def is_gemm_op(op_name: str): if is_quant(op_name): return False - return is_cutlass_gemm_op(op_name) or \ - "xmma_gemm" in op_name or \ - "gemv2T_kernel" in op_name or \ - "splitKreduce" in op_name or \ - "s16816gemm" in op_name + return ( + is_cutlass_gemm_op(op_name) + or "xmma_gemm" in op_name + or "gemv2T_kernel" in op_name + or "splitKreduce" in op_name + or "s16816gemm" in op_name + ) def is_elementwise_op(op_name: str): return "elementwise_kernel" in op_name def is_mem_op(op_name: str): - return "memcpy" in op_name.lower() or \ - "memset" in op_name.lower() + return "memcpy" in op_name.lower() or "memset" in op_name.lower() def is_vocab_embedding_op(op_name: str): return "vocabparallelembed" in op_name.lower() @@ -195,17 +198,15 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: return "nccl" in op_name.lower() def is_nccl_all_reduce(op_name: str): - return is_nccl_op(op_name) and \ - ("all_reduce" in op_name.lower() or \ - "allreduce" in op_name.lower()) + return is_nccl_op(op_name) and ( + "all_reduce" in op_name.lower() or "allreduce" in op_name.lower() + ) def is_nccl_gather(op_name: str): - return is_nccl_op(op_name) and \ - "gather" in op_name.lower() + return is_nccl_op(op_name) and "gather" in op_name.lower() def is_nccl_broadcast(op_name: str): - return is_nccl_op(op_name) and \ - "broadcast" in op_name.lower() + return is_nccl_op(op_name) and "broadcast" in op_name.lower() # Reduce ops types def is_cross_device_reduce_1stage(op_name: str): @@ -269,114 +270,122 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: ops = list(filter(lambda x: x not in nccl_other_ops, ops)) cross_device_reduce_1stage_ops = list( - filter(lambda x: is_cross_device_reduce_1stage(x), ops)) + filter(lambda x: is_cross_device_reduce_1stage(x), ops) + ) ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops)) cross_device_reduce_2stage_ops = list( - filter(lambda x: is_cross_device_reduce_2stage(x), ops)) + filter(lambda x: is_cross_device_reduce_2stage(x), ops) + ) ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) - custom_ar_all_reduce_ops = list( - filter(lambda x: is_custom_ar_all_reduce(x), ops)) + custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops)) ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops)) reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) if len(attention_ops): - trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1) if len(quant_ops): - trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1) if len(sgmv_shrink_ops): - trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum", - axis=1) + trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1) if len(sgmv_expand_ops): - trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum", - axis=1) + trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1) if len(bgmv_shrink_ops): - trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum", - axis=1) + trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1) if len(bgmv_expand_ops): - trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum", - axis=1) + trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1) if len(cutlass_gemm_ops): - trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum", - axis=1) + trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1) if len(gemm_ops): - trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1) if len(rms_norm_ops): - trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1) if len(vocab_embed_ops): - trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", - axis=1) + trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1) if len(mem_ops): - trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1) if len(elementwise_ops): - trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", - axis=1) + trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1) if len(nccl_all_reduce_ops): - trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg( - "sum", axis=1) + trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg( + "sum", axis=1 + ) if len(nccl_gather_ops): - trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum", - axis=1) + trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1) if len(nccl_broadcast_ops): - trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg( - "sum", axis=1) + trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1) if len(nccl_other_ops): - trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum", - axis=1) + trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1) if len(cross_device_reduce_1stage_ops): - trace_df['cross_device_reduce_1stage_ops'] = trace_df[ - cross_device_reduce_1stage_ops].agg("sum", axis=1) + trace_df["cross_device_reduce_1stage_ops"] = trace_df[ + cross_device_reduce_1stage_ops + ].agg("sum", axis=1) if len(cross_device_reduce_2stage_ops): - trace_df['cross_device_reduce_2stage_ops'] = trace_df[ - cross_device_reduce_2stage_ops].agg("sum", axis=1) + trace_df["cross_device_reduce_2stage_ops"] = trace_df[ + cross_device_reduce_2stage_ops + ].agg("sum", axis=1) if len(custom_ar_all_reduce_ops): - trace_df['custom_ar_all_reduce_ops'] = trace_df[ - custom_ar_all_reduce_ops].agg("sum", axis=1) + trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg( + "sum", axis=1 + ) if len(reduce_kernel_ops): - trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", - axis=1) + trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1) - trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops + - sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops + - cutlass_gemm_ops + gemm_ops + rms_norm_ops + - vocab_embed_ops + mem_ops + elementwise_ops + - nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + - nccl_other_ops + cross_device_reduce_1stage_ops + - cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops + - reduce_kernel_ops, - axis=1, - inplace=True) + trace_df.drop( + attention_ops + + quant_ops + + sgmv_shrink_ops + + sgmv_expand_ops + + bgmv_shrink_ops + + bgmv_expand_ops + + cutlass_gemm_ops + + gemm_ops + + rms_norm_ops + + vocab_embed_ops + + mem_ops + + elementwise_ops + + nccl_all_reduce_ops + + nccl_gather_ops + + nccl_broadcast_ops + + nccl_other_ops + + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + + custom_ar_all_reduce_ops + + reduce_kernel_ops, + axis=1, + inplace=True, + ) return trace_df ## Data plotting utils #### -def plot_trace_df(traces_df: pd.DataFrame, - plot_metric: str, - plot_title: str, - output: Optional[Path] = None): - +def plot_trace_df( + traces_df: pd.DataFrame, + plot_metric: str, + plot_title: str, + output: Optional[Path] = None, +): def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: phase_df = traces_df.query(f'phase == "{phase}"') - descs = phase_df['phase_desc'].to_list() + descs = phase_df["phase_desc"].to_list() assert all([desc == descs[0] for desc in descs]) return descs[0] - phases = traces_df['phase'].unique() + phases = traces_df["phase"].unique() phase_descs = [get_phase_description(traces_df, p) for p in phases] - traces_df = traces_df.pivot_table(index="phase", - columns="name", - values=plot_metric, - aggfunc="sum") + traces_df = traces_df.pivot_table( + index="phase", columns="name", values=plot_metric, aggfunc="sum" + ) traces_df = group_trace_by_operations(traces_df) @@ -396,20 +405,19 @@ def plot_trace_df(traces_df: pd.DataFrame, # Write the values as text on the bars for bar in ax.patches: if bar.get_height() != 0: - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() / 2 + bar.get_y(), - f"{round(bar.get_height(), 2)}", - ha='center', - color='w', - weight='bold', - size=5) + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha="center", + color="w", + weight="bold", + size=5, + ) # Setup legend handles, labels = plt.gca().get_legend_handles_labels() - legend = fig.legend(handles, - labels, - loc='center left', - bbox_to_anchor=(1, 1)) + legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1)) shorten_plot_legend_strings(legend, 50) # Setup labels and title @@ -417,21 +425,20 @@ def plot_trace_df(traces_df: pd.DataFrame, ax.set_ylabel(plot_metric) plt.suptitle(plot_title) - plt.savefig(output, bbox_inches='tight') + plt.savefig(output, bbox_inches="tight") print("Created: ", output) def main( - json_trace: Path, - output_directory: Path, - depth: int, # Fetch/Plot operations at this depth of the Json tree - plot_metric: str, - make_names_unique: bool, - top_k: int, - json_nodes_to_fold: list[str]): - + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: list[str], +): def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame: - def get_entries_and_traces(key: str): entries_and_traces: list[tuple[Any, Any]] = [] for root in profile_json[key]["summary_stats"]: @@ -441,16 +448,14 @@ def main( get_entries_at_depth(depth, entries_and_traces, root) return entries_and_traces - def keep_only_top_entries(df: pd.DataFrame, - metric: str, - top_k: int = 9) -> pd.DataFrame: - df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, - ["name"]] = "others" + def keep_only_top_entries( + df: pd.DataFrame, metric: str, top_k: int = 9 + ) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others" return df def get_phase_description(key: str) -> str: - num_running_seqs = profile_json[key]['metadata'][ - 'num_running_seqs'] + num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"] if num_running_seqs is not None: return f"{key}-seqs-{num_running_seqs}" else: @@ -466,20 +471,24 @@ def main( # To pandas dataframe trace_dfs = list( - map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), - traces)) + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces) + ) # Respect top_k if top_k: trace_dfs = list( map( lambda trace_df: keep_only_top_entries( - trace_df, "cuda_time_us", top_k), trace_dfs)) + trace_df, "cuda_time_us", top_k + ), + trace_dfs, + ) + ) # Fill in information about the step-keys for trace_df, step_key in zip(trace_dfs, step_keys): - trace_df['phase'] = step_key - trace_df['phase_desc'] = get_phase_description(step_key) + trace_df["phase"] = step_key + trace_df["phase_desc"] = get_phase_description(step_key) # Combine all data frames so they can be put in a single plot traces_df = pd.concat(trace_dfs) @@ -492,17 +501,23 @@ def main( def make_plot_title_suffix(profile_json: dict) -> str: context = profile_json["context"] - sparsity = context.get('sparsity', None) - run_type = \ - f'Run {context["num_steps"]} steps' if context['num_steps'] else \ - (f'Complete {context["complete_num_requests_per_step"]} per ' - f'step; Run till completion') - return (f"{context['engine_args']['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['engine_args']['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}\n" - f"Run Type: {run_type}") + sparsity = context.get("sparsity", None) + run_type = ( + f"Run {context['num_steps']} steps" + if context["num_steps"] + else ( + f"Complete {context['complete_num_requests_per_step']} per " + f"step; Run till completion" + ) + ) + return ( + f"{context['engine_args']['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['engine_args']['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}\n" + f"Run Type: {run_type}" + ) profile_json = None with open(json_trace) as f: @@ -511,14 +526,14 @@ def main( # Get all `llm.generate.step()` profile step_traces = list(profile_json.keys()) - assert (step_traces[0] == 'context') + assert step_traces[0] == "context" step_traces = step_traces[1:] # have only prefill and decodes prefills = list(filter(lambda x: "prefill" in x, step_traces)) all_decodes = list(filter(lambda x: "decode" in x, step_traces)) assert len(prefills) + len(all_decodes) == len(step_traces) assert len(prefills) == 1 - decodes = all_decodes[::args.step_plot_interval] + decodes = all_decodes[:: args.step_plot_interval] if decodes[-1] != all_decodes[-1]: # Always have the last decode decodes.append(all_decodes[-1]) @@ -528,48 +543,63 @@ def main( plot_title_suffix = make_plot_title_suffix(profile_json) - plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, - output_directory / Path("prefill.png")) - plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, - output_directory / Path("decode_steps.png")) + plot_trace_df( + prefill_traces, + plot_metric, + "prefill " + plot_title_suffix, + output_directory / Path("prefill.png"), + ) + plot_trace_df( + decode_traces, + plot_metric, + "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png"), + ) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--json-trace", - type=str, - required=True, - help="json trace file output by \ - examples/offline_inference/profiling.py") - parser.add_argument("--output-directory", - type=str, - required=False, - help="Directory to output plots") - parser.add_argument("--level", - type=str, - default="module", - choices=["module", "kernel"]) - parser.add_argument("--top-k", - type=int, - default=12, - help="Only graph the top `top_k` entries by time.") - parser.add_argument("--fold-json-node", - nargs='+', - default=['Sampler', 'LogitsProcessor'], - help='Do not plot the children of these nodes. Let, \ + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by \ + examples/offline_inference/profiling.py", + ) + parser.add_argument( + "--output-directory", type=str, required=False, help="Directory to output plots" + ) + parser.add_argument( + "--level", type=str, default="module", choices=["module", "kernel"] + ) + parser.add_argument( + "--top-k", + type=int, + default=12, + help="Only graph the top `top_k` entries by time.", + ) + parser.add_argument( + "--fold-json-node", + nargs="+", + default=["Sampler", "LogitsProcessor"], + help="Do not plot the children of these nodes. Let, \ the node represent the aggregate of all its \ - children') - parser.add_argument("--plot-metric", - type=str, - default="cuda_time_ms", - help='Metric to plot. some options are cuda_time_ms, \ - pct_cuda_time') + children", + ) + parser.add_argument( + "--plot-metric", + type=str, + default="cuda_time_ms", + help="Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time", + ) parser.add_argument( "--step-plot-interval", type=int, default=4, - help="For every `step_plot_interval` steps, plot 1 step") + help="For every `step_plot_interval` steps, plot 1 step", + ) args = parser.parse_args() @@ -583,11 +613,19 @@ if __name__ == "__main__": else: raise Exception(f"Unexpected level value ({args.level})") - output_directory = args.output_directory if args.output_directory else Path( - args.json_trace).parent + output_directory = ( + args.output_directory if args.output_directory else Path(args.json_trace).parent + ) if not os.path.exists(output_directory): os.makedirs(output_directory) - main(Path(args.json_trace), output_directory, depth, args.plot_metric, - make_names_unique, args.top_k, args.fold_json_node) + main( + Path(args.json_trace), + output_directory, + depth, + args.plot_metric, + make_names_unique, + args.top_k, + args.fold_json_node, + ) diff --git a/tools/report_build_time_ninja.py b/tools/report_build_time_ninja.py index 7386cdd9f7245..fe3f352fe153e 100644 --- a/tools/report_build_time_ninja.py +++ b/tools/report_build_time_ninja.py @@ -83,9 +83,9 @@ class Target: """ # Allow for modest floating-point errors epsilon = 0.000002 - if (self.weighted_duration > self.Duration() + epsilon): - print('{} > {}?'.format(self.weighted_duration, self.Duration())) - assert (self.weighted_duration <= self.Duration() + epsilon) + if self.weighted_duration > self.Duration() + epsilon: + print("{} > {}?".format(self.weighted_duration, self.Duration())) + assert self.weighted_duration <= self.Duration() + epsilon return self.weighted_duration def DescribeTargets(self): @@ -93,10 +93,10 @@ class Target: # Some build steps generate dozens of outputs - handle them sanely. # The max_length was chosen so that it can fit most of the long # single-target names, while minimizing word wrapping. - result = ', '.join(self.targets) + result = ", ".join(self.targets) max_length = 65 if len(result) > max_length: - result = result[:max_length] + '...' + result = result[:max_length] + "..." return result @@ -106,12 +106,13 @@ def ReadTargets(log, show_all): The result is a list of Target objects.""" header = log.readline() - assert header == '# ninja log v5\n', \ - 'unrecognized ninja log version {!r}'.format(header) + assert header == "# ninja log v5\n", "unrecognized ninja log version {!r}".format( + header + ) targets_dict = {} last_end_seen = 0.0 for line in log: - parts = line.strip().split('\t') + parts = line.strip().split("\t") if len(parts) != 5: # If ninja.exe is rudely halted then the .ninja_log file may be # corrupt. Silently continue. @@ -150,17 +151,17 @@ def ReadTargets(log, show_all): def GetExtension(target, extra_patterns): """Return the file extension that best represents a target. - For targets that generate multiple outputs it is important to return a - consistent 'canonical' extension. Ultimately the goal is to group build steps - by type.""" + For targets that generate multiple outputs it is important to return a + consistent 'canonical' extension. Ultimately the goal is to group build steps + by type.""" for output in target.targets: if extra_patterns: - for fn_pattern in extra_patterns.split(';'): - if fnmatch.fnmatch(output, '*' + fn_pattern + '*'): + for fn_pattern in extra_patterns.split(";"): + if fnmatch.fnmatch(output, "*" + fn_pattern + "*"): return fn_pattern # Not a true extension, but a good grouping. - if output.endswith('type_mappings'): - extension = 'type_mappings' + if output.endswith("type_mappings"): + extension = "type_mappings" break # Capture two extensions if present. For example: file.javac.jar should @@ -170,26 +171,26 @@ def GetExtension(target, extra_patterns): extension = ext2 + ext1 # Preserve the order in the file name. if len(extension) == 0: - extension = '(no extension found)' + extension = "(no extension found)" - if ext1 in ['.pdb', '.dll', '.exe']: - extension = 'PEFile (linking)' + if ext1 in [".pdb", ".dll", ".exe"]: + extension = "PEFile (linking)" # Make sure that .dll and .exe are grouped together and that the # .dll.lib files don't cause these to be listed as libraries break - if ext1 in ['.so', '.TOC']: - extension = '.so (linking)' + if ext1 in [".so", ".TOC"]: + extension = ".so (linking)" # Attempt to identify linking, avoid identifying as '.TOC' break # Make sure .obj files don't get categorized as mojo files - if ext1 in ['.obj', '.o']: + if ext1 in [".obj", ".o"]: break # Jars are the canonical output of java targets. - if ext1 == '.jar': + if ext1 == ".jar": break # Normalize all mojo related outputs to 'mojo'. - if output.count('.mojom') > 0: - extension = 'mojo' + if output.count(".mojom") > 0: + extension = "mojo" break return extension @@ -214,8 +215,8 @@ def SummarizeEntries(entries, extra_step_types): if target.end > latest: latest = target.end total_cpu_time += target.Duration() - task_start_stop_times.append((target.start, 'start', target)) - task_start_stop_times.append((target.end, 'stop', target)) + task_start_stop_times.append((target.start, "start", target)) + task_start_stop_times.append((target.end, "stop", target)) length = latest - earliest weighted_total = 0.0 @@ -241,10 +242,10 @@ def SummarizeEntries(entries, extra_step_types): if num_running > 0: # Update the total weighted time up to this moment. last_weighted_time += (time - last_time) / float(num_running) - if action_name == 'start': + if action_name == "start": # Record the total weighted task time when this task starts. running_tasks[target] = last_weighted_time - if action_name == 'stop': + if action_name == "stop": # Record the change in the total weighted task time while this task # ran. weighted_duration = last_weighted_time - running_tasks[target] @@ -252,13 +253,16 @@ def SummarizeEntries(entries, extra_step_types): weighted_total += weighted_duration del running_tasks[target] last_time = time - assert (len(running_tasks) == 0) + assert len(running_tasks) == 0 # Warn if the sum of weighted times is off by more than half a second. if abs(length - weighted_total) > 500: - print('Warning: Possible corrupt ninja log, results may be ' - 'untrustworthy. Length = {:.3f}, weighted total = {:.3f}'.format( - length, weighted_total)) + print( + "Warning: Possible corrupt ninja log, results may be " + "untrustworthy. Length = {:.3f}, weighted total = {:.3f}".format( + length, weighted_total + ) + ) entries_by_ext = defaultdict(list) for target in entries: @@ -266,32 +270,38 @@ def SummarizeEntries(entries, extra_step_types): entries_by_ext[extension].append(target) for key, values in entries_by_ext.items(): - print(' Longest build steps for {}:'.format(key)) + print(" Longest build steps for {}:".format(key)) values.sort(key=lambda x: x.WeightedDuration()) for target in values[-long_count:]: print( - ' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'. - format(target.WeightedDuration(), target.DescribeTargets(), - target.Duration())) + " {:8.1f} weighted s to build {} ({:.1f} s elapsed time)".format( + target.WeightedDuration(), + target.DescribeTargets(), + target.Duration(), + ) + ) - print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x ' - 'parallelism)'.format(length, total_cpu_time, - total_cpu_time * 1.0 / length)) - print(' {} build steps completed, average of {:1.2f}/s'.format( - len(entries), - len(entries) / (length))) + print( + " {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x " + "parallelism)".format(length, total_cpu_time, total_cpu_time * 1.0 / length) + ) + print( + " {} build steps completed, average of {:1.2f}/s".format( + len(entries), len(entries) / (length) + ) + ) def main(): - log_file = '.ninja_log' + log_file = ".ninja_log" parser = argparse.ArgumentParser() - parser.add_argument('-C', dest='build_directory', help='Build directory.') + parser.add_argument("-C", dest="build_directory", help="Build directory.") parser.add_argument( - '-s', - '--step-types', - help='semicolon separated fnmatch patterns for build-step grouping') - parser.add_argument('--log-file', - help="specific ninja log file to analyze.") + "-s", + "--step-types", + help="semicolon separated fnmatch patterns for build-step grouping", + ) + parser.add_argument("--log-file", help="specific ninja log file to analyze.") args, _extra_args = parser.parse_known_args() if args.build_directory: log_file = os.path.join(args.build_directory, log_file) @@ -300,17 +310,16 @@ def main(): if args.step_types: # Make room for the extra build types. global long_ext_count - long_ext_count += len(args.step_types.split(';')) + long_ext_count += len(args.step_types.split(";")) try: with open(log_file) as log: entries = ReadTargets(log, False) SummarizeEntries(entries, args.step_types) except OSError: - print('Log file {!r} not found, no build summary created.'.format( - log_file)) + print("Log file {!r} not found, no build summary created.".format(log_file)) return errno.ENOENT -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/tools/validate_config.py b/tools/validate_config.py index 8b1e955c653d7..fb6f0e6a92850 100644 --- a/tools/validate_config.py +++ b/tools/validate_config.py @@ -8,6 +8,9 @@ and that each field has a docstring. import ast import inspect import sys +from itertools import pairwise + +import regex as re def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: @@ -18,28 +21,17 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: https://davidism.com/mit-license/ """ - def pairwise(iterable): - """ - Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - - Can be removed when Python 3.9 support is dropped. - """ - iterator = iter(iterable) - a = next(iterator, None) - - for b in iterator: - yield a, b - a = b - out = {} # Consider each pair of nodes. for a, b in pairwise(cls_node.body): # Must be an assignment then a constant string. - if (not isinstance(a, (ast.Assign, ast.AnnAssign)) - or not isinstance(b, ast.Expr) - or not isinstance(b.value, ast.Constant) - or not isinstance(b.value.value, str)): + if ( + not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str) + ): continue doc = inspect.cleandoc(b.value.value) @@ -59,25 +51,27 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: class ConfigValidator(ast.NodeVisitor): - - def __init__(self): - ... + def __init__(self): ... def visit_ClassDef(self, node): # Validate class with both @config and @dataclass decorators decorators = [ - id for d in node.decorator_list if (isinstance(d, ast.Name) and ( - (id := d.id) == 'config' or id == 'dataclass')) or - (isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and - (id := d.func.id) == 'dataclass')) + id + for d in node.decorator_list + if ( + isinstance(d, ast.Name) + and ((id := d.id) == "config" or id == "dataclass") + ) + or ( + isinstance(d, ast.Call) + and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass") + ) ] - if set(decorators) == {'config', 'dataclass'}: + if set(decorators) == {"config", "dataclass"}: validate_class(node) - elif set(decorators) == {'config'}: - fail( - f"Class {node.name} with config decorator must be a dataclass.", - node) + elif set(decorators) == {"config"}: + fail(f"Class {node.name} with config decorator must be a dataclass.", node) self.generic_visit(node) @@ -88,11 +82,14 @@ def validate_class(class_node: ast.ClassDef): for stmt in class_node.body: # A field is defined as a class variable that has a type annotation. if isinstance(stmt, ast.AnnAssign): - # Skip ClassVar + # Skip ClassVar and InitVar # see https://docs.python.org/3/library/dataclasses.html#class-variables - if isinstance(stmt.annotation, ast.Subscript) and isinstance( - stmt.annotation.value, - ast.Name) and stmt.annotation.value.id == "ClassVar": + # and https://docs.python.org/3/library/dataclasses.html#init-only-variables + if ( + isinstance(stmt.annotation, ast.Subscript) + and isinstance(stmt.annotation.value, ast.Name) + and stmt.annotation.value.id in {"ClassVar", "InitVar"} + ): continue if isinstance(stmt.target, ast.Name): @@ -100,22 +97,30 @@ def validate_class(class_node: ast.ClassDef): if stmt.value is None: fail( f"Field '{field_name}' in {class_node.name} must have " - "a default value.", stmt) + "a default value.", + stmt, + ) if field_name not in attr_docs: fail( f"Field '{field_name}' in {class_node.name} must have " - "a docstring.", stmt) + "a docstring.", + stmt, + ) - if isinstance(stmt.annotation, ast.Subscript) and \ - isinstance(stmt.annotation.value, ast.Name) \ - and stmt.annotation.value.id == "Union" and \ - isinstance(stmt.annotation.slice, ast.Tuple): + if ( + isinstance(stmt.annotation, ast.Subscript) + and isinstance(stmt.annotation.value, ast.Name) + and stmt.annotation.value.id == "Union" + and isinstance(stmt.annotation.slice, ast.Tuple) + ): args = stmt.annotation.slice.elts literal_args = [ - arg for arg in args - if isinstance(arg, ast.Subscript) and isinstance( - arg.value, ast.Name) and arg.value.id == "Literal" + arg + for arg in args + if isinstance(arg, ast.Subscript) + and isinstance(arg.value, ast.Name) + and arg.value.id == "Literal" ] if len(literal_args) > 1: fail( @@ -123,7 +128,9 @@ def validate_class(class_node: ast.ClassDef): "use a single " "Literal type. Please use 'Literal[Literal1, " "Literal2]' instead of 'Union[Literal1, Literal2]'" - ".", stmt) + ".", + stmt, + ) def validate_ast(tree: ast.stmt): @@ -132,7 +139,7 @@ def validate_ast(tree: ast.stmt): def validate_file(file_path: str): try: - print(f"validating {file_path} config dataclasses ", end="") + print(f"Validating {file_path} config dataclasses ", end="") with open(file_path, encoding="utf-8") as f: source = f.read() @@ -140,7 +147,7 @@ def validate_file(file_path: str): validate_ast(tree) except ValueError as e: print(e) - SystemExit(2) + raise SystemExit(1) from e else: print("✅") @@ -151,7 +158,13 @@ def fail(message: str, node: ast.stmt): def main(): for filename in sys.argv[1:]: - validate_file(filename) + # Only run for Python files in vllm/ or tests/ + if not re.match(r"^(vllm|tests)/.*\.py$", filename): + continue + # Only run if the file contains @config + with open(filename, encoding="utf-8") as f: + if "@config" in f.read(): + validate_file(filename) if __name__ == "__main__": diff --git a/use_existing_torch.py b/use_existing_torch.py index a9f79e16981c4..fd4caa69ec9c1 100644 --- a/use_existing_torch.py +++ b/use_existing_torch.py @@ -3,7 +3,7 @@ import glob -requires_files = glob.glob('requirements/*.txt') +requires_files = glob.glob("requirements/*.txt") requires_files += ["pyproject.toml"] for file in requires_files: print(f">>> cleaning {file}") @@ -11,9 +11,9 @@ for file in requires_files: lines = f.readlines() if "torch" in "".join(lines).lower(): print("removed:") - with open(file, 'w') as f: + with open(file, "w") as f: for line in lines: - if 'torch' not in line.lower(): + if "torch" not in line.lower(): f.write(line) else: print(line.strip()) diff --git a/vllm/__init__.py b/vllm/__init__.py index 7b90fd3a241bd..b9c868de68868 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -14,6 +14,8 @@ import typing import vllm.env_override # noqa: F401 MODULE_ATTRS = { + "bc_linter_skip": "._bc_linter:bc_linter_skip", + "bc_linter_include": "._bc_linter:bc_linter_include", "AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs", "EngineArgs": ".engine.arg_utils:EngineArgs", "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", @@ -46,14 +48,22 @@ if typing.TYPE_CHECKING: from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry - from vllm.outputs import (ClassificationOutput, - ClassificationRequestOutput, CompletionOutput, - EmbeddingOutput, EmbeddingRequestOutput, - PoolingOutput, PoolingRequestOutput, - RequestOutput, ScoringOutput, - ScoringRequestOutput) + from vllm.outputs import ( + ClassificationOutput, + ClassificationRequestOutput, + CompletionOutput, + EmbeddingOutput, + EmbeddingRequestOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, + ScoringOutput, + ScoringRequestOutput, + ) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + + from ._bc_linter import bc_linter_include, bc_linter_skip else: def __getattr__(name: str) -> typing.Any: @@ -64,12 +74,13 @@ else: module = import_module(module_name, __package__) return getattr(module, attr_name) else: - raise AttributeError( - f'module {__package__} has no attribute {name}') + raise AttributeError(f"module {__package__} has no attribute {name}") __all__ = [ "__version__", + "bc_linter_skip", + "bc_linter_include", "__version_tuple__", "LLM", "ModelRegistry", diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py new file mode 100644 index 0000000000000..af68396af0b5a --- /dev/null +++ b/vllm/_bc_linter.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# vllm/_bc_linter.py +from __future__ import annotations + +from typing import Any, Callable, TypeVar, overload + +T = TypeVar("T") + + +@overload +def bc_linter_skip(obj: T) -> T: ... + + +@overload +def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ... + + +def bc_linter_skip(obj: Any = None, *, reason: str | None = None): + """ + No-op decorator to mark symbols/files for BC-linter suppression. + + Usage: + @bc_linter_skip + def legacy_api(...): ... + """ + + def _wrap(x: T) -> T: + return x + + return _wrap if obj is None else obj + + +@overload +def bc_linter_include(obj: T) -> T: ... + + +@overload +def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ... + + +def bc_linter_include(obj: Any = None, *, reason: str | None = None): + """ + Usage: + @bc_linter_include + def public_api(...): ... + """ + + def _wrap(x: T) -> T: + return x + + return _wrap if obj is None else obj + + +__all__ = ["bc_linter_skip", "bc_linter_include"] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0043456e0009a..eac0a5009e81f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import torch @@ -13,16 +12,7 @@ from vllm.scalar_type import ScalarType logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_xpu(): - try: - import vllm._C - except ImportError as e: - logger.warning("Failed to import from vllm._C with %r", e) - -supports_moe_ops = False -with contextlib.suppress(ImportError): - import vllm._moe_C # noqa: F401 - supports_moe_ops = True +current_platform.import_kernels() if TYPE_CHECKING: @@ -58,11 +48,26 @@ def paged_attention_v1( blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v1( - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step) + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) def paged_attention_v2( @@ -90,11 +95,29 @@ def paged_attention_v2( blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v2( - out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) def paged_attention_rocm( @@ -117,13 +140,30 @@ def paged_attention_rocm( k_scale: torch.Tensor, v_scale: torch.Tensor, fp8_out_scale: Optional[torch.Tensor] = None, + mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", ) -> None: - torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, - scale, block_tables, seq_lens, - query_start_loc, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, - v_scale, fp8_out_scale) + torch.ops._rocm_C.paged_attention( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + query_start_loc, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + fp8_out_scale, + mfma_type, + ) def mla_decode_kvcache_cpu( @@ -134,19 +174,23 @@ def mla_decode_kvcache_cpu( block_tables: torch.Tensor, seq_lens: torch.Tensor, ) -> None: - torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale, - block_tables, seq_lens) + torch.ops._C_cpu.mla_decode_kvcache( + out, query, kv_cache, scale, block_tables, seq_lens + ) # merge attn states ops -def merge_attn_states(output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None) -> None: - torch.ops._C.merge_attn_states(output, output_lse, prefix_output, - prefix_lse, suffix_output, suffix_lse) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + torch.ops._C.merge_attn_states( + output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse + ) def convert_vertical_slash_indexes( @@ -165,33 +209,43 @@ def convert_vertical_slash_indexes( nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M - block_count = torch.zeros(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - block_offset = torch.zeros(batch_size, - num_heads, - num_rows, - nnz_slash, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_count = torch.zeros(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_index = torch.zeros(batch_size, - num_heads, - num_rows, - nnz_vertical, - dtype=q_seqlens.dtype, - device=q_seqlens.device) + block_count = torch.zeros( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + block_offset = torch.zeros( + batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + column_count = torch.zeros( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + column_index = torch.zeros( + batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) torch.ops._C.convert_vertical_slash_indexes( - block_count, block_offset, column_count, column_index, q_seqlens, - kv_seqlens, vertical_indexes, slash_indexes, context_size, - block_size_M, block_size_N, causal) + block_count, + block_offset, + column_count, + column_index, + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + context_size, + block_size_M, + block_size_N, + causal, + ) return block_count, block_offset, column_count, column_index @@ -214,33 +268,45 @@ def convert_vertical_slash_indexes_mergehead( nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M - block_count = torch.empty(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - block_offset = torch.empty(batch_size, - num_heads, - num_rows, - nnz_slash, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_count = torch.empty(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_index = torch.empty(batch_size, - num_heads, - num_rows, - nnz_vertical, - dtype=q_seqlens.dtype, - device=q_seqlens.device) + block_count = torch.empty( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + block_offset = torch.empty( + batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + column_count = torch.empty( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + column_index = torch.empty( + batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) torch.ops._C.convert_vertical_slash_indexes_mergehead( - block_count, block_offset, column_count, column_index, q_seqlens, - kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, - slash_indices_count, context_size, block_size_M, block_size_N, causal) + block_count, + block_offset, + column_count, + column_index, + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + vertical_indices_count, + slash_indices_count, + context_size, + block_size_M, + block_size_N, + causal, + ) return block_count, block_offset, column_count, column_index @@ -253,56 +319,71 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - torch.ops._C.rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox) - - -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor], head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) + torch.ops._C.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) # layer norm ops -def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: +def rms_norm( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float +) -> None: # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input input_contiguous = input.contiguous() torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) -def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: +def fused_add_rms_norm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float +) -> None: torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def poly_norm( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + epsilon: float, +) -> None: + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) + + def apply_repetition_penalties_torch( - logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> None: repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( - 1, logits.size(1)) + 1, logits.size(1) + ) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - penalties = torch.where(prompt_mask | output_mask, repetition_penalties, - 1.0) + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0) # If logits are positive, divide by penalty, otherwise multiply by penalty. scaling = torch.where(logits > 0, 1.0 / penalties, penalties) logits *= scaling def apply_repetition_penalties_cuda( - logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: - torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask, - repetition_penalties) + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> None: + torch.ops._C.apply_repetition_penalties_( + logits, prompt_mask, output_mask, repetition_penalties + ) -def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, - repetition_penalties: torch.Tensor) -> None: +def apply_repetition_penalties( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> None: """Apply repetition penalties to logits in-place. Args: @@ -312,11 +393,13 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, repetition_penalties: The repetition penalties of shape (num_seqs, ). """ if logits.is_cuda and logits.is_contiguous(): - apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_cuda( + logits, prompt_mask, output_mask, repetition_penalties + ) else: - apply_repetition_penalties_torch(logits, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits, prompt_mask, output_mask, repetition_penalties + ) # fused quant layer norm ops @@ -326,128 +409,172 @@ def rms_norm_dynamic_per_token_quant( epsilon: float, quant_dtype: torch.dtype, scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None + residual: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=quant_dtype) - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) - torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight, - scales, epsilon, scale_ub, - residual) + torch.ops._C.rms_norm_dynamic_per_token_quant( + output, input, weight, scales, epsilon, scale_ub, residual + ) return output, scales # quantization ops # awq -def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: +def awq_dequantize( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: int, + thx: int, + thy: int, +) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( - awq_dequantize_triton) + awq_dequantize_triton, + ) + return awq_dequantize_triton(qweight, scales, zeros) - return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, - thx, thy) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) -def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, - scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: +def awq_gemm( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import ( - awq_gemm_triton) + from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq -def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, use_exllama: bool, - bit: int) -> torch.Tensor: - return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, +) -> torch.Tensor: + return torch.ops._C.gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit + ) if hasattr(torch.ops._C, "gptq_gemm"): @register_fake("_C::gptq_gemm") - def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, - use_exllama: bool, bit: int) -> torch.Tensor: - return torch.empty((a.size(0), b_q_weight.size(1)), - dtype=a.dtype, - device=a.device) + def _gptq_gemm_fake( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, + ) -> torch.Tensor: + return torch.empty( + (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device + ) -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, - bit: int) -> None: +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # marlin_24 -def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, b_q_type.id, size_m, - size_n, size_k) +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return torch.ops._C.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::gptq_marlin_24_gemm") - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: + def _gptq_marlin_24_gemm_fake( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + ) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::gptq_marlin_gemm") - def _gptq_marlin_gemm_fake(a: torch.Tensor, - c: Optional[torch.Tensor], - b_q_weight: torch.Tensor, - b_bias: Optional[torch.Tensor], - b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_zeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - b_q_type_id: int, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + def _gptq_marlin_gemm_fake( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, + ) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::awq_dequantize") - def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: torch.SymInt, - thx: int, thy: int) -> torch.Tensor: + def _awq_dequantize_fake( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: torch.SymInt, + thx: int, + thy: int, + ) -> torch.Tensor: in_c = qweight.size(0) qout_c = qweight.size(1) out_c = qout_c * 8 - return torch.empty((in_c, out_c), - dtype=scales.dtype, - device=scales.device) + return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) @register_fake("_C::awq_gemm") - def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, - split_k_iters: torch.SymInt) -> torch.Tensor: + def _awq_gemm_fake( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: torch.SymInt, + ) -> torch.Tensor: num_in_feats = input.size(0) - return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), - dtype=input.dtype, - device=input.device).sum(0) + return torch.empty( + (split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device, + ).sum(0) @register_fake("_C::machete_mm") def machete_mm_fake( @@ -469,24 +596,55 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake( - b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype]) -> torch.Tensor: - return torch.empty_like(b_q_weight, - memory_format=torch.contiguous_format) + b_q_weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + ) -> torch.Tensor: + return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + + @register_fake("_C::cutlass_w4a8_mm") + def cutlass_w4a8_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + out_dtype = out_type if out_type is not None else torch.bfloat16 + return torch.empty((m, n), device=a.device, dtype=out_dtype) + + @register_fake("_C::cutlass_pack_scale_fp8") + def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: + return torch.empty_like(scales, memory_format=torch.contiguous_format) + + @register_fake("_C::cutlass_encode_and_reorder_int4b") + def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @register_fake("_C::allspark_w8a16_gemm") - def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - n: torch.SymInt, group_size: torch.SymInt, - sm_count: torch.SymInt, - sm_version: torch.SymInt, - CUBLAS_M_THRESHOLD: torch.SymInt, - has_zp: bool, - n32k16_reorder: bool) -> torch.Tensor: + def _allspark_w8a16_gemm_fake( + a: torch.Tensor, + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: torch.SymInt, + group_size: torch.SymInt, + sm_count: torch.SymInt, + sm_version: torch.SymInt, + CUBLAS_M_THRESHOLD: torch.SymInt, + has_zp: bool, + n32k16_reorder: bool, + ) -> torch.Tensor: m = a.size(0) return torch.empty((m, n), device=a.device, dtype=a.dtype) @@ -495,11 +653,12 @@ if hasattr(torch.ops._C, "ggml_dequantize"): @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake( - W: torch.Tensor, - quant_type: int, - m: torch.SymInt, - n: torch.SymInt, - dtype: Optional[torch.dtype] = None) -> torch.Tensor: + W: torch.Tensor, + quant_type: int, + m: torch.SymInt, + n: torch.SymInt, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @register_fake("_C::ggml_mul_mat_vec_a8") @@ -534,9 +693,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"): tokens: torch.SymInt, ) -> torch.Tensor: tokens = X.size(0) - return torch.empty((tokens * top_k, row), - dtype=torch.float16, - device=W.device) + return torch.empty((tokens * top_k, row), dtype=torch.float16, device=W.device) if hasattr(torch.ops._C, "ggml_moe_a8_vec"): @@ -552,9 +709,7 @@ if hasattr(torch.ops._C, "ggml_moe_a8_vec"): tokens: torch.SymInt, ) -> torch.Tensor: tokens = X.size(0) - return torch.empty((tokens * top_k, row), - dtype=X.dtype, - device=W.device) + return torch.empty((tokens * top_k, row), dtype=X.dtype, device=W.device) # cutlass @@ -571,20 +726,23 @@ def cutlass_blockwise_scaled_grouped_mm( problem_sizes: torch.Tensor, expert_offsets: torch.Tensor, ): - torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a, - scales_b, problem_sizes, - expert_offsets) + torch.ops._C.cutlass_blockwise_scaled_grouped_mm( + output, a, b, scales_a, scales_b, problem_sizes, expert_offsets + ) -def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, alpha: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: +def cutlass_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 m, n = a.shape[0], b.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, - alpha) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha) return out @@ -593,16 +751,17 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_block_fp8( - cuda_device_capability) + return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ `cutlass_scaled_mm` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` @@ -625,69 +784,65 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.numel( - ) == b.shape[1] and bias.dtype == out_dtype + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype # Massage the input to be 2D target_shape = (*a.shape[:-1], b.shape[1]) a = a.view(-1, a.shape[-1]) - cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 if current_platform.is_rocm() or not cutlass_compatible_b: from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa - triton_scaled_mm) + triton_scaled_mm, + ) + out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) else: - out = torch.empty((a.shape[0], b.shape[1]), - dtype=out_dtype, - device=a.device) + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out.view(*target_shape) -def cutlass_scaled_mm_azp(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ :param azp_adj: In the per-tensor case, this should include the azp. Always per-channel. :param azp: Only set in the per-token case. Per-token if set. """ - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.numel( - ) == b.shape[1] and bias.dtype == out_dtype + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype # Massage the input to be 2D target_shape = (*a.shape[:-1], b.shape[1]) a = a.view(-1, a.shape[-1]) assert azp is None or azp.numel() == a.shape[0] - out = torch.empty((a.shape[0], b.shape[1]), - dtype=out_dtype, - device=a.device) - torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, - azp, bias) + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) return out.view(*target_shape) def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_sparse_scaled_mm_supported( - cuda_device_capability) + return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) -def cutlass_sparse_compress(a: torch.Tensor) \ - -> tuple[torch.Tensor, torch.Tensor]: + +def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Compresses a sparse matrix for use with Cutlass sparse operations. @@ -718,26 +873,25 @@ def cutlass_sparse_compress(a: torch.Tensor) \ - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. """ - assert (a.dtype in [ - torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 - ]) - assert (a.is_contiguous()) + assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16] + assert a.is_contiguous() # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 elemsPerMetaElem = 4 - assert (a.shape[1] % (2 * elemsPerMetaElem) == 0) + assert a.shape[1] % (2 * elemsPerMetaElem) == 0 return torch.ops._C.cutlass_sparse_compress(a) def cutlass_scaled_sparse_mm( - a: torch.Tensor, - bt_nzs: torch.Tensor, - bt_meta: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + a: torch.Tensor, + bt_nzs: torch.Tensor, + bt_meta: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ Performs a scaled sparse matrix multiplication using Cutlass. @@ -761,31 +915,33 @@ def cutlass_scaled_sparse_mm( Returns: - The result of the scaled sparse matrix multiplication. """ - assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == bt_nzs.shape[0] \ - and bias.dtype == out_dtype + assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype m = a.shape[0] n = bt_nzs.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, - scale_b, bias) + torch.ops._C.cutlass_scaled_sparse_mm( + out, a, bt_nzs, bt_meta, scale_a, scale_b, bias + ) return out -def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, - expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - input_permutation: torch.Tensor, - output_permutation: torch.Tensor, - num_experts: int, - n: int, - k: int, - blockscale_offsets: Optional[torch.Tensor] = None): +def get_cutlass_moe_mm_data( + topk_ids: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, + output_permutation: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None, +): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. @@ -809,22 +965,29 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, computed with expert E is blockscale_offsets[E + 1] - blockscale_offsets[E] """ - return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, - input_permutation, - output_permutation, - num_experts, n, k, - blockscale_offsets) + return torch.ops._C.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k, + blockscale_offsets, + ) def get_cutlass_moe_mm_problem_sizes( - topk_ids: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - num_experts: int, - n: int, - k: int, - blockscale_offsets: Optional[torch.Tensor] = None): + topk_ids: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None, +): """ Compute only the per-expert problem sizes needed by the two grouped matrix multiplications used in CUTLASS-based fused MoE. @@ -835,8 +998,8 @@ def get_cutlass_moe_mm_problem_sizes( used in the fused MoE operation. """ return torch.ops._C.get_cutlass_moe_mm_problem_sizes( - topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, - blockscale_offsets) + topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets + ) def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): @@ -845,25 +1008,31 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): This is used in MoE to permute the input tensor before performing grouped matrix multiplications. """ num_tokens_permuted = dst2src_map.shape[0] - output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]), - device=input_tensor.device, - dtype=input_tensor.dtype) + output_tensor = torch.empty( + (num_tokens_permuted, input_tensor.shape[1]), + device=input_tensor.device, + dtype=input_tensor.dtype, + ) torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) return output_tensor -def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - expert_num_tokens: torch.Tensor, - num_local_experts: int, padded_m: int, n: int, - k: int): +def get_cutlass_pplx_moe_mm_data( + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + expert_num_tokens: torch.Tensor, + num_local_experts: int, + padded_m: int, + n: int, + k: int, +): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. The function takes in expert_num_tokens (token count per expert) and - non_zero_expert_idxs (consecutive indices of experts with non-zero token + non_zero_expert_idxs (consecutive indices of experts with non-zero token counts) and uses them to compute: - expert_offsets: Indices that mark at which token index each expert begins its computation. @@ -872,16 +1041,31 @@ def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, the fused MoE operation. """ return torch.ops._C.get_cutlass_pplx_moe_mm_data( - expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, - num_local_experts, padded_m, n, k) + expert_offsets, + problem_sizes1, + problem_sizes2, + expert_num_tokens, + num_local_experts, + padded_m, + n, + k, + ) -def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor, expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, a_strides: torch.Tensor, - b_strides: torch.Tensor, c_strides: torch.Tensor, - per_act_token: bool, per_out_ch: bool): +def cutlass_moe_mm( + out_tensors: torch.Tensor, + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, + a_strides: torch.Tensor, + b_strides: torch.Tensor, + c_strides: torch.Tensor, + per_act_token: bool, + per_out_ch: bool, +): """ A single grouped matrix multiplication used in CUTLASS-based fused MoE. The function executes fp8-quantized OUT = AB matrix multiplication. @@ -893,17 +1077,33 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, MMs used in the fused MoE operation. - a/b/c_strides: The data strides passed to grouped matrix multiplication. """ - return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, - a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, - c_strides, per_act_token, per_out_ch) + return torch.ops._C.cutlass_moe_mm( + out_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + c_strides, + per_act_token, + per_out_ch, + ) -def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor, alphas: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, sf_offsets: torch.Tensor): +def cutlass_fp4_moe_mm( + out_tensors: torch.Tensor, + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + alphas: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + sf_offsets: torch.Tensor, +): """ An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. @@ -920,123 +1120,216 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ - return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, - a_scales, b_scales, alphas, - problem_sizes, expert_offsets, - sf_offsets) + return torch.ops._C.cutlass_fp4_group_mm( + out_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + alphas, + problem_sizes, + expert_offsets, + sf_offsets, + ) # gptq_marlin -def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) # gptq_marlin -def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def gptq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) for e in range(num_experts): - output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], - size_k, size_n, num_bits) + output[e] = torch.ops._C.gptq_marlin_repack( + b_q_weight[e], perm[e], size_k, size_n, num_bits + ) return output -def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def awq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) for e in range(num_experts): - output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, - size_n, num_bits) + output[e] = torch.ops._C.awq_marlin_repack( + b_q_weight[e], size_k, size_n, num_bits + ) return output -def gptq_marlin_gemm(a: torch.Tensor, - c: Optional[torch.Tensor], - b_q_weight: torch.Tensor, - b_bias: Optional[torch.Tensor], - b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_zeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales, - global_scale, b_zeros, g_idx, perm, - workspace, b_q_type.id, size_m, - size_n, size_k, is_k_full, - use_atomic_add, use_fp32_reduce, - is_zp_float) +def gptq_marlin_gemm( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return torch.ops._C.gptq_marlin_gemm( + a, + c, + b_q_weight, + b_bias, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) # machete def machete_supported_schedules( - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: Optional[torch.dtype], - group_zeros_type: Optional[torch.dtype] = None, - channel_scales_type: Optional[torch.dtype] = None, - token_scales_type: Optional[torch.dtype] = None, - out_type: Optional[torch.dtype] = None) -> list[str]: + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None, +) -> list[str]: return torch.ops._C.machete_supported_schedules( - a_type, b_type.id, group_scales_type, group_zeros_type, - channel_scales_type, token_scales_type, out_type) + a_type, + b_type.id, + group_scales_type, + group_zeros_type, + channel_scales_type, + token_scales_type, + out_type, + ) def machete_mm( - a: torch.Tensor, - # b_q Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type: ScalarType, - out_type: Optional[torch.dtype] = None, - b_group_scales: Optional[torch.Tensor] = None, - b_group_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - b_channel_scales: Optional[torch.Tensor] = None, - a_token_scales: Optional[torch.Tensor] = None, - schedule: Optional[str] = None) -> torch.Tensor: - return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, - b_group_zeros, b_group_size, - b_channel_scales, a_token_scales, schedule) + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.machete_mm( + a, + b_q, + b_type.id, + out_type, + b_group_scales, + b_group_zeros, + b_group_size, + b_channel_scales, + a_token_scales, + schedule, + ) def machete_prepack_B( - b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype]) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, - group_scales_type) + b_q_weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], +) -> torch.Tensor: + return torch.ops._C.machete_prepack_B( + b_q_weight, a_type, b_type.id, group_scales_type + ) + + +# CUTLASS W4A8 +def cutlass_w4a8_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm( + a, + b_q, + b_group_scales, + b_group_size, + b_channel_scales, + a_token_scales, + out_type, + maybe_schedule, + ) + + +def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_pack_scale_fp8(scales) + + +def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_encode_and_reorder_int4b(b) if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") - def _permute_cols_fake(a: torch.Tensor, - perm: torch.Tensor) -> torch.Tensor: + def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -1046,8 +1339,8 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: # fp4 def scaled_fp4_quant( - input: torch.Tensor, - input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, input_global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. @@ -1067,18 +1360,17 @@ def scaled_fp4_quant( in the sizzled layout. """ assert not current_platform.is_rocm() - assert input.ndim >= 1, ( - f'input.ndim needs to be >= 1, but got {input.ndim}.') + assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." other_dims = 1 if input.ndim == 1 else -1 input = input.reshape(other_dims, input.shape[-1]) m, n = input.shape block_size = 16 device = input.device - assert n % block_size == 0, ( - f'last dim has to be multiple of 16, but got {n}.') + assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert input.dtype in (torch.float16, torch.bfloat16), ( - f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." + ) # Two fp4 values will be packed into an uint8. output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) @@ -1092,12 +1384,11 @@ def scaled_fp4_quant( rounded_m = round_up(m, 128) scale_n = n // block_size rounded_n = round_up(scale_n, 4) - output_scale = torch.empty((rounded_m, rounded_n // 4), - device=device, - dtype=torch.int32) + output_scale = torch.zeros( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) - torch.ops._C.scaled_fp4_quant(output, input, output_scale, - input_global_scale) + torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale @@ -1123,7 +1414,8 @@ def scaled_fp4_experts_quant( """ assert not current_platform.is_rocm() assert input_tensor.ndim == 2, ( - f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + f"input.ndim needs to be == 2, but got {input_tensor.ndim}." + ) # Control the maximum number of tokens per expert supported by the # NVFP4 MoE Expert Quantization. This is used to prevent the kernel @@ -1132,26 +1424,33 @@ def scaled_fp4_experts_quant( MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE m_numtopk, k = input_tensor.shape - assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" f"{MAX_TOKENS_PER_EXPERT})" f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" - f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.") + f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." + ) scales_k = k // 16 padded_k = (scales_k + (4 - 1)) // 4 # output is uint8 and packed fp4 values - output = torch.empty(m_numtopk, - k // 2, - device=input_tensor.device, - dtype=torch.uint8) - output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, - padded_k, - dtype=torch.int32, - device=input_tensor.device) - torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, - input_global_scale, expert_offsets, - blockscale_offsets) + output = torch.empty( + m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 + ) + output_scales = torch.empty( + MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device, + ) + torch.ops._C.scaled_fp4_experts_quant( + output, + output_scales, + input_tensor, + input_global_scale, + expert_offsets, + blockscale_offsets, + ) output_scales = output_scales.view(torch.float8_e4m3fn) return output, output_scales @@ -1189,7 +1488,7 @@ def scaled_fp8_quant( scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) + assert input.ndim == 2 shape: Union[tuple[int, int], torch.Size] = input.shape # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = current_platform.fp8_dtype() @@ -1198,17 +1497,15 @@ def scaled_fp8_quant( if output is None: output = torch.empty(shape, device=input.device, dtype=out_dtype) else: - assert num_token_padding is None, \ - "padding not supported if output passed in" + assert num_token_padding is None, "padding not supported if output passed in" assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) + output, input, scale, scale_ub + ) else: scale = torch.empty(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) @@ -1221,10 +1518,10 @@ def scaled_fp8_quant( # gptq allspark def allspark_repack_weight( - qweight: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor] = None, - has_zp: bool = False + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format @@ -1246,38 +1543,61 @@ def allspark_repack_weight( N = qweight.shape[1] N_32align = (N + 32 - 1) // 32 * 32 - qweight_reorder = torch.empty((N_32align, K), - device=qweight.device, - dtype=qweight.dtype) - scale_reorder = torch.empty((1, N_32align), - device=scale.device, - dtype=scale.dtype) + qweight_reorder = torch.empty( + (N_32align, K), device=qweight.device, dtype=qweight.dtype + ) + scale_reorder = torch.empty((1, N_32align), device=scale.device, dtype=scale.dtype) zero_point_reorder = None if has_zp: assert zero_point is not None, ( - "zero_point must be provided for asymmetric quantization.") - zero_point_reorder = torch.empty((1, N_32align), - device=zero_point.device, - dtype=zero_point.dtype) + "zero_point must be provided for asymmetric quantization." + ) + zero_point_reorder = torch.empty( + (1, N_32align), device=zero_point.device, dtype=zero_point.dtype + ) torch.ops._C.rearrange_kn_weight_as_n32k16_order( - qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, - zero_point_reorder, K, N, N_32align) + qweight, + scale, + zero_point, + has_zp, + qweight_reorder, + scale_reorder, + zero_point_reorder, + K, + N, + N_32align, + ) return qweight_reorder, scale_reorder, zero_point_reorder -def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], n: int, - group_size: int, sm_count: int, sm_version: int, - CUBLAS_M_THRESHOLD: int, has_zp: bool, - n32k16_reorder: bool) -> torch.Tensor: - - return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, - n, group_size, sm_count, - sm_version, CUBLAS_M_THRESHOLD, - has_zp, n32k16_reorder) +def allspark_w8a16_gemm( + a: torch.Tensor, + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: int, + group_size: int, + sm_count: int, + sm_version: int, + CUBLAS_M_THRESHOLD: int, + has_zp: bool, + n32k16_reorder: bool, +) -> torch.Tensor: + return torch.ops._C.allspark_w8a16_gemm( + a, + b_qweight, + b_scales, + b_qzeros, + n, + group_size, + sm_count, + sm_version, + CUBLAS_M_THRESHOLD, + has_zp, + n32k16_reorder, + ) # int8 @@ -1285,7 +1605,7 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True + symmetric: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -1304,26 +1624,27 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - assert symmetric == ( - azp - is None), "azp must only be provided for asymmetric quantization." + assert symmetric == (azp is None), ( + "azp must only be provided for asymmetric quantization." + ) torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(), - input_scales, input_azp) + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant( + output, input.contiguous(), input_scales, input_azp + ) return output, input_scales, input_azp # gguf -def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, - dtype: Optional[torch.dtype]) -> torch.Tensor: +def ggml_dequantize( + W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype] +) -> torch.Tensor: return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) @@ -1356,9 +1677,17 @@ def ggml_moe_a8( top_k: int, tokens: int, ) -> torch.Tensor: - return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids, - num_tokens_post_padded, quant_type, row, - top_k, tokens) + return torch.ops._C.ggml_moe_a8( + X, + W, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + quant_type, + row, + top_k, + tokens, + ) def ggml_moe_a8_vec( @@ -1370,8 +1699,7 @@ def ggml_moe_a8_vec( row: torch.SymInt, tokens: torch.SymInt, ) -> torch.Tensor: - return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, - tokens) + return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, tokens) def ggml_moe_get_block_size(quant_type: int) -> int: @@ -1379,38 +1707,62 @@ def ggml_moe_get_block_size(quant_type: int) -> int: # mamba -def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, - B: torch.Tensor, C: torch.Tensor, - D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - ssm_states: torch.Tensor, pad_slot_id: int): - torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, - delta_softplus, query_start_loc, - cache_indices, has_initial_state, - ssm_states, pad_slot_id) +def selective_scan_fwd( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, + pad_slot_id: int, +): + torch.ops._C.selective_scan_fwd( + u, + delta, + A, + B, + C, + D_, + z_, + delta_bias_, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ) # ROCm skinny gemms -def LLMM1(a: torch.Tensor, b: torch.Tensor, - rows_per_block: int) -> torch.Tensor: +def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) -def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: - return torch.ops._rocm_C.wvSplitK(a, b, cu_count) +def wvSplitK( + a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None +) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) -def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cu_count: int) -> torch.Tensor: - out = torch.empty((b.shape[0], a.shape[0]), - dtype=out_dtype, - device=b.device) - torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) +def wvSplitKQ( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None, +) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out @@ -1419,107 +1771,212 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) -def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: - torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) +def moe_align_block_size( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) -def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, - b_qweight: torch.Tensor, b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], - sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, top_k: int, - BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, - bit: int) -> torch.Tensor: +def moe_wna16_gemm( + input: torch.Tensor, + output: torch.Tensor, + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, + top_k: int, + BLOCK_SIZE_M: int, + BLOCK_SIZE_N: int, + BLOCK_SIZE_K: int, + bit: int, +) -> torch.Tensor: if not current_platform.is_cuda(): raise NotImplementedError( - "The optimized moe_wna16_gemm kernel is only " - "available on CUDA platforms") - torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, - b_qzeros, topk_weights, sorted_token_ids, - experts_ids, num_tokens_post_pad, top_k, - BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, - bit) + "The optimized moe_wna16_gemm kernel is only available on CUDA platforms" + ) + torch.ops._moe_C.moe_wna16_gemm( + input, + output, + b_qweight, + b_scales, + b_qzeros, + topk_weights, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + top_k, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + bit, + ) -def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor) -> None: - torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indices, - gating_output) +def topk_softmax( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, +) -> None: + torch.ops._moe_C.topk_softmax( + topk_weights, topk_ids, token_expert_indices, gating_output + ) -def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], - b_qweight: torch.Tensor, - b_bias: Optional[torch.Tensor], - b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_qzeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_past_padded: torch.Tensor, - topk_weights: torch.Tensor, moe_block_size: int, - top_k: int, mul_topk_weights: bool, is_ep: bool, - b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, use_atomic_add: bool, - use_fp32_reduce: bool, - is_zp_float: bool) -> torch.Tensor: +def grouped_topk( + scores: torch.Tensor, + scores_with_bias: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, +): + if not current_platform.is_cuda(): + raise NotImplementedError( + "The fused grouped_topk kernel is only available on CUDA platforms" + ) + return torch.ops._moe_C.grouped_topk( + scores, + scores_with_bias, + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + +def moe_wna16_marlin_gemm( + input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, + top_k: int, + mul_topk_weights: bool, + is_ep: bool, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool, +) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros, - g_idx, perm, workspace, sorted_token_ids, expert_ids, - num_tokens_past_padded, topk_weights, moe_block_size, top_k, - mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k, - is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) + input, + output, + b_qweight, + b_bias, + b_scales, + global_scale, + b_qzeros, + g_idx, + perm, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_past_padded, + topk_weights, + moe_block_size, + top_k, + mul_topk_weights, + is_ep, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) -if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): +if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") - def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, - sorted_ids: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, b_scales: torch.Tensor, - b_zero_points: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, size_k: torch.SymInt, - is_k_full: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, - apply_weights: bool) -> torch.Tensor: - return torch.empty((size_m, topk, size_n), - dtype=a.dtype, - device=a.device) + def marlin_gemm_moe_fake( + a: torch.Tensor, + b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + b_scales: torch.Tensor, + b_zero_points: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + num_experts: int, + topk: int, + moe_block_size: int, + replicate_input: bool, + apply_weights: bool, + ) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) @register_fake("_moe_C::moe_wna16_marlin_gemm") - def moe_wna16_marlin_gemm_fake(input: torch.Tensor, - output: Optional[torch.Tensor], - b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_past_padded: torch.Tensor, - topk_weights: torch.Tensor, - moe_block_size: int, top_k: int, - mul_topk_weights: bool, is_ep: bool, - b_q_type: ScalarType, size_m: int, - size_n: int, size_k: int, is_k_full: bool, - use_atomic_add: bool, use_fp32_reduce: bool, - is_zp_float: bool) -> torch.Tensor: - return torch.empty((size_m * top_k, size_n), - dtype=input.dtype, - device=input.device) + def moe_wna16_marlin_gemm_fake( + input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, + top_k: int, + mul_topk_weights: bool, + is_ep: bool, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool, + ) -> torch.Tensor: + return torch.empty( + (size_m * top_k, size_n), dtype=input.dtype, device=input.device + ) def reshape_and_cache( @@ -1532,9 +1989,16 @@ def reshape_and_cache( k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + torch.ops._C_cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) def reshape_and_cache_flash( @@ -1547,10 +2011,16 @@ def reshape_and_cache_flash( k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, - v_scale) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) def concat_and_cache_mla( @@ -1561,46 +2031,92 @@ def concat_and_cache_mla( kv_cache_dtype: str, scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, - slot_mapping, kv_cache_dtype, - scale) + torch.ops._C_cache_ops.concat_and_cache_mla( + kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale + ) -def copy_blocks(key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor) -> None: +def copy_blocks( + key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) -def copy_blocks_mla(kv_caches: list[torch.Tensor], - block_mapping: torch.Tensor) -> None: +def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) -def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: torch.Tensor) -> None: +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, - input: torch.Tensor, - scale: float = 1.0, - kv_dtype: str = "fp8") -> None: +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) def gather_and_maybe_dequant_cache( - src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - kv_cache_dtype: str, - scale: torch.Tensor, - seq_starts: Optional[torch.Tensor] = None) -> None: + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None, +) -> None: torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( - src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, - scale, seq_starts) + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + kv_cache_dtype, + scale, + seq_starts, + ) + + +def cp_gather_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None, +) -> None: + torch.ops._C_cache_ops.cp_gather_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts + ) + + +def indexer_k_quant_and_cache( + k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + kv_cache_dtype: str, +) -> None: + torch.ops._C_cache_ops.indexer_k_quant_and_cache( + k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype + ) + + +def cp_gather_indexer_k_quant_cache( + kv_cache: torch.Tensor, + dst_k: torch.Tensor, + dst_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + ) def get_device_attribute(attribute: int, device: int) -> int: @@ -1610,20 +2126,30 @@ def get_device_attribute(attribute: int, device: int) -> int: def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # ruff: noqa: E501 return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( - device) + device + ) # custom ar -def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, - rank: int, fully_connected: bool) -> int: - return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, - fully_connected) +def init_custom_ar( + ipc_tensors: list[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + fully_connected: bool, +) -> int: + return torch.ops._C_custom_ar.init_custom_ar( + ipc_tensors, rank_data, rank, fully_connected + ) -def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, - reg_buffer_sz_bytes: int) -> None: - torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, - reg_buffer_sz_bytes) +def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, +) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) def dispose(fa: int) -> None: @@ -1642,8 +2168,9 @@ def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, handles: list[list[int]], - offsets: list[list[int]]) -> None: +def register_graph_buffers( + fa: int, handles: list[list[int]], offsets: list[list[int]] +) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) @@ -1660,9 +2187,9 @@ def free_shared_buffer(ptr: int) -> None: # quick all reduce -def init_custom_qr(rank: int, - world_size: int, - qr_max_size: Optional[int] = None) -> int: +def init_custom_qr( + rank: int, world_size: int, qr_max_size: Optional[int] = None +) -> int: return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) @@ -1670,13 +2197,14 @@ def qr_destroy(fa: int) -> None: torch.ops._C_custom_ar.qr_destroy(fa) -def qr_all_reduce(fa: int, - inp: torch.Tensor, - out: torch.Tensor, - quant_level: int, - cast_bf2half: bool = False) -> None: - torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, - cast_bf2half) +def qr_all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool = False, +) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) def qr_get_handle(fa: int) -> torch.Tensor: @@ -1706,9 +2234,9 @@ def get_flash_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._C.get_flash_mla_metadata(cache_seqlens, - num_heads_per_head_k, - num_heads_k) + return torch.ops._C.get_flash_mla_metadata( + cache_seqlens, num_heads_per_head_k, num_heads_k + ) def flash_mla_with_kvcache( @@ -1739,7 +2267,7 @@ def flash_mla_with_kvcache( softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) + softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( q, k_cache, @@ -1755,44 +2283,53 @@ def flash_mla_with_kvcache( return out, softmax_lse -def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, page_table: torch.Tensor, - scale: float) -> torch.Tensor: - torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale) +def sm100_cutlass_mla_decode( + out: torch.Tensor, + lse: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + workspace: torch.Tensor, + scale: float, + num_kv_splits: int, +) -> torch.Tensor: + torch.ops._C.sm100_cutlass_mla_decode( + out, + lse, + q_nope, + q_pe, + kv_c_and_k_pe_cache, + seq_lens, + page_table, + workspace, + scale, + num_kv_splits, + ) return out -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, page_table: torch.Tensor, - workspace: torch.Tensor, scale: float, - num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, - kv_c_and_k_pe_cache, seq_lens, - page_table, workspace, scale, - num_kv_splits) - return out - - -def sm100_cutlass_mla_get_workspace_size(max_seq_len: int, num_batches: int, - sm_count: int, - num_kv_splits: int) -> int: +def sm100_cutlass_mla_get_workspace_size( + max_seq_len: int, num_batches: int, sm_count: int, num_kv_splits: int +) -> int: return torch.ops._C.sm100_cutlass_mla_get_workspace_size( - max_seq_len, num_batches, sm_count, num_kv_splits) + max_seq_len, num_batches, sm_count, num_kv_splits + ) if hasattr(torch.ops._C, "weight_packed_linear"): @register_fake("_C::weight_packed_linear") - def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor, - bias: Optional[torch.Tensor], - is_vnni: bool) -> torch.Tensor: - return torch.empty((mat1.size(0), mat2.size(0)), - dtype=mat1.dtype, - device=mat2.device) + def weight_packed_linear_fake( + mat1: torch.Tensor, + mat2: torch.Tensor, + bias: Optional[torch.Tensor], + is_vnni: bool, + ) -> torch.Tensor: + return torch.empty( + (mat1.size(0), mat2.size(0)), dtype=mat1.dtype, device=mat2.device + ) if hasattr(torch.ops._C, "fused_experts_cpu"): @@ -1834,7 +2371,6 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): class CPUDNNLGEMMHandler: - def __init__(self) -> None: self.handler: Optional[int] = None self.n = -1 @@ -1845,6 +2381,38 @@ class CPUDNNLGEMMHandler: torch.ops._C.release_dnnl_matmul_handler(self.handler) +_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) + + +def is_onednn_acl_supported(): + return torch.ops._C.is_onednn_acl_supported() + + +def create_onednn_mm( + weight: torch.Tensor, # [K, N] + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_mm_handler( + weight, primitive_cache_size + ) + return handler + + +def onednn_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) + torch.ops._C.onednn_mm( + output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler + ) + + return output + + def create_onednn_scaled_mm( weight: torch.Tensor, # [K, N] weight_scales: torch.Tensor, @@ -1856,15 +2424,17 @@ def create_onednn_scaled_mm( handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( - weight, weight_scales, output_type, dynamic_quant, use_azp, - primitive_cache_size) + weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size + ) return handler -def onednn_scaled_int8_quant(input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True): +def onednn_scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +): """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -1884,20 +2454,16 @@ def onednn_scaled_int8_quant(input: torch.Tensor, input = input.view((token_num, input.shape[-1])) if scale is not None: # static-per-tensor quantization. - assert symmetric == ( - azp - is None), "azp must only be provided for asymmetric quantization." + assert symmetric == (azp is None), ( + "azp must only be provided for asymmetric quantization." + ) torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty((token_num, 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) + input_scales = torch.empty((token_num, 1), device=input.device, dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp @@ -1910,7 +2476,169 @@ def onednn_scaled_mm( input_zp_adj: Optional[torch.Tensor], bias: Optional[torch.Tensor], ) -> torch.Tensor: - torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp, - input_zp_adj, bias, dnnl_handler.handler) + torch.ops._C.onednn_scaled_mm( + output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler + ) return output + + +if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): + + @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") + def _fake_matmul_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, + ): + return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) + + +def matmul_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) + + +if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"): + + @register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn") + def _fake_matmul_ada_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, + ): + return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) + + +def matmul_ada_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) + + +def ceil_div(a, b): + return (a + b - 1) // b + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"): + + @register_fake("_qutlass_C::fusedQuantizeMxQuest") + def _fake_fused_quantize_mx_quest( + a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor + ): + return xh_e2m1, xh_e8m0 + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"): + + @register_fake("_qutlass_C::fusedQuantizeMxAbsMax") + def _fake_fused_quantize_mx_absmax( + a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor + ): + return xh_e2m1, xh_e8m0 + + +def fusedQuantizeMx( + a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest" +) -> tuple[torch.Tensor, torch.Tensor]: + if a.dim() == 0: + raise ValueError("`a` must have at least 1 dimension.") + if a.size(-1) % 32 != 0: + raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.") + if b.device != a.device: + raise ValueError("`a` and `b` must be on the same device.") + + xh_e2m1 = torch.empty( + *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device + ) + + rows, cols = a.numel() // a.size(-1), a.size(-1) // 32 + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device + ) + + if not hasattr(torch.ops, "_qutlass_C"): + raise RuntimeError( + "The `_qutlass_C` extension is not loaded. " + "Make sure your custom op library is imported before calling fusedQuantizeMx." + ) + + if method == "quest": + return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) + elif method == "abs_max": + return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0) + else: + raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'") + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"): + + @register_fake("_qutlass_C::fusedQuantizeNv") + def _fake_fused_quantize_nv( + a: torch.Tensor, + b: torch.Tensor, + xh_e2m1: torch.Tensor, + xh_e4m3: torch.Tensor, + global_scale: torch.Tensor, + ): + return xh_e2m1, xh_e4m3 + + +def fusedQuantizeNv( + a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + xh_e2m1 = torch.empty( + *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device + ) + + rows, cols = a.numel() // a.size(-1), a.size(-1) // 16 + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + xh_e4m3 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device + ) + + return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale) + + +def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: + """ + Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832) + kernels. Note that these kernels exploit the recursive properties of + Sylvester Hadamards, and therefore do not require transform weight data + + Note that sylvester hadamard transforms are also symmetric, which means that + this function is also applies the (transpose <=> inverse) transform. + + :param x: value to be transformed inplace + :param inplace: modify value in place + :return: value after transformation + """ + return torch.ops._C.hadacore_transform(x, inplace) + + +if hasattr(torch.ops._C, "hadacore_transform"): + + @register_fake("_C::hadacore_transform") + def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor: + return torch.empty_like(x) if not inplace else x diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 79e3e448cada3..1f458f940a289 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,25 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) try: import intel_extension_for_pytorch as ipex except ImportError as e: - logger.warning("Import error msg: %s", e.msg) + logger.debug("Import error msg: %s", e.msg) class ipex_ops: - @staticmethod def _reshape_activation_tensor( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: num = x.size(0) d = x.size(1) // 2 x = x.reshape(num, 2, d) @@ -143,31 +144,26 @@ class ipex_ops: is_neox: bool, ) -> None: rot_dim = cos_sin_cache.size(1) - ipex.llm.functional.rotary_embedding_batched(positions, query, key, - head_size, cos_sin_cache, - is_neox, rot_dim) + ipex.llm.functional.rotary_embedding_batched( + positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim + ) @staticmethod - def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - ipex.llm.functional.rotary_embedding_batched(positions, query, key, - head_size, cos_sin_cache, - is_neox, rot_dim, - cos_sin_cache_offsets) - - @staticmethod - def rms_norm(input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> torch.Tensor: + def rms_norm( + input: torch.Tensor, weight: torch.Tensor, epsilon: float + ) -> torch.Tensor: return ipex.llm.functional.rms_norm(input, weight, epsilon) @staticmethod - def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: - tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, - epsilon, True) + def fused_add_rms_norm( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + ) -> None: + tmp = ipex.llm.functional.add_rms_norm( + residual, input, weight, None, epsilon, True + ) input.copy_(tmp) @staticmethod @@ -196,22 +192,43 @@ class ipex_ops: raise ValueError("IPEX CPU does not support logits_soft_cap") assert alibi_slopes is None assert window_size_left < 0 and window_size_right < 0 - ipex.llm.functional.varlen_attention(query.contiguous(), - key.contiguous(), - value.contiguous(), out, - seqlen_q.int(), - seqlen_k.int(), max_seqlen_q, - max_seqlen_k, pdropout, - softmax_scale, zero_tensors, - is_causal, return_softmax, - gen_) + ipex.llm.functional.varlen_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + out, + seqlen_q.int(), + seqlen_k.int(), + max_seqlen_q, + max_seqlen_k, + pdropout, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + ) else: # XPU build ipex.llm.functional.varlen_attention( - query.contiguous(), key.contiguous(), value.contiguous(), out, - seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q, - max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal, - return_softmax, gen_, window_size_left, window_size_right, - logits_soft_cap) + query.contiguous(), + key.contiguous(), + value.contiguous(), + out, + seqlen_q.int(), + seqlen_k.int(), + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + pdropout, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + window_size_left, + window_size_right, + logits_soft_cap, + ) @staticmethod def reshape_and_cache( @@ -226,7 +243,8 @@ class ipex_ops: ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slot_mapping) + key, value, key_cache, value_cache, slot_mapping + ) @staticmethod def reshape_and_cache_flash( @@ -241,10 +259,16 @@ class ipex_ops: k_scale_float: float = 1.0, v_scale_float: float = 1.0, ) -> None: - assert kv_cache_dtype == "auto" - # TODO: support FP8 kv cache. ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slot_mapping) + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale_float, + v_scale_float, + ) @staticmethod def flash_attn_varlen_func( @@ -276,10 +300,12 @@ class ipex_ops: if cu_seqlens_k is None: # cu_seqlens_k is not used in ipex kernel. cu_seqlens_k = torch.cumsum(seqused_k, dim=0) - cu_seqlens_k = torch.cat([ - torch.tensor([0], device=seqused_k.device, dtype=torch.int32), - cu_seqlens_k - ]).to(torch.int32) + cu_seqlens_k = torch.cat( + [ + torch.tensor([0], device=seqused_k.device, dtype=torch.int32), + cu_seqlens_k, + ] + ).to(torch.int32) real_window_size: tuple[int, int] if window_size is None: @@ -309,36 +335,38 @@ class ipex_ops: @staticmethod def get_scheduler_metadata( - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads_q, - num_heads_kv, - headdim, - cache_seqlens: torch.Tensor, - qkv_dtype=torch.bfloat16, - headdim_v=None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, - max_seqlen_k_new=0, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - has_softcap=False, - num_splits=0, # Can be tuned for speed - pack_gqa=None, # Can be tuned for speed - sm_margin=0, # Can be tuned if some SMs are used for communication + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication ) -> None: logger.warning_once( - "get_scheduler_metadata is not implemented for ipex_ops, " - "returning None.") + "get_scheduler_metadata is not implemented for ipex_ops, returning None." + ) return None @staticmethod - def copy_blocks(key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor) -> None: + def copy_blocks( + key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor, + ) -> None: torch.xpu.copy_blocks( # type: ignore key_caches, value_caches, @@ -346,6 +374,62 @@ class ipex_ops: ) @staticmethod - def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: torch.Tensor) -> None: + def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor + ) -> None: torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore + + @staticmethod + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, + output: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function is designed for both static and dynamic quantization: + If you provide the scale, it will use static scaling and if you omit + it, the scale will be determined dynamically. Currently, XPU platform + only supports dynamic quantization. The function also allows optional + padding of the output tensors for downstream kernels that will benefit + from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[tuple[int, int], torch.Size] = input.shape + out_dtype: torch.dtype = current_platform.fp8_dtype() + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, ( + "padding not supported if output passed in" + ) + assert output.dtype == out_dtype + assert scale is None, "only dynamic fp8 quantization supported on XPU" + assert not use_per_token_if_dynamic, ( + "per token dynamic fp8 quantization not supported on XPU" + ) + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale) + + return output, scale diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py deleted file mode 100644 index 9753a08806565..0000000000000 --- a/vllm/adapter_commons/layers.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - - -@dataclass -class AdapterMapping: - # Per every token in input_ids: - index_mapping: tuple[int, ...] - # Per sampled token: - prompt_mapping: tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py deleted file mode 100644 index 7b685880a9e6c..0000000000000 --- a/vllm/adapter_commons/models.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, TypeVar - -from torch import nn - -from vllm.logger import init_logger -from vllm.utils import LRUCache - -logger = init_logger(__name__) - - -class AdapterModel(ABC): - - def __init__(self, model_id=None): - self.id = model_id - - @abstractmethod - def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): - # Common initialization code - # Load weights or embeddings from local checkpoint - raise NotImplementedError("Subclasses must implement this method.") - - -T = TypeVar('T') - - -class AdapterLRUCache(LRUCache[int, T]): - - def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): - super().__init__(capacity) - self.deactivate_fn = deactivate_fn - - def _on_remove(self, key: int, value: Optional[T]): - logger.debug("Removing adapter int id: %d", key) - self.deactivate_fn(key) - return super()._on_remove(key, value) - - -class AdapterModelManager(ABC): - - def __init__( - self, - model: nn.Module, - ): - """Create a AdapterModelManager and adapter for a given model. - Args: - model: the model to be adapted. - """ - self.model: nn.Module = model - self._registered_adapters: dict[int, Any] = {} - # Dict instead of a Set for compatibility with LRUCache. - self._active_adapters: dict[int, None] = {} - self.adapter_type = 'Adapter' - self._last_mapping = None - - def __len__(self) -> int: - return len(self._registered_adapters) - - @property - @abstractmethod - def adapter_slots(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def capacity(self) -> int: - raise NotImplementedError - - @abstractmethod - def activate_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def deactivate_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def add_adapter(self, adapter: Any) -> bool: - raise NotImplementedError - - @abstractmethod - def set_adapter_mapping(self, mapping: Any) -> None: - raise NotImplementedError - - @abstractmethod - def remove_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_all_adapters(self) -> None: - raise NotImplementedError - - @abstractmethod - def get_adapter(self, adapter_id: int) -> Optional[Any]: - raise NotImplementedError - - @abstractmethod - def list_adapters(self) -> dict[int, Any]: - raise NotImplementedError - - @abstractmethod - def pin_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py deleted file mode 100644 index 8135b54ba19f6..0000000000000 --- a/vllm/adapter_commons/request.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod - - -class AdapterRequest(ABC): - """ - Base class for adapter requests. - """ - - @property - @abstractmethod - def adapter_id(self) -> int: - raise NotImplementedError - - def __post_init__(self) -> None: - if self.adapter_id < 1: - raise ValueError(f"id must be > 0, got {self.adapter_id}") - - def __eq__(self, value: object) -> bool: - return isinstance( - value, self.__class__) and self.adapter_id == value.adapter_id - - def __hash__(self) -> int: - return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py deleted file mode 100644 index a1a56b6bbd4ba..0000000000000 --- a/vllm/adapter_commons/utils.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Optional - - -## model functions -def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None], - deactivate_func: Callable) -> bool: - if adapter_id in active_adapters: - deactivate_func(adapter_id) - active_adapters.pop(adapter_id) - return True - return False - - -def add_adapter(adapter: Any, registered_adapters: dict[int, Any], - capacity: int, add_func: Callable) -> bool: - if adapter.id not in registered_adapters: - if len(registered_adapters) >= capacity: - raise RuntimeError('No free adapter slots.') - add_func(adapter) - registered_adapters[adapter.id] = adapter - return True - return False - - -def set_adapter_mapping(mapping: Any, last_mapping: Any, - set_mapping_func: Callable) -> Any: - if last_mapping != mapping: - set_mapping_func(mapping) - return mapping - return last_mapping - - -def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any], - deactivate_func: Callable) -> bool: - deactivate_func(adapter_id) - return bool(registered_adapters.pop(adapter_id, None)) - - -def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]: - return dict(registered_adapters) - - -def get_adapter(adapter_id: int, - registered_adapters: dict[int, Any]) -> Optional[Any]: - return registered_adapters.get(adapter_id) - - -## worker functions -def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any], - apply_adapters_func, - set_adapter_mapping_func) -> None: - apply_adapters_func(requests) - set_adapter_mapping_func(mapping) - - -def add_adapter_worker(adapter_request: Any, list_adapters_func, - load_adapter_func, add_adapter_func, - activate_adapter_func) -> bool: - if adapter_request.adapter_id in list_adapters_func(): - return False - loaded_adapter = load_adapter_func(adapter_request) - loaded = add_adapter_func(loaded_adapter) - activate_adapter_func(loaded_adapter.id) - return loaded - - -def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func, - adapter_slots: int, remove_adapter_func, - add_adapter_func) -> None: - models_that_exist = list_adapters_func() - models_map = { - adapter_request.adapter_id: adapter_request - for adapter_request in adapter_requests if adapter_request - } - if len(models_map) > adapter_slots: - raise RuntimeError( - f"Number of requested models ({len(models_map)}) is greater " - f"than the number of GPU model slots " - f"({adapter_slots}).") - new_models = set(models_map) - models_to_add = new_models - models_that_exist - models_to_remove = models_that_exist - new_models - for adapter_id in models_to_remove: - remove_adapter_func(adapter_id) - for adapter_id in models_to_add: - add_adapter_func(models_map[adapter_id]) - - -def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]: - return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py deleted file mode 100644 index 07e85d138ac50..0000000000000 --- a/vllm/adapter_commons/worker_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Any, Optional - -import torch - - -class AbstractWorkerManager(ABC): - - def __init__(self, device: torch.device): - self.device = device - - @property - @abstractmethod - def is_enabled(self) -> bool: - raise NotImplementedError - - @abstractmethod - def set_active_adapters(self, requests: set[Any], - mapping: Optional[Any]) -> None: - raise NotImplementedError - - @abstractmethod - def add_adapter(self, adapter_request: Any) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_all_adapters(self) -> None: - raise NotImplementedError - - @abstractmethod - def list_adapters(self) -> set[int]: - raise NotImplementedError diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 1c16230849bca..61c2dbf55fe31 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -32,13 +32,11 @@ class AudioAsset: @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: - audio_path = get_vllm_public_assets(filename=self.filename, - s3_prefix=ASSET_DIR) + audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) def get_local_path(self) -> Path: - return get_vllm_public_assets(filename=self.filename, - s3_prefix=ASSET_DIR) + return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) @property def url(self) -> str: diff --git a/vllm/assets/base.py b/vllm/assets/base.py index 31cde431b5b6a..409bfc18ff8cf 100644 --- a/vllm/assets/base.py +++ b/vllm/assets/base.py @@ -20,8 +20,7 @@ def get_cache_dir() -> Path: @lru_cache -def get_vllm_public_assets(filename: str, - s3_prefix: Optional[str] = None) -> Path: +def get_vllm_public_assets(filename: str, s3_prefix: Optional[str] = None) -> Path: """ Download an asset file from ``s3://vllm-public-assets`` and return the path to the downloaded file. @@ -36,6 +35,7 @@ def get_vllm_public_assets(filename: str, global_http_connection.download_file( f"{VLLM_S3_BUCKET_URL}/{filename}", asset_path, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT) + timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py index c977242a3d484..c1a0f2b9cc294 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from pathlib import Path from typing import Literal import torch @@ -11,17 +12,38 @@ from .base import get_vllm_public_assets VLM_IMAGES_DIR = "vision_model_images" -ImageAssetName = Literal["stop_sign", "cherry_blossom"] +ImageAssetName = Literal[ + "stop_sign", + "cherry_blossom", + "hato", + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk", + "Grayscale_8bits_palette_sample_image", + "1280px-Venn_diagram_rgb", + "RGBA_comp", + "237-400x300", + "231-200x300", + "27-500x500", + "17-150x600", + "handelsblatt-preview", + "paper-11", +] @dataclass(frozen=True) class ImageAsset: name: ImageAssetName + def get_path(self, ext: str) -> Path: + """ + Return s3 path for given image. + """ + return get_vllm_public_assets( + filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR + ) + @property - def pil_image(self) -> Image.Image: - image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", - s3_prefix=VLM_IMAGES_DIR) + def pil_image(self, ext="jpg") -> Image.Image: + image_path = self.get_path(ext) return Image.open(image_path) @property @@ -29,6 +51,9 @@ class ImageAsset: """ Image embeddings, only used for testing purposes with llava 1.5. """ - image_path = get_vllm_public_assets(filename=f"{self.name}.pt", - s3_prefix=VLM_IMAGES_DIR) + image_path = self.get_path("pt") return torch.load(image_path, map_location="cpu", weights_only=True) + + def read_bytes(self, ext: str) -> bytes: + p = Path(self.get_path(ext)) + return p.read_bytes() diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 8ab0e9760be87..6b2ca8f867e03 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -65,18 +65,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: frames = np.stack(frames) if len(frames) < num_frames: - raise ValueError(f"Could not read enough frames from video file {path}" - f" (expected {num_frames} frames, got {len(frames)})") + raise ValueError( + f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})" + ) return frames -def video_to_pil_images_list(path: str, - num_frames: int = -1) -> list[Image.Image]: +def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]: frames = video_to_ndarrays(path, num_frames) return [Image.fromarray(frame) for frame in frames] -def video_get_metadata(path: str) -> dict[str, Any]: +def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]: cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Could not open video file {path}") @@ -85,11 +86,18 @@ def video_get_metadata(path: str) -> dict[str, Any]: fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames / fps if fps > 0 else 0 + if num_frames == -1 or num_frames > total_frames: + num_frames = total_frames + metadata = { - "total_num_frames": total_frames, + "total_num_frames": num_frames, "fps": fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames, } return metadata @@ -110,29 +118,29 @@ class VideoAsset: def filename(self) -> str: return self._NAME_TO_FILE[self.name] + @property + def video_path(self) -> str: + return download_video_asset(self.filename) + @property def pil_images(self) -> list[Image.Image]: - video_path = download_video_asset(self.filename) - ret = video_to_pil_images_list(video_path, self.num_frames) + ret = video_to_pil_images_list(self.video_path, self.num_frames) return ret @property def np_ndarrays(self) -> npt.NDArray: - video_path = download_video_asset(self.filename) - ret = video_to_ndarrays(video_path, self.num_frames) + ret = video_to_ndarrays(self.video_path, self.num_frames) return ret @property def metadata(self) -> dict[str, Any]: - video_path = download_video_asset(self.filename) - ret = video_get_metadata(video_path) + ret = video_get_metadata(self.video_path, self.num_frames) return ret def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: """ Read audio data from the video asset, used in Qwen2.5-Omni examples. - + See also: examples/offline_inference/qwen2_5_omni/only_thinker.py """ - video_path = download_video_asset(self.filename) - return librosa.load(video_path, sr=sampling_rate)[0] + return librosa.load(self.video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index dcb2aa68fbee9..dd35165d5415e 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -13,7 +14,5 @@ __all__ = [ "AttentionBackend", "AttentionMetadata", "AttentionType", - "AttentionMetadataBuilder", - "AttentionState", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0b9c625533cb7..3f23d4ef7d2c1 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,20 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from contextlib import contextmanager -from dataclasses import dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, - Protocol, Set, Tuple, Type, TypeVar) +from typing import Generic, Optional, Protocol, TypeVar, Union import torch +from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey -from vllm.multimodal import MultiModalPlaceholderMap - -if TYPE_CHECKING: - from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - ModelRunnerInputBuilderBase) class AttentionType: @@ -23,23 +15,40 @@ class AttentionType: Attention type. Use string to be compatible with `torch.compile`. """ - # Decoder attention between previous layer Q/K/V + DECODER = "decoder" - # Encoder attention between previous layer Q/K/V for encoder-decoder + """Decoder attention between previous layer Q/K/V.""" ENCODER = "encoder" - # Encoder attention between previous layer Q/K/V + """Encoder attention between previous layer Q/K/V for encoder-decoder.""" ENCODER_ONLY = "encoder_only" - # Attention between dec. Q and enc. K/V for encoder-decoder + """Encoder attention between previous layer Q/K/V.""" ENCODER_DECODER = "encoder_decoder" + """Attention between dec. Q and enc. K/V for encoder-decoder.""" + + +class MultipleOf: + base: int + + def __init__(self, base: int): + self.base = base class AttentionBackend(ABC): """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + # Whether this backend supports receiving pre-quantized query input. + # If True, the attention layer will handle query quantization instead + # of the backend, allowing torch.compile to fuse quantization with + # previous operations. + # Needs to be worked through for all backends + # https://github.com/vllm-project/vllm/issues/25584 + supports_quant_query_input: bool = False + @staticmethod @abstractmethod def get_name() -> str: @@ -47,18 +56,17 @@ class AttentionBackend(ABC): @staticmethod @abstractmethod - def get_impl_cls() -> Type["AttentionImpl"]: + def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError @staticmethod @abstractmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError - @staticmethod - @abstractmethod - def get_state_cls() -> Type["AttentionState"]: - raise NotImplementedError + @classmethod + def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]: + return cls.get_impl_cls().get_supported_kernel_block_size() @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": @@ -66,7 +74,7 @@ class AttentionBackend(ABC): @staticmethod @abstractmethod - def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: raise NotImplementedError @staticmethod @@ -76,28 +84,12 @@ class AttentionBackend(ABC): block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: raise NotImplementedError @staticmethod - def get_kv_cache_stride_order() -> Tuple[int, ...]: - raise NotImplementedError - - @staticmethod - @abstractmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise NotImplementedError - - @staticmethod - @abstractmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: + def get_kv_cache_stride_order() -> tuple[int, ...]: raise NotImplementedError @classmethod @@ -105,141 +97,18 @@ class AttentionBackend(ABC): return (cls.__module__, cls.__qualname__) -@dataclass class AttentionMetadata: - """Attention metadata for prefill and decode batched together.""" - # Total number of prefill requests. - num_prefills: int - # Number of prefill tokens. - num_prefill_tokens: int - # Number of decode tokens. Note that it is equivalent to the number of - # decode requests. - num_decode_tokens: int - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - - # The index maps that relate multi-modal embeddings to the corresponding - # placeholders. - # - # N.B. These aren't really related to attention and don't belong on this - # type -- this is just a temporary solution to make them available to - # `model_executable`. - multi_modal_placeholder_index_maps: Optional[Dict[ - str, MultiModalPlaceholderMap.IndexMap]] - - # Enable/disable KV scales calculation. This is so that we can disable the - # calculation until after prefill and cuda graph capture. - enable_kv_scales_calculation: bool - - @property - @abstractmethod - def prefill_metadata(self) -> Optional["AttentionMetadata"]: - """Return the attention metadata that's required to run prefill - attention.""" - pass - - @property - @abstractmethod - def decode_metadata(self) -> Optional["AttentionMetadata"]: - """Return the attention metadata that's required to run decode - attention.""" - pass - - def asdict_zerocopy(self, - skip_fields: Optional[Set[str]] = None - ) -> Dict[str, Any]: - """Similar to dataclasses.asdict, but avoids deepcopying.""" - if skip_fields is None: - skip_fields = set() - # Note that if we add dataclasses as fields, they will need - # similar handling. - return { - field.name: getattr(self, field.name) - for field in fields(self) if field.name not in skip_fields - } + pass T = TypeVar("T", bound=AttentionMetadata) -class AttentionState(ABC, Generic[T]): - """Holds attention backend-specific objects reused during the - lifetime of the model runner.""" - - @abstractmethod - def __init__(self, runner: "ModelRunnerBase"): - ... - - @abstractmethod - @contextmanager - def graph_capture(self, max_batch_size: int): - """Context manager used when capturing CUDA graphs.""" - yield - - @abstractmethod - def graph_clone(self, batch_size: int) -> "AttentionState[T]": - """Clone attention state to save in CUDA graph metadata.""" - ... - - @abstractmethod - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> T: - """Get attention metadata for CUDA graph capture of batch_size.""" - ... - - @abstractmethod - def get_graph_input_buffers( - self, - attn_metadata: T, - is_encoder_decoder_model: bool = False) -> Dict[str, Any]: - """Get attention-specific input buffers for CUDA graph capture.""" - ... - - @abstractmethod - def prepare_graph_input_buffers( - self, - input_buffers: Dict[str, Any], - attn_metadata: T, - is_encoder_decoder_model: bool = False) -> None: - """In-place modify input buffers dict for CUDA graph replay.""" - ... - - @abstractmethod - def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: - """Prepare state for forward pass.""" - ... - - -class AttentionMetadataBuilder(ABC, Generic[T]): - """Abstract class for attention metadata builders.""" - - @abstractmethod - def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: - """Create the builder, remember some configuration and parameters.""" - raise NotImplementedError - - @abstractmethod - def prepare(self) -> None: - """Prepare for one batch.""" - raise NotImplementedError - - @abstractmethod - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int) -> T: - """Build attention metadata with on-device tensors.""" - raise NotImplementedError - - class AttentionLayer(Protocol): - _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor + _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor @@ -251,11 +120,37 @@ class AttentionLayer(Protocol): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... class AttentionImpl(ABC, Generic[T]): + # Whether the attention impl can return the softmax lse for decode. + # Some features like decode context parallelism require the softmax lse. + can_return_lse_for_decode: bool = False + + # some attention backends might not always want to return lse + # even if they can return lse (for efficiency reasons) + need_to_return_lse_for_decode: bool = False + + dcp_world_size: int + dcp_rank: int + + def __new__(cls, *args, **kwargs): + # use __new__ so that all subclasses will call this + self = super().__new__(cls) + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.need_to_return_lse_for_decode = ( + self.dcp_world_size > 1 and self.can_return_lse_for_decode + ) + return self @abstractmethod def __init__( @@ -264,7 +159,7 @@ class AttentionImpl(ABC, Generic[T]): head_size: int, scale: float, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", logits_soft_cap: Optional[float] = None, @@ -273,6 +168,11 @@ class AttentionImpl(ABC, Generic[T]): ) -> None: raise NotImplementedError + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + # TODO: implement this function for all backends. + return [MultipleOf(1)] + @abstractmethod def forward( self, @@ -301,6 +201,30 @@ class AttentionImpl(ABC, Generic[T]): class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + indexer: Optional[object] = None, + ) -> None: + raise NotImplementedError @abstractmethod def forward( diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py deleted file mode 100644 index ce9467efd23c7..0000000000000 --- a/vllm/attention/backends/differential_flash_attn.py +++ /dev/null @@ -1,928 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""" An implementation of https://arxiv.org/pdf/2410.05258 """ -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -import torch -from einops import rearrange - -from vllm import _custom_ops as ops -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.flash_attn import FlashAttentionBackend -# yapf: enable -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, - is_block_tables_empty) -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) -from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - -logger = init_logger(__name__) - - -class DifferentialFlashAttentionBackend(AttentionBackend): - accept_output_buffer = False - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2" - return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) - - @staticmethod - def get_name() -> str: - return "DIFFERENTIAL_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]: - return DifferentialFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: - return DifferentialFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: - return DifferentialFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class DifferentialFlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional[ - "DifferentialFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional[ - "DifferentialFlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - # Cross-layer shared attention block tables - cross_layer_shared_block_tables: Optional[torch.Tensor] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata( - self) -> Optional["DifferentialFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - cross_layer_shared_block_tables = ( - None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata( - self) -> Optional["DifferentialFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - cross_layer_shared_block_tables = ( - None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[self.num_prefills:]) - self._cached_decode_metadata = DifferentialFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_decode_metadata - - -class DifferentialFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.cross_layer_shared_block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - # TODO: add support for chunked prefill and prefix caching. - assert not chunked_prefill_enabled, \ - "chunked prefill is not supported for now" - assert not prefix_cache_hit, "prefix caching is not supported for now" - - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - cross_layer_shared_block_table = [] - if prefix_cache_hit: - cross_layer_shared_block_table = block_tables[seq_id] - elif block_tables is not None: - if curr_sliding_window_block == 0: - cross_layer_shared_block_table = block_tables[seq_id] - else: - cross_layer_shared_block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.cross_layer_shared_block_tables.append( - cross_layer_shared_block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables(self, num_seqs: int, - block_tables: List[List[int]], - graph_block_tables) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - # max_batch_size, max_blocks = self.runner.graph_block_tables.shape - max_batch_size, max_blocks = graph_block_tables.shape - assert max_batch_size >= num_seqs - - # graph_block_tables = self.runner.graph_block_tables[:num_seqs] - graph_block_tables = graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - - self.cross_layer_shared_block_tables.extend([] * - cuda_graph_pad_size) - - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables, self.runner.graph_block_tables) - cross_layer_shared_block_tables = \ - self._get_graph_runner_block_tables( - num_seqs, self.cross_layer_shared_block_tables, - self.runner.cross_layer_shared_graph_block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - cross_layer_shared_block_tables = make_tensor_with_pad( - self.cross_layer_shared_block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return DifferentialFlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class DifferentialFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - differential_flash_attention_config: Optional[Dict[str, Any]] = None, - ) -> None: - if differential_flash_attention_config is None: - differential_flash_attention_config = {} - self.differential_flash_attention_config = \ - differential_flash_attention_config - self.used_shared_kv_cache = kv_sharing_target_layer_name is not None - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version( - requires_alibi=self.alibi_slopes is not None) - if is_quantized_kv_cache(self.kv_cache_dtype) and ( - not self.kv_cache_dtype.startswith("fp8") - or not flash_attn_supports_fp8()): - raise NotImplementedError( - f"FlashAttention does not support {self.kv_cache_dtype} " - "kv-cache on this device " - f"(FA supports fp8 = {flash_attn_supports_fp8()}).") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - self.attn_type = attn_type - - self.lambda_full = None - self.subln = self.differential_flash_attention_config["subln"] - - def split_heads(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - x = rearrange(x, "... (H two) D -> ... H two D", two=2) - x1 = x[..., 0, :] - x2 = x[..., 1, :] - return x1.contiguous(), x2.contiguous() - - def split_kv_cache(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - if x.numel() == 0: - return torch.empty(0), torch.empty(0) - - x1, x2 = x[0], x[1] - return x1, x2 - - def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor, - value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata): - if kv_cache.numel() > 0 and key is not None and value is not None: - updated_slot_mapping = attn_metadata.slot_mapping - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - def forward_generate_kv_cache( - self, query: torch.Tensor, key: Optional[torch.Tensor], - value: Optional[torch.Tensor], k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: - - head_size = self.head_size - num_heads = self.num_heads // 2 - num_kv_heads = self.num_kv_heads // 2 - - query = query.view(-1, num_heads, head_size) - if key is not None: - assert value is not None - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - else: - assert value is None - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[ - 0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" - assert value.shape[ - 0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - if key is not None and value is not None: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens, "query shape mismatch" - assert decode_query.shape[ - 0] == num_decode_tokens, "decode query shape mismatch" - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if k_cache.numel() == 0 \ - or prefill_meta.block_tables is None \ - or prefill_meta.block_tables.numel() == 0: - # normal attention - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - assert prefill_output.shape == output[: - num_prefill_tokens].shape - output[:num_prefill_tokens] = prefill_output - else: - raise Exception("prefix caching not supported") - - if decode_meta := attn_metadata.decode_metadata: - block_tables_arg = decode_meta.block_tables - try: - output[num_prefill_tokens:] = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - except Exception as e: - logger.error("Error in PagedAttention.forward_decode: %s", - str(e)) - raise e - - # Reshape the output tensor. - return output.view(-1, num_heads, head_size) - - def forward_with_kv_cache_only( - self, - query: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - ): - if not attn_metadata.decode_metadata: - block_tables_arg = attn_metadata.cross_layer_shared_block_tables - else: - block_tables_arg = attn_metadata.block_tables - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=attn_metadata.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - return output - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for DifferentialFlashAttentionImpl") - - if self.lambda_full is None: - self.lambda_init = self.differential_flash_attention_config[ - "lambda_init"] - lambda_q1 = self.differential_flash_attention_config["lambda_q1"] - lambda_k1 = self.differential_flash_attention_config["lambda_k1"] - lambda_q2 = self.differential_flash_attention_config["lambda_q2"] - lambda_k2 = self.differential_flash_attention_config["lambda_k2"] - lambda_1 = torch.exp( - torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) - lambda_2 = torch.exp( - torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) - self.lambda_full = lambda_1 - lambda_2 + self.lambda_init - - if not self.used_shared_kv_cache: # need to generate kv-cache - q = q.view(-1, self.num_heads, self.head_size) - k = k.view(-1, self.num_kv_heads, self.head_size) - v = v.view(-1, self.num_kv_heads, self.head_size) - - q1, q2 = self.split_heads(q) - k1, k2 = self.split_heads(k) - v1, v2 = self.split_heads(v) - - # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501 - # Split by half along the first dimension. - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" - assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" - - if kv_cache1.numel() != 0: - self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) - self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) - - key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) - key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) - else: - key_cache1, value_cache1 = torch.empty(0), torch.empty(0) - key_cache2, value_cache2 = torch.empty(0), torch.empty(0) - attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, - value_cache1, - attn_metadata) - attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, - value_cache2, - attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, - value_cache1, - attn_metadata) - attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, - value_cache2, - attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - attn = attn1 - self.lambda_full * attn2 - # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, - "... H (two D) -> ... (H two) D", - two=2) - - else: # reuse the kv cache, full attention - q = q.view(-1, self.num_heads, self.head_size) - q1, q2 = self.split_heads(q) - # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501 - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] - key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] - - attn11 = self.forward_with_kv_cache_only(q1, key_cache1, - value_cache1, - attn_metadata) - attn12 = self.forward_with_kv_cache_only(q1, key_cache1, - value_cache2, - attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_with_kv_cache_only(q2, key_cache2, - value_cache1, - attn_metadata) - attn22 = self.forward_with_kv_cache_only(q2, key_cache2, - value_cache2, - attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - attn = attn1 - self.lambda_full * attn2 - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, - "... H (two D) -> ... (H two) D", - two=2) - attn_output = attn_output.view(-1, self.num_heads * self.head_size) - return attn_output diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py deleted file mode 100644 index 85957bea1e26d..0000000000000 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ /dev/null @@ -1,1499 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with Dual chunk flash attention and sparse attention. -""" -import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -import torch -import torch.distributed -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionImpl, - FlashAttentionMetadata, - FlashAttentionMetadataBuilder) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.logger import init_logger -from vllm.utils import async_tensor_h2d -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache, sparse_attn_func) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - -logger = init_logger(__name__) - - -class DualChunkFlashAttentionBackend(FlashAttentionBackend): - - accept_output_buffer: bool = False - - @staticmethod - def get_name() -> str: - return "DUAL_CHUNK_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: - return DualChunkFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: - return DualChunkFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: - return DualChunkFlashAttentionMetadataBuilder - - -@dataclass -class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): - # Block size of the paged kv cache. - block_size: int = 16 - - # Original max position embeddings. - original_max_position_embeddings: int = 0 - - # Chunk size - chunk_size: int = 8192 - - # Local size - local_size: int = 1024 - - # (batch_size,). The orig sequence length per sequence. - orig_seq_lens: Optional[List[int]] = None - - # orig_seq_lens stored as a tensor. - orig_seq_lens_tensor: Optional[torch.Tensor] = None - - # Length scaling factor - scaling_factor: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for intra attention. - seq_lens_intra: Optional[torch.Tensor] = None - - # Max sequence length for intra attention. - max_seq_len_intra: Optional[int] = None - - # (batch_size, num_blocks). Block table for intra attention. - block_tables_intra: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for succ attention. - seq_lens_succ: Optional[torch.Tensor] = None - - # Max sequence length for succ attention. - max_seq_len_succ: Optional[int] = None - - # (batch_size, num_blocks). Block table for succ attention. - block_tables_succ: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for inter attention. - seq_lens_inter: Optional[torch.Tensor] = None - - # Max sequence length for inter attention. - max_seq_len_inter: Optional[int] = None - - _cached_prefill_metadata: Optional[ - "DualChunkFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - prefill_metadata = super().prefill_metadata - if prefill_metadata is None: - return None - - prefill_metadata = DualChunkFlashAttentionMetadata( - **prefill_metadata.asdict_zerocopy()) - - prefill_metadata.orig_seq_lens = ( - None if self.orig_seq_lens is None else - self.orig_seq_lens[:self.num_prefills]) - prefill_metadata.orig_seq_lens_tensor = ( - None if self.orig_seq_lens_tensor is None else - self.orig_seq_lens_tensor[:self.num_prefills]) - - if self.original_max_position_embeddings > 0: - assert prefill_metadata.orig_seq_lens_tensor is not None - prefill_metadata.scaling_factor = ( - 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / - self.original_max_position_embeddings) + - 1.0).clip(min=1) - - self._cached_prefill_metadata = prefill_metadata - return prefill_metadata - - @property - def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - - decode_metadata = super().decode_metadata - if decode_metadata is None: - return None - - decode_metadata = DualChunkFlashAttentionMetadata( - **decode_metadata.asdict_zerocopy()) - - decode_metadata.orig_seq_lens_tensor = ( - None if self.orig_seq_lens_tensor is None else - self.orig_seq_lens_tensor[self.num_prefills:]) - - assert decode_metadata.orig_seq_lens_tensor is not None - assert decode_metadata.block_tables is not None - - cache_seq_lens = decode_metadata.orig_seq_lens_tensor - chunk_len = self.chunk_size - self.local_size - chunk_num_curr = (cache_seq_lens - 1) // chunk_len - batch_size = decode_metadata.num_decode_tokens - - if self.original_max_position_embeddings > 0: - decode_metadata.scaling_factor = (0.1 * torch.log( - cache_seq_lens / self.original_max_position_embeddings) + - 1.0).clip(min=1) - - seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len - max_seq_len_intra = seq_lens_intra.max().item() - decode_metadata.seq_lens_intra = seq_lens_intra - decode_metadata.max_seq_len_intra = max_seq_len_intra - - block_tables_intra = torch.zeros( - batch_size, - (max_seq_len_intra - 1) // self.block_size + 1, - dtype=decode_metadata.block_tables.dtype, - device=decode_metadata.block_tables.device, - ) - for i in range(batch_size): - st = chunk_num_curr[i] * chunk_len // self.block_size - ed = min( - st + (max_seq_len_intra - 1) // self.block_size + 1, - (cache_seq_lens[i] - 1) // self.block_size + 1, - ) - block_tables_intra[i, :ed - - st] = decode_metadata.block_tables[i, st:ed] - decode_metadata.block_tables_intra = block_tables_intra - - seq_lens_succ = (chunk_num_curr - - (chunk_num_curr - 1).clip(min=0)) * chunk_len - max_seq_len_succ = seq_lens_succ.max().item() - decode_metadata.seq_lens_succ = seq_lens_succ - decode_metadata.max_seq_len_succ = max_seq_len_succ - if max_seq_len_succ: - block_tables_succ = torch.zeros( - batch_size, - (max_seq_len_succ - 1) // self.block_size + 1, - dtype=decode_metadata.block_tables.dtype, - device=decode_metadata.block_tables.device, - ) - for i in range(batch_size): - start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // - self.block_size) - end = min( - start + (max_seq_len_succ - 1) // self.block_size + 1, - (cache_seq_lens[i] - 1) // self.block_size + 1, - ) - block_tables_succ[ - i, :end - start] = decode_metadata.block_tables[i, - start:end] - decode_metadata.block_tables_succ = block_tables_succ - - seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len - max_seq_len_inter = seq_lens_inter.max().item() - decode_metadata.seq_lens_inter = seq_lens_inter - decode_metadata.max_seq_len_inter = max_seq_len_inter - - self._cached_decode_metadata = decode_metadata - return decode_metadata - - -class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): - - def prepare(self): - super().prepare() - self.orig_seq_lens: List[int] = [] - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - super()._add_seq_group(inter_data, chunked_prefill_enabled, - prefix_cache_hit) - for prompt_len, seq_len in zip(inter_data.prompt_lens, - inter_data.seq_lens): - self.orig_seq_lens.append(max(prompt_len, seq_len)) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - attn_metadata = super().build(seq_lens, query_lens, - cuda_graph_pad_size, batch_size) - attn_metadata = DualChunkFlashAttentionMetadata( - **attn_metadata.asdict_zerocopy()) - - device = self.runner.device - attn_metadata.orig_seq_lens = self.orig_seq_lens - attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( - self.orig_seq_lens, torch.int, device, self.runner.pin_memory) - - attn_metadata.block_size = self.runner.block_size - dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, - "dual_chunk_attention_config", {}) - attn_metadata.original_max_position_embeddings = \ - dual_chunk_attn_config.get("original_max_position_embeddings", 0) - attn_metadata.chunk_size = dual_chunk_attn_config.get( - "chunk_size", 8192) - attn_metadata.local_size = dual_chunk_attn_config.get( - "local_size", 1024) - - return attn_metadata - - -class DualChunkFlashAttentionImpl(FlashAttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - The prompts might have different lengths, while the generation tokens - always have length 1. - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - layer_idx: int = -1, - dual_chunk_attention_config: Optional[Dict[str, Any]] = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "DUAL_CHUNK_FLASH_ATTN backend.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - - support_head_sizes = ( - DualChunkFlashAttentionBackend.get_supported_head_sizes()) - - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - - assert dual_chunk_attention_config is not None - self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) - self.local_size = dual_chunk_attention_config.get("local_size", 1024) - self.original_max_position_embeddings = dual_chunk_attention_config.get( - "original_max_position_embeddings", 0) - self.sparse_attention_config = dual_chunk_attention_config.get( - "sparse_attention_config", None) - if not self.sparse_attention_config: - logger.warning_once("Sparse attention will not be enabled as " - "sparse attention config is not provided.") - self.sparse_attention_enabled = dual_chunk_attention_config.get( - "sparse_attention_enabled", self.sparse_attention_config - is not None) - self.sparse_attention_threshold = dual_chunk_attention_config.get( - "sparse_attention_threshold", 32768) - self.sparse_attention_last_q = dual_chunk_attention_config.get( - "sparse_attention_last_q", 64) - self.layer_idx = layer_idx - self.dual_chunk_attention_config = dual_chunk_attention_config - - if self.sparse_attention_config: - self.sparse_attention_config = { - int(i): j - for i, j in self.sparse_attention_config[ - self.layer_idx].items() - } - start_head = self.num_heads * get_tensor_model_parallel_rank() - end_head = start_head + self.num_heads - self.sparse_attention_config = [ - self.sparse_attention_config[i] - for i in range(start_head, end_head) - ] - - if self.sparse_attention_enabled: - self.arange = torch.arange(self.sparse_attention_last_q, - device="cuda") - self.last_q_mask = (self.arange[None, None, :, None] - >= self.arange[None, None, None, :]) - - def forward( # type: ignore - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DualChunkFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with DualChunkFlashAttention. - Args: - query: shape = [num_tokens, num_heads * head_size] - query_succ: shape = [num_tokens, num_heads * head_size] - query_inter: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is None, "Output tensor not supported for DualChunk" - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") - - ( - query, - query_succ, - query_inter, - query_succ_critical, - query_inter_critical, - ) = torch.split(query, query.shape[-1] // 5, dim=-1) - - assert ( - query_succ is not None and query_inter is not None - ), "query_succ and query_inter are required in Dual Chunk Attention." - - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - query_succ = query_succ.view(-1, self.num_heads, self.head_size) - query_inter = query_inter.view(-1, self.num_heads, self.head_size) - query_succ_critical = query_succ_critical.view(-1, self.num_heads, - self.head_size) - query_inter_critical = query_inter_critical.view( - -1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if self.original_max_position_embeddings > 0: - if prefill_meta := attn_metadata.prefill_metadata: - assert prefill_meta.scaling_factor is not None - assert prefill_meta.query_start_loc is not None - assert prefill_meta.orig_seq_lens is not None - current_start = 0 - query_start_loc_cpu = prefill_meta.query_start_loc.cpu() - for i in range(len(prefill_meta.orig_seq_lens)): - current_end = (current_start + - (query_start_loc_cpu[i + 1] - - query_start_loc_cpu[i]).item()) - key[current_start:current_end].mul_( - prefill_meta.scaling_factor[i]) - current_start = current_end - assert current_end <= attn_metadata.num_prefill_tokens - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta.scaling_factor is not None - scaling_factor = decode_meta.scaling_factor - key[attn_metadata.num_prefill_tokens:].mul_( - scaling_factor.unsqueeze(-1).unsqueeze(-1)) - - if kv_cache is not None and kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - output = torch.empty_like(query) - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - decode_query_succ = query_succ[num_prefill_tokens:] - decode_query_inter = query_inter[num_prefill_tokens:] - - # QKV for prefill. - query = query[:num_prefill_tokens] - query_succ = query_succ[:num_prefill_tokens] - query_inter = query_inter[:num_prefill_tokens] - query_succ_critical = query_succ_critical[:num_prefill_tokens] - query_inter_critical = query_inter_critical[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention, called during the profiling run. - out = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - assert prefill_meta.orig_seq_lens is not None - output[:num_prefill_tokens] = ( - self._dual_chunk_flash_attn_prefill( - q=query, - q_succ=query_succ, - q_inter=query_inter, - q_succ_critical=query_succ_critical, - q_inter_critical=query_inter_critical, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - orig_seq_lens=prefill_meta.orig_seq_lens, - scaling_factor=prefill_meta.scaling_factor, - softmax_scale=self.scale, - causal=True, - window_size=(-1, -1), - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - chunk_size=self.chunk_size, - local_size=self.local_size, - )) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = ( - self._dual_chunk_flash_attn_decoding( - decode_query.unsqueeze(1), - decode_query_succ.unsqueeze(1), - decode_query_inter.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - chunk_size=self.chunk_size, - local_size=self.local_size, - original_max_position_embeddings=self. - original_max_position_embeddings, - decode_meta=decode_meta, - ).squeeze(1)) - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) - - def _dual_chunk_flash_attn_prefill( - self, - q, - q_succ, - q_inter, - q_succ_critical, - q_inter_critical, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - orig_seq_lens: List[int], - scaling_factor: torch.Tensor, - softmax_scale: float, - causal: Optional[bool] = True, - window_size: Tuple[int, int] = (-1, -1), - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - chunk_size: int = 8192, - local_size: int = 1024, - ): - if alibi_slopes is not None: - raise ValueError( - "Dual Chunk Attention does not support alibi_slopes") - if not causal: - raise ValueError( - "Dual Chunk Attention does not support causal=False") - if window_size != (-1, -1): - raise ValueError( - "Dual Chunk Attention does not support window_size") - - cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() - cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() - all_outputs = [] - - for i in range(0, len(cu_seqlens_q_cpu) - 1): - qs = cu_seqlens_q_cpu[i] - qe = cu_seqlens_q_cpu[i:i + 2][-1] - ks = cu_seqlens_k_cpu[i] - ke = cu_seqlens_k_cpu[i:i + 2][-1] - - current_q = q[qs:qe] - current_q_succ = q_succ[qs:qe] - current_q_inter = q_inter[qs:qe] - current_q_succ_critical = q_succ_critical[qs:qe] - current_q_inter_critical = q_inter_critical[qs:qe] - - if block_table is None: - current_k = k[ks:ke] - current_v = v[ks:ke] - current_block_table = None - current_orig_seq_len = orig_seq_lens[i] - else: - current_block_table = block_table[i] - current_orig_seq_len = orig_seq_lens[i] - current_k = k - current_v = v - sparse_attn_enabled = (self.sparse_attention_enabled - and current_orig_seq_len - > self.sparse_attention_threshold) - - if current_q.shape[0] == 0: - continue - - if current_k.shape[0] == 0: - all_outputs.append( - torch.zeros( - (current_q.shape[0], current_q.shape[1], v.shape[2]), - device=q.device, - dtype=q.dtype, - )) - continue - - current_output = torch.empty_like(current_q) - group_size = int(current_q.size(-2) / current_k.size(-2)) - - if sparse_attn_enabled: - num_device_q_heads = current_q.size(-2) - heads_vertical_size = torch.empty(size=(num_device_q_heads, ), - dtype=torch.int32) - heads_slash_size = torch.empty(size=(num_device_q_heads, ), - dtype=torch.int32) - for head_id in range(current_q.size(-2)): - ( - ty, - vertical_size, - slash_size, - _, - ) = self.sparse_attention_config[head_id] - assert ty == "vertical_and_slash", "only support slash mode" - - if vertical_size == 30: - vertical_size += 100 - heads_vertical_size[head_id] = vertical_size - heads_slash_size[head_id] = slash_size - - current_output = self._dual_chunk_flash_attn_prefill_func( - current_q, # allheads - current_q_succ, - current_q_inter, - current_q_succ_critical, - current_q_inter_critical, - current_k, - current_v, - current_block_table, - softmax_scale, - chunk_size, - local_size, - scaling_factor[i].item(), - ke - ks, - sparse_attn_enabled=sparse_attn_enabled, - heads_vertical_size=heads_vertical_size, - heads_slash_size=heads_slash_size, - group_size=group_size) - else: - for head_id in range(current_q.size(-2)): - # (seq_len, num_heads, head_size) - current_q_head = current_q[:, head_id, :].unsqueeze(1) - current_q_succ_head = \ - current_q_succ[:, head_id, :].unsqueeze(1) - current_q_inter_head = \ - current_q_inter[:, head_id, :].unsqueeze(1) - current_q_succ_head_critical = \ - current_q_succ_critical[:, head_id, :].unsqueeze(1) - current_q_inter_head_critical = \ - current_q_inter_critical[:, head_id, :].unsqueeze(1) - if block_table is not None: - current_k_head = current_k[..., head_id // - group_size, :].unsqueeze(2) - current_v_head = current_v[..., head_id // - group_size, :].unsqueeze(2) - - else: - current_k_head = current_k[:, head_id, :].unsqueeze(1) - current_v_head = current_v[:, head_id, :].unsqueeze(1) - - current_out = self._dual_chunk_flash_attn_prefill_func( - current_q_head, - current_q_succ_head, - current_q_inter_head, - current_q_succ_head_critical, - current_q_inter_head_critical, - current_k_head, - current_v_head, - current_block_table, - softmax_scale, - chunk_size, - local_size, - scaling_factor[i].item(), - ke - ks, - sparse_attn_enabled=sparse_attn_enabled, - ) - current_output[:, head_id:head_id + 1, :] = current_out - all_outputs.append(current_output) - return torch.cat(all_outputs, dim=0) - - def _dual_chunk_flash_attn_prefill_func( - self, - q, - q_succ, - q_inter, - q_succ_critical, - q_inter_critical, - k, - v, - block_table, - softmax_scale: float, - chunk_size: int, - local_size: int, - scaling_factor: float, - k_length: int, - sparse_attn_enabled: Optional[bool] = True, - heads_vertical_size=None, - heads_slash_size=None, - group_size=None, - ): - flash_results = [] - chunk_len = chunk_size - local_size - - if block_table is not None: - block_size = v.shape[1] - if chunk_len % block_size != 0: - raise ValueError("chunk_len must be divisible by block_size.") - else: - block_size = 1 - - if self.original_max_position_embeddings > 0: - softmax_scale = softmax_scale * scaling_factor - - begin = k_length - q.shape[0] - while begin < k_length: - flash_per_chunk = [] - - prev_chunk_end_pos = (begin // chunk_len) * chunk_len - next_chunk_end_pos = prev_chunk_end_pos + chunk_len - end = min(next_chunk_end_pos, k_length) - qbegin = begin - (k_length - q.shape[0]) - qend = end - (k_length - q.shape[0]) - - qk_chunks = [] - q_states_intra = q[qbegin:qend] - # choose critical token - if block_table is not None: - block_tables_intra = _get_block(block_table, block_size, - prev_chunk_end_pos, end) - k_states_intra = k[block_tables_intra].view( - -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] - v_states_intra = v[block_tables_intra].view( - -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] - else: - block_tables_intra = None - k_states_intra = k[prev_chunk_end_pos:end] - v_states_intra = v[prev_chunk_end_pos:end] - - if sparse_attn_enabled: - last_q_size = min(qend - qbegin, self.sparse_attention_last_q) - _, num_device_k_heads, head_dim = k_states_intra.shape - k_states_intra = (k_states_intra.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, head_dim)) - v_states_intra = (v_states_intra.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, head_dim)) - qk_chunks.append( - (q_states_intra.transpose(0, 1)[:, -last_q_size:] * - softmax_scale) @ k_states_intra.permute(1, 2, 0)) - - if prev_chunk_end_pos - chunk_len >= 0: - q_states_succ = q_succ[qbegin:qend] - q_states_succ_critical = q_succ_critical[qbegin:qend] - if block_table is not None: - block_tables_succ = _get_block( - block_table, block_size, - prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) - k_states_succ = k[block_tables_succ].view( - -1, *k.shape[-2:])[:chunk_len] - v_states_succ = v[block_tables_succ].view( - -1, *v.shape[-2:])[:chunk_len] - else: - k_states_succ = k[prev_chunk_end_pos - - chunk_len:prev_chunk_end_pos] - v_states_succ = v[prev_chunk_end_pos - - chunk_len:prev_chunk_end_pos] - - if sparse_attn_enabled: - k_states_succ = (k_states_succ.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - v_states_succ = (v_states_succ.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - qk_chunks.append((q_states_succ_critical.transpose( - 0, 1)[:, -last_q_size:] * softmax_scale) - @ k_states_succ.permute(1, 2, 0)) - - if prev_chunk_end_pos - chunk_len * 2 >= 0: - q_states_inter = q_inter[qbegin:qend] - q_states_inter_critical = q_inter_critical[qbegin:qend] - if block_table is not None: - block_tables_inter = _get_block( - block_table, block_size, 0, - prev_chunk_end_pos - chunk_len) - k_states_inter = k[block_tables_inter].view( - -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] - v_states_inter = v[block_tables_inter].view( - -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] - else: - k_states_inter = k[:prev_chunk_end_pos - chunk_len] - v_states_inter = v[:prev_chunk_end_pos - chunk_len] - - if sparse_attn_enabled: - k_states_inter = (k_states_inter.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - v_states_inter = (v_states_inter.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - qk_chunks.append((q_states_inter_critical.transpose( - 0, 1)[:, -last_q_size:] * softmax_scale) - @ k_states_inter.permute(1, 2, 0)) - - if sparse_attn_enabled: - reversed_qk = qk_chunks[::-1] - qk = torch.cat(reversed_qk, dim=-1) - - qk[:, :, -last_q_size:] = torch.where( - self.last_q_mask[..., -last_q_size:, - -last_q_size:].to(qk.device), - qk[:, :, -last_q_size:], -torch.inf) - qk = F.softmax(qk, dim=-1, dtype=torch.float32) - - vertical = qk.sum(-2, keepdim=True) - vertical[..., :30] = torch.inf - - # Avoid sorting by using the min/max ints to fill the indexer - # buffers. - int32_max = torch.iinfo(torch.int32).max - int32_min = torch.iinfo(torch.int32).min - n_heads = qk.size()[0] - max_slash_topk = torch.max(heads_slash_size).item() - max_vertical_topk = torch.max(heads_vertical_size).item() - # store each head's slash topk, vertical topk - vertical = vertical.reshape((n_heads, -1)) - # prevent out of range when prompt size < max_vertical_topk - max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) - vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, - -1).indices - slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), - dtype=torch.int64, - device=qk.device) - for head_i in range(n_heads): - # (nqheads=1, lastq, k_len) - head_score = qk[head_i:head_i + 1, :, :] - slash_scores = _sum_all_diagonal_matrix(head_score) - if head_score.size(1) != 1: - # drop right up corner - slash_scores = slash_scores[..., :-last_q_size + 1] - slash_scores[..., -100:] = torch.inf - - head_slash_size = heads_slash_size[head_i] - head_slash_size = min(head_slash_size, vertical.size(-1)) - slash_topk = torch.topk(slash_scores, head_slash_size, - -1).indices - #(nheads, max_topk) - slash_topk_buffer[head_i, :head_slash_size] = slash_topk - - # reset heads topk - heads_slash_size[head_i] = head_slash_size - heads_vertical_size[head_i] = min( - heads_vertical_size[head_i], max_vertical_topk) - - # store - vertical_buffer = torch.full((n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - succ_slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - inter_vertical_buffer = torch.full( - (n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - inter_slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - - vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - succ_vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - inter_vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - - for head_i in range(n_heads): - vertical_topk = vertical_topk_buffer[ - head_i, :heads_vertical_size[head_i]] - # intra - intra_vertical_indices = vertical_topk[ - vertical_topk >= - prev_chunk_end_pos] - prev_chunk_end_pos - if intra_vertical_indices.nelement() == 0: - intra_vertical_indices = torch.cat([ - intra_vertical_indices, - torch.arange(0, - k_states_intra.size(0), - max(1, - k_states_intra.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - slash_topk = slash_topk_buffer[ - head_i, :heads_slash_size[head_i]] - intra_slash_indices = ( - (qk.size(-1) - 1) - - slash_topk[slash_topk >= prev_chunk_end_pos]) - # fill buffer - v_count = intra_vertical_indices.nelement() - s_count = intra_slash_indices.nelement() - vertical_size_buffer[head_i] = v_count - slash_sizes_buffer[head_i] = s_count - vertical_buffer[head_i, :v_count].copy_( - intra_vertical_indices) - slash_buffer[head_i, :s_count].copy_(intra_slash_indices) - # succ - if prev_chunk_end_pos - chunk_len >= 0: - succ_vertical_indices = vertical_topk[ - (vertical_topk < prev_chunk_end_pos) - & (vertical_topk >= prev_chunk_end_pos - - chunk_len)] - (prev_chunk_end_pos - chunk_len) - # TODO: support no vertical - if succ_vertical_indices.nelement() == 0: - succ_vertical_indices = torch.cat([ - succ_vertical_indices, - torch.arange( - 0, - k_states_succ.size(0), - max(1, - k_states_succ.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - succ_slash_indices = ( - (prev_chunk_end_pos + (qend - qbegin) - 1) - - slash_topk[((slash_topk >= - (prev_chunk_end_pos - chunk_len)) & - (slash_topk < (prev_chunk_end_pos + - (qend - qbegin))))]) - if succ_slash_indices.nelement() == 0: - succ_slash_indices = torch.cat([ - succ_slash_indices, - torch.arange( - 0, - k_states_succ.size(0), - max(1, - k_states_succ.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - # fill buffer - v_count = succ_vertical_indices.nelement() - s_count = succ_slash_indices.nelement() - succ_vertical_size_buffer[head_i] = v_count - succ_slash_sizes_buffer[head_i] = s_count - succ_vertical_buffer[head_i, :v_count].copy_( - succ_vertical_indices) - succ_slash_buffer[head_i, :s_count].copy_( - succ_slash_indices) - - if prev_chunk_end_pos - 2 * chunk_len >= 0: - inter_vertical_indices = vertical_topk[ - vertical_topk < prev_chunk_end_pos - chunk_len] - - if inter_vertical_indices.nelement() == 0: - inter_vertical_indices = torch.cat([ - inter_vertical_indices, - torch.arange( - 0, - k_states_inter.size(0), - max(1, - k_states_inter.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - inter_slash_indices = ( - (prev_chunk_end_pos - chunk_len + - (qend - qbegin) - 1) - - slash_topk[slash_topk < (prev_chunk_end_pos - - chunk_len + - (qend - qbegin))]) - if inter_slash_indices.nelement() == 0: - inter_slash_indices = torch.cat([ - inter_slash_indices, - torch.arange( - 0, - k_states_inter.size(0), - max(1, - k_states_inter.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - # fill buffer - v_count = inter_vertical_indices.nelement() - s_count = inter_slash_indices.nelement() - inter_vertical_size_buffer[head_i] = v_count - inter_slash_sizes_buffer[head_i] = s_count - inter_vertical_buffer[head_i, :v_count].copy_( - inter_vertical_indices) - inter_slash_buffer[head_i, :s_count].copy_( - inter_slash_indices) - else: - intra_vertical_indices, intra_slash_indices = None, None - succ_vertical_indices, succ_slash_indices = None, None - inter_vertical_indices, inter_slash_indices = None, None - - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_intra, - k_states_intra, - v_states_intra, - softmax_scale=softmax_scale, - causal=True, - stage="intra", - vertical_indices=vertical_buffer, - slash_indices=slash_buffer, - vertical_indices_count=vertical_size_buffer, - slash_indices_count=slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_intra, - k_states_intra, - v_states_intra, - softmax_scale=softmax_scale, - causal=True, - stage="intra", - vertical_indices=intra_vertical_indices, - slash_indices=intra_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - if prev_chunk_end_pos - chunk_len >= 0: - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_succ, - k_states_succ, - v_states_succ, - softmax_scale=softmax_scale, - causal=False, - stage="succ", - vertical_indices=succ_vertical_buffer, - slash_indices=succ_slash_buffer, - vertical_indices_count=succ_vertical_size_buffer, - slash_indices_count=succ_slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_succ, - k_states_succ, - v_states_succ, - softmax_scale=softmax_scale, - causal=False, - stage="succ", - vertical_indices=succ_vertical_indices, - slash_indices=succ_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - if prev_chunk_end_pos - chunk_len * 2 >= 0: - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_inter, - k_states_inter, - v_states_inter, - softmax_scale=softmax_scale, - causal=False, - stage="inter", - vertical_indices=inter_vertical_buffer, - slash_indices=inter_slash_buffer, - vertical_indices_count=inter_vertical_size_buffer, - slash_indices_count=inter_slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_inter, - k_states_inter, - v_states_inter, - softmax_scale=softmax_scale, - causal=False, - stage="inter", - vertical_indices=inter_vertical_indices, - slash_indices=inter_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - flash_results.append(flash_per_chunk) - begin = end - - attn_output = self._merge_attn_outputs(flash_results) - del flash_results - return attn_output - - def _do_flash_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - softmax_scale: float, - causal: bool = True, - max_seqlen_k: Optional[int] = None, - stage: str = "intra", - vertical_indices: Optional[torch.Tensor] = None, - slash_indices: Optional[torch.Tensor] = None, - vertical_indices_count: Optional[torch.Tensor] = None, - slash_indices_count: Optional[torch.Tensor] = None, - mergehead_softmax_scale: Optional[float] = None, - sparse_attn_enabled: Optional[bool] = False, - ): - if max_seqlen_k is None: - max_seqlen_k = key_states.shape[0] - - q_len = query_states.shape[0] - q_heads = query_states.shape[1] - h_dim = query_states.shape[-1] - - if sparse_attn_enabled: - assert slash_indices is not None - if stage == "intra": - assert causal - else: - assert not causal - - query_states = query_states.unsqueeze(0).transpose(1, 2) - key_states = key_states.unsqueeze(0).transpose(1, 2) - value_states = value_states.unsqueeze(0).transpose(1, 2) - - q = query_states - k = key_states - v = value_states - - if (vertical_indices_count is not None and \ - slash_indices_count is not None): - assert mergehead_softmax_scale is not None - - res, s_lse = _vertical_slash_sparse_attention( - q, - k, - v, - vertical_indices, - slash_indices, - mergehead_softmax_scale, - causal=causal, - stage=stage, - vertical_indices_count=vertical_indices_count, - slash_indices_count=slash_indices_count) - res = res.view(q_heads, q_len, - h_dim).transpose(0, 1) # (qlen,nhead,h_dim) - s_lse = s_lse.view( - q_heads, q_len, - 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) - else: - res, s_lse = _vertical_slash_sparse_attention(q, - k, - v, - vertical_indices, - slash_indices, - softmax_scale, - causal=causal, - stage=stage) - res = res.view(q_len, q_heads, h_dim) - s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() - return res, s_lse - - output, softmax_lse = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - softmax_scale=softmax_scale, - cu_seqlens_q=torch.tensor([0, query_states.shape[0]], - dtype=torch.int32, - device=query_states.device), - max_seqlen_q=query_states.shape[0], - cu_seqlens_k=torch.tensor([0, max_seqlen_k], - dtype=torch.int32, - device=query_states.device), - max_seqlen_k=max_seqlen_k, - causal=causal, - return_softmax_lse=True, - ) - softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, - 2).float() - return output, softmax_lse - - def _merge_attn_outputs( - self, - flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], - return_lse: Optional[bool] = False, - ) -> torch.Tensor: - attn_outputs_all = [] - logits_all = [] - - for flash_per_chunk in flash_results: - if len(flash_per_chunk) == 1: - attn_outputs_all.append(flash_per_chunk[0][0]) - if return_lse: - logits_all.append(flash_per_chunk[0][1]) - continue - - attn_outputs = torch.stack([ - flash_attn_output[0] for flash_attn_output in flash_per_chunk - ]) - logits = torch.stack([ - flash_attn_output[1] for flash_attn_output in flash_per_chunk - ]) - logits = logits.to(torch.float32) - - if return_lse: - max_val = torch.max(logits, dim=0).values - diff = torch.abs(logits[0] - logits[1]) - log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) - logits_all.append(log_sum_exp) - - max_logits = torch.max(logits, dim=0).values - stable_logits = logits - max_logits.unsqueeze(0) - lse_s = torch.exp(stable_logits).detach() - lse_sum = torch.sum(lse_s, dim=0) - lse_s /= lse_sum - attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) - attn_outputs_all.append(attn_outputs.sum(dim=0)) - - if return_lse: - return (torch.cat(attn_outputs_all, - dim=0), torch.cat(logits_all, dim=-1)) - else: - return torch.cat(attn_outputs_all, dim=0) - - def _dual_chunk_flash_attn_decoding( - self, - query: torch.Tensor, - query_succ: torch.Tensor, - query_inter: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - softmax_scale: float, - causal: bool, - alibi_slopes: Optional[torch.Tensor], - chunk_size: int, - local_size: int, - original_max_position_embeddings: int, - decode_meta: DualChunkFlashAttentionMetadata, - ): - if not causal: - raise ValueError( - "Dual Chunk Attention does not support causal=False") - - block_size = value_cache.shape[1] - chunk_len = chunk_size - local_size - if chunk_len % block_size != 0: - raise ValueError("chunk_len must be divisible by block_size.") - if original_max_position_embeddings > 0: - assert decode_meta.scaling_factor is not None - scaling_factor = decode_meta.scaling_factor - query = (query * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype - ) # possible for numerical issue, need to fused in the kernel - query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype) - query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype) - outputs_list = [] - softmax_lses_list = [] - - # intra-attention - intra_output, intra_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query, - key_cache, - value_cache, - decode_meta.block_tables_intra, - decode_meta.seq_lens_intra, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(intra_output) - softmax_lses_list.append(intra_softmax_lse) - - # succ-attention - if decode_meta.max_seq_len_succ: - succ_output, succ_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query_succ, - key_cache, - value_cache, - decode_meta.block_tables_succ, - decode_meta.seq_lens_succ, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(succ_output) - softmax_lses_list.append(succ_softmax_lse) - - # inter-attention - if decode_meta.max_seq_len_inter: - inter_output, inter_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query_inter, - key_cache, - value_cache, - block_table[:, :decode_meta.max_seq_len_inter], - decode_meta.seq_lens_inter, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(inter_output) - softmax_lses_list.append(inter_softmax_lse) - outputs = torch.stack(outputs_list, dim=0) - del outputs_list - softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) - del softmax_lses_list - max_logits = torch.max(softmax_lses, dim=0).values - stable_logits = softmax_lses - max_logits.unsqueeze(0) - lse_s = torch.exp(stable_logits).detach() - lse_sum = torch.sum(lse_s, dim=0) - lse_s /= lse_sum - outputs *= lse_s.unsqueeze(-1).transpose(2, 3) - return outputs.sum(0) - - def _dual_chunk_flash_attn_decoding_with_exp_sums( - self, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - softmax_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - ): - out, softmax_lse = flash_attn_with_kvcache( - q=query, - k_cache=key_cache, - v_cache=value_cache, - block_table=block_table, - cache_seqlens=cache_seqlens, - softmax_scale=softmax_scale, - alibi_slopes=alibi_slopes, - causal=causal, - return_softmax_lse=True, - ) - mask = (cache_seqlens == 0) - out[mask] = 0 - softmax_lse[mask] = -float("inf") - return out, softmax_lse - - -def _vertical_slash_sparse_attention( - query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] - key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] - value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] - v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] - s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] - softmax_scale: float, - causal: bool = True, - stage: str = "intra", - block_size_M: int = 64, - block_size_N: int = 64, - vertical_indices_count: torch.Tensor = None, # [N_HEADS,] - slash_indices_count: torch.Tensor = None, -): - if stage == "intra": - assert causal - else: - assert not causal - - batch_size, num_heads, context_size, head_dim = query.shape - _, _, kv_seq_len, _ = key.shape - - if head_dim not in [16, 32, 64, 128, 256, 512]: - target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim - query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) - key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) - value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - - v_idx = v_idx.to(torch.int32).reshape( - (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape( - (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] - q_seqlens = torch.tensor([context_size], - dtype=torch.int32, - device=query.device) - kv_seqlens = torch.tensor([kv_seq_len], - dtype=torch.int32, - device=query.device) - - if vertical_indices_count is not None and slash_indices_count is not None: - ( - block_count, - block_offset, - column_count, - column_index, - ) = ops.convert_vertical_slash_indexes_mergehead( - q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, - slash_indices_count, context_size, block_size_M, block_size_N, - causal) - else: - ( - block_count, - block_offset, - column_count, - column_index, - ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, - s_idx, context_size, - block_size_M, block_size_N, - causal) - - q = query.transpose(1, 2).contiguous() - k = key.transpose(1, 2).contiguous() - v = value.transpose(1, 2).contiguous() - out, lse = sparse_attn_func( - q, - k, - v, - block_count, - block_offset, - column_count, - column_index, - causal=causal, - softmax_scale=softmax_scale, - return_softmax_lse=True, - ) - out = out.transpose(1, 2).contiguous() - softmax_lse = lse.reshape(*lse.shape, 1) - return (out[..., :context_size, :head_dim], - softmax_lse[..., :context_size, :]) - - -def _sum_all_diagonal_matrix(mat: torch.tensor): - h, n, m = mat.shape - # Zero matrix used for padding - zero_mat = torch.zeros((h, n, n), device=mat.device) - # pads the matrix on left and right - mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) - # Change the strides - mat_strided = mat_padded.as_strided((1, n, n + m), - (n * (2 * n + m), 2 * n + m + 1, 1)) - # Sums the resulting matrix's columns - sum_diags = torch.sum(mat_strided, 1) - return sum_diags[:, 1:] # drop left bottom corner - - -def _get_block(block_table: torch.Tensor, block_size: int, begin: int, - end: int): - begin_block = begin // block_size - end_block = (end - 1) // block_size + 1 - return block_table[begin_block:end_block] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py deleted file mode 100755 index ba7a9afe86782..0000000000000 --- a/vllm/attention/backends/flash_attn.py +++ /dev/null @@ -1,932 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type - -import torch - -from vllm import _custom_ops as ops -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -# yapf: enable -from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, - get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) -from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - -logger = init_logger(__name__) - - -class FlashAttentionBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_name() -> str: - return "FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["FlashAttentionImpl"]: - return FlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return FlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class FlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = FlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - self._cached_decode_metadata = FlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_decode_metadata - - -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return FlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class FlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "FLASH_ATTN backend.") - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version( - requires_alibi=self.alibi_slopes is not None) - if is_quantized_kv_cache(self.kv_cache_dtype) and ( - not self.kv_cache_dtype.startswith("fp8") - or not flash_attn_supports_fp8()): - raise NotImplementedError( - f"FlashAttention does not support {self.kv_cache_dtype} " - "kv-cache on this device " - f"(FA supports fp8 = {flash_attn_supports_fp8()}).") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") - - # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: - assert ( - layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( - "key/v_scale is only supported in FlashAttention 3 with " - "base dtype bfloat16") - - attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes - logits_soft_cap: Optional[float] = self.logits_soft_cap - fp8_attention = kv_cache_dtype.startswith("fp8") - - if fp8_attention and not flash_attn_supports_fp8(): - raise NotImplementedError( - "FlashAttention does not support FP8 kv-cache on this device.") - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - # We skip updating the KV cache under two conditions: - # a. When the Attention Type is ENCODER. In this phase, we compute - # only the encoder attention without updating the cache. - # b. When both Key and Value are None. This occurs during - # cross-attention computation in the decoding phase, where the - # KV cache is already populated with the cross-attention - # tensor. Thus, we skip cache updates during this time. - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( - value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), # type: ignore[union-attr] - kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if fp8_attention: - kv_cache = kv_cache.view(torch.float8_e4m3fn) - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) - - if fp8_attention: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_query_tokens:] - decode_output = output[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - prefill_output = output[:num_prefill_query_tokens] - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ - _get_query_key_seq_metadata(prefill_meta, True, attn_type) - - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - if fp8_attention: - num_kv_tokens, num_kv_heads, head_size = key.shape - - key, _ = ops.scaled_fp8_quant( - key.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._k_scale) - key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) - - value, _ = ops.scaled_fp8_quant( - value.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._v_scale) - value = value.reshape( - (num_kv_tokens, num_kv_heads, head_size)) - - descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) - flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert prefill_meta.seq_lens is not None - assert prefill_meta.query_start_loc is not None - max_seq_len = max(prefill_meta.seq_lens) - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - seqused_k=prefill_meta.seq_lens_tensor, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - # Use flash_attn_varlen_func kernel for speculative decoding - # because different queries might have different lengths. - - assert decode_meta.max_decode_query_len is not None - # use only for actual varlen decoding - if decode_meta.max_decode_query_len > 1: - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support max_decode_query_len > 1" - ) - assert decode_meta.query_start_loc is not None - descale_shape = (decode_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - seqused_k=decode_meta.seq_lens_tensor, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - out=decode_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) - flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=decode_output.unsqueeze(1), - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - return output - - -def _get_query_key_seq_metadata( - attn_metadata, - is_prompt: bool, - attn_type: str, -) -> tuple: - """ - Returns sequence metadata for key and query based on the specified - attention type and whether input is a prompt. - - This function computes the starting locations and maximum sequence lengths - for key and query sequences for different attention types. - - Args: - attn_metadata: The attention metadata object - is_prompt (bool): A flag indicating if the input is a prompt - attn_type (AttentionType): The type of attention being used. - - Returns: - tuple: A tuple containing four integers: - - Starting location for the query sequence. - - Maximum sequence length for the query sequence. - - Starting location for the key sequence. - - Maximum sequence length for the key sequence. - - Raises: - AttributeError: If an invalid attention type is provided. - """ - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.seq_start_loc, max_seq_len) - - elif attn_type == AttentionType.ENCODER_DECODER: - # This is cross attention between the where the key - # is the precomputed encoder attention and query - # is the input sequence. - # Choose query max length based on whether it is prompt - # or not. - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER: - # For encoder attention both the query and the key are same i.e the - # encoder sequence. - return (attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, - attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _get_causal_option(attn_type: str) -> bool: - """ - Determine whether the given attention type is suitable for causal - attention mechanisms. - - Args: - attn_type (AttentionType): The type of attention being evaluated - - Returns: - bool: Returns `True` if the attention type is suitable for causal - attention (i.e., not encoder, encoder-only, or encoder-decoder), - otherwise returns `False`. - """ - return not (attn_type == AttentionType.ENCODER - or attn_type == AttentionType.ENCODER_ONLY - or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py deleted file mode 100644 index f23c096952ce0..0000000000000 --- a/vllm/attention/backends/flashmla.py +++ /dev/null @@ -1,227 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) - - -class FlashMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "FLASHMLA" - - @staticmethod - def get_impl_cls() -> Type["FlashMLAImpl"]: - return FlashMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["FlashMLAMetadata"]: - return FlashMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: - return FlashMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["FlashMLAState"]: - return FlashMLAState - - -@dataclass -class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, - torch.Tensor]] = None - decode_num_splits: Optional[torch.Tensor] = None - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - # TODO: cache assignment? - if decode_metadata is not None: - decode_metadata.decode_tile_scheduler_metadata=\ - self.decode_tile_scheduler_metadata - decode_metadata.decode_num_splits=\ - self.decode_num_splits - return decode_metadata - - -class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - m = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - - if m.num_decode_tokens > 0: - m.decode_tile_scheduler_metadata, m.decode_num_splits = \ - get_mla_metadata( - m.seq_lens_tensor[m.num_prefills:], - self.num_q_heads, - 1, # MQA for the decode path - ) - - return m - - -class FlashMLAState(MLACommonState[FlashMLAMetadata]): - - def __init__(self, *args, **kwds): - super().__init__(*args, **kwds) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - @contextmanager - def graph_capture(self, max_batch_size: int): - # Run a dummy `get_mla_metadata` so we can get the right shapes - self._graph_decoder_tile_scheduler_metadata, \ - self._graph_decode_num_splits = get_mla_metadata( - torch.ones( - max_batch_size, dtype=torch.int32, device=self.runner.device), - self.num_q_heads, - 1, # MQA for the decode path - ) - - with super().graph_capture(max_batch_size): - yield - - del self._graph_decoder_tile_scheduler_metadata - del self._graph_decode_num_splits - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - assert metadata.num_decode_tokens > 0 - - decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( - self._graph_seq_lens[:batch_size], - self.num_q_heads, - 1, # MQA for the decode path - ) - - self._graph_decoder_tile_scheduler_metadata.copy_( - decoder_tile_scheduler_metadata) - self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) - - metadata.decode_tile_scheduler_metadata=\ - self._graph_decoder_tile_scheduler_metadata - metadata.decode_num_splits=\ - self._graph_decode_num_splits[:batch_size + 1] - - return metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers["decode_tile_scheduler_metadata"] = \ - attn_metadata.decode_metadata.decode_tile_scheduler_metadata - input_buffers["decode_num_splits"] = \ - attn_metadata.decode_metadata.decode_num_splits - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - input_buffers["decode_tile_scheduler_metadata"].copy_( - attn_metadata.decode_metadata.decode_tile_scheduler_metadata) - input_buffers["decode_num_splits"].copy_( - attn_metadata.decode_metadata.decode_num_splits) - - -class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str] = None, - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) - - o, _ = flash_mla_with_kvcache( - q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, - num_splits=decode_meta.decode_num_splits, - softmax_scale=self.scale, - causal=True, - ) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py deleted file mode 100644 index c5ed4c6e40326..0000000000000 --- a/vllm/attention/backends/mla/common.py +++ /dev/null @@ -1,1310 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -# MLA Common Components - -This file implements common components for MLA implementations. - -First we define: - -Sq as Q sequence length -Skv as KV sequence length - -MLA has two possible ways of computing, a data-movement friendly approach and a -compute friendly approach, we generally want to use the compute friendly -approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) -and the data-movement friendly approach for "decode" (i.e. the ratio -Sq / Skv is "large"). - -NOTE what we deem small and large is currently determined by if its labelled -prefill or decode by the scheduler, but this is something we should probably -tune. - -Main reference: DeepseekV2 paper, and FlashInfer Implementation -(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - -Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. -* For decode (i.e. the memory friendly approach) the attention "simulates" a -multi-head attention, while the compute is similar to multi-query attention. - -Below is example of both paths assuming batchsize = 1 - -## More Extent Definitions: - -C Context length, `Skv - Sq` -H hidden size -N number of attention heads -Lq latent dimension for Q 1536 in DSV3 -Lkv latent dimension for K/V 512 in DSV3 -P nope dimension, no rope. 128 in DSV3 -R rope dimension, goes through rope. 64 in DSV3 -V V head dim. 128 in DSV3 - -## Vector/Matrix Definitions - -h_t hidden states (input to attention) shape [Sq, H] -q_c latent/compressed Q shape [Sq, Lq] -q_nope uncompressed Q (no-rope) shape [Sq, N, P] -q_pe uncompressed Q (rope) shape [Sq, N, R] -kv_c latent/compressed KV shape [Skv, Lkv] -k_pe decoupled k position embeddings shape [Skv, R] -new_kv_c new kv_c from current iter shape [Sq, Lkv] -new_k_pe new k_pe from current iter shape [Sq, R] -cache_kv_c cached k_c from previous iters shape [C, Lkv] -cache_k_pe cached k_pe from previous iters shape [C, R] -W_DQ project h_t to q_c shape [H, Lq] -W_UQ project q_c to q_nope shape [Lq, N * P] -W_QR project q_c to q_pe shape [Lq, N * R] -W_DKV project h_t to kv_c shape [H, Lkv] -W_UK project kv_c to k_nope shape [Lkv, N, P] -W_KR project h_t to k_pe shape [H, R] -W_UV project kv_c to v shape [Lkv, N, V] -W_O project v to h_t shape [N * V, H] - - -## Compute Friendly Approach (i.e. "_forward_prefill"): - -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(Sq, N, P) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) -k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) -k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) -v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) - -// MHA with QK headdim = P + R -// V headdim = V -// spda_o shape [Sq, N, V] -spda_o = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - v -) -return spda_o @ W_O - -NOTE: in the actual code, - `kv_b_proj` is [W_UK; W_UV] concatenated per head - `q_b_proj` is [W_UQ; W_QR] concatenated per head - `out_proj` is W_O - - -## Data-Movement Friendly Approach (i.e. "_forward_decode"): - -Runtime -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(-1, N, P) -ql_nope = einsum("snh,lnh->snl", q, W_UK) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) -k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) - -// MQA with QK headdim = Lkv + R -// V headdim = Lkv -// spda_o shape [Sq, N, Lkv] -// NOTE: this is less compute-friendly since Lkv > P -// but is more data-movement friendly since its MQA vs MHA -spda_o = scaled_dot_product_attention( - torch.cat([ql_nope, q_pe], dim=-1), - torch.cat([kv_c, k_pe], dim=-1), - kv_c -) - -o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) -return o.view(-1, N * V) @ self.num_heads @ W_O - - -## Chunked Prefill - -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to -the data-movement friendly approach if the chunk (i.e. `Sq`) is small. - -However, the compute-friendly approach can potentially run out of memory if Skv -is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` - -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a -fixed workspace size. - -The chunked prefill approach is as follows: - -MCC Max chunk of context to process per iter, computed dynamically, - used to bound the memory usage - -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(Sq, N, P) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) -new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) - -// MHA between queries and new KV -// with QK headdim = P + R -// V headdim = V -// curr_o shape [Sq, N, V] -// curr_lse shape [N, Sq], this is just order FA returns -curr_o, curr_lse = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - new_v, - casual=True, - return_softmax_lse=True -) - -// Compute attention with the already existing context -for chunk_idx in range(cdiv(C, MCC)): - chunk_start = chunk_idx * MCC - chunk_end = min(chunk_start + MCC, C) - Sc = chunk_end - chunk_start - cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] - cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] - cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) - cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) - - chunk_o, chunk_lse = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([cache_k_nope_chunk, - cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], - dim=-1), - cache_v_chunk, - casual=False, - return_softmax_lse=True - ) - - curr_o, curr_lse = merge_attn_states( - suffix_output=curr_o, - suffix_lse=curr_lse, - prefix_output=chunk_o, - prefix_lse=chunk_lse, - ) - -return curr_o @ W_O -""" - -import functools -from abc import abstractmethod -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type, TypeVar) - -import torch - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, MLAAttentionImpl) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON -from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down - -if HAS_TRITON: - from vllm.attention.ops.triton_flash_attention import triton_attention -else: - triton_attention = None - -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func - is_vllm_fa = True -except ImportError: - is_vllm_fa = False - try: - # For rocm use upstream flash attention - from flash_attn import flash_attn_varlen_func - except ImportError: - flash_attn_varlen_func = None - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - -is_hip = current_platform.is_rocm() - - -class MLACommonBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return MLACommonMetadata - - @staticmethod - def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: - return MLACommonMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["MLACommonState"]: - return MLACommonState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -T = TypeVar("T", bound="MLACommonMetadata") - - -class MLACommonState(AttentionState, Generic[T]): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - scheduler_config = runner.scheduler_config - self.model_config = runner.model_config - cache_config = runner.cache_config - - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - self.enable_prefix_caching = cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - self.context_chunk_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.context_chunk_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - self._positions = torch.zeros((max_batch_size, ), - dtype=torch.long, - device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._positions - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> T: - assert self._is_graph_capturing - - attn_metadata = self.runner.attn_backend.make_metadata( - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - use_cuda_graph=True, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - head_dim=self.runner.model_config.get_head_size()) - - if is_encoder_decoder_model: - raise NotImplementedError( - "MLACommonState does not support encoder/decoder yet") - - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } - if is_encoder_decoder_model: - raise NotImplementedError( - "MLACommonState does not support encoder/decoder yet") - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - def begin_forward(self, model_input): - if self.chunked_prefill_enabled or self.enable_prefix_caching: - if not hasattr(self, "context_chunk_workspace"): - # not self.runner.device does not return the correct device - # for this process, (init_device sets the correct device but - # only on the Worker). The only way Ive figured out to get the - # correct device is to allocate the workspace on the first call - # to begin_forward and use the device of the input tokens - assert model_input.input_tokens is not None - self.context_chunk_workspace = torch.empty( - (self.context_chunk_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=model_input.input_tokens.device, - ) - - model_input.attn_metadata.context_chunk_workspace = \ - self.context_chunk_workspace - - -@dataclass -class MLACommonMetadata(AttentionMetadata): - """Metadata for MLACommon. - - NOTE: Please read the comment at the top of the file before trying to - understand this class - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional[Any] = None - _cached_decode_metadata: Optional[Any] = None - - num_prefill_tokens: int - - # The dimension of the attention heads - head_dim: Optional[int] = None - - # Used when chunked prefill is enabled to simulate worst case workspace - # allocations, hopefully to avoid going OOM - is_profile_run: bool = False - - # New for MLA (compared to FlashAttention) - # For chunked prefill - context_chunk_cu_seq_lens: Optional[torch.Tensor] = None - context_chunk_starts: Optional[torch.Tensor] = None - context_chunk_seq_tot: Optional[List[int]] = None - context_chunk_max_seq_lens: Optional[List[int]] = None - # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted - context_chunk_workspace: Optional[torch.Tensor] = None - - def __post_init__(self): - supported_head_sizes = MLACommonBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") - - @property - def prefill_metadata(self): - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = self.__class__( - # Required by ModelRunner - use_cuda_graph=False, # Not Attention Related - # Required by Attention Metadata - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - head_dim=self.head_dim, - is_profile_run=self.is_profile_run, - # MLACommonMetadata Chunk prefill specific - context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, - context_chunk_starts=self.context_chunk_starts, - context_chunk_seq_tot=self.context_chunk_seq_tot, - context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self): - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - self._cached_decode_metadata = self.__class__( - # Required by ModelRunner - use_cuda_graph=self.use_cuda_graph, # Not Attention Related - # Required by Attention Metadata - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - head_dim=self.head_dim, - is_profile_run=self.is_profile_run) - return self._cached_decode_metadata - - -class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - BLOCK_TABLE_EXTENDER: list[list[int]] = [] - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.chunked_prefill_enabled = \ - self.runner.scheduler_config.chunked_prefill_enabled - self.enable_prefix_caching = \ - self.runner.cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - attn_state = self.input_builder.runner.attn_state - self.context_chunk_workspace_size = \ - attn_state.context_chunk_workspace_size - self.page_size = self.runner.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * - cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - context_chunk_cu_seq_lens = None - context_chunk_starts = None - context_chunk_seq_tot = None - context_chunk_max_seq_lens = None - - if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ - and self.num_prefills > 0 \ - and context_lens_tensor is not None \ - and context_lens_tensor[:self.num_prefills].max() > 0: - - # NOTE: it is recommend you read the `Chunked Prefill` section in - # the comment at the top of the file before trying to understand - # the following code - - num_prefills_with_context = \ - (context_lens_tensor[:self.num_prefills] > 0).sum().item() - - # currently we allocate an equal amount of workspace for each - # prefill in the batch, we could probably use a more advanced - # algorithm here and allocate more workspace to prefills with - # longer context lengths - max_context_chunk = \ - self.context_chunk_workspace_size // num_prefills_with_context - - # align max_context_chunk to page_size by rounding down, - # currently the `gather_and_maybe_dequant_cache` kernel cannot - # handle `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, self.page_size) - assert max_context_chunk > 0 - num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) - - # if `max_context_chunk = 256`, `num_chunks = 3`, and - # `num_prefills_with_context = 4`, create a tensor that looks like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - context_chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32)\ - .unsqueeze(1).expand(-1, self.num_prefills)\ - * max_context_chunk - chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ - .unsqueeze(0), context_chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) - _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ - .unsqueeze(-1) - context_chunk_cu_seq_lens = \ - torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) - context_chunk_max_seq_lens = \ - chunk_seq_lens.max(dim=1).values.tolist() - context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() - assert max(context_chunk_seq_tot) <= \ - self.context_chunk_workspace_size - - return self.runner.attn_backend.make_metadata( - # Required by ModelRunner - use_cuda_graph=use_captured_graph, # Not Attention Related - # Required by Attention Metadata - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, # Not Attention Related - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - head_dim=self.runner.model_config.get_head_size(), - is_profile_run=self.runner.in_profile_run, - # MLACommonMetadata Chunk prefill specific - context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, - context_chunk_starts=context_chunk_starts, - context_chunk_seq_tot=context_chunk_seq_tot, - context_chunk_max_seq_lens=context_chunk_max_seq_lens, - ) - - -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - q_lora_rank: Optional[int], - kv_lora_rank: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - qk_head_dim: int, - v_head_dim: int, - kv_b_proj: ColumnParallelLinear, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing not supported in V0.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_head_dim - self.v_head_dim = v_head_dim - self.kv_b_proj = kv_b_proj - - self.triton_fa_func = triton_attention - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 - self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) - - def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, - return_softmax_lse, **kwargs): - maybe_padded_v = v - if self._pad_v: - maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ - and not return_softmax_lse: - attn_out = self.triton_fa_func( - q, - k, - maybe_padded_v, - None, # output - kwargs["cu_seqlens_q"], - kwargs["cu_seqlens_k"], - kwargs["max_seqlen_q"], - kwargs["max_seqlen_k"], - kwargs["causal"], - softmax_scale, - None, # bias - ) - elif is_vllm_fa: - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - else: - # Use return_attn_probs instead of return_softmax_lse for RoCM - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_attn_probs=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - - # Unpack the output if there is multiple results, - # triton always returns (output, softmax_lse), - # vllm_flash_attn returns (output, softmax_lse) when - # `return_softmax_lse = True` - # flash_attn (RoCM) returns (output, softmax_lse, ...) when - # `return_attn_probs = True` - rest = None - if isinstance(attn_out, tuple): - attn_out, *rest = attn_out - - # Remain consistent with old `flash_attn_varlen_func` where there - # is only one output tensor if `return_softmax_lse` is False. - if return_softmax_lse: - assert rest is not None - return attn_out, rest[0] - return attn_out - - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - - def _compute_prefill_context( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor, - ): - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - assert prefill_metadata.context_chunk_seq_tot is not None - assert prefill_metadata.context_chunk_cu_seq_lens is not None - assert prefill_metadata.context_chunk_starts is not None - assert prefill_metadata.context_chunk_max_seq_lens is not None - assert prefill_metadata.context_lens_tensor is not None - - output = None - iters = len(prefill_metadata.context_chunk_seq_tot) - - # Fetch from attn_metadata directly, since it late bound by - # MLAAttentionState, grabbing it directly `attn_metadata` can avoid - # any weirdness around prefill_metadata caching - assert attn_metadata.context_chunk_workspace is not None - workspace = attn_metadata.context_chunk_workspace - - for i in range(iters): - toks = prefill_metadata.context_chunk_seq_tot[i] - - ops.gather_and_maybe_dequant_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], - batch_size=prefill_metadata.num_prefills, - kv_cache_dtype=self.kv_cache_dtype, - scale=k_scale, - seq_starts=prefill_metadata.context_chunk_starts[i], - ) - - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - - if output is None: - output = attn_output - output_lse = attn_softmax_lse - else: - output_tmp = torch.empty_like(output) - output_lse_tmp = torch.empty_like(output_lse) - merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, - prefix_output=output, - prefix_lse=output_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) - output = output_tmp - output_lse = output_lse_tmp - - return output, output_lse - - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor, - ) -> torch.Tensor: - - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - - has_context = prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - output = self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - - if has_context: - # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) - - output = torch.empty_like(suffix_output) - merge_attn_states( - output=output, - prefix_output=context_output, - prefix_lse=context_lse, - suffix_output=suffix_output, - suffix_lse=suffix_lse, - ) - - # unpad if necessary - if self._pad_v: - output = output[..., :v.shape[-1]] - - return output.flatten(start_dim=-2) - - @abstractmethod - def _forward_decode( - self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, # query in unified attn - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: T, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if output is not None: - raise NotImplementedError( - "output is not yet supported for MLAImplBase") - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLAImplBase") - - if attn_metadata.is_profile_run and \ - attn_metadata.context_chunk_workspace is not None: - # During the profile run try to simulate to worse case output size - # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` - # since this can be large - _ = torch.empty( - (attn_metadata.context_chunk_workspace.shape[0], - self.num_heads, self.qk_nope_head_dim + self.v_head_dim), - device=k_c_normed.device, - dtype=k_c_normed.dtype, - ) - - has_decode = attn_metadata.decode_metadata is not None - has_prefill = attn_metadata.prefill_metadata is not None - - num_prefill_tokens: int = attn_metadata.num_prefill_tokens - q = q.view(-1, self.num_heads, self.qk_head_dim) - - decode_q = q[num_prefill_tokens:] - - prefill_q = q[:num_prefill_tokens] - prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_k_c_normed = k_c_normed[:num_prefill_tokens] - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - - output = torch.empty(attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens, - self.v_head_dim * self.num_heads, - device=q.device, - dtype=q.dtype) - if has_prefill: - output[:num_prefill_tokens] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) - - if has_decode: - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - - output[num_prefill_tokens:] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - - return output diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py deleted file mode 100644 index e630a6c6de8c4..0000000000000 --- a/vllm/attention/backends/placeholder_attn.py +++ /dev/null @@ -1,340 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.multimodal import MultiModalPlaceholderMap - -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder) -from vllm.utils import async_tensor_h2d - -# Placeholder attention backend for models like Mamba and pooling models that -# lack attention. - - -class PlaceholderAttentionBackend(AttentionBackend): - """Placeholder backend for when no attention is needed.""" - - @staticmethod - def get_name() -> str: - return "NO_ATTENTION" - - @staticmethod - def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: - return PlaceholderAttentionImpl - - @staticmethod - def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: - return PlaceholderAttentionMetadataBuilder - - @staticmethod - def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: - return PlaceholderAttentionMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (1, 1, 1, 1, 1) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - return - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - return - - -@dataclass -class PlaceholderAttentionMetadata(AttentionMetadata): - """Attention metadata for prefill and decode batched together.""" - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # Placeholder. - block_tables: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None - _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - - # Placeholders - slot_mapping = torch.empty(0) - block_tables = torch.empty(0) - - self._cached_prefill_metadata = PlaceholderAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=0, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Placeholders - slot_mapping = torch.empty(0) - block_tables = torch.empty(0) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - - self._cached_decode_metadata = PlaceholderAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - - -class PlaceholderAttentionMetadataBuilder( - AttentionMetadataBuilder[PlaceholderAttentionMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - - self.input_builder = input_builder - self.runner = input_builder.runner - - def prepare(self): - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - """ - is_prompt = inter_data.is_prompt - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - - # Some input builders such as ModelInputForCPUBuilder do not have the - # "inter_data_list" attribute. - # Let's check inter_data_list exists before we reference it. - if hasattr(self.input_builder, "inter_data_list"): - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - if use_captured_graph: - num_decode_tokens = batch_size - self.num_prefill_tokens - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - # Placeholders - slot_mapping_tensor = torch.empty(0) - block_tables = torch.empty(0) - - return PlaceholderAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class PlaceholderAttentionImpl(AttentionImpl): - - def __init__(self, *args, **kwargs) -> None: - return - - def forward(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py new file mode 100644 index 0000000000000..313f941ebf934 --- /dev/null +++ b/vllm/attention/backends/registry.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention backend registry""" + +import enum +from typing import Optional + +from vllm.utils import resolve_obj_by_qualname + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + TRITON_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_ATTN = enum.auto() + ROCM_AITER_MLA = enum.auto() + ROCM_AITER_FA = enum.auto() # used for ViT attn backend + TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() + FLASHINFER_MLA = enum.auto() + TRITON_MLA = enum.auto() + CUTLASS_MLA = enum.auto() + FLASHMLA = enum.auto() + FLASHMLA_SPARSE = enum.auto() + FLASH_ATTN_MLA = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + NO_ATTENTION = enum.auto() + FLEX_ATTENTION = enum.auto() + TREE_ATTN = enum.auto() + ROCM_AITER_UNIFIED_ATTN = enum.auto() + + +BACKEND_MAP = { + _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 + _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 + _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 + _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 + _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 + _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 + _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 + _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 + _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 + _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 + _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 + _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 + _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 + _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 + _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 + _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 + _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 + _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 +} + + +def register_attn_backend(backend: _Backend, class_path: Optional[str] = None): + """ + Decorator: register a custom attention backend into BACKEND_MAPPING. + - If class_path is provided, use it. + - Otherwise, auto-generate from the class object. + Validation: only checks if 'backend' is a valid _Backend enum member. + Overwriting existing mappings is allowed. This enables other hardware + platforms to plug in custom out-of-tree backends. + """ + if not isinstance(backend, _Backend): + raise ValueError(f"{backend} is not a valid _Backend enum value.") + + def decorator(cls): + path = class_path or f"{cls.__module__}.{cls.__qualname__}" + BACKEND_MAP[backend] = path + return cls + + return decorator + + +def backend_to_class_str(backend: _Backend) -> str: + """Get the backend class string + + Args: + backend: The backend enum value + + Returns: + The backend class string + """ + return BACKEND_MAP[backend] + + +def backend_to_class(backend: _Backend) -> type: + """Get the backend class. + + Args: + backend: The backend enum value + + Returns: + The backend class + """ + backend_class_name = backend_to_class_str(backend) + return resolve_obj_by_qualname(backend_class_name) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + _Backend: enum value if backend_name is a valid in-tree type + None: otherwise it's an invalid in-tree type or an out-of-tree platform + is loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else None diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py deleted file mode 100644 index a2e9710437d95..0000000000000 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ /dev/null @@ -1,410 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Type, Union - -import torch - -import vllm.envs as envs -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.backends.utils import (compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, - get_aiter_mla_metadata) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - - -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_MLA - - -class AiterMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "ROCM_AITER_MLA" - - @staticmethod - def get_impl_cls() -> Type["AiterMLAImpl"]: - return AiterMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["AiterMLAMetadata"]: - return AiterMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: - return AiterMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["AiterMLAState"]: - return AiterMLAState - - -@dataclass -class AiterMLAMetadata(MLACommonMetadata): - # The following 5 tensors are for current version of AITER MLA - block_table_bound: Optional[torch.Tensor] = None - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_lens: Optional[torch.Tensor] = None - - # This is just to make new AITER MLA API work - # -- MTP support is not added yet. - qo_indptr: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self): - prefill_metadata = super().prefill_metadata - self._cached_prefill_metadata = prefill_metadata - - if prefill_metadata is not None: - prefill_metadata.paged_kv_indptr = self.paged_kv_indptr - prefill_metadata.paged_kv_indices = self.paged_kv_indices - prefill_metadata\ - .paged_kv_last_page_lens = self.paged_kv_last_page_lens - prefill_metadata.block_table_bound = self.block_table_bound - prefill_metadata.qo_indptr = self.qo_indptr - - # update the cache - self._cached_prefill_metadata = self.__class__( - **prefill_metadata.__dict__) - - return self._cached_prefill_metadata - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - - self._cached_decode_metadata = decode_metadata - - if decode_metadata is not None: - decode_metadata.paged_kv_indptr = self.paged_kv_indptr - decode_metadata.paged_kv_indices = self.paged_kv_indices - decode_metadata\ - .paged_kv_last_page_lens = self.paged_kv_last_page_lens - decode_metadata.block_table_bound = self.block_table_bound - decode_metadata.qo_indptr = self.qo_indptr - - # update the cache - self._cached_decode_metadata = self.__class__( - **decode_metadata.__dict__) - - return self._cached_decode_metadata - - -class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - super().__init__(input_builder) - assert self.block_size == 1, "AITER MLA requires only block size 1." - - def prepare(self): - super().prepare() - self.paged_kv_indices: list[int] = [] - self.paged_kv_indptr: list[int] = [0] - self.paged_kv_last_page_lens: list[int] = [] - self.total_blocks = 0 - self.qo_indptr: list[int] = [0] - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - if is_profile_run: - return - - # Update paged_kv_* tensors only for non-profile run - block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - self.qo_indptr.append(self.qo_indptr[-1] + 1) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_lens.append(last_page_len) - - def build(self, seq_lens: list[int], query_lens: list[int], - cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: - metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - if use_captured_graph: - last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * - cuda_graph_pad_size) - self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) - last_qo_indptr = self.qo_indptr[-1] - self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) - - # For current version of AITER MLA - if len(self.paged_kv_indptr) > 0: - # extend to the maximum number of blocks as returned by the - # scheduler - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) - paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, - device=device, - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, - device=device, - dtype=torch.int) - paged_kv_last_page_lens_tensor = torch.tensor( - self.paged_kv_last_page_lens, device=device, dtype=torch.int) - block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - - 1, - device=device, - dtype=torch.int) - - qo_indptr = torch.tensor(self.qo_indptr, - device=device, - dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_lens_tensor = None - block_table_bound_tensor = None - qo_indptr = None - - metadata.paged_kv_indptr = paged_kv_indptr_tensor - metadata.paged_kv_indices = paged_kv_indices_tensor - metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor - metadata.block_table_bound = block_table_bound_tensor - metadata.qo_indptr = qo_indptr - - return metadata - - -class AiterMLAState(MLACommonState[AiterMLAMetadata]): - - @contextmanager - def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens, qo_indptr = \ - get_aiter_mla_metadata( - max_batch_size=max_batch_size, - block_size=self.runner.block_size, - max_block_per_batch=\ - self.runner.get_max_block_per_batch(), - device=self.runner.device) - self._paged_kv_indices_tensor = kv_indices - self._paged_kv_indptr_tensor = kv_indptr - self._paged_kv_last_page_lens_tensor = last_page_lens - self._qo_indptr_tensor = qo_indptr - - with super().graph_capture(max_batch_size): - yield - - del self._paged_kv_indices_tensor - del self._paged_kv_indptr_tensor - del self._paged_kv_last_page_lens_tensor - del self._qo_indptr_tensor - - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: - - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - - paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] - paged_kv_indices = self._paged_kv_indices_tensor - paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: - batch_size] - qo_indptr = self._qo_indptr_tensor[:batch_size + 1] - - metadata.paged_kv_indptr = paged_kv_indptr - metadata.paged_kv_indices = paged_kv_indices - metadata.paged_kv_last_page_lens = paged_kv_last_page_lens - metadata.qo_indptr = qo_indptr - - return metadata - - def get_graph_input_buffers(self, - attn_metadata: AiterMLAMetadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers[ - 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr - input_buffers[ - "paged_kv_indices"] = attn_metadata.\ - decode_metadata.paged_kv_indices - input_buffers[ - "paged_kv_last_page_lens"] = attn_metadata.\ - decode_metadata.paged_kv_last_page_lens - input_buffers['qo_indptr'] = attn_metadata.qo_indptr - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata: AiterMLAMetadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ - 0] - input_buffers["paged_kv_indptr"].copy_( - attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) - input_buffers["paged_kv_indices"][:num_total_blocks].copy_( - attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) - input_buffers["paged_kv_last_page_lens"].copy_( - attn_metadata.decode_metadata.paged_kv_last_page_lens, - non_blocking=True) - input_buffers["qo_indptr"].copy_( - attn_metadata.decode_metadata.qo_indptr, non_blocking=True) - - -class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - from aiter import flash_attn_varlen_func - self.flash_attn_varlen_func = flash_attn_varlen_func - - def _flash_attn_varlen_diff_headdims( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - softmax_scale: float, return_softmax_lse: bool, - **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: - output = self.flash_attn_varlen_func( - q, - k, - v, - **kwargs, - ) - - return output - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: AiterMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.empty(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.qo_indptr, - attn_metadata.max_query_len, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py deleted file mode 100644 index e4c27a0ef36e9..0000000000000 --- a/vllm/attention/backends/rocm_flash_attn.py +++ /dev/null @@ -1,952 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer ROCm GPUs.""" -import itertools -from dataclasses import dataclass -from functools import cache -from typing import List, Optional, Tuple, Type - -import torch - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) -from vllm.platforms import current_platform - -logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 256 - - -@cache -def is_rocm_aiter_paged_attn_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ - and envs.VLLM_ROCM_USE_AITER \ - - -@cache -def _get_paged_attn_module() -> PagedAttention: - """ - Initializes the appropriate PagedAttention module from `attention/ops`, - which is used as helper function - by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. - - The choice of attention module depends on whether - AITER paged attention is enabled: - - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - - Otherwise, it defaults to using the original `PagedAttention`. - """ - if is_rocm_aiter_paged_attn_enabled(): - # Import AITERPagedAttention only when the flag is enabled - from vllm.attention.ops.rocm_aiter_paged_attn import ( - AITERPagedAttention) - return AITERPagedAttention() - return PagedAttention() - - -class ROCmFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ROCM_FLASH" - - @staticmethod - def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: - return ROCmFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return ROCmFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: - return ROCmFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - paged_attn = _get_paged_attn_module() - return paged_attn.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - paged_attn = _get_paged_attn_module() - paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - paged_attn = _get_paged_attn_module() - paged_attn.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.block_tables is not None - - self._cached_prefill_metadata = ROCmFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = ROCmFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - -class ROCmFlashAttentionMetadataBuilder( - CommonMetadataBuilder[ROCmFlashAttentionMetadata]): - - _metadata_cls = ROCmFlashAttentionMetadata - - -def _make_alibi_bias(alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: Optional[List[int]], - make_attn_mask: bool = True) -> List[torch.Tensor]: - attn_biases = [] - if seq_lens: - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat( - (num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( - alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) - - return attn_biases - - -def _get_seq_len_block_table_args( - attn_metadata: ROCmFlashAttentionMetadata, - attn_type: str, -) -> tuple: - ''' - The particular choice of sequence-length - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths - Encoder attn -> select encoder sequence lengths fields - Encoder-only attn -> select prefill sequence lengths with - bidirectional attention - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention, encoder-only - - Returns: - - * Appropriate sequence-lengths tensors for query and key - * Appropriate max sequence-length scalar - * Causal masking flag - ''' - - if attn_type == AttentionType.ENCODER: - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - causal_mask = False - - # No block tables associated with encoder attention - return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, - query_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_lens, causal_mask) - - elif attn_type == AttentionType.ENCODER_ONLY: - # For encoder-only models, we use the prefill sequence lengths - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - max_seq_len = attn_metadata.max_prefill_seq_len - # Encoder-only models typically use bidirectional attention - causal_mask = False - - return (query_seq_start_loc, max_seq_len, query_seq_start_loc, - max_seq_len, attn_metadata.seq_lens, causal_mask) - - elif attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - max_seq_len = attn_metadata.max_prefill_seq_len - causal_mask = True - - return (query_seq_start_loc, max_seq_len, query_seq_start_loc, - max_seq_len, attn_metadata.seq_lens, causal_mask) - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - key_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - causal_mask = False - - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (query_start_loc, attn_metadata.max_prefill_seq_len, - key_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.seq_lens, causal_mask) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class ROCmFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| - |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "ROCM_FLASH backend.") - if use_irope: - logger.warning_once( - "Using irope in ROCm Flash Attention is not supported yet, it " - "will fail back to global attention for long context.") - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - self.logits_soft_cap = 0.0 - else: - self.logits_soft_cap = logits_soft_cap - self.attn_type = attn_type - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.paged_attn_module = _get_paged_attn_module() - supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( - ) - - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.use_naive_attn = False - # NOTE: Allow for switching between Triton and CK. Defaulting to triton. - self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN - if self.use_triton_flash_attn: - if logits_soft_cap is not None: - raise ValueError( - "ROCm Triton FlashAttention does not support attention" - " logits soft capping." - " please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") - - from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 - triton_attention) - self.triton_attn_func = triton_attention - logger.debug("Using Triton FA in ROCmBackend") - if self.sliding_window != (-1, -1): - logger.warning("ROCm Triton FA does not currently support " - "sliding window attention. If using half " - "precision, please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") - else: - # if not using triton, navi3x/navi21/navi10 do not use flash-attn - # either - if not current_platform.has_device_capability(90): - self.use_naive_attn = True - else: - try: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.fa_attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") - except ModuleNotFoundError: - self.use_naive_attn = True - - if self.use_naive_attn: - if logits_soft_cap is not None: - raise ValueError( - "ROCm Naive FlashAttention does not support " - "attention logits soft capping.") - - self.sdpa_attn_func = _sdpa_attention - logger.debug("Using naive (SDPA) attention in ROCmBackend") - - self.aiter_kv_scales_initialized = False - self.force_fp8_attention = ( - get_current_vllm_config() is not None - and get_current_vllm_config().model_config.override_attention_dtype - == "fp8") - - def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - tokens, n_kv_heads, head_dim = x.shape - return (x[:, :, - None, :].expand(tokens, n_kv_heads, n_rep, - head_dim).reshape(tokens, n_kv_heads * n_rep, - head_dim)) - - def fused_output_quant_supported(self, quant_key: QuantKey): - if self.use_triton_flash_attn: - return quant_key == kFp8StaticTensorSym - - # Only supported in the Triton backend - return False - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * ROCmFlashAttentionImpl.forward() may be invoked for both self- and - cross-attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) - * ENCODER_ONLY: bidirectional attention with no KV caching; - use prefill sequence attributes - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None and not self.use_triton_flash_attn: - raise NotImplementedError( - "fused output quantization only supported for Triton" - " implementation in ROCMFlashAttentionImpl for now") - - if output_block_scale is not None: - raise NotImplementedError( - "fused nvfp4 output quantization is not supported" - " for ROCMFlashAttentionImpl") - - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - paged_attn = self.paged_attn_module - - # Reshaping kv tensors is required for AITER paged attention kernel - # because it works on a different tensor shape, - # when the size of one element is one byte (int8/fp8 dtypes). - # This reshaping is only required on the first forward call - # and the kv cache must not be empty. - if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 - and not self.aiter_kv_scales_initialized - and kv_cache.shape != torch.Size([0])): - num_blocks = kv_cache.shape[1] - block_size = kv_cache.shape[2] // (self.num_kv_heads * - self.head_size) - k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), - dtype=torch.float32, - device=kv_cache.device) - v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), - dtype=torch.float32, - device=kv_cache.device) - self.aiter_kv_scales_initialized = True - k_scale.fill_(layer._k_scale.item()) - v_scale.fill_(layer._v_scale.item()) - layer._k_scale = k_scale - layer._v_scale = v_scale - - # Only update KV cache for decoder self-attention - # and encoder-decoder cross-attention - if self.attn_type not in [ - AttentionType.ENCODER, AttentionType.ENCODER_ONLY - ] and kv_cache.numel() > 0: - key_cache, value_cache = paged_attn.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if key is not None and value is not None: - # Reshape the input keys and values and store them in the - # cache. If kv_cache is not provided, the new key and value - # tensors are not cached. This happens during the initial - # memory profiling run. - paged_attn.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping - if self.attn_type != AttentionType.ENCODER_DECODER else - attn_metadata.cross_slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.attn_type != AttentionType.ENCODER: - num_prefill_tokens = attn_metadata.num_prefill_tokens - elif self.attn_type == AttentionType.ENCODER_ONLY: - # For encoder-only models, all tokens are processed in one go - num_prefill_tokens = query.shape[0] - else: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - - # For encoder-only and encoder models, - # we process all tokens at once - # For decoder and encoder-decoder, - # we may need to limit key/value to prefill tokens - if key is not None and value is not None \ - and self.attn_type not in [AttentionType.ENCODER_DECODER, - AttentionType.ENCODER_ONLY]: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - # normal attention and DECODER - if self.attn_type == AttentionType.DECODER and ( - kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = (prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - attn_metadata.seq_lens, True) - # prefix-enabled attention and ENCODER/ENCODER_DECODER - else: - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = _get_seq_len_block_table_args( - prefill_meta, self.attn_type) - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # triton attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - attn_masks = None - if self.use_triton_flash_attn: - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - seq_lens, - make_attn_mask=causal_mask) # type: ignore - - use_fp8_scales = (layer._q_scale and layer._k_scale - and layer._v_scale and layer._prob_scale - and (self.kv_cache_dtype == "fp8" - or self.force_fp8_attention)) - - full_scales = ( - layer._q_scale.item(), layer._k_scale.item(), - layer._v_scale.item(), - layer._prob_scale.item()) if use_fp8_scales else None - self.triton_attn_func( - query, - key, - value, - output[:num_prefill_tokens], - query_seq_start_loc, - key_seq_start_loc, - query_max_seq_len, - key_max_seq_len, - causal_mask, - self.scale, - attn_masks[0][None] - if attn_masks is not None else None, - full_scales, - output_scale, - ) - elif self.use_naive_attn: - if self.num_kv_heads != self.num_heads: - # Interleave for MQA workaround. - key = self.repeat_kv(key, self.num_queries_per_kv) - value = self.repeat_kv(value, self.num_queries_per_kv) - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - attn_metadata.seq_lens, - make_attn_mask=causal_mask) # type: ignore - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - # sdpa math backend attention - self.sdpa_attn_func( - query, - key, - value, - output[:num_prefill_tokens], - query_seq_start_loc, - num_prefill_tokens, - self.num_heads, - self.head_size, - self.scale, - attn_masks, - ) - else: - # upstream FA does not support an output arg, copy - output[:num_prefill_tokens] = self.fa_attn_func( - q=query, - k=key, - v=value, - cu_seqlens_q=query_seq_start_loc, - cu_seqlens_k=key_seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=key_max_seq_len, - softmax_scale=self.scale, - causal=causal_mask, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - - else: - # prefix-enabled attention - - # not applicable for encoder-only models - if self.attn_type != AttentionType.ENCODER_ONLY: - output[:num_prefill_tokens] = paged_attn.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) - # Skip decode phase for encoder-only models - if (decode_meta := attn_metadata.decode_metadata) and ( - self.attn_type != AttentionType.ENCODER_ONLY): - # Decoding run. - # Whether to use rocm custom paged attention or not - num_seqs, num_heads, head_size = decode_query.shape - block_size = value_cache.shape[3] - gqa_ratio = num_heads // self.num_kv_heads - from vllm.platforms.rocm import use_rocm_custom_paged_attention - use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len, self.sliding_window, - self.kv_cache_dtype, self.alibi_slopes) - - if use_custom: - max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type - != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len) - assert max_seq_len is not None - max_num_partitions = ( - (max_seq_len + _PARTITION_SIZE_ROCM - 1) // - _PARTITION_SIZE_ROCM) - assert _PARTITION_SIZE_ROCM % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=query.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - - query_start_loc = None - ops.paged_attention_rocm( - output[num_prefill_tokens:], - exp_sums, - max_logits, - tmp_output, - decode_query, - key_cache, - value_cache, - self.num_kv_heads, - self.scale, - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, - query_start_loc, - block_size, - max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - output_scale, - ) - else: - # PagedAttention does not support fused quant, manually quantize - if output_scale is None: - out_pa = output[num_prefill_tokens:] - else: - out_pa = torch.empty_like(output[num_prefill_tokens:], - dtype=query.dtype) - - out_pa[:] = paged_attn.forward_decode( - decode_query, - key_cache, - value_cache, - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, - decode_meta.max_decode_seq_len - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Manually perform quantization - if output_scale is not None: - out_uq = out_pa.view(-1, self.num_heads * self.head_size) - out_q = output.view(-1, self.num_heads * self.head_size) - ops.scaled_fp8_quant(out_uq, - output_scale, - output=out_q[num_prefill_tokens:]) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - -def _sdpa_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - seq_lens: torch.Tensor, - num_tokens: int, - num_heads: int, - head_size: int, - scale: float, - attn_masks: Optional[List[torch.Tensor]] = None, -) -> torch.Tensor: - start = 0 - assert output.shape == (num_tokens, num_heads, head_size) - assert output.dtype == query.dtype - assert output.device == query.device - - for i, seq_len in enumerate(seq_lens): - end = start + seq_len - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH): - sub_out = torch.nn.functional.scaled_dot_product_attention( - query[:, start:end, :], - key[:, start:end, :], - value[:, start:end, :], - dropout_p=0.0, - is_causal=attn_masks is None, - attn_mask=attn_masks[i] if attn_masks else None, - scale=scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out - start = end - - return output diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py deleted file mode 100644 index fba5b5f6bca86..0000000000000 --- a/vllm/attention/backends/triton_mla.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) -from vllm.attention.ops.triton_decode_attention import decode_attention_fwd - - -class TritonMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_impl_cls() -> Type["TritonMLAImpl"]: - return TritonMLAImpl - - -class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - num_kv_splits = 4 # TODO: heuristic - - # TODO(lucas) Allocate ahead of time - attn_logits = torch.empty( - ( - B, - self.num_heads, - num_kv_splits, - # NOTE(lucas) idk why the +1 is here but sglang has it so we - # just mirror that - self.kv_lora_rank + 1, - ), - dtype=torch.float32, - device=q.device, - ) - - # Add a head dim of 1 - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - - # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 34e059067d84d..46a87bdd1f7e1 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,593 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend utils""" -from collections import defaultdict -from contextlib import contextmanager + from dataclasses import dataclass -from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import Optional -import numpy as np -import torch - -from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, - AttentionState) -from vllm.attention.backends.abstract import AttentionType from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.worker.model_runner_base import ModelRunnerBase - -# Error string(s) for encoder/decoder -# unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " - "with encoder/decoder models.") - PAD_SLOT_ID = -1 -# Switch to numpy implementation of compute_slot_mapping -# if we have at least this many elements. Could be tuned further. -_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - - -def is_block_tables_empty(block_tables: Union[None, Dict]): - """ - Check if block_tables is None or a dictionary with all None values. - """ - if block_tables is None: - return True - return (isinstance(block_tables, dict) - and all(value is None for value in block_tables.values())) - - -def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, - context_len: int, sliding_window: int): - """ - Compute the start index of slot mapping. - """ - start_idx = 0 - if is_prompt and sliding_window is not None: - start_idx = max(0, query_len - sliding_window) - return start_idx - - -def _compute_slot_mapping_python(slot_mapping: List[int], - block_table: List[int], range_start: int, - range_end: int, block_size: int): - for i in range(range_start, range_end): - block_number = block_table[i // block_size] - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - -def _compute_slot_mapping_numpy(slot_mapping: List[int], - block_table: List[int], range_start: int, - range_end: int, block_size: int): - block_table_array = np.array(block_table) - idx = np.arange(range_start, range_end) - block_offset = idx % block_size - idx //= block_size - seq_slot_mapping_array = block_table_array[idx] - seq_slot_mapping_array *= block_size - seq_slot_mapping_array += block_offset - slot_mapping.extend(seq_slot_mapping_array) - - -def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], - seq_id: int, seq_len: int, context_len: int, - start_idx: int, block_size: int, - block_tables: Dict[int, List[int]]): - """ - Compute slot mapping. - """ - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([PAD_SLOT_ID] * seq_len) - return - - # Mask the [0, start_idx) tokens of the prompt with - # PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - padding_mask_len = max(0, start_idx - context_len) - slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) - - range_start = max(start_idx, context_len) - range_end = seq_len - numel = range_end - range_start - block_table = block_tables[seq_id] - - # numpy implementation will be faster than python if we have - # many elements, otherwise it will be slower. - if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: - _compute_slot_mapping_python(slot_mapping, block_table, range_start, - range_end, block_size) - else: - _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, - range_end, block_size) - - -TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') - - -class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): - - _metadata_cls: Type[TAttentionMetadata] - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if inter_data.prefix_cache_hit: - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - for i, block_table in enumerate(self.block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.from_numpy(input_block_tables).to( - device, non_blocking=True) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, "query_lens: {}".format(query_lens) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return self._metadata_cls( # type: ignore - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class CommonAttentionState(AttentionState): - - def __init__(self, runner: "ModelRunnerBase"): - self.runner = runner - self._is_graph_capturing = False - - @contextmanager - def graph_capture(self, max_batch_size: int): - - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - - def graph_clone(self, batch_size: int) -> "CommonAttentionState": - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - use_cuda_graph=True, - ) - if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ - f"Expected attn_backend name to be either 'XFORMERS'," \ - f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ - f"got '{self.runner.attn_backend.get_name()}'" - self._update_captured_metadata_for_enc_dec_model( - batch_size=batch_size, attn_metadata=attn_metadata) - - return attn_metadata - - def get_graph_input_buffers( - self, - attn_metadata, - is_encoder_decoder_model: bool = False) -> Dict[str, Any]: - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } - if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ - f"Expected attn_backend name to be either 'XFORMERS'," \ - f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ - f"got '{self.runner.attn_backend.get_name()}'" - self._add_additional_input_buffers_for_enc_dec_model( - attn_metadata=attn_metadata, input_buffers=input_buffers) - return input_buffers - - def prepare_graph_input_buffers( - self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False) -> None: - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers and - # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ - f"'FLASH_ATTN', but "\ - f"got '{self.runner.attn_backend.get_name()}'" - self._prepare_input_buffers_for_enc_dec_model( - attn_metadata, input_buffers) - - def begin_forward(self, model_input) -> None: - return - - def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, - attn_metadata): - """ - Updates the attention metadata parameters for CUDA graph capture in an - encoder-decoder model. - - This method modifies attention-related tensors and metadata required - for CUDA graph capture in encoder-decoder models. Specifically, it - updates the cross-attention and encoder sequence tensors in the - AttentionMetadata object. - """ - # During decode phase the cross_slot_mapping will be empty. Hence set - # an empty tensor for CUDA Graph capture. - attn_metadata.cross_slot_mapping = torch.tensor( - [], dtype=torch.int).cuda() - attn_metadata.cross_block_tables = torch.full( - (batch_size, self.runner.get_max_block_per_batch()), - 1, - dtype=torch.int).cuda() - attn_metadata.encoder_seq_lens = torch.full((batch_size, ), - 1, - dtype=torch.int).cuda() - attn_metadata.encoder_seq_lens_tensor = torch.full( - (batch_size, ), 1, dtype=torch.int).cuda() - attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture - attn_metadata.num_encoder_tokens = 0 - - def _add_additional_input_buffers_for_enc_dec_model( - self, attn_metadata, input_buffers: Dict[str, Any]): - """ - Saves additional input buffers specific to the encoder-decoder model - from the attention metadata. - - This method extracts and stores encoder-decoder related input buffers - from the `attn_metadata` into the `input_buffers` dictionary. The - buffers include encoder sequence lengths, cross-slot mappings, and - cross-block tables, which are essential for the encoder-decoder model - during CUDA graph replay. - """ - input_buffers["encoder_seq_lens_tensor"] = ( - attn_metadata.decode_metadata.encoder_seq_lens_tensor) - input_buffers["cross_slot_mapping"] = ( - attn_metadata.decode_metadata.cross_slot_mapping) - input_buffers["cross_block_tables"] = ( - attn_metadata.decode_metadata.cross_block_tables) - - def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, - input_buffers: Dict[str, - Any]): - """ - Populates input buffers with data from the encoder-decoder model's - attention metadata. - - This method fills the input buffers with encoder-decoder specific - tensors. It copies data from the `attn_metadata` and keyword arguments - (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. - The copied data includes attention-related metadata as well as input - IDs and positional information for the encoder. - """ - input_buffers["encoder_seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.encoder_seq_lens_tensor, - non_blocking=True) - input_buffers["cross_slot_mapping"].copy_( - attn_metadata.decode_metadata.cross_slot_mapping, - non_blocking=True) - input_buffers["cross_block_tables"].copy_( - attn_metadata.decode_metadata.cross_block_tables, - non_blocking=True) - - -def is_all_encoder_attn_metadata_set(attn_metadata): - ''' - All attention metadata required for encoder attention is set. - ''' - return ((attn_metadata.encoder_seq_lens is not None) - and (attn_metadata.encoder_seq_lens_tensor is not None) - and (attn_metadata.max_encoder_seq_len is not None)) - - -def is_all_cross_attn_metadata_set(attn_metadata): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return (attn_metadata.is_all_encoder_attn_metadata_set - and (attn_metadata.cross_slot_mapping is not None) - and (attn_metadata.cross_block_tables is not None)) - - -def get_seq_len_block_table_args( - attn_metadata, - is_prompt: bool, - attn_type: str, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def get_num_prefill_decode_query_kv_tokens( - attn_metadata, - attn_type: str, -) -> Tuple[int, int, int]: - """ - Calculate the number of prefill and decode tokens for query, key/value - based on the attention metadata and the specified attention type. - - Args: - attn_metadata (AttentionMetadata): Attention Metadata object. - attn_type (AttentionType): The type of attention being used. - Returns: - Tuple[int, int, int]: A tuple containing three integers: - - The number of prefill query tokens. - - The number of prefill key/value tokens. - - The number of decode query tokens. - - Raises: - AssertionError: If the number of encoder tokens in `attn_metadata` - is `None` when required for the calculations. - """ - num_prefill_query_tokens = 0 - num_decode_query_tokens = 0 - num_prefill_kv_tokens = 0 - if attn_type == AttentionType.ENCODER: - # Encoder attention is only invoked during prefill phase. - # The same input servers a both query and key. - assert attn_metadata.num_encoder_tokens is not None - num_prefill_query_tokens = attn_metadata.num_encoder_tokens - num_prefill_kv_tokens = attn_metadata.num_encoder_tokens - num_decode_query_tokens = 0 - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_query_tokens = attn_metadata.num_prefill_tokens - # The key is the encoder/cross-attention. - num_prefill_kv_tokens = attn_metadata.num_encoder_tokens - num_decode_query_tokens = attn_metadata.num_decode_tokens - else: # attn_type == AttentionType.DECODER or - # attn_type == AttentionType.ENCODER_ONLY - num_prefill_query_tokens = attn_metadata.num_prefill_tokens - num_prefill_kv_tokens = attn_metadata.num_prefill_tokens - num_decode_query_tokens = attn_metadata.num_decode_tokens - - return (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) - @dataclass class MLADims: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py deleted file mode 100644 index c1213f7620a7a..0000000000000 --- a/vllm/attention/backends/xformers.py +++ /dev/null @@ -1,805 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with xFormers and PagedAttention.""" -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type - -import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMaskWithTensorBias) - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import ( - CommonAttentionState, CommonMetadataBuilder, - get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class XFormersBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "XFORMERS" - - @staticmethod - def get_impl_cls() -> Type["XFormersImpl"]: - return XFormersImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return XFormersMetadata - - @staticmethod - def get_builder_cls() -> Type["XFormersMetadataBuilder"]: - return XFormersMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for XFormersbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # FIXME: It is for flash attn. - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] = None - - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - - # Self-attention prefill/decode metadata cache - _cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cached_decode_metadata: Optional["XFormersMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[AttentionBias]] = None - self.encoder_attn_bias: Optional[List[AttentionBias]] = None - self.cross_attn_bias: Optional[List[AttentionBias]] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata(self) -> Optional["XFormersMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - # Recover cached prefill-phase attention - # metadata structure - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - # Construct & cache prefill-phase attention metadata structure - self._cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["XFormersMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - # Recover cached decode-phase attention - # metadata structure - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - # Construct & cache decode-phase attention metadata structure - self._cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - -def _get_attn_bias( - attn_metadata: XFormersMetadata, - attn_type: str, -) -> Optional[AttentionBias]: - ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate attention bias value given the attention type - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - return attn_metadata.attn_bias - elif attn_type == AttentionType.ENCODER: - return attn_metadata.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return attn_metadata.cross_attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _set_attn_bias( - attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]], - attn_type: str, -) -> None: - ''' - Update appropriate attention bias field of attention metadata, - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_bias: The desired attention bias value - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - attn_metadata.attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER: - attn_metadata.encoder_attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - attn_metadata.cross_attn_bias = attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): - - _metadata_cls = XFormersMetadata - - -class XFormersImpl(AttentionImpl[XFormersMetadata]): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "XFORMERS backend.") - if logits_soft_cap is not None: - logger.warning_once("XFormers does not support logits soft cap. " - "Outputs may be slightly off.") - if use_irope: - logger.warning_once( - "Using irope in XFormers is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: "XFormersMetadata", - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with xFormers and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * XFormersImpl.forward() may be invoked for both self- and cross- - attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). - Used for encoder branch of encoder-decoder models. - * ENCODER_ONLY: no kv_caching, uses the normal attention - attributes (seq_lens/seq_lens_tensor/max_seq_len). - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for XFormersImpl") - - attn_type = self.attn_type - # Check that appropriate attention metadata attributes are - # selected for the desired attention type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - # Self-attention vs. cross-attention will impact - # which KV cache memory-mapping & which - # seqlen datastructures we utilize - - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): - # KV-cache during decoder-self- or - # encoder-decoder-cross-attention, but not - # during encoder attention. - # - # Even if there are no new key/value pairs to cache, - # we still need to break out key_cache and value_cache - # i.e. for later use by paged attention - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if (key is not None) and (value is not None): - - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - PagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - if key is not None and value is not None: - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # normal attention. - # block tables are empty if the prompt does not have a cached - # prefix. - out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_query_tokens].shape - output[:num_prefill_query_tokens] = out - else: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have prefix attention.") - - assert prefill_meta.query_start_loc is not None - assert prefill_meta.max_query_len is not None - - # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - out = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window, - layer._k_scale, - layer._v_scale, - ) - assert output[:num_prefill_query_tokens].shape == out.shape - output[:num_prefill_query_tokens] = out - - if decode_meta := attn_metadata.decode_metadata: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") - - ( - seq_lens_arg, - max_seq_len_arg, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - - output[num_prefill_query_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - block_tables_arg, - seq_lens_arg, - max_seq_len_arg, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: str = AttentionType.DECODER, - ) -> torch.Tensor: - """Attention for 1D query of multiple prompts. Multiple prompt - tokens are flattened in to `query` input. - - See https://facebookresearch.github.io/xformers/components/ops.html - for API spec. - - Args: - output: shape = [num_prefill_tokens, num_heads, head_size] - query: shape = [num_prefill_tokens, num_heads, head_size] - key: shape = [num_prefill_tokens, num_kv_heads, head_size] - value: shape = [num_prefill_tokens, num_kv_heads, head_size] - attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally - """ - - original_query = query - if self.num_kv_heads != self.num_heads: - # GQA/MQA requires the shape [B, M, G, H, K]. - # Note that the output also has the same shape (which is different - # from a spec from the doc). - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata, attn_type) - if attn_bias is None: - if self.alibi_slopes is None: - - # Cross attention block of decoder branch of encoder-decoder - # model uses seq_lens for dec / encoder_seq_lens for enc - if (attn_type == AttentionType.ENCODER_DECODER): - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens is not None - - # Cross-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, - attn_metadata.encoder_seq_lens, - device=query.device) - - # Encoder branch of encoder-decoder model uses - # attn_metadata.encoder_seq_lens - elif attn_type == AttentionType.ENCODER: - - assert attn_metadata.encoder_seq_lens is not None - - # Encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens, device=query.device) - - # Self-attention block of encoder-only model just - # uses the seq_lens directly. - elif attn_type == AttentionType.ENCODER_ONLY: - assert attn_metadata.seq_lens is not None - - # Encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, device=query.device) - - # Self-attention block of decoder branch just - # uses the seq_lens directly - elif attn_type == AttentionType.DECODER: - assert attn_metadata.seq_lens is not None - - # Decoder self-attention mask is causal - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens, device=query.device) - else: - raise ValueError("Unknown AttentionType: %s", attn_type) - - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - attn_bias = [attn_bias] - else: - assert attn_type == AttentionType.DECODER - assert attn_metadata.seq_lens is not None - attn_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) - - _set_attn_bias(attn_metadata, attn_bias, attn_type) - - # No alibi slopes. - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - # Add the batch dimension. - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias[0], - p=0.0, - scale=self.scale) - return out.view_as(original_query) - - # Attention with alibi slopes. - # FIXME(woosuk): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - assert attn_metadata.seq_lens is not None - output = torch.empty_like(original_query) - start = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens): - end = start + seq_len - out = xops.memory_efficient_attention_forward( - query[None, start:end], - key[None, start:end], - value[None, start:end], - attn_bias=attn_bias[i], - p=0.0, - scale=self.scale) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.view_as(original_query[start:end])) - start += seq_len - return output - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[AttentionBias]: - attn_biases: List[AttentionBias] = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - # Calculate a matrix where each element represents ith element- jth - # element. - bias = bias[None, :] - bias[:, None] - - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) - - return attn_biases diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9fbead31782a9..9f43cb31218f7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import List, Optional + +from typing import Callable, Optional, cast import torch import torch.nn as nn @@ -9,24 +10,37 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.platforms import _Backend, current_platform -from vllm.utils import direct_register_custom_op +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.platforms import current_platform +from vllm.utils import GiB_bytes, direct_register_custom_op logger = init_logger(__name__) USE_XFORMERS_OPS = None +try: + tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,) +except AttributeError: + tag_cudagraph_unsafe = () # type: ignore[assignment] def check_xformers_availability(): @@ -34,8 +48,7 @@ def check_xformers_availability(): if USE_XFORMERS_OPS is not None: return USE_XFORMERS_OPS - if current_platform.is_cuda() and current_platform.has_device_capability( - 100): + if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: @@ -54,7 +67,51 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -class Attention(nn.Module): +def check_upstream_fa_availability(dtype: torch.dtype): + if ( + dtype in (torch.float16, torch.bfloat16) + and current_platform.is_cuda() + and current_platform.has_device_capability(80) + ): + from transformers.utils import is_flash_attn_2_available + + return is_flash_attn_2_available() + if current_platform.is_rocm(): + from importlib.util import find_spec + + return find_spec("flash_attn") is not None + return False + + +def maybe_get_vit_flash_attn_backend( + attn_backend: _Backend, use_upstream_fa: bool +) -> tuple[_Backend, Callable]: + if ( + attn_backend != _Backend.FLASH_ATTN + and attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: + use_upstream_fa = True + + if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + else: + flash_attn_varlen_func = None + + return attn_backend, flash_attn_varlen_func + + +class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors @@ -72,12 +129,11 @@ class Attention(nn.Module): head_size: int, scale: float, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -101,18 +157,16 @@ class Attention(nn.Module): if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size - is_attention_free = cache_config.is_attention_free calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 - is_attention_free = False calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads - assert num_heads % num_kv_heads == 0, \ - f"num_heads ({num_heads}) is not " \ - f"divisible by num_kv_heads ({num_kv_heads})" + assert num_heads % num_kv_heads == 0, ( + f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" + ) # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with @@ -139,23 +193,25 @@ class Attention(nn.Module): # the quant op after this attention layer. self._o_scale_float: Optional[float] = None - self.use_mla = use_mla self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None - quant_method = quant_config.get_quant_method( - self, prefix=prefix) if quant_config else None + quant_method = ( + quant_config.get_quant_method(self, prefix=prefix) if quant_config else None + ) if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod): + quant_method, UnquantizedLinearMethod + ): assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") + raise ValueError( + "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." + ) # If quantization is enabled, we make "k_scale" and "v_scale" # parameters so that it can be loaded from the model checkpoint. # The k/v_scale will then be converted back to native float32 @@ -167,21 +223,31 @@ class Attention(nn.Module): # weight and activation dtype. dtype = torch.get_default_dtype() if attn_backend is None: - self.attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - is_attention_free, - use_mla=use_mla, - has_sink=self.has_sink) + self.attn_backend = get_attn_backend( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=False, + has_sink=self.has_sink, + ) else: self.attn_backend = attn_backend impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **extra_impl_args) + self.impl = impl_cls( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **extra_impl_args, + ) self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.dtype = dtype @@ -189,8 +255,7 @@ class Attention(nn.Module): # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() + self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config @@ -212,13 +277,39 @@ class Attention(nn.Module): # by bind_kv_cache # this variable will not be accessed if use_direct_call is True self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) + torch.tensor([]) + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) ] - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + try: + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + except torch.cuda.OutOfMemoryError as e: + logger.error("Failed to initialize attention q/k/v range constants: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug( + "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes + ) + logger.debug( + "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes + ) + raise RuntimeError( + "Failed to initialize q/k/v range constants. " + "This may be caused by insufficient memory to allocate " + "kv cache." + ) from e + + # for attn backends supporting query quantization + self.query_quant = None + if ( + self.kv_cache_dtype.startswith("fp8") + and self.attn_backend.supports_quant_query_input + ): + self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) def forward( self, @@ -240,45 +331,44 @@ class Attention(nn.Module): `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(query, key, value) + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) + + output_dtype = query.dtype + if self.query_quant is not None: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} + query, _ = self.query_quant(query, self._q_scale) + if self.use_output: - output_shape = (output_shape - if output_shape is not None else query.shape) - output = torch.zeros(output_shape, - dtype=query.dtype, - device=query.device) + output_shape = output_shape if output_shape is not None else query.shape + output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] - # We skip reshaping query, key and value tensors for the MLA - # backend since these tensors have different semantics and are - # processed differently. - if not self.use_mla: - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output) + self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata, output=output + ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) + query, key, value, output, self.layer_name + ) return output.view(-1, hidden_size) else: if self.use_direct_call: @@ -287,11 +377,13 @@ class Attention(nn.Module): if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, - self_kv_cache, attn_metadata) + return self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata + ) else: return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name) + query, key, value, self.layer_name + ) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) @@ -316,12 +408,11 @@ class Attention(nn.Module): self.impl.process_weights_after_loading(act_dtype) # FlashInfer requires attention sinks to be float32 - if (self.backend == _Backend.FLASHINFER_VLLM_V1 - and hasattr(self.impl, 'sinks')): + if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"): from vllm.v1.attention.backends.flashinfer import FlashInferImpl + assert isinstance(self.impl, FlashInferImpl) - if (self.impl.sinks is not None - and self.impl.sinks.dtype != torch.float32): + if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32: self.impl.sinks = self.impl.sinks.to(torch.float32) def get_attn_backend(self) -> type[AttentionBackend]: @@ -337,50 +428,88 @@ class MultiHeadAttention(nn.Module): head_size: int, scale: float, num_kv_heads: Optional[int] = None, - ): + # This has no effect, it is only here to make it easier to swap + # between Attention and MultiHeadAttention + prefix: str = "", + ) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.layer_name = prefix - assert self.num_heads % self.num_kv_heads == 0, \ - f"num_heads ({self.num_heads}) is not " \ + assert self.num_heads % self.num_kv_heads == 0, ( + f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" + ) self.num_queries_per_kv = self.num_heads // self.num_kv_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype=None, - block_size=16, - is_attention_free=False) - backend = backend_name_to_enum(attn_backend.get_name()) - if current_platform.is_rocm(): - # currently, only torch_sdpa is supported on rocm + + # Determine the attention backend + backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + + # Some auto-selected backends can be upgraded + # to upstream flash attention if available. + # If vllm native fa is selected, we use it directly. + use_upstream_fa = False + + if current_platform.is_xpu(): + # currently, only torch_sdpa is supported on xpu self.attn_backend = _Backend.TORCH_SDPA else: - if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, - _Backend.FLEX_ATTENTION): - backend = _Backend.XFORMERS + self.attn_backend = ( + backend + if backend + in { + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.PALLAS, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + } + else _Backend.TORCH_SDPA + ) - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 - } else _Backend.TORCH_SDPA + self.attn_backend, self._flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + use_upstream_fa, + ) + ) - if (self.attn_backend == _Backend.XFORMERS - and not check_xformers_availability()): + if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): self.attn_backend = _Backend.TORCH_SDPA + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } + + # this condition is just to make sure that the + # use_upstream_fa in the log is correct + if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + use_upstream_fa = True + + logger.info_once( + f"MultiHeadAttention attn_backend: {self.attn_backend}, " + f"use_upstream_fa: {use_upstream_fa}" + ) + def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: - """Input shape: batch_size x seq_len x hidden_size""" - # TODO(Isotr0py): Use existing backend implementations and support FA3 - bsz, q_len, _ = query.size() + """Input shape: + (batch_size x seq_len x hidden_size) or + (batch_size x seq_len x num_heads x head_size) + """ + bsz, q_len = query.size()[:2] kv_len = key.size(1) query = query.view(bsz, q_len, self.num_heads, self.head_size) @@ -392,31 +521,261 @@ class MultiHeadAttention(nn.Module): key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.XFORMERS: + if self.is_flash_attn_backend: + cu_seqlens_q = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device + ) + cu_seqlens_k = torch.arange( + 0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device + ) + + out = self._flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops - out = xops.memory_efficient_attention_forward(query, - key, - value, - scale=self.scale) + out = xops.memory_efficient_attention_forward( + query, key, value, scale=self.scale + ) elif self.attn_backend == _Backend.TORCH_SDPA: - query, key, value = (x.transpose(1, 2) - for x in (query, key, value)) - out = F.scaled_dot_product_attention(query, - key, - value, - scale=self.scale) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.PALLAS_VLLM_V1: - query, key, value = (x.transpose(1, 2) - for x in (query, key, value)) + elif self.attn_backend == _Backend.PALLAS: + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention + out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) + else: + # ViT attention hasn't supported this backend yet + raise NotImplementedError( + f"ViT attention hasn't supported {self.attn_backend} backend yet." + ) return out.reshape(bsz, q_len, -1) +class MLAAttention(nn.Module, AttentionLayerBase): + """Multi-Head Latent Attention layer. + + This class takes query, and compressed key/value tensors as input. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + kv_b_proj: ColumnParallelLinear, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_sparse: bool = False, + indexer: Optional[object] = None, + ): + super().__init__() + self.num_heads = num_heads + self.scale = scale + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.head_size = kv_lora_rank + qk_rope_head_dim + self.layer_name = prefix + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + calculate_kv_scales = False + self.kv_cache_dtype = kv_cache_dtype + + dtype = torch.get_default_dtype() + self.attn_backend = get_attn_backend( + self.head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=True, + use_sparse=use_sparse, + ) + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) + self.impl = impl_cls( + num_heads=self.num_heads, + head_size=self.head_size, + scale=self.scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=self.kv_cache_dtype, + logits_soft_cap=None, + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=kv_b_proj, + indexer=indexer, + ) + + self.use_direct_call = not current_platform.opaque_attention_op() + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + self.kv_cache = [ + torch.tensor([]) + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) + ] + + # Align with Attention's scale attributes for MLA backends. + + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # Host-side mirrors used by some attention backends + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + self._o_scale_float: Optional[float] = None + + self.use_sparse = use_sparse + + # Initialize q/k/v range constants. + try: + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + except torch.cuda.OutOfMemoryError: + # Keep defaults if allocation fails; not critical for init. + pass + + def forward( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + # Mirror Attention.forward scale calculation path + if self.calculate_kv_scales and getattr( + attn_metadata, "enable_kv_scales_calculation", False + ): + self.calc_kv_scales(q, kv_c_normed, k_pe) + + if self.attn_backend.accept_output_buffer: + output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, + output=output, + ) + return output + else: + return self.impl.forward( + self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + ) + else: + if self.attn_backend.accept_output_buffer: + output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) + torch.ops.vllm.unified_mla_attention_with_output( + q, + kv_c_normed, + k_pe, + output, + self.layer_name, + ) + return output + else: + # We can still access forward context to check calculation flag + if self.calculate_kv_scales: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + if getattr(attn_metadata, "enable_kv_scales_calculation", False): + self.calc_kv_scales(q, kv_c_normed, k_pe) + return torch.ops.vllm.unified_mla_attention( + q, + kv_c_normed, + k_pe, + self.layer_name, + ) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading(act_dtype) + + def calc_kv_scales( + self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor + ) -> None: + """Optional scale calculation for MLA inputs. + + Mirrors Attention.calc_kv_scales. Not all MLA backends require this + """ + # Use safe defaults if ranges are not present + q_range = getattr(self, "q_range", torch.tensor(1.0)) + k_range = getattr(self, "k_range", torch.tensor(1.0)) + v_range = getattr(self, "v_range", torch.tensor(1.0)) + + self._q_scale.copy_(torch.abs(q).max() / q_range) + # kv_c_normed is the compressed KV representation; use it for k/v + kv_abs_max = torch.abs(kv_c_normed).max() + self._k_scale.copy_(kv_abs_max / k_range) + self._v_scale.copy_(kv_abs_max / v_range) + self._q_scale_float = self._q_scale.item() + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + self.calculate_kv_scales = False + + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + + def wait_for_kv_layer_from_connector(layer_name: str): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -433,7 +792,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): def maybe_save_kv_layer_to_connector( layer_name: str, - kv_cache_layer: List[torch.Tensor], + kv_cache_layer: list[torch.Tensor], ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -445,8 +804,45 @@ def maybe_save_kv_layer_to_connector( if attn_metadata is None: return assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata[layer_name]) + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name]) + + +def maybe_calc_kv_scales( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + + if attn_metadata is None or not getattr( + attn_metadata, "enable_kv_scales_calculation", False + ): + return + + self = forward_context.no_compile_layers[layer_name] + self.calc_kv_scales(query, key, value) + + +def maybe_calc_kv_scales_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="maybe_calc_kv_scales", + op_func=maybe_calc_kv_scales, + mutates_args=["query", "key", "value"], + fake_impl=maybe_calc_kv_scales_fake, +) def unified_attention( @@ -463,8 +859,7 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - output = self.impl.forward(self, query, key, value, kv_cache, - attn_metadata) + output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output @@ -482,9 +877,8 @@ def unified_attention_fake( direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, - mutates_args=[], fake_impl=unified_attention_fake, - dispatch_key=current_platform.dispatch_key, + tags=tag_cudagraph_unsafe, ) @@ -504,15 +898,17 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -534,5 +930,95 @@ direct_register_custom_op( op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, + tags=tag_cudagraph_unsafe, +) + + +def unified_mla_attention( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + +def unified_mla_attention_fake( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(q).contiguous() + + +direct_register_custom_op( + op_name="unified_mla_attention", + op_func=unified_mla_attention, + mutates_args=[], + fake_impl=unified_mla_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_mla_attention_with_output( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +def unified_mla_attention_with_output_fake( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_mla_attention_with_output", + op_func=unified_mla_attention_with_output, + mutates_args=["output", "output_block_scale"], + fake_impl=unified_mla_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 087c5004bde06..3d37e901605f9 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -1,18 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import List, Optional +from typing import ClassVar, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.selector import get_attn_backend -from vllm.config import CacheConfig, QuantizationConfig +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, make_local_attention_virtual_batches, - subclass_attention_backend) + AttentionCGSupport, + CommonAttentionMetadata, + make_local_attention_virtual_batches, + subclass_attention_backend, +) from ..layer import Attention @@ -28,37 +31,42 @@ def create_chunked_local_attention_backend( underlying_builder = underlying_attn_backend.get_builder_cls() class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: common_attn_metadata = make_local_attention_virtual_batches( - attention_chunk_size, common_attn_metadata, block_size) - return super().build(common_prefix_len, common_attn_metadata, - fast_build) + attention_chunk_size, common_attn_metadata, block_size + ) + return super().build(common_prefix_len, common_attn_metadata, fast_build) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=ChunkedLocalAttentionBuilder) + builder_cls=ChunkedLocalAttentionBuilder, + ) return attn_backend class ChunkedLocalAttention(Attention): - - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - attention_chunk_size: int, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - kv_sharing_target_layer_name: Optional[str] = None, - prefix: str = ""): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + attention_chunk_size: int, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[list[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + kv_sharing_target_layer_name: Optional[str] = None, + prefix: str = "", + ): dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -68,12 +76,13 @@ class ChunkedLocalAttention(Attention): block_size = 16 if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend(head_size, dtype, - kv_cache_dtype, - block_size) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size, block_size) + underlying_attn_backend, attention_chunk_size, block_size + ) else: # in v0 the local attention is handled inside the backends attn_backend = None @@ -88,4 +97,5 @@ class ChunkedLocalAttention(Attention): quant_config=quant_config, prefix=prefix, kv_sharing_target_layer_name=kv_sharing_target_layer_name, - attn_backend=attn_backend) + attn_backend=attn_backend, + ) diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py new file mode 100644 index 0000000000000..fb7004f86538f --- /dev/null +++ b/vllm/attention/layers/cross_attention.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import numpy as np +import torch + +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) +from vllm.v1.kv_cache_interface import CrossAttentionSpec + +logger = init_logger(__name__) + + +def _get_max_encoder_len(vllm_config: "VllmConfig") -> int: + """Gets the max number of encoder input tokens from the config.""" + sc = vllm_config.scheduler_config + assert sc and isinstance(sc.max_num_encoder_input_tokens, int), ( + "max_num_encoder_input_tokens must be int for enc-dec models" + ) + return sc.max_num_encoder_input_tokens + + +def _get_cross_slot_mapping( + encoder_seq_lens: np.ndarray, + block_table_tensor: torch.Tensor, + kv_cache_spec: CrossAttentionSpec, + device: torch.device, +) -> torch.Tensor: + """Get cross-attention slot mappings.""" + + block_size = kv_cache_spec.block_size + slot_mappings = [] + + # Find indices with non-zero encoder sequence lengths + # The majority of parallel requests will be running the + # decoder, so this list should be relatively small. + active_indices = np.nonzero(encoder_seq_lens)[0] + + for req_index in active_indices: + encoder_seq_len = encoder_seq_lens[req_index].item() + + # Calculate the number of blocks needed for this request + num_blocks_needed = cdiv(encoder_seq_len, block_size) + + # Get the block IDs for this request from the tensor + req_block_ids = block_table_tensor[req_index] + + # Get only the blocks we need (first num_blocks_needed blocks) + needed_block_ids = req_block_ids[:num_blocks_needed] + + # All needed blocks are allocated + i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device) + block_indices = i_values // block_size + block_offsets = i_values % block_size + block_numbers = needed_block_ids[block_indices] + slot_mapping = block_numbers * block_size + block_offsets + + slot_mappings.append(slot_mapping) + + if slot_mappings: + return torch.cat(slot_mappings) + else: + return torch.empty(0, dtype=torch.int64, device=device) + + +@functools.lru_cache +def create_cross_attention_backend( + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: + prefix = "CrossAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class CrossAttentionBuilder(underlying_builder): # type: ignore + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_metadata = copy(common_attn_metadata) + new_metadata.causal = False + max_encoder_len = _get_max_encoder_len(self.vllm_config) + new_metadata.max_seq_len = max_encoder_len + + new_metadata.seq_lens = torch.full( + (new_metadata.num_reqs,), + max_encoder_len, + dtype=torch.int32, + device=self.device, + ) + new_metadata.seq_lens_cpu = torch.full( + (new_metadata.num_reqs,), + max_encoder_len, + dtype=torch.int32, + device="cpu", + ) + new_metadata.slot_mapping = _get_cross_slot_mapping( + new_metadata.encoder_seq_lens, + new_metadata.block_table_tensor, + self.kv_cache_spec, + self.device, + ) + return super().build(common_prefix_len, new_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=CrossAttentionBuilder, + ) + + return attn_backend + + +class CrossAttention(Attention): + """ + Cross-attention for encoder-decoder models. + Handles attention between decoder queries and encoder keys/values. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs, + ): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + + attn_backend = create_cross_attention_backend(underlying_attn_backend) + else: + # in v0 cross attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_DECODER, ( + "CrossAttention only supports AttentionType.ENCODER_DECODER" + ) + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_DECODER, + **kwargs, + ) diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index cea05df5b96d2..f49f195563dca 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -7,36 +7,45 @@ from typing import Optional import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - subclass_attention_backend) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) @functools.lru_cache def create_encoder_only_attention_backend( - underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: prefix = "EncoderOnlyAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: new_common_attn_metadata = copy(common_attn_metadata) new_common_attn_metadata.causal = False - return super().build(common_prefix_len, new_common_attn_metadata, - fast_build) + return super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=EncoderOnlyAttentionBuilder) + builder_cls=EncoderOnlyAttentionBuilder, + ) return attn_backend @@ -46,13 +55,15 @@ class EncoderOnlyAttention(Attention): Encoder attention is a special case that doesn't need a KV Cache. """ - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - cache_config: Optional[CacheConfig] = None, - attn_type: Optional[str] = None, - **kwargs): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs, + ): dtype = torch.get_default_dtype() if cache_config is not None: @@ -63,24 +74,28 @@ class EncoderOnlyAttention(Attention): block_size = 16 if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend(head_size, dtype, - kv_cache_dtype, - block_size) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) attn_backend = create_encoder_only_attention_backend( - underlying_attn_backend) + underlying_attn_backend + ) else: # in v0 encoder only attention is handled inside the backends attn_backend = None if attn_type is not None: - assert attn_type == AttentionType.ENCODER_ONLY, \ + assert attn_type == AttentionType.ENCODER_ONLY, ( "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" + ) - super().__init__(num_heads=num_heads, - head_size=head_size, - scale=scale, - cache_config=cache_config, - attn_backend=attn_backend, - attn_type=AttentionType.ENCODER_ONLY, - **kwargs) + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs, + ) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index e5b90a8b27558..aa791fe970063 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd +float8_info = torch.finfo(current_platform.fp8_dtype()) + @triton.jit def cdiv_fn(x, y): @@ -23,69 +25,73 @@ def cdiv_fn(x, y): @triton.jit def kernel_paged_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - num_queries_per_kv_padded: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - x: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.int64, # int - stride_k_cache_4: tl.int64, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.int64, # int - filter_by_query_len: tl.constexpr, # bool - query_start_len_ptr, # [num_seqs+1] - USE_SINKS: tl.constexpr, # bool + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale_inv, + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] + USE_SINKS: tl.constexpr, # bool + USE_FP8: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) if filter_by_query_len: cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + - 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if cur_batch_query_len > 1: return else: cur_batch_in_all_start_index = seq_idx query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( - 0, num_queries_per_kv_padded) + 0, num_queries_per_kv_padded + ) - query_offset = (cur_batch_in_all_start_index * query_stride_0 + - query_head_idx[:, None] * query_stride_1) + query_offset = ( + cur_batch_in_all_start_index * query_stride_0 + + query_head_idx[:, None] * query_stride_1 + ) head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv head_mask = head_mask & (query_head_idx < num_query_heads) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, - 0).to(tl.int1) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) # Q : (num_queries_per_kv, HEAD_SIZE,) Q = tl.load( @@ -97,9 +103,7 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride if not USE_SINKS: - M = tl.full([num_queries_per_kv_padded], - float("-inf"), - dtype=tl.float32) + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) else: M = tl.load( sink_ptr + query_head_idx, @@ -108,43 +112,43 @@ def kernel_paged_attention_2d( ).to(dtype=tl.float32) L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) - acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], - dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, - mask=head_mask, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0 + ) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles for j in range(0, num_blocks): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) offs_n = tl.arange(0, BLOCK_SIZE) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - v_offset = (physical_block_idx * stride_v_cache_0 + - kv_head_idx * stride_v_cache_1 + - offs_d[None, :] * stride_v_cache_2 + - offs_n[:, None] * stride_v_cache_3) + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3 + ) - k_offset = (physical_block_idx * stride_k_cache_0 + - kv_head_idx * stride_k_cache_1 + - (offs_d[:, None] // x) * stride_k_cache_2 + - offs_n[None, :] * stride_k_cache_3 + - (offs_d[:, None] % x) * stride_k_cache_4) + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4 + ) # K : (HEAD_SIZE, BLOCK_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], - other=0.0) + K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) if K_load.dtype.is_fp8(): K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) @@ -152,9 +156,7 @@ def kernel_paged_attention_2d( K = K_load # V : (BLOCK_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], - other=0.0) + V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) if V_load.dtype.is_fp8(): V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) @@ -166,15 +168,13 @@ def kernel_paged_attention_2d( seq_mask = seq_offset[None, :] < boundary # S : (num_queries_per_kv, BLOCK_SIZE,) - S = tl.where(head_mask[:, None] & seq_mask, 0.0, - float("-inf")).to(tl.float32) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32) S += scale * tl.dot(Q, K) context_len = seq_len - 1 if SLIDING_WINDOW > 0: - S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, - -10000) + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -204,13 +204,17 @@ def kernel_paged_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - output_offset = (cur_batch_in_all_start_index * output_stride_0 + - query_head_idx * output_stride_1) + output_offset = ( + cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1 + ) tl.store( - output_ptr + output_offset[:, None] + - tl.arange(0, HEAD_SIZE_PADDED)[None, :], + output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :], acc, mask=dim_mask[None, :] & head_mask[:, None], ) @@ -234,12 +238,12 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + output_scale=None, # Optional tensor for sinks sinks=None, ): - if sm_scale is None: - sm_scale = 1.0 / (query.shape[1]**0.5) + sm_scale = 1.0 / (query.shape[1] ** 0.5) use_alibi_slopes = alibi_slopes is not None @@ -266,6 +270,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + fp8_out_scale=output_scale, sinks=sinks, ) @@ -292,10 +297,10 @@ def chunked_prefill_paged_decode( key_cache = key_cache.view(target_dtype) value_cache = value_cache.view(target_dtype) - num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), - 16) + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) from vllm.platforms.rocm import use_rocm_custom_paged_attention + use_custom = use_rocm_custom_paged_attention( query.dtype, head_size, @@ -309,14 +314,14 @@ def chunked_prefill_paged_decode( ) if use_custom: _PARTITION_SIZE_ROCM = 256 - max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // - _PARTITION_SIZE_ROCM) + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM assert _PARTITION_SIZE_ROCM % block_size == 0 total_num_seq = block_table.shape[0] tmp_output = torch.empty( - size=(total_num_seq, num_query_heads, max_num_partitions, - head_size), - dtype=output.dtype, + size=(total_num_seq, num_query_heads, max_num_partitions, head_size), + dtype=query.dtype, device=output.device, ) exp_sums = torch.empty( @@ -345,12 +350,15 @@ def chunked_prefill_paged_decode( kv_cache_dtype=kv_cache_dtype, k_scale=k_scale, v_scale=v_scale, + fp8_out_scale=output_scale, ) else: - kernel_paged_attention_2d[( - num_seqs, - num_kv_heads, - )]( + kernel_paged_attention_2d[ + ( + num_seqs, + num_kv_heads, + ) + ]( output_ptr=output, query_ptr=query, key_cache_ptr=key_cache, @@ -362,6 +370,7 @@ def chunked_prefill_paged_decode( scale=sm_scale, k_scale=k_scale, v_scale=v_scale, + out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, @@ -388,4 +397,5 @@ def chunked_prefill_paged_decode( filter_by_query_len=True, query_start_len_ptr=query_start_loc, USE_SINKS=sinks is not None, + USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py new file mode 100644 index 0000000000000..097fbae68cda5 --- /dev/null +++ b/vllm/attention/ops/common.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.triton_utils import tl, triton + + +@triton.jit +def _correct_attn_cp_out_kernel( + outputs_ptr, + new_output_ptr, + lses_ptr, + vlse_ptr, + outputs_stride_B, + outputs_stride_H, + outputs_stride_D, + lses_stride_N, + lses_stride_B, + lses_stride_H, + lse_idx, + HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr, +): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + outputs_ptr (triton.PointerType): + Pointer to input tensor of shape [ B, H, D ] + lses_ptr (triton.PointerType): + Pointer to input tensor of shape [ N, B, H ] + new_output_ptr (triton.PointerType): + Pointer to output tensor of shape [ B, H, D ] + vlse_ptr (triton.PointerType): + Pointer to output tensor of shape [ B, H ] + """ + batch_idx = tl.program_id(axis=0).to(tl.int64) + head_idx = tl.program_id(axis=1).to(tl.int64) + d_offsets = tl.arange(0, HEAD_DIM) + num_n_offsets = tl.arange(0, N_ROUNDED) + + # shape = [N] + lse_offsets = ( + num_n_offsets * lses_stride_N + + batch_idx * lses_stride_B + + head_idx * lses_stride_H + ) + + # calc final lse + lse = tl.load(lses_ptr + lse_offsets) + lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse) + lse_max = tl.max(lse, axis=0) + lse -= lse_max + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + lse += lse_max + + lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H + tl.store(vlse_ptr + lse_offsets, lse) + + # shape = [D] + output_offsets = ( + batch_idx * outputs_stride_B + + head_idx * outputs_stride_H + + d_offsets * outputs_stride_D + ) + + # correct output + lse_offset = ( + lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H + ) + lse_tmp = tl.load(lses_ptr + lse_offset) + lse_finally = lse_tmp - lse + lse_finally = tl.where( + (lse_finally != lse_finally) | (lse_finally == float("inf")), + -float("inf"), + lse_finally, + ) + factor = tl.exp(lse_finally) + output = tl.load(outputs_ptr + output_offsets) + output = output * factor + + tl.store(new_output_ptr + output_offsets, output) + + +class CPTritonContext: + """The CPTritonContext is used to avoid recompilation of the Triton JIT.""" + + def __init__(self): + self.inner_kernel = None + + def call_kernel(self, kernel, grid, *regular_args, **const_args): + if self.inner_kernel is None: + self.inner_kernel = kernel[grid](*regular_args, **const_args) + else: + self.inner_kernel[grid](*regular_args) + + +def correct_attn_out( + out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext +) -> tuple[torch.Tensor, torch.Tensor]: + """Correct the attention output using the all-gathered lses. + + Args: + out: Tensor of shape [ B, H, D ] + lses: Tensor of shape [ N, B, H ] + cp_rank: Current rank in the context-parallel group + ctx: Triton context to avoid recompilation + + Returns: + Tuple of (out, lse) with corrected attention and final log-sum-exp. + """ + if ctx is None: + ctx = CPTritonContext() + + lse = torch.empty_like(lses[0]) + + grid = (out.shape[0], out.shape[1], 1) + regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank) + const_args = { + "HEAD_DIM": out.shape[-1], + "N_ROUNDED": lses.shape[0], + } + + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) + return out, lse + + +def cp_lse_ag_out_rs( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + if cp_group.world_size == 1: + return cp_attn_out + + if ctx is None: + ctx = CPTritonContext() + + lses = torch.empty( + (cp_group.world_size,) + cp_attn_lse.shape, + dtype=cp_attn_lse.dtype, + device=cp_attn_lse.device, + ) + + cp_attn_lse = cp_attn_lse.contiguous() + lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) + out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + out = cp_group.reduce_scatter(out, dim=1) + return out + + +@triton.jit +def _pack_seq_kernel( + x_ptr, # [N, D] + out_ptr, # [B, Lmax, D] + lengths_ptr, # *i32, [B] + N: tl.constexpr, + D: tl.constexpr, + Lmax: tl.constexpr, + PAD_VALUE: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr, # features per program +): + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + + # Compute start index and sequence length from cumulative lengths + in_start = 0 + for i in range(pid_b): + in_start += tl.load(lengths_ptr + i) + seq_len = tl.load(lengths_ptr + pid_b) + + # valid time positions for this block + t_mask = off_t < Lmax + + # compute input row indices for valid (b, t) + in_row = in_start + off_t + valid_row = (off_t < seq_len) & t_mask + + # Pointers + # x_ptr: row-major [N, D] + x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :] + + # out_ptr: row-major [B, Lmax, D] + out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + + # Initialize with PAD (cast will occur as needed based on out_ptr dtype) + d_mask = off_d[None, :] < D + pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32) + tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask) + + # Load & write only where within seq_len + x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask) + tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) + + +def pack_seq_triton( + x: torch.Tensor, + lengths: torch.Tensor, + pad_value: float = -float("inf"), + block_t: int = 64, + block_d: int = 64, +) -> torch.Tensor: + """ + Pack sequences of different lengths into a batched tensor. + + Args: + x: [N, ...] - input tensor where N is total number of tokens + lengths: [B] - sequence lengths for each batch + pad_value: value to use for padding + block_t: block size for time dimension + block_d: block size for feature dimension + + Returns: + packed: [B, Lmax, ...] - packed tensor + """ + + # Handle multi-dimensional input by reshaping to (N, -1) + original_shape = x.shape + if len(original_shape) > 2: + N = original_shape[0] + x_reshaped = x.reshape(N, -1) + D = x_reshaped.shape[1] + else: + N, D = x.shape + x_reshaped = x + + B = lengths.numel() + Lmax = int(lengths.max().item()) + + # Starts are computed inside the kernel from lengths + + out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) + + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) + _pack_seq_kernel[grid]( + x_reshaped, + out, + lengths.int(), + N, + D, + Lmax, + PAD_VALUE=float(pad_value), + BLOCK_T=block_t, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) + + # Reshape output back to original dimensions (except first dimension) + if len(original_shape) > 2: + output_shape = (B, Lmax) + original_shape[1:] + out = out.reshape(output_shape) + + return out + + +@triton.jit +def _unpack_seq_triton_kernel( + packed_ptr, # [B, Lmax, D] + out_ptr, # [N, D] + lengths_ptr, # *i32, [B] + B: tl.constexpr, + Lmax: tl.constexpr, + D: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr, # features per program +): + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + + # bounds: compute start from cumulative lengths + in_start = 0 + for i in range(pid_b): + in_start += tl.load(lengths_ptr + i) + seq_len = tl.load(lengths_ptr + pid_b) + + # valid time positions for this block + t_mask = off_t < Lmax + valid_row = (off_t < seq_len) & t_mask + + # compute output row indices for valid (b, t) + out_row = in_start + off_t + + # Pointers + # packed_ptr: row-major [B, Lmax, D] + packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + + # out_ptr: row-major [N, D] + out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :] + + # Load from packed tensor and store to output + d_mask = off_d[None, :] < D + packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask) + tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) + + +def unpack_seq_triton( + packed_tensor: torch.Tensor, + lengths: torch.Tensor, + block_t: int = 64, + block_d: int = 64, +) -> torch.Tensor: + """ + Unpack a packed decode query tensor back to the original format. + Efficient Triton implementation. + + Args: + packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton + lengths: [B] - sequence lengths for each batch + block_t: block size for time dimension + block_d: block size for feature dimension + + Returns: + unpacked_tensor: [N, ...] where N = sum(lengths) + """ + + # Handle multi-dimensional input by reshaping to (B, Lmax, -1) + original_shape = packed_tensor.shape + if len(original_shape) > 3: + B, Lmax = original_shape[:2] + packed_reshaped = packed_tensor.reshape(B, Lmax, -1) + D = packed_reshaped.shape[2] + else: + B, Lmax, D = packed_tensor.shape + packed_reshaped = packed_tensor + + # Calculate total number of elements + N = int(lengths.sum().item()) + + out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype) + + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) + _unpack_seq_triton_kernel[grid]( + packed_reshaped, + out, + lengths.int(), + B, + Lmax, + D, + BLOCK_T=block_t, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) + + # Reshape output back to original dimensions (except first dimension) + if len(original_shape) > 3: + output_shape = (N,) + original_shape[2:] + out = out.reshape(output_shape) + + return out diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 564042cf8eb12..0bf354a95b1ca 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py -from typing import Optional, Tuple +from typing import Optional import torch @@ -13,48 +13,104 @@ logger = init_logger(__name__) if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False +if current_platform.is_cuda(): + try: + import vllm._flashmla_extension_C # noqa: F401 -def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + _flashmla_extension_C_AVAILABLE = True + except ImportError: + _flashmla_extension_C_AVAILABLE = False +else: + _flashmla_extension_C_AVAILABLE = False + + +def _is_flashmla_available() -> tuple[bool, Optional[str]]: + if not _flashmla_C_AVAILABLE: + return ( + False, + "vllm._flashmla_C is not available, likely was not " + "compiled due to insufficient nvcc version or a supported arch " + "was not in the list of target arches to compile for.", + ) + if not _flashmla_extension_C_AVAILABLE: + return ( + False, + "vllm._flashmla_extension_C is not available, likely " + "was not compiled due to a build error.", + ) + + return True, None + + +def is_flashmla_dense_supported() -> tuple[bool, Optional[str]]: """ Return: is_supported_flag, unsupported_reason (optional). """ - if not current_platform.is_cuda(): - return False, "FlashMLA is only supported on CUDA devices." + is_availble, maybe_reason = _is_flashmla_available() + if not is_availble: + return False, maybe_reason if current_platform.get_device_capability()[0] != 9: - return False, "FlashMLA is only supported on Hopper devices." - if not _flashmla_C_AVAILABLE: - return False, "vllm._flashmla_C is not available, likely was not "\ - "compiled due to insufficient nvcc version or a supported arch "\ - "(only sm90a currently) was not in the list of target arches to "\ - "compile for." + return False, "FlashMLA Dense is only supported on Hopper devices." + return True, None + + +def is_flashmla_sparse_supported() -> tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_availble, maybe_reason = _is_flashmla_available() + if not is_availble: + return False, maybe_reason + if current_platform.get_device_capability()[0] not in (9, 10): + return ( + False, + "FlashMLA Sparse is only supported on Hopper and Blackwell devices.", + ) return True, None def get_mla_metadata( cache_seqlens: torch.Tensor, - num_heads_per_head_k: int, + num_q_tokens_per_head_k: int, num_heads_k: int, -) -> Tuple[torch.Tensor, torch.Tensor]: + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. + - cache_seqlens: (batch_size), dtype torch.int32. + - num_q_tokens_per_head_k: + Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + - num_heads_k: The number of k heads. + - num_heads_q: + The number of q heads. + This argument is optional when sparse attention is not enabled + - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + - topk: If not None, sparse attention will be enabled, + and only tokens in the `indices` array + passed to `flash_mla_with_kvcache_sm90` will be attended to. - Return: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. + Returns: + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + - num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, - num_heads_per_head_k, - num_heads_k) + return torch.ops._flashmla_C.get_mla_decoding_metadata( + cache_seqlens, + num_q_tokens_per_head_k, + num_heads_k, + num_heads_q, + is_fp8_kvcache, + topk, + ) def flash_mla_with_kvcache( @@ -69,45 +125,114 @@ def flash_mla_with_kvcache( causal: bool = False, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q. - descale_k: (batch_size), torch.float32. Descaling factors for K. + - q: (batch_size, seq_len_q, num_heads_q, head_dim). + - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + - cache_seqlens: (batch_size), torch.int32. + - head_dim_v: Head dimension of v. + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, + returned by get_mla_metadata. + - num_splits: + (batch_size + 1), torch.int32, returned by get_mla_metadata. + - softmax_scale: float. + The scale of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + - causal: bool. Whether to apply causal attention mask. + - descale_q: (batch_size), + torch.float32. Descaling factors for Q, used for fp8 quantization. + - descale_k: (batch_size), + torch.float32. Descaling factors for K, used for fp8 quantization. + - is_fp8_kvcache: bool. + Whether the k_cache and v_cache are in fp8 format. + For the format of FP8 KV cache, please refer to README.md + - indices: (batch_size, seq_len_q, topk), torch.int32. + If not None, sparse attention will be enabled, + and only tokens in the `indices` array will be attended to. + Invalid indices should be set to -1 or numbers >= total_seq_len_kv. + For details about how to set up `indices`, please refer to README.md. - Return: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + Returns: + - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) - out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( - q, - k_cache, - head_dim_v, - cache_seqlens, - block_table, - softmax_scale, - causal, - tile_scheduler_metadata, - num_splits, - descale_q, - descale_k, + softmax_scale = q.shape[-1] ** (-0.5) + if indices is not None: + # NOTE (zyongye): sparse attention is also causal + # since it only attend to the tokens before + # but here `causal` should not be specified + assert not causal, "causal must be `false` if sparse attention is enabled." + assert (descale_q is None) == (descale_k is None), ( + "descale_q and descale_k should be both None or both not None" ) + + if indices is None and q.element_size() == 1: + out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( + q, + k_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + descale_q, + descale_k, + ) + else: + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, + k_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + is_fp8_kvcache, + indices, + ) return out, softmax_lse +def flash_mla_sparse_prefill( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + - q: [s_q, h_q, d_qk], bfloat16 + - kv: [s_kv, h_kv, d_qk], bfloat16 + - indices: [s_q, h_kv, topk], int32. + Invalid indices should be set to -1 or numbers >= s_kv + - sm_scale: float + - d_v: The dimension of value vectors. Can only be 512 + + Returns: + - (output, max_logits, lse) + About the definition of output, + max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v) + return results + + # # TODO: Add fake functions # diff --git a/vllm/attention/ops/merge_attn_states.py b/vllm/attention/ops/merge_attn_states.py index 5cb1a47394cf6..79800eb40766c 100644 --- a/vllm/attention/ops/merge_attn_states.py +++ b/vllm/attention/ops/merge_attn_states.py @@ -15,7 +15,6 @@ def merge_attn_states( suffix_lse: torch.Tensor, output_lse: Optional[torch.Tensor] = None, ) -> None: - # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel # is not support for FP8 dtype, fallback to use Triton kernel. def supported_dtypes(o: torch.Tensor) -> bool: @@ -31,13 +30,19 @@ def merge_attn_states( return headdim % 4 == 0 return headdim % 8 == 0 - if (current_platform.is_cuda() and supported_dtypes(output) - and supported_headdim(output)): + if ( + current_platform.is_cuda() + and supported_dtypes(output) + and supported_headdim(output) + ): from vllm._custom_ops import merge_attn_states - return merge_attn_states(output, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse) + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) else: - from vllm.attention.ops.triton_merge_attn_states import ( - merge_attn_states) - return merge_attn_states(output, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse) + from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py deleted file mode 100644 index 29fa432017616..0000000000000 --- a/vllm/attention/ops/nki_flash_attn.py +++ /dev/null @@ -1,903 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import neuronxcc.nki.isa as nisa -import neuronxcc.nki.language as nl -import numpy as np -import torch -from neuronxcc import nki -from neuronxcc.nki.language import par_dim - -from vllm.utils import cdiv - - -def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 - - -@nki.jit -def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): - """ - Load block tables from HBM into SRAM - - `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. - In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. - """ - B_P_SIZE = 128 - - # reshape as `(num_tiles, num_blocks_per_tile)` - assert len(block_tables_hbm.shape) == 1 - (num_total_blocks, ) = block_tables_hbm.shape - assert num_blocks_per_tile * num_tiles == num_total_blocks - block_tables_hbm = block_tables_hbm.reshape( - (num_tiles, num_blocks_per_tile)) - - block_tables_sbuf = nl.zeros( - (cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), - dtype=nl.int32, - ) - for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(num_blocks_per_tile)[None, :] - block_tables_sbuf[i, i_p, i_f] = nl.load( - block_tables_hbm[i_p + i * B_P_SIZE, i_f], - dtype=nl.int32, - mask=(i_p + i * B_P_SIZE < num_tiles), - ) - return block_tables_sbuf - - -@nki.jit -def transform_block_tables_for_indirect_load( - block_tables, - block_size_tiling_factor, - num_head, - head_id, -): - """ - This function does two things: - 1. calculate new `block_tables` for a `head_id` after flattening - `num_block`, `num_head`, and `block_size_tiling_factor` dimensions - 2. transpose the result so that `block_table` for each tile is mapped to - SBUF Partition dimension for vectorized DMA - - Tiling trick to further improve DMA performance: - Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M - blocks of a given `head_id` from HBM, the load `cache[block_tables, - head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not - fully utilize hardware parallelization. The solution is to tile `block_size` - into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * - block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape - `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. - - Note: - We don't further tile D dimension as small DMA size also hurts performance. - """ - B_P_SIZE = 128 - num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( - block_tables.shape) - assert num_tiles_per_partition == B_P_SIZE - assert is_power_of_2( - num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" - - num_loads = cdiv(num_blocks_per_tile, B_P_SIZE) - block_tables_transposed = nl.ndarray( - ( - num_loads, - par_dim(B_P_SIZE), - num_partitions * num_tiles_per_partition, - ), - dtype=nl.int32, - ) - - # prepare iota ahead of time to avoid repeatedly using Gpsimd - if num_head > 1: - head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) - head_id = nl.transpose( - head_id.broadcast_to((1, num_tiles_per_partition))) - if num_blocks_per_tile > 1: - head_id = head_id.broadcast_to( - (num_tiles_per_partition, num_blocks_per_tile)) - - if block_size_tiling_factor > 1: - broadcast_shape = ( - num_tiles_per_partition, - num_blocks_per_tile, - block_size_tiling_factor, - ) - offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], - dtype=nl.int32).broadcast_to(broadcast_shape) - - for partition_id in nl.affine_range(num_partitions): - block_tables_partition = block_tables[partition_id] - if num_head > 1: - # fuse num_block and num_head dimension - block_tables_partition = block_tables_partition * num_head + head_id - - if block_size_tiling_factor > 1: - # need to apply block size tiling trick - assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE - block_tables_partition = ((block_tables_partition * - block_size_tiling_factor).reshape( - (num_tiles_per_partition, - num_blocks_per_tile, - 1)).broadcast_to(broadcast_shape)) - new_block_tables = block_tables_partition + offset - new_block_tables = new_block_tables.reshape( - (num_tiles_per_partition, B_P_SIZE)) - else: - new_block_tables = block_tables_partition - - # transpose the block table so that it can be used by vector DGE - for i in nl.affine_range(num_loads): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = (partition_id * num_tiles_per_partition + - nl.arange(num_tiles_per_partition)[None, :]) - block_tables_transposed[i, i_p, i_f] = nl.transpose( - new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) - return block_tables_transposed - - -@nki.jit -def load_kv_tile_from_cache( - cur_k_tile, - cur_v_tile, - kv_cache, - block_tables, - large_k_tile_idx, - num_blocks_per_large_tile, - tiled_block_size, - B_P_SIZE, - B_D_SIZE, -): - """ - Load KV cache and transform Key and Value into layout required by Matmul - - Vectorized DMA Load layout: - Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) - - Layout used by attention matmuls: - Key: (par_dim(B_D_SIZE), seqlen_kv) - Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) - equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) - """ - # load key cache - num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) - for load_idx in nl.affine_range(num_loads): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) - if cur_k_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) - # Transpose SBUF tensor using PE - for tb_i in nl.affine_range(tiled_block_size): - cur_k_tile[ - :, - nl.ds( - load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, - B_P_SIZE, - ), - ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) - - # load value cache - for load_idx in nl.affine_range(num_loads): - loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) - if cur_v_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - cur_v_tile[ - :, - nl.ds( - load_idx * tiled_block_size * B_D_SIZE, - tiled_block_size * B_D_SIZE, - ), - ] = loaded - - -@nki.jit -def transpose_p_local(p_local_transposed, - p_local, - LARGE_TILE_SZ, - B_F_SIZE=512): - for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), - buffer=nl.sbuf, - dtype=p_local.dtype) - else: - p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), - buffer=nl.psum, - dtype=np.float32) - - for j in nl.affine_range(B_F_SIZE // 128): - j_128_slice = nl.ds(j * 128, 128) - i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) - - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( - p_local[:, i_j_128_slice]) - else: - p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( - p_local[:, i_j_128_slice]) - - p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( - p_local_t_tmp, dtype=p_local_transposed.dtype) - - -@nki.jit -def _flash_attention_core( - q_local_tile, - k, - v, - o_buffer, - l_buffer, - m_buffer, - kernel_dtype, - acc_type, - tile_mask, - use_causal_mask, - q_tile_idx=None, - initialize=False, - LARGE_TILE_SZ=2048, - B_P_SIZE=128, - B_F_SIZE=512, - B_D_SIZE=128, - qk_res_buffer=None, -): - """ - The flash attention core function to calculate self attention between a tile - of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_D_SIZE) - The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will - be split into size B_F_SIZE tiles - - The results are stored in the following three buffers - o_buffer: (B_P_SIZE, d) - l_buffer: (B_P_SIZE, 1) - m_buffer: (B_P_SIZE, 1) - - All IO buffers are in SBUF. - """ - num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - - qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - buffer=nl.sbuf, - dtype=acc_type) - max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), - dtype=acc_type) - for k_i in nl.affine_range(num_k_tile_per_large_tile): - k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) - - if use_causal_mask: - # mask are used to only apply computation to the lower half of the - # matrix, which reduce the arithmetic intensity by up to 50% - multiplication_required_selection = (q_tile_idx * B_P_SIZE - >= k_i * B_F_SIZE) - else: - multiplication_required_selection = True - - if multiplication_required_selection: - qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), - dtype=np.float32, - buffer=nl.psum) # (128, 512) - qk_psum[:, :] = nl.matmul(q_local_tile, - k[:, k_i_b_f_slice], - transpose_x=True) # (p(128), 512) - qk_res_buf[:, k_i_b_f_slice] = nl.where( - tile_mask[:, k_i_b_f_slice], - qk_psum[:, nl.ds(0, B_F_SIZE)], - -9984.0, - dtype=acc_type, - ) - else: - qk_res_buf[:, k_i_b_f_slice] = -9984.0 - - # Calculate max of the current tile - max_local[:, k_i] = nisa.tensor_reduce( - np.max, - qk_res_buf[:, k_i_b_f_slice], - axis=(1, ), - dtype=acc_type, - negate=False, - ) - - if qk_res_buffer is not None: - qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) - - max_ = nisa.tensor_reduce( - np.max, - max_local[:, :], - axis=(1, ), - dtype=acc_type, - negate=False, - ) - - o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), - dtype=o_buffer.dtype) - - if initialize: - m_buffer[:, 0] = nl.copy(max_) - m_current = max_ - else: - m_previous = nl.copy(m_buffer[:, 0]) - m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) - - m_current = m_buffer[:, 0] - # Compute scaling factor - alpha = nisa.activation( - np.exp, - m_previous, - bias=-1 * m_current, - scale=1.0, - ) - o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) - - p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) - - p_partial_sum = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), - dtype=acc_type, - ) - - for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): - k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) - - # compute exp(qk - max) - # Compute partial row - tile sum of exp(qk - max)) - # FIXME : Use activation accumulate to accumulate over k_r_i loop ? - p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( - np.exp, - qk_res_buf[:, k_r_i_reduce_slice], - bias=-1 * m_current, - scale=1.0, - reduce_op=nl.add, - reduce_res=p_partial_sum[:, k_r_i], - dtype=kernel_dtype, - ) - - ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) - - p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - transpose_p_local( - p_local_transposed=p_local_transposed, - p_local=p_local, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_F_SIZE=B_F_SIZE, - ) - - pv_psum = nl.zeros( - (par_dim(B_P_SIZE), B_D_SIZE), - dtype=np.float32, - buffer=nl.psum, - ) - for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - pv_psum[:, :] += nl.matmul( - p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], - v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], - transpose_x=True, - ) # (128, 128) (p(Br), d) - - if initialize: - o_buffer[:, :] = nl.copy(pv_psum[:, :]) - l_buffer[:, 0] = nl.add(nl.log(ps), max_) - else: - o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) - - l_prev = l_buffer[:, 0] - l_exp = nl.add( - nl.exp(nl.subtract(l_prev, m_current)), - ps, - ) - l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) - - -@nki.jit -def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): - B_P_SIZE = 128 - B_D_SIZE = v_hbm_tile.shape[-1] - loaded = nl.load(v_hbm_tile[ - nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), - :, - ]) - if cur_v_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) - cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded - - -@nki.jit -def flash_paged_attention( - query, - key, - value, - kv_cache, - block_tables, - mask, - softmax_scale=None, - mixed_precision=True, - LARGE_TILE_SZ=2048, - return_debug_tensors=False, -): - """ - Flash PagedAttention Forward Kernel. - - IO tensor layouts: - - query: shape (1, n_heads, d, seq_q) - - key: shape (1, n_kv_heads, d, seq_k) - - value: shape (1, n_kv_heads, seq_v, d) - - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) - - block_tables: (num_active_blocks, ) - - mask: (seq_q, num_active_blocks * block_size + seq_q) - - o: shape (1, n_heads, seq_q, d) - - - This kernel requires seq_k == seq_v - - We use continuous batching by default, so the batch dimension is - always 1, and different requests are concatenated along sequence - dimension. - - We use paged cache blocks (kv_cache) to store KV cache. - - IO tensor dtypes: - - This kernel assumes all IO tensors have the same dtype except for - block_tables (int32) and mask (int32) - - If mixed_precision is True, then all Tensor Engine operation will be - performed in bfloat16 and accumulation will be performed in float32. - Otherwise the intermediates will be in the same type as the inputs. - - Compile-time Constants: - - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - - mixed_precision: flag to set non-matmul ops in fp32 precision, default - is set to `true`, if false, we use same precision as input types - - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention - computation reduction - - GQA support Notes: - the spmd kernel for launching kernel should be on kv_heads instead of - nheads - - Example usage: - MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] - usage: `flash_fwd[b, h](q, k, v, ...)` - GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] - usage: `flash_fwd[b, kv_h](q, k, v, ...)` - """ - B_F_SIZE = 512 - B_P_SIZE = 128 - b, h, d, seqlen_q = query.shape - B_D_SIZE = d - n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine - _, num_blocks, k_h, block_size, _ = kv_cache.shape - q_h_per_k_h = h // k_h - assert b == 1, f"invalid batch size {b=}" - assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" - cache_shape = (2, num_blocks, k_h, block_size, d) - assert (tuple(kv_cache.shape) == cache_shape - ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" - assert key is None or tuple(key.shape) == ( - 1, - k_h, - d, - seqlen_q, - ), f"key shape {key.shape} mismatch!" - assert value is None or tuple(value.shape) == ( - 1, - k_h, - seqlen_q, - d, - ), f"value shape {value.shape} mismatch!" - - assert ( - nl.program_ndim() == 2 - ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - - (num_active_blocks, ) = block_tables.shape - context_kv_len = num_active_blocks * block_size - assert ( - LARGE_TILE_SZ % B_F_SIZE == 0 - ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" - assert (context_kv_len % LARGE_TILE_SZ == 0 - ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" - - num_blocks_per_large_tile = LARGE_TILE_SZ // block_size - assert is_power_of_2( - num_blocks_per_large_tile - ), f"{num_blocks_per_large_tile=} is expected of be power of 2" - if seqlen_q > B_F_SIZE: - MAX_REDUCTION_TILE = 2048 - if seqlen_q // 2 > MAX_REDUCTION_TILE: - assert ( - seqlen_q % MAX_REDUCTION_TILE == 0 - ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" - else: - assert (seqlen_q % B_F_SIZE == 0 - ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" - - kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype - acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype - softmax_scale = softmax_scale or (1.0 / (d**0.5)) - num_large_k_tile = context_kv_len // LARGE_TILE_SZ - - o = nl.ndarray((b, h, seqlen_q, d), - dtype=query.dtype, - buffer=nl.shared_hbm) - hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( - None, - None, - None, - None, - ) - if return_debug_tensors: - hbm_l_buffer = nl.ndarray((b, h, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - hbm_m_buffer = nl.ndarray((b, h, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - qk_res_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - block_tables_sbuf = load_block_tables( - block_tables_hbm=block_tables, - num_tiles=num_large_k_tile, - num_blocks_per_tile=num_blocks_per_large_tile, - ) - - # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient - if num_blocks_per_large_tile < B_P_SIZE: - # we checked num_blocks_per_tile is a power of 2 - assert B_P_SIZE % num_blocks_per_large_tile == 0 - block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile - # We assume block_size >= block_size_tiling_factor - assert block_size % block_size_tiling_factor == 0 - else: - block_size_tiling_factor = 1 - tiled_block_size = block_size // block_size_tiling_factor - - # Indirect DMA load must be placed along Partition Dimension - block_tables_sbuf = transform_block_tables_for_indirect_load( - block_tables_sbuf, - block_size_tiling_factor=block_size_tiling_factor, - num_head=k_h, - head_id=head_id, - ) - - # Flatten KV cache to be 3D for loading into SBUF - new_cache_shape = ( - 2, - num_blocks * k_h * block_size_tiling_factor, - tiled_block_size * d, - ) - kv_cache = kv_cache.reshape(new_cache_shape) - - # Global Flash Attention accumulators - o_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - l_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - m_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - - for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): - num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) - cur_k_tile = nl.ndarray( - (par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype, - ) - cur_v_tile = nl.ndarray( - (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), - dtype=kernel_dtype, - ) - load_kv_tile_from_cache( - cur_k_tile=cur_k_tile, - cur_v_tile=cur_v_tile, - kv_cache=kv_cache, - block_tables=block_tables_sbuf, - large_k_tile_idx=large_k_tile_idx, - num_blocks_per_large_tile=num_blocks_per_large_tile, - tiled_block_size=tiled_block_size, - B_P_SIZE=B_P_SIZE, - B_D_SIZE=B_D_SIZE, - ) - - for i in nl.affine_range(n_tile_q): - cur_mask = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), - ]) - for i_q_h in nl.affine_range(q_h_per_k_h): - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) - q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load(q_hbm_tile[:, - nl.ds(i * - B_P_SIZE, B_P_SIZE)]) - if q_sbuf_tile.dtype != kernel_dtype: - q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) - q_tile[:, :] = q_sbuf_tile * softmax_scale - - _flash_attention_core( - q_local_tile=q_tile, - k=cur_k_tile, - v=cur_v_tile, - o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[i, i_q_h], - m_buffer=m_buffer[i, i_q_h], - kernel_dtype=kernel_dtype, - acc_type=acc_type, - tile_mask=cur_mask, - use_causal_mask=False, - q_tile_idx=i, - initialize=large_k_tile_idx == 0, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_P_SIZE=B_P_SIZE, - B_F_SIZE=B_F_SIZE, - B_D_SIZE=B_D_SIZE, - ) - - # compute attention between input query, key and value - if key is not None and value is not None: - B_F_SIZE = min(seqlen_q, B_F_SIZE) - LARGE_TILE_SZ = seqlen_q - - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - cur_v_tile = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), - dtype=kernel_dtype, - ) - - loaded = nl.load(key[batch_id, head_id, :, :]) - if loaded.dtype != kernel_dtype: - loaded = nl.copy(loaded, dtype=kernel_dtype) - cur_k_tile[:, :] = loaded - - v_hbm_tile = value[batch_id, head_id] - for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - load_v_tile( - v_hbm_tile=v_hbm_tile, - cur_v_tile=cur_v_tile, - large_tile_idx=0, - v_i=v_i, - LARGE_TILE_SZ=LARGE_TILE_SZ, - ) - - for i in nl.affine_range(n_tile_q): - cur_mask = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(context_kv_len, LARGE_TILE_SZ), - ]) - for i_q_h in nl.affine_range(q_h_per_k_h): - - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) - q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load(q_hbm_tile[:, - nl.ds(i * - B_P_SIZE, B_P_SIZE)]) - if q_sbuf_tile.dtype != kernel_dtype: - q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) - q_tile[:, :] = q_sbuf_tile * softmax_scale - _flash_attention_core( - q_local_tile=q_tile, - k=cur_k_tile, - v=cur_v_tile, - o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[i, i_q_h], - m_buffer=m_buffer[i, i_q_h], - kernel_dtype=kernel_dtype, - acc_type=acc_type, - tile_mask=cur_mask, - use_causal_mask=True, - q_tile_idx=i, - initialize=False, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_P_SIZE=B_P_SIZE, - B_F_SIZE=B_F_SIZE, - B_D_SIZE=B_D_SIZE, - qk_res_buffer=(qk_res_buffer[i, i_q_h] - if qk_res_buffer is not None else None), - ) - - # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # - for i_q_h in nl.affine_range(q_h_per_k_h): - for i in nl.affine_range(n_tile_q): - out = nl.multiply( - o_buffer[i, i_q_h], - nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), - dtype=kernel_dtype, - ) - - nl.store( - o[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - :, - ], - out, - ) - # maximum and summation statistics - if return_debug_tensors: - nl.store( - hbm_m_buffer[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - ], - m_buffer[i, i_q_h, :, :], - ) - nl.store( - hbm_l_buffer[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - ], - l_buffer[i, i_q_h], - ) - nl.store( - hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], - qk_res_buffer[batch_id, i_q_h, :, :], - ) - - if return_debug_tensors: - return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res - return o - - -def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): - """ - Reorder the mask to make it compatible with the flash attention kernel. - - We vectorize KV cache read to improve DMA utilization. However, the layout - that maximizes DMA bandwidth changes the order tokens are consumed. - - The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, - tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And - each step the engine consumes a column (rather than a row) of B_P_SIZE - tokens. Therefore, the tokens are visited in a strided way. - - To make sure mask matches the order tokens are consumed, we need to properly - transpose mask. - """ - total_query_len, total_seq_len = mask.shape - context_kv_len = total_seq_len - total_query_len - - B_P_SIZE = 128 - assert (LARGE_TILE_SZ - >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" - num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) - tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks - if tiled_block_size > 1: - # Mask reordering is needed when tiled_block_size > 1 - device = mask.device - mask = mask.cpu() - context_mask = mask[:, :context_kv_len] - context_mask = context_mask.view( - total_query_len, - context_kv_len // LARGE_TILE_SZ, - num_tiled_blocks // B_P_SIZE, - B_P_SIZE, - tiled_block_size, - ) - context_mask = context_mask.transpose(3, 4).reshape( - total_query_len, context_kv_len) - new_mask = mask[:, context_kv_len:] - return torch.concat([context_mask, new_mask], dim=1).to(device) - else: - return mask - - -def flash_attn_varlen_nkifunc( - query, - key, - value, - kv_cache, - block_table, - attn_mask, - n_kv_head=None, - head_size=None, - LARGE_TILE_SZ=2048, - mixed_precision=True, -): - """ - Compute flash paged attention for variable length sequences. - - This function is a wrapper around the flash attention NKI kernel. It takes - in the following arguments: - - query: (1, n_heads, d, seq_q) - - key: (1, n_kv_heads, d, seq_k) - - value: (1, n_kv_heads, seq_v, d) - - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) - - block_tables: (n_active_blocks, ) - - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) - - Notes: - - attn_mask must be reordered outside using `reorder_context_mask` - - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) - for better DMA throughput - """ - if n_kv_head is None: - n_kv_head = kv_cache.shape[2] - assert kv_cache.shape[0] == 2 - assert kv_cache.shape[2] == n_kv_head - if head_size is None: - head_size = kv_cache.shape[-1] - - kwargs = dict( - query=query, - key=key, - value=value, - kv_cache=kv_cache, - block_tables=block_table, - mask=attn_mask, - softmax_scale=1.0 / (head_size**0.5), - mixed_precision=mixed_precision, - LARGE_TILE_SZ=LARGE_TILE_SZ, - ) - - o = flash_paged_attention[1, n_kv_head](**kwargs) - return o - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, -) -> None: - """ - Writes key-value pairs to the KV cache at specified positions. - - Args: - key (torch.Tensor): Key tensor with shape - (num_tokens, n_kv_head, d_head) - value (torch.Tensor): Value tensor with shape - (num_tokens, n_kv_head, d_head) - kv_cache (torch.Tensor): Key/value cache tensor with shape - (2, num_blocks, n_kv_head, block_size, d_head) - slot_mapping (torch.Tensor): Mapping tensor indicating cache positions - with shape (num_tokens) - - Returns: - None: Updates the kv_cache tensor in-place - """ - block_size = kv_cache.size(3) - n_kv_head = key.size(1) - - # Calculate indices with explicit floor division - block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - block_offsets = slot_mapping % block_size - - # Create the head indices tensor - head_indices = torch.arange(n_kv_head, device=key.device) - - # Update caches using index_put_ - kv_cache.index_put_( - (torch.tensor([0], device=key.device), block_indices[:, None], - head_indices[None, :], block_offsets[:, None]), key) - - kv_cache.index_put_( - (torch.tensor([1], device=key.device), block_indices[:, None], - head_indices[None, :], block_offsets[:, None]), value) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c6d1501e27578..4db7d1a3a3258 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -2,13 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional import torch -from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd @@ -19,6 +24,7 @@ _PARTITION_SIZE = 512 @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] @@ -34,9 +40,8 @@ class PagedAttentionMetadata: class PagedAttention: - @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod @@ -45,7 +50,8 @@ class PagedAttention: block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) @staticmethod @@ -53,13 +59,12 @@ class PagedAttention: kv_cache: torch.Tensor, num_kv_heads: int, head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: x = 16 // kv_cache.element_size() num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @@ -109,16 +114,17 @@ class PagedAttention: if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention block_size = value_cache.size(-1) - assert (blocksparse_block_size > 0 and - blocksparse_block_size % block_size == 0), \ - (f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables.") + assert ( + blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 + ), ( + f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables." + ) output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -126,8 +132,9 @@ class PagedAttention: # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_seq_len <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + use_v1 = max_seq_len <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 + ) if use_v1: # Run PagedAttention V1. @@ -248,7 +255,7 @@ class PagedAttention: @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index d75983bd407d0..d0d836cc6aa5e 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -33,10 +33,12 @@ def _kv_cache_update_kernel( # Copy from new_kv_hbm_ref to scratch for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - new_kv_start = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[1, offset_i], 0) - length = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[2, offset_i], 0) + new_kv_start = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0 + ) + length = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 + ) async_copy = pltpu.make_async_copy( new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], scratch.at[i, pl.ds(0, length), ...], @@ -52,10 +54,12 @@ def _kv_cache_update_kernel( async_copies.clear() for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[0, offset_i], 0) - length = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[2, offset_i], 0) + kv_cache_start = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0 + ) + length = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 + ) async_copy = pltpu.make_async_copy( scratch.at[i, pl.ds(0, length), ...], kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], @@ -72,12 +76,14 @@ def _kv_cache_update_kernel( static_argnames=["page_size", "num_slices_per_block"], ) def kv_cache_update( - new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] - slices: jax. - Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) - kv_cache: jax. - Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] - num_kv_update_slices: jax.Array, # [1] + # [total_num_token, num_combined_kv_heads, head_dim] + new_kv: jax.Array, + # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) + slices: jax.Array, + # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + kv_cache: jax.Array, + # [1] + num_kv_update_slices: jax.Array, *, page_size: int = 32, num_slices_per_block: int = 8, @@ -114,7 +120,7 @@ def kv_cache_update( num_scalar_prefetch=len(scalar_prefetches), in_specs=in_specs, out_specs=out_specs, - grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ), + grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),), scratch_shapes=scratch_shapes, ), out_shape=out_shape, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index e1d41930f6231..addf1d9dea73e 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) +float8_info = torch.finfo(current_platform.fp8_dtype()) # Here's an example autotuner config for this kernel. This config does provide @@ -33,58 +34,63 @@ IS_TURING = current_platform.get_device_capability() == (7, 5) # key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] # ) @triton.jit -def _fwd_kernel(Q, - K, - V, - K_cache, - V_cache, - sink_ptr, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr, - SKIP_DECODE: tl.constexpr, - USE_SINKS: tl.constexpr, - MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): - +def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + sink_ptr, + B_Loc, + sm_scale, + k_scale, + v_scale, + out_scale_inv, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + USE_SINKS: tl.constexpr, + USE_FP8: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -94,8 +100,7 @@ def _fwd_kernel(Q, cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len if SKIP_DECODE and cur_batch_query_len == 1: @@ -115,17 +120,21 @@ def _fwd_kernel(Q, # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # [M,D] - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] + dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to( + tl.int1 + ) # [D] - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] + q = tl.load( + Q + off_q, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len), + other=0.0, + ) # [M,D] # initialize pointer to m and l if not USE_SINKS: @@ -141,32 +150,43 @@ def _fwd_kernel(Q, acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] # compute query against context (no causal mask here) - for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - loop_unroll_factor=num_unroll_cache): + for start_n in tl.range( + 0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache + ): start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) + bn = tl.load( + B_Loc + + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s + ).to(tl.int64) # [D,BLOCK_SIZE] off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) + bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x + ) # [BLOCK_SIZE,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - offs_bs_n[:, None] * stride_v_cache_bl) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl + ) - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + if ( + start_n + BLOCK_SIZE > cur_batch_ctx_len + or BLOCK_DMODEL != BLOCK_DMODEL_PADDED + ): k_load = tl.load( K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + mask=dim_mask[:, None] + & ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0, + ) # [D,N] else: k_load = tl.load(K_cache + off_k) @@ -177,8 +197,9 @@ def _fwd_kernel(Q, qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where( + (start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") + ) qk *= sm_scale if SLIDING_WINDOW > 0: # (cur_batch_ctx_len + offs_m[:, None]) are the positions of @@ -192,9 +213,12 @@ def _fwd_kernel(Q, # sliding window may lead to the entire row being masked. # This then makes m_ij contain -inf, which causes NaNs in # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) + qk = tl.where( + (cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :]) + < SLIDING_WINDOW, + qk, + -10000, + ) # compute running maximum m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) @@ -204,13 +228,16 @@ def _fwd_kernel(Q, acc = acc * alpha[:, None] # update acc - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + if ( + start_n + BLOCK_SIZE > cur_batch_ctx_len + or BLOCK_DMODEL != BLOCK_DMODEL_PADDED + ): v_load = tl.load( V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] + mask=dim_mask[None, :] + & ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0, + ) # [N,D] else: v_load = tl.load(V_cache + off_v) @@ -225,10 +252,16 @@ def _fwd_kernel(Q, l_i = l_i * alpha + l_ij m_i = m_ij - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = ( + offs_n[None, :] * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd + ) + off_v = ( + offs_n[:, None] * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd + ) k_ptrs = K + off_k v_ptrs = V + off_v @@ -236,27 +269,32 @@ def _fwd_kernel(Q, block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) # compute query against itself (with causal mask) - for start_n in tl.range(0, \ - block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - loop_unroll_factor=num_unroll_request): + for start_n in tl.range( + 0, + block_mask * (start_m + 1) * BLOCK_M, + BLOCK_N, + loop_unroll_factor=num_unroll_request, + ): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] + & ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk *= sm_scale # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) if SLIDING_WINDOW > 0: qk = tl.where( offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, - qk, -10000) + qk, + -10000, + ) # compute running maximum m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) @@ -266,11 +304,12 @@ def _fwd_kernel(Q, acc = acc * alpha[:, None] # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] + & ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0, + ) p = p.to(v.dtype) acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) @@ -281,12 +320,18 @@ def _fwd_kernel(Q, acc = acc / l_i[:, None] # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + tl.store( + out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len) + ) return @@ -349,12 +394,17 @@ def _fwd_kernel_flash_attn_v2( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) - q = tl.load(Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0, + ) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -364,26 +414,36 @@ def _fwd_kernel_flash_attn_v2( for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + bn = tl.load( + B_Loc + + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0, + ).to(tl.int64) off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x + ) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl + ) + k = tl.load( + K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where( + (start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") + ) qk *= sm_scale # -- compute m_ij, p, l_ij @@ -402,9 +462,11 @@ def _fwd_kernel_flash_attn_v2( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) + v = tl.load( + V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -412,30 +474,34 @@ def _fwd_kernel_flash_attn_v2( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = ( + offs_n[None, :] * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd + ) + off_v = ( + offs_n[:, None] * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd + ) k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -453,11 +519,11 @@ def _fwd_kernel_flash_attn_v2( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -467,12 +533,15 @@ def _fwd_kernel_flash_attn_v2( # acc /= l_i[:, None] # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + tl.store( + out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len + ) return @@ -537,8 +606,7 @@ def _fwd_kernel_alibi( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len if SKIP_DECODE and cur_batch_query_len == 1: @@ -550,16 +618,22 @@ def _fwd_kernel_alibi( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to( + tl.int1 + ) - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + q = tl.load( + Q + off_q, + mask=dim_mask[None, :] + & (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0, + ) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -572,23 +646,31 @@ def _fwd_kernel_alibi( for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + bn = tl.load( + B_Loc + + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0, + ).to(tl.int64) off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x + ) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl + ) + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0, + ) # [D,N] if k_load.dtype.is_fp8(): k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) @@ -597,16 +679,20 @@ def _fwd_kernel_alibi( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where( + (start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") + ) qk *= sm_scale # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope + alibi = ( + tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None] + ) * alibi_slope alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, + float("-inf"), + ) qk += alibi alibi_start_k += BLOCK_N @@ -626,30 +712,36 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0, + ) if v_load.dtype.is_fp8(): v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) - acc = tl.dot(p, v, acc=acc, input_precision='ieee') + acc = tl.dot(p, v, acc=acc, input_precision="ieee") # update m_i and l_i l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = ( + offs_n[None, :] * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd + ) + off_v = ( + offs_n[:, None] * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd + ) k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) # init alibi alibi_slope = tl.load(Alibi_slopes + cur_head) @@ -664,22 +756,25 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load( k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + mask=dim_mask[:, None] + & ((start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk = tl.dot(q, k, acc=qk, input_precision="ieee") qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope + alibi = ( + tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None] + ) * alibi_slope alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, + float("-inf"), + ) qk += alibi alibi_start_k += BLOCK_N @@ -701,12 +796,13 @@ def _fwd_kernel_alibi( # update acc v = tl.load( v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + mask=dim_mask[None, :] + & ((start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0, + ) p = p.to(v.dtype) - acc = tl.dot(p, v, acc=acc, input_precision='ieee') + acc = tl.dot(p, v, acc=acc, input_precision="ieee") # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -714,44 +810,51 @@ def _fwd_kernel_alibi( acc = acc / l_i[:, None] # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + tl.store( + out_ptrs, + acc, + mask=dim_mask[None, :] + & (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + ) return @torch.inference_mode() -def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False, - sinks=None): - +def context_attention_fwd( + q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False, + fp8_out_scale=None, + sinks=None, +): q_dtype_is_f32 = q.dtype is torch.float32 # Turing does have tensor core for float32 multiplication # use ieee as fallback for triton kernels work. There is also # warning on vllm/config.py to inform users this fallback # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + IN_PRECISION = "ieee" if IS_TURING and q_dtype_is_f32 else None # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton @@ -769,10 +872,15 @@ def context_attention_fwd(q, k_cache = k_cache.view(target_dtype) v_cache = v_cache.view(target_dtype) - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") + if ( + k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 + and kv_cache_dtype == "auto" + ): + raise ValueError( + "kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel" + ) # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -793,6 +901,7 @@ def context_attention_fwd(q, if alibi_slopes is not None: assert sinks is None, "Sinks arg is not supported with alibi" + assert fp8_out_scale is None, "FP8 output not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -833,13 +942,11 @@ def context_attention_fwd(q, k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), - k_cache.stride( - 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x] v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] + v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, @@ -857,8 +964,7 @@ def context_attention_fwd(q, if current_platform.is_rocm(): extra_kargs = {"kpack": 1, "waves_per_eu": 2} - grid = lambda META: (batch, head, - triton.cdiv(max_input_len, META["BLOCK_M"])) + grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid]( q, k, @@ -870,6 +976,7 @@ def context_attention_fwd(q, sm_scale, k_scale, v_scale, + 1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0, b_start_loc, b_seq_len, k_cache.shape[4], @@ -892,12 +999,11 @@ def context_attention_fwd(q, k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), - k_cache.stride( - 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x] v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), - v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size] BLOCK_SIZE=v_cache.shape[3], num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, @@ -905,6 +1011,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, + USE_FP8=fp8_out_scale is not None, BLOCK_M=128, BLOCK_N=64, num_unroll_cache=4, @@ -912,5 +1019,6 @@ def context_attention_fwd(q, num_warps=4, num_stages=1, USE_SINKS=sinks is not None, - **extra_kargs) + **extra_kargs, + ) return diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index d91cda255ff31..c358b5971f865 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -9,18 +9,16 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -def get_aiter_mla_metadata(max_batch_size: int, block_size: int, - max_block_per_batch: int, - device: torch.device) -> tuple[torch.Tensor, ...]: - paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, - dtype=torch.int32, - device=device) - paged_kv_indptr = torch.zeros(max_batch_size + 1, - dtype=torch.int32, - device=device) - paged_kv_last_page_lens = torch.full((max_batch_size, ), - block_size, - dtype=torch.int32) +def get_aiter_mla_metadata( + max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device +) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros( + max_batch_size * max_block_per_batch, dtype=torch.int32, device=device + ) + paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device) + paged_kv_last_page_lens = torch.full( + (max_batch_size,), block_size, dtype=torch.int32 + ) qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr @@ -37,18 +35,18 @@ def aiter_mla_decode_fwd( kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): - - torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, - kv_buffer.view( - -1, 1, 1, q.shape[-1]), - o, - qo_indptr, - max_seqlen_qo, - kv_indptr, - kv_indices, - kv_last_page_lens, - sm_scale=sm_scale, - logit_cap=logit_cap) + torch.ops.vllm.rocm_aiter_mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) def mla_decode_fwd_impl( @@ -65,16 +63,18 @@ def mla_decode_fwd_impl( ) -> None: from aiter.mla import mla_decode_fwd - mla_decode_fwd(q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap) + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) def mla_decode_fwd_fake( @@ -96,9 +96,11 @@ if current_platform.is_rocm(): if is_torch_equal_or_newer("2.7.0"): tags = () else: - tags = (torch.Tag.needs_fixed_stride_order, ), - direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", - op_func=mla_decode_fwd_impl, - mutates_args=["o"], - fake_impl=mla_decode_fwd_fake, - tags=tags) + tags = ((torch.Tag.needs_fixed_stride_order,),) + direct_register_custom_op( + op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=tags, + ) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index ad97152e208b8..069cfcaf00aaf 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -13,7 +13,6 @@ FP8_DTYPE = current_platform.fp8_dtype() class AITERPagedAttention(PagedAttention): - @staticmethod def write_to_paged_cache( key: torch.Tensor, @@ -26,19 +25,31 @@ class AITERPagedAttention(PagedAttention): v_scale: torch.Tensor, ) -> None: if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, - v_scale) + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) else: - kv_cache_torch_dtype = (FP8_DTYPE - if "fp8" in kv_cache_dtype else torch.int8) + kv_cache_torch_dtype = FP8_DTYPE if "fp8" in kv_cache_dtype else torch.int8 key_cache = key_cache.view(kv_cache_torch_dtype) value_cache = value_cache.view(kv_cache_torch_dtype) rocm_aiter.reshape_and_cache_with_pertoken_quant( - key, value, key_cache, value_cache, k_scale, v_scale, - slot_mapping.flatten(), True) + key, + value, + key_cache, + value_cache, + k_scale, + v_scale, + slot_mapping.flatten(), + True, + ) @staticmethod def forward_decode( @@ -78,25 +89,36 @@ class AITERPagedAttention(PagedAttention): blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step) + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) if "fp8" in kv_cache_dtype: - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention block_size = value_cache.size(-1) - assert (blocksparse_block_size > 0 and - blocksparse_block_size % block_size == 0), \ - (f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables.") + assert ( + blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 + ), ( + f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables." + ) output = torch.empty_like(query) block_size = value_cache.shape[3] max_num_blocks_per_seq = cdiv(max_seq_len, block_size) - rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, - seq_lens, max_num_blocks_per_seq, k_scale, - v_scale, output) + rocm_aiter.pa_fwd_asm( + query, + key_cache, + value_cache, + block_tables, + seq_lens, + max_num_blocks_per_seq, + k_scale, + v_scale, + output, + ) return output diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index f82ce5b4d4b67..aebc2e63cff69 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -42,10 +42,11 @@ logger = logging.getLogger(__name__) # Only print the following warnings when triton version < 3.2.0. # The issue won't affect performance or accuracy. -if version.parse(triton.__version__) < version.parse('3.2.0'): +if version.parse(triton.__version__) < version.parse("3.2.0"): logger.warning( "The following error message 'operation scheduled before its operands' " - "can be ignored.") + "can be ignored." + ) @triton.jit @@ -101,8 +102,7 @@ def _fwd_kernel_stage1( kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = -float("inf") e_sum = 0.0 @@ -112,14 +112,18 @@ def _fwd_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + - offs_n // PAGE_SIZE, + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + offs_d[None, :]) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), @@ -133,8 +137,11 @@ def _fwd_kernel_stage1( qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + - cur_kv_head * stride_buf_vh + offs_dv[None, :]) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), @@ -150,8 +157,12 @@ def _fwd_kernel_stage1( e_sum = e_sum * re_scale + tl.sum(p, 0) e_max = n_e_max - offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + offs_dv) + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) tl.store( Att_Out + offs_mid_o, @@ -159,8 +170,12 @@ def _fwd_kernel_stage1( mask=(mask_dv), ) - offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) tl.store( Att_Out + offs_mid_o_1, @@ -282,25 +297,22 @@ def _fwd_grouped_kernel_stage1( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = cur_batch - offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ - None, :] - q = tl.load(Q + offs_q, - mask=(mask_h[:, None]) & (mask_d[None, :]), - other=0.0) + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) mask_dpe = offs_dpe < Lk - off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + - offs_dpe[None, :]) - qpe = tl.load(Q + off_qpe, - mask=(mask_h[:, None]) & (mask_dpe[None, :]), - other=0.0) + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) @@ -310,14 +322,18 @@ def _fwd_grouped_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + - offs_n // PAGE_SIZE, + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + offs_d[:, None]) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), @@ -325,13 +341,14 @@ def _fwd_grouped_kernel_stage1( ) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: - offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + - offs_dpe[:, None]) + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) kpe = tl.load( K_Buffer + offs_buf_kpe, - mask=(offs_n[None, :] < split_kv_end) & - (mask_dpe[:, None]), + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) qk += tl.dot(qpe, kpe.to(qpe.dtype)) @@ -340,11 +357,15 @@ def _fwd_grouped_kernel_stage1( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), - qk, float("-inf")) + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) - offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + - cur_kv_head * stride_buf_vh + offs_dv[None, :]) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), @@ -360,9 +381,12 @@ def _fwd_grouped_kernel_stage1( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_mid_o = (cur_batch * stride_mid_ob + - cur_head[:, None] * stride_mid_oh + - split_kv_id * stride_mid_os + offs_dv[None, :]) + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) tl.store( Att_Out + offs_mid_o, @@ -370,8 +394,12 @@ def _fwd_grouped_kernel_stage1( mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) tl.store( Att_Out + offs_mid_o_1, @@ -427,11 +455,7 @@ def _decode_grouped_att_m_fwd( if is_hip_: # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = { - "waves_per_eu": 1, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} num_stages = 1 _fwd_grouped_kernel_stage1[grid]( @@ -474,12 +498,14 @@ def _decode_grouped_att_m_fwd( def _fwd_kernel_stage2( Mid_O, o, + lse, B_Seqlen, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_obs, stride_oh, + stride_lse_bs, NUM_KV_SPLITS: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, @@ -502,13 +528,12 @@ def _fwd_kernel_stage2( for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: - tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, - mask=mask_d, - other=0.0) + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) n_e_max = tl.maximum(tlogic, e_max) @@ -525,12 +550,18 @@ def _fwd_kernel_stage2( acc / e_sum, mask=mask_d, ) + lse_val = e_max + tl.log(e_sum) + tl.store( + lse + cur_batch * stride_lse_bs + cur_head, + lse_val, + ) def _decode_softmax_reducev_fwd( logits, q, o, + lse, v_buffer, b_seq_len, num_kv_splits, @@ -545,22 +576,20 @@ def _decode_softmax_reducev_fwd( if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = { - "waves_per_eu": 4, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} grid = (batch, head_num) _fwd_kernel_stage2[grid]( logits, o, + lse, b_seq_len, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), + lse.stride(0), NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=BLOCK_DV, Lv=Lv, @@ -575,6 +604,7 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -595,8 +625,9 @@ def decode_attention_fwd_normal( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, - num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) def decode_attention_fwd_grouped( @@ -604,6 +635,7 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -624,8 +656,9 @@ def decode_attention_fwd_grouped( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, - num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) def decode_attention_fwd( @@ -633,6 +666,7 @@ def decode_attention_fwd( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -651,6 +685,7 @@ def decode_attention_fwd( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, @@ -666,6 +701,7 @@ def decode_attention_fwd( k_buffer, v_buffer, o, + lse, req_to_token, b_seq_len, attn_logits, diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 49070e4c7ae6a..c0ab35d07b1fe 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -55,16 +55,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) + rng_offsets = dropout_offsets( + philox_seed, philox_offset, dropout_p, m, n, stride + ).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep @@ -74,9 +74,9 @@ def load_fn(block_ptr, first, second, pad): if first and second: tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) else: tensor = tl.load(block_ptr) return tensor @@ -145,9 +145,7 @@ def _attn_fwd_inner( # if not is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], - actual_seqlen_k, - dtype=tl.int32) + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -160,8 +158,9 @@ def _attn_fwd_inner( if USE_FP8: qk *= qk_scale if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") + bias = load_fn( + bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" + ) # While bias is added after multiplying qk with sm_scale, our # optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. @@ -173,9 +172,12 @@ def _attn_fwd_inner( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) + philox_offset = ( + batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + + start_n + - BLOCK_N + ) keep = dropout_mask( philox_seed, philox_offset, @@ -187,8 +189,7 @@ def _attn_fwd_inner( if RETURN_ENCODED_SOFTMAX: tl.store( encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), + tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), ) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: @@ -221,89 +222,57 @@ def _attn_fwd_inner( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, BLOCK_N) + ) return acc, l_i, m_i def get_cdna_autotune_configs(): return [ triton.Config( - { - 'BLOCK_M': 256, - 'BLOCK_N': 64, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 256, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 128, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 256, - 'BLOCK_N': 128, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 256, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 1, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 3, - 'PRE_LOAD_V': True - }, + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": True}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 3, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 64, - 'BLOCK_N': 64, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 32, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: # triton.Config( # { @@ -315,47 +284,31 @@ def get_cdna_autotune_configs(): # num_stages=1, # num_warps=4, # ), - ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + ], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"] def get_rdna_autotune_configs(): return [ triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 32, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 32, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 16, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 16, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: # triton.Config( # { @@ -385,7 +338,7 @@ def get_rdna_autotune_configs(): # }, # num_stages=1, # num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + ], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"] def get_autotune_configs(): @@ -501,15 +454,17 @@ def attn_fwd( # This captures the decrease in n_blocks if we have a rectangular attn # matrix n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N + ) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) + o_offset = ( + off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + ) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), @@ -545,8 +500,7 @@ def attn_fwd( padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), @@ -555,8 +509,7 @@ def attn_fwd( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), @@ -565,8 +518,7 @@ def attn_fwd( block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), @@ -587,9 +539,9 @@ def attn_fwd( else: bias_ptr = None if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k + batch_philox_offset = ( + philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k + ) else: batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. @@ -692,8 +644,9 @@ def attn_fwd( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, n_full_blocks) + ) acc, l_i, m_i = _attn_fwd_inner( acc, l_i, @@ -749,13 +702,12 @@ def attn_fwd( acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) + out_mask_boundary = tl.full( + (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 + ) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = tl.zeros((1, ), tl.float32) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = tl.zeros((1,), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m @@ -772,8 +724,7 @@ def attn_fwd( # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), @@ -821,7 +772,6 @@ def check_args( class _attention(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -847,8 +797,7 @@ class _attention(torch.autograd.Function): def check_and_convert(t, scale): if t.dtype != float8: descale = 1.0 / scale - ts = (t * descale).clamp(min=float8_info.min, - max=float8_info.max) + ts = (t * descale).clamp(min=float8_info.min, max=float8_info.max) return ts.to(float8) else: return t @@ -923,8 +872,7 @@ class _attention(torch.autograd.Function): bias_strides = (0, 0, 0, 0) p_descale = 1.0 / p_scale - o_descale = 1.0 / fp8_out_scale.item( - ) if fp8_out_scale is not None else 1.0 + o_descale = 1.0 / fp8_out_scale.item() if fp8_out_scale is not None else 1.0 arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 56d78ed5ea6ee..d29f92f8cecb2 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -61,8 +61,8 @@ def merge_attn_states_kernel( # If we see an inf assume FA2 and convert inf to -inf for consistency # and correctness. Inf generally doesn't make sense in this context outside # of undefined-behavior/FA2-case, so I think this a safe assumption. - p_lse = float('-inf') if p_lse == float('inf') else p_lse - s_lse = float('-inf') if s_lse == float('inf') else s_lse + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse @@ -70,7 +70,7 @@ def merge_attn_states_kernel( # Will reuse precomputed Exp values for scale factor computation. p_se = tl.exp(p_lse) s_se = tl.exp(s_lse) - out_se = (p_se + s_se) + out_se = p_se + s_se if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse @@ -78,12 +78,20 @@ def merge_attn_states_kernel( head_arange = tl.arange(0, PADDED_HEAD_SIZE) head_mask = head_arange < HEAD_SIZE - p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. @@ -91,7 +99,8 @@ def merge_attn_states_kernel( p_scale = p_se / out_se s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale - tl.store(output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - out, - mask=head_mask) + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py new file mode 100644 index 0000000000000..bbcd560ad56e3 --- /dev/null +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + + +@triton.jit +def reshape_and_cache_kernel_flash( + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + key_cache_ptr, # [num_blocks, block_size, num_heads, head_size] + value_cache_ptr, # [num_blocks, block_size, num_heads, head_size] + slot_mapping_ptr, # [num_tokens] + k_scale, # float32 + v_scale, # float32 + # strides + key_stride: tl.int64, + value_stride: tl.int64, + block_stride: tl.int64, + page_stride: tl.int64, + num_heads: tl.constexpr, + head_size: tl.constexpr, + block_size: tl.constexpr, + # FP8 flags + FP8_KV_CACHE: tl.constexpr, + # tune parameters + TILE_SIZE: tl.constexpr, +): + token_idx = tl.program_id(axis=0) + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) + if slot_idx < 0: + # Padding token that should be ignored. + return + + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + tile_pos = tile_i * TILE_SIZE + tile_offs + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + src_key_idx = token_idx * key_stride + src_value_idx = token_idx * value_stride + + tgt_idx = block_idx * block_stride + block_offset * page_stride + + # [TILE_SIZE] + key_load = tl.load( + key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size) + ) + if FP8_KV_CACHE: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) + else: + key_tile = key_load + + # [TILE_SIZE] + value_load = tl.load( + value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size) + ) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load + + tl.store( + key_cache_ptr + tgt_idx + tile_pos, + key_tile, + mask=tile_pos < (num_heads * head_size), + ) + tl.store( + value_cache_ptr + tgt_idx + tile_pos, + value_tile, + mask=tile_pos < (num_heads * head_size), + ) + return + + +def triton_reshape_and_cache_flash( + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size] + # [num_blocks, block_size, num_heads, head_size] + key_cache: torch.Tensor, + # [num_blocks, block_size, num_heads, head_size] + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 +): + num_tokens = key.shape[0] + num_heads = key.shape[1] + head_size = key.shape[2] + block_size = key_cache.shape[1] + n = num_heads * head_size + + key_stride = key.stride()[0] + value_stride = value.stride()[0] + block_stride = key_cache.stride()[0] + page_stride = key_cache.stride()[1] + + head_stride = key_cache.stride()[2] + assert head_stride == head_size, "only continous heads are supported" + + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." + ) + kv_cache_torch_dtype = ( + current_platform.fp8_dtype() + if kv_cache_dtype.startswith("fp8") + else key_cache.dtype + ) + + if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + assert kv_cache_dtype != torch.uint8, ( + "explicit fp8 cast and store to " + "uint8 is not supported by triton reshape_and_cache_flash" + ) + + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint8, + torch.float8_e4m3fnuz, + ], ( + "unsupported dtype of KV cache tensor, got " + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." + ) + + # heuristics instead of autotuning + TILE_SIZE = min(2048, triton.next_power_of_2(n)) + if current_platform.is_rocm() or current_platform.is_xpu(): + num_stages = 4 + num_warps = 8 + else: # cuda + num_stages = 10 + num_warps = 16 + if torch.cuda.get_device_capability(key.device)[0] < 9: + TILE_SIZE = min(512, TILE_SIZE) + + # TODO(ngl): maybe replace with static launch grid to avoid overhead if + # using cudagraphs + grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) + + reshape_and_cache_kernel_flash[grid]( + key_ptr=key, + value_ptr=value, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + slot_mapping_ptr=slot_mapping, + k_scale=k_scale, + v_scale=v_scale, + # strides + key_stride=key_stride, + value_stride=value_stride, + block_stride=block_stride, + page_stride=page_stride, + num_heads=num_heads, + head_size=head_size, + block_size=block_size, + # FP8 flags + FP8_KV_CACHE=FP8_KV_CACHE, + # autotune parameters + TILE_SIZE=TILE_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 56ebed0f52448..565be1c39bec1 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -10,9 +10,11 @@ import torch from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton logger = init_logger(__name__) +float8_info = torch.finfo(current_platform.fp8_dtype()) @triton.jit @@ -29,8 +31,13 @@ def apply_softcap(S, x): @triton.jit -def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, - BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): left: tl.int32 = 0 right = num_seqs while left < right: @@ -48,77 +55,84 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -153,50 +167,85 @@ def kernel_unified_attention_2d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) # query-query attention bias if USE_QQ_BIAS: - qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 - ) # shape: [BLOCK_M] + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( - BLOCK_M - 1) // num_queries_per_kv + 1 + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) # adjust for potential padding in the last q_block by considering the # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles (blocks) that need to be processed to - # cover the longest sequence prefix (due to causal masking, blocks beyond + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond # this prefix can be skipped) - num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) - # iterate through tiles - for j in range(0, num_blocks): + # ---- Sliding-window tile pruning -------------------- + # Default: keep previous global behavior + tile_start = 0 + tile_end = num_tiles + if SLIDING_WINDOW > 0: + # Query rows covered by this Q-block + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + # For sliding window, each query position q can only attend to + # keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs] + # where q_abs = context_len + q + # The union of allowed key positions for this Q-block is: + # [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi] + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + # Convert to tile indices and clamp + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + # iterate through tiles (now limited to the sliding window range) + for j in range(tile_start, tile_end): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) - v_offset = (physical_block_idx * stride_v_cache_0 + - kv_head_idx * stride_v_cache_2 + - offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) - k_offset = (physical_block_idx * stride_k_cache_0 + - kv_head_idx * stride_k_cache_2 + - offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) - # K : (HEAD_SIZE, BLOCK_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], - other=0.0) + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -206,10 +255,12 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], - other=0.0) + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -219,24 +270,26 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -256,11 +309,12 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE) + # P : (BLOCK_M, TILE_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -281,10 +335,15 @@ def kernel_unified_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - output_offset = (query_offset_0[:, None] * output_stride_0 + - query_offset_1[:, None] * output_stride_1 + - offs_d[None, :]) + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) tl.store( output_ptr + output_offset, @@ -295,67 +354,67 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( - segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -365,22 +424,23 @@ def kernel_unified_attention_3d( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) - if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv - - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -415,40 +475,66 @@ def kernel_unified_attention_3d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) # query-query attention bias if USE_QQ_BIAS: - qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 - ) # shape: [BLOCK_M] + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) - v_offset = (physical_block_idx * stride_v_cache_0 + - kv_head_idx * stride_v_cache_2 + - offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) - k_offset = (physical_block_idx * stride_k_cache_0 + - kv_head_idx * stride_k_cache_2 + - offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) - # K : (HEAD_SIZE, BLOCK_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], - other=0.0) + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -458,10 +544,12 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], - other=0.0) + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -471,24 +559,25 @@ def kernel_unified_attention_3d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -508,11 +597,12 @@ def kernel_unified_attention_3d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE,) + # P : (BLOCK_M, TILE_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -532,88 +622,93 @@ def kernel_unified_attention_3d( acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( - query_offset_0[:, None].to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) tl.store( segm_output_ptr + segm_output_offset, acc, mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ) - segm_offset = (query_offset_0.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) - tl.store(segm_expsum_ptr + segm_offset, - L, - mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) @triton.jit def reduce_segments( - output_ptr, # [num_tokens, num_query_heads, head_size] - segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] - seq_lens_ptr, # [num_seqs] - num_seqs, # int - num_query_heads: tl.constexpr, # int - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + # [num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + TILE_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, - BLOCK_Q, False) + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( - [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, - 0).to(tl.int1) + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) # load segment maxima - segm_offset = (query_token_idx.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_head_idx * NUM_SEGMENTS_PER_SEQ + - tl.arange(0, NUM_SEGMENTS_PER_SEQ)) - segm_max = tl.load(segm_max_ptr + segm_offset, - mask=segm_mask, - other=float("-inf")) + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) overall_max = tl.max(segm_max) # load and rescale segment exp sums - segm_expsum = tl.load(segm_expsum_ptr + segm_offset, - mask=segm_mask, - other=0.0) + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) overall_expsum = tl.sum(segm_expsum) # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + - tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) segm_output = tl.load( segm_output_ptr + segm_output_offset, mask=segm_mask[:, None] & dim_mask[None, :], @@ -624,10 +719,16 @@ def reduce_segments( # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + # write result - output_offset = (query_token_idx * output_stride_0 + - query_head_idx * output_stride_1 + - tl.arange(0, HEAD_SIZE_PADDED)) + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) tl.store(output_ptr + output_offset, acc, mask=dim_mask) @@ -649,6 +750,7 @@ def unified_attention( k_descale, v_descale, alibi_slopes=None, + output_scale=None, qq_bias=None, # Optional tensor for sinks sinks=None, @@ -656,13 +758,8 @@ def unified_attention( assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" - block_size = v.shape[1] - assert q.element_size() >= 2 or block_size >= 32, \ - "Block size must be at least 32 for fp8" - if sinks is not None: - assert sinks.shape[0] == q.shape[1], \ - "Sinks must be num_query_heads size" + assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -674,7 +771,9 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) BLOCK_Q = BLOCK_M // num_queries_per_kv # Ideally we would launch with kernel with: @@ -688,12 +787,20 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + # Assigning default tile sizes for prefill and decode. + # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) + # and at least 16 for all other data types. + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 + # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: - kernel_unified_attention_2d[( - total_num_q_blocks, - num_kv_heads, - )]( + kernel_unified_attention_2d[ + ( + total_num_q_blocks, + num_kv_heads, + ) + ]( output_ptr=out, query_ptr=q, key_cache_ptr=k, @@ -706,6 +813,7 @@ def unified_attention( scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, @@ -716,6 +824,7 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -735,6 +844,7 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, + USE_FP8=output_scale is not None, ) else: # for initial version, NUM_SEGMENTS = 16 is chosen as a default @@ -764,52 +874,51 @@ def unified_attention( device=q.device, ) - kernel_unified_attention_3d[( - total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - sink_ptr=sinks, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - qq_bias_ptr=qq_bias, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_QQ_BIAS=use_qq_bias, - USE_SOFTCAP=(softcap > 0), - USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) - + kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -818,13 +927,15 @@ def unified_attention( seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a235ba6e0b42..7dfe6ffda6a80 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -2,38 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass from functools import cache -from typing import Generator, Optional, Union +from typing import Optional, Union import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) -def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: - """ - Convert a string backend name to a _Backend enum value. - - Returns: - * _Backend: enum value if backend_name is a valid in-tree type - * None: otherwise it's an invalid in-tree type or an out-of-tree platform is - loaded. - """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else \ - None - - def get_env_variable_attn_backend() -> Optional[_Backend]: - ''' + """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. @@ -41,10 +28,9 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: * _Backend enum value if an override is specified * None otherwise - ''' + """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return (None - if backend_name is None else backend_name_to_enum(backend_name)) + return None if backend_name is None else backend_name_to_enum(backend_name) # Global state allows a particular choice of backend @@ -58,7 +44,7 @@ forced_attn_backend: Optional[_Backend] = None def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: - ''' + """ Force all attention operations to use a specified backend. Passing `None` for the argument re-enables automatic @@ -67,16 +53,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: Arguments: * attn_backend: backend selection (None to revert to auto) - ''' + """ global forced_attn_backend forced_attn_backend = attn_backend def get_global_forced_attn_backend() -> Optional[_Backend]: - ''' + """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. - ''' + """ return forced_attn_backend @@ -109,26 +95,27 @@ def is_attn_backend_supported( assert isinstance(attn_backend, type) # TODO: Update the interface once V0 is removed - if get_supported_head_sizes := getattr(attn_backend, - "get_supported_head_sizes", None): + if get_supported_head_sizes := getattr( + attn_backend, "get_supported_head_sizes", None + ): is_head_size_supported = head_size in get_supported_head_sizes() - elif validate_head_size := getattr(attn_backend, "validate_head_size", - None): + elif validate_head_size := getattr(attn_backend, "validate_head_size", None): try: validate_head_size(head_size) is_head_size_supported = True except Exception: is_head_size_supported = False else: - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "head size validation") + raise NotImplementedError( + f"{attn_backend.__name__} does not support head size validation" + ) - if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", - None): + if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None): is_dtype_supported = dtype in get_supported_dtypes() else: - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "dtype validation") + raise NotImplementedError( + f"{attn_backend.__name__} does not support dtype validation" + ) return _IsSupported( can_import=True, @@ -142,9 +129,9 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool = False, use_mla: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -156,10 +143,10 @@ def get_attn_backend( dtype=dtype, kv_cache_dtype=kv_cache_dtype, block_size=block_size, - is_attention_free=is_attention_free, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, + use_sparse=use_sparse, ) @@ -169,52 +156,66 @@ def _cached_get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: - # If there are no attention layers (e.g. we are running Mamba), - # use the placeholder NO_ATTENTION - if is_attention_free: - from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionBackend) - return PlaceholderAttentionBackend - # Check whether a particular choice of backend was # previously forced. # # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: Optional[_Backend] = ( - get_global_forced_attn_backend()) + backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend() if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: # Check the environment variable and override if specified backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: + if backend_by_env_var.endswith("_VLLM_V1"): + logger.warning( + "The suffix '_VLLM_V1' in the environment variable " + "%s is no longer necessary as V0 backends have been " + "deprecated. Please remove this suffix from your " + "environment variable setting.", + STR_BACKEND_ENV_VAR, + ) + backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: raise ValueError( f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}") + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) # get device-specific attn_backend + from vllm.platforms import current_platform + attention_cls = current_platform.get_attn_backend_cls( - selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla, has_sink) + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) if not attention_cls: raise ValueError( - f"Invalid attention backend for {current_platform.device_name}") + f"Invalid attention backend for {current_platform.device_name}" + ) return resolve_obj_by_qualname(attention_cls) @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend) -> Generator[None, None, None]: - ''' + attn_backend: _Backend, +) -> Generator[None, None, None]: + """ Globally force a vLLM attention backend override within a context manager, reverting the global attention backend override to its prior state upon exiting the context @@ -227,7 +228,7 @@ def global_force_attn_backend_context_manager( Returns: * Generator - ''' + """ # Save the current state of the global backend override (if any) original_value = get_global_forced_attn_backend() @@ -241,3 +242,4 @@ def global_force_attn_backend_context_manager( finally: # Revert the original global backend override, if any global_force_attn_backend(original_value) + _cached_get_attn_backend.cache_clear() diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index f8b00565f0517..e13afd46ee96b 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -10,11 +10,12 @@ logger = init_logger(__name__) if current_platform.is_cuda(): from vllm import _custom_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash - from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) + from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash flash_attn_varlen_func = ops.flash_attn_varlen_func get_scheduler_metadata = ops.get_scheduler_metadata @@ -23,18 +24,23 @@ elif current_platform.is_xpu(): def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + if current_platform.is_xpu(): return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) + fa_version_unsupported_reason, + is_fa_version_supported, + ) + device_capability = current_platform.get_device_capability() assert device_capability is not None # 1. default version depending on platform - fa_version = 3 if (device_capability.major == 9 - and is_fa_version_supported(3)) else 2 + fa_version = ( + 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 + ) # 2. override if passed by environment if envs.VLLM_FLASH_ATTN_VERSION is not None: @@ -45,17 +51,22 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: if device_capability.major == 10 and fa_version == 3: logger.warning_once( "Cannot use FA version 3 on Blackwell platform " - "defaulting to FA version 2.") + "defaulting to FA version 2." + ) fa_version = 2 if requires_alibi and fa_version == 3: - logger.warning_once("Cannot use FA version 3 with ALiBi, " - "defaulting to FA version 2.") + logger.warning_once( + "Cannot use FA version 3 with ALiBi, defaulting to FA version 2." + ) fa_version = 2 if not is_fa_version_supported(fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - fa_version, fa_version_unsupported_reason(fa_version)) + logger.error( + "Cannot use FA version %d is not supported due to %s", + fa_version, + fa_version_unsupported_reason(fa_version), + ) assert is_fa_version_supported(fa_version) return fa_version @@ -64,8 +75,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def flash_attn_supports_fp8() -> bool: - return get_flash_attn_version() == 3 and \ - current_platform.get_device_capability().major == 9 + return ( + get_flash_attn_version() == 3 + and current_platform.get_device_capability().major == 9 + ) + + +def flash_attn_supports_mla(): + from vllm.platforms import current_platform + + if current_platform.is_cuda(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + is_fa_version_supported, + ) + + return ( + is_fa_version_supported(3) + and current_platform.get_device_capability()[0] == 9 + ) + except (ImportError, AssertionError): + pass + return False def is_flash_attn_varlen_func_available() -> bool: diff --git a/vllm/attention/utils/kv_sharing_utils.py b/vllm/attention/utils/kv_sharing_utils.py index b4ae8bdf4d762..93af5bf7e13fe 100644 --- a/vllm/attention/utils/kv_sharing_utils.py +++ b/vllm/attention/utils/kv_sharing_utils.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -def validate_kv_sharing_target(current_layer_name, target_layer_name, - static_forward_context): - error_msg = (f"Specified KV sharing target layer for {current_layer_name} " - f"is not valid: target layer {target_layer_name} ") +def validate_kv_sharing_target( + current_layer_name, target_layer_name, static_forward_context +): + error_msg = ( + f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} " + ) if current_layer_name == target_layer_name: - raise ValueError(error_msg + - "cannot be the same as the current layer.") + raise ValueError(error_msg + "cannot be the same as the current layer.") if target_layer_name not in static_forward_context: from vllm.model_executor.models.utils import extract_layer_index @@ -20,14 +22,12 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name, if current_layer_idx <= target_layer_idx: raise ValueError(error_msg + "must come before the current layer.") else: - raise ValueError(error_msg + - "is not a valid Attention layer in the model.") + raise ValueError(error_msg + "is not a valid Attention layer in the model.") # Currently KV sharing is only supported between layers of the same type - target_layer_attn_type = static_forward_context[ - target_layer_name].attn_type + target_layer_attn_type = static_forward_context[target_layer_name].attn_type expected = static_forward_context[current_layer_name].attn_type if target_layer_attn_type != expected: raise ValueError( - error_msg + - f"must be the same type as the current layer ({expected}).") + error_msg + f"must be the same type as the current layer ({expected})." + ) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 5a2e79e1b5c74..e0ba863b9210e 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -4,8 +4,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union +from vllm.logprobs import Logprob from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -18,6 +18,7 @@ class BeamSearchSequence: The text field is optional and will only be filled when the sequence is about to be returned to the user. """ + # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] @@ -36,11 +37,11 @@ class BeamSearchOutput: It contains the list of the best beam search sequences. The length of the list is equal to the beam width. """ + sequences: list[BeamSearchSequence] class BeamSearchInstance: - def __init__( self, prompt_tokens: list[int], @@ -79,9 +80,9 @@ def get_beam_search_score( def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, - length_penalty) + return get_beam_search_score( + x.tokens, x.cum_logprob, eos_token_id, length_penalty + ) return sort_beams_key diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 920d21bda3c5b..8e71a7bfb1293 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,18 +11,23 @@ generation. Supported dataset types include: - HuggingFace - VisionArena """ + +import argparse +import ast import base64 import io import json import logging +import math import random from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import numpy as np from PIL import Image @@ -33,7 +38,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict from vllm.multimodal.image import convert_image_mode -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import PlaceholderModule try: @@ -70,12 +75,10 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: Union[str, Any] + prompt: Union[str, list[str]] prompt_len: int expected_output_len: int - multi_modal_data: Optional[ - Union[MultiModalDataDict, dict, list[dict]] - ] = None + multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None lora_request: Optional[LoRARequest] = None request_id: Optional[str] = None @@ -93,28 +96,31 @@ class BenchmarkDataset(ABC): self, dataset_path: Optional[str] = None, random_seed: int = DEFAULT_SEED, + disable_shuffle: bool = False, + **kwargs, ) -> None: """ Initialize the BenchmarkDataset with an optional dataset path and random - seed. - + seed. + Args: dataset_path (Optional[str]): Path to the dataset. If None, it - indicates that a default or random dataset might be used. + indicates that a default or random dataset might be used. random_seed (int): Seed value for reproducible shuffling or - sampling. Defaults to DEFAULT_SEED. + sampling. Defaults to DEFAULT_SEED. """ self.dataset_path = dataset_path # Set the random seed, ensuring that a None value is replaced with the # default seed. - self.random_seed = (random_seed - if random_seed is not None else self.DEFAULT_SEED) + self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED + self.disable_shuffle = disable_shuffle self.data = None def apply_multimodal_chat_transformation( - self, - prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + self, + prompt: str, + mm_content: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None, + ) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -122,7 +128,15 @@ class BenchmarkDataset(ABC): """ content = [{"text": prompt, "type": "text"}] if mm_content is not None: - content.append(mm_content) + if isinstance(mm_content, list): + content.extend(cast(list[dict[str, Any]], mm_content)) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "Could not process multimodal content of type: " + + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -136,39 +150,31 @@ class BenchmarkDataset(ABC): NotImplementedError: If a subclass does not implement this method. """ # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError( - "load_data must be implemented in subclasses.") + raise NotImplementedError("load_data must be implemented in subclasses.") def get_random_lora_request( self, - tokenizer: PreTrainedTokenizerBase, max_loras: Optional[int] = None, lora_path: Optional[str] = None, - ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + ) -> Optional[LoRARequest]: """ - Optionally select a random LoRA request and return its associated - tokenizer. + Optionally select a random LoRA request. This method is used when LoRA parameters are provided. It randomly - selects a LoRA based on max_loras and retrieves a cached tokenizer for - that LoRA if available. Otherwise, it returns the base tokenizer. + selects a LoRA based on max_loras. Args: - tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no - LoRA is selected. max_loras (Optional[int]): The maximum number of LoRAs available. If `None`, LoRA is not used. lora_path (Optional[str]): Path to the LoRA parameters on disk. If `None`, LoRA is not used. Returns: - A tuple with the following elements: - - A new [LoRARequest][] (or `None` if not applicable). - - The tokenizer associated with the LoRA request - (or the base tokenizer). + A new [`LoRARequest`][vllm.lora.request.LoRARequest] + (or `None` if not applicable). """ if max_loras is None or lora_path is None: - return None, tokenizer + return None # Generate a random LoRA ID in the range [1, max_loras]. lora_id = random.randint(1, max_loras) @@ -177,16 +183,16 @@ class BenchmarkDataset(ABC): lora_int_id=lora_id, lora_path=lora_path_on_disk(lora_path), ) - if lora_id not in lora_tokenizer_cache: - lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) - # Return lora_request and the cached tokenizer if available; otherwise, - # return the base tokenizer - return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + return lora_request @abstractmethod - def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "") -> list[SampleRequest]: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, + ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -197,8 +203,7 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. - request_id_prefix (str) The prefix of request_id. - + request_id_prefix (str): The prefix of request_id. Returns: list[SampleRequest]: A list of sample requests generated from the @@ -211,6 +216,7 @@ class BenchmarkDataset(ABC): requests: list[SampleRequest], num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -220,20 +226,32 @@ class BenchmarkDataset(ABC): requests (List[SampleRequest]): The current list of sampled requests. num_requests (int): The target number of requests. - request_id_prefix (str) The prefix of the request ids. + request_id_prefix (str): The prefix applied to generated request + identifiers. """ + if no_oversample: + logger.info("Skipping oversampling. Total samples: %d.", len(requests)) + return + if len(requests) < num_requests: random.seed(self.random_seed) - additional = deepcopy( - random.choices(requests, k=num_requests - len(requests)) - ) - for i in range(len(additional)): - req = additional[i] + needed = num_requests - len(requests) + additional = [] + for i in range(needed): + req = deepcopy(random.choice(requests)) req.request_id = request_id_prefix + str(len(requests) + i) + additional.append(req) requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", - num_requests) + logger.info("Oversampled requests to reach %d total samples.", num_requests) + + ids = [req.request_id for req in requests] + if len(ids) != len(set(ids)): + raise ValueError( + "Duplicate request_id found in the sampled " + "requests. Please ensure that each request_id " + "is unique." + ) # ----------------------------------------------------------------------------- @@ -258,14 +276,14 @@ def is_valid_sequence( """ # Check for invalid conditions prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len - < min_len) + output_too_short = (not skip_min_output_len_check) and (output_len < min_len) prompt_too_long = prompt_len > max_prompt_len combined_too_long = (prompt_len + output_len) > max_total_len # Return True if none of the invalid conditions are met - return not (prompt_too_short or output_too_short or prompt_too_long - or combined_too_long) + return not ( + prompt_too_short or output_too_short or prompt_too_long or combined_too_long + ) @cache @@ -297,28 +315,30 @@ def process_image(image: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(image, dict) and 'bytes' in image: - image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, dict) and "bytes" in image: + image = Image.open(BytesIO(image["bytes"])) if isinstance(image, Image.Image): image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") - image_base64 = base64.b64encode( - image_data.getvalue()).decode("utf-8") + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") return { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, } if isinstance(image, str): - image_url = (image if image.startswith( - ("http://", "file://")) else f"file://{image}") + image_url = ( + image + if image.startswith(("http://", "https://", "file://")) + else f"file://{image}" + ) return {"type": "image_url", "image_url": {"url": image_url}} - raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes.") + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes." + ) def process_video(video: Any) -> Mapping[str, Any]: @@ -337,115 +357,667 @@ def process_video(video: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(video, dict) and 'bytes' in video: - video_bytes = video['bytes'] + if isinstance(video, dict) and "bytes" in video: + video_bytes = video["bytes"] video_base64 = base64.b64encode(video_bytes).decode("utf-8") return { "type": "video_url", - "video_url": { - "url": f"data:video/mp4;base64,{video_base64}" - }, + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, } if isinstance(video, str): - video_url = (video if video.startswith( - ("http://", "file://")) else f"file://{video}") + video_url = ( + video + if video.startswith(("http://", "https://", "file://")) + else f"file://{video}" + ) return {"type": "video_url", "video_url": {"url": video_url}} raise ValueError( f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 ) + +def gen_prompt_decode_to_target_len( + tokenizer: PreTrainedTokenizerBase, + token_sequence: list[int], + target_token_len: int, + max_retry: int = 10, + add_special_tokens: bool = False, + rng: Optional[np.random.Generator] = None, +) -> tuple[str, list[int]]: + """ + Ensure decoded-then-encoded prompt length matches the target token length. + + This function decodes an initial token sequence to text and re-encodes it + , iteratively adjusting the token sequence length to match a target. + This is necessary because some tokenizers do not guarantee a 1:1 mapping + between consecutive tokens and the decoded-then-encoded sequence length. + For example, for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + + Returns a tuple of the final prompt string and the adjusted token sequence. + """ + remain_num_try = max_retry + token_mismatch = 0 + while True: + prompt = tokenizer.decode(token_sequence) + token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + if remain_num_try <= 0: + if len(token_sequence) != target_token_len: + token_mismatch = len(token_sequence) - target_token_len + break + + if len(token_sequence) == target_token_len: + break + elif len(token_sequence) < target_token_len: + if rng is not None: + extra_tokens = rng.integers( + 0, + tokenizer.vocab_size, + size=target_token_len - len(token_sequence), + ).tolist() + else: + extra_tokens = np.random.randint( + 0, + tokenizer.vocab_size, + size=target_token_len - len(token_sequence), + ).tolist() + token_sequence.extend(extra_tokens) + elif len(token_sequence) > target_token_len: + token_sequence = token_sequence[:target_token_len] + + remain_num_try -= 1 + + return prompt, token_sequence, token_mismatch + + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- class RandomDataset(BenchmarkDataset): + """ + Synthetic text-only dataset for serving/throughput benchmarks. + + Strategy: + - Sample input/output token lengths per request from integer-uniform ranges + around configured means (controlled by range_ratio). + - Prepend a fixed random prefix of length prefix_len. + - Generate the remaining tokens as a reproducible sequence: + (offset + index + arange(input_len)) % vocab_size. + - Decode then re-encode/truncate to ensure prompt token counts match. + - Uses numpy.default_rng seeded with random_seed for reproducible sampling. + """ + # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 DEFAULT_INPUT_LEN = 1024 DEFAULT_OUTPUT_LEN = 128 - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - random.seed(self.random_seed) - np.random.seed(self.random_seed) + # Use numpy's default_rng for deterministic sampling + # Do not use random.seed() or np.random.seed() elsewhere in this class. + # This ensures that the RNG is isolated from global RNG state. + self._rng = np.random.default_rng(self.random_seed) def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, - request_id_prefix: str = "", + batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer ) + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) - - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - output_high = int(output_len * (1 + range_ratio)) - - # Add logging for debugging - logger.info( - "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, input_high, output_low, output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] + token_mismatch_total = 0 for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + token_mismatch_total += token_mismatch requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), request_id=request_id_prefix + str(i), - )) + ) + ) + # only used for embeddings benchmark. + if batchsize > 1: + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + batch_requests.append( + SampleRequest( + prompt=[req.prompt for req in batch], + prompt_len=sum(req.prompt_len for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + requests = batch_requests + + if token_mismatch_total != 0: + sign = "more" if token_mismatch_total > 0 else "fewer" + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + sign, + ) + return requests + def get_prefix( + self, tokenizer: PreTrainedTokenizerBase, prefix_len: int + ) -> list[int]: + """ + Get the prefix for the dataset. + """ + return ( + self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) + + def get_sampling_params( + self, + num_requests: int, + range_ratio: float, + input_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get the sampling parameters for the dataset. + """ + # Enforce range_ratio < 1 + if not (0.0 <= range_ratio < 1.0): + raise ValueError("range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + # Bounds use floor for low and ceil for high + input_low = math.floor(real_input_len * (1 - range_ratio)) + input_high = math.ceil(real_input_len * (1 + range_ratio)) + output_low = math.floor(output_len * (1 - range_ratio)) + output_high = math.ceil(output_len * (1 + range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + + if input_low > input_high: + raise ValueError( + f"Invalid input sampling interval: low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + "Invalid output sampling interval: " + f"low={output_low} > high={output_high}" + ) + + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, + input_high, + output_low, + output_high, + ) + + input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests) + return input_lens, output_lens, offsets + + def generate_token_sequence( + self, + *, + tokenizer: PreTrainedTokenizerBase, + prefix_token_ids: list[int], + prefix_len: int, + vocab_size: int, + input_len: int, + offset: int, + index: int, + ) -> tuple[str, int, int]: + """ + Returns (prompt, total_input_len). + + NOTE: After decoding the prompt we have to encode and decode it again. + This is done because in some cases N consecutive tokens + give a string tokenized into != N number of tokens. + For example for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + To avoid uncontrolled change of the prompt length, + the encoded sequence is truncated before being decoded again. + """ + # Build the inner sequence by sampling sequentially from the vocab + inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + + # Decode, then re-encode and truncate to preserve token count invariants + total_input_len = prefix_len + int(input_len) + prompt, adjusted_token_sequence, token_mismatch = ( + gen_prompt_decode_to_target_len( + tokenizer=tokenizer, + token_sequence=token_sequence, + target_token_len=total_input_len, + add_special_tokens=False, + rng=self._rng, + ) + ) + total_input_len = len(adjusted_token_sequence) + return prompt, total_input_len, token_mismatch + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: not yet supported (TODO: implement video generation method). + - Audio: not yet supported. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a “low-freq” mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, height: int, num_frames: int) -> Any: + """Generate synthetic video with random values. + + TODO: Finish this method. + """ + raise NotImplementedError("Video sampling is WIP.") + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config( + self, bucket_config: dict[tuple[int, int, int], float] + ) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError( + "Got invalid bucket config. Bucket config values must be non-zero." + ) + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + def generate_mm_item( + self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image( + self.generate_synthetic_image(mm_item_config[1], mm_item_config[0]) + ) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video( + self.generate_synthetic_video( + mm_item_config[1], mm_item_config[0], mm_item_config[2] + ) + ) + else: + raise ValueError(f"Invalid multimodal item configuration: {mm_item_config}") + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError( + f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}" + ) + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", + bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() if k in allowed_modalities + } + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", + limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)), + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError( + f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}" + ) + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, + max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int, int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice( + list(bucket_config_copy.keys()), p=list(bucket_config_copy.values()) + ) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield (mm_item_config) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning( + "Exhausted all multimodal items of modality %s", modality + ) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config(bucket_config_copy) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[ + tuple[int, int, int], float + ] = DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + # NOTE: Video sampling is WIP. Raise error if video is in bucket config + # and probability is non-zero. + if any( + self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items() + ): + raise NotImplementedError( + "Video sampling not implemented; set its probability to 0." + ) + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] + token_mismatch_total = 0 + for i in range(num_requests): + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + token_mismatch_total += token_mismatch + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast( + list[dict[str, Any]], + [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ], + ) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content + ) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + mm_requests.append(sample_request) + + if token_mismatch_total != 0: + sign = "more" if token_mismatch_total > 0 else "fewer" + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + sign, + ) + + return mm_requests + # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation @@ -470,11 +1042,13 @@ class ShareGPTDataset(BenchmarkDataset): self.data = json.load(f) # Filter entries with at least two conversation turns. self.data = [ - entry for entry in self.data + entry + for entry in self.data if "conversations" in entry and len(entry["conversations"]) >= 2 ] random.seed(self.random_seed) - random.shuffle(self.data) + if not getattr(self, "disable_shuffle", False): + random.shuffle(self.data) def sample( self, @@ -485,6 +1059,7 @@ class ShareGPTDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: samples: list = [] @@ -497,27 +1072,27 @@ class ShareGPTDataset(BenchmarkDataset): entry["conversations"][1]["value"], ) - lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + lora_request = self.get_random_lora_request( + max_loras=max_loras, lora_path=lora_path + ) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) - if not is_valid_sequence(prompt_len, - new_output_len, - skip_min_output_len_check=output_len - is not None): + new_output_len = len(completion_ids) if output_len is None else output_len + if not is_valid_sequence( + prompt_len, + new_output_len, + skip_min_output_len_check=output_len is not None, + ): continue - if image_path := entry.get("image"): - mm_content = process_image(image_path) - elif video_path := entry.get("video"): + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): mm_content = process_video(video_path) - else: + else: mm_content = None if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) samples.append( SampleRequest( prompt=prompt, @@ -526,12 +1101,35 @@ class ShareGPTDataset(BenchmarkDataset): lora_request=lora_request, multi_modal_data=mm_content, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) + self.maybe_oversample_requests( + samples, num_requests, request_id_prefix, no_oversample + ) return samples +class _ValidateDatasetArgs(argparse.Action): + """Argparse action to validate dataset name and path compatibility.""" + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + # Get current values of both dataset_name and dataset_path + dataset_name = getattr(namespace, "dataset_name", "random") + dataset_path = getattr(namespace, "dataset_path", None) + + # Validate the combination + if dataset_name == "random" and dataset_path is not None: + parser.error( + "Cannot use 'random' dataset with --dataset-path. " + "Please specify the appropriate --dataset-name (e.g., " + "'sharegpt', 'custom', 'sonnet') for your dataset file: " + f"{dataset_path}" + ) + + def add_dataset_parser(parser: FlexibleArgumentParser): parser.add_argument("--seed", type=int, default=0) parser.add_argument( @@ -544,9 +1142,17 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-name", type=str, default="random", + action=_ValidateDatasetArgs, choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", - "prefix_repetition" + "sharegpt", + "burstgpt", + "sonnet", + "random", + "random-mm", + "hf", + "custom", + "prefix_repetition", + "spec_bench", ], help="Name of the dataset to benchmark on.", ) @@ -559,9 +1165,25 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-path", type=str, default=None, + action=_ValidateDatasetArgs, help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.", ) + parser.add_argument( + "--no-oversample", + action="store_true", + help="Do not oversample if the dataset has fewer samples than num-prompts.", + ) + parser.add_argument( + "--skip-chat-template", + action="store_true", + help="Skip applying chat template to prompt for datasets that support it.", + ) + parser.add_argument( + "--disable-shuffle", + action="store_true", + help="Disable shuffling of dataset samples for deterministic ordering.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") @@ -569,14 +1191,21 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--custom-output-len", type=int, default=256, - help= - "Number of output tokens per request, used only for custom dataset.", + help="Number of output tokens per request, used only for custom dataset.", ) - custom_group.add_argument( - "--custom-skip-chat-template", - action="store_true", - help= - "Skip applying chat template to prompt, used only for custom dataset.", + + spec_bench_group = parser.add_argument_group("spec bench dataset options") + spec_bench_group.add_argument( + "--spec-bench-output-len", + type=int, + default=256, + help="Num of output tokens per request, used only for spec bench dataset.", + ) + spec_bench_group.add_argument( + "--spec-bench-category", + type=str, + default=None, + help="Category for spec bench dataset. If None, use all categories.", ) sonnet_group = parser.add_argument_group("sonnet dataset options") @@ -584,22 +1213,19 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--sonnet-input-len", type=int, default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", + help="Number of input tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-output-len", type=int, default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", + help="Number of output tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-prefix-len", type=int, default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", + help="Number of prefix tokens per request, used only for sonnet dataset.", ) sharegpt_group = parser.add_argument_group("sharegpt dataset options") @@ -611,20 +1237,32 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the ShareGPT dataset.", ) + blazedit_group = parser.add_argument_group("blazedit dataset options") + blazedit_group.add_argument( + "--blazedit-min-distance", + type=float, + default=0.0, + help="Minimum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + blazedit_group.add_argument( + "--blazedit-max-distance", + type=float, + default=1.0, + help="Maximum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", type=int, default=1024, - help= - "Number of input tokens per request, used only for random sampling.", + help="Number of input tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-output-len", type=int, default=128, - help= - "Number of output tokens per request, used only for random sampling.", + help="Number of output tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-range-ratio", @@ -639,23 +1277,133 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--random-prefix-len", type=int, default=0, - help=("Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]."), + help=( + "Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]." + ), + ) + random_group.add_argument( + "--random-batch-size", + type=int, + default=1, + help=("Batch size for random sampling. Only used for embeddings benchmark."), + ) + + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset" + ) + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + '\'{"image": 3, "video": 0}\'. The sampled per-request item ' + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not ( + isinstance(key, tuple) + and len(key) == 3 + and all(isinstance(x, int) for x in key) + ): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), ) hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + hf_group.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) + hf_group.add_argument( + "--hf-name", + type=str, + default=None, + help=( + "Name of the dataset on HuggingFace " + "(e.g., 'lmarena-ai/VisionArena-Chat'). " + "Specify this if your dataset-path is a local path." + ), + ) hf_group.add_argument( "--hf-output-len", type=int, @@ -665,7 +1413,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ) prefix_repetition_group = parser.add_argument_group( - "prefix repetition dataset options") + "prefix repetition dataset options" + ) prefix_repetition_group.add_argument( "--prefix-repetition-prefix-len", type=int, @@ -697,20 +1446,28 @@ def add_dataset_parser(parser: FlexibleArgumentParser): def get_samples(args, tokenizer) -> list[SampleRequest]: + if not hasattr(args, "request_id_prefix"): + args.request_id_prefix = "" + if args.dataset_name == "custom": - dataset = CustomDataset(dataset_path=args.dataset_path) + dataset = CustomDataset( + dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle + ) input_requests = dataset.sample( num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.custom_output_len, - skip_chat_template=args.custom_skip_chat_template, + skip_chat_template=args.skip_chat_template, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) elif args.dataset_name == "sonnet": - dataset = SonnetDataset(dataset_path=args.dataset_path) + dataset = SonnetDataset( + dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle + ) # For the "sonnet" dataset, formatting depends on the backend. - if args.endpoint_type == "openai-chat": + if args.backend == "openai-chat": input_requests = dataset.sample( num_requests=args.num_prompts, input_len=args.sonnet_input_len, @@ -719,10 +1476,12 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, return_prompt_formatted=False, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) input_requests = dataset.sample( num_requests=args.num_prompts, input_len=args.sonnet_input_len, @@ -731,88 +1490,166 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, return_prompt_formatted=True, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) elif args.dataset_name == "hf": # all following datasets are implemented from the # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + hf_kwargs = {} + if ( + args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = VisionArenaDataset args.hf_split = "train" args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in MMVUDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMVUDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMVUDataset + args.hf_split = "validation" + args.hf_subset = None + elif ( + args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = InstructCoderDataset args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = MTBenchDataset args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ConversationDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = ConversationDataset - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS + or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS + ): dataset_class = AIMODataset args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + elif ( + args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501 + or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = NextEditPredictionDataset args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = ASRDataset args.hf_split = "train" - elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS: + elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS: + dataset_class = BlazeditDataset + args.hf_split = "train" + hf_kwargs = { + "min_distance": args.blazedit_min_distance, + "max_distance": args.blazedit_max_distance, + } + elif ( + args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = MLPerfDataset args.hf_split = "train" + elif ( + args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMStarDataset + args.hf_split = "val" + args.hf_subset = None else: - supported_datasets = set([ - dataset_name for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ]) + supported_datasets = set( + [ + dataset_name + for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ] + ) raise ValueError( f"Unsupported dataset path: {args.dataset_path}. " "Huggingface dataset only supports dataset_path" f" from one of following: {supported_datasets}. " "Please consider contributing if you would " - "like to add support for additional dataset formats.") + "like to add support for additional dataset formats." + ) - if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [ - "openai-chat", - "openai-audio", - ]: + if dataset_class.IS_MULTIMODAL and not ( + args.backend in ("openai-chat", "openai-audio") + or "embeddings-" in args.backend + ): # multi-modal benchmark is only available on OpenAI Chat # endpoint-type. raise ValueError( "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' endpoint-type.") + "'openai-audio' backends." + ) input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, dataset_split=args.hf_split, random_seed=args.seed, no_stream=args.no_stream, + hf_name=args.hf_name, + disable_shuffle=args.disable_shuffle, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + skip_chat_template=args.skip_chat_template, + **hf_kwargs, ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - request_id_prefix=args.request_id_prefix, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts, - request_id_prefix=args.request_id_prefix,), - "random": - lambda: RandomDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( + "spec_bench": lambda: SpecBench( + dataset_path=args.dataset_path, + category=args.spec_bench_category, + disable_shuffle=args.disable_shuffle, + ).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.spec_bench_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "random": lambda: RandomDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, + ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, @@ -820,10 +1657,31 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: output_len=args.random_output_len, range_ratio=args.random_range_ratio, request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + no_oversample=args.no_oversample, ), - "prefix_repetition": - lambda: PrefixRepetitionRandomDataset( - random_seed=args.seed, dataset_path=args.dataset_path + "random-mm": lambda: RandomMultiModalDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), + "prefix_repetition": lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -832,10 +1690,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_prefixes=args.prefix_repetition_num_prefixes, output_len=args.prefix_repetition_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), } try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.backend not in ["openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) input_requests = dataset_mapping[args.dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {args.dataset_name}") from err @@ -876,8 +1741,7 @@ class CustomDataset(BenchmarkDataset): # Load the JSONL file if self.dataset_path.endswith(".jsonl"): - jsonl_data = pd.read_json(path_or_buf=self.dataset_path, - lines=True) + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) # check if the JSONL file has a 'prompt' column if "prompt" not in jsonl_data.columns: @@ -891,10 +1755,12 @@ class CustomDataset(BenchmarkDataset): self.data.append(row.to_dict()) else: raise NotImplementedError( - "Only JSONL format is supported for CustomDataset.") + "Only JSONL format is supported for CustomDataset." + ) random.seed(self.random_seed) - random.shuffle(self.data) + if not getattr(self, "disable_shuffle", False): + random.shuffle(self.data) def sample( self, @@ -906,8 +1772,19 @@ class CustomDataset(BenchmarkDataset): enable_multimodal_chat: bool = False, skip_chat_template: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: + # load all data if needed + self.num_available_samples = len(self.data) + if num_requests <= 0: + num_requests = self.num_available_samples + logger.info( + "num_requests is set to 0 or negative, " + "so using all available samples: %d", + num_requests, + ) + sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -917,10 +1794,7 @@ class CustomDataset(BenchmarkDataset): # apply template if not skip_chat_template: prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], + [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False, ) @@ -932,17 +1806,65 @@ class CustomDataset(BenchmarkDataset): prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests +# ----------------------------------------------------------------------------- +# Spec Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SpecBench(CustomDataset): + """ + Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench + Download the dataset using: + wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + """ # noqa: E501 + + def __init__(self, **kwargs) -> None: + self.category = kwargs.pop("category", None) + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = [] + + # Load the JSONL file + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) + + # check if the JSONL file has a 'turns' column + if "turns" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'turns' column.") + + for _, row in jsonl_data.iterrows(): + # sample only from a specific category if specified + if (not self.category) or (self.category == row["category"]): + prompt = row["turns"][0] + self.data.append({"prompt": prompt}) + + random.seed(self.random_seed) + if not getattr(self, "disable_shuffle", False): + random.shuffle(self.data) + + def sample(self, **kwargs) -> list: + # leverage CustomDataset sample + return super().sample(**kwargs) + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- + @deprecated( "SonnetDataset is deprecated and will be removed in a future version.", ) @@ -979,24 +1901,25 @@ class SonnetDataset(BenchmarkDataset): output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: # Calculate average token length for a poem line. tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template(base_msg, - add_generation_prompt=True, - tokenize=False) + base_fmt = tokenizer.apply_chat_template( + base_msg, add_generation_prompt=True, tokenize=False + ) base_offset = len(tokenizer(base_fmt).input_ids) if input_len <= base_offset: raise ValueError( f"'input_len' must be higher than the base prompt length " - f"({base_offset}).") + f"({base_offset})." + ) # Determine how many poem lines to use. num_input_lines = round((input_len - base_offset) / avg_len) @@ -1006,22 +1929,24 @@ class SonnetDataset(BenchmarkDataset): samples = [] ind = 0 while len(samples) < num_requests: - extra_lines = random.choices(self.data, - k=num_input_lines - num_prefix_lines) + extra_lines = random.choices( + self.data, k=num_input_lines - num_prefix_lines + ) prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" msg = [{"role": "user", "content": prompt}] prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False) + msg, add_generation_prompt=True, tokenize=False + ) prompt_len = len(tokenizer(prompt_formatted).input_ids) if prompt_len <= input_len: samples.append( SampleRequest( - prompt=prompt_formatted - if return_prompt_formatted else prompt, + prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, - request_id=request_id_prefix + str(ind), - )) + request_id=request_id_prefix + str(ind), + ) + ) ind += 1 return samples @@ -1042,7 +1967,9 @@ class BurstGPTDataset(BenchmarkDataset): super().__init__(**kwargs) self.load_data() - def load_data(self, ): + def load_data( + self, + ): if self.dataset_path is None: raise ValueError("dataset_path must be provided for loading data.") @@ -1056,8 +1983,7 @@ class BurstGPTDataset(BenchmarkDataset): def _sample_loaded_data(self, num_requests: int) -> list: if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, - random_state=self.random_seed) + data = self.data.sample(n=num_requests, random_state=self.random_seed) else: data = self.data.sample( n=num_requests, @@ -1074,6 +2000,7 @@ class BurstGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, lora_path: Optional[str] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: samples = [] @@ -1081,8 +2008,9 @@ class BurstGPTDataset(BenchmarkDataset): for i in range(num_requests): input_len = int(data[i][2]) output_len = int(data[i][3]) - lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + lora_req = self.get_random_lora_request( + max_loras=max_loras, lora_path=lora_path + ) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + # j) modulo vocab_size. @@ -1095,7 +2023,8 @@ class BurstGPTDataset(BenchmarkDataset): expected_output_len=output_len, lora_request=lora_req, request_id=request_id_prefix + str(i), - )) + ) + ) return samples @@ -1113,6 +2042,7 @@ class HuggingFaceDataset(BenchmarkDataset): dataset_split: str, no_stream: bool = False, dataset_subset: Optional[str] = None, + hf_name: Optional[str] = None, **kwargs, ) -> None: super().__init__(dataset_path=dataset_path, **kwargs) @@ -1120,6 +2050,7 @@ class HuggingFaceDataset(BenchmarkDataset): self.dataset_split = dataset_split self.dataset_subset = dataset_subset self.load_stream = not no_stream + self.hf_name = hf_name or dataset_path self.load_data() def load_data(self) -> None: @@ -1130,7 +2061,8 @@ class HuggingFaceDataset(BenchmarkDataset): split=self.dataset_split, streaming=self.load_stream, ) - self.data = self.data.shuffle(seed=self.random_seed) + if not getattr(self, "disable_shuffle", False): + self.data = self.data.shuffle(seed=self.random_seed) # ----------------------------------------------------------------------------- @@ -1140,21 +2072,25 @@ class HuggingFaceDataset(BenchmarkDataset): class ConversationDataset(HuggingFaceDataset): """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { - 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + "lmms-lab/LLaVA-OneVision-Data", + "Aeala/ShareGPT_Vicuna_unfiltered", } IS_MULTIMODAL = True - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: # Filter examples with at least 2 conversations - filtered_data = self.data.filter( - lambda x: len(x["conversations"]) >= 2) + filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] ind = 0 dynamic_output = output_len is None @@ -1171,17 +2107,14 @@ class ConversationDataset(HuggingFaceDataset): completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len): + if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - mm_content = process_image( - item["image"]) if "image" in item else None + mm_content = process_image(item["image"]) if "image" in item else None if enable_multimodal_chat: # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, @@ -1189,10 +2122,12 @@ class ConversationDataset(HuggingFaceDataset): expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1208,10 +2143,8 @@ class VisionArenaDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": - lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": - lambda x: x["turns"][0][0]["content"] + "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], } IS_MULTIMODAL = True @@ -1222,18 +2155,17 @@ class VisionArenaDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.hf_name}") prompt = parser_fn(item) mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) @@ -1241,8 +2173,7 @@ class VisionArenaDataset(HuggingFaceDataset): # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, @@ -1250,9 +2181,65 @@ class VisionArenaDataset(HuggingFaceDataset): expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests + + +class MMVUDataset(HuggingFaceDataset): + """ + MMVU Dataset. + https://huggingface.co/datasets/yale-nlp/MMVU + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "yale-nlp/MMVU": lambda x: x["question"] + + " " + + (" ".join(f"{k}.{v}" for k, v in x["choices"].items())), + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN + sampled_requests = [] + for i, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) + if parser_fn is None: + raise ValueError(f"Unsupported dataset path: {self.hf_name}") + prompt = parser_fn(item) + mm_content = process_video(item["video"]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1276,15 +2263,18 @@ class InstructCoderDataset(HuggingFaceDataset): "likaixin/InstructCoder", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -1295,14 +2285,12 @@ class InstructCoderDataset(HuggingFaceDataset): ) # apply template - prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True, - tokenize=False, - ) + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( @@ -1311,9 +2299,11 @@ class InstructCoderDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1343,11 +2333,12 @@ class MTBenchDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): @@ -1356,14 +2347,12 @@ class MTBenchDataset(HuggingFaceDataset): prompt = item["turns"][0] # apply template - prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True, - tokenize=False, - ) + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( @@ -1372,9 +2361,100 @@ class MTBenchDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Blazedit Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BlazeditDataset(HuggingFaceDataset): + """ + Blazedit Dataset. + https://github.com/ise-uiuc/blazedit + + 5k char version: vdaita/edit_5k_char + 10k char version: vdaita/edit_10k_char + """ # noqa: E501 + + # 5k char version will have output as ~5k chars + # 10k char version will have output as ~10k chars + # Assuming 3 char per token, 10k chars will be 3333 tokens + # We set default to 4000 to be safe + DEFAULT_OUTPUT_LEN = 4000 + SUPPORTED_DATASET_PATHS = { + "vdaita/edit_5k_char", + "vdaita/edit_10k_char", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + skip_chat_template: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + min_distance: float = 0.0, + max_distance: float = 1.0, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN + sampled_requests = [] + + for i, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + code = item["code"] + change_request = item["change_request"] + norm_distance = item["norm_distance"] + + # compare the levenshtein distance normalized by code length + if norm_distance < min_distance or norm_distance > max_distance: + continue + + # template copied from + # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501 + prompt = f"""Given a code file, please apply the change requests and generate the new file. + +Original file: +```python +{code} +``` + +Change request: +{change_request} + +Please generate the new code file in the "New file" section below.""" # noqa: E501 + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + request_id=request_id_prefix + str(i), + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests @@ -1387,17 +2467,22 @@ class AIMODataset(HuggingFaceDataset): """ Dataset class for processing a AIMO dataset with reasoning questions. """ + SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT" + "AI-MO/aimo-validation-aime", + "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - request_id_prefix: str = "", - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: sampled_requests = [] ind = 0 dynamic_output = output_len is None @@ -1405,7 +2490,7 @@ class AIMODataset(HuggingFaceDataset): for item in self.data: if len(sampled_requests) >= num_requests: break - prompt, completion = item['problem'], item["solution"] + prompt, completion = item["problem"], item["solution"] prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids @@ -1413,10 +2498,9 @@ class AIMODataset(HuggingFaceDataset): completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, - completion_len, - max_prompt_len=2048, - max_total_len=32000): + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 + ): continue sampled_requests.append( SampleRequest( @@ -1425,11 +2509,12 @@ class AIMODataset(HuggingFaceDataset): expected_output_len=output_len, multi_modal_data=None, request_id=request_id_prefix + str(ind), - - )) + ) + ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1451,12 +2536,12 @@ You are a code completion assistant and your task is to analyze user edits and t ### Response: -""" # noqa: E501 +""" # noqa: E501 def _format_zeta_prompt( - sample: dict, - original_start_marker: str = "<|editable_region_start|>") -> dict: + sample: dict, original_start_marker: str = "<|editable_region_start|>" +) -> dict: """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. This function formats examples from the NEP dataset @@ -1499,13 +2584,17 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - request_id_prefix: str = "", - **kwargs): - formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( - self.dataset_path) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name) if formatting_prompt_func is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.hf_name}") samples = [] for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) @@ -1514,12 +2603,16 @@ class NextEditPredictionDataset(HuggingFaceDataset): prompt=sample["prompt"], prompt_len=len(tokenizer(sample["prompt"]).input_ids), expected_output_len=len( - tokenizer(sample["expected_output"]).input_ids), + tokenizer(sample["expected_output"]).input_ids + ), request_id=request_id_prefix + str(i), - )) + ) + ) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) + self.maybe_oversample_requests( + samples, num_requests, request_id_prefix, no_oversample + ) return samples @@ -1560,8 +2653,7 @@ class ASRDataset(HuggingFaceDataset): IS_MULTIMODAL = True # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = ( - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>") + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" skip_long_audios: bool = True def sample( @@ -1570,10 +2662,10 @@ class ASRDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] @@ -1598,7 +2690,8 @@ class ASRDataset(HuggingFaceDataset): expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 if skipped: logger.warning( @@ -1607,8 +2700,9 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1646,6 +2740,7 @@ class MLPerfDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. @@ -1691,8 +2786,9 @@ class MLPerfDataset(HuggingFaceDataset): ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1702,7 +2798,7 @@ class MLPerfDataset(HuggingFaceDataset): class PrefixRepetitionRandomDataset(BenchmarkDataset): - # Default values copied from benchmark_serving.py for the repeated prefix + # Default values copied from benchmark_serving.py for the repeated prefix # dataset. DEFAULT_PREFIX_LEN = 256 DEFAULT_SUFFIX_LEN = 256 @@ -1726,6 +2822,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): num_prefixes: int = DEFAULT_NUM_PREFIXES, output_len: int = DEFAULT_OUTPUT_LEN, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: vocab_size = tokenizer.vocab_size @@ -1740,29 +2837,26 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): """Generate tokens that decode and re-encode to exactly target_length.""" # Generate random tokens - tokens = np.random.randint( - 0, vocab_size, size=target_length).tolist() - text = tokenizer.decode(tokens) - re_encoded = tokenizer.encode(text, add_special_tokens=False) + tokens = np.random.randint(0, vocab_size, size=target_length).tolist() - if len(re_encoded) == target_length: - return re_encoded - elif len(re_encoded) < target_length: - # Recursively generate additional consistent tokens - needed = target_length - len(re_encoded) - extra_tokens = _generate_exact_length_tokens(needed) - return re_encoded + extra_tokens - else: - # Truncate to target length - return re_encoded[:target_length] + _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501 + tokenizer=tokenizer, + token_sequence=tokens, + target_token_len=target_length, + add_special_tokens=False, + ) + return adjusted_tokens, token_mismatch requests = [] + token_mismatch_total = 0 for _ in range(num_prefixes): prefix_tokens = _generate_exact_length_tokens(prefix_len) for _ in range(prompts_per_prefix): - suffix_tokens = _generate_exact_length_tokens(suffix_len) - + suffix_tokens, token_mistmatch = _generate_exact_length_tokens( + suffix_len + ) + token_mismatch_total += token_mistmatch combined_tokens = prefix_tokens + suffix_tokens prompt = tokenizer.decode(combined_tokens) prompt_len = len(combined_tokens) @@ -1774,5 +2868,89 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): ) ) - random.shuffle(requests) + if token_mismatch_total != 0: + sign = "more" if token_mismatch_total > 0 else "fewer" + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + sign, + ) + if not getattr(self, "disable_shuffle", False): + random.shuffle(requests) return requests + + +# ----------------------------------------------------------------------------- +# MMStar Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MMStarDataset(HuggingFaceDataset): + """ + Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar + refer to: https://github.com/sgl-project/SpecForge/pull/106 + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"} + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list[SampleRequest]: + # If --hf-output-len is not set, use the default output length. + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN + sampled_requests: list[SampleRequest] = [] + + for ind, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + # Split the question text from options + # (keep only the part before "Options:"). + full_q: str = item.get("question", "") + question_text = full_q.split("Options:", 1)[0].strip() + + # Multimodal image content. + mm_content = process_image(item["image"]) + + # Compute prompt token length (note: this is plain text length + # if enable_multimodal_chat is False). + prompt_len = len(tokenizer(question_text).input_ids) + + if enable_multimodal_chat: + # If multimodal content should be embedded in the chat message, + # convert to [{"role":"user","content":[...]}] + prompt = self.apply_multimodal_chat_transformation( + question_text, mm_content + ) + mm_for_request = None # Already embedded in chat content. + else: + # Default: prompt is plain text, + # image is in mm_content for the bench to assemble. + prompt = question_text + mm_for_request = mm_content + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_for_request, + request_id=request_id_prefix + str(ind), + ) + ) + + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 05378ec74d2fa..7692697fe768a 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -13,20 +13,20 @@ import numpy as np from tqdm import tqdm import vllm.envs as envs -from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, - write_to_json) +from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.sampling_params import BeamSearchParams -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={"latency": results["latencies"]}, - extra_info={k: results[k] - for k in ["avg_latency", "percentiles"]}) + extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, + ) if pt_records: pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" write_to_json(pt_file, pt_records) @@ -49,10 +49,9 @@ def add_cli_args(parser: argparse.ArgumentParser): default=10, help="Number of iterations to run for warmup.", ) - parser.add_argument("--num-iters", - type=int, - default=30, - help="Number of iterations to run.") + parser.add_argument( + "--num-iters", type=int, default=30, help="Number of iterations to run." + ) parser.add_argument( "--profile", action="store_true", @@ -67,8 +66,10 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) @@ -81,7 +82,8 @@ def main(args: argparse.Namespace): if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: raise OSError( "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler.") + "Please set it to a valid path to use torch profiler." + ) engine_args = EngineArgs.from_cli_args(args) # Lazy import to avoid importing LLM when the bench command is not selected. @@ -91,9 +93,11 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + - args.output_len), ("Please ensure that max_model_len is greater than" - " the sum of input_len and output_len.") + args.input_len + args.output_len + ), ( + "Please ensure that max_model_len is greater than" + " the sum of input_len and output_len." + ) sampling_params = SamplingParams( n=args.n, @@ -103,18 +107,16 @@ def main(args: argparse.Namespace): max_tokens=args.output_len, detokenize=not args.disable_detokenize, ) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = np.random.randint( + 10000, size=(args.batch_size, args.input_len) + ) + dummy_prompts: list[PromptType] = [ + {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() + ] def llm_generate(): if not args.use_beam_search: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) else: llm.beam_search( dummy_prompts, diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 76beded4d5189..28146ce6200d1 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -8,18 +8,62 @@ import os import sys import time import traceback +from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import Optional, Union +from typing import Any, Literal, Optional, Protocol, Union import aiohttp +import regex as re from tqdm.asyncio import tqdm AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +class StreamedResponseHandler: + """Handles streaming HTTP responses by accumulating chunks until complete + messages are available.""" + + def __init__(self): + self.buffer = "" + + def add_chunk(self, chunk_bytes: bytes) -> list[str]: + """Add a chunk of bytes to the buffer and return any complete + messages.""" + chunk_str = chunk_bytes.decode("utf-8") + self.buffer += chunk_str + + messages = [] + + # Split by double newlines (SSE message separator) + while "\n\n" in self.buffer: + message, self.buffer = self.buffer.split("\n\n", 1) + message = message.strip() + if message: + messages.append(message) + + # if self.buffer is not empty, check if it is a complete message + # by removing data: prefix and check if it is a valid JSON + if self.buffer.startswith("data: "): + message_content = self.buffer.removeprefix("data: ").strip() + if message_content == "[DONE]": + messages.append(self.buffer.strip()) + self.buffer = "" + elif message_content: + try: + json.loads(message_content) + messages.append(self.buffer.strip()) + self.buffer = "" + except json.JSONDecodeError: + # Incomplete JSON, wait for more chunks. + pass + + return messages + + @dataclass class RequestFuncInput: """The input for the request function.""" + prompt: str api_url: str prompt_len: int @@ -27,6 +71,7 @@ class RequestFuncInput: model: str model_name: Optional[str] = None logprobs: Optional[int] = None + extra_headers: Optional[dict] = None extra_body: Optional[dict] = None multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False @@ -37,16 +82,60 @@ class RequestFuncInput: @dataclass class RequestFuncOutput: """The output of the request function including metrics.""" + generated_text: str = "" success: bool = False latency: float = 0.0 output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: list[float] = field( - default_factory=list) # list of inter-token latencies + itl: list[float] = field(default_factory=list) # list of inter-token latencies tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" + start_time: float = 0.0 + + +class RequestFunc(Protocol): + def __call__( + self, + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, + ) -> Awaitable[RequestFuncOutput]: ... + + +def _validate_api_url( + api_url: str, + api_name: str, + expected_suffixes: Union[str, set[str]], +) -> None: + if isinstance(expected_suffixes, str): + expected_suffixes = {expected_suffixes} + + expected_suffixes = {*expected_suffixes, "profile"} + + if not api_url.endswith(tuple(expected_suffixes)): + raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.") + + +def _update_payload_common( + payload: dict[str, Any], + request_func_input: RequestFuncInput, +) -> None: + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + + +def _update_headers_common( + headers: dict[str, Any], + request_func_input: RequestFuncInput, +) -> None: + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id async def async_request_openai_completions( @@ -64,13 +153,12 @@ async def async_request_openai_completions( The output of the request function. """ api_url = request_func_input.api_url - assert api_url.endswith( - ("completions", "profile") - ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + _validate_api_url(api_url, "OpenAI Completions API", "completions") payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -81,74 +169,74 @@ async def async_request_openai_completions( "include_usage": True, }, } - if request_func_input.ignore_eos: - payload["ignore_eos"] = request_func_input.ignore_eos - if request_func_input.extra_body: - payload.update(request_func_input.extra_body) + _update_payload_common(payload, request_func_input) + headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } - if request_func_input.request_id: - headers["x-request-id"] = request_func_input.request_id + _update_headers_common(headers, request_func_input) output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = "" st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: first_chunk_received = False - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - # NOTE: SSE comments (often used as pings) start with - # a colon. These are not JSON data payload and should - # be skipped. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data: ") + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue - if chunk != "[DONE]": - data = json.loads(chunk) + chunk = message.removeprefix("data: ") - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # want to check a token was generated - if choices := data.get("choices"): - # Note that text could be empty here - # e.g. for special tokens - text = choices[0].get("text") - timestamp = time.perf_counter() - # First token - if not first_chunk_received: - first_chunk_received = True - ttft = time.perf_counter() - st - output.ttft = ttft + if chunk != "[DONE]": + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft - most_recent_timestamp = timestamp - generated_text += text or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") if first_chunk_received: output.success = True else: output.success = False output.error = ( "Never received a valid chunk to calculate TTFT." - "This response will be marked as failed!") + "This response will be marked as failed!" + ) output.generated_text = generated_text output.latency = most_recent_timestamp - st else: @@ -164,57 +252,62 @@ async def async_request_openai_completions( return output +def _get_chat_content( + request_func_input: RequestFuncInput, + mm_position: Literal["first", "last"] = "last", +) -> list[dict[str, Any]]: + text_contents = [{"type": "text", "text": request_func_input.prompt}] + + mm_contents = [] + if request_func_input.multi_modal_content: + mm_content = request_func_input.multi_modal_content + if isinstance(mm_content, list): + mm_contents.extend(request_func_input.multi_modal_content) + elif isinstance(mm_content, dict): + mm_contents.append(request_func_input.multi_modal_content) + else: + raise TypeError( + "multi_modal_content must be a dict or list[dict] for openai-chat" + ) + + if mm_position == "first": + return mm_contents + text_contents + + return text_contents + mm_contents + + async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, pbar: Optional[tqdm] = None, + mm_position: Literal["first", "last"] = "last", ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith(("chat/completions", "profile")), ( - "OpenAI Chat Completions API URL must end with 'chat/completions'.") + _validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions") + + content = _get_chat_content(request_func_input, mm_position=mm_position) - content = [{"type": "text", "text": request_func_input.prompt}] - if request_func_input.multi_modal_content: - mm_content = request_func_input.multi_modal_content - if isinstance(mm_content, list): - content.extend(mm_content) - elif isinstance(mm_content, dict): - content.append(mm_content) - else: - raise TypeError( - "multi_modal_content must be a dict or list[dict] " - "for openai-chat" - ) payload = { - "model": - request_func_input.model_name - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "messages": [ - { - "role": "user", - "content": content - }, + {"role": "user", "content": content}, ], - "temperature": - 0.0, - "max_completion_tokens": - request_func_input.output_len, - "stream": - True, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, "stream_options": { "include_usage": True, }, } - if request_func_input.ignore_eos: - payload["ignore_eos"] = request_func_input.ignore_eos - if request_func_input.extra_body: - payload.update(request_func_input.extra_body) + _update_payload_common(payload, request_func_input) + headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } - if request_func_input.request_id: - headers["x-request-id"] = request_func_input.request_id + _update_headers_common(headers, request_func_input) output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -222,46 +315,47 @@ async def async_request_openai_chat_completions( generated_text = "" ttft = 0.0 st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - # NOTE: SSE comments (often used as pings) start with - # a colon. These are not JSON data payload and should - # be skipped. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data: ") + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + chunk = message.removeprefix("data: ") - if choices := data.get("choices"): - content = choices[0]["delta"].get("content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) - most_recent_timestamp = timestamp + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True @@ -288,36 +382,27 @@ async def async_request_openai_audio( import soundfile api_url = request_func_input.api_url - assert api_url.endswith(("transcriptions", "translations")), ( - "OpenAI Chat Completions API URL must end with 'transcriptions' ") - "or `translations`." + _validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"}) content = [{"type": "text", "text": request_func_input.prompt}] payload = { - "model": - request_func_input.model_name - if request_func_input.model_name else request_func_input.model, - "temperature": - 0.0, - "max_completion_tokens": - request_func_input.output_len, - "stream": - True, - "language": - "en", + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", # Flattened due to multipart/form-data - "stream_include_usage": - True, - "stream_continuous_usage_stats": - True, + "stream_include_usage": True, + "stream_continuous_usage_stats": True, } - if request_func_input.extra_body: - payload.update(request_func_input.extra_body) + _update_payload_common(payload, request_func_input) + headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } - if request_func_input.request_id: - headers["x-request-id"] = request_func_input.request_id + _update_headers_common(headers, request_func_input) # Send audio file def to_bytes(y, sr): @@ -341,42 +426,47 @@ async def async_request_openai_audio( generated_text = "" ttft = 0.0 st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: - async with session.post(url=api_url, - data=form, - headers=headers) as response: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: if response.status == 200: - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + messages = handler.add_chunk(chunk_bytes) + for message in messages: + chunk = message.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - # Decoding phase - else: - output.itl.append( - timestamp - most_recent_timestamp) + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp + ) - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens" + ) - most_recent_timestamp = timestamp + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True @@ -394,16 +484,248 @@ async def async_request_openai_audio( return output +async def _run_openai_embeddings( + session: aiohttp.ClientSession, + api_url: str, + payload: dict[str, Any], + headers: dict[str, Any], + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + output = RequestFuncOutput() + st = time.perf_counter() + output.start_time = st + try: + async with session.post(url=api_url, headers=headers, json=payload) as response: + if response.status == 200: + output.latency = time.perf_counter() - st + data = await response.json() + output.success = True + output.generated_text = "" + output.prompt_len = data.get("usage", {}).get("prompt_tokens", 0) + else: + output.success = False + output.error = response.reason or "" + except Exception as e: + output.success = False + output.error = str(e) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "input": request_func_input.prompt, + } + _update_payload_common(payload, request_func_input) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + return await _run_openai_embeddings( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_openai_embeddings_chat( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, + mm_position: Literal["first", "last"] = "last", +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") + + content = _get_chat_content(request_func_input, mm_position=mm_position) + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + } + _update_payload_common(payload, request_func_input) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + return await _run_openai_embeddings( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +def _try_extract_request_idx(request_func_input: RequestFuncInput): + if request_func_input.request_id: + match = re.search(r"(\d+)$", request_func_input.request_id) + if match: + try: + return int(match.group(1)) + except ValueError: + pass + + return None + + +def _preprocess_clip(request_func_input: RequestFuncInput): + if request_func_input.multi_modal_content: + # Image input + request_func_input.prompt = "" + + # max_model_len=77 is too short for most datasets, + # so by default we truncate the prompt to max_model_len + if request_func_input.extra_body is None: + request_func_input.extra_body = {} + if "truncate_prompt_tokens" not in request_func_input.extra_body: + request_func_input.extra_body["truncate_prompt_tokens"] = -1 + + +def _preprocess_vlm2vec(request_func_input: RequestFuncInput): + if request_func_input.multi_modal_content: + request_idx = _try_extract_request_idx(request_func_input) + + # Adjust the ratio manually if needed. + use_image_only_prompt = request_idx is None or request_idx % 2 == 0 + + if use_image_only_prompt: + # Image input + request_func_input.prompt = "Represent the given image." + else: + # Text+Image input + request_func_input.prompt = ( + f"Represent the given image with the following question: " + f"{request_func_input.prompt}" + ) + + +async def async_request_openai_embeddings_clip( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + _preprocess_clip(request_func_input) + + return await async_request_openai_embeddings_chat( + request_func_input, + session, + pbar=pbar, + ) + + +async def async_request_openai_embeddings_vlm2vec( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + _preprocess_vlm2vec(request_func_input) + + return await async_request_openai_embeddings_chat( + request_func_input, + session, + pbar=pbar, + mm_position="first", + ) + + +async def async_request_infinity_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "Infinity Embeddings API", "embeddings") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + } + + if request_func_input.prompt: + payload["input"] = request_func_input.prompt + else: + mm_content = request_func_input.multi_modal_content + assert isinstance(mm_content, dict) + + mm_type = mm_content["type"] + payload["input"] = mm_content[mm_type]["url"] + payload["modality"] = mm_type.split("_", 1)[0] + + _update_payload_common(payload, request_func_input) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + return await _run_openai_embeddings( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_infinity_embeddings_clip( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + _preprocess_clip(request_func_input) + + return await async_request_infinity_embeddings( + request_func_input, + session, + pbar=pbar, + ) + + # TODO: Add more request functions for different API protocols. -ASYNC_REQUEST_FUNCS = { +ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { "vllm": async_request_openai_completions, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, + "openai-embeddings-chat": async_request_openai_embeddings_chat, + "openai-embeddings-clip": async_request_openai_embeddings_clip, + "openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec, + # Infinity embedding server: https://github.com/michaelfeil/infinity + "infinity-embeddings": async_request_infinity_embeddings, + "infinity-embeddings-clip": async_request_infinity_embeddings_clip, + # (Infinity embedding server does not support vlm2vec) } OPENAI_COMPATIBLE_BACKENDS = [ - k for k, v in ASYNC_REQUEST_FUNCS.items() - if v in (async_request_openai_completions, - async_request_openai_chat_completions) + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, async_request_openai_chat_completions) ] diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py index 7e836158386a9..5649faf055976 100644 --- a/vllm/benchmarks/lib/ready_checker.py +++ b/vllm/benchmarks/lib/ready_checker.py @@ -8,11 +8,11 @@ import time import aiohttp from tqdm.asyncio import tqdm -from .endpoint_request_func import RequestFuncInput, RequestFuncOutput +from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput async def wait_for_endpoint( - request_func, + request_func: RequestFunc, test_input: RequestFuncInput, session: aiohttp.ClientSession, timeout_seconds: int = 600, @@ -20,30 +20,29 @@ async def wait_for_endpoint( ) -> RequestFuncOutput: """ Wait for an endpoint to become available before starting benchmarks. - + Args: request_func: The async request function to call test_input: The RequestFuncInput to test with timeout_seconds: Maximum time to wait in seconds (default: 10 minutes) retry_interval: Time between retries in seconds (default: 5 seconds) - + Returns: RequestFuncOutput: The successful response - + Raises: ValueError: If the endpoint doesn't become available within the timeout """ deadline = time.perf_counter() + timeout_seconds output = RequestFuncOutput(success=False) print(f"Waiting for endpoint to become up in {timeout_seconds} seconds") - + with tqdm( - total=timeout_seconds, + total=timeout_seconds, bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining", unit="s", ) as pbar: - - while True: + while True: # update progress bar remaining = deadline - time.perf_counter() elapsed = timeout_seconds - remaining @@ -57,16 +56,17 @@ async def wait_for_endpoint( # ping the endpoint using request_func try: output = await request_func( - request_func_input=test_input, session=session) + request_func_input=test_input, session=session + ) if output.success: pbar.close() return output except aiohttp.ClientConnectorError: pass - + # retry after a delay sleep_duration = min(retry_interval, remaining) if sleep_duration > 0: await asyncio.sleep(sleep_duration) - + return output diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py index 5f95fdcc75829..32e9db4990078 100644 --- a/vllm/benchmarks/lib/utils.py +++ b/vllm/benchmarks/lib/utils.py @@ -8,9 +8,9 @@ import os from typing import Any -def convert_to_pytorch_benchmark_format(args: argparse.Namespace, - metrics: dict[str, list], - extra_info: dict[str, Any]) -> list: +def convert_to_pytorch_benchmark_format( + args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] +) -> list: """ Save the benchmark results in the format used by PyTorch OSS benchmark with on metric per record @@ -38,12 +38,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, }, } - tp = record["benchmark"]["extra_info"]["args"].get( - "tensor_parallel_size") + tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") # Save tensor_parallel_size parameter if it's part of the metadata if not tp and "tensor_parallel_size" in extra_info: - record["benchmark"]["extra_info"]["args"][ - "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( + extra_info["tensor_parallel_size"] + ) records.append(record) @@ -51,10 +51,14 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, class InfEncoder(json.JSONEncoder): - def clear_inf(self, o: Any): if isinstance(o, dict): - return {k: self.clear_inf(v) for k, v in o.items()} + return { + str(k) + if not isinstance(k, (str, int, float, bool, type(None))) + else k: self.clear_inf(v) + for k, v in o.items() + } elif isinstance(o, list): return [self.clear_inf(v) for v in o] elif isinstance(o, float) and math.isinf(o): diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 79f2c475cbe5d..c3c45f05f800b 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -4,28 +4,33 @@ r"""Benchmark online serving throughput. On the server side, run one of the following commands to launch the vLLM OpenAI API server: - vllm serve <your_model> <engine arguments> + vllm serve <your_model> <engine arguments> On the client side, run: vllm bench serve \ - --endpoint-type <endpoint_type. Default 'openai'> \ - --label <benchmark result label. Default using endpoint_type> \ + --backend <backend or endpoint type. Default 'openai'> \ + --label <benchmark result label. Default using backend> \ --model <your_model> \ --dataset-name <dataset_name. Default 'random'> \ --request-rate <request_rate. Default inf> \ --num-prompts <num_prompts. Default 1000> """ + import argparse import asyncio +import contextlib import gc +import importlib.util import json import os import random +import shutil import time import warnings from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime +from enum import Enum from typing import Any, Literal, Optional import aiohttp @@ -33,18 +38,28 @@ import numpy as np from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase -from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser, - get_samples) +from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples from vllm.benchmarks.lib.endpoint_request_func import ( - ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, - RequestFuncOutput) + ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput, +) from vllm.benchmarks.lib.ready_checker import wait_for_endpoint -from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, - write_to_json) +from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) and ( + shutil.which("gnuplot") is not None +) + + +class TaskType(Enum): + GENERATION = "generation" + EMBEDDING = "embedding" + @dataclass class BenchmarkMetrics: @@ -74,6 +89,21 @@ class BenchmarkMetrics: median_e2el_ms: float std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] + # Max output tokens per second and concurrent requests at that peak + max_output_tokens_per_s: float + max_concurrent_requests: int + + +@dataclass +class EmbedBenchmarkMetrics: + completed: int + total_input: int + request_throughput: float + total_token_throughput: float + mean_e2el_ms: float + std_e2el_ms: float + median_e2el_ms: float + percentiles_e2el_ms: float def _get_current_request_rate( @@ -84,8 +114,11 @@ def _get_current_request_rate( total_requests: int, request_rate: float, ) -> float: - if (ramp_up_strategy and ramp_up_start_rps is not None - and ramp_up_end_rps is not None): + if ( + ramp_up_strategy + and ramp_up_start_rps is not None + and ramp_up_end_rps is not None + ): progress = request_index / max(total_requests - 1, 1) if ramp_up_strategy == "linear": increase = (ramp_up_end_rps - ramp_up_start_rps) * progress @@ -123,7 +156,7 @@ async def get_request( A lower burstiness value (0 < burstiness < 1) results in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. - ramp_up_strategy (optional): + ramp_up_strategy (optional): The ramp-up strategy. Can be "linear" or "exponential". If None, uses constant request rate (specified by request_rate). ramp_up_start_rps (optional): @@ -132,10 +165,10 @@ async def get_request( The ending request rate for ramp-up. """ assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) # Convert to list to get length for ramp-up calculations - if isinstance(input_requests, Iterable) and not isinstance( - input_requests, list): + if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): input_requests = list(input_requests) total_requests = len(input_requests) @@ -145,12 +178,14 @@ async def get_request( request_rates = [] delay_ts = [] for request_index, request in enumerate(input_requests): - current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + current_request_rate = _get_current_request_rate( + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate, + ) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -160,7 +195,7 @@ async def get_request( # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) - + # Calculate the cumulative delay time from the first sent out requests. for i in range(1, len(delay_ts)): delay_ts[i] += delay_ts[i - 1] @@ -170,11 +205,11 @@ async def get_request( # logic would re-scale delay time to ensure the final delay_ts # align with target_total_delay_s. # - # NOTE: If we simply accumulate the random delta values - # from the gamma distribution, their sum would have 1-2% gap + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap # from target_total_delay_s. The purpose of the following logic is to - # close the gap for stablizing the throughput data - # from different random seeds. + # close the gap for stabilizing the throughput data + # from different random seeds. target_total_delay_s = total_requests / request_rate normalize_factor = target_total_delay_s / delay_ts[-1] delay_ts = [delay * normalize_factor for delay in delay_ts] @@ -189,6 +224,49 @@ async def get_request( yield request, request_rates[request_index] +def calculate_metrics_for_embeddings( + outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float] +) -> EmbedBenchmarkMetrics: + """Calculate the metrics for the embedding requests. + + Args: + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + selected_percentiles: The percentiles to select. + + Returns: + The calculated benchmark metrics. + """ + total_input = 0 + completed = 0 + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + e2els.append(outputs[i].latency) + completed += 1 + total_input += outputs[i].prompt_len + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2, + ) + metrics = EmbedBenchmarkMetrics( + completed=completed, + total_input=total_input, + request_throughput=completed / dur_s, + total_token_throughput=total_input / dur_s, + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], + ) + return metrics + + def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], @@ -230,8 +308,10 @@ def calculate_metrics( # bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer( + outputs[i].generated_text, add_special_tokens=False + ).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -254,16 +334,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -274,7 +357,74 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) + + # Calculate max output tokens per second metric + max_output_tokens_per_s = 0.0 + max_concurrent_requests = 0 + + # Find the time range across all successful requests + successful_outputs = [output for output in outputs if output.success] + if successful_outputs: + min_start_time = min(output.start_time for output in successful_outputs) + max_end_time = max( + output.start_time + output.latency for output in successful_outputs + ) + + # Create second buckets (ceiling to ensure we capture all time) + duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 + tokens_per_second = np.zeros(duration_seconds) + concurrent_requests_per_second = np.zeros(duration_seconds) + + for i, output in enumerate(successful_outputs): + # Calculate token generation timestamp using + # start_time, ttft, and itl + token_times = [output.start_time + output.ttft] + current_time = token_times[0] + for itl_value in output.itl: + current_time += itl_value + token_times.append(current_time) + + # Add tokens to second buckets + for token_time in token_times: + second_bucket = int(token_time - min_start_time) + if 0 <= second_bucket < duration_seconds: + tokens_per_second[second_bucket] += 1 + + # Track concurrent requests for each second this request was active + request_start_second = int(output.start_time - min_start_time) + request_end_second = int( + (output.start_time + output.latency) - min_start_time + ) + + for second in range(request_start_second, request_end_second + 1): + concurrent_requests_per_second[second] += 1 + + # Find the maximum tokens per second and corresponding + # concurrent requests + if len(tokens_per_second) > 0: + max_output_tokens_per_s = float(np.max(tokens_per_second)) + max_concurrent_requests = int(np.max(concurrent_requests_per_second)) + + if TERM_PLOTLIB_AVAILABLE: + import termplotlib as tpl + + fig = tpl.figure() + fig.plot( + np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second", + ) + fig.plot( + np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second", + ) + fig.show() + else: + print("tip: install termplotlib and gnuplot to plot the metrics") + metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -283,33 +433,40 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by the endpoint + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by the endpoint std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], + max_output_tokens_per_s=max_output_tokens_per_s, + max_concurrent_requests=max_concurrent_requests, ) return metrics, actual_output_lens async def benchmark( + task_type: TaskType, endpoint_type: str, api_url: str, base_url: str, @@ -328,16 +485,17 @@ async def benchmark( goodput_config_dict: dict[str, float], max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], extra_body: Optional[dict], ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, ramp_up_start_rps: Optional[int] = None, ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): - if endpoint_type in ASYNC_REQUEST_FUNCS: + try: request_func = ASYNC_REQUEST_FUNCS[endpoint_type] - else: - raise ValueError(f"Unknown endpoint_type: {endpoint_type}") + except KeyError: + raise ValueError(f"Unknown backend: {endpoint_type}") from None # Reuses connections across requests to reduce TLS handshake overhead. connector = aiohttp.TCPConnector( @@ -383,51 +541,63 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body, ) - test_output = await wait_for_endpoint( - request_func, - test_input, - session, - timeout_seconds=ready_check_timeout_sec, - ) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + if ready_check_timeout_sec > 0: + test_output = await wait_for_endpoint( + request_func, + test_input, + session, + timeout_seconds=ready_check_timeout_sec, + ) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark " + "arguments are correctly specified. " + f"Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") else: - print("Initial test run completed. Starting main benchmark run...") + print("Skipping endpoint ready check.") if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))]) + [random.choice(lora_modules) for _ in range(len(input_requests))] + ) if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + ) profile_output = await request_func( - request_func_input=profile_input, session=session) + request_func_input=profile_input, session=session + ) if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") - print(f"Will increase RPS from {ramp_up_start_rps} to " - f"{ramp_up_end_rps} RPS over the duration of the benchmark.") + print( + f"Will increase RPS from {ramp_up_start_rps} to " + f"{ramp_up_end_rps} RPS over the duration of the benchmark." + ) else: print(f"Traffic request rate: {request_rate}") @@ -436,22 +606,17 @@ async def benchmark( pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = ( + asyncio.Semaphore(max_concurrency) + if max_concurrency + else contextlib.nullcontext() + ) async def limited_request_func(request_func_input, session, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, - session=session, - pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - session=session, - pbar=pbar) + return await request_func( + request_func_input=request_func_input, session=session, pbar=pbar + ) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] @@ -460,23 +625,27 @@ async def benchmark( last_int_rps = -1 if ramp_up_strategy is not None and ramp_up_start_rps is not None: last_int_rps = ramp_up_start_rps - rps_change_events.append({ - "rps": last_int_rps, - "timestamp": datetime.now().isoformat(), - }) + rps_change_events.append( + { + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + } + ) async for request, current_request_rate in get_request( - input_requests, request_rate, burstiness, ramp_up_strategy, - ramp_up_start_rps, ramp_up_end_rps): + input_requests, + request_rate, + burstiness, + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + ): if ramp_up_strategy is not None: current_int_rps = int(current_request_rate) if current_int_rps > last_int_rps: timestamp = datetime.now().isoformat() for rps_val in range(last_int_rps + 1, current_int_rps + 1): - rps_change_events.append({ - "rps": rps_val, - "timestamp": timestamp - }) + rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, @@ -490,22 +659,27 @@ async def benchmark( req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - request_id=request_id,) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + request_id=request_id, + ) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - session=session, - pbar=pbar))) + limited_request_func( + request_func_input=request_func_input, session=session, pbar=pbar + ) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if pbar is not None: @@ -513,55 +687,95 @@ async def benchmark( benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) + if task_type == TaskType.GENERATION: + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + else: + metrics = calculate_metrics_for_embeddings( + outputs=outputs, + dur_s=benchmark_duration, + selected_percentiles=selected_percentiles, + ) + actual_output_lens = 0 - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) if max_concurrency is not None: - print("{:<40} {:<10}".format("Maximum request concurrency:", - max_concurrency)) - if request_rate != float('inf'): - print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", - request_rate )) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) + if request_rate != float("inf"): + print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + if isinstance(metrics, BenchmarkMetrics): + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + if isinstance(metrics, BenchmarkMetrics): + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Peak concurrent requests:", metrics.max_concurrent_requests + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": - metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } + if isinstance(metrics, BenchmarkMetrics): + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "max_output_tokens_per_s": metrics.max_output_tokens_per_s, + "max_concurrent_requests": metrics.max_concurrent_requests, + } + else: + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "request_throughput": metrics.request_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "errors": [output.error for output in outputs], + } if rps_change_events: result["rps_change_events"] = rps_change_events @@ -578,30 +792,37 @@ async def benchmark( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") + if task_type == TaskType.GENERATION: + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") print("=" * 50) @@ -617,7 +838,8 @@ async def benchmark( logprobs=logprobs, ) profile_output = await request_func( - request_func_input=profile_input, session=session) + request_func_input=profile_input, session=session + ) if profile_output.success: print("Profiler stopped") @@ -636,12 +858,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -654,31 +878,42 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any], - file_name: str) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any], file_name: str +) -> None: metrics = [ - "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", - "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", - "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + "median_ttft_ms", + "mean_ttft_ms", + "std_ttft_ms", + "p99_ttft_ms", + "mean_tpot_ms", + "median_tpot_ms", + "std_tpot_ms", + "p99_tpot_ms", + "median_itl_ms", + "mean_itl_ms", + "std_itl_ms", + "p99_itl_ms", ] # These raw data might be useful, but they are rather big. They can be added # later if needed ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] pt_records = convert_to_pytorch_benchmark_format( args=args, - metrics={k: [results[k]] - for k in metrics if k in results}, + metrics={k: [results[k]] for k in metrics if k in results}, extra_info={ k: results[k] - for k in results if k not in metrics and k not in ignored_metrics - }) + for k in results + if k not in metrics and k not in ignored_metrics + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" @@ -687,24 +922,19 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, def add_cli_args(parser: argparse.ArgumentParser): add_dataset_parser(parser) - parser.add_argument( - "--endpoint-type", - type=str, - default="openai", - choices=list(ASYNC_REQUEST_FUNCS.keys()), - ) parser.add_argument( "--label", type=str, default=None, help="The label (prefix) of the benchmark results. If not specified, " - "the endpoint type will be used as the label.", + "the value of '--backend' will be used as the label.", ) parser.add_argument( "--backend", type=str, - default="vllm", + default="openai", choices=list(ASYNC_REQUEST_FUNCS.keys()), + help="The type of backend or endpoint to use for the benchmark.", ) parser.add_argument( "--base-url", @@ -721,6 +951,15 @@ def add_cli_args(parser: argparse.ArgumentParser): default="/v1/completions", help="API endpoint.", ) + parser.add_argument( + "--header", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --header x-additional-info=0.3.3) " + "for headers to be passed with each request. These headers override " + "per backend constants and values set via environment variable, and " + "will be overriden by other arguments (such as request ids).", + ) parser.add_argument( "--max-concurrency", type=int, @@ -732,7 +971,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -743,19 +983,20 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( "--logprobs", type=int, default=None, - help=("Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed"), + help=( + "Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed" + ), ) parser.add_argument( "--request-rate", @@ -838,32 +1079,34 @@ def add_cli_args(parser: argparse.ArgumentParser): "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\"." - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99".' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) @@ -875,28 +1118,24 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Specify the prefix of request id.", ) - sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( "--top-p", type=float, default=None, - help="Top-p sampling parameter. Only has effect on " - "openai-compatible backends.", + help="Top-p sampling parameter. Only has effect on openai-compatible backends.", ) sampling_group.add_argument( "--top-k", type=int, default=None, - help="Top-k sampling parameter. Only has effect on " - "openai-compatible backends.", + help="Top-k sampling parameter. Only has effect on openai-compatible backends.", ) sampling_group.add_argument( "--min-p", type=float, default=None, - help="Min-p sampling parameter. Only has effect on " - "openai-compatible backends.", + help="Min-p sampling parameter. Only has effect on openai-compatible backends.", ) sampling_group.add_argument( "--temperature", @@ -906,31 +1145,57 @@ def add_cli_args(parser: argparse.ArgumentParser): "openai-compatible backends. If not specified, default to greedy " "decoding (i.e. temperature==0.0).", ) + sampling_group.add_argument( + "--frequency-penalty", + type=float, + default=None, + help="Frequency penalty sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--presence-penalty", + type=float, + default=None, + help="Presence penalty sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--repetition-penalty", + type=float, + default=None, + help="Repetition penalty sampling parameter. Only has effect on " + "openai-compatible backends.", + ) parser.add_argument( - '--tokenizer-mode', + "--tokenizer-mode", type=str, default="auto", - choices=['auto', 'slow', 'mistral', 'custom'], + choices=["auto", "slow", "mistral", "custom"], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer. \n* ' + "always use the slow tokenizer. \n* " '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.') + '"custom" will use --tokenizer to select the preregistered tokenizer.', + ) - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ") + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) - parser.add_argument("--lora-modules", - nargs='+', - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.") + parser.add_argument( + "--lora-modules", + nargs="+", + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.", + ) parser.add_argument( "--ramp-up-strategy", @@ -940,7 +1205,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The ramp-up strategy. This would be used to " "ramp up the request rate from initial RPS to final " "RPS rate (specified by --ramp-up-start-rps and " - "--ramp-up-end-rps.) over the duration of the benchmark." + "--ramp-up-end-rps.) over the duration of the benchmark.", ) parser.add_argument( "--ramp-up-start-rps", @@ -961,13 +1226,15 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, default=600, help="Maximum time to wait for the endpoint to become ready " - "in seconds (default: 600 seconds / 10 minutes).", + "in seconds (default: 600 seconds / 10 minutes). If set to 0, " + "the ready check will be skipped.", ) def main(args: argparse.Namespace) -> dict[str, Any]: return asyncio.run(main_async(args)) + async def main_async(args: argparse.Namespace) -> dict[str, Any]: print(args) random.seed(args.seed) @@ -990,12 +1257,9 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: raise ValueError("Ramp-up start and end RPS must be non-negative") if args.ramp_up_start_rps > args.ramp_up_end_rps: raise ValueError("Ramp-up start RPS must be less than end RPS") - if (args.ramp_up_strategy == "exponential" - and args.ramp_up_start_rps == 0): - raise ValueError( - "For exponential ramp-up, the start RPS cannot be 0.") + if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: + raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") - endpoint_type = args.endpoint_type label = args.label model_id = args.model model_name = args.served_model_name @@ -1009,69 +1273,94 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" - tokenizer = get_tokenizer(tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code) + # Headers + headers = None + if args.header: + headers = {} + for item in args.header: + if "=" in item: + kvstring = item.split("=", 1) + headers[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError("Invalid header format. Please use KEY=VALUE format.") + + tokenizer = get_tokenizer( + tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) if args.dataset_name is None: raise ValueError( "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required.") + "'--dataset-path' if required." + ) # Load the dataset. input_requests = get_samples(args, tokenizer) goodput_config_dict = check_goodput_args(args) + backend = args.backend + task_type = TaskType.EMBEDDING if "embeddings" in backend else TaskType.GENERATION + # Collect the sampling parameters. - sampling_params = { - k: v - for k, v in { - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "temperature": args.temperature, - }.items() if v is not None - } + if task_type == TaskType.GENERATION: + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature, + "frequency_penalty": args.frequency_penalty, + "presence_penalty": args.presence_penalty, + "repetition_penalty": args.repetition_penalty, + }.items() + if v is not None + } - # Sampling parameters are only supported by openai-compatible backend. - if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: - raise ValueError("Sampling parameters are only supported by " - "openai-compatible backends.") + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError( + "Sampling parameters are only supported by openai-compatible backends." + ) - if "temperature" not in sampling_params: - sampling_params["temperature"] = 0.0 # Default to greedy decoding. + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + else: + sampling_params = {} # Avoid GC processing "static" data - reduce pause times. gc.collect() gc.freeze() benchmark_result = await benchmark( - endpoint_type=args.endpoint_type, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ready_check_timeout_sec=args.ready_check_timeout_sec, - ) + task_type=task_type, + endpoint_type=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_headers=headers, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + ready_check_timeout_sec=args.ready_check_timeout_sec, + ) # Save config and results to json result_json: dict[str, Any] = {} @@ -1079,7 +1368,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") result_json["date"] = current_dt - result_json["endpoint_type"] = args.endpoint_type + result_json["endpoint_type"] = args.backend # for backward compatibility + result_json["backend"] = args.backend result_json["label"] = label result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id @@ -1089,7 +1379,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.metadata: for item in args.metadata: if "=" in item: - kvstring = item.split("=") + kvstring = item.split("=", 1) result_json[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError( @@ -1097,8 +1387,9 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ) # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf" + ) result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -1113,12 +1404,12 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if not args.save_detailed: # Remove fields with too many data points for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", ]: if field in result_json: del result_json[field] @@ -1128,11 +1419,14 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Save to file if args.save_result or args.append_result: base_model_id = model_id.split("/")[-1] - max_concurrency_str = (f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None else "") - label = label or endpoint_type + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None + else "" + ) + label = label or args.backend if args.ramp_up_strategy is not None: - file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa else: file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: @@ -1140,13 +1434,13 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.result_dir: os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) - with open(file_name, - mode="a+" if args.append_result else "w", - encoding="utf-8") as outfile: + with open( + file_name, mode="a+" if args.append_result else "w", encoding="utf-8" + ) as outfile: # Append a newline. if args.append_result and outfile.tell() != 0: outfile.write("\n") json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) - return result_json \ No newline at end of file + return result_json diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index f022a55e625f5..b0f63fd2c7227 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline inference throughput.""" + import argparse import dataclasses import json @@ -13,18 +14,21 @@ from typing import Any, Optional, Union import torch import uvloop from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase -from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, - ConversationDataset, - InstructCoderDataset, - PrefixRepetitionRandomDataset, - RandomDataset, SampleRequest, - ShareGPTDataset, SonnetDataset, - VisionArenaDataset) -from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, - write_to_json) +from vllm.benchmarks.datasets import ( + AIMODataset, + BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, + PrefixRepetitionRandomDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) +from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest @@ -37,26 +41,34 @@ def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: - prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + prompt = ( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"]) + if "prompt_token_ids" in request.prompt + else TextPrompt(prompt=request.prompt) + ) + if request.multi_modal_data: + assert isinstance(request.multi_modal_data, dict) + prompt["multi_modal_data"] = request.multi_modal_data + prompts.append(prompt) + sampling_params.append( SamplingParams( n=n, @@ -65,7 +77,8 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -75,10 +88,13 @@ def run_vllm( outputs = None if not use_beam_search: start = time.perf_counter() - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + if do_profile: + llm.start_profile() + outputs = llm.generate( + prompts, sampling_params, lora_request=lora_requests, use_tqdm=True + ) + if do_profile: + llm.stop_profile() end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -88,36 +104,46 @@ def run_vllm( for request in requests: assert request.expected_output_len == output_len start = time.perf_counter() + if do_profile: + llm.start_profile() llm.beam_search( prompts, BeamSearchParams( beam_width=n, max_tokens=output_len, ignore_eos=True, - )) + ), + ) + if do_profile: + llm.stop_profile() end = time.perf_counter() return end - start, outputs def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + do_profile: bool, + disable_detokenize: bool = False, +) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking multimodal models as it properly handles multimodal inputs and chat formatting. For non-multimodal models, use run_vllm() instead. """ from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests." + ) prompts = [] sampling_params: list[SamplingParams] = [] @@ -131,9 +157,14 @@ def run_vllm_chat( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() + if do_profile: + llm.start_profile() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + if do_profile: + llm.stop_profile() end = time.perf_counter() return end - start, outputs @@ -142,36 +173,44 @@ async def run_vllm_async( requests: list[SampleRequest], n: int, engine_args: AsyncEngineArgs, + do_profile: bool, disable_frontend_multiprocessing: bool = False, disable_detokenize: bool = False, ) -> float: from vllm import SamplingParams from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, + ) async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing=disable_frontend_multiprocessing, ) as llm: - model_config = await llm.get_model_config() + model_config = llm.model_config assert all( - model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] lora_requests: list[Optional[LoRARequest]] = [] for request in requests: - prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + prompt = ( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"]) + if "prompt_token_ids" in request.prompt + else TextPrompt(prompt=request.prompt) + ) + + if request.multi_modal_data: + assert isinstance(request.multi_modal_data, dict) + prompt["multi_modal_data"] = request.multi_modal_data + sampling_params.append( SamplingParams( n=n, @@ -180,21 +219,24 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp, - lr) in enumerate(zip(prompts, sampling_params, lora_requests)): - generator = llm.generate(prompt, - sp, - lora_request=lr, - request_id=f"test{i}") + if do_profile: + await llm.start_profile() + for i, (prompt, sp, lr) in enumerate( + zip(prompts, sampling_params, lora_requests) + ): + generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass + if do_profile: + await llm.stop_profile() end = time.perf_counter() return end - start @@ -209,7 +251,8 @@ def run_hf( disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token @@ -232,14 +275,15 @@ def run_hf( # Check if we can add more requests to the batch. next_prompt_len = requests[i + 1].prompt_len next_output_len = requests[i + 1].expected_output_len - if (max(max_prompt_len, next_prompt_len) + - max(max_output_len, next_output_len)) <= 2048: + if ( + max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len) + ) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", - padding=True).input_ids + input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=True, @@ -262,8 +306,9 @@ def run_hf( return end - start -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ @@ -271,9 +316,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, "tokens_per_second": [results["tokens_per_second"]], }, extra_info={ - k: results[k] - for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }) + k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" @@ -305,7 +350,8 @@ def get_requests(args, tokenizer): sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_name == "sonnet": assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) dataset_cls = SonnetDataset sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["return_prompt_formatted"] = True @@ -314,21 +360,21 @@ def get_requests(args, tokenizer): elif args.dataset_name == "hf": if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_split"] = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset - common_kwargs['dataset_subset'] = args.hf_subset - common_kwargs['dataset_split'] = args.hf_split + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_cls = AIMODataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" elif args.dataset_name == "prefix_repetition": dataset_cls = PrefixRepetitionRandomDataset sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len @@ -339,7 +385,26 @@ def get_requests(args, tokenizer): raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} - return dataset_cls(**common_kwargs).sample(**sample_kwargs) + requests = dataset_cls(**common_kwargs).sample(**sample_kwargs) + requests = filter_requests_for_dp(requests, args.data_parallel_size) + return requests + + +def filter_requests_for_dp(requests, data_parallel_size): + # Note(zhuohan): The way we get data_parallel_rank is hacky and only + # works for external launcher mode. Should be cleaned up and deprecated + # in the future with a better vLLM distributed process design. + if data_parallel_size == 1: + return requests + + global_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + data_parallel_rank = global_rank // (world_size // data_parallel_size) + return [ + r + for i, r in enumerate(requests) + if i % data_parallel_size == data_parallel_rank + ] def validate_args(args): @@ -352,7 +417,8 @@ def validate_args(args): warnings.warn( "The '--dataset' argument will be deprecated in the next release. " "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2) + stacklevel=2, + ) args.dataset_path = args.dataset if not getattr(args, "tokenizer", None): @@ -369,9 +435,8 @@ def validate_args(args): and not args.dataset_path and args.dataset_name not in {"prefix_repetition"} ): - print( - "When dataset path is not set, it will default to random dataset") - args.dataset_name = 'random' + print("When dataset path is not set, it will default to random dataset") + args.dataset_name = "random" if args.input_len is None: raise ValueError("input_len must be provided for a random dataset") @@ -379,41 +444,55 @@ def validate_args(args): # --hf-subset and --hf-split: only used # when dataset_name is 'hf' if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None): - warnings.warn("--hf-subset and --hf-split will be ignored \ + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None + ): + warnings.warn( + "--hf-subset and --hf-split will be ignored \ since --dataset-name is not 'hf'.", - stacklevel=2) + stacklevel=2, + ) elif args.dataset_name == "hf": if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm-chat", ( + f"{args.dataset_path} needs to use vllm-chat as the backend." + ) + elif args.dataset_path in ( + InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm", ( + f"{args.dataset_path} needs to use vllm as the backend." + ) else: - raise ValueError( - f"{args.dataset_path} is not supported by hf dataset.") + raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != 'random' and args.random_range_ratio is not None: - warnings.warn("--random-range-ratio will be ignored since \ + if args.dataset_name != "random" and args.random_range_ratio is not None: + warnings.warn( + "--random-range-ratio will be ignored since \ --dataset-name is not 'random'.", - stacklevel=2) + stacklevel=2, + ) # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # set. - if args.dataset_name not in {"random", "sonnet", None - } and args.prefix_len is not None: - warnings.warn("--prefix-len will be ignored since --dataset-name\ + if ( + args.dataset_name not in {"random", "sonnet", None} + and args.prefix_len is not None + ): + warnings.warn( + "--prefix-len will be ignored since --dataset-name\ is not 'random', 'sonnet', or not set.", - stacklevel=2) + stacklevel=2, + ) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError( - "LoRA benchmarking is only supported for vLLM backend") + raise ValueError("LoRA benchmarking is only supported for vLLM backend") if getattr(args, "enable_lora", False) and args.lora_path is None: raise ValueError("LoRA path must be provided when enable_lora is True") @@ -423,8 +502,10 @@ def validate_args(args): if args.backend != "hf" and args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") - if args.backend in {"hf", "mii"} and getattr(args, "quantization", - None) is not None: + if ( + args.backend in {"hf", "mii"} + and getattr(args, "quantization", None) is not None + ): raise ValueError("Quantization is only for vLLM backend.") if args.backend == "mii" and args.dtype != "auto": @@ -432,32 +513,36 @@ def validate_args(args): if args.backend == "mii" and args.n != 1: raise ValueError("n must be 1 for MII backend.") if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII backend.") + + if args.data_parallel_size > 1 and ( + args.distributed_executor_backend != "external_launcher" or args.async_engine + ): + # --data-parallel is not supported fully. + # Old issue: https://github.com/vllm-project/vllm/issues/16222 + # Currently we only support data parallel with external launcher + # mode (i.e., launch with toruchrun). raise ValueError( - "Tokenizer must be the same as the model for MII backend.") - - # --data-parallel is not supported currently. - # https://github.com/vllm-project/vllm/issues/16222 - if args.data_parallel_size > 1: - raise ValueError( - "Data parallel is not supported in offline benchmark, " + "Data parallel is only supported with external launcher mode " + "with synchronous engine in offline benchmark, " "please use benchmark serving instead" ) def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm") + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm", + ) parser.add_argument( "--dataset-name", type=str, - choices=[ - "sharegpt", "random", "sonnet", "burstgpt", "hf", - "prefix_repetition" - ], + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf", "prefix_repetition"], help="Name of the dataset to benchmark on.", - default="sharegpt") + default="sharegpt", + ) parser.add_argument( "--dataset", type=str, @@ -465,57 +550,70 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Path to the ShareGPT dataset, will be deprecated in\ the next release. The dataset is expected to " "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: <prompt_or_response>]]]]") - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.") + "list[dict[..., value: <prompt_or_response>]]]]", + ) parser.add_argument( - '--output-json', + "--dataset-path", type=str, default=None, help="Path to the dataset" + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=1000, help="Number of prompts to process." + ) + parser.add_argument( + "--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.", + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument("--async-engine", - action='store_true', - default=False, - help="Use vLLM async engine rather than LLM class.") - parser.add_argument("--disable-frontend-multiprocessing", - action='store_true', - default=False, - help="Disable decoupled async engine frontend.") + help="Path to save the throughput results in JSON format.", + ) + parser.add_argument( + "--async-engine", + action="store_true", + default=False, + help="Use vLLM async engine rather than LLM class.", + ) + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + default=False, + help="Disable decoupled async engine frontend.", + ) parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)")) + help=( + "Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)" + ), + ) # LoRA parser.add_argument( "--lora-path", type=str, default=None, help="Path to the lora adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.") + "a relative path, or a Hugging Face model identifier.", + ) parser.add_argument( "--prefix-len", type=int, @@ -535,18 +633,24 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # hf dtaset - parser.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - parser.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + parser.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + parser.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Use Torch Profiler. The env variable " + "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.", + ) # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( - "prefix repetition dataset options") + "prefix repetition dataset options" + ) prefix_repetition_group.add_argument( "--prefix-repetition-prefix-len", type=int, @@ -588,10 +692,10 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None - for request in requests) + is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": if args.async_engine: @@ -600,22 +704,40 @@ def main(args: argparse.Namespace): requests, args.n, AsyncEngineArgs.from_cli_args(args), - args.disable_frontend_multiprocessing, - args.disable_detokenize, - )) + disable_frontend_multiprocessing=args.disable_frontend_multiprocessing, + disable_detokenize=args.disable_detokenize, + do_profile=args.profile, + ) + ) else: elapsed_time, request_outputs = run_vllm( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, + args.n, + EngineArgs.from_cli_args(args), + disable_detokenize=args.disable_detokenize, + do_profile=args.profile, + ) elif args.backend == "hf": assert args.tensor_parallel_size == 1 - elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.hf_max_batch_size, args.trust_remote_code, - args.disable_detokenize) + if args.profile: + raise NotImplementedError("Profiling not implemented yet for backend='hf'.") + elapsed_time = run_hf( + requests, + args.model, + tokenizer, + args.n, + args.hf_max_batch_size, + args.trust_remote_code, + args.disable_detokenize, + ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, + args.n, + EngineArgs.from_cli_args(args), + disable_detokenize=args.disable_detokenize, + do_profile=args.profile, + ) else: raise ValueError(f"Unknown backend: {args.backend}") @@ -627,28 +749,31 @@ def main(args: argparse.Namespace): for ro in request_outputs: if not isinstance(ro, RequestOutput): continue - total_prompt_tokens += len( - ro.prompt_token_ids) if ro.prompt_token_ids else 0 - total_output_tokens += sum( - len(o.token_ids) for o in ro.outputs if o) + total_prompt_tokens += ( + len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 + ) + total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) total_num_tokens = total_prompt_tokens + total_output_tokens else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len - for r in requests) + total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests) total_prompt_tokens = total_num_tokens - total_output_tokens if is_multi_modal and args.backend != "vllm-chat": - print("\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details.") + print( + "\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details." + ) # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # vllm-chat backend counts the image tokens now - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s" + ) print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") diff --git a/vllm/collect_env.py b/vllm/collect_env.py index ee43ad12e8a5e..4ca0852e3998f 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -9,6 +9,7 @@ import locale import os import subprocess import sys + # Unlike the rest of the PyTorch this file must be python2 compliant. # This script outputs relevant system environment info # Run it with `python collect_env.py` or `python -m torch.utils.collect_env` @@ -20,46 +21,47 @@ from vllm.envs import environment_variables try: import torch + TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information SystemEnv = namedtuple( - 'SystemEnv', + "SystemEnv", [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', - 'rocm_version', # vllm specific field - 'neuron_sdk_version', # vllm specific field - 'vllm_version', # vllm specific field - 'vllm_build_flags', # vllm specific field - 'gpu_topo', # vllm specific field - 'env_vars', - ]) + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + "rocm_version", # vllm specific field + "vllm_version", # vllm specific field + "vllm_build_flags", # vllm specific field + "gpu_topo", # vllm specific field + "env_vars", + ], +) DEFAULT_CONDA_PATTERNS = { "torch", @@ -75,6 +77,7 @@ DEFAULT_CONDA_PATTERNS = { "zmq", "nvidia", "pynvml", + "flashinfer-python", } DEFAULT_PIP_PATTERNS = { @@ -90,6 +93,7 @@ DEFAULT_PIP_PATTERNS = { "zmq", "nvidia", "pynvml", + "flashinfer-python", } @@ -97,18 +101,17 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False try: - p = subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=shell) + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell + ) raw_output, raw_err = p.communicate() rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' + if get_platform() == "win32": + enc = "oem" else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) - if command == 'nvidia-smi topo -m': + if command == "nvidia-smi topo -m": # don't remove the leading whitespace of `nvidia-smi topo -m` # because they are meaningful output = output.rstrip() @@ -119,7 +122,7 @@ def run(command): except FileNotFoundError: cmd_str = command if isinstance(command, str) else command[0] - return 127, '', f"Command not found: {cmd_str}" + return 127, "", f"Command not found: {cmd_str}" def run_and_read_all(run_lambda, command): @@ -146,49 +149,54 @@ def run_and_return_first_line(run_lambda, command): rc, out, _ = run_lambda(command) if rc != 0: return None - return out.split('\n')[0] + return out.split("\n")[0] def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = DEFAULT_CONDA_PATTERNS - conda = os.environ.get('CONDA_EXE', 'conda') - out = run_and_read_all(run_lambda, [conda, 'list']) + conda = os.environ.get("CONDA_EXE", "conda") + out = run_and_read_all(run_lambda, [conda, "list"]) if out is None: return out - return "\n".join(line for line in out.splitlines() - if not line.startswith("#") and any(name in line - for name in patterns)) + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") and any(name in line for name in patterns) + ) def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', - r'clang version (.*)') + return run_and_parse_first_match( + run_lambda, "clang --version", r"clang version (.*)" + ) def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', - r'cmake (.*)') + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") def get_nvidia_driver_version(run_lambda): - if get_platform() == 'darwin': - cmd = 'kextstat | grep -i cuda' - return run_and_parse_first_match(run_lambda, cmd, - r'com[.]nvidia[.]CUDA [(](.*?)[)]') + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match( + run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" + ) smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, - r'Driver Version: (.*?) ') + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr( - torch.version, 'hip') and torch.version.hip is not None): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE + and hasattr(torch.version, "hip") + and torch.version.hip is not None + ): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -201,43 +209,42 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() - uuid_regex = re.compile(r' \(UUID: .+?\)') - rc, out, _ = run_lambda(smi + ' -L') + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") if rc != 0: return None # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, '', out) + return re.sub(uuid_regex, "", out) def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', - r'release .+ V(.*)') + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") - where_cmd = os.path.join(system_root, 'System32', 'where') + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get('CUDNN_LIBRARY') + l = os.environ.get("CUDNN_LIBRARY") if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() - for fn in out.split('\n'): + for fn in out.split("\n"): fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -247,20 +254,20 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] - result = '\n'.join(files) - return 'Probably one of the following:\n{}'.format(result) + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux - smi = 'nvidia-smi' - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', - 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', - 'NVSMI', smi) - new_path = os.path.join(system_root, 'System32', smi) + smi = "nvidia-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join( + program_files_root, "NVIDIA Corporation", "NVSMI", smi + ) + new_path = os.path.join(system_root, "System32", smi) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -271,17 +278,9 @@ def get_nvidia_smi(): def get_rocm_version(run_lambda): """Returns the ROCm version if available, otherwise 'N/A'.""" - return run_and_parse_first_match(run_lambda, 'hipcc --version', - r'HIP version: (\S+)') - - -def get_neuron_sdk_version(run_lambda): - # Adapted from your install script - try: - result = run_lambda(["neuron-ls"]) - return result if result[0] == 0 else 'N/A' - except Exception: - return 'N/A' + return run_and_parse_first_match( + run_lambda, "hipcc --version", r"HIP version: (\S+)" + ) def get_vllm_version(): @@ -290,12 +289,12 @@ def get_vllm_version(): if __version__ == "dev": return "N/A (dev)" version_str = __version_tuple__[-1] - if isinstance(version_str, str) and version_str.startswith('g'): + if isinstance(version_str, str) and version_str.startswith("g"): # it's a dev build - if '.' in version_str: + if "." in version_str: # it's a dev build containing local changes - git_sha = version_str.split('.')[0][1:] - date = version_str.split('.')[-1][1:] + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] return f"{__version__} (git sha: {git_sha}, date: {date})" else: # it's a dev build without local changes @@ -306,20 +305,19 @@ def get_vllm_version(): def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. - return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( - os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), - 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', - 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled', + return "CUDA Archs: {}; ROCm: {}".format( + os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"), + "Enabled" if os.environ.get("ROCM_HOME") else "Disabled", ) def get_gpu_topo(run_lambda): output = None - if get_platform() == 'linux': - output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if get_platform() == "linux": + output = run_and_read_all(run_lambda, "nvidia-smi topo -m") if output is None: - output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + output = run_and_read_all(run_lambda, "rocm-smi --showtopo") return output @@ -401,17 +399,17 @@ def get_gpu_topo(run_lambda): def get_cpu_info(run_lambda): - rc, out, err = 0, '', '' - if get_platform() == 'linux': - rc, out, err = run_lambda('lscpu') - elif get_platform() == 'win32': + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": rc, out, err = run_lambda( - 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ - CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE' + "wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE" ) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = 'None' + cpu_info = "None" if rc == 0: cpu_info = out else: @@ -420,67 +418,69 @@ def get_cpu_info(run_lambda): def get_platform(): - if sys.platform.startswith('linux'): - return 'linux' - elif sys.platform.startswith('win32'): - return 'win32' - elif sys.platform.startswith('cygwin'): - return 'cygwin' - elif sys.platform.startswith('darwin'): - return 'darwin' + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" else: return sys.platform def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', - r'(.*)') + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") def get_windows_version(run_lambda): - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') - findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic") + findstr_cmd = os.path.join(system_root, "System32", "findstr") return run_and_read_all( - run_lambda, - '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd) + ) def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', - r'Description:\t(.*)') + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) def check_release_file(run_lambda): - return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', - r'PRETTY_NAME="(.*)"') + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' + ) def get_os(run_lambda): from platform import machine + platform = get_platform() - if platform == 'win32' or platform == 'cygwin': + if platform == "win32" or platform == "cygwin": return get_windows_version(run_lambda) - if platform == 'darwin': + if platform == "darwin": version = get_mac_version(run_lambda) if version is None: return None - return 'macOS {} ({})'.format(version, machine()) + return "macOS {} ({})".format(version, machine()) - if platform == 'linux': + if platform == "linux": # Ubuntu/Debian based desc = get_lsb_version(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) - return '{} ({})'.format(platform, machine()) + return "{} ({})".format(platform, machine()) # Unknown platform return platform @@ -488,14 +488,26 @@ def get_os(run_lambda): def get_python_platform(): import platform + return platform.platform() def get_libc_version(): import platform - if get_platform() != 'linux': - return 'N/A' - return '-'.join(platform.libc_ver()) + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) + + +def is_uv_venv(): + if os.environ.get("UV"): + return True + pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg") + if os.path.exists(pyvenv_cfg_path): + with open(pyvenv_cfg_path, "r") as f: + return any(line.startswith("uv = ") for line in f) + return False def get_pip_packages(run_lambda, patterns=None): @@ -506,14 +518,15 @@ def get_pip_packages(run_lambda, patterns=None): def run_with_pip(): try: import importlib.util - pip_spec = importlib.util.find_spec('pip') + + pip_spec = importlib.util.find_spec("pip") pip_available = pip_spec is not None except ImportError: pip_available = False if pip_available: - cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] - elif os.environ.get("UV") is not None: + cmd = [sys.executable, "-mpip", "list", "--format=freeze"] + elif is_uv_venv(): print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] else: @@ -522,23 +535,24 @@ def get_pip_packages(run_lambda, patterns=None): ) out = run_and_read_all(run_lambda, cmd) - return "\n".join(line for line in out.splitlines() - if any(name in line for name in patterns)) + return "\n".join( + line for line in out.splitlines() if any(name in line for name in patterns) + ) - pip_version = 'pip3' if sys.version[0] == '3' else 'pip' + pip_version = "pip3" if sys.version[0] == "3" else "pip" out = run_with_pip() return pip_version, out def get_cachingallocator_config(): - ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() - config = os.environ.get('CUDA_MODULE_LOADING', '') + config = os.environ.get("CUDA_MODULE_LOADING", "") return config else: return "N/A" @@ -547,17 +561,26 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack - return str( - torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" def get_env_vars(): - env_vars = '' - secret_terms = ('secret', 'token', 'api', 'access', 'password') - report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", - "OMP_", "MKL_", "NVIDIA") + env_vars = "" + secret_terms = ("secret", "token", "api", "access", "password") + report_prefix = ( + "TORCH", + "NCCL", + "PYTORCH", + "CUDA", + "CUBLAS", + "CUDNN", + "OMP_", + "MKL_", + "NVIDIA", + ) for k, v in os.environ.items(): if any(term in k.lower() for term in secret_terms): continue @@ -578,30 +601,30 @@ def get_env_info(): debug_mode_str = str(torch.version.debug) cuda_available_str = str(torch.cuda.is_available()) cuda_version_str = torch.version.cuda - if not hasattr(torch.version, - 'hip') or torch.version.hip is None: # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + if ( + not hasattr(torch.version, "hip") or torch.version.hip is None + ): # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" else: # HIP version def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else 'N/A' + return _lst[0] if _lst else "N/A" - cfg = torch._C._show_config().split('\n') - hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') - miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') - cuda_version_str = 'N/A' + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" hip_compiled_version = torch.version.hip else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A" + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" sys_version = sys.version.replace("\n", " ") conda_packages = get_conda_packages(run_lambda) rocm_version = get_rocm_version(run_lambda) - neuron_sdk_version = get_neuron_sdk_version(run_lambda) vllm_version = get_vllm_version() vllm_build_flags = summarize_vllm_build_flags() gpu_topo = get_gpu_topo(run_lambda) @@ -609,9 +632,9 @@ def get_env_info(): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format( - sys_version, - sys.maxsize.bit_length() + 1), + python_version="{} ({}-bit runtime)".format( + sys_version, sys.maxsize.bit_length() + 1 + ), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -635,7 +658,6 @@ def get_env_info(): is_xnnpack_available=is_xnnpack_available(), cpu_info=get_cpu_info(run_lambda), rocm_version=rocm_version, - neuron_sdk_version=neuron_sdk_version, vllm_version=vllm_version, vllm_build_flags=vllm_build_flags, gpu_topo=gpu_topo, @@ -702,7 +724,6 @@ env_info_fmt += """ vLLM Info ============================== ROCM Version : {rocm_version} -Neuron SDK Version : {neuron_sdk_version} vLLM Version : {vllm_version} vLLM Build Flags: {vllm_build_flags} @@ -717,15 +738,14 @@ GPU Topology: def pretty_str(envinfo): - - def replace_nones(dct, replacement='Could not collect'): + def replace_nones(dct, replacement="Could not collect"): for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct - def replace_bools(dct, true='Yes', false='No'): + def replace_bools(dct, true="Yes", false="No"): for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -733,43 +753,48 @@ def pretty_str(envinfo): dct[key] = false return dct - def prepend(text, tag='[prepend]'): - lines = text.split('\n') + def prepend(text, tag="[prepend]"): + lines = text.split("\n") updated_lines = [tag + line for line in lines] - return '\n'.join(updated_lines) + return "\n".join(updated_lines) - def replace_if_empty(text, replacement='No relevant packages'): + def replace_if_empty(text, replacement="No relevant packages"): if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split('\n')) > 1: - return '\n{}\n'.format(string) + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line - mutable_dict['nvidia_gpu_models'] = \ - maybe_start_on_next_line(envinfo.nvidia_gpu_models) + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( + envinfo.nvidia_gpu_models + ) # If the machine doesn't have CUDA, report some fields as 'No CUDA' dynamic_cuda_fields = [ - 'cuda_runtime_version', - 'nvidia_gpu_models', - 'nvidia_driver_version', + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", ] - all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] - all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None - for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available( - ) and all_dynamic_cuda_fields_missing: + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields + ) + if ( + TORCH_AVAILABLE + and not torch.cuda.is_available() + and all_dynamic_cuda_fields_missing + ): for field in all_cuda_fields: - mutable_dict[field] = 'No CUDA' + mutable_dict[field] = "No CUDA" if envinfo.cuda_compiled_version is None: - mutable_dict['cuda_compiled_version'] = 'None' + mutable_dict["cuda_compiled_version"] = "None" # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -778,20 +803,20 @@ def pretty_str(envinfo): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty( - mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty( - mutable_dict['conda_packages']) + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend( - mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version)) - if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend( - mutable_dict['conda_packages'], '[conda] ') - mutable_dict['cpu_info'] = envinfo.cpu_info + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) + ) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend( + mutable_dict["conda_packages"], "[conda] " + ) + mutable_dict["cpu_info"] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -804,22 +829,29 @@ def main(): output = get_pretty_env_info() print(output) - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr( - torch.utils, '_crash_handler'): + if ( + TORCH_AVAILABLE + and hasattr(torch, "utils") + and hasattr(torch.utils, "_crash_handler") + ): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): dumps = [ - os.path.join(minidump_dir, dump) - for dump in os.listdir(minidump_dir) + os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) ] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) creation_time = datetime.datetime.fromtimestamp(ctime).strftime( - '%Y-%m-%d %H:%M:%S') - msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ - "if this is related to your bug please include it when you file a report ***" + "%Y-%m-%d %H:%M:%S" + ) + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format( + latest, creation_time + ) + + "if this is related to your bug please include it when you file a report ***" + ) print(msg, file=sys.stderr) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b02d1..7448bb122152d 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -1,57 +1,179 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, - register_replacement) +from torch._inductor.pattern_matcher import ( + PatternMatcherPass, + fwd_only, + register_replacement, +) +from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform -from .vllm_inductor_pass import VllmInductorPass +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 -def silu_mul_pattern_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - return at2[1] +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + +FUSED_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 +} +silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( + torch.ops._C, "silu_and_mul_nvfp4_quant" +) +if silu_and_mul_nvfp4_quant_supported: + FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -def silu_mul_replacement_static(result: torch.Tensor, - result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, - result=result, - input=input, - scale=scale) - return at[1] +class ActivationQuantPattern(ABC): + """ + The base class for Activation+Quant fusions. + Should not be used directly. + """ + + def __init__( + self, + quant_key: QuantKey, + ): + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + + assert self.quant_key in QUANT_OPS, ( + f"unsupported quantization scheme {self.quant_key}" + ) + self.QUANT_OP = QUANT_OPS[self.quant_key] + + assert self.quant_key in FUSED_OPS, ( + f"unsupported fusion scheme {self.quant_key}" + ) + self.FUSED_OP = FUSED_OPS[self.quant_key] + + def empty_quant(self, *args, **kwargs): + kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} + return torch.empty(*args, **kwargs) + + @abstractmethod + def register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError -def empty_bf16(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") +class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Fp8StaticQuant Pattern + """ + + def __init__(self, symmetric: bool = True): + quant_key = QuantKey( + dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric + ) + super().__init__(quant_key) + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) + at2 = auto_functionalized( + self.QUANT_OP, result=result, input=at1[1], scale=scale + ) + return at2[1] + + def replacement( + result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, result=result, input=input, scale=scale + ) + return at[1] + + inputs = [ + self.empty_quant(5, 4), # result + empty_bf16(5, 4), # result_silu_mul + empty_bf16(5, 4), # input + empty_fp32(1, 1), # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) -def empty_fp8(*args, **kwargs): - fp8 = current_platform.fp8_dtype() - return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") +class SiluMulNvfp4QuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+Nvfp4Quant Pattern + """ + + def __init__(self): + super().__init__(kNvfp4Quant) + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + result: torch.Tensor, + output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) + at2 = auto_functionalized( + self.QUANT_OP, + output=result, + input=at1[1], + output_scale=output_scale, + input_scale=scale, + ) + return at2[1], at2[2] + + def replacement( + result: torch.Tensor, + output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + result_block_scale=output_scale, + input=input, + input_global_scale=scale, + ) + return at[1], at[2] + + inputs = [ + self.empty_quant(5, 32), # result + empty_i32(128, 4), # output_scale + empty_bf16(5, 64), # result_silu_mul + empty_bf16(5, 64), # input + empty_fp32(1, 1), # scale + ] + + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) -def empty_fp32(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") - - -class ActivationQuantFusionPass(VllmInductorPass): +class ActivationQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. @@ -61,29 +183,32 @@ class ActivationQuantFusionPass(VllmInductorPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="activation_quant_fusion_pass") + pass_name="activation_quant_fusion_pass" + ) - inputs = [ - empty_fp8(5, 4), # Quant output - empty_bf16(5, 4), # Silu_and_mul output - empty_bf16(5, 4), # Input - empty_fp32(1, 1) # Scale - ] - register_replacement(silu_mul_pattern_static, - silu_mul_replacement_static, inputs, fwd_only, - self.patterns) + pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() + pattern_silu_mul_fp8.register(self.patterns) + if silu_and_mul_nvfp4_quant_supported: + pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() + pattern_silu_mul_nvfp4.register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_act_quant_fusion") + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns in ActivationQuantFusionPass", - count) - - self.dump_graph(graph, "after_act_quant_fusion") - self.end_and_log() + def uuid(self): + return VllmInductorPass.hash_source( + self, + ActivationQuantPattern, + SiluMulFp8StaticQuantPattern, + SiluMulNvfp4QuantPattern, + ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 56494dffc96b3..826ab42462c3b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -3,6 +3,7 @@ import ast import dataclasses +import hashlib import os import pprint import time @@ -15,13 +16,23 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs +from vllm.compilation.inductor_pass import pass_context +from vllm.compilation.partition_rules import ( + inductor_partition_rule_context, + resolve_defined_ops, +) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname -from .compiler_interface import (CompilerInterface, EagerAdaptor, - InductorAdaptor, InductorStandaloneAdaptor) +from .caching import VllmSerializableFunction +from .compiler_interface import ( + CompilerInterface, + EagerAdaptor, + InductorAdaptor, + InductorStandaloneAdaptor, +) from .counter import compilation_counter from .inductor_pass import InductorPass from .pass_manager import PostGradPassManager @@ -31,8 +42,13 @@ logger = init_logger(__name__) def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: if compilation_config.use_inductor: - if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( - "2.8.0.dev"): + # Use standalone compile only if requested, version is new enough, + # and the symbol actually exists in this PyTorch build. + if ( + envs.VLLM_USE_STANDALONE_COMPILE + and is_torch_equal_or_newer("2.8.0.dev") + and hasattr(torch._inductor, "standalone_compile") + ): logger.debug("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: @@ -67,10 +83,24 @@ class CompilerManager: def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + @contextmanager + def compile_context(self, runtime_shape: Optional[int] = None): + """Provide compilation context for the duration of compilation to set + any torch global properties we want to scope to a single Inductor + compilation (e.g. partition rules, pass context).""" + with pass_context(runtime_shape): + if self.compilation_config.use_inductor_graph_partition: + inductor_partition_ops = resolve_defined_ops( + self.compilation_config.splitting_ops + ) + with inductor_partition_rule_context(inductor_partition_ops): + yield + else: + yield + + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): """ Initialize the cache directory for the compiler. @@ -98,9 +128,9 @@ class CompilerManager: # do not use eval(), it is unsafe. self.cache = ast.literal_eval(f.read()) - self.compiler.initialize_cache(cache_dir=cache_dir, - disable_cache=disable_cache, - prefix=prefix) + self.compiler.initialize_cache( + cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix + ) def save_to_file(self): if self.disable_cache or not self.is_cache_updated: @@ -110,35 +140,46 @@ class CompilerManager: with open(self.cache_file_path, "w") as f: f.write(data) - def load(self, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Optional[Callable]: + def load( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Optional[Callable]: if (runtime_shape, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load(handle, graph, example_inputs, - graph_index, runtime_shape) + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, runtime_shape + ) if runtime_shape is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via " - "handle %s", graph_index, self.compiler.name, handle) + "Directly load the %s-th graph for dynamic shape from %s via handle %s", + graph_index, + self.compiler.name, + handle, + ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via " - "handle %s", graph_index, str(runtime_shape), - self.compiler.name, handle) + "Directly load the %s-th graph for shape %s from %s via handle %s", + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) return compiled_graph - def compile(self, - graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None) -> Any: + def compile( + self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None, + ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time global compilation_start_time @@ -149,23 +190,27 @@ class CompilerManager: compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, - runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if runtime_shape is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", elapsed) + "from the cache, took %.3f s", + elapsed, + ) else: logger.info( "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", str(runtime_shape), - elapsed) + "from the cache, took %.3f s", + str(runtime_shape), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -174,37 +219,47 @@ class CompilerManager: # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = \ - f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape, - maybe_key) + maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + + with self.compile_context(runtime_shape): + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + runtime_shape, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, - self.compiler.name)] = handle + self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph if runtime_shape is None: - logger.info( - "Cache the graph for dynamic shape for later use") + logger.info("Cache the graph for dynamic shape for later use") else: - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) + logger.info( + "Cache the graph of shape %s for later use", str(runtime_shape) + ) if runtime_shape is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via " - "handle %s", graph_index, self.compiler.name, handle) + "Store the %s-th graph for dynamic shape from %s via handle %s", + graph_index, + self.compiler.name, + handle, + ) else: logger.debug( "Store the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), self.compiler.name, - handle) + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: @@ -212,11 +267,13 @@ class CompilerManager: elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed if runtime_shape is None: - logger.info("Compiling a graph for dynamic shape takes %.2f s", - elapsed) + logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) else: - logger.info("Compiling a graph for shape %s takes %.2f s", - runtime_shape, elapsed) + logger.info( + "Compiling a graph for shape %s takes %.2f s", + runtime_shape, + elapsed, + ) return compiled_graph @@ -229,8 +286,9 @@ class SplitItem: graph: fx.GraphModule -def split_graph(graph: fx.GraphModule, - ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]: +def split_graph( + graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload] +) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 node_to_subgraph_id = {} @@ -238,7 +296,12 @@ def split_graph(graph: fx.GraphModule, for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == 'call_function' and str(node.target) in ops: + # Match node.target against resolved_ops + # node.target can be OpOverloadPacket, need to check .default + if node.op == "call_function" and ( + node.target in resolved_ops + or (hasattr(node.target, "default") and node.target.default in resolved_ops) + ): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) @@ -251,10 +314,8 @@ def split_graph(graph: fx.GraphModule, # the semantics of the graph will change when we # have mutations in the graph split_gm = torch.fx.passes.split_module.split_module( - graph, - None, - lambda node: node_to_subgraph_id[node], - keep_original_order=True) + graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True + ) outputs = [] @@ -268,10 +329,9 @@ def split_graph(graph: fx.GraphModule, module = getattr(split_gm, name) graph_id = int(name.replace("submod_", "")) - outputs.append( - SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) - # sort by intetger graph_id, rather than string name + # sort by integer graph_id, rather than string name outputs.sort(key=lambda x: x.graph_id) return split_gm, outputs @@ -292,15 +352,19 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): has some special cudagraph output handling. """ - def __init__(self, module: torch.fx.GraphModule, - compile_submod_names: list[str], vllm_config: VllmConfig, - graph_pool, vllm_backend: "VllmBackend"): + def __init__( + self, + module: torch.fx.GraphModule, + compile_submod_names: list[str], + vllm_config: VllmConfig, + vllm_backend: "VllmBackend", + ): super().__init__(module) from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.vllm_config = vllm_config self.vllm_backend = vllm_backend # When True, it annoyingly dumps the torch.fx.Graph on errors. @@ -314,9 +378,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): with self.fake_mode, enable_python_dispatcher(): return super().run(*fake_args) - def call_module(self, target: torch.fx.node.Target, - args: tuple[torch.fx.node.Argument, - ...], kwargs: dict[str, Any]) -> Any: + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], + kwargs: dict[str, Any], + ) -> Any: assert isinstance(target, str) output = super().call_module(target, args, kwargs) @@ -327,29 +394,44 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_dynamic_shape = self.vllm_backend.\ - compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None) + + compiled_graph_for_dynamic_shape = ( + self.vllm_backend.compiler_manager.compile( + submod, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None, + ) + ) # Lazy import here to avoid circular import - from .cuda_graph import CUDAGraphOptions - from .cuda_piecewise_backend import PiecewiseBackend + from .piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( - submod, self.vllm_config, index, - len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend) + submod, + self.vllm_config, + index, + len(self.compile_submod_names), + sym_shape_indices, + compiled_graph_for_dynamic_shape, + self.vllm_backend, + ) + + if ( + self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + and not self.compilation_config.use_inductor_graph_partition + ): + # We're using Dynamo-based piecewise splitting, so we wrap + # the whole subgraph with a static graph wrapper. + from .cuda_graph import CUDAGraphOptions - if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: # resolve the static graph wrapper class (e.g. CUDAGraphWrapper # class) as platform dependent. static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls()) + current_platform.get_static_graph_wrapper_cls() + ) # Always assign PIECEWISE runtime mode to the # CUDAGraphWrapper for piecewise_backend, to distinguish @@ -359,11 +441,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): runnable=piecewise_backend, vllm_config=self.vllm_config, runtime_mode=CUDAGraphMode.PIECEWISE, - graph_pool=self.graph_pool, cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph)) + weak_ref_output=piecewise_backend.is_last_graph, + ), + ) else: self.module.__dict__[target] = piecewise_backend @@ -381,8 +464,9 @@ model_tag: str = "backbone" def set_model_tag(tag: str): """Context manager to set the model tag.""" global model_tag - assert tag != model_tag, \ + assert tag != model_tag, ( f"Model tag {tag} is the same as the current tag {model_tag}." + ) old_tag = model_tag model_tag = tag try: @@ -405,7 +489,6 @@ class VllmBackend: vllm_config: VllmConfig compilation_config: CompilationConfig - graph_pool: Any _called: bool = False # the graph we compiled graph: fx.GraphModule @@ -424,22 +507,14 @@ class VllmBackend: vllm_config: VllmConfig, prefix: str = "", ): - # if the model is initialized with a non-empty prefix, # then usually it's enough to use that prefix, - # e.g. launguage_model, vision_model, etc. + # e.g. language_model, vision_model, etc. # when multiple parts are initialized as independent # models, we need to use the model_tag to distinguish # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global_graph_pool = current_platform.get_global_graph_pool() - - # TODO: in the future, if we want to use multiple - # streams, it might not be safe to share a global pool. - # only investigate this when we use multiple streams - self.graph_pool = global_graph_pool - # Passes to run on the graph post-grad. self.post_grad_pass_manager = PostGradPassManager() @@ -450,7 +525,8 @@ class VllmBackend: self.compilation_config = vllm_config.compilation_config self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config) + self.compilation_config + ) # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -464,16 +540,22 @@ class VllmBackend: inductor_config = config.inductor_compile_config PASS_KEY = "post_grad_custom_post_pass" if PASS_KEY in inductor_config: - # Config should automatically wrap all inductor passes if isinstance(inductor_config[PASS_KEY], PostGradPassManager): - assert (inductor_config[PASS_KEY].uuid() == - self.post_grad_pass_manager.uuid()) + # PassManager already added to config, make sure it's correct + assert ( + inductor_config[PASS_KEY].uuid() + == self.post_grad_pass_manager.uuid() + ) else: + # Config should automatically wrap all inductor passes assert isinstance(inductor_config[PASS_KEY], InductorPass) self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager - def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + def __call__( + self, graph: fx.GraphModule, example_inputs + ) -> VllmSerializableFunction: + from .caching import _compute_code_hash, compilation_config_hash_factors vllm_config = self.vllm_config if not self.compilation_config.cache_dir: @@ -482,37 +564,11 @@ class VllmBackend: # the cache dir will be the same so that we can reuse the compiled # graph. - factors = [] - # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affect the computation graph. - env_hash = envs.compute_hash() - factors.append(env_hash) - - # 1. factors come from the vllm_config (it mainly summarizes how the - # model is created) - config_hash = vllm_config.compute_hash() - factors.append(config_hash) - + factors = compilation_config_hash_factors(vllm_config) # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) - forward_code_files = list( - sorted(self.compilation_config.traced_files)) + code_hash = _compute_code_hash(self.compilation_config.traced_files) self.compilation_config.traced_files.clear() - logger.debug( - "Traced files (to be considered for compilation cache):\n%s", - "\n".join(forward_code_files)) - hash_content = [] - for filepath in forward_code_files: - hash_content.append(filepath) - if filepath == "<string>": - # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. - continue - with open(filepath) as f: - hash_content.append(f.read()) - import hashlib - code_hash = hashlib.md5("\n".join(hash_content).encode(), - usedforsecurity=False).hexdigest() factors.append(code_hash) # 3. compiler hash @@ -520,8 +576,9 @@ class VllmBackend: factors.append(compiler_hash) # combine all factors to generate the cache dir - hash_key = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_key = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, @@ -535,8 +592,7 @@ class VllmBackend: self.compilation_config.cache_dir = cache_dir rank = vllm_config.parallel_config.rank dp_rank = vllm_config.parallel_config.data_parallel_rank - local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", - self.prefix) + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix) os.makedirs(local_cache_dir, exist_ok=True) self.compilation_config.local_cache_dir = local_cache_dir @@ -545,16 +601,19 @@ class VllmBackend: if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: - logger.info("Using cache directory: %s for vLLM's torch.compile", - local_cache_dir) + logger.info( + "Using cache directory: %s for vLLM's torch.compile", local_cache_dir + ) - self.compiler_manager.initialize_cache(local_cache_dir, disable_cache, - self.prefix) + self.compiler_manager.initialize_cache( + local_cache_dir, disable_cache, self.prefix + ) # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) self.compilation_config.compilation_time += dynamo_time @@ -566,8 +625,14 @@ class VllmBackend: self.graph = graph self.configure_post_pass() - self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_config.splitting_ops) + if self.compilation_config.use_inductor_graph_partition: + # Let Inductor decide partitioning; avoid FX-level pre-splitting. + fx_split_ops: list[str] = [] + else: + fx_split_ops = self.compilation_config.splitting_ops or [] + + resolved_split_ops = resolve_defined_ops(fx_split_ops) + self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops) from torch._dynamo.utils import lazy_format_graph_code @@ -576,25 +641,27 @@ class VllmBackend: lazy_format_graph_code("before split", self.graph) lazy_format_graph_code("after split", self.split_gm) - compilation_counter.num_piecewise_graphs_seen += len( - self.piecewise_graphs) + compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs) submod_names_to_compile = [ - item.submod_name for item in self.piecewise_graphs + item.submod_name + for item in self.piecewise_graphs if not item.is_splitting_graph ] # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes - PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, self.graph_pool, - self).run(*example_inputs) + PiecewiseCompileInterpreter( + self.split_gm, submod_names_to_compile, self.vllm_config, self + ).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa # use `print_readable` because it can include submodules - src = "from __future__ import annotations\nimport torch\n" + \ - self.split_gm.print_readable(print_output=False) + src = ( + "from __future__ import annotations\nimport torch\n" + + self.split_gm.print_readable(print_output=False) + ) src = src.replace("<lambda>", "GraphModule") with open(graph_path, "w") as f: f.write(src) @@ -603,12 +670,17 @@ class VllmBackend: self._called = True - if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ - not self.compilation_config.cudagraph_copy_inputs: - return self.split_gm + if ( + self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + or not self.compilation_config.cudagraph_copy_inputs + ): + return VllmSerializableFunction( + graph, example_inputs, self.prefix, self.split_gm + ) # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() fake_args = [ fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t @@ -619,10 +691,12 @@ class VllmBackend: # for weights and static buffers, they will have concrete shapes. # symbolic shape only happens for input tensors. from torch.fx.experimental.symbolic_shapes import is_symbolic + self.sym_tensor_indices = [ - i for i, x in enumerate(fake_args) - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ - any(is_symbolic(d) for d in x.size()) + i + for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + and any(is_symbolic(d) for d in x.size()) ] # compiler managed cudagraph input buffers @@ -647,4 +721,6 @@ class VllmBackend: list_args[index] = static_tensor return self.split_gm(*list_args) - return copy_and_call + return VllmSerializableFunction( + graph, example_inputs, self.prefix, copy_and_call + ) diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 1c3f52c533b13..6ee82e74963d9 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -12,8 +12,13 @@ class AbstractStaticGraphWrapper(Protocol): to be captured as a static graph. """ - def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs): + def __init__( + self, + runnable: Callable[..., Any], + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + **kwargs: Any, + ) -> None: """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -25,16 +30,13 @@ class AbstractStaticGraphWrapper(Protocol): graph runtime. See CUDAGraphMode in vllm/config.py. Note that only the subset enum `NONE`, `PIECEWISE` and `FULL` are used as concrete runtime mode for cudagraph dispatching. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. Keyword Args: kwargs: Additional keyword arguments for platform-specific configurations. """ raise NotImplementedError - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ Executes the wrapped callable. diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py new file mode 100644 index 0000000000000..fc930e9b4f143 --- /dev/null +++ b/vllm/compilation/caching.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import inspect +import pickle +from unittest.mock import patch + +import torch +from torch.utils import _pytree as pytree + +import vllm.envs as envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +try: + from torch._dynamo.aot_compile import SerializableCallable +except ImportError: + SerializableCallable = object + +assert isinstance(SerializableCallable, type) + +logger = init_logger(__name__) + + +class VllmSerializableFunction(SerializableCallable): + """ + A wrapper around a compiled function by vllm. It will forward the tensor + inputs to the compiled function and return the result. + It also implements a serialization interface to support PyTorch's precompile + with custom backend, so that we can save and load the compiled function on + disk. There's no need to wrap around the compiled function if we don't want + to serialize them in particular cases. + Right now serialization for the custom backend is done via + serializing the Dynamo fx graph plus example inputs. + """ + + def __init__(self, graph_module, example_inputs, prefix, optimized_call): + assert isinstance(graph_module, torch.fx.GraphModule) + self.graph_module = graph_module + self.example_inputs = example_inputs + self.prefix = prefix + self.optimized_call = optimized_call + self.shape_env = None + sym_input = next( + (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None + ) + if sym_input is not None: + self.shape_env = sym_input.node.shape_env + + def __call__(self, *args, **kwargs): + return self.optimized_call(*args, **kwargs) + + @classmethod + def serialize_compile_artifacts( + cls, compiled_fn: "VllmSerializableFunction" + ) -> bytes: + import sympy + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler, Options + + state = compiled_fn.__dict__.copy() + state.pop("optimized_call") + state.pop("shape_env") + for node in state["graph_module"].graph.nodes: + node.meta.pop("source_fn_stack", None) + node.meta.pop("nn_module_stack", None) + + graph_reducer_override = GraphPickler.reducer_override + + def _graph_reducer_override(self, obj): + if ( + inspect.isclass(obj) + and issubclass(obj, sympy.Function) + and hasattr(obj, "_torch_unpickler") + ): + return obj._torch_unpickler, (obj._torch_handler_name,) + if isinstance(obj, FakeTensorMode): + return type(None), () + return graph_reducer_override(self, obj) + + # Mask off tensor inputs since they are large and not needed. + state["example_inputs"] = pytree.tree_map_only( + torch.Tensor, lambda _: None, state["example_inputs"] + ) + with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): + state["graph_module"] = GraphPickler.dumps( + state["graph_module"], Options(ops_filter=None) + ) + state["example_inputs"] = GraphPickler.dumps(state["example_inputs"]) + return pickle.dumps(state) + + @classmethod + def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction": + from torch._guards import TracingContext, tracing + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from vllm.compilation.backends import VllmBackend + + state = pickle.loads(data) + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) + state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) + vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) + + def optimized_call(*example_inputs): + """ + On the first run of the optimized call, we rerun the compiler + backend which should result in a cache hit. After the backend + call returns, we just do a one-time replacement of the optimized + call with the compiled function, so that subsequent calls are on + the AOT compiled path. + """ + compile_inputs = [ + inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs) + ] + with tracing(TracingContext(fake_mode)): + fn.optimized_call = vllm_backend( + state["graph_module"], compile_inputs + ).optimized_call + return fn.optimized_call(*example_inputs) + + fn = cls(**state, optimized_call=optimized_call) + return fn + + @property + def co_name(self): + """ + Used for depyf debugging. + """ + return "VllmSerializableFunction" + + +def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: + factors = [] + # 0. factors come from the env, for example, The values of + # VLLM_PP_LAYER_PARTITION will affect the computation graph. + env_hash = envs.compute_hash() + factors.append(env_hash) + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + config_hash = vllm_config.compute_hash() + factors.append(config_hash) + return factors + + +def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: + items = list(sorted(file_contents.items(), key=lambda x: x[0])) + hash_content = [] + for filepath, content in items: + hash_content.append(filepath) + if filepath == "<string>": + # This means the function was dynamically generated, with + # e.g. exec(). We can't actually check these. + continue + hash_content.append(content) + return hashlib.md5( + "\n".join(hash_content).encode(), usedforsecurity=False + ).hexdigest() + + +def _compute_code_hash(files: set[str]) -> str: + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", "\n".join(files) + ) + file_contents = {} + for filepath in files: + if filepath == "<string>": + file_contents[filepath] = "" + else: + with open(filepath) as f: + file_contents[filepath] = f.read() + return _compute_code_hash_with_content(file_contents) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 6ae50245ed3a8..988a1069cd9e7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,23 +10,31 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from .vllm_inductor_pass import VllmInductorPass +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() if find_spec("flashinfer"): try: import flashinfer.comm as flashinfer_comm - flashinfer_comm = (flashinfer_comm if hasattr( - flashinfer_comm, "trtllm_allreduce_fusion") else None) + + flashinfer_comm = ( + flashinfer_comm + if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") + else None + ) except ImportError: flashinfer_comm = None else: @@ -42,7 +50,6 @@ STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: - def __init__(self, dtype: torch.dtype, device: str): self.dtype = dtype self.device = device @@ -51,14 +58,12 @@ class BasePattern: class GEMMReduceScatterPattern(BasePattern): - def get_inputs(self): mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [mul, mm_weight] def register(self, pm_pass: PatternMatcherPass): - def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): mm = torch.ops.aten.mm.default(mul, mm_weight) reduce_scatter = torch.ops.vllm.reduce_scatter.default( @@ -80,12 +85,12 @@ class GEMMReduceScatterPattern(BasePattern): return gemm_rs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllGatherGEMMPattern(BasePattern): - def get_inputs(self): x = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -93,7 +98,6 @@ class AllGatherGEMMPattern(BasePattern): return [x, weight] def register(self, pm_pass: PatternMatcherPass): - def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -108,8 +112,8 @@ class AllGatherGEMMPattern(BasePattern): return torch.ops.aten.mm.default(all_gather, weight) def replacement( - x: torch.Tensor, - weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, [weight], @@ -118,65 +122,87 @@ class AllGatherGEMMPattern(BasePattern): ) return mm_outputs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class ScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self): input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - mm_weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + mm_weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) return [input, mm_weight, scale_a, scale_b] def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: - scaled_mm = torch.ops.aten._scaled_mm.default(input, - mat2=mat2, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype) + def pattern( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + scaled_mm = torch.ops.aten._scaled_mm.default( + input, + mat2=mat2, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype, + ) reduce_scatter = torch.ops.vllm.reduce_scatter.default( scaled_mm, dim=0, world_size=self.tp_size, - group_name=self.tp.unique_name) + group_name=self.tp.unique_name, + ) return reduce_scatter - def replacement(input: torch.Tensor, mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + def replacement( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllGatherScaledMMPattern(BasePattern): - def get_inputs(self): x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) s1 = x.shape[0] * self.tp_size @@ -186,7 +212,6 @@ class AllGatherScaledMMPattern(BasePattern): return [x, weight, scale_a, scale_b] def register(self, pm_pass: PatternMatcherPass): - def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -194,22 +219,25 @@ class AllGatherScaledMMPattern(BasePattern): scale_b: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name + ) - return torch.ops.aten._scaled_mm.default(all_gather, - mat2=weight, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype) + return torch.ops.aten._scaled_mm.default( + all_gather, + mat2=weight, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype, + ) - def replacement(x: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: + def replacement( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa x, [weight], @@ -224,29 +252,33 @@ class AllGatherScaledMMPattern(BasePattern): ) return mm_outputs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class CutlassScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self): input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - mm_weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + mm_weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - cutlass_mm_output = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) + cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype) return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cutlass_mm_output: torch.Tensor) -> torch.Tensor: + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor, + ) -> torch.Tensor: cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( torch.ops._C.cutlass_scaled_mm.default, out=cutlass_mm_output, @@ -254,41 +286,58 @@ class CutlassScaledMMReduceScatterPattern(BasePattern): b=weight, a_scales=scale_a, b_scales=scale_b, - bias=None) + bias=None, + ) reduce_scatter = torch.ops.vllm.reduce_scatter.default( cutlass_scaled_mm[1], dim=0, world_size=self.tp_size, - group_name=self.tp.unique_name) + group_name=self.tp.unique_name, + ) return reduce_scatter - def replacement(input: torch.Tensor, mat2: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cutlass_mm_output: torch.Tensor) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + def replacement( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor, + ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllGatherCutlassScaledMMPattern(BasePattern): - def get_inputs(self): x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) s1 = x.shape[0] * self.tp_size @@ -301,7 +350,6 @@ class AllGatherCutlassScaledMMPattern(BasePattern): return [x, weight, scale_a, scale_b, output] def register(self, pm_pass: PatternMatcherPass): - def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -310,10 +358,8 @@ class AllGatherCutlassScaledMMPattern(BasePattern): output: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name + ) cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( torch.ops._C.cutlass_scaled_mm.default, @@ -322,12 +368,17 @@ class AllGatherCutlassScaledMMPattern(BasePattern): b=weight, a_scales=scale_a, b_scales=scale_b, - bias=None) + bias=None, + ) return cutlass_scaled_mm[1] - def replacement(x: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - output: torch.Tensor) -> torch.Tensor: + def replacement( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa x, [weight], @@ -342,51 +393,54 @@ class AllGatherCutlassScaledMMPattern(BasePattern): ) return mm_outputs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) -class AsyncTPPass(VllmInductorPass): - +class AsyncTPPass(VllmPatternMatcherPass): + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) # Enable symmetric memory for the TP process group enable_symm_mem_for_group(get_tp_group().device_group.group_name) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="async_tp_pass") - GEMMReduceScatterPattern(self.model_dtype, - self.device).register(self.patterns) + pass_name="async_tp_pass" + ) + GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) - AllGatherGEMMPattern(self.model_dtype, - self.device).register(self.patterns) + AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) # These fusions are enabled only for bfloat16 models because # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling # only supports bfloat16 as the output dtype. if self.model_dtype == torch.bfloat16: - ScaledMMReduceScatterPattern(self.model_dtype, - self.device).register(self.patterns) - AllGatherScaledMMPattern(self.model_dtype, - self.device).register(self.patterns) + ScaledMMReduceScatterPattern(self.model_dtype, self.device).register( + self.patterns + ) + AllGatherScaledMMPattern(self.model_dtype, self.device).register( + self.patterns + ) - CutlassScaledMMReduceScatterPattern( - self.model_dtype, self.device).register(self.patterns) - AllGatherCutlassScaledMMPattern( - self.model_dtype, self.device).register(self.patterns) + CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register( + self.patterns + ) + AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register( + self.patterns + ) + + self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: # only do replace for specific shapes tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_async_tp_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns with async TP pass.", count) - self.dump_graph(graph, "after_async_tp_pass") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) if flashinfer_comm is not None: @@ -401,6 +455,19 @@ if flashinfer_comm is not None: 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + + try: + _FI_MAX_SIZES.update( + { + int(k): int(float(v) * MiB) + for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() + } + ) + except Exception as e: + raise ValueError( + "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) + ) from e + # opt for a more conservative default value # when world size is not in _FI_MAX_SIZES _DEFAULT_FI_MAX_SIZE = MiB // 2 @@ -432,8 +499,9 @@ if flashinfer_comm is not None: max_fusion_size, ) if use_flashinfer: - assert (_FI_WORKSPACE_TENSOR is not None - ), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -465,60 +533,67 @@ if flashinfer_comm is not None: quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout - layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=scale_factor, ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if (scale_factor is not None and scale_out is None - and fuse_rms_quant): + if scale_factor is not None and scale_out is None and fuse_rms_quant: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, allreduce_out, residual, rms_gamma, - scale_factor, rms_eps) + quant_out, + allreduce_out, + residual, + rms_gamma, + scale_factor, + rms_eps, + ) else: torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, - rms_eps) + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) else: if norm_out is None: - torch.ops._C.fused_add_rms_norm(allreduce_out, residual, - rms_gamma, rms_eps) + torch.ops._C.fused_add_rms_norm( + allreduce_out, residual, rms_gamma, rms_eps + ) norm_out = allreduce_out else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, - rms_eps) + torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) if scale_factor is not None: if scale_out is not None: - torch.ops._C.scaled_fp4_quant(quant_out, norm_out, - scale_out, scale_factor) + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) else: torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor) + quant_out, norm_out, scale_factor + ) if scale_factor is None or norm_out is not None: - # we need to return allreduce outpput + # we need to return allreduce output # in cases of non quant fused AR + RMS norm # and fused AR + RMS norm + quant without fused add allreduce_in.copy_(allreduce_out) def call_trtllm_fused_allreduce_norm_fake( - allreduce_in: torch.Tensor, - residual: torch.Tensor, - rms_gamma: torch.Tensor, - rms_eps: float, - world_rank: int, - world_size: int, - launch_with_pdl: bool, - trigger_completion_at_end: bool, - fp32_acc: bool, - max_token_num: int, - pattern_code: int, - fuse_rms_quant: bool, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, - scale_out: Optional[torch.Tensor] = None, - scale_factor: Optional[torch.Tensor] = None) -> None: + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None, + ) -> None: pass direct_register_custom_op( @@ -532,10 +607,10 @@ if flashinfer_comm is not None: "scale_out", ], fake_impl=call_trtllm_fused_allreduce_norm_fake, - dispatch_key=current_platform.dispatch_key, ) flashinfer_trtllm_fused_allreduce_norm = ( - torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default + ) class FlashInferFusedAllReduceParams: @@ -573,7 +648,7 @@ class FlashInferFusedAllReduceParams: class AllReduceRMSNormPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces the allreduce + rms norm (without residual) with fused flashinfer implementation. Applies to allreduce + rmsnorm before attn in the first Transformer block. """ @@ -591,17 +666,15 @@ class AllReduceRMSNormPattern(BasePattern): def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) + rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype) return [input, rms_result, weight] def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, rms_result: torch.Tensor, - weight: torch.Tensor): + def pattern( + input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor + ): allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_OP, @@ -613,8 +686,9 @@ class AllReduceRMSNormPattern(BasePattern): # rms_result, allreduce_output return rms[1], allreduce_output - def replacement(input: torch.Tensor, rms_result: torch.Tensor, - weight: torch.Tensor): + def replacement( + input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -625,20 +699,20 @@ class AllReduceRMSNormPattern(BasePattern): scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNorm, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) # rms_result, allreduce_in return allreduce[3], allreduce[1] - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedAddRMSNormPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (with residual) + This pattern replaces the allreduce + rms norm (with residual) with fused flashinfer implementation. Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. """ @@ -665,9 +739,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): ] def register(self, pm_pass: PatternMatcherPass): - - def pattern(residual: torch.Tensor, input: torch.Tensor, - weight: torch.Tensor): + def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_ADD_OP, @@ -679,8 +751,9 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): # input, residual return rms[1], rms[2] - def replacement(residual: torch.Tensor, input: torch.Tensor, - weight: torch.Tensor): + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ): allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -690,44 +763,46 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNorm, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) # allreduce_in, residual return allreduce[1], allreduce[2] - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces the allreduce + rms norm (without residual) + static fp8 quant with fused flashinfer implementation. - Applies to allreduce + rmsnorm + quant before attn + Applies to allreduce + rmsnorm + quant before attn in the first Transformer block. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input = torch.zeros([1, 8, 4], - device=self.device, - dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.quant_dtype) + input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) + rmsnorm_result = torch.empty( + [1, 8, 4], device=self.device, dtype=self.dtype + ) + quant_result = torch.empty( + [1, 8, 4], device=self.device, dtype=self.quant_dtype + ) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) return [input, rmsnorm_result, quant_result, weight, scale] @@ -740,23 +815,31 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized(RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon) + rmsnorm_out_tuple = auto_functionalized( + RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon, + ) - quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale) + quant_out_tuple = auto_functionalized( + STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out_tuple[1], + scale=scale, + ) # quant_out, allreduce_output return quant_out_tuple[1], all_reduce - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement( + input: torch.Tensor, + result_rms: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -767,8 +850,10 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant + ), scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) @@ -776,40 +861,41 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): # quant_out, allreduce_output return allreduce[4], allreduce[1] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): """ This pattern replaces the allreduce + rms norm (with residual) + static fp8 quant with fused flashinfer implementation. - Applies to o_proj + rmsnorm after attn + quant and + Applies to o_proj + rmsnorm after attn + quant and mlp + rmsnorm + quant before attn. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): input = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([4, 4], - device=self.device, - dtype=self.quant_dtype) - scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) + quant_result = torch.empty( + [4, 4], device=self.device, dtype=self.quant_dtype + ) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) return [ quant_result, @@ -828,25 +914,30 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): ): allreduce_output = tensor_model_parallel_all_reduce(input) - fused_add_rmsnorm_out_tuple = \ - auto_functionalized( + fused_add_rmsnorm_out_tuple = auto_functionalized( RMS_ADD_OP, input=allreduce_output, residual=residual, weight=weight, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) quant_out_tuple = auto_functionalized( STATIC_FP8_QUANT_OP, result=quant_result, input=fused_add_rmsnorm_out_tuple[1], - scale=scale) + scale=scale, + ) # quant_out, allreduce_output return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] - def replacement(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -856,56 +947,61 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant + ), scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) - # # quant_out, rms_norm_residual + # quant_out, rms_norm_residual return allreduce[4], allreduce[2] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces the allreduce + rms norm (without residual) + static nvfp4 quant with fused flashinfer implementation. - Applies to allreduce + rmsnorm + quant before attn + Applies to allreduce + rmsnorm + quant before attn in the first Transformer block. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input = torch.empty([1, 16, 16], - device=self.device, - dtype=self.dtype) + input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty([1, 16, 16], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty((16, 8), - device=self.device, - dtype=torch.uint8) - input_global_scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) + rmsnorm_result = torch.empty( + [1, 16, 16], device=self.device, dtype=self.dtype + ) + quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) + input_global_scale = torch.empty( + [1, 1], device=self.device, dtype=torch.float32 + ) weight = torch.empty([16], device=self.device, dtype=self.dtype) - output_scale = torch.empty([128, 4], - device=self.device, - dtype=torch.int32) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) return [ - input, rmsnorm_result, quant_result, weight, - input_global_scale, output_scale + input, + rmsnorm_result, + quant_result, + weight, + input_global_scale, + output_scale, ] def pattern( @@ -917,26 +1013,33 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): output_scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized(RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon) + rmsnorm_out_tuple = auto_functionalized( + RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon, + ) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, input=rmsnorm_out_tuple[1], output_scale=output_scale, - input_scale=input_global_scale) + input_scale=input_global_scale, + ) # quant_out, allreduce_output, output_scale return quant_out_tuple[1], all_reduce, quant_out_tuple[2] - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - input_global_scale: torch.Tensor, - output_scale: torch.Tensor): + def replacement( + input: torch.Tensor, + result_rms: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor, + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -947,8 +1050,10 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): scale_out=output_scale, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant + ), scale_factor=input_global_scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) @@ -956,44 +1061,41 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): # quant_out, allreduce_output, output_scale return allreduce[4], allreduce[1], allreduce[5] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): """ This pattern replaces the allreduce + rms norm (with residual) + static nvfp4 quant with fused flashinfer implementation. - Applies to o_proj + rmsnorm after attn + quant and + Applies to o_proj + rmsnorm after attn + quant and mlp + rmsnorm + quant before attn. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): input = torch.empty([16, 16], device=self.device, dtype=self.dtype) - residual = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) - weight = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty((16, 8), - device=self.device, - dtype=torch.uint8) - input_global_scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) - output_scale = torch.empty([128, 4], - device=self.device, - dtype=torch.int32) + residual = torch.empty([16, 16], device=self.device, dtype=self.dtype) + weight = torch.empty([16, 16], device=self.device, dtype=self.dtype) + quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) + input_global_scale = torch.empty( + [1, 1], device=self.device, dtype=torch.float32 + ) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) return [ quant_result, @@ -1004,33 +1106,46 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): input_global_scale, ] - def pattern(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, output_scale: torch.Tensor, - weight: torch.Tensor, input_global_scale: torch.Tensor): + def pattern( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + ): allreduce_output = tensor_model_parallel_all_reduce(input) - fused_add_rmsnorm_out_tuple = \ - auto_functionalized( + fused_add_rmsnorm_out_tuple = auto_functionalized( RMS_ADD_OP, input=allreduce_output, residual=residual, weight=weight, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, input=fused_add_rmsnorm_out_tuple[1], output_scale=output_scale, - input_scale=input_global_scale) + input_scale=input_global_scale, + ) # quant_out, allreduce_output, output_scale - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[ - 2], quant_out_tuple[2] + return ( + quant_out_tuple[1], + fused_add_rmsnorm_out_tuple[2], + quant_out_tuple[2], + ) - def replacement(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, output_scale: torch.Tensor, - weight: torch.Tensor, - input_global_scale: torch.Tensor): + def replacement( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + ): allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1040,20 +1155,22 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): scale_out=output_scale, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant + ), scale_factor=input_global_scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) # quant_out, rms_norm_residual, output_scale return allreduce[4], allreduce[2], allreduce[5] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) -class AllReduceFusionPass(VllmInductorPass): - +class AllReduceFusionPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig): super().__init__(config) self.disabled = True @@ -1061,7 +1178,8 @@ class AllReduceFusionPass(VllmInductorPass): if self.tp_size <= 1: return self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="all_reduce_fusion_pass") + pass_name="all_reduce_fusion_pass" + ) if config.model_config is None: return self.hidden_dim = config.model_config.get_hidden_size() @@ -1071,21 +1189,21 @@ class AllReduceFusionPass(VllmInductorPass): if flashinfer_comm is None: logger.warning( "Flashinfer is not installed or comm module not found, " - "skipping allreduce fusion pass") + "skipping allreduce fusion pass" + ) return # Check if the world size is supported if self.tp_size not in _FI_MAX_SIZES: logger.warning( - "Flashinfer allreduce fusion is not " - "supported for world size %s", + "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // - (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num) + _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) + // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), + config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, @@ -1094,7 +1212,8 @@ class AllReduceFusionPass(VllmInductorPass): hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, - )) + ) + ) global _FI_WORKSPACE_TENSOR _FI_WORKSPACE_TENSOR = workspace_tensor @@ -1105,8 +1224,14 @@ class AllReduceFusionPass(VllmInductorPass): max_token_num=max_num_token, # fuse rms norm static fp8 quant fused op # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + ) + self.register_patterns() + self.dump_patterns(config, self.patterns) + + @enable_fake_mode + def register_patterns(self): for epsilon in [1e-5, 1e-6]: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, @@ -1152,19 +1277,19 @@ class AllReduceFusionPass(VllmInductorPass): self.disabled = False + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: + logger.debug("AllReduceFusionPass disabled") return - self.begin() - self.dump_graph(graph, "before_all_reduce_fusion_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) - self.dump_graph(graph, "after_all_reduce_fusion_pass") - self.end_and_log() + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) def __del__(self): - if self.disabled: + if getattr(self, "disabled", True): return if flashinfer_comm is not None: flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( - self.ipc_handles, self.group) + self.ipc_handles, self.group + ) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 7158fd685964f..e5fa2518b87be 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -17,21 +17,19 @@ from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer -from .inductor_pass import pass_context - class CompilerInterface: """ The interface for a compiler that can be used by vLLM. """ + # The name of the compiler, e.g. inductor. # This is a class-level attribute. name: str - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): """ when the vLLM process uses `cache_dir` as the cache directory, the compiler should initialize itself with the cache directory, @@ -93,12 +91,14 @@ class CompilerInterface: """ return None, None - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: """ Load the compiled function from the handle. Raises an error if the handle is invalid. @@ -150,11 +150,13 @@ def get_inductor_factors() -> list[Any]: factors: list[Any] = [] # summarize system state from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() factors.append(system_factors) # summarize pytorch state from torch._inductor.codecache import torch_key + torch_factors = torch_key() factors.append(torch_factors) return factors @@ -169,18 +171,19 @@ class InductorStandaloneAdaptor(CompilerInterface): Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off. """ + name = "inductor_standalone" def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): self.cache_dir = cache_dir def compile( @@ -196,6 +199,7 @@ class InductorStandaloneAdaptor(CompilerInterface): if compiler_config is not None: current_config.update(compiler_config) set_inductor_config(current_config, runtime_shape) + set_functorch_config() if isinstance(runtime_shape, int): dynamic_shapes = "from_example_inputs" @@ -203,12 +207,13 @@ class InductorStandaloneAdaptor(CompilerInterface): dynamic_shapes = "from_tracing_context" from torch._inductor import standalone_compile - with pass_context(runtime_shape): - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}) + + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None @@ -218,19 +223,23 @@ class InductorStandaloneAdaptor(CompilerInterface): compilation_counter.num_compiled_artifacts_saved += 1 return compiled_graph, (key, path) - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) path = handle[1] inductor_compiled_graph = torch._inductor.CompiledArtifact.load( - path=path, format="unpacked") + path=path, format="unpacked" + ) from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) def compiled_graph_wrapper(*args): @@ -250,21 +259,22 @@ class InductorAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. """ + name = "inductor" def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): self.cache_dir = cache_dir self.prefix = prefix - self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir + self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir if disable_cache: return # redirect the cache directory to a sub-directory @@ -288,6 +298,7 @@ class InductorAdaptor(CompilerInterface): ) -> tuple[Optional[Callable], Optional[Any]]: compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx + current_config = {} if compiler_config is not None: current_config.update(compiler_config) @@ -297,6 +308,7 @@ class InductorAdaptor(CompilerInterface): current_config["fx_graph_remote_cache"] = False set_inductor_config(current_config, runtime_shape) + set_functorch_config() # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -308,8 +320,8 @@ class InductorAdaptor(CompilerInterface): # it to get the hash of the compiled graph directly. hash_str, file_path = None, None - from torch._inductor.codecache import (FxGraphCache, - compiled_fx_graph_hash) + from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash + if torch.__version__.startswith("2.5"): original_load = FxGraphCache.load original_load_name = "torch._inductor.codecache.FxGraphCache.load" @@ -319,14 +331,18 @@ class InductorAdaptor(CompilerInterface): nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable file_path = compiled_fn.__code__.co_filename # noqa - if not file_path.startswith(self.base_cache_dir): + if ( + not file_path.startswith(self.base_cache_dir) + and compiled_fn.__closure__ is not None + ): # hooked in the align_inputs_from_check_idxs function # in torch/_inductor/utils.py for cell in compiled_fn.__closure__: if not callable(cell.cell_contents): continue if cell.cell_contents.__code__.co_filename.startswith( - self.base_cache_dir): + self.base_cache_dir + ): # this is the real file path compiled from Inductor file_path = cell.cell_contents.__code__.co_filename break @@ -338,23 +354,24 @@ class InductorAdaptor(CompilerInterface): original_load_name = None def hijacked_compile_fx_inner(*args, **kwargs): - output = torch._inductor.compile_fx.compile_fx_inner( - *args, **kwargs) + output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs) nonlocal hash_str inductor_compiled_graph = output if inductor_compiled_graph is not None: nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable file_path = compiled_fn.__code__.co_filename # noqa - if not file_path.startswith(self.base_cache_dir): + if ( + not file_path.startswith(self.base_cache_dir) + and compiled_fn.__closure__ is not None + ): # hooked in the align_inputs_from_check_idxs function # in torch/_inductor/utils.py for cell in compiled_fn.__closure__: if not callable(cell.cell_contents): continue code = cell.cell_contents.__code__ - if code.co_filename.startswith( - self.base_cache_dir): + if code.co_filename.startswith(self.base_cache_dir): # this is the real file path # compiled from Inductor file_path = code.co_filename @@ -387,29 +404,38 @@ class InductorAdaptor(CompilerInterface): # for hijacking the hash of the compiled graph stack.enter_context( - patch("torch._inductor.codecache.compiled_fx_graph_hash", - hijack_compiled_fx_graph_hash)) + patch( + "torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash, + ) + ) # for providing a dummy shape environment stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - _get_shape_env)) + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env, + ) + ) - from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCache) + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache if hasattr(AOTAutogradCache, "_get_shape_env"): stack.enter_context( patch( "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", - _get_shape_env)) + _get_shape_env, + ) + ) # for forcing the graph to be cached stack.enter_context( patch( "torch._inductor.codecache.FxGraphCache._check_can_cache", - _check_can_cache)) + _check_can_cache, + ) + ) # Dynamo metrics context, see method for more details. stack.enter_context(self.metrics_context()) @@ -422,23 +448,25 @@ class InductorAdaptor(CompilerInterface): # standalone_compile sometime. if is_torch_equal_or_newer("2.6"): stack.enter_context( - torch._inductor.config.patch(fx_graph_remote_cache=False)) + torch._inductor.config.patch(fx_graph_remote_cache=False) + ) # InductorAdaptor (unfortunately) requires AOTAutogradCache # to be turned off to run. It will fail to acquire the hash_str # and error if not. # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem. stack.enter_context( - torch._functorch.config.patch(enable_autograd_cache=False)) + torch._functorch.config.patch(enable_autograd_cache=False) + ) stack.enter_context( - torch._functorch.config.patch( - enable_remote_autograd_cache=False)) + torch._functorch.config.patch(enable_remote_autograd_cache=False) + ) - with pass_context(runtime_shape): - compiled_graph = compile_fx( - graph, - example_inputs, - inner_compile=hijacked_compile_fx_inner, - config_patches=current_config) + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config, + ) # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # compilation cache. So turn off the checks if we disable the @@ -451,52 +479,63 @@ class InductorAdaptor(CompilerInterface): "failed, leading to a corrupted compilation artifact. " "We recommend trying to " "remove ~/.cache/vllm/torch_compile_cache and try again " - "to see the real issue. ") + "to see the real issue. " + ) assert file_path is not None, ( - "failed to get the file path of the compiled graph") + "failed to get the file path of the compiled graph" + ) return compiled_graph, (hash_str, file_path) - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) hash_str = handle[0] - from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCache) + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._inductor.codecache import FxGraphCache + with ExitStack() as exit_stack: exit_stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv())) + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache if hasattr(AOTAutogradCache, "_get_shape_env"): exit_stack.enter_context( patch( "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv())) + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) # Dynamo metrics context, see method for more details. exit_stack.enter_context(self.metrics_context()) if torch.__version__.startswith("2.5"): inductor_compiled_graph = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, False) + hash_str, example_inputs, True, False + ) assert inductor_compiled_graph is not None, ( "Inductor cache lookup failed. Please remove" f"the cache directory and try again." # noqa ) elif torch.__version__ >= "2.6": - from torch._inductor.output_code import ( - CompiledFxGraphConstantsWithGm) + from torch._inductor.output_code import CompiledFxGraphConstantsWithGm + constants = CompiledFxGraphConstantsWithGm(graph) inductor_compiled_graph, _ = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, None, constants) + hash_str, example_inputs, True, None, constants + ) assert inductor_compiled_graph is not None, ( "Inductor cache lookup failed. Please remove" f"the cache directory and try again." # noqa @@ -509,6 +548,7 @@ class InductorAdaptor(CompilerInterface): # need to know if the graph returns a tuple from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) # this is the callable we return to Dynamo to run @@ -542,6 +582,7 @@ class InductorAdaptor(CompilerInterface): """ if is_torch_equal_or_newer("2.6"): import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() else: return contextlib.nullcontext() @@ -551,8 +592,14 @@ def set_inductor_config(config, runtime_shape): if isinstance(runtime_shape, int): # for a specific batchsize, tuning triton kernel parameters # can be beneficial - config["max_autotune"] = True - config["coordinate_descent_tuning"] = True + config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE + config["coordinate_descent_tuning"] = ( + envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING + ) + + +def set_functorch_config(): + torch._functorch.config.bundled_autograd_cache = False class EagerAdaptor(CompilerInterface): diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index e01dd3915a3a1..9e8de831bcb29 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -41,7 +41,8 @@ class CompilationCounter: assert getattr(self, k) - getattr(old, k) == v, ( f"{k} not as expected, before it is {getattr(old, k)}" f", after it is {getattr(self, k)}, " - f"expected diff is {v}") + f"expected diff is {v}" + ) compilation_counter = CompilationCounter() diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 65a38197ad4e2..4c3ac9e56a377 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform @@ -44,10 +45,10 @@ class CUDAGraphWrapper: The workflow of this wrapper in the cudagraph dispatching is as follows: 1. At initialization, a runtime mode is assigned to the wrapper (FULL or - PIECEWISE). - 2. At runtime, the wrapper receives a runtime_mode and a + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a batch_descriptor(key) from the forward context and blindly trust them - for cudagraph dispatching. + for cudagraph dispatching. 3. If runtime_mode is NONE or runtime_mode does not match the mode of the wrapper, just call the runnable directly. 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, @@ -56,22 +57,22 @@ class CUDAGraphWrapper: Note: CUDAGraphWrapper does not store persistent buffers or copy any runtime inputs into that buffers for replay. We assume implementing them - is done outside of the wrapper. That is because we do not make any + is done outside of the wrapper. That is because we do not make any assumption on the dynamic shape (batch size) of the runtime inputs, as a - trade-off for staying orthogonal to compilation logic. Nevertheless, + trade-off for staying orthogonal to compilation logic. Nevertheless, tracing and checking the input addresses to be consistent during replay is guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ - def __init__(self, - runnable: Callable, - vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, - graph_pool: Any = None, - cudagraph_options: Optional[CUDAGraphOptions] = None): + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + cudagraph_options: Optional[CUDAGraphOptions] = None, + ): self.runnable = runnable self.vllm_config = vllm_config - self.graph_pool = graph_pool self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -81,23 +82,26 @@ class CUDAGraphWrapper: # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't # need to initialize a CUDAGraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE - if self.graph_pool is None: - self.graph_pool = current_platform.get_global_graph_pool() + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() self.cudagraph_options = cudagraph_options # the entries for different batch descriptors that we need to capture # cudagraphs for. - self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\ - = {} + self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {} def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError(f"Attribute {key} not exists in the runnable of " - f"cudagraph wrapper: {self.runnable}") + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}" + ) def unwrap(self) -> Callable: # in case we need to access the original runnable. @@ -108,8 +112,10 @@ class CUDAGraphWrapper: batch_descriptor = forward_context.batch_descriptor cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode - if cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode != self.runtime_mode: + if ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode != self.runtime_mode + ): # CUDAGraphMode.NONE could mean the profile run, a warmup run, or # running without cudagraphs. # We do not trigger capture/replay if the runtime mode is not @@ -120,8 +126,9 @@ class CUDAGraphWrapper: if batch_descriptor not in self.concrete_cudagraph_entries: # create a new entry for this batch descriptor - self.concrete_cudagraph_entries[batch_descriptor] = \ - CUDAGraphEntry(batch_descriptor=batch_descriptor) + self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry( + batch_descriptor=batch_descriptor + ) entry = self.concrete_cudagraph_entries[batch_descriptor] @@ -131,8 +138,11 @@ class CUDAGraphWrapper: # capturing is fast, we don't need to log it for every # shape. E.g. we only log it for the first subgraph in # piecewise mode. - logger.debug("Capturing a cudagraph on (%s,%s)", - self.runtime_mode.name, entry.batch_descriptor) + logger.debug( + "Capturing a cudagraph on (%s,%s)", + self.runtime_mode.name, + entry.batch_descriptor, + ) # validate that cudagraph capturing is legal at this point. validate_cudagraph_capturing_enabled() @@ -151,9 +161,12 @@ class CUDAGraphWrapper: # therefore, we only run gc for the first graph, # and disable gc for the rest of the graphs. stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) + stack.enter_context(patch("torch.cuda.empty_cache", lambda: None)) + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. with torch.cuda.graph(cudagraph, pool=self.graph_pool): # `output` is managed by pytorch's cudagraph pool @@ -187,7 +200,8 @@ class CUDAGraphWrapper: assert new_input_addresses == entry.input_addresses, ( f"Input addresses for cudagraphs are different " f"during replay. Expected {entry.input_addresses}, " - f"got {new_input_addresses}") + f"got {new_input_addresses}" + ) entry.cudagraph.replay() return entry.output diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 41d9fcb824b01..20bf63c804012 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,20 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import hashlib import inspect +import os +import sys from typing import Callable, Optional, TypeVar, Union, overload from unittest.mock import patch import torch import torch.nn as nn +from packaging import version from torch._dynamo.symbolic_convert import InliningInstructionTranslator +import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.utils import supports_dynamo +from vllm.utils import resolve_obj_by_qualname, supports_dynamo from .monitor import start_monitoring_torch_compile @@ -32,11 +38,11 @@ def ignore_torch_compile(cls: _T) -> _T: a support_torch_compile decorator, but we don't want to compile the class `cls` that inherits the parent class. This only ignores compiling the forward of the class the - decorator is applied to. + decorator is applied to. If the parent has ignore_torch_compile but the child has support_torch_compile, the child will still be compiled. - + If the class has one or more submodules that have support_torch_compile decorator applied, compile will not be ignored for those submodules. @@ -56,21 +62,18 @@ def _should_ignore_torch_compile(cls) -> bool: def support_torch_compile( *, enable_if: Optional[Callable[[VllmConfig], bool]] = None, -) -> Callable[[_T], _T]: - ... +) -> Callable[[_T], _T]: ... @overload def support_torch_compile( *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], -) -> Callable[[_T], _T]: - ... +) -> Callable[[_T], _T]: ... @overload -def support_torch_compile(cls: _T) -> _T: - ... +def support_torch_compile(cls: _T) -> _T: ... def support_torch_compile( @@ -87,8 +90,7 @@ def support_torch_compile( ```python @support_torch_compile class MyModel(nn.Module): - def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): - ... + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` Usage 2: use as a decorator with arguments: @@ -96,8 +98,7 @@ def support_torch_compile( ```python @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) class MyModel(nn.Module): - def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): - ... + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic @@ -137,7 +138,7 @@ def support_torch_compile( def cls_decorator_helper(cls: _T) -> _T: # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile`` - if not hasattr(cls, 'forward'): + if not hasattr(cls, "forward"): raise TypeError("decorated class should have a forward method.") sig = inspect.signature(cls.forward) inferred_dynamic_arg_dims = dynamic_arg_dims @@ -145,26 +146,31 @@ def support_torch_compile( inferred_dynamic_arg_dims = {} for k, v in sig.parameters.items(): if v.annotation in [ - torch.Tensor, Optional[torch.Tensor], - IntermediateTensors, Optional[IntermediateTensors] + torch.Tensor, + Optional[torch.Tensor], + IntermediateTensors, + Optional[IntermediateTensors], ]: inferred_dynamic_arg_dims[k] = 0 - logger.debug(("Inferred dynamic dimensions for " - "forward method of %s: %s"), cls, - list(inferred_dynamic_arg_dims.keys())) + logger.debug( + ("Inferred dynamic dimensions for forward method of %s: %s"), + cls, + list(inferred_dynamic_arg_dims.keys()), + ) if len(inferred_dynamic_arg_dims) == 0: raise ValueError( "No dynamic dimensions found in the forward method of " - f"{cls}. Please provide dynamic_arg_dims explicitly.") + f"{cls}. Please provide dynamic_arg_dims explicitly." + ) for k in inferred_dynamic_arg_dims: if k not in sig.parameters: raise ValueError( - f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims, - enable_if) + f"Argument {k} not found in the forward method of {cls}" + ) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -174,6 +180,33 @@ def support_torch_compile( return cls_decorator_helper +def _model_hash_key(fn) -> str: + import vllm + + sha256_hash = hashlib.sha256() + sha256_hash.update(vllm.__version__.encode()) + sha256_hash.update(fn.__qualname__.encode()) + sha256_hash.update(str(fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + +def _verify_source_unchanged(source_info, vllm_config) -> None: + from .caching import _compute_code_hash, _compute_code_hash_with_content + + file_contents = {} + for source in source_info.inlined_sources: + module = sys.modules[source.module] + file = inspect.getfile(module) + vllm_config.compilation_config.traced_files.add(file) + file_contents[file] = source.content + expected_checksum = _compute_code_hash_with_content(file_contents) + actual_checksum = _compute_code_hash(set(file_contents.keys())) + if expected_checksum != actual_checksum: + raise RuntimeError( + "Source code has changed since the last compilation. Recompiling the model." + ) + + def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], @@ -189,29 +222,32 @@ def _support_torch_compile( # take care of method resolution order # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher - cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,) old_init = cls.__init__ setattr(cls, IGNORE_COMPILE_KEY, False) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = \ - vllm_config.compilation_config.level in [ - CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) or not enable_compile + self.do_not_compile = ( + vllm_config.compilation_config.level + in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS] + or not supports_dynamo() + or _should_ignore_torch_compile(self.__class__) + or not enable_compile + ) if self.do_not_compile: return compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_level=vllm_config.compilation_config.level) + self, compilation_level=vllm_config.compilation_config.level + ) cls.__init__ = __init__ @@ -222,6 +258,64 @@ def _support_torch_compile( if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) + if getattr(self, "aot_compiled_fn", None) is not None: + return self.aot_compiled_fn(self, *args, **kwargs) + + cache_dir = None + aot_compilation_path = None + if envs.VLLM_USE_AOT_COMPILE: + """ + When using torch.compile in AOT mode, we store the cache artifacts + under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash} + contains all of the factors except for the source files being + traced through, because we don't actually know which source files + to check at this point (before dynamo runs). + On loading we will actually look at the source files being traced + through. If any source file have changed (compared with the + serialized backend artifacts), then we need to generate a new AOT + compile artifact from scratch. + """ + from .caching import compilation_config_hash_factors + + factors: list[str] = compilation_config_hash_factors(self.vllm_config) + + factors.append(_model_hash_key(self.forward)) + hash_key = hashlib.sha256(str(factors).encode()).hexdigest() + + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, + "torch_aot_compile", + hash_key, + ) + + rank = self.vllm_config.parallel_config.rank + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") + aot_compilation_path = os.path.join(cache_dir, "model") + try: + with ( + set_current_vllm_config(self.vllm_config), + open(aot_compilation_path, "rb") as f, + ): + start_monitoring_torch_compile(self.vllm_config) + loaded_fn = torch.compiler.load_compiled_function(f) + _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) + self.aot_compiled_fn = loaded_fn + except Exception as e: + if os.path.exists(aot_compilation_path): + logger.warning( + "Cannot load aot compilation from path %s, error: %s", + aot_compilation_path, + str(e), + ) + if envs.VLLM_FORCE_AOT_LOAD: + raise e + if getattr(self, "aot_compiled_fn", None) is not None: + logger.info( + "Directly load AOT compilation from path %s", aot_compilation_path + ) + return self.aot_compiled_fn(self, *args, **kwargs) + # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: sig = inspect.signature(self.__class__.forward) @@ -233,26 +327,23 @@ def _support_torch_compile( dims = [dims] if isinstance(dims, int) else dims if isinstance(arg, torch.Tensor): # In case dims is specified with negative indexing - dims = [ - arg.ndim + dim if dim < 0 else dim for dim in dims - ] + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] torch._dynamo.mark_dynamic(arg, dims) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): # In case dims is specified with negative indexing dims = [ - tensor.ndim + dim if dim < 0 else dim - for dim in dims + tensor.ndim + dim if dim < 0 else dim for dim in dims ] torch._dynamo.mark_dynamic(tensor, dims) else: raise ValueError( "Unsupported dynamic dimensions" - f" {dims} for argument {k} with type {type(arg)}.") + f" {dims} for argument {k} with type {type(arg)}." + ) # here, it is the starting point of the `torch.compile` process start_monitoring_torch_compile(self.vllm_config) - logger.debug("Start compiling function %s", - self.original_code_object) + logger.debug("Start compiling function %s", self.original_code_object) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, @@ -261,8 +352,7 @@ def _support_torch_compile( # it seems Dynamo reuse the compilation across instances, # while we need to make sure the compiled code is not reused. # we need to control all the compilation of the model. - torch._dynamo.eval_frame.remove_from_cache( - self.original_code_object) + torch._dynamo.eval_frame.remove_from_cache(self.original_code_object) # collect all relevant files traced by Dynamo, # so that the compilation cache can trigger re-compilation @@ -270,19 +360,19 @@ def _support_torch_compile( # 1. the file containing the top-level forward function self.vllm_config.compilation_config.traced_files.add( - self.original_code_object.co_filename) + self.original_code_object.co_filename + ) # 2. every time Dynamo sees a function call, it will inline - # the function by calling InliningInstructionTranslator.inline_call + # the function by calling InliningInstructionTranslator.inline_call_ # we hijack this function to know all the functions called # during Dynamo tracing, and their corresponding files - inline_call = InliningInstructionTranslator.inline_call + inline_call = InliningInstructionTranslator.inline_call_ - def patched_inline_call(parent, func, args, kwargs): - code = func.get_code() - self.vllm_config.compilation_config.traced_files.add( - code.co_filename) - return inline_call(parent, func, args, kwargs) + def patched_inline_call(self_): + code = self_.f_code + self.vllm_config.compilation_config.traced_files.add(code.co_filename) + return inline_call(self_) # Disable the C++ compilation of symbolic shape guards. C++-fication # of symbolic shape guards can improve guard overhead. But, since @@ -291,18 +381,29 @@ def _support_torch_compile( dynamo_config_patches = {} try: _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards - dynamo_config_patches[ - "enable_cpp_symbolic_shape_guards"] = False + dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False except AttributeError: # Note: this config is not available in torch 2.6, we can skip # if the config doesn't exist - logger.debug( - "enable_cpp_symbolic_shape_guards config not available") + logger.debug("enable_cpp_symbolic_shape_guards config not available") - with patch.object(InliningInstructionTranslator, 'inline_call', - patched_inline_call), torch._dynamo.config.patch( - **dynamo_config_patches): - output = self.compiled_callable(*args, **kwargs) + with ( + patch.object( + InliningInstructionTranslator, "inline_call_", patched_inline_call + ), + torch._dynamo.config.patch(**dynamo_config_patches), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), + _torch27_patch_tensor_subclasses(), + ): + if envs.VLLM_USE_AOT_COMPILE: + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + output = self.aot_compiled_fn(self, *args, **kwargs) + assert aot_compilation_path is not None + assert cache_dir is not None + os.makedirs(cache_dir, exist_ok=True) + self.aot_compiled_fn.save_compiled_function(aot_compilation_path) + else: + output = self.compiled_callable(*args, **kwargs) return output # usually, capturing the model once is enough, and then we can @@ -314,3 +415,97 @@ def _support_torch_compile( cls.__call__ = __call__ return cls + + +@contextlib.contextmanager +def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): + """ + Context manager to set/unset customized cudagraph partition wrappers. + + If we're using Inductor-based graph partitioning, we currently have the + whole `fx.Graph` before Inductor lowering and and the piecewise + splitting happens after all graph passes and fusions. Here, we add + a custom hook for Inductor to wrap each partition with our static + graph wrapper class to maintain more control over static graph + capture and replay. + """ + from vllm.config import CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + and compilation_config.use_inductor_graph_partition + ): + from torch._inductor.utils import CUDAGraphWrapperMetadata + + from vllm.compilation.cuda_graph import CUDAGraphOptions + from vllm.platforms import current_platform + + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls() + ) + + def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): + partition_id = metadata.partition_index + num_partitions = metadata.num_partitions + return static_graph_wrapper_class( + runnable=f, + vllm_config=vllm_config, + runtime_mode=CUDAGraphMode.PIECEWISE, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=partition_id == 0, + gc_disable=partition_id != 0, + weak_ref_output=partition_id == num_partitions - 1, + ), + ) + + torch._inductor.utils.set_customized_partition_wrappers( + customized_cudagraph_wrapper + ) + + yield + + if ( + compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + and compilation_config.use_inductor_graph_partition + ): + torch._inductor.utils.set_customized_partition_wrappers(None) + + +@contextlib.contextmanager +def _torch27_patch_tensor_subclasses(): + """ + Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when + using torch 2.7.0. This enables using weight_loader_v2 and the use of + `BasevLLMParameters` without having to replace them with regular tensors + before `torch.compile`-time. + """ + from vllm.model_executor.parameter import ( + BasevLLMParameter, + ModelWeightParameter, + RowvLLMParameter, + _ColumnvLLMParameter, + ) + + def return_false(*args, **kwargs): + return False + + if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"): + yield + return + + with ( + torch._dynamo.config.patch( + "traceable_tensor_subclasses", + [ + BasevLLMParameter, + ModelWeightParameter, + _ColumnvLLMParameter, + RowvLLMParameter, + ], + ), + patch( + "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false + ), + ): + yield diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 286221d32c1ee..0dffb343f9a28 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -9,6 +9,7 @@ import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import is_func from .vllm_inductor_pass import VllmInductorPass @@ -25,9 +26,15 @@ class FixFunctionalizationPass(VllmInductorPass): To add new nodes to defunctionalize, add to the if-elif chain in __call__. """ + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_fix_functionalization") + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug( + "XPU platform does not support fix functionalizationpass currently." + ) + return self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 @@ -39,75 +46,111 @@ class FixFunctionalizationPass(VllmInductorPass): at_target = node.args[0] if at_target == torch.ops._C.rotary_embedding.default: - query = kwargs['query'] - mm_node = query.args[0].args[0] + query = kwargs["query"] + key = kwargs["key"] + getitem_nodes = self.getitem_users(node) - # rotary_embedding is a special case: the two mutating inputs - # are query and key, which are slices of mm_node. - # While functionalized, results at[1] and at[2] are scattered - # back into mm_node. After de-functionalization, we can just - # use mm_node directly. - for idx, user in self.getitem_users(node).items(): - for user_of_getitem in user.users: - if is_func(user_of_getitem, - torch.ops.aten.slice_scatter.default): - user_of_getitem.replace_all_uses_with(mm_node) - self._remove(user_of_getitem) - self._remove(user) + if ( + is_func(query, operator.getitem) + and is_func(key, operator.getitem) + and query.args[0] == key.args[0] + and is_func(query.args[0], torch.ops.aten.split_with_sizes.default) + and all( + is_func(user, torch.ops.aten.slice_scatter.default) + for getitem_node in getitem_nodes.values() + for user in getitem_node.users + ) + ): + # Pattern where query and key are slices of an mm_node. + # While functionalized, results at [1] and [2] are scattered + # back into mm_node. So after de-functionalization, we can + # just use mm_node directly. - self.insert_defunctionalized(graph, node) - self._remove(node) + mm_node = query.args[0].args[0] + for user in getitem_nodes.values(): + for user_of_getitem in user.users: + if is_func( + user_of_getitem, torch.ops.aten.slice_scatter.default + ): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + else: + # Directly replace the auto_functionalize(rotary_embedding) + # with the inplace rotary_embedding. In theory, we shouldn't + # do this blindly, but in practice in vLLM it's ok. The best + # solution is to use auto_functionalization_v2 and then use + # inductor's builtin defunctionalization (reinplacing) pass. + mutated_args = {1: "query", 2: "key"} + self.defunctionalize(graph, node, mutated_args) # rms_norm replacements avoid the most copies for LLaMa. elif at_target == torch.ops._C.fused_add_rms_norm.default: - mutated_args = {1: 'input', 2: 'residual'} + mutated_args = {1: "input", 2: "residual"} self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 - mutated_args = {1: 'result', 2: 'residual'} + mutated_args = {1: "result", 2: "residual"} self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 - mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} + mutated_args = {1: "result", 2: "scale", 3: "residual"} self.defunctionalize(graph, node, mutated_args) elif at_target in [ - torch.ops._C.rms_norm.default, - torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default, ]: - mutated_args = {1: 'result'} + mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) # For some reason we need to specify the args for both # silu_and_mul and silu_and_mul_quant. The kwargs # pathway gets the wrong answer. elif at_target == torch.ops._C.silu_and_mul.default: - mutated_args = {1: 'result'} - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'input')) + mutated_args = {1: "result"} + self.defunctionalize( + graph, node, mutated_args, args=("result", "input") + ) elif at_target == torch.ops._C.silu_and_mul_quant.default: - mutated_args = {1: 'result'} - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'input', 'scale')) + mutated_args = {1: "result"} + self.defunctionalize( + graph, node, mutated_args, args=("result", "input", "scale") + ) + elif ( + hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant") + and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default + ): + mutated_args = {1: "result", 2: "result_block_scale"} + self.defunctionalize( + graph, + node, + mutated_args, + args=( + "result", + "result_block_scale", + "input", + "input_global_scale", + ), + ) else: continue # skip the count count += 1 - self.dump_graph(graph, "before_fix_functionalization_cleanup") + self.dump_graph(graph, "before_cleanup") # Remove the nodes all at once count_removed = len(self.nodes_to_remove) for node in self.nodes_to_remove: graph.erase_node(node) - logger.debug("De-functionalized %s nodes, removed %s nodes", count, - count_removed) - self.dump_graph(graph, "after_fix_functionalization") - self.end_and_log() + logger.debug( + "De-functionalized %s nodes, removed %s nodes", count, count_removed + ) + self.nodes_to_remove.clear() - def _remove(self, node_or_nodes: Union[torch.fx.Node, - Iterable[torch.fx.Node]]): + def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): """ Stage a node (or nodes) for removal at the end of the pass. """ @@ -116,12 +159,13 @@ class FixFunctionalizationPass(VllmInductorPass): else: self.nodes_to_remove.extend(node_or_nodes) - def defunctionalize(self, - graph: torch.fx.Graph, - node: torch.fx.Node, - mutated_args: dict[int, Union[torch.fx.Node, str]], - args: Optional[tuple[Union[torch.fx.Node, str], - ...]] = None): + def defunctionalize( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): """ De-functionalize a node by replacing it with a call to the original. It also replaces the getitem users with the mutated arguments. @@ -131,10 +175,9 @@ class FixFunctionalizationPass(VllmInductorPass): self.insert_defunctionalized(graph, node, args=args) self._remove(node) - def replace_users_with_mutated_args(self, node: torch.fx.Node, - mutated_args: dict[int, - Union[torch.fx.Node, - str]]): + def replace_users_with_mutated_args( + self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]] + ): """ Replace all getitem users of the auto-functionalized node with the mutated arguments. @@ -160,11 +203,12 @@ class FixFunctionalizationPass(VllmInductorPass): users[idx] = user return users - def insert_defunctionalized(self, - graph: torch.fx.Graph, - node: torch.fx.Node, - args: Optional[tuple[Union[torch.fx.Node, str], - ...]] = None): + def insert_defunctionalized( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): """ Insert a new defunctionalized node into the graph before node. If one of the kwargs is 'out', provide args directly, @@ -176,8 +220,9 @@ class FixFunctionalizationPass(VllmInductorPass): :param args: If we cannot use kwargs, specify args directly. If an arg is a string, `node.kwargs[arg]` is used. """ # noqa: E501 - assert is_func(node, auto_functionalized), \ + assert is_func(node, auto_functionalized), ( f"node must be auto-functionalized, is {node} instead" + ) # Create a new call to the original function with graph.inserting_before(node): @@ -186,6 +231,7 @@ class FixFunctionalizationPass(VllmInductorPass): graph.call_function(function, kwargs=node.kwargs) else: # Args passed as strings refer to items in node.kwargs - args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg - for arg in args) + args = tuple( + node.kwargs[arg] if isinstance(arg, str) else arg for arg in args + ) graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 413948799de35..df54e94a03db4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, NamedTuple, Optional +from typing import Any, NamedTuple import torch import torch._inductor.pattern_matcher as pm @@ -12,13 +12,19 @@ from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, - kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) + GroupShape, + QuantKey, + ScaleDesc, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform -from .fx_utils import find_getitem_maybe -from .multi_output_match import MultiOutputMatch -from .vllm_inductor_pass import VllmInductorPass +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() @@ -41,14 +47,12 @@ RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: - torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 - kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501 + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default class FusedRMSQuantKey(NamedTuple): @@ -57,142 +61,93 @@ class FusedRMSQuantKey(NamedTuple): quant: type of quantization fused_add: does the op also perform the residual add """ + quant: QuantKey fused_add: bool def __str__(self): - return (f"FusedQuantKey({self.quant}, with" - f"{'' if self.fused_add else 'out'} residual)") + return ( + f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)" + ) FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey(kFp8StaticTensorSym, False): - torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8StaticTensorSym, True): - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, False): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, True): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, False + ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, True + ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, False + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, True + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } -class QuantMultiOutputMatch(MultiOutputMatch): - - def __init__(self, match: pm.Match, quant_op, fused_op): - super().__init__(match) - assert isinstance(quant_op, OpOverload) - assert isinstance(fused_op, OpOverload) - self.QUANT_OP = quant_op # in-place quant op - self.FUSED_OP = fused_op # in-place fused quant op - - def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, - int]], - **kwargs): - """ - This utility function inserts an auto-functionalized node for FUSED_OP. - It also correctly sets its meta value and rebinds the users of the - unfused nodes to use the fused node instead. - - :param fused_return_mapping: A dictionary, mapping from getitem indices - of the fused node result to a tuple of the old node and a getitem index. - :param kwargs: kwargs that get directly forwarded to the auto_fn node - - Example: - If we want to replace this graph: - _, x1, x2 = auto_fn(op1) - _, y1, y2 = auto_fn(op2) - - with - _, x1, y2, x2 = auto_fn(FUSED_OP) - - we would call: - insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} - - Note that the 0th element is None for auto-functionalized in-place ops. - Hence, others appear 1-indexed. - """ - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) - indices = fused_return_mapping.keys() - getitem_nodes = self.insert_getitems(fused_node, indices) - - # Prepare the meta value, use a list so it's mutable - meta_val = [None] * (max(indices) + 1) - - # Iterate through elements of the tuple produced by fused_node - for idx, getitem_node in zip(indices, getitem_nodes): - old_node, old_idx = fused_return_mapping[idx] - - # If the old value was never used, the old_getitem might not exist - old_getitem = find_getitem_maybe(old_node, old_idx) - if old_getitem is not None: - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - old_getitem.replace_all_uses_with(getitem_node) - getitem_node.meta["val"] = old_getitem.meta["val"] - - # Extract the appropriate meta value - # It is present even if the getitem node does not exist - meta_val[idx] = old_node.meta["val"][old_idx] - - # Fix the meta value on the new fused node - fused_node.meta["val"] = tuple(meta_val) - - class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, \ - f"unsupported quantization scheme {key.quant}" + assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, \ - f"unsupported fused rmsnorm+quant op for {key}" + assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] class RMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + fused_key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon) - at2 = auto_functionalized(self.QUANT_OP, - result=result, - input=at1[1], - scale=scale) + def pattern( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized( + RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + at2 = auto_functionalized( + self.QUANT_OP, result=result, input=at1[1], scale=scale + ) # result return at2[1] - def replacement(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon) + def replacement( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result return at[1] @@ -202,54 +157,60 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, - pm_pass) + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): - - def pattern(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) - at1 = auto_functionalized(self.QUANT_OP, - result=result, - input=at[1], - scale=scale) + def register(self, pm_pass: PatternMatcherPass): + def pattern( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + at1 = auto_functionalized( + self.QUANT_OP, result=result, input=at[1], scale=scale + ) # result, residual return at1[1], at[2] - def replacement(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon) + def replacement( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result, residual return at[1], at[2] @@ -259,7 +220,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -268,83 +229,63 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 1 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract the result and residual. - # The auto_fn node returns a tuple of (None, result, residual). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa - # result_node_new = at[1] - # residual_node_new = at[2] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - - # 0 is always None - fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} - self.insert_fused_node(fused_return_mapping, - **kwargs, - epsilon=rms_node.kwargs["epsilon"]) + ) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): - - def pattern(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon) - at2 = auto_functionalized(self.QUANT_OP, - result=result, - input=at1[1], - scale=scale, - scale_ub=None) + def register(self, pm_pass: PatternMatcherPass): + def pattern( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized( + RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + at2 = auto_functionalized( + self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None + ) # result, scale return at2[1], at2[2] - def replacement(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None) + def replacement( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + ) # result, scale return at[1], at[2] @@ -354,7 +295,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -363,86 +304,63 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 1 - assert len(quant_node.users) == 2 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract the result and scale. - # The auto_fn node returns a tuple of (None, result, scale). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa - # result_node_new = at[1] - # scale_node_new = at[2] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - del kwargs["result_rms"] # not used in the fused op - - fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} - self.insert_fused_node( - fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - scale_ub=None, # not used but required - residual=None, # not used but required - **kwargs) + ) class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) - def register(self, pm_pass: PatternMatcherPass, - record_match: Callable[[MultiOutputMatch], bool]): - - def pattern(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) - at1 = auto_functionalized(self.QUANT_OP, - result=result, - input=at[1], - scale=scale, - scale_ub=None) + def register(self, pm_pass: PatternMatcherPass): + def pattern( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + at1 = auto_functionalized( + self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None + ) # result, residual, scale return at1[1], at[2], at1[2] - def replacement(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual) + def replacement( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + ) # result, residual, scale return at[1], at[3], at[2] @@ -452,7 +370,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -461,136 +379,53 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match( - self.Match(m, self.QUANT_OP, self.FUSED_OP))) - - class Match(QuantMultiOutputMatch): - - def process(self): - # Find the nodes in the match that we need to rebind - rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(self.QUANT_OP) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 2 - - # First, insert a new auto_functionalized node for the fused op, - # as well as getitem nodes to extract result, scale, and residual. - # The auto_fn node returns a tuple (None, result, scale, residual). - # - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa - # result_node_new = at[1] - # scale_node_new = at[2] - # residual_node_new = at[3] - with self.inserting_after_match(): - # Missing epsilon, scalars cannot be inputs to the pattern - kwargs = self.match.kwargs.copy() - - fused_return_mapping = { - 1: (quant_node, 1), # result - 2: (quant_node, 2), # scale - 3: (rms_node, 2), # residual - } - self.insert_fused_node( - fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - scale_ub=None, # not used but required - **kwargs) + ) -class FusionPass(VllmInductorPass): +class RMSNormQuantFusionPass(VllmPatternMatcherPass): """ - This pass fuses a pre-defined set of custom ops into fused ops. - It uses the torch pattern matcher to find the patterns and replace them. - It also manually processes multi-output matches, as those are broken in - the torch pattern matcher. - - Because patterns can only be registered once, the pass is a singleton. - This will be addressed in a future version of PyTorch: - https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports fused_add_rms_norm. """ - _instance: 'Optional[FusionPass]' = None - - @classmethod - def instance(cls, config: VllmConfig): - """ - Get the singleton instance of the FusionPass. - If the instance exists, the config is updated but - initialization is not repeated. - """ - if cls._instance is None: - cls._instance = FusionPass(config) - else: - cls._instance.pass_config = config.compilation_config.pass_config - return cls._instance - + @enable_fake_mode def __init__(self, config: VllmConfig): - assert self.__class__._instance is None, \ - "FusionPass singleton instance already exists" super().__init__(config) - self.matches: list[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="fusion_pass") + pass_name="rmsnorm_quant_fusion_pass" + ) for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Matches for patterns below have 2 or more outputs, - # so we need to process them manually (see process_matches) - - # Fuse rms_norm + static fp8 quant + # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + self.patterns + ) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns, self.record_match) + self.patterns + ) - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() - - def record_match(self, match: MultiOutputMatch) -> bool: - # Hijack the extra_check to record the match and - # save it for post-processing. - self.matches.append(match) - - # Return False to prevent automatic replacement. - return False - - def process_matches(self, graph: fx.Graph): - """ - Manually process multi-output matches and replace them with fused nodes. - See MultiOutputMatch for more details. - """ - for match in self.matches: - match.process() - - # Finally, remove matched nodes - graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in self.matches - for node in match.match.nodes) + self.dump_patterns(config, self.patterns) + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_fusion") + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", count) - self.dump_graph(graph, "after_pattern_match") - - # Manually process multi-output matches (and run DCE) - self.process_matches(graph) - logger.debug("Post-processed %s matches", len(self.matches)) - self.dump_graph(graph, "after_fusion") - self.matches.clear() - self.end_and_log() + def uuid(self) -> Any: + return self.hash_source( + self, + RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern, + ) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index f942afe6a28ee..ae36cef926539 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -7,19 +7,21 @@ import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._subclasses.fake_tensor import (FakeTensorMode, - unset_fake_temporarily) from vllm.attention import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kNvfp4Quant, kStaticTensorScale) + QuantKey, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 -from .vllm_inductor_pass import VllmInductorPass +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -40,6 +42,7 @@ class AttentionQuantPattern(ABC): self, layer: Attention, quant_key: QuantKey, + dtype: torch.dtype, ): self.layer = layer self.layer_name = layer.layer_name @@ -47,18 +50,23 @@ class AttentionQuantPattern(ABC): self.head_size = layer.head_size self.quant_key = quant_key self.quant_dtype = quant_key.dtype + self.dtype = dtype - assert self.quant_key in QUANT_OPS, \ + assert self.quant_key in QUANT_OPS, ( f"unsupported quantization scheme {self.quant_key}" + ) self.QUANT_OP = QUANT_OPS[self.quant_key] + def empty(self, *args, **kwargs): + kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs} + return torch.empty(*args, **kwargs) + def empty_quant(self, *args, **kwargs): - kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @staticmethod def wrap_trace_fn(process_fx, trace_fn): - def wrapped(*args, **kwargs): return process_fx(trace_fn(*args, **kwargs)) @@ -67,6 +75,7 @@ class AttentionQuantPattern(ABC): @staticmethod def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) return gm @@ -92,71 +101,88 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): def __init__( self, layer: Attention, + dtype: torch.dtype, symmetric: bool = True, ): - quant_key = QuantKey(dtype=FP8_DTYPE, - scale=kStaticTensorScale, - symmetric=symmetric) - super().__init__(layer, quant_key) + quant_key = QuantKey( + dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric + ) + super().__init__(layer, quant_key, dtype) def _register(self, pm_pass: PatternMatcherPass): - - def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=None, - output_block_scale=None) + def pattern( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None, + ) attn_out_view = RESHAPE_OP( - at1[1], [q.shape[0], self.num_heads * self.head_size]) - at2 = auto_functionalized(self.QUANT_OP, - result=output_quant, - input=attn_out_view, - scale=scale) + at1[1], [q.shape[0], self.num_heads * self.head_size] + ) + at2 = auto_functionalized( + self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale + ) return at2[1] - def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - scale: torch.Tensor): + def replacement( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + scale: torch.Tensor, + ): # attn output in quant_dtype output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size], 0.0, dtype=self.quant_dtype, - device=q.device) - at1 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=scale, - output_block_scale=None) + device=q.device, + ) + at1 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=scale, + output_block_scale=None, + ) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) - # Need custom fake mode, otherwise tracing happens with real tensors. - # That would not work for the unified_attention custom op. - with unset_fake_temporarily(), FakeTensorMode(): - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # attn_output - self.empty_quant(5, self.num_heads * - self.head_size), # quant_output - empty_fp32(1, 1) # scale - ] + inputs = [ + self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q + self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k + self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v + self.empty( + 5, self.num_heads, self.head_size, dtype=self.dtype + ), # attn_output + self.empty_quant(5, self.num_heads * self.head_size), # quant_output + empty_fp32(1, 1), # scale + ] - pm.register_replacement( - pattern, replacement, inputs, - AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), - pm_pass) + pm.register_replacement( + pattern, + replacement, + inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + ), + pm_pass, + ) class AttentionNvfp4QuantPattern(AttentionQuantPattern): @@ -169,80 +195,97 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__(self, layer: Attention): - super().__init__(layer, kNvfp4Quant) + def __init__(self, layer: Attention, dtype: torch.dtype): + super().__init__(layer, kNvfp4Quant, dtype) def _register(self, pm_pass: PatternMatcherPass): - - def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - output_scale: torch.Tensor, input_scale: torch.Tensor): - at1 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=None, - output_block_scale=None) + def pattern( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + output_scale: torch.Tensor, + input_scale: torch.Tensor, + ): + at1 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None, + ) attn_out_view = RESHAPE_OP( - at1[1], [q.shape[0], self.num_heads * self.head_size]) - at2 = auto_functionalized(self.QUANT_OP, - output=output_quant, - input=attn_out_view, - output_scale=output_scale, - input_scale=input_scale) + at1[1], [q.shape[0], self.num_heads * self.head_size] + ) + at2 = auto_functionalized( + self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale, + ) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], output_scale_view - def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - output_scale: torch.Tensor, input_scale: torch.Tensor): + def replacement( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + output_scale: torch.Tensor, + input_scale: torch.Tensor, + ): # attention output in quant_dtype output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size // 2], 0.0, dtype=self.quant_dtype, - device=q.device) + device=q.device, + ) # attention output block scale - output_scale_view = torch.ops.aten.view.dtype( - output_scale, FP8_DTYPE) - at2 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=input_scale, - output_block_scale=output_scale_view) - output = RESHAPE_OP(at2[1], - [-1, self.num_heads * self.head_size // 2]) + output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) + at2 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view, + ) + output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2]) return output, at2[2] - # Need custom fake mode, otherwise tracing happens with real tensors. - # That would not work for the unified_attention custom op. - with unset_fake_temporarily(), FakeTensorMode(): - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # output_attn - self.empty_quant(5, self.num_heads * self.head_size // - 2), # output_quant - empty_i32(128, - round_up(self.num_heads * self.head_size // 16, - 4)), # output_scale - empty_fp32(1, 1), # input_scale - ] + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant + empty_i32( + 128, round_up(self.num_heads * self.head_size // 16, 4) + ), # output_scale + empty_fp32(1, 1), # input_scale + ] - pm.register_replacement( - pattern, replacement, inputs, - AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), - pm_pass) + pm.register_replacement( + pattern, + replacement, + inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + ), + pm_pass, + ) -class AttnFusionPass(VllmInductorPass): +class AttnFusionPass(VllmPatternMatcherPass): """ This pass fuses post-attention quantization onto attention if supported. @@ -255,6 +298,7 @@ class AttnFusionPass(VllmInductorPass): support are attention kernels, which need to support fusing output quant. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -262,34 +306,35 @@ class AttnFusionPass(VllmInductorPass): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): - pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8 = AttentionFp8StaticQuantPattern( + layer, config.model_config.dtype + ) pattern_fp8.register_if_supported(self.patterns) - pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) - pattern_nvfp4.register_if_supported(self.patterns) + if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + pattern_nvfp4 = AttentionNvfp4QuantPattern( + layer, config.model_config.dtype + ) + pattern_nvfp4.register_if_supported(self.patterns) if len(attn_layers) == 0: logger.warning( "Attention + quant fusion is enabled, but no attention layers " "were found in CompilationConfig.static_forward_context " - "so no fusion patterns were registered.") + "so no fusion patterns were registered." + ) + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.graph.Graph) -> None: - self.begin() - self.dump_graph(graph, "before_attn_fusion") - - count = self.patterns.apply(graph) - - # TODO: Move this to pass_manager.py after the fx graph broken issue - # has been resolved. - # see https://github.com/vllm-project/vllm/issues/23091 - graph.eliminate_dead_code() - - logger.debug("Fused quantization onto %s attention nodes", count) - self.dump_graph(graph, "after_attn_fusion") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Fused quant onto %s attention nodes", self.matched_count) def uuid(self): - return VllmInductorPass.hash_source(self, AttentionQuantPattern, - AttentionFp8StaticQuantPattern, - AttentionNvfp4QuantPattern) + return VllmInductorPass.hash_source( + self, + AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern, + ) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 2db8b5441bd6f..114b53c74c48f 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -19,8 +19,9 @@ def is_auto_func(node: fx.Node, op: OpOverload) -> bool: # Returns the first specified node with the given op (if it exists) -def find_specified_fn_maybe(nodes: Iterable[fx.Node], - op: OpOverload) -> Optional[fx.Node]: +def find_specified_fn_maybe( + nodes: Iterable[fx.Node], op: OpOverload +) -> Optional[fx.Node]: for node in nodes: if node.target == op: return node @@ -35,8 +36,7 @@ def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: # Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[fx.Node], - op: OpOverload) -> Optional[fx.Node]: +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: for node in nodes: if is_func(node, auto_functionalized) and node.args[0] == op: # noqa return node diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 2a149c65b3877..9085448d23978 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import hashlib import inspect import json @@ -10,6 +11,7 @@ from typing import Any, Callable, Optional, Union import torch from torch import fx +from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from vllm.utils import is_torch_equal_or_newer @@ -17,14 +19,14 @@ if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass else: # CustomGraphPass is not present in 2.5 or lower, import our version - from .torch25_custom_graph_pass import ( # noqa: E501 - Torch25CustomGraphPass as CustomGraphPass) + from .torch25_custom_graph_pass import ( + Torch25CustomGraphPass as CustomGraphPass, + ) _pass_context = None class PassContext: - def __init__(self, runtime_shape: Optional[int]): self.runtime_shape = runtime_shape @@ -103,9 +105,9 @@ class CallableInductorPass(InductorPass): implementation of the UUID. """ - def __init__(self, - callable: Callable[[fx.Graph], None], - uuid: Optional[Any] = None): + def __init__( + self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None + ): self.callable = callable self._uuid = self.hash_source(callable) if uuid is None else uuid @@ -114,3 +116,19 @@ class CallableInductorPass(InductorPass): def uuid(self) -> Any: return self._uuid + + +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ + + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) + + return result + + return fn_new diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 9047bf3cbf8e8..d3c437795fabb 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time from vllm.config import CompilationConfig, CompilationLevel, VllmConfig @@ -18,21 +17,22 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): torch_compile_start_time = time.time() compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.PIECEWISE and \ - compilation_config.debug_dump_path: + path = vllm_config.compile_debug_dump_path() + if compilation_config.level == CompilationLevel.PIECEWISE and path: import depyf - path = os.path.join(compilation_config.debug_dump_path, - f"rank_{vllm_config.parallel_config.rank}") + + path.mkdir(parents=True, exist_ok=True) global context_manager - context_manager = depyf.prepare_debug(path) + context_manager = depyf.prepare_debug(path.as_posix()) context_manager.__enter__() def end_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config if compilation_config.level == CompilationLevel.PIECEWISE: - logger.info("torch.compile takes %.2f s in total", - compilation_config.compilation_time) + logger.info( + "torch.compile takes %.2f s in total", compilation_config.compilation_time + ) global context_manager if context_manager is not None: context_manager.__exit__(None, None, None) @@ -43,13 +43,15 @@ cudagraph_capturing_enabled: bool = True def validate_cudagraph_capturing_enabled(): - # used to monitor whether an cudagraph capturing is legal at runtime. + # used to monitor whether a cudagraph capturing is legal at runtime. # should be called before any cudagraph capturing. # if an illegal cudagraph capturing happens, raise an error. global cudagraph_capturing_enabled if not cudagraph_capturing_enabled: - raise RuntimeError("CUDA graph capturing detected at an inappropriate " - "time. This operation is currently disabled.") + raise RuntimeError( + "CUDA graph capturing detected at an inappropriate " + "time. This operation is currently disabled." + ) def set_cudagraph_capturing_enabled(enabled: bool): diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py deleted file mode 100644 index 6d1893777cec6..0000000000000 --- a/vllm/compilation/multi_output_match.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import abc -import operator -from abc import abstractmethod -from collections.abc import Iterable - -from torch import fx -from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor import pattern_matcher as pm -from torch._ops import OpOverload -from torch.fx import Node - -from vllm.compilation.fx_utils import find_auto_fn - - -class MultiOutputMatch(abc.ABC): - """ - This class provides utilities to process multi-output matches and - manually insert replacements. - - This is necessary because the automatic replacement for multi-output - matches is broken: https://github.com/pytorch/pytorch/issues/137280 - """ - - def __init__(self, match: pm.Match): - self.match = match - - @abstractmethod - def process(self): - """ - Process a multi-output match and manually insert the replacement. - - This method should: - 1. Insert the replacement nodes after the last node in the match. - 2. Rebind the users of nodes in the match to use the new nodes. - 3. Set meta["val"] for de-functionalization. - - The result of an auto-functionalized node is a tuple of tensors. - The first element is the return value of the function, usually None. - The remaining elements are the mutated args of the function. - - All auto-functionalized nodes must contain a proper meta["val"], - as it is used by de-functionalization. meta["val"] has to contain the - value of the node (tuple of tensors) that would be returned by the - functionalized node during tracing. - - Existing nodes in the graph all have this property set, but we have - to set it manually for new nodes we insert. - - Example: - # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None - at = auto_functionalized(torch.ops._C.foo.default, a, b, c) - # at.meta["val"] = (None, a, c) - """ - raise NotImplementedError - - @property - def nodes(self) -> list[fx.Node]: - return self.match.nodes - - @property - def graph(self) -> fx.Graph: - return self.match.graph - - def find_auto_fn(self, op) -> fx.Node: - """ - Find the first auto_functionalized node with the given op in the match. - """ - return find_auto_fn(self.nodes, op) - - def inserting_after_match(self): - """ - Insert nodes after the last node in the match. - This is done to avoid use-before-definition errors after inserting - replacement nodes. - """ - - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(self.graph.nodes): - if last_node_in_match in self.match.nodes: - break - else: - raise ValueError("No nodes in graph") - - return self.graph.inserting_after(last_node_in_match) - - def insert_getitems(self, tuple_node: fx.Node, - indices: Iterable[int]) -> tuple[fx.Node, ...]: - """ - Insert operator.getitem nodes to extract elements from a tuple node. - - :param tuple_node: The tuple node to extract elements from. - :param indices: The indices of the elements to extract. - :return: Tuple of the new getitem nodes, corresponding to the indices. - """ - with self.graph.inserting_after(tuple_node): - return tuple( - self.graph.call_function(operator.getitem, (tuple_node, idx)) - for idx in indices) - - def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: - """ - Insert an auto_functionalized node with the given op and kwargs. - """ - return self.graph.call_function(auto_functionalized, (op, ), - kwargs=kwargs) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 4888d4d1298e3..45668c7af3151 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -62,14 +62,10 @@ class NoOpEliminationPass(VllmInductorPass): scaled_mm: "f16[s0, 4096]" = ... at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) out: "f16[s0, 4096]" = at[1] - - TODO(luka): This is currently tested in test_fusion, - but separate tests could be good. """ + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - self.dump_graph(graph, "before_noop_elimination") count = 0 # Remove no-op reshapes/views: for node in graph.nodes: @@ -85,81 +81,57 @@ class NoOpEliminationPass(VllmInductorPass): graph.erase_node(input) count += 1 - # Case 2: remove this reshape if it produces the original shape - input, shape = node.args[:2] + # remove reshape/slice if it produces the original shape + if is_func(node, torch.ops.aten.reshape.default) or is_func( + node, torch.ops.aten.slice.Tensor + ): + input = node.args[0] input_shape = input.meta["val"].shape - if len(shape) != len(input_shape): - # Reshape changing rank, skip - continue - - if shape.count(-1) > 1: - # Invalid reshape args, skip - continue - - if self.all_dims_equivalent(shape, input_shape): + output_shape = node.meta["val"].shape + if self.all_dims_equivalent(input_shape, output_shape): node.replace_all_uses_with(input) graph.erase_node(node) count += 1 - - elif is_func(node, torch.ops.aten.slice.Tensor): - input, dim_index, start, end = node.args[:4] - input_shape = input.meta["val"].shape - i_dim = input_shape[dim_index] - - if start == 0 and self.dims_equivalent(end, i_dim): - node.replace_all_uses_with(input) - graph.erase_node(node) - count += 1 - elif is_func(node, torch.ops.aten.slice_scatter.default): base, view, dim_index, start, end = node.args[:5] base_shape = base.meta["val"].shape view_shape = view.meta["val"].shape - view_dim = view_shape[dim_index] - - # Check that view fully covers base and the full view is used - # (if the view fully covered the base after slicing but was not - # fully used, we could replace slice_scatter with a simple slice - # but that's a niche case). - if (base_shape == view_shape and start == 0 - and self.dims_equivalent(end, view_dim)): + if self.all_dims_equivalent(base_shape, view_shape): node.replace_all_uses_with(view) graph.erase_node(node) count += 1 logger.debug("Removed %s no-op reshapes and slices", count) - self.dump_graph(graph, "after_noop_elimination") - self.end_and_log() - def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], - i_dims: Iterable[Union[int, SymInt]]): - return all( - self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) - - def dims_equivalent(self, dim: Union[int, torch.fx.Node], - i_dim: Union[int, SymInt]) -> bool: + # ---------------------- Shape comparison helpers ---------------------- + def dims_equivalent( + self, dim: Union[int, SymInt], i_dim: Union[int, SymInt] + ) -> bool: """ This function checks if two dimensions are equivalent. :param dim: The dimension arg to reshape/slice :param i_dim: The corresponding dimension in the input tensor :return: Are the dimensions equivalent? - There are three cases in which the dimensions are equivalent: + There are two cases in which the dimensions are equivalent: 1. The dimensions are equal (both integers) - 2. The reshape dimension is -1 (i.e. inferred) - 3. The dimensions both correspond to the same SymInt - - While case 2 does not guarantee the dimensions are equal, - they are equal if all other dimensions are equal. - - In case 3, the reshape dimension is a torch.fx.Node, - and its value is a SymInt. That value is equal to the - input dimension. - + 2. The dimensions both correspond to the same SymInt """ - # Case 1 and 2 - if dim == i_dim or dim == -1: - return True - # Case 3 - return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim + # Case 1 + if isinstance(i_dim, int) and isinstance(dim, int): + return dim == i_dim + # Case 2 + if isinstance(i_dim, SymInt) and isinstance(dim, SymInt): + return dim == i_dim + return False + + def all_dims_equivalent( + self, dims: Iterable[Union[int, SymInt]], i_dims: Iterable[Union[int, SymInt]] + ) -> bool: + dims_ = list(dims) + i_dims_ = list(i_dims) + if len(dims_) != len(i_dims_): + # Different ranks can't be equivalent + return False + return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py new file mode 100644 index 0000000000000..c17a5bd4480c9 --- /dev/null +++ b/vllm/compilation/partition_rules.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING + +from torch._library.utils import lookup_op + +from vllm.logger import init_logger + +if TYPE_CHECKING: + import torch + +logger = init_logger(__name__) + + +def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]: + """Resolve operator names to OpOverload objects. + + Skips operators that fail to resolve (e.g., operators not registered or + model-specific operators not present in the current model). + + Note: Users should inspect the operator graph before lowering and ensure + the specified operators are present in the final graph. Built-in PyTorch + operators (aten::*, torch::*) may be decomposed, fused, or transformed + during Inductor's compilation passes, so use them with caution. + + Args: + op_names: List of operator names in PyTorch format + (e.g., "vllm::unified_attention") + + Returns: + List of successfully resolved operator overloads + """ + resolved = [] + for op_name in op_names: + try: + resolved.append(lookup_op(op_name)) + except Exception: + # Skip operators that don't exist (e.g., model-specific ops) + logger.warning( + "Failed to resolve operator for Inductor partition: %s", op_name + ) + continue + + return resolved + + +@contextlib.contextmanager +def inductor_partition_rule_context(overloads: list[torch._ops.OpOverload]): + """Context manager to temporarily register Inductor partition rules. + + Registers custom partition rules for specified operators, forcing the + Inductor scheduler to partition the graph at these operators. The rules + are automatically restored to their previous state on exit. + + Note: Callers should use resolve_defined_ops() to convert operator names + to OpOverload objects before calling this function. + + Args: + overloads: List of resolved operator overload objects. + """ + if not overloads: + logger.debug("No partition ops provided; skipping rule registration.") + yield + return + + from torch._inductor.scheduler import ( # type: ignore + _custom_should_partition_fns, + register_should_partition_rule, + ) + + def _always_partition(*_args, **_kwargs): + return True + + # Save current state before registering + saved_rules = _custom_should_partition_fns.copy() + + for overload in overloads: + register_should_partition_rule( + overload, + _always_partition, + ) + + logger.debug("Registered inductor partition rules for %d operators", len(overloads)) + + try: + yield + finally: + # Clear and restore previous state + _custom_should_partition_fns.clear() + _custom_should_partition_fns.update(saved_rules) + logger.debug("Restored previous partition rules state.") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e07e52be9fdf6..e323fa1f77349 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,29 +1,52 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools from torch import fx as fx +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import set_env_var + +from .post_cleanup import PostCleanupPass +from .vllm_inductor_pass import VllmInductorPass if current_platform.is_cuda_alike(): - from .fusion import FusionPass + from .activation_quant_fusion import ActivationQuantFusionPass + from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass -from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass from .sequence_parallelism import SequenceParallelismPass -from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +def with_pattern_match_debug(fn): + """ + Function decorator that turns on inductor pattern match debug + for the duration of the call. + Used to avoid logging builtin Inductor pattern matching. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None: + # optionally check rank here + with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val): + return fn(*args, **kwargs) + return fn(*args, **kwargs) + + return wrapper + + class PostGradPassManager(CustomGraphPass): """ The pass manager for post-grad passes. @@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: list[VllmInductorPass] = [] + self.passes: list[InductorPass] = [] + @with_pattern_match_debug def __call__(self, graph: fx.Graph): + VllmInductorPass.dump_prefix = 0 # reset dump index + shape = get_pass_context().runtime_shape for pass_ in self.passes: if pass_.is_applicable_for_shape(shape): pass_(graph) + VllmInductorPass.dump_prefix += 1 + + # post-cleanup goes before fix_functionalization + # because it requires a functional graph + self.post_cleanup(graph) + VllmInductorPass.dump_prefix += 1 # always run fix_functionalization last self.fix_functionalization(graph) + VllmInductorPass.dump_prefix = None # Cleanup index def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config @@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] + if self.pass_config.enable_fusion: - self.passes += [FusionPass.instance(config)] + self.passes += [RMSNormQuantFusionPass(config)] self.passes += [ActivationQuantFusionPass(config)] if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/piecewise_backend.py similarity index 86% rename from vllm/compilation/cuda_piecewise_backend.py rename to vllm/compilation/piecewise_backend.py index ae26e9f1bf2b6..61551766a1c52 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -23,15 +23,19 @@ class ConcreteSizeEntry: class PiecewiseBackend: - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - piecewise_compile_index: int, total_piecewise_compiles: int, - sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): + def __init__( + self, + graph: fx.GraphModule, + vllm_config: VllmConfig, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, + ): """ The backend for piecewise compilation. - It mainly handles the compilation of static shapes and + It mainly handles the compilation of static shapes and dispatching based on runtime shape. We will compile `self.graph` once for the general shape, @@ -46,13 +50,11 @@ class PiecewiseBackend: self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) + self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set( - self.compilation_config.compile_sizes) + self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) self.first_run_finished = False @@ -108,7 +110,8 @@ class PiecewiseBackend: self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) + runtime_shape=runtime_shape, + ) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: diff --git a/vllm/compilation/post_cleanup.py b/vllm/compilation/post_cleanup.py new file mode 100644 index 0000000000000..55117516838ca --- /dev/null +++ b/vllm/compilation/post_cleanup.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from torch import fx + +from vllm.compilation.vllm_inductor_pass import VllmInductorPass + + +class PostCleanupPass(VllmInductorPass): + """ + This pass performs cleanup after custom passes. + It topologically sorts the graph and removes unused nodes. + This is needed because the pattern matcher does not guarantee producing + a topologically sorted graph, and there may be unused nodes left around. + """ + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + from torch._inductor.pattern_matcher import stable_topological_sort + + stable_topological_sort(graph) + graph.eliminate_dead_code() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index ebc025cba71ed..2bc705c3b9a9c 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -9,12 +9,12 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.platforms import current_platform -from .vllm_inductor_pass import VllmInductorPass +from .inductor_pass import enable_fake_mode +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -22,12 +22,14 @@ logger = init_logger(__name__) class _RMSNormAndQuantOpHelper: """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" - def __init__(self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: Optional[torch._ops.OpOverload] = None, - **kwargs): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs, + ): self.epsilon = epsilon self.dtype = dtype self.device = device @@ -39,60 +41,78 @@ class _RMSNormAndQuantOpHelper: result=result_buffer, input=input_tensor, weight=weight_tensor, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) - def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, - weight_tensor): + def _functional_fused_add_rmsnorm( + self, input_tensor, residual_tensor, weight_tensor + ): return torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, input=input_tensor, residual=residual_tensor, weight=weight_tensor, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) - def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, - quant_result_buffer, input_tensor, - weight_tensor, scale_tensor): + def _functional_rmsnorm_then_quant( + self, + rmsnorm_result_buffer, + quant_result_buffer, + input_tensor, + weight_tensor, + scale_tensor, + ): if self.quant_op is None: raise RuntimeError( "_RMSNormAndQuantOpHelper was not initialized with a quant_op." ) - rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer, - input_tensor, - weight_tensor) + rmsnorm_out_tuple = self._functional_rmsnorm( + rmsnorm_result_buffer, input_tensor, weight_tensor + ) quant_out_tuple = torch.ops.higher_order.auto_functionalized( self.quant_op, result=quant_result_buffer, input=rmsnorm_out_tuple[1], - scale=scale_tensor) + scale=scale_tensor, + ) return quant_out_tuple - def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, - input_tensor, residual_tensor, - weight_tensor, scale_tensor): + def _functional_fused_add_rmsnorm_then_quant( + self, + quant_result_buffer, + input_tensor, + residual_tensor, + weight_tensor, + scale_tensor, + ): if self.quant_op is None: raise RuntimeError( "_RMSNormAndQuantOpHelper was not initialized with a quant_op." ) fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( - input_tensor, residual_tensor, weight_tensor) + input_tensor, residual_tensor, weight_tensor + ) quant_out_tuple = torch.ops.higher_order.auto_functionalized( self.quant_op, result=quant_result_buffer, input=fused_add_rmsnorm_out_tuple[1], - scale=scale_tensor) + scale=scale_tensor, + ) return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): """Helper for sequence parallelism patterns.""" - def __init__(self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: Optional[torch._ops.OpOverload] = None, - **kwargs): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs, + ): super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) self.tp_group = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() @@ -102,21 +122,16 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.vllm.reduce_scatter.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp_group.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name + ) def _all_gather(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp_group.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name + ) class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) @@ -125,7 +140,6 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): return [input, permute, arg3_1] def register(self, pm_pass: PatternMatcherPass): - def pattern( input: torch.Tensor, permute: torch.Tensor, @@ -144,26 +158,23 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): reduce_scatter = self._reduce_scatter(input) rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, - arg3_1) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1) all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [ residual, @@ -172,7 +183,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( residual: torch.Tensor, mm_1: torch.Tensor, @@ -180,7 +190,8 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights) + all_reduce, residual, rms_norm_weights + ) return rmsnorm[1], rmsnorm[2] def replacement( @@ -190,23 +201,22 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights) + reduce_scatter, residual, rms_norm_weights + ) all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [ residual, @@ -215,7 +225,6 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( residual: torch.Tensor, mm_1: torch.Tensor, @@ -223,7 +232,8 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights) + all_reduce, residual, rms_norm_weights + ) return rmsnorm[1] def replacement( @@ -233,37 +243,34 @@ class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights) + reduce_scatter, residual, rms_norm_weights + ) normalized = self._all_gather(rmsnorm[1]) return normalized - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) FP8_DTYPE = current_platform.fp8_dtype() class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): + def __init__( + self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + ): super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], - device=self.device, - dtype=FP8_DTYPE) + rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) return [input, rmsnorm_result, quant_result, weight, scale] def register(self, pm_pass: PatternMatcherPass): - def pattern( input: torch.Tensor, rmsnorm_result: torch.Tensor, @@ -273,7 +280,8 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ): all_reduce = self._all_reduce(input) static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, all_reduce, weight, scale) + rmsnorm_result, quant_result, all_reduce, weight, scale + ) return static_fp8[1], all_reduce def replacement( @@ -285,34 +293,36 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ): reduce_scatter = self._reduce_scatter(input) - rmsnorm_result = torch.empty_like(reduce_scatter, - dtype=rmsnorm_result.dtype) + rmsnorm_result = torch.empty_like( + reduce_scatter, dtype=rmsnorm_result.dtype + ) quant_result = torch.empty_like( rmsnorm_result, # Output of RMSNorm - dtype=quant_result.dtype) + dtype=quant_result.dtype, + ) static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, reduce_scatter, weight, scale) + rmsnorm_result, quant_result, reduce_scatter, weight, scale + ) all_gather = self._all_gather(static_fp8[1]) return all_gather, reduce_scatter - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): + def __init__( + self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + ): super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) @@ -325,7 +335,6 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( result: torch.Tensor, residual: torch.Tensor, @@ -334,8 +343,11 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - result, all_reduce, residual, rms_norm_weights, scale) + static_fp8, rmsnorm_residual_out = ( + self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale + ) + ) return static_fp8[1], rmsnorm_residual_out def replacement( @@ -346,31 +358,31 @@ class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, - dtype=result.dtype) - static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - quant_result_buf, reduce_scatter, residual, rms_norm_weights, - scale) + quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) + static_fp8, rmsnorm_residual_out = ( + self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale + ) + ) all_gather = self._all_gather(static_fp8[1]) return all_gather, rmsnorm_residual_out - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): + def __init__( + self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + ): super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) @@ -383,7 +395,6 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( result: torch.Tensor, residual: torch.Tensor, @@ -393,7 +404,8 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - result, all_reduce, residual, rms_norm_weights, scale) + result, all_reduce, residual, rms_norm_weights, scale + ) return static_fp8[1] def replacement( @@ -404,19 +416,19 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, - dtype=result.dtype) + quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - quant_result_buf, reduce_scatter, residual, rms_norm_weights, - scale) + quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale + ) normalized = self._all_gather(static_fp8[1]) return normalized - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) -class SequenceParallelismPass(VllmInductorPass): +class SequenceParallelismPass(VllmPatternMatcherPass): """ This pass enables sequence parallelism for models. It identifies patterns where an AllReduce operation is followed by @@ -436,47 +448,46 @@ class SequenceParallelismPass(VllmInductorPass): performance. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="sequence_parallelism_pass") + pass_name="sequence_parallelism_pass" + ) for epsilon in [1e-5, 1e-6]: # RMSNorm + Static FP8 quantization patterns fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default FirstAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - fp8_quant_op).register(self.patterns) + epsilon, self.model_dtype, self.device, fp8_quant_op + ).register(self.patterns) MiddleAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - fp8_quant_op).register(self.patterns) + epsilon, self.model_dtype, self.device, fp8_quant_op + ).register(self.patterns) LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - fp8_quant_op).register(self.patterns) + epsilon, self.model_dtype, self.device, fp8_quant_op + ).register(self.patterns) # Normal RMSNorm patterns - FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) + FirstAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) - MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) + MiddleAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) - LastAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) - - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() + LastAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) + self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): - self.begin() - self.dump_graph(graph, "before_sequence_parallelism_pass") - count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns with sequence parallelism", count) - self.dump_graph(graph, "after_sequence_parallelism_pass") - self.end_and_log() + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py index cd3970657522e..ea8b56cf9d6ac 100644 --- a/vllm/compilation/torch25_custom_graph_pass.py +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -37,6 +37,8 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition) return self.uuid() def __setstate__(self, state): - raise ValueError("Cannot unpickle CustomGraphPass because pickling" - " is used for cache key uuid. Use torch>=2.6 with" - " native uuid support for custom passes.") + raise ValueError( + "Cannot unpickle CustomGraphPass because pickling" + " is used for cache key uuid. Use torch>=2.6 with" + " native uuid support for custom passes." + ) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index b822b05b0f1ec..5aa08220bc2d7 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import functools +import operator import time +from typing import ClassVar, Optional +import regex as re import torch from torch._dynamo.utils import lazy_format_graph_code +from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter from vllm.config import VllmConfig from vllm.logger import init_logger @@ -20,16 +24,33 @@ class VllmInductorPass(InductorPass): It provides timing, logging, and dumping utilities. """ + dump_prefix: ClassVar[Optional[int]] = None + """Keep track of pass index for debug dump ordering.""" + def __init__(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - self.model_dtype = config.model_config.dtype if config.model_config \ - else None - self.device = config.device_config.device if config.device_config \ - else None + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None self.pass_name = self.__class__.__name__ + @staticmethod + def time_and_log(call_fn): + @functools.wraps(call_fn) + def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before") + call_fn(self, graph) + self.dump_graph(graph, "after") + self.end_and_log() + + return wrapped + def dump_graph(self, graph: torch.fx.Graph, stage: str): - lazy_format_graph_code(stage, graph.owning_module) + i = VllmInductorPass.dump_prefix + i_str = "" if i is None else f".{i}" + lazy_format_graph_code( + f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module + ) def begin(self): self._start_time = time.perf_counter_ns() @@ -40,8 +61,96 @@ class VllmInductorPass(InductorPass): logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) -class PrinterInductorPass(VllmInductorPass): +class VllmPatternMatcherPass(VllmInductorPass): + """ + A VllmInductorPass that uses the Inductor pattern matcher. + Its main use is providing the dump_patterns utility that dumps the + Inductor pattern matcher patterns into a file, which greatly aids debugging. + TODO(luka) move more utilities to this pass. + """ + + matched_count: int = 0 + """The number of matched patterns in the pass.""" + + _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( + r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>" + ) + + def _replace_op_overloads(self, string: str) -> str: + """Replace <OpOverload(..., ...)> with nicer formulations""" + return self._OP_OVERLOAD_PATTERN.sub( + lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", + string, + ) + + def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): + """ + If debug dumping is enabled, dump the Inductor pattern-matcher patterns + into the debug_dump_path folder next to the dumped fx graphs. + + This method does its best to print something that looks like Python code + for easier debugging and potentially navigation. If any errors appear in + the output, please add to this method. + + TODO(luka): use pattern object to manually produce pattern graph + """ + debug_dump_path = config.compile_debug_dump_path() + if not debug_dump_path: + return + + debug_dump_path.mkdir(parents=True, exist_ok=True) + + from vllm.utils import unique_filepath + + file_path = unique_filepath( + lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py" + ) + + with file_path.open("w") as f: + print( + f"# This file was produced by VllmPatternMatcherPass." + f"dump_patterns for {self.pass_name}.\n" + f"# It does its best to produce valid-Python-looking code but" + f" please add to dump_patterns if there are any errors.\n\n" + f"from torch._higher_order_ops.auto_functionalize import " + f"auto_functionalized as auto_functionalized\n" + f"from torch._inductor.pattern_matcher import *", + file=f, + ) + + for node, patterns in pm_pass.patterns.items(): + # fix the operator.getitem repr + if node[1] == operator.getitem: + node_repr = f"({repr(node[0])}, operator.getitem)" + else: + node_repr = repr(node) + + node_repr = self._replace_op_overloads(node_repr) + + print(f"\n\n# Patterns for op: {node_repr}", file=f) + for i, pattern in enumerate(patterns): + # reserve auto_functionalized ahead of time + pp = PatternPrettyPrinter() + pp.namespace.create_name("auto_functionalized", None) + + # Assemble pattern + out_node = pp.pretty_print(pattern.pattern) + pattern_repr = "\n".join( + [f"def pattern_{i}():"] + + [ + f"{pp.memoized_objs_names[key]} = " + f"{pp.memoized_objs_pp[key]}" + for key in pp.memoized_objs_names + ] + + [f"return {out_node}"] + ).replace("\n", "\n ") + + pattern_repr = self._replace_op_overloads(pattern_repr) + print(f"{pattern_repr}\n", file=f) + + +class PrinterInductorPass(VllmInductorPass): def __init__(self, name: str, config: VllmConfig): super().__init__(config) self.name = name diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 96d4eae2ee9aa..2007b655e2642 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,8 +11,7 @@ from typing import Callable, Optional import torch import vllm.envs as envs -from vllm.config import (CompilationLevel, CUDAGraphMode, - get_current_vllm_config) +from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config from vllm.logger import init_logger logger = init_logger(__name__) @@ -31,10 +30,9 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, - compiled_callable: Optional[Callable] = None, - compilation_level: int = 0): - + def __init__( + self, compiled_callable: Optional[Callable] = None, compilation_level: int = 0 + ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config if compiled_callable is None: @@ -44,14 +42,26 @@ class TorchCompileWrapperWithCustomDispatcher: backend = vllm_config.compilation_config.init_backend(vllm_config) options = None if isinstance(backend, str) and backend == "inductor": - options = get_current_vllm_config( - ).compilation_config.inductor_compile_config + options = ( + get_current_vllm_config().compilation_config.inductor_compile_config + ) + if envs.VLLM_USE_AOT_COMPILE: + options = options or {} + # This effectively drop all the guards. + # We need this because bytecode hook is not used any more to + # drop guards in the AOT compile mode. + options["guard_filter_fn"] = lambda guards: [False for _ in guards] + if hasattr(torch._dynamo.config, "enable_aot_compile"): + torch._dynamo.config.enable_aot_compile = True + else: + msg = "torch._dynamo.config.enable_aot_compile is not " + msg += "available. AOT compile is disabled and please " + msg += "upgrade PyTorch version to use AOT compile." + logger.warning(msg) compiled_callable = torch.compile( - self.forward, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend, - options=options) + self.forward, fullgraph=True, backend=backend, options=options + ) self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ @@ -61,8 +71,18 @@ class TorchCompileWrapperWithCustomDispatcher: # read the env var to determine whether to use the custom dispatcher # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. - self.use_custom_dispatcher: bool = \ + self.use_custom_dispatcher: bool = ( compilation_level >= CompilationLevel.DYNAMO_ONCE + ) + + def aot_compile(self, *args, **kwargs): + if not hasattr(self.compiled_callable, "aot_compile"): + raise RuntimeError( + "aot_compile is not supported by the current configuration. " + + "Please make sure torch.compile is enabled with the latest " + + f"version of PyTorch (current using torch: {torch.__version__})" + ) + return self.compiled_callable.aot_compile((args, kwargs)) def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. @@ -72,8 +92,7 @@ class TorchCompileWrapperWithCustomDispatcher: return self.compiled_callable(*args, **kwargs) @abstractmethod - def forward(self, *args, **kwargs): - ... + def forward(self, *args, **kwargs): ... def bytecode_hook(self, old_code: CodeType, new_code: CodeType): """Hook to save the compiled bytecode for direct execution.""" @@ -94,33 +113,41 @@ class TorchCompileWrapperWithCustomDispatcher: return self.compiled_codes.append(new_code) - debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path - if isinstance(debug_dump_dir, str) and debug_dump_dir != "": - rank = self.vllm_config.parallel_config.rank - decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}", - "transformed_code.py") - if not os.path.exists(decompiled_file): + + path = self.vllm_config.compile_debug_dump_path() + if path: + decompiled_file = path / "transformed_code.py" + if not decompiled_file.exists(): try: # usually the decompilation will succeed for most models, # as we guarantee a full-graph compilation in Dynamo. # but there's no 100% guarantee, since decompliation is # not a reversible process. import depyf + src = depyf.decompile(new_code) with open(decompiled_file, "w") as f: f.write(src) - logger.debug("Dynamo transformed code saved to %s", - decompiled_file) + logger.debug("Dynamo transformed code saved to %s", decompiled_file) except Exception: pass - if self.vllm_config.compilation_config.cudagraph_mode != \ - CUDAGraphMode.NONE and "update" in new_code.co_names: + if ( + self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and "update" in new_code.co_names + ): import depyf + src = depyf.decompile(new_code) - msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + msg = ( + "Assigning / modifying buffers of nn.Module during forward pass is not " + "allowed when using cudagraph inside the compiler because it will " + "cause silent errors. Please use eager mode or fix the code. The " + "following code contains clues about which buffer is being modified " + f"(please search for the usage of the function `update`):\n{src}" + ) raise RuntimeError(msg) @contextmanager @@ -131,8 +158,9 @@ class TorchCompileWrapperWithCustomDispatcher: variables as the original code. Therefore we can directly switch the code object in the function and call it. - See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. - """ # noqa + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 + for more details. + """ self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 6ce40626b3a81..6a0197d044dcd 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,4109 +1,99 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# ruff: noqa: F401 -import ast -import copy -import enum -import hashlib -import inspect -import json -import textwrap -import uuid -import warnings -from collections.abc import Mapping -from contextlib import contextmanager -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace -from functools import cached_property, lru_cache -from importlib.util import find_spec -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args) +from vllm.config.cache import CacheConfig +from vllm.config.compilation import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + PassConfig, +) +from vllm.config.device import DeviceConfig +from vllm.config.kv_events import KVEventsConfig +from vllm.config.kv_transfer import KVTransferConfig +from vllm.config.load import LoadConfig +from vllm.config.lora import LoRAConfig +from vllm.config.model import ( + ModelConfig, + iter_architecture_defaults, + try_match_architecture_defaults, +) +from vllm.config.multimodal import MultiModalConfig +from vllm.config.observability import ObservabilityConfig +from vllm.config.parallel import EPLBConfig, ParallelConfig +from vllm.config.pooler import PoolerConfig +from vllm.config.scheduler import SchedulerConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.config.speech_to_text import SpeechToTextConfig +from vllm.config.structured_outputs import StructuredOutputsConfig +from vllm.config.utils import ( + ConfigType, + SupportsMetricsInfo, + config, + get_attr_docs, + is_init_field, + update_config, +) +from vllm.config.vllm import ( + VllmConfig, + get_cached_compilation_config, + get_current_vllm_config, + get_layers_from_vllm_config, + set_current_vllm_config, +) -import regex as re -import torch -from pydantic import (ConfigDict, SkipValidation, field_validator, - model_validator) -from pydantic.dataclasses import dataclass -from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE -from typing_extensions import Self, assert_never, runtime_checkable - -import vllm.envs as envs -from vllm import version -from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, - PrefixCachingHashAlgo) -from vllm.config.compilation import (CompilationConfig, CompilationLevel, - CUDAGraphMode, PassConfig) -from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, - ParallelConfig) -from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy -from vllm.config.utils import ConfigType, config -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.platforms import current_platform -from vllm.transformers_utils.config import ( - ConfigFormat, get_config, get_hf_image_processor_config, - get_hf_text_config, get_pooling_config, - get_sentence_transformer_tokenizer_config, is_encoder_decoder, - is_interleaved, maybe_override_with_speculators_target_model, - try_get_generation_config, try_get_safetensors_metadata, - try_get_tokenizer_config, uses_mrope) -from vllm.transformers_utils.s3_utils import S3Model -from vllm.transformers_utils.utils import is_s3, maybe_model_redirect -from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType, - LazyLoader, common_broadcastable_dtype, random_uuid) - -if TYPE_CHECKING: - from _typeshed import DataclassInstance - from transformers.configuration_utils import PretrainedConfig - - import vllm.model_executor.layers.quantization as me_quant - import vllm.model_executor.models as me_models - from vllm.model_executor.layers.quantization import QuantizationMethods - from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) - from vllm.model_executor.model_loader import LoadFormats - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.sample.logits_processor import LogitsProcessor - - HfOverrides = Union[dict, Callable[[type], type]] -else: - DataclassInstance = Any - PretrainedConfig = Any - QuantizationConfig = Any - QuantizationMethods = Any - BaseModelLoader = Any - LoadFormats = Any - TensorizerConfig = Any - LogitsProcessor = Any - HfOverrides = Union[dict[str, Any], Callable[[type], type]] - - me_quant = LazyLoader("model_executor", globals(), - "vllm.model_executor.layers.quantization") - me_models = LazyLoader("model_executor", globals(), - "vllm.model_executor.models") - -logger = init_logger(__name__) -DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance) - -TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward", "transcription", "draft"] - -_ResolvedTask = Literal["generate", "transcription", "encode", "embed", - "classify", "reward", "draft"] - -RunnerOption = Literal["auto", "generate", "pooling", "draft"] - -RunnerType = Literal["generate", "pooling", "draft"] - -ConvertOption = Literal["auto", "none", "embed", "classify", "reward"] - -ConvertType = Literal["none", "embed", "classify", "reward"] - -_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { - "generate": ["generate", "transcription"], - "pooling": ["embedding", "embed", "classify", "score", "reward"], - "draft": ["draft"], -} - -_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { - "generate": [], - "pooling": ["embed", "classify", "reward"], - "draft": [], -} - -# Some model suffixes are based on auto classes from Transformers: -# https://huggingface.co/docs/transformers/en/model_doc/auto -# NOTE: Items higher on this list priority over lower ones -_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ - ("ForCausalLM", ("generate", "none")), - ("ForConditionalGeneration", ("generate", "none")), - ("ChatModel", ("generate", "none")), - ("LMHeadModel", ("generate", "none")), - ("ForTextEncoding", ("pooling", "embed")), - ("EmbeddingModel", ("pooling", "embed")), - ("ForSequenceClassification", ("pooling", "classify")), - ("ForAudioClassification", ("pooling", "classify")), - ("ForImageClassification", ("pooling", "classify")), - ("ForVideoClassification", ("pooling", "classify")), - ("ClassificationModel", ("pooling", "classify")), - ("ForRewardModeling", ("pooling", "reward")), - ("RewardModel", ("pooling", "reward")), - # Let other `*Model`s take priority - ("Model", ("pooling", "embed")), +# __all__ should only contain classes and functions. +# Types and globals should be imported from their respective modules. +__all__ = [ + # From vllm.config.cache + "CacheConfig", + # From vllm.config.compilation + "CompilationConfig", + "CompilationLevel", + "CUDAGraphMode", + "PassConfig", + # From vllm.config.device + "DeviceConfig", + # From vllm.config.kv_events + "KVEventsConfig", + # From vllm.config.kv_transfer + "KVTransferConfig", + # From vllm.config.load + "LoadConfig", + # From vllm.config.lora + "LoRAConfig", + # From vllm.config.model + "ModelConfig", + "iter_architecture_defaults", + "try_match_architecture_defaults", + # From vllm.config.multimodal + "MultiModalConfig", + # From vllm.config.observability + "ObservabilityConfig", + # From vllm.config.parallel + "EPLBConfig", + "ParallelConfig", + # From vllm.config.pooler + "PoolerConfig", + # From vllm.config.scheduler + "SchedulerConfig", + # From vllm.config.speculative + "SpeculativeConfig", + # From vllm.config.speech_to_text + "SpeechToTextConfig", + # From vllm.config.structured_outputs + "StructuredOutputsConfig", + # From vllm.config.utils + "ConfigType", + "SupportsMetricsInfo", + "config", + "get_attr_docs", + "is_init_field", + "update_config", + # From vllm.config.vllm + "VllmConfig", + "get_cached_compilation_config", + "get_current_vllm_config", + "set_current_vllm_config", + "get_layers_from_vllm_config", ] - - -def iter_architecture_defaults(): - yield from _SUFFIX_TO_DEFAULTS - - -def try_match_architecture_defaults( - architecture: str, - *, - runner_type: Optional[RunnerType] = None, - convert_type: Optional[ConvertType] = None, -) -> Optional[tuple[str, tuple[RunnerType, ConvertType]]]: - for suffix, (default_runner_type, - default_convert_type) in iter_architecture_defaults(): - if ((runner_type is None or runner_type == default_runner_type) and - (convert_type is None or convert_type == default_convert_type) - and architecture.endswith(suffix)): - return suffix, (default_runner_type, default_convert_type) - - return None - - -@runtime_checkable -class SupportsHash(Protocol): - - def compute_hash(self) -> str: - ... - - -class SupportsMetricsInfo(Protocol): - - def metrics_info(self) -> dict[str, str]: - ... - - -class ModelImpl(str, enum.Enum): - AUTO = "auto" - VLLM = "vllm" - TRANSFORMERS = "transformers" - - -def get_attr_docs(cls: type[Any]) -> dict[str, str]: - """ - Get any docstrings placed after attribute assignments in a class body. - - https://davidism.com/mit-license/ - """ - - def pairwise(iterable): - """ - Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - - Can be removed when Python 3.9 support is dropped. - """ - iterator = iter(iterable) - a = next(iterator, None) - - for b in iterator: - yield a, b - a = b - - try: - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] - except (OSError, KeyError, TypeError): - # HACK: Python 3.13+ workaround - set missing __firstlineno__ - # Workaround can be removed after we upgrade to pydantic==2.12.0 - with open(inspect.getfile(cls)) as f: - for i, line in enumerate(f): - if f"class {cls.__name__}" in line and ":" in line: - cls.__firstlineno__ = i + 1 - break - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] - - if not isinstance(cls_node, ast.ClassDef): - raise TypeError("Given object was not a class.") - - out = {} - - # Consider each pair of nodes. - for a, b in pairwise(cls_node.body): - # Must be an assignment then a constant string. - if (not isinstance(a, (ast.Assign, ast.AnnAssign)) - or not isinstance(b, ast.Expr) - or not isinstance(b.value, ast.Constant) - or not isinstance(b.value.value, str)): - continue - - doc = inspect.cleandoc(b.value.value) - - # An assignment can have multiple targets (a = b = v), but an - # annotated assignment only has one target. - targets = a.targets if isinstance(a, ast.Assign) else [a.target] - - for target in targets: - # Must be assigning to a plain name. - if not isinstance(target, ast.Name): - continue - - out[target.id] = doc - - return out - - -def get_field(cls: ConfigType, name: str) -> Field: - """Get the default factory field of a dataclass by name. Used for getting - default factory fields in `EngineArgs`.""" - if not is_dataclass(cls): - raise TypeError("The given class is not a dataclass.") - cls_fields = {f.name: f for f in fields(cls)} - if name not in cls_fields: - raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields[name] - if (default_factory := named_field.default_factory) is not MISSING: - return field(default_factory=default_factory) - if (default := named_field.default) is not MISSING: - return field(default=default) - raise ValueError( - f"{cls.__name__}.{name} must have a default value or default factory.") - - -def is_init_field(cls: ConfigType, name: str) -> bool: - return next(f for f in fields(cls) if f.name == name).init - - -TokenizerMode = Literal["auto", "slow", "mistral", "custom"] -ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -MMEncoderTPMode = Literal["weights", "data"] - - -class LogprobsMode(enum.Enum): - RAW_LOGITS = "raw_logits" - RAW_LOGPROBS = "raw_logprobs" - PROCESSED_LOGITS = "processed_logits" - PROCESSED_LOGPROBS = "processed_logprobs" - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class ModelConfig: - """Configuration for the model.""" - - model: str = "Qwen/Qwen3-0.6B" - """Name or path of the Hugging Face model to use. It is also used as the - content for `model_name` tag in metrics output when `served_model_name` is - not specified.""" - runner: RunnerOption = "auto" - """The type of model runner to use. Each vLLM instance only supports one - model runner, even if the same model can be used for multiple types.""" - convert: ConvertOption = "auto" - """Convert the model using adapters defined in - [vllm.model_executor.models.adapters][]. The most common use case is to - adapt a text generation model to be used for pooling tasks.""" - task: Optional[TaskOption] = None - """[DEPRECATED] The task to use the model for. If the model supports more - than one model runner, this is used to select which model runner to run. - - Note that the model may support other tasks using the same model runner. - """ - tokenizer: SkipValidation[str] = None # type: ignore - """Name or path of the Hugging Face tokenizer to use. If unspecified, model - name or path will be used.""" - tokenizer_mode: TokenizerMode = "auto" - """Tokenizer mode:\n - - "auto" will use the fast tokenizer if available.\n - - "slow" will always use the slow tokenizer.\n - - "mistral" will always use the tokenizer from `mistral_common`.\n - - "custom" will use --tokenizer to select the preregistered tokenizer.""" - trust_remote_code: bool = False - """Trust remote code (e.g., from HuggingFace) when downloading the model - and tokenizer.""" - dtype: Union[ModelDType, torch.dtype] = "auto" - """Data type for model weights and activations:\n - - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 - precision for BF16 models.\n - - "half" for FP16. Recommended for AWQ quantization.\n - - "float16" is the same as "half".\n - - "bfloat16" for a balance between precision and range.\n - - "float" is shorthand for FP32 precision.\n - - "float32" for FP32 precision.""" - seed: Optional[int] = None - """Random seed for reproducibility. Initialized to None in V0, but - initialized to 0 in V1.""" - hf_config_path: Optional[str] = None - """Name or path of the Hugging Face config to use. If unspecified, model - name or path will be used.""" - allowed_local_media_path: str = "" - """Allowing API requests to read local images or videos from directories - specified by the server file system. This is a security risk. Should only - be enabled in trusted environments.""" - revision: Optional[str] = None - """The specific model version to use. It can be a branch name, a tag name, - or a commit id. If unspecified, will use the default version.""" - code_revision: Optional[str] = None - """The specific revision to use for the model code on the Hugging Face Hub. - It can be a branch name, a tag name, or a commit id. If unspecified, will - use the default version.""" - rope_scaling: dict[str, Any] = field(default_factory=dict) - """RoPE scaling configuration. For example, - `{"rope_type":"dynamic","factor":2.0}`.""" - rope_theta: Optional[float] = None - """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE - theta improves the performance of the scaled model.""" - tokenizer_revision: Optional[str] = None - """The specific revision to use for the tokenizer on the Hugging Face Hub. - It can be a branch name, a tag name, or a commit id. If unspecified, will - use the default version.""" - max_model_len: SkipValidation[int] = None # type: ignore - """Model context length (prompt and output). If unspecified, will be - automatically derived from the model config. - - When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable - format. Examples:\n - - 1k -> 1000\n - - 1K -> 1024\n - - 25.6k -> 25,600""" - spec_target_max_model_len: Optional[int] = None - """Specify the maximum length for spec decoding draft models.""" - quantization: SkipValidation[Optional[QuantizationMethods]] = None - """Method used to quantize the weights. If `None`, we first check the - `quantization_config` attribute in the model config file. If that is - `None`, we assume the model weights are not quantized and use `dtype` to - determine the data type of the weights.""" - enforce_eager: bool = False - """Whether to always use eager-mode PyTorch. If True, we will disable CUDA - graph and always execute the model in eager mode. If False, we will use - CUDA graph and eager execution in hybrid for maximal performance and - flexibility.""" - max_seq_len_to_capture: int = 8192 - """Maximum sequence len covered by CUDA graphs. When a sequence has context - length larger than this, we fall back to eager mode. Additionally for - encoder-decoder models, if the sequence length of the encoder input is - larger than this, we fall back to the eager mode.""" - max_logprobs: int = 20 - """Maximum number of log probabilities to return when `logprobs` is - specified in `SamplingParams`. The default value comes the default for the - OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * - vocab_size) logprobs are allowed to be returned and it may cause OOM.""" - logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS - """Indicates the content returned in the logprobs and prompt_logprobs. - Supported mode: - 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. - Raw means the values before applying any logit processors, like bad words. - Processed means the values after applying all processors, including - temperature and top_k/top_p. - """ - disable_sliding_window: bool = False - """Whether to disable sliding window. If True, we will disable the sliding - window functionality of the model, capping to sliding window size. If the - model does not support sliding window, this argument is ignored.""" - disable_cascade_attn: bool = False - """Disable cascade attention for V1. While cascade attention does not - change the mathematical correctness, disabling it could be useful for - preventing potential numerical issues. Note that even if this is set to - False, cascade attention will be only used when the heuristic tells that - it's beneficial.""" - skip_tokenizer_init: bool = False - """Skip initialization of tokenizer and detokenizer. Expects valid - `prompt_token_ids` and `None` for prompt from the input. The generated - output will contain token ids.""" - enable_prompt_embeds: bool = False - """If `True`, enables passing text embeddings as inputs via the - `prompt_embeds` key. Note that enabling this will double the time required - for graph compilation.""" - served_model_name: Optional[Union[str, list[str]]] = None - """The model name(s) used in the API. If multiple names are provided, the - server will respond to any of the provided names. The model name in the - model field of a response will be the first name in this list. If not - specified, the model name will be the same as the `--model` argument. Noted - that this name(s) will also be used in `model_name` tag content of - prometheus metrics, if multiple names provided, metrics tag will take the - first one.""" - limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) - """Maximum number of data items per modality per prompt. Only applicable - for multimodal models.""" - interleave_mm_strings: bool = False - """Enable fully interleaved support for multimodal prompts, while using - --chat-template-content-format=string. Defaults to False.""" - skip_mm_profiling: bool = False - """When enabled, skips multimodal memory profiling and only profiles with - language backbone model during engine initialization. - """ - media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set - `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ - use_async_output_proc: bool = True - """Whether to use async output processor.""" - config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value - """The format of the model config to load:\n - - "auto" will try to load the config in hf format if available else it - will try to load in mistral format.\n - - "hf" will load the config in hf format.\n - - "mistral" will load the config in mistral format.""" - hf_token: Optional[Union[bool, str]] = None - """The token to use as HTTP bearer authorization for remote files . If - `True`, will use the token generated when running `huggingface-cli login` - (stored in `~/.huggingface`).""" - hf_overrides: HfOverrides = field(default_factory=dict) - """If a dictionary, contains arguments to be forwarded to the Hugging Face - config. If a callable, it is called to update the HuggingFace config.""" - mm_processor_kwargs: Optional[dict[str, Any]] = None - """Arguments to be forwarded to the model's processor for multi-modal data, - e.g., image processor. Overrides for the multi-modal processor obtained - from `AutoProcessor.from_pretrained`. The available overrides depend on the - model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. - """ - mm_processor_cache_gb: int = 4 - """The size (in GiB) of the multi-modal processor cache, which is used to - avoid re-processing past multi-modal inputs. - - This cache is duplicated for each API process and engine core process, - resulting in a total memory usage of - `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. - - Set to `0` to disable this cache completely (not recommended).""" - mm_encoder_tp_mode: MMEncoderTPMode = "weights" - """Indicates how to optimize multi-modal encoder inference using - tensor parallelism (TP). - - - `"weights"`: Within the same vLLM engine, split the weights of - each layer across TP ranks. (default TP behavior) - - `"data"`: Within the same vLLM engine, split the batched input data - across TP ranks to process the data in parallel, while hosting - the full weights on each TP rank. - This batch-level DP is not to be confused with API request-level - DP (which is controlled by `--data-parallel-size`). - This is only supported on a per-model basis and falls back to - `"weights"` if the encoder does not support DP.""" - override_neuron_config: dict[str, Any] = field(default_factory=dict) - """Initialize non-default neuron config or override default neuron config - that are specific to Neuron devices, this argument will be used to - configure the neuron config that can not be gathered from the vllm - arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`.""" - pooler_config: Optional["PoolerConfig"] = field(init=False) - """Pooler config which controls the behaviour of output pooling in pooling - models.""" - override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None - """Initialize non-default pooling config or override default pooling config - for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. - """ - logits_processor_pattern: Optional[str] = None - """Optional regex pattern specifying valid logits processor qualified names - that can be passed with the `logits_processors` extra completion argument. - Defaults to `None`, which allows no processors.""" - generation_config: str = "auto" - """The folder path to the generation config. Defaults to `"auto"`, the - generation config will be loaded from model path. If set to `"vllm"`, no - generation config is loaded, vLLM defaults will be used. If set to a folder - path, the generation config will be loaded from the specified folder path. - If `max_new_tokens` is specified in generation config, then it sets a - server-wide limit on the number of output tokens for all requests.""" - override_generation_config: dict[str, Any] = field(default_factory=dict) - """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If - used with `--generation-config auto`, the override parameters will be - merged with the default config from the model. If used with - `--generation-config vllm`, only the override parameters are used.""" - enable_sleep_mode: bool = False - """Enable sleep mode for the engine (only cuda platform is supported).""" - model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value - """Which implementation of the model to use:\n - - "auto" will try to use the vLLM implementation, if it exists, and fall - back to the Transformers implementation if no vLLM implementation is - available.\n - - "vllm" will use the vLLM model implementation.\n - - "transformers" will use the Transformers model implementation.""" - override_attention_dtype: Optional[str] = None - """Override dtype for attention""" - logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None - """One or more logits processors' fully-qualified class names or class - definitions""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.model) - factors.append(self.dtype) - factors.append(self.quantization) - factors.append(self.revision) - factors.append(self.code_revision) - factors.append(self.max_model_len) - factors.append(self.max_logprobs) - factors.append(self.disable_sliding_window) - factors.append(self.trust_remote_code) - factors.append(self.generation_config) - factors.append(self.model_impl) - factors.append(self.override_generation_config) - factors.append(self.rope_scaling) - factors.append(self.rope_theta) - # hf_config can control how the model looks! - factors.append(self.hf_config.to_json_string()) - str_factors = str(factors) - assert_hashable(str_factors) - return hashlib.sha256(str(factors).encode()).hexdigest() - - def __post_init__(self) -> None: - # Set the default seed to 0 in V1. - # NOTE(woosuk): In V0, we set the default seed to None because the - # driver worker shares the same process as the user process, and thus - # setting a seed affects the user process as well. - # In V1, we use separate processes for workers (unless - # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here - # doesn't affect the user process. However, without a consistent seed, - # different tensor parallel workers would sample different tokens, - # leading to inconsistent results. - if envs.VLLM_USE_V1 and self.seed is None: - self.seed = 0 - if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: - logger.warning( - "The global random seed is set to %d. Since " - "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " - "affect the random state of the Python process that " - "launched vLLM.", self.seed) - - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - - # Keep set served_model_name before maybe_model_redirect(self.model) - self.served_model_name = get_served_model_name(self.model, - self.served_model_name) - self.model = maybe_model_redirect(self.model) - # The tokenizer is consistent with the model by default. - if self.tokenizer is None: - self.tokenizer = self.model - if self.tokenizer_revision is None: - self.tokenizer_revision = self.revision - self.tokenizer = maybe_model_redirect(self.tokenizer) - - if isinstance(self.hf_config_path, str): - self.hf_config_path = maybe_model_redirect(self.hf_config_path) - - if callable(self.hf_overrides): - hf_overrides_kw = {} - hf_overrides_fn = self.hf_overrides - else: - hf_overrides_kw = self.hf_overrides - hf_overrides_fn = None - - if self.rope_scaling: - hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling} - hf_overrides_kw.update(hf_override) - hf_overrides_str = json.dumps(hf_overrides_kw) - msg = ( - "`--rope-scaling` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") - warnings.warn(DeprecationWarning(msg), stacklevel=2) - if self.rope_theta is not None: - hf_override = {"rope_theta": self.rope_theta} - hf_overrides_kw.update(hf_override) - hf_overrides_str = json.dumps(hf_overrides_kw) - msg = ( - "`--rope-theta` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer) - - if (backend := envs.VLLM_ATTENTION_BACKEND - ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: - raise ValueError( - "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " - "module was not found. See " - "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 - "for instructions on how to install it.") - - from vllm.platforms import current_platform - - if (self.override_attention_dtype is not None - and not current_platform.is_rocm()): - warnings.warn( - "override-attention-dtype is set but not using ROCm platform", - stacklevel=2) - - if (self.enable_sleep_mode - and not current_platform.is_sleep_mode_available()): - raise ValueError( - "Sleep mode is not supported on current platform.") - - if isinstance(self.config_format, str): - self.config_format = ConfigFormat(self.config_format) - - hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, - self.revision, - self.code_revision, - self.config_format, - hf_overrides_kw=hf_overrides_kw, - hf_overrides_fn=hf_overrides_fn) - - self.hf_config = hf_config - self.hf_text_config = get_hf_text_config(self.hf_config) - self.attention_chunk_size = getattr(self.hf_text_config, - "attention_chunk_size", None) - self.encoder_config = self._get_encoder_config() - self.hf_image_processor_config = get_hf_image_processor_config( - self.model, hf_token=self.hf_token, revision=self.revision) - - architectures = self.architectures - registry = self.registry - is_generative_model = registry.is_text_generation_model( - architectures, self) - is_pooling_model = registry.is_pooling_model(architectures, self) - - def _task_to_convert(task: TaskOption) -> ConvertType: - if task == "embedding" or task == "embed": - return "embed" - if task == "classify": - return "classify" - if task == "reward": - return "reward" - if task == "score": - new_task = self._get_default_pooling_task(architectures) - return "classify" if new_task == "classify" else "embed" - - return "none" - - if self.task is not None: - runner: RunnerOption = "auto" - convert: ConvertOption = "auto" - msg_prefix = ("The 'task' option has been deprecated and will be " - "removed in v0.13.0 or v1.0, whichever comes first.") - msg_hint = "Please remove this option." - - is_generative_task = self.task in _RUNNER_TASKS["generate"] - is_pooling_task = self.task in _RUNNER_TASKS["pooling"] - - if is_generative_model and is_pooling_model: - if is_generative_task: - runner = "generate" - convert = "auto" - msg_hint = ("Please replace this option with `--runner " - "generate` to continue using this model " - "as a generative model.") - elif is_pooling_task: - runner = "pooling" - convert = "auto" - msg_hint = ("Please replace this option with `--runner " - "pooling` to continue using this model " - "as a pooling model.") - else: # task == "auto" - pass - elif is_generative_model or is_pooling_model: - if is_generative_task: - runner = "generate" - convert = "auto" - msg_hint = "Please remove this option" - elif is_pooling_task: - runner = "pooling" - convert = _task_to_convert(self.task) - msg_hint = ("Please replace this option with `--convert " - f"{convert}` to continue using this model " - "as a pooling model.") - else: # task == "auto" - pass - else: - raise AssertionError("The model should be a generative or " - "pooling model when task is set to " - f"{self.task!r}.") - - self.runner = runner - self.convert = convert - - msg = f"{msg_prefix} {msg_hint}" - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - self.runner_type = self._get_runner_type(architectures, self.runner) - self.convert_type = self._get_convert_type(architectures, - self.runner_type, - self.convert) - - if self.runner_type == "generate" and not is_generative_model: - generate_converts = _RUNNER_CONVERTS["generate"] - if self.convert_type not in generate_converts: - # Currently we don't have any converters for generative models - raise ValueError( - "This model does not support `--runner generate`.") - if self.runner_type == "pooling" and not is_pooling_model: - pooling_converts = _RUNNER_CONVERTS["pooling"] - if self.convert_type not in pooling_converts: - convert_option = "<" + "|".join(pooling_converts) + ">" - raise ValueError( - "This model does not support `--runner pooling`. " - f"You can pass `--convert {convert_option} to adapt " - "it into a pooling model.") - - self.supported_tasks = self._get_supported_tasks( - architectures, self.runner_type, self.convert_type) - - # Note: Initialize these attributes early because transformers fallback - # may fail to load dynamic modules in child processes - model_info, arch = registry.inspect_model_cls(architectures, self) - self._model_info = model_info - self._architecture = arch - logger.info("Resolved architecture: %s", arch) - - self.pooler_config = self._init_pooler_config() - - self.dtype = _get_and_verify_dtype( - self.model, - self.hf_config, - self.dtype, - is_pooling_model=self.runner_type == "pooling", - revision=self.revision, - ) - - # Interleaved attention is not supported by some backends in V0 - if (not self.disable_sliding_window - and is_interleaved(self.hf_text_config) - and not envs.VLLM_USE_V1 - and (backend := envs.VLLM_ATTENTION_BACKEND) - in ("XFORMERS", "FLASHINFER")): - logger.warning_once( - "%s has interleaved attention, which is currently not " - "supported by the %s backend. Disabling sliding window and " - "capping the max length to the sliding window size (%d).", - self.hf_text_config.model_type, - backend, - self.hf_text_config.sliding_window, - ) - self.disable_sliding_window = True - - self.original_max_model_len = self.max_model_len - self.max_model_len = self.get_and_verify_max_len(self.max_model_len) - self.multimodal_config = self._init_multimodal_config() - - if self.disable_sliding_window: - # Set after get_and_verify_max_len to ensure that max_model_len - # can be correctly capped to sliding window size - self.hf_text_config.sliding_window = None - - if not self.skip_tokenizer_init: - self._verify_tokenizer_mode() - - if (not current_platform.is_neuron() and self.override_neuron_config): - raise ValueError( - "`override_neuron_config` is only supported on Neuron.") - - # Avoid running try_verify_and_update_config multiple times - self.config_updated = False - - self._verify_quantization() - self._verify_cuda_graph() - self._verify_bnb_config() - - @field_validator("quantization", mode="before") - @classmethod - def validate_quantization_before(cls, value: Any) -> Any: - if isinstance(value, str): - return value.lower() - return value - - @model_validator(mode="after") - def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": - if not isinstance(self.tokenizer, str): - raise ValueError("tokenizer must be a string after __post_init__.") - if not isinstance(self.max_model_len, int): - raise ValueError( - "max_model_len must be an integer after __post_init__.") - return self - - def _get_transformers_backend_cls(self) -> str: - """Determine which Transformers backend class will be used if - `model_impl` is set to `transformers` or `auto`.""" - if getattr(self, "runner_type", self.runner) == "pooling": - return "TransformersModel" - if self.hf_config != self.hf_text_config: - # If 'hf_text_config' is the same as 'hf_config'. If not, it is - # probably a composite config, i.e. multimodal - return "TransformersForMultimodalLM" - return "TransformersForCausalLM" - - def using_transformers_backend(self) -> bool: - """Check if the model is using the Transformers backend class.""" - return self.architecture == self._get_transformers_backend_cls() - - @property - def registry(self): - return me_models.ModelRegistry - - @property - def architectures(self) -> list[str]: - return getattr(self.hf_config, "architectures", []) - - @property - def architecture(self) -> str: - """The architecture vllm actually used.""" - return self._architecture - - def maybe_pull_model_tokenizer_for_s3(self, model: str, - tokenizer: str) -> None: - """Pull model/tokenizer from S3 to temporary directory when needed. - - Args: - model: Model name or path - tokenizer: Tokenizer name or path - """ - if not (is_s3(model) or is_s3(tokenizer)): - return - - if is_s3(model): - s3_model = S3Model() - s3_model.pull_files(model, - allow_pattern=["*.model", "*.py", "*.json"]) - self.model_weights = model - self.model = s3_model.dir - - # If tokenizer is same as model, download to same directory - if model == tokenizer: - s3_model.pull_files(model, - ignore_pattern=[ - "*.pt", "*.safetensors", "*.bin", - "*.tensors" - ]) - self.tokenizer = s3_model.dir - return - - # Only download tokenizer if needed and not already handled - if is_s3(tokenizer): - s3_tokenizer = S3Model() - s3_tokenizer.pull_files( - model, - ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) - self.tokenizer = s3_tokenizer.dir - - def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: - if self._model_info.supports_multimodal: - return MultiModalConfig( - limit_per_prompt=self.limit_mm_per_prompt, - media_io_kwargs=self.media_io_kwargs, - mm_processor_kwargs=self.mm_processor_kwargs, - mm_processor_cache_gb=self.mm_processor_cache_gb, - mm_encoder_tp_mode=self.mm_encoder_tp_mode, - interleave_mm_strings=self.interleave_mm_strings, - skip_mm_profiling=self.skip_mm_profiling, - ) - - return None - - def set_mm_processor_cache_gb(self, value: int) -> None: - mm_config = self.get_multimodal_config() - - self.mm_processor_cache_gb = value - mm_config.mm_processor_cache_gb = value - - def _get_encoder_config(self): - return get_sentence_transformer_tokenizer_config( - self.model, self.revision) - - def _init_pooler_config(self) -> Optional["PoolerConfig"]: - if self.runner_type == "pooling": - if isinstance(self.override_pooler_config, dict): - self.override_pooler_config = PoolerConfig( - **self.override_pooler_config) - - pooler_config = self.override_pooler_config or PoolerConfig() - - base_config = get_pooling_config(self.model, self.revision) - if base_config is not None: - # Only set values that are not overridden by the user - for k, v in base_config.items(): - if getattr(pooler_config, k) is None: - setattr(pooler_config, k, v) - - default_pooling_type = self._model_info.default_pooling_type - if pooler_config.pooling_type is None: - pooler_config.pooling_type = default_pooling_type - - return pooler_config - - return None - - def _verify_tokenizer_mode(self) -> None: - tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) - if tokenizer_mode not in get_args(TokenizerMode): - raise ValueError( - f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - f"one of {get_args(TokenizerMode)}.") - self.tokenizer_mode = tokenizer_mode - - def _get_default_runner_type( - self, - architectures: list[str], - ) -> RunnerType: - registry = self.registry - - # Some Sentence Transformers models use *ForCausalLM archs - if get_pooling_config(self.model, self.revision): - return "pooling" - - for arch in architectures: - if arch in registry.get_supported_archs(): - if registry.is_pooling_model(architectures, self): - return "pooling" - if registry.is_text_generation_model(architectures, self): - return "generate" - - match = try_match_architecture_defaults(arch) - if match: - _, (runner_type, _) = match - return runner_type - - return "generate" - - def _get_runner_type( - self, - architectures: list[str], - runner: RunnerOption, - ) -> RunnerType: - if runner != "auto": - return runner - - runner_type = self._get_default_runner_type(architectures) - - # Don't log the most common case - if runner_type != "generate": - logger.info( - "Resolved `--runner auto` to `--runner %s`. " - "Pass the value explicitly to silence this message.", - runner_type) - - return runner_type - - def _get_default_convert_type( - self, - architectures: list[str], - runner_type: RunnerType, - ) -> ConvertType: - registry = self.registry - - for arch in architectures: - if arch in registry.get_supported_archs(): - if (runner_type == "generate" - and registry.is_text_generation_model( - architectures, self)): - return "none" - if (runner_type == "pooling" - and registry.is_pooling_model(architectures, self)): - return "none" - - match = try_match_architecture_defaults(arch, - runner_type=runner_type) - if match: - _, (_, convert_type) = match - return convert_type - - # This is to handle Sentence Transformers models that use *ForCausalLM - # and also multi-modal pooling models which are not defined as - # Sentence Transformers models - if runner_type == "pooling": - return "embed" - - return "none" - - def _get_convert_type( - self, - architectures: list[str], - runner_type: RunnerType, - convert: ConvertOption, - ) -> ConvertType: - if convert != "auto": - return convert - - convert_type = self._get_default_convert_type(architectures, - runner_type) - - # Don't log the most common case - if convert_type != "none": - logger.info( - "Resolved `--convert auto` to `--convert %s`. " - "Pass the value explicitly to silence this message.", - convert_type) - - return convert_type - - def _get_supported_generation_tasks( - self, - architectures: list[str], - convert_type: ConvertType, - ) -> list[_ResolvedTask]: - registry = self.registry - - if registry.is_transcription_only_model(architectures, self): - return ["transcription"] - - # TODO: Use get_supported_generation_tasks once V0 is removed - supported_tasks = list[_ResolvedTask]() - if (registry.is_text_generation_model(architectures, self) - or convert_type in _RUNNER_CONVERTS["generate"]): - supported_tasks.append("generate") - - if registry.is_transcription_model(architectures, self): - supported_tasks.append("transcription") - - return supported_tasks - - def _get_default_pooling_task( - self, - architectures: list[str], - ) -> Literal["embed", "classify", "reward"]: - if self.registry.is_cross_encoder_model(architectures, self): - return "classify" - - for arch in architectures: - match = try_match_architecture_defaults(arch, - runner_type="pooling") - if match: - _, (_, convert_type) = match - assert convert_type != "none" - return convert_type - - return "embed" - - def _get_supported_pooling_tasks( - self, - architectures: list[str], - convert_type: ConvertType, - ) -> list[_ResolvedTask]: - registry = self.registry - - # TODO: Use get_supported_pooling_tasks once V0 is removed - supported_tasks = list[_ResolvedTask]() - if (registry.is_pooling_model(architectures, self) - or convert_type in _RUNNER_CONVERTS["pooling"]): - supported_tasks.append("encode") - - extra_task = (self._get_default_pooling_task(architectures) - if convert_type == "none" else convert_type) - supported_tasks.append(extra_task) - - return supported_tasks - - def _get_supported_tasks( - self, - architectures: list[str], - runner_type: RunnerType, - convert_type: ConvertType, - ) -> list[_ResolvedTask]: - if runner_type == "generate": - return self._get_supported_generation_tasks( - architectures, convert_type) - if runner_type == "pooling": - return self._get_supported_pooling_tasks(architectures, - convert_type) - if runner_type == "draft": - return ["draft"] - - assert_never(runner_type) - - def _parse_quant_hf_config(self): - quant_cfg = getattr(self.hf_config, "quantization_config", None) - if quant_cfg is None: - # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(self.hf_config, "compression_config", None) - - else: - # Set quant_method for ModelOpt models. - producer_name = quant_cfg.get("producer", {}).get("name") - if producer_name == "modelopt": - quant_algo = quant_cfg.get("quantization", - {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError( - f"Unknown ModelOpt quant algo: {quant_algo}") - - return quant_cfg - - def _verify_quantization(self) -> None: - supported_quantization = me_quant.QUANTIZATION_METHODS - optimized_quantization_methods = [ - "fp8", - "modelopt", - "gptq_marlin_24", - "gptq_marlin", - "awq_marlin", - "fbgemm_fp8", - "compressed-tensors", - "experts_int8", - "quark", - "modelopt_fp4", - "bitblas", - "gptq_bitblas", - "inc", - "petit_nvfp4", - ] - if self.quantization is not None: - self.quantization = cast(me_quant.QuantizationMethods, - self.quantization) - - # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config() - - if quant_cfg is not None: - # Use the community standard 'quant_method' - quant_method = quant_cfg.get("quant_method", "").lower() - - # Normalize library names - quant_method = quant_method.replace("compressed_tensors", - "compressed-tensors") - - quant_cfg["quant_method"] = quant_method - - # Quantization methods which are overrides (i.e. they have a - # `override_quantization_method` method) must be checked in order - # of preference (this is particularly important for GPTQ). - overrides = [ - "bitblas", - "gptq_marlin_24", - "gptq_marlin", - "gptq_bitblas", - "awq_marlin", - "ipex", - "moe_wna16", - "modelopt", - "modelopt_fp4", - "petit_nvfp4", - ] - quantization_methods = [ - q for q in supported_quantization if q not in overrides - ] - # Any custom overrides will be in quantization_methods so we place - # them at the start of the list so custom overrides have preference - # over the built in ones. - quantization_methods = quantization_methods + overrides - - # Detect which checkpoint is it - for name in quantization_methods: - method = me_quant.get_quantization_config(name) - quantization_override = method.override_quantization_method( - quant_cfg, self.quantization) - if quantization_override is not None: - # Raise error if the override is not custom (custom would - # be in QUANTIZATION_METHODS but not QuantizationMethods) - # and hasn't been added to the overrides list. - if (name in get_args(me_quant.QuantizationMethods) - and name not in overrides): - raise ValueError( - f"Quantization method {name} is an override but " - "is has not been added to the `overrides` list " - "above. This is necessary to ensure that the " - "overrides are checked in order of preference.") - quant_method = quantization_override - self.quantization = quantization_override - break - - # Verify quantization configurations. - if self.quantization is None: - self.quantization = quant_method - elif self.quantization != quant_method: - raise ValueError( - "Quantization method specified in the model config " - f"({quant_method}) does not match the quantization " - f"method specified in the `quantization` argument " - f"({self.quantization}).") - - if self.quantization is not None: - if self.quantization not in supported_quantization: - raise ValueError( - f"Unknown quantization method: {self.quantization}. Must " - f"be one of {supported_quantization}.") - from vllm.platforms import current_platform - current_platform.verify_quantization(self.quantization) - if self.quantization not in optimized_quantization_methods: - logger.warning( - "%s quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.", self.quantization) - - def _verify_cuda_graph(self) -> None: - # The `max_seq_len_to_capture` was incorrectly - # based on the encoder's input length (448) - # but not the decoder's larger input length (1500). - # This change ensures the CUDA Graph captures the correct, - # larger sequence length, allowing it to work as intended. - effective_max_seq_len = self.max_model_len - if self.is_encoder_decoder: - effective_max_seq_len = max( - effective_max_seq_len, - getattr(self.hf_config, "max_source_positions", 0)) - self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, - effective_max_seq_len) - # CUDAGraph capture not supported for enc-dec models and mllama on ROCm - ROCM_UNSUPPORTED_MODELS = ['mllama'] - unsupported_rocm = (self.hf_config.model_type - in ROCM_UNSUPPORTED_MODELS - or self.is_encoder_decoder) - - if (unsupported_rocm and not self.enforce_eager - and current_platform.is_rocm()): - logger.warning( - "CUDA graph is not supported for %s on ROCm yet, fallback " - "to eager mode.", self.hf_config.model_type) - self.enforce_eager = True - - def _verify_bnb_config(self) -> None: - """ - The current version of bitsandbytes (0.46.1) with 8-bit models does not - yet support CUDA graph. - # TODO Remove this when bitsandbytes supports. - """ - is_bitsandbytes = self.quantization == "bitsandbytes" - has_quantization_config = (getattr(self.hf_config, - "quantization_config", None) - is not None) - is_8bit = (self.hf_config.quantization_config.get( - "load_in_8bit", False) if has_quantization_config else False) - if all([ - is_bitsandbytes, - has_quantization_config, - is_8bit, - not self.enforce_eager, - ]): - logger.warning( - "CUDA graph is not supported on BitsAndBytes 8bit yet, " - "fallback to the eager mode.") - - self.enforce_eager = True - - def _verify_with_expert_parallelism(self) -> None: - num_expert_names = [ - "moe_num_experts", # Dbrx - "num_experts", # Jamba - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = 0 - for name in num_expert_names: - num_experts = getattr(self.hf_text_config, name, 0) - if num_experts > 0: - break - if num_experts < 1: - raise ValueError( - "Number of experts in the model must be greater than 0 " - "when expert parallelism is enabled.") - - def verify_dual_chunk_attention_config( - self, - load_config: "LoadConfig", - ) -> None: - if hasattr(self.hf_config, "dual_chunk_attention_config"): - # Try loading the sparse attention config - from vllm.model_executor.model_loader.weight_utils import ( - get_sparse_attention_config) - sparse_attn_config = get_sparse_attention_config(self, load_config) - if sparse_attn_config: - self.hf_config.dual_chunk_attention_config[ - "sparse_attention_config"] = sparse_attn_config - if "sparse_attention_enabled" not in \ - self.hf_config.dual_chunk_attention_config: - self.hf_config.dual_chunk_attention_config[ - "sparse_attention_enabled"] = True - - def verify_async_output_proc(self, parallel_config, speculative_config, - device_config) -> None: - if not self.use_async_output_proc: - # Nothing to check - return - - if parallel_config.pipeline_parallel_size > 1: - self.use_async_output_proc = False - return - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - from vllm.platforms import current_platform - if not current_platform.is_async_output_supported(self.enforce_eager): - self.use_async_output_proc = False - return - - if envs.VLLM_USE_RAY_SPMD_WORKER: - self.use_async_output_proc = False - return - - # Async postprocessor is not necessary for pooling models - # since there is no token generation - if self.runner_type == "pooling": - self.use_async_output_proc = False - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - if speculative_config: - self.use_async_output_proc = False - - def verify_with_parallel_config( - self, - parallel_config: "ParallelConfig", - ) -> None: - - if parallel_config.distributed_executor_backend == "external_launcher": - assert self.seed is not None, ( - "Seed must be set when using external launcher backend to " - "make sure sampling results are the same across workers.") - - total_num_attention_heads = getattr(self.hf_text_config, - "num_attention_heads", 0) - tensor_parallel_size = parallel_config.tensor_parallel_size - if total_num_attention_heads % tensor_parallel_size != 0: - raise ValueError( - f"Total number of attention heads ({total_num_attention_heads})" - " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") - - if parallel_config.enable_expert_parallel: - self._verify_with_expert_parallelism() - - pipeline_parallel_size = parallel_config.pipeline_parallel_size - if pipeline_parallel_size > 1: - if not self.registry.is_pp_supported_model(self.architectures, - self): - raise NotImplementedError( - "Pipeline parallelism is not supported for this model. " - "Supported models implement the `SupportsPP` interface.") - - if self.use_async_output_proc: - self.use_async_output_proc = False - - def get_sliding_window(self) -> Optional[int]: - """Get the sliding window size from the HF text config if present.""" - return getattr(self.hf_text_config, "sliding_window", None) - - def get_vocab_size(self) -> int: - return getattr(self.hf_text_config, "vocab_size", 0) - - def get_hidden_size(self) -> int: - return getattr(self.hf_text_config, "hidden_size", 0) - - @property - def is_deepseek_mla(self) -> bool: - if not hasattr(self.hf_text_config, "model_type"): - return False - elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'): - return self.hf_text_config.kv_lora_rank is not None - elif self.hf_text_config.model_type == 'eagle': - # if the model is an EAGLE module, check for the - # underlying architecture - return self.hf_text_config.model.model_type in \ - ('deepseek_v2', 'deepseek_v3') \ - and self.hf_text_config.kv_lora_rank is not None - return False - - def get_head_size(self) -> int: - # TODO remove hard code - if self.is_deepseek_mla: - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", - 0) - if self.use_mla: - return self.hf_text_config.kv_lora_rank + qk_rope_head_dim - else: - qk_nope_head_dim = getattr(self.hf_text_config, - "qk_nope_head_dim", 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim - - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - == "zamba2"): - return self.hf_text_config.attention_head_dim - - if self.is_attention_free: - return 0 - - # NOTE: Some configs may set head_dim=None in the config - if getattr(self.hf_text_config, "head_dim", None) is not None: - return self.hf_text_config.head_dim - - # FIXME(woosuk): This may not be true for all models. - return (self.hf_text_config.hidden_size // - self.hf_text_config.num_attention_heads) - - def get_total_num_kv_heads(self) -> int: - """Returns the total number of KV heads.""" - # For GPTBigCode & Falcon: - # NOTE: for falcon, when new_decoder_architecture is True, the - # multi_query flag is ignored and we use n_head_kv for the number of - # KV heads. - falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_text_config, - "multi_query", False): - # Multi-query attention, only one KV head. - # Currently, tensor parallelism is not supported in this case. - return 1 - - # For DBRX and MPT - if self.hf_config.model_type == "mpt": - if "kv_n_heads" in self.hf_config.attn_config: - return self.hf_config.attn_config["kv_n_heads"] - return self.hf_config.num_attention_heads - if self.hf_config.model_type == "dbrx": - return getattr(self.hf_config.attn_config, "kv_n_heads", - self.hf_config.num_attention_heads) - - if self.hf_config.model_type == "nemotron-nas": - for block in self.hf_config.block_configs: - if not block.attention.no_op: - return self.hf_config.num_attention_heads \ - // block.attention.n_heads_in_group - - raise RuntimeError("Couldn't determine number of kv heads") - - if self.is_attention_free: - return 0 - - attributes = [ - # For Falcon: - "n_head_kv", - "num_kv_heads", - # For LLaMA-2: - "num_key_value_heads", - # For ChatGLM: - "multi_query_group_num", - ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - - # For non-grouped-query attention models, the number of KV heads is - # equal to the number of attention heads. - return self.hf_text_config.num_attention_heads - - def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: - """Returns the number of KV heads per GPU.""" - if self.use_mla: - # When using MLA during decode it becomes MQA - return 1 - - total_num_kv_heads = self.get_total_num_kv_heads() - # If tensor parallelism is used, we divide the number of KV heads by - # the tensor parallel size. We will replicate the KV heads in the - # case where the number of KV heads is smaller than the tensor - # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) - - def get_num_attention_heads(self, - parallel_config: "ParallelConfig") -> int: - num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) - return num_heads // parallel_config.tensor_parallel_size - - def get_layers_start_end_indices( - self, parallel_config: "ParallelConfig") -> tuple[int, int]: - from vllm.distributed.utils import get_pp_indices - if (self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp"): - total_num_hidden_layers = getattr(self.hf_text_config, - "num_nextn_predict_layers", 0) - else: - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) - # the layout order is: DP x PP x TP - pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size - ) % parallel_config.pipeline_parallel_size - pp_size = parallel_config.pipeline_parallel_size - start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) - return start, end - - def get_num_layers(self, parallel_config: "ParallelConfig") -> int: - start, end = self.get_layers_start_end_indices(parallel_config) - return end - start - - def get_num_layers_by_block_type( - self, - parallel_config: "ParallelConfig", - block_type: LayerBlockType = LayerBlockType.attention, - ) -> int: - # This function relies on 'layers_block_type' in hf_config, - # for w/o this attribute, we will need to have workarounds like so - attn_block_type = block_type == LayerBlockType.attention - is_transformer = not self.is_hybrid and \ - not self.has_noops and \ - not self.is_attention_free - start, end = self.get_layers_start_end_indices(parallel_config) - - if is_transformer: - # Handle the basic case first - return end - start if attn_block_type else 0 - elif self.is_attention_free: - # Attention free - # Note that this code assumes there - # is only one type of attention-free block type. - return 0 if attn_block_type else end - start - elif self.has_noops: - block_configs = self.hf_config.block_configs - return sum(not bc.attention.no_op - for bc in block_configs[start:end]) - else: - # Hybrid model Jamba - layers_block_type_value = getattr(self.hf_config, - "layers_block_type", None) - if layers_block_type_value is not None: - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - == "zamba2"): - if attn_block_type: - return sum(t == "hybrid" - for t in layers_block_type_value[start:end]) - else: - return self.get_num_layers(parallel_config) - return sum(t == block_type.value - for t in layers_block_type_value[start:end]) - - # Hybrid model Minimax - attn_type_list = getattr(self.hf_config, "attn_type_list", None) - if attn_type_list: - return sum(t == 1 for t in attn_type_list[start:end]) - - if layers_block_type_value is None and attn_type_list is None: - raise ValueError( - "The model is an hybrid without a" - "layers_block_type or an attn_type_list in the hf_config," - "cannot determine the num of " - f"{block_type.value} layers") - - return sum(t == 1 for t in attn_type_list[start:end]) - - def get_mamba_chunk_size(self) -> Optional[int]: - """ - Returns the mamba chunk size if it exists - """ - # used by e.g. Bamba, FalconH1, Granite, PLaMo2 - chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None) - if chunk_size is None: - # used by e.g. Mamba2, NemotronH, Zamba - chunk_size = getattr(self.hf_text_config, "chunk_size", None) - return chunk_size - - def get_multimodal_config(self) -> "MultiModalConfig": - """ - Get the multimodal configuration of the model. - - Raises: - ValueError: If the model is not multimodal. - """ - if self.multimodal_config is None: - raise ValueError("The model is not multimodal.") - - return self.multimodal_config - - def try_get_generation_config(self) -> dict[str, Any]: - """ - This method attempts to retrieve the non-default values of the - generation config for this model. - - The generation config can contain information about special tokens, as - well as sampling parameters. Which is why this method exists separately - to `get_diff_sampling_param`. - - Returns: - A dictionary containing the non-default generation config. - """ - if self.generation_config in {"auto", "vllm"}: - config = try_get_generation_config( - self.hf_config_path or self.model, - trust_remote_code=self.trust_remote_code, - revision=self.revision, - ) - else: - config = try_get_generation_config( - self.generation_config, - trust_remote_code=self.trust_remote_code, - ) - - if config is None: - return {} - - return config.to_diff_dict() - - def get_diff_sampling_param(self) -> dict[str, Any]: - """ - This method returns a dictionary containing the non-default sampling - parameters with `override_generation_config` applied. - - The default sampling parameters are: - - - vLLM's neutral defaults if `self.generation_config="vllm"` - - the model's defaults if `self.generation_config="auto"` - - as defined in `generation_config.json` if - `self.generation_config="path/to/generation_config/dir"` - - Returns: - A dictionary containing the non-default sampling parameters. - """ - if self.generation_config == "vllm": - config = {} - else: - config = self.try_get_generation_config() - - # Overriding with given generation config - config.update(self.override_generation_config) - - available_params = [ - "repetition_penalty", - "temperature", - "top_k", - "top_p", - "min_p", - "max_new_tokens", - ] - if any(p in config for p in available_params): - diff_sampling_param = { - p: config.get(p) - for p in available_params if config.get(p) is not None - } - # Huggingface definition of max_new_tokens is equivalent - # to vLLM's max_tokens - if "max_new_tokens" in diff_sampling_param: - diff_sampling_param["max_tokens"] = diff_sampling_param.pop( - "max_new_tokens") - else: - diff_sampling_param = {} - - if diff_sampling_param: - logger.warning_once( - "Default sampling parameters have been overridden by the " - "model's Hugging Face generation config recommended from the " - "model creator. If this is not intended, please relaunch " - "vLLM instance with `--generation-config vllm`.") - return diff_sampling_param - - @property - def is_encoder_decoder(self) -> bool: - """Extract the HF encoder/decoder model flag.""" - """ - For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to - True to enable cross-attention - Neuron needs all multimodal data to be in the decoder and does not - need to explicitly enable cross-attention - """ - if (current_platform.is_neuron() - and self.hf_config.model_type == "mllama"): - return False - - return is_encoder_decoder(self.hf_config) - - @property - def uses_mrope(self) -> bool: - return uses_mrope(self.hf_config) - - @property - def is_multimodal_model(self) -> bool: - return self.multimodal_config is not None - - @property - def enable_mm_processor_cache(self) -> bool: - """Whether the multi-modal processor cache should be enabled.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - def get_mm_input_cache_gb(self) -> int: - mm_config = self.multimodal_config - if mm_config is None: - return 0 - - return envs.VLLM_MM_INPUT_CACHE_GIB - - @property - def is_cross_encoder(self) -> bool: - return (self._model_info.supports_cross_encoding - or self.convert_type == "classify") - - @property - def is_pp_supported(self) -> bool: - return self._model_info.supports_pp - - @property - def is_multimodal_raw_input_supported(self) -> bool: - return self._model_info.supports_multimodal_raw_input - - @property - def is_attention_free(self) -> bool: - return self._model_info.is_attention_free - - @property - def is_hybrid(self) -> bool: - return self._model_info.is_hybrid - - @property - def has_noops(self) -> bool: - return self._model_info.has_noops - - @property - def has_inner_state(self): - return self._model_info.has_inner_state - - @property - def is_v1_compatible(self) -> bool: - return not self._model_info.supports_v0_only - - @property - def use_mla(self) -> bool: - return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE - - @property - def is_matryoshka(self) -> bool: - return (bool(getattr(self.hf_config, "matryoshka_dimensions", None)) - or getattr(self.hf_config, "is_matryoshka", False)) - - @property - def matryoshka_dimensions(self): - return getattr(self.hf_config, "matryoshka_dimensions", None) - - @property - def use_pad_token(self) -> bool: - # cross_encoder models defaults to using pad_token. - # `llm as reranker` models defaults to not using pad_token. - return getattr(self.hf_config, "use_pad_token", True) - - def get_and_verify_max_len(self, max_model_len: int): - # Consider max_model_len in tokenizer_config only when - # pooling models use absolute position_embedding. - tokenizer_config = None - if (self.runner_type == "pooling" and getattr( - self.hf_config, "position_embedding_type", "") == "absolute"): - tokenizer_config = try_get_tokenizer_config( - self.tokenizer, - trust_remote_code=self.trust_remote_code, - revision=self.tokenizer_revision) - max_model_len = _get_and_verify_max_len( - hf_config=self.hf_text_config, - tokenizer_config=tokenizer_config, - max_model_len=max_model_len, - disable_sliding_window=self.disable_sliding_window, - sliding_window=self.get_sliding_window(), - spec_target_max_model_len=self.spec_target_max_model_len, - encoder_config=self.encoder_config) - logger.info("Using max model len %s", max_model_len) - return max_model_len - - -@config -@dataclass -class LoadConfig: - """Configuration for loading the model weights.""" - - load_format: Union[str, LoadFormats] = "auto" - """The format of the model weights to load:\n - - "auto" will try to load the weights in the safetensors format and fall - back to the pytorch bin format if safetensors format is not available.\n - - "pt" will load the weights in the pytorch bin format.\n - - "safetensors" will load the weights in the safetensors format.\n - - "npcache" will load the weights in pytorch format and store a numpy cache - to speed up the loading.\n - - "dummy" will initialize the weights with random values, which is mainly - for profiling.\n - - "tensorizer" will use CoreWeave's tensorizer library for fast weight - loading. See the Tensorize vLLM Model script in the Examples section for - more information.\n - - "runai_streamer" will load the Safetensors weights using Run:ai Model - Streamer.\n - - "bitsandbytes" will load the weights using bitsandbytes quantization.\n - - "sharded_state" will load weights from pre-sharded checkpoint files, - supporting efficient loading of tensor-parallel models.\n - - "gguf" will load weights from GGUF format files (details specified in - https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n - - "mistral" will load weights from consolidated safetensors files used by - Mistral models. - - Other custom values can be supported via plugins.""" - download_dir: Optional[str] = None - """Directory to download and load the weights, default to the default - cache directory of Hugging Face.""" - model_loader_extra_config: Union[dict, TensorizerConfig] = field( - default_factory=dict) - """Extra config for model loader. This will be passed to the model loader - corresponding to the chosen load_format.""" - device: Optional[str] = None - """Device to which model weights will be loaded, default to - device_config.device""" - ignore_patterns: Optional[Union[list[str], str]] = None - """The list of patterns to ignore when loading the model. Default to - "original/**/*" to avoid repeated loading of llama's checkpoints.""" - use_tqdm_on_load: bool = True - """Whether to enable tqdm for showing progress bar when loading model - weights.""" - pt_load_map_location: Union[str, dict[str, str]] = "cpu" - """ - pt_load_map_location: the map location for loading pytorch checkpoint, to - support loading checkpoints can only be loaded on certain devices like - "cuda", this is equivalent to {"": "cuda"}. Another supported format is - mapping from different devices like from GPU 1 to GPU 0: - {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings - in dictionary needs to be double quoted for json parsing. For more details, - see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - self.load_format = self.load_format.lower() - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: - logger.info( - "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns) - else: - self.ignore_patterns = ["original/**/*"] - - -Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"] - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class DeviceConfig: - """Configuration for the device to use for vLLM execution.""" - - device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" - """Device type for vLLM execution. - This parameter is deprecated and will be - removed in a future release. - It will now be set automatically based - on the current platform.""" - device_type: str = field(init=False) - """Device type from the current platform. This is set in - `__post_init__`.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # the device/platform information will be summarized - # by torch/vllm automatically. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if self.device == "auto": - # Automated device type detection - from vllm.platforms import current_platform - self.device_type = current_platform.device_type - if not self.device_type: - raise RuntimeError( - "Failed to infer device type, please set " - "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " - "to turn on verbose logging to help debug the issue.") - else: - # Device type is assigned explicitly - if isinstance(self.device, str): - self.device_type = self.device - elif isinstance(self.device, torch.device): - self.device_type = self.device.type - - # Some device types require processing inputs on CPU - if self.device_type in ["neuron"]: - self.device = torch.device("cpu") - elif self.device_type in ["tpu"]: - self.device = None - else: - # Set device with device type - self.device = torch.device(self.device_type) - - -SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp"] - - -@config -@dataclass -class SpeculativeConfig: - """Configuration for speculative decoding.""" - - # General speculative decoding control - num_speculative_tokens: SkipValidation[int] = None # type: ignore - """The number of speculative tokens, if provided. It will default to the - number in the draft model config if present, otherwise, it is required.""" - model: Optional[str] = None - """The name of the draft model, eagle head, or additional weights, if - provided.""" - method: Optional[SpeculativeMethod] = None - """The name of the speculative method to use. If users provide and set the - `model` param, the speculative method type will be detected automatically - if possible, if `model` param is not provided, the method name must be - provided. - - If using `ngram` method, the related configuration `prompt_lookup_max` and - `prompt_lookup_min` should be considered.""" - draft_tensor_parallel_size: Optional[int] = None - """The degree of the tensor parallelism for the draft model. Can only be 1 - or the same as the target model's tensor parallel size.""" - disable_logprobs: bool = True - """If set to True, token log probabilities are not returned during - speculative decoding. If set to False, token log probabilities are returned - according to the log probability settings in SamplingParams.""" - - # Draft model configuration - quantization: Optional[me_quant.QuantizationMethods] = None - """Quantization method that was used to quantize the draft model weights. - If `None`, we assume the model weights are not quantized. Note that it only - takes effect when using the draft model-based speculative method.""" - max_model_len: Optional[int] = None - """The maximum model length of the draft model. Used when testing the - ability to skip speculation for some sequences.""" - revision: Optional[str] = None - """The specific model version to use for the draft model. It can be a - branch name, a tag name, or a commit id. If unspecified, will use the - default version.""" - code_revision: Optional[str] = None - """The specific revision to use for the draft model code on Hugging Face - Hub. It can be a branch name, a tag name, or a commit id. If unspecified, - will use the default version.""" - - # Advanced control - disable_by_batch_size: Optional[int] = None - """Disable speculative decoding for new incoming requests when the number - of enqueued requests is larger than this value, if provided.""" - - # Ngram proposer configuration - prompt_lookup_max: Optional[int] = None - """Maximum size of ngram token window when using Ngram proposer, required - when method is set to ngram.""" - prompt_lookup_min: Optional[int] = None - """Minimum size of ngram token window when using Ngram proposer, if - provided. Defaults to 1.""" - - speculative_token_tree: Optional[str] = None - """Specifies the tree structure for speculative token generation. - """ - # required configuration params passed from engine - target_model_config: SkipValidation[ModelConfig] = None # type: ignore - """The configuration of the target model.""" - target_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore - """The parallel configuration for the target model.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore - """Whether vLLM is configured to use chunked prefill or not. Used for - raising an error since it's not yet compatible with speculative decode.""" - disable_log_stats: SkipValidation[bool] = None # type: ignore - """Whether to disable the periodic printing of stage times in speculative - decoding.""" - - # params generated in the post-init stage - draft_model_config: SkipValidation[ModelConfig] = None # type: ignore - """The configuration of the draft model initialized internal.""" - draft_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore - """The parallel configuration for the draft model initialized internal.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - # Eagle3 affects the computation graph because it returns intermediate - # hidden states in addition to the final hidden state. - factors.append(self.method == "eagle3") - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - @staticmethod - def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: - if hf_config.model_type == "deepseek_v3": - hf_config.model_type = "deepseek_mtp" - if hf_config.model_type == "deepseek_mtp": - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["DeepSeekMTPModel"] - }) - - if hf_config.architectures[0] == "MiMoForCausalLM": - hf_config.model_type = "mimo_mtp" - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["MiMoMTPModel"] - }) - - if hf_config.architectures[0] == "Glm4MoeForCausalLM": - hf_config.model_type = "glm4_moe_mtp" - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["Glm4MoeMTPModel"] - }) - - if hf_config.model_type == "ernie4_5_moe": - hf_config.model_type = "ernie_mtp" - if hf_config.model_type == "ernie_mtp": - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["ErnieMTPModel"] - }) - return hf_config - - return hf_config - - def __post_init__(self): - - # Note: "method" is a new parameter that helps to extend the - # configuration of non-model-based proposers, and the "model" parameter - # will be used to set the draft model, eagle head, or additional weight - # when needed. If users do not specify "method", the speculative method - # will be detected automatically if possible. If the speculative method - # can not be detected, it will be considered as the "draft_model" by - # default. - - if self.model is None and self.num_speculative_tokens is not None: - # TODO(Shangming): Refactor mtp configuration logic when supporting - # mtp acceleration for more models besides deepseek_v3 - if self.target_model_config and \ - (self.target_model_config.hf_text_config.model_type \ - == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type in - ("mimo","ernie4_5_moe")): - # use the draft model from the same model: - self.model = self.target_model_config.model - elif self.method in ("ngram", "[ngram]"): - self.model = "ngram" - else: - raise ValueError("num_speculative_tokens was provided without " - "speculative model.") - - # Automatically configure the method for ngram when "model" is used - # instead of "method" - if self.method is None and (self.model is not None - and self.model in ("ngram", "[ngram]")): - self.method = "ngram" - - if self.method in ("ngram", "[ngram]"): - # Unified to "ngram" internally - self.method = "ngram" - # Set default values if not provided - if (self.prompt_lookup_min is None - and self.prompt_lookup_max is None): - # TODO(woosuk): Tune these values. They are arbitrarily chosen. - self.prompt_lookup_min = 5 - self.prompt_lookup_max = 5 - elif self.prompt_lookup_min is None: - assert self.prompt_lookup_max is not None - self.prompt_lookup_min = self.prompt_lookup_max - elif self.prompt_lookup_max is None: - assert self.prompt_lookup_min is not None - self.prompt_lookup_max = self.prompt_lookup_min - - # Validate values - if self.prompt_lookup_min < 1: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") - if self.prompt_lookup_max < 1: - raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") - if self.prompt_lookup_min > self.prompt_lookup_max: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must " - f"be <= prompt_lookup_max={self.prompt_lookup_max}") - - # TODO: current we still need extract vocab_size from target model - # config, in future, we may try refactor it out, and set - # draft related config as None here. - self.draft_model_config = self.target_model_config - self.draft_parallel_config = self.target_parallel_config - else: - self.prompt_lookup_max = 0 - self.prompt_lookup_min = 0 - - if self.model is not None: - self.draft_model_config = ModelConfig( - model=self.model, - runner="draft", - tokenizer=self.target_model_config.tokenizer, - tokenizer_mode=self.target_model_config.tokenizer_mode, - trust_remote_code=self.target_model_config. - trust_remote_code, - allowed_local_media_path=self.target_model_config. - allowed_local_media_path, - dtype=self.target_model_config.dtype, - seed=self.target_model_config.seed, - revision=self.revision, - code_revision=self.code_revision, - tokenizer_revision=self.target_model_config. - tokenizer_revision, - spec_target_max_model_len=self.target_model_config. - max_model_len, - quantization=self.quantization, - enforce_eager=self.target_model_config.enforce_eager, - max_seq_len_to_capture=self.target_model_config. - max_seq_len_to_capture, - max_logprobs=self.target_model_config.max_logprobs, - hf_overrides=SpeculativeConfig.hf_config_override, - ) - - # Automatically detect the method - if self.method in ('eagle', 'eagle3'): - pass - elif "eagle-" in self.draft_model_config.model.lower() or \ - "eagle3-" in self.draft_model_config.model.lower(): - self.method = "eagle" - elif self.draft_model_config.hf_config.model_type == "medusa": - self.method = "medusa" - elif (self.draft_model_config.hf_config.model_type == - "mlp_speculator"): - self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type - in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): - self.method = "deepseek_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Deepseek MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type == - "ernie_mtp"): - self.method = "ernie_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Ernie MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - else: - self.method = "draft_model" - raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or deepseek_mtp.") - - # Replace hf_config for EAGLE draft_model - if self.method in ("eagle", "eagle3"): - if self.enable_chunked_prefill and not envs.VLLM_USE_V1: - raise ValueError( - "Chunked prefill and EAGLE are not compatible " - "when using V0.") - - from vllm.transformers_utils.configs import ( - SpeculatorsConfig) - from vllm.transformers_utils.configs.eagle import ( - EAGLEConfig) - - if isinstance(self.draft_model_config.hf_config, - (EAGLEConfig, SpeculatorsConfig)): - pass - else: - eagle_config = EAGLEConfig( - self.draft_model_config.hf_config, - method=self.method, - model_type="eagle") - self.draft_model_config.hf_config = eagle_config - - if (self.num_speculative_tokens is not None - and hasattr(self.draft_model_config.hf_config, - "num_lookahead_tokens")): - self.draft_model_config.hf_config.num_lookahead_tokens = \ - self.num_speculative_tokens - - n_predict = getattr(self.draft_model_config.hf_config, - "n_predict", None) - if n_predict is not None: - if self.num_speculative_tokens is None: - # Default to max value defined in draft model config. - self.num_speculative_tokens = n_predict - elif self.num_speculative_tokens > n_predict and \ - self.num_speculative_tokens % n_predict != 0: - # Ensure divisibility for MTP module reuse. - raise ValueError( - f"num_speculative_tokens:{self.num_speculative_tokens}" - f" must be divisible by {n_predict=}") - - if self.speculative_token_tree is None: - # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) - for i in range(self.num_speculative_tokens) - ]) - else: - # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) - self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) - - self.draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_tp( - self.target_parallel_config, - self.draft_tensor_parallel_size, - self.draft_model_config.hf_config - ) - - self.draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - self.max_model_len, - self.draft_model_config.max_model_len, - self.target_model_config.max_model_len, - )) - - self.draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, - self.draft_tensor_parallel_size)) - - @staticmethod - def _maybe_override_draft_max_model_len( - speculative_max_model_len: Optional[int], - draft_max_model_len: int, - target_max_model_len: int, - ) -> int: - """Determine the max sequence len for the draft model. This is usually - the draft_max_model_len, but may be the target_max_model_len if it is - less than the draft_max_model_len, or may be speculative_max_model_len - if it is specified. - - This is necessary so that sequences do not exceed the capacity of the - draft model or the target model. - - speculative_max_model_len is mainly used for testing that sequences can - skip speculation. - """ - - if speculative_max_model_len is not None: - - if speculative_max_model_len > draft_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {draft_max_model_len=}") - - if speculative_max_model_len > target_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {target_max_model_len=}") - - return speculative_max_model_len - - return min( - draft_max_model_len, - target_max_model_len, - ) - - @staticmethod - def _verify_and_get_draft_tp( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig) -> int: - """ - Verifies and adjusts the tensor parallel size for a draft model - specified using speculative_draft_tensor_parallel_size. - """ - # If speculative_draft_tensor_parallel_size is unset then set it - # appropriately else verify that it is set correctly. - if speculative_draft_tensor_parallel_size is None: - if draft_hf_config.model_type == "mlp_speculator": - speculative_draft_tensor_parallel_size = 1 - if target_parallel_config.tensor_parallel_size > 1: - logger.warning( - "%s cannot currently be run with tp>1; " - "setting speculative_draft_tensor_parallel_size=1", - draft_hf_config.model_type) - else: - speculative_draft_tensor_parallel_size = \ - target_parallel_config.tensor_parallel_size - elif speculative_draft_tensor_parallel_size not in ( - 1, target_parallel_config.tensor_parallel_size): - raise ValueError( - f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1 or target model tensor_parallel_size") - return speculative_draft_tensor_parallel_size - - @staticmethod - def create_draft_parallel_config( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: int, - ) -> ParallelConfig: - """Create a parallel config for use by the draft worker. - - This is mostly a copy of the target parallel config, except the tp_size. - """ - draft_parallel_config = ParallelConfig( - pipeline_parallel_size=target_parallel_config. - pipeline_parallel_size, - tensor_parallel_size=speculative_draft_tensor_parallel_size, - distributed_executor_backend=target_parallel_config. - distributed_executor_backend, - max_parallel_loading_workers=target_parallel_config. - max_parallel_loading_workers, - disable_custom_all_reduce=target_parallel_config. - disable_custom_all_reduce, - ray_workers_use_nsight=target_parallel_config. - ray_workers_use_nsight, - placement_group=target_parallel_config.placement_group, - ) - - return draft_parallel_config - - @model_validator(mode='after') - def _verify_args(self) -> Self: - if self.num_speculative_tokens is None: - raise ValueError( - "num_speculative_tokens must be provided with " - "speculative model unless the draft model config contains an " - "n_predict parameter.") - - if self.num_speculative_tokens <= 0: - raise ValueError("Expected num_speculative_tokens to be greater " - f"than zero ({self.num_speculative_tokens}).") - - if self.draft_model_config: - self.draft_model_config.verify_with_parallel_config( - self.draft_parallel_config) - - if (self.disable_by_batch_size is not None - and self.disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{self.disable_by_batch_size=}") - - eagle3_target_supported = ["llama", "qwen"] - if self.method == "eagle3" and self.target_model_config and not any( - supported_model in - self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported): - raise ValueError( - f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 - f"Got {self.target_model_config.hf_text_config.model_type=}") - - return self - - @property - def num_lookahead_slots(self) -> int: - """The number of additional slots the scheduler should allocate per - step, in addition to the slots allocated for each known token. - - This is equal to the number of speculative tokens, as each speculative - token must be scored. - """ - return self.num_speculative_tokens - - def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") - - def __repr__(self) -> str: - method = self.method - model = None if method == "ngram" else self.draft_model_config.model - num_spec_tokens = self.num_speculative_tokens - return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" - - -LoRADType = Literal["auto", "float16", "bfloat16"] - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class LoRAConfig: - """Configuration for LoRA.""" - - max_lora_rank: int = 16 - """Max LoRA rank.""" - max_loras: int = 1 - """Max number of LoRAs in a single batch.""" - fully_sharded_loras: bool = False - """By default, only half of the LoRA computation is sharded with tensor - parallelism. Enabling this will use the fully sharded layers. At high - sequence length, max rank or tensor parallel size, this is likely faster. - """ - max_cpu_loras: Optional[int] = None - """Maximum number of LoRAs to store in CPU memory. Must be >= than - `max_loras`.""" - lora_dtype: Union[torch.dtype, LoRADType] = "auto" - """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: int = 256 - """Maximum size of extra vocabulary that can be present in a LoRA adapter - (added to the base model vocabulary).""" - lora_vocab_padding_size: ClassVar[int] = current_platform\ - .get_lora_vocab_padding_size() - - default_mm_loras: Optional[dict[str, str]] = None - """Dictionary mapping specific modalities to LoRA model paths; this field - is only applicable to multimodal models and should be leveraged when a - model always expects a LoRA to be active when a given modality is present. - Note that currently, if a request provides multiple additional - modalities, each of which have their own LoRA, we do NOT apply - default_mm_loras because we currently only support one lora adapter - per prompt. When run in offline mode, the lora IDs for n modalities - will be automatically assigned to 1-n with the names of the modalities - in alphabetic order.""" - bias_enabled: bool = False - """Enable bias for LoRA adapters.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.max_lora_rank) - factors.append(self.max_loras) - factors.append(self.fully_sharded_loras) - factors.append(self.lora_dtype) - factors.append(self.lora_extra_vocab_size) - factors.append(self.lora_vocab_padding_size) - factors.append(self.bias_enabled) - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - # Setting the maximum rank to 512 should be able to satisfy the vast - # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) - possible_lora_extra_vocab_size = (256, 512) - if self.max_lora_rank not in possible_max_ranks: - raise ValueError( - f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") - if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError( - f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") - if self.max_loras < 1: - raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") - if self.max_cpu_loras is None: - self.max_cpu_loras = self.max_loras - elif self.max_cpu_loras < self.max_loras: - raise ValueError( - f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})") - - def verify_with_cache_config(self, cache_config: CacheConfig): - if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: - raise ValueError( - "V0 LoRA does not support CPU offload, please use V1.") - - def verify_with_model_config(self, model_config: ModelConfig): - if self.lora_dtype in (None, "auto"): - self.lora_dtype = model_config.dtype - elif isinstance(self.lora_dtype, str): - self.lora_dtype = getattr(torch, self.lora_dtype) - - -@config -@dataclass -class MultiModalConfig: - """Controls the behavior of multimodal models.""" - - limit_per_prompt: dict[str, int] = \ - cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) - """ - The maximum number of input items allowed per prompt for each modality. - Defaults to 1 (V0) or 999 (V1) for each modality. - - For example, to allow up to 16 images and 2 videos per prompt: - `{"image": 16, "video": 2}` - """ - - media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set - `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ - - mm_processor_kwargs: Optional[dict[str, object]] = None - """ - Overrides for the multi-modal processor obtained from - `transformers.AutoProcessor.from_pretrained`. - - The available overrides depend on the model that is being run. - - For example, for Phi-3-Vision: - `{"num_crops": 4}`. - """ - - mm_processor_cache_gb: int = 4 - """ - The size (in GiB) of the multi-modal processor cache, which is used to - - This cache is duplicated for each API process and engine core process, - resulting in a total memory usage of - `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. - - Set to `0` to disable this cache completely (not recommended). - """ - - mm_encoder_tp_mode: MMEncoderTPMode = "weights" - """ - Indicates how to optimize multi-modal encoder inference using - tensor parallelism (TP). - - - `"weights"`: Within the same vLLM engine, split the weights of - each layer across TP ranks. (default TP behavior) - - `"data"`: Within the same vLLM engine, split the batched input data - across TP ranks to process the data in parallel, while hosting - the full weights on each TP rank. - This batch-level DP is not to be confused with API request-level - DP (which is controlled by `--data-parallel-size`). - This is only supported on a per-model basis and falls back to - `"weights"` if the encoder does not support DP. - """ - - interleave_mm_strings: bool = False - """ - Enable fully interleaved support for multimodal prompts. - """ - - skip_mm_profiling: bool = False - """ - When enabled, skips multimodal memory profiling and only profiles with - language backbone model during engine initialization. - - This reduces engine startup time but shifts the responsibility to users for - estimating the peak memory usage of the activation of multimodal encoder and - embedding cache. - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def get_limit_per_prompt(self, modality: str) -> int: - """ - Get the maximum number of input items allowed per prompt - for the given modality. - """ - return self.limit_per_prompt.get( - modality, - 999 if envs.VLLM_USE_V1 else 1, - ) - - def merge_mm_processor_kwargs( - self, - inference_kwargs: Mapping[str, object], - ) -> dict[str, object]: - """ - Get the keyword arguments to pass to the multi-modal processor - according to the extra arguments passed during inference. - """ - kwargs = self.mm_processor_kwargs or {} - return kwargs | dict(inference_kwargs) - - -@config -@dataclass -class PoolerConfig: - """Controls the behavior of output pooling in pooling models.""" - - pooling_type: Optional[str] = None - """ - The pooling method of the pooling model. This should be a key in - [`vllm.model_executor.layers.pooler.PoolingType`][]. - """ - - ## for embeddings models - normalize: Optional[bool] = None - """ - Whether to normalize the embeddings outputs. - """ - dimensions: Optional[int] = None - """ - Reduce the dimensions of embeddings if model - support matryoshka representation. - """ - - ## for classification models - activation: Optional[bool] = None - """ - Whether to apply activation function to the classification outputs. - """ - - ## for reward models - softmax: Optional[bool] = None - """ - Whether to apply softmax to the reward outputs. - """ - step_tag_id: Optional[int] = None - """ - If set, only the score corresponding to the ``step_tag_id`` in the - generated sentence should be returned. Otherwise, the scores for all tokens - are returned. - """ - returned_token_ids: Optional[list[int]] = None - """ - A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the - ``math-shepherd-mistral-7b-prm`` model. - """ - - enable_chunked_processing: Optional[bool] = None - """ - Whether to enable chunked processing for long inputs that exceed the model's - maximum position embeddings. When enabled, long inputs will be split into - chunks, processed separately, and then aggregated using weighted averaging. - This allows embedding models to handle arbitrarily long text without CUDA - errors. Defaults to False. - """ - - max_embed_len: Optional[int] = None - """ - Maximum input length allowed for embedding generation. When set, allows - inputs longer than max_embed_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring - VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds - max_embed_len, it will be handled according to the original max_model_len - validation logic. Defaults to None (i.e. set to max_model_len). - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - -_STR_DTYPE_TO_TORCH_DTYPE = { - "half": torch.float16, - "float16": torch.float16, - "float": torch.float32, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - -# model_type -> reason -_FLOAT16_NOT_SUPPORTED_MODELS = { - "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", - "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", - "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", - "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", -} - - -def _is_valid_dtype(model_type: str, dtype: torch.dtype): - if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 - return False - - return True - - -def _check_valid_dtype(model_type: str, dtype: torch.dtype): - if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: - reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] - raise ValueError(f"The model type {model_type!r} " - f"does not support float16. Reason: {reason}") - - return True - - -def _find_dtype( - model_id: str, - config: PretrainedConfig, - *, - revision: Optional[str], -): - # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct - # because config.torch_dtype can be None. - config_dtype = getattr(config, "torch_dtype", None) - - # Fallbacks for multi-modal models if the root config - # does not define torch_dtype - if config_dtype is None: - config_dtype = getattr(config.get_text_config(), "torch_dtype", None) - if config_dtype is None and hasattr(config, "vision_config"): - config_dtype = getattr(config.vision_config, "torch_dtype", None) - if config_dtype is None and hasattr(config, "encoder_config"): - config_dtype = getattr(config.encoder_config, "torch_dtype", None) - - # Try to read the dtype of the weights if they are in safetensors format - if config_dtype is None: - repo_mt = try_get_safetensors_metadata(model_id, revision=revision) - - if repo_mt and (files_mt := repo_mt.files_metadata): - param_dtypes: set[torch.dtype] = { - _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] - for file_mt in files_mt.values() - for dtype_str in file_mt.parameter_count - if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE - } - - if param_dtypes: - return common_broadcastable_dtype(param_dtypes) - - if config_dtype is None: - config_dtype = torch.float32 - - return config_dtype - - -def _resolve_auto_dtype( - model_type: str, - config_dtype: torch.dtype, - *, - is_pooling_model: bool, -): - from vllm.platforms import current_platform - - supported_dtypes = [ - dtype for dtype in current_platform.supported_dtypes - if _is_valid_dtype(model_type, dtype) - ] - - if is_pooling_model and torch.float16 in supported_dtypes: - preferred_dtype = torch.float16 - else: - preferred_dtype = supported_dtypes[0] - - # Downcast for float32 models - if config_dtype == torch.float32: - config_dtype = preferred_dtype - - if config_dtype in supported_dtypes: - return config_dtype - - # Ensure device compatibility - device_name = current_platform.get_device_name() - device_capability = current_platform.get_device_capability() - - if device_capability is None: - device_str = f"{device_name!r}" - else: - version_str = device_capability.as_version_str() - device_str = f"{device_name!r} (with compute capability {version_str})" - - logger.warning( - "Your device %s doesn't support %s. " - "Falling back to %s for compatibility.", - device_str, - config_dtype, - preferred_dtype, - ) - - return preferred_dtype - - -def _get_and_verify_dtype( - model_id: str, - config: PretrainedConfig, - dtype: Union[str, torch.dtype], - *, - is_pooling_model: bool, - revision: Optional[str] = None, -) -> torch.dtype: - config_dtype = _find_dtype(model_id, config, revision=revision) - model_type = config.model_type - - if isinstance(dtype, str): - dtype = dtype.lower() - if dtype == "auto": - # Set default dtype from model config - torch_dtype = _resolve_auto_dtype( - model_type, - config_dtype, - is_pooling_model=is_pooling_model, - ) - else: - if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: - raise ValueError(f"Unknown dtype: {dtype!r}") - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] - elif isinstance(dtype, torch.dtype): - torch_dtype = dtype - else: - raise ValueError(f"Unknown dtype: {dtype}") - - _check_valid_dtype(model_type, torch_dtype) - - if torch_dtype != config_dtype: - if torch_dtype == torch.float32: - # Upcasting to float32 is allowed. - logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) - elif config_dtype == torch.float32: - # Downcasting from float32 to float16 or bfloat16 is allowed. - logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) - else: - # Casting between float16 and bfloat16 is allowed with a warning. - logger.warning("Casting %s to %s.", config_dtype, torch_dtype) - - return torch_dtype - - -def _get_and_verify_max_len( - hf_config: PretrainedConfig, - tokenizer_config: Optional[dict], - max_model_len: Optional[int], - disable_sliding_window: bool, - sliding_window: Optional[int], - spec_target_max_model_len: Optional[int] = None, - encoder_config: Optional[Any] = None, -) -> int: - """Get and verify the model's maximum length.""" - derived_max_model_len = float("inf") - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Whisper - "max_target_positions", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - # Choose the smallest "max_length" from the possible keys - max_len_key = None - for key in possible_keys: - max_len = getattr(hf_config, key, None) - if max_len is not None: - max_len_key = key if max_len < derived_max_model_len \ - else max_len_key - derived_max_model_len = min(derived_max_model_len, max_len) - # For Command-R / Cohere, Cohere2 / Aya Vision models - if tmp_max_len := getattr(hf_config, "model_max_length", None): - max_len_key = "model_max_length" - derived_max_model_len = tmp_max_len - - # If sliding window is manually disabled, max_length should be less - # than the sliding window length in the model config. - if (disable_sliding_window and sliding_window is not None - and sliding_window < derived_max_model_len): - max_len_key = "sliding_window" - derived_max_model_len = sliding_window - - # Consider model_max_length in tokenizer_config - if tokenizer_config: - tokenizer_model_max_length = tokenizer_config.get( - "model_max_length", derived_max_model_len) - derived_max_model_len = min(derived_max_model_len, - tokenizer_model_max_length) - - # If none of the keys were found in the config, use a default and - # log a warning. - if derived_max_model_len == float("inf"): - if max_model_len is not None: - # If max_model_len is specified, we use it. - return max_model_len - - if spec_target_max_model_len is not None: - # If this is a speculative draft model, we use the max model len - # from the target model. - return spec_target_max_model_len - - default_max_len = 2048 - logger.warning( - "The model's config.json does not contain any of the following " - "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", possible_keys, - default_max_len) - derived_max_model_len = default_max_len - - rope_scaling = getattr(hf_config, "rope_scaling", None) - # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE - # scaling, so we skip applying the scaling factor again. - if rope_scaling is not None and "gemma3" not in hf_config.model_type: - # No need to consider "type" key because of patch_rope_scaling when - # loading HF config - rope_type = rope_scaling["rope_type"] - - if rope_type not in ("su", "longrope", "llama3"): - if disable_sliding_window: - # TODO(robertgshaw): Find a model that supports rope_scaling - # with sliding window to see if this case should be allowed. - raise NotImplementedError( - "Disabling sliding window is not supported for models " - "with rope_scaling. Please raise an issue so we can " - "investigate.") - - # NOTE: rope_type == "default" does not define factor - # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py - scaling_factor = rope_scaling.get("factor", 1.0) - - if rope_type == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] - derived_max_model_len *= scaling_factor - - if encoder_config and "max_seq_length" in encoder_config: - derived_max_model_len = encoder_config["max_seq_length"] - - # If the user specified a max length, make sure it is smaller than the - # derived length from the HF model config. - if max_model_len is None: - max_model_len = int(derived_max_model_len) - if current_platform.is_tpu(): - logger.warning( - "--max-model-len is not specified, " - "it's currently using model's default length %s, " - "which might be too large." - "Please input with --max-model-len based on your " - "request input length and output length, to avoid " - "unnecessary degradation.", max_model_len) - elif max_model_len > derived_max_model_len: - # Some models might have a separate key for specifying model_max_length - # that will be bigger than derived_max_model_len. We compare user input - # with model_max_length and allow this override when it's smaller. - model_max_length = getattr(hf_config, "model_max_length", None) - if model_max_length is not None and max_model_len <= model_max_length: - if disable_sliding_window: - # TODO(robertgshaw): Find a model that has model_max_length - # with sliding window to see if this case should be allowed. - raise NotImplementedError( - "Disabling sliding window is not supported for models " - "model_max_length in the config. Please raise an issue " - "so we can investigate.") - else: - msg = ( - f"User-specified max_model_len ({max_model_len}) is greater " - f"than the derived max_model_len ({max_len_key}=" - f"{derived_max_model_len} or model_max_length=" - f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors.") - if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: - logger.warning( - "%s Make sure the value is correct and within the " - "model context size.", msg) - else: - raise ValueError( - f"{msg} To allow overriding this maximum, set " - "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") - return int(max_model_len) - - -def get_served_model_name(model: str, - served_model_name: Optional[Union[str, list[str]]]): - """ - If the input is a non-empty list, the first model_name in - `served_model_name` is taken. - If the input is a non-empty string, it is used directly. - For cases where the input is either an empty string or an - empty list, the fallback is to use `self.model`. - """ - if not served_model_name: - return model - if isinstance(served_model_name, list): - return served_model_name[0] - return served_model_name - - -GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"] - - -@config -@dataclass -class DecodingConfig: - """Dataclass which contains the decoding strategy of the engine.""" - - backend: GuidedDecodingBackend = "auto" - """Which engine will be used for guided decoding (JSON schema / regex etc) - by default. With "auto", we will make opinionated choices based on request - contents and what the backend libraries currently support, so the behavior - is subject to change in each release.""" - - disable_fallback: bool = False - """If `True`, vLLM will not fallback to a different backend on error.""" - - disable_any_whitespace: bool = False - """If `True`, the model will not generate any whitespace during guided - decoding. This is only supported for xgrammar and guidance backends.""" - - disable_additional_properties: bool = False - """If `True`, the `guidance` backend will not use `additionalProperties` - in the JSON schema. This is only supported for the `guidance` backend and - is used to better align its behaviour with `outlines` and `xgrammar`.""" - - reasoning_backend: str = "" - """Select the reasoning parser depending on the model that you're using. - This is used to parse the reasoning content into OpenAI API format.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if (self.disable_any_whitespace - and self.backend not in ("xgrammar", "guidance")): - raise ValueError("disable_any_whitespace is only supported for " - "xgrammar and guidance backends.") - if (self.disable_additional_properties and self.backend != "guidance"): - raise ValueError("disable_additional_properties is only supported " - "for the guidance backend.") - - -DetailedTraceModules = Literal["model", "worker", "all"] - - -@config -@dataclass -class ObservabilityConfig: - """Configuration for observability - metrics and tracing.""" - - show_hidden_metrics_for_version: Optional[str] = None - """Enable deprecated Prometheus metrics that have been hidden since the - specified version. For example, if a previously deprecated metric has been - hidden since the v0.7.0 release, you use - `--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while - you migrate to new metrics. The metric is likely to be removed completely - in an upcoming release.""" - - @cached_property - def show_hidden_metrics(self) -> bool: - """Check if the hidden metrics should be shown.""" - if self.show_hidden_metrics_for_version is None: - return False - return version._prev_minor_version_was( - self.show_hidden_metrics_for_version) - - otlp_traces_endpoint: Optional[str] = None - """Target URL to which OpenTelemetry traces will be sent.""" - - collect_detailed_traces: Optional[list[DetailedTraceModules]] = None - """It makes sense to set this only if `--otlp-traces-endpoint` is set. If - set, it will collect detailed traces for the specified modules. This - involves use of possibly costly and or blocking operations and hence might - have a performance impact. - - Note that collecting detailed timing information for each request can be - expensive.""" - - @cached_property - def collect_model_forward_time(self) -> bool: - """Whether to collect model forward time for the request.""" - return (self.collect_detailed_traces is not None - and ("model" in self.collect_detailed_traces - or "all" in self.collect_detailed_traces)) - - @cached_property - def collect_model_execute_time(self) -> bool: - """Whether to collect model execute time for the request.""" - return (self.collect_detailed_traces is not None - and ("worker" in self.collect_detailed_traces - or "all" in self.collect_detailed_traces)) - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if (self.collect_detailed_traces is not None - and len(self.collect_detailed_traces) == 1 - and "," in self.collect_detailed_traces[0]): - self._parse_collect_detailed_traces() - - from vllm.tracing import is_otel_available, otel_import_error_traceback - if not is_otel_available() and self.otlp_traces_endpoint is not None: - raise ValueError( - "OpenTelemetry is not available. Unable to configure " - "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " - f"installed. Original error:\n{otel_import_error_traceback}") - - def _parse_collect_detailed_traces(self): - assert isinstance(self.collect_detailed_traces, list) - self.collect_detailed_traces = cast( - list[DetailedTraceModules], - self.collect_detailed_traces[0].split(",")) - - -KVProducer = Literal["kv_producer", "kv_both"] -KVConsumer = Literal["kv_consumer", "kv_both"] -KVRole = Literal[KVProducer, KVConsumer] - - -@config -@dataclass -class KVTransferConfig: - """Configuration for distributed KV cache transfer.""" - - kv_connector: Optional[str] = None - """The KV connector for vLLM to transmit KV caches between vLLM instances. - """ - - engine_id: Optional[str] = None - """The engine id for KV transfers.""" - - kv_buffer_device: Optional[str] = "cuda" - """The device used by kv connector to buffer the KV cache. - Currently only support 'cuda'.""" - - kv_buffer_size: float = 1e9 - """The buffer size for TorchDistributedConnector. Measured in number of - bytes. Recommended value: 1e9 (about 1GB).""" - - kv_role: Optional[KVRole] = None - """Whether this vLLM instance produces, consumes KV cache, or both. Choices - are 'kv_producer', 'kv_consumer', and 'kv_both'.""" - - kv_rank: Optional[int] = None - """The rank of this vLLM instance in the KV cache transfer. Typical value: - 0 for prefill instance, 1 for decode instance. - Currently only 1P1D is supported.""" - - kv_parallel_size: int = 1 - """The number of parallel instances for KV cache transfer. For - PyNcclConnector, this should be 2.""" - - kv_ip: str = "127.0.0.1" - """The KV connector ip, used to build distributed connection.""" - - kv_port: int = 14579 - """The KV connector port, used to build distributed connection.""" - - kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) - """any extra config that the connector may need.""" - - kv_connector_module_path: Optional[str] = None - """The Python module path to dynamically load the KV connector from. - Only supported in V1.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self) -> None: - if self.engine_id is None: - self.engine_id = str(uuid.uuid4()) - - if self.kv_role is not None and self.kv_role not in get_args(KVRole): - raise ValueError(f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are {get_args(KVRole)}") - - if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - f"is set, supported roles are {get_args(KVRole)}") - - @property - def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVRole) - - @property - def is_kv_producer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVProducer) - - @property - def is_kv_consumer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVConsumer) - - def get_from_extra_config(self, key, default) -> Any: - return self.kv_connector_extra_config.get(key, default) - - -@config -@dataclass -class KVEventsConfig: - """Configuration for KV event publishing.""" - - enable_kv_cache_events: bool = False - """If True, enable KV cache events for tracking block storage and removal. - Events can be published externally by zmq using the event publisher config. - """ - - publisher: str = "null" - """The publisher to use for publishing kv events. Can be "null", "zmq". - """ - - endpoint: str = "tcp://*:5557" - """The zmq endpoint to use for publishing kv events. - """ - - replay_endpoint: Optional[str] = None - """The zmq endpoint to use for replaying kv events. - """ - - buffer_steps: int = 10_000 - """The number of steps to cache for replay endpoint. Will only save - events from the last N steps for the replay endpoint. - """ - - hwm: int = 100_000 - """The zmq high water mark for the event publisher. After queueing N events, - events will start dropping if the consumer is not keeping up. - """ - - max_queue_size: int = 100_000 - """The maximum number of events to queue while waiting for publishing. - """ - - topic: str = "" - """The topic to use for the event publisher. Consumers can subscribe to - this topic to receive events. - """ - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class VllmConfig: - """Dataclass which contains all vllm-related configuration. This - simplifies passing around the distinct configurations in the codebase. - """ - - # TODO: use default_factory once default constructing ModelConfig doesn't - # try to download a model - model_config: ModelConfig = None # type: ignore - """Model configuration.""" - cache_config: CacheConfig = field(default_factory=CacheConfig) - """Cache configuration.""" - parallel_config: ParallelConfig = field(default_factory=ParallelConfig) - """Parallel configuration.""" - scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) - """Scheduler configuration.""" - device_config: DeviceConfig = field(default_factory=DeviceConfig) - """Device configuration.""" - load_config: LoadConfig = field(default_factory=LoadConfig) - """Load configuration.""" - lora_config: Optional[LoRAConfig] = None - """LoRA configuration.""" - speculative_config: Optional[SpeculativeConfig] = None - """Speculative decoding configuration.""" - decoding_config: DecodingConfig = field(default_factory=DecodingConfig) - """Decoding configuration.""" - observability_config: Optional[ObservabilityConfig] = None - """Observability configuration.""" - quant_config: Optional[QuantizationConfig] = None - """Quantization configuration.""" - compilation_config: CompilationConfig = field( - default_factory=CompilationConfig) - """`torch.compile` and cudagraph capture configuration for the model. - - As a shorthand, `-O<n>` can be used to directly specify the compilation - level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). - Currently, -O <n> and -O=<n> are supported as well but this will likely be - removed in favor of clearer -O<n> syntax in the future. - - NOTE: level 0 is the default level without any optimization. level 1 and 2 - are for internal testing only. level 3 is the recommended level for - production, also default in V1. - - You can specify the full compilation config like so: - `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` - """ - kv_transfer_config: Optional[KVTransferConfig] = None - """The configurations for distributed KV cache transfer.""" - kv_events_config: Optional[KVEventsConfig] = None - """The configurations for event publishing.""" - # some opaque config, only used to provide additional information - # for the hash computation, mainly used for testing, debugging or out of - # tree config registration. - additional_config: Union[dict, SupportsHash] = field(default_factory=dict) - """Additional config for specified platform. Different platforms may - support different configs. Make sure the configs are valid for the platform - you are using. Contents must be hashable.""" - instance_id: str = "" - """The ID of the vLLM instance.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - - # summarize vllm config - vllm_factors: list[Any] = [] - from vllm import __version__ - vllm_factors.append(__version__) - vllm_factors.append(envs.VLLM_USE_V1) - if self.model_config: - vllm_factors.append(self.model_config.compute_hash()) - else: - vllm_factors.append("None") - if self.cache_config: - vllm_factors.append(self.cache_config.compute_hash()) - else: - vllm_factors.append("None") - if self.parallel_config: - vllm_factors.append(self.parallel_config.compute_hash()) - else: - vllm_factors.append("None") - if self.scheduler_config: - vllm_factors.append(self.scheduler_config.compute_hash()) - else: - vllm_factors.append("None") - if self.device_config: - vllm_factors.append(self.device_config.compute_hash()) - else: - vllm_factors.append("None") - if self.load_config: - vllm_factors.append(self.load_config.compute_hash()) - else: - vllm_factors.append("None") - if self.lora_config: - vllm_factors.append(self.lora_config.compute_hash()) - # LoRA creates static buffers based on max_num_batched_tokens. - # The tensor sizes and strides get captured in the torch.compile - # graph explicitly. - vllm_factors.append( - str(self.scheduler_config.max_num_batched_tokens)) - else: - vllm_factors.append("None") - if self.speculative_config: - vllm_factors.append(self.speculative_config.compute_hash()) - else: - vllm_factors.append("None") - if self.decoding_config: - vllm_factors.append(self.decoding_config.compute_hash()) - else: - vllm_factors.append("None") - if self.observability_config: - vllm_factors.append(self.observability_config.compute_hash()) - else: - vllm_factors.append("None") - if self.quant_config: - pass # should be captured by model_config.quantization - if self.compilation_config: - vllm_factors.append(self.compilation_config.compute_hash()) - else: - vllm_factors.append("None") - if self.kv_transfer_config: - vllm_factors.append(self.kv_transfer_config.compute_hash()) - else: - vllm_factors.append("None") - if self.additional_config: - if isinstance(additional_config := self.additional_config, dict): - additional_config_hash = hashlib.md5( - json.dumps(additional_config, sort_keys=True).encode(), - usedforsecurity=False, - ).hexdigest() - else: - additional_config_hash = additional_config.compute_hash() - vllm_factors.append(additional_config_hash) - else: - vllm_factors.append("None") - factors.append(vllm_factors) - - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] - return hash_str - - def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_capture_size, - # it should raise an IndexError. - # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_capture_size - return self.compilation_config.bs_to_padded_graph_size[batch_size] - - @staticmethod - def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: - """Get the quantization config.""" - from vllm.platforms import current_platform - if model_config.quantization is not None: - from vllm.model_executor.model_loader.weight_utils import ( - get_quant_config) - quant_config = get_quant_config(model_config, load_config) - capability_tuple = current_platform.get_device_capability() - - if capability_tuple is not None: - capability = capability_tuple.to_int() - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} " - "is not supported for the current GPU. Minimum " - f"capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - return quant_config - return None - - @staticmethod - def get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: - import copy - - # For some reason, the _ version of this modifies the model_config - # object, so using deepcopy to avoid this problem. - return VllmConfig._get_quantization_config(copy.deepcopy(model_config), - load_config) - - def with_hf_config( - self, - hf_config: PretrainedConfig, - architectures: Optional[list[str]] = None, - ) -> "VllmConfig": - if architectures is not None: - hf_config = copy.deepcopy(hf_config) - hf_config.architectures = architectures - - model_config = copy.deepcopy(self.model_config) - model_config.hf_config = hf_config - - return replace(self, model_config=model_config) - - def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ - - self.try_verify_and_update_config() - - if self.model_config is not None: - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) - self.model_config.verify_with_parallel_config(self.parallel_config) - self.model_config.verify_dual_chunk_attention_config( - self.load_config) - - self.cache_config.verify_with_parallel_config(self.parallel_config) - - if self.lora_config is not None: - self.lora_config.verify_with_cache_config(self.cache_config) - self.lora_config.verify_with_model_config(self.model_config) - - if self.quant_config is None and self.model_config is not None: - self.quant_config = VllmConfig._get_quantization_config( - self.model_config, self.load_config) - - from vllm.platforms import current_platform - if self.model_config is not None and \ - self.scheduler_config.chunked_prefill_enabled and \ - self.model_config.dtype == torch.float32 and \ - current_platform.get_device_capability() == (7, 5): - logger.warning_once( - "Turing devices tensor cores do not support float32 matmul. " - "To workaround this limitation, vLLM will set 'ieee' input " - "precision for chunked prefill triton kernels.") - - # If the user does not explicitly set a compilation level, then - # we use the default level. The default level depends on other - # settings (see the below code). - if self.compilation_config.level is None: - if envs.VLLM_USE_V1: - if (self.model_config is not None - and not self.model_config.enforce_eager): - self.compilation_config.level = CompilationLevel.PIECEWISE - else: - self.compilation_config.level = \ - CompilationLevel.NO_COMPILATION - - else: - # NB: Passing both --enforce-eager and a compilation level - # in V0 means the compilation level wins out. - self.compilation_config.level = CompilationLevel.NO_COMPILATION - - # async tp is built on top of sequence parallelism - # and requires it to be enabled. - if self.compilation_config.pass_config.enable_async_tp: - self.compilation_config.pass_config.enable_sequence_parallelism = \ - True - if self.compilation_config.pass_config.enable_sequence_parallelism: - self.compilation_config.custom_ops.append("+rms_norm") - - if current_platform.is_cuda_alike() or current_platform.is_xpu(): - # if cudagraph_mode is not explicitly set by users, set default - # value - if self.compilation_config.cudagraph_mode is None: - if envs.VLLM_USE_V1 and self.compilation_config.level \ - == CompilationLevel.PIECEWISE: - self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - - # disable cudagraph when enforce eager execution - if self.model_config is not None and \ - self.model_config.enforce_eager: - logger.info("Cudagraph is disabled under eager mode") - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - elif envs.VLLM_USE_V1: - self.compilation_config.cudagraph_num_of_warmups = 1 - - self._set_cudagraph_sizes() - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - - if self.cache_config.cpu_offload_gb > 0 and \ - self.compilation_config.level != CompilationLevel.NO_COMPILATION \ - and not envs.VLLM_USE_V1: - logger.warning( - "CPU offload is not supported with `torch.compile` in v0 yet." - " Disabling `torch.compile`.") - self.compilation_config.level = CompilationLevel.NO_COMPILATION - - if ((not envs.VLLM_USE_V1) and self.lora_config is not None - and self.compilation_config.level - != CompilationLevel.NO_COMPILATION): - logger.warning( - "LoRA for V0 is not supported with `torch.compile` yet. " - "Disabling `torch.compile`.") - self.compilation_config.level = CompilationLevel.NO_COMPILATION - - disable_chunked_prefill_reasons: list[str] = [] - - if self.model_config and self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": - disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") - elif not getattr(self.model_config.hf_config, "is_causal", True): - disable_chunked_prefill_reasons.append( - "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") - - if disable_chunked_prefill_reasons: - for reason in disable_chunked_prefill_reasons: - logger.info(reason) - self.scheduler_config.chunked_prefill_enabled = False - self.scheduler_config.long_prefill_token_threshold = 0 - - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False - - if (self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events - and not self.cache_config.enable_prefix_caching): - logger.warning( - "KV cache events are on, but prefix caching is not enabled." - "Use --enable-prefix-caching to enable.") - if (self.kv_events_config is not None - and self.kv_events_config.publisher != "null" - and not self.kv_events_config.enable_kv_cache_events): - logger.warning("KV cache events are disabled," - "but the scheduler is configured to publish them." - "Modify KVEventsConfig.enable_kv_cache_events" - "to True to enable.") - current_platform.check_and_update_config(self) - - # final check of cudagraph mode after platform-specific update - if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): - if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ - and self.model_config is not None and \ - not self.model_config.disable_cascade_attn: - logger.info("CUDAGraphMode.FULL is not supported with " - "cascade attention currently. Disabling cascade" - "attention.") - self.model_config.disable_cascade_attn = True - - if self.compilation_config.cudagraph_mode\ - .requires_piecewise_compilation(): - assert self.compilation_config.level == \ - CompilationLevel.PIECEWISE, \ - "Compilation level should be CompilationLevel.PIECEWISE "\ - "when cudagraph_mode piecewise cudagraphs is used, "\ - f"cudagraph_mode={self.compilation_config.cudagraph_mode}" - - if not self.instance_id: - self.instance_id = random_uuid()[:5] - - # Do this after all the updates to compilation_config.level - if envs.VLLM_USE_V1 and \ - self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - - if (envs.VLLM_USE_V1 - and not self.scheduler_config.disable_hybrid_kv_cache_manager): - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not (current_platform.is_cuda() or current_platform.is_rocm()): - # Hybrid KV cache manager is not supported on non-GPU platforms. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_transfer_config is not None: - # Hybrid KV cache manager is not compatible with KV transfer. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.model_config is not None and \ - self.model_config.attention_chunk_size is not None: - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif \ - not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - - def update_sizes_for_sequence_parallelism(self, - possible_sizes: list) -> list: - # remove the sizes that not multiple of tp_size when - # enable sequence parallelism - removed_sizes = [ - size for size in possible_sizes - if size % self.parallel_config.tensor_parallel_size != 0 - ] - if removed_sizes: - logger.warning( - "Batch sizes %s are removed because they are not " - "multiple of tp_size %d when " - "sequence parallelism is enabled", removed_sizes, - self.parallel_config.tensor_parallel_size) - - return [ - size for size in possible_sizes - if size % self.parallel_config.tensor_parallel_size == 0 - ] - - def _set_cudagraph_sizes(self): - """ - cudagraph batchsize padding logic: - - `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible - batch sizes that cudagraph will capture. - - Depending on the engine's configuration of `max_num_seqs`, the - candidate batch sizes to capture cudagraph will shrink to the subset - which just cover the range of `[1, max_num_seqs]`. In the common case, - `max_num_seqs` is 256, and the cudagraph batch sizes will be - `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. - - However, if users specify the cudagraph capture sizes through - compilation config, we will use the specified sizes instead. - - In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` - will be the final sizes to capture cudagraph (in descending order). - - During runtime, if batchsize is larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - no cudagraph will be used. - If the batch size is no larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - we can quickly find the padded graph size for a given batch size by - looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. - """ - - # calculate the default `batch_size_capture_list` - if not envs.VLLM_USE_V1: - batch_size_capture_list = [] - if self.scheduler_config is not None and \ - self.model_config is not None and \ - not self.model_config.enforce_eager: - - possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] - if self.parallel_config.tensor_parallel_size > 1 and \ - self.compilation_config.pass_config.enable_sequence_parallelism: - possible_sizes = self.update_sizes_for_sequence_parallelism( - possible_sizes) - - # find the minimum size that is larger than max_num_seqs, - # which then becomes the max_batchsize_to_capture - larger_sizes = [ - x for x in possible_sizes - if x >= self.scheduler_config.max_num_seqs - ] - if larger_sizes: - max_batchsize_to_capture = larger_sizes[0] - else: - max_batchsize_to_capture = possible_sizes[-1] - - # filter out the sizes that are - # larger than max_batchsize_to_capture - batch_size_capture_list = [ - size for size in possible_sizes - if size <= max_batchsize_to_capture - ] - else: - batch_size_capture_list = [] - if self.model_config is not None and \ - not self.model_config.enforce_eager: - cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes - if len(cuda_graph_sizes) == 1: - batch_size_capture_list = [1, 2, 4] + [ - i for i in range(8, cuda_graph_sizes[0] + 1, 8) - ] - elif len(cuda_graph_sizes) > 1: - batch_size_capture_list = sorted(cuda_graph_sizes) - else: - raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") - if self.parallel_config.tensor_parallel_size > 1 and \ - self.compilation_config.pass_config.enable_sequence_parallelism: - batch_size_capture_list = \ - self.update_sizes_for_sequence_parallelism(batch_size_capture_list) - max_num_tokens = self.scheduler_config.max_num_batched_tokens - batch_size_capture_list = [ - size for size in batch_size_capture_list - if size <= max_num_tokens - ] - - self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) - - def recalculate_max_model_len(self, max_model_len: int): - # Can only be called in try_verify_and_update_config - model_config = self.model_config - max_model_len = model_config.get_and_verify_max_len(max_model_len) - self.model_config.max_model_len = max_model_len - self.scheduler_config.max_model_len = max_model_len - - def try_verify_and_update_config(self): - if self.model_config is None: - return - - # Avoid running try_verify_and_update_config multiple times - if getattr(self.model_config, "config_updated", False): - return - self.model_config.config_updated = True - - architecture = self.model_config.architecture - if architecture is None: - return - - from vllm.model_executor.models.config import ( - MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) - cls = MODELS_CONFIG_MAP.get(architecture, None) - if cls is not None: - cls.verify_and_update_config(self) - - if self.model_config.is_hybrid: - HybridAttentionMambaModelConfig.verify_and_update_config(self) - - if self.model_config.convert_type == "classify": - # Maybe convert ForCausalLM into ForSequenceClassification model. - from vllm.model_executor.models.adapters import ( - SequenceClassificationConfig) - SequenceClassificationConfig.verify_and_update_config(self) - - def __str__(self): - return ( - f"model={self.model_config.model!r}, " - f"speculative_config={self.speculative_config!r}, " - f"tokenizer={self.model_config.tokenizer!r}, " - f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " - f"tokenizer_mode={self.model_config.tokenizer_mode}, " - f"revision={self.model_config.revision}, " - f"override_neuron_config={self.model_config.override_neuron_config}, " # noqa - f"tokenizer_revision={self.model_config.tokenizer_revision}, " - f"trust_remote_code={self.model_config.trust_remote_code}, " - f"dtype={self.model_config.dtype}, " - f"max_seq_len={self.model_config.max_model_len}, " - f"download_dir={self.load_config.download_dir!r}, " - f"load_format={self.load_config.load_format}, " - f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa - f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa - f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa - f"quantization={self.model_config.quantization}, " - f"enforce_eager={self.model_config.enforce_eager}, " - f"kv_cache_dtype={self.cache_config.cache_dtype}, " - f"device_config={self.device_config.device}, " - f"decoding_config={self.decoding_config!r}, " - f"observability_config={self.observability_config!r}, " - f"seed={self.model_config.seed}, " - f"served_model_name={self.model_config.served_model_name}, " - f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " - f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa - f"use_async_output_proc={self.model_config.use_async_output_proc}, " - f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}") - - -_current_vllm_config: Optional[VllmConfig] = None -_current_prefix: Optional[str] = None - - -@contextmanager -def set_current_vllm_config(vllm_config: VllmConfig, - check_compile=False, - prefix: Optional[str] = None): - """ - Temporarily set the current vLLM config. - Used during model initialization. - We save the current vLLM config in a global variable, - so that all modules can access it, e.g. custom ops - can access the vLLM config to determine how to dispatch. - """ - global _current_vllm_config, _current_prefix - old_vllm_config = _current_vllm_config - old_prefix = _current_prefix - from vllm.compilation.counter import compilation_counter - num_models_seen = compilation_counter.num_models_seen - try: - _current_vllm_config = vllm_config - _current_prefix = prefix - yield - except Exception: - raise - else: - logger.debug("enabled custom ops: %s", - vllm_config.compilation_config.enabled_custom_ops) - logger.debug("disabled custom ops: %s", - vllm_config.compilation_config.disabled_custom_ops) - if check_compile and \ - vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ - and compilation_counter.num_models_seen == num_models_seen: - # If the model supports compilation, - # compilation_counter.num_models_seen should be increased - # by at least 1. - # If it is not increased, it means the model does not support - # compilation (does not have @support_torch_compile decorator). - logger.warning( - "`torch.compile` is turned on, but the model %s" - " does not support it. Please open an issue on GitHub" - " if you want it to be supported.", - vllm_config.model_config.model) - finally: - _current_vllm_config = old_vllm_config - _current_prefix = old_prefix - # Clear the compilation config cache when context changes - get_cached_compilation_config.cache_clear() - - -@lru_cache(maxsize=1) -def get_cached_compilation_config(): - """Cache config to avoid repeated calls to get_current_vllm_config()""" - return get_current_vllm_config().compilation_config - - -def get_current_vllm_config() -> VllmConfig: - if _current_vllm_config is None: - # in ci, usually when we test custom ops/modules directly, - # we don't set the vllm config. In that case, we set a default - # config. - logger.warning("Current vLLM config is not set.") - from vllm.config import VllmConfig - return VllmConfig() - return _current_vllm_config - - -def get_current_model_prefix() -> str: - """ - Get the prefix of the model that's currently being initialized. - """ - assert _current_prefix is not None, \ - "Current model prefix is not set. " - return _current_prefix - - -def contains_object_print(text): - """ - Check if the text looks like a printed Python object, e.g. - contains any substring matching the pattern: "at 0xFFFFFFF>" - We match against 0x followed by 2-16 hex chars (there's - a max of 16 on a 64 bit system). - - Args: - text (str): The text to check - - Returns: - result (bool): `True` if a match is found, `False` otherwise. - """ - pattern = r'at 0x[a-fA-F0-9]{2,16}>' - match = re.search(pattern, text) - return match is not None - - -def assert_hashable(text): - if not contains_object_print(text): - return True - raise AssertionError( - f"vLLM tried to hash some configs that may have Python objects ids " - f"in them. This is a bug, please file an issue. " - f"Text being hashed: {text}") - - -T = TypeVar("T") - - -def get_layers_from_vllm_config( - vllm_config: VllmConfig, - layer_type: type[T], - layer_names: Optional[list[str]] = None) -> dict[str, T]: - """ - Get layers from the vLLM config. - - Args: - vllm_config: The vLLM config. - layer_type: The type of the layer to get. - layer_names: The names of the layers to get. If None, return all layers. - """ - - if layer_names is None: - layer_names = list( - vllm_config.compilation_config.static_forward_context.keys()) - - forward_context = vllm_config.compilation_config.static_forward_context - - return { - layer_name: forward_context[layer_name] - for layer_name in layer_names - if isinstance(forward_context[layer_name], layer_type) - } - - -@config -@dataclass -class SpeechToTextConfig: - """Configuration for speech-to-text models.""" - - sample_rate: float = 16_000 - """Sample rate (Hz) to resample input audio to. Most speech models expect - 16kHz audio input. The input audio will be automatically resampled to this - rate before processing.""" - - max_audio_clip_s: int = 30 - """Maximum duration in seconds for a single audio clip without chunking. - Audio longer than this will be split into smaller chunks if - `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" - - overlap_chunk_second: int = 1 - """Overlap duration in seconds between consecutive audio chunks when - splitting long audio. This helps maintain context across chunk boundaries - and improves transcription quality at split points.""" - - min_energy_split_window_size: Optional[int] = 1600 - """Window size in samples for finding low-energy (quiet) regions to split - audio chunks. The algorithm looks for the quietest moment within this - window to minimize cutting through speech. Default 1600 samples ≈ 100ms - at 16kHz. If None, no chunking will be done.""" - - @property - def allow_audio_chunking(self) -> bool: - return self.min_energy_split_window_size is not None - - -def update_config(config: DataclassInstanceT, - overrides: dict[str, Any]) -> DataclassInstanceT: - processed_overrides = {} - for field_name, value in overrides.items(): - assert hasattr( - config, field_name), f"{type(config)} has no field `{field_name}`" - current_value = getattr(config, field_name) - if is_dataclass(current_value) and not is_dataclass(value): - assert isinstance(value, dict), ( - f"Overrides to {type(config)}.{field_name} must be a dict" - f" or {type(current_value)}, but got {type(value)}") - value = update_config( - current_value, # type: ignore[type-var] - value) - processed_overrides[field_name] = value - return replace(config, **processed_overrides) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ae11dec3ca5e2..fd47d5c8f976f 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -3,13 +3,11 @@ import hashlib from dataclasses import field -from typing import TYPE_CHECKING, Any, Literal, Optional, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional -from pydantic import SkipValidation, model_validator +from pydantic import Field, SkipValidation, field_validator from pydantic.dataclasses import dataclass -from typing_extensions import Self -import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger from vllm.utils import GiB_bytes, get_cpu_memory @@ -22,9 +20,9 @@ else: logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] -PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] +PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @config @@ -33,14 +31,13 @@ class CacheConfig: """Configuration for the KV cache.""" block_size: SkipValidation[BlockSize] = None # type: ignore - """Size of a contiguous cache block in number of tokens. This is ignored on - neuron devices and set to `--max-model-len`. On CUDA devices, only block - sizes up to 32 are supported. On HPU devices, block size defaults to 128. + """Size of a contiguous cache block in number of tokens. On CUDA devices, + only block sizes up to 32 are supported. This config has no static default. If left unspecified by the user, it will be set in `Platform.check_and_update_config()` based on the current platform.""" - gpu_memory_utilization: float = 0.9 + gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9. This is a @@ -48,12 +45,16 @@ class CacheConfig: not matter if you have another vLLM instance running on the same GPU. For example, if you have two vLLM instances running on the same GPU, you can set the GPU memory utilization to 0.5 for each instance.""" - swap_space: float = 4 + swap_space: float = Field(default=4, ge=0) """Size of the CPU swap space per GPU (in GiB).""" cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports - fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" + fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc). + Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use + bfloat16 instead, this is an invalid option for models that do not default + to fp8. + """ is_attention_free: bool = False """Whether the model is attention-free. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -64,18 +65,13 @@ class CacheConfig: """Sliding window size for the KV cache. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" enable_prefix_caching: Optional[bool] = None - """Whether to enable prefix caching. Disabled by default for V0. Enabled by - default for V1.""" - prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Whether to enable prefix caching. Enabled by default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - - "builtin" is Python's built-in hash.\n - - "sha256" is collision resistant but with certain overheads. - This option uses Pickle for object serialization before hashing.\n - - "sha256_cbor_64bit" provides a reproducible, cross-language compatible - hash. It serializes objects using canonical CBOR and hashes them with - SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 - digest.""" - cpu_offload_gb: float = 0 + - "sha256" uses Pickle for object serialization before hashing.\n + - "sha256_cbor" provides a reproducible, cross-language compatible hash. It + serializes objects using canonical CBOR and hashes them with SHA-256.""" + cpu_offload_gb: float = Field(default=0, ge=0) """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and @@ -93,7 +89,8 @@ class CacheConfig: mamba_page_size_padded: Optional[int] = None """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" - + mamba_block_size: Optional[int] = None + """Size of a contiguous cache block in number of tokens for mamba cache.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the ssm state). If set to 'auto', the data type will be inferred from the model @@ -115,10 +112,19 @@ class CacheConfig: In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), some layers can skip tokens corresponding to prefill. This flag enables - attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models (e.g. Gemma3n) + attention metadata for eligible layers to be overridden with metadata + necessary for implementing this optimization in some models (e.g. Gemma3n) """ + kv_cache_memory_bytes: Optional[int] = None + """Size of KV Cache per GPU in bytes. By default, this is set to None + and vllm can automatically infer the kv cache size based on + gpu_memory_utilization. However, users may want to manually specify + the kv cache memory size. kv_cache_memory_bytes allows more fine-grain + control of how much memory gets used when compared with using + gpu_memory_utilization. Note that kv_cache_memory_bytes + (when not-None) ignores gpu_memory_utilization""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -136,80 +142,42 @@ class CacheConfig: factors.append(self.mamba_cache_dtype) factors.append(self.mamba_ssm_cache_dtype) # `cpu_offload_gb` does not use `torch.compile` yet. - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self) -> None: - self.swap_space_bytes = self.swap_space * GiB_bytes - - self._verify_cache_dtype() - self._verify_prefix_caching() - def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} - @model_validator(mode='after') - def _verify_args(self) -> Self: - if self.cpu_offload_gb < 0: - raise ValueError("CPU offload space must be non-negative" - f", but got {self.cpu_offload_gb}") - - if self.gpu_memory_utilization > 1.0: - raise ValueError( - "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") - - if self.kv_sharing_fast_prefill: - logger.warning_once( - "--kv-sharing-fast-prefill is currently work in progress " - "and not functional yet (i.e. no prefill savings)") - - return self - - def _verify_cache_dtype(self) -> None: - if self.cache_dtype == "auto": - pass - elif self.cache_dtype in get_args(CacheDType): + @field_validator("cache_dtype", mode="after") + @classmethod + def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: + if cache_dtype.startswith("fp8"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor.") - else: - raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") - - def _verify_prefix_caching(self) -> None: - if not self.enable_prefix_caching: - return - - if self.sliding_window is not None and not envs.VLLM_USE_V1: - raise NotImplementedError( - "Prefix caching is not supported with sliding window. " - "Run with --disable-sliding-window to use prefix caching.") - - if (self.enable_prefix_caching and self.prefix_caching_hash_algo - not in get_args(PrefixCachingHashAlgo)): - raise ValueError( - "Unknown prefix caching hash algorithm: " - f"{self.prefix_caching_hash_algo}. Must be one of " - f"{get_args(PrefixCachingHashAlgo)}.") + "scaling factor." + ) + return cache_dtype def verify_with_parallel_config( self, parallel_config: ParallelConfig, ) -> None: + swap_space_bytes = self.swap_space * GiB_bytes total_cpu_memory = get_cpu_memory() # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel # group are in the same node. However, the GPUs may span multiple nodes. num_gpus_per_node = parallel_config.tensor_parallel_size - cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + cpu_memory_usage = swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " - f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " - "is allocated for the swap space.") + msg = ( + f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space." + ) if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index e2785e7602e45..e65728ba7f4e1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -5,12 +5,12 @@ import enum import hashlib from collections import Counter from dataclasses import asdict, field +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union from pydantic import TypeAdapter, field_validator from pydantic.dataclasses import dataclass -import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import config from vllm.logger import init_logger @@ -33,38 +33,50 @@ class CompilationLevel: class CUDAGraphMode(enum.Enum): - """ Constants for the cudagraph mode in CompilationConfig. + """Constants for the cudagraph mode in CompilationConfig. Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also treated as concrete runtime mode for cudagraph runtime dispatching. """ + NONE = 0 PIECEWISE = 1 FULL = 2 FULL_DECODE_ONLY = (FULL, NONE) FULL_AND_PIECEWISE = (FULL, PIECEWISE) - def decode_mode(self) -> 'CUDAGraphMode': - return CUDAGraphMode(self.value[0]) if \ - self.separate_routine() else self + def decode_mode(self) -> "CUDAGraphMode": + return CUDAGraphMode(self.value[0]) if self.separate_routine() else self - def mixed_mode(self) -> 'CUDAGraphMode': - return CUDAGraphMode(self.value[1]) if \ - self.separate_routine() else self + def mixed_mode(self) -> "CUDAGraphMode": + return CUDAGraphMode(self.value[1]) if self.separate_routine() else self + + def has_mode(self, mode: "CUDAGraphMode") -> bool: + assert not mode.separate_routine() + if self.separate_routine(): + return mode.value in self.value + return self == mode def requires_piecewise_compilation(self) -> bool: - return (self.decode_mode() == CUDAGraphMode.PIECEWISE - or self.mixed_mode() == CUDAGraphMode.PIECEWISE) + return self.has_mode(CUDAGraphMode.PIECEWISE) - def max_cudagraph_mode(self) -> 'CUDAGraphMode': - return CUDAGraphMode(max( - self.value)) if self.separate_routine() else self + def max_cudagraph_mode(self) -> "CUDAGraphMode": + return CUDAGraphMode(max(self.value)) if self.separate_routine() else self def has_full_cudagraphs(self) -> bool: return self.max_cudagraph_mode() == CUDAGraphMode.FULL + def has_piecewise_cudagraphs(self) -> bool: + return self.requires_piecewise_compilation() + def separate_routine(self) -> bool: return isinstance(self.value, tuple) + def valid_runtime_modes(self) -> bool: + return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] + + def __str__(self) -> str: + return self.name + @config @dataclass @@ -75,11 +87,11 @@ class PassConfig: don't all have access to full configuration - that would create a cycle as the `PassManager` is set as a property of config.""" - enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) + enable_fusion: bool = False """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" enable_attn_fusion: bool = False """Whether to enable the custom attention+quant fusion pass.""" - enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) + enable_noop: bool = False """Whether to enable the custom no-op elimination pass.""" enable_sequence_parallelism: bool = False """Whether to enable sequence parallelism.""" @@ -105,11 +117,13 @@ class PassConfig: if self.enable_fusion: logger.warning_once( "Fusion enabled but reshape elimination disabled. " - "RMSNorm/SiluMul + quant (fp8) fusion might not work") + "RMSNorm/SiluMul + quant (fp8) fusion might not work" + ) if self.enable_attn_fusion: logger.warning_once( "Fusion enabled but reshape elimination disabled. " - "Attention + quant (fp8) fusion might not work") + "Attention + quant (fp8) fusion might not work" + ) @config @@ -152,6 +166,7 @@ class CompilationConfig: sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing. """ + # Top-level Compilation control level: Optional[int] = None """The level of compilation: @@ -162,7 +177,7 @@ class CompilationConfig: - 1: dynamo as is. - 2: dynamo once. - 3: piecewise compilation.""" - debug_dump_path: str = "" + debug_dump_path: Optional[Path] = None """The path to dump the debug information.""" cache_dir: str = "" """The directory to store the compiled graph, to accelerate Inductor @@ -194,8 +209,23 @@ class CompilationConfig: disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" splitting_ops: Optional[list[str]] = None - """A list of ops to split the full graph into subgraphs, used in piecewise - compilation.""" + """A list of ops to exclude from cudagraphs, used in piecewise compilation. + + The behavior depends on use_inductor_graph_partition: + + - When use_inductor_graph_partition=False (default): + These ops are used for Dynamo FX-level graph splitting. The graph is + split at these ops before Inductor compilation, creating separate + subgraphs for cudagraph capture. + + - When use_inductor_graph_partition=True: + These ops are used to register Inductor partition rules. The graph + partitioning happens at Inductor codegen time after all passes and + fusions are finished, allowing compilation and custom passes to operate + on the full graph while still excluding these ops from cudagraphs. + + If None, defaults to attention ops for piecewise cudagraphs. + If empty list [], no ops are excluded (suitable for full cudagraphs).""" # Inductor capture use_inductor: bool = True @@ -225,17 +255,17 @@ class CompilationConfig: # CudaGraph compilation cudagraph_mode: Optional[CUDAGraphMode] = None """ - The mode of the cudagraph. + The mode of the cudagraph: + - NONE, no cudagraph capture. - - PIECEWISE. (v1 default) + - PIECEWISE. - FULL. - FULL_DECODE_ONLY. - - FULL_AND_PIECEWISE. + - FULL_AND_PIECEWISE. (v1 default) PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph - incompatiable ops (i.e. some attention ops) outside the cudagraph + incompatible ops (i.e. some attention ops) outside the cudagraph for general flexibility. - This is the default mode. FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. @@ -248,7 +278,7 @@ class CompilationConfig: FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. - This is like the most performant mode for most models. + This is the most performant mode for most models and is the default. Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the @@ -270,7 +300,8 @@ class CompilationConfig: Note that this is orthogonal to the cudagraph capture logic outside of compilation. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE + instead. """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. @@ -295,7 +326,28 @@ class CompilationConfig: flag cannot be used together with splitting_ops. This may provide performance benefits for smaller models. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= + FULL_AND_PIECEWISE instead. + """ + + use_inductor_graph_partition: bool = False + """Use inductor graph partition to split the graph at cudagraph_unsafe ops. + This partition happens at inductor codegen time after all passes and fusions + are finished. It generates a single `call` function which wraps + cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops + outside the partition functions. For a graph with N cudagraph-unsafe ops + (e.g., Attention), there would be N+1 partitions. To mark an op as + cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when + register the custom op. + + This config supports both full cudagraph and piecewise cudagraph without + compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper + to each partition. For N+1 partitions, there would be N+1 + CUDAGraph wrapper instances. + + For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the + inductor `call` function in the model runner. The top-level full cudagraph + capture ignores all partitioning. """ pass_config: PassConfig = field(default_factory=PassConfig) @@ -307,37 +359,42 @@ class CompilationConfig: """local cache dir for each rank""" bs_to_padded_graph_size: list[int] = field( default=None, # type: ignore - init=False) + init=False, + ) """optimization: Intuitively, bs_to_padded_graph_size should be dict[int, int]. since we know all keys are in a range [0, max_capture_size], we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = field(default_factory=Counter, - init=False) + enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are enabled""" - disabled_custom_ops: Counter[str] = field(default_factory=Counter, - init=False) + disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are disabled""" traced_files: set[str] = field(default_factory=set, init=False) """files that are traced for compilation""" compilation_time: float = field(default=0.0, init=False) """time taken for compilation""" - static_forward_context: dict[str, Any] = field(default_factory=dict, - init=False) + static_forward_context: dict[str, Any] = field(default_factory=dict, init=False) """Per-model forward context Map from layer name to layer objects that need to be accessed outside model code, e.g., Attention, FusedMOE when dp_size>1.""" # Attention ops; used for piecewise cudagraphs + # Use PyTorch operator format: "namespace::name" _attention_ops: ClassVar[list[str]] = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - "vllm.mamba_mixer2", - "vllm.mamba_mixer", - "vllm.short_conv", + "vllm::unified_attention", + "vllm::unified_attention_with_output", + "vllm::unified_mla_attention", + "vllm::unified_mla_attention_with_output", + "vllm::mamba_mixer2", + "vllm::mamba_mixer", + "vllm::short_conv", + "vllm::linear_attention", + "vllm::plamo2_mamba_mixer", + "vllm::gdn_attention", + "vllm::sparse_attn_indexer", ] def compute_hash(self) -> str: @@ -384,13 +441,11 @@ class CompilationConfig: if pass_config_exclude: exclude["pass_config"] = pass_config_exclude - # The cast to string is necessary because Pydantic is mocked in docs - # builds and sphinx-argparse doesn't know the return type of decode() - return str( - TypeAdapter(CompilationConfig).dump_json( - self, - exclude=exclude, # type: ignore[arg-type] - exclude_unset=True).decode()) + config = TypeAdapter(CompilationConfig).dump_python( + self, exclude=exclude, exclude_unset=True + ) + + return str(config) __str__ = __repr__ @@ -418,16 +473,16 @@ class CompilationConfig: # https://github.com/vllm-project/vllm/issues/14703 if is_torch_equal_or_newer("2.6"): - KEY = 'enable_auto_functionalized_v2' + KEY = "enable_auto_functionalized_v2" if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False for k, v in self.inductor_passes.items(): if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be callable or a qualified name") - self.inductor_compile_config[k] = v if isinstance( - v, InductorPass) else CallableInductorPass(v) + assert callable(v), f"pass {k} should be callable or a qualified name" + self.inductor_compile_config[k] = ( + v if isinstance(v, InductorPass) else CallableInductorPass(v) + ) continue # resolve function from qualified name @@ -435,40 +490,68 @@ class CompilationConfig: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func if isinstance( - func, InductorPass) else CallableInductorPass(func) + self.inductor_compile_config[k] = ( + func if isinstance(func, InductorPass) else CallableInductorPass(func) + ) if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) # migrate the deprecated flags if not self.use_cudagraph: - logger.warning("use_cudagraph is deprecated, use " - "cudagraph_mode=NONE instead.") - if self.cudagraph_mode is not None: + logger.warning( + "use_cudagraph is deprecated, use cudagraph_mode=NONE instead." + ) + if ( + self.cudagraph_mode is not None + and self.cudagraph_mode != CUDAGraphMode.NONE + ): raise ValueError( "use_cudagraph and cudagraph_mode are mutually" " exclusive, prefer cudagraph_mode since " - "use_cudagraph is deprecated.") + "use_cudagraph is deprecated." + ) self.cudagraph_mode = CUDAGraphMode.NONE if self.full_cuda_graph: - logger.warning("full_cuda_graph is deprecated, use " - "cudagraph_mode=FULL instead.") - if self.cudagraph_mode is not None: - raise ValueError("full_cuda_graph and cudagraph_mode are " - "mutually exclusive, prefer cudagraph_mode " - "since full_cuda_graph is deprecated.") + logger.warning( + "full_cuda_graph is deprecated, use cudagraph_mode=FULL instead." + ) + if ( + self.cudagraph_mode is not None + and not self.cudagraph_mode.has_full_cudagraphs() + ): + raise ValueError( + "full_cuda_graph and cudagraph_mode are " + "mutually exclusive, prefer cudagraph_mode " + "since full_cuda_graph is deprecated." + ) self.cudagraph_mode = CUDAGraphMode.FULL + if self.use_inductor_graph_partition and not is_torch_equal_or_newer( + "2.9.0.dev" + ): + raise ValueError( + "use_inductor_graph_partition is only " + "supported with torch>=2.9.0.dev. Set " + "use_inductor_graph_partition=False instead." + ) + + for op in self.custom_ops: + if op[0] not in {"+", "-"} and op not in {"all", "none"}: + raise ValueError( + f"Invalid syntax '{op}' for custom op, " + "must be 'all', 'none', '+op' or '-op' " + "(where 'op' is the registered op name)" + ) + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) - if self.level in [ - CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE - ]: + if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: if self.backend == "": return "eager" if self.backend in torch_backends: @@ -480,10 +563,10 @@ class CompilationConfig: assert self.level == CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend + return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, - cudagraph_capture_sizes: list[int]) -> None: + def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None: """To complete the initialization of config, we need to know the cudagraph sizes.""" @@ -493,9 +576,14 @@ class CompilationConfig: # de-duplicate the sizes provided by the config dedup_sizes = list(set(self.cudagraph_capture_sizes)) if len(dedup_sizes) < len(self.cudagraph_capture_sizes): - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - cudagraph_capture_sizes, dedup_sizes) + logger.info( + ( + "cudagraph sizes specified by model runner" + " %s is overridden by config %s" + ), + cudagraph_capture_sizes, + dedup_sizes, + ) self.cudagraph_capture_sizes = dedup_sizes computed_compile_sizes = [] @@ -504,9 +592,10 @@ class CompilationConfig: self.compile_sizes = list(set(self.compile_sizes)) for x in self.compile_sizes: if isinstance(x, str): - assert x == "cudagraph_capture_sizes", \ - "Unrecognized size type in compile_sizes, " \ - f"expect 'cudagraph_capture_sizes', got {x}" + assert x == "cudagraph_capture_sizes", ( + "Unrecognized size type in compile_sizes, " + f"expect 'cudagraph_capture_sizes', got {x}" + ) computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) @@ -515,52 +604,157 @@ class CompilationConfig: # sort to make sure cudagraph capture sizes are in descending order self.cudagraph_capture_sizes.sort(reverse=True) - self.max_capture_size = self.cudagraph_capture_sizes[ - 0] if self.cudagraph_capture_sizes else 0 + self.max_capture_size = ( + self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0 + ) # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_capture_size + 1) - ] - for end, start in zip(self.cudagraph_capture_sizes, - self.cudagraph_capture_sizes[1:] + [0]): + self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)] + for end, start in zip( + self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0] + ): for bs in range(start, end): if bs == start: self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end - self.bs_to_padded_graph_size[ - self.max_capture_size] = self.max_capture_size + self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when level is # CompilationLevel.PIECEWISE assert self.level == CompilationLevel.PIECEWISE, ( "set_splitting_ops_for_v1 should only be called when " - "level is CompilationLevel.PIECEWISE") + "level is CompilationLevel.PIECEWISE" + ) + + if self.use_inductor_graph_partition: + self.set_splitting_ops_for_inductor_graph_partition() + return + + if self.pass_config.enable_attn_fusion: + # here use_inductor_graph_partition is False + self.set_splitting_ops_for_attn_fusion() + return if self.splitting_ops is None: # NOTE: When using full cudagraph, instead of setting an empty # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture the - # full cudagraph outside the fx graph. This reduces some cpu - # overhead when the runtime batch_size is not cudagraph captured. - # see https://github.com/vllm-project/vllm/pull/20059 for details. - self.splitting_ops = self._attention_ops + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. Make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: - logger.warning_once("Using piecewise compilation with empty " - "splitting_ops.") + logger.warning_once("Using piecewise compilation with empty splitting_ops") if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.warning_once( - "When compilation level is piecewise with empty " - "splitting_ops, PIECEWISE cudagraph_mode will be " - "treated as FULL cudagraph_mode. Please ensure you are " - "using attention backends that support cudagraph or set " - "cudagraph_mode to NONE explicitly if encountering " - "any problems.") + "Piecewise compilation with empty splitting_ops do not" + "contains piecewise cudagraph. Setting cudagraph_" + "mode to NONE. Hint: If you are using attention backends " + "that support cudagraph, consider manually setting " + "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " + "full cudagraphs." + ) + self.cudagraph_mode = CUDAGraphMode.NONE + elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do not " + "contains piecewise cudagraph. Setting cudagraph_mode " + "to FULL." + ) self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + def set_splitting_ops_for_inductor_graph_partition(self): + assert self.use_inductor_graph_partition + if self.splitting_ops is None: + self.splitting_ops = list(self._attention_ops) + + def set_splitting_ops_for_attn_fusion(self): + assert self.pass_config.enable_attn_fusion + # For dynamo-partition (non-inductor) attention fusion, + # set splitting_ops to empty to avoid splitting at attention ops + self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "enable_attn_fusion is incompatible with piecewise " + "cudagraph when use_inductor_graph_partition is off. " + "In this case, splitting_ops will be set to empty " + "list, and cudagraph_mode will be set to FULL. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems." + ) + self.cudagraph_mode = CUDAGraphMode.FULL + + assert not self.splitting_ops_contain_attention(), ( + "attention ops should not be in splitting_ops " + "when enable_attn_fusion is True" + ) + def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( - op in self.splitting_ops for op in self._attention_ops) + op in self.splitting_ops for op in self._attention_ops + ) + + def is_attention_compiled_piecewise(self) -> bool: + if not self.splitting_ops_contain_attention(): + return False + + if not self.use_inductor_graph_partition: + # Dynamo-level FX split case + return self.level == CompilationLevel.PIECEWISE + + # Inductor partition case + return ( + self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor" + ) + + def custom_op_log_check(self): + """ + This method logs the enabled/disabled custom ops and checks that the + passed custom_ops field only contains relevant ops. + It is called at the end of set_current_vllm_config, + after the custom ops have been instantiated. + """ + + if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0: + logger.debug("No custom ops found in model.") + return + + logger.debug("enabled custom ops: %s", self.enabled_custom_ops) + logger.debug("disabled custom ops: %s", self.disabled_custom_ops) + + all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops + for op in self.custom_ops: + if op in {"all", "none"}: + continue + + assert op[0] in {"+", "-"}, ( + "Invalid custom op syntax (should be checked during init)" + ) + + # check if op name exists in model + op_name = op[1:] + if op_name not in all_ops_in_model: + from vllm.model_executor.custom_op import CustomOp + + # Does op exist at all or is it just not present in this model? + # Note: Only imported op classes appear in the registry. + missing_str = ( + "doesn't exist (or wasn't imported/registered)" + if op_name not in CustomOp.op_registry + else "not present in model" + ) + + enable_str = "enabling" if op[0] == "+" else "disabling" + logger.warning_once( + "Op '%s' %s, %s with '%s' has no effect", + op_name, + missing_str, + enable_str, + op, + ) diff --git a/vllm/config/device.py b/vllm/config/device.py new file mode 100644 index 0000000000000..4b66424795413 --- /dev/null +++ b/vllm/config/device.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import Any, Literal, Optional, Union + +import torch +from pydantic import ConfigDict, SkipValidation +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class DeviceConfig: + """Configuration for the device to use for vLLM execution.""" + + device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" + """Device type for vLLM execution. + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based + on the current platform.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if self.device == "auto": + # Automated device type detection + from vllm.platforms import current_platform + + self.device_type = current_platform.device_type + if not self.device_type: + raise RuntimeError( + "Failed to infer device type, please set " + "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " + "to turn on verbose logging to help debug the issue." + ) + else: + # Device type is assigned explicitly + if isinstance(self.device, str): + self.device_type = self.device + elif isinstance(self.device, torch.device): + self.device_type = self.device.type + + # Some device types require processing inputs on CPU + if self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py new file mode 100644 index 0000000000000..1c6bdffa1281d --- /dev/null +++ b/vllm/config/kv_events.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class KVEventsConfig: + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py new file mode 100644 index 0000000000000..b33294fd66f78 --- /dev/null +++ b/vllm/config/kv_transfer.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import uuid +from dataclasses import field +from typing import Any, Literal, Optional, get_args + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. Choices are + 'cuda' and 'cpu'.""" + + kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" + + kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" + + kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + P2pNcclConnector, this should be 2.""" + + kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}" + ) + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError( + "Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}" + ) + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and self.kv_role in get_args(KVRole) + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and self.kv_role in get_args(KVProducer) + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and self.kv_role in get_args(KVConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) diff --git a/vllm/config/load.py b/vllm/config/load.py new file mode 100644 index 0000000000000..aa35bc63d5d10 --- /dev/null +++ b/vllm/config/load.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic import Field, field_validator +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.model_executor.model_loader import LoadFormats + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +else: + LoadFormats = Any + TensorizerConfig = Any + +logger = init_logger(__name__) + + +@config +@dataclass +class LoadConfig: + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormats] = "auto" + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models. + - Other custom values can be supported via plugins.""" + download_dir: Optional[str] = None + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + safetensors_load_strategy: str = "lazy" + """Specifies the loading strategy for safetensors weights. + - "lazy" (default): Weights are memory-mapped from the file. This enables + on-demand loading and is highly efficient for models on local storage. + - "eager": The entire file is read into CPU memory upfront before loading. + This is recommended for models on network filesystems (e.g., Lustre, NFS) + as it avoids inefficient random reads, significantly speeding up model + initialization. However, it uses more CPU RAM. + - "torchao": Weights are loaded in upfront and then reconstructed + into torchao tensor subclasses. This is used when the checkpoint + was quantized using torchao and saved using safetensors. + Needs torchao >= 0.14.0 + """ + model_loader_extra_config: Union[dict, TensorizerConfig] = Field( + default_factory=dict + ) + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format.""" + device: Optional[str] = None + """Device to which model weights will be loaded, default to + device_config.device""" + ignore_patterns: Union[list[str], str] = Field( + default_factory=lambda: ["original/**/*"] + ) + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" + use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + @field_validator("load_format", mode="after") + def _lowercase_load_format(cls, load_format: str) -> str: + return load_format.lower() + + @field_validator("ignore_patterns", mode="after") + def _validate_ignore_patterns( + cls, ignore_patterns: Union[list[str], str] + ) -> Union[list[str], str]: + if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + ignore_patterns, + ) + + return ignore_patterns diff --git a/vllm/config/lora.py b/vllm/config/lora.py new file mode 100644 index 0000000000000..c531618a186d9 --- /dev/null +++ b/vllm/config/lora.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union + +import torch +from pydantic import ConfigDict, Field, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.config.cache import CacheConfig +else: + ModelConfig = Any + CacheConfig = Any + +logger = init_logger(__name__) + +LoRADType = Literal["auto", "float16", "bfloat16"] +MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512] +LoRAExtraVocabSize = Literal[256, 512] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class LoRAConfig: + """Configuration for LoRA.""" + + max_lora_rank: MaxLoRARanks = 16 + """Max LoRA rank.""" + max_loras: int = Field(default=1, ge=1) + """Max number of LoRAs in a single batch.""" + fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ + max_cpu_loras: Optional[int] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" + lora_extra_vocab_size: LoRAExtraVocabSize = Field( + default=256, + deprecated=( + "`lora_extra_vocab_size` is deprecated and will be removed " + "in v0.12.0. Additional vocabulary support for " + "LoRA adapters is being phased out." + ), + ) + """(Deprecated) Maximum size of extra vocabulary that can be present in a + LoRA adapter. Will be removed in v0.12.0.""" + lora_vocab_padding_size: ClassVar[int] = ( + current_platform.get_lora_vocab_padding_size() + ) + default_mm_loras: Optional[dict[str, str]] = None + """Dictionary mapping specific modalities to LoRA model paths; this field + is only applicable to multimodal models and should be leveraged when a + model always expects a LoRA to be active when a given modality is present. + Note that currently, if a request provides multiple additional + modalities, each of which have their own LoRA, we do NOT apply + default_mm_loras because we currently only support one lora adapter + per prompt. When run in offline mode, the lora IDs for n modalities + will be automatically assigned to 1-n with the names of the modalities + in alphabetic order.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) + + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + @model_validator(mode="after") + def _validate_lora_config(self) -> Self: + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})" + ) + + return self + + def verify_with_cache_config(self, cache_config: CacheConfig): + if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: + raise ValueError("V0 LoRA does not support CPU offload, please use V1.") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) diff --git a/vllm/config/model.py b/vllm/config/model.py new file mode 100644 index 0000000000000..d0c027e47675c --- /dev/null +++ b/vllm/config/model.py @@ -0,0 +1,2130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import json +import warnings +from dataclasses import InitVar, field +from importlib.util import find_spec +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Union, + cast, + get_args, +) + +import torch +from pydantic import ConfigDict, SkipValidation, field_validator, model_validator +from pydantic.dataclasses import dataclass +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE + +import vllm.envs as envs +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig +from vllm.config.pooler import PoolerConfig +from vllm.config.scheduler import RunnerType +from vllm.config.utils import assert_hashable, config, getattr_iter +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.transformers_utils.config import ( + ConfigFormat, + get_config, + get_hf_image_processor_config, + get_hf_text_config, + get_pooling_config, + get_sentence_transformer_tokenizer_config, + is_encoder_decoder, + is_interleaved, + try_get_generation_config, + try_get_safetensors_metadata, + try_get_tokenizer_config, + uses_mrope, +) +from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri +from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + import vllm.model_executor.layers.quantization as me_quant + import vllm.model_executor.models as me_models + from vllm.config.load import LoadConfig + from vllm.config.parallel import ParallelConfig + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.v1.sample.logits_processor import LogitsProcessor +else: + PretrainedConfig = Any + + me_quant = LazyLoader( + "model_executor", globals(), "vllm.model_executor.layers.quantization" + ) + me_models = LazyLoader("model_executor", globals(), "vllm.model_executor.models") + LoadConfig = Any + ParallelConfig = Any + QuantizationMethods = Any + LogitsProcessor = Any + +logger = init_logger(__name__) + +RunnerOption = Literal["auto", RunnerType] +ConvertType = Literal["none", "embed", "classify", "reward"] +ConvertOption = Literal["auto", ConvertType] +TaskOption = Literal[ + "auto", + "generate", + "embedding", + "embed", + "classify", + "score", + "reward", + "transcription", + "draft", +] +TokenizerMode = Literal["auto", "slow", "mistral", "custom"] +ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] +LogprobsMode = Literal[ + "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" +] +HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig]] +ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] + +_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { + "generate": ["generate", "transcription"], + "pooling": ["embedding", "embed", "classify", "score", "reward"], + "draft": ["draft"], +} + +_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { + "generate": [], + "pooling": ["embed", "classify", "reward"], + "draft": [], +} + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class ModelConfig: + """Configuration for the model.""" + + model: str = "Qwen/Qwen3-0.6B" + """Name or path of the Hugging Face model to use. It is also used as the + content for `model_name` tag in metrics output when `served_model_name` is + not specified.""" + runner: RunnerOption = "auto" + """The type of model runner to use. Each vLLM instance only supports one + model runner, even if the same model can be used for multiple types.""" + convert: ConvertOption = "auto" + """Convert the model using adapters defined in + [vllm.model_executor.models.adapters][]. The most common use case is to + adapt a text generation model to be used for pooling tasks.""" + task: Optional[TaskOption] = None + """[DEPRECATED] The task to use the model for. If the model supports more + than one model runner, this is used to select which model runner to run. + + Note that the model may support other tasks using the same model runner. + """ + tokenizer: SkipValidation[str] = None # type: ignore + """Name or path of the Hugging Face tokenizer to use. If unspecified, model + name or path will be used.""" + tokenizer_mode: TokenizerMode = "auto" + """Tokenizer mode:\n + - "auto" will use the fast tokenizer if available.\n + - "slow" will always use the slow tokenizer.\n + - "mistral" will always use the tokenizer from `mistral_common`.\n + - "custom" will use --tokenizer to select the preregistered tokenizer.""" + trust_remote_code: bool = False + """Trust remote code (e.g., from HuggingFace) when downloading the model + and tokenizer.""" + dtype: Union[ModelDType, torch.dtype] = "auto" + """Data type for model weights and activations:\n + - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 + precision for BF16 models.\n + - "half" for FP16. Recommended for AWQ quantization.\n + - "float16" is the same as "half".\n + - "bfloat16" for a balance between precision and range.\n + - "float" is shorthand for FP32 precision.\n + - "float32" for FP32 precision.""" + seed: Optional[int] = None + """Random seed for reproducibility. Initialized to None in V0, but + initialized to 0 in V1.""" + hf_config_path: Optional[str] = None + """Name or path of the Hugging Face config to use. If unspecified, model + name or path will be used.""" + allowed_local_media_path: str = "" + """Allowing API requests to read local images or videos from directories + specified by the server file system. This is a security risk. Should only + be enabled in trusted environments.""" + allowed_media_domains: Optional[list[str]] = None + """If set, only media URLs that belong to this domain can be used for + multi-modal inputs. """ + revision: Optional[str] = None + """The specific model version to use. It can be a branch name, a tag name, + or a commit id. If unspecified, will use the default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the model code on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + rope_scaling: dict[str, Any] = field(default_factory=dict) + """RoPE scaling configuration. For example, + `{"rope_type":"dynamic","factor":2.0}`.""" + rope_theta: Optional[float] = None + """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE + theta improves the performance of the scaled model.""" + tokenizer_revision: Optional[str] = None + """The specific revision to use for the tokenizer on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + max_model_len: SkipValidation[int] = None # type: ignore + """Model context length (prompt and output). If unspecified, will be + automatically derived from the model config. + + When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable + format. Examples:\n + - 1k -> 1000\n + - 1K -> 1024\n + - 25.6k -> 25,600""" + spec_target_max_model_len: Optional[int] = None + """Specify the maximum length for spec decoding draft models.""" + quantization: SkipValidation[Optional[QuantizationMethods]] = None + """Method used to quantize the weights. If `None`, we first check the + `quantization_config` attribute in the model config file. If that is + `None`, we assume the model weights are not quantized and use `dtype` to + determine the data type of the weights.""" + enforce_eager: bool = False + """Whether to always use eager-mode PyTorch. If True, we will disable CUDA + graph and always execute the model in eager mode. If False, we will use + CUDA graph and eager execution in hybrid for maximal performance and + flexibility.""" + max_logprobs: int = 20 + """Maximum number of log probabilities to return when `logprobs` is + specified in `SamplingParams`. The default value comes the default for the + OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * + vocab_size) logprobs are allowed to be returned and it may cause OOM.""" + logprobs_mode: LogprobsMode = "raw_logprobs" + """Indicates the content returned in the logprobs and prompt_logprobs. + Supported mode: + 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. + Raw means the values before applying any logit processors, like bad words. + Processed means the values after applying all processors, including + temperature and top_k/top_p. + """ + disable_sliding_window: bool = False + """Whether to disable sliding window. If True, we will disable the sliding + window functionality of the model, capping to sliding window size. If the + model does not support sliding window, this argument is ignored.""" + disable_cascade_attn: bool = False + """Disable cascade attention for V1. While cascade attention does not + change the mathematical correctness, disabling it could be useful for + preventing potential numerical issues. Note that even if this is set to + False, cascade attention will be only used when the heuristic tells that + it's beneficial.""" + skip_tokenizer_init: bool = False + """Skip initialization of tokenizer and detokenizer. Expects valid + `prompt_token_ids` and `None` for prompt from the input. The generated + output will contain token ids.""" + enable_prompt_embeds: bool = False + """If `True`, enables passing text embeddings as inputs via the + `prompt_embeds` key. Note that enabling this will double the time required + for graph compilation.""" + served_model_name: Optional[Union[str, list[str]]] = None + """The model name(s) used in the API. If multiple names are provided, the + server will respond to any of the provided names. The model name in the + model field of a response will be the first name in this list. If not + specified, the model name will be the same as the `--model` argument. Noted + that this name(s) will also be used in `model_name` tag content of + prometheus metrics, if multiple names provided, metrics tag will take the + first one.""" + config_format: Union[str, ConfigFormat] = "auto" + """The format of the model config to load:\n + - "auto" will try to load the config in hf format if available else it + will try to load in mistral format.\n + - "hf" will load the config in hf format.\n + - "mistral" will load the config in mistral format.""" + hf_token: Optional[Union[bool, str]] = None + """The token to use as HTTP bearer authorization for remote files . If + `True`, will use the token generated when running `huggingface-cli login` + (stored in `~/.huggingface`).""" + hf_overrides: HfOverrides = field(default_factory=dict) + """If a dictionary, contains arguments to be forwarded to the Hugging Face + config. If a callable, it is called to update the HuggingFace config.""" + logits_processor_pattern: Optional[str] = None + """Optional regex pattern specifying valid logits processor qualified names + that can be passed with the `logits_processors` extra completion argument. + Defaults to `None`, which allows no processors.""" + generation_config: str = "auto" + """The folder path to the generation config. Defaults to `"auto"`, the + generation config will be loaded from model path. If set to `"vllm"`, no + generation config is loaded, vLLM defaults will be used. If set to a folder + path, the generation config will be loaded from the specified folder path. + If `max_new_tokens` is specified in generation config, then it sets a + server-wide limit on the number of output tokens for all requests.""" + override_generation_config: dict[str, Any] = field(default_factory=dict) + """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If + used with `--generation-config auto`, the override parameters will be + merged with the default config from the model. If used with + `--generation-config vllm`, only the override parameters are used.""" + enable_sleep_mode: bool = False + """Enable sleep mode for the engine (only cuda platform is supported).""" + model_impl: Union[str, ModelImpl] = "auto" + """Which implementation of the model to use:\n + - "auto" will try to use the vLLM implementation, if it exists, and fall + back to the Transformers implementation if no vLLM implementation is + available.\n + - "vllm" will use the vLLM model implementation.\n + - "transformers" will use the Transformers model implementation.\n + - "terratorch" will use the TerraTorch model implementation. + """ + override_attention_dtype: Optional[str] = None + """Override dtype for attention""" + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None + """One or more logits processors' fully-qualified class names or class + definitions""" + io_processor_plugin: Optional[str] = None + """IOProcessor plugin name to load at model startup""" + + # Pooler config + pooler_config: Optional[PoolerConfig] = None + """Pooler config which controls the behaviour of output pooling in pooling + models.""" + override_pooler_config: Optional[Union[dict, PoolerConfig]] = None + """[DEPRECATED] Use `pooler_config` instead. This field will be removed in + v0.12.0 or v1.0.0, whichever is sooner.""" + + # Multimodal config and init vars + multimodal_config: Optional[MultiModalConfig] = None + """Configuration for multimodal model. If `None`, this will be inferred + from the architecture of `self.model`.""" + limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int, dict[str, int]]]]] = None + media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None + mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None + mm_processor_cache_gb: InitVar[Optional[float]] = None + mm_processor_cache_type: InitVar[Optional[MMCacheType]] = None + mm_shm_cache_max_object_size_mb: InitVar[Optional[int]] = None + mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None + interleave_mm_strings: InitVar[Optional[bool]] = None + skip_mm_profiling: InitVar[Optional[bool]] = None + video_pruning_rate: InitVar[Optional[float]] = None + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.model) + factors.append(self.dtype) + factors.append(self.quantization) + factors.append(self.revision) + factors.append(self.code_revision) + factors.append(self.max_model_len) + factors.append(self.max_logprobs) + factors.append(self.disable_sliding_window) + factors.append(self.trust_remote_code) + factors.append(self.generation_config) + factors.append(self.model_impl) + factors.append(self.override_generation_config) + factors.append(self.rope_scaling) + factors.append(self.rope_theta) + factors.append(self.video_pruning_rate) + + # hf_config can control how the model looks! + try: + hf_config_json = self.hf_config.to_json_string(use_diff=False) + except TypeError: + from transformers import PretrainedConfig + + from vllm.utils.jsontree import json_map_leaves + + # Handle nested HF configs with unserializable values gracefully + hf_config_json = ( + json.dumps( + json_map_leaves( + lambda v: v.to_dict() + if isinstance(v, PretrainedConfig) + else str(v), + self.hf_config.to_dict(), + ), + indent=2, + sort_keys=True, + ) + + "\n" + ) + + factors.append(hf_config_json) + + str_factors = str(factors) + assert_hashable(str_factors) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def _update_nested( + self, + target: Union["PretrainedConfig", dict[str, Any]], + updates: dict[str, Any], + ) -> None: + """Recursively updates a config or dict with nested updates.""" + for key, value in updates.items(): + if isinstance(value, dict): + # Get the nested target + if isinstance(target, dict): + nested_target = target.get(key) + else: + nested_target = getattr(target, key, None) + + # If nested target exists and can be updated recursively + if nested_target is not None and ( + isinstance(nested_target, dict) + or hasattr(nested_target, "__dict__") + ): + self._update_nested(nested_target, value) + continue + + # Set the value (base case) + if isinstance(target, dict): + target[key] = value + else: + setattr(target, key, value) + + def _apply_dict_overrides( + self, + config: "PretrainedConfig", + overrides: dict[str, Any], + ) -> None: + """Apply dict overrides, handling both nested configs and dict values.""" + from transformers import PretrainedConfig + + for key, value in overrides.items(): + attr = getattr(config, key, None) + if attr is not None and isinstance(attr, PretrainedConfig): + # It's a nested config - recursively update it + self._update_nested(attr, value) + else: + # It's a dict-valued parameter - set it directly + setattr(config, key, value) + + def __post_init__( + self, + # Multimodal config init vars + limit_mm_per_prompt: Optional[dict[str, int]], + media_io_kwargs: Optional[dict[str, dict[str, Any]]], + mm_processor_kwargs: Optional[dict[str, Any]], + mm_processor_cache_gb: Optional[float], + mm_processor_cache_type: Optional[MMCacheType], + mm_shm_cache_max_object_size_mb: Optional[int], + mm_encoder_tp_mode: Optional[MMEncoderTPMode], + interleave_mm_strings: Optional[bool], + skip_mm_profiling: Optional[bool], + video_pruning_rate: Optional[float], + ) -> None: + # Set the default seed to 0 in V1. + # NOTE(woosuk): In V0, we set the default seed to None because the + # driver worker shares the same process as the user process, and thus + # setting a seed affects the user process as well. + # In V1, we use separate processes for workers (unless + # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here + # doesn't affect the user process. However, without a consistent seed, + # different tensor parallel workers would sample different tokens, + # leading to inconsistent results. + if envs.VLLM_USE_V1 and self.seed is None: + self.seed = 0 + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.warning( + "The global random seed is set to %d. Since " + "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " + "affect the random state of the Python process that " + "launched vLLM.", + self.seed, + ) + + # Keep set served_model_name before maybe_model_redirect(self.model) + self.served_model_name = get_served_model_name( + self.model, self.served_model_name + ) + self.model = maybe_model_redirect(self.model) + # The tokenizer is consistent with the model by default. + if self.tokenizer is None: + self.tokenizer = self.model + if self.tokenizer_revision is None: + self.tokenizer_revision = self.revision + self.tokenizer = maybe_model_redirect(self.tokenizer) + + if isinstance(self.hf_config_path, str): + self.hf_config_path = maybe_model_redirect(self.hf_config_path) + + if callable(self.hf_overrides): + hf_overrides_kw = {} + hf_overrides_fn = self.hf_overrides + dict_overrides: dict[str, Any] = {} + else: + # Separate dict overrides from flat ones + # We'll determine how to apply dict overrides after loading the config + hf_overrides_kw = {} + dict_overrides = {} + for key, value in self.hf_overrides.items(): + if isinstance(value, dict): + dict_overrides[key] = value + else: + hf_overrides_kw[key] = value + hf_overrides_fn = None + + if self.rope_scaling: + hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides_kw) + msg = ( + "`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`" + ) + warnings.warn(DeprecationWarning(msg), stacklevel=2) + if self.rope_theta is not None: + hf_override = {"rope_theta": self.rope_theta} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides_kw) + msg = ( + "`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`" + ) + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) + + if ( + (backend := envs.VLLM_ATTENTION_BACKEND) + and backend == "FLASHINFER" + and find_spec("flashinfer") is None + ): + raise ValueError( + "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "module was not found. See " + "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 + "for instructions on how to install it." + ) + + from vllm.platforms import current_platform + + if self.override_attention_dtype is not None and not current_platform.is_rocm(): + warnings.warn( + "override-attention-dtype is set but not using ROCm platform", + stacklevel=2, + ) + + if self.enable_sleep_mode and not current_platform.is_sleep_mode_available(): + raise ValueError("Sleep mode is not supported on current platform.") + + hf_config = get_config( + self.hf_config_path or self.model, + self.trust_remote_code, + self.revision, + self.code_revision, + self.config_format, + hf_overrides_kw=hf_overrides_kw, + hf_overrides_fn=hf_overrides_fn, + ) + + self.hf_config = hf_config + if dict_overrides: + self._apply_dict_overrides(hf_config, dict_overrides) + self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr( + self.hf_text_config, "attention_chunk_size", None + ) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=self.hf_token, revision=self.revision + ) + + architectures = self.architectures + registry = self.registry + is_generative_model = registry.is_text_generation_model(architectures, self) + is_pooling_model = registry.is_pooling_model(architectures, self) + + def _task_to_convert(task: TaskOption) -> ConvertType: + if task == "embedding" or task == "embed": + return "embed" + if task == "classify": + return "classify" + if task == "reward": + return "reward" + if task == "score": + new_task = self._get_default_pooling_task(architectures) + return "classify" if new_task == "classify" else "embed" + + return "none" + + if self.task is not None: + runner: RunnerOption = "auto" + convert: ConvertOption = "auto" + msg_prefix = ( + "The 'task' option has been deprecated and will be " + "removed in v0.13.0 or v1.0, whichever comes first." + ) + msg_hint = "Please remove this option." + + is_generative_task = self.task in _RUNNER_TASKS["generate"] + is_pooling_task = self.task in _RUNNER_TASKS["pooling"] + + if is_generative_model and is_pooling_model: + if is_generative_task: + runner = "generate" + convert = "auto" + msg_hint = ( + "Please replace this option with `--runner " + "generate` to continue using this model " + "as a generative model." + ) + elif is_pooling_task: + runner = "pooling" + convert = "auto" + msg_hint = ( + "Please replace this option with `--runner " + "pooling` to continue using this model " + "as a pooling model." + ) + else: # task == "auto" + pass + elif is_generative_model or is_pooling_model: + if is_generative_task: + runner = "generate" + convert = "auto" + msg_hint = "Please remove this option" + elif is_pooling_task: + runner = "pooling" + convert = _task_to_convert(self.task) + msg_hint = ( + "Please replace this option with `--convert " + f"{convert}` to continue using this model " + "as a pooling model." + ) + else: # task == "auto" + pass + else: + debug_info = { + "architectures": architectures, + "is_generative_model": is_generative_model, + "is_pooling_model": is_pooling_model, + } + raise AssertionError( + "The model should be a generative or " + "pooling model when task is set to " + f"{self.task!r}. Found: {debug_info}" + ) + + self.runner = runner + self.convert = convert + + msg = f"{msg_prefix} {msg_hint}" + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + self.runner_type = self._get_runner_type(architectures, self.runner) + self.convert_type = self._get_convert_type( + architectures, self.runner_type, self.convert + ) + + if self.runner_type == "generate" and not is_generative_model: + generate_converts = _RUNNER_CONVERTS["generate"] + if self.convert_type not in generate_converts: + # Currently we don't have any converters for generative models + raise ValueError("This model does not support `--runner generate`.") + if self.runner_type == "pooling" and not is_pooling_model: + pooling_converts = _RUNNER_CONVERTS["pooling"] + if self.convert_type not in pooling_converts: + convert_option = "<" + "|".join(pooling_converts) + ">" + raise ValueError( + "This model does not support `--runner pooling`. " + f"You can pass `--convert {convert_option} to adapt " + "it into a pooling model." + ) + + # Note: Initialize these attributes early because transformers fallback + # may fail to load dynamic modules in child processes + model_info, arch = registry.inspect_model_cls(architectures, self) + self._model_info = model_info + self._architecture = arch + logger.info("Resolved architecture: %s", arch) + + # Init pooler config if needed + if self.runner_type == "pooling": + if self.override_pooler_config is not None: + logger.warning_once( + "`override_pooler_config` is deprecated and will be " + "removed in v0.12.0 or v1.0.0, whichever is sooner. " + "Please use `pooler_config` instead." + ) + + if isinstance(self.override_pooler_config, dict): + self.pooler_config = PoolerConfig(**self.override_pooler_config) + else: + self.pooler_config = self.override_pooler_config + + if self.pooler_config is None: + self.pooler_config = PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(self.pooler_config, k) is None: + setattr(self.pooler_config, k, v) + + default_pooling_type = self._model_info.default_pooling_type + if self.pooler_config.pooling_type is None: + self.pooler_config.pooling_type = default_pooling_type + + self.dtype: torch.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + is_pooling_model=self.runner_type == "pooling", + revision=self.revision, + ) + + # Interleaved attention is not supported by some backends in V0 + if ( + not self.disable_sliding_window + and is_interleaved(self.hf_text_config) + and not envs.VLLM_USE_V1 + and (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER") + ): + logger.warning_once( + "%s has interleaved attention, which is currently not " + "supported by the %s backend. Disabling sliding window and " + "capping the max length to the sliding window size (%d).", + self.hf_text_config.model_type, + backend, + self.hf_text_config.sliding_window, + ) + self.disable_sliding_window = True + + self.original_max_model_len = self.max_model_len + self.max_model_len = self.get_and_verify_max_len(self.max_model_len) + # Init multimodal config if needed + if self._model_info.supports_multimodal: + if ( + mm_encoder_tp_mode == "data" + and not self._model_info.supports_multimodal_encoder_tp_data + ): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`." + ) + mm_encoder_tp_mode = "weights" + + mm_config_kwargs = dict( + limit_per_prompt=limit_mm_per_prompt, + media_io_kwargs=media_io_kwargs, + mm_processor_kwargs=mm_processor_kwargs, + mm_processor_cache_gb=mm_processor_cache_gb, + mm_processor_cache_type=mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, + mm_encoder_tp_mode=mm_encoder_tp_mode, + interleave_mm_strings=interleave_mm_strings, + skip_mm_profiling=skip_mm_profiling, + video_pruning_rate=video_pruning_rate, + ) + + mm_config_kwargs = { + k: v for k, v in mm_config_kwargs.items() if v is not None + } + + self.multimodal_config = MultiModalConfig(**mm_config_kwargs) + + if self.disable_sliding_window: + # Set after get_and_verify_max_len to ensure that max_model_len + # can be correctly capped to sliding window size + self.hf_text_config.sliding_window = None + + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + + # Avoid running try_verify_and_update_config multiple times + self.config_updated = False + + self._verify_quantization() + self._verify_cuda_graph() + self._verify_bnb_config() + + @field_validator("quantization", mode="before") + @classmethod + def validate_quantization_before(cls, value: Any) -> Any: + if isinstance(value, str): + return value.lower() + return value + + @model_validator(mode="after") + def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + if not isinstance(self.tokenizer, str): + raise ValueError("tokenizer must be a string after __post_init__.") + if not isinstance(self.max_model_len, int): + raise ValueError("max_model_len must be an integer after __post_init__.") + return self + + def _get_transformers_backend_cls(self) -> str: + """Determine which Transformers backend class will be used if + `model_impl` is set to `transformers` or `auto`.""" + prefix = "Transformers" + prefix += "MoE" if self.get_num_experts() > 1 else "" + # Check if the architecture we're wrapping has defaults + runner = None + convert = None + if defaults := try_match_architecture_defaults(self.architectures[0]): + _, (runner, convert) = defaults + # Overwrite with user-specified values + if self.runner != "auto": + runner = self.runner + if self.convert not in {"auto", "none"}: + convert = self.convert + # Fall back to default values if still not set + if runner is None: + runner = "generate" + if convert in {None, "none"}: + convert = "embed" + # Resolve Transformers backend pooling classes + if runner == "pooling": + if convert == "embed": + return prefix + "EmbeddingModel" + if convert == "classify": + return prefix + "ForSequenceClassification" + # Resolve Transformers backend generate classes + if self.hf_config != self.hf_text_config: + # If 'hf_text_config' is the same as 'hf_config'. If not, it is + # probably a composite config, i.e. multimodal + return prefix + "ForMultimodalLM" + return prefix + "ForCausalLM" + + def using_transformers_backend(self) -> bool: + """Check if the model is using the Transformers backend class.""" + used_cls = self._model_info.architecture + transformers_backend_cls = self._get_transformers_backend_cls() + return used_cls == transformers_backend_cls + + @property + def registry(self): + return me_models.ModelRegistry + + @property + def architectures(self) -> list[str]: + return getattr(self.hf_config, "architectures", []) + + @property + def architecture(self) -> str: + """The architecture vllm actually used.""" + return self._architecture + + def maybe_pull_model_tokenizer_for_runai(self, model: str, tokenizer: str) -> None: + """Pull model/tokenizer from Object Storage to temporary + directory when needed. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + """ + + if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)): + return + + if is_runai_obj_uri(model): + object_storage_model = ObjectStorageModel(url=model) + object_storage_model.pull_files( + model, allow_pattern=["*.model", "*.py", "*.json"] + ) + self.model_weights = model + self.model = object_storage_model.dir + + # If tokenizer is same as model, download to same directory + if model == tokenizer: + object_storage_model.pull_files( + model, + ignore_pattern=[ + "*.pt", + "*.safetensors", + "*.bin", + "*.tensors", + "*.pth", + ], + ) + self.tokenizer = object_storage_model.dir + return + + # Only download tokenizer if needed and not already handled + if is_runai_obj_uri(tokenizer): + object_storage_tokenizer = ObjectStorageModel(url=tokenizer) + object_storage_tokenizer.pull_files( + model, + ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors", "*.pth"], + ) + self.tokenizer = object_storage_tokenizer.dir + + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config(self.model, self.revision) + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) + if tokenizer_mode not in get_args(TokenizerMode): + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + f"one of {get_args(TokenizerMode)}." + ) + self.tokenizer_mode = tokenizer_mode + + def _get_default_runner_type( + self, + architectures: list[str], + ) -> RunnerType: + registry = self.registry + + # Some Sentence Transformers models use *ForCausalLM archs + if get_pooling_config(self.model, self.revision): + return "pooling" + + for arch in architectures: + if arch in registry.get_supported_archs(): + if registry.is_pooling_model(architectures, self): + return "pooling" + if registry.is_text_generation_model(architectures, self): + return "generate" + + match = try_match_architecture_defaults(arch) + if match: + _, (runner_type, _) = match + return runner_type + + return "generate" + + def _get_runner_type( + self, + architectures: list[str], + runner: RunnerOption, + ) -> RunnerType: + if runner != "auto": + return runner + + runner_type = self._get_default_runner_type(architectures) + + # Don't log the most common case + if runner_type != "generate": + logger.info( + "Resolved `--runner auto` to `--runner %s`. " + "Pass the value explicitly to silence this message.", + runner_type, + ) + + return runner_type + + def _get_default_convert_type( + self, + architectures: list[str], + runner_type: RunnerType, + ) -> ConvertType: + registry = self.registry + + for arch in architectures: + if arch in registry.get_supported_archs(): + if runner_type == "generate" and registry.is_text_generation_model( + architectures, self + ): + return "none" + if runner_type == "pooling" and registry.is_pooling_model( + architectures, self + ): + return "none" + + match = try_match_architecture_defaults(arch, runner_type=runner_type) + if match: + _, (_, convert_type) = match + return convert_type + + # This is to handle Sentence Transformers models that use *ForCausalLM + # and also multi-modal pooling models which are not defined as + # Sentence Transformers models + if runner_type == "pooling": + return "embed" + + return "none" + + def _get_convert_type( + self, + architectures: list[str], + runner_type: RunnerType, + convert: ConvertOption, + ) -> ConvertType: + if convert != "auto": + return convert + + convert_type = self._get_default_convert_type(architectures, runner_type) + + # Don't log the most common case + if convert_type != "none": + logger.info( + "Resolved `--convert auto` to `--convert %s`. " + "Pass the value explicitly to silence this message.", + convert_type, + ) + + return convert_type + + def _get_default_pooling_task( + self, + architectures: list[str], + ) -> Literal["embed", "classify", "reward"]: + if self.registry.is_cross_encoder_model(architectures, self): + return "classify" + + for arch in architectures: + match = try_match_architecture_defaults(arch, runner_type="pooling") + if match: + _, (_, convert_type) = match + assert convert_type != "none" + return convert_type + + return "embed" + + def _parse_quant_hf_config(self, hf_config: PretrainedConfig): + quant_cfg = getattr(hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(hf_config, "compression_config", None) + + else: + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") + + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = me_quant.QUANTIZATION_METHODS + if self.quantization is not None: + self.quantization = cast(me_quant.QuantizationMethods, self.quantization) + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config(self.hf_config) + if quant_cfg is None and ( + text_config := getattr(self.hf_config, "text_config", None) + ): + # Check the text config as well for multi-modal models. + quant_cfg = self._parse_quant_hf_config(text_config) + + if quant_cfg is not None: + # Use the community standard 'quant_method' + quant_method = quant_cfg.get("quant_method", "").lower() + + # Normalize library names + quant_method = quant_method.replace( + "compressed_tensors", "compressed-tensors" + ) + + quant_cfg["quant_method"] = quant_method + + # Quantization methods which are overrides (i.e. they have a + # `override_quantization_method` method) must be checked in order + # of preference (this is particularly important for GPTQ). + overrides = [ + "bitblas", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "ipex", + "moe_wna16", + "modelopt", + "modelopt_fp4", + "petit_nvfp4", + # Ensure heavy backends are probed last to avoid unnecessary + # imports during override detection (e.g., MXFP4 imports Triton) + "mxfp4", + ] + quantization_methods = [ + q for q in supported_quantization if q not in overrides + ] + # Any custom overrides will be in quantization_methods so we place + # them at the start of the list so custom overrides have preference + # over the built-in ones. + quantization_methods = quantization_methods + overrides + + # Detect which checkpoint is it + for name in quantization_methods: + method = me_quant.get_quantization_config(name) + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization + ) + if quantization_override is not None: + # Raise error if the override is not custom (custom would + # be in QUANTIZATION_METHODS but not QuantizationMethods) + # and hasn't been added to the overrides list. + if ( + name in get_args(me_quant.QuantizationMethods) + and name not in overrides + ): + raise ValueError( + f"Quantization method {name} is an override but " + "is has not been added to the `overrides` list " + "above. This is necessary to ensure that the " + "overrides are checked in order of preference." + ) + quant_method = quantization_override + self.quantization = quantization_override + break + + quant_method = quant_method if quant_method != "" else None + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}." + ) + from vllm.platforms import current_platform + + current_platform.verify_quantization(self.quantization) + + def _verify_cuda_graph(self) -> None: + # CUDAGraph capture not supported for encoder-decoder models on ROCm + unsupported_rocm = self.is_encoder_decoder + if unsupported_rocm and not self.enforce_eager and current_platform.is_rocm(): + logger.warning( + "CUDA graph is not supported for %s on ROCm yet, fallback " + "to eager mode.", + self.hf_config.model_type, + ) + self.enforce_eager = True + + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.46.1) with 8-bit models does not + yet support CUDA graph. + # TODO Remove this when bitsandbytes supports. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = ( + getattr(self.hf_config, "quantization_config", None) is not None + ) + is_8bit = ( + self.hf_config.quantization_config.get("load_in_8bit", False) + if has_quantization_config + else False + ) + if all( + [ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ] + ): + logger.warning( + "CUDA graph is not supported on BitsAndBytes 8bit yet, " + "fallback to the eager mode." + ) + + self.enforce_eager = True + + def _verify_with_expert_parallelism(self) -> None: + num_experts = self.get_num_experts() + if num_experts < 1: + raise ValueError( + "Number of experts in the model must be greater than 0 " + "when expert parallelism is enabled." + ) + + def verify_dual_chunk_attention_config( + self, + load_config: LoadConfig, + ) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + from vllm.model_executor.model_loader.weight_utils import ( + get_sparse_attention_config, + ) + + sparse_attn_config = get_sparse_attention_config(self, load_config) + if sparse_attn_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_config" + ] = sparse_attn_config + if ( + "sparse_attention_enabled" + not in self.hf_config.dual_chunk_attention_config + ): + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled" + ] = True + + def verify_with_parallel_config( + self, + parallel_config: ParallelConfig, + ) -> None: + if parallel_config.distributed_executor_backend == "external_launcher": + assert self.seed is not None, ( + "Seed must be set when using external launcher backend to " + "make sure sampling results are the same across workers." + ) + + total_num_attention_heads = getattr( + self.hf_text_config, "num_attention_heads", 0 + ) + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size})." + ) + + if parallel_config.enable_expert_parallel: + self._verify_with_expert_parallelism() + + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if pipeline_parallel_size > 1 and not self.registry.is_pp_supported_model( + self.architectures, self + ): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface." + ) + + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size from the HF text config if present.""" + return getattr(self.hf_text_config, "sliding_window", None) + + def get_vocab_size(self) -> int: + return getattr(self.hf_text_config, "vocab_size", 0) + + def get_hidden_size(self) -> int: + return getattr(self.hf_text_config, "hidden_size", 0) + + @property + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in ( + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "deepseek_mtp", + "kimi_k2", + "longcat_flash", + ): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == "eagle": + # if the model is an EAGLE module, check for the + # underlying architecture + return ( + self.hf_text_config.model.model_type + in ("deepseek_v2", "deepseek_v3", "deepseek_v32") + and self.hf_text_config.kv_lora_rank is not None + ) + return False + + def get_head_size(self) -> int: + # TODO remove hard code + if self.is_deepseek_mla: + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) + if self.use_mla: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + if hasattr(self.hf_text_config, "model_type") and ( + self.hf_text_config.model_type == "zamba2" + ): + return self.hf_text_config.attention_head_dim + + if self.is_attention_free: + return 0 + + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: + return self.hf_text_config.hidden_size_per_head + + # FIXME(woosuk): This may not be true for all models. + return ( + self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads + ) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) + + if self.hf_config.model_type == "nemotron-nas": + for block in self.hf_config.block_configs: + if not block.attention.no_op: + return ( + self.hf_config.num_attention_heads + // block.attention.n_heads_in_group + ) + + raise RuntimeError("Couldn't determine number of kv heads") + + if self.is_attention_free: + return 0 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: + """Returns the number of KV heads per GPU.""" + if self.use_mla: + # When using MLA during decode it becomes MQA + return 1 + + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads // parallel_config.tensor_parallel_size + + def get_num_experts(self) -> int: + """Returns the number of experts in the model.""" + num_expert_names = [ + "num_experts", # Jamba + "moe_num_experts", # Dbrx + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + if isinstance(num_experts, list): + # Ernie VL's remote code uses list[int]... + # The values are always the same so we just take the first one. + return num_experts[0] + return num_experts + + def get_layers_start_end_indices( + self, parallel_config: ParallelConfig + ) -> tuple[int, int]: + from vllm.distributed.utils import get_pp_indices + + if ( + self.hf_text_config.model_type == "deepseek_mtp" + or self.hf_config.model_type == "mimo_mtp" + or self.hf_config.model_type == "glm4_moe_mtp" + or self.hf_config.model_type == "ernie_mtp" + or self.hf_config.model_type == "qwen3_next_mtp" + ): + total_num_hidden_layers = getattr( + self.hf_text_config, "num_nextn_predict_layers", 0 + ) + elif self.hf_config.model_type == "longcat_flash_mtp": + total_num_hidden_layers = getattr( + self.hf_text_config, "num_nextn_predict_layers", 1 + ) + else: + total_num_hidden_layers = getattr( + self.hf_text_config, "num_hidden_layers", 0 + ) + # the layout order is: DP x PP x TP + pp_rank = ( + parallel_config.rank // parallel_config.tensor_parallel_size + ) % parallel_config.pipeline_parallel_size + pp_size = parallel_config.pipeline_parallel_size + start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) + return start, end + + def get_num_layers(self, parallel_config: ParallelConfig) -> int: + start, end = self.get_layers_start_end_indices(parallel_config) + return end - start + + def get_num_layers_by_block_type( + self, + parallel_config: ParallelConfig, + block_type: LayerBlockType = LayerBlockType.attention, + ) -> int: + # This function relies on 'layers_block_type' in hf_config, + # for w/o this attribute, we will need to have workarounds like so + attn_block_type = block_type == LayerBlockType.attention + is_transformer = ( + not self.is_hybrid and not self.has_noops and not self.is_attention_free + ) + start, end = self.get_layers_start_end_indices(parallel_config) + + if is_transformer: + # Handle the basic case first + return end - start if attn_block_type else 0 + elif self.is_attention_free: + # Attention free + # Note that this code assumes there + # is only one type of attention-free block type. + return 0 if attn_block_type else end - start + elif self.has_noops: + block_configs = self.hf_config.block_configs + return sum(not bc.attention.no_op for bc in block_configs[start:end]) + else: + # Hybrid model Jamba + layers_block_type_value = getattr( + self.hf_text_config, "layers_block_type", None + ) + if layers_block_type_value is not None: + if hasattr(self.hf_text_config, "model_type") and ( + self.hf_text_config.model_type == "zamba2" + ): + if attn_block_type: + return sum( + t == "hybrid" for t in layers_block_type_value[start:end] + ) + else: + return self.get_num_layers(parallel_config) + return sum( + t == block_type.value for t in layers_block_type_value[start:end] + ) + + # Hybrid model Minimax + attn_type_list = getattr(self.hf_config, "attn_type_list", None) + if attn_type_list: + return sum(t == 1 for t in attn_type_list[start:end]) + + # Hybrid model Qwen3Next + layer_types_value = getattr(self.hf_config, "layer_types", None) + if layer_types_value is not None: + if getattr(block_type, "value", block_type) == "attention": + return sum( + t == "full_attention" for t in layer_types_value[start:end] + ) + elif getattr(block_type, "value", block_type) == "linear_attention": + return sum( + t == "linear_attention" for t in layer_types_value[start:end] + ) + else: + return sum( + t == getattr(block_type, "value", block_type) + for t in layer_types_value[start:end] + ) + + if ( + layers_block_type_value is None + and attn_type_list is None + and layer_types_value is None + ): + raise ValueError( + "The model is an hybrid without a" + "layers_block_type or an attn_type_list, or a layer_types " + "in the hf_config, cannot determine the num of " + f"{block_type.value} layers" + ) + + def get_mamba_chunk_size(self) -> Optional[int]: + """ + Returns the mamba chunk size if it exists + """ + # used by e.g. Bamba, FalconH1, Granite, PLaMo2 + chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None) + if chunk_size is None: + # used by e.g. Mamba2, NemotronH, Zamba + chunk_size = getattr(self.hf_text_config, "chunk_size", None) + return chunk_size + + def get_multimodal_config(self) -> MultiModalConfig: + """ + Get the multimodal configuration of the model. + + Raises: + ValueError: If the model is not multimodal. + """ + if self.multimodal_config is None: + raise ValueError("The model is not multimodal.") + + return self.multimodal_config + + def try_get_generation_config(self) -> dict[str, Any]: + """ + This method attempts to retrieve the non-default values of the + generation config for this model. + + The generation config can contain information about special tokens, as + well as sampling parameters. Which is why this method exists separately + to `get_diff_sampling_param`. + + Returns: + A dictionary containing the non-default generation config. + """ + if self.generation_config in {"auto", "vllm"}: + config = try_get_generation_config( + self.hf_config_path or self.model, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + config_format=self.config_format, + ) + else: + config = try_get_generation_config( + self.generation_config, + trust_remote_code=self.trust_remote_code, + config_format=self.config_format, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + def get_diff_sampling_param(self) -> dict[str, Any]: + """ + This method returns a dictionary containing the non-default sampling + parameters with `override_generation_config` applied. + + The default sampling parameters are: + + - vLLM's neutral defaults if `self.generation_config="vllm"` + - the model's defaults if `self.generation_config="auto"` + - as defined in `generation_config.json` if + `self.generation_config="path/to/generation_config/dir"` + + Returns: + A dictionary containing the non-default sampling parameters. + """ + if self.generation_config == "vllm": + config = {} + else: + config = self.try_get_generation_config() + + # Overriding with given generation config + config.update(self.override_generation_config) + + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + "max_new_tokens", + ] + if any(p in config for p in available_params): + diff_sampling_param = { + p: config.get(p) for p in available_params if config.get(p) is not None + } + # Huggingface definition of max_new_tokens is equivalent + # to vLLM's max_tokens + if "max_new_tokens" in diff_sampling_param: + diff_sampling_param["max_tokens"] = diff_sampling_param.pop( + "max_new_tokens" + ) + else: + diff_sampling_param = {} + + if diff_sampling_param: + logger.warning_once( + "Default sampling parameters have been overridden by the " + "model's Hugging Face generation config recommended from the " + "model creator. If this is not intended, please relaunch " + "vLLM instance with `--generation-config vllm`." + ) + return diff_sampling_param + + @property + def is_encoder_decoder(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) + + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + + @property + def is_multimodal_raw_input_only_model(self) -> bool: + return self._model_info.supports_multimodal_raw_input_only + + @property + def is_cross_encoder(self) -> bool: + return ( + self._model_info.supports_cross_encoding or self.convert_type == "classify" + ) + + @property + def is_pp_supported(self) -> bool: + return self._model_info.supports_pp + + @property + def is_attention_free(self) -> bool: + return self._model_info.is_attention_free + + @property + def is_hybrid(self) -> bool: + return self._model_info.is_hybrid + + @property + def has_noops(self) -> bool: + return self._model_info.has_noops + + @property + def has_inner_state(self): + return self._model_info.has_inner_state + + @property + def is_v1_compatible(self) -> bool: + return not self._model_info.supports_v0_only + + @property + def use_mla(self) -> bool: + return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE + + @property + def is_matryoshka(self) -> bool: + return bool(getattr(self.hf_config, "matryoshka_dimensions", None)) or getattr( + self.hf_config, "is_matryoshka", False + ) + + @property + def matryoshka_dimensions(self): + return getattr(self.hf_config, "matryoshka_dimensions", None) + + @property + def use_pad_token(self) -> bool: + # cross_encoder models defaults to using pad_token. + # `llm as reranker` models defaults to not using pad_token. + return getattr(self.hf_config, "use_pad_token", True) + + @property + def head_dtype(self) -> torch.dtype: + """ + "head" refers to the last Linear layer(s) of an LLM, + such as the lm_head in a generation model, + or the score or classifier in a classification model. + + `head_dtype` currently only supports pooling models.\n + - The pooling model defaults to using fp32 head, + you can use --hf-overrides '{"head_dtype": "model"}' to disable it. + """ + + head_dtype = _get_head_dtype( + config=self.hf_config, dtype=self.dtype, runner_type=self.runner_type + ) + + if self.runner_type != "pooling" and head_dtype != self.dtype: + logger.warning_once( + "`head_dtype` currently only supports pooling models." + "fallback to model dtype [%s].", + self.dtype, + ) + return self.dtype + + if head_dtype not in current_platform.supported_dtypes: + logger.warning_once( + "The current platform does not support [%s] head dtype, " + "fallback to model dtype [%s].", + head_dtype, + self.dtype, + ) + return self.dtype + + logger.debug_once("head dtype: %s", head_dtype) + return head_dtype + + def get_and_verify_max_len(self, max_model_len: int): + # Consider max_model_len in tokenizer_config only when + # pooling models use absolute position_embedding. + tokenizer_config = None + if ( + self.runner_type == "pooling" + and getattr(self.hf_config, "position_embedding_type", "") == "absolute" + ): + tokenizer_config = try_get_tokenizer_config( + self.tokenizer, + trust_remote_code=self.trust_remote_code, + revision=self.tokenizer_revision, + ) + max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + tokenizer_config=tokenizer_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window=self.get_sliding_window(), + spec_target_max_model_len=self.spec_target_max_model_len, + encoder_config=self.encoder_config, + ) + logger.info("Using max model len %s", max_model_len) + return max_model_len + + +def get_served_model_name( + model: str, served_model_name: Optional[Union[str, list[str]]] +): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +# Some model suffixes are based on auto classes from Transformers: +# https://huggingface.co/docs/transformers/en/model_doc/auto +# NOTE: Items higher on this list priority over lower ones +_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ + ("ForCausalLM", ("generate", "none")), + ("ForConditionalGeneration", ("generate", "none")), + ("ChatModel", ("generate", "none")), + ("LMHeadModel", ("generate", "none")), + ("ForTextEncoding", ("pooling", "embed")), + ("EmbeddingModel", ("pooling", "embed")), + ("ForSequenceClassification", ("pooling", "classify")), + ("ForAudioClassification", ("pooling", "classify")), + ("ForImageClassification", ("pooling", "classify")), + ("ForVideoClassification", ("pooling", "classify")), + ("ClassificationModel", ("pooling", "classify")), + ("ForRewardModeling", ("pooling", "reward")), + ("RewardModel", ("pooling", "reward")), + # Let other `*Model`s take priority + ("Model", ("pooling", "embed")), +] + + +def iter_architecture_defaults(): + yield from _SUFFIX_TO_DEFAULTS + + +def try_match_architecture_defaults( + architecture: str, + *, + runner_type: Optional[RunnerType] = None, + convert_type: Optional[ConvertType] = None, +) -> Optional[tuple[str, tuple[RunnerType, ConvertType]]]: + for suffix, ( + default_runner_type, + default_convert_type, + ) in iter_architecture_defaults(): + if ( + (runner_type is None or runner_type == default_runner_type) + and (convert_type is None or convert_type == default_convert_type) + and architecture.endswith(suffix) + ): + return suffix, (default_runner_type, default_convert_type) + + return None + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +# model_type -> reason +_FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3_text": "Numerical instability. Please use bfloat16 or float32 instead.", + "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", + "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", +} + + +def _is_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 + return False + + return True + + +def _check_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] + raise ValueError( + f"The model type {model_type!r} does not support float16. Reason: {reason}" + ) + + return True + + +def _find_dtype( + model_id: str, + config: PretrainedConfig, + *, + revision: Optional[str], +): + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define torch_dtype + if config_dtype is None: + config_dtype = getattr(config.get_text_config(), "torch_dtype", None) + if config_dtype is None and hasattr(config, "vision_config"): + config_dtype = getattr(config.vision_config, "torch_dtype", None) + if config_dtype is None and hasattr(config, "encoder_config"): + config_dtype = getattr(config.encoder_config, "torch_dtype", None) + + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + + if config_dtype is None: + config_dtype = torch.float32 + + return config_dtype + + +def _resolve_auto_dtype( + model_type: str, + config_dtype: torch.dtype, + *, + is_pooling_model: bool, +): + from vllm.platforms import current_platform + + supported_dtypes = [ + dtype + for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype) + ] + + if is_pooling_model and torch.float16 in supported_dtypes: + preferred_dtype = torch.float16 + else: + preferred_dtype = supported_dtypes[0] + + # Downcast for float32 models + if config_dtype == torch.float32: + config_dtype = preferred_dtype + + if config_dtype in supported_dtypes: + return config_dtype + + # Ensure device compatibility + device_name = current_platform.get_device_name() + device_capability = current_platform.get_device_capability() + + if device_capability is None: + device_str = f"{device_name!r}" + else: + version_str = device_capability.as_version_str() + device_str = f"{device_name!r} (with compute capability {version_str})" + + logger.warning( + "Your device %s doesn't support %s. Falling back to %s for compatibility.", + device_str, + config_dtype, + preferred_dtype, + ) + + return preferred_dtype + + +def _get_and_verify_dtype( + model_id: str, + config: PretrainedConfig, + dtype: Union[str, torch.dtype], + *, + is_pooling_model: bool, + revision: Optional[str] = None, +) -> torch.dtype: + config_dtype = _find_dtype(model_id, config, revision=revision) + model_type = config.model_type + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + # Set default dtype from model config + torch_dtype = _resolve_auto_dtype( + model_type, + config_dtype, + is_pooling_model=is_pooling_model, + ) + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype!r}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + _check_valid_dtype(model_type, torch_dtype) + + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def _get_head_dtype( + config: PretrainedConfig, dtype: torch.dtype, runner_type: str +) -> torch.dtype: + head_dtype: Optional[Union[str, torch.dtype]] = getattr(config, "head_dtype", None) + + if head_dtype == "model": + return dtype + elif isinstance(head_dtype, str): + head_dtype = head_dtype.lower() + if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {head_dtype!r}") + return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype] + elif isinstance(head_dtype, torch.dtype): + return head_dtype + elif head_dtype is None: + if torch.float32 not in current_platform.supported_dtypes: + return dtype + if runner_type == "pooling": + return torch.float32 + return dtype + else: + raise ValueError(f"Unknown dtype: {head_dtype}") + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + tokenizer_config: Optional[dict], + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window: Optional[int], + spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + # For Command-R / Cohere, Cohere2 / Aya Vision models + if tmp_max_len := getattr(hf_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if ( + disable_sliding_window + and sliding_window is not None + and sliding_window < derived_max_model_len + ): + max_len_key = "sliding_window" + derived_max_model_len = sliding_window + + # Consider model_max_length in tokenizer_config + if tokenizer_config: + tokenizer_model_max_length = tokenizer_config.get( + "model_max_length", derived_max_model_len + ) + derived_max_model_len = min(derived_max_model_len, tokenizer_model_max_length) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + if spec_target_max_model_len is not None: + # If this is a speculative draft model, we use the max model len + # from the target model. + return spec_target_max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", + possible_keys, + default_max_len, + ) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: + # No need to consider "type" key because of patch_rope_scaling when + # loading HF config + rope_type = rope_scaling["rope_type"] + + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate." + ) + + # NOTE: rope_type == "default" does not define factor + # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py + scaling_factor = rope_scaling.get("factor", 1.0) + + if rope_type == "yarn": + derived_max_model_len = rope_scaling["original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + if current_platform.is_tpu(): + logger.warning( + "--max-model-len is not specified, " + "it's currently using model's default length %s, " + "which might be too large." + "Please input with --max-model-len based on your " + "request input length and output length, to avoid " + "unnecessary degradation.", + max_model_len, + ) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate." + ) + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json)." + ) + warning = ( + "VLLM_ALLOW_LONG_MAX_MODEL_LEN must be used with extreme " + "caution. If the model uses relative position encoding (RoPE), " + "positions exceeding derived_max_model_len lead to nan. If the " + "model uses absolute position encoding, positions exceeding " + "derived_max_model_len will cause a CUDA array out-of-bounds " + "error." + ) + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning_once("%s %s", msg, warning) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + f"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. {warning}" + ) + return int(max_model_len) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py new file mode 100644 index 0000000000000..fc8d2262dcb40 --- /dev/null +++ b/vllm/config/multimodal.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from collections.abc import Mapping +from dataclasses import field +from typing import Any, Literal, Optional, Union + +from pydantic import ConfigDict, Field, field_validator +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@dataclass +class BaseDummyOptions: + """Base options for generating dummy data during profiling.""" + + count: int = Field(999, ge=0) + + +@dataclass(config=ConfigDict(extra="forbid")) +class VideoDummyOptions(BaseDummyOptions): + """Options for generating dummy video data during profiling.""" + + num_frames: Optional[int] = Field(None, gt=0) + width: Optional[int] = Field(None, gt=0) + height: Optional[int] = Field(None, gt=0) + + +@dataclass(config=ConfigDict(extra="forbid")) +class ImageDummyOptions(BaseDummyOptions): + """Options for generating dummy image data during profiling.""" + + width: Optional[int] = Field(None, gt=0) + height: Optional[int] = Field(None, gt=0) + + +@dataclass(config=ConfigDict(extra="forbid")) +class AudioDummyOptions(BaseDummyOptions): + """Options for generating dummy audio data during profiling.""" + + length: Optional[int] = Field(None, gt=0) + + +MMEncoderTPMode = Literal["weights", "data"] +MMCacheType = Literal["shm", "lru"] +DummyOptions = Union[ + BaseDummyOptions, VideoDummyOptions, ImageDummyOptions, AudioDummyOptions +] + + +@config +@dataclass +class MultiModalConfig: + """Controls the behavior of multimodal models.""" + + limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict) + """The maximum number of input items and options allowed per + prompt for each modality. + Defaults to 999 for each modality. + + Legacy format (count only): + {"image": 16, "video": 2} + + Configurable format (with options): + {"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512}, + "image": {"count": 5, "width": 512, "height": 512}} + + Mixed format (combining both): + {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, + "height": 512}} + """ + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set + `--media-io-kwargs '{"video": {"num_frames": 40} }'`""" + mm_processor_kwargs: Optional[dict[str, object]] = None + """Arguments to be forwarded to the model's processor for multi-modal data, + e.g., image processor. Overrides for the multi-modal processor obtained + from `transformers.AutoProcessor.from_pretrained`. + + The available overrides depend on the model that is being run. + + For example, for Phi-3-Vision: + `{"num_crops": 4}`.""" + mm_processor_cache_gb: float = 4 + """The size (in GiB) of the multi-modal processor cache, which is used to + avoid re-processing past multi-modal inputs. + + This cache is duplicated for each API process and engine core process, + resulting in a total memory usage of + `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. + + Set to `0` to disable this cache completely (not recommended).""" + mm_processor_cache_type: MMCacheType = "lru" + """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, + use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" + mm_shm_cache_max_object_size_mb: int = 128 + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"shm"`.""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using tensor + parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior)\n + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" + interleave_mm_strings: bool = False + """Enable fully interleaved support for multimodal prompts, while using + --chat-template-content-format=string.""" + skip_mm_profiling: bool = False + """When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + + This reduces engine startup time but shifts the responsibility to users for + estimating the peak memory usage of the activation of multimodal encoder and + embedding cache.""" + video_pruning_rate: Optional[float] = None + """Sets pruning rate for video pruning via Efficient Video Sampling. + Value sits in range [0;1) and determines fraction of media tokens + from each video to be pruned. + """ + + @field_validator("limit_per_prompt", mode="before") + @classmethod + def _validate_limit_per_prompt( + cls, value: dict[str, Union[int, dict[str, int]]] + ) -> dict[str, DummyOptions]: + for k, v in value.items(): + # Handle legacy format where only count is specified + if isinstance(v, int): + v = {"count": v} + # Convert to the appropriate DummyOptions subclass + if k == "video": + value[k] = VideoDummyOptions(**v) + elif k == "image": + value[k] = ImageDummyOptions(**v) + elif k == "audio": + value[k] = AudioDummyOptions(**v) + else: + value[k] = BaseDummyOptions(**v) + return value + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def get_limit_per_prompt(self, modality: str) -> int: + """ + Get the maximum number of input items allowed per prompt + for the given modality (backward compatible). + """ + limit_data = self.limit_per_prompt.get(modality) + + if limit_data is None: + # Unspecified modality is set to 999 by default + return 999 + return limit_data.count + + def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]: + """ + Get the configurable dummy data options for a modality. + Returns None if no options are configured for this modality. + """ + # All values are now DummyOptions after normalization + return self.limit_per_prompt.get(modality) + + def merge_mm_processor_kwargs( + self, + inference_kwargs: Mapping[str, object], + ) -> dict[str, object]: + """ + Get the keyword arguments to pass to the multi-modal processor + according to the extra arguments passed during inference. + """ + kwargs = self.mm_processor_kwargs or {} + return kwargs | dict(inference_kwargs) + + def is_multimodal_pruning_enabled(self): + return self.video_pruning_rate is not None and self.video_pruning_rate > 0 diff --git a/vllm/config/observability.py b/vllm/config/observability.py new file mode 100644 index 0000000000000..6c7b5fbbee477 --- /dev/null +++ b/vllm/config/observability.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from functools import cached_property +from typing import Any, Literal, Optional, cast + +from pydantic.dataclasses import dataclass + +from vllm import version +from vllm.config.utils import config + +DetailedTraceModules = Literal["model", "worker", "all"] + + +@config +@dataclass +class ObservabilityConfig: + """Configuration for observability - metrics and tracing.""" + + show_hidden_metrics_for_version: Optional[str] = None + """Enable deprecated Prometheus metrics that have been hidden since the + specified version. For example, if a previously deprecated metric has been + hidden since the v0.7.0 release, you use + `--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while + you migrate to new metrics. The metric is likely to be removed completely + in an upcoming release.""" + + @cached_property + def show_hidden_metrics(self) -> bool: + """Check if the hidden metrics should be shown.""" + if self.show_hidden_metrics_for_version is None: + return False + return version._prev_minor_version_was(self.show_hidden_metrics_for_version) + + otlp_traces_endpoint: Optional[str] = None + """Target URL to which OpenTelemetry traces will be sent.""" + + collect_detailed_traces: Optional[list[DetailedTraceModules]] = None + """It makes sense to set this only if `--otlp-traces-endpoint` is set. If + set, it will collect detailed traces for the specified modules. This + involves use of possibly costly and or blocking operations and hence might + have a performance impact. + + Note that collecting detailed timing information for each request can be + expensive.""" + + @cached_property + def collect_model_forward_time(self) -> bool: + """Whether to collect model forward time for the request.""" + return self.collect_detailed_traces is not None and ( + "model" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces + ) + + @cached_property + def collect_model_execute_time(self) -> bool: + """Whether to collect model execute time for the request.""" + return self.collect_detailed_traces is not None and ( + "worker" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces + ) + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if ( + self.collect_detailed_traces is not None + and len(self.collect_detailed_traces) == 1 + and "," in self.collect_detailed_traces[0] + ): + self._parse_collect_detailed_traces() + + from vllm.tracing import is_otel_available, otel_import_error_traceback + + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}" + ) + + def _parse_collect_detailed_traces(self): + assert isinstance(self.collect_detailed_traces, list) + self.collect_detailed_traces = cast( + list[DetailedTraceModules], self.collect_detailed_traces[0].split(",") + ) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index f7b8b1d0a5658..88bee9e2d42ee 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from dataclasses import field +import os from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from pydantic import TypeAdapter, model_validator +from pydantic import Field, model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -29,7 +29,9 @@ else: logger = init_logger(__name__) +ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] +DataParallelBackend = Literal["ray", "mp"] @config @@ -47,7 +49,7 @@ class EPLBConfig: of the last `lb_window_size` steps will be used for rearranging experts. """ - num_redundant_experts: int = 0 + num_redundant_experts: int = Field(default=0, ge=0) """Number of redundant experts to use for expert parallelism.""" log_balancedness: bool = False @@ -56,13 +58,6 @@ class EPLBConfig: This is turned off by default since it will cause communication overhead. """ - @classmethod - def from_cli(cls, cli_value: str) -> "EPLBConfig": - """Parse the CLI value for the compilation config. - -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. - """ - return TypeAdapter(EPLBConfig).validate_json(cli_value) - @config @dataclass @@ -89,12 +84,12 @@ class ParallelConfig: """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" - data_parallel_backend: str = "mp" + data_parallel_backend: DataParallelBackend = "mp" """Backend to use for data parallel, either "mp" or "ray".""" data_parallel_external_lb: bool = False """Whether to use "external" DP LB mode. Applies only to online serving and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" - wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank + wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank is provided explicitly to vllm serve.""" data_parallel_hybrid_lb: bool = False """Whether to use "hybrid" DP LB mode. Applies only to online serving @@ -107,8 +102,17 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + eplb_config: EPLBConfig = Field(default_factory=EPLBConfig) """Expert parallelism configuration.""" + expert_placement_strategy: ExpertPlacementStrategy = "linear" + """The expert placement strategy for MoE layers:\n + - "linear": Experts are placed in a contiguous manner. For example, with 4 + experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have + experts [2, 3].\n + - "round_robin": Experts are placed in a round-robin manner. For example, + with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 + will have experts [1, 3]. This strategy can help improve load balancing + for grouped expert models with no redundant experts.""" num_redundant_experts: Optional[int] = None """`num_redundant_experts` is deprecated and has been replaced with `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. @@ -134,6 +138,24 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_dbo: bool = False + """Enable dual batch overlap for the model executor.""" + + dbo_decode_token_threshold: int = 32 + """The threshold for dual batch overlap for batches only containing decodes. + If the number of tokens in the request is greater than this threshold, + microbatching will be used. Otherwise, the request will be processed in a + single batch.""" + dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune + """The threshold for dual batch overlap for batches that contain one or more + prefills. If the number of tokens in the request is greater than this + threshold, microbatching will be used. Otherwise, the request will be + processed in a single batch.""" + + disable_nccl_for_dp_synchronization: bool = False + """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py + to use Gloo instead of NCCL for its all reduce""" + ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -143,9 +165,9 @@ class ParallelConfig: placement_group: Optional[PlacementGroup] = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[Union[str, - DistributedExecutorBackend, - type[ExecutorBase]]] = None + distributed_executor_backend: Optional[ + Union[str, DistributedExecutorBackend, type[ExecutorBase]] + ] = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size is less than @@ -166,17 +188,86 @@ class ParallelConfig: new attributes and methods to the worker class for use in collective_rpc calls.""" - world_size: int = field(init=False) + world_size: int = Field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" rank: int = 0 """Global rank in distributed setup.""" - _data_parallel_master_port_list: list[int] = field(default_factory=list) + _data_parallel_master_port_list: list[int] = Field(default_factory=list) """List of open port auto-queried for data parallel messaging. Set to be private as it's not intended to be configured by users. """ + decode_context_parallel_size: int = 1 + """Number of decode context parallel groups, because the world size does + not change by dcp, it simply reuse the GPUs of TP group, and tp_size + needs to be divisible by dcp_size.""" + + _api_process_count: int = Field(default=1, gt=0) + """ + The number of API processes initialized. + + Note: + This is an internal config that is only valid for and + should only be set by API server scale-out. + """ + + _api_process_rank: int = Field(default=0, ge=-1) + """ + The rank of this API process, or `-1` for engine core processes + under API server scale-out. + + Note: + This is an internal config that is only valid for and + should only be set by API server scale-out. + """ + + @model_validator(mode="after") + def _validate_parallel_config(self) -> Self: + if self._api_process_rank >= self._api_process_count: + raise ValueError( + "Invalid value of `_api_process_rank`. " + f"Expected to be `-1` or `[0, {self._api_process_count})`, " + f"but found: {self._api_process_rank}" + ) + + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})" + ) + + if self.data_parallel_size <= 1 and self.data_parallel_external_lb: + raise ValueError( + "data_parallel_external_lb can only be set when data_parallel_size > 1" + ) + + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now." + ) + if not self.enable_expert_parallel: + raise ValueError("enable_expert_parallel must be True to use EPLB.") + if self.tensor_parallel_size * self.data_parallel_size <= 1: + raise ValueError( + "EPLB requires tensor_parallel_size or data_parallel_size " + f"to be greater than 1, but got " + f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." + ) + else: + if self.eplb_config.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts is set to " + f"{self.eplb_config.num_redundant_experts} but EPLB is not " + "enabled. Either enable EPLB or unset " + "num_redundant_experts." + ) + + return self + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world @@ -211,7 +302,8 @@ class ParallelConfig: from torch.distributed import DistNetworkError from vllm.distributed.utils import ( - stateless_init_torch_distributed_process_group) + stateless_init_torch_distributed_process_group, + ) max_retries = 5 last_exc: Optional[Exception] = None @@ -223,12 +315,12 @@ class ParallelConfig: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend="gloo") + backend="gloo", + ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. if "EADDRINUSE" in str(e): - logger.warning( - "Address already in use. Retrying with a new port.") + logger.warning("Address already in use. Retrying with a new port.") last_exc = e continue # try again with a new port raise e @@ -237,12 +329,33 @@ class ParallelConfig: assert last_exc is not None raise last_exc + # The all_reduce at the end of attention (during o_proj) means that + # inputs are replicated across each rank of the tensor parallel group. + # If using expert-parallelism with DeepEP All2All ops, replicated + # tokens results in useless duplicate computation and communication. + # + # In this case, ensure the input to the experts is sequence parallel + # to avoid the excess work. + # + # Not needed for pplx-kernels as it can handle duplicate input tokens. + @property + def use_sequence_parallel_moe(self) -> bool: + return ( + envs.VLLM_ALL2ALL_BACKEND + in ( + "allgather_reducescatter", + "naive", + "deepep_high_throughput", + "deepep_low_latency", + ) + and self.enable_expert_parallel + and self.tensor_parallel_size > 1 + and self.data_parallel_size > 1 + ) + @staticmethod - def has_unfinished_dp(dp_group: ProcessGroup, - has_unfinished: bool) -> bool: - tensor = torch.tensor([has_unfinished], - dtype=torch.int32, - device="cpu") + def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu") # dp rank 0: has_unfinished_seqs=True # dp rank 1: has_unfinished_seqs=False # aggregated: has_unfinished_seqs=True @@ -252,13 +365,10 @@ class ParallelConfig: return aggregated_has_unfinished @staticmethod - def sync_kv_cache_memory_size(dp_group: ProcessGroup, - kv_cache_memory: int) -> int: + def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int: if kv_cache_memory == -1: kv_cache_memory = torch.iinfo(torch.int64).max - tensor = torch.tensor([kv_cache_memory], - dtype=torch.int64, - device="cpu") + tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu") # we cannot use broadcast for stateless dp group since it depends # on global rank torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) @@ -271,6 +381,9 @@ class ParallelConfig: graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states. + + This hash is also used for DP worker configuration validation + to prevent hangs from mismatched collective communication patterns. """ factors: list[Any] = [] factors.append(self.pipeline_parallel_size) @@ -278,60 +391,77 @@ class ParallelConfig: factors.append(self.enable_expert_parallel) factors.append(self.data_parallel_size) factors.append(envs.VLLM_ALL2ALL_BACKEND) + factors.append(self.enable_eplb) + if self.enable_eplb: + factors.append(self.eplb_config.log_balancedness) + factors.append(self.eplb_config.window_size) + factors.append(self.eplb_config.step_interval) + factors.append(self.eplb_config.num_redundant_experts) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: # Forward deprecated fields to their new location if self.num_redundant_experts is not None: - self.eplb_config.num_redundant_experts = ( - self.num_redundant_experts) + self.eplb_config.num_redundant_experts = self.num_redundant_experts logger.warning_once( "num_redundant_experts is deprecated and has been replaced " "with eplb_config.num_redundant_experts. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) if self.eplb_window_size is not None: self.eplb_config.window_size = self.eplb_window_size logger.warning_once( "eplb_window_size is deprecated and has been replaced " "with eplb_config.window_size. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) if self.eplb_step_interval is not None: self.eplb_config.step_interval = self.eplb_step_interval logger.warning_once( "eplb_step_interval is deprecated and has been replaced " "with eplb_config.step_interval. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) if self.eplb_log_balancedness is not None: self.eplb_config.log_balancedness = self.eplb_log_balancedness logger.warning_once( "eplb_log_balancedness is deprecated and has been replaced " "with eplb_config.log_balancedness. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) # Continue with the rest of the initialization - self.world_size = self.pipeline_parallel_size * \ - self.tensor_parallel_size + self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size - if self.data_parallel_size_local > self.data_parallel_size: - raise ValueError( - f"data_parallel_size_local ({self.data_parallel_size_local}) " - f"must be <= data_parallel_size ({self.data_parallel_size})") + if self.distributed_executor_backend == "external_launcher": + logger.info("Using external launcher for distributed inference.") + self.world_size *= self.data_parallel_size if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. + if self.distributed_executor_backend == "external_launcher": + # For external launcher, + # we need to set the data parallel rank automatically + self.data_parallel_rank = int(os.environ["RANK"]) // ( + self.world_size // self.data_parallel_size + ) + logger.info( + "Set data_parallel_rank to %d automatically.", + self.data_parallel_rank, + ) if not self._data_parallel_master_port_list: self._data_parallel_master_port_list = get_open_ports_list(5) - self.data_parallel_master_port = \ - self._data_parallel_master_port_list.pop() + self.data_parallel_master_port = self._data_parallel_master_port_list.pop() if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( f"data_parallel_rank ({self.data_parallel_rank})" - f" must be in the range [0, {self.data_parallel_size})") + f" must be in the range [0, {self.data_parallel_size})" + ) else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE @@ -340,75 +470,52 @@ class ParallelConfig: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - if self.data_parallel_external_lb: - raise ValueError("data_parallel_external_lb can only " - "be set when data_parallel_size > 1") - if self.distributed_executor_backend == "external_launcher": - import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") - if self.enable_eplb: - if not current_platform.is_cuda(): - raise ValueError( - "Expert parallelism load balancing is only supported on " - "CUDA devices now.") - if self.eplb_config.num_redundant_experts < 0: - raise ValueError( - "num_redundant_experts must be non-negative, but got " - f"{self.eplb_config.num_redundant_experts}.") - if not self.enable_expert_parallel: - raise ValueError( - "enable_expert_parallel must be True to use EPLB.") - if self.tensor_parallel_size * self.data_parallel_size <= 1: - raise ValueError( - "EPLB requires tensor_parallel_size or data_parallel_size " - f"to be greater than 1, but got " - f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." - ) - else: - if self.eplb_config.num_redundant_experts != 0: - raise ValueError( - "num_redundant_experts should be used with EPLB." - f"{self.eplb_config.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. from vllm.executor import ray_utils + backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() - if current_platform.is_neuron(): - # neuron uses single process to control multiple devices + if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: backend = "uni" - elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: - backend = "uni" - elif (current_platform.is_cuda() - and cuda_device_count_stateless() < self.world_size): + elif ( + current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size + ): if not ray_found: - raise ValueError("Unable to load Ray: " - f"{ray_utils.ray_import_err}. Ray is " - "required for multi-node inference, " - "please install Ray with `pip install " - "ray`.") + raise ValueError( + "Unable to load Ray: " + f"{ray_utils.ray_import_err}. Ray is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`." + ) backend = "ray" elif self.data_parallel_backend == "ray": - logger.info("Using ray distributed inference because " - "data_parallel_backend is ray") + logger.info( + "Using ray distributed inference because " + "data_parallel_backend is ray" + ) backend = "ray" elif ray_found: if self.placement_group: backend = "ray" else: from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): from ray.util import get_current_placement_group + if get_current_placement_group(): backend = "ray" self.distributed_executor_backend = backend - logger.debug("Defaulting to use %s for distributed inference", - backend) + logger.debug("Defaulting to use %s for distributed inference", backend) if self.distributed_executor_backend is None and self.world_size == 1: self.distributed_executor_backend = "uni" @@ -417,33 +524,43 @@ class ParallelConfig: def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( isinstance(self.distributed_executor_backend, type) - and getattr(self.distributed_executor_backend, "uses_ray", False)) + and getattr(self.distributed_executor_backend, "uses_ray", False) + ) - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform - if self.distributed_executor_backend is not None and not isinstance( - self.distributed_executor_backend, str) and not (isinstance( - self.distributed_executor_backend, type) and issubclass( - self.distributed_executor_backend, ExecutorBase)): + + if ( + self.distributed_executor_backend is not None + and not isinstance(self.distributed_executor_backend, str) + and not ( + isinstance(self.distributed_executor_backend, type) + and issubclass(self.distributed_executor_backend, ExecutorBase) + ) + ): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " "values are 'ray', 'mp' 'uni', 'external_launcher', " - " custom ExecutorBase subclass or its import path.") + " custom ExecutorBase subclass or its import path." + ) if self.use_ray: from vllm.executor import ray_utils + ray_utils.assert_ray_available() if not current_platform.use_custom_allreduce(): self.disable_custom_all_reduce = True logger.debug( "Disabled the custom all-reduce kernel because it is not " - "supported on current platform.") + "supported on current platform." + ) if self.ray_workers_use_nsight and not self.use_ray: - raise ValueError("Unable to use nsight profiling unless workers " - "run with Ray.") + raise ValueError( + "Unable to use nsight profiling unless workers run with Ray." + ) return self diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py new file mode 100644 index 0000000000000..8b10992faa022 --- /dev/null +++ b/vllm/config/pooler.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any, Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class PoolerConfig: + """Controls the behavior of output pooling in pooling models.""" + + pooling_type: Optional[str] = None + """ + The pooling method of the pooling model. This should be a key in + [`vllm.model_executor.layers.pooler.PoolingType`][]. + """ + + ## for embeddings models + normalize: Optional[bool] = None + """ + Whether to normalize the embeddings outputs. Defaults to True. + """ + dimensions: Optional[int] = None + """ + Reduce the dimensions of embeddings if model + support matryoshka representation. Defaults to None. + """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_embed_len to be accepted for embedding models. + When an input exceeds max_embed_len, it will be handled according to + the original max_model_len validation logic. + Defaults to None (i.e. set to max_model_len). + """ + + ## for classification models + activation: Optional[bool] = None + """ + Whether to apply activation function to the classification outputs. + Defaults to True. + """ + logit_bias: Optional[float] = None + """ + If provided, apply classification logit biases. Defaults to None. + """ + + ## for reward models + softmax: Optional[bool] = None + """ + Whether to apply softmax to the reward outputs. + Defaults to True. + """ + step_tag_id: Optional[int] = None + """ + If set, only the score corresponding to the ``step_tag_id`` in the + generated sentence should be returned. Otherwise, the scores for all tokens + are returned. + """ + returned_token_ids: Optional[list[int]] = None + """ + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the + ``math-shepherd-mistral-7b-prm`` model. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 93002012799ab..396258aac287b 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from dataclasses import field -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from dataclasses import InitVar, field +from typing import Any, Literal, Union from pydantic import SkipValidation, model_validator from pydantic.dataclasses import dataclass @@ -11,18 +11,15 @@ from typing_extensions import Self from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS) - -if TYPE_CHECKING: - from vllm.config import RunnerType -else: - RunnerType = Any +from vllm.utils import ( + DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, +) logger = init_logger(__name__) -PreemptionMode = Literal["swap", "recompute"] +RunnerType = Literal["generate", "pooling", "draft"] SchedulerPolicy = Literal["fcfs", "priority"] @@ -82,10 +79,6 @@ class SchedulerConfig: 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" - delay_factor: float = 0.0 - """Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -93,6 +86,13 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" + is_encoder_decoder: InitVar[bool] = False + """True if the model is an encoder-decoder model. + + Note: This is stored in the ModelConfig, and is used only here to + disable chunked prefill and prefix caching for encoder-decoder models. + """ + # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = field(init=False) """Multimodal encoder compute budget, only used in V1. @@ -107,14 +107,6 @@ class SchedulerConfig: NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" - preemption_mode: Optional[PreemptionMode] = None - """Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead.""" - send_delta_data: bool = False """Private API. If used, scheduler sends delta data to workers instead of an entire data. It should be enabled only @@ -174,17 +166,27 @@ class SchedulerConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self) -> None: + def __post_init__(self, is_encoder_decoder: bool) -> None: if self.max_model_len is None: self.max_model_len = 8192 if self.max_num_seqs is None: self.max_num_seqs = 128 + if is_encoder_decoder: + # Chunked prefill should be disabled for encoder-decoder models. + self.disable_chunked_mm_input = True + self.chunked_prefill_enabled = False + self.enable_chunked_prefill = False + self.long_prefill_token_threshold = 0 + logger.info( + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both." + ) + if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS @@ -193,7 +195,8 @@ class SchedulerConfig: # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS + ) if self.runner_type == "pooling": # Choose specific value for higher throughput @@ -212,8 +215,8 @@ class SchedulerConfig: # Ensure max_num_batched_tokens does not exceed model limit. # Some models (e.g., Whisper) have embeddings tied to max length. self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, - self.max_num_batched_tokens) + self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens + ) self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -221,20 +224,22 @@ class SchedulerConfig: if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens) + self.max_num_batched_tokens, + ) self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * - 0.04) + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) logger.info( "Concurrent partial prefills enabled with " "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " "long_prefill_token_threshold=%d", - self.max_num_partial_prefills, self.max_long_partial_prefills, - self.long_prefill_token_threshold) + self.max_num_partial_prefills, + self.max_long_partial_prefills, + self.long_prefill_token_threshold, + ) # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. # This avoids OOM in tight memory scenarios with small max_num_seqs, @@ -244,61 +249,71 @@ class SchedulerConfig: self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] if self.async_scheduling: - self.scheduler_cls = ( - "vllm.v1.core.sched.async_scheduler.AsyncScheduler") + self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: - if (self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled): + if ( + self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled + ): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + f"({self.max_num_seqs})." + ) if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: logger.warning( "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len) + self.max_num_seqs * self.max_model_len, + ) if self.num_lookahead_slots < 0: raise ValueError( "num_lookahead_slots " f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0.") + "equal to 0." + ) if self.max_num_partial_prefills < 1: raise ValueError( f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1.") + "must be greater than or equal to 1." + ) elif self.max_num_partial_prefills > 1: if not self.chunked_prefill_enabled: - raise ValueError("Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1.") + raise ValueError( + "Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1." + ) if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( "long_prefill_token_threshold " f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len}).") + f"than the max_model_len ({self.max_model_len})." + ) - if (self.max_long_partial_prefills - < 1) or (self.max_long_partial_prefills - > self.max_num_partial_prefills): + if (self.max_long_partial_prefills < 1) or ( + self.max_long_partial_prefills > self.max_num_partial_prefills + ): raise ValueError( f"max_long_partial_prefills ({self.max_long_partial_prefills}) " "must be greater than or equal to 1 and less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + f"max_num_partial_prefills ({self.max_num_partial_prefills})." + ) return self diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py new file mode 100644 index 0000000000000..aa0c07cf62a36 --- /dev/null +++ b/vllm/config/speculative.py @@ -0,0 +1,604 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import hashlib +from typing import TYPE_CHECKING, Any, Literal, Optional + +from pydantic import SkipValidation, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.parallel import ParallelConfig +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + import vllm.model_executor.layers.quantization as me_quant + from vllm.config import ModelConfig +else: + PretrainedConfig = Any + ModelConfig = Any + + me_quant = LazyLoader( + "model_executor", globals(), "vllm.model_executor.layers.quantization" + ) + +logger = init_logger(__name__) + +SpeculativeMethod = Literal[ + "ngram", + "eagle", + "eagle3", + "medusa", + "mlp_speculator", + "draft_model", + "deepseek_mtp", + "ernie_mtp", + "qwen3_next_mtp", + "mimo_mtp", + "longcat_flash_mtp", + "mtp", +] +MTP_MODEL_TYPES = ( + "deepseek_mtp", + "mimo_mtp", + "glm4_moe_mtp", + "ernie_mtp", + "qwen3_next_mtp", + "longcat_flash_mtp", +) + + +@config +@dataclass +class SpeculativeConfig: + """Configuration for speculative decoding.""" + + enforce_eager: Optional[bool] = None + """Override the default enforce_eager from model_config""" + # General speculative decoding control + num_speculative_tokens: SkipValidation[int] = None # type: ignore + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" + disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" + + # Draft model configuration + quantization: Optional[me_quant.QuantizationMethods] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" + max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" + revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + + # Advanced control + disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + disable_padded_drafter_batch: bool = False + """Disable input padding for speculative decoding. If set to True, + speculative input batches can contain sequences of different lengths, + which may only be supported by certain attention backends. This currently + only affects the EAGLE method of speculation.""" + + # Ngram proposer configuration + prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" + prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ + # required configuration params passed from engine + target_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the target model.""" + target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore + """The parallel configuration for the target model.""" + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" + disable_log_stats: SkipValidation[bool] = None # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" + + # params generated in the post-init stage + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the draft model initialized internal.""" + draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore + """The parallel configuration for the draft model initialized internal.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + # Eagle3 affects the computation graph because it returns intermediate + # hidden states in addition to the final hidden state. + factors.append(self.method == "eagle3") + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]} + ) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"], + } + ) + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"], + } + ) + + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]} + ) + + if hf_config.model_type == "qwen3_next": + hf_config.model_type = "qwen3_next_mtp" + if hf_config.model_type == "qwen3_next_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]} + ) + if hf_config.model_type == "longcat_flash": + hf_config.model_type = "longcat_flash_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) + hf_config.update( + {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} + ) + + return hf_config + + def __post_init__(self): + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.method in MTP_MODEL_TYPES: + logger.warning( + "method `%s` is deprecated and replaced with mtp.", self.method + ) + self.method = "mtp" + + if self.model is None and self.num_speculative_tokens is not None: + if self.method == "mtp": + assert self.target_model_config is not None, ( + "target_model_config must be present for mtp" + ) + if self.target_model_config.hf_text_config.model_type == "deepseek_v32": + # FIXME(luccafong): cudgraph with v32 MTP is not supported, + # remove this when the issue is fixed. + self.enforce_eager = True + # use the draft model from the same model: + self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError( + "num_speculative_tokens was provided but without speculative model." + ) + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and ( + self.model is not None and self.model in ("ngram", "[ngram]") + ): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if self.prompt_lookup_min is None and self.prompt_lookup_max is None: + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0" + ) + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0" + ) + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}" + ) + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + # TODO: Move this import to the top once `ModelConfig` + # lives in `vllm.config.model`. + from vllm.config import ModelConfig + + self.draft_model_config = ModelConfig( + model=self.model, + runner="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config.allowed_local_media_path, + allowed_media_domains=self.target_model_config.allowed_media_domains, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config.tokenizer_revision, + spec_target_max_model_len=self.target_model_config.max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ("eagle", "eagle3"): + pass + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif self.draft_model_config.hf_config.model_type == "mlp_speculator": + self.method = "mlp_speculator" + elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: + self.method = "mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "Enabling num_speculative_tokens > 1 will run" + "multiple times of forward on same MTP layer" + ",which may result in lower acceptance rate" + ) + elif self.draft_model_config.hf_config.model_type in ( + "longcat_flash_mtp" + ): + self.method = "longcat_flash_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "LongCat MTP models only have " + "one layer. Might need some code changes " + "to support multiple layers." + ) + else: + self.method = "draft_model" + raise NotImplementedError( + "Speculative decoding with draft model is not " + "supported yet. Please consider using other " + "speculative decoding methods such as ngram, medusa, " + "eagle, or mtp." + ) + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0." + ) + + from vllm.transformers_utils.configs import SpeculatorsConfig + from vllm.transformers_utils.configs.eagle import EAGLEConfig + + if isinstance( + self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig), + ): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle", + ) + self.draft_model_config.hf_config = eagle_config + + if self.num_speculative_tokens is not None and hasattr( + self.draft_model_config.hf_config, "num_lookahead_tokens" + ): + self.draft_model_config.hf_config.num_lookahead_tokens = ( + self.num_speculative_tokens + ) + + n_predict = getattr( + self.draft_model_config.hf_config, "n_predict", None + ) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif ( + self.num_speculative_tokens > n_predict + and self.num_speculative_tokens % n_predict != 0 + ): + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}" + ) + + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str( + [(i + 1) * (0,) for i in range(self.num_speculative_tokens)] + ) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval(self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t)) + ) + + self.draft_tensor_parallel_size = ( + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config, + ) + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + ) + ) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, self.draft_tensor_parallel_size + ) + ) + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + if speculative_max_model_len > draft_max_model_len: + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}" + ) + + if speculative_max_model_len > target_max_model_len: + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}" + ) + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + + @staticmethod + def _verify_and_get_draft_tp( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, + ) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. + """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. + if speculative_draft_tensor_parallel_size is None: + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type, + ) + else: + speculative_draft_tensor_parallel_size = ( + target_parallel_config.tensor_parallel_size + ) + elif speculative_draft_tensor_parallel_size not in ( + 1, + target_parallel_config.tensor_parallel_size, + ): + raise ValueError( + f"{speculative_draft_tensor_parallel_size=} cannot be " + f"other value than 1 or target model tensor_parallel_size" + ) + return speculative_draft_tensor_parallel_size + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config.pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + distributed_executor_backend=target_parallel_config.distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + @model_validator(mode="after") + def _verify_args(self) -> Self: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter." + ) + + if self.num_speculative_tokens <= 0: + raise ValueError( + "Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens})." + ) + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config + ) + + if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2: + raise ValueError( + "Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}" + ) + + eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] + if ( + self.method == "eagle3" + and self.target_model_config + and not any( + supported_model in self.target_model_config.hf_text_config.model_type + for supported_model in eagle3_target_supported + ) + ): + raise ValueError( + f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 + f"Got {self.target_model_config.hf_text_config.model_type=}" + ) + + return self + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3", "mtp") + + def __repr__(self) -> str: + method = self.method + model = None if method == "ngram" else self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/config/speech_to_text.py b/vllm/config/speech_to_text.py new file mode 100644 index 0000000000000..de9f525efe185 --- /dev/null +++ b/vllm/config/speech_to_text.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class SpeechToTextConfig: + """Configuration for speech-to-text models.""" + + sample_rate: float = 16_000 + """Sample rate (Hz) to resample input audio to. Most speech models expect + 16kHz audio input. The input audio will be automatically resampled to this + rate before processing.""" + + max_audio_clip_s: int = 30 + """Maximum duration in seconds for a single audio clip without chunking. + Audio longer than this will be split into smaller chunks if + `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" + + overlap_chunk_second: int = 1 + """Overlap duration in seconds between consecutive audio chunks when + splitting long audio. This helps maintain context across chunk boundaries + and improves transcription quality at split points.""" + + min_energy_split_window_size: Optional[int] = 1600 + """Window size in samples for finding low-energy (quiet) regions to split + audio chunks. The algorithm looks for the quietest moment within this + window to minimize cutting through speech. Default 1600 samples ≈ 100ms + at 16kHz. If None, no chunking will be done.""" + + @property + def allow_audio_chunking(self) -> bool: + return self.min_energy_split_window_size is not None diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py new file mode 100644 index 0000000000000..5111c9c77d90e --- /dev/null +++ b/vllm/config/structured_outputs.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any, Literal + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +StructuredOutputsBackend = Literal[ + "auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer" +] + + +@config +@dataclass +class StructuredOutputsConfig: + """Dataclass which contains structured outputs config for the engine.""" + + backend: StructuredOutputsBackend = "auto" + """Which engine will be used for structured outputs (e.g. JSON schema, + regex, etc) by default. With "auto", we will make opinionated choices + based on request contents and what the backend libraries currently support, + so the behavior is subject to change in each release.""" + disable_fallback: bool = False + """If `True`, vLLM will not fallback to a different backend on error.""" + disable_any_whitespace: bool = False + """If `True`, the model will not generate any whitespace during structured + outputs. This is only supported for xgrammar and guidance backends.""" + disable_additional_properties: bool = False + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + reasoning_parser: str = "" + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"): + raise ValueError( + "disable_any_whitespace is only supported for " + "xgrammar and guidance backends." + ) + if self.disable_additional_properties and self.backend != "guidance": + raise ValueError( + "disable_additional_properties is only supported " + "for the guidance backend." + ) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 98fbeb1fa86aa..5e7e7580c5a9e 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -1,15 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions for vLLM config dataclasses.""" -from typing import TYPE_CHECKING, TypeVar +import ast +import inspect +import textwrap +from collections.abc import Iterable +from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from itertools import pairwise +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +import regex as re +from pydantic.fields import FieldInfo +from typing_extensions import runtime_checkable if TYPE_CHECKING: from _typeshed import DataclassInstance - - ConfigType = type[DataclassInstance] else: - ConfigType = type + DataclassInstance = Any +ConfigType = type[DataclassInstance] ConfigT = TypeVar("ConfigT", bound=ConfigType) @@ -27,3 +37,142 @@ def config(cls: ConfigT) -> ConfigT: script, which is invoked during the pre-commit checks. """ return cls + + +def get_field(cls: ConfigType, name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields[name] + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + if isinstance(default, FieldInfo): + # Handle pydantic.Field defaults + if default.default_factory is not None: + return field(default_factory=default.default_factory) + else: + default = default.default + return field(default=default) + + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory." + ) + + +def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: + """ + A helper function that retrieves an attribute from an object which may + have multiple possible names. This is useful when fetching attributes from + arbitrary `transformers.PretrainedConfig` instances. + """ + for name in names: + if hasattr(object, name): + return getattr(object, name) + return default + + +def contains_object_print(text: str) -> bool: + """ + Check if the text looks like a printed Python object, e.g. + contains any substring matching the pattern: "at 0xFFFFFFF>" + We match against 0x followed by 2-16 hex chars (there's + a max of 16 on a 64-bit system). + + Args: + text (str): The text to check + + Returns: + result (bool): `True` if a match is found, `False` otherwise. + """ + pattern = r"at 0x[a-fA-F0-9]{2,16}>" + match = re.search(pattern, text) + return match is not None + + +def assert_hashable(text: str) -> bool: + if not contains_object_print(text): + return True + raise AssertionError( + f"vLLM tried to hash some configs that may have Python objects ids " + f"in them. This is a bug, please file an issue. " + f"Text being hashed: {text}" + ) + + +def get_attr_docs(cls: type[Any]) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + https://davidism.com/mit-license/ + """ + + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + + if not isinstance(cls_node, ast.ClassDef): + raise TypeError("Given object was not a class.") + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if ( + not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str) + ): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +def is_init_field(cls: ConfigType, name: str) -> bool: + return next(f for f in fields(cls) if f.name == name).init + + +@runtime_checkable +class SupportsHash(Protocol): + def compute_hash(self) -> str: ... + + +class SupportsMetricsInfo(Protocol): + def metrics_info(self) -> dict[str, str]: ... + + +def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: + processed_overrides = {} + for field_name, value in overrides.items(): + assert hasattr(config, field_name), ( + f"{type(config)} has no field `{field_name}`" + ) + current_value = getattr(config, field_name) + if is_dataclass(current_value) and not is_dataclass(value): + assert isinstance(value, dict), ( + f"Overrides to {type(config)}.{field_name} must be a dict" + f" or {type(current_value)}, but got {type(value)}" + ) + value = update_config( + current_value, # type: ignore[type-var] + value, + ) + processed_overrides[field_name] = value + return replace(config, **processed_overrides) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py new file mode 100644 index 0000000000000..833581035a318 --- /dev/null +++ b/vllm/config/vllm.py @@ -0,0 +1,876 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import hashlib +import json +import os +from contextlib import contextmanager +from dataclasses import field, replace +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.transformers_utils.runai_utils import is_runai_obj_uri +from vllm.utils import random_uuid + +from .cache import CacheConfig +from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode +from .device import DeviceConfig +from .kv_events import KVEventsConfig +from .kv_transfer import KVTransferConfig +from .load import LoadConfig +from .lora import LoRAConfig +from .model import ModelConfig +from .observability import ObservabilityConfig +from .parallel import ParallelConfig +from .scheduler import SchedulerConfig +from .speculative import SpeculativeConfig +from .structured_outputs import StructuredOutputsConfig +from .utils import SupportsHash, config + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +else: + PretrainedConfig = Any + + QuantizationConfig = Any + +logger = init_logger(__name__) + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + # TODO: use default_factory once default constructing ModelConfig doesn't + # try to download a model + model_config: ModelConfig = None # type: ignore + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" + lora_config: Optional[LoRAConfig] = None + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" + structured_outputs_config: StructuredOutputsConfig = field( + default_factory=StructuredOutputsConfig + ) + """Structured outputs configuration.""" + observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" + quant_config: Optional[QuantizationConfig] = None + """Quantization configuration.""" + compilation_config: CompilationConfig = field(default_factory=CompilationConfig) + """`torch.compile` and cudagraph capture configuration for the model. + + As a shorthand, `-O<n>` can be used to directly specify the compilation + level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). + Currently, -O <n> and -O=<n> are supported as well but this will likely be + removed in favor of clearer -O<n> syntax in the future. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production, also default in V1. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" + kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing, debugging or out of + # tree config registration. + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" + instance_id: str = "" + """The ID of the vLLM instance.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + + # summarize vllm config + vllm_factors: list[Any] = [] + from vllm import __version__ + + vllm_factors.append(__version__) + vllm_factors.append(envs.VLLM_USE_V1) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + # LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens)) + else: + vllm_factors.append("None") + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") + if self.structured_outputs_config: + vllm_factors.append(self.structured_outputs_config.compute_hash()) + else: + vllm_factors.append("None") + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + else: + vllm_factors.append("None") + if self.additional_config: + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) + else: + vllm_factors.append("None") + factors.append(vllm_factors) + + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] + return hash_str + + def pad_for_cudagraph(self, batch_size: int) -> int: + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + from vllm.platforms import current_platform + + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import get_quant_config + + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) + quant_config.maybe_update_config(model_config.model) + return quant_config + return None + + @staticmethod + def get_quantization_config( + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config( + copy.deepcopy(model_config), load_config + ) + + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + + def __post_init__(self): + """Verify configs are valid & consistent with each other.""" + + self.try_verify_and_update_config() + + if self.model_config is not None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config(self.load_config) + + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config is not None: + self.lora_config.verify_with_cache_config(self.cache_config) + self.lora_config.verify_with_model_config(self.model_config) + + if self.quant_config is None and self.model_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config + ) + + from vllm.platforms import current_platform + + if ( + self.model_config is not None + and self.scheduler_config.chunked_prefill_enabled + and self.model_config.dtype == torch.float32 + and current_platform.get_device_capability() == (7, 5) + ): + logger.warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels." + ) + + # If the user does not explicitly set a compilation level, then + # we use the default level. The default level depends on other + # settings (see the below code). + if self.compilation_config.level is None: + if envs.VLLM_USE_V1: + if ( + self.model_config is not None + and not self.model_config.enforce_eager + ): + self.compilation_config.level = CompilationLevel.PIECEWISE + else: + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + else: + # NB: Passing both --enforce-eager and a compilation level + # in V0 means the compilation level wins out. + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + # async tp is built on top of sequence parallelism + # and requires it to be enabled. + if self.compilation_config.pass_config.enable_async_tp: + self.compilation_config.pass_config.enable_sequence_parallelism = True + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") + + if current_platform.support_static_graph_mode(): + # if cudagraph_mode is not explicitly set by users, set default + # value + if self.compilation_config.cudagraph_mode is None: + if ( + envs.VLLM_USE_V1 + and self.compilation_config.level == CompilationLevel.PIECEWISE + ): + # default to full and piecewise for most models + self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_AND_PIECEWISE + ) + + # pooling models and encoder-decoder models + # do not support full cudagraphs + if self.model_config is not None and ( + self.model_config.pooler_config is not None + or self.model_config.is_encoder_decoder + ): + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + # disable cudagraph when enforce eager execution + if self.model_config is not None and self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode") + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + elif envs.VLLM_USE_V1: + self.compilation_config.cudagraph_num_of_warmups = 1 + + self._set_cudagraph_sizes() + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if self.cache_config.kv_sharing_fast_prefill: + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not " + "compatible with EAGLE as EAGLE requires correct logits " + "for all tokens while fast prefill gives incorrect logits " + "for prompt tokens." + ) + + logger.warning_once( + "--kv-sharing-fast-prefill requires changes on model side for " + "correctness and to realize prefill savings. " + ) + + disable_chunked_prefill_reasons: list[str] = [] + + if self.model_config: + if self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + 'Only "last" pooling supports chunked ' + "prefill and prefix caching; disabling both." + ) + if not getattr(self.model_config.hf_config, "is_causal", True): + disable_chunked_prefill_reasons.append( + "Only models using causal attention supports chunked " + "prefill and prefix caching; disabling both." + ) + elif self.model_config.is_encoder_decoder: + from vllm.multimodal import MULTIMODAL_REGISTRY + + self.scheduler_config.max_num_encoder_input_tokens = ( + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + ) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens, + ) + if ( + self.model_config.architecture == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" + ): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'." + ) + + # Final off-switch for CP/APC: + # Disable for (a) collected blockers, (b) encoder–decoder, or + # (c) explicit CP=False when APC wasn't requested. + # Do NOT disable merely because the resolved CP flag is False. + apc_requested = ( + self.cache_config is not None and self.cache_config.enable_prefix_caching + ) + if ( + disable_chunked_prefill_reasons + or (self.model_config is not None and self.model_config.is_encoder_decoder) + or ( + self.scheduler_config.enable_chunked_prefill is False + and not apc_requested + ) + ): + for reason in disable_chunked_prefill_reasons: + logger.info(reason) + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.long_prefill_token_threshold = 0 + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + + if ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching + ): + logger.warning( + "KV cache events are on, but prefix caching is not enabled." + "Use --enable-prefix-caching to enable." + ) + if ( + self.kv_events_config is not None + and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events + ): + logger.warning( + "KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable." + ) + current_platform.check_and_update_config(self) + + # Do this after all the updates to compilation_config.level + if ( + envs.VLLM_USE_V1 + and self.compilation_config.level == CompilationLevel.PIECEWISE + ): + self.compilation_config.set_splitting_ops_for_v1() + + # final check of cudagraph mode after all possible updates + if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and self.model_config is not None + and not self.model_config.disable_cascade_attn + and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501 + ): + logger.warning_once( + "No piecewise cudagraph for executing cascade attention." + " Will fall back to eager execution if a batch runs " + "into cascade attentions" + ) + + if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): + assert self.compilation_config.level == CompilationLevel.PIECEWISE, ( + "Compilation level should be CompilationLevel.PIECEWISE " + "when cudagraph_mode piecewise cudagraphs is used, " + f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + ) + + # final migrate the deprecated flags + self.compilation_config.use_cudagraph = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) + self.compilation_config.full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + + if self.parallel_config.enable_dbo: + a2a_backend = envs.VLLM_ALL2ALL_BACKEND + assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( + "Microbatching currently only supports the deepep_low_latency and " + f"deepep_high_throughput all2all backend. {a2a_backend} is not " + "supported. To fix set the VLLM_ALL2ALL_BACKEND environment " + "variable to deepep_low_latency or deepep_high_throughput and " + "install the DeepEP kernels." + ) + + if not self.model_config.disable_cascade_attn: + self.model_config.disable_cascade_attn = True + logger.warning_once("Disabling cascade attention when DBO is enabled.") + + if not self.instance_id: + self.instance_id = random_uuid()[:5] + + if ( + envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager + ): + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_transfer_config is not None: + # Hybrid KV cache manager is not compatible with KV transfer. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." + ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + + if self.compilation_config.debug_dump_path: + self.compilation_config.debug_dump_path = ( + self.compilation_config.debug_dump_path.absolute().expanduser() + ) + if envs.VLLM_DEBUG_DUMP_PATH is not None: + env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser() + if self.compilation_config.debug_dump_path: + logger.warning( + "Config-specified debug dump path is overridden" + " by VLLM_DEBUG_DUMP_PATH to %s", + env_path, + ) + self.compilation_config.debug_dump_path = env_path + + def has_blocked_weights(): + if self.quant_config is not None: + if hasattr(self.quant_config, "weight_block_size"): + return self.quant_config.weight_block_size is not None + elif hasattr(self.quant_config, "has_blocked_weights"): + return self.quant_config.has_blocked_weights() + return False + + # Enable quant_fp8 CUDA ops (TODO disable in follow up) + # On H100 the CUDA kernel is faster than + # native implementation + # https://github.com/vllm-project/vllm/issues/25094 + if has_blocked_weights(): + custom_ops = self.compilation_config.custom_ops + if "-quant_fp8" not in custom_ops: + custom_ops.append("+quant_fp8") + + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size + for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", + removed_sizes, + self.parallel_config.tensor_parallel_size, + ) + + return [ + size + for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + + def _set_cudagraph_sizes(self): + """ + vLLM defines the default candidate list of batch sizes for CUDA graph + capture as: + + ```python + max_graph_size = min(max_num_seqs * 2, 512) + # 1, 2, 4, then multiples of 8 up to max_graph_size + cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] + + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). + + These sizes are used to capture and reuse CUDA graphs for + performance-critical paths (e.g., decoding). Capturing enables + significantly faster kernel dispatch by avoiding Python overhead. The + list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on + most GPUs), which controls the total allowed number of tokens in a + batch. Since each sequence may have a variable number of tokens, the + maximum usable batch size will depend on actual sequence lengths. + + Example: + With `max_num_batched_tokens = 8192`, and typical sequences + averaging ~32 tokens, most practical batch sizes fall below 256. + However, the system will still allow capture sizes up to 512 if + shape and memory permit. + + Note: + If users explicitly specify cudagraph capture sizes in the + compilation config, those will override this default logic. + At runtime: + + - If batch size <= one of the `cudagraph_capture_sizes`, the closest + padded CUDA graph will be used. + - If batch size > largest `cudagraph_capture_sizes`, cudagraph will + not be used. + """ + + # calculate the default `batch_size_capture_list` + batch_size_capture_list = [] + if self.model_config is not None and not self.model_config.enforce_eager: + cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + max_graph_size = cuda_graph_sizes[0] + assert max_graph_size >= 1, ( + "Maximum cudagraph size should be greater than or equal to 1." + ) + batch_size_capture_list = [ + i for i in [1, 2, 4] if i <= max_graph_size + ] + list(range(8, max_graph_size + 1, 8)) + elif len(cuda_graph_sizes) > 1: + batch_size_capture_list = sorted(cuda_graph_sizes) + else: + raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + if ( + self.parallel_config.tensor_parallel_size > 1 + and self.compilation_config.pass_config.enable_sequence_parallelism + ): + batch_size_capture_list = self.update_sizes_for_sequence_parallelism( + batch_size_capture_list + ) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + batch_size_capture_list = [ + size for size in batch_size_capture_list if size <= max_num_tokens + ] + + self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list) + + def recalculate_max_model_len(self, max_model_len: int): + # Can only be called in try_verify_and_update_config + model_config = self.model_config + max_model_len = model_config.get_and_verify_max_len(max_model_len) + self.model_config.max_model_len = max_model_len + self.scheduler_config.max_model_len = max_model_len + + def try_verify_and_update_config(self): + if self.model_config is None: + return + + # Avoid running try_verify_and_update_config multiple times + if getattr(self.model_config, "config_updated", False): + return + self.model_config.config_updated = True + + architecture = self.model_config.architecture + if architecture is None: + return + + from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, + HybridAttentionMambaModelConfig, + ) + + cls = MODELS_CONFIG_MAP.get(architecture, None) + if cls is not None: + cls.verify_and_update_config(self) + + if self.model_config.is_hybrid: + HybridAttentionMambaModelConfig.verify_and_update_config(self) + + if self.model_config.convert_type == "classify": + # Maybe convert ForCausalLM into ForSequenceClassification model. + from vllm.model_executor.models.adapters import SequenceClassificationConfig + + SequenceClassificationConfig.verify_and_update_config(self) + + if hasattr(self.model_config, "model_weights") and is_runai_obj_uri( + self.model_config.model_weights + ): + if self.load_config.load_format == "auto": + logger.info( + "Detected Run:ai model config. " + "Overriding `load_format` to 'runai_streamer'" + ) + self.load_config.load_format = "runai_streamer" + elif self.load_config.load_format != "runai_streamer": + raise ValueError( + f"To load a model from S3, 'load_format' " + f"must be 'runai_streamer', " + f"but got '{self.load_config.load_format}'. " + f"Model: {self.model_config.model}" + ) + + def compile_debug_dump_path(self) -> Optional[Path]: + """Returns a rank-aware path for dumping + torch.compile debug information. + """ + if self.compilation_config.debug_dump_path is None: + return None + tp_rank = self.parallel_config.rank + dp_rank = self.parallel_config.data_parallel_rank + data_parallel_size = self.parallel_config.data_parallel_size + append_path = ( + f"rank_{tp_rank}" + if data_parallel_size == 1 + else f"rank_{tp_rank}_dp_{dp_rank}" + ) + path = self.compilation_config.debug_dump_path / append_path + return path + + def __str__(self): + return ( + f"model={self.model_config.model!r}, " + f"speculative_config={self.speculative_config!r}, " + f"tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " + f"tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}, " + f"download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa + f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f"device_config={self.device_config.device}, " + f"structured_outputs_config={self.structured_outputs_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"pooler_config={self.model_config.pooler_config!r}, " + f"compilation_config={self.compilation_config!r}" + ) + + +_current_vllm_config: Optional[VllmConfig] = None +_current_prefix: Optional[str] = None + + +@contextmanager +def set_current_vllm_config( + vllm_config: VllmConfig, check_compile=False, prefix: Optional[str] = None +): + """ + Temporarily set the current vLLM config. + Used during model initialization. + We save the current vLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM config to determine how to dispatch. + """ + global _current_vllm_config, _current_prefix + old_vllm_config = _current_vllm_config + old_prefix = _current_prefix + from vllm.compilation.counter import compilation_counter + + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + _current_prefix = prefix + yield + except Exception: + raise + else: + if check_compile: + vllm_config.compilation_config.custom_op_log_check() + + if ( + check_compile + and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and compilation_counter.num_models_seen == num_models_seen + ): + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + " if you want it to be supported.", + vllm_config.model_config.model, + ) + finally: + _current_vllm_config = old_vllm_config + _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_vllm_config()""" + return get_current_vllm_config().compilation_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current vLLM config is not set.") + return VllmConfig() + return _current_vllm_config + + +T = TypeVar("T") + + +def get_layers_from_vllm_config( + vllm_config: VllmConfig, + layer_type: type[T], + layer_names: Optional[list[str]] = None, +) -> dict[str, T]: + """ + Get layers from the vLLM config. + + Args: + vllm_config: The vLLM config. + layer_type: The type of the layer to get. + layer_names: The names of the layers to get. If None, return all layers. + """ + + if layer_names is None: + layer_names = list(vllm_config.compilation_config.static_forward_context.keys()) + + forward_context = vllm_config.compilation_config.static_forward_context + + return { + layer_name: forward_context[layer_name] + for layer_name in layer_names + if isinstance(forward_context[layer_name], layer_type) + } diff --git a/vllm/connections.py b/vllm/connections.py index 103505eb3d81f..8d5e0e5cbf5d0 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -41,8 +41,9 @@ class HTTPConnection: parsed_url = urlparse(url) if parsed_url.scheme not in ("http", "https"): - raise ValueError("Invalid HTTP URL: A valid HTTP URL " - "must have scheme 'http' or 'https'.") + raise ValueError( + "Invalid HTTP URL: A valid HTTP URL must have scheme 'http' or 'https'." + ) def _headers(self, **extras: str) -> MutableMapping[str, str]: return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} @@ -54,16 +55,20 @@ class HTTPConnection: stream: bool = False, timeout: Optional[float] = None, extra_headers: Optional[Mapping[str, str]] = None, + allow_redirects: bool = True, ): self._validate_http_url(url) client = self.get_sync_client() extra_headers = extra_headers or {} - return client.get(url, - headers=self._headers(**extra_headers), - stream=stream, - timeout=timeout) + return client.get( + url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout, + allow_redirects=allow_redirects, + ) async def get_async_response( self, @@ -71,18 +76,26 @@ class HTTPConnection: *, timeout: Optional[float] = None, extra_headers: Optional[Mapping[str, str]] = None, + allow_redirects: bool = True, ): self._validate_http_url(url) client = await self.get_async_client() extra_headers = extra_headers or {} - return client.get(url, - headers=self._headers(**extra_headers), - timeout=timeout) + return client.get( + url, + headers=self._headers(**extra_headers), + timeout=timeout, + allow_redirects=allow_redirects, + ) - def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: - with self.get_response(url, timeout=timeout) as r: + def get_bytes( + self, url: str, *, timeout: Optional[float] = None, allow_redirects: bool = True + ) -> bytes: + with self.get_response( + url, timeout=timeout, allow_redirects=allow_redirects + ) as r: r.raise_for_status() return r.content @@ -92,8 +105,11 @@ class HTTPConnection: url: str, *, timeout: Optional[float] = None, + allow_redirects: bool = True, ) -> bytes: - async with await self.get_async_response(url, timeout=timeout) as r: + async with await self.get_async_response( + url, timeout=timeout, allow_redirects=allow_redirects + ) as r: r.raise_for_status() return await r.read() diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py deleted file mode 100644 index 444bb25f2830a..0000000000000 --- a/vllm/core/block/block_table.py +++ /dev/null @@ -1,399 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from typing import List, Optional - -from vllm.core.block.common import BlockList -from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator -from vllm.utils import Device, cdiv, chunk_list - - -class BlockTable: - """A class to manage blocks for a specific sequence. - - The BlockTable maps a sequence of tokens to a list of blocks, where each - block represents a contiguous memory allocation for a portion of the - sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is - responsible for allocating and freeing memory for the blocks. - - Args: - block_size (int): The maximum number of tokens that can be stored in a - single block. - block_allocator (DeviceAwareBlockAllocator): The block allocator used to - manage memory for the blocks. - _blocks (Optional[List[Block]], optional): An optional list of existing - blocks to initialize the BlockTable with. If not provided, an empty - BlockTable is created. - max_block_sliding_window (Optional[int], optional): The number of - blocks to keep around for each sequence. If None, all blocks - are kept (eg., when sliding window is not used). - It should at least fit the sliding window size of the model. - - Attributes: - _block_size (int): The maximum number of tokens that can be stored in a - single block. - _allocator (DeviceAwareBlockAllocator): The block allocator used to - manage memory for the blocks. - _blocks (Optional[List[Block]]): The list of blocks managed by this - BlockTable. - _num_full_slots (int): The number of tokens currently stored in the - blocks. - """ - - def __init__( - self, - block_size: int, - block_allocator: DeviceAwareBlockAllocator, - _blocks: Optional[List[Block]] = None, - max_block_sliding_window: Optional[int] = None, - ): - self._block_size = block_size - self._allocator = block_allocator - if _blocks is None: - _blocks = [] - self._blocks: BlockList = BlockList(_blocks) - - self._max_block_sliding_window = max_block_sliding_window - self._num_full_slots = self._get_num_token_ids() - - @staticmethod - def get_num_required_blocks(token_ids: List[int], - block_size: int, - num_lookahead_slots: int = 0) -> int: - """Calculates the minimum number of blocks required to store a given - sequence of token IDs along with any look-ahead slots that may be - required (like in multi-step + chunked-prefill). - - This assumes worst-case scenario, where every block requires a new - allocation (e.g. ignoring prefix caching). - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - block_size (int): The maximum number of tokens that can be stored in - a single block. - num_lookahead_slots (int): look-ahead slots that the sequence may - require. - - Returns: - int: The minimum number of blocks required to store the given - sequence of token IDs along with any required look-ahead slots. - """ - return cdiv(len(token_ids) + num_lookahead_slots, block_size) - - def allocate(self, - token_ids: List[int], - device: Device = Device.GPU, - extra_hash: Optional[int] = None) -> None: - """Allocates memory blocks for storing the given sequence of token IDs. - - This method allocates the required number of blocks to store the given - sequence of token IDs. - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - device (Device, optional): The device on which the blocks should be - allocated. Defaults to Device.GPU. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefixcaching block. - """ - assert not self._is_allocated - assert token_ids - blocks = self._allocate_blocks_for_token_ids(prev_block=None, - token_ids=token_ids, - device=device, - extra_hash=extra_hash) - self.update(blocks) - self._num_full_slots = len(token_ids) - - def update(self, blocks: List[Block]) -> None: - """Resets the table to the newly provided blocks - (with their corresponding block ids) - """ - self._blocks.update(blocks) - - def append_token_ids(self, - token_ids: List[int], - num_lookahead_slots: int = 0, - num_computed_slots: Optional[int] = None, - extra_hash: Optional[int] = None) -> None: - """Appends a sequence of token IDs to the existing blocks in the - BlockTable. - - This method appends the given sequence of token IDs to the existing - blocks in the BlockTable. If there is not enough space in the existing - blocks, new blocks are allocated using the `ensure_num_empty_slots` - method to accommodate the additional tokens. - - The token IDs are divided into chunks of size `block_size` (except for - the first chunk, which may be smaller), and each chunk is appended to a - separate block. - - Args: - token_ids (List[int]): The sequence of token IDs to be appended. - num_computed_slots (Optional[int]): The number of KV cache slots - that are already filled (computed). - When sliding window is enabled, this is used to compute how many - blocks to drop at the front of the sequence. - Without sliding window, None can be passed. - Without chunked prefill, it should be the same as - _num_full_slots. - extra_hash (Optional[int]): The hash value of additional - factors such as adapters that influence the block, apart - from the token_ids. - """ - assert self._is_allocated, "no blocks have been allocated" - assert len(self._blocks) > 0 - - # Drop blocks that are no longer needed due to sliding window - if self._max_block_sliding_window is not None: - null_block = self._allocator.allocate_or_get_null_block() - assert num_computed_slots is not None - end_block_idx = (num_computed_slots // - self._block_size) - self._max_block_sliding_window - for idx in range(0, end_block_idx): - b = self._blocks[idx] - if b is not null_block: - self._allocator.free(b) - self._blocks[idx] = null_block - - # Ensure there are enough empty slots for the new tokens plus - # lookahead slots - self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + - num_lookahead_slots, - extra_hash=extra_hash) - - # Update the blocks with the new tokens - first_block_idx = self._num_full_slots // self._block_size - token_blocks = self._chunk_token_blocks_for_append(token_ids) - - for i, token_block in enumerate(token_blocks): - self._blocks.append_token_ids(first_block_idx + i, token_block) - - self._num_full_slots += len(token_ids) - - def ensure_num_empty_slots(self, - num_empty_slots: int, - extra_hash: Optional[int] = None) -> None: - """Ensures that the BlockTable has at least the specified number of - empty slots available. - - This method checks if the BlockTable has enough empty slots (i.e., - available space) to accommodate the requested number of tokens. If not, - it allocates additional blocks on the GPU to ensure that the required - number of empty slots is available. - - Args: - num_empty_slots (int): The minimum number of empty slots required. - extra_hash (Optional[int]): The hash value of additional - factors such as adapters that influence the block, apart - from the token_ids. - """ - # Currently the block table only supports - # appending tokens to GPU blocks. - device = Device.GPU - assert self._is_allocated - - if self._num_empty_slots >= num_empty_slots: - return - - slots_to_allocate = num_empty_slots - self._num_empty_slots - blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) - - for _ in range(blocks_to_allocate): - assert len(self._blocks) > 0 - self._blocks.append( - self._allocator.allocate_mutable_block( - prev_block=self._blocks[-1], - device=device, - extra_hash=extra_hash)) - - def fork(self) -> "BlockTable": - """Creates a new BlockTable instance with a copy of the blocks from the - current instance. - - This method creates a new BlockTable instance with the same block size, - block allocator, and a copy of the blocks from the current instance. The - new BlockTable has its own independent set of blocks, but shares the - same underlying memory allocation with the original BlockTable. - - Returns: - BlockTable: A new BlockTable instance with a copy of the blocks from - the current instance. - """ - assert self._is_allocated - assert len(self._blocks) > 0 - forked_blocks = self._allocator.fork(self._blocks[-1]) - return BlockTable( - block_size=self._block_size, - block_allocator=self._allocator, - _blocks=forked_blocks, - max_block_sliding_window=self._max_block_sliding_window, - ) - - def free(self) -> None: - """Frees the memory occupied by the blocks in the BlockTable. - - This method iterates over all the blocks in the `_blocks` list and calls - the `free` method of the `_allocator` object to release the memory - occupied by each block. After freeing all the blocks, the `_blocks` list - is set to `None`. - """ - for block in self.blocks: - self._allocator.free(block) - self._blocks.reset() - - @property - def physical_block_ids(self) -> List[int]: - """Returns a list of physical block indices for the blocks in the - BlockTable. - - This property returns a list of integers, where each integer represents - the physical block index of a corresponding block in the `_blocks` list. - The physical block index is a unique identifier for the memory location - occupied by the block. - - Returns: - List[int]: A list of physical block indices for the blocks in the - BlockTable. - """ - return self._blocks.ids() - - def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: - """Get the number of "unseen" tokens in the sequence. - - Unseen tokens are tokens in the sequence corresponding to this block - table, but are not yet appended to this block table. - - Args: - sequence_token_ids (List[int]): The list of token ids in the - sequence. - - Returns: - List[int]: The postfix of sequence_token_ids that has not yet been - appended to the block table. - """ - - # Since the block table is append-only, the unseen token ids are the - # ones after the appended ones. - return sequence_token_ids[self.num_full_slots:] - - def _allocate_blocks_for_token_ids( - self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> List[Block]: - blocks: List[Block] = [] - - block_token_ids = [] - tail_token_ids = [] - for cur_token_ids in chunk_list(token_ids, self._block_size): - if len(cur_token_ids) == self._block_size: - block_token_ids.append(cur_token_ids) - else: - tail_token_ids.append(cur_token_ids) - - if block_token_ids: - blocks.extend( - self._allocator.allocate_immutable_blocks( - prev_block, - block_token_ids=block_token_ids, - device=device, - extra_hash=extra_hash)) - prev_block = blocks[-1] - - if tail_token_ids: - assert len(tail_token_ids) == 1 - cur_token_ids = tail_token_ids[0] - - block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device, extra_hash=extra_hash) - block.append_token_ids(cur_token_ids) - - blocks.append(block) - - return blocks - - def _get_all_token_ids(self) -> List[int]: - # NOTE: This function is O(seq_len); use sparingly. - token_ids: List[int] = [] - - if not self._is_allocated: - return token_ids - - for block in self.blocks: - token_ids.extend(block.token_ids) - - return token_ids - - def _get_num_token_ids(self) -> int: - res = 0 - for block in self.blocks: - res += len(block.token_ids) - - return res - - @property - def _is_allocated(self) -> bool: - return len(self._blocks) > 0 - - @property - def blocks(self) -> List[Block]: - return self._blocks.list() - - @property - def _num_empty_slots(self) -> int: - assert self._is_allocated - return len(self._blocks) * self._block_size - self._num_full_slots - - @property - def num_full_slots(self) -> int: - """Returns the total number of tokens currently stored in the - BlockTable. - - Returns: - int: The total number of tokens currently stored in the BlockTable. - """ - return self._num_full_slots - - def get_num_blocks_touched_by_append_slots( - self, token_ids: List[int], num_lookahead_slots: int) -> int: - """Determine how many blocks will be "touched" by appending the token - ids. - - This is required for the scheduler to determine whether a sequence can - continue generation, or if it must be preempted. - """ - # Math below is equivalent to: - # all_token_ids = token_ids + [-1] * num_lookahead_slots - # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) - # return len(token_blocks) - - num_token_ids = len(token_ids) + num_lookahead_slots - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - num_token_blocks = (1 + math.ceil( - (num_token_ids - first_chunk_size) / self._block_size)) - return num_token_blocks - - def _chunk_token_blocks_for_append( - self, token_ids: List[int]) -> List[List[int]]: - """Split the token ids into block-sized chunks so they can be easily - appended to blocks. The first such "token block" may have less token ids - than the block size, since the last allocated block may be partially - full. - - If no token ids are provided, then no chunks are returned. - """ - - if not token_ids: - return [] - - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - token_blocks = [token_ids[:first_chunk_size]] - token_blocks.extend( - chunk_list(token_ids[first_chunk_size:], self._block_size)) - return token_blocks diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py deleted file mode 100644 index a337007a9eaa6..0000000000000 --- a/vllm/core/block/common.py +++ /dev/null @@ -1,371 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections import deque -from dataclasses import dataclass -from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple - -from vllm.core.block.interfaces import Block, BlockAllocator - -BlockId = int -RefCount = int - - -class RefCounterProtocol(Protocol): - - def incr(self, block_id: BlockId) -> RefCount: - raise NotImplementedError - - def decr(self, block_id: BlockId) -> RefCount: - raise NotImplementedError - - def get(self, block_id: BlockId) -> RefCount: - raise NotImplementedError - - -class RefCounter(RefCounterProtocol): - """A class for managing reference counts for a set of block indices. - - The RefCounter class maintains a dictionary that maps block indices to their - corresponding reference counts. It provides methods to increment, decrement, - and retrieve the reference count for a given block index. - - Args: - all_block_indices (Iterable[BlockId]): An iterable of block indices - to initialize the reference counter with. - """ - - def __init__(self, all_block_indices: Iterable[BlockId]): - deduped = set(all_block_indices) - self._refcounts: Dict[BlockId, RefCount] = { - index: 0 - for index in deduped - } - - def incr(self, block_id: BlockId) -> RefCount: - assert block_id in self._refcounts - pre_incr_refcount = self._refcounts[block_id] - - assert pre_incr_refcount >= 0 - - post_incr_refcount = pre_incr_refcount + 1 - self._refcounts[block_id] = post_incr_refcount - return post_incr_refcount - - def decr(self, block_id: BlockId) -> RefCount: - assert block_id in self._refcounts - refcount = self._refcounts[block_id] - - assert refcount > 0 - refcount -= 1 - - self._refcounts[block_id] = refcount - - return refcount - - def get(self, block_id: BlockId) -> RefCount: - assert block_id in self._refcounts - return self._refcounts[block_id] - - def as_readonly(self) -> "ReadOnlyRefCounter": - return ReadOnlyRefCounter(self) - - -class ReadOnlyRefCounter(RefCounterProtocol): - """A read-only view of the RefCounter class. - - The ReadOnlyRefCounter class provides a read-only interface to access the - reference counts maintained by a RefCounter instance. It does not allow - modifications to the reference counts. - - Args: - refcounter (RefCounter): The RefCounter instance to create a read-only - view for. - """ - - def __init__(self, refcounter: RefCounter): - self._refcounter = refcounter - - def incr(self, block_id: BlockId) -> RefCount: - raise ValueError("Incr not allowed") - - def decr(self, block_id: BlockId) -> RefCount: - raise ValueError("Decr not allowed") - - def get(self, block_id: BlockId) -> RefCount: - return self._refcounter.get(block_id) - - -class CopyOnWriteTracker: - """A class for tracking and managing copy-on-write operations for blocks. - - The CopyOnWriteTracker class maintains a mapping of source block indices to - their corresponding copy-on-write destination block indices. It works in - conjunction with a RefCounter. - - Args: - refcounter (RefCounter): The reference counter used to track block - reference counts. - """ - - def __init__(self, refcounter: RefCounterProtocol): - self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] - self._refcounter = refcounter - - def is_appendable(self, block: Block) -> bool: - """Checks if the block is shared or not. If shared, then it cannot - be appended and needs to be duplicated via copy-on-write - """ - block_id = block.block_id - if block_id is None: - return True - - refcount = self._refcounter.get(block_id) - return refcount <= 1 - - def record_cow(self, src_block_id: Optional[BlockId], - trg_block_id: Optional[BlockId]) -> None: - """Records a copy-on-write operation from source to target block id - Args: - src_block_id (BlockId): The source block id from which to copy - the data - trg_block_id (BlockId): The target block id to which the data - is copied - """ - assert src_block_id is not None - assert trg_block_id is not None - self._copy_on_writes.append((src_block_id, trg_block_id)) - - def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: - """Clears the copy-on-write tracking information and returns the current - state. - - This method returns a list mapping source block indices to - destination block indices for the current copy-on-write operations. - It then clears the internal tracking information. - - Returns: - List[Tuple[BlockId, BlockId]]: A list mapping source - block indices to destination block indices for the - current copy-on-write operations. - """ - cows = self._copy_on_writes - self._copy_on_writes = [] - return cows - - -class BlockPool: - """Used to pre-allocate block objects, in order to avoid excessive python - object allocations/deallocations. - The pool starts from "pool_size" objects and will increase to more objects - if necessary - - Note that multiple block objects may point to the same physical block id, - which is why this pool is needed, so that it will be easier to support - prefix caching and more complicated sharing of physical blocks. - """ - - def __init__(self, block_size: int, create_block: Block.Factory, - allocator: BlockAllocator, pool_size: int): - self._block_size = block_size - self._create_block = create_block - self._allocator = allocator - self._pool_size = pool_size - assert self._pool_size >= 0 - - self._free_ids: Deque[int] = deque(range(self._pool_size)) - self._pool = [] - for i in range(self._pool_size): - self._pool.append( - self._create_block(prev_block=None, - token_ids=[], - block_size=self._block_size, - allocator=self._allocator, - block_id=None, - extra_hash=None)) - - def increase_pool(self): - """Doubles the internal pool size - """ - cur_pool_size = self._pool_size - new_pool_size = cur_pool_size * 2 - self._pool_size = new_pool_size - - self._free_ids += deque(range(cur_pool_size, new_pool_size)) - - for i in range(cur_pool_size, new_pool_size): - self._pool.append( - self._create_block(prev_block=None, - token_ids=[], - block_size=self._block_size, - allocator=self._allocator, - block_id=None, - extra_hash=None)) - - def init_block(self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - physical_block_id: Optional[int], - extra_hash: Optional[int] = None) -> Block: - if len(self._free_ids) == 0: - self.increase_pool() - assert len(self._free_ids) > 0 - - pool_id = self._free_ids.popleft() - - block = self._pool[pool_id] - block.__init__( # type: ignore[misc] - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - allocator=block._allocator, # type: ignore[attr-defined] - block_id=physical_block_id, - extra_hash=extra_hash) - block.pool_id = pool_id # type: ignore[attr-defined] - return block - - def free_block(self, block: Block) -> None: - self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] - - -class BlockList: - """This class is an optimization to allow fast-access to physical - block ids. It maintains a block id list that is updated with the - block list and this avoids the need to reconstruct the block id - list on every iteration of the block manager - """ - - def __init__(self, blocks: List[Block]): - self._blocks: List[Block] = [] - self._block_ids: List[int] = [] - - self.update(blocks) - - def _add_block_id(self, block_id: Optional[BlockId]) -> None: - assert block_id is not None - self._block_ids.append(block_id) - - def _update_block_id(self, block_index: int, - new_block_id: Optional[BlockId]) -> None: - assert new_block_id is not None - self._block_ids[block_index] = new_block_id - - def update(self, blocks: List[Block]): - self._blocks = blocks - - # Cache block ids for fast query - self._block_ids = [] - for block in self._blocks: - self._add_block_id(block.block_id) - - def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: - block = self._blocks[block_index] - prev_block_id = block.block_id - - block.append_token_ids(token_ids) - - # CoW or promotion may update the internal block_id - if prev_block_id != block.block_id: - self._update_block_id(block_index, block.block_id) - - def append(self, new_block: Block): - self._blocks.append(new_block) - self._add_block_id(new_block.block_id) - - def __len__(self) -> int: - return len(self._blocks) - - def __getitem__(self, block_index: int) -> Block: - return self._blocks[block_index] - - def __setitem__(self, block_index: int, new_block: Block) -> None: - self._blocks[block_index] = new_block - self._update_block_id(block_index, new_block.block_id) - - def reset(self): - self._blocks = [] - self._block_ids = [] - - def list(self) -> List[Block]: - return self._blocks - - def ids(self) -> List[int]: - return self._block_ids - - -@dataclass -class CacheMetricData: - """A utility dataclass to maintain cache metric. - To avoid overflow, we maintain the hit rate in block granularity, so that - we can maintain a single hit rate for n_completed_block x block_size, - and calculate the real time hit rate by the following: - BS = The number of queries per block. - nB = The number of completed blocks. - HR = hit rate of (nB x BS) queries. - Q = current number of queries (< BS). - H = current number of hits (< BS). - hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) - """ - num_completed_blocks: int = 0 - completed_block_cache_hit_rate: float = 0.0 - num_incompleted_block_queries: int = 0 - num_incompleted_block_hit: int = 0 - block_size: int = 1000 - - def query(self, hit: bool): - self.num_incompleted_block_queries += 1 - self.num_incompleted_block_hit += 1 if hit else 0 - - # When a block is completed, update the cache hit rate - # and reset the incomplete numbers. - if self.num_incompleted_block_queries == self.block_size: - hit_rate = (self.num_incompleted_block_hit / - self.num_incompleted_block_queries) - self.completed_block_cache_hit_rate = ( - self.completed_block_cache_hit_rate * self.num_completed_blocks - + hit_rate) / (self.num_completed_blocks + 1) - self.num_incompleted_block_queries = 0 - self.num_incompleted_block_hit = 0 - self.num_completed_blocks += 1 - - def get_hit_rate(self): - incomplete_ratio = self.num_incompleted_block_queries / self.block_size - total_blocks = self.num_completed_blocks + incomplete_ratio - if total_blocks == 0: - return 0.0 - - completed_block_hit, incompleted_block_hit = 0.0, 0.0 - if self.num_completed_blocks > 0: - completed_block_hit = (self.completed_block_cache_hit_rate * - self.num_completed_blocks) - if self.num_incompleted_block_queries > 0: - incompleted_hit_rate = (self.num_incompleted_block_hit / - self.num_incompleted_block_queries) - incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) - return (completed_block_hit + incompleted_block_hit) / total_blocks - - -def get_all_blocks_recursively(last_block: Block) -> List[Block]: - """Retrieves all the blocks in a sequence starting from the last block. - - This function recursively traverses the sequence of blocks in reverse order, - starting from the given last block, and returns a list of all the blocks in - the sequence. - - Args: - last_block (Block): The last block in the sequence. - - Returns: - List[Block]: A list of all the blocks in the sequence, in the order they - appear. - """ - - def recurse(block: Block, lst: List[Block]) -> None: - if block.prev_block is not None: - recurse(block.prev_block, lst) - lst.append(block) - - all_blocks: List[Block] = [] - recurse(last_block, all_blocks) - return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py deleted file mode 100644 index 92bc5e157e148..0000000000000 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ /dev/null @@ -1,439 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Dict, FrozenSet, List, Optional, Tuple - -from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, - DeviceAwareBlockAllocator) -from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator -from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator -from vllm.utils import Device - - -class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - """A block allocator that can allocate blocks on both CPU and GPU memory. - - This class implements the `DeviceAwareBlockAllocator` interface and provides - functionality for allocating and managing blocks of memory on both CPU and - GPU devices. - - The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU - blocks, and allows for allocation, deallocation, forking, and swapping of - blocks across these memory pools. - """ - - @staticmethod - def create( - allocator_type: str, - num_gpu_blocks: int, - num_cpu_blocks: int, - block_size: int, - ) -> DeviceAwareBlockAllocator: - """Creates a CpuGpuBlockAllocator instance with the specified - configuration. - - This static method creates and returns a CpuGpuBlockAllocator instance - based on the provided parameters. It initializes the CPU and GPU block - allocators with the specified number of blocks, block size, and - allocator type. - - Args: - allocator_type (str): The type of block allocator to use for CPU - and GPU blocks. Currently supported values are "naive" and - "prefix_caching". - num_gpu_blocks (int): The number of blocks to allocate for GPU - memory. - num_cpu_blocks (int): The number of blocks to allocate for CPU - memory. - block_size (int): The size of each block in number of tokens. - - Returns: - DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the - specified configuration. - - Notes: - - The block IDs are assigned contiguously, with GPU block IDs coming - before CPU block IDs. - """ - reserved_blocks = 0 - block_ids = list( - range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) - num_gpu_blocks -= reserved_blocks - gpu_block_ids = block_ids[:num_gpu_blocks] - cpu_block_ids = block_ids[num_gpu_blocks:] - - if allocator_type == "naive": - gpu_allocator: BlockAllocator = NaiveBlockAllocator( - create_block=NaiveBlock, # type: ignore - num_blocks=num_gpu_blocks, - block_size=block_size, - block_ids=gpu_block_ids, - ) - - cpu_allocator: BlockAllocator = NaiveBlockAllocator( - create_block=NaiveBlock, # type: ignore - num_blocks=num_cpu_blocks, - block_size=block_size, - block_ids=cpu_block_ids, - ) - elif allocator_type == "prefix_caching": - gpu_allocator = PrefixCachingBlockAllocator( - num_blocks=num_gpu_blocks, - block_size=block_size, - block_ids=gpu_block_ids, - ) - - cpu_allocator = PrefixCachingBlockAllocator( - num_blocks=num_cpu_blocks, - block_size=block_size, - block_ids=cpu_block_ids, - ) - else: - raise ValueError(f"Unknown allocator type {allocator_type=}") - - return CpuGpuBlockAllocator( - cpu_block_allocator=cpu_allocator, - gpu_block_allocator=gpu_allocator, - ) - - def __init__(self, cpu_block_allocator: BlockAllocator, - gpu_block_allocator: BlockAllocator): - assert not ( - cpu_block_allocator.all_block_ids - & gpu_block_allocator.all_block_ids - ), "cpu and gpu block allocators can't have intersection of block ids" - - self._allocators = { - Device.CPU: cpu_block_allocator, - Device.GPU: gpu_block_allocator, - } - - self._swap_mapping: Dict[int, int] = {} - self._null_block: Optional[Block] = None - - self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} - for _, allocator in self._allocators.items(): - for block_id in allocator.all_block_ids: - self._block_ids_to_allocator[block_id] = allocator - - def allocate_or_get_null_block(self) -> Block: - if self._null_block is None: - self._null_block = NullBlock( - self.allocate_mutable_block(None, Device.GPU)) - return self._null_block - - def allocate_mutable_block(self, - prev_block: Optional[Block], - device: Device, - extra_hash: Optional[int] = None) -> Block: - """Allocates a new mutable block on the specified device. - - Args: - prev_block (Optional[Block]): The previous block to in the sequence. - Used for prefix hashing. - device (Device): The device on which to allocate the new block. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefix caching block. - - Returns: - Block: The newly allocated mutable block. - """ - return self._allocators[device].allocate_mutable_block( - prev_block, extra_hash=extra_hash) - - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device, - extra_hash: Optional[int] = None) -> List[Block]: - """Allocates a new group of immutable blocks with the provided block - token IDs on the specified device. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - Used for prefix hashing. - block_token_ids (List[int]): The list of block token IDs to be - stored in the new blocks. - device (Device): The device on which to allocate the new block. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefix caching block. - - Returns: - List[Block]: The newly allocated list of immutable blocks - containing the provided block token IDs. - """ - return self._allocators[device].allocate_immutable_blocks( - prev_block, block_token_ids, extra_hash=extra_hash) - - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> Block: - """Allocates a new immutable block with the provided token IDs on the - specified device. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - Used for prefix hashing. - token_ids (List[int]): The list of token IDs to be stored in the new - block. - device (Device): The device on which to allocate the new block. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefix caching block. - - Returns: - Block: The newly allocated immutable block containing the provided - token IDs. - """ - return self._allocators[device].allocate_immutable_block( - prev_block, token_ids, extra_hash=extra_hash) - - def free(self, block: Block) -> None: - """Frees the memory occupied by the given block. - - Args: - block (Block): The block to be freed. - """ - # Null block should never be freed - if isinstance(block, NullBlock): - return - block_id = block.block_id - assert block_id is not None - allocator = self._block_ids_to_allocator[block_id] - allocator.free(block) - - def fork(self, last_block: Block) -> List[Block]: - """Creates a new sequence of blocks that shares the same underlying - memory as the original sequence. - - Args: - last_block (Block): The last block in the original sequence. - - Returns: - List[Block]: A new list of blocks that shares the same memory as the - original sequence. - """ - # do not attempt to fork the null block - assert not isinstance(last_block, NullBlock) - block_id = last_block.block_id - assert block_id is not None - allocator = self._block_ids_to_allocator[block_id] - return allocator.fork(last_block) - - def get_num_free_blocks(self, device: Device) -> int: - """Returns the number of free blocks available on the specified device. - - Args: - device (Device): The device for which to query the number of free - blocks. AssertionError is raised if None is passed. - - Returns: - int: The number of free blocks available on the specified device. - """ - return self._allocators[device].get_num_free_blocks() - - def get_num_total_blocks(self, device: Device) -> int: - return self._allocators[device].get_num_total_blocks() - - def get_physical_block_id(self, device: Device, absolute_id: int) -> int: - """Returns the zero-offset block id on certain device given the - absolute block id. - - Args: - device (Device): The device for which to query relative block id. - absolute_id (int): The absolute block id for the block in - whole allocator. - - Returns: - int: The zero-offset block id on certain device. - """ - return self._allocators[device].get_physical_block_id(absolute_id) - - def swap(self, blocks: List[Block], src_device: Device, - dst_device: Device) -> Dict[int, int]: - """Execute the swap for the given blocks from source_device - on to dest_device, save the current swap mapping and append - them to the accumulated `self._swap_mapping` for each - scheduling move. - - Args: - blocks: List of blocks to be swapped. - src_device (Device): Device to swap the 'blocks' from. - dst_device (Device): Device to swap the 'blocks' to. - - Returns: - Dict[int, int]: Swap mapping from source_device - on to dest_device. - """ - src_block_ids = [block.block_id for block in blocks] - self._allocators[src_device].swap_out(blocks) - self._allocators[dst_device].swap_in(blocks) - dst_block_ids = [block.block_id for block in blocks] - - current_swap_mapping: Dict[int, int] = {} - for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): - if src_block_id is not None and dst_block_id is not None: - self._swap_mapping[src_block_id] = dst_block_id - current_swap_mapping[src_block_id] = dst_block_id - return current_swap_mapping - - def get_num_full_blocks_touched(self, blocks: List[Block], - device: Device) -> int: - """Returns the number of full blocks that will be touched by - swapping in/out the given blocks on to the 'device'. - - Args: - blocks: List of blocks to be swapped. - device (Device): Device to swap the 'blocks' on. - - Returns: - int: the number of full blocks that will be touched by - swapping in/out the given blocks on to the 'device'. - Non full blocks are ignored when deciding the number - of blocks to touch. - """ - return self._allocators[device].get_num_full_blocks_touched(blocks) - - def clear_copy_on_writes(self) -> List[Tuple[int, int]]: - """Clears the copy-on-write (CoW) state and returns the mapping of - source to destination block IDs. - - Returns: - List[Tuple[int, int]]: A list mapping source block IDs to - destination block IDs. - """ - # CoW only supported on GPU - device = Device.GPU - return self._allocators[device].clear_copy_on_writes() - - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - """Mark blocks as accessed, only use for prefix caching.""" - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].mark_blocks_as_accessed(block_ids, now) - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - """Mark blocks as accessed, only use for prefix caching.""" - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].mark_blocks_as_computed(block_ids) - - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].get_common_computed_block_ids( - computed_seq_block_ids) - - @property - def all_block_ids(self) -> FrozenSet[int]: - return frozenset(self._block_ids_to_allocator.keys()) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - assert device in self._allocators - return self._allocators[device].get_prefix_cache_hit_rate() - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache for specified or all devices.""" - if device: - return self._allocators[device].reset_prefix_cache() - success = True - for allocator in self._allocators.values(): - success = success and allocator.reset_prefix_cache() - return success - - def get_and_reset_swaps(self) -> List[Tuple[int, int]]: - """Returns and clears the mapping of source to destination block IDs. - Will be called after every swapping operations for now, and after every - schedule when BlockManagerV2 become default. Currently not useful. - - Returns: - List[Tuple[int, int]]: A mapping of source to destination block IDs. - """ - mapping = self._swap_mapping.copy() - self._swap_mapping.clear() - return list(mapping.items()) - - def find_cached_blocks_prefix( - self, - block_hashes: List[int], - device: Device = Device.GPU, - ) -> List[int]: - return self._allocators[device].find_cached_blocks_prefix(block_hashes) - - -class NullBlock(Block): - """ - Null blocks are used as a placeholders for KV cache blocks that have - been dropped due to sliding window. - This implementation just wraps an ordinary block and prevents it from - being modified. It also allows for testing if a block is NullBlock - via isinstance(). - """ - - def __init__(self, proxy: Block): - super().__init__() - self._proxy = proxy - - def append_token_ids(self, token_ids: List[BlockId]): - raise ValueError("null block should not be modified") - - @property - def block_id(self): - return self._proxy.block_id - - @block_id.setter - def block_id(self, value: Optional[BlockId]): - raise ValueError("null block should not be modified") - - @property - def token_ids(self) -> List[BlockId]: - return self._proxy.token_ids - - @property - def num_tokens_total(self) -> int: - raise NotImplementedError( - "num_tokens_total is not used for null block") - - @property - def num_empty_slots(self) -> BlockId: - return self._proxy.num_empty_slots - - @property - def is_full(self): - return self._proxy.is_full - - @property - def prev_block(self): - return self._proxy.prev_block - - @property - def extra_hash(self): - return None - - @property - def computed(self): - return self._proxy.computed - - @computed.setter - def computed(self, value): - self._proxy.computed = value - - @property - def last_accessed(self) -> float: - return self._proxy.last_accessed - - @last_accessed.setter - def last_accessed(self, last_accessed_ts: float): - self._proxy.last_accessed = last_accessed_ts - - @property - def content_hash(self): - return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py deleted file mode 100644 index 1a05881f7c005..0000000000000 --- a/vllm/core/block/interfaces.py +++ /dev/null @@ -1,319 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple - -from vllm.utils import Device - -BlockId = int - - -class Block(ABC): - - @abstractmethod - def append_token_ids(self, token_ids: List[int]) -> None: - pass - - @property - @abstractmethod - def block_id(self) -> Optional[int]: - pass - - @block_id.setter - @abstractmethod - def block_id(self, value: Optional[int]) -> None: - """NOTE: Do not use this API outside Block.""" - self._block_id = value - - @property - @abstractmethod - def token_ids(self) -> List[int]: - pass - - @property - @abstractmethod - def num_tokens_total(self) -> int: - """The number of tokens till the current block (inclusive) - """ - pass - - @property - @abstractmethod - def num_empty_slots(self) -> int: - pass - - @property - @abstractmethod - def is_full(self) -> bool: - pass - - @property - @abstractmethod - def prev_block(self) -> Optional["Block"]: - pass - - @property - @abstractmethod - def extra_hash(self) -> Optional[int]: - return None - - @property - @abstractmethod - def computed(self) -> bool: - raise NotImplementedError - - @computed.setter - @abstractmethod - def computed(self, value) -> bool: - """Should be only used by PrefixCacingAllocator""" - raise NotImplementedError - - @property - @abstractmethod - def last_accessed(self) -> float: - raise NotImplementedError - - @last_accessed.setter - @abstractmethod - def last_accessed(self, last_accessed_ts: float): - raise NotImplementedError - - class Factory(Protocol): - - @abstractmethod - def __call__( - self, - prev_block: Optional["Block"], - token_ids: List[int], - block_size: int, - allocator: "BlockAllocator", - block_id: Optional[int] = None, - computed: bool = False, - extra_hash: Optional[int] = None, - ) -> "Block": - pass - - @property - @abstractmethod - def content_hash(self) -> Optional[int]: - """Return the content-based hash of the current block, or None if it is - not yet defined or not supported. - - For the content-based hash to be defined, the current block must be - full. - """ - return None - - -class BlockAllocator(ABC): - - @abstractmethod - def allocate_mutable_block(self, prev_block: Optional[Block], - extra_hash: Optional[int]) -> Block: - pass - - @abstractmethod - def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int], - extra_hash: Optional[int]) -> Block: - pass - - @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - extra_hash: Optional[int]) -> List[Block]: - pass - - @abstractmethod - def free(self, block: Block) -> None: - pass - - @abstractmethod - def fork(self, last_block: Block) -> List[Block]: - pass - - @abstractmethod - def get_num_total_blocks(self) -> int: - pass - - @abstractmethod - def get_num_free_blocks(self) -> int: - pass - - @abstractmethod - def get_physical_block_id(self, absolute_id: int) -> int: - pass - - @abstractmethod - def swap_out(self, blocks: List[Block]) -> None: - pass - - @abstractmethod - def swap_in(self, blocks: List[Block]) -> None: - pass - - @property - @abstractmethod - def all_block_ids(self) -> FrozenSet[int]: - pass - - @abstractmethod - def clear_copy_on_writes(self) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - pass - - @abstractmethod - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - pass - - @abstractmethod - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - pass - - @abstractmethod - def cow_block_if_not_appendable(self, block: Block) -> BlockId: - """NOTE: This should not be used besides Block""" - pass - - @abstractmethod - def promote_to_immutable_block(self, block: Block) -> BlockId: - """NOTE: This should not be used besides Block""" - pass - - @abstractmethod - def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - @abstractmethod - def reset_prefix_cache(self) -> bool: - """Reset prefix cache.""" - pass - - class NoFreeBlocksError(ValueError): - pass - - @abstractmethod - def find_cached_blocks_prefix( - self, - block_hashes: List[int], - ) -> List[int]: - pass - - -class DeviceAwareBlockAllocator(ABC): - - @abstractmethod - def allocate_mutable_block(self, - prev_block: Optional[Block], - device: Device, - extra_hash: Optional[int] = None) -> Block: - pass - - @abstractmethod - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> Block: - pass - - @abstractmethod - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device, - extra_hash: Optional[int] = None, - ) -> List[Block]: - pass - - @abstractmethod - def get_num_free_blocks(self, device: Device) -> int: - pass - - @abstractmethod - def get_num_total_blocks(self, device: Device) -> int: - pass - - @abstractmethod - def free(self, block: Block) -> None: - pass - - @abstractmethod - def fork(self, last_block: Block) -> List[Block]: - pass - - @property - @abstractmethod - def all_block_ids(self) -> FrozenSet[int]: - pass - - @abstractmethod - def clear_copy_on_writes(self) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - pass - - @abstractmethod - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - pass - - @abstractmethod - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - pass - - @abstractmethod - def get_num_full_blocks_touched(self, blocks: List[Block], - device: Device) -> int: - pass - - @abstractmethod - def swap(self, blocks: List[Block], src_device: Device, - dst_device: Device) -> Dict[int, int]: - pass - - @abstractmethod - def get_physical_block_id(self, device: Device, absolute_id: int) -> int: - pass - - @abstractmethod - def allocate_or_get_null_block(self) -> Block: - """ - Null blocks are used as a placeholders for KV cache blocks that have - been dropped due to sliding window. - There is at most one null block per allocator. - """ - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self, device: Device) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - @abstractmethod - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache.""" - pass - - @abstractmethod - def find_cached_blocks_prefix( - self, - block_hashes: List[int], - device: Device = Device.GPU, - ) -> List[int]: - pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py deleted file mode 100644 index dae6ead04e9c9..0000000000000 --- a/vllm/core/block/naive_block.py +++ /dev/null @@ -1,466 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections import deque -from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union - -from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, - get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device - -Refcount = int - - -class NaiveBlockAllocator(BlockAllocator): - """A simple block allocator that manages blocks of memory without prefix - caching. - - Args: - create_block (Block.Factory): A factory function for creating new - blocks. This is used when a NaiveBlockAllocator is composed within - a prefix caching allocator -- the naive block allocator must - construct prefix caching blocks (but shouldn't know anything else - about them). - num_blocks (int): The total number of blocks to manage. - block_size (int): The size of each block in tokens. - block_ids (Optional[Iterable[int]], optional): An optional iterable of - block IDs. If not provided, block IDs will be assigned sequentially - from 0 to num_blocks - 1. - """ - - def __init__( - self, - create_block: Block.Factory, - num_blocks: int, - block_size: int, - block_ids: Optional[Iterable[int]] = None, - block_pool: Optional[BlockPool] = None, - ): - if block_ids is None: - block_ids = range(num_blocks) - - self._free_block_indices: Deque[BlockId] = deque(block_ids) - self._all_block_indices = frozenset(block_ids) - assert len(self._all_block_indices) == num_blocks - - self._refcounter = RefCounter( - all_block_indices=self._free_block_indices) - self._block_size = block_size - - self._cow_tracker = CopyOnWriteTracker( - refcounter=self._refcounter.as_readonly()) - - if block_pool is None: - extra_factor = 4 - # Pre-allocate "num_blocks * extra_factor" block objects. - # The "* extra_factor" is a buffer to allow more block objects - # than physical blocks - self._block_pool = BlockPool(self._block_size, create_block, self, - num_blocks * extra_factor) - else: - # In this case, the block pool is provided by the caller, - # which means that there is most likely a need to share - # a block pool between allocators - self._block_pool = block_pool - - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates a new immutable block with the given token IDs, linked to - the previous block. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. If - None, then the block to be allocated is the first block in the - sequence. - token_ids (List[int]): The token IDs to be stored in the new block. - - Returns: - Block: The newly allocated immutable block. - """ - assert device is None - block = self.allocate_mutable_block(prev_block=prev_block) - block.append_token_ids(token_ids) - return block - - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> List[Block]: - assert device is None - num_blocks = len(block_token_ids) - - block_ids = [] - for i in range(num_blocks): - block_ids.append(self._allocate_block_id()) - - blocks = [] - for i in range(num_blocks): - prev_block = self._block_pool.init_block( - prev_block=prev_block, - token_ids=block_token_ids[i], - block_size=self._block_size, - physical_block_id=block_ids[i]) - blocks.append(prev_block) - - return blocks - - def allocate_mutable_block(self, - prev_block: Optional[Block], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates a new mutable block, linked to the previous block. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. If - None, then the block to be allocated is the first block in the - sequence. - - Returns: - Block: The newly allocated mutable block. - """ - assert device is None - block_id = self._allocate_block_id() - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=[], - block_size=self._block_size, - physical_block_id=block_id) - return block - - def _allocate_block_id(self) -> BlockId: - if not self._free_block_indices: - raise BlockAllocator.NoFreeBlocksError() - - block_id = self._free_block_indices.popleft() - self._refcounter.incr(block_id) - return block_id - - def _free_block_id(self, block: Union[Block, BlockId]) -> None: - if isinstance(block, Block): - block_id = block.block_id - block.block_id = None - else: - block_id = block - assert block_id is not None - - refcount = self._refcounter.decr(block_id) - if refcount == 0: - self._free_block_indices.appendleft(block_id) - - def free(self, block: Block, keep_block_object: bool = False) -> None: - # Release the physical block id - self._free_block_id(block) - - # Release the block object - if not keep_block_object: - self._block_pool.free_block(block) - - def free_block_id(self, block_id: BlockId) -> None: - self._free_block_id(block_id) - - def fork(self, last_block: Block) -> List[Block]: - """Creates a new sequence of blocks that shares the same underlying - memory as the original sequence. - - Args: - last_block (Block): The last block in the original sequence. - - Returns: - List[Block]: The new sequence of blocks that shares the same memory - as the original sequence. - """ - source_blocks = get_all_blocks_recursively(last_block) - - forked_blocks: List[Block] = [] - prev_block = None - for block in source_blocks: - - # Increment refcount for each block. - assert block.block_id is not None - refcount = self._refcounter.incr(block.block_id) - assert refcount != 1, "can't fork free'd block" - - forked_block = self._block_pool.init_block( - prev_block=prev_block, - token_ids=block.token_ids, - block_size=self._block_size, - physical_block_id=block.block_id) - - forked_blocks.append(forked_block) - prev_block = forked_blocks[-1] - - return forked_blocks - - def get_num_free_blocks(self) -> int: - return len(self._free_block_indices) - - def get_num_total_blocks(self) -> int: - return len(self._all_block_indices) - - def get_physical_block_id(self, absolute_id: int) -> int: - """Returns the zero-offset block id on certain block allocator - given the absolute block id. - - Args: - absolute_id (int): The absolute block id for the block - in whole allocator. - - Returns: - int: The zero-offset block id on certain device. - """ - return sorted(self._all_block_indices).index(absolute_id) - - @property - def refcounter(self): - return self._refcounter - - @property - def all_block_ids(self) -> FrozenSet[int]: - return self._all_block_indices - - def cow_block_if_not_appendable(self, block: Block) -> BlockId: - """Performs a copy-on-write operation on the given block if it is not - appendable. - - Args: - block (Block): The block to check for copy-on-write. - - Returns: - BlockId: The block index of the new block if a copy-on-write - operation was performed, or the original block index if - no copy-on-write was necessary. - """ - src_block_id = block.block_id - assert src_block_id is not None - - if self._cow_tracker.is_appendable(block): - return src_block_id - - self._free_block_id(block) - trg_block_id = self._allocate_block_id() - - self._cow_tracker.record_cow(src_block_id, trg_block_id) - - return trg_block_id - - def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: - """Returns the copy-on-write source->destination mapping and clears it. - - Returns: - List[Tuple[BlockId, BlockId]]: A list mapping source - block indices to destination block indices. - """ - return self._cow_tracker.clear_cows() - - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - """Mark blocks as accessed, used in prefix caching. - - Since the naive allocator does not implement prefix caching, we do - nothing. - """ - pass - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - """Mark blocks as computed, used in prefix caching. - - Since the naive allocator does not implement prefix caching, we do - nothing. - """ - pass - - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - """Determine blocks that can be skipped in prefill. - - Since the naive allocator does not support prefix caching, always return - an empty list. - """ - return [] - - def promote_to_immutable_block(self, block: Block) -> BlockId: - raise NotImplementedError("There is no promotion for naive blocks") - - def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: - """Returns the number of full blocks that will be touched by - swapping in/out. - - Args: - blocks: List of blocks to be swapped. - Returns: - int: the number of full blocks that will be touched by - swapping in/out the given blocks. Non full blocks are ignored - when deciding the number of blocks to touch. - """ - # NOTE: for naive block, we use set to eliminate common blocks among - # seqs, also we compare the empty slots in the mutable blocks with - # lookahead slots to get the number of unique new block that are - # needed. - old_block_set = set() - for block in blocks: - if block.is_full: - old_block_set.add(block) - return len(old_block_set) - - def swap_out(self, blocks: List[Block]) -> None: - for block in blocks: - self._free_block_id(block) - - def swap_in(self, blocks: List[Block]) -> None: - for block in blocks: - # Here we allocate either immutable or mutable block and then - # extract its block_id. Note that the block object is released - # and the block_id is assigned to "block" to allow reusing the - # existing "block" object - if block.is_full: - tmp_block = self.allocate_immutable_block( - prev_block=block.prev_block, token_ids=block.token_ids) - else: - tmp_block = self.allocate_mutable_block( - prev_block=block.prev_block) - tmp_block.append_token_ids(block.token_ids) - - block_id = tmp_block.block_id - tmp_block.block_id = None - self._block_pool.free_block(tmp_block) - - block.block_id = block_id # Assign block_id - - def get_prefix_cache_hit_rate(self) -> float: - return -1 - - def reset_prefix_cache(self) -> bool: - """No prefix cache for naive block allocator.""" - return True - - def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: - # Not applicable for naive block allocator. - return [] - - -class NaiveBlock(Block): - """An implementation of the Block class that does not support prefix - caching. - - The NaiveBlock class represents a block of token IDs with a fixed size. It - provides methods for appending token IDs to the block and manages copy-on - -write operations when necessary. - - Args: - prev_block (Block): The previous block in the sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. - block_size (int): The maximum number of token IDs that can be stored in - the block. - allocator (BlockAllocator): The block allocator associated with this - block. - block_id (Optional[int], optional): The physical block index - of this block. Defaults to None, which means no allocation has been - made. - _cow_target (Optional[Block], optional): The copy-on-write target block. - If not provided, it defaults to self. - """ - - def __init__(self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - allocator: BlockAllocator, - block_id: Optional[int] = None, - _cow_target: Optional[Block] = None, - extra_hash: Optional[int] = None): - self._token_ids: List[int] = [] - self._block_size = block_size - self._prev_block = prev_block - self._block_id = block_id - self._allocator = allocator - self._cow_target = _cow_target if _cow_target is not None else self - - self._append_token_ids_no_cow(token_ids) - - def append_token_ids(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block and performs a - copy-on-write if necessary. - - Args: - token_ids (Optional[List[int]]): The token IDs to be appended - to the block. - """ - self._append_token_ids_no_cow(token_ids) - - if self._block_id is not None: - self._block_id = (self._allocator.cow_block_if_not_appendable( - self._cow_target)) - - def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block - - Args: - token_ids (List[int]): The token IDs to be appended to the block. - """ - if len(token_ids) == 0: - return - - assert len(token_ids) <= self.num_empty_slots - - self._token_ids.extend(token_ids) - - @property - def computed(self) -> bool: - raise NotImplementedError - - @computed.setter - def computed(self, value) -> None: - raise NotImplementedError - - @property - def last_accessed(self) -> float: - raise NotImplementedError - - @last_accessed.setter - def last_accessed(self, last_accessed_ts: float): - raise NotImplementedError - - @property - def block_id(self) -> Optional[int]: - return self._block_id - - @block_id.setter - def block_id(self, value: Optional[int]) -> None: - self._block_id = value - - @property - def is_full(self) -> bool: - return self.num_empty_slots == 0 - - @property - def num_empty_slots(self) -> int: - return self._block_size - len(self.token_ids) - - @property - def token_ids(self) -> List[int]: - return self._token_ids - - @property - def num_tokens_total(self) -> int: - raise NotImplementedError( - "num_tokens_total is not used for naive block") - - @property - def block_size(self) -> int: - return self._block_size - - @property - def prev_block(self) -> Optional["Block"]: - return self._prev_block - - @property - def extra_hash(self): - return None - - @property - def content_hash(self) -> Optional[int]: - return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py deleted file mode 100644 index 2913a01bf34a5..0000000000000 --- a/vllm/core/block/prefix_caching_block.py +++ /dev/null @@ -1,1135 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Token blocks.""" -import sys -from bisect import bisect_left -from os.path import commonprefix -from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, - Tuple) - -from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, - get_all_blocks_recursively) -from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, - DeviceAwareBlockAllocator) -from vllm.core.block.naive_block import (BlockPool, NaiveBlock, - NaiveBlockAllocator) -from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor -from vllm.logger import init_logger -from vllm.sequence import Sequence - -PrefixHash = int - -# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME -# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, -# then we know this block hasn't been accessed yet. -_DEFAULT_LAST_ACCESSED_TIME = -1 - -logger = init_logger(__name__) - - -class BlockTracker: - """Used to track the status of a block inside the prefix caching allocator - """ - __slots__ = ("active", "last_accessed", "computed") - - def reset(self): - self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME - self.computed: bool = False - - def __init__(self): - self.active: bool = False - self.reset() - - def enable(self): - assert not self.active - self.active = True - self.reset() - - def disable(self): - assert self.active - self.active = False - self.reset() - - -class PrefixCachingBlockAllocator(BlockAllocator): - """A block allocator that implements prefix caching. - - The PrefixCachingBlockAllocator maintains a cache of blocks based on their - content hash. It reuses blocks with the same content hash to avoid redundant - memory allocation. The allocator also supports copy-on-write operations. - - Args: - num_blocks (int): The total number of blocks to manage. - block_size (int): The size of each block in tokens. - block_ids(Optional[Iterable[int]], optional): An optional iterable of - block IDs. If not provided, block IDs will be assigned sequentially - from 0 to num_blocks - 1. - """ - - # Note that we use 'None' as a string here instead of None because - # as of Python 3.12, hash(None) returns a constant predictable value. - # This could possibly make it easier to find and exploit hash - # collisions. 'None' as a string will be hashed differently per process, - # but consistently within the same process. This is the same as the - # behavior of None prior to Python 3.12. - _none_hash: int = hash('None') - - # Implements Block.Factory. - def __init__( - self, - num_blocks: int, - block_size: int, - block_ids: Optional[Iterable[int]] = None, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU, - ): - if block_ids is None: - block_ids = range(num_blocks) - - self._block_size = block_size - - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash will be in this dict, even if they have refcount 0. - self._cached_blocks: Dict[PrefixHash, BlockId] = {} - - # A list of immutable block IDs that have been touched by scheduler - # and should be marked as computed after an entire batch of sequences - # are scheduled. - self._touched_blocks: Set[BlockId] = set() - - # Used to track status of each physical block id - self._block_tracker: Dict[BlockId, BlockTracker] = {} - for block_id in block_ids: - self._block_tracker[block_id] = BlockTracker() - - # Pre-allocate "num_blocks * extra_factor" block objects. - # The "* extra_factor" is a buffer to allow more block objects - # than physical blocks - extra_factor = 4 - self._block_pool = BlockPool(self._block_size, self._create_block, - self, num_blocks * extra_factor) - - # An allocator for blocks that do not have prefix hashes. - self._hashless_allocator = NaiveBlockAllocator( - create_block=self._create_block, # type: ignore - num_blocks=num_blocks, - block_size=block_size, - block_ids=block_ids, - block_pool=self._block_pool, # Share block pool here - ) - - # Evitor used to maintain how we want to handle those computed blocks - # if we find memory pressure is high. - self.eviction_policy = eviction_policy - self.evictor: Evictor = make_evictor(self.eviction_policy) - - # We share the refcounter between allocators. This allows us to promote - # blocks originally allocated in the hashless allocator to immutable - # blocks. - self._refcounter = self._hashless_allocator.refcounter - - self._cow_tracker = CopyOnWriteTracker( - refcounter=self._refcounter.as_readonly()) - - self.metric_data = CacheMetricData() - - def _create_block( - self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - allocator: BlockAllocator, - block_id: Optional[int] = None, - computed: bool = False, - extra_hash: Optional[int] = None, - ) -> Block: - # Bind block to self. - allocator = self - - return PrefixCachingBlock( - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=allocator, - computed=computed, - extra_hash=extra_hash, - ) - - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates an immutable block with the given token IDs, reusing cached - blocks if possible. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - token_ids (List[int]): The token IDs to be stored in the block. - - Returns: - Block: The allocated immutable block. - """ - assert device is None - assert_prefix_caching_block_or_none(prev_block) - - # First, try to create a block that points to cached data - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=token_ids, - block_size=self._block_size, - physical_block_id=None, - extra_hash=extra_hash) - assert block.content_hash is not None - - cached_block_id = self._cached_blocks.get(block.content_hash, None) - if cached_block_id is not None: - self.metric_data.query(hit=True) - block.block_id = cached_block_id - self._incr_refcount_cached_block(block) - return block - self.metric_data.query(hit=False) - self._block_pool.free_block(block) - - # No cached block => Allocate a new block - block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash) - block.append_token_ids(token_ids) - return block - - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> List[Block]: - blocks = [] - for token_ids in block_token_ids: - prev_block = self.allocate_immutable_block(prev_block=prev_block, - token_ids=token_ids, - device=device, - extra_hash=extra_hash) - blocks.append(prev_block) - return blocks - - def allocate_mutable_block(self, - prev_block: Optional[Block], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates a mutable block. If there are no free blocks, this will - evict unused cached blocks. - - Args: - prev_block (Block): The previous block in the sequence. - None is not allowed unlike it is super class. - - Returns: - Block: The allocated mutable block. - """ - assert device is None - assert_prefix_caching_block_or_none(prev_block) - - block_id = self._allocate_block_id() - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=[], - block_size=self._block_size, - physical_block_id=block_id, - extra_hash=extra_hash) - assert not block.computed - assert block.content_hash is None - return block - - def _incr_refcount_cached_block(self, block: Block) -> None: - # Set this block to be "computed" since it is pointing to a - # cached block id (which was already computed) - block.computed = True - - block_id = block.block_id - assert block_id is not None - - refcount = self._refcounter.incr(block_id) - if refcount == 1: - # In case a cached block was evicted, restore its tracking - if block_id in self.evictor: - self.evictor.remove(block_id) - - self._track_block_id(block_id, computed=True) - - def _decr_refcount_cached_block(self, block: Block) -> None: - # Ensure this is immutable/cached block - assert block.content_hash is not None - - block_id = block.block_id - assert block_id is not None - - refcount = self._refcounter.decr(block_id) - if refcount > 0: - block.block_id = None - return - else: - assert refcount == 0 - - # No longer used - assert block.content_hash in self._cached_blocks - - # Add the cached block to the evictor - # (This keeps the cached block around so it can be reused) - self.evictor.add(block_id, block.content_hash, block.num_tokens_total, - self._block_tracker[block_id].last_accessed) - - # Stop tracking the block - self._untrack_block_id(block_id) - - block.block_id = None - - def _decr_refcount_hashless_block(self, block: Block) -> None: - block_id = block.block_id - assert block_id is not None - - # We may have a fork case where block is shared, - # in which case, we cannot remove it from tracking - refcount = self._refcounter.get(block_id) - if refcount == 1: - self._untrack_block_id(block_id) - - # Decrement refcount of the block_id, but do not free the block object - # itself (will be handled by the caller) - self._hashless_allocator.free(block, keep_block_object=True) - - def _allocate_block_id(self) -> BlockId: - """First tries to allocate a block id from the hashless allocator, - and if there are no blocks, then tries to evict an unused cached block. - """ - hashless_block_id = self._maybe_allocate_hashless_block_id() - if hashless_block_id is not None: - return hashless_block_id - - evicted_block_id = self._maybe_allocate_evicted_block_id() - if evicted_block_id is not None: - return evicted_block_id - - # No block available in hashless allocator, nor in unused cache blocks. - raise BlockAllocator.NoFreeBlocksError() - - def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: - try: - # Allocate mutable block and extract its block_id - block = self._hashless_allocator.allocate_mutable_block( - prev_block=None) - block_id = block.block_id - self._block_pool.free_block(block) - - self._track_block_id(block_id, computed=False) - return block_id - except BlockAllocator.NoFreeBlocksError: - return None - - def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: - if self.evictor.num_blocks == 0: - return None - - # Here we get an evicted block, which is only added - # into evictor if its ref counter is 0 - # and since its content would be changed, we need - # to remove it from _cached_blocks's tracking list - block_id, content_hash_to_evict = self.evictor.evict() - - # Sanity checks - assert content_hash_to_evict in self._cached_blocks - _block_id = self._cached_blocks[content_hash_to_evict] - assert self._refcounter.get(_block_id) == 0 - assert _block_id == block_id - - self._cached_blocks.pop(content_hash_to_evict) - - self._refcounter.incr(block_id) - self._track_block_id(block_id, computed=False) - - return block_id - - def _free_block_id(self, block: Block) -> None: - """Decrements the refcount of the block. The block may be in two - possible states: (1) immutable/cached or (2) mutable/hashless. - In the first case, the refcount is decremented directly and the block - may be possibly added to the evictor. In other case, hashless - allocator free(..) with keep_block_object=True is called to only free - the block id (since the block object may be reused by the caller) - """ - block_id = block.block_id - assert block_id is not None, "Freeing unallocated block is undefined" - - if block.content_hash is not None: - # Immutable: This type of block is always cached, and we want to - # keep it in the evictor for future reuse - self._decr_refcount_cached_block(block) - else: - # Mutable: This type of block is not cached, so we release it - # directly to the hashless allocator - self._decr_refcount_hashless_block(block) - - assert block.block_id is None - - def free(self, block: Block, keep_block_object: bool = False) -> None: - """Release the block (look at free_block_id(..) docs) - """ - # Release the physical block index - self._free_block_id(block) - - # Release the block object to the pool - if not keep_block_object: - self._block_pool.free_block(block) - - def fork(self, last_block: Block) -> List[Block]: - """Creates a new sequence of blocks that shares the same underlying - memory as the original sequence. - - Args: - last_block (Block): The last block in the original sequence. - - Returns: - List[Block]: The new sequence of blocks that shares the same memory - as the original sequence. - """ - source_blocks = get_all_blocks_recursively(last_block) - - forked_blocks: List[Block] = [] - prev_block = None - for block in source_blocks: - block_id = block.block_id - assert block_id is not None - - refcount = self._refcounter.incr(block_id) - assert refcount != 1, "can't fork free'd block_id = {}".format( - block_id) - - forked_block = self._block_pool.init_block( - prev_block=prev_block, - token_ids=block.token_ids, - block_size=self._block_size, - physical_block_id=block_id, - extra_hash=block.extra_hash) - - forked_blocks.append(forked_block) - prev_block = forked_blocks[-1] - - return forked_blocks - - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: - assert device is None - # The number of free blocks is the number of hashless free blocks - # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks( - ) + self.evictor.num_blocks - - def get_num_total_blocks(self) -> int: - return self._hashless_allocator.get_num_total_blocks() - - def get_physical_block_id(self, absolute_id: int) -> int: - """Returns the zero-offset block id on certain block allocator - given the absolute block id. - - Args: - absolute_id (int): The absolute block id for the block - in whole allocator. - - Returns: - int: The rzero-offset block id on certain device. - """ - return sorted(self.all_block_ids).index(absolute_id) - - @property - def all_block_ids(self) -> FrozenSet[int]: - return self._hashless_allocator.all_block_ids - - def get_prefix_cache_hit_rate(self) -> float: - return self.metric_data.get_hit_rate() - - def reset_prefix_cache(self) -> bool: - """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, - or used for resetting prefix caching status for benchmarking. - - Returns: - bool: True if the prefix cache is successfully reset, - False otherwise. - """ - num_used_blocks = (self.get_num_total_blocks() - - self.get_num_free_blocks()) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Free all blocks in the evictor. - while (block_id := - self._maybe_allocate_evicted_block_id()) is not None: - self._hashless_allocator.free_block_id(block_id) - - # Should not have any cached blocks because all blocks are evicted. - assert not self._cached_blocks - - # Reset the evictor. - self.evictor = make_evictor(self.eviction_policy) - - # Reset the block tracker. - for block_id in self._block_tracker: - self._block_tracker[block_id] = BlockTracker() - - # Reset the metrics. - self.metric_data = CacheMetricData() - - logger.info("Successfully reset prefix cache") - return True - - def is_block_cached(self, block: Block) -> bool: - assert block.content_hash is not None - return block.content_hash in self._cached_blocks - - def promote_to_immutable_block(self, block: Block) -> BlockId: - """Once a mutable block is full, it can be promoted to an immutable - block. This means that its content can be referenced by future blocks - having the same prefix. - - Note that if we already have a cached block with the same content, we - will replace the newly-promoted block's mapping with the existing cached - block id. - - Args: - block: The mutable block to be promoted. - - Returns: - BlockId: Either the original block index, or the block index of - the previously cached block matching the same content. - """ - # Ensure block can be promoted - assert block.content_hash is not None - assert block.block_id is not None - assert self._refcounter.get(block.block_id) > 0 - - if block.content_hash not in self._cached_blocks: - # No cached content hash => Set this block as cached. - # Note that this block cannot be marked as computed yet - # because other sequences in the same batch cannot reuse - # this block. - self._cached_blocks[block.content_hash] = block.block_id - # Mark this block as touched so that it can be marked as - # computed after the entire batch of sequences are scheduled. - self._touched_blocks.add(block.block_id) - return block.block_id - - # Reuse the cached content hash - self._decr_refcount_hashless_block(block) - block.block_id = self._cached_blocks[block.content_hash] - - # Increment refcount of the cached block and (possibly) restore - # it from the evictor. - # Note that in this case, the block is marked as computed - self._incr_refcount_cached_block(block) - - return block.block_id - - def cow_block_if_not_appendable(self, block: Block) -> BlockId: - """Performs a copy-on-write operation on the given block if it is not - appendable. - - Args: - block (Block): The block to check for copy-on-write. - - Returns: - BlockId: The block index of the new block if a copy-on-write - operation was performed, or the original block index if - no copy-on-write was necessary. - """ - src_block_id = block.block_id - assert src_block_id is not None - - if self._cow_tracker.is_appendable(block): - return src_block_id - - self._free_block_id(block) - trg_block_id = self._allocate_block_id() - - self._cow_tracker.record_cow(src_block_id, trg_block_id) - - return trg_block_id - - def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: - """Returns the copy-on-write source->destination mapping and clears it. - - Returns: - List[Tuple[BlockId, BlockId]]: A list mapping source - block indices to destination block indices. - """ - return self._cow_tracker.clear_cows() - - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - """Mark blocks as accessed, used in prefix caching. - - If the block is added into evictor, we need to update corresponding - info in evictor's metadata. - """ - - for block_id in block_ids: - if self._block_tracker[block_id].active: - self._block_tracker[block_id].last_accessed = now - elif block_id in self.evictor: - self.evictor.update(block_id, now) - else: - raise ValueError( - "Mark block as accessed which is not belonged to GPU") - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - # Mark all touched blocks as computed. - for block_id in self._touched_blocks: - self._block_tracker[block_id].computed = True - self._touched_blocks.clear() - - def _track_block_id(self, block_id: Optional[BlockId], - computed: bool) -> None: - assert block_id is not None - self._block_tracker[block_id].enable() - self._block_tracker[block_id].computed = computed - - def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: - assert block_id is not None - self._block_tracker[block_id].disable() - - def block_is_computed(self, block_id: int) -> bool: - if self._block_tracker[block_id].active: - return self._block_tracker[block_id].computed - else: - return block_id in self.evictor - - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - """Return the block ids that are common for a given sequence group. - - Only those blocks that are immutable and already be marked - compyted would be taken consideration. - """ - - # NOTE We exclude the last block to avoid the case where the entire - # prompt is cached. This would cause erroneous behavior in model - # runner. - - # It returns a list of int although type annotation says list of string. - if len(computed_seq_block_ids) == 1: - return computed_seq_block_ids[0] - - return commonprefix([ - ids for ids in computed_seq_block_ids # type: ignore - if ids - ]) - - def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: - """Returns the number of full blocks that will be touched by - swapping in/out. - - Args: - blocks: List of blocks to be swapped. - Returns: - int: the number of full blocks that will be touched by - swapping in/out the given blocks. Non full blocks are ignored - when deciding the number of blocks to touch. - """ - num_touched_blocks: int = 0 - for block in blocks: - # If the block has a match in the cache and the cached - # block is not referenced, then we still count it as a - # touched block - if block.is_full and (not self.is_block_cached(block) or \ - (block.content_hash is not None and \ - self._cached_blocks[block.content_hash] in \ - self.evictor)): - num_touched_blocks += 1 - return num_touched_blocks - - def swap_out(self, blocks: List[Block]) -> None: - """Execute the swap out actions. Basically just free the - given blocks. - - Args: - blocks: List of blocks to be swapped out. - """ - for block in blocks: - self._free_block_id(block) - - def swap_in(self, blocks: List[Block]) -> None: - """Execute the swap in actions. Change the block id from - old allocator to current allocator for each block to finish - the block table update. - - Args: - blocks: List of blocks to be swapped in. - """ - for block in blocks: - # Here we allocate either immutable or mutable block and then - # extract its block_id. Note that the block object is released - # and the block_id is assigned to "block" to allow reusing the - # existing "block" object - if block.is_full: - tmp_block = self.allocate_immutable_block( - prev_block=block.prev_block, - token_ids=block.token_ids, - extra_hash=block.extra_hash) - else: - tmp_block = self.allocate_mutable_block( - prev_block=block.prev_block, extra_hash=block.extra_hash) - tmp_block.append_token_ids(block.token_ids) - - block_id = tmp_block.block_id - self._block_pool.free_block(tmp_block) - - block.block_id = block_id # Assign block_id - - def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: - """ - Given a list of block hashes, return the prefix of the block hashes that - are all cached. - - Since a block's block hash includes the hashes of all previous blocks, - and we only allocate/deallocate blocks in the entire sequence, so if a - block is cached, then all previous blocks are also cached. With this - property, we can use binary search to find the prefix of cached blocks. - - Args: - block_hashes (List[int]): The list of block hashes. - - Returns: - List[int]: The prefix of the `block_hashes` that are cached. - """ - - def _block_is_cached(block_hash: PrefixHash) -> bool: - if block_hash not in self._cached_blocks: - return False - - cached_block_id = self._cached_blocks[block_hash] - # We only consider the blocks that are marked as computed. - return self.block_is_computed(cached_block_id) - - def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: - - # python <= 3.10 don't have the key argument - if sys.version_info < (3, 10): - a = [key(e) for e in a] - return bisect_left(a, x) - else: - return bisect_left(a, x, key=key) - - # Look for the first block that's not cached, and returns the prefix - # i.e. blocks that are cached. - idx = _bisect_left(block_hashes, - True, - key=lambda x: not _block_is_cached(x)) - return block_hashes[:idx] - - -class PrefixCachingBlock(Block): - """A block implementation that supports prefix caching. - - The PrefixCachingBlock class represents a block of token IDs with prefix - caching capabilities. It wraps a NaiveBlock internally and provides - additional functionality for content hashing and promoting immutable blocks - with the prefix caching allocator. - - Args: - prev_block (Optional[PrefixCachingBlock]): The previous block in the - sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. - block_size (int): The maximum number of token IDs that can be stored in - the block. - allocator (BlockAllocator): The prefix - caching block allocator associated with this block. - block_id (Optional[int], optional): The physical block index - of this block. Defaults to None. - extra_hash (Optional[int]): The hash value of additional factors - such as adapters that influence the block, apart from the token_ids. - """ - - # Note that we use 'None' as a string here instead of None because - # as of Python 3.12, hash(None) returns a constant predictable value. - # This could possibly make it easier to find and exploit hash - # collisions. 'None' as a string will be hashed differently per process, - # but consistently within the same process. This is the same as the - # behavior of None prior to Python 3.12. - _none_hash: int = hash('None') - - def __init__( - self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - allocator: BlockAllocator, - block_id: Optional[int] = None, - computed: bool = False, - extra_hash: Optional[int] = None, - ): - assert isinstance(allocator, PrefixCachingBlockAllocator), ( - "Currently this class is only tested with " - "PrefixCachingBlockAllocator. Got instead allocator = {}".format( - allocator)) - assert_prefix_caching_block_or_none(prev_block) - - self._prev_block = prev_block - self._cached_content_hash: Optional[int] = None - self._cached_num_tokens_total: int = 0 - self._allocator = allocator - self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME - self._computed = computed - self._extra_hash = extra_hash - - # On the first time, we create the block object, and next we only - # reinitialize it - if hasattr(self, "_block"): - self._block.__init__( # type: ignore[has-type] - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=self._allocator) - else: - self._block = NaiveBlock(prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=self._allocator) - - self._update_num_tokens_total() - - def _update_num_tokens_total(self): - """Incrementally computes the number of tokens that there is - till the current block (included) - """ - res = 0 - - # Add all previous blocks - if self._prev_block is not None: - res += self._prev_block.num_tokens_total - - # Add current block - res += len(self.token_ids) - - self._cached_num_tokens_total = res - - @property - def computed(self) -> bool: - return self._computed - - @computed.setter - def computed(self, value) -> None: - self._computed = value - - @property - def last_accessed(self) -> float: - return self._last_accessed - - @last_accessed.setter - def last_accessed(self, last_accessed_ts: float): - self._last_accessed = last_accessed_ts - - def append_token_ids(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block and registers the block as - immutable if the block becomes full. - - Args: - token_ids (List[int]): The token IDs to be appended to the block. - """ - # Ensure this is mutable block (not promoted) - assert self.content_hash is None - assert not self.computed - - if len(token_ids) == 0: - return - - # Ensure there are input tokens - assert token_ids, "Got token_ids = {}".format(token_ids) - - # Naive block handles CoW. - self._block.append_token_ids(token_ids) - self._update_num_tokens_total() - - # If the content hash is present, then the block can be made immutable. - # Register ourselves with the allocator, potentially replacing the - # physical block index. - if self.content_hash is not None: - self.block_id = self._allocator.promote_to_immutable_block(self) - - @property - def block_id(self) -> Optional[int]: - return self._block.block_id - - @block_id.setter - def block_id(self, value) -> None: - self._block.block_id = value - - @property - def is_full(self) -> bool: - return self._block.is_full - - @property - def num_empty_slots(self) -> int: - return self._block.num_empty_slots - - @property - def num_tokens_total(self) -> int: - return self._cached_num_tokens_total - - @property - def block_size(self) -> int: - return self._block.block_size - - @property - def token_ids(self) -> List[int]: - return self._block.token_ids - - @property - def prev_block(self) -> Optional[Block]: - return self._prev_block - - @property - def extra_hash(self) -> Optional[int]: - return self._extra_hash - - @property - def content_hash(self) -> Optional[int]: - """Return the content-based hash of the current block, or None if it is - not yet defined. - - For the content-based hash to be defined, the current block must be - full. - """ - # If the hash is already computed, return it. - if self._cached_content_hash is not None: - return self._cached_content_hash - - # We cannot compute a hash for the current block because it is not full. - if not self.is_full: - return None - - is_first_block = self._prev_block is None - prev_block_hash = ( - self._none_hash if is_first_block else - self._prev_block.content_hash # type: ignore - ) - - # Previous block exists but does not yet have a hash. - # Return no hash in this case. - if prev_block_hash == self._none_hash and not is_first_block: - return None - - self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( - is_first_block, - prev_block_hash, - cur_block_token_ids=self.token_ids, - extra_hash=self._extra_hash) - return self._cached_content_hash - - @classmethod - def hash_block_tokens(cls, - is_first_block: bool, - prev_block_hash: Optional[int], - cur_block_token_ids: List[int], - extra_hash: Optional[int] = None) -> int: - """Computes a hash value corresponding to the contents of a block and - the contents of the preceding block(s). The hash value is used for - prefix caching. - - Parameters: - - is_first_block (bool): A flag indicating if the block is the first in - the sequence. - - prev_block_hash (Optional[int]): The hash of the previous block. None - if this is the first block. - - cur_block_token_ids (List[int]): A list of token ids in the current - block. The current block is assumed to be full. - - extra_hash (Optional[int]): The hash value of additional factors - such as adapters that influence the block, apart from the token_ids. - - Returns: - - int: The computed hash value for the block. - """ - if is_first_block and prev_block_hash is None: - prev_block_hash = cls._none_hash - return hash((is_first_block, prev_block_hash, *cur_block_token_ids, - extra_hash)) - - -class ComputedBlocksTracker: - """ - Tracks the computed blocks for each sequence. - - Internally, it maintains a map from sequence id to the list of block hashes - for the sequence. We cache the hashes of the full blocks for each sequence, - and make sure the hash is calculated in the same way as the allocator. - When a sequence is being decoded, we also update the sequence's hash - accordingly and incrementally. - - From the sequence hash, with prefix caching enabled, we could also calculate - the number of cached tokens for the sequence by looking up the number of - cached block hashes in the allocator. - """ - - # Note that we use 'None' as a string here instead of None because - # as of Python 3.12, hash(None) returns a constant predictable value. - # This could possibly make it easier to find and exploit hash - # collisions. 'None' as a string will be hashed differently per process, - # but consistently within the same process. This is the same as the - # behavior of None prior to Python 3.12. - _none_hash: int = hash('None') - - def __init__( - self, - allocator: DeviceAwareBlockAllocator, - block_size: int, - enable_caching: bool, - ): - self._allocator = allocator - self._block_size = block_size - self._enable_caching = enable_caching - - # A map from seq_id to the list of block hashes for the - # sequence. This is so that we don't have to recompute the block hashes - # for the sequence when we need to check if the sequence is cached. - # Note a block that's not full will not have its hash calculated and - # recorded. - self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} - - # A map from seq_id to the number of tokens that are cached for the - # sequence. - # We need this so that a sequence in continuous prefill doesn't - # accidentally see its cached token count change. See comments in - # `get_num_cached_tokens` for more details. - self._seq_id_to_num_tokens_computed: Dict[int, int] = {} - - def _update_seq_hashes(self, seq: Sequence) -> None: - """Incrementally update the sequence's block hashes and record them.""" - assert self._enable_caching - - block_hashes_recorded = self._seq_id_to_blocks_hashes.get( - seq.seq_id, []) - cur_num_blocks_recorded = len(block_hashes_recorded) - token_ids = seq.get_token_ids() - assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( - f"The sequence has {len(token_ids)} tokens, but" - f" already recorded {cur_num_blocks_recorded} blocks. " - "This should not happen since we assume blocks are " - "only appended other than recomputation. When the sequence is " - "recomputed, we should have removed the info of the old blocks.") - # Update the computed block hashes for the sequence. Since only full - # blocks are considered as "computed", we take floor here. - num_computed_blocks = len(token_ids) // self._block_size - - # We need to know the hash of the previous block to compute the hash of - # the current block so that blocks could be uniquely identified across - # sequences of prefixes. - prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else - block_hashes_recorded[-1]) - # Only update the computed block hashes for the new blocks - for i in range(cur_num_blocks_recorded, num_computed_blocks): - assert len(token_ids) >= (i + 1) * self._block_size - block_token_ids = token_ids[i * self._block_size:(i + 1) * - self._block_size] - - # NOTE: If there are any factors affecting the block besides - # token_ids, they should be added as input to extra_hash. - extra_hash = seq.extra_hash() - - # This has to be kept in sync with the allocator's hash - # calculation. - block_hash = PrefixCachingBlock.hash_block_tokens( - is_first_block=prev_block_hash == self._none_hash, - prev_block_hash=prev_block_hash, - cur_block_token_ids=block_token_ids, - extra_hash=extra_hash, - ) - block_hashes_recorded.append(block_hash) - prev_block_hash = block_hash - - self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded - - def get_num_cached_tokens(self, seq: Sequence) -> int: - if not self._enable_caching: - return 0 - - # We always try to update the sequence hashes on the fly. - # This is to ensure that we don't miss any cached tokens for the - # sequence during decode. - # This routine should only update hash for any new blocks too. - self._update_seq_hashes(seq) - - num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( - seq.seq_id, None) - - # TODO(rickyx): This hack could be removed once we mark blocks as - # computed correctly with chunked prefills. - if num_computed_tokens_prev is not None and seq.is_prefill(): - # For a sequence that is still in prefill, we don't - # recompute the number of cached tokens. - # This also handles correctly chunked prefill since currently - # we mark blocks as computed even if the sequence is still partially - # prefilled. So a continuously prefilled sequence should not - # see its cached token count change while running. - return num_computed_tokens_prev - - block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] - - # This is O(logN), where N is the number of blocks. - num_cached_blocks = len( - self._allocator.find_cached_blocks_prefix(block_hashes)) - num_cached_tokens = num_cached_blocks * self._block_size - self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens - return num_cached_tokens - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking the sequence.""" - if not self._enable_caching: - return - assert seq_id in self._seq_id_to_blocks_hashes - del self._seq_id_to_blocks_hashes[seq_id] - - assert seq_id in self._seq_id_to_num_tokens_computed - del self._seq_id_to_num_tokens_computed[seq_id] - - -class LastAccessBlocksTracker: - """Manages the last access time of the tracked sequences, in order to allow - an efficient update of allocator's block last access times - """ - - def __init__(self, allocator): - self._allocator = allocator - self._seq_last_access: Dict[int, Optional[float]] = {} - - def add_seq(self, seq_id: int) -> None: - """Start tracking seq_id - """ - assert seq_id not in self._seq_last_access - self._seq_last_access[seq_id] = None - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking seq_id - """ - assert seq_id in self._seq_last_access - del self._seq_last_access[seq_id] - - def update_last_access(self, seq_id: int, time: float) -> None: - assert seq_id in self._seq_last_access - self._seq_last_access[seq_id] = time - - def update_seq_blocks_last_access(self, seq_id: int, - block_ids: List[int]) -> None: - assert seq_id in self._seq_last_access - - ts = self._seq_last_access[seq_id] - - if ts is None: - # No last access was recorded, no need to update. - return - - self._allocator.mark_blocks_as_accessed(block_ids, ts) - - -def assert_prefix_caching_block_or_none(block: Optional[Block]): - if block is None: - return - assert isinstance(block, - PrefixCachingBlock), "Got block = {}".format(block) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py deleted file mode 100644 index e933c6ee7c8bd..0000000000000 --- a/vllm/core/block/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Block manager utils.""" -from vllm.sequence import SequenceGroup -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) - - -def check_no_caching_or_swa_for_blockmgr_encdec( - block_mgr, seq_group: SequenceGroup) -> None: - ''' - Enforce that prefix caching & sliding-window attention (SWA) - are currently unsupported *specifically* for encoder/decoder models. - - Raises NotImplementedError if unsupported scenario is detected. - - Arguments: - - * block_mgr: BlockSpaceManager instance - * seq_group: SequenceGroup passed to block_mgr - ''' - - if seq_group.is_encoder_decoder(): - if block_mgr.max_block_sliding_window is not None: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - - if block_mgr.enable_caching: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py deleted file mode 100644 index 4ec5a775f465c..0000000000000 --- a/vllm/core/block_manager.py +++ /dev/null @@ -1,525 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A block manager that manages token blocks.""" -from typing import Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple - -from vllm.core.block.block_table import BlockTable -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block -from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - LastAccessBlocksTracker) -from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -SeqId = int -EncoderSeqId = str - - -class SelfAttnBlockSpaceManager(BlockSpaceManager): - """BlockSpaceManager which manages the allocation of KV cache. - - It owns responsibility for allocation, swapping, allocating memory for - autoregressively-generated tokens, and other advanced features such as - prefix caching, forking/copy-on-write, and sliding-window memory allocation. - - This class implements the design described in - https://github.com/vllm-project/vllm/pull/3492. - - Lookahead slots - The block manager has the notion of a "lookahead slot". These are slots - in the KV cache that are allocated for a sequence. Unlike the other - allocated slots, the content of these slots is undefined -- the worker - may use the memory allocations in any way. - - In practice, a worker could use these lookahead slots to run multiple - forward passes for a single scheduler invocation. Each successive - forward pass would write KV activations to the corresponding lookahead - slot. This allows low inter-token latency use-cases, where the overhead - of continuous batching scheduling is amortized over >1 generated tokens. - - Speculative decoding uses lookahead slots to store KV activations of - proposal tokens. - - See https://github.com/vllm-project/vllm/pull/3250 for more information - on lookahead scheduling. - - Args: - block_size (int): The size of each memory block. - num_gpu_blocks (int): The number of memory blocks allocated on GPU. - num_cpu_blocks (int): The number of memory blocks allocated on CPU. - watermark (float, optional): The threshold used for memory swapping. - Defaults to 0.01. - sliding_window (Optional[int], optional): The size of the sliding - window. Defaults to None. - enable_caching (bool, optional): Flag indicating whether caching is - enabled. Defaults to False. - """ - - def __init__( - self, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, - watermark: float = 0.01, - sliding_window: Optional[int] = None, - enable_caching: bool = False, - ) -> None: - self.block_size = block_size - self.num_total_gpu_blocks = num_gpu_blocks - self.num_total_cpu_blocks = num_cpu_blocks - - self.sliding_window = sliding_window - # max_block_sliding_window is the max number of blocks that need to be - # allocated - self.max_block_sliding_window = None - if sliding_window is not None: - # +1 here because // rounds down - num_blocks = sliding_window // block_size + 1 - # +1 here because the last block may not be full, - # and so the sequence stretches one more block at the beginning - # For example, if sliding_window is 3 and block_size is 4, - # we may need 2 blocks when the second block only holds 1 token. - self.max_block_sliding_window = num_blocks + 1 - - self.watermark = watermark - assert watermark >= 0.0 - - self.enable_caching = enable_caching - - self.watermark_blocks = int(watermark * num_gpu_blocks) - - self.block_allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching" if enable_caching else "naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - self.block_tables: Dict[SeqId, BlockTable] = {} - self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} - - self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator, self.block_size, self.enable_caching) - self._last_access_blocks_tracker = LastAccessBlocksTracker( - self.block_allocator) - - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), - block_size=self.block_size, - num_lookahead_slots=num_lookahead_slots, - ) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), - block_size=self.block_size, - ) - - if self.max_block_sliding_window is not None: - num_required_blocks = min(num_required_blocks, - self.max_block_sliding_window) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU) - - # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks - < self.watermark_blocks): - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _allocate_sequence(self, seq: Sequence) -> BlockTable: - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - max_block_sliding_window=self.max_block_sliding_window, - ) - if seq.get_token_ids(): - # NOTE: If there are any factors affecting the block besides - # token_ids, they should be added as input to extra_hash. - extra_hash = seq.extra_hash() - - # Add blocks to the block table only if the sequence is non empty. - block_table.allocate(token_ids=seq.get_token_ids(), - extra_hash=extra_hash) - - return block_table - - def allocate(self, seq_group: SequenceGroup) -> None: - - # Allocate self-attention block tables for decoder sequences - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert not (set(seq.seq_id for seq in waiting_seqs) - & self.block_tables.keys()), "block table already exists" - - # NOTE: Here we assume that all sequences in the group have the same - # prompt. - seq = waiting_seqs[0] - block_table: BlockTable = self._allocate_sequence(seq) - self.block_tables[seq.seq_id] = block_table - - # Track seq - self._last_access_blocks_tracker.add_seq(seq.seq_id) - - # Assign the block table for each sequence. - for seq in waiting_seqs[1:]: - self.block_tables[seq.seq_id] = block_table.fork() - - # Track seq - self._last_access_blocks_tracker.add_seq(seq.seq_id) - - # Allocate cross-attention block table for encoder sequence - # - # NOTE: Here we assume that all sequences in the group have the same - # encoder prompt. - request_id = seq_group.request_id - - assert (request_id - not in self.cross_block_tables), \ - "block table already exists" - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - block_table = self._allocate_sequence(encoder_seq) - self.cross_block_tables[request_id] = block_table - - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - """Determine if there is enough space in the GPU KV cache to continue - generation of the specified sequence group. - - We use a worst-case heuristic: assume each touched block will require a - new allocation (either via CoW or new block). We can append slots if the - number of touched blocks is less than the number of free blocks. - - "Lookahead slots" are slots that are allocated in addition to the slots - for known tokens. The contents of the lookahead slots are not defined. - This is used by speculative decoding when speculating future tokens. - """ - - num_touched_blocks = 0 - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - block_table = self.block_tables[seq.seq_id] - - num_touched_blocks += ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids( - seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - )) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - Device.GPU) - return num_touched_blocks <= num_free_gpu_blocks - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - - block_table = self.block_tables[seq.seq_id] - - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - num_computed_slots=seq.data.get_num_computed_tokens(), - extra_hash=seq.extra_hash(), - ) - # Return any new copy-on-writes. - new_cows = self.block_allocator.clear_copy_on_writes() - return new_cows - - def free(self, seq: Sequence) -> None: - seq_id = seq.seq_id - - if seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - - # Update seq block ids with the latest access time - self._last_access_blocks_tracker.update_seq_blocks_last_access( - seq_id, self.block_tables[seq.seq_id].physical_block_ids) - - # Untrack seq - self._last_access_blocks_tracker.remove_seq(seq_id) - self._computed_blocks_tracker.remove_seq(seq_id) - - # Free table/blocks - self.block_tables[seq_id].free() - del self.block_tables[seq_id] - - def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - seq_id = seq.seq_id - self._computed_blocks_tracker.remove_seq(seq_id) - - def free_cross(self, seq_group: SequenceGroup) -> None: - request_id = seq_group.request_id - if request_id not in self.cross_block_tables: - # Already freed or hasn't been scheduled yet. - return - self.cross_block_tables[request_id].free() - del self.cross_block_tables[request_id] - - def get_block_table(self, seq: Sequence) -> List[int]: - block_ids = self.block_tables[seq.seq_id].physical_block_ids - return block_ids # type: ignore - - def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: - request_id = seq_group.request_id - assert request_id in self.cross_block_tables - block_ids = self.cross_block_tables[request_id].physical_block_ids - assert all(b is not None for b in block_ids) - return block_ids # type: ignore - - def access_all_blocks_in_seq(self, seq: Sequence, now: float): - if self.enable_caching: - # Record the latest access time for the sequence. The actual update - # of the block ids is deferred to the sequence free(..) call, since - # only during freeing of block ids, the blocks are actually added to - # the evictor (which is when the most updated time is required) - # (This avoids expensive calls to mark_blocks_as_accessed(..)) - self._last_access_blocks_tracker.update_last_access( - seq.seq_id, now) - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - # If prefix caching is enabled, mark immutable blocks as computed - # right after they have been scheduled (for prefill). This assumes - # the scheduler is synchronous so blocks are actually computed when - # scheduling the next batch. - self.block_allocator.mark_blocks_as_computed([]) - - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - """Determine which blocks for which we skip prefill. - - With prefix caching we can skip prefill for previously-generated blocks. - Currently, the attention implementation only supports skipping cached - blocks if they are a contiguous prefix of cached blocks. - - This method determines which blocks can be safely skipped for all - sequences in the sequence group. - """ - computed_seq_block_ids = [] - for seq in seqs: - all_blocks = self.block_tables[seq.seq_id].physical_block_ids - num_cached_tokens = ( - self._computed_blocks_tracker.get_num_cached_tokens(seq)) - assert num_cached_tokens % self.block_size == 0 - num_cached_blocks = num_cached_tokens // self.block_size - computed_block_ids = all_blocks[:num_cached_blocks] - computed_seq_block_ids.append(computed_block_ids) - - # NOTE(sang): This assumes seq_block_ids doesn't contain any None. - return self.block_allocator.get_common_computed_block_ids( - computed_seq_block_ids) # type: ignore - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - if parent_seq.seq_id not in self.block_tables: - # Parent sequence has either been freed or never existed. - return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.fork() - - # Track child seq - self._last_access_blocks_tracker.add_seq(child_seq.seq_id) - - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - """Returns the AllocStatus for the given sequence_group - with num_lookahead_slots. - - Args: - sequence_group (SequenceGroup): The sequence group to swap in. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - AllocStatus: The AllocStatus for the given sequence group. - """ - return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, - num_lookahead_slots) - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - """Returns the block id mapping (from CPU to GPU) generated by - swapping in the given seq_group with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap in. - - Returns: - List[Tuple[int, int]]: The mapping of swapping block from CPU - to GPU. - """ - physical_block_id_mapping = [] - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - blocks = self.block_tables[seq.seq_id].blocks - if len(blocks) == 0: - continue - - seq_swap_mapping = self.block_allocator.swap(blocks=blocks, - src_device=Device.CPU, - dst_device=Device.GPU) - - # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) - - seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id): - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id) - for cpu_block_id, gpu_block_id in seq_swap_mapping.items() - } - - physical_block_id_mapping.extend( - list(seq_physical_block_id_mapping.items())) - - return physical_block_id_mapping - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - """Returns whether we can swap out the given sequence_group - with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap out. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - bool: Whether it's possible to swap out current sequence group. - """ - alloc_status = self._can_swap(seq_group, Device.CPU, - SequenceStatus.RUNNING) - return alloc_status == AllocStatus.OK - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - """Returns the block id mapping (from GPU to CPU) generated by - swapping out the given sequence_group with num_lookahead_slots. - - Args: - sequence_group (SequenceGroup): The sequence group to swap out. - - Returns: - List[Tuple[int, int]]: The mapping of swapping block from - GPU to CPU. - """ - physical_block_id_mapping = [] - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - blocks = self.block_tables[seq.seq_id].blocks - if len(blocks) == 0: - continue - - seq_swap_mapping = self.block_allocator.swap(blocks=blocks, - src_device=Device.GPU, - dst_device=Device.CPU) - - # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) - - seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id): - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id) - for gpu_block_id, cpu_block_id in seq_swap_mapping.items() - } - - physical_block_id_mapping.extend( - list(seq_physical_block_id_mapping.items())) - - return physical_block_id_mapping - - def get_num_free_gpu_blocks(self) -> int: - return self.block_allocator.get_num_free_blocks(Device.GPU) - - def get_num_free_cpu_blocks(self) -> int: - return self.block_allocator.get_num_free_blocks(Device.CPU) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_allocator.get_prefix_cache_hit_rate(device) - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - return self.block_allocator.reset_prefix_cache(device) - - def _can_swap(self, - seq_group: SequenceGroup, - device: Device, - status: SequenceStatus, - num_lookahead_slots: int = 0) -> AllocStatus: - """Returns the AllocStatus for swapping in/out the given sequence_group - on to the 'device'. - - Args: - sequence_group (SequenceGroup): The sequence group to swap in/out. - device (Device): device to swap the 'seq_group' on. - status (SequenceStatus): The status of sequence which is needed - for action. RUNNING for swap out and SWAPPED for swap in - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - AllocStatus: The AllocStatus for swapping in/out the given - sequence_group on to the 'device'. - """ - # First determine the number of blocks that will be touched by this - # swap. Then verify if there are available blocks in the device - # to perform the swap. - num_blocks_touched = 0 - blocks: List[Block] = [] - for seq in seq_group.get_seqs(status=status): - block_table = self.block_tables[seq.seq_id] - if block_table.blocks is not None: - # Compute the number blocks to touch for the tokens to be - # appended. This does NOT include the full blocks that need - # to be touched for the swap. - num_blocks_touched += \ - block_table.get_num_blocks_touched_by_append_slots( - block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots) - blocks.extend(block_table.blocks) - # Compute the number of full blocks to touch and add it to the - # existing count of blocks to touch. - num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( - blocks, device=device) - - watermark_blocks = 0 - if device == Device.GPU: - watermark_blocks = self.watermark_blocks - - if self.block_allocator.get_num_total_blocks( - device) < num_blocks_touched: - return AllocStatus.NEVER - elif self.block_allocator.get_num_free_blocks( - device) - num_blocks_touched >= watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def get_num_cached_tokens(self, seq: Sequence) -> int: - """Get the number of tokens in blocks that are already computed and - cached in the block manager for the sequence. - """ - return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py deleted file mode 100644 index 7ec4768e90b1a..0000000000000 --- a/vllm/core/evictor.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -import heapq -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple - - -class EvictionPolicy(enum.Enum): - """Enum for eviction policy used by make_evictor to instantiate the correct - Evictor subclass. - """ - LRU = enum.auto() - - -class Evictor(ABC): - """The Evictor subclasses should be used by the BlockAllocator class to - handle eviction of freed Blocks. - """ - - @abstractmethod - def __init__(self): - pass - - @abstractmethod - def __contains__(self, block_id: int) -> bool: - pass - - @abstractmethod - def evict(self) -> Tuple[int, int]: - """Runs the eviction algorithm and returns the evicted block's - content hash along with physical block id along with physical block id - """ - pass - - @abstractmethod - def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, - last_accessed: float): - """Adds block to the evictor, making it a candidate for eviction""" - pass - - @abstractmethod - def update(self, block_id: int, last_accessed: float): - """Update corresponding block's access time in metadata""" - pass - - @abstractmethod - def remove(self, block_id: int): - """Remove a given block id from the cache.""" - pass - - @property - @abstractmethod - def num_blocks(self) -> int: - pass - - -class BlockMetaData: - """Data structure for storing key data describe cached block, so that - evitor could use to make its decision which one to choose for eviction - - Here we use physical block id as the dict key, as there maybe several - blocks with the same content hash, but their physical id is unique. - """ - - def __init__(self, content_hash: int, num_hashed_tokens: int, - last_accessed: float): - self.content_hash = content_hash - self.num_hashed_tokens = num_hashed_tokens - self.last_accessed = last_accessed - - -class LRUEvictor(Evictor): - """Evicts in a least-recently-used order using the last_accessed timestamp - that's recorded in the Block. If there are multiple blocks with - the same last_accessed time, then the one with the largest num_hashed_tokens - will be evicted. If two blocks each have the lowest last_accessed time and - highest num_hashed_tokens value, then one will be chose arbitrarily - """ - - # CLEANUP_THRESHOLD determines the maximum allowable size of the priority - # queue relative to the free table size. When this threshold is exceeded, - # a cleanup operation is triggered to reduce memory usage. - CLEANUP_THRESHOLD = 50 - - def __init__(self): - self.free_table: Dict[int, BlockMetaData] = {} - self.priority_queue = [] - - def __contains__(self, block_id: int) -> bool: - return block_id in self.free_table - - def evict(self) -> Tuple[int, int]: - if len(self.free_table) == 0: - raise ValueError("No usable cache memory left") - - while self.priority_queue: - # We do not remove outdated entries from the priority queue at the - # time of updating the last_accessed timestamp. Instead, outdated - # entries are filtered out here during eviction. Outdated entries - # would either not in the free table, or have older last accessed - # time. - last_accessed, _, block_id, content_hash = heapq.heappop( - self.priority_queue) - if (block_id in self.free_table and - self.free_table[block_id].last_accessed == last_accessed): - self.free_table.pop(block_id) - return block_id, content_hash - - raise ValueError("No usable cache memory left") - - def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, - last_accessed: float): - self.free_table[block_id] = BlockMetaData(content_hash, - num_hashed_tokens, - last_accessed) - heapq.heappush( - self.priority_queue, - (last_accessed, -num_hashed_tokens, block_id, content_hash)) - self._cleanup_if_necessary() - - def update(self, block_id: int, last_accessed: float): - self.free_table[block_id].last_accessed = last_accessed - - def _cleanup_if_necessary(self): - if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len( - self.free_table): - self._cleanup() - - def _cleanup(self): - new_priority_queue: List[Tuple[float, int, int, int]] = [] - - for block_id, block in self.free_table.items(): - new_priority_queue.append( - (block.last_accessed, -block.num_hashed_tokens, block_id, - block.content_hash)) - heapq.heapify(new_priority_queue) - - self.priority_queue = new_priority_queue - - def remove(self, block_id: int): - if block_id not in self.free_table: - raise ValueError( - "Attempting to remove block that's not in the evictor") - self.free_table.pop(block_id) - - @property - def num_blocks(self) -> int: - return len(self.free_table) - - -def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: - if eviction_policy == EvictionPolicy.LRU: - return LRUEvictor() - else: - raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py deleted file mode 100644 index 69b9169ddd8a9..0000000000000 --- a/vllm/core/interfaces.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -from abc import ABC, abstractmethod -from typing import List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple - -from vllm.sequence import Sequence, SequenceGroup -from vllm.utils import Device - - -class AllocStatus(enum.Enum): - """Result for BlockSpaceManager.can_allocate - - 1. Ok: seq_group can be allocated now. - 2. Later: seq_group cannot be allocated. - The capacity of allocator is larger than seq_group required. - 3. Never: seq_group can never be allocated. - The seq_group is too large to allocated in GPU. - """ - OK = enum.auto() - LATER = enum.auto() - NEVER = enum.auto() - - -class BlockSpaceManager(ABC): - - @staticmethod - def get_block_space_manager_class(version: str): - version = version.lower() - - if version == "selfattn": - from vllm.core.block_manager import SelfAttnBlockSpaceManager - return SelfAttnBlockSpaceManager - - if version == "placeholder": - from vllm.core.placeholder_block_space_manager import ( - PlaceholderBlockSpaceManager) - return PlaceholderBlockSpaceManager - - raise ValueError(f"Unknown version {version=}") - - @abstractmethod - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - pass - - @abstractmethod - def allocate(self, seq_group: SequenceGroup) -> None: - pass - - @abstractmethod - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - pass - - @abstractmethod - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - pass - - @abstractmethod - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - pass - - @abstractmethod - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - pass - - @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def free(self, seq: Sequence) -> None: - pass - - @abstractmethod - def get_block_table(self, seq: Sequence) -> List[int]: - pass - - @abstractmethod - def get_num_free_gpu_blocks(self) -> int: - pass - - @abstractmethod - def get_num_free_cpu_blocks(self) -> int: - pass - - @abstractmethod - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - pass - - @abstractmethod - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - pass - - @abstractmethod - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self, device: Device) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - @abstractmethod - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache for specified or all devices.""" - pass - - @abstractmethod - def get_num_cached_tokens(self, seq: Sequence) -> int: - pass - - @abstractmethod - def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - pass \ No newline at end of file diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py deleted file mode 100644 index 679515924e85d..0000000000000 --- a/vllm/core/placeholder_block_space_manager.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Tuple - -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.sequence import Sequence, SequenceGroup -from vllm.utils import Device - - -class PlaceholderBlockSpaceManager(BlockSpaceManager): - """A version of BlockSpaceManager for use in environments - where block management is not required. - For example: pooling models or attention-free models like Mamba. - - This class provides the same interface as BlockSpaceManager, but its - methods perform no actions or return simple values like True in specific - actions. It's designed to be used in scenarios where the overhead of - block management is unnecessary, such as in an embedding environment. - """ - - def __init__( - self, - **kwargs, - ) -> None: - pass - - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - # Always return OK for dummy purposes - return AllocStatus.OK - - def allocate(self, seq_group: SequenceGroup) -> None: - # No actual allocation logic needed - pass - - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return True - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - return [] - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - pass - - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - return AllocStatus.OK - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - return None # type: ignore - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - return True - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - return None # type: ignore - - def free(self, seq: Sequence) -> None: - # No operation on free - return - - def get_block_table(self, seq: Sequence) -> List[int]: - return None # type: ignore - - def get_num_free_gpu_blocks(self) -> int: - return 1 - - def get_num_free_cpu_blocks(self) -> int: - return 1 - - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - pass - - def get_common_computed_block_ids(self, - seq_group: List[Sequence]) -> List[int]: - return [] - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - pass - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return -1 - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - return True - - def get_num_cached_tokens(self, seq: Sequence) -> int: - return 0 - - def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - return diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py deleted file mode 100644 index 63894e7f5dc8b..0000000000000 --- a/vllm/core/scheduler.py +++ /dev/null @@ -1,2028 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -import os -import random -import time -from collections import deque -from dataclasses import dataclass, field -from typing import Callable, Deque, Dict, Iterable, List, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union - -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupBase, SequenceGroupMetadata, - SequenceGroupMetadataDelta, SequenceStage, - SequenceStatus) -from vllm.utils import Device, PyObjectCache - -logger = init_logger(__name__) - -# Test-only. If configured, decode is preempted with -# ARTIFICIAL_PREEMPTION_PROB% probability. -ENABLE_ARTIFICIAL_PREEMPT = bool( - os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa -ARTIFICIAL_PREEMPTION_PROB = 0.5 -ARTIFICIAL_PREEMPTION_MAX_CNT = 500 - - -class PreemptionMode(enum.Enum): - """Preemption modes. - - 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory - and swap them back in when the sequences are resumed. - 2. Recomputation: Discard the blocks of the preempted sequences and - recompute them when the sequences are resumed, treating the sequences as - new prompts. - """ - - SWAP = enum.auto() - RECOMPUTE = enum.auto() - - -@dataclass -class SchedulingBudget: - """The available slots for scheduling. - - TODO(sang): Right now, the budget is request_id-aware meaning it can ignore - budget update from the same request_id. It is because in normal scheduling - path, we update RUNNING num_seqs ahead of time, meaning it could be - updated more than once when scheduling RUNNING requests. Since this won't - happen if we only have chunked prefill scheduling, we can remove this - feature from the API when chunked prefill is enabled by default. - """ - - token_budget: int - max_num_seqs: int - _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) - _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) - # Number of cached tokens in the batch. - _num_cached_tokens: int = 0 - # Number of actual non-cached tokens in the batch. - _num_batched_tokens: int = 0 - _num_curr_seqs: int = 0 - - def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - # We allow num_new_tokens to be 0 when the entire sequence has - # been cached. - assert num_new_tokens >= 0 - assert num_new_seqs != 0 - return (self.num_batched_tokens + num_new_tokens <= self.token_budget - and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) - - def remaining_token_budget(self): - return self.token_budget - self.num_batched_tokens - - def add_num_batched_tokens(self, - req_id: str, - num_batched_tokens: int, - num_cached_tokens: int = 0): - if req_id in self._request_ids_num_batched_tokens: - return - assert num_cached_tokens >= 0 - assert num_batched_tokens >= 0 - - self._request_ids_num_batched_tokens.add(req_id) - self._num_batched_tokens += num_batched_tokens - self._num_cached_tokens += num_cached_tokens - - def subtract_num_batched_tokens(self, req_id: str, - num_batched_tokens: int): - if req_id in self._request_ids_num_batched_tokens: - self._request_ids_num_batched_tokens.remove(req_id) - self._num_batched_tokens -= num_batched_tokens - - def add_num_seqs(self, req_id: str, num_curr_seqs: int): - if req_id in self._request_ids_num_curr_seqs: - return - - self._request_ids_num_curr_seqs.add(req_id) - self._num_curr_seqs += num_curr_seqs - - def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): - if req_id in self._request_ids_num_curr_seqs: - self._request_ids_num_curr_seqs.remove(req_id) - self._num_curr_seqs -= num_curr_seqs - - @property - def num_batched_tokens(self): - return self._num_batched_tokens - - @property - def num_curr_seqs(self): - return self._num_curr_seqs - - @property - def num_cached_tokens(self): - return self._num_cached_tokens - - -@dataclass -class ScheduledSequenceGroup: - # A sequence group that's scheduled. - seq_group: SequenceGroup - # The total chunk size (number of tokens) to process for next iteration. - # 1 for decoding. Same as prompt tokens for prefill, but if prefill is - # chunked, it can be smaller than that. - token_chunk_size: int - - -@dataclass -class SchedulerOutputs: - """The scheduling decision made from a scheduler.""" - - # Scheduled sequence groups. - scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] - # Number of prefill groups scheduled. - num_prefill_groups: int - # Total number of batched tokens. - num_batched_tokens: int - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] - # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] - # Sequence groups that are going to be ignored. - ignored_seq_groups: List[SequenceGroup] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - # The number of requests in the running queue - running_queue_size: int - preempted: int - - def __post_init__(self): - # Swap in and swap out should never happen at the same time. - assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) - - self.num_loras: int = len(self.lora_requests) - if self.num_loras > 0: - self._sort_by_lora_ids() - - def is_empty(self) -> bool: - # NOTE: We do not consider the ignored sequence groups. - return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) - - def _sort_by_lora_ids(self): - assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups) - - def key_fn(group: ScheduledSequenceGroup): - key = (group.seq_group.lora_int_id, group.seq_group.request_id) - if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups): - # Sort sequence groups so that all prefills come before all - # decodes as required by chunked prefill. - return (not group.seq_group.is_prefill(), *key) - return key - - self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, - key=key_fn) - - @property - def lora_requests(self) -> Set[LoRARequest]: - return { - g.seq_group.lora_request - for g in self.scheduled_seq_groups - if g.seq_group.lora_request is not None - } - - -@dataclass -class SchedulerRunningOutputs: - """The requests that are scheduled from a running queue. - - Could contain prefill (prefill that's chunked) or decodes. If there's not - enough memory, it can be preempted (for recompute) or swapped out. - """ - - # Selected sequences that are running and in a decoding phase. - decode_seq_groups: List[ScheduledSequenceGroup] - # Selected sequences that are running and in a prefill phase. - # I.e., it means the prefill has been chunked. - prefill_seq_groups: List[ScheduledSequenceGroup] - # The preempted sequences. - preempted: List[SequenceGroup] - # Sequences that are swapped out. - swapped_out: List[SequenceGroup] - # The blocks to swap out. - blocks_to_swap_out: List[Tuple[int, int]] - # The blocks to copy. - blocks_to_copy: List[Tuple[int, int]] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - - # Optimization for fast-access to seq_group lists - decode_seq_groups_list: List[SequenceGroup] - prefill_seq_groups_list: List[SequenceGroup] - - @classmethod - def create_empty(cls) -> "SchedulerRunningOutputs": - return SchedulerRunningOutputs( - decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - decode_seq_groups_list=[], - prefill_seq_groups_list=[], - ) - - -@dataclass -class SchedulerSwappedInOutputs: - """The requests that are scheduled from a swap queue. - - Could contain prefill (prefill that's chunked) or decodes. - """ - - # Selected sequences that are going to be swapped in and is in a - # decoding phase. - decode_seq_groups: List[ScheduledSequenceGroup] - # Selected sequences that are going to be swapped in and in a prefill - # phase. I.e., it means the prefill has been chunked. - prefill_seq_groups: List[ScheduledSequenceGroup] - # The blocks to swap in. - blocks_to_swap_in: List[Tuple[int, int]] - # The blocks to copy. - blocks_to_copy: List[Tuple[int, int]] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - # Infeasible sequence groups. - infeasible_seq_groups: List[SequenceGroup] - - @classmethod - def create_empty(cls) -> "SchedulerSwappedInOutputs": - return SchedulerSwappedInOutputs( - decode_seq_groups=[], - prefill_seq_groups=[], - blocks_to_swap_in=[], - blocks_to_copy=[], - num_lookahead_slots=0, - infeasible_seq_groups=[], - ) - - -@dataclass -class SchedulerPrefillOutputs: - """The requests that are scheduled from a waiting queue. - - Could contain a fresh prefill requests or preempted requests that need - to be recomputed from scratch. - """ - - # Selected sequences for prefill. - seq_groups: List[ScheduledSequenceGroup] - # Ignored sequence groups. - ignored_seq_groups: List[SequenceGroup] - num_lookahead_slots: int - - @classmethod - def create_empty(cls) -> "SchedulerPrefillOutputs": - return SchedulerPrefillOutputs( - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=0, - ) - - -def seq_group_metadata_builder(): - return SequenceGroupMetadata(request_id="", - is_prompt=False, - seq_data={}, - sampling_params=None, - block_tables={}) - - -def scheduler_running_outputs_builder(): - return SchedulerRunningOutputs(decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - prefill_seq_groups_list=[], - decode_seq_groups_list=[]) - - -def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), - token_chunk_size=0) - # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) - - -@dataclass -class PartialPrefillMetadata: - """Holds information about the partial prefills that are currently running - during a single iteration of the Scheduler. - When chunked prefill is enabled, we allow a certain number of seqs to be - partially prefilled during each iteration. Having multiple partial prefills - in flight allows us to minimize TTFT and avoid decode starvation in cases - where a single sequence group with a very large prompt blocks the queue for - too many iterations. - The number of long prefill requests is limited so that smaller - requests may jump the queue in front of them and get to the decode - phase faster. - """ - - # A minimum bound on the total number of prefills to be scheduled during - # this iteration - schedulable_prefills: int - - # The number of long prefill requests currently running - long_prefills: int - - scheduler_config: SchedulerConfig - - def can_schedule(self, seq_group: SequenceGroup) -> bool: - """When concurrent partial prefills are enabled, - we limit the number of long requests and only accept - shorter requests from the queue while running them - concurrently""" - return not (seq_group.first_seq.get_num_new_tokens() - > self.scheduler_config.long_prefill_token_threshold - and self.long_prefills - >= self.scheduler_config.max_long_partial_prefills - and self.scheduler_config.max_num_partial_prefills > 1) - - def maybe_increment_partial_prefills(self, - seq_group: SequenceGroup) -> None: - # When a new prefill is scheduled, we need to know if it is a - # long request - if (seq_group.first_seq.get_num_new_tokens() - > self.scheduler_config.long_prefill_token_threshold): - self.long_prefills += 1 - - @classmethod - def from_queues( - cls, - running: Deque[SequenceGroup], - waiting: Deque[SequenceGroup], - scheduler_config: SchedulerConfig, - ) -> "PartialPrefillMetadata": - """Create a PartialPrefillMetadata object from the current state of - the scheduler's queues. - This accounts for the currently running prefill requests, and peeks into - the waiting queue to see if there are more prefills to potentially be - scheduled during this iteration.""" - prefills = 0 - long_prefills = 0 - - waiting_long_prefills = 0 - - for sg in running: - if sg.first_seq.data.stage == SequenceStage.PREFILL: - prefills += 1 - if (sg.first_seq.get_num_new_tokens() - > scheduler_config.long_prefill_token_threshold): - long_prefills += 1 - - for sg in waiting: - # Don't bother looping through the rest of the queue if we know - # there are already at - # least max_partial_prefills requests to fill - if prefills >= scheduler_config.max_num_partial_prefills: - break - - # Don't count long requests from the waiting queue if we aren't - # going to schedule them anyway - if (sg.first_seq.get_num_new_tokens() - > scheduler_config.long_prefill_token_threshold): - if (long_prefills + waiting_long_prefills - >= scheduler_config.max_long_partial_prefills): - continue - waiting_long_prefills += 1 - prefills += 1 - - # NB: long_prefills and waiting_long_prefills are tracked separately. - # We don't account for the waiting requests here because we need to use - # this metadata to track how many have actually been scheduled. - return PartialPrefillMetadata( - schedulable_prefills=min( - prefills, scheduler_config.max_num_partial_prefills), - long_prefills=long_prefills, - scheduler_config=scheduler_config, - ) - - -class Scheduler: - - def __init__( - self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - output_proc_callback: Optional[Callable] = None, - ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - # Note for LoRA scheduling: the current policy is extremely - # simple and NOT fair. It can lead to starvation of some - # LoRAs. This should be improved in the future. - self.lora_config = lora_config - - version = "selfattn" - if (self.scheduler_config.runner_type == "pooling" - or self.cache_config.is_attention_free): - version = "placeholder" - - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version) - - num_gpu_blocks = cache_config.num_gpu_blocks - if num_gpu_blocks: - num_gpu_blocks //= pipeline_parallel_size - - num_cpu_blocks = cache_config.num_cpu_blocks - if num_cpu_blocks: - num_cpu_blocks //= pipeline_parallel_size - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching, - ) - - # Sequence groups in the WAITING state. - # Contain new prefill or preempted requests. - self.waiting: Deque[SequenceGroup] = deque() - # Sequence groups in the RUNNING state. - # Contain decode requests. - self.running: Deque[SequenceGroup] = deque() - # Sequence groups in the SWAPPED state. - # Contain decode requests that are swapped out. - self.swapped: Deque[SequenceGroup] = deque() - # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests - # can and must be released after the current step. - # This is used to evict the finished requests from the Mamba cache. - self._finished_requests_ids: List[str] = list() - # Time at previous scheduling step - self.prev_time = 0.0 - # Did we schedule a prompt at previous step? - self.prev_prompt = False - # Latency of the last prompt step - self.last_prompt_latency = 0.0 - # preemption mode, RECOMPUTE or SWAP - self.user_specified_preemption_mode = scheduler_config.preemption_mode - - # The following field is test-only. It is used to inject artificial - # preemption. - self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT - self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT - if self.enable_artificial_preemption - else 0) - self.num_cumulative_preemption: int = 0 - - # Used to cache python objects - self._seq_group_metadata_cache: List[PyObjectCache] = [] - self._scheduler_running_outputs_cache: List[PyObjectCache] = [] - self._scheduled_seq_group_cache: List[PyObjectCache] = [] - - # For async output processing, we need to swap cache buffers between - # iterations. I.e. since the output processing is lagged one step, - # we cannot reuse the cached objects immediately when the schedule() - # is called again, but only when schedule() is called the second time. - self.output_proc_callback = output_proc_callback - self.use_async_output_proc = self.output_proc_callback is not None - self.num_cache_iters = 2 if self.use_async_output_proc else 1 - - self.cache_id = 0 - for i in range(self.num_cache_iters): - self._seq_group_metadata_cache.append( - PyObjectCache(seq_group_metadata_builder)) - self._scheduler_running_outputs_cache.append( - PyObjectCache(scheduler_running_outputs_builder)) - self._scheduled_seq_group_cache.append( - PyObjectCache(scheduled_seq_group_builder)) - - # For async postprocessor, the extra decode run cannot be done - # when the request reaches max_model_len. In this case, the request - # will be stopped during schedule() call and added to this stop list - # for processing and deallocation by the free_finished_seq_groups() - self._async_stopped: List[SequenceGroup] = [] - - # List with the chunk sizes to hand out to each sequence depending - # on how many partial prefills are running. This is slightly faster than - # running an integer division every time a prefill is scheduled. - # This splits the budget evenly among all prefills. - self.partial_prefill_budget_lookup_list = [0] * ( - self.scheduler_config.max_num_partial_prefills + 1) - self.partial_prefill_budget_lookup_list[0] = ( - scheduler_config.max_num_batched_tokens) - for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): - self.partial_prefill_budget_lookup_list[i] = ( - scheduler_config.max_num_batched_tokens // i) - - @property - def next_cache_id(self): - return (self.cache_id + 1) % self.num_cache_iters - - @property - def lora_enabled(self) -> bool: - return bool(self.lora_config) - - @property - def num_decoding_tokens_per_seq(self) -> int: - """The number of new tokens.""" - return 1 - - def add_seq_group(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the waiting queue. - self.waiting.append(seq_group) - - def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the running queue. - # Only for testing purposes. - self.running.append(seq_group) - - def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the swapped queue. - # Only for testing purposes. - self.swapped.append(seq_group) - - def abort_seq_group( - self, - request_id: Union[str, Iterable[str]], - seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, - ) -> None: - """Aborts a sequence group with the given ID. - - Check if the sequence group with the given ID - is present in any of the state queue. - If present, remove the sequence group from the state queue. - Also, if any of the sequences in the sequence group is not finished, - free the sequence with status `FINISHED_ABORTED`. - Otherwise, do nothing. - - Args: - request_id: The ID(s) of the sequence group to abort. - seq_id_to_seq_group: helper for groups with n>1 - """ - if isinstance(request_id, str): - request_id = (request_id, ) - request_ids = set(request_id) - seq_id_to_seq_group = seq_id_to_seq_group or {} - for state_queue in [self.waiting, self.running, self.swapped]: - aborted_groups: List[SequenceGroup] = [] - for seq_group in state_queue: - # When n>1, seq_group.request_id looks like - # foo_parallel_sample_0, while request_ids is just foo, and we - # should resolve it as real_request_id to match. - if seq_group.request_id in seq_id_to_seq_group: - real_request_id = seq_id_to_seq_group[ - seq_group.request_id].group_id - else: - real_request_id = seq_group.request_id - if real_request_id in request_ids: - # Appending aborted group into pending list. - aborted_groups.append(seq_group) - # We can't remove real_request_id in request_ids here, - # because there may be other seq groups sharing the same - # real_request_id - for aborted_group in aborted_groups: - # Remove the sequence group from the state queue. - state_queue.remove(aborted_group) - # Remove the aborted request from the Mamba cache. - self._finished_requests_ids.append(aborted_group.request_id) - for seq in aborted_group.get_seqs(): - if seq.is_finished(): - continue - seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) - if aborted_group.request_id in seq_id_to_seq_group: - del seq_id_to_seq_group[aborted_group.request_id] - - self._free_seq_group_cross_attn_blocks(aborted_group) - - def _free_seq_group_cross_attn_blocks( - self, - seq_group: SequenceGroup, - ) -> None: - """ - Free a sequence group from a cross-attention block table. - Has no effect on decoder-only models. - """ - if seq_group.is_encoder_decoder(): - self.block_manager.free_cross(seq_group) - - def has_unfinished_seqs(self) -> bool: - return (len(self.waiting) != 0 or len(self.running) != 0 - or len(self.swapped) != 0) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_manager.get_prefix_cache_hit_rate(device) - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - return self.block_manager.reset_prefix_cache(device) - - def get_num_unfinished_seq_groups(self) -> int: - return len(self.waiting) + len(self.running) + len(self.swapped) - - def get_and_reset_finished_requests_ids(self) -> List[str]: - """Flushes the list of request ids of previously finished seq_groups.""" - finished_requests_ids = self._finished_requests_ids - self._finished_requests_ids = list() - return finished_requests_ids - - def _schedule_running( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> SchedulerRunningOutputs: - """Schedule sequence groups that are running. - - Running queue should include decode and chunked prefill requests. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any decodes are preempted. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any decodes are preempted. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - partial_prefill_metadata: information about the partial prefills - that are currently running - - Returns: - SchedulerRunningOutputs. - """ - ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ - self.cache_id].get_object() - ret.blocks_to_swap_out.clear() - ret.blocks_to_copy.clear() - ret.decode_seq_groups.clear() - ret.prefill_seq_groups.clear() - ret.preempted.clear() - ret.swapped_out.clear() - - ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking) - - ret.decode_seq_groups_list.clear() - ret.prefill_seq_groups_list.clear() - - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out - blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy - - decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups - prefill_seq_groups: List[ - ScheduledSequenceGroup] = ret.prefill_seq_groups - preempted: List[SequenceGroup] = ret.preempted - swapped_out: List[SequenceGroup] = ret.swapped_out - - running_queue = self.running - assert len(self._async_stopped) == 0 - while running_queue: - seq_group = running_queue[0] - # We discard the cached tokens info here because we don't need it - # for running sequence: - # 1. If a sequence is running with chunked prefill, the cached - # tokens info was already used for the first prefill. - # 2. If a sequence is running with non-chunked prefill, then - # there it's a decoding sequence, and the cached tokens info is - # irrelevant. - num_uncached_new_tokens, _ = \ - self._get_num_new_uncached_and_cached_tokens( - seq_group, - SequenceStatus.RUNNING, - enable_chunking, - budget, - partial_prefill_metadata, - ) - - num_running_tokens = num_uncached_new_tokens - if num_running_tokens == 0: - # No budget => Stop - break - - running_queue.popleft() - - # With async postprocessor, an extra decode run is done - # to process the final tokens. The check below avoids this extra - # decode run when the model max len is reached, in order to avoid - # a memory overflow. - if (self.use_async_output_proc and seq_group.seqs[0].get_len() - > self.scheduler_config.max_model_len): - self._async_stopped.append(seq_group) - continue - - # NOTE(woosuk): Preemption happens only when there is no available - # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group, enable_chunking): - budget.subtract_num_batched_tokens(seq_group.request_id, - num_running_tokens) - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(seq_group.request_id, - num_running_seqs) - - if (curr_loras is not None and seq_group.lora_int_id > 0 - and seq_group.lora_int_id in curr_loras): - curr_loras.remove(seq_group.lora_int_id) - - # Determine victim sequence - cont_loop = True - if running_queue: - # Preempt the lowest-priority sequence group. - victim_seq_group = running_queue.pop() - else: - # No other sequence group can be preempted. - # Preempt the current sequence group. - # Note: This is also where we stop this loop - # (since there is nothing else to preempt) - victim_seq_group = seq_group - cont_loop = False - - # With async postprocessor, before preempting a sequence - # we need to ensure it has no pending async postprocessor - do_preempt = True - if self.use_async_output_proc: - assert self.output_proc_callback is not None - self.output_proc_callback( - request_id=victim_seq_group.request_id) - - # It may be that the async pending "victim_seq_group" - # becomes finished, in which case we simply free it. - if victim_seq_group.is_finished(): - self._free_finished_seq_group(victim_seq_group) - do_preempt = False - - # Do preemption - if do_preempt: - preempted_mode = self._preempt(victim_seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(victim_seq_group) - else: - swapped_out.append(victim_seq_group) - - if not cont_loop: - break - else: - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - is_prefill = seq_group.is_prefill() - - scheduled_seq_group: ScheduledSequenceGroup = ( - self._scheduled_seq_group_cache[ - self.cache_id].get_object()) - scheduled_seq_group.seq_group = seq_group - if is_prefill: - scheduled_seq_group.token_chunk_size = num_running_tokens - prefill_seq_groups.append(scheduled_seq_group) - ret.prefill_seq_groups_list.append(seq_group) - else: - scheduled_seq_group.token_chunk_size = 1 - decode_seq_groups.append(scheduled_seq_group) - ret.decode_seq_groups_list.append(seq_group) - - budget.add_num_batched_tokens(seq_group.request_id, - num_running_tokens) - # OPTIMIZATION: Note that get_max_num_running_seqs is - # expensive. For the default scheduling chase where - # enable_chunking is False, num_seqs are updated before running - # this method, so we don't have to update it again here. - if enable_chunking: - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.add_num_seqs(seq_group.request_id, num_running_seqs) - if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.add(seq_group.lora_int_id) - - self._scheduler_running_outputs_cache[self.next_cache_id].reset() - self._scheduled_seq_group_cache[self.next_cache_id].reset() - - return ret - - def _schedule_swapped( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - ) -> SchedulerSwappedInOutputs: - """Schedule sequence groups that are swapped out. - - It schedules swapped requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are swapped in. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are swapped in. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - - Returns: - SchedulerSwappedInOutputs. - """ - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: List[Tuple[int, int]] = [] - blocks_to_copy: List[Tuple[int, int]] = [] - decode_seq_groups: List[ScheduledSequenceGroup] = [] - prefill_seq_groups: List[ScheduledSequenceGroup] = [] - infeasible_seq_groups: List[SequenceGroup] = [] - - swapped_queue = self.swapped - - leftover_swapped: Deque[SequenceGroup] = deque() - while swapped_queue: - seq_group = swapped_queue[0] - - # If the sequence group cannot be swapped in, stop. - is_prefill = seq_group.is_prefill() - alloc_status = self.block_manager.can_swap_in( - seq_group, - self._get_num_lookahead_slots(is_prefill, enable_chunking)) - if alloc_status == AllocStatus.LATER: - break - elif alloc_status == AllocStatus.NEVER: - logger.warning( - "Failing the request %s because there's not enough kv " - "cache blocks to run the entire sequence.", - seq_group.request_id, - ) - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - infeasible_seq_groups.append(seq_group) - swapped_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (lora_int_id > 0 and (lora_int_id not in curr_loras) - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_swapped.appendleft(seq_group) - swapped_queue.popleft() - continue - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.SWAPPED, enable_chunking, - budget)) - - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.SWAPPED) - break - - if lora_int_id > 0 and curr_loras is not None: - curr_loras.add(lora_int_id) - swapped_queue.popleft() - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - if is_prefill: - prefill_seq_groups.append( - ScheduledSequenceGroup( - seq_group, - token_chunk_size=num_new_tokens_uncached + - num_new_tokens_cached, - )) - else: - decode_seq_groups.append( - ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens( - seq_group.request_id, - num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens_cached, - ) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - swapped_queue.extendleft(leftover_swapped) - - return SchedulerSwappedInOutputs( - decode_seq_groups=decode_seq_groups, - prefill_seq_groups=prefill_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking), - infeasible_seq_groups=infeasible_seq_groups, - ) - - def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled: - prompt_limit = self.scheduler_config.max_model_len - else: - prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens, - ) - - # Model is fine tuned with long context. Return the fine tuned max_len. - if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: - assert prompt_limit <= seq_group.lora_request.long_lora_max_len - return seq_group.lora_request.long_lora_max_len - else: - return prompt_limit - - def _get_priority(self, - seq_group: SequenceGroup) -> Tuple[Optional[int], float]: - """Get the priority of the sequence group. - Highest preference to user-defined priority, followed by arrival time. - Args: - seq_group: The sequence group input. - Returns: - The priority of the sequence group. - """ - return seq_group.priority, seq_group.arrival_time - - def _schedule_priority_preemption( - self, - budget: SchedulingBudget, - ) -> int: - """Sorts waiting and running queue. Also, force preempt requests - from the running queue if their priority is lower. - Priority-based preemption is used with the priority policy. - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - Returns: - A count of priority-based preemptions. - """ - - waiting_queue = self.waiting - - running_queue = deque(sorted(self.running, key=self._get_priority)) - - blocks_to_swap_out: List[Tuple[int, int]] = [] - force_preemption_count = 0 - - if waiting_queue: - seq_group = waiting_queue.popleft() - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, _ = \ - self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, False, budget) - - # Only preempt if priority inversion exists - while running_queue and self._get_priority( - running_queue[-1]) > self._get_priority(seq_group): - # Only preempt if waiting sequence cannot be allocated - can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens_uncached > 0 - and can_allocate == AllocStatus.OK - and budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - )): - break - - # Adjust budget to remove the victim sequence group - vseq_group = running_queue.pop() - num_running_tokens_uncached, _ = ( - self._get_num_new_uncached_and_cached_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget)) - budget.subtract_num_batched_tokens( - vseq_group.request_id, num_running_tokens_uncached) - num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(vseq_group.request_id, - num_running_seqs) - - # Preempt out the victim sequence group - self._preempt(vseq_group, blocks_to_swap_out) - waiting_queue.appendleft(vseq_group) - force_preemption_count += 1 - # Put the sequence back into the waiting queue - waiting_queue.appendleft(seq_group) - - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - - waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) - - self.waiting = waiting_queue - self.running = running_queue - return force_preemption_count - - def _schedule_prefills( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> SchedulerPrefillOutputs: - """Schedule sequence groups that are in prefill stage. - - Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE - as a new prefill (that starts from beginning -> most recently generated - tokens). - - It schedules waiting requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are scheduled. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - partial_prefill_metadata: information about the partial prefills - that are currently running - - Returns: - SchedulerPrefillOutputs. - """ - if budget.remaining_token_budget() == 0: - # Do nothing: Can't add any more prefill anyway - return SchedulerPrefillOutputs( - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), - ) - ignored_seq_groups: List[SequenceGroup] = [] - seq_groups: List[ScheduledSequenceGroup] = [] - using_prompt_embeds: bool = False - - waiting_queue = self.waiting - - leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: - seq_group = waiting_queue[0] - - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - if (partial_prefill_metadata is not None - and not partial_prefill_metadata.can_schedule(seq_group)): - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_uncached_and_cached_tokens( - seq_group, - SequenceStatus.WAITING, - enable_chunking, - budget, - partial_prefill_metadata=partial_prefill_metadata, - )) - num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - - if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() - assert num_new_tokens == num_prompt_tokens - - prompt_limit = self._get_prompt_limit(seq_group) - if num_new_tokens > prompt_limit: - logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", - num_new_tokens, - prompt_limit, - ) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.FINISHED_IGNORED) - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - num_lookahead_slots: int = 0 - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate( - seq_group, num_lookahead_slots=num_lookahead_slots) - if can_allocate == AllocStatus.LATER: - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - "Input prompt (%d tokens) + lookahead slots (%d) is " - "too long and exceeds the capacity of block_manager", - num_new_tokens, - num_lookahead_slots, - ) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.FINISHED_IGNORED) - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - # We cannot mix sequence groups that use prompt embeds and - # those that do not. - if len(seq_groups) == 0: - using_prompt_embeds = seq_group.uses_prompt_embeds() - if using_prompt_embeds != seq_group.uses_prompt_embeds(): - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (self.lora_enabled and lora_int_id > 0 - and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - - if (budget.num_batched_tokens - >= self.scheduler_config.max_num_batched_tokens): - # We've reached the budget limit - since there might be - # continuous prefills in the running queue, we should break - # to avoid scheduling any new prefills. - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - break - - num_new_seqs = seq_group.get_max_num_running_seqs() - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - break - - # Can schedule this request. - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() - self._allocate_and_set_running(seq_group) - - if partial_prefill_metadata is not None: - partial_prefill_metadata.maybe_increment_partial_prefills( - seq_group) - - seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens( - seq_group.request_id, - num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens_cached, - ) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - # Queue requests that couldn't be scheduled. - waiting_queue.extendleft(leftover_waiting_sequences) - if len(seq_groups) > 0: - self.prev_prompt = True - - return SchedulerPrefillOutputs( - seq_groups=seq_groups, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), - ) - - def _schedule_default(self) -> SchedulerOutputs: - """Schedule queued requests. - - The current policy is designed to optimize the throughput. First, - it batches as many prefill requests as possible. And it schedules - decodes. If there's a pressure on GPU memory, decode requests can - be swapped or preempted. - """ - # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - # Make sure we include num running seqs before scheduling prefill, - # so that we don't schedule beyond max_num_seqs for prefill. - for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) - curr_loras = (set( - seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None) - - prefills = SchedulerPrefillOutputs.create_empty() - running_scheduled = SchedulerRunningOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # If any requests are swapped, prioritized swapped requests. - if not self.swapped: - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=False) - - if len(prefills.seq_groups - ) == 0 and self.scheduler_config.policy == "priority": - self._schedule_priority_preemption(budget) - - # Don't schedule decodes if prefills are scheduled. - # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running - # only contains decode requests, not chunked prefills. - if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=False) - - # If any sequence group is preempted, do not swap in any sequence - # group. because it means there's no slot for new running requests. - if (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out) == 0): - swapped_in = \ - self._schedule_swapped(budget, curr_loras) - - assert (budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - # Update new running requests. - if len(prefills.seq_groups) > 0: - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - self.running.extend(running_scheduled.decode_seq_groups_list) - - if len(swapped_in.decode_seq_groups) > 0: - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - preempted = len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) - - # There should be no prefill from running queue because this policy - # doesn't allow chunked prefills. - assert len(running_scheduled.prefill_seq_groups) == 0 - assert len(swapped_in.prefill_seq_groups) == 0 - - # Merge lists - num_prefill_groups = len(prefills.seq_groups) - ignored_seq_groups_for_embeds = list[SequenceGroup]() - if num_prefill_groups > 0: - scheduled_seq_groups = prefills.seq_groups - scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) - ignored_seq_groups_for_embeds.clear() - else: - scheduled_seq_groups = running_scheduled.decode_seq_groups - if len(scheduled_seq_groups) > 0: - using_prompt_embeds = scheduled_seq_groups[ - 0].seq_group.uses_prompt_embeds() - ignored_seq_groups_for_embeds.clear() - indices_ignored = list[int]() - for i, schedule_seq_group in enumerate(scheduled_seq_groups): - if using_prompt_embeds !=\ - schedule_seq_group.seq_group.uses_prompt_embeds(): - ignored_seq_groups_for_embeds.append( - schedule_seq_group.seq_group) - indices_ignored.append(i) - if len(ignored_seq_groups_for_embeds) > 0: - scheduled_seq_groups = [ - group for i, group in enumerate(scheduled_seq_groups) - if i not in indices_ignored - ] - else: - ignored_seq_groups_for_embeds.clear() - - scheduled_seq_groups.extend(swapped_in.decode_seq_groups) - - blocks_to_copy = running_scheduled.blocks_to_copy - blocks_to_copy.extend(swapped_in.blocks_to_copy) - - ignored_seq_groups = prefills.ignored_seq_groups - ignored_seq_groups.extend(ignored_seq_groups_for_embeds) - ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) - - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + - budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, - running_queue_size=len(self.running), - preempted=preempted, - ) - - def _schedule_chunked_prefill(self) -> SchedulerOutputs: - """Schedule queued requests. - - Chunked prefill allows to chunk prefill requests, batch them together - with decode requests. This policy 1. schedule as many decoding requests - as possible. 2. schedule chunked prefill requests that are not - finished. 3. schedule swapped request. 4. schedule new prefill - requests. - - The policy can sustain the high GPU utilization because it can put - prefill and decodes requests to the same batch, while it improves - inter token latency because decodes requests don't need to be blocked - by prefill requests. - """ - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - curr_loras: Set[int] = set() - - prefills = SchedulerPrefillOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # Create partial prefill metadata - partial_prefill_metadata = PartialPrefillMetadata.from_queues( - running=self.running, - waiting=self.waiting, - scheduler_config=self.scheduler_config, - ) - - # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running( - budget, - curr_loras, - enable_chunking=True, - partial_prefill_metadata=partial_prefill_metadata, - ) - - # Schedule swapped out requests. - # If preemption happens, it means we don't have space for swap-in. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) - - prefills = self._schedule_prefills( - budget, - curr_loras, - enable_chunking=True, - partial_prefill_metadata=partial_prefill_metadata, - ) - - assert (budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - - # Update new running requests. - # By default, vLLM scheduler prioritizes prefills. - # Once chunked prefill is enabled, - # the policy is changed to prioritize decode requests. - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) - # Because multiple prefills may be running concurrently, we need to - # make sure that prefills which are scheduled to finish are listed - # before those that won't. This is so that on the next scheduling - # iteration when they have transitioned to the decode stage, they are - # properly prioritized over sequences that are still in the prefill - # stage. - self.running.extend( - self._order_finishing_prefills_first( - running_scheduled.prefill_seq_groups)) - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - # Put prefills first due to Attention backend ordering assumption. - scheduled_seq_groups = (prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups) - num_prefill_groups = (len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)) - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + - budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups + - swapped_in.infeasible_seq_groups, - num_lookahead_slots=0, - running_queue_size=len(self.running), - preempted=(len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)), - ) - - def _order_finishing_prefills_first( - self, scheduled_prefill_seqs: List[ScheduledSequenceGroup] - ) -> List[SequenceGroup]: - """Returns a list of prefilling SequenceGroups where sequences that are - scheduled to finish prefilling are listed first""" - finishing = [ - s.seq_group for s in scheduled_prefill_seqs - if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size - ] - not_finishing = [ - s.seq_group for s in scheduled_prefill_seqs - if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size - ] - return finishing + not_finishing - - def _schedule(self) -> SchedulerOutputs: - """Schedule queued requests.""" - if self.scheduler_config.chunked_prefill_enabled: - return self._schedule_chunked_prefill() - else: - return self._schedule_default() - - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: - """Determine whether or not we have enough space in the KV cache to - continue generation of the sequence group. - """ - # It is True only for testing case to trigger artificial preemption. - if (self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0): - self.artificial_preempt_cnt -= 1 - return False - - is_prefill = seq_group.is_prefill() - num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill, enable_chunking) - - return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) - - def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: - # async_output_proc is allowed only when we have a single sequence - # in the sequence group - no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1) - return no_single_seq - - def schedule( - self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. - scheduler_start_time = time.perf_counter() - - scheduler_outputs: SchedulerOutputs = self._schedule() - now = time.time() - - if not self.cache_config.enable_prefix_caching: - common_computed_block_nums = [] - - allow_async_output_proc: bool = self.use_async_output_proc - - # Create input data structures. - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - token_chunk_size = scheduled_seq_group.token_chunk_size - seq_group.maybe_set_first_scheduled_time(now) - - seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id].get_object() - seq_group_metadata.seq_data.clear() - seq_group_metadata.block_tables.clear() - - # seq_id -> SequenceData - seq_data: Dict[int, SequenceData] = {} - # seq_id -> physical block numbers - block_tables: Dict[int, List[int]] = {} - - if seq_group.is_encoder_decoder(): - # Encoder associated with SequenceGroup - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - encoder_seq_data = encoder_seq.data - # Block table for cross-attention - # Also managed at SequenceGroup level - cross_block_table = self.block_manager.get_cross_block_table( - seq_group) - else: - encoder_seq_data = None - cross_block_table = None - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq_id = seq.seq_id - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) - - if self.cache_config.enable_prefix_caching: - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) - - do_sample = True - is_prompt = seq_group.is_prefill() - # We should send the metadata to workers when the first prefill - # is sent. Subsequent requests could be chunked prefill or decode. - is_first_prefill = False - if is_prompt: - seqs = seq_group.get_seqs() - # Prefill has only 1 sequence. - assert len(seqs) == 1 - num_computed_tokens = seqs[0].data.get_num_computed_tokens() - is_first_prefill = num_computed_tokens == 0 - # In the next iteration, all prompt tokens are not computed. - # It means the prefill is chunked, and we don't need sampling. - # NOTE: We use get_len instead of get_prompt_len because when - # a sequence is preempted, prefill includes previous generated - # output tokens. - if (token_chunk_size + num_computed_tokens - < seqs[0].data.get_len()): - do_sample = False - - # It assumes the scheduled_seq_groups is ordered by - # prefill < decoding. - if is_first_prefill or not self.scheduler_config.send_delta_data: - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - do_sample=do_sample, - pooling_params=seq_group.pooling_params, - token_chunk_size=token_chunk_size, - lora_request=seq_group.lora_request, - computed_block_nums=common_computed_block_nums, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - state=seq_group.state, - token_type_ids=seq_group.token_type_ids, - # `multi_modal_data` will only be present for the 1st comm - # between engine and worker. - # the subsequent comms can still use delta, but - # `multi_modal_data` will be None. - multi_modal_data=(seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups - > 0 else None), - multi_modal_placeholders=( - seq_group.multi_modal_placeholders - if scheduler_outputs.num_prefill_groups > 0 else None), - ) - else: - # When SPMD mode is enabled, we only send delta data except for - # the first request to reduce serialization cost. - seq_data_delta = {} - for id, data in seq_data.items(): - seq_data_delta[id] = data.get_delta_and_reset() - seq_group_metadata = SequenceGroupMetadataDelta( - seq_data_delta, - seq_group.request_id, - block_tables, - is_prompt, - do_sample=do_sample, - token_chunk_size=token_chunk_size, - computed_block_nums=common_computed_block_nums, - ) - seq_group_metadata_list.append(seq_group_metadata) - - if allow_async_output_proc: - allow_async_output_proc = self._allow_async_output_proc( - seq_group) - - # Now that the batch has been created, we can assume all blocks in the - # batch will have been computed before the next scheduling invocation. - # This is because the engine assumes that a failure in model execution - # will crash the vLLM instance / will not retry. - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group, - scheduled_seq_group.token_chunk_size) - - self._seq_group_metadata_cache[self.next_cache_id].reset() - - scheduler_time = time.perf_counter() - scheduler_start_time - # Add this to scheduler time to all the sequences that are currently - # running. This will help estimate if the scheduler is a significant - # component in the e2e latency. - for seq_group in self.running: - if seq_group is not None and seq_group.metrics is not None: - if seq_group.metrics.scheduler_time is not None: - seq_group.metrics.scheduler_time += scheduler_time - else: - seq_group.metrics.scheduler_time = scheduler_time - - # Move to next cache (if exists) - self.cache_id = self.next_cache_id - - # Return results - return (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) - - def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: - self.block_manager.fork(parent_seq, child_seq) - - def free_seq(self, seq: Sequence) -> None: - """Free a sequence from a block table.""" - self.block_manager.free(seq) - - def remove_seq_from_computed_blocks_tracker( - self, seq_group: SequenceGroup, - status: Optional[SequenceStatus]) -> None: - seqs = seq_group.get_seqs(status=status) - for seq in seqs: - self._remove_seq_from_computed_blocks_tracker(seq) - - def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - """ - Free a sequence computed blocks tracker _seq_id_to_blocks_hashes - and _seq_id_to_num_tokens_computed. - """ - self.block_manager.remove_seq_from_computed_blocks_tracker(seq) - - def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: - """Free finished seqs in a sequence group.""" - for seq in seq_group.get_seqs(): - if seq.is_finished(): - self.free_seq(seq) - - def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - def free_finished_seq_groups(self) -> None: - remaining: Deque[SequenceGroup] = deque() - for seq_group in self.running: - self._free_finished_seq_group(seq_group) - if not seq_group.is_finished(): - remaining.append(seq_group) - - self.running = remaining - - # Handle async stopped sequence groups - # (ones that reached max model len) - if self._async_stopped: - for seq_group in self._async_stopped: - self._free_seq_group_cross_attn_blocks(seq_group) - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - self._async_stopped.clear() - - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - seq.status = SequenceStatus.RUNNING - - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - enable_chunking: bool = False, - ) -> None: - """Appends new slots to the sequences in the given sequence group. - - Args: - seq_group (SequenceGroup): The sequence group containing the - sequences to append slots to. - blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two - ints, the first int is the source block index, and the second - int is the destination block index. This list is updated with - the new source and destination block indices for the appended - slots. - enable_chunking (bool): True if chunked prefill is enabled. - """ - is_prefill: bool = seq_group.is_prefill() - num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking) - - seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING - for seq in seq_group.get_seqs(status=seq_status): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) - if len(cows) > 0: - blocks_to_copy.extend(cows) - - def _preempt(self, seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: - # If preemption mode is not specified, we determine the mode as follows: - # We use recomputation by default since it incurs lower overhead than - # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not currently supported. In - # such a case, we use swapping instead. - # FIXME(woosuk): This makes our scheduling policy a bit bizarre. - # As swapped sequences are prioritized over waiting sequences, - # sequence groups with multiple sequences are implicitly prioritized - # over sequence groups with a single sequence. - # TODO(woosuk): Support recomputation for sequence groups with multiple - # sequences. This may require a more sophisticated CUDA kernel. - if self.user_specified_preemption_mode is None: - if seq_group.get_max_num_running_seqs() == 1: - preemption_mode = PreemptionMode.RECOMPUTE - else: - preemption_mode = PreemptionMode.SWAP - - elif self.user_specified_preemption_mode == "swap": - preemption_mode = PreemptionMode.SWAP - else: - preemption_mode = PreemptionMode.RECOMPUTE - - if self.num_cumulative_preemption % 50 == 0: - logger.warning( - "Sequence group %s is preempted by %s mode because there is " - "not enough KV cache space. This can affect the end-to-end " - "performance. Increase gpu_memory_utilization or " - "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", - seq_group.request_id, - preemption_mode, - self.num_cumulative_preemption + 1, - ) - self.num_cumulative_preemption += 1 - - if preemption_mode == PreemptionMode.RECOMPUTE: - self._preempt_by_recompute(seq_group) - elif preemption_mode == PreemptionMode.SWAP: - self._preempt_by_swap(seq_group, blocks_to_swap_out) - else: - raise AssertionError("Invalid preemption mode.") - return preemption_mode - - def _preempt_by_recompute( - self, - seq_group: SequenceGroup, - ) -> None: - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - assert len(seqs) == 1 - for seq in seqs: - seq.status = SequenceStatus.WAITING - self.free_seq(seq) - seq.reset_state_for_recompute() - self._free_seq_group_cross_attn_blocks(seq_group) - - def _preempt_by_swap( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - ) -> None: - self._swap_out(seq_group, blocks_to_swap_out) - - def _swap_in( - self, - seq_group: SequenceGroup, - blocks_to_swap_in: List[Tuple[int, int]], - ) -> None: - mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - seq.status = SequenceStatus.RUNNING - - def _swap_out( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - ) -> None: - if not self.block_manager.can_swap_out(seq_group): - # FIXME(woosuk): Abort the sequence group instead of aborting the - # entire engine. - raise RuntimeError( - "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") - mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.SWAPPED - - def _passed_delay(self, now: float) -> bool: - if self.prev_prompt: - self.last_prompt_latency = now - self.prev_time - self.prev_time, self.prev_prompt = now, False - # Delay scheduling prompts to let waiting queue fill up - if self.scheduler_config.delay_factor > 0 and self.waiting: - earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ((now - earliest_arrival_time) - > (self.scheduler_config.delay_factor * - self.last_prompt_latency) or not self.running) - else: - passed_delay = True - return passed_delay - - def _get_num_lookahead_slots(self, is_prefill: bool, - enable_chunking: bool) -> int: - """The number of slots to allocate per sequence per step, beyond known - token ids. Speculative decoding uses these slots to store KV activations - of tokens which may or may not be accepted. - """ - return 0 - - def _get_num_new_uncached_and_cached_tokens( - self, - seq_group: SequenceGroup, - status: SequenceStatus, - enable_chunking: bool, - budget: SchedulingBudget, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> Tuple[int, int]: - """ - Returns the number of new uncached and cached tokens to schedule for a - given sequence group that's in a given `status`. - - The API could chunk the number of tokens to compute based on `budget` - if `enable_chunking` is True. If a sequence group has multiple - sequences (e.g., running beam search), it means it is in decoding - phase, so chunking doesn't happen. - - Returns (0, 0) if the new token cannot be computed due to token budget. - - The cached tokens's blocks are already computed, and the attention - backend will reuse the cached blocks rather than recomputing them. So - the scheduler could schedule these cached tokens "for free". - - Args: - seq_group: The sequence group to get the number of new tokens to - schedule. - status: The status of the sequences to get the number of new tokens - to schedule. - enable_chunking: Whether to chunk the number of tokens to compute. - budget: The budget to chunk the number of tokens to compute. - partial_prefill_metadata: information about the partial prefills - that are currently running - - - Returns: - A tuple of two ints. The first int is the number of new uncached - tokens to schedule. The second int is the number of cached tokens. - If no more new tokens can be scheduled, returns (0, 0). - """ - num_cached_new_tokens = 0 - num_uncached_new_tokens = 0 - - seqs = seq_group.get_seqs(status=status) - # Compute the number of new uncached and cached tokens for - # each sequence. - for seq in seqs: - if not seq.is_prefill(): - # Decode sequences should always just have 1 uncached token - # TODO(rickyx): Actually is this still correct for multi-step? - num_uncached_new_tokens += 1 - continue - - num_computed_tokens_seq = seq.get_num_computed_tokens() - all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq - if not self.cache_config.enable_prefix_caching: - # If prefix caching is not enabled, all new tokens are uncached. - num_uncached_new_tokens += all_num_new_tokens_seq - continue - - # NOTE: the cache token might be currently in a block that's in an - # evictor meaning that it's not yet allocated. However, we don't - # exclude such tokens in the cache count because it will be - # guaranteed to be allocated later if the sequence can be allocated. - num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( - seq) - - # Sanity check. - if num_cached_tokens_seq < num_computed_tokens_seq: - # This should only happen with chunked prefill, and - # the seq is still in prefill. The `num_cached_tokens_seq` - # is the value we calculated on scheduling the first prefill. - # For subsequent continuous prefill steps, we cached the - # number of cache tokens for the sequence so the cached token - # count could be less than the number of computed tokens. - # See comments on `ComputedBlocksTracker` for more details. - assert ( - seq.is_prefill() and seq.status == SequenceStatus.RUNNING - and self.scheduler_config.chunked_prefill_enabled - ), ("Number of cached tokens should not be less than the " - "number of computed tokens for a sequence that's still " - f"in prefill. But there are {num_cached_tokens_seq} cached " - f"tokens and {num_computed_tokens_seq} computed tokens " - f"for sequence {seq.seq_id}.") - - num_cached_new_tokens_seq = max( - 0, num_cached_tokens_seq - num_computed_tokens_seq) - num_uncached_new_tokens_seq = (all_num_new_tokens_seq - - num_cached_new_tokens_seq) - - num_uncached_new_tokens += num_uncached_new_tokens_seq - num_cached_new_tokens += num_cached_new_tokens_seq - - if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: - # For a fully cached hit sequence, we actually need to recompute the - # last token. So we need at least 1 uncached token to schedule. - # See ModelRunner._compute_for_prefix_cache_hit for more details. - num_uncached_new_tokens = 1 - num_cached_new_tokens -= 1 - - if enable_chunking and len(seqs) == 1: - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. - num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( - self.scheduler_config, - self.cache_config, - budget, - self._get_prompt_limit(seq_group), - num_uncached_new_tokens, - self.partial_prefill_budget_lookup_list, - partial_prefill_metadata, - ) - - return num_uncached_new_tokens, num_cached_new_tokens - - @staticmethod - def _chunk_new_tokens_to_schedule( - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - budget: SchedulingBudget, - prompt_limit: int, - num_new_tokens: int, - partial_prefill_budget_lookup_list: List[int], - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> int: - """ - Chunks the number of new tokens to schedule based on the budget when - chunked prefill is enabled. - - Args: - scheduler_config: The scheduler config. - cache_config: The cache config. - budget: The budget to chunk the number of tokens to compute. - prompt_limit: The maximum number of tokens allowed in a prompt. - num_new_tokens: The number of new tokens to schedule. - - Returns: - The number of new tokens to schedule after chunking. - """ - remaining_token_budget = budget.remaining_token_budget() - - # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = ( - remaining_token_budget if partial_prefill_metadata is None else - partial_prefill_budget_lookup_list[ - partial_prefill_metadata.schedulable_prefills]) - - if cache_config.enable_prefix_caching: - # When prefix caching is enabled and we're partially prefilling - # a sequence, we always allocate a number of new tokens that is - # divisible by the block size to avoid partial block matching. - block_size = cache_config.block_size - # Don't exceed either the total budget or slot budget. - # Take min of those and get the next lowest multiple of the - # block size: - remaining_token_budget = ( - min(remaining_token_budget, prefill_slot_budget) // - block_size) * block_size - # NB: In the case where num_new_tokens < budget, we are - # finishing prefill for this sequence, so we do not need to - # allocate a full block. - - num_new_tokens = min(num_new_tokens, remaining_token_budget, - prefill_slot_budget) - - return num_new_tokens diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 942e866ed97ee..97c6654385b35 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -16,8 +16,11 @@ from typing import Any, Callable, Optional, Union import torch +from vllm.logger import init_logger from vllm.utils import is_pin_memory_available +logger = init_logger(__name__) + def find_loaded_library(lib_name) -> Optional[str]: """ @@ -25,7 +28,7 @@ def find_loaded_library(lib_name) -> Optional[str]: the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ # noqa found_line = None with open("/proc/self/maps") as f: for line in f: @@ -40,17 +43,21 @@ def find_loaded_library(lib_name) -> Optional[str]: start = found_line.index("/") path = found_line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ + assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" + ) return path cumem_available = False try: - from vllm.cumem_allocator import (init_module, python_create_and_map, - python_unmap_and_release) - from vllm.distributed.device_communicators.cuda_wrapper import ( - CudaRTLibrary) + from vllm.cumem_allocator import ( + init_module, + python_create_and_map, + python_unmap_and_release, + ) + from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + lib_name = find_loaded_library("cumem_allocator") libcudart = CudaRTLibrary() cumem_available = True @@ -83,20 +90,19 @@ def unmap_and_release(allocation_handle: HandleType) -> None: def get_pluggable_allocator( - python_malloc_fn: Callable[[int], - int], python_free_func: Callable[[int, int], - None] + python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] ) -> torch.cuda.memory.CUDAPluggableAllocator: init_module(python_malloc_fn, python_free_func) new_alloc = torch.cuda.memory.CUDAPluggableAllocator( - lib_name, 'my_malloc', 'my_free') + lib_name, "my_malloc", "my_free" + ) return new_alloc @contextmanager def use_memory_pool_with_allocator( - python_malloc_fn: Callable[[int], int], - python_free_func: Callable[[int, int], None]) -> None: + python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] +) -> None: new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) with torch.cuda.memory.use_mem_pool(mem_pool): @@ -127,6 +133,7 @@ class CuMemAllocator: the global variable will be overwritten and the free callback will not work as expected. """ + instance: "CuMemAllocator" = None default_tag: str = "default" @@ -144,37 +151,53 @@ class CuMemAllocator: def __init__(self): conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") - assert "expandable_segments:True" not in conf, \ - ("Expandable segments are not compatible with memory pool. " + assert "expandable_segments:True" not in conf, ( + "Expandable segments are not compatible with memory pool. " "Please track https://github.com/pytorch/pytorch/issues/147851 " - "for the latest updates.") + "for the latest updates." + ) self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" py_d_mem = allocation_handle[2] self.pointer_to_data[py_d_mem] = AllocationData( - allocation_handle, self.current_tag) + allocation_handle, self.current_tag + ) + logger.debug( + "Allocated %s bytes for %s with address %s from cumem allocator", + allocation_handle[1], + self.current_tag, + py_d_mem, + ) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" data = self.pointer_to_data.pop(ptr) if data.cpu_backup_tensor is not None: data.cpu_backup_tensor = None + logger.debug( + "Freed %s bytes for %s with address %s from cumem allocator", + data.handle[1], + data.tag, + ptr, + ) return data.handle - def sleep( - self, - offload_tags: Optional[Union[tuple[str, ...], - str]] = None) -> None: + def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be @@ -186,35 +209,50 @@ class CuMemAllocator: if offload_tags is None: # by default, allocated tensors are offloaded # when the allocator sleeps - offload_tags = (CuMemAllocator.default_tag, ) + offload_tags = (CuMemAllocator.default_tag,) elif isinstance(offload_tags, str): - offload_tags = (offload_tags, ) + offload_tags = (offload_tags,) assert isinstance(offload_tags, tuple) + total_bytes = 0 + backup_bytes = 0 + for ptr, data in self.pointer_to_data.items(): handle = data.handle + total_bytes += handle[1] if data.tag in offload_tags: + backup_bytes += handle[1] size_in_bytes = handle[1] cpu_backup_tensor = torch.empty( size_in_bytes, dtype=torch.uint8, - device='cpu', - pin_memory=is_pin_memory_available()) + device="cpu", + pin_memory=is_pin_memory_available(), + ) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) data.cpu_backup_tensor = cpu_backup_tensor unmap_and_release(handle) + logger.info( + "CuMemAllocator: sleep freed %.2f GiB memory in total, of which " + "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " + "directly.", + total_bytes / 1024**3, + backup_bytes / 1024**3, + (total_bytes - backup_bytes) / 1024**3, + ) + gc.collect() torch.cuda.empty_cache() def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. @@ -226,8 +264,9 @@ class CuMemAllocator: if data.cpu_backup_tensor is not None: cpu_backup_tensor = data.cpu_backup_tensor if cpu_backup_tensor is not None: - size_in_bytes = cpu_backup_tensor.numel( - ) * cpu_backup_tensor.element_size() + size_in_bytes = ( + cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() + ) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) data.cpu_backup_tensor = None @@ -249,8 +288,9 @@ class CuMemAllocator: old_tag = self.current_tag self.current_tag = tag - with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback) as data: + with use_memory_pool_with_allocator( + self.python_malloc_callback, self.python_free_callback + ) as data: # start to hit another PyTorch bug in PyTorch 2.6, # possibly because of gc-related issue w.r.t. the allocator and # the memory pool. @@ -262,12 +302,17 @@ class CuMemAllocator: # when using pluggable allocator, see # https://github.com/pytorch/pytorch/issues/145168 . # if we have some memory allocated and then freed, - # the memory will not be released. - # right now it is fine, because we only use this allocator - # during weight loading and kv cache creation, where we only - # allocate memory. - # TODO: we need to find a way to release the memory, - # i.e. calling torch.cuda.empty_cache() + # the memory will not be released, e.g. in online quantization, + # where the model is created in higher precision, and then + # quantized in lower precision. + # Find all unused allocations and manually release them. + # TODO: we should expose `empty_cache` method in the memory pool. + # TODO: ask for help from PyTorch team to expose this method. + allocations = data[0].snapshot() + for allocation in allocations: + if allocation["allocated_size"] == 0: + handle = self._python_free_callback(allocation["address"]) + unmap_and_release(handle) self.current_tag = old_tag def get_current_usage(self) -> int: diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0a5a95176f7c3..46a735f22ed85 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -14,28 +14,30 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) -def tensor_model_parallel_all_gather(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: +def tensor_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" return get_tp_group().all_gather(input_, dim) -def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: +def tensor_model_parallel_reduce_scatter( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: """Reduce-Scatter the input tensor across model parallel group.""" return get_tp_group().reduce_scatter(input_, dim) -def tensor_model_parallel_gather(input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: +def tensor_model_parallel_gather( + input_: torch.Tensor, dst: int = 0, dim: int = -1 +) -> Optional[torch.Tensor]: """Gather the input tensor across model parallel group.""" return get_tp_group().gather(input_, dst, dim) -def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor, - Any]]] = None, - src: int = 0): +def broadcast_tensor_dict( + tensor_dict: Optional[dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 +): if not torch.distributed.is_initialized(): return tensor_dict return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 85f87cb21edcd..a67405f44206a 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,22 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any +from typing import Any, Optional import torch import torch.distributed as dist +import vllm.envs as envs +from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import has_deep_ep, has_pplx +from vllm.utils.flashinfer import has_flashinfer_all2all from .base_device_communicator import All2AllManagerBase, Cache -logger = init_logger(__name__) +if has_flashinfer_all2all(): + from flashinfer.comm import Mapping + from flashinfer.comm.mnnvl import MnnvlConfig + from flashinfer.comm.trtllm_alltoall import MnnvlMoe -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.layer import FusedMoE -else: - FusedMoE = None +logger = init_logger(__name__) class NaiveAll2AllManager(All2AllManagerBase): @@ -30,43 +33,61 @@ class NaiveAll2AllManager(All2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) + def naive_multicast( + self, + x: torch.Tensor, + cu_tokens_across_sp_cpu: torch.Tensor, + is_sequence_parallel: bool, + ) -> torch.Tensor: + assert len(x.shape) == 2 + buffer = torch.empty( + (cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype + ) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] + rank = self.rank if is_sequence_parallel else self.dp_rank + world_size = self.world_size if is_sequence_parallel else self.dp_world_size + + start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1] + end = cu_tokens_across_sp_cpu[rank] buffer[start:end, :].copy_(x) - for idx in range(self.dp_world_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - self.dp_group.broadcast(buffer[start:end, :], idx) + for idx in range(world_size): + start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1] + end = cu_tokens_across_sp_cpu[idx] + get_ep_group().broadcast(buffer[start:end, :], idx) return buffer - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + dp_metadata = get_forward_context().dp_metadata + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel + ) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: + ep_rank = self.rank if is_sequence_parallel else self.dp_rank - all_hidden_states = self.dp_group.all_reduce(hidden_states) + dp_metadata = get_forward_context().dp_metadata + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) + + start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1] + end = cu_tokens_across_sp_cpu[ep_rank] + + all_hidden_states = get_ep_group().all_reduce(hidden_states) hidden_states = all_hidden_states[start:end, :] return hidden_states @@ -74,30 +95,87 @@ class NaiveAll2AllManager(All2AllManagerBase): pass +class AgRsAll2AllManager(All2AllManagerBase): + """ + An implementation of all2all communication based on + all-gather (dispatch) and reduce-scatter (combine). + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather hidden_states and router_logits from all dp ranks. + """ + sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] + hidden_states, router_logits = dist_group.all_gatherv( + [hidden_states, router_logits], + dim=0, + sizes=sizes, + ) + return hidden_states, router_logits + + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: + """ + Reduce-scatter hidden_states across all dp ranks. + """ + sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes) + return hidden_states + + def destroy(self): + pass + + class PPLXAll2AllManager(All2AllManagerBase): """ All2All communication based on PPLX kernels. """ def __init__(self, cpu_group): - assert has_pplx( - ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + assert has_pplx(), ( + "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" + " to install pplx_kernels." + ) super().__init__(cpu_group) if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init, + ) + logger.debug( - "Initialize NVSHMEM for pplx_kernels: " - "rank=%d, world size=%d", self.rank, self.world_size) - uid = nvshmem_get_unique_id( - ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() - dist.broadcast(uid, - src=dist.get_process_group_ranks(self.cpu_group)[0], - group=self.cpu_group) + "Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d", + self.rank, + self.world_size, + ) + uid = ( + nvshmem_get_unique_id() + if self.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) + dist.broadcast( + uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group, + ) logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) @@ -105,15 +183,23 @@ class PPLXAll2AllManager(All2AllManagerBase): def get_handle(self, kwargs): import pplx_kernels as pplx - return self.handle_cache.get_or_create( - kwargs, pplx.AllToAll.internode - if self.internode else pplx.AllToAll.intranode) - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + return self.handle_cache.get_or_create( + kwargs, + pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, + ) + + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -123,6 +209,7 @@ class PPLXAll2AllManager(All2AllManagerBase): if self.internode: from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") nvshmem_finalize() @@ -133,8 +220,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): """ def __init__(self, cpu_group): - assert has_deep_ep( - ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa + assert has_deep_ep(), ( + "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" + " to install DeepEP kernels." + ) # noqa super().__init__(cpu_group) self.handle_cache = Cache() @@ -145,11 +234,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): def get_handle(self, kwargs): raise NotImplementedError - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -166,12 +261,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_rdma_bytes = None num_qps_per_rank = None if self.internode: - num_rdma_bytes = 1024 * 1024 * 1024 + num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = self.num_sms // 2 else: num_rdma_bytes = 0 @@ -179,30 +274,39 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): assert num_rdma_bytes is not None assert num_qps_per_rank is not None - return dict(group=self.cpu_group, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=False, - num_qps_per_rank=num_qps_per_rank) + return dict( + group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=False, + num_qps_per_rank=num_qps_per_rank, + ) def get_handle(self, kwargs): - assert len(kwargs) == 0, ( "DeepEPHTAll2AllManager expects no arguments. All the required " - "args are computed in the Manager itself.") + "args are computed in the Manager itself." + ) import deep_ep + buffer_kwargs = self._make_all2all_kwargs() logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( - buffer_kwargs, deep_ep.Buffer) - # It is dangerous to set num sms outside this function. num_sms is not - # a part of the hash-key that identifies this object. If we are in a - # situation where we make objects with different num_sms, the hash key - # in get_or_create must be updated. - handle.set_num_sms(self.num_sms) + buffer_kwargs, deep_ep.Buffer + ) return handle + def set_num_sms(self, num_sms: int): + import deep_ep + + # Right now the buffers are sized for only what the kernels were + # created with. So we can only reduce the number of SMS used + # but not increase it. + if num_sms > self.num_sms: + num_sms = self.num_sms + deep_ep.Buffer.set_num_sms(num_sms) + class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): """ @@ -231,20 +335,23 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): import deep_ep # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = num_local_experts num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, hidden=token_hidden_size, num_ranks=num_ep_ranks, - num_experts=num_global_experts) + num_experts=num_global_experts, + ) assert num_rdma_bytes is not None - return dict(group=self.cpu_group, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=num_qps_per_rank) + return dict( + group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_qps_per_rank, + ) def get_handle(self, kwargs): """ @@ -252,13 +359,118 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): _make_all2all_kwargs. """ import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( - buffer_kwargs, deep_ep.Buffer) - # It is dangerous to set num sms outside this function. num_sms is not - # a part of the hash-key that identifies this object. If we are in a - # situation where we make objects with different num_sms, the hash key - # in get_or_create must be updated. - handle.set_num_sms(self.num_sms) + buffer_kwargs, deep_ep.Buffer + ) return handle + + # DeepEP LL uses RDMA so no SMs are used for communication + def max_sms_used(self) -> Optional[int]: + return 0 + + +class FlashInferAllToAllManager(All2AllManagerBase): + """ + All2All communication based on flashinfer kernels. + """ + + # This type lint could be removed after all of the work in + # https://github.com/vllm-project/vllm/issues/26533 done. + rank: int + world_size: int + + def __init__(self, cpu_group): + assert has_flashinfer_all2all(), ( + "flashinfer all2all module not found. Please install/check flashinfer" + ) # noqa + super().__init__(cpu_group) + logger.debug( + "Initialize for flashinfer All2All rank=%d, world size=%d", + self.rank, + self.world_size, + ) + self.initialized = False + self.alltoall_info = None + + def initialize( + self, + world_size: int, + rank: int, + gpus_per_node: int, + ): + """Initialize workspace""" + if self.initialized: + return + + self.cleanup() + logger.debug("making map: rank=%d, world size=%d", rank, world_size) + self.mapping = Mapping( + world_size, + rank, + gpus_per_node, + tp_size=world_size, + ) + + from vllm.distributed.device_communicators.mnnvl_compat import ( + CustomCommunicator, + ) + + dp_config = MnnvlConfig( + comm_backend=CustomCommunicator(get_dp_group().cpu_group), + fabric_page_size=1 << 29, # 512MB + allocation_granularity=0, # Auto-detect + ) + + self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config) + self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace( + self.mapping, dp_config + ) + + self.world_size = world_size + self.rank = rank + self.gpus_per_node = gpus_per_node + self.initialized = True + + logger.info( + "FlashInfer All2All initialized for rank %s, size %s", rank, world_size + ) + + def ensure_alltoall_workspace_initialized(self): + """Ensure workspace is initialized""" + if not has_flashinfer_all2all(): + return False + + if self.world_size <= 1: + return False + + if not self.initialized: + self.initialize( + world_size=self.world_size, + rank=self.rank, + gpus_per_node=torch.cuda.device_count, + ) + return self.initialized + + def get_handle(self, kwargs): + return self + + def cleanup(self): + """Clean up workspace""" + if ( + self.initialized + and self.workspace_tensor is not None + and self.prepare_workspace_tensor is not None + ): + try: + del self.workspace_tensor + del self.prepare_workspace_tensor + except Exception as e: + logger.warning("Failed to cleanup FlashInfer workspace: %s", e) + finally: + self.workspace_tensor = None + self.prepare_workspace_tensor = None + self.mapping = None + self.initialized = False diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 5c64e7d5c4ba3..dabb48320be45 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -10,16 +10,16 @@ import sys import tempfile from collections.abc import Sequence from itertools import product -from typing import Optional +from typing import Any, Optional +import torch import torch.distributed as dist import torch.multiprocessing as mp import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger -from vllm.utils import (cuda_device_count_stateless, - update_environment_variables) +from vllm.utils import cuda_device_count_stateless, update_environment_variables logger = init_logger(__name__) @@ -36,9 +36,9 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = { "10.0": { 2: 2 * MiB, # 2 MB 4: 2 * MiB, # 2 MB - 6: 2 * MiB, # 2 MB - 8: 2 * MiB, # 2 MB - } + 6: 1 * MiB, # 1 MB + 8: 1 * MiB, # 1 MB + }, } SYMM_MEM_ALL_REDUCE_MAX_SIZES = { @@ -53,18 +53,43 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = { 4: 32 * MiB, # 32 MB 6: 128 * MiB, # 128 MB 8: 128 * MiB, # 128 MB - } + }, +} + +NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = { + "min_world_size": 4, + "thresholds": { + 4: 2 * MiB, # 2 MB + 8: 1 * MiB, # 1 MB + }, + "always_use_above_world_size": 8, # Always use symm mem for world_size > 8 } -def producer(batch_src: Sequence[int], - producer_queue, - consumer_queue, - result_queue, - cuda_visible_devices: Optional[str] = None): +def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool: + from vllm.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled, + ) + + if not is_symmetric_memory_enabled(): + return False + if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: + return False + threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size) + if threshold is not None and input_tensor.nbytes >= threshold: + return True + return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"] + + +def producer( + batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): if cuda_visible_devices is not None: - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for i in batch_src: @@ -90,14 +115,15 @@ def producer(batch_src: Sequence[int], lib.cudaDeviceReset() -def consumer(batch_tgt: Sequence[int], - producer_queue, - consumer_queue, - result_queue, - cuda_visible_devices: Optional[str] = None): +def consumer( + batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): if cuda_visible_devices is not None: - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for j in batch_tgt: @@ -173,12 +199,26 @@ def can_actually_p2p( producer_queue = smp.Queue() consumer_queue = smp.Queue() result_queue = smp.Queue() - p_src = smp.Process(target=producer, - args=(batch_src, producer_queue, consumer_queue, - result_queue, cuda_visible_devices)) - p_tgt = smp.Process(target=consumer, - args=(batch_tgt, producer_queue, consumer_queue, - result_queue, cuda_visible_devices)) + p_src = smp.Process( + target=producer, + args=( + batch_src, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) + p_tgt = smp.Process( + target=consumer, + args=( + batch_tgt, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) p_src.start() p_tgt.start() p_src.join() @@ -191,7 +231,10 @@ def can_actually_p2p( if a != b: logger.warning( "Two processes do not agree on the P2P access" - " status on %d -> %d, treat as disabled.", src, tgt) + " status on %d -> %d, treat as disabled.", + src, + tgt, + ) result.append(False) else: result.append(a) @@ -230,12 +273,14 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) path = os.path.join( - envs.VLLM_CACHE_ROOT, - f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) os.makedirs(os.path.dirname(path), exist_ok=True) from vllm.distributed.parallel_state import get_world_group - if ((not is_distributed or get_world_group().local_rank == 0) - and (not os.path.exists(path))): + + if (not is_distributed or get_world_group().local_rank == 0) and ( + not os.path.exists(path) + ): # only the local master process (with local_rank == 0) can # enter this block to calculate the cache logger.info("generating GPU P2P access cache in %s", path) @@ -254,11 +299,10 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # we don't use the output of the subprocess directly, # because the subprocess might produce logging output with tempfile.NamedTemporaryFile() as output_file: - input_bytes = pickle.dumps( - (batch_src, batch_tgt, output_file.name)) - returned = subprocess.run([sys.executable, __file__], - input=input_bytes, - capture_output=True) + input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name)) + returned = subprocess.run( + [sys.executable, __file__], input=input_bytes, capture_output=True + ) # check if the subprocess is successful try: returned.check_returncode() @@ -267,7 +311,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: raise RuntimeError( f"Error happened when batch testing " f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" - f"{returned.stderr.decode()}") from e + f"{returned.stderr.decode()}" + ) from e with open(output_file.name, "rb") as f: result = pickle.load(f) for _i, _j, r in zip(batch_src, batch_tgt, result): diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 9e5aa4e4c2a89..c32be0bec55c0 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -10,7 +10,6 @@ from torch.distributed import ProcessGroup class Cache: - def __init__(self): self._cache: WeakValueDictionary = WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety @@ -28,18 +27,23 @@ class Cache: class All2AllManagerBase: + rank: int + world_size: int def __init__(self, cpu_group): self.cpu_group = cpu_group # compute some common properties - from vllm.distributed.parallel_state import (get_dp_group, - get_tp_group, - in_the_same_node_as) + from vllm.distributed.parallel_state import ( + get_dp_group, + get_tp_group, + in_the_same_node_as, + ) # all2all lives in ep group, which is merged from dp and tp group self.dp_group = get_dp_group() self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction # when we create this object self.dp_rank = self.dp_group.rank_in_group @@ -60,11 +64,21 @@ class All2AllManagerBase: # and reuse it for the same config. raise NotImplementedError - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ): raise NotImplementedError - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def set_num_sms(self, num_sms: int): + pass + + def max_sms_used(self) -> Optional[int]: + return None # None means it could use the whole GPU + + def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False): raise NotImplementedError def destroy(self): @@ -79,11 +93,13 @@ class DeviceCommunicatorBase: communication backend), the `device_group` will also be given. """ - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): self.device = device or torch.device("cpu") self.cpu_group = cpu_group self.device_group = device_group @@ -93,11 +109,11 @@ class DeviceCommunicatorBase: self.ranks = dist.get_process_group_ranks(cpu_group) self.global_rank = dist.get_rank() self.global_world_size = dist.get_world_size() - self.rank_in_group = dist.get_group_rank(self.cpu_group, - self.global_rank) + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False from vllm.config import get_current_vllm_config + config = get_current_vllm_config() if config is not None: # as long as we use data parallel (coupled data parallel @@ -121,41 +137,39 @@ class DeviceCommunicatorBase: # NOTE: we have to use concat-style all-gather here, # stack-style all-gather has compatibility issues with # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * self.world_size, ) + input_size[1:] + output_size = (input_size[0] * self.world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - dist.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) # Reshape - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.reshape((self.world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) return output_tensor def all_gatherv( self, input_: Union[torch.Tensor, list[torch.Tensor]], dim: int = 0, - sizes: Optional[list[int]] = None + sizes: Optional[list[int]] = None, ) -> Union[torch.Tensor, list[torch.Tensor]]: raise NotImplementedError - def reduce_scatter(self, - input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. @@ -167,30 +181,28 @@ class DeviceCommunicatorBase: assert input_tensor.shape[0] % world_size == 0 chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] + output_shape = (chunk_size,) + input_tensor.shape[1:] - output_tensor = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) + output_tensor = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) # Perform reduce-scatter operation - torch.distributed.reduce_scatter_tensor(output_tensor, - input_tensor, - group=self.device_group) + torch.distributed.reduce_scatter_tensor( + output_tensor, input_tensor, group=self.device_group + ) # Reshape before returning return output_tensor.movedim(0, dim).contiguous() - def reduce_scatterv(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[list[int]] = None) -> torch.Tensor: + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + ) -> torch.Tensor: raise NotImplementedError - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -198,7 +210,8 @@ class DeviceCommunicatorBase: """ world_size = self.world_size assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -209,10 +222,9 @@ class DeviceCommunicatorBase: else: gather_list = None # Gather. - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) else: @@ -226,10 +238,9 @@ class DeviceCommunicatorBase: dst = (self.rank_in_group + 1) % self.world_size torch.distributed.send(tensor, self.ranks[dst], self.device_group) - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if src is None: @@ -242,8 +253,7 @@ class DeviceCommunicatorBase: def destroy(self): pass - def prepare_communication_buffer_for_model(self, - model: torch.nn.Module) -> None: + def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None: """ Prepare the communication buffer for the model. """ @@ -251,22 +261,33 @@ class DeviceCommunicatorBase: return moe_modules = [ - module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" + module + for module in model.modules() + # TODO(bnell): Should use isinstance but can't. Maybe search for + # presence of quant_method.init_prepare_finalize? + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) ] for module in moe_modules: - module.quant_method.init_prepare_finalize() + module.quant_method.init_prepare_finalize(module) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. """ return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: """ Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class. diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index bda567f8489c5..c09b3ba9ceba6 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -15,30 +15,30 @@ from .base_device_communicator import DeviceCommunicatorBase class CpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) self.dist_module = torch.distributed - if (current_platform.get_cpu_architecture() - == CpuArchEnum.X86) and hasattr( - torch.ops._C, - "init_shm_manager") and (unique_name.startswith("tp") - or unique_name.startswith("pp")): + if ( + (current_platform.get_cpu_architecture() == CpuArchEnum.X86) + and hasattr(torch.ops._C, "init_shm_manager") + and (unique_name.startswith("tp") or unique_name.startswith("pp")) + ): self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): self.dist_module.all_reduce(input_, group=self.device_group) return input_ - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -46,7 +46,8 @@ class CpuCommunicator(DeviceCommunicatorBase): """ world_size = self.world_size assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -58,10 +59,9 @@ class CpuCommunicator(DeviceCommunicatorBase): gather_list = None # Gather. - self.dist_module.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) + self.dist_module.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) @@ -77,23 +77,24 @@ class CpuCommunicator(DeviceCommunicatorBase): # NOTE: we have to use concat-style all-gather here, # stack-style all-gather has compatibility issues with # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * self.world_size, ) + input_size[1:] + output_size = (input_size[0] * self.world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - self.dist_module.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + self.dist_module.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) # Reshape - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.reshape((self.world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) return output_tensor def send_tensor_dict( @@ -111,7 +112,6 @@ class CpuCommunicator(DeviceCommunicatorBase): class _CPUSHMDistributed: - def __init__(self, communicator: CpuCommunicator): instance_identifier = os.environ["VLLM_DIST_IDENT"] unique_name = communicator.unique_name @@ -139,24 +139,32 @@ class _CPUSHMDistributed: return handle - def all_reduce(self, - input: torch.Tensor, - group: Optional[ProcessGroup] = None) -> None: + def all_reduce( + self, input: torch.Tensor, group: Optional[ProcessGroup] = None + ) -> None: torch.ops._C.shm_allreduce(self.handle, input) - def gather(self, - input: torch.Tensor, - gather_list: Optional[list[torch.Tensor]], - dst: int = -1, - group: Optional[ProcessGroup] = None) -> None: + def gather( + self, + input: torch.Tensor, + gather_list: Optional[list[torch.Tensor]], + dst: int = -1, + group: Optional[ProcessGroup] = None, + ) -> None: # Note: different from the torch gather, here we use local dst rank. - torch.ops._C.shm_gather(self.handle, input, gather_list, - torch.distributed.get_group_rank(group, dst)) + torch.ops._C.shm_gather( + self.handle, + input, + gather_list, + torch.distributed.get_group_rank(group, dst), + ) - def all_gather_into_tensor(self, - output: torch.Tensor, - input: torch.Tensor, - group: Optional[ProcessGroup] = None) -> None: + def all_gather_into_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + group: Optional[ProcessGroup] = None, + ) -> None: torch.ops._C.shm_all_gather(self.handle, input, output) def send_tensor_dict( @@ -169,11 +177,11 @@ class _CPUSHMDistributed: size_list = [] for v in value_list: if not isinstance(v, torch.Tensor): - raise RuntimeError( - "CpuCommunicator only supports sending tensors.") + raise RuntimeError("CpuCommunicator only supports sending tensors.") size_list.append(v.size()) - key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]), - dtype=torch.uint8) + key_size_tensor = torch.frombuffer( + pickle.dumps([key_list, size_list]), dtype=torch.uint8 + ) value_list.append(key_size_tensor) torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 0ea8de2f36f4b..45096dffb5b63 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -7,6 +7,13 @@ import torch from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.device_communicators.all_reduce_utils import ( + should_nccl_symm_mem_allreduce, +) +from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops +from vllm.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled, +) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -16,52 +23,63 @@ logger = init_logger(__name__) class CudaCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) if "tp" not in unique_name: - # only tp uses custom allreduce + # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False + use_torch_symm_mem = False else: - from vllm.distributed.parallel_state import ( - _ENABLE_CUSTOM_ALL_REDUCE) + from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM - # ep does not use pynccl - use_pynccl = "ep" not in unique_name - - self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce + self.use_torch_symm_mem = use_torch_symm_mem # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) + CustomAllreduce, + ) + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) - from vllm.distributed.device_communicators.symm_mem import ( - SymmMemCommunicator) + QuickAllReduce, + ) + from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator self.pynccl_comm: Optional[PyNcclCommunicator] = None - if use_pynccl and self.world_size > 1: + if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, ) + if is_symmetric_memory_enabled(): + register_nccl_symmetric_ops(self.pynccl_comm) self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None + if use_torch_symm_mem and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, + symm_mem_enabled=( + self.symm_mem_comm is not None and not self.symm_mem_comm.disabled + ), ) if current_platform.is_rocm(): @@ -70,57 +88,82 @@ class CudaCommunicator(DeviceCommunicatorBase): # Based on quickreduce (https://github.com/mk1-project/quickreduce). # If it's a rocm, 'use_custom_allreduce==True' means it must # currently be an MI300 series. - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) - if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): - self.symm_mem_comm = SymmMemCommunicator( - group=self.cpu_group, - device=self.device, - ) + self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + elif all2all_backend == "allgather_reducescatter": + from .all2all import AgRsAll2AllManager + + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AllGather-ReduceScatter all2all manager.") elif all2all_backend == "pplx": from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) logger.info("Using PPLX all2all manager.") elif all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager + self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) logger.info("Using DeepEP High-Throughput all2all manager.") elif all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager + self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) logger.info("Using DeepEP Low-Latency all2all manager.") + elif all2all_backend == "flashinfer_all2allv": + from .all2all import FlashInferAllToAllManager + + self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) + logger.info("Using Flashinfer all2allv manager.") else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): + # since currently we perform copy input -> symm_input -> out-of-place AR + # return symm_output, we don't need to check if input is symmetric + if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce( + self.pynccl_comm.world_size, input_ + ): + out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) + if out is not None: + return out # always try quick reduce first, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): + if ( + qr_comm is not None + and not qr_comm.disabled + and qr_comm.should_quick_allreduce(input_) + ): out = qr_comm.quick_all_reduce(input_) assert out is not None return out ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): + if ( + ca_comm is not None + and not ca_comm.disabled + and ca_comm.should_custom_ar(input_) + ): out = ca_comm.custom_all_reduce(input_) assert out is not None return out symm_mem_comm = self.symm_mem_comm - if symm_mem_comm is not None and \ - symm_mem_comm.should_use_symm_mem(input_): + if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_): out = symm_mem_comm.all_reduce(input_) assert out is not None return out pynccl_comm = self.pynccl_comm + if pynccl_comm is None or pynccl_comm.disabled: + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) if out is None: @@ -146,21 +189,20 @@ class CudaCommunicator(DeviceCommunicatorBase): assert input_tensor.shape[0] % world_size == 0 chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] + output_shape = (chunk_size,) + input_tensor.shape[1:] - output = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() - def reduce_scatterv(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[list[int]] = None): + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + ): world_size = self.world_size pynccl_comm = self.pynccl_comm assert pynccl_comm is not None @@ -179,16 +221,16 @@ class CudaCommunicator(DeviceCommunicatorBase): else: assert input_tensor.shape[0] % world_size == 0 chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] + output_shape = (chunk_size,) + input_tensor.shape[1:] - output = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) if sizes is not None: - pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -205,10 +247,9 @@ class CudaCommunicator(DeviceCommunicatorBase): else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if src is None: @@ -231,10 +272,12 @@ class CudaCommunicator(DeviceCommunicatorBase): self.all2all_manager.destroy() self.all2all_manager = None - def all_gatherv(self, - input_: Union[torch.Tensor, list[torch.Tensor]], - dim: int = 0, - sizes: Optional[list[int]] = None): + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None, + ): if dim != 0: raise NotImplementedError("only dim 0 all-gatherv is supported") world_size = self.world_size @@ -246,20 +289,20 @@ class CudaCommunicator(DeviceCommunicatorBase): if sizes is not None and all(s == sizes[0] for s in sizes): sizes = None - def _all_gather_single(input_: torch.Tensor, - sizes: Optional[list[int]] = None): + def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None): input_size = input_.size() if sizes is not None: assert len(sizes) == world_size assert input_.shape[dim] == sizes[self.rank_in_group], ( - f"{input_.shape[dim]} != {sizes[self.rank_in_group]}") - output_size = (sum(sizes), ) + input_size[1:] + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}" + ) + output_size = (sum(sizes),) + input_size[1:] else: - output_size = (input_size[0] * world_size, ) + input_size[1:] + output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) if sizes is not None: pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes) else: @@ -278,14 +321,22 @@ class CudaCommunicator(DeviceCommunicatorBase): return output_list def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) + hidden_states, router_logits, is_sequence_parallel + ) return hidden_states, router_logits - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) + hidden_states = self.all2all_manager.combine( + hidden_states, is_sequence_parallel + ) return hidden_states diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 2c38e8ed21d7d..a77d2666e2ce3 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -42,7 +42,7 @@ def find_loaded_library(lib_name) -> Optional[str]: the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ # noqa found = False with open("/proc/self/maps") as f: for line in f: @@ -57,8 +57,9 @@ def find_loaded_library(lib_name) -> Optional[str]: start = line.index("/") path = line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ + assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" + ) return path @@ -70,30 +71,38 @@ class CudaRTLibrary: Function("cudaDeviceSynchronize", cudaError_t, []), # ​cudaError_t cudaDeviceReset ( void ) Function("cudaDeviceReset", cudaError_t, []), - # const char* cudaGetErrorString ( cudaError_t error ) Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), - # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) - Function("cudaMalloc", cudaError_t, - [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + Function( + "cudaMalloc", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], + ), # ​cudaError_t cudaFree ( void* devPtr ) Function("cudaFree", cudaError_t, [ctypes.c_void_p]), # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) - Function("cudaMemset", cudaError_t, - [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + Function( + "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] + ), # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa - Function("cudaMemcpy", cudaError_t, [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind - ]), - + Function( + "cudaMemcpy", + cudaError_t, + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind], + ), # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa - Function("cudaIpcGetMemHandle", cudaError_t, - [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + Function( + "cudaIpcGetMemHandle", + cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p], + ), # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa - Function("cudaIpcOpenMemHandle", cudaError_t, [ - ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint - ]), + Function( + "cudaIpcOpenMemHandle", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint], + ), ] # class attribute to store the mapping from the path to the library @@ -109,11 +118,10 @@ class CudaRTLibrary: so_file = find_loaded_library("libcudart") if so_file is None: so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var - assert so_file is not None, \ - ( - "libcudart is not loaded in the current process, " - "try setting VLLM_CUDART_SO_PATH" - ) + assert so_file is not None, ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib @@ -154,27 +162,29 @@ class CudaRTLibrary: def cudaFree(self, devPtr: ctypes.c_void_p) -> None: self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) - def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, - count: int) -> None: + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) - def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, - count: int) -> None: + def cudaMemcpy( + self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int + ) -> None: cudaMemcpyDefault = 4 kind = cudaMemcpyDefault self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) - def cudaIpcGetMemHandle(self, - devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: handle = cudaIpcMemHandle_t() - self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( - ctypes.byref(handle), devPtr)) + self.CUDART_CHECK( + self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr) + ) return handle - def cudaIpcOpenMemHandle(self, - handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudaIpcMemLazyEnablePeerAccess = 1 devPtr = ctypes.c_void_p() - self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( - ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + self.CUDART_CHECK( + self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess + ) + ) return devPtr diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 80aca81234eb0..fd5c5dfd9da0e 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,7 +11,9 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.device_communicators.all_reduce_utils import ( - CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) + CUSTOM_ALL_REDUCE_MAX_SIZES, + gpu_p2p_access_check, +) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,8 +34,7 @@ def _can_p2p(rank: int, world_size: int) -> bool: if i == rank: continue if envs.VLLM_SKIP_P2P_CHECK: - logger.info( - "Skipping P2P check and trusting the driver's P2P report.") + logger.info("Skipping P2P check and trusting the driver's P2P report.") return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): return False @@ -41,26 +42,29 @@ def _can_p2p(rank: int, world_size: int) -> bool: def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or (inp.storage().nbytes() - - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size()) + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) class CustomAllreduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] # max_size: max supported allreduce size - def __init__(self, - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024, + symm_mem_enabled=False, + ) -> None: """ Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -71,20 +75,24 @@ class CustomAllreduce: if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-GPU environment - logger.info("Custom allreduce is disabled because " - "of missing custom allreduce library") + logger.info( + "Custom allreduce is disabled because " + "of missing custom allreduce library" + ) return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "CustomAllreduce should be attached to a non-NCCL group.") + "CustomAllreduce should be attached to a non-NCCL group." + ) if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom allreduce for multi-node case. logger.warning( "Custom allreduce is disabled because this process group" - " spans across nodes.") + " spans across nodes." + ) return rank = dist.get_rank(group=self.group) @@ -99,7 +107,9 @@ class CustomAllreduce: "Custom allreduce is disabled due to an unsupported world" " size: %d. Supported world sizes: %s. To silence this " "warning, specify disable_custom_all_reduce=True explicitly.", - world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + world_size, + str(CustomAllreduce._SUPPORTED_WORLD_SIZES), + ) return if isinstance(device, int): @@ -109,13 +119,15 @@ class CustomAllreduce: # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - device_capability = current_platform.get_device_capability( - ).as_version_str() - if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM - and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + device_capability = current_platform.get_device_capability().as_version_str() + if ( + current_platform.is_cuda() + and symm_mem_enabled + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES + ): max_size = min( - CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], - max_size) + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size + ) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) @@ -123,12 +135,9 @@ class CustomAllreduce: device_ids = list(range(cuda_device_count_stateless())) physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], - dtype=torch.int, - device="cpu") + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") gather_list = [ - torch.tensor([0], dtype=torch.int, device="cpu") - for _ in range(world_size) + torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size) ] dist.all_gather(gather_list, tensor, group=self.group) physical_device_ids = [t.item() for t in gather_list] @@ -137,13 +146,13 @@ class CustomAllreduce: # where custom allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - fully_connected = current_platform.is_fully_connected( - physical_device_ids) + fully_connected = current_platform.is_fully_connected(physical_device_ids) if world_size > 2 and not fully_connected: logger.warning( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " - "specify disable_custom_all_reduce=True explicitly.") + "specify disable_custom_all_reduce=True explicitly." + ) return # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time @@ -153,16 +162,17 @@ class CustomAllreduce: logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.") + "warning, specify disable_custom_all_reduce=True explicitly." + ) return self.disabled = False # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a + # Metadata composes of two parts: metadata for synchronization and a # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group, - uncached=True) + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group, uncached=True + ) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) @@ -171,21 +181,22 @@ class CustomAllreduce: # 8*world_size bytes where world_size is at most 8. Allocating 8MB # is enough for 131072 such tuples. The largest model I've seen only # needs less than 10000 of registered tuples. - self.rank_data = torch.empty(8 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) self.max_size = max_size self.rank = rank self.world_size = world_size self.fully_connected = fully_connected - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.fully_connected) + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.fully_connected + ) ops.register_buffer(self._ptr, self.buffer_ptrs) @contextmanager def capture(self): """ - The main responsibility of this context manager is the + The main responsibility of this context manager is the `register_graph_buffers` call at the end of the context. It records all the buffer addresses used in the CUDA graph. """ @@ -203,15 +214,13 @@ class CustomAllreduce: # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. # see https://github.com/pytorch/pytorch/issues/126032 for details. - all_data = [[None, None] - for _ in range(dist.get_world_size(group=self.group))] + all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_data[self.rank] = [handle, offset] ranks = sorted(dist.get_process_group_ranks(group=self.group)) for i, rank in enumerate(ranks): - dist.broadcast_object_list(all_data[i], - src=rank, - group=self.group, - device="cpu") + dist.broadcast_object_list( + all_data[i], src=rank, group=self.group, device="cpu" + ) # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore @@ -232,13 +241,11 @@ class CustomAllreduce: return inp_size < self.max_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): + def all_reduce( + self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False + ): """Performs an out-of-place all reduce. - + If registered is True, this assumes inp's pointer is already IPC-registered. Otherwise, inp is first copied into a pre-registered buffer. @@ -248,8 +255,9 @@ class CustomAllreduce: if registered: ops.all_reduce(self._ptr, inp, out, 0, 0) else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], - self.max_size) + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -282,9 +290,11 @@ class CustomAllreduce: self.close() @staticmethod - def create_shared_buffer(size_in_bytes: int, - group: Optional[ProcessGroup] = None, - uncached: Optional[bool] = False) -> list[int]: + def create_shared_buffer( + size_in_bytes: int, + group: Optional[ProcessGroup] = None, + uncached: Optional[bool] = False, + ) -> list[int]: pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) world_size = dist.get_world_size(group=group) @@ -301,9 +311,11 @@ class CustomAllreduce: return pointers @staticmethod - def free_shared_buffer(pointers: list[int], - group: Optional[ProcessGroup] = None, - rank: Optional[int] = None) -> None: + def free_shared_buffer( + pointers: list[int], + group: Optional[ProcessGroup] = None, + rank: Optional[int] = None, + ) -> None: if rank is None: rank = dist.get_rank(group=group) if ops is not None: diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py new file mode 100644 index 0000000000000..61aee2db46b84 --- /dev/null +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch.distributed as dist +from flashinfer.comm.mnnvl import CommBackend as CommBackend + +from vllm.utils.flashinfer import has_flashinfer_all2all + +assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found" + + +class CustomCommunicator(CommBackend): + def __init__(self, group): + self._group = group + + def Get_rank(self) -> int: + return self._group.rank() + + def Get_size(self) -> int: + return self._group.size() + + def allgather(self, data: int): + gathered = [None] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + + def Split(self, color: int, key: int) -> "CustomCommunicator": + return self diff --git a/vllm/distributed/device_communicators/neuron_communicator.py b/vllm/distributed/device_communicators/neuron_communicator.py deleted file mode 100644 index 5b61a1687a016..0000000000000 --- a/vllm/distributed/device_communicators/neuron_communicator.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) -from vllm.platforms import current_platform - -if current_platform.is_neuron(): - import torch_xla.core.xla_model as xm - - -class NeuronCommunicator(DeviceCommunicatorBase): - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, x) - - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "Neuron only supports dim=-1 for all-gather." - return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 502bfd39005ad..59fa3f9c449b0 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -8,18 +8,55 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +import vllm.envs as envs from vllm.distributed.device_communicators.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, - ncclRedOpTypeEnum, ncclUniqueId) + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import current_stream logger = init_logger(__name__) +_NCCL_SYMM_OPS_REGISTERED = False + + +def register_nccl_symmetric_ops(pynccl_comm): + from vllm.distributed.device_communicators.pynccl_allocator import ( + nccl_symm_mem_context, + ) + from vllm.utils import direct_register_custom_op + + global _NCCL_SYMM_OPS_REGISTERED + if _NCCL_SYMM_OPS_REGISTERED: + return + _NCCL_SYMM_OPS_REGISTERED = True + + def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor: + with nccl_symm_mem_context(pynccl_comm): + symm_input = torch.empty_like(input_tensor) + symm_output = torch.empty_like(input_tensor) + symm_input.copy_(input_tensor) + symm_output = pynccl_comm.all_reduce(symm_input, symm_output) + return symm_output + + def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor: + return torch.empty_like(input_tensor) + + direct_register_custom_op( + op_name="all_reduce_symmetric_with_copy", + op_func=all_reduce_symmetric_with_copy_impl, + fake_impl=all_reduce_symmetric_with_copy_fake, + ) + class PyNcclCommunicator: - def __init__( self, group: Union[ProcessGroup, StatelessProcessGroup], @@ -31,7 +68,7 @@ class PyNcclCommunicator: group: the process group to work on. If None, it will use the default process group. device: the device to bind the PyNcclCommunicator to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". library_path: the path to the NCCL library. If None, it will use the default library path. It is the caller's responsibility to make sure each communicator @@ -40,7 +77,8 @@ class PyNcclCommunicator: if not isinstance(group, StatelessProcessGroup): assert dist.is_initialized() assert dist.get_backend(group) != dist.Backend.NCCL, ( - "PyNcclCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group." + ) # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) @@ -51,7 +89,7 @@ class PyNcclCommunicator: self.group = group # if world_size == 1, no need to create communicator - if self.world_size == 1: + if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL: self.available = False self.disabled = True return @@ -67,6 +105,7 @@ class PyNcclCommunicator: self.available = True self.disabled = False + self.nccl_version = self.nccl.ncclGetRawVersion() logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) if self.rank == 0: @@ -98,7 +137,8 @@ class PyNcclCommunicator: # current cuda device to the specified one with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( - self.world_size, self.unique_id, self.rank) + self.world_size, self.unique_id, self.rank + ) stream = current_stream() # A small all_reduce for warmup. @@ -107,10 +147,13 @@ class PyNcclCommunicator: stream.synchronize() del data - def all_reduce(self, - in_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None) -> torch.Tensor: + def all_reduce( + self, + in_tensor: torch.Tensor, + out_tensor: torch.Tensor = None, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ) -> torch.Tensor: if self.disabled: return None # nccl communicator created on a specific device @@ -118,24 +161,28 @@ class PyNcclCommunicator: # otherwise it will cause "illegal memory access" assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}") + f"but the input tensor is on {in_tensor.device}" + ) - out_tensor = torch.empty_like(in_tensor) + if out_tensor is None: + out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() - self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - ncclDataTypeEnum.from_torch(in_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) return out_tensor - def all_gather(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - stream=None): + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): if self.disabled: return # nccl communicator created on a specific device @@ -143,14 +190,18 @@ class PyNcclCommunicator: # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def all_gatherv( self, @@ -166,14 +217,15 @@ class PyNcclCommunicator: # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() assert output_tensor.shape[0] == sum(sizes) split_offset = 0 self.nccl.ncclGroupStart() for root, split_size in enumerate(sizes): - dst_slice = output_tensor[split_offset:split_offset + split_size] + dst_slice = output_tensor[split_offset : split_offset + split_size] self.nccl.ncclBroadcast( buffer_type(input_tensor.data_ptr()), buffer_type(dst_slice.data_ptr()), @@ -186,11 +238,13 @@ class PyNcclCommunicator: split_offset += split_size self.nccl.ncclGroupEnd() - def reduce_scatter(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None): + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): if self.disabled: return # nccl communicator created on a specific device @@ -198,15 +252,19 @@ class PyNcclCommunicator: # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def reduce_scatterv( self, @@ -223,20 +281,25 @@ class PyNcclCommunicator: # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() split_offset = 0 self.nccl.ncclGroupStart() for root, split_size in enumerate(sizes): - chunk = input_tensor[split_offset:split_offset + split_size, ...] + chunk = input_tensor[split_offset : split_offset + split_size, ...] self.nccl.ncclReduce( buffer_type(chunk.data_ptr()), - buffer_type(output_tensor.data_ptr()), chunk.numel(), + buffer_type(output_tensor.data_ptr()), + chunk.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), root, self.comm, - cudaStream_t(stream.cuda_stream)) + ncclRedOpTypeEnum.from_torch(op), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) split_offset += split_size self.nccl.ncclGroupEnd() @@ -245,31 +308,44 @@ class PyNcclCommunicator: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def broadcast(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() if src == self.rank: @@ -279,12 +355,32 @@ class PyNcclCommunicator: else: sendbuff = buffer_type() recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def group_start(self): self.nccl.ncclGroupStart() def group_end(self): self.nccl.ncclGroupEnd() + + def register_comm_window(self, tensor: torch.Tensor): + return self.nccl.ncclCommWindowRegister( + self.comm, + buffer_type(tensor.data_ptr()), + tensor.numel() * tensor.element_size(), + 1, + ) + + def register_comm_window_raw(self, ptr: int, size: int): + return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1) + + def deregister_comm_window(self, window): + return self.nccl.ncclCommWindowDeregister(self.comm, window) diff --git a/vllm/distributed/device_communicators/pynccl_allocator.py b/vllm/distributed/device_communicators/pynccl_allocator.py new file mode 100644 index 0000000000000..3fe4fd744d77a --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_allocator.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import atexit +import contextlib +import tempfile +from typing import Any, Optional + +import torch +from packaging import version +from torch.cuda.memory import CUDAPluggableAllocator +from torch.utils.cpp_extension import load_inline + +from vllm import envs +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import find_nccl_include_paths + +logger = init_logger(__name__) + +nccl_allocator_source = """ +#include <nccl.h> +extern "C" { + +void* nccl_alloc_plug(size_t size, int device, void* stream) { + void* ptr; + ncclResult_t err = ncclMemAlloc(&ptr, size); + return ptr; + +} + +void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { + ncclResult_t err = ncclMemFree(ptr); +} + +} +""" + +_allocator = None +_allocator_wrapper = None +_mem_pool = None +_registered_base_addrs = set() +_graph_pool_id = None +_nccl_allocator_failed_to_compile = False +_cached_pool_snapshot = None + + +def is_symmetric_memory_enabled(): + global _nccl_allocator_failed_to_compile + return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile + + +def is_symmetric_memory_tensor(tensor: torch.Tensor): + if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None: + return False + for segment in _cached_pool_snapshot: + for block in segment["blocks"]: + if block["address"] == tensor.untyped_storage().data_ptr(): + return True + return False + + +def set_graph_pool_id(graph_pool_id): + global _graph_pool_id + _graph_pool_id = graph_pool_id + + +def compile_nccl_allocator(): + global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile + if not current_platform.is_cuda(): + _nccl_allocator_failed_to_compile = True + return + try: + out_dir = tempfile.gettempdir() + nccl_allocator_libname = "nccl_allocator" + nccl_include_paths = find_nccl_include_paths() + load_inline( + name=nccl_allocator_libname, + cpp_sources=nccl_allocator_source, + with_cuda=True, + extra_ldflags=["-lnccl"], + verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG", + is_python_module=False, + build_directory=out_dir, + extra_include_paths=nccl_include_paths, + ) + _allocator_wrapper = CUDAPluggableAllocator( + f"{out_dir}/{nccl_allocator_libname}.so", + "nccl_alloc_plug", + "nccl_free_plug", + ) + _allocator = _allocator_wrapper.allocator() + except Exception as e: + _nccl_allocator_failed_to_compile = True + logger.warning( + "Failed to compile NCCL memory allocator. " + "Symmetric memory will be disabled. " + "This is expected if NCCL headers are not available. " + "optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory " + "containing the NCCL header. " + "Error: %s", + str(e), + ) + + +def get_nccl_mem_pool(): + global _mem_pool, _nccl_allocator_failed_to_compile + if _mem_pool is None and not _nccl_allocator_failed_to_compile: + compile_nccl_allocator() + if _allocator is not None: + _mem_pool = torch.cuda.MemPool(_allocator) + return _mem_pool + + +def _cleanup_nccl_mem_pool(): + global _mem_pool + _mem_pool = None + + +def _cleanup_nccl_allocator_wrapper(): + global _allocator_wrapper + _allocator_wrapper = None + + +atexit.register(_cleanup_nccl_mem_pool) +atexit.register(_cleanup_nccl_allocator_wrapper) + + +class nccl_symm_mem_context: + def __init__( + self, + pynccl_comm: PyNcclCommunicator, + disabled: bool = False, + ): + self.disabled = ( + disabled + or not is_symmetric_memory_enabled() + or pynccl_comm.world_size == 1 + or not current_platform.is_cuda() + or get_nccl_mem_pool() is None + or version.parse(torch.__version__) < version.parse("2.8.0.a0") + ) + if self.disabled: + self.pynccl_comm: Optional[PyNcclCommunicator] = None + self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = ( + contextlib.nullcontext() + ) + self.is_graph_capture = None + self.device = None + else: + self.pynccl_comm = pynccl_comm + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) + self.is_graph_capture = torch.cuda.is_current_stream_capturing() + self.device = torch.cuda.current_device() + + def __enter__(self): + if self.disabled: + return self + assert self.pynccl_comm is not None, ( + "Symmetric memory requires pynccl to be initalized" + ) + assert self.pynccl_comm.nccl_version >= 22703, ( + "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + ) + if self.is_graph_capture: + assert _graph_pool_id is not None, ( + "graph_pool_id is not set under graph capture" + ) + # Pause graph memory pool to use symmetric memory with cuda graph + torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) + self._mem_pool_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disabled: + return + global _cached_pool_snapshot + global _registered_base_addrs + self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) + _pool = get_nccl_mem_pool() + assert _pool is not None + _cached_pool_snapshot = _pool.snapshot() + assert self.pynccl_comm is not None + for segment in _cached_pool_snapshot: + if segment["address"] not in _registered_base_addrs: + self.pynccl_comm.register_comm_window_raw( + segment["address"], segment["total_size"] + ) + _registered_base_addrs.add(segment["address"]) + if self.is_graph_capture: + torch._C._cuda_beginAllocateCurrentThreadToPool(self.device, _graph_pool_id) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index a930b63bc26ff..e4d7b0f8fb85a 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -30,7 +30,9 @@ from typing import Any, Optional import torch from torch.distributed import ReduceOp +from vllm import envs from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import find_nccl_library logger = init_logger(__name__) @@ -41,6 +43,7 @@ logger = init_logger(__name__) ncclResult_t = ctypes.c_int ncclComm_t = ctypes.c_void_p +ncclWindow_t = ctypes.c_void_p class ncclUniqueId(ctypes.Structure): @@ -130,88 +133,141 @@ class NCCLLibrary: # const char* ncclGetErrorString(ncclResult_t result) Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), # ncclResult_t ncclGetVersion(int *version); - Function("ncclGetVersion", ncclResult_t, - [ctypes.POINTER(ctypes.c_int)]), + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); - Function("ncclGetUniqueId", ncclResult_t, - [ctypes.POINTER(ncclUniqueId)]), + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument # is a pointer to a pointer - Function("ncclCommInitRank", ncclResult_t, [ - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, - ctypes.c_int - ]), + Function( + "ncclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), # ncclResult_t ncclAllReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, int root, # ncclComm_t comm, cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllGather", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduceScatter( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclReduceScatter", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); - Function("ncclSend", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclRecv( # void* recvbuff, size_t count, ncclDataType_t datatype, # int src, ncclComm_t comm, cudaStream_t stream); - Function("ncclRecv", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclBroadcast( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, int root, ncclComm_t comm, # cudaStream_t stream); - Function("ncclBroadcast", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -222,6 +278,23 @@ class NCCLLibrary: Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); Function("ncclGroupEnd", ncclResult_t, []), + # ncclResult_t ncclCommWindowRegister( + # ncclComm_t comm, void* buff, size_t size, + # ncclWindow_t* win, int winFlags); + Function( + "ncclCommWindowRegister", + ncclResult_t, + [ + ncclComm_t, + buffer_type, + ctypes.c_size_t, + ctypes.POINTER(ncclWindow_t), + ctypes.c_int, + ], + ), + # ncclResult_t ncclCommWindowDeregister( + # ncclComm_t comm, ncclWindow_t win); + Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]), ] # class attribute to store the mapping from the path to the library @@ -233,7 +306,6 @@ class NCCLLibrary: path_to_dict_mapping: dict[str, dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): - so_file = so_file or find_nccl_library() try: @@ -249,17 +321,39 @@ class NCCLLibrary: "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) raise e if so_file not in NCCLLibrary.path_to_dict_mapping: _funcs: dict[str, Any] = {} for func in NCCLLibrary.exported_functions: - f = getattr(self.lib, func.name) - f.restype = func.restype - f.argtypes = func.argtypes - _funcs[func.name] = f + try: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + except AttributeError: + if func.name in [ + "ncclCommWindowRegister", + "ncclCommWindowDeregister", + ]: + if envs.VLLM_USE_NCCL_SYMM_MEM: + logger.warning_once( + "The symbol %s is not found in the NCCL " + "library %s. To enable VLLM_USE_NCCL_SYMM_MEM " + " please update your NCCL version to >= " + "2.27.03.", + func.name, + so_file, + ) + if current_platform.is_rocm(): + # Having an exception here on ROCm platform is + # not allowed during graph capturing + continue + raise NCCLLibrary.path_to_dict_mapping[so_file] = _funcs self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] @@ -271,10 +365,14 @@ class NCCLLibrary: error_str = self.ncclGetErrorString(result) raise RuntimeError(f"NCCL error: {error_str}") - def ncclGetVersion(self) -> str: + def ncclGetRawVersion(self) -> int: version = ctypes.c_int() self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) - version_str = str(version.value) + # something like 21903 + return version.value + + def ncclGetVersion(self) -> str: + version_str = str(self.ncclGetRawVersion()) # something like 21903 --> "2.19.3" major = version_str[0].lstrip("0") minor = version_str[1:3].lstrip("0") @@ -283,88 +381,153 @@ class NCCLLibrary: def ncclGetUniqueId(self) -> ncclUniqueId: unique_id = ncclUniqueId() - self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( - ctypes.byref(unique_id))) + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) return unique_id def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: if len(data) != 128: raise ValueError( - f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes" + ) unique_id = ncclUniqueId() ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) return unique_id - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, - rank: int) -> ncclComm_t: + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: comm = ncclComm_t() - self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) + self.NCCL_CHECK( + self._funcs["ncclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) return comm - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) + self.NCCL_CHECK( + self._funcs["ncclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) - def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, root: int, - comm: ncclComm_t, stream: cudaStream_t) -> None: + def ncclReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count, - datatype, op, root, comm, - stream)) + self.NCCL_CHECK( + self._funcs["ncclReduce"]( + sendbuff, recvbuff, count, datatype, op, root, comm, stream + ) + ) - def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # which is an aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, - dest, comm, stream)) + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + ) - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, - comm, stream)) + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + ) - def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) @@ -375,8 +538,27 @@ class NCCLLibrary: def ncclGroupEnd(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + def ncclCommWindowRegister( + self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int + ) -> ncclWindow_t: + window = ncclWindow_t() + self.NCCL_CHECK( + self._funcs["ncclCommWindowRegister"]( + comm, buff, size, ctypes.byref(window), win_flags + ) + ) + return window + + def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + __all__ = [ - "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", - "ncclComm_t", "cudaStream_t", "buffer_type" + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", ] diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index c61231e2d33f4..16b6b6c28ea3a 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -27,9 +27,10 @@ except Exception: def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or (inp.storage().nbytes() - - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size()) + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) class QuickReduceRegime(Enum): @@ -44,7 +45,6 @@ MB = 1024 * 1024 class QuickAllReduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 8] _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] # The following data is based on kernel tests. @@ -58,27 +58,28 @@ class QuickAllReduce: (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], } - def __init__(self, group: ProcessGroup, - device: Union[int, str, torch.device]) -> None: + def __init__( + self, group: ProcessGroup, device: Union[int, str, torch.device] + ) -> None: """ - Custom allreduce provides non-destructive acceleration and is + Custom allreduce provides non-destructive acceleration and is available for CUDA and ROCm MI300 series. - Custom quick allreduce leverages quantization for further - acceleration on ROCm. It currently supports Q8, Q6, and Q4 + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 quantization formats and FP(float16, bfloat16). - Quick allreduce is designed as a complement to custom allreduce. - Its initialization requires even stricter conditions. + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. - Only the ROCm MI300 series is supported for quick allreduce at + Only the ROCm MI300 series is supported for quick allreduce at this time. Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -93,18 +94,23 @@ class QuickAllReduce: if not quick_ar: # disable because of missing quick reduce library # e.g. in a cuda environment - logger.info("Custom quick allreduce is disabled because " - "of missing custom quick allreduce library") + logger.info( + "Custom quick allreduce is disabled because " + "of missing custom quick allreduce library" + ) return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "Custom quick allreduce should be attached to a non-NCCL group.") + "Custom quick allreduce should be attached to a non-NCCL group." + ) if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom quick allreduce for # multi-node case. - logger.warning("Custom quick allreduce is disabled because this " - "process group spans across nodes.") + logger.warning( + "Custom quick allreduce is disabled because this " + "process group spans across nodes." + ) return rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) @@ -118,7 +124,9 @@ class QuickAllReduce: logger.warning( "Custom quick allreduce is disabled due to an " "unsupported world size: %d. Supported world sizes: %s.", - world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + world_size, + str(QuickAllReduce._SUPPORTED_WORLD_SIZES), + ) return if isinstance(device, int): @@ -134,9 +142,7 @@ class QuickAllReduce: else: device_ids = list(range(cuda_device_count_stateless())) physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], - dtype=torch.int, - device="cpu") + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") gather_list = [ torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(self.world_size) @@ -148,12 +154,12 @@ class QuickAllReduce: # where custom quick allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - self.fully_connected = current_platform.is_fully_connected( - physical_device_ids) + self.fully_connected = current_platform.is_fully_connected(physical_device_ids) if self.world_size > 2 and not self.fully_connected: logger.debug( "Custom quick allreduce is disabled because it's not supported " - "on more than two PCIe-only GPUs. ") + "on more than two PCIe-only GPUs. " + ) return self.init_quick_all_reduce() @@ -169,24 +175,31 @@ class QuickAllReduce: "Custom quick allreduce:", f"Invalid quantization level: {regime_str}. " "Supported levels: " - f"{list(QuickReduceRegime.__members__.keys())}") + f"{list(QuickReduceRegime.__members__.keys())}", + ) return if regime_str == "NONE": - logger.debug("Custom quick allreduce is disabled based " - "on env variable " - "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'") + logger.debug( + "Custom quick allreduce is disabled based " + "on env variable " + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'" + ) return self.qr_quant_level = QuickReduceRegime[regime_str] vllm_config = get_current_vllm_config() - if vllm_config is not None and \ - hasattr(vllm_config, "model_config") and \ - hasattr(vllm_config.model_config, "dtype"): + if ( + vllm_config is not None + and hasattr(vllm_config, "model_config") + and hasattr(vllm_config.model_config, "dtype") + ): dtype = vllm_config.model_config.dtype if dtype not in [torch.float16, torch.bfloat16]: logger.debug( "Custom quick allreduce disabled: only supports " - "float16 and float16, but get %s.", dtype) + "float16 and float16, but get %s.", + dtype, + ) return if dtype == torch.bfloat16 and self.use_fp16_kernels: @@ -194,7 +207,8 @@ class QuickAllReduce: "Custom quick allreduce: BF16 inputs will be converted " "to FP16 to improve performance. set " "envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 " - "to turn off.") + "to turn off." + ) # VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB @@ -206,8 +220,7 @@ class QuickAllReduce: ) qr_max_size = qr_max_size * MB self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) - self.qr_max_size = qr_max_size if qr_max_size is not None \ - else ops.qr_max_size() + self.qr_max_size = qr_max_size if qr_max_size is not None else ops.qr_max_size() self.create_shared_buffer() self.disabled = False @@ -217,16 +230,15 @@ class QuickAllReduce: try: props = torch.cuda.get_device_properties(0) gcn_arch = getattr(props, "gcnArchName", "") - supported_archs = ['gfx94', 'gfx95'] + supported_archs = ["gfx94", "gfx95"] return any(gfx in gcn_arch for gfx in supported_archs) except Exception as e: - logger.warning("Failed to determine ROCm for quick allreduce: %s", - e) + logger.warning("Failed to determine ROCm for quick allreduce: %s", e) return False def create_shared_buffer(self): """ - Creates a shared buffer for quickreduce. + Creates a shared buffer for quickreduce. Has to be called after init_custom_qr """ handle = ops.qr_get_handle(self._ptr) @@ -253,9 +265,11 @@ class QuickAllReduce: dtype = inp.dtype if self.use_fp16_kernels: dtype = torch.float16 - return inp_size <= self.qr_max_size and \ - inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ - [self.qr_quant_level.value] + return ( + inp_size <= self.qr_max_size + and inp_size + >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value] + ) def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): """Performs an out-of-place custom quick all reduce.""" @@ -263,8 +277,9 @@ class QuickAllReduce: # as QR uses static IPC buffer. if out is None: out = torch.empty_like(inp) - ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value, - self.use_fp16_kernels) + ops.qr_all_reduce( + self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels + ) return out def close(self): diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 46cc1c2f52d67..da79afc7ac145 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -6,12 +6,12 @@ from typing import Any, Optional import ray import torch from ray.exceptions import RayChannelError -from ray.experimental.channel.communicator import (Communicator, - TorchTensorAllocator) +from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator from torch.distributed import ReduceOp from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) + DeviceCommunicatorBase, +) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.utils import current_stream @@ -59,11 +59,11 @@ class RayPPCommunicator(Communicator): self._rank: Optional[int] = None self._actor_handles = actor_handles if use_communication_streams: - raise NotImplementedError( - "use_communication_streams is not supported") + raise NotImplementedError("use_communication_streams is not supported") if cuda_stream is not None and cuda_stream != current_stream(): raise ValueError( - "cuda_stream other than the current stream is not supported") + "cuda_stream other than the current stream is not supported" + ) if rank is not None: # Rank is not None, this is Ray worker @@ -99,13 +99,14 @@ class RayPPCommunicator(Communicator): # Ray actor IDs are 32-character hex strings (128 bits) ACTOR_ID_LEN = 32 - actor_id_bytes = actor_id_str.encode('utf-8') - assert len( - actor_id_bytes - ) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}" + actor_id_bytes = actor_id_str.encode("utf-8") + assert len(actor_id_bytes) == ACTOR_ID_LEN, ( + f"Unexpected actor ID length: {len(actor_id_bytes)}" + ) - actor_id_tensor = torch.frombuffer( - actor_id_bytes, dtype=torch.uint8).to(self._comm.device) + actor_id_tensor = torch.frombuffer(actor_id_bytes, dtype=torch.uint8).to( + self._comm.device + ) # All-gather full actor IDs from all actors gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0) @@ -115,9 +116,8 @@ class RayPPCommunicator(Communicator): for rank in range(self._world_size): start_idx = rank * ACTOR_ID_LEN end_idx = (rank + 1) * ACTOR_ID_LEN - actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy( - ).tobytes() - actor_id = actor_bytes.decode('utf-8') + actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy().tobytes() + actor_id = actor_bytes.decode("utf-8") self._actor_id_to_rank[actor_id] = rank def initialize(self, rank: int) -> None: @@ -131,9 +131,10 @@ class RayPPCommunicator(Communicator): """ Return the given actor's rank using device communicator collective ops. """ - assert hasattr(self, '_actor_id_to_rank'), ( + assert hasattr(self, "_actor_id_to_rank"), ( "Actor rank mapping not built. " - "This should have been done during initialization.") + "This should have been done during initialization." + ) actor_id_str = actor._actor_id.hex() @@ -178,7 +179,7 @@ class RayPPCommunicator(Communicator): def recv( self, - shape: tuple[int], + shape: tuple[int, ...], dtype: "torch.dtype", peer_rank: int, allocator: TorchTensorAllocator, @@ -186,7 +187,7 @@ class RayPPCommunicator(Communicator): """ Receive a torch.Tensor from a peer and synchronize the current stream. - After this call returns, the receive buffer is safe to read from from + After this call returns, the receive buffer is safe to read from any stream. An RayChannelError will be raised if an error occurred (e.g., remote actor died), and the buffer is not safe to read. diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index c7810043b81e8..4cec601027284 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -14,14 +14,24 @@ import torch import torch.distributed as dist import zmq from torch.distributed import ProcessGroup -from zmq import IPV6 # type: ignore -from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore +from zmq import ( # type: ignore + IPV6, # type: ignore + SUB, + SUBSCRIBE, + XPUB, + XPUB_VERBOSE, + Context, +) import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger -from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, - is_valid_ipv6_address) +from vllm.utils import ( + get_ip, + get_open_port, + get_open_zmq_ipc_path, + is_valid_ipv6_address, +) VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -29,7 +39,6 @@ logger = init_logger(__name__) class SpinTimer: - def record_activity(self): pass @@ -66,12 +75,13 @@ class SpinSleepTimer(SpinTimer): class ShmRingBuffer: - - def __init__(self, - n_reader: int, - max_chunk_bytes: int, - max_chunks: int, - name: Optional[str] = None): + def __init__( + self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None, + ): """ A shared memory ring buffer implementation for broadcast communication. Essentially, it is a queue where only one will `enqueue` and multiple @@ -120,13 +130,14 @@ class ShmRingBuffer: created object to other processes by pickling it. The other processes will get the name of the shared memory and open it, so that they can access the same shared memory buffer. - """# noqa + """ # noqa self.n_reader = n_reader self.metadata_size = 1 + n_reader self.max_chunk_bytes = max_chunk_bytes self.max_chunks = max_chunks - self.total_bytes_of_buffer = (self.max_chunk_bytes + - self.metadata_size) * self.max_chunks + self.total_bytes_of_buffer = ( + self.max_chunk_bytes + self.metadata_size + ) * self.max_chunks self.data_offset = 0 self.metadata_offset = self.max_chunk_bytes * self.max_chunks @@ -134,10 +145,10 @@ class ShmRingBuffer: # we are creating a buffer self.is_creator = True self.shared_memory = shared_memory.SharedMemory( - create=True, size=self.total_bytes_of_buffer) + create=True, size=self.total_bytes_of_buffer + ) # initialize the metadata section to 0 - with memoryview(self.shared_memory.buf[self.metadata_offset:] - ) as metadata_buffer: + with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer: torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) else: # we are opening an existing buffer @@ -145,8 +156,10 @@ class ShmRingBuffer: # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. - with patch("multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None): + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): try: self.shared_memory = shared_memory.SharedMemory(name=name) # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa @@ -154,8 +167,7 @@ class ShmRingBuffer: # so the shared memory block size may be larger or equal # to the requested size. The size parameter is ignored # when attaching to an existing block. - assert (self.shared_memory.size - >= self.total_bytes_of_buffer) + assert self.shared_memory.size >= self.total_bytes_of_buffer except FileNotFoundError: # we might deserialize the object in a different node # in this case, this object is not used, @@ -163,8 +175,12 @@ class ShmRingBuffer: pass def handle(self): - return (self.n_reader, self.max_chunk_bytes, self.max_chunks, - self.shared_memory.name) + return ( + self.n_reader, + self.max_chunk_bytes, + self.max_chunks, + self.shared_memory.name, + ) def __reduce__(self): return ( @@ -182,14 +198,14 @@ class ShmRingBuffer: def get_data(self, current_idx: int): start = self.data_offset + current_idx * self.max_chunk_bytes end = start + self.max_chunk_bytes - with memoryview(self.shared_memory.buf[start:end]) as buf: + with self.shared_memory.buf[start:end] as buf: yield buf @contextmanager def get_metadata(self, current_idx: int): start = self.metadata_offset + current_idx * self.metadata_size end = start + self.metadata_size - with memoryview(self.shared_memory.buf[start:end]) as buf: + with self.shared_memory.buf[start:end] as buf: yield buf @@ -204,7 +220,6 @@ class Handle: class MessageQueue: - def __init__( self, n_reader, # number of all readers @@ -228,8 +243,7 @@ class MessageQueue: # for local readers, we will: # 1. create a shared memory ring buffer to communicate small data # 2. create a publish-subscribe socket to communicate large data - self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, - max_chunks) + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks) # XPUB is very similar to PUB, # except that it can receive subscription messages @@ -279,8 +293,7 @@ class MessageQueue: self.handle = Handle( local_reader_ranks=local_reader_ranks, - buffer_handle=self.buffer.handle() - if self.buffer is not None else None, + buffer_handle=self.buffer.handle() if self.buffer is not None else None, local_subscribe_addr=local_subscribe_addr, remote_subscribe_addr=remote_subscribe_addr, remote_addr_ipv6=remote_addr_ipv6, @@ -315,8 +328,9 @@ class MessageQueue: self.remote_socket = None - self._read_spin_timer = SpinSleepTimer( - ) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + self._read_spin_timer = ( + SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + ) else: self.buffer = None # type: ignore self.current_idx = -1 @@ -387,21 +401,22 @@ class MessageQueue: # Release the processor to other threads sched_yield() + # if we time out, raise an exception + elapsed = time.monotonic() - start_time + if timeout is not None and elapsed > timeout: + raise TimeoutError + # if we wait for a long time, log a message - if (time.monotonic() - start_time - > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): - logger.debug( - ("No available shared memory broadcast block found" - " in %s second."), + if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: + logger.info( + "No available shared memory broadcast block found" + " in %s seconds. This typically happens when some" + " processes are hanging or doing some" + " time-consuming work (e.g. compilation)", VLLM_RINGBUFFER_WARNING_INTERVAL, ) n_warning += 1 - # if we time out, raise an exception - if (timeout is not None - and time.monotonic() - start_time > timeout): - raise TimeoutError - continue # found a block that is either # (1) not written @@ -423,14 +438,16 @@ class MessageQueue: metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @contextmanager - def acquire_read(self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None): + def acquire_read( + self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None, + indefinite: bool = False, + ): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -450,24 +467,27 @@ class MessageQueue: # Release the processor to other threads self._read_spin_timer.spin() - # if we wait for a long time, log a message - if (time.monotonic() - start_time - > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): - logger.debug( - ("No available shared memory broadcast block found" - " in %s second."), - VLLM_RINGBUFFER_WARNING_INTERVAL, - ) - n_warning += 1 - if cancel is not None and cancel.is_set(): raise RuntimeError("cancelled") # if we time out, raise an exception - if (timeout is not None - and time.monotonic() - start_time > timeout): + elapsed = time.monotonic() - start_time + if timeout is not None and elapsed > timeout: raise TimeoutError + # if we wait for a long time, log a message + if not indefinite and ( + elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + ): + logger.info( + "No available shared memory broadcast block found" + " in %s seconds. This typically happens when some" + " processes are hanging or doing some" + " time-consuming work (e.g. compilation).", + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + continue # found a block that is not read by this reader # let caller read from the buffer @@ -477,14 +497,13 @@ class MessageQueue: # caller has read from the buffer # set the read flag metadata_buffer[self.local_reader_rank + 1] = 1 - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self._read_spin_timer.record_activity() break def enqueue(self, obj, timeout: Optional[float] = None): - """ Write to message queue with optional timeout (in seconds) """ + """Write to message queue with optional timeout (in seconds)""" assert self._is_writer, "Only writers can enqueue" serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) if self.n_local_reader > 0: @@ -495,16 +514,19 @@ class MessageQueue: else: with self.acquire_write(timeout) as buf: buf[0] = 0 # not overflow - buf[1:len(serialized_obj) + 1] = serialized_obj + buf[1 : len(serialized_obj) + 1] = serialized_obj if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None): - """ Read from message queue with optional timeout (in seconds) """ + def dequeue( + self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None, + indefinite: bool = False, + ): + """Read from message queue with optional timeout (in seconds)""" if self._is_local_reader: - with self.acquire_read(timeout, cancel) as buf: + with self.acquire_read(timeout, cancel, indefinite) as buf: overflow = buf[0] == 1 if not overflow: # no need to know the size of serialized object @@ -535,11 +557,12 @@ class MessageQueue: return self.dequeue() @staticmethod - def create_from_process_group(pg: Union[ProcessGroup, - StatelessProcessGroup], - max_chunk_bytes, - max_chunks, - writer_rank=0) -> "MessageQueue": + def create_from_process_group( + pg: Union[ProcessGroup, StatelessProcessGroup], + max_chunk_bytes, + max_chunks, + writer_rank=0, + ) -> "MessageQueue": if isinstance(pg, ProcessGroup): group_rank = dist.get_rank(pg) group_world_size = dist.get_world_size(pg) @@ -550,6 +573,7 @@ class MessageQueue: global_ranks = list(range(pg.world_size)) from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) same_node_ranks = [i for i, s in enumerate(status) if s] n_reader = group_world_size - 1 @@ -566,17 +590,17 @@ class MessageQueue: ) handle = buffer_io.export_handle() if isinstance(pg, ProcessGroup): - dist.broadcast_object_list([handle], - src=global_ranks[writer_rank], - group=pg) + dist.broadcast_object_list( + [handle], src=global_ranks[writer_rank], group=pg + ) else: pg.broadcast_obj(handle, writer_rank) else: if isinstance(pg, ProcessGroup): recv = [None] - dist.broadcast_object_list(recv, - src=global_ranks[writer_rank], - group=pg) + dist.broadcast_object_list( + recv, src=global_ranks[writer_rank], group=pg + ) handle = recv[0] # type: ignore else: handle = pg.broadcast_obj(None, writer_rank) diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py new file mode 100644 index 0000000000000..a5486c30edf29 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -0,0 +1,656 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +from abc import ABC, abstractmethod +from collections.abc import Iterable +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import chain +from multiprocessing import shared_memory +from multiprocessing.synchronize import Lock as LockType +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SingleWriterShmRingBuffer: + """ + A single-writer, multiple-reader ring buffer implementation using shared + memory. This class provides a thread-safe ring buffer where one process + can write data while multiple processes/threads can read from it. + + Architecture: + - Uses shared memory for cross-process communication + - Maintains metadata for each allocated buffer chunk in the writer process + - Supports custom "is_free_fn" functions to determine when buffers can be + reused + - Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]` + + Key Concepts: + - monotonic_id_start/end: Track the range of active buffer IDs + - data_buffer_start/end: Track the physical memory range in use + - Automatic wraparound when reaching buffer end + - Lazy garbage collection based on is_free_fn checks + + Example Usage Scenarios: + + Scenario 1: Simple Linear Allocation + ``` + Buffer size: 100 bytes + Initial state: [................................................. ] + ^start=end(0) + + After allocating 20 bytes (id=0): + [id:0|size:20|data........][...................................] + ^start(0) ^end(28) + + After allocating 30 bytes (id=1): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + ``` + + Scenario 2: Memory Reclamation + ``` + Before freeing (both buffers still in use): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + + After id:0 is marked free by readers: + [FREED.................... ][id:1|size:30|data..............][..] + ^start(28) ^end(66) + + After both are freed: + [FREED..............................................][..] + ^start=end(66) + ``` + + Scenario 3: Wraparound Allocation (continuing from Scenario 2) + ``` + Starting from after memory reclamation in Scenario 2: + [FREED..............................................][..] + ^start=end(66) + + Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + ``` + + Scenario 4: Error Handling - Out of Space + ``` + Starting from after wraparound allocation in Scenario 3: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + + Trying to allocate 20 more bytes: + occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100) + -> Raises MemoryError: "Not enough space in the data buffer" + ``` + + Thread Safety: + - Single writer: Only one process/thread should write (allocate_buf) + - Multiple readers: Multiple processes/threads can read (access_buf) + - Reader synchronization handled by is_free_fn callback + - Writer handles garbage collection (free_buf) based on reader feedback + + Memory Layout per Buffer Chunk: + `[4-byte monotonic_id][4-byte chunk_size][actual_data...]` + ^metadata_start ^data_start + + The monotonic_id ensures data integrity - readers can verify they're + accessing the correct data even after buffer wraparound or reuse. + """ + + def __init__( + self, + data_buffer_size: int, + name: Optional[str] = None, + create: bool = False, + ): + self.data_buffer_size = data_buffer_size + self.is_writer = create + + self.ID_NBYTES = 4 + self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value + self.SIZE_NBYTES = 4 + # 4 bytes for id, 4 bytes for buffer size + self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + if create: + # we are creating a buffer + self.metadata = { + self.monotonic_id_end: self.data_buffer_end + } # monotonic_id -> start address + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.data_buffer_size, name=name + ) + else: + # we are opening an existing buffer + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert self.shared_memory.size >= self.data_buffer_size + + logger.debug( + "Shared memory created/opened with name: %s, size: %d", + self.shared_memory.name, + self.data_buffer_size, + ) + + def handle(self): + return ( + self.data_buffer_size, + self.shared_memory.name, + ) + + def clear(self) -> None: + """Clear the ring buffer.""" + assert self.is_writer, "Only the writer can clear the buffer." + self.metadata.clear() + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_writer: + self.shared_memory.unlink() + + def int2byte(self, integer: int) -> bytes: + """Convert an integer to bytes.""" + return integer.to_bytes(self.ID_NBYTES, "little", signed=True) + + def byte2int(self, byte_data: bytes) -> int: + """Convert bytes back to an integer.""" + return int.from_bytes(byte_data, "little", signed=True) + + def allocate_buf(self, size: int) -> tuple[int, int]: + """ + Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory. + Memory layout: + `[4-byte monotonic_id][4-byte size][buffer data...]` + """ + assert self.is_writer, "Only the writer can allocate buffers." + assert size > 0, "Size must be greater than 0" + size += self.MD_SIZE # add metadata size to the buffer size + # reset to beginning if the buffer does have enough contiguous space + buffer_end_reset = self.data_buffer_end % self.data_buffer_size + if buffer_end_reset + size > self.data_buffer_size: + buffer_end_reset = ( + self.data_buffer_end // self.data_buffer_size + 1 + ) * self.data_buffer_size + else: # no reset needed + buffer_end_reset = self.data_buffer_end + + # check if we have enough space in the data buffer + # i.e. if the new end (self.data_buffer_end + size) + # exceeds the start of the data buffer + occupied_size_new = buffer_end_reset + size - self.data_buffer_start + if occupied_size_new > self.data_buffer_size: + raise MemoryError( + "Not enough space in the data buffer, " + "try calling free_buf() to free up space" + ) + self.data_buffer_end = buffer_end_reset + + # first 4 bytes as the monotonic id + buf_idx = self.data_buffer_end % self.data_buffer_size + self.shared_memory.buf[buf_idx : buf_idx + self.ID_NBYTES] = self.int2byte( + self.monotonic_id_end + ) + # next 4 bytes as the size of the data buffer + self.shared_memory.buf[buf_idx + self.ID_NBYTES : buf_idx + self.MD_SIZE] = ( + self.int2byte(size) + ) + + # record metadata + self.metadata[self.monotonic_id_end % self.ID_MAX] = self.data_buffer_end + # update buffer and monotonic id indices + current_buffer_end = self.data_buffer_end + current_id_end = self.monotonic_id_end + self.data_buffer_end += size + self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX + return current_buffer_end, current_id_end + + @contextmanager + def access_buf(self, address: int): + buf_idx = address % self.data_buffer_size + + # read metadata + metadata_buff = self.shared_memory.buf[buf_idx : buf_idx + self.MD_SIZE] + id = self.byte2int(metadata_buff[: self.ID_NBYTES]) + size = self.byte2int(metadata_buff[self.ID_NBYTES : self.MD_SIZE]) + + # yield the data buffer and metadata + data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE : buf_idx + size] + with ( + memoryview(data_buff) as data_view, + ): + yield data_view, (id, size) + + def free_buf( + self, + is_free_fn: Callable[[int, memoryview], bool], + nbytes: Optional[int] = None, + ) -> Iterable[int]: + """ + Free a buffer of the given size. This is a no-op in shared memory, + but we need to keep track of the metadata. + + If freed memory spreads across the end and start of the ring buffer, + the actual freed memory will be in two segments. In this case there + still might not be a contiguous space of `nbytes` available. + + Args: + nbytes (int, optional): The size of the buffer to free. If None, + frees the maximum size of the ring buffer. + """ + + assert self.is_writer, "Only the writer can free buffers." + logger.debug( + "Freeing up space in the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", + self.monotonic_id_start, + self.monotonic_id_end, + ) + monotonic_id_before = self.monotonic_id_start + # if nbytes is None, free up the maximum size of the ring buffer + if nbytes is None: + nbytes = self.data_buffer_size + freed_bytes = 0 + while self.monotonic_id_start in self.metadata and freed_bytes < nbytes: + address = self.metadata[self.monotonic_id_start] + with self.access_buf(address) as (data_buff, metadata): + if is_free_fn(self.monotonic_id_start, data_buff): + # check passed, we can free the buffer + del self.metadata[self.monotonic_id_start] + self.monotonic_id_start = ( + self.monotonic_id_start + 1 + ) % self.ID_MAX + self.data_buffer_start = address + freed_bytes += metadata[1] + else: + # there are still readers, we cannot free the buffer + break + + logger.debug( + "Freed %d bytes from the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", + freed_bytes, + self.monotonic_id_start, + self.monotonic_id_end, + ) + + # buffer wrap around + if self.data_buffer_start >= self.data_buffer_size: + self.data_buffer_start -= self.data_buffer_size + self.data_buffer_end -= self.data_buffer_size + + monotonic_id_after = self.monotonic_id_start + # id wrap around + if monotonic_id_after >= monotonic_id_before: + return range(monotonic_id_before, monotonic_id_after) + else: + return chain( + range(monotonic_id_before, self.ID_MAX), range(0, monotonic_id_after) + ) + + +class ObjectSerde(ABC): + @abstractmethod + def serialize(self, value: Any) -> tuple[Any, int, bytes, int]: + """Serialize an object to bytes.""" + raise NotImplementedError + + @abstractmethod + def deserialize(self, data: memoryview) -> Any: + """Deserialize bytes back to an object.""" + raise NotImplementedError + + +class MsgpackSerde(ObjectSerde): + def __init__(self): + # Delayed import to avoid circular dependency + from vllm.multimodal.inputs import MultiModalKwargsItem + from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + + self.encoder = MsgpackEncoder() + self.tensor_decoder = MsgpackDecoder(torch.Tensor) + self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) + self._mm_kwargs_item_cls = MultiModalKwargsItem + + def serialize( + self, value: Any + ) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: + len_arr = None + if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)): + type_name = type(value).__name__ + value = self.encoder.encode(value) + len_arr = [len(s) for s in value] + nbytes = sum(len_arr) + else: + value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + type_name = type(value).__name__ + nbytes = len(value) + + object_metadata = (type_name, nbytes, len_arr) + serialized_metadata = pickle.dumps( + object_metadata, protocol=pickle.HIGHEST_PROTOCOL + ) + return value, nbytes, serialized_metadata, len(serialized_metadata) + + def deserialize(self, data_view: memoryview) -> Any: + # pickle.loads do not read past the end of a pickled object + # within a large buffer, so we can skip storing the metadata size + type_name, nbytes, len_arr = pickle.loads(data_view) + serialized_data = bytearray(data_view[-nbytes:]) + + if type_name == torch.Tensor.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx : start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.tensor_decoder.decode(obj) + elif type_name == self._mm_kwargs_item_cls.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx : start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.mm_decoder.decode(obj) + elif type_name == bytes.__name__: + obj = pickle.loads(serialized_data) + else: + raise ValueError(f"Unsupported object type '{type_name}' in metadata") + + return obj + + +@dataclass +class ShmObjectStorageHandle: + max_object_size: int + n_readers: int + ring_buffer_handle: tuple[int, str] + serde_class: type[ObjectSerde] + reader_lock: Optional[LockType] + + +class SingleWriterShmObjectStorage: + """ + A single-writer, multiple-reader object storage system built on top of a + shared memory ring buffer. Provides key-value storage with automatic memory + management and cross-process serialization support. + + This storage system follows a FIFO (First-In-First-Out) eviction policy + where the oldest objects are automatically freed when memory runs low. + Memory is reclaimed based on reader reference counting - objects are only + freed when all readers have finished accessing them. + + Architecture: + - Single writer process can put(key, value) objects + - Multiple reader processes can get(address, monotonic_id) objects + - Built on SingleWriterShmRingBuffer for efficient shared memory management + - Thread-safe operations with reader synchronization via locks + + Key Features: + - FIFO Eviction: Oldest objects are evicted first when memory is full + - Reference Counting: Objects are only freed when no readers are + accessing them + - Duplicate Key Handling: Existing keys are not overwritten, just + re-referenced + - Customized Serialization: By default uses Msgpack for efficient + serialization of Python objects, but can be extended for custom types + - Cross-Process Safety: Uses shared memory with proper synchronization + - Automatic Cleanup: Garbage collection happens transparently during + allocation + + Memory Layout per Object: + `[4-byte reference_count][metadata_size][serialized_object_data]` + + Thread Safety: + - Writer operations (put, clear) are single-threaded by design + - Reader operations (get) are thread-safe with lock-based reference + counting + - Memory reclamation is handled exclusively by the writer process + """ + + def __init__( + self, + max_object_size: int, + n_readers: int, + ring_buffer: SingleWriterShmRingBuffer, + serde_class: type[ObjectSerde] = MsgpackSerde, + reader_lock: Optional[LockType] = None, + ): + """ + Initialize the object storage. + + Args: + max_object_size: Maximum size for a single object in bytes. + n_readers: Number of reader processes that can access the storage. + ring_buffer: The shared memory ring buffer for storing objects. + serde_class: Serializer/deserializer for objects. + reader_lock: Optional lock for synchronizing reader access. + Raises: + ValueError: If reader_lock is None for readers. + """ + + self.max_object_size = max_object_size + self.n_readers = n_readers + self.serde_class = serde_class + self.ser_de = serde_class() + self.ring_buffer = ring_buffer + self.is_writer = self.ring_buffer.is_writer + + self.flag_bytes = 4 # for in-use flag + + if self.is_writer: + # Key-value mapping: key -> (address, monotonic_id) + self.key_index: dict[str, tuple[int, int]] = {} + # Reverse mapping: monotonic_id -> key + self.id_index: dict[int, str] = {} + # Writer flag to track in-use status: monotonic_id -> count + self.writer_flag: dict[int, int] = {} + else: + if reader_lock is None: + raise ValueError("Lock must be provided for readers.") + + self._reader_lock = reader_lock + + def clear(self) -> None: + """Clear the object storage.""" + if self.is_writer: + self.ring_buffer.clear() + self.key_index.clear() + self.id_index.clear() + self.writer_flag.clear() + logger.debug("Object storage cleared and reinitialized.") + + def copy_to_buffer( + self, + data: Union[bytes, list[bytes]], + data_bytes: int, + metadata: bytes, + md_bytes: int, + data_view: memoryview, + ) -> None: + data_view[self.flag_bytes : self.flag_bytes + md_bytes] = metadata + if isinstance(data, bytes): + data_view[-data_bytes:] = data + elif isinstance(data, list): + start_idx = self.flag_bytes + md_bytes + for item_bytes in data: + item_size = len(item_bytes) + data_view[start_idx : start_idx + item_size] = item_bytes + start_idx += item_size + else: + raise ValueError(f"Unsupported data type for serialization: {type(data)}") + + def increment_writer_flag(self, id: int) -> None: + """Set the in-use flag for the writer.""" + self.writer_flag[id] = self.writer_flag.get(id, 0) + 1 + + def increment_reader_flag(self, data_view: memoryview) -> None: + """Set the in-use flag for the reader.""" + # >0 for in-use flag + reader_count = self.ring_buffer.byte2int(data_view) + data_view[:] = self.ring_buffer.int2byte(reader_count + 1) + + def free_unused(self) -> None: + """Free unused buffers in the ring buffer.""" + # try to free up 2*max_object_size bytes of space in the ring buffer, + # since the buffer might be fragmented + freed_ids = self.ring_buffer.free_buf( + self.default_is_free_check, 2 * self.max_object_size + ) + # update the metadata after freeing up space + for freed_id in freed_ids: + key_to_free = self.id_index[freed_id] + del self.key_index[key_to_free] + del self.id_index[freed_id] + del self.writer_flag[freed_id] + + def is_cached(self, key: str) -> bool: + """ + Check if the object with the given key is cached. + """ + return key in self.key_index + + def get_cached(self, key: str) -> tuple[int, int]: + """ + Get the cached object by key if it exists. + """ + address, monotonic_id = self.key_index[key] + self.increment_writer_flag(monotonic_id) + return address, monotonic_id + + def put(self, key: str, value: Any) -> tuple[int, int]: + """ + Store a key-value pair in the object storage. + Attempts to free max_object_size bytes using FIFO order + when the ring buffer runs out of space during a put() operation. + + Args: + key: String key to identify the object + value: Any serializable Python object + + Raises: + MemoryError: If there's not enough space in the buffer + ValueError: If the serialized object is too large + ValueError: If the key already exists in the storage + """ + if key in self.key_index: + raise ValueError(f"Key '{key}' already exists in the storage.") + + object_data, data_bytes, object_metadata, md_bytes = self.ser_de.serialize( + value + ) + buffer_size = self.flag_bytes + data_bytes + md_bytes + + # Sanity checks + if buffer_size > self.max_object_size: + raise ValueError( + f"Serialized object size ({buffer_size} bytes) exceeds " + f"max object size ({self.max_object_size} bytes)" + ) + + # Allocate new buffer + try: + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + except MemoryError: + self.free_unused() + # try again after freeing up space + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + + # Write data to buffer + with self.ring_buffer.access_buf(address) as (data_view, metadata): + data_view[: self.flag_bytes] = self.ring_buffer.int2byte(0) + self.copy_to_buffer( + object_data, data_bytes, object_metadata, md_bytes, data_view + ) + self.increment_writer_flag(monotonic_id) + + # Update key index + self.key_index[key] = (address, monotonic_id) + self.id_index[monotonic_id] = key + return address, monotonic_id + + def get(self, address: int, monotonic_id: int) -> Any: + # Read data from buffer + with self.ring_buffer.access_buf(address) as (data_view, buf_metadata): + # check id from metadata + if buf_metadata[0] != monotonic_id: + raise ValueError( + f"Data for address:id '{address}:{monotonic_id}'" + " has been modified or is invalid." + ) + + obj = self.ser_de.deserialize(data_view[self.flag_bytes :]) + + # decrease the in-use flag for reader reads + if self._reader_lock is not None: + with self._reader_lock: + self.increment_reader_flag(data_view[: self.flag_bytes]) + else: + # if self._reader_lock is None, it means we are the writer + # in this case, we do not need to decrease the reader count + assert self.is_writer + + return obj + + def handle(self): + """Get handle for sharing across processes.""" + return ShmObjectStorageHandle( + max_object_size=self.max_object_size, + n_readers=self.n_readers, + ring_buffer_handle=self.ring_buffer.handle(), + serde_class=self.serde_class, + reader_lock=self._reader_lock, + ) + + @staticmethod + def create_from_handle( + handle: ShmObjectStorageHandle, + ) -> "SingleWriterShmObjectStorage": + logger.debug("Creating storage from handle: %s", handle) + ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle) + return SingleWriterShmObjectStorage( + max_object_size=handle.max_object_size, + n_readers=handle.n_readers, + ring_buffer=ring_buffer, + serde_class=handle.serde_class, + reader_lock=handle.reader_lock, + ) + + def default_is_free_check(self, id: int, buf: memoryview) -> bool: + """ + Default is_free function that checks if the first 4 bytes are zero. + This indicates that the buffer is free. + """ + reader_count = int.from_bytes(buf[0:4], "little", signed=True) + writer_count = self.writer_flag[id] + return reader_count >= writer_count * self.n_readers diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index d907e1b833d04..88451f9552c13 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -7,7 +7,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.all_reduce_utils import ( - SYMM_MEM_ALL_REDUCE_MAX_SIZES) + SYMM_MEM_ALL_REDUCE_MAX_SIZES, +) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -27,16 +28,21 @@ class SymmMemCommunicator: "10.0": [6, 8], } - def __init__(self, group: ProcessGroup, device: Union[int, str, - torch.device]): + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + # add options for testing + force_multimem: Optional[bool] = None, + max_size_override: Optional[int] = None, + ): self.disabled = True if not symm_mem_available: return if not current_platform.is_cuda(): - logger.warning("SymmMemCommunicator: symmetric " - "memory is not available.") + logger.warning("SymmMemCommunicator: symmetric memory is not available.") return if isinstance(device, int): device = torch.device(f"cuda:{device}") @@ -47,8 +53,9 @@ class SymmMemCommunicator: self.device = device self.group = group self.world_size = dist.get_world_size(self.group) - self.device_capability = current_platform.get_device_capability( - ).as_version_str() + self.device_capability = ( + current_platform.get_device_capability().as_version_str() + ) if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: logger.warning( "SymmMemCommunicator: Device capability %s not supported, " @@ -56,16 +63,25 @@ class SymmMemCommunicator: self.device_capability, ) return - if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ - self.device_capability]: + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]: logger.warning( "SymmMemCommunicator: World size %d not supported, " "communicator is not available.", self.world_size, ) return - self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ - self.world_size] + # Use override max_size if provided, otherwise use default + if max_size_override is not None: + self.max_size = max_size_override + logger.info( + "SymmMemCommunicator: Using override max_size: %s bytes", + self.max_size, + ) + else: + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size + ] + self.buffer = torch_symm_mem.empty( self.max_size // self.dtype.itemsize, device=self.device, @@ -73,9 +89,12 @@ class SymmMemCommunicator: ) handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) if handle.multicast_ptr == 0: - logger.warning("SymmMemCommunicator: symmetric memory " - "multicast operations are not supported.") + logger.warning( + "SymmMemCommunicator: symmetric memory " + "multicast operations are not supported." + ) return + self.force_multimem = force_multimem self.disabled = False def should_use_symm_mem(self, inp: torch.Tensor): @@ -89,23 +108,32 @@ class SymmMemCommunicator: return inp_size < self.max_size def all_reduce( - self, - inp: torch.Tensor, - *, - out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None + ) -> Optional[torch.Tensor]: if not self.should_use_symm_mem(inp): return None if out is None: out = torch.empty_like(inp) - self.buffer[:inp.numel()].copy_(inp.view(-1)) - if self.world_size in self._WORLD_SIZES_MULTIMEM[ - self.device_capability]: - torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], - "sum", - self.group.group_name) + self.buffer[: inp.numel()].copy_(inp.view(-1)) + + # Determine which algorithm to use + use_multimem = False + if self.force_multimem is not None: + # Test override: use forced setting + use_multimem = self.force_multimem else: - torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], - "sum", - self.group.group_name) - out.copy_(self.buffer[:inp.numel()].view(out.shape)) + # Normal logic: use multimem for supported world sizes + use_multimem = ( + self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability] + ) + + if use_multimem: + torch.ops.symm_mem.multimem_all_reduce_( + self.buffer[: inp.numel()], "sum", self.group.group_name + ) + else: + torch.ops.symm_mem.two_shot_all_reduce_( + self.buffer[: inp.numel()], "sum", self.group.group_name + ) + out.copy_(self.buffer[: inp.numel()].view(out.shape)) return out diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 942dd67f065dc..b2faea512791a 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,35 +10,39 @@ from torch.distributed import ProcessGroup from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.platforms.tpu import USE_TPU_COMMONS +from vllm.platforms.tpu import USE_TPU_INFERENCE from .base_device_communicator import DeviceCommunicatorBase -USE_RAY = parallel_config = get_current_vllm_config( -).parallel_config.distributed_executor_backend == "ray" +USE_RAY = parallel_config = ( + get_current_vllm_config().parallel_config.distributed_executor_backend == "ray" +) logger = init_logger(__name__) -if not USE_TPU_COMMONS: - logger.info("tpu_commons not found, using vLLM's TpuCommunicator") +if not USE_TPU_INFERENCE: + logger.info("tpu_inference not found, using vLLM's TpuCommunicator") if current_platform.is_tpu(): import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) + create_optimized_replica_groups, + ) + if USE_RAY: from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node @@ -96,7 +100,9 @@ class TpuCommunicator(DeviceCommunicatorBase): return xm.all_gather(input_, dim=dim) -if USE_TPU_COMMONS: - from tpu_commons.distributed.device_communicators import ( - TpuCommunicator as TpuCommonsCommunicator) - TpuCommunicator = TpuCommonsCommunicator # type: ignore +if USE_TPU_INFERENCE: + from tpu_inference.distributed.device_communicators import ( + TpuCommunicator as TpuInferenceCommunicator, + ) + + TpuCommunicator = TpuInferenceCommunicator # type: ignore diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index dee5ed7a28830..33d5b2cf1d879 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -7,28 +7,48 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import vllm.envs as envs +from vllm.logger import init_logger + from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class XpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend != "naive": + logger.warning( + "`%s` all2all manager is not supported on XPU." + "Falling back to `naive` all2all manager for XPU.", + all2all_backend, + ) + all2all_backend = "naive" + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -36,23 +56,43 @@ class XpuCommunicator(DeviceCommunicatorBase): # cluster so we use all_gather instead for now. input_size = input_.size() # Allocate output tensor. - output_tensor = torch.empty((self.world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - dist.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) if self.rank_in_group == dst: # Reshape output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) else: output_tensor = None return output_tensor def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) + + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits, is_sequence_parallel + ) + return hidden_states, router_logits + + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine( + hidden_states, is_sequence_parallel + ) + return hidden_states diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py index 80511024b9304..4cd51dd384ad2 100644 --- a/vllm/distributed/eplb/__init__.py +++ b/vllm/distributed/eplb/__init__.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -''' +""" Expert parallelism load balancer (EPLB). -''' +""" from .eplb_state import * from .rebalance_algo import * diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 042acf40d67c2..663f040270461 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -35,8 +35,11 @@ import torch from torch.distributed import ProcessGroup, all_reduce from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import (get_ep_group, get_node_count, - in_the_same_node_as) +from vllm.distributed.parallel_state import ( + get_ep_group, + get_node_count, + in_the_same_node_as, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -190,11 +193,10 @@ class EplbState: """ Build the initial EPLB state. """ - physical_to_logical_map_list = ( - cls.build_initial_global_physical_to_logical_map( - model.num_routed_experts, - model.num_redundant_experts, - )) + physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + ) physical_to_logical_map = torch.tensor( physical_to_logical_map_list, device=device, @@ -205,7 +207,8 @@ class EplbState: MAX_EXPERT_REDUNDANCY = 1023 assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, ( f"num_redundant_experts {model.num_redundant_experts} " - f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}") + f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}" + ) max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1 logical_to_physical_map = torch.full( (model.num_logical_experts, max_slots_per_logical_expert), @@ -213,31 +216,42 @@ class EplbState: device=device, ) logical_replica_count = torch.zeros( - (model.num_logical_experts, ), + (model.num_logical_experts,), device=device, dtype=torch.long, ) for i in range(model.num_physical_experts): logical_idx = physical_to_logical_map[i] - logical_to_physical_map[logical_idx, - logical_replica_count[logical_idx]] = i + logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i logical_replica_count[logical_idx] += 1 # Duplicate initial mapping for all layers - physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand( - model.num_moe_layers, - -1, - ).contiguous() - logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand( - model.num_moe_layers, - -1, - -1, - ).contiguous() - logical_replica_count = logical_replica_count.unsqueeze(0).expand( - model.num_moe_layers, - -1, - ).contiguous() + physical_to_logical_map = ( + physical_to_logical_map.unsqueeze(0) + .expand( + model.num_moe_layers, + -1, + ) + .contiguous() + ) + logical_to_physical_map = ( + logical_to_physical_map.unsqueeze(0) + .expand( + model.num_moe_layers, + -1, + -1, + ) + .contiguous() + ) + logical_replica_count = ( + logical_replica_count.unsqueeze(0) + .expand( + model.num_moe_layers, + -1, + ) + .contiguous() + ) expert_load_pass = torch.zeros( (model.num_moe_layers, model.num_physical_experts), @@ -246,21 +260,21 @@ class EplbState: ) expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( - (expert_load_window_size, model.num_moe_layers, - model.num_physical_experts), + (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), dtype=torch.int32, device=device, ) # Set the initial progress of rearrangement to 3/4 eplb_step_interval = parallel_config.eplb_config.step_interval - expert_rearrangement_step = max( - 0, eplb_step_interval - eplb_step_interval // 4) + expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4) if global_expert_load is not None: ep_group = get_ep_group().device_group - assert global_expert_load.shape == (model.num_moe_layers, - model.num_logical_experts) + assert global_expert_load.shape == ( + model.num_moe_layers, + model.num_logical_experts, + ) assert global_expert_load.dtype == torch.int64 num_replicas = model.num_physical_experts @@ -273,20 +287,21 @@ class EplbState: logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}") + f"{num_gpus=}, {num_nodes=}" + ) # Get new expert mappings ( new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = (rebalance_experts( + ) = rebalance_experts( global_expert_load, num_replicas, num_groups, num_nodes, num_gpus, - )) + ) max_physical_slots = new_logical_to_physical_map.shape[-1] assert max_physical_slots <= logical_to_physical_map.shape[-1] @@ -326,22 +341,25 @@ class EplbState: expert_rearrangement_step_interval=eplb_step_interval, ) - def step(self, - model: MixtureOfExperts, - is_dummy: bool = False, - is_profile: bool = False, - log_stats: bool = False) -> None: + def step( + self, + model: MixtureOfExperts, + is_dummy: bool = False, + is_profile: bool = False, + log_stats: bool = False, + ) -> None: """ Step the EPLB state. Args: model (MixtureOfExperts): The MoE model. is_dummy (bool): If `True`, this is a dummy step and the load - metrics recorded in this forward pass will not count. Defaults - to `False`. + metrics recorded in this forward pass will not count. + Defaults to `False`. is_profile (bool): If `True`, perform a dummy rearrangement - with maximum communication cost. This is used in `profile_run` - to reserve enough memory for the communication buffer. + with maximum communication cost. This is used in + `profile_run` to reserve enough memory + for the communication buffer. log_stats (bool): If `True`, log the expert load metrics. # Stats @@ -368,32 +386,40 @@ class EplbState: all_reduce(total_expert_load_pass, group=ep_group) # num_tokens_per_rank: (num_moe_layers, num_ranks) - num_tokens_per_rank = total_expert_load_pass.reshape( - total_expert_load_pass.shape[0], ep_group.size(), - -1).sum(dim=-1).float() + num_tokens_per_rank = ( + total_expert_load_pass.reshape( + total_expert_load_pass.shape[0], ep_group.size(), -1 + ) + .sum(dim=-1) + .float() + ) # Compute balancedness ratio: # for each layer: # (mean load across ranks) / (max load across ranks) avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) - max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum( - dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) # Just to make type checker happy tokens_tensors: list[float] = torch.stack( - [avg_tokens_tensor, max_tokens_tensor]).tolist() + [avg_tokens_tensor, max_tokens_tensor] + ).tolist() avg_tokens, max_tokens = tokens_tensors balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 if ep_group.rank() == 0: logger.info( - "EPLB step: avg_tokens=%.2f, max_tokens=%d, " - "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + "EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f", + avg_tokens, + max_tokens, + balancedness, + ) # Update the expert load sliding window if not is_dummy: self.expert_load_window[self.expert_load_window_step] = ( - self.expert_load_pass.clone()) + self.expert_load_pass.clone() + ) self.expert_load_window_step += 1 if self.expert_load_window_step >= self.expert_load_window_size: self.expert_load_window_step = 0 @@ -404,17 +430,18 @@ class EplbState: # rearrangement step and perform rearrangement to ensure all ranks are # performing collective communication. self.expert_rearrangement_step += 1 - if (self.expert_rearrangement_step - >= self.expert_rearrangement_step_interval): + if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: self.expert_rearrangement_step = 0 self.rearrange(model) - def rearrange(self, - model: MixtureOfExperts, - is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_load: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, int]] = None) -> None: + def rearrange( + self, + model: MixtureOfExperts, + is_profile: bool = False, + execute_shuffle: bool = True, + global_expert_load: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None, + ) -> Optional[torch.Tensor]: """ Rearrange the experts according to the current load. """ @@ -427,8 +454,7 @@ class EplbState: if is_main_rank: torch.cuda.synchronize() time_start = time.perf_counter() - logger.info("Rearranging experts %s...", - "(profile)" if is_profile else "") + logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") if global_expert_load is None: # Map the physical expert load to global logical experts @@ -441,23 +467,25 @@ class EplbState: ) logical_expert_load_window.scatter_add_( dim=-1, - index=self.physical_to_logical_map.unsqueeze(0).expand_as( - self.expert_load_window).long(), + index=self.physical_to_logical_map.unsqueeze(0) + .expand_as(self.expert_load_window) + .long(), src=self.expert_load_window, ) if not execute_shuffle: metadata = torch.tensor( [ - model.num_moe_layers, model.num_logical_experts, - self.physical_to_logical_map.shape[1] + model.num_moe_layers, + model.num_logical_experts, + self.physical_to_logical_map.shape[1], ], dtype=torch.int32, device="cpu", ) - torch.distributed.broadcast(metadata, - group=get_ep_group().cpu_group, - group_src=0) + torch.distributed.broadcast( + metadata, group=get_ep_group().cpu_group, group_src=0 + ) # Perform all-reduce to get the expert load across all ranks global_expert_load_window = logical_expert_load_window.sum(dim=0) @@ -466,9 +494,9 @@ class EplbState: if not execute_shuffle: # (num_moe_layers, old_num_physical_experts) old_global_expert_indices = self.physical_to_logical_map - torch.distributed.broadcast(old_global_expert_indices, - group=ep_group, - group_src=0) + torch.distributed.broadcast( + old_global_expert_indices, group=ep_group, group_src=0 + ) return global_expert_load_window else: assert execute_shuffle @@ -483,10 +511,10 @@ class EplbState: # the GPUs to be released. cpu_group = get_ep_group().cpu_group num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) - num_gpus = sum(new_rank != -1 - for new_rank in rank_mapping.values()) - num_replicas = num_replicas // ep_group.size( - ) * num_gpus # handle num replicas change + num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values()) + num_replicas = ( + num_replicas // ep_group.size() * num_gpus + ) # handle num replicas change else: num_nodes = get_node_count() num_gpus = ep_group.size() @@ -496,20 +524,21 @@ class EplbState: logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}") + f"{num_gpus=}, {num_nodes=}" + ) # Get new expert mappings ( new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = (rebalance_experts( + ) = rebalance_experts( global_expert_load_window, num_replicas, num_groups, num_nodes, num_gpus, - )) + ) # Update expert weights rearrange_expert_weights_inplace( @@ -522,18 +551,20 @@ class EplbState: ) if not is_profile: - if self.physical_to_logical_map.shape[ - 1] != new_physical_to_logical_map.shape[1]: + if ( + self.physical_to_logical_map.shape[1] + != new_physical_to_logical_map.shape[1] + ): self.physical_to_logical_map = new_physical_to_logical_map.to( - self.physical_to_logical_map.device) + self.physical_to_logical_map.device + ) else: self.physical_to_logical_map.copy_(new_physical_to_logical_map) max_physical_slots = new_logical_to_physical_map.shape[-1] assert max_physical_slots <= self.logical_to_physical_map.shape[-1] new_logical_to_physical_map = torch.nn.functional.pad( new_logical_to_physical_map, - (0, - self.logical_to_physical_map.shape[-1] - max_physical_slots), + (0, self.logical_to_physical_map.shape[-1] - max_physical_slots), value=-1, ) self.logical_to_physical_map.copy_(new_logical_to_physical_map) @@ -548,6 +579,7 @@ class EplbState: " (profile) " if is_profile else " ", time_end - time_start, ) + return None @staticmethod def recv_state() -> tuple[torch.Tensor, torch.Tensor]: @@ -556,11 +588,10 @@ class EplbState: """ ep_group = get_ep_group() metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, - group=ep_group.cpu_group, - group_src=0) + torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist()) + metadata.tolist() + ) global_expert_load = torch.zeros( (num_moe_layers, num_logical_experts), dtype=torch.int64, @@ -572,9 +603,9 @@ class EplbState: dtype=torch.int64, device=ep_group.device, ) - torch.distributed.broadcast(old_global_expert_indices, - group=ep_group.device_group, - group_src=0) + torch.distributed.broadcast( + old_global_expert_indices, group=ep_group.device_group, group_src=0 + ) return global_expert_load, old_global_expert_indices @@ -613,4 +644,4 @@ def _node_count_with_rank_mapping( if is_same_node and node_assignment[other_rank] == 0: node_assignment[other_rank] = next_node_id - return next_node_id \ No newline at end of file + return next_node_id diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 879b5b9f18240..c9d30d6481ab6 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -15,8 +15,9 @@ on how the EPLB algorithm works. import torch -def balanced_packing(weight: torch.Tensor, - num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: +def balanced_packing( + weight: torch.Tensor, num_packs: int +) -> tuple[torch.Tensor, torch.Tensor]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. @@ -34,25 +35,21 @@ def balanced_packing(weight: torch.Tensor, groups_per_pack = num_groups // num_packs if groups_per_pack == 1: - pack_index = torch.arange(weight.size(-1), - dtype=torch.int64, - device=weight.device).expand(weight.shape) + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=weight.device + ).expand(weight.shape) rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) return pack_index, rank_in_pack indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, - fill_value=-1, - dtype=torch.int64, - device="cpu") + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") rank_in_pack = torch.full_like(pack_index, fill_value=-1) for i in range(num_layers): pack_weights = [0] * num_packs pack_items = [0] * num_packs for group in indices[i]: pack = min( - (i - for i in range(num_packs) if pack_items[i] < groups_per_pack), + (i for i in range(num_packs) if pack_items[i] < groups_per_pack), key=pack_weights.__getitem__, ) assert pack_items[pack] < groups_per_pack @@ -64,8 +61,8 @@ def balanced_packing(weight: torch.Tensor, def replicate_experts( - weight: torch.Tensor, - num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + weight: torch.Tensor, num_phy: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. @@ -83,8 +80,7 @@ def replicate_experts( num_redundant = num_phy - num_log assert num_redundant >= 0 device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, - device=device).repeat(n, 1) + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device) @@ -102,20 +98,23 @@ def rebalance_experts_hierarchical( num_groups: int, num_nodes: int, num_gpus: int, -): +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: weight: [num_moe_layers, num_logical_experts] num_physical_experts: number of physical experts after replication num_groups: number of expert groups num_nodes: number of server nodes, where the intra-node network - (e.g, NVLink) is faster + (e.g., NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: - physical_to_logical_map: [num_moe_layers, num_physical_experts] - logical_to_physical_map: [num_moe_layers, num_logical_experts, X] - logical_count: [num_moe_layers, num_logical_experts] + physical_to_logical_map (torch.Tensor): + [num_moe_layers, num_physical_experts] + logical_to_physical_map (torch.Tensor): + [num_moe_layers, num_logical_experts, X] + logical_count (torch.Tensor): + [num_moe_layers, num_logical_experts] """ num_layers, num_logical_experts = weight.shape assert num_logical_experts % num_groups == 0 @@ -131,45 +130,51 @@ def rebalance_experts_hierarchical( inv.scatter_( 1, perm, - torch.arange(perm.size(1), dtype=torch.int64, - device=perm.device).expand(perm.shape), + torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( + perm.shape + ), ) return inv # Step 1: pack groups to nodes tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = balanced_packing( - tokens_per_group, num_nodes) - log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * - group_size).unsqueeze(-1) + - torch.arange(group_size, - dtype=torch.int64, - device=group_pack_index.device)).flatten(-2) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) + ).flatten(-2) mlog2log = inverse(log2mlog) # Step 2: construct redundant experts within nodes # [num_layers * num_nodes, num_logical_experts // num_nodes] tokens_per_mlog = weight.gather(-1, mlog2log).view( - -1, num_logical_experts // num_nodes) + -1, num_logical_experts // num_nodes + ) phy2mlog, phyrank, mlogcnt = replicate_experts( - tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_mlog, num_physical_experts // num_nodes + ) # Step 3: pack physical_experts to GPUs # [num_layers * num_nodes, num_physical_experts // num_nodes] tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = balanced_packing(tokens_per_phy, - num_gpus // num_nodes) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack pphy2phy = inverse(phy2pphy) pphy2mlog = phy2mlog.gather( - -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( - 0, - num_logical_experts, - num_logical_experts // num_nodes, - device=group_pack_index.device, - ).view(1, -1, 1)).flatten(-2) + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) @@ -197,11 +202,13 @@ def rebalance_experts( num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: - physical_to_logical_map: [layers, num_replicas], the expert index of - each replica - logical_to_physical_map: [layers, num_logical_experts, X], the replica - indices for each expert - expert_count: [layers, num_logical_experts], number of physical + physical_to_logical_map: + [layers, num_replicas], the expert index of each replica + logical_to_physical_map: + [layers, num_logical_experts, X], the replica indices for each + expert + expert_count: + [layers, num_logical_experts], number of physical replicas for each logical expert """ num_layers, num_logical_experts = weight.shape @@ -209,11 +216,13 @@ def rebalance_experts( if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus) + weight, num_replicas, num_groups, num_nodes, num_gpus + ) else: # use global load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus) + weight, num_replicas, 1, 1, num_gpus + ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 log2phy: torch.Tensor = torch.full( @@ -225,8 +234,9 @@ def rebalance_experts( log2phy.view(num_layers, -1).scatter_( -1, phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, - device=log2phy.device).expand(num_layers, -1), + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), ) return phy2log, log2phy, logcnt diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8a7d1170bb01..344fae457c9b5 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -11,8 +11,13 @@ from functools import partial from typing import Optional import torch -from torch.distributed import (P2POp, ProcessGroup, all_gather, - batch_isend_irecv, get_global_rank) +from torch.distributed import ( + P2POp, + ProcessGroup, + all_gather, + batch_isend_irecv, + get_global_rank, +) def idx_local_to_global( @@ -132,8 +137,7 @@ def shuffle_layer( continue if old_indices[src_global] == new_indices[dst_global]: is_received_locally[dst] = True - for weight, buffer in zip(expert_weights, - expert_weights_buffer): + for weight, buffer in zip(expert_weights, expert_weights_buffer): buffer[dst].copy_(weight[src]) p2p_ops: list[P2POp] = [] @@ -177,7 +181,8 @@ def shuffle_layer( torch.distributed.isend, weight[src], dst_global, - ) for weight in expert_weights + ) + for weight in expert_weights ] # 3. Initiate receiving of weights. @@ -216,7 +221,8 @@ def shuffle_layer( torch.distributed.irecv, weight[dst], src_global, - ) for weight in expert_weights_buffer + ) + for weight in expert_weights_buffer ] # 4. Execute the P2P operations. The real communication happens here. @@ -271,29 +277,25 @@ def rearrange_expert_weights_inplace( if rank_mapping is not None: if len(rank_mapping) == ep_group.size(): # scale down - new_global_expert_indices = \ - _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices = _map_new_expert_indices_with_rank_mapping( new_global_expert_indices, rank_mapping, ) else: # scale up - old_global_expert_indices = \ - _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices = _map_old_expert_indices_with_rank_mapping( old_global_expert_indices, rank_mapping, ep_group.size(), ) - assert old_global_expert_indices.shape[ - 1] == new_global_expert_indices.shape[1] + assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1] num_moe_layers, num_physical_experts = old_global_expert_indices.shape assert len(expert_weights) == num_moe_layers num_local_physical_experts = next(iter(expert_weights[0])).shape[0] - assert new_global_expert_indices.shape == (num_moe_layers, - num_physical_experts) + assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) ep_rank = ep_group.rank() ep_size = ep_group.size() @@ -342,13 +344,13 @@ def _map_old_expert_indices_with_rank_mapping( ) -> torch.Tensor: """ Map the old global expert indices to the new global expert indices. - + Args: old_global_expert_indices: Shape (num_layers, old_ep_size * num_local_physical_experts). rank_mapping: Mapping from old rank to new rank. new_ep_size: New expert parallelism size. - + Returns: Mapped expert indices with shape (num_layers, new_ep_size * num_local_physical_experts). @@ -379,8 +381,9 @@ def _map_old_expert_indices_with_rank_mapping( new_start_idx = new_rank * num_local_physical_experts new_end_idx = (new_rank + 1) * num_local_physical_experts - mapped_expert_indices[:, new_start_idx:new_end_idx] = \ + mapped_expert_indices[:, new_start_idx:new_end_idx] = ( old_global_expert_indices[:, old_start_idx:old_end_idx] + ) # If new_rank is None or >= new_ep_size, the experts remain -1 # (scale down case) @@ -415,8 +418,9 @@ def _map_new_expert_indices_with_rank_mapping( new_start_idx = new_rank * num_local_physical_experts new_end_idx = (new_rank + 1) * num_local_physical_experts - mapped_expert_indices[:, old_start_idx:old_end_idx] = \ + mapped_expert_indices[:, old_start_idx:old_end_idx] = ( new_global_expert_indices[:, new_start_idx:new_end_idx] + ) return mapped_expert_indices diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 2d7935773dd9f..d93ae63e0eb4d 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -14,17 +14,18 @@ from typing import Any, Callable, Optional, Union import msgspec import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import ExternalBlockHash logger = init_logger(__name__) class EventBatch( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False, # type: ignore[call-arg] + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] ): ts: float events: list[Any] @@ -32,24 +33,30 @@ class EventBatch( class KVCacheEvent( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False, # type: ignore[call-arg] - tag=True): + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] + tag=True, +): """Base class for all KV cache-related events""" +MEDIUM_GPU = "GPU" + + class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[ExternalBlockHash] + parent_block_hash: Optional[ExternalBlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] + medium: Optional[str] class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[ExternalBlockHash] + medium: Optional[str] class AllBlocksCleared(KVCacheEvent): @@ -63,14 +70,14 @@ class KVEventBatch(EventBatch): class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism support. - + In data parallel setups, each DP rank runs its own EventPublisher instance to avoid duplicate events and ensure proper event attribution: - + - Each DP rank creates a separate publisher - Publishers automatically annotate events with their data_parallel_rank - This allows consumers to distinguish events from different DP ranks - + The publisher is responsible for adding DP metadata since the scheduler operates independently of DP topology and shouldn't need DP awareness. """ @@ -124,6 +131,7 @@ class ZmqEventPublisher(EventPublisher): topic: Topic to publish events to. """ + SHUTDOWN_TIMEOUT: float = 1.0 END_SEQ = (-1).to_bytes(8, "big", signed=True) @@ -150,21 +158,22 @@ class ZmqEventPublisher(EventPublisher): self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) self._replay_endpoint = self.offset_endpoint_port( - replay_endpoint, self._dp_rank) + replay_endpoint, self._dp_rank + ) self._hwm = hwm self._socket_setup() # Payload self._seq_gen = count() - self._topic_bytes = topic.encode('utf-8') + self._topic_bytes = topic.encode("utf-8") # Thread self._running = True logger.info("Starting ZMQ publisher thread") - self._thread = threading.Thread(target=self._publisher_thread, - daemon=True, - name="zmq-publisher") + self._thread = threading.Thread( + target=self._publisher_thread, daemon=True, name="zmq-publisher" + ) self._thread.start() def publish(self, events: EventBatch) -> None: @@ -214,10 +223,12 @@ class ZmqEventPublisher(EventPublisher): self._pub.set_hwm(self._hwm) # Heuristic: bind if wildcard / * present, else connect. # bind stable, connect volatile convention - if (self._endpoint is not None - and ("*" in self._endpoint or "::" in self._endpoint - or self._endpoint.startswith("ipc://") - or self._endpoint.startswith("inproc://"))): + if self._endpoint is not None and ( + "*" in self._endpoint + or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://") + ): self._pub.bind(self._endpoint) elif self._endpoint is not None: self._pub.connect(self._endpoint) @@ -257,8 +268,7 @@ class ZmqEventPublisher(EventPublisher): payload = self._pack.encode(event) seq_bytes = seq.to_bytes(8, "big") - self._pub.send_multipart( - (self._topic_bytes, seq_bytes, payload)) + self._pub.send_multipart((self._topic_bytes, seq_bytes, payload)) self._buffer.append((seq, payload)) self._event_queue.task_done() @@ -285,24 +295,26 @@ class ZmqEventPublisher(EventPublisher): # (identity, empty_delim) are stripped off by the router # receiving payload is (seq_bytes, payload) self._replay.send_multipart( - (client_id, b"", seq.to_bytes(8, "big"), buf)) + (client_id, b"", seq.to_bytes(8, "big"), buf) + ) # Send end of sequence marker # receiving payload is (-1, b""") self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) @staticmethod - def offset_endpoint_port(endpoint: Optional[str], - data_parallel_rank: int) -> Optional[str]: - """Helper function to offset the port in an endpoint by + def offset_endpoint_port( + endpoint: Optional[str], data_parallel_rank: int + ) -> Optional[str]: + """Helper function to offset the port in an endpoint by the data parallel rank. Args: - endpoint: The endpoint string + endpoint: The endpoint string (e.g., "tcp://*:5557" or "inproc://cache") data_parallel_rank: The data parallel rank to offset by Returns: - The endpoint with the port offset by data_parallel_rank + The endpoint with the port offset by data_parallel_rank or suffix appended """ # Do nothing if input is None or data_parallel_rank is 0 @@ -316,7 +328,7 @@ class ZmqEventPublisher(EventPublisher): # Get everything after the last colon (the port) last_colon_idx = endpoint.rfind(":") base_addr = endpoint[:last_colon_idx] - base_port = int(endpoint[last_colon_idx + 1:]) + base_port = int(endpoint[last_colon_idx + 1 :]) new_port = base_port + data_parallel_rank return f"{base_addr}:{new_port}" return endpoint @@ -330,16 +342,15 @@ class EventPublisherFactory: } @classmethod - def register_publisher(cls, name: str, - ctor: Callable[..., EventPublisher]) -> None: + def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None: if name in cls._registry: raise KeyError(f"publisher '{name}' already registered") cls._registry[name] = ctor @classmethod - def create(cls, - config: Optional[KVEventsConfig], - data_parallel_rank: int = 0) -> EventPublisher: + def create( + cls, config: Optional[KVEventsConfig], data_parallel_rank: int = 0 + ) -> EventPublisher: """Create publisher from a config mapping.""" if not config: return NullEventPublisher() @@ -352,5 +363,4 @@ class EventPublisherFactory: constructor = cls._registry[kind] except KeyError as exc: raise ValueError(f"Unknown event publisher '{kind}'") from exc - return constructor(data_parallel_rank=data_parallel_rank, - **config_dict) + return constructor(data_parallel_rank=data_parallel_rank, **config_dict) diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index 349d3dfbd84fc..39377aabcce3a 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -2,7 +2,7 @@ # Distributed KV cache transfer This folder implements distributed KV cache transfer across vLLM instances. -Currently the main usecase is for disaggregated prefilling. +Currently the main use case is for disaggregated prefilling. ## Abstractions @@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions: Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. -NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed communication service already supports key-value-based lookup (like redis or RDMA database). diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index fa9b7e4f14c02..2bf4e1feb7034 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -2,11 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_transfer_state import ( - KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, - has_kv_transfer_group, is_v1_kv_transfer_group) + KVConnectorBaseType, + ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) __all__ = [ - "get_kv_transfer_group", "has_kv_transfer_group", - "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", - "KVConnectorBaseType" + "get_kv_transfer_group", + "has_kv_transfer_group", + "is_v1_kv_transfer_group", + "ensure_kv_transfer_initialized", + "ensure_kv_transfer_shutdown", + "KVConnectorBaseType", ] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 584fc1d655951..395a4e20e0ba3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -4,17 +4,17 @@ import importlib from typing import TYPE_CHECKING, Callable -# yapf: disable import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.base import ( - KVConnectorBase, KVConnectorBaseType) + KVConnectorBase, + KVConnectorBaseType, +) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger -# yapf: enable - if TYPE_CHECKING: - from vllm.config import KVTransferConfig, VllmConfig + from vllm.config import VllmConfig + from vllm.config.kv_transfer import KVTransferConfig logger = init_logger(__name__) @@ -23,8 +23,7 @@ class KVConnectorFactory: _registry: dict[str, Callable[[], type[KVConnectorBase]]] = {} @classmethod - def register_connector(cls, name: str, module_path: str, - class_name: str) -> None: + def register_connector(cls, name: str, module_path: str, class_name: str) -> None: """Register a connector with a lazy-loading module and class name.""" if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") @@ -42,13 +41,18 @@ class KVConnectorFactory: role: KVConnectorRole, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}") + raise ValueError( + "Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}" + ) kv_transfer_config = config.kv_transfer_config connector_cls = cls.get_connector_class(kv_transfer_config) - logger.info("Creating v1 connector with name: %s and engine_id: %s", - connector_cls.__name__, kv_transfer_config.engine_id) + logger.info( + "Creating v1 connector with name: %s and engine_id: %s", + connector_cls.__name__, + kv_transfer_config.engine_id, + ) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. # Scheduler connector: # - Co-locate with scheduler process @@ -61,7 +65,7 @@ class KVConnectorFactory: @classmethod def get_connector_class( - cls, kv_transfer_config: "KVTransferConfig" + cls, kv_transfer_config: "KVTransferConfig" ) -> type[KVConnectorBaseType]: """Get the connector class by name.""" connector_name = kv_transfer_config.kv_connector @@ -70,8 +74,7 @@ class KVConnectorFactory: else: connector_module_path = kv_transfer_config.kv_connector_module_path if connector_module_path is None: - raise ValueError( - f"Unsupported connector type: {connector_name}") + raise ValueError(f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) return connector_cls @@ -84,24 +87,35 @@ class KVConnectorFactory: KVConnectorFactory.register_connector( "SharedStorageConnector", "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", - "SharedStorageConnector") + "SharedStorageConnector", +) KVConnectorFactory.register_connector( "P2pNcclConnector", "vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector", - "P2pNcclConnector") + "P2pNcclConnector", +) KVConnectorFactory.register_connector( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", - "LMCacheConnectorV1") + "LMCacheConnectorV1", +) KVConnectorFactory.register_connector( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", - "NixlConnector") + "NixlConnector", +) KVConnectorFactory.register_connector( "MultiConnector", "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", - "MultiConnector") + "MultiConnector", +) + +KVConnectorFactory.register_connector( + "OffloadingConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", + "OffloadingConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 2364400b3d350..056ece60e84dd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,18 +3,18 @@ """ KV cache helper for store. """ + from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import Optional, cast +from typing import Literal, Optional, Union, cast import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -22,14 +22,12 @@ logger = init_logger(__name__) class model_aware_kv_ops_helper: - def __init__(self, config: VllmConfig): self.is_deepseek_mla = config.model_config.is_deepseek_mla self.use_mla_opt = not envs.VLLM_MLA_DISABLE self.tp_size = config.parallel_config.tensor_parallel_size def get_model_args(self, model_executable: torch.nn.Module): - model_config = model_executable.model.config self.model_executable = model_executable num_heads = int(model_config.num_key_value_heads / self.tp_size) @@ -44,14 +42,12 @@ class model_aware_kv_ops_helper: # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading # to a kv_cache shape of [2, num_blks, blk_size, # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. + # For more details, see vllm/v1/attention/backends/mla/common.py. if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + \ - model_config.qk_rope_head_dim + head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim num_heads = 1 elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + \ - model_config.qk_rope_head_dim + head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim else: head_size = getattr(model_config, "head_dim", None) if head_size is None: @@ -68,16 +64,24 @@ class model_aware_kv_ops_helper: value_cache = kv_cache[1].reshape(-1, num_heads, head_size) return key_cache, value_cache - def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, - layer, kv_cache, slot_mapping, start_pos, end_pos): - + def put_kv_to_cache( + self, + model_executable: torch.nn.Module, + keys, + values, + layer, + kv_cache, + slot_mapping, + start_pos, + end_pos, + ): model_config = model_executable.model.config if self.is_deepseek_mla and self.use_mla_opt: layer.self_attn.attn = layer.self_attn.mla_attn k_c_normed_k_pe = keys.squeeze(1) - k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :] ops.concat_and_cache_mla( k_c_normed.to(kv_cache.device), k_pe.to(kv_cache.device), @@ -107,17 +111,17 @@ def get_kv_connector_cache_layout(): kv_config = vllm_config.kv_transfer_config if kv_config is not None: connector_cls = KVConnectorFactory.get_connector_class(kv_config) - required_kvcache_layout = connector_cls.get_required_kvcache_layout( - vllm_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config) if required_kvcache_layout is not None: return required_kvcache_layout - logger.info_once("Connectors do not specify a " \ - "kv cache layout, defaulting to NHD.") + logger.info_once( + "Connectors do not specify a kv cache layout, defaulting to NHD." + ) return "NHD" class KVOutputAggregator: - """Utility class to aggregate the output of all workers into a single + """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" def __init__(self, world_size: int): @@ -126,14 +130,16 @@ class KVOutputAggregator: self._recv_remaining_count = defaultdict[str, int](lambda: world_size) self._send_remaining_count = defaultdict[str, int](lambda: world_size) - def aggregate(self, - outputs: list[ModelRunnerOutput], - output_rank: int = 0) -> ModelRunnerOutput: - # aggregate kv_connector_output from all workers + def aggregate( + self, outputs: list[ModelRunnerOutput], output_rank: int = 0 + ) -> ModelRunnerOutput: + # Aggregate kv_connector_output from all workers - def update_finished_set(req_ids: Optional[set[str]], - remaining_count_dict: dict[str, int], - finished_set: set[str]) -> None: + def update_finished_set( + req_ids: Optional[set[str]], + remaining_count_dict: dict[str, int], + finished_set: set[str], + ) -> None: for req_id in req_ids or (): remaining_count_dict[req_id] -= 1 if remaining_count_dict[req_id] == 0: @@ -142,14 +148,35 @@ class KVOutputAggregator: finished_sending = set[str]() finished_recving = set[str]() - for output in outputs: - output = output.kv_connector_output + aggregated_kv_connector_stats = None + invalid_block_ids = set[int]() + for model_runner_output in outputs: + output = model_runner_output.kv_connector_output if not output: continue - update_finished_set(output.finished_sending, - self._send_remaining_count, finished_sending) - update_finished_set(output.finished_recving, - self._recv_remaining_count, finished_recving) + update_finished_set( + output.finished_sending, self._send_remaining_count, finished_sending + ) + update_finished_set( + output.finished_recving, self._recv_remaining_count, finished_recving + ) + + # Aggregate kv_connector_stats from all workers. + if aggregated_kv_connector_stats is None: + # Use the first worker's kv_connector_stats as accumulator. + aggregated_kv_connector_stats = output.kv_connector_stats + elif kv_connector_stats := output.kv_connector_stats: + if aggregated_kv_connector_stats is None: + aggregated_kv_connector_stats = kv_connector_stats + else: + assert isinstance( + aggregated_kv_connector_stats, type(kv_connector_stats) + ) + aggregated_kv_connector_stats = ( + aggregated_kv_connector_stats.aggregate(kv_connector_stats) + ) + + invalid_block_ids |= output.invalid_block_ids # select output of the worker specified by output_rank output = outputs[output_rank] @@ -157,22 +184,22 @@ class KVOutputAggregator: output.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending or None, finished_recving=finished_recving or None, + kv_connector_stats=aggregated_kv_connector_stats or None, + invalid_block_ids=invalid_block_ids, ) return output - def async_aggregate(self, - output_futures: Sequence[Future[ModelRunnerOutput]], - output_rank: int = 0) -> Future[ModelRunnerOutput]: + def async_aggregate( + self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0 + ) -> Future[ModelRunnerOutput]: """Takes a list of futures and returns a single future which resolves to the respective list of outputs.""" result_future: Future[ModelRunnerOutput] = Future() - outputs: list[Optional[ModelRunnerOutput]] = [None - ] * len(output_futures) + outputs: list[Optional[ModelRunnerOutput]] = [None] * len(output_futures) def make_callback(idx): - def callback(fut): if result_future.done(): return @@ -187,8 +214,10 @@ class KVOutputAggregator: # this check assumes io_thread_pool uses a single thread if all(outputs): result_future.set_result( - self.aggregate(cast(list[ModelRunnerOutput], outputs), - output_rank)) + self.aggregate( + cast(list[ModelRunnerOutput], outputs), output_rank + ) + ) return callback @@ -196,3 +225,53 @@ class KVOutputAggregator: output_future.add_done_callback(make_callback(i)) return result_future + + +def _make_src_and_dst_indices( + src_block_ids: list[int], + dst_block_ids: list[int], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> tuple[torch.Tensor, torch.Tensor]: + src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64) + return src_indices, dst_indices + + +def copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: Literal["h2d", "d2h"], +) -> None: + """Copy kv blocks between different buffers.""" + if ( + not src_kv_caches + or not dst_kv_caches + or not src_block_ids + or not dst_block_ids + or len(src_block_ids) != len(dst_block_ids) + ): + return + + src_device = next(iter(src_kv_caches.values())).device + dst_device = next(iter(dst_kv_caches.values())).device + + src_indices, dst_indices = _make_src_and_dst_indices( + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device, + ) + + from vllm.platforms import current_platform + + if direction == "h2d": + copy_fn = current_platform.insert_blocks_to_device + else: + copy_fn = current_platform.swap_out_blocks_to_host + for layer_name in src_kv_caches: + src_tensor = src_kv_caches[layer_name] + dst_tensor = dst_kv_caches[layer_name] + copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index f00f31dde915a..034c7afe97a48 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorRole, +) __all__ = ["KVConnectorRole", "KVConnectorBase_V1"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 5601ee74be110..e871b3017d8bb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -7,18 +7,21 @@ communication in vLLM v1 The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. - get_num_new_matched_tokens() - get number of new tokens + get_num_new_matched_tokens() - get number of new tokens that exist in the remote KV cache. Might be called multiple times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. update_connector_output() - update KVConnector state after output is received from worker-side connectors. - request_finished() - called when a request is finished, with - the computed kv cache blocks for the request. - Returns whether KV cache should be freed now or will be - freed asynchronously and optionally returns KV transfer - params. + request_finished() - called once when a request is finished, + with the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or if the + connector now assumes responsibility for freeing the + the blocks asynchronously. Also optionally returns KV + transfer params. + take_events() - returns new KV events that were collected + by the connector since the last call. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -34,6 +37,7 @@ The class provides the following primitives: import enum from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch @@ -45,15 +49,23 @@ from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig + from vllm.distributed.kv_events import KVCacheEvent + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request # s_tensor_list, d_tensor_list, s_indices, d_indices, direction -CopyBlocksOp = Callable[[ - dict[str, torch.Tensor], dict[ - str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"] -], None] +CopyBlocksOp = Callable[ + [ + dict[str, torch.Tensor], + dict[str, torch.Tensor], + list[int], + list[int], + Literal["h2d", "d2h"], + ], + None, +] logger = init_logger(__name__) @@ -71,15 +83,16 @@ class KVConnectorMetadata(ABC): # noqa: B024 Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. """ + pass class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " - "subject to change in the future as we iterate the design.") + "subject to change in the future as we iterate the design." + ) self._connector_metadata: Optional[KVConnectorMetadata] = None self._vllm_config = vllm_config self._role = role @@ -92,11 +105,10 @@ class KVConnectorBase_V1(ABC): # Worker-side methods # ============================== - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. - This function should be called by the model runner every time + This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving. @@ -108,7 +120,7 @@ class KVConnectorBase_V1(ABC): def clear_connector_metadata(self) -> None: """Clear the connector metadata. - This function should be called by the model runner every time + This function should be called by the model runner every time after the model execution. """ self._connector_metadata = None @@ -131,7 +143,7 @@ class KVConnectorBase_V1(ABC): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: + Args: kv_caches: dictionary of layer names, kv cache """ return @@ -144,8 +156,7 @@ class KVConnectorBase_V1(ABC): return @abstractmethod - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -156,9 +167,9 @@ class KVConnectorBase_V1(ABC): **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ pass @@ -168,7 +179,7 @@ class KVConnectorBase_V1(ABC): Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -177,16 +188,21 @@ class KVConnectorBase_V1(ABC): pass @abstractmethod - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: """ - Start saving a layer of KV cache from vLLM's paged buffer + Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -222,6 +238,40 @@ class KVConnectorBase_V1(ABC): """ return None, None + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + + Returns: + Set of block IDs that encountered load errors. + Empty set if no load errors occurred. + + Notes: + - Applies to both sync- and async-loading requests. + - Async loading: failed blocks may be reported in any forward pass + up to and including the pass where the request ID is returned by + `get_finished()`. Even if failures occur, the request must still + be reported via `get_finished()`, and the failed block IDs must + appear here no later than that same pass. + - Sync loading: failed blocks should be reported in the forward + pass in which they are detected. + """ + return set() + + def shutdown(self): + """ + Shutdown the connector. This is called when the worker process + is shutting down to ensure that all the async operations are + completed and the connector is cleaned up properly. + """ + return None + + def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: + """ + Get the KV connector stats collected during the last interval. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -231,11 +281,11 @@ class KVConnectorBase_V1(ABC): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally @@ -243,18 +293,28 @@ class KVConnectorBase_V1(ABC): Returns: A tuple with the following elements: - - The number of tokens that can be loaded from the + - An optional number of tokens that can be loaded from the external KV cache beyond what is already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. - `True` if external KV cache tokens will be loaded asynchronously (between scheduler steps). Must be 'False' if the first element is 0. + + Notes: + The connector should only consider the largest prefix of prompt- + tokens for which KV cache is actually available at the time of the + call. If the cache cannot be loaded for some tokens (e.g., due to + connectivity issues or eviction), those tokens must not be taken + into account. """ pass @abstractmethod - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. @@ -274,7 +334,8 @@ class KVConnectorBase_V1(ABC): @abstractmethod def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: """ Build the connector metadata for this step. @@ -302,7 +363,11 @@ class KVConnectorBase_V1(ABC): block_ids: list[int], ) -> tuple[bool, Optional[dict[str, Any]]]: """ - Called when a request has finished, before its blocks are freed. + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the the blocks + asynchronously by returning True. Returns: True if the request is being saved/sent asynchronously and blocks @@ -313,9 +378,17 @@ class KVConnectorBase_V1(ABC): """ return False, None + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + return () + @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]: """ Get the required KV cache layout for this connector. Args: @@ -327,6 +400,30 @@ class KVConnectorBase_V1(ABC): """ if cls is KVConnectorBase_V1: - raise TypeError("get_required_kvcache_layout should not be called " - "on the abstract base class") + raise TypeError( + "get_required_kvcache_layout should not be called " + "on the abstract base class" + ) + return None + + def get_finished_count(self) -> Optional[int]: + """ + Get the count of requests expected to complete send/receive operations + via this connector. + + Returns: + int: expected sending or receiving completion count. + """ + + return None + + @classmethod + def build_kv_connector_stats( + cls, data: Optional[dict[str, Any]] = None + ) -> Optional["KVConnectorStats"]: + """ + KVConnectorStats resolution method. This method allows dynamically + registered connectors to return their own KVConnectorStats object, + which can implement custom aggregation logic on the data dict. + """ return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e838ac2499c04..b50cc3ab30fa9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -7,7 +7,10 @@ from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -21,7 +24,6 @@ logger = init_logger(__name__) class LMCacheConnectorV1(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) @@ -29,8 +31,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): # ============================== # Worker-side methods # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -41,9 +42,9 @@ class LMCacheConnectorV1(KVConnectorBase_V1): **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ self._lmcache_engine.start_load_kv(forward_context, **kwargs) @@ -52,7 +53,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -60,22 +61,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1): """ self._lmcache_engine.wait_for_layer_load(layer_name) - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: """ - Start saving the a layer of KV cache from vLLM's paged buffer + Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ - self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, - **kwargs) + self._lmcache_engine.save_kv_layer( + layer_name, kv_layer, attn_metadata, **kwargs + ) def wait_for_save(self): """ @@ -110,34 +117,35 @@ class LMCacheConnectorV1(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( - request, num_computed_tokens), False + request, num_computed_tokens + ), False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. """ - self._lmcache_engine.update_state_after_alloc(request, - num_external_tokens) + self._lmcache_engine.update_state_after_alloc(request, num_external_tokens) def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: """ Build the connector metadata for this step. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py new file mode 100644 index 0000000000000..879cc9a23581a --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field +from typing import Any, Optional, Union + +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class KVConnectorStats: + """ + Base class for KV Connector Stats, a container for transfer performance + metrics or otherwise important telemetry from the connector. + All sub-classes need to be serializable as stats are sent from worker to + logger process. + """ + + data: dict[str, Any] = field(default_factory=dict) + + def reset(self): + """Reset the stats, clear the state.""" + raise NotImplementedError + + def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats": + """ + Aggregate stats with another `KVConnectorStats` object. + """ + raise NotImplementedError + + def reduce(self) -> dict[str, Union[int, float]]: + """ + Reduce the observations collected during a time interval to one or + more representative values (eg avg/median/sum of the series). + This is meant to be called by the logger to produce a summary of the + stats for the last time interval. + """ + raise NotImplementedError + + def is_empty(self) -> bool: + """Return True if the stats are empty.""" + raise NotImplementedError + + +class KVConnectorLogging: + def __init__(self, kv_tranfer_config: KVTransferConfig): + # This should be called on frontend process. + assert not has_kv_transfer_group() + # Instantiate the connector's stats class. + if kv_tranfer_config and kv_tranfer_config.kv_connector: + self.connector_cls = KVConnectorFactory.get_connector_class( + kv_tranfer_config + ) + self.reset() + + def reset(self): + self.transfer_stats_accumulator: Optional[KVConnectorStats] = None + + def observe(self, transfer_stats_data: dict[str, Any]): + # Should not be called when a KVConnector is not configured. + assert self.connector_cls is not None + # Called periodically when connector syncs with the scheduler. + # Note that this is not the same as the logging interval. + # We expect transfer_stats_data to be aggregated across all workers and + # consist of observations from a single connector or a MultiConnector. + transfer_stats = self.connector_cls.build_kv_connector_stats( + transfer_stats_data + ) + if transfer_stats is None: + logger.warning_once( + "The connector %s is collecting stats but " + "does not implement the " + "`build_kv_connector_stats` method. " + "Stats will not be logged.", + self.connector_cls, + ) + return + + if self.transfer_stats_accumulator is None: + self.transfer_stats_accumulator = transfer_stats + else: + # Accumulate last interval stats. + self.transfer_stats_accumulator = self.transfer_stats_accumulator.aggregate( + transfer_stats + ) + + def log(self, log_fn=logger.info): + """Log transfer metrics periodically, similar to throughput logging""" + if ( + self.transfer_stats_accumulator + and not self.transfer_stats_accumulator.is_empty() + ): + # Produce a single cumulative stats object for the last time + # interval from the recorded observations. + xfer_metrics = self.transfer_stats_accumulator.reduce() + xfer_metrics_str = ", ".join(f"{k}={v}" for k, v in xfer_metrics.items()) + log_fn("KV Transfer metrics: %s", xfer_metrics_str) + + # Reset metrics for next interval + self.reset() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d3f6a226dc72c..e48d4ccd1d6c0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,24 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from collections.abc import Iterable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.config import VllmConfig +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger -from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata + from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -30,6 +36,42 @@ class MultiKVConnectorMetadata(KVConnectorMetadata): extra_async_saves: Optional[dict[str, int]] = None +@dataclass +class MultiKVConnectorStats(KVConnectorStats): + """ + Maintain a dict of KVConnectorStats objects, one for each connector. + This is used to aggregate the stats from all connectors separately. + """ + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + for connector_id, stats in other.data.items(): + if connector_id not in self.data: + self[connector_id] = stats + else: + assert isinstance(stats, type(self.data[connector_id])) + self[connector_id] = self[connector_id].aggregate(stats) + return self + + def reset(self): + for stats in self.data.values(): + stats.reset() + + def reduce(self) -> dict[str, Any]: + # TODO (NickLucche) Adjust for logging on separate lines + return { + connector_id: stats.reduce() for connector_id, stats in self.data.items() + } + + def is_empty(self) -> bool: + return all(stats.is_empty() for stats in self.data.values()) + + def __getitem__(self, connector_id: str) -> KVConnectorStats: + return self.data[connector_id] + + def __setitem__(self, connector_id: str, stats: KVConnectorStats): + self.data[connector_id] = stats + + class MultiConnector(KVConnectorBase_V1): """ A wrapper for using multiple KVConnectors at the same time. @@ -43,17 +85,21 @@ class MultiConnector(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._connectors: list[KVConnectorBase_V1] = [] + self._ktc_kv_transfer_config = [] ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors") + "connectors" + ) assert ktcs is not None for ktc in ktcs: temp_config = copy.copy(vllm_config) - engine_id = ktc.get("engine_id", - vllm_config.kv_transfer_config.engine_id) + engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) temp_config.kv_transfer_config = KVTransferConfig( - **ktc, engine_id=engine_id) + **ktc, engine_id=engine_id + ) self._connectors.append( - KVConnectorFactory.create_connector(temp_config, role)) + KVConnectorFactory.create_connector(temp_config, role) + ) + self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to # load the request from (if any). @@ -72,12 +118,10 @@ class MultiConnector(KVConnectorBase_V1): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: - self._extra_async_saves.update( - connector_metadata.extra_async_saves) + self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) @@ -85,11 +129,23 @@ class MultiConnector(KVConnectorBase_V1): for c in self._connectors: c.clear_connector_metadata() + def shutdown(self): + exception: Optional[Exception] = None + for c in self._connectors: + try: + c.shutdown() + except Exception as e: + logger.exception( + "Exception during connector %s shutdown.", c.__class__.__name__ + ) + exception = e + if exception: + raise exception + # ============================== # Worker-side methods # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: for c in self._connectors: c.start_load_kv(forward_context, **kwargs) @@ -97,8 +153,13 @@ class MultiConnector(KVConnectorBase_V1): for c in self._connectors: c.wait_for_layer_load(layer_name) - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: for c in self._connectors: c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) @@ -133,6 +194,12 @@ class MultiConnector(KVConnectorBase_V1): return finished_sending or None, finished_recving or None + def get_block_ids_with_load_errors(self) -> set[int]: + agg_block_ids: set[int] = set() + for c in self._connectors: + agg_block_ids |= c.get_block_ids_with_load_errors() + return agg_block_ids + # ============================== # Scheduler-side methods # ============================== @@ -140,11 +207,16 @@ class MultiConnector(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: to_return = (0, False) for i, c in enumerate(self._connectors): toks, load_async = c.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) + # If there is a connector still looking up the matches, + # we return None to indicate that we are not done yet. + if toks is None: + return (None, False) # The first connector that has new matched tokens will be assigned # to this request. if to_return[0] == 0 and toks > 0: @@ -152,27 +224,27 @@ class MultiConnector(KVConnectorBase_V1): to_return = (toks, load_async) return to_return - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - chosen_connector = self._requests_to_connector.get( - request.request_id, -1) + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + chosen_connector = self._requests_to_connector.get(request.request_id, -1) empty_blocks = blocks.new_empty() for i, c in enumerate(self._connectors): if i == chosen_connector: # Forward call to the chosen connector (if any). - c.update_state_after_alloc(request, blocks, - num_external_tokens) + c.update_state_after_alloc(request, blocks, num_external_tokens) else: # Call with empty blocks for other connectors. c.update_state_after_alloc(request, empty_blocks, 0) def build_connector_meta( - self, - scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: - metadata = MultiKVConnectorMetadata(metadata=tuple( - c.build_connector_meta(scheduler_output) - for c in self._connectors)) + self, scheduler_output: SchedulerOutput + ) -> MultiKVConnectorMetadata: + metadata = MultiKVConnectorMetadata( + metadata=tuple( + c.build_connector_meta(scheduler_output) for c in self._connectors + ) + ) if self._extra_async_saves: metadata.extra_async_saves = self._extra_async_saves self._extra_async_saves = {} @@ -198,7 +270,8 @@ class MultiConnector(KVConnectorBase_V1): # TODO we can probably change this to merge the dicts here, # checking for key clashes. raise RuntimeError( - "Only one connector can produce KV transfer params") + "Only one connector can produce KV transfer params" + ) kv_txfer_params = txfer_params if async_saves > 1: self._extra_async_saves[request.request_id] = async_saves - 1 @@ -208,9 +281,12 @@ class MultiConnector(KVConnectorBase_V1): return async_saves > 0, kv_txfer_params + def take_events(self) -> Iterable["KVCacheEvent"]: + for c in self._connectors: + yield from c.take_events() + @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]: """ Get the required KV cache layout for this connector. Args: @@ -221,23 +297,49 @@ class MultiConnector(KVConnectorBase_V1): None if the connector does not require a specific layout. """ ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors") + "connectors" + ) assert ktcs is not None layouts: set[str] = set() temp_vllm_config = copy.copy(vllm_config) for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) temp_vllm_config.kv_transfer_config = kv_transfer_config - connector_cls = KVConnectorFactory.get_connector_class( - kv_transfer_config) - required_kvcache_layout = ( - connector_cls.get_required_kvcache_layout(temp_vllm_config)) + connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout( + temp_vllm_config + ) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) if len(layouts) > 1: - raise ValueError(f"KV cache layout mismatch: " - f"found {len(layouts)} different layouts " - f"({', '.join(layouts) })." - f"All connectors must use the same layout.") + raise ValueError( + f"KV cache layout mismatch: " + f"found {len(layouts)} different layouts " + f"({', '.join(layouts)})." + f"All connectors must use the same layout." + ) return next(iter(layouts), None) + + @classmethod + def build_kv_connector_stats( + cls, data: Optional[dict[str, Any]] = None + ) -> Optional[KVConnectorStats]: + return ( + MultiKVConnectorStats(data=data) + if data is not None + else MultiKVConnectorStats() + ) + + def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]: + # Group connector stats by connector type. + stats_by_connector: Optional[MultiKVConnectorStats] = None + for c in self._connectors: + stats = c.get_kv_connector_stats() + if stats is None: + continue + if stats_by_connector is None: + # Lazy init to allow optional return value. + stats_by_connector = MultiKVConnectorStats() + stats_by_connector[c.__class__.__name__] = stats + return stats_by_connector diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6608d2a4a9e09..365d1a1ff280c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import copy import logging import math +import os import queue import threading import time @@ -11,28 +13,36 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import msgspec +import numpy as np import torch import zmq from vllm import envs -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + CopyBlocksOp, + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import RequestStatus if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -50,29 +60,45 @@ logger = init_logger(__name__) # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used try: from nixl._api import nixl_agent as NixlWrapper + from nixl._bindings import nixlXferTelemetry + logger.info("NIXL is available") except ImportError: logger.warning("NIXL is not available") NixlWrapper = None + nixlXferTelemetry = None -# Supported xPUs and types of kv transfer buffer. -# {xPU: tuple of supported kv buffer types} -_NIXL_SUPPORTED_XPUS = { - "cuda": ("cuda", ), - "tpu": ("cpu", ), +try: + from nixl._api import nixl_agent_config +except ImportError: + nixl_agent_config = None + logger.warning("NIXL agent config is not available") + +# Supported platforms and types of kv transfer buffer. +# {device: tuple of supported kv buffer types} +_NIXL_SUPPORTED_DEVICE = { + "cuda": ( + "cuda", + "cpu", + ), + "tpu": ("cpu",), + "xpu": ("cpu",), } +# support for oot platform by providing mapping in current_platform +_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) class NixlAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int - block_len: int + block_lens: list[int] attn_backend_name: str kv_cache_layout: str @@ -88,11 +114,12 @@ class ReqMeta: class NixlConnectorMetadata(KVConnectorMetadata): - def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_send: dict[ReqId, float] = {} + self.reqs_in_batch: set[ReqId] = set() + self.reqs_not_processed: set[ReqId] = set() def add_new_req( self, @@ -120,20 +147,19 @@ class NixlConnectorMetadata(KVConnectorMetadata): class NixlConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[NixlConnectorScheduler] = \ + self.connector_scheduler: Optional[NixlConnectorScheduler] = ( NixlConnectorScheduler(vllm_config, self.engine_id) + ) self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker( - vllm_config, self.engine_id) + self.connector_worker = NixlConnectorWorker(vllm_config, self.engine_id) ############################################################ # Class Methods @@ -141,8 +167,10 @@ class NixlConnector(KVConnectorBase_V1): @classmethod def get_required_kvcache_layout(cls, vllm_config: VllmConfig): if vllm_config.model_config is None: - logger.warning_once("Unable to detect current VLLM config. " - "Fallback to default kv cache layout.") + logger.warning_once( + "Unable to detect current VLLM config. " + "Fallback to default kv cache layout." + ) return None use_mla = vllm_config.model_config.use_mla if use_mla: @@ -150,8 +178,9 @@ class NixlConnector(KVConnectorBase_V1): # as the layout should not matter in that case, # which fallback to the default behavior. return None - logger.info_once("NixlConnector setting KV cache " - "layout to HND for better xfer performance.") + logger.info_once( + "NixlConnector setting KV cache layout to HND for better xfer performance." + ) return "HND" ############################################################ @@ -159,18 +188,20 @@ class NixlConnector(KVConnectorBase_V1): ############################################################ def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[Optional[int], bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): assert self.connector_scheduler is not None return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) + request, blocks, num_external_tokens + ) def build_connector_meta( self, @@ -198,14 +229,26 @@ class NixlConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + assert self.connector_worker is not None + return self.connector_worker.get_kv_connector_stats() + + @classmethod + def build_kv_connector_stats( + cls, data: Optional[dict[str, Any]] = None + ) -> Optional[KVConnectorStats]: + return ( + NixlKVConnectorStats(data=data) + if data is not None + else NixlKVConnectorStats() + ) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) @@ -214,18 +257,26 @@ class NixlConnector(KVConnectorBase_V1): """NixlConnector does not do layerwise saving.""" pass - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: """NixlConnector does not save explicitly.""" pass def wait_for_save(self): assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) - if self.connector_worker.use_host_buffer and \ - self.connector_worker.copy_blocks: + if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks: self.connector_worker.save_kv_to_host(self._connector_metadata) + def shutdown(self): + if self.connector_worker is not None: + self.connector_worker.shutdown() + class NixlConnectorScheduler: """Implementation of Scheduler side methods""" @@ -236,11 +287,11 @@ class NixlConnectorScheduler: self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_port = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) - self.use_host_buffer = \ - vllm_config.kv_transfer_config.kv_buffer_device == "cpu" + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) + self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv/send. @@ -250,10 +301,14 @@ class NixlConnectorScheduler: self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} + self._reqs_in_batch: set[ReqId] = set() + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[ReqId] = set() def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: """ For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution. @@ -273,7 +328,9 @@ class NixlConnectorScheduler: logger.debug( "NIXLConnector get_num_new_matched_tokens: " "num_computed_tokens=%s, kv_transfer_params=%s", - num_computed_tokens, params) + num_computed_tokens, + params, + ) if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. @@ -284,18 +341,22 @@ class NixlConnectorScheduler: # No remote prefill for this request. return 0, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): params = request.kv_transfer_params logger.debug( "NIXLConnector update_state_after_alloc: " "num_external_tokens=%s, kv_transfer_params=%s", - num_external_tokens, params) + num_external_tokens, + params, + ) if not params: return + + if params.get("do_remote_decode"): + self._reqs_in_batch.add(request.request_id) if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. @@ -308,25 +369,33 @@ class NixlConnectorScheduler: # block is not overwritten; and it will be safe to skip saving them # to host xfer buffer. if block_ids: - self._reqs_need_save[request.request_id] = \ - (request, block_ids) + self._reqs_need_save[request.request_id] = (request, block_ids) elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): - if all(p in params for p in ("remote_engine_id", "remote_host", - "remote_port")): + if all( + p in params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. - local_block_ids = (blocks.get_unhashed_block_ids() - if num_external_tokens > 0 else []) + local_block_ids = ( + blocks.get_unhashed_block_ids() + if num_external_tokens > 0 + else [] + ) # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = ( - request, local_block_ids) + request, + local_block_ids, + ) else: logger.warning( "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", params) + "request will not utilize KVTransfer", + params, + ) else: assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. @@ -345,6 +414,8 @@ class NixlConnectorScheduler: request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, + load_remote_cache=True, + save_to_host=False, ) for req_id, (req, block_ids) in self._reqs_need_save.items(): @@ -358,10 +429,14 @@ class NixlConnectorScheduler: ) meta.reqs_to_send = self._reqs_need_send + meta.reqs_in_batch = self._reqs_in_batch + meta.reqs_not_processed = self._reqs_not_processed # Clear the list once workers start the transfers self._reqs_need_recv.clear() self._reqs_need_save.clear() + self._reqs_in_batch = set() + self._reqs_not_processed = set() self._reqs_need_send = {} return meta @@ -375,11 +450,14 @@ class NixlConnectorScheduler: Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. """ + from vllm.v1.request import RequestStatus params = request.kv_transfer_params logger.debug( - "NIXLConnector request_finished, request_status=%s, " - "kv_transfer_params=%s", request.status, params) + "NIXLConnector request_finished, request_status=%s, kv_transfer_params=%s", + request.status, + params, + ) if not params: return False, None @@ -394,8 +472,12 @@ class NixlConnectorScheduler: params["do_remote_prefill"] = False return False, None - if (not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + if not params.get("do_remote_decode"): + return False, None + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + # Also include the case of a P/D Prefill request with immediate + # block free (eg abort). Stop tracking this request. + self._reqs_not_processed.add(request.request_id) return False, None # TODO: check whether block_ids actually ever be 0. If not we could @@ -404,8 +486,9 @@ class NixlConnectorScheduler: if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[request.request_id] = time.perf_counter( - ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + ) return delay_free_blocks, dict( do_remote_prefill=True, @@ -414,7 +497,8 @@ class NixlConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, - tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + ) class NixlConnectorWorker: @@ -431,8 +515,24 @@ class NixlConnectorWorker: self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( + "backends", ["UCX"] + ) + # TODO temporary, once nixl allows for telemetry flag in config + # (next release), we can remove this env var. + os.environ["NIXL_TELEMETRY_ENABLE"] = "1" # Agent. - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] + if nixl_agent_config is None: + config = None + else: + config = ( + nixl_agent_config(backends=self.nixl_backends) + if len(non_ucx_backends) > 0 + else nixl_agent_config(num_threads=8) + ) + + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -441,9 +541,10 @@ class NixlConnectorWorker: # base port (which is sent in the KVTransferParams). # Each TP rank listens/queries on the base_port + tp_rank. self.side_channel_port: int = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) # Metadata. self.engine_id: EngineId = engine_id @@ -454,29 +555,33 @@ class NixlConnectorWorker: # KV Caches and nixl tracking data. self.device_type = current_platform.device_type - self.kv_buffer_device: str = \ - vllm_config.kv_transfer_config.kv_buffer_device - if self.device_type not in _NIXL_SUPPORTED_XPUS: + self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device + if self.device_type not in _NIXL_SUPPORTED_DEVICE: raise RuntimeError(f"{self.device_type} is not supported.") - elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[ - self.device_type]: + elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " - "is not supported.") + "is not supported." + ) self.device_kv_caches: dict[str, torch.Tensor] = {} # cpu kv buffer for xfer - # used when xPU memory can not be registered under nixl + # used when device memory can not be registered under nixl self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.use_host_buffer = self.kv_buffer_device == "cpu" - if self.kv_buffer_device == "cuda": - self.nixl_memory_type = "VRAM" - elif self.kv_buffer_device == "cpu": - self.nixl_memory_type = "DRAM" - else: + # support for oot platform which can't register nixl memory + # type based on kv_buffer_device + self.nixl_memory_type = current_platform.get_nixl_memory_type() + if self.nixl_memory_type is None: + if self.kv_buffer_device == "cuda": + self.nixl_memory_type = "VRAM" + elif self.kv_buffer_device == "cpu": + self.nixl_memory_type = "DRAM" + if self.nixl_memory_type is None: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " - "is not supported.") + "is not supported." + ) # Note: host xfer buffer ops when use_host_buffer is True self.copy_blocks: Optional[CopyBlocksOp] = None @@ -506,6 +611,8 @@ class NixlConnectorWorker: self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} + # Set of requests that have been part of a batch, regardless of status. + self._reqs_to_process: set[ReqId] = set() # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None @@ -513,7 +620,8 @@ class NixlConnectorWorker: self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. max_workers=1, - thread_name_prefix="vllm-nixl-handshake-initiator") + thread_name_prefix="vllm-nixl-handshake-initiator", + ) self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} # Protects _handshake_futures and _remote_agents. @@ -530,16 +638,17 @@ class NixlConnectorWorker: self.block_window_per_layer: list[Optional[int]] = [] self.use_mla = self.model_config.use_mla - backend = get_attn_backend(self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.use_mla) + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name) - self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 - self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 + self._use_flashinfer = attn_backend == _Backend.FLASHINFER + self._use_pallas = attn_backend == _Backend.PALLAS self.kv_cache_layout = get_kv_cache_layout() logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) @@ -548,17 +657,15 @@ class NixlConnectorWorker: # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) - - def __del__(self): - """Cleanup background threads on destruction.""" - self._handshake_initiation_executor.shutdown(wait=False) - if self._nixl_handshake_listener_t: - self._nixl_handshake_listener_t.join(timeout=0) + self.xfer_stats = NixlKVConnectorStats() @staticmethod - def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, base_port: int, - tp_rank: int): + def _nixl_handshake_listener( + metadata: NixlAgentMetadata, + ready_event: threading.Event, + base_port: int, + tp_rank: int, + ): """Background thread for getting new NIXL handshakes.""" # NOTE(rob): this is a simple implementation. We will move # to a better approach via HTTP endpoint soon. @@ -566,8 +673,7 @@ class NixlConnectorWorker: encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", - str(size_in_bytes)) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes)) # Listen for new requests for metadata. host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST @@ -578,8 +684,7 @@ class NixlConnectorWorker: while True: identity, _, msg = sock.recv_multipart() if msg != GET_META_MSG: - logger.warning( - "Connection listener got unexpected message %s", msg) + logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) def _nixl_handshake( @@ -602,8 +707,9 @@ class NixlConnectorWorker: tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug("Querying metadata on path: %s at remote rank %s", path, - p_remote_rank) + logger.debug( + "Querying metadata on path: %s at remote rank %s", path, p_remote_rank + ) # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: @@ -612,27 +718,32 @@ class NixlConnectorWorker: decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) + logger.debug( + "NIXL handshake: get metadata took: %s", got_metadata_time - start_time + ) # Ensure engine id matches. if metadata.engine_id != expected_engine_id: - raise RuntimeError(f"Remote NIXL agent engine ID mismatch. " - f"Expected {expected_engine_id}," - f"received {metadata.engine_id}.") + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) # Register Remote agent. - remote_agent_name = self.add_remote_agent(metadata, p_remote_rank, - remote_tp_size) + remote_agent_name = self.add_remote_agent( + metadata, p_remote_rank, remote_tp_size + ) setup_agent_time = time.perf_counter() - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + logger.debug( + "NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) # Remote rank -> agent name. return {p_remote_rank: remote_agent_name} - def initialize_host_xfer_buffer( - self, kv_caches: dict[str, torch.Tensor]) -> None: + def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ Initialize transfer buffer in CPU mem for accelerators NOT directly supported by NIXL (e.g., tpu) @@ -642,9 +753,9 @@ class NixlConnectorWorker: for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype - xfer_buffers[layer_name] = torch.empty(kv_shape, - dtype=kv_dtype, - device="cpu") + xfer_buffers[layer_name] = torch.empty( + kv_shape, dtype=kv_dtype, device="cpu" + ) except MemoryError as e: logger.error("NIXLConnectorWorker gets %s.", e) raise @@ -653,17 +764,25 @@ class NixlConnectorWorker: def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """Assign copy (d2h, h2d) operations when host buffer is used.""" + # Set a no-op if the host buffer is not cpu. + if self.kv_buffer_device != "cpu": + return assert self.use_host_buffer self.copy_blocks = copy_operation - def _background_nixl_handshake(self, req_id: str, - remote_engine_id: EngineId, meta: ReqMeta): + def _background_nixl_handshake( + self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta + ): # Do NIXL handshake in background and add to _ready_requests when done. fut = self._handshake_futures.get(remote_engine_id) if fut is None: fut = self._handshake_initiation_executor.submit( - self._nixl_handshake, meta.remote_host, meta.remote_port, - meta.tp_size, remote_engine_id) + self._nixl_handshake, + meta.remote_host, + meta.remote_port, + meta.tp_size, + remote_engine_id, + ) self._handshake_futures[remote_engine_id] = fut def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): @@ -690,24 +809,27 @@ class NixlConnectorWorker: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( f"host_buffer: {len(self.host_xfer_buffers)}, " - f"kv_caches: {len(kv_caches)}") + f"kv_caches: {len(kv_caches)}" + ) xfer_buffers = self.host_xfer_buffers else: xfer_buffers = kv_caches assert not self.host_xfer_buffers, ( "host_xfer_buffer should not be initialized when " - f"kv_buffer_device is {self.kv_buffer_device}") + f"kv_buffer_device is {self.kv_buffer_device}" + ) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, - self.use_host_buffer) + "use_host_buffer: %s", + self.use_mla, + self.kv_buffer_device, + self.use_host_buffer, + ) caches_data = [] # With hybrid allocator, layers can share a kv cache tensor seen_base_addresses = [] - xfer_buffers = (self.host_xfer_buffers - if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -715,15 +837,15 @@ class NixlConnectorWorker: # are non-contiguous (it's not locally guaranteed that they will be) # Disadvantage is that the encoded NixlAgentMetadata is now larger # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor + # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = not (self.use_mla or self._use_pallas_v1 - or self._use_flashinfer) + split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) tensor_size_bytes = None + # Enable different block lengths for different layers when MLA is used. + self.block_len_per_layer = list[int]() + self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): - cache_list = cache_or_caches if split_k_and_v else [ - cache_or_caches - ] + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] for cache in cache_list: base_addr = cache.data_ptr() @@ -737,62 +859,103 @@ class NixlConnectorWorker: tensor_size_bytes = curr_tensor_size_bytes self.num_blocks = cache.shape[0] - assert tensor_size_bytes == curr_tensor_size_bytes, \ - "All kv cache tensors must have the same size" + assert cache.shape[0] == self.num_blocks, ( + "All kv cache tensors must have the same number of blocks" + ) + + self.block_len_per_layer.append( + curr_tensor_size_bytes // self.num_blocks + ) + self.slot_size_per_layer.append( + self.block_len_per_layer[-1] // self.block_size + ) + + if not self.use_mla: + # Different kv cache shape is not supported by HeteroTP + assert tensor_size_bytes == curr_tensor_size_bytes, ( + "All kv cache tensors must have the same size" + ) caches_data.append( - (base_addr, tensor_size_bytes, self.tp_rank, "")) + (base_addr, curr_tensor_size_bytes, self.tp_rank, "") + ) + + logger.debug( + "Different block lengths collected: %s", set(self.block_len_per_layer) + ) + assert len(self.block_len_per_layer) == len(seen_base_addresses) + assert self.num_blocks != 0 self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - descs = self.nixl_wrapper.get_reg_descs(caches_data, - self.nixl_memory_type) + descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends) logger.debug("Done registering descs") self._registered_descs.append(descs) - assert tensor_size_bytes is not None - assert self.num_blocks != 0 - assert tensor_size_bytes % self.num_blocks == 0 - self.block_len = tensor_size_bytes // self.num_blocks - self.slot_size_bytes = self.block_len // self.block_size - if self._use_flashinfer: - assert self.slot_size_bytes % 2 == 0 - self.slot_size_bytes /= 2 self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks + if self._use_flashinfer: + for i in range(len(self.slot_size_per_layer)): + assert self.slot_size_per_layer[i] % 2 == 0 + self.slot_size_per_layer[i] //= 2 + + # NOTE (NickLucche) When FlashInfer is used, memory is registered + # with joint KV for each block. This minimizes the overhead in + # registerMem allowing faster descs queries. In order to be able to + # split on kv_heads dim as required by heterogeneous TP, one must + # be able to index K/V separately. Hence we double the number + # of 'virtual' regions here and halve `block_len` below. + self.num_regions *= 2 # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in seen_base_addresses: + for i, base_addr in enumerate(seen_base_addresses): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to # select agent_meta.num_blocks instead of self.num_blocks for # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # (addr, len, device id) - # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len, self.tp_rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.tp_rank) + blocks_data.append((addr, kv_block_len, self.tp_rank)) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, - self.nixl_memory_type) + if self._use_flashinfer: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len_per_layer[i] + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.tp_rank)) + logger.debug( + "Created %s blocks for src engine %s and rank %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + ) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + "NIXL_INIT_AGENT", descs + ) - # TODO(mgoin): Hybrid memory allocator is currently diabled for + # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) + + assert isinstance( + self.vllm_config.model_config.hf_text_config, Llama4TextConfig + ) llama4_config = self.vllm_config.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size @@ -803,8 +966,10 @@ class NixlConnectorWorker: is_local_attention = no_rope_layers[layer_idx] != 0 block_window = chunk_block_size if is_local_attention else None self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) + logger.debug( + "Llama 4 block window per layer mapping: %s", + self.block_window_per_layer, + ) assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. @@ -813,35 +978,39 @@ class NixlConnectorWorker: agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - block_len=self.block_len, + block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, - kv_cache_layout=self.kv_cache_layout) + kv_cache_layout=self.kv_cache_layout, + ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, args=(metadata, ready_event, self.side_channel_port, self.tp_rank), daemon=True, - name="nixl_handshake_listener") + name="nixl_handshake_listener", + ) self._nixl_handshake_listener_t.start() ready_event.wait() # Wait for listener ZMQ socket to be ready. - def add_remote_agent(self, - nixl_agent_meta: NixlAgentMetadata, - remote_tp_rank: int = 0, - remote_tp_size: int = 1) -> str: + def add_remote_agent( + self, + nixl_agent_meta: NixlAgentMetadata, + remote_tp_rank: int = 0, + remote_tp_size: int = 1, + ) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. In particular, handle both homogeneous and heterogeneous TP. The former - requires local rank_i to read from remote rank_i. - The latter, assuming D.world_size > P.world_size, requires that two or + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or more local TP worker share the xfer from a single TP worker. - Here's an example: + Here's an example (non-MLA case): rank_offset p_remote_tp_rank - (kv split no) + (kv split no) -------------------------------- 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] / @@ -854,19 +1023,19 @@ class NixlConnectorWorker: Decoder TP workers Prefix TP workers (world_size=4) (world_size=2) - tp_ratio = 4 // 2 = 2 - - Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. - Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split - along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. - + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 so that the whole cache is shared by "tp_ratio" D TP workers. - """ # noqa: E501 + """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): @@ -880,43 +1049,55 @@ class NixlConnectorWorker: assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) + nixl_agent_meta.agent_metadata + ) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. - tp_ratio = divide(self._tp_size[self.engine_id], - self._tp_size[engine_id]) + tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas_v1 or tp_ratio == 1, \ - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) # Handle tp_size>num_kv_heads: replicate KV cache. total_num_kv_heads = self.model_config.get_total_num_kv_heads() is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or is_kv_replicated: - # With MLA the only difference is in the number of blocks. - remote_block_size = nixl_agent_meta.block_len // ( - self.slot_size_bytes) - assert self.block_len == nixl_agent_meta.block_len + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + "KV cache sizes must match between P and D when replicated" + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: - remote_block_size = nixl_agent_meta.block_len // ( - self.slot_size_bytes * tp_ratio) + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, ( + "All remote layers must have the same block size" + ) + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) if self._use_flashinfer: - # Account for joint KV in FlashInfer. + # With flashinfer, KV are sent in the same message. remote_block_size //= 2 if tp_ratio > 1: # Heterogeneous TP expects same kv_cache_layout. assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout + if self.device_type == "xpu": + raise ValueError("Heterogeneous TP is not supported on XPU") - assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) assert self.block_size == remote_block_size, ( - "Remote P worker with different block size is not supported " - f"{self.block_size=} {remote_block_size=}") + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -929,32 +1110,47 @@ class NixlConnectorWorker: # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - # Only register the remote's descriptors if current rank pulls from it. - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ - if not (self.use_mla or is_kv_replicated) else 0 + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. - for base_addr in nixl_agent_meta.kv_caches_base_addr: + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + rank_offset = ( + self.tp_rank % tp_ratio * kv_block_len + if not (self.use_mla or is_kv_replicated) + else 0 + ) for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * nixl_agent_meta.block_lens[i] # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) + blocks_data.append((addr, kv_block_len, remote_tp_rank)) + + if self._use_flashinfer: + # With FlashInfer index V separately to allow head splitting. + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_lens[i] + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 + blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and " - "local rank %s", len(blocks_data), engine_id, remote_tp_rank, - self.tp_rank) + "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, - self.nixl_memory_type) - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs + ) return remote_agent_name @@ -964,13 +1160,20 @@ class NixlConnectorWorker: assert self.copy_blocks is not None local_block_ids = meta.local_block_ids - self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, - local_block_ids, local_block_ids, "h2d") + self.copy_blocks( + self.host_xfer_buffers, + self.device_kv_caches, + local_block_ids, + local_block_ids, + "h2d", + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "synced recved kv of request[%s] to device kv buffer," - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) + "local_block_ids: %s. ", + req_id, + ",".join(map(str, meta.local_block_ids)), + ) def save_kv_to_host(self, metadata: NixlConnectorMetadata): """copy kv from device to host buffer.""" @@ -981,11 +1184,18 @@ class NixlConnectorWorker: if logger.isEnabledFor(logging.DEBUG): logger.debug( "save_load_kv for request[%s] to host xfer buffer." - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) + "local_block_ids: %s. ", + req_id, + ",".join(map(str, meta.local_block_ids)), + ) # blocking - self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, - meta.local_block_ids, meta.local_block_ids, "d2h") + self.copy_blocks( + self.device_kv_caches, + self.host_xfer_buffers, + meta.local_block_ids, + meta.local_block_ids, + "d2h", + ) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -998,8 +1208,11 @@ class NixlConnectorWorker: if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " - "and %s requests done recving", self.tp_rank, - len(done_sending), len(done_recving)) + "and %s requests done recving", + self.tp_rank, + len(done_sending), + len(done_recving), + ) if self.use_host_buffer: for req_id in done_recving: @@ -1017,8 +1230,12 @@ class NixlConnectorWorker: count = self.consumer_notification_counts_by_req.pop(req_id, 0) logger.warning( "Releasing expired KV blocks for request %s which were " - "retrieved by %d decode worker(s) within %d seconds.", req_id, - count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) + "retrieved by %d decode worker(s) within %d seconds.", + req_id, + count, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) + self._reqs_to_process.remove(req_id) del self._reqs_to_send[req_id] done_sending.add(req_id) @@ -1034,24 +1251,30 @@ class NixlConnectorWorker: for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) - if req_id not in self._reqs_to_send: + if ( + req_id not in self._reqs_to_send + and req_id not in self._reqs_to_process + ): logger.error( "Potentially invalid KV blocks for " "unrecognized request %s were retrieved by " - "a decode worker. They may have expired.", req_id) + "a decode worker. They may have expired.", + req_id, + ) continue self.consumer_notification_counts_by_req[req_id] += 1 # Wait all consumers (D) to be done reading before freeing. - if self.consumer_notification_counts_by_req[req_id] == int( - tp_ratio): + if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] - del self._reqs_to_send[req_id] + self._reqs_to_process.remove(req_id) + self._reqs_to_send.pop(req_id, None) return notified_req_ids def _pop_done_transfers( - self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: + self, transfers: dict[str, list[tuple[int, float]]] + ) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -1065,13 +1288,15 @@ class NixlConnectorWorker: for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": + # Get telemetry from NIXL + res = self.nixl_wrapper.get_xfer_telemetry(handle) + self.xfer_stats.record_transfer(res) self.nixl_wrapper.release_xfer_handle(handle) elif xfer_state == "PROC": in_progress = True continue else: - raise RuntimeError("Transfer failed with state %s", - xfer_state) + raise RuntimeError("Transfer failed with state %s", xfer_state) if not in_progress: done_req_ids.add(req_id) del transfers[req_id] @@ -1086,17 +1311,19 @@ class NixlConnectorWorker: remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - remote_engine_id, len(meta.local_block_ids), - len(meta.remote_block_ids)) + "Num local_block_ids: %s. Num remote_block_ids: %s. ", + req_id, + remote_engine_id, + len(meta.local_block_ids), + len(meta.remote_block_ids), + ) if self.use_host_buffer: self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - self._background_nixl_handshake( - req_id, remote_engine_id, meta) + self._background_nixl_handshake(req_id, remote_engine_id, meta) continue # Handshake already completed, start async read xfer. @@ -1106,13 +1333,32 @@ class NixlConnectorWorker: while not self._ready_requests.empty(): self._read_blocks_for_req(*self._ready_requests.get_nowait()) + # Keep around the requests that have been part of a batch. This is + # needed because async scheduling pushes the misalignment between the + # moment in which requests expiration is set (P side) and the moment in + # which blocks are read from D. As P can now more easily lag behind D + # while processing the next batch, we make sure to only set an + # expiration for requests that have not been read from D yet. + for req_id in metadata.reqs_in_batch: + self._reqs_to_process.add(req_id) + + # Remove all requests that are not to be processed (eg aborted). + for req_id in metadata.reqs_not_processed: + self._reqs_to_process.discard(req_id) + # We should never get an abort after setting an expiry timer + assert req_id not in self._reqs_to_send + # Add to requests that are waiting to be read and track expiration. - self._reqs_to_send.update(metadata.reqs_to_send) + for req_id, expiration_time in metadata.reqs_to_send.items(): + if req_id in self._reqs_to_process: + self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, req_id) + meta.remote_engine_id, + req_id, + ) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -1120,9 +1366,13 @@ class NixlConnectorWorker: remote_block_ids=meta.remote_block_ids, ) - def _read_blocks(self, local_block_ids: list[int], - remote_block_ids: list[int], dst_engine_id: str, - request_id: str): + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + ): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1135,8 +1385,7 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self._tp_size[ - self.engine_id] // self._tp_size[dst_engine_id] + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, @@ -1163,19 +1412,22 @@ class NixlConnectorWorker: # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: list[int] = [] - remote_block_descs_ids: list[int] = [] + local_block_descs_ids: np.ndarray + remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids) + dst_engine_id, remote_block_ids + ) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids) + self.engine_id, local_block_ids + ) else: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) - for layer_idx, block_window in enumerate( - self.block_window_per_layer): + local_descs_list = [] + remote_descs_list = [] + for layer_idx, block_window in enumerate(self.block_window_per_layer): # For each layer: if block_window is None: # If not chunked, we just use the @@ -1189,12 +1441,17 @@ class NixlConnectorWorker: # Get descs ids for the layer. layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, layer_local_block_ids, layer_idx) + self.engine_id, layer_local_block_ids, layer_idx + ) layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, layer_remote_block_ids, layer_idx) + dst_engine_id, layer_remote_block_ids, layer_idx + ) - local_block_descs_ids.extend(layer_local_desc_ids) - remote_block_descs_ids.extend(layer_remote_desc_ids) + local_descs_list.append(layer_local_desc_ids) + remote_descs_list.append(layer_remote_desc_ids) + + local_block_descs_ids = np.concatenate(local_descs_list) + remote_block_descs_ids = np.concatenate(remote_descs_list) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -1212,21 +1469,18 @@ class NixlConnectorWorker: self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). - # TODO (NickLucche) surface xfer elapsed time - self._recving_transfers[request_id].append( - (handle, time.perf_counter())) + self._recving_transfers[request_id].append((handle, time.perf_counter())) - def _get_block_descs_ids(self, - engine_id: str, - block_ids: list[int], - layer_idx: Optional[int] = None) -> list[int]: + def _get_block_descs_ids( + self, engine_id: str, block_ids: list[int], layer_idx: Optional[int] = None + ) -> np.ndarray: """ Get the descs ids for a set of block ids. If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ if layer_idx is None: - region_ids = range(self.num_regions) + region_ids = np.arange(self.num_regions) else: assert layer_idx < self.num_layers if self.num_layers < self.num_regions: @@ -1234,20 +1488,68 @@ class NixlConnectorWorker: # the regions are organized as [K0, V0, K1, V1, ...] # and we select K_i and V_i assert 2 * self.num_layers == self.num_regions - region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2) else: # Otherwise, we assume we have MLA and select i-th layer assert self.num_layers == self.num_regions - region_ids = range(layer_idx, layer_idx + 1) + region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. - descs_ids: list[int] = [] - for reg_id in region_ids: - for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) - return descs_ids + region_ids = region_ids[:, None] + block_ids = np.array(block_ids)[None, :] + descs_ids = region_ids * num_blocks + block_ids + return descs_ids.flatten() + + def get_backend_aware_kv_block_len(self, layer_idx: int): + """ + Get the block length for one K/V element (K and V have the same size). + + For FA and other backends, this is equal to the length of the whole + block, as K and V are in separate regions. + For FlashInfer, this is half the length of the whole block, as K and V + share the same region. + """ + if self._use_flashinfer: + # For indexing only half (either just the K or V part). + block_len = self.block_len_per_layer[layer_idx] // 2 + else: + block_len = self.block_len_per_layer[layer_idx] + return block_len + + def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + """ + Get the KV transfer stats for the connector. + """ + # Clear stats for next iteration + if not self.xfer_stats.is_empty(): + return self.xfer_stats.clone_and_reset() + return None + + def shutdown(self): + """Shutdown the connector worker.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t is not None: + self._nixl_handshake_listener_t.join(timeout=0) + self._nixl_handshake_listener_t = None + for handles in self._recving_transfers.values(): + for handle, _ in handles: + self.nixl_wrapper.release_xfer_handle(handle) + self._recving_transfers.clear() + if self.src_xfer_side_handle: + self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) + self.src_xfer_side_handle = 0 + for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + self.dst_xfer_side_handles.clear() + for remote_agents in self._remote_agents.values(): + for agent_name in remote_agents.values(): + self.nixl_wrapper.remove_remote_agent(agent_name) + self._remote_agents.clear() + for desc in self._registered_descs: + self.nixl_wrapper.deregister_memory(desc) + self._registered_descs.clear() @contextlib.contextmanager @@ -1260,10 +1562,94 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: ctx: Optional[zmq.Context] = None try: ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) finally: if ctx is not None: ctx.destroy(linger=0) + + +@dataclass +class NixlKVConnectorStats(KVConnectorStats): + """Container for transfer performance metrics""" + + def __post_init__(self): + if not self.data: + # Empty container init, no data is passed in. + self.reset() + + def reset(self): + # Must be serializable + self.data: dict[str, list[float]] = { + "transfer_duration": [], + "post_duration": [], + "bytes_transferred": [], + "num_descriptors": [], + } + + def record_transfer(self, res: nixlXferTelemetry): + # Keep metrics units consistent with rest of the code: time us->s + self.data["transfer_duration"].append(res.xferDuration / 1e6) + self.data["post_duration"].append(res.postDuration / 1e6) + self.data["bytes_transferred"].append(res.totalBytes) + self.data["num_descriptors"].append(res.descCount) + + def clone_and_reset(self) -> "NixlKVConnectorStats": + old = copy.copy(self) + self.reset() + return old + + def is_empty(self) -> bool: + return self.num_successful_transfers == 0 + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + if not other.is_empty(): + for k, v in other.data.items(): + accumulator = self.data[k] + assert isinstance(accumulator, list) + accumulator.extend(v) + return self + + def reduce(self) -> dict[str, Union[int, float]]: + # Compute compact representative stats suitable for CLI logging + if self.is_empty(): + return { + "Num successful transfers": 0, + "Avg xfer time (ms)": 0, + "P90 xfer time (ms)": 0, + "Avg post time (ms)": 0, + "P90 post time (ms)": 0, + "Avg MB per transfer": 0, + "Throughput (MB/s)": 0, + "Avg number of descriptors": 0, + } + + xfer_time = np.asarray(self.data["transfer_duration"]) + post_time = np.asarray(self.data["post_duration"]) + # Convert to MB for CLI logging. + mb = np.asarray(self.data["bytes_transferred"]) / 2**20 + descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32) + n = len(descs) + assert n == self.num_successful_transfers + + total_mb = mb.sum() + avg_mb = total_mb / n + + total_time_seconds = xfer_time.sum() + throughput_mb_s = total_mb / total_time_seconds + + return { + "Num successful transfers": n, + "Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3), + "P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3), + "Avg post time (ms)": round(post_time.mean() * 1e3, 3), + "P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3), + "Avg MB per transfer": round(avg_mb, 3), + "Throughput (MB/s)": round(throughput_mb_s, 3), + "Avg number of descriptors": round(descs.mean(), 1), + } + + @property + def num_successful_transfers(self) -> int: + return len(self.data["transfer_duration"]) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py new file mode 100644 index 0000000000000..745af0efba180 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -0,0 +1,498 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from itertools import islice +from typing import Any, Optional + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_offload.abstract import OffloadingManager +from vllm.v1.kv_offload.factory import OffloadingSpecFactory +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.request import Request + +ReqId = str + +logger = init_logger(__name__) + + +@dataclass +class OffloadingConnectorMetadata(KVConnectorMetadata): + reqs_to_load: dict[ReqId, TransferSpec] + reqs_to_store: dict[ReqId, TransferSpec] + + +class OffloadingConnector(KVConnectorBase_V1): + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config, role) + + spec = OffloadingSpecFactory.create_spec(vllm_config) + + self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None + self.connector_worker: Optional[OffloadingConnectorWorker] = None + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = OffloadingConnectorScheduler(spec) + elif role == KVConnectorRole.WORKER: + self.connector_worker = OffloadingConnectorWorker(spec) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) + self.connector_worker.start_store_kv(self._connector_metadata) + + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens + ) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def update_connector_output(self, connector_output: KVConnectorOutput): + assert self.connector_scheduler is not None + self.connector_scheduler.update_connector_output(connector_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + def take_events(self) -> Iterable[KVCacheEvent]: + assert self.connector_scheduler is not None + return self.connector_scheduler.take_events() + + +class OffloadingConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, spec: OffloadingSpec): + self.gpu_block_size = spec.gpu_block_size + self.offloaded_block_size = spec.offloaded_block_size + self.block_size_factor = self.offloaded_block_size // self.gpu_block_size + self.manager: OffloadingManager = spec.get_manager() + + self._requests: dict[ReqId, Request] = {} + # list of GPU block IDs per request + self._request_block_ids: dict[ReqId, list[int]] = {} + # requests to load for the current scheduler step + self._reqs_to_load: dict[ReqId, TransferSpec] = {} + # request blocks are stored in order + # index of next block (of size offloaded_block_size) to offload + self._next_stored_block_idx: dict[ReqId, int] = {} + + # request ID -> set(block hashes being stored/load) + self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set) + + def _get_block_hashes( + self, + req: Request, + start_idx: int = 0, + end_idx: Optional[int] = None, + ) -> Iterable[BlockHash]: + return islice( + req.block_hashes, + self.block_size_factor * start_idx + self.block_size_factor - 1, + self.block_size_factor * end_idx if end_idx else None, + self.block_size_factor, + ) + + def get_num_new_matched_tokens( + self, request: Request, num_computed_tokens: int + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded beyond the + num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - The number of tokens that can be loaded beyond what is + already computed. + - `True` if tokens will be loaded asynchronously + (between scheduler steps). + """ + num_blocks = request.num_tokens // self.offloaded_block_size + + assert len(request.block_hashes) // self.block_size_factor == num_blocks + block_hashes = self._get_block_hashes(request) + + self.manager.touch(block_hashes) + + full_block_tokens = self.offloaded_block_size * num_blocks + if full_block_tokens - num_computed_tokens < self.offloaded_block_size: + # we can load less than a block, skip + return 0, False + + start_block_idx = num_computed_tokens // self.offloaded_block_size + hits = self.manager.lookup( + self._get_block_hashes(request, start_idx=start_block_idx) + ) + if hits == 0: + return 0, False + + num_hit_tokens = ( + self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens + ) + logger.debug( + "Request %s hit %s offloaded tokens after %s GPU hit tokens", + request.request_id, + num_hit_tokens, + num_computed_tokens, + ) + if num_hit_tokens < self.offloaded_block_size: + return 0, False + + return num_hit_tokens, True + + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): + self._requests[request.request_id] = request + # the block ids are updated in _get_reqs_to_store + self._request_block_ids[request.request_id] = [] + + if num_external_tokens == 0: + return + + block_groups = blocks.get_block_ids() + block_ids = block_groups[0] + + num_computed_gpu_blocks = sum( + block.block_hash is not None for block in blocks.blocks[0] + ) + num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size + full_block_tokens = num_computed_tokens + num_external_tokens + assert full_block_tokens % self.offloaded_block_size == 0 + + num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks + assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size + + start_block_idx = num_computed_tokens // self.offloaded_block_size + num_blocks = full_block_tokens // self.offloaded_block_size + + assert len(request.block_hashes) // self.block_size_factor >= num_blocks + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) + + src_spec = self.manager.prepare_load(block_hashes) + dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:]) + + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) + + self._reqs_to_load[request.request_id] = (src_spec, dst_spec) + self._reqs_being_loaded[request.request_id].update(block_hashes) + self._next_stored_block_idx[request.request_id] = num_blocks + + def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): + reqs_to_store: dict[ReqId, TransferSpec] = {} + # iterate over both new and cached requests + for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): + if preempted: + self._request_block_ids[req_id] = [] + + if new_block_id_groups: + new_block_ids = new_block_id_groups[0] + self._request_block_ids[req_id] += new_block_ids + + block_ids = self._request_block_ids[req_id] + + req = self._requests[req_id] + new_tokens = scheduler_output.num_scheduled_tokens[req_id] + total_tokens = req.num_computed_tokens + new_tokens + num_blocks = total_tokens // self.offloaded_block_size + start_block_idx = self._next_stored_block_idx.get(req_id, 0) + num_new_blocks = num_blocks - start_block_idx + + if num_new_blocks <= 0: + continue + + num_gpu_blocks = num_blocks * self.block_size_factor + assert len(req.block_hashes) >= num_gpu_blocks + + new_block_hashes = self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks + ) + store_output = self.manager.prepare_store(new_block_hashes) + if store_output is None: + logger.warning( + "Request %s: cannot store %s blocks", req_id, num_new_blocks + ) + continue + + self._next_stored_block_idx[req_id] = num_blocks + + if not store_output.block_hashes_to_store: + continue + block_hashes_to_store = set(store_output.block_hashes_to_store) + + block_hashes = self._get_block_hashes(req, end_idx=num_blocks) + self.manager.touch(block_hashes) + + new_block_hashes = self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks + ) + dst_spec = store_output.store_spec + src_block_ids: list[int] = [] + for idx, blk_hash in enumerate(new_block_hashes): + if blk_hash not in block_hashes_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * self.block_size_factor + for i in range(self.block_size_factor): + src_block_ids.append(block_ids[gpu_block_idx + i]) + src_spec = GPULoadStoreSpec(src_block_ids) + + reqs_to_store[req_id] = (src_spec, dst_spec) + self._reqs_being_stored[req_id] |= block_hashes_to_store + + logger.debug( + "Request %s offloading %s blocks starting from block #%d", + req_id, + len(block_hashes_to_store), + start_block_idx, + ) + + return reqs_to_store + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + meta = OffloadingConnectorMetadata( + reqs_to_load=self._reqs_to_load, + reqs_to_store=self._get_reqs_to_store(scheduler_output), + ) + self._reqs_to_load = {} + return meta + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + for req_id in connector_output.finished_sending or []: + block_hashes = self._reqs_being_stored.pop(req_id, None) + if block_hashes: + self.manager.complete_store(block_hashes) + + for req_id in connector_output.finished_recving or []: + block_hashes = self._reqs_being_loaded.pop(req_id, None) + if block_hashes: + self.manager.complete_load(block_hashes) + + def request_finished( + self, + request: Request, + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + req_id = request.request_id + self._requests.pop(req_id, None) + self._request_block_ids.pop(req_id, None) + self._next_stored_block_idx.pop(req_id, None) + + request_being_stored = req_id in self._reqs_being_stored + return request_being_stored, None + + def take_events(self) -> Iterable[KVCacheEvent]: + """Take the KV cache events from the connector. + + Returns: + A list of KV cache events. + """ + for event in self.manager.take_events(): + if event.removed: + yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium) + else: + yield BlockStored( + block_hashes=event.block_hashes, + parent_block_hash=None, + token_ids=[], + lora_id=None, + block_size=event.block_size, + medium=event.medium, + ) + + +class OffloadingConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, spec: OffloadingSpec): + self.spec = spec + self.worker = OffloadingWorker() + + self._job_counter = 0 + + # req_id -> (job_id, store) + self._jobs: dict[int, tuple[ReqId, bool]] = {} + # req_id -> active job IDs + self._load_job: dict[ReqId, int] = {} + # req_id -> set(active job IDs) + self._store_jobs = defaultdict[ReqId, set[int]](set) + + self._finished_reqs_waiting_for_store: set[ReqId] = set() + + def _generate_job_id(self) -> int: + job_id = self._job_counter + self._job_counter = job_id + 1 + return job_id + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches): + self.worker.register_handler(src_cls, dst_cls, handler) + + def start_load_kv(self, metadata: OffloadingConnectorMetadata): + for req_id, transfer_spec in metadata.reqs_to_load.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, False) + assert req_id not in self._load_job + self._load_job[req_id] = job_id + assert self.worker.transfer_async(job_id, transfer_spec) + + def start_store_kv(self, metadata: OffloadingConnectorMetadata): + for req_id, transfer_spec in metadata.reqs_to_store.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, True) + self._store_jobs[req_id].add(job_id) + assert self.worker.transfer_async(job_id, transfer_spec) + + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + Returns a list of request IDs that finished loading or storing. + + Returns: + ids of requests that have finished asynchronous transfer + tuple of (sending/saving ids, recving/loading ids). + """ + finished_sending = set() + finished_recving = set() + for job_id, success in self.worker.get_finished(): + # we currently do not support job failures + assert success + req_id, store = self._jobs.pop(job_id) + if store: + req_jobs = self._store_jobs[req_id] + req_jobs.remove(job_id) + if req_jobs: + continue + + if req_id in self._finished_reqs_waiting_for_store: + self._finished_reqs_waiting_for_store.remove(req_id) + finished_sending.add(req_id) + del self._store_jobs[req_id] + else: + req_job = self._load_job[req_id] + assert job_id == req_job + del self._load_job[req_id] + finished_recving.add(req_id) + + for req_id in finished_req_ids: + pending_req_jobs = self._store_jobs.get(req_id) + if pending_req_jobs: + self._finished_reqs_waiting_for_store.add(req_id) + elif pending_req_jobs is not None: + finished_sending.add(req_id) + del self._store_jobs[req_id] + + return finished_sending, finished_recving + + +def yield_req_data( + scheduler_output, +) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: + """ + Yields: + (req_id, new_block_id_groups, preempted) + """ + # new requests + for req_data in scheduler_output.scheduled_new_reqs: + yield req_data.req_id, req_data.block_ids, False + + # cached requests + cached_reqs = scheduler_output.scheduled_cached_reqs + yield from zip( + cached_reqs.req_ids, + cached_reqs.new_block_ids, + cached_reqs.resumed_from_preemption, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 32d0e43d71afe..0e6693db5cd24 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -9,9 +9,13 @@ import torch from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( - P2pNcclEngine) + P2pNcclEngine, +) from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import MLACommonMetadata @@ -30,27 +34,20 @@ logger = init_logger(__name__) class ReqMeta: # Request Id request_id: str - # Request tokens - token_ids: torch.Tensor - # Slot mappings, should have the same length as token_ids - slot_mapping: torch.Tensor + # Request block ids + block_ids: torch.Tensor + # Request num tokens + num_tokens: int @staticmethod - def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], - block_size: int) -> "ReqMeta": - valid_num_tokens = len(token_ids) - token_ids_tensor = torch.tensor(token_ids) + def make_meta( + request_id: str, token_ids: list[int], block_ids: list[int], block_size: int + ) -> "ReqMeta": block_ids_tensor = torch.tensor(block_ids) - num_blocks = block_ids_tensor.shape[0] - block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size - slot_mapping = slot_mapping.flatten()[:valid_num_tokens] - return ReqMeta( request_id=request_id, - token_ids=token_ids_tensor, - slot_mapping=slot_mapping, + block_ids=block_ids_tensor, + num_tokens=len(token_ids), ) @@ -69,11 +66,11 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata): block_size: int, ) -> None: self.requests.append( - ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)) + ReqMeta.make_meta(request_id, token_ids, block_ids, block_size) + ) class P2pNcclConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size @@ -82,24 +79,27 @@ class P2pNcclConnector(KVConnectorBase_V1): self.is_producer = self.config.is_kv_producer self.chunked_prefill: dict[str, Any] = {} - self._rank = get_world_group().rank \ - if role == KVConnectorRole.WORKER else 0 - self._local_rank = get_world_group().local_rank \ - if role == KVConnectorRole.WORKER else 0 + self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 + self._local_rank = ( + get_world_group().local_rank if role == KVConnectorRole.WORKER else 0 + ) - self.p2p_nccl_engine = P2pNcclEngine( - local_rank=self._local_rank, - config=self.config, - hostname="", - port_offset=self._rank, - ) if role == KVConnectorRole.WORKER else None + self.p2p_nccl_engine = ( + P2pNcclEngine( + local_rank=self._local_rank, + config=self.config, + hostname="", + port_offset=self._rank, + ) + if role == KVConnectorRole.WORKER + else None + ) # ============================== # Worker-side methods # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -123,67 +123,68 @@ class P2pNcclConnector(KVConnectorBase_V1): return def inject_kv_into_layer( - dst_kv_cache_layer: torch.Tensor, - src_kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, + layer: torch.Tensor, + kv_cache: torch.Tensor, + block_ids: torch.Tensor, request_id: str, ) -> None: - """Inject the KV cache into the layer. + """ + Inject KV cache data into a given attention layer tensor. + + This function updates `layer` in-place with values from `kv_cache`, + handling different backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + If the number of provided block IDs does not match the number of KV + blocks, only the overlapping portion is updated, and a warning is + logged. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not - using MLA, [num_pages, page_size, xxx] otherwise. - src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] - otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape - [num_tokens]. - request_id (str): request id for log + layer (torch.Tensor): The attention layer KV tensor to update. + kv_cache (torch.Tensor): The KV cache tensor to inject. + block_ids (torch.Tensor): Indices of the blocks to update. + request_id (str): Request identifier used for logging. + + Returns: + None. The function modifies `layer` in-place. """ - dst_kv_cache_layer_shape = dst_kv_cache_layer.shape - if isinstance(attn_metadata, MLACommonMetadata): - num_pages = dst_kv_cache_layer_shape[0] - page_size = dst_kv_cache_layer_shape[1] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 0) - num_token = src_kv_cache.shape[0] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + if ( + isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2 + ): # MLA or FlashInfer + num_block = kv_cache.shape[0] + self.check_tensors_except_dim(layer, kv_cache, 0) + if len(block_ids) == num_block: + layer[block_ids, ...] = kv_cache else: - dst_kv_cache_layer[slot_mapping[:num_token], - ...] = src_kv_cache + layer[block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", + len(block_ids), + num_block, + request_id, + ) - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) - else: - num_pages = dst_kv_cache_layer_shape[1] - page_size = dst_kv_cache_layer_shape[2] - dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) - self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, - 1) - num_token = src_kv_cache.shape[1] - if len(slot_mapping) == num_token: - dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + elif layer.shape[0] == 2: # FlashAttention + num_block = kv_cache.shape[1] + self.check_tensors_except_dim(layer, kv_cache, 1) + if len(block_ids) == num_block: + layer[:, block_ids, ...] = kv_cache else: - dst_kv_cache_layer[:, slot_mapping[:num_token], - ...] = src_kv_cache + layer[:, block_ids[:num_block], ...] = kv_cache logger.warning( - "🚧src_kv_cache does not match, num_slot:%d, " - "num_token:%d, request_id:%s", len(slot_mapping), - num_token, request_id) - - dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + "🚧kv_cache does not match, block_ids:%d, " + "num_block:%d, request_id:%s", + len(block_ids), + num_block, + request_id, + ) # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() + metadata: KVConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, P2pNcclConnectorMetadata) if metadata is None: @@ -191,29 +192,32 @@ class P2pNcclConnector(KVConnectorBase_V1): # Load the KV for each request each layer for request in metadata.requests: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self._rank) for layer_name in forward_context.no_compile_layers: layer = forward_context.no_compile_layers[layer_name] # Only process layers that have kv_cache # attribute (attention layers) Skip non-attention # layers like FusedMoE - kv_cache = getattr(layer, 'kv_cache', None) + kv_cache = getattr(layer, "kv_cache", None) if kv_cache is None: continue - kv_cache_layer = kv_cache[ \ - forward_context.virtual_engine] + layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( - request.request_id + "#" + layer_name) + request.request_id + "#" + layer_name, remote_address + ) if kv_cache is None: - logger.warning("🚧src_kv_cache is None, %s", - request.request_id) + logger.warning("🚧kv_cache is None, %s", request.request_id) continue - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping, request.request_id) + inject_kv_into_layer( + layer, kv_cache, request.block_ids, request.request_id + ) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's @@ -226,8 +230,13 @@ class P2pNcclConnector(KVConnectorBase_V1): """ return - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. @@ -245,16 +254,48 @@ class P2pNcclConnector(KVConnectorBase_V1): assert self.p2p_nccl_engine is not None + def extract_kv_from_layer( + layer: torch.Tensor, + block_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Extract KV cache slices from a given attention layer tensor. + + This function handles multiple backend layouts: + - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are + indexed along the first dimension. + - FlashAttention: KV tensors are indexed along the second + dimension. + + Args: + layer (torch.Tensor): The KV cache from the attention layer. + block_ids (torch.Tensor): Indices of blocks to extract. + + Returns: + torch.Tensor: A tensor containing the extracted KV slices. + Returns None if the layout is unsupported. + """ + if ( + isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2 + ): # MLA or FlashInfer + return layer[block_ids, ...] + + if layer.shape[0] == 2: # FlashAttention + return layer[:, block_ids, ...] + + return None + connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) for request in connector_metadata.requests: request_id = request.request_id ip, port = self.parse_request_id(request_id, True) remote_address = ip + ":" + str(port + self._rank) + + kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) self.p2p_nccl_engine.send_tensor( - request_id + "#" + layer_name, kv_layer, remote_address, - request.slot_mapping, - isinstance(attn_metadata, MLACommonMetadata)) + request_id + "#" + layer_name, kv_cache, remote_address + ) def wait_for_save(self): if self.is_producer: @@ -262,8 +303,8 @@ class P2pNcclConnector(KVConnectorBase_V1): self.p2p_nccl_engine.wait_for_sent() def get_finished( - self, finished_req_ids: set[str], - **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]: + self, finished_req_ids: set[str], **kwargs: Any + ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -277,10 +318,8 @@ class P2pNcclConnector(KVConnectorBase_V1): assert self.p2p_nccl_engine is not None - no_compile_layers = ( - self._vllm_config.compilation_config.static_forward_context) - return self.p2p_nccl_engine.get_finished(finished_req_ids, - no_compile_layers) + no_compile_layers = self._vllm_config.compilation_config.static_forward_context + return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers) # ============================== # Scheduler-side methods @@ -307,23 +346,24 @@ class P2pNcclConnector(KVConnectorBase_V1): if self.is_producer: return 0, False - num_external_tokens = (len(request.prompt_token_ids) - 1 - - num_computed_tokens) + num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens if num_external_tokens < 0: num_external_tokens = 0 return num_external_tokens, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. """ if not self.is_producer and num_external_tokens > 0: self._requests_need_load[request.request_id] = ( - request, blocks.get_block_ids()[0]) + request, + blocks.get_block_ids()[0], + ) def build_connector_meta( self, @@ -342,26 +382,33 @@ class P2pNcclConnector(KVConnectorBase_V1): for new_req in scheduler_output.scheduled_new_reqs: if self.is_producer: - num_scheduled_tokens = ( - scheduler_output.num_scheduled_tokens)[new_req.req_id] + num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[ + new_req.req_id + ] num_tokens = num_scheduled_tokens + new_req.num_computed_tokens # the request's prompt is chunked prefill if num_tokens < len(new_req.prompt_token_ids): # 'CachedRequestData' has no attribute 'prompt_token_ids' self.chunked_prefill[new_req.req_id] = ( - new_req.block_ids[0], new_req.prompt_token_ids) + new_req.block_ids[0], + new_req.prompt_token_ids, + ) continue # the request's prompt is not chunked prefill - meta.add_request(request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size) + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) continue if new_req.req_id in self._requests_need_load: - meta.add_request(request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size) + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) self._requests_need_load.pop(new_req.req_id) cached_reqs = scheduler_output.scheduled_cached_reqs @@ -371,24 +418,24 @@ class P2pNcclConnector(KVConnectorBase_V1): resumed_from_preemption = cached_reqs.resumed_from_preemption[i] if self.is_producer: - num_scheduled_tokens = ( - scheduler_output.num_scheduled_tokens)[req_id] - num_tokens = (num_scheduled_tokens + num_computed_tokens) + num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id] + num_tokens = num_scheduled_tokens + num_computed_tokens assert req_id in self.chunked_prefill block_ids = new_block_ids[0] if not resumed_from_preemption: - block_ids = (self.chunked_prefill[req_id][0] + block_ids) + block_ids = self.chunked_prefill[req_id][0] + block_ids prompt_token_ids = self.chunked_prefill[req_id][1] # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): - self.chunked_prefill[req_id] = (block_ids, - prompt_token_ids) + self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) continue # the request's prompt is all prefilled finally - meta.add_request(request_id=req_id, - token_ids=prompt_token_ids, - block_ids=block_ids, - block_size=self._block_size) + meta.add_request( + request_id=req_id, + token_ids=prompt_token_ids, + block_ids=block_ids, + block_size=self._block_size, + ) self.chunked_prefill.pop(req_id, None) continue @@ -405,10 +452,12 @@ class P2pNcclConnector(KVConnectorBase_V1): # of the block_ids for the request. block_ids = new_block_ids[0] - meta.add_request(request_id=req_id, - token_ids=token_ids, - block_ids=block_ids, - block_size=self._block_size) + meta.add_request( + request_id=req_id, + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + ) self._requests_need_load.clear() return meta @@ -453,8 +502,7 @@ class P2pNcclConnector(KVConnectorBase_V1): port = int(match.group(2)) return ip, port - raise ValueError( - f"Request id {request_id} does not contain hostname and port") + raise ValueError(f"Request id {request_id} does not contain hostname and port") @staticmethod def check_tensors_except_dim(tensor1, tensor2, dim): @@ -462,8 +510,9 @@ class P2pNcclConnector(KVConnectorBase_V1): shape2 = tensor2.size() if len(shape1) != len(shape2) or not all( - s1 == s2 - for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim): + s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim + ): raise NotImplementedError( "Currently, only symmetric TP is supported. Asymmetric TP, PP," - "and others will be supported in future PRs.") + "and others will be supported in future PRs." + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index b94f2296dcb36..cff68818ca70b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -15,11 +15,17 @@ import msgpack import torch import zmq -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, +) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 - TensorMemoryPool) + TensorMemoryPool, +) from vllm.utils import current_stream, get_ip logger = logging.getLogger(__name__) @@ -31,12 +37,12 @@ DEFAULT_MEM_POOL_SIZE_GB = 32 def set_p2p_nccl_context(num_channels: str): original_values: dict[str, Any] = {} env_vars = [ - 'NCCL_MAX_NCHANNELS', - 'NCCL_MIN_NCHANNELS', - 'NCCL_CUMEM_ENABLE', - 'NCCL_BUFFSIZE', - 'NCCL_PROTO', # LL,LL128,SIMPLE - 'NCCL_ALGO', # RING,TREE + "NCCL_MAX_NCHANNELS", + "NCCL_MIN_NCHANNELS", + "NCCL_CUMEM_ENABLE", + "NCCL_BUFFSIZE", + "NCCL_PROTO", # LL,LL128,SIMPLE + "NCCL_ALGO", # RING,TREE ] for var in env_vars: @@ -45,9 +51,9 @@ def set_p2p_nccl_context(num_channels: str): logger.info("set_p2p_nccl_context, original_values: %s", original_values) try: - os.environ['NCCL_MAX_NCHANNELS'] = num_channels - os.environ['NCCL_MIN_NCHANNELS'] = num_channels - os.environ['NCCL_CUMEM_ENABLE'] = '1' + os.environ["NCCL_MAX_NCHANNELS"] = num_channels + os.environ["NCCL_MIN_NCHANNELS"] = num_channels + os.environ["NCCL_CUMEM_ENABLE"] = "1" yield finally: for var in env_vars: @@ -62,18 +68,17 @@ class SendQueueItem: tensor_id: str remote_address: str tensor: torch.Tensor - slot_mapping: torch.Tensor - is_mla: bool class P2pNcclEngine: - - def __init__(self, - local_rank: int, - config: KVTransferConfig, - hostname: str = "", - port_offset: int = 0, - library_path: Optional[str] = None) -> None: + def __init__( + self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None, + ) -> None: self.config = config self.rank = port_offset self.local_rank = local_rank @@ -93,8 +98,8 @@ class P2pNcclEngine: # The `http_port` must be consistent with the port of OpenAI. self.http_address = ( - f"{self._hostname}:" - f"{self.config.kv_connector_extra_config['http_port']}") + f"{self._hostname}:{self.config.kv_connector_extra_config['http_port']}" + ) # If `proxy_ip` or `proxy_port` is `""`, # then the ping thread will not be enabled. @@ -120,15 +125,17 @@ class P2pNcclEngine: self.recv_stream = torch.cuda.Stream() mem_pool_size_gb = float( - self.config.get_from_extra_config("mem_pool_size_gb", - DEFAULT_MEM_POOL_SIZE_GB)) - self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb * - 1024**3)) # GB + self.config.get_from_extra_config( + "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB + ) + ) + self.pool = TensorMemoryPool( + max_block_size=int(mem_pool_size_gb * 1024**3) + ) # GB # The sending type includes tree mutually exclusive options: # PUT, GET, PUT_ASYNC. - self.send_type = self.config.get_from_extra_config( - "send_type", "PUT_ASYNC") + self.send_type = self.config.get_from_extra_config("send_type", "PUT_ASYNC") if self.send_type == "GET": # tensor_id: torch.Tensor self.send_store: dict[str, torch.Tensor] = {} @@ -136,15 +143,16 @@ class P2pNcclEngine: # PUT or PUT_ASYNC # tensor_id: torch.Tensor self.send_queue: deque[SendQueueItem] = deque() - self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} if self.send_type == "PUT_ASYNC": - self._send_thread = threading.Thread(target=self.send_async, - daemon=True) + self._send_thread = threading.Thread( + target=self.send_async, daemon=True + ) self._send_thread.start() # tensor_id: torch.Tensor/(addr, dtype, shape) self.recv_store: dict[str, Any] = {} self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {} + self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.socks: dict[str, Any] = {} # remote_address: client socket self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank) @@ -152,10 +160,12 @@ class P2pNcclEngine: self.buffer_size_threshold = float(self.config.kv_buffer_size) self.nccl_num_channels = self.config.get_from_extra_config( - "nccl_num_channels", "8") + "nccl_num_channels", "8" + ) self._listener_thread = threading.Thread( - target=self.listen_for_requests, daemon=True) + target=self.listen_for_requests, daemon=True + ) self._listener_thread.start() self._ping_thread = None @@ -166,9 +176,16 @@ class P2pNcclEngine: logger.info( "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, " "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_" - "threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank, - self.http_address, self.zmq_address, self.proxy_address, - self.send_type, self.buffer_size_threshold, self.nccl_num_channels) + "threshold:%.2f, nccl_num_channels:%s", + self.rank, + self.local_rank, + self.http_address, + self.zmq_address, + self.proxy_address, + self.send_type, + self.buffer_size_threshold, + self.nccl_num_channels, + ) def create_connect(self, remote_address: typing.Optional[str] = None): assert remote_address is not None @@ -178,8 +195,11 @@ class P2pNcclEngine: sock.connect(f"tcp://{remote_address}") self.socks[remote_address] = sock if remote_address in self.comms: - logger.info("👋comm exists, remote_address:%s, comms:%s", - remote_address, self.comms) + logger.info( + "👋comm exists, remote_address:%s, comms:%s", + remote_address, + self.comms, + ) return sock, self.comms[remote_address] unique_id = self.nccl.ncclGetUniqueId() @@ -189,11 +209,14 @@ class P2pNcclEngine: with torch.cuda.device(self.device): rank = 0 with set_p2p_nccl_context(self.nccl_num_channels): - comm: ncclComm_t = self.nccl.ncclCommInitRank( - 2, unique_id, rank) + comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank) self.comms[remote_address] = (comm, rank) - logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s", - self.zmq_address, remote_address, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👉%s, MyRank:%s", + self.zmq_address, + remote_address, + rank, + ) return self.socks[remote_address], self.comms[remote_address] @@ -202,8 +225,6 @@ class P2pNcclEngine: tensor_id: str, tensor: torch.Tensor, remote_address: typing.Optional[str] = None, - slot_mapping: torch.Tensor = None, - is_mla: bool = False, ) -> bool: if remote_address is None: with self.recv_store_cv: @@ -211,11 +232,9 @@ class P2pNcclEngine: self.recv_store_cv.notify() return True - item = SendQueueItem(tensor_id=tensor_id, - remote_address=remote_address, - tensor=tensor, - slot_mapping=slot_mapping, - is_mla=is_mla) + item = SendQueueItem( + tensor_id=tensor_id, remote_address=remote_address, tensor=tensor + ) if self.send_type == "PUT": return self.send_sync(item) @@ -229,27 +248,49 @@ class P2pNcclEngine: # GET with self.send_store_cv: tensor_size = tensor.element_size() * tensor.numel() - while (self.buffer_size + tensor_size - > self.buffer_size_threshold): - oldest_tenser_id = next(iter(self.send_store)) - oldest_tenser = self.send_store.pop(oldest_tenser_id) - oldest_tenser_size = oldest_tenser.element_size( - ) * oldest_tenser.numel() - self.buffer_size -= oldest_tenser_size - logger.info( + if tensor_size > self.buffer_size_threshold: + logger.warning( + "❗[GET]tensor_id:%s, tensor_size:%d, is greater than" + "buffer size threshold :%d, skip send to %s, rank:%d", + tensor_id, + tensor_size, + self.buffer_size_threshold, + remote_address, + self.rank, + ) + return False + while self.buffer_size + tensor_size > self.buffer_size_threshold: + assert len(self.send_store) > 0 + oldest_tensor_id = next(iter(self.send_store)) + oldest_tensor = self.send_store.pop(oldest_tensor_id) + oldest_tensor_size = ( + oldest_tensor.element_size() * oldest_tensor.numel() + ) + self.buffer_size -= oldest_tensor_size + logger.debug( "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," - " buffer_size:%d, oldest_tenser_size:%d, rank:%d", - remote_address, tensor_id, tensor_size, self.buffer_size, - oldest_tenser_size, self.rank) + " buffer_size:%d, oldest_tensor_size:%d, rank:%d", + remote_address, + tensor_id, + tensor_size, + self.buffer_size, + oldest_tensor_size, + self.rank, + ) self.send_store[tensor_id] = tensor self.buffer_size += tensor_size logger.debug( "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " - "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address, - tensor_id, tensor_size, tensor.shape, self.rank, + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, + tensor_id, + tensor_size, + tensor.shape, + self.rank, self.buffer_size, - self.buffer_size / self.buffer_size_threshold * 100) + self.buffer_size / self.buffer_size_threshold * 100, + ) return True def recv_tensor( @@ -267,17 +308,18 @@ class P2pNcclEngine: if tensor is not None: if isinstance(tensor, tuple): addr, dtype, shape = tensor - tensor = self.pool.load_tensor(addr, dtype, shape, - self.device) + tensor = self.pool.load_tensor(addr, dtype, shape, self.device) else: - self.buffer_size -= (tensor.element_size() * - tensor.numel()) + self.buffer_size -= tensor.element_size() * tensor.numel() else: duration = time.time() - start_time logger.warning( - "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " - "rank:%d", remote_address, tensor_id, duration * 1000, - self.rank) + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, rank:%d", + remote_address, + tensor_id, + duration * 1000, + self.rank, + ) return tensor # GET @@ -296,14 +338,18 @@ class P2pNcclEngine: message = sock.recv() data = msgpack.loads(message) if data["ret"] != 0: - logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", - remote_address, tensor_id, data["ret"]) + logger.warning( + "🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, + tensor_id, + data["ret"], + ) return None with torch.cuda.stream(self.recv_stream): - tensor = torch.empty(data["shape"], - dtype=getattr(torch, data["dtype"]), - device=self.device) + tensor = torch.empty( + data["shape"], dtype=getattr(torch, data["dtype"]), device=self.device + ) self.recv(comm, tensor, rank ^ 1, self.recv_stream) @@ -318,38 +364,45 @@ class P2pNcclEngine: remote_address, message = self.router_socket.recv_multipart() data = msgpack.loads(message) if data["cmd"] == "NEW": - unique_id = self.nccl.unique_id_from_bytes( - bytes(data["unique_id"])) + unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"])) with torch.cuda.device(self.device): rank = 1 with set_p2p_nccl_context(self.nccl_num_channels): comm: ncclComm_t = self.nccl.ncclCommInitRank( - 2, unique_id, rank) + 2, unique_id, rank + ) self.comms[remote_address.decode()] = (comm, rank) - logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", - self.zmq_address, remote_address.decode(), - rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, + remote_address.decode(), + rank, + ) elif data["cmd"] == "PUT": tensor_id = data["tensor_id"] try: with torch.cuda.stream(self.recv_stream): - tensor = torch.empty(data["shape"], - dtype=getattr( - torch, data["dtype"]), - device=self.device) + tensor = torch.empty( + data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device, + ) self.router_socket.send_multipart([remote_address, b"0"]) comm, rank = self.comms[remote_address.decode()] self.recv(comm, tensor, rank ^ 1, self.recv_stream) tensor_size = tensor.element_size() * tensor.numel() - if (self.buffer_size + tensor_size - > self.buffer_size_threshold): + if self.buffer_size + tensor_size > self.buffer_size_threshold: # Store Tensor in memory pool addr = self.pool.store_tensor(tensor) tensor = (addr, tensor.dtype, tensor.shape) logger.warning( "🔴[PUT]Recv Tensor, Out Of Threshold, " - "%s👈%s, data:%s, addr:%d", self.zmq_address, - remote_address.decode(), data, addr) + "%s👈%s, data:%s, addr:%d", + self.zmq_address, + remote_address.decode(), + data, + addr, + ) else: self.buffer_size += tensor_size @@ -357,9 +410,11 @@ class P2pNcclEngine: self.router_socket.send_multipart([remote_address, b"1"]) tensor = None logger.warning( - "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " - "data:%s", self.zmq_address, remote_address.decode(), - data) + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s", + self.zmq_address, + remote_address.decode(), + data, + ) with self.recv_store_cv: self.recv_store[tensor_id] = tensor @@ -374,7 +429,7 @@ class P2pNcclEngine: data = { "ret": 0, "shape": tensor.shape, - "dtype": str(tensor.dtype).replace("torch.", "") + "dtype": str(tensor.dtype).replace("torch.", ""), } # LRU self.send_store[tensor_id] = tensor @@ -382,26 +437,26 @@ class P2pNcclEngine: else: data = {"ret": 1} - self.router_socket.send_multipart( - [remote_address, msgpack.dumps(data)]) + self.router_socket.send_multipart([remote_address, msgpack.dumps(data)]) if data["ret"] == 0: comm, rank = self.comms[remote_address.decode()] - self.send(comm, tensor.to(self.device), rank ^ 1, - self.send_stream) + self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) else: logger.warning( "🚧Unexpected, Received message from %s, data:%s", - remote_address, data) + remote_address, + data, + ) def have_sent_tensor_id(self, tensor_id: str): - request_id = tensor_id.split('#')[0] + request_id = tensor_id.split("#")[0] if request_id not in self.send_request_id_to_tensor_ids: self.send_request_id_to_tensor_ids[request_id] = set() self.send_request_id_to_tensor_ids[request_id].add(tensor_id) def have_received_tensor_id(self, tensor_id: str): - request_id = tensor_id.split('#')[0] + request_id = tensor_id.split("#")[0] if request_id not in self.recv_request_id_to_tensor_ids: self.recv_request_id_to_tensor_ids[request_id] = set() self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) @@ -425,7 +480,10 @@ class P2pNcclEngine: duration = time.time() - start_time logger.debug( "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" - " to be empty, rank:%d", duration * 1000, self.rank) + " to be empty, rank:%d", + duration * 1000, + self.rank, + ) def send_sync(self, item: SendQueueItem) -> bool: if item.remote_address is None: @@ -433,9 +491,7 @@ class P2pNcclEngine: if item.remote_address not in self.socks: self.create_connect(item.remote_address) - with self.send_stream: - tensor = self.extract_kv_from_layer(item.is_mla, item.tensor, - item.slot_mapping) + tensor = item.tensor sock = self.socks[item.remote_address] comm, rank = self.comms[item.remote_address] @@ -443,7 +499,7 @@ class P2pNcclEngine: "cmd": "PUT", "tensor_id": item.tensor_id, "shape": tensor.shape, - "dtype": str(tensor.dtype).replace("torch.", "") + "dtype": str(tensor.dtype).replace("torch.", ""), } sock.send(msgpack.dumps(data)) @@ -452,10 +508,14 @@ class P2pNcclEngine: logger.error( "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", - self.zmq_address, item.remote_address, rank, data, + self.zmq_address, + item.remote_address, + rank, + data, tensor.shape, tensor.element_size() * tensor.numel() / 1024**3, - response.decode()) + response.decode(), + ) return False self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) @@ -466,7 +526,7 @@ class P2pNcclEngine: return True def get_finished( - self, finished_req_ids: set[str], no_compile_layers + self, finished_req_ids: set[str], no_compile_layers ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have @@ -486,10 +546,8 @@ class P2pNcclEngine: if tensor_id in self.recv_store: with self.recv_store_cv: tensor = self.recv_store.pop(tensor_id, None) - self.send_request_id_to_tensor_ids.pop( - request_id, None) - self.recv_request_id_to_tensor_ids.pop( - request_id, None) + self.send_request_id_to_tensor_ids.pop(request_id, None) + self.recv_request_id_to_tensor_ids.pop(request_id, None) if isinstance(tensor, tuple): addr, _, _ = tensor self.pool.free(addr) @@ -510,7 +568,7 @@ class P2pNcclEngine: data = { "type": "P" if self.config.is_kv_producer else "D", "http_address": self.http_address, - "zmq_address": self.zmq_address + "zmq_address": self.zmq_address, } while True: sock.send(msgpack.dumps(data)) @@ -519,27 +577,39 @@ class P2pNcclEngine: def send(self, comm, tensor: torch.Tensor, dst: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() with torch.cuda.stream(stream): - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + comm, + cudaStream_t(stream.cuda_stream), + ) stream.synchronize() def recv(self, comm, tensor: torch.Tensor, src: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() with torch.cuda.stream(stream): - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + comm, + cudaStream_t(stream.cuda_stream), + ) stream.synchronize() def close(self) -> None: @@ -548,21 +618,3 @@ class P2pNcclEngine: self._send_thread.join() if self._ping_thread is not None: self._ping_thread.join() - - @staticmethod - def extract_kv_from_layer( - is_mla: bool, - layer: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> torch.Tensor: - """Extract the KV cache from the layer. - Assume the shape of the layer is (2, num_pages, page_size, xxx) - if MLA is not used, and (num_pages, page_size, xxx) otherwise. - """ - if is_mla: - num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] - - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index 02e3bc6274f60..899f1eae86d27 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -67,8 +67,7 @@ class TensorMemoryPool: if max_block_size <= 0 or min_block_size <= 0: raise ValueError("Block sizes must be positive") if max_block_size < min_block_size: - raise ValueError( - "Max block size must be greater than min block size") + raise ValueError("Max block size must be greater than min block size") self.max_block_size = self._round_to_power_of_two(max_block_size) self.min_block_size = self._round_to_power_of_two(min_block_size) @@ -91,16 +90,18 @@ class TensorMemoryPool: size //= 2 def _allocate_pinned_memory(self): - self.base_tensor = torch.empty(self.max_block_size // 4, - dtype=torch.float32, - pin_memory=True) + self.base_tensor = torch.empty( + self.max_block_size // 4, dtype=torch.float32, pin_memory=True + ) self.base_address = self.base_tensor.data_ptr() - initial_block = MemoryBlock(size=self.max_block_size, - addr=self.base_address) - self.free_lists[self.max_block_size][ - initial_block.addr] = initial_block - logger.debug("TensorMemoryPool, base_address:", self.base_address, - self.base_address % self.max_block_size) + initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address) + self.free_lists[self.max_block_size][initial_block.addr] = initial_block + + logger.debug( + "TensorMemoryPool, base_address:%d, max_block_size:%d", + self.base_address, + self.max_block_size, + ) def allocate(self, size: int) -> int: """Allocates a memory block of at least the requested size. @@ -117,8 +118,7 @@ class TensorMemoryPool: if size <= 0: raise ValueError("Allocation size must be positive") - required_size = self._round_to_power_of_two( - max(size, self.min_block_size)) + required_size = self._round_to_power_of_two(max(size, self.min_block_size)) if required_size > self.max_block_size: raise ValueError("Requested size exceeds maximum block size") @@ -134,8 +134,7 @@ class TensorMemoryPool: raise ValueError("Insufficient memory") def _split_block(self, block: MemoryBlock, required_size: int): - while (block.size > required_size - and block.size // 2 >= self.min_block_size): + while block.size > required_size and block.size // 2 >= self.min_block_size: buddy_size = block.size // 2 buddy_addr = block.addr + buddy_size @@ -164,8 +163,11 @@ class TensorMemoryPool: depth = 0 while depth < MAX_MERGE_DEPTH: - buddy_offset = block.size if (block.addr - self.base_address) % ( - 2 * block.size) == 0 else -block.size + buddy_offset = ( + block.size + if (block.addr - self.base_address) % (2 * block.size) == 0 + else -block.size + ) buddy_addr = block.addr + buddy_offset buddy = self.free_lists[block.size].get(buddy_addr) if buddy: @@ -201,14 +203,14 @@ class TensorMemoryPool: self.free(addr) raise ValueError( f"Allocated block size {block.size} is smaller than " - f"required size {size}") + f"required size {size}" + ) try: buffer = (ctypes.c_byte * block.size).from_address(block.addr) - cpu_tensor = torch.frombuffer(buffer, - dtype=tensor.dtype, - count=tensor.numel()).reshape( - tensor.shape) + cpu_tensor = torch.frombuffer( + buffer, dtype=tensor.dtype, count=tensor.numel() + ).reshape(tensor.shape) except ValueError as err: self.free(addr) raise ValueError(f"Failed to create tensor view: {err}") from err @@ -217,8 +219,13 @@ class TensorMemoryPool: return addr - def load_tensor(self, addr: int, dtype: torch.dtype, - shape: tuple[int, ...], device) -> torch.Tensor: + def load_tensor( + self, + addr: int, + dtype: torch.dtype, + shape: tuple[int, ...], + device: torch.device, + ) -> torch.Tensor: """Loads a tensor from pinned host memory to the specified device. Args: @@ -245,8 +252,9 @@ class TensorMemoryPool: raise ValueError("Requested tensor size exceeds block size") buffer = (ctypes.c_byte * block.size).from_address(block.addr) - cpu_tensor = torch.frombuffer(buffer, dtype=dtype, - count=num_elements).reshape(shape) + cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape( + shape + ) cuda_tensor = torch.empty(shape, dtype=dtype, device=device) @@ -258,7 +266,7 @@ class TensorMemoryPool: """Cleans up all memory resources and resets the pool state.""" self.free_lists.clear() self.allocated_blocks.clear() - if hasattr(self, 'base_tensor'): + if hasattr(self, "base_tensor"): del self.base_tensor def __del__(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index fd79387269d56..a1bab4e061455 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -2,15 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib import os -from dataclasses import dataclass -from typing import TYPE_CHECKING +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional import safetensors import torch from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -35,15 +38,22 @@ class ReqMeta: mm_hashes: list[str] @staticmethod - def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, - is_store: bool, mm_hashes: list[str]) -> "ReqMeta": + def make_meta( + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + mm_hashes: list[str], + ) -> "ReqMeta": valid_num_tokens = align_to_block_size(len(token_ids), block_size) token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] block_ids_tensor = torch.tensor(block_ids) num_blocks = block_ids_tensor.shape[0] block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_tensor.reshape((num_blocks, 1)) * block_size + ) slot_mapping = slot_mapping.flatten()[:valid_num_tokens] return ReqMeta( token_ids=token_ids_tensor, @@ -55,10 +65,7 @@ class ReqMeta: @dataclass class SharedStorageConnectorMetadata(KVConnectorMetadata): - requests: list[ReqMeta] - - def __init__(self): - self.requests = [] + requests: list[ReqMeta] = field(default_factory=list) def add_request( self, @@ -69,8 +76,8 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata): mm_hashes: list[str], ) -> None: self.requests.append( - ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, - mm_hashes)) + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes) + ) class SharedStorageConnector(KVConnectorBase_V1): @@ -85,13 +92,13 @@ class SharedStorageConnector(KVConnectorBase_V1): self._requests_need_load: dict[str, Request] = {} transfer_config = vllm_config.kv_transfer_config self._storage_path = transfer_config.get_from_extra_config( - "shared_storage_path", "/tmp") + "shared_storage_path", "/tmp" + ) logger.info(vllm_config.kv_transfer_config) logger.info("Shared storage path is %s", self._storage_path) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - """Start loading the KV cache from the connector buffer to vLLM's + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. Args: @@ -99,7 +106,7 @@ class SharedStorageConnector(KVConnectorBase_V1): **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. """ attn_metadata = forward_context.attn_metadata @@ -112,13 +119,13 @@ class SharedStorageConnector(KVConnectorBase_V1): """Inject the KV cache into the layer. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not using MLA, [num_pages, page_size, xxx] otherwise. src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape + slot_mapping (torch.Tensor): the slot mapping. In shape [num_tokens]. """ dst_kv_cache_layer_shape = dst_kv_cache_layer.shape @@ -126,14 +133,16 @@ class SharedStorageConnector(KVConnectorBase_V1): num_pages = dst_kv_cache_layer_shape[0] page_size = dst_kv_cache_layer_shape[1] dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) + num_pages * page_size, -1 + ) dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) else: num_pages = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[2] dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) + 2, num_pages * page_size, -1 + ) dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) @@ -149,40 +158,39 @@ class SharedStorageConnector(KVConnectorBase_V1): attn_metadata = forward_context.attn_metadata if attn_metadata is None: - logger.warning( - "In connector.start_load_kv, but the attn_metadata is None") + logger.warning("In connector.start_load_kv, but the attn_metadata is None") return # Load the KV for each request each layer for request in metadata.requests: if request.is_store: continue - logger.info("Inject KV cache of %d tokens to the paged memory", - len(request.slot_mapping)) + logger.info( + "Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping), + ) for layer_name in forward_context.no_compile_layers: layer = forward_context.no_compile_layers[layer_name] # Only process layers that have kv_cache # attribute (attention layers) Skip non-attention # layers like FusedMoE/MLP etc. - kv_cache_attr = getattr(layer, 'kv_cache', None) + kv_cache_attr = getattr(layer, "kv_cache", None) if kv_cache_attr is None: continue - kv_cache_layer = kv_cache_attr[ \ - forward_context.virtual_engine] + kv_cache_layer = kv_cache_attr[forward_context.virtual_engine] filename = self._generate_filename_debug( - layer_name, request.token_ids, request.mm_hashes) - kv_cache = safetensors.torch.load_file( - filename)["kv_cache"].cuda() - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping) + layer_name, request.token_ids, request.mm_hashes + ) + kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's - paged buffer. - + paged buffer. + This interface will be useful for layer-by-layer pipelining. Args: @@ -190,14 +198,19 @@ class SharedStorageConnector(KVConnectorBase_V1): """ return - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """Start saving the KV cache of the layer from vLLM's paged buffer + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -214,20 +227,18 @@ class SharedStorageConnector(KVConnectorBase_V1): """ if isinstance(attn_metadata, MLACommonMetadata): num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, - ...] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, SharedStorageConnectorMetadata) for request in connector_metadata.requests: if request.is_store: filename = self._generate_filename_debug( - layer_name, request.token_ids, request.mm_hashes) - kv_cache = extract_kv_from_layer(kv_layer, - request.slot_mapping) + layer_name, request.token_ids, request.mm_hashes + ) + kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) tensors = {"kv_cache": kv_cache.detach().cpu()} safetensors.torch.save_file(tensors, filename) @@ -238,18 +249,18 @@ class SharedStorageConnector(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ # NOTE: in this debug implementation, we assume that the prompt is @@ -267,13 +278,14 @@ class SharedStorageConnector(KVConnectorBase_V1): # Now, first num_tokens_to_check tokens are hit, we need to prepare # the metadata for the worker connector to correctly load the KV num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size) + len(request.prompt_token_ids) - 1, self._block_size + ) return num_tokens_to_check - num_computed_tokens, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. @@ -300,11 +312,13 @@ class SharedStorageConnector(KVConnectorBase_V1): total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=False, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in new_req.mm_features], + ) total_need_load += 1 else: # NOTE: here, we set the store and load being exclusive, @@ -312,11 +326,13 @@ class SharedStorageConnector(KVConnectorBase_V1): # NOTE(rob): for this debug implementation, we only cache # the original prompt tokens. if not self._found_match_for_request(new_req): - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=True, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True, + mm_hashes=[f.identifier for f in new_req.mm_features], + ) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -341,11 +357,13 @@ class SharedStorageConnector(KVConnectorBase_V1): # of the block_ids for the request. block_ids = new_block_ids[0] - meta.add_request(token_ids=token_ids, - block_ids=block_ids, - block_size=self._block_size, - is_store=False, - mm_hashes=request.mm_hashes) + meta.add_request( + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in request.mm_features], + ) total_need_load += 1 assert total_need_load == len(self._requests_need_load) @@ -360,14 +378,15 @@ class SharedStorageConnector(KVConnectorBase_V1): self, request: "Request", ) -> bool: - """Check if the cache is hit for the request. - """ + """Check if the cache is hit for the request.""" num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size) - foldername = self._generate_foldername_debug(torch.tensor( - request.prompt_token_ids)[:num_tokens_to_check], - request.mm_hashes, - create_folder=False) + len(request.prompt_token_ids) - 1, self._block_size + ) + foldername = self._generate_foldername_debug( + torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], + [f.identifier for f in request.mm_features], + create_folder=False, + ) return os.path.exists(foldername) def _generate_foldername_debug( @@ -376,7 +395,7 @@ class SharedStorageConnector(KVConnectorBase_V1): mm_hashes: list[str], create_folder=False, ) -> str: - """Generate a folder name based on the hash of the bytes of the input + """Generate a folder name based on the hash of the bytes of the input ids. """ token_bytes = token_ids.numpy().tobytes() @@ -384,9 +403,8 @@ class SharedStorageConnector(KVConnectorBase_V1): # to create a canonical key. if mm_hashes: mm_str = "-".join(mm_hashes) - token_bytes += mm_str.encode('utf-8') - input_ids_hash = hashlib.md5(token_bytes, - usedforsecurity=False).hexdigest() + token_bytes += mm_str.encode("utf-8") + input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest() foldername = os.path.join(self._storage_path, input_ids_hash) if create_folder: @@ -399,16 +417,15 @@ class SharedStorageConnector(KVConnectorBase_V1): token_ids: torch.Tensor, mm_hashes: list[str], ) -> str: - """Generate a file name based on the layer name and the hash + """Generate a file name based on the layer name and the hash of the bytes of the input ids. """ - foldername = self._generate_foldername_debug(token_ids, - mm_hashes=mm_hashes, - create_folder=True) + foldername = self._generate_foldername_debug( + token_ids, mm_hashes=mm_hashes, create_folder=True + ) return os.path.join(foldername, f"{layer_name}.safetensors") def align_to_block_size(num_tokens: int, block_size) -> int: - """Align the number of tokens to the block size. - """ + """Align the number of tokens to the block size.""" return (num_tokens - 1) // block_size * block_size diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index eef14269f1961..08b683bfe23f5 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -42,39 +42,44 @@ class KVLookupBufferBase(KVCacheBufferBase): Abstract base class for a KVCache lookup buffer. This class provides an abstraction for a key-value (KV) cache lookup buffer. - + The key of the lookup buffer: - input_tokens: token IDs of the request - roi: a binary mask on top of input_tokens. - - Purpose of roi: Since KV cache may only be available for a subset of - tokens in the input (for example, when vLLM is connected to an external - KV cache service), roi specifies the subset of tokens that the KV cache + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache is associated with. - - NOTE: roi can be further extended to describe which part of KV the - current process is holding (each process may only hold a part of KV + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV due to TP and PP). This is not implemented for now. - + The value of the lookup buffer: - key: the key tensor in the KV cache - value: the value tensor in the KV cache - - hidden: the final hidden state generated by model forwarding. This allows + - hidden: the final hidden state generated by model forwarding. This allows vLLM to bypass further model forwarding by transmitting the hidden state. """ @abstractmethod - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: + def insert( + self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + ) -> None: """Insert into the lookup buffer. - + The functionality is similar to the following python statement ``` buffer[input_tokens, roi] = [key, value, hidden] ``` - + FIXME: in the future, we should only have two arguments, key and value, where key is a tensor dict and value is a tensor dict. - + FIXME: we should transmit both sampler outputs and the hidden states. Args: @@ -82,8 +87,8 @@ class KVLookupBufferBase(KVCacheBufferBase): roi (torch.Tensor): A binary mask on top of the input tokens key (torch.Tensor): The key tensor in the KV cache. value (torch.Tensor): The value tensor in the KV cache. - hidden (torch.Tensor): The final hidden state tensor generated - during model forwarding to bypass model + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model forwarding. Raises: @@ -93,16 +98,16 @@ class KVLookupBufferBase(KVCacheBufferBase): @abstractmethod def drop_select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor] + ) -> list[Optional[torch.Tensor]]: """Select and *drop* KV cache entries from the lookup buffer. - + The functionality is similar to the following python statements ``` ret = buffer.pop(input_tokens, roi) return ret ``` - + If `input_tokens` and `roi` is `None`, it means selecting any of the KV caches in the buffer, return, and remove it from the buffer, useful when offloading KV cache to KV cache storage service. diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py index 4381aad1e9956..44fc6d8ac5ad3 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -6,6 +6,7 @@ think of KV cache transfer operations as putting new KV cache entries into a remote KVStore-based lookup buffer and getting existing KV caches from this remote lookup buffer. """ + import json import os from dataclasses import dataclass @@ -16,8 +17,7 @@ from safetensors.torch import load as safetensors_load from safetensors.torch import save as safetensors_save from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( - KVStoreBufferBase) +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase from vllm.logger import init_logger DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB @@ -37,65 +37,69 @@ class MooncakeStoreConfig: master_server_address: str @staticmethod - def from_file(file_path: str) -> 'MooncakeStoreConfig': + def from_file(file_path: str) -> "MooncakeStoreConfig": """Load the config from a JSON file.""" with open(file_path) as fin: config = json.load(fin) return MooncakeStoreConfig( local_hostname=config.get("local_hostname"), metadata_server=config.get("metadata_server"), - global_segment_size=config.get("global_segment_size", - DEFAULT_GLOBAL_SEGMENT_SIZE), - local_buffer_size=config.get("local_buffer_size", - DEFAULT_LOCAL_BUFFER_SIZE), + global_segment_size=config.get( + "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE + ), + local_buffer_size=config.get( + "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE + ), protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), master_server_address=config.get("master_server_address"), ) @staticmethod - def load_from_env() -> 'MooncakeStoreConfig': + def load_from_env() -> "MooncakeStoreConfig": """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") if config_file_path is None: raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." + ) return MooncakeStoreConfig.from_file(config_file_path) class MooncakeStore(KVStoreBufferBase): - def __init__( self, config: VllmConfig, ): - try: from mooncake.store import MooncakeDistributedStore except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e + "to run vLLM with MooncakeConnector." + ) from e try: self.store = MooncakeDistributedStore() self.config = MooncakeStoreConfig.load_from_env() logger.info("Mooncake Configuration loaded successfully.") - self.store.setup(self.config.local_hostname, - self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, self.config.device_name, - self.config.master_server_address) + self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) except ValueError as e: logger.error("Configuration loading failed: %s", e) raise except Exception as exc: - logger.error( - "An error occurred while loading the configuration: %s", exc) + logger.error("An error occurred while loading the configuration: %s", exc) raise def close(self): @@ -126,12 +130,9 @@ class MooncakeStore(KVStoreBufferBase): value: torch.Tensor, ) -> None: """Put KVCache to Mooncake Store""" - device_id = value.device.index if value.device.type == 'cuda' else -1 + device_id = value.device.index if value.device.type == "cuda" else -1 device_tensor = torch.tensor(device_id, dtype=torch.int32) - value_bytes = safetensors_save({ - "tensor": value, - "device_id": device_tensor - }) + value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor}) try: self.store.put(key, value_bytes) except TypeError as err: @@ -154,8 +155,11 @@ class MooncakeStore(KVStoreBufferBase): tensor = loaded_tensors["tensor"] device_id_tensor = loaded_tensors["device_id"] device_id = int(device_id_tensor.item()) - device = torch.device( - 'cuda', device_id) if device_id >= 0 else torch.device('cpu') + device = ( + torch.device("cuda", device_id) + if device_id >= 0 + else torch.device("cpu") + ) return tensor.to(device) return None diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index a0ff7c320f61e..cd58ec2e76398 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -1,23 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ - Implements a distributed key-value (KV) cache transfer mechanism. +Implements a distributed key-value (KV) cache transfer mechanism. - Key Features: - - Distributed KV cache transmission using PyNccl pipes. - - Non-blocking `insert`, blocking `drop_select`. - - Use CPU signal pipe to avoid racing condition - - Handles buffer size constraints and provide backpressure mechanism to - stop the prefill instance when the decode instance is slow. +Key Features: +- Distributed KV cache transmission using PyNccl pipes. +- Non-blocking `insert`, blocking `drop_select`. +- Use CPU signal pipe to avoid racing condition +- Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. """ + import threading from collections import deque from typing import Optional, Union import torch -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( - KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger @@ -25,9 +25,9 @@ logger = init_logger(__name__) class SimpleBuffer(KVLookupBufferBase): - - def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, - buffer_size_thresh: float): + def __init__( + self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float + ): """ signal_pipe: on CPU @@ -51,9 +51,11 @@ class SimpleBuffer(KVLookupBufferBase): self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None - def _matches(self, tokens_roi_sender: list[torch.Tensor], - tokens_roi_recver: list[torch.Tensor]): - + def _matches( + self, + tokens_roi_sender: list[torch.Tensor], + tokens_roi_recver: list[torch.Tensor], + ): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) @@ -74,15 +76,12 @@ class SimpleBuffer(KVLookupBufferBase): # simple common prefix matching min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], - tokens_recver[:min_length]): + if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): return min_length return 0 - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: - + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: @@ -90,7 +89,6 @@ class SimpleBuffer(KVLookupBufferBase): self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): - if isinstance(data, torch.Tensor): return data.element_size() * data.numel() if not data: @@ -100,10 +98,14 @@ class SimpleBuffer(KVLookupBufferBase): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): - + def _add_to_buffer( + self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + ): if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() if isinstance(roi, torch.Tensor): @@ -134,9 +136,7 @@ class SimpleBuffer(KVLookupBufferBase): return signal is None def drop_select_handler(self): - try: - while True: signal = self.signal_pipe.recv_tensor() if self._is_end_signal(signal): @@ -146,20 +146,21 @@ class SimpleBuffer(KVLookupBufferBase): input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() - assert roi is not None, "Please provide the roi when sending "\ - "drop-select request" - roi = (roi > 0.5) + assert roi is not None, ( + "Please provide the roi when sending drop-select request" + ) + roi = roi > 0.5 tokens_roi_recver = [input_tokens, roi] def is_buffer_available( - tokens_roi_recver: list[torch.Tensor], ) -> bool: + tokens_roi_recver: list[torch.Tensor], + ) -> bool: # perform input tokens and roi matching # FIXME: this matching is O(n), ideally it should be O(1) # but this buffer size won't (and shouldn't) be too large so # the fix is not urgent. for _ in range(len(self.buffer)): - if self._matches(self.buffer[0], - tokens_roi_recver) > 0: + if self._matches(self.buffer[0], tokens_roi_recver) > 0: return True # rotate the element we just accessed to the end self.buffer.rotate(-1) @@ -167,8 +168,7 @@ class SimpleBuffer(KVLookupBufferBase): with self.buffer_cv: while not is_buffer_available(tokens_roi_recver): - logger.debug( - "KV transfer buffer is not available. Waiting...") + logger.debug("KV transfer buffer is not available. Waiting...") self.buffer_cv.wait() # need to clone the tensor # in case the tensor is freed before sending finishes @@ -178,18 +178,18 @@ class SimpleBuffer(KVLookupBufferBase): self.buffer_cv.notify() except RuntimeError as e: - if 'Connection closed by peer' not in str(e): + if "Connection closed by peer" not in str(e): raise e logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: - - assert self.request_handling_thread is None, \ - "drop_select should be called by the KV cache consumer "\ + self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor] + ) -> list[Optional[torch.Tensor]]: + assert self.request_handling_thread is None, ( + "drop_select should be called by the KV cache consumer " "(e.g. the decode vLLM instance)" + ) if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() @@ -205,30 +205,36 @@ class SimpleBuffer(KVLookupBufferBase): if roi is not None: # convert from float tensor to bool tensor # as PyNccl does not support sending bool tensor - roi = (roi > 0.5) + roi = roi > 0.5 key = self.data_pipe.recv_tensor() value = self.data_pipe.recv_tensor() hidden = self.data_pipe.recv_tensor() return [input_tokens, roi, key, value, hidden] - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - + def insert( + self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + ) -> None: self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. if self.request_handling_thread is None: self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) + target=self.drop_select_handler + ) self.request_handling_thread.start() def close(self): - - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: + if ( + hasattr(self, "request_handling_thread") + and self.request_handling_thread is not None + ): self.request_handling_thread.join() else: diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 1423fd032477e..e27c6b2101b84 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -26,11 +26,11 @@ class KVPipeBase(ABC): @abstractmethod def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """Send a tensor, or None, via the pipe. - + Need to support sending None -- important for error handling. - - TODO: add a `key` argument so that we can use traditional - key-value database as the distributed communication mechanism behind + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind the pipe. Args: @@ -46,7 +46,7 @@ class KVPipeBase(ABC): """Receive a tensor (can be None) from the pipeline. Returns: - Optional[torch.Tensor]: The tensor received from the pipeline. Can + Optional[torch.Tensor]: The tensor received from the pipeline. Can be None. Raises: @@ -58,7 +58,7 @@ class KVPipeBase(ABC): def close(self) -> None: """Close the pipeline and release resources. - This method is responsible for closing the communication pipeline + This method is responsible for closing the communication pipeline and releasing any resources associated with it. Raises: diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 0b560d1b3b3ce..65858f86aa235 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -13,7 +13,7 @@ import zmq from safetensors.torch import load as safetensors_load from safetensors.torch import save as safetensors_save -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger from vllm.utils import join_host_port, make_zmq_path, split_host_port @@ -32,7 +32,7 @@ class MooncakeTransferEngineConfig: device_name: str @staticmethod - def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': + def from_file(file_path: str) -> "MooncakeTransferEngineConfig": """Load the config from a JSON file.""" with open(file_path) as fin: config = json.load(fin) @@ -46,12 +46,13 @@ class MooncakeTransferEngineConfig: ) @staticmethod - def load_from_env() -> 'MooncakeTransferEngineConfig': + def load_from_env() -> "MooncakeTransferEngineConfig": """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") if config_file_path is None: raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." + ) return MooncakeTransferEngineConfig.from_file(config_file_path) @@ -65,7 +66,8 @@ class MooncakeTransferEngine: raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e + "to run vLLM with MooncakeConnector." + ) from e self.engine = TransferEngine() self.local_rank = local_rank @@ -77,16 +79,13 @@ class MooncakeTransferEngine: logger.error(e) raise except Exception as exc: - logger.error( - "An error occurred while loading the configuration: %s", exc) + logger.error("An error occurred while loading the configuration: %s", exc) raise - prefill_host, base_prefill_port = split_host_port( - self.config.prefill_url) + prefill_host, base_prefill_port = split_host_port(self.config.prefill_url) decode_host, base_decode_port = split_host_port(self.config.decode_url) # Avoid ports conflict when running prefill and decode on the same node - if prefill_host == decode_host and \ - base_prefill_port == base_decode_port: + if prefill_host == decode_host and base_prefill_port == base_decode_port: base_decode_port = base_decode_port + 100 prefill_port = base_prefill_port + self.local_rank @@ -94,12 +93,15 @@ class MooncakeTransferEngine: self.prefill_url = join_host_port(prefill_host, prefill_port) self.decode_url = join_host_port(decode_host, decode_port) - self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, - self.config.metadata_server, self.config.protocol, - self.config.device_name, self.config.metadata_backend) + self.initialize( + self.prefill_url if kv_rank == 0 else self.decode_url, + self.config.metadata_server, + self.config.protocol, + self.config.device_name, + self.config.metadata_backend, + ) - self.remote_url = (self.decode_url - if kv_rank == 0 else self.prefill_url) + self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url # Initialize ZeroMQ context and sockets self.context = zmq.Context() # type: ignore[attr-defined] @@ -109,51 +111,57 @@ class MooncakeTransferEngine: self.receiver_ack = self.context.socket(zmq.constants.PUSH) self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) - self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, - decode_host, base_decode_port) + self._setup_metadata_sockets( + kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port + ) - def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int, - d_host: str, d_port: int) -> None: + def _setup_metadata_sockets( + self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int + ) -> None: """Set up ZeroMQ sockets for sending and receiving data.""" # Offsets < 8 are left for initialization in case tp and pp are enabled p_rank_offset = p_port + 8 + self.local_rank * 2 d_rank_offset = d_port + 8 + self.local_rank * 2 if kv_rank == 0: - self.sender_socket.bind( - make_zmq_path("tcp", p_host, p_rank_offset + 1)) + self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1)) self.receiver_socket.connect( - make_zmq_path("tcp", d_host, d_rank_offset + 1)) - self.sender_ack.connect( - make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.receiver_ack.bind( - make_zmq_path("tcp", p_host, p_rank_offset + 2)) + make_zmq_path("tcp", d_host, d_rank_offset + 1) + ) + self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2)) else: self.receiver_socket.connect( - make_zmq_path("tcp", p_host, p_rank_offset + 1)) - self.sender_socket.bind( - make_zmq_path("tcp", d_host, d_rank_offset + 1)) - self.receiver_ack.bind( - make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.sender_ack.connect( - make_zmq_path("tcp", p_host, p_rank_offset + 2)) + make_zmq_path("tcp", p_host, p_rank_offset + 1) + ) + self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1)) + self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2)) - def initialize(self, local_hostname: str, metadata_server: str, - protocol: str, device_name: str, - metadata_backend: Union[str, None]) -> None: + def initialize( + self, + local_hostname: str, + metadata_server: str, + protocol: str, + device_name: str, + metadata_backend: Union[str, None], + ) -> None: """Initialize the mooncake instance.""" if metadata_backend is None: - self.engine.initialize(local_hostname, metadata_server, protocol, - device_name) + self.engine.initialize( + local_hostname, metadata_server, protocol, device_name + ) else: supported_backend = ["etcd", "redis"] metadata_backend = metadata_backend.lower() if metadata_backend not in supported_backend: raise ValueError( "Mooncake Configuration error. `metadata_backend`" - f" should be one of {supported_backend}.") + f" should be one of {supported_backend}." + ) - self.engine.initialize_ext(local_hostname, metadata_server, - protocol, device_name, metadata_backend) + self.engine.initialize_ext( + local_hostname, metadata_server, protocol, device_name, metadata_backend + ) def allocate_managed_buffer(self, length: int) -> int: """Allocate a managed buffer of the specified length.""" @@ -167,18 +175,17 @@ class MooncakeTransferEngine: """Free a previously allocated managed buffer.""" return self.engine.free_managed_buffer(buffer, length) - def transfer_sync(self, buffer: int, peer_buffer_address: int, - length: int) -> int: + def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int: """Synchronously transfer data to the specified address.""" - ret = self.engine.transfer_sync_read(self.remote_url, buffer, - peer_buffer_address, length) + ret = self.engine.transfer_sync_read( + self.remote_url, buffer, peer_buffer_address, length + ) if ret < 0: logger.error("Transfer Return Error") raise Exception("Transfer Return Error") return ret - def write_bytes_to_buffer(self, buffer: int, user_data: bytes, - length: int) -> int: + def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int: """Write bytes to the allocated buffer.""" return self.engine.write_bytes_to_buffer(buffer, user_data, length) @@ -189,7 +196,7 @@ class MooncakeTransferEngine: def wait_for_ack(self, src_ptr: int, length: int) -> None: """Asynchronously wait for ACK from the receiver.""" ack = self.sender_ack.recv() - if ack != b'ACK': + if ack != b"ACK": logger.error("Failed to receive ACK from the receiver") self.free_managed_buffer(src_ptr, length) @@ -200,8 +207,8 @@ class MooncakeTransferEngine: src_ptr = self.allocate_managed_buffer(length) self.write_bytes_to_buffer(src_ptr, user_data, length) self.sender_socket.send_multipart( - [struct.pack("!Q", src_ptr), - struct.pack("!Q", length)]) + [struct.pack("!Q", src_ptr), struct.pack("!Q", length)] + ) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) def recv_bytes(self) -> bytes: @@ -214,7 +221,7 @@ class MooncakeTransferEngine: ret = self.read_bytes_from_buffer(dst_ptr, length) # Buffer cleanup - self.receiver_ack.send(b'ACK') + self.receiver_ack.send(b"ACK") self.free_managed_buffer(dst_ptr, length) return ret @@ -223,10 +230,9 @@ class MooncakeTransferEngine: class MooncakePipe(KVPipeBase): """MooncakeTransferEngine based Pipe implementation.""" - def __init__(self, - local_rank: int, - config: KVTransferConfig, - device: Optional[str] = None): + def __init__( + self, local_rank: int, config: KVTransferConfig, device: Optional[str] = None + ): """Initialize the mooncake pipe and set related parameters.""" self.config = config self.local_rank = local_rank @@ -236,8 +242,7 @@ class MooncakePipe(KVPipeBase): else: self.device = self._select_device(device) - self.transfer_engine = MooncakeTransferEngine(self.kv_rank, - self.local_rank) + self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank) self.transport_thread: Optional[ThreadPoolExecutor] = None self.none_tensor = torch.tensor([NONE_INT], device=self.device) @@ -267,7 +272,7 @@ class MooncakePipe(KVPipeBase): if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) tensor = tensor if tensor is not None else self.none_tensor - assert (len(tensor.shape) > 0) + assert len(tensor.shape) > 0 self.transport_thread.submit(self._send_impl, tensor) def recv_tensor(self) -> Optional[torch.Tensor]: diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 09de0b682efca..c79b7e7e50303 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ - This module implements a PyNccl pipe for sending and receiving - Optional[torch.Tensor] between distributed ranks with advanced - communication features. +This module implements a PyNccl pipe for sending and receiving +Optional[torch.Tensor] between distributed ranks with advanced +communication features. - Key Features: - - Supports sending and receiving tensors with metadata - - Handles both CUDA and CPU device communications - - Implements a non-blocking tensor transfer mechanism - - Manages buffer size and provides backpressure control - - Supports distributed process groups with configurable parameters +Key Features: +- Supports sending and receiving tensors with metadata +- Handles both CUDA and CPU device communications +- Implements a non-blocking tensor transfer mechanism +- Manages buffer size and provides backpressure control +- Supports distributed process groups with configurable parameters """ import threading @@ -20,7 +20,7 @@ from typing import Callable, Optional import torch -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.utils import StatelessProcessGroup @@ -30,7 +30,6 @@ logger = init_logger(__name__) class BrokenPipeException(Exception): - def __init__(self, message): self.message = message super().__init__(self.message) @@ -40,16 +39,17 @@ Metadata = dict[str, Optional[torch.Tensor]] class PyNcclPipe(KVPipeBase): - METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__(self, - local_rank: int, - config: KVTransferConfig, - device: Optional[str] = None, - port_offset: int = 0): + def __init__( + self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0, + ): self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank @@ -84,9 +84,9 @@ class PyNcclPipe(KVPipeBase): def _get_device_send_recv_impl( self, group: StatelessProcessGroup - ) -> tuple[Callable[[torch.Tensor, int], None], Callable[ - [torch.Tensor, int], None]]: - + ) -> tuple[ + Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None] + ]: send: Callable[[torch.Tensor, int], None] recv: Callable[[torch.Tensor, int], None] if self.device.type == "cuda": @@ -144,9 +144,9 @@ class PyNcclPipe(KVPipeBase): buffer: A tensor of the specified type and shape, allocated on `self.device`. """ - return torch.empty(metadata["shape"], - dtype=metadata["dtype"], - device=self.device) + return torch.empty( + metadata["shape"], dtype=metadata["dtype"], device=self.device + ) def _send_metadata(self, metadata: Metadata): """ @@ -179,8 +179,7 @@ class PyNcclPipe(KVPipeBase): metadata = self._make_metadata(tensor) self._send_metadata(metadata) if tensor is not None: - self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + self.device_send_func(tensor.to(self.device), self.target_rank_for_send) def _recv_impl(self) -> Optional[torch.Tensor]: """ @@ -198,8 +197,9 @@ class PyNcclPipe(KVPipeBase): return buffer - def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + def send_tensor_wrapper( + self, tensor: Optional[torch.Tensor], tensor_size: int + ) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ @@ -209,9 +209,14 @@ class PyNcclPipe(KVPipeBase): with self.buffer_size_lock: self.buffer_size -= tensor_size except Exception as e: - logger.error("[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), str(tensor), str(e)) + logger.error( + "[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), + str(tensor), + str(e), + ) import traceback + traceback.print_exc() def block_if_full(self): @@ -244,15 +249,14 @@ class PyNcclPipe(KVPipeBase): with self.buffer_size_lock: self.buffer_size += tensor_size - self.transport_thread.submit(self.send_tensor_wrapper, tensor, - tensor_size) + self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size) def recv_tensor(self) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. - Args: - tensor: The received tensor, or `None` if no tensor is received. + Returns: + The received tensor, or `None` if no tensor is received. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -266,6 +270,7 @@ class PyNcclPipe(KVPipeBase): logger.error("%s", e) logger.error("My device: %s", self.device) import traceback + traceback.print_exc() raise e @@ -275,6 +280,5 @@ class PyNcclPipe(KVPipeBase): """ Close the pipe and release associated resources. """ - if hasattr(self, - "transport_thread") and self.transport_thread is not None: + if hasattr(self, "transport_thread") and self.transport_thread is not None: self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 5e0f64fca220c..f8f65f28ff6d7 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, Optional from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -17,7 +18,8 @@ _KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None def get_kv_transfer_group() -> KVConnectorBaseType: assert _KV_CONNECTOR_AGENT is not None, ( - "disaggregated KV cache transfer parallel group is not initialized") + "disaggregated KV cache transfer parallel group is not initialized" + ) return _KV_CONNECTOR_AGENT @@ -25,8 +27,7 @@ def has_kv_transfer_group() -> bool: return _KV_CONNECTOR_AGENT is not None -def is_v1_kv_transfer_group( - connector: Optional[KVConnectorBaseType] = None) -> bool: +def is_v1_kv_transfer_group(connector: Optional[KVConnectorBaseType] = None) -> bool: """Check if the KV connector is the v1 connector. If the argument is None, it will check the global KV connector @@ -57,10 +58,20 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if vllm_config.kv_transfer_config is None: return - if (vllm_config.kv_transfer_config.is_kv_transfer_instance - and _KV_CONNECTOR_AGENT is None): + if ( + vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None + ): if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, role=KVConnectorRole.WORKER) + config=vllm_config, role=KVConnectorRole.WORKER + ) else: raise ValueError("V0 is no longer supported") + + +def ensure_kv_transfer_shutdown() -> None: + global _KV_CONNECTOR_AGENT + if _KV_CONNECTOR_AGENT is not None: + _KV_CONNECTOR_AGENT.shutdown() + _KV_CONNECTOR_AGENT = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b89aee99c8d46..cb5a75c59f096 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,7 @@ If you only need to use the distributed environment without model/pipeline parallelism, you can skip the model parallel initialization and destruction steps. """ + import contextlib import gc import pickle @@ -29,22 +30,30 @@ import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass +from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Optional, Union from unittest.mock import patch import torch import torch.distributed +import torch.distributed._functional_collectives as funcol +import torch.distributed._symmetric_memory from torch.distributed import Backend, ProcessGroup from typing_extensions import deprecated import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) + DeviceCommunicatorBase, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import (direct_register_custom_op, get_distributed_init_method, - resolve_obj_by_qualname, supports_custom_op) +from vllm.utils import ( + direct_register_custom_op, + get_distributed_init_method, + resolve_obj_by_qualname, + supports_custom_op, +) @dataclass @@ -56,7 +65,7 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) def _split_tensor_dict( - tensor_dict: dict[str, Union[torch.Tensor, Any]] + tensor_dict: dict[str, Union[torch.Tensor, Any]], ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -73,7 +82,8 @@ def _split_tensor_dict( # receiving side will set the device index. device = value.device.type metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) + (key, TensorMetadata(device, value.dtype, value.size())) + ) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -115,8 +125,9 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) -def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def reduce_scatter( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: @@ -124,15 +135,17 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, return group._reduce_scatter_out_place(tensor, dim) -def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def reduce_scatter_fake( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: new_shape = list(tensor.shape) new_shape[dim] = tensor.shape[dim] // world_size return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) -def all_gather(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def all_gather( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: @@ -140,37 +153,124 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int, return group._all_gather_out_place(tensor, dim) -def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def all_gather_fake( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: new_shape = list(tensor.shape) new_shape[dim] = tensor.shape[dim] * world_size return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) +def patched_fused_scaled_matmul_reduce_scatter_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + # Copied from + # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189 + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def patched_fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + + if supports_custom_op(): - from vllm.platforms import current_platform direct_register_custom_op( op_name="all_reduce", op_func=all_reduce, - mutates_args=[], fake_impl=all_reduce_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="reduce_scatter", op_func=reduce_scatter, - mutates_args=[], fake_impl=reduce_scatter_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="all_gather", op_func=all_gather, - mutates_args=[], fake_impl=all_gather_fake, - dispatch_key=current_platform.dispatch_key, + ) + + # TODO: Remove this once the pytorch fix + # (https://github.com/pytorch/pytorch/pull/165086) gets released, + # in either 2.9.1 or 2.10 + direct_register_custom_op( + op_name="patched_fused_scaled_matmul_reduce_scatter", + op_func=patched_fused_scaled_matmul_reduce_scatter, + fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake, ) @@ -224,7 +324,8 @@ class GroupCoordinator: for ranks in group_ranks: device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) + ranks, backend=torch_distributed_backend + ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") @@ -248,8 +349,7 @@ class GroupCoordinator: elif current_platform.is_xpu(): self.device = torch.device(f"xpu:{local_rank}") elif current_platform.is_out_of_tree(): - self.device = torch.device( - f"{current_platform.device_name}:{local_rank}") + self.device = torch.device(f"{current_platform.device_name}:{local_rank}") else: self.device = torch.device("cpu") @@ -257,7 +357,8 @@ class GroupCoordinator: self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( - current_platform.get_device_communicator_cls()) + current_platform.get_device_communicator_cls() + ) self.device_communicator = device_comm_cls( cpu_group=self.cpu_group, device=self.device, @@ -265,19 +366,23 @@ class GroupCoordinator: unique_name=self.unique_name, ) - from vllm.distributed.device_communicators.shm_broadcast import ( - MessageQueue) + from vllm.distributed.device_communicators.shm_broadcast import MessageQueue + self.mq_broadcaster: Optional[MessageQueue] = None if use_message_queue_broadcaster and self.world_size > 1: self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6) + self.cpu_group, 1 << 22, 6 + ) from vllm.platforms import current_platform - self.use_custom_op_call = (current_platform.is_cuda_alike() - or current_platform.is_tpu()) - self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr( - torch.ops._C, "init_shm_manager")) + self.use_custom_op_call = ( + current_platform.is_cuda_alike() or current_platform.is_tpu() + ) + + self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( + torch.ops._C, "init_shm_manager" + ) @property def first_rank(self): @@ -315,7 +420,8 @@ class GroupCoordinator: @contextmanager def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None): + self, graph_capture_context: Optional[GraphCaptureContext] = None + ): if graph_capture_context is None: stream = torch.cuda.Stream() graph_capture_context = GraphCaptureContext(stream) @@ -326,7 +432,9 @@ class GroupCoordinator: # so we don't abstract it into the base class maybe_ca_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( - CudaCommunicator) + CudaCommunicator, + ) + if self.device_communicator is not None: assert isinstance(self.device_communicator, CudaCommunicator) ca_comm = self.device_communicator.ca_comm @@ -362,8 +470,7 @@ class GroupCoordinator: return input_ if self.use_custom_op_call: - return torch.ops.vllm.all_reduce(input_, - group_name=self.unique_name) + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) else: return self._all_reduce_out_place(input_) @@ -378,66 +485,62 @@ class GroupCoordinator: if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if self.use_custom_op_call: - return torch.ops.vllm.all_gather(input_, - dim, - world_size, - group_name=self.unique_name) + return torch.ops.vllm.all_gather( + input_, dim, world_size, group_name=self.unique_name + ) else: return self._all_gather_out_place(input_, dim) - def _all_gather_out_place(self, input_: torch.Tensor, - dim: int) -> torch.Tensor: + def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.all_gather(input_, dim) - def all_gatherv(self, - input_: Union[torch.Tensor, list[torch.Tensor]], - dim: int = 0, - sizes: Optional[list[int]] = None): + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None, + ): if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.all_gatherv(input_, dim, sizes) - def reduce_scatter(self, - input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if self.use_custom_op_call: - return torch.ops.vllm.reduce_scatter(input_, - dim, - world_size, - group_name=self.unique_name) + return torch.ops.vllm.reduce_scatter( + input_, dim, world_size, group_name=self.unique_name + ) else: return self._reduce_scatter_out_place(input_, dim) - def reduce_scatterv(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[list[int]] = None) -> torch.Tensor: + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + ) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.reduce_scatterv(input_, dim, sizes) - def _reduce_scatter_out_place(self, input_: torch.Tensor, - dim: int) -> torch.Tensor: + def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.reduce_scatter(input_, dim) - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -461,9 +564,9 @@ class GroupCoordinator: if self.world_size == 1: return input_ # Broadcast. - torch.distributed.broadcast(input_, - src=self.ranks[src], - group=self.device_group) + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) return input_ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): @@ -479,21 +582,20 @@ class GroupCoordinator: assert src == 0, "Message queue broadcaster only supports src=0" return self.mq_broadcaster.broadcast_object(obj) if self.rank_in_group == src: - torch.distributed.broadcast_object_list([obj], - src=self.ranks[src], - group=self.cpu_group) + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) return obj else: recv = [None] - torch.distributed.broadcast_object_list(recv, - src=self.ranks[src], - group=self.cpu_group) + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) return recv[0] - def broadcast_object_list(self, - obj_list: list[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): + def broadcast_object_list( + self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ @@ -503,9 +605,9 @@ class GroupCoordinator: if self.world_size == 1: return obj_list # Broadcast. - torch.distributed.broadcast_object_list(obj_list, - src=self.ranks[src], - group=self.device_group) + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) return obj_list def send_object(self, obj: Any, dst: int) -> None: @@ -516,25 +618,22 @@ class GroupCoordinator: assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " - "as the current rank.") + "as the current rank." + ) # Serialize object to tensor and get the size as well object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - size_tensor = torch.tensor([object_tensor.numel()], - dtype=torch.long, - device="cpu") + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) # Send object size - torch.distributed.send(size_tensor, - dst=self.ranks[dst], - group=self.cpu_group) + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) # Send object - torch.distributed.send(object_tensor, - dst=self.ranks[dst], - group=self.cpu_group) + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) return None @@ -551,22 +650,24 @@ class GroupCoordinator: size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size - rank_size = torch.distributed.recv(size_tensor, - src=self.ranks[src], - group=self.cpu_group) + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device="cpu") + device="cpu", + ) - rank_object = torch.distributed.recv(object_tensor, - src=self.ranks[src], - group=self.cpu_group) + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) assert rank_object == rank_size, ( - "Received object sender rank does not match the size sender rank.") + "Received object sender rank does not match the size sender rank." + ) obj = pickle.loads(object_tensor.numpy().tobytes()) @@ -577,13 +678,13 @@ class GroupCoordinator: tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None + metadata_group: Optional[ProcessGroup] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() or self.world_size == 1): + if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict group = self.device_group @@ -593,9 +694,9 @@ class GroupCoordinator: rank_in_group = self.rank_in_group if rank_in_group == src: metadata_list: list[tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, @@ -608,16 +709,14 @@ class GroupCoordinator: continue if tensor.is_cpu: # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=metadata_group, async_op=True + ) else: # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -628,9 +727,9 @@ class GroupCoordinator: async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor @@ -641,14 +740,13 @@ class GroupCoordinator: tensor, src=self.ranks[src], group=metadata_group, - async_op=True) + async_op=True, + ) else: # use group for GPU tensors handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=group, - async_op=True) + tensor, src=self.ranks[src], group=group, async_op=True + ) async_handles.append(handle) tensor_dict[key] = tensor else: @@ -662,18 +760,33 @@ class GroupCoordinator: tensor_dict: dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. + + all_gather_group: The group for the all-gather operation. If provided, + an optimization is enabled where each rank in the group sends a + slice of a tensor and the receiver reconstructs it using an + all-gather, which can improve performance. This is typically the + tensor-parallel group. + all_gather_tensors: A dictionary to specify which tensors should use + the all-gather optimization, which is only effective when + `all_gather_group` is provided. By default, this optimization is + on for any tensor whose size is divisible by the + `all_gather_group`'s world size. However, it should be disabled + for tensors that are not fully replicated across the group (e.g., + the residual tensor when sequence parallelism is enabled). This + dictionary allows overriding the default behavior on a per-tensor + basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict - - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) group = self.device_group metadata_group = self.cpu_group @@ -686,56 +799,81 @@ class GroupCoordinator: if self.device_communicator is None: raise ValueError("No device communicator found") self.device_communicator.send_tensor_dict( # type: ignore - tensor_dict, dst) + tensor_dict, dst + ) return None metadata_list: list[tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), f"Expecting a dictionary, got {type(tensor_dict)}" + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: + + tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)] + assert len(tensor_keys) == len(tensor_list) + + for key, tensor in zip(tensor_keys, tensor_list): if tensor.numel() == 0: # Skip sending empty tensors. continue # send-allgather: send only a slice, then do allgather. - if (all_gather_group is not None - and tensor.numel() % all_gather_size == 0): + use_all_gather = ( + all_gather_group is not None and tensor.numel() % all_gather_size == 0 + ) + use_all_gather = ( + all_gather_tensors.get(key, use_all_gather) + if all_gather_tensors + else use_all_gather + ) + if use_all_gather: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=metadata_group) + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) else: # use group for GPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=group) + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None def recv_tensor_dict( self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. + + all_gather_group: The group for the all-gather operation. If provided, + an optimization is enabled where each rank in the group sends a + slice of a tensor and the receiver reconstructs it using an + all-gather, which can improve performance. This is typically the + tensor-parallel group. + all_gather_tensors: A dictionary to specify which tensors should use + the all-gather optimization, which is only effective when + `all_gather_group` is provided. By default, this optimization is + on for any tensor whose size is divisible by the + `all_gather_group`'s world size. However, it should be disabled + for tensors that are not fully replicated across the group (e.g., + the residual tensor when sequence parallelism is enabled). This + dictionary allows overriding the default behavior on a per-tensor + basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None - - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) group = self.device_group metadata_group = self.cpu_group @@ -748,43 +886,47 @@ class GroupCoordinator: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.recv_tensor_dict( # type: ignore - src) + src + ) recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue # send-allgather: send only a slice, then do allgather. - use_all_gather = (all_gather_group is not None - and tensor.numel() % all_gather_size == 0) + use_all_gather = ( + all_gather_group is not None + and tensor.numel() % all_gather_size == 0 + ) + use_all_gather = ( + all_gather_tensors.get(key, use_all_gather) + if all_gather_tensors + else use_all_gather + ) if use_all_gather: orig_shape = tensor.shape - tensor = tensor.reshape(all_gather_size, - -1)[all_gather_rank] + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=metadata_group) + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) else: # use group for GPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) + torch.distributed.recv(tensor, src=self.ranks[src], group=group) if use_all_gather: # do the allgather tensor = all_gather_group.all_gather( # type: ignore - tensor, dim=0) + tensor, dim=0 + ) tensor = tensor.reshape(orig_shape) tensor_dict[key] = tensor @@ -808,10 +950,9 @@ class GroupCoordinator: raise ValueError("No device communicator found") self.device_communicator.send(tensor, dst) - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if self.device_communicator is None: @@ -832,21 +973,26 @@ class GroupCoordinator: def prepare_communication_buffer_for_model(self, model: torch.nn.Module): if self.device_communicator is not None: - self.device_communicator.prepare_communication_buffer_for_model( - model) + self.device_communicator.prepare_communication_buffer_for_model(model) def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is not None: - return self.device_communicator.dispatch(hidden_states, - router_logits) + return self.device_communicator.dispatch( + hidden_states, router_logits, is_sequence_parallel + ) else: return hidden_states, router_logits - def combine(self, hidden_states) -> torch.Tensor: + def combine( + self, hidden_states, is_sequence_parallel: bool = False + ) -> torch.Tensor: if self.device_communicator is not None: - return self.device_communicator.combine(hidden_states) + return self.device_communicator.combine(hidden_states, is_sequence_parallel) else: return hidden_states @@ -856,12 +1002,13 @@ _NODE_COUNT: Optional[int] = None def get_world_group() -> GroupCoordinator: - assert _WORLD is not None, ("world group is not initialized") + assert _WORLD is not None, "world group is not initialized" return _WORLD -def init_world_group(ranks: list[int], local_rank: int, - backend: str) -> GroupCoordinator: +def init_world_group( + ranks: list[int], local_rank: int, backend: str +) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], local_rank=local_rank, @@ -878,7 +1025,6 @@ def init_model_parallel_group( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ) -> GroupCoordinator: - return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -893,24 +1039,37 @@ _TP: Optional[GroupCoordinator] = None def get_tp_group() -> GroupCoordinator: - assert _TP is not None, ("tensor model parallel group is not initialized") + assert _TP is not None, "tensor model parallel group is not initialized" return _TP -@deprecated("`get_tensor_model_parallel_group` has been replaced with " - "`get_tp_group` and may be removed after v0.12. Please use " - "`get_tp_group` instead.") +@deprecated( + "`get_tensor_model_parallel_group` has been replaced with " + "`get_tp_group` and may be removed after v0.12. Please use " + "`get_tp_group` instead." +) def get_tensor_model_parallel_group(): return get_tp_group() +_DCP: Optional[GroupCoordinator] = None + + +def get_dcp_group() -> GroupCoordinator: + assert _DCP is not None, "decode context model parallel group is not initialized" + return _DCP + + +# kept for backward compatibility +get_context_model_parallel_group = get_dcp_group + _PP: Optional[GroupCoordinator] = None _DP: Optional[GroupCoordinator] = None def get_dp_group() -> GroupCoordinator: - assert _DP is not None, ("data parallel group is not initialized") + assert _DP is not None, "data parallel group is not initialized" return _DP @@ -918,19 +1077,20 @@ _EP: Optional[GroupCoordinator] = None def get_ep_group() -> GroupCoordinator: - assert _EP is not None, ("expert parallel group is not initialized") + assert _EP is not None, "expert parallel group is not initialized" return _EP def get_pp_group() -> GroupCoordinator: - assert _PP is not None, ( - "pipeline model parallel group is not initialized") + assert _PP is not None, "pipeline model parallel group is not initialized" return _PP -@deprecated("`get_pipeline_model_parallel_group` has been replaced with " - "`get_pp_group` and may be removed in v0.12. Please use " - "`get_pp_group` instead.") +@deprecated( + "`get_pipeline_model_parallel_group` has been replaced with " + "`get_pp_group` and may be removed in v0.12. Please use " + "`get_pp_group` instead." +) def get_pipeline_model_parallel_group(): return get_pp_group() @@ -939,8 +1099,8 @@ def get_pipeline_model_parallel_group(): def graph_capture(device: torch.device): """ `graph_capture` is a context manager which should surround the code that - is capturing the CUDA graph. Its main purpose is to ensure that the - some operations will be run after the graph is captured, before the graph + is capturing the CUDA graph. Its main purpose is to ensure that some + operations will be run after the graph is captured, before the graph is replayed. It returns a `GraphCaptureContext` object which contains the necessary data for the graph capture. Currently, it only contains the stream that the graph capture is running on. This stream is set to the @@ -951,8 +1111,7 @@ def graph_capture(device: torch.device): from other kernels possibly launched on background in the default stream. """ context = GraphCaptureContext(torch.cuda.Stream(device=device)) - with get_tp_group().graph_capture(context), get_pp_group().graph_capture( - context): + with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context): yield context @@ -972,14 +1131,24 @@ def init_distributed_environment( distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", + timeout: Optional[timedelta] = None, ): logger.debug( - "world_size=%d rank=%d local_rank=%d " - "distributed_init_method=%s backend=%s", world_size, rank, local_rank, - distributed_init_method, backend) + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) from vllm.config import get_current_vllm_config + config = get_current_vllm_config() - if config is not None and config.parallel_config.data_parallel_size > 1: + if ( + config is not None + and config.parallel_config.data_parallel_size > 1 + and config.parallel_config.distributed_executor_backend != "external_launcher" + ): parallel_config = config.parallel_config # adjust to take into account data parallelism # offset the rank by the data parallel rank @@ -991,49 +1160,55 @@ def init_distributed_environment( distributed_init_method = get_distributed_init_method(ip, port) logger.info( "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, rank, distributed_init_method) + world_size, + rank, + distributed_init_method, + ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " - "distributed environment") + "distributed environment" + ) if not torch.distributed.is_backend_available(backend): logger.warning( - "Distributed backend %s is not available; " - "falling back to gloo.", backend) + "Distributed backend %s is not available; falling back to gloo.", + backend, + ) assert torch.distributed.is_gloo_available(), ( - "Fallback Gloo backend is not available.") + "Fallback Gloo backend is not available." + ) backend = "gloo" # this backend is used for WORLD torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, world_size=world_size, - rank=rank) + rank=rank, + timeout=timeout, + ) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 if local_rank == -1: # local rank not set, this usually happens in single-node # setting, where we can use rank as local rank - if distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK - else: - local_rank = rank + local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank global _WORLD, _NODE_COUNT if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) _NODE_COUNT = _node_count(_WORLD.cpu_group) - logger.debug("Detected %d nodes in the distributed environment", - _NODE_COUNT) + logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( - "world group already initialized with a different world size") + "world group already initialized with a different world size" + ) def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """ @@ -1063,11 +1238,11 @@ def initialize_model_parallel( assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) + backend = backend or torch.distributed.get_backend(get_world_group().device_group) data_parallel_size = 1 from vllm.config import get_current_vllm_config + config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size @@ -1082,88 +1257,115 @@ def initialize_model_parallel( # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, - tensor_model_parallel_size) # noqa + -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size + ) # noqa # Build the tensor model-parallel groups. global _TP - assert _TP is None, ("tensor model parallel group is already initialized") + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="tp") + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + ) + + # Build the DCP model-parallel groups. + global _DCP + assert _DCP is None, "decode context model parallel group is already initialized" + # Note(hc): In the current implementation of decode context parallel, + # dcp_size must not exceed tp_size, because the world size does not + # change by DCP, it simply reuses the GPUs of TP group, and split one + # TP group into tp_size//dcp_size DCP groups. + group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DCP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp", + ) # Build the pipeline model-parallel groups. global _PP - assert _PP is None, ( - "pipeline model parallel group is already initialized") - group_ranks = all_ranks.transpose(2, 3).reshape( - -1, pipeline_model_parallel_size).unbind(0) + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] - _PP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="pp") + _PP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="pp" + ) global _DP - assert _DP is None, ("data parallel group is already initialized") - group_ranks = all_ranks.transpose(1, - 3).reshape(-1, - data_parallel_size).unbind(0) + assert _DP is None, "data parallel group is already initialized" + group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="dp") + _DP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dp" + ) global _EP - assert _EP is None, ("expert parallel group is already initialized") - group_ranks = all_ranks.transpose(1, 2).reshape( - -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + assert _EP is None, "expert parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(1, 2) + .reshape(-1, data_parallel_size * tensor_model_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") + _EP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="ep" + ) logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, - _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, - _EP.rank_in_group) + "DP rank %s, PP rank %s, TP rank %s, EP rank %s", + rank, + world_size, + _DP.rank_in_group, + _PP.rank_in_group, + _TP.rank_in_group, + _EP.rank_in_group, + ) def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) + backend = backend or torch.distributed.get_backend(get_world_group().device_group) if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, backend) + initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + decode_context_model_parallel_size, + backend, + ) return - assert ( - get_tensor_model_parallel_world_size() == tensor_model_parallel_size - ), ("tensor parallel group already initialized, but of unexpected size. " + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size. " f"got: {get_tensor_model_parallel_world_size()=} vs. " - f"wanted: {tensor_model_parallel_size=}") + f"wanted: {tensor_model_parallel_size=}" + ) pp_world_size = get_pp_group().world_size - assert (pp_world_size == pipeline_model_parallel_size), ( + assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size. " f"got: {pp_world_size=} vs. " - f"wanted: {pipeline_model_parallel_size=}") + f"wanted: {pipeline_model_parallel_size=}" + ) def prepare_communication_buffer_for_model(model: torch.nn.Module): @@ -1185,7 +1387,7 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP is not None and _PP is not None) + return _TP is not None and _PP is not None _TP_STATE_PATCHED = False @@ -1226,10 +1428,19 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_decode_context_model_parallel_world_size(): + """Return world size for the decode context model parallel group.""" + return get_dcp_group().world_size + + +def get_decode_context_model_parallel_rank(): + """Return my rank for the decode context model parallel group.""" + return get_dcp_group().rank_in_group + + def get_node_count() -> int: - """Return the total number of nodes in the distributed environment. """ - assert _NODE_COUNT is not None, ( - "distributed environment is not initialized") + """Return the total number of nodes in the distributed environment.""" + assert _NODE_COUNT is not None, "distributed environment is not initialized" return _NODE_COUNT @@ -1246,6 +1457,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DCP + if _DCP: + _DCP.destroy() + _DCP = None + global _DP if _DP: _DP.destroy() @@ -1272,9 +1488,11 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_distributed_environment() if shutdown_ray: import ray # Lazy import Ray + ray.shutdown() gc.collect() from vllm.platforms import current_platform + empty_cache = current_platform.empty_cache if empty_cache is not None: empty_cache() @@ -1282,21 +1500,21 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): if not current_platform.is_cpu(): torch._C._host_emptyCache() except AttributeError: - logger.warning( - "torch._C._host_emptyCache() only available in Pytorch >=2.5") + logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5") -def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], - source_rank: int = 0) -> list[bool]: +def in_the_same_node_as( + pg: Union[ProcessGroup, StatelessProcessGroup], source_rank: int = 0 +) -> list[bool]: """ This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). """ if isinstance(pg, ProcessGroup): - assert torch.distributed.get_backend( - pg) != torch.distributed.Backend.NCCL, ( - "in_the_same_node_as should be tested with a non-NCCL group.") + assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group." + ) # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) @@ -1319,10 +1537,11 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], if rank == source_rank: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) - shm.buf[:len(magic_message)] = magic_message + shm.buf[: len(magic_message)] = magic_message if isinstance(pg, ProcessGroup): torch.distributed.broadcast_object_list( - [shm.name], src=ranks[source_rank], group=pg) + [shm.name], src=ranks[source_rank], group=pg + ) else: pg.broadcast_obj(shm.name, src=source_rank) is_in_the_same_node[rank] = 1 @@ -1331,17 +1550,20 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], if isinstance(pg, ProcessGroup): recv = [None] torch.distributed.broadcast_object_list( - recv, src=ranks[source_rank], group=pg) + recv, src=ranks[source_rank], group=pg + ) name = recv[0] else: name = pg.broadcast_obj(None, src=source_rank) # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. - with patch("multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None): + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): shm = shared_memory.SharedMemory(name=name) - if shm.buf[:len(magic_message)] == magic_message: + if shm.buf[: len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 except Exception as e: logger.error("Error ignored in is_in_the_same_node: %s", e) diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py index 0a786b4a1708f..3db25d1a19641 100644 --- a/vllm/distributed/tpu_distributed_utils.py +++ b/vllm/distributed/tpu_distributed_utils.py @@ -10,18 +10,17 @@ import torch_xla.distributed.spmd as xs from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) logger = init_logger(__name__) class XlaQKVParallelLinear(nn.Module): - - def __init__(self, - qkv_linear: nn.Module, - mesh: Optional["xs.Mesh"] = None): + def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None): super().__init__() assert isinstance(qkv_linear, QKVParallelLinear) self.skip_bias_add = qkv_linear.skip_bias_add @@ -39,21 +38,22 @@ class XlaQKVParallelLinear(nn.Module): self._shard_weight(mesh) def _shard_weight(self, mesh: "xs.Mesh"): - self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False) - self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False) - self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False) - xs.mark_sharding(self.q_weight, mesh, ('x', None)) - xs.mark_sharding(self.k_weight, mesh, ('x', None)) - xs.mark_sharding(self.v_weight, mesh, ('x', None)) + self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False) + self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False) + self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False) + xs.mark_sharding(self.q_weight, mesh, ("x", None)) + xs.mark_sharding(self.k_weight, mesh, ("x", None)) + xs.mark_sharding(self.v_weight, mesh, ("x", None)) if self.q_bias is not None: - assert self.k_bias is not None and self.v_bias is not None, \ + assert self.k_bias is not None and self.v_bias is not None, ( "QKVParallelLinear should have q, k, and v biases together." - self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False) - xs.mark_sharding(self.q_bias, mesh, ('x', )) - self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False) - xs.mark_sharding(self.k_bias, mesh, ('x', )) - self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False) - xs.mark_sharding(self.v_bias, mesh, ('x', )) + ) + self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False) + xs.mark_sharding(self.q_bias, mesh, ("x",)) + self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False) + xs.mark_sharding(self.k_bias, mesh, ("x",)) + self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False) + xs.mark_sharding(self.v_bias, mesh, ("x",)) def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): q_proj_size, k_proj_size, _ = qkv_linear.output_sizes @@ -61,22 +61,25 @@ class XlaQKVParallelLinear(nn.Module): # along the output dimension. qkv_weight = qkv_linear.weight.data.cpu() q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) - k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size], - requires_grad=False) - v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:], - requires_grad=False) + k_weight = Parameter( + qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False + ) + v_weight = Parameter( + qkv_weight[q_proj_size + k_proj_size :], requires_grad=False + ) self.register_parameter("q_weight", q_weight) self.register_parameter("k_weight", k_weight) self.register_parameter("v_weight", v_weight) if qkv_linear.bias is not None: - q_bias = Parameter(qkv_linear.bias[:q_proj_size], - requires_grad=False) - k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size + - k_proj_size], - requires_grad=False) - v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:], - requires_grad=False) + q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False) + k_bias = Parameter( + qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size], + requires_grad=False, + ) + v_bias = Parameter( + qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False + ) self.register_parameter("q_bias", q_bias) self.register_parameter("k_bias", k_bias) self.register_parameter("v_bias", v_bias) @@ -102,42 +105,48 @@ class XlaQKVParallelLinear(nn.Module): # The concat and the following split will be noop, and should be # optimized away by the compiler. qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) - output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \ - self.skip_bias_add else None + output_bias = ( + torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None + ) if not self.return_bias: return qkv_proj return qkv_proj, output_bias -def partition_column_parallel_linear(layer: torch.nn.Module, - mesh: xs.Mesh) -> torch.nn.Module: +def partition_column_parallel_linear( + layer: torch.nn.Module, mesh: xs.Mesh +) -> torch.nn.Module: assert isinstance(layer, ColumnParallelLinear) - xs.mark_sharding(layer.weight, mesh, ('x', None)) + xs.mark_sharding(layer.weight, mesh, ("x", None)) logger.debug("Applied column-parallel sharding to %s", layer) return layer -def partition_row_parallel_linear(layer: torch.nn.Module, - mesh: xs.Mesh) -> torch.nn.Module: +def partition_row_parallel_linear( + layer: torch.nn.Module, mesh: xs.Mesh +) -> torch.nn.Module: assert isinstance(layer, RowParallelLinear) - xs.mark_sharding(layer.weight, mesh, (None, 'x')) + xs.mark_sharding(layer.weight, mesh, (None, "x")) logger.debug("Applied row-parallel sharding to %s", layer) return layer -def partition_qkv_parallel_linear(layer: torch.nn.Module, - mesh: xs.Mesh) -> torch.nn.Module: +def partition_qkv_parallel_linear( + layer: torch.nn.Module, mesh: xs.Mesh +) -> torch.nn.Module: assert isinstance(layer, QKVParallelLinear) xla_layer = XlaQKVParallelLinear(layer, mesh) logger.debug("Applied qkv parallel sharding to %s", layer) return xla_layer -MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([ - ("QKVParallelLinear", partition_qkv_parallel_linear), - ("ColumnParallelLinear", partition_column_parallel_linear), - ("RowParallelLinear", partition_row_parallel_linear), -]) +MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict( + [ + ("QKVParallelLinear", partition_qkv_parallel_linear), + ("ColumnParallelLinear", partition_column_parallel_linear), + ("RowParallelLinear", partition_row_parallel_linear), + ] +) def get_fqn(module): @@ -147,9 +156,9 @@ def get_fqn(module): def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: """ - Recursively check a PyTorch model and apply appropriate sharding based on + Recursively check a PyTorch model and apply appropriate sharding based on the MODULE_TYPE_TO_WRAPPING_FUNC mapping. - + Args: model: torch.nn.Module to process mesh: An XLA SPMD mesh object used for sharding @@ -161,7 +170,8 @@ def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: wrapped_module = wrapping_func(module, mesh) assert parent is not None and name is not None, ( - "Top Level module is not expected to be wrapped.") + "Top Level module is not expected to be wrapped." + ) if wrapped_module is not module: # Wrapped module and module are different py object. # The original module should be replaced by the diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 67f71643d039c..a35f28c25385a 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -19,9 +19,12 @@ from typing import Any, Optional import torch from torch.distributed import ProcessGroup, TCPStore -from torch.distributed.distributed_c10d import (Backend, PrefixStore, - _get_default_timeout, - _unregister_process_group) +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _get_default_timeout, + _unregister_process_group, +) from torch.distributed.rendezvous import rendezvous import vllm.envs as envs @@ -33,9 +36,9 @@ logger = init_logger(__name__) # We prefer to use os.sched_yield as it results in tighter polling loops, # measured to be around 3e-7 seconds. However on earlier versions of Python # os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) -USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) - or (sys.version_info[:2] == (3, 10) - and sys.version_info[2] >= 8)) +USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or ( + sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8 +) def sched_yield(): @@ -48,7 +51,8 @@ def sched_yield(): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( - numerator, denominator) + numerator, denominator + ) def divide(numerator, denominator): @@ -63,16 +67,16 @@ def split_tensor_along_last_dim( num_partitions: int, contiguous_split_chunks: bool = False, ) -> Sequence[torch.Tensor]: - """ Split a tensor along its last dimension. + """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. - Returns: - A list of Tensors + Returns: + A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 @@ -86,8 +90,9 @@ def split_tensor_along_last_dim( return tensor_list -def get_pp_indices(num_hidden_layers: int, pp_rank: int, - pp_size: int) -> tuple[int, int]: +def get_pp_indices( + num_hidden_layers: int, pp_rank: int, pp_size: int +) -> tuple[int, int]: """Try to evenly distribute layers across partitions. If the number of layers is not divisible by the number of partitions, @@ -104,17 +109,15 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, partition_list_str = envs.VLLM_PP_LAYER_PARTITION if partition_list_str is not None: try: - partitions = [ - int(layer) for layer in partition_list_str.split(",") - ] + partitions = [int(layer) for layer in partition_list_str.split(",")] except ValueError as err: - raise ValueError("Invalid partition string: {}".format( - partition_list_str)) from err + raise ValueError( + "Invalid partition string: {}".format(partition_list_str) + ) from err if len(partitions) != pp_size: raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") if sum(partitions) != num_hidden_layers: - raise ValueError( - f"{sum(partitions)=} does not match {num_hidden_layers=}.") + raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") else: layers_per_partition = num_hidden_layers // pp_size partitions = [layers_per_partition for _ in range(pp_size)] @@ -126,7 +129,8 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, "Hidden layers were unevenly partitioned: [%s]. " "This can be manually overridden using the " "VLLM_PP_LAYER_PARTITION environment variable", - ",".join(str(p) for p in partitions)) + ",".join(str(p) for p in partitions), + ) start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] @@ -140,6 +144,7 @@ class StatelessProcessGroup: group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects. """ + rank: int world_size: int store: torch._C._distributed_c10d.Store @@ -154,21 +159,16 @@ class StatelessProcessGroup: # src rank -> counter recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) broadcast_send_counter: int = 0 - broadcast_recv_src_counter: dict[int, int] = dataclasses.field( - default_factory=dict) + broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) # A deque to store the data entries, with key and timestamp. - entries: deque[tuple[str, - float]] = dataclasses.field(default_factory=deque) + entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque) def __post_init__(self): assert self.rank < self.world_size self.send_dst_counter = {i: 0 for i in range(self.world_size)} self.recv_src_counter = {i: 0 for i in range(self.world_size)} - self.broadcast_recv_src_counter = { - i: 0 - for i in range(self.world_size) - } + self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" @@ -192,8 +192,8 @@ class StatelessProcessGroup: def recv_obj(self, src: int) -> Any: """Receive an object from a source rank.""" obj = pickle.loads( - self.store.get( - f"send_to/{self.rank}/{self.recv_src_counter[src]}")) + self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") + ) self.recv_src_counter[src] += 1 return obj @@ -204,15 +204,13 @@ class StatelessProcessGroup: """ if self.rank == src: self.expire_data() - key = (f"broadcast_from/{src}/" - f"{self.broadcast_send_counter}") + key = f"broadcast_from/{src}/{self.broadcast_send_counter}" self.store.set(key, pickle.dumps(obj)) self.broadcast_send_counter += 1 self.entries.append((key, time.time())) return obj else: - key = (f"broadcast_from/{src}/" - f"{self.broadcast_recv_src_counter[src]}") + key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}" recv_obj = pickle.loads(self.store.get(key)) self.broadcast_recv_src_counter[src] += 1 return recv_obj @@ -278,8 +276,7 @@ class StatelessProcessGroup: # Check for timeout cur_time = time.time() if cur_time - start_time > timeout: - raise RuntimeError("Barrier timed out after %f seconds", - timeout) + raise RuntimeError("Barrier timed out after %f seconds", timeout) # Check for each process for i in range(self.world_size): @@ -326,8 +323,7 @@ class StatelessProcessGroup: while len(processes_departed) < self.world_size: # Check for timeout if time.time() - start_time > timeout: - raise RuntimeError("Barrier departure timed out after %f s", - timeout) + raise RuntimeError("Barrier departure timed out after %f s", timeout) # Check for each process for i in range(self.world_size): @@ -356,14 +352,12 @@ class StatelessProcessGroup: try: self.store.delete_key(f"arrival_{barrier_id}_{i}") except Exception: - logger.debug("Error deleting key: %s", - f'arrival_{barrier_id}_{i}') + logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}") try: self.store.delete_key(f"departure_{barrier_id}_{i}") except Exception: - logger.debug("Error deleting key: %s", - f'departure_{barrier_id}_{i}') + logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}") @staticmethod def create( @@ -388,7 +382,7 @@ class StatelessProcessGroup: used for exchanging metadata. With this function, process A and process B can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. - """ # noqa + """ # noqa launch_server = rank == 0 if launch_server: # listen on the specified interface (instead of 0.0.0.0) @@ -416,14 +410,19 @@ class StatelessProcessGroup: world_size=world_size, store=store, socket=listen_socket, - data_expiration_seconds=data_expiration_seconds) + data_expiration_seconds=data_expiration_seconds, + ) -def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, - group_rank: int, group_size: int, - timeout: timedelta) -> ProcessGroup: +def init_gloo_process_group( + backend: Backend, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, +) -> ProcessGroup: """ - Stateless init ProcessGroup with gloo backend compatible with + Stateless init ProcessGroup with gloo backend compatible with different torch versions. """ if is_torch_equal_or_newer("2.6"): @@ -441,10 +440,10 @@ def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, options, ) from torch.distributed.distributed_c10d import ProcessGroupGloo - backend_class = ProcessGroupGloo(prefix_store, - group_rank, - group_size, - timeout=timeout) + + backend_class = ProcessGroupGloo( + prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.GLOO device = torch.device("cpu") if is_torch_equal_or_newer("2.6"): @@ -457,8 +456,8 @@ def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, - backend: str) -> ProcessGroup: + host: str, port: int, rank: int, world_size: int, backend: str +) -> ProcessGroup: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for @@ -495,7 +494,8 @@ def stateless_init_torch_distributed_process_group( timeout = _get_default_timeout(backend) store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) + rendezvous(init_method, rank, world_size, timeout=timeout) + ) store.set_timeout(timeout) group_rank = rank @@ -506,22 +506,25 @@ def stateless_init_torch_distributed_process_group( prefix_store = PrefixStore(init_method, store) if backend == "gloo": - return init_gloo_process_group(backend=backend, - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout) + return init_gloo_process_group( + backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout, + ) from vllm.platforms import current_platform + return current_platform.stateless_init_device_torch_dist_pg( backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, - timeout=timeout) + timeout=timeout, + ) -def stateless_destroy_torch_distributed_process_group( - pg: ProcessGroup) -> None: +def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: """ Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). @@ -531,6 +534,7 @@ def stateless_destroy_torch_distributed_process_group( else: # Lazy import for non-CUDA backends. from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) _unregister_process_group(pg.group_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 965264ee3097a..cb47e439fc733 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable import argparse import copy import dataclasses @@ -10,43 +9,80 @@ import json import sys from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, - Literal, Optional, Type, TypeVar, Union, cast, get_args, - get_origin) +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Literal, + Optional, + TypeVar, + Union, + cast, + get_args, + get_origin, +) import huggingface_hub import regex as re import torch from pydantic import TypeAdapter, ValidationError +from pydantic.fields import FieldInfo from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigFormat, ConfigType, ConvertOption, - DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, EPLBConfig, - GuidedDecodingBackend, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ModelImpl, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, - SchedulerPolicy, SpeculativeConfig, TaskOption, - TokenizerMode, VllmConfig, get_attr_docs, get_field) +from vllm.config import ( + CacheConfig, + CompilationConfig, + ConfigType, + DeviceConfig, + EPLBConfig, + KVEventsConfig, + KVTransferConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + MultiModalConfig, + ObservabilityConfig, + ParallelConfig, + PoolerConfig, + SchedulerConfig, + SpeculativeConfig, + StructuredOutputsConfig, + VllmConfig, + get_attr_docs, +) +from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo +from vllm.config.device import Device +from vllm.config.model import ( + ConvertOption, + HfOverrides, + LogprobsMode, + ModelDType, + RunnerOption, + TaskOption, + TokenizerMode, +) +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode +from vllm.config.observability import DetailedTraceModules +from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy +from vllm.config.scheduler import SchedulerPolicy +from vllm.config.utils import get_field from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import get_model_path, is_interleaved +from vllm.transformers_utils.config import ( + get_model_path, + is_interleaved, + maybe_override_with_speculators, +) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, - GiB_bytes, get_ip, is_in_ray_actor) +from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor from vllm.v1.sample.logits_processor import LogitsProcessor -# yapf: enable - if TYPE_CHECKING: from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods @@ -67,20 +103,18 @@ TypeHintT = Union[type[T], object] def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _parse_type(val: str) -> T: try: return return_type(val) except ValueError as e: raise argparse.ArgumentTypeError( - f"Value {val} cannot be converted to {return_type}.") from e + f"Value {val} cannot be converted to {return_type}." + ) from e return _parse_type -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: - +def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: def _optional_type(val: str) -> Optional[T]: if val == "" or val == "None": return None @@ -121,7 +155,8 @@ def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: if not all(isinstance(option, option_type) for option in options): raise ValueError( "All options must be of the same type. " - f"Got {options} with types {[type(c) for c in options]}") + f"Got {options} with types {[type(c) for c in options]}" + ) kwarg = "metavar" if contains_type(type_hints, str) else "choices" return {"type": option_type, kwarg: sorted(options)} @@ -152,9 +187,17 @@ def is_online_quantization(quantization: Any) -> bool: return quantization in ["inc"] +NEEDS_HELP = ( + any("--help" in arg for arg in sys.argv) # vllm SUBCOMMAND --help + or (argv0 := sys.argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND + or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND +) + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: - cls_docs = get_attr_docs(cls) + # Save time only getting attr docs if we're generating help text + cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} for field in fields(cls): # Get the set of possible types for the field @@ -167,12 +210,19 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the default value of the field if field.default is not MISSING: default = field.default + # Handle pydantic.Field defaults + if isinstance(default, FieldInfo): + default = ( + default.default + if default.default_factory is None + else default.default_factory() + ) elif field.default_factory is not MISSING: default = field.default_factory() # Get the help text for the field name = field.name - help = cls_docs[name].strip() + help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -180,8 +230,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = ("Should either be a valid JSON string or JSON keys passed " - "individually.") + json_tip = ( + "Should either be a valid JSON string or JSON keys passed individually." + ) if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: @@ -203,7 +254,8 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: tuple_type = types[0] assert all(t is tuple_type for t in types if t is not Ellipsis), ( "All non-Ellipsis tuple elements must be of the same " - f"type. Got {types}.") + f"type. Got {types}." + ) kwargs[name]["type"] = tuple_type kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) elif contains_type(type_hints, list): @@ -219,23 +271,30 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: elif contains_type(type_hints, int): kwargs[name]["type"] = int # Special case for large integers - if name in {"max_model_len", "max_num_batched_tokens"}: + human_readable_ints = { + "max_model_len", + "max_num_batched_tokens", + "kv_cache_memory_bytes", + } + if name in human_readable_ints: kwargs[name]["type"] = human_readable_int + kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif (contains_type(type_hints, dict) - and (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints))): + elif contains_type(type_hints, dict) and ( + contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints) + ): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += f"\n\n{json_tip}" - elif (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints)): + elif contains_type(type_hints, str) or any( + is_not_builtin(th) for th in type_hints + ): kwargs[name]["type"] = str else: - raise ValueError( - f"Unsupported type {type_hints} for argument {name}.") + raise ValueError(f"Unsupported type {type_hints} for argument {name}.") # If the type hint was a sequence of literals, use the helper function # to update the type and choices @@ -254,6 +313,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: def get_kwargs(cls: ConfigType) -> dict[str, Any]: """Return argparse kwargs for the given Config dataclass. + If `--help` or `mkdocs` are not present in the command line command, the + attribute documentation will not be included in the help output. + The heavy computation is cached via functools.lru_cache, and a deep copy is returned so callers can mutate the dictionary without affecting the cached version. @@ -264,9 +326,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model: str = ModelConfig.model - served_model_name: Optional[Union[ - str, List[str]]] = ModelConfig.served_model_name + served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name tokenizer: Optional[str] = ModelConfig.tokenizer hf_config_path: Optional[str] = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner @@ -277,24 +339,26 @@ class EngineArgs: tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path + allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains download_dir: Optional[str] = LoadConfig.download_dir + safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = ModelConfig.seed max_model_len: Optional[int] = ModelConfig.max_model_len - cuda_graph_sizes: list[int] = get_field(SchedulerConfig, - "cuda_graph_sizes") + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. - distributed_executor_backend: Optional[Union[ - str, DistributedExecutorBackend, - Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + distributed_executor_backend: Optional[ + Union[str, DistributedExecutorBackend, type[ExecutorBase]] + ] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -304,29 +368,41 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_dbo: bool = ParallelConfig.enable_dbo + dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold + dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold + disable_nccl_for_dp_synchronization: bool = ( + ParallelConfig.disable_nccl_for_dp_synchronization + ) eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb + expert_placement_strategy: ExpertPlacementStrategy = ( + ParallelConfig.expert_placement_strategy + ) + _api_process_count: int = ParallelConfig._api_process_count + _api_process_rank: int = ParallelConfig._api_process_rank num_redundant_experts: int = EPLBConfig.num_redundant_experts eplb_window_size: int = EPLBConfig.window_size eplb_step_interval: int = EPLBConfig.step_interval eplb_log_balancedness: bool = EPLBConfig.log_balancedness - max_parallel_loading_workers: Optional[ - int] = ParallelConfig.max_parallel_loading_workers + max_parallel_loading_workers: Optional[int] = ( + ParallelConfig.max_parallel_loading_workers + ) block_size: Optional[BlockSize] = CacheConfig.block_size enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching - prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo + ) disable_sliding_window: bool = ModelConfig.disable_sliding_window disable_cascade_attn: bool = ModelConfig.disable_cascade_attn swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization - max_num_batched_tokens: Optional[ - int] = SchedulerConfig.max_num_batched_tokens + kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes + max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills - long_prefill_token_threshold: int = \ - SchedulerConfig.long_prefill_token_threshold + long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode @@ -340,77 +416,79 @@ class EngineArgs: tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision quantization: Optional[QuantizationMethods] = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager - max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - limit_mm_per_prompt: dict[str, int] = \ - get_field(MultiModalConfig, "limit_per_prompt") + limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field( + MultiModalConfig, "limit_per_prompt" + ) interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings - media_io_kwargs: dict[str, dict[str, - Any]] = get_field(MultiModalConfig, - "media_io_kwargs") - mm_processor_kwargs: Optional[Dict[str, Any]] = \ - MultiModalConfig.mm_processor_kwargs + media_io_kwargs: dict[str, dict[str, Any]] = get_field( + MultiModalConfig, "media_io_kwargs" + ) + mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED - mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_type: Optional[MMCacheType] = ( + MultiModalConfig.mm_processor_cache_type + ) + mm_shm_cache_max_object_size_mb: int = ( + MultiModalConfig.mm_shm_cache_max_object_size_mb + ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode + io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling + video_pruning_rate: float = MultiModalConfig.video_pruning_rate # LoRA fields enable_lora: bool = False - enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank - default_mm_loras: Optional[Dict[str, str]] = \ - LoRAConfig.default_mm_loras + default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[ - int] = CacheConfig.num_gpu_blocks_override + num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots - model_loader_extra_config: dict = \ - get_field(LoadConfig, "model_loader_extra_config") - ignore_patterns: Optional[Union[str, - List[str]]] = LoadConfig.ignore_patterns - preemption_mode: Optional[str] = SchedulerConfig.preemption_mode + model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") + ignore_patterns: Union[str, list[str]] = get_field(LoadConfig, "ignore_patterns") - scheduler_delay_factor: float = SchedulerConfig.delay_factor - enable_chunked_prefill: Optional[ - bool] = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( - SchedulerConfig.disable_hybrid_kv_cache_manager) + SchedulerConfig.disable_hybrid_kv_cache_manager + ) - guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend - guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback - guided_decoding_disable_any_whitespace: bool = \ - DecodingConfig.disable_any_whitespace - guided_decoding_disable_additional_properties: bool = \ - DecodingConfig.disable_additional_properties - logits_processor_pattern: Optional[ - str] = ModelConfig.logits_processor_pattern + structured_outputs_config: StructuredOutputsConfig = get_field( + VllmConfig, "structured_outputs_config" + ) + reasoning_parser: str = StructuredOutputsConfig.reasoning_parser + # Deprecated guided decoding fields + guided_decoding_backend: Optional[str] = None + guided_decoding_disable_fallback: Optional[bool] = None + guided_decoding_disable_any_whitespace: Optional[bool] = None + guided_decoding_disable_additional_properties: Optional[bool] = None - speculative_config: Optional[Dict[str, Any]] = None + logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern - show_hidden_metrics_for_version: Optional[str] = \ + speculative_config: Optional[dict[str, Any]] = None + + show_hidden_metrics_for_version: Optional[str] = ( ObservabilityConfig.show_hidden_metrics_for_version - otlp_traces_endpoint: Optional[str] = \ - ObservabilityConfig.otlp_traces_endpoint - collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ + ) + otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: Optional[list[DetailedTraceModules]] = ( ObservabilityConfig.collect_detailed_traces - disable_async_output_proc: bool = not ModelConfig.use_async_output_proc + ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy - scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls + scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls - override_neuron_config: dict[str, Any] = \ - get_field(ModelConfig, "override_neuron_config") - override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config + override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( ModelConfig.override_pooler_config - compilation_config: CompilationConfig = \ - get_field(VllmConfig, "compilation_config") + ) + compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -419,8 +497,9 @@ class EngineArgs: generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode - override_generation_config: dict[str, Any] = \ - get_field(ModelConfig, "override_generation_config") + override_generation_config: dict[str, Any] = get_field( + ModelConfig, "override_generation_config" + ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype @@ -428,9 +507,7 @@ class EngineArgs: mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype - additional_config: dict[str, Any] = \ - get_field(VllmConfig, "additional_config") - reasoning_parser: str = DecodingConfig.reasoning_backend + additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location @@ -438,35 +515,36 @@ class EngineArgs: # DEPRECATED enable_multimodal_encoder_data_parallel: bool = False - logits_processors: Optional[list[Union[ - str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = ( + ModelConfig.logits_processors + ) """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - kv_sharing_fast_prefill: bool = \ - CacheConfig.kv_sharing_fast_prefill + kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object if isinstance(self.compilation_config, dict): - self.compilation_config = CompilationConfig( - **self.compilation_config) + self.compilation_config = CompilationConfig(**self.compilation_config) if isinstance(self.eplb_config, dict): - self.eplb_config = EPLBConfig.from_cli(json.dumps( - self.eplb_config)) + self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins + load_general_plugins() # when use hf offline,replace model id to local model path if huggingface_hub.constants.HF_HUB_OFFLINE: model_id = self.model self.model = get_model_path(self.model, self.revision) logger.info( - "HF_HUB_OFFLINE is True, replace model_id [%s] " \ - "to model_path [%s]",model_id, self.model) + "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]", + model_id, + self.model, + ) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -478,95 +556,92 @@ class EngineArgs: title="ModelConfig", description=ModelConfig.__doc__, ) - if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]): model_group.add_argument("--model", **model_kwargs["model"]) model_group.add_argument("--runner", **model_kwargs["runner"]) model_group.add_argument("--convert", **model_kwargs["convert"]) - model_group.add_argument("--task", - **model_kwargs["task"], - deprecated=True) + model_group.add_argument("--task", **model_kwargs["task"], deprecated=True) model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) - model_group.add_argument("--tokenizer-mode", - **model_kwargs["tokenizer_mode"]) - model_group.add_argument("--trust-remote-code", - **model_kwargs["trust_remote_code"]) + model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"]) + model_group.add_argument( + "--trust-remote-code", **model_kwargs["trust_remote_code"] + ) model_group.add_argument("--dtype", **model_kwargs["dtype"]) model_group.add_argument("--seed", **model_kwargs["seed"]) - model_group.add_argument("--hf-config-path", - **model_kwargs["hf_config_path"]) - model_group.add_argument("--allowed-local-media-path", - **model_kwargs["allowed_local_media_path"]) - model_group.add_argument("--revision", **model_kwargs["revision"]) - model_group.add_argument("--code-revision", - **model_kwargs["code_revision"]) - model_group.add_argument("--rope-scaling", - **model_kwargs["rope_scaling"]) - model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) - model_group.add_argument("--tokenizer-revision", - **model_kwargs["tokenizer_revision"]) - model_group.add_argument("--max-model-len", - **model_kwargs["max_model_len"]) - model_group.add_argument("--quantization", "-q", - **model_kwargs["quantization"]) - model_group.add_argument("--enforce-eager", - **model_kwargs["enforce_eager"]) - model_group.add_argument("--max-seq-len-to-capture", - **model_kwargs["max_seq_len_to_capture"]) - model_group.add_argument("--max-logprobs", - **model_kwargs["max_logprobs"]) - model_group.add_argument("--logprobs-mode", - choices=[f.value for f in LogprobsMode], - **model_kwargs["logprobs_mode"]) - model_group.add_argument("--disable-sliding-window", - **model_kwargs["disable_sliding_window"]) - model_group.add_argument("--disable-cascade-attn", - **model_kwargs["disable_cascade_attn"]) - model_group.add_argument("--skip-tokenizer-init", - **model_kwargs["skip_tokenizer_init"]) - model_group.add_argument("--enable-prompt-embeds", - **model_kwargs["enable_prompt_embeds"]) - model_group.add_argument("--served-model-name", - **model_kwargs["served_model_name"]) - # This one is a special case because it is the - # opposite of ModelConfig.use_async_output_proc + model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"]) model_group.add_argument( - "--disable-async-output-proc", - action="store_true", - default=EngineArgs.disable_async_output_proc, - help="Disable async output processing. This may result in " - "lower performance.") - model_group.add_argument("--config-format", - choices=[f.value for f in ConfigFormat], - **model_kwargs["config_format"]) + "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"] + ) + model_group.add_argument( + "--allowed-media-domains", **model_kwargs["allowed_media_domains"] + ) + model_group.add_argument("--revision", **model_kwargs["revision"]) + model_group.add_argument("--code-revision", **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"]) + model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) + model_group.add_argument( + "--tokenizer-revision", **model_kwargs["tokenizer_revision"] + ) + model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"]) + model_group.add_argument( + "--disable-sliding-window", **model_kwargs["disable_sliding_window"] + ) + model_group.add_argument( + "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"] + ) + model_group.add_argument( + "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"] + ) + model_group.add_argument( + "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"] + ) + model_group.add_argument( + "--served-model-name", **model_kwargs["served_model_name"] + ) + model_group.add_argument("--config-format", **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs - model_group.add_argument("--hf-token", - type=str, - nargs="?", - const=True, - default=model_kwargs["hf_token"]["default"], - help=model_kwargs["hf_token"]["help"]) - model_group.add_argument("--hf-overrides", - **model_kwargs["hf_overrides"]) - model_group.add_argument("--override-neuron-config", - **model_kwargs["override_neuron_config"]) - model_group.add_argument("--override-pooler-config", - **model_kwargs["override_pooler_config"]) - model_group.add_argument("--logits-processor-pattern", - **model_kwargs["logits_processor_pattern"]) - model_group.add_argument("--generation-config", - **model_kwargs["generation_config"]) - model_group.add_argument("--override-generation-config", - **model_kwargs["override_generation_config"]) - model_group.add_argument("--enable-sleep-mode", - **model_kwargs["enable_sleep_mode"]) - model_group.add_argument("--model-impl", - choices=[f.value for f in ModelImpl], - **model_kwargs["model_impl"]) - model_group.add_argument("--override-attention-dtype", - **model_kwargs["override_attention_dtype"]) - model_group.add_argument("--logits-processors", - **model_kwargs["logits_processors"]) + model_group.add_argument( + "--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"], + ) + model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) + model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"]) + model_group.add_argument( + "--override-pooler-config", + **model_kwargs["override_pooler_config"], + deprecated=True, + ) + model_group.add_argument( + "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"] + ) + model_group.add_argument( + "--generation-config", **model_kwargs["generation_config"] + ) + model_group.add_argument( + "--override-generation-config", **model_kwargs["override_generation_config"] + ) + model_group.add_argument( + "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"] + ) + model_group.add_argument("--model-impl", **model_kwargs["model_impl"]) + model_group.add_argument( + "--override-attention-dtype", **model_kwargs["override_attention_dtype"] + ) + model_group.add_argument( + "--logits-processors", **model_kwargs["logits_processors"] + ) + model_group.add_argument( + "--io-processor-plugin", **model_kwargs["io_processor_plugin"] + ) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -575,39 +650,44 @@ class EngineArgs: description=LoadConfig.__doc__, ) load_group.add_argument("--load-format", **load_kwargs["load_format"]) - load_group.add_argument("--download-dir", - **load_kwargs["download_dir"]) - load_group.add_argument("--model-loader-extra-config", - **load_kwargs["model_loader_extra_config"]) - load_group.add_argument("--ignore-patterns", - **load_kwargs["ignore_patterns"]) - load_group.add_argument("--use-tqdm-on-load", - **load_kwargs["use_tqdm_on_load"]) - load_group.add_argument('--pt-load-map-location', - **load_kwargs["pt_load_map_location"]) - - # Guided decoding arguments - guided_decoding_kwargs = get_kwargs(DecodingConfig) - guided_decoding_group = parser.add_argument_group( - title="DecodingConfig", - description=DecodingConfig.__doc__, + load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument( + "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"] ) - guided_decoding_group.add_argument("--guided-decoding-backend", - **guided_decoding_kwargs["backend"]) - guided_decoding_group.add_argument( - "--guided-decoding-disable-fallback", - **guided_decoding_kwargs["disable_fallback"]) - guided_decoding_group.add_argument( - "--guided-decoding-disable-any-whitespace", - **guided_decoding_kwargs["disable_any_whitespace"]) - guided_decoding_group.add_argument( - "--guided-decoding-disable-additional-properties", - **guided_decoding_kwargs["disable_additional_properties"]) - guided_decoding_group.add_argument( + load_group.add_argument( + "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"] + ) + load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--pt-load-map-location", **load_kwargs["pt_load_map_location"] + ) + + # Structured outputs arguments + structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) + structured_outputs_group = parser.add_argument_group( + title="StructuredOutputsConfig", + description=StructuredOutputsConfig.__doc__, + ) + structured_outputs_group.add_argument( "--reasoning-parser", # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), - **guided_decoding_kwargs["reasoning_backend"]) + **structured_outputs_kwargs["reasoning_parser"], + ) + # Deprecated guided decoding arguments + for arg, type in [ + ("--guided-decoding-backend", str), + ("--guided-decoding-disable-fallback", bool), + ("--guided-decoding-disable-any-whitespace", bool), + ("--guided-decoding-disable-additional-properties", bool), + ]: + structured_outputs_group.add_argument( + arg, + type=type, + help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."), + deprecated=True, + ) # Parallel arguments parallel_kwargs = get_kwargs(ParallelConfig) @@ -617,97 +697,132 @@ class EngineArgs: ) parallel_group.add_argument( "--distributed-executor-backend", - **parallel_kwargs["distributed_executor_backend"]) + **parallel_kwargs["distributed_executor_backend"], + ) parallel_group.add_argument( - "--pipeline-parallel-size", "-pp", - **parallel_kwargs["pipeline_parallel_size"]) - parallel_group.add_argument("--tensor-parallel-size", "-tp", - **parallel_kwargs["tensor_parallel_size"]) - parallel_group.add_argument("--data-parallel-size", "-dp", - **parallel_kwargs["data_parallel_size"]) + "--pipeline-parallel-size", + "-pp", + **parallel_kwargs["pipeline_parallel_size"], + ) parallel_group.add_argument( - '--data-parallel-rank', - '-dpn', + "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] + ) + parallel_group.add_argument( + "--decode-context-parallel-size", + "-dcp", + **parallel_kwargs["decode_context_parallel_size"], + ) + parallel_group.add_argument( + "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] + ) + parallel_group.add_argument( + "--data-parallel-rank", + "-dpn", type=int, - help='Data parallel rank of this instance. ' - 'When set, enables external load balancer mode.') - parallel_group.add_argument('--data-parallel-start-rank', - '-dpr', - type=int, - help='Starting data parallel rank ' - 'for secondary nodes.') - parallel_group.add_argument('--data-parallel-size-local', - '-dpl', - type=int, - help='Number of data parallel replicas ' - 'to run on this node.') - parallel_group.add_argument('--data-parallel-address', - '-dpa', - type=str, - help='Address of data parallel cluster ' - 'head-node.') - parallel_group.add_argument('--data-parallel-rpc-port', - '-dpp', - type=int, - help='Port for data parallel RPC ' - 'communication.') - parallel_group.add_argument('--data-parallel-backend', - '-dpb', - type=str, - default='mp', - help='Backend for data parallel, either ' - '"mp" or "ray".') + help="Data parallel rank of this instance. " + "When set, enables external load balancer mode.", + ) parallel_group.add_argument( - "--data-parallel-hybrid-lb", - **parallel_kwargs["data_parallel_hybrid_lb"]) + "--data-parallel-start-rank", + "-dpr", + type=int, + help="Starting data parallel rank for secondary nodes.", + ) parallel_group.add_argument( - "--enable-expert-parallel", - **parallel_kwargs["enable_expert_parallel"]) - parallel_group.add_argument("--enable-eplb", - **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--eplb-config", - **parallel_kwargs["eplb_config"]) + "--data-parallel-size-local", + "-dpl", + type=int, + help="Number of data parallel replicas to run on this node.", + ) + parallel_group.add_argument( + "--data-parallel-address", + "-dpa", + type=str, + help="Address of data parallel cluster head-node.", + ) + parallel_group.add_argument( + "--data-parallel-rpc-port", + "-dpp", + type=int, + help="Port for data parallel RPC communication.", + ) + parallel_group.add_argument( + "--data-parallel-backend", + "-dpb", + type=str, + default="mp", + help='Backend for data parallel, either "mp" or "ray".', + ) + parallel_group.add_argument( + "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"] + ) + parallel_group.add_argument( + "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] + ) + parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--dbo-decode-token-threshold", + **parallel_kwargs["dbo_decode_token_threshold"], + ) + parallel_group.add_argument( + "--dbo-prefill-token-threshold", + **parallel_kwargs["dbo_prefill_token_threshold"], + ) + parallel_group.add_argument( + "--disable-nccl-for-dp-synchronization", + **parallel_kwargs["disable_nccl_for_dp_synchronization"], + ) + parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--expert-placement-strategy", + **parallel_kwargs["expert_placement_strategy"], + ) parallel_group.add_argument( "--num-redundant-experts", type=int, - help= - "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-window-size", type=int, help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", - deprecated=True) + deprecated=True, + ) parallel_group.add_argument( "--eplb-step-interval", type=int, - help= - "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-log-balancedness", action=argparse.BooleanOptionalAction, - help= - "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--max-parallel-loading-workers", - **parallel_kwargs["max_parallel_loading_workers"]) + **parallel_kwargs["max_parallel_loading_workers"], + ) parallel_group.add_argument( - "--ray-workers-use-nsight", - **parallel_kwargs["ray_workers_use_nsight"]) + "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"] + ) parallel_group.add_argument( "--disable-custom-all-reduce", - **parallel_kwargs["disable_custom_all_reduce"]) - parallel_group.add_argument("--worker-cls", - **parallel_kwargs["worker_cls"]) - parallel_group.add_argument("--worker-extension-cls", - **parallel_kwargs["worker_extension_cls"]) + **parallel_kwargs["disable_custom_all_reduce"], + ) + parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) + parallel_group.add_argument( + "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] + ) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", action="store_true", - deprecated=True) + deprecated=True, + ) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -716,27 +831,36 @@ class EngineArgs: description=CacheConfig.__doc__, ) cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) - cache_group.add_argument("--gpu-memory-utilization", - **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument( + "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"] + ) + cache_group.add_argument( + "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"] + ) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) - cache_group.add_argument("--kv-cache-dtype", - **cache_kwargs["cache_dtype"]) - cache_group.add_argument("--num-gpu-blocks-override", - **cache_kwargs["num_gpu_blocks_override"]) - cache_group.add_argument("--enable-prefix-caching", - **cache_kwargs["enable_prefix_caching"]) - cache_group.add_argument("--prefix-caching-hash-algo", - **cache_kwargs["prefix_caching_hash_algo"]) - cache_group.add_argument("--cpu-offload-gb", - **cache_kwargs["cpu_offload_gb"]) - cache_group.add_argument("--calculate-kv-scales", - **cache_kwargs["calculate_kv_scales"]) - cache_group.add_argument("--kv-sharing-fast-prefill", - **cache_kwargs["kv_sharing_fast_prefill"]) - cache_group.add_argument("--mamba-cache-dtype", - **cache_kwargs["mamba_cache_dtype"]) - cache_group.add_argument("--mamba-ssm-cache-dtype", - **cache_kwargs["mamba_ssm_cache_dtype"]) + cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) + cache_group.add_argument( + "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"] + ) + cache_group.add_argument( + "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"] + ) + cache_group.add_argument( + "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"] + ) + cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument( + "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"] + ) + cache_group.add_argument( + "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"] + ) + cache_group.add_argument( + "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"] + ) + cache_group.add_argument( + "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -744,26 +868,41 @@ class EngineArgs: title="MultiModalConfig", description=MultiModalConfig.__doc__, ) - multimodal_group.add_argument("--limit-mm-per-prompt", - **multimodal_kwargs["limit_per_prompt"]) - multimodal_group.add_argument("--media-io-kwargs", - **multimodal_kwargs["media_io_kwargs"]) multimodal_group.add_argument( - "--mm-processor-kwargs", - **multimodal_kwargs["mm_processor_kwargs"]) + "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] + ) multimodal_group.add_argument( - "--mm-processor-cache-gb", - **multimodal_kwargs["mm_processor_cache_gb"]) - multimodal_group.add_argument("--disable-mm-preprocessor-cache", - action="store_true", - deprecated=True) + "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] + ) multimodal_group.add_argument( - "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) + "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"] + ) multimodal_group.add_argument( - "--interleave-mm-strings", - **multimodal_kwargs["interleave_mm_strings"]) - multimodal_group.add_argument("--skip-mm-profiling", - **multimodal_kwargs["skip_mm_profiling"]) + "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"] + ) + multimodal_group.add_argument( + "--disable-mm-preprocessor-cache", action="store_true", deprecated=True + ) + multimodal_group.add_argument( + "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"] + ) + multimodal_group.add_argument( + "--mm-shm-cache-max-object-size-mb", + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"], + ) + multimodal_group.add_argument( + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] + ) + multimodal_group.add_argument( + "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"] + ) + multimodal_group.add_argument( + "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"] + ) + + multimodal_group.add_argument( + "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -774,24 +913,22 @@ class EngineArgs: lora_group.add_argument( "--enable-lora", action=argparse.BooleanOptionalAction, - help="If True, enable handling of LoRA adapters.") - lora_group.add_argument("--enable-lora-bias", - **lora_kwargs["bias_enabled"]) + help="If True, enable handling of LoRA adapters.", + ) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) - lora_group.add_argument("--max-lora-rank", - **lora_kwargs["max_lora_rank"]) - lora_group.add_argument("--lora-extra-vocab-size", - **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) + lora_group.add_argument( + "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"] + ) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], ) - lora_group.add_argument("--max-cpu-loras", - **lora_kwargs["max_cpu_loras"]) - lora_group.add_argument("--fully-sharded-loras", - **lora_kwargs["fully_sharded_loras"]) - lora_group.add_argument("--default-mm-loras", - **lora_kwargs["default_mm_loras"]) + lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument( + "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] + ) + lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) @@ -801,21 +938,22 @@ class EngineArgs: ) observability_group.add_argument( "--show-hidden-metrics-for-version", - **observability_kwargs["show_hidden_metrics_for_version"]) + **observability_kwargs["show_hidden_metrics_for_version"], + ) observability_group.add_argument( - "--otlp-traces-endpoint", - **observability_kwargs["otlp_traces_endpoint"]) + "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"] + ) # TODO: generalise this special case choices = observability_kwargs["collect_detailed_traces"]["choices"] metavar = f"{{{','.join(choices)}}}" observability_kwargs["collect_detailed_traces"]["metavar"] = metavar observability_kwargs["collect_detailed_traces"]["choices"] += [ - ",".join(p) - for p in permutations(get_args(DetailedTraceModules), r=2) + ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2) ] observability_group.add_argument( "--collect-detailed-traces", - **observability_kwargs["collect_detailed_traces"]) + **observability_kwargs["collect_detailed_traces"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -824,44 +962,49 @@ class EngineArgs: description=SchedulerConfig.__doc__, ) scheduler_group.add_argument( - "--max-num-batched-tokens", - **scheduler_kwargs["max_num_batched_tokens"]) - scheduler_group.add_argument("--max-num-seqs", - **scheduler_kwargs["max_num_seqs"]) + "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"] + ) scheduler_group.add_argument( - "--max-num-partial-prefills", - **scheduler_kwargs["max_num_partial_prefills"]) + "--max-num-seqs", **scheduler_kwargs["max_num_seqs"] + ) + scheduler_group.add_argument( + "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"] + ) scheduler_group.add_argument( "--max-long-partial-prefills", - **scheduler_kwargs["max_long_partial_prefills"]) - scheduler_group.add_argument('--cuda-graph-sizes', - **scheduler_kwargs["cuda_graph_sizes"]) + **scheduler_kwargs["max_long_partial_prefills"], + ) + scheduler_group.add_argument( + "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"] + ) scheduler_group.add_argument( "--long-prefill-token-threshold", - **scheduler_kwargs["long_prefill_token_threshold"]) - scheduler_group.add_argument("--num-lookahead-slots", - **scheduler_kwargs["num_lookahead_slots"]) - scheduler_group.add_argument("--scheduler-delay-factor", - **scheduler_kwargs["delay_factor"]) - scheduler_group.add_argument("--preemption-mode", - **scheduler_kwargs["preemption_mode"]) + **scheduler_kwargs["long_prefill_token_threshold"], + ) + scheduler_group.add_argument( + "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"] + ) # multi-step scheduling has been removed; corresponding arguments # are no longer supported. - scheduler_group.add_argument("--scheduling-policy", - **scheduler_kwargs["policy"]) scheduler_group.add_argument( - "--enable-chunked-prefill", - **scheduler_kwargs["enable_chunked_prefill"]) + "--scheduling-policy", **scheduler_kwargs["policy"] + ) scheduler_group.add_argument( - "--disable-chunked-mm-input", - **scheduler_kwargs["disable_chunked_mm_input"]) - scheduler_group.add_argument("--scheduler-cls", - **scheduler_kwargs["scheduler_cls"]) + "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"] + ) + scheduler_group.add_argument( + "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"] + ) + scheduler_group.add_argument( + "--scheduler-cls", **scheduler_kwargs["scheduler_cls"] + ) scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", - **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) - scheduler_group.add_argument("--async-scheduling", - **scheduler_kwargs["async_scheduling"]) + **scheduler_kwargs["disable_hybrid_kv_cache_manager"], + ) + scheduler_group.add_argument( + "--async-scheduling", **scheduler_kwargs["async_scheduling"] + ) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -873,21 +1016,29 @@ class EngineArgs: # create_engine_config. So we set the type to a JSON string here to # delay the Pydantic validation that comes with SpeculativeConfig. vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads) - vllm_group.add_argument("--speculative-config", - **vllm_kwargs["speculative_config"]) - vllm_group.add_argument("--kv-transfer-config", - **vllm_kwargs["kv_transfer_config"]) - vllm_group.add_argument('--kv-events-config', - **vllm_kwargs["kv_events_config"]) - vllm_group.add_argument("--compilation-config", "-O", - **vllm_kwargs["compilation_config"]) - vllm_group.add_argument("--additional-config", - **vllm_kwargs["additional_config"]) + vllm_group.add_argument( + "--speculative-config", **vllm_kwargs["speculative_config"] + ) + vllm_group.add_argument( + "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] + ) + vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument( + "--compilation-config", "-O", **vllm_kwargs["compilation_config"] + ) + vllm_group.add_argument( + "--additional-config", **vllm_kwargs["additional_config"] + ) + vllm_group.add_argument( + "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] + ) # Other arguments - parser.add_argument('--disable-log-stats', - action='store_true', - help='Disable logging statistics.') + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="Disable logging statistics.", + ) return parser @@ -896,7 +1047,9 @@ class EngineArgs: # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + engine_args = cls( + **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + ) return engine_args def create_model_config(self) -> ModelConfig: @@ -905,16 +1058,20 @@ class EngineArgs: self.quantization = self.load_format = "gguf" # NOTE: This is to allow model loading from S3 in CI - if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 - and self.model in MODELS_ON_S3 and self.load_format == "auto"): + if ( + not isinstance(self, AsyncEngineArgs) + and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == "auto" + ): self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" - self.load_format = "runai_streamer" if self.disable_mm_preprocessor_cache: logger.warning( "`--disable-mm-preprocessor-cache` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-processor-cache-gb 0` instead.", ) + "Please use `--mm-processor-cache-gb 0` instead.", + ) self.mm_processor_cache_gb = 0 elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: @@ -931,7 +1088,8 @@ class EngineArgs: logger.warning( "--enable-multimodal-encoder-data-parallel` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-encoder-tp-mode data` instead.") + "Please use `--mm-encoder-tp-mode data` instead." + ) self.mm_encoder_tp_mode = "data" @@ -945,6 +1103,7 @@ class EngineArgs: tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, + allowed_media_domains=self.allowed_media_domains, dtype=self.dtype, seed=self.seed, revision=self.revision, @@ -957,7 +1116,6 @@ class EngineArgs: max_model_len=self.max_model_len, quantization=self.quantization, enforce_eager=self.enforce_eager, - max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, logprobs_mode=self.logprobs_mode, disable_sliding_window=self.disable_sliding_window, @@ -969,12 +1127,13 @@ class EngineArgs: interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, skip_mm_profiling=self.skip_mm_profiling, - use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_processor_cache_type=self.mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, - override_neuron_config=self.override_neuron_config, + pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config, @@ -983,35 +1142,39 @@ class EngineArgs: model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, + video_pruning_rate=self.video_pruning_rate, + io_processor_plugin=self.io_processor_plugin, ) def validate_tensorizer_args(self): - from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig) + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + for key in self.model_loader_extra_config: if key in TensorizerConfig._fields: - self.model_loader_extra_config["tensorizer_config"][ - key] = self.model_loader_extra_config[key] + self.model_loader_extra_config["tensorizer_config"][key] = ( + self.model_loader_extra_config[key] + ) def create_load_config(self) -> LoadConfig: - if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" if self.load_format == "tensorizer": if hasattr(self.model_loader_extra_config, "to_serializable"): self.model_loader_extra_config = ( - self.model_loader_extra_config.to_serializable()) + self.model_loader_extra_config.to_serializable() + ) self.model_loader_extra_config["tensorizer_config"] = {} - self.model_loader_extra_config["tensorizer_config"][ - "tensorizer_dir"] = self.model + self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = ( + self.model + ) self.validate_tensorizer_args() return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, - device="cpu" - if is_online_quantization(self.quantization) else None, + safetensors_load_strategy=self.safetensors_load_strategy, + device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, @@ -1033,38 +1196,20 @@ class EngineArgs: provided as a JSON string input via CLI arguments or directly as a dictionary from the engine. """ - - from vllm.transformers_utils.config import get_config - from vllm.transformers_utils.configs.speculators.base import ( - SpeculatorsConfig) - if self.speculative_config is None: - hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, self.revision, - self.code_revision, self.config_format) - - # if loading a SpeculatorsConfig, load the specualtive_config - # details from the config directly - # no user input required / expected - if isinstance(hf_config, SpeculatorsConfig): - # We create one since we don't create one - self.speculative_config = {} - self.speculative_config[ - "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = self.model - self.speculative_config["method"] = hf_config.method - else: - return None + return None # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. - self.speculative_config.update({ - "target_model_config": target_model_config, - "target_parallel_config": target_parallel_config, - "enable_chunked_prefill": enable_chunked_prefill, - "disable_log_stats": disable_log_stats, - }) + self.speculative_config.update( + { + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + } + ) return SpeculativeConfig(**self.speculative_config) def create_engine_config( @@ -1087,9 +1232,21 @@ class EngineArgs: """ current_platform.pre_register_and_update() - device_config = DeviceConfig( - device=cast(Device, current_platform.device_type)) + device_config = DeviceConfig(device=cast(Device, current_platform.device_type)) + model_config = self.create_model_config() + self.model = model_config.model + self.tokenizer = model_config.tokenizer + + (self.model, self.tokenizer, self.speculative_config) = ( + maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) + ) # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" # and fall back to V0 for experimental or unsupported features. @@ -1108,33 +1265,23 @@ class EngineArgs: else: envs.set_vllm_use_v1(use_v1) - # Set default arguments for V0 or V1 Engine. - if use_v1: - self._set_default_args_v1(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM/s390x CPUs in V1 - if current_platform.is_cpu( - ) and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM): - logger.info( - "Chunked prefill is not supported for ARM and POWER " - "and S390X CPUs; " - "disabling it for V1 backend.") - self.enable_chunked_prefill = False - else: - self._set_default_args_v0(model_config) + # Set default arguments for V1 Engine. + self._set_default_args(usage_context, model_config) + # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 + if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( + CpuArchEnum.POWERPC, + CpuArchEnum.S390X, + CpuArchEnum.ARM, + CpuArchEnum.RISCV, + ): + logger.info( + "Chunked prefill is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) + self.enable_chunked_prefill = False assert self.enable_chunked_prefill is not None - if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: - assert self.enforce_eager, ( - "Cuda graph is not supported with DualChunkFlashAttention. " - "To run the model in eager mode, set 'enforce_eager=True' " - "or use '--enforce-eager' in the CLI.") - assert current_platform.is_cuda(), ( - "DualChunkFlashAttention is only supported on CUDA platform.") - assert not use_v1, ( - "DualChunkFlashAttention is not supported on V1 engine. " - "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") - sliding_window: Optional[int] = None if not is_interleaved(model_config.hf_text_config): # Only set CacheConfig.sliding_window if the model is all sliding @@ -1142,9 +1289,20 @@ class EngineArgs: # global layers in interleaved sliding window models. sliding_window = model_config.get_sliding_window() + # Note(hc): In the current implementation of decode context + # parallel(DCP), tp_size needs to be divisible by dcp_size, + # because the world size does not change by dcp, it simply + # reuses the GPUs of TP group, and split one TP group into + # tp_size//dcp_size DCP groups. + assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, ( + f"tp_size={self.tensor_parallel_size} must be divisible by" + f"dcp_size={self.decode_context_parallel_size}." + ) + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, + kv_cache_memory_bytes=self.kv_cache_memory_bytes, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, @@ -1165,8 +1323,15 @@ class EngineArgs: # of a Ray task, therefore we check is_ray_initialized() # as opposed to is_in_ray_actor(). import ray + ray_runtime_env = ray.get_runtime_context().runtime_env - logger.info("Using ray runtime env: %s", ray_runtime_env) + # Avoid logging sensitive environment variables + sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {} + if "env_vars" in sanitized_env: + sanitized_env["env_vars"] = { + k: "***" for k in sanitized_env["env_vars"] + } + logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env) # Get the current placement group if Ray is initialized and # we are in a Ray actor. If so, then the placement group will be @@ -1180,15 +1345,15 @@ class EngineArgs: placement_group = ray.util.get_current_placement_group() assert not headless or not self.data_parallel_hybrid_lb, ( - "data_parallel_hybrid_lb is not applicable in " - "headless mode") + "data_parallel_hybrid_lb is not applicable in headless mode" + ) data_parallel_external_lb = self.data_parallel_rank is not None # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( - "data_parallel_size_local must be 1 when data_parallel_rank " - "is set") + "data_parallel_size_local must be 1 when data_parallel_rank is set" + ) data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. self.data_parallel_hybrid_lb = False @@ -1211,8 +1376,8 @@ class EngineArgs: self.data_parallel_rank = self.data_parallel_start_rank or 0 else: assert not self.data_parallel_hybrid_lb, ( - "data_parallel_size_local must be set to use " - "data_parallel_hybrid_lb.") + "data_parallel_size_local must be set to use data_parallel_hybrid_lb." + ) # Local DP size defaults to global DP size if not set. data_parallel_size_local = self.data_parallel_size @@ -1223,42 +1388,46 @@ class EngineArgs: if self.data_parallel_backend == "ray": host_ip = get_ip() logger.info( - "Using host IP %s as ray-based data parallel address", - host_ip) + "Using host IP %s as ray-based data parallel address", host_ip + ) data_parallel_address = host_ip else: assert self.data_parallel_backend == "mp", ( "data_parallel_backend can only be ray or mp, got %s", - self.data_parallel_backend) + self.data_parallel_backend, + ) data_parallel_address = ParallelConfig.data_parallel_master_ip else: data_parallel_address = self.data_parallel_address # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. - data_parallel_rpc_port = self.data_parallel_rpc_port if ( + data_parallel_rpc_port = ( self.data_parallel_rpc_port - is not None) else ParallelConfig.data_parallel_rpc_port + if (self.data_parallel_rpc_port is not None) + else ParallelConfig.data_parallel_rpc_port + ) if self.async_scheduling: # Async scheduling does not work with the uniprocess backend. if self.distributed_executor_backend is None: self.distributed_executor_backend = "mp" - logger.info("Using mp-based distributed executor backend " - "for async scheduling.") - if self.distributed_executor_backend == "uni": - raise ValueError("Async scheduling is not supported with " - "uni-process backend.") + logger.info( + "Defaulting to mp-based distributed executor " + "backend for async scheduling." + ) if self.pipeline_parallel_size > 1: - raise ValueError("Async scheduling is not supported with " - "pipeline-parallel-size > 1.") + raise ValueError( + "Async scheduling is not supported with pipeline-parallel-size > 1." + ) # Currently, async scheduling does not support speculative decoding. # TODO(woosuk): Support it. if self.speculative_config is not None: raise ValueError( "Currently, speculative decoding is not supported with " - "async scheduling.") + "async scheduling." + ) # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: @@ -1282,8 +1451,13 @@ class EngineArgs: data_parallel_backend=self.data_parallel_backend, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, + enable_dbo=self.enable_dbo, + dbo_decode_token_threshold=self.dbo_decode_token_threshold, + dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, + disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, + expert_placement_strategy=self.expert_placement_strategy, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, @@ -1292,20 +1466,11 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, + decode_context_parallel_size=self.decode_context_parallel_size, + _api_process_count=self._api_process_count, + _api_process_rank=self._api_process_rank, ) - if model_config.is_multimodal_model: - dp_supports_mm_processor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_processor_cache - and model_config.mm_processor_cache_gb > 0): - logger.warning( - "Multi-modal processor cache is disabled because " - "it is not compatible with data parallelism when " - "there does not exist a one-to-one correspondance " - "between API and engine core processes.") - model_config.set_mm_processor_cache_gb(0) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -1326,38 +1491,41 @@ class EngineArgs: max_model_len=model_config.max_model_len, cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, - delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, - preemption_mode=self.preemption_mode, - send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER - and parallel_config.use_ray), + is_encoder_decoder=model_config.is_encoder_decoder, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, - disable_hybrid_kv_cache_manager=self. - disable_hybrid_kv_cache_manager, + disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, ) if not model_config.is_multimodal_model and self.default_mm_loras: raise ValueError( "Default modality-specific LoRA(s) were provided for a " - "non multimodal model") + "non multimodal model" + ) - lora_config = LoRAConfig( - bias_enabled=self.enable_lora_bias, - max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - default_mm_loras=self.default_mm_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras - and self.max_cpu_loras > 0 else None) if self.enable_lora else None + lora_config = ( + LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + default_mm_loras=self.default_mm_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras and self.max_cpu_loras > 0 + else None, + ) + if self.enable_lora + else None + ) # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": @@ -1365,18 +1533,29 @@ class EngineArgs: load_config = self.create_load_config() - decoding_config = DecodingConfig( - backend=self.guided_decoding_backend, - disable_fallback=self.guided_decoding_disable_fallback, - disable_any_whitespace=self.guided_decoding_disable_any_whitespace, - disable_additional_properties=\ - self.guided_decoding_disable_additional_properties, - reasoning_backend=self.reasoning_parser - ) + # Pass reasoning_parser into StructuredOutputsConfig + if self.reasoning_parser: + self.structured_outputs_config.reasoning_parser = self.reasoning_parser + + # Forward the deprecated CLI args to the StructuredOutputsConfig + so_config = self.structured_outputs_config + if self.guided_decoding_backend is not None: + so_config.guided_decoding_backend = self.guided_decoding_backend + if self.guided_decoding_disable_fallback is not None: + so_config.guided_decoding_disable_fallback = ( + self.guided_decoding_disable_fallback + ) + if self.guided_decoding_disable_any_whitespace is not None: + so_config.guided_decoding_disable_any_whitespace = ( + self.guided_decoding_disable_any_whitespace + ) + if self.guided_decoding_disable_additional_properties is not None: + so_config.guided_decoding_disable_additional_properties = ( + self.guided_decoding_disable_additional_properties + ) observability_config = ObservabilityConfig( - show_hidden_metrics_for_version=( - self.show_hidden_metrics_for_version), + show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version), otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) @@ -1390,7 +1569,7 @@ class EngineArgs: lora_config=lora_config, speculative_config=speculative_config, load_config=load_config, - decoding_config=decoding_config, + structured_outputs_config=self.structured_outputs_config, observability_config=observability_config, compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, @@ -1406,218 +1585,105 @@ class EngineArgs: ############################################################# # Unsupported Feature Flags on V1. - if self.load_format == "sharded_state": + if self.logits_processor_pattern != EngineArgs.logits_processor_pattern: _raise_or_fallback( - feature_name=f"--load_format {self.load_format}", - recommend_to_remove=False) - return False - - if (self.logits_processor_pattern - != EngineArgs.logits_processor_pattern): - _raise_or_fallback(feature_name="--logits-processor-pattern", - recommend_to_remove=False) - return False - - if self.preemption_mode != SchedulerConfig.preemption_mode: - _raise_or_fallback(feature_name="--preemption-mode", - recommend_to_remove=True) - return False - - if (self.disable_async_output_proc - != EngineArgs.disable_async_output_proc): - _raise_or_fallback(feature_name="--disable-async-output-proc", - recommend_to_remove=True) - return False - - if self.scheduler_delay_factor != SchedulerConfig.delay_factor: - _raise_or_fallback(feature_name="--scheduler-delay-factor", - recommend_to_remove=True) - return False - - # Need at least Ampere for now (FA support required). - # Skip this check if we are running on a non-GPU platform, - # or if the device capability is not available - # (e.g. in a Ray actor without GPUs). - if (current_platform.is_cuda() - and current_platform.get_device_capability() - and current_platform.get_device_capability().major < 8): - _raise_or_fallback(feature_name="Compute Capability < 8.0", - recommend_to_remove=False) - return False - - if self.kv_cache_dtype != "auto": - supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype, model_config) - if not supported: - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False - - # No text embedding inputs so far. - if self.enable_prompt_embeds: - _raise_or_fallback(feature_name="--enable-prompt-embeds", - recommend_to_remove=False) + feature_name="--logits-processor-pattern", recommend_to_remove=False + ) return False # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: - _raise_or_fallback(feature_name=model_config.architectures, - recommend_to_remove=False) - return False - - # V1 mamba models are unoptimized. - if model_config.has_inner_state and _warn_or_fallback( - feature_name="Mamba"): + _raise_or_fallback( + feature_name=model_config.architectures, recommend_to_remove=False + ) return False # No Concurrent Partial Prefills so far. - if (self.max_num_partial_prefills - != SchedulerConfig.max_num_partial_prefills - or self.max_long_partial_prefills - != SchedulerConfig.max_long_partial_prefills): - _raise_or_fallback(feature_name="Concurrent Partial Prefill", - recommend_to_remove=False) - return False - - # No OTLP observability so far. - if (self.otlp_traces_endpoint or self.collect_detailed_traces): - _raise_or_fallback(feature_name="--otlp-traces-endpoint", - recommend_to_remove=False) + if ( + self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills + or self.max_long_partial_prefills + != SchedulerConfig.max_long_partial_prefills + ): + _raise_or_fallback( + feature_name="Concurrent Partial Prefill", recommend_to_remove=False + ) return False # V1 supports N-gram, Medusa, and Eagle speculative decoding. - if (self.speculative_config is not None - and self.speculative_config.get("method") == "draft_model"): - raise NotImplementedError( - "Speculative decoding with draft model is not supported yet. " - "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or deepseek_mtp.") + if self.speculative_config is not None: + # speculative_config could still be a dict at this point + if isinstance(self.speculative_config, dict): + method = self.speculative_config.get("method", None) + else: + method = self.speculative_config.method + + if method == "draft_model": + raise NotImplementedError( + "Draft model speculative decoding is not supported yet. " + "Please consider using other speculative decoding methods " + "such as ngram, medusa, eagle, or mtp." + ) V1_BACKENDS = [ - "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", - "PALLAS_VLLM_V1", - "TRITON_ATTN_VLLM_V1", + "TRITON_ATTN", "TRITON_MLA", "CUTLASS_MLA", "FLASHMLA", + "FLASH_ATTN_MLA", "FLASHINFER", - "FLASHINFER_VLLM_V1", + "FLASHINFER_MLA", "ROCM_AITER_MLA", - "TORCH_SDPA_VLLM_V1", + "TORCH_SDPA", "FLEX_ATTENTION", "TREE_ATTN", - "XFORMERS_VLLM_V1", + "XFORMERS", + "ROCM_ATTN", + "ROCM_AITER_UNIFIED_ATTN", ] - if (envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + if ( + envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS + ): name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False - # Platforms must decide if they can support v1 for this model - if not current_platform.supports_v1(model_config=model_config): - _raise_or_fallback( - feature_name=f"device type={current_platform.device_type}", - recommend_to_remove=False) - return False ############################################################# # Experimental Features - allow users to opt in. if self.pipeline_parallel_size > 1: - supports_pp = getattr(self.distributed_executor_backend, - 'supports_pp', False) + supports_pp = getattr( + self.distributed_executor_backend, "supports_pp", False + ) if not supports_pp and self.distributed_executor_backend not in ( - ParallelConfig.distributed_executor_backend, "ray", "mp", - "external_launcher"): - name = "Pipeline Parallelism without Ray distributed " \ - "executor or multiprocessing executor or external " \ - "launcher" - _raise_or_fallback(feature_name=name, - recommend_to_remove=False) + ParallelConfig.distributed_executor_backend, + "ray", + "mp", + "external_launcher", + ): + name = ( + "Pipeline Parallelism without Ray distributed " + "executor or multiprocessing executor or external " + "launcher" + ) + _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - # The platform may be supported on V1, but off by default for now. - if not current_platform.default_v1( # noqa: SIM103 - model_config=model_config) and _warn_or_fallback( - current_platform.device_name): - return False - - if (current_platform.is_cpu() - and model_config.get_sliding_window() is not None): - _raise_or_fallback(feature_name="sliding window (CPU backend)", - recommend_to_remove=False) + if current_platform.is_cpu() and model_config.get_sliding_window() is not None: + _raise_or_fallback( + feature_name="sliding window (CPU backend)", recommend_to_remove=False + ) return False ############################################################# return True - def _set_default_args_v0(self, model_config: ModelConfig) -> None: - """Set Default Arguments for V0 Engine.""" - - max_model_len = model_config.max_model_len - use_long_context = max_model_len > 32768 - if self.enable_chunked_prefill is None: - # Chunked prefill not supported for Multimodal or MLA in V0. - if model_config.is_multimodal_model or model_config.use_mla: - self.enable_chunked_prefill = False - - # Enable chunked prefill by default for long context (> 32K) - # models to avoid OOM errors in initial memory profiling phase. - elif use_long_context: - is_gpu = current_platform.is_cuda() - use_sliding_window = (model_config.get_sliding_window() - is not None) - use_spec_decode = self.speculative_config is not None - - if (is_gpu and not use_sliding_window and not use_spec_decode - and not self.enable_lora - and model_config.runner_type != "pooling"): - self.enable_chunked_prefill = True - logger.warning( - "Chunked prefill is enabled by default for models " - "with max_model_len > 32K. Chunked prefill might " - "not work with some features or models. If you " - "encounter any issues, please disable by launching " - "with --enable-chunked-prefill=False.") - - if self.enable_chunked_prefill is None: - self.enable_chunked_prefill = False - - if not self.enable_chunked_prefill and use_long_context: - logger.warning( - "The model has a long context length (%s). This may cause" - "OOM during the initial memory profiling phase, or result " - "in low performance due to small KV cache size. Consider " - "setting --max-model-len to a smaller value.", max_model_len) - elif (self.enable_chunked_prefill - and model_config.runner_type == "pooling"): - msg = "Chunked prefill is not supported for pooling models" - raise ValueError(msg) - - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching: - # Disable prefix caching for multimodal models for VLLM_V0. - if model_config.is_multimodal_model: - logger.warning( - "--enable-prefix-caching is not supported for multimodal " - "models in V0 and has been disabled.") - self.enable_prefix_caching = False - - # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo == "sha256": - raise ValueError( - "sha256 is not supported for prefix caching in V0 engine. " - "Please use 'builtin'.") - - # Set max_num_seqs to 256 for VLLM_V0. - if self.max_num_seqs is None: - self.max_num_seqs = 256 - - def _set_default_args_v1(self, usage_context: UsageContext, - model_config: ModelConfig) -> None: + def _set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig + ) -> None: """Set Default Arguments for V1 Engine.""" # V1 always uses chunked prefills and prefix caching @@ -1625,18 +1691,34 @@ class EngineArgs: # For pooling tasks the default is False if model_config.runner_type != "pooling": self.enable_chunked_prefill = True - if self.enable_prefix_caching is None: - self.enable_prefix_caching = True - else: + # TODO: When prefix caching supports prompt embeds inputs, this + # check can be removed. + if self.enable_prompt_embeds and self.enable_prefix_caching is not False: + logger.warning( + "--enable-prompt-embeds and --enable-prefix-caching " + "are not supported together in V1. Prefix caching has " + "been disabled." + ) + self.enable_prefix_caching = False + + if self.enable_prefix_caching is None: + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + if model_config.is_hybrid: + self.enable_prefix_caching = False + else: + self.enable_prefix_caching = True + else: pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) - incremental_prefill_supported = (pooling_type is not None - and pooling_type.lower() == "last" - and is_causal) + incremental_prefill_supported = ( + pooling_type is not None + and pooling_type.lower() == "last" + and is_causal + ) - action = "Enabling" if \ - incremental_prefill_supported else "Disabling" + action = "Enabling" if incremental_prefill_supported else "Disabling" if self.enable_chunked_prefill is None: self.enable_chunked_prefill = incremental_prefill_supported @@ -1670,6 +1752,7 @@ class EngineArgs: # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. from vllm.usage.usage_lib import UsageContext + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1695,15 +1778,15 @@ class EngineArgs: if current_platform.is_tpu(): default_max_num_batched_tokens_tpu = { UsageContext.LLM_CLASS: { - 'V6E': 2048, - 'V5E': 1024, - 'V5P': 512, + "V6E": 2048, + "V5E": 1024, + "V5P": 512, }, UsageContext.OPENAI_API_SERVER: { - 'V6E': 1024, - 'V5E': 512, - 'V5P': 256, - } + "V6E": 1024, + "V5E": 512, + "V5P": 256, + }, } # cpu specific default values. @@ -1719,47 +1802,58 @@ class EngineArgs: } use_context_value = usage_context.value if usage_context else None - if (self.max_num_batched_tokens is None - and usage_context in default_max_num_batched_tokens): + if ( + self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens + ): if current_platform.is_tpu(): chip_name = current_platform.get_device_name() - if chip_name in default_max_num_batched_tokens_tpu[ - usage_context]: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens_tpu[ - usage_context][chip_name] + if chip_name in default_max_num_batched_tokens_tpu[usage_context]: + self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ + usage_context + ][chip_name] else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] else: if not self.enable_chunked_prefill: self.max_num_batched_tokens = model_config.max_model_len else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", - self.max_num_batched_tokens, use_context_value) + self.max_num_batched_tokens, + use_context_value, + ) - if (self.max_num_seqs is None - and usage_context in default_max_num_seqs): - self.max_num_seqs = min(default_max_num_seqs[usage_context], - self.max_num_batched_tokens or sys.maxsize) + if self.max_num_seqs is None and usage_context in default_max_num_seqs: + self.max_num_seqs = min( + default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize, + ) - logger.debug("Setting max_num_seqs to %d for %s usage context.", - self.max_num_seqs, use_context_value) + logger.debug( + "Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, + use_context_value, + ) @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + enable_log_requests: bool = False @property @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self) -> bool: return not self.enable_log_requests @@ -1767,28 +1861,34 @@ class AsyncEngineArgs(EngineArgs): @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self, value: bool): self.enable_log_requests = not value @staticmethod - def add_cli_args(parser: FlexibleArgumentParser, - async_args_only: bool = False) -> FlexibleArgumentParser: + def add_cli_args( + parser: FlexibleArgumentParser, async_args_only: bool = False + ) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--enable-log-requests', - action=argparse.BooleanOptionalAction, - default=AsyncEngineArgs.enable_log_requests, - help='Enable logging requests.') - parser.add_argument('--disable-log-requests', - action=argparse.BooleanOptionalAction, - default=not AsyncEngineArgs.enable_log_requests, - help='[DEPRECATED] Disable logging requests.', - deprecated=True) + parser.add_argument( + "--enable-log-requests", + action=argparse.BooleanOptionalAction, + default=AsyncEngineArgs.enable_log_requests, + help="Enable logging requests.", + ) + parser.add_argument( + "--disable-log-requests", + action=argparse.BooleanOptionalAction, + default=not AsyncEngineArgs.enable_log_requests, + help="[DEPRECATED] Disable logging requests.", + deprecated=True, + ) current_platform.pre_register_and_update(parser) return parser @@ -1796,7 +1896,8 @@ class AsyncEngineArgs(EngineArgs): def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: raise NotImplementedError( - f"VLLM_USE_V1=1 is not supported with {feature_name}.") + f"VLLM_USE_V1=1 is not supported with {feature_name}." + ) msg = f"{feature_name} is not supported by the V1 Engine. " msg += "Falling back to V0. " if recommend_to_remove: @@ -1805,21 +1906,6 @@ def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): logger.warning(msg) -def _warn_or_fallback(feature_name: str) -> bool: - if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: - logger.warning( - "Detected VLLM_USE_V1=1 with %s. Usage should " - "be considered experimental. Please report any " - "issues on Github.", feature_name) - should_exit = False - else: - logger.info( - "%s is experimental on VLLM_USE_V1=1. " - "Falling back to V0 Engine.", feature_name) - should_exit = True - return should_exit - - def human_readable_int(value): """Parse human-readable integers like '1k', '2M', etc. Including decimal values with decimal multipliers. @@ -1830,17 +1916,17 @@ def human_readable_int(value): - '25.6k' -> 25,600 """ value = value.strip() - match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value) if match: decimal_multiplier = { - 'k': 10**3, - 'm': 10**6, - 'g': 10**9, + "k": 10**3, + "m": 10**6, + "g": 10**9, } binary_multiplier = { - 'K': 2**10, - 'M': 2**20, - 'G': 2**30, + "K": 2**10, + "M": 2**20, + "G": 2**30, } number, suffix = match.groups() @@ -1853,9 +1939,11 @@ def human_readable_int(value): try: return int(number) * mult except ValueError as e: - raise argparse.ArgumentTypeError("Decimals are not allowed " \ - f"with binary suffixes like {suffix}. Did you mean to use " \ - f"{number}{suffix.lower()} instead?") from e + raise argparse.ArgumentTypeError( + "Decimals are not allowed " + f"with binary suffixes like {suffix}. Did you mean to use " + f"{number}{suffix.lower()} instead?" + ) from e # Regular plain number. return int(value) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 84ad2299b0655..ede027759a8b2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,1126 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import time -import weakref -from functools import partial -from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, - Mapping, Optional, Set, Tuple, Type, Union) -from weakref import ReferenceType +from vllm.v1.engine.async_llm import AsyncLLM -import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.core.scheduler import SchedulerOutputs -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import LLMEngine -from vllm.engine.metrics_types import StatLoggerBase -from vllm.engine.protocol import EngineClient -from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import PromptType -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, deprecate_kwargs, weak_bind - -logger = init_logger(__name__) -ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S - - -class AsyncEngineDeadError(RuntimeError): - pass - - -def _log_task_completion(task: asyncio.Task, - error_callback: Callable[[Exception], None]) -> None: - """This function is only intended for the `engine.run_engine_loop()` task. - - In particular, that task runs a `while True` loop that can only exit if - there is an exception. - """ - - exception = None - try: - return_value = task.result() - raise AssertionError( - f"The engine background task should never finish without an " - f"exception. {return_value}") - except asyncio.exceptions.CancelledError: - # We assume that if the task is cancelled, we are gracefully shutting - # down. This should only happen on program exit. - logger.info("Engine is gracefully shutting down.") - except Exception as e: - exception = e - logger.error("Engine background task failed", exc_info=e) - error_callback(exception) - raise AsyncEngineDeadError( - "Task finished unexpectedly. This should never happen! " - "Please open an issue on GitHub. See stack trace above for the " - "actual cause.") from e - - -STOP_ITERATION = Exception() # Sentinel - - -class AsyncStream: - """A stream of RequestOutputs or PoolingRequestOutputs for a request - that can be iterated over asynchronously via an async generator.""" - - def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: - self.request_id = request_id - self._cancel = cancel - self._queue: asyncio.Queue = asyncio.Queue() - self._finished = False - - def put(self, item: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: - if not self._finished: - self._queue.put_nowait(item) - - def finish( - self, - exception: Optional[Union[BaseException, Type[BaseException]]] = None, - ) -> None: - if not self._finished: - self._finished = True - self._queue.put_nowait( - exception if self._is_raisable(exception) else STOP_ITERATION) - - @property - def finished(self) -> bool: - return self._finished - - async def generator( - self - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: - try: - while True: - result = await self._queue.get() - if self._is_raisable(result): - if result == STOP_ITERATION: - return - raise result - yield result - except GeneratorExit: - self._cancel(self.request_id) - raise asyncio.CancelledError from None - - @staticmethod - def _is_raisable(value: Any): - return isinstance(value, BaseException) or \ - (isinstance(value, type) and \ - issubclass(value, BaseException)) - - -class RequestTracker: - """Synchronous abstraction for tracking requests.""" - - def __init__(self) -> None: - self._request_streams: Dict[str, AsyncStream] = {} - self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[AsyncStream, - dict]] = asyncio.Queue() - self.new_requests_event = asyncio.Event() - - def __contains__(self, item): - return item in self._request_streams - - def __len__(self) -> int: - return len(self._request_streams) - - def propagate_exception(self, - exc: Exception, - request_id: Optional[str] = None) -> None: - """Propagate an exception to request streams - (all if request_id is None).""" - if request_id is not None: - self.abort_request(request_id, exception=exc) - else: - # NB: tuple() used here because self.abort_request pops the stream - # out of self._request_streams, so we can't iterate on it directly - for rid in tuple(self._request_streams.keys()): - self.abort_request(rid, exception=exc) - - def process_request_output(self, - request_output: Union[RequestOutput, - PoolingRequestOutput], - *, - verbose: bool = False) -> None: - """Process a request output from the engine.""" - request_id = request_output.request_id - finished = request_output.finished - - if finished: - stream = self._request_streams.pop(request_id, None) - else: - stream = self._request_streams.get(request_id) - # Guard against a KeyError which can occur if the request was aborted - # while the output was generated - if stream is not None: - stream.put(request_output) - if finished: - stream.finish() - - if verbose and finished: - logger.info("Finished request %s.", request_id) - - def process_exception(self, - request_id: str, - exception: BaseException, - *, - verbose: bool = False) -> None: - """Propagate an exception from the engine.""" - if verbose: - logger.info("Finished request %s.", request_id) - self.abort_request(request_id, exception=exception) - - def add_request(self, - request_id: str, - *, - verbose: bool = False, - **engine_add_request_kwargs) -> AsyncStream: - """Add a request to be sent to the engine on the next background - loop iteration.""" - if request_id in self._request_streams: - raise KeyError(f"Request {request_id} already exists.") - - abort_request = partial(self.abort_request, verbose=verbose) - stream = AsyncStream(request_id, abort_request) - self._new_requests.put_nowait((stream, { - "request_id": request_id, - **engine_add_request_kwargs - })) - - self.new_requests_event.set() - - if verbose: - logger.info("Added request %s.", request_id) - - return stream - - def abort_request(self, - request_id: str, - *, - exception: Optional[Union[BaseException, - Type[BaseException]]] = None, - verbose: bool = False) -> None: - """Abort a request during next background loop iteration.""" - if verbose: - logger.info("Aborted request %s.", request_id) - - self._aborted_requests.put_nowait(request_id) - - stream = self._request_streams.pop(request_id, None) - if stream is not None: - stream.finish(exception=exception) - - def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: - """Get the new requests and finished requests to be - sent to the engine.""" - new_requests: List[Dict] = [] - finished_requests: Set[str] = set() - - while not self._aborted_requests.empty(): - request_id = self._aborted_requests.get_nowait() - finished_requests.add(request_id) - - while not self._new_requests.empty(): - stream, new_request = self._new_requests.get_nowait() - request_id = stream.request_id - if request_id in finished_requests: - # The request has already been aborted. - stream.finish(asyncio.CancelledError) - finished_requests.discard(request_id) - else: - self._request_streams[request_id] = stream - new_requests.append(new_request) - - return new_requests, finished_requests - - async def wait_for_new_requests(self): - if not self.has_new_requests(): - await self.new_requests_event.wait() - self.new_requests_event.clear() - - def has_new_requests(self): - return not self._new_requests.empty() - - -class _AsyncLLMEngine(LLMEngine): - """Extension of LLMEngine to add async methods.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - async def step_async( - self, virtual_engine: int - ) -> List[Union[RequestOutput, PoolingRequestOutput]]: - """Performs one decoding iteration and returns newly generated results. - The workers are ran asynchronously if possible. - - This function performs one decoding iteration of the engine. It first - schedules the sequences to be executed in the next iteration and the - token blocks to be swapped in/out/copy. Then, it executes the model - and updates the scheduler with the model outputs. Finally, it decodes - the sequences and returns the newly generated results. - """ - # these are cached outputs from previous iterations. None if on first - # iteration - cached_outputs = self.cached_scheduler_outputs[virtual_engine] - seq_group_metadata_list = cached_outputs.seq_group_metadata_list - scheduler_outputs = cached_outputs.scheduler_outputs - allow_async_output_proc = cached_outputs.allow_async_output_proc - - ctx = self.scheduler_contexts[virtual_engine] - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() - - # skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): - - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs - - if not scheduler_outputs.is_empty(): - # this will cause mamba_cache/minimax_cache failed - # to release finished_requests_ids of the last steps - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() - - # Maybe switch from async mode to sync mode - if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - - else: - finished_requests_ids = list() - - assert seq_group_metadata_list is not None - assert scheduler_outputs is not None - - if not scheduler_outputs.is_empty(): - - # Check if we have a cached last_output from the previous iteration. - # For supporting PP this is probably the best way to pass the - # sampled_token_ids, as a separate broadcast over all the PP stages - # will cause one virtual engine's microbatch to block the pipeline. - last_sampled_token_ids = \ - self._get_last_sampled_token_ids(virtual_engine) - - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - virtual_engine=virtual_engine, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots, - running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids, - # We use ExecuteModelRequest to pass the last sampled_token_ids - # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) - - if allow_async_output_proc: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] - - # Execute the model. - outputs = await self.model_executor.execute_model_async( - execute_model_req) - - else: - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - outputs = [] - - if not self._has_remaining_steps(seq_group_metadata_list): - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. - is_first_step_output: bool = False if not seq_group_metadata_list \ - else seq_group_metadata_list[0].state.num_steps == 1 - - ctx.append_output(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=allow_async_output_proc, - is_last_step=True, - is_first_step_output=is_first_step_output) - - if outputs and allow_async_output_proc: - assert len( - outputs - ) == 1, "Async postprocessor expects only a single output set" - self._advance_to_next_step( - outputs[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) - - if not allow_async_output_proc: - self._process_model_outputs(ctx=ctx) - - # Log stats. - self.do_log_stats(scheduler_outputs, outputs) - - # Tracing - self.do_tracing(scheduler_outputs) - - else: - # Multi-step case - return ctx.request_outputs - - if not self.has_unfinished_requests(): - # Drain async postprocessor (if exists) - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - assert len(ctx.output_queue) == 0 - - return ctx.request_outputs - - async def stop_remote_worker_execution_loop_async(self) -> None: - """Stop the remote worker execution loop.""" - await self.model_executor.stop_remote_worker_execution_loop_async() - - async def get_tokenizer_async(self, - lora_request: Optional[LoRARequest] = None - ) -> AnyTokenizer: - return await ( - self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) - - async def add_request_async( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - data_parallel_rank: Optional[int] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> None: - """ - Async version of - [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request]. - """ - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - if priority != 0 and not self.scheduler_config.policy == "priority": - raise ValueError(f"Got priority {priority} but " - "Priority scheduling is not enabled.") - if arrival_time is None: - arrival_time = time.time() - - if data_parallel_rank is not None: - raise ValueError("Targeting data_parallel_rank only supported " - "in v1 client.") - - if (isinstance(prompt, dict) - and prompt.get("prompt_embeds", None) is not None - and not prompt.get("prompt_token_ids", None)): - # We use the -2 dimension (instead of 0) in case a batched input - # of batch size 1 is passed in. - prompt["prompt_token_ids"] = [0 - ] * prompt["prompt_embeds"].shape[-2] - - processed_inputs = await self.input_preprocessor.preprocess_async( - prompt, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ) - - async def check_health_async(self) -> None: - self.model_executor.check_health() - - async def collective_rpc_async(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): - raise NotImplementedError - - -class AsyncLLMEngine(EngineClient): - """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine]. - - This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to - make it asynchronous. It uses asyncio to create a background loop that keeps - processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked - by the generate method when there are requests in the waiting queue. The - generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine] - to the caller. - - Args: - log_requests: Whether to log the requests. - start_engine_loop: If True, the background task to run the engine - will be automatically started in the generate call. - *args: Arguments for [`LLMEngine`][vllm.LLMEngine]. - **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine]. - """ - - _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine - - def __init__(self, - *args, - log_requests: bool = True, - start_engine_loop: bool = True, - **kwargs) -> None: - if envs.VLLM_USE_V1: - raise ValueError( - "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " - "This should not happen. As a workaround, try using " - "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") - - self.log_requests = log_requests - self.engine = self._engine_class(*args, **kwargs) - - # This ensures quick processing of request outputs - # so the append to asyncio queues is not delayed, - # especially for multi-step. - self.use_process_request_outputs_callback = ( - self.engine.model_config.use_async_output_proc) - - if self.use_process_request_outputs_callback: - self.engine.process_request_outputs_callback = \ - weak_bind(self.process_request_outputs) - - self.background_loop: Optional[asyncio.Future] = None - # We need to keep a reference to unshielded - # task as well to prevent it from being garbage - # collected - self._background_loop_unshielded: Optional[asyncio.Task] = None - self.start_engine_loop = start_engine_loop - self._errored_with: Optional[BaseException] = None - - # Lazy initialized fields - self._request_tracker: RequestTracker - - def __del__(self): - if rt := getattr(self, "request_tracker", None): - # Wake up engine loop so that it will exit cleanly - rt.new_requests_event.set() - - @classmethod - def _get_executor_cls(cls, - engine_config: VllmConfig) -> Type[ExecutorBase]: - return LLMEngine._get_executor_cls(engine_config) - - @classmethod - @deprecate_kwargs( - "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), - ) - def from_vllm_config( - cls, - vllm_config: VllmConfig, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - disable_log_requests: bool = True, # Deprecated, will be removed - ) -> "AsyncLLMEngine": - """Create an AsyncLLMEngine from the EngineArgs.""" - - return cls( - vllm_config=vllm_config, - executor_class=cls._get_executor_cls(vllm_config), - start_engine_loop=start_engine_loop, - log_requests=enable_log_requests, - log_stats=not disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - - @classmethod - def from_engine_args( - cls, - engine_args: AsyncEngineArgs, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - - vllm_config = engine_args.create_engine_config(usage_context) - - async_engine_cls = cls - if envs.VLLM_USE_V1: - from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine - async_engine_cls = V1AsyncLLMEngine - - return async_engine_cls.from_vllm_config( - vllm_config=vllm_config, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - disable_log_stats=engine_args.disable_log_stats, - enable_log_requests=engine_args.enable_log_requests, - ) - - @property - def is_running(self) -> bool: - return (self.background_loop is not None - and self._background_loop_unshielded is not None - and not self._background_loop_unshielded.done()) - - @property - def is_stopped(self) -> bool: - return self.errored or (self.background_loop is not None and - self._background_loop_unshielded is not None - and self._background_loop_unshielded.done()) - - @property - def errored(self) -> bool: - return self._errored_with is not None - - @property - def dead_error(self) -> BaseException: - return AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError).") - - def set_errored(self, exc: Exception) -> None: - self._errored_with = exc - - def _error_callback(self, exc: Exception) -> None: - self.set_errored(exc) - self._request_tracker.propagate_exception(exc) - - async def get_input_preprocessor(self) -> InputPreprocessor: - return self.engine.input_preprocessor - - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return await self.engine.get_tokenizer_async(lora_request) - - def start_background_loop(self) -> None: - """Start the background loop.""" - if self.errored: - raise AsyncEngineDeadError( - "Background loop has errored already.") from self._errored_with - if self.is_running: - raise RuntimeError("Background loop is already running.") - # Initialize the RequestTracker here so it uses the right event loop. - self._request_tracker = RequestTracker() - - self._background_loop_unshielded = asyncio.get_event_loop( - ).create_task(self.run_engine_loop(weakref.ref(self))) - self._background_loop_unshielded.add_done_callback( - partial(_log_task_completion, error_callback=self._error_callback)) - self.background_loop = asyncio.shield(self._background_loop_unshielded) - - def shutdown_background_loop(self) -> None: - """ - Shut down the background loop. - - This method needs to be called during cleanup to remove - references to `self` and properly GC the resources held - by the async LLM engine (e.g., the executors as well as - their resources). - """ - if self._background_loop_unshielded is not None: - self._background_loop_unshielded.cancel() - self._background_loop_unshielded = None - self.background_loop = None - - async def engine_step(self, virtual_engine: int) -> bool: - """Kick the engine to process the waiting requests. - - Returns True if there are in-progress requests.""" - - new_requests, aborted_requests = ( - self._request_tracker.get_new_and_aborted_requests()) - - for new_request in new_requests: - # Add the request into the vLLM engine's waiting queue. - try: - await self.engine.add_request_async(**new_request) - except ValueError as e: - # TODO: use a vLLM specific error for failed validation - self._request_tracker.process_exception( - new_request["request_id"], - e, - verbose=self.log_requests, - ) - - if aborted_requests: - await self._engine_abort(aborted_requests) - - request_outputs = await self.engine.step_async(virtual_engine) - - # Put the outputs into the corresponding streams. - # If used as a callback, then already invoked inside - # LLMEngine's _process_model_outputs - if not self.use_process_request_outputs_callback: - all_finished = self.process_request_outputs(request_outputs) - else: - # For callback case, we only need to detect when all - # requests are finished - all_finished = all(request_output.finished - for request_output in request_outputs) - - return not all_finished - - def process_request_outputs(self, request_outputs) -> bool: - # Put the outputs into the corresponding streams. - all_finished = True - for request_output in request_outputs: - self._request_tracker.process_request_output( - request_output, verbose=self.log_requests) - all_finished = all_finished and request_output.finished - - return all_finished - - async def _engine_abort(self, request_ids: Iterable[str]): - self.engine.abort_request(request_ids) - - @staticmethod - async def run_engine_loop(engine_ref: ReferenceType): - """We use a weakref to the engine so that the running loop - doesn't prevent the engine being garbage collected.""" - engine: Optional[AsyncLLMEngine] = engine_ref() - if not engine: - return - - pipeline_parallel_size = \ - engine.engine.parallel_config.pipeline_parallel_size - has_requests_in_progress = [False] * pipeline_parallel_size - while True: - if not any(has_requests_in_progress): - logger.debug("Waiting for new requests...") - # Stop the execute model loop in parallel workers until there - # are more requests to process. This avoids waiting - # indefinitely in torch.distributed ops which may otherwise - # timeout, and unblocks the RPC thread in the workers so that - # they can process any other queued control plane messages, - # such as add/remove lora adapters. - await engine.engine.stop_remote_worker_execution_loop_async() - request_tracker = engine._request_tracker - # Allow engine to be garbage collected while - # waiting for new requests - del engine - await asyncio.sleep(0) - if engine_ref() is None: - return - await request_tracker.wait_for_new_requests() - engine = engine_ref() - if not engine: - return - logger.debug("Got new requests!") - requests_in_progress = [ - asyncio.create_task(engine.engine_step(ve)) - for ve in range(pipeline_parallel_size) - ] - has_requests_in_progress = [True] * pipeline_parallel_size - - # Abort if iteration takes too long due to unrecoverable errors - # (eg. NCCL timeouts). - try: - async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): - done, _ = await asyncio.wait( - requests_in_progress, - return_when=asyncio.FIRST_COMPLETED) - for _ in range(pipeline_parallel_size): - await asyncio.sleep(0) - for task in done: - result = task.result() - virtual_engine = requests_in_progress.index(task) - has_unfinished_requests = ( - engine.engine. - has_unfinished_requests_for_virtual_engine( - virtual_engine)) - if result or has_unfinished_requests: - requests_in_progress[virtual_engine] = ( - asyncio.create_task( - engine.engine_step(virtual_engine))) - has_requests_in_progress[virtual_engine] = True - else: - has_requests_in_progress[virtual_engine] = False - except asyncio.TimeoutError as exc: - logger.error( - "Engine iteration timed out. This should never happen!") - engine.set_errored(exc) - raise - await asyncio.sleep(0) - - async def add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - data_parallel_rank: Optional[int] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: - if not self.is_running: - if self.start_engine_loop: - self.start_background_loop() - else: - raise AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError).") - - if (priority != 0 - and not self.engine.scheduler_config.policy == "priority"): - raise ValueError(f"Got priority {priority} but " - "Priority scheduling is not enabled.") - - stream = self._request_tracker.add_request( - request_id, - verbose=self.log_requests, - prompt=prompt, - params=params, - arrival_time=arrival_time or time.time(), - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - data_parallel_rank=data_parallel_rank, - tokenization_kwargs=tokenization_kwargs, - ) - - return stream.generator() - - async def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - data_parallel_rank: Optional[int] = None, - ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - sampling_params: The sampling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - data_parallel_rank: The (global) data parallel rank that must - handle this request. Only applicable if DP is enabled. - Yields: - The output `RequestOutput` objects from the LLMEngine - for the request. - - Details: - - If the engine is not running, start the background loop, - which iteratively invokes - [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step] - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. - - Example: - >>> # Please refer to entrypoints/api_server.py for - >>> # the complete example. - >>> - >>> # initialize the engine and the example input - >>> # note that engine_args here is AsyncEngineArgs instance - >>> engine = AsyncLLMEngine.from_engine_args(engine_args) - >>> example_input = { - >>> "prompt": "What is LLM?", - >>> "stream": False, # assume the non-streaming case - >>> "temperature": 0.0, - >>> "request_id": 0, - >>> } - >>> - >>> # start the generation - >>> results_generator = engine.generate( - >>> example_input["prompt"], - >>> SamplingParams(temperature=example_input["temperature"]), - >>> example_input["request_id"]) - >>> - >>> # get the results - >>> final_output = None - >>> async for request_output in results_generator: - >>> if await request.is_disconnected(): - >>> # Abort the request if the client disconnects. - >>> await engine.abort(request_id) - >>> # Return or raise an error - >>> ... - >>> final_output = request_output - >>> - >>> # Process and return the final output - >>> ... - """ - try: - async for output in await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - data_parallel_rank=data_parallel_rank, - ): - yield LLMEngine.validate_output(output, RequestOutput) - except asyncio.CancelledError: - await self.abort(request_id) - raise - - async def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Yields: - The output `PoolingRequestOutput` objects from the LLMEngine - for the request. - - Details: - - If the engine is not running, start the background loop, - which iteratively invokes - [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][] - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. - - Example: - ``` - # Please refer to entrypoints/api_server.py for - # the complete example. - - # initialize the engine and the example input - # note that engine_args here is AsyncEngineArgs instance - engine = AsyncLLMEngine.from_engine_args(engine_args) - example_input = { - "input": "What is LLM?", - "request_id": 0, - } - - # start the generation - results_generator = engine.encode( - example_input["input"], - PoolingParams(), - example_input["request_id"]) - - # get the results - final_output = None - async for request_output in results_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - # Return or raise an error - ... - final_output = request_output - - # Process and return the final output - ... - ``` - """ - try: - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - tokenization_kwargs=tokenization_kwargs, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) - except asyncio.CancelledError: - await self.abort(request_id) - raise - - async def abort(self, request_id: Union[str, Iterable[str]]) -> None: - """Abort a request. - - Abort a submitted request. If the request is finished or not found, - this method will be a no-op. - - Args: - request_id: The unique id of the request. - """ - if not isinstance(request_id, str): - raise RuntimeError("Only single-request abort supported in" - " deprecated V0") - if not self.is_running: - raise AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError).") - - return self._abort(request_id) - - def _abort(self, request_id: str) -> None: - """Abort a request. - - Abort a submitted request. If the request is finished or not found, - this method will be a no-op. - - Args: - request_id: The unique id of the request. - """ - self._request_tracker.abort_request(request_id, - exception=asyncio.CancelledError, - verbose=self.log_requests) - - async def get_vllm_config(self) -> VllmConfig: - """Get the vllm configuration of the vLLM engine.""" - return self.engine.get_vllm_config() - - async def get_model_config(self) -> ModelConfig: - """Get the model configuration of the vLLM engine.""" - return self.engine.get_model_config() - - async def get_parallel_config(self) -> ParallelConfig: - """Get the parallel configuration of the vLLM engine.""" - return self.engine.get_parallel_config() - - async def get_decoding_config(self) -> DecodingConfig: - """Get the decoding configuration of the vLLM engine.""" - return self.engine.get_decoding_config() - - async def get_scheduler_config(self) -> SchedulerConfig: - """Get the scheduling configuration of the vLLM engine.""" - return self.engine.get_scheduler_config() - - async def get_lora_config(self) -> LoRAConfig: - """Get the lora configuration of the vLLM engine.""" - return self.engine.get_lora_config() - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None) -> None: - self.engine.do_log_stats() - - async def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - t = time.perf_counter() - logger.debug("Starting health check...") - if self.is_stopped: - raise AsyncEngineDeadError("Background loop is stopped.") - - await self.engine.check_health_async() - logger.debug("Health check took %fs", time.perf_counter() - t) - - async def is_tracing_enabled(self) -> bool: - return self.engine.is_tracing_enabled() - - def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - self.engine.add_logger(logger_name=logger_name, logger=logger) - - def remove_logger(self, logger_name: str) -> None: - self.engine.remove_logger(logger_name=logger_name) - - async def start_profile(self) -> None: - self.engine.start_profile() - - async def stop_profile(self) -> None: - self.engine.stop_profile() - - async def reset_mm_cache(self) -> None: - self.engine.reset_mm_cache() - - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: - self.engine.reset_prefix_cache(device) - - async def sleep(self, level: int = 1) -> None: - await self.reset_prefix_cache() - self.engine.sleep(level) - - async def wake_up(self, tags: Optional[list[str]] = None) -> None: - self.engine.wake_up(tags) - - async def is_sleeping(self) -> bool: - return self.engine.is_sleeping() - - async def add_lora(self, lora_request: LoRARequest) -> None: - self.engine.add_lora(lora_request) - - async def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): - """ - Perform a collective RPC call to the given path. - """ - return await self.engine.collective_rpc_async(method, timeout, args, - kwargs) - - -# TODO(v1): Remove this class proxy when V1 goes default. -if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: - from vllm.v1.engine.async_llm import AsyncLLM - - AsyncLLMEngine = AsyncLLM # type: ignore +AsyncLLMEngine = AsyncLLM # type: ignore diff --git a/vllm/engine/async_timeout.py b/vllm/engine/async_timeout.py deleted file mode 100644 index 28a023a71ef52..0000000000000 --- a/vllm/engine/async_timeout.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Workaround for https://github.com/python/cpython/issues/86296 -# -# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py -# Licensed under the Apache License (Apache-2.0) - -import asyncio -import enum -import sys -from types import TracebackType -from typing import Any, Optional, Type - -if sys.version_info[:2] >= (3, 11): - from asyncio import timeout as asyncio_timeout -else: - - def asyncio_timeout(delay: Optional[float]) -> "Timeout": - """timeout context manager. - Useful in cases when you want to apply timeout logic around block - of code or in cases when asyncio.wait_for is not suitable. For example: - >>> async with timeout(0.001): - ... async with aiohttp.get('https://github.com') as r: - ... await r.text() - delay - value in seconds or None to disable timeout logic - """ - loop = asyncio.get_running_loop() - deadline = loop.time() + delay if delay is not None else None - return Timeout(deadline, loop) - - class _State(enum.Enum): - INIT = "INIT" - ENTER = "ENTER" - TIMEOUT = "TIMEOUT" - EXIT = "EXIT" - - class Timeout: - # Internal class, please don't instantiate it directly - # Use timeout() and timeout_at() public factories instead. - # - # Implementation note: `async with timeout()` is preferred - # over `with timeout()`. - # While technically the Timeout class implementation - # doesn't need to be async at all, - # the `async with` statement explicitly points that - # the context manager should be used from async function context. - # - # This design allows to avoid many silly misusages. - # - # TimeoutError is raised immediately when scheduled - # if the deadline is passed. - # The purpose is to time out as soon as possible - # without waiting for the next await expression. - - __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") - - def __init__(self, deadline: Optional[float], - loop: asyncio.AbstractEventLoop) -> None: - self._loop = loop - self._state = _State.INIT - - self._timeout_handler = None # type: Optional[asyncio.Handle] - if deadline is None: - self._deadline = None # type: Optional[float] - else: - self.update(deadline) - - async def __aenter__(self) -> "Timeout": - self._do_enter() - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - self._do_exit(exc_type) - return None - - @property - def expired(self) -> bool: - """Is timeout expired during execution?""" - return self._state == _State.TIMEOUT - - @property - def deadline(self) -> Optional[float]: - return self._deadline - - def reject(self) -> None: - """Reject scheduled timeout if any.""" - # cancel is maybe better name but - # task.cancel() raises CancelledError in asyncio world. - if self._state not in (_State.INIT, _State.ENTER): - raise RuntimeError(f"invalid state {self._state.value}") - self._reject() - - def _reject(self) -> None: - if self._timeout_handler is not None: - self._timeout_handler.cancel() - self._timeout_handler = None - - def shift(self, delay: float) -> None: - """Advance timeout on delay seconds. - The delay can be negative. - Raise RuntimeError if shift is called when deadline is not scheduled - """ - deadline = self._deadline - if deadline is None: - raise RuntimeError( - "cannot shift timeout if deadline is not scheduled") - self.update(deadline + delay) - - def update(self, deadline: float) -> None: - """Set deadline to absolute value. - deadline argument points on the time in the same clock system - as loop.time(). - If new deadline is in the past the timeout is raised immediately. - Please note: it is not POSIX time but a time with - undefined starting base, e.g. the time of the system power on. - """ - if self._state == _State.EXIT: - raise RuntimeError( - "cannot reschedule after exit from context manager") - if self._state == _State.TIMEOUT: - raise RuntimeError("cannot reschedule expired timeout") - if self._timeout_handler is not None: - self._timeout_handler.cancel() - self._deadline = deadline - if self._state != _State.INIT: - self._reschedule() - - def _reschedule(self) -> None: - assert self._state == _State.ENTER - deadline = self._deadline - if deadline is None: - return - - now = self._loop.time() - if self._timeout_handler is not None: - self._timeout_handler.cancel() - - task = asyncio.current_task() - if deadline <= now: - self._timeout_handler = self._loop.call_soon( - self._on_timeout, task) - else: - self._timeout_handler = self._loop.call_at( - deadline, self._on_timeout, task) - - def _do_enter(self) -> None: - if self._state != _State.INIT: - raise RuntimeError(f"invalid state {self._state.value}") - self._state = _State.ENTER - self._reschedule() - - def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: - if exc_type is asyncio.CancelledError and \ - self._state == _State.TIMEOUT: - self._timeout_handler = None - raise asyncio.TimeoutError - # timeout has not expired - self._state = _State.EXIT - self._reject() - return None - - def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None: - if task: - task.cancel() - self._state = _State.TIMEOUT - # drop the reference early - self._timeout_handler = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351e87c..a0fe38eb320d6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,1895 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from collections import Counter as collectionsCounter -from collections import deque -from contextlib import contextmanager -from dataclasses import dataclass -from functools import partial -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Literal, Mapping, NamedTuple, Optional) -from typing import Sequence as GenericSequence -from typing import Set, Type, Union, cast +from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine -import torch -from typing_extensions import TypeVar - -import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - VllmConfig) -from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase, Stats -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.entrypoints.openai.logits_processors import ( - get_logits_processors as get_openai_logits_processors) -from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs -from vllm.inputs.parse import split_enc_dec_inputs -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.logits_process import get_bad_words_logits_processors -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.processing import EncDecMultiModalProcessor -from vllm.outputs import (PoolingRequestOutput, RequestOutput, - RequestOutputFactory) -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, - PoolingSequenceGroupOutput, Sequence, SequenceGroup, - SequenceGroupBase, SequenceGroupMetadata, - SequenceGroupOutput, SequenceStatus) -from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, - init_tracer) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import ( - TokenizerGroup, init_tokenizer_from_configs) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind -from vllm.version import __version__ as VLLM_VERSION -from vllm.worker.model_runner_base import InputProcessingError - -logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - -_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) -_R = TypeVar("_R", default=Any) - - -@dataclass -class SchedulerOutputState: - """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - allow_async_output_proc: bool = False - last_output: Optional[SamplerOutput] = None - - -class OutputData(NamedTuple): - outputs: List[SamplerOutput] - seq_group_metadata_list: List[SequenceGroupMetadata] - scheduler_outputs: SchedulerOutputs - is_async: bool - is_last_step: bool - # Indicates if this output is from the first step of the - # multi-step. When multi-step is disabled, this is always - # set to True. - # is_first_step_output is invalid when `outputs` has - # outputs from multiple steps. - is_first_step_output: Optional[bool] - skip: List[int] - - -class SchedulerContext: - - def __init__(self) -> None: - self.output_queue: Deque[OutputData] = deque() - self.request_outputs: List[Union[RequestOutput, - PoolingRequestOutput]] = [] - self.seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None - self.scheduler_outputs: Optional[SchedulerOutputs] = None - - def append_output(self, outputs: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool, - is_first_step_output: Optional[bool]): - self.output_queue.append( - OutputData(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=is_async, - is_last_step=is_last_step, - is_first_step_output=is_first_step_output, - skip=[])) - - -class LLMEngine: - """An LLM engine that receives requests and generates texts. - - This is the main class for the vLLM engine. It receives requests - from clients and generates texts from the LLM. It includes a tokenizer, a - language model (possibly distributed across multiple GPUs), and GPU memory - space allocated for intermediate states (aka KV cache). This class utilizes - iteration-level scheduling and efficient memory management to maximize the - serving throughput. - - The [`LLM`][vllm.LLM] class wraps this class for offline batched inference - and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine] - class wraps this class for online serving. - - The config arguments are derived from [`EngineArgs`][vllm.EngineArgs]. - - Args: - vllm_config: The configuration for initializing and running vLLM. - executor_class: The model executor class for managing distributed - execution. - log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection. - """ - - DO_VALIDATE_OUTPUT: ClassVar[bool] = False - """A flag to toggle whether to validate the type of request output.""" - - @classmethod - @contextmanager - def enable_output_validation(cls): - cls.DO_VALIDATE_OUTPUT = True - - yield - - cls.DO_VALIDATE_OUTPUT = False - - @classmethod - def validate_output( - cls, - output: object, - output_type: Type[_O], - ) -> _O: - do_validate = cls.DO_VALIDATE_OUTPUT - - if ((TYPE_CHECKING or do_validate) - and not isinstance(output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - return cast(_O, output) - - @classmethod - def validate_outputs( - cls, - outputs: GenericSequence[object], - output_type: Type[_O], - ) -> List[_O]: - do_validate = cls.DO_VALIDATE_OUTPUT - - outputs_: List[_O] - if TYPE_CHECKING or do_validate: - outputs_ = [] - for output in outputs: - if not isinstance(output, output_type): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - outputs_.append(output) - else: - outputs_ = outputs - - return outputs_ - - tokenizer: Optional[TokenizerGroup] - - def __init__( - self, - vllm_config: VllmConfig, - executor_class: Type[ExecutorBase], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - use_cached_outputs: bool = False, - ) -> None: - if envs.VLLM_USE_V1: - raise ValueError( - "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. " - "This should not happen. As a workaround, try using " - "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") - - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config # noqa - self.load_config = vllm_config.load_config - self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa - ) - self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa - ) - - logger.info( - "Initializing a V0 LLM engine (v%s) with config: %s, " - "use_cached_outputs=%s, ", - VLLM_VERSION, - vllm_config, - use_cached_outputs, - ) - - self.log_stats = log_stats - self.use_cached_outputs = use_cached_outputs - - if self.model_config.skip_tokenizer_init: - self.tokenizer = None - self.detokenizer = None - tokenizer_group = None - else: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - - # Ensure that the function doesn't contain a reference to self, - # to avoid engine GC issues - def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: - assert tokenizer_group, ("tokenizer_group cannot be None, " - "make sure skip_tokenizer_init is False") - return tokenizer_group.get_lora_tokenizer(sequence.lora_request) - - self.seq_counter = Counter() - self.generation_config_fields = ( - self.model_config.try_get_generation_config()) - - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) - - self.model_executor = executor_class(vllm_config=vllm_config) - - if self.model_config.runner_type != "pooling": - self._initialize_kv_caches() - - # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - usage_message.report_usage( - get_architecture_class_name(self.model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(self.model_config.dtype), - "tensor_parallel_size": - self.parallel_config.tensor_parallel_size, - "block_size": - self.cache_config.block_size, - "gpu_memory_utilization": - self.cache_config.gpu_memory_utilization, - - # Quantization - "quantization": - self.model_config.quantization, - "kv_cache_dtype": - str(self.cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(self.lora_config), - "enable_prefix_caching": - self.cache_config.enable_prefix_caching, - "enforce_eager": - self.model_config.enforce_eager, - "disable_custom_all_reduce": - self.parallel_config.disable_custom_all_reduce, - }) - - self.cached_scheduler_outputs = [ - SchedulerOutputState() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.scheduler_contexts = [ - SchedulerContext() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - if self.model_config.use_async_output_proc: - process_model_outputs = weak_bind(self._process_model_outputs) - - self.async_callbacks = [ - partial(process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - else: - self.async_callbacks = [] - - # Currently used by AsyncLLMEngine to ensure quick append - # of request outputs to asyncio queues - self.process_request_outputs_callback: Optional[Callable] = None - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): - Scheduler = resolve_obj_by_qualname( - self.vllm_config.scheduler_config.scheduler_cls) - else: - Scheduler = self.vllm_config.scheduler_config.scheduler_cls - self.scheduler = [ - Scheduler( - self.scheduler_config, self.cache_config, self.lora_config, - self.parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] - if self.model_config.use_async_output_proc else None) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - - # Metric Logging. - if self.log_stats: - if stat_loggers is not None: - self.stat_loggers = stat_loggers - else: - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from vllm.engine.metrics import (LoggingStatLogger, - PrometheusStatLogger) - - self.stat_loggers = { - "logging": - LoggingStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - vllm_config=vllm_config), - "prometheus": - PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict( - model_name=self.model_config.served_model_name), - vllm_config=vllm_config), - } - self.stat_loggers["prometheus"].info("cache_config", - self.cache_config) - - self.tracer = None - if self.observability_config.otlp_traces_endpoint: - self.tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) - - # Create sequence output processor, e.g. for beam search or - # speculative decoding. - self.output_processor = ( - SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - get_tokenizer_for_seq, - stop_checker=StopChecker(self.scheduler_config.max_model_len, - get_tokenizer_for_seq), - )) - - self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} - - # Flag to set when an input fails to process and the engine should run - # the next step without re-scheduling. - self._skip_scheduling_next_step = False - - # Don't keep the dummy data in memory - self.reset_mm_cache() - - def _initialize_kv_caches(self) -> None: - """Initialize the KV cache in the worker(s). - - The workers will determine the number of blocks in both the GPU cache - and the swap CPU cache. - """ - start = time.time() - num_gpu_blocks, num_cpu_blocks = ( - self.model_executor.determine_num_available_blocks()) - - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) - elapsed = time.time() - start - logger.info(("init engine (profile, create kv cache, " - "warmup model) took %.2f seconds"), elapsed) - - @classmethod - def _get_executor_cls(cls, - engine_config: VllmConfig) -> Type[ExecutorBase]: - # distributed_executor_backend must be set in VllmConfig.__post_init__ - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") - executor_class = distributed_executor_backend - elif distributed_executor_backend == "ray": - from vllm.executor.ray_distributed_executor import ( - RayDistributedExecutor) - executor_class = RayDistributedExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.mp_distributed_executor import ( - MultiprocessingDistributedExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingDistributedExecutor - elif distributed_executor_backend == "uni": - # JAX-style, single-process, multi-device executor. - from vllm.executor.uniproc_executor import UniProcExecutor - executor_class = UniProcExecutor - elif distributed_executor_backend == "external_launcher": - # executor with external launcher - from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher) - executor_class = ExecutorWithExternalLauncher - else: - raise ValueError("unrecognized distributed_executor_backend: " - f"{distributed_executor_backend}") - return executor_class - - @classmethod - def from_vllm_config( - cls, - vllm_config: VllmConfig, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - disable_log_stats: bool = False, - ) -> "LLMEngine": - return cls( - vllm_config=vllm_config, - executor_class=cls._get_executor_cls(vllm_config), - log_stats=(not disable_log_stats), - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - - @classmethod - def from_engine_args( - cls, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - vllm_config = engine_args.create_engine_config(usage_context) - - engine_cls = cls - if envs.VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - engine_cls = V1LLMEngine - - return engine_cls.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - stat_loggers=stat_loggers, - disable_log_stats=engine_args.disable_log_stats, - ) - - def __reduce__(self): - # This is to ensure that the LLMEngine is not referenced in - # the closure used to initialize Ray worker actors - raise RuntimeError("LLMEngine should not be pickled!") - - def __del__(self): - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): - model_executor.shutdown() - - def get_tokenizer_group(self) -> TokenizerGroup: - if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") - - return self.tokenizer - - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - - def _init_tokenizer(self) -> TokenizerGroup: - return init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - lora_config=self.lora_config) - - def _verify_args(self) -> None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: - self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) - - def _add_processed_request( - self, - request_id: str, - processed_inputs: ProcessorInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> Optional[SequenceGroup]: - """Add a processed request to the engine's request pool. - return the created sequence group. - """ - if isinstance(params, SamplingParams) and params.n > 1: - ParallelSampleSequenceGroup.add_request( - request_id, - self, - params, - processed_inputs=processed_inputs, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ) - return None - - self._validate_model_inputs(processed_inputs, lora_request) - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - - seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, - lora_request) - - encoder_seq = (None if encoder_inputs is None else Sequence( - seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) - - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - encoder_seq=encoder_seq, - priority=priority) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - encoder_seq=encoder_seq, - priority=priority) - else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") - - # Add the sequence group to the scheduler with least unfinished seqs. - costs = [ - scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler - ] - min_cost_scheduler = self.scheduler[costs.index(min(costs))] - min_cost_scheduler.add_seq_group(seq_group) - - return seq_group - - def stop_remote_worker_execution_loop(self) -> None: - self.model_executor.stop_remote_worker_execution_loop() - - def add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> None: - """Add a request to the engine's request pool. - - The request is added to the request pool and will be processed by the - scheduler as `engine.step()` is called. The exact scheduling policy is - determined by the scheduler. - - Args: - request_id: The unique ID of the request. - prompt: The prompt to the LLM. See - [PromptType][vllm.inputs.PromptType] - for more details about the format of each input. - params: Parameters for sampling or pooling. - [SamplingParams][vllm.SamplingParams] for text generation. - [PoolingParams][vllm.PoolingParams] for pooling. - arrival_time: The arrival time of the request. If None, we use - the current monotonic time. - lora_request: The LoRA request to add. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Details: - - Set arrival_time to the current time if it is None. - - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of [Sequence][vllm.Sequence] objects. - - Create a [SequenceGroup][vllm.SequenceGroup] object - from the list of [Sequence][vllm.Sequence]. - - Add the [SequenceGroup][vllm.SequenceGroup] object to the - scheduler. - - Example: - >>> # initialize engine - >>> engine = LLMEngine.from_engine_args(engine_args) - >>> # set request arguments - >>> example_prompt = "Who is the president of the United States?" - >>> sampling_params = SamplingParams(temperature=0.0) - >>> request_id = 0 - >>> - >>> # add the request to the engine - >>> engine.add_request( - >>> str(request_id), - >>> example_prompt, - >>> SamplingParams(temperature=0.0)) - >>> # continue the request processing - >>> ... - """ - if not isinstance(request_id, str): - raise TypeError( - f"request_id must be a string, got {type(request_id)}") - - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - - if priority != 0 and not self.scheduler_config.policy == "priority": - raise ValueError(f"Got priority {priority} but " - "Priority scheduling is not enabled.") - - if isinstance(params, SamplingParams) \ - and params.logits_processors: - raise ValueError( - "Logits processors are not supported in multi-step decoding") - - if arrival_time is None: - arrival_time = time.time() - - if (isinstance(prompt, dict) - and prompt.get("prompt_embeds", None) is not None - and not prompt.get("prompt_token_ids", None)): - seq_len = prompt["prompt_embeds"].shape[0] - prompt["prompt_token_ids"] = [0] * seq_len - - processed_inputs = self.input_preprocessor.preprocess( - prompt, - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - ) - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ) - - def _create_sequence_group_with_sampling( - self, - request_id: str, - seq: Sequence, - sampling_params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with SamplingParams.""" - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") - - sampling_params = self._build_logits_processors( - sampling_params, lora_request) - - # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.clone() - - sampling_params.update_from_generation_config( - self.generation_config_fields, seq.eos_token_id) - - # Create the sequence group. - draft_size = 1 - if self.vllm_config.speculative_config is not None: - draft_size = \ - self.vllm_config.speculative_config.num_speculative_tokens + 1 - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - encoder_seq=encoder_seq, - priority=priority, - draft_size=draft_size) - - return seq_group - - def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with PoolingParams.""" - # Defensive copy of PoolingParams, which are used by the pooler - pooling_params = pooling_params.clone() - # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - encoder_seq=encoder_seq, - priority=priority) - return seq_group - - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - """Aborts a request(s) with the given ID. - - Args: - request_id: The ID(s) of the request to abort. - - Details: - - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][]. - - Example: - >>> # initialize engine and add a request with request_id - >>> request_id = str(0) - >>> # abort the request - >>> engine.abort_request(request_id) - """ - for scheduler in self.scheduler: - scheduler.abort_seq_group( - request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) - - def get_vllm_config(self) -> VllmConfig: - """Gets the vllm configuration.""" - return self.vllm_config - - def get_model_config(self) -> ModelConfig: - """Gets the model configuration.""" - return self.model_config - - def get_parallel_config(self) -> ParallelConfig: - """Gets the parallel configuration.""" - return self.parallel_config - - def get_decoding_config(self) -> DecodingConfig: - """Gets the decoding configuration.""" - return self.decoding_config - - def get_scheduler_config(self) -> SchedulerConfig: - """Gets the scheduler configuration.""" - return self.scheduler_config - - def get_lora_config(self) -> LoRAConfig: - """Gets the LoRA configuration.""" - return self.lora_config - - def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return sum(scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler) - - def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return any(scheduler.has_unfinished_seqs() - for scheduler in self.scheduler) - - def has_unfinished_requests_for_virtual_engine( - self, virtual_engine: int) -> bool: - """ - Returns True if there are unfinished requests for the virtual engine. - """ - return self.scheduler[virtual_engine].has_unfinished_seqs() - - def reset_mm_cache(self) -> bool: - """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache( - self.model_config) - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache for all devices.""" - - success = True - for scheduler in self.scheduler: - success = success and scheduler.reset_prefix_cache(device) - return success - - @staticmethod - def _process_sequence_group_outputs( - seq_group: SequenceGroup, - outputs: List[PoolingSequenceGroupOutput], - ) -> None: - seq_group.pooled_data = outputs[0].data - - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_STOPPED - - return - - def _process_model_outputs(self, - ctx: SchedulerContext, - request_id: Optional[str] = None) -> None: - """Apply the model output to the sequences in the scheduled seq groups - and return responses. - - ctx: The virtual engine context to work on - request_id: If provided, then only this request is going to be processed - """ - - now = time.time() - - if len(ctx.output_queue) == 0: - return None - - # Get pending async postprocessor - if request_id: - # When we process only one request, no pop is required - # (since later we will process all of the rest) - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, skip) = ctx.output_queue[0] - else: - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, - skip) = ctx.output_queue.popleft() - - # Sanity check - assert len(seq_group_metadata_list) == len( - scheduler_outputs.scheduled_seq_groups) - - has_multiple_outputs: bool = len(outputs) > 1 - outputs_by_sequence_group: List[List[SequenceGroupOutput]] - assert not has_multiple_outputs - outputs_by_sequence_group = outputs - - # Determine the requests we need to operate on - if request_id: - indices = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): - if seq_group_meta.request_id == request_id: - assert i not in skip # Cannot be called twice - indices.append(i) - break - - # If the request_id was not found, then it means that - # this is a new request that has no pending async - # postprocessor - if not indices: - return - else: - indices = range(len(seq_group_metadata_list)) # type: ignore - - finished_before: List[int] = [] - finished_now: List[int] = [] - for i in indices: - if i in skip: - continue - - seq_group_meta = seq_group_metadata_list[i] - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group: SequenceGroup = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - finished_before.append(i) - continue - - output: List[SequenceGroupOutput] - if has_multiple_outputs: - output = outputs_by_sequence_group[i] - else: - output = [outputs_by_sequence_group[0][i]] - - if not is_async: - seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size or 0) - - if outputs: - for o in outputs: - if (isinstance(o, SamplerOutput) - and seq_group.metrics is not None): - if seq_group.metrics.model_forward_time is not None: - seq_group.metrics.model_forward_time += ( - o.model_forward_time or 0) - else: - seq_group.metrics.model_forward_time = ( - o.model_forward_time) - if seq_group.metrics.model_execute_time is not None: - seq_group.metrics.model_execute_time += ( - o.model_execute_time or 0) - else: - seq_group.metrics.model_execute_time = ( - o.model_execute_time) - - if self.model_config.runner_type == "pooling": - self._process_sequence_group_outputs(seq_group, output) - else: - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs( - seq_group, output, is_async) - - if seq_group.is_finished(): - finished_now.append(i) - - # Generate outputs for the requests that finished this iteration - for i in finished_now: - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - if not seq_group.is_prefill(): - seq_group.set_last_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs) - if request_output: - ctx.request_outputs.append(request_output) - - # When we process a single request, we skip it for the next time, - # and invoke the request output callback (if there was final output) - if request_id: - assert len(indices) == 1 - skip.append(indices[0]) - - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - # Free currently finished requests - if finished_now: - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() - - # Create the outputs - for i in indices: - if i in skip or i in finished_before or i in finished_now: - continue # Avoids double processing - - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - if not seq_group.is_prefill(): - seq_group.set_last_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs) - if request_output: - ctx.request_outputs.append(request_output) - - # Create outputs only after processing the scheduler's results - - for seq_group in scheduler_outputs.ignored_seq_groups: - params = seq_group.sampling_params - if params is not None and params.output_kind == ( - RequestOutputKind.DELTA) and not seq_group.is_finished(): - continue - - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs, - ) - if request_output: - ctx.request_outputs.append(request_output) - - # Immediately process request outputs here (if callback is given) - if (ctx.request_outputs - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - - # For async case, we need to record the stats here. - # For non-async case, the stats are done in the - # LLMEngine/AsyncLLMEngine directly - if is_async: - # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before, - skip) - - # Tracing - self.do_tracing(scheduler_outputs, finished_before) - - return None - - def _advance_to_next_step( - self, output: SamplerOutput, - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: - """Given model output from a single run, append the tokens to the - sequences. This is normally done inside output processor, but it is - required if the worker is to perform async forward pass to next step. - """ - for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ - zip(seq_group_metadata_list, output, scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - continue - - token_chunk_size = (seq_group_metadata.token_chunk_size - if seq_group_metadata.token_chunk_size - is not None else 0) - seq_group.update_num_computed_tokens(token_chunk_size) - - if seq_group_metadata.do_sample: - assert len(sequence_group_outputs.samples) == 1, ( - "Async output processor expects a single sample" - " (i.e sampling_params.n == 1)") - sample = sequence_group_outputs.samples[0] - - assert len(seq_group.seqs) == 1 - seq = seq_group.seqs[0] - - seq.append_token_id(sample.output_token, sample.logprobs, - sample.output_embed) - - def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: - """Performs one decoding iteration and returns newly generated results. - - <figure markdown="span"> - ![Overview of the step function](https://i.imgur.com/sv2HssD.png) - <figcaption>Overview of the step function</figcaption> - </figure> - - Details: - - Step 1: Schedules the sequences to be executed in the next - iteration and the token blocks to be swapped in/out/copy. - - - Depending on the scheduling policy, - sequences may be `preempted/reordered`. - - A Sequence Group (SG) refer to a group of sequences - that are generated from the same prompt. - - - Step 2: Calls the distributed executor to execute the model. - - Step 3: Processes the model output. This mainly includes: - - - Decodes the relevant outputs. - - Updates the scheduled sequence groups with model outputs - based on its `sampling parameters` (`use_beam_search` or not). - - Frees the finished sequence groups. - - - Finally, it creates and returns the newly generated results. - - Example: - ``` - # Please see the example/ folder for more detailed examples. - - # initialize engine and request arguments - engine = LLMEngine.from_engine_args(engine_args) - example_inputs = [(0, "What is LLM?", - SamplingParams(temperature=0.0))] - - # Start the engine with an event loop - while True: - if example_inputs: - req_id, prompt, sampling_params = example_inputs.pop(0) - engine.add_request(str(req_id),prompt,sampling_params) - - # continue the request processing - request_outputs = engine.step() - for request_output in request_outputs: - if request_output.finished: - # return or show the request output - - if not (engine.has_unfinished_requests() or example_inputs): - break - ``` - """ - if self.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported through AsyncLLMEngine " - "as performance will be severely degraded otherwise.") - - # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0. - virtual_engine = 0 - - # These are cached outputs from previous iterations. None if on first - # iteration - cached_outputs = self.cached_scheduler_outputs[virtual_engine] - seq_group_metadata_list = cached_outputs.seq_group_metadata_list - scheduler_outputs = cached_outputs.scheduler_outputs - allow_async_output_proc = cached_outputs.allow_async_output_proc - - ctx = self.scheduler_contexts[virtual_engine] - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() - - # Skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. - # The scheduler is also skipped if a single request caused the last - # engine step to fail, and the previous schedule needs to be rerun. - if not self._has_remaining_steps( - seq_group_metadata_list - ) and not self._skip_scheduling_next_step: - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs - - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() - # When n>1, elements in self.seq_id_to_seq_group should be deleted - # here, otherwise memory leaks. - for finished_request_id in finished_requests_ids: - if finished_request_id in self.seq_id_to_seq_group: - del self.seq_id_to_seq_group[finished_request_id] - - # Maybe switch from async mode to sync mode - if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - - else: - finished_requests_ids = list() - - assert seq_group_metadata_list is not None - assert scheduler_outputs is not None - - if not scheduler_outputs.is_empty(): - - # Check if we have a cached last_output from the previous iteration. - # For supporting PP this is probably the best way to pass the - # sampled_token_ids, as a separate broadcast over all the PP stages - # will cause one virtual engine's microbatch to block the pipeline. - last_sampled_token_ids = \ - self._get_last_sampled_token_ids(virtual_engine) - - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots, - running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids, - # We use ExecuteModelRequest to pass the last sampled_token_ids - # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) - - if allow_async_output_proc: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] - - try: - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) - self._skip_scheduling_next_step = False - except InputProcessingError as e: - # The input for this request cannot be processed, so we must - # abort it. If there are remaining requests in the batch that - # have been scheduled, they will be retried on the next step. - invalid_request_id = e.request_id - self._abort_and_cache_schedule( - request_id=invalid_request_id, - virtual_engine=virtual_engine, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - allow_async_output_proc=allow_async_output_proc) - # Raise so the caller is notified that this request failed - raise - - else: - # Nothing scheduled => If there is pending async postprocessor, - # then finish it here. - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - # No outputs in this case - outputs = [] - - if not self._has_remaining_steps(seq_group_metadata_list): - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. - is_first_step_output: bool = False if not seq_group_metadata_list \ - else seq_group_metadata_list[0].state.num_steps == 1 - - # Add results to the output_queue - ctx.append_output(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=allow_async_output_proc, - is_last_step=True, - is_first_step_output=is_first_step_output) - - if outputs and allow_async_output_proc: - assert len(outputs) == 1, ( - "Async postprocessor expects only a single output set") - - self._advance_to_next_step( - outputs[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) - - # Check if need to run the usual non-async path - if not allow_async_output_proc: - self._process_model_outputs(ctx=ctx) - - # Log stats. - self.do_log_stats(scheduler_outputs, outputs) - - # Tracing - self.do_tracing(scheduler_outputs) - else: - # Multi-step case - return ctx.request_outputs - - if not self.has_unfinished_requests(): - # Drain async postprocessor (if exists) - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - assert len(ctx.output_queue) == 0 - - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise timeout, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() - - return ctx.request_outputs - - def _abort_and_cache_schedule( - self, request_id: str, virtual_engine: int, - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, - allow_async_output_proc: bool) -> None: - """Aborts a single request, and caches the scheduler outputs minus that - request. This allows the next step to continue processing the remaining - requests without having to re-run the scheduler.""" - - # Abort the request and remove its sequence group from the current - # schedule - self.abort_request(request_id) - for i, metadata in enumerate(seq_group_metadata_list): - if metadata.request_id == request_id: - del seq_group_metadata_list[i] - break - for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): - if group.seq_group.request_id == request_id: - del scheduler_outputs.scheduled_seq_groups[i] - break - - # If there are still other sequence groups left in the schedule, cache - # them and flag the engine to reuse the schedule. - if len(seq_group_metadata_list) > 0: - self._skip_scheduling_next_step = True - # Reuse multi-step caching logic - self._cache_scheduler_outputs_for_multi_step( - virtual_engine=virtual_engine, - scheduler_outputs=scheduler_outputs, - seq_group_metadata_list=seq_group_metadata_list, - allow_async_output_proc=allow_async_output_proc) - - def _has_remaining_steps( - self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ) -> bool: - return False - - def _cache_scheduler_outputs_for_multi_step( - self, virtual_engine: int, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs, - allow_async_output_proc: bool) -> None: - co = self.cached_scheduler_outputs[virtual_engine] - - co.seq_group_metadata_list = seq_group_metadata_list - co.scheduler_outputs = scheduler_outputs - co.allow_async_output_proc = allow_async_output_proc - co.last_output = None - - def _update_cached_scheduler_output( - self, virtual_engine: int, - output: List[Optional[SamplerOutput]]) -> None: - if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 - and output[0] is not None): - last_output = output[-1] - assert last_output is not None - assert last_output.sampled_token_ids_cpu is not None - assert last_output.sampled_token_ids is None - assert last_output.sampled_token_probs is None - self.cached_scheduler_outputs[ - virtual_engine].last_output = last_output - - def _get_last_sampled_token_ids( - self, virtual_engine: int) -> Optional[torch.Tensor]: - return None - - def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} already exists.") - self.stat_loggers[logger_name] = logger - - def remove_logger(self, logger_name: str) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name not in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} does not exist.") - del self.stat_loggers[logger_name] - - def do_log_stats(self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> None: - """Forced log when no requests active.""" - if self.log_stats: - stats = self._get_stats(scheduler_outputs, model_output, - finished_before, skip) - for logger in self.stat_loggers.values(): - logger.log(stats) - - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs], - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> Stats: - """Get Stats to be Logged to Prometheus. - - Args: - scheduler_outputs: Optional, used to populate metrics related to - the scheduled batch, - model_output: Optional, used to emit speculative decoding metrics - which are created by the workers. - finished_before: Optional, indices of sequences that were finished - before. These sequences will be ignored. - skip: Optional, indices of sequences that were preempted. These - sequences will be ignored. - """ - now = time.time() - - # System State - # Scheduler State - num_running_sys = sum( - len(scheduler.running) for scheduler in self.scheduler) - num_swapped_sys = sum( - len(scheduler.swapped) for scheduler in self.scheduler) - num_waiting_sys = sum( - len(scheduler.waiting) for scheduler in self.scheduler) - - # KV Cache Usage in % - num_total_gpu = self.cache_config.num_gpu_blocks - gpu_cache_usage_sys = 0. - if num_total_gpu: # Guard against both None and 0 - num_free_gpu = sum( - scheduler.block_manager.get_num_free_gpu_blocks() - for scheduler in self.scheduler) - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) - - num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage_sys = 0. - if num_total_cpu: # Guard against both None and 0 - num_free_cpu = sum( - scheduler.block_manager.get_num_free_cpu_blocks() - for scheduler in self.scheduler) - cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) - - # Prefix Cache Hit Rate. Note that we always use - # the cache hit rate of the first virtual engine. - cpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.CPU) - gpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.GPU) - - # Exchange the uasge and cache hit stats between gpu and cpu when - # running on cpu because the cpu_worker.py intentionally reports the - # number of cpu blocks as gpu blocks in favor of cache management. - if self.device_config.device_type == "cpu": - num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu - gpu_cache_usage_sys, cpu_cache_usage_sys = ( - cpu_cache_usage_sys, - gpu_cache_usage_sys, - ) - gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = ( - cpu_prefix_cache_hit_rate, - gpu_prefix_cache_hit_rate, - ) - - # Iteration stats - num_prompt_tokens_iter = 0 - num_generation_tokens_iter = 0 - num_tokens_iter = 0 - time_to_first_tokens_iter: List[float] = [] - time_per_output_tokens_iter: List[float] = [] - num_preemption_iter = (0 if scheduler_outputs is None else - scheduler_outputs.preempted) - - # Request stats - # Latency - time_e2e_requests: List[float] = [] - time_queue_requests: List[float] = [] - time_inference_requests: List[float] = [] - time_prefill_requests: List[float] = [] - time_decode_requests: List[float] = [] - # Metadata - num_prompt_tokens_requests: List[int] = [] - num_generation_tokens_requests: List[int] = [] - n_requests: List[int] = [] - max_num_generation_tokens_requests: List[int] = [] - max_tokens_requests: List[int] = [] - finished_reason_requests: List[str] = [] - - # LoRA requests - running_lora_adapters = dict( - collectionsCounter([ - running_request.lora_request.lora_name - for scheduler in self.scheduler - for running_request in scheduler.running - if running_request.lora_request - ])) - waiting_lora_adapters = dict( - collectionsCounter([ - waiting_request.lora_request.lora_name - for scheduler in self.scheduler - for waiting_request in scheduler.waiting - if waiting_request.lora_request - ])) - max_lora_stat = "0" - if self.lora_config: - max_lora_stat = str(self.lora_config.max_loras) - - # NOTE: This loop assumes prefill seq_groups are before - # decode seq_groups in scheduled_seq_groups. - if scheduler_outputs is not None: - # For async postprocessor, already finished sequences need to be - # not counted (to avoid double counting) - actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - - num_generation_tokens_from_prefill_groups = 0 - # NOTE: if scheduler_outputs.num_prefill_groups > 0 and - # the len of scheduler_outputs.scheduled_seq_groups is != - # scheduler_outputs.num_prefill_groups, this means that - # chunked prefills have been detected. - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double logging when using async output proc - if finished_before and idx in finished_before: - actual_num_batched_tokens -= 1 - continue - - # Currently, skip == preempted sequences, so we need to skip - # their log stats - if skip and idx in skip: - continue - - group_was_prefill = idx < scheduler_outputs.num_prefill_groups - seq_group = scheduled_seq_group.seq_group - - # NOTE: a seq_group that completed all of its prefill tokens - # in the last iteration will have seq_group.is_prefill() = False - # with group_was_prefill = True - if group_was_prefill: - # Number of prompt tokens. - num_prompt_tokens_iter += ( - scheduled_seq_group.token_chunk_size) - - # If the seq_group just finished the prefill state - # get TTFT. - if not seq_group.is_prefill(): - latency = seq_group.get_last_token_latency() - time_to_first_tokens_iter.append(latency) - - # One generation token per finished prefill. - num_generation_tokens_from_prefill_groups += ( - seq_group.num_seqs()) - else: - # TPOTs. - latency = seq_group.get_last_token_latency() - time_per_output_tokens_iter.append(latency) - if seq_group.state.current_step == 0: - # For async_output_proc, the do_log_stats() - # is called following init_multi_step(), which - # sets the current_step to zero. - actual_num_batched_tokens +=\ - seq_group.state.num_steps - 1 - else: - actual_num_batched_tokens +=\ - seq_group.state.current_step - 1 - - # Because of chunked prefill, we can have a single sequence - # group that does multiple prompt_runs. To prevent logging - # the same metadata more than once per request, we standardize - # on logging request level information for finished requests, - # which can only happen once. - if seq_group.is_finished(): - # Latency timings - time_e2e_requests.append(now - - seq_group.metrics.arrival_time) - if (seq_group.metrics.first_scheduled_time is not None and - seq_group.metrics.first_token_time is not None): - time_queue_requests.append( - seq_group.metrics.first_scheduled_time - - seq_group.metrics.arrival_time) - time_prefill_requests.append( - seq_group.metrics.first_token_time - - seq_group.metrics.first_scheduled_time) - time_decode_requests.append( - now - seq_group.metrics.first_token_time) - time_inference_requests.append( - now - seq_group.metrics.first_scheduled_time) - # Metadata - num_prompt_tokens_requests.append( - len(seq_group.prompt_token_ids)) - num_generation_tokens_requests.extend([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ]) - max_num_generation_tokens_requests.append( - max(seq.get_output_len() - for seq in seq_group.get_seqs())) - if seq_group.sampling_params is not None: - n_requests.append(seq_group.sampling_params.n) - max_tokens_requests.append( - seq_group.sampling_params.max_tokens) - finished_reason_requests.extend([ - SequenceStatus.get_finished_reason(seq.status) - for seq in seq_group.get_finished_seqs() - ]) - - # Number of generation tokens. - # num_batched_tokens equals the number of prompt_tokens plus the - # number of decode_tokens in a single iteration. So, - # num_generation_tokens = num_batched_tokens - num_prompt_tokens - # + num_generation_tokens_from_prefill_groups (since we generate - # one token on prefills on iters where the prefill finishes). - num_generation_tokens_iter = ( - actual_num_batched_tokens - num_prompt_tokens_iter + - num_generation_tokens_from_prefill_groups) - num_tokens_iter = (num_generation_tokens_iter + - num_prompt_tokens_iter) - - return Stats( - now=now, - # System stats - # Scheduler State - num_running_sys=num_running_sys, - num_swapped_sys=num_swapped_sys, - num_waiting_sys=num_waiting_sys, - # KV Cache Usage in % - gpu_cache_usage_sys=gpu_cache_usage_sys, - cpu_cache_usage_sys=cpu_cache_usage_sys, - # Prefix Cache Hit Rate - cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, - gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, - - # Iteration stats - num_prompt_tokens_iter=num_prompt_tokens_iter, - num_generation_tokens_iter=num_generation_tokens_iter, - num_tokens_iter=num_tokens_iter, - time_to_first_tokens_iter=time_to_first_tokens_iter, - time_per_output_tokens_iter=time_per_output_tokens_iter, - num_preemption_iter=num_preemption_iter, - - # Request stats - # Latency - time_e2e_requests=time_e2e_requests, - time_queue_requests=time_queue_requests, - time_inference_requests=time_inference_requests, - time_prefill_requests=time_prefill_requests, - time_decode_requests=time_decode_requests, - # Metadata - num_prompt_tokens_requests=num_prompt_tokens_requests, - num_generation_tokens_requests=num_generation_tokens_requests, - max_num_generation_tokens_requests= - max_num_generation_tokens_requests, - n_requests=n_requests, - max_tokens_requests=max_tokens_requests, - finished_reason_requests=finished_reason_requests, - max_lora=str(max_lora_stat), - waiting_lora_adapters=list(waiting_lora_adapters.keys()), - running_lora_adapters=list(running_lora_adapters.keys())) - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_executor.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_executor.remove_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_executor.list_loras() - - def pin_lora(self, lora_id: int) -> bool: - return self.model_executor.pin_lora(lora_id) - - def start_profile(self) -> None: - self.model_executor.start_profile() - - def stop_profile(self) -> None: - self.model_executor.stop_profile() - - def sleep(self, level: int = 1) -> None: - assert self.vllm_config.model_config.enable_sleep_mode, ( - "Sleep mode is not enabled in the model config") - self.model_executor.sleep(level=level) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - assert self.vllm_config.model_config.enable_sleep_mode, ( - "Sleep mode is not enabled in the model config") - self.model_executor.wake_up(tags) - - def is_sleeping(self) -> bool: - return self.model_executor.is_sleeping - - def check_health(self) -> None: - self.model_executor.check_health() - - def is_tracing_enabled(self) -> bool: - return self.tracer is not None - - def do_tracing(self, - scheduler_outputs: SchedulerOutputs, - finished_before: Optional[List[int]] = None) -> None: - if self.tracer is None: - return - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double tracing when using async output proc - if finished_before and idx in finished_before: - continue - - seq_group = scheduled_seq_group.seq_group - if seq_group.is_finished(): - self.create_trace_span(seq_group) - - def create_trace_span(self, seq_group: SequenceGroup) -> None: - if self.tracer is None or seq_group.sampling_params is None: - return - arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) - - trace_context = extract_trace_context(seq_group.trace_headers) - - with self.tracer.start_as_current_span( - "llm_request", - kind=SpanKind.SERVER, - context=trace_context, - start_time=arrival_time_nano_seconds) as seq_span: - metrics = seq_group.metrics - - # Handle potential None values for cancelled/aborted requests - ttft = (metrics.first_token_time - metrics.arrival_time - if metrics.first_token_time is not None else None) - - e2e_time = (metrics.finished_time - metrics.arrival_time - if metrics.finished_time is not None else None) - - seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL, - self.model_config.model) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, - seq_group.request_id) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, - seq_group.sampling_params.temperature) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, - seq_group.sampling_params.top_p) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, - seq_group.sampling_params.max_tokens) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, - seq_group.sampling_params.n) - seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES, - seq_group.num_seqs()) - seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, - len(seq_group.prompt_token_ids)) - seq_span.set_attribute( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, - sum([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ])) - - # Only set timing attributes if the values are available - if metrics.time_in_queue is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, - metrics.time_in_queue) - if ttft is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft) - if e2e_time is not None: - seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, - e2e_time) - if metrics.scheduler_time is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER, - metrics.scheduler_time) - if metrics.model_forward_time is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD, - metrics.model_forward_time / 1000.0) - if metrics.model_execute_time is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, - metrics.model_execute_time) - - def _validate_model_inputs(self, inputs: ProcessorInputs, - lora_request: Optional[LoRARequest]): - encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - - if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, - lora_request, - prompt_type="encoder") - - self._validate_model_input(decoder_inputs, - lora_request, - prompt_type="decoder") - - def _validate_model_input( - self, - prompt_inputs: SingletonInputs, - lora_request: Optional[LoRARequest], - *, - prompt_type: Literal["encoder", "decoder"], - ): - model_config = self.model_config - tokenizer = (None if self.tokenizer is None else - self.tokenizer.get_lora_tokenizer(lora_request)) - - prompt_ids = prompt_inputs.get("prompt_token_ids", []) - if not prompt_ids: - if prompt_type == "encoder" and model_config.is_multimodal_model: - pass # Mllama may have empty encoder inputs for text-only data - elif prompt_inputs["type"] == "embeds": - pass - else: - raise ValueError(f"The {prompt_type} prompt cannot be empty") - - if tokenizer is not None: - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") - - max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) > max_prompt_len: - if prompt_type == "encoder" and model_config.is_multimodal_model: - mm_registry = self.input_preprocessor.mm_registry - mm_processor = mm_registry.create_processor( - model_config, - tokenizer=tokenizer or object(), # Dummy if no tokenizer - ) - assert isinstance(mm_processor, EncDecMultiModalProcessor) - - if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper - - if model_config.is_multimodal_model: - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") - else: - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens.") - - raise ValueError( - f"The {prompt_type} prompt (length {len(prompt_ids)}) is " - f"longer than the maximum model length of {max_prompt_len}. " - f"{suggestion}") - - # TODO: Find out how many placeholder tokens are there so we can - # check that chunked prefill does not truncate them - # max_batch_len = self.scheduler_config.max_num_batched_tokens - - def _build_logits_processors( - self, sampling_params: SamplingParams, - lora_request: Optional[LoRARequest]) -> SamplingParams: - """Constructs logits processors based on the logits_bias, and - allowed_token_ids fields in sampling_params. Deletes those fields and - adds the constructed logits processors to the logits_processors field. - Returns the modified sampling params.""" - - logits_processors = [] - - if (sampling_params.logit_bias or sampling_params.allowed_token_ids): - tokenizer = self.get_tokenizer(lora_request=lora_request) - - processors = get_openai_logits_processors( - logit_bias=sampling_params.logit_bias, - allowed_token_ids=sampling_params.allowed_token_ids, - tokenizer=tokenizer) - logits_processors.extend(processors) - - # Unset so these don't get passed down to the model - sampling_params.logit_bias = None - sampling_params.allowed_token_ids = None - - if len(sampling_params.bad_words) > 0: - tokenizer = self.get_tokenizer(lora_request) - processors = get_bad_words_logits_processors( - bad_words=sampling_params.bad_words, tokenizer=tokenizer) - logits_processors.extend(processors) - - if logits_processors: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = logits_processors - else: - sampling_params.logits_processors.extend(logits_processors) - - return sampling_params - - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.model_executor.collective_rpc(method, timeout, args, - kwargs) - - -if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - LLMEngine = V1LLMEngine # type: ignore +LLMEngine = V1LLMEngine # type: ignore diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ba8dbd1fad791..45b798ed96cb2 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from typing import Counter as CollectionsCounter -from typing import Dict, List, Optional, Type, Union, cast +from collections import Counter as CollectionsCounter +from typing import Optional, Union, cast import numpy as np import prometheus_client @@ -43,7 +43,7 @@ class Metrics: _counter_cls = prometheus_client.Counter _histogram_cls = prometheus_client.Histogram - def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + def __init__(self, labelnames: list[str], vllm_config: VllmConfig): # Unregister any existing vLLM collectors (for CI/CD) self._unregister_vllm_metrics() @@ -51,8 +51,7 @@ class Metrics: # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future - self.show_hidden_metrics = \ - vllm_config.observability_config.show_hidden_metrics + self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics # System stats # Scheduler State @@ -60,12 +59,14 @@ class Metrics: name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames, - multiprocess_mode="sum") + multiprocess_mode="sum", + ) self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames, - multiprocess_mode="sum") + multiprocess_mode="sum", + ) self.gauge_lora_info = self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", @@ -82,81 +83,173 @@ class Metrics: name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, - multiprocess_mode="sum") + multiprocess_mode="sum", + ) # Iteration stats self.counter_num_preemption = self._counter_cls( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", labelnames=labelnames, - buckets=[ - 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 - ]) + buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + ) self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labelnames, buckets=[ - 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, - 2560.0 - ]) + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + 160.0, + 640.0, + 2560.0, + ], + ) + # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds + # TODO: in 0.12, only enable if show_hidden_metrics=True self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", + documentation=( + "Histogram of time per output token in seconds." + "DEPRECATED: Use vllm:inter_token_latency_seconds instead." + ), labelnames=labelnames, buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 - ]) + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + ], + ) + self.histogram_inter_token_latency = self._histogram_cls( + name="vllm:inter_token_latency_seconds", + documentation="Histogram of inter token latency in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + ], + ) # Request stats # Latency request_latency_buckets = [ - 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, - 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + 0.3, + 0.5, + 0.8, + 1.0, + 1.5, + 2.0, + 2.5, + 5.0, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + 120.0, + 240.0, + 480.0, + 960.0, + 1920.0, + 7680.0, ] self.histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_queue_time_request = self._histogram_cls( name="vllm:request_queue_time_seconds", - documentation= - "Histogram of time spent in WAITING phase for request.", + documentation="Histogram of time spent in WAITING phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_inference_time_request = self._histogram_cls( name="vllm:request_inference_time_seconds", - documentation= - "Histogram of time spent in RUNNING phase for request.", + documentation="Histogram of time spent in RUNNING phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_prefill_time_request = self._histogram_cls( name="vllm:request_prefill_time_seconds", - documentation= - "Histogram of time spent in PREFILL phase for request.", + documentation="Histogram of time spent in PREFILL phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_decode_time_request = self._histogram_cls( name="vllm:request_decode_time_seconds", - documentation= - "Histogram of time spent in DECODE phase for request.", + documentation="Histogram of time spent in DECODE phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) # Metadata self.histogram_num_prompt_tokens_request = self._histogram_cls( @@ -165,19 +258,18 @@ class Metrics: labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) - self.histogram_num_generation_tokens_request = \ - self._histogram_cls( - name="vllm:request_generation_tokens", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) + self.histogram_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) self.histogram_max_num_generation_tokens_request = self._histogram_cls( name="vllm:request_max_num_generation_tokens", - documentation= - "Histogram of maximum number of requested generation tokens.", + documentation="Histogram of maximum number of requested generation tokens.", labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len)) + buckets=build_1_2_5_buckets(max_model_len), + ) self.histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", @@ -193,10 +285,10 @@ class Metrics: self.counter_request_success = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", - labelnames=labelnames + [Metrics.labelname_finish_reason]) + labelnames=labelnames + [Metrics.labelname_finish_reason], + ) - -# --8<-- [end:metrics-definitions] + # --8<-- [end:metrics-definitions] def _unregister_vllm_metrics(self) -> None: for collector in list(prometheus_client.REGISTRY._collector_to_names): @@ -208,16 +300,18 @@ class _RayGaugeWrapper: """Wraps around ray.util.metrics.Gauge to provide same API as prometheus_client.Gauge""" - def __init__(self, - name: str, - documentation: str = "", - labelnames: Optional[List[str]] = None, - multiprocess_mode: str = ""): + def __init__( + self, + name: str, + documentation: str = "", + labelnames: Optional[list[str]] = None, + multiprocess_mode: str = "", + ): del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None - self._gauge = ray_metrics.Gauge(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self._gauge = ray_metrics.Gauge( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def labels(self, **labels): self._gauge.set_default_tags(labels) @@ -235,14 +329,13 @@ class _RayCounterWrapper: """Wraps around ray.util.metrics.Counter to provide same API as prometheus_client.Counter""" - def __init__(self, - name: str, - documentation: str = "", - labelnames: Optional[List[str]] = None): + def __init__( + self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None + ): labelnames_tuple = tuple(labelnames) if labelnames else None - self._counter = ray_metrics.Counter(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self._counter = ray_metrics.Counter( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def labels(self, **labels): self._counter.set_default_tags(labels) @@ -258,17 +351,21 @@ class _RayHistogramWrapper: """Wraps around ray.util.metrics.Histogram to provide same API as prometheus_client.Histogram""" - def __init__(self, - name: str, - documentation: str = "", - labelnames: Optional[List[str]] = None, - buckets: Optional[List[float]] = None): + def __init__( + self, + name: str, + documentation: str = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None boundaries = buckets if buckets else [] - self._histogram = ray_metrics.Histogram(name=name, - description=documentation, - tag_keys=labelnames_tuple, - boundaries=boundaries) + self._histogram = ray_metrics.Histogram( + name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries, + ) def labels(self, **labels): self._histogram.set_default_tags(labels) @@ -283,14 +380,18 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _gauge_cls: Type[prometheus_client.Gauge] = cast( - Type[prometheus_client.Gauge], _RayGaugeWrapper) - _counter_cls: Type[prometheus_client.Counter] = cast( - Type[prometheus_client.Counter], _RayCounterWrapper) - _histogram_cls: Type[prometheus_client.Histogram] = cast( - Type[prometheus_client.Histogram], _RayHistogramWrapper) - def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + _gauge_cls: type[prometheus_client.Gauge] = cast( + type[prometheus_client.Gauge], _RayGaugeWrapper + ) + _counter_cls: type[prometheus_client.Counter] = cast( + type[prometheus_client.Counter], _RayCounterWrapper + ) + _histogram_cls: type[prometheus_client.Histogram] = cast( + type[prometheus_client.Histogram], _RayHistogramWrapper + ) + + def __init__(self, labelnames: list[str], vllm_config: VllmConfig): if ray_metrics is None: raise ImportError("RayMetrics requires Ray to be installed.") super().__init__(labelnames, vllm_config) @@ -300,14 +401,14 @@ class RayMetrics(Metrics): pass -def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: +def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by mantissa values until the value exceeds the specified maximum. """ exponent = 0 - buckets: List[int] = [] + buckets: list[int] = [] while True: for m in mantissa_lst: value = m * 10**exponent @@ -318,7 +419,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: exponent += 1 -def build_1_2_5_buckets(max_value: int) -> List[int]: +def build_1_2_5_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_5_buckets(100) @@ -327,7 +428,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: return build_buckets([1, 2, 5], max_value) -def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: +def build_1_2_3_5_8_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_3_5_8_buckets(100) @@ -336,14 +437,12 @@ def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: return build_buckets([1, 2, 3, 5, 8], max_value) -def local_interval_elapsed(now: float, last_log: float, - local_interval: float) -> bool: +def local_interval_elapsed(now: float, last_log: float, local_interval: float) -> bool: elapsed_time = now - last_log return elapsed_time > local_interval -def get_throughput(tracked_stats: List[int], now: float, - last_log: float) -> float: +def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float: return float(np.sum(tracked_stats) / (now - last_log)) @@ -357,29 +456,32 @@ class LoggingStatLogger(StatLoggerBase): def log(self, stats: Stats) -> None: """Called by LLMEngine. - Logs to Stdout every self.local_interval seconds.""" + Logs to Stdout every self.local_interval seconds.""" # Save tracked stats for token counters. self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. - if local_interval_elapsed(stats.now, self.last_local_log, - self.local_interval): + if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Compute summary metrics for tracked stats (and log them - # to promethus if applicable). - prompt_throughput = get_throughput(self.num_prompt_tokens, - now=stats.now, - last_log=self.last_local_log) + # to prometheus if applicable). + prompt_throughput = get_throughput( + self.num_prompt_tokens, now=stats.now, last_log=self.last_local_log + ) generation_throughput = get_throughput( - self.num_generation_tokens, - now=stats.now, - last_log=self.last_local_log) + self.num_generation_tokens, now=stats.now, last_log=self.last_local_log + ) log_fn = logger.info - if not any((prompt_throughput, generation_throughput, - self.last_prompt_throughput, - self.last_generation_throughput)): + if not any( + ( + prompt_throughput, + generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput, + ) + ): # Avoid log noise on an idle production system log_fn = logger.debug @@ -397,8 +499,10 @@ class LoggingStatLogger(StatLoggerBase): stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) - if (stats.cpu_prefix_cache_hit_rate >= 0 - or stats.gpu_prefix_cache_hit_rate >= 0): + if ( + stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0 + ): log_fn( "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", stats.gpu_prefix_cache_hit_rate * 100, @@ -420,17 +524,20 @@ class LoggingStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase): - """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + """PrometheusStatLogger is used LLMEngine to log to Prometheus.""" + _metrics_cls = Metrics _gauge_cls = prometheus_client.Gauge - def __init__(self, local_interval: float, labels: Dict[str, str], - vllm_config: VllmConfig) -> None: + def __init__( + self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig + ) -> None: super().__init__(local_interval, vllm_config) # Prometheus metrics self.labels = labels - self.metrics = self._metrics_cls(labelnames=list(labels.keys()), - vllm_config=vllm_config) + self.metrics = self._metrics_cls( + labelnames=list(labels.keys()), vllm_config=vllm_config + ) def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. @@ -440,88 +547,106 @@ class PrometheusStatLogger(StatLoggerBase): # Convenience function for logging to counter. # Prevent ValueError from negative increment if data < 0: - logger.warning("Skipping negative increment of %g to %s", data, - counter) + logger.warning("Skipping negative increment of %g to %s", data, counter) return counter.labels(**self.labels).inc(data) - def _log_counter_labels(self, counter, data: CollectionsCounter, - label_key: str) -> None: + def _log_counter_labels( + self, counter, data: CollectionsCounter, label_key: str + ) -> None: # Convenience function for collection counter of labels. for label, count in data.items(): counter.labels(**{**self.labels, label_key: label}).inc(count) - def _log_histogram(self, histogram, data: Union[List[int], - List[float]]) -> None: + def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None: # Convenience function for logging list to histogram. for datum in data: histogram.labels(**self.labels).observe(datum) - def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + def _log_gauge_string(self, gauge, data: dict[str, str]) -> None: gauge.labels(**data).set_to_current_time() def _log_prometheus(self, stats: Stats) -> None: # System state data - self._log_gauge(self.metrics.gauge_scheduler_running, - stats.num_running_sys) - self._log_gauge(self.metrics.gauge_scheduler_waiting, - stats.num_waiting_sys) - self._log_gauge(self.metrics.gauge_gpu_cache_usage, - stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_scheduler_running, stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, stats.gpu_cache_usage_sys) # Including max-lora in metric, in future this property of lora # config maybe extended to be dynamic. lora_info = { - self.metrics.labelname_running_lora_adapters: - ",".join(stats.running_lora_adapters), - self.metrics.labelname_waiting_lora_adapters: - ",".join(stats.waiting_lora_adapters), - self.metrics.labelname_max_lora: - stats.max_lora, + self.metrics.labelname_running_lora_adapters: ",".join( + stats.running_lora_adapters + ), + self.metrics.labelname_waiting_lora_adapters: ",".join( + stats.waiting_lora_adapters + ), + self.metrics.labelname_max_lora: stats.max_lora, } self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) # Iteration level data - self._log_counter(self.metrics.counter_num_preemption, - stats.num_preemption_iter) - self._log_counter(self.metrics.counter_prompt_tokens, - stats.num_prompt_tokens_iter) - self._log_counter(self.metrics.counter_generation_tokens, - stats.num_generation_tokens_iter) - self._log_histogram(self.metrics.histogram_iteration_tokens, - [stats.num_tokens_iter]) - self._log_histogram(self.metrics.histogram_time_to_first_token, - stats.time_to_first_tokens_iter) - self._log_histogram(self.metrics.histogram_time_per_output_token, - stats.time_per_output_tokens_iter) + self._log_counter( + self.metrics.counter_num_preemption, stats.num_preemption_iter + ) + self._log_counter( + self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter + ) + self._log_counter( + self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_iteration_tokens, [stats.num_tokens_iter] + ) + self._log_histogram( + self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_time_per_output_token, + stats.inter_token_latencies_iter, + ) + self._log_histogram( + self.metrics.histogram_inter_token_latency, stats.inter_token_latencies_iter + ) # Request level data # Latency - self._log_histogram(self.metrics.histogram_e2e_time_request, - stats.time_e2e_requests) - self._log_histogram(self.metrics.histogram_queue_time_request, - stats.time_queue_requests) - self._log_histogram(self.metrics.histogram_inference_time_request, - stats.time_inference_requests) - self._log_histogram(self.metrics.histogram_prefill_time_request, - stats.time_prefill_requests) - self._log_histogram(self.metrics.histogram_decode_time_request, - stats.time_decode_requests) + self._log_histogram( + self.metrics.histogram_e2e_time_request, stats.time_e2e_requests + ) + self._log_histogram( + self.metrics.histogram_queue_time_request, stats.time_queue_requests + ) + self._log_histogram( + self.metrics.histogram_inference_time_request, stats.time_inference_requests + ) + self._log_histogram( + self.metrics.histogram_prefill_time_request, stats.time_prefill_requests + ) + self._log_histogram( + self.metrics.histogram_decode_time_request, stats.time_decode_requests + ) # Metadata - finished_reason_counter = CollectionsCounter( - stats.finished_reason_requests) - self._log_counter_labels(self.metrics.counter_request_success, - finished_reason_counter, - Metrics.labelname_finish_reason) - self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, - stats.num_prompt_tokens_requests) + finished_reason_counter = CollectionsCounter(stats.finished_reason_requests) + self._log_counter_labels( + self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason, + ) + self._log_histogram( + self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests, + ) self._log_histogram( self.metrics.histogram_num_generation_tokens_request, - stats.num_generation_tokens_requests) + stats.num_generation_tokens_requests, + ) self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) self._log_histogram( self.metrics.histogram_max_num_generation_tokens_request, - stats.max_num_generation_tokens_requests) - self._log_histogram(self.metrics.histogram_max_tokens_request, - stats.max_tokens_requests) + stats.max_num_generation_tokens_requests, + ) + self._log_histogram( + self.metrics.histogram_max_tokens_request, stats.max_tokens_requests + ) def log(self, stats: Stats): """Logs to prometheus and tracked stats every iteration.""" @@ -533,9 +658,7 @@ class PrometheusStatLogger(StatLoggerBase): self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. - if local_interval_elapsed(stats.now, self.last_local_log, - self.local_interval): - + if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] @@ -551,12 +674,14 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:cache_config_info", documentation="Information of the LLMEngine CacheConfig", labelnames=metrics_info.keys(), - multiprocess_mode="mostrecent") + multiprocess_mode="mostrecent", + ) info_gauge.labels(**metrics_info).set(1) class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics def info(self, type: str, obj: SupportsMetricsInfo) -> None: diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 3281a9121a9df..ac796f4e1c758 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -4,7 +4,7 @@ These types are defined in this file to avoid importing vllm.engine.metrics and therefore importing prometheus_client. -This is required due to usage of Prometheus multiprocess mode to enable +This is required due to usage of Prometheus multiprocess mode to enable metrics after splitting out the uvicorn process from the engine process. Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR @@ -16,7 +16,6 @@ do this in Python code and lazily import prometheus_client. import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List from vllm.config import SupportsMetricsInfo, VllmConfig @@ -24,6 +23,7 @@ from vllm.config import SupportsMetricsInfo, VllmConfig @dataclass class Stats: """Created by LLMEngine for use by StatLogger.""" + now: float # System stats (should have _sys suffix) @@ -42,26 +42,26 @@ class Stats: num_prompt_tokens_iter: int num_generation_tokens_iter: int num_tokens_iter: int - time_to_first_tokens_iter: List[float] - time_per_output_tokens_iter: List[float] + time_to_first_tokens_iter: list[float] + inter_token_latencies_iter: list[float] num_preemption_iter: int # Request stats (should have _requests suffix) # Latency - time_e2e_requests: List[float] - time_queue_requests: List[float] - time_inference_requests: List[float] - time_prefill_requests: List[float] - time_decode_requests: List[float] + time_e2e_requests: list[float] + time_queue_requests: list[float] + time_inference_requests: list[float] + time_prefill_requests: list[float] + time_decode_requests: list[float] # Metadata - num_prompt_tokens_requests: List[int] - num_generation_tokens_requests: List[int] - n_requests: List[int] - max_num_generation_tokens_requests: List[int] - max_tokens_requests: List[int] - finished_reason_requests: List[str] - waiting_lora_adapters: List[str] - running_lora_adapters: List[str] + num_prompt_tokens_requests: list[int] + num_generation_tokens_requests: list[int] + n_requests: list[int] + max_num_generation_tokens_requests: list[int] + max_tokens_requests: list[int] + finished_reason_requests: list[str] + waiting_lora_adapters: list[str] + running_lora_adapters: list[str] max_lora: str @@ -70,8 +70,8 @@ class StatLoggerBase(ABC): def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: # Tracked stats over current local logging interval. - self.num_prompt_tokens: List[int] = [] - self.num_generation_tokens: List[int] = [] + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] self.last_local_log = time.time() self.local_interval = local_interval diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py deleted file mode 100644 index ff0405d2f843e..0000000000000 --- a/vllm/engine/multiprocessing/__init__.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import uuid -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Mapping, Optional, Union - -from vllm import PoolingParams -from vllm.inputs import PromptType -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.utils import Device - -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -IPC_INPUT_EXT = "_input_socket" -IPC_OUTPUT_EXT = "_output_socket" -IPC_HEALTH_EXT = "_health_socket" -IPC_DATA_EXT = "_data_socket" - - -class MQEngineDeadError(RuntimeError): - pass - - -@dataclass -class RPCProcessRequest: - prompt: PromptType - params: Union[SamplingParams, PoolingParams] - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - priority: int = 0 - - def __init__( - self, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> None: - super().__init__() - - self.prompt = prompt - self.params = params - self.request_id = request_id - self.lora_request = lora_request - self.trace_headers = trace_headers - self.priority = priority - - -@dataclass -class RPCError: - request_id: Optional[str] - is_engine_errored: bool - exception: BaseException - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCStartupRequest(Enum): - IS_SERVER_READY = 1 - - -@dataclass -class RPCStartupResponse: - tracing_enabled: bool - - -class RPCUProfileRequest(Enum): - START_PROFILE = 1 - STOP_PROFILE = 2 - - -class RPCResetMultiModalCacheRequest(Enum): - RESET = 1 - - -@dataclass -class RPCResetPrefixCacheRequest: - device: Device - - -class RPCSleepRequest(Enum): - SLEEP_LEVEL_1 = 1 - SLEEP_LEVEL_2 = 2 - - -@dataclass -class RPCWakeUpRequest: - tags: Optional[list[str]] = None - - -@dataclass -class RPCIsSleepingRequest: - # Set the default value of request_id to a new UUID - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -@dataclass -class RPCIsSleepingResponse: - request_id: str - is_sleeping: bool - - -@dataclass -class RPCLoadAdapterRequest: - lora_request: LoRARequest - # Set the default value of request_id to a new UUID - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -@dataclass -class RPCAdapterLoadedResponse: - request_id: str - - -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, - RPCUProfileRequest, RPCLoadAdapterRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, RPCSleepRequest, - RPCWakeUpRequest, RPCIsSleepingRequest] - -REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, - RPCIsSleepingResponse, RPCError] - - -def ENGINE_DEAD_ERROR( - error: Optional[BaseException] = None) -> MQEngineDeadError: - if error is None: - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - "find the original error") - - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py deleted file mode 100644 index 0bb11328b1db5..0000000000000 --- a/vllm/engine/multiprocessing/client.py +++ /dev/null @@ -1,668 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import copy -import pickle -from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, - Mapping, Optional, Union, cast) - -import cloudpickle -import psutil -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import PoolingParams -from vllm.config import DecodingConfig, ModelConfig, VllmConfig -from vllm.core.scheduler import SchedulerOutputs -# yapf conflicts with isort for this block -# yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, RPC_REQUEST_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) -from vllm.engine.protocol import EngineClient -# yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import Device - -logger = init_logger(__name__) - - -class MQClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class MQLLMEngineClient(EngineClient): - """A client wrapper for MQLLMEngine that conforms to the - EngineClient protocol. - - MQLLMEngine and MQLLMEngineClient are intended to run in separate - processes communicating via zeromq ipc sockets. - - The entrypoint to MQLLMEngineClient is through the generate() - method. On generate() MQLLMEngine does three things: - - Creates an asyncio output queue - - Sends a RPCGenerateRequest to the MQLLMEngine via zmq - - Pulls RequestOutputs from its queue and yields them - - MQLLMEngine runs two background loops: - - output_loop: the output loop pulls List[RequestOutput] - from the MQLLMEngine via zmq (each list is the output - of one engine_step in the LLMEngine). It then parses - the list and pushes individual request_outputs into - the corresponding output_queue such that they can be - consumed by the .generate() method. - - health_loop: the health loop queries the health socket - every N seconds, confirming the engine is healthy - """ - - def __init__(self, ipc_path: str, engine_config: VllmConfig, - engine_pid: int): - self.context = zmq.asyncio.Context() - self._errored_with: Optional[BaseException] = None - - # Get the configs. - self.vllm_config = engine_config - self.model_config = engine_config.model_config - self.decoding_config = engine_config.decoding_config - - if self.vllm_config.model_config.skip_tokenizer_init: - self.tokenizer = None - - else: - # Create the tokenizer group. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=engine_config.scheduler_config, - lora_config=engine_config.lora_config) - - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer) - - # Send RPCGenerateRequest to the MQLLMEngine. - self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") - - # Receive streams of RequestOutput from the MQLLMEngine. - self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - - # IPC path for acking heartbeats. - self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) - self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") - - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Stream for each individual request. - self.output_queues: Dict[str, asyncio.Queue] = {} - - # Loop to handle output of the LLMEngine periodically. - # Started after the MQLLMEngine is ready so that we can - # build the Client in an executor to enable clean shutdown. - self.output_loop: Optional[asyncio.Task] = None - - # Loop to check health of the LLMEngine periodically. - # Started after the MQLLMEngine is ready. - self.health_loop: Optional[asyncio.Task] = None - self._engine_process = psutil.Process(engine_pid) - - @staticmethod - def is_unsupported_config(vllm_config: VllmConfig): - # Pipeline parallel not yet supported - return vllm_config.parallel_config.pipeline_parallel_size > 1 - - @contextmanager - def get_data_socket(self) -> Iterator[Socket]: - socket = self.context.socket(zmq.constants.DEALER) - try: - socket.connect(self.data_ipc_path) - yield socket - finally: - socket.close(linger=0) - - async def run_heartbeat_loop(self, timeout: int): - """Background loop that continually checks to ensure the engine process - is still alive. - """ - try: - while True: - # Check if the engine process is running: - if not self._engine_process.is_running() or ( - self._engine_process.status() == psutil.STATUS_ZOMBIE): - # NB: is_running() returns True for zombies - self._set_errored( - RuntimeError( - f"Engine process (pid {self._engine_process.pid}) " - "died.")) - break - - if await self.heartbeat_socket.poll(timeout=timeout): - # Heartbeat received- check the message - await self._check_success( - error_message="Heartbeat failed.", - socket=self.heartbeat_socket) - - logger.debug("Heartbeat successful.") - - except asyncio.CancelledError: - logger.debug("Shutting down MQLLMEngineClient check health loop.") - - except psutil.NoSuchProcess: - self._set_errored( - RuntimeError( - f"Engine process (pid {self._engine_process.pid}) died.")) - - except Exception as e: - self._set_errored(e) - - async def run_output_handler_loop(self): - """Get RequestOutputs from Engine and stream to Request Queues""" - - try: - while True: - # Poll, checking for ENGINE_DEAD - while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT - ) == 0: - logger.debug("Waiting for output from MQLLMEngine.") - - # If errored, alert all running requests. - if self.errored: - for queue_j in tuple(self.output_queues.values()): - queue_j.put_nowait( - ENGINE_DEAD_ERROR(self._errored_with)) - return - - message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - is_error = isinstance(request_outputs, - (BaseException, RPCError)) - if is_error: - if isinstance(request_outputs, RPCError): - rpc_error: RPCError = request_outputs - request_id = rpc_error.request_id - exception = rpc_error.exception - is_engine_errored = rpc_error.is_engine_errored - else: - # MPLLMEngine should always return an RPCError to - # the output_socket when an issue arises. - # If we are here, we are in a bad state and - # should shut down the server. - error: BaseException = request_outputs - logger.error( - "Received Exception %s rather than RPCError from " - "MPLLMEngine. This should never happen.", error) - request_id = None - exception = error - is_engine_errored = True - - # Set to error state only on engine critical error - # (and record only the first one) - if is_engine_errored and not self._errored_with: - self._errored_with = exception - # If engine is errored, no matter the type of exception - # it will no longer be able to receive new requests, - # therefore we have to inform that the current - # processed requests failed as well. Send back a dead - # engine error give this feedback and also give a - # 'hint' to the server to shutdown next. - exception = self.dead_error - - if request_id is None: - # If request_id is None, then the engine raised an - # exception for a batch, and we may not know the - # request that caused it, neither if it was actually - # caused by any of them (e.g. CUDA OOM). Therefore we - # broadcast the same exception for all requests. - for queue_i in tuple(self.output_queues.values()): - queue_i.put_nowait(exception) - else: - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(exception) - # Put each output into the appropriate queue. - elif isinstance( - request_outputs, - (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): - self._add_output(request_outputs) - else: - for request_output in request_outputs: - self._add_output(request_output) - - except asyncio.CancelledError: - logger.debug("Shutting down MQLLMEngineClient output handler.") - - def _add_output(self, request_output: Union[RequestOutput, - RPCAdapterLoadedResponse, - RPCIsSleepingResponse]): - queue = self.output_queues.get(request_output.request_id) - if queue is not None: - queue.put_nowait(request_output) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Start output_loop - if self.output_loop is None: - # only generate once to avoid multiple concurrent output_loops - # this will lead to race conditions and wrong orders of tokens - # returned by the engine - # setup will be called multiple times during the startup of - # the engine - self.output_loop = asyncio.create_task( - self.run_output_handler_loop()) - - with self.get_data_socket() as socket: - # Wait until server is ready. - response = await self._wait_for_server_rpc(socket) - - self.tracing_flag = response.tracing_enabled - - # Start health_loop. - if self.health_loop is None: - self.health_loop = asyncio.create_task( - self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets and terminate the context. - self.context.destroy(linger=0) - - # Cancel background tasks. - if self.health_loop is not None: - self.health_loop.cancel() - if self.output_loop is not None: - self.output_loop.cancel() - - def _set_errored(self, e: BaseException): - logger.exception(repr(e)) - if self._errored_with is None: - self._errored_with = e - - @staticmethod - async def _send_get_data_rpc_request(request: RPCStartupRequest, - expected_type: Any, - error_message: str, - socket: Socket) -> Any: - """Send an RPC request that is expecting data back.""" - - # Ping RPCServer with a request. - await socket.send_multipart((pickle.dumps(request), ), copy=False) - - # Make sure the server responds in time. - if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: - raise TimeoutError("RPCServer didn't reply within " - f"{VLLM_RPC_TIMEOUT} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, BaseException): - raise data - elif not isinstance(data, expected_type): - raise ValueError(error_message) - - return data - - @staticmethod - async def _send_one_way_rpc_request(request: RPC_REQUEST_T, - socket: Socket): - """Send one-way RPC request to trigger an action.""" - - if socket.closed: - raise MQClientClosedError() - - await socket.send_multipart((pickle.dumps(request), )) - - async def _await_ack(self, error_message: str, socket: Socket): - """Await acknowledgement that a request succeeded.""" - - if socket.closed: - raise MQClientClosedError() - - if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: - raise TimeoutError("MQLLMEngine didn't reply within " - f"{VLLM_RPC_TIMEOUT}ms") - - await self._check_success(error_message, socket) - - @staticmethod - async def _check_success(error_message: str, socket: Socket): - """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" - - if socket.closed: - raise MQClientClosedError() - - frame = await socket.recv(copy=False) - response = pickle.loads(frame.buffer) - - # Raise error if unsuccessful - if isinstance(response, BaseException): - raise response - elif (not isinstance(response, str) - or response != VLLM_RPC_SUCCESS_STR): - raise ValueError(error_message) - - async def get_input_preprocessor(self) -> InputPreprocessor: - return self.input_preprocessor - - async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): - if self.tokenizer is None: - return None - else: - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_vllm_config(self) -> VllmConfig: - return self.vllm_config - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: - """Wait for the RPCServer to start up.""" - - return await self._send_get_data_rpc_request( - request=RPCStartupRequest.IS_SERVER_READY, - expected_type=RPCStartupResponse, - error_message="Unable to start RPC Server", - socket=socket) - - async def abort(self, request_id: Union[str, Iterable[str]]): - """Send an ABORT_REQUEST signal to the RPC Server""" - - if not isinstance(request_id, str): - raise RuntimeError("Only single-request abort supported in" - " deprecated V0") - - with suppress(MQClientClosedError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), socket=self.input_socket) - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - ) -> None: - """ - Ignore do_log_stats (handled on MQLLMEngine polling) - """ - pass - - async def check_health(self): - """ - The check health loop probes the health status of the - Engine's health every N seconds and sets _errored_with - if the engine is unhealthy. - """ - if self._errored_with is not None: - raise self._errored_with - - @property - def is_running(self) -> bool: - return not self.errored - - @property - def is_stopped(self) -> bool: - return self.errored - - @property - def errored(self) -> bool: - return self._errored_with is not None - - @property - def dead_error(self) -> BaseException: - return ENGINE_DEAD_ERROR(self._errored_with) - - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - sampling_params: The sampling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: Priority of the request (lower means earlier handling). - Any priority other than 0 will lead to an error if the - scheduling policy is not "priority". - """ - return cast( - AsyncGenerator[RequestOutput, None], - self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, priority)) - - def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[PoolingRequestOutput, None]: - """Generate outputs for a request from a pooling model. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - - Yields: - The output `PoolingRequestOutput` objects from the LLMEngine - for the request. - """ - return cast( - AsyncGenerator[PoolingRequestOutput, None], - self._process_request(prompt, - pooling_params, - request_id, - lora_request, - trace_headers, - priority=priority)) - - async def _process_request( - self, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - PoolingRequestOutput, None]]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - # If already dead, error out. - if self._errored_with is not None: - raise ENGINE_DEAD_ERROR(self._errored_with) - - # Ensure the request id is unique among running requests - if request_id in self.output_queues: - raise ValueError(f"Request {request_id} already exists") - - # 1) Create output queue for this request. - queue: asyncio.Queue[Union[RequestOutput, - BaseException]] = asyncio.Queue() - self.output_queues[request_id] = queue - - try: - # 2) Detach logits processors so that they can be pickled - # separately (may require cloudpickle which is slower) - if isinstance(params, SamplingParams) and params.logits_processors: - # Defensive shallow copy - params = copy.copy(params) - logits_processors = params.logits_processors - params.logits_processors = None - lp_bytes = cloudpickle.dumps(logits_processors) - else: - lp_bytes = None - - request_bytes = pickle.dumps( - RPCProcessRequest( - prompt=prompt, - params=params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - )) - - # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) - await self.input_socket.send_multipart(parts, copy=False) - - # 4) Stream the RequestOutputs from the output queue. Note - # that the output_loop pushes RequestOutput objects to this - # queue after pulling them from the zmq socket. - finished = False - try: - while not finished: - request_output = await queue.get() - - if isinstance(request_output, BaseException): - raise request_output - - finished = request_output.finished - yield request_output - finally: - # Request was canceled by the client. - if not finished and not self.errored: - await self.abort(request_id) - finally: - self.output_queues.pop(request_id) - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) - - async def reset_mm_cache(self) -> None: - """Reset the multi-modal cache""" - - await self._send_one_way_rpc_request( - request=RPCResetMultiModalCacheRequest.RESET, - socket=self.input_socket) - - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: - """Reset the prefix cache""" - - await self._send_one_way_rpc_request( - request=RPCResetPrefixCacheRequest(device), - socket=self.input_socket) - - async def sleep(self, level: int = 1) -> None: - """Sleep the engine for a given level""" - return await self._send_one_way_rpc_request( - request=RPCSleepRequest(level), socket=self.input_socket) - - async def wake_up(self, tags: Optional[list[str]] = None) -> None: - """Wake up the engine""" - return await self._send_one_way_rpc_request( - request=RPCWakeUpRequest(tags), socket=self.input_socket) - - async def is_sleeping(self) -> bool: - """Check whether the engine is sleeping""" - request = RPCIsSleepingRequest() - - queue: asyncio.Queue[Union[BaseException, - RPCIsSleepingResponse]] = asyncio.Queue() - self.output_queues[request.request_id] = queue - - request_bytes = pickle.dumps(request) - await self.input_socket.send_multipart((request_bytes, ), copy=False) - - request_output = await queue.get() - self.output_queues.pop(request.request_id) - - if isinstance(request_output, BaseException): - raise request_output - return request_output.is_sleeping - - async def add_lora(self, lora_request: LoRARequest) -> None: - """Load a new LoRA adapter into the engine for future requests.""" - # Uses the same I/O as generate requests - request = RPCLoadAdapterRequest(lora_request) - - # Create output queue for this request. - queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() - self.output_queues[request.request_id] = queue - - # Send the request - request_bytes = pickle.dumps(request) - await self.input_socket.send_multipart((request_bytes, ), copy=False) - - # Wait for the response - request_output = await queue.get() - self.output_queues.pop(request.request_id) - - # Raise on error, otherwise happily return None - if isinstance(request_output, BaseException): - raise request_output diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py deleted file mode 100644 index 903f3fd71ebcd..0000000000000 --- a/vllm/engine/multiprocessing/engine.py +++ /dev/null @@ -1,469 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pickle -import signal -from contextlib import contextmanager -from typing import Iterator, List, Optional, Union - -import cloudpickle -import zmq - -from vllm import AsyncEngineArgs, SamplingParams -from vllm.config import VllmConfig -from vllm.engine.llm_engine import LLMEngine -# yapf conflicts with isort for this block -# yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) -# yapf: enable -from vllm.logger import init_logger -from vllm.outputs import RequestOutput -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs -from vllm.worker.model_runner_base import InputProcessingError - -logger = init_logger(__name__) - -POLLING_TIMEOUT_MS = 10000 -HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) - - -class MQLLMEngine: - """A multiprocessing wrapper for - [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - - This class is used to wrap the - [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use - in concurrnet manner. It runs a background loop and uses zeromq to - receive new requests and stream outputs incrementally via ipc. - - The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode - process is kicked off when a new RPCProcessRequest is received by the - input_socket. - - The self.engine_loop checks the input_socket for new requests, - adds them to the LLMEngine if there are any, calls the internal - [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends - the RequestOutputs back over the output_socket. - - If use_async_sockets is set, the logic associated with reading new - requests from the socket and sending data to the socket is passed - as a callback to the llm_engine, which calls the logic asynchronously - such that the IPC can be overlapped with the GPU. - - Args: - ipc_path: Base path for zeromq interprocess messaging - use_async_sockets: Whether to make send/recv async with GPU - log_requests: Whether to log the requests. - *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - """ - - def __init__(self, - ipc_path: str, - use_async_sockets: bool, - *args, - log_requests: bool = True, - **kwargs) -> None: - # For MQLLMEngine, we can use cached outputs, since each new request - # output is immediately pickled and send over the socket, which frees - # the python object to be reused again. - kwargs['use_cached_outputs'] = True - - self.engine = LLMEngine(*args, **kwargs) - self.log_requests = log_requests - - self.use_async_sockets = use_async_sockets - if self.use_async_sockets: - self.engine.process_request_outputs_callback = \ - self._async_socket_engine_callback - - self.ctx = zmq.Context() # type: ignore[attr-defined] - - # Receive input from the client. - self.input_socket = self.ctx.socket(zmq.constants.PULL) - self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") - - # Send output stream back to client. - self.output_socket = self.ctx.socket(zmq.constants.PUSH) - self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - - # Send heartbeats back to client. - self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) - self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") - - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Error state. - self._errored_with: Optional[BaseException] = None - - @property - def dead_error(self) -> BaseException: - if self._errored_with is not None: - return ENGINE_DEAD_ERROR(self._errored_with) - else: - return ENGINE_DEAD_ERROR() - - @classmethod - @deprecate_kwargs( - "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), - ) - def from_vllm_config( - cls, - vllm_config: VllmConfig, - usage_context: UsageContext, - enable_log_requests: bool, - disable_log_stats: bool, - ipc_path: str, - disable_log_requests: bool = True, # Deprecated, will be removed - ) -> "MQLLMEngine": - # Setup plugins for each process - from vllm.plugins import load_general_plugins - load_general_plugins() - - use_async_sockets = vllm_config.model_config.use_async_output_proc - - return cls( - vllm_config=vllm_config, - executor_class=LLMEngine._get_executor_cls(vllm_config), - ipc_path=ipc_path, - usage_context=usage_context, - use_async_sockets=use_async_sockets, - log_requests=enable_log_requests, - log_stats=(not disable_log_stats), - ) - - @staticmethod - def from_engine_args(engine_args: AsyncEngineArgs, - usage_context: UsageContext, ipc_path: str): - """Creates an MQLLMEngine from the engine arguments.""" - - vllm_config = engine_args.create_engine_config(usage_context) - return MQLLMEngine.from_vllm_config( - ipc_path=ipc_path, - vllm_config=vllm_config, - usage_context=usage_context, - enable_log_requests=engine_args.enable_log_requests, - disable_log_stats=engine_args.disable_log_stats, - ) - - def start(self): - try: - try: - logger.debug("Starting Startup Loop.") - self.run_startup_loop() - logger.debug("Starting Engine Loop.") - self.run_engine_loop() - except Exception as e: - logger.exception(repr(e)) - except KeyboardInterrupt: - logger.debug("Shutting down MQLLMEngine.") - finally: - logger.debug("MQLLMEngine is shut down.") - self.cleanup() - - def cleanup(self): - """Cleanup zeromq state on shutdown.""" - # Closes all sockets and destroys context. - self.ctx.destroy(linger=0) - del self.engine - - @contextmanager - def make_data_socket( - self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] - socket = self.ctx.socket(zmq.constants.ROUTER) - try: - socket.bind(self.data_ipc_path) - yield socket - finally: - socket.close(linger=0) - - def run_startup_loop(self) -> None: - """Startup loop for sending data from Engine -> Client.""" - - with self.make_data_socket() as socket: - response: Union[RPCStartupResponse, BaseException] - try: - identity, message = socket.recv_multipart(copy=False) - request: RPCStartupRequest = pickle.loads(message.buffer) - - # Handle the query from the Client. - if request == RPCStartupRequest.IS_SERVER_READY: - tracing_enabled = self.engine.is_tracing_enabled() - response = RPCStartupResponse( - tracing_enabled=tracing_enabled) - - except Exception as e: - response = e - - socket.send_multipart((identity, pickle.dumps(response)), - copy=False) - - def run_engine_loop(self): - """Core busy loop of the LLMEngine.""" - - while True: - if not self.engine.has_unfinished_requests(): - # Poll until there is work to do. - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - # When there's no work, check on engine health and send - # health status back to client - self._health_check() - self.engine.do_log_stats() - logger.debug("Waiting for new requests in engine loop.") - - # Handle any input from the client. - self.handle_new_input() - - # Engine step. - request_outputs = self.engine_step() - - # Send request outputs (if async, done in engine_step callback). - if not self.use_async_sockets: - self._send_outputs(request_outputs) - - def engine_step(self) -> List[RequestOutput]: - """Engine step wrapper with error handling.""" - try: - return self.engine.step() - except SystemExit: - raise - except InputProcessingError as e: - # Special case where we handle an error preparing the inputs for - # a single request in the batch - rpc_err = RPCError(request_id=e.request_id, - is_engine_errored=False, - exception=e.__cause__) - self._send_outputs(rpc_err) - return [] - except BaseException as e: - self._set_errored(e) - rpc_err = RPCError(request_id=None, - is_engine_errored=True, - exception=e) - self._send_outputs(rpc_err) - raise e - - def handle_new_input(self): - """Handle new input from the socket""" - try: - while self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - request = pickle.loads(frames[0].buffer) - - if isinstance(request, RPCProcessRequest): - if len(frames) > 1: - # Use cloudpickle for logits processors - assert isinstance(request.params, SamplingParams) - lprocs = cloudpickle.loads(frames[1].buffer) - request.params.logits_processors = lprocs - self._handle_process_request(request) - elif isinstance(request, RPCAbortRequest): - self._handle_abort_request(request) - elif isinstance(request, RPCUProfileRequest): - if request == RPCUProfileRequest.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(request, RPCLoadAdapterRequest): - self._handle_load_adapter_request(request) - elif isinstance(request, RPCResetMultiModalCacheRequest): - self.reset_mm_cache() - elif isinstance(request, RPCResetPrefixCacheRequest): - self.reset_prefix_cache() - elif isinstance(request, RPCSleepRequest): - self.sleep(request.value) - elif isinstance(request, RPCWakeUpRequest): - self.wake_up(request.tags) - elif isinstance(request, RPCIsSleepingRequest): - self._handle_is_sleeping_request(request) - else: - raise ValueError("Unknown RPCRequest Type: " - f"{type(request)}") - - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) - raise e from None - - def _handle_process_request(self, request: RPCProcessRequest): - """Handle RPCProcessRequest by adding it to the LLMEngine.""" - request_id = request.request_id - - if self._errored_with is not None: - rpc_err = RPCError(request_id=request_id, - is_engine_errored=True, - exception=ENGINE_DEAD_ERROR(self._errored_with)) - self._send_outputs(rpc_err) - - try: - self.engine.add_request(request_id=request_id, - prompt=request.prompt, - params=request.params, - lora_request=request.lora_request, - trace_headers=request.trace_headers, - priority=request.priority) - - if self.log_requests: - logger.info("Added request %s.", request.request_id) - - except Exception as e: - # We do not set self._errored = True here, since the error - # is due to an issue adding this request to the engine, - # rather than an issue with the engine itself. - logger.debug("Failed to add request %s to engine. %s", - request.request_id, e) - is_errored = self._errored_with is not None - rpc_err = RPCError(request_id=request_id, - is_engine_errored=is_errored, - exception=e) - self._send_outputs(rpc_err) - - # Remove request from the engine. - self.engine.abort_request(request_id) - - def _handle_abort_request(self, request: RPCAbortRequest): - self.engine.abort_request(request.request_id) - if self.log_requests: - logger.info("Aborted request %s.", request.request_id) - - def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): - try: - self.engine.add_lora(request.lora_request) - except BaseException as e: - # Send back an error if the adater fails to load - rpc_err = RPCError(request_id=request.request_id, - is_engine_errored=False, - exception=e) - self._send_outputs(rpc_err) - return - # Otherwise, send back the successful load message - self._send_outputs( - RPCAdapterLoadedResponse(request_id=request.request_id)) - - def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): - is_sleeping = self.is_sleeping() - self._send_outputs( - RPCIsSleepingResponse(request_id=request.request_id, - is_sleeping=is_sleeping)) - - def _health_check(self): - # Send unhealthy if engine has already errored - if self._errored_with is not None: - self._send_unhealthy(self._errored_with) - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) - - def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): - """Send outputs back to the engine client. These can be: - - Exceptions - - A list of generation outputs - - A response from loading a lora adapter - """ - if outputs: - try: - from ray.exceptions import RayTaskError - - # RayTaskError might not pickelable here. We need to unpack the - # underlying exception as the real exception in the output. - if (isinstance(outputs, RPCError) - and isinstance(outputs.exception, RayTaskError)): - outputs.exception = outputs.exception.cause - except ImportError: - pass - - output_bytes = pickle.dumps(outputs) - self.output_socket.send_multipart((output_bytes, ), copy=False) - - def _send_healthy(self): - """Send HEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - - def _send_unhealthy(self, error: BaseException): - """Send UNHEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - error_bytes = pickle.dumps(error) - self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) - - def _async_socket_engine_callback(self, - request_outputs: REQUEST_OUTPUTS_T): - """Callback used by engine to make socket handling async with GPU.""" - self._send_outputs(request_outputs) - self.handle_new_input() - - def _set_errored(self, e: BaseException): - """Log and set errored status if this is the first issue.""" - if self._errored_with is None: - self._errored_with = e - - def start_profile(self) -> None: - self.engine.start_profile() - - def stop_profile(self) -> None: - self.engine.stop_profile() - - def reset_mm_cache(self) -> bool: - return self.engine.reset_mm_cache() - - def reset_prefix_cache(self) -> bool: - return self.engine.reset_prefix_cache() - - def sleep(self, level: int = 1) -> None: - self.engine.sleep(level) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - self.engine.wake_up(tags) - - def is_sleeping(self) -> bool: - return self.engine.is_sleeping() - - -def signal_handler(*_) -> None: - raise KeyboardInterrupt("MQLLMEngine terminated") - - -def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, - ipc_path: str, disable_log_stats: bool, - enable_log_requests: bool, engine_alive): - try: - # Ensure we can serialize transformer config before spawning - maybe_register_config_serialize_by_value() - - engine = MQLLMEngine.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - disable_log_stats=disable_log_stats, - enable_log_requests=enable_log_requests, - ipc_path=ipc_path) - - signal.signal(signal.SIGTERM, signal_handler) - - engine.start() - - except BaseException as e: - logger.exception(e) - engine_alive.value = False - raise e from None diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py deleted file mode 100644 index 4d75719c1719b..0000000000000 --- a/vllm/engine/output_processor/interfaces.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Callable, List - -from vllm.config import SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import Counter - - -class SequenceGroupOutputProcessor(ABC): - """Interface for logic that processes new token ids in sequence groups, - managing detokenization, stop checking, and freeing/forking sequences with - the scheduler. - - This is highly coupled with the LLMEngine and should be seen as an extension - of it. The logic is separated to simplify the LLMEngine class and allow - separate implementations for single-step decoding (which supports beam - search sequence forking) and multi-step decoding (which does not support - beam search, but does support speculative decoding). - """ - - @staticmethod - def create_output_processor( - scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], - stop_checker: "StopChecker", - ): - """Create an output processor. - - Multi-step scheduling is no longer supported. Always return a - single-step output processor. - """ - from vllm.engine.output_processor.single_step import ( - SingleStepOutputProcessor) - return SingleStepOutputProcessor(scheduler_config, detokenizer, - scheduler, seq_counter, stop_checker) - - @abstractmethod - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput], - is_async: bool) -> None: - """Process new token ids for the sequence group. Handles logic such as - detokenization, stop checking, and freeing/forking sequences in the - scheduler. - """ - pass - - @abstractmethod - def process_prompt_logprob(self, seq_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: - """Update prompt logprobs received from outputs to seq_group.""" - pass diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py deleted file mode 100644 index dbf6a371d050a..0000000000000 --- a/vllm/engine/output_processor/single_step.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List - -from vllm.config import SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.logger import init_logger -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, - SequenceGroupOutput) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.utils import Counter - -logger = init_logger(__name__) - - -def single_step_process_prompt_logprob( - sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, - output: CompletionSequenceGroupOutput) -> None: - """Process prompt logprobs associated with the - [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step. - - Do nothing if the output has no prompt logprobs. - - Account for the fact that transformers do not compute first-token logprobs. - - Args: - sg_output_proc: - [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor] - instance - seq_group: the output is associated with this - [`SequenceGroup`][vllm.sequence.SequenceGroup] - output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] - for a single scheduler step - """ - prompt_logprobs = output.prompt_logprobs - - # If this is the first (or only) "chunk" of the prefill, we need - # to prepend None to the list of prompt logprobs. The reason for this - # is that for N prompt tokens, the Sampler will generate N-1 total - # prompt logprobs during prefill since the token at idx 0 will not - # have a logprob associated with it. - if prompt_logprobs is not None: - if not seq_group.prompt_logprobs: - prompt_logprobs = [None] + prompt_logprobs - seq_group.prompt_logprobs = [] - - assert hasattr(sg_output_proc, 'detokenizer') - if (seq_group.sampling_params.detokenize - and sg_output_proc.detokenizer): - sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( - seq_group, - prompt_logprobs, - position_offset=len(seq_group.prompt_logprobs)) - - seq_group.prompt_logprobs.extend(prompt_logprobs) - - -class SingleStepOutputProcessor(SequenceGroupOutputProcessor): - """SequenceGroupOutputProcessor which handles "output processing" logic, - which happens after the model returns generated token ids and before - scheduling of the next batch. Output processing logic includes - detokenization, and determining if a sequence is finished (e.g. via max len - or eos token). - - The SingleStepOutputProcessor is specialized to the case where the model - emits at most a single token per invocation, which precludes configurations - such as speculative decoding or multi-step decoding. This enables beam - search sampling, which requires forking/finishing/freeing sequences in a way - that is currently difficult to schedule multiple steps ahead of time. - """ - - def __init__(self, scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, scheduler: List[Scheduler], - seq_counter: Counter, stop_checker: StopChecker): - self.scheduler_config = scheduler_config - self.detokenizer = detokenizer - self.scheduler = scheduler - self.seq_counter = seq_counter - self.stop_checker = stop_checker - - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput], - is_async: bool) -> None: - """Append all new tokens to sequences in the sequence group. Fork any - surviving beam candidates; free any unsurviving ones. - - Invokes detokenizer to detokenize new tokens, and also marks sequences - as finished if they meet stop conditions. - - is_async - Indicates whether this postprocessor runs in - parallel with the GPU forward pass and is processing - tokens from the previous step. If this is true, then - no tokens need to be appended since it is already done - externally (before the next schedule() call) - """ - assert (len(outputs) == 1 - ), f"{type(self)} does not support multiple outputs per step" - return self._process_sequence_group_outputs(sequence_group, outputs[0], - is_async) - - def process_prompt_logprob(self, seq_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: - """Process prompt logprobs associated with one step of a single-step- - scheduled computation. - - Args: - seq_group: the output is associated with this - [`SequenceGroup`][vllm.sequence.SequenceGroup] - outputs: the - [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] - for a single scheduler step - """ - assert len(outputs) == 1, "Single step should only have 1 output." - output = outputs[0] - assert isinstance(output, CompletionSequenceGroupOutput) - single_step_process_prompt_logprob(self, seq_group, output) - - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput, - is_async: bool) -> None: - sampling_params = seq_group.sampling_params - - sample = outputs.samples[0] - seq = seq_group.first_seq - if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs, - sample.output_embed) - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py deleted file mode 100644 index 3fb2f71b5e999..0000000000000 --- a/vllm/engine/output_processor/stop_checker.py +++ /dev/null @@ -1,131 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Callable, List, Optional, Tuple - -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams -from vllm.sequence import Sequence, SequenceStatus -from vllm.transformers_utils.tokenizer import AnyTokenizer - - -class StopChecker: - """LLMEngine helper class which separates out the logic involving stop - checking. This checks things such as: whether the eos token was emitted, - whether the max_tokens has been consumed, whether a stop string has been - emitted, or if we have exceeded the max model len. - """ - - def __init__(self, max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): - # Do not use it directly, but use `self._get_max_model_len`. - self._max_model_len = max_model_len - self.get_tokenizer_for_seq = get_tokenizer_for_seq - - def _get_max_model_len(self, lora_req: Optional[LoRARequest]): - if lora_req and lora_req.long_lora_max_len: - return lora_req.long_lora_max_len - else: - return self._max_model_len - - def maybe_stop_sequence( - self, - seq: Sequence, - new_char_count: int, - sampling_params: SamplingParams, - lora_req: Optional[LoRARequest] = None, - ) -> None: - """Stop the finished sequences. - - new_char_count is the number of chars added to the - sequence's output text for the newly generated token - """ - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - # Remove the last EOS token unless explicitly specified - # This prevents unintended exposure of the EOS token - if new_char_count and ( - not sampling_params.include_stop_str_in_output): - seq.output_text = seq.output_text[:-new_char_count] - seq.status = SequenceStatus.FINISHED_STOPPED - return - - # Check if a stop token was encountered. - # This assumes a single token produced per step. - last_token_id = seq.get_last_token_id() - if last_token_id in (sampling_params.stop_token_ids or ()): - if new_char_count and ( - not sampling_params.include_stop_str_in_output): - # Remove last token - seq.output_text = seq.output_text[:-new_char_count] - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if any stop strings are matched. - stop = self.check_stop_strings( - seq.output_text, new_char_count, sampling_params.stop, - sampling_params.include_stop_str_in_output) - if stop is not None: - stop_str, truncate_to = stop - if truncate_to != -1: - seq.output_text = seq.output_text[:truncate_to] - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - - # Check if the sequence has reached max_model_len. - if seq.get_len() >= self._get_max_model_len(lora_req): - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - @staticmethod - def check_stop_strings( - output_text: str, - new_char_count: int, - stop: List[str], - include_in_output: bool, - ) -> Optional[Tuple[str, int]]: - """Check if any stop strings are matched and truncate sequence - output text accordingly. - - Returns tuple (stop_string, offset) if matched or else None. - - Where stop_string is the matched stop string and offset is the - length to which output_text should be truncated, or -1 for no - truncation. - """ - if not new_char_count or not stop: - return None - - for stop_str in stop: - stop_string_len = len(stop_str) - # Avoid searching already-searched text. - stop_index = output_text.find(stop_str, - 1 - new_char_count - stop_string_len) - if stop_index == -1: - continue - - if include_in_output: - # Truncate to end of stop string. - stop_index += stop_string_len - if stop_index >= len(output_text): - # No truncation required. - return stop_str, -1 - - # Truncate the output text to either the beginning - # or end of the stop string. - return stop_str, stop_index - return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py deleted file mode 100644 index 1e127eb982425..0000000000000 --- a/vllm/engine/output_processor/util.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List -from typing import Sequence as GenericSequence -from typing import cast - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput - - -def create_output_by_sequence_group( - outputs: GenericSequence[SamplerOutput], - num_seq_groups: int) -> List[List[SequenceGroupOutput]]: - """Helper method which transforms a 2d list organized by - [step][sequence group] into [sequence group][step]. - """ - output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ - [] for _ in range(num_seq_groups) - ] - for step in outputs: - sequence_group_output: CompletionSequenceGroupOutput - for i, sequence_group_output in enumerate(step): - output_by_sequence_group[i].append(sequence_group_output) - - # Cast to the more generic type that CompletionSequenceGroupOutput - # inherits from. - return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 5e8ac9c0b3987..e7d957d7b684e 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,24 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, Iterable, Mapping, Optional, Union +from collections.abc import AsyncGenerator, Iterable, Mapping +from typing import Any, Optional, Union -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function -from vllm.config import DecodingConfig, ModelConfig, VllmConfig -from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType, TokensPrompt -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt -from vllm.inputs.preprocess import InputPreprocessor +from vllm.config import ModelConfig, VllmConfig +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import IOProcessor from vllm.pooling_params import PoolingParams -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import SamplingParams +from vllm.tasks import SupportedTask from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import Device, collect_from_async_generator, random_uuid +from vllm.utils import Device +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.processor import Processor logger = init_logger(__name__) @@ -26,195 +25,44 @@ logger = init_logger(__name__) class EngineClient(ABC): """Protocol class for Clients to Engine""" - @property - @abstractmethod - def is_running(self) -> bool: - ... + vllm_config: VllmConfig + model_config: ModelConfig + processor: Processor + io_processor: Optional[IOProcessor] @property @abstractmethod - def is_stopped(self) -> bool: - ... + def is_running(self) -> bool: ... @property @abstractmethod - def errored(self) -> bool: - ... + def is_stopped(self) -> bool: ... @property @abstractmethod - def dead_error(self) -> BaseException: - ... + def errored(self) -> bool: ... + + @property + @abstractmethod + def dead_error(self) -> BaseException: ... @abstractmethod def generate( self, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], sampling_params: SamplingParams, request_id: str, + *, + prompt_text: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... - async def beam_search( - self, - prompt: PromptType, - request_id: str, - params: BeamSearchParams, - lora_request: Optional[LoRARequest] = None, - ) -> AsyncGenerator[RequestOutput, None]: - - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - include_stop_str_in_output = params.include_stop_str_in_output - - preprocessor = await self.get_input_preprocessor() - tokenizer_group = preprocessor.get_tokenizer_group() - tokenizer = await tokenizer_group.get_lora_tokenizer_async() - - if is_explicit_encoder_decoder_prompt(prompt): - raise NotImplementedError - else: - processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) - - if processed_inputs["type"] == "embeds": - raise NotImplementedError - - # This is a workaround to fix multimodal beam search; this is a - # bandaid fix for 2 small problems: - # 1. Multi_modal_data on the processed_inputs currently resolves to - # `None`. - # 2. preprocessing above expands the multimodal placeholders. However, - # this happens again in generation, so the double expansion causes - # a mismatch. - # TODO - would be ideal to handle this more gracefully. - prompt_token_ids = prompt.get("prompt_token_ids") - multi_modal_data = prompt.get("multi_modal_data") - - prompt_text = processed_inputs.get("prompt") - mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") - - tokenized_length = len(prompt_token_ids) - - sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) - - beam_search_params = SamplingParams( - logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature, - ) - all_beams = [ - BeamSearchSequence(tokens=prompt_token_ids, - cum_logprob=0, - logprobs=[], - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - lora_request=lora_request) - ] - completed = [] - - for _ in range(max_tokens): - prompts_batch, lora_req_batch = zip(*[( - TokensPrompt(prompt_token_ids=beam.tokens, - multi_modal_data=beam.multi_modal_data, - mm_processor_kwargs=beam.mm_processor_kwargs), - beam.lora_request, - ) for beam in all_beams]) - - tasks = [] - - request_id = f"beam_search-{random_uuid()}" - for i, (individual_prompt, - lora_req) in enumerate(zip(prompts_batch, lora_req_batch)): - request_id_item = f"{request_id}-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.generate(individual_prompt, - beam_search_params, - request_id_item, - lora_request=lora_req))) - tasks.append(task) - - output = await asyncio.gather(*tasks) - - output = [x[0] for x in output] - - new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - completed.append( - BeamSearchSequence( - tokens=current_beam.tokens + - [token_id] if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + - [logprobs], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - finish_reason="stop", - stop_reason=tokenizer.eos_token_id)) - else: - new_beams.append( - BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + - [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam. - multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs)) - - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): - # Skip the eos token in the text. - tokens = beam.tokens[tokenized_length:-1] - else: - tokens = beam.tokens[tokenized_length:] - beam.text = tokenizer.decode(tokens) - - beam_search_output = RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput(text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - finish_reason=beam.finish_reason if - beam.finish_reason is not None else "length", - stop_reason=beam.stop_reason) - for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None) - - yield beam_search_output - @abstractmethod def encode( self, @@ -224,6 +72,7 @@ class EngineClient(ABC): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" ... @@ -239,44 +88,15 @@ class EngineClient(ABC): ... @abstractmethod - async def get_vllm_config(self) -> VllmConfig: - """Get the vllm configuration of the vLLM engine.""" + async def get_tokenizer(self) -> AnyTokenizer: + """Get the tokenizer""" ... @abstractmethod - async def get_model_config(self) -> ModelConfig: - """Get the model configuration of the vLLM engine.""" - ... + async def is_tracing_enabled(self) -> bool: ... @abstractmethod - async def get_decoding_config(self) -> DecodingConfig: - """Get the decoding configuration of the vLLM engine.""" - ... - - @abstractmethod - async def get_input_preprocessor(self) -> InputPreprocessor: - """Get the input processor of the vLLM engine.""" - ... - - @abstractmethod - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get the appropriate tokenizer for the request""" - ... - - @abstractmethod - async def is_tracing_enabled(self) -> bool: - ... - - @abstractmethod - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[list[SamplerOutput]] = None, - ) -> None: - ... + async def do_log_stats(self) -> None: ... @abstractmethod async def check_health(self) -> None: @@ -299,8 +119,7 @@ class EngineClient(ABC): ... @abstractmethod - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: """Reset the prefix cache""" ... @@ -320,20 +139,26 @@ class EngineClient(ABC): ... @abstractmethod - async def add_lora(self, lora_request: LoRARequest) -> None: + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" ... - async def scale_elastic_ep(self, - new_data_parallel_size: int, - drain_timeout: int = 300) -> None: + async def scale_elastic_ep( + self, new_data_parallel_size: int, drain_timeout: int = 300 + ) -> None: """Scale the engine""" raise NotImplementedError - async def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): + async def collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): """Perform a collective RPC call to the given path.""" raise NotImplementedError + + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + """Get supported tasks""" + raise NotImplementedError diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 3d1e5dc14d2f3..c31d15ddac4f5 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -7,6 +7,7 @@ For production use, we recommend using our OpenAI compatible server. We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ + import asyncio import json import ssl @@ -68,9 +69,7 @@ async def _generate(request_dict: dict, raw_request: Request) -> Response: async for request_output in results_generator: prompt = request_output.prompt assert prompt is not None - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] + text_outputs = [prompt + output.text for output in request_output.outputs] ret = {"text": text_outputs} yield (json.dumps(ret) + "\n").encode("utf-8") @@ -109,16 +108,20 @@ async def init_app( global engine engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER)) + engine = ( + llm_engine + if llm_engine is not None + else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER + ) + ) app.state.engine_client = engine return app -async def run_server(args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None, - **uvicorn_kwargs: Any) -> None: +async def run_server( + args: Namespace, llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs: Any +) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) @@ -151,26 +154,27 @@ if __name__ == "__main__": parser.add_argument("--port", type=parser.check_port, default=8000) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) - parser.add_argument("--ssl-ca-certs", - type=str, - default=None, - help="The CA certificates file") + parser.add_argument( + "--ssl-ca-certs", type=str, default=None, help="The CA certificates file" + ) parser.add_argument( "--enable-ssl-refresh", action="store_true", default=False, - help="Refresh SSL Context when SSL certificate files change") + help="Refresh SSL Context when SSL certificate files change", + ) parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)" + help="Whether client certificate is required (see stdlib ssl module's)", ) parser.add_argument( "--root-path", type=str, default=None, - help="FastAPI root_path when app is behind a path based routing proxy") + help="FastAPI root_path when app is behind a path based routing proxy", + ) parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 7b11a50642de9..e548554dca734 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -8,48 +8,49 @@ from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, - cast) +from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast +import jinja2 +import jinja2.ext +import jinja2.meta import jinja2.nodes +import jinja2.parser +import jinja2.sandbox import transformers.utils.chat_template_utils as hf_chat_utils -# yapf conflicts with isort for this block -# yapf: disable -from openai.types.chat import (ChatCompletionAssistantMessageParam, - ChatCompletionContentPartImageParam, - ChatCompletionContentPartInputAudioParam) from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import (ChatCompletionContentPartRefusalParam, - ChatCompletionContentPartTextParam) + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam, +) from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) -from openai.types.chat import (ChatCompletionMessageToolCallParam, - ChatCompletionToolMessageParam) -from openai.types.chat.chat_completion_content_part_input_audio_param import ( - InputAudio) + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam, +) +from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.responses import ResponseInputImageParam from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter -# yapf: enable -from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, - ProcessorMixin) +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin + # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MediaConnector -# yapf: disable -from vllm.transformers_utils.chat_templates import ( - get_chat_template_fallback_path) -# yapf: enable +from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid +from vllm.utils import random_uuid, supports_kw logger = init_logger(__name__) @@ -75,7 +76,7 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): - image_embeds: Required[Union[str, dict[str, str]]] + image_embeds: Optional[Union[str, dict[str, str]]] """ The image embeddings. It can be either: - A single base64 string. @@ -83,6 +84,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """ type: Required[Literal["image_embeds"]] """The type of the content part.""" + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class VideoURL(TypedDict, total=False): @@ -103,6 +109,7 @@ class PILImage(BaseModel): """ A PIL.Image.Image object. """ + image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) @@ -115,7 +122,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): "image_pil": ImageAsset('cherry_blossom').pil_image } """ - image_pil: Required[PILImage] + + image_pil: Optional[PILImage] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): @@ -127,7 +140,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): "image_url": "https://example.com/image.jpg" } """ - image_url: Required[str] + + image_url: Optional[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): @@ -138,7 +157,8 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): "audio_url": "https://example.com/audio.mp3" } """ - audio_url: Required[str] + + audio_url: Optional[str] class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): @@ -149,7 +169,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): "video_url": "https://example.com/video.mp4" } """ - video_url: Required[str] + + video_url: Optional[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomThinkCompletionContentParam(TypedDict, total=False): @@ -174,19 +200,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, - ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartVideoParam, + ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPILImageParam, CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, str, - CustomThinkCompletionContentParam] + CustomChatCompletionContentSimpleVideoParam, + str, + CustomThinkCompletionContentParam, +] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" + role: Required[str] """The role of the message's author.""" @@ -207,9 +238,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" -ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, - CustomChatCompletionMessageParam, - OpenAIHarmonyMessage] +ChatCompletionMessageParam = Union[ + OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam, + OpenAIHarmonyMessage, +] # TODO: Make fields ReadOnly once mypy supports it @@ -246,9 +279,11 @@ def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: if isinstance(node, jinja2.nodes.Getitem): - return (_is_var_access(node.node, varname) - and isinstance(node.arg, jinja2.nodes.Const) - and node.arg.value == key) + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) if isinstance(node, jinja2.nodes.Getattr): return _is_var_access(node.node, varname) and node.attr == key @@ -262,20 +297,18 @@ def _is_var_or_elems_access( key: Optional[str] = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): - return (node.node is not None - and _is_var_or_elems_access(node.node, varname, key)) + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) - if (isinstance(node, jinja2.nodes.Getitem) - and isinstance(node.arg, jinja2.nodes.Slice)): + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): return _is_var_or_elems_access(node.node, varname, key) - # yapf: disable - return ( - _is_attr_access(node, varname, key) if key - else _is_var_access(node, varname) - ) # yapf: enable + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): @@ -304,8 +337,7 @@ def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): # the scope in which each variable is defined, but that is too complicated def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): messages_varnames = [ - varname - for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") ] # Search for {%- for message in messages -%} loops @@ -371,17 +403,57 @@ def resolve_mistral_chat_template( chat_template: Optional[str], **kwargs: Any, ) -> Optional[str]: - if chat_template is not None: - logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer.") - if "add_generation_prompt" in kwargs: - logger.warning_once( - "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored.") - if "continue_final_message" in kwargs: - logger.warning_once( - "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored.") + if chat_template is not None or kwargs.get("chat_template_kwargs") is not None: + raise ValueError( + "'chat_template' or 'chat_template_kwargs' cannot be overridden " + "for mistral tokenizer." + ) + + return None + + +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]() +""" +Used in `_try_get_processor_chat_template` to avoid calling +`cached_get_processor` again if the processor fails to be loaded. + +This is needed because `lru_cache` does not cache when an exception happens. +""" + + +def _try_get_processor_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model_config: ModelConfig, +) -> Optional[str]: + cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) + if cache_key in _PROCESSOR_CHAT_TEMPLATES: + return _PROCESSOR_CHAT_TEMPLATES[cache_key] + + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), + trust_remote_code=model_config.trust_remote_code, + ) + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and (chat_template := processor.chat_template) is not None + ): + _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template + return chat_template + except Exception: + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + _PROCESSOR_CHAT_TEMPLATES[cache_key] = None return None @@ -398,26 +470,19 @@ def resolve_hf_chat_template( # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: - try: - processor = cached_get_processor( - tokenizer.name_or_path, - processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, - ProcessorMixin), - trust_remote_code=model_config.trust_remote_code, - ) - if isinstance(processor, ProcessorMixin) and \ - hasattr(processor, 'chat_template') and \ - processor.chat_template is not None: - return processor.chat_template - except Exception: - logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 + chat_template = _try_get_processor_chat_template(tokenizer, model_config) + if chat_template is not None: + return chat_template # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: - logger.debug("Failed to load AutoTokenizer chat template for %s", - tokenizer.name_or_path, exc_info=True) + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( @@ -425,12 +490,16 @@ def resolve_hf_chat_template( tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: - logger.info("Loading chat template fallback for %s as there isn't one " - "defined on HF Hub.", tokenizer.name_or_path) + logger.info_once( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) chat_template = load_chat_template(path) else: - logger.debug("There is no chat template fallback for %s", - tokenizer.name_or_path) + logger.debug_once( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) return chat_template @@ -452,11 +521,17 @@ def _resolve_chat_template_content_format( else: hf_chat_template = None - jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) - else load_chat_template(chat_template, is_literal=True)) + jinja_text = ( + hf_chat_template + if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) - detected_format = ("string" if jinja_text is None else - _detect_content_format(jinja_text, default="string")) + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) return detected_format @@ -512,7 +587,6 @@ def resolve_chat_template_content_format( return detected_format - ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") @@ -530,7 +604,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self._model_config = model_config self._tokenizer = tokenizer - self._items_by_modality = defaultdict[str, list[_T]](list) + self._items_by_modality = defaultdict[str, list[Optional[_T]]](list) + self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list) @property def model_config(self) -> ModelConfig: @@ -539,6 +614,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): @cached_property def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) return cast(type[SupportsMultiModal], model_cls) @@ -546,6 +622,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def allowed_local_media_path(self): return self._model_config.allowed_local_media_path + @property + def allowed_media_domains(self): + return self._model_config.allowed_media_domains + @property def mm_registry(self): return MULTIMODAL_REGISTRY @@ -554,10 +634,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def mm_processor(self): return self.mm_registry.create_processor(self.model_config) - def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + def add( + self, + modality: ModalityStr, + item: Optional[_T], + uuid: Optional[str] = None, + ) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. + + An optional uuid can be added which serves as a unique identifier of the + media. """ input_modality = modality.replace("_embeds", "") num_items = len(self._items_by_modality[modality]) + 1 @@ -565,37 +653,56 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) + self._uuids_by_modality[modality].append(uuid) return self.model_cls.get_placeholder_str(modality, num_items) + def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: + if not self._items_by_modality: + return None + mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) + if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: + raise ValueError("Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in uuids_by_modality: + image_embeds_uuids = uuids_by_modality["image_embeds"] + if len(image_embeds_uuids) > 1: + raise ValueError("Only one message can have {'type': 'image_embeds'}") + mm_uuids["image"] = uuids_by_modality["image_embeds"] + if "image" in uuids_by_modality: + mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images + if "audio" in uuids_by_modality: + mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios + if "video" in uuids_by_modality: + mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids + @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[object]): - def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError(\ - "Mixing raw image and embedding inputs is not allowed") + raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError(\ - "Only one message can have {'type': 'image_embeds'}") + raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -603,32 +710,34 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): - async def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} - items_by_modality = { - modality: await asyncio.gather(*items) - for modality, items in self._items_by_modality.items() - } + items_by_modality = {} + for modality, items in self._items_by_modality.items(): + coros = [] + for item in items: + if item is not None: + coros.append(item) + else: + coros.append(asyncio.sleep(0)) + items_by_modality[modality] = await asyncio.gather(*coros) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError( - "Mixing raw image and embedding inputs is not allowed") + raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError( - "Only one message can have {'type': 'image_embeds'}") + raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -636,11 +745,10 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class BaseMultiModalContentParser(ABC): - def __init__(self) -> None: super().__init__() - # stores model placehodlers list with corresponding + # stores model placeholders list with corresponding # general MM placeholder: # { # "<##IMAGE##>": ["<image>", "<image>", "<image>"], @@ -648,8 +756,7 @@ class BaseMultiModalContentParser(ABC): # } self._placeholder_storage: dict[str, list] = defaultdict(list) - def _add_placeholder(self, modality: ModalityStr, - placeholder: Optional[str]): + def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]): mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: self._placeholder_storage[mod_placeholder].append(placeholder) @@ -658,108 +765,138 @@ class BaseMultiModalContentParser(ABC): return dict(self._placeholder_storage) @abstractmethod - def parse_image(self, image_url: str) -> None: + def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, + ) -> None: raise NotImplementedError @abstractmethod - def parse_image_pil(self, image_pil: Image.Image) -> None: + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_audio(self, audio_url: str) -> None: + def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod - def parse_input_audio(self, input_audio: InputAudio) -> None: + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_video(self, video_url: str) -> None: + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError class MultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: MultiModalItemTracker) -> None: super().__init__() self._tracker = tracker - + multimodal_config = self._tracker.model_config.multimodal_config + media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) self._connector = MediaConnector( - media_io_kwargs=self._tracker._model_config.media_io_kwargs, + media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, + allowed_media_domains=tracker.allowed_media_domains, ) - def parse_image(self, image_url: str) -> None: - image = self._connector.fetch_image(image_url) + def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + image = self._connector.fetch_image(image_url) if image_url else None - placeholder = self._tracker.add("image", image) + placeholder = self._tracker.add("image", image, uuid) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, + ) -> None: if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) for k, v in image_embeds.items() } - placeholder = self._tracker.add("image_embeds", embeds) + placeholder = self._tracker.add("image_embeds", embeds, uuid) if isinstance(image_embeds, str): embedding = self._connector.fetch_image_embedding(image_embeds) - placeholder = self._tracker.add("image_embeds", embedding) + placeholder = self._tracker.add("image_embeds", embedding, uuid) + + if image_embeds is None: + placeholder = self._tracker.add("image_embeds", None, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - placeholder = self._tracker.add("image", image_pil) + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: + placeholder = self._tracker.add("image", image_pil, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: - audio = self._connector.fetch_audio(audio_url) + def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: + audio = self._connector.fetch_audio(audio_url) if audio_url else None - placeholder = self._tracker.add("audio", audio) + placeholder = self._tracker.add("audio", audio, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: - audio_data = input_audio.get("data", "") - audio_format = input_audio.get("format", "") - audio_url = f"data:audio/{audio_format};base64,{audio_data}" + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: + if input_audio: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + if audio_data: + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + else: + # If a UUID is provided, audio data may be empty. + audio_url = None + else: + audio_url = None - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: - video = self._connector.fetch_video(video_url=video_url) + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + video = self._connector.fetch_video(video_url=video_url) if video_url else None - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) class AsyncMultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker + multimodal_config = self._tracker.model_config.multimodal_config + media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) self._connector = MediaConnector( - media_io_kwargs=self._tracker._model_config.media_io_kwargs, - allowed_local_media_path=tracker.allowed_local_media_path + media_io_kwargs=media_io_kwargs, + allowed_local_media_path=tracker.allowed_local_media_path, + allowed_media_domains=tracker.allowed_media_domains, ) - def parse_image(self, image_url: str) -> None: - image_coro = self._connector.fetch_image_async(image_url) + def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + image_coro = self._connector.fetch_image_async(image_url) if image_url else None - placeholder = self._tracker.add("image", image_coro) + placeholder = self._tracker.add("image", image_coro, uuid) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: - future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, + ) -> None: + future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future() if isinstance(image_embeds, dict): embeds = { @@ -769,37 +906,57 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): future.set_result(embeds) if isinstance(image_embeds, str): - embedding = self._connector.\ - fetch_image_embedding(image_embeds) + embedding = self._connector.fetch_image_embedding(image_embeds) future.set_result(embedding) - placeholder = self._tracker.add("image_embeds", future) + if image_embeds is None: + future.set_result(None) + + placeholder = self._tracker.add("image_embeds", future, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - future: asyncio.Future[Image.Image] = asyncio.Future() - future.set_result(image_pil) + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: + future: asyncio.Future[Optional[Image.Image]] = asyncio.Future() + if image_pil: + future.set_result(image_pil) + else: + future.set_result(None) - placeholder = self._tracker.add("image", future) + placeholder = self._tracker.add("image", future, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: - audio_coro = self._connector.fetch_audio_async(audio_url) + def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: + audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None - placeholder = self._tracker.add("audio", audio_coro) + placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: - audio_data = input_audio.get("data", "") - audio_format = input_audio.get("format", "") - audio_url = f"data:audio/{audio_format};base64,{audio_data}" + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: + if input_audio: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + if audio_data: + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + else: + # If a UUID is provided, audio data may be empty. + audio_url = None + else: + audio_url = None - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: - video = self._connector.fetch_video_async(video_url=video_url) + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + video = ( + self._connector.fetch_video_async(video_url=video_url) + if video_url + else None + ) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) @@ -809,20 +966,21 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): return elif isinstance(chat_template, Path) and not chat_template.exists(): - raise FileNotFoundError( - "the supplied chat template path doesn't exist") + raise FileNotFoundError("the supplied chat template path doesn't exist") elif isinstance(chat_template, str): JINJA_CHARS = "{}\n" - if not any(c in chat_template - for c in JINJA_CHARS) and not Path(chat_template).exists(): + if ( + not any(c in chat_template for c in JINJA_CHARS) + and not Path(chat_template).exists() + ): raise ValueError( f"The supplied chat template string ({chat_template}) " - f"appears path-like, but doesn't exist!") + f"appears path-like, but doesn't exist!" + ) else: - raise TypeError( - f"{type(chat_template)} is not a valid chat template type") + raise TypeError(f"{type(chat_template)} is not a valid chat template type") def _load_chat_template( @@ -835,8 +993,9 @@ def _load_chat_template( if is_literal: if isinstance(chat_template, Path): - raise TypeError("chat_template is expected to be read directly " - "from its value") + raise TypeError( + "chat_template is expected to be read directly from its value" + ) return chat_template @@ -849,9 +1008,11 @@ def _load_chat_template( JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) raise ValueError(msg) from e # If opening a file fails, set chat template to be args to @@ -870,8 +1031,9 @@ def load_chat_template( return _cached_load_chat_template(chat_template, is_literal=is_literal) -def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], - texts: list[str]) -> str: +def _get_interleaved_text_prompt( + placeholder_storage: dict[str, list], texts: list[str] +) -> str: for idx, elem in enumerate(texts): if elem in placeholder_storage: texts[idx] = placeholder_storage[elem].pop(0) @@ -881,10 +1043,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], - texts: list[str], - interleave_strings: bool - ) -> str: +def _get_full_multimodal_text_prompt( + placeholder_storage: dict[str, list], + texts: list[str], + interleave_strings: bool, +) -> str: """Combine multimodal prompts for a multimodal language model.""" # flatten storage to make it looks like @@ -907,7 +1070,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], # Look through the text prompt to check for missing placeholders missing_placeholders: list[str] = [] for placeholder in placeholder_counts: - # For any existing placeholder in the text prompt, we leave it as is placeholder_counts[placeholder] -= text_prompt.count(placeholder) @@ -916,15 +1078,16 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], "Placeholder count is negative! " "Ensure that the 'interleave_strings' flag is disabled " "(current value: %s) " - "when manually placing image placeholders.", interleave_strings + "when manually placing image placeholders.", + interleave_strings, ) logger.debug("Input prompt: %s", text_prompt) raise ValueError( f"Found more '{placeholder}' placeholders in input prompt than " - "actual multimodal data items.") + "actual multimodal data items." + ) - missing_placeholders.extend([placeholder] * - placeholder_counts[placeholder]) + missing_placeholders.extend([placeholder] * placeholder_counts[placeholder]) # NOTE: Default behaviour: we always add missing placeholders # at the front of the prompt, if interleave_strings=False @@ -943,8 +1106,7 @@ _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python -_ResponsesInputImageParser = TypeAdapter( - ResponseInputImageParam).validate_python +_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] # Define a mapping from part types to their corresponding parsing functions. @@ -952,32 +1114,23 @@ MM_PARSER_MAP: dict[ str, Callable[[ChatCompletionContentPartParam], _ContentPart], ] = { - "text": - lambda part: _TextParser(part).get("text", None), - "thinking": - lambda part: _ThinkParser(part).get("thinking", None), - "input_text": - lambda part: _TextParser(part).get("text", None), - "input_image": - lambda part: _ResponsesInputImageParser(part).get("image_url", None), - "image_url": - lambda part: _ImageParser(part).get("image_url", {}).get("url", None), - "image_embeds": - lambda part: _ImageEmbedsParser(part).get("image_embeds", None), + "text": lambda part: _TextParser(part).get("text", None), + "thinking": lambda part: _ThinkParser(part).get("thinking", None), + "input_text": lambda part: _TextParser(part).get("text", None), + "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None), + "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None), + "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None), - "audio_url": - lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), - "input_audio": - lambda part: _InputAudioParser(part).get("input_audio", None), - "refusal": - lambda part: _RefusalParser(part).get("refusal", None), - "video_url": - lambda part: _VideoParser(part).get("video_url", {}).get("url", None), + "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), + "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None), + "refusal": lambda part: _RefusalParser(part).get("refusal", None), + "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None), } def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + part: ChatCompletionContentPartParam, +) -> tuple[str, _ContentPart]: """ Parses a given multi-modal content part based on its type. @@ -993,38 +1146,67 @@ def _parse_chat_message_content_mm_part( ValueError: If the 'type' field is missing and no direct URL is found. """ assert isinstance( - part, dict) # This is needed to avoid mypy errors: part.get() from str + part, dict + ) # This is needed to avoid mypy errors: part.get() from str part_type = part.get("type", None) + uuid = part.get("uuid", None) - if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501 content = MM_PARSER_MAP[part_type](part) # Special case for 'image_url.detail' # We only support 'auto', which is the default if part_type == "image_url" and part.get("detail", "auto") != "auto": - logger.warning("'image_url.detail' is currently not supported " - "and will be ignored.") + logger.warning( + "'image_url.detail' is currently not supported and will be ignored." + ) return part_type, content # Handle missing 'type' but provided direct URL fields. # 'type' is required field by pydantic - if part_type is None: - if part.get("image_url") is not None: - image_params = cast(CustomChatCompletionContentSimpleImageParam, - part) - return "image_url", image_params.get("image_url", "") - if part.get("audio_url") is not None: - audio_params = cast(CustomChatCompletionContentSimpleAudioParam, - part) - return "audio_url", audio_params.get("audio_url", "") + if part_type is None or uuid is not None: + if "image_url" in part: + image_params = cast(CustomChatCompletionContentSimpleImageParam, part) + image_url = image_params.get("image_url", None) + if isinstance(image_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + image_url = image_url.get("url", None) + return "image_url", image_url + if "image_pil" in part: + # "image_pil" could be None if UUID is provided. + image_params = cast( # type: ignore + CustomChatCompletionContentPILImageParam, part + ) + image_pil = image_params.get("image_pil", None) + return "image_pil", image_pil + if "image_embeds" in part: + # "image_embeds" could be None if UUID is provided. + image_params = cast( # type: ignore + ChatCompletionContentPartImageEmbedsParam, part + ) + image_embeds = image_params.get("image_embeds", None) + return "image_embeds", image_embeds + if "audio_url" in part: + audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part) + audio_url = audio_params.get("audio_url", None) + if isinstance(audio_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + audio_url = audio_url.get("url", None) + return "audio_url", audio_url if part.get("input_audio") is not None: input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params - if part.get("video_url") is not None: - video_params = cast(CustomChatCompletionContentSimpleVideoParam, - part) - return "video_url", video_params.get("video_url", "") + if "video_url" in part: + video_params = cast(CustomChatCompletionContentSimpleVideoParam, part) + video_url = video_params.get("video_url", None) + if isinstance(video_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + video_url = video_url.get("url", None) + return "video_url", video_url # Raise an error if no 'type' or direct URL is found. raise ValueError("Missing 'type' field in multimodal part.") @@ -1033,9 +1215,10 @@ def _parse_chat_message_content_mm_part( return part_type, "unknown part_type content" -VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", - "image_embeds", "image_pil", - "audio_url", "input_audio", "video_url") +PART_TYPES_TO_SKIP_NONE_CONTENT = ( + "text", + "refusal", +) def _parse_chat_message_content_parts( @@ -1055,21 +1238,20 @@ def _parse_chat_message_content_parts( part, mm_parser, wrap_dicts=wrap_dicts, - interleave_strings=interleave_strings + interleave_strings=interleave_strings, ) if parse_res: content.append(parse_res) if wrap_dicts: # Parsing wraps images and texts as interleaved dictionaries - return [ConversationMessage(role=role, - content=content)] # type: ignore + return [ConversationMessage(role=role, content=content)] # type: ignore texts = cast(list[str], content) mm_placeholder_storage = mm_parser.mm_placeholder_storage() if mm_placeholder_storage: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, - texts, - interleave_strings) + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_storage, texts, interleave_strings + ) else: text_prompt = "\n".join(texts) @@ -1096,49 +1278,63 @@ def _parse_chat_message_content_part( part_type, content = _parse_chat_message_content_mm_part(part) # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but # content is None, log a warning and skip - if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: + if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None: logger.warning( "Skipping multimodal part '%s' (type: '%s') " - "with empty / unparsable content.", part, part_type) + "with empty / unparsable content.", + part, + part_type, + ) return None if part_type in ("text", "input_text", "refusal", "thinking"): str_content = cast(str, content) if wrap_dicts: - return {'type': 'text', 'text': str_content} + return {"type": "text", "text": str_content} else: return str_content + # For media items, if a user has provided one, use it. Otherwise, insert + # a placeholder empty uuid. + uuid = part.get("uuid", None) + if uuid is not None: + uuid = str(uuid) + modality = None if part_type == "image_pil": - image_content = cast(Image.Image, content) - mm_parser.parse_image_pil(image_content) + image_content = cast(Image.Image, content) if content is not None else None + mm_parser.parse_image_pil(image_content, uuid) modality = "image" elif part_type in ("image_url", "input_image"): str_content = cast(str, content) - mm_parser.parse_image(str_content) + mm_parser.parse_image(str_content, uuid) modality = "image" elif part_type == "image_embeds": - content = cast(Union[str, dict[str, str]], content) - mm_parser.parse_image_embeds(content) + if content is not None: + content = cast(Union[str, dict[str, str]], content) + else: + content = None + mm_parser.parse_image_embeds(content, uuid) modality = "image" elif part_type == "audio_url": str_content = cast(str, content) - mm_parser.parse_audio(str_content) + mm_parser.parse_audio(str_content, uuid) modality = "audio" elif part_type == "input_audio": dict_content = cast(InputAudio, content) - mm_parser.parse_input_audio(dict_content) + mm_parser.parse_input_audio(dict_content, uuid) modality = "audio" elif part_type == "video_url": str_content = cast(str, content) - mm_parser.parse_video(str_content) + mm_parser.parse_video(str_content, uuid) modality = "video" else: raise NotImplementedError(f"Unknown part type: {part_type}") - return {'type': modality} if wrap_dicts else ( - MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + return ( + {"type": modality} + if wrap_dicts + else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None) ) @@ -1159,9 +1355,7 @@ def _parse_chat_message_content( if content is None: content = [] elif isinstance(content, str): - content = [ - ChatCompletionContentPartTextParam(type="text", text=content) - ] + content = [ChatCompletionContentPartTextParam(type="text", text=content)] result = _parse_chat_message_content_parts( role, content, # type: ignore @@ -1171,14 +1365,13 @@ def _parse_chat_message_content( ) for result_msg in result: - if role == 'assistant': + if role == "assistant": parsed_msg = _AssistantParser(message) # The 'tool_calls' is not None check ensures compatibility. # It's needed only if downstream code doesn't strictly # follow the OpenAI spec. - if ("tool_calls" in parsed_msg - and parsed_msg["tool_calls"] is not None): + if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None: result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1198,12 +1391,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: - if (message["role"] == "assistant" and "tool_calls" in message - and isinstance(message["tool_calls"], list)): - + if ( + message["role"] == "assistant" + and "tool_calls" in message + and isinstance(message["tool_calls"], list) + ): for item in message["tool_calls"]: - item["function"]["arguments"] = json.loads( - item["function"]["arguments"]) + # if arguments is None or empty string, set to {} + if content := item["function"].get("arguments"): + item["function"]["arguments"] = json.loads(content) + else: + item["function"]["arguments"] = {} def parse_chat_messages( @@ -1211,7 +1409,11 @@ def parse_chat_messages( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: +) -> tuple[ + list[ConversationMessage], + Optional[MultiModalDataDict], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -1224,14 +1426,14 @@ def parse_chat_messages( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def parse_chat_messages_futures( @@ -1239,7 +1441,11 @@ def parse_chat_messages_futures( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: +) -> tuple[ + list[ConversationMessage], + Awaitable[Optional[MultiModalDataDict]], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -1252,14 +1458,63 @@ def parse_chat_messages_futures( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() + + +# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 +# only preserve the parse function used to resolve chat template kwargs +class AssistantTracker(jinja2.ext.Extension): + tags = {"generation"} + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + call = self.call_method("_generation_support") + call_block = jinja2.nodes.CallBlock(call, [], [], body) + return call_block.set_lineno(lineno) + + +def _resolve_chat_template_kwargs( + chat_template: str, +): + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + return template_vars + + +_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) + + +def resolve_chat_template_kwargs( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: str, + chat_template_kwargs: dict[str, Any], +) -> dict[str, Any]: + fn_kw = { + k + for k in chat_template_kwargs + if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) + } + + template_vars = _cached_resolve_chat_template_kwargs(chat_template) + + # We exclude chat_template from kwargs here, because + # chat template has been already resolved at this stage + unexpected_vars = {"chat_template"} + accept_vars = (fn_kw | template_vars) - unexpected_vars + return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} def apply_hf_chat_template( @@ -1283,28 +1538,34 @@ def apply_hf_chat_template( raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one.") + "does not define one." + ) try: - + resolved_kwargs = resolve_chat_template_kwargs( + tokenizer=tokenizer, + chat_template=hf_chat_template, + chat_template_kwargs=kwargs, + ) return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] chat_template=hf_chat_template, tokenize=tokenize, - **kwargs, + **resolved_kwargs, ) # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `transformers` while applying chat template") + "An error occurred in `transformers` while applying chat template" + ) raise ValueError(str(e)) from e + def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: list[ChatCompletionMessageParam], @@ -1337,26 +1598,26 @@ def apply_mistral_chat_template( # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `mistral_common` while applying chat " - "template") + "An error occurred in `mistral_common` while applying chat template" + ) raise ValueError(str(e)) from e + def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): idx = 0 for msg in conversation: - if msg['role'] == 'assistant': - tool_calls = msg.get('tool_calls') - idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + if msg["role"] == "assistant": + tool_calls = msg.get("tool_calls") + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa return idx -def make_tool_call_id(id_type:str='random', func_name=None, idx=None): - if id_type=='kimi_k2': - return f'functions.{func_name}:{idx}' +def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): + if id_type == "kimi_k2": + return f"functions.{func_name}:{idx}" else: # by default return random return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py index 41671b5b98abb..211e157fc7c82 100644 --- a/vllm/entrypoints/cli/__init__.py +++ b/vllm/entrypoints/cli/__init__.py @@ -2,11 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand -from vllm.entrypoints.cli.benchmark.throughput import ( - BenchmarkThroughputSubcommand) +from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand __all__: list[str] = [ "BenchmarkLatencySubcommand", "BenchmarkServingSubcommand", "BenchmarkThroughputSubcommand", -] \ No newline at end of file +] diff --git a/vllm/entrypoints/cli/benchmark/base.py b/vllm/entrypoints/cli/benchmark/base.py index 0c22bc75105e6..3263459fd6810 100644 --- a/vllm/entrypoints/cli/benchmark/base.py +++ b/vllm/entrypoints/cli/benchmark/base.py @@ -6,7 +6,7 @@ from vllm.entrypoints.cli.types import CLISubcommand class BenchmarkSubcommandBase(CLISubcommand): - """ The base class of subcommands for vllm bench. """ + """The base class of subcommands for vllm bench.""" help: str diff --git a/vllm/entrypoints/cli/benchmark/latency.py b/vllm/entrypoints/cli/benchmark/latency.py index 3e68963cfd44e..548ddf4d603e7 100644 --- a/vllm/entrypoints/cli/benchmark/latency.py +++ b/vllm/entrypoints/cli/benchmark/latency.py @@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): - """ The `latency` subcommand for vllm bench. """ + """The `latency` subcommand for vllm bench.""" name = "latency" help = "Benchmark the latency of a single batch of requests." diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index 87fb9f3514645..d7455daa1a6b7 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -8,15 +8,14 @@ import typing from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, - show_filtered_argument_or_group_from_help) +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG if typing.TYPE_CHECKING: from vllm.utils import FlexibleArgumentParser class BenchmarkSubcommand(CLISubcommand): - """ The `bench` subcommand for the vLLM CLI. """ + """The `bench` subcommand for the vLLM CLI.""" name = "bench" help = "vLLM bench subcommand." @@ -29,28 +28,27 @@ class BenchmarkSubcommand(CLISubcommand): pass def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: bench_parser = subparsers.add_parser( self.name, - help=self.help, description=self.help, - usage="vllm bench <bench_type> [options]") - bench_subparsers = bench_parser.add_subparsers(required=True, - dest="bench_type") + usage=f"vllm {self.name} <bench_type> [options]", + ) + bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type") for cmd_cls in BenchmarkSubcommandBase.__subclasses__(): cmd_subparser = bench_subparsers.add_parser( cmd_cls.name, help=cmd_cls.help, description=cmd_cls.help, - usage=f"vllm bench {cmd_cls.name} [options]", + usage=f"vllm {self.name} {cmd_cls.name} [options]", ) cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd) cmd_cls.add_cli_args(cmd_subparser) - show_filtered_argument_or_group_from_help(cmd_subparser, - ["bench", cmd_cls.name]) - cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG + cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format( + subcmd=f"{self.name} {cmd_cls.name}" + ) return bench_parser diff --git a/vllm/entrypoints/cli/benchmark/serve.py b/vllm/entrypoints/cli/benchmark/serve.py index 3dd7a46d6284b..b085f52afb3b3 100644 --- a/vllm/entrypoints/cli/benchmark/serve.py +++ b/vllm/entrypoints/cli/benchmark/serve.py @@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase class BenchmarkServingSubcommand(BenchmarkSubcommandBase): - """ The `serve` subcommand for vllm bench. """ + """The `serve` subcommand for vllm bench.""" name = "serve" help = "Benchmark the online serving throughput." diff --git a/vllm/entrypoints/cli/benchmark/throughput.py b/vllm/entrypoints/cli/benchmark/throughput.py index d5d43ad4a3591..c25be75ec11e2 100644 --- a/vllm/entrypoints/cli/benchmark/throughput.py +++ b/vllm/entrypoints/cli/benchmark/throughput.py @@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): - """ The `throughput` subcommand for vllm bench. """ + """The `throughput` subcommand for vllm bench.""" name = "throughput" help = "Benchmark offline inference throughput." diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py index 785c18812adb7..e79a7efec6bac 100644 --- a/vllm/entrypoints/cli/collect_env.py +++ b/vllm/entrypoints/cli/collect_env.py @@ -14,7 +14,8 @@ if typing.TYPE_CHECKING: class CollectEnvSubcommand(CLISubcommand): - """The `collect-env` subcommand for the vLLM CLI. """ + """The `collect-env` subcommand for the vLLM CLI.""" + name = "collect-env" @staticmethod @@ -23,13 +24,14 @@ class CollectEnvSubcommand(CLISubcommand): collect_env_main() def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: return subparsers.add_parser( "collect-env", help="Start collecting environment information.", description="Start collecting environment information.", - usage="vllm collect-env") + usage="vllm collect-env", + ) def cmd_init() -> list[CLISubcommand]: diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index fed3ea6504050..cb15952f0d2de 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -'''The CLI entrypoints of vLLM +"""The CLI entrypoints of vLLM Note that all future modules must be lazily loaded within main -to avoid certain eager import breakage.''' +to avoid certain eager import breakage.""" + from __future__ import annotations import importlib.metadata +import sys + +from vllm.logger import init_logger + +logger = init_logger(__name__) def main(): @@ -28,23 +34,38 @@ def main(): cli_env_setup() + # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default + if len(sys.argv) > 1 and sys.argv[1] == "bench": + logger.debug( + "Bench command detected, must ensure current platform is not " + "UnspecifiedPlatform to avoid device type inference error" + ) + from vllm import platforms + + if platforms.current_platform.is_unspecified(): + from vllm.platforms.cpu import CpuPlatform + + platforms.current_platform = CpuPlatform() + logger.info( + "Unspecified platform detected, switching to CPU Platform instead." + ) + parser = FlexibleArgumentParser( description="vLLM CLI", - epilog=VLLM_SUBCMD_PARSER_EPILOG, + epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), ) parser.add_argument( - '-v', - '--version', - action='version', - version=importlib.metadata.version('vllm'), + "-v", + "--version", + action="version", + version=importlib.metadata.version("vllm"), ) subparsers = parser.add_subparsers(required=False, dest="subparser") cmds = {} for cmd_module in CMD_MODULES: new_cmds = cmd_module.cmd_init() for cmd in new_cmds: - cmd.subparser_init(subparsers).set_defaults( - dispatch_function=cmd.cmd) + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) cmds[cmd.name] = cmd args = parser.parse_args() if args.subparser in cmds: diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 7c01de94a3436..5372210bbf55c 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: def _register_signal_handlers(): - def signal_handler(sig, frame): sys.exit(0) @@ -45,6 +44,28 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: return model_name, openai_client +def _print_chat_stream(stream) -> str: + output = "" + for chunk in stream: + delta = chunk.choices[0].delta + if delta.content: + output += delta.content + print(delta.content, end="", flush=True) + print() + return output + + +def _print_completion_stream(stream) -> str: + output = "" + for chunk in stream: + text = chunk.choices[0].text + if text is not None: + output += text + print(text, end="", flush=True) + print() + return output + + def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: conversation: list[ChatCompletionMessageParam] = [] if system_prompt is not None: @@ -58,29 +79,29 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: break conversation.append({"role": "user", "content": input_message}) - chat_completion = client.chat.completions.create(model=model_name, - messages=conversation) - - response_message = chat_completion.choices[0].message - output = response_message.content - - conversation.append(response_message) # type: ignore - print(output) + stream = client.chat.completions.create( + model=model_name, messages=conversation, stream=True + ) + output = _print_chat_stream(stream) + conversation.append({"role": "assistant", "content": output}) -def _add_query_options( - parser: FlexibleArgumentParser) -> FlexibleArgumentParser: +def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--url", type=str, default="http://localhost:8000/v1", - help="url of the running OpenAI-Compatible RESTful API server") + help="url of the running OpenAI-Compatible RESTful API server", + ) parser.add_argument( "--model-name", type=str, default=None, - help=("The model name used in prompt completion, default to " - "the first model in list models API call.")) + help=( + "The model name used in prompt completion, default to " + "the first model in list models API call." + ), + ) parser.add_argument( "--api-key", type=str, @@ -88,12 +109,14 @@ def _add_query_options( help=( "API key for OpenAI services. If provided, this api key " "will overwrite the api key obtained through environment variables." - )) + ), + ) return parser class ChatCommand(CLISubcommand): - """The `chat` subcommand for the vLLM CLI. """ + """The `chat` subcommand for the vLLM CLI.""" + name = "chat" @staticmethod @@ -108,9 +131,11 @@ class ChatCommand(CLISubcommand): if args.quick: conversation.append({"role": "user", "content": args.quick}) - chat_completion = client.chat.completions.create( - model=model_name, messages=conversation) - print(chat_completion.choices[0].message.content) + stream = client.chat.completions.create( + model=model_name, messages=conversation, stream=True + ) + output = _print_chat_stream(stream) + conversation.append({"role": "assistant", "content": output}) return print("Please enter a message for the chat model:") @@ -121,14 +146,11 @@ class ChatCommand(CLISubcommand): break conversation.append({"role": "user", "content": input_message}) - chat_completion = client.chat.completions.create( - model=model_name, messages=conversation) - - response_message = chat_completion.choices[0].message - output = response_message.content - - conversation.append(response_message) # type: ignore - print(output) + stream = client.chat.completions.create( + model=model_name, messages=conversation, stream=True + ) + output = _print_chat_stream(stream) + conversation.append({"role": "assistant", "content": output}) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -138,39 +160,46 @@ class ChatCommand(CLISubcommand): "--system-prompt", type=str, default=None, - help=("The system prompt to be added to the chat template, " - "used for models that support system prompts.")) - parser.add_argument("-q", - "--quick", - type=str, - metavar="MESSAGE", - help=("Send a single prompt as MESSAGE " - "and print the response, then exit.")) + help=( + "The system prompt to be added to the chat template, " + "used for models that support system prompts." + ), + ) + parser.add_argument( + "-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE and print the response, then exit."), + ) return parser def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: parser = subparsers.add_parser( "chat", help="Generate chat completions via the running API server.", description="Generate chat completions via the running API server.", - usage="vllm chat [options]") + usage="vllm chat [options]", + ) return ChatCommand.add_cli_args(parser) class CompleteCommand(CLISubcommand): - """The `complete` subcommand for the vLLM CLI. """ - name = 'complete' + """The `complete` subcommand for the vLLM CLI.""" + + name = "complete" @staticmethod def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) if args.quick: - completion = client.completions.create(model=model_name, - prompt=args.quick) - print(completion.choices[0].text) + stream = client.completions.create( + model=model_name, prompt=args.quick, stream=True + ) + _print_completion_stream(stream) return print("Please enter prompt to complete:") @@ -179,10 +208,10 @@ class CompleteCommand(CLISubcommand): input_prompt = input("> ") except EOFError: break - completion = client.completions.create(model=model_name, - prompt=input_prompt) - output = completion.choices[0].text - print(output) + stream = client.completions.create( + model=model_name, prompt=input_prompt, stream=True + ) + _print_completion_stream(stream) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -193,20 +222,25 @@ class CompleteCommand(CLISubcommand): "--quick", type=str, metavar="PROMPT", - help= - "Send a single prompt and print the completion output, then exit.") + help="Send a single prompt and print the completion output, then exit.", + ) return parser def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: parser = subparsers.add_parser( "complete", - help=("Generate text completions based on the given prompt " - "via the running API server."), - description=("Generate text completions based on the given prompt " - "via the running API server."), - usage="vllm complete [options]") + help=( + "Generate text completions based on the given prompt " + "via the running API server." + ), + description=( + "Generate text completions based on the given prompt " + "via the running API server." + ), + usage="vllm complete [options]", + ) return CompleteCommand.add_cli_args(parser) diff --git a/vllm/entrypoints/cli/run_batch.py b/vllm/entrypoints/cli/run_batch.py index 86491678d7d24..6e7a15ada49cf 100644 --- a/vllm/entrypoints/cli/run_batch.py +++ b/vllm/entrypoints/cli/run_batch.py @@ -9,8 +9,7 @@ import importlib.metadata import typing from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, - show_filtered_argument_or_group_from_help) +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.logger import init_logger if typing.TYPE_CHECKING: @@ -21,14 +20,16 @@ logger = init_logger(__name__) class RunBatchSubcommand(CLISubcommand): """The `run-batch` subcommand for vLLM CLI.""" + name = "run-batch" @staticmethod def cmd(args: argparse.Namespace) -> None: from vllm.entrypoints.openai.run_batch import main as run_batch_main - logger.info("vLLM batch processing API version %s", - importlib.metadata.version("vllm")) + logger.info( + "vLLM batch processing API version %s", importlib.metadata.version("vllm") + ) logger.info("args: %s", args) # Start the Prometheus metrics server. @@ -45,23 +46,21 @@ class RunBatchSubcommand(CLISubcommand): asyncio.run(run_batch_main(args)) def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: from vllm.entrypoints.openai.run_batch import make_arg_parser run_batch_parser = subparsers.add_parser( - "run-batch", + self.name, help="Run batch prompts and write results to file.", description=( "Run batch prompts using vLLM's OpenAI-compatible API.\n" - "Supports local or HTTP input/output files."), - usage= - "vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>", + "Supports local or HTTP input/output files." + ), + usage="vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>", ) run_batch_parser = make_arg_parser(run_batch_parser) - show_filtered_argument_or_group_from_help(run_batch_parser, - ["run-batch"]) - run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG + run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return run_batch_parser diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 803a3e004656a..b3960b74cf019 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -10,34 +10,47 @@ import uvloop import vllm import vllm.envs as envs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, - setup_server) -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) -from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, - show_filtered_argument_or_group_from_help) +from vllm.entrypoints.openai.api_server import ( + run_server, + run_server_worker, + setup_server, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, decorate_logs, get_tcp_uri, - set_process_title) +from vllm.utils import ( + FlexibleArgumentParser, + decorate_logs, + get_tcp_uri, + set_process_title, +) from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus -from vllm.v1.utils import (APIServerProcessManager, - wait_for_completion_or_failure) +from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure logger = init_logger(__name__) +DESCRIPTION = """Launch a local OpenAI-compatible API server to serve LLM +completions via HTTP. Defaults to Qwen/Qwen3-0.6B if no model is specified. + +Search by using: `--help=<ConfigGroup>` to explore options by section (e.g., +--help=ModelConfig, --help=Frontend) + Use `--help=all` to show all available flags at once. +""" + class ServeSubcommand(CLISubcommand): - """The `serve` subcommand for the vLLM CLI. """ + """The `serve` subcommand for the vLLM CLI.""" + name = "serve" @staticmethod def cmd(args: argparse.Namespace) -> None: # If model is specified in CLI (as positional arg), it takes precedence - if hasattr(args, 'model_tag') and args.model_tag is not None: + if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag if args.headless or args.api_server_count < 1: @@ -53,17 +66,14 @@ class ServeSubcommand(CLISubcommand): validate_parsed_serve_args(args) def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: serve_parser = subparsers.add_parser( - "serve", - help="Start the vLLM OpenAI Compatible API server.", - description="Start the vLLM OpenAI Compatible API server.", - usage="vllm serve [model_tag] [options]") + self.name, description=DESCRIPTION, usage="vllm serve [model_tag] [options]" + ) serve_parser = make_arg_parser(serve_parser) - show_filtered_argument_or_group_from_help(serve_parser, ["serve"]) - serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG + serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return serve_parser @@ -72,29 +82,27 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): - if args.api_server_count > 1: raise ValueError("api_server_count can't be set in headless mode") # Create the EngineConfig. engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = engine_args.create_engine_config(usage_context=usage_context, - headless=True) + vllm_config = engine_args.create_engine_config( + usage_context=usage_context, headless=True + ) if not envs.VLLM_USE_V1: raise ValueError("Headless mode is only supported for V1") if engine_args.data_parallel_hybrid_lb: - raise ValueError("data_parallel_hybrid_lb is not applicable in " - "headless mode") + raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode") parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local if local_engine_count <= 0: - raise ValueError("data_parallel_size_local must be > 0 in " - "headless mode") + raise ValueError("data_parallel_size_local must be > 0 in headless mode") host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too @@ -110,7 +118,10 @@ def run_headless(args: argparse.Namespace): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, handshake_address) + "with head node address %s.", + local_engine_count, + handshake_address, + ) # Create the engines. engine_manager = CoreEngineProcManager( @@ -133,37 +144,31 @@ def run_headless(args: argparse.Namespace): def run_multi_api_server(args: argparse.Namespace): - assert not args.headless - num_api_servers = args.api_server_count + num_api_servers: int = args.api_server_count assert num_api_servers > 0 - orig_mm_processor_cache_gb = args.mm_processor_cache_gb - if num_api_servers > 1: setup_multiprocess_prometheus() - # Not compatible with API server scale-out - args.mm_processor_cache_gb = 0 - listen_address, sock = setup_server(args) engine_args = vllm.AsyncEngineArgs.from_cli_args(args) + engine_args._api_process_count = num_api_servers + engine_args._api_process_rank = -1 + usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) - model_config = vllm_config.model_config if num_api_servers > 1: if not envs.VLLM_USE_V1: raise ValueError("api_server_count > 1 is only supported for V1") if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " - "with api_server_count > 1") - - if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0: - logger.warning("Multi-modal processor cache is disabled because " - "it is not compatible with `api_server_count > 1`.") + raise ValueError( + "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " + "with api_server_count > 1" + ) executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats @@ -176,10 +181,9 @@ def run_multi_api_server(args: argparse.Namespace): api_server_manager: Optional[APIServerProcessManager] = None - with launch_core_engines(vllm_config, executor_class, log_stats, - num_api_servers) as (local_engine_manager, - coordinator, addresses): - + with launch_core_engines( + vllm_config, executor_class, log_stats, num_api_servers + ) as (local_engine_manager, coordinator, addresses): # Construct common args for the APIServerProcessManager up-front. api_server_manager_kwargs = dict( target_server_fn=run_api_server_worker_proc, @@ -190,7 +194,9 @@ def run_multi_api_server(args: argparse.Namespace): input_addresses=addresses.inputs, output_addresses=addresses.outputs, stats_update_address=coordinator.get_stats_publish_address() - if coordinator else None) + if coordinator + else None, + ) # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the # start of the API servers until the local engine is started @@ -199,34 +205,34 @@ def run_multi_api_server(args: argparse.Namespace): # via the handshake with the local engine. if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb): # Start API servers using the manager. - api_server_manager = APIServerProcessManager( - **api_server_manager_kwargs) + api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) # Start API servers now if they weren't already started. if api_server_manager is None: api_server_manager_kwargs["stats_update_address"] = ( - addresses.frontend_stats_publish_address) - api_server_manager = APIServerProcessManager( - **api_server_manager_kwargs) + addresses.frontend_stats_publish_address + ) + api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) # Wait for API servers - wait_for_completion_or_failure(api_server_manager=api_server_manager, - engine_manager=local_engine_manager, - coordinator=coordinator) + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator, + ) -def run_api_server_worker_proc(listen_address, - sock, - args, - client_config=None, - **uvicorn_kwargs) -> None: +def run_api_server_worker_proc( + listen_address, sock, args, client_config=None, **uvicorn_kwargs +) -> None: """Entrypoint for individual API server worker processes.""" + client_config = client_config or {} + server_index = client_config.get("client_index", 0) # Set process title and add process-specific prefix to stdout and stderr. - server_index = client_config.get("client_index", 0) if client_config else 0 set_process_title("APIServer", str(server_index)) decorate_logs() uvloop.run( - run_server_worker(listen_address, sock, args, client_config, - **uvicorn_kwargs)) + run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs) + ) diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py index b88f094b302ad..6194f421a1bb4 100644 --- a/vllm/entrypoints/cli/types.py +++ b/vllm/entrypoints/cli/types.py @@ -24,6 +24,6 @@ class CLISubcommand: pass def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: raise NotImplementedError("Subclasses should implement this method") diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index f70e1fc207f86..f410ee9c40456 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -1,16 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import contextlib import json import logging from abc import ABC, abstractmethod -from collections.abc import Sequence -from typing import TYPE_CHECKING, Union +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Optional, Union +from openai.types.responses.tool import Mcp from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.harmony_utils import ( - get_encoding, get_streamable_parser_for_assistant, render_for_completion) + get_encoding, + get_streamable_parser_for_assistant, + render_for_completion, +) from vllm.entrypoints.tool import Tool +from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput if TYPE_CHECKING: @@ -18,9 +25,44 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# This is currently needed as the tool type doesn't 1:1 match the +# tool namespace, which is what is used to look up the +# connection to the tool server +_TOOL_NAME_TO_TYPE_MAP = { + "browser": "web_search_preview", + "python": "code_interpreter", + "container": "container", +} + + +def _map_tool_name_to_tool_type(tool_name: str) -> str: + if tool_name not in _TOOL_NAME_TO_TYPE_MAP: + available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys()) + raise ValueError( + f"Built-in tool name '{tool_name}' not defined in mapping. " + f"Available tools: {available_tools}" + ) + return _TOOL_NAME_TO_TYPE_MAP[tool_name] + + +class TurnTokens: + """Tracks token counts for a single conversation turn.""" + + def __init__(self, input_tokens=0, output_tokens=0): + self.input_tokens = input_tokens + self.output_tokens = output_tokens + + def reset(self): + """Reset counters for a new turn.""" + self.input_tokens = 0 + self.output_tokens = 0 + + def copy(self): + """Create a copy of this turn's token counts.""" + return TurnTokens(self.input_tokens, self.output_tokens) + class ConversationContext(ABC): - @abstractmethod def append_output(self, output) -> None: pass @@ -37,14 +79,37 @@ class ConversationContext(ABC): def render_for_completion(self) -> list[int]: pass + @abstractmethod + async def init_tool_sessions( + self, + tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ) -> None: + pass + + @abstractmethod + async def cleanup_session(self) -> None: + raise NotImplementedError("Should not be called.") + class SimpleContext(ConversationContext): - def __init__(self): self.last_output = None + self.num_prompt_tokens = 0 + self.num_output_tokens = 0 + self.num_cached_tokens = 0 + # todo num_reasoning_tokens is not implemented yet. + self.num_reasoning_tokens = 0 def append_output(self, output) -> None: self.last_output = output + if not isinstance(output, RequestOutput): + raise ValueError("SimpleContext only supports RequestOutput.") + self.num_prompt_tokens = len(output.prompt_token_ids or []) + self.num_cached_tokens = output.num_cached_tokens or 0 + self.num_output_tokens += len(output.outputs[0].token_ids or []) def need_builtin_tool_call(self) -> bool: return False @@ -55,49 +120,161 @@ class SimpleContext(ConversationContext): def render_for_completion(self) -> list[int]: raise NotImplementedError("Should not be called.") + async def init_tool_sessions( + self, + tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ) -> None: + pass + + async def cleanup_session(self) -> None: + raise NotImplementedError("Should not be called.") + class HarmonyContext(ConversationContext): - def __init__( self, messages: list, - tool_sessions: dict[str, Tool], + available_tools: list[str], ): self._messages = messages - self.tool_sessions = tool_sessions + self.finish_reason: Optional[str] = None + self.available_tools = available_tools + self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} + self.called_tools: set[str] = set() self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) self.num_prompt_tokens = 0 self.num_output_tokens = 0 - # TODO(woosuk): Implement the following fields. self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + self.num_tool_output_tokens = 0 - def _update_num_prompt_tokens(self, output: RequestOutput): - if output.prompt_token_ids and len(output.prompt_token_ids) > 0: - # NOTE: with built-in tools, there might be multiple rounds in - # the conversation, with the full conversation being resent - # as new prompt each time. Hence the sum. - self.num_prompt_tokens += len(output.prompt_token_ids) + # Turn tracking - replaces multiple individual tracking variables + self.current_turn = TurnTokens() + self.previous_turn = TurnTokens() + self.is_first_turn = True + self.first_tok_of_message = True # For streaming support - def _update_num_output_tokens(self, token_ids: Sequence[int]): - self.num_output_tokens += len(token_ids) + def _update_num_reasoning_tokens(self): + # Count all analysis and commentary channels as reasoning tokens + if self.parser.current_channel in {"analysis", "commentary"}: + self.num_reasoning_tokens += 1 - def append_output(self, output) -> None: + def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: if isinstance(output, RequestOutput): - self._update_num_prompt_tokens(output) output_token_ids = output.outputs[0].token_ids - self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() + self._update_prefill_token_usage(output) + # Reset current turn output tokens for this turn + self.current_turn.output_tokens = 0 + self._update_decode_token_usage(output) + # Move current turn to previous turn for next turn's calculations + self.previous_turn = self.current_turn.copy() + # append_output is called only once before tool calling + # in non-streaming case + # so we can append all the parser messages to _messages output_msgs = self.parser.messages + # The responses finish reason is set in the last message + self.finish_reason = output.outputs[0].finish_reason else: # Tool output. output_msgs = output self._messages.extend(output_msgs) + def _update_prefill_token_usage(self, output: RequestOutput) -> None: + """Update token usage statistics for the prefill phase of generation. + + The prefill phase processes the input prompt tokens. This method: + 1. Counts the prompt tokens for this turn + 2. Calculates tool output tokens for multi-turn conversations + 3. Updates cached token counts + 4. Tracks state for next turn calculations + + Tool output tokens are calculated as: + current_prompt_tokens - last_turn_prompt_tokens - + last_turn_output_tokens + This represents tokens added between turns (typically tool responses). + + Args: + output: The RequestOutput containing prompt token information + """ + if output.prompt_token_ids is not None: + this_turn_input_tokens = len(output.prompt_token_ids) + else: + this_turn_input_tokens = 0 + logger.error("RequestOutput appended contains no prompt_token_ids.") + + # Update current turn input tokens + self.current_turn.input_tokens = this_turn_input_tokens + self.num_prompt_tokens += this_turn_input_tokens + + # Calculate tool tokens (except on first turn) + if self.is_first_turn: + self.is_first_turn = False + else: + # start counting tool after first turn + # tool tokens = this turn prefill - last turn prefill - + # last turn decode + this_turn_tool_tokens = ( + self.current_turn.input_tokens + - self.previous_turn.input_tokens + - self.previous_turn.output_tokens + ) + + # Handle negative tool token counts (shouldn't happen in normal + # cases) + if this_turn_tool_tokens < 0: + logger.error( + "Negative tool output tokens calculated: %d " + "(current_input=%d, previous_input=%d, " + "previous_output=%d). Setting to 0.", + this_turn_tool_tokens, + self.current_turn.input_tokens, + self.previous_turn.input_tokens, + self.previous_turn.output_tokens, + ) + this_turn_tool_tokens = 0 + + self.num_tool_output_tokens += this_turn_tool_tokens + + # Update cached tokens + if output.num_cached_tokens is not None: + self.num_cached_tokens += output.num_cached_tokens + + def _update_decode_token_usage(self, output: RequestOutput) -> int: + """Update token usage statistics for the decode phase of generation. + + The decode phase processes the generated output tokens. This method: + 1. Counts output tokens from all completion outputs + 2. Updates the total output token count + 3. Tracks tokens generated in the current turn + + In streaming mode, this is called for each token generated. + In non-streaming mode, this is called once with all output tokens. + + Args: + output: The RequestOutput containing generated token information + + Returns: + int: Number of output tokens processed in this call + """ + updated_output_token_count = 0 + if output.outputs: + for completion_output in output.outputs: + # only keep last round + updated_output_token_count += len(completion_output.token_ids) + self.num_output_tokens += updated_output_token_count + self.current_turn.output_tokens += updated_output_token_count + return updated_output_token_count + @property def messages(self) -> list: return self._messages @@ -105,8 +282,11 @@ class HarmonyContext(ConversationContext): def need_builtin_tool_call(self) -> bool: last_msg = self.messages[-1] recipient = last_msg.recipient - return recipient is not None and (recipient.startswith("browser.") - or recipient.startswith("python")) + return recipient is not None and ( + recipient.startswith("browser.") + or recipient.startswith("python") + or recipient.startswith("container.") + ) async def call_tool(self) -> list[Message]: if not self.messages: @@ -116,18 +296,25 @@ class HarmonyContext(ConversationContext): if recipient is not None: if recipient.startswith("browser."): return await self.call_search_tool( - self.tool_sessions["browser"], last_msg) + self._tool_sessions["browser"], last_msg + ) elif recipient.startswith("python"): return await self.call_python_tool( - self.tool_sessions["python"], last_msg) + self._tool_sessions["python"], last_msg + ) + elif recipient.startswith("container."): + return await self.call_container_tool( + self._tool_sessions["container"], last_msg + ) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: return render_for_completion(self.messages) - async def call_search_tool(self, tool_session: Union["ClientSession", - Tool], - last_msg: Message) -> list[Message]: + async def call_search_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: + self.called_tools.add("browser") if isinstance(tool_session, Tool): return await tool_session.get_result(self) tool_name = last_msg.recipient.split(".")[1] @@ -137,12 +324,18 @@ class HarmonyContext(ConversationContext): content = TextContent(text=result_str) author = Author(role=Role.TOOL, name=last_msg.recipient) return [ - Message(author=author, content=[content], recipient=Role.ASSISTANT) + Message( + author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel, + ) ] - async def call_python_tool(self, tool_session: Union["ClientSession", - Tool], - last_msg: Message) -> list[Message]: + async def call_python_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: + self.called_tools.add("python") if isinstance(tool_session, Tool): return await tool_session.get_result(self) param = { @@ -155,15 +348,91 @@ class HarmonyContext(ConversationContext): author = Author(role=Role.TOOL, name="python") return [ - Message(author=author, - content=[content], - channel=last_msg.channel, - recipient=Role.ASSISTANT) + Message( + author=author, + content=[content], + channel=last_msg.channel, + recipient=Role.ASSISTANT, + ) ] + async def init_tool_sessions( + self, + tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ): + if tool_server: + for tool_name in self.available_tools: + if tool_name not in self._tool_sessions: + tool_type = _map_tool_name_to_tool_type(tool_name) + headers = ( + mcp_tools[tool_type].headers if tool_type in mcp_tools else None + ) + tool_session = await exit_stack.enter_async_context( + tool_server.new_session(tool_name, request_id, headers) + ) + self._tool_sessions[tool_name] = tool_session + exit_stack.push_async_exit(self.cleanup_session) + + async def call_container_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: + """ + Call container tool. Expect this to be run in a stateful docker + with command line terminal. + The official container tool would at least + expect the following format: + - for tool name: exec + - args: + { + "cmd":List[str] "command to execute", + "workdir":optional[str] "current working directory", + "env":optional[object/dict] "environment variables", + "session_name":optional[str] "session name", + "timeout":optional[int] "timeout in seconds", + "user":optional[str] "user name", + } + """ + self.called_tools.add("container") + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + tool_name = last_msg.recipient.split(".")[1].split(" ")[0] + args = json.loads(last_msg.content[0].text) + result = await tool_session.call_tool(tool_name, args) + result_str = result.content[0].text + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name=last_msg.recipient) + return [ + Message( + author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel, + ) + ] + + async def cleanup_session(self, *args, **kwargs) -> None: + """Can be used as coro to used in __aexit__""" + + async def cleanup_tool_session(tool_session): + if not isinstance(tool_session, Tool): + logger.info( + "Cleaning up tool session for %s", tool_session._client_info + ) + with contextlib.suppress(Exception): + await tool_session.call_tool("cleanup_session", {}) + + await asyncio.gather( + *( + cleanup_tool_session(self._tool_sessions[tool]) + for tool in self.called_tools + ) + ) + class StreamingHarmonyContext(HarmonyContext): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.last_output = None @@ -175,23 +444,34 @@ class StreamingHarmonyContext(HarmonyContext): @property def messages(self) -> list: - return self.parser.messages + return self._messages - def append_output(self, output) -> None: + def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: if isinstance(output, RequestOutput): # append_output is called for each output token in streaming case, # so we only want to add the prompt tokens once for each message. if self.first_tok_of_message: - self._update_num_prompt_tokens(output) + self._update_prefill_token_usage(output) + self.current_turn.output_tokens = 0 # Reset self.first_tok_of_message if needed: # if the current token is the last one of the current message # (finished=True), then the next token processed will mark the # beginning of a new message self.first_tok_of_message = output.finished - tok = output.outputs[0].token_ids[0] - self.parser.process(tok) - self._update_num_output_tokens(output.outputs[0].token_ids) + for tok in output.outputs[0].token_ids: + self.parser.process(tok) + self._update_decode_token_usage(output) + + # For streaming, update previous turn when message is complete + if output.finished: + self.previous_turn = self.current_turn.copy() + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() self.last_tok = tok + if len(self._messages) - self.num_init_messages < len(self.parser.messages): + self._messages.extend( + self.parser.messages[len(self._messages) - self.num_init_messages :] + ) else: # Handle the case of tool output in direct message format assert len(output) == 1, "Tool output should be a single message" @@ -204,13 +484,13 @@ class StreamingHarmonyContext(HarmonyContext): for tok in toks: self.parser.process(tok) self.last_tok = toks[-1] + # TODO: add tool_output messages to self._messages def is_expecting_start(self) -> bool: return self.parser.state == StreamState.EXPECT_START def is_assistant_action_turn(self) -> bool: - return self.last_tok in self.encoding.stop_tokens_for_assistant_actions( - ) + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions() def render_for_completion(self) -> list[int]: # now this list of tokens as next turn's starting tokens diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index bc810f683f4a4..53a08b1a4485c 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -1,24 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + import datetime import json from collections.abc import Iterable, Sequence -from typing import Literal, Optional, Union +from typing import Literal, Union -from openai.types.responses import (ResponseFunctionToolCall, - ResponseOutputItem, ResponseOutputMessage, - ResponseOutputText, ResponseReasoningItem) +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputItem, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) from openai.types.responses.response_function_web_search import ( - ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) + ActionFind, + ActionOpenPage, + ActionSearch, + ResponseFunctionWebSearch, +) from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent) + Content as ResponseReasoningTextContent, +) from openai.types.responses.tool import Tool -from openai_harmony import (Author, Conversation, DeveloperContent, - HarmonyEncodingName, Message, ReasoningEffort, - Role, StreamableParser, SystemContent, TextContent, - ToolDescription, load_harmony_encoding) +from openai_harmony import ( + Author, + ChannelConfig, + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + ReasoningEffort, + Role, + StreamableParser, + SystemContent, + TextContent, + ToolDescription, + load_harmony_encoding, +) -from vllm.entrypoints.openai.protocol import ResponseInputOutputItem +from vllm import envs +from vllm.entrypoints.openai.protocol import ( + ChatCompletionToolsParam, + ResponseInputOutputItem, +) from vllm.utils import random_uuid REASONING_EFFORT = { @@ -29,28 +56,51 @@ REASONING_EFFORT = { _harmony_encoding = None +# Builtin tools that should be included in the system message when +# they are available and requested by the user. +# Tool args are provided by MCP tool descriptions. Output +# of the tools are stringified. +BUILTIN_TOOLS = { + "web_search_preview", + "code_interpreter", + "container", +} + + +def has_custom_tools(tool_types: list[str]) -> bool: + return not set(tool_types).issubset(BUILTIN_TOOLS) + def get_encoding(): global _harmony_encoding if _harmony_encoding is None: - _harmony_encoding = load_harmony_encoding( - HarmonyEncodingName.HARMONY_GPT_OSS) + _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) return _harmony_encoding def get_system_message( - model_identity: Optional[str] = None, - reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, - start_date: Optional[str] = None, - browser_description: Optional[str] = None, - python_description: Optional[str] = None, + model_identity: str | None = None, + reasoning_effort: Literal["high", "medium", "low"] | None = None, + start_date: str | None = None, + browser_description: str | None = None, + python_description: str | None = None, + container_description: str | None = None, + instructions: str | None = None, + with_custom_tools: bool = False, ) -> Message: sys_msg_content = SystemContent.new() if model_identity is not None: sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: + current_identity = sys_msg_content.model_identity + new_identity = ( + f"{current_identity}\n{instructions}" if current_identity else instructions + ) + sys_msg_content = sys_msg_content.with_model_identity(new_identity) if reasoning_effort is not None: sys_msg_content = sys_msg_content.with_reasoning_effort( - REASONING_EFFORT[reasoning_effort]) + REASONING_EFFORT[reasoning_effort] + ) if start_date is None: # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. start_date = datetime.datetime.now().strftime("%Y-%m-%d") @@ -59,35 +109,65 @@ def get_system_message( sys_msg_content = sys_msg_content.with_tools(browser_description) if python_description is not None: sys_msg_content = sys_msg_content.with_tools(python_description) + if container_description is not None: + sys_msg_content = sys_msg_content.with_tools(container_description) + if not with_custom_tools: + channel_config = sys_msg_content.channel_config + invalid_channel = "commentary" + new_config = ChannelConfig.require_channels( + [c for c in channel_config.valid_channels if c != invalid_channel] + ) + sys_msg_content = sys_msg_content.with_channel_config(new_config) sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) return sys_msg -def get_developer_message(instructions: Optional[str] = None, - tools: Optional[list[Tool]] = None) -> Message: +def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): + if isinstance(tool, ChatCompletionToolsParam): + return ToolDescription.new( + name=tool.function.name, + description=tool.function.description, + parameters=tool.function.parameters, + ) + return ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + + +def get_developer_message( + instructions: str | None = None, + tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None, +) -> Message: dev_msg_content = DeveloperContent.new() - if instructions is not None: + if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: dev_msg_content = dev_msg_content.with_instructions(instructions) if tools is not None: - function_tools = [] + function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] for tool in tools: - if tool.type in ("web_search_preview", "code_interpreter"): + if tool.type in ( + "web_search_preview", + "code_interpreter", + "container", + "mcp", + ): # These are built-in tools that are added to the system message. + # Adding in MCP for now until we support MCP tools executed + # server side pass + elif tool.type == "function": function_tools.append(tool) else: raise ValueError(f"tool type {tool.type} not supported") if function_tools: function_tool_descriptions = [ - ToolDescription.new( - name=tool.name, - description=tool.description, - parameters=tool.parameters, - ) for tool in function_tools + create_tool_definition(tool) for tool in function_tools ] dev_msg_content = dev_msg_content.with_function_tools( - function_tool_descriptions) + function_tool_descriptions + ) dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) return dev_msg @@ -98,7 +178,7 @@ def get_user_message(content: str) -> Message: def parse_response_input( response_msg: ResponseInputOutputItem, - prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]], ) -> Message: if not isinstance(response_msg, dict): response_msg = response_msg.model_dump() @@ -116,30 +196,32 @@ def parse_response_input( if isinstance(content, str): msg = Message.from_role_and_content(role, text_prefix + content) else: - contents = [ - TextContent(text=text_prefix + c["text"]) for c in content - ] + contents = [TextContent(text=text_prefix + c["text"]) for c in content] msg = Message.from_role_and_contents(role, contents) + if role == "assistant": + msg = msg.with_channel("final") elif response_msg["type"] == "function_call_output": call_id = response_msg["call_id"] - call_response: Optional[ResponseFunctionToolCall] = None + call_response: ResponseFunctionToolCall | None = None for prev_response in reversed(prev_responses): - if isinstance(prev_response, ResponseFunctionToolCall - ) and prev_response.call_id == call_id: + if ( + isinstance(prev_response, ResponseFunctionToolCall) + and prev_response.call_id == call_id + ): call_response = prev_response break if call_response is None: raise ValueError(f"No call message found for {call_id}") msg = Message.from_author_and_content( Author.new(Role.TOOL, f"functions.{call_response.name}"), - response_msg["output"]) + response_msg["output"], + ) elif response_msg["type"] == "reasoning": content = response_msg["content"] assert len(content) == 1 msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) elif response_msg["type"] == "function_call": - msg = Message.from_role_and_content(Role.ASSISTANT, - response_msg["arguments"]) + msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"]) msg = msg.with_channel("commentary") msg = msg.with_recipient(f"functions.{response_msg['name']}") msg = msg.with_content_type("json") @@ -148,22 +230,62 @@ def parse_response_input( return msg -def parse_chat_input(chat_msg) -> Message: - role = chat_msg["role"] - content = chat_msg["content"] +def parse_chat_input(chat_msg) -> list[Message]: + if not isinstance(chat_msg, dict): + # Handle Pydantic models + chat_msg = chat_msg.model_dump(exclude_none=True) + + role = chat_msg.get("role") + + # Assistant message with tool calls + tool_calls = chat_msg.get("tool_calls") + if role == "assistant" and tool_calls: + msgs: list[Message] = [] + for call in tool_calls: + func = call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "") or "" + msg = Message.from_role_and_content(Role.ASSISTANT, arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{name}") + msg = msg.with_content_type("json") + msgs.append(msg) + return msgs + + # Tool role message (tool output) + if role == "tool": + name = chat_msg.get("name", "") + content = chat_msg.get("content", "") or "" + if isinstance(content, list): + # Handle array format for tool message content + # by concatenating all text parts. + content = "".join( + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ) + + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{name}"), content + ).with_channel("commentary") + return [msg] + + # Default: user/assistant/system messages with content + content = chat_msg.get("content", "") if isinstance(content, str): contents = [TextContent(text=content)] else: # TODO: Support refusal. - contents = [TextContent(text=c["text"]) for c in content] + contents = [TextContent(text=c.get("text", "")) for c in content] msg = Message.from_role_and_contents(role, contents) - return msg + return [msg] def render_for_completion(messages: list[Message]) -> list[int]: conversation = Conversation.from_messages(messages) token_ids = get_encoding().render_conversation_for_completion( - conversation, Role.ASSISTANT) + conversation, Role.ASSISTANT + ) return token_ids @@ -187,14 +309,18 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: # TODO: translate to url properly! if recipient == "browser.search": action = ActionSearch( - query=f"cursor:{browser_call.get('query', '')}", type="search") + query=f"cursor:{browser_call.get('query', '')}", type="search" + ) elif recipient == "browser.open": action = ActionOpenPage( - url=f"cursor:{browser_call.get('url', '')}", type="open_page") + url=f"cursor:{browser_call.get('url', '')}", type="open_page" + ) elif recipient == "browser.find": - action = ActionFind(pattern=browser_call["pattern"], - url=f"cursor:{browser_call.get('url', '')}", - type="find") + action = ActionFind( + pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find", + ) else: raise ValueError(f"Unknown browser action: {recipient}") web_search_item = ResponseFunctionWebSearch( @@ -211,15 +337,16 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=content.text, - type="reasoning_text") + ResponseReasoningTextContent( + text=content.text, type="reasoning_text" + ) ], status=None, ) output_items.append(reasoning_item) elif message.channel == "commentary": - if message.recipient.startswith("functions."): - function_name = message.recipient.split(".")[-1] + if recipient is not None and recipient.startswith("functions."): + function_name = recipient.split(".")[-1] for content in message.content: random_id = random_uuid() response_item = ResponseFunctionToolCall( @@ -227,25 +354,29 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: call_id=f"call_{random_id}", type="function_call", name=function_name, - id=f"ft_{random_id}", + id=f"fc_{random_id}", ) output_items.append(response_item) - elif message.recipient.startswith( - "python") or message.recipient.startswith("browser"): + elif recipient is not None and ( + recipient.startswith("python") + or recipient.startswith("browser") + or recipient.startswith("container") + ): for content in message.content: reasoning_item = ResponseReasoningItem( id=f"rs_{random_uuid()}", summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=content.text, - type="reasoning_text") + ResponseReasoningTextContent( + text=content.text, type="reasoning_text" + ) ], status=None, ) output_items.append(reasoning_item) else: - raise ValueError(f"Unknown recipient: {message.recipient}") + raise ValueError(f"Unknown recipient: {recipient}") elif message.channel == "final": contents = [] for content in message.content: @@ -269,15 +400,13 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: return output_items -def parse_remaining_state( - parser: StreamableParser) -> list[ResponseOutputItem]: +def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]: if not parser.current_content: return [] if parser.current_role != Role.ASSISTANT: return [] current_recipient = parser.current_recipient - if (current_recipient is not None - and current_recipient.startswith("browser.")): + if current_recipient is not None and current_recipient.startswith("browser."): return [] if parser.current_channel == "analysis": @@ -286,8 +415,9 @@ def parse_remaining_state( summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=parser.current_content, - type="reasoning_text") + ResponseReasoningTextContent( + text=parser.current_content, type="reasoning_text" + ) ], status=None, ) @@ -303,7 +433,9 @@ def parse_remaining_state( id=f"msg_{random_uuid()}", content=[output_text], role="assistant", - status="completed", + # if the parser still has messages (ie if the generator got cut + # abruptly), this should be incomplete + status="incomplete", type="message", ) return [text_item] @@ -326,7 +458,8 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: def parse_chat_output( - token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]: + token_ids: Sequence[int], +) -> tuple[str | None, str | None, bool]: parser = parse_output_into_messages(token_ids) output_msgs = parser.messages is_tool_call = False # TODO: update this when tool call is supported @@ -341,7 +474,6 @@ def parse_chat_output( else: reasoning_msg = output_msgs[:-1] final_msg = output_msgs[-1] - reasoning_content = "\n".join( - [msg.content[0].text for msg in reasoning_msg]) + reasoning_content = "\n".join([msg.content[0].text for msg in reasoning_msg]) final_content = final_msg.content[0].text return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 4e852ba594930..349437363c5b8 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -11,11 +11,11 @@ import uvicorn from fastapi import FastAPI, Request, Response from vllm import envs -from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import EngineClient -from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, - H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) +from vllm.entrypoints.constants import ( + H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, +) from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -24,10 +24,12 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) -async def serve_http(app: FastAPI, - sock: Optional[socket.socket], - enable_ssl_refresh: bool = False, - **uvicorn_kwargs: Any): +async def serve_http( + app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, + **uvicorn_kwargs: Any, +): """ Start a FastAPI app using Uvicorn, with support for custom Uvicorn config options. Supports http header limits via h11_max_incomplete_event_size and @@ -41,11 +43,12 @@ async def serve_http(app: FastAPI, if methods is None or path is None: continue - logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) # Extract header limit options if present h11_max_incomplete_event_size = uvicorn_kwargs.pop( - "h11_max_incomplete_event_size", None) + "h11_max_incomplete_event_size", None + ) h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None) # Set safe defaults if not provided @@ -64,16 +67,19 @@ async def serve_http(app: FastAPI, loop = asyncio.get_running_loop() - watchdog_task = loop.create_task( - watchdog_loop(server, app.state.engine_client)) - server_task = loop.create_task( - server.serve(sockets=[sock] if sock else None)) + watchdog_task = loop.create_task(watchdog_loop(server, app.state.engine_client)) + server_task = loop.create_task(server.serve(sockets=[sock] if sock else None)) - ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( - ssl_context=config.ssl, - key_path=config.ssl_keyfile, - cert_path=config.ssl_certfile, - ca_path=config.ssl_ca_certs) + ssl_cert_refresher = ( + None + if not enable_ssl_refresh + else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs, + ) + ) def signal_handler() -> None: # prevents the uvicorn signal handler to exit early @@ -95,9 +101,12 @@ async def serve_http(app: FastAPI, port = uvicorn_kwargs["port"] process = find_process_using_port(port) if process is not None: - logger.debug( + logger.warning( "port %s is used by process %s launched with command:\n%s", - port, process, " ".join(process.cmdline())) + port, + process, + " ".join(process.cmdline()), + ) logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() finally: @@ -133,14 +142,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """ VLLM V1 AsyncLLM catches exceptions and returns only two types: EngineGenerateError and EngineDeadError. - + EngineGenerateError is raised by the per request generate() method. This error could be request specific (and therefore recoverable - e.g. if there is an error in input processing). - + EngineDeadError is raised by the background output_handler method. This error is global and therefore not recoverable. - + We register these @app.exception_handlers to return nice responses to the end user if they occur and shut down if needed. See https://fastapi.tiangolo.com/tutorial/handling-errors/ @@ -155,8 +164,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """ @app.exception_handler(RuntimeError) - @app.exception_handler(AsyncEngineDeadError) - @app.exception_handler(MQEngineDeadError) @app.exception_handler(EngineDeadError) @app.exception_handler(EngineGenerateError) async def runtime_exception_handler(request: Request, __): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 728ed8328d36d..8f47c20f27e0a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,49 +9,75 @@ import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar +from typing_extensions import TypeVar, deprecated -import vllm.envs as envs -from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, - BeamSearchSequence, - create_sort_beams_key_function) -from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, - is_init_field) -from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides, - PoolerConfig, RunnerOption) -from vllm.engine.llm_engine import LLMEngine -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - ChatTemplateContentFormatOption, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages, - resolve_chat_template_content_format) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam, - _cosine_similarity, - _validate_score_input_lens, - compress_token_type_ids, - get_score_prompt) -# yapf: enable -from vllm.entrypoints.utils import (_validate_truncation_size, - log_non_default_args) -from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt +from vllm.beam_search import ( + BeamSearchInstance, + BeamSearchOutput, + BeamSearchSequence, + create_sort_beams_key_function, +) +from vllm.config import ( + CompilationConfig, + PoolerConfig, + StructuredOutputsConfig, + is_init_field, +) +from vllm.config.model import ( + ConvertOption, + HfOverrides, + ModelDType, + RunnerOption, + TokenizerMode, +) +from vllm.engine.arg_utils import EngineArgs +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages, + resolve_chat_template_content_format, +) +from vllm.entrypoints.score_utils import ( + ScoreContentPartParam, + ScoreMultiModalParam, + _cosine_similarity, + _validate_score_input_lens, + compress_token_type_ids, + get_score_prompt, +) +from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args +from vllm.inputs import ( + DataPrompt, + PromptType, + SingletonPrompt, + TextPrompt, + TokensPrompt, +) +from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, - PoolingRequestOutput, RequestOutput, - ScoringRequestOutput) +from vllm.outputs import ( + ClassificationRequestOutput, + EmbeddingRequestOutput, + PoolingRequestOutput, + RequestOutput, + ScoringRequestOutput, +) from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, - SamplingParams) +from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask -from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, - get_cached_tokenizer) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + MistralTokenizer, + get_cached_tokenizer, +) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, is_list_of +from vllm.utils import Counter, Device, as_iter, is_list_of +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -85,6 +111,8 @@ class LLM: or videos from directories specified by the server file system. This is a security risk. Should only be enabled in trusted environments. + allowed_media_domains: If set, only media URLs that belong to this + domain can be used for multi-modal inputs. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, @@ -108,6 +136,14 @@ class LLM: values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of- memory (OOM) errors. + kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default, + this is set to None and vllm can automatically infer the kv cache + size based on gpu_memory_utilization. However, users may want to + manually specify the kv cache memory size. kv_cache_memory_bytes + allows more fine-grain control of how much memory gets used when + compared with using gpu_memory_utilization. Note that + kv_cache_memory_bytes (when not-None) ignores + gpu_memory_utilization swap_space: The size (GiB) of CPU memory per GPU to use as swap space. This can be used for temporarily storing the states of the requests when their `best_of` sampling parameters are larger than 1. If all @@ -121,15 +157,8 @@ class LLM: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. - max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode. Additionally for encoder-decoder models, if the - sequence length of the encoder input is larger than this, we fall - back to the eager mode. disable_custom_all_reduce: See [ParallelConfig][vllm.config.ParallelConfig]. - disable_async_output_proc: Disable async output processing. - This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -141,9 +170,11 @@ class LLM: multi-modal processor obtained from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. - override_pooler_config: Initialize non-default pooling config or - override default pooling config for the pooling model. - e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. + pooler_config: Initialize non-default pooling config for the pooling + model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. + override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This + argument is deprecated and will be removed in v0.12.0 or v1.0.0, + whichever is sooner. compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. @@ -165,6 +196,7 @@ class LLM: skip_tokenizer_init: bool = False, trust_remote_code: bool = False, allowed_local_media_path: str = "", + allowed_media_domains: Optional[list[str]] = None, tensor_parallel_size: int = 1, dtype: ModelDType = "auto", quantization: Optional[QuantizationMethods] = None, @@ -175,18 +207,21 @@ class LLM: swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: bool = False, - max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, - disable_async_output_proc: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, + pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, dict[str, Any], - CompilationConfig]] = None, - logits_processors: Optional[list[Union[str, - type[LogitsProcessor]]]] = None, - **kwargs, + structured_outputs_config: Optional[ + Union[dict[str, Any], StructuredOutputsConfig] + ] = None, + kv_cache_memory_bytes: Optional[int] = None, + compilation_config: Optional[ + Union[int, dict[str, Any], CompilationConfig] + ] = None, + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None, + **kwargs: Any, ) -> None: """LLM constructor.""" @@ -201,21 +236,23 @@ class LLM: kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) if "kv_transfer_config" in kwargs and isinstance( - kwargs["kv_transfer_config"], dict): - from vllm.config import KVTransferConfig + kwargs["kv_transfer_config"], dict + ): + from vllm.config.kv_transfer import KVTransferConfig + raw_config_dict = kwargs["kv_transfer_config"] try: - kwargs["kv_transfer_config"] = KVTransferConfig( - **raw_config_dict) + kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict) except ValidationError as e: logger.error( "Failed to convert 'kv_transfer_config' dict to " "KVTransferConfig object. Dict: %s. Error: %s", - raw_config_dict, e) + raw_config_dict, + e, + ) # Consider re-raising a more specific vLLM error or ValueError # to provide better context to the user. - raise ValueError( - f"Invalid 'kv_transfer_config' provided: {e}") from e + raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e if hf_overrides is None: hf_overrides = {} @@ -223,16 +260,35 @@ class LLM: if compilation_config is not None: if isinstance(compilation_config, int): compilation_config_instance = CompilationConfig( - level=compilation_config) + level=compilation_config + ) elif isinstance(compilation_config, dict): - predicate = lambda x: is_init_field(CompilationConfig, x[0]) compilation_config_instance = CompilationConfig( - **dict(filter(predicate, compilation_config.items()))) + **{ + k: v + for k, v in compilation_config.items() + if is_init_field(CompilationConfig, k) + } + ) else: compilation_config_instance = compilation_config else: compilation_config_instance = CompilationConfig() + if structured_outputs_config is not None: + if isinstance(structured_outputs_config, dict): + structured_outputs_instance = StructuredOutputsConfig( + **{ + k: v + for k, v in structured_outputs_config.items() + if is_init_field(StructuredOutputsConfig, k) + } + ) + else: + structured_outputs_instance = structured_outputs_config + else: + structured_outputs_instance = StructuredOutputsConfig() + engine_args = EngineArgs( model=model, runner=runner, @@ -242,6 +298,7 @@ class LLM: skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, allowed_local_media_path=allowed_local_media_path, + allowed_media_domains=allowed_media_domains, tensor_parallel_size=tensor_parallel_size, dtype=dtype, quantization=quantization, @@ -249,16 +306,17 @@ class LLM: tokenizer_revision=tokenizer_revision, seed=seed, gpu_memory_utilization=gpu_memory_utilization, + kv_cache_memory_bytes=kv_cache_memory_bytes, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, - max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - disable_async_output_proc=disable_async_output_proc, hf_token=hf_token, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, + pooler_config=pooler_config, override_pooler_config=override_pooler_config, + structured_outputs_config=structured_outputs_instance, compilation_config=compilation_config_instance, logits_processors=logits_processors, **kwargs, @@ -268,44 +326,41 @@ class LLM: # Create the Engine (autoselects V0 vs V1) self.llm_engine = LLMEngine.from_engine_args( - engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS + ) self.engine_class = type(self.llm_engine) self.request_counter = Counter() self.default_sampling_params: Union[dict[str, Any], None] = None - if envs.VLLM_USE_V1: - supported_tasks = self.llm_engine \ - .get_supported_tasks() # type: ignore - else: - supported_tasks = self.llm_engine.model_config.supported_tasks - - logger.info("Supported_tasks: %s", supported_tasks) - + supported_tasks = self.llm_engine.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) self.supported_tasks = supported_tasks - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.llm_engine.get_tokenizer_group().get_lora_tokenizer( - lora_request) + self.model_config = self.llm_engine.model_config + self.processor = self.llm_engine.processor + self.io_processor = self.llm_engine.io_processor + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer() + + @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.") def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: - tokenizer_group = self.llm_engine.get_tokenizer_group() - # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from # user-defined tokenizer started with 'Cached' if tokenizer.__class__.__name__.startswith("Cached"): - tokenizer_group.tokenizer = tokenizer + self.llm_engine.tokenizer = tokenizer else: - tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) + + def reset_mm_cache(self) -> None: + self.processor.clear_mm_cache() + self.llm_engine.reset_mm_cache() def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: - self.default_sampling_params = ( - self.llm_engine.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() @@ -313,8 +368,9 @@ class LLM: def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, @@ -329,7 +385,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -352,36 +408,27 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ - model_config = self.llm_engine.model_config + model_config = self.model_config runner_type = model_config.runner_type if runner_type != "generate": raise ValueError( "LLM.generate() is only supported for generative models. " "Try passing `--runner generate` to use the model as a " - "generative model.") + "generative model." + ) if sampling_params is None: # Use default sampling params. sampling_params = self.get_default_sampling_params() - tokenization_kwargs: dict[str, Any] = {} - truncate_prompt_tokens = None - if isinstance(sampling_params, SamplingParams): - truncate_prompt_tokens = sampling_params.truncate_prompt_tokens - - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) - # Add any modality specific loras to the corresponding prompts - lora_request = self._get_modality_specific_lora_reqs( - prompts, lora_request) + lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request) self._validate_and_add_requests( prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, priority=priority, ) @@ -389,46 +436,57 @@ class LLM: return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( - self, prompts: Union[PromptType, Sequence[PromptType]], - lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): + self, + prompts: Union[PromptType, Sequence[PromptType]], + lora_request: Optional[Union[list[LoRARequest], LoRARequest]], + ): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. lora_config = self.llm_engine.vllm_config.lora_config # If there's no lora config / default_mm_loras, or the model # isn't multimodal, leave the lora as is. - if (lora_config is None - or not self.llm_engine.model_config.is_multimodal_model - or (lora_config and lora_config.default_mm_loras is None)): + if ( + lora_config is None + or not self.model_config.is_multimodal_model + or (lora_config and lora_config.default_mm_loras is None) + ): return lora_request if not isinstance(prompts, Sequence): prompts = [prompts] - optional_loras = ([lora_request] * len(prompts) - if not isinstance(lora_request, Sequence) else - lora_request) + optional_loras = ( + [lora_request] * len(prompts) + if not isinstance(lora_request, Sequence) + else lora_request + ) return [ self._resolve_single_prompt_mm_lora( prompt, opt_lora_req, lora_config.default_mm_loras, - ) for prompt, opt_lora_req in zip(prompts, optional_loras) + ) + for prompt, opt_lora_req in zip(prompts, optional_loras) ] - def _resolve_single_prompt_mm_lora(self, prompt: PromptType, - lora_request: Optional[LoRARequest], - default_mm_loras: Optional[dict[str, - str]]): - if (not default_mm_loras or not isinstance(prompt, dict) - or "multi_modal_data" not in prompt): + def _resolve_single_prompt_mm_lora( + self, + prompt: PromptType, + lora_request: Optional[LoRARequest], + default_mm_loras: Optional[dict[str, str]], + ): + if ( + not default_mm_loras + or not isinstance(prompt, dict) + or not (mm_data := prompt.get("multi_modal_data") or {}) + ): return lora_request - prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - - intersection = set(prompt["multi_modal_data"].keys()) \ - .intersection(default_mm_loras.keys()) + intersection = set( + mm_data.keys() # type: ignore + ).intersection(default_mm_loras.keys()) if not intersection: return lora_request if len(intersection) > 1: @@ -438,7 +496,9 @@ class LLM: " used by a single prompt consuming several modalities; " " currently we only support one lora per request; as such," " lora(s) registered with modalities: %s" - " will be skipped", intersection) + " will be skipped", + intersection, + ) return lora_request # Build the LoRA request; the ID of the default mm lora is the @@ -454,7 +514,8 @@ class LLM: logger.warning( "A modality with a registered lora and a lora_request " "with a different ID were provided; falling back to the " - "lora_request as we only apply one LoRARequest per prompt") + "lora_request as we only apply one LoRARequest per prompt" + ) return lora_request return LoRARequest( @@ -463,11 +524,13 @@ class LLM: modality_lora_path, ) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: """ Execute an RPC call on all workers. @@ -497,9 +560,14 @@ class LLM: """ Run a function directly on the model inside each worker, returning the result for each of them. + + !!! warning + To reduce the overhead of data transfer, avoid returning large + arrays or tensors from this method. If you must return them, + make sure you move them to CPU first to avoid taking up additional + VRAM! """ - executor = self.llm_engine.model_executor - return executor.apply_model(func) + return self.llm_engine.apply_model(func) def _get_beam_search_lora_requests( self, @@ -507,10 +575,10 @@ class LLM: prompts: list[Union[TokensPrompt, TextPrompt]], ) -> list[Optional[LoRARequest]]: """Get the optional lora request corresponding to each prompt.""" - if isinstance(lora_request, - Sequence) and len(lora_request) != len(prompts): + if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts): raise ValueError( - "Lora request list should be the same length as the prompts") + "Lora request list should be the same length as the prompts" + ) if lora_request is None or isinstance(lora_request, LoRARequest): return [lora_request] * len(prompts) @@ -523,6 +591,7 @@ class LLM: params: BeamSearchParams, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, use_tqdm: bool = False, + concurrency_limit: Optional[int] = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -533,6 +602,8 @@ class LLM: params: The beam search parameters. lora_request: LoRA request to use for generation, if any. use_tqdm: Whether to use tqdm to display the progress bar. + concurrency_limit: The maximum number of concurrent requests. + If None, the number of concurrent requests is unlimited. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -542,8 +613,7 @@ class LLM: ignore_eos = params.ignore_eos length_penalty = params.length_penalty - lora_requests = self._get_beam_search_lora_requests( - lora_request, prompts) + lora_requests = self._get_beam_search_lora_requests(lora_request, prompts) tokenizer = self.get_tokenizer() sort_beams_key = create_sort_beams_key_function( @@ -551,25 +621,31 @@ class LLM: length_penalty, ) - def create_tokens_prompt_from_beam( - beam: BeamSearchSequence) -> TokensPrompt: - token_prompt_kwargs: TokensPrompt = { - "prompt_token_ids": beam.tokens - } + if use_tqdm and concurrency_limit is not None: + logger.warning( + "Progress bar is not supported when using concurrency_limit. " + "Disabling progress bar." + ) + use_tqdm = False + + if concurrency_limit is None: + concurrency_limit = len(prompts) + + def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt: + token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens} if beam.multi_modal_data is not None: token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data if beam.mm_processor_kwargs is not None: - token_prompt_kwargs[ - "mm_processor_kwargs"] = beam.mm_processor_kwargs + token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs return TokensPrompt(**token_prompt_kwargs) # generate 2 * beam_width candidates at each step # following the huggingface transformers implementation # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) + beam_search_params = SamplingParams( + logprobs=2 * beam_width, max_tokens=1, temperature=temperature + ) instances: list[BeamSearchInstance] = [] for lora_req, prompt in zip(lora_requests, prompts): @@ -578,8 +654,7 @@ class LLM: if "multi_modal_data" in prompt: mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] if "mm_processor_kwargs" in prompt: - mm_kwargs["mm_processor_kwargs"] = prompt[ - "mm_processor_kwargs"] + mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"] if "prompt_token_ids" in prompt: prompt = cast(TokensPrompt, prompt) # Needed for mypy @@ -593,82 +668,98 @@ class LLM: lora_request=lora_req, logprobs=None, **mm_kwargs, - ), ) + ), + ) - token_iter = range(max_tokens) - if use_tqdm: - token_iter = tqdm(token_iter, - desc="Beam search", - unit="token", - unit_scale=False) - logger.warning( - "The progress bar shows the upper bound on token steps and " - "may finish early due to stopping conditions. It does not " - "reflect instance-level progress.") + for prompt_start in range(0, len(prompts), concurrency_limit): + instances_batch = instances[prompt_start : prompt_start + concurrency_limit] - for _ in token_iter: - all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances), [])) - pos = [0] + list( - itertools.accumulate( - len(instance.beams) for instance in instances)) - instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:])) + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm( + token_iter, desc="Beam search", unit="token", unit_scale=False + ) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress." + ) + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances_batch), []) + ) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances_batch + ) + ) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:]) + ) - if len(all_beams) == 0: - break + if len(all_beams) == 0: + break - # create the corresponding batch entries for prompt & optional lora - prompts_batch, lora_req_batch = zip( - *[(create_tokens_prompt_from_beam(beam), beam.lora_request) - for beam in all_beams]) + # create corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[ + (create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams + ] + ) - # only runs for one step - # we don't need to use tqdm here - output = self.generate(prompts_batch, - sampling_params=beam_search_params, - use_tqdm=False, - lora_request=lora_req_batch) + # only runs for one step + # we don't need to use tqdm here + output = self.generate( + prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch, + ) - for (start, end), instance in zip(instance_start_and_end, - instances): - instance_new_beams = [] - for i in range(start, end): - current_beam = all_beams[i] - result = output[i] + for (start, end), instance in zip( + instance_start_and_end, instances_batch + ): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] - if result.outputs[0].logprobs is not None: - # if `result.outputs[0].logprobs` is None, it means - # the sequence is completed because of the max-model-len - # or abortion. we don't need to add it to the new beams. - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs) + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the + # max-model-len or abortion. we don't need to add + # it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam.mm_processor_kwargs, + ) - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - instance.completed.append(new_beam) - else: - instance_new_beams.append(new_beam) - sorted_beams = sorted(instance_new_beams, - key=sort_beams_key, - reverse=True) - instance.beams = sorted_beams[:beam_width] + if ( + token_id == tokenizer.eos_token_id + and not ignore_eos + ): + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted( + instance_new_beams, key=sort_beams_key, reverse=True + ) + instance.beams = sorted_beams[:beam_width] outputs = [] for instance in instances: instance.completed.extend(instance.beams) - sorted_completed = sorted(instance.completed, - key=sort_beams_key, - reverse=True) + sorted_completed = sorted( + instance.completed, key=sort_beams_key, reverse=True + ) best_beams = sorted_completed[:beam_width] for beam in best_beams: @@ -677,12 +768,110 @@ class LLM: return outputs + def preprocess_chat( + self, + messages: Union[ + list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]] + ], + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[list[dict[str, Any]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[TokensPrompt]: + """ + Generate prompt for a chat conversation. The pre-processed + prompt can then be used as input for the other LLM methods. + + Refer to `chat` for a complete description of the arguments. + Returns: + A list of `TokensPrompts` objects containing the tokenized + prompt after chat template interpolation, and the + pre-processed multi-modal inputs. + """ + list_of_messages: list[list[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages) + else: + # messages is list[...] + list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] + + tokenizer = self.get_tokenizer() + model_config = self.model_config + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tools, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + prompts: list[TokensPrompt] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data, mm_uuids = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + if isinstance(tokenizer, MistralTokenizer): + prompt_token_ids = apply_mistral_chat_template( + tokenizer, + messages=msgs, + **_chat_template_kwargs, + ) + else: + prompt_str = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode( + prompt_str, add_special_tokens=False + ) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return prompts + def chat( self, - messages: Union[list[ChatCompletionMessageParam], - list[list[ChatCompletionMessageParam]]], - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, + messages: Union[ + list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]] + ], + sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, @@ -697,8 +886,8 @@ class LLM: Generate responses for a chat conversation. The chat conversation is converted into a text prompt using the - tokenizer and calls the [generate][] method to generate the - responses. + tokenizer and calls the [generate][vllm.LLM.generate] method to generate + the responses. Multi-modal inputs can be passed in the same way you would pass them to the OpenAI API. @@ -743,77 +932,17 @@ class LLM: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ - list_of_messages: list[list[ChatCompletionMessageParam]] - # Handle multi and single conversations - if is_list_of(messages, list): - # messages is list[list[...]] - list_of_messages = cast(list[list[ChatCompletionMessageParam]], - messages) - else: - # messages is list[...] - list_of_messages = [ - cast(list[ChatCompletionMessageParam], messages) - ] - - tokenizer = self.get_tokenizer(lora_request) - model_config = self.llm_engine.get_model_config() - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tools, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) - - _chat_template_kwargs: dict[str, Any] = dict( + prompts = self.preprocess_chat( + messages=messages, chat_template=chat_template, + chat_template_content_format=chat_template_content_format, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, tools=tools, + chat_template_kwargs=chat_template_kwargs, + mm_processor_kwargs=mm_processor_kwargs, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) - - prompts: list[Union[TokensPrompt, TextPrompt]] = [] - - for msgs in list_of_messages: - # NOTE: _parse_chat_message_content_parts() currently doesn't - # handle mm_processor_kwargs, since there is no implementation in - # the chat message parsing for it. - conversation, mm_data = parse_chat_messages( - msgs, - model_config, - tokenizer, - content_format=resolved_content_format, - ) - - if isinstance(tokenizer, MistralTokenizer): - prompt_token_ids = apply_mistral_chat_template( - tokenizer, - messages=msgs, - **_chat_template_kwargs, - ) - else: - prompt_str = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - # Special tokens are already included in chat templates so - # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode(prompt_str, - add_special_tokens=False) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - - if mm_data is not None: - prompt["multi_modal_data"] = mm_data - - if mm_processor_kwargs is not None: - prompt["mm_processor_kwargs"] = mm_processor_kwargs - - prompts.append(prompt) return self.generate( prompts, @@ -824,9 +953,8 @@ class LLM: def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + prompts: Union[PromptType, Sequence[PromptType], DataPrompt], + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, @@ -844,7 +972,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. @@ -853,6 +981,8 @@ class LLM: If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. pooling_task: Override the pooling task to use. + tokenization_kwargs: overrides tokenization_kwargs set in + pooling_params Returns: A list of `PoolingRequestOutput` objects containing the @@ -863,64 +993,96 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ + + if self.supported_tasks == ["encode"] and pooling_task is None: + pooling_task = "encode" + if pooling_task is None: - if "embed" in self.supported_tasks: - pooling_task = "embed" - else: - pooling_task = "encode" + pooling_task = "embed" if "embed" in self.supported_tasks else "encode" logger.warning_once( "`LLM.encode` is currently using `pooling_task = %s`.\n" "Please use one of the more specific methods or set the " "task directly when using `LLM.encode`:\n" " - For embeddings, use `LLM.embed(...)` " - "or `pooling_task=\"embed\"`.\n" + 'or `pooling_task="embed"`.\n' " - For classification logits, use `LLM.classify(...)` " - "or `pooling_task=\"classify\"`.\n" + 'or `pooling_task="classify"`.\n' " - For rewards, use `LLM.reward(...)` " - "or `pooling_task=\"reward\"`\n" + 'or `pooling_task="reward"`\n' " - For similarity scores, use `LLM.score(...)`.", - pooling_task) + pooling_task, + ) - model_config = self.llm_engine.model_config + model_config = self.model_config runner_type = model_config.runner_type if runner_type != "pooling": raise ValueError( "LLM.encode() is only supported for pooling models. " "Try passing `--runner pooling` to use the model as a " - "pooling model.") + "pooling model." + ) if pooling_task not in self.supported_tasks: - raise ValueError( - f"pooling_task must be one of {self.supported_tasks}.") + raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - if isinstance(pooling_params, PoolingParams): - pooling_params.verify(pooling_task, model_config) - else: - for pooling_param in pooling_params: - pooling_param.verify(pooling_task, model_config) + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens - if tokenization_kwargs is None: - tokenization_kwargs = dict[str, Any]() - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, - tokenization_kwargs) + io_processor_prompt = False + if isinstance(prompts, dict) and "data" in prompts: + io_processor_prompt = True + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details." + ) + + # Validate the request data is valid for the loaded plugin + validated_prompt = self.io_processor.parse_request(prompts) + + # obtain the actual model prompts from the pre-processor + prompts = self.io_processor.pre_process(prompt=validated_prompt) self._validate_and_add_requests( prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, ) outputs = self._run_engine(use_tqdm=use_tqdm) - return self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + + model_outputs = self.engine_class.validate_outputs( + outputs, PoolingRequestOutput + ) + + if io_processor_prompt: + # get the post-processed model outputs + assert self.io_processor is not None + processed_outputs = self.io_processor.post_process( + model_output=model_outputs + ) + + return [ + PoolingRequestOutput[Any]( + request_id="", + outputs=processed_outputs, + prompt_token_ids=[], + finished=True, + ) + ] + else: + return model_outputs def embed( self, @@ -928,8 +1090,7 @@ class LLM: *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[EmbeddingRequestOutput]: """ @@ -942,7 +1103,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: If `True`, shows a tqdm progress bar. @@ -958,7 +1119,8 @@ class LLM: if "embed" not in self.supported_tasks: raise ValueError( "Embedding API is not supported by this model. " - "Try converting the model using `--convert embed`.") + "Try converting the model using `--convert embed`." + ) items = self.encode( prompts, @@ -976,8 +1138,7 @@ class LLM: prompts: Union[PromptType, Sequence[PromptType]], *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ClassificationRequestOutput]: """ @@ -990,7 +1151,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. @@ -1005,7 +1166,8 @@ class LLM: if "classify" not in self.supported_tasks: raise ValueError( "Classification API is not supported by this model. " - "Try converting the model using `--convert classify`.") + "Try converting the model using `--convert classify`." + ) items = self.encode( prompts, @@ -1024,8 +1186,7 @@ class LLM: *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[PoolingRequestOutput]: """ @@ -1034,7 +1195,7 @@ class LLM: Args: prompts: The prompts to the LLM. You may pass a sequence of prompts for batch inference. See [PromptType][vllm.inputs.PromptType] - for more details about the format of each prompts. + for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. @@ -1066,7 +1227,6 @@ class LLM: pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: - encoded_output: list[PoolingRequestOutput] = self.encode( text_1 + text_2, truncate_prompt_tokens=truncate_prompt_tokens, @@ -1076,20 +1236,17 @@ class LLM: pooling_task="embed", ) - encoded_output_1: list[PoolingRequestOutput] = encoded_output[ - 0:len(text_1)] - encoded_output_2: list[PoolingRequestOutput] = encoded_output[ - len(text_1):] + encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] + encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - scores = _cosine_similarity(tokenizer=tokenizer, - embed_1=encoded_output_1, - embed_2=encoded_output_2) + scores = _cosine_similarity( + tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2 + ) - items = self.engine_class.validate_outputs(scores, - PoolingRequestOutput) + items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in items] def _cross_encoding_score( @@ -1102,11 +1259,10 @@ class LLM: pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: - model_config = self.llm_engine.model_config + model_config = self.model_config if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "Score API is not supported for Mistral tokenizer") + raise ValueError("Score API is not supported for Mistral tokenizer") if len(data_1) == 1: data_1 = data_1 * len(data_2) @@ -1114,21 +1270,19 @@ class LLM: if pooling_params is None: pooling_params = PoolingParams(task="score") - model_config = self.llm_engine.model_config pooling_params.verify("score", model_config) pooling_params_list = list[PoolingParams]() tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) + _validate_truncation_size( + model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs + ) prompts = list[PromptType]() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - model_config = self.llm_engine.model_config - for q, d in input_pairs: _, engine_prompt = get_score_prompt( model_config=model_config, @@ -1138,8 +1292,7 @@ class LLM: tokenization_kwargs=tokenization_kwargs, ) - if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( - "token_type_ids", None)): + if token_type_ids := engine_prompt.pop("token_type_ids", None): params = pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) params.extra_kwargs = {"compressed_token_type_ids": compressed} @@ -1157,17 +1310,14 @@ class LLM: ) outputs = self._run_engine(use_tqdm=use_tqdm) - items = self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in items] def score( self, - data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], - ScoreMultiModalParam], - data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], - ScoreMultiModalParam], + data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam], + data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam], /, *, truncate_prompt_tokens: Optional[int] = None, @@ -1210,22 +1360,27 @@ class LLM: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. """ - model_config = self.llm_engine.model_config + model_config = self.model_config runner_type = model_config.runner_type if runner_type != "pooling": raise ValueError( "LLM.score() is only supported for pooling models. " "Try passing `--runner pooling` to use the model as a " - "pooling model.") + "pooling model." + ) supported_tasks = self.supported_tasks if all(t not in supported_tasks for t in ("embed", "classify")): - raise ValueError("Score API is not supported by this model. " - "Try converting the model using " - "`--convert embed` or `--convert classify`.") + raise ValueError( + "Score API is not supported by this model. " + "Try converting the model using " + "`--convert embed` or `--convert classify`." + ) - if (model_config.is_cross_encoder - and getattr(model_config.hf_config, "num_labels", 0) != 1): + if ( + model_config.is_cross_encoder + and getattr(model_config.hf_config, "num_labels", 0) != 1 + ): raise ValueError("Score API is only enabled for num_labels == 1.") # the tokenizer for models such as @@ -1235,12 +1390,16 @@ class LLM: if not model_config.is_multimodal_model: - def check_data_type(data: Union[SingletonPrompt, - Sequence[SingletonPrompt], - ScoreMultiModalParam]): + def check_data_type( + data: Union[ + SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam + ], + ): if isinstance(data, dict) and "content" in data: - raise ValueError("ScoreMultiModalParam is not supported " - f"for {model_config.architecture}") + raise ValueError( + "ScoreMultiModalParam is not supported " + f"for {model_config.architecture}" + ) check_data_type(data_1) check_data_type(data_2) @@ -1248,11 +1407,13 @@ class LLM: def ensure_str(prompt: SingletonPrompt): if isinstance(prompt, dict): if "multi_modal_data" in prompt: - raise ValueError("Multi-modal prompt is not " - "supported for scoring") + raise ValueError( + "Multi-modal prompt is not supported for scoring" + ) elif "prompt_token_ids" in prompt: prompt = tokenizer.decode( - cast(TokensPrompt, prompt)["prompt_token_ids"]) + cast(TokensPrompt, prompt)["prompt_token_ids"] + ) elif "prompt" in prompt: prompt = cast(TextPrompt, prompt)["prompt"] assert type(prompt) is str @@ -1290,7 +1451,8 @@ class LLM: truncate_prompt_tokens, use_tqdm, pooling_params, - lora_request) + lora_request, + ) else: return self._embedding_score( tokenizer, @@ -1299,7 +1461,8 @@ class LLM: truncate_prompt_tokens, use_tqdm, pooling_params, - lora_request) + lora_request, + ) def start_profile(self) -> None: self.llm_engine.start_profile() @@ -1334,8 +1497,8 @@ class LLM: def wake_up(self, tags: Optional[list[str]] = None): """ - Wake up the engine from sleep mode. See the [sleep][] method - for more details. + Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep] + method for more details. Args: tags: An optional list of tags to reallocate the engine memory @@ -1356,35 +1519,35 @@ class LLM: Note: This method is only available with the V1 LLM engine. """ - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - assert isinstance(self.llm_engine, V1LLMEngine) return self.llm_engine.get_metrics() def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], - params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, - Sequence[PoolingParams]], + prompts: Union[PromptType, Sequence[PromptType], DataPrompt], + params: Union[ + SamplingParams, + Sequence[SamplingParams], + PoolingParams, + Sequence[PoolingParams], + ], *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], - tokenization_kwargs: Optional[dict[str, Any]] = None, priority: Optional[list[int]] = None, ) -> None: if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + prompts = [prompts] # type: ignore[list-item] num_requests = len(prompts) if isinstance(params, Sequence) and len(params) != num_requests: - raise ValueError("The lengths of prompts and params " - "must be the same.") - if isinstance(lora_request, - Sequence) and len(lora_request) != num_requests: - raise ValueError("The lengths of prompts and lora_request " - "must be the same.") + raise ValueError("The lengths of prompts and params must be the same.") + if isinstance(lora_request, Sequence) and len(lora_request) != num_requests: + raise ValueError( + "The lengths of prompts and lora_request must be the same." + ) - for sp in params if isinstance(params, Sequence) else (params, ): + for sp in params if isinstance(params, Sequence) else (params,): if isinstance(sp, SamplingParams): # We only care about the final output sp.output_kind = RequestOutputKind.FINAL_ONLY @@ -1396,37 +1559,125 @@ class LLM: it = tqdm_func(it, desc="Adding requests") for i, prompt in enumerate(it): + if isinstance(prompt, dict): + self._validate_mm_data_and_uuids( + prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids") + ) + self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request[i] if isinstance( - lora_request, Sequence) else lora_request, + lora_request=lora_request[i] + if isinstance(lora_request, Sequence) + else lora_request, priority=priority[i] if priority else 0, ) - def _add_request( + def _validate_mm_data_and_uuids( self, - prompt: PromptType, + multi_modal_data: Optional[Any], # MultiModalDataDict + multi_modal_uuids: Optional[Any], # MultiModalUUIDDict + ): + """ + Validate that if any multi-modal data is skipped (i.e. None), + then its corresponding UUID must be set. + """ + if multi_modal_data is None: + return + + for modality, data in multi_modal_data.items(): + if isinstance(data, list): + for i, d in enumerate(data): + if d is None: + if ( + multi_modal_uuids is None + or modality not in multi_modal_uuids + or multi_modal_uuids[ # noqa: E501 + modality + ] + is None + ): + raise ValueError( + f"Multi-modal data for {modality} is None " + f"but UUID is not provided" + ) + else: + if ( + len(multi_modal_uuids[modality]) <= i + or multi_modal_uuids[modality][i] is None + ): + raise ValueError( + f"Multi-modal data for {modality} is None " + f"but UUID is not provided" + ) + else: + if data is None and ( + multi_modal_uuids is None + or modality not in multi_modal_uuids + or multi_modal_uuids[modality] is None + ): + raise ValueError( + f"Multi-modal data for {modality} is None" + f" but UUID is not provided" + ) + + def _process_inputs( + self, + request_id: str, + engine_prompt: PromptType, params: Union[SamplingParams, PoolingParams], - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - priority: int = 0, - ) -> None: - request_id = str(next(self.request_counter)) - self.llm_engine.add_request( + *, + lora_request: Optional[LoRARequest], + priority: int, + ) -> tuple[EngineCoreRequest, dict[str, Any]]: + """Use the Processor to process inputs for LLMEngine.""" + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size( + self.model_config.max_model_len, + params.truncate_prompt_tokens, + tokenization_kwargs, + ) + + engine_request = self.processor.process_inputs( request_id, - prompt, + engine_prompt, params, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, priority=priority, ) + return engine_request, tokenization_kwargs + + def _add_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest] = None, + priority: int = 0, + ) -> None: + prompt_text, _, _ = get_prompt_components(prompt) + request_id = str(next(self.request_counter)) + + engine_request, tokenization_kwargs = self._process_inputs( + request_id, + prompt, + params, + lora_request=lora_request, + priority=priority, + ) + + self.llm_engine.add_request( + request_id, + engine_request, + params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + priority=priority, + prompt_text=prompt_text, + ) def _run_engine( - self, - *, - use_tqdm: Union[bool, Callable[..., tqdm]] = True + self, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True ) -> list[Union[RequestOutput, PoolingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -1436,8 +1687,7 @@ class LLM: total=num_requests, desc="Processed prompts", dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, " - f"output: {0:.2f} toks/s"), + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), ) # Run the engine. @@ -1457,12 +1707,13 @@ class LLM: total_in_toks += len(output.prompt_token_ids) * n in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs) - out_spd = (total_out_toks / - pbar.format_dict["elapsed"]) + len(stp.token_ids) for stp in output.outputs + ) + out_spd = total_out_toks / pbar.format_dict["elapsed"] pbar.postfix = ( f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s") + f"output: {out_spd:.2f} toks/s" + ) pbar.update(n) else: pbar.update(1) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 152d11c84ea02..96a84668e92b3 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -15,7 +15,6 @@ logger = init_logger(__name__) class RequestLogger: - def __init__(self, *, max_log_len: Optional[int]) -> None: self.max_log_len = max_log_len @@ -25,8 +24,7 @@ class RequestLogger: prompt: Optional[str], prompt_token_ids: Optional[list[int]], prompt_embeds: Optional[torch.Tensor], - params: Optional[Union[SamplingParams, PoolingParams, - BeamSearchParams]], + params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], ) -> None: max_log_len = self.max_log_len @@ -41,9 +39,14 @@ class RequestLogger: "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " "prompt_embeds shape: %s, " - "lora_request: %s.", request_id, prompt, params, prompt_token_ids, + "lora_request: %s.", + request_id, + prompt, + params, + prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, - lora_request) + lora_request, + ) def log_outputs( self, @@ -65,8 +68,7 @@ class RequestLogger: stream_info = "" if is_streaming: - stream_info = (" (streaming delta)" - if delta else " (streaming complete)") + stream_info = " (streaming delta)" if delta else " (streaming complete)" logger.info( "Generated response %s%s: output: %r, " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 14ba8aa641837..5d5baad00da16 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,30 +2,30 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import atexit import gc +import hashlib import importlib import inspect import json import multiprocessing import multiprocessing.forkserver as forkserver import os +import secrets import signal import socket import tempfile import uuid from argparse import Namespace -from collections.abc import AsyncIterator, Awaitable +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable from contextlib import asynccontextmanager -from functools import partial from http import HTTPStatus -from typing import Annotated, Any, Callable, Optional +from typing import Annotated, Any, Callable, Literal, Optional import prometheus_client import pydantic import regex as re import uvloop -from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -40,80 +40,92 @@ from typing_extensions import assert_never import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore -from vllm.engine.multiprocessing.client import MQLLMEngineClient -from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (load_chat_template, - resolve_hf_chat_template, - resolve_mistral_chat_template) +from vllm.entrypoints.chat_utils import ( + load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template, +) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionResponse, - ClassificationRequest, - ClassificationResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - DetokenizeResponse, - EmbeddingRequest, - EmbeddingResponse, ErrorInfo, - ErrorResponse, - LoadLoRAAdapterRequest, - PoolingRequest, PoolingResponse, - RerankRequest, RerankResponse, - ResponsesRequest, - ResponsesResponse, ScoreRequest, - ScoreResponse, TokenizeRequest, - TokenizeResponse, - TranscriptionRequest, - TranscriptionResponse, - TranslationRequest, - TranslationResponse, - UnloadLoRAAdapterRequest) -# yapf: enable +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, + EmbeddingResponse, + ErrorInfo, + ErrorResponse, + IOProcessorResponse, + LoadLoRAAdapterRequest, + PoolingRequest, + PoolingResponse, + RerankRequest, + RerankResponse, + ResponsesRequest, + ResponsesResponse, + ScoreRequest, + ScoreResponse, + StreamingResponsesResponse, + TokenizeRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, + TranslationResponse, + UnloadLoRAAdapterRequest, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_classification import ( - ServingClassification) +from vllm.entrypoints.openai.serving_classification import ServingClassification from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - LoRAModulePath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import ( + BaseModelPath, + LoRAModulePath, + OpenAIServingModels, +) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_score import ServingScores -from vllm.entrypoints.openai.serving_tokenization import ( - OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.entrypoints.openai.serving_transcription import ( - OpenAIServingTranscription, OpenAIServingTranslation) + OpenAIServingTranscription, + OpenAIServingTranslation, +) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer, - ToolServer) -from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, - log_non_default_args, with_cancellation) +from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer +from vllm.entrypoints.utils import ( + cli_env_setup, + load_aware_call, + log_non_default_args, + with_cancellation, +) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs, - get_open_zmq_ipc_path, is_valid_ipv6_address, - set_ulimit) +from vllm.utils import ( + Device, + FlexibleArgumentParser, + decorate_logs, + is_valid_ipv6_address, + set_ulimit, +) +from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.openai.api_server') +logger = init_logger("vllm.entrypoints.openai.api_server") _running_tasks: set[asyncio.Task] = set() @@ -157,12 +169,11 @@ async def build_async_engine_client( disable_frontend_multiprocessing: Optional[bool] = None, client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: - if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": # The executor is expected to be mp. # Pre-import heavy modules in the forkserver process logger.debug("Setup forkserver with pre-imports") - multiprocessing.set_start_method('forkserver') + multiprocessing.set_start_method("forkserver") multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"]) forkserver.ensure_running() logger.debug("Forkserver setup complete!") @@ -170,16 +181,18 @@ async def build_async_engine_client( # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) + if client_config: + engine_args._api_process_count = client_config.get("client_count", 1) + engine_args._api_process_rank = client_config.get("client_index", 0) if disable_frontend_multiprocessing is None: - disable_frontend_multiprocessing = bool( - args.disable_frontend_multiprocessing) + disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing) async with build_async_engine_client_from_engine_args( - engine_args, - usage_context=usage_context, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, - client_config=client_config, + engine_args, + usage_context=usage_context, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, + client_config=client_config, ) as engine: yield engine @@ -204,150 +217,50 @@ async def build_async_engine_client_from_engine_args( vllm_config = engine_args.create_engine_config(usage_context=usage_context) # V1 AsyncLLM. - if envs.VLLM_USE_V1: - if disable_frontend_multiprocessing: - logger.warning( - "V1 is enabled, but got --disable-frontend-multiprocessing. " - "To disable frontend multiprocessing, set VLLM_USE_V1=0.") + assert envs.VLLM_USE_V1 - from vllm.v1.engine.async_llm import AsyncLLM - async_llm: Optional[AsyncLLM] = None - client_count = client_config.pop( - "client_count") if client_config else 1 - client_index = client_config.pop( - "client_index") if client_config else 0 - try: - async_llm = AsyncLLM.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - enable_log_requests=engine_args.enable_log_requests, - disable_log_stats=engine_args.disable_log_stats, - client_addresses=client_config, - client_count=client_count, - client_index=client_index) + if disable_frontend_multiprocessing: + logger.warning( + "V1 is enabled, but got --disable-frontend-multiprocessing. " + "To disable frontend multiprocessing, set VLLM_USE_V1=0." + ) - # Don't keep the dummy data in memory - await async_llm.reset_mm_cache() + from vllm.v1.engine.async_llm import AsyncLLM - yield async_llm - finally: - if async_llm: - async_llm.shutdown() + async_llm: Optional[AsyncLLM] = None - # V0 AsyncLLM. - elif (MQLLMEngineClient.is_unsupported_config(vllm_config) - or disable_frontend_multiprocessing): + # Don't mutate the input client_config + client_config = dict(client_config) if client_config else {} + client_count = client_config.pop("client_count", 1) + client_index = client_config.pop("client_index", 0) - engine_client: Optional[EngineClient] = None - try: - engine_client = AsyncLLMEngine.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - enable_log_requests=engine_args.enable_log_requests, - disable_log_stats=engine_args.disable_log_stats) - yield engine_client - finally: - if engine_client and hasattr(engine_client, "shutdown"): - engine_client.shutdown() + try: + async_llm = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + enable_log_requests=engine_args.enable_log_requests, + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_count=client_count, + client_index=client_index, + ) - # V0MQLLMEngine. - else: - if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: - # Make TemporaryDirectory for prometheus multiprocessing - # Note: global TemporaryDirectory will be automatically - # cleaned up upon exit. - global prometheus_multiproc_dir - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - else: - logger.warning( - "Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") + # Don't keep the dummy data in memory + await async_llm.reset_mm_cache() - # Select random path for IPC. - ipc_path = get_open_zmq_ipc_path() - logger.debug("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) - - # Start RPCServer in separate process (holds the LLMEngine). - # the current process might have CUDA context, - # so we need to spawn a new process - context = multiprocessing.get_context("spawn") - - # Ensure we can serialize transformer config before spawning - maybe_register_config_serialize_by_value() - - # The Process can raise an exception during startup, which may - # not actually result in an exitcode being reported. As a result - # we use a shared variable to communicate the information. - engine_alive = multiprocessing.Value('b', True, lock=False) - engine_process = context.Process( - target=run_mp_engine, - args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path, - engine_args.disable_log_stats, - engine_args.enable_log_requests, engine_alive)) - engine_process.start() - engine_pid = engine_process.pid - assert engine_pid is not None, "Engine process failed to start." - logger.info("Started engine process with PID %d", engine_pid) - - def _cleanup_ipc_path(): - socket_path = ipc_path.replace("ipc://", "") - if os.path.exists(socket_path): - os.remove(socket_path) - - # Ensure we clean up the local IPC socket file on exit. - atexit.register(_cleanup_ipc_path) - - # Build RPCClient, which conforms to EngineClient Protocol. - build_client = partial(MQLLMEngineClient, ipc_path, vllm_config, - engine_pid) - mq_engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_client) - try: - while True: - try: - await mq_engine_client.setup() - break - except TimeoutError: - if (not engine_process.is_alive() - or not engine_alive.value): - raise RuntimeError( - "Engine process failed to start. See stack " - "trace for the root cause.") from None - - yield mq_engine_client # type: ignore[misc] - finally: - # Ensure rpc server process was terminated - engine_process.terminate() - - # Close all open connections to the backend - mq_engine_client.close() - - # Wait for engine process to join - engine_process.join(4) - if engine_process.exitcode is None: - # Kill if taking longer than 5 seconds to stop - engine_process.kill() - - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import multiprocess - multiprocess.mark_process_dead(engine_process.pid) + yield async_llm + finally: + if async_llm: + async_llm.shutdown() async def validate_json_request(raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() media_type = content_type.split(";", maxsplit=1)[0] if media_type != "application/json": - raise RequestValidationError(errors=[ - "Unsupported Media Type: Only 'application/json' is allowed" - ]) + raise RequestValidationError( + errors=["Unsupported Media Type: Only 'application/json' is allowed"] + ) router = APIRouter() @@ -446,8 +359,11 @@ def engine_client(request: Request) -> EngineClient: @router.get("/health", response_class=Response) async def health(raw_request: Request) -> Response: """Health check.""" - await engine_client(raw_request).check_health() - return Response(status_code=200) + try: + await engine_client(raw_request).check_health() + return Response(status_code=200) + except EngineDeadError: + return Response(status_code=503) @router.get("/load") @@ -466,8 +382,7 @@ async def get_server_load_metrics(request: Request): # - /rerank # - /v1/rerank # - /v2/rerank - return JSONResponse( - content={'server_load': request.app.state.server_load_metrics}) + return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) @router.get("/ping", response_class=Response) @@ -477,22 +392,16 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_IMPLEMENTED.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -500,34 +409,33 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): try: generator = await handler.create_tokenize(request, raw_request) except NotImplementedError as e: - raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) + ) from e except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, TokenizeResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/detokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -537,12 +445,14 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): except OverflowError as e: raise RequestValidationError(errors=[str(e)]) from e except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, DetokenizeResponse): return JSONResponse(content=generator.model_dump()) @@ -551,15 +461,18 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): def maybe_register_tokenizer_info_endpoint(args): """Conditionally register the tokenizer info endpoint if enabled.""" - if getattr(args, 'enable_tokenizer_info_endpoint', False): + if getattr(args, "enable_tokenizer_info_endpoint", False): @router.get("/tokenizer_info") async def get_tokenizer_info(raw_request: Request): """Get comprehensive tokenizer information.""" result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse(content=result.model_dump(), - status_code=result.error.code if isinstance( - result, ErrorResponse) else 200) + return JSONResponse( + content=result.model_dump(), + status_code=result.error.code + if isinstance(result, ErrorResponse) + else 200, + ) @router.get("/v1/models") @@ -576,61 +489,88 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/responses", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +async def _convert_stream_to_sse_events( + generator: AsyncGenerator[StreamingResponsesResponse, None], +) -> AsyncGenerator[str, None]: + """Convert the generator to a stream of events in SSE format""" + async for event in generator: + event_type = getattr(event, "type", "unknown") + # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + event_data = ( + f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n" + ) + yield event_data + + +@router.post( + "/v1/responses", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def create_responses(request: ResponsesRequest, raw_request: Request): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Responses API") + message="The model does not support Responses API" + ) try: generator = await handler.create_responses(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ResponsesResponse): return JSONResponse(content=generator.model_dump()) - return StreamingResponse(content=generator, media_type="text/event-stream") + + return StreamingResponse( + content=_convert_stream_to_sse_events(generator), media_type="text/event-stream" + ) @router.get("/v1/responses/{response_id}") -async def retrieve_responses(response_id: str, raw_request: Request): +async def retrieve_responses( + response_id: str, + raw_request: Request, + starting_after: Optional[int] = None, + stream: Optional[bool] = False, +): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Responses API") + message="The model does not support Responses API" + ) try: - response = await handler.retrieve_responses(response_id) + response = await handler.retrieve_responses( + response_id, + starting_after=starting_after, + stream=stream, + ) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) - return JSONResponse(content=response.model_dump()) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + elif isinstance(response, ResponsesResponse): + return JSONResponse(content=response.model_dump()) + return StreamingResponse( + content=_convert_stream_to_sse_events(response), media_type="text/event-stream" + ) @router.post("/v1/responses/{response_id}/cancel") @@ -638,54 +578,51 @@ async def cancel_responses(response_id: str, raw_request: Request): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Responses API") + message="The model does not support Responses API" + ) try: response = await handler.cancel_responses(response_id) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) return JSONResponse(content=response.model_dump()) -@router.post("/v1/chat/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - } - }) +@router.post( + "/v1/chat/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call -async def create_chat_completion(request: ChatCompletionRequest, - raw_request: Request): +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Chat Completions API") + message="The model does not support Chat Completions API" + ) try: generator = await handler.create_chat_completion(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -693,109 +630,107 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Completions API") + message="The model does not support Completions API" + ) try: generator = await handler.create_completion(request, raw_request) except OverflowError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e) + ) from e except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, CompletionResponse): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/embeddings", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Embeddings API") + message="The model does not support Embeddings API" + ) try: generator = await handler.create_embedding(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/pooling", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/pooling", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Pooling API") + message="The model does not support Pooling API" + ) try: generator = await handler.create_pooling(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) - elif isinstance(generator, PoolingResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, (PoolingResponse, IOProcessorResponse)): return JSONResponse(content=generator.model_dump()) assert_never(generator) @@ -804,21 +739,23 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): @router.post("/classify", dependencies=[Depends(validate_json_request)]) @with_cancellation @load_aware_call -async def create_classify(request: ClassificationRequest, - raw_request: Request): +async def create_classify(request: ClassificationRequest, raw_request: Request): handler = classify(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Classification API") + message="The model does not support Classification API" + ) try: generator = await handler.create_classify(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ClassificationResponse): return JSONResponse(content=generator.model_dump()) @@ -826,96 +763,90 @@ async def create_classify(request: ClassificationRequest, assert_never(generator) -@router.post("/score", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Score API") + message="The model does not support Score API" + ) try: generator = await handler.create_score(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ScoreResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/v1/score", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " - "have moved it to `/score`. Please update your client accordingly.") + "have moved it to `/score`. Please update your client accordingly." + ) return await create_score(request, raw_request) -@router.post("/v1/audio/transcriptions", - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNPROCESSABLE_ENTITY.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/audio/transcriptions", + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call -async def create_transcriptions(raw_request: Request, - request: Annotated[TranscriptionRequest, - Form()]): +async def create_transcriptions( + raw_request: Request, request: Annotated[TranscriptionRequest, Form()] +): handler = transcription(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Transcriptions API") + message="The model does not support Transcriptions API" + ) audio_data = await request.file.read() try: - generator = await handler.create_transcription(audio_data, request, - raw_request) + generator = await handler.create_transcription(audio_data, request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, TranscriptionResponse): return JSONResponse(content=generator.model_dump()) @@ -923,44 +854,38 @@ async def create_transcriptions(raw_request: Request, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/audio/translations", - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNPROCESSABLE_ENTITY.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/audio/translations", + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call -async def create_translations(request: Annotated[TranslationRequest, - Form()], - raw_request: Request): +async def create_translations( + request: Annotated[TranslationRequest, Form()], raw_request: Request +): handler = translation(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Translations API") + message="The model does not support Translations API" + ) audio_data = await request.file.read() try: - generator = await handler.create_translation(audio_data, request, - raw_request) + generator = await handler.create_translation(audio_data, request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, TranslationResponse): return JSONResponse(content=generator.model_dump()) @@ -968,79 +893,90 @@ async def create_translations(request: Annotated[TranslationRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Rerank (Score) API") + message="The model does not support Rerank (Score) API" + ) try: generator = await handler.do_rerank(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, RerankResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/v1/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( "To indicate that the rerank API is not part of the standard OpenAI" " API, we have located it at `/rerank`. Please update your client " - "accordingly. (Note: Conforms to JinaAI rerank API)") + "accordingly. (Note: Conforms to JinaAI rerank API)" + ) return await do_rerank(request, raw_request) -@router.post("/v2/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v2/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) if envs.VLLM_SERVER_DEV_MODE: - logger.warning("SECURITY WARNING: Development endpoints are enabled! " - "This should NOT be used in production!") + logger.warning( + "SECURITY WARNING: Development endpoints are enabled! " + "This should NOT be used in production!" + ) + + PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig) @router.get("/server_info") - async def show_server_info(raw_request: Request): - server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} + async def show_server_info( + raw_request: Request, + config_format: Annotated[Literal["text", "json"], Query()] = "text", + ): + vllm_config: VllmConfig = raw_request.app.state.vllm_config + server_info = { + "vllm_config": str(vllm_config) + if config_format == "text" + else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str) + # fallback=str is needed to handle e.g. torch.dtype + } return JSONResponse(content=server_info) @router.post("/reset_prefix_cache") @@ -1089,19 +1025,24 @@ if envs.VLLM_SERVER_DEV_MODE: try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}") from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e method = body.get("method") if method is None: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'method' in request body") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'method' in request body", + ) # For security reason, only serialized string args/kwargs are passed. - # User-defined `method` is responsible for deseralization if needed. + # User-defined `method` is responsible for deserialization if needed. args: list[str] = body.get("args", []) kwargs: dict[str, str] = body.get("kwargs", {}) timeout: Optional[float] = body.get("timeout") results = await engine_client(raw_request).collective_rpc( - method=method, timeout=timeout, args=tuple(args), kwargs=kwargs) + method=method, timeout=timeout, args=tuple(args), kwargs=kwargs + ) if results is None: return Response(status_code=200) response: list[Any] = [] @@ -1113,45 +1054,39 @@ if envs.VLLM_SERVER_DEV_MODE: return JSONResponse(content={"results": response}) -@router.post("/scale_elastic_ep", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "model": dict - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.REQUEST_TIMEOUT.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/scale_elastic_ep", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) async def scale_elastic_ep(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=400, - detail="Invalid JSON format") from e # noqa: B904 + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 new_data_parallel_size = body.get("new_data_parallel_size") drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes if new_data_parallel_size is None: - raise HTTPException(status_code=400, - detail="new_data_parallel_size is required") - - if not isinstance(new_data_parallel_size, - int) or new_data_parallel_size <= 0: raise HTTPException( - status_code=400, - detail="new_data_parallel_size must be a positive integer") + status_code=400, detail="new_data_parallel_size is required" + ) + + if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, detail="new_data_parallel_size must be a positive integer" + ) if not isinstance(drain_timeout, int) or drain_timeout <= 0: - raise HTTPException(status_code=400, - detail="drain_timeout must be a positive integer") + raise HTTPException( + status_code=400, detail="drain_timeout must be a positive integer" + ) # Set scaling flag to prevent new requests global _scaling_elastic_ep @@ -1159,15 +1094,17 @@ async def scale_elastic_ep(raw_request: Request): client = engine_client(raw_request) try: await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) - return JSONResponse({ - "message": - f"Scaled to {new_data_parallel_size} " - "data parallel engines", - }) + return JSONResponse( + { + "message": f"Scaled to {new_data_parallel_size} data parallel engines", + } + ) except TimeoutError as e: - raise HTTPException(status_code=408, - detail="Scale failed due to request drain timeout " - f"after {drain_timeout} seconds") from e + raise HTTPException( + status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds", + ) from e except Exception as e: logger.error("Scale failed: %s", e) raise HTTPException(status_code=500, detail="Scale failed") from e @@ -1204,31 +1141,29 @@ INVOCATION_VALIDATORS = [ ] -@router.post("/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) async def invocations(raw_request: Request): """For SageMaker, routes requests based on the request type.""" try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}") from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}" + ) from e - valid_endpoints = [(validator, endpoint) - for validator, (get_handler, - endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None] + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] for request_validator, endpoint in valid_endpoints: try: @@ -1242,8 +1177,7 @@ async def invocations(raw_request: Request): t.__name__ if isinstance(t := validator._type, type) else str(t) for validator, _ in valid_endpoints ] - msg = ("Cannot find suitable handler for request. " - f"Expected one of: {type_names}") + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" res = base(raw_request).create_error_response(message=msg) return JSONResponse(content=res.model_dump(), status_code=res.error.code) @@ -1251,7 +1185,8 @@ async def invocations(raw_request: Request): if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!") + "used for local development!" + ) @router.post("/start_profile") async def start_profile(raw_request: Request): @@ -1271,29 +1206,32 @@ if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: logger.warning( "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!") + "This should ONLY be used for local development!" + ) - @router.post("/v1/load_lora_adapter", - dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoRAAdapterRequest, - raw_request: Request): + @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): handler = models(raw_request) response = await handler.load_lora_adapter(request) if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) return Response(status_code=200, content=response) - @router.post("/v1/unload_lora_adapter", - dependencies=[Depends(validate_json_request)]) - async def unload_lora_adapter(request: UnloadLoRAAdapterRequest, - raw_request: Request): + @router.post( + "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] + ) + async def unload_lora_adapter( + request: UnloadLoRAAdapterRequest, raw_request: Request + ): handler = models(raw_request) response = await handler.unload_lora_adapter(request) if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) return Response(status_code=200, content=response) @@ -1305,15 +1243,16 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: with open(log_config_file) as f: return json.load(f) except Exception as e: - logger.warning("Failed to load log config from file %s: error %s", - log_config_file, e) + logger.warning( + "Failed to load log config from file %s: error %s", log_config_file, e + ) return None class AuthenticationMiddleware: """ Pure ASGI middleware that authenticates each request by checking - if the Authorization header exists and equals "Bearer {api_key}". + if the Authorization Bearer token exists and equals anyof "{api_key}". Notes ----- @@ -1324,12 +1263,27 @@ class AuthenticationMiddleware: def __init__(self, app: ASGIApp, tokens: list[str]) -> None: self.app = app - self.api_tokens = {f"Bearer {token}" for token in tokens} + self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens] - def __call__(self, scope: Scope, receive: Receive, - send: Send) -> Awaitable[None]: - if scope["type"] not in ("http", - "websocket") or scope["method"] == "OPTIONS": + def verify_token(self, headers: Headers) -> bool: + authorization_header_value = headers.get("Authorization") + if not authorization_header_value: + return False + + scheme, _, param = authorization_header_value.partition(" ") + if scheme.lower() != "bearer": + return False + + param_hash = hashlib.sha256(param.encode("utf-8")).digest() + + token_match = False + for token_hash in self.api_tokens: + token_match |= secrets.compare_digest(param_hash, token_hash) + + return token_match + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS": # scope["type"] can be "lifespan" or "startup" for example, # in which case we don't need to do anything return self.app(scope, receive, send) @@ -1337,10 +1291,8 @@ class AuthenticationMiddleware: url_path = URL(scope=scope).path.removeprefix(root_path) headers = Headers(scope=scope) # Type narrow to satisfy mypy. - if url_path.startswith("/v1") and headers.get( - "Authorization") not in self.api_tokens: - response = JSONResponse(content={"error": "Unauthorized"}, - status_code=401) + if url_path.startswith("/v1") and not self.verify_token(headers): + response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) return response(scope, receive, send) return self.app(scope, receive, send) @@ -1355,8 +1307,7 @@ class XRequestIdMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app - def __call__(self, scope: Scope, receive: Receive, - send: Send) -> Awaitable[None]: + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: if scope["type"] not in ("http", "websocket"): return self.app(scope, receive, send) @@ -1370,8 +1321,7 @@ class XRequestIdMiddleware: """ if message["type"] == "http.response.start": response_headers = MutableHeaders(raw=message["headers"]) - request_id = request_headers.get("X-Request-Id", - uuid.uuid4().hex) + request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex) response_headers.append("X-Request-Id", request_id) await send(message) @@ -1394,8 +1344,7 @@ class ScalingMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app - def __call__(self, scope: Scope, receive: Receive, - send: Send) -> Awaitable[None]: + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: if scope["type"] != "http": return self.app(scope, receive, send) @@ -1403,11 +1352,12 @@ class ScalingMiddleware: global _scaling_elastic_ep if _scaling_elastic_ep: # Return 503 Service Unavailable response - response = JSONResponse(content={ - "error": - "The model is currently scaling. Please try again later." - }, - status_code=503) + response = JSONResponse( + content={ + "error": "The model is currently scaling. Please try again later." + }, + status_code=503, + ) return response(scope, receive, send) return self.app(scope, receive, send) @@ -1417,28 +1367,27 @@ def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: from vllm.entrypoints.openai.protocol import ( - ChatCompletionStreamResponse, CompletionStreamResponse) + ChatCompletionStreamResponse, + CompletionStreamResponse, + ) # Try using Completion types for type-safe parsing - if chunk_data.get('object') == 'chat.completion.chunk': - chat_response = ChatCompletionStreamResponse.model_validate( - chunk_data) + if chunk_data.get("object") == "chat.completion.chunk": + chat_response = ChatCompletionStreamResponse.model_validate(chunk_data) if chat_response.choices and chat_response.choices[0].delta.content: return chat_response.choices[0].delta.content - elif chunk_data.get('object') == 'text_completion': - completion_response = CompletionStreamResponse.model_validate( - chunk_data) - if completion_response.choices and completion_response.choices[ - 0].text: + elif chunk_data.get("object") == "text_completion": + completion_response = CompletionStreamResponse.model_validate(chunk_data) + if completion_response.choices and completion_response.choices[0].text: return completion_response.choices[0].text except pydantic.ValidationError: # Fallback to manual parsing - if 'choices' in chunk_data and chunk_data['choices']: - choice = chunk_data['choices'][0] - if 'delta' in choice and choice['delta'].get('content'): - return choice['delta']['content'] - elif choice.get('text'): - return choice['text'] + if "choices" in chunk_data and chunk_data["choices"]: + choice = chunk_data["choices"][0] + if "delta" in choice and choice["delta"].get("content"): + return choice["delta"]["content"] + elif choice.get("text"): + return choice["text"] return "" @@ -1454,7 +1403,7 @@ class SSEDecoder: import json try: - chunk_str = chunk.decode('utf-8') + chunk_str = chunk.decode("utf-8") except UnicodeDecodeError: # Skip malformed chunks return [] @@ -1463,18 +1412,18 @@ class SSEDecoder: events = [] # Process complete lines - while '\n' in self.buffer: - line, self.buffer = self.buffer.split('\n', 1) - line = line.rstrip('\r') # Handle CRLF + while "\n" in self.buffer: + line, self.buffer = self.buffer.split("\n", 1) + line = line.rstrip("\r") # Handle CRLF - if line.startswith('data: '): + if line.startswith("data: "): data_str = line[6:].strip() - if data_str == '[DONE]': - events.append({'type': 'done'}) + if data_str == "[DONE]": + events.append({"type": "done"}) elif data_str: try: event_data = json.loads(data_str) - events.append({'type': 'data', 'data': event_data}) + events.append({"type": "data", "data": event_data}) except json.JSONDecodeError: # Skip malformed JSON continue @@ -1492,7 +1441,7 @@ class SSEDecoder: def get_complete_content(self) -> str: """Get the complete buffered content.""" - return ''.join(self.content_buffer) + return "".join(self.content_buffer) def _log_streaming_response(response, response_body: list) -> None: @@ -1513,10 +1462,10 @@ def _log_streaming_response(response, response_body: list) -> None: events = sse_decoder.decode_chunk(chunk) for event in events: - if event['type'] == 'data': - content = sse_decoder.extract_content(event['data']) + if event["type"] == "data": + content = sse_decoder.extract_content(event["data"]) sse_decoder.add_content(content) - elif event['type'] == 'done': + elif event["type"] == "done": # Log complete content when done full_content = sse_decoder.get_complete_content() if full_content: @@ -1525,19 +1474,20 @@ def _log_streaming_response(response, response_body: list) -> None: full_content = full_content[:2048] + "" "...[truncated]" logger.info( - "response_body={streaming_complete: " \ + "response_body={streaming_complete: " "content='%s', chunks=%d}", - full_content, chunk_count) + full_content, + chunk_count, + ) else: logger.info( - "response_body={streaming_complete: " \ - "no_content, chunks=%d}", - chunk_count) + "response_body={streaming_complete: no_content, chunks=%d}", + chunk_count, + ) return response.body_iterator = iterate_in_threadpool(buffered_iterator()) - logger.info("response_body={streaming_started: chunks=%d}", - len(response_body)) + logger.info("response_body={streaming_started: chunks=%d}", len(response_body)) def _log_non_streaming_response(response_body: list) -> None: @@ -1551,10 +1501,9 @@ def _log_non_streaming_response(response_body: list) -> None: def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: - app = FastAPI(openapi_url=None, - docs_url=None, - redoc_url=None, - lifespan=lifespan) + app = FastAPI( + openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan + ) else: app = FastAPI(lifespan=lifespan) app.include_router(router) @@ -1573,14 +1522,16 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(HTTPException) async def http_exception_handler(_: Request, exc: HTTPException): err = ErrorResponse( - error=ErrorInfo(message=exc.detail, - type=HTTPStatus(exc.status_code).phrase, - code=exc.status_code)) + error=ErrorInfo( + message=exc.detail, + type=HTTPStatus(exc.status_code).phrase, + code=exc.status_code, + ) + ) return JSONResponse(err.model_dump(), status_code=exc.status_code) @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_: Request, - exc: RequestValidationError): + async def validation_exception_handler(_: Request, exc: RequestValidationError): exc_str = str(exc) errors_str = str(exc.errors()) @@ -1589,11 +1540,14 @@ def build_app(args: Namespace) -> FastAPI: else: message = exc_str - err = ErrorResponse(error=ErrorInfo(message=message, - type=HTTPStatus.BAD_REQUEST.phrase, - code=HTTPStatus.BAD_REQUEST)) - return JSONResponse(err.model_dump(), - status_code=HTTPStatus.BAD_REQUEST) + err = ErrorResponse( + error=ErrorInfo( + message=message, + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST, + ) + ) + return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]: @@ -1606,16 +1560,16 @@ def build_app(args: Namespace) -> FastAPI: app.add_middleware(ScalingMiddleware) if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: - logger.warning("CAUTION: Enabling log response in the API Server. " - "This can include sensitive information and should be " - "avoided in production.") + logger.warning( + "CAUTION: Enabling log response in the API Server. " + "This can include sensitive information and should be " + "avoided in production." + ) @app.middleware("http") async def log_response(request: Request, call_next): response = await call_next(request) - response_body = [ - section async for section in response.body_iterator - ] + response_body = [section async for section in response.body_iterator] response.body_iterator = iterate_in_threadpool(iter(response_body)) # Check if this is a streaming response by looking at content-type content_type = response.headers.get("content-type", "") @@ -1638,18 +1592,20 @@ def build_app(args: Namespace) -> FastAPI: elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: - raise ValueError(f"Invalid middleware {middleware}. " - f"Must be a function or a class.") + raise ValueError( + f"Invalid middleware {middleware}. Must be a function or a class." + ) return app async def init_app_state( engine_client: EngineClient, - vllm_config: VllmConfig, state: State, args: Namespace, ) -> None: + vllm_config = engine_client.vllm_config + if args.served_model_name is not None: served_model_names = args.served_model_name else: @@ -1661,22 +1617,15 @@ async def init_app_state( request_logger = None base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names + BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config - model_config = vllm_config.model_config - if envs.VLLM_USE_V1: - supported_tasks = await engine_client \ - .get_supported_tasks() # type: ignore - else: - supported_tasks = model_config.supported_tasks - - logger.info("Supported_tasks: %s", supported_tasks) + supported_tasks = await engine_client.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) resolved_chat_template = load_chat_template(args.chat_template) if resolved_chat_template is not None: @@ -1686,7 +1635,8 @@ async def init_app_state( if isinstance(tokenizer, MistralTokenizer): # The warning is logged in resolve_mistral_chat_template. resolved_chat_template = resolve_mistral_chat_template( - chat_template=resolved_chat_template) + chat_template=resolved_chat_template + ) else: hf_chat_template = resolve_hf_chat_template( tokenizer=tokenizer, @@ -1700,10 +1650,14 @@ async def init_app_state( "Using supplied chat template: %s\n" "It is different from official chat template '%s'. " "This discrepancy may lead to performance degradation.", - resolved_chat_template, args.model) + resolved_chat_template, + args.model, + ) if args.tool_server == "demo": tool_server: Optional[ToolServer] = DemoToolServer() + assert isinstance(tool_server, DemoToolServer) + await tool_server.init_and_validate() elif args.tool_server: tool_server = MCPToolServer() await tool_server.add_tool_server(args.tool_server) @@ -1711,8 +1665,11 @@ async def init_app_state( tool_server = None # Merge default_mm_loras into the static lora_modules - default_mm_loras = (vllm_config.lora_config.default_mm_loras - if vllm_config.lora_config is not None else {}) + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) lora_modules = args.lora_modules if default_mm_loras: @@ -1720,7 +1677,8 @@ async def init_app_state( LoRAModulePath( name=modality, path=lora_path, - ) for modality, lora_path in default_mm_loras.items() + ) + for modality, lora_path in default_mm_loras.items() ] if args.lora_modules is None: lora_modules = default_mm_lora_paths @@ -1729,104 +1687,140 @@ async def init_app_state( state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, - model_config=model_config, base_model_paths=base_model_paths, lora_modules=lora_modules, ) await state.openai_serving_models.init_static_loras() - state.openai_serving_responses = OpenAIServingResponses( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - tool_server=tool_server, - reasoning_parser=args.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - ) if "generate" in supported_tasks else None - state.openai_serving_chat = OpenAIServingChat( - engine_client, - model_config, - state.openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args. - exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - reasoning_parser=args.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - ) if "generate" in supported_tasks else None - state.openai_serving_completion = OpenAIServingCompletion( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - ) if "generate" in supported_tasks else None - state.openai_serving_pooling = OpenAIServingPooling( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - ) if "encode" in supported_tasks else None - state.openai_serving_embedding = OpenAIServingEmbedding( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - ) if "embed" in supported_tasks else None - state.openai_serving_classification = ServingClassification( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - ) if "classify" in supported_tasks else None - - enable_serving_reranking = ("classify" in supported_tasks and getattr( - model_config.hf_config, "num_labels", 0) == 1) - state.openai_serving_scores = ServingScores( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - ) if ("embed" in supported_tasks or enable_serving_reranking) else None - + state.openai_serving_responses = ( + OpenAIServingResponses( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + tool_server=tool_server, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_chat = ( + OpenAIServingChat( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_completion = ( + OpenAIServingCompletion( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_pooling = ( + OpenAIServingPooling( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "encode" in supported_tasks + else None + ) + state.openai_serving_embedding = ( + OpenAIServingEmbedding( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "embed" in supported_tasks + else None + ) + state.openai_serving_classification = ( + ServingClassification( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "classify" in supported_tasks + else None + ) + state.openai_serving_scores = ( + ServingScores( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if ("embed" in supported_tasks or "score" in supported_tasks) + else None + ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + state.openai_serving_transcription = ( + OpenAIServingTranscription( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "transcription" in supported_tasks + else None + ) + state.openai_serving_translation = ( + OpenAIServingTranslation( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "transcription" in supported_tasks + else None ) - state.openai_serving_transcription = OpenAIServingTranscription( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - ) if "transcription" in supported_tasks else None - state.openai_serving_translation = OpenAIServingTranslation( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - ) if "transcription" in supported_tasks else None state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -1853,17 +1847,20 @@ def create_server_unix_socket(path: str) -> socket.socket: def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valid_tool_parses)} }})") + if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses: + raise KeyError( + f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})" + ) valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() - if args.reasoning_parser \ - and args.reasoning_parser not in valid_reasoning_parses: + if ( + reasoning_parser := args.structured_outputs_config.reasoning_parser + ) and reasoning_parser not in valid_reasoning_parses: raise KeyError( - f"invalid reasoning parser: {args.reasoning_parser} " - f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + f"invalid reasoning parser: {reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})" + ) def setup_server(args): @@ -1902,8 +1899,7 @@ def setup_server(args): else: addr, port = sock_addr is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address( - addr) else addr or "0.0.0.0" + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" return listen_address, sock @@ -1918,35 +1914,33 @@ async def run_server(args, **uvicorn_kwargs) -> None: await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) -async def run_server_worker(listen_address, - sock, - args, - client_config=None, - **uvicorn_kwargs) -> None: +async def run_server_worker( + listen_address, sock, args, client_config=None, **uvicorn_kwargs +) -> None: """Run a single API server worker.""" if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) - server_index = client_config.get("client_index", 0) if client_config else 0 - # Load logging config for uvicorn if specified log_config = load_log_config(args.log_config_file) if log_config is not None: - uvicorn_kwargs['log_config'] = log_config + uvicorn_kwargs["log_config"] = log_config async with build_async_engine_client( - args, - client_config=client_config, + args, + client_config=client_config, ) as engine_client: maybe_register_tokenizer_info_endpoint(args) app = build_app(args) - vllm_config = await engine_client.get_vllm_config() - await init_app_state(engine_client, vllm_config, app.state, args) + await init_app_state(engine_client, app.state, args) - logger.info("Starting vLLM API server %d on %s", server_index, - listen_address) + logger.info( + "Starting vLLM API server %d on %s", + engine_client.vllm_config.parallel_config._api_process_rank, + listen_address, + ) shutdown_task = await serve_http( app, sock=sock, @@ -1980,7 +1974,8 @@ if __name__ == "__main__": # entrypoints. cli_env_setup() parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") + description="vLLM OpenAI-Compatible RESTful API server." + ) parser = make_arg_parser(parser) args = parser.parse_args() validate_parsed_serve_args(args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 6e4eff5c80243..1f16646db63b8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -18,10 +18,14 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type -from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, - validate_chat_template) -from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, - H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) +from vllm.entrypoints.chat_utils import ( + ChatTemplateContentFormatOption, + validate_chat_template, +) +from vllm.entrypoints.constants import ( + H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, +) from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger @@ -31,7 +35,6 @@ logger = init_logger(__name__) class LoRAParserAction(argparse.Action): - def __call__( self, parser: argparse.ArgumentParser, @@ -57,8 +60,7 @@ class LoRAParserAction(argparse.Action): lora = LoRAModulePath(**lora_dict) lora_list.append(lora) except json.JSONDecodeError: - parser.error( - f"Invalid JSON format for --lora-modules: {item}") + parser.error(f"Invalid JSON format for --lora-modules: {item}") except TypeError as e: parser.error( f"Invalid fields for --lora-modules: {item} - {str(e)}" @@ -70,14 +72,16 @@ class LoRAParserAction(argparse.Action): @dataclass class FrontendArgs: """Arguments for the OpenAI-compatible frontend server.""" + host: Optional[str] = None """Host name.""" port: int = 8000 """Port number.""" uds: Optional[str] = None """Unix domain socket path. If set, host and port arguments are ignored.""" - uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", - "trace"] = "info" + uvicorn_log_level: Literal[ + "debug", "info", "warning", "error", "critical", "trace" + ] = "info" """Log level for uvicorn.""" disable_uvicorn_access_log: bool = False """Disable uvicorn access log.""" @@ -103,9 +107,13 @@ class FrontendArgs: chat_template_content_format: ChatTemplateContentFormatOption = "auto" """The format to render message content within a chat template. -* "string" will render the content as a string. Example: `"Hello World"` -* "openai" will render the content as a list of dictionaries, similar to OpenAI -schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" + * "string" will render the content as a string. Example: `"Hello World"` + * "openai" will render the content as a list of dictionaries, similar to + OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" + trust_request_chat_template: bool = False + """Whether to trust the chat template provided in the request. If False, + the server will always use the chat template specified by `--chat-template` + or the ones from tokenizer.""" response_role: str = "assistant" """The role name to return if `request.add_generation_prompt=true`.""" ssl_keyfile: Optional[str] = None @@ -134,14 +142,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """If specified, will run the OpenAI frontend server in the same process as the model serving engine.""" enable_request_id_headers: bool = False - """If specified, API server will add X-Request-Id header to responses. - Caution: this hurts performance at high QPS.""" + """If specified, API server will add X-Request-Id header to responses.""" enable_auto_tool_choice: bool = False - """If specified, exclude tool definitions in prompts when - tool_choice='none'.""" - exclude_tools_when_tool_choice_none: bool = False """Enable auto tool choice for supported models. Use `--tool-call-parser` to specify which parser to use.""" + exclude_tools_when_tool_choice_none: bool = False + """If specified, exclude tool definitions in prompts when + tool_choice='none'.""" tool_call_parser: Optional[str] = None """Select the tool call parser depending on the model that you're using. This is used to parse the model-generated tool call into OpenAI API format. @@ -172,14 +179,16 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """Enable the /get_tokenizer_info endpoint. May expose chat templates and other tokenizer configuration.""" enable_log_outputs: bool = False - """If set to True, enable logging of model outputs (generations) - in addition to the input logging that is enabled by default.""" + """If True, log model outputs (generations). + Requires --enable-log-requests.""" h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT """Maximum size (bytes) of an incomplete HTTP event (header or body) for h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB).""" h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT """Maximum number of HTTP headers allowed in a request for h11 parser. Helps mitigate header abuse. Default: 256.""" + log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE + """If set to True, log the stack trace of error responses""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -202,7 +211,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["action"] = LoRAParserAction - # Special case: Middleware needs append action + # Special case: Middleware needs to append action frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["type"] = str if "nargs" in frontend_kwargs["middleware"]: @@ -213,7 +222,8 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" valid_tool_parsers = list(ToolParserManager.tool_parsers.keys()) parsers_str = ",".join(valid_tool_parsers) frontend_kwargs["tool_call_parser"]["metavar"] = ( - f"{{{parsers_str}}} or name registered in --tool-parser-plugin") + f"{{{parsers_str}}} or name registered in --tool-parser-plugin" + ) frontend_group = parser.add_argument_group( title="Frontend", @@ -233,27 +243,32 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: register all arguments instead of manually enumerating them here. This avoids code duplication and keeps the argument definitions in one place. """ - parser.add_argument("model_tag", - type=str, - nargs="?", - help="The model tag to serve " - "(optional if specified in config)") + parser.add_argument( + "model_tag", + type=str, + nargs="?", + help="The model tag to serve (optional if specified in config)", + ) parser.add_argument( "--headless", action="store_true", default=False, help="Run in headless mode. See multi-node data parallel " - "documentation for more details.") - parser.add_argument("--api-server-count", - "-asc", - type=int, - default=1, - help="How many API server processes to run.") + "documentation for more details.", + ) + parser.add_argument( + "--api-server-count", + "-asc", + type=int, + default=1, + help="How many API server processes to run.", + ) parser.add_argument( "--config", help="Read CLI options from a config file. " "Must be a YAML with the following options: " - "https://docs.vllm.ai/en/latest/configuration/serve_args.html") + "https://docs.vllm.ai/en/latest/configuration/serve_args.html", + ) parser = FrontendArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) @@ -270,11 +285,13 @@ def validate_parsed_serve_args(args: argparse.Namespace): # Enable auto tool needs a tool call parser to be valid if args.enable_auto_tool_choice and not args.tool_call_parser: - raise TypeError("Error: --enable-auto-tool-choice requires " - "--tool-call-parser") + raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser") + if args.enable_log_outputs and not args.enable_log_requests: + raise TypeError("Error: --enable-log-outputs requires --enable-log-requests") def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( - prog="-m vllm.entrypoints.openai.api_server") + prog="-m vllm.entrypoints.openai.api_server" + ) return make_arg_parser(parser_for_docs) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 29d72256cf70b..2ea9fbf386ba1 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -19,12 +19,11 @@ class AllowedTokenIdsLogitsProcessor: self.allowed_ids: Optional[list[int]] = list(allowed_ids) self.mask: Optional[torch.Tensor] = None - def __call__(self, token_ids: list[int], - logits: torch.Tensor) -> torch.Tensor: + def __call__(self, token_ids: list[int], logits: torch.Tensor) -> torch.Tensor: if self.mask is None: - self.mask = torch.ones((logits.shape[-1], ), - dtype=torch.bool, - device=logits.device) + self.mask = torch.ones( + (logits.shape[-1],), dtype=torch.bool, device=logits.device + ) self.mask[self.allowed_ids] = False self.allowed_ids = None logits.masked_fill_(self.mask, float("-inf")) @@ -39,8 +38,7 @@ def _get_allowed_token_ids_logits_processor( if not allowed_token_ids: raise ValueError("Empty allowed_token_ids provided") if not all(0 <= tid < vocab_size for tid in allowed_token_ids): - raise ValueError("allowed_token_ids contains " - "out-of-vocab token id") + raise ValueError("allowed_token_ids contains out-of-vocab token id") return AllowedTokenIdsLogitsProcessor(allowed_token_ids) @@ -71,20 +69,25 @@ def get_logits_processors( except ValueError as exc: raise ValueError( "Found token_id in logit_bias that is not " - "an integer or string representing an integer") from exc + "an integer or string representing an integer" + ) from exc # Check if token_id is within the vocab size for token_id, bias in clamped_logit_bias.items(): if token_id < 0 or token_id >= len(tokenizer): - raise ValueError(f"token_id {token_id} in logit_bias contains " - "out-of-vocab token id") + raise ValueError( + f"token_id {token_id} in logit_bias contains out-of-vocab token id" + ) logits_processors.append( - partial(logit_bias_logits_processor, clamped_logit_bias)) + partial(logit_bias_logits_processor, clamped_logit_bias) + ) if allowed_token_ids is not None: logits_processors.append( _get_allowed_token_ids_logits_processor( - frozenset(allowed_token_ids), len(tokenizer))) + frozenset(allowed_token_ids), len(tokenizer) + ) + ) return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a3d7b78cf4552..6ff7ceef48055 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,46 +6,81 @@ import json import time from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Literal, Optional, Union +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, Union import regex as re import torch from fastapi import HTTPException, UploadFile -# yapf: disable from openai.types.chat.chat_completion_audio import ( - ChatCompletionAudio as OpenAIChatCompletionAudio) -from openai.types.chat.chat_completion_message import ( - Annotation as OpenAIAnnotation) -# yapf: enable -from openai.types.responses import (ResponseFunctionToolCall, - ResponseInputItemParam, ResponseOutputItem, - ResponsePrompt, ResponseReasoningItem, - ResponseStatus) + ChatCompletionAudio as OpenAIChatCompletionAudio, +) +from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation +from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseFunctionToolCall, + ResponseInputItemParam, + ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponsePrompt, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseStatus, + ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, +) +from openai.types.responses import ( + ResponseCompletedEvent as OpenAIResponseCompletedEvent, +) +from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent +from openai.types.responses import ( + ResponseInProgressEvent as OpenAIResponseInProgressEvent, +) +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent, +) # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) from openai.types.responses import ResponseTextConfig except ImportError: # For newer openai versions (>= 1.100.0) - from openai.types.responses import (ResponseFormatTextConfig as - ResponseTextConfig) + from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig -from openai.types.responses.response import ToolChoice +from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning -from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, - ValidationInfo, field_validator, model_validator) +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + ValidationInfo, + field_serializer, + field_validator, + model_validator, +) from typing_extensions import TypeAlias from vllm import envs -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - make_tool_call_id) -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam) +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id +from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) -from vllm.sequence import Logprob +from vllm.sampling_params import ( + BeamSearchParams, + RequestOutputKind, + SamplingParams, + StructuredOutputsParams, +) from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -79,8 +114,7 @@ class OpenAIBaseModel(BaseModel): # Compare against both field names and aliases if any(k not in field_names for k in data): logger.warning( - "The following fields were present in the request " - "but ignored: %s", + "The following fields were present in the request but ignored: %s", data.keys() - field_names, ) return result @@ -149,7 +183,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): description: Optional[str] = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') + json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema") strict: Optional[bool] = None @@ -157,8 +191,9 @@ class StructuralTag(OpenAIBaseModel): begin: str # schema is the field, but that causes conflicts with pydantic so # instead use structural_tag_schema with an alias - structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, - alias="schema") + structural_tag_schema: Optional[dict[str, Any]] = Field( + default=None, alias="schema" + ) end: str @@ -215,18 +250,19 @@ class LogitsProcessorConstructor(BaseModel): LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] -def get_logits_processors(processors: Optional[LogitsProcessors], - pattern: Optional[str]) -> Optional[list[Any]]: +def get_logits_processors( + processors: Optional[LogitsProcessors], pattern: Optional[str] +) -> Optional[list[Any]]: if processors and pattern: logits_processors = [] for processor in processors: - qualname = processor if isinstance(processor, - str) else processor.qualname + qualname = processor if isinstance(processor, str) else processor.qualname if not re.match(pattern, qualname): raise ValueError( f"Logits processor '{qualname}' is not allowed by this " "server. See --logits-processor-pattern engine argument " - "for more information.") + "for more information." + ) try: logits_processor = resolve_obj_by_qualname(qualname) except Exception as e: @@ -234,37 +270,41 @@ def get_logits_processors(processors: Optional[LogitsProcessors], f"Logits processor '{qualname}' could not be resolved: {e}" ) from e if isinstance(processor, LogitsProcessorConstructor): - logits_processor = logits_processor(*processor.args or [], - **processor.kwargs or {}) + logits_processor = logits_processor( + *processor.args or [], **processor.kwargs or {} + ) logits_processors.append(logits_processor) return logits_processors elif processors: raise ValueError( "The `logits_processors` argument is not supported by this " - "server. See --logits-processor-pattern engine argugment " - "for more information.") + "server. See --logits-processor-pattern engine argument " + "for more information." + ) return None -ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, - ResponseReasoningItem, - ResponseFunctionToolCall] +ResponseInputOutputItem: TypeAlias = Union[ + ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall +] class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/responses/create background: Optional[bool] = False - include: Optional[list[ - Literal[ - "code_interpreter_call.outputs", - "computer_call_output.output.image_url", - "file_search_call.results", - "message.input_image.image_url", - "message.output_text.logprobs", - "reasoning.encrypted_content", - ], - ]] = None + include: Optional[ + list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ] + ] = None input: Union[str, list[ResponseInputOutputItem]] instructions: Optional[str] = None max_output_tokens: Optional[int] = None @@ -275,8 +315,7 @@ class ResponsesRequest(OpenAIBaseModel): previous_response_id: Optional[str] = None prompt: Optional[ResponsePrompt] = None reasoning: Optional[Reasoning] = None - service_tier: Literal["auto", "default", "flex", "scale", - "priority"] = "auto" + service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto" store: Optional[bool] = True stream: Optional[bool] = False temperature: Optional[float] = None @@ -294,7 +333,8 @@ class ResponsesRequest(OpenAIBaseModel): description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -305,7 +345,8 @@ class ResponsesRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) cache_salt: Optional[str] = Field( default=None, @@ -315,7 +356,18 @@ class ResponsesRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) + + enable_response_messages: bool = Field( + default=False, + description=( + "Dictates whether or not to return messages as part of the " + "response object. Currently only supported for non-streaming " + "non-background and gpt-oss only. " + ), + ) # --8<-- [end:responses-extra-params] _DEFAULT_SAMPLING_PARAMS = { @@ -336,19 +388,25 @@ class ResponsesRequest(OpenAIBaseModel): default_sampling_params = default_sampling_params or {} if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output - guided_decoding = None + structured_outputs = None if self.text is not None and self.text.format is not None: response_format = self.text.format - if response_format.type == "json_schema": - guided_decoding = GuidedDecodingParams.from_optional( - json=response_format.schema_) + if ( + response_format.type == "json_schema" + and response_format.schema_ is not None + ): + structured_outputs = StructuredOutputsParams( + json=response_format.schema_ + ) elif response_format.type == "json_object": raise NotImplementedError("json_object is not supported") @@ -357,29 +415,29 @@ class ResponsesRequest(OpenAIBaseModel): temperature=temperature, top_p=top_p, max_tokens=max_tokens, - logprobs=self.top_logprobs - if self.is_include_output_logprobs() else None, + logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, - output_kind=(RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY), - guided_decoding=guided_decoding, + output_kind=( + RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY + ), + structured_outputs=structured_outputs, ) def is_include_output_logprobs(self) -> bool: """Check if the request includes output logprobs.""" if self.include is None: return False - return isinstance( - self.include, - list) and "message.output_text.logprobs" in self.include + return ( + isinstance(self.include, list) + and "message.output_text.logprobs" in self.include + ) @model_validator(mode="before") def validate_background(cls, data): if not data.get("background"): return data if not data.get("store", True): - raise ValueError( - "background can only be used when `store` is true") + raise ValueError("background can only be used when `store` is true") return data @model_validator(mode="before") @@ -394,11 +452,12 @@ class ResponsesRequest(OpenAIBaseModel): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -413,8 +472,9 @@ class ChatCompletionRequest(OpenAIBaseModel): top_logprobs: Optional[int] = 0 max_tokens: Optional[int] = Field( default=None, - deprecated= - 'max_tokens is deprecated in favor of the max_completion_tokens field') + deprecated="max_tokens is deprecated in favor of " + "the max_completion_tokens field", + ) max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -426,12 +486,14 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = None top_p: Optional[float] = None tools: Optional[list[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[ - Literal["none"], - Literal["auto"], - Literal["required"], - ChatCompletionNamedToolChoiceParam, - ]] = "none" + tool_choice: Optional[ + Union[ + Literal["none"], + Literal["auto"], + Literal["required"], + ChatCompletionNamedToolChoiceParam, + ] + ] = "none" reasoning_effort: Optional[Literal["low", "medium", "high"]] = None include_reasoning: bool = True @@ -452,7 +514,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) @@ -463,23 +525,26 @@ class ChatCompletionRequest(OpenAIBaseModel): default=False, description=( "If true, the new message will be prepended with the last message " - "if they belong to the same role."), + "if they belong to the same role." + ), ) add_generation_prompt: bool = Field( default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) continue_final_message: bool = Field( default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), ) add_special_tokens: bool = Field( default=False, @@ -488,16 +553,18 @@ class ChatCompletionRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) documents: Optional[list[dict[str, str]]] = Field( default=None, - description= - ("A list of dicts representing documents that will be accessible to " - "the model if it is performing RAG (retrieval-augmented generation)." - " If the template does not support RAG, this argument will have no " - "effect. We recommend that each document should be a dict containing " - "\"title\" and \"text\" keys."), + description=( + "A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + '"title" and "text" keys.' + ), ) chat_template: Optional[str] = Field( default=None, @@ -505,68 +572,95 @@ class ChatCompletionRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) + structured_outputs: Optional[StructuredOutputsParams] = Field( + default=None, + description="Additional kwargs for structured outputs", + ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, - description=("If specified, the output will follow the JSON schema."), + description=( + "`guided_json` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `json` to `structured_outputs` instead." + ), ) guided_regex: Optional[str] = Field( default=None, description=( - "If specified, the output will follow the regex pattern."), + "`guided_regex` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `regex` to `structured_outputs` instead." + ), ) guided_choice: Optional[list[str]] = Field( default=None, description=( - "If specified, the output will be exactly one of the choices."), + "`guided_choice` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `choice` to `structured_outputs` instead." + ), ) guided_grammar: Optional[str] = Field( default=None, description=( - "If specified, the output will follow the context free grammar."), + "`guided_grammar` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `grammar` to `structured_outputs` instead." + ), ) structural_tag: Optional[str] = Field( default=None, description=( - "If specified, the output will follow the structural tag schema."), + "`structural_tag` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `structural_tag` to `structured_outputs` instead." + ), ) guided_decoding_backend: Optional[str] = Field( default=None, description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be either " - "'outlines' / 'lm-format-enforcer'"), + "`guided_decoding_backend` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please remove it from your request." + ), ) guided_whitespace_pattern: Optional[str] = Field( default=None, description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding."), + "`guided_whitespace_pattern` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `whitespace_pattern` to `structured_outputs` instead." + ), ) priority: int = Field( default=0, description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) logits_processors: Optional[LogitsProcessors] = Field( default=None, @@ -578,13 +672,17 @@ class ChatCompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}." + ), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified." + ), + ) return_token_ids: Optional[bool] = Field( default=None, description=( @@ -592,7 +690,9 @@ class ChatCompletionRequest(OpenAIBaseModel): "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens.")) + "need to map generated text back to input tokens." + ), + ) cache_salt: Optional[str] = Field( default=None, description=( @@ -601,15 +701,20 @@ class ChatCompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:chat-completion-extra-params] @@ -624,13 +729,13 @@ class ChatCompletionRequest(OpenAIBaseModel): } def to_beam_search_params( - self, max_tokens: int, - default_sampling_params: dict) -> BeamSearchParams: - + self, max_tokens: int, default_sampling_params: dict + ) -> BeamSearchParams: n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) return BeamSearchParams( beam_width=n, @@ -647,7 +752,6 @@ class ChatCompletionRequest(OpenAIBaseModel): logits_processor_pattern: Optional[str], default_sampling_params: dict, ) -> SamplingParams: - # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -656,46 +760,66 @@ class ChatCompletionRequest(OpenAIBaseModel): ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs - guided_json_object = None - if self.response_format is not None: - if self.response_format.type == "json_object": - guided_json_object = True - elif self.response_format.type == "json_schema": - json_schema = self.response_format.json_schema - assert json_schema is not None - self.guided_json = json_schema.json_schema - elif self.response_format.type == "structural_tag": - structural_tag = self.response_format - assert structural_tag is not None and isinstance( - structural_tag, StructuralTagResponseFormat) - s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structural_tag = json.dumps(s_tag_obj) + # Forward deprecated guided_* parameters to structured_outputs + if self.structured_outputs is None: + kwargs = dict[str, Any]( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + whitespace_pattern=self.guided_whitespace_pattern, + structural_tag=self.structural_tag, + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if len(kwargs) > 0: + self.structured_outputs = StructuredOutputsParams(**kwargs) - guided_decoding = GuidedDecodingParams.from_optional( - json=self._get_guided_json_from_tool() or self.guided_json, - regex=self.guided_regex, - choice=self.guided_choice, - grammar=self.guided_grammar, - json_object=guided_json_object, - backend=self.guided_decoding_backend, - whitespace_pattern=self.guided_whitespace_pattern, - structural_tag=self.structural_tag, - ) + response_format = self.response_format + json_schema_from_tool = self._get_json_schema_from_tool() + if response_format is not None or json_schema_from_tool is not None: + # If structured outputs wasn't already enabled, + # we must enable it for these features to work + if self.structured_outputs is None: + self.structured_outputs = StructuredOutputsParams() + + # Set structured output params for response format + if response_format is not None: + if response_format.type == "json_object": + self.structured_outputs.json_object = True + elif response_format.type == "json_schema": + json_schema = response_format.json_schema + assert json_schema is not None + self.structured_outputs.json = json_schema.json_schema + elif response_format.type == "structural_tag": + structural_tag = response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat + ) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structured_outputs.structural_tag = json.dumps(s_tag_obj) + + # Set structured output params for tool calling + if json_schema_from_tool is not None: + self.structured_outputs.json = json_schema_from_tool extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: @@ -721,21 +845,22 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens=self.min_tokens, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, - logits_processors=get_logits_processors(self.logits_processors, - logits_processor_pattern), + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, - guided_decoding=guided_decoding, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, - bad_words= self.bad_words, + bad_words=self.bad_words, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) - def _get_guided_json_from_tool( - self) -> Optional[Union[str, dict, BaseModel]]: + def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: # user has chosen to not use any tool if self.tool_choice == "none" or self.tools is None: return None @@ -745,8 +870,7 @@ class ChatCompletionRequest(OpenAIBaseModel): tool_name = self.tool_choice.function.name tools = {tool.function.name: tool.function for tool in self.tools} if tool_name not in tools: - raise ValueError( - f"Tool '{tool_name}' has not been passed in `tools`.") + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") tool = tools[tool_name] return tool.parameters @@ -758,37 +882,31 @@ class ChatCompletionRequest(OpenAIBaseModel): def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: return { "properties": { - "name": { - "type": "string", - "enum": [tool.function.name] - }, + "name": {"type": "string", "enum": [tool.function.name]}, # parameters are always generated as '{}' in the final # output if they are missing from the request # (i.e. are None or '{}') so the schema is # updated to produce an empty object in that case "parameters": tool.function.parameters - if tool.function.parameters else { - "type": "object", - "properties": {} - } + if tool.function.parameters + else {"type": "object", "properties": {}}, }, - "required": ["name", "parameters"] + "required": ["name", "parameters"], } - def get_tool_schema_defs( - tools: list[ChatCompletionToolsParam]) -> dict: + def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict: all_defs = dict[str, dict[str, Any]]() for tool in tools: if tool.function.parameters is None: continue defs = tool.function.parameters.pop("$defs", {}) for def_name, def_schema in defs.items(): - if def_name in all_defs and all_defs[ - def_name] != def_schema: + if def_name in all_defs and all_defs[def_name] != def_schema: raise ValueError( f"Tool definition '{def_name}' has " "multiple schemas, which is not " - "supported.") + "supported." + ) else: all_defs[def_name] = def_schema return all_defs @@ -798,8 +916,8 @@ class ChatCompletionRequest(OpenAIBaseModel): "minItems": 1, "items": { "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools] - } + "anyOf": [get_tool_schema(tool) for tool in self.tools], + }, } json_schema_defs = get_tool_schema_defs(self.tools) if json_schema_defs: @@ -812,8 +930,7 @@ class ChatCompletionRequest(OpenAIBaseModel): @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -821,18 +938,22 @@ class ChatCompletionRequest(OpenAIBaseModel): @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and prompt_logprobs > 0: + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise ValueError( - "`prompt_logprobs` are not available when `stream=True`.") - - if prompt_logprobs < 0: - raise ValueError("`prompt_logprobs` must be a positive value.") + "`prompt_logprobs` are not available when `stream=True`." + ) + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise ValueError("`prompt_logprobs` must be a positive value or -1.") + if prompt_logprobs == -1 and not envs.VLLM_USE_V1: + raise ValueError( + "`prompt_logprobs=-1` is only supported with vLLM engine V1." + ) if (top_logprobs := data.get("top_logprobs")) is not None: - if top_logprobs < 0: - raise ValueError("`top_logprobs` must be a positive value.") + if top_logprobs < 0 and top_logprobs != -1: + raise ValueError("`top_logprobs` must be a positive value or -1.") - if top_logprobs > 0 and not data.get("logprobs"): + if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) @@ -841,34 +962,39 @@ class ChatCompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod - def check_guided_decoding_count(cls, data): + def check_structured_outputs_count(cls, data): if isinstance(data, ValueError): raise data - guide_count = sum([ - "guided_json" in data and data["guided_json"] is not None, - "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None - ]) - # you can only use one kind of guided decoding - if guide_count > 1: + if data.get("structured_outputs", None) is None: + return data + + structured_outputs_kwargs = data["structured_outputs"] + count = sum( + structured_outputs_kwargs.get(k) is not None + for k in ("json", "regex", "choice") + ) + # you can only use one kind of constraints for structured outputs + if count > 1: raise ValueError( - "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice').") - # you can only either use guided decoding or tools, not both - if guide_count > 1 and data.get("tool_choice", "none") not in ( - "none", - "auto", - "required", + "You can only use one kind of constraints for structured " + "outputs ('json', 'regex' or 'choice')." + ) + # you can only either use structured outputs or tools, not both + if count > 1 and data.get("tool_choice", "none") not in ( + "none", + "auto", + "required", ): raise ValueError( - "You can only either use guided decoding or tools, not both.") + "You can only either use constraints for structured outputs " + "or tools, not both." + ) return data @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): - # if "tool_choice" is not specified but tools are provided, # default to "auto" tool_choice if "tool_choice" not in data and data.get("tools"): @@ -880,52 +1006,58 @@ class ChatCompletionRequest(OpenAIBaseModel): # if "tool_choice" is specified -- validation if "tool_choice" in data and data["tool_choice"] is not None: - # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: - raise ValueError( - "When using `tool_choice`, `tools` must be set.") + raise ValueError("When using `tool_choice`, `tools` must be set.") # make sure that tool choice is either a named tool # OR that it's set to "auto" or "required" - if data["tool_choice"] not in [ - "auto", "required" - ] and not isinstance(data["tool_choice"], dict): + if data["tool_choice"] not in ["auto", "required"] and not isinstance( + data["tool_choice"], dict + ): raise ValueError( - f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ - 'Only named tools, "none", "auto" or "required" '\ - 'are supported.' + f"Invalid value for `tool_choice`: {data['tool_choice']}! " + 'Only named tools, "none", "auto" or "required" ' + "are supported." ) # if tool_choice is "required" but the "tools" list is empty, # override the data to behave like "none" to align with # OpenAI’s behavior. - if data["tool_choice"] == "required" and isinstance( - data["tools"], list) and len(data["tools"]) == 0: + if ( + data["tool_choice"] == "required" + and isinstance(data["tools"], list) + and len(data["tools"]) == 0 + ): data["tool_choice"] = "none" del data["tools"] return data # ensure that if "tool_choice" is specified as an object, # it matches a valid tool - correct_usage_message = 'Correct usage: `{"type": "function",' \ + correct_usage_message = ( + 'Correct usage: `{"type": "function",' ' "function": {"name": "my_function"}}`' + ) if isinstance(data["tool_choice"], dict): valid_tool = False function = data["tool_choice"].get("function") if not isinstance(function, dict): raise ValueError( f"Invalid value for `function`: `{function}` in " - f"`tool_choice`! {correct_usage_message}") + f"`tool_choice`! {correct_usage_message}" + ) if "name" not in function: - raise ValueError(f"Expected field `name` in `function` in " - f"`tool_choice`! {correct_usage_message}") + raise ValueError( + f"Expected field `name` in `function` in " + f"`tool_choice`! {correct_usage_message}" + ) function_name = function["name"] - if not isinstance(function_name, - str) or len(function_name) == 0: + if not isinstance(function_name, str) or len(function_name) == 0: raise ValueError( f"Invalid `name` in `function`: `{function_name}`" - f" in `tool_choice`! {correct_usage_message}") + f" in `tool_choice`! {correct_usage_message}" + ) for tool in data["tools"]: if tool["function"]["name"] == function_name: valid_tool = True @@ -933,16 +1065,18 @@ class ChatCompletionRequest(OpenAIBaseModel): if not valid_tool: raise ValueError( "The tool specified in `tool_choice` does not match any" - " of the specified `tools`") + " of the specified `tools`" + ) return data @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data @model_validator(mode="before") @@ -952,11 +1086,12 @@ class ChatCompletionRequest(OpenAIBaseModel): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -965,7 +1100,6 @@ class CompletionRequest(OpenAIBaseModel): # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None - prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 @@ -995,17 +1129,19 @@ class CompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None allowed_token_ids: Optional[list[int]] = None prompt_logprobs: Optional[int] = None # --8<-- [end:completion-sampling-params] # --8<-- [start:completion-extra-params] + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None add_special_tokens: bool = Field( default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) response_format: Optional[AnyResponseFormat] = Field( default=None, @@ -1015,51 +1151,73 @@ class CompletionRequest(OpenAIBaseModel): ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." ), ) + structured_outputs: Optional[StructuredOutputsParams] = Field( + default=None, + description="Additional kwargs for structured outputs", + ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, - description="If specified, the output will follow the JSON schema.", + description=( + "`guided_json` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `json` to `structured_outputs` instead." + ), ) guided_regex: Optional[str] = Field( default=None, description=( - "If specified, the output will follow the regex pattern."), + "`guided_regex` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `regex` to `structured_outputs` instead." + ), ) guided_choice: Optional[list[str]] = Field( default=None, description=( - "If specified, the output will be exactly one of the choices."), + "`guided_choice` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `choice` to `structured_outputs` instead." + ), ) guided_grammar: Optional[str] = Field( default=None, description=( - "If specified, the output will follow the context free grammar."), + "`guided_grammar` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `grammar` to `structured_outputs` instead." + ), ) guided_decoding_backend: Optional[str] = Field( default=None, description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be one of " - "'outlines' / 'lm-format-enforcer'"), + "`guided_decoding_backend` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please remove it from your request." + ), ) guided_whitespace_pattern: Optional[str] = Field( default=None, description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding."), + "`guided_whitespace_pattern` is deprecated. " + "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " + "Please pass `whitespace_pattern` to `structured_outputs` instead." + ), ) priority: int = Field( default=0, description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) logits_processors: Optional[LogitsProcessors] = Field( default=None, @@ -1071,14 +1229,18 @@ class CompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}." + ), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified." + ), + ) return_token_ids: Optional[bool] = Field( default=None, description=( @@ -1086,7 +1248,9 @@ class CompletionRequest(OpenAIBaseModel): "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens.")) + "need to map generated text back to input tokens." + ), + ) cache_salt: Optional[str] = Field( default=None, @@ -1096,16 +1260,21 @@ class CompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:completion-extra-params] @@ -1124,7 +1293,6 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: int, default_sampling_params: Optional[dict] = None, ) -> BeamSearchParams: - if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 @@ -1147,7 +1315,6 @@ class CompletionRequest(OpenAIBaseModel): logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None, ) -> SamplingParams: - if default_sampling_params is None: default_sampling_params = {} @@ -1159,16 +1326,20 @@ class CompletionRequest(OpenAIBaseModel): ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -1176,20 +1347,25 @@ class CompletionRequest(OpenAIBaseModel): echo_without_generation = self.echo and self.max_tokens == 0 - guided_json_object = None - if (self.response_format is not None - and self.response_format.type == "json_object"): - guided_json_object = True + # Forward deprecated guided_* parameters to structured_outputs + if self.structured_outputs is None: + kwargs = dict[str, Any]( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + whitespace_pattern=self.guided_whitespace_pattern, + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if len(kwargs) > 0: + self.structured_outputs = StructuredOutputsParams(**kwargs) - guided_decoding = GuidedDecodingParams.from_optional( - json=self.guided_json, - regex=self.guided_regex, - choice=self.guided_choice, - grammar=self.guided_grammar, - json_object=guided_json_object, - backend=self.guided_decoding_backend, - whitespace_pattern=self.guided_whitespace_pattern, - ) + if ( + self.structured_outputs is not None + and self.response_format is not None + and self.response_format.type == "json_object" + ): + self.structured_outputs.json_object = True extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: @@ -1216,42 +1392,52 @@ class CompletionRequest(OpenAIBaseModel): skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - logits_processors=get_logits_processors(self.logits_processors, - logits_processor_pattern), + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, - guided_decoding=guided_decoding, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, - ) + ) @model_validator(mode="before") @classmethod - def check_guided_decoding_count(cls, data): - guide_count = sum([ - "guided_json" in data and data["guided_json"] is not None, - "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None - ]) - if guide_count > 1: + def check_structured_outputs_count(cls, data): + if data.get("structured_outputs", None) is None: + return data + + structured_outputs_kwargs = data["structured_outputs"] + count = sum( + structured_outputs_kwargs.get(k) is not None + for k in ("json", "regex", "choice") + ) + if count > 1: raise ValueError( - "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice').") + "You can only use one kind of constraints for structured " + "outputs ('json', 'regex' or 'choice')." + ) return data @model_validator(mode="before") @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and prompt_logprobs > 0: + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise ValueError( - "`prompt_logprobs` are not available when `stream=True`.") - - if prompt_logprobs < 0: - raise ValueError("`prompt_logprobs` must be a positive value.") + "`prompt_logprobs` are not available when `stream=True`." + ) + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise ValueError("`prompt_logprobs` must be a positive value or -1.") + if prompt_logprobs == -1 and not envs.VLLM_USE_V1: + raise ValueError( + "`prompt_logprobs=-1` is only supported with vLLM engine V1." + ) if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.") @@ -1261,17 +1447,26 @@ class CompletionRequest(OpenAIBaseModel): @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @model_validator(mode="before") @classmethod def validate_prompt_and_prompt_embeds(cls, data): - if data.get("prompt") is None and data.get("prompt_embeds") is None: + prompt = data.get("prompt") + prompt_embeds = data.get("prompt_embeds") + + prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "") + embeds_is_empty = prompt_embeds is None or ( + isinstance(prompt_embeds, list) and len(prompt_embeds) == 0 + ) + + if prompt_is_empty and embeds_is_empty: raise ValueError( - "At least one of `prompt` or `prompt_embeds` must be set.") + "Either prompt or prompt_embeds must be provided and non-empty." + ) + return data @model_validator(mode="before") @@ -1281,11 +1476,12 @@ class CompletionRequest(OpenAIBaseModel): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1304,29 +1500,35 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) priority: int = Field( default=0, description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) normalize: Optional[bool] = None # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize, + ) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1339,6 +1541,15 @@ class EmbeddingChatRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # --8<-- [start:chat-embedding-extra-params] + add_generation_prompt: bool = Field( + default=False, + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), + ) + add_special_tokens: bool = Field( default=False, description=( @@ -1346,7 +1557,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) chat_template: Optional[str] = Field( default=None, @@ -1354,13 +1566,15 @@ class EmbeddingChatRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -1371,14 +1585,16 @@ class EmbeddingChatRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @@ -1386,22 +1602,64 @@ class EmbeddingChatRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize, + ) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] PoolingCompletionRequest = EmbeddingCompletionRequest PoolingChatRequest = EmbeddingChatRequest -PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] + +T = TypeVar("T") + + +class IOProcessorRequest(OpenAIBaseModel, Generic[T]): + model: Optional[str] = None + + priority: int = Field(default=0) + """ + The priority of the request (lower means earlier handling; + default: 0). Any priority other than 0 will raise an error + if the served model does not use priority scheduling. + """ + data: T + """ + When using plugins IOProcessor plugins, the actual input is processed + by the plugin itself. Hence, we use a generic type for the request data + """ + softmax: bool = True + + def to_pooling_params(self): + return PoolingParams(task="encode", softmax=self.softmax) + + +class IOProcessorResponse(OpenAIBaseModel, Generic[T]): + request_id: Optional[str] = None + """ + The request_id associated with this response + """ + created_at: int = Field(default_factory=lambda: int(time.time())) + + data: T + """ + When using plugins IOProcessor plugins, the actual output is generated + by the plugin itself. Hence, we use a generic type for the response data + """ + + +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, IOProcessorRequest] class ScoreRequest(OpenAIBaseModel): @@ -1422,7 +1680,8 @@ class ScoreRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1430,7 +1689,10 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation, + ) class RerankRequest(OpenAIBaseModel): @@ -1452,7 +1714,8 @@ class RerankRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1460,7 +1723,10 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation, + ) class RerankDocument(BaseModel): @@ -1489,8 +1755,7 @@ class CompletionLogProbs(OpenAIBaseModel): text_offset: list[int] = Field(default_factory=list) token_logprobs: list[Optional[float]] = Field(default_factory=list) tokens: list[str] = Field(default_factory=list) - top_logprobs: list[Optional[dict[str, - float]]] = Field(default_factory=list) + top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list) class CompletionResponseChoice(OpenAIBaseModel): @@ -1503,7 +1768,8 @@ class CompletionResponseChoice(OpenAIBaseModel): description=( "The stop string or token id that caused the completion " "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), + "including encountering the EOS token" + ), ) token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None @@ -1516,14 +1782,16 @@ class CompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", - "priority"]] = None + service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( + None + ) system_fingerprint: Optional[str] = None usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec kv_transfer_params: Optional[dict[str, Any]] = Field( - default=None, description="KVTransfer parameters.") + default=None, description="KVTransfer parameters." + ) class CompletionResponseStreamChoice(OpenAIBaseModel): @@ -1536,7 +1804,8 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): description=( "The stop string or token id that caused the completion " "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), + "including encountering the EOS token" + ), ) # not part of the OpenAI spec but for tracing the tokens # prompt tokens is put into choice to align with CompletionResponseChoice @@ -1610,7 +1879,8 @@ class ClassificationRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1618,7 +1888,10 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation, + ) class ClassificationData(OpenAIBaseModel): @@ -1722,8 +1995,9 @@ class ChatCompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", - "priority"]] = None + service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( + None + ) system_fingerprint: Optional[str] = None usage: UsageInfo @@ -1731,7 +2005,8 @@ class ChatCompletionResponse(OpenAIBaseModel): prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( - default=None, description="KVTransfer parameters.") + default=None, description="KVTransfer parameters." + ) class DeltaMessage(OpenAIBaseModel): @@ -1782,7 +2057,8 @@ class InputTokensDetails(OpenAIBaseModel): class OutputTokensDetails(OpenAIBaseModel): - reasoning_tokens: int + reasoning_tokens: int = 0 + tool_output_tokens: int = 0 class ResponseUsage(OpenAIBaseModel): @@ -1797,7 +2073,7 @@ class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) # error: Optional[ResponseError] = None - # incomplete_details: Optional[IncompleteDetails] = None + incomplete_details: Optional[IncompleteDetails] = None instructions: Optional[str] = None metadata: Optional[Metadata] = None model: str @@ -1822,6 +2098,49 @@ class ResponsesResponse(OpenAIBaseModel): usage: Optional[ResponseUsage] = None user: Optional[str] = None + # --8<-- [start:responses-extra-params] + # These are populated when enable_response_messages is set to True + # NOTE: custom serialization is needed + # see serialize_input_messages and serialize_output_messages + input_messages: Optional[list[ChatCompletionMessageParam]] = None + output_messages: Optional[list[ChatCompletionMessageParam]] = None + # --8<-- [end:responses-extra-params] + + # NOTE: openAI harmony doesn't serialize TextContent properly, + # TODO: this fixes for TextContent, but need to verify for tools etc + # https://github.com/openai/harmony/issues/78 + @field_serializer("output_messages", when_used="json") + def serialize_output_messages(self, msgs, _info): + if msgs: + serialized = [] + for m in msgs: + if isinstance(m, dict): + serialized.append(m) + elif hasattr(m, "__dict__"): + serialized.append(m.to_dict()) + else: + # fallback to pyandic dump + serialized.append(m.model_dump_json()) + return serialized + return None + + # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it + # https://github.com/openai/harmony/issues/78 + @field_serializer("input_messages", when_used="json") + def serialize_input_messages(self, msgs, _info): + if msgs: + serialized = [] + for m in msgs: + if isinstance(m, dict): + serialized.append(m) + elif hasattr(m, "__dict__"): + serialized.append(m.to_dict()) + else: + # fallback to pyandic dump + serialized.append(m.model_dump_json()) + return serialized + return None + @classmethod def from_request( cls, @@ -1832,14 +2151,25 @@ class ResponsesResponse(OpenAIBaseModel): output: list[ResponseOutputItem], status: ResponseStatus, usage: Optional[ResponseUsage] = None, + input_messages: Optional[list[ChatCompletionMessageParam]] = None, + output_messages: Optional[list[ChatCompletionMessageParam]] = None, ) -> "ResponsesResponse": + incomplete_details: Optional[IncompleteDetails] = None + if status == "incomplete": + incomplete_details = IncompleteDetails(reason="max_output_tokens") + # TODO: implement the other reason for incomplete_details, + # which is content_filter + # incomplete_details = IncompleteDetails(reason='content_filter') return cls( id=request.request_id, created_at=created_time, + incomplete_details=incomplete_details, instructions=request.instructions, metadata=request.metadata, model=model_name, output=output, + input_messages=input_messages, + output_messages=output_messages, parallel_tool_calls=request.parallel_tool_calls, temperature=sampling_params.temperature, tool_choice=request.tool_choice, @@ -1861,8 +2191,89 @@ class ResponsesResponse(OpenAIBaseModel): ) -BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, - ScoreRequest, RerankRequest] +# TODO: this code can be removed once +# https://github.com/openai/openai-python/issues/2634 has been resolved +class ResponseReasoningPartDoneEvent(OpenAIBaseModel): + content_index: int + """The index of the content part that is done.""" + + item_id: str + """The ID of the output item that the content part was added to.""" + + output_index: int + """The index of the output item that the content part was added to.""" + + part: ResponseReasoningTextContent + """The content part that is done.""" + + sequence_number: int + """The sequence number of this event.""" + + type: Literal["response.reasoning_part.done"] + """The type of the event. Always `response.reasoning_part.done`.""" + + +# TODO: this code can be removed once +# https://github.com/openai/openai-python/issues/2634 has been resolved +class ResponseReasoningPartAddedEvent(OpenAIBaseModel): + content_index: int + """The index of the content part that is done.""" + + item_id: str + """The ID of the output item that the content part was added to.""" + + output_index: int + """The index of the output item that the content part was added to.""" + + part: ResponseReasoningTextContent + """The content part that is done.""" + + sequence_number: int + """The sequence number of this event.""" + + type: Literal["response.reasoning_part.added"] + """The type of the event. Always `response.reasoning_part.added`.""" + + +# vLLM Streaming Events +# Note: we override the response type with the vLLM ResponsesResponse type +class ResponseCompletedEvent(OpenAIResponseCompletedEvent): + response: ResponsesResponse # type: ignore[override] + + +class ResponseCreatedEvent(OpenAIResponseCreatedEvent): + response: ResponsesResponse # type: ignore[override] + + +class ResponseInProgressEvent(OpenAIResponseInProgressEvent): + response: ResponsesResponse # type: ignore[override] + + +StreamingResponsesResponse: TypeAlias = Union[ + "ResponseCreatedEvent", + "ResponseInProgressEvent", + "ResponseCompletedEvent", + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseReasoningPartAddedEvent, + ResponseReasoningPartDoneEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, + ResponseWebSearchCallCompletedEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterCallCompletedEvent, +] + +BatchRequestInputBody = Union[ + ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest +] class BatchRequestInput(OpenAIBaseModel): @@ -1887,7 +2298,7 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: BatchRequestInputBody - @field_validator('body', mode='plain') + @field_validator("body", mode="plain") @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models @@ -1911,8 +2322,9 @@ class BatchResponseData(OpenAIBaseModel): request_id: str # The body of the response. - body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, - ScoreResponse, RerankResponse]] = None + body: Optional[ + Union[ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse] + ] = None class BatchRequestOutput(OpenAIBaseModel): @@ -1941,12 +2353,14 @@ class TokenizeCompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) return_token_strs: Optional[bool] = Field( default=False, - description=("If true, also return the token strings " - "corresponding to the token ids."), + description=( + "If true, also return the token strings corresponding to the token ids." + ), ) @@ -1956,24 +2370,27 @@ class TokenizeChatRequest(OpenAIBaseModel): add_generation_prompt: bool = Field( default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) return_token_strs: Optional[bool] = Field( default=False, - description=("If true, also return the token strings " - "corresponding to the token ids."), + description=( + "If true, also return the token strings corresponding to the token ids." + ), ) continue_final_message: bool = Field( default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), ) add_special_tokens: bool = Field( default=False, @@ -1982,7 +2399,8 @@ class TokenizeChatRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) chat_template: Optional[str] = Field( default=None, @@ -1990,13 +2408,15 @@ class TokenizeChatRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -2010,10 +2430,11 @@ class TokenizeChatRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data @@ -2038,7 +2459,7 @@ class DetokenizeResponse(OpenAIBaseModel): class TokenizerInfoResponse(OpenAIBaseModel): """ - Response containing tokenizer configuration + Response containing tokenizer configuration equivalent to tokenizer_config.json """ @@ -2057,8 +2478,7 @@ class UnloadLoRAAdapterRequest(BaseModel): ## Protocols for Audio -AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", - "vtt"] +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] class TranscriptionRequest(OpenAIBaseModel): @@ -2100,7 +2520,8 @@ class TranscriptionRequest(OpenAIBaseModel): ## TODO (varun) : Support if set to 0, certain thresholds are met !! timestamp_granularities: list[Literal["word", "segment"]] = Field( - alias="timestamp_granularities[]", default=[]) + alias="timestamp_granularities[]", default=[] + ) """The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. @@ -2120,11 +2541,20 @@ class TranscriptionRequest(OpenAIBaseModel): vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:transcription-extra-params] + to_language: Optional[str] = None + """The language of the output audio we transcribe to. + + Please note that this is not currently used by supported models at this + time, but it is a placeholder for future use, matching translation api. + """ + # --8<-- [start:transcription-sampling-params] temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -2171,10 +2601,8 @@ class TranscriptionRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2183,35 +2611,42 @@ class TranscriptionRequest(OpenAIBaseModel): # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - top_p=top_p, - top_k=top_k, - min_p=min_p, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - presence_penalty=self.presence_penalty, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY, - extra_args=self.vllm_xargs) + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + extra_args=self.vllm_xargs, + ) @model_validator(mode="before") @classmethod @@ -2225,16 +2660,21 @@ class TranscriptionRequest(OpenAIBaseModel): stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data # Transcription response objects +class TranscriptionUsageAudio(OpenAIBaseModel): + type: Literal["duration"] = "duration" + seconds: int + + class TranscriptionResponse(OpenAIBaseModel): text: str """The transcribed text.""" + usage: TranscriptionUsageAudio class TranscriptionWord(OpenAIBaseModel): @@ -2352,6 +2792,9 @@ class TranslationRequest(OpenAIBaseModel): # TODO support additional sampling parameters # --8<-- [start:translation-sampling-params] + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -2371,6 +2814,14 @@ class TranslationRequest(OpenAIBaseModel): will improve accuracy. """ + to_language: Optional[str] = None + """The language of the input audio we translate to. + + Please note that this is not supported by all models, refer to the specific + model documentation for more details. + For instance, Whisper only supports `to_language=en`. + """ + stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -2387,10 +2838,8 @@ class TranslationRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2398,13 +2847,17 @@ class TranslationRequest(OpenAIBaseModel): # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY) + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + ) @model_validator(mode="before") @classmethod @@ -2412,8 +2865,7 @@ class TranslationRequest(OpenAIBaseModel): stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 01551a8c7f04a..e394f24f8793f 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -14,23 +14,22 @@ import torch from prometheus_client import start_http_server from tqdm import tqdm -import vllm.envs as envs -from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -# yapf: disable -from vllm.entrypoints.openai.protocol import (BatchRequestInput, - BatchRequestOutput, - BatchResponseData, - ChatCompletionResponse, - EmbeddingResponse, ErrorResponse, - RerankResponse, ScoreResponse) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + EmbeddingResponse, + ErrorResponse, + RerankResponse, + ScoreResponse, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_score import ServingScores from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser, random_uuid @@ -45,10 +44,10 @@ def make_arg_parser(parser: FlexibleArgumentParser): "--input-file", required=True, type=str, - help= - "The path or url to a single input file. Currently supports local file " + help="The path or url to a single input file. Currently supports local file " "paths, or the http protocol (http or https). If a URL is specified, " - "the file should be available via HTTP GET.") + "the file should be available via HTTP GET.", + ) parser.add_argument( "-o", "--output-file", @@ -56,7 +55,8 @@ def make_arg_parser(parser: FlexibleArgumentParser): type=str, help="The path or url to a single output file. Currently supports " "local file paths, or web (http or https) urls. If a URL is specified," - " the file should be available via HTTP PUT.") + " the file should be available via HTTP PUT.", + ) parser.add_argument( "--output-tmp-dir", type=str, @@ -64,24 +64,27 @@ def make_arg_parser(parser: FlexibleArgumentParser): help="The directory to store the output file before uploading it " "to the output URL.", ) - parser.add_argument("--response-role", - type=optional_type(str), - default="assistant", - help="The role name to return if " - "`request.add_generation_prompt=True`.") + parser.add_argument( + "--response-role", + type=optional_type(str), + default="assistant", + help="The role name to return if `request.add_generation_prompt=True`.", + ) parser = AsyncEngineArgs.add_cli_args(parser) - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - '\n\nDefault: Unlimited') + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="Max number of prompt characters or prompt " + "ID numbers being printed in log." + "\n\nDefault: Unlimited", + ) - parser.add_argument("--enable-metrics", - action="store_true", - help="Enable Prometheus metrics") + parser.add_argument( + "--enable-metrics", action="store_true", help="Enable Prometheus metrics" + ) parser.add_argument( "--url", type=str, @@ -98,16 +101,16 @@ def make_arg_parser(parser: FlexibleArgumentParser): ) parser.add_argument( "--enable-prompt-tokens-details", - action='store_true', + action="store_true", default=False, - help="If set to True, enable prompt_tokens_details in usage.") + help="If set to True, enable prompt_tokens_details in usage.", + ) return parser def parse_args(): - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible batch runner.") + parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.") return make_arg_parser(parser).parse_args() @@ -119,7 +122,6 @@ _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elap class BatchProgressTracker: - def __init__(self): self._total = 0 self._pbar: Optional[tqdm] = None @@ -132,43 +134,45 @@ class BatchProgressTracker: self._pbar.update() def pbar(self) -> tqdm: - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 - self._pbar = tqdm(total=self._total, - unit="req", - desc="Running batch", - mininterval=5, - disable=not enable_tqdm, - bar_format=_BAR_FORMAT) + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + self._pbar = tqdm( + total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ) return self._pbar async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): - async with aiohttp.ClientSession() as session, \ - session.get(path_or_url) as resp: + async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp: return await resp.text() else: with open(path_or_url, encoding="utf-8") as f: return f.read() -async def write_local_file(output_path: str, - batch_outputs: list[BatchRequestOutput]) -> None: +async def write_local_file( + output_path: str, batch_outputs: list[BatchRequestOutput] +) -> None: """ Write the responses to a local file. output_path: The path to write the responses to. batch_outputs: The list of batch outputs to write. """ # We should make this async, but as long as run_batch runs as a - # standalone program, blocking the event loop won't effect performance. + # standalone program, blocking the event loop won't affect performance. with open(output_path, "w", encoding="utf-8") as f: for o in batch_outputs: print(o.model_dump_json(), file=f) -async def upload_data(output_url: str, data_or_file: str, - from_file: bool) -> None: +async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None: """ Upload a local file to a URL. output_url: The URL to upload the file to. @@ -185,23 +189,26 @@ async def upload_data(output_url: str, data_or_file: str, try: # We increase the timeout to 1000 seconds to allow # for large files (default is 300). - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( - total=1000)) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1000) + ) as session: if from_file: with open(data_or_file, "rb") as file: - async with session.put(output_url, - data=file) as response: + async with session.put(output_url, data=file) as response: if response.status != 200: - raise Exception(f"Failed to upload file.\n" - f"Status: {response.status}\n" - f"Response: {response.text()}") + raise Exception( + f"Failed to upload file.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}" + ) else: - async with session.put(output_url, - data=data_or_file) as response: + async with session.put(output_url, data=data_or_file) as response: if response.status != 200: - raise Exception(f"Failed to upload data.\n" - f"Status: {response.status}\n" - f"Response: {response.text()}") + raise Exception( + f"Failed to upload data.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}" + ) except Exception as e: if attempt < max_retries: @@ -218,8 +225,9 @@ async def upload_data(output_url: str, data_or_file: str, ) from e -async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], - output_tmp_dir: str) -> None: +async def write_file( + path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str +) -> None: """ Write batch_outputs to a file or upload to a URL. path_or_url: The path or URL to write batch_outputs to. @@ -243,14 +251,13 @@ async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], else: # Write responses to a temporary file and then upload it to the URL. with tempfile.NamedTemporaryFile( - mode="w", - encoding="utf-8", - dir=output_tmp_dir, - prefix="tmp_batch_output_", - suffix=".jsonl", + mode="w", + encoding="utf-8", + dir=output_tmp_dir, + prefix="tmp_batch_output_", + suffix=".jsonl", ) as f: - logger.info("Writing outputs to temporary local file %s", - f.name) + logger.info("Writing outputs to temporary local file %s", f.name) await write_local_file(f.name, batch_outputs) logger.info("Uploading outputs to %s", path_or_url) await upload_data(path_or_url, f.name, from_file=True) @@ -259,8 +266,9 @@ async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], await write_local_file(path_or_url, batch_outputs) -def make_error_request_output(request: BatchRequestInput, - error_msg: str) -> BatchRequestOutput: +def make_error_request_output( + request: BatchRequestInput, error_msg: str +) -> BatchRequestOutput: batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, @@ -274,25 +282,28 @@ def make_error_request_output(request: BatchRequestInput, async def make_async_error_request_output( - request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + request: BatchRequestInput, error_msg: str +) -> BatchRequestOutput: return make_error_request_output(request, error_msg) -async def run_request(serving_engine_func: Callable, - request: BatchRequestInput, - tracker: BatchProgressTracker) -> BatchRequestOutput: +async def run_request( + serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker, +) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance( - response, - (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, - RerankResponse), + response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse), ): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( - body=response, request_id=f"vllm-batch-{random_uuid()}"), + body=response, request_id=f"vllm-batch-{random_uuid()}" + ), error=None, ) elif isinstance(response, ErrorResponse): @@ -301,12 +312,14 @@ async def run_request(serving_engine_func: Callable, custom_id=request.custom_id, response=BatchResponseData( status_code=response.error.code, - request_id=f"vllm-batch-{random_uuid()}"), + request_id=f"vllm-batch-{random_uuid()}", + ), error=response, ) else: batch_output = make_error_request_output( - request, error_msg="Request must not be sent in stream mode") + request, error_msg="Request must not be sent in stream mode" + ) tracker.completed() return batch_output @@ -314,7 +327,6 @@ async def run_request(serving_engine_func: Callable, async def run_batch( engine_client: EngineClient, - vllm_config: VllmConfig, args: Namespace, ) -> None: if args.served_model_name is not None: @@ -328,55 +340,58 @@ async def run_batch( request_logger = None base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names + BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] - model_config = vllm_config.model_config - - if envs.VLLM_USE_V1: - supported_tasks = await engine_client \ - .get_supported_tasks() # type: ignore - else: - supported_tasks = model_config.supported_tasks - - logger.info("Supported_tasks: %s", supported_tasks) + model_config = engine_client.model_config + supported_tasks = await engine_client.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) # Create the openai serving objects. openai_serving_models = OpenAIServingModels( engine_client=engine_client, - model_config=model_config, base_model_paths=base_model_paths, lora_modules=None, ) - openai_serving_chat = OpenAIServingChat( - engine_client, - model_config, - openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=None, - chat_template_content_format="auto", - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if "generate" in supported_tasks else None - openai_serving_embedding = OpenAIServingEmbedding( - engine_client, - model_config, - openai_serving_models, - request_logger=request_logger, - chat_template=None, - chat_template_content_format="auto", - ) if "embed" in supported_tasks else None + openai_serving_chat = ( + OpenAIServingChat( + engine_client, + openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) + if "generate" in supported_tasks + else None + ) + openai_serving_embedding = ( + OpenAIServingEmbedding( + engine_client, + openai_serving_models, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + ) + if "embed" in supported_tasks + else None + ) - enable_serving_reranking = ("classify" in supported_tasks and getattr( - model_config.hf_config, "num_labels", 0) == 1) + enable_serving_reranking = ( + "classify" in supported_tasks + and getattr(model_config.hf_config, "num_labels", 0) == 1 + ) - openai_serving_scores = ServingScores( - engine_client, - model_config, - openai_serving_models, - request_logger=request_logger, - ) if ("embed" in supported_tasks or enable_serving_reranking) else None + openai_serving_scores = ( + ServingScores( + engine_client, + openai_serving_models, + request_logger=request_logger, + ) + if ("embed" in supported_tasks or enable_serving_reranking) + else None + ) tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) @@ -393,61 +408,72 @@ async def run_batch( # Determine the type of request and run it. if request.url == "/v1/chat/completions": - chat_handler_fn = openai_serving_chat.create_chat_completion if \ - openai_serving_chat is not None else None + chat_handler_fn = ( + openai_serving_chat.create_chat_completion + if openai_serving_chat is not None + else None + ) if chat_handler_fn is None: response_futures.append( make_async_error_request_output( request, - error_msg= - "The model does not support Chat Completions API", - )) + error_msg="The model does not support Chat Completions API", + ) + ) continue - response_futures.append( - run_request(chat_handler_fn, request, tracker)) + response_futures.append(run_request(chat_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": - embed_handler_fn = openai_serving_embedding.create_embedding if \ - openai_serving_embedding is not None else None + embed_handler_fn = ( + openai_serving_embedding.create_embedding + if openai_serving_embedding is not None + else None + ) if embed_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Embeddings API", - )) + ) + ) continue - response_futures.append( - run_request(embed_handler_fn, request, tracker)) + response_futures.append(run_request(embed_handler_fn, request, tracker)) tracker.submitted() elif request.url.endswith("/score"): - score_handler_fn = openai_serving_scores.create_score if \ - openai_serving_scores is not None else None + score_handler_fn = ( + openai_serving_scores.create_score + if openai_serving_scores is not None + else None + ) if score_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Scores API", - )) + ) + ) continue - response_futures.append( - run_request(score_handler_fn, request, tracker)) + response_futures.append(run_request(score_handler_fn, request, tracker)) tracker.submitted() elif request.url.endswith("/rerank"): - rerank_handler_fn = openai_serving_scores.do_rerank if \ - openai_serving_scores is not None else None + rerank_handler_fn = ( + openai_serving_scores.do_rerank + if openai_serving_scores is not None + else None + ) if rerank_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Rerank API", - )) + ) + ) continue - response_futures.append( - run_request(rerank_handler_fn, request, tracker)) + response_futures.append(run_request(rerank_handler_fn, request, tracker)) tracker.submitted() else: response_futures.append( @@ -458,7 +484,8 @@ async def run_batch( " /score, /rerank ." "See vllm/entrypoints/openai/api_server.py for supported " "score/rerank versions.", - )) + ) + ) with tracker.pbar(): responses = await asyncio.gather(*response_futures) @@ -471,13 +498,11 @@ async def main(args: Namespace): from vllm.usage.usage_lib import UsageContext async with build_async_engine_client( - args, - usage_context=UsageContext.OPENAI_BATCH_RUNNER, - disable_frontend_multiprocessing=False, + args, + usage_context=UsageContext.OPENAI_BATCH_RUNNER, + disable_frontend_multiprocessing=False, ) as engine_client: - vllm_config = await engine_client.get_vllm_config() - - await run_batch(engine_client, vllm_config, args) + await run_batch(engine_client, args) if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 65aac23ee618e..94c24ce9b307a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ import json import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import Callable, Final, Optional, Union +from typing import Final, Optional, Union import jinja2 import partial_json_parser @@ -15,59 +15,77 @@ from fastapi import Request from openai_harmony import Message as OpenAIMessage from pydantic import TypeAdapter -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, - ConversationMessage, - get_history_tool_calls_cnt, - make_tool_call_id) +from vllm.entrypoints.chat_utils import ( + ChatTemplateContentFormatOption, + ConversationMessage, + get_history_tool_calls_cnt, + make_tool_call_id, +) from vllm.entrypoints.harmony_utils import ( - get_developer_message, get_stop_tokens_for_assistant_actions, - get_streamable_parser_for_assistant, get_system_message, parse_chat_input, - parse_chat_output, render_for_completion) + get_developer_message, + get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, + get_system_message, + parse_chat_input, + parse_chat_output, + render_for_completion, +) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - ChatCompletionLogProb, ChatCompletionLogProbs, - ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, - ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition, - PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - clamp_prompt_logprobs) + ChatCompletionLogProb, + ChatCompletionLogProbs, + ChatCompletionLogProbsContent, + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ErrorResponse, + FunctionCall, + FunctionDefinition, + PromptTokenUsageInfo, + RequestResponseMetadata, + ToolCall, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( - MistralToolCall) +from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput -from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, - truncate_tool_call_ids, - validate_request_params) +from vllm.transformers_utils.tokenizers import ( + maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params, +) from vllm.utils import as_list logger = init_logger(__name__) class OpenAIServingChat(OpenAIServing): - def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, response_role: str, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, return_tokens_as_token_ids: bool = False, reasoning_parser: str = "", enable_auto_tools: bool = False, @@ -76,75 +94,57 @@ class OpenAIServingChat(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, + ) self.response_role = response_role self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template self.enable_log_outputs = enable_log_outputs + # set up reasoning parser + self.reasoning_parser = self._get_reasoning_parser( + reasoning_parser_name=reasoning_parser + ) # set up tool use self.enable_auto_tools: bool = enable_auto_tools - if self.enable_auto_tools: - logger.info( - "\"auto\" tool choice has been enabled please note that while" - " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored.") - - self.reasoning_parser: Optional[Callable[[AnyTokenizer], - ReasoningParser]] = None - if reasoning_parser: - try: - self.reasoning_parser = ( - ReasoningParserManager.get_reasoning_parser( - reasoning_parser)) - assert self.reasoning_parser is not None - except Exception as e: - raise TypeError( - f"{reasoning_parser=} has not been registered") from e - self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None - if self.enable_auto_tools: - try: - if (tool_parser == "pythonic" and - model_config.model.startswith("meta-llama/Llama-3.2")): - logger.warning( - "Llama3.2 models may struggle to emit valid pythonic" - " tool calls") - self.tool_parser = ToolParserManager.get_tool_parser( - tool_parser) - except Exception as e: - raise TypeError("Error: --enable-auto-tool-choice requires " - f"tool_parser:'{tool_parser}' which has not " - "been registered") from e - self.exclude_tools_when_tool_choice_none = ( - exclude_tools_when_tool_choice_none) + self.tool_parser = self._get_tool_parser( + tool_parser_name=tool_parser, enable_auto_tools=enable_auto_tools + ) + self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default chat sampling params from %s: %s", - source, self.default_sampling_params) - if self.model_config.hf_config.model_type == 'kimi_k2': - self.tool_call_id_type = 'kimi_k2' + logger.info( + "Using default chat sampling params from %s: %s", + source, + self.default_sampling_params, + ) + if self.model_config.hf_config.model_type == "kimi_k2": + self.tool_call_id_type = "kimi_k2" else: - self.tool_call_id_type = 'random' + self.tool_call_id_type = "random" - self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: if "stop_token_ids" not in self.default_sampling_params: self.default_sampling_params["stop_token_ids"] = [] self.default_sampling_params["stop_token_ids"].extend( - get_stop_tokens_for_assistant_actions()) + get_stop_tokens_for_assistant_actions() + ) # NOTE(woosuk): While OpenAI's chat completion API supports browsing # for some models, currently vLLM doesn't support it. Please use the @@ -160,8 +160,7 @@ class OpenAIServingChat(OpenAIServing): self, request: ChatCompletionRequest, raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, - ErrorResponse]: + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: """ Chat Completion API similar to OpenAI's API. @@ -182,11 +181,12 @@ class OpenAIServingChat(OpenAIServing): try: lora_request = self._maybe_get_adapters( - request, supports_default_mm_loras=True) + request, supports_default_mm_loras=True + ) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() tool_parser = self.tool_parser @@ -198,26 +198,36 @@ class OpenAIServingChat(OpenAIServing): truncate_tool_call_ids(request) validate_request_params(request) - if (request.tool_choice == "auto" and - not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer) - and not self.use_harmony): + if ( + request.tool_choice == "auto" + and not (self.enable_auto_tools and tool_parser is not None) + and not isinstance(tokenizer, MistralTokenizer) + and not self.use_harmony + ): # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( - "\"auto\" tool choice requires " + '"auto" tool choice requires ' "--enable-auto-tool-choice and --tool-call-parser to be set" ) - if (request.tools is None - or (request.tool_choice == "none" - and self.exclude_tools_when_tool_choice_none)): + if request.tools is None or ( + request.tool_choice == "none" + and self.exclude_tools_when_tool_choice_none + ): tool_dicts = None else: tool_dicts = [tool.model_dump() for tool in request.tools] if not self.use_harmony: # Common case. + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( conversation, request_prompts, @@ -227,15 +237,13 @@ class OpenAIServingChat(OpenAIServing): tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, + chat_template_content_format=self.chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, tool_dicts=tool_dicts, documents=request.documents, chat_template_kwargs=request.chat_template_kwargs, tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) else: @@ -245,13 +253,13 @@ class OpenAIServingChat(OpenAIServing): request_prompts, engine_prompts, ) = self._make_request_with_harmony(request) - except (ValueError, TypeError, RuntimeError, - jinja2.TemplateError) as e: + except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") - request_id = "chatcmpl-" \ - f"{self._base_request_id(raw_request, request.request_id)}" + request_id = ( + f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" + ) request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: @@ -261,7 +269,7 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - sampling_params: Union[SamplingParams, BeamSearchParams] + prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) if self.default_sampling_params is None: self.default_sampling_params = {} @@ -270,39 +278,60 @@ class OpenAIServingChat(OpenAIServing): max_model_len=self.max_model_len, request=request, input_length=len(engine_prompt["prompt_token_ids"]), - default_sampling_params=self.default_sampling_params) + default_sampling_params=self.default_sampling_params, + ) + sampling_params: Union[SamplingParams, BeamSearchParams] if request.use_beam_search: sampling_params = request.to_beam_search_params( - max_tokens, self.default_sampling_params) + max_tokens, self.default_sampling_params + ) else: sampling_params = request.to_sampling_params( - max_tokens, self.model_config.logits_processor_pattern, - self.default_sampling_params) + max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params, + ) - self._log_inputs(request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request) + self._log_inputs( + request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) if isinstance(sampling_params, BeamSearchParams): - generator = self.engine_client.beam_search( + generator = self.beam_search( prompt=engine_prompt, request_id=request_id, params=sampling_params, lora_request=lora_request, ) else: - generator = self.engine_client.generate( + engine_request, tokenization_kwargs = await self._process_inputs( + request_id, engine_prompt, sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generator = self.engine_client.generate( + engine_request, + sampling_params, request_id, lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + prompt_text=prompt_text, + tokenization_kwargs=tokenization_kwargs, ) generators.append(generator) @@ -311,7 +340,7 @@ class OpenAIServingChat(OpenAIServing): return self.create_error_response(str(e)) assert len(generators) == 1 - result_generator, = generators + (result_generator,) = generators # Streaming response if request.stream: @@ -323,12 +352,19 @@ class OpenAIServingChat(OpenAIServing): conversation, tokenizer, request_metadata, - enable_force_include_usage=self.enable_force_include_usage) + enable_force_include_usage=self.enable_force_include_usage, + ) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, model_name, - conversation, tokenizer, request_metadata) + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -339,7 +375,7 @@ class OpenAIServingChat(OpenAIServing): return request.messages[-1]["role"] @staticmethod - def _bracket_level(s: str, opening='{', closing='}') -> int: + def _bracket_level(s: str, opening="{", closing="}") -> int: """ Calculate the current level of nested brackets in a given string. """ @@ -352,8 +388,7 @@ class OpenAIServingChat(OpenAIServing): return level @staticmethod - def _filter_delta_text(delta_text: str, - previous_text: str) -> tuple[str, bool]: + def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]: # remove last '},' of the tool definition stemming from the # "name"/"parameters" outer object or closing ']' of the tool list # count occurrences of opening and closing curly braces and @@ -363,10 +398,10 @@ class OpenAIServingChat(OpenAIServing): bracket_level = OpenAIServingChat._bracket_level(previous_text) updated_delta, passed_zero = "", False for c in delta_text: - if c == '{': + if c == "{": bracket_level += 1 passed_zero = bracket_level == 0 - elif c == '}': + elif c == "}": bracket_level -= 1 passed_zero = bracket_level == 0 @@ -374,7 +409,7 @@ class OpenAIServingChat(OpenAIServing): updated_delta += c else: # if a comma is reached at level 0 we can stop - if c == ',': + if c == ",": break return updated_delta, passed_zero @@ -384,7 +419,7 @@ class OpenAIServingChat(OpenAIServing): current_text: Optional[str], delta_text: str, function_name_returned: bool, - tool_call_idx: Optional[int] = None + tool_call_idx: Optional[int] = None, ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -392,7 +427,7 @@ class OpenAIServingChat(OpenAIServing): try: obj = partial_json_parser.loads(current_text) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") obj = None # check if the current text is a valid array @@ -403,60 +438,72 @@ class OpenAIServingChat(OpenAIServing): delta_message = None else: _, finishes_previous_tool = OpenAIServingChat._filter_delta_text( - delta_text, previous_text) + delta_text, previous_text + ) # take the last tool call from the generated list current_tool_call = obj[-1] # once parameters have been generated the name is complete as well - if not finishes_previous_tool and ("name" not in current_tool_call - or "parameters" - not in current_tool_call): + if not finishes_previous_tool and ( + "name" not in current_tool_call or "parameters" not in current_tool_call + ): function_name_returned = False delta_message = None else: if not function_name_returned: # get partly generated arguments from the latest tool call - param_match = re.search(r'.*"parameters":\s*(.*)', - current_text) + param_match = re.search( + r'.*"parameters":\s*(.*)', current_text, re.DOTALL + ) arguments = param_match.group(1) if param_match else "" arguments, _ = OpenAIServingChat._filter_delta_text( - arguments, previous_text) + arguments, previous_text + ) # if this iteration finishes a previous tool call but a # new incomplete tool is already generated, take the # previous from the list - if (finishes_previous_tool - and "parameters" not in current_tool_call): + if finishes_previous_tool and "parameters" not in current_tool_call: current_tool_call = obj[-2] function_name_returned = True tool_call_id = make_tool_call_id( id_type=self.tool_call_id_type, func_name=current_tool_call["name"], - idx=tool_call_idx) - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=tool_call_id, - function=DeltaFunctionCall( - name=current_tool_call["name"], - arguments=arguments), - index=len(obj) - 1, - type="function") - ]) + idx=tool_call_idx, + ) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=tool_call_id, + function=DeltaFunctionCall( + name=current_tool_call["name"], arguments=arguments + ), + index=len(obj) - 1, + type="function", + ) + ] + ) else: delta_text, _ = OpenAIServingChat._filter_delta_text( - delta_text, previous_text) + delta_text, previous_text + ) if delta_text != "": - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - function=DeltaFunctionCall( - # OpenAI API returns None - # instead of name every time - name=None, - arguments=delta_text), - index=len(obj) - 1) - ]) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + function=DeltaFunctionCall( + # OpenAI API returns None + # instead of name every time + name=None, + arguments=delta_text, + ), + index=len(obj) - 1, + ) + ] + ) else: delta_message = None @@ -485,9 +532,10 @@ class OpenAIServingChat(OpenAIServing): num_cached_tokens = None if self.use_harmony: harmony_parsers = [ - get_streamable_parser_for_assistant() - for _ in range(num_choices) + get_streamable_parser_for_assistant() for _ in range(num_choices) ] + harmony_tools_streamed = [False] * num_choices + tools_streamed = [False] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -497,11 +545,12 @@ class OpenAIServingChat(OpenAIServing): # Determine whether tools are in use with "auto" tool choice tool_choice_auto = ( not tool_choice_function_name - and self._should_stream_with_auto_tool_parsing(request)) + and self._should_stream_with_auto_tool_parsing(request) + ) all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices - if self.tool_call_id_type == 'kimi_k2': + if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) else: history_tool_call_cnt = 0 @@ -548,10 +597,10 @@ class OpenAIServingChat(OpenAIServing): stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage \ - or enable_force_include_usage - include_continuous_usage = include_usage and \ - stream_options.continuous_usage_stats + include_usage = stream_options.include_usage or enable_force_include_usage + include_continuous_usage = ( + include_usage and stream_options.continuous_usage_stats + ) else: include_usage, include_continuous_usage = False, False @@ -581,7 +630,8 @@ class OpenAIServingChat(OpenAIServing): content="", ), logprobs=None, - finish_reason=None) + finish_reason=None, + ) # return prompt_token_ids at the first chunk ever chunk = ChatCompletionStreamResponse( @@ -590,16 +640,20 @@ class OpenAIServingChat(OpenAIServing): created=created_time, choices=[choice_data], model=model_name, - prompt_token_ids=(res.prompt_token_ids - if request.return_token_ids else - None)) + prompt_token_ids=( + res.prompt_token_ids + if request.return_token_ids + else None + ), + ) # if continuous usage stats are requested, add it if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens) + total_tokens=num_prompt_tokens, + ) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -608,33 +662,36 @@ class OpenAIServingChat(OpenAIServing): # last message if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if conversation and "content" in conversation[ - -1] and conversation[-1].get("role") == role: + if ( + conversation + and "content" in conversation[-1] + and conversation[-1].get("role") == role + ): last_msg_content = conversation[-1]["content"] or "" if last_msg_content: for i in range(num_choices): - choice_data = ( - ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage( - content=last_msg_content), - logprobs=None, - finish_reason=None)) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + logprobs=None, + finish_reason=None, + ) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens) + total_tokens=num_prompt_tokens, + ) - data = chunk.model_dump_json( - exclude_unset=True) + data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" first_iteration = False @@ -646,42 +703,41 @@ class OpenAIServingChat(OpenAIServing): continue if request.logprobs and request.top_logprobs is not None: - assert output.logprobs is not None, ( - "Did not output logprobs") + assert output.logprobs is not None, "Did not output logprobs" logprobs = self._create_chat_logprobs( token_ids=output.token_ids, top_logprobs=output.logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, - return_as_token_id=request. - return_tokens_as_token_ids, + return_as_token_id=request.return_tokens_as_token_ids, ) else: logprobs = None if self.use_harmony: harmony_parser = harmony_parsers[i] + prev_recipient = harmony_parser.current_recipient + delta_text = "" for token_id in output.token_ids: harmony_parser.process(token_id) - # FIXME(woosuk): Support function calling - is_final = harmony_parser.current_channel == "final" - if not (request.include_reasoning or is_final): - # Skip the reasoning content. - continue - delta_text = harmony_parser.last_content_delta or "" + delta_text += harmony_parser.last_content_delta or "" + cur_channel = harmony_parser.current_channel + cur_recipient = harmony_parser.current_recipient else: delta_text = output.text - if not delta_text and not output.token_ids and \ - not previous_num_tokens[i]: + if ( + not delta_text + and not output.token_ids + and not previous_num_tokens[i] + ): # Chunked prefill case, don't return empty chunks continue delta_message: Optional[DeltaMessage] # just update previous_texts and previous_token_ids - if ((tool_choice_auto or self.reasoning_parser) - and not self.use_harmony): + if tool_choice_auto or self.reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -690,42 +746,102 @@ class OpenAIServingChat(OpenAIServing): # avoid the None + list error. if previous_token_ids: current_token_ids = previous_token_ids + as_list( - output.token_ids) + output.token_ids + ) else: current_token_ids = as_list(output.token_ids) if self.use_harmony: - if is_final: + if cur_channel == "final": delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if request.include_reasoning: + delta_message = DeltaMessage( + reasoning_content=delta_text + ) + else: + delta_message = None + elif ( + cur_channel == "commentary" + and cur_recipient + and cur_recipient.startswith("functions.") + ): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if ( + msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith("functions.") + ): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split("functions.", 1)[1] + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ] + ) + elif delta_text: + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall( + arguments=delta_text + ), + ) + ] + ) + else: + delta_message = None + + if delta_message is not None: + harmony_tools_streamed[i] = True else: - delta_message = DeltaMessage( - reasoning_content=delta_text) + delta_message = None # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: - if (self.reasoning_parser and not reasoning_end_arr[i] - and not reasoning_parser.is_reasoning_end( - previous_token_ids)): + if ( + self.reasoning_parser + and not reasoning_end_arr[i] + and not reasoning_parser.is_reasoning_end( + previous_token_ids + ) + ): assert reasoning_parser is not None delta_message = ( - reasoning_parser. - extract_reasoning_content_streaming( + reasoning_parser.extract_reasoning_content_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, output.token_ids, - )) + ) + ) # When encountering think end id in delta_token_ids # or think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. # Only keep 'content', remove 'reasoning_content'. if reasoning_parser.is_reasoning_end( - as_list(output.token_ids)) or ( - res.prompt_token_ids - and reasoning_parser.is_reasoning_end( - res.prompt_token_ids)): + as_list(output.token_ids) + ) or ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids + ) + ): reasoning_end_arr[i] = True if delta_message and delta_message.content: # This need to be added to next `delta_text` @@ -741,22 +857,27 @@ class OpenAIServingChat(OpenAIServing): if function_name_returned[i]: delta_tool_call = DeltaToolCall( - function=DeltaFunctionCall( - arguments=delta_text), - index=i) + function=DeltaFunctionCall(arguments=delta_text), + index=i, + ) else: delta_tool_call = DeltaToolCall( id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, - arguments=delta_text), - index=i) + arguments=delta_text, + ), + index=i, + ) function_name_returned[i] = True - delta_message = DeltaMessage(tool_calls=[ - delta_tool_call, - ]) + delta_message = DeltaMessage( + tool_calls=[ + delta_tool_call, + ] + ) + tools_streamed[i] = True elif request.tool_choice == "required": assert previous_texts is not None @@ -765,11 +886,9 @@ class OpenAIServingChat(OpenAIServing): fn_name_returned = function_name_returned[i] if self.reasoning_parser: - _, content = \ - reasoning_parser.extract_reasoning_content( - current_text, - request - ) + _, content = reasoning_parser.extract_reasoning_content( + current_text, request + ) else: content = current_text delta_message, function_name_returned[i] = ( @@ -778,13 +897,16 @@ class OpenAIServingChat(OpenAIServing): current_text=content, delta_text=delta_text, function_name_returned=fn_name_returned, - tool_call_idx=history_tool_call_cnt)) - if (delta_message and delta_message.tool_calls and - delta_message.tool_calls[0].id is not None): + tool_call_idx=history_tool_call_cnt, + ) + ) + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): history_tool_call_cnt += 1 - - # update the previous values for the next iteration - previous_texts[i] = current_text + tools_streamed[i] = True # handle streaming deltas for tools with "auto" tool choice # and reasoning parser @@ -796,23 +918,26 @@ class OpenAIServingChat(OpenAIServing): output_token_ids = as_list(output.token_ids) if not reasoning_end_arr[i]: delta_message = ( - reasoning_parser. - extract_reasoning_content_streaming( + reasoning_parser.extract_reasoning_content_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, output_token_ids, - )) + ) + ) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. # Remove the text and token ids related # to 'reasoning_content'. - if res.prompt_token_ids and \ - reasoning_parser.is_reasoning_end( - res.prompt_token_ids): + if ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids + ) + ): reasoning_end_arr[i] = True current_token_ids = output_token_ids if delta_message and delta_message.content: @@ -824,12 +949,13 @@ class OpenAIServingChat(OpenAIServing): # set reasoning status to end. # Remove the text and token ids related # to 'reasoning_content'. - if reasoning_parser.is_reasoning_end( - output_token_ids): + if reasoning_parser.is_reasoning_end(output_token_ids): reasoning_end_arr[i] = True - current_token_ids = \ + current_token_ids = ( reasoning_parser.extract_content_ids( - output_token_ids) + output_token_ids + ) + ) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -849,45 +975,52 @@ class OpenAIServingChat(OpenAIServing): delta_text = current_text delta_token_ids = current_token_ids - delta_message = ( - tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - request=request)) - # when only tool calls - elif tool_choice_auto: - assert tool_parser is not None - delta_message = ( - tool_parser.extract_tool_calls_streaming( + delta_message = tool_parser.extract_tool_calls_streaming( previous_text=previous_text, current_text=current_text, delta_text=delta_text, previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, - delta_token_ids=output.token_ids, - request=request)) + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True # when only reasoning elif self.reasoning_parser: - delta_message = (reasoning_parser. - extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - )) + delta_message = ( + reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) + ) # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if tool_choice_auto or self.reasoning_parser: + if ( + tool_choice_auto or self.reasoning_parser + ) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -905,7 +1038,10 @@ class OpenAIServingChat(OpenAIServing): # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - continue + if output.finish_reason is None: + continue + else: + delta_message = DeltaMessage() # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -916,7 +1052,8 @@ class OpenAIServingChat(OpenAIServing): delta_content = "".join( tc.function.arguments for tc in delta_message.tool_calls - if tc.function and tc.function.arguments) + if tc.function and tc.function.arguments + ) if delta_content: self.request_logger.log_outputs( @@ -935,71 +1072,101 @@ class OpenAIServingChat(OpenAIServing): delta=delta_message, logprobs=logprobs, finish_reason=None, - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None)) + token_ids=( + as_list(output.token_ids) + if request.return_token_ids + else None + ), + ) # if the model is finished generating else: # check to make sure we haven't "forgotten" to stream # any tokens that were generated but previously # matched by partial json parsing - # only happens if we are NOT using guided decoding + # only happens if we are NOT using structured outputs auto_tools_called = False if tool_parser: - auto_tools_called = len( - tool_parser.prev_tool_call_arr) > 0 - index = len(tool_parser.prev_tool_call_arr - ) - 1 if auto_tools_called else 0 + auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0 + index = ( + len(tool_parser.prev_tool_call_arr) - 1 + if auto_tools_called + else 0 + ) else: index = 0 - if self._should_check_for_unstreamed_tool_arg_tokens( - delta_message, output) and tool_parser: + if ( + self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output + ) + and tool_parser + ): latest_delta_len = 0 - if ((isinstance( + if ( + isinstance( delta_message.tool_calls[0].function, - DeltaFunctionCall)) and isinstance( - delta_message.tool_calls[0].function. - arguments, str)): + DeltaFunctionCall, + ) + ) and isinstance( + delta_message.tool_calls[0].function.arguments, str + ): latest_delta_len = len( - delta_message.tool_calls[0].function. - arguments) + delta_message.tool_calls[0].function.arguments + ) # get the expected call based on partial JSON # parsing which "autocompletes" the JSON expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( - "arguments", {}), - ensure_ascii=False) + "arguments", {} + ), + ensure_ascii=False, + ) # get what we've streamed so far for arguments # for the current tool - actual_call = tool_parser.streamed_args_for_tool[ - index] - if (latest_delta_len > 0): + actual_call = tool_parser.streamed_args_for_tool[index] + if latest_delta_len > 0: actual_call = actual_call[:-latest_delta_len] # check to see if there's anything left to stream - remaining_call = expected_call.replace( - actual_call, "", 1) + remaining_call = expected_call.replace(actual_call, "", 1) # set that as a delta message - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(index=index, - function=DeltaFunctionCall( - arguments=remaining_call). - model_dump(exclude_none=True)) - ]) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=index, + function=DeltaFunctionCall( + arguments=remaining_call + ).model_dump(exclude_none=True), + ) + ] + ) # Send the finish response for each request.n only once + if ( + auto_tools_called + or tools_streamed[i] + or (self.use_harmony and harmony_tools_streamed[i]) + ): + finish_reason_ = "tool_calls" + else: + finish_reason_ = ( + output.finish_reason if output.finish_reason else "stop" + ) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason - if not auto_tools_called else "tool_calls", + finish_reason=finish_reason_, stop_reason=output.stop_reason, - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None)) + token_ids=( + as_list(output.token_ids) + if request.return_token_ids + else None + ), + ) finish_reason_sent[i] = True @@ -1008,7 +1175,8 @@ class OpenAIServingChat(OpenAIServing): object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) # handle usage stats if requested & if continuous if include_continuous_usage: @@ -1026,13 +1194,15 @@ class OpenAIServingChat(OpenAIServing): # is sent, send the usage if include_usage: completion_tokens = sum(previous_num_tokens) - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=num_cached_tokens) + cached_tokens=num_cached_tokens + ) final_usage_chunk = ChatCompletionStreamResponse( id=request_id, @@ -1040,9 +1210,11 @@ class OpenAIServingChat(OpenAIServing): created=created_time, choices=[], model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -1059,14 +1231,13 @@ class OpenAIServingChat(OpenAIServing): for i in range(num_choices): full_text = ( previous_texts[i] - if previous_texts and i < len(previous_texts) else - f"<streaming_complete: {previous_num_tokens[i]} tokens>" + if previous_texts and i < len(previous_texts) + else f"<streaming_complete: {previous_num_tokens[i]} tokens>" ) self.request_logger.log_outputs( request_id=request_id, outputs=full_text, - output_token_ids= - None, # Consider also logging all token IDs + output_token_ids=None, # Consider also logging all token IDs finish_reason="streaming_complete", is_streaming=True, delta=False, @@ -1090,7 +1261,6 @@ class OpenAIServingChat(OpenAIServing): tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: - created_time = int(time.time()) final_res: Optional[RequestOutput] = None @@ -1106,7 +1276,7 @@ class OpenAIServingChat(OpenAIServing): assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] - if self.tool_call_id_type == 'kimi_k2': + if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) else: history_tool_call_cnt = 0 @@ -1115,6 +1285,7 @@ class OpenAIServingChat(OpenAIServing): for output in final_res.outputs: token_ids = output.token_ids out_logprobs = output.logprobs + tool_call_info = None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" @@ -1129,32 +1300,41 @@ class OpenAIServingChat(OpenAIServing): logprobs = None if self.use_harmony: - reasoning_content, final_content, is_tool_call = ( - parse_chat_output(token_ids)) + reasoning_content, content, _ = parse_chat_output(token_ids) if not request.include_reasoning: reasoning_content = None - if is_tool_call: - # TODO(woosuk): Implement tool call for gpt-oss. - # For now, only Responses API supports tool call for - # gpt-oss. - raise NotImplementedError( - "Tool call in Chat Completion API is not supported " - "for gpt-oss yet. Please use Responses API instead.") - else: - # Normal message + if self.tool_parser is not None: + tool_parser = self.tool_parser(tokenizer) + # NOTE: We use token_ids for openai tool parser + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + content = tool_call_info.content message = ChatMessage( role=role, reasoning_content=reasoning_content, - content=final_content, + content=content, + tool_calls=tool_call_info.tool_calls, + ) + else: + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" if is_tool_call else - output.finish_reason if output.finish_reason else "stop", + finish_reason="tool_calls" + if (tool_call_info is not None and tool_call_info.tools_called) + else output.finish_reason + if output.finish_reason + else "stop", stop_reason=output.stop_reason, ) choices.append(choice_data) @@ -1168,9 +1348,9 @@ class OpenAIServingChat(OpenAIServing): return self.create_error_response(str(e)) # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. - reasoning_content, content = ( - reasoning_parser.extract_reasoning_content( - output.text, request=request)) + reasoning_content, content = reasoning_parser.extract_reasoning_content( + output.text, request=request + ) if not request.include_reasoning: reasoning_content = None else: @@ -1180,76 +1360,93 @@ class OpenAIServingChat(OpenAIServing): auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if (not self.enable_auto_tools or not self.tool_parser) and \ - (not isinstance(request.tool_choice, - ChatCompletionNamedToolChoiceParam - ) and request.tool_choice != "required"): - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + if (not self.enable_auto_tools or not self.tool_parser) and ( + not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) + and request.tool_choice != "required" + ): + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) # if the request uses tools and specified a tool choice - elif request.tool_choice and type( - request.tool_choice) is ChatCompletionNamedToolChoiceParam: - - tool_call_class = MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall + elif ( + request.tool_choice + and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam + ): + tool_call_class = ( + MistralToolCall + if isinstance(tokenizer, MistralTokenizer) + else ToolCall + ) message = ChatMessage( role=role, reasoning_content=reasoning_content, content="", tool_calls=[ - tool_call_class(function=FunctionCall( - name=request.tool_choice.function.name, - arguments=content, - )) + tool_call_class( + function=FunctionCall( + name=request.tool_choice.function.name, + arguments=content, + ) + ) ], ) elif request.tool_choice and request.tool_choice == "required": - tool_call_class = MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall + tool_call_class = ( + MistralToolCall + if isinstance(tokenizer, MistralTokenizer) + else ToolCall + ) # the fields of FunctionDefinition are a superset of the # tool call outputs and can be used for parsing assert content is not None - tool_calls = TypeAdapter( - list[FunctionDefinition]).validate_json(content) + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json( + content + ) tool_call_ids = [] for tool_call in tool_calls: tool_call_ids.append( - make_tool_call_id(id_type=self.tool_call_id_type, - func_name=tool_call.name, - idx=history_tool_call_cnt)) + make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt, + ) + ) history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", tool_calls=[ - tool_call_class(id=tool_call_ids[i], - function=FunctionCall( - name=tool_call.name, - arguments=json.dumps( - tool_call.parameters, - ensure_ascii=False))) + tool_call_class( + id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, ensure_ascii=False + ), + ), + ) for i, tool_call in enumerate(tool_calls) ], - reasoning_content=reasoning_content) + reasoning_content=reasoning_content, + ) # if the request doesn't use tool choice # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": - - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) # handle when there are tools and tool choice is auto - elif request.tools and ( - request.tool_choice == "auto" - or request.tool_choice is None) and self.enable_auto_tools \ - and self.tool_parser: - + elif ( + request.tools + and (request.tool_choice == "auto" or request.tool_choice is None) + and self.enable_auto_tools + and self.tool_parser + ): try: tool_parser = self.tool_parser(tokenizer) except RuntimeError as e: @@ -1257,16 +1454,19 @@ class OpenAIServingChat(OpenAIServing): return self.create_error_response(str(e)) tool_call_info = tool_parser.extract_tool_calls( - content if content is not None else "", request=request) + content if content is not None else "", request=request + ) # In the OpenAI API the finish_reason is "tools_called" # if the tool choice is auto and the model produced a tool # call. The same is not true for named function calls auto_tools_called = tool_call_info.tools_called if tool_call_info.tools_called: - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=tool_call_info.content, - tool_calls=tool_call_info.tool_calls) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls, + ) else: # FOR NOW make it a chat message; we will have to detect @@ -1275,48 +1475,55 @@ class OpenAIServingChat(OpenAIServing): # try to use content return from tool parser first, # tool parser may do some modify for the content. - if (tool_call_info.content - and len(tool_call_info.content) > 0): + if tool_call_info.content and len(tool_call_info.content) > 0: ret_content = tool_call_info.content - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=ret_content) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=ret_content, + ) # undetermined case that is still important to handle else: logger.error( "Error in chat_completion_full_generator - cannot determine" " if tools should be extracted. Returning a standard chat " - "completion.") - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + "completion." + ) + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" if auto_tools_called else - output.finish_reason if output.finish_reason else "stop", + finish_reason="tool_calls" + if auto_tools_called + else output.finish_reason + if output.finish_reason + else "stop", stop_reason=output.stop_reason, - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None), + token_ids=( + as_list(output.token_ids) if request.return_token_ids else None + ), ) choices.append(choice_data) if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if (conversation and "content" in conversation[-1] - and conversation[-1].get("role") == role): + if ( + conversation + and "content" in conversation[-1] + and conversation[-1].get("role") == role + ): last_msg_content = conversation[-1]["content"] or "" if isinstance(last_msg_content, list): - last_msg_content = "\n".join(msg['text'] - for msg in last_msg_content) + last_msg_content = "\n".join(msg["text"] for msg in last_msg_content) for choice in choices: - full_message = last_msg_content + (choice.message.content - or "") + full_message = last_msg_content + (choice.message.content or "") choice.message.content = full_message assert final_res.prompt_token_ids is not None @@ -1324,14 +1531,17 @@ class OpenAIServingChat(OpenAIServing): if final_res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + - num_generated_tokens) + len(output.token_ids) for output in final_res.outputs + ) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens) + cached_tokens=final_res.num_cached_tokens + ) request_metadata.final_usage_info = usage @@ -1342,8 +1552,9 @@ class OpenAIServingChat(OpenAIServing): choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), - prompt_token_ids=(final_res.prompt_token_ids - if request.return_token_ids else None), + prompt_token_ids=( + final_res.prompt_token_ids if request.return_token_ids else None + ), kv_transfer_params=final_res.kv_transfer_params, ) @@ -1358,9 +1569,11 @@ class OpenAIServingChat(OpenAIServing): tool_call_descriptions = [] for tc in choice.message.tool_calls: if hasattr(tc.function, "name") and hasattr( - tc.function, "arguments"): + tc.function, "arguments" + ): tool_call_descriptions.append( - f"{tc.function.name}({tc.function.arguments})") + f"{tc.function.name}({tc.function.arguments})" + ) tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" @@ -1368,8 +1581,7 @@ class OpenAIServingChat(OpenAIServing): # Get the corresponding output token IDs output_token_ids = None if choice.index < len(final_res.outputs): - output_token_ids = final_res.outputs[ - choice.index].token_ids + output_token_ids = final_res.outputs[choice.index].token_ids self.request_logger.log_outputs( request_id=request_id, @@ -1383,21 +1595,27 @@ class OpenAIServingChat(OpenAIServing): return response def _get_top_logprobs( - self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], - tokenizer: AnyTokenizer, - should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: + self, + logprobs: dict[int, Logprob], + top_logprobs: Optional[int], + tokenizer: AnyTokenizer, + should_return_as_token_id: bool, + ) -> list[ChatCompletionLogProb]: return [ ChatCompletionLogProb( - token=(token := self._get_decoded_token( - p[1], - p[0], - tokenizer, - return_as_token_id=should_return_as_token_id, - )), + token=( + token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + ) + ), logprob=max(p[1].logprob, -9999.0), bytes=list(token.encode("utf-8", errors="replace")), - ) for i, p in enumerate(logprobs.items()) - if top_logprobs and i < top_logprobs + ) + for i, p in enumerate(logprobs.items()) + if (top_logprobs and i < top_logprobs or top_logprobs == -1) ] def _create_chat_logprobs( @@ -1411,21 +1629,25 @@ class OpenAIServingChat(OpenAIServing): """Create OpenAI-style logprobs.""" logprobs_content: list[ChatCompletionLogProbsContent] = [] - should_return_as_token_id = return_as_token_id if \ - return_as_token_id is not None else self.return_tokens_as_token_ids + should_return_as_token_id = ( + return_as_token_id + if return_as_token_id is not None + else self.return_tokens_as_token_ids + ) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is None or step_top_logprobs.get( - token_id) is None: - token = tokenizer.decode(token_id) + if step_top_logprobs is None or step_top_logprobs.get(token_id) is None: if should_return_as_token_id: token = f"token_id:{token_id}" + else: + token = tokenizer.decode(token_id) logprobs_content.append( ChatCompletionLogProbsContent( token=token, bytes=list(token.encode("utf-8", errors="replace")), - )) + ) + ) else: step_token = step_top_logprobs[token_id] step_decoded = step_token.decoded_token @@ -1439,17 +1661,21 @@ class OpenAIServingChat(OpenAIServing): should_return_as_token_id, ), logprob=max(step_token.logprob, -9999.0), - bytes=None if step_decoded is None else list( - step_decoded.encode("utf-8", errors="replace")), + bytes=None + if step_decoded is None + else list(step_decoded.encode("utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs, - tokenizer, should_return_as_token_id), - )) + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + should_return_as_token_id, + ), + ) + ) return ChatCompletionLogProbs(content=logprobs_content) - def _should_stream_with_auto_tool_parsing(self, - request: ChatCompletionRequest): + def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest): """ Utility function to check if streamed tokens should go through the tool call parser that was configured. @@ -1458,8 +1684,12 @@ class OpenAIServingChat(OpenAIServing): is configured, "auto" tool choice is enabled, and the request's tool choice field indicates that "auto" tool choice should be used. """ - return (request.tools and self.tool_parser and self.enable_auto_tools - and request.tool_choice in ['auto', None]) + return ( + request.tools + and self.tool_parser + and self.enable_auto_tools + and request.tool_choice in ["auto", None] + ) def _should_check_for_unstreamed_tool_arg_tokens( self, @@ -1472,13 +1702,15 @@ class OpenAIServingChat(OpenAIServing): is a tool call with arguments. """ - # yapf: disable return bool( # if there is a delta message that includes tool calls which # include a function that has arguments output.finish_reason is not None - and self.enable_auto_tools and self.tool_parser and delta_message - and delta_message.tool_calls and delta_message.tool_calls[0] + and self.enable_auto_tools + and self.tool_parser + and delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0] and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) @@ -1497,16 +1729,18 @@ class OpenAIServingChat(OpenAIServing): sys_msg = get_system_message( reasoning_effort=request.reasoning_effort, browser_description=None, - python_description=None) + python_description=None, + with_custom_tools=request.tools is not None, + ) messages.append(sys_msg) # Add developer message. - dev_msg = get_developer_message() + dev_msg = get_developer_message(tools=request.tools) messages.append(dev_msg) # Add user message. for chat_msg in request.messages: - messages.append(parse_chat_input(chat_msg)) + messages.extend(parse_chat_input(chat_msg)) # Render prompt token ids. prompt_token_ids = render_for_completion(messages) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 377f7f6847179..0e9a5846276bc 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -8,18 +8,22 @@ import numpy as np from fastapi import Request from typing_extensions import override -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ClassificationData, - ClassificationRequest, - ClassificationResponse, - ErrorResponse, UsageInfo) -# yapf: enable -from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, - OpenAIServing, - ServeContext) +from vllm.entrypoints.openai.protocol import ( + ClassificationData, + ClassificationRequest, + ClassificationResponse, + ErrorResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import ( + ClassificationServeContext, + OpenAIServing, + ServeContext, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.pooling_params import PoolingParams @@ -28,7 +32,6 @@ logger = init_logger(__name__) class ClassificationMixin(OpenAIServing): - @override async def _preprocess( self, @@ -49,19 +52,12 @@ class ClassificationMixin(OpenAIServing): return None try: - ctx.lora_request = self._maybe_get_adapters(ctx.request) + ctx.tokenizer = await self.engine_client.get_tokenizer() - ctx.tokenizer = await self.engine_client.get_tokenizer( - ctx.lora_request) - - ( - ctx.request_prompts, - ctx.engine_prompts, - ) = await self._preprocess_completion( - ctx.request, - ctx.tokenizer, - ctx.request.input, - truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, + renderer = self._get_renderer(ctx.tokenizer) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + config=self._build_render_config(ctx.request), ) return None @@ -83,16 +79,16 @@ class ClassificationMixin(OpenAIServing): items: list[ClassificationData] = [] num_prompt_tokens = 0 - final_res_batch_checked = cast(list[PoolingRequestOutput], - ctx.final_res_batch) + final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) for idx, final_res in enumerate(final_res_batch_checked): classify_res = ClassificationOutput.from_base(final_res.outputs) probs = classify_res.probs predicted_index = int(np.argmax(probs)) - label = getattr(self.model_config.hf_config, "id2label", - {}).get(predicted_index) + label = getattr(self.model_config.hf_config, "id2label", {}).get( + predicted_index + ) item = ClassificationData( index=idx, @@ -118,6 +114,12 @@ class ClassificationMixin(OpenAIServing): usage=usage, ) + def _build_render_config(self, request: ClassificationRequest) -> RenderConfig: + return RenderConfig( + max_length=self.max_model_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + ) + class ServingClassification(ClassificationMixin): request_id_prefix = "classify" @@ -125,16 +127,16 @@ class ServingClassification(ClassificationMixin): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, + log_error_stack=log_error_stack, ) async def create_classify( @@ -142,9 +144,8 @@ class ServingClassification(ClassificationMixin): request: ClassificationRequest, raw_request: Request, ) -> Union[ClassificationResponse, ErrorResponse]: - model_name = self._get_model_name(request.model) - request_id = (f"{self.request_id_prefix}-" - f"{self._base_request_id(raw_request)}") + model_name = self.models.model_name() + request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" ctx = ClassificationServeContext( request=request, @@ -155,18 +156,6 @@ class ServingClassification(ClassificationMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ClassificationServeContext, - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a0ce654094039..d18301103e475 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -9,38 +9,30 @@ from typing import Optional, Union, cast import jinja2 from fastapi import Request -from typing_extensions import assert_never -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (CompletionLogProbs, - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - ErrorResponse, - PromptTokenUsageInfo, - RequestResponseMetadata, - UsageInfo) -from vllm.entrypoints.openai.serving_engine import ( - EmbedsPrompt as ServingEngineEmbedsPrompt) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - TextTokensPrompt, - clamp_prompt_logprobs, - is_text_tokens_prompt) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + PromptTokenUsageInfo, + RequestResponseMetadata, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens -from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, - is_tokens_prompt) +from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import as_list, merge_async_iterators @@ -48,29 +40,27 @@ logger = init_logger(__name__) class OpenAIServingCompletion(OpenAIServing): - def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source @@ -106,16 +96,17 @@ class OpenAIServingCompletion(OpenAIServing): # Return error for unsupported features. if request.suffix is not None: - return self.create_error_response( - "suffix is not currently supported") + return self.create_error_response("suffix is not currently supported") if request.echo and request.prompt_embeds is not None: - return self.create_error_response( - "Echo is unsupported with prompt embeds.") + return self.create_error_response("Echo is unsupported with prompt embeds.") - request_id = ( - f"cmpl-" - f"{self._base_request_id(raw_request, request.request_id)}") + if request.prompt_logprobs is not None and request.prompt_embeds is not None: + return self.create_error_response( + "prompt_logprobs is not compatible with prompt embeds." + ) + + request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}" created_time = int(time.time()) request_metadata = RequestResponseMetadata(request_id=request_id) @@ -125,14 +116,16 @@ class OpenAIServingCompletion(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) - request_prompts, engine_prompts = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, + engine_prompts = await renderer.render_prompt_and_embeds( + prompt_or_prompts=request.prompt, + prompt_embeds=request.prompt_embeds, + config=self._build_render_config(request), ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") @@ -151,23 +144,17 @@ class OpenAIServingCompletion(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - sampling_params: Union[SamplingParams, BeamSearchParams] - # Mypy does not infer that engine_prompt will have only one of - # "prompt_token_ids" or "prompt_embeds" defined, and both of - # these as Union[object, the expected type], where it infers - # object if engine_prompt is a subclass of one of the - # typeddicts that defines both keys. Worse, because of - # https://github.com/python/mypy/issues/8586, mypy does not - # infer the type of engine_prompt correctly because of the - # enumerate. So we need an unnecessary cast here. - engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], - engine_prompt) - if is_embeds_prompt(engine_prompt): - input_length = len(engine_prompt["prompt_embeds"]) - elif is_tokens_prompt(engine_prompt): - input_length = len(engine_prompt["prompt_token_ids"]) + prompt_text, prompt_token_ids, prompt_embeds = ( + self._get_prompt_components(engine_prompt) + ) + + input_length = None + if prompt_token_ids is not None: + input_length = len(prompt_token_ids) + elif prompt_embeds is not None: + input_length = len(prompt_embeds) else: - assert_never(engine_prompt) + raise NotImplementedError if self.default_sampling_params is None: self.default_sampling_params = {} @@ -179,9 +166,11 @@ class OpenAIServingCompletion(OpenAIServing): default_sampling_params=self.default_sampling_params, ) + sampling_params: Union[SamplingParams, BeamSearchParams] if request.use_beam_search: sampling_params = request.to_beam_search_params( - max_tokens, self.default_sampling_params) + max_tokens, self.default_sampling_params + ) else: sampling_params = request.to_sampling_params( max_tokens, @@ -193,34 +182,47 @@ class OpenAIServingCompletion(OpenAIServing): self._log_inputs( request_id_item, - request_prompts[i], + engine_prompt, params=sampling_params, lora_request=lora_request, ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) # Mypy inconsistently requires this second cast in different # environments. It shouldn't be necessary (redundant from above) # but pre-commit in CI fails without it. - engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], - engine_prompt) + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt) if isinstance(sampling_params, BeamSearchParams): - generator = self.engine_client.beam_search( + generator = self.beam_search( prompt=engine_prompt, request_id=request_id, params=sampling_params, lora_request=lora_request, ) else: - generator = self.engine_client.generate( + engine_request, tokenization_kwargs = await self._process_inputs( + request_id_item, engine_prompt, sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generator = self.engine_client.generate( + engine_request, + sampling_params, request_id_item, lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + prompt_text=prompt_text, + tokenization_kwargs=tokenization_kwargs, ) generators.append(generator) @@ -230,21 +232,23 @@ class OpenAIServingCompletion(OpenAIServing): result_generator = merge_async_iterators(*generators) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. Noting that best_of is only supported in V0. In addition, # we do not stream the results when use beam search. - stream = (request.stream - and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search) + stream = ( + request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search + ) # Streaming response if stream: return self.completion_stream_generator( request, - request_prompts, + engine_prompts, result_generator, request_id, created_time, @@ -268,14 +272,14 @@ class OpenAIServingCompletion(OpenAIServing): # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - request_prompt = request_prompts[i] - if is_text_tokens_prompt(request_prompt): - final_res.prompt = request_prompt["prompt"] - else: - final_res.prompt = None + engine_prompt = engine_prompts[i] + final_res.prompt = ( + None + if is_embeds_prompt(engine_prompt) + else engine_prompt.get("prompt") + ) - final_res_batch_checked = cast(list[RequestOutput], - final_res_batch) + final_res_batch_checked = cast(list[RequestOutput], final_res_batch) response = self.request_output_to_completion_response( final_res_batch_checked, @@ -308,8 +312,7 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, - request_prompts: list[Union[TextTokensPrompt, - ServingEngineEmbedsPrompt]], + engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -329,10 +332,10 @@ class OpenAIServingCompletion(OpenAIServing): stream_options = request.stream_options if stream_options: - include_usage = (stream_options.include_usage - or enable_force_include_usage) - include_continuous_usage = (include_usage and - stream_options.continuous_usage_stats) + include_usage = stream_options.include_usage or enable_force_include_usage + include_continuous_usage = ( + include_usage and stream_options.continuous_usage_stats + ) else: include_usage, include_continuous_usage = False, False @@ -345,22 +348,21 @@ class OpenAIServingCompletion(OpenAIServing): num_cached_tokens = res.num_cached_tokens first_iteration = False - if res.prompt is not None: - prompt_text = res.prompt - else: - request_prompt = request_prompts[prompt_idx] - if is_text_tokens_prompt(request_prompt): - prompt_text = request_prompt["prompt"] - else: - prompt_text = None + prompt_text = res.prompt + if prompt_text is None: + engine_prompt = engine_prompts[prompt_idx] + prompt_text = ( + None + if is_embeds_prompt(engine_prompt) + else engine_prompt.get("prompt") + ) # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[dict[ - int, Logprob]]]] + out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] for output in res.outputs: i = output.index + prompt_idx * num_choices @@ -373,6 +375,8 @@ class OpenAIServingCompletion(OpenAIServing): assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None + if request.return_token_ids: + prompt_text = "" assert prompt_text is not None if request.max_tokens == 0: # only return the prompt @@ -404,22 +408,23 @@ class OpenAIServingCompletion(OpenAIServing): prompt_token_ids_to_return = prompt_token_ids has_echoed[i] = True - if (not delta_text and not delta_token_ids - and not previous_num_tokens[i]): + if ( + not delta_text + and not delta_token_ids + and not previous_num_tokens[i] + ): # Chunked prefill case, don't return empty chunks continue if request.logprobs is not None: - assert out_logprobs is not None, ( - "Did not output logprobs") + assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, initial_text_offset=previous_text_lens[i], - return_as_token_id=request. - return_tokens_as_token_ids, + return_as_token_id=request.return_tokens_as_token_ids, ) else: logprobs = None @@ -441,8 +446,11 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=finish_reason, stop_reason=stop_reason, prompt_token_ids=prompt_token_ids_to_return, - token_ids=(as_list(output.token_ids) if - request.return_token_ids else None), + token_ids=( + as_list(output.token_ids) + if request.return_token_ids + else None + ), ) ], ) @@ -468,7 +476,8 @@ class OpenAIServingCompletion(OpenAIServing): if self.enable_prompt_tokens_details and num_cached_tokens: final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=num_cached_tokens) + cached_tokens=num_cached_tokens + ) if include_usage: final_usage_chunk = CompletionStreamResponse( @@ -479,7 +488,8 @@ class OpenAIServingCompletion(OpenAIServing): usage=final_usage_info, ) final_usage_data = final_usage_chunk.model_dump_json( - exclude_unset=False, exclude_none=True) + exclude_unset=False, exclude_none=True + ) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -514,12 +524,13 @@ class OpenAIServingCompletion(OpenAIServing): prompt_text = final_res.prompt token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[dict[int, - Logprob]]]] + out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] for output in final_res.outputs: assert request.max_tokens is not None if request.echo: + if request.return_token_ids: + prompt_text = "" assert prompt_text is not None if request.max_tokens == 0: token_ids = prompt_token_ids @@ -563,10 +574,12 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, - prompt_token_ids=(prompt_token_ids - if request.return_token_ids else None), - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None), + prompt_token_ids=( + prompt_token_ids if request.return_token_ids else None + ), + token_ids=( + as_list(output.token_ids) if request.return_token_ids else None + ), ) choices.append(choice_data) @@ -580,10 +593,14 @@ class OpenAIServingCompletion(OpenAIServing): total_tokens=num_prompt_tokens + num_generated_tokens, ) - if (self.enable_prompt_tokens_details and last_final_res - and last_final_res.num_cached_tokens): + if ( + self.enable_prompt_tokens_details + and last_final_res + and last_final_res.num_cached_tokens + ): usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=last_final_res.num_cached_tokens) + cached_tokens=last_final_res.num_cached_tokens + ) request_metadata.final_usage_info = usage if final_res_batch: @@ -614,9 +631,11 @@ class OpenAIServingCompletion(OpenAIServing): last_token_len = 0 - should_return_as_token_id = (return_as_token_id - if return_as_token_id is not None else - self.return_tokens_as_token_ids) + should_return_as_token_id = ( + return_as_token_id + if return_as_token_id is not None + else self.return_tokens_as_token_ids + ) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: @@ -645,19 +664,20 @@ class OpenAIServingCompletion(OpenAIServing): # logprobs, as defined in the openai API # (cf. https://github.com/openai/openai-openapi/blob/ # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) - out_top_logprobs.append({ - # Convert float("-inf") to the - # JSON-serializable float that OpenAI uses - self._get_decoded_token( - top_lp[1], - top_lp[0], - tokenizer, - return_as_token_id=should_return_as_token_id, - ): - max(top_lp[1].logprob, -9999.0) - for i, top_lp in enumerate(step_top_logprobs.items()) - if num_output_top_logprobs >= i - }) + out_top_logprobs.append( + { + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token( + top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + ): max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + } + ) if len(out_text_offset) == 0: out_text_offset.append(initial_text_offset) @@ -671,3 +691,17 @@ class OpenAIServingCompletion(OpenAIServing): tokens=out_tokens, top_logprobs=out_top_logprobs, ) + + def _build_render_config( + self, + request: CompletionRequest, + max_input_length: Optional[int] = None, + ) -> RenderConfig: + max_input_tokens_len = self.max_model_len - (request.max_tokens or 0) + return RenderConfig( + max_length=max_input_tokens_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + cache_salt=request.cache_salt, + needs_detokenization=bool(request.echo and not request.return_token_ids), + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9dcad8e391c68..e0c9d9aa812f6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -10,30 +10,35 @@ import torch from fastapi import Request from typing_extensions import assert_never, override -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this docstring -# yapf: disable -from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, - EmbeddingCompletionRequest, - EmbeddingRequest, - EmbeddingResponse, - EmbeddingResponseData, - ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, - OpenAIServing, - RequestPrompt, - ServeContext, - TextTokensPrompt) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import ( + EmbeddingServeContext, + OpenAIServing, + ServeContext, + TextTokensPrompt, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.entrypoints.renderer import RenderConfig from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger -from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingOutput, PoolingRequestOutput, RequestOutput) +from vllm.outputs import ( + EmbeddingOutput, + EmbeddingRequestOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, +) from vllm.pooling_params import PoolingParams from vllm.utils import chunk_list @@ -56,7 +61,6 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -64,9 +68,13 @@ class EmbeddingMixin(OpenAIServing): # Avoid repeated attribute lookups self.supports_chunked_processing = bool( - pooler_config and pooler_config.enable_chunked_processing) - self.max_embed_len = (pooler_config.max_embed_len if pooler_config - and pooler_config.max_embed_len else None) + pooler_config and pooler_config.enable_chunked_processing + ) + self.max_embed_len = ( + pooler_config.max_embed_len + if pooler_config and pooler_config.max_embed_len + else None + ) @override async def _preprocess( @@ -77,43 +85,47 @@ class EmbeddingMixin(OpenAIServing): try: ctx.lora_request = self._maybe_get_adapters(ctx.request) - tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request - ) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) if isinstance(ctx.request, EmbeddingChatRequest): ( _, - ctx.request_prompts, + _, ctx.engine_prompts, ) = await self._preprocess_chat( ctx.request, tokenizer, ctx.request.messages, - chat_template=ctx.request.chat_template - or ctx.chat_template, - chat_template_content_format=ctx. - chat_template_content_format, - # In embedding requests, we are not generating tokens, - # so there is no need to append extra tokens to the input - add_generation_prompt=False, + chat_template=ctx.request.chat_template or ctx.chat_template, + chat_template_content_format=ctx.chat_template_content_format, + add_generation_prompt=ctx.request.add_generation_prompt, continue_final_message=False, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, add_special_tokens=ctx.request.add_special_tokens, ) else: - (ctx.request_prompts, - ctx.engine_prompts) = await self._preprocess_completion( - ctx.request, - tokenizer, - ctx.request.input, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, - add_special_tokens=ctx.request.add_special_tokens, - ) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + config=self._build_render_config(ctx.request), + ) return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig: + # Set max_length based on chunked processing capability + if self._should_use_chunked_processing(request): + max_length = None + else: + max_length = self.max_embed_len or self.max_model_len + + return RenderConfig( + max_length=max_length, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + @override def _build_response( self, @@ -122,16 +134,16 @@ class EmbeddingMixin(OpenAIServing): items: list[EmbeddingResponseData] = [] num_prompt_tokens = 0 - final_res_batch_checked = cast(list[PoolingRequestOutput], - ctx.final_res_batch) + final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) for idx, final_res in enumerate(final_res_batch_checked): embedding_res = EmbeddingRequestOutput.from_base(final_res) item = EmbeddingResponseData( index=idx, - embedding=_get_embedding(embedding_res.outputs, - ctx.request.encoding_format), + embedding=_get_embedding( + embedding_res.outputs, ctx.request.encoding_format + ), ) prompt_token_ids = final_res.prompt_token_ids @@ -157,10 +169,10 @@ class EmbeddingMixin(OpenAIServing): def _should_use_chunked_processing(self, request) -> bool: """Check if chunked processing should be used for this request.""" - return isinstance( - request, - (EmbeddingCompletionRequest, - EmbeddingChatRequest)) and self.supports_chunked_processing + return ( + isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)) + and self.supports_chunked_processing + ) async def _process_chunked_request( self, @@ -178,25 +190,27 @@ class EmbeddingMixin(OpenAIServing): max_pos_embeddings = self._get_max_position_embeddings() # Process all chunks for MEAN aggregation for chunk_idx, chunk_tokens in enumerate( - chunk_list(token_ids, max_pos_embeddings)): + chunk_list(token_ids, max_pos_embeddings) + ): # Create a request ID for this chunk - chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" - f"chunk-{chunk_idx}") + chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" # Create engine prompt for this chunk - chunk_engine_prompt = EngineTokensPrompt( - prompt_token_ids=chunk_tokens) + chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens) # Create chunk request prompt for logging chunk_text = "" chunk_request_prompt = TextTokensPrompt( - prompt=chunk_text, prompt_token_ids=chunk_tokens) + prompt=chunk_text, prompt_token_ids=chunk_tokens + ) # Log the chunk - self._log_inputs(chunk_request_id, - chunk_request_prompt, - params=pooling_params, - lora_request=ctx.lora_request) + self._log_inputs( + chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) # Create generator for this chunk and wrap it to return indices original_generator = self.engine_client.encode( @@ -222,8 +236,7 @@ class EmbeddingMixin(OpenAIServing): token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingCompletionRequest, EmbeddingChatRequest)): + if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)): # Check if chunked processing is enabled for pooling models enable_chunked = self._should_use_chunked_processing(request) @@ -243,13 +256,15 @@ class EmbeddingMixin(OpenAIServing): validation_error_msg = ( "This model's {length_type} is {max_length_value} tokens. " "However, you requested {token_num} tokens in the input for " - "embedding generation. Please reduce the length of the input.") + "embedding generation. Please reduce the length of the input." + ) chunked_processing_error_msg = ( "This model's {length_type} is {max_length_value} tokens. " "However, you requested {token_num} tokens in the input for " "embedding generation. Please reduce the length of the input " - "or enable chunked processing.") + "or enable chunked processing." + ) # Check if input exceeds max length if token_num > max_length_value: @@ -257,7 +272,9 @@ class EmbeddingMixin(OpenAIServing): validation_error_msg.format( length_type=length_type, max_length_value=max_length_value, - token_num=token_num)) + token_num=token_num, + ) + ) # Check for chunked processing # when exceeding max_position_embeddings @@ -266,31 +283,36 @@ class EmbeddingMixin(OpenAIServing): # Allow long inputs when chunked processing is enabled logger.info( "Input length %s exceeds max_position_embeddings " - "%s, will use chunked processing", token_num, - max_pos_embeddings) + "%s, will use chunked processing", + token_num, + max_pos_embeddings, + ) else: raise ValueError( chunked_processing_error_msg.format( length_type="maximum position embeddings length", max_length_value=max_pos_embeddings, - token_num=token_num)) + token_num=token_num, + ) + ) - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # For other request types, use the parent's implementation return super()._validate_input(request, input_ids, input_text) def _is_text_tokens_prompt(self, prompt) -> bool: """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt + ) async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], - request_prompt: RequestPrompt, + engine_prompt: EngineTokensPrompt, pooling_params: PoolingParams, trace_headers: Optional[Mapping[str, str]], prompt_index: int, @@ -298,16 +320,12 @@ class EmbeddingMixin(OpenAIServing): """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" - self._log_inputs(request_id_item, - request_prompt, - params=pooling_params, - lora_request=ctx.lora_request) - - # Mypy has an existing bug related to inferring the variance - # of TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) # Return the original generator without wrapping return self.engine_client.encode( @@ -335,13 +353,16 @@ class EmbeddingMixin(OpenAIServing): return await super()._prepare_generators(ctx) # Custom logic for chunked processing - generators: list[AsyncGenerator[Union[RequestOutput, - PoolingRequestOutput], - None]] = [] + generators: list[ + AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None] + ] = [] try: - trace_headers = (None if ctx.raw_request is None else await - self._get_trace_headers(ctx.raw_request.headers)) + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) pooling_params = self._create_pooling_params(ctx) if isinstance(pooling_params, ErrorResponse): @@ -354,43 +375,32 @@ class EmbeddingMixin(OpenAIServing): return self.create_error_response(str(e)) if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") - - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") + return self.create_error_response("Engine prompts not available") max_pos_embeddings = self._get_max_position_embeddings() for i, engine_prompt in enumerate(ctx.engine_prompts): - request_prompt = ctx.request_prompts[i] - # Check if this specific prompt needs chunked processing - if self._is_text_tokens_prompt(request_prompt): + if self._is_text_tokens_prompt(engine_prompt): # Cast to TextTokensPrompt since we've verified # prompt_token_ids - text_tokens_prompt = cast(TextTokensPrompt, request_prompt) - if (len(text_tokens_prompt["prompt_token_ids"]) - > max_pos_embeddings): + text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) + if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings: # Use chunked processing for this prompt chunk_generators = await self._process_chunked_request( - ctx, text_tokens_prompt, pooling_params, - trace_headers, i) + ctx, text_tokens_prompt, pooling_params, trace_headers, i + ) generators.extend(chunk_generators) continue # Normal processing for short prompts or non-token prompts - # Cast engine_prompt to the expected type for mypy - engine_prompt_typed = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) generator = await self._create_single_prompt_generator( - ctx, engine_prompt_typed, request_prompt, pooling_params, - trace_headers, i) + ctx, engine_prompt, pooling_params, trace_headers, i + ) generators.append(generator) from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) return None @@ -406,16 +416,15 @@ class EmbeddingMixin(OpenAIServing): ) -> Optional[ErrorResponse]: """Collect and aggregate batch results with support for chunked processing. - - For chunked requests, performs online aggregation to + + For chunked requests, performs online aggregation to minimize memory usage. For regular requests, collects results normally. """ ctx = cast(EmbeddingServeContext, ctx) try: if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") # Check if we used chunked processing use_chunked = self._should_use_chunked_processing(ctx.request) @@ -423,13 +432,8 @@ class EmbeddingMixin(OpenAIServing): if not use_chunked: return await super()._collect_batch(ctx=ctx) - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - if ctx.result_generator is None: - return self.create_error_response( - "Result generator not available") + return self.create_error_response("Result generator not available") # Online aggregation for chunked requests to # minimize memory usage @@ -450,10 +454,10 @@ class EmbeddingMixin(OpenAIServing): # Initialize aggregator for this prompt if needed if prompt_idx not in prompt_aggregators: prompt_aggregators[prompt_idx] = { - 'weighted_sum': None, - 'total_weight': 0, - 'chunk_count': 0, - 'request_id': result.request_id.split("-chunk-")[0] + "weighted_sum": None, + "total_weight": 0, + "chunk_count": 0, + "request_id": result.request_id.split("-chunk-")[0], } aggregator = prompt_aggregators[prompt_idx] @@ -465,44 +469,45 @@ class EmbeddingMixin(OpenAIServing): return self.create_error_response( f"Expected PoolingRequestOutput for " f"chunked embedding, got " - f"{type(result).__name__}") + f"{type(result).__name__}" + ) # Handle both PoolingOutput and # EmbeddingOutput types - if hasattr(result.outputs, 'data'): + if hasattr(result.outputs, "data"): # PoolingOutput case embedding_data = result.outputs.data - elif hasattr(result.outputs, 'embedding'): + elif hasattr(result.outputs, "embedding"): # EmbeddingOutput case - # convert embedding list to tensor embedding_data = result.outputs.embedding else: return self.create_error_response( - f"Unsupported output type: " - f"{type(result.outputs).__name__}") + f"Unsupported output type: {type(result.outputs).__name__}" + ) if not isinstance(embedding_data, torch.Tensor): - embedding_data = torch.tensor(embedding_data, - dtype=torch.float32) + embedding_data = torch.tensor( + embedding_data, dtype=torch.float32 + ) if result.prompt_token_ids is None: return self.create_error_response( - "prompt_token_ids cannot be None for " - "chunked processing") + "prompt_token_ids cannot be None for chunked processing" + ) weight = len(result.prompt_token_ids) - weighted_embedding = embedding_data.to( - dtype=torch.float32) * weight + weighted_embedding = embedding_data.to(dtype=torch.float32) * weight - if aggregator['weighted_sum'] is None: + if aggregator["weighted_sum"] is None: # First chunk - aggregator['weighted_sum'] = weighted_embedding + aggregator["weighted_sum"] = weighted_embedding else: # Accumulate - aggregator['weighted_sum'] += weighted_embedding + aggregator["weighted_sum"] += weighted_embedding - aggregator['total_weight'] += weight - aggregator['chunk_count'] += 1 + aggregator["total_weight"] += weight + aggregator["chunk_count"] += 1 else: # Non-chunked result - extract prompt_idx from request_id parts = result.request_id.split("-") @@ -513,11 +518,13 @@ class EmbeddingMixin(OpenAIServing): prompt_idx = result_idx # Fallback to result_idx short_prompts_results[prompt_idx] = cast( - PoolingRequestOutput, result) + PoolingRequestOutput, result + ) # Finalize aggregated results - final_res_batch: list[Union[PoolingRequestOutput, - EmbeddingRequestOutput]] = [] + final_res_batch: list[ + Union[PoolingRequestOutput, EmbeddingRequestOutput] + ] = [] num_prompts = len(ctx.engine_prompts) for prompt_idx in range(num_prompts): @@ -525,55 +532,57 @@ class EmbeddingMixin(OpenAIServing): # Finalize MEAN aggregation for this chunked prompt aggregator = prompt_aggregators[prompt_idx] - weighted_sum = aggregator['weighted_sum'] - total_weight = aggregator['total_weight'] - - if (weighted_sum is not None - and isinstance(weighted_sum, torch.Tensor) - and isinstance(total_weight, - (int, float)) and total_weight > 0): + weighted_sum = aggregator["weighted_sum"] + total_weight = aggregator["total_weight"] + if ( + weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, (int, float)) + and total_weight > 0 + ): # Compute final mean embedding final_embedding = weighted_sum / total_weight # Create a PoolingRequestOutput # for the aggregated result - pooling_output_data = PoolingOutput( - data=final_embedding) + pooling_output_data = PoolingOutput(data=final_embedding) # Get original prompt token IDs for this prompt - original_prompt = ctx.request_prompts[prompt_idx] + original_prompt = ctx.engine_prompts[prompt_idx] if not self._is_text_tokens_prompt(original_prompt): return self.create_error_response( - f"Chunked prompt {prompt_idx} is not a " - f"TextTokensPrompt") + f"Chunked prompt {prompt_idx} is not a TextTokensPrompt" + ) - original_token_ids = cast( - TextTokensPrompt, - original_prompt)["prompt_token_ids"] + original_token_ids = cast(TextTokensPrompt, original_prompt)[ + "prompt_token_ids" + ] pooling_request_output = PoolingRequestOutput( - request_id=aggregator['request_id'], + request_id=aggregator["request_id"], prompt_token_ids=original_token_ids, outputs=pooling_output_data, - finished=True) + finished=True, + ) final_res_batch.append(pooling_request_output) else: return self.create_error_response( - f"Failed to aggregate chunks " - f"for prompt {prompt_idx}") + f"Failed to aggregate chunks for prompt {prompt_idx}" + ) elif prompt_idx in short_prompts_results: final_res_batch.append( - cast(PoolingRequestOutput, - short_prompts_results[prompt_idx])) + cast(PoolingRequestOutput, short_prompts_results[prompt_idx]) + ) else: return self.create_error_response( - f"Result not found for prompt {prompt_idx}") + f"Result not found for prompt {prompt_idx}" + ) ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], - final_res_batch) + list[Union[RequestOutput, PoolingRequestOutput]], final_res_batch + ) return None @@ -587,20 +596,24 @@ class OpenAIServingEmbedding(EmbeddingMixin): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, + log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_embedding( self, @@ -613,10 +626,11 @@ class OpenAIServingEmbedding(EmbeddingMixin): See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = ( f"{self.request_id_prefix}-" - f"{self._base_request_id(raw_request, request.request_id)}") + f"{self._base_request_id(raw_request, request.request_id)}" + ) ctx = EmbeddingServeContext( request=request, @@ -629,18 +643,6 @@ class OpenAIServingEmbedding(EmbeddingMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ServeContext[EmbeddingRequest], - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, @@ -656,3 +658,17 @@ class OpenAIServingEmbedding(EmbeddingMixin): return self.create_error_response(str(e)) return pooling_params + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + if isinstance(ctx.request, EmbeddingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=ctx.request.chat_template, + chat_template_kwargs=ctx.request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + return await super()._preprocess(ctx) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0f4a7c0186b65..0d1a525c6d3da 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import io import json import sys import time +import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus -from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypeVar, Union, cast, overload) +from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union -import pybase64 import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field @@ -24,71 +22,107 @@ else: from typing_extensions import TypedDict import vllm.envs as envs -from vllm.config import ModelConfig +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.engine.protocol import EngineClient -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - ChatTemplateContentFormatOption, - ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages_futures, - resolve_chat_template_content_format) +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format, +) from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionResponse, - ClassificationRequest, - ClassificationResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - EmbeddingChatRequest, - EmbeddingCompletionRequest, - EmbeddingRequest, - EmbeddingResponse, ErrorInfo, - ErrorResponse, PoolingResponse, - RerankRequest, ResponsesRequest, - ScoreRequest, ScoreResponse, - TokenizeChatRequest, - TokenizeCompletionRequest, - TokenizeResponse, - TranscriptionRequest, - TranscriptionResponse, - TranslationRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + ErrorInfo, + ErrorResponse, + IOProcessorRequest, + PoolingResponse, + RerankRequest, + ResponsesRequest, + ScoreRequest, + ScoreResponse, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser -# yapf: enable -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig +from vllm.entrypoints.utils import _validate_truncation_size +from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import ( + PromptComponents, + get_prompt_components, + is_explicit_encoder_decoder_prompt, +) from vllm.logger import init_logger +from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict) -from vllm.outputs import PoolingRequestOutput, RequestOutput + MultiModalDataDict, + MultiModalUUIDDict, +) +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob, PromptLogprobs -from vllm.tracing import (contains_trace_headers, extract_trace_headers, - log_tracing_disabled_warning) +from vllm.tracing import ( + contains_trace_headers, + extract_trace_headers, + log_tracing_disabled_warning, +) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, - merge_async_iterators, random_uuid) +from vllm.utils import ( + AsyncMicrobatchTokenizer, + collect_from_async_generator, + is_list_of, + make_async, + merge_async_iterators, + random_uuid, +) +from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, - EmbeddingCompletionRequest, RerankRequest, - ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest] +CompletionLikeRequest = Union[ + CompletionRequest, + DetokenizeRequest, + EmbeddingCompletionRequest, + RerankRequest, + ClassificationRequest, + ScoreRequest, + TokenizeCompletionRequest, +] -ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, - TokenizeChatRequest] +ChatLikeRequest = Union[ + ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest +] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, - ResponsesRequest] +AnyRequest = Union[ + CompletionLikeRequest, + ChatLikeRequest, + SpeechToTextRequest, + ResponsesRequest, + IOProcessorRequest, +] AnyResponse = Union[ CompletionResponse, @@ -115,13 +149,19 @@ RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt + ) def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt + ) RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -132,9 +172,9 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ + request_prompts: Optional[Sequence[RequestPrompt]] = [] - engine_prompts: Optional[Union[list[EngineTokensPrompt], - list[EngineEmbedsPrompt]]] = [] + engine_prompts: Optional[list[EngineTokensPrompt]] = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -144,16 +184,23 @@ class ResponseGenerationMixin(BaseModel): Mixin for response generation, managing result generators and final batch results. """ - result_generator: Optional[AsyncGenerator[tuple[int, Union[ - RequestOutput, PoolingRequestOutput]], None]] = None + + result_generator: Optional[ + AsyncGenerator[tuple[int, Union[RequestOutput, PoolingRequestOutput]], None] + ] = None final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( - default_factory=list) + default_factory=list + ) model_config = ConfigDict(arbitrary_types_allowed=True) -class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, - Generic[RequestT]): +class ServeContext( + RequestProcessingMixin, + ResponseGenerationMixin, + BaseModel, + Generic[RequestT], +): # Shared across all requests request: RequestT raw_request: Optional[Request] = None @@ -164,7 +211,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, # Shared across most requests tokenizer: Optional[AnyTokenizer] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # `protected_namespaces` resolves Pydantic v2's warning # on conflict with protected namespace "model_" @@ -199,18 +245,16 @@ class OpenAIServing: def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__() self.engine_client = engine_client - self.model_config = model_config - self.max_model_len = model_config.max_model_len self.models = models @@ -219,9 +263,273 @@ class OpenAIServing: self.enable_force_include_usage = enable_force_include_usage self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + self._apply_mistral_chat_template_async = make_async( + apply_mistral_chat_template, executor=self._tokenizer_executor + ) - self._async_tokenizer_pool: dict[AnyTokenizer, - AsyncMicrobatchTokenizer] = {} + self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} + self.log_error_stack = log_error_stack + + self.processor = self.models.processor + self.io_processor = self.models.io_processor + self.model_config = self.models.model_config + self.max_model_len = self.model_config.max_model_len + + def _get_tool_parser( + self, tool_parser_name: Optional[str] = None, enable_auto_tools: bool = False + ) -> Optional[Callable[[AnyTokenizer], ToolParser]]: + """Get the tool parser based on the name.""" + parser = None + if not enable_auto_tools or tool_parser_name is None: + return parser + logger.info( + '"auto" tool choice has been enabled please note that while' + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored." + ) + + try: + if tool_parser_name == "pythonic" and self.model_config.model.startswith( + "meta-llama/Llama-3.2" + ): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic tool calls" + ) + parser = ToolParserManager.get_tool_parser(tool_parser_name) + except Exception as e: + raise TypeError( + "Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser_name}' which has not " + "been registered" + ) from e + return parser + + def _get_reasoning_parser( + self, + reasoning_parser_name: str, + ) -> Optional[Callable[[AnyTokenizer], ReasoningParser]]: + """Get the reasoning parser based on the name.""" + parser = None + if not reasoning_parser_name: + return None + try: + parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name) + assert parser is not None + except Exception as e: + raise TypeError(f"{reasoning_parser_name=} has not been registered") from e + return parser + + async def reset_mm_cache(self) -> None: + self.processor.clear_mm_cache() + await self.engine_client.reset_mm_cache() + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + lora_request: Optional[LoRARequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output + + processor = self.processor + tokenizer = processor.tokenizer + if tokenizer is None: + raise ValueError( + "You cannot use beam search when `skip_tokenizer_init` is True" + ) + + eos_token_id: int = tokenizer.eos_token_id # type: ignore + + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs( + prompt + ) + + if processed_inputs["type"] == "embeds": + raise NotImplementedError + + # This is a workaround to fix multimodal beam search; this is a + # bandaid fix for 2 small problems: + # 1. Multi_modal_data on the processed_inputs currently resolves to + # `None`. + # 2. preprocessing above expands the multimodal placeholders. However, + # this happens again in generation, so the double expansion causes + # a mismatch. + # TODO - would be ideal to handle this more gracefully. + prompt_text: Optional[str] + prompt_token_ids: list[int] + multi_modal_data: Optional[MultiModalDataDict] + if isinstance(prompt, str): + prompt_text = prompt + prompt_token_ids = [] + multi_modal_data = None + else: + prompt_text = prompt.get("prompt") # type: ignore + prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore + multi_modal_data = prompt.get("multi_modal_data") # type: ignore + + mm_processor_kwargs: Optional[dict[str, Any]] = processed_inputs.get( + "mm_processor_kwargs" + ) # type: ignore + + tokenized_length = len(prompt_token_ids) + + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence( + tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + lora_request=lora_request, + ) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch, lora_req_batch = zip( + *[ + ( + EngineTokensPrompt( + prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs, + ), + beam.lora_request, + ) + for beam in all_beams + ] + ) + + tasks = [] + request_id_batch = f"{request_id}-{random_uuid()}" + + for i, (individual_prompt, lora_req) in enumerate( + zip(prompts_batch, lora_req_batch) + ): + request_id_item = f"{request_id_batch}-beam-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.engine_client.generate( + individual_prompt, + beam_search_params, + request_id_item, + lora_request=lora_req, + ) + ) + ) + tasks.append(task) + + output = [x[0] for x in await asyncio.gather(*tasks)] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + if token_id == eos_token_id and not ignore_eos: + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id] + if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop", + stop_reason=eos_token_id, + ) + ) + else: + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam.mm_processor_kwargs, + ) + ) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if beam.tokens[-1] == eos_token_id and not ignore_eos: + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + text=beam.text, # type: ignore + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason + if beam.finish_reason is not None + else "length", + stop_reason=beam.stop_reason, + ) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + + def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: + """ + Get a Renderer instance with the provided tokenizer. + Uses shared async tokenizer pool for efficiency. + """ + return CompletionRenderer( + model_config=self.model_config, + tokenizer=tokenizer, + async_tokenizer_pool=self._async_tokenizer_pool, + ) + + def _build_render_config( + self, + request: Any, + ) -> RenderConfig: + """ + Build and return a `RenderConfig` for an endpoint. + + Used by the renderer to control how prompts are prepared + (e.g., tokenization and length handling). Endpoints should + implement this with logic appropriate to their request type. + """ + raise NotImplementedError def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ @@ -291,17 +599,17 @@ class OpenAIServing: yield self._build_response(ctx) def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: - truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", - None) + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if truncate_prompt_tokens is not None: - if truncate_prompt_tokens <= self.max_model_len: - ctx.truncate_prompt_tokens = truncate_prompt_tokens - else: - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len + ): + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." + ) return None def _create_pooling_params( @@ -310,7 +618,8 @@ class OpenAIServing: ) -> Union[PoolingParams, ErrorResponse]: if not hasattr(ctx.request, "to_pooling_params"): return self.create_error_response( - "Request type does not support pooling parameters") + "Request type does not support pooling parameters" + ) return ctx.request.to_pooling_params() @@ -319,40 +628,34 @@ class OpenAIServing: ctx: ServeContext, ) -> Optional[ErrorResponse]: """Schedule the request and get the result generator.""" - generators: list[AsyncGenerator[Union[RequestOutput, - PoolingRequestOutput], - None]] = [] + generators: list[ + AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None] + ] = [] try: - trace_headers = (None if ctx.raw_request is None else await - self._get_trace_headers(ctx.raw_request.headers)) + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) pooling_params = self._create_pooling_params(ctx) if isinstance(pooling_params, ErrorResponse): return pooling_params if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) - self._log_inputs(request_id_item, - ctx.request_prompts[i], - params=pooling_params, - lora_request=ctx.lora_request) - - # Mypy has an existing bug related to inferring the variance of - # TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) generator = self.engine_client.encode( engine_prompt, pooling_params, @@ -379,28 +682,24 @@ class OpenAIServing: """Collect batch results from the result generator.""" try: if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") num_prompts = len(ctx.engine_prompts) - final_res_batch: list[Optional[Union[RequestOutput, - PoolingRequestOutput]]] + final_res_batch: list[Optional[Union[RequestOutput, PoolingRequestOutput]]] final_res_batch = [None] * num_prompts if ctx.result_generator is None: - return self.create_error_response( - "Result generator not available") + return self.create_error_response("Result generator not available") async for i, res in ctx.result_generator: final_res_batch[i] = res if None in final_res_batch: return self.create_error_response( - "Failed to generate results for all prompts") + "Failed to generate results for all prompts" + ) - ctx.final_res_batch = [ - res for res in final_res_batch if res is not None - ] + ctx.final_res_batch = [res for res in final_res_batch if res is not None] return None @@ -408,50 +707,66 @@ class OpenAIServing: return self.create_error_response(str(e)) def create_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(error=ErrorInfo( - message=message, type=err_type, code=status_code.value)) + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> ErrorResponse: + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() + return ErrorResponse( + error=ErrorInfo(message=message, type=err_type, code=status_code.value) + ) def create_streaming_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> str: json_str = json.dumps( - self.create_error_response(message=message, - err_type=err_type, - status_code=status_code).model_dump()) + self.create_error_response( + message=message, err_type=err_type, status_code=status_code + ).model_dump() + ) return json_str async def _check_model( self, request: AnyRequest, ) -> Optional[ErrorResponse]: - error_response = None if self._is_model_supported(request.model): return None if request.model in self.models.lora_requests: return None - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( - load_result := await self.models.resolve_lora(request.model)): + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): if isinstance(load_result, LoRARequest): return None - if isinstance(load_result, ErrorResponse) and \ - load_result.error.code == HTTPStatus.BAD_REQUEST.value: + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): error_response = load_result return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) def _get_active_default_mm_loras( - self, request: AnyRequest) -> Optional[LoRARequest]: + self, request: AnyRequest + ) -> Optional[LoRARequest]: """Determine if there are any active default multimodal loras.""" # TODO: Currently this is only enabled for chat completions # to be better aligned with only being enabled for .generate @@ -479,7 +794,6 @@ class OpenAIServing: request: AnyRequest, supports_default_mm_loras: bool = False, ) -> Optional[LoRARequest]: - if request.model in self.models.lora_requests: return self.models.lora_requests[request.model] @@ -507,8 +821,11 @@ class OpenAIServing: return message_types for message in request.messages: - if (isinstance(message, dict) and "content" in message - and isinstance(message["content"], list)): + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): for content_dict in message["content"]: if "type" in content_dict: message_types.add(content_dict["type"].split("_")[0]) @@ -517,34 +834,39 @@ class OpenAIServing: async def _normalize_prompt_text_to_input( self, request: AnyRequest, - tokenizer: AnyTokenizer, prompt: str, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], + tokenizer: AnyTokenizer, add_special_tokens: bool, ) -> TextTokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) - if (self.model_config.encoder_config is not None - and self.model_config.encoder_config.get( - "do_lower_case", False)): + if ( + self.model_config.encoder_config is not None + and self.model_config.encoder_config.get("do_lower_case", False) + ): prompt = prompt.lower() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) + if truncate_prompt_tokens is None: encoded = await async_tokenizer( - prompt, add_special_tokens=add_special_tokens) + prompt, add_special_tokens=add_special_tokens + ) elif truncate_prompt_tokens < 0: # Negative means we cap at the model's max length encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=self.max_model_len) + max_length=self.max_model_len, + ) else: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=truncate_prompt_tokens) + max_length=truncate_prompt_tokens, + ) input_ids = encoded.input_ids input_text = prompt @@ -554,20 +876,23 @@ class OpenAIServing: async def _normalize_prompt_tokens_to_input( self, request: AnyRequest, - tokenizer: AnyTokenizer, prompt_ids: list[int], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + tokenizer: Optional[AnyTokenizer], ) -> TextTokensPrompt: - async_tokenizer = self._get_async_tokenizer(tokenizer) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) if truncate_prompt_tokens is None: input_ids = prompt_ids elif truncate_prompt_tokens < 0: - input_ids = prompt_ids[-self.max_model_len:] + input_ids = prompt_ids[-self.max_model_len :] else: input_ids = prompt_ids[-truncate_prompt_tokens:] - input_text = await async_tokenizer.decode(input_ids) + if tokenizer is None: + input_text = "" + else: + async_tokenizer = self._get_async_tokenizer(tokenizer) + input_text = await async_tokenizer.decode(input_ids) return self._validate_input(request, input_ids, input_text) @@ -581,33 +906,39 @@ class OpenAIServing: # Note: EmbeddingRequest, ClassificationRequest, # and ScoreRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest, - ScoreRequest, RerankRequest, ClassificationRequest)): - + if isinstance( + request, + ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ScoreRequest, + RerankRequest, + ClassificationRequest, + ), + ): # Note: input length can be up to the entire model context length # since these requests don't generate tokens. if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", - ClassificationRequest: "classification" + ClassificationRequest: "classification", } - operation = operations.get(type(request), - "embedding generation") + operation = operations.get(type(request), "embedding generation") raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the input for {operation}. " - f"Please reduce the length of the input.") - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) + f"Please reduce the length of the input." + ) + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation - if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, - DetokenizeRequest)): - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) + if isinstance( + request, + (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), + ): + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # chat completion endpoint supports max_completion_tokens if isinstance(request, ChatCompletionRequest): @@ -623,16 +954,17 @@ class OpenAIServing: f"This model's maximum context length is " f"{self.max_model_len} tokens. However, your request has " f"{token_num} input tokens. Please reduce the length of " - "the input messages.") + "the input messages." + ) - if max_tokens is not None and \ - token_num + max_tokens > self.max_model_len: + if max_tokens is not None and token_num + max_tokens > self.max_model_len: raise ValueError( "'max_tokens' or 'max_completion_tokens' is too large: " f"{max_tokens}. This model's maximum context length is " f"{self.max_model_len} tokens and your request has " f"{token_num} input tokens ({max_tokens} > {self.max_model_len}" - f" - {token_num}).") + f" - {token_num})." + ) return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) @@ -641,20 +973,16 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_input: Union[str, list[int]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: """ - A simpler implementation of - [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] - that assumes single input. + A simpler implementation that tokenizes a single prompt input. """ async for result in self._tokenize_prompt_inputs_async( - request, - tokenizer, + request, + tokenizer, [prompt_input], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, + add_special_tokens=add_special_tokens, ): return result raise ValueError("No results yielded from tokenization") @@ -664,181 +992,45 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> AsyncGenerator[TextTokensPrompt, None]: """ - A simpler implementation of - [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] - that assumes multiple inputs. + A simpler implementation that tokenizes multiple prompt inputs. """ - for text in prompt_inputs: - if isinstance(text, str): + for prompt in prompt_inputs: + if isinstance(prompt, str): yield await self._normalize_prompt_text_to_input( request, - tokenizer, - prompt=text, - truncate_prompt_tokens=truncate_prompt_tokens, + prompt=prompt, + tokenizer=tokenizer, add_special_tokens=add_special_tokens, ) else: yield await self._normalize_prompt_tokens_to_input( request, - tokenizer, - prompt_ids=text, - truncate_prompt_tokens=truncate_prompt_tokens, + prompt_ids=prompt, + tokenizer=tokenizer, ) - async def _tokenize_prompt_input_or_inputs_async( + def _validate_chat_template( self, - request: AnyRequest, - tokenizer: AnyTokenizer, - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, - add_special_tokens: bool = True, - ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: - """ - Tokenize/detokenize depending on the input format. - - According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_ - , each input can be a string or array of tokens. Note that each request - can pass one or more inputs. - """ - inputs_embeds = list[EmbedsPrompt]() - inputs_text = list[TextTokensPrompt]() - - if (isinstance(request, CompletionRequest) - and request.prompt_embeds is not None): - inputs_embeds.extend( - self._load_prompt_embeds(request.prompt_embeds, - truncate_prompt_tokens)) - - # Empty prompts are okay as long as there are prompt embeddings - if input_or_inputs is None or (inputs_embeds - and input_or_inputs == ""): - return [], inputs_embeds - - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is False" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - - # Parse and batch the input prompts - batch_inputs = parse_and_batch_prompt(input_or_inputs) - - # Process each input in the batch concurrently - tasks = [] - for prompt_input in batch_inputs: - if prompt_input["is_tokens"] is False: - task = self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens) - else: - task = self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) - tasks.append(task) - - # Wait for all tokenization tasks to complete - results = await asyncio.gather(*tasks) - inputs_text.extend(results) - - return inputs_text, inputs_embeds - - @overload - async def _preprocess_completion( - self, - request: Union[DetokenizeRequest, EmbeddingCompletionRequest, - RerankRequest, ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest], - tokenizer: AnyTokenizer, - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., - add_special_tokens: bool = ..., - ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: - ... - - @overload - async def _preprocess_completion( - self, - request: CompletionRequest, - tokenizer: AnyTokenizer, - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., - add_special_tokens: bool = ..., - ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ - EngineTokensPrompt, EngineEmbedsPrompt]]]: - ... - - async def _preprocess_completion( - self, - request: CompletionLikeRequest, - tokenizer: AnyTokenizer, - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, - add_special_tokens: bool = True, - ) -> tuple[Union[list[TextTokensPrompt], list[Union[ - TextTokensPrompt, EmbedsPrompt]]], Union[ - list[EngineTokensPrompt], list[Union[EngineTokensPrompt, - EngineEmbedsPrompt]]]]: - if not isinstance(request, - CompletionRequest) and input_or_inputs is None: - raise ValueError( - "Prompt embeds with non-completion requests is not" - " currently supported.") - - (request_prompts_text, request_prompts_embeds - ) = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) - - engine_prompts_text = [ - EngineTokensPrompt( - prompt_token_ids=request_prompt_text["prompt_token_ids"]) - for request_prompt_text in request_prompts_text - ] - cache_salt = request.cache_salt if ( - hasattr(request, "cache_salt") - and request.cache_salt is not None) else None - if cache_salt: - for prompt_text in engine_prompts_text: - prompt_text["cache_salt"] = cache_salt - - # This check is equivalent to simply checking if - # `request_prompts_embeds` is empty, but it's difficult to propagate - # overloads to the private helper functions to enable this check. - # This overload is needed because only TextPrompts are allowed for - # non-completion requests and if we don't add the overload here, - # everywhere this function is used outside of serving_completion will - # need logic asserting that only text prompts are in the request. - if not isinstance(request, - CompletionRequest) and input_or_inputs is not None: - return request_prompts_text, engine_prompts_text - - engine_prompts_embeds = [ - EngineEmbedsPrompt( - prompt_embeds=request_prompt_embeds["prompt_embeds"]) - for request_prompt_embeds in request_prompts_embeds - ] - if cache_salt: - for prompt_embed in engine_prompts_embeds: - prompt_embed["cache_salt"] = cache_salt - - request_prompts = request_prompts_embeds + request_prompts_text - engine_prompts = engine_prompts_embeds + engine_prompts_text - return request_prompts, engine_prompts + request_chat_template: Optional[str], + chat_template_kwargs: Optional[dict[str, Any]], + trust_request_chat_template: bool, + ) -> Optional[ErrorResponse]: + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None + ) + ): + return self.create_error_response( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None async def _preprocess_chat( self, @@ -853,10 +1045,12 @@ class OpenAIServing: documents: Optional[list[dict[str, str]]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, - ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[EngineTokensPrompt]]: + ) -> tuple[ + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], + ]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -866,7 +1060,7 @@ class OpenAIServing: tokenizer, model_config=model_config, ) - conversation, mm_data_future = parse_chat_messages_futures( + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, tokenizer, @@ -887,7 +1081,7 @@ class OpenAIServing: if tokenizer is None: request_prompt = "placeholder" elif isinstance(tokenizer, MistralTokenizer): - request_prompt = apply_mistral_chat_template( + request_prompt = await self._apply_mistral_chat_template_async( tokenizer, messages=messages, **_chat_template_kwargs, @@ -905,8 +1099,9 @@ class OpenAIServing: # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM - should_parse_tools = tool_parser is not None and (hasattr( - request, "tool_choice") and request.tool_choice != "none") + should_parse_tools = tool_parser is not None and ( + hasattr(request, "tool_choice") and request.tool_choice != "none" + ) if should_parse_tools: if not isinstance(request, ChatCompletionRequest): @@ -914,35 +1109,43 @@ class OpenAIServing: raise NotImplementedError(msg) request = tool_parser(tokenizer).adjust_request( # type: ignore - request=request) + request=request + ) if tokenizer is None: assert isinstance(request_prompt, str), ( - "Prompt has to be a string", \ - "when the tokenizer is not initialised" + "Prompt has to be a string", + "when the tokenizer is not initialised", + ) + prompt_inputs = TextTokensPrompt( + prompt=request_prompt, prompt_token_ids=[1] ) - prompt_inputs = TextTokensPrompt(prompt=request_prompt, - prompt_token_ids=[1]) elif isinstance(request_prompt, str): prompt_inputs = await self._tokenize_prompt_input_async( request, tokenizer, request_prompt, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: # For MistralTokenizer assert is_list_of(request_prompt, int), ( - "Prompt has to be either a string or a list of token ids") + "Prompt has to be either a string or a list of token ids" + ) prompt_inputs = TextTokensPrompt( prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt) + prompt_token_ids=request_prompt, + ) engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"]) + prompt_token_ids=prompt_inputs["prompt_token_ids"] + ) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -951,6 +1154,33 @@ class OpenAIServing: return conversation, [request_prompt], [engine_prompt] + async def _process_inputs( + self, + request_id: str, + engine_prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + *, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]], + priority: int, + ) -> tuple[EngineCoreRequest, dict[str, Any]]: + """Use the Processor to process inputs for AsyncLLM.""" + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size( + self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs + ) + + engine_request = self.processor.process_inputs( + request_id, + engine_prompt, + params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + ) + return engine_request, tokenization_kwargs + async def _generate_with_builtin_tools( self, request_id: str, @@ -962,6 +1192,7 @@ class OpenAIServing: priority: int = 0, **kwargs, ): + prompt_text, _, _ = self._get_prompt_components(request_prompt) orig_priority = priority while True: self._log_inputs( @@ -970,14 +1201,27 @@ class OpenAIServing: params=sampling_params, lora_request=lora_request, ) - generator = self.engine_client.generate( + trace_headers = kwargs.get("trace_headers") + engine_request, tokenization_kwargs = await self._process_inputs( + request_id, engine_prompt, sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ) + + generator = self.engine_client.generate( + engine_request, + sampling_params, request_id, lora_request=lora_request, priority=priority, + prompt_text=prompt_text, + tokenization_kwargs=tokenization_kwargs, **kwargs, ) + async for res in generator: context.append_output(res) # NOTE(woosuk): The stop condition is handled by the engine. @@ -997,69 +1241,33 @@ class OpenAIServing: # Create inputs for the next turn. # Render the next prompt token ids. prompt_token_ids = context.render_for_completion() - engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_token_ids) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) request_prompt = prompt_token_ids # Update the sampling params. - sampling_params.max_tokens = (self.max_model_len - - len(prompt_token_ids)) + sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids) # OPTIMIZATION priority = orig_priority - 1 - @staticmethod - def _load_prompt_embeds( - prompt_embeds: Optional[Union[bytes, list[bytes]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - ) -> list[EmbedsPrompt]: + def _get_prompt_components( + self, + prompt: Union[RequestPrompt, PromptType], + ) -> PromptComponents: + if isinstance(prompt, list): + return PromptComponents(token_ids=prompt) - def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load(io.BytesIO( - pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu")) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() - if tensor.dim() > 2: - tensor = tensor.squeeze(0) - assert tensor.dim() == 2 - if truncate_prompt_tokens is not None: - tensor = tensor[-truncate_prompt_tokens:] - return {"prompt_embeds": tensor} - - if prompt_embeds: - if isinstance(prompt_embeds, list): - return [ - _load_and_validate_embed(embed) for embed in prompt_embeds - ] - else: - return [_load_and_validate_embed(prompt_embeds)] - else: - return [] + return get_prompt_components(prompt) # type: ignore[arg-type] def _log_inputs( self, request_id: str, - inputs: RequestPrompt, - params: Optional[Union[SamplingParams, PoolingParams, - BeamSearchParams]], + inputs: Union[RequestPrompt, PromptType], + params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], ) -> None: if self.request_logger is None: return - prompt, prompt_token_ids, prompt_embeds = None, None, None - if isinstance(inputs, str): - prompt = inputs - elif isinstance(inputs, list): - prompt_token_ids = inputs - elif 'prompt_embeds' in inputs: - prompt_embeds = inputs.get("prompt_embeds") - else: - prompt = inputs["prompt"] - prompt_token_ids = inputs["prompt_token_ids"] + + prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs) self.request_logger.log_inputs( request_id, @@ -1085,8 +1293,9 @@ class OpenAIServing: return None @staticmethod - def _base_request_id(raw_request: Optional[Request], - default: Optional[str] = None) -> Optional[str]: + def _base_request_id( + raw_request: Optional[Request], default: Optional[str] = None + ) -> Optional[str]: """Pulls the request id to use from a header, if provided""" default = default or random_uuid() if raw_request is None: @@ -1095,10 +1304,12 @@ class OpenAIServing: return raw_request.headers.get("X-Request-Id", default) @staticmethod - def _get_decoded_token(logprob: Logprob, - token_id: int, - tokenizer: AnyTokenizer, - return_as_token_id: bool = False) -> str: + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False, + ) -> str: if return_as_token_id: return f"token_id:{token_id}" @@ -1111,19 +1322,10 @@ class OpenAIServing: return True return self.models.is_base_model(model_name) - def _get_model_name(self, - model_name: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> str: - if lora_request: - return lora_request.lora_name - if not model_name: - return self.models.base_model_paths[0].name - return model_name - def clamp_prompt_logprobs( - prompt_logprobs: Union[PromptLogprobs, - None]) -> Union[PromptLogprobs, None]: + prompt_logprobs: Union[PromptLogprobs, None], +) -> Union[PromptLogprobs, None]: if prompt_logprobs is None: return prompt_logprobs @@ -1131,6 +1333,6 @@ def clamp_prompt_logprobs( if logprob_dict is None: continue for logprob_values in logprob_dict.values(): - if logprob_values.logprob == float('-inf'): + if logprob_values.logprob == float("-inf"): logprob_values.logprob = -9999.0 return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index a4efa0815b4e7..1aaac60f29933 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -7,13 +7,16 @@ from dataclasses import dataclass from http import HTTPStatus from typing import Optional, Union -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse, - LoadLoRAAdapterRequest, - ModelCard, ModelList, - ModelPermission, - UnloadLoRAAdapterRequest) +from vllm.entrypoints.openai.protocol import ( + ErrorInfo, + ErrorResponse, + LoadLoRAAdapterRequest, + ModelCard, + ModelList, + ModelPermission, + UnloadLoRAAdapterRequest, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry @@ -47,40 +50,43 @@ class OpenAIServingModels: def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, base_model_paths: list[BaseModelPath], *, lora_modules: Optional[list[LoRAModulePath]] = None, ): super().__init__() - self.base_model_paths = base_model_paths - - self.max_model_len = model_config.max_model_len self.engine_client = engine_client - self.model_config = model_config + self.base_model_paths = base_model_paths self.static_lora_modules = lora_modules self.lora_requests: dict[str, LoRARequest] = {} self.lora_id_counter = AtomicCounter(0) self.lora_resolvers: list[LoRAResolver] = [] - for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers( - ): + for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(): self.lora_resolvers.append( - LoRAResolverRegistry.get_resolver(lora_resolver_name)) + LoRAResolverRegistry.get_resolver(lora_resolver_name) + ) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) + self.processor = self.engine_client.processor + self.io_processor = self.engine_client.io_processor + self.model_config = self.engine_client.model_config + self.max_model_len = self.model_config.max_model_len + async def init_static_loras(self): """Loads all static LoRA modules. Raises if any fail to load""" if self.static_lora_modules is None: return for lora in self.static_lora_modules: - load_request = LoadLoRAAdapterRequest(lora_path=lora.path, - lora_name=lora.name) + load_request = LoadLoRAAdapterRequest( + lora_path=lora.path, lora_name=lora.name + ) load_result = await self.load_lora_adapter( - request=load_request, base_model_name=lora.base_model_name) + request=load_request, base_model_name=lora.base_model_name + ) if isinstance(load_result, ErrorResponse): raise ValueError(load_result.error.message) @@ -100,47 +106,48 @@ class OpenAIServingModels: return self.base_model_paths[0].name async def show_available_models(self) -> ModelList: - """Show available models. This includes the base model and all + """Show available models. This includes the base model and all adapters""" model_cards = [ - ModelCard(id=base_model.name, - max_model_len=self.max_model_len, - root=base_model.model_path, - permission=[ModelPermission()]) + ModelCard( + id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()], + ) for base_model in self.base_model_paths ] lora_cards = [ - ModelCard(id=lora.lora_name, - root=lora.local_path, - parent=lora.base_model_name if lora.base_model_name else - self.base_model_paths[0].name, - permission=[ModelPermission()]) + ModelCard( + id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name + if lora.base_model_name + else self.base_model_paths[0].name, + permission=[ModelPermission()], + ) for lora in self.lora_requests.values() ] model_cards.extend(lora_cards) return ModelList(data=model_cards) async def load_lora_adapter( - self, - request: LoadLoRAAdapterRequest, - base_model_name: Optional[str] = None + self, request: LoadLoRAAdapterRequest, base_model_name: Optional[str] = None ) -> Union[ErrorResponse, str]: lora_name = request.lora_name # Ensure atomicity based on the lora name async with self.lora_resolver_lock[lora_name]: - error_check_ret = await self._check_load_lora_adapter_request( - request) + error_check_ret = await self._check_load_lora_adapter_request(request) if error_check_ret is not None: return error_check_ret lora_path = request.lora_path unique_id = self.lora_id_counter.inc(1) - lora_request = LoRARequest(lora_name=lora_name, - lora_int_id=unique_id, - lora_path=lora_path) - if base_model_name is not None and self.is_base_model( - base_model_name): + lora_request = LoRARequest( + lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path + ) + if base_model_name is not None and self.is_base_model(base_model_name): lora_request.base_model_name = base_model_name # Validate that the adapter can be loaded into the engine @@ -154,24 +161,24 @@ class OpenAIServingModels: error_type = "NotFoundError" status_code = HTTPStatus.NOT_FOUND - return create_error_response(message=str(e), - err_type=error_type, - status_code=status_code) + return create_error_response( + message=str(e), err_type=error_type, status_code=status_code + ) self.lora_requests[lora_name] = lora_request - logger.info("Loaded new LoRA adapter: name '%s', path '%s'", - lora_name, lora_path) + logger.info( + "Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path + ) return f"Success: LoRA adapter '{lora_name}' added successfully." async def unload_lora_adapter( - self, - request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]: + self, request: UnloadLoRAAdapterRequest + ) -> Union[ErrorResponse, str]: lora_name = request.lora_name # Ensure atomicity based on the lora name async with self.lora_resolver_lock[lora_name]: - error_check_ret = await self._check_unload_lora_adapter_request( - request) + error_check_ret = await self._check_unload_lora_adapter_request(request) if error_check_ret is not None: return error_check_ret @@ -181,48 +188,49 @@ class OpenAIServingModels: return f"Success: LoRA adapter '{lora_name}' removed successfully." async def _check_load_lora_adapter_request( - self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]: + self, request: LoadLoRAAdapterRequest + ) -> Optional[ErrorResponse]: # Check if both 'lora_name' and 'lora_path' are provided if not request.lora_name or not request.lora_path: return create_error_response( message="Both 'lora_name' and 'lora_path' must be provided.", err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) # Check if the lora adapter with the given name already exists if request.lora_name in self.lora_requests: return create_error_response( - message= - f"The lora adapter '{request.lora_name}' has already been " + message=f"The lora adapter '{request.lora_name}' has already been " "loaded.", err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) return None async def _check_unload_lora_adapter_request( - self, - request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]: + self, request: UnloadLoRAAdapterRequest + ) -> Optional[ErrorResponse]: # Check if 'lora_name' is not provided return an error if not request.lora_name: return create_error_response( - message= - "'lora_name' needs to be provided to unload a LoRA adapter.", + message="'lora_name' needs to be provided to unload a LoRA adapter.", err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) # Check if the lora adapter with the given name exists if request.lora_name not in self.lora_requests: return create_error_response( - message= - f"The lora adapter '{request.lora_name}' cannot be found.", + message=f"The lora adapter '{request.lora_name}' cannot be found.", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) return None - async def resolve_lora( - self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: + async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: """Attempt to resolve a LoRA adapter using available resolvers. Args: @@ -244,8 +252,7 @@ class OpenAIServingModels: # Try to resolve using available resolvers for resolver in self.lora_resolvers: - lora_request = await resolver.resolve_lora( - base_model_name, lora_name) + lora_request = await resolver.resolve_lora(base_model_name, lora_name) if lora_request is not None: found_adapter = True @@ -256,33 +263,43 @@ class OpenAIServingModels: self.lora_requests[lora_name] = lora_request logger.info( "Resolved and loaded LoRA adapter '%s' using %s", - lora_name, resolver.__class__.__name__) + lora_name, + resolver.__class__.__name__, + ) return lora_request except BaseException as e: logger.warning( "Failed to load LoRA '%s' resolved by %s: %s. " - "Trying next resolver.", lora_name, - resolver.__class__.__name__, e) + "Trying next resolver.", + lora_name, + resolver.__class__.__name__, + e, + ) continue if found_adapter: # An adapter was found, but all attempts to load it failed. return create_error_response( - message=(f"LoRA adapter '{lora_name}' was found " - "but could not be loaded."), + message=( + f"LoRA adapter '{lora_name}' was found but could not be loaded." + ), err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) else: # No adapter was found return create_error_response( message=f"LoRA adapter {lora_name} does not exist", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(error=ErrorInfo( - message=message, type=err_type, code=status_code.value)) + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message=message, type=err_type, code=status_code.value) + ) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 38745d001ade6..964655fb7f65e 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -13,16 +13,23 @@ import torch from fastapi import Request from typing_extensions import assert_never -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, - PoolingChatRequest, - PoolingRequest, PoolingResponse, - PoolingResponseData, UsageInfo) +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + IOProcessorRequest, + IOProcessorResponse, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingRequest, + PoolingResponse, + PoolingResponseData, + UsageInfo, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput @@ -48,30 +55,33 @@ def _get_data( class OpenAIServingPooling(OpenAIServing): - def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, + log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_pooling( self, request: PoolingRequest, raw_request: Optional[Request] = None, - ) -> Union[PoolingResponse, ErrorResponse]: + ) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]: """ See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. @@ -80,56 +90,77 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - encoding_format = request.encoding_format - if request.dimensions is not None: - return self.create_error_response( - "dimensions is currently not supported") + model_name = self.models.model_name() - model_name = self._get_model_name(request.model) request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) - truncate_prompt_tokens = request.truncate_prompt_tokens - + is_io_processor_request = isinstance(request, IOProcessorRequest) try: - truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens) lora_request = self._maybe_get_adapters(request) if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = await self.engine_client.get_tokenizer(lora_request - ) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) - if isinstance(request, PoolingChatRequest): + if getattr(request, "dimensions", None) is not None: + return self.create_error_response( + "dimensions is currently not supported" + ) + + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) + truncate_prompt_tokens = _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens + ) + + if is_io_processor_request: + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details." + ) + + validated_prompt = self.io_processor.parse_request(request) + + engine_prompts = await self.io_processor.pre_process_async( + prompt=validated_prompt, request_id=request_id + ) + + elif isinstance(request, PoolingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, + chat_template_content_format=self.chat_template_content_format, # In pooling requests, we are not generating tokens, # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + elif isinstance(request, PoolingCompletionRequest): + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.input, + config=self._build_render_config(request), + ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + raise ValueError(f"Unsupported request of type {type(request)}") except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -147,13 +178,18 @@ class OpenAIServingPooling(OpenAIServing): for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - request_prompts[i], - params=pooling_params, - lora_request=lora_request) + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=lora_request, + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) generator = self.engine_client.encode( engine_prompt, @@ -171,6 +207,15 @@ class OpenAIServingPooling(OpenAIServing): result_generator = merge_async_iterators(*generators) + if is_io_processor_request: + assert self.io_processor is not None + output = await self.io_processor.post_process_async( + model_output=result_generator, + request_id=request_id, + ) + return self.io_processor.output_to_response(output) + + assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) num_prompts = len(engine_prompts) # Non-streaming response @@ -182,15 +227,14 @@ class OpenAIServingPooling(OpenAIServing): assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(list[PoolingRequestOutput], - final_res_batch) + final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch) response = self.request_output_to_pooling_response( final_res_batch_checked, request_id, created_time, model_name, - encoding_format, + request.encoding_format, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -233,3 +277,10 @@ class OpenAIServingPooling(OpenAIServing): data=items, usage=usage, ) + + def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig: + return RenderConfig( + max_length=self.max_model_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 6b131bbb04d19..60f8b78ed1757 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -4,65 +4,99 @@ import asyncio import json import time +import uuid +from collections import deque from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus -from typing import Any, Callable, Final, Optional, Union +from typing import Callable, Final, Optional, Union import jinja2 -import openai.types.responses as openai_responses_types from fastapi import Request -from openai import BaseModel -# yapf conflicts with isort for this block -# yapf: disable -from openai.types.responses import (ResponseCreatedEvent, - ResponseFunctionToolCall, - ResponseInProgressEvent, - ResponseOutputItem, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, ResponseOutputText, - ResponseReasoningItem, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent) -from openai.types.responses.response_output_text import (Logprob, - LogprobTopLogprob) -# yapf: enable +from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterToolCallParam, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseFunctionToolCall, + ResponseFunctionWebSearch, + ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseStatus, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, + response_function_web_search, + response_text_delta_event, +) +from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent) + Content as ResponseReasoningTextContent, +) from openai_harmony import Message as OpenAIHarmonyMessage from vllm import envs -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - ChatTemplateContentFormatOption) -from vllm.entrypoints.context import (ConversationContext, HarmonyContext, - SimpleContext, StreamingHarmonyContext) +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, +) +from vllm.entrypoints.context import ( + ConversationContext, + HarmonyContext, + SimpleContext, + StreamingHarmonyContext, +) from vllm.entrypoints.harmony_utils import ( - get_developer_message, get_stop_tokens_for_assistant_actions, - get_system_message, get_user_message, parse_output_message, - parse_remaining_state, parse_response_input, render_for_completion) + get_developer_message, + get_stop_tokens_for_assistant_actions, + get_system_message, + get_user_message, + has_custom_tools, + parse_output_message, + parse_remaining_state, + parse_response_input, + render_for_completion, +) from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (ErrorResponse, - InputTokensDetails, - OutputTokensDetails, - RequestResponseMetadata, - ResponsesRequest, - ResponsesResponse, ResponseUsage) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, + ErrorResponse, + InputTokensDetails, + OutputTokensDetails, + RequestResponseMetadata, + ResponseCompletedEvent, + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseReasoningPartAddedEvent, + ResponseReasoningPartDoneEvent, + ResponsesRequest, + ResponsesResponse, + ResponseUsage, + StreamingResponsesResponse, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.tool_server import MCPToolServer, ToolServer +from vllm.entrypoints.tool_server import ToolServer from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob as SampleLogprob +from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput -from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob as SampleLogprob -from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -70,11 +104,9 @@ logger = init_logger(__name__) class OpenAIServingResponses(OpenAIServing): - def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], @@ -88,41 +120,35 @@ class OpenAIServingResponses(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format self.enable_log_outputs = enable_log_outputs - self.reasoning_parser: Optional[Callable[[AnyTokenizer], - ReasoningParser]] = None - if reasoning_parser: - try: - self.reasoning_parser = ( - ReasoningParserManager.get_reasoning_parser( - reasoning_parser)) - assert self.reasoning_parser is not None - except Exception as e: - raise TypeError( - f"{reasoning_parser=} has not been registered") from e - + self.reasoning_parser = self._get_reasoning_parser( + reasoning_parser_name=reasoning_parser + ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default chat sampling params from %s: %s", - source, self.default_sampling_params) + logger.info( + "Using default chat sampling params from %s: %s", + source, + self.default_sampling_params, + ) # If False (default), the "store" option is (silently) ignored and the # response is not stored. If True, the response is stored in memory. @@ -134,26 +160,31 @@ class OpenAIServingResponses(OpenAIServing): logger.warning_once( "`VLLM_ENABLE_RESPONSES_API_STORE` is enabled. This may " "cause a memory leak since we never remove responses from " - "the store.") + "the store." + ) - self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: - logger.warning("For gpt-oss, we ignore --enable-auto-tool-choice " - "and always enable tool use.") + logger.warning( + "For gpt-oss, we ignore --enable-auto-tool-choice " + "and always enable tool use." + ) # OpenAI models have two EOS-like tokens: <|return|> and <|call|>. # We need to add them to the stop token ids. if "stop_token_ids" not in self.default_sampling_params: self.default_sampling_params["stop_token_ids"] = [] self.default_sampling_params["stop_token_ids"].extend( - get_stop_tokens_for_assistant_actions()) + get_stop_tokens_for_assistant_actions() + ) # set up tool use self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: logger.info( - "\"auto\" tool choice has been enabled please note that while" + '"auto" tool choice has been enabled please note that while' " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored.") + "compatibility reasons, it will be ignored." + ) # HACK(woosuk): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we @@ -166,15 +197,44 @@ class OpenAIServingResponses(OpenAIServing): # never remove messages from the store. self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} + # HACK(wuhang): This is a hack. We should use a better store. + # FIXME: If enable_store=True, this may cause a memory leak since we + # never remove events from the store. + self.event_store: dict[ + str, tuple[deque[StreamingResponsesResponse], asyncio.Event] + ] = {} + self.background_tasks: dict[str, asyncio.Task] = {} self.tool_server = tool_server + def _validate_generator_input( + self, engine_prompt: EngineTokensPrompt + ) -> Optional[ErrorResponse]: + """Add validations to the input to the generator here.""" + if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): + error_message = ( + "The engine prompt length" + f" {len(engine_prompt['prompt_token_ids'])} " + f"exceeds the max_model_len {self.max_model_len}. " + "Please reduce prompt." + ) + return self.create_error_response( + err_type="invalid_request_error", + message=error_message, + status_code=HTTPStatus.BAD_REQUEST, + ) + return None + async def create_responses( self, request: ResponsesRequest, raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]: + ) -> Union[ + AsyncGenerator[StreamingResponsesResponse, None], + ResponsesResponse, + ErrorResponse, + ]: error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) @@ -195,7 +255,8 @@ class OpenAIServingResponses(OpenAIServing): "therefore does not support the background mode. To " "enable these features, set the environment variable " "`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching " - "the vLLM server."), + "the vLLM server." + ), status_code=HTTPStatus.BAD_REQUEST, ) # Disable the store option. @@ -215,8 +276,6 @@ class OpenAIServingResponses(OpenAIServing): # Handle the previous response ID. prev_response_id = request.previous_response_id if prev_response_id is not None: - if not prev_response_id.startswith("resp_"): - return self._make_invalid_id_error(prev_response_id) async with self.response_store_lock: prev_response = self.response_store.get(prev_response_id) if prev_response is None: @@ -226,36 +285,32 @@ class OpenAIServingResponses(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - model_name = self._get_model_name(request.model, lora_request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + model_name = self.models.model_name(lora_request) + tokenizer = await self.engine_client.get_tokenizer() if self.use_harmony: messages, request_prompts, engine_prompts = ( - self._make_request_with_harmony(request, prev_response)) + self._make_request_with_harmony(request, prev_response) + ) else: - messages, request_prompts, engine_prompts = ( - await self._make_request(request, prev_response, - tokenizer)) + messages, request_prompts, engine_prompts = await self._make_request( + request, prev_response, tokenizer + ) - except (ValueError, TypeError, RuntimeError, jinja2.TemplateError, - NotImplementedError) as e: + except ( + ValueError, + TypeError, + RuntimeError, + jinja2.TemplateError, + NotImplementedError, + ) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") - request_metadata = RequestResponseMetadata( - request_id=request.request_id) + request_metadata = RequestResponseMetadata(request_id=request.request_id) if raw_request: raw_request.state.request_metadata = request_metadata - if self.tool_server is not None and isinstance( - self.tool_server, MCPToolServer - ) and (request.background or request.stream) and request.tools and any( - tool.type in ["web_search_preview", "code_interpreter"] - for tool in request.tools): - return self.create_error_response( - "MCP tool server is not supported in background mode and " - "streaming mode") - # Schedule the request and get the result generator. generators: list[AsyncGenerator[ConversationContext, None]] = [] @@ -265,80 +320,94 @@ class OpenAIServingResponses(OpenAIServing): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): builtin_tool_list.append("python") - async with AsyncExitStack() as exit_stack: - try: - if self.tool_server is not None: - # TODO: initialize tool sessions lazily when the session - # is actually used. - tool_session_ctxs: dict[str, Any] = { - tool_name: - exit_stack.enter_async_context( - self.tool_server.new_session(tool_name)) - for tool_name in builtin_tool_list - } - tool_sessions = {} - for tool_name in builtin_tool_list: - tool_sessions[tool_name] = ( - await tool_session_ctxs[tool_name]) - else: - assert len(builtin_tool_list) == 0 - tool_sessions = {} - for i, engine_prompt in enumerate(engine_prompts): - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) + if self.tool_server.has_tool("container"): + builtin_tool_list.append("container") - trace_headers = (None if raw_request is None else await - self._get_trace_headers( - raw_request.headers)) + if self.tool_server is not None: + available_tools = builtin_tool_list + else: + assert len(builtin_tool_list) == 0 + available_tools = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + maybe_error = self._validate_generator_input(engine_prompt) + if maybe_error is not None: + return maybe_error - context: ConversationContext - if self.use_harmony: - if request.stream: - context = StreamingHarmonyContext( - messages, tool_sessions) - else: - context = HarmonyContext(messages, tool_sessions) - else: - context = SimpleContext() - generator = self._generate_with_builtin_tools( - request_id=request.request_id, - request_prompt=request_prompts[i], - engine_prompt=engine_prompt, - sampling_params=sampling_params, - context=context, - lora_request=lora_request, - priority=request.priority, - trace_headers=trace_headers, - ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - assert len(generators) == 1 - result_generator, = generators - - # Store the input messages. - if request.store: - self.msg_store[request.request_id] = messages - - if request.background: - created_time = int(time.time()) - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="queued", - usage=None, + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"] ) - async with self.response_store_lock: - self.response_store[response.id] = response - # Run the request in the background. + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params + ) + + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext(messages, available_tools) + else: + context = HarmonyContext(messages, available_tools) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, + lora_request=lora_request, + priority=request.priority, + trace_headers=trace_headers, + ) + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + (result_generator,) = generators + + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages + + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response + + # Run the request in the background. + if request.stream: + task = asyncio.create_task( + self._run_background_request_stream( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{request.request_id}", + ) + else: task = asyncio.create_task( self._run_background_request( request, @@ -353,37 +422,40 @@ class OpenAIServingResponses(OpenAIServing): name=f"create_{response.id}", ) - # For cleanup. - response_id = response.id - self.background_tasks[response_id] = task - task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) - return response + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None) + ) if request.stream: - return self.responses_stream_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) + return self.responses_background_stream_generator(request.request_id) + return response - try: - return await self.responses_full_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) - except Exception as e: - return self.create_error_response(str(e)) - return self.create_error_response("Should not reach here") + if request.stream: + return self.responses_stream_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + + try: + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + except Exception as e: + return self.create_error_response(str(e)) async def _make_request( self, @@ -393,7 +465,8 @@ class OpenAIServingResponses(OpenAIServing): ): if len(request.tools) > 0: raise NotImplementedError( - "Tool use is not supported in Responses API without Harmony") + "Tool use is not supported in Responses API without Harmony" + ) # Construct the input messages. messages = self._construct_input_messages(request, prev_response) _, request_prompts, engine_prompts = await self._preprocess_chat( @@ -412,10 +485,9 @@ class OpenAIServingResponses(OpenAIServing): ): if request.tool_choice != "auto": raise NotImplementedError( - "Only 'auto' tool_choice is supported in " - "response API with Harmony") - messages = self._construct_input_messages_with_harmony( - request, prev_response) + "Only 'auto' tool_choice is supported in response API with Harmony" + ) + messages = self._construct_input_messages_with_harmony(request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) @@ -425,6 +497,22 @@ class OpenAIServingResponses(OpenAIServing): return messages, [prompt_token_ids], [engine_prompt] + async def _initialize_tool_sessions( + self, + request: ResponsesRequest, + context: ConversationContext, + exit_stack: AsyncExitStack, + ): + # we should only initialize the tool session if the request needs tools + if len(request.tools) == 0: + return + mcp_tools = { + tool.server_label: tool for tool in request.tools if tool.type == "mcp" + } + await context.init_tool_sessions( + self.tool_server, exit_stack, request.request_id, mcp_tools + ) + async def responses_full_generator( self, request: ResponsesRequest, @@ -439,23 +527,38 @@ class OpenAIServingResponses(OpenAIServing): if created_time is None: created_time = int(time.time()) - try: - async for _ in result_generator: - pass - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + async with AsyncExitStack() as exit_stack: + try: + await self._initialize_tool_sessions(request, context, exit_stack) + async for _ in result_generator: + pass + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + # NOTE: Implementation of stauts is still WIP, but for now + # we guarantee that if the status is not "completed", it is accurate. + # "completed" is implemented as the "catch-all" for now. + status: ResponseStatus = "completed" + + input_messages = None + output_messages = None if self.use_harmony: assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) - # TODO: these are all 0 for now! - num_prompt_tokens = context.num_prompt_tokens - num_generated_tokens = context.num_output_tokens - num_cached_tokens = context.num_cached_tokens - num_reasoning_tokens = context.num_reasoning_tokens + if request.enable_response_messages: + input_messages = context.messages[: context.num_init_messages] + output_messages = context.messages[context.num_init_messages :] + num_tool_output_tokens = context.num_tool_output_tokens + if len(output) > 0: + if context.finish_reason == "length": + status = "incomplete" + elif context.finish_reason == "abort": + status = "cancelled" + else: + status = "incomplete" else: assert isinstance(context, SimpleContext) final_res = context.last_output @@ -463,32 +566,43 @@ class OpenAIServingResponses(OpenAIServing): assert len(final_res.outputs) == 1 final_output = final_res.outputs[0] - output = self._make_response_output_items(request, final_output, - tokenizer) + output = self._make_response_output_items(request, final_output, tokenizer) + # TODO: context for non-gptoss models doesn't use messages + # so we can't get them out yet + if request.enable_response_messages: + raise NotImplementedError( + "enable_response_messages is currently only supported for gpt-oss" + ) # Calculate usage. assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = len(final_output.token_ids) - num_cached_tokens = final_res.num_cached_tokens - num_reasoning_tokens = 0 + num_tool_output_tokens = 0 + + assert isinstance(context, (SimpleContext, HarmonyContext)) + num_prompt_tokens = context.num_prompt_tokens + num_generated_tokens = context.num_output_tokens + num_cached_tokens = context.num_cached_tokens + num_reasoning_tokens = context.num_reasoning_tokens usage = ResponseUsage( input_tokens=num_prompt_tokens, output_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, - input_tokens_details=InputTokensDetails( - cached_tokens=num_cached_tokens), + input_tokens_details=InputTokensDetails(cached_tokens=num_cached_tokens), output_tokens_details=OutputTokensDetails( - reasoning_tokens=num_reasoning_tokens), + reasoning_tokens=num_reasoning_tokens, + tool_output_tokens=num_tool_output_tokens, + ), ) response = ResponsesResponse.from_request( request, sampling_params, + input_messages=input_messages, + output_messages=output_messages, model_name=model_name, created_time=created_time, output=output, - status="completed", + status=status, usage=usage, ) @@ -496,56 +610,96 @@ class OpenAIServingResponses(OpenAIServing): async with self.response_store_lock: stored_response = self.response_store.get(response.id) # If the response is already cancelled, don't update it. - if (stored_response is None - or stored_response.status != "cancelled"): + if stored_response is None or stored_response.status != "cancelled": self.response_store[response.id] = response return response - def _topk_logprobs(self, logprobs: dict[int, - SampleLogprob], top_logprobs: int, - tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]: + def _topk_logprobs( + self, + logprobs: dict[int, SampleLogprob], + top_logprobs: int, + tokenizer: AnyTokenizer, + ) -> list[LogprobTopLogprob]: """Returns the top-k logprobs from the logprobs dictionary.""" out = [] for i, (token_id, _logprob) in enumerate(logprobs.items()): if i >= top_logprobs: break - text = _logprob.decoded_token if _logprob.decoded_token \ - is not None else tokenizer.decode([token_id]) + text = ( + _logprob.decoded_token + if _logprob.decoded_token is not None + else tokenizer.decode([token_id]) + ) out.append( LogprobTopLogprob( token=text, logprob=max(_logprob.logprob, -9999.0), bytes=list(text.encode("utf-8", errors="replace")), - )) + ) + ) return out def _create_response_logprobs( - self, - token_ids: Sequence[int], - logprobs: Optional[SampleLogprobs], - tokenizer: AnyTokenizer, - top_logprobs: Optional[int] = None) -> list[Logprob]: + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None, + ) -> list[Logprob]: assert logprobs is not None, "logprobs must be provided" assert len(token_ids) == len(logprobs), ( - "token_ids and logprobs.token_ids must have the same length") + "token_ids and logprobs.token_ids must have the same length" + ) out = [] for i, token_id in enumerate(token_ids): logprob = logprobs[i] token_logprob = logprob[token_id] - text = token_logprob.decoded_token if token_logprob.decoded_token \ - is not None else tokenizer.decode([token_id]) + text = ( + token_logprob.decoded_token + if token_logprob.decoded_token is not None + else tokenizer.decode([token_id]) + ) out.append( Logprob( token=text, logprob=max(token_logprob.logprob, -9999.0), bytes=list(text.encode("utf-8", errors="replace")), - top_logprobs=self._topk_logprobs(logprob, - top_logprobs=top_logprobs, - tokenizer=tokenizer) - if top_logprobs else [], - )) + top_logprobs=self._topk_logprobs( + logprob, top_logprobs=top_logprobs, tokenizer=tokenizer + ) + if top_logprobs + else [], + ) + ) return out + def _create_stream_response_logprobs( + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None, + ) -> list[response_text_delta_event.Logprob]: + lgs = self._create_response_logprobs( + token_ids=token_ids, + logprobs=logprobs, + tokenizer=tokenizer, + top_logprobs=top_logprobs, + ) + return [ + response_text_delta_event.Logprob( + token=lg.token, + logprob=lg.logprob, + top_logprobs=[ + response_text_delta_event.LogprobTopLogprob( + token=tl.token, logprob=tl.logprob + ) + for tl in lg.top_logprobs + ], + ) + for lg in lgs + ] + def _make_response_output_items( self, request: ResponsesRequest, @@ -559,9 +713,9 @@ class OpenAIServingResponses(OpenAIServing): logger.exception("Error in reasoning parser creation.") raise e - reasoning_content, content = ( - reasoning_parser.extract_reasoning_content(final_output.text, - request=request)) + reasoning_content, content = reasoning_parser.extract_reasoning_content( + final_output.text, request=request + ) else: reasoning_content = None content = final_output.text @@ -591,8 +745,9 @@ class OpenAIServingResponses(OpenAIServing): summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=reasoning_content, - type="reasoning_text") + ResponseReasoningTextContent( + text=reasoning_content, type="reasoning_text" + ) ], status=None, # NOTE: Only the last output item has status. ) @@ -607,7 +762,9 @@ class OpenAIServingResponses(OpenAIServing): logprobs=final_output.logprobs, tokenizer=tokenizer, top_logprobs=request.top_logprobs, - ) if request.is_include_output_logprobs() else None, + ) + if request.is_include_output_logprobs() + else None, ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", @@ -623,7 +780,7 @@ class OpenAIServingResponses(OpenAIServing): self, context: HarmonyContext, ) -> list[ResponseOutputItem]: - output_items = [] + output_items: list[ResponseOutputItem] = [] num_init_messages = context.num_init_messages for msg in context.messages[num_init_messages:]: output_items.extend(parse_output_message(msg)) @@ -640,10 +797,12 @@ class OpenAIServingResponses(OpenAIServing): ) -> list[ChatCompletionMessageParam]: messages: list[ChatCompletionMessageParam] = [] if request.instructions: - messages.append({ - "role": "system", - "content": request.instructions, - }) + messages.append( + { + "role": "system", + "content": request.instructions, + } + ) # Prepend the conversation history. if prev_response is not None: @@ -656,10 +815,12 @@ class OpenAIServingResponses(OpenAIServing): # NOTE: We skip the reasoning output. if isinstance(output_item, ResponseOutputMessage): for content in output_item.content: - messages.append({ - "role": "assistant", - "content": content.text, - }) + messages.append( + { + "role": "assistant", + "content": content.text, + } + ) # Append the new input. # Responses API supports simple text inputs without chat format. @@ -677,28 +838,55 @@ class OpenAIServingResponses(OpenAIServing): messages: list[OpenAIHarmonyMessage] = [] if prev_response is None: # New conversation. - reasoning_effort = (request.reasoning.effort - if request.reasoning else None) + reasoning_effort = request.reasoning.effort if request.reasoning else None tool_types = [tool.type for tool in request.tools] - enable_browser = ("web_search_preview" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("browser")) - enable_code_interpreter = ("code_interpreter" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("python")) + + # Allow the MCP Tool type to enable built in tools if the + # server_label is allowlisted in + # envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS + if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS: + for tool in request.tools: + if ( + tool.type == "mcp" + and tool.server_label in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS + ): + tool_types.append(tool.server_label) + enable_browser = ( + "web_search_preview" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("browser") + ) + enable_code_interpreter = ( + "code_interpreter" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("python") + ) + enable_container = ( + "container" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("container") + ) + with_custom_tools = has_custom_tools(tool_types) sys_msg = get_system_message( reasoning_effort=reasoning_effort, - browser_description=self.tool_server.get_tool_description( - "browser") - if enable_browser and self.tool_server is not None else None, - python_description=self.tool_server.get_tool_description( - "python") if enable_code_interpreter - and self.tool_server is not None else None, + browser_description=self.tool_server.get_tool_description("browser") + if enable_browser and self.tool_server is not None + else None, + python_description=self.tool_server.get_tool_description("python") + if enable_code_interpreter and self.tool_server is not None + else None, + container_description=self.tool_server.get_tool_description("container") + if enable_container and self.tool_server is not None + else None, + instructions=request.instructions, + with_custom_tools=with_custom_tools, ) messages.append(sys_msg) - dev_msg = get_developer_message(request.instructions, - request.tools) - messages.append(dev_msg) + if with_custom_tools: + dev_msg = get_developer_message( + instructions=request.instructions, tools=request.tools + ) + messages.append(dev_msg) else: # Continue the previous conversation. # FIXME(woosuk): Currently, request params like reasoning and @@ -718,15 +906,15 @@ class OpenAIServingResponses(OpenAIServing): if prev_msg_i.channel == "final": prev_final_msg_idx = i break - recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1:] - del prev_msgs[prev_final_msg_idx + 1:] + recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1 :] + del prev_msgs[prev_final_msg_idx + 1 :] for msg in recent_turn_msgs: assert isinstance(msg, OpenAIHarmonyMessage) if msg.channel != "analysis": prev_msgs.append(msg) messages.extend(prev_msgs) # Append the new input. - # Reponses API supports simple text inputs without chat format. + # Responses API supports simple text inputs without chat format. if isinstance(request.input, str): messages.append(get_user_message(request.input)) else: @@ -735,9 +923,8 @@ class OpenAIServingResponses(OpenAIServing): else: prev_outputs = [] for response_msg in request.input: - messages.append( - parse_response_input(response_msg, prev_outputs)) - # User passes in a a tool call request and its output. We need + messages.append(parse_response_input(response_msg, prev_outputs)) + # User passes in a tool call request and its output. We need # to add the tool call request to prev_outputs so that the # parse_response_input can find the tool call request when # parsing the tool call output. @@ -745,6 +932,36 @@ class OpenAIServingResponses(OpenAIServing): prev_outputs.append(response_msg) return messages + async def _run_background_request_stream( + self, + request: ResponsesRequest, + *args, + **kwargs, + ): + event_deque: deque[StreamingResponsesResponse] = deque() + new_event_signal = asyncio.Event() + self.event_store[request.request_id] = (event_deque, new_event_signal) + response = None + try: + generator = self.responses_stream_generator(request, *args, **kwargs) + async for event in generator: + event_deque.append(event) + new_event_signal.set() # Signal new event available + except Exception as e: + logger.exception("Background request failed for %s", request.request_id) + response = self.create_error_response(str(e)) + finally: + new_event_signal.set() + + if response is not None and isinstance(response, ErrorResponse): + # If the request has failed, update the status to "failed". + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + async def _run_background_request( self, request: ResponsesRequest, @@ -752,11 +969,9 @@ class OpenAIServingResponses(OpenAIServing): **kwargs, ): try: - response = await self.responses_full_generator( - request, *args, **kwargs) + response = await self.responses_full_generator(request, *args, **kwargs) except Exception as e: - logger.exception("Background request failed for %s", - request.request_id) + logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) if isinstance(response, ErrorResponse): @@ -768,27 +983,58 @@ class OpenAIServingResponses(OpenAIServing): if stored_response.status not in ("completed", "cancelled"): stored_response.status = "failed" + async def responses_background_stream_generator( + self, + response_id: str, + starting_after: Optional[int] = None, + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + if response_id not in self.event_store: + raise ValueError(f"Unknown response_id: {response_id}") + + event_deque, new_event_signal = self.event_store[response_id] + start_index = 0 if starting_after is None else starting_after + 1 + current_index = start_index + + while True: + new_event_signal.clear() + + # Yield existing events from start_index + while current_index < len(event_deque): + event = event_deque[current_index] + yield event + if getattr(event, "type", "unknown") == "response.completed": + return + current_index += 1 + + await new_event_signal.wait() + async def retrieve_responses( self, response_id: str, - ) -> Union[ErrorResponse, ResponsesResponse]: - if not response_id.startswith("resp_"): - return self._make_invalid_id_error(response_id) - + starting_after: Optional[int], + stream: Optional[bool], + ) -> Union[ + ErrorResponse, + ResponsesResponse, + AsyncGenerator[StreamingResponsesResponse, None], + ]: async with self.response_store_lock: response = self.response_store.get(response_id) if response is None: return self._make_not_found_error(response_id) + + if stream: + return self.responses_background_stream_generator( + response_id, + starting_after, + ) return response async def cancel_responses( self, response_id: str, ) -> Union[ErrorResponse, ResponsesResponse]: - if not response_id.startswith("resp_"): - return self._make_invalid_id_error(response_id) - async with self.response_store_lock: response = self.response_store.get(response_id) if response is None: @@ -805,22 +1051,14 @@ class OpenAIServingResponses(OpenAIServing): response.status = "cancelled" # Abort the request. - if (task := self.background_tasks.get(response_id)): + if task := self.background_tasks.get(response_id): task.cancel() try: await task except asyncio.CancelledError: - logger.exception("Background task for %s was cancelled", - response_id) + logger.exception("Background task for %s was cancelled", response_id) return response - def _make_invalid_id_error(self, response_id: str) -> ErrorResponse: - return self.create_error_response( - err_type="invalid_request_error", - message=(f"Invalid 'response_id': '{response_id}'. " - "Expected an ID that begins with 'resp'."), - ) - def _make_not_found_error(self, response_id: str) -> ErrorResponse: return self.create_error_response( err_type="invalid_request_error", @@ -831,14 +1069,16 @@ class OpenAIServingResponses(OpenAIServing): def _make_store_not_supported_error(self) -> ErrorResponse: return self.create_error_response( err_type="invalid_request_error", - message=("`store=True` (default) is not supported. Please set " - "`store=False` in Responses API or set " - "`VLLM_ENABLE_RESPONSES_API_STORE=1` in the env var when " - "starting the vLLM server."), + message=( + "`store=True` (default) is not supported. Please set " + "`store=False` in Responses API or set " + "`VLLM_ENABLE_RESPONSES_API_STORE=1` in the env var when " + "starting the vLLM server." + ), status_code=HTTPStatus.BAD_REQUEST, ) - async def responses_stream_generator( + async def _process_simple_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, @@ -847,60 +1087,320 @@ class OpenAIServingResponses(OpenAIServing): model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - created_time: Optional[int] = None, - ) -> AsyncGenerator[str, None]: - # TODO: - # 1. Handle disconnect - - if not isinstance(context, StreamingHarmonyContext): - raise NotImplementedError( - "Streaming is not supported for responses API without Harmony." - ) - - created_time = created_time or int(time.time()) - - sequence_number = 0 - - def _send_event(event: BaseModel): - nonlocal sequence_number - # Set sequence_number if the event has this attribute - if hasattr(event, 'sequence_number'): - event.sequence_number = sequence_number - sequence_number += 1 - # Get event type from the event's type field if it exists - event_type = getattr(event, 'type', 'unknown') - return (f"event: {event_type}\n" - f"data: {event.model_dump_json(indent=None)}\n\n") - - current_content_index = 0 # FIXME: this number is never changed + created_time: int, + _increment_sequence_number_and_return: Callable[ + [StreamingResponsesResponse], StreamingResponsesResponse + ], + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + current_content_index = 0 current_output_index = 0 - current_item_id = "" # FIXME: this number is never changed + current_item_id = "" + reasoning_parser = None + if self.reasoning_parser: + reasoning_parser = self.reasoning_parser(tokenizer) + previous_text = "" + previous_token_ids: list[int] = [] + first_delta_sent = False + previous_delta_messages: list[DeltaMessage] = [] + async for ctx in result_generator: + assert isinstance(ctx, SimpleContext) + if ctx.last_output is None: + continue + if ctx.last_output.outputs: + output = ctx.last_output.outputs[0] + if reasoning_parser: + delta_message = ( + reasoning_parser.extract_reasoning_content_streaming( + previous_text=previous_text, + current_text=previous_text + output.text, + delta_text=output.text, + previous_token_ids=previous_token_ids, + current_token_ids=previous_token_ids + output.token_ids, + delta_token_ids=output.token_ids, + ) + ) + else: + delta_message = DeltaMessage( + content=output.text, + ) + previous_text += output.text + previous_token_ids += output.token_ids + if not delta_message: + continue + if not first_delta_sent: + current_item_id = str(uuid.uuid4()) + if delta_message.reasoning_content: + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + ) + ) + else: + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + ) + ) + yield _increment_sequence_number_and_return( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) + ) + current_content_index += 1 + first_delta_sent = True + # todo(kebe7jun) tool call support + + # check delta message and previous delta message are + # same as content or reasoning content + if ( + previous_delta_messages + and previous_delta_messages[-1].reasoning_content is not None + and delta_message.content is not None + ): + # from reasoning to normal content, send done + # event for reasoning + reason_content = "".join( + pm.reasoning_content + for pm in previous_delta_messages + if pm.reasoning_content is not None + ) + yield _increment_sequence_number_and_return( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=reason_content, + ) + ) + current_content_index = 0 + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + ) + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + ) + ) + current_output_index += 1 + current_item_id = str(uuid.uuid4()) + yield _increment_sequence_number_and_return( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) + ) + current_content_index += 1 + # reset previous delta messages + previous_delta_messages = [] + + if delta_message.reasoning_content is not None: + yield _increment_sequence_number_and_return( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.reasoning_content, + ) + ) + elif delta_message.content is not None: + yield _increment_sequence_number_and_return( + ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.content, + logprobs=self._create_stream_response_logprobs( + token_ids=output.token_ids, + logprobs=output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) + if request.is_include_output_logprobs() + else [], + ) + ) + current_content_index += 1 + + previous_delta_messages.append(delta_message) + if previous_delta_messages: + if previous_delta_messages[-1].reasoning_content is not None: + reason_content = "".join( + pm.reasoning_content + for pm in previous_delta_messages + if pm.reasoning_content is not None + ) + yield _increment_sequence_number_and_return( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=reason_content, + ) + ) + current_content_index += 1 + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + ) + ) + elif previous_delta_messages[-1].content is not None: + final_content = "".join( + pm.content + for pm in previous_delta_messages + if pm.content is not None + ) + yield _increment_sequence_number_and_return( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=final_content, + logprobs=[], + item_id=current_item_id, + ) + ) + current_content_index += 1 + part = ResponseOutputText( + text=final_content, + type="output_text", + annotations=[], + ) + yield _increment_sequence_number_and_return( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=part, + ) + ) + current_content_index += 1 + item = ResponseOutputMessage( + type="message", + role="assistant", + content=[ + part, + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=item, + ) + ) + + async def _process_harmony_streaming_events( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: int, + _increment_sequence_number_and_return: Callable[ + [StreamingResponsesResponse], StreamingResponsesResponse + ], + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + current_content_index = -1 + current_output_index = 0 + current_item_id: str = "" sent_output_item_added = False - initial_response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="in_progress", - usage=None, - ).model_dump() - yield _send_event( - ResponseCreatedEvent( - type="response.created", - sequence_number=-1, - response=initial_response, - )) - yield _send_event( - ResponseInProgressEvent( - type="response.in_progress", - sequence_number=-1, - response=initial_response, - )) - async for ctx in result_generator: - assert isinstance(ctx, StreamingHarmonyContext) if ctx.is_expecting_start(): @@ -913,19 +1413,18 @@ class OpenAIServingResponses(OpenAIServing): # Deal with tool call here pass elif previous_item.channel == "analysis": + content = ResponseReasoningTextContent( + text=previous_item.content[0].text, + type="reasoning_text", + ) reasoning_item = ResponseReasoningItem( type="reasoning", - content=[ - ResponseReasoningTextContent( - text=previous_item.content[0].text, - type="reasoning_text", - ), - ], + content=[content], status="completed", id=current_item_id, summary=[], ) - yield _send_event( + yield _increment_sequence_number_and_return( ResponseReasoningTextDoneEvent( type="response.reasoning_text.done", item_id=current_item_id, @@ -933,22 +1432,34 @@ class OpenAIServingResponses(OpenAIServing): output_index=current_output_index, content_index=current_content_index, text=previous_item.content[0].text, - )) - yield _send_event( + ) + ) + yield _increment_sequence_number_and_return( + ResponseReasoningPartDoneEvent( + type="response.reasoning_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=content, + ) + ) + yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, item=reasoning_item, - )) + ) + ) elif previous_item.channel == "final": text_content = ResponseOutputText( type="output_text", text=previous_item.content[0].text, annotations=[], ) - yield _send_event( - openai_responses_types.ResponseTextDoneEvent( + yield _increment_sequence_number_and_return( + ResponseTextDoneEvent( type="response.output_text.done", sequence_number=-1, output_index=current_output_index, @@ -956,9 +1467,9 @@ class OpenAIServingResponses(OpenAIServing): text=previous_item.content[0].text, logprobs=[], item_id=current_item_id, - )) - yield _send_event( - openai_responses_types. + ) + ) + yield _increment_sequence_number_and_return( ResponseContentPartDoneEvent( type="response.content_part.done", sequence_number=-1, @@ -966,9 +1477,10 @@ class OpenAIServingResponses(OpenAIServing): output_index=current_output_index, content_index=current_content_index, part=text_content, - )) - yield _send_event( - openai_responses_types.ResponseOutputItemDoneEvent( + ) + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, @@ -979,45 +1491,50 @@ class OpenAIServingResponses(OpenAIServing): content=[text_content], status="completed", ), - )) + ) + ) + # stream the output of a harmony message if ctx.parser.last_content_delta: - if (ctx.parser.current_channel == "final" - and ctx.parser.current_recipient is None): + if ( + ctx.parser.current_channel == "final" + and ctx.parser.current_recipient is None + ): if not sent_output_item_added: sent_output_item_added = True - yield _send_event( - openai_responses_types. + current_item_id = f"msg_{random_uuid()}" + yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseOutputMessage( + item=ResponseOutputMessage( id=current_item_id, type="message", role="assistant", content=[], status="in_progress", ), - )) - yield _send_event( - openai_responses_types. + ) + ) + current_content_index += 1 + yield _increment_sequence_number_and_return( ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], logprobs=[], ), - )) - yield _send_event( - openai_responses_types.ResponseTextDeltaEvent( + ) + ) + yield _increment_sequence_number_and_return( + ResponseTextDeltaEvent( type="response.output_text.delta", sequence_number=-1, content_index=current_content_index, @@ -1026,41 +1543,43 @@ class OpenAIServingResponses(OpenAIServing): delta=ctx.parser.last_content_delta, # TODO, use logprobs from ctx.last_request_output logprobs=[], - )) - elif (ctx.parser.current_channel == "analysis" - and ctx.parser.current_recipient is None): + ) + ) + elif ( + ctx.parser.current_channel == "analysis" + and ctx.parser.current_recipient is None + ): if not sent_output_item_added: sent_output_item_added = True - yield _send_event( - openai_responses_types. + current_item_id = f"msg_{random_uuid()}" + yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseReasoningItem( + item=ResponseReasoningItem( type="reasoning", id=current_item_id, summary=[], status="in_progress", ), - )) - yield _send_event( - openai_responses_types. - ResponseContentPartAddedEvent( - type="response.content_part.added", + ) + ) + current_content_index += 1 + yield _increment_sequence_number_and_return( + ResponseReasoningPartAddedEvent( + type="response.reasoning_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( - type="output_text", + part=ResponseReasoningTextContent( text="", - annotations=[], - logprobs=[], + type="reasoning_text", ), - )) - yield _send_event( + ) + ) + yield _increment_sequence_number_and_return( ResponseReasoningTextDeltaEvent( type="response.reasoning_text.delta", item_id=current_item_id, @@ -1068,158 +1587,176 @@ class OpenAIServingResponses(OpenAIServing): content_index=current_content_index, delta=ctx.parser.last_content_delta, sequence_number=-1, - )) + ) + ) + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif ( + ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": + if not sent_output_item_added: + sent_output_item_added = True + current_item_id = f"tool_{random_uuid()}" + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + ) + ) + yield _increment_sequence_number_and_return( + ResponseCodeInterpreterCallInProgressEvent( + type="response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + ) + ) + yield _increment_sequence_number_and_return( + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + ) + ) + # stream tool call outputs if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] - if (self.tool_server is not None - and self.tool_server.has_tool("browser") - and previous_item.recipient is not None - and previous_item.recipient.startswith("browser.")): - function_name = previous_item.recipient[len("browser."):] + if ( + self.tool_server is not None + and self.tool_server.has_tool("browser") + and previous_item.recipient is not None + and previous_item.recipient.startswith("browser.") + ): + function_name = previous_item.recipient[len("browser.") :] action = None parsed_args = json.loads(previous_item.content[0].text) if function_name == "search": - action = (openai_responses_types. - response_function_web_search.ActionSearch( - type="search", - query=parsed_args["query"], - )) + action = response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + ) elif function_name == "open": - action = ( - openai_responses_types. - response_function_web_search.ActionOpenPage( - type="open_page", - # TODO: translate to url - url=f"cursor:{parsed_args.get('cursor', '')}", - )) + action = response_function_web_search.ActionOpenPage( + type="open_page", + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) elif function_name == "find": - action = ( - openai_responses_types. - response_function_web_search.ActionFind( - type="find", - pattern=parsed_args["pattern"], - # TODO: translate to url - url=f"cursor:{parsed_args.get('cursor', '')}", - )) + action = response_function_web_search.ActionFind( + type="find", + pattern=parsed_args["pattern"], + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) else: - raise ValueError( - f"Unknown function name: {function_name}") + raise ValueError(f"Unknown function name: {function_name}") - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( + current_item_id = f"tool_{random_uuid()}" + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - response_function_web_search. - ResponseFunctionWebSearch( + item=response_function_web_search.ResponseFunctionWebSearch( # TODO: generate a unique id for web search call type="web_search_call", id=current_item_id, action=action, status="in_progress", ), - )) - yield _send_event( - openai_responses_types. + ) + ) + yield _increment_sequence_number_and_return( ResponseWebSearchCallInProgressEvent( type="response.web_search_call.in_progress", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) - yield _send_event( - openai_responses_types. + ) + ) + yield _increment_sequence_number_and_return( ResponseWebSearchCallSearchingEvent( type="response.web_search_call.searching", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) # enqueue - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseWebSearchCallCompletedEvent( type="response.web_search_call.completed", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) - yield _send_event( - openai_responses_types.ResponseOutputItemDoneEvent( + ) + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseFunctionWebSearch( + item=ResponseFunctionWebSearch( type="web_search_call", id=current_item_id, action=action, status="completed", ), - )) + ) + ) - if (self.tool_server is not None - and self.tool_server.has_tool("python") - and previous_item.recipient is not None - and previous_item.recipient.startswith("python")): - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code="", - container_id="auto", - outputs=[], - status="in_progress", - ), - )) - yield _send_event( - openai_responses_types. - ResponseCodeInterpreterCallInProgressEvent( - type="response.code_interpreter_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - )) - # TODO: do we need to add delta event here? - yield _send_event( - openai_responses_types. + if ( + self.tool_server is not None + and self.tool_server.has_tool("python") + and previous_item.recipient is not None + and previous_item.recipient.startswith("python") + ): + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCodeDoneEvent( type="response.code_interpreter_call_code.done", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - code=previous_item.content[0].text)) - yield _send_event( - openai_responses_types. + code=previous_item.content[0].text, + ) + ) + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallInterpretingEvent( type="response.code_interpreter_call.interpreting", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) - yield _send_event( - openai_responses_types. + ) + ) + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCompletedEvent( type="response.code_interpreter_call.completed", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) - yield _send_event( - openai_responses_types.ResponseOutputItemDoneEvent( + ) + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( + item=ResponseCodeInterpreterToolCallParam( type="code_interpreter_call", id=current_item_id, code=previous_item.content[0].text, @@ -1228,27 +1765,104 @@ class OpenAIServingResponses(OpenAIServing): outputs=[], status="completed", ), - )) + ) + ) - async def empty_async_generator(): - # A hack to trick Python to think this is a generator but in fact - # it immediately returns. - if False: - yield + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + # TODO: + # 1. Handle disconnect - final_response = await self.responses_full_generator( - request, - sampling_params, - empty_async_generator(), - context, - model_name, - tokenizer, - request_metadata, - created_time=created_time, - ) - yield _send_event( - openai_responses_types.ResponseCompletedEvent( - type="response.completed", - sequence_number=-1, - response=final_response.model_dump(), - )) + created_time = created_time or int(time.time()) + + sequence_number = 0 + + def _increment_sequence_number_and_return( + event: StreamingResponsesResponse, + ) -> StreamingResponsesResponse: + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, "sequence_number"): + event.sequence_number = sequence_number + sequence_number += 1 + return event + + async with AsyncExitStack() as exit_stack: + processer = None + if self.use_harmony: + # TODO: in streaming, we noticed this bug: + # https://github.com/vllm-project/vllm/issues/25697 + await self._initialize_tool_sessions(request, context, exit_stack) + processer = self._process_harmony_streaming_events + else: + processer = self._process_simple_streaming_events + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _increment_sequence_number_and_return( + ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + ) + ) + yield _increment_sequence_number_and_return( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + ) + ) + + async for event_data in processer( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + _increment_sequence_number_and_return, + ): + yield event_data + + async def empty_async_generator(): + # A hack to trick Python to think this is a generator but + # in fact it immediately returns. + if False: + yield + + final_response = await self.responses_full_generator( + request, + sampling_params, + empty_async_generator(), + context, + model_name, + tokenizer, + request_metadata, + created_time=created_time, + ) + yield _increment_sequence_number_and_return( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response, + ) + ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index c246274514dbf..84ea33a07fa58 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -7,26 +7,30 @@ from typing import Any, Optional, Union from fastapi import Request -from vllm import envs -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, - RerankRequest, RerankResponse, - RerankResult, RerankUsage, - ScoreRequest, ScoreResponse, - ScoreResponseData, UsageInfo) +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + RerankDocument, + RerankRequest, + RerankResponse, + RerankResult, + RerankUsage, + ScoreRequest, + ScoreResponse, + ScoreResponseData, + UsageInfo, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam, - _cosine_similarity, - _validate_score_input_lens, - compress_token_type_ids, - get_score_prompt) -# yapf: enable +from vllm.entrypoints.score_utils import ( + ScoreContentPartParam, + ScoreMultiModalParam, + _cosine_similarity, + _validate_score_input_lens, + compress_token_type_ids, + get_score_prompt, +) from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger @@ -39,19 +43,20 @@ logger = init_logger(__name__) class ServingScores(OpenAIServing): - def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) async def _embedding_score( self, @@ -67,24 +72,23 @@ class ServingScores(OpenAIServing): input_texts = texts_1 + texts_2 engine_prompts: list[TokensPrompt] = [] - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) + tokenize_async = make_async( + tokenizer.__call__, executor=self._tokenizer_executor + ) tokenization_kwargs = tokenization_kwargs or {} tokenized_prompts = await asyncio.gather( - *(tokenize_async(t, **tokenization_kwargs) for t in input_texts)) + *(tokenize_async(t, **tokenization_kwargs) for t in input_texts) + ) for tok_result, input_text in zip(tokenized_prompts, input_texts): - - text_token_prompt = \ - self._validate_input( - request, - tok_result["input_ids"], - input_text) + text_token_prompt = self._validate_input( + request, tok_result["input_ids"], input_text + ) engine_prompts.append( - TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"])) + TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"]) + ) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] @@ -96,13 +100,14 @@ class ServingScores(OpenAIServing): return self.create_error_response(str(e)) for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - input_texts[i], - params=pooling_params, - lora_request=lora_request) + self._log_inputs( + request_id_item, + input_texts[i], + params=pooling_params, + lora_request=lora_request, + ) generators.append( self.engine_client.encode( @@ -112,15 +117,15 @@ class ServingScores(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, - )) + ) + ) result_generator = merge_async_iterators(*generators) # Non-streaming response final_res_batch: list[PoolingRequestOutput] = [] - embeddings: list[Optional[PoolingRequestOutput]] =\ - [None] * len(engine_prompts) + embeddings: list[Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) async for i, res in result_generator: embeddings[i] = res @@ -139,9 +144,9 @@ class ServingScores(OpenAIServing): if len(emb_texts_1) == 1: emb_texts_1 = emb_texts_1 * len(emb_texts_2) - final_res_batch = _cosine_similarity(tokenizer=tokenizer, - embed_1=emb_texts_1, - embed_2=emb_texts_2) + final_res_batch = _cosine_similarity( + tokenizer=tokenizer, embed_1=emb_texts_1, embed_2=emb_texts_2 + ) return final_res_batch @@ -153,7 +158,6 @@ class ServingScores(OpenAIServing): data_1: Union[str, ScoreContentPartParam], data_2: Union[str, ScoreContentPartParam], ) -> tuple[str, TokensPrompt]: - model_config = self.model_config full_prompt, engine_prompt = get_score_prompt( @@ -163,8 +167,7 @@ class ServingScores(OpenAIServing): tokenizer=tokenizer, tokenization_kwargs=tokenization_kwargs, ) - self._validate_input(request, engine_prompt["prompt_token_ids"], - full_prompt) + self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt) if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -188,22 +191,28 @@ class ServingScores(OpenAIServing): data_1 = data_1 * len(data_2) if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") + raise ValueError("MistralTokenizer not supported for cross-encoding") tokenization_kwargs = tokenization_kwargs or {} input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - preprocess_async = make_async(self._preprocess_score, - executor=self._tokenizer_executor) + preprocess_async = make_async( + self._preprocess_score, executor=self._tokenizer_executor + ) preprocessed_prompts = await asyncio.gather( - *(preprocess_async(request=request, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - data_1=t1, - data_2=t2) for t1, t2 in input_pairs)) + *( + preprocess_async( + request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2, + ) + for t1, t2 in input_pairs + ) + ) for full_prompt, engine_prompt in preprocessed_prompts: request_prompts.append(full_prompt) @@ -222,20 +231,19 @@ class ServingScores(OpenAIServing): for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - request_prompts[i], - params=default_pooling_params, - lora_request=lora_request) + self._log_inputs( + request_id_item, + request_prompts[i], + params=default_pooling_params, + lora_request=lora_request, + ) - if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( - "token_type_ids", None)): + if token_type_ids := engine_prompt.pop("token_type_ids", None): pooling_params = default_pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) - pooling_params.extra_kwargs = { - "compressed_token_type_ids": compressed - } + pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed} else: - pooling_params = (default_pooling_params) + pooling_params = default_pooling_params generator = self.engine_client.encode( engine_prompt, @@ -251,8 +259,9 @@ class ServingScores(OpenAIServing): result_generator = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: list[ - Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) + final_res_batch: list[Optional[PoolingRequestOutput]] = [None] * len( + engine_prompts + ) async for i, res in result_generator: final_res_batch[i] = res @@ -266,21 +275,27 @@ class ServingScores(OpenAIServing): request: Union[ScoreRequest, RerankRequest], request_id: str, raw_request: Optional[Request] = None, - truncate_prompt_tokens: Optional[int] = None, ) -> Union[list[PoolingRequestOutput], ErrorResponse]: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() + + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, - tokenization_kwargs) + _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens, tokenization_kwargs + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) - if not self.model_config.is_multimodal_model and (isinstance( - data_1, dict) or isinstance(data_2, dict)): + if not self.model_config.is_multimodal_model and ( + isinstance(data_1, dict) or isinstance(data_2, dict) + ): raise ValueError( f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501 ) @@ -306,7 +321,8 @@ class ServingScores(OpenAIServing): request_id=request_id, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - trace_headers=trace_headers) + trace_headers=trace_headers, + ) else: return await self._embedding_score( @@ -317,7 +333,8 @@ class ServingScores(OpenAIServing): request_id=request_id, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - trace_headers=trace_headers) + trace_headers=trace_headers, + ) async def create_score( self, @@ -343,7 +360,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch @@ -352,7 +368,7 @@ class ServingScores(OpenAIServing): final_res_batch, request_id, created_time, - self._get_model_name(request.model), + self.models.model_name(), ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -361,9 +377,7 @@ class ServingScores(OpenAIServing): return self.create_error_response(str(e)) async def do_rerank( - self, - request: RerankRequest, - raw_request: Optional[Request] = None + self, request: RerankRequest, raw_request: Optional[Request] = None ) -> Union[RerankResponse, ErrorResponse]: """ Rerank API based on JinaAI's rerank API; implements the same @@ -380,9 +394,15 @@ class ServingScores(OpenAIServing): request_id = f"rerank-{self._base_request_id(raw_request)}" documents = request.documents - top_n = request.top_n if request.top_n > 0 else ( - len(documents) - if isinstance(documents, list) else len(documents["content"])) + top_n = ( + request.top_n + if request.top_n > 0 + else ( + len(documents) + if isinstance(documents, list) + else len(documents["content"]) + ) + ) try: final_res_batch = await self._run_scoring( @@ -391,7 +411,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch @@ -399,7 +418,7 @@ class ServingScores(OpenAIServing): return self.request_output_to_rerank_response( final_res_batch, request_id, - self._get_model_name(request.model), + self.models.model_name(), documents, top_n, ) @@ -445,9 +464,13 @@ class ServingScores(OpenAIServing): ) def request_output_to_rerank_response( - self, final_res_batch: list[PoolingRequestOutput], request_id: str, - model_name: str, documents: Union[list[str], ScoreMultiModalParam], - top_n: int) -> RerankResponse: + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + model_name: str, + documents: Union[list[str], ScoreMultiModalParam], + top_n: int, + ) -> RerankResponse: """ Convert the output of do_rank to a RerankResponse """ @@ -458,9 +481,9 @@ class ServingScores(OpenAIServing): result = RerankResult( index=idx, - document=RerankDocument(text=documents[idx]) if isinstance( - documents, list) else RerankDocument( - multi_modal=documents["content"][idx]), + document=RerankDocument(text=documents[idx]) + if isinstance(documents, list) + else RerankDocument(multi_modal=documents["content"][idx]), relevance_score=classify_res.outputs.score, ) results.append(result) @@ -476,4 +499,5 @@ class ServingScores(OpenAIServing): id=request_id, model=model_name, results=results, - usage=RerankUsage(total_tokens=num_prompt_tokens)) + usage=RerankUsage(total_tokens=num_prompt_tokens), + ) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 58d720474768b..fb16d5ac690f1 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -6,22 +6,21 @@ from typing import Any, Final, Optional, Union import jinja2 from fastapi import Request -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (DetokenizeRequest, - DetokenizeResponse, - ErrorResponse, - TokenizeChatRequest, - TokenizeRequest, - TokenizeResponse, - TokenizerInfoResponse) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, + TokenizeResponse, + TokenizerInfoResponse, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -29,24 +28,27 @@ logger = init_logger(__name__) class OpenAIServingTokenization(OpenAIServing): - def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, + log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger) + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_tokenize( self, @@ -62,14 +64,25 @@ class OpenAIServingTokenization(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) if isinstance(request, TokenizeChatRequest): - tool_dicts = (None if request.tools is None else - [tool.model_dump() for tool in request.tools]) + tool_dicts = ( + None + if request.tools is None + else [tool.model_dump() for tool in request.tools] + ) + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, @@ -77,44 +90,40 @@ class OpenAIServingTokenization(OpenAIServing): request.messages, tool_dicts=tool_dicts, chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, + chat_template_content_format=self.chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, chat_template_kwargs=request.chat_template_kwargs, add_special_tokens=request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.prompt, + config=self._build_render_config(request), + ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") input_ids: list[int] = [] - for i, engine_prompt in enumerate(engine_prompts): - self._log_inputs(request_id, - request_prompts[i], - params=None, - lora_request=lora_request) + for engine_prompt in engine_prompts: + self._log_inputs( + request_id, engine_prompt, params=None, lora_request=lora_request + ) - if isinstance(engine_prompt, - dict) and "prompt_token_ids" in engine_prompt: + if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt: input_ids.extend(engine_prompt["prompt_token_ids"]) token_strs = None if request.return_token_strs: token_strs = tokenizer.convert_ids_to_tokens(input_ids) - return TokenizeResponse(tokens=input_ids, - token_strs=token_strs, - count=len(input_ids), - max_model_len=self.max_model_len) + return TokenizeResponse( + tokens=input_ids, + token_strs=token_strs, + count=len(input_ids), + max_model_len=self.max_model_len, + ) async def create_detokenize( self, @@ -129,12 +138,11 @@ class OpenAIServingTokenization(OpenAIServing): lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() - self._log_inputs(request_id, - request.tokens, - params=None, - lora_request=lora_request) + self._log_inputs( + request_id, request.tokens, params=None, lora_request=lora_request + ) prompt_input = await self._tokenize_prompt_input_async( request, @@ -146,15 +154,18 @@ class OpenAIServingTokenization(OpenAIServing): return DetokenizeResponse(prompt=input_text) async def get_tokenizer_info( - self, ) -> Union[TokenizerInfoResponse, ErrorResponse]: + self, + ) -> Union[TokenizerInfoResponse, ErrorResponse]: """Get comprehensive tokenizer information.""" try: tokenizer = await self.engine_client.get_tokenizer() info = TokenizerInfo(tokenizer, self.chat_template).to_dict() return TokenizerInfoResponse(**info) except Exception as e: - return self.create_error_response( - f"Failed to get tokenizer info: {str(e)}") + return self.create_error_response(f"Failed to get tokenizer info: {str(e)}") + + def _build_render_config(self, request: TokenizeRequest) -> RenderConfig: + return RenderConfig(add_special_tokens=request.add_special_tokens) @dataclass diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 0d6989fe91bfa..f6b08bf11aacf 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -5,14 +5,20 @@ from typing import Optional, Union from fastapi import Request -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - ErrorResponse, RequestResponseMetadata, TranscriptionRequest, - TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionStreamResponse, TranslationRequest, TranslationResponse, - TranslationResponseStreamChoice, TranslationStreamResponse) + ErrorResponse, + RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, + TranslationRequest, + TranslationResponse, + TranslationResponseStreamChoice, + TranslationStreamResponse, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText from vllm.logger import init_logger @@ -27,24 +33,24 @@ class OpenAIServingTranscription(OpenAISpeechToText): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="transcribe") + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="transcribe", + log_error_stack=log_error_stack, + ) async def create_transcription( - self, audio_data: bytes, request: TranscriptionRequest, - raw_request: Request - ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], - ErrorResponse]: + self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request + ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], ErrorResponse]: """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription @@ -59,10 +65,13 @@ class OpenAIServingTranscription(OpenAISpeechToText): ) async def transcription_stream_generator( - self, request: TranscriptionRequest, - result_generator: list[AsyncGenerator[RequestOutput, None]], - request_id: str, request_metadata: RequestResponseMetadata, - audio_duration_s: float) -> AsyncGenerator[str, None]: + self, + request: TranscriptionRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + ) -> AsyncGenerator[str, None]: generator = self._speech_to_text_stream_generator( request=request, list_result_generator=result_generator, @@ -83,22 +92,23 @@ class OpenAIServingTranslation(OpenAISpeechToText): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="translate") + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="translate", + log_error_stack=log_error_stack, + ) async def create_translation( - self, audio_data: bytes, request: TranslationRequest, - raw_request: Request + self, audio_data: bytes, request: TranslationRequest, raw_request: Request ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]: """Translation API similar to OpenAI's API. @@ -114,10 +124,13 @@ class OpenAIServingTranslation(OpenAISpeechToText): ) async def translation_stream_generator( - self, request: TranslationRequest, - result_generator: list[AsyncGenerator[RequestOutput, None]], - request_id: str, request_metadata: RequestResponseMetadata, - audio_duration_s: float) -> AsyncGenerator[str, None]: + self, + request: TranslationRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + ) -> AsyncGenerator[str, None]: generator = self._speech_to_text_stream_generator( request=request, list_result_generator=result_generator, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 01140a4bfea7e..2f518574242bf 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -12,16 +12,21 @@ import numpy as np from fastapi import Request import vllm.envs as envs -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - DeltaMessage, ErrorResponse, RequestResponseMetadata, - TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionStreamResponse, TranslationResponse, - TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - SpeechToTextRequest) + DeltaMessage, + ErrorResponse, + RequestResponseMetadata, + TranscriptionResponse, + TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, + TranslationResponse, + TranslationResponseStreamChoice, + TranslationStreamResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -41,42 +46,46 @@ logger = init_logger(__name__) class OpenAISpeechToText(OpenAIServing): - """Base class for speech-to-text operations like transcription and + """Base class for speech-to-text operations like transcription and translation.""" def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", + log_error_stack: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack, + ) - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() self.task_type = task_type self.asr_config = self.model_cls.get_speech_to_text_config( - model_config, task_type) + self.model_config, task_type + ) self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", - self.default_sampling_params) + self.default_sampling_params, + ) @cached_property def model_cls(self) -> type[SupportsTranscription]: from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) return cast(type[SupportsTranscription], model_cls) @@ -87,6 +96,12 @@ class OpenAISpeechToText(OpenAIServing): ) -> tuple[list[PromptType], float]: # Validate request language = self.model_cls.validate_language(request.language) + # Skip to_language validation to avoid extra logging for Whisper. + to_language = ( + self.model_cls.validate_language(request.to_language) + if request.to_language + else None + ) if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: raise ValueError("Maximum file size exceeded.") @@ -97,8 +112,10 @@ class OpenAISpeechToText(OpenAIServing): y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) duration = librosa.get_duration(y=y, sr=sr) - do_split_audio = (self.asr_config.allow_audio_chunking - and duration > self.asr_config.max_audio_clip_s) + do_split_audio = ( + self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s + ) chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) prompts = [] for chunk in chunks: @@ -110,7 +127,9 @@ class OpenAISpeechToText(OpenAIServing): model_config=self.model_config, language=language, task_type=self.task_type, - request_prompt=request.prompt) + request_prompt=request.prompt, + to_language=to_language, + ) prompts.append(prompt) return prompts, duration @@ -122,7 +141,7 @@ class OpenAISpeechToText(OpenAIServing): response_class: type[T], stream_generator_method: Callable[..., AsyncGenerator[str, None]], ) -> Union[T, AsyncGenerator[str, None], ErrorResponse]: - """Base method for speech-to-text operations like transcription and + """Base method for speech-to-text operations like transcription and translation.""" error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -134,9 +153,10 @@ class OpenAISpeechToText(OpenAIServing): if self.engine_client.errored: raise self.engine_client.dead_error - if request.response_format not in ['text', 'json']: + if request.response_format not in ["text", "json"]: return self.create_error_response( - "Currently only support response_format `text` or `json`") + "Currently only support response_format `text` or `json`" + ) request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" @@ -149,8 +169,8 @@ class OpenAISpeechToText(OpenAIServing): if lora_request: return self.create_error_response( - "Currently do not support LoRA for " - f"{self.task_type.title()}.") + f"Currently do not support LoRA for {self.task_type.title()}." + ) prompts, duration_s = await self._preprocess_speech_to_text( request=request, @@ -161,38 +181,42 @@ class OpenAISpeechToText(OpenAIServing): logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - list_result_generator: Optional[list[AsyncGenerator[RequestOutput, - None]]] = None + list_result_generator: Optional[list[AsyncGenerator[RequestOutput, None]]] = ( + None + ) try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a # fixed-size log-mel-spectogram. default_max_tokens = self.model_config.max_model_len sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) + default_max_tokens, self.default_sampling_params + ) self._log_inputs( request_id, # It will not display special tokens like <|startoftranscript|> request.prompt, params=sampling_params, - lora_request=None) + lora_request=None, + ) list_result_generator = [ self.engine_client.generate( prompt, sampling_params, request_id, - ) for prompt in prompts + ) + for prompt in prompts ] except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) if request.stream: - return stream_generator_method(request, list_result_generator, - request_id, request_metadata, - duration_s) + return stream_generator_method( + request, list_result_generator, request_id, request_metadata, duration_s + ) # Non-streaming response. try: assert list_result_generator is not None @@ -200,7 +224,20 @@ class OpenAISpeechToText(OpenAIServing): for result_generator in list_result_generator: async for op in result_generator: text += op.outputs[0].text - return cast(T, response_class(text=text)) + + if self.task_type == "transcribe": + # add usage in TranscriptionResponse. + usage = { + "type": "duration", + # rounded up as per openAI specs + "seconds": int(math.ceil(duration_s)), + } + final_response = cast(T, response_class(text=text, usage=usage)) + else: + # no usage in response for translation task + final_response = cast(T, response_class(text=text)) # type: ignore[call-arg] + + return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: @@ -217,9 +254,11 @@ class OpenAISpeechToText(OpenAIServing): chunk_object_type: Literal["translation.chunk", "transcription.chunk"], response_stream_choice_class: Union[ type[TranscriptionResponseStreamChoice], - type[TranslationResponseStreamChoice]], - stream_response_class: Union[type[TranscriptionStreamResponse], - type[TranslationStreamResponse]], + type[TranslationResponseStreamChoice], + ], + stream_response_class: Union[ + type[TranscriptionStreamResponse], type[TranslationStreamResponse] + ], ) -> AsyncGenerator[str, None]: created_time = int(time.time()) model_name = request.model @@ -227,11 +266,14 @@ class OpenAISpeechToText(OpenAIServing): completion_tokens = 0 num_prompt_tokens = 0 - include_usage = request.stream_include_usage \ - if request.stream_include_usage else False - include_continuous_usage = request.stream_continuous_usage_stats\ - if include_usage and request.stream_continuous_usage_stats\ + include_usage = ( + request.stream_include_usage if request.stream_include_usage else False + ) + include_continuous_usage = ( + request.stream_continuous_usage_stats + if include_usage and request.stream_continuous_usage_stats else False + ) try: for result_generator in list_result_generator: @@ -240,8 +282,8 @@ class OpenAISpeechToText(OpenAIServing): if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) if audio_tokens := self.model_cls.get_num_audio_tokens( - audio_duration_s, self.asr_config, - self.model_config): + audio_duration_s, self.asr_config, self.model_config + ): num_prompt_tokens += audio_tokens # We need to do it here, because if there are exceptions in @@ -257,20 +299,22 @@ class OpenAISpeechToText(OpenAIServing): if output.finish_reason is None: # Still generating, send delta update. - choice_data = response_stream_choice_class( - delta=delta_message) + choice_data = response_stream_choice_class(delta=delta_message) else: # Model is finished generating. choice_data = response_stream_choice_class( delta=delta_message, finish_reason=output.finish_reason, - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + ) - chunk = stream_response_class(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) + chunk = stream_response_class( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + ) # handle usage stats if requested & if continuous if include_continuous_usage: @@ -286,10 +330,11 @@ class OpenAISpeechToText(OpenAIServing): # Once the final token is handled, if stream_options.include_usage # is sent, send the usage. if include_usage: - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) final_usage_chunk = stream_response_class( id=request_id, @@ -297,16 +342,19 @@ class OpenAISpeechToText(OpenAIServing): created=created_time, choices=[], model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices request_metadata.final_usage_info = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens) + total_tokens=num_prompt_tokens + completion_tokens, + ) except Exception as e: # TODO: Use a vllm-specific Validation Error @@ -316,8 +364,9 @@ class OpenAISpeechToText(OpenAIServing): # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" - def _split_audio(self, audio_data: np.ndarray, - sample_rate: int) -> list[np.ndarray]: + def _split_audio( + self, audio_data: np.ndarray, sample_rate: int + ) -> list[np.ndarray]: chunk_size = sample_rate * self.asr_config.max_audio_clip_s overlap_size = sample_rate * self.asr_config.overlap_chunk_second chunks = [] @@ -331,17 +380,15 @@ class OpenAISpeechToText(OpenAIServing): # Find the best split point in the overlap region search_start = i + chunk_size - overlap_size search_end = min(i + chunk_size, audio_data.shape[-1]) - split_point = self._find_split_point(audio_data, search_start, - search_end) + split_point = self._find_split_point(audio_data, search_start, search_end) # Extract chunk up to the split point chunks.append(audio_data[..., i:split_point]) i = split_point return chunks - def _find_split_point(self, wav: np.ndarray, start_idx: int, - end_idx: int) -> int: - """Find the best point to split audio by + def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int: + """Find the best point to split audio by looking for silence or low amplitude. Args: wav: Audio tensor [1, T] @@ -358,8 +405,8 @@ class OpenAISpeechToText(OpenAIServing): min_energy_window = self.asr_config.min_energy_split_window_size assert min_energy_window is not None for i in range(0, len(segment) - min_energy_window, min_energy_window): - window = segment[i:i + min_energy_window] - energy = (window**2).mean()**0.5 + window = segment[i : i + min_energy_window] + energy = (window**2).mean() ** 0.5 if energy < min_energy: quietest_idx = i + start_idx min_energy = energy diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 44aa1208a54c7..2c5a0a6af23f0 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -14,11 +14,14 @@ from .jamba_tool_parser import JambaToolParser from .kimi_k2_tool_parser import KimiK2ToolParser from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser +from .longcat_tool_parser import LongcatFlashToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser +from .openai_tool_parser import OpenAIToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser +from .qwen3xml_tool_parser import Qwen3XMLToolParser from .seed_oss_tool_parser import SeedOssToolParser from .step3_tool_parser import Step3ToolParser from .xlam_tool_parser import xLAMToolParser @@ -34,6 +37,7 @@ __all__ = [ "Llama3JsonToolParser", "JambaToolParser", "Llama4PythonicToolParser", + "LongcatFlashToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", @@ -44,6 +48,8 @@ __all__ = [ "HunyuanA13BToolParser", "Glm4MoeModelToolParser", "Qwen3CoderToolParser", + "Qwen3XMLToolParser", "SeedOssToolParser", "Step3ToolParser", + "OpenAIToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 02aeab6136316..e6ee2fa777f81 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -6,9 +6,11 @@ from collections.abc import Sequence from functools import cached_property from typing import Callable, Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import import_from_path, is_list_of @@ -38,16 +40,15 @@ class ToolParser: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: """ Static method that used to adjust the request parameters. """ return request def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from a complete model-generated string. @@ -56,7 +57,8 @@ class ToolParser: Static because it's stateless. """ raise NotImplementedError( - "AbstractToolParser.extract_tool_calls has not been implemented!") + "AbstractToolParser.extract_tool_calls has not been implemented!" + ) def extract_tool_calls_streaming( self, @@ -76,8 +78,8 @@ class ToolParser: previously been parsed and extracted (see constructor) """ raise NotImplementedError( - "AbstractToolParser.extract_tool_calls_streaming has not been " - "implemented!") + "AbstractToolParser.extract_tool_calls_streaming has not been implemented!" + ) class ToolParserManager: @@ -96,13 +98,15 @@ class ToolParserManager: raise KeyError(f"tool helper: '{name}' not found in tool_parsers") @classmethod - def _register_module(cls, - module: type, - module_name: Optional[Union[str, list[str]]] = None, - force: bool = True) -> None: + def _register_module( + cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True, + ) -> None: if not issubclass(module, ToolParser): raise TypeError( - f'module must be subclass of ToolParser, but got {type(module)}' + f"module must be subclass of ToolParser, but got {type(module)}" ) if module_name is None: module_name = module.__name__ @@ -111,30 +115,32 @@ class ToolParserManager: for name in module_name: if not force and name in cls.tool_parsers: existed_module = cls.tool_parsers[name] - raise KeyError(f'{name} is already registered ' - f'at {existed_module.__module__}') + raise KeyError( + f"{name} is already registered at {existed_module.__module__}" + ) cls.tool_parsers[name] = module @classmethod def register_module( - cls, - name: Optional[Union[str, list[str]]] = None, - force: bool = True, - module: Union[type, None] = None) -> Union[type, Callable]: + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None, + ) -> Union[type, Callable]: """ Register module with the given name or name list. it can be used as a - decoder(with module as None) or normal function(with module as not + decoder(with module as None) or normal function(with module as not None). """ if not isinstance(force, bool): - raise TypeError(f'force must be a boolean, but got {type(force)}') + raise TypeError(f"force must be a boolean, but got {type(force)}") # raise the error ahead of time - if not (name is None or isinstance(name, str) - or is_list_of(name, str)): + if not (name is None or isinstance(name, str) or is_list_of(name, str)): raise TypeError( - 'name must be None, an instance of str, or a sequence of str, ' - f'but got {type(name)}') + "name must be None, an instance of str, or a sequence of str, " + f"but got {type(name)}" + ) # use it as a normal method: x.register_module(module=SomeClass) if module is not None: @@ -159,6 +165,7 @@ class ToolParserManager: try: import_from_path(module_name, plugin_path) except Exception: - logger.exception("Failed to load module '%s' from %s.", - module_name, plugin_path) + logger.exception( + "Failed to load module '%s' from %s.", module_name, plugin_path + ) return diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py index 2656db9c6238b..c6e8f1686e245 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -7,13 +7,19 @@ from typing import Union import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -22,15 +28,15 @@ logger = init_logger(__name__) @ToolParserManager.register_module("deepseek_v31") class DeepSeekV31ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - []) # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" self.tool_calls_end_token: str = "<|tool▁calls▁end|>" @@ -39,45 +45,47 @@ class DeepSeekV31ToolParser(ToolParser): self.tool_call_end_token: str = "<|tool▁call▁end|>" self.tool_call_regex = re.compile( - r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>" + r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>" ) self.stream_tool_call_portion_regex = re.compile( - r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)") + r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)" + ) self.stream_tool_call_name_regex = re.compile( - r"(?P<function_name>.*)<|tool▁sep|>") + r"(?P<function_name>.*)<|tool▁sep|>" + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( - "DeepSeek-V3 Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "DeepSeek-V3.1 Tool parser could not locate tool call " + "start/end tokens in the tokenizer!" + ) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: try: @@ -85,8 +93,7 @@ class DeepSeekV31ToolParser(ToolParser): # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = self.tool_call_regex.findall( - model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) tool_calls = [] for match in function_call_tuples: @@ -94,12 +101,13 @@ class DeepSeekV31ToolParser(ToolParser): tool_calls.append( ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=function_args), - )) + function=FunctionCall( + name=function_name, arguments=function_args + ), + ) + ) - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -107,11 +115,10 @@ class DeepSeekV31ToolParser(ToolParser): ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -123,55 +130,58 @@ class DeepSeekV31ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_calls_end_token, - "") + delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( + self.tool_calls_end_token, "" + ) try: - # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -185,27 +195,29 @@ class DeepSeekV31ToolParser(ToolParser): logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') @@ -216,13 +228,16 @@ class DeepSeekV31ToolParser(ToolParser): diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -233,17 +248,17 @@ class DeepSeekV31ToolParser(ToolParser): current_tool_call = dict() if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) + current_tool_call_matches = self.stream_tool_call_portion_regex.match( + tool_call_portion + ) if current_tool_call_matches: tool_name, tool_args = current_tool_call_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) + self.stream_tool_call_name_regex.match(tool_call_portion) + ) if current_tool_call_name_matches: tool_name = current_tool_call_name_matches.groups() current_tool_call["name"] = tool_name @@ -260,16 +275,18 @@ class DeepSeekV31ToolParser(ToolParser): function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None @@ -279,15 +296,19 @@ class DeepSeekV31ToolParser(ToolParser): if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -297,7 +318,8 @@ class DeepSeekV31ToolParser(ToolParser): # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -311,52 +333,56 @@ class DeepSeekV31ToolParser(ToolParser): # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments - - # last case -- we have an update to existing arguments. - elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] - logger.debug("got diff %s", delta_text) - - delta = DeltaMessage(tool_calls=[ + delta = DeltaMessage( + tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), + arguments=cur_arguments + ).model_dump(exclude_none=True), ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if ( + isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + delta_arguments = cur_arguments[len(prev_arguments) :] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index ac272b0c3b205..e8a5d2e6dc133 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -7,13 +7,19 @@ from typing import Union import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -22,15 +28,15 @@ logger = init_logger(__name__) @ToolParserManager.register_module("deepseek_v3") class DeepSeekV3ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - []) # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" self.tool_calls_end_token: str = "<|tool▁calls▁end|>" @@ -47,38 +53,39 @@ class DeepSeekV3ToolParser(ToolParser): ) self.stream_tool_call_name_regex = re.compile( - r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n") + r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n" + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "DeepSeek-V3 Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: try: @@ -86,8 +93,7 @@ class DeepSeekV3ToolParser(ToolParser): # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = self.tool_call_regex.findall( - model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) tool_calls = [] for match in function_call_tuples: @@ -95,12 +101,13 @@ class DeepSeekV3ToolParser(ToolParser): tool_calls.append( ToolCall( type=tool_type, - function=FunctionCall(name=function_name, - arguments=function_args), - )) + function=FunctionCall( + name=function_name, arguments=function_args + ), + ) + ) - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -108,11 +115,10 @@ class DeepSeekV3ToolParser(ToolParser): ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -124,55 +130,58 @@ class DeepSeekV3ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_calls_end_token, - "") + delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( + self.tool_calls_end_token, "" + ) try: - # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -186,27 +195,29 @@ class DeepSeekV3ToolParser(ToolParser): logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') @@ -217,13 +228,16 @@ class DeepSeekV3ToolParser(ToolParser): diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -234,21 +248,19 @@ class DeepSeekV3ToolParser(ToolParser): current_tool_call = dict() if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) + current_tool_call_matches = self.stream_tool_call_portion_regex.match( + tool_call_portion + ) if current_tool_call_matches: - tool_type, tool_name, tool_args = ( - current_tool_call_matches.groups()) + tool_type, tool_name, tool_args = current_tool_call_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) + self.stream_tool_call_name_regex.match(tool_call_portion) + ) if current_tool_call_name_matches: - tool_type, tool_name = ( - current_tool_call_name_matches.groups()) + tool_type, tool_name = current_tool_call_name_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: @@ -263,16 +275,18 @@ class DeepSeekV3ToolParser(ToolParser): function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None @@ -282,15 +296,19 @@ class DeepSeekV3ToolParser(ToolParser): if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -300,7 +318,8 @@ class DeepSeekV3ToolParser(ToolParser): # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -314,52 +333,56 @@ class DeepSeekV3ToolParser(ToolParser): # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments - - # last case -- we have an update to existing arguments. - elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] - logger.debug("got diff %s", delta_text) - - delta = DeltaMessage(tool_calls=[ + delta = DeltaMessage( + tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), + arguments=cur_arguments + ).model_dump(exclude_none=True), ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if ( + isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + delta_arguments = cur_arguments[len(prev_arguments) :] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py index 8fd14f171d0af..1d7d7d3f8629d 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -8,14 +8,20 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -24,7 +30,6 @@ logger = init_logger(__name__) @ToolParserManager.register_module("glm45") class Glm4MoeModelToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent = False @@ -36,20 +41,20 @@ class Glm4MoeModelToolParser(ToolParser): self.tool_calls_start_token = self.tool_call_start_token - self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", - re.DOTALL) + self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL) self.func_detail_regex = re.compile( - r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL) + r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL + ) self.func_arg_regex = re.compile( - r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", - re.DOTALL) + r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) self._buffer = "" @@ -58,18 +63,22 @@ class Glm4MoeModelToolParser(ToolParser): model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - def _is_string_type( - tool_name: str, arg_name: str, - tools: Optional[list[ChatCompletionToolsParam]]) -> bool: + tool_name: str, + arg_name: str, + tools: Optional[list[ChatCompletionToolsParam]], + ) -> bool: if tools is None: return False for tool in tools: if tool.function.name == tool_name: if tool.function.parameters is None: return False - arg_type = tool.function.parameters.get( - "properties", {}).get(arg_name, {}).get("type", None) + arg_type = ( + tool.function.parameters.get("properties", {}) + .get(arg_name, {}) + .get("type", None) + ) return arg_type == "string" logger.warning("No tool named '%s'.", tool_name) return False @@ -101,28 +110,30 @@ class Glm4MoeModelToolParser(ToolParser): arg_val = value.strip() if not _is_string_type(tc_name, arg_key, request.tools): arg_val = _deserialize(arg_val) - logger.debug("arg_key = %s, arg_val = %s", arg_key, - arg_val) + logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) arg_dct[arg_key] = arg_val tool_calls.append( - ToolCall(type="function", - function=FunctionCall( - name=tc_name, arguments=json.dumps(arg_dct)))) + ToolCall( + type="function", + function=FunctionCall( + name=tc_name, arguments=json.dumps(arg_dct) + ), + ) + ) except Exception: logger.exception("Failed to extract tool call spec") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: if len(tool_calls) > 0: - content = model_output[:model_output. - find(self.tool_calls_start_token)] - return ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=content) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + content = model_output[: model_output.find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -155,7 +166,8 @@ class Glm4MoeModelToolParser(ToolParser): self.streamed_args_for_tool.append("") extracted_tool_calls = self.extract_tool_calls( - cur_text[:end_idx + len(self.tool_call_end_token)], request) + cur_text[: end_idx + len(self.tool_call_end_token)], request + ) if len(extracted_tool_calls.tool_calls) == 0: logger.warning("Failed to extract any tool calls.") @@ -163,22 +175,27 @@ class Glm4MoeModelToolParser(ToolParser): tool_call = extracted_tool_calls.tool_calls[0] self.prev_tool_call_arr[self.current_tool_id] = { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments) + "arguments": json.loads(tool_call.function.arguments), } - self.streamed_args_for_tool[ - self.current_tool_id] = tool_call.function.arguments + self.streamed_args_for_tool[self.current_tool_id] = ( + tool_call.function.arguments + ) delta = DeltaMessage( content=extracted_tool_calls.content, tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - id=tool_call.id, - type=tool_call.type, - function=DeltaFunctionCall( - name=tool_call.function.name, - arguments=tool_call.function.arguments)) - ]) + DeltaToolCall( + index=self.current_tool_id, + id=tool_call.id, + type=tool_call.type, + function=DeltaFunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ], + ) self.current_tool_id += 1 - self._buffer = cur_text[end_idx + len(self.tool_call_end_token):] + self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :] return delta self._buffer = cur_text[start_idx:] diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 824b100f357b5..c42b358b1e34b 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -11,17 +11,25 @@ import regex as re from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, - find_common_prefix, - is_complete_json, - partial_json_loads) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import ( + consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -47,12 +55,12 @@ class Granite20bFCToolParser(ToolParser): self.tool_call_regex = re.compile(r"<function_call>\s*") def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: if self.tool_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) dec = JSONDecoder() try: @@ -66,13 +74,15 @@ class Granite20bFCToolParser(ToolParser): start_of_json = match.end() # end_index == the start of the next function call # (if exists) - next_function_call_start = (matches[i + 1].start() if i + - 1 < len(matches) else None) + next_function_call_start = ( + matches[i + 1].start() if i + 1 < len(matches) else None + ) raw_function_calls.append( dec.raw_decode( - model_output[start_of_json:next_function_call_start]) - [0]) + model_output[start_of_json:next_function_call_start] + )[0] + ) logger.debug("Extracted %d tool calls", len(raw_function_calls)) tool_calls = [ @@ -81,13 +91,15 @@ class Granite20bFCToolParser(ToolParser): function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False), + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), ), - ) for function_call in raw_function_calls + ) + for function_call in raw_function_calls ] - content = model_output[:model_output.find(self.bot_token)] + content = model_output[: model_output.find(self.bot_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -96,9 +108,9 @@ class Granite20bFCToolParser(ToolParser): except Exception as e: logger.error("Error in extracting tool call from response %s", e) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -110,9 +122,9 @@ class Granite20bFCToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - - if len(current_text) < len( - self.bot_token) and self.bot_token.startswith(current_text): + if len(current_text) < len(self.bot_token) and self.bot_token.startswith( + current_text + ): return None if not current_text.startswith(self.bot_token): @@ -122,8 +134,7 @@ class Granite20bFCToolParser(ToolParser): # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] @@ -132,24 +143,23 @@ class Granite20bFCToolParser(ToolParser): start_idx = consume_space(start_idx, current_text) while start_idx < len(current_text): - (obj, - end_idx) = partial_json_loads(current_text[start_idx:], - flags) + (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags) is_complete.append( - is_complete_json(current_text[start_idx:start_idx + - end_idx])) + is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) start_idx += end_idx start_idx = consume_space(start_idx, current_text) start_idx += len(self.bot_token) start_idx = consume_space(start_idx, current_text) tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -158,9 +168,9 @@ class Granite20bFCToolParser(ToolParser): # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -168,21 +178,24 @@ class Granite20bFCToolParser(ToolParser): if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) else: delta = None else: @@ -199,15 +212,18 @@ class Granite20bFCToolParser(ToolParser): elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -219,34 +235,35 @@ class Granite20bFCToolParser(ToolParser): delta = None if cur_arguments: - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: - - prefix = find_common_prefix( - prev_args_json, cur_args_json) + prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) self.prev_tool_call_arr = tool_call_arr return delta @@ -254,6 +271,6 @@ class Granite20bFCToolParser(ToolParser): except Exception as e: logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index ac517616a95b4..989973923ae58 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -9,17 +9,25 @@ import partial_json_parser from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, - find_common_prefix, - is_complete_json, - partial_json_loads) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import ( + consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -45,21 +53,24 @@ class GraniteToolParser(ToolParser): self.bot_string = "<tool_call>" def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: - stripped = model_output.strip()\ - .removeprefix(self.bot_token)\ - .removeprefix(self.bot_string)\ - .lstrip() - if not stripped or stripped[0] != '[': - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: + stripped = ( + model_output.strip() + .removeprefix(self.bot_token) + .removeprefix(self.bot_string) + .lstrip() + ) + if not stripped or stripped[0] != "[": + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: raw_function_calls = json.loads(stripped) if not isinstance(raw_function_calls, list): raise Exception( - f"Expected dict or list, got {type(raw_function_calls)}") + f"Expected dict or list, got {type(raw_function_calls)}" + ) logger.debug("Extracted %d tool calls", len(raw_function_calls)) tool_calls = [ @@ -68,10 +79,12 @@ class GraniteToolParser(ToolParser): function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False), + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), ), - ) for function_call in raw_function_calls + ) + for function_call in raw_function_calls ] return ExtractedToolCallInformation( @@ -82,9 +95,9 @@ class GraniteToolParser(ToolParser): except Exception as e: logger.error("Error in extracting tool call from response %s", e) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -96,41 +109,40 @@ class GraniteToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_token): - start_idx = consume_space(start_idx + len(self.bot_token), - current_text) + start_idx = consume_space(start_idx + len(self.bot_token), current_text) if current_text[start_idx:].startswith(self.bot_string): - start_idx = consume_space(start_idx + len(self.bot_string), - current_text) - if not current_text or start_idx >= len(current_text)\ - or current_text[start_idx] != '[': + start_idx = consume_space(start_idx + len(self.bot_string), current_text) + if ( + not current_text + or start_idx >= len(current_text) + or current_text[start_idx] != "[" + ): return DeltaMessage(content=delta_text) # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = None is_complete = None try: tool_calls, end_idx = partial_json_loads( - current_text[start_idx:], flags) + current_text[start_idx:], flags + ) if type(tool_calls) is list: tool_call_arr = tool_calls else: return DeltaMessage(content=delta_text) is_complete = [True] * len(tool_calls) - if not is_complete_json( - current_text[start_idx:start_idx + end_idx]): + if not is_complete_json(current_text[start_idx : start_idx + end_idx]): is_complete[-1] = False except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # case -- if no tokens have been streamed for the tool, e.g. @@ -145,7 +157,6 @@ class GraniteToolParser(ToolParser): # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor if len(tool_call_arr) > self.current_tool_id + 1: - # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -153,21 +164,24 @@ class GraniteToolParser(ToolParser): if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 @@ -181,15 +195,18 @@ class GraniteToolParser(ToolParser): elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True # now we know we're on the same tool call and we're streaming @@ -198,33 +215,35 @@ class GraniteToolParser(ToolParser): cur_arguments = current_tool_call.get("arguments") if cur_arguments: - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: - prefix = find_common_prefix( - prev_args_json, cur_args_json) + prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) self.prev_tool_call_arr = tool_call_arr return delta @@ -232,6 +251,6 @@ class GraniteToolParser(ToolParser): except Exception as e: logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index a6ce33af6bd00..4529eb51796e1 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -10,13 +10,19 @@ import regex as re from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -25,37 +31,41 @@ logger = init_logger(__name__) @ToolParserManager.register_module("hermes") class Hermes2ProToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) if isinstance(self.model_tokenizer, MistralTokenizer): - logger.error( - "Detected Mistral tokenizer when using a Hermes model") + logger.error("Detected Mistral tokenizer when using a Hermes model") self.model_tokenizer = self.model_tokenizer.tokenizer self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_call_start_token: str = "<tool_call>" self.tool_call_end_token: str = "</tool_call>" self.tool_call_regex = re.compile( - r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL) + r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL + ) self.scratch_pad_regex = re.compile( - r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL) + r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) self.tool_call_start_token_ids = self.model_tokenizer.encode( - self.tool_call_start_token, add_special_tokens=False) + self.tool_call_start_token, add_special_tokens=False + ) self.tool_call_end_token_ids = self.model_tokenizer.encode( - self.tool_call_end_token, add_special_tokens=False) + self.tool_call_end_token, add_special_tokens=False + ) self.tool_call_start_token_array = [ self.model_tokenizer.decode([token_id]) @@ -77,13 +87,17 @@ class Hermes2ProToolParser(ToolParser): def tool_call_delta_buffer(self, delta_text: str): # If the sequence of tool_call_start or tool_call_end tokens is not yet # complete, fill the buffer with the token and return "". - if (delta_text in self.tool_call_start_token_array - or delta_text in self.tool_call_end_token_array): + if ( + delta_text in self.tool_call_start_token_array + or delta_text in self.tool_call_end_token_array + ): # If delta_text is the last token of tool_call_start_token or # tool_call_end_token, empty the buffer and return # the buffered text + delta_text. - if (delta_text == self.tool_call_start_token_array[-1] - or delta_text == self.tool_call_end_token_array[-1]): + if ( + delta_text == self.tool_call_start_token_array[-1] + or delta_text == self.tool_call_end_token_array[-1] + ): buffered_text = self.buffered_delta_text self.buffered_delta_text = "" return buffered_text + delta_text @@ -98,27 +112,32 @@ class Hermes2ProToolParser(ToolParser): else: return delta_text + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": + # do not skip special tokens because the tool_call tokens are + # marked "special" in some models. Since they are skipped + # prior to the call to the tool parser, it breaks tool calling. + request.skip_special_tokens = False + return request + def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_call_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: - try: # there are two possible captures - between tags, or between a # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = ( - self.tool_call_regex.findall(model_output)) + function_call_tuples = self.tool_call_regex.findall(model_output) # load the JSON, and then use it to build the Function and # Tool Call @@ -132,24 +151,26 @@ class Hermes2ProToolParser(ToolParser): function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False))) + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) for function_call in raw_function_calls ] - content = model_output[:model_output. - find(self.tool_call_start_token)] + content = model_output[: model_output.find(self.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if content else None) + content=content if content else None, + ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -168,10 +189,12 @@ class Hermes2ProToolParser(ToolParser): delta_text = self.tool_call_delta_buffer(delta_text) # If the last characters of previous_text # match self.buffered_delta_text, remove only the matching part. - if (len(previous_text) >= len(self.buffered_delta_text) - and previous_text[-len(self.buffered_delta_text):] - == self.buffered_delta_text): - previous_text = previous_text[:-len(self.buffered_delta_text)] + if ( + len(previous_text) >= len(self.buffered_delta_text) + and previous_text[-len(self.buffered_delta_text) :] + == self.buffered_delta_text + ): + previous_text = previous_text[: -len(self.buffered_delta_text)] current_text = previous_text + delta_text logger.debug("delta_text: %s", delta_text) @@ -182,50 +205,51 @@ class Hermes2ProToolParser(ToolParser): return DeltaMessage(content=delta_text) try: - # figure out where we are in the parsing by counting tool call # start & end tags - prev_tool_start_count = previous_text.count( - self.tool_call_start_token) + prev_tool_start_count = previous_text.count(self.tool_call_start_token) prev_tool_end_count = previous_text.count(self.tool_call_end_token) - cur_tool_start_count = current_text.count( - self.tool_call_start_token) + cur_tool_start_count = current_text.count(self.tool_call_start_token) cur_tool_end_count = current_text.count(self.tool_call_end_token) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case: if tool open & close tag counts don't match, we're doing # imaginary "else" block here # something with tools with this diff. # flags for partial JSON parting. exported constants from # "Allow" are handled via BIT MASK - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -239,42 +263,49 @@ class Hermes2ProToolParser(ToolParser): logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if (self.prev_tool_call_arr is None - or len(self.prev_tool_call_arr) == 0): - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = diff.encode('utf-8').decode( - 'unicode_escape') if diff is str else diff - if ('"}' not in delta_text): + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) + if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " - "been streamed yet: %s", diff) - self.streamed_args_for_tool[self.current_tool_id] \ - += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -284,13 +315,14 @@ class Hermes2ProToolParser(ToolParser): return delta try: - - current_tool_call = partial_json_parser.loads( - tool_call_portion or "{}", - flags) if tool_call_portion else None + current_tool_call = ( + partial_json_parser.loads(tool_call_portion or "{}", flags) + if tool_call_portion + else None + ) logger.debug("Parsed tool call %s", current_tool_call) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None except json.decoder.JSONDecodeError: logger.debug("unable to parse JSON") @@ -299,19 +331,23 @@ class Hermes2ProToolParser(ToolParser): # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. if not self.current_tool_name_sent: - if (current_tool_call is None): + if current_tool_call is None: return None function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None # case -- otherwise, send the tool call delta @@ -320,15 +356,19 @@ class Hermes2ProToolParser(ToolParser): if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = DeltaMessage(content=delta_text) \ - if text_portion is not None else None + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -337,8 +377,9 @@ class Hermes2ProToolParser(ToolParser): # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON - prev_arguments = ( - self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -352,62 +393,99 @@ class Hermes2ProToolParser(ToolParser): # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: + # extract the content after {"name": ..., "arguments": + # directly from tool_call_portion as cur_arguments_json, + # since cur_arguments may differ from the original text + # due to partial JSON parsing + # for example, tool_call_portion = + # {"name": "search", "arguments": {"search_request": {" + # but cur_arguments = + # {"search_request": {}} + function_name = current_tool_call.get("name") + match = re.search( + r'\{"name":\s*"' + + re.escape(function_name) + + r'"\s*,\s*"arguments":\s*(.*)', + tool_call_portion.strip(), + re.DOTALL, + ) + if match: + cur_arguments_json = match.group(1) + else: + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) - logger.debug("finding %s in %s", delta_text, - cur_arguments_json) + logger.debug("finding %s in %s", delta_text, cur_arguments_json) - # get the location where previous args differ from current - if (delta_text not in cur_arguments_json[:-2]): + # get the location where previous args differ from current. + if delta_text not in cur_arguments_json: return None - args_delta_start_loc = cur_arguments_json[:-2]. \ - rindex(delta_text) + \ - len(delta_text) + args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len( + delta_text + ) # use that to find the actual delta arguments_delta = cur_arguments_json[:args_delta_start_loc] - logger.debug("First tokens in arguments received: %s", - arguments_delta) + logger.debug("First tokens in arguments received: %s", arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] \ - += arguments_delta + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if isinstance(delta_text, str) and len(delta_text.rstrip( - )) >= 1 and delta_text.rstrip()[-1] == '}': + # judge whether the tool_call_portion is a complete JSON + try: + json.loads(tool_call_portion) + is_complete_json = True + except Exception: + is_complete_json = False + + # if the delta_text ends with a '}' and tool_call_portion is a + # complete JSON, then the last '}' does not belong to the + # arguments, so we should trim it off + if ( + isinstance(delta_text, str) + and len(delta_text.rstrip()) >= 1 + and delta_text.rstrip()[-1] == "}" + and is_complete_json + ): delta_text = delta_text.rstrip()[:-1] logger.debug("got diff %s", delta_text) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=delta_text).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] \ - += delta_text + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=delta_text).model_dump( + exclude_none=True + ), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += delta_text # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[self.current_tool_id] = \ - current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py index 2b65f2579fb43..1855d69adb217 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -8,13 +8,19 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +31,6 @@ logger = init_logger(__name__) @ToolParserManager.register_module("hunyuan_a13b") class HunyuanA13BToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -33,8 +38,7 @@ class HunyuanA13BToolParser(ToolParser): self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 self.current_tool_name_sent = False - self.streamed_args: list[str] = [ - ] # Track arguments sent for each tool + self.streamed_args: list[str] = [] # Track arguments sent for each tool # For backward compatibility with tests self.current_tools_sent: list[bool] = [] @@ -44,12 +48,14 @@ class HunyuanA13BToolParser(ToolParser): # Regex patterns for preprocessing self.answer_tool_calls_pattern = re.compile( - r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL) + r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL + ) self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') self.tool_empty_arg_reg = re.compile( - r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}' + ) # TODO: not support nested json object in fc arguments. self.tool_non_empty_arg_reg = re.compile( @@ -66,15 +72,21 @@ class HunyuanA13BToolParser(ToolParser): } def preprocess_model_output( - self, model_output: str) -> tuple[Optional[str], Optional[str]]: + self, model_output: str + ) -> tuple[Optional[str], Optional[str]]: # find the location tool call for match in self.answer_tool_calls_pattern.finditer(model_output): start, end = match.span() # check tool_calls whether in side of <think> - think_regions = [(m.start(), m.end()) for m in re.finditer( - r"<think>(.*?)</think>", model_output, flags=re.DOTALL)] - in_think = any(start > t_start and end < t_end - for t_start, t_end in think_regions) + think_regions = [ + (m.start(), m.end()) + for m in re.finditer( + r"<think>(.*?)</think>", model_output, flags=re.DOTALL + ) + ] + in_think = any( + start > t_start and end < t_end for t_start, t_end in think_regions + ) if not in_think: content = model_output[:start] tool_calls_content = match.group(1).strip() @@ -86,24 +98,23 @@ class HunyuanA13BToolParser(ToolParser): return model_output, None def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract tool calls from a complete model output. """ try: # Preprocess the model output - content, potential_tool_calls = self.preprocess_model_output( - model_output) + content, potential_tool_calls = self.preprocess_model_output(model_output) if not potential_tool_calls: # some text should be filtered out for no function call # this text is in a13b's chat template. if content: content = content.replace("助手:", "", 1) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=content) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=content + ) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) @@ -120,8 +131,11 @@ class HunyuanA13BToolParser(ToolParser): tool_calls: list[ToolCall] = [] for idx, call in enumerate(tool_calls_data): - if (not isinstance(call, dict) or "name" not in call - or "arguments" not in call): + if ( + not isinstance(call, dict) + or "name" not in call + or "arguments" not in call + ): continue tool_call = ToolCall( @@ -129,8 +143,11 @@ class HunyuanA13BToolParser(ToolParser): type="function", function=FunctionCall( name=call["name"], - arguments=(json.dumps(call["arguments"]) if isinstance( - call["arguments"], dict) else call["arguments"]), + arguments=( + json.dumps(call["arguments"]) + if isinstance(call["arguments"], dict) + else call["arguments"] + ), ), ) tool_calls.append(tool_call) @@ -146,9 +163,9 @@ class HunyuanA13BToolParser(ToolParser): ) except Exception: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -166,10 +183,12 @@ class HunyuanA13BToolParser(ToolParser): start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_string): - start_idx = consume_space(start_idx + len(self.bot_string), - current_text) - if not current_text or start_idx >= len( - current_text) or current_text[start_idx] != '[': + start_idx = consume_space(start_idx + len(self.bot_string), current_text) + if ( + not current_text + or start_idx >= len(current_text) + or current_text[start_idx] != "[" + ): return DeltaMessage(content=delta_text) self._try_parse_json_tools(current_text[start_idx:]) @@ -185,13 +204,15 @@ class HunyuanA13BToolParser(ToolParser): self._ensure_state_arrays(tool_count) current_idx = self.streaming_state["current_tool_index"] - name_delta = self._handle_tool_name_streaming(current_idx, tool_count, - name_matches) + name_delta = self._handle_tool_name_streaming( + current_idx, tool_count, name_matches + ) if name_delta: return name_delta - args_delta = self._handle_tool_args_streaming(current_text, - current_idx, tool_count) + args_delta = self._handle_tool_args_streaming( + current_text, current_idx, tool_count + ) if args_delta: return args_delta @@ -207,166 +228,195 @@ class HunyuanA13BToolParser(ToolParser): def _handle_test_compatibility(self, current_text: str): if len(self.current_tools_sent) > 0: - if (len(self.current_tools_sent) == 1 - and self.current_tools_sent[0] is False): + if ( + len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False + ): name_match = self.tool_name_reg.search(current_text) if name_match: function_name = name_match.group(1) tool_id = f"chatcmpl-tool-{random_uuid()}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=0, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tools_sent = [True] self.current_tool_id = 0 self.streaming_state["current_tool_index"] = 0 if len(self.streaming_state["sent_tools"]) == 0: - self.streaming_state["sent_tools"].append({ - "sent_name": - True, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": True, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) else: - self.streaming_state["sent_tools"][0][ - "sent_name"] = True + self.streaming_state["sent_tools"][0]["sent_name"] = True self.current_tool_name_sent = True return delta return None def _ensure_state_arrays(self, tool_count: int): while len(self.streaming_state["sent_tools"]) < tool_count: - self.streaming_state["sent_tools"].append({ - "sent_name": False, - "sent_arguments_prefix": False, - "sent_arguments": "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": False, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) while len(self.streaming_state["tool_ids"]) < tool_count: self.streaming_state["tool_ids"].append(None) - def _handle_tool_name_streaming(self, current_idx: int, tool_count: int, - name_matches): + def _handle_tool_name_streaming( + self, current_idx: int, tool_count: int, name_matches + ): if current_idx == -1 or current_idx < tool_count - 1: next_idx = current_idx + 1 - if (next_idx < tool_count - and not self.streaming_state["sent_tools"][next_idx] - ["sent_name"]): + if ( + next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx]["sent_name"] + ): self.streaming_state["current_tool_index"] = next_idx self.current_tool_id = next_idx current_idx = next_idx tool_name = name_matches[current_idx].group(1) tool_id = f"call_{current_idx}_{random_uuid()}" self.streaming_state["tool_ids"][current_idx] = tool_id - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - type="function", - id=tool_id, - function=DeltaFunctionCall(name=tool_name).model_dump( - exclude_none=True), - ) - ]) - self.streaming_state["sent_tools"][current_idx][ - "sent_name"] = True + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True + ), + ) + ] + ) + self.streaming_state["sent_tools"][current_idx]["sent_name"] = True self.current_tool_name_sent = True while len(self.streamed_args) <= current_idx: self.streamed_args.append("") return delta return None - def _handle_tool_args_streaming(self, current_text: str, current_idx: int, - tool_count: int): - + def _handle_tool_args_streaming( + self, current_text: str, current_idx: int, tool_count: int + ): if current_idx >= 0 and current_idx < tool_count: empty_args_match = self.tool_empty_arg_reg.search(current_text) if empty_args_match and empty_args_match.start() > 0: for i in range(tool_count): if i == current_idx: if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"]: + "sent_arguments_prefix" + ]: self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True + "sent_arguments_prefix" + ] = True self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{}" + "sent_arguments" + ] = "{}" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{}").model_dump( - exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}" + ).model_dump(exclude_none=True), + ) + ] + ) if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] + "current_tool_index" + ] return delta - args_matches = list( - self.tool_non_empty_arg_reg.finditer(current_text)) + args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) is_last_tool = current_idx == tool_count - 1 if not is_last_tool: next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) + "},{", args_matches[current_idx].start() + ) if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1) + args_end_pos = next_tool_pos + 1 args_text = ( - current_text[args_matches[current_idx].start( - ):args_end_pos].split('"arguments":')[1].strip()) + current_text[ + args_matches[current_idx].start() : args_end_pos + ] + .split('"arguments":')[1] + .strip() + ) sent_args = self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] + "sent_arguments" + ] if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] and args_text.startswith("{"): + "sent_arguments_prefix" + ] and args_text.startswith("{"): self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True + "sent_arguments_prefix" + ] = True self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{" + "sent_arguments" + ] = "{" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{").model_dump(exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall(arguments="{").model_dump( + exclude_none=True + ), + ) + ] + ) return delta if args_text.startswith(sent_args): - args_diff = args_text[len(sent_args):] + args_diff = args_text[len(sent_args) :] if args_diff: self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = args_text + "sent_arguments" + ] = args_text while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += args_diff - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments=args_diff).model_dump( - exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff + ).model_dump(exclude_none=True), + ) + ] + ) return delta if args_text.endswith("}") and args_text == sent_args: if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] + "current_tool_index" + ] return None diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 6ef8fadf59ac5..9adaea297b05f 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -9,15 +9,20 @@ import partial_json_parser from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -26,16 +31,14 @@ logger = init_logger(__name__) @ToolParserManager.register_module(["internlm"]) class Internlm2ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.position = 0 - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": # do not skip special tokens because internlm use the special - # tokens to indicated the start and end of the tool calls + # tokens to indicate the start and end of the tool calls # information. request.skip_special_tokens = False return request @@ -57,45 +60,43 @@ class Internlm2ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - if '<|action_start|>' not in current_text: + if "<|action_start|>" not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) - # if the tool call is sended, return a empty delta message - # to make sure the finish_reason will be send correctly. + # if the tool call is sent, return an empty delta message + # to make sure the finish_reason will be sent correctly. if self.current_tool_id > 0: - return DeltaMessage(content='') + return DeltaMessage(content="") last_pos = self.position - if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + if "<|action_start|><|plugin|>" not in current_text[last_pos:]: return None new_delta = current_text[last_pos:] - text, action = new_delta.split('<|action_start|><|plugin|>') + text, action = new_delta.split("<|action_start|><|plugin|>") if len(text) > 0: self.position = self.position + len(text) return DeltaMessage(content=text) action = action.strip() - action = action.split('<|action_end|>'.strip())[0] + action = action.split("<|action_end|>".strip())[0] # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: parsable_arr = action - # tool calls are generated in an object in inernlm2 + # tool calls are generated in an object in internlm2 # it's not support parallel tool calls try: - tool_call_arr: dict = partial_json_parser.loads( - parsable_arr, flags) + tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # if the current tool name hasn't been sent, send if available @@ -104,14 +105,18 @@ class Internlm2ToolParser(ToolParser): function_name = tool_call_arr.get("name") if function_name: self.current_tool_id = self.current_tool_id + 1 - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True self.streamed_args_for_tool.append("") else: @@ -120,7 +125,8 @@ class Internlm2ToolParser(ToolParser): # arguments else: prev_arguments = self.get_arguments( - self.prev_tool_call_arr[self.current_tool_id]) + self.prev_tool_call_arr[self.current_tool_id] + ) cur_arguments = self.get_arguments(tool_call_arr) # not arguments generated @@ -129,43 +135,47 @@ class Internlm2ToolParser(ToolParser): # will never happen elif not cur_arguments and prev_arguments: logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") + "INVARIANT - impossible to have arguments reset mid-arguments" + ) delta = None # first time to get parameters elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) - arguments_delta = cur_arguments_json[:cur_arguments_json. - index(delta_text) + - len(delta_text)] - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + arguments_delta = cur_arguments_json[ + : cur_arguments_json.index(delta_text) + len(delta_text) + ] + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta # both prev and cur parameters, send the increase parameters elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) + cur_args_json, prev_args_json + ) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff # check to see if the name is defined and has been sent. if so, # stream the name - otherwise keep waiting @@ -176,8 +186,8 @@ class Internlm2ToolParser(ToolParser): except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None def extract_tool_calls( @@ -187,30 +197,33 @@ class Internlm2ToolParser(ToolParser): ) -> ExtractedToolCallInformation: text = model_output tools = request.tools - if '<|action_start|><|plugin|>' in text: - text, action = text.split('<|action_start|><|plugin|>') - action = action.split('<|action_end|>'.strip())[0] - action = action[action.find('{'):] + if "<|action_start|><|plugin|>" in text: + text, action = text.split("<|action_start|><|plugin|>") + action = action.split("<|action_end|>".strip())[0] + action = action[action.find("{") :] action_dict = json.loads(action) - name, parameters = action_dict['name'], json.dumps( - action_dict.get('parameters', action_dict.get('arguments', - {})), - ensure_ascii=False) + name, parameters = ( + action_dict["name"], + json.dumps( + action_dict.get("parameters", action_dict.get("arguments", {})), + ensure_ascii=False, + ), + ) if not tools or name not in [t.function.name for t in tools]: - ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=text) + ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=text + ) tool_calls = [ - ToolCall( - function=FunctionCall(name=name, arguments=parameters)) + ToolCall(function=FunctionCall(name=name, arguments=parameters)) ] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=text if len(text) > 0 else None) + content=text if len(text) > 0 else None, + ) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=text) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=text + ) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 3b41f6034704c..1ae3e0da33513 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -10,14 +10,17 @@ import regex as re from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer @@ -27,7 +30,6 @@ logger = init_logger(__name__) @ToolParserManager.register_module("jamba") class JambaToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -39,33 +41,35 @@ class JambaToolParser(ToolParser): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<tool_calls>" self.tool_calls_end_token: str = "</tool_calls>" self.tool_calls_regex = re.compile( - rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", - re.DOTALL) + rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", re.DOTALL + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "Jamba Tool parser could not locate tool calls start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": # do not skip special tokens because jamba use the special # tokens to indicate the start and end of the tool calls # information. @@ -73,17 +77,15 @@ class JambaToolParser(ToolParser): return request def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: - + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: - try: # use a regex to find the tool call between the tags function_calls = self.tool_calls_regex.findall(model_output)[0] @@ -97,25 +99,26 @@ class JambaToolParser(ToolParser): function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False), - )) for function_call in raw_function_calls + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) + for function_call in raw_function_calls ] - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if - (len(content) > 0 and content != " ") else None) + content=content if (len(content) > 0 and content != " ") else None, + ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -127,7 +130,6 @@ class JambaToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool if self.tool_calls_start_token not in current_text: @@ -138,8 +140,10 @@ class JambaToolParser(ToolParser): # handle if we detected the start of tool calls token which means # the start of tool calling - if (self.tool_calls_start_token_id in delta_token_ids - and len(delta_token_ids) == 1): + if ( + self.tool_calls_start_token_id in delta_token_ids + and len(delta_token_ids) == 1 + ): # if it's the only token, return None, so we don't send a chat # completion and don't send a control token return None @@ -148,28 +152,28 @@ class JambaToolParser(ToolParser): # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # Extract the tool calls between the special tool call tokens - parsable_arr = current_text.split( - self.tool_calls_start_token)[-1].split( - self.tool_calls_end_token)[0] + parsable_arr = current_text.split(self.tool_calls_start_token)[-1].split( + self.tool_calls_end_token + )[0] # tool calls are generated in an array, so do partial JSON # parsing on the entire array try: tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) + parsable_arr, flags + ) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -178,9 +182,9 @@ class JambaToolParser(ToolParser): # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -190,16 +194,19 @@ class JambaToolParser(ToolParser): if diff: diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff + self.streamed_args_for_tool[self.current_tool_id], "" + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += diff else: delta = None else: @@ -218,15 +225,18 @@ class JambaToolParser(ToolParser): if not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -234,60 +244,66 @@ class JambaToolParser(ToolParser): # now we know we're on the same tool call and we're streaming # arguments else: - - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) cur_arguments = current_tool_call.get("arguments") - new_text = delta_text.replace("\'", "\"") + new_text = delta_text.replace("'", '"') if not cur_arguments and not prev_arguments: - delta = None elif not cur_arguments and prev_arguments: logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") + "INVARIANT - impossible to have arguments reset mid-arguments" + ) delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) - logger.debug("finding %s in %s", new_text, - cur_arguments_json) + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) + logger.debug("finding %s in %s", new_text, cur_arguments_json) - arguments_delta = cur_arguments_json[:cur_arguments_json. - index(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + arguments_delta = cur_arguments_json[ + : cur_arguments_json.index(new_text) + len(new_text) + ] + logger.debug( + "First tokens in arguments received: %s", arguments_delta + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) + logger.debug( + "Searching for diff between \n%s\n%s", + cur_args_json, + prev_args_json, + ) argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) + cur_args_json, prev_args_json + ) logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff else: # try parsing it with regular JSON - if it works we're # at the end, and we need to send the difference between @@ -303,6 +319,6 @@ class JambaToolParser(ToolParser): except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index 834b33052b45d..a2eff21a44667 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -7,13 +7,19 @@ from typing import Union import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -22,14 +28,14 @@ logger = init_logger(__name__) @ToolParserManager.register_module(["kimi_k2"]) class KimiK2ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - []) # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" @@ -45,39 +51,38 @@ class KimiK2ToolParser(ToolParser): r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)" ) - self.stream_tool_call_name_regex = re.compile( - r"(?P<tool_call_id>.+:\d+)\s*") + self.stream_tool_call_name_regex = re.compile(r"(?P<tool_call_id>.+:\d+)\s*") if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "Kimi-K2 Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: try: @@ -85,8 +90,7 @@ class KimiK2ToolParser(ToolParser): # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = self.tool_call_regex.findall( - model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) logger.debug("function_call_tuples: %s", function_call_tuples) @@ -94,17 +98,18 @@ class KimiK2ToolParser(ToolParser): for match in function_call_tuples: function_id, function_args = match # function_id: functions.get_weather:0 - function_name = function_id.split('.')[1].split(':')[0] + function_name = function_id.split(".")[1].split(":")[0] tool_calls.append( ToolCall( id=function_id, - type='function', - function=FunctionCall(name=function_name, - arguments=function_args), - )) + type="function", + function=FunctionCall( + name=function_name, arguments=function_args + ), + ) + ) - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -112,11 +117,10 @@ class KimiK2ToolParser(ToolParser): ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -128,55 +132,58 @@ class KimiK2ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_calls_end_token, - "") + delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( + self.tool_calls_end_token, "" + ) try: - # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -190,27 +197,29 @@ class KimiK2ToolParser(ToolParser): logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') @@ -221,13 +230,16 @@ class KimiK2ToolParser(ToolParser): diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -238,23 +250,23 @@ class KimiK2ToolParser(ToolParser): current_tool_call = dict() if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) + current_tool_call_matches = self.stream_tool_call_portion_regex.match( + tool_call_portion + ) if current_tool_call_matches: - tool_id, tool_args = (current_tool_call_matches.groups()) - tool_name = tool_id.split('.')[1].split(':')[0] - current_tool_call['id'] = tool_id + tool_id, tool_args = current_tool_call_matches.groups() + tool_name = tool_id.split(".")[1].split(":")[0] + current_tool_call["id"] = tool_id current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) + self.stream_tool_call_name_regex.match(tool_call_portion) + ) if current_tool_call_name_matches: - tool_id_str, = current_tool_call_name_matches.groups() - tool_name = tool_id_str.split('.')[1].split(':')[0] - current_tool_call['id'] = tool_id_str + (tool_id_str,) = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split(".")[1].split(":")[0] + current_tool_call["id"] = tool_id_str current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: @@ -270,16 +282,18 @@ class KimiK2ToolParser(ToolParser): tool_id = current_tool_call.get("id") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None @@ -289,15 +303,19 @@ class KimiK2ToolParser(ToolParser): if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -307,7 +325,8 @@ class KimiK2ToolParser(ToolParser): # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -321,52 +340,56 @@ class KimiK2ToolParser(ToolParser): # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments - - # last case -- we have an update to existing arguments. - elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] - logger.debug("got diff %s", delta_text) - - delta = DeltaMessage(tool_calls=[ + delta = DeltaMessage( + tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), + arguments=cur_arguments + ).model_dump(exclude_none=True), ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if ( + isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + delta_arguments = cur_arguments[len(prev_arguments) :] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 6bf44a4345a9d..162675efbc9a7 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -9,13 +9,19 @@ import regex as re from transformers import PreTrainedTokenizerBase import vllm.envs as envs -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -31,6 +37,7 @@ class Llama4PythonicToolParser(ToolParser): Toolcall parser for Llama4 that produce tool calls in a pythonic style Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic """ + # TODO(mdepinet): Possible future improvements: # 1. Support text + tools separated by either <|python_tag|> or \n\n # 2. Support tools outside of a list (or separated by a semicolon). @@ -40,7 +47,8 @@ class Llama4PythonicToolParser(ToolParser): TOOL_CALL_REGEX = re.compile( r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", - re.DOTALL) + re.DOTALL, + ) def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) @@ -55,8 +63,8 @@ class Llama4PythonicToolParser(ToolParser): self.current_tool_id = value def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ @@ -64,46 +72,52 @@ class Llama4PythonicToolParser(ToolParser): # remove <|python_start|> and <|python_end|> # as Llama 4 model sometime will output those tokens if model_output.startswith("<|python_start|>"): - model_output = model_output[len("<|python_start|>"):] + model_output = model_output[len("<|python_start|>") :] model_output = model_output.replace("<|python_end|>", "") is_tool_call_pattern = False try: - is_tool_call_pattern = self.TOOL_CALL_REGEX.match( - model_output, - timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + is_tool_call_pattern = ( + self.TOOL_CALL_REGEX.match( + model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS + ) + is not None + ) except TimeoutError: - logger.warning( - "Regex timeout occurred when matching tool call pattern.") - logger.debug("Regex timeout occurred when matching user input: %s", - model_output) + logger.warning("Regex timeout occurred when matching tool call pattern.") + logger.debug( + "Regex timeout occurred when matching user input: %s", model_output + ) if not is_tool_call_pattern: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: module = ast.parse(model_output) parsed = getattr(module.body[0], "value", None) if isinstance(parsed, ast.List) and all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): return ExtractedToolCallInformation( tools_called=True, tool_calls=[ _handle_single_tool(e) # type: ignore for e in parsed.elts ], - content=None) + content=None, + ) else: raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -115,18 +129,17 @@ class Llama4PythonicToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - if not current_text.startswith("[") and not current_text.startswith( - "<|python_start|>"): + "<|python_start|>" + ): return DeltaMessage(content=delta_text) try: # remove <|python_start|> and <|python_end|> if current_text.startswith("<|python_start|>"): - current_text = current_text[len("<|python_start|>"):] + current_text = current_text[len("<|python_start|>") :] if current_text.endswith("<|python_end|>"): - current_text = current_text[:current_text. - rfind("<|python_end|>")] + current_text = current_text[: current_text.rfind("<|python_end|>")] valid_and_added_text = _make_valid_python(current_text) if valid_and_added_text is None: return None @@ -135,9 +148,11 @@ class Llama4PythonicToolParser(ToolParser): module = ast.parse(valid_text) parsed = getattr(module.body[0], "value", None) if not isinstance(parsed, ast.List) or not all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) tool_calls = [ _handle_single_tool(e) # type: ignore for e in parsed.elts @@ -152,34 +167,36 @@ class Llama4PythonicToolParser(ToolParser): if len(self.streamed_args_for_tool) == index: self.streamed_args_for_tool.append("") - new_call_complete = index < len( - tool_calls) - 1 or ")]" not in added_text + new_call_complete = ( + index < len(tool_calls) - 1 or ")]" not in added_text + ) if new_call_complete: self.current_tool_index += 1 - withheld_suffix = (added_text[:-2] - if not new_call_complete else "") + withheld_suffix = added_text[:-2] if not new_call_complete else "" if not new_call_complete and added_text[-2] == ")": # Function call is incomplete. Withhold the closing bracket. withheld_suffix = withheld_suffix + "}" # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta(self.streamed_args_for_tool[index], - new_call, index, withheld_suffix) + delta = _compute_tool_delta( + self.streamed_args_for_tool[index], new_call, index, withheld_suffix + ) if delta is not None: tool_deltas.append(delta) - if (delta.function is not None - and delta.function.arguments is not None): - self.streamed_args_for_tool[ - index] += delta.function.arguments + if ( + delta.function is not None + and delta.function.arguments is not None + ): + self.streamed_args_for_tool[index] += delta.function.arguments - # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining it's final streaming delta, automatically - # adding autocompleted JSON. - # These two lines avoid that nonsense while ensuring finish_reason - # is set to tool_calls when at least one tool is called. + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining its final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. if tool_deltas and not self.prev_tool_call_arr: self.prev_tool_call_arr = [{"arguments": {}}] @@ -188,14 +205,14 @@ class Llama4PythonicToolParser(ToolParser): elif not added_text and self.current_tool_id > 0: # Return an empty DeltaMessage once the tool calls are all done # so that finish_reason gets set. - return DeltaMessage(content='') + return DeltaMessage(content="") else: return None except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None @@ -204,8 +221,7 @@ def _get_parameter_value(val: ast.expr) -> Any: return val.value elif isinstance(val, ast.Dict): if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError( - "Dict tool call arguments must have literal keys") + raise _UnexpectedAstError("Dict tool call arguments must have literal keys") return { k.value: _get_parameter_value(v) # type: ignore for k, v in zip(val.keys, val.values) @@ -223,9 +239,10 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments = {} for keyword in call.keywords: arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall(type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, arguments=json.dumps(arguments)), + ) def _make_valid_python(text: str) -> Union[tuple[str, str], None]: @@ -261,21 +278,25 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: # we can't fill in a valid value. return None if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[:text.rfind("{")] + trailing_dict_text = text[: text.rfind("{")] num_keys = trailing_dict_text.count(":") num_values = trailing_dict_text.count(",") if num_keys <= num_values: return None # Incomplete property name within parameter value if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[:text.rfind("(")] + trailing_params_text = text[: text.rfind("(")] num_full_param_names = trailing_params_text.count("=") num_full_param_values = trailing_params_text.count(",") if num_full_param_names <= num_full_param_values: return None # Incomplete parameter name if text.endswith(","): text = text[:-1] - if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( - "[") and not text.endswith(")"): + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): return None # Incomplete function name added_text = "" @@ -294,23 +315,29 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: return text + added_text, added_text -def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, - index: int, - withheld_suffix: str) -> Union[DeltaToolCall, None]: +def _compute_tool_delta( + previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str +) -> Union[DeltaToolCall, None]: new_call_args = new_call.function.arguments if withheld_suffix: assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[:-len(withheld_suffix)] + new_call_args = new_call_args[: -len(withheld_suffix)] if not previously_sent_args: - return DeltaToolCall(id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - )) + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) - arg_diff = new_call_args[len(previously_sent_args):] - return DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall( - arguments=arg_diff)) if arg_diff else None + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) + ) + if arg_diff + else None + ) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 31b19c8db4163..4d5ef5ed64aa2 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -11,16 +11,24 @@ from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, - is_complete_json, - partial_json_loads) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import ( + find_common_prefix, + is_complete_json, + partial_json_loads, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -33,7 +41,7 @@ class Llama3JsonToolParser(ToolParser): Tool call parser for Llama 3.x and 4 models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser llama3_json or + Used when --enable-auto-tool-choice --tool-call-parser llama3_json or llama4_json are set. """ @@ -45,42 +53,45 @@ class Llama3JsonToolParser(ToolParser): self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.bot_token = "<|python_tag|>" - self.bot_token_id = tokenizer.encode(self.bot_token, - add_special_tokens=False)[0] + self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[ + 0 + ] # Updated regex to match multiple JSONs separated by semicolons # This pattern is more robust and can handle nested JSON objects self.tool_call_regex = re.compile( - r'{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*', - re.DOTALL) + r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*", + re.DOTALL, + ) def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Only extracts JSON content and ignores any surrounding plain text. Supports both single JSON and multiple JSONs separated by semicolons. """ # Quick check before running regex - if not (self.bot_token in model_output or '{' in model_output): - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + if not (self.bot_token in model_output or "{" in model_output): + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) # Find JSON object(s) in the text using regex match = self.tool_call_regex.search(model_output) if not match: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: json_str = match.group(0) # Split by semicolon and strip whitespace - json_objects = [obj.strip() for obj in json_str.split(';')] + json_objects = [obj.strip() for obj in json_str.split(";")] tool_calls: list[ToolCall] = [] for json_obj in json_objects: @@ -95,19 +106,24 @@ class Llama3JsonToolParser(ToolParser): # function call args are JSON but as a string arguments=json.dumps( obj["arguments"] - if "arguments" in obj else obj["parameters"], - ensure_ascii=False)))) + if "arguments" in obj + else obj["parameters"], + ensure_ascii=False, + ), + ), + ) + ) - return ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=None) + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=None + ) except Exception: logger.exception("Error in extracting tool call from response.") # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -119,47 +135,49 @@ class Llama3JsonToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - - if not (current_text.startswith(self.bot_token) - or current_text.startswith('{')): + if not ( + current_text.startswith(self.bot_token) or current_text.startswith("{") + ): return DeltaMessage(content=delta_text) # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] try: # depending on the prompt format the Llama model may or may not # prefix the output with the <|python_tag|> token - start_idx = len(self.bot_token) if current_text.startswith( - self.bot_token) else 0 + start_idx = ( + len(self.bot_token) + if current_text.startswith(self.bot_token) + else 0 + ) while start_idx < len(current_text): - (obj, - end_idx) = partial_json_loads(current_text[start_idx:], - flags) + (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags) is_complete.append( - is_complete_json(current_text[start_idx:start_idx + - end_idx])) - start_idx += end_idx + len('; ') + is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) + start_idx += end_idx + len("; ") # depending on the prompt Llama can use # either arguments or parameters if "parameters" in obj: - assert "arguments" not in obj, \ + assert "arguments" not in obj, ( "model generated both parameters and arguments" + ) obj["arguments"] = obj["parameters"] tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -168,9 +186,9 @@ class Llama3JsonToolParser(ToolParser): # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -178,21 +196,24 @@ class Llama3JsonToolParser(ToolParser): if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) else: delta = None else: @@ -209,15 +230,18 @@ class Llama3JsonToolParser(ToolParser): elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -229,34 +253,35 @@ class Llama3JsonToolParser(ToolParser): delta = None if cur_arguments: - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: - - prefix = find_common_prefix( - prev_args_json, cur_args_json) + prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) self.prev_tool_call_arr = tool_call_arr return delta @@ -264,6 +289,6 @@ class Llama3JsonToolParser(ToolParser): except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py new file mode 100644 index 0000000000000..1dc1a0290c8d9 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import regex as re + +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParserManager +from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +@ToolParserManager.register_module("longcat") +class LongcatFlashToolParser(Hermes2ProToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.tool_call_start_token: str = "<longcat_tool_call>" + self.tool_call_end_token: str = "</longcat_tool_call>" + + self.tool_call_regex = re.compile( + r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)", + re.DOTALL, + ) + + self.tool_call_start_token_ids = self.model_tokenizer.encode( + self.tool_call_start_token, add_special_tokens=False + ) + self.tool_call_end_token_ids = self.model_tokenizer.encode( + self.tool_call_end_token, add_special_tokens=False + ) + + self.tool_call_start_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_start_token_ids + ] + + self.tool_call_end_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_end_token_ids + ] diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 283e6095013d6..0b83fd237a6a7 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -8,15 +8,20 @@ from typing import Any, Optional, Union import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +30,6 @@ logger = init_logger(__name__) @ToolParserManager.register_module("minimax") class MinimaxToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -40,7 +44,8 @@ class MinimaxToolParser(ToolParser): self.tool_call_start_token = "<tool_calls>" self.tool_call_end_token = "</tool_calls>" self.tool_call_regex = re.compile( - r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL) + r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL + ) self.thinking_tag_pattern = r"<think>(.*?)</think>" self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"') self.tool_args_pattern = re.compile(r'"arguments":\s*') @@ -52,50 +57,51 @@ class MinimaxToolParser(ToolParser): if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) # Get token IDs for tool call start/end tokens - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: logger.warning( "Minimax Tool parser could not locate tool call start/end " - "tokens in the tokenizer. Falling back to string matching.") + "tokens in the tokenizer. Falling back to string matching." + ) def preprocess_model_output(self, model_output: str) -> str: """ Preprocess model output by removing tool calls from thinking tags. - + Args: model_output: Raw model output string - + Returns: Preprocessed model output with tool calls removed from thinking tags """ def remove_tool_calls_from_think(match): think_content = match.group(1) - cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>", - "", - think_content, - flags=re.DOTALL) + cleaned_content = re.sub( + r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL + ) return f"<think>{cleaned_content}</think>" - return re.sub(self.thinking_tag_pattern, - remove_tool_calls_from_think, - model_output, - flags=re.DOTALL) + return re.sub( + self.thinking_tag_pattern, + remove_tool_calls_from_think, + model_output, + flags=re.DOTALL, + ) def _clean_duplicate_braces(self, args_text: str) -> str: """ Clean duplicate closing braces from arguments text. - + Args: args_text: Raw arguments text - + Returns: Cleaned arguments text with proper JSON formatting """ @@ -109,7 +115,7 @@ class MinimaxToolParser(ToolParser): except json.JSONDecodeError: pass - while args_text.endswith('}}'): + while args_text.endswith("}}"): candidate = args_text[:-1] try: json.loads(candidate) @@ -122,10 +128,10 @@ class MinimaxToolParser(ToolParser): def _clean_delta_braces(self, delta_text: str) -> str: """ Clean delta text by removing excessive closing braces. - + Args: delta_text: Delta text to clean - + Returns: Cleaned delta text """ @@ -134,10 +140,10 @@ class MinimaxToolParser(ToolParser): delta_stripped = delta_text.strip() - if delta_stripped and all(c in '}\n\r\t ' for c in delta_stripped): - brace_count = delta_stripped.count('}') + if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped): + brace_count = delta_stripped.count("}") if brace_count > 1: - return '}\n' if delta_text.endswith('\n') else '}' + return "}\n" if delta_text.endswith("\n") else "}" return delta_text @@ -148,34 +154,32 @@ class MinimaxToolParser(ToolParser): ) -> ExtractedToolCallInformation: """ Extract tool calls from model output for non-streaming mode. - + Args: model_output: Complete model output request: Chat completion request - + Returns: ExtractedToolCallInformation containing tool calls and content """ processed_output = self.preprocess_model_output(model_output) if self.tool_call_start_token not in processed_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: - function_call_tuples = self.tool_call_regex.findall( - processed_output) + function_call_tuples = self.tool_call_regex.findall(processed_output) raw_function_calls = [] for match in function_call_tuples: tool_call_content = match[0] if match[0] else match[1] if tool_call_content.strip(): - lines = tool_call_content.strip().split('\n') + lines = tool_call_content.strip().split("\n") for line in lines: line = line.strip() - if line and line.startswith('{') and line.endswith( - '}'): + if line and line.startswith("{") and line.endswith("}"): try: parsed_call = json.loads(line) raw_function_calls.append(parsed_call) @@ -186,25 +190,29 @@ class MinimaxToolParser(ToolParser): for function_call in raw_function_calls: if "name" in function_call and "arguments" in function_call: tool_calls.append( - ToolCall(type="function", - function=FunctionCall( - name=function_call["name"], - arguments=json.dumps( - function_call["arguments"], - ensure_ascii=False)))) + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) + ) processed_pos = processed_output.find(self.tool_call_start_token) if processed_pos != -1: processed_content = processed_output[:processed_pos].strip() if processed_content: - lines = processed_content.split('\n') + lines = processed_content.split("\n") for line in reversed(lines): line = line.strip() if line: pos = model_output.find(line) if pos != -1: - content = model_output[:pos + len(line)] + content = model_output[: pos + len(line)] break else: content = "" @@ -216,68 +224,74 @@ class MinimaxToolParser(ToolParser): return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, - content=content.strip() if content.strip() else None) + content=content.strip() if content.strip() else None, + ) except Exception: logger.exception( - "An unexpected error occurred during tool call extraction.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + "An unexpected error occurred during tool call extraction." + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def _update_thinking_state(self, text: str) -> None: """ Update the thinking tag state based on text content. - + Args: text: Text to analyze for thinking tags """ open_count = text.count("<think>") close_count = text.count("</think>") self.in_thinking_tag = open_count > close_count or ( - open_count == close_count and text.endswith("</think>")) + open_count == close_count and text.endswith("</think>") + ) def _is_potential_tag_start(self, text: str) -> bool: """ Check if text might be the start of a tool call tag. - + Args: text: Text to check - + Returns: True if text could be the start of a tool call tag """ for tag in [self.tool_call_start_token, self.tool_call_end_token]: if any( - tag.startswith(text[-i:]) - for i in range(1, min(len(text) + 1, len(tag)))): + tag.startswith(text[-i:]) + for i in range(1, min(len(text) + 1, len(tag))) + ): return True return False def _should_buffer_content(self, delta_text: str) -> bool: """ Determine if content should be buffered for later processing. - + Args: delta_text: Delta text to check - + Returns: True if content should be buffered """ if self.in_thinking_tag: return False - return bool(self.pending_buffer - or self.tool_call_start_token in delta_text - or self.tool_call_end_token in delta_text - or delta_text.startswith('<')) + return bool( + self.pending_buffer + or self.tool_call_start_token in delta_text + or self.tool_call_end_token in delta_text + or delta_text.startswith("<") + ) def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]: """ Split delta text into safe content and potential tag content. - + Args: delta_text: Delta text to split - + Returns: Tuple of (safe_content, potential_tag_content) """ @@ -295,10 +309,10 @@ class MinimaxToolParser(ToolParser): def _process_buffer(self, new_content: str) -> str: """ Process buffered content and return output content. - + Args: new_content: New content to add to buffer - + Returns: Processed output content """ @@ -326,7 +340,7 @@ class MinimaxToolParser(ToolParser): break output_content += self.pending_buffer[:tag_pos] - self.pending_buffer = self.pending_buffer[tag_pos + tag_len:] + self.pending_buffer = self.pending_buffer[tag_pos + tag_len :] return output_content @@ -340,13 +354,14 @@ class MinimaxToolParser(ToolParser): def _advance_to_next_tool(self) -> None: """Advance to the next tool in the streaming sequence.""" - self.streaming_state["current_tool_index"] = int( - self.streaming_state["current_tool_index"]) + 1 + self.streaming_state["current_tool_index"] = ( + int(self.streaming_state["current_tool_index"]) + 1 + ) def _set_current_tool_index(self, index: int) -> None: """ Set the current tool index. - + Args: index: Tool index to set """ @@ -355,7 +370,7 @@ class MinimaxToolParser(ToolParser): def _get_current_tool_index(self) -> int: """ Get the current tool index. - + Returns: Current tool index """ @@ -364,10 +379,10 @@ class MinimaxToolParser(ToolParser): def _get_next_unsent_tool_index(self, tool_count: int) -> int: """ Get the index of the next unsent tool. - + Args: tool_count: Total number of tools - + Returns: Index of next unsent tool, or -1 if all tools sent """ @@ -383,7 +398,7 @@ class MinimaxToolParser(ToolParser): def _ensure_state_arrays(self, tool_count: int) -> None: """ Ensure state arrays have sufficient capacity for tool_count tools. - + Args: tool_count: Number of tools to prepare for """ @@ -391,11 +406,13 @@ class MinimaxToolParser(ToolParser): tool_ids = list(self.streaming_state["tool_ids"]) while len(sent_tools) < tool_count: - sent_tools.append({ - "sent_name": False, - "sent_arguments": "", - "id": make_tool_call_id(), - }) + sent_tools.append( + { + "sent_name": False, + "sent_arguments": "", + "id": make_tool_call_id(), + } + ) while len(tool_ids) < tool_count: tool_ids.append(None) @@ -406,10 +423,10 @@ class MinimaxToolParser(ToolParser): def _detect_tools_in_text(self, text: str) -> int: """ Detect the number of tools in text by counting name patterns. - + Args: text: Text to analyze - + Returns: Number of tools detected """ @@ -419,26 +436,26 @@ class MinimaxToolParser(ToolParser): def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: """ Find the boundaries of tool calls in text. - + Args: text: Text to analyze - + Returns: List of (start, end) positions for tool calls """ boundaries = [] i = 0 while i < len(text): - if text[i] == '{': + if text[i] == "{": start = i depth = 0 has_name = False has_arguments = False while i < len(text): - if text[i] == '{': + if text[i] == "{": depth += 1 - elif text[i] == '}': + elif text[i] == "}": depth -= 1 if depth == 0: end = i + 1 @@ -447,10 +464,9 @@ class MinimaxToolParser(ToolParser): boundaries.append((start, end)) break - if not has_name and '"name"' in text[start:i + 1]: + if not has_name and '"name"' in text[start : i + 1]: has_name = True - if not has_arguments and '"arguments"' in text[start:i + - 1]: + if not has_arguments and '"arguments"' in text[start : i + 1]: has_arguments = True i += 1 @@ -461,46 +477,46 @@ class MinimaxToolParser(ToolParser): i += 1 return boundaries - def _extract_tool_args(self, tool_content: str, args_match) -> str: + def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str: """ Extract tool arguments from tool content. - + Args: tool_content: Tool call content args_match: Regex match for arguments pattern - + Returns: Extracted arguments as string """ args_start_pos = args_match.end() remaining_content = tool_content[args_start_pos:] - if remaining_content.strip().startswith('{'): + if remaining_content.strip().startswith("{"): depth = 0 for i, char in enumerate(remaining_content): - if char == '{': + if char == "{": depth += 1 - elif char == '}': + elif char == "}": depth -= 1 if depth == 0: - return remaining_content[:i + 1] + return remaining_content[: i + 1] else: - args_end = remaining_content.find('}') + args_end = remaining_content.find("}") if args_end > 0: return remaining_content[:args_end].strip() - return remaining_content.rstrip('}').strip() + return remaining_content.rstrip("}").strip() def _get_current_tool_content( - self, text: str, - tool_index: int) -> tuple[Optional[str], Optional[str]]: + self, text: str, tool_index: int + ) -> tuple[Optional[str], Optional[str]]: """ Get the content of a specific tool by index. - + Args: text: Text containing tool calls tool_index: Index of tool to extract - + Returns: Tuple of (tool_name, tool_arguments) or (None, None) if not found """ @@ -521,22 +537,22 @@ class MinimaxToolParser(ToolParser): args_text = self._extract_tool_args(tool_content, args_match) return name, args_text except Exception: - remaining_content = tool_content[args_match.end():] - args_text = remaining_content.rstrip('}').strip() + remaining_content = tool_content[args_match.end() :] + args_text = remaining_content.rstrip("}").strip() return name, args_text return name, None def _handle_tool_name_streaming( - self, tool_content: str, - tool_count: int) -> Union[DeltaMessage, None]: + self, tool_content: str, tool_count: int + ) -> Union[DeltaMessage, None]: """ Handle streaming of tool names. - + Args: tool_content: Content containing tool calls tool_count: Total number of tools - + Returns: DeltaMessage with tool name or None if no tool to stream """ @@ -564,24 +580,29 @@ class MinimaxToolParser(ToolParser): self.streaming_state["sent_tools"] = sent_tools self.streaming_state["tool_ids"] = tool_ids - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=next_idx, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=tool_name).model_dump(exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=next_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True + ), + ) + ] + ) def _handle_tool_args_streaming( - self, tool_content: str, - tool_count: int) -> Union[DeltaMessage, None]: + self, tool_content: str, tool_count: int + ) -> Union[DeltaMessage, None]: """ Handle streaming of tool arguments. - + Args: tool_content: Content containing tool calls tool_count: Total number of tools - + Returns: DeltaMessage with tool arguments or None if no arguments to stream """ @@ -590,8 +611,7 @@ class MinimaxToolParser(ToolParser): if current_idx < 0 or current_idx >= tool_count: return None - tool_name, tool_args = self._get_current_tool_content( - tool_content, current_idx) + tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx) if not tool_name or tool_args is None: return None @@ -611,29 +631,37 @@ class MinimaxToolParser(ToolParser): sent_tools[current_idx]["sent_arguments"] = clean_args self.streaming_state["sent_tools"] = sent_tools - if clean_args.endswith('}'): + if clean_args.endswith("}"): self._advance_to_next_tool() - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=current_idx, - function=DeltaFunctionCall( - arguments=args_delta).model_dump( - exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_delta + ).model_dump(exclude_none=True), + ) + ] + ) elif not sent_args and clean_args: clean_args_delta = self._clean_delta_braces(clean_args) sent_tools[current_idx]["sent_arguments"] = clean_args self.streaming_state["sent_tools"] = sent_tools - if clean_args.endswith('}'): + if clean_args.endswith("}"): self._advance_to_next_tool() - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=current_idx, - function=DeltaFunctionCall( - arguments=clean_args_delta).model_dump( - exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=clean_args_delta + ).model_dump(exclude_none=True), + ) + ] + ) return None @@ -651,14 +679,15 @@ class MinimaxToolParser(ToolParser): search_start = pos + 1 think_regions = [] - for match in re.finditer(self.thinking_tag_pattern, - current_text, - flags=re.DOTALL): + for match in re.finditer( + self.thinking_tag_pattern, current_text, flags=re.DOTALL + ): think_regions.append((match.start(), match.end())) for pos in end_token_positions: - in_think = any(pos >= t_start and pos < t_end - for t_start, t_end in think_regions) + in_think = any( + pos >= t_start and pos < t_end for t_start, t_end in think_regions + ) if not in_think: return True @@ -681,14 +710,12 @@ class MinimaxToolParser(ToolParser): if self._should_buffer_content(delta_text): buffered_output = self._process_buffer(delta_text) - return DeltaMessage( - content=buffered_output) if buffered_output else None + return DeltaMessage(content=buffered_output) if buffered_output else None if self._is_end_tool_calls(current_text): return DeltaMessage(content=delta_text) - safe_content, potential_tag = self._split_content_for_buffering( - delta_text) + safe_content, potential_tag = self._split_content_for_buffering(delta_text) if potential_tag: self.pending_buffer += potential_tag return DeltaMessage(content=safe_content) if safe_content else None @@ -696,35 +723,39 @@ class MinimaxToolParser(ToolParser): processed_current_text = self.preprocess_model_output(current_text) if self.tool_call_start_token not in processed_current_text: - if (self.tool_call_end_token in delta_text - and self.tool_call_start_token in current_text): + if ( + self.tool_call_end_token in delta_text + and self.tool_call_start_token in current_text + ): return None - if delta_text.strip( - ) == '' and self.tool_call_start_token in current_text: + if delta_text.strip() == "" and self.tool_call_start_token in current_text: return None - if (self._get_current_tool_index() != -1 - and self.tool_call_end_token in current_text): + if ( + self._get_current_tool_index() != -1 + and self.tool_call_end_token in current_text + ): self._reset_streaming_state() return DeltaMessage(content=delta_text) - if (self.tool_call_start_token_id is not None - and self.tool_call_start_token_id in delta_token_ids - and len(delta_token_ids) == 1): + if ( + self.tool_call_start_token_id is not None + and self.tool_call_start_token_id in delta_token_ids + and len(delta_token_ids) == 1 + ): return None - original_tool_start = self._find_tool_start_outside_thinking( - current_text) + original_tool_start = self._find_tool_start_outside_thinking(current_text) if original_tool_start is None: return None content_before_tools = self._extract_content_before_tools( - current_text, delta_text, original_tool_start) + current_text, delta_text, original_tool_start + ) if content_before_tools: return DeltaMessage(content=content_before_tools) try: - tool_content = self._extract_tool_content(current_text, - original_tool_start) + tool_content = self._extract_tool_content(current_text, original_tool_start) current_tools_count = self._detect_tools_in_text(tool_content) if current_tools_count == 0: @@ -735,24 +766,23 @@ class MinimaxToolParser(ToolParser): self._ensure_state_arrays(current_tools_count) - return (self._handle_tool_name_streaming(tool_content, - current_tools_count) - or self._handle_tool_args_streaming( - tool_content, current_tools_count)) + return self._handle_tool_name_streaming( + tool_content, current_tools_count + ) or self._handle_tool_args_streaming(tool_content, current_tools_count) except Exception: - logger.exception("An unexpected error occurred ", - "during streaming tool call handling.") + logger.exception( + "An unexpected error occurred ", "during streaming tool call handling." + ) return None - def _find_tool_start_outside_thinking(self, - current_text: str) -> Optional[int]: + def _find_tool_start_outside_thinking(self, current_text: str) -> Optional[int]: """ Find the start position of tool calls outside of thinking tags. - + Args: current_text: Current text to search - + Returns: Position of tool call start or None if not found """ @@ -762,26 +792,32 @@ class MinimaxToolParser(ToolParser): if pos == -1: return None - think_regions = [(m.start(), m.end()) for m in re.finditer( - r"<think>(.*?)</think>", current_text, flags=re.DOTALL)] - in_think = any(pos >= t_start and pos < t_end - for t_start, t_end in think_regions) + think_regions = [ + (m.start(), m.end()) + for m in re.finditer( + r"<think>(.*?)</think>", current_text, flags=re.DOTALL + ) + ] + in_think = any( + pos >= t_start and pos < t_end for t_start, t_end in think_regions + ) if not in_think: return pos search_start = pos + 1 - def _extract_content_before_tools(self, current_text: str, delta_text: str, - tool_start: int) -> Optional[str]: + def _extract_content_before_tools( + self, current_text: str, delta_text: str, tool_start: int + ) -> Optional[str]: """ Extract content that appears before tool calls. - + Args: current_text: Current text delta_text: Delta text tool_start: Start position of tools - + Returns: Content before tools or None """ @@ -790,18 +826,18 @@ class MinimaxToolParser(ToolParser): if delta_start_pos < tool_start: content_part = delta_text if delta_start_pos + len(delta_text) > tool_start: - content_part = delta_text[:tool_start - delta_start_pos] + content_part = delta_text[: tool_start - delta_start_pos] return content_part if content_part else None return None def _extract_tool_content(self, current_text: str, tool_start: int) -> str: """ Extract tool content from current text starting at tool_start. - + Args: current_text: Current text tool_start: Start position of tool calls - + Returns: Extracted tool content """ diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f122904e..b3b8960276bcc 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -12,15 +12,20 @@ import regex as re from partial_json_parser.core.options import Allow from pydantic import Field -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -30,8 +35,7 @@ ALPHANUMERIC = ascii_letters + digits class MistralToolCall(ToolCall): - id: str = Field( - default_factory=lambda: MistralToolCall.generate_random_id()) + id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id()) @staticmethod def generate_random_id(): @@ -45,8 +49,9 @@ class MistralToolCall(ToolCall): def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: - return isinstance(model_tokenizer, MistralTokenizer) \ - and model_tokenizer.version >= 11 + return ( + isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 + ) @ToolParserManager.register_module("mistral") @@ -63,35 +68,38 @@ class MistralToolParser(ToolParser): super().__init__(tokenizer) if not isinstance(self.model_tokenizer, MistralTokenizer): - logger.info("Non-Mistral tokenizer detected when using a Mistral " - "model...") + logger.info("Non-Mistral tokenizer detected when using a Mistral model...") # initialize properties used for state when parsing tool calls in # streaming mode self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if _is_fn_name_regex_support(self.model_tokenizer): self.fn_name_regex = re.compile( - r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) + r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL + ) else: self.fn_name_regex = None if self.bot_token_id is None: raise RuntimeError( "Mistral Tool Parser could not locate the tool call token in " - "the tokenizer!") + "the tokenizer!" + ) - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if not isinstance( - self.model_tokenizer, MistralTokenizer - ) and request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if ( + not isinstance(self.model_tokenizer, MistralTokenizer) + and request.tools + and request.tool_choice != "none" + ): # Do not skip special tokens when using chat template # with Mistral parser as TOOL_CALL token is needed # for tool detection. @@ -113,9 +121,9 @@ class MistralToolParser(ToolParser): # case -- if a tool call token is not present, return a text response if self.bot_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) # first remove the BOT token tool_content = model_output.replace(self.bot_token, "").strip() @@ -134,16 +142,15 @@ class MistralToolParser(ToolParser): # fn_name is encoded outside serialized json dump # only arguments are serialized - function_call_arr.append({ - "name": fn_name, - "arguments": json.loads(args) - }) + function_call_arr.append( + {"name": fn_name, "arguments": json.loads(args)} + ) else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: # use a regex to find the part corresponding to the tool call. # NOTE: This use case should not happen if the model is trained - # correctly. It's a easy possible fix so it's included, but + # correctly. It's an easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls raw_tool_call = self.tool_call_regex.findall(tool_content)[0] function_call_arr = json.loads(raw_tool_call) @@ -155,8 +162,11 @@ class MistralToolParser(ToolParser): function=FunctionCall( name=raw_function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(raw_function_call["arguments"], - ensure_ascii=False))) + arguments=json.dumps( + raw_function_call["arguments"], ensure_ascii=False + ), + ), + ) for raw_function_call in function_call_arr ] @@ -165,14 +175,15 @@ class MistralToolParser(ToolParser): return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if len(content) > 0 else None) + content=content if len(content) > 0 else None, + ) except Exception: logger.exception("Error in extracting tool call from response.") # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=tool_content) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=tool_content + ) def extract_tool_calls_streaming( self, @@ -184,7 +195,6 @@ class MistralToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool if self.bot_token not in current_text: @@ -195,8 +205,7 @@ class MistralToolParser(ToolParser): # handle if we detected the BOT token which means the start of tool # calling - if (self.bot_token_id in delta_token_ids - and len(delta_token_ids) == 1): + if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1: # if it's the only token, return None, so we don't send a chat # completion any don't send a control token return None @@ -205,10 +214,8 @@ class MistralToolParser(ToolParser): # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # replace BOT token with empty string, and convert single quotes # to double to allow parsing as JSON since mistral uses single # quotes instead of double for tool calls @@ -218,15 +225,17 @@ class MistralToolParser(ToolParser): # parsing on the entire array try: tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) + parsable_arr, flags + ) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -235,9 +244,9 @@ class MistralToolParser(ToolParser): # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -247,16 +256,19 @@ class MistralToolParser(ToolParser): if diff: diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff + self.streamed_args_for_tool[self.current_tool_id], "" + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += diff else: delta = None else: @@ -275,15 +287,18 @@ class MistralToolParser(ToolParser): if not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -291,64 +306,72 @@ class MistralToolParser(ToolParser): # now we know we're on the same tool call and we're streaming # arguments else: - - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) cur_arguments = current_tool_call.get("arguments") - new_text = delta_text.replace("\'", "\"") - if ('"}' in new_text): - new_text = new_text[:new_text.rindex('"}')] + new_text = delta_text.replace("'", '"') + if '"}' in new_text: + new_text = new_text[: new_text.rindex('"}')] if not cur_arguments and not prev_arguments: - delta = None elif not cur_arguments and prev_arguments: logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") + "INVARIANT - impossible to have arguments reset mid-arguments" + ) delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False)[:-2] - logger.debug("finding %s in %s", new_text, - cur_arguments_json) + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[ + :-2 + ] + logger.debug("finding %s in %s", new_text, cur_arguments_json) - if (new_text not in cur_arguments_json): + if new_text not in cur_arguments_json: return None - arguments_delta = cur_arguments_json[:cur_arguments_json. - rindex(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + arguments_delta = cur_arguments_json[ + : cur_arguments_json.rindex(new_text) + len(new_text) + ] + logger.debug( + "First tokens in arguments received: %s", arguments_delta + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) + logger.debug( + "Searching for diff between \n%s\n%s", + cur_args_json, + prev_args_json, + ) argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) + cur_args_json, prev_args_json + ) logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff else: # try parsing it with regular JSON - if it works we're # at the end, and we need to send the difference between @@ -364,6 +387,6 @@ class MistralToolParser(ToolParser): except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py new file mode 100644 index 0000000000000..8d7cbbfba649d --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import json +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from vllm.entrypoints.harmony_utils import parse_output_into_messages +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("openai") +class OpenAIToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + token_ids: Sequence[int] | None = None, + ) -> ExtractedToolCallInformation: + if token_ids is None: + raise NotImplementedError( + "OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501 + ) + + parser = parse_output_into_messages(token_ids) + tool_calls = [] + final_content = None + + if len(parser.messages) > 0: + for msg in parser.messages: + if len(msg.content) < 1: + continue + msg_text = msg.content[0].text + if msg.recipient and msg.recipient.startswith("functions."): + # If no content-type is given assume JSON, as that's the + # most common case with gpt-oss models. + if not msg.content_type or "json" in msg.content_type: + # load and dump the JSON text to check validity and + # remove any extra newlines or other odd formatting + try: + tool_args = json.dumps(json.loads(msg_text)) + except json.JSONDecodeError: + logger.exception( + "Error decoding JSON tool call from response." + ) + tool_args = msg_text + else: + tool_args = msg_text + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=msg.recipient.split("functions.")[1], + arguments=tool_args, + ), + ) + ) + elif msg.channel == "final": + final_content = msg_text + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=final_content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + raise NotImplementedError( + "Not being used, manual parsing in serving_chat.py" # noqa: E501 + ) diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 85dd56213c6ac..114987e5600b2 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -9,12 +9,17 @@ import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -26,7 +31,7 @@ class Phi4MiniJsonToolParser(ToolParser): Tool call parser for phi-4-mini models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json + Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json are all set """ @@ -38,39 +43,42 @@ class Phi4MiniJsonToolParser(ToolParser): self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.bot_token: str = "functools" def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ logger.debug("Model output: %s", model_output) - pattern = r'functools\[(.*?)\]' + pattern = r"functools\[(.*?)\]" matches = re.search(pattern, model_output, re.DOTALL) if not matches: logger.debug("No function calls found") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: function_call_arr: list[dict[str, Any]] = [] try: - json_content = '[' + matches.group(1) + ']' + json_content = "[" + matches.group(1) + "]" function_call_arr = json.loads(json_content) - logger.debug("Successfully extracted %d function calls", - len(function_call_arr)) + logger.debug( + "Successfully extracted %d function calls", len(function_call_arr) + ) except json.JSONDecodeError as e: logger.error( - "Failed to parse function calls from model output. " - "Error: %s", str(e)) + "Failed to parse function calls from model output. Error: %s", + str(e), + ) tool_calls: list[ToolCall] = [ ToolCall( @@ -81,22 +89,25 @@ class Phi4MiniJsonToolParser(ToolParser): # function call args are JSON but as a string arguments=json.dumps( raw_function_call["arguments"] - if "arguments" in raw_function_call else - raw_function_call["parameters"], - ensure_ascii=False), - )) for raw_function_call in function_call_arr + if "arguments" in raw_function_call + else raw_function_call["parameters"], + ensure_ascii=False, + ), + ), + ) + for raw_function_call in function_call_arr ] # get any content before the tool call - ret = ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=None) + ret = ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=None + ) return ret except Exception: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -108,5 +119,4 @@ class Phi4MiniJsonToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Optional[DeltaMessage]: - return None diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 73329cdf701d6..272068a6f0ac7 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -10,13 +10,19 @@ import regex as re from transformers import PreTrainedTokenizerBase import vllm.envs as envs -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -34,6 +40,7 @@ class PythonicToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set """ + # TODO(mdepinet): Possible future improvements: # 1. Support text + tools separated by either <|python_tag|> or \n\n # 2. Support tools outside of a list (or separated by a semicolon). @@ -43,7 +50,8 @@ class PythonicToolParser(ToolParser): TOOL_CALL_REGEX = re.compile( r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", - re.DOTALL) + re.DOTALL, + ) def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) @@ -58,48 +66,54 @@ class PythonicToolParser(ToolParser): self.current_tool_id = value def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ is_tool_call_pattern = False try: - is_tool_call_pattern = self.TOOL_CALL_REGEX.match( - model_output, - timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + is_tool_call_pattern = ( + self.TOOL_CALL_REGEX.match( + model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS + ) + is not None + ) except TimeoutError: - logger.warning( - "Regex timeout occurred when matching tool call pattern.") - logger.debug("Regex timeout occurred when matching user input: %s", - model_output) + logger.warning("Regex timeout occurred when matching tool call pattern.") + logger.debug( + "Regex timeout occurred when matching user input: %s", model_output + ) if not is_tool_call_pattern: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: module = ast.parse(model_output) parsed = getattr(module.body[0], "value", None) if isinstance(parsed, ast.List) and all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): return ExtractedToolCallInformation( tools_called=True, tool_calls=[ _handle_single_tool(e) # type: ignore for e in parsed.elts ], - content=None) + content=None, + ) else: raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -111,7 +125,6 @@ class PythonicToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - if not current_text.startswith("["): return DeltaMessage(content=delta_text) @@ -124,9 +137,11 @@ class PythonicToolParser(ToolParser): module = ast.parse(valid_text) parsed = getattr(module.body[0], "value", None) if not isinstance(parsed, ast.List) or not all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) tool_calls = [ _handle_single_tool(e) # type: ignore for e in parsed.elts @@ -141,31 +156,33 @@ class PythonicToolParser(ToolParser): if len(self.streamed_args_for_tool) == index: self.streamed_args_for_tool.append("") - new_call_complete = index < len( - tool_calls) - 1 or ")]" not in added_text + new_call_complete = ( + index < len(tool_calls) - 1 or ")]" not in added_text + ) if new_call_complete: self.current_tool_index += 1 - withheld_suffix = (added_text[:-2] - if not new_call_complete else "") + withheld_suffix = added_text[:-2] if not new_call_complete else "" if not new_call_complete and added_text[-2] == ")": # Function call is incomplete. Withhold the closing bracket. withheld_suffix = withheld_suffix + "}" # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta(self.streamed_args_for_tool[index], - new_call, index, withheld_suffix) + delta = _compute_tool_delta( + self.streamed_args_for_tool[index], new_call, index, withheld_suffix + ) if delta is not None: tool_deltas.append(delta) - if (delta.function is not None - and delta.function.arguments is not None): - self.streamed_args_for_tool[ - index] += delta.function.arguments + if ( + delta.function is not None + and delta.function.arguments is not None + ): + self.streamed_args_for_tool[index] += delta.function.arguments # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining it's final streaming delta, automatically + # when determining its final streaming delta, automatically # adding autocompleted JSON. # These two lines avoid that nonsense while ensuring finish_reason # is set to tool_calls when at least one tool is called. @@ -177,14 +194,14 @@ class PythonicToolParser(ToolParser): elif not added_text and self.current_tool_id > 0: # Return an empty DeltaMessage once the tool calls are all done # so that finish_reason gets set. - return DeltaMessage(content='') + return DeltaMessage(content="") else: return None except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None @@ -193,8 +210,7 @@ def _get_parameter_value(val: ast.expr) -> Any: return val.value elif isinstance(val, ast.Dict): if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError( - "Dict tool call arguments must have literal keys") + raise _UnexpectedAstError("Dict tool call arguments must have literal keys") return { k.value: _get_parameter_value(v) # type: ignore for k, v in zip(val.keys, val.values) @@ -214,9 +230,9 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments[keyword.arg] = _get_parameter_value(keyword.value) return ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(arguments, - ensure_ascii=False)), + function=FunctionCall( + name=function_name, arguments=json.dumps(arguments, ensure_ascii=False) + ), ) @@ -253,21 +269,25 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: # we can't fill in a valid value. return None if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[:text.rfind("{")] + trailing_dict_text = text[: text.rfind("{")] num_keys = trailing_dict_text.count(":") num_values = trailing_dict_text.count(",") if num_keys <= num_values: return None # Incomplete property name within parameter value if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[:text.rfind("(")] + trailing_params_text = text[: text.rfind("(")] num_full_param_names = trailing_params_text.count("=") num_full_param_values = trailing_params_text.count(",") if num_full_param_names <= num_full_param_values: return None # Incomplete parameter name if text.endswith(","): text = text[:-1] - if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( - "[") and not text.endswith(")"): + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): return None # Incomplete function name added_text = "" @@ -286,23 +306,29 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: return text + added_text, added_text -def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, - index: int, - withheld_suffix: str) -> Union[DeltaToolCall, None]: +def _compute_tool_delta( + previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str +) -> Union[DeltaToolCall, None]: new_call_args = new_call.function.arguments if withheld_suffix: assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[:-len(withheld_suffix)] + new_call_args = new_call_args[: -len(withheld_suffix)] if not previously_sent_args: - return DeltaToolCall(id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - )) + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) - arg_diff = new_call_args[len(previously_sent_args):] - return DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall( - arguments=arg_diff)) if arg_diff else None + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) + ) + if arg_diff + else None + ) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index 2501d6739e8f6..a41ca30bf5276 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import ast import json import uuid from collections.abc import Sequence @@ -8,28 +8,35 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["qwen3_coder"]) +@ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] + # Override base class type - we use string IDs for tool calls + self.current_tool_id: Optional[str] = None # type: ignore self.streamed_args_for_tool: list[str] = [] # Sentinel tokens for streaming mode @@ -42,50 +49,42 @@ class Qwen3CoderToolParser(ToolParser): self.is_tool_call_started: bool = False self.failed_count: int = 0 - # Streaming state variables - self.current_tool_index: int = 0 - self.header_sent: bool = False - self.current_tool_string_id: Optional[str] = None - self.current_function_name: Optional[str] = None - self.current_param_name: Optional[str] = None - self.current_param_value: str = "" - self.param_count: int = 0 - self.in_param: bool = False - self.in_function: bool = False - self.accumulated_text: str = "" - self.json_started: bool = False - self.json_closed: bool = False - # Enhanced streaming state - reset for each new message self._reset_streaming_state() # Regex patterns self.tool_call_complete_regex = re.compile( - r"<tool_call>(.*?)</tool_call>", re.DOTALL) + r"<tool_call>(.*?)</tool_call>", re.DOTALL + ) self.tool_call_regex = re.compile( - r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL) + r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL + ) self.tool_call_function_regex = re.compile( - r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL + ) self.tool_call_parameter_regex = re.compile( - r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", + re.DOTALL, + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: raise RuntimeError( "Qwen3 XML Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) - logger.debug("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -96,7 +95,7 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_index = 0 self.is_tool_call_started = False self.header_sent = False - self.current_tool_string_id = None + self.current_tool_id = None self.current_function_name = None self.current_param_name = None self.current_param_value = "" @@ -106,138 +105,167 @@ class Qwen3CoderToolParser(ToolParser): self.accumulated_text = "" self.json_started = False self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + + def _get_arguments_config( + self, func_name: str, tools: Optional[list[ChatCompletionToolsParam]] + ) -> dict: + """Extract argument configuration for a function.""" + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") and hasattr(config.function, "name") + ): + continue + if config.type == "function" and config.function.name == func_name: + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", func_name) + return {} + + def _convert_param_value( + self, param_value: str, param_name: str, param_config: dict, func_name: str + ) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", + param_name, + func_name, + ) + return param_value + + if ( + isinstance(param_config[param_name], dict) + and "type" in param_config[param_name] + ): + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, + param_name, + func_name, + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", + param_value, + param_name, + func_name, + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` or `false`) in tool '%s', degenerating to " + "false.", + param_value, + param_name, + func_name, + ) + return param_value == "true" + else: + if ( + param_type in ["object", "array", "arr"] + or param_type.startswith("dict") + or param_type.startswith("list") + ): + try: + param_value = json.loads(param_value) + return param_value + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "parsed with json.loads in tool '%s', will try " + "other methods to parse it.", + param_value, + param_name, + func_name, + ) + try: + param_value = ast.literal_eval(param_value) # safer + except (ValueError, SyntaxError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `ast.literal_eval()` in tool " + "'%s', degenerating to string.", + param_value, + param_name, + func_name, + ) + return param_value def _parse_xml_function_call( - self, function_call_str: str, - tools: Optional[list[ChatCompletionToolsParam]] + self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - - def get_arguments_config(func_name: str) -> dict: - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): - continue - if (config.type == "function" - and config.function.name == func_name): - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) - return {} - - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: - # Handle null value for any type - if param_value.lower() == "null": - return None - - converted_value: Any - - if param_name not in param_config: - if param_config != {}: - logger.warning( - "Parsed parameter '%s' is not defined in the tool " - "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) - return param_value - - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() - else: - param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: - return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): - try: - converted_value = int(param_value) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not an " - "integer in tool '%s', degenerating to string.", - param_value, param_name, func_name) - return param_value - elif (param_type.startswith("num") - or param_type.startswith("float")): - try: - float_param_value = float(param_value) - converted_value = (float_param_value if float_param_value - - int(float_param_value) != 0 else - int(float_param_value)) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value, - param_name, func_name) - return param_value - elif param_type in ["boolean", "bool", "binary"]: - param_value = param_value.lower() - if param_value not in ["true", "false"]: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "boolean (`true` of `false`) in tool '%s', " - "degenerating to false.", param_value, param_name, - func_name) - return param_value == "true" - else: - if param_type == "object" or param_type.startswith("dict"): - try: - converted_value = json.loads(param_value) - return converted_value - except json.JSONDecodeError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "valid JSON object in tool '%s', will try other " - "methods to parse it.", param_value, param_name, - func_name) - logger.warning( - "Parameter '%s' has unknown type '%s'. " - "The value will be treated as a string.", param_name, - param_type) - return param_value - # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] - param_config = get_arguments_config(function_name) - parameters = function_call_str[end_index + 1:] + param_config = self._get_arguments_config(function_name, tools) + parameters = function_call_str[end_index + 1 :] param_dict = {} - for match in self.tool_call_parameter_regex.findall(parameters): - match_text = match[0] if match[0] else match[1] + for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] - param_value = str(match_text[idx + 1:]) + param_value = str(match_text[idx + 1 :]) # Remove prefix and trailing \n if param_value.startswith("\n"): param_value = param_value[1:] if param_value.endswith("\n"): param_value = param_value[:-1] - param_dict[param_name] = convert_param_value( - param_value, param_name, param_config, function_name) + param_dict[param_name] = self._convert_param_value( + param_value, param_name, param_config, function_name + ) return ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(param_dict, - ensure_ascii=False)), + function=FunctionCall( + name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False) + ), ) def _get_function_calls(self, model_output: str) -> list[str]: @@ -253,8 +281,7 @@ class Qwen3CoderToolParser(ToolParser): raw_function_calls = [] for tool_call in raw_tool_calls: - raw_function_calls.extend( - self.tool_call_function_regex.findall(tool_call)) + raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) function_calls = [ match[0] if match[0] else match[1] for match in raw_function_calls @@ -268,38 +295,37 @@ class Qwen3CoderToolParser(ToolParser): ) -> ExtractedToolCallInformation: # Quick check to avoid unnecessary processing if self.tool_call_prefix not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: function_calls = self._get_function_calls(model_output) if len(function_calls) == 0: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) tool_calls = [ self._parse_xml_function_call(function_call_str, request.tools) for function_call_str in function_calls ] - # Populate prev_tool_call_arr for serving layer to set - # finish_reason + # Populate prev_tool_call_arr for serving layer to set finish_reason self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: - self.prev_tool_call_arr.append({ - "name": - tool_call.function.name, - "arguments": - tool_call.function.arguments, - }) + self.prev_tool_call_arr.append( + { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + ) # Extract content before tool calls content_index = model_output.find(self.tool_call_start_token) - content_index = (content_index if content_index >= 0 else - model_output.find(self.tool_call_prefix)) + idx = model_output.find(self.tool_call_prefix) + content_index = content_index if content_index >= 0 else idx content = model_output[:content_index] # .rstrip() return ExtractedToolCallInformation( @@ -310,9 +336,9 @@ class Qwen3CoderToolParser(ToolParser): except Exception: logger.exception("Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -324,39 +350,37 @@ class Qwen3CoderToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # If no delta text, return None unless it's an EOS token after tool - # calls + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools if not delta_text: # Check if this is an EOS token after all tool calls are complete - # We check for tool calls in the text even if is_tool_call_started - # is False because it might have been reset after processing all - # tools - if (delta_token_ids - and self.tool_call_end_token_id not in delta_token_ids): + # Check for tool calls in text even if is_tool_call_started + # is False (might have been reset after processing all tools) + if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: # Count complete tool calls complete_calls = len( - self.tool_call_complete_regex.findall(current_text)) + self.tool_call_complete_regex.findall(current_text) + ) # If we have completed tool calls and populated # prev_tool_call_arr - if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed - open_calls = ( - current_text.count(self.tool_call_start_token) - - current_text.count(self.tool_call_end_token)) + open_calls = current_text.count( + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) if open_calls == 0: - # Return empty delta message to allow finish_reason - # processing + # Return empty delta for finish_reason processing return DeltaMessage(content="") elif not self.is_tool_call_started and current_text: # This is a regular content response that's now complete return DeltaMessage(content="") return None - # Check if this is the first call (reset state if needed) - if not previous_text: - self._reset_streaming_state() - # Update accumulated text self.accumulated_text = current_text @@ -371,11 +395,11 @@ class Qwen3CoderToolParser(ToolParser): self.param_count = 0 self.json_started = False self.json_closed = False + self.accumulated_params = {} # Check if there are more tool calls - tool_starts_count = current_text.count( - self.tool_call_start_token) - if self.current_tool_index >= tool_starts_count: + tool_starts = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts: # No more tool calls self.is_tool_call_started = False # Continue processing next tool @@ -384,20 +408,25 @@ class Qwen3CoderToolParser(ToolParser): # Handle normal content before tool calls if not self.is_tool_call_started: # Check if tool call is starting - if (self.tool_call_start_token_id in delta_token_ids - or self.tool_call_start_token in delta_text): + if ( + self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text + ): self.is_tool_call_started = True # Return any content before the tool call if self.tool_call_start_token in delta_text: - content_before = delta_text[:delta_text.index( - self.tool_call_start_token)] + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] if content_before: return DeltaMessage(content=content_before) return None else: # Check if we're between tool calls - skip whitespace - if (current_text.rstrip().endswith(self.tool_call_end_token) - and delta_text.strip() == ""): + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): # We just ended a tool call, skip whitespace return None # Normal content, no tool call @@ -412,80 +441,87 @@ class Qwen3CoderToolParser(ToolParser): # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index - tool_starts: list[int] = [] + tool_start_positions: list[int] = [] idx = 0 while True: idx = current_text.find(self.tool_call_start_token, idx) if idx == -1: break - tool_starts.append(idx) + tool_start_positions.append(idx) idx += len(self.tool_call_start_token) - if self.current_tool_index >= len(tool_starts): + if self.current_tool_index >= len(tool_start_positions): # No more tool calls to process yet return None - tool_start_idx = tool_starts[self.current_tool_index] + tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) - tool_end_idx = current_text.find(self.tool_call_end_token, - tool_start_idx) + tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) if tool_end_idx == -1: tool_text = current_text[tool_start_idx:] else: - tool_text = current_text[tool_start_idx:tool_end_idx + - len(self.tool_call_end_token)] + tool_text = current_text[ + tool_start_idx : tool_end_idx + len(self.tool_call_end_token) + ] # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix + ) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_string_id = self._generate_tool_call_id() + self.current_tool_id = self._generate_tool_call_id() self.header_sent = True self.in_function = True - # IMPORTANT: Add to prev_tool_call_arr immediately when we - # detect a tool call. This ensures + # IMPORTANT: Add to prev_tool_call_arr immediately when + # we detect a tool call. This ensures # finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name - for tool in self.prev_tool_call_arr) + for tool in self.prev_tool_call_arr + ) if not already_added: - self.prev_tool_call_arr.append({ - "name": self.current_function_name, - "arguments": - "{}", # Placeholder, will be updated later - }) + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) # Send header with function info - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - id=self.current_tool_string_id, - function=DeltaFunctionCall( - name=self.current_function_name, arguments=""), - type="function", - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) return None # We've sent header, now handle function body if self.in_function: # Send opening brace if not sent yet - if (not self.json_started - and self.parameter_prefix not in delta_text): + if not self.json_started and self.parameter_prefix not in delta_text: self.json_started = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) # Make sure json_started is set if we're processing parameters if not self.json_started: @@ -496,107 +532,161 @@ class Qwen3CoderToolParser(ToolParser): # Close JSON self.json_closed = True - # Extract the complete tool call to update prev_tool_call_arr - # with final arguments. Find the function content - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) - func_content_end = tool_text.find(self.function_end_token, - func_start) + # Extract complete tool call to update + # prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix + ) + func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: func_content = tool_text[func_start:func_content_end] # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, + self.streaming_request.tools + if self.streaming_request + else None, + ) if parsed_tool: - # Update existing entry in prev_tool_call_arr with - # complete arguments + # Update existing entry in + # prev_tool_call_arr with complete args for i, tool in enumerate(self.prev_tool_call_arr): - if (tool.get("name") == - parsed_tool.function.name): - self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + if tool.get("name") == parsed_tool.function.name: + args = parsed_tool.function.arguments + self.prev_tool_call_arr[i]["arguments"] = args break except Exception: pass # Ignore parsing errors during streaming - result = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ]) + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) # Reset state for next tool self.in_function = False self.json_closed = True + self.accumulated_params = {} return result # Look for parameters - # Count how many complete parameters we have processed - complete_params = tool_text.count(self.parameter_end_token) + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) # Check if we should start a new parameter - if not self.in_param and self.param_count < complete_params: - # Find the unprocessed parameter - # Count parameter starts - param_starts = [] - idx = 0 - while True: - idx = tool_text.find(self.parameter_prefix, idx) - if idx == -1: - break - param_starts.append(idx) - idx += len(self.parameter_prefix) + if ( + not self.in_param + and self.param_count < len(param_starts) + and len(param_starts) > self.param_count + ): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] - if len(param_starts) > self.param_count: - # Process the next parameter - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] - if ">" in remaining: - # We have the complete parameter name - name_end = remaining.find(">") - self.current_param_name = remaining[:name_end] + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] - # Find the parameter value - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or + # function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.function_end_token) - # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) - if param_end_idx != -1: - # Complete parameter found - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - # Build complete JSON fragment for this parameter - if self.param_count == 0: - json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + if next_param_idx != -1 and ( + func_end_idx == -1 or next_param_idx < func_end_idx + ): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.tool_call_end_token in tool_text: + # Tool call is complete, so parameter + # must be complete too. Use all + # remaining text before function end + param_end_idx = len(value_text) else: - json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + # Still streaming, wait for more content + return None - self.param_count += 1 + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] - return DeltaMessage(tool_calls=[ + # Store raw value for later processing + self.accumulated_params[self.current_param_name] = param_value + + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request + else None, + ) + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, + self.current_param_name, + param_config, + self.current_function_name or "", + ) + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps( + converted_value, ensure_ascii=False + ) + + if self.param_count == 0: + json_fragment = ( + f'"{self.current_param_name}": {serialized_value}' + ) + else: + json_fragment = ( + f', "{self.current_param_name}": {serialized_value}' + ) + + self.param_count += 1 + + return DeltaMessage( + tool_calls=[ DeltaToolCall( index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), + function=DeltaFunctionCall(arguments=json_fragment), ) - ]) + ] + ) - # Continue parameter value + # Continue parameter value - Not used in the current implementation + # since we process complete parameters above if self.in_param: if self.parameter_end_token in delta_text: # End of parameter @@ -606,29 +696,51 @@ class Qwen3CoderToolParser(ToolParser): # Skip past > if at start if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] - # Calculate incremental JSON + # Store complete value full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") - full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + self.accumulated_params[self.current_param_name] = full_value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request + else None, + ) + + # Convert the parameter value to the appropriate type + converted_value = self._convert_param_value( + full_value, + self.current_param_name or "", + param_config, + self.current_function_name or "", + ) + + # Serialize the converted value + serialized_value = json.dumps(converted_value, ensure_ascii=False) + + # Since we've been streaming the quoted version, + # we need to close it properly + # This is complex - for now just complete the value self.in_param = False self.current_param_value = "" - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped + '"'), - ) - ]) + # Just close the current parameter string + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments='"' + ), # Close the string quote + ) + ] + ) else: # Continue accumulating value value_chunk = delta_text @@ -636,29 +748,36 @@ class Qwen3CoderToolParser(ToolParser): # Handle first chunk after param name if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = ( + json.dumps(self.current_param_value, ensure_ascii=False)[ + 1:-1 + ] + if self.current_param_value + else "" + ) self.current_param_value += value_chunk full_escaped = json.dumps( - self.current_param_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + self.current_param_value, ensure_ascii=False + )[1:-1] + delta_escaped = full_escaped[len(prev_escaped) :] if delta_escaped: - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + ), + ) + ] + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py new file mode 100644 index 0000000000000..1b7e4fec316eb --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py @@ -0,0 +1,1251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union +from xml.parsers.expat import ParserCreate + +import regex as re + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +class StreamingXMLToolCallParser: + """ + Simplified streaming XML tool call parser + Supports streaming input, parsing, and output + """ + + def __init__(self): + self.reset_streaming_state() + + # Tool configuration information + self.tools: Union[list[ChatCompletionToolsParam], None] = None + self.tool_call_start_token: str = "<tool_call>" + self.tool_call_end_token: str = "</tool_call>" + self.function_start_token: str = "<function=" + self.function_end_token: str = "</function>" + self.parameter_start_token: str = "<parameter=" + self.parameter_end_token: str = "</parameter>" + + def reset_streaming_state(self): + """Reset streaming parsing state""" + + self.deltas = [] + # state for streaming + self.tool_call_index = 0 + self.current_call_id = None + self.last_completed_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + + self.streaming_buffer = "" + self.last_processed_pos = 0 + + self.text_content_buffer = "" + + # state for preprocessing and deferred parsing + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + # recreate parser + self.parser = ParserCreate() + self.setup_parser() + + def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: + """ + Parse single streaming XML chunk and return Delta response + This is the actual streaming interface that receives chunks + one by one and maintains internal state + + Args: + xml_chunk: Single XML chunk string + Returns: + DeltaMessage: Contains delta information generated by this chunk, + returns empty response if no complete elements + """ + # Record delta count before processing + initial_delta_count = len(self.deltas) + + self.streaming_buffer += xml_chunk + + found_elements = self._process_complete_xml_elements() + + if found_elements: + # If complete elements found, check if end events were missed + # some tags may not have been triggered + try: + new_deltas = self.deltas[initial_delta_count:] + # If this chunk contains </function> + # but didn't generate '}', then complete it + if ( + self.current_call_id is not None + and self.function_end_token in xml_chunk + ): + # - Added '}' (non-empty parameter ending) + # - Added '{}' (empty parameter function) + has_function_close = any( + ( + td.tool_calls + and any( + ( + tc.function + and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) + and (tc.function.arguments in ("}", "{}")) + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) + if not has_function_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.current_function_name: + self._end_element("function") + # If this chunk contains </tool_call> + # but didn't generate final empty delta, then complete it + if ( + self.current_call_id is not None + and self.tool_call_end_token in xml_chunk + ): + has_toolcall_close = any( + ( + td.tool_calls + and any( + ( + tc.type == "function" + and tc.function + and tc.function.arguments == "" + and tc.id == self.current_call_id + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) + if not has_toolcall_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.current_function_name: + self._end_element("function") + self._end_element("tool_call") + except Exception as e: + logger.warning("Error with fallback parsing: %s", e) + # Merge newly generated deltas into single response + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count + ) + return result_delta + else: + # No complete elements, check if there's unoutput text content + if self.text_content_buffer and self.tool_call_index == 0: + # Has text content but no tool_call yet, output text content + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + # Clear buffer to avoid duplicate output + self.text_content_buffer = "" + return text_delta + + # If this chunk contains end tags but wasn't triggered by parser, + # manually complete end events + # Only execute when still on the same call as when entered, + # to prevent accidentally closing new calls + # in multi <tool_call> scenarios + if self.current_call_id is not None and ( + self.function_end_token in xml_chunk + or self.tool_call_end_token in xml_chunk + ): + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.function_end_token in xml_chunk and self.current_function_name: + self._end_element("function") + if self.tool_call_end_token in xml_chunk: + self._end_element("tool_call") + # Return the merged delta result generated by this fallback + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count + ) + return result_delta + + # No complete elements, return empty response + return DeltaMessage(content=None) + + def _escape_xml_special_chars(self, text: str) -> str: + """ + Escape XML special characters + Args: + text: Original text + Returns: + Escaped text + """ + xml_escapes = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", + } + + for char, escape in xml_escapes.items(): + text = text.replace(char, escape) + + return text + + def _process_complete_xml_elements(self) -> bool: + """ + Process complete XML elements in buffer + + Returns: + bool: Whether complete elements were found and processed + """ + found_any = False + + while self.last_processed_pos < len(self.streaming_buffer): + # Find next complete xml element + element, end_pos = self._find_next_complete_element(self.last_processed_pos) + if element is None: + # No complete element found, wait for more data + break + + # Check if this element should be skipped + if self._should_skip_element(element): + self.last_processed_pos = end_pos + continue + + # Found complete XML element, process it + try: + preprocessed_element = self._preprocess_xml_chunk(element) + # Check if this is the first tool_call start + if ( + ( + preprocessed_element.strip().startswith("<tool_call>") + or preprocessed_element.strip().startswith("<function name=") + ) + and self.tool_call_index == 0 + ) and self.text_content_buffer: + # First tool_call starts, + # output previously collected text content first + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + # Clear buffer for potential subsequent text content + self.text_content_buffer = "" + + # If a new tool_call starts and + # there are already completed tool_calls + if ( + preprocessed_element.strip().startswith("<tool_call>") + and self.tool_call_index > 0 + and self.current_call_id + ): + # Reset parser state but preserve generated deltas + if self.current_param_name: + self._end_element("parameter") + if self.current_function_open or self.current_function_name: + self._end_element("function") + # Output final tool_call tail delta + final_delta = DeltaMessage( + role=None, + content=None, + reasoning_content=None, + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ], + ) + self._emit_delta(final_delta) + # Reset XML parser and current call state + self._reset_xml_parser_after_tool_call() + # Parse preprocessed element + self.parser.Parse(preprocessed_element, False) + found_any = True + + except Exception as e: + logger.warning("Error when parsing XML elements: %s", e) + + # Update processed position + self.last_processed_pos = end_pos + + return found_any + + def _should_skip_element(self, element: str) -> bool: + """ + Determine whether an element should be skipped + + Args: + element: Element to evaluate + + Returns: + bool: True means should skip, False means should process + """ + + # If it's a tool_call XML tag, don't skip + if ( + element.startswith(self.tool_call_start_token) + or element.startswith(self.function_start_token) + or element.startswith(self.parameter_start_token) + ): + return False + + # If currently not parsing tool calls and not blank, + # collect this text instead of skipping + # Only process other XML elements after tool_call appears, + # otherwise treat as plain text + if self.current_call_id is None and element: + # Collect text content to buffer + self.text_content_buffer += element + return True # Still skip, but content has been collected + + # If currently parsing tool calls, + # this might be parameter value, don't skip + if self.current_call_id is not None: + return False + + # Skip blank content + return not element + + def _find_next_complete_element(self, start_pos: int) -> tuple[Optional[str], int]: + """ + Find next complete XML element from specified position + + Args: + start_pos: Position to start searching + + Returns: + (Complete element string, element end position), + returns (None, start_pos) if no complete element found + """ + buffer = self.streaming_buffer[start_pos:] + + if not buffer: + return None, start_pos + + if buffer.startswith("<"): + # Need to ensure no new < appears, + # find the nearest one between < and > + tag_end = buffer.find("<", 1) + tag_end2 = buffer.find(">", 1) + if tag_end != -1 and tag_end2 != -1: + # Next nearest is < + if tag_end < tag_end2: + return buffer[:tag_end], start_pos + tag_end + # Next nearest is >, means found XML element + else: + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 + elif tag_end != -1: + return buffer[:tag_end], start_pos + tag_end + elif tag_end2 != -1: + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 + else: + # If currently not parsing tool calls (entering a tool_call), + # check if starts with <tool_call> + if self.current_call_id is None: + # Check if might be start of <tool_call> + if buffer == "<tool_call>"[: len(buffer)]: + # Might be start of <tool_call>, wait for more data + return None, start_pos + else: + # Not start of <tool_call>, treat as text + return buffer, start_pos + len(buffer) + else: + # When parsing tool calls, + # wait for more data to get complete tag + return None, start_pos + else: + # Find text content (until next < or buffer end) + next_tag_pos = buffer.find("<") + if next_tag_pos != -1: + # Found text content + text_content = buffer[:next_tag_pos] + return text_content, start_pos + next_tag_pos + else: + # Buffer end is all text, process + # (no longer wait for more data) + remaining = buffer + return remaining, start_pos + len(remaining) + + def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage: + """ + Merge newly generated deltas from this processing + into a single DeltaMessage + + Args: + initial_count: Delta count before processing + + Returns: + Merged DeltaMessage containing all newly generated delta information + """ + if len(self.deltas) <= initial_count: + return DeltaMessage(content=None) + + # Get newly generated deltas + new_deltas = self.deltas[initial_count:] + + if len(new_deltas) == 1: + # Only one new delta, return directly + return new_deltas[0] + + # Merge multiple new deltas + merged_tool_calls: list[DeltaToolCall] = [] + merged_content: str = "" + + for delta in new_deltas: + if delta.content: + merged_content += delta.content + if delta.tool_calls: + # For tool_calls, we need to intelligently merge arguments + for tool_call in delta.tool_calls: + # Find if there's already a tool_call with the same call_id + existing_call = None + for existing in merged_tool_calls: + if existing.id == tool_call.id: + existing_call = existing + break + + if existing_call and existing_call.function: + # Merge to existing tool_call + if tool_call.function and tool_call.function.name: + existing_call.function.name = tool_call.function.name + if ( + tool_call.function + and tool_call.function.arguments is not None + ): + if existing_call.function.arguments is None: + existing_call.function.arguments = "" + + # For streaming JSON parameters, + # simply concatenate in order + new_args = tool_call.function.arguments + existing_call.function.arguments += new_args + if tool_call.type: + existing_call.type = tool_call.type + else: + # Add new tool_call + merged_tool_calls.append(tool_call) + + return DeltaMessage( + content=merged_content if merged_content else None, + tool_calls=merged_tool_calls, + ) + + def _preprocess_xml_chunk(self, chunk: str) -> str: + """ + Preprocess XML chunk, handle non-standard formats, + and escape special characters + + Args: + chunk: Original XML chunk + + Returns: + Processed XML chunk + """ + + # Check if this is a tool_call related element + is_tool_call = False + if chunk.startswith(self.tool_call_start_token) or chunk.startswith( + self.tool_call_end_token + ): + is_tool_call = True + if chunk.startswith(self.function_start_token) or chunk.startswith( + self.function_end_token + ): + is_tool_call = True + if chunk.startswith(self.parameter_start_token) or chunk.startswith( + self.parameter_end_token + ): + is_tool_call = True + # Handle <function=name> format -> <function name="name"> + processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk) + # Handle <parameter=name> format -> <parameter name="name"> + processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed) + + original_chunk = chunk + # If in parameter value accumulation mode + if self._pre_inside_parameter: + # Parameter end: output accumulated raw text + # safely then return </parameter> + if processed.startswith("</parameter>"): + body_text = self._pre_param_buffer + # Trigger deferred parsing mode + # literal_eval+json output in end_element + self.defer_current_parameter = True + self.deferred_param_raw_value = body_text + # Clean up state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + safe_text = self._escape_xml_special_chars(body_text) + return f"{safe_text}</parameter>" + else: + # If this is the first block of content after entering parameter + # evaluate if deferred parsing is needed; + # If not needed, exit accumulation mode + # and pass through directly + if self._pre_param_buffer == "": + # Get current parameter type + param_type = ( + self._get_param_type(self._pre_current_param_name) + if self._pre_current_param_name + else "string" + ) + # Only these types need deferred parsing to + # handle Python literals containing single quotes + is_object_type = param_type in ["object"] + is_complex_type = ( + param_type in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + + # Only delay when contains container symbols + # and has single quotes and is complex type + has_container_hint = ( + ("[" in original_chunk) + or ("{" in original_chunk) + or ("(" in original_chunk) + ) + + # Determine if deferred parsing is needed + need_defer = False + if is_complex_type: + # Complex type, always need deferred parsing + need_defer = True + elif ( + is_object_type + and has_container_hint + and ("'" in original_chunk) + ): + # Object type with container symbols + # and single quotes, need deferred parsing + need_defer = True + + if not need_defer: + # No need for deferred parsing, + # exit parameter mode directly + self._pre_inside_parameter = False + return self._escape_xml_special_chars(original_chunk) + self._pre_param_buffer += original_chunk + return "" + + # Parameter start: enable accumulation + if processed.startswith("<parameter name="): + m = re.match(r'<parameter name="([^"]+)">', processed) + if m: + self._pre_current_param_name = m.group(1) + self._pre_inside_parameter = True + self._pre_param_buffer = "" + return processed + + # If processed doesn't contain special_token, escape processed + # This is because XML parsing encounters special characters + # and reports errors, so escaping is needed + if not is_tool_call: + processed = self._escape_xml_special_chars(processed) + return processed + + def _emit_delta(self, delta: DeltaMessage): + """Emit Delta response (streaming output)""" + self.deltas.append(delta) + + def _auto_close_open_parameter_if_needed(self, incoming_tag: Optional[str] = None): + """Before starting to process new elements, + if there are unclosed tags from before, + automatically complete their endings to the parser. + - If there are unclosed parameters, + it's equivalent to feeding `</parameter>` + - When about to start a new function or tool_call, + if there are unclosed functions, complete `</function>`. + - When about to start a new tool_call, + if there are unclosed tool_calls, complete `</tool_call>`. + """ + # First close unclosed parameters + if self.current_param_name: + self._end_element("parameter") + + # If about to start new function or tool_call, + # and there are unclosed functions, close function first + if incoming_tag in ("function", "tool_call") and self.current_function_name: + self._end_element("function") + + # If about to start new tool_call, + # and there are unclosed tool_calls, close tool_call first + if incoming_tag == "tool_call" and self.current_call_id: + self._end_element("tool_call") + + def _start_element(self, name: str, attrs: dict[str, str]): + """Handle XML start element events""" + + if name == "root": + return + + if name == "tool_call": + # Before opening new tool_call, + # automatically complete previous unclosed tags + self._auto_close_open_parameter_if_needed("tool_call") + + self.parameters = {} + self.current_call_id = self._get_next_call_id() + self.current_param_is_first = True + self.tool_call_index += 1 + elif name.startswith("function") or (name == "function"): + # If missing tool_call, manually complete + if not self.current_call_id: + self._start_element("tool_call", {}) + # Before opening new function, + # automatically complete previous unclosed tags (parameter/function) + self._auto_close_open_parameter_if_needed("function") + function_name = self._extract_function_name(name, attrs) + self.current_function_name = function_name + self.current_function_open = True + if function_name: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=function_name, arguments="" + ), + ) + ] + ) + self._emit_delta(delta) + elif name.startswith("parameter") or (name == "parameter"): + # If previous parameter hasn't ended normally, + # complete its end first, then start new parameter + self._auto_close_open_parameter_if_needed("parameter") + param_name = self._extract_parameter_name(name, attrs) + self.current_param_name = param_name + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False # Reset start quote flag + + # Only output parameter name and colon, + # don't output quotes + # decide after parameter value type is determined + if param_name: + if not self.parameters: + # First parameter + # start JSON, only output parameter name and colon + json_start = f'{{"{param_name}": ' + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_start + ), + ) + ] + ) + self._emit_delta(delta) + self.current_param_is_first = True + else: + # Subsequent parameters + # add comma and parameter name, no quotes + json_continue = f', "{param_name}": ' + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_continue + ), + ) + ] + ) + self._emit_delta(delta) + self.current_param_is_first = False + + def _char_data(self, data: str): + """Handle XML character data events""" + if data and self.current_param_name: + # If preprocessing stage determines deferred parsing is needed, + # only cache character data, no streaming output + if self.defer_current_parameter: + original_data = data + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + return + + param_type = self._get_param_type(self.current_param_name) + + # Check if this is the first time receiving data for this parameter + # If this is the first packet of data and starts with \n, remove \n + if not self.current_param_value and data.startswith("\n"): + data = data[1:] + + # Output start quote for string type (if not already output) + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + and not self.start_quote_emitted + ): + quote_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) + self._emit_delta(quote_delta) + self.start_quote_emitted = True + + if not data: + return + + original_data = data + # Delay output of trailing newline + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + + # convert parameter value by param_type + converted_value = self._convert_param_value( + self.current_param_value, param_type + ) + output_data = self._convert_for_json_streaming(converted_value, param_type) + + delta_data = output_data[len(self.current_param_value_converted) :] + self.current_param_value_converted = output_data + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=delta_data), + ) + ] + ) + self._emit_delta(delta) + + def _end_element(self, name: str): + """Handle XML end element events""" + + if name == "root": + return + + # If function or tool_call ends and there are still unclosed parameters, + # complete parameter end first + if ( + name.startswith("function") or name == "function" or name == "tool_call" + ) and self.current_param_name: + self._auto_close_open_parameter_if_needed() + + if ( + name.startswith("parameter") or name == "parameter" + ) and self.current_param_name: + # End current parameter + param_name = self.current_param_name + param_value = self.current_param_value + + # If in deferred parsing mode, + # perform overall parsing on raw content + # accumulated in preprocessing stage and output once + if self.defer_current_parameter: + raw_text = ( + self.deferred_param_raw_value + if self.deferred_param_raw_value + else param_value + ) + parsed_value = None + output_arguments = None + try: + # If previously delayed trailing newline, + # add it back before parsing + if self.should_emit_end_newline: + raw_for_parse = raw_text + "\n" + else: + raw_for_parse = raw_text + parsed_value = ast.literal_eval(raw_for_parse) + output_arguments = json.dumps(parsed_value, ensure_ascii=False) + except Exception: + # Fallback: output as string as-is + output_arguments = json.dumps(raw_text, ensure_ascii=False) + parsed_value = raw_text + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=output_arguments + ), + ) + ] + ) + self._emit_delta(delta) + + # Clean up and store + self.should_emit_end_newline = False + self.parameters[param_name] = parsed_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + return + + param_type = self._get_param_type(param_name) + + # convert complete parameter value by param_type + converted_value = self._convert_param_value(param_value, param_type) + + # Decide whether to add end quote based on parameter type + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # For empty string parameters, need special handling + if not param_value and not self.start_quote_emitted: + # No start quote output, + # directly output complete empty string + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='""'), + ) + ] + ) + self._emit_delta(delta) + else: + # Non-empty parameter value, output end quote + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) + self._emit_delta(delta) + + self.should_emit_end_newline = False + # Store converted value + self.parameters[param_name] = converted_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + + elif name.startswith("function") or name == "function": + # if there are parameters, close JSON object + if self.parameters: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ] + ) + self._emit_delta(delta) + # return empty object + else: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ] + ) + self._emit_delta(delta) + self.current_function_open = False + + elif name == "tool_call": + # Before ending tool_call, + # ensure function is closed to complete missing right brace + if self.current_function_open: + # If there are still unclosed parameters, close them first + if self.current_param_name: + self._end_element("parameter") + # Close function, ensure output '}' or '{}' + self._end_element("function") + # Final Delta + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ] + ) + self._emit_delta(delta) + + # Check if there's text content to output (between tool_calls) + if self.text_content_buffer.strip(): + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + + self._reset_xml_parser_after_tool_call() + + def setup_parser(self): + """Set up XML parser event handlers""" + self.parser.buffer_text = True + self.parser.StartElementHandler = self._start_element + self.parser.EndElementHandler = self._end_element + self.parser.CharacterDataHandler = self._char_data + + def set_tools(self, tools: Union[list[ChatCompletionToolsParam], None]): + """Set tool configuration information""" + self.tools = tools + + def _get_next_call_id(self): + """Generate unique call ID""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _extract_function_name(self, name: str, attrs: dict[str, str]) -> Optional[str]: + """Extract function name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "function": + return parts[1] + + return None + + def _extract_parameter_name( + self, name: str, attrs: dict[str, str] + ) -> Optional[str]: + """Extract parameter name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "parameter": + return parts[1] + + return None + + def _get_param_type(self, param_name: str) -> str: + """Get parameter type based on tool configuration, defaults to string + Args: + param_name: Parameter name + + Returns: + Parameter type + """ + if not self.tools or not self.current_function_name: + return "string" + + for tool in self.tools: + if not hasattr(tool, "type") or not ( + hasattr(tool, "function") and hasattr(tool.function, "name") + ): + continue + if ( + tool.type == "function" + and tool.function.name == self.current_function_name + ): + if not hasattr(tool.function, "parameters"): + return "string" + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + properties = params["properties"] + if param_name in properties and isinstance( + properties[param_name], dict + ): + return self.repair_param_type( + str(properties[param_name].get("type", "string")) + ) + elif isinstance(params, dict) and param_name in params: + param_config = params[param_name] + if isinstance(param_config, dict): + return self.repair_param_type( + str(param_config.get("type", "string")) + ) + break + return "string" + + def repair_param_type(self, param_type: str) -> str: + """Repair unknown parameter types by treating them as string + Args: + param_type: Parameter type + + Returns: + Repaired parameter type + """ + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + or param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + or param_type.startswith("num") + or param_type.startswith("float") + or param_type in ["boolean", "bool", "binary"] + or ( + param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + ): + return param_type + else: + return "string" + + def _convert_param_value(self, param_value: str, param_type: str) -> Any: + """Convert value based on parameter type + Args: + param_value: Parameter value + param_type: Parameter type + + Returns: + Converted value + """ + if param_value.lower() == "null": + return None + + param_type = param_type.strip().lower() + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer " + "in tool '%s', degenerating to string.", + param_value, + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value: float = float(param_value) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", + param_value, + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + return param_value == "true" + else: + return param_value + + def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str: + """Convert converted_value based on + whether it's empty and if type is string + Args: + converted_value: Converted value + param_type: Parameter type + + Returns: + Converted string for streaming output + """ + # Check if value is empty, but exclude numeric 0 + if converted_value is None or converted_value == "": + return "" + + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # String type, remove double quotes + return json.dumps(converted_value, ensure_ascii=False)[1:-1] + else: + # Non-string type, return complete JSON string + if not isinstance(converted_value, str): + return json.dumps(converted_value, ensure_ascii=False) + else: + return converted_value + + def _reset_xml_parser_after_tool_call(self): + """ + Each tool_call is treated as a separate XML document, + so we need to reset the parser after each tool_call. + """ + + # recreate XML parser + self.parser = ParserCreate() + self.setup_parser() + + # Reset current tool_call state + if self.current_call_id: + self.last_completed_call_id = self.current_call_id + self.current_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + self.text_content_buffer = "" + + # Reset preprocessing and deferred parsing state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + +@ToolParserManager.register_module("qwen3_xml") +class Qwen3XMLToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.parser = StreamingXMLToolCallParser() + + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + self.parser.reset_streaming_state() + if request: + self.parser.set_tools(request.tools) + result = self.parser.parse_single_streaming_chunks(model_output) + if not result.tool_calls: + return ExtractedToolCallInformation( + tool_calls=[], + tools_called=False, + content=result.content, + ) + else: + tool_calls = [] + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_calls.append( + ToolCall( + id=tool_call.id, + type=tool_call.type, + function=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ) + return ExtractedToolCallInformation( + tool_calls=tool_calls, + tools_called=len(tool_calls) > 0, + content=result.content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if not previous_text: + self.parser.reset_streaming_state() + if request: + self.parser.set_tools(request.tools) + + # Model sometimes outputs separately causing delta_text to be empty. + # If there were tool_calls before and all current tool_calls have ended, + # return an empty tool_call for outer streaming output + # to correctly output tool_call field + if not delta_text and delta_token_ids: + open_calls = current_text.count( + self.parser.tool_call_start_token + ) - current_text.count(self.parser.tool_call_end_token) + if open_calls == 0 and self.parser.tool_call_index > 0: + # If current_call_id is None, use last_completed_call_id + call_id = ( + self.parser.current_call_id or self.parser.last_completed_call_id + ) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.parser.tool_call_index - 1, + id=call_id, + function=DeltaFunctionCall(arguments=""), + type="function", + ) + ] + ) + + return self.parser.parse_single_streaming_chunks(delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py index 69cf2e68f7c41..2e7bd0d1d344d 100644 --- a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -11,14 +11,20 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -51,33 +57,36 @@ class SeedOssToolParser(ToolParser): self.failed_count: int = 0 self._reset_streaming_state() - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: raise RuntimeError( "Seed_Oss XML parser: tokenizer did not include " - "<seed:tool_call> or its closing tag.") + "<seed:tool_call> or its closing tag." + ) tool_start_re = re.escape(self.tool_call_start_token) tool_end_re = re.escape(self.tool_call_end_token) self.tool_call_complete_regex = re.compile( - rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) + rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL + ) self.tool_call_regex = re.compile( - rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", - re.DOTALL) + rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", re.DOTALL + ) self.tool_call_function_regex = re.compile( - r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL + ) self.tool_call_parameter_regex = re.compile( - r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL + ) - logger.info("vLLM Seed-Oss XML tool parser loaded (%s).", - self.__class__.__name__) + logger.info( + "vLLM Seed-Oss XML tool parser loaded (%s).", self.__class__.__name__ + ) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -100,20 +109,17 @@ class SeedOssToolParser(ToolParser): self.json_closed = False def _parse_xml_function_call( - self, function_call_str: str, - tools: Optional[list[ChatCompletionToolsParam]] + self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - def get_arguments_config(func_name: str) -> dict: if tools is None: return {} for config in tools: if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): + hasattr(config, "function") and hasattr(config.function, "name") + ): continue - if (config.type == "function" - and config.function.name == func_name): + if config.type == "function" and config.function.name == func_name: if not hasattr(config.function, "parameters"): return {} params = config.function.parameters @@ -123,12 +129,12 @@ class SeedOssToolParser(ToolParser): return params else: return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) + logger.warning("Tool '%s' is not defined in the tools list.", func_name) return {} - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: + def convert_param_value( + param_value: str, param_name: str, param_config: dict, func_name: str + ) -> Any: # Handle null value for any type if param_value.lower() == "null": return None @@ -138,44 +144,55 @@ class SeedOssToolParser(ToolParser): logger.warning( "Parsed parameter '%s' is not defined in " "the tool parameters for tool '%s', " - "directly returning the string value.", param_name, - func_name) + "directly returning the string value.", + param_name, + func_name, + ) return param_value - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() + if ( + isinstance(param_config[param_name], dict) + and "type" in param_config[param_name] + ): + param_type = str(param_config[param_name]["type"]).strip().lower() else: param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): try: param_value = int(param_value) # type: ignore except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not an integer in tool " - "'%s', degenerating to string.", param_value, - param_name, func_name) + "'%s', degenerating to string.", + param_value, + param_name, + func_name, + ) return param_value - elif param_type.startswith("num") or param_type.startswith( - "float"): + elif param_type.startswith("num") or param_type.startswith("float"): try: float_param_value = float(param_value) - param_value = float_param_value if float_param_value - int( - float_param_value) != 0 else int( - float_param_value) # type: ignore + param_value = ( + float_param_value # type: ignore + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) # type: ignore + ) except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not a float in tool " - "'%s', degenerating to string.", param_value, - param_name, func_name) + "'%s', degenerating to string.", + param_value, + param_name, + func_name, + ) return param_value elif param_type in ["boolean", "bool", "binary"]: param_value = param_value.lower() @@ -183,7 +200,10 @@ class SeedOssToolParser(ToolParser): logger.warning( "Parsed value '%s' of parameter '%s' is not a boolean " "(`true` of `false`) in tool '%s', degenerating to false.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) return param_value == "true" else: if param_type == "object" or param_type.startswith("dict"): @@ -194,27 +214,33 @@ class SeedOssToolParser(ToolParser): logger.warning( "Parsed value '%s' of parameter '%s' is not a valid JSON " "object in tool '%s', will try other methods to parse it.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) try: param_value = ast.literal_eval(param_value) except (ValueError, SyntaxError): logger.warning( "Parsed value '%s' of parameter '%s' cannot be converted via " "Python `ast.literal_eval()` in tool '%s', degenerating to string.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) return param_value # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] param_config = get_arguments_config(function_name) - parameters = function_call_str[end_index + 1:] + parameters = function_call_str[end_index + 1 :] param_dict = {} for match in self.tool_call_parameter_regex.findall(parameters): match_text = match[0] if match[0] else match[1] idx = match_text.index(">") param_name = match_text[:idx] - param_value = str(match_text[idx + 1:]) + param_value = str(match_text[idx + 1 :]) # Remove prefix and trailing \n if param_value.startswith("\n"): param_value = param_value[1:] @@ -222,12 +248,13 @@ class SeedOssToolParser(ToolParser): param_value = param_value[:-1] param_dict[param_name] = convert_param_value( - param_value, param_name, param_config, function_name) + param_value, param_name, param_config, function_name + ) return ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(param_dict, - ensure_ascii=False)), + function=FunctionCall( + name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False) + ), ) def _get_function_calls(self, model_output: str) -> list[str]: @@ -243,8 +270,7 @@ class SeedOssToolParser(ToolParser): raw_function_calls = [] for tool_call in raw_tool_calls: - raw_function_calls.extend( - self.tool_call_function_regex.findall(tool_call)) + raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) function_calls = [ match[0] if match[0] else match[1] for match in raw_function_calls @@ -258,26 +284,32 @@ class SeedOssToolParser(ToolParser): ) -> ExtractedToolCallInformation: # Quick check to avoid unnecessary processing if self.tool_call_prefix not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) # Check if both think start and end tokens are present - if (self.think_start_token in model_output - and self.think_end_token in model_output): + if ( + self.think_start_token in model_output + and self.think_end_token in model_output + ): # Find the position of think end token think_end_index = model_output.find(self.think_end_token) + len( - self.think_end_token) + self.think_end_token + ) # Extract content after think end token result_content = model_output[think_end_index:] thinking_content = model_output[:think_end_index] + else: + thinking_content = "" + result_content = model_output try: function_calls = self._get_function_calls(result_content) if len(function_calls) == 0: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) tool_calls = [ self._parse_xml_function_call(function_call_str, request.tools) @@ -288,19 +320,20 @@ class SeedOssToolParser(ToolParser): self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: - self.prev_tool_call_arr.append({ - "name": - tool_call.function.name, - "arguments": - tool_call.function.arguments, - }) + self.prev_tool_call_arr.append( + { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + ) # Extract content before tool calls - tool_call_start_index = result_content.find( - self.tool_call_start_token) + tool_call_start_index = result_content.find(self.tool_call_start_token) tool_call_start_index = ( - tool_call_start_index if tool_call_start_index >= 0 else - result_content.find(self.tool_call_prefix)) + tool_call_start_index + if tool_call_start_index >= 0 + else result_content.find(self.tool_call_prefix) + ) content = thinking_content + result_content[:tool_call_start_index] return ExtractedToolCallInformation( @@ -311,9 +344,9 @@ class SeedOssToolParser(ToolParser): except Exception: logger.exception("Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -331,18 +364,18 @@ class SeedOssToolParser(ToolParser): # Check if this is an EOS token after all tool calls are complete # We check for tool calls in the text even if is_tool_call_started # is False because it might have been reset after processing all tools - if (delta_token_ids - and self.tool_call_end_token_id not in delta_token_ids): + if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: # Count complete tool calls complete_calls = len( - self.tool_call_complete_regex.findall(current_text)) + self.tool_call_complete_regex.findall(current_text) + ) # If we have completed tool calls and populated prev_tool_call_arr if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed open_calls = current_text.count( - self.tool_call_start_token) - current_text.count( - self.tool_call_end_token) + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) if open_calls == 0: # Return empty delta message to allow finish_reason processing return DeltaMessage(content="") @@ -372,16 +405,18 @@ class SeedOssToolParser(ToolParser): # Check if there are more tool calls if self.current_tool_index >= current_text.count( - self.tool_call_start_token): + self.tool_call_start_token + ): # No more tool calls self.is_tool_call_started = False # Continue processing next tool return None # Check if end thinking - if (not self.is_thinking_end - and (self.think_end_token_id in delta_token_ids - or self.think_end_token in delta_text)): + if not self.is_thinking_end and ( + self.think_end_token_id in delta_token_ids + or self.think_end_token in delta_text + ): self.is_thinking_end = True # If thinking hasn't ended yet, don't process any tool calls @@ -391,20 +426,25 @@ class SeedOssToolParser(ToolParser): # Handle normal content before tool calls if not self.is_tool_call_started: # Check if tool call is starting - if (self.tool_call_start_token_id in delta_token_ids - or self.tool_call_start_token in delta_text): + if ( + self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text + ): self.is_tool_call_started = True # Return any content before the tool call if self.tool_call_start_token in delta_text: - content_before = delta_text[:delta_text.index( - self.tool_call_start_token)] + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] if content_before: return DeltaMessage(content=content_before) return None else: # Check if we're between tool calls - skip whitespace - if (current_text.rstrip().endswith(self.tool_call_end_token) - and delta_text.strip() == ""): + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): # We just ended a tool call, skip whitespace return None # Normal content, no tool call @@ -420,9 +460,11 @@ class SeedOssToolParser(ToolParser): # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index # Only process tool calls after think_end_token - think_end_index = current_text.find(self.think_end_token) + len( - self.think_end_token - ) if self.think_end_token in current_text else 0 + think_end_index = ( + current_text.find(self.think_end_token) + len(self.think_end_token) + if self.think_end_token in current_text + else 0 + ) tool_starts: list[int] = [] idx = think_end_index while True: @@ -438,26 +480,26 @@ class SeedOssToolParser(ToolParser): tool_start_idx = tool_starts[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) - tool_end_idx = current_text.find(self.tool_call_end_token, - tool_start_idx) + tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) if tool_end_idx == -1: tool_text = current_text[tool_start_idx:] else: - tool_text = current_text[tool_start_idx:tool_end_idx + - len(self.tool_call_end_token)] + tool_text = current_text[ + tool_start_idx : tool_end_idx + len(self.tool_call_end_token) + ] # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix) + self.tool_call_prefix + ) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_id = self._generate_tool_call_id( - ) # type: ignore + self.current_tool_id = self._generate_tool_call_id() # type: ignore self.header_sent = True self.in_function = True @@ -465,38 +507,44 @@ class SeedOssToolParser(ToolParser): # This ensures finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name - for tool in self.prev_tool_call_arr) + for tool in self.prev_tool_call_arr + ) if not already_added: - self.prev_tool_call_arr.append({ - "name": self.current_function_name, - "arguments": - "{}", # Placeholder, will be updated later - }) + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) # Send header with function info - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - id=self.current_tool_id, - function=DeltaFunctionCall( - name=self.current_function_name, arguments=""), - type="function", - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) return None # We've sent header, now handle function body if self.in_function: # Send opening brace if not sent yet - if (not self.json_started - and self.parameter_prefix not in delta_text): + if not self.json_started and self.parameter_prefix not in delta_text: self.json_started = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) # Make sure json_started is set if we're processing parameters if not self.json_started: @@ -510,34 +558,38 @@ class SeedOssToolParser(ToolParser): # Extract the complete tool call to update prev_tool_call_arr with final arguments # Find the function content func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix) - func_content_end = tool_text.find(self.function_end_token, - func_start) + self.tool_call_prefix + ) + func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: func_content = tool_text[func_start:func_content_end] # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, request.tools if request else None + ) if parsed_tool: # Update existing entry in prev_tool_call_arr with complete arguments for i, tool in enumerate(self.prev_tool_call_arr): - if tool.get( - "name") == parsed_tool.function.name: + if tool.get("name") == parsed_tool.function.name: self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + parsed_tool.function.arguments + ) break except Exception: logger.warning( "Failed to parse tool arguments during streaming.", - exc_info=True) + exc_info=True, + ) - result = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ]) + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) # Reset state for next tool self.in_function = False @@ -580,8 +632,7 @@ class SeedOssToolParser(ToolParser): value_text = value_text[1:] # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) + param_end_idx = value_text.find(self.parameter_end_token) if param_end_idx != -1: # Complete parameter found param_value = value_text[:param_end_idx] @@ -591,22 +642,33 @@ class SeedOssToolParser(ToolParser): # Build complete JSON fragment for this parameter if self.param_count == 0: json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + '"' + + self.current_param_name + + '": "' + + json.dumps(param_value)[1:-1] + + '"' + ) else: json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + ', "' + + self.current_param_name + + '": "' + + json.dumps(param_value)[1:-1] + + '"' + ) self.param_count += 1 - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment + ), + ) + ] + ) # Continue parameter value if self.in_param: @@ -618,29 +680,34 @@ class SeedOssToolParser(ToolParser): # Skip past > if at start if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if not self.current_param_value and value_chunk.startswith( - "\n"): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] # Calculate incremental JSON full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = ( + json.dumps(self.current_param_value)[1:-1] + if self.current_param_value + else "" + ) full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + delta_escaped = full_escaped[len(prev_escaped) :] self.in_param = False self.current_param_value = "" - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped + '"'), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"' + ), + ) + ] + ) else: # Continue accumulating value value_chunk = delta_text @@ -648,29 +715,32 @@ class SeedOssToolParser(ToolParser): # Handle first chunk after param name if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if not self.current_param_value and value_chunk.startswith( - "\n"): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = ( + json.dumps(self.current_param_value)[1:-1] + if self.current_param_value + else "" + ) self.current_param_value += value_chunk - full_escaped = json.dumps( - self.current_param_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + full_escaped = json.dumps(self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped) :] if delta_escaped: - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + ), + ) + ] + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py index a20d18eb52544..34bd372b2060b 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -8,13 +8,19 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -35,9 +41,7 @@ class Step3ToolParser(ToolParser): TOOL_CALL_BEGIN = "<|tool_call_begin|>" TOOL_CALL_END = "<|tool_call_end|>" TOOL_SEP = "<|tool_sep|>" - SPECIAL_TOKENS = [ - TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END - ] + SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -46,18 +50,16 @@ class Step3ToolParser(ToolParser): self.tool_block_started = False self.tool_block_finished = False - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request @staticmethod def _parse_steptml_invoke( - action_text: str + action_text: str, ) -> tuple[Optional[str], Optional[dict[str, str]]]: - func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', - action_text) + func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text) if not func_name_match: return None, None func_name = func_name_match.group(1) @@ -65,7 +67,8 @@ class Step3ToolParser(ToolParser): params: dict[str, str] = {} param_matches = re.findall( r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', - action_text) + action_text, + ) for name, value in param_matches: params[name] = value.strip() return func_name, params @@ -95,11 +98,13 @@ class Step3ToolParser(ToolParser): params[key] = float(value) elif typ == "boolean": lower_val = value.lower() - params[key] = lower_val == "true" if lower_val in ( - "true", "false") else value + params[key] = ( + lower_val == "true" + if lower_val in ("true", "false") + else value + ) elif typ == "null": - params[key] = None if value.lower( - ) == "null" else value + params[key] = None if value.lower() == "null" else value break return params @@ -113,13 +118,12 @@ class Step3ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # The main loop processes the stream from the last known position. while True: if self.position >= len(current_text): return None # We've processed the entire stream. - unprocessed_text = current_text[self.position:] + unprocessed_text = current_text[self.position :] # STATE: After all tools are done, all subsequent text is content. if self.tool_block_finished: @@ -135,8 +139,10 @@ class Step3ToolParser(ToolParser): start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN) if start_pos == -1: - if self.TOOL_CALLS_BEGIN.startswith( - unprocessed_text.strip()) and unprocessed_text: + if ( + self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip()) + and unprocessed_text + ): return None # It's a prefix, wait. self.position = len(current_text) return DeltaMessage(content=unprocessed_text) @@ -157,9 +163,9 @@ class Step3ToolParser(ToolParser): continue # Check if we are between tool calls. - tool_finished = ( - self.current_tool_id != -1 and - self.prev_tool_call_arr[self.current_tool_id].get("finished")) + tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[ + self.current_tool_id + ].get("finished") if self.current_tool_id == -1 or tool_finished: if unprocessed_text.startswith(self.TOOL_CALL_BEGIN): self.position += len(self.TOOL_CALL_BEGIN) @@ -170,8 +176,7 @@ class Step3ToolParser(ToolParser): self.current_tool_name_sent = False while len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) - self.prev_tool_call_arr[ - self.current_tool_id]["finished"] = False + self.prev_tool_call_arr[self.current_tool_id]["finished"] = False continue if self.TOOL_CALL_BEGIN.startswith(unprocessed_text): @@ -179,63 +184,65 @@ class Step3ToolParser(ToolParser): # STATE: Parsing an active tool call. if self.current_tool_id != -1 and not self.prev_tool_call_arr[ - self.current_tool_id].get("finished", False): + self.current_tool_id + ].get("finished", False): end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END) if end_tool_pos == -1: tool_body = unprocessed_text else: tool_body = unprocessed_text[:end_tool_pos] - if end_tool_pos == -1 and self.TOOL_CALL_END.startswith( - tool_body): + if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body): return None - function_name, arguments = self._parse_steptml_invoke( - tool_body) + function_name, arguments = self._parse_steptml_invoke(tool_body) if not function_name: return None - tool_call_arr = { - "name": function_name, - "parameters": arguments or {} - } + tool_call_arr = {"name": function_name, "parameters": arguments or {}} # Send the function name as soon as it's parsed. if not self.current_tool_name_sent: self.current_tool_name_sent = True - self.prev_tool_call_arr[self.current_tool_id].update( - tool_call_arr) - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=f"chatcmpl-tool-{random_uuid()}", - function=DeltaFunctionCall( - name=function_name)) - ]) + self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall(name=function_name), + ) + ] + ) # Update our internal state with the latest parsed arguments. - self.prev_tool_call_arr[ - self.current_tool_id].update( # noqa: E501 - tool_call_arr) + self.prev_tool_call_arr[self.current_tool_id].update( # noqa: E501 + tool_call_arr + ) # Only send arguments when the tool call is complete. if end_tool_pos != -1: self.position += end_tool_pos + len(self.TOOL_CALL_END) - self.prev_tool_call_arr[ - self.current_tool_id]["finished"] = True + self.prev_tool_call_arr[self.current_tool_id]["finished"] = True final_args = self._cast_arguments( function_name, tool_call_arr.get("parameters", {}), # type: ignore - request) + request, + ) if final_args: - final_args_json = json.dumps(final_args, - ensure_ascii=False) - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=final_args_json)) - ]) + final_args_json = json.dumps(final_args, ensure_ascii=False) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=final_args_json + ), + ) + ] + ) # If tool is not finished, return None to wait for more tokens. return None @@ -248,15 +255,15 @@ class Step3ToolParser(ToolParser): request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.TOOL_CALLS_BEGIN not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1) if self.TOOL_CALLS_END not in rest: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1) content = (pre_text + post_text).strip() @@ -276,21 +283,22 @@ class Step3ToolParser(ToolParser): if type_part.strip() != "function": continue - function_name, params_dict = self._parse_steptml_invoke( - invoke_part) + function_name, params_dict = self._parse_steptml_invoke(invoke_part) if function_name and params_dict is not None: - params_dict = self._cast_arguments(function_name, params_dict, - request) + params_dict = self._cast_arguments(function_name, params_dict, request) params_str = json.dumps(params_dict, ensure_ascii=False) tool_calls.append( - ToolCall(function=FunctionCall(name=function_name, - arguments=params_str))) + ToolCall( + function=FunctionCall(name=function_name, arguments=params_str) + ) + ) if tool_calls: return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if content else None) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + content=content if content else None, + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index aa41cd6dc53ed..e076ab38e3364 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -22,7 +22,7 @@ def find_common_prefix(s1: str, s2: str) -> str: e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap' """ - prefix = '' + prefix = "" min_length = min(len(s1), len(s2)) for i in range(0, min_length): if s1[i] == s2[i]: @@ -40,7 +40,7 @@ def find_common_suffix(s1: str, s2: str) -> str: e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' """ - suffix = '' + suffix = "" min_length = min(len(s1), len(s2)) for i in range(1, min_length + 1): if s1[-i] == s2[-i] and not s1[-i].isalnum(): @@ -70,15 +70,15 @@ def extract_intermediate_diff(curr: str, old: str) -> str: """ suffix = find_common_suffix(curr, old) - old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + old = old[::-1].replace(suffix[::-1], "", 1)[::-1] prefix = find_common_prefix(curr, old) diff = curr if len(suffix): - diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] if len(prefix): # replace the prefix only once in case it's mirrored - diff = diff.replace(prefix, '', 1) + diff = diff.replace(prefix, "", 1) return diff diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 87cd413b37200..c1f0d29cc0873 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -8,13 +8,19 @@ from typing import Any, Optional, Union import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -24,7 +30,6 @@ logger = init_logger(__name__) @ToolParserManager.register_module("xlam") class xLAMToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -32,8 +37,7 @@ class xLAMToolParser(ToolParser): self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 self.current_tool_name_sent = False - self.streamed_args: list[str] = [ - ] # Track arguments sent for each tool + self.streamed_args: list[str] = [] # Track arguments sent for each tool # For backward compatibility with tests self.current_tools_sent: list[bool] = [] @@ -57,7 +61,8 @@ class xLAMToolParser(ToolParser): } def preprocess_model_output( - self, model_output: str) -> tuple[Optional[str], Optional[str]]: + self, model_output: str + ) -> tuple[Optional[str], Optional[str]]: """ Preprocess the model output to extract content and potential tool calls. Returns: @@ -66,8 +71,7 @@ class xLAMToolParser(ToolParser): # Check for thinking tag thinking_match = re.search(self.thinking_tag_pattern, model_output) if thinking_match: - content = model_output[:thinking_match.start() + - len("</think>")].strip() + content = model_output[: thinking_match.start() + len("</think>")].strip() thinking_content = thinking_match.group(1).strip() # Try to parse the thinking content as JSON @@ -94,8 +98,7 @@ class xLAMToolParser(ToolParser): try: json.loads(json_str) # Extract content by removing the JSON code block - content = re.sub(json_pattern, "", - model_output).strip() + content = re.sub(json_pattern, "", model_output).strip() return content, json_str except json.JSONDecodeError: continue @@ -107,28 +110,30 @@ class xLAMToolParser(ToolParser): return None, model_output except json.JSONDecodeError: # Even if it's not valid JSON yet, it might be a tool call in progress - if ("{" in model_output and "name" in model_output - and "arguments" in model_output): + if ( + "{" in model_output + and "name" in model_output + and "arguments" in model_output + ): return None, model_output # If no tool calls found, return the original output as content return model_output, None def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract tool calls from a complete model output. """ try: # Preprocess the model output - content, potential_tool_calls = self.preprocess_model_output( - model_output) + content, potential_tool_calls = self.preprocess_model_output(model_output) if not potential_tool_calls: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=content) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=content + ) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) @@ -145,8 +150,11 @@ class xLAMToolParser(ToolParser): tool_calls: list[ToolCall] = [] for idx, call in enumerate(tool_calls_data): - if (not isinstance(call, dict) or "name" not in call - or "arguments" not in call): + if ( + not isinstance(call, dict) + or "name" not in call + or "arguments" not in call + ): logger.debug("Invalid tool call format at index %d", idx) continue @@ -155,8 +163,11 @@ class xLAMToolParser(ToolParser): type="function", function=FunctionCall( name=call["name"], - arguments=(json.dumps(call["arguments"]) if isinstance( - call["arguments"], dict) else call["arguments"]), + arguments=( + json.dumps(call["arguments"]) + if isinstance(call["arguments"], dict) + else call["arguments"] + ), ), ) tool_calls.append(tool_call) @@ -169,9 +180,9 @@ class xLAMToolParser(ToolParser): except Exception as e: logger.exception("Error extracting tool calls: %s", str(e)) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -186,11 +197,41 @@ class xLAMToolParser(ToolParser): """ Extract tool calls for streaming mode. """ - # Simplify detection: if it begins with "[" treat it as a function call - is_function_call = (current_text.strip().startswith("[")) + # First, check for a definitive start of a tool call block. + # This prevents premature parsing of incomplete output. + stripped_text = current_text.strip() + preprocessed_content, preprocessed_tool_calls = self.preprocess_model_output( + current_text + ) - # If not a function call, return normal content - if not is_function_call: + # For JSON code blocks, we need to detect them earlier, even if incomplete + has_potential_json_block = ( + "```json" in current_text + or "```\n[" in current_text + or "[TOOL_CALLS]" in current_text + or "<tool_call>" in current_text + ) + + is_tool_call_block = ( + stripped_text.startswith("[") + or stripped_text.startswith("<tool_call>") + or stripped_text.startswith("[TOOL_CALLS]") + or + # Check if we have thinking tags with JSON-like content following + ("</think>[" in current_text) + or + # Check if the text contains a JSON array after preprocessing + preprocessed_tool_calls is not None + or + # For JSON code blocks, detect early if we see enough structure + ( + has_potential_json_block + and '"name"' in current_text + and '"arguments"' in current_text + ) + ) + + if not is_tool_call_block: return DeltaMessage(content=delta_text) try: @@ -204,7 +245,11 @@ class xLAMToolParser(ToolParser): # Try parsing as JSON to check for complete tool calls try: - parsed_tools = json.loads(current_text) + # Use preprocessed tool calls if available + tool_calls_text = ( + preprocessed_tool_calls if preprocessed_tool_calls else current_text + ) + parsed_tools = json.loads(tool_calls_text) if isinstance(parsed_tools, list): # Update our tool array for next time self.prev_tool_call_arr = parsed_tools @@ -214,11 +259,15 @@ class xLAMToolParser(ToolParser): # Check for test-specific state setup (current_tools_sent) # This handles the case where tests manually set current_tools_sent - if (hasattr(self, "current_tools_sent") # type: ignore - and len(self.current_tools_sent) > 0): + if ( + hasattr(self, "current_tools_sent") # type: ignore + and len(self.current_tools_sent) > 0 + ): # If current_tools_sent is set to [False], it means the test wants us to send the name - if (len(self.current_tools_sent) == 1 - and self.current_tools_sent[0] is False): + if ( + len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False + ): # Extract the function name using regex name_pattern = r'"name"\s*:\s*"([^"]+)"' name_match = re.search(name_pattern, current_text) @@ -227,54 +276,81 @@ class xLAMToolParser(ToolParser): # The test expects us to send just the name first tool_id = make_tool_call_id() - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=0, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) # Update state to reflect that we've sent the name self.current_tools_sent = [True] self.current_tool_id = 0 self.streaming_state["current_tool_index"] = 0 if len(self.streaming_state["sent_tools"]) == 0: - self.streaming_state["sent_tools"].append({ - "sent_name": - True, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": True, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) else: - self.streaming_state["sent_tools"][0][ - "sent_name"] = True + self.streaming_state["sent_tools"][0]["sent_name"] = True self.current_tool_name_sent = True return delta # Use regex to identify tool calls in the output + # Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks + search_text = ( + preprocessed_tool_calls if preprocessed_tool_calls else current_text + ) + + # For JSON code blocks that aren't complete yet, try to extract the JSON content + if not preprocessed_tool_calls and has_potential_json_block: + # Try to extract the JSON array from within the code block + json_match = re.search( + r"```(?:json)?\s*([\s\S]*?)(?:```|$)", current_text + ) + if json_match: + potential_json = json_match.group(1).strip() + # Use this as search text even if it's incomplete + if potential_json.startswith("[") and ( + '"name"' in potential_json and '"arguments"' in potential_json + ): + search_text = potential_json + + # Try to find complete tool names first name_pattern = r'"name"\s*:\s*"([^"]+)"' - name_matches = list(re.finditer(name_pattern, current_text)) + name_matches = list(re.finditer(name_pattern, search_text)) tool_count = len(name_matches) - # If no tools found yet, return + # If no complete tool names found, check for partial tool names if tool_count == 0: - return None + # Check if we're in the middle of parsing a tool name + partial_name_pattern = r'"name"\s*:\s*"([^"]*)' + partial_matches = list(re.finditer(partial_name_pattern, search_text)) + if partial_matches: + # We have a partial tool name - not ready to emit yet + return None + else: + # No tools found at all + return None # Ensure our state arrays are large enough while len(self.streaming_state["sent_tools"]) < tool_count: - self.streaming_state["sent_tools"].append({ - "sent_name": - False, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": False, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) while len(self.streaming_state["tool_ids"]) < tool_count: self.streaming_state["tool_ids"].append(None) @@ -287,14 +363,13 @@ class xLAMToolParser(ToolParser): next_idx = current_idx + 1 # If tool at next_idx has not been sent yet - if (next_idx < tool_count - and not self.streaming_state["sent_tools"][next_idx] - ["sent_name"]): + if ( + next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx]["sent_name"] + ): # Update indexes self.streaming_state["current_tool_index"] = next_idx - self.current_tool_id = ( - next_idx # For backward compatibility - ) + self.current_tool_id = next_idx # For backward compatibility current_idx = next_idx # Extract the tool name @@ -304,21 +379,20 @@ class xLAMToolParser(ToolParser): tool_id = f"call_{current_idx}_{random_uuid()}" self.streaming_state["tool_ids"][current_idx] = tool_id - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=tool_name).model_dump( - exclude_none=True), # type: ignore - ) - ]) - self.streaming_state["sent_tools"][current_idx][ - "sent_name"] = True - self.current_tool_name_sent = ( - True # For backward compatibility + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True + ), # type: ignore + ) + ] ) + self.streaming_state["sent_tools"][current_idx]["sent_name"] = True + self.current_tool_name_sent = True # For backward compatibility # Keep track of streamed args for backward compatibility while len(self.streamed_args) <= current_idx: @@ -331,8 +405,9 @@ class xLAMToolParser(ToolParser): # Support both regular and empty argument objects # First, check for the empty arguments case: "arguments": {} empty_args_pattern = ( - r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') - empty_args_match = re.search(empty_args_pattern, current_text) + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}' + ) + empty_args_match = re.search(empty_args_pattern, search_text) # Check if this tool has empty arguments if empty_args_match and empty_args_match.start() > 0: @@ -341,42 +416,45 @@ class xLAMToolParser(ToolParser): for i in range(tool_count): if i == current_idx: # If this is our current tool and it has empty arguments - if not self.streaming_state["sent_tools"][ - current_idx]["sent_arguments_prefix"]: + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix" + ]: # Send empty object - self.streaming_state["sent_tools"][ - current_idx][ - "sent_arguments_prefix"] = True - self.streaming_state["sent_tools"][ - current_idx]["sent_arguments"] = "{}" + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix" + ] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments" + ] = "{}" # Update streamed_args for backward compatibility while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{}"). - model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}" + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) # Move to next tool if available if current_idx < tool_count - 1: - self.streaming_state[ - "current_tool_index"] += 1 + self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] + "current_tool_index" + ] return delta # Extract arguments for current tool using regex for non-empty arguments args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' - args_matches = list(re.finditer(args_pattern, current_text)) + args_matches = list(re.finditer(args_pattern, search_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) @@ -384,69 +462,82 @@ class xLAMToolParser(ToolParser): # Handle transition between tools is_last_tool = current_idx == tool_count - 1 - # Find where the arguments for our current tool end - if not is_last_tool: - # If we have more tools after this one, try to find the complete argument block - next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) - if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1 - ) # +1 to include the '}' - args_text = (current_text[args_matches[current_idx] - .start():args_end_pos]. - split('"arguments":')[1].strip()) + # For multiple tools, extract only the arguments for the current tool + if tool_count > 1: + # Parse the entire JSON structure to properly extract arguments for each tool + try: + parsed_tools = json.loads(search_text) + if isinstance(parsed_tools, list) and current_idx < len( + parsed_tools + ): + current_tool = parsed_tools[current_idx] + if isinstance(current_tool.get("arguments"), dict): + args_text = json.dumps(current_tool["arguments"]) + else: + args_text = str(current_tool.get("arguments", "{}")) + except (json.JSONDecodeError, KeyError, IndexError): + # Fallback to regex-based extraction + pass # If arguments haven't been sent yet - sent_args = self.streaming_state["sent_tools"][ - current_idx]["sent_arguments"] + sent_args = self.streaming_state["sent_tools"][current_idx][ + "sent_arguments" + ] # If we haven't sent the opening bracket yet if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] and args_text.startswith( - "{"): + "sent_arguments_prefix" + ] and args_text.startswith("{"): self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True + "sent_arguments_prefix" + ] = True self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{" + "sent_arguments" + ] = "{" # Update streamed_args for backward compatibility while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{").model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{" + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) return delta # If we need to send more arguments if args_text.startswith(sent_args): # Calculate what part of arguments we need to send - args_diff = args_text[len(sent_args):] + args_diff = args_text[len(sent_args) :] if args_diff: # Update our state self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = args_text + "sent_arguments" + ] = args_text # Update streamed_args for backward compatibility while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += args_diff - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments=args_diff).model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) return delta # If the tool's arguments are complete, check if we need to move to the next tool @@ -455,7 +546,8 @@ class xLAMToolParser(ToolParser): if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] # For compatibility + "current_tool_index" + ] # For compatibility # If we got here, we couldn't determine what to stream next return None diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py new file mode 100644 index 0000000000000..98c9cbbbd376e --- /dev/null +++ b/vllm/entrypoints/renderer.py @@ -0,0 +1,411 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import io +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Annotated, Optional, Union + +import pybase64 +import torch +from pydantic import Field + +from vllm.config import ModelConfig +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TextPrompt as EngineTextPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.parse import get_prompt_components, parse_raw_prompts +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AsyncMicrobatchTokenizer + + +@dataclass(frozen=True) +class RenderConfig: + """Configuration to control how prompts are prepared.""" + + max_length: Optional[int] = None + """Maximum allowable total input token length. If provided, + token inputs longer than this raise ``ValueError``.""" + + truncate_prompt_tokens: Optional[int] = None + """Number of tokens to keep. ``None`` means no truncation. + ``0`` yields an empty list (and skips embeds). + ``-1`` maps to ``model_config.max_model_len``.""" + + add_special_tokens: Optional[bool] = True + """Whether to add model-specific special tokens during tokenization.""" + + cache_salt: Optional[str] = None + """String to disambiguate prefix cache entries.""" + + needs_detokenization: Optional[bool] = False + """If True, detokenize IDs back to text for inclusion in outputs.""" + + def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> Optional[int]: + """Validate and normalize `truncate_prompt_tokens` parameter.""" + truncate_prompt_tokens = self.truncate_prompt_tokens + if truncate_prompt_tokens is None: + return None + + if truncate_prompt_tokens == 0: + return 0 + + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = model_config.max_model_len + + max_length = self.max_length + if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] + raise ValueError( + f"{truncate_prompt_tokens=} cannot be greater than " + f"{max_length=}. Please select a smaller truncation size." + ) + + return truncate_prompt_tokens + + +class BaseRenderer(ABC): + """ + Base class for unified input processing and rendering. + + The Renderer serves as a unified input processor that consolidates + tokenization, chat template formatting, and multimodal input handling + into a single component. + It converts high-level API requests (OpenAI-style JSON) into token IDs and + multimodal features ready for engine consumption. + + Key responsibilities: + - Convert text prompts to token sequences with proper special tokens + - Apply chat templates and format conversations + - Handle multimodal inputs (images, audio, etc.) when applicable + - Manage prompt truncation and length validation + - Provide clean separation between API layer and engine core + """ + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + ): + super().__init__() + self.model_config = model_config + self.tokenizer = tokenizer + + @abstractmethod + async def render_prompt( + self, + *, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + config: RenderConfig, + ) -> list[EngineTokensPrompt]: + """ + Convert text or token inputs into engine-ready TokensPrompt objects. + + This method accepts text or token inputs and produces a + list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects + for the engine. + + Args: + prompt_or_prompts: One of: + - ``str``: Single text prompt. + - ``list[str]``: Batch of text prompts. + - ``list[int]``: Single pre-tokenized sequence. + - ``list[list[int]]``: Batch of pre-tokenized sequences. + config: Render configuration controlling how prompts are prepared + (e.g., tokenization and length handling). + + Returns: + list[EngineTokensPrompt]: Engine-ready token prompts. + + Raises: + ValueError: If input formats are invalid or length limits exceeded. + """ + raise NotImplementedError + + @abstractmethod + async def render_prompt_and_embeds( + self, + *, + prompt_or_prompts: Optional[ + Union[str, list[str], list[int], list[list[int]]] + ] = None, + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + config: RenderConfig, + ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + """ + Convert text/token and/or base64-encoded embeddings inputs into + engine-ready prompt objects using a unified RenderConfig. + + At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be + provided and non-empty. If both are omitted or empty (e.g., empty + string and empty list), a ``ValueError`` is raised. + + Args: + prompt_or_prompts: Text or token inputs to include. + prompt_embeds: Base64-encoded bytes (or list thereof) containing a + torch-saved tensor to be used as prompt embeddings. + config: Render configuration controlling how prompts are prepared + (e.g., tokenization and length handling). + + Returns: + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + Engine-ready prompt objects. + + Raises: + ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds`` + are omitted or empty (decoder prompt cannot be empty), or if + length limits are exceeded. + """ + raise NotImplementedError + + @classmethod + def load_prompt_embeds( + cls, + prompt_embeds: Union[bytes, list[bytes]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None, + cache_salt: Optional[str] = None, + ) -> list[EngineEmbedsPrompt]: + """Load and validate base64-encoded embeddings into prompt objects.""" + + def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) + if cache_salt is not None: + embeds_prompt["cache_salt"] = cache_salt + return embeds_prompt + + if isinstance(prompt_embeds, list): + return [_load_and_validate_embed(embed) for embed in prompt_embeds] + + return [_load_and_validate_embed(prompt_embeds)] + + +class CompletionRenderer(BaseRenderer): + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + async_tokenizer_pool: Optional[ + dict[AnyTokenizer, AsyncMicrobatchTokenizer] + ] = None, + ): + super().__init__(model_config, tokenizer) + self.async_tokenizer_pool = async_tokenizer_pool + self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None + + async def render_prompt( + self, + *, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + config: RenderConfig, + ) -> list[EngineTokensPrompt]: + """Implementation of prompt rendering for completion-style requests. + + Uses async tokenizer pooling for improved performance. See base class + for detailed parameter documentation. + """ + truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config) + if truncate_prompt_tokens == 0: + return [] + + tasks = ( + self._create_prompt( + prompt_input, + config=config, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + for prompt_input in parse_raw_prompts(prompt_or_prompts) + ) + + return await asyncio.gather(*tasks) + + async def render_prompt_and_embeds( + self, + *, + prompt_or_prompts: Optional[ + Union[str, list[str], list[int], list[list[int]]] + ] = None, + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + config: RenderConfig, + ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + """ + Render text/token prompts and/or precomputed embedding prompts. At + least one of `prompt_or_prompts` or `prompt_embeds` must be provided. + """ + truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config) + if truncate_prompt_tokens == 0: + return [] + + rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = [] + + if prompt_embeds is not None: + rendered.extend( + self.load_prompt_embeds( + prompt_embeds, truncate_prompt_tokens, config.cache_salt + ) + ) + if prompt_or_prompts is None or prompt_or_prompts == "": + return rendered + + token_prompts = await self.render_prompt( + prompt_or_prompts=prompt_or_prompts, + config=config, + ) + rendered.extend(token_prompts) + + return rendered + + def _maybe_apply_truncation( + self, token_ids: list[int], truncate_prompt_tokens: Optional[int] + ) -> list[int]: + """Apply truncation to token sequence.""" + if truncate_prompt_tokens is None: + return token_ids + if truncate_prompt_tokens >= len(token_ids): + return token_ids + + return token_ids[-truncate_prompt_tokens:] + + async def _create_prompt( + self, + prompt_input: Union[EngineTextPrompt, EngineTokensPrompt], + config: RenderConfig, + truncate_prompt_tokens: Optional[int], + ) -> EngineTokensPrompt: + prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) + + if prompt_token_ids is not None: + # NOTE: detokenization is needed when echo is enabled, + # where the input token IDs are decoded back to text. + return await self._create_prompt_from_token_ids( + prompt_token_ids, + config.max_length, + truncate_prompt_tokens, + config.cache_salt, + config.needs_detokenization, + ) + + if prompt is not None: + return await self._create_prompt_from_text( + prompt, + config.max_length, + truncate_prompt_tokens, + config.add_special_tokens, + config.cache_salt, + ) + + # TODO: Also handle embeds prompt using this method + raise NotImplementedError + + async def _create_prompt_from_text( + self, + text: str, + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + add_special_tokens: Optional[bool], + cache_salt: Optional[str], + ) -> EngineTokensPrompt: + """Tokenize text input asynchronously.""" + async_tokenizer = self._get_async_tokenizer() + + # Handle encoder-specific preprocessing + if ( + self.model_config.encoder_config is not None + and self.model_config.encoder_config.get("do_lower_case", False) + ): + text = text.lower() + + # Tokenize texts + if truncate_prompt_tokens is None: + encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens) + else: + encoded = await async_tokenizer( + text, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens, + ) + + return self._create_tokens_prompt( + encoded.input_ids, max_length, cache_salt, text + ) + + async def _create_prompt_from_token_ids( + self, + token_ids: list[int], + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + cache_salt: Optional[str], + needs_detokenization: Optional[bool] = False, + ) -> EngineTokensPrompt: + """Optionally detokenize token IDs and build a tokens prompt.""" + token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) + + prompt = None + if needs_detokenization: + async_tokenizer = self._get_async_tokenizer() + prompt = await async_tokenizer.decode(token_ids) + + return self._create_tokens_prompt( + token_ids=token_ids, + max_length=max_length, + cache_salt=cache_salt, + prompt=prompt, + ) + + def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: + """Get or create async tokenizer using shared pool.""" + async_tokenizer = self.async_tokenizer + if async_tokenizer is not None: + return async_tokenizer + + tokenizer = self.tokenizer + if self.tokenizer is None: + raise ValueError("No tokenizer available for text input processing") + + if self.async_tokenizer_pool is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + else: + async_tokenizer = self.async_tokenizer_pool.get(tokenizer) + if async_tokenizer is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + self.async_tokenizer_pool[tokenizer] = async_tokenizer + self.async_tokenizer = async_tokenizer + return async_tokenizer + + def _create_tokens_prompt( + self, + token_ids: list[int], + max_length: Optional[int] = None, + cache_salt: Optional[str] = None, + prompt: Optional[str] = None, + ) -> EngineTokensPrompt: + """Create validated EngineTokensPrompt.""" + if max_length is not None and len(token_ids) > max_length: + raise ValueError( + f"This model's maximum context length is {max_length} tokens. " + f"However, your request has {len(token_ids)} input tokens. " + "Please reduce the length of the input messages." + ) + + tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) + if cache_salt is not None: + tokens_prompt["cache_salt"] = cache_salt + if prompt is not None: + tokens_prompt["prompt"] = prompt + return tokens_prompt diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 642d6389539bc..1fb56d246debe 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -7,31 +7,39 @@ from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( - BaseMultiModalItemTracker, ChatCompletionContentPartImageEmbedsParam, - ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam, - MultiModalItemTracker, _ContentPart, _parse_chat_message_content_part) + BaseMultiModalItemTracker, + ChatCompletionContentPartImageEmbedsParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + MultiModalItemTracker, + _ContentPart, + _parse_chat_message_content_part, +) from vllm.inputs import TokensPrompt from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict from vllm.outputs import PoolingRequestOutput -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - PreTrainedTokenizer, - PreTrainedTokenizerFast) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) ScoreContentPartParam: TypeAlias = Union[ - ChatCompletionContentPartImageParam, - ChatCompletionContentPartImageEmbedsParam] + ChatCompletionContentPartImageParam, ChatCompletionContentPartImageEmbedsParam +] class ScoreMultiModalParam(TypedDict, total=False): """ A specialized parameter type for scoring multimodal content - + The reasons why don't reuse `CustomChatCompletionMessageParam` directly: 1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions 2. Including chat-specific fields would confuse users about their purpose in scoring 3. This is a more focused interface that only exposes what's needed for scoring - """ # noqa: E501 + """ # noqa: E501 + content: Required[list[ScoreContentPartParam]] """The multimodal contents""" @@ -41,7 +49,6 @@ def _cosine_similarity( embed_1: list[PoolingRequestOutput], embed_2: list[PoolingRequestOutput], ) -> list[PoolingRequestOutput]: - scorer = CosineSimilarity(0) scores: Union[list[PoolingRequestOutput]] = [] @@ -49,8 +56,7 @@ def _cosine_similarity( pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) padding = [] - if (pad_token_id := getattr(tokenizer, "pad_token_id", - None)) is not None: + if (pad_token_id := getattr(tokenizer, "pad_token_id", None)) is not None: padding = [pad_token_id] tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids @@ -60,7 +66,9 @@ def _cosine_similarity( request_id=f"{emb_1.request_id}_{emb_2.request_id}", outputs=pair_score, prompt_token_ids=tokens, - finished=True)) + finished=True, + ) + ) return scores @@ -96,8 +104,7 @@ def parse_score_data( if content is not None and isinstance(content, str): return cast(str, content) else: - raise ValueError( - f"Only string content is supported, but got {content}.") + raise ValueError(f"Only string content is supported, but got {content}.") prompt_1 = ensure_str(content_1) prompt_2 = ensure_str(content_2) @@ -109,7 +116,6 @@ def _parse_score_content( data: Union[str, ScoreContentPartParam], mm_tracker: BaseMultiModalItemTracker, ) -> Optional[_ContentPart]: - if isinstance(data, str): data = ChatCompletionContentPartTextParam(type="text", text=data) @@ -127,8 +133,10 @@ def _parse_score_content( mm_placeholder_storage = mm_parser.mm_placeholder_storage() - if len(mm_placeholder_storage) != 1 or len( - next(iter(mm_placeholder_storage.values()))) != 1: + if ( + len(mm_placeholder_storage) != 1 + or len(next(iter(mm_placeholder_storage.values()))) != 1 + ): raise ValueError("Only one multi-modal item is supported") return next(iter(mm_placeholder_storage.values()))[0] @@ -149,8 +157,7 @@ def apply_score_template( raise ValueError("Get empty score template from model") return full_prompt - raise ValueError( - f"Unsupported model architecture: {model_config.architecture}") + raise ValueError(f"Unsupported model architecture: {model_config.architecture}") def post_process_tokens( @@ -159,7 +166,7 @@ def post_process_tokens( ) -> None: """ Perform architecture-specific manipulations on the input tokens. - + Note: This is an in-place operation. """ @@ -192,9 +199,9 @@ def get_score_prompt( prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) elif model_config.use_pad_token: # cross_encoder models defaults to using pad_token. - prompt_inputs = tokenizer(text=prompt_1, - text_pair=prompt_2, - **tokenization_kwargs) + prompt_inputs = tokenizer( + text=prompt_1, text_pair=prompt_2, **tokenization_kwargs + ) full_prompt = tokenizer.decode(prompt_inputs["input_ids"]) else: # `llm as reranker` models defaults to not using pad_token. @@ -219,8 +226,10 @@ def compress_token_type_ids(token_type_ids: list[int]) -> int: if not found. """ first_one = len(token_type_ids) - err_msg = "Token type ids are expected to be a sequence"\ - " of zeros followed by a sequence of ones" + err_msg = ( + "Token type ids are expected to be a sequence" + " of zeros followed by a sequence of ones" + ) for i, type_id in enumerate(token_type_ids): if type_id == 0 and first_one < i: raise ValueError(err_msg) diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py index e3646a60a7cc1..ff0dd1bbfc6bb 100644 --- a/vllm/entrypoints/ssl.py +++ b/vllm/entrypoints/ssl.py @@ -17,11 +17,13 @@ class SSLCertRefresher: reloads them when they change. """ - def __init__(self, - ssl_context: SSLContext, - key_path: Optional[str] = None, - cert_path: Optional[str] = None, - ca_path: Optional[str] = None) -> None: + def __init__( + self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None, + ) -> None: self.ssl = ssl_context self.key_path = key_path self.cert_path = cert_path @@ -36,8 +38,10 @@ class SSLCertRefresher: self.watch_ssl_cert_task = None if self.key_path and self.cert_path: self.watch_ssl_cert_task = asyncio.create_task( - self._watch_files([self.key_path, self.cert_path], - update_ssl_cert_chain)) + self._watch_files( + [self.key_path, self.cert_path], update_ssl_cert_chain + ) + ) # Setup CA files watcher def update_ssl_ca(change: Change, file_path: str) -> None: @@ -48,22 +52,21 @@ class SSLCertRefresher: self.watch_ssl_ca_task = None if self.ca_path: self.watch_ssl_ca_task = asyncio.create_task( - self._watch_files([self.ca_path], update_ssl_ca)) + self._watch_files([self.ca_path], update_ssl_ca) + ) - async def _watch_files(self, paths, fun: Callable[[Change, str], - None]) -> None: + async def _watch_files(self, paths, fun: Callable[[Change, str], None]) -> None: """Watch multiple file paths asynchronously.""" logger.info("SSLCertRefresher monitors files: %s", paths) async for changes in awatch(*paths): try: for change, file_path in changes: - logger.info("File change detected: %s - %s", change.name, - file_path) + logger.info("File change detected: %s - %s", change.name, file_path) fun(change, file_path) except Exception as e: logger.error( - "SSLCertRefresher failed taking action on file change. " - "Error: %s", e) + "SSLCertRefresher failed taking action on file change. Error: %s", e + ) def stop(self) -> None: """Stop watching files.""" diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py index 758789a5e059d..c74ce1ee16de1 100644 --- a/vllm/entrypoints/tool.py +++ b/vllm/entrypoints/tool.py @@ -4,6 +4,8 @@ import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from openai_harmony import Author, Message, Role, TextContent + from vllm.logger import init_logger if TYPE_CHECKING: @@ -12,10 +14,12 @@ if TYPE_CHECKING: logger = init_logger(__name__) +MIN_GPT_OSS_VERSION = "0.0.7" + def validate_gpt_oss_install(): """ - Check if the gpt-oss is installed and its version is at least 0.0.3. + Check if the gpt-oss is installed and its version is at least 0.0.7. If not, raise an ImportError. """ from importlib.metadata import PackageNotFoundError, version @@ -23,29 +27,27 @@ def validate_gpt_oss_install(): from packaging.version import InvalidVersion, Version try: - pkg_version_str = version("gpt_oss") # e.g., "0.0.5" + pkg_version_str = version("gpt_oss") pkg_version = Version(pkg_version_str) except PackageNotFoundError: raise ImportError("Package 'gpt_oss' is not installed.") from None except InvalidVersion as e: - raise ImportError( - f"Invalid version string for 'gpt_oss': {e}") from None + raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None - if pkg_version < Version("0.0.3"): + if pkg_version < Version(MIN_GPT_OSS_VERSION): raise ImportError( - f"gpt_oss >= 0.0.3 is required, but {pkg_version} is installed." + f"gpt_oss >= {MIN_GPT_OSS_VERSION} is required, " + f"but {pkg_version} is installed." ) from None class Tool(ABC): - @abstractmethod async def get_result(self, context: "ConversationContext") -> Any: pass class HarmonyBrowserTool(Tool): - def __init__(self): self.enabled = True exa_api_key = os.getenv("EXA_API_KEY") @@ -61,8 +63,8 @@ class HarmonyBrowserTool(Tool): except ImportError as e: self.enabled = False logger.warning_once( - "gpt_oss is not installed properly (%s), browsing is disabled", - e) + "gpt_oss is not installed properly (%s), browsing is disabled", e + ) return browser_backend = ExaBackend(source="web", api_key=exa_api_key) @@ -71,6 +73,7 @@ class HarmonyBrowserTool(Tool): async def get_result(self, context: "ConversationContext") -> Any: from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) last_msg = context.messages[-1] tool_output_msgs = [] @@ -84,7 +87,6 @@ class HarmonyBrowserTool(Tool): class HarmonyPythonTool(Tool): - def __init__(self): self.enabled = True @@ -94,15 +96,41 @@ class HarmonyPythonTool(Tool): except ImportError as e: self.enabled = False logger.warning_once( - "gpt_oss is not installed properly (%s), code interpreter is " - "disabled", e) + "gpt_oss is not installed properly (%s), code interpreter is disabled", + e, + ) return self.python_tool = PythonTool() + + async def validate(self): + if not self.enabled: + return + try: + message = Message( + author=Author(role=Role.ASSISTANT), + content=[TextContent(text="print('Hello, world!')")], + channel="analysis", + recipient="python", + content_type="code", + ) + msgs = [] + async for msg in self.python_tool.process(message): + msgs.append(msg) + assert msgs[0].content[0].text == "Hello, world!\n" + except Exception as e: + self.enabled = False + logger.warning_once( + "Code interpreter tool failed to initialize (%s), code " + "interpreter is disabled", + e, + ) + return logger.info_once("Code interpreter tool initialized") async def get_result(self, context: "ConversationContext") -> Any: from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) last_msg = context.messages[-1] tool_output_msgs = [] diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 2f28595f27c6a..b3dceecc15834 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -19,8 +19,10 @@ async def list_server_and_tools(server_url: str): from mcp import ClientSession from mcp.client.sse import sse_client - async with sse_client(url=server_url) as streams, ClientSession( - *streams) as session: + async with ( + sse_client(url=server_url) as streams, + ClientSession(*streams) as session, + ): initialize_response = await session.initialize() list_tools_response = await session.list_tools() return initialize_response, list_tools_response @@ -38,21 +40,22 @@ def trim_schema(schema: dict) -> dict: # if there's more than 1 types, also remove "null" type as Harmony will # just ignore it types = [ - type_dict["type"] for type_dict in schema["anyOf"] - if type_dict["type"] != 'null' + type_dict["type"] + for type_dict in schema["anyOf"] + if type_dict["type"] != "null" ] schema["type"] = types del schema["anyOf"] if "properties" in schema: schema["properties"] = { - k: trim_schema(v) - for k, v in schema["properties"].items() + k: trim_schema(v) for k, v in schema["properties"].items() } return schema def post_process_tools_description( - list_tools_result: "ListToolsResult") -> "ListToolsResult": + list_tools_result: "ListToolsResult", +) -> "ListToolsResult": # Adapt the MCP tool result for Harmony for tool in list_tools_result.tools: tool.inputSchema = trim_schema(tool.inputSchema) @@ -60,7 +63,8 @@ def post_process_tools_description( # Some tools schema don't need to be part of the prompt (e.g. simple text # in text out for Python) list_tools_result.tools = [ - tool for tool in list_tools_result.tools + tool + for tool in list_tools_result.tools if getattr(tool.annotations, "include_in_prompt", True) ] @@ -68,7 +72,6 @@ def post_process_tools_description( class ToolServer(ABC): - @abstractmethod def has_tool(self, tool_name: str) -> bool: """ @@ -77,8 +80,7 @@ class ToolServer(ABC): pass @abstractmethod - def get_tool_description(self, - tool_name: str) -> Optional[ToolNamespaceConfig]: + def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: """ Return the tool description for the given tool name. If the tool is not supported, return None. @@ -86,7 +88,9 @@ class ToolServer(ABC): pass @abstractmethod - def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: + def new_session( + self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + ) -> AbstractAsyncContextManager[Any]: """ Create a session for the tool. """ @@ -94,14 +98,14 @@ class ToolServer(ABC): class MCPToolServer(ToolServer): - def __init__(self): try: import mcp # noqa: F401 except ImportError: raise ImportError( "mcp is not installed. Please run `pip install mcp` to use " - "MCPToolServer.") from None + "MCPToolServer." + ) from None self.harmony_tool_descriptions = {} async def add_tool_server(self, server_url: str): @@ -110,30 +114,35 @@ class MCPToolServer(ToolServer): self.urls: dict[str, str] = {} for url in tool_urls: url = f"http://{url}/sse" - initialize_response, list_tools_response = ( - await list_server_and_tools(url)) + initialize_response, list_tools_response = await list_server_and_tools(url) - list_tools_response = post_process_tools_description( - list_tools_response) + list_tools_response = post_process_tools_description(list_tools_response) tool_from_mcp = ToolNamespaceConfig( name=initialize_response.serverInfo.name, description=initialize_response.instructions, tools=[ - ToolDescription.new(name=tool.name, - description=tool.description, - parameters=tool.inputSchema) + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.inputSchema, + ) for tool in list_tools_response.tools - ]) + ], + ) self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp if tool_from_mcp.name not in self.urls: self.urls[tool_from_mcp.name] = url else: logger.warning( "Tool %s already exists. Ignoring duplicate tool server %s", - tool_from_mcp.name, url) - logger.info("MCPToolServer initialized with tools: %s", - list(self.harmony_tool_descriptions.keys())) + tool_from_mcp.name, + url, + ) + logger.info( + "MCPToolServer initialized with tools: %s", + list(self.harmony_tool_descriptions.keys()), + ) def has_tool(self, tool_name: str): return tool_name in self.harmony_tool_descriptions @@ -142,36 +151,46 @@ class MCPToolServer(ToolServer): return self.harmony_tool_descriptions.get(tool_name) @asynccontextmanager - async def new_session(self, tool_name: str): + async def new_session( + self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + ): from mcp import ClientSession from mcp.client.sse import sse_client + url = self.urls.get(tool_name) + request_headers = {"x-session-id": session_id} + if headers is not None: + request_headers.update(headers) if not url: raise KeyError(f"Tool '{tool_name}' is not supported") - async with sse_client(url=url) as streams, ClientSession( - *streams) as session: + async with ( + sse_client(url=url, headers=request_headers) as streams, + ClientSession(*streams) as session, + ): await session.initialize() yield session class DemoToolServer(ToolServer): - def __init__(self): self.tools: dict[str, Tool] = {} + + async def init_and_validate(self): browser_tool = HarmonyBrowserTool() + python_tool = HarmonyPythonTool() + await python_tool.validate() if browser_tool.enabled: self.tools["browser"] = browser_tool - python_tool = HarmonyPythonTool() if python_tool.enabled: self.tools["python"] = python_tool - logger.info("DemoToolServer initialized with tools: %s", - list(self.tools.keys())) + logger.info( + "DemoToolServer initialized with tools: %s", list(self.tools.keys()) + ) def has_tool(self, tool_name: str) -> bool: return tool_name in self.tools - def get_tool_description(self, - tool_name: str) -> Optional[ToolNamespaceConfig]: + def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: if tool_name not in self.tools: return None if tool_name == "browser": @@ -182,7 +201,9 @@ class DemoToolServer(ToolServer): raise ValueError(f"Unknown tool {tool_name}") @asynccontextmanager - async def new_session(self, tool_name: str): + async def new_session( + self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + ): if tool_name not in self.tools: raise KeyError(f"Tool '{tool_name}' is not supported") yield self.tools[tool_name] diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index d8905fc141245..c97ca6538814d 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse import asyncio import dataclasses import functools import os -import subprocess -import sys +from argparse import Namespace from typing import Any, Optional, Union from fastapi import Request @@ -16,8 +14,7 @@ from starlette.background import BackgroundTask, BackgroundTasks from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -25,13 +22,11 @@ from vllm.utils import FlexibleArgumentParser logger = init_logger(__name__) VLLM_SUBCMD_PARSER_EPILOG = ( - "Tip: Use `vllm [serve|run-batch|bench <bench_type>] " - "--help=<keyword>` to explore arguments from help.\n" - " - To view a argument group: --help=ModelConfig\n" - " - To view a single argument: --help=max-num-seqs\n" - " - To search by keyword: --help=max\n" - " - To list all groups: --help=listgroup\n" - " - To view help with pager: --help=page") + "For full list: vllm {subcmd} --help=all\n" + "For a section: vllm {subcmd} --help=ModelConfig (case-insensitive)\n" # noqa: E501 + "For a flag: vllm {subcmd} --help=max-model-len (_ or - accepted)\n" # noqa: E501 + "Documentation: https://docs.vllm.ai\n" +) async def listen_for_disconnect(request: Request) -> None: @@ -42,9 +37,9 @@ async def listen_for_disconnect(request: Request) -> None: # If load tracking is enabled *and* the counter exists, decrement # it. Combines the previous nested checks into a single condition # to satisfy the linter rule. - if (getattr(request.app.state, "enable_server_load_tracking", - False) - and hasattr(request.app.state, "server_load_metrics")): + if getattr( + request.app.state, "enable_server_load_tracking", False + ) and hasattr(request.app.state, "server_load_metrics"): request.app.state.server_load_metrics -= 1 break @@ -75,15 +70,15 @@ def with_cancellation(handler_func): # normal route handler, with the correct request type hinting. @functools.wraps(handler_func) async def wrapper(*args, **kwargs): - # The request is either the second positional arg or `raw_request` request = args[1] if len(args) > 1 else kwargs["raw_request"] handler_task = asyncio.create_task(handler_func(*args, **kwargs)) cancellation_task = asyncio.create_task(listen_for_disconnect(request)) - done, pending = await asyncio.wait([handler_task, cancellation_task], - return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + [handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED + ) for task in pending: task.cancel() @@ -99,18 +94,16 @@ def decrement_server_load(request: Request): def load_aware_call(func): - @functools.wraps(func) async def wrapper(*args, **kwargs): - raw_request = kwargs.get("raw_request", - args[1] if len(args) > 1 else None) + raw_request = kwargs.get("raw_request", args[1] if len(args) > 1 else None) if raw_request is None: raise ValueError( - "raw_request required when server load tracking is enabled") + "raw_request required when server load tracking is enabled" + ) - if not getattr(raw_request.app.state, "enable_server_load_tracking", - False): + if not getattr(raw_request.app.state, "enable_server_load_tracking", False): return await func(*args, **kwargs) # ensure the counter exists @@ -126,18 +119,18 @@ def load_aware_call(func): if isinstance(response, (JSONResponse, StreamingResponse)): if response.background is None: - response.background = BackgroundTask(decrement_server_load, - raw_request) + response.background = BackgroundTask(decrement_server_load, raw_request) elif isinstance(response.background, BackgroundTasks): - response.background.add_task(decrement_server_load, - raw_request) + response.background.add_task(decrement_server_load, raw_request) elif isinstance(response.background, BackgroundTask): # Convert the single BackgroundTask to BackgroundTasks # and chain the decrement_server_load task to it tasks = BackgroundTasks() - tasks.add_task(response.background.func, - *response.background.args, - **response.background.kwargs) + tasks.add_task( + response.background.func, + *response.background.args, + **response.background.kwargs, + ) tasks.add_task(decrement_server_load, raw_request) response.background = tasks else: @@ -174,7 +167,6 @@ def _validate_truncation_size( truncate_prompt_tokens: Optional[int], tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> Optional[int]: - if truncate_prompt_tokens is not None: if truncate_prompt_tokens <= -1: truncate_prompt_tokens = max_model_len @@ -183,7 +175,8 @@ def _validate_truncation_size( raise ValueError( f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " f"is greater than max_model_len ({max_model_len})." - f" Please, select a smaller truncation size.") + f" Please, select a smaller truncation size." + ) if tokenization_kwargs is not None: tokenization_kwargs["truncation"] = True @@ -196,116 +189,33 @@ def _validate_truncation_size( return truncate_prompt_tokens -def _output_with_pager(text: str): - """Output text using scrolling view if available and appropriate.""" - - pagers = ['less -R', 'more'] - for pager_cmd in pagers: - try: - proc = subprocess.Popen(pager_cmd.split(), - stdin=subprocess.PIPE, - text=True) - proc.communicate(input=text) - return - except (subprocess.SubprocessError, OSError, FileNotFoundError): - continue - - # No pager worked, fall back to normal print - print(text) - - -def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, - subcommand_name: list[str]): - - # Only handle --help=<keyword> for the current subcommand. - # Since subparser_init() runs for all subcommands during CLI setup, - # we skip processing if the subcommand name is not in sys.argv. - # sys.argv[0] is the program name. The subcommand follows. - # e.g., for `vllm bench latency`, - # sys.argv is `['vllm', 'bench', 'latency', ...]` - # and subcommand_name is "bench latency". - if len(sys.argv) <= len(subcommand_name) or sys.argv[ - 1:1 + len(subcommand_name)] != subcommand_name: - return - - for arg in sys.argv: - if arg.startswith('--help='): - search_keyword = arg.split('=', 1)[1] - - # Enable paged view for full help - if search_keyword == 'page': - help_text = parser.format_help() - _output_with_pager(help_text) - sys.exit(0) - - # List available groups - if search_keyword == 'listgroup': - output_lines = ["\nAvailable argument groups:"] - for group in parser._action_groups: - if group.title and not group.title.startswith( - "positional arguments"): - output_lines.append(f" - {group.title}") - if group.description: - output_lines.append(" " + - group.description.strip()) - output_lines.append("") - _output_with_pager("\n".join(output_lines)) - sys.exit(0) - - # For group search - formatter = parser._get_formatter() - for group in parser._action_groups: - if group.title and group.title.lower() == search_keyword.lower( - ): - formatter.start_section(group.title) - formatter.add_text(group.description) - formatter.add_arguments(group._group_actions) - formatter.end_section() - _output_with_pager(formatter.format_help()) - sys.exit(0) - - # For single arg - matched_actions = [] - - for group in parser._action_groups: - for action in group._group_actions: - # search option name - if any(search_keyword.lower() in opt.lower() - for opt in action.option_strings): - matched_actions.append(action) - - if matched_actions: - header = f"\nParameters matching '{search_keyword}':\n" - formatter = parser._get_formatter() - formatter.add_arguments(matched_actions) - _output_with_pager(header + formatter.format_help()) - sys.exit(0) - - print(f"\nNo group or parameter matching '{search_keyword}'") - print("Tip: use `--help=listgroup` to view all groups.") - sys.exit(1) - - -def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest, - CompletionRequest], - input_length: int, default_sampling_params: dict) -> int: - - max_tokens = getattr(request, "max_completion_tokens", - None) or request.max_tokens +def get_max_tokens( + max_model_len: int, + request: Union[ChatCompletionRequest, CompletionRequest], + input_length: int, + default_sampling_params: dict, +) -> int: + max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens default_max_tokens = max_model_len - input_length max_output_tokens = current_platform.get_max_output_tokens(input_length) - return min(val - for val in (default_max_tokens, max_tokens, max_output_tokens, - default_sampling_params.get("max_tokens")) - if val is not None) + return min( + val + for val in ( + default_max_tokens, + max_tokens, + max_output_tokens, + default_sampling_params.get("max_tokens"), + ) + if val is not None + ) -def log_non_default_args(args: Union[argparse.Namespace, EngineArgs]): +def log_non_default_args(args: Union[Namespace, EngineArgs]): non_default_args = {} - # Handle argparse.Namespace - if isinstance(args, argparse.Namespace): + # Handle Namespace + if isinstance(args, Namespace): parser = make_arg_parser(FlexibleArgumentParser()) for arg, default in vars(parser.parse_args([])).items(): if default != getattr(args, arg): @@ -313,14 +223,17 @@ def log_non_default_args(args: Union[argparse.Namespace, EngineArgs]): # Handle EngineArgs instance elif isinstance(args, EngineArgs): - default_args = EngineArgs() # Create default instance + default_args = EngineArgs(model=args.model) # Create default instance for field in dataclasses.fields(args): current_val = getattr(args, field.name) default_val = getattr(default_args, field.name) if current_val != default_val: non_default_args[field.name] = current_val + if default_args.model != EngineArgs.model: + non_default_args["model"] = default_args.model else: - raise TypeError("Unsupported argument type. " \ - "Must be argparse.Namespace or EngineArgs instance.") + raise TypeError( + "Unsupported argument type. Must be Namespace or EngineArgs instance." + ) logger.info("non-default args: %s", non_default_args) diff --git a/vllm/env_override.py b/vllm/env_override.py index ef425d433320d..7f9054e738463 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -13,29 +13,11 @@ logger = init_logger(__name__) # that interact with vllm workers. # they are executed whenever `import vllm` is called. -if os.environ.get('NCCL_CUMEM_ENABLE', '0') != '0': - logger.warning( - "NCCL_CUMEM_ENABLE is set to %s, skipping override. " - "This may increase memory overhead with cudagraph+allreduce: " - "https://github.com/NVIDIA/nccl/issues/1234", - os.environ['NCCL_CUMEM_ENABLE']) -elif not os.path.exists('/dev/nvidia-caps-imex-channels'): - # NCCL requires NCCL_CUMEM_ENABLE to work with - # multi-node NVLink, typically on GB200-NVL72 systems. - # The ultimate way to detect multi-node NVLink is to use - # NVML APIs, which are too expensive to call here. - # As an approximation, we check the existence of - # /dev/nvidia-caps-imex-channels, used by - # multi-node NVLink to communicate across nodes. - # This will still cost some GPU memory, but it is worthwhile - # because we can get very fast cross-node bandwidth with NVLink. - os.environ['NCCL_CUMEM_ENABLE'] = '0' - # see https://github.com/vllm-project/vllm/pull/15951 # it avoids unintentional cuda initialization from torch.cuda.is_available() -os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' +os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" # see https://github.com/vllm-project/vllm/issues/10480 -os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' +os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 diff --git a/vllm/envs.py b/vllm/envs.py index 5d0e972f43ad0..ab8548cf50661 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +import json import os import sys import tempfile -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union if TYPE_CHECKING: VLLM_HOST_IP: str = "" @@ -17,7 +18,6 @@ if TYPE_CHECKING: LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False - VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -31,14 +31,16 @@ if TYPE_CHECKING: VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" VLLM_NO_USAGE_STATS: bool = False + VLLM_DISABLE_FLASHINFER_PREFILL: bool = False VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" VLLM_CONFIGURE_LOGGING: int = 1 VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" + VLLM_LOGGING_STREAM: str = "ext://sys.stdout" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None - VLLM_LOG_STATS_INTERVAL: float = 10. + VLLM_LOG_STATS_INTERVAL: float = 10.0 VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None @@ -54,27 +56,30 @@ if TYPE_CHECKING: VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False - VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" + VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_XLA_USE_SPMD: bool = False - VLLM_WORKER_MULTIPROC_METHOD: str = "fork" + VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") + VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" + VLLM_MAIN_CUDA_VERSION: str = "12.8" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False - CMAKE_BUILD_TYPE: Optional[str] = None + CMAKE_BUILD_TYPE: Optional[Literal["Debug", "Release", "RelWithDebInfo"]] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms @@ -84,12 +89,15 @@ if TYPE_CHECKING: VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False + VLLM_USE_AOT_COMPILE: bool = False + VLLM_FORCE_AOT_LOAD: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False @@ -98,6 +106,10 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False + VLLM_ROCM_USE_TRITON_ROPE: bool = False + VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -111,16 +123,19 @@ if TYPE_CHECKING: VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False + VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32 VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 + VLLM_USE_STANDALONE_COMPILE: bool = False VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False + VLLM_RAY_DP_PACK_STRATEGY: str = "strict" VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MXFP4_USE_MARLIN: Optional[bool] = None VLLM_V0_USE_OUTLINES_CACHE: bool = False @@ -128,41 +143,72 @@ if TYPE_CHECKING: VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_USING_PATHWAYS: bool = False - VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True + VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False - VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" + VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput" VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" - VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 - VLLM_ALL2ALL_BACKEND: str = "naive" + VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 + VLLM_ALL2ALL_BACKEND: Literal[ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", + ] = "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 - VLLM_KV_CACHE_LAYOUT: Optional[str] = None + VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False - VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" + VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[ + "FP", "INT8", "INT6", "INT4", "NONE" + ] = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None - VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 + VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None + VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False + VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False + VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False + VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True + VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" + VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 + VLLM_DBO_COMM_SMS: int = 20 + GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] + VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None + VLLM_DEBUG_DUMP_PATH: Optional[str] = None + VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True + VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True + VLLM_USE_NCCL_SYMM_MEM: bool = False + VLLM_NCCL_INCLUDE_PATH: Optional[str] = None + VLLM_USE_FBGEMM: bool = False + VLLM_GC_DEBUG: str = "" def get_default_cache_root(): @@ -191,6 +237,113 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: return bool(int(value)) +def use_aot_compile() -> bool: + from vllm.utils import is_torch_equal_or_newer + + default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" + return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" + + +def env_with_choices( + env_name: str, + default: Optional[str], + choices: Union[list[str], Callable[[], list[str]]], + case_sensitive: bool = True, +) -> Callable[[], Optional[str]]: + """ + Create a lambda that validates environment variable against allowed choices + + Args: + env_name: Name of the environment variable + default: Default value if not set (can be None) + choices: List of valid string options or callable that returns list + case_sensitive: Whether validation should be case sensitive + + Returns: + Lambda function for environment_variables dict + """ + + def _get_validated_env() -> Optional[str]: + value = os.getenv(env_name) + if value is None: + return default + + # Resolve choices if it's a callable (for lazy loading) + actual_choices = choices() if callable(choices) else choices + + if not case_sensitive: + check_value = value.lower() + check_choices = [choice.lower() for choice in actual_choices] + else: + check_value = value + check_choices = actual_choices + + if check_value not in check_choices: + raise ValueError( + f"Invalid value '{value}' for {env_name}. " + f"Valid options: {actual_choices}." + ) + + return value + + return _get_validated_env + + +def env_list_with_choices( + env_name: str, + default: list[str], + choices: Union[list[str], Callable[[], list[str]]], + case_sensitive: bool = True, +) -> Callable[[], list[str]]: + """ + Create a lambda that validates environment variable + containing comma-separated values against allowed choices + + Args: + env_name: Name of the environment variable + default: Default list of values if not set + choices: List of valid string options or callable that returns list + case_sensitive: Whether validation should be case sensitive + + Returns: + Lambda function for environment_variables + dict that returns list of strings + """ + + def _get_validated_env_list() -> list[str]: + value = os.getenv(env_name) + if value is None: + return default + + # Split comma-separated values and strip whitespace + values = [v.strip() for v in value.split(",") if v.strip()] + + if not values: + return default + + # Resolve choices if it's a callable (for lazy loading) + actual_choices = choices() if callable(choices) else choices + + # Validate each value + for val in values: + if not case_sensitive: + check_value = val.lower() + check_choices = [choice.lower() for choice in actual_choices] + else: + check_value = val + check_choices = actual_choices + + if check_value not in check_choices: + raise ValueError( + f"Invalid value '{val}' in {env_name}. " + f"Valid options: {actual_choices}." + ) + + return values + + return _get_validated_env_list + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -200,15 +353,16 @@ def get_vllm_port() -> Optional[int]: Raises: ValueError: If VLLM_PORT is a URI, suggest k8s service discovery issue. """ - if 'VLLM_PORT' not in os.environ: + if "VLLM_PORT" not in os.environ: return None - port = os.getenv('VLLM_PORT', '0') + port = os.getenv("VLLM_PORT", "0") try: return int(port) except ValueError as err: from urllib.parse import urlparse + parsed = urlparse(port) if parsed.scheme: raise ValueError( @@ -216,8 +370,7 @@ def get_vllm_port() -> Optional[int]: "This may be caused by a Kubernetes service discovery issue," "check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html" ) from None - raise ValueError( - f"VLLM_PORT '{port}' must be a valid integer") from err + raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err # The begin-* and end* here are used by the documentation generator @@ -226,293 +379,269 @@ def get_vllm_port() -> Optional[int]: # --8<-- [start:env-vars-definition] environment_variables: dict[str, Callable[[], Any]] = { - # ================== Installation Time Env Vars ================== - # Target device of vLLM, supporting [cuda (by default), - # rocm, neuron, cpu] - "VLLM_TARGET_DEVICE": - lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), - + # rocm, cpu] + "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), + # Main CUDA version of vLLM, supporting [12.6, 12.8, 12.9], + # 12.8 is the default. This follows PyTorch but can be overridden. + "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() + or "12.8", # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs - "MAX_JOBS": - lambda: os.getenv("MAX_JOBS", None), - + "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), # Number of threads to use for nvcc # By default this is 1. # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. - "NVCC_THREADS": - lambda: os.getenv("NVCC_THREADS", None), - + "NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None), # If set, vllm will use precompiled binaries (*.so) - "VLLM_USE_PRECOMPILED": - lambda: os.environ.get("VLLM_USE_PRECOMPILED", "").strip().lower() in - ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), - + "VLLM_USE_PRECOMPILED": lambda: os.environ.get("VLLM_USE_PRECOMPILED", "") + .strip() + .lower() + in ("1", "true") + or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), # Used to mark that setup.py is running in a Docker build context, # in order to force the use of precompiled binaries. - "VLLM_DOCKER_BUILD_CONTEXT": - lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "").strip().lower() in - ("1", "true"), - + "VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") + .strip() + .lower() + in ("1", "true"), # Whether to force using nightly wheel in python build. # This is used for testing the nightly wheel in python build. - "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": - lambda: bool(int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) - ), - + "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": lambda: bool( + int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) + ), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" - "CMAKE_BUILD_TYPE": - lambda: os.getenv("CMAKE_BUILD_TYPE"), - + "CMAKE_BUILD_TYPE": env_with_choices( + "CMAKE_BUILD_TYPE", None, ["Debug", "Release", "RelWithDebInfo"] + ), # If set, vllm will print verbose logs during installation - "VERBOSE": - lambda: bool(int(os.getenv('VERBOSE', '0'))), - + "VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))), # Root directory for vLLM configuration files # Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set # Note that this not only affects how vllm finds its configuration files # during runtime, but also affects how vllm installs its configuration # files during **installation**. - "VLLM_CONFIG_ROOT": - lambda: os.path.expanduser( + "VLLM_CONFIG_ROOT": lambda: os.path.expanduser( os.getenv( "VLLM_CONFIG_ROOT", os.path.join(get_default_config_root(), "vllm"), - )), - + ) + ), # ================== Runtime Env Vars ================== - # Root directory for vLLM cache files # Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set - "VLLM_CACHE_ROOT": - lambda: os.path.expanduser( + "VLLM_CACHE_ROOT": lambda: os.path.expanduser( os.getenv( "VLLM_CACHE_ROOT", os.path.join(get_default_cache_root(), "vllm"), - )), - + ) + ), # used in distributed environment to determine the ip address # of the current node, when the node has multiple network interfaces. # If you are using multi-node inference, you should set this differently # on each node. - 'VLLM_HOST_IP': - lambda: os.getenv('VLLM_HOST_IP', ""), - + "VLLM_HOST_IP": lambda: os.getenv("VLLM_HOST_IP", ""), # used in distributed environment to manually set the communication port # Note: if VLLM_PORT is set, and some code asks for multiple ports, the # VLLM_PORT will be used as the first port, and the rest will be generated # by incrementing the VLLM_PORT value. - 'VLLM_PORT': - get_vllm_port, - + "VLLM_PORT": get_vllm_port, # path used for ipc when the frontend api server is running in # multi-processing mode to communicate with the backend engine process. - 'VLLM_RPC_BASE_PATH': - lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()), - + "VLLM_RPC_BASE_PATH": lambda: os.getenv( + "VLLM_RPC_BASE_PATH", tempfile.gettempdir() + ), # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers - "VLLM_USE_MODELSCOPE": - lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", - + "VLLM_USE_MODELSCOPE": lambda: os.environ.get( + "VLLM_USE_MODELSCOPE", "False" + ).lower() + == "true", # Interval in seconds to log a warning message when the ring buffer is full - "VLLM_RINGBUFFER_WARNING_INTERVAL": - lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), - + "VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int( + os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60") + ), # path to cudatoolkit home directory, under which should be bin, include, # and lib directories. - "CUDA_HOME": - lambda: os.environ.get("CUDA_HOME", None), - + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), # Path to the NCCL library file. It is needed because nccl>=2.19 brought # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 - "VLLM_NCCL_SO_PATH": - lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), - + "VLLM_NCCL_SO_PATH": lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl # library file in the locations specified by `LD_LIBRARY_PATH` - "LD_LIBRARY_PATH": - lambda: os.environ.get("LD_LIBRARY_PATH", None), - + "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), # flag to control if vllm should use triton flash attention - "VLLM_USE_TRITON_FLASH_ATTN": - lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in - ("true", "1")), - + "VLLM_USE_TRITON_FLASH_ATTN": lambda: ( + os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1") + ), # Use separate prefill and decode kernels for V1 attention instead of # the unified triton kernel. - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": - lambda: - (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in - ("true", "1")), - - # Use AITER triton unified attention for V1 attention - "VLLM_USE_AITER_UNIFIED_ATTENTION": - lambda: - (os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in - ("true", "1")), - + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: ( + os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() + in ("true", "1") + ), # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. - "VLLM_FLASH_ATTN_VERSION": - lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)), - - # Internal flag to enable Dynamo fullgraph capture - "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": - lambda: bool( - os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), - + "VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int( + os.environ.get("VLLM_FLASH_ATTN_VERSION", None) + ), # Feature flag to enable/disable Inductor standalone compile. # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is - # enabled by default. - "VLLM_USE_STANDALONE_COMPILE": - lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", - + # disabled by default. + "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get( + "VLLM_USE_STANDALONE_COMPILE", "0" + ) + == "1", + # Debug pattern matching inside custom passes. + # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). + "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( + "VLLM_PATTERN_MATCH_DEBUG", None + ), + # Dump fx graphs to the given directory. + # It will override CompilationConfig.debug_dump_path if set. + "VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None), + # Feature flag to enable/disable AOT compilation. This will ensure + # compilation is done in warmup phase and the compilation will be + # reused in subsequent calls. + "VLLM_USE_AOT_COMPILE": use_aot_compile, + # Force vllm to always load AOT compiled models from disk. Failure + # to load will result in a hard error when this is enabled. + # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. + "VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id - "LOCAL_RANK": - lambda: int(os.environ.get("LOCAL_RANK", "0")), - + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), # used to control the visible devices in the distributed setting - "CUDA_VISIBLE_DEVICES": - lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), - + "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), # timeout for each iteration in the engine - "VLLM_ENGINE_ITERATION_TIMEOUT_S": - lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), - + "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int( + os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60") + ), # API key for vLLM API server - "VLLM_API_KEY": - lambda: os.environ.get("VLLM_API_KEY", None), - + "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), # Whether to log responses from API Server for debugging - "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": - lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" - ).lower() == "true", - + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": lambda: os.environ.get( + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" + ).lower() + == "true", # S3 access information, used for tensorizer to load model from S3 - "S3_ACCESS_KEY_ID": - lambda: os.environ.get("S3_ACCESS_KEY_ID", None), - "S3_SECRET_ACCESS_KEY": - lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), - "S3_ENDPOINT_URL": - lambda: os.environ.get("S3_ENDPOINT_URL", None), - + "S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None), + "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": lambda: os.environ.get("S3_ENDPOINT_URL", None), # Usage stats collection - "VLLM_USAGE_STATS_SERVER": - lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), - "VLLM_NO_USAGE_STATS": - lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", - "VLLM_DO_NOT_TRACK": - lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( - "DO_NOT_TRACK", None) or "0") == "1", - "VLLM_USAGE_SOURCE": - lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), - + "VLLM_USAGE_STATS_SERVER": lambda: os.environ.get( + "VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai" + ), + "VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DISABLE_FLASHINFER_PREFILL": lambda: os.environ.get( + "VLLM_DISABLE_FLASHINFER_PREFILL", "0" + ) + == "1", + "VLLM_DO_NOT_TRACK": lambda: ( + os.environ.get("VLLM_DO_NOT_TRACK", None) + or os.environ.get("DO_NOT_TRACK", None) + or "0" + ) + == "1", + "VLLM_USAGE_SOURCE": lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), # Logging configuration # If set to 0, vllm will not configure logging # If set to 1, vllm will configure logging using the default configuration # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH - "VLLM_CONFIGURE_LOGGING": - lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), - "VLLM_LOGGING_CONFIG_PATH": - lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), - + "VLLM_CONFIGURE_LOGGING": lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), # this is used for configuring the default logging level - "VLLM_LOGGING_LEVEL": - lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), - + "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), + # this is used for configuring the default logging stream + "VLLM_LOGGING_STREAM": lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages - "VLLM_LOGGING_PREFIX": - lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), - + "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), # if set, vllm will call logits processors in a thread pool with this many # threads. This is useful when using custom logits processors that either # (a) launch additional CUDA kernels or (b) do significant CPU-bound work # while not holding the python GIL, or both. - "VLLM_LOGITS_PROCESSOR_THREADS": - lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) - if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, - + "VLLM_LOGITS_PROCESSOR_THREADS": lambda: int( + os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0") + ) + if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ + else None, # If set, vllm will log stats at this interval in seconds # If not set, vllm will log stats every 10 seconds. - "VLLM_LOG_STATS_INTERVAL": - lambda: val if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) - > 0. else 10., - + "VLLM_LOG_STATS_INTERVAL": lambda: val + if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) > 0.0 + else 10.0, # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging - "VLLM_TRACE_FUNCTION": - lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), - + "VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), # Backend for attention computation - # Available options: + # Example options: # - "TORCH_SDPA": use torch.nn.MultiheadAttention # - "FLASH_ATTN": use FlashAttention # - "XFORMERS": use XFormers - # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA - "VLLM_ATTENTION_BACKEND": - lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), - + # - "FLASH_ATTN_MLA": use FlashAttention for MLA + # - "FLASHINFER_MLA": use FlashInfer for MLA + # - "CUTLASS_MLA": use CUTLASS for MLA + # All possible options loaded dynamically from _Backend enum + "VLLM_ATTENTION_BACKEND": env_with_choices( + "VLLM_ATTENTION_BACKEND", + None, + lambda: list( + __import__( + "vllm.attention.backends.registry", fromlist=["_Backend"] + )._Backend.__members__.keys() + ), + ), # If set, vllm will use flashinfer sampler - "VLLM_USE_FLASHINFER_SAMPLER": - lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) - if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - + "VLLM_USE_FLASHINFER_SAMPLER": lambda: bool( + int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]) + ) + if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ + else None, # Pipeline stage partition strategy - "VLLM_PP_LAYER_PARTITION": - lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), - + "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), # (CPU backend only) CPU key-value cache space. # default is None and will be set as 4 GB - "VLLM_CPU_KVCACHE_SPACE": - lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) - if "VLLM_CPU_KVCACHE_SPACE" in os.environ else None, - + "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) + if "VLLM_CPU_KVCACHE_SPACE" in os.environ + else None, # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. - "VLLM_CPU_OMP_THREADS_BIND": - lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), - + "VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), # (CPU backend only) CPU cores not used by OMP threads . # Those CPU cores will not be used by OMP threads of a rank. - "VLLM_CPU_NUM_OF_RESERVED_CPU": - lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")) - if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None, - + "VLLM_CPU_NUM_OF_RESERVED_CPU": lambda: int( + os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0") + ) + if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ + else None, # (CPU backend only) whether to use prepack for MoE layer. This will be # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might # need to set this to "0" (False). - "VLLM_CPU_MOE_PREPACK": - lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), - + "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), # (CPU backend only) whether to use SGL kernels, optimized for small batch. - "VLLM_CPU_SGL_KERNEL": - lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), - + "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. - "VLLM_USE_RAY_SPMD_WORKER": - lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), - + "VLLM_USE_RAY_SPMD_WORKER": lambda: bool( + int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0")) + ), # If the env var is set, it uses the Ray's Compiled Graph # (previously known as ADAG) API which optimizes the # control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Note that this variable is set to 1 in V1 by default # when ray distributed executor is used. - "VLLM_USE_RAY_COMPILED_DAG": - lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), - + "VLLM_USE_RAY_COMPILED_DAG": lambda: bool( + int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0")) + ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -521,63 +650,69 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "nccl": use NCCL for communication # - "shm": use shared memory and gRPC for communication # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. - "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": - lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"), - + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices( + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", ["auto", "nccl", "shm"] + ), # If the env var is set, it enables GPU communication overlap # (experimental feature) in Ray's Compiled Graph. This flag is ignored if # VLLM_USE_RAY_COMPILED_DAG is not set. - "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": - lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) - ), - + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool( + int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) + ), # If the env var is set, it uses a Ray Communicator wrapping # vLLM's pipeline parallelism communicator to interact with Ray's # Compiled Graph. Otherwise, it uses Ray's NCCL communicator. # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. - "VLLM_USE_RAY_WRAPPED_PP_COMM": - lambda: bool(int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))), - + "VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool( + int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work - "VLLM_WORKER_MULTIPROC_METHOD": - lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "fork"), - + "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( + "VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork"] + ), # Path to the cache for storing downloaded assets - "VLLM_ASSETS_CACHE": - lambda: os.path.expanduser( + "VLLM_ASSETS_CACHE": lambda: os.path.expanduser( os.getenv( "VLLM_ASSETS_CACHE", os.path.join(get_default_cache_root(), "vllm", "assets"), - )), - + ) + ), + # If the env var is set, we will clean model file in + # this path $VLLM_ASSETS_CACHE/model_streamer/$model_name + "VLLM_ASSETS_CACHE_MODEL_CLEAN": lambda: bool( + int(os.getenv("VLLM_ASSETS_CACHE_MODEL_CLEAN", "0")) + ), # Timeout for fetching images when serving multimodal models # Default is 5 seconds - "VLLM_IMAGE_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), - + "VLLM_IMAGE_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), # Timeout for fetching videos when serving multimodal models # Default is 30 seconds - "VLLM_VIDEO_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")), - + "VLLM_VIDEO_FETCH_TIMEOUT": lambda: int( + os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30") + ), # Timeout for fetching audio when serving multimodal models # Default is 10 seconds - "VLLM_AUDIO_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), - + "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int( + os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10") + ), + # Whether to allow HTTP redirects when fetching from media URLs. + # Default to True + "VLLM_MEDIA_URL_ALLOW_REDIRECTS": lambda: bool( + int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1")) + ), # Max number of workers for the thread pool handling # media bytes loading. Set to 1 to disable parallel processing. # Default is 8 - "VLLM_MEDIA_LOADING_THREAD_COUNT": - lambda: int(os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8")), - + "VLLM_MEDIA_LOADING_THREAD_COUNT": lambda: int( + os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8") + ), # Maximum filesize in MB for a single audio file when processing # speech-to-text requests. Files larger than this will be rejected. # Default is 25 MB - "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB": - lambda: int(os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25")), - + "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB": lambda: int( + os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25") + ), # Backend for Video IO # - "opencv": Default backend that uses OpenCV stream buffered backend. # @@ -585,258 +720,250 @@ environment_variables: dict[str, Callable[[], Any]] = { # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. - "VLLM_VIDEO_LOADER_BACKEND": - lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - + "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv( + "VLLM_VIDEO_LOADER_BACKEND", "opencv" + ), # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache # Default is 4 GiB per API process + 4 GiB per engine core process - "VLLM_MM_INPUT_CACHE_GIB": - lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), - + "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. - "VLLM_XLA_CACHE_PATH": - lambda: os.path.expanduser( + "VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser( os.getenv( "VLLM_XLA_CACHE_PATH", os.path.join(get_default_cache_root(), "vllm", "xla_cache"), - )), - + ) + ), # If set, assert on XLA recompilation after each execution step. - "VLLM_XLA_CHECK_RECOMPILATION": - lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), - + "VLLM_XLA_CHECK_RECOMPILATION": lambda: bool( + int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0")) + ), # Enable SPMD mode for TPU backend. - "VLLM_XLA_USE_SPMD": - lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), - "VLLM_FUSED_MOE_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), + "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768") + ), # Control whether to use fused MoE activation chunking. Current chunking # logic is incompatible with torch.compile and causes IMA. See issue # https://github.com/vllm-project/vllm/issues/19631. - "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": - lambda: bool( - int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))), - + "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool( + int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1")) + ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests - "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": - lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)), - + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( + os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0) + ), # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows # the user to specify a max sequence length greater than # the max length derived from the model's config.json. # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": - lambda: - (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in - ("1", "true")), - + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": lambda: ( + os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() + in ("1", "true") + ), # If set, forces FP8 Marlin to be used for FP8 quantization regardless # of the hardware support for FP8 compute. - "VLLM_TEST_FORCE_FP8_MARLIN": - lambda: - (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in - ("1", "true")), - "VLLM_TEST_FORCE_LOAD_FORMAT": - lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"), - + "VLLM_TEST_FORCE_FP8_MARLIN": lambda: ( + os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() + in ("1", "true") + ), + "VLLM_TEST_FORCE_LOAD_FORMAT": lambda: os.getenv( + "VLLM_TEST_FORCE_LOAD_FORMAT", "dummy" + ), # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_TIMEOUT": - lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), - + "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # Timeout in seconds for keeping HTTP connections alive in API server - "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": - lambda: int(os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5")), - + "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int( + os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5") + ), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded # if this is set to an empty string, no plugins will be loaded - "VLLM_PLUGINS": - lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ - "VLLM_PLUGINS"].split(","), - + "VLLM_PLUGINS": lambda: None + if "VLLM_PLUGINS" not in os.environ + else os.environ["VLLM_PLUGINS"].split(","), # a local directory to look in for unrecognized LoRA adapters. # only works if plugins are enabled and # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. - "VLLM_LORA_RESOLVER_CACHE_DIR": - lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), - + "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( + "VLLM_LORA_RESOLVER_CACHE_DIR", None + ), # Enables torch profiler if set. # Both AsyncLLM's CPU traces as well as workers' # traces (CPU & GPU) will be saved under this directory. # Note that it must be an absolute path. - "VLLM_TORCH_PROFILER_DIR": - lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os - .path.abspath(os.path.expanduser(os.getenv( - "VLLM_TORCH_PROFILER_DIR", ".")))), - + "VLLM_TORCH_PROFILER_DIR": lambda: ( + None + if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None + else os.path.abspath( + os.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", ".")) + ) + ), # Enable torch profiler to record shapes if set # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will # not record shapes. - "VLLM_TORCH_PROFILER_RECORD_SHAPES": - lambda: bool(os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"), - + "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0" + ), # Enable torch profiler to profile memory if set # VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler # will not profile memory. - "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": - lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"), - + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0" + ), # Enable torch profiler to profile stack if set # VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL # profile stack by default. - "VLLM_TORCH_PROFILER_WITH_STACK": - lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"), - + "VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0" + ), # Enable torch profiler to profile flops if set # VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will # not profile flops. - "VLLM_TORCH_PROFILER_WITH_FLOPS": - lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"), - + "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0" + ), # If set, vLLM will use Triton implementations of AWQ. - "VLLM_USE_TRITON_AWQ": - lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), - + "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), # If set, allow loading or unloading lora adapters in runtime, - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": - lambda: - (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in - ("1", "true")), - + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": lambda: ( + os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() + in ("1", "true") + ), # We assume drivers can report p2p status correctly. # If the program hangs when using custom allreduce, # potantially caused by a bug in the driver (535 series), # if might be helpful to set VLLM_SKIP_P2P_CHECK=0 # so that vLLM can verify if p2p is actually working. # See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa - "VLLM_SKIP_P2P_CHECK": - lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "1") == "1", - + "VLLM_SKIP_P2P_CHECK": lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "1") == "1", # List of quantization kernels that should be disabled, used for testing # and performance comparisons. Currently only affects MPLinearKernel # selection # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) - "VLLM_DISABLED_KERNELS": - lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ - "VLLM_DISABLED_KERNELS"].split(","), - + "VLLM_DISABLED_KERNELS": lambda: [] + if "VLLM_DISABLED_KERNELS" not in os.environ + else os.environ["VLLM_DISABLED_KERNELS"].split(","), + # Disable pynccl (using torch.distributed instead) + "VLLM_DISABLE_PYNCCL": lambda: ( + os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") + ), # If set, use the V1 code path. - "VLLM_USE_V1": - lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), - + "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. - "VLLM_ROCM_USE_AITER": - lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1") + ), # Whether to use aiter paged attention. # By default is disabled. - "VLLM_ROCM_USE_AITER_PAGED_ATTN": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1") + ), # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) - "VLLM_ROCM_USE_AITER_LINEAR": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_LINEAR": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") + ), # Whether to use aiter moe ops. # By default is enabled. - "VLLM_ROCM_USE_AITER_MOE": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_MOE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1") + ), # use aiter rms norm op if aiter ops are enabled. - "VLLM_ROCM_USE_AITER_RMSNORM": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_RMSNORM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1") + ), # Whether to use aiter mla ops. # By default is enabled. - "VLLM_ROCM_USE_AITER_MLA": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_MLA": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1") + ), # Whether to use aiter mha ops. # By default is enabled. - "VLLM_ROCM_USE_AITER_MHA": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_MHA": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") + ), + # Whether to use aiter fp4 gemm asm. + # By default is disabled. + "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") + ), + # Whether to use aiter rope. + # By default is disabled. + "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") + ), + # Whether to use aiter triton fp8 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") + ), + # Use AITER triton unified attention for V1 attention + "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() + in ("true", "1") + ), # use rocm skinny gemms - "VLLM_ROCM_USE_SKINNY_GEMM": - lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( + os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") + ), # Pad the fp8 weights to 256 bytes for ROCm - "VLLM_ROCM_FP8_PADDING": - lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), - + "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Pad the weights for the moe kernel - "VLLM_ROCM_MOE_PADDING": - lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), - + "VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), # custom paged attention kernel for MI3* cards - "VLLM_ROCM_CUSTOM_PAGED_ATTN": - lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( + os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce - "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": - lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(), - + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": env_with_choices( + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", + "NONE", + ["FP", "INT8", "INT6", "INT4", "NONE"], + ), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 # kernels are slower than fp16, # If environment variable is set to 1, the input is converted to fp16 - "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": - lambda: - (os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": lambda: ( + os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() + in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards. # Controls the maximum allowed number of data bytes(MB) for custom quick # allreduce communication. # Default: 2048 MB. # Data exceeding this size will use either custom allreduce or RCCL # communication. - "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": - lambda: maybe_convert_int( - os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), - + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": lambda: maybe_convert_int( + os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None) + ), # Divisor for dynamic query scale factor calculation for FP8 KV Cache - "Q_SCALE_CONSTANT": - lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), + "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), # Divisor for dynamic key scale factor calculation for FP8 KV Cache - "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), # Divisor for dynamic value scale factor calculation for FP8 KV Cache - "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), - + "V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), # If set, enable multiprocessing in LLM for the V1 code path. - "VLLM_ENABLE_V1_MULTIPROCESSING": - lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), - "VLLM_LOG_BATCHSIZE_INTERVAL": - lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), - "VLLM_DISABLE_COMPILE_CACHE": - lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), - + "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool( + int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")) + ), + "VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float( + os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1") + ), + "VLLM_DISABLE_COMPILE_CACHE": lambda: bool( + int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0")) + ), # If set, vllm will run in development mode, which will enable # some additional endpoints for developing and debugging, # e.g. `/reset_prefix_cache` - "VLLM_SERVER_DEV_MODE": - lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), - + "VLLM_SERVER_DEV_MODE": lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), # Controls the maximum number of requests to handle in a # single asyncio task when processing per-token outputs in the # V1 AsyncLLM interface. It is applicable when handling a high @@ -844,149 +971,163 @@ environment_variables: dict[str, Callable[[], Any]] = { # Setting this too high can result in a higher variance of # inter-message latencies. Setting it too low can negatively impact # TTFT and overall throughput. - "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")), - + "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128") + ), # If set, vLLM will disable the MLA attention optimizations. - "VLLM_MLA_DISABLE": - lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), - + "VLLM_MLA_DISABLE": lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), + # If set, vLLM will pick up the provided Flash Attention MLA + # max number splits for cuda graph decode + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": lambda: int( + os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "32") + ), # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. - "VLLM_RAY_PER_WORKER_GPUS": - lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")), - + "VLLM_RAY_PER_WORKER_GPUS": lambda: float( + os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0") + ), # Bundle indices for Ray, if it is set, it can control precisely # which indices are used for the Ray bundle, for every worker. # Format: comma-separated list of integers, e.g. "0,1,2,3" - "VLLM_RAY_BUNDLE_INDICES": - lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), - + "VLLM_RAY_BUNDLE_INDICES": lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), # In some system, find_loaded_library() may not work. So we allow users to # specify the path through environment variable VLLM_CUDART_SO_PATH. - "VLLM_CUDART_SO_PATH": - lambda: os.getenv("VLLM_CUDART_SO_PATH", None), - + "VLLM_CUDART_SO_PATH": lambda: os.getenv("VLLM_CUDART_SO_PATH", None), # Rank of the process in the data parallel setting - "VLLM_DP_RANK": - lambda: int(os.getenv("VLLM_DP_RANK", "0")), - + "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), # Rank of the process in the data parallel setting. # Defaults to VLLM_DP_RANK when not set. - "VLLM_DP_RANK_LOCAL": - lambda: int( - os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)), - + "VLLM_DP_RANK_LOCAL": lambda: int( + os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK) + ), # World size of the data parallel setting - "VLLM_DP_SIZE": - lambda: int(os.getenv("VLLM_DP_SIZE", "1")), - + "VLLM_DP_SIZE": lambda: int(os.getenv("VLLM_DP_SIZE", "1")), # IP address of the master node in the data parallel setting - "VLLM_DP_MASTER_IP": - lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), - + "VLLM_DP_MASTER_IP": lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), # Port of the master node in the data parallel setting - "VLLM_DP_MASTER_PORT": - lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), - + "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), # In the context of executing MoE models with Data-Parallel, Expert-Parallel # and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE # dictates the quantum of tokens that can be dispatched from a DP # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE # units. - "VLLM_MOE_DP_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), - + "VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), # Randomize inputs during dummy runs when using Data Parallel - "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": - lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1", - + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get( + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" + ) + == "1", + # Strategy to pack the data parallel ranks for Ray. + # Available options: + # - "fill": + # for DP master node, allocate exactly data-parallel-size-local DP ranks, + # for non-master nodes, allocate as many DP ranks as can fit; + # - "strict": + # allocate exactly data-parallel-size-local DP ranks to each picked node; + # This environment variable is ignored if data-parallel-backend is not Ray. + "VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv( + "VLLM_RAY_DP_PACK_STRATEGY", "strict" + ), # Whether to use S3 path for model loading in CI via RunAI Streamer - "VLLM_CI_USE_S3": - lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", - + "VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", # Use model_redirect to redirect the model name to a local folder. # `model_redirect` can be a json file mapping the model between # repo_id and local folder: # {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"} # or a space separated values table file: # meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B - "VLLM_MODEL_REDIRECT_PATH": - lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None), - + "VLLM_MODEL_REDIRECT_PATH": lambda: os.environ.get( + "VLLM_MODEL_REDIRECT_PATH", None + ), # Whether to use atomicAdd reduce in gptq/awq marlin kernel. - "VLLM_MARLIN_USE_ATOMIC_ADD": - lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", - + "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get( + "VLLM_MARLIN_USE_ATOMIC_ADD", "0" + ) + == "1", # Whether to use marlin kernel in mxfp4 quantization method - "VLLM_MXFP4_USE_MARLIN": - lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)), - + "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( + os.environ.get("VLLM_MXFP4_USE_MARLIN", None) + ), # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. - "VLLM_V0_USE_OUTLINES_CACHE": - lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", - + "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get( + "VLLM_V0_USE_OUTLINES_CACHE", "0" + ) + == "1", # Whether to turn on the outlines cache for V1 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. - "VLLM_V1_USE_OUTLINES_CACHE": - lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1", - + "VLLM_V1_USE_OUTLINES_CACHE": lambda: os.environ.get( + "VLLM_V1_USE_OUTLINES_CACHE", "0" + ) + == "1", # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. - "VLLM_TPU_BUCKET_PADDING_GAP": - lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) - if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, - "VLLM_TPU_MOST_MODEL_LEN": - lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)), - + "VLLM_TPU_BUCKET_PADDING_GAP": lambda: int( + os.environ["VLLM_TPU_BUCKET_PADDING_GAP"] + ) + if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ + else 0, + "VLLM_TPU_MOST_MODEL_LEN": lambda: maybe_convert_int( + os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None) + ), # Whether using Pathways - "VLLM_TPU_USING_PATHWAYS": - lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()), - + "VLLM_TPU_USING_PATHWAYS": lambda: bool( + "proxy" in os.getenv("JAX_PLATFORMS", "").lower() + ), # Allow use of DeepGemm kernels for fused moe ops. - "VLLM_USE_DEEP_GEMM": - lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), - + "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. - # E8M0 is faster on B200 but may reduce accuracy. - "VLLM_USE_DEEP_GEMM_E8M0": - lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( + int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine # startup time by a couple of minutes. # Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup. - "VLLM_SKIP_DEEP_GEMM_WARMUP": - lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), - + "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool( + int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0")) + ), + # Whether to use fused grouped_topk used for MoE expert selection. + "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( + int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) + ), # Allow use of FlashInfer MoE kernels for fused moe ops. - "VLLM_USE_FLASHINFER_MOE_FP8": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), - + "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) + ), + # Allow use of FlashInfer MoE kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0")) + ), # Allow use of FlashInfer CUTLASS kernels for fused moe ops. - "VLLM_USE_FLASHINFER_MOE_FP4": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))), - + "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0")) + ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), - + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0")) + ), + # If set to 1, use the FlashInfer CUTLASS backend for + # MXFP8 (activation) x MXFP4 (weight) MoE. + # This is separate from the TRTLLMGEN path controlled by + # VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")) + ), # If set to 1, use the FlashInfer # BF16 (activation) x MXFP4 (weight) MoE backend. - "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))), - + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0")) + ), # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. - "VLLM_XGRAMMAR_CACHE_MB": - lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), - + "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), # Control the threshold for msgspec to use 'zero copy' for # serialization/deserialization of tensors. Tensors below # this limit will be encoded into the msgpack buffer, and @@ -994,82 +1135,97 @@ environment_variables: dict[str, Callable[[], Any]] = { # While the sending side still actually copies the tensor # in all cases, on the receiving side, tensors above this # limit will actually be zero-copy decoded. - "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": - lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), - + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int( + os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256") + ), # If set, allow insecure serialization using pickle. # This is useful for environments where it is deemed safe to use the # insecure method and it is needed for some reason. - "VLLM_ALLOW_INSECURE_SERIALIZATION": - lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), - + "VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool( + int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")) + ), # IP address used for NIXL handshake between remote agents. - "VLLM_NIXL_SIDE_CHANNEL_HOST": - lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"), - + "VLLM_NIXL_SIDE_CHANNEL_HOST": lambda: os.getenv( + "VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost" + ), # Port used for NIXL handshake between remote agents. - "VLLM_NIXL_SIDE_CHANNEL_PORT": - lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), - + "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( + os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") + ), # all2all backend for vllm's expert parallel communication # Available options: - # - "naive": naive all2all implementation using all-reduce + # - "naive": naive all2all implementation using broadcasts + # - "allgather_reducescatter": all2all implementation based on allgather and + # reducescatter # - "pplx": use pplx kernels # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels - "VLLM_ALL2ALL_BACKEND": - lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), - - # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both - # require compute capability 10.0 or above. + # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl + "VLLM_ALL2ALL_BACKEND": env_with_choices( + "VLLM_ALL2ALL_BACKEND", + "allgather_reducescatter", + [ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", + ], + ), + # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. + # Both require compute capability 10.0 or above. # Available options: # - "throughput": [default] # Uses CUTLASS kernels optimized for high-throughput batch inference. # - "latency": # Uses TensorRT-LLM kernels optimized for low-latency inference. - # To set this backend, define the environment variable: - # export VLLM_FLASHINFER_MOE_BACKEND=latency. - # If not set, defaults to "throughput". - "VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv( - "VLLM_FLASHINFER_MOE_BACKEND", "throughput" + "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( + "VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"] ), - # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. # This is used to prevent the kernel from running out of memory. - "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": - lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), - + "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int( + os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840") + ), + # Specifies the thresholds of the communicated tensor sizes under which + # vllm should use flashinfer fused allreduce. The variable should be a + # JSON with the following format: + # { <world size>: <max size in mb> } + # Unspecified world sizes will fall back to + # { 2: 64, 4: 1, <everything else>: 0.5 } + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": lambda: json.loads( + os.getenv("VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}") + ), # MoE routing strategy selector. # See `RoutingSimulator.get_available_strategies()` # for available # strategies. # Cutstom routing strategies can be registered by # RoutingSimulator.register_strategy() # Note: custom strategies may not produce correct model outputs - "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": - lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(), - + "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": lambda: os.environ.get( + "VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "" + ).lower(), # Regex timeout for use by the vLLM tool parsing plugins. - "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": - lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), - + "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( + os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") + ), # Reduce CPU usage when vLLM is idle. Enabling this will incur small # latency penalty when a request eventually comes. - "VLLM_SLEEP_WHEN_IDLE": - lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), - + "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. - "VLLM_MQ_MAX_CHUNK_BYTES_MB": - lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")), - + "VLLM_MQ_MAX_CHUNK_BYTES_MB": lambda: int( + os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16") + ), # Timeout in seconds for execute_model RPC calls in multiprocessing # executor (only applies when TP > 1). - "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": - lambda: int(os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "300")), - + "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": lambda: int( + os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "300") + ), # KV Cache layout used throughout vllm. # Some common values are: # - NHD @@ -1077,63 +1233,71 @@ environment_variables: dict[str, Callable[[], Any]] = { # Where N=num_blocks, H=num_heads and D=head_size. The default value will # leave the layout choice to the backend. Mind that backends may only # implement and support a subset of all possible layouts. - "VLLM_KV_CACHE_LAYOUT": - lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None), - + "VLLM_KV_CACHE_LAYOUT": env_with_choices( + "VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"] + ), # Enable checking whether the generated logits contain NaNs, # indicating corrupted output. Useful for debugging low level bugs # or bad hardware but it may add compute overhead. - "VLLM_COMPUTE_NANS_IN_LOGITS": - lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), - + "VLLM_COMPUTE_NANS_IN_LOGITS": lambda: bool( + int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0")) + ), # Controls whether or not emulations are used for NVFP4 # generations on machines < 100 for compressed-tensors # models - "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), - + "VLLM_USE_NVFP4_CT_EMULATIONS": lambda: bool( + int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")) + ), # Time (in seconds) after which the KV cache on the producer side is # automatically cleared if no READ notification is received from the # consumer. This is only applicable when using NixlConnector in a # disaggregated decode-prefill setup. - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": - lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")), - + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( + os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") + ), # Controls whether or not to use cudnn prefill - "VLLM_USE_CUDNN_PREFILL": - lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - - # If set to 1, use the TRTLLM attention backend in flashinfer. - "VLLM_USE_TRTLLM_ATTENTION": - lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), - + "VLLM_USE_CUDNN_PREFILL": lambda: bool( + int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) + ), + # If set to 1/True, use the TRTLLM attention backend in flashinfer. + # If set to 0/False, use the default attention backend in flashinfer. + # If not set, auto-detect the attention backend in flashinfer. + "VLLM_USE_TRTLLM_ATTENTION": lambda: ( + None + if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ + else os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true") + ), + # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": lambda: bool( + int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0")) + ), # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. - "VLLM_HAS_FLASHINFER_CUBIN": - lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), - + "VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # Otherwise, uses the first available of: flashinfer cutlass GEMM, # vllm cutlass GEMM, marlin GEMM. - "VLLM_USE_TRTLLM_FP4_GEMM": - lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))), - + "VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool( + int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0")) + ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. - "VLLM_ENABLE_CUDAGRAPH_GC": - lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), - + "VLLM_ENABLE_CUDAGRAPH_GC": lambda: bool( + int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0")) + ), + # Disable padding to CUDA graph capture batch sizes. + # TODO(wentao): https://github.com/vllm-project/vllm/issues/23378 + # After the issue is fixed, we can remove this flag. + "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": lambda: bool( + int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0")) + ), # Used to force set up loopback IP - "VLLM_LOOPBACK_IP": - lambda: os.getenv("VLLM_LOOPBACK_IP", ""), - + "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), # Used to set the process name prefix for vLLM processes. # This is useful for debugging and monitoring purposes. # The default value is "VLLM". - "VLLM_PROCESS_NAME_PREFIX": - lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), - + "VLLM_PROCESS_NAME_PREFIX": lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), # Allow chunked local attention with hybrid kv cache manager. # Currently using the Hybrid KV cache manager with chunked local attention # in the Llama4 models (the only models currently using chunked local attn) @@ -1141,10 +1305,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # This flag is used to allow users to enable it if they want to (to save on # kv-cache memory usage and enable longer contexts) # TODO(lucas): Remove this flag once latency regression is resolved. - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": - lambda: bool(int(os.getenv(\ - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))), - + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool( + int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0")) + ), # Enables support for the "store" option in the OpenAI Responses API. # When set to 1, vLLM's OpenAI server will retain the input and output # messages for those requests in memory. By default, this is disabled (0), @@ -1154,17 +1317,80 @@ environment_variables: dict[str, Callable[[], Any]] = { # lost when the vLLM server shuts down. # 2. Enabling this option will cause a memory leak, as stored messages are # never removed from memory until the server terminates. - "VLLM_ENABLE_RESPONSES_API_STORE": - lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), - + "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool( + int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0")) + ), + # If set, use the fp8 mfma in rocm paged attention. + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN": lambda: bool( + int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0")) + ), # Whether to use pytorch symmetric memory for allreduce - "VLLM_ALLREDUCE_USE_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), - + "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( + int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) + ), # Allows vllm to find tuned config under customized folder - "VLLM_TUNED_CONFIG_FOLDER": - lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), - + "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), + # Allows harmony instructions to be injected on system messages + "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool( + int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0")) + ), + # Add optional custom scopes for profiling, disable to avoid overheads + "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool( + int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0")) + ), + # Add optional nvtx scopes for profiling, disable to avoid overheads + "VLLM_NVTX_SCOPES_FOR_PROFILING": lambda: bool( + int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0")) + ), + # Represent block hashes in KV cache events as 64-bit integers instead of + # raw bytes. Defaults to True for backward compatibility. + "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": lambda: bool( + int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1")) + ), + # Name of the shared memory buffer used for object storage. + # Only effective when mm_config.mm_processor_cache_type == "shm". + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv( + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER" + ), + # The size in MB of the buffers (NVL and RDMA) used by DeepEP + "VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int( + os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024") + ), + # The number of SMs to allocate for communication kernels when running DBO + # the rest of the SMs on the device will be allocated to compute + "VLLM_DBO_COMM_SMS": lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), + # Valid values are container,code_interpreter,web_search_preview + # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter + "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_list_with_choices( + "GPT_OSS_SYSTEM_TOOL_MCP_LABELS", + [], + ["container", "code_interpreter", "web_search_preview"], + ), + # Enable max_autotune & coordinate_descent_tuning in inductor_config + # to compile static shapes passed from compile_sizes in compilation_config + # If set to 1, enable max_autotune; By default, this is enabled (1) + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE": lambda: bool( + int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1")) + ), + # If set to 1, enable coordinate_descent_tuning; + # By default, this is enabled (1) + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING": lambda: bool( + int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", "1")) + ), + # Flag to enable NCCL symmetric memory allocation and registration + "VLLM_USE_NCCL_SYMM_MEM": lambda: bool( + int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0")) + ), + # NCCL header path + "VLLM_NCCL_INCLUDE_PATH": lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), + # Flag to enable FBGemm kernels on model execution + "VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))), + # GC debug config + # - VLLM_GC_DEBUG=0: disable GC debugger + # - VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times + # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with + # top 5 collected objects + "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), } # --8<-- [end:env-vars-definition] @@ -1193,7 +1419,8 @@ def set_vllm_use_v1(use_v1: bool): raise ValueError( "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " "explicitly by the user. Please raise this as a Github " - "Issue and explicitly set VLLM_USE_V1=0 or 1.") + "Issue and explicitly set VLLM_USE_V1=0 or 1." + ) os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" @@ -1215,6 +1442,7 @@ def compute_hash() -> str: environment_variables_to_hash = [ "VLLM_PP_LAYER_PARTITION", "VLLM_MLA_DISABLE", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "VLLM_USE_TRITON_FLASH_ATTN", "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", @@ -1223,18 +1451,22 @@ def compute_hash() -> str: "VLLM_FUSED_MOE_CHUNK_SIZE", "VLLM_FLASHINFER_MOE_BACKEND", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION", - "VLLM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ATTENTION_BACKEND", "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_TRTLLM_FP4_GEMM", + "VLLM_USE_FUSED_MOE_GROUPED_TOPK", + "VLLM_USE_FLASHINFER_MOE_FP16", "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_LINEAR", @@ -1242,6 +1474,10 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_RMSNORM", "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", + "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", + "VLLM_ROCM_USE_TRITON_ROPE", + "VLLM_ROCM_USE_AITER_FP8BMM", + "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_MOE_PADDING", @@ -1249,18 +1485,20 @@ def compute_hash() -> str: "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", + "VLLM_USE_FBGEMM", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables, # it's not a user error, it's a bug - assert key in environment_variables, \ + assert key in environment_variables, ( "Please update environment_variables_to_hash in envs.py" + ) - factors = [ - environment_variables[key]() for key in environment_variables_to_hash - ] + factors = [environment_variables[key]() for key in environment_variables_to_hash] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 813232cd19281..7bdef5cbe748c 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,22 +4,22 @@ import asyncio import time from abc import ABC, abstractmethod +from collections.abc import Awaitable from functools import cached_property -from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, - Union) +from typing import Any, Callable, Optional, Union -import torch.nn as nn from typing_extensions import TypeVar import vllm.platforms from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.sequence import ExecuteModelRequest from vllm.tasks import SupportedTask from vllm.utils import make_async -from vllm.worker.worker_base import WorkerBase +from vllm.v1.outputs import PoolerOutput, SamplerOutput +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -30,7 +30,7 @@ class ExecutorBase(ABC): """Base class for all executors. An executor is responsible for executing the model on one device, - or it can be a distributed executor + or it can be a distributed executor that can execute the model on multiple devices. """ @@ -54,17 +54,20 @@ class ExecutorBase(ABC): self._init_executor() self.is_sleeping = False self.sleeping_tags: set[str] = set() + self.kv_output_aggregator = None @abstractmethod def _init_executor(self) -> None: raise NotImplementedError @abstractmethod - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: + def collective_rpc( + self, + method: Union[str, Callable[[WorkerBase], _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: """ Execute an RPC call on all workers. @@ -82,14 +85,14 @@ class ExecutorBase(ABC): Returns: A list containing the results from each worker. - + Note: It is recommended to use this API to only pass control messages, and set up data-plane communication to pass data. """ raise NotImplementedError - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. @@ -97,9 +100,10 @@ class ExecutorBase(ABC): ExecutorBase may require modification of the result, e.g. to ensure the selected cache sizes are compatible with all workers. - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where + `num_gpu_blocks` are blocks that are "active" on the device and can be + appended to. + `num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be appended to. """ results = self.collective_rpc("determine_num_available_blocks") @@ -108,33 +112,29 @@ class ExecutorBase(ABC): return a, b def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ + """Initialize the KV cache by invoking the underlying worker.""" # NOTE: This is logged in the executor because there can be >1 workers. - logger.info("# %s blocks: %d, # CPU blocks: %d", - vllm.platforms.current_platform.device_name, - num_gpu_blocks, num_cpu_blocks) - max_concurrency = (num_gpu_blocks * self.cache_config.block_size / - self.model_config.max_model_len) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, max_concurrency) + logger.info( + "# %s blocks: %d, # CPU blocks: %d", + vllm.platforms.current_platform.device_name, + num_gpu_blocks, + num_cpu_blocks, + ) + max_concurrency = ( + num_gpu_blocks + * self.cache_config.block_size + / self.model_config.max_model_len + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, + max_concurrency, + ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self.collective_rpc("initialize_cache", - args=(num_gpu_blocks, num_cpu_blocks)) - - def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - """ - Run a function directly on the model inside each worker, - returning the result for each of them. - """ - - def rpc_func(worker: WorkerBase) -> _R: - return func(worker.get_model()) - - return self.collective_rpc(rpc_func) + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) @cached_property # Avoid unnecessary RPC calls def supported_tasks(self) -> tuple[SupportedTask, ...]: @@ -143,9 +143,8 @@ class ExecutorBase(ABC): def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - output = self.collective_rpc("execute_model", - args=(execute_model_req, )) + ) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]: + output = self.collective_rpc("execute_model", args=(execute_model_req,)) return output[0] def stop_remote_worker_execution_loop(self) -> None: @@ -154,22 +153,26 @@ class ExecutorBase(ABC): def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("add_lora", args=(lora_request, ))) + return all(self.collective_rpc("add_lora", args=(lora_request,))) def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("remove_lora", args=(lora_id, ))) + return all(self.collective_rpc("remove_lora", args=(lora_id,))) def pin_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("pin_lora", args=(lora_id, ))) + return all(self.collective_rpc("pin_lora", args=(lora_id,))) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: sets = self.collective_rpc("list_loras") for s in sets: assert s == sets[0], "All workers should have the same LORAs." return sets[0] + def reset_mm_cache(self) -> None: + """Reset the multi-modal cache in each worker.""" + self.collective_rpc("reset_mm_cache") + def start_profile(self) -> None: self.collective_rpc("start_profile") @@ -185,8 +188,9 @@ class ExecutorBase(ABC): time_after_sleep = time.perf_counter() self.sleeping_tags = {"weights", "kv_cache"} self.is_sleeping = True - logger.info("It took %.6f seconds to fall asleep.", - time_after_sleep - time_before_sleep) + logger.info( + "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep + ) def wake_up(self, tags: Optional[list[str]] = None): if not self.is_sleeping: @@ -195,15 +199,18 @@ class ExecutorBase(ABC): if tags: for tag in tags: if tag not in self.sleeping_tags: - logger.warning("Tag %s is not in sleeping tags %s", tag, - self.sleeping_tags) + logger.warning( + "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags + ) return time_before_wakeup = time.perf_counter() self.collective_rpc("wake_up", kwargs=dict(tags=tags)) time_after_wakeup = time.perf_counter() - logger.info("It took %.6f seconds to wake up tags %s.", - time_after_wakeup - time_before_wakeup, - tags if tags is not None else self.sleeping_tags) + logger.info( + "It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags, + ) if tags: for tag in tags: self.sleeping_tags.remove(tag) @@ -218,10 +225,10 @@ class ExecutorBase(ABC): pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self.collective_rpc("save_sharded_state", - kwargs=dict(path=path, - pattern=pattern, - max_size=max_size)) + self.collective_rpc( + "save_sharded_state", + kwargs=dict(path=path, pattern=pattern, max_size=max_size), + ) @abstractmethod def check_health(self) -> None: @@ -231,14 +238,11 @@ class ExecutorBase(ABC): def shutdown(self) -> None: """Shutdown the executor.""" - return - - def __del__(self): - self.shutdown() + self.collective_rpc("shutdown") async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: """Executes one model step on the given sequences.""" output = await make_async(self.execute_model)(execute_model_req) return output @@ -252,6 +256,12 @@ class ExecutorBase(ABC): exception.""" self.check_health() + def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None: + """Init KVOutputAggregator""" + self.kv_output_aggregator = KVOutputAggregator( + finished_count or self.parallel_config.world_size + ) + class DistributedExecutorBase(ExecutorBase): """Abstract superclass of distributed executor implementations.""" @@ -266,12 +276,13 @@ class DistributedExecutorBase(ExecutorBase): def execute_model( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: # TODO: unify into collective_rpc if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", - async_run_tensor_parallel_workers_only=True) + async_run_tensor_parallel_workers_only=True, + ) # Only the driver worker returns the sampling results. driver_outputs = self._driver_execute_model(execute_model_req) @@ -292,7 +303,7 @@ class DistributedExecutorBase(ExecutorBase): @abstractmethod def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop @@ -301,11 +312,13 @@ class DistributedExecutorBase(ExecutorBase): """ raise NotImplementedError - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[Any]: return self._run_workers(method, *args, **(kwargs or {})) @abstractmethod @@ -324,7 +337,7 @@ class DistributedExecutorBase(ExecutorBase): run only in the remote TP workers, not the driver worker. It will also be run asynchronously and return a list of futures rather than blocking on the results. - + # TODO: simplify and merge with collective_rpc """ raise NotImplementedError @@ -336,12 +349,13 @@ class DistributedExecutorBase(ExecutorBase): raise NotImplementedError async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop()) + self._start_worker_execution_loop() + ) # Only the driver worker returns the sampling results. return await self._driver_execute_model_async(execute_model_req) @@ -361,7 +375,7 @@ class DistributedExecutorBase(ExecutorBase): async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: """Execute the model asynchronously in the driver worker. Passing None will cause the driver to stop the model execution diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py deleted file mode 100644 index 4e8c6d79095f9..0000000000000 --- a/vllm/executor/mp_distributed_executor.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -from typing import Any, Callable, List, Optional, Union - -import cloudpickle - -from vllm.executor.executor_base import DistributedExecutorBase -from vllm.executor.multiproc_worker_utils import ( - ProcessWorkerWrapper, ResultHandler, WorkerMonitor, - set_multiprocessing_worker_envs) -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, - get_distributed_init_method, get_ip, get_open_port, - make_async, run_method, update_environment_variables) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -class MultiprocessingDistributedExecutor(DistributedExecutorBase): - """Python multiprocessing-based distributed executor""" - - uses_ray: bool = False - - def _check_cuda(self) -> None: - """Check that the number of GPUs is sufficient for the parallel - configuration. Separate from _init_executor to reduce the number of - indented blocks. - """ - parallel_config = self.parallel_config - world_size = parallel_config.world_size - tensor_parallel_size = parallel_config.tensor_parallel_size - - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - if tensor_parallel_size > cuda_device_count: - raise RuntimeError( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - if world_size > cuda_device_count: - raise RuntimeError( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - - def _init_executor(self) -> None: - - from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - self._check_cuda() - - # Create the parallel GPU workers. - world_size = self.parallel_config.world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - - # Set multiprocessing envs that are common to V0 and V1 - set_multiprocessing_worker_envs(self.parallel_config) - - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - self.workers: List[ProcessWorkerWrapper] = [] - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: List[ProcessWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: List[ProcessWorkerWrapper] = [] - - if world_size == 1: - self.worker_monitor = None - else: - result_handler = ResultHandler() - for rank in range(1, world_size): - worker = ProcessWorkerWrapper(result_handler, - WorkerWrapperBase, - self.vllm_config, rank) - self.workers.append(worker) - if rank % tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) - - self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() - self.worker_monitor.start() - - # Set up signal handlers to shutdown the executor cleanly - # sometimes gc does not work well - - self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) - - all_kwargs = [] - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - for i in range(world_size): - local_rank = i - rank = i - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=(not self.parallel_config) - or (rank % self.parallel_config.tensor_parallel_size == 0), - ) - all_kwargs.append(kwargs) - self._run_workers("init_worker", all_kwargs) - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - self.driver_exec_model = make_async(self.driver_worker.execute_model) - self.pp_locks: Optional[List[asyncio.Lock]] = None - - def shutdown(self): - if (worker_monitor := getattr(self, "worker_monitor", - None)) is not None: - worker_monitor.close() - - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - return self.driver_worker.execute_model(execute_model_req) - - def _run_workers( - self, - method: Union[str, Callable], - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> List[Any]: - """Runs the given method on all workers. - - Args: - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - """ - if isinstance(method, str): - sent_method = method - else: - sent_method = cloudpickle.dumps(method) - del method - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - if async_run_tensor_parallel_workers_only: - # Run only non-driver workers and just return futures. - return [ - worker.execute_method(sent_method, *args, **kwargs) - for worker in self.non_driver_workers - ] - - # Start all remote workers first. - worker_outputs = [ - worker.execute_method(sent_method, *args, **kwargs) - for worker in self.workers - ] - - driver_worker_output = run_method(self.driver_worker, sent_method, - args, kwargs) - - # Get the results of the workers. - return [driver_worker_output - ] + [output.get() for output in worker_outputs] - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - if self.worker_monitor is not None and not self.worker_monitor.is_alive( - ): - raise RuntimeError("Worker processes are not running") - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - for result in parallel_worker_tasks: - result.get() - - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - if not self.tp_driver_workers: - return await self.driver_exec_model(execute_model_req) - - if self.pp_locks is None: - # This locks each pipeline parallel stage so multiple virtual - # engines can't execute on the same stage at the same time - # We create the locks here to avoid creating them in the constructor - # which uses a different asyncio loop. - self.pp_locks = [ - asyncio.Lock() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - tasks = [ - asyncio.create_task( - _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], - execute_model_req)) - ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, - start=1): - tasks.append( - asyncio.create_task( - _run_task_with_lock(driver_worker.execute_method_async, - self.pp_locks[pp_rank], - "execute_model", execute_model_req))) - results = await asyncio.gather(*tasks) - - # Only the last PP stage has the final results. - return results[-1] - - async def _start_worker_execution_loop(self): - coros = [ - worker.execute_method_async("start_worker_execution_loop") - for worker in self.non_driver_workers - ] - return await asyncio.gather(*coros) diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index 4ce6d8dfad2cc..ac16f06b160e1 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from array import array -from typing import Any, Type +from typing import Any from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE @@ -16,13 +16,14 @@ def encode_hook(obj: Any) -> Any: if isinstance(obj, array): assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " - f"Given array has a type code of {obj.typecode}.") + f"Given array has a type code of {obj.typecode}." + ) return obj.tobytes() if isinstance(obj, MultiModalKwargs): return dict(obj) -def decode_hook(type: Type, obj: Any) -> Any: +def decode_hook(type: type, obj: Any) -> Any: """Custom msgspec dec hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py deleted file mode 100644 index 48b3479ed7997..0000000000000 --- a/vllm/executor/multiproc_worker_utils.py +++ /dev/null @@ -1,279 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -import threading -import uuid -from dataclasses import dataclass -from multiprocessing import Queue -from multiprocessing.connection import wait -from multiprocessing.process import BaseProcess -from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union - -import torch - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.utils import (_maybe_force_spawn, decorate_logs, get_mp_context, - run_method) - -logger = init_logger(__name__) - -T = TypeVar('T') - -_TERMINATE = "TERMINATE" # sentinel - -JOIN_TIMEOUT_S = 2 - - -@dataclass -class Result(Generic[T]): - """Result of task dispatched to worker""" - - task_id: uuid.UUID - value: Optional[T] = None - exception: Optional[BaseException] = None - - -class ResultFuture(threading.Event, Generic[T]): - """Synchronous future for non-async case""" - - def __init__(self): - super().__init__() - self.result: Optional[Result[T]] = None - - def set_result(self, result: Result[T]): - self.result = result - self.set() - - def get(self) -> T: - self.wait() - assert self.result is not None - if self.result.exception is not None: - raise self.result.exception - return self.result.value # type: ignore[return-value] - - -def _set_future_result(future: Union[ResultFuture, asyncio.Future], - result: Result): - if isinstance(future, ResultFuture): - future.set_result(result) - return - loop = future.get_loop() - if not loop.is_closed(): - if result.exception is not None: - loop.call_soon_threadsafe(future.set_exception, result.exception) - else: - loop.call_soon_threadsafe(future.set_result, result.value) - - -class ResultHandler(threading.Thread): - """Handle results from all workers (in background thread)""" - - def __init__(self) -> None: - super().__init__(daemon=True) - self.result_queue = get_mp_context().Queue() - self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} - - def run(self): - for result in iter(self.result_queue.get, _TERMINATE): - future = self.tasks.pop(result.task_id) - _set_future_result(future, result) - # Ensure that all waiters will receive an exception - for task_id, future in self.tasks.items(): - _set_future_result( - future, - Result(task_id=task_id, - exception=ChildProcessError("worker died"))) - - def close(self): - self.result_queue.put(_TERMINATE) - - -class WorkerMonitor(threading.Thread): - """Monitor worker status (in background thread)""" - - def __init__(self, workers: List['ProcessWorkerWrapper'], - result_handler: ResultHandler): - super().__init__(daemon=True) - self.workers = workers - self.result_handler = result_handler - self._close = False - - def run(self) -> None: - # Blocks until any worker exits - dead_sentinels = wait([w.process.sentinel for w in self.workers]) - if not self._close: - self._close = True - - # Kill / cleanup all workers - for worker in self.workers: - process = worker.process - if process.sentinel in dead_sentinels: - process.join(JOIN_TIMEOUT_S) - if process.exitcode is not None and process.exitcode != 0: - logger.error("Worker %s pid %s died, exit code: %s", - process.name, process.pid, process.exitcode) - # Cleanup any remaining workers - if logger: - logger.info("Killing local vLLM worker processes") - for worker in self.workers: - worker.kill_worker() - # Must be done after worker task queues are all closed - self.result_handler.close() - - for worker in self.workers: - worker.process.join(JOIN_TIMEOUT_S) - - def close(self): - if self._close: - return - self._close = True - logger.info("Terminating local vLLM worker processes") - for worker in self.workers: - worker.terminate_worker() - # Must be done after worker task queues are all closed - self.result_handler.close() - - -class ProcessWorkerWrapper: - """Local process wrapper for vllm.worker.Worker, - for handling single-node multi-GPU tensor parallel.""" - - def __init__(self, result_handler: ResultHandler, - worker_factory: Callable[[VllmConfig, int], Any], - vllm_config: VllmConfig, rank: int) -> None: - self.mp = get_mp_context() - self._task_queue = self.mp.Queue() - self.result_queue = result_handler.result_queue - self.tasks = result_handler.tasks - self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined] - target=_run_worker_process, - name="VllmWorkerProcess", - kwargs=dict( - worker_factory=worker_factory, - task_queue=self._task_queue, - result_queue=self.result_queue, - vllm_config=vllm_config, - rank=rank, - ), - daemon=True) - - self.process.start() - - def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], - method: Union[str, bytes], args, kwargs): - task_id = uuid.uuid4() - self.tasks[task_id] = future - try: - self._task_queue.put((task_id, method, args, kwargs)) - except SystemExit: - raise - except BaseException as e: - del self.tasks[task_id] - raise ChildProcessError("worker died") from e - - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - future: ResultFuture = ResultFuture() - self._enqueue_task(future, method, args, kwargs) - return future - - async def execute_method_async(self, method: Union[str, bytes], *args, - **kwargs): - future = asyncio.get_running_loop().create_future() - self._enqueue_task(future, method, args, kwargs) - return await future - - def terminate_worker(self): - try: - self._task_queue.put(_TERMINATE) - except ValueError: - self.process.kill() - self._task_queue.close() - - def kill_worker(self): - self._task_queue.close() - self.process.kill() - - -def _run_worker_process( - worker_factory: Callable[[VllmConfig, int], Any], - task_queue: Queue, - result_queue: Queue, - vllm_config: VllmConfig, - rank: int, -) -> None: - """Worker process event loop""" - - # Add process-specific prefix to stdout and stderr - process_name = get_mp_context().current_process().name - decorate_logs(process_name) - - # Initialize worker - worker = worker_factory(vllm_config, rank) - del worker_factory - - # Accept tasks from the engine in task_queue - # and return task output in result_queue - logger.info("Worker ready; awaiting tasks") - try: - for items in iter(task_queue.get, _TERMINATE): - output = None - exception = None - task_id, method, args, kwargs = items - try: - output = run_method(worker, method, args, kwargs) - except SystemExit: - raise - except KeyboardInterrupt: - break - except BaseException as e: - logger.exception( - "Exception in worker %s while processing method %s.", - process_name, method) - exception = e - result_queue.put( - Result(task_id=task_id, value=output, exception=exception)) - except KeyboardInterrupt: - pass - except Exception: - logger.exception("Worker failed") - - # Flush TunableOp results when TunableOp is enabled and - # online (in situ) tuning is enabled. - # Offline tuning API (record_untuned_is_enabled()) only - # available in PyTorch 2.6 or later. - if torch.cuda.is_available(): - import torch.cuda.tunable as tunable - if (tunable.is_enabled() and tunable.tuning_is_enabled() - and not tunable.record_untuned_is_enabled()): - tunable.write_file() - - logger.info("Worker exiting") - - -def set_multiprocessing_worker_envs(parallel_config): - """ Set up environment variables that should be used when there are workers - in a multiprocessing environment. This should be called by the parent - process before worker processes are created""" - - _maybe_force_spawn() - - # Configure thread parallelism if OMP_NUM_THREADS isn't set - # - # Helps to avoid CPU contention. The default of spawning a thread per - # core combined with multiprocessing for each GPU can have a negative - # impact on performance. The contention is amplified when running in a - # container where CPU limits can cause throttling. - default_omp_num_threads = 1 - if "OMP_NUM_THREADS" not in os.environ and ( - current_parallelism := - torch.get_num_threads()) > default_omp_num_threads: - logger.warning( - "Reducing Torch parallelism from %d threads to %d to avoid " - "unnecessary CPU contention. Set OMP_NUM_THREADS in the " - "external environment to tune this value as needed.", - current_parallelism, default_omp_num_threads) - os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) - torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 37c3fe59c65dd..6a9608d70b69d 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -5,24 +5,27 @@ import asyncio import os from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import cloudpickle import msgspec import vllm.envs as envs -from vllm.executor.executor_base import ( - DistributedExecutorBase) # yapf: disable +from vllm.executor.executor_base import DistributedExecutorBase from vllm.executor.msgspec_utils import encode_hook -from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, - ray) +from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, make_async) +from vllm.utils import ( + _run_task_with_lock, + get_distributed_init_method, + get_ip, + get_open_port, + make_async, +) +from vllm.v1.outputs import SamplerOutput if ray is not None: from ray.actor import ActorHandle @@ -43,6 +46,7 @@ class RayWorkerMetaData: The order of ray worker creation can be random, and we need to reset the rank after creating all workers. """ + worker: ActorHandle created_rank: int adjusted_rank: int = -1 @@ -55,7 +59,10 @@ class RayDistributedExecutor(DistributedExecutorBase): # These env vars are worker-specific, therefore are NOT copied # from the driver to the workers WORKER_SPECIFIC_ENV_VARS = { - "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + "VLLM_HOST_IP", + "VLLM_HOST_PORT", + "LOCAL_RANK", + "CUDA_VISIBLE_DEVICES", } # These non-vLLM env vars are copied from the driver to workers @@ -86,13 +93,13 @@ class RayDistributedExecutor(DistributedExecutorBase): self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER if self.use_ray_compiled_dag: assert self.use_ray_spmd_worker, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires " - "VLLM_USE_RAY_SPMD_WORKER=1") + "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1" + ) if self.use_ray_spmd_worker: # TODO: Support SPMD worker for non-DAG Ray executor. assert self.use_ray_compiled_dag, ( - "VLLM_USE_RAY_SPMD_WORKER=1 requires " - "VLLM_USE_RAY_COMPILED_DAG=1") + "VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1" + ) assert self.uses_ray initialize_ray_cluster(self.parallel_config) @@ -107,39 +114,42 @@ class RayDistributedExecutor(DistributedExecutorBase): self._init_workers_ray(placement_group) self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - self.output_decoder = msgspec.msgpack.Decoder( - Optional[List[SamplerOutput]]) + self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]]) self.use_v1 = envs.VLLM_USE_V1 - self.pp_locks: Optional[List[asyncio.Lock]] = None + self.pp_locks: Optional[list[asyncio.Lock]] = None if not self.use_ray_compiled_dag: - self.driver_exec_method = make_async( - self.driver_worker.execute_method) + self.driver_exec_method = make_async(self.driver_worker.execute_method) def shutdown(self) -> None: - logger.info( - "Shutting down Ray distributed executor. If you see error log " - "from logging.cc regarding SIGTERM received, please ignore because " - "this is the expected termination process in Ray.") + if logger: + # Somehow logger can be None here. + logger.info( + "Shutting down Ray distributed executor. If you see error log " + "from logging.cc regarding SIGTERM received, please ignore " + "because this is the expected termination process in Ray." + ) if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() import ray + for worker in self.workers: ray.kill(worker) self.forward_dag = None - def _configure_ray_workers_use_nsight(self, - ray_remote_kwargs) -> Dict[str, Any]: + def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]: # If nsight profiling is enabled, we need to set the profiling # configuration for the ray workers as runtime env. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) - runtime_env.update({ - "nsight": { - "t": "cuda,cudnn,cublas", - "o": "'worker_process_%p'", - "cuda-graph-trace": "node", + runtime_env.update( + { + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } } - }) + ) return ray_remote_kwargs @@ -147,49 +157,50 @@ class RayDistributedExecutor(DistributedExecutorBase): def _get_env_vars_to_be_updated(self): return self._env_vars_for_all_workers - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] + self.workers: list[RayWorkerWrapper] = [] # Used in ray compiled DAG: indexed first by PP rank, # and then TP rank. In other words, the inner list is # the TP group of workers for a PP rank. - self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + self.pp_tp_workers: list[list[RayWorkerWrapper]] = [] if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( - ray_remote_kwargs) + ray_remote_kwargs + ) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. - bundle_indices: List[int] + bundle_indices: list[int] if envs.VLLM_RAY_BUNDLE_INDICES: # Use the bundle indices specified by the user. - bundle_indices = list( - map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) - assert len(bundle_indices) == self.parallel_config.world_size, \ - ("VLLM_RAY_BUNDLE_INDICES must have the same size" - f" as the world size, but got {bundle_indices=} " - f"and {self.parallel_config.world_size=}") - assert len(set(bundle_indices)) == len(bundle_indices), \ - ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," - f" but got {bundle_indices=}") + bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) + assert len(bundle_indices) == self.parallel_config.world_size, ( + "VLLM_RAY_BUNDLE_INDICES must have the same size" + f" as the world size, but got {bundle_indices=} " + f"and {self.parallel_config.world_size=}" + ) + assert len(set(bundle_indices)) == len(bundle_indices), ( + "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}" + ) else: # use the first N bundles that have GPU resources. bundle_indices = [] for bundle_id, bundle in enumerate(placement_group.bundle_specs): if bundle.get(current_platform.ray_device_key, 0): bundle_indices.append(bundle_id) - bundle_indices = bundle_indices[:self.parallel_config.world_size] + bundle_indices = bundle_indices[: self.parallel_config.world_size] - worker_metadata: List[RayWorkerMetaData] = [] + worker_metadata: list[RayWorkerMetaData] = [] driver_ip = get_ip() for rank, bundle_id in enumerate(bundle_indices): scheduling_strategy = PlacementGroupSchedulingStrategy( @@ -205,8 +216,7 @@ class RayDistributedExecutor(DistributedExecutorBase): num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rpc_rank=rank) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank) else: worker = ray.remote( num_cpus=0, @@ -214,15 +224,15 @@ class RayDistributedExecutor(DistributedExecutorBase): resources={current_platform.ray_device_key: num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rpc_rank=rank) - worker_metadata.append( - RayWorkerMetaData(worker=worker, created_rank=rank)) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank) + worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) - worker_ips = ray.get([ - each.worker.get_node_ip.remote() # type: ignore[attr-defined] - for each in worker_metadata - ]) + worker_ips = ray.get( + [ + each.worker.get_node_ip.remote() # type: ignore[attr-defined] + for each in worker_metadata + ] + ) for each, ip in zip(worker_metadata, worker_ips): each.ip = ip @@ -237,7 +247,8 @@ class RayDistributedExecutor(DistributedExecutorBase): # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config, rpc_rank=0) + vllm_config=self.vllm_config, rpc_rank=0 + ) worker_metadata.pop(i) break @@ -248,9 +259,10 @@ class RayDistributedExecutor(DistributedExecutorBase): "Ray does not allocate any GPUs on the driver node." f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." "Consider adjusting the Ray placement group or running " - "the driver on a GPU node.") + "the driver on a GPU node." + ) - ip_counts: Dict[str, int] = {} + ip_counts: dict[str, int] = {} for ip in worker_ips: ip_counts[ip] = ip_counts.get(ip, 0) + 1 @@ -270,15 +282,15 @@ class RayDistributedExecutor(DistributedExecutorBase): # After sorting, the workers on the same node will be # close to each other, and the workers on the driver # node will be placed first. - sorted_worker_metadata = sorted(worker_metadata, - key=sort_by_driver_then_worker_ip) + sorted_worker_metadata = sorted( + worker_metadata, key=sort_by_driver_then_worker_ip + ) start_rank = 0 if self.use_ray_spmd_worker else 1 for i, item in enumerate(sorted_worker_metadata): item.adjusted_rank = i + start_rank self.workers = [item.worker for item in sorted_worker_metadata] rerank_mapping = { - item.created_rank: item.adjusted_rank - for item in sorted_worker_metadata + item.created_rank: item.adjusted_rank for item in sorted_worker_metadata } self._run_workers("adjust_rank", rerank_mapping) @@ -289,8 +301,8 @@ class RayDistributedExecutor(DistributedExecutorBase): # driver_dummy_worker can be None when using ray spmd worker. continue worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote()) \ - ) # type: ignore + ray.get(worker.get_node_and_gpu_ids.remote()) + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -318,20 +330,27 @@ class RayDistributedExecutor(DistributedExecutorBase): f"{n_ips} unique IP addresses {all_ips}. Please check your" " network configuration. If you set `VLLM_HOST_IP`" " environment variable, make sure it is unique for" - " each node.") + " each node." + ) # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [{ - current_platform.device_control_env_var: - ",".join(map(str, node_gpus[node_id])), - } for (node_id, _) in worker_node_and_gpu_ids] + all_args_to_update_environment_variables = [ + { + current_platform.device_control_env_var: ",".join( + map(str, node_gpus[node_id]) + ), + } + for (node_id, _) in worker_node_and_gpu_ids + ] # Environment variables to copy from driver to workers env_vars_to_copy = get_env_vars_to_copy( exclude_vars=self.WORKER_SPECIFIC_ENV_VARS, additional_vars=set(current_platform.additional_env_vars).union( - self.ADDITIONAL_ENV_VARS), - destination="workers") + self.ADDITIONAL_ENV_VARS + ), + destination="workers", + ) # Copy existing env vars to each worker's args for args in all_args_to_update_environment_variables: @@ -340,11 +359,11 @@ class RayDistributedExecutor(DistributedExecutorBase): if name in os.environ: args[name] = os.environ[name] - self._env_vars_for_all_workers = ( - all_args_to_update_environment_variables) + self._env_vars_for_all_workers = all_args_to_update_environment_variables - self._run_workers("update_environment_variables", - self._get_env_vars_to_be_updated()) + self._run_workers( + "update_environment_variables", self._get_env_vars_to_be_updated() + ) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address. @@ -357,7 +376,8 @@ class RayDistributedExecutor(DistributedExecutorBase): # the node. driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) + driver_ip, get_open_port() + ) # Initialize the actual workers inside worker wrapper. all_kwargs = [] @@ -375,19 +395,20 @@ class RayDistributedExecutor(DistributedExecutorBase): self._run_workers("init_worker", all_kwargs) self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, + ) if self.use_ray_spmd_worker: for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) - for tp_rank in range( - self.parallel_config.tensor_parallel_size): + for tp_rank in range(self.parallel_config.tensor_parallel_size): # PP=2, TP=4 # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = (pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank + rank = ( + pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank assert len(self.pp_tp_workers[pp_rank]) == tp_rank assert pp_rank < len(self.pp_tp_workers) self.pp_tp_workers[pp_rank].append(self.workers[rank]) @@ -395,11 +416,11 @@ class RayDistributedExecutor(DistributedExecutorBase): # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. - self.tp_driver_workers: List[RayWorkerWrapper] = [] + self.tp_driver_workers: list[RayWorkerWrapper] = [] # This is the list of workers that are not drivers and not the first # worker in a TP group. These are the workers that will be # broadcasted to. - self.non_driver_workers: List[RayWorkerWrapper] = [] + self.non_driver_workers: list[RayWorkerWrapper] = [] # Enforce rank order for correct rank to return final output. for index, worker in enumerate(self.workers): @@ -412,20 +433,20 @@ class RayDistributedExecutor(DistributedExecutorBase): def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") - return self.driver_worker.execute_method("execute_model", - execute_model_req) + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" + ) + return self.driver_worker.execute_method("execute_model", execute_model_req) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: if not self.use_ray_spmd_worker: return super().execute_model(execute_model_req) @@ -437,10 +458,7 @@ class RayDistributedExecutor(DistributedExecutorBase): else: serialized_data = self.input_encoder.encode(execute_model_req) outputs = ray.get(self.forward_dag.execute(serialized_data)) - if self.use_v1: - output = outputs[0] - else: - output = self.output_decoder.decode(outputs[0]) + output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0]) return output def _run_workers( @@ -461,19 +479,15 @@ class RayDistributedExecutor(DistributedExecutorBase): rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs """ - if isinstance(method, str): - sent_method = method - else: - sent_method = cloudpickle.dumps(method) + sent_method = method if isinstance(method, str) else cloudpickle.dumps(method) del method if self.use_ray_spmd_worker: assert not async_run_tensor_parallel_workers_only, ( - "async_run_tensor_parallel_workers_only is not supported for " - "spmd mode.") + "async_run_tensor_parallel_workers_only is not supported for spmd mode." + ) if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") + raise NotImplementedError("max_concurrent_workers is not supported yet.") # Start the ray workers first. ray_workers = self.workers @@ -517,23 +531,27 @@ class RayDistributedExecutor(DistributedExecutorBase): required_version = version.parse("2.43.0") current_version = version.parse(importlib.metadata.version("ray")) if current_version < required_version: - raise ValueError(f"Ray version {required_version} is " - f"required, but found {current_version}") + raise ValueError( + f"Ray version {required_version} is " + f"required, but found {current_version}" + ) import importlib.util - cgraph_spec = importlib.util.find_spec( - "ray.experimental.compiled_dag_ref") + + cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref") if cgraph_spec is None: - raise ValueError("Ray Compiled Graph is not installed. " - "Run `pip install ray[cgraph]` to install it.") + raise ValueError( + "Ray Compiled Graph is not installed. " + "Run `pip install ray[cgraph]` to install it." + ) cupy_spec = importlib.util.find_spec("cupy") - if (cupy_spec is None - and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"): + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl": raise ValueError( "cupy is not installed but required since " "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. " - "Run `pip install ray[cgraph]` and check cupy installation.") + "Run `pip install ray[cgraph]` and check cupy installation." + ) def _compiled_ray_dag(self, enable_asyncio: bool): assert self.parallel_config.use_ray @@ -547,18 +565,26 @@ class RayDistributedExecutor(DistributedExecutorBase): # ray.dag, otherwise it will not take effect. os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112 from ray.dag import InputNode, MultiOutputNode - logger.info("RAY_CGRAPH_get_timeout is set to %s", - os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112 - logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", - envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE) - logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", - envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + + logger.info( + "RAY_CGRAPH_get_timeout is set to %s", + os.environ["RAY_CGRAPH_get_timeout"], # noqa: SIM112 + ) + logger.info( + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE, + ) + logger.info( + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM, + ) channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE if channel_type not in ("auto", "nccl", "shm"): raise ValueError( "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: " - f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.") + f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'." + ) with InputNode() as input_data: # Example DAG: PP=2, TP=4 @@ -583,20 +609,24 @@ class RayDistributedExecutor(DistributedExecutorBase): # and the TP group executes in SPMD fashion. if self.use_v1: outputs = [ - worker.execute_model_ray. - bind( # type: ignore[attr-defined] - outputs[i]) for i, worker in enumerate(tp_group) + worker.execute_model_ray.bind( # type: ignore[attr-defined] + outputs[i] + ) + for i, worker in enumerate(tp_group) ] else: outputs = [ - worker.execute_model_spmd. - bind( # type: ignore[attr-defined] - outputs[i]) for i, worker in enumerate(tp_group) + worker.execute_model_spmd.bind( # type: ignore[attr-defined] + outputs[i] + ) + for i, worker in enumerate(tp_group) ] last_pp_rank = len(self.pp_tp_workers) - 1 - if (pp_rank < last_pp_rank and - envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"): + if ( + pp_rank < last_pp_rank + and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm" + ): # Specify how intermediate tensors should be passed # between pp stages, no need to specify for the last # pp stage or when using shared memory (the default). @@ -610,30 +640,37 @@ class RayDistributedExecutor(DistributedExecutorBase): if envs.VLLM_USE_RAY_WRAPPED_PP_COMM: from ray.experimental.channel.accelerator_context import ( - register_accelerator_context) + register_accelerator_context, + ) from vllm.distributed.device_communicators.ray_communicator import ( - RayPPCommunicator) - register_accelerator_context(torch_module_name="cuda", - communicator_cls=RayPPCommunicator) - logger.info("Using RayPPCommunicator " - "(which wraps vLLM _PP GroupCoordinator) " - "for Ray Compiled Graph communication.") + RayPPCommunicator, + ) + + register_accelerator_context( + torch_module_name="cuda", communicator_cls=RayPPCommunicator + ) + logger.info( + "Using RayPPCommunicator " + "(which wraps vLLM _PP GroupCoordinator) " + "for Ray Compiled Graph communication." + ) else: - logger.info("Using Ray's NCCL communicator for " - "Ray Compiled Graph communication.") + logger.info( + "Using Ray's NCCL communicator for Ray Compiled Graph communication." + ) return forward_dag.experimental_compile( enable_asyncio=enable_asyncio, - _overlap_gpu_communication=envs. - VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM, + ) def __del__(self): self.shutdown() async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: if not self.use_ray_spmd_worker: return await super().execute_model_async(execute_model_req) @@ -646,14 +683,13 @@ class RayDistributedExecutor(DistributedExecutorBase): return self.output_decoder.decode(output) async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] = None + ) -> list[SamplerOutput]: assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" + ) if not self.tp_driver_workers: - return await self.driver_exec_method("execute_model", - execute_model_req) + return await self.driver_exec_method("execute_model", execute_model_req) if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -666,16 +702,25 @@ class RayDistributedExecutor(DistributedExecutorBase): tasks = [ asyncio.create_task( - _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], - "execute_model", execute_model_req)) + _run_task_with_lock( + self.driver_exec_method, + self.pp_locks[0], + "execute_model", + execute_model_req, + ) + ) ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, - start=1): + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1): tasks.append( asyncio.create_task( - _run_task_with_lock(driver_worker.execute_method.remote, - self.pp_locks[pp_rank], - "execute_model", execute_model_req))) + _run_task_with_lock( + driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", + execute_model_req, + ) + ) + ) results = await asyncio.gather(*tasks) @@ -684,7 +729,8 @@ class RayDistributedExecutor(DistributedExecutorBase): async def _start_worker_execution_loop(self): assert not self.use_ray_spmd_worker, ( - "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1" + ) coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7abaffa54c089..c3c8a70678add 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -4,18 +4,20 @@ import os import time from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import msgspec import vllm.platforms from vllm.config import ParallelConfig +from vllm.distributed import get_pp_group from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm.v1.worker.worker_base import WorkerWrapperBase if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -28,11 +30,13 @@ try: import ray from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup + try: from ray._private.state import available_resources_per_node except ImportError: # Ray 2.9.x doesn't expose `available_resources_per_node` from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node class RayWorkerWrapper(WorkerWrapperBase): @@ -47,27 +51,28 @@ try: # that thread. self.compiled_dag_cuda_device_set = False - self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, - dec_hook=decode_hook) + self.input_decoder = msgspec.msgpack.Decoder( + ExecuteModelRequest, dec_hook=decode_hook + ) self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) def get_node_ip(self) -> str: return get_ip() - def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: node_id = ray.get_runtime_context().get_node_id() device_key = vllm.platforms.current_platform.ray_device_key if not device_key: - raise RuntimeError("current platform %s does not support ray.", - vllm.platforms.current_platform.device_name) - gpu_ids = ray.get_runtime_context().get_accelerator_ids( - )[device_key] + raise RuntimeError( + "current platform %s does not support ray.", + vllm.platforms.current_platform.device_name, + ) + gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] return node_id, gpu_ids def execute_model_spmd( - self, req_or_tuple: Union[bytes, - Tuple[bytes, - Optional[IntermediateTensors]]] + self, + req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]], ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. @@ -92,8 +97,9 @@ try: current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - output = self.worker._execute_model_spmd(execute_model_req, - intermediate_tensors) + output = self.worker._execute_model_spmd( + execute_model_req, intermediate_tensors + ) # Pipeline model request and output to the next pipeline stage. if isinstance(output, IntermediateTensors): output = serialized_req, output @@ -119,11 +125,12 @@ try: def execute_model_ray( self, - scheduler_output: Union["SchedulerOutput", - Tuple["SchedulerOutput", - "IntermediateTensors"]], - ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput", - "IntermediateTensors"]]: + scheduler_output: Union[ + "SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"] + ], + ) -> Union[ + "ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"] + ]: # This method is used by Ray Compiled Graph to execute the model, # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() @@ -133,12 +140,23 @@ try: else: scheduler_output, intermediate_tensors = scheduler_output, None output = self.worker.model_runner.execute_model( - scheduler_output, intermediate_tensors) + scheduler_output, intermediate_tensors + ) if isinstance(output, IntermediateTensors): output = scheduler_output, output + elif not get_pp_group().is_last_rank: + # Case where there are no scheduled requests + # but may still be finished requests. + assert not output or not output.req_ids + output = scheduler_output, None + # Ensure outputs crossing Ray compiled DAG are serializable. + # AsyncModelRunnerOutput holds CUDA events and cannot be + # pickled. + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() return output - def override_env_vars(self, vars: Dict[str, str]): + def override_env_vars(self, vars: dict[str, str]): os.environ.update(vars) ray_import_err = None @@ -159,12 +177,15 @@ def ray_is_available() -> bool: def assert_ray_available(): """Raise an exception if Ray is not available.""" if ray is None: - raise ValueError(f"Failed to import Ray: {ray_import_err}." - "Please install Ray with `pip install ray`.") + raise ValueError( + f"Failed to import Ray: {ray_import_err}." + "Please install Ray with `pip install ray`." + ) -def _verify_bundles(placement_group: "PlacementGroup", - parallel_config: ParallelConfig, device_str: str): +def _verify_bundles( + placement_group: "PlacementGroup", parallel_config: ParallelConfig, device_str: str +): """Verify a given placement group has bundles located in the right place. There are 2 rules. @@ -172,14 +193,15 @@ def _verify_bundles(placement_group: "PlacementGroup", - Fail if driver node is not included in a placement group. """ assert ray.is_initialized(), ( - "Ray is not initialized although distributed-executor-backend is ray.") + "Ray is not initialized although distributed-executor-backend is ray." + ) pg_data = placement_group_table(placement_group) # bundle_idx -> node_id bundle_to_node_ids = pg_data["bundles_to_node_id"] # bundle_idx -> bundle (e.g., {"GPU": 1}) bundles = pg_data["bundles"] # node_id -> List of bundle (e.g., {"GPU": 1}) - node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list) for bundle_idx, node_id in bundle_to_node_ids.items(): node_id_to_bundle[node_id].append(bundles[bundle_idx]) @@ -205,8 +227,13 @@ def _verify_bundles(placement_group: "PlacementGroup", "unless you have fast interconnect across nodes, like " "Infiniband. To resolve this issue, make sure you have more " "than %d GPUs available at each node.", - parallel_config.tensor_parallel_size, device_str, len(bundles), - device_str, node_id, parallel_config.tensor_parallel_size) + parallel_config.tensor_parallel_size, + device_str, + len(bundles), + device_str, + node_id, + parallel_config.tensor_parallel_size, + ) def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): @@ -217,7 +244,7 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): """ # Wait until PG is ready - this will block until all - # requested resources are available, and will timeout + # requested resources are available, and will time out # if they cannot be provisioned. placement_group_specs = current_placement_group.bundle_specs @@ -238,7 +265,9 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): " and make sure the IP addresses used by ray cluster" " are the same as VLLM_HOST_IP environment variable" " specified in each node if you are running on a multi-node.", - int(time.time() - s), placement_group_specs) + int(time.time() - s), + placement_group_specs, + ) try: ray.get(pg_ready_ref, timeout=0) @@ -247,7 +276,8 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): "Cannot provide a placement group of " f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " "`ray status` and `ray list nodes` to make sure the cluster has " - "enough resources.") from None + "enough resources." + ) from None def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): @@ -262,8 +292,9 @@ def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): # Exponential backoff for warning print. wait_interval *= 2 logger.info( - "Waiting for removing a placement group of specs for " - "%d seconds.", int(time.time() - s)) + "Waiting for removing a placement group of specs for %d seconds.", + int(time.time() - s), + ) time.sleep(wait_interval) @@ -294,19 +325,21 @@ def initialize_ray_cluster( except ConnectionError: logger.warning( "No existing RAY instance detected. " - "A new instance will be launched with current node resources.") - ray.init(address=ray_address, - num_gpus=parallel_config.world_size, - runtime_env=parallel_config.ray_runtime_env) + "A new instance will be launched with current node resources." + ) + ray.init( + address=ray_address, + num_gpus=parallel_config.world_size, + runtime_env=parallel_config.ray_runtime_env, + ) else: - ray.init(address=ray_address, - runtime_env=parallel_config.ray_runtime_env) + ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env) device_str = current_platform.ray_device_key if not device_str: raise ValueError( - f"current platform {current_platform.device_name} does not " - "support ray.") + f"current platform {current_platform.device_name} does not support ray." + ) # Create or get the placement group for worker processes if parallel_config.placement_group: @@ -325,8 +358,8 @@ def initialize_ray_cluster( bundle_devices = bundle.get(device_str, 0) if bundle_devices > 1: raise ValueError( - "Placement group bundle cannot have more than 1 " - f"{device_str}.") + f"Placement group bundle cannot have more than 1 {device_str}." + ) if bundle_devices: device_bundles += 1 if parallel_config.world_size > device_bundles: @@ -334,10 +367,10 @@ def initialize_ray_cluster( f"The number of required {device_str}s exceeds the total " f"number of available {device_str}s in the placement group. " f"Required number of devices: {parallel_config.world_size}. " - f"Total number of devices: {device_bundles}.") + f"Total number of devices: {device_bundles}." + ) else: - logger.info("No current placement group found. " - "Creating a new placement group.") + logger.info("No current placement group found. Creating a new placement group.") num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) # Log a warning message and delay resource allocation failure response. # Avoid immediate rejection to allow user-initiated placement group @@ -345,12 +378,14 @@ def initialize_ray_cluster( if parallel_config.world_size > num_devices_in_cluster: logger.warning( "The number of required %ss exceeds the total " - "number of available %ss in the placement group.", device_str, - device_str) + "number of available %ss in the placement group.", + device_str, + device_str, + ) # Create a new placement group - placement_group_specs: List[Dict[str, float]] = ([{ - device_str: 1.0 - } for _ in range(parallel_config.world_size)]) + placement_group_specs: list[dict[str, float]] = [ + {device_str: 1.0} for _ in range(parallel_config.world_size) + ] # vLLM engine is also a worker to execute model with an accelerator, # so it requires to have the device in a current node. Check if @@ -363,14 +398,16 @@ def initialize_ray_cluster( f"Current node has no {device_str} available. " f"{current_node_resource=}. vLLM engine cannot start without " f"{device_str}. Make sure you have at least 1 {device_str} " - f"available in a node {current_node_id=} {current_ip=}.") + f"available in a node {current_node_id=} {current_ip=}." + ) # This way, at least bundle is required to be created in a current # node. placement_group_specs[0][f"node:{current_ip}"] = 0.001 # By default, Ray packs resources as much as possible. current_placement_group = ray.util.placement_group( - placement_group_specs, strategy="PACK") + placement_group_specs, strategy="PACK" + ) _wait_until_pg_ready(current_placement_group) assert current_placement_group is not None @@ -381,6 +418,7 @@ def initialize_ray_cluster( def get_num_tpu_nodes() -> int: from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() total_tpus = int(cluster_resources["TPU"]) tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index aabc9ed9b80a2..612fd73c12b15 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from concurrent.futures import Future, ThreadPoolExecutor +from functools import cached_property +from multiprocessing import Lock +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -10,53 +12,78 @@ import torch.distributed as dist import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - run_method) +from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.outputs import AsyncModelRunnerOutput +from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class UniProcExecutor(ExecutorBase): - uses_ray: bool = False def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - local_rank = 0 - # set local rank as the device index if specified - device_info = self.vllm_config.device_config.device.__str__().split( - ":") - if len(device_info) > 1: - local_rank = int(device_info[1]) - rank = 0 - is_driver_worker = True + """Initialize the worker and load the model.""" + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) + distributed_init_method, rank, local_rank = self._distributed_args() kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, + is_driver_worker=True, + shared_worker_lock=Lock(), ) - self.collective_rpc("init_worker", args=([kwargs], )) + + self.async_output_thread: Optional[ThreadPoolExecutor] = None + if self.max_concurrent_batches > 1: + self.async_output_thread = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="WorkerAsyncOutput" + ) + + self.collective_rpc("init_worker", args=([kwargs],)) self.collective_rpc("init_device") self.collective_rpc("load_model") - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + def _distributed_args(self) -> tuple[str, int, int]: + """Return (distributed_init_method, rank, local_rank).""" + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split(":") + local_rank = int(device_info[1]) if len(device_info) > 1 else 0 + return distributed_init_method, 0, local_rank + + @cached_property + def max_concurrent_batches(self) -> int: + return 2 if self.scheduler_config.async_scheduling else 1 + + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + ) -> list[Any]: if kwargs is None: kwargs = {} - answer = run_method(self.driver_worker, method, args, kwargs) - return [answer] + + if not non_block: + return [run_method(self.driver_worker, method, args, kwargs)] + + try: + result = run_method(self.driver_worker, method, args, kwargs) + if isinstance(result, AsyncModelRunnerOutput): + if (async_thread := self.async_output_thread) is not None: + return [async_thread.submit(result.get_output)] + result = result.get_output() + future = Future[Any]() + future.set_result(result) + except Exception as e: + future = Future[Any]() + future.set_exception(e) + return [future] def check_health(self) -> None: # UniProcExecutor will always be healthy as long as @@ -64,13 +91,20 @@ class UniProcExecutor(ExecutorBase): return def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: self.driver_worker.reinitialize_distributed(reconfig_request) - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() return + def shutdown(self) -> None: + if worker := self.driver_worker: + worker.shutdown() + UniProcExecutorAsync = UniProcExecutor @@ -91,21 +125,19 @@ class ExecutorWithExternalLauncher(UniProcExecutor): deterministic, all the engines will generate the same outputs, and they don't need to synchronize the states with each other. """ + uses_ray: bool = False def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ - ("ExecutorWithExternalLauncher needs deterministic " - "execution, so it" - "does not support delay_factor in scheduling") + """Initialize the worker and load the model.""" if envs.VLLM_USE_V1: - assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ - ("To get deterministic execution in V1, " - "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( + "To get deterministic execution in V1, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" + ) + super()._init_executor() + + def _distributed_args(self) -> tuple[str, int, int]: # engines are launched in torchrun-compatible launchers # so we can use the env:// method. # required env vars: @@ -116,30 +148,21 @@ class ExecutorWithExternalLauncher(UniProcExecutor): distributed_init_method = "env://" rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) - is_driver_worker = True - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, - ) - self.collective_rpc("init_worker", args=([kwargs], )) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + return distributed_init_method, rank, local_rank - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> tuple[int, int]: """ Determine the number of available KV blocks. Add an additional all_reduce to get the min across all ranks. - Note that even if we have the same `gpu_memory_utilization` and - `swap_space`, the available memory in every rank might still - differ because NCCL can take different amounts of memory in - different ranks. Therefore, it is necessary to test if all ranks + Note that even if we have the same `gpu_memory_utilization` and + `swap_space`, the available memory in every rank might still + differ because NCCL can take different amounts of memory in + different ranks. Therefore, it is necessary to test if all ranks agree on the same KV cache configuration. """ a, b = super().determine_num_available_blocks() from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c57c51d289ac8..36f3062a9e3a0 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -8,11 +8,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch -import torch.distributed as dist import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp +from vllm.v1.worker.ubatch_utils import UBatchSlices if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -32,6 +33,7 @@ class BatchDescriptor(NamedTuple): items as minimal as possible to properly and uniquely describe the padded batch for cudagraph. """ + num_tokens: int uniform_decode: bool = False """ @@ -47,16 +49,30 @@ class BatchDescriptor(NamedTuple): return BatchDescriptor(self.num_tokens, uniform_decode=False) -def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], - max_num_tokens: int, - chunk_idx: int) -> list[int]: - dp_size = len(num_tokens_across_dp_cpu) +def _compute_sp_num_tokens( + num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int +) -> list[int]: + sp_tokens = ( + num_tokens_across_dp_cpu + sequence_parallel_size - 1 + ) // sequence_parallel_size - local_size = [-1] * dp_size - for i in range(dp_size): - dp_tokens = num_tokens_across_dp_cpu[i] - local_size[i] = min(max_num_tokens, - dp_tokens - (max_num_tokens * chunk_idx)) + sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size) + return sp_tokens.tolist() + + +def _compute_chunked_local_num_tokens( + num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int, + max_num_tokens: int, + chunk_idx: int, +) -> list[int]: + sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size) + sp_size = len(sp_tokens) + + local_size = [-1] * sp_size + for i in range(sp_size): + # Take into account sharding if MoE activation is sequence parallel. + local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx)) if local_size[i] <= 0: local_size[i] = 1 # ensure lockstep even if done return local_size @@ -65,58 +81,34 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor - cu_tokens_across_dp_cpu: torch.Tensor + num_tokens_across_dp_cpu: torch.Tensor + + # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: Optional[list[int]] = None - @staticmethod - def num_tokens_across_dp(num_tokens: int, dp_size: int, - dp_rank: int) -> torch.Tensor: - """ - Gather the num_tokens across all DP ranks and return results in a - CPU tensor of size dp_size. - """ - num_tokens_across_dp = [0] * dp_size - num_tokens_across_dp[dp_rank] = num_tokens - num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device="cpu", - dtype=torch.int32) - from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - return num_tokens_tensor - @staticmethod def make( - parallel_config: ParallelConfig, - attn_metadata: Any, - num_tokens: int, - num_tokens_across_dp: Optional[torch.Tensor] = None + parallel_config: ParallelConfig, + num_tokens: int, + num_tokens_across_dp_cpu: torch.Tensor, ) -> "DPMetadata": - + assert num_tokens_across_dp_cpu is not None assert parallel_config.data_parallel_size > 1 - dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - if attn_metadata is not None and hasattr(attn_metadata, - "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - # for v1 attention backends or no attn_metadata - batchsize = num_tokens + batchsize = num_tokens # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert (num_tokens_across_dp is None - or num_tokens_across_dp[dp_rank] == batchsize) - if num_tokens_across_dp is None: - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - batchsize, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) - cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) - return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + assert num_tokens_across_dp_cpu[dp_rank] == batchsize, ( + f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" + ) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) + return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) @contextmanager - def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): + def chunked_sizes( + self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int + ): """ Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution. @@ -130,33 +122,57 @@ class DPMetadata: `chunk_idx`, this context manager sets `self.local_sizes` to the number of tokens to process in that chunk on each rank. - It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the - number of tokens per rank, and calls `_compute_chunked_local_num_tokens` - to determine the chunk-wise split. - `self.local_sizes` is only valid inside the context. Args: - max_chunk_size_per_rank: The max number of tokens each rank is + sequence_parallel_size: When Attn is TP and MoE layers are EP, + we use SP between the layers to avoid + redundant ops. We need this value to + compute the chunked sizes. + max_chunk_size_per_rank: The max number of tokens each rank is allowed to process in this chunk. chunk_idx: The index of the chunk to compute sizes for. """ - cu_sizes = self.cu_tokens_across_dp_cpu - num_tokens_across_dp_cpu = [ - (cu_sizes[i] - - cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item() - for i in range(len(cu_sizes)) - ] self.local_sizes = _compute_chunked_local_num_tokens( - num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) + self.num_tokens_across_dp_cpu, + sequence_parallel_size, + max_chunk_size_per_rank, + chunk_idx, + ) + try: + yield self.local_sizes + finally: + self.local_sizes = None + + @contextmanager + def sp_local_sizes(self, sequence_parallel_size: int): + """ + Context mamager for setting self.local_sizes. Same as self.chunked_sizes + but without any chunking. + """ + self.local_sizes = _compute_sp_num_tokens( + self.num_tokens_across_dp_cpu, sequence_parallel_size + ) try: yield self.local_sizes finally: self.local_sizes = None def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: + assert self.local_sizes is not None return self.local_sizes + # Get the cumulative tokens across sequence parallel ranks. + # In this case the input to the MoEs will be distributed w.r.t both + # DP and TP rank. + # When sp_size==1, this is just the cummulative num tokens across DP. + def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: + num_tokens_across_sp_cpu = ( + self.num_tokens_across_dp_cpu - 1 + sp_size + ) // sp_size + num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size) + return torch.cumsum(num_tokens_across_sp_cpu, dim=0) + @dataclass class ForwardContext: @@ -166,9 +182,15 @@ class ForwardContext: Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata - set dynamically for each forward pass + Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one + for each microbatch. + Set dynamically for each forward pass """ - attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] + attn_metadata: Union[ + "AttentionMetadata", + dict[str, "AttentionMetadata"], + list[dict[str, "AttentionMetadata"]], + ] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass @@ -178,10 +200,12 @@ class ForwardContext: cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE batch_descriptor: Optional[BatchDescriptor] = None + ubatch_slices: Optional[UBatchSlices] = None + def __post_init__(self): - assert self.cudagraph_runtime_mode in [ - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" + ) _forward_context: Optional[ForwardContext] = None @@ -191,19 +215,57 @@ def get_forward_context() -> ForwardContext: """Get the current forward context.""" assert _forward_context is not None, ( "Forward context is not set. " - "Please use `set_forward_context` to set the forward context.") + "Please use `set_forward_context` to set the forward context." + ) return _forward_context +def create_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + dp_metadata: Optional[DPMetadata] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None, +): + return ForwardContext( + no_compile_layers=vllm_config.compilation_config.static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ) + + +@contextmanager +def override_forward_context(forward_context: Optional[ForwardContext]): + """A context manager that overrides the current forward context. + This is used to override the forward context for a specific + forward pass. + """ + global _forward_context + prev_context = _forward_context + _forward_context = forward_context + try: + yield + finally: + _forward_context = prev_context + + @contextmanager def set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None, +): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -212,34 +274,55 @@ def set_forward_context( need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() + dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1 and ( - attn_metadata is not None or num_tokens is not None): - dp_metadata = DPMetadata.make(vllm_config.parallel_config, - attn_metadata, num_tokens or 0, - num_tokens_across_dp) + attn_metadata is not None or num_tokens is not None + ): + # If num_tokens_across_dp hasn't already been initialized, then + # initialize it here. Both DP padding and Microbatching will be + # disabled. + if num_tokens_across_dp is None: + assert ubatch_slices is None + assert num_tokens is not None + _, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, + parallel_config=vllm_config.parallel_config, + allow_microbatching=False, + allow_dp_padding=False, + ) + assert num_tokens_across_dp is not None + dp_metadata = DPMetadata.make( + vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp + ) - global _forward_context - prev_context = _forward_context - _forward_context = ForwardContext( - no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + # Convenience: if cudagraph is used and num_tokens is given, we can just + # create a batch descriptor here if not given (there's no harm since if it + # doesn't match in the wrapper it'll fall through). + if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: + batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) + + forward_context = create_forward_context( + attn_metadata, + vllm_config, + virtual_engine, + dp_metadata, + cudagraph_runtime_mode, + batch_descriptor, + ubatch_slices, ) try: - yield + with override_forward_context(forward_context): + yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: if hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens + batchsize = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) else: # for v1 attention backends batchsize = num_tokens @@ -247,13 +330,13 @@ def set_forward_context( # adding a sync point here should not affect # scheduling of the next batch from vllm.platforms import current_platform + synchronize = current_platform.synchronize if synchronize is not None: synchronize() now = time.perf_counter() # time measurement is in milliseconds - batchsize_forward_time[batchsize].append( - (now - forward_start_time) * 1000) + batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000) if now - last_logging_time > batchsize_logging_interval: last_logging_time = now forward_stats = [] @@ -266,8 +349,10 @@ def set_forward_context( forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) if forward_stats: - logger.info(("Batchsize forward time stats " - "(batchsize, count, median_time(ms)): %s"), - forward_stats) - - _forward_context = prev_context + logger.info( + ( + "Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s" + ), + forward_stats, + ) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index aef7841e71b71..d9aed70c9b979 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,23 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, - EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, embeds_inputs, - to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import (DummyData, InputContext, InputProcessingContext, - InputRegistry) - -INPUT_REGISTRY = InputRegistry() -""" -The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used -by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the -target model. -""" +from .data import ( + DataPrompt, + DecoderOnlyInputs, + EmbedsInputs, + EmbedsPrompt, + EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + TextPrompt, + TokenInputs, + TokensPrompt, + build_explicit_enc_dec_prompt, + embeds_inputs, + to_enc_dec_tuple_list, + token_inputs, + zip_enc_dec_prompts, +) __all__ = [ + "DataPrompt", "TextPrompt", "TokensPrompt", "PromptType", @@ -35,9 +41,4 @@ __all__ = [ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", - "INPUT_REGISTRY", - "DummyData", - "InputContext", - "InputProcessingContext", - "InputRegistry", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 23cb5e5022f19..be14decb4ac9d 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,7 +7,11 @@ import torch from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: - from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs + from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalInputs, + MultiModalUUIDDict, + ) class TextPrompt(TypedDict): @@ -16,13 +20,13 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: NotRequired["MultiModalDataDict"] + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] """ Optional multi-modal data to pass to the model, if the model supports it. """ - mm_processor_kwargs: NotRequired[dict[str, Any]] + mm_processor_kwargs: NotRequired[Optional[dict[str, Any]]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -30,6 +34,15 @@ class TextPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching, and MUST be unique per + multimodal item. + """ + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -42,16 +55,19 @@ class TokensPrompt(TypedDict): prompt_token_ids: list[int] """A list of token IDs to pass to the model.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" - multi_modal_data: NotRequired["MultiModalDataDict"] + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] """ Optional multi-modal data to pass to the model, if the model supports it. """ - mm_processor_kwargs: NotRequired[dict[str, Any]] + mm_processor_kwargs: NotRequired[Optional[dict[str, Any]]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -59,6 +75,14 @@ class TokensPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching. + """ + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -77,6 +101,16 @@ class EmbedsPrompt(TypedDict): """ +class DataPrompt(TypedDict): + """Represents generic inputs handled by IO processor plugins.""" + + data: Any + """The input data""" + + data_format: str + """The input data format""" + + SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: @@ -103,23 +137,27 @@ more than one prompt, i.e. def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt + ) def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt + ) -_T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, - covariant=True) -_T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, - covariant=True) +_T1_co = TypeVar( + "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True +) +_T2_co = TypeVar( + "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True +) # TODO: Make fields ReadOnly once mypy supports it @@ -174,14 +212,6 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" - token_type_ids: NotRequired[list[int]] - """The token type IDs of the prompt.""" - - prompt: NotRequired[str] - """ - The original prompt text corresponding to the token IDs, if available. - """ - cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. @@ -190,18 +220,12 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: list[int], - token_type_ids: Optional[list[int]] = None, - prompt: Optional[str] = None, cache_salt: Optional[str] = None, ) -> TokenInputs: """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) - if prompt is not None: - inputs["prompt"] = prompt - if token_type_ids is not None: - inputs["token_type_ids"] = token_type_ids if cache_salt is not None: inputs["cache_salt"] = cache_salt @@ -262,8 +286,8 @@ class EncoderDecoderInputs(TypedDict): SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ -A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be -passed to [`vllm.sequence.Sequence`][]. +A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be +passed to [`Sequence`][collections.abc.Sequence]. """ ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] @@ -292,8 +316,9 @@ def build_explicit_enc_dec_prompt( def zip_enc_dec_prompts( enc_prompts: Iterable[_T1], dec_prompts: Iterable[Optional[_T2]], - mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]], - dict[str, Any]]] = None, + mm_processor_kwargs: Optional[ + Union[Iterable[dict[str, Any]], dict[str, Any]] + ] = None, ) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of @@ -312,20 +337,21 @@ def zip_enc_dec_prompts( encoder_prompt, decoder_prompt, cast(dict[str, Any], mm_processor_kwargs), - ) for (encoder_prompt, - decoder_prompt) in zip(enc_prompts, dec_prompts) + ) + for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) ] return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, - mm_proc_kwargs) - for (encoder_prompt, decoder_prompt, mm_proc_kwargs - ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip( + enc_prompts, dec_prompts, mm_processor_kwargs + ) ] def to_enc_dec_tuple_list( enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], ) -> list[tuple[_T1, Optional[_T2]]]: - return [(enc_dec_prompt["encoder_prompt"], - enc_dec_prompt["decoder_prompt"]) - for enc_dec_prompt in enc_dec_prompts] + return [ + (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts + ] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 8c3700799e4ab..2f7bd50df022e 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,45 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Literal, Optional, TypedDict, Union, cast, overload +from typing import TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, Union, cast from typing_extensions import TypeIs from vllm.utils import is_list_of -from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import ( + EmbedsPrompt, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + TextPrompt, + TokensPrompt, +) + +if TYPE_CHECKING: + import torch -class ParsedText(TypedDict): - content: str - is_tokens: Literal[False] - - -class ParsedTokens(TypedDict): - content: list[int] - is_tokens: Literal[True] - - -@overload -def parse_and_batch_prompt( - prompt: Union[str, list[str]], ) -> Sequence[ParsedText]: - ... - - -@overload -def parse_and_batch_prompt( - prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]: - ... - - -def parse_and_batch_prompt( +def parse_raw_prompts( prompt: Union[str, list[str], list[int], list[list[int]]], -) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: +) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]: if isinstance(prompt, str): # case 1: a string - return [ParsedText(content=prompt, is_tokens=False)] + return [TextPrompt(prompt=prompt)] if isinstance(prompt, list): if len(prompt) == 0: @@ -48,13 +36,11 @@ def parse_and_batch_prompt( if is_list_of(prompt, str): # case 2: array of strings prompt = cast(list[str], prompt) - return [ - ParsedText(content=elem, is_tokens=False) for elem in prompt - ] + return [TextPrompt(prompt=elem) for elem in prompt] if is_list_of(prompt, int): # case 3: array of tokens prompt = cast(list[int], prompt) - return [ParsedTokens(content=prompt, is_tokens=True)] + return [TokensPrompt(prompt_token_ids=prompt)] if is_list_of(prompt, list): prompt = cast(list[list[int]], prompt) if len(prompt[0]) == 0: @@ -62,13 +48,12 @@ def parse_and_batch_prompt( if is_list_of(prompt[0], int): # case 4: array of token arrays - return [ - ParsedTokens(content=elem, is_tokens=True) - for elem in prompt - ] + return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] - raise TypeError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") + raise TypeError( + "prompt must be a string, array of strings, " + "array of tokens, or array of token arrays" + ) class ParsedStrPrompt(TypedDict): @@ -91,28 +76,9 @@ class ParsedEmbedsPrompt(TypedDict): content: EmbedsPrompt -ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt, - ParsedTokensPrompt, ParsedEmbedsPrompt] - - -@overload -def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: - ... +ParsedSingletonPrompt = Union[ + ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, ParsedEmbedsPrompt +] def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: @@ -122,19 +88,19 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: # Type ignores are because mypy does not correctly infer the TypedDicts # Pyright does succeed. if "prompt_embeds" in prompt: - return ParsedEmbedsPrompt( - type="embeds", content=prompt) # type: ignore[typeddict-item] + return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item] elif "prompt_token_ids" in prompt: - return ParsedTokensPrompt( - type="tokens", content=prompt) # type: ignore[typeddict-item] + return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item] elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) raise TypeError( - "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt" + ) def is_explicit_encoder_decoder_prompt( - prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]: + prompt: PromptType, +) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt @@ -149,3 +115,23 @@ def split_enc_dec_inputs( ) return None, inputs + + +class PromptComponents(NamedTuple): + text: Optional[str] = None + token_ids: Optional[list[int]] = None + embeds: Optional["torch.Tensor"] = None + + +def get_prompt_components(prompt: PromptType) -> PromptComponents: + if isinstance(prompt, str): + return PromptComponents(text=prompt) + + if encoder_prompt := prompt.get("encoder_prompt"): + return get_prompt_components(encoder_prompt) # type: ignore[arg-type] + + return PromptComponents( + text=prompt.get("prompt"), # type: ignore[arg-type] + token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type] + embeds=prompt.get("prompt_embeds"), + ) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3f521012e82a2..809f6c8d83f01 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio from collections.abc import Mapping from typing import Any, Optional, Union, cast @@ -9,62 +8,82 @@ from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs) +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalInputs, + MultiModalUUIDDict, +) +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils.jsontree import json_iter_leaves +from vllm.v1.metrics.stats import MultiModalCacheStats -from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, - EncoderDecoderInputs, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, embeds_inputs, token_inputs) +from .data import ( + DecoderOnlyInputs, + EmbedsInputs, + EmbedsPrompt, + EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + TextPrompt, + TokenInputs, + TokensPrompt, + embeds_inputs, + token_inputs, +) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) class InputPreprocessor: - def __init__( self, model_config: ModelConfig, - tokenizer: Optional[TokenizerGroup], + tokenizer: Optional[AnyTokenizer], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer self.mm_registry = mm_registry + self.mm_processor_cache = mm_processor_cache - def get_tokenizer_group(self) -> TokenizerGroup: + self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None + + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("You cannot pass text prompts when " - "`skip_tokenizer_init` is True") + raise ValueError( + "You cannot pass text prompts when `skip_tokenizer_init` is True" + ) return self.tokenizer - def get_bos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_bos_token_id(self) -> Optional[int]: if self.tokenizer is None: - logger.warning("Using None for BOS token id because tokenizer " - "is not initialized") + logger.warning_once( + "Using None for BOS token id because tokenizer is not initialized" + ) return None - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + return self.tokenizer.bos_token_id - def get_eos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_eos_token_id(self) -> Optional[int]: if self.tokenizer is None: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") + logger.warning_once( + "Using None for EOS token id because tokenizer is not initialized" + ) return None - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + return self.tokenizer.eos_token_id def get_decoder_start_token_id(self) -> Optional[int]: """ @@ -76,22 +95,26 @@ class InputPreprocessor: if not self.model_config.is_encoder_decoder: logger.warning_once( "Using None for decoder start token id because " - "this is not an encoder/decoder model.") + "this is not an encoder/decoder model." + ) return None if self.model_config is None or self.model_config.hf_config is None: logger.warning_once( "Using None for decoder start token id because " - "model config is not available.") + "model config is not available." + ) return None - dec_start_token_id = getattr(self.model_config.hf_config, - "decoder_start_token_id", None) + dec_start_token_id = getattr( + self.model_config.hf_config, "decoder_start_token_id", None + ) if dec_start_token_id is None: logger.warning_once( "Falling back on <BOS> for decoder start token " "id because decoder start token id is not " - "available.") + "available." + ) dec_start_token_id = self.get_bos_token_id() return dec_start_token_id @@ -161,8 +184,10 @@ class InputPreprocessor: # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if ( + len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id + ): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -187,14 +212,13 @@ class InputPreprocessor: def _tokenize_prompt( self, prompt: str, - lora_request: Optional[LoRARequest], tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[int]: """ Apply the model's tokenizer to a text prompt, returning the corresponding token IDs. """ - tokenizer = self.get_tokenizer_group() + tokenizer = self.get_tokenizer() tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) encoder_config = self.model_config.encoder_config @@ -202,50 +226,28 @@ class InputPreprocessor: if encoder_config and encoder_config.get("do_lower_case", False): prompt = prompt.lower() - return tokenizer.encode(prompt=prompt, - lora_request=lora_request, - **tokenization_kwargs) + return tokenizer.encode(prompt, **tokenization_kwargs) - async def _tokenize_prompt_async( - self, - prompt: str, - lora_request: Optional[LoRARequest], - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[int]: - """ - Async version of - [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt]. - """ - tokenizer = self.get_tokenizer_group() - tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) - - return await tokenizer.encode_async(prompt=prompt, - lora_request=lora_request, - **tokenization_kwargs) - - def _get_mm_tokenizer( - self, - lora_request: Optional[LoRARequest], - ) -> AnyTokenizer: + def _get_mm_tokenizer(self) -> AnyTokenizer: # PrithviGeoSpatialMAE needs to be initialized without a tokenizer # while using also multi-modal input if not self.tokenizer: return cast(AnyTokenizer, object()) # Dummy - tokenizer_group = self.get_tokenizer_group() - return tokenizer_group.get_lora_tokenizer(lora_request) + tokenizer = self.get_tokenizer() + return tokenizer - async def _get_mm_tokenizer_async( - self, - lora_request: Optional[LoRARequest], - ) -> AnyTokenizer: - # PrithviGeoSpatialMAE needs to be initialized without a tokenizer - # while using also multi-modal input - if not self.tokenizer: - return cast(AnyTokenizer, object()) # Dummy + def _get_mm_processor(self) -> BaseMultiModalProcessor: + if not hasattr(self, "_mm_processor"): + tokenizer = self._get_mm_tokenizer() - tokenizer_group = self.get_tokenizer_group() - return await tokenizer_group.get_lora_tokenizer_async(lora_request) + self._mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) + + return self._mm_processor def _process_multimodal( self, @@ -253,56 +255,48 @@ class InputPreprocessor: mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. """ - tokenizer = self._get_mm_tokenizer(lora_request) - - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self._get_mm_processor() if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs) + mm_input = mm_processor.apply( + prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + mm_hashes = mm_input["mm_hashes"] - async def _process_multimodal_async( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - mm_processor_kwargs: Optional[Mapping[str, object]], - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - ) -> MultiModalInputs: - """ - Async version of - [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal]. - """ - tokenizer = await self._get_mm_tokenizer_async(lora_request) + # Validate that all mm items have a string as their hash + contains_only_strings = all( + isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes) + ) + if not contains_only_strings: + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method." + ) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - - return mm_processor.apply(prompt, - mm_data, - hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs) + return mm_input def _process_embeds( self, parsed_content: EmbedsPrompt, ) -> EmbedsInputs: if not self.model_config.enable_prompt_embeds: - raise ValueError("You must set `--enable-prompt-embeds` to input " - "`prompt_embeds`.") + raise ValueError( + "You must set `--enable-prompt-embeds` to input `prompt_embeds`." + ) prompt_embeds = parsed_content["prompt_embeds"] @@ -314,70 +308,59 @@ class InputPreprocessor: prompt_embeds = prompt_embeds.squeeze(dim=0) if prompt_embeds.ndim != 2: - raise ValueError( - "prompt_embeds must be of shape (seq_len, hidden_size).") + raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).") - return embeds_inputs(prompt_embeds=prompt_embeds, - cache_salt=parsed_content.get("cache_salt")) + # Tensors must be on CPU for serialization between processes + # in the MsgpackEncoder. Casting to CPU here ensures that there is no + # hidden device transfer in the critical path of generation. + prompt_embeds = prompt_embeds.cpu() - async def _process_embeds_async( - self, - parsed_content: EmbedsPrompt, - ) -> EmbedsInputs: - return self._process_embeds(parsed_content) + return embeds_inputs( + prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt") + ) + + def _truncate_inputs( + self, inputs: list[int], tokenization_kwargs: Optional[dict[str, Any]] = None + ) -> list[int]: + if ( + not tokenization_kwargs + or "truncation" not in tokenization_kwargs + or self.tokenizer is None + ): + return inputs + + max_length = tokenization_kwargs["max_length"] + + if self.tokenizer.truncation_side == "left": + return inputs[-max_length:] + else: + return inputs[:max_length] def _process_tokens( self, parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] - token_type_ids = parsed_content.get("token_type_ids") + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs + ) inputs: Union[TokenInputs, MultiModalInputs] - if multi_modal_data := parsed_content.get("multi_modal_data"): + if self.model_config.is_multimodal_model: inputs = self._process_multimodal( prompt_token_ids, - multi_modal_data, - parsed_content.get("mm_processor_kwargs"), + parsed_content.get("multi_modal_data") or {}, + parsed_content.get("mm_processor_kwargs") or {}, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) else: - inputs = token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) + if parsed_content.get("multi_modal_data"): + raise ValueError("This model does not support multimodal inputs") - if cache_salt := parsed_content.get("cache_salt"): - inputs["cache_salt"] = cache_salt - - return inputs - - async def _process_tokens_async( - self, - parsed_content: TokensPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] - token_type_ids = parsed_content.get("token_type_ids") - - inputs: Union[TokenInputs, MultiModalInputs] - if multi_modal_data := parsed_content.get("multi_modal_data"): - inputs = await self._process_multimodal_async( - prompt_token_ids, - multi_modal_data, - parsed_content.get("mm_processor_kwargs"), - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - ) - else: - inputs = token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) + inputs = token_inputs(prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -388,62 +371,29 @@ class InputPreprocessor: self, parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] inputs: Union[TokenInputs, MultiModalInputs] - if multi_modal_data := parsed_content.get("multi_modal_data"): + if self.model_config.is_multimodal_model: inputs = self._process_multimodal( prompt_text, - multi_modal_data, - parsed_content.get("mm_processor_kwargs"), + parsed_content.get("multi_modal_data") or {}, + parsed_content.get("mm_processor_kwargs") or {}, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) else: + if parsed_content.get("multi_modal_data"): + raise ValueError("This model does not support multimodal inputs") + prompt_token_ids = self._tokenize_prompt( prompt_text, - lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) - inputs = token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) - - if cache_salt := parsed_content.get("cache_salt"): - inputs["cache_salt"] = cache_salt - - return inputs - - async def _process_text_async( - self, - parsed_content: TextPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - ) -> Union[TokenInputs, MultiModalInputs]: - prompt_text = parsed_content["prompt"] - - inputs: Union[TokenInputs, MultiModalInputs] - if multi_modal_data := parsed_content.get("multi_modal_data"): - inputs = await self._process_multimodal_async( - prompt_text, - multi_modal_data, - parsed_content.get("mm_processor_kwargs"), - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - ) - else: - prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - inputs = token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) + inputs = token_inputs(prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): inputs["cache_salt"] = cache_salt @@ -454,7 +404,8 @@ class InputPreprocessor: self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -462,7 +413,6 @@ class InputPreprocessor: Arguments: * prompt: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts Returns: @@ -475,53 +425,19 @@ class InputPreprocessor: if parsed["type"] == "tokens": return self._process_tokens( parsed["content"], - lora_request=lora_request, + mm_uuids=mm_uuids, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - ) - - assert_never(parsed) - - async def _prompt_to_llm_inputs_async( - self, - prompt: SingletonPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - ) -> SingletonInputs: - """ - Async version of - [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs]. - """ - parsed = parse_singleton_prompt(prompt) - - if parsed["type"] == "embeds": - return await self._process_embeds_async(parsed["content"]) - if parsed["type"] == "tokens": - return await self._process_tokens_async( - parsed["content"], - lora_request=lora_request, - ) - if parsed["type"] == "text": - return await self._process_text_async( - parsed["content"], - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - ) - if parsed["type"] == "str": - return await self._process_text_async( - TextPrompt(prompt=parsed["content"]), - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) assert_never(parsed) @@ -531,16 +447,20 @@ class InputPreprocessor: encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if (encoder_inputs["type"] == "embeds" - or decoder_inputs and decoder_inputs["type"] == "embeds"): - raise ValueError("Embedding inputs are not supported for encoder-" - "decoder models") + if ( + encoder_inputs["type"] == "embeds" + or decoder_inputs + and decoder_inputs["type"] == "embeds" + ): + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) # Needed for mypy - encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], - encoder_inputs) - decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]], - decoder_inputs) + encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], encoder_inputs) + decoder_inputs = cast( + Optional[Union[TokenInputs, MultiModalInputs]], decoder_inputs + ) if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": @@ -550,16 +470,18 @@ class InputPreprocessor: # overridden by the audio features. dec_token_ids = encoder_inputs["prompt_token_ids"].copy() else: - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - None) + dec_token_ids = self._prepare_decoder_input_ids_for_generation(None) decoder_inputs = token_inputs(dec_token_ids) else: if "multi_modal_data" in decoder_inputs: - raise ValueError("Multi-modal decoder inputs of encoder-" - "decoder models are not supported yet") + raise ValueError( + "Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet" + ) dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] + ) decoder_inputs["prompt_token_ids"] = dec_token_ids return EncoderDecoderInputs( @@ -576,10 +498,14 @@ class InputPreprocessor: For encoder/decoder models only: Separate Encoder/Decoder inputs from a MultiModalEncDecInputs """ - if (inputs["type"] == "embeds" or decoder_inputs_to_override - and decoder_inputs_to_override["type"] == "embeds"): - raise ValueError("Embedding inputs are not supported for encoder-" - "decoder models") + if ( + inputs["type"] == "embeds" + or decoder_inputs_to_override + and decoder_inputs_to_override["type"] == "embeds" + ): + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) # Needed for mypy inputs = cast( @@ -595,22 +521,19 @@ class InputPreprocessor: decoder_inputs: SingletonInputs if inputs["type"] == "multimodal": # Multimodal data inputs - if not ("encoder_prompt" in inputs - and "encoder_prompt_token_ids" in inputs): - raise RuntimeError("You should register an encoder-decoder " - "multi-modal processor for encoder-decoder " - "models.") + if "encoder_prompt_token_ids" not in inputs: + raise RuntimeError( + "You should register an encoder-decoder " + "multi-modal processor for encoder-decoder " + "models." + ) inputs = cast(MultiModalEncDecInputs, inputs) - encoder_inputs = token_inputs( - prompt=inputs["encoder_prompt"], - prompt_token_ids=inputs["encoder_prompt_token_ids"], - ) + encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"]) decoder_prompt_inputs = decoder_inputs_to_override or inputs decoder_inputs = MultiModalInputs( type="multimodal", - prompt=decoder_prompt_inputs.get("prompt", ""), prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], mm_kwargs=inputs["mm_kwargs"], mm_hashes=inputs["mm_hashes"], @@ -620,7 +543,7 @@ class InputPreprocessor: decoder_inputs["cache_salt"] = cache_salt elif inputs["type"] == "token": # Text-only inputs - encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) + encoder_inputs = token_inputs(prompt_token_ids=[]) decoder_inputs = decoder_inputs_to_override or inputs else: assert_never(inputs) # type: ignore[arg-type] @@ -631,6 +554,8 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -669,80 +594,33 @@ class InputPreprocessor: decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): + # `cast` is needed for mypy, but not pyright + prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt) encoder_inputs = self._prompt_to_llm_inputs( - prompt["encoder_prompt"], + prompt_["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := prompt_["decoder_prompt"]) is None: decoder_inputs = None else: decoder_inputs = self._prompt_to_llm_inputs(decoder_input) # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: - encoder_inputs, decoder_inputs = ( - self._split_enc_dec_mm_inputs(encoder_inputs, - decoder_inputs)) - else: - inputs = self._prompt_to_llm_inputs( - prompt, - tokenization_kwargs=tokenization_kwargs, - ) - if self.model_config.is_multimodal_model: - # Encoder-Decoder Multimodal model - encoder_inputs, decoder_inputs = ( - self._split_enc_dec_mm_inputs(inputs)) - else: - encoder_inputs = inputs - decoder_inputs = None - - return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) - - async def _process_encoder_decoder_prompt_async( - self, - prompt: PromptType, - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> EncoderDecoderInputs: - """ - Async version of - [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt]. - """ - encoder_inputs: SingletonInputs - decoder_inputs: Optional[SingletonInputs] - - if is_explicit_encoder_decoder_prompt(prompt): - encoder_task = self._prompt_to_llm_inputs_async( - prompt["encoder_prompt"], - tokenization_kwargs=tokenization_kwargs, - ) - - if (decoder_input := prompt["decoder_prompt"]) is None: - encoder_inputs = await encoder_task - decoder_inputs = None - else: - decoder_task = self._prompt_to_llm_inputs_async( - decoder_input, - tokenization_kwargs=tokenization_kwargs, + encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs( + encoder_inputs, decoder_inputs ) - - encoder_inputs, decoder_inputs = await asyncio.gather( - encoder_task, decoder_task) - - # For multimodal model, override decoder prompt from processor - # with explicit decoder prompt. - if self.model_config.is_multimodal_model: - encoder_inputs, decoder_inputs = ( - self._split_enc_dec_mm_inputs(encoder_inputs, - decoder_inputs)) else: - inputs = await self._prompt_to_llm_inputs_async( - prompt, + # `cast` is needed for mypy, but not pyright + inputs = self._prompt_to_llm_inputs( + cast(SingletonPrompt, prompt), tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model - encoder_inputs, decoder_inputs = ( - self._split_enc_dec_mm_inputs(inputs)) + encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs(inputs) else: encoder_inputs = inputs decoder_inputs = None @@ -754,8 +632,9 @@ class InputPreprocessor: prompt_inputs: DecoderOnlyInputs, ) -> DecoderOnlyInputs: if "prompt_token_ids" in prompt_inputs: - prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], - prompt_inputs) # Needed for mypy + prompt_inputs = cast( + Union[TokenInputs, MultiModalInputs], prompt_inputs + ) # Needed for mypy return prompt_inputs @@ -763,7 +642,8 @@ class InputPreprocessor: self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -773,7 +653,6 @@ class InputPreprocessor: Arguments: * prompt: input prompt - * lora_request Returns: @@ -783,80 +662,74 @@ class InputPreprocessor: prompt_comps = self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) return self._build_decoder_only_llm_inputs(prompt_comps) - async def _process_decoder_only_prompt_async( - self, - prompt: SingletonPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, - ) -> DecoderOnlyInputs: - """ - Async version of - [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt]. - """ - prompt_comps = await self._prompt_to_llm_inputs_async( - prompt, - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs(prompt_comps) - - def preprocess( + def _preprocess( self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: - """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder. return self._process_encoder_decoder_prompt( prompt, tokenization_kwargs, + mm_uuids=mm_uuids, ) if is_explicit_encoder_decoder_prompt(prompt): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") + raise ValueError( + "Cannot pass encoder-decoder prompt to decoder-only models" + ) # Decoder-only operation + # `cast` is needed for mypy, but not pyright return self._process_decoder_only_prompt( - prompt, + cast(SingletonPrompt, prompt), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) - async def preprocess_async( + def preprocess( self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: - """ - Async version of - [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. - """ - if self.model_config.is_encoder_decoder: - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder. - return await self._process_encoder_decoder_prompt_async( - prompt, - tokenization_kwargs, - ) - - if is_explicit_encoder_decoder_prompt(prompt): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - return await self._process_decoder_only_prompt_async( + """Preprocess the input prompt.""" + res = self._preprocess( prompt, - tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + tokenization_kwargs, + mm_uuids=mm_uuids, ) + + if self.mm_processor_cache and self.mm_cache_stats is not None: + delta = self.mm_processor_cache.make_stats(delta=True) + self.mm_cache_stats.requests += 1 + self.mm_cache_stats.queries += delta.total + self.mm_cache_stats.hits += delta.hits + + return res + + def stat_mm_cache(self) -> Optional[MultiModalCacheStats]: + mm_cache_stats = self.mm_cache_stats + if mm_cache_stats is None: + return None + + self.mm_cache_stats = MultiModalCacheStats() + + return mm_cache_stats + + def clear_mm_cache(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.clear_cache() + + if self.mm_cache_stats is not None: + self.mm_cache_stats.reset = True diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py deleted file mode 100644 index ef146fdfbf97c..0000000000000 --- a/vllm/inputs/registry.py +++ /dev/null @@ -1,245 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union - -import torch -from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar - -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.utils import get_allowed_kwarg_only_overrides -from vllm.utils.jsontree import JSONTree, json_map_leaves - -if TYPE_CHECKING: - from vllm.config import ModelConfig - from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, - MultiModalRegistry) - from vllm.sequence import SequenceData - from vllm.transformers_utils.tokenizer import AnyTokenizer -else: - ModelConfig = Any - MultiModalDataDict = Any - MultiModalPlaceholderDict = Any - MultiModalRegistry = Any - SequenceData = Any - AnyTokenizer = Any - -_T = TypeVar("_T") -_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) -_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class InputContext: - """ - Contains information about the model which may be used to - modify the inputs. - """ - - model_config: ModelConfig - """The configuration of the model.""" - - def get_hf_config( - self, - typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig, - /, - ) -> _C: - """ - Get the HuggingFace configuration - (`transformers.PretrainedConfig`) of the model, - additionally checking its type. - - Raises: - TypeError: If the configuration is not of the specified type. - """ - hf_config = self.model_config.hf_config - if not isinstance(hf_config, typ): - raise TypeError("Invalid type of HuggingFace config. " - f"Expected type: {typ}, but " - f"found type: {type(hf_config)}") - - return hf_config - - def get_hf_image_processor_config(self) -> dict[str, Any]: - """ - Get the HuggingFace image processor configuration of the model. - """ - return self.model_config.hf_image_processor_config - - def get_mm_config(self): - """ - Get the multimodal config of the model. - - Raises: - RuntimeError: If the model is not a multimodal model. - """ - mm_config = self.model_config.multimodal_config - if mm_config is None: - raise RuntimeError("Not a multimodal model") - - return mm_config - - def get_hf_processor( - self, - typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, - /, - **kwargs: object, - ) -> _P: - """ - Get the HuggingFace processor - (`transformers.ProcessorMixin`) of the model, - additionally checking its type. - - Raises: - TypeError: If the processor is not of the specified type. - """ - return cached_processor_from_config( - self.model_config, - processor_cls=typ, - **kwargs, - ) - - def init_processor( - self, - typ: type[_T], - /, - **kwargs: object, - ) -> _T: - """ - Initialize a HuggingFace-like processor class, merging the - keyword arguments with those in the model's configuration. - """ - mm_config = self.model_config.get_multimodal_config() - base_kwargs = mm_config.mm_processor_kwargs - if base_kwargs is None: - base_kwargs = {} - - merged_kwargs = {**base_kwargs, **kwargs} - - return typ(**merged_kwargs) - - -@dataclass(frozen=True) -class InputProcessingContext(InputContext): - tokenizer: AnyTokenizer - """The tokenizer used to tokenize the inputs.""" - - def get_hf_processor( - self, - typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, - /, - **kwargs: object, - ) -> _P: - return super().get_hf_processor( - typ, - tokenizer=self.tokenizer, - **kwargs, - ) - - def call_hf_processor( - self, - hf_processor: ProcessorMixin, - data: Mapping[str, object], - kwargs: Mapping[str, object] = {}, - ) -> Union[BatchFeature, JSONTree]: - """ - Call `hf_processor` on the prompt `data` - (text, image, audio...) with configurable options `kwargs`. - """ - assert callable(hf_processor) - - mm_config = self.model_config.get_multimodal_config() - merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) - - allowed_kwargs = get_allowed_kwarg_only_overrides( - hf_processor, - merged_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - - def maybe_cast_dtype(x): - # This mimics the behavior of transformers.BatchFeature - if isinstance(x, torch.Tensor) and x.is_floating_point(): - return x.to(dtype=self.model_config.dtype) - return x - - try: - output = hf_processor(**data, - **allowed_kwargs, - return_tensors="pt") - # this emulates output.to(dtype=self.model_config.dtype) - if isinstance(output, BatchFeature): - cast_output = json_map_leaves(maybe_cast_dtype, output.data) - return BatchFeature(cast_output) - - cast_output = json_map_leaves(maybe_cast_dtype, output) - - logger.warning_once( - f"{type(hf_processor).__name__} did not return `BatchFeature`. " - "Make sure to match the behaviour of `ProcessorMixin` when " - "implementing custom processors.") - return cast_output - - except Exception as exc: - msg = (f"Failed to apply {type(hf_processor).__name__} " - f"on data={data} with kwargs={allowed_kwargs}") - - raise ValueError(msg) from exc - - -class DummyData(NamedTuple): - """ - Dummy data used for profiling. - - Note: This is only used in V0. - """ - - seq_data: SequenceData - multi_modal_data: Optional[MultiModalDataDict] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - - -class InputRegistry: - """ - Note: This is only used in V0. - """ - - def dummy_data_for_profiling( - self, - model_config: ModelConfig, - seq_len: int, - mm_registry: MultiModalRegistry, - is_encoder_data: bool = False, - ) -> DummyData: - """ - Create dummy data for profiling the memory usage of a model. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.sequence import SequenceData - - if not model_config.is_multimodal_model: - seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) - return DummyData(seq_data=seq_data) - - # Encoder dummy data does not contain multi-modal data - if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data( - model_config, seq_len) - seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) - return DummyData(seq_data=seq_data) - - dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) - - return DummyData( - seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data.get_data(), - multi_modal_placeholders=dec_data.multi_modal_placeholders, - ) diff --git a/vllm/logger.py b/vllm/logger.py index 8f06eb03c7f93..37e8495768c04 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Logging configuration for vLLM.""" + import datetime import json import logging @@ -20,9 +21,12 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX +VLLM_LOGGING_STREAM = envs.VLLM_LOGGING_STREAM -_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " - "[%(filename)s:%(lineno)d] %(message)s") +_FORMAT = ( + f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " + "[%(fileinfo)s:%(lineno)d] %(message)s" +) _DATE_FORMAT = "%m-%d %H:%M:%S" DEFAULT_LOGGING_CONFIG = { @@ -38,7 +42,7 @@ DEFAULT_LOGGING_CONFIG = { "class": "logging.StreamHandler", "formatter": "vllm", "level": VLLM_LOGGING_LEVEL, - "stream": "ext://sys.stdout", + "stream": VLLM_LOGGING_STREAM, }, }, "loggers": { @@ -49,7 +53,7 @@ DEFAULT_LOGGING_CONFIG = { }, }, "version": 1, - "disable_existing_loggers": False + "disable_existing_loggers": False, } @@ -118,7 +122,8 @@ def _configure_vllm_root_logger() -> None: "VLLM_CONFIGURE_LOGGING evaluated to false, but " "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " "implies VLLM_CONFIGURE_LOGGING. Please enable " - "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH." + ) if VLLM_CONFIGURE_LOGGING: logging_config = DEFAULT_LOGGING_CONFIG @@ -127,13 +132,16 @@ def _configure_vllm_root_logger() -> None: if not path.exists(VLLM_LOGGING_CONFIG_PATH): raise RuntimeError( "Could not load logging config. File does not exist: %s", - VLLM_LOGGING_CONFIG_PATH) + VLLM_LOGGING_CONFIG_PATH, + ) with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file: custom_config = json.loads(file.read()) if not isinstance(custom_config, dict): - raise ValueError("Invalid logging config. Expected dict, got %s.", - type(custom_config).__name__) + raise ValueError( + "Invalid logging config. Expected dict, got %s.", + type(custom_config).__name__, + ) logging_config = custom_config for formatter in logging_config.get("formatters", {}).values(): @@ -167,7 +175,7 @@ logger = init_logger(__name__) def _trace_calls(log_path, root_dir, frame, event, arg=None): - if event in ['call', 'return']: + if event in ["call", "return"]: # Extract the filename, line number, function name, and the code object filename = frame.f_code.co_filename lineno = frame.f_lineno @@ -187,26 +195,29 @@ def _trace_calls(log_path, root_dir, frame, event, arg=None): last_filename = "" last_lineno = 0 last_func_name = "" - with open(log_path, 'a') as f: + with open(log_path, "a") as f: ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") - if event == 'call': - f.write(f"{ts} Call to" - f" {func_name} in {filename}:{lineno}" - f" from {last_func_name} in {last_filename}:" - f"{last_lineno}\n") + if event == "call": + f.write( + f"{ts} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) else: - f.write(f"{ts} Return from" - f" {func_name} in {filename}:{lineno}" - f" to {last_func_name} in {last_filename}:" - f"{last_lineno}\n") + f.write( + f"{ts} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) except NameError: # modules are deleted during shutdown pass return partial(_trace_calls, log_path, root_dir) -def enable_trace_function_call(log_file_path: str, - root_dir: Optional[str] = None): +def enable_trace_function_call(log_file_path: str, root_dir: Optional[str] = None): """ Enable tracing of every function call in code under `root_dir`. This is useful for debugging hangs or crashes. @@ -220,7 +231,8 @@ def enable_trace_function_call(log_file_path: str, logger.warning( "VLLM_TRACE_FUNCTION is enabled. It will record every" " function executed by Python. This will slow down the code. It " - "is suggested to be used for debugging hang or crashes only.") + "is suggested to be used for debugging hang or crashes only." + ) logger.info("Trace frame log is saved to %s", log_file_path) if root_dir is None: # by default, this is the vllm root directory diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py index cf690a89ae9bc..7202259ca21aa 100644 --- a/vllm/logging_utils/__init__.py +++ b/vllm/logging_utils/__init__.py @@ -2,7 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.logging_utils.formatter import NewLineFormatter +from vllm.logging_utils.log_time import logtime __all__ = [ "NewLineFormatter", + "logtime", ] diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index ad89638e10614..3a97000647d60 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -21,9 +21,10 @@ def prepare_object_to_dump(obj) -> str: if isinstance(obj, str): return f"'{obj}'" # Double quotes elif isinstance(obj, dict): - dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ - for k, v in obj.items()}) - return f'{{{dict_str}}}' + dict_str = ", ".join( + {f"{str(k)}: {prepare_object_to_dump(v)}" for k, v in obj.items()} + ) + return f"{{{dict_str}}}" elif isinstance(obj, list): return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" elif isinstance(obj, set): @@ -36,15 +37,14 @@ def prepare_object_to_dump(obj) -> str: elif isinstance(obj, torch.Tensor): # We only print the 'draft' of the tensor to not expose sensitive data # and to get some metadata in case of CUDA runtime crashed - return (f"Tensor(shape={obj.shape}, " - f"device={obj.device}," - f"dtype={obj.dtype})") - elif hasattr(obj, 'anon_repr'): + return f"Tensor(shape={obj.shape}, device={obj.device},dtype={obj.dtype})" + elif hasattr(obj, "anon_repr"): return obj.anon_repr() - elif hasattr(obj, '__dict__'): + elif hasattr(obj, "__dict__"): items = obj.__dict__.items() - dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \ - for k, v in items]) + dict_str = ", ".join( + [f"{str(k)}={prepare_object_to_dump(v)}" for k, v in items] + ) return f"{type(obj).__name__}({dict_str})" else: # Hacky way to make sure we can serialize the object in JSON format @@ -54,18 +54,22 @@ def prepare_object_to_dump(obj) -> str: return repr(obj) -def dump_engine_exception(config: VllmConfig, - scheduler_output: SchedulerOutput, - scheduler_stats: Optional[SchedulerStats]): +def dump_engine_exception( + config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats], +): # NOTE: ensure we can log extra info without risking raises # unexpected errors during logging with contextlib.suppress(Exception): _dump_engine_exception(config, scheduler_output, scheduler_stats) -def _dump_engine_exception(config: VllmConfig, - scheduler_output: SchedulerOutput, - scheduler_stats: Optional[SchedulerStats]): +def _dump_engine_exception( + config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats], +): logger.error( "Dumping input data for V1 LLM engine (v%s) with config: %s, ", VLLM_VERSION, @@ -73,8 +77,7 @@ def _dump_engine_exception(config: VllmConfig, ) try: dump_obj = prepare_object_to_dump(scheduler_output) - logger.error("Dumping scheduler output for model execution: %s", - dump_obj) + logger.error("Dumping scheduler output for model execution: %s", dump_obj) if scheduler_stats: logger.error("Dumping scheduler stats: %s", scheduler_stats) except Exception: diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py index 0affef10078dc..02ba308e18796 100644 --- a/vllm/logging_utils/formatter.py +++ b/vllm/logging_utils/formatter.py @@ -2,16 +2,75 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging +from pathlib import Path + +from vllm import envs class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" def __init__(self, fmt, datefmt=None, style="%"): - logging.Formatter.__init__(self, fmt, datefmt, style) + super().__init__(fmt, datefmt, style) + + self.use_relpath = envs.VLLM_LOGGING_LEVEL == "DEBUG" + if self.use_relpath: + self.root_dir = Path(__file__).resolve().parent.parent.parent def format(self, record): - msg = logging.Formatter.format(self, record) + def shrink_path(relpath: Path) -> str: + """ + Shortens a file path for logging display: + - Removes leading 'vllm' folder if present. + - If path starts with 'v1', + keeps the first two and last two levels, + collapsing the middle as '...'. + - Otherwise, keeps the first and last two levels, + collapsing the middle as '...'. + - If the path is short, returns it as-is. + - Examples: + vllm/model_executor/layers/quantization/utils/fp8_utils.py -> + model_executor/.../quantization/utils/fp8_utils.py + vllm/model_executor/layers/quantization/awq.py -> + model_executor/layers/quantization/awq.py + vllm/v1/attention/backends/mla/common.py -> + v1/attention/backends/mla/common.py + + Args: + relpath (Path): The relative path to be shortened. + Returns: + str: The shortened path string for display. + """ + parts = list(relpath.parts) + new_parts = [] + if parts and parts[0] == "vllm": + parts = parts[1:] + if parts and parts[0] == "v1": + new_parts += parts[:2] + parts = parts[2:] + elif parts: + new_parts += parts[:1] + parts = parts[1:] + if len(parts) > 2: + new_parts += ["..."] + parts[-2:] + else: + new_parts += parts + return "/".join(new_parts) + + if self.use_relpath: + abs_path = getattr(record, "pathname", None) + if abs_path: + try: + relpath = Path(abs_path).resolve().relative_to(self.root_dir) + except Exception: + relpath = Path(record.filename) + else: + relpath = Path(record.filename) + record.fileinfo = shrink_path(relpath) + else: + record.fileinfo = record.filename + + msg = super().format(record) if record.message != "": parts = msg.split(record.message) msg = msg.replace("\n", "\r\n" + parts[0]) diff --git a/vllm/logging_utils/log_time.py b/vllm/logging_utils/log_time.py new file mode 100644 index 0000000000000..9e94f463711d3 --- /dev/null +++ b/vllm/logging_utils/log_time.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Provides a timeslice logging decorator +""" + +import functools +import time + + +def logtime(logger, msg=None): + """ + Logs the execution time of the decorated function. + Always place it beneath other decorators. + """ + + def _inner(func): + @functools.wraps(func) + def _wrapper(*args, **kwargs): + start = time.perf_counter() + result = func(*args, **kwargs) + elapsed = time.perf_counter() - start + + prefix = ( + f"Function '{func.__module__}.{func.__qualname__}'" + if msg is None + else msg + ) + logger.debug("%s: Elapsed time %.7f secs", prefix, elapsed) + return result + + return _wrapper + + return _inner diff --git a/vllm/logits_process.py b/vllm/logits_process.py index 5967d0836bd45..6ac30ae0028e9 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from collections.abc import Sequence from typing import Callable, Union import torch @@ -19,8 +19,8 @@ to sample from.""" def get_bad_words_logits_processors( - bad_words: list[str], - tokenizer: AnyTokenizer) -> list[LogitsProcessor]: + bad_words: list[str], tokenizer: AnyTokenizer +) -> list[LogitsProcessor]: bad_words_ids: list[list[int]] = list() for bad_word in bad_words: @@ -31,15 +31,15 @@ def get_bad_words_logits_processors( prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token if (not add_prefix_space) or ( - add_prefix_space - and prompt_token_ids[0] != bad_words_ids[-1][0] - and len(prompt_token_ids) == len(bad_words_ids[-1])): + add_prefix_space + and prompt_token_ids[0] != bad_words_ids[-1][0] + and len(prompt_token_ids) == len(bad_words_ids[-1]) + ): bad_words_ids.append(prompt_token_ids) return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] @@ -55,7 +55,7 @@ class NoBadWordsLogitsProcessor: def __call__( self, - past_tokens_ids: Union[list[int], tuple[int]], + past_tokens_ids: Sequence[int], logits: torch.FloatTensor, ) -> torch.Tensor: if self.word_bias is None: @@ -78,8 +78,9 @@ class NoBadWordsLogitsProcessor: assert len(actual_prefix) == len(expected_prefix) is_match = tuple(actual_prefix) == tuple(expected_prefix) - last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match - else self._NEUTRAL_LOGIT) + last_token_bias[last_token_id] += ( + self._SMALLEST_LOGIT if is_match else self._NEUTRAL_LOGIT + ) logits = logits + self.word_bias + last_token_bias @@ -93,9 +94,9 @@ class NoBadWordsLogitsProcessor: self._check_token_ids_bounds(vocab_size=vocab_size) - self.word_bias = torch.zeros((vocab_size, ), - dtype=torch.float, - device=logits.device) + self.word_bias = torch.zeros( + (vocab_size,), dtype=torch.float, device=logits.device + ) for bad_word_ids in self.bad_words_ids: if len(bad_word_ids) == 1: @@ -116,4 +117,5 @@ class NoBadWordsLogitsProcessor: f" but the following tokens" f" were specified as bad: {invalid_token_ids}." f" All token id values should be integers satisfying:" - f" 0 <= token_id < {vocab_size}.") + f" 0 <= token_id < {vocab_size}." + ) diff --git a/vllm/logprobs.py b/vllm/logprobs.py new file mode 100644 index 0000000000000..2458e43c690f6 --- /dev/null +++ b/vllm/logprobs.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + + +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs and token ranks. + + Attributes: + logprob: The logprob of chosen token + rank: The vocab rank of chosen token (>=1) + decoded_token: The decoded chosen token index + """ + + logprob: float + rank: Optional[int] = None + decoded_token: Optional[str] = None + + +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. +PromptLogprobs = list[Optional[dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. +SampleLogprobs = list[dict[int, Logprob]] diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py deleted file mode 100644 index 7fc4cfe026aee..0000000000000 --- a/vllm/lora/fully_sharded_layers.py +++ /dev/null @@ -1,355 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import LoRAConfig -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - RowParallelLinearWithLoRA) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - pass - - -def _fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - return (can_replace(*args, **kwargs) - and kwargs["lora_config"].fully_sharded_loras) - - return dec - - -def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): - """ - For `ColumnParallelLinearWithLoRA` or classes that inherit from - `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. - """ - assert (layer.n_slices == len(layer.lora_a_stacked) == len( - layer.lora_b_stacked) == len(layer.output_slices)) - if layer.lora_bias_stacked is not None: - assert layer.n_slices == len(layer.lora_bias_stacked) - - output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - - # Since communication is needed, the buffer is directly initialized as a - # tensor rather than a tuple of tensor. - buffers = torch.zeros( - (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( - buffers, x, layer.lora_a_stacked, 1.0) - - if not current_platform.can_update_inplace(): - buffers = shrunk_buffers - - buffers = tensor_model_parallel_all_gather(buffers) - - lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( - output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - # now have column partitioned and packed output - return output - - -# these layers are based on the tensor parallelism strategy given in -# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, -# https://arxiv.org/abs/2311.03285. - - -class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): - """ - Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, - # their `lora_a` and `lora_b` have different sharding patterns. After - # completing the `lora_a` GEMM , a gather operation is performed. - # Therefore, the sharding of `lora_a` only needs to correspond with the - # gather operation. - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedColumnParallelLinearWithShardedLoRA( - MergedColumnParallelLinearWithLoRA): - """ - Differs from MergedColumnParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - #NOTE: lora_a contains 2 subloras, and each sublora could be None. - output_shard_size = self.lora_a_stacked[0].shape[2] - output_start_idx = self.tp_rank * output_shard_size - lora_a = [ - lora_a[0][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[0] is not None else None, - lora_a[1][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[1] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): - """ - Differs from QKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): - """ - Differs from MergedQKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - # NOTE: lora_a contains 3 subloras, and each sublora could be None. - shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] - start_idx = [self.tp_rank * shard_size[i] for i in range(3)] - lora_a = [ - lora_a[0][:, start_idx[0]:start_idx[0] + - shard_size[0]] if lora_a[0] is not None else None, - lora_a[1][:, start_idx[1]:start_idx[1] + - shard_size[1]] if lora_a[1] is not None else None, - lora_a[2][:, start_idx[2]:start_idx[2] + - shard_size[2]] if lora_a[2] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): - """ - Differs from RowParallelLinearWithLoRA by slicing the - LoRA B's also. - - Based on S-LoRA, slicing happens along the output dim. - This yields a combined partial sum from the row parallel base - layer and column partitioned output from the LoRA. - """ - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - shard_size = self.lora_b_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - if bias is None: - return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - shard_size = self.lora_bias_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros( - (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( - buffer, x, self.lora_a_stacked, 1.0) - if not current_platform.can_update_inplace(): - buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) - - # following S-LoRA, allows the fusing of all_gather and all_reduce - # by adding the column partitioned lora output to a slice of output - # tensor, which is a partial sum due to row parallel. All that - # remains is a standard all_reduce. User should be aware though that - # the output is not the same as a normal row_parallel, it should be - # reduced before being used - # NOTE offset are based on the rank. - shard_size = self.lora_b_stacked[0].shape[2] - offset_start = self.tp_rank * shard_size - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( - output, - buffer, - self.lora_b_stacked, - self.lora_bias_stacked, - self.output_slices, - offset_start=offset_start, - add_input=True, - ) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - return output - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py deleted file mode 100644 index 24a05d310d108..0000000000000 --- a/vllm/lora/layers.py +++ /dev/null @@ -1,1192 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PretrainedConfig - -from vllm.adapter_commons.layers import AdapterMapping -from vllm.config import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) -from vllm.distributed.utils import divide -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - from vllm.lora.punica_wrapper import PunicaWrapperBase - - -def _get_lora_device(base_layer: nn.Module) -> torch.device: - # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 - """Returns the device for where to place the LoRA tensors.""" - # unquantizedLinear - if hasattr(base_layer, "weight"): - return base_layer.weight.device - # Compressed Tensor - elif hasattr(base_layer, "weight_packed"): - return base_layer.weight_packed.device - # GPTQ/AWQ - elif hasattr(base_layer, "qweight"): - return base_layer.qweight.device - # HQQ marlin - elif hasattr(base_layer, "W_q"): - return base_layer.W_q.device - else: - raise ValueError(f"Unsupported base layer: {base_layer}") - - -def _not_fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of not using fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - decorate = kwargs.pop("decorate") if "decorate" in kwargs else True - condition = (not kwargs["lora_config"].fully_sharded_loras - if decorate else True) - return can_replace(*args, **kwargs) and condition - - return dec - - -@dataclass -class LoRAMapping(AdapterMapping): - is_prefill: bool = False - - -class BaseLayerWithLoRA(nn.Module): - - def slice_lora_a( - self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora a if splitting for tensor parallelism.""" - ... - - def slice_lora_b( - self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora b if splitting with tensor parallelism.""" - ... - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """Initializes lora matrices.""" - ... - - def reset_lora(self, index: int): - """Resets the lora weights at index back to 0.""" - ... - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - """Overwrites lora tensors at index.""" - ... - - def set_mapping( - self, - punica_wrapper, - ): - self.punica_wrapper: PunicaWrapperBase = punica_wrapper - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - """Returns True if the layer can be replaced by this LoRA layer.""" - raise NotImplementedError - - -class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: VocabParallelEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - self.embeddings_slice: Optional[tuple[int, int]] - self.embeddings_weights: Optional[torch.Tensor] - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: - - if self.base_layer.num_added_embeddings_per_partition > 0: - # We can start adding lora weights - self.embeddings_weights = self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:self. - base_layer.num_org_embeddings_per_partition + - self.base_layer.num_added_embeddings_per_partition] - self.embeddings_slice = ( - self.base_layer.shard_indices.added_vocab_start_index - - self.base_layer.org_vocab_size, - self.base_layer.shard_indices.added_vocab_end_index - - self.base_layer.org_vocab_size) - self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:].fill_(0) - else: - self.embeddings_slice = None - self.embeddings_weights = None - - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - self.base_layer.embedding_dim, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked_2d = self.lora_a_stacked.view( - self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], - self.lora_a_stacked.shape[2], - ) - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0]:self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, - 1, 0) - - # NB: Don't use torch.narrow here. torch.narrow triggers some - # Dynamic Shape specialization in torch.compile - num_tokens = x.shape[0] - indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] - indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] - - full_lora_a_embeddings = F.embedding( - x + indices_1, - self.lora_a_stacked_2d, - ) - full_output = self.base_layer.forward(x + - (indices_0 * added_tokens_mask)) - - full_output_org = full_output - if full_output.ndim == 3: - full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) - if full_lora_a_embeddings.ndim == 3: - full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], - -1, - ) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_embedding( - full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) - - if not current_platform.can_update_inplace(): - full_output = lora_output - - return full_output.view_as(full_output_org) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is VocabParallelEmbedding - - @property - def weight(self): - return self.base_layer.weight - - -class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: LinearBase): - super().__init__() - self.base_layer = base_layer - self.input_size = self.base_layer.input_size - self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - - self.output_slices: tuple[int, ...] - self.tp_size: int - self.output_size: int - self.n_slices: int - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config - # - if isinstance(self.base_layer, ReplicatedLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, ColumnParallelLinear): - lora_a_out_size = (lora_config.max_lora_rank if - not lora_config.fully_sharded_loras else divide( - lora_config.max_lora_rank, self.tp_size)) - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, RowParallelLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = (self.output_size if - not lora_config.fully_sharded_loras else divide( - self.output_size, self.tp_size)) - else: - raise NotImplementedError - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_out_size, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_b_out_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - if lora_config.bias_enabled: - lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.output_slices = (self.lora_b_stacked[0].shape[2], ) - - def reset_lora(self, index: int): - for s_index in range(self.n_slices): - self.lora_a_stacked[s_index][index] = 0 - self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[s_index][index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - # Except for QKVParallelLinearWithLoRA and - # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers - # store weights in a tuple of size 1. These two layers will - # override this function. - assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == - self.n_slices == 1) - - self.reset_lora(index) - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if lora_bias is not None: - - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - # In transformers backend, x and output have extra batch dimension like - # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), - # therefore we need to flatten the batch dimensions. - if x.ndim == 3 and output.ndim == 3: - output = output.flatten(0, 1) - x = x.flatten(0, 1) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) - if not current_platform.can_update_inplace(): - output = lora_output - - return output - - @property - def weight(self) -> torch.Tensor: - - # unquantizedLinear - if hasattr(self.base_layer, "weight"): - return self.base_layer.weight - # Compressed Tensor - elif hasattr(self.base_layer, "weight_packed"): - return self.base_layer.weight_packed - # GPTQ/AWQ - elif hasattr(self.base_layer, "qweight"): - return self.base_layer.qweight - # marlin - elif hasattr(self.base_layer, "B"): - return self.base_layer.B - # HQQ marlin - elif hasattr(self.base_layer, "W_q"): - return self.base_layer.W_q - else: - raise ValueError(f"Unsupported base layer: {self.base_layer}") - - @property - def bias(self) -> Optional[torch.Tensor]: - if hasattr(self.base_layer, "bias"): - return self.base_layer.bias - else: - return None - - -class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: ReplicatedLinear) -> None: - super().__init__(base_layer, ) - # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 - self.output_size = self.base_layer.output_size - self.n_slices = 1 - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ReplicatedLinearWithLoRA - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output = self.apply(input_, bias) - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - # ReplicatedLinear should always be replaced, regardless of the fully - # sharded LoRAs setting, because it is, by definition, copied per GPU. - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ReplicatedLinear - - -class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - """ - LoRA on top of ColumnParallelLinear layer. - LoRA B is sliced for tensor parallelism. - There are two types for the `base_layer`: - 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. - 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. - """ - - def __init__(self, base_layer: ColumnParallelLinear) -> None: - super().__init__(base_layer) - # The base_layer type is ColumnParallelLinear or - # MergedColumnParallelLinear, their weight sharding logic is - # inconsistent when TP is greater than 1. - self.is_merged_col_linear = type( - base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size = self.base_layer.output_size_per_partition - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - # Applicable to cases where the base_layer is - # MergedColumnParallelLinear. - if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size // 2 - offset = lora_b.shape[-1] // 2 - - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] - lora_b = torch.cat([left_weight, right_weight], dim=1) - # Applicable to cases where the base_layer is - # ColumnParallelLinear. - else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - # TODO: Fix the slicing logic of bias. - if bias is None: - return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - - if not self.base_layer.return_bias: - return output - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ColumnParallelLinear or ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 1) - - -class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 2 sublayers (slices) - packed together (eg. gate_proj + up_proj -> gate_up_proj). - - This means we have 2 LoRAs, each applied to one half of the layer. - - Both slices must have the same size. - """ - - def __init__( - self, base_layer: Union[MergedColumnParallelLinear, - QKVParallelLinear]) -> None: - super().__init__(base_layer) - # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - # the output_sizes in MergedColumnParallelLinear is not sharded by tp - # we need to divide it by the tp_size to get correct slices size - output_sizes = self.base_layer.output_sizes - self.output_slices = tuple( - divide(output_size, self.tp_size) for output_size in output_sizes) - self.n_slices = len(self.output_slices) - self.output_ids = (self.tp_rank, ) * self.n_slices - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overriding this function is to enhance code - maintainability. - """ - self.lora_config = lora_config - - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - sliced_lora_b = [None] * self.n_slices - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[:, - shard_size * shard_id:shard_size * - (shard_id + 1)] - return sliced_lora_b - - def slice_bias( - self, bias: list[Union[torch.Tensor, - None]]) -> list[Union[torch.Tensor, None]]: - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id:shard_size * - (shard_id + 1)] - return bias - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - for i in range(self.n_slices): - if (lora_a_i := lora_a[i]) is not None: - self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) - if (lora_b_i := lora_b[i]) is not None: - self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, - 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, - non_blocking=True) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2) - - -class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ - ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chatglm3 and baichuan-7b, - only contains a single LoRA within their qkv_proj layer. - - During inference with Tensor Parallel, the weights of lora_b - must be accurately partitioned according to the respective ranks. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - self.q_proj_total_size = (self.base_layer.total_num_heads * - self.base_layer.head_size) - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * - self.base_layer.head_size) - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - bias_k = bias[k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - bias = torch.cat([bias_q, bias_k, bias_v], dim=1) - return bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 1 - - -class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): - """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) - packed together in qkv proj fashion - (q_proj + k_proj + v_proj -> qkv_proj). - - This means we have 3 LoRAs, each applied to one slice of the layer. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - # There are three LoRA layer. - self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.q_shard_id = self.tp_rank - self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - - self.output_slices = ( - self.q_proj_shard_size, - self.kv_proj_shard_size, - self.kv_proj_shard_size, - ) - self.output_ids = ( - self.q_shard_id, - self.kv_shard_id, - self.kv_shard_id, - ) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overloading this function is to handle inconsistent - weight dimensions in qkv lora. - """ - super().create_lora_weights(max_loras, lora_config, model_config) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) - - -#TODO: Implement this -class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): - pass - - -class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: RowParallelLinear) -> None: - super().__init__(base_layer) - - self.tp_size = get_tensor_model_parallel_world_size() - # reset input_size - self.input_size = self.base_layer.input_size_per_partition - self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() - # There is only one LoRA layer. - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - - shard_size = self.input_size - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ - # set up backprop all-reduce. - if self.base_layer.input_is_parallel: - input_parallel = input_ - else: - # TODO: simplify code below - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() - - # Matrix multiply. - output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.base_layer.skip_bias_add: - output = (output_ + self.base_layer.bias - if self.base_layer.bias is not None else output_) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is RowParallelLinear - - -class LogitsProcessorWithLoRA(BaseLayerWithLoRA): - """ - LoRA wrapper for LogitsProcessor, with extra logic to handle the - application of the LoRA adapter and added LoRA vocabulary. - - Args: - base_layer: LogitsProcessor layer - hidden_size: hidden size of the model - dtype: data type of the model - device: device of the model - sharded_to_full_mapping: index mapping from sharded vocab to full vocab - received from base_layer.get_sharded_to_full_mapping(). If None, - no reindexing will be done. - """ - - def __init__(self, base_layer: LogitsProcessor, hidden_size: int, - dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[list[int]]) -> None: - super().__init__() - self.base_layer = base_layer - self.hidden_size = hidden_size - self.dtype = dtype - self.device = device - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.sharded_to_full_mapping = sharded_to_full_mapping - - @property - def logits_as_input(self): - return self.base_layer.logits_as_input - - @property - def vocab_size(self): - return self.base_layer.vocab_size - - @property - def scale(self): - return self.base_layer.scale - - @property - def soft_cap(self): - return self.base_layer.soft_cap - - @property - def use_all_gather(self): - return self.base_layer.use_all_gather - - @property - def org_vocab_size(self): - return self.base_layer.org_vocab_size - - @property - def include_gpu_probs_tensor(self): - return self.base_layer.include_gpu_probs_tensor - - @property - def should_modify_greedy_probs_inplace(self): - return self.base_layer.should_modify_greedy_probs_inplace - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - # TODO: Verify if this condition can be further relaxed - if 32000 < self.base_layer.vocab_size > 257024: - raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 257024") - self.lora_a_stacked = torch.zeros( - ( - max_loras, - 1, - lora_config.max_lora_rank, - self.hidden_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - # Pad for kernel compatibility - math.ceil(self.base_layer.vocab_size / - lora_config.lora_vocab_padding_size) * - lora_config.lora_vocab_padding_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) - if self.sharded_to_full_mapping is not None: - self.sharded_to_full_mapping_gpu = torch.tensor( - self.sharded_to_full_mapping, - device=self.device, - dtype=torch.long) - else: - self.sharded_to_full_mapping_gpu = None - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ] = embeddings_tensor - - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: - # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, hidden_states) - if embedding_bias is not None: - logits += embedding_bias - - # Gather logits for TP - logits = self.base_layer._gather_logits(logits) - - if logits is None: - return None - - if self.sharded_to_full_mapping_gpu is not None: - # Reindex full logits tensor to ensure 1:1 mapping between - # index and token_id - # Example for: - # org_vocab_size = 4 - # added_vocab_size = 2 - # pad_to_size = 8 - # tp_size = 2 - - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 4, -1, 2, 3, 5, -1] - - # Therefore, the mapping is expected to be: - # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, - # we get: - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 2, 3, 4, 5, -1, -1] - logits = logits[:, self.sharded_to_full_mapping_gpu] - - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu(): - indices_padded = indices_padded[:logits.size(0)] - - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) - - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_logits( - logits, hidden_states, self.lora_a_stacked, - self.lora_b_stacked, 1.0) - - if not current_platform.can_update_inplace(): - logits = lora_output - - # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] - return logits - - def forward(self, *args, **kwargs): - return type(self.base_layer).forward(self, *args, **kwargs) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # Special handling for the LogitsProcessor. - return False diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py new file mode 100644 index 0000000000000..4915ef85f4f73 --- /dev/null +++ b/vllm/lora/layers/__init__.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, +) +from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA +from vllm.lora.layers.row_parallel_linear import ( + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, +) +from vllm.lora.layers.utils import LoRAMapping +from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA + +__all__ = [ + "BaseLayerWithLoRA", + "VocabParallelEmbeddingWithLoRA", + "LogitsProcessorWithLoRA", + "ColumnParallelLinearWithLoRA", + "ColumnParallelLinearWithShardedLoRA", + "MergedColumnParallelLinearWithLoRA", + "MergedColumnParallelLinearWithShardedLoRA", + "MergedQKVParallelLinearWithLoRA", + "MergedQKVParallelLinearWithShardedLoRA", + "QKVParallelLinearWithLoRA", + "QKVParallelLinearWithShardedLoRA", + "RowParallelLinearWithLoRA", + "RowParallelLinearWithShardedLoRA", + "ReplicatedLinearWithLoRA", + "LoRAMapping", +] diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py new file mode 100644 index 0000000000000..5279247a17594 --- /dev/null +++ b/vllm/lora/layers/base.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +class BaseLayerWithLoRA(nn.Module): + def slice_lora_a( + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py new file mode 100644 index 0000000000000..da053f0923aba --- /dev/null +++ b/vllm/lora/layers/base_linear.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed.utils import divide +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA +from .utils import _get_lora_device + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + # Ensure tp_size and tp_rank consistency with the base_layer. + self.tp_size = self.base_layer.tp_size + self.tp_rank = self.base_layer.tp_rank + self.device = _get_lora_device(self.base_layer) + self.output_slices: tuple[int, ...] + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = ( + lora_config.max_lora_rank + if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size) + ) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = ( + self.output_size + if not lora_config.fully_sharded_loras + else divide(self.output_size, self.tp_size) + ) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + for _ in range(self.n_slices) + ) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + for _ in range(self.n_slices) + ) + self.output_slices = (self.lora_b_stacked[0].shape[2],) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + # Except for QKVParallelLinearWithLoRA and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert ( + len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1 + ) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( + lora_a, non_blocking=True + ) + self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( + lora_b, non_blocking=True + ) + + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices + ) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + @property + def weight(self) -> torch.Tensor: + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py new file mode 100644 index 0000000000000..c49b90a80ceac --- /dev/null +++ b/vllm/lora/layers/column_parallel_linear.py @@ -0,0 +1,587 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed.utils import divide +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, +) +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert ( + layer.n_slices + == len(layer.lora_a_stacked) + == len(layer.lora_b_stacked) + == len(layer.output_slices) + ) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0 + ) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + + buffers = tensor_model_parallel_all_gather(buffers) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.output_slices, + offset_start=0, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + shard_size = self.output_size // 2 + offset = lora_b.shape[0] // 2 + + left_weight = lora_b[ + self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, : + ] + right_weight = lora_b[ + offset + self.tp_rank * shard_size : offset + + (self.tp_rank + 1) * shard_size, + :, + ] + lora_b = torch.cat([left_weight, right_weight], dim=0) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + shard_size = self.output_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[start_idx:end_idx, :] + return lora_b + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output and self.tp_size > 1: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1 + ) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (e.g. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, QKVParallelLinear] + ) -> None: + super().__init__(base_layer) + # There are two LoRA layers + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes + ) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank,) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank + if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size) + ) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + for _ in range(self.n_slices) + ) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + for output_size in self.output_slices + ) + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + sliced_lora_b = [None] * self.n_slices + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices) + ): + if (lora_b_i := lora_b[i]) is not None: + sliced_lora_b[i] = lora_b_i[ + shard_size * shard_id : shard_size * (shard_id + 1), : + ] + return sliced_lora_b + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1] + ].copy_(lora_a_i, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] + ].copy_(lora_b_i, non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2 + ) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = ( + self.base_layer.total_num_heads * self.base_layer.head_size + ) + self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size + self.kv_proj_shard_size = ( + self.base_layer.num_kv_heads * self.base_layer.head_size + ) + self.kv_proj_total_size = ( + self.base_layer.total_num_kv_heads * self.base_layer.head_size + ) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[ + self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size + * (self.q_shard_id + 1), + :, + ] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[ + k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1), + :, + ] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[ + v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1), + :, + ] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) + return lora_b + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + + self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size + self.kv_proj_shard_size = ( + self.base_layer.num_kv_heads * self.base_layer.head_size + ) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3 + + +# These following layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + lora_a = lora_a[start_idx : start_idx + shard_size, :] + return lora_a + + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][output_start_idx : output_start_idx + output_shard_size, :] + if lora_a[0] is not None + else None, + lora_a[1][output_start_idx : output_start_idx + output_shard_size, :] + if lora_a[1] is not None + else None, + ] + return lora_a + + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): + """ + Differs from QKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + lora_a = lora_a[start_idx : start_idx + shard_size, :] + return lora_a + + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): + """ + Differs from MergedQKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :] + if lora_a[0] is not None + else None, + lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :] + if lora_a[1] is not None + else None, + lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :] + if lora_a[2] is not None + else None, + ] + return lora_a + + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py new file mode 100644 index 0000000000000..f3ca60fb28d90 --- /dev/null +++ b/vllm/lora/layers/logits_processor.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__( + self, + base_layer: LogitsProcessor, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + sharded_to_full_mapping: Optional[list[int]], + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 257024" + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil( + self.base_layer.vocab_size / lora_config.lora_vocab_padding_size + ) + * lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, device=self.device, dtype=torch.long + ) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( + lora_a, non_blocking=True + ) + self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( + lora_b, non_blocking=True + ) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + : embeddings_tensor.shape[0], + : embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) + + neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype) + + lora_logits[-1] = neg_inf + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu() or current_platform.is_xpu(): + indices_padded = indices_padded[: logits.size(0)] + + lora_logits = ( + lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ) + .index_select(0, indices_padded) + .nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf) + ) + + logits[ + :, + self.base_layer.org_vocab_size : self.base_layer.org_vocab_size + + lora_logits.shape[1], + ] = lora_logits + + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0 + ) + + if not current_platform.can_update_inplace(): + logits = lora_output + + # Remove paddings in vocab (if any). + logits = logits[:, : self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py new file mode 100644 index 0000000000000..18a35cd1e0f22 --- /dev/null +++ b/vllm/lora/layers/replicated_linear.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.layers.linear import ReplicatedLinear + +from .base_linear import BaseLinearLayerWithLoRA + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__( + base_layer, + ) + # To ensure interface compatibility, set to 1 always. + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py new file mode 100644 index 0000000000000..fff4fb38ead90 --- /dev/null +++ b/vllm/lora/layers/row_parallel_linear.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import ( + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[:, start_idx:end_idx] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +# The following layer is based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[start_idx:end_idx, :] + return lora_b + + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0 + ) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + if self.tp_size > 1: + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py new file mode 100644 index 0000000000000..2da90f180ee74 --- /dev/null +++ b/vllm/lora/layers/utils.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch +import torch.nn as nn + + +@dataclass +class LoRAMapping: + index_mapping: tuple[int, ...] + prompt_mapping: tuple[int, ...] + is_prefill: bool = False + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True + return can_replace(*args, **kwargs) and condition + + return dec + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return ( + can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras + ) + + return dec diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py new file mode 100644 index 0000000000000..0a252b425c4a8 --- /dev/null +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition # noqa: E501 + + self.base_layer.num_added_embeddings_per_partition + ] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index + - self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index + - self.base_layer.org_vocab_size, + ) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition : + ].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, + # so we need transpose here + self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True + ) + self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( + lora_b, non_blocking=True + ) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + : embeddings_tensor.shape[0], + : embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0] : self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) + + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + num_tokens = x.shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] + indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] + + full_lora_a_embeddings = F.embedding( + x + indices_1, + self.lora_a_stacked_2d, + ) + full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1 + ) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], + -1, + ) + + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True + ) + + if not current_platform.can_update_inplace(): + full_output = lora_output + + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + @property + def weight(self): + return self.base_layer.weight diff --git a/vllm/lora/lora.py b/vllm/lora/lora_weights.py similarity index 71% rename from vllm/lora/lora.py rename to vllm/lora/lora_weights.py index 958364fca592f..b043a46f9e2a5 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora_weights.py @@ -21,7 +21,6 @@ class LoRALayerWeights: lora_alpha: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None, scaling: Optional[float] = None, ) -> None: @@ -30,7 +29,6 @@ class LoRALayerWeights: self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b - self.bias = bias self.embeddings_tensor = embeddings_tensor if scaling is None: @@ -48,11 +46,11 @@ class LoRALayerWeights: @property def input_dim(self) -> int: - return self.lora_a.shape[0] + return self.lora_a.shape[1] @property def output_dim(self) -> int: - return self.lora_b.shape[1] + return self.lora_b.shape[0] @property def is_packed(self) -> bool: @@ -60,8 +58,9 @@ class LoRALayerWeights: @property def extra_vocab_size(self) -> int: - return self.embeddings_tensor.shape[ - 0] if self.embeddings_tensor is not None else 0 + return ( + self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0 + ) @classmethod def from_config( @@ -70,51 +69,53 @@ class LoRALayerWeights: peft_helper: PEFTHelper, embeddings_tensor: Optional[torch.Tensor] = None, ) -> "LoRALayerWeights": - return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, - None, None, embeddings_tensor, - peft_helper.vllm_lora_scaling_factor) + # lora_a and lora_b are set to None for config-based construction + return cls( + module_name, + peft_helper.r, + peft_helper.lora_alpha, + None, + None, + embeddings_tensor, + peft_helper.vllm_lora_scaling_factor, + ) @classmethod def create_dummy_lora_weights( - cls, - module_name: str, - input_dim: int, - output_dim: int, - rank: int, - dtype: torch.dtype, - device: torch.types.Device, - embeddings_tensor_dim: Optional[int] = None, - bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.types.Device, + embeddings_tensor_dim: Optional[int] = None, + ) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() - lora_a = torch.zeros([input_dim, rank], - dtype=dtype, - device=device, - pin_memory=pin_memory) - lora_b = torch.zeros([rank, output_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) - if bias_enabled: - bias = torch.zeros([output_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) - else: - bias = None + lora_a = torch.zeros( + [rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory + ) + lora_b = torch.zeros( + [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory + ) - embeddings_tensor = torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory) if embeddings_tensor_dim else None + embeddings_tensor = ( + torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory, + ) + if embeddings_tensor_dim + else None + ) return cls( module_name, rank=rank, lora_alpha=1, lora_a=lora_a, lora_b=lora_b, - bias=bias, embeddings_tensor=embeddings_tensor, ) @@ -129,7 +130,6 @@ class PackedLoRALayerWeights(LoRALayerWeights): lora_alphas: list[Optional[int]], lora_a: list[Optional[torch.Tensor]], lora_b: list[Optional[torch.Tensor]], - bias: Optional[list[Optional[torch.Tensor]]] = None, scaling: Optional[list[float]] = None, ) -> None: super().__init__( @@ -138,7 +138,6 @@ class PackedLoRALayerWeights(LoRALayerWeights): lora_alpha=0, lora_a=lora_a, lora_b=lora_b, - bias=bias, scaling=scaling, # type: ignore embeddings_tensor=None, ) @@ -170,11 +169,11 @@ class PackedLoRALayerWeights(LoRALayerWeights): [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], - [lora.bias if lora is not None else None for lora in loras], scaling=[ 1 if lora is not None else None # type: ignore for lora in loras - ]) + ], + ) return obj def optimize(self) -> "PackedLoRALayerWeights": diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3072047a2606c..cf9089eff1757 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,29 +3,27 @@ import math import os -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, TypeVar, Union import regex as re import safetensors.torch import torch from torch import nn -from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, - AdapterModelManager) -from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, - get_adapter, list_adapters, - remove_adapter, set_adapter_mapping) -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping -from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.lora.utils import (from_layer, from_layer_logits_processor, - get_supported_lora_modules, - is_regex_target_modules, - parse_fine_tuned_lora_name, replace_submodule) +from vllm.lora.utils import ( + from_layer, + from_layer_logits_processor, + get_supported_lora_modules, + is_regex_target_modules, + parse_fine_tuned_lora_name, + replace_submodule, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal @@ -34,9 +32,24 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.model_executor.utils import get_packed_modules_mapping from vllm.utils import is_pin_memory_available +from vllm.utils.cache import LRUCache logger = init_logger(__name__) +T = TypeVar("T") + + +class AdapterLRUCache(LRUCache[int, T]): + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: int, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + _GLOBAL_LORA_ID = 0 @@ -52,12 +65,13 @@ def is_moe_model(model: nn.Module) -> bool: logger.warning_once( "For MoE models, vLLM currently does not support fused MoE LoRA " "inference. Please ensure that the loaded LoRA model does not " - "contain expert weights.") + "contain expert weights." + ) return True return False -class LoRAModel(AdapterModel): +class LoRAModel: """A LoRA fine-tuned model.""" def __init__( @@ -75,9 +89,9 @@ class LoRAModel(AdapterModel): """ self.id = lora_model_id - assert ( - lora_model_id - > 0), f"a valid lora id should be greater than 0, got {self.id}" + assert lora_model_id > 0, ( + f"a valid lora id should be greater than 0, got {self.id}" + ) self.rank = rank self.loras: dict[str, LoRALayerWeights] = loras @@ -93,8 +107,11 @@ class LoRAModel(AdapterModel): @property def extra_vocab_size(self) -> int: - return max(lora.extra_vocab_size - for lora in self.loras.values()) if self.loras else 0 + return ( + max(lora.extra_vocab_size for lora in self.loras.values()) + if self.loras + else 0 + ) def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: """Get LoRA for a given module by name""" @@ -122,53 +139,45 @@ class LoRAModel(AdapterModel): pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): - module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( - tensor_name, weights_mapper) + module_name, is_lora_a = parse_fine_tuned_lora_name( + tensor_name, weights_mapper + ) if module_name not in loras: lora_embeddings_tensor = None if embeddings: assert embedding_modules is not None embeddings_module = next( - (k for k in embedding_modules if k in module_name), - None) + (k for k in embedding_modules if k in module_name), None + ) if embeddings_module: lora_embeddings_tensor = embeddings[ - embedding_modules[embeddings_module]].to( - device=device, dtype=dtype) + embedding_modules[embeddings_module] + ].to(device=device, dtype=dtype) if pin_memory: - lora_embeddings_tensor = ( - lora_embeddings_tensor.pin_memory()) + lora_embeddings_tensor = lora_embeddings_tensor.pin_memory() loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper, lora_embeddings_tensor) + module_name, peft_helper, lora_embeddings_tensor + ) - if is_bias: - loras[module_name].bias = tensor.to(device=device, - dtype=dtype).t() - bias = tensor.to(device=device, dtype=dtype).t() + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) if pin_memory: - bias = bias.pin_memory() - loras[module_name].bias = bias - elif is_lora_a: - loras[module_name].lora_a = tensor.to(device=device, - dtype=dtype).t() - if pin_memory: - loras[module_name].lora_a = loras[ - module_name].lora_a.pin_memory() + loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() else: - loras[module_name].lora_b = tensor.to(device=device, - dtype=dtype).t() + loras[module_name].lora_b = tensor.to(device=device, dtype=dtype) assert embedding_padding_modules is not None - if any(name in module_name - for name in embedding_padding_modules - ) and target_embedding_padding is not None: + if ( + any(name in module_name for name in embedding_padding_modules) + and target_embedding_padding is not None + ): lora_b = loras[module_name].lora_b - assert target_embedding_padding >= lora_b.shape[1] - addition = target_embedding_padding - lora_b.shape[1] + assert target_embedding_padding >= lora_b.shape[0] + addition = target_embedding_padding - lora_b.shape[0] loras[module_name].lora_b = torch.nn.functional.pad( - lora_b, (0, addition)) + lora_b, (0, 0, 0, addition) + ) if pin_memory: - loras[module_name].lora_b = loras[ - module_name].lora_b.pin_memory() + loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() for lora in loras.values(): lora.optimize() @@ -177,19 +186,20 @@ class LoRAModel(AdapterModel): @classmethod def from_local_checkpoint( - cls, - lora_dir: str, - expected_lora_modules: list[str], - peft_helper: PEFTHelper, - *, - lora_model_id: Optional[int] = None, - device: str = "cuda", - dtype: Optional[torch.dtype] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[dict[str, str]] = None, - embedding_padding_modules: Optional[list[str]] = None, - weights_mapper: Optional[WeightsMapper] = None, - tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel": + cls, + lora_dir: str, + expected_lora_modules: list[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + tensorizer_config_dict: Optional[dict] = None, + ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. Args: @@ -209,16 +219,15 @@ class LoRAModel(AdapterModel): lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") new_embeddings_tensor_path = os.path.join( - lora_dir, "new_embeddings.safetensors") - new_embeddings_bin_file_path = os.path.join(lora_dir, - "new_embeddings.bin") + lora_dir, "new_embeddings.safetensors" + ) + new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") tensors: dict[str, torch.Tensor] = {} unexpected_modules: list[Union[list[str], str]] = [] def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa - module_name, _, _ = parse_fine_tuned_lora_name( - lora_module, weights_mapper) + module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) part_name = module_name.split(".")[-1] if part_name not in expected_lora_modules: unexpected_modules.append(module_name) @@ -227,19 +236,22 @@ class LoRAModel(AdapterModel): f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct") + f" Please verify that the loaded LoRA module is correct" + ) if tensorizer_config_dict: from tensorizer import TensorDeserializer tensorizer_config = TensorizerConfig(**tensorizer_config_dict) - lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, - "adapter_model.tensors") + lora_tensor_path = os.path.join( + tensorizer_config.tensorizer_dir, "adapter_model.tensors" + ) tensorizer_args = tensorizer_config._construct_tensorizer_args() tensors = TensorDeserializer( lora_tensor_path, dtype=tensorizer_config.dtype, - **tensorizer_args.deserialization_kwargs) + **tensorizer_args.deserialization_kwargs, + ) check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): @@ -250,14 +262,12 @@ class LoRAModel(AdapterModel): # loraified. C won’t exist in the safetensor but it will exist in # the target_modules of the adapter_config.json. unexpected_modules = [] - with safetensors.safe_open(lora_tensor_path, - framework="pt") as f: # type: ignore + with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore # Load tensors if there are only expected modules. check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path) or os.path.isfile( - lora_pt_file_path): + elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): # When a bin/pt file is provided, we rely on config to find # unexpected modules. unexpected_modules = [] @@ -275,33 +285,33 @@ class LoRAModel(AdapterModel): # https://github.com/vllm-project/vllm/pull/5909. But there's no # other better mechanism. if unexpected_modules and not is_regex_target_modules( - peft_helper.target_modules, expected_lora_modules): + peft_helper.target_modules, expected_lora_modules + ): raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct") - lora_file_path = (lora_bin_file_path - if os.path.isfile(lora_bin_file_path) else - lora_pt_file_path) - tensors = torch.load(lora_file_path, - map_location=device, - weights_only=True) + f" Please verify that the loaded LoRA module is correct" + ) + lora_file_path = ( + lora_bin_file_path + if os.path.isfile(lora_bin_file_path) + else lora_pt_file_path + ) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) else: raise ValueError(f"{lora_dir} doesn't contain tensors") embeddings = None if os.path.isfile(new_embeddings_tensor_path): - embeddings = safetensors.torch.load_file( - new_embeddings_tensor_path) + embeddings = safetensors.torch.load_file(new_embeddings_tensor_path) elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load(new_embeddings_bin_file_path, - map_location=device, - weights_only=True) + embeddings = torch.load( + new_embeddings_bin_file_path, map_location=device, weights_only=True + ) return cls.from_lora_tensors( - lora_model_id=get_lora_id() - if lora_model_id is None else lora_model_id, + lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, tensors=tensors, peft_helper=peft_helper, device=device, @@ -310,10 +320,11 @@ class LoRAModel(AdapterModel): target_embedding_padding=target_embedding_padding, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, - weights_mapper=weights_mapper) + weights_mapper=weights_mapper, + ) -class LoRAModelManager(AdapterModelManager): +class LoRAModelManager: """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( @@ -336,6 +347,11 @@ class LoRAModelManager(AdapterModelManager): vocab_size: the vocab size of the model. lora_config: the LoRA configuration. """ + self.model: SupportsLoRA = model + self._registered_adapters: dict[int, LoRAModel] = {} + # Dict instead of a set for compatibility with LRUCache. + self._active_adapters: dict[int, None] = {} + self.adapter_type = "LoRA" self.lora_config = lora_config self.device = device self.max_num_seqs = max_num_seqs @@ -347,9 +363,8 @@ class LoRAModelManager(AdapterModelManager): max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, - max_loras=self.lora_config.max_loras) - - super().__init__(model) + max_loras=self.lora_config.max_loras, + ) self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" @@ -361,7 +376,8 @@ class LoRAModelManager(AdapterModelManager): supports_multimodal(self.model) # In case the model only supports LoRA for # text modules (e.g. ChatGLM) - and hasattr(self.model, "get_mm_mapping")) + and hasattr(self.model, "get_mm_mapping") + ) self.is_pooling_model = is_pooling_model(self.model) self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} @@ -370,7 +386,9 @@ class LoRAModelManager(AdapterModelManager): self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self - self.adapter_type = 'LoRA' + + def __len__(self) -> int: + return len(self._registered_adapters) @property def capacity(self) -> int: @@ -392,33 +410,32 @@ class LoRAModelManager(AdapterModelManager): if lora_id in self._active_adapters: return False first_free_slot = next( - ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) - if lora_id is None), None) + ( + (i, lora_id) + for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None + ), + None, + ) if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot self._active_adapters[lora_id] = None lora_model = self._registered_adapters[lora_id] - logger.debug("Activating LoRA. int id: %d, slot index: %d", - lora_model.id, index) + logger.debug( + "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index + ) self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: module_lora.optimize() - # Bias is not explicitly enabled with the flag enable_lora_bias. - bias = module_lora.bias - if ((torch.is_tensor(bias) or - (isinstance(bias, Sequence) and any(b is not None - for b in bias))) - and not self.lora_config.bias_enabled): - module_lora.bias = None - raise ValueError( - f"Adapter bias cannot be used for {module_name}" - " without --enable-lora-bias.") - module.set_lora(index, module_lora.lora_a, module_lora.lora_b, - module_lora.embeddings_tensor, - module_lora.bias) + module.set_lora( + index, + module_lora.lora_a, + module_lora.lora_b, + module_lora.embeddings_tensor, + ) else: module.reset_lora(index) return True @@ -438,7 +455,8 @@ class LoRAModelManager(AdapterModelManager): """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager. " - "Use LRUCacheLoRAModelManager for pinning") # type: ignore + "Use LRUCacheLoRAModelManager for pinning" + ) # type: ignore def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: # update lora states @@ -457,16 +475,14 @@ class LoRAModelManager(AdapterModelManager): self._active_adapters.clear() def _create_lora_modules(self): - def _parent_module(module_name: str) -> str: # module name is a dot separated name. # for example: # - given an input 'x.y.z' return 'x.y' # - given an input 'x' return '' - return module_name.rpartition('.')[0] + return module_name.rpartition(".")[0] - for module_name, module in self.model.named_modules( - remove_duplicate=False): + for module_name, module in self.model.named_modules(remove_duplicate=False): if isinstance(module, PPMissingLayer): continue if not self._match_target_modules(module_name): @@ -483,35 +499,48 @@ class LoRAModelManager(AdapterModelManager): parts = module_name.split(".")[-1] packed_moduled_lst = self.packed_modules_mapping.get(parts, []) new_module = replace_submodule( - self.model, module_name, - from_layer(module, self.lora_slots, self.lora_config, - packed_moduled_lst, self.model.config)) + self.model, + module_name, + from_layer( + module, + self.lora_slots, + self.lora_config, + packed_moduled_lst, + self.model.config, + ), + ) # (yard1): TODO make this more robust if "lm_head" in module_name: - logits_processor_module_name = 'logits_processor' + logits_processor_module_name = "logits_processor" parent_module = _parent_module(module_name) if parent_module: logits_processor_module_name = ( - f"{parent_module}.{logits_processor_module_name}") + f"{parent_module}.{logits_processor_module_name}" + ) logits_processor_module = self.model.get_submodule( - logits_processor_module_name) + logits_processor_module_name + ) new_module = replace_submodule( - self.model, logits_processor_module_name, - from_layer_logits_processor(logits_processor_module, - module, self.lora_slots, - self.lora_config, - self.model.config)) + self.model, + logits_processor_module_name, + from_layer_logits_processor( + logits_processor_module, + module, + self.lora_slots, + self.lora_config, + self.model.config, + ), + ) # In some models, especially multimodal ones, layers with the same # name may have different types, such as nn.Linear and # ReplicatedLinear. The nn.Linear layers cannot be replaced with # LoRA layers, leading to assertion error. The following check # aims to prevent this error - if self.supports_mm and not isinstance(new_module, - BaseLayerWithLoRA): + if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) @@ -523,33 +552,40 @@ class LoRAModelManager(AdapterModelManager): self.modules[module_name] = module def create_dummy_lora( - self, - lora_id: int, - rank: int, - embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: + self, + lora_id: int, + rank: int, + embedding_modules: Optional[dict[str, str]] = None, + ) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): - bias_enabled = self.lora_config.bias_enabled - if (not self._match_target_modules(module_name) - or not isinstance(module, BaseLayerWithLoRA) - or self._filter_unsupported_mm_module(module_name)): + if ( + not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or self._filter_unsupported_mm_module(module_name) + ): continue parts = module_name.split(".") if module_name not in self.packed_modules: assert embedding_modules is not None if parts[-1] in embedding_modules: - input_dim = (module.base_layer.org_vocab_size + - self.lora_config.lora_extra_vocab_size if - hasattr(module.base_layer, "org_vocab_size") - else module.base_layer.weight.shape[1]) - output_dim = module.base_layer.embedding_dim if hasattr( - module.base_layer, - "embedding_dim") else module.base_layer.weight.shape[0] - embeddings_tensor_dim = (module.base_layer.embedding_dim if - hasattr(module.base_layer, - "embedding_dim") else - module.base_layer.weight.shape[1]) + input_dim = ( + module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size + if hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1] + ) + output_dim = ( + module.base_layer.embedding_dim + if hasattr(module.base_layer, "embedding_dim") + else module.base_layer.weight.shape[0] + ) + embeddings_tensor_dim = ( + module.base_layer.embedding_dim + if hasattr(module.base_layer, "embedding_dim") + else module.base_layer.weight.shape[1] + ) lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, @@ -558,7 +594,7 @@ class LoRAModelManager(AdapterModelManager): module.lora_a_stacked[0].dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim, - bias_enabled=bias_enabled) + ) else: lora = LoRALayerWeights.create_dummy_lora_weights( module_name, @@ -567,9 +603,7 @@ class LoRAModelManager(AdapterModelManager): rank, module.lora_a_stacked[0].dtype, "cpu", - bias_enabled=bias_enabled, ) - lora.optimize() else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] @@ -582,9 +616,7 @@ class LoRAModelManager(AdapterModelManager): rank, module.lora_a_stacked[i].dtype, "cpu", - bias_enabled=bias_enabled, ) - lora.optimize() subloras.append(lora) lora = PackedLoRALayerWeights.pack(subloras) model.loras[module_name] = lora @@ -593,9 +625,11 @@ class LoRAModelManager(AdapterModelManager): def _match_target_modules(self, module_name: str): return any( re.match( - r".*\.{target_module}$".format(target_module=target_module), - module_name) or target_module == module_name - for target_module in self.supported_lora_modules) + r".*\.{target_module}$".format(target_module=target_module), module_name + ) + or target_module == module_name + for target_module in self.supported_lora_modules + ) def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ @@ -606,8 +640,7 @@ class LoRAModelManager(AdapterModelManager): if self.supports_mm: module_mapping: MultiModelKeys = self.model.get_mm_mapping() prefix_lst = module_mapping.connector + module_mapping.tower_model - return any( - [module_name.startswith(prefix) for prefix in prefix_lst]) + return any([module_name.startswith(prefix) for prefix in prefix_lst]) return False def _register_packed_modules(self, module_full_name: str) -> None: @@ -641,23 +674,22 @@ class LoRAModelManager(AdapterModelManager): continue replacement_loras[i] = None # HACK Temporary solution for the pool model. - if self.is_pooling_model and not lora_model.check_lora_name( - module_name): + if self.is_pooling_model and not lora_model.check_lora_name(module_name): replaced_module_name = module_name.replace("model.", "") if lora_model.check_lora_name(module_name): module_name = replaced_module_name lora_model.loras[module_name] = PackedLoRALayerWeights.pack( - replacement_loras) + replacement_loras + ) # Remove the modules that have been replaced. for module in replaced_module: lora_model.loras.pop(module, None) def _get_lora_layer_weights( - self, lora_model: LoRAModel, - module_name: str) -> Optional[LoRALayerWeights]: + self, lora_model: LoRAModel, module_name: str + ) -> Optional[LoRALayerWeights]: org_module_name = module_name - if self.is_pooling_model and not lora_model.check_lora_name( - module_name): + if self.is_pooling_model and not lora_model.check_lora_name(module_name): # If it's a pool model, and the layer name is not found, # remove the prefix 'model.' and search again. module_name = module_name.replace("model.", "") @@ -665,53 +697,71 @@ class LoRAModelManager(AdapterModelManager): org_module_name = module_name logger.info_once( "For the pool model, successfully loaded the LoRA weights " - "after removing the prefix 'model.'.") + "after removing the prefix 'model.'." + ) return lora_model.get_lora(org_module_name) def deactivate_adapter(self, adapter_id: int) -> bool: - return deactivate_adapter(adapter_id, self._active_adapters, - self._deactivate_adapter) + if adapter_id not in self._active_adapters: + return False + self._deactivate_adapter(adapter_id) + self._active_adapters.pop(adapter_id, None) + return True def add_adapter(self, adapter: LoRAModel) -> bool: - logger.debug("Adding lora. Model id: %d, " - "int id: %d", adapter.id, adapter.id) - return add_adapter(adapter, self._registered_adapters, self.capacity, - self._add_adapter) + logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id) + if adapter.id in self._registered_adapters: + return False + if len(self._registered_adapters) >= self.capacity: + raise RuntimeError("No free adapter slots.") + self._add_adapter(adapter) + return True def set_adapter_mapping(self, mapping: LoRAMapping) -> None: - self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, - self._set_adapter_mapping) + if self._last_mapping != mapping: + self._set_adapter_mapping(mapping) + self._last_mapping = mapping def remove_adapter(self, adapter_id: int) -> bool: - return remove_adapter(adapter_id, self._registered_adapters, - self.deactivate_adapter) + self.deactivate_adapter(adapter_id) + if adapter_id not in self._registered_adapters: + return False + self._registered_adapters.pop(adapter_id, None) + return True - def list_adapters(self) -> dict[int, Any]: - return list_adapters(self._registered_adapters) + def list_adapters(self) -> dict[int, LoRAModel]: + return dict(self._registered_adapters) - def get_adapter(self, adapter_id: int) -> Optional[Any]: - return get_adapter(adapter_id, self._registered_adapters) + def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]: + return self._registered_adapters.get(adapter_id) class LoRALRUCache(AdapterLRUCache[LoRAModel]): - - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], - bool]): + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): """A model manager that manages multiple LoRAs with LRU cache.""" - def __init__(self, model: nn.Module, max_num_seqs: int, - max_num_batched_tokens: int, vocab_size: int, - lora_config: LoRAConfig, device: torch.device): - super().__init__(model, max_num_seqs, max_num_batched_tokens, - vocab_size, lora_config, device) + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + ): + super().__init__( + model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device + ) self._registered_adapters: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_adapter) + self.capacity, self.deactivate_adapter + ) self._active_adapters: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_adapter) + self.lora_slots, self._deactivate_adapter + ) def list_adapters(self) -> dict[int, LoRAModel]: """List all registered LoRAModels.""" @@ -719,8 +769,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager): def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - logger.debug("Adding lora. Model id: %d, " - "int id: %d", lora.id, lora.id) + logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id) if lora.id not in self._registered_adapters: self._add_adapter(lora) was_added = True @@ -734,8 +783,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager): self, lora_id: int, ) -> bool: - if lora_id not in self._active_adapters and len( - self._active_adapters) >= self.lora_slots: + if ( + lora_id not in self._active_adapters + and len(self._active_adapters) >= self.lora_slots + ): self._active_adapters.remove_oldest() result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order @@ -758,8 +809,9 @@ class LRUCacheLoRAModelManager(LoRAModelManager): try: self._registered_adapters.pin(lora_id) except ValueError as err: - raise ValueError("Pinning failed. " - f"LoRA {lora_id} is not registered.") from err + raise ValueError( + f"Pinning failed. LoRA {lora_id} is not registered." + ) from err def _pin_lora_in_gpu_cache(self, lora_id: int): if lora_id not in self._active_adapters: @@ -770,14 +822,15 @@ class LRUCacheLoRAModelManager(LoRAModelManager): def create_lora_manager( - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, - **kwargs) -> LoRAModelManager: + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, + **kwargs, +) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not isinstance(model, SupportsLoRA): raise ValueError(f"Model {type(model)} is not supported for LoRA.") @@ -788,5 +841,6 @@ def create_lora_manager( vocab_size=vocab_size, lora_config=lora_config, device=device, - **kwargs) + **kwargs, + ) return lora_manager diff --git a/vllm/lora/ops/ipex_ops/__init__.py b/vllm/lora/ops/ipex_ops/__init__.py index 5daa432493b19..f5a5e0e6f951f 100644 --- a/vllm/lora/ops/ipex_ops/__init__.py +++ b/vllm/lora/ops/ipex_ops/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.ipex_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink) +from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/ipex_ops/lora_ops.py b/vllm/lora/ops/ipex_ops/lora_ops.py index 7590c868ecb67..0767f90b2f9e7 100644 --- a/vllm/lora/ops/ipex_ops/lora_ops.py +++ b/vllm/lora/ops/ipex_ops/lora_ops.py @@ -13,32 +13,45 @@ except ImportError as e: raise e -def bgmv_shrink(inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0) -> None: - - ipex.llm.functional.bgmv_shrink(inputs, lora_a_weights, output_tensor, - lora_indices_tensor, scaling) +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + ipex.llm.functional.bgmv_shrink( + inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling + ) -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True) -> None: - ipex.llm.functional.bgmv_expand(inputs, lora_b_weights, output_tensor, - lora_indices_tensor, add_inputs) +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + ipex.llm.functional.bgmv_expand( + inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs + ) -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True) -> None: - ipex.llm.functional.bgmv_expand_slice(inputs, lora_b_weights, - output_tensor, lora_indices_tensor, - slice_offset, slice_size, add_inputs) +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + ipex.llm.functional.bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + lora_indices_tensor, + slice_offset, + slice_size, + add_inputs, + ) diff --git a/vllm/lora/ops/torch_ops/__init__.py b/vllm/lora/ops/torch_ops/__init__.py index 22aa3c63dce19..89865af4e9b89 100644 --- a/vllm/lora/ops/torch_ops/__init__.py +++ b/vllm/lora/ops/torch_ops/__init__.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 -from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, - sgmv_shrink) +from vllm.lora.ops.torch_ops.lora_ops import ( + bgmv_expand, # noqa: F401 + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) __all__ = [ "bgmv_expand", diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index cba5baad86686..4fc6248d5448e 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -4,30 +4,31 @@ import torch -def sgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - seq_len_tensor) +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, - add_inputs) + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) @@ -58,62 +59,70 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - seq_len_tensor) + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, - scaling) + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) -def bgmv_shrink(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + output_tensor[:, : outputs.shape[1]] = scaling * outputs[:] -def sgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - seq_len_tensor) +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, add_inputs) + bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + slice_offset, + slice_size, + add_inputs, + ) -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:] else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py index e93064d0c83ad..f6397a68ddb81 100644 --- a/vllm/lora/ops/triton_ops/kernel_utils.py +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -3,23 +3,35 @@ """ Utilities for Punica kernel construction. """ + from vllm.triton_utils import tl, triton @triton.jit -def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, - b_dtype: tl.constexpr): +def mm_k( + a_ptr, + b_ptr, + ak_stride, + bk_stride, + offset_k, + K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr, +): """ Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of B (k x n), iterate, through the K dimension to compute the partial/complete matrix block product. If SPLIT_K == 1, the output m x n product is complete. If SPLIT_K > 1, the thread block computes partial outputs. The partial - outputs are then atomically summed in the caller code. + outputs are then atomically summed in the caller code. Args: - a_ptr: Array of pointers, identifying rows of A + a_ptr: Array of pointers, identifying rows of A b_ptr: Array of pointers, identifying columns of B ak_stride: K dimension stride of the A matrix bk_stride: K dimension stride of the B matrix @@ -29,7 +41,7 @@ def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, BLOCK_K: K dimension atom EVEN_K: True if the blocks of A and B can be loaded without any masking. - SPLIT_K: Parameter signifying parallelism in the K dimension. + SPLIT_K: Parameter signifying parallelism in the K dimension. CAST_TYPE: if True, cast the values from the A matrix to the B matrix dtype. b_dtype: datatype of the B matrix @@ -40,14 +52,12 @@ def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, tiled_a = tl.load(a_ptr) tiled_b = tl.load(b_ptr) else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] - < K - k * (BLOCK_K * SPLIT_K), - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] - < K - k * (BLOCK_K * SPLIT_K), - other=0) + tiled_a = tl.load( + a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0 + ) + tiled_b = tl.load( + b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0 + ) if CAST_TYPE: tiled_a = tiled_a.to(b_dtype) accumulator += tl.dot( @@ -121,7 +131,8 @@ def do_expand_kernel( else: cur_input_ptr = input_ptr + slice_id * input_d0_stride cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) + tl.pointer_type(out_ptr.dtype.element_ty) + ) # Identify the column indices of B to process. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -129,17 +140,35 @@ def do_expand_kernel( # Identify A and B block pointers offset_k = tl.arange(0, BLOCK_K) - a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride) - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + - rbn[None, :] * cur_lora_d1_stride) + a_ptr = ( + cur_input_ptr + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride + ) + b_ptr = ( + cur_lora_ptr + + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride + ) # Compute the block matrix product. SPLIT_K = 1 - accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, - offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, - CAST_TYPE, cur_lora_ptr.dtype.element_ty) + accumulator = mm_k( + a_ptr, + b_ptr, + input_d2_stride, + cur_lora_d2_stride, + offset_k, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + CAST_TYPE, + cur_lora_ptr.dtype.element_ty, + ) tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) if SLICE_NUM == 1: @@ -150,10 +179,12 @@ def do_expand_kernel( # Identify the C output pointers to store the results of the accumulator. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start offset_cm = tl.arange(0, BLOCK_M) - c_ptr = (out_ptr + ram[:, None] * output_d0_stride + - offset_cn[None, :] * output_d1_stride) - c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] - < (cur_slice_start + N)) + c_ptr = ( + out_ptr + + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride + ) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < (cur_slice_start + N)) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) @@ -207,7 +238,8 @@ def do_shrink_kernel( else: # current lora ptr cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(input_ptr.dtype.element_ty)) + tl.pointer_type(input_ptr.dtype.element_ty) + ) # Identify the column indices of B to process. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -215,24 +247,42 @@ def do_shrink_kernel( # Identify A and B block pointers offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) - a_ptr = (input_ptr + ram[:, None] * input_d0_stride + - offset_k[None, :] * input_d1_stride) - b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + - rbn[None, :] * lora_d1_stride + - offset_k[:, None] * lora_d2_stride) + a_ptr = ( + input_ptr + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride + ) + b_ptr = ( + cur_lora_ptr + + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride + ) # Compute partial/complete block matrix product. - accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, - K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, - cur_lora_ptr.dtype.element_ty) + accumulator = mm_k( + a_ptr, + b_ptr, + input_d1_stride, + lora_d2_stride, + offset_k, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + False, + cur_lora_ptr.dtype.element_ty, + ) # Identify the C output pointers to store the results of the accumulator. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_cm = tl.arange(0, BLOCK_M) - cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + - slice_id * output_d0_stride) - c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ - None, :] * output_d2_stride + cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * output_d0_stride + c_ptr = ( + cur_out_ptr + + ram[:, None] * output_d1_stride + + offset_cn[None, :] * output_d2_stride + ) c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) accumulator *= scaling diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index b1ab84e08ba76..a7a552b9903d5 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -11,42 +11,41 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @triton.jit def _lora_expand_kernel( - input_ptr, - lora_ptr, - out_ptr, - M, - N, - K, - token_indices_sorted_by_lora_ids, - num_tokens_per_lora, - lora_token_start_loc, - lora_ids, - slice_start_loc, - input_d0_stride, - input_d1_stride, - input_d2_stride, # 1 - ls_d0_ptr, - ls_d1_ptr, - ls_d2_ptr, # 1 - output_d0_stride, - output_d1_stride, # 1 - output_hs_ptr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, - SLICE_NUM: tl.constexpr, - SAME_STRIDE: tl.constexpr): - + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr, +): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M) @@ -82,8 +81,9 @@ def _lora_expand_kernel( # Identify all rows that this CTA should process. lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) - cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + - lora_m_indices_start + cta_m_offset) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) # Load all relevant row indices. offset_m = tl.arange(0, BLOCK_M) % cta_m_len @@ -120,22 +120,21 @@ def _lora_expand_kernel( SLICE_NUM, EVEN_K, CAST_TYPE, - ADD_INPUTS) + ADD_INPUTS, + ) @torch.inference_mode() def _lora_expand( inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] - lora_b_weights: list[ - torch.Tensor], # shape [num_lora, hidden_size, lora_rank] - output_tensor: torch. - Tensor, # shape [num_tokens, hidden_size * num_slices] + lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank] + output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices] token_lora_mapping: torch.Tensor, # shape [num_tokens] token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] - no_lora_flag_cpu: torch.Tensor, # shape [1] + no_lora_flag_cpu: torch.Tensor, # shape [1] offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -150,7 +149,7 @@ def _lora_expand( token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from the A matrix grouped by LoRA IDs. num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number - of tokens that are to be processed by LoRA ID lora_ids[i] + of tokens that are to be processed by LoRA ID lora_ids[i] lora_token_start_loc (torch.Tensor): A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] @@ -159,9 +158,9 @@ def _lora_expand( lora_ids (torch.Tensor): LoRA ids to process. no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates if there are any requests that require LoRA. - offset_start (int, optional): Offset start for output_tensor. + offset_start (int, optional): Offset start for output_tensor. Defaults to 0. - add_inputs (bool, optional): Whether to add the input tensor to the + add_inputs (bool, optional): Whether to add the input tensor to the output tensor. Defaults to False. """ @@ -180,15 +179,20 @@ def _lora_expand( # metadata sanity check. M = inputs.size(1) assert token_lora_mapping.size(0) == M - assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( - 0) + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 - (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, - same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, - inputs.device) + ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + hidden_sizes_tensor, + same_stride, + MAX_N, + ) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device) K = lora_b_weights[0].shape[-1] # K= rank ADD_INPUTS = add_inputs @@ -207,8 +211,8 @@ def _lora_expand( EVEN_K = K % BLOCK_K == 0 # type: ignore if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ - torch.float16, - torch.bfloat16, + torch.float16, + torch.bfloat16, ]: CAST_TYPE = True @@ -283,7 +287,6 @@ try: op_func=_lora_expand, mutates_args=["output_tensor"], fake_impl=_lora_expand_fake, - dispatch_key=current_platform.dispatch_key, ) lora_expand = torch.ops.vllm.lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 39e647b9b88a4..df343305d710d 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -30,39 +30,35 @@ class LoRAKernelMeta: no_lora_flag_cpu: torch.Tensor @staticmethod - def make(max_loras: int, max_num_tokens: int, - device: Union[torch.device, str]) -> "LoRAKernelMeta": + def make( + max_loras: int, max_num_tokens: int, device: Union[torch.device, str] + ) -> "LoRAKernelMeta": + token_lora_mapping = torch.empty( + max_num_tokens, dtype=torch.int32, device=device + ) - token_lora_mapping = torch.empty(max_num_tokens, - dtype=torch.int32, - device=device) - - token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, - dtype=torch.int32, - device=device) + token_indices_sorted_by_lora_ids = torch.empty( + max_num_tokens, dtype=torch.int32, device=device + ) # +1 because "no-lora" is also a possibility # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] # is a possibility. - active_lora_ids = torch.empty(max_loras + 1, - dtype=torch.int32, - device=device) + active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, device=device) # using running example, [3, 10, 5, 2] is a possibility. - num_tokens_per_lora = torch.zeros(max_loras + 1, - dtype=torch.int32, - device=device) + num_tokens_per_lora = torch.zeros( + max_loras + 1, dtype=torch.int32, device=device + ) # +2 for this because, the first index is always 0. # using running example, lora_token_start_loc # is [0, 3, 13, 18, 20]. - lora_token_start_loc = torch.zeros(max_loras + 2, - dtype=torch.int32, - device=device) + lora_token_start_loc = torch.zeros( + max_loras + 2, dtype=torch.int32, device=device + ) - no_lora_flag_cpu = torch.tensor([False], - dtype=torch.bool, - device='cpu') + no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu") return LoRAKernelMeta( token_lora_mapping=token_lora_mapping, @@ -70,7 +66,8 @@ class LoRAKernelMeta: active_lora_ids=active_lora_ids, num_tokens_per_lora=num_tokens_per_lora, lora_token_start_loc=lora_token_start_loc, - no_lora_flag_cpu=no_lora_flag_cpu) + no_lora_flag_cpu=no_lora_flag_cpu, + ) def _reset(self): self.active_lora_ids.fill_(-1) @@ -83,8 +80,8 @@ class LoRAKernelMeta: Prepare kernel metadata tensors for the current forward pass. Args: - token_lora_tensor (torch.Tensor): Tensor containing lora indices - for each input token. + token_lora_mapping (torch.Tensor): Tensor containing lora indices + for each input token. """ self._reset() @@ -100,34 +97,44 @@ class LoRAKernelMeta: num_tokens = token_lora_mapping.size(0) # copy token lora mapping - self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, - non_blocking=True) + self.token_lora_mapping[:num_tokens].copy_( + token_lora_mapping, non_blocking=True + ) # token_indices_sorted_by_lora_ids - _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, - stable=True) + _, token_indices_sorted_by_lora_ids = torch.sort( + token_lora_mapping, stable=True + ) # start gpu transfer self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( - token_indices_sorted_by_lora_ids, non_blocking=True) + token_indices_sorted_by_lora_ids, non_blocking=True + ) # active_lora_ids, num_tokens_per_lora - lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, - sorted=True, - return_counts=True) - self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, - non_blocking=True) - self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( - num_tokens_per_lora, non_blocking=True) + lora_ids, num_tokens_per_lora = torch.unique( + token_lora_mapping, sorted=True, return_counts=True + ) + self.active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True) + self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True + ) # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) - self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( - lora_token_start_loc, non_blocking=True) + self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True + ) def meta_args( self, token_nums: int - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: """ This function returns the kernel metadata required for the current forward pass execution of the kernel. The function returns all the @@ -136,7 +143,7 @@ class LoRAKernelMeta: Args: token_nums (int): Number of input tokens in the current forward - pass. + pass of the kernel. """ return ( self.token_lora_mapping[:token_nums], diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 1e7075ab07151..1e7e43e30de78 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -11,22 +11,38 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @triton.jit -def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, - token_indices_sorted_by_lora_ids, num_tokens_per_lora, - lora_token_start_loc, lora_ids, scaling, - input_d0_stride, input_d1_stride, lora_d0_stride, - lora_d1_stride, lora_d2_stride, output_d0_stride, - output_d1_stride, output_d2_stride, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): - +def _lora_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + input_d0_stride, + input_d1_stride, + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + output_d0_stride, + output_d1_stride, + output_d2_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M) @@ -55,8 +71,9 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, # Identify all rows that this CTA should process. lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) - cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + - lora_m_indices_start + cta_m_offset) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) # Load all relevant row indices. offset_m = tl.arange(0, BLOCK_M) % cta_m_len @@ -91,17 +108,17 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, BLOCK_K, EVEN_K, SPLIT_K, - SLICE_NUM) + SLICE_NUM, + ) @torch.inference_mode() def _lora_shrink( inputs: torch.Tensor, # shape [num_tokens, hidden_size] - lora_a_weights: list[ - torch.Tensor], # shape [num_loras, lora_rank, hidden_size] + lora_a_weights: list[torch.Tensor], # shape [num_loras, lora_rank, hidden_size] output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] token_lora_mapping: torch.Tensor, # shape [num_tokens] - token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] @@ -119,7 +136,7 @@ def _lora_shrink( token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from the A matrix grouped by LoRA IDs. num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number - of tokens that are to be processed by LoRA ID lora_ids[i] + of tokens that are to be processed by LoRA ID lora_ids[i] lora_token_start_loc (torch.Tensor): A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] @@ -148,13 +165,13 @@ def _lora_shrink( # metadata sanity check M = inputs.size(0) assert token_lora_mapping.size(0) == M - assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( - 0) + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 - (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, - lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = ( + _get_lora_a_ptr(lora_a_weights, inputs.device) + ) N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank NUM_SLICES = len(lora_a_weights) MAX_LORAS = lora_ids.size(0) @@ -237,7 +254,6 @@ try: op_func=_lora_shrink, mutates_args=["output_tensor"], fake_impl=_lora_shrink_fake, - dispatch_key=current_platform.dispatch_key, ) lora_shrink = torch.ops.vllm.lora_shrink diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 4c50fbd270516..3a3e8fc8931e8 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -9,9 +9,9 @@ _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): """ - `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. - Refer to: + Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py """ key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) @@ -35,14 +35,15 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) if len(lora_a_weights) > 1: - lora_ptr_tensor = torch.tensor(tensor_ptrs, - device=device, - dtype=torch.uint64) + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) else: lora_ptr_tensor = lora_a_weights[0] - if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 - or len(set(lora_strides_d2)) > 1): + if ( + len(set(lora_strides_d0)) > 1 + or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1 + ): raise ValueError("All LoRA weights must have the same stride.") _LORA_A_PTR_DICT[key] = ( @@ -54,12 +55,13 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): return _LORA_A_PTR_DICT.get(key) -def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, - device: torch.device): - """ - `_LORA_B_PTR_DICT` collects the required information during `profile_run`, +def _get_lora_b_ptr( + lora_weights: list[torch.Tensor], offset_start: int, device: torch.device +): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. - Refer to: + Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py """ @@ -91,20 +93,21 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, if len(lora_weights) > 1: # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, - device=device, - dtype=torch.uint64) - slice_start_tensor = torch.tensor(slice_offset_lst, - device=device, - dtype=torch.uint64) + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + slice_start_tensor = torch.tensor( + slice_offset_lst, device=device, dtype=torch.uint64 + ) else: slice_start_tensor = slice_offset_lst[0] lora_ptr_tensor = lora_b_weight[0] # If each lora has the same stride, there's no need to use a # tensor for storage. - if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and - len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: + if ( + len(set(lora_strides_d0)) == 1 + and len(set(lora_strides_d1)) == 1 + and len(set(lora_strides_d2)) == 1 + ) and len(set(hidden_sizes)) == 1: lora_strides_d0_tensor = lora_strides_d0[0] lora_strides_d1_tensor = lora_strides_d1[0] lora_strides_d2_tensor = lora_strides_d2[0] @@ -119,8 +122,14 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, same_stride = False # MAX_N is the maximum hidden size among all the lora_b weights MAX_N = max(hidden_sizes) - _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, - lora_strides_d0_tensor, lora_strides_d1_tensor, - lora_strides_d2_tensor, hidden_sizes_tensor, - same_stride, MAX_N) + _LORA_B_PTR_DICT[key] = ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + hidden_sizes_tensor, + same_stride, + MAX_N, + ) return _LORA_B_PTR_DICT.get(key) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 7e7c3c892457a..b5570ceca68ca 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink) +from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 9118f3351ef0a..4924890b388cb 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -33,8 +33,7 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape if len(loras.shape) == 4: loras = loras.squeeze(axis=1) @@ -73,13 +72,12 @@ def bgmv_expand( limit = 1 if output_tensor.shape[1] > outputs.shape[1]: - outputs = F.pad(outputs, - (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) + outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) if add_inputs: - return output_tensor + outputs[:limit, :output_tensor.shape[1]] + return output_tensor + outputs[:limit, : output_tensor.shape[1]] else: - return outputs[:limit, :output_tensor.shape[1]] + return outputs[:limit, : output_tensor.shape[1]] def bgmv_shrink( @@ -93,14 +91,12 @@ def bgmv_shrink( inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. lora_b_weights (torch.Tensor): LoRA weights of shape [num_loras, lora_rank, hidden_size]. - output_tensor (torch.Tensor): (Unused) output tensor (placeholder). lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token. scaling (float, optional): Scalar multiplier applied to the output. """ - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, - lora_indices_tensor) + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice( diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 8b8e5cb7d5fae..8f21a2570224e 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -9,7 +9,7 @@ import os from dataclasses import MISSING, dataclass, field, fields from typing import Literal, Optional, Union -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -18,9 +18,9 @@ logger = init_logger(__name__) @dataclass class PEFTHelper: - """ + """ A helper class for PEFT configurations, specifically designed for LoRA. - This class handles configuration validation, compatibility checks for + This class handles configuration validation, compatibility checks for various LoRA implementations. """ @@ -29,7 +29,7 @@ class PEFTHelper: lora_alpha: int target_modules: Union[list[str], str] - bias: Literal["none", "all", "lora_only"] = field(default="none") + bias: Literal["none"] = field(default="none") modules_to_save: Optional[list[str]] = field(default=None) # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) use_rslora: bool = field(default=False) @@ -71,37 +71,38 @@ class PEFTHelper: # Identify any missing required fields missing_fields = required_fields - set(config_dict.keys()) if missing_fields: - raise ValueError( - f"Missing required configuration fields: {missing_fields}") + raise ValueError(f"Missing required configuration fields: {missing_fields}") # Filter out fields that aren't defined in the class - filtered_dict = { - k: v - for k, v in config_dict.items() if k in class_fields - } + filtered_dict = {k: v for k, v in config_dict.items() if k in class_fields} return cls(**filtered_dict) @classmethod def from_local_dir( - cls, - lora_path: str, - max_position_embeddings: Optional[int], - tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper": + cls, + lora_path: str, + max_position_embeddings: Optional[int], + tensorizer_config_dict: Optional[dict] = None, + ) -> "PEFTHelper": lora_config_path = os.path.join(lora_path, "adapter_config.json") if tensorizer_config_dict: tensorizer_config = TensorizerConfig(**tensorizer_config_dict) tensorizer_args = tensorizer_config._construct_tensorizer_args() from tensorizer.stream_io import open_stream - lora_config_path = os.path.join(tensorizer_config.tensorizer_dir, - "adapter_config.json") - with open_stream(lora_config_path, - mode="rb", - **tensorizer_args.stream_kwargs) as f: + + lora_config_path = os.path.join( + tensorizer_config.tensorizer_dir, "adapter_config.json" + ) + with open_stream( + lora_config_path, mode="rb", **tensorizer_args.stream_kwargs + ) as f: config = json.load(f) - logger.info("Successfully deserialized LoRA config from %s", - tensorizer_config.tensorizer_dir) + logger.info( + "Successfully deserialized LoRA config from %s", + tensorizer_config.tensorizer_dir, + ) else: with open(lora_config_path) as f: @@ -112,16 +113,16 @@ class PEFTHelper: def validate_legal(self, lora_config: LoRAConfig) -> None: """ - Validates the LoRA configuration settings against application + Validates the LoRA configuration settings against application constraints and requirements. """ error_msg = self._validate_features() if self.r > lora_config.max_lora_rank: error_msg.append( f"LoRA rank {self.r} is greater than max_lora_rank" - f" {lora_config.max_lora_rank}.") - if self.bias != "none" and not lora_config.bias_enabled: - error_msg.append( - "Adapter bias cannot be used without bias_enabled.") + f" {lora_config.max_lora_rank}." + ) + if self.bias != "none": + error_msg.append("Adapter bias is not supported.") if error_msg: raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b3413de1c8163..b803a482b1bca 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -60,14 +60,13 @@ class PunicaWrapperABC(ABC): y: torch.Tensor, x: Union[tuple[torch.Tensor, ...], torch.Tensor], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> Optional[torch.Tensor]: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. """ raise NotImplementedError @@ -81,39 +80,42 @@ class PunicaWrapperABC(ABC): **kwargs, ) -> Optional[torch.Tensor]: """ - Applies lora specifically for VocabParallelEmbeddingWithLoRA, + Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. """ raise NotImplementedError @abstractmethod - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ - Applicable to linear-related lora. + Applicable to linear-related lora. """ raise NotImplementedError @abstractmethod - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -122,41 +124,41 @@ class PunicaWrapperABC(ABC): class PunicaWrapperBase(PunicaWrapperABC): """ - PunicaWrapperBase is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - self._token_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices_padded = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._embeddings_indices = torch.empty(2, - max_num_batched_tokens, - dtype=torch.long, - device=device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + self._token_lora_indices = torch.empty( + max_num_batched_tokens, dtype=torch.long, device=device + ) + self._sampler_indices = torch.empty( + max_num_batched_tokens, dtype=torch.long, device=device + ) + self._sampler_indices_padded = torch.empty( + max_num_batched_tokens, dtype=torch.long, device=device + ) + self._embeddings_indices = torch.empty( + 2, max_num_batched_tokens, dtype=torch.long, device=device + ) # 4 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices self.indices_len: list[Optional[int]] = [None] * 4 # these attributes are the information required for sgmv kernel - self._seq_start_locs = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._seq_lengths = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._lora_indices_per_batch = torch.empty(max_batches, - dtype=torch.long, - device=device) + self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device) + self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device) + self._lora_indices_per_batch = torch.empty( + max_batches, dtype=torch.long, device=device + ) self.device: torch.device = device self.max_length: int = 0 self.token_nums: int = 0 @@ -186,89 +188,66 @@ class PunicaWrapperBase(PunicaWrapperABC): extra_vocab_size, self.device, ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) + self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded + ) + self._embeddings_indices[ + : embeddings_indices.shape[0], : embeddings_indices.shape[1] + ].copy_(embeddings_indices) self.indices_len[:] = indices_len - def _update_prefill_metadata(self, - token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: + ( + b_seq_start_tensor, + seq_length_tensor, + lora_indices_tensor, + batch_size, + max_length, + token_nums, + no_lora, + ) = compute_meta(token_lora_tensor) - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, - no_lora) = compute_meta(token_lora_tensor) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) + self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor) + self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor + ) self.batch_size = batch_size self.max_length = max_length self.token_nums = token_nums self.no_lora = no_lora - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: tuple[int, ...], - lora_bias_stacked: tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - offset_left += slice - - return output.view_as(org_output) - @property def prefill_metadata( - self + self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ - This property provides a convenient way to access the necessary + This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. 1. seq_start_locs: Tensor of sequence start positions. 2. seq_lengths: Tensor of sequence lengths. - 3. lora_indices_per_batch: Tensor of lora indices, and an index of + 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. 4. batch_size: Batch size after clustering identical lora indices. 5. max_length: The maximum sequence length in the batch. 6. token_nums: The token numbers in the batch. """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length, self.token_nums) + return ( + self._seq_start_locs[: self.batch_size], + self._seq_lengths[: self.batch_size], + self._lora_indices_per_batch[: self.batch_size], + self.batch_size, + self.max_length, + self.token_nums, + ) @property def token_lora_indices(self) -> torch.Tensor: """ - This property provides the lora indices corresponding to each token + This property provides the lora indices corresponding to each token in the batch. An index of -1 means no lora should be applied. """ token_lora_len = self.indices_len[0] @@ -276,8 +255,8 @@ class PunicaWrapperBase(PunicaWrapperABC): @property def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for + """ + This property is used to access the lora indices specifically for LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] @@ -294,18 +273,24 @@ class PunicaWrapperBase(PunicaWrapperABC): @property def embeddings_indices(self) -> torch.Tensor: """ - This property provides access to the indices used for lora embeddings, + This property provides access to the indices used for lora embeddings, specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] - def update_metadata(self, mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + **kwargs, + ): + self._update_base_metadata( + mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size + ) if mapping.is_prefill: # Update metadata required for prefill-related operators. @@ -315,16 +300,21 @@ class PunicaWrapperBase(PunicaWrapperABC): self.is_prefill = False @abstractmethod - def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, **kwargs) -> Optional[torch.Tensor]: + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor @@ -336,32 +326,30 @@ class PunicaWrapperBase(PunicaWrapperABC): raise NotImplementedError @abstractmethod - def add_expand(self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> Optional[torch.Tensor]: + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> Optional[torch.Tensor]: """ - Performs GEMM and bias addition for multiple slices of lora_b. - + Performs GEMM for multiple slices of lora_b. + Semantics: offset = offset_start for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size offset_start (int): The starting position of y, defaults to 0 add_inputs (bool): Defaults to True. @@ -371,12 +359,14 @@ class PunicaWrapperBase(PunicaWrapperABC): raise NotImplementedError @abstractmethod - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -393,19 +383,20 @@ class PunicaWrapperBase(PunicaWrapperABC): raise NotImplementedError @abstractmethod - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ - Applicable to linear-related lora. + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -414,14 +405,13 @@ class PunicaWrapperBase(PunicaWrapperABC): @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. @@ -430,18 +420,20 @@ class PunicaWrapperBase(PunicaWrapperABC): raise NotImplementedError @abstractmethod - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 59049cccc8cbe..93e64eb6ba843 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -5,9 +5,14 @@ from typing import Callable, Optional, Union import torch -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.torch_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) from .punica_base import PunicaWrapperBase @@ -16,15 +21,19 @@ from .punica_base import PunicaWrapperBase # inherit this class class PunicaWrapperCPU(PunicaWrapperBase): """ - PunicaWrapperCPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) def _shrink_prefill( self, @@ -33,7 +42,7 @@ class PunicaWrapperCPU(PunicaWrapperBase): w_t_all: torch.Tensor, scale: float, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_shrink( @@ -60,7 +69,7 @@ class PunicaWrapperCPU(PunicaWrapperBase): w_t_all: torch.Tensor, add_inputs: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_expand( @@ -89,7 +98,7 @@ class PunicaWrapperCPU(PunicaWrapperBase): y_slice_size: int, add_inputs: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_expand_slice( @@ -111,8 +120,9 @@ class PunicaWrapperCPU(PunicaWrapperBase): y_slice_size: int, add_inputs: bool, ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + bgmv_expand_slice( + x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs + ) def _apply_expand( self, @@ -124,18 +134,19 @@ class PunicaWrapperCPU(PunicaWrapperBase): add_inputs: bool = True, ): """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` computation, which is suitable for the GEMM of lora'b. """ - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) + expand_slice_fun: Callable = ( + self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode + ) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): + def _apply_shrink( + self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float + ): """ Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. @@ -146,25 +157,31 @@ class PunicaWrapperCPU(PunicaWrapperBase): """ y_org = y y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) + shrink_fun: Callable = ( + self._shrink_prefill if self.is_prefill else self._shrink_decode + ) shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) - def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, **kwargs): + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the prefill stage, and the `_shrink_prefill` function should be called. Otherwise, it is the decode stage, and the _shrink_decode function should be called. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor @@ -175,43 +192,37 @@ class PunicaWrapperCPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) - def add_expand(self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. - + Performs GEMM for multiple slices of lora_b. + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -224,12 +235,14 @@ class PunicaWrapperCPU(PunicaWrapperBase): offset_left += output_slices[slice_idx] y = y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -244,23 +257,25 @@ class PunicaWrapperCPU(PunicaWrapperBase): """ # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) + expand_fun: Callable = ( + self._expand_prefill if self.is_prefill else self._expand_decode + ) expand_fun(y, x, lora_b_stacked, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> None: """ - Applicable to linear-related lora. + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -269,54 +284,47 @@ class PunicaWrapperCPU(PunicaWrapperBase): @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) if buffer is None: r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices)) + ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) + self.add_expand( + y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs + ) - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -336,14 +344,8 @@ class PunicaWrapperCPU(PunicaWrapperBase): if buffer is None: # We set the buffer to be float32 by default, consistent with the # triton op - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) # LogitsProcessorWithLoRA always using bgmv. bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + bgmv_expand(buffer, lora_b_stacked, y, self.sampler_indices, add_inputs=True) y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 2db0e9fee1420..8173fe99ea13d 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -11,13 +11,11 @@ from typing import Optional, Union, final import torch -import vllm.envs as envs from vllm.lora.layers import LoRAMapping from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, - lora_shrink) + from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink from .punica_base import PunicaWrapperBase @@ -25,54 +23,63 @@ from .punica_base import PunicaWrapperBase @final class PunicaWrapperGPU(PunicaWrapperBase): """ - PunicaWrapperGPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica triton kernel. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - self.max_loras = kwargs['max_loras'] + self.max_loras = kwargs["max_loras"] - self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, - max_num_batched_tokens, - device=device) + self.token_mapping_meta = LoRAKernelMeta.make( + self.max_loras, max_num_batched_tokens, device=device + ) - # When cudagraph capture size is greater than max_num_seqs (max_batches, - # here), V0 captures the graph as if max_num_seqs is set to - # the capture size. - # V1 doesn't have this problem and always respects max_num_seqs. - max_num_prompts = (max_batches - if envs.VLLM_USE_V1 else max_num_batched_tokens) - self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, - max_num_prompts, - device=device) - - def update_metadata(self, mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): + self.prompt_mapping_meta = LoRAKernelMeta.make( + self.max_loras, max_batches, device=device + ) + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + **kwargs, + ): self.is_prefill = mapping.is_prefill - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + self._update_base_metadata( + mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size + ) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) - def add_shrink(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, - ...], scale: float, **kwargs): + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor @@ -89,41 +96,34 @@ class PunicaWrapperGPU(PunicaWrapperBase): scale, ) - def add_expand(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. - + Performs GEMM for multiple slices of lora_b. + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) - if lora_bias_stacked is not None: - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, - y.size(0)) - self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -140,12 +140,14 @@ class PunicaWrapperGPU(PunicaWrapperBase): y = y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -161,26 +163,27 @@ class PunicaWrapperGPU(PunicaWrapperBase): lora_expand( x.unsqueeze(dim=0), - (lora_b_stacked, ), + (lora_b_stacked,), y, *self.token_mapping_meta.meta_args(x.size(0)), offset_start=0, add_inputs=add_inputs, ) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ - Applicable to linear-related lora. + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -189,26 +192,18 @@ class PunicaWrapperGPU(PunicaWrapperBase): @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] - + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, - y.size(0)) - y = self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -224,28 +219,31 @@ class PunicaWrapperGPU(PunicaWrapperBase): x, lora_a_stacked, scale, - **kwargs) + **kwargs, + ) self.add_expand( y, buffer, # type: ignore lora_b_stacked, - None, output_slices, add_inputs=True, - **kwargs) + **kwargs, + ) - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -265,15 +263,21 @@ class PunicaWrapperGPU(PunicaWrapperBase): if buffer is None: # We set the buffer to be float32 by default, refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), - *self.prompt_mapping_meta.meta_args(x.size(0)), scale) + lora_shrink( + x, + [lora_a_stacked], + buffer.unsqueeze(dim=0), + *self.prompt_mapping_meta.meta_args(x.size(0)), + scale, + ) - lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], - y, - *self.prompt_mapping_meta.meta_args(buffer.size(0)), - add_inputs=True) + lora_expand( + buffer.unsqueeze(dim=0), + [lora_b_stacked], + y, + *self.prompt_mapping_meta.meta_args(buffer.size(0)), + add_inputs=True, + ) y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index c684ac77cc9ca..c017721803fe3 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -14,7 +14,8 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: punica_wrapper_qualname = current_platform.get_punica_wrapper() punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) punica_wrapper = punica_wrapper_cls(*args, **kwargs) - assert punica_wrapper is not None, \ + assert punica_wrapper is not None, ( "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + ) logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1]) return punica_wrapper diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 07dc337a1cc87..dff30d5d2a2d1 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import torch.nn.functional as F -import torch_xla.core.xla_model as xm +import torch_xla from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from vllm.lora.punica_wrapper.utils import convert_mapping @@ -25,27 +25,29 @@ class PunicaWrapperTPU(PunicaWrapperBase): Multi-LoRA, and to provide the interface for the pytorch punica ops. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) # PunicaWrapperBase defines some tensors with dtype=torch.int64, which # isn't supported by the TPU. So convert those tensors to int32. # Not all of them are used by the TPU so only convert the useful ones. - self._token_lora_indices = self._token_lora_indices.to( - dtype=torch.int32) + self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) self._sampler_indices_padded = self._sampler_indices_padded.to( - dtype=torch.int32) + dtype=torch.int32 + ) torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, - True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, - True) + torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) @@ -77,21 +79,38 @@ class PunicaWrapperTPU(PunicaWrapperBase): ): return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) - def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - add_inputs: bool): - return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), - add_inputs) + def expand( + self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool + ): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs) - def expand_slice(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - add_inputs: bool) -> torch.Tensor: - return bgmv_expand_slice(x, w_t_all, y, - self._get_token_lora_indices(x), y_offset, - y_slice_size, add_inputs) + def expand_slice( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ) -> torch.Tensor: + return bgmv_expand_slice( + x, + w_t_all, + y, + self._get_token_lora_indices(x), + y_offset, + y_slice_size, + add_inputs, + ) - def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, **kwargs) -> Optional[torch.Tensor]: + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -115,31 +134,29 @@ class PunicaWrapperTPU(PunicaWrapperBase): y[slice_idx, :, :] = y_s # type: ignore[index] return y - def add_expand(self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> torch.Tensor: + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> torch.Tensor: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ @@ -147,25 +164,26 @@ class PunicaWrapperTPU(PunicaWrapperBase): y = y.view(-1, y.shape[-1]) offset_left = 0 - if lora_bias_stacked is not None: - y = self._apply_bias(self._get_token_lora_indices(y), y, - output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = self.expand_slice(y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs) + y = self.expand_slice( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) offset_left += output_slices[slice_idx] return y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> torch.Tensor: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> torch.Tensor: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -182,17 +200,18 @@ class PunicaWrapperTPU(PunicaWrapperBase): # Embedding layer only needs the expand op return self.expand(y, x, lora_b_stacked, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> torch.Tensor: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> torch.Tensor: """ Applicable to linear-related lora. @@ -203,24 +222,19 @@ class PunicaWrapperTPU(PunicaWrapperBase): @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will not be changed in-place. x (torch.Tensor): Input tensor (T, E) lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self._get_token_lora_indices(y), y, - output_slices, lora_bias_stacked) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -231,23 +245,21 @@ class PunicaWrapperTPU(PunicaWrapperBase): device=x.device, ) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - return self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) + return self.add_expand( + y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs + ) - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -269,49 +281,9 @@ class PunicaWrapperTPU(PunicaWrapperBase): sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) - y = bgmv_expand(buffer, - lora_b_stacked, - y, - sampler_indices, - add_inputs=True) + y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: tuple[int, ...], - lora_bias_stacked: tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias = torch.where(indices[:, None] == -1, 0, bias) - - bias = F.pad(bias, (offset_left, output.shape[1] - - (offset_left + slice), 0, 0)) - - output += bias - offset_left += slice - - return output.view_as(org_output) - # This performs the same tensor ops as the base method, except it does them # on the CPU then transfers the results to the TPU def _update_base_metadata( @@ -323,13 +295,12 @@ class PunicaWrapperTPU(PunicaWrapperBase): extra_vocab_size: int, ): # Make sure we don't accidentally collect outside operations - xm.mark_step() + torch_xla.sync() # Pad the prompt mapping to avoid running into recompiles on the TPU # TODO: Should this happen inside mapping internally? If so how can we # avoid having backend specific LoRAMapping classes? - mapping.prompt_mapping = self._pad_prompt_mapping( - mapping.prompt_mapping) + mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping) ( base_indices, @@ -346,35 +317,33 @@ class PunicaWrapperTPU(PunicaWrapperBase): "cpu", ) self._token_lora_indices = self._pad_to_shape( - base_indices, self._token_lora_indices.shape, - dims=1).to(self.device) - self._sampler_indices = self._pad_to_shape(sampler_indices, - self._sampler_indices.shape, - dims=1).to(self.device) + base_indices, self._token_lora_indices.shape, dims=1 + ).to(self.device) + self._sampler_indices = self._pad_to_shape( + sampler_indices, self._sampler_indices.shape, dims=1 + ).to(self.device) self._sampler_indices_padded = self._pad_to_shape( - sampler_indices_padded, self._sampler_indices_padded.shape, - dims=1).to(self.device) + sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 + ).to(self.device) self._embeddings_indices = self._pad_to_shape( - embeddings_indices, self._embeddings_indices.shape, - dims=2).to(self.device) + embeddings_indices, self._embeddings_indices.shape, dims=2 + ).to(self.device) self.indices_len[:] = indices_len - def _update_prefill_metadata(self, - token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 - self._lora_indices_per_batch[:self. - batch_size] = token_lora_tensor[:self. - batch_size] + self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[ + : self.batch_size + ] - def _pad_prompt_mapping( - self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: + def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: num_reqs = len(prompt_mapping) # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular # import MIN_NUM_SEQS = 8 - padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) pad_len = padded_num_reqs - num_reqs padding = [-1] * pad_len @@ -387,5 +356,4 @@ class PunicaWrapperTPU(PunicaWrapperBase): else: pad_rows = target_shape[0] - src.shape[0] pad_cols = target_shape[1] - src.shape[1] - return F.pad(src, (0, pad_cols, 0, pad_rows), - value=0).to(torch.int32) + return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32) diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 572e39e0eced0..e3d03ac8dc2c2 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -21,25 +21,35 @@ from .punica_base import PunicaWrapperBase class PunicaWrapperXPU(PunicaWrapperBase): """ PunicaWrapperXPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica ipex kernel. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) - def update_metadata(self, mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): - + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + **kwargs, + ): self.is_prefill = mapping.is_prefill - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + self._update_base_metadata( + mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size + ) def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) @@ -63,19 +73,25 @@ class PunicaWrapperXPU(PunicaWrapperBase): add_inputs: bool, ): token_lora_indices = self._get_token_lora_indices(x) - bgmv_expand_slice(x, w_t_all, y, token_lora_indices, y_offset, - y_slice_size, add_inputs) + bgmv_expand_slice( + x, w_t_all, y, token_lora_indices, y_offset, y_slice_size, add_inputs + ) - def add_shrink(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, - ...], scale: float, **kwargs): + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor @@ -85,43 +101,36 @@ class PunicaWrapperXPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) - def add_expand(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. - + Performs GEMM for multiple slices of lora_b. + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) - if lora_bias_stacked is not None: - token_lora_indices = self._get_token_lora_indices(y) - self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -139,12 +148,14 @@ class PunicaWrapperXPU(PunicaWrapperBase): offset_start += output_slices[slice_idx] y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -160,17 +171,18 @@ class PunicaWrapperXPU(PunicaWrapperBase): token_lora_indices = self._get_token_lora_indices(x) bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applicable to linear-related lora. @@ -181,25 +193,19 @@ class PunicaWrapperXPU(PunicaWrapperBase): @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = self._get_token_lora_indices(y) - y = self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -215,28 +221,38 @@ class PunicaWrapperXPU(PunicaWrapperBase): x, lora_a_stacked, scale, - **kwargs) + **kwargs, + ) self.add_expand( y, buffer, # type: ignore lora_b_stacked, - None, output_slices, add_inputs=True, - **kwargs) + **kwargs, + ) - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -256,14 +272,8 @@ class PunicaWrapperXPU(PunicaWrapperBase): if buffer is None: # We set the buffer to be float32 by default, refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale) + bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index d22c29da1c615..90d1614e674db 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: def compute_meta( - token_lora_tensor: torch.Tensor + token_lora_tensor: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: @@ -23,7 +23,8 @@ def compute_meta( """ lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) + token_lora_tensor, return_counts=True + ) cum_result = torch.cumsum(seq_length_tensor, dim=0) b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) @@ -36,8 +37,15 @@ def compute_meta( # does not need to launch the triton kernel, which can improve performance if batch_size == 1 and lora_indices_tensor == -1: no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) + return ( + b_seq_start_tensor, + seq_length_tensor, + lora_indices_tensor, + batch_size, + max_length, + token_nums, + no_lora, + ) # TODO see if this can be vectorized @@ -83,14 +91,16 @@ def convert_mapping( lora_indices = index_mapping_indices.copy() prompt_mapping: list[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping + lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) + lora_idx = ( + lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 + else -1 + ) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx @@ -101,23 +111,27 @@ def convert_mapping( ] indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, - embeddings_indices) + prompt_mapping_tensor = torch.tensor( + prompt_mapping, dtype=torch.long, device=device + ) + embeddings_indices = torch.stack( + [ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ] + ) + embeddings_indices = torch.where( + embeddings_indices == -1, max_loras - 1, embeddings_indices + ) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded = torch.where(sampler_indices_padded == -1, - max_loras - 1, sampler_indices_padded) + sampler_indices_padded = torch.where( + sampler_indices_padded == -1, max_loras - 1, sampler_indices_padded + ) sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) + 0, len(sampler_indices_padded), device=device, dtype=torch.long + ) + (sampler_indices_padded * len(sampler_indices_padded)) # Contain length of indices tensors. Used to index into each tensor. indices_len = [ diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5bbba7830c1b1..650e060a5804d 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -6,13 +6,12 @@ from typing import Optional import msgspec -from vllm.adapter_commons.request import AdapterRequest - class LoRARequest( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, +): # type: ignore[call-arg] """ Request for a LoRA adapter. @@ -24,7 +23,6 @@ class LoRARequest( lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ - __metaclass__ = AdapterRequest lora_name: str lora_int_id: int @@ -35,13 +33,16 @@ class LoRARequest( tensorizer_config_dict: Optional[dict] = None def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError(f"id must be > 0, got {self.lora_int_id}") if self.lora_local_path: warnings.warn( "The 'lora_local_path' attribute is deprecated " "and will be removed in a future version. " "Please use 'lora_path' instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) if not self.lora_path: self.lora_path = self.lora_local_path or "" @@ -67,7 +68,8 @@ class LoRARequest( "and will be removed in a future version. " "Please use 'path' instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return self.lora_path @local_path.setter @@ -77,7 +79,8 @@ class LoRARequest( "and will be removed in a future version. " "Please use 'path' instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) self.lora_path = value def __eq__(self, value: object) -> bool: @@ -86,8 +89,7 @@ class LoRARequest( instances based on lora_name. This allows for identification and comparison lora adapter across engines. """ - return isinstance(value, - self.__class__) and self.lora_name == value.lora_name + return isinstance(value, self.__class__) and self.lora_name == value.lora_name def __hash__(self) -> int: """ diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py index 5808ae105e864..d366b94521cd8 100644 --- a/vllm/lora/resolver.py +++ b/vllm/lora/resolver.py @@ -22,8 +22,9 @@ class LoRAResolver(ABC): """ @abstractmethod - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: """Abstract method to resolve and fetch a LoRA model adapter. Implements logic to locate and download LoRA adapter based on the name. @@ -61,8 +62,10 @@ class _LoRAResolverRegistry: if resolver_name in self.resolvers: logger.warning( "LoRA resolver %s is already registered, and will be " - "overwritten by the new resolver instance %s.", resolver_name, - resolver) + "overwritten by the new resolver instance %s.", + resolver_name, + resolver, + ) self.resolvers[resolver_name] = resolver @@ -78,7 +81,8 @@ class _LoRAResolverRegistry: if resolver_name not in self.resolvers: raise KeyError( f"LoRA resolver '{resolver_name}' not found. " - f"Available resolvers: {list(self.resolvers.keys())}") + f"Available resolvers: {list(self.resolvers.keys())}" + ) return self.resolvers[resolver_name] diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ab0a9fbd255de..595c774e03be3 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -6,37 +6,40 @@ from typing import TYPE_CHECKING, Optional, Union import huggingface_hub import regex as re -from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, RepositoryNotFoundError) +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + HFValidationError, + RepositoryNotFoundError, +) from torch import nn from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) -# being imported for _all_lora_classes below -# yapf conflicts with isort for this block -# yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LogitsProcessorWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) -from vllm.model_executor.layers.linear import LinearBase -# yapf: enable +# being imported for _all_lora_classes below +from vllm.lora.layers import ( + BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, + VocabParallelEmbeddingWithLoRA, +) +from vllm.model_executor.layers.linear import LinearBase if TYPE_CHECKING: from vllm.model_executor.layers.logits_processor import LogitsProcessor - from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -58,20 +61,23 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { } -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig] = None) -> nn.Module: +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig] = None, +) -> nn.Module: for lora_cls in _all_lora_classes: # specifying kwargs so they can be easily accessed in decorator - if lora_cls.can_replace_layer(source_layer=layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config): + if lora_cls.can_replace_layer( + source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + ): instance_layer = lora_cls(layer) - instance_layer.create_lora_weights(max_loras, lora_config, - model_config) + instance_layer.create_lora_weights(max_loras, lora_config, model_config) return instance_layer return layer @@ -83,15 +89,20 @@ def from_layer_logits_processor( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> LogitsProcessorWithLoRA: - ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device, - lm_head.get_sharded_to_full_mapping()) + ret = LogitsProcessorWithLoRA( + layer, + lm_head.embedding_dim, + lm_head.weight.dtype, + lm_head.weight.device, + lm_head.get_sharded_to_full_mapping(), + ) ret.create_lora_weights(max_loras, lora_config, model_config) return ret -def replace_submodule(model: nn.Module, module_name: str, - new_module: nn.Module) -> nn.Module: +def replace_submodule( + model: nn.Module, module_name: str, new_module: nn.Module +) -> nn.Module: """Replace a submodule in a model with a new module.""" parent = model.get_submodule(".".join(module_name.split(".")[:-1])) target_name = module_name.split(".")[-1] @@ -100,9 +111,8 @@ def replace_submodule(model: nn.Module, module_name: str, def parse_fine_tuned_lora_name( - name: str, - weights_mapper: Optional["WeightsMapper"] = None -) -> tuple[str, bool, bool]: + name: str, weights_mapper: Optional["WeightsMapper"] = None +) -> tuple[str, bool]: """Parse the name of lora weights. args: @@ -114,7 +124,6 @@ def parse_fine_tuned_lora_name( tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. - is_bias whether the tensor is lora bias. """ # LoRA weight qualified name usually starts with `base_model.model.`, @@ -134,28 +143,24 @@ def parse_fine_tuned_lora_name( start_index = 2 if name.startswith("base_model.model.") else 0 parts = name.split(".") - if parts[-1] == "weight" and (parts[-2] == "lora_A" - or parts[-2] == "lora_B"): + if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): new_name = ".".join(parts[start_index:-2]) - return new_name, parts[-2] == "lora_A", False + return new_name, parts[-2] == "lora_A" if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": new_name = ".".join(parts[start_index:-1]) - return new_name, parts[-1] == "lora_embedding_A", False - - if parts[-1] == "bias": - new_name = ".".join(parts[start_index:-2]) - return new_name, False, True + return new_name, parts[-1] == "lora_embedding_A" raise ValueError(f"{name} is unsupported LoRA weight") -def is_regex_target_modules(load_modules: Union[str, list[str]], - expected_lora_modules: list[str]) -> bool: +def is_regex_target_modules( + load_modules: Union[str, list[str]], expected_lora_modules: list[str] +) -> bool: """ - PEFT supports passing `target_modules` in the form of regular expressions, - such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to - determine whether the suffix in the regular expression is present in the + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the `expected_lora_modules`. """ @@ -197,7 +202,7 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: supported_lora_modules.add(name) # get all the linear subfixes. - if isinstance(module, (LinearBase, )): + if isinstance(module, (LinearBase,)): supported_lora_modules.add(name.split(".")[-1]) return list(supported_lora_modules) @@ -225,7 +230,7 @@ def get_adapter_absolute_path(lora_path: str) -> str: return lora_path # If the path starts with ~, expand the user home directory. - if lora_path.startswith('~'): + if lora_path.startswith("~"): return os.path.expanduser(lora_path) # Check if the expanded relative path exists locally. @@ -234,12 +239,15 @@ def get_adapter_absolute_path(lora_path: str) -> str: # If the path does not exist locally, assume it's a Hugging Face repo. try: - local_snapshot_path = huggingface_hub.snapshot_download( - repo_id=lora_path) - except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, - HFValidationError): + local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path) + except ( + HfHubHTTPError, + RepositoryNotFoundError, + EntryNotFoundError, + HFValidationError, + ): # Handle errors that may occur during the download - # Return original path instead instead of throwing error here + # Return original path instead of throwing error here logger.exception("Error downloading the HuggingFace model") return lora_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 248d2954f1ef4..3ca819fb732cf 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -6,15 +6,14 @@ from typing import Any, Literal, Optional, Union import torch -from vllm.adapter_commons.utils import (add_adapter_worker, - apply_adapters_worker, - list_adapters_worker, - set_active_adapters_worker) -from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.config import LoRAConfig +from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.lora.models import (LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.models import ( + LoRAModel, + LoRAModelManager, + LRUCacheLoRAModelManager, + create_lora_manager, +) from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -22,7 +21,7 @@ from vllm.lora.utils import get_adapter_absolute_path logger = init_logger(__name__) -class WorkerLoRAManager(AbstractWorkerManager): +class WorkerLoRAManager: """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already @@ -32,26 +31,28 @@ class WorkerLoRAManager(AbstractWorkerManager): def __init__( self, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, + vllm_config: VllmConfig, device: torch.device, embedding_modules: dict[str, str], embedding_padding_modules: list[str], lora_model_cls: type[LoRAModel] = LoRAModel, - max_position_embeddings: Optional[int] = None, ): self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens - self.vocab_size = vocab_size - self.lora_config = lora_config - self.max_position_embeddings = max_position_embeddings - super().__init__(device) + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + ) + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.lora_config = vllm_config.lora_config + + # Use get_text_config() in case of multimodal models + text_config = vllm_config.model_config.hf_config.get_text_config() + + self.max_position_embeddings = text_config.max_position_embeddings + self.device = device # Lazily initialized by create_lora_manager. self._adapter_manager: LoRAModelManager @@ -85,15 +86,12 @@ class WorkerLoRAManager(AbstractWorkerManager): def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - supported_lora_modules = ( - self._adapter_manager.supported_lora_modules) - packed_modules_mapping = ( - self._adapter_manager.packed_modules_mapping) + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping expected_lora_modules: list[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: - expected_lora_modules.extend( - packed_modules_mapping[module]) + expected_lora_modules.extend(packed_modules_mapping[module]) else: expected_lora_modules.append(module) @@ -101,8 +99,10 @@ class WorkerLoRAManager(AbstractWorkerManager): lora_path = get_adapter_absolute_path(lora_request.lora_path) peft_helper = PEFTHelper.from_local_dir( - lora_path, self.max_position_embeddings, - lora_request.tensorizer_config_dict) + lora_path, + self.max_position_embeddings, + lora_request.tensorizer_config_dict, + ) # Validates the LoRA configuration against requirements before # loading weights, throwing an exception if validation fails. @@ -120,12 +120,13 @@ class WorkerLoRAManager(AbstractWorkerManager): lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + - self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, tensorizer_config_dict=lora_request.tensorizer_config_dict, - weights_mapper=hf_to_vllm_mapper) + weights_mapper=hf_to_vllm_mapper, + ) except FileNotFoundError as e: # FileNotFoundError should be raised if both @@ -135,26 +136,29 @@ class WorkerLoRAManager(AbstractWorkerManager): # For NotFoundError raise ValueError( f"Loading lora {lora_request.lora_name} failed: No adapter " - f"found for {lora_request.lora_path}") from e + f"found for {lora_request.lora_path}" + ) from e except Exception as e: # For BadRequestError raise e if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " - f"is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}.") + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}." + ) return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_adapters(): return False if isinstance(self._cached_dummy_lora, LoRAModel): - dummy_lora = self._cached_dummy_lora.clone( - lora_request.lora_int_id) + dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id) else: dummy_lora = self._adapter_manager.create_dummy_lora( - lora_request.lora_int_id, rank, self.embedding_modules) + lora_request.lora_int_id, rank, self.embedding_modules + ) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) @@ -162,21 +166,37 @@ class WorkerLoRAManager(AbstractWorkerManager): def pin_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.pin_adapter(adapter_id) - def set_active_adapters(self, requests: set[Any], - mapping: Optional[Any]) -> None: - set_active_adapters_worker(requests, mapping, self._apply_adapters, - self._adapter_manager.set_adapter_mapping) + def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: + self._apply_adapters(requests) + if mapping is not None: + self._adapter_manager.set_adapter_mapping(mapping) def _apply_adapters(self, adapter_requests: set[Any]) -> None: - apply_adapters_worker(adapter_requests, self.list_adapters, - self._adapter_manager.adapter_slots, - self.remove_adapter, self.add_adapter) + existing_adapters = self.list_adapters() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests + if adapter_request + } + if len(models_map) > self._adapter_manager.adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + "than the number of GPU model slots " + f"({self._adapter_manager.adapter_slots})." + ) + requested_ids = set(models_map) + for adapter_id in existing_adapters - requested_ids: + self.remove_adapter(adapter_id) + for adapter_id in requested_ids - existing_adapters: + self.add_adapter(models_map[adapter_id]) def add_adapter(self, adapter_request: Any) -> bool: - return add_adapter_worker(adapter_request, self.list_adapters, - self._load_adapter, - self._adapter_manager.add_adapter, - self._adapter_manager.activate_adapter) + if adapter_request.adapter_id in self.list_adapters(): + return False + loaded_adapter = self._load_adapter(adapter_request) + loaded = self._adapter_manager.add_adapter(loaded_adapter) + self._adapter_manager.activate_adapter(loaded_adapter.id) + return loaded def remove_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.remove_adapter(adapter_id) @@ -185,7 +205,7 @@ class WorkerLoRAManager(AbstractWorkerManager): self._adapter_manager.remove_all_adapters() def list_adapters(self) -> set[int]: - return list_adapters_worker(self._adapter_manager.list_adapters) + return set(self._adapter_manager.list_adapters()) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): @@ -216,13 +236,15 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request - for lora_request in lora_requests if lora_request + for lora_request in lora_requests + if lora_request } if len(loras_map) > self._adapter_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._adapter_manager.lora_slots}).") + f"({self._adapter_manager.lora_slots})." + ) for lora in loras_map.values(): self.add_adapter(lora) @@ -242,15 +264,15 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): # Loading succeeded, now check if we will exceed cache capacity and # evict if the oldest adapter if so if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: - assert isinstance(self._adapter_manager, - LRUCacheLoRAModelManager) + assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager) self._adapter_manager.remove_oldest_adapter() # Then add the new adapter to the cache loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._adapter_manager.get_adapter( - lora_request.lora_int_id) is not None + loaded = ( + self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None + ) self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 55dfe8088c8f3..b50f0cb3a61a2 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,15 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.parameter import (BasevLLMParameter, - PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingMetadataCache) +from vllm.model_executor.parameter import BasevLLMParameter, PackedvLLMParameter from vllm.model_executor.utils import set_random_seed __all__ = [ - "SamplingMetadata", - "SamplingMetadataCache", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6b5a107396c92..ad5a09ca970d6 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -32,8 +32,11 @@ class CustomOp(nn.Module): op_cls_to_instantiate = cls else: op_cls_to_instantiate = cls.op_registry_oot[op_name] - logger.debug("Instantiating custom op: %s using %s", op_name, - str(op_cls_to_instantiate)) + logger.debug( + "Instantiating custom op: %s using %s", + op_name, + str(op_cls_to_instantiate), + ) return super().__new__(op_cls_to_instantiate) def __init__(self): @@ -73,11 +76,6 @@ class CustomOp(nn.Module): # NOTE(woosuk): This is a placeholder for future extensions. return self.forward_native(*args, **kwargs) - def forward_neuron(self, *args, **kwargs): - # By default, we assume that Neuron ops are compatible with the - # PyTorch-native implementation. - return self.forward_native(*args, **kwargs) - def forward_oot(self, *args, **kwargs): # By default, we assume that OOT ops are compatible with the # PyTorch-native implementation. @@ -91,8 +89,7 @@ class CustomOp(nn.Module): if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name]) else: - compilation_config.disabled_custom_ops.update( - [self.__class__.name]) + compilation_config.disabled_custom_ops.update([self.__class__.name]) if not enabled: return self.forward_native @@ -105,8 +102,6 @@ class CustomOp(nn.Module): return self.forward_tpu elif current_platform.is_xpu(): return self.forward_xpu - elif current_platform.is_neuron(): - return self.forward_neuron elif current_platform.is_out_of_tree(): return self.forward_oot else: @@ -126,8 +121,7 @@ class CustomOp(nn.Module): enabled = f"+{cls.name}" in custom_ops disabled = f"-{cls.name}" in custom_ops - assert not (enabled - and disabled), f"Cannot enable and disable {cls.name}" + assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" return (CustomOp.default_on() or enabled) and not disabled @@ -138,9 +132,12 @@ class CustomOp(nn.Module): Specifying 'all' or 'none' in custom_op takes precedence. """ from vllm.config import CompilationLevel + compilation_config = get_cached_compilation_config() - default_on = (compilation_config.level < CompilationLevel.PIECEWISE - or not compilation_config.use_inductor) + default_on = ( + compilation_config.level < CompilationLevel.PIECEWISE + or not compilation_config.use_inductor + ) count_none = compilation_config.custom_ops.count("none") count_all = compilation_config.custom_ops.count("all") return default_on and not count_none > 0 or count_all > 0 @@ -150,13 +147,12 @@ class CustomOp(nn.Module): # Examples: # - MyOp.enabled() # - op_registry["my_op"].enabled() - op_registry: dict[str, type['CustomOp']] = {} - op_registry_oot: dict[str, type['CustomOp']] = {} + op_registry: dict[str, type["CustomOp"]] = {} + op_registry_oot: dict[str, type["CustomOp"]] = {} # Decorator to register custom ops. @classmethod def register(cls, name: str): - def decorator(op_cls): assert name not in cls.op_registry, f"Duplicate op name: {name}" op_cls.name = name @@ -176,11 +172,9 @@ class CustomOp(nn.Module): # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod") @classmethod def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None): - def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, \ - f"Duplicate op name: {reg_name}" + assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name cls.op_registry_oot[reg_name] = op_cls return op_cls diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 86ab4f546d127..96745b99f7a7e 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom activation functions.""" + import math from typing import Optional @@ -8,13 +9,19 @@ import torch import torch.nn as nn import torch.nn.functional as F -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import LazyDict +logger = init_logger(__name__) + @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): @@ -29,7 +36,7 @@ class FatreluAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ - def __init__(self, threshold: float = 0.): + def __init__(self, threshold: float = 0.0): super().__init__() self.threshold = threshold if current_platform.is_cuda_alike(): @@ -46,7 +53,7 @@ class FatreluAndMul(CustomOp): def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x, self.threshold) return out @@ -69,6 +76,7 @@ class SiluAndMul(CustomOp): self.op = torch.ops._C.silu_and_mul elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -80,25 +88,18 @@ class SiluAndMul(CustomOp): def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out - def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) - result = s * x_reshaped[:, d:] - return result.view(*x.shape[:-1], d) - @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): @@ -117,6 +118,7 @@ class MulAndSilu(CustomOp): self.op = torch.ops._C.mul_and_silu elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -128,7 +130,7 @@ class MulAndSilu(CustomOp): def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out @@ -160,10 +162,8 @@ class GeluAndMulSparse(CustomOp): # Sparsity. if activation_sparsity == 0.0: - raise ValueError( - "activation_sparsity is 0.0. Please use GeluAndMul.") - target_sparsity_tensor = torch.tensor(activation_sparsity, - dtype=torch.float32) + raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.") + target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32) normal_dist = torch.distributions.normal.Normal(0, 1) self.std_multiplier = normal_dist.icdf(target_sparsity_tensor) @@ -211,6 +211,7 @@ class GeluAndMul(CustomOp): self.op = torch.ops._C.gelu_tanh_and_mul elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + if approximate == "none": self.op = ipex_ops.gelu_and_mul else: @@ -223,20 +224,20 @@ class GeluAndMul(CustomOp): def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out def extra_repr(self) -> str: - return f'approximate={repr(self.approximate)}' + return f"approximate={repr(self.approximate)}" @CustomOp.register("swigluoai_and_mul") @@ -259,7 +260,7 @@ class SwigluOAIAndMul(CustomOp): def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) return out @@ -270,20 +271,19 @@ class SwigluOAIAndMul(CustomOp): @CustomOp.register("gelu_new") class NewGELU(CustomOp): - def __init__(self): super().__init__() if current_platform.is_cuda_alike() or current_platform.is_cpu(): self.op = torch.ops._C.gelu_new elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_new def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" c = math.sqrt(2.0 / math.pi) - return 0.5 * x * (1.0 + torch.tanh(c * - (x + 0.044715 * torch.pow(x, 3.0)))) + return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) @@ -296,19 +296,18 @@ class NewGELU(CustomOp): @CustomOp.register("gelu_fast") class FastGELU(CustomOp): - def __init__(self): super().__init__() if current_platform.is_cuda_alike() or current_platform.is_cpu(): self.op = torch.ops._C.gelu_fast elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_fast def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" - return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * - (1.0 + 0.044715 * x * x))) + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) @@ -328,6 +327,7 @@ class QuickGELU(CustomOp): self.op = torch.ops._C.gelu_quick elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_quick def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -359,10 +359,122 @@ class ReLUSquaredActivation(CustomOp): return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - #TODO : implement cuda kenrels + # TODO : implement cuda kernels return self.forward_native(x) +@CustomOp.register("xielu") +class XIELU(CustomOp): + """ + Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010 + If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA + Otherwise, we emit a single warning and use xIELU Python + """ + + def __init__( + self, + alpha_p_init: float = 0.8, + alpha_n_init: float = 0.8, + beta: float = 0.5, + eps: float = -1e-6, + dtype: torch.dtype = torch.bfloat16, + with_vector_loads: bool = False, + ): + super().__init__() + self.alpha_p = nn.Parameter( + torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze( + 0 + ) + ) + self.alpha_n = nn.Parameter( + torch.log( + torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1 + ).unsqueeze(0) + ) + self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) + self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) + self.with_vector_loads = with_vector_loads + # Temporary until xIELU CUDA fully implemented + self._beta_scalar = float(self.beta.detach().cpu().float().item()) + self._eps_scalar = float(self.eps.detach().cpu().float().item()) + + self._xielu_cuda_obj = None + try: + import xielu.ops # noqa: F401 + + self._xielu_cuda_obj = torch.classes.xielu.XIELU() + msg = "Using experimental xIELU CUDA." + try: + from torch._dynamo import allow_in_graph + + self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) + msg += " Enabled torch._dynamo for xIELU CUDA." + except Exception as err: + msg += ( + f" Could not enable torch._dynamo for xIELU ({err}) - " + "this may result in slower performance." + ) + self._xielu_cuda_fn = self._xielu_cuda + logger.warning_once(msg) + except Exception as err: + logger.warning_once( + "CUDA-fused xIELU not available (%s) –" + " falling back to a Python version.\n" + "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`", + str(err), + ) + + def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: + alpha_p = nn.functional.softplus(self.alpha_p) + alpha_n = self.beta + nn.functional.softplus(self.alpha_n) + return torch.where( + x > 0, + alpha_p * x * x + self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x, + ) + + def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: + """Firewall function to prevent torch.compile from seeing .item()""" + assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None" + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions" + " but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p, + self.alpha_n, + # Temporary until xIELU CUDA fully implemented -> + # self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + def forward_native(self, input: torch.Tensor) -> torch.Tensor: + if self._xielu_cuda_obj is not None and input.is_cuda: + if not torch._dynamo.is_compiling(): + return self._xielu_cuda_fn(input) + else: + logger.warning_once( + "torch._dynamo is compiling, using Python version of xIELU." + ) + return self._xielu_python(input) + + def forward_cuda(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_native(input) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -381,14 +493,14 @@ class ScaledActivation(nn.Module): self.input_is_parallel = input_is_parallel if input_is_parallel: tp_size = get_tensor_model_parallel_world_size() - intermediate_size_per_partition = divide(intermediate_size, - tp_size) + intermediate_size_per_partition = divide(intermediate_size, tp_size) else: intermediate_size_per_partition = intermediate_size if params_dtype is None: params_dtype = torch.get_default_dtype() self.scales = nn.Parameter( - torch.empty(intermediate_size_per_partition, dtype=params_dtype)) + torch.empty(intermediate_size_per_partition, dtype=params_dtype) + ) set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -405,53 +517,53 @@ class ScaledActivation(nn.Module): param_data.copy_(loaded_weight) -_ACTIVATION_REGISTRY = LazyDict({ - "gelu": - lambda: nn.GELU(), - "gelu_fast": - lambda: FastGELU(), - "gelu_new": - lambda: NewGELU(), - "gelu_pytorch_tanh": - lambda: nn.GELU(approximate="tanh"), - "relu": - lambda: nn.ReLU(), - "relu2": - lambda: ReLUSquaredActivation(), - "silu": - lambda: nn.SiLU(), - "quick_gelu": - lambda: QuickGELU(), -}) +_ACTIVATION_REGISTRY = LazyDict( + { + "gelu": lambda: nn.GELU(), + "gelu_fast": lambda: FastGELU(), + "gelu_new": lambda: NewGELU(), + "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"), + "relu": lambda: nn.ReLU(), + "relu2": lambda: ReLUSquaredActivation(), + "silu": lambda: nn.SiLU(), + "quick_gelu": lambda: QuickGELU(), + "tanh": lambda: nn.Tanh(), + "sigmoid": lambda: nn.Sigmoid(), + "xielu": lambda: XIELU(), + } +) def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("torch.nn.modules."): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + if act_fn_name not in _ACTIVATION_REGISTRY: - raise ValueError( - f"Activation function {act_fn_name!r} is not supported.") + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") return _ACTIVATION_REGISTRY[act_fn_name] -_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ - "gelu": - lambda: GeluAndMul(), - "silu": - lambda: SiluAndMul(), - "geglu": - lambda: GeluAndMul(), - "swigluoai": - lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), -}) +_ACTIVATION_AND_MUL_REGISTRY = LazyDict( + { + "gelu": lambda: GeluAndMul(), + "silu": lambda: SiluAndMul(), + "geglu": lambda: GeluAndMul(), + "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), + } +) def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" act_fn_name = act_fn_name.lower() if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: - raise ValueError( - f"Activation function {act_fn_name!r} is not supported.") + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py new file mode 100644 index 0000000000000..fa74c20840da1 --- /dev/null +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for attention-like layers.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class AttentionLayerBase(ABC): + """ + Base class for attention-like layers (Attention, Mamba, etc.) + that support the v1 engine. + + This provides a common interface for getting attention backends + from different layer types. + """ + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this layer.""" + pass diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py new file mode 100644 index 0000000000000..9fd85d1e9e194 --- /dev/null +++ b/vllm/model_executor/layers/batch_invariant.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +from collections import namedtuple +from collections.abc import Callable +from typing import Any, Union + +import torch + +from vllm.triton_utils import tl, triton + + +def _matmul_launch_metadata( + grid: Callable[..., Any], kernel: Any, args: dict[str, Any] +) -> dict[str, Any]: + ret = {} + m, n, k = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" + if "tiles_per_update" in args: + ret["name"] = ( + f"{kernel.name} [M={m}, N={n}, K={k}, " + f"tiles_per_update={args['tiles_per_update']:02}]" + ) + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k + ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) + return ret + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + bias_ptr, + M, + N, + K, # + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + A_LARGE: tl.constexpr, + B_LARGE: tl.constexpr, + C_LARGE: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + if A_LARGE: + offs_am = offs_am.to(tl.int64) + if B_LARGE: + offs_bn = offs_bn.to(tl.int64) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + if A_LARGE or B_LARGE: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + else: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + a = tl.load( + a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, b, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if C_LARGE: + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_cn + bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) + accumulator += bias + if c_ptr.dtype.element_ty == tl.float8e4nv: + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul_persistent( + a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None +): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert bias is None or bias.dim() == 1, ( + "Currently assuming bias is 1D, let Horace know if you run into this" + ) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + + configs = { + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float32: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + # print(a.device, b.device, c.device) + matmul_kernel_persistent[grid]( + a, + b, + c, # + bias, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + NUM_SMS=NUM_SMS, # + A_LARGE=a.numel() > 2**31, + B_LARGE=b.numel() > 2**31, + C_LARGE=c.numel() > 2**31, + HAS_BIAS=bias is not None, + **configs[dtype], + ) + return c + + +@triton.jit +def _log_softmax_kernel( + input_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute log_softmax along the last dimension of a 2D tensor. + Each block handles one row of the input tensor. + """ + # Get the row index for this block + row_idx = tl.program_id(0).to(tl.int64) + + # Compute base pointers for input and output rows + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Find maximum value in the row for numerical stability + max_val = -float("inf") + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) + + # Update maximum + max_val = tl.max(tl.maximum(vals, max_val)) + + # Step 2: Compute sum of exp(x - max_val) + sum_exp = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + + # Compute exp(x - max_val) and accumulate + exp_vals = tl.exp(vals - max_val) + sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) + + # Compute log(sum_exp) + log_sum_exp = tl.log(sum_exp) + + # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask) + + # Compute log_softmax + output = vals - max_val - log_sum_exp + + # Store results + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Compute log_softmax using Triton kernel. + + Args: + input: Input tensor + dim: Dimension along which to compute log_softmax + (only -1 or last dim supported) + >> Stashed changes + Returns: + Tensor with log_softmax applied along the specified dimension + """ + if dim != -1 and dim != input.ndim - 1: + raise ValueError( + "This implementation only supports log_softmax along the last dimension" + ) + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + + n_rows, n_cols = input_2d.shape + + # Allocate output tensor + output = torch.empty_like(input_2d) + + # Choose block size based on the number of columns + BLOCK_SIZE = 1024 + + # Launch kernel with one block per row + grid = (n_rows,) + _log_softmax_kernel[grid]( + input_2d, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Reshape output back to original shape + return output.reshape(original_shape) + + +@triton.jit +def mean_kernel( + input_ptr, + output_ptr, + input_stride0, + input_stride1, + input_stride2, + output_stride0, + output_stride1, + M, # size before reduction dim + N, # size of reduction dim + K, # size after reduction dim + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for computing mean along a single dimension. + Input is viewed as (M, N, K) where N is the dimension being reduced. + """ + # Program ID gives us which output element we're computing + pid = tl.program_id(0) + + # Compute output indices + m_idx = pid // K + k_idx = pid % K + + # Bounds check + if m_idx >= M or k_idx >= K: + return + + # Accumulate sum across reduction dimension + acc = 0.0 + for n_start in range(0, N, BLOCK_SIZE): + n_offsets = n_start + tl.arange(0, BLOCK_SIZE) + mask = n_offsets < N + + # Calculate input indices + input_idx = ( + m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 + ) + + # Load and accumulate + vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + acc += tl.sum(vals) + + # Compute mean and store + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 + tl.store(output_ptr + output_idx, mean_val) + + +def mean_dim( + input: torch.Tensor, + dim: int, + keepdim: bool = False, + dtype: Union[torch.dtype, None] = None, +) -> torch.Tensor: + """ + Triton implementation of torch.mean with single dimension reduction. + + Args: + input: Input tensor + dim: Single dimension along which to compute mean + keepdim: Whether to keep the reduced dimension + dtype: Output dtype. If None, uses input dtype + (or float32 for integer inputs) + + Returns: + Tensor with mean values along specified dimension + """ + # Validate inputs + assert input.is_cuda, "Input must be a CUDA tensor" + assert -input.ndim <= dim < input.ndim, ( + f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" + ) + + # Handle negative dim + if dim < 0: + dim = dim + input.ndim + + # Handle dtype + if dtype is None: + if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + dtype = torch.float32 + else: + dtype = input.dtype + + # Convert input to appropriate dtype if needed + if input.dtype != dtype: + input = input.to(dtype) + + # Get input shape and strides + shape = list(input.shape) + + # Calculate dimensions for kernel + M = 1 + for i in range(dim): + M *= shape[i] + + N = shape[dim] + + K = 1 + for i in range(dim + 1, len(shape)): + K *= shape[i] + + # Reshape input to 3D view (M, N, K) + input_3d = input.reshape(M, N, K) + + # Create output shape + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1 :] + + # Create output tensor + output = torch.empty(output_shape, dtype=dtype, device=input.device) + + # Reshape output for kernel + output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K) + + # Launch kernel + grid = (M * K,) + BLOCK_SIZE = 1024 + + mean_kernel[grid]( + input_3d, + output_2d, + input_3d.stride(0), + input_3d.stride(1), + input_3d.stride(2), + output_2d.stride(0), + output_2d.stride(1) if output_2d.ndim > 1 else 0, + M, + N, + K, + BLOCK_SIZE, + ) + + return output + + +def mm_batch_invariant(a, b): + return matmul_persistent(a, b) + + +def addmm_batch_invariant(bias, a, b): + return matmul_persistent(a, b, bias=bias) + + +def _log_softmax_batch_invariant(input, dim, _half_to_float): + assert not _half_to_float, "not implemented" + return log_softmax(input, dim=dim) + + +def mean_batch_invariant( + input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None +): + assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" + + result = input.to(torch.float32) + + # Sort dimensions to reduce from largest to smallest to handle shifting dims + # during iterative reduction. + sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) + + # Iteratively apply a deterministic mean. + for d in sorted_dims: + result = mean_dim(result, dim=d, keepdim=True) + + if not keepdim: + # Squeeze the reduced dimensions. + for d in sorted_dims: + result = result.squeeze(d) + + return result + + +_batch_invariant_MODE = False +_batch_invariant_LIB = None + + +def is_batch_invariant_mode_enabled(): + return _batch_invariant_MODE + + +def enable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB + if _batch_invariant_MODE: + return + + _batch_invariant_MODE = True + _batch_invariant_LIB = torch.library.Library("aten", "IMPL") + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl( + "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" + ) + _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") + + +def disable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE = False + _batch_invariant_LIB = None + + +@contextlib.contextmanager +def set_batch_invariant_mode(enabled: bool = True): + global _batch_invariant_MODE, _batch_invariant_LIB + old_data = (_batch_invariant_MODE, _batch_invariant_LIB) + if enabled: + enable_batch_invariant_mode() + else: + disable_batch_invariant_mode() + yield + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE, _batch_invariant_LIB = old_data + + +AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) + + +def get_batch_invariant_attention_block_size() -> AttentionBlockSize: + return AttentionBlockSize(block_m=16, block_n=16) + + +def vllm_kernel_override_batch_invariant(): + env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" + is_overridden = False + val = os.getenv(env_key, "0") + try: + is_overridden = int(val) != 0 + except ValueError: + is_overridden = False + return is_overridden + + +def init_batch_invariance(): + # this will hit all the csrc overrides as well + if vllm_kernel_override_batch_invariant(): + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + enable_batch_invariant_mode() diff --git a/vllm/model_executor/layers/fla/__init__.py b/vllm/model_executor/layers/fla/__init__.py new file mode 100644 index 0000000000000..0e89cf9f79439 --- /dev/null +++ b/vllm/model_executor/layers/fla/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py new file mode 100644 index 0000000000000..c19cc14ba6928 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule +from .layernorm_guard import RMSNormGated + +__all__ = [ + "RMSNormGated", + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py new file mode 100644 index 0000000000000..d65c87aba11cd --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd( + k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 + ) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, ( + "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + ) + assert len(beta.shape) == 3, ( + "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + ) + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map( + lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g) + ) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py new file mode 100644 index 0000000000000..817962d9c9465 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp +from .utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_G"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr( + h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr( + h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr( + h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.make_block_ptr( + v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_v_new = ( + tl.make_block_ptr( + v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + if SAVE_NEW_VALUE + else None + ) + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr( + v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + tl.store( + p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1) + ) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr( + g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None + else None + ) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = ( + k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + ) + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return h, v_new, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py new file mode 100644 index 0000000000000..ae404a3615f61 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .op import exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr( + q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1) + ) + p_h = tl.make_block_ptr( + h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) + ) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr( + v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_o = tl.make_block_ptr( + o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000..0da3f243901fb --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .op import exp + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "USE_G": lambda args: args["g_cumsum"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr( + g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * exp(b_g_diff) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + ) + return A diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py new file mode 100644 index 0000000000000..cfa2b3b48e709 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,) + ) + p_o = tl.make_block_ptr( + o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,) + ) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({"BS": BS}, num_warps=num_warps) + for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( + "chunk_size must be a power of 2" + ) + BT = chunk_size + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None + else None + ) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( + "chunk_size must be a power of 2" + ) + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: + if not head_first and g.shape[1] < g.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + assert g.shape[0] == 1, ( + "Only batch size 1 is supported when cu_seqlens are provided" + ) + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g, chunk_size, reverse, cu_seqlens, head_first, output_dtype + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g, chunk_size, reverse, cu_seqlens, head_first, output_dtype + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py new file mode 100644 index 0000000000000..fa10bdb36caa3 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ( + ht + + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + * stride_final_state_token + ) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/index.py b/vllm/model_executor/layers/fla/ops/index.py new file mode 100644 index 0000000000000..f023e1378bb88 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/index.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +from vllm.triton_utils import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices( + cu_seqlens: torch.LongTensor, chunk_size: int +) -> torch.LongTensor: + indices = torch.cat( + [ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ] + ) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets( + cu_seqlens: torch.LongTensor, chunk_size: int +) -> torch.LongTensor: + return torch.cat( + [cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)] + ).cumsum(-1) diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py new file mode 100644 index 0000000000000..315dd904523b8 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=["D"], +) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + for BT in BT_LIST + ], + key=["D"], +) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd( + x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None +): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py new file mode 100644 index 0000000000000..655cdb3f30eb1 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Tri Dao +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2024, Tri Dao. + +# ruff: noqa: E501 +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from vllm.triton_utils import tl, triton + +from .utils import input_guard + + +def rms_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True, +): + dtype = x.dtype + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + } +) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor = None, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + @input_guard + @staticmethod + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + +def layernorm_fn( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + return LayerNormFn.apply( + x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm + ) + + +def rmsnorm_fn( + x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True +): + return LayerNormFn.apply( + x, weight, bias, z, eps, group_size, norm_before_gate, True + ) + + +class LayerNormGated(nn.Module): + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) + + +class RMSNormGated(nn.Module): + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return rmsnorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) diff --git a/vllm/model_executor/layers/fla/ops/op.py b/vllm/model_executor/layers/fla/ops/op.py new file mode 100644 index 0000000000000..ee2f4185a5df5 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/op.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +from vllm.triton_utils import tl, tldevice, triton + +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + + @triton.jit + def div_normal(x, y): + return x / y + + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +if not hasattr(tl, "gather"): + + @triton.jit + def gather(src, index, axis, _builder=None): + # This is a fallback implementation when tl.gather is not supported + # In order to pass triton compiler, there is no actual gather operation + return src +else: + gather = tl.gather diff --git a/vllm/model_executor/layers/fla/ops/solve_tril.py b/vllm/model_executor/layers/fla/ops/solve_tril.py new file mode 100644 index 0000000000000..d30fea90aec38 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/solve_tril.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .utils import input_guard + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["BT"], +) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0) + ) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr( + A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) + ) + p_Ad_11 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0) + ) + p_Ad_22 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) + ) + p_Ai_11 = tl.make_block_ptr( + Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0) + ) + p_Ai_22 = tl.make_block_ptr( + Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0) + ) + p_Ai_21 = tl.make_block_ptr( + Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) + ) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot( + tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee" + ) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) + ) + p_A_32 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0) + ) + p_A_31 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) + ) + p_A_43 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0) + ) + p_A_42 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0) + ) + p_A_41 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) + ) + p_Ad_11 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0) + ) + p_Ad_22 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) + ) + p_Ad_33 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) + ) + p_Ad_44 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) + ) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot( + tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee" + ) + Ai_32 = -tl.dot( + tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee" + ) + Ai_43 = -tl.dot( + tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee" + ) + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) + + p_Ai_11 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0) + ) + p_Ai_22 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0) + ) + p_Ai_33 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0) + ) + p_Ai_44 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0) + ) + p_Ai_21 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) + ) + p_Ai_31 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) + ) + p_Ai_32 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0) + ) + p_Ai_41 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) + ) + p_Ai_42 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0) + ) + p_Ai_43 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0) + ) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0) + ) + p_Ai_13 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0) + ) + p_Ai_14 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0) + ) + p_Ai_23 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0) + ) + p_Ai_24 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0) + ) + p_Ai_34 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0) + ) + tl.store( + p_Ai_12, + fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_13, + fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_14, + fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_23, + fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_24, + fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_34, + fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty( + B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype + ) + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + ) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = ( + merge_16x16_to_32x32_inverse_kernel + if BT == 32 + else merge_16x16_to_64x64_inverse_kernel + ) + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + return Ai diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py new file mode 100644 index 0000000000000..07124f33f1e66 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from enum import Enum +from typing import Any, Callable, Literal, Optional + +import torch + +from vllm.triton_utils import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all( + k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items() + ) + ): + cache_entries = ( + cache_entries[:i] + + cache_entries[i + 1 :] + + [(args, kwargs, last_result)] + ) + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = ( + i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args + ) + contiguous_kwargs = { + k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != "hip" else "cuda" +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)[ + "max_shared_mem" + ] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py new file mode 100644 index 0000000000000..b628a90e843f8 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3007643d7a288..799f782848944 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -6,10 +6,17 @@ from typing import Any, Optional from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) + FusedMoEActivationFormat, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -36,6 +43,8 @@ __all__ = [ "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", "FusedMoEPrepareAndFinalize", + "SharedFusedMoE", + "activation_without_mul", "override_config", "get_config", ] @@ -43,26 +52,34 @@ __all__ = [ if HAS_TRITON: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa - import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, + ) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, + ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4, - cutlass_moe_fp8) - from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts) + CutlassBatchedExpertsFp8, + CutlassExpertsFp8, + cutlass_moe_fp4, + cutlass_moe_fp8, + ) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + BatchedTritonExperts, + ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, fused_experts, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + TritonExperts, + fused_experts, + fused_topk, + get_config_file_name, + grouped_topk, + ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, + ) __all__ += [ - "fused_moe", "fused_topk", "fused_experts", "get_config_file_name", @@ -78,3 +95,11 @@ if HAS_TRITON: "TritonOrDeepGemmExperts", "BatchedTritonOrDeepGemmExperts", ] +else: + # Some model classes directly use the custom ops. Add placeholders + # to avoid import errors. + def _raise_exception(method: str): + raise NotImplementedError(f"{method} is not implemented as lack of triton.") + + fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk") + fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts") diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index c4d680af932f0..35d2dcb91d253 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,12 +7,14 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_e8m0_used) +from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used logger = init_logger(__name__) @@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm( y_q_ptr, # fp8 quantized activations (E, T, H) y_s_ptr, # 16-bit scales (E, T, G) counts_ptr, # int32 num tokens per expert (E) - # Sizes --------------------------------------------------------------- H: tl.constexpr, # hidden dimension (per output) GROUP_SIZE: tl.constexpr, # elements per group (usually 128) - # Strides for input (elements) --------------------------------------- stride_i_e, stride_i_t, stride_i_h, - # Strides for y_q (elements) ----------------------------------------- stride_yq_e, stride_yq_t, stride_yq_h, - # Strides for y_s (elements) ----------------------------------------- stride_ys_e, stride_ys_t, stride_ys_g, - # Stride for counts (elements) stride_counts_e, - # Numeric params ------------------------------------------------------ eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, use_ue8m0: tl.constexpr, - # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, NUM_STAGES: tl.constexpr, @@ -76,17 +71,14 @@ def _silu_mul_fp8_quant_deep_gemm( base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h base_gate_offset = base_input_offset + cols * stride_i_h base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h - base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + - cols * stride_yq_h) + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h base_ys_offset = e * stride_ys_e + g * stride_ys_g for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): - gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, - mask=mask, - other=0.0).to(tl.float32) - up = tl.load(input_ptr + base_up_offset + t * stride_i_t, - mask=mask, - other=0.0) + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) gate = gate * (1.0 / (1.0 + tl.exp(-gate))) y = gate * up @@ -101,120 +93,153 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm( +def persistent_masked_m_silu_mul_quant( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens=16, group_size: int = 128, - eps: float = 1e-10, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales - - y has shape (E, T, 2*H). The first half of the last dimension is + y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. + We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2` + be a parallelization factor for persistent_masked_m_silu_mul_quant over the + hidden dimension. + + Let `expert_offsets = [0] + [num_tokens.cumsum()]` and + `total_tokens = expert_offsets[-1]`. + persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of + thread blocks. Each thread block contains `NUM_WARPS` warps. + + Every thread block needs to find it's corresponding expert by warp-parallel scanning + over the `expert_offsets` array. + + The i-th warp in the first thread block processes + `[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups + sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`, + pipelining loads and computes. + + The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2 + can is visualized like so: + + stage0 stage1 + ┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐ + │gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│ + └─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘ + + with the main difference between V1 and V2 being the global load + stride between warps, and between half-warps. Regarding the latter stride, + we assign the first half warp of every warp for `gate` loads and the second + half-warp to `up` loads. Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + Let NUM_WARPS be the number of warps in a single thread block and + `GROUP_SIZE = 128` be the size of the quantization group. """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape assert H2 % 2 == 0, "last dim of y must be even (2*H)" H = H2 // 2 - G = H // group_size - assert H % group_size == 0, "H must be divisible by group_size" - assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ - "tokens_per_expert must be shape (E,)" - tokens_per_expert = tokens_per_expert.to(device=y.device, - dtype=torch.int32) + G = (H + group_size - 1) // group_size + assert H % 8 == 0, "H must be divisible by 8" + assert group_size == 128, "H must be divisible by 8" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) - # allocate outputs fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - # strides (elements) - stride_i_e, stride_i_t, stride_i_h = y.stride() - stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() - - # desired scale strides (elements): (T*G, 1, T) stride_ys_e = T * G stride_ys_t = 1 stride_ys_g = T - y_s = torch.empty_strided((E, T, G), - (stride_ys_e, stride_ys_t, stride_ys_g), - dtype=torch.float32, - device=y.device) - - stride_cnt_e = tokens_per_expert.stride()[0] - - # Static grid over experts and H-groups. - # A loop inside the kernel handles the token dim - grid = (E * G, ) - - f_info = torch.finfo(fp8_dtype) - fp8_max = f_info.max - fp8_min = f_info.min - - _silu_mul_fp8_quant_deep_gemm[grid]( - y, - y_q, - y_s, - tokens_per_expert, - H, - group_size, - stride_i_e, - stride_i_t, - stride_i_h, - stride_yq_e, - stride_yq_t, - stride_yq_h, - stride_ys_e, - stride_ys_t, - stride_ys_g, - stride_cnt_e, - eps, - fp8_min, - fp8_max, - is_blackwell_deep_gemm_e8m0_used(), - BLOCK=group_size, - NUM_STAGES=4, - num_warps=1, + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, ) + use_ue8m0 = is_deep_gemm_e8m0_used() + + cuda_arch = current_platform.get_device_capability( + device_id=y.device.index + ).to_int() + + if cuda_arch >= 80: + torch.ops._C.persistent_masked_m_silu_mul_quant( + y, tokens_per_expert, y_q, y_s, use_ue8m0 + ) + else: + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G,) + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + eps: float = 1e-10 + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + return y_q, y_s class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - - # The Deep Gemm kernels only support block size of 128 - DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] - - def __init__(self, - max_num_tokens: int, - num_dispatchers: int, - block_shape: list[int], - per_act_token_quant=False): + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + ): """ max_num_tokens: Maximum number of tokens from a DP Rank num_dispatchers: The number of DP dispatchers. - block_shape: Block quantization block shape. - per_act_token_quant: Per activation token quantization flag. + quant_config: Quantization configuration """ - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE + super().__init__(quant_config) + assert self.block_shape == deep_gemm_block_shape() self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -228,29 +253,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks # end up sending their tokens. This needs to be fixed. num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = (num_experts, max_num_tokens * num_dispatchers, - max(K, N)) + max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) output = (num_experts, max_num_tokens * num_dispatchers, K) - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def apply( self, @@ -263,10 +283,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -285,8 +301,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w2.size(1) == K - E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( - hidden_states, w1, w2, topk_ids) + E, max_num_tokens, N, K, _ = self.moe_problem_size( + hidden_states, w1, w2, topk_ids + ) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) @@ -294,11 +311,18 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens - fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), - workspace1, expert_num_tokens, expected_m) + fp8_m_grouped_gemm_nt_masked( + (a1q, a1q_scale), + (w1, self.w1_scale), + workspace1, + expert_num_tokens, + expected_m, + ) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, - expert_num_tokens) + a2q, a2q_scale = persistent_masked_m_silu_mul_quant( + workspace1, expert_num_tokens + ) - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, - expert_num_tokens, expected_m) + fp8_m_grouped_gemm_nt_masked( + (a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m + ) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 89d7412ee2236..09c4de0f87159 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -6,70 +6,60 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - - def __init__(self, - max_num_tokens: int, - num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, - allow_deep_gemm: bool = False): - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - )) + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + allow_deep_gemm: bool = False, + ): + super().__init__(quant_config) self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape, + quant_config=self.quant_config, ) - self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 - and self.block_shape - == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + self.allow_deep_gemm = ( + allow_deep_gemm + and self.quant_config.use_fp8_w8a8 + and self.block_shape == deep_gemm_block_shape() + ) - self.batched_deep_gemm_experts = BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - block_shape=self.block_shape, # type: ignore[arg-type] - ) if self.allow_deep_gemm else None + self.batched_deep_gemm_experts = ( + BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + quant_config=self.quant_config, + ) + if self.allow_deep_gemm + else None + ) - assert (self.batched_deep_gemm_experts is not None - or self.batched_triton_experts is not None) + assert ( + self.batched_deep_gemm_experts is not None + or self.batched_triton_experts is not None + ) @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: if self.batched_triton_experts is not None: - assert (self.batched_deep_gemm_experts is None - or self.batched_deep_gemm_experts.activation_formats - == self.batched_triton_experts.activation_formats) + assert ( + self.batched_deep_gemm_experts is None + or self.batched_deep_gemm_experts.activation_formats + == self.batched_triton_experts.activation_formats + ) return self.batched_triton_experts.activation_formats else: assert self.batched_deep_gemm_experts is not None @@ -78,14 +68,16 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_chunking(self) -> bool: bdge = self.batched_deep_gemm_experts bte = self.batched_triton_experts - return ((bdge is None or bdge.supports_chunking()) - and (bte is None or bte.supports_chunking())) + return (bdge is None or bdge.supports_chunking()) and ( + bte is None or bte.supports_chunking() + ) def supports_expert_map(self) -> bool: bdge = self.batched_deep_gemm_experts bte = self.batched_triton_experts - return ((bdge is None or bdge.supports_expert_map()) - and (bte is None or bte.supports_expert_map())) + return (bdge is None or bdge.supports_expert_map()) and ( + bte is None or bte.supports_expert_map() + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: bdge = self.batched_deep_gemm_experts @@ -98,7 +90,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): if is_bdge_war and is_bte_war: assert bdge_war == bte_war, ( "Both implementations should agree on WeightAndReduce impls. " - f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}") + f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}" + ) if bdge_war is not None: return bdge_war @@ -106,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert bte_war is not None return bte_war + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -117,20 +111,32 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm: assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts, - expert_tokens_metadata) + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_metadata, + ) else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts, - expert_tokens_metadata) + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_metadata, + ) def apply( self, @@ -143,10 +149,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -154,11 +156,26 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - experts = (self.batched_deep_gemm_experts - if self.allow_deep_gemm else self.batched_triton_experts) + experts = ( + self.batched_deep_gemm_experts + if self.allow_deep_gemm + else self.batched_triton_experts + ) assert experts is not None - experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, - activation, global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_tokens_meta, - apply_router_weight_on_input) + experts.apply( + output, + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + activation, + global_num_experts, + expert_map, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_tokens_meta, + apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 7c1a7b636a9c2..5780c969d273a 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -4,78 +4,183 @@ from dataclasses import dataclass from typing import Optional, Union import torch -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) import vllm.envs as envs from vllm.config import ParallelConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.utils import cdiv +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( + OCP_MX_DTYPES, + OCP_MX_Scheme, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.utils import cdiv, has_triton_kernels from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe logger = init_logger(__name__) - -def _get_quant_config_quantization_args( - quant_config: Optional[QuantizationConfig], - prop_name: str, -) -> Optional[QuantizationArgs]: - if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') - and "Linear" in quant_config.target_scheme_map and - "input_activations" in quant_config.target_scheme_map["Linear"]): - return quant_config.target_scheme_map["Linear"].get(prop_name) - else: - return None +if has_triton_kernels(): + try: + from triton_kernels.matmul_ogs import PrecisionConfig + except ImportError: + logger.error( + "Failed to import Triton kernels. Please make sure your triton " + "version is compatible." + ) -def get_quant_config_input_quant( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - return _get_quant_config_quantization_args(quant_config, - "input_activations") - - -def get_quant_config_weight_quant( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - return _get_quant_config_quantization_args(quant_config, "weights") - - -def get_config_quant_dtype( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, -) -> Union[None, torch.dtype, str]: +def _get_config_dtype_str( + dtype: torch.dtype, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: Optional[str] = None, +) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ if use_fp8_w8a8: - return torch.float8_e4m3fn - elif use_int8_w8a8: - return torch.int8 - elif use_mxfp4_w4a4: - return "mxfp4" + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" + elif ocp_mx_scheme is not None: + # The output of this function is passed to `try_get_optimal_moe_config`, + # and as we only simulate OCP MX execution in fused_moe for now, + # we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now. + return None + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" return None +def _quant_flags_to_group_shape( + quant_dtype: Union[torch.dtype, str, None], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]], +) -> tuple[Optional[GroupShape], Optional[GroupShape]]: + """ + Convert MoE quantization flags into more generic GroupShapes. + """ + a_shape: Optional[GroupShape] + w_shape: Optional[GroupShape] + if block_shape is not None: + assert not per_act_token_quant + assert not per_out_ch_quant + # TODO(bnell): this is not quite right for activations since first + # dim should be 1. + a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + else: + w_shape = None + a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR + + if per_act_token_quant: + a_shape = GroupShape.PER_TOKEN + + if per_out_ch_quant: + w_shape = GroupShape.PER_TOKEN + + return a_shape, w_shape + + +@dataclass +class FusedMoEQuantDesc: + """ + A quantization descriptor for fused MoE ops. This class can describe + either activations or weights. + """ + + # The quantized type of this parameters. None means unquantized or + # already quantized. + # TODO (bnell): use scalar_type instead of Union. + dtype: Union[torch.dtype, str, None] = None + + # A field that describes the quantization group shape, from quant_utils.py. + # * (-1, -1) for per-tensor quantization + # * (1, -1) for per-row quantization + # * (-1, 1) for per-column quantization + # * (128, 128) for 128x128 deepseek style block quantization + # * (1, 128) for deepseek style activation quantization + # (i.e. per-token-per-group) + shape: Optional[GroupShape] = None + + # Quantization scales. + # TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc? + scale: Union[torch.Tensor, "PrecisionConfig", None] = None + + # Quantization alphas or gscales, used for nvfp4 types. + # TODO(bnell): put some of these in subclasses + alpha_or_gscale: Optional[torch.Tensor] = None + + # Zero points for int4/int8 types + zp: Optional[torch.Tensor] = None + + # Biases for GPT triton MoE + bias: Optional[torch.Tensor] = None + + +# TODO(bnell): have subclasses for specific moe methods? +# e.g. for specific arguments bias, precision, etc. @dataclass class FusedMoEQuantConfig: - # The post quantization activation type. - # TODO (bnell): use scalar_type instead of Union. - quant_dtype: Union[torch.dtype, str, None] = None - per_act_token_quant: bool = False - per_out_ch_quant: bool = False - block_shape: Optional[list[int]] = None + """ + The FusedMoEQuantConfig contains all the quantization parameters for + a single FusedMoEMethodBase operation. It consists of four + FusedMoEQuantDescs, one for each activation and set of weights. - # TODO: add col major flag? - # add detailed quant info for input, intermediates, weights, etc? + Each FusedMoEMethodBase must implement a get_fused_moe_quant_config + method to construct a FusedMoEQuantConfig for use with that class. + + FusedMoEQuant configs are only used for modular kernels, fused_experts + (from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and + triton_kernel_moe_forward. Other MoE methods can ignore the + FusedMoEQuantConfig (for now) and hardcode it to None. + + There are currently some restrictions on what can be expressed: + - Most MoE ops only support similar quantization strategies for + each parameter, e.g. both weights must have the same GroupShape + and both activations must share the same GroupShape. One exception to + this is the cutlass moe which allows per channel quantization on the + outputs. Note: this restrictions are not always rigorously checked. + - Not all fused MoE functions support all the parameters, e.g. zero points, + global scales, alphas and biases are not universally supported. + - Fully general GroupShapes are not allowed. Activations only support + per token, per tensor or K-blocked. + - Weights are not required to have a GroupShape since they have already + been quantized. + + Other notes: + - PrecisionConfigs are specific to GPT OSS Triton. + - As a follow up it would probably make sense to subclass FusedMoEQuantDesc + or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses + so that only the required quantization parameters are used/stored. + """ + + # TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking + _a1: FusedMoEQuantDesc + _a2: FusedMoEQuantDesc + _w1: FusedMoEQuantDesc + _w2: FusedMoEQuantDesc def __post_init__(self): - assert (not self.per_act_token_quant - or self.block_shape is None), "illegal quantization" + assert not self.per_act_token_quant or self.block_shape is None, ( + "illegal quantization" + ) + + # + # Convenience accessors for various properties. + # + + @property + def quant_dtype(self) -> Union[torch.dtype, str, None]: + return self._a1.dtype @property def is_quantized(self) -> bool: @@ -83,21 +188,163 @@ class FusedMoEQuantConfig: @property def is_per_act_token(self) -> bool: - return self.per_act_token_quant + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_act_token_quant(self) -> bool: + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_out_ch_quant(self) -> bool: + return self._w1.shape == GroupShape.PER_TOKEN + + @property + def is_per_tensor(self) -> bool: + return self._a1.shape == GroupShape.PER_TENSOR + + @property + def block_shape(self) -> Optional[list[int]]: + if ( + self._a1.shape is not None + and self._a1.shape != GroupShape.PER_TENSOR + and self._a1.shape != GroupShape.PER_TOKEN + ): + return [self._a1.shape.row, self._a1.shape.col] + else: + return None @property def is_block_quantized(self) -> bool: return self.block_shape is not None @property - def is_per_tensor(self) -> bool: - return not self.per_act_token_quant and self.block_shape is None + def a1_scale(self) -> Optional[torch.Tensor]: + assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor) + return self._a1.scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self._a1.alpha_or_gscale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor) + return self._a2.scale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self._a2.alpha_or_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor) + return self._w1.scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self._w1.zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self._w1.bias + + @property + def w1_precision(self) -> Optional["PrecisionConfig"]: + assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig) + return self._w1.scale + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self._w1.alpha_or_gscale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor) + return self._w2.scale + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self._w2.zp + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self._w2.bias + + @property + def w2_precision(self) -> Optional["PrecisionConfig"]: + assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig) + return self._w2.scale + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self._w2.alpha_or_gscale + + @property + def use_fp8_w8a8(self) -> bool: + return self.quant_dtype == torch.float8_e4m3fn + + @property + def use_int8_w8a8(self) -> bool: + return self.quant_dtype == torch.int8 + + @property + def use_int8_w8a16(self) -> bool: + return self._a1.dtype is None and self._w1.dtype == torch.int8 + + @property + def use_int4_w4a16(self) -> bool: + return self._a1.dtype is None and self._w1.dtype == "int4" + + @property + def ocp_mx_scheme(self) -> Union[str, None]: + if not hasattr(self, "_ocp_mx_scheme"): + if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or ( + self._w1.dtype is not None and not isinstance(self._w1.dtype, str) + ): + self._ocp_mx_scheme = None + else: + ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( + self._a1.dtype, self._w1.dtype + ) + + if ocp_mx_scheme is not None: + ocp_mx_scheme = ocp_mx_scheme.value + + self._ocp_mx_scheme = ocp_mx_scheme + + return self._ocp_mx_scheme + + @property + def use_mxfp4_w4a16(self) -> bool: + return self._a1.dtype is None and self._w1.dtype == "mxfp4" + + @property + def use_nvfp4_w4a4(self) -> bool: + return self.quant_dtype == "nvfp4" + + def config_name(self, dtype: torch.dtype) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ + return _get_config_dtype_str( + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + ocp_mx_scheme=self.ocp_mx_scheme, + dtype=dtype, + ) def scale_shape( self, max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int]]: + """ + Construct the proper activation scale shape for this + config. + """ if self.is_quantized: if self.is_block_quantized: assert self.block_shape is not None @@ -117,6 +364,10 @@ class FusedMoEQuantConfig: max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int, int]]: + """ + Construct the proper activation batched scale shape for this + config, e.g. (num experts, *scale_shape). + """ if self.is_quantized: scale_shape = self.scale_shape(max_tokens, hidden_dim) assert scale_shape is not None @@ -126,38 +377,258 @@ class FusedMoEQuantConfig: @staticmethod def make( - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + quant_dtype: Union[torch.dtype, str, None] = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: Optional[list[int]] = None, + w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + g1_alphas: Optional[torch.Tensor] = None, + g2_alphas: Optional[torch.Tensor] = None, + a1_gscale: Optional[torch.Tensor] = None, + a2_gscale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + weight_dtype: Union[torch.dtype, str, None] = None, ) -> "FusedMoEQuantConfig": - assert sum([ - int(flag) for flag in [ - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - use_mxfp4_w4a4, - ] - ]) <= 1, "Quantization flags are mutually exclusive." + """ + General builder function for a FusedMoEQuantConfig. + - quant_dtype: Optional quantization type. None if activations are + unquantized or quantized prior to calling. Note: "nvfp4", "mxfp4", + "mxfp6_e3m2", "mxfp6_e2m3" are the only valid string values + for quant_dtype. + - per_act_token_quant: Activations have per token quantization. + - per_out_ch_quant: Outputs have per channel quantization. (only + for cutlass). + - block_shape: Optional block size for block-wise quantization. + Incompatible with per_act_token and per_out_ch quant. + - w1_scale: Optional scale to be used for w1. + - w2_scale: Optional scale to be used for w2. + - a1_scale: Optional scale to be used for a1. + - a2_scale: Optional scale to be used for a2. + - g1_alphas: Optional global quantization scales for w1 (for nvfp4). + - g2_alphas: Optional global quantization scales for w2 (for nvfp4). + - a1_gscale: Optional global quantization scales for a1 (for nvfp4). + - a2_gscale: Optional global quantization scales for a2 (for nvfp4). + - w1_bias: Optional biases for w1 (GPT OSS Triton). + - w2_bias: Optional biases for w1 (GPT OSS Triton). + - w1_zp: Optional w1 zero points for int4/int8 quantization. + - w2_zp: Optional w2 zero points for int4/int8 quantization. + """ + assert not isinstance(quant_dtype, str) or quant_dtype in { + "nvfp4", + "mxfp4", + "mxfp6_e3m2", + "mxfp6_e2m3", + } + assert not isinstance(weight_dtype, str) or weight_dtype in { + "nvfp4", + "mxfp4", + "mxfp6_e3m2", + "mxfp6_e2m3", + } - quant_dtype = get_config_quant_dtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, + if weight_dtype is None: + weight_dtype = quant_dtype + + a_shape, w_shape = _quant_flags_to_group_shape( + quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape ) - return FusedMoEQuantConfig( - quant_dtype, - per_act_token_quant, - per_out_ch_quant, - block_shape, + quant_config = FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale), + _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale), + _w1=FusedMoEQuantDesc( + weight_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias + ), + _w2=FusedMoEQuantDesc( + weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias + ), ) + assert quant_config.per_act_token_quant == per_act_token_quant + assert quant_config.per_out_ch_quant == per_out_ch_quant + assert quant_config.block_shape == block_shape + return quant_config + + +def fp8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and fp8 weights. + """ + return FusedMoEQuantConfig.make( + torch.float8_e4m3fn, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) + + +def int8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + per_act_token_quant: bool = False, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for int8 activations and int8 weights. + """ + return FusedMoEQuantConfig.make( + torch.int8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=None, + ) + + +def mxfp4_w4a16_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + +def ocp_mx_moe_quant_config( + quant_dtype: str, + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + weight_dtype: Optional[str] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + assert quant_dtype in OCP_MX_DTYPES + return FusedMoEQuantConfig.make( + quant_dtype=quant_dtype, + weight_dtype=weight_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=block_shape, + ) + + +def nvfp4_moe_quant_config( + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and nvp4 weights. + """ + return FusedMoEQuantConfig.make( + "nvfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=None, + ) + + +def int4_w4a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and int4 weights. + Note: Activations are pre-quantized. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(shape=group_shape), + _a2=FusedMoEQuantDesc(shape=group_shape), + _w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp), + _w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp), + ) + + +def int8_w8a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and int8 weights. + Note: Activations are pre-quantized. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(shape=group_shape), + _a2=FusedMoEQuantDesc(shape=group_shape), + _w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp), + _w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp), + ) + + +def biased_moe_quant_config( + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], +) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations with biases. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc(bias=w1_bias), + _w2=FusedMoEQuantDesc(bias=w2_bias), + ) + + +# A FusedMoEQuantConfig constant for an unquantized MoE op. +FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make() @dataclass @@ -177,28 +648,26 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") + return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx" @property def use_deepep_ht_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + return ( + self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + ) @property def use_deepep_ll_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - - @property - def use_flashinfer_cutlass_kernels(self): - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") + return ( + self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" + ) @staticmethod - def make(tp_size_: int, dp_size_: int, - vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + def make( + tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig + ) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input `tp_size_`, `dp_size_` and vllm's parallel config, determine what @@ -278,34 +747,37 @@ class FusedMoEParallelConfig: tp_rank = dp_rank * tp_size_ + tp_rank return tp_size, tp_rank - use_ep = (dp_size_ * tp_size_ > 1 - and vllm_parallel_config.enable_expert_parallel) + use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: - return FusedMoEParallelConfig(tp_size=tp_size, - tp_rank=tp_rank, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=1, - ep_rank=0, - use_ep=False) + return FusedMoEParallelConfig( + tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False, + ) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank - return FusedMoEParallelConfig(tp_size=1, - tp_rank=0, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - use_ep=True) + return FusedMoEParallelConfig( + tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True, + ) # Adapted from pplx-kernels tests/all_to_all_utils.py @@ -321,47 +793,18 @@ class FusedMoEConfig: # The activation type. in_dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] = None - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False def __post_init__(self): if self.dp_size > 1: - logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d", - self.max_num_tokens) + logger.debug_once( + "Using FusedMoEConfig::max_num_tokens=%d", self.max_num_tokens + ) assert self.max_num_tokens > 0 - @property - def quant_dtype(self) -> Union[torch.dtype, str, None]: - if self.quant_config is not None: - return self.quant_config.quant_dtype - else: - return None - - @property - def block_shape(self) -> Optional[list[int]]: - if self.quant_config is not None: - return self.quant_config.block_shape - else: - return None - - @property - def per_act_token_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_act_token_quant - else: - return False - - @property - def per_out_ch_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_out_ch_quant - else: - return False - @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -404,87 +847,11 @@ class FusedMoEConfig: @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels - - @staticmethod - def make( - num_experts: int, - experts_per_token: int, - hidden_dim: int, - num_local_experts: int, - moe_parallel_config: FusedMoEParallelConfig, - in_dtype: torch.dtype, - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config: Optional[Union[FusedMoEQuantConfig, - QuantizationConfig]] = None, - has_bias: bool = False, - ) -> "FusedMoEConfig": - - _quant_config: Optional[FusedMoEQuantConfig] = None - - if quant_config is not None and isinstance(quant_config, - QuantizationConfig): - if hasattr(quant_config, 'weight_block_size'): - block_shape = quant_config.weight_block_size - else: - block_shape = None - per_act_token_quant = False - per_out_ch_quant = False - quant_dtype: Union[torch.dtype, str, None] = None - - input_quant = get_quant_config_input_quant(quant_config) - weight_quant = get_quant_config_weight_quant(quant_config) - - if input_quant is not None: - per_act_token_quant = (input_quant.strategy - == QuantizationStrategy.TOKEN - if input_quant is not None else False) - - if input_quant.num_bits == 8: - if input_quant.type == QuantizationType.FLOAT: - quant_dtype = torch.float8_e4m3fn - elif input_quant.type == QuantizationType.INT: - quant_dtype = torch.int8 - - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if quant_dtype is None and isinstance(quant_config, Fp8Config): - quant_dtype = torch.float8_e4m3fn - - from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptNvFp4Config) - if quant_dtype is None and isinstance(quant_config, - ModelOptNvFp4Config): - quant_dtype = "nvfp4" - - if weight_quant is not None: - per_out_ch_quant = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL) - - if quant_dtype is not None: - _quant_config = FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - ) - else: - _quant_config = FusedMoEQuantConfig() - if moe_parallel_config.dp_size > 1: - logger.warning_once("MoE DP setup unable to determine " - "quantization scheme or unsupported " - "quantization type. This model will " - "not run with DP enabled.") - else: - _quant_config = quant_config - - return FusedMoEConfig( - num_experts=num_experts, - experts_per_token=experts_per_token, - hidden_dim=hidden_dim, - num_local_experts=num_local_experts, - moe_parallel_config=moe_parallel_config, - in_dtype=in_dtype, - quant_config=_quant_config, - max_num_tokens=max_num_tokens, - has_bias=has_bias, + """ + Whether to use FlashInfer cutlass kernels for NVFP4 MoE. + """ + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput" ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..99501df6f1764 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } + } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json index 2c78bfaba7890..2e0dd7a4b9507 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -1,218 +1,146 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 8, + "num_warps": 4, "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 8, + "num_warps": 4, "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 + "num_warps": 4, + "num_stages": 5 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 + "num_warps": 4, + "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, "48": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 + "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2 + "num_stages": 4 }, "256": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 }, "512": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 + "num_warps": 4, + "num_stages": 4 }, "1024": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 3 }, "1536": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, "num_stages": 3 }, "2048": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 3 }, "3072": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 3 }, "4096": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "5120": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "9216": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "13312": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "17408": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "25600": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "33792": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "41984": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "50176": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "58368": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 3 } } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json index 4da841e74a79f..4ea86340c3243 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json @@ -5,7 +5,7 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "2": { "BLOCK_SIZE_M": 16, @@ -13,7 +13,7 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, @@ -21,7 +21,7 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, @@ -29,7 +29,7 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "16": { "BLOCK_SIZE_M": 16, @@ -37,52 +37,52 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 + "num_warps": 8, + "num_stages": 3 }, "48": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "96": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, "num_stages": 4 }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 @@ -91,57 +91,57 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 }, "512": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 + "num_warps": 4, + "num_stages": 3 }, "2048": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, "3072": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "4096": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "5120": { "BLOCK_SIZE_M": 128, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..f3f1a562710b0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..19046fcf1d6a2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..5f9422fe6f7c4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..600bd4444535a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json @@ -0,0 +1,123 @@ +{ + "triton_version": "3.4.0", + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e5059358c91e3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..db1b6e98df469 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..b962d19506ce5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json index b9dc2d71f6dcf..1bbb8aa613996 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -9,16 +9,16 @@ }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -26,15 +26,15 @@ "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -42,7 +42,7 @@ "24": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -53,12 +53,12 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -82,10 +82,10 @@ "128": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "256": { "BLOCK_SIZE_M": 16, @@ -98,8 +98,8 @@ "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -123,15 +123,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..8fb4947d62ab2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..6d0cdfd274293 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..de8eec366eca3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } + } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..80fce79fb64c9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } + } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..54d3bf190ebec --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json index 26f9abd6b789e..6a4018195603a 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -1,29 +1,5 @@ { "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, @@ -31,44 +7,68 @@ "num_warps": 4, "num_stages": 5 }, - "16": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, - "24": { + "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 }, "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, "96": { @@ -77,22 +77,22 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "128": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 + "num_warps": 8, + "num_stages": 3 }, "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 4 }, "512": { @@ -100,47 +100,47 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 8, + "num_warps": 4, "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "3072": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 }, "4096": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 3 } } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..4f500d487c56d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000000000..ed8afa6b6db88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..5fea55a8000ff --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json index bbb2386046b11..1e3f46e0ba84a 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -2,7 +2,7 @@ "1": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -20,78 +20,78 @@ "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 + "num_warps": 8, + "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "64": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16, "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, "96": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 3 }, "256": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -100,47 +100,47 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 + "num_warps": 4, + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "4096": { - "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 3 } } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..8239492d8f4f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..bdbaf3811c939 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..6e17bcd214748 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..aa7610cd75e77 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..df920e8b39ba8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e8fe8ea67f246 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..0baf13cb6a5c5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json index 307c9240938c5..c7998718dab4c 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json @@ -18,18 +18,18 @@ "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, @@ -58,7 +58,7 @@ "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 @@ -74,73 +74,73 @@ "96": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2 + "num_stages": 4 }, "256": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, "num_stages": 4 }, "1024": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 + "num_warps": 4, + "num_stages": 3 }, - "3072": { - "BLOCK_SIZE_M": 128, + "2048": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 }, "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..cc853947c19f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..bf97f671477b3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..24f13cdeff4f8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..b4e736bec9b65 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..bb71005a72bc5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..ac53df14ce846 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..f1ed617d6308f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e72282dc5bcd9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..4fc4868eaa85a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..d70adca05e779 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..0f5867fea5f89 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..d104aa5167b22 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..22e3d09676d06 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} + diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..94408e279b656 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..9f4c3cbc9b8a9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..20146f53a6eba --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..d0140252594f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..8bac7af0c2dac --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..cc1427c139e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..68649395a23ed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..2f0b45014e863 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..5d69efe9ed5f9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..5910027e17f9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..564ff499d43c4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..a68c83147eeb3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..e55df46b40269 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..a0855a921f3f6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..5dd1a8e19c2ce --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..d5b6d02123d71 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..40d86ff8ba324 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..6014d827d7417 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..3622659f3e915 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..311d2e829a050 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..91c4b916b8649 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..8fee30ec70660 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..d677d69c57a25 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,154 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index e67ff66882102..3592a88b0ef2f 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -3,14 +3,136 @@ from typing import Callable, Optional import torch +from torch.nn import functional as F from vllm import envs -class IPEXFusedMOE: +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + +def swigluoai_and_mul( + x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0 +) -> torch.Tensor: + d = x.shape[-1] // 2 + gate, up = x[..., :d], x[..., d:] + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(alpha * gate) + return (up + 1) * glu + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights, topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + return grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_logit_vals, topk_idx = torch.topk( + router_logits, k=top_k, dim=-1, sorted=False + ) + if renormalize: + topk_vals = torch.softmax(topk_logit_vals, dim=-1) + else: + logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True) + topk_vals = (topk_logit_vals - logZ).exp() + return topk_vals.to(torch.float32), topk_idx.to(torch.int32) + else: + return custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + + +class IPEXFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, @@ -31,12 +153,16 @@ class IPEXFusedMOE: expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input + assert routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {routed_scaling_factor} is not supported." + ) return layer.ipex_fusion( x, use_grouped_topk, @@ -52,117 +178,9 @@ class IPEXFusedMOE: class SGLFusedMOE: - def __init__(self, layer: torch.nn.Module) -> None: pass - @staticmethod - def _grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - - gating_output = gating_output.float() - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - num_token = scores.shape[0] - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use - # biased scores for expert selection but original scores for - # routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) - else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, - k=topk_group, - dim=-1, - sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, - -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, - keepdim=True) - - return topk_weights, topk_ids.to(torch.int32) - - @staticmethod - def _select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = SGLFusedMOE._grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - elif custom_routing_function is None: - assert scoring_func == "softmax" - topk_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) - if renormalize: - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - topk_ids = topk_ids.to(torch.int32) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - - return topk_weights, topk_ids - def __call__( self, layer: torch.nn.Module, @@ -177,13 +195,14 @@ class SGLFusedMOE: expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input - topk_weights, topk_ids = SGLFusedMOE._select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -193,6 +212,7 @@ class SGLFusedMOE: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, ) @@ -213,3 +233,94 @@ class SGLFusedMOE: True, ) return x + + +class CPUFusedMOE: + def __init__(self, layer: torch.nn.Module) -> None: + pass + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation in {"silu", "swigluoai"}, f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 + len_experts = global_num_experts + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + has_w13_bias = hasattr(layer, "w13_bias") + has_w2_bias = hasattr(layer, "w2_bias") + + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + layer_w13_weight = layer.w13_weight[i] + layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None + layer_w2_weight = layer.w2_weight[i] + layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None + + gate_up = F.linear( + tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias + ) + if activation == "swigluoai": + gate_up = swigluoai_and_mul(gate_up) + else: + gate_up = silu_and_mul(gate_up) + expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weights.dtype) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 95d23ec0346c1..fa158287d418d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" CUTLASS based Fused MoE kernels.""" +"""CUTLASS based Fused MoE kernels.""" + from typing import Callable, Optional import torch @@ -10,13 +11,17 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_unpermute) + moe_permute, + moe_unpermute, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache) + TopKWeightAndReduceDelegate, + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize, _resize_cache from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -56,20 +61,28 @@ def run_cutlass_moe_fp8( assert w2.dtype == torch.float8_e4m3fn assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1" assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" - assert w1_scale.dim() == 1 or w1_scale.size( - 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.size( - 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert ( + w1_scale.dim() == 1 or w1_scale.size(1) == 1 or w1_scale.shape[1] == w1.size(1) + ), "W1 scale shape mismatch" + assert ( + w2_scale.dim() == 1 or w2_scale.size(1) == 1 or w2_scale.shape[1] == w2.size(1) + ), "W2 scale shape mismatch" assert w1.size(0) == w2.size(0), "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( - 0) == 1 or a1q_scale.size( - 0) == a1q.shape[0], "Input scale shape mismatch" + assert ( + a1q_scale is None + or a1q_scale.dim() == 0 + or a1q_scale.size(0) == 1 + or a1q_scale.size(0) == a1q.shape[0] + ), "Input scale shape mismatch" assert w1.size(0) == w2.size(0), "Weights expert number mismatch" assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( - 0) == 1 or a2_scale.size( - 0) == a1q.shape[0], "Intermediate scale shape mismatch" + assert ( + a2_scale is None + or a2_scale.dim() == 0 + or a2_scale.size(0) == 1 + or a2_scale.size(0) == a1q.shape[0] + ), "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" if expert_map is not None: assert expert_num_tokens is None @@ -97,8 +110,9 @@ def run_cutlass_moe_fp8( if expert_map is not None: "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids] != -1, - expert_map[topk_ids], -1) + local_topk_ids = torch.where( + expert_map[topk_ids] != -1, expert_map[topk_ids], -1 + ) else: local_topk_ids = topk_ids @@ -108,35 +122,39 @@ def run_cutlass_moe_fp8( if use_batched_format: mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2)) act_out = _resize_cache(workspace2, (local_E * padded_M, N)) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), - (local_E * padded_M, N)) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (local_E * padded_M, N) + ) mm2_out = _resize_cache(workspace2, (local_E * padded_M, K)) else: - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), - (M * topk, K)) + a1q_perm = _resize_cache( + workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K) + ) mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) act_out = _resize_cache(workspace2, (M * topk, N)) # original workspace are based on input hidden_states dtype (bf16) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), - (M * topk, N)) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N) + ) mm2_out = _resize_cache(workspace2, (M * topk, K)) if use_batched_format: assert expert_num_tokens is not None - expert_offsets = torch.empty((local_E), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((local_E, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((local_E, 3), - dtype=torch.int32, - device=device) + expert_offsets = torch.empty((local_E), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) - ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1, - problem_sizes2, expert_num_tokens, - local_E, padded_M, N, K) + ops.get_cutlass_pplx_moe_mm_data( + expert_offsets, + problem_sizes1, + problem_sizes2, + expert_num_tokens, + local_E, + padded_M, + N, + K, + ) w1_scale = w1_scale.reshape(w1_scale.size(0), -1) w2_scale = w2_scale.reshape(w2_scale.size(0), -1) @@ -146,15 +164,14 @@ def run_cutlass_moe_fp8( # during offset calculations expert_offsets = expert_offsets.to(torch.int64) else: - problem_sizes1 = torch.empty((global_num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((global_num_experts, 3), - dtype=torch.int32, - device=device) + problem_sizes1 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + problem_sizes2 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) - num_expert = global_num_experts if expert_map is None \ - else expert_map.size(0) + num_expert = global_num_experts if expert_map is None else expert_map.size(0) # permuted a1q reuses workspace2 a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( a1q, @@ -163,12 +180,13 @@ def run_cutlass_moe_fp8( num_expert, local_E, expert_map, - permuted_hidden_states=a1q_perm) + permuted_hidden_states=a1q_perm, + ) expert_offsets = expert_offsets[:-1] - ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1, - problem_sizes2, - global_num_experts, N, K) + ops.get_cutlass_moe_mm_problem_sizes( + local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K + ) if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by @@ -176,56 +194,70 @@ def run_cutlass_moe_fp8( # this rank handles only partial tokens, or when it is batched . mm1_out.fill_(0) - ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets, - problem_sizes1, ab_strides1, ab_strides1, c_strides1, - per_act_token, per_out_ch) + ops.cutlass_moe_mm( + mm1_out, + a1q, + w1, + a1q_scale, + w1_scale, + expert_offsets, + problem_sizes1, + ab_strides1, + ab_strides1, + c_strides1, + per_act_token, + per_out_ch, + ) activation_callable(act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( - act_out, - a2_scale, - use_per_token_if_dynamic=per_act_token, - output=quant_out) + act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out + ) if expert_map is not None: mm2_out.fill_(0) - ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets, - problem_sizes2, ab_strides2, ab_strides2, c_strides2, - per_act_token, per_out_ch) + ops.cutlass_moe_mm( + mm2_out, + a2q, + w2, + a2q_scale, + w2_scale, + expert_offsets, + problem_sizes2, + ab_strides2, + ab_strides2, + c_strides2, + per_act_token, + per_out_ch, + ) if use_batched_format: output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True) else: # for non-chunking mode the output is resized from workspace13 # so we need to make sure mm2_out uses workspace2. - moe_unpermute(out=output, - permuted_hidden_states=mm2_out, - topk_weights=topk_weights, - inv_permuted_idx=inv_perm) + moe_unpermute( + out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm, + ) class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - )) + assert quant_config.use_fp8_w8a8 + super().__init__(quant_config) self.out_dtype = out_dtype self.ab_strides1 = ab_strides1 self.ab_strides2 = ab_strides2 @@ -247,10 +279,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -258,8 +286,8 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" - assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" expert_num_tokens = None if expert_tokens_meta is not None: @@ -267,50 +295,66 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): activation_callable = lambda o, i: self.activation(activation, o, i) - use_batched_format = self.activation_formats[ - 0] == mk.FusedMoEActivationFormat.BatchedExperts + use_batched_format = ( + self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + ) in_dtype = hidden_states.dtype run_cutlass_moe_fp8( - output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, - self.c_strides2, workspace13, workspace2, expert_num_tokens, + output, + hidden_states, + w1, + w2, + topk_ids, + activation_callable, + global_num_experts, + expert_map, + self.w1_scale, + self.w2_scale, + a1q_scale, + a2_scale, + self.ab_strides1, + self.ab_strides2, + self.c_strides1, + self.c_strides2, + workspace13, + workspace2, + expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, - self.per_act_token_quant, self.per_out_ch_quant, - use_batched_format, topk_weights) + self.per_act_token_quant, + self.per_out_ch_quant, + use_batched_format, + topk_weights, + ) class CutlassExpertsFp8(CutlassExpertsFp8Base): - def __init__( self, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): super().__init__( out_dtype, - per_act_token_quant, - per_out_ch_quant, ab_strides1, ab_strides2, c_strides1, c_strides2, - block_shape, + quant_config, ) @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -322,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): # topk weights and reduction are fused in moe_unpermute cuda kernel return TopKWeightAndReduceNoOP() + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -333,38 +378,32 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M * topk, max(N, K)) workspace2 = (M * topk, max(N // 2, K)) output = (M, K) - return (workspace1, workspace2, output, - self.out_dtype if self.out_dtype is not None else a.dtype) + return (workspace1, workspace2, output) class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - def __init__( self, max_experts_per_worker: int, num_dispatchers: int, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): super().__init__( out_dtype, - per_act_token_quant, - per_out_ch_quant, ab_strides1, ab_strides2, c_strides1, c_strides2, - block_shape, + quant_config, ) assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker @@ -372,10 +411,12 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -383,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): def supports_expert_map(self) -> bool: return False - # TODO(bnell): maybe remove need for passing aq to workspace_shapes + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -395,17 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - padded_M = aq.size(1) + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, - max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, - max(N // 2, K)) - output = (self.max_experts_per_worker, padded_M, K) - return (workspace1, workspace2, output, - self.out_dtype if self.out_dtype is not None else a.dtype) + workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) + workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K)) + output = (self.max_experts_per_worker, M, K) + return (workspace1, workspace2, output) def cutlass_moe_fp8( @@ -414,16 +451,12 @@ def cutlass_moe_fp8( w2_q: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - per_act_token: Optional[bool] = None, + quant_config: FusedMoEQuantConfig, activation: str = "silu", - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, @@ -475,24 +508,28 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - if per_act_token is None: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.size(0) + assert quant_config is not None - num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( - 0) + if quant_config.a1_scale is not None: + assert quant_config.per_act_token_quant == quant_config.a1_scale.numel() != 1 + if quant_config.a2_scale is not None: + assert quant_config.per_act_token_quant == quant_config.a2_scale.numel() != 1 + + assert quant_config.w1_scale is None or ( + quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) == w1_q.size(1)) + ) + + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( out_dtype=a.dtype, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_out_ch, ab_strides1=ab_strides1, ab_strides2=ab_strides2, c_strides1=c_strides1, c_strides2=c_strides2, + quant_config=quant_config, ), ) @@ -502,14 +539,9 @@ def cutlass_moe_fp8( w2_q, topk_weights, topk_ids, - False, - activation, - num_experts, - expert_map, - w1_scale, - w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + activation=activation, + global_num_experts=num_experts, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -542,7 +574,7 @@ def run_cutlass_moe_fp4( ) -> None: """ MoE implementation for FP4 Inputs - + # Gemm 1 a: Input tensor: [m, k] (half/bfloat16) a1_gscale: Activation scale per expert: [e] (float32) @@ -552,16 +584,16 @@ def run_cutlass_moe_fp4( full precision) w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) (Block size = 16 for NVFP4) - + # Gemm 2 a2_gscale: Activation scale per expert: [e] w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 - + topk_weights: [m, topk] dtype: float8 topk_ids: [m, topk] dtype: float8 - + m, n, k: Unquantized weight shapes, dtype: int e: number of experts, dtype: int @@ -570,25 +602,30 @@ def run_cutlass_moe_fp4( assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" - assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 - and w2_blockscale.ndim - == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + assert ( + w1_fp4.ndim == 3 + and w2_fp4.ndim == 3 + and w1_blockscale.ndim == 3 + and w2_blockscale.ndim == 3 + ), "All Weights must be of rank 3 for cutlass_moe_fp4" m_a, k_a = a.shape e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape - assert (e_w1 == e_w2 - and e_w1 == e), ("Number of experts must match", - f" between weights. {e_w1}, {e_w2}, {e}") - assert (k_a == half_k_w1 * 2 - and k == k_w2), ("Hidden size mismatch between a, w1 and w2") - assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " - "expected `n`") - assert (m == m_a), "input shape mismatch" + assert e_w1 == e_w2 and e_w1 == e, ( + "Number of experts must match", + f" between weights. {e_w1}, {e_w2}, {e}", + ) + assert k_a == half_k_w1 * 2 and k == k_w2, ( + "Hidden size mismatch between a, w1 and w2" + ) + assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`" + assert m == m_a, "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.size(0) == m and topk_ids.size(0) - == m), ("topk must be provided for each row of a") + assert topk_weights.size(0) == m and topk_ids.size(0) == m, ( + "topk must be provided for each row of a" + ) topk = topk_ids.size(1) out_dtype = a.dtype num_topk = topk_ids.size(1) @@ -605,15 +642,25 @@ def run_cutlass_moe_fp4( if apply_router_weight_on_input: # TODO: this only works for topK=1, will need to update for topK>1 - assert num_topk == 1, \ + assert num_topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a.mul_(topk_weights.to(out_dtype)) # problem shapes should have [m, n, k] # Note that problem sizes are based on logical number of elements. - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, e, n, k, - blockscale_offsets) + ops.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + e, + n, + k, + blockscale_offsets, + ) a = ops.shuffle_rows(a, a_map) rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( @@ -626,17 +673,34 @@ def run_cutlass_moe_fp4( c1 = _resize_cache(workspace13, (m * topk, n * 2)) c2 = _resize_cache(workspace2, (m * topk, n)) c3 = _resize_cache(workspace13, (m * topk, k)) - ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, - w1_blockscale, w1_alphas, problem_sizes1, - expert_offsets[:-1], blockscale_offsets[:-1]) + ops.cutlass_fp4_moe_mm( + c1, + rep_a_fp4, + w1_fp4, + rep_a_blockscale, + w1_blockscale, + w1_alphas, + problem_sizes1, + expert_offsets[:-1], + blockscale_offsets[:-1], + ) del rep_a_fp4, rep_a_blockscale torch.ops._C.silu_and_mul(c2, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( - c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk + ) - ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, - w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1]) + ops.cutlass_fp4_moe_mm( + c3, + int_fp4, + w2_fp4, + int_blockscale, + w2_blockscale, + w2_alphas, + problem_sizes2, + expert_offsets[:-1], + blockscale_offsets[:-1], + ) del int_fp4, int_blockscale c3 = ops.shuffle_rows(c3, c_map) @@ -644,60 +708,45 @@ def run_cutlass_moe_fp4( assert output.dtype == out_dtype if not apply_router_weight_on_input: output.copy_( - (c3.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), - non_blocking=True) + ( + c3.view(m, num_topk, k) + * topk_weights.view(m, num_topk, 1).to(out_dtype) + ).sum(dim=1), + non_blocking=True, + ) else: output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True) return +# Split into batched and non-batched class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, max_experts_per_worker: int, out_dtype: torch.dtype, - per_act_token_quant: bool, - per_out_ch_quant: bool, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, use_batched_format: bool = False, ): - super().__init__( - # NVFP4 requires two levels of quantization, which involves - # computing some scaling factors dynamically. This makes it - # incompatible with the typical prepare -> MoE -> finalize - # pipeline. Move the quantization logic into the MoE body. - FusedMoEQuantConfig( - quant_dtype=None, # skip quantization in prepare/finalize - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - )) + super().__init__(quant_config) self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype self.use_batched_format = use_batched_format - # TODO(bnell): put this stuff into quant config? - self.g1_alphas = g1_alphas - self.g2_alphas = g2_alphas - self.a1_gscale = a1_gscale - self.a2_gscale = a2_gscale - @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: if self.use_batched_format: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) else: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_expert_map(self) -> bool: return False @@ -708,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -719,21 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.size(1) - workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) - output = (self.max_experts_per_worker, padded_M, K) + workspace1 = (self.max_experts_per_worker, M, max(N, K)) + workspace2 = (self.max_experts_per_worker, M, (N // 2)) + output = (self.max_experts_per_worker, M, K) else: workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) output = (M, K) - return (workspace1, workspace2, output, - self.out_dtype if self.out_dtype is not None else a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -746,18 +794,14 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, + a1q_scale: Optional[torch.Tensor], # unused + a2_scale: Optional[torch.Tensor], # unused workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids) + e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) n = w2.shape[2] * 2 run_cutlass_moe_fp4( @@ -765,11 +809,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): a=hidden_states, a1_gscale=self.a1_gscale, w1_fp4=w1, - w1_blockscale=w1_scale, + w1_blockscale=self.w1_scale, w1_alphas=self.g1_alphas, a2_gscale=self.a2_gscale, w2_fp4=w2, - w2_blockscale=w2_scale, + w2_blockscale=self.w2_scale, w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, @@ -785,37 +829,49 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def cutlass_moe_fp4( - a: torch.Tensor, - w1_fp4: torch.Tensor, - w2_fp4: torch.Tensor, - w1_blockscale: torch.Tensor, - w2_blockscale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - m: int, - n: int, - k: int, - e: int, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False) -> torch.Tensor: - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_config: FusedMoEQuantConfig, + m: int, + n: int, + k: int, + e: int, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + assert expert_map is None, ( + "Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE's cutlass_moe_fp4." + ) + + # TODO(bnell): this feels a bit hacky + # NVFP4 requires two levels of quantization, which involves + # computing some scaling factors dynamically. This makes it + # incompatible with the typical prepare -> MoE -> finalize + # pipeline. Move the quantization logic into the MoE body. + quant_config = FusedMoEQuantConfig.make( + quant_dtype=None, # skip quantization in prepare/finalize + per_act_token_quant=quant_config.per_act_token_quant, + per_out_ch_quant=quant_config.per_out_ch_quant, + block_shape=quant_config.block_shape, + g1_alphas=quant_config.g1_alphas, + g2_alphas=quant_config.g2_alphas, + a1_gscale=quant_config.a1_gscale, + a2_gscale=quant_config.a2_gscale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + ) + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( - g1_alphas, - g2_alphas, - a1_gscale, - a2_gscale, max_experts_per_worker=e, out_dtype=a.dtype, - per_act_token_quant=False, - per_out_ch_quant=False, + quant_config=quant_config, use_batched_format=False, ), ) @@ -830,19 +886,18 @@ def cutlass_moe_fp4( activation="silu", global_num_experts=e, expert_map=None, - w1_scale=w1_blockscale, - w2_scale=w2_blockscale, - a1_scale=None, - a2_scale=None, apply_router_weight_on_input=apply_router_weight_on_input, ) def _valid_cutlass_block_scaled_grouped_gemm( - w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str, - apply_router_weight_on_input: bool, - expert_map: Optional[torch.Tensor]) -> bool: - + w1: torch.Tensor, + w2: torch.Tensor, + inplace: bool, + activation: str, + apply_router_weight_on_input: bool, + expert_map: Optional[torch.Tensor], +) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): return N % 128 == 0 and K % 128 == 0 @@ -856,7 +911,7 @@ def _valid_cutlass_block_scaled_grouped_gemm( ) return False - if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn: logger.debug_once( "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). " "w1.dtype: %s, w2.dtype: %s", @@ -867,19 +922,21 @@ def _valid_cutlass_block_scaled_grouped_gemm( if expert_map is not None: logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" - " not supported.") + "CutlassBlockScaledGroupedGemm disabled: expert_parallel is not supported." + ) return False if activation != "silu": logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: only activation silu is" - " supported.") + "CutlassBlockScaledGroupedGemm disabled: only activation silu is supported." + ) return False if apply_router_weight_on_input: - logger.debug_once("CutlassBlockScaledGroupedGemm disabled:" - " apply_router_weight_on_input is not supported.") + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled:" + " apply_router_weight_on_input is not supported." + ) return False if inplace: @@ -891,6 +948,7 @@ def _valid_cutlass_block_scaled_grouped_gemm( return True +# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8. def run_cutlass_block_scaled_fused_experts( a: torch.Tensor, w1: torch.Tensor, @@ -906,17 +964,16 @@ def run_cutlass_block_scaled_fused_experts( w2_scale = w2_scale.transpose(1, 2) assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert a.shape[0] == topk_ids.shape[ - 0], "a and topk_ids must have the same batch size" + assert a.shape[0] == topk_ids.shape[0], ( + "a and topk_ids must have the same batch size" + ) assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1_scale expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2_scale expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch" assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" out_dtype = a.dtype @@ -927,21 +984,14 @@ def run_cutlass_block_scaled_fused_experts( topk = topk_ids.size(1) - a_q, a1_scale = _fp8_quantize(a, - A_scale=None, - per_act_token=False, - block_shape=[128, 128]) + a_q, a1_scale = _fp8_quantize( + a, A_scale=None, per_act_token=False, block_shape=[128, 128] + ) device = a_q.device - expert_offsets = torch.empty((num_experts + 1, ), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) + expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) @@ -977,10 +1027,9 @@ def run_cutlass_block_scaled_fused_experts( intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) torch.ops._C.silu_and_mul(intermediate, c1) - intermediate_q, a2_scale = _fp8_quantize(intermediate, - A_scale=None, - per_act_token=False, - block_shape=[128, 128]) + intermediate_q, a2_scale = _fp8_quantize( + intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128] + ) ops.cutlass_blockwise_scaled_grouped_mm( c2, @@ -992,5 +1041,6 @@ def run_cutlass_block_scaled_fused_experts( expert_offsets[:-1], ) - return (c2[c_map].view(m, topk, k) * - topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) + return ( + c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) + ).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 7b8467a5a0cf0..fc0cb5c530da6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import Optional import torch @@ -9,37 +8,40 @@ from tqdm import tqdm import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce) + compute_aligned_M, + deep_gemm_block_shape, + deepgemm_moe_permute, + deepgemm_unpermute_and_reduce, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.utils import has_deep_gemm, run_once from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous logger = init_logger(__name__) -@functools.cache -def deep_gemm_block_shape() -> list[int]: - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - block = dg.get_m_alignment_for_contiguous_layout() - return [block, block] - - def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: align = deep_gemm_block_shape()[0] return align <= M and N % align == 0 and K % align == 0 -def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def _valid_deep_gemm( + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor +) -> bool: """ Check if the given problem size is supported by the DeepGemm grouped gemm kernel. All of M, N, K and the quantization block_shape must be @@ -57,13 +59,14 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, if not _valid_deep_gemm_shape(M, N, K): logger.debug_once( "DeepGemm disabled due to unaligned problem size. " - "M: %s, N: %s, K: %s. M should >= align size " - "and N and K must be multiples of %s." + "M: %s, N: %s, K: %s. M should >= %s " + "and N and K must be multiples of %s. " "This is not an error and we will fall back to triton.", M, N, K, align, + align, ) return False elif N <= 512: @@ -77,17 +80,19 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, ) return False - if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn: logger.debug_once( - "DeepGemm disabled: invalid weight dtype(s). " - "w1.dtype: %s, w2.dtype: %s", + "DeepGemm disabled: invalid weight dtype(s). w1.dtype: %s, w2.dtype: %s", w1.dtype, w2.dtype, ) return False - if (not hidden_states.is_contiguous() or not w1.is_contiguous() - or not w2.is_contiguous()): + if ( + not hidden_states.is_contiguous() + or not w1.is_contiguous() + or not w2.is_contiguous() + ): logger.debug_once( "DeepGemm disabled: weights or activations not contiguous. " "hidden_states.is_contiguous(): %s, w1.is_contiguous(): %s, " @@ -102,10 +107,13 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, @run_once -def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int): +def warmup_deepgemm_gg_contiguous_kernels( + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int, +): """ DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the input tensor shapes. In this function, we construct all possible input @@ -114,8 +122,7 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, call and not during actual model inference. """ - assert w1.size(0) == w2.size(0), ( - "w1 and w2 must have the same number of experts") + assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" block_m = deep_gemm_block_shape()[0] num_experts = w1.size(0) @@ -123,36 +130,39 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None) + MAX_M = compute_aligned_M( + env.VLLM_FUSED_MOE_CHUNK_SIZE, + num_topk, + num_experts, + block_m, + expert_tokens_meta=None, + ) # Distribute expert-ids evenly. MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint(low=0, - high=num_experts, - size=(MAX_BLOCKS, ), - device=device, - dtype=torch.int32) + expert_ids_block = torch.randint( + low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 + ) expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) - a1q_scales = torch.empty((MAX_M, k // block_m), - device=device, - dtype=torch.float32) + a1q_scales = torch.empty( + (MAX_M, k // block_m), device=device, dtype=torch.float32 + ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=MAX_BLOCKS, - desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})") + pbar = tqdm( + total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})" + ) num_tokens = MAX_M while num_tokens > 0: m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), - out[:num_tokens], expert_ids[:num_tokens]) + (a1q[:num_tokens], a1q_scales[:num_tokens]), + (w, w_scale), + out[:num_tokens], + expert_ids[:num_tokens], + ) pbar.update(1) num_tokens = num_tokens - block_m @@ -161,21 +171,21 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - - def __init__(self): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=deep_gemm_block_shape(), - )) + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + assert quant_config.block_shape == deep_gemm_block_shape() + assert quant_config.quant_dtype == torch.float8_e4m3fn + assert not quant_config.per_act_token_quant + assert not quant_config.per_out_ch_quant @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -188,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -197,17 +205,18 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.block_shape is not None block_m = self.block_shape[0] - M_sum = compute_aligned_M(M, topk, local_num_experts, block_m, - expert_tokens_meta) + M_sum = compute_aligned_M( + M, topk, local_num_experts, block_m, expert_tokens_meta + ) assert M_sum % block_m == 0 workspace1 = (M_sum, max(N, K)) workspace2 = (M_sum, max(N // 2, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -220,10 +229,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -231,10 +236,11 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - assert self.block_shape is not None assert a1q_scale is not None - assert w1_scale is not None - assert w2_scale is not None + assert a2_scale is None + assert self.block_shape is not None + assert self.w1_scale is not None + assert self.w2_scale is not None a1q = hidden_states _, N, K = w1.size() @@ -245,18 +251,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w2.size(1) == K - M_sum = compute_aligned_M(M=topk_ids.size(0), - num_topk=topk_ids.size(1), - local_num_experts=local_num_experts, - alignment=deep_gemm_block_shape()[0], - expert_tokens_meta=expert_tokens_meta) + M_sum = compute_aligned_M( + M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=deep_gemm_block_shape()[0], + expert_tokens_meta=expert_tokens_meta, + ) - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), - (M_sum, K)) + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K)) mm1_out = _resize_cache(workspace13, (M_sum, N)) act_out = _resize_cache(workspace2, (M_sum, N // 2)) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), - (M_sum, N // 2)) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + ) mm2_out = _resize_cache(workspace2, (M_sum, K)) a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( @@ -266,32 +274,36 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): local_num_experts=local_num_experts, expert_map=expert_map, expert_tokens_meta=expert_tokens_meta, - aq_out=a1q_perm) + aq_out=a1q_perm, + ) assert a1q.size(0) == M_sum - m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), - mm1_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous( + (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids + ) self.activation(activation, act_out, mm1_out.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(act_out, - self.block_shape[1], - column_major_scales=True, - out_q=quant_out) + a2q, a2q_scale = per_token_group_quant_fp8( + act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out + ) - m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), - mm2_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous( + (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids + ) if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - deepgemm_unpermute_and_reduce(a=mm2_out, - topk_ids=topk_ids, - topk_weights=topk_weights, - inv_perm=inv_perm, - expert_map=expert_map, - output=output) + deepgemm_unpermute_and_reduce( + a=mm2_out, + topk_ids=topk_ids, + topk_weights=topk_weights, + inv_perm=inv_perm, + expert_map=expert_map, + output=output, + ) def deep_gemm_moe_fp8( @@ -347,9 +359,17 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=deep_gemm_block_shape(), + ) + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - DeepGemmExperts(), + DeepGemmExperts(quant_config), ) return fn( hidden_states, @@ -357,13 +377,9 @@ def deep_gemm_moe_fp8( w2, topk_weights, topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index c8469501af5db..2ac968a9b4ab4 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -20,27 +20,33 @@ from vllm.utils import round_up def deep_gemm_block_shape() -> list[int]: # Lazy import to avoid CUDA initialization problems. import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() return [block, block] -def expert_num_tokens_round_up_and_sum(expert_num_tokens: torch.Tensor, - alignment: int) -> int: +def expert_num_tokens_round_up_and_sum( + expert_num_tokens: torch.Tensor, alignment: int +) -> int: # Round up each element in expert_num_tokens to the nearest multiple of # alignment. - ent = (expert_num_tokens.to(torch.int64) + - (alignment - 1)) // alignment * alignment + ent = (expert_num_tokens.to(torch.int64) + (alignment - 1)) // alignment * alignment return torch.sum(ent).item() -def compute_aligned_M(M: int, num_topk: int, local_num_experts: int, - alignment: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): - - if ((expert_tokens_meta is not None) - and (expert_tokens_meta.expert_num_tokens_cpu is not None)): +def compute_aligned_M( + M: int, + num_topk: int, + local_num_experts: int, + alignment: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], +): + if (expert_tokens_meta is not None) and ( + expert_tokens_meta.expert_num_tokens_cpu is not None + ): return expert_num_tokens_round_up_and_sum( - expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment) + expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment + ) # expert_num_tokens information is not available on the cpu. # compute the max required size. @@ -74,14 +80,14 @@ def _fwd_kernel_ep_scatter_1( cur_expert = tl.program_id(0) offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) - tokens_per_expert = tl.load(num_recv_tokens_per_expert + offset_cumsum, - mask=offset_cumsum < num_experts, - other=0) + tokens_per_expert = tl.load( + num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0, + ) tokens_per_expert = round_up_128(tokens_per_expert) cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert - tl.store(expert_start_loc + offset_cumsum, - cumsum, - mask=offset_cumsum < num_experts) + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) cur_expert_start = tl.load(expert_start_loc + cur_expert) cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) @@ -136,34 +142,31 @@ def _fwd_kernel_ep_scatter_2( mask_s = offset_in_s < SCALE_HIDDEN_SIZE for token_id in range(start_token_id, total_token_num, grid_num): - to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, - mask=mask) - to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + - offset_in_s, - mask=mask_s) + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + to_copy_s = tl.load( + recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s + ) for topk_index in tl.range(0, topk_num, 1, num_stages=4): - expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + - topk_index) + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) if HAS_EXPERT_MAP: expert_id = apply_expert_map(expert_id, expert_map) if expert_id >= 0: - dest_token_index = tl.atomic_add(expert_start_loc + expert_id, - 1) + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) tl.store( - output_index + token_id * output_index_stride0 + - topk_index, dest_token_index) - output_tensor_ptr = (output_tensor + - dest_token_index * output_tensor_stride0) + output_index + token_id * output_index_stride0 + topk_index, + dest_token_index, + ) + output_tensor_ptr = ( + output_tensor + dest_token_index * output_tensor_stride0 + ) output_tensor_scale_ptr = ( - output_tensor_scale + - dest_token_index * output_tensor_scale_stride0) + output_tensor_scale + dest_token_index * output_tensor_scale_stride0 + ) tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) - tl.store(output_tensor_scale_ptr + offset_in_s, - to_copy_s, - mask=mask_s) + tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) @torch.no_grad() @@ -189,7 +192,7 @@ def ep_scatter( assert m_indices.shape[0] % BLOCK_E == 0 - _fwd_kernel_ep_scatter_1[(grid, )]( + _fwd_kernel_ep_scatter_1[(grid,)]( num_recv_tokens_per_expert, expert_start_loc, m_indices, @@ -201,7 +204,7 @@ def ep_scatter( grid = min(recv_topk.shape[0], 1024 * 8) - _fwd_kernel_ep_scatter_2[(grid, )]( + _fwd_kernel_ep_scatter_2[(grid,)]( recv_topk.shape[0], expert_start_loc, recv_x, @@ -265,27 +268,33 @@ def _fwd_kernel_ep_gather( off_d = tl.arange(0, BLOCK_D) accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) for topk_index in range(0, topk_num): - expert_id = tl.load(recv_topk_ids + - cur_token * recv_topk_ids_stride0 + topk_index) + expert_id = tl.load( + recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index + ) if HAS_EXPERT_MAP: expert_id = apply_expert_map(expert_id, expert_map) if expert_id >= 0: - source_token_index = tl.load(input_index + - cur_token * input_index_stride0 + - topk_index) - acc_weight = tl.load(recv_topk_weight + - cur_token * recv_topk_weight_stride0 + - topk_index) - tmp = tl.load(input_tensor + - source_token_index * input_tensor_stride0 + - cur_block * BLOCK_D + off_d) + source_token_index = tl.load( + input_index + cur_token * input_index_stride0 + topk_index + ) + acc_weight = tl.load( + recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index + ) + tmp = tl.load( + input_tensor + + source_token_index * input_tensor_stride0 + + cur_block * BLOCK_D + + off_d + ) accumulator += tmp.to(tl.float32) * acc_weight tl.store( - output_tensor + cur_token * output_tensor_stride0 + - cur_block * BLOCK_D + off_d, + output_tensor + + cur_token * output_tensor_stride0 + + cur_block * BLOCK_D + + off_d, accumulator.to(output_tensor.dtype.element_ty), ) @@ -332,44 +341,46 @@ def ep_gather( return -def deepgemm_moe_permute(aq: torch.Tensor, - aq_scale: torch.Tensor, - topk_ids: torch.Tensor, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - aq_out: Optional[torch.Tensor] = None): - +def deepgemm_moe_permute( + aq: torch.Tensor, + aq_scale: torch.Tensor, + topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + aq_out: Optional[torch.Tensor] = None, +): assert aq.ndim == 2 - assert topk_ids.dtype.is_signed, ( - "The kernel uses -1 to represent invalid topk_ids") + assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids" H = aq.size(1) device = aq.device block_m = deep_gemm_block_shape()[0] block_k = deep_gemm_block_shape()[1] - M_sum = compute_aligned_M(M=topk_ids.size(0), - num_topk=topk_ids.size(1), - local_num_experts=local_num_experts, - alignment=block_m, - expert_tokens_meta=expert_tokens_meta) + M_sum = compute_aligned_M( + M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=block_m, + expert_tokens_meta=expert_tokens_meta, + ) - expert_start_loc = torch.empty((local_num_experts), - device=device, - dtype=torch.int32) + expert_start_loc = torch.empty( + (local_num_experts), device=device, dtype=torch.int32 + ) assert aq_out is None or aq_out.shape == (M_sum, H) if aq_out is None: aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype) - aq_scale_out = torch.empty((M_sum, H // block_k), - device=device, - dtype=torch.float32) + aq_scale_out = torch.empty( + (M_sum, H // block_k), device=device, dtype=torch.float32 + ) - maybe_has_empty_blocks = ((expert_tokens_meta is None) - or (expert_tokens_meta.expert_num_tokens_cpu - is None)) + maybe_has_empty_blocks = (expert_tokens_meta is None) or ( + expert_tokens_meta.expert_num_tokens_cpu is None + ) expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32) @@ -379,35 +390,39 @@ def deepgemm_moe_permute(aq: torch.Tensor, if expert_tokens_meta is not None: expert_num_tokens = expert_tokens_meta.expert_num_tokens else: - expert_num_tokens = count_expert_num_tokens(topk_ids, - local_num_experts, - expert_map) + expert_num_tokens = count_expert_num_tokens( + topk_ids, local_num_experts, expert_map + ) - ep_scatter(recv_x=aq, - recv_x_scale=aq_scale, - recv_topk=topk_ids, - num_recv_tokens_per_expert=expert_num_tokens, - expert_start_loc=expert_start_loc, - expert_map=expert_map, - output_tensor=aq_out, - output_tensor_scale=aq_scale_out, - m_indices=expert_ids, - output_index=inv_perm) + ep_scatter( + recv_x=aq, + recv_x_scale=aq_scale, + recv_topk=topk_ids, + num_recv_tokens_per_expert=expert_num_tokens, + expert_start_loc=expert_start_loc, + expert_map=expert_map, + output_tensor=aq_out, + output_tensor_scale=aq_scale_out, + m_indices=expert_ids, + output_index=inv_perm, + ) return aq_out, aq_scale_out, expert_ids, inv_perm def deepgemm_unpermute_and_reduce( - a: torch.Tensor, # Grouped gemm output - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: torch.Tensor, - expert_map: Optional[torch.Tensor], - output: torch.Tensor): - - return ep_gather(input_tensor=a, - recv_topk_ids=topk_ids, - recv_topk_weight=topk_weights, - input_index=inv_perm, - expert_map=expert_map, - output_tensor=output) + a: torch.Tensor, # Grouped gemm output + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: torch.Tensor, + expert_map: Optional[torch.Tensor], + output: torch.Tensor, +): + return ep_gather( + input_tensor=a, + recv_topk_ids=topk_ids, + recv_topk_weight=topk_weights, + input_index=inv_perm, + expert_map=expert_map, + output_tensor=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 437e569d3130d..85c4fd90dc6c1 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Callable, Optional, Union import deep_ep import torch @@ -8,9 +8,20 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils import round_up +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_switch_to_comm, + dbo_switch_to_compute, + dbo_switch_to_compute_sync, + dbo_yield_and_switch_from_comm_to_compute, + dbo_yield_and_switch_from_compute_to_comm, +) class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -18,17 +29,40 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ - def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, - dp_size: int, rank_expert_offset: int): + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int, dtype: torch.dtype) -> int: + # Round up hidden size so it is compatible with DeepEP High Throughput + # kernels. + # DeepEP intranode kernels make copies in units of, + # 32(warp-size) int4 elements. Round up hidden size to respect this. + # For example, an input hidden size of 2880 with dtype torch.bfloat16 + # will be rounded up to 3072. + hidden_size_bytes = hidden_size * dtype.itemsize + xfer_atom_size = 512 # 32 * 16 (size(int4)) + if hidden_size_bytes % xfer_atom_size == 0: + return hidden_size + + hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) + return hidden_size_bytes // dtype.itemsize + + def __init__( + self, + buffer: deep_ep.Buffer, + num_dispatchers: int, + dp_size: int, + rank_expert_offset: int, + ): super().__init__() self.buffer = buffer self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset + self.async_prepare = True + # The dispatch function returns a handle that the combine function - # requires. We store the handle here so it is available to the - # combine function. - self.handle = None + # requires. Under DBO microbatching we must track one handle per + # micro-batch to avoid races between threads. + self.handles = [None, None] # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] @@ -36,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -47,38 +84,57 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return torch.int64 def _get_dispatch_config(self) -> Optional[deep_ep.Config]: - if self.dp_size not in self.available_rank_configs: + if self.num_dispatchers_ not in self.available_rank_configs: return None - return deep_ep.Buffer.get_dispatch_config(self.dp_size) + return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_) def _get_combine_config(self) -> Optional[deep_ep.Config]: - if self.dp_size not in self.available_rank_configs: + if self.num_dispatchers_ not in self.available_rank_configs: return None - return deep_ep.Buffer.get_combine_config(self.dp_size) - - def _do_dispatch(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, num_experts: int): + return deep_ep.Buffer.get_combine_config(self.num_dispatchers_) + def _do_dispatch( + self, + tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, + num_experts: int, + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> Callable: has_scales = token_scales is not None - (num_tokens_per_rank, num_tokens_per_rdma_rank, - dispatch_expert_num_tokens, is_token_in_rank, - event) = self.buffer.get_dispatch_layout( - topk_idx=rank_topk_ids, - num_experts=num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False) + # We yield before launching the dispatch kernel since the dispatch + # kernel will block the CPU so we want to queue up all the compute + # for the other ubatch before the dispatch kernel starts. + dbo_yield_and_switch_from_compute_to_comm() + + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + dispatch_expert_num_tokens, + is_token_in_rank, + event, + ) = self.buffer.get_dispatch_layout( + topk_idx=rank_topk_ids, + num_experts=num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ) token_data = tokens if has_scales: token_data = (tokens, token_scales) ( - token_data, expert_topk_ids, expert_topk_weights, - expert_num_tokens_per_expert_list, self.handle, event + token_data, + expert_topk_ids, + expert_topk_weights, + expert_num_tokens_per_expert_list, + handle, + event, ) = self.buffer.dispatch( x=token_data, handle=None, @@ -93,8 +149,42 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=False, - allocate_on_comm_stream=False) + async_finish=self.async_prepare and not dbo_enabled(), + allocate_on_comm_stream=False, + ) + + # record the handle for this ubatch + a2a_idx = dbo_current_ubatch_id() + self.handles[a2a_idx] = handle + + dbo_switch_to_compute_sync() + + return lambda: self._receiver( + event, + has_scales, + token_data, + expert_topk_ids, + num_experts, + expert_num_tokens_per_expert_list, + expert_topk_weights, + a1_scale, + quant_config, + ) + + def _receiver( + self, + event: deep_ep.EventOverlap, + has_scales: bool, + token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], + expert_topk_ids: Optional[torch.Tensor], + num_experts: int, + expert_num_tokens_per_expert_list: list[int], + expert_topk_weights: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if event.event is not None: + event.current_stream_wait() if has_scales: expert_x, expert_x_scale = token_data @@ -112,72 +202,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # DeepEP's topk_ids output refers to the local experts directly. Offset # the topk_ids to move it back to the global experts space so it aligns # with existing vLLM interfaces. + assert expert_topk_ids is not None expert_topk_ids = torch.where( expert_topk_ids == -1, num_experts - 1 if self.rank_expert_offset == 0 else 0, - expert_topk_ids + self.rank_expert_offset) + expert_topk_ids + self.rank_expert_offset, + ) # Makes a GPU-CPU copy. # TODO (varun): Maybe it is better to re-compute the expert_num_tokens # on GPU. expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( - expert_num_tokens_per_expert_list, device=expert_x.device) + expert_num_tokens_per_expert_list, device=expert_x.device + ) - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) - - def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: - - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * topk_weights.to(a1.dtype) - - if quant_config.is_block_quantized: - # Quant and Dispatch - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - if a1q_scale is not None and a1q_scale.numel() == 1: - a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) - else: - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 - (expert_x, _, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1, - token_scales=None, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: # Quantize after dispatch. expert_x_scale = None if expert_x.numel() != 0: @@ -186,12 +229,87 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1_scale, quant_dtype=quant_config.quant_dtype, per_act_token_quant=False, - block_shape=quant_config.block_shape) + block_shape=quant_config.block_shape, + ) - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) + return ( + expert_x, + expert_x_scale, + expert_tokens_meta, + expert_topk_ids, + expert_topk_weights, + ) - def finalize( + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.ReceiverType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1 = a1 * topk_weights.to(a1.dtype) + + if quant_config.is_block_quantized: + # Quant and Dispatch + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) + a1_post_scale = None + else: + a1q = a1 + a1q_scale = None + a1_post_scale = quant_config.a1_scale + + return self._do_dispatch( + tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config, + ) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + return receiver() + + def _finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -199,9 +317,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - - assert self.handle is not None + do_async: bool, + ) -> Optional[Callable]: + a2a_idx = dbo_current_ubatch_id() + handle = self.handles[a2a_idx] + assert handle is not None # fused_expert_output can have 0 tokens - This happens when none of the # tokens from the all2all reach this EP rank. @@ -215,14 +335,76 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, ) - + dbo_yield_and_switch_from_compute_to_comm() combined_x, _, event = self.buffer.combine( x=fused_expert_output, - handle=self.handle, + handle=handle, topk_weights=None, config=self._get_combine_config(), previous_event=None, - async_finish=False, - allocate_on_comm_stream=False) - # Respect inplace outputs. - output.copy_(combined_x, non_blocking=True) + async_finish=do_async and not dbo_enabled(), + allocate_on_comm_stream=False, + ) + + dbo_switch_to_compute() + + if do_async: + + def _receiver(): + if event.event is not None: + event.current_stream_wait() + dbo_switch_to_comm() + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + + # TODO(lucas): refactor the modular kernel so this will be + # handled there + dbo_yield_and_switch_from_comm_to_compute() + + return _receiver + else: + # TODO(lucas): support this case with the refactored modular kernel + assert not dbo_enabled() + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + return None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + receiver = self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + True, + ) + assert receiver is not None + return receiver + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + False, + ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 93ac11fb4bfbf..117bfe6e6b4d7 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Callable, Optional, Union import deep_ep import torch @@ -8,17 +8,26 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input, normalize_batched_scales_shape) + moe_kernel_quantize_input, + normalize_batched_scales_shape, +) +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_maybe_run_recv_hook, +) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] -def dequant_fp8(expert_x_fp8: torch.Tensor, - expert_x_scales: torch.Tensor) -> torch.Tensor: +def dequant_fp8( + expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor +) -> torch.Tensor: """ Return dequantized tensor in fp32 """ @@ -28,7 +37,8 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, num_experts = expert_x_fp8.size(0) expert_x_fp32 = expert_x_fp8.to(torch.float32).view( - num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) + num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE + ) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) @@ -42,11 +52,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # specific hidden sizes. SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168] - def __init__(self, - buffer: deep_ep.Buffer, - max_tokens_per_rank: int, - num_dispatchers: int, - use_fp8_dispatch: bool = False): + def __init__( + self, + buffer: deep_ep.Buffer, + max_tokens_per_rank: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False, + ): super().__init__() self.buffer = buffer @@ -55,12 +67,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. - self.handle = None + self.handles: list[Optional[tuple]] = [None, None] self.num_dispatchers_ = num_dispatchers def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts @@ -74,16 +89,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def _do_quant( self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], - per_act_token_quant: bool, - block_shape: Optional[list[int]], + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - - block_k = block_shape[1] if block_shape is not None else None if self.use_fp8_dispatch: + block_k = ( + quant_config.block_shape[1] + if quant_config.block_shape is not None + else None + ) if block_k == DEEPEP_QUANT_BLOCK_SIZE: # DeepEP kernels did the quantization for us. x, x_scales = x @@ -99,72 +113,186 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # TODO (varun): Optimization - Use a batched version of quant x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, - per_act_token_quant, - block_shape) + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) x = x.view((num_experts, -1, hidden_dim)) - if quant_dtype is not None: + if quant_config.quant_dtype is not None: assert x_scales is not None x_scales = normalize_batched_scales_shape(x_scales, num_experts) return x, x_scales - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: - + ) -> tuple[Callable, mk.ReceiverType]: hidden_size = a1.size(1) - assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ - (f"Hidden Size {hidden_size} not in supported list of hidden sizes" - f"{self.SUPPORTED_HIDDEN_SIZES}") + assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( + f"Hidden Size {hidden_size} not in supported list of hidden sizes" + f"{self.SUPPORTED_HIDDEN_SIZES}" + ) + + a2a_idx = dbo_current_ubatch_id() if self.use_fp8_dispatch: - assert hidden_size % 128 == 0, \ - "DeepEP kernels quantize the inputs in blocks of shape 128" + assert hidden_size % 128 == 0, ( + "DeepEP kernels quantize the inputs in blocks of shape 128" + ) - has_per_token_scales = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + has_per_token_scales = ( + quant_config.a1_scale.numel() != 1 + if quant_config.a1_scale is not None + else ( + quant_config.a2_scale.numel() != 1 + if quant_config.a2_scale is not None + else False + ) + ) assert not has_per_token_scales, ( - "low_latency kernels doesn't support dispatching per-token scales") + "low_latency kernels doesn't support dispatching per-token scales" + ) if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") + "apply_router_weight_on_input is only implemented for topk=1" + ) a1 = a1 * topk_weights.to(a1.dtype) # Dispatch - expert_x, expert_num_tokens, self.handle, event, hook = \ - self.buffer.low_latency_dispatch(a1, - topk_ids, - self.max_tokens_per_rank, - num_experts, - use_fp8=self.use_fp8_dispatch, - async_finish=False, - return_recv_hook=False) + expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( + a1, + topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + async_finish=False, + return_recv_hook=True, + ) + self.handles[a2a_idx] = handle - expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + return ( + hook, + lambda: self._receiver( + expert_x, + expert_num_tokens, + quant_config.a1_scale, + a1.dtype, + quant_config, + ), + ) + + def _receiver( + self, + expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + expert_num_tokens: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config) expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None + ) - return (expert_x, expert_x_scale, expert_tokens_meta, None, None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook, receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + hook() + return receiver() + + def _finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + do_async: bool, + ) -> tuple[Callable, Callable]: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( + "Weight application and reduction happens in the combine kernel." + ) + + a2a_idx = dbo_current_ubatch_id() + do_recv_hook = dbo_enabled() or do_async + handle = self.handles[a2a_idx] + assert handle is not None + + combine_topk_weights = topk_weights + if apply_router_weight_on_input: + # weights have already been applied. + combine_topk_weights = torch.ones_like(topk_weights) + + # TODO (varun) : Enable zero copy mode + dbo_maybe_run_recv_hook() + _, _, recv_hook = self.buffer.low_latency_combine( + fused_expert_output, + topk_ids, + combine_topk_weights, + handle, + async_finish=False, + zero_copy=False, + return_recv_hook=do_recv_hook, + out=output, + ) + + return recv_hook, lambda: None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> tuple[Callable, Callable]: + return self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=True, + ) def finalize( self, @@ -175,23 +303,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - assert isinstance( - weight_and_reduce_impl, TopKWeightAndReduceDelegate - ), ("Weight application and reduction happens in the combine kernel.") - assert self.handle is not None - - combine_topk_weights = topk_weights - if apply_router_weight_on_input: - # weights have already been applied. - combine_topk_weights = torch.ones_like(topk_weights) - - # TODO (varun) : Enable zero copy mode - _, event, hook = self.buffer.low_latency_combine( + self._finalize( + output, fused_expert_output, + topk_weights, topk_ids, - combine_topk_weights, - self.handle, - async_finish=False, - zero_copy=False, - return_recv_hook=False, - out=output) + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=False, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index feab3f74cac53..1b33c7075fb36 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Optional import torch @@ -8,77 +8,75 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + create_flashinfer_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) -from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, - has_flashinfer_cutlass_fused_moe) + TopKWeightAndReduceNoOP, +) +from vllm.utils.flashinfer import ( + flashinfer_cutlass_fused_moe, + has_flashinfer_cutlass_fused_moe, +) logger = init_logger(__name__) -def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def is_valid_flashinfer_cutlass_fused_moe( + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor +) -> bool: """ Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel. """ if not has_flashinfer_cutlass_fused_moe(): - logger.debug_once("FlashInferExperts disabled: " - "flashinfer_cutlass_fused_moe not available.") + logger.debug_once( + "FlashInferExperts disabled: flashinfer_cutlass_fused_moe not available." + ) return False # Data type checks - if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8 - or hidden_states.dtype - not in [torch.float32, torch.float16, torch.bfloat16]): + if ( + w1.dtype != torch.uint8 + or w2.dtype != torch.uint8 + or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] + ): logger.debug_once( "FlashInferExperts disabled: w1/w2 must be torch.uint8 " f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " - f"float32, float16, or bfloat16 (got {hidden_states.dtype}).") + f"float32, float16, or bfloat16 (got {hidden_states.dtype})." + ) return False return True class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, out_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], + quant_config: FusedMoEQuantConfig, ep_rank: int = 0, ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, ): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=False, - block_shape=None, - )) - assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( - "Only nvfp4,fp8 quantization are currently supported.") + super().__init__(quant_config) + assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( + "Only nvfp4, fp8, bfloat16 and" + " float16 quantization are currently supported." + ) self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank self.tp_size = tp_size - self.g1_alphas = g1_alphas - self.g2_alphas = g2_alphas - self.a1_gscale = a1_gscale - self.a2_gscale = a2_gscale self.out_dtype = out_dtype @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_expert_map(self) -> bool: return False @@ -92,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -101,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. """ @@ -120,15 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - aq_m, aq_n = aq.shape - workspace2 = () - output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ - torch.float8_e4m3fn else (aq_m, aq_n) - workspace_dtype = a.dtype - workspace1 = output_shape + workspace1 = (M, K) + workspace2 = (0,) + output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) # The workspace is determined by `aq`, since it comes after any # potential communication op and is involved in the expert computation. - return (workspace1, workspace2, output_shape, workspace_dtype) + return (workspace1, workspace2, output_shape) def apply( self, @@ -141,43 +134,51 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], # Not used + a2_scale: Optional[torch.Tensor], workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], ): + assert activation == "silu", ( + "Only activation silu is supported in FlashInferExperts" + ) + if self.quant_dtype == torch.float8_e4m3fn: quant_scales = [ - self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale + self.g1_alphas, + self.a2_gscale, + self.g2_alphas, + self.a1_gscale, ] a1q_scale = None # not passing input_sf in fp8 fc1_expert_weights = w1 fc2_expert_weights = w2 - else: + elif self.quant_dtype == "nvfp4": # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( - "w1_scale and w2_scale must not " - "be None for FlashInferExperts") + assert self.w1_scale is not None and self.w2_scale is not None, ( + "w1_scale and w2_scale must not be None for FlashInferExperts" + ) # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ self.a1_gscale, - w1_scale.view(torch.int32), + self.w1_scale.view(torch.int32), self.g1_alphas, self.a2_gscale, - w2_scale.view(torch.int32), + self.w2_scale.view(torch.int32), self.g2_alphas, ] # FlashInfer API requires weight to be long for nvfp4 fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) + else: + quant_scales = None + a1q_scale = None + fc1_expert_weights = w1 + fc2_expert_weights = w2 _ = flashinfer_cutlass_fused_moe( input=hidden_states, @@ -202,30 +203,64 @@ def flashinfer_cutlass_moe_fp4( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + quant_config: FusedMoEQuantConfig, inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, - a1_gscale=a1_gscale), + create_flashinfer_prepare_finalize(use_dp=False), FlashInferExperts( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, out_dtype=hidden_states.dtype, - quant_dtype="nvfp4", - )) + quant_config=quant_config, + ), + ) + + return fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def flashinfer_cutlass_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_config: FusedMoEQuantConfig, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + tp_rank: int = 0, + tp_size: int = 1, + ep_rank: int = 0, + ep_size: int = 1, + use_dp: bool = False, +) -> torch.Tensor: + fused_experts = mk.FusedMoEModularKernel( + create_flashinfer_prepare_finalize(use_dp=use_dp), + FlashInferExperts( + out_dtype=hidden_states.dtype, + quant_config=quant_config, + tp_rank=tp_rank, + tp_size=tp_size, + ep_rank=ep_rank, + ep_size=ep_size, + ), + ) return fused_experts( hidden_states=hidden_states, @@ -237,7 +272,5 @@ def flashinfer_cutlass_moe_fp4( activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 061b02172c446..4907b9ff5730b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -5,11 +5,16 @@ from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_dp_group +from vllm.distributed import get_dp_group, get_ep_group +from vllm.distributed.device_communicators.base_device_communicator import ( + All2AllManagerBase, +) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -18,17 +23,16 @@ def get_local_sizes(): class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """Base class for FlashInfer MoE prepare and finalize operations.""" def __init__( self, use_dp: bool, - a1_gscale: Optional[torch.Tensor], num_dispatchers: int = 1, ): super().__init__() self.num_dispatchers_ = num_dispatchers self.use_dp = use_dp - self.a1_gscale = a1_gscale self.local_tokens = None @property @@ -44,56 +48,264 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return False + + def _apply_router_weight_on_input( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """Apply router weight on input if needed.""" + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1.mul_(topk_weights.to(a1.dtype)) + + +class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): + """FlashInfer implementation using AllToAll communication.""" + + def __init__( + self, + use_dp: bool, + num_dispatchers: int = 1, + ): + super().__init__(use_dp, num_dispatchers) + self.alltoall_info = None + + # Initialize all2all_manager only for DP case + self.all2all_manager = None + if self.use_dp: + self.all2all_manager = get_ep_group().device_communicator.all2all_manager + def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], # Not used - a2_scale: Optional[torch.Tensor], # Not used topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, - # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: - - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - a1.mul_(topk_weights.to(a1.dtype)) - - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - self.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - # Swizzling after communication - is_fp4_scale_swizzled=not self.use_dp, + ) -> mk.PrepareResultType: + self._apply_router_weight_on_input( + a1, topk_weights, topk_ids, apply_router_weight_on_input ) - if self.use_dp: - topk_weights, topk_ids, a1q, a1q_scale = \ - get_dp_group().all_gatherv( - [topk_weights, topk_ids, a1q, a1q_scale], - dim=0, - sizes=get_local_sizes(), + + if not self.use_dp: + # Non-DP case: standard quantization + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + else: + # DP case: use FlashInfer AllToAll + global_num_tokens_cpu = get_local_sizes() + top_k = topk_ids.size(1) + + (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = ( + flashinfer_alltoall_dispatch( + self.all2all_manager, + global_num_tokens_cpu, + a1, + quant_config.a1_gscale, + topk_ids, + topk_weights, + top_k, + num_experts, + quant_config, ) - a1_m, a1_n = a1q.shape - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + ) return a1q, a1q_scale, None, topk_ids, topk_weights - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if self.use_dp: + top_k = topk_ids.size(1) + token_count = output.shape[0] + fused_expert_output = flashinfer_alltoall_combine( + self.all2all_manager, + fused_expert_output, + top_k=top_k, + token_count=token_count, + alltoall_info=self.alltoall_info, + ) + output.copy_(fused_expert_output) + + +class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): + def __init__( + self, + use_dp: bool, + num_dispatchers: int = 1, + ): + super().__init__(use_dp, num_dispatchers) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + self._apply_router_weight_on_input( + a1, topk_weights, topk_ids, apply_router_weight_on_input + ) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + if self.use_dp: + topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) + if quant_config.quant_dtype == "nvfp4": + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, dim=0, sizes=get_local_sizes()) + fused_expert_output, dim=0, sizes=get_local_sizes() + ) output.copy_(fused_expert_output) + + +def flashinfer_alltoall_dispatch( + all2all_manager: All2AllManagerBase, + global_num_tokens_cpu: list[int], + x: torch.Tensor, + gs: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + top_k: int, + num_experts: int, + quant_config: FusedMoEQuantConfig, +): + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + "FlashInfer AllToAll workspace not available" + ) + + ep_rank = all2all_manager.rank + ep_size = all2all_manager.world_size + max_num_token = ( + max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0] + ) + alltoall_info, topk_ids, topk_weights, _ = ( + MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + topk_ids, + topk_weights, + None, + all2all_manager.prepare_workspace, + max_num_token, + ep_rank, + ep_size, + num_experts, + num_experts, + top_k, + ) + ) + + x, x_sf = moe_kernel_quantize_input( + x, + gs, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + + x_sf = MnnvlMoe.mnnvl_moe_alltoallv( + x_sf, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + x_sf = nvfp4_block_scale_interleave(x_sf) + return alltoall_info, topk_ids, topk_weights, x, x_sf + + +def flashinfer_alltoall_combine( + all2all_manager: All2AllManagerBase, + output: torch.Tensor, + top_k: int, + token_count: int, + alltoall_info, +): + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + "FlashInfer AllToAll workspace not available" + ) + return MnnvlMoe.mnnvl_moe_alltoallv_combine( + output, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank=all2all_manager.rank, + ep_size=all2all_manager.world_size, + top_k=top_k, + token_count=token_count, + ) + + +def create_flashinfer_prepare_finalize( + use_dp: bool, + use_nvfp4: bool = False, + enable_alltoallv: bool = False, +) -> FlashInferCutlassMoEPrepareAndFinalize: + """Factory function to create the appropriate FlashInfer implementation.""" + if use_nvfp4: + if enable_alltoallv: + return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) + else: + return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) + # Fp8 only supports AllGather + return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py new file mode 100644 index 0000000000000..d12d05915566d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + calculate_tile_tokens_dim, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.utils import direct_register_custom_op + + +def flashinfer_fused_moe_blockscale_fp8( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0, +) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + + assert top_k <= global_num_experts + assert top_k <= 8 + assert topk_group <= 4 + assert global_num_experts > num_expert_group + assert global_num_experts % num_expert_group == 0 + assert global_num_experts % 4 == 0 + assert top_k < (topk_group * global_num_experts / num_expert_group) + assert block_shape == [128, 128] + # Routing kernel expects #experts <= #threads 256 + assert global_num_experts <= 256 + + a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + return flashinfer_trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale_inv, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale_inv, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling, + tile_tokens_dim=calculate_tile_tokens_dim( + x.shape[0], top_k, global_num_experts + ), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) + + +def flashinfer_fused_moe_blockscale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0, +) -> torch.Tensor: + return torch.empty_like(x) + + +# TODO(bnell): Does this really need to be a torch.op? +direct_register_custom_op( + op_name="flashinfer_fused_moe_blockscale_fp8", + op_func=flashinfer_fused_moe_blockscale_fp8, + fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order,), +) + + +def flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0, +) -> torch.Tensor: + num_expert_group = num_expert_group if num_expert_group is not None else 0 + topk_group = topk_group if topk_group is not None else 0 + + quant_hidden_states, _ = moe_kernel_quantize_input( + hidden_states, + input_scale, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + ) + + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe + + return flashinfer_trtllm_fp8_per_tensor_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=quant_hidden_states, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=calculate_tile_tokens_dim( + hidden_states.shape[0], top_k, num_experts + ), + routing_method_type=routing_method_type, + ) + + +def flashinfer_fused_moe_per_tensor_scale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +# TODO(bnell): Does this really need to be a torch.op? +direct_register_custom_op( + op_name="flashinfer_fused_moe_per_tensor_scale_fp8", + op_func=flashinfer_fused_moe_per_tensor_scale_fp8, + mutates_args=["hidden_states"], + fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order,), +) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b46f4be4b912e..0c31684d23677 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,21 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" + from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) + TopKWeightAndReduceDelegate, + TopKWeightAndReduceNaiveBatched, +) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, - normalize_scales_shape) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + _resize_cache, + moe_kernel_quantize_input, + normalize_batched_scales_shape, + normalize_scales_shape, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.triton_utils import tl, triton @@ -56,12 +60,12 @@ def moe_mmk( use_w8a16: tl.constexpr, per_act_token_quant: tl.constexpr, ): - offs_k = tl.arange(0, BLOCK_K) if use_w8a16: - b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) if use_w8a8: @@ -94,9 +98,11 @@ def moe_mmk( for k in range(0, tl.cdiv(K, BLOCK_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), - other=0.0) + a = tl.load( + a_ptrs, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0, + ) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) # We accumulate along the K dimension. if use_w8a16: @@ -105,13 +111,12 @@ def moe_mmk( if group_k > 0 and group_n > 0: k_start = k * BLOCK_K offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, - mask=mask_m, - other=0.0) + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=mask_m, other=0.0 + ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, - None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) @@ -137,9 +142,9 @@ def moe_mmk( @triton.jit def expert_triton_kernel( - a_ptr, #[max_tokens, K] - b_ptr, #[K, N] - c_ptr, #[max_tokens, N] + a_ptr, # [max_tokens, K] + b_ptr, # [K, N] + c_ptr, # [max_tokens, N] expert_id, compute_type: tl.constexpr, # Dimensions @@ -177,7 +182,6 @@ def expert_triton_kernel( BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): - offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) % N offs_k = tl.arange(0, BLOCK_K) @@ -221,7 +225,8 @@ def expert_triton_kernel( compute_type, use_fp8_w8a8, use_int8_w8a16, - per_act_token_quant) + per_act_token_quant, + ) # store in C offs_cn = tl.arange(0, BLOCK_N) @@ -284,7 +289,7 @@ def batched_triton_kernel( # axis 1 is M_blocks * N_blocks pid_mn = tl.program_id(axis=1) - #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + # num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) pid_m = pid_mn // num_pid_n pid_n = pid_mn % num_pid_n @@ -300,8 +305,12 @@ def batched_triton_kernel( a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn - c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + - cta_n_start * stride_cn) + c_ptr = ( + c_ptr + + expert_id * stride_ce + + cta_m_start * stride_cm + + cta_n_start * stride_cn + ) offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N @@ -350,50 +359,54 @@ def batched_triton_kernel( # Kernel config BLOCK_M, BLOCK_N, - BLOCK_K) + BLOCK_K, + ) def invoke_moe_batched_triton_kernel( - A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - per_act_token_quant: bool, - block_shape: Optional[list[int]] = None): - + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, N, K] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +): assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) N = C.size(2) - BLOCK_M = config['BLOCK_SIZE_M'] - BLOCK_N = config['BLOCK_SIZE_N'] - BLOCK_K = config['BLOCK_SIZE_K'] + BLOCK_M = config["BLOCK_SIZE_M"] + BLOCK_N = config["BLOCK_SIZE_N"] + BLOCK_K = config["BLOCK_SIZE_K"] - grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * - triton.cdiv(B.size(1), BLOCK_N)) + grid = ( + expert_num_tokens.size(0), + triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N), + ) - A_scale = normalize_batched_scales_shape(A_scale, - expert_num_tokens.shape[0]) + A_scale = normalize_batched_scales_shape(A_scale, expert_num_tokens.shape[0]) if B_scale is not None and B_scale.ndim == 1: assert B_scale.numel() == expert_num_tokens.shape[0] B_scale = B_scale.view(-1, 1, 1) assert A_scale is None or A_scale.ndim == 3, ( - f"{0 if A_scale is None else A_scale.shape}") + f"{0 if A_scale is None else A_scale.shape}" + ) assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, ( - f"{0 if B_scale is None else B_scale.shape}") + f"{0 if B_scale is None else B_scale.shape}" + ) if B_scale is not None: if B_scale.ndim == 1: @@ -459,7 +472,8 @@ def invoke_moe_batched_triton_kernel( # Kernel config BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K) + BLOCK_K=BLOCK_K, + ) class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -495,20 +509,19 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return False + def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) @@ -516,16 +529,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ + assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a1.mul_(topk_weights.to(a1.dtype)) num_tokens, hidden_dim = a1.size() topk = topk_ids.size(1) - tokens_per_expert = torch.zeros(num_experts, - dtype=torch.int, - device=a1.device) + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a1.device) num_local_experts = self.num_local_experts @@ -537,24 +549,23 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): b_a1 = torch.zeros( (num_local_experts, self.max_num_tokens, hidden_dim), dtype=b_type, - device=a1.device) + device=a1.device, + ) if quant_config.is_quantized: scale_shape = quant_config.batched_scale_shape( - num_local_experts, self.max_num_tokens, hidden_dim) + num_local_experts, self.max_num_tokens, hidden_dim + ) - b_a1_scale = torch.empty(scale_shape, - dtype=torch.float32, - device=a1.device) + b_a1_scale = torch.empty(scale_shape, dtype=torch.float32, device=a1.device) else: - assert a1_scale is None + assert quant_config.a1_scale is None b_a1_scale = None first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts - a1_scale = normalize_scales_shape(a1_scale) - a2_scale = normalize_scales_shape(a2_scale) + a1_scale = normalize_scales_shape(quant_config.a1_scale) for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() @@ -563,11 +574,11 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): continue idx = expert_id - first_expert tokens_per_expert[idx] = rows - rhs = a1[:topks.numel()][topks] + rhs = a1[: topks.numel()][topks] if quant_config.quant_dtype is not None: if a1_scale is not None: if quant_config.is_per_act_token: - rhs_a1_scale = a1_scale[:topks.numel()][topks] + rhs_a1_scale = a1_scale[: topks.numel()][topks] else: rhs_a1_scale = a1_scale else: @@ -583,14 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): if quant_config.is_per_act_token: b_a1_scale[idx, :rows] = b_s[:rows] else: - b_a1_scale[idx, :b_s.shape[0]] = b_s + b_a1_scale[idx, : b_s.shape[0]] = b_s else: b_a1[idx, :rows, :] = rhs assert b_a1_scale is None or b_a1_scale.ndim == 3 expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None) + expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None + ) return b_a1, b_a1_scale, expert_tokens_meta, None, None @@ -625,37 +637,24 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): self, max_num_tokens: int, num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - assert not use_mxfp4_w4a4, "NYI" + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert self.quant_config.ocp_mx_scheme is None, "NYI" self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -669,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -678,20 +675,18 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) output = workspace13 - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: assert self.quant_config.is_quantized f32 = torch.float32 - if (self.quant_config.is_per_act_token - or self.quant_config.is_per_tensor): + if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: return t.to(f32) * scale else: return t.to(f32) * group_broadcast(scale, t.shape) @@ -707,10 +702,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -723,15 +714,16 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens = expert_tokens_meta.expert_num_tokens num_local_experts = w1.size(0) - assert num_local_experts == w1.size(0), ( - f"{num_local_experts} == {w1.size(0)}") + assert num_local_experts == w1.size(0), f"{num_local_experts} == {w1.size(0)}" N = w1.size(1) // 2 for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor - if (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): + if ( + torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing() + ): num = hidden_states.shape[1] else: num = int(expert_num_tokens[expert].item()) @@ -742,20 +734,18 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): tmp = _resize_cache(workspace2, (num, N)) if self.quant_config.is_quantized: - assert a1q_scale is not None and w1_scale is not None - input = self.dequant(hidden_states[expert, :, :], - a1q_scale[expert]) - w1_dq = self.dequant(w1[expert], w1_scale[expert]) + assert a1q_scale is not None and self.w1_scale is not None + input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) + w1_dq = self.dequant(w1[expert], self.w1_scale[expert]) input = input[:num] @ w1_dq.transpose(0, 1) else: - input = hidden_states[expert, :num, :] @ w1[expert].transpose( - 0, 1) + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) self.activation(activation, tmp, input.to(tmp.dtype)) if self.quant_config.is_quantized: - assert w2_scale is not None - w2_dq = self.dequant(w2[expert], w2_scale[expert]) + assert self.w2_scale is not None + w2_dq = self.dequant(w2[expert], self.w2_scale[expert]) else: w2_dq = w2[expert] @@ -773,17 +763,16 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): + if torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing(): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) assert A_scale is None or A_scale.ndim <= 2, ( - f"{A_scale.shape if A_scale is not None else None}") - A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, - hidden_dim), A_scale, - qtype, per_act_token_quant, - block_shape) + f"{A_scale.shape if A_scale is not None else None}" + ) + A_q, A_q_scale = moe_kernel_quantize_input( + A.view(-1, hidden_dim), A_scale, qtype, per_act_token_quant, block_shape + ) A_q = A_q.view(E, -1, hidden_dim) A_q_scale = normalize_batched_scales_shape(A_q_scale, E) @@ -803,9 +792,7 @@ def batched_moe_kernel_quantize_input( else: scale_shape = (E, 1, 1) - A_q_scale = torch.zeros(scale_shape, - dtype=torch.float32, - device=A.device) + A_q_scale = torch.zeros(scale_shape, dtype=torch.float32, device=A.device) num_experts = expert_num_tokens.numel() @@ -815,7 +802,7 @@ def batched_moe_kernel_quantize_input( num_tokens = int(expert_num_tokens[e].item()) if num_tokens > 0: if A_scale is not None: - scales = A_scale[e, :min(num_tokens, A_scale.shape[1])] + scales = A_scale[e, : min(num_tokens, A_scale.shape[1])] else: scales = None A_q[e, :num_tokens], tmp_scale = moe_kernel_quantize_input( @@ -826,7 +813,7 @@ def batched_moe_kernel_quantize_input( block_shape, ) assert tmp_scale is not None - A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale + A_q_scale[e, : tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale @@ -842,44 +829,26 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self, max_num_tokens: int, num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - assert not use_mxfp4_w4a4, "NYI" + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert self.quant_config.ocp_mx_scheme is None, "NYI" assert max_num_tokens > 0 assert num_dispatchers > 0 - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_mxfp4_w4a4 = use_mxfp4_w4a4 self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -893,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -902,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) output = (num_experts, max_num_tokens * num_dp, K) - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def apply( self, @@ -923,10 +889,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -935,36 +897,34 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): apply_router_weight_on_input: bool, ): # Check constraints. - if self.use_int4_w4a16: - assert hidden_states.size(-1) // 2 == w1.size(2), ( - "Hidden size mismatch") + if self.quant_config.use_int4_w4a16: + assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" else: assert hidden_states.size(-1) == w1.size(2), ( - f"Hidden size mismatch {hidden_states.size(-1)} " - f"!= {w1.size(2)}") + f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" + ) - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + torch.float32, + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, ] assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens - E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( - hidden_states, w1, w2, topk_ids) + E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( + hidden_states, w1, w2, topk_ids + ) assert w1.size(0) == E assert w2.size(0) == E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - use_mxfp4_w4a4=self.use_mxfp4_w4a4, - dtype=hidden_states.dtype) + config_dtype = self.quant_config.config_name(hidden_states.dtype) config = try_get_optimal_moe_config( w1.size(), @@ -984,17 +944,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError( - f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, - (E, max_num_tokens, N)) - intermediate_cache2 = _resize_cache(workspace2, - (E, max_num_tokens, N // 2)) + intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - if self.use_fp8_w8a8: + # TODO(bnell): should this be done for any quantized type? + if self.quant_config.use_fp8_w8a8: intermediate_cache1.fill_(0) a1q_scale = normalize_batched_scales_shape(a1q_scale, E) @@ -1007,25 +965,36 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a1q_scale, - B_scale=w1_scale, - B_zp=w1_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + B_scale=self.w1_scale, + B_zp=self.w1_zp, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + ) intermediate_cache2.fill_(0) # TODO (bnell): use triton utility from batched deep gemm. - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + self.activation( + activation, + intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N), + ) qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, max_num_tokens, E, N, - expert_num_tokens, self.quant_dtype, self.per_act_token_quant, - self.block_shape) + intermediate_cache2, + a2_scale, + max_num_tokens, + E, + N, + expert_num_tokens, + self.quant_dtype, + self.per_act_token_quant, + self.block_shape, + ) invoke_moe_batched_triton_kernel( A=qintermediate_cache2, @@ -1034,11 +1003,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a2q_scale, - B_scale=w2_scale, - B_zp=w2_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + B_scale=self.w2_scale, + B_zp=self.w2_zp, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1e3ac6cd79f68..6412c3eaa1932 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,44 +1,60 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" + from typing import Optional import torch +from typing_extensions import override import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_make_workspace_new, maybe_warn_marlin_atomic_add) + marlin_make_workspace_new, + marlin_moe_intermediate_size, + maybe_warn_marlin_atomic_add, +) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import direct_register_custom_op -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - bias1: Optional[torch.Tensor], - bias2: Optional[torch.Tensor], - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_type_id: int, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - activation: Optional[str] = "silu", - expert_map: Optional[torch.Tensor] = None, - global_scale1: Optional[torch.Tensor] = None, - global_scale2: Optional[torch.Tensor] = None, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - is_k_full: bool = True, - inplace: bool = False) -> torch.Tensor: +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_type_id: int, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + activation: Optional[str] = "silu", + expert_map: Optional[torch.Tensor] = None, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + intermediate_cache13: Optional[torch.Tensor] = None, + intermediate_cache2: Optional[torch.Tensor] = None, + is_k_full: bool = True, + output: Optional[torch.Tensor] = None, + inplace: bool = False, +) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -49,8 +65,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, - w2 (torch.Tensor): The second set of expert weights. - w1_scale (torch.Tensor): Scale to be used for w1. - w2_scale (torch.Tensor): Scale to be used for w2. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). + - gating_output (Optional[torch.Tensor]): The output of the gating + operation (before softmax). - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. - sort_indices1 (Optional[torch.Tensor]): The first act_order input @@ -68,22 +84,29 @@ def fused_marlin_moe(hidden_states: torch.Tensor, """ quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ - scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f + scalar_types.uint4, + scalar_types.uint8b128, + scalar_types.uint4b8, + scalar_types.float8_e4m3fn, + scalar_types.float4_e2m1f, ] bit4_scalar_types = [ - scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.float4_e2m1f, ] num_bits = 4 if quant_type in bit4_scalar_types else 8 # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[ - 0], "Number of tokens mismatch" - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // ( - num_bits // 2), "Hidden size mismatch w2" + if gating_output is not None: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch" + ) + assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), ( + "Hidden size mismatch w2" + ) assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" @@ -93,7 +116,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, M, K = hidden_states.shape E = w1.shape[0] - N = w2.shape[1] * 16 + N = marlin_moe_intermediate_size(w1, w2) topk = topk_ids.shape[1] # M block size selection logic @@ -104,31 +127,36 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if global_num_experts == -1: global_num_experts = E - sorted_token_ids, expert_ids, num_tokens_post_padded = \ - moe_align_block_size(topk_ids, block_size_m, global_num_experts, - expert_map) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, block_size_m, global_num_experts, expert_map + ) if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache13 = torch.empty( - (M * topk_ids.shape[1] * max(2 * N, K), ), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] - intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) - intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] - intermediate_cache3 = intermediate_cache3.view(-1, K) + if intermediate_cache2 is None: + intermediate_cache2 = torch.empty( + (M * topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if intermediate_cache13 is None: + intermediate_cache13 = torch.empty( + (M * topk * max(2 * N, K),), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N)) + intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K)) + intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N)) maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) - use_atomic_add = hidden_states.dtype == torch.half or \ - torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + use_atomic_add = ( + hidden_states.dtype == torch.half + or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + ) intermediate_cache1 = ops.moe_wna16_marlin_gemm( hidden_states, @@ -156,18 +184,23 @@ def fused_marlin_moe(hidden_states: torch.Tensor, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, - is_zp_float=False) + is_zp_float=False, + ) if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + torch.ops._C.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + torch.ops._C.swigluoai_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) else: - raise ValueError(f"Unsupported activation: {activation}. " - "Only silu and swigluoai activations are supported.") + raise ValueError( + f"Unsupported activation: {activation}. " + "Only silu and swigluoai activations are supported." + ) if expert_map is not None: intermediate_cache3.zero_() @@ -198,43 +231,178 @@ def fused_marlin_moe(hidden_states: torch.Tensor, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, - is_zp_float=False).view(-1, topk, K) + is_zp_float=False, + ).view(-1, topk, K) - output = hidden_states if inplace else torch.empty_like(hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=output) + if output is None: + if inplace and not disable_inplace(): + output = hidden_states + else: + output = torch.empty_like(hidden_states) + + return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) -def fused_marlin_moe_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_type_id: int, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - global_scale1: Optional[torch.Tensor] = None, - global_scale2: Optional[torch.Tensor] = None, - expert_map: Optional[torch.Tensor] = None, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - is_k_full: bool = True, - inplace: bool = False) -> torch.Tensor: +def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_type_id: int, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + intermediate_cache13: Optional[torch.Tensor] = None, + intermediate_cache2: Optional[torch.Tensor] = None, + is_k_full: bool = True, + output: Optional[torch.Tensor] = None, + inplace: bool = False, +) -> torch.Tensor: return torch.empty_like(hidden_states) direct_register_custom_op( op_name="fused_marlin_moe", op_func=fused_marlin_moe, - mutates_args=[], fake_impl=fused_marlin_moe_fake, ) + + +class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" + super().__init__(quant_config) + + @override + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + assert w1.dim() == 3 and w2.dim() == 3 + + E = w1.size(0) + K = a1.size(-1) + N = marlin_moe_intermediate_size(w1, w2) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + return True + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # Modular Kernel provisions output buffer from workspace1. However in + # the fused_marlin_moe() function, the final torch.sum(), is defined + # essentially as, + # `torch.sum(workspace1, dim=1, out=output)` + # Having overlapping input and output tensors for torch.sum seems + # error prone and depends on how the torch.sum is implemented. + # For this reason we swap let the output buffer provision from + # workspace2. + + # Workspace/IntermediateCache allocation matching fused_marlin_moe() + # workspace1 = (M * topk * max(2 * N, K),) + # workspace2 = (M * topk, N) + + # Workspace/IntermediateCache allocation accounting for output buffer + # provisioning + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk * max(2 * N, K),) + output = (M, K) + + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + assert self.w1_scale is not None + assert self.w2_scale is not None + return fused_marlin_moe( + hidden_states=hidden_states, + w1=w1, + w2=w2, + bias1=self.w1_bias, + bias2=self.w2_bias, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale, + gating_output=None, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16 + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map, + output=output, + # Workspaces are swapped in workspace_shapes() to account for proper + # output buffer allocation. Please refer to workspace_shapes(). + intermediate_cache13=workspace2, + intermediate_cache2=workspace13, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 02b7b65f4a025..eda825ffcae1e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Fused MoE kernel.""" +"""Fused MoE Triton kernels.""" + import functools import json import os -# torch.compile needs typing.List. It will fail torch.library.infer_schema -# otherwise -from typing import List # noqa: UP035 -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -16,31 +14,41 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -# yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, get_config_quant_dtype) + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, + _get_config_dtype_str, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, - run_cutlass_block_scaled_fused_experts) -# yapf: enable + run_cutlass_block_scaled_fused_experts, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8) + _valid_deep_gemm, + deep_gemm_moe_fp8, +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - calculate_tile_tokens_dim) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - dequant_mxfp4) + _resize_cache, + activation_without_mul, + disable_inplace, + moe_kernel_quantize_input, +) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -48,64 +56,73 @@ logger = init_logger(__name__) @triton.jit -def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, - token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, - compute_type): +def write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, +): accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel_gptq_awq( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - b_scale_ptr, - b_zp_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N: tl.constexpr, - K: tl.constexpr, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bse, - stride_bsk, - stride_bsn, - stride_bze, - stride_bzk, - stride_bzn, - block_k_diviable: tl.constexpr, - group_size: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - has_zp: tl.constexpr, - use_int4_w4a16: tl.constexpr, - use_int8_w8a16: tl.constexpr): + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. @@ -154,8 +171,7 @@ def fused_moe_kernel_gptq_awq( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens @@ -164,25 +180,41 @@ def fused_moe_kernel_gptq_awq( # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. - write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, - offs_token, token_mask, BLOCK_SIZE_M, - BLOCK_SIZE_N, compute_type) + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) return - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) if use_int4_w4a16: - b_ptrs = b_ptr + off_experts * stride_be + \ - (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ - stride_bn + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: - b_ptrs = b_ptr + off_experts * stride_be + \ - offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) if not has_zp and use_int4_w4a16: b_zp_num = 8 @@ -208,34 +240,43 @@ def fused_moe_kernel_gptq_awq( k_mask = None k_other = None - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) b = tl.load(b_ptrs) if use_int4_w4a16: b = (b >> b_shifter) & 0xF - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ - offs_bn[None, :] * stride_bsn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ - stride_bsk + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w4a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ - (offs_bn[None, :] // 2) * stride_bzn + \ - offs_k_true * stride_bzk + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = (b_zp >> b_zp_shifter) & 0xF b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ - offs_bn[None, :] * stride_bzn + \ - offs_k_true * stride_bzk + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = b_zp.to(tl.float32) @@ -254,17 +295,14 @@ def fused_moe_kernel_gptq_awq( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @@ -370,8 +408,7 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens @@ -380,22 +417,35 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. - write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, - offs_token, token_mask, BLOCK_SIZE_M, - BLOCK_SIZE_N, compute_type) + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) return - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) if use_int8_w8a16: - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8 or use_int8_w8a8: @@ -403,17 +453,18 @@ def fused_moe_kernel( if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n - b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + - offs_bsn * stride_bsn) + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) # channel-wise elif per_channel_quant: - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, - None] + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] # tensor-wise else: a_scale = tl.load(a_scale_ptr) @@ -431,13 +482,12 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -445,13 +495,12 @@ def fused_moe_kernel( if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, - mask=token_mask, - other=0.0) + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, - None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: if use_fp8_w8a8: # acc used to enable fp8_fast_accum @@ -466,9 +515,7 @@ def fused_moe_kernel( if HAS_BIAS: accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) @@ -483,43 +530,46 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) -def invoke_fused_moe_kernel(A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict[str, Any], - compute_type: tl.dtype, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[list[int]] = None, - B_bias: Optional[torch.Tensor] = None) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, + B_bias: Optional[torch.Tensor] = None, +) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None - assert (block_shape is None - or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) - assert (block_shape is None - or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) + assert block_shape is None or triton.cdiv( + B.size(-2), block_shape[0] + ) == B_scale.size(-2) + assert block_shape is None or triton.cdiv( + B.size(-1), block_shape[1] + ) == B_scale.size(-1) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None @@ -534,16 +584,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor, EM = sorted_token_ids.size(0) if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. - # We assume that top_ids of each token is unique, so + # We assume that top_ids of each token is unique, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. - EM = min(sorted_token_ids.size(0), - A.size(0) * top_k * config['BLOCK_SIZE_M']) - grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( - B.size(1), META['BLOCK_SIZE_N']), ) + EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), + ) HAS_BIAS = B_bias is not None - if (use_int8_w8a16 or use_int4_w4a16) and \ - block_shape is not None and block_shape[1] > 0: + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -551,27 +605,41 @@ def invoke_fused_moe_kernel(A: torch.Tensor, num_valid_tokens=num_tokens, group_size=block_shape[1], num_experts=B.size(0), - bit=4 if use_int4_w4a16 else 8) + bit=4 if use_int4_w4a16 else 8, + ) config = config.copy() config.update( - get_moe_wna16_block_config(config=config, - use_moe_wna16_cuda=use_moe_wna16_cuda, - num_valid_tokens=num_tokens, - size_k=A.size(1), - size_n=B.size(1), - num_experts=B.size(1), - group_size=block_shape[1], - real_top_k=top_k, - block_size_m=config["BLOCK_SIZE_M"])) + get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"], + ) + ) if use_moe_wna16_cuda: bit = 4 if use_int4_w4a16 else 8 - ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, - topk_weights if mul_routed_weight else None, - sorted_token_ids, expert_ids, - num_tokens_post_padded, top_k, - config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], bit) + ops.moe_wna16_gemm( + A, + C, + B, + B_scale, + B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + bit, + ) return fused_moe_kernel_gptq_awq[grid]( @@ -615,8 +683,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, config = config.copy() BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") if block_shape is not None: - BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], - block_shape[1])) + BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) fused_moe_kernel[grid]( A, B, @@ -639,16 +706,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.stride(1), C.stride(1), C.stride(2), - A_scale.stride(0) - if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) - if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) - if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) - if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) - if B_scale is not None and B_scale.ndim >= 2 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, B_bias.stride(0) if B_bias is not None else 0, B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], @@ -666,15 +728,93 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ) +@triton.jit +def compute_identity_kernel( + top_k: int, + hidden_states_ptr: tl.tensor, + expert_scales_ptr: tl.tensor, + num_tokens: int, + output_ptr: tl.tensor, + hidden_dim: int, + scales_stride: int, + BLOCK_SIZE: tl.constexpr, +) -> None: + pid = tl.program_id(0) + + batch_id = pid // (hidden_dim // BLOCK_SIZE) + dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE + + if batch_id >= num_tokens or dim_offset >= hidden_dim: + return + + h = tl.load( + hidden_states_ptr + + batch_id * hidden_dim + + dim_offset + + tl.arange(0, BLOCK_SIZE), + mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, + ) + + result = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for i in range(top_k): + scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i) + result += h * scale + + tl.store( + output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), + result, + mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, + ) + + +def zero_experts_compute_triton( + expert_indices: torch.Tensor, + expert_scales: torch.Tensor, + num_experts: int, + zero_expert_type: str, + hidden_states: torch.Tensor, +) -> torch.Tensor: + N = expert_indices.numel() + top_k = expert_indices.size(-1) + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + + if zero_expert_type == "identity": + zero_expert_mask = expert_indices < num_experts + zero_expert_scales = expert_scales.clone() + zero_expert_scales[zero_expert_mask] = 0.0 + + normal_expert_mask = expert_indices >= num_experts + expert_indices[normal_expert_mask] = 0 + expert_scales[normal_expert_mask] = 0.0 + + output = torch.zeros_like(hidden_states).to(hidden_states.device) + hidden_dim = hidden_states.size(-1) + num_tokens = hidden_states.size(0) + + grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),) + compute_identity_kernel[grid]( + top_k, + hidden_states, + zero_expert_scales, + num_tokens, + output, + hidden_dim, + zero_expert_scales.stride(0), + BLOCK_SIZE=256, + ) + + return output + + # Adapted from: https://github.com/sgl-project/sglang/pull/2628 -def get_config_file_name(E: int, - N: int, - dtype: Optional[str], - block_shape: Optional[list[int]] = None) -> str: +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - block_shape_selector = ("" if not block_shape or not all(block_shape) else - f",block_shape={block_shape}").replace(" ", "") + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ).replace(" ", "") return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 @@ -707,34 +847,50 @@ def get_moe_configs( user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: user_defined_config_file_path = os.path.join( - user_defined_config_folder, json_file_name) + user_defined_config_folder, json_file_name + ) config_file_paths.append(user_defined_config_file_path) default_config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) + logger.info( + "Using configuration from %s for MoE layer.", config_file_path + ) # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + tuned_config = json.load(f) + # Delete triton_version from tuned_config + tuned_config.pop("triton_version", None) + return {int(key): val for key, val in tuned_config.items()} # If no optimized configuration is available, we will use the default # configuration logger.warning( - ("Using default MoE config. Performance might be sub-optimal! " - "Config file not found at %s"), config_file_paths) + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_paths, + ) return None -def get_moe_wna16_block_config(config: dict[str, - int], use_moe_wna16_cuda: bool, - num_valid_tokens: int, size_k: int, size_n: int, - num_experts: int, group_size: int, - real_top_k: int, block_size_m: int): +def get_moe_wna16_block_config( + config: dict[str, int], + use_moe_wna16_cuda: bool, + num_valid_tokens: int, + size_k: int, + size_n: int, + num_experts: int, + group_size: int, + real_top_k: int, + block_size_m: int, +): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: # optimal block config is set return {} @@ -756,20 +912,24 @@ def get_moe_wna16_block_config(config: dict[str, num_n_blocks = size_k // block_size_k num_k_blocks = size_n // block_size_k - num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ - num_experts + num_m_blocks = ( + num_valid_tokens + block_size_m - 1 + ) / block_size_m + num_experts if num_valid_tokens // real_top_k <= block_size_m: num_m_blocks = min(num_m_blocks, num_valid_tokens) num_blocks = num_m_blocks * num_n_blocks * num_k_blocks - if size_k % 256 == 0 and num_blocks >= 256 and \ - block_size_k < 256: + if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: block_size_k = 256 num_blocks = num_blocks // (256 // block_size_k) - if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ - size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ - num_blocks >= 512: + if ( + num_m_blocks <= 16 + and size_k % (block_size_k * 2) == 0 + and size_k % (block_size_k * 2) == 0 + and block_size_k <= 512 + and num_blocks >= 512 + ): block_size_k = block_size_k * 2 num_blocks = num_blocks // 2 @@ -788,10 +948,15 @@ def get_moe_wna16_block_config(config: dict[str, return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} -def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, - num_experts: int, bit: int): - return current_platform.is_cuda() and bit == 4 and \ - group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6 +def should_moe_wna16_use_cuda( + num_valid_tokens: int, group_size: int, num_experts: int, bit: int +): + return ( + current_platform.is_cuda() + and bit == 4 + and group_size in [32, 64, 128] + and num_valid_tokens / num_experts <= 6 + ) def get_default_config( @@ -821,8 +986,7 @@ def get_default_config( # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later bit = 4 if dtype == "int4_w4a16" else 8 - use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, - block_shape[1], E, bit) + use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit) if use_moe_wna16_cuda: config = {"BLOCK_SIZE_M": min(16, M)} elif M <= 20: @@ -857,6 +1021,7 @@ def try_get_optimal_moe_config( block_shape: Optional[list[int]] = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() if override_config: config = override_config @@ -875,15 +1040,17 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - block_shape) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) return config -def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: +def vllm_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> tuple[torch.Tensor, ...]: ops.topk_softmax( topk_weights, topk_indices, @@ -899,6 +1066,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: if is_rocm_aiter_moe_enabled(): from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax + return rocm_aiter_topk_softmax return vllm_topk_softmax @@ -910,35 +1078,52 @@ def fused_topk( renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" M, _ = hidden_states.size() - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) topk_ids = torch.empty( M, topk, dtype=torch.int32 if indices_type is None else indices_type, - device=hidden_states.device) - token_expert_indices = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + device=hidden_states.device, + ) + token_expert_indices = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. topk_func = dispatch_topk_func() - topk_weights, topk_ids = topk_func(topk_weights, topk_ids, - token_expert_indices, - gating_output_float, renormalize) + topk_weights, topk_ids = topk_func( + topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize + ) return topk_weights, topk_ids, token_expert_indices +def fused_topk_bias( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor, + topk: int, + renormalize: bool, +): + n_routed_experts = gating_output.shape[-1] + scores = gating_output.softmax(dim=-1) + scores_for_choice = scores.view( + -1, n_routed_experts + ) + e_score_correction_bias.unsqueeze(0) + topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights.to(torch.float32), topk_indices.to(torch.int32) + + # This is used by the Deepseek-V2 and Deepseek-V3 model @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( @@ -949,11 +1134,29 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if ( + envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + and current_platform.is_cuda() + and num_expert_group <= 32 + and topk <= 32 + and e_score_correction_bias is not None + ): + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + ) - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) @@ -968,118 +1171,230 @@ def grouped_topk( # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -def get_config_dtype_str( - dtype: torch.dtype, - use_int4_w4a16: Optional[bool] = False, - use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False, - use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: - if use_fp8_w8a8: - return "fp8_w8a8" - elif use_int8_w8a16: - return "int8_w8a16" - elif use_int4_w4a16: - return "int4_w4a16" - elif use_mxfp4_w4a4: - return "mxfp4_w4a4" - elif dtype == torch.float: - # avoiding cases where kernel fails when float32 MoE - # use fp16/bfloat16 configs - return "float32" - return None +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Map the logical expert ids to physical expert ids + and record the expert load metrics. + + This will select a pseudo-random replica for each logical expert. + Only used for EPLB. + + Args: + topk_ids: The logical expert ids. + expert_load_view: The expert load view. + logical_to_physical_map: The logical to physical map. + logical_replica_count: The logical replica count. + indices_type: The indices type. + + Returns: + The physical expert ids. + """ + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + # Use (token position) modulo (replica count) + # to deterministically choose a replica + replica_count = logical_replica_count[topk_ids_long] + # Flatten-position based index, reshaped back to `topk_ids` shape + pos_indices = torch.arange( + topk_ids.numel(), device=topk_ids.device, dtype=torch.long + ).reshape_as(topk_ids) + # Compute pseudo-random indices by modulo + replica_indices = (pos_indices % replica_count).unsqueeze(-1) + physical_ids = ( + logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1) + ) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_physical_experts,) + + # `torch.bincount` is not compilable, so use `scatter_add_` instead. + topk_ids_flatten = topk_ids.flatten() + expert_load_view.scatter_add_( + dim=0, + index=topk_ids_flatten.long(), + src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + ) + + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + return topk_ids + + +def fused_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + topk_values, topk_indices = ops.grouped_topk( + scores, + scores_with_bias.to(scores.dtype), + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + return topk_values.to(torch.float32), topk_indices.to(torch.int32) def inplace_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> None: - fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, is_act_and_mul, - apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - use_mxfp4_w4a4, per_channel_quant, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape, w1_bias, w2_bias) + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: Optional[str] = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> None: + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ocp_mx_scheme, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) -def inplace_fused_experts_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> None: +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: Optional[str] = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> None: pass @@ -1088,177 +1403,11 @@ direct_register_custom_op( op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, - tags=(() if is_torch_equal_or_newer("2.7.0") else - (torch.Tag.needs_fixed_stride_order, )), -) - - -def flashinfer_fused_moe_blockscale_fp8( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: List[int], #noqa: UP006 - routed_scaling: float = 1.0) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe - assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 - assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) - assert block_shape == [128, 128] - - a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=w13_weight, - gemm1_weights_scale=w13_weight_scale_inv, - gemm2_weights=w2_weight, - gemm2_weights_scale=w2_weight_scale_inv, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, - global_num_experts), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) - - -def flashinfer_fused_moe_blockscale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routed_scaling: float = 1.0) -> torch.Tensor: - return torch.empty_like(x) - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_blockscale_fp8", - op_func=flashinfer_fused_moe_blockscale_fp8, - mutates_args=[], - fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - -def flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - - quant_hidden_states, _ = moe_kernel_quantize_input( - hidden_states, - input_scale, - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False) - - from vllm.utils.flashinfer import ( - flashinfer_trtllm_fp8_per_tensor_scale_moe) - return flashinfer_trtllm_fp8_per_tensor_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=quant_hidden_states, - gemm1_weights=gemm1_weights, - output1_scales_scalar=output1_scales_scalar, - output1_scales_gate_scalar=output1_scales_gate_scalar, - gemm2_weights=gemm2_weights, - output2_scales_scalar=output2_scales_scalar, - num_experts=num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling_factor, - use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], - top_k, num_experts), - routing_method_type=routing_method_type) - - -def flashinfer_fused_moe_per_tensor_scale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: - pass - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_per_tensor_scale_fp8", - op_func=flashinfer_fused_moe_per_tensor_scale_fp8, - mutates_args=["hidden_states"], - fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), ) @@ -1269,13 +1418,12 @@ def outplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", - is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: Optional[str] = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1285,59 +1433,82 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 + block_shape: Optional[list[int]] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return fused_experts_impl( - hidden_states, w1, w2, topk_weights, topk_ids, False, activation, - is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, - per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ocp_mx_scheme, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) def outplace_fused_experts_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: Optional[str] = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: return torch.empty_like(hidden_states) direct_register_custom_op( op_name="outplace_fused_experts", op_func=outplace_fused_experts, - mutates_args=[], fake_impl=outplace_fused_experts_fake, - tags=(() if is_torch_equal_or_newer("2.7.0") else - (torch.Tag.needs_fixed_stride_order, )), + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), ) def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: torch.ops.vllm.inplace_fused_experts(**kwargs) - hidden_states = kwargs['hidden_states'] + hidden_states = kwargs["hidden_states"] return hidden_states @@ -1346,53 +1517,45 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if inplace: + if inplace and not disable_inplace(): return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, +) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + use_fp8_w8a8 = quant_config.use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and use_fp8_w8a8 - and (is_blackwell_deep_gemm_e8m0_used() - or _valid_deep_gemm(hidden_states, w1, w2))): + if ( + allow_deep_gemm + and quant_config.use_fp8_w8a8 + and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) + ): + assert quant_config is not None assert apply_router_weight_on_input is False - assert is_act_and_mul, ( - "DeepGemm only supports is_act_and_mul=True for now.") return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, @@ -1403,24 +1566,29 @@ def fused_experts(hidden_states: torch.Tensor, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm( - w1, w2, inplace, activation, apply_router_weight_on_input, - expert_map)): + elif ( + allow_cutlass_block_scaled_grouped_gemm + and use_fp8_w8a8 + and _valid_cutlass_block_scaled_grouped_gemm( + w1, w2, inplace, activation, apply_router_weight_on_input, expert_map + ) + ): + assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, topk_weights=topk_weights, - topk_ids=topk_ids) + topk_ids=topk_ids, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1429,28 +1597,56 @@ def fused_experts(hidden_states: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - is_act_and_mul=is_act_and_mul, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_channel_quant=per_channel_quant, + use_fp8_w8a8=quant_config.use_fp8_w8a8, + use_int8_w8a8=quant_config.use_int8_w8a8, + use_int8_w8a16=quant_config.use_int8_w8a16, + use_int4_w4a16=quant_config.use_int4_w4a16, + ocp_mx_scheme=quant_config.ocp_mx_scheme, + per_channel_quant=quant_config.per_act_token_quant, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - w1_bias=w1_bias, - w2_bias=w2_bias, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + w1_zp=quant_config.w1_zp, + w2_zp=quant_config.w2_zp, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, + block_shape=quant_config.block_shape, + w1_bias=quant_config.w1_bias, + w2_bias=quant_config.w2_bias, ) +SILU_NO_MUL: str = activation_without_mul("silu") +GELU_NO_MUL: str = activation_without_mul("gelu") + + +def _get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + ocp_mx_scheme: Optional[str], +) -> Union[None, torch.dtype, str]: + """ + Get the quantization type based on the quantization strategy flags. + We don't have a quant_config at this point so we need to work backwards. + A return type of None means no quantization is required because the + input is unquantized or has been quantized prior to calling + fused_experts_impl. + """ + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": + return "mxfp4" + elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: + return "mxfp6_e3m2" + elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: + return "mxfp6_e2m3" + return None + + def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1459,13 +1655,12 @@ def fused_experts_impl( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", - is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: Optional[str] = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1481,22 +1676,34 @@ def fused_experts_impl( ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: - assert hidden_states.size(1) // 2 == w1.size(2), ( - "Hidden size mismatch") - elif use_mxfp4_w4a4: - # 16bit activation and fp4x2 packed weight - assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" + assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" + elif ocp_mx_scheme is not None: + if ocp_mx_scheme in { + "w_mxfp4_a_mxfp4", + "w_mxfp4_a_mxfp6_e3m2", + "w_mxfp4_a_mxfp6_e2m3", + }: + # 16bit activation and fp4x2 packed weight + assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" + elif ocp_mx_scheme in { + "w_mxfp6_e3m2_a_mxfp6_e3m2", + "w_mxfp6_e2m3_a_mxfp6_e2m3", + }: + assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( + "hidden size mismatch" + ) + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") else: assert hidden_states.size(1) == w1.size(2), ( - f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" + ) assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] num_tokens = hidden_states.size(0) E, N, _ = w1.size() @@ -1508,17 +1715,22 @@ def fused_experts_impl( # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - dtype=hidden_states.dtype) - qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4) + config_dtype = _get_config_dtype_str( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ocp_mx_scheme=ocp_mx_scheme, + dtype=hidden_states.dtype, + ) + + # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are + # quantized prior to calling fused_experts. + quant_dtype = _get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + ocp_mx_scheme=ocp_mx_scheme, + ) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1533,16 +1745,18 @@ def fused_experts_impl( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) - intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) + cache13 = torch.empty( + M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty( + (M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype + ) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 @@ -1553,22 +1767,51 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - if inplace: + if inplace and not disable_inplace(): out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) - if use_mxfp4_w4a4: - # Weight has to be dequantized for mxfp4 emulation. - w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) - w1_scale = None - w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) - w2_scale = None + if ocp_mx_scheme is not None: + # TODO: On platforms for which `current_platform.supports_mx()` is True + # and for which we have a native OCP mx fused MOE kernel, + # this dequantization step should not be done. + if ocp_mx_scheme in { + OCP_MX_Scheme.w_mxfp4_a_mxfp4, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, + }: + # Weight has to be dequantized for mxfp4 emulation. + w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) + w1_scale = None + w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) + w2_scale = None + elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: + w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype + ) + w1_scale = None + w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype + ) + w2_scale = None + elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: + w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype + ) + w1_scale = None + w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype + ) + w2_scale = None + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.size() @@ -1581,8 +1824,9 @@ def fused_experts_impl( # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.size(1)] + intermediate_cache2 = intermediate_cache2[ + : tokens_in_chunk * topk_ids.size(1) + ] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1591,257 +1835,117 @@ def fused_experts_impl( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, - quant_dtype=qtype, + quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + ) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) - invoke_fused_moe_kernel(qcurr_hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - apply_router_weight_on_input, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - B_bias=w1_bias) + invoke_fused_moe_kernel( + qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w1_bias, + ) # Activation function with multiplication - if activation == "silu" and is_act_and_mul: - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - elif activation == "gelu" and is_act_and_mul: - torch.ops._C.gelu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - elif activation == "swigluoai" and is_act_and_mul: - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - # Activation function without multiplication - elif activation == "silu": - intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) + if activation == "silu": + torch.ops._C.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "gelu": + torch.ops._C.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) + # Activation function without multiplication + elif activation == SILU_NO_MUL: + intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) + elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}, " - f"with is_act_and_mul={is_act_and_mul}.") + raise ValueError(f"Unsupported FusedMoe activation: {activation}.") qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - quant_dtype=qtype, + quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + ) - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - B_bias=w2_bias) + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w2_bias, + ) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.size()), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - is_act_and_mul (bool): If True, use activation-and-mul function for - activation (self-gated activation), otherwise use activation function - for activation (ungated activation). - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseekv2 model uses grouped_topk - - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and - OCP MXFP4 activation to compute the inner products for w1 and w2. - Defaults to False. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. - - block_shape: (Optional[list[int]]): Optional block size for block-wise - quantization. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - if not is_act_and_mul: - assert inplace is False, ( - "is_act_and_mul=False is not supported with inplace=True") - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, - topk, renormalize, - num_expert_group, topk_group) - elif custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) - - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - is_act_and_mul=is_act_and_mul, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - w1_bias=w1_bias, - w2_bias=w2_bias) - - class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_mxfp4_w4a4 = use_mxfp4_w4a4 + super().__init__(quant_config) @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -1854,8 +1958,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -1863,11 +1965,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -1880,10 +1982,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -1892,40 +1990,36 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): apply_router_weight_on_input: bool, ): # Check constraints. - if self.use_int4_w4a16: - assert hidden_states.size(-1) // 2 == w1.size(2), ( - "Hidden size mismatch") + if self.quant_config.use_int4_w4a16: + assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" else: - assert hidden_states.size(-1) == w1.size(2), \ - (f"Hidden size mismatch {hidden_states.size(-1)} " - f"!= {w1.size(2)}") + assert hidden_states.size(-1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" + ) - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.dim() == 2 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + torch.float32, + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, ] - E, num_tokens, N, K, top_k_num = mk._moe_problem_size( - hidden_states, w1, w2, topk_ids) + E, num_tokens, N, K, top_k_num = self.moe_problem_size( + hidden_states, w1, w2, topk_ids + ) if global_num_experts == -1: global_num_experts = E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - use_mxfp4_w4a4=self.use_mxfp4_w4a4, - dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, - config_dtype, + self.quant_config.config_name(hidden_states.dtype), num_tokens, block_shape=self.block_shape, ) @@ -1939,28 +2033,26 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError( - f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") # Note that the output tensor might be in workspace1 - intermediate_cache1 = _resize_cache(workspace2, - (num_tokens, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace13, - (num_tokens * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace2, - (num_tokens, top_k_num, K)) + intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache( + workspace13, (num_tokens * top_k_num, N // 2) + ) + intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) invoke_fused_moe_kernel( hidden_states, w1, intermediate_cache1, a1q_scale, - w1_scale, - w1_zp, + self.w1_scale, + self.w1_zp, None, # topk_weights sorted_token_ids, expert_ids, @@ -1969,31 +2061,36 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): top_k_num, config, compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, - B_bias=None # TODO support B_bias + B_bias=self.w1_bias, ) - self.activation(activation, intermediate_cache2, - intermediate_cache1.view(-1, N)) + self.activation( + activation, intermediate_cache2, intermediate_cache1.view(-1, N) + ) a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.quant_dtype, - self.per_act_token_quant, self.block_shape) + intermediate_cache2, + a2_scale, + self.quant_dtype, + self.per_act_token_quant, + self.block_shape, + ) invoke_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, a2q_scale, - w2_scale, - w2_zp, + self.w2_scale, + self.w2_zp, topk_weights, sorted_token_ids, expert_ids, @@ -2002,36 +2099,22 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): 1, config, compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, - B_bias=None # TODO support B_bias + B_bias=self.w2_bias, ) ops.moe_sum(intermediate_cache3, output) def modular_triton_fused_moe( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, - per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ), + TritonExperts(quant_config), ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 312befe2c1d71..283ce80556d26 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,13 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceNoOP, +) +from vllm.triton_utils import tl, triton from vllm.utils import has_triton_kernels logger = init_logger(__name__) @@ -15,16 +21,55 @@ logger = init_logger(__name__) if has_triton_kernels(): try: import triton_kernels.swiglu - from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, - matmul_ogs) - from triton_kernels.routing import routing - except ModuleNotFoundError: + from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs + from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix + from triton_kernels.tensor import Bitmatrix + except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " - "version is compatible.") + "version is compatible. Error: %s", + e, + ) -if TYPE_CHECKING: - from triton_kernels.matmul_ogs import PrecisionConfig + +@triton.jit +def pack_bitmatrix( + bitmatrix, + topk_ids, + n_rows, # n_rows in bitmatrix / topk_ids + bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix + n_expts_act, # num_topk + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + Packs topk_ids into a bitmatrix. + code reference: + https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264 + """ + pid_m = tl.program_id(0) + offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offsets_k = tl.arange(0, BLOCK_SIZE_K) + offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] + mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] + indices = tl.load(topk_ids + offsets, mask=mask, other=-1) + div = indices // 32 + rem = indices % 32 + one = tl.cast(1, tl.uint32) + + # Iterate through all the relevant bitmatrix columns. + for i in range(bm_cols): + # When BLOCK_SIZE_K=32, offs is just the column index. + offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) + # All topks that need to go into this column has the correct bit set. + # Other bits are 0. x is a 2D tensor. + x = tl.where( + div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0 + ) + # Reduce x to get a single int32_t bitpack. + y = tl.reduce_or(x, axis=1) + bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * bm_cols + offs[None, :] + tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) def triton_kernel_moe_forward( @@ -35,25 +80,14 @@ def triton_kernel_moe_forward( topk: int, renormalize: bool, activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - - routing_data, gather_idx, scatter_idx = routing(gating_output, - topk, - sm_first=not renormalize) + routing_data, gather_idx, scatter_idx = routing( + gating_output, topk, sm_first=not renormalize + ) return triton_kernel_fused_experts( None, @@ -64,20 +98,11 @@ def triton_kernel_moe_forward( gather_idx, scatter_idx, activation=activation, + quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, - w1_precision=w1_precision, - w2_precision=w2_precision, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + ) # This is a triton implementation of the fused_experts function @@ -90,28 +115,21 @@ def triton_kernel_fused_experts( gather_indx, # GatherIndx scatter_indx, # ScatterIndx activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, swiglu_alpha: float = 1.702, swiglu_limit: float = 7.0, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, + a1q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 - assert w1_bias is None or w1_bias.dtype == torch.float32 - assert w2_bias is None or w2_bias.dtype == torch.float32 + assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 + assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 # Shape check, only check non-mxfp4 assert hidden_states.shape[-1] == w1.shape[-2] @@ -124,82 +142,132 @@ def triton_kernel_fused_experts( act = FusedActivation( FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), - (swiglu_alpha, swiglu_limit), 2) + (swiglu_alpha, swiglu_limit), + 2, + ) gammas = routing_data.gate_scal if routing_data else None intermediate_cache1 = matmul_ogs( hidden_states, w1, - w1_bias, + quant_config.w1_bias, routing_data, gather_indx=gather_indx, - precision_config=w1_precision, + precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, - fused_activation=act) + fused_activation=act, + ) intermediate_cache3 = matmul_ogs( intermediate_cache1, w2, - w2_bias, + quant_config.w2_bias, routing_data, scatter_indx=scatter_indx, - precision_config=w2_precision, + precision_config=quant_config.w2_precision, gammas=None if apply_router_weight_on_input else gammas, y=output_tensor, ) return intermediate_cache3 -class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +def make_routing_data( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_local_experts: int, +) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: + topk_ids = topk_ids.to(torch.int16) + topk_weights = topk_weights.to(torch.bfloat16) - def __init__( - self, - quant_config, - max_num_tokens: int, - num_dispatchers: int, - w1_precision: "PrecisionConfig", - w2_precision: "PrecisionConfig", - w1_bias: Optional[torch.Tensor], - w2_bias: Optional[torch.Tensor], - ): + n_rows, num_topk = topk_ids.size() + + BLOCK_SIZE_M = 512 + BLOCK_SIZE_K = 32 + + bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks + bitmatrix = torch.zeros( + (n_rows, bm_cols), dtype=torch.uint32, device=topk_ids.device + ) + + grid = (triton.cdiv(n_rows, BLOCK_SIZE_M),) + pack_bitmatrix[grid]( + bitmatrix, + topk_ids, + n_rows, + bm_cols, + num_topk, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + bitmatrix_shape = [n_rows, bm_cols * 32] + bitmatrix_shape_max = [n_rows, None] + bitmatrix = Bitmatrix( + bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None + ) + + # matmul_ogs expects invalid topk_weights to be -1s + topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) + routing_data, gather_indx, scatter_indx = routing_from_bitmatrix( + bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk + ) + + return routing_data, gather_indx, scatter_indx + + +class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Weight application and reduction happens in the fused_experts kernel. + return TopKWeightAndReduceNoOP() + + def _make_routing_data( + self, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_local_experts: int, + ) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: + return make_routing_data(topk_ids, topk_weights, num_local_experts) + + +class OAITritonExperts(BaseOAITritonExperts): + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - self.w1_precision = w1_precision - self.w2_precision = w2_precision - self.w1_bias = w1_bias - self.w2_bias = w2_bias @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: - return False - - def supports_expert_map(self) -> bool: - return False - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return True def workspace_shapes( - self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, - topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata] - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel - assert a.dim() == 2 - num_dp = self.num_dispatchers - num_experts = local_num_experts - max_num_tokens = self.max_num_tokens - workspace2 = (0, 0, 0) - output = (num_experts, max_num_tokens * num_dp, N) - return (output, workspace2, output, a.dtype) + workspace1 = (M, K) + workspace2 = (0, 0) + output = (M, K) + return (workspace1, workspace2, output) def apply( self, @@ -212,10 +280,6 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -223,25 +287,31 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - return triton_kernel_fused_experts( - output, + if expert_map is not None: + topk_ids = expert_map[topk_ids] + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + routing_data, gather_indx, scatter_indx = self._make_routing_data( + topk_ids, topk_weights, local_num_experts + ) + + experts_output = triton_kernel_fused_experts( + None, hidden_states, w1, w2, - None, - None, - None, + routing_data, + gather_indx, + scatter_indx, activation=activation, + quant_config=self.quant_config, apply_router_weight_on_input=False, - use_fp8_w8a8=False, - per_channel_quant=False, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=self.w1_bias, - w2_bias=self.w2_bias, - w1_precision=self.w1_precision, - w2_precision=self.w2_precision, - a1_scale=a1q_scale, - a2_scale=a2_scale) + global_num_experts=local_num_experts, + expert_map=None, # applied already + a1q_scale=a1q_scale, + ) + + output.copy_(experts_output, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fcc6987d26bb2..94a733aa03b93 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,8 +3,9 @@ from abc import abstractmethod from collections.abc import Iterable +from contextlib import nullcontext from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, Union, get_args, overload import torch import torch.nn.functional as F @@ -12,51 +13,82 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs from vllm.config import get_current_vllm_config -from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.config.parallel import ExpertPlacementStrategy +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -# yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig) -# yapf: enable + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + biased_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) + FusedMoEActivationFormat, + FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.fused_moe.routing_simulator import ( - RoutingSimulator) + is_rocm_aiter_moe_enabled, +) +from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, - round_up) +from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts, fused_experts + from .fused_moe import TritonExperts, eplb_map_to_physical_and_record, fused_experts + if has_pplx(): - from .pplx_prepare_finalize import (PplxPrepareAndFinalize, - pplx_hidden_dim_scale_bytes) + from .pplx_prepare_finalize import ( + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, - DeepEPLLPrepareAndFinalize) + from .deepep_ll_prepare_finalize import ( + DEEPEP_QUANT_BLOCK_SHAPE, + DeepEPLLPrepareAndFinalize, + ) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPrepareAndFinalize = None # type: ignore + + def _eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype], + ) -> torch.Tensor: + # CPU fallback: no EPLB so just return as is + return topk_ids + + eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record + if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk) -elif current_platform.is_cpu(): - pass + rocm_aiter_grouped_topk as grouped_topk, + ) else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): @@ -75,18 +107,23 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - - # TODO(bnell): also pass quant_config? def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe - self.fused_experts: Optional[Callable] = None + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None + self.fused_experts: Optional[FusedMoEModularKernel] = None self.topk_indices_dtype = None @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): raise NotImplementedError def uses_weight_scale_2_pattern(self) -> bool: @@ -101,23 +138,27 @@ class FusedMoEMethodBase(QuantizeMethodBase): @staticmethod def _maybe_make_prepare_finalize( - moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: + moe: FusedMoEConfig, + quant_config: Optional[FusedMoEQuantConfig], + ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None - assert not moe.use_flashinfer_cutlass_kernels, \ - "Must be created in modelopt.py" + # TODO: could allow this now + assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" if moe.use_pplx_kernels: + assert quant_config is not None + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, moe.hidden_dim, moe.in_dtype, - moe.quant_dtype, - per_act_token_quant=moe.per_act_token_quant, - block_shape=moe.block_shape, + quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, ) all_to_all_args = dict( @@ -133,13 +174,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): hidden_dim_scale_bytes=hidden_scale_bytes, ) - num_dispatchers = (all2all_manager.world_size // - all2all_manager.tp_group.world_size) + num_dispatchers = ( + all2all_manager.world_size // all2all_manager.tp_group.world_size + ) # Intranode pplx a2a takes a group name while internode does not. if not all2all_manager.internode: - all_to_all_args[ - "group_name"] = all2all_manager.cpu_group.group_name + all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name handle = all2all_manager.get_handle(all_to_all_args) @@ -158,27 +199,26 @@ class FusedMoEMethodBase(QuantizeMethodBase): handle, num_dispatchers=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, - rank_expert_offset=all2all_manager.rank * - moe.num_local_experts, + rank_expert_offset=all2all_manager.rank * moe.num_local_experts, ) elif moe.use_deepep_ll_kernels: + assert quant_config is not None all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, num_ep_ranks=all2all_manager.world_size, num_global_experts=moe.num_experts, - num_local_experts=moe.num_experts // - all2all_manager.world_size) + num_local_experts=moe.num_experts // all2all_manager.world_size, + ) handle = all2all_manager.get_handle(all_to_all_args) - # Note : We may want to use FP8 dispatch even otherwise just to - # reduce datamovement - use_fp8_dispatch = (moe.quant_config is not None - and moe.quant_config.quant_dtype - == current_platform.fp8_dtype() - and moe.quant_config.block_shape - == DEEPEP_QUANT_BLOCK_SHAPE) + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE + ) prepare_finalize = DeepEPLLPrepareAndFinalize( handle, @@ -189,44 +229,63 @@ class FusedMoEMethodBase(QuantizeMethodBase): return prepare_finalize - def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[FusedMoEPrepareAndFinalize]: - if moe.moe_parallel_config.use_all2all_kernels: - return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize( + self.moe, self.moe_quant_config + ) else: return None # Note: init_prepare_finalize should only be called by # prepare_communication_buffer_for_model. - def init_prepare_finalize(self): + def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None - prepare_finalize = self.maybe_make_prepare_finalize(self.moe) + + # We must get the quant config here so that the layer is + # completely initialized, i.e. all weights loaded and post + # processed. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + prepare_finalize = self.maybe_make_prepare_finalize() if prepare_finalize is not None: - logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, - self, id(self)) + logger.debug( + "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) + ) assert self.topk_indices_dtype is None - assert self.fused_experts is None, \ + assert self.fused_experts is None, ( f"Attempt to override experts for {id(self)}!" + ) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe) + experts = self.select_gemm_impl(prepare_finalize, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + layer.shared_experts, ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation raise NotImplementedError( f"{self.__class__.__name__} must select appropriate gemm " - "implementation based on the prepare_finalize") + "implementation based on the prepare_finalize" + ) + + @abstractmethod + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + raise NotImplementedError + + @property + def using_modular_kernel(self) -> bool: + return self.fused_experts is not None @abstractmethod def apply( @@ -243,6 +302,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -250,7 +310,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError @@ -260,77 +320,143 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts else: self.rocm_aiter_fused_experts = None # type: ignore + # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS + self.flashinfer_cutlass_moe_enabled = ( + has_flashinfer_cutlass_fused_moe() + and envs.VLLM_USE_FLASHINFER_MOE_FP16 + and self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + and current_platform.get_device_capability()[0] >= 9 + ) + if self.flashinfer_cutlass_moe_enabled: + logger.info_once( + "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" + ) + from functools import partial + + from .flashinfer_cutlass_moe import flashinfer_cutlass_moe + + self.flashinfer_cutlass_moe = partial( + flashinfer_cutlass_moe, + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + tp_rank=self.moe.moe_parallel_config.tp_rank, + tp_size=self.moe.moe_parallel_config.tp_size, + ep_rank=self.moe.moe_parallel_config.ep_rank, + ep_size=self.moe.moe_parallel_config.ep_size, + ) + else: + if ( + self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + ): + logger.info_once( + "FlashInfer CUTLASS MoE is available for EP" + " but not enabled, consider setting" + " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it." + ) + elif self.moe.moe_parallel_config.dp_size > 1: + logger.info_once( + "FlashInfer CUTLASS MoE is currently not available for DP." + ) + self.flashinfer_cutlass_moe = None # type: ignore + + def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - # TODO(bnell): Remove. Every layer should have an moe config object. - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): + assert self.moe_quant_config is not None + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts() + return TritonExperts(self.moe_quant_config) - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - if self.has_bias: - w13_bias = torch.nn.Parameter(torch.zeros( + w13_weight = torch.nn.Parameter( + torch.empty( num_experts, 2 * intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.moe.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype + ), + requires_grad=False, + ) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - if self.has_bias: - w2_bias = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - dtype=params_dtype), - requires_grad=False) + if self.moe.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory - if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): + if ( + envs.VLLM_ROCM_MOE_PADDING + and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0 + ): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] torch.cuda.empty_cache() + return weight def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -341,48 +467,62 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights) + shuffle_weights, + ) if self.rocm_aiter_moe_enabled: shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 + if self.flashinfer_cutlass_moe_enabled: + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + layer.w13_weight.data = w13_weight_swapped.contiguous() + if current_platform.is_xpu(): import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, use_prepack=True, ) elif current_platform.is_cpu(): + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.fused_moe import cpu_fused_moe - from vllm.model_executor.layers.utils import ( - check_cpu_sgl_kernel) + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel + dtype_w13 = layer.w13_weight.dtype _, n_w13, k_w13 = layer.w13_weight.size() dtype_w2 = layer.w2_weight.dtype _, n_w2, k_w2 = layer.w2_weight.size() - if (envs.VLLM_CPU_SGL_KERNEL - and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) - and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)): + if ( + envs.VLLM_CPU_SGL_KERNEL + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2) + ): packed_w13_weight = torch.ops._C.convert_weight_packed( - layer.w13_weight) + layer.w13_weight + ) assert packed_w13_weight.size() == layer.w13_weight.size() layer.w13_weight.copy_(packed_w13_weight) del packed_w13_weight packed_w2_weight = torch.ops._C.convert_weight_packed( - layer.w2_weight) + layer.w2_weight + ) assert packed_w2_weight.size() == layer.w2_weight.size() layer.w2_weight.copy_(packed_w2_weight) layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) else: layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: - raise NotImplementedError("CPU MOE only supports x86 arch.") + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) def apply( self, @@ -398,6 +538,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -405,7 +546,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -425,6 +566,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -434,6 +576,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logical_replica_count=logical_replica_count, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + def forward_cuda( self, layer: torch.nn.Module, @@ -448,6 +601,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -455,9 +609,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -467,16 +623,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, expert_map=expert_map, expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count) + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + ) if self.rocm_aiter_moe_enabled: - return self.rocm_aiter_fused_experts( + assert self.fused_experts is None + result = self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -484,12 +646,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_ids=topk_ids, expert_map=expert_map, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif self.flashinfer_cutlass_moe_enabled: + return self.flashinfer_cutlass_moe( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + ) elif self.fused_experts is not None: - if self.has_bias: - raise ValueError( - "FusedMoEModularKernel does not support bias.") - return self.fused_experts( + if self.moe.has_bias: + raise ValueError("FusedMoEModularKernel does not support bias.") + result = self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -503,21 +675,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) else: assert fused_experts is not None - return fused_experts( + result = fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_bias=layer.w13_bias if self.has_bias else None, - w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, ) + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result + def forward_cpu( self, layer: torch.nn.Module, @@ -532,6 +711,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -539,12 +719,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): - if enable_eplb is not False or expert_load_view is not None or \ - logical_to_physical_map is not None or \ - logical_replica_count is not None: - raise NotImplementedError("Expert load balancing is not supported " - "for CPU.") + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for CPU.") return layer.cpu_fused_moe( layer, x, @@ -558,6 +740,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map, custom_routing_function, scoring_func, + routed_scaling_factor, e_score_correction_bias, apply_router_weight_on_input, activation, @@ -577,6 +760,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -584,12 +768,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): - if enable_eplb is not False or expert_load_view is not None or \ - logical_to_physical_map is not None or \ - logical_replica_count is not None: - raise NotImplementedError("Expert load balancing is not supported " - "for XPU.") + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for XPU.") assert custom_routing_function is None return layer.ipex_fusion( x, @@ -615,6 +801,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -622,7 +809,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -630,55 +817,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): assert apply_router_weight_on_input is False if scoring_func != "softmax": raise NotImplementedError( - "Only softmax scoring function is supported for TPU.") + "Only softmax scoring function is supported for TPU." + ) if e_score_correction_bias is not None: raise NotImplementedError( - "Expert score correction bias is not supported for TPU.") + "Expert score correction bias is not supported for TPU." + ) assert activation == "silu", f"{activation} is not supported for TPU." - if enable_eplb is not False or expert_load_view is not None or \ - logical_to_physical_map is not None or \ - logical_replica_count is not None: - raise NotImplementedError("Expert load balancing is not supported " - "for TPU.") - return fused_moe_pallas(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=top_k, - gating_output=router_logits, - global_num_experts=global_num_experts, - expert_map=expert_map, - renormalize=renormalize) + assert routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." + ) + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for TPU.") + return fused_moe_pallas( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize, + ) if current_platform.is_tpu(): forward_native = forward_tpu elif current_platform.is_cpu(): forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu else: forward_native = forward_cuda def determine_expert_map( - ep_size: int, ep_rank: int, - global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]: + ep_size: int, + ep_rank: int, + global_num_experts: int, + expert_placement_strategy: ExpertPlacementStrategy = "linear", +) -> tuple[int, Optional[torch.Tensor]]: """ - Calculates how many experts should be assigned to each rank for EP and - creates a mapping from global to local expert index. Experts are - distributed evenly across ranks. Any remaining are assigned to the - last rank. + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. - Args: - ep_size (int): The size of the expert parallel group - global_num_experts (int): The total number of experts in the model. + Args: + ep_size: The size of the expert parallel group + ep_rank: The rank of the current process in the expert parallel + group + global_num_experts: The total number of experts in the model. + expert_placement_strategy: The expert placement strategy. - Returns: - tuple[int, Optional[torch.Tensor]]: A tuple containing: - - local_num_experts (int): The number of experts assigned - to the current rank. - - expert_map (Optional[torch.Tensor]): A tensor of shape - (global_num_experts,) mapping from global to local index. - Contains -1 for experts not assigned to the current rank. - Returns None if ep_size is 1. - """ + Returns: + tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ assert ep_size > 0 if ep_size == 1: return (global_num_experts, None) @@ -686,38 +890,101 @@ def determine_expert_map( # Distribute experts as evenly as possible to each rank. base_experts = global_num_experts // ep_size remainder = global_num_experts % ep_size - if ep_rank < remainder: - local_num_experts = base_experts + 1 - else: - local_num_experts = base_experts + local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts # Create a tensor of size num_experts filled with -1 - expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) - # Create a expert map for the local experts - start_idx = ep_rank * base_experts + min(ep_rank, remainder) - expert_map[start_idx:start_idx + local_num_experts] = torch.arange( - 0, local_num_experts, dtype=torch.int32) + expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) + # Create an expert map for the local experts + if expert_placement_strategy == "linear": + start_idx = ep_rank * base_experts + min(ep_rank, remainder) + expert_map[start_idx : start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) + elif expert_placement_strategy == "round_robin": + local_log_experts = torch.arange( + ep_rank, global_num_experts, ep_size, dtype=torch.int32 + ) + + expert_map[local_log_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) + else: + raise ValueError( + "Unsupported expert placement strategy " + f"'{expert_placement_strategy}', expected one of " + f"{get_args(ExpertPlacementStrategy)}" + ) return (local_num_experts, expert_map) def get_compressed_expert_map(expert_map: torch.Tensor) -> str: """ - Compresses the expert map by removing any -1 entries. + Compresses the expert map by removing any -1 entries. - Args: - expert_map (torch.Tensor): A tensor of shape (global_num_experts,) - mapping from global to local index. Contains -1 for experts not - assigned to the current rank. + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. - Returns: - str: A string mapping from local to global index. - Using str to support hashing for logging once only. - """ + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ global_indices = torch.where(expert_map != -1)[0] local_indices = expert_map[global_indices] return ", ".join( f"{local_index.item()}->{global_index.item()}" - for local_index, global_index in zip(local_indices, global_indices)) + for local_index, global_index in zip(local_indices, global_indices) + ) + + +def maybe_roundup_hidden_size( + hidden_size: int, + act_dtype: torch.dtype, + quant_config: Optional[QuantizationConfig], + moe_parallel_config: FusedMoEParallelConfig, +) -> int: + """ + Given layer hidden size and MoE configurations, round up hidden_size + if necessary. + + Args: + hidden_size: Layer hidden-size + act_dtype: Data type of the layer activations. + quant_config: Fused MoE quantization configuration. + moe_parallel_config: Fused MoE parallelization strategy configuration. + + Return: + Rounded up hidden_size if rounding up is required based on the configs. + Original hidden size otherwise. + """ + + if moe_parallel_config.use_deepep_ht_kernels: + hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size, act_dtype + ) + + # we are padding globally so EP buffer allocation works + if quant_config and quant_config.get_name() == "mxfp4": + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, + get_mxfp4_backend, + ) + + current_mxfp4_backend = get_mxfp4_backend() + if ( + current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + ): + hidden_size = round_up(hidden_size, 128) + elif ( + current_platform.is_rocm() + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): + hidden_size = round_up(hidden_size, 256) + + return hidden_size @CustomOp.register("fused_moe") @@ -738,7 +1005,7 @@ class FusedMoE(CustomOp): intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel + renormalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. """ @@ -762,38 +1029,60 @@ class FusedMoE(CustomOp): prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, + is_sequence_parallel=False, + zero_expert_num: Optional[int] = 0, + zero_expert_type: Optional[str] = None, + expert_mapping: Optional[list[tuple[str, str, int, str]]] = None, ): super().__init__() if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - tp_size_ = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - dp_size_ = (dp_size - if dp_size is not None else get_dp_group().world_size) - vllm_config = get_current_vllm_config() - self.moe_parallel_config: FusedMoEParallelConfig = ( - FusedMoEParallelConfig.make( - tp_size_=tp_size_, - dp_size_=dp_size_, - vllm_parallel_config=vllm_config.parallel_config)) + + # FIXME (varun): We should have a better way of inferring the activation + # datatype. This works for now as the tensor datatype entering the MoE + # operation is typically unquantized (i.e. float16/bfloat16). + if vllm_config.model_config is not None: + moe_in_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + moe_in_dtype = params_dtype + + tp_size_ = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size + + self.is_sequence_parallel = is_sequence_parallel + self.sp_size = tp_size_ if is_sequence_parallel else 1 + + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=tp_size_, + dp_size_=dp_size_, + vllm_parallel_config=vllm_config.parallel_config, + ) self.global_num_experts = num_experts + num_redundant_experts + self.zero_expert_num = zero_expert_num + self.zero_expert_type = zero_expert_type - # we padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 - should_use_flashinfer_mxfp4) - if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): - hidden_size = round_up(hidden_size, 256) + # Expert mapping used in self.load_weights + self.expert_mapping = expert_mapping + + # Round up hidden size if needed. + hidden_size = maybe_roundup_hidden_size( + hidden_size, moe_in_dtype, quant_config, self.moe_parallel_config + ) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -810,25 +1099,58 @@ class FusedMoE(CustomOp): # Determine expert maps if self.use_ep: if self.enable_eplb: - assert self.global_num_experts % self.ep_size == 0, \ - "EPLB currently only supports even distribution of " \ + assert self.global_num_experts % self.ep_size == 0, ( + "EPLB currently only supports even distribution of " "experts across ranks." + ) else: - assert num_redundant_experts == 0, \ + assert num_redundant_experts == 0, ( "Redundant experts are only supported with EPLB." - self.local_num_experts, self.expert_map = determine_expert_map( + ) + + expert_placement_strategy = ( + vllm_config.parallel_config.expert_placement_strategy + ) + if expert_placement_strategy == "round_robin": + # TODO(Bruce): will support round robin expert placement with + # EPLB enabled in the future. + round_robin_supported = ( + (num_expert_group is not None and num_expert_group > 1) + and num_redundant_experts == 0 + and not self.enable_eplb + ) + + if not round_robin_supported: + logger.warning( + "Round-robin expert placement is only supported for " + "models with multiple expert groups and no redundant " + "experts. Falling back to linear expert placement." + ) + expert_placement_strategy = "linear" + + self.expert_map: Optional[torch.Tensor] + local_num_experts, expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + global_num_experts=self.global_num_experts, + expert_placement_strategy=expert_placement_strategy, + ) + self.local_num_experts = local_num_experts + self.register_buffer("expert_map", expert_map) logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + "[EP Rank %s/%s] Expert parallelism is enabled. Expert " + "placement strategy: %s. Local/global" " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + " %s.", + self.ep_rank, + self.ep_size, + expert_placement_strategy, + self.local_num_experts, self.global_num_experts, - get_compressed_expert_map(self.expert_map)) + get_compressed_expert_map(self.expert_map), + ) else: - self.local_num_experts, self.expert_map = (self.global_num_experts, - None) + self.local_num_experts, self.expert_map = (self.global_num_experts, None) self.top_k = top_k @@ -844,48 +1166,49 @@ class FusedMoE(CustomOp): self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError("Only softmax scoring function is supported for " - "non-grouped topk.") + raise ValueError( + "Only softmax scoring function is supported for non-grouped topk." + ) - if vllm_config.model_config is not None: - model_dtype = vllm_config.model_config.dtype - else: - # TODO (bnell): This is a hack to get test_mixtral_moe to work - # since model_config is not set in the pytest test. - model_dtype = params_dtype - - moe = FusedMoEConfig.make(num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config=quant_config, - has_bias=has_bias) + moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=moe_in_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + has_bias=has_bias, + ) self.moe_config = moe + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. quant_method: Optional[QuantizeMethodBase] = None - quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None - else quant_config.get_quant_method(self, prefix)) + quant_method = ( + UnquantizedFusedMoEMethod(moe) + if quant_config is None + else quant_config.get_quant_method(self, prefix) + ) + if quant_method is None: + quant_method = UnquantizedFusedMoEMethod(moe) assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method if self.enable_eplb: - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8MoEMethod) - if not isinstance(quant_method, - (Fp8MoEMethod, UnquantizedFusedMoEMethod)): + from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod + + if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)): # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API @@ -893,22 +1216,23 @@ class FusedMoE(CustomOp): # quantization methods, so I'm leaving it for now. # If you plan to add support for more quantization methods, # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError("EPLB is only supported for FP8 " - "quantization for now.") + raise NotImplementedError( + "EPLB is only supported for FP8 quantization for now." + ) moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, + "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod")): + if self.quant_method.__class__.__name__ in ( + "GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) @@ -916,19 +1240,30 @@ class FusedMoE(CustomOp): # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None - if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_parallel_config.use_flashinfer_cutlass_kernels): - self.batched_hidden_states = torch.zeros( - (moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + + if self.use_dp_chunking: + states_shape: tuple[int, ...] + logits_shape: tuple[int, ...] # Note here we use `num_experts` which is logical expert count + if vllm_config.parallel_config.enable_dbo: + states_shape = (2, moe.max_num_tokens, self.hidden_size) + logits_shape = (2, moe.max_num_tokens, num_experts) + else: + states_shape = (moe.max_num_tokens, self.hidden_size) + logits_shape = (moe.max_num_tokens, num_experts) + + self.batched_hidden_states = torch.zeros( + states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + self.batched_router_logits = torch.zeros( - (moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return None @property def tp_size(self): @@ -972,21 +1307,41 @@ class FusedMoE(CustomOp): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels + return ( + self.moe_quant_config is not None + and self.moe_quant_config.quant_dtype == "nvfp4" + and self.moe_config.use_flashinfer_cutlass_kernels + ) + + @property + def use_dp_chunking(self) -> bool: + # Route to the chunked forward path using the FlashInfer Cutlass kernel + # only when data parallelism (DP) is enabled. + return ( + self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels + or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) + ) def update_expert_map(self): # ep_size and ep_rank should already be updated assert self.expert_map is not None with self.expert_map.device: - self.local_num_experts, self.expert_map = determine_expert_map( + local_num_experts, expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + global_num_experts=self.global_num_experts, + ) + self.local_num_experts = local_num_experts + self.register_buffer("expert_map", expert_map) - def _load_per_tensor_weight_scale(self, shard_id: str, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - expert_id: int): + def _load_per_tensor_weight_scale( + self, + shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int, + ): param_data = param.data # for per tensor weight quantization if shard_id in ("w1", "w3"): @@ -998,25 +1353,32 @@ class FusedMoE(CustomOp): elif shard_id == "w2": param_data[expert_id] = loaded_weight - def _load_combined_w13_weight_scale(self, shard_dim: int, - loaded_weight: torch.Tensor, - param: torch.Tensor, tp_rank: int): + def _load_combined_w13_weight_scale( + self, + shard_dim: int, + loaded_weight: torch.Tensor, + param: torch.Tensor, + tp_rank: int, + ): """ Load w13 weight scales assuming that w1 weight scales and w3 weight scales are stored in the same loaded_weight tensor. """ shard_size = param.shape[shard_dim] - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) param.copy_(loaded_weight) - def _load_model_weight_or_group_weight_scale(self, - shard_dim: int, - expert_data: torch.Tensor, - shard_id: str, - loaded_weight: torch.Tensor, - tp_rank: int, - load_full_w2: bool = False): + def _load_model_weight_or_group_weight_scale( + self, + shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full_w2: bool = False, + ): """ Load grouped weight scales for group quantization or model weights :param shard_dim: dimension to shard @@ -1029,47 +1391,58 @@ class FusedMoE(CustomOp): if shard_id == "w2": # In the case where we have actorder/g_idx, we do not partition the # w2 scales, as indicated by `load_full` argument, for all tp cases - self._load_w2(shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank, - load_full=load_full_w2) + self._load_w2( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + load_full=load_full_w2, + ) elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) - def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, - shard_dim: int, shard_id: str, - loaded_weight: torch.Tensor, - tp_rank: int): + def _load_per_channel_weight_scale( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + ): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_w13(self, - expert_data: torch.Tensor, - shard_dim: int, - shard_id: str, - loaded_weight: torch.Tensor, - tp_rank: int, - load_full: bool = False): + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + def _load_w13( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False, + ): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 if not load_full: - loaded_weight = loaded_weight.narrow(shard_dim, - shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -1080,39 +1453,48 @@ class FusedMoE(CustomOp): expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - def _load_w2(self, - expert_data: torch.Tensor, - shard_dim: int, - loaded_weight: torch.Tensor, - tp_rank: int, - load_full: bool = False): - + def _load_w2( + self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False, + ): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] if not load_full: - loaded_weight = loaded_weight.narrow(shard_dim, - shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) - def _load_single_value(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int): + def _load_single_value( + self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int + ): param_data = param.data # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight - def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): - + def _load_g_idx( + self, + shard_id: str, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + ): if shard_id == "w2": - self._load_w2(shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) + self._load_w2( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) else: assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) @@ -1123,27 +1505,36 @@ class FusedMoE(CustomOp): return self.expert_map[expert_id].item() @overload - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int, - return_success: Literal[False]) -> None: - ... + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: Literal[False], + ) -> None: ... @overload - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int, - return_success: Literal[True]) -> bool: - ... - - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - return_success: bool = False) -> Optional[bool]: + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: Literal[True], + ) -> bool: ... + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False, + ) -> Optional[bool]: if self.quant_config and self.quant_config.get_name() == "mxfp4": # (FIXME) for gpt-oss all experts are combined if "bias" in weight_name: @@ -1166,13 +1557,13 @@ class FusedMoE(CustomOp): # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality if self.quant_method.__class__.__name__ in ( - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod"): + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): loaded_weight = loaded_weight.t().contiguous() if shard_id not in ("w1", "w2", "w3"): - raise ValueError(f"shard_id must be ['w1','w2','w3'] but " - f"got {shard_id}.") + raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.") # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever @@ -1234,43 +1625,49 @@ class FusedMoE(CustomOp): # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) - if ("compressed" in quant_method_name.lower() - and param.data[expert_id] != 1 - and (param.data[expert_id] - loaded_weight).abs() > 1e-5): + if ( + "compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5 + ): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") + f"vs. {loaded_weight}" + ) - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + self._load_single_value( + param=param, loaded_weight=loaded_weight, expert_id=expert_id + ) return True if return_success else None # Case g_idx if "g_idx" in weight_name: - self._load_g_idx(shard_dim=0, - shard_id=shard_id, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank) + self._load_g_idx( + shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + ) return True if return_success else None # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: # Determine per-tensor weight scale patterns based on variant # Use the dedicated method instead of brittle string matching - uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( - ) + uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern() # Call _load_per_tensor_weight_scale() to load per-tensor (scalar) # weights scales. # Input scales are always per-tensor. # Weight scales: FP4 uses "weight_scale_2" and FP8 uses # "weight_scale" for per-tensor scales. - is_per_tensor = ("weight_scale_2" in weight_name - if uses_weight_scale_2 else "weight_scale" - in weight_name) or "input_scale" in weight_name + is_per_tensor = ( + "weight_scale_2" in weight_name + if uses_weight_scale_2 + else "weight_scale" in weight_name + ) or "input_scale" in weight_name if is_per_tensor: self._load_per_tensor_weight_scale( shard_id=shard_id, @@ -1305,12 +1702,12 @@ class FusedMoE(CustomOp): shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=self.tp_rank) + tp_rank=self.tp_rank, + ) return True if return_success else None # Case weight scales, zero_points and offset, weight/input global scales - if ("scale" in weight_name or "zero" in weight_name - or "offset" in weight_name): + if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name: # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported @@ -1323,10 +1720,11 @@ class FusedMoE(CustomOp): shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=self.tp_rank) + tp_rank=self.tp_rank, + ) elif quant_method in [ - FusedMoeWeightScaleSupported.GROUP.value, - FusedMoeWeightScaleSupported.BLOCK.value, + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, @@ -1334,26 +1732,28 @@ class FusedMoE(CustomOp): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank, - load_full_w2=getattr(param, "load_full_w2", False)) + load_full_w2=getattr(param, "load_full_w2", False), + ) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) else: - WEIGHT_SCALE_SUPPORTED = [ - e.value for e in FusedMoeWeightScaleSupported - ] + WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] raise ValueError( - f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}" + ) return True if return_success else None # Case weight_shape if "weight_shape" in weight_name: # only required by compressed-tensors - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + self._load_single_value( + param=param, loaded_weight=loaded_weight, expert_id=expert_id + ) return True if return_success else None # Case model weights @@ -1363,11 +1763,45 @@ class FusedMoE(CustomOp): shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=self.tp_rank) + tp_rank=self.tp_rank, + ) return True if return_success else None return False if return_success else None + def load_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[str]: + if (expert_mapping := self.expert_mapping) is None: + raise ValueError( + "`self.expert_mapping` must be provided to " + "load weights using `self.load_weights`." + ) + for expert_name, loaded_weight in weights: + qual_name = f"{self.layer_name}.{expert_name}" + for param_name, weight_name, expert_id, shard_id in expert_mapping: + if weight_name not in qual_name: + continue + weight_name = qual_name.replace(weight_name, param_name) + param_name = weight_name.removeprefix(f"{self.layer_name}.") + param = getattr(self, param_name) + success = self.weight_loader( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + logger.debug( + "Loaded %s for expert %d into %s", + param_name, + expert_id, + self.layer_name, + ) + yield param_name + def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) assert all(weight.is_contiguous() for _, weight in weights) @@ -1380,8 +1814,11 @@ class FusedMoE(CustomOp): } return [ - weight.view(self.local_num_experts, -1) for name, weight in weights + weight.view(self.local_num_experts, -1) + for name, weight in weights if name not in NON_EXPERT_WEIGHTS + and weight.shape != torch.Size([]) + and not name.startswith("_shared_experts.") ] def set_eplb_state( @@ -1401,6 +1838,12 @@ class FusedMoE(CustomOp): self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] + def ensure_moe_quant_config(self): + if self.quant_method.moe_quant_config is None: + self.quant_method.moe_quant_config = ( + self.quant_method.get_fused_moe_quant_config(self) + ) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1412,6 +1855,7 @@ class FusedMoE(CustomOp): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, @@ -1419,30 +1863,38 @@ class FusedMoE(CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + global_num_experts: Optional[int] = None, + zero_expert_num: Optional[int] = None, + zero_expert_type: Optional[str] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the router logits. Returns: - (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): - The weights and *global physical* expert ids of the top-k experts. + (topk_weights, topk_ids, zero_expert_result) + (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + The weights, expert ids, and zero expert computation result. **Compatibility**: When EPLB is not enabled, the returned ids are equivalent to global logical ids, so should be compatible with plain MoE implementations without redundant experts. """ - from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, + fused_topk_bias, + ) # Check if we should use a routing simulation strategy routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY if routing_strategy != "": - return RoutingSimulator.simulate_routing( + topk_weights, topk_ids = RoutingSimulator.simulate_routing( hidden_states=hidden_states, router_logits=router_logits, strategy_name=routing_strategy, top_k=top_k, - indices_type=indices_type) + indices_type=indices_type, + ) # DeepSeekv2 uses grouped_top_k if use_grouped_topk: @@ -1456,9 +1908,21 @@ class FusedMoE(CustomOp): num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + elif e_score_correction_bias is not None: + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + e_score_correction_bias=e_score_correction_bias.data, + topk=top_k, + renormalize=renormalize, + ) + if routed_scaling_factor is not None: + topk_weights *= routed_scaling_factor elif custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states=hidden_states, @@ -1472,7 +1936,8 @@ class FusedMoE(CustomOp): hidden_states=hidden_states, gating_output=router_logits, topk=top_k, - renormalize=renormalize) + renormalize=renormalize, + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1481,59 +1946,33 @@ class FusedMoE(CustomOp): assert logical_to_physical_map is not None assert logical_replica_count is not None - # 1. Convert the logical expert ids to physical expert ids - # Directly select a random replica for each logical expert - - # TODO: maybe optimize this by using specified kernels, - # or compute pseudo-random indices by modulo - - # In case `indices_type` is not `torch.long` or `torch.int`, - # e.g. `torch.uint32` as required by dispatch/combine kernels - topk_ids_long = topk_ids.long() - replica_indices = ( - torch.rand_like(topk_ids, dtype=torch.float) * - logical_replica_count[topk_ids_long]).long().unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids_long].gather( - -1, replica_indices).squeeze(-1) - - topk_ids = physical_ids - - # 2. Record expert load metrics. - - # TODO(bowen): When using `FusedMoEModularKernel`, this - # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert - # token count, in some cases directly from the kernel. - # However, now there are many code paths not using - # the modular kernel, e.g. calling `fused_experts`, - # so we decide to keep the logic here. - # - # If later refactor moved all the MoE kernel calls - # to the modular kernel, we can move this logic there - # to achieve better efficiency. - - # `expert_load_view`: (num_physical_experts,) - - topk_ids_flatten = topk_ids.flatten() - - # Performance optimization: - # `masked_fill` is significantly faster than `masked_select` - invalid_mask = topk_ids_flatten < 0 - # Replace invalid expert ids with 0 (just a dummy position) - # to avoid out-of-bounds errors in scatter_add_ - index = topk_ids_flatten.masked_fill_(invalid_mask, 0) - # `src` is the valid mask, which is 1 for valid and 0 for invalid - src = ~invalid_mask - - expert_load_view.scatter_add_(dim=0, - index=index.long(), - src=src.to(expert_load_view)) - - topk_ids = topk_ids.to(dtype=indices_type) + topk_ids = eplb_map_to_physical_and_record( + topk_ids=topk_ids, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + indices_type=indices_type, + ) assert topk_ids.dtype == indices_type or indices_type is None - return topk_weights, topk_ids + # Compute zero expert result if needed + if ( + zero_expert_num is not None + and zero_expert_num > 0 + and zero_expert_type is not None + and global_num_experts is not None + ): + zero_expert_result = zero_experts_compute_triton( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=hidden_states, + ) + else: + zero_expert_result = None + return topk_weights, topk_ids, zero_expert_result def must_reduce_shared_expert_outputs(self) -> bool: """ @@ -1548,67 +1987,129 @@ class FusedMoE(CustomOp): Therefore it is required that we reduce the shared_experts output early. """ - return (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels) + assert self.quant_method is not None + return ( + self.quant_method.fused_experts is not None + and self.quant_method.fused_experts.output_is_reduced() + ) - def maybe_all_reduce_tensor_model_parallel( - self, final_hidden_states: torch.Tensor): + def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """ - The pplx combine kernel reduces across GPU ranks by default. + Some combine kernels reduce across GPU ranks by default. """ - if (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels): + if self.must_reduce_shared_expert_outputs(): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_native( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: - hidden_states = F.pad(hidden_states, - (0, self.hidden_size - og_hidden_states), - mode='constant', - value=0.0) - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward( - hidden_states, router_logits, - self.layer_name)[..., :og_hidden_states] + hidden_states = F.pad( + hidden_states, + (0, self.hidden_size - og_hidden_states), + mode="constant", + value=0.0, + ) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name + ) + return fused_output[..., :og_hidden_states] + else: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits + ) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name + ) + return ( + shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states], + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(hidden_states, router_logits) + + def forward_impl_chunked( + self, + full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype assert self.batched_router_logits.dtype == full_router_logits.dtype # Check size compatibility. - assert ( - self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)) - assert ( - self.batched_router_logits.size(-1) == full_router_logits.size(-1)) + assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) + assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) - full_final_hidden_states = torch.empty_like(full_hidden_states) + self.ensure_moe_quant_config() + + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) + if self.shared_experts is not None: + full_shared_final_hidden_states = torch.empty_like(full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - assert (self.batched_hidden_states.size(0) # type: ignore - >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore - >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + # This is only true when DBO has been enabled in the config. + # Both tensors will have an outer dimension for the ubatch id + if self.batched_hidden_states.dim() == 3: + assert self.batched_router_logits.dim() == 3 + batch_buffer_idx = dbo_current_ubatch_id() + batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[batch_buffer_idx, :] + else: + batched_hidden_states = self.batched_hidden_states + batched_router_logits = self.batched_router_logits + + assert ( + batched_hidden_states.size(0) # type: ignore + >= chunk_size + ) + assert ( + batched_router_logits.size(0) # type: ignore + >= chunk_size + ) + staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) + # If there are shared experts but we are not using a modular kernel, + # the shared experts must be called here + if ( + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + and self.shared_experts is not None + ): + shared_output = self.shared_experts(staged_hidden_states) + else: + shared_output = None + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1623,6 +2124,7 @@ class FusedMoE(CustomOp): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, enable_eplb=self.enable_eplb, @@ -1631,109 +2133,201 @@ class FusedMoE(CustomOp): logical_replica_count=self.logical_replica_count, ) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) + + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, tuple) + assert self.shared_experts is None + final_hidden_states, zero_expert_result = final_hidden_states + if zero_expert_result is not None: + final_hidden_states += zero_expert_result + if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states, non_blocking=True) + if self.shared_experts is None: + full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states, non_blocking=True + ) + else: + full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states[0], non_blocking=True + ) + full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states[1], non_blocking=True + ) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + # If the input to the MoE is sequence parallel then divide by sp_size + # to find the maximum number of tokens for any individual dispatcher. + if self.is_sequence_parallel: + max_tokens_across_dispatchers = cdiv( + max_tokens_across_dispatchers, self.sp_size + ) + num_tokens = full_hidden_states.size(0) for chunk_idx, chunk_start_ in enumerate( - range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): + range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank) + ): chunk_start = chunk_start_ - chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dp) + chunk_end = min( + chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers + ) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, - chunk_idx): - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= num_tokens) + with ctx.dp_metadata.chunked_sizes( + self.sp_size, moe_dp_chunk_size_per_rank, chunk_idx + ): + process_chunk( + chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens + ) - return full_final_hidden_states + if self.shared_experts is None: + return full_fused_final_hidden_states + else: + return (full_shared_final_hidden_states, full_fused_final_hidden_states) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None - # Route to the chunked forward path using the FlashInfer Cutlass kernel - # only when data parallelism (DP) is enabled. - use_flashinfer_cutlass_kernels = ( - self.dp_size > 1 - and self.moe_parallel_config.use_flashinfer_cutlass_kernels) - if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or use_flashinfer_cutlass_kernels): + + self.ensure_moe_quant_config() + + if self.use_dp_chunking: return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( - self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) - - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - enable_eplb=self.enable_eplb, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, + self.dp_size > 1 and not self.quant_method.using_modular_kernel ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + # If there are shared experts but we are not using a modular kernel, the + # shared experts must be called here + if ( + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + and self.shared_experts is not None + ): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None - return final_hidden_states + ctx = get_forward_context() + sp_ctx = ( + ctx.dp_metadata.sp_local_sizes(self.sp_size) + if ctx.dp_metadata + else nullcontext() + ) + + with sp_ctx: + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits, self.is_sequence_parallel + ) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, + ) + + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, tuple) + final_hidden_states, zero_expert_result = final_hidden_states + + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states, self.is_sequence_parallel) + + if ( + not self.is_sequence_parallel + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states + + if self.shared_experts is not None: + return ( + reduce_output(final_hidden_states[0], do_combine=False), + reduce_output(final_hidden_states[1]), + ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, torch.Tensor) + return reduce_output(final_hidden_states) + zero_expert_result + else: + return reduce_output(final_hidden_states) @classmethod def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: - + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + num_redundant_experts: int = 0, + ) -> list[tuple[str, str, int, str]]: num_physical_experts = num_experts + num_redundant_experts # In the returned mapping: # - `expert_id` is the physical expert id # - `weight_name` contains the weight name of the logical expert # So that we should map the expert id to logical in `weight_name` - physical_to_logical_map = \ + physical_to_logical_map = ( EplbState.build_initial_global_physical_to_logical_map( - num_experts, num_redundant_experts) + num_experts, num_redundant_experts + ) + ) return [ # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", - expert_id, shard_id) for expert_id in range(num_physical_experts) + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_", + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_physical_experts) for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), @@ -1742,7 +2336,6 @@ class FusedMoE(CustomOp): ] def extra_repr(self) -> str: - s = ( f"global_num_experts={self.global_num_experts}, " f"local_num_experts={self.local_num_experts}, " @@ -1752,7 +2345,8 @@ class FusedMoE(CustomOp): f"ep_size={self.ep_size}, " f"reduce_results={self.reduce_results}, " f"renormalize={self.renormalize}, " - f"use_grouped_topk={self.use_grouped_topk}") + f"use_grouped_topk={self.use_grouped_topk}" + ) if self.use_grouped_topk: s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 @@ -1762,17 +2356,22 @@ class FusedMoE(CustomOp): return s -def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - assert self.quant_method is not None - + assert self.shared_experts is None return self.forward_impl(hidden_states, router_logits) -def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1781,8 +2380,37 @@ direct_register_custom_op( op_func=moe_forward, mutates_args=["hidden_states"], fake_impl=moe_forward_fake, - dispatch_key=current_platform.dispatch_key, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), +) + + +def moe_forward_shared( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.shared_experts is not None + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_shared_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + shared_out = torch.empty_like(hidden_states) + fused_out = torch.empty_like(hidden_states) + return shared_out, fused_out + + +direct_register_custom_op( + op_name="moe_forward_shared", + op_func=moe_forward_shared, + mutates_args=["hidden_states"], + fake_impl=moe_forward_shared_fake, + tags=(torch.Tag.needs_fixed_stride_order,), ) # Mark the FusedMoE weight_loader as supporting MoE-specific parameters diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2ea6383d5ae90..b5602a112ef13 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,15 +4,25 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Callable, Optional, Union, final import torch import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable - _resize_cache, count_expert_num_tokens) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, + count_expert_num_tokens, + disable_inplace, +) from vllm.utils import cdiv +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_maybe_run_recv_hook, + dbo_register_recv_hook, + dbo_yield, +) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -52,75 +62,38 @@ from vllm.utils import cdiv # -def _moe_problem_size( - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, -) -> tuple[int, int, int, int, int]: - """ - Extract the MoE problem size from the given tensor arguments: - - a: The hidden states, input to the MoE layer. - - w1: The first set of expert weights. - - w2: The second set of expert weights. - - topk_ids: The topk ids. - - Note: extracting the problem shape from the weight and activation tensors is - not obvious. It needs to be done this way specifically due to subtle issues - with particular kernels, e.g. the int4 kernels divide the trailing dimension - by two, so it's not "correct" to extract N or K from the trailing dimension - of w1 or w2. Similarly, some kernels transpose the weights, so this needs - to be kept in mind. - """ - assert w1.dim() == 3 and w2.dim() == 3 - E, N, _ = w1.size() - K = w2.size(1) - - if a1.dim() == 2: - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), \ - f"{topk_ids.size(0)} != {a1.size(0)}" - M = a1.size(0) - else: - assert a1.dim() == 3 - assert a1.size(0) == E, f"{a1.size(0)} == {E}" - M = a1.size(1) # This is max_num_tokens - - assert topk_ids.dim() == 2 - topk = topk_ids.size(1) - - return E, M, N, K, topk - - class FusedMoEActivationFormat(Enum): """ The standard activation format (num_tokens, hidden dim). """ - Standard = "standard", + + Standard = ("standard",) """ The batched experts format (num experts, max tokens per expert, hidden dim) """ - BatchedExperts = "batched_experts", + BatchedExperts = ("batched_experts",) @dataclass class ExpertTokensMetadata: """ - Metadata regarding expert-token routing. - """ + Metadata regarding expert-token routing. + """ + expert_num_tokens: torch.Tensor expert_num_tokens_cpu: Optional[torch.Tensor] @staticmethod - def make_from_list(expert_num_tokens_list: list[int], - device: str) -> "ExpertTokensMetadata": - expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list, - device="cpu", - dtype=torch.int32) + def make_from_list( + expert_num_tokens_list: list[int], device: str + ) -> "ExpertTokensMetadata": + expert_num_tokens_cpu = torch.tensor( + expert_num_tokens_list, device="cpu", dtype=torch.int32 + ) return ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens_cpu.to(device, - non_blocking=True), - expert_num_tokens_cpu=expert_num_tokens_cpu) + expert_num_tokens=expert_num_tokens_cpu.to(device, non_blocking=True), + expert_num_tokens_cpu=expert_num_tokens_cpu, + ) class TopKWeightAndReduce(ABC): @@ -129,10 +102,14 @@ class TopKWeightAndReduce(ABC): """ @abstractmethod - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: """ Apply topk_weights to the fused_experts_outputs and/or reduce. If an output tensor is not passed, it will be created in the @@ -141,6 +118,29 @@ class TopKWeightAndReduce(ABC): raise NotImplementedError +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - Optional ExpertTokensMetadata containing gpu/cpu tensors +# as big as the number of local experts with the information about the +# number of tokens assigned to each local expert. +# - Optional dispatched expert topk IDs +# - Optional dispatched expert topk weight +# +# See `prepare` method below. +# +PrepareResultType = tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], +] + +ReceiverType = Callable[[], PrepareResultType] + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -152,24 +152,56 @@ class FusedMoEPrepareAndFinalize(ABC): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: + ) -> PrepareResultType: """ - Perform any quantization (and/or) dispatching needed - for this kernel. + Perform any quantization (and/or) dispatching needed for this kernel. + - a1: The (unquantized) input to the MoE layer. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + - quant_config: Quantization info provided by the fused experts. + + Returns a tuple of: + - quantized + dispatched a. + - Optional quantized + dispatched a1_scales. + - Optional ExpertTokensMetadata containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. + - Optional dispatched expert topk IDs + - Optional dispatched expert topk weight + """ + raise NotImplementedError + + def supports_async(self) -> bool: + """ + Indicates whether or not this class implements prepare_async and + finalize_async. + """ + return False + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> Union[tuple[Callable, ReceiverType], ReceiverType]: + """ + Perform any quantization (and/or) dispatching needed for this kernel + but do not wait for results from other workers. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make @@ -182,14 +214,26 @@ class FusedMoEPrepareAndFinalize(ABC): - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - Returns a tuple of: - - quantized + dispatched a. - - quantized + dispatched a1_scales. - - Optional ExpertTokensMetadata containing gpu/cpu tensors - as big as the number of local experts with the information about the - number of tokens assigned to each local expert. - - Optional dispatched expert topk IDs - - Optional dispatched expert topk weight + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as + `prepare`, if a hook is returned this is more lightweight check that + the recv is complete without doing extra work (used by DBO, will be + refactored in the very near future) + + e.g. + + ret = obj.prepare_async(...) + + if isinstance(ret, tuple): + hook, receiver = ret + hook() + + if hook is not None: + a, a_scales, expert_meta, topk_ids, topk_weights = receiver() + + is equivalent to: + + a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...) """ raise NotImplementedError @@ -218,6 +262,48 @@ class FusedMoEPrepareAndFinalize(ABC): """ raise NotImplementedError + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> Union[tuple[Callable, Callable], Callable]: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output but do not wait for results from other workers. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. + + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as + `finalize`, if a hook is returned this is more lightweight check that + the recv is complete without doing extra work (used by DBO, will be + refactored in the very near future) + + ret = obj.finalize_async(output, ...) + ... output not valid yet ... + if isinstance(ret, tuple): + hook, receiver = ret + hook() + receiver() + ... output valid here ... + + is equivalent to: + + obj.finalize(output, ...) + """ + raise NotImplementedError + @property @abstractmethod def activation_format(self) -> FusedMoEActivationFormat: @@ -241,7 +327,7 @@ class FusedMoEPrepareAndFinalize(ABC): def max_num_tokens_per_rank(self) -> Optional[int]: """ Some PrepareFinalize All2All implementations are batched. Meaning, - they can processes only as set of tokens at a time. This + they can process only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens the implementation can process at a time. Return None if there are no such restrictions. @@ -252,7 +338,16 @@ class FusedMoEPrepareAndFinalize(ABC): def num_dispatchers(self) -> int: raise NotImplementedError + @abstractmethod + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of finalize is reduced across all + ranks. + """ + raise NotImplementedError + +# TODO: add supported activations method (return string) class FusedMoEPermuteExpertsUnpermute(ABC): """ An abstract base class for the [Permute-Experts-Unpermute] step described @@ -261,23 +356,72 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: FusedMoEQuantConfig, ): - if quant_config is not None: - self.quant_config = quant_config - else: - self.quant_config = FusedMoEQuantConfig() + """ + quant_config: Quantization parameters for this experts instance. + """ + self.quant_config = quant_config @property @abstractmethod def activation_formats( - self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + self, + ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: """ A property which is a tuple of the input and output activation formats for the 'apply' method. """ raise NotImplementedError + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation + tensors is not obvious. It needs to be done this way specifically + due to subtle issues with particular kernels, e.g. the int4 kernels + divide the trailing dimension by two, so it's not "correct" to + extract N or K from the trailing dimension of w1 or w2. Similarly, + some kernels transpose the weights, so this needs to be kept in mind. + + Note: This implementation covers most cases. However, if experts + require a specialized implementation, like MarlinExperts, they are free + to override this function. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.size() + K = a1.size(-1) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + + # + # Various helpers for accessing quantization parameters from the + # quant_config. + # + @property def quant_dtype(self) -> Optional[torch.dtype]: return self.quant_config.quant_dtype @@ -294,6 +438,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def per_out_ch_quant(self) -> bool: return self.quant_config.per_out_ch_quant + @property + def a1_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.a1_scale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.a2_scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self.quant_config.a1_gscale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self.quant_config.a2_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_scale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_zp + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_bias + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_bias + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self.quant_config.g1_alphas + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self.quant_config.g2_alphas + # TODO (bnell): make this return a CHUNK_SIZE or None instead? @abstractmethod def supports_chunking(self) -> bool: @@ -310,11 +502,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + """ + Workspace type: The dtype to use for the workspace tensors. + """ + return act_dtype + @abstractmethod def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -322,27 +518,39 @@ class FusedMoEPermuteExpertsUnpermute(ABC): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Compute the shapes for the temporary and final outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm. + Inputs: + - M: number of tokens. + - N: Row (or column) dimension of expert weights. + - K: hidden dimension + - topk: The number of top-k experts to select. + - global_num_experts: global number of experts. + - local_num_experts: local number of experts due to DP/EP. + - expert_tokens_meta: number of tokens per expert metadata for batched + format. + Returns a tuple of: - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - workspace2 shape tuple: must be large enough to hold the result of the activation function. - output shape tuple: must be exact size of the final gemm output. - - Workspace type: The dtype to use for the workspace tensors. - - Note: in order for activation chunking to work, the first dimension - of each tuple must be the number of tokens. + - Note: workspace shapes can be 0 if the workspace is not needed. + But in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens when the shape is + not 0. """ raise NotImplementedError - def activation(self, activation: str, output: torch.Tensor, - input: torch.Tensor) -> None: + def activation( + self, activation: str, output: torch.Tensor, input: torch.Tensor + ) -> None: assert output.size(-1) * 2 == input.size(-1) if activation == "silu": torch.ops._C.silu_and_mul(output, input) @@ -352,8 +560,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): raise ValueError(f"Unsupported FusedMoe activation: {activation}") def enable_chunking(self): - return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ - self.supports_chunking() + return ( + envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking() + ) def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: raise NotImplementedError @@ -370,17 +579,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - ): + ) -> None: """ This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2. @@ -392,7 +597,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights: A map of row to expert weights. Some implementations - choose to do weight application. + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -401,15 +606,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. - - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be - used for a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + used for a1. Result of quantization from prepare/finalize and not + from the FusedMoEQuantConfig. - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation @@ -425,8 +624,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): raise NotImplementedError -def _chunk_scales(scales: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def _slice_scales( + scales: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: return scales @@ -435,6 +635,25 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, return None +class SharedResizableBuffer: + def __init__(self): + self.buffer = None + + def get( + self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + assert shape != () + shape_numel = prod(shape) + if ( + self.buffer is None + or self.buffer.numel() < shape_numel + or self.buffer.device != device + or self.buffer.dtype != dtype + ): + self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) + return self.buffer[:shape_numel].view(*shape) + + @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -449,94 +668,284 @@ class FusedMoEModularKernel(torch.nn.Module): objects. """ + class SharedBuffers: + def __init__(self) -> None: + self.fused_out = SharedResizableBuffer() + self.workspace13 = SharedResizableBuffer() + self.workspace2 = SharedResizableBuffer() + + # Persistent buffers that are shared across `FusedMoEModularKernel` + # instances (layers), to save memory and allocattions. + # + # We have two sets of buffers to support dual batch overlap (DBO) where each + # microbatch (ubatch) should use its own set of buffers to avoid + # cross-ubatch contimination. + # NOTE that memory is lazily allocated for these buffers, meaning that if + # DBO isn't being used, the second SharedBuffers will be empty. + shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] + def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, + shared_experts: Optional[torch.nn.Module] = None, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts - assert prepare_finalize.activation_format == \ - fused_experts.activation_formats[0], ( - f"{prepare_finalize.__class__.__name__}." - f"{prepare_finalize.activation_format} == " - f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_formats[0]}") + self.shared_experts = shared_experts + assert ( + prepare_finalize.activation_format == fused_experts.activation_formats[0] + ), ( + f"{prepare_finalize.__class__.__name__}." + f"{prepare_finalize.activation_format} == " + f"{fused_experts.__class__.__name__}." + f"{fused_experts.activation_formats[0]}" + ) - def _do_fused_experts( + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of fused MoE kernel + is reduced across all ranks. + """ + return self.prepare_finalize.output_is_reduced() + + def _chunk_info(self, M: int) -> tuple[int, int]: + """ + Compute number of chunks and chunk size for given M. + If chunking is not supported, set the CHUNK_SIZE to M so we + get num_chunks == 1. Take max(M, 1) to avoid divide by zero. + If there are no tokens to process, the number of chunks will be zero. + """ + CHUNK_SIZE = max( + 1, + ( + M + if not self.fused_experts.supports_chunking() + else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + ), + ) + num_chunks = cdiv(M, CHUNK_SIZE) + # If there are no tokens, then there should be no loop iterations. + assert M > 0 or num_chunks == 0 + return num_chunks, CHUNK_SIZE + + def _allocate_buffers( self, - fused_out: Optional[torch.Tensor], - a1: torch.Tensor, - a1q: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, + out_dtype: torch.dtype, + device: torch.device, + M_chunk: int, + M_full: int, + N: int, + K: int, + top_k: int, global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Allocate temporary and output buffers for the fused experts op. + Inputs: + - out_dtype: output type of workspace and output tensors. + - device: the device of the workspace and output tensors. + See `workspace_shapes` for a description of the remainder of arguments. + Returns a tuple of (workspace13, workspace2, output) tensors. + """ + assert M_full > 0 and M_chunk > 0 - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + num_chunks, _ = self._chunk_info(M_full) - (workspace13_shape, workspace2_shape, fused_out_shape, - workspace_dtype) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, - expert_tokens_meta) + # select per-ubatch buffers to avoid cross-ubatch reuse under DBO + ubatch_idx = dbo_current_ubatch_id() + buffers = self.shared_buffers[ubatch_idx] + workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) + + # Get intermediate workspace shapes based off the chunked M size. + workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( + M_chunk, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + + # Get final output shape based on the full M size. + _, _, fused_out_shape = self.fused_experts.workspace_shapes( + M_full, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = torch.empty(prod(workspace13_shape), - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device=a1.device, - dtype=workspace_dtype) - - assert fused_out is None or fused_out.shape == fused_out_shape, ( - f"fused_out {fused_out.shape} but expected {fused_out_shape}") - if fused_out is None: - # reuse workspace13 for the output - fused_out = _resize_cache(workspace13, fused_out_shape) - - self.fused_experts.apply( - fused_out, - a1q, - w1, - w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, + workspace13 = buffers.workspace13.get( + workspace13_shape, device=device, dtype=workspace_dtype + ) + workspace2 = buffers.workspace2.get( + workspace2_shape, device=device, dtype=workspace_dtype ) - return fused_out + # Construct the entire output that can then be processed in chunks. + # Reuse workspace13 for the output in the non-chunked case as long + # as it is large enough. This will not always be the case for standard + # format experts and with experts that have empty workspaces. + if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): + fused_out = _resize_cache(workspace13, fused_out_shape) + else: + fused_out = buffers.fused_out.get( + fused_out_shape, device=device, dtype=out_dtype + ) - def _maybe_chunk_fused_experts( + return workspace13, workspace2, fused_out + + @staticmethod + def _slice_output_tensor( + fused_out: torch.Tensor, + chunk_idx: int, + num_chunks: int, + CHUNK_SIZE: int, + M: int, + ) -> torch.Tensor: + if num_chunks == 1: + return fused_out + + assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}" + factor = fused_out.size(0) // M + out_chunk_size = CHUNK_SIZE * factor + s = chunk_idx * out_chunk_size + e = min(s + out_chunk_size, fused_out.size(0)) + return fused_out[s:e] + + @staticmethod + def _slice_expert_tokens_metadata( + num_chunks: int, + full_expert_tokens_meta: Optional[ExpertTokensMetadata], + chunk_topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Optional[ExpertTokensMetadata]: + if num_chunks == 1 or full_expert_tokens_meta is None: + return full_expert_tokens_meta + + # The existing expert_num_tokens is for the entire a1q + # input. Chunking forces recomputation of the number + # of tokens assigned to each expert. + c_expert_num_tokens = count_expert_num_tokens( + chunk_topk_ids, local_num_experts, expert_map + ) + + c_expert_num_tokens_cpu = None + need_expert_num_tokens_cpu = ( + full_expert_tokens_meta.expert_num_tokens_cpu is not None + ) + if need_expert_num_tokens_cpu: + # This is blocking as some implementations need the count + # on the CPU to determine appropriate input/out fused-moe + # buffers + c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False) + + return ExpertTokensMetadata( + expert_num_tokens=c_expert_num_tokens, + expert_num_tokens_cpu=c_expert_num_tokens_cpu, + ) + + def _prepare( self, - a1: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + torch.Tensor, + torch.Tensor, + ]: + """ + The _prepare method is a wrapper around self.prepare_finalize.prepare + that handles DBO and async. + """ + if not self.prepare_finalize.supports_async(): + # We shouldn't be running an a2a kernel that doesn't + # support async prepare/finalize + # TODO(lucas): enable in follow-up + assert not dbo_enabled() + + ( + a1q, + a1q_scale, + expert_tokens_meta, + _expert_topk_ids, + _expert_topk_weights, + ) = self.prepare_finalize.prepare( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + else: + # Overlap shared expert compute with all2all dispatch. + dbo_maybe_run_recv_hook() + prepare_ret = self.prepare_finalize.prepare_async( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = ( + prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) + ) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() + + ( + a1q, + a1q_scale, + expert_tokens_meta, + _expert_topk_ids, + _expert_topk_weights, + ) = receiver() + + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. + topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids + topk_weights = ( + topk_weights if _expert_topk_weights is None else _expert_topk_weights + ) + + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights + + def _fused_experts( + self, + in_dtype: torch.dtype, a1q: torch.Tensor, + a1q_scale: Optional[torch.Tensor], w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -545,135 +954,154 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, + expert_tokens_meta: Optional[ExpertTokensMetadata], ) -> torch.Tensor: + _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( + a1q, w1, w2, topk_ids + ) - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + num_chunks, CHUNK_SIZE = self._chunk_info(M_full) - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - num_chunks = cdiv(M, CHUNK_SIZE) + def input_chunk_range(chunk_idx: int) -> tuple[int, int]: + if num_chunks == 1: + # Use a1q.size(0) here since batched format does not + # keep M in the first dimension. + return 0, a1q.size(0) + else: + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M_full) + return s, e - # TODO(bnell): get rid of one level here, update slice functions - # to nops on num_chunks==1 - - if not self.fused_experts.supports_chunking() or num_chunks == 1: - return self._do_fused_experts( - fused_out=None, - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, + # This happens when none of the tokens from the all2all reach this + # EP rank. Also, note that this is only relevant for CUDAGraph + # incompatible all2all kernels like the DeepEP high-throughput + # kernels. CUDAGraph compatible all2all kernels like the pplx + # kernels and the DeepEP low-latency kernels are always batched + # and can never run into the tensor.numel() == 0 case. + if M_full == 0: + assert num_chunks == 0 + workspace13 = None + workspace2 = None + fused_out = torch.empty_like(a1q) + else: + assert num_chunks > 0 + workspace13, workspace2, fused_out = self._allocate_buffers( + in_dtype, + a1q.device, + CHUNK_SIZE, + M_full, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, ) - # Chunking required case - assert num_chunks > 1 - - # Construct the entire output that can then be processed in chunks. - (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, - expert_tokens_meta) - fused_out = torch.empty(fused_out_shape, - device=a1q.device, - dtype=a1.dtype) - - def slice_input_tensors( - chunk_idx: int - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, torch.Tensor]: - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - return (a1q[s:e], _chunk_scales(a1q_scale, s, e), - _chunk_scales(a2_scale, s, - e), topk_ids[s:e], topk_weights[s:e]) - - def slice_output_tensor(chunk_idx: int) -> torch.Tensor: - assert fused_out.size(0) % M == 0, ( - f"fused_out shape {fused_out.shape} vs M {M}") - factor = fused_out.size(0) // M - out_chunk_size = CHUNK_SIZE * factor - s = chunk_idx * out_chunk_size - e = min(s + out_chunk_size, fused_out.size(0)) - return fused_out[s:e] - - def slice_expert_tokens_metadata( - full_expert_tokens_meta: ExpertTokensMetadata, - chunk_topk_ids: torch.Tensor, local_num_experts: int, - expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata: - # The existing expert_num_tokens is for the entire a1q - # input. Chunking forces recomputation of the number - # of tokens assigned to each expert. - c_expert_num_tokens = count_expert_num_tokens( - chunk_topk_ids, local_num_experts, expert_map) - - c_expert_num_tokens_cpu = None - need_expert_num_tokens_cpu = ( - full_expert_tokens_meta.expert_num_tokens_cpu is not None) - if need_expert_num_tokens_cpu: - # This is blocking as some implementations need the count - # on the CPU to determine appropriate input/out fused-moe - # buffers - c_expert_num_tokens_cpu = c_expert_num_tokens.to( - "cpu", non_blocking=False) - - return ExpertTokensMetadata( - expert_num_tokens=c_expert_num_tokens, - expert_num_tokens_cpu=c_expert_num_tokens_cpu) - for chunk_idx in range(num_chunks): - c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( - slice_input_tensors(chunk_idx)) + s, e = input_chunk_range(chunk_idx) - c_expert_tokens_meta = None - if expert_tokens_meta is not None: - c_expert_tokens_meta = slice_expert_tokens_metadata( - expert_tokens_meta, c_topk_ids, local_num_experts, - expert_map) + c_expert_tokens_meta = self._slice_expert_tokens_metadata( + num_chunks, + expert_tokens_meta, + topk_ids[s:e], + local_num_experts, + expert_map, + ) - self._do_fused_experts( - fused_out=slice_output_tensor(chunk_idx), - a1=a1, - a1q=c_a1q, + c_fused_out = self._slice_output_tensor( + fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full + ) + + self.fused_experts.apply( + output=c_fused_out, + hidden_states=a1q[s:e], w1=w1, w2=w2, - topk_weights=c_topk_weights, - topk_ids=c_topk_ids, + topk_weights=topk_weights[s:e], + topk_ids=topk_ids[s:e], activation=activation, global_num_experts=global_num_experts, - local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, + a1q_scale=_slice_scales(a1q_scale, s, e), + a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e), + workspace13=workspace13, + workspace2=workspace2, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) return fused_out + def _finalize( + self, + output: torch.Tensor, + fused_out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + The _finalize method is a wrapper around self.prepare_finalize.finalize + that handles DBO, async and shared expert overlap. + """ + shared_output: Optional[torch.Tensor] = None + + if not self.prepare_finalize.supports_async(): + assert not dbo_enabled() + + self.prepare_finalize.finalize( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + else: + finalize_ret = self.prepare_finalize.finalize_async( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = ( + finalize_ret + if isinstance(finalize_ret, tuple) + else (None, finalize_ret) + ) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() + + receiver() + + if self.shared_experts is None: + return output + else: + assert shared_output is not None + return shared_output, output + def forward( self, hidden_states: torch.Tensor, @@ -685,14 +1113,8 @@ class FusedMoEModularKernel(torch.nn.Module): activation: str = "silu", global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -713,14 +1135,6 @@ class FusedMoEModularKernel(torch.nn.Module): - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. - - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. @@ -729,70 +1143,45 @@ class FusedMoEModularKernel(torch.nn.Module): - torch.Tensor: The output tensor after applying the MoE layer. """ - a1 = hidden_states - output = a1 if inplace else torch.zeros_like(a1) + if inplace and self.shared_experts is None and not disable_inplace(): + output = hidden_states + else: + output = torch.zeros_like(hidden_states) local_num_experts = w1.size(0) if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + ) - # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. - topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids - topk_weights = (topk_weights if _expert_topk_weights is None else - _expert_topk_weights) + fused_out = self._fused_experts( + in_dtype=hidden_states.dtype, + a1q=a1q, + a1q_scale=a1q_scale, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_tokens_meta=expert_tokens_meta, + ) - fused_out = None - - if a1q.numel() == 0: - # This happens when none of the tokens from the all2all reach this - # EP rank. Also, note that this is only relevant for CUDAGraph - # incompatible all2all kernels like the DeepEP high-throughput - # kernels. CUDAGraph compatible all2all kernels like the pplx - # kernels and the DeepEP low-latency kernels are always batched - # and can never run into the tensor.numel() == 0 case. - fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) - else: - fused_out = self._maybe_chunk_fused_experts( - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - self.prepare_finalize.finalize( + return self._finalize( output, fused_out, + hidden_states, topk_weights, topk_ids, apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl(), ) - - return output diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index c7d7126bab3ad..9994088ca5d9a 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -14,7 +14,7 @@ def moe_align_block_size( block_size: int, num_experts: int, expert_map: Optional[torch.Tensor] = None, - pad_sorted_ids: bool = False + pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -68,19 +68,18 @@ def moe_align_block_size( max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + ops.moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) if expert_map is not None: expert_ids = expert_map[expert_ids] diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 582ae3e12c289..66c00cf89873a 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -7,18 +7,20 @@ import torch.nn.functional as F def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: """ - Compute the histogram of a int32 tensor. The bin edges are defined by the - min and max values, with step = 1. - """ + Compute the histogram of an int32 tensor. The bin edges are defined by the + min and max values, with step = 1. + """ assert input.dtype == torch.int32, "input must be of torch.int32 dtype." assert min <= max, "min must be less than or equal to max." - def searchsorted(sorted_sequence: torch.Tensor, - values_to_search: torch.Tensor) -> torch.Tensor: + def searchsorted( + sorted_sequence: torch.Tensor, values_to_search: torch.Tensor + ) -> torch.Tensor: return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) - bin_edges = torch.linspace(min, max, max - min + 1, - dtype=input.dtype).to(input.device) + bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to( + input.device + ) return searchsorted(bin_edges, input).to(torch.int32) @@ -41,6 +43,7 @@ def fused_moe( """ assert expert_map is None, "expert_map is not supported for pallas MoE." import torch_xla.experimental.custom_kernel # noqa: F401 + orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() @@ -50,7 +53,8 @@ def fused_moe( dtype = hidden_states.dtype assert (num_tokens * topk) % 16 == 0, ( "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " - f"16 but got {num_tokens * topk}") + f"16 but got {num_tokens * topk}" + ) hidden_states = hidden_states.view(num_tokens, hidden_size) gating_output = gating_output.view(num_tokens, num_experts) @@ -63,8 +67,7 @@ def fused_moe( topk_indices = topk_indices.flatten() topk_argsort_indices = topk_indices.argsort() topk_argsort_revert_indices = topk_argsort_indices.argsort() - token_indices = torch.arange(num_tokens, - device=device).repeat_interleave(topk) + token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk) token_indices = token_indices[topk_argsort_indices] group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 16a155e718478..698080f8aec6f 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -6,7 +6,8 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm @@ -17,8 +18,9 @@ def _moe_permute( global_num_experts: int, expert_map: Optional[torch.Tensor], block_m: int, -) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]: +) -> tuple[ + torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor +]: """ Determine the sorted_token_ids, expert_ids for the given problem size. Permute the hidden states and scales according to `sorted_token_ids`. @@ -27,12 +29,9 @@ def _moe_permute( tokens_in_chunk = curr_hidden_states.size(0) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True + ) inv_perm: Optional[torch.Tensor] = None @@ -43,14 +42,12 @@ def _moe_permute( # Permute according to sorted token ids. sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) + curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num) if a1q_scale is not None: a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm) def _moe_unpermute_and_reduce( @@ -84,8 +81,9 @@ def moe_permute( align_block_size: Optional[int] = None, fill_invalid_expert: int = -1, permuted_hidden_states: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]: +) -> tuple[ + torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor +]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. @@ -117,13 +115,21 @@ def moe_permute( """ n_token, n_hidden = hidden_states.size() topk = topk_ids.size(1) - assert (n_hidden * hidden_states.element_size() - ) % 16 == 0, "permue kernel need hidden dim align to 16B" + assert (n_hidden * hidden_states.element_size()) % 16 == 0, ( + "permue kernel need hidden dim align to 16B" + ) permuted_row_size = n_token * topk if align_block_size is not None: - permuted_row_size = (permuted_row_size + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size + permuted_row_size = ( + ( + permuted_row_size + + n_expert * (align_block_size - 1) + + align_block_size + - 1 + ) + // align_block_size + * align_block_size + ) if n_local_expert == -1: n_local_expert = n_expert if permuted_hidden_states is None: @@ -134,40 +140,57 @@ def moe_permute( ) assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" - f" but got {permuted_hidden_states.size()}") + f" but got {permuted_hidden_states.size()}" + ) - token_expert_indices = torch.arange(0, - n_token * topk, - dtype=torch.int32, - device=hidden_states.device).reshape( - (n_token, topk)) + token_expert_indices = torch.arange( + 0, n_token * topk, dtype=torch.int32, device=hidden_states.device + ).reshape((n_token, topk)) - m_indices = torch.full((permuted_row_size, ), - fill_invalid_expert, - dtype=torch.int32, - device=hidden_states.device) - expert_first_token_offset = torch.empty(n_local_expert + 1, - dtype=torch.int64, - device=hidden_states.device) - permuted_idx = torch.full((permuted_row_size, ), - n_token * topk, - dtype=torch.int32, - device=hidden_states.device) - inv_permuted_idx = torch.empty((n_token, topk), - dtype=torch.int32, - device=hidden_states.device) + m_indices = torch.full( + (permuted_row_size,), + fill_invalid_expert, + dtype=torch.int32, + device=hidden_states.device, + ) + expert_first_token_offset = torch.empty( + n_local_expert + 1, dtype=torch.int64, device=hidden_states.device + ) + permuted_idx = torch.full( + (permuted_row_size,), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device, + ) + inv_permuted_idx = torch.empty( + (n_token, topk), dtype=torch.int32, device=hidden_states.device + ) topk_ids = topk_ids.to(torch.int32) - torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices, - expert_map, n_expert, n_local_expert, topk, - align_block_size, permuted_hidden_states, - expert_first_token_offset, inv_permuted_idx, - permuted_idx, m_indices) + torch.ops._moe_C.moe_permute( + hidden_states, + topk_ids, + token_expert_indices, + expert_map, + n_expert, + n_local_expert, + topk, + align_block_size, + permuted_hidden_states, + expert_first_token_offset, + inv_permuted_idx, + permuted_idx, + m_indices, + ) if a1q_scale is not None and a1q_scale.dim() > 1: - a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // - topk] - return (permuted_hidden_states, a1q_scale, expert_first_token_offset, - inv_permuted_idx.flatten(), m_indices) + a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] + return ( + permuted_hidden_states, + a1q_scale, + expert_first_token_offset, + inv_permuted_idx.flatten(), + m_indices, + ) def moe_unpermute( @@ -185,7 +208,7 @@ def moe_unpermute( - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - - expert_first_token_offset (Optional[torch.Tensor]): offset of the first + - expert_first_token_offset (Optional[torch.Tensor]): offset of the first token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation @@ -193,12 +216,18 @@ def moe_unpermute( """ topk = topk_weights.size(1) n_hidden = permuted_hidden_states.size(-1) - assert (n_hidden * permuted_hidden_states.element_size() - ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + assert (n_hidden * permuted_hidden_states.element_size()) % 16 == 0, ( + "unpermue kernel need hidden dim align to 16B" + ) - torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, - inv_permuted_idx, expert_first_token_offset, - topk, out) + torch.ops._moe_C.moe_unpermute( + permuted_hidden_states, + topk_weights, + inv_permuted_idx, + expert_first_token_offset, + topk, + out, + ) def moe_permute_unpermute_supported(): diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py index 6160da7329518..f721d00d75ea7 100644 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -45,7 +45,7 @@ def fused_moe( for expert_idx in range(num_experts): expert_w1 = w1[expert_idx] expert_w2 = w2[expert_idx] - expert_mask = (selected_experts == expert_idx) + expert_mask = selected_experts == expert_idx expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) x = F.linear(hidden_states, expert_w1) gate = F.silu(x[:, :intermediate_size]) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 401f37922b7bb..e87953e34eaf2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Callable, Optional, Union import pplx_kernels as pplx import torch @@ -9,9 +9,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.model_executor.layers.fused_moe.utils import ( - _validate_scale_shape, moe_kernel_quantize_input) + _validate_scale_shape, + moe_kernel_quantize_input, +) from vllm.utils import cdiv, round_up logger = init_logger(__name__) @@ -60,7 +63,6 @@ def pplx_hidden_dim_scale_bytes( class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - def __init__( self, a2a: pplx.AllToAll, @@ -84,25 +86,27 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ - def prepare( + def output_is_reduced(self) -> bool: + return True + + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[Callable, mk.ReceiverType]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -114,8 +118,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): if expert_map is not None: logger.warning_once( "The PPLX backend does not support expert mapping. " - "The provided `expert_map` will be ignored.") - expert_map = None #noqa: F841 + "The provided `expert_map` will be ignored." + ) + expert_map = None # noqa: F841 # Is this always going to be a1.device? device = a1.device @@ -124,19 +129,26 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") + "apply_router_weight_on_input is only implemented for topk=1" + ) a1 = a1 * topk_weights.to(a1.dtype) repeat_cols = 4 repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) + # TODO(bnell): always pass quant_config.a1_scale? a1q, a1q_scale = moe_kernel_quantize_input( - a1, (None if quant_config.per_act_token_quant else a1_scale), + a1, + (None if quant_config.per_act_token_quant else quant_config.a1_scale), quant_dtype=quant_config.quant_dtype, per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape) + block_shape=quant_config.block_shape, + ) - _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, - quant_config.block_shape) + _validate_scale_shape( + a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape + ) + + orig_a_scale_block_shape: Optional[int] = None if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -152,8 +164,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # TODO (bnell): use group_broadcast instead? a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) - assert a1q_scale is None or a1q_scale.ndim == 2, \ + assert a1q_scale is None or a1q_scale.ndim == 2, ( f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" + ) expert_num_tokens = torch.empty( self.num_local_experts, @@ -162,8 +175,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) expert_x = torch.empty( - (self.num_local_experts, - self.max_num_tokens * self.num_dispatchers(), hidden_dim), + ( + self.num_local_experts, + self.max_num_tokens * self.num_dispatchers(), + hidden_dim, + ), dtype=a1q.dtype, device=device, ) @@ -179,14 +195,13 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): else: # (M x K_tiles) -> (E x M x K_tiles) assert quant_config.block_shape is not None - num_blocks = cdiv(expert_x.size(2), - quant_config.block_shape[1]) + num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) final_dim = num_blocks expert_x_scale_shape = ( self.num_local_experts, expert_x.size(1), - round_up(final_dim, 4) # round up for alignment + round_up(final_dim, 4), # round up for alignment ) expert_x_scale = torch.empty( @@ -205,19 +220,128 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids, bound_m=bound_m, + do_send=True, + do_recv=False, ) + hook = lambda: self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) + + return ( + hook, + lambda: self._receiver( + expert_num_tokens, + expert_x, + expert_x_scale, + orig_a_scale_block_shape, + ), + ) + + def _receiver( + self, + expert_num_tokens: torch.Tensor, + expert_x: torch.Tensor, + expert_x_scale: Optional[torch.Tensor], + orig_a_scale_block_shape: Optional[int], + ) -> mk.PrepareResultType: if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None + ) return expert_x, expert_x_scale, expert_tokens_meta, None, None + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook, receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + hook() + return receiver() + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( + "Weight application and reduction happens in the combine kernel." + ) + + # This argument is optional + # There's not much point setting this unless it is != topk_ids.size(0) + bound_m: Optional[torch.Tensor] = None + + # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on + # num_tokens = output.size(0) # M + # assert topk_ids.size(0) == num_tokens, ( + # f"{topk_ids.size(0)} == {num_tokens}") + assert topk_ids.size() == topk_weights.size(), ( + f"{topk_ids.size()} == {topk_weights.size()}" + ) + assert output.size(0) <= self.max_num_tokens, ( + f"{output.size(0)} <= {self.max_num_tokens}" + ) + assert output.size(1) == fused_expert_output.size(-1) + + # Set weights to 1 if we did them in dispatch. This is hacky. + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + + topk_ids_u32 = topk_ids.view(dtype=torch.uint32) + + self.a2a.combine( + out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + return lambda: self.a2a.combine( + out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) + def finalize( self, output: torch.Tensor, @@ -227,30 +351,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - assert isinstance( - weight_and_reduce_impl, TopKWeightAndReduceDelegate - ), ("Weight application and reduction happens in the combine kernel.") - - # This argument is optional - # There's not much point setting this unless it is != topk_ids.size(0) - bound_m: Optional[torch.Tensor] = None - - # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on - #num_tokens = output.size(0) # M - #assert topk_ids.size(0) == num_tokens, ( - # f"{topk_ids.size(0)} == {num_tokens}") - assert topk_ids.size() == topk_weights.size(), ( - f"{topk_ids.size()} == {topk_weights.size()}") - assert output.size(0) <= self.max_num_tokens, ( - f"{output.size(0)} <= {self.max_num_tokens}") - assert output.size(1) == fused_expert_output.size(-1) - - # Set weights to 1 if we did them in dispatch. This is hacky. - if apply_router_weight_on_input: - topk_weights = torch.ones_like(topk_weights) - - self.a2a.combine(out_tokens=output, - indices=topk_ids.view(dtype=torch.uint32), - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m) + receiver = self.finalize_async( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + ) + receiver() diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 567a0a88fec0a..1e572d2394781 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -7,13 +7,13 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -27,31 +27,34 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return 1 + def output_is_reduced(self) -> bool: + return False + def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: - + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ + assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a1.mul_(topk_weights.to(a1.dtype)) a1q, a1q_scale = moe_kernel_quantize_input( - a1, a1_scale, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + a1, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) return a1q, a1q_scale, None, None, None @@ -71,4 +74,5 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): fused_expert_output=fused_expert_output, topk_weights=topk_weights, topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 93e20c3477bbe..801785b18fb9e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -7,6 +7,10 @@ from typing import Optional import torch from vllm import envs +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -36,138 +40,162 @@ class ActivationMethod(IntEnum): @cache def is_rocm_aiter_moe_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_MOE \ + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_MOE and envs.VLLM_ROCM_USE_AITER + ) def rocm_aiter_asm_moe_tkw1_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - per_tensor_quant_scale: Optional[torch.Tensor] = None, - expert_mask: Optional[torch.Tensor] = None, - activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: - + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, +) -> torch.Tensor: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe_tkw1 activation = ActivationType(activation_method) - return asm_moe_tkw1(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - per_tensor_quant_scale=per_tensor_quant_scale, - expert_mask=expert_mask, - activation=activation) + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) def rocm_aiter_asm_moe_tkw1_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - per_tensor_quant_scale: Optional[torch.Tensor] = None, - expert_mask: Optional[torch.Tensor] = None, - activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, +) -> torch.Tensor: return torch.empty_like(hidden_states) -def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> None: +def rocm_aiter_topk_softmax_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: from aiter import topk_softmax - topk_softmax(topk_weights, topk_indices, token_expert_indices, - gating_output, renormalize) + + topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) -def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> None: +def rocm_aiter_topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: pass def rocm_aiter_biased_grouped_topk_impl( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: - from aiter import biased_grouped_topk - biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids, - num_expert_group, topk_group, need_renorm, - routed_scaling_factor) + biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) def rocm_aiter_biased_grouped_topk_fake( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: pass def rocm_aiter_grouped_topk_impl( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: - from aiter import grouped_topk - grouped_topk(gating_output, topk_weights, topk_ids, num_expert_group, - topk_group, need_renorm, scoring_func, routed_scaling_factor) + grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + scoring_func, + routed_scaling_factor, + ) def rocm_aiter_grouped_topk_fake( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: pass @@ -193,9 +221,21 @@ def rocm_aiter_fused_moe_impl( activation = ActivationType(activation_method) quant_type = QuantType(quant_method) - return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, - activation, quant_type, doweight_stage1, w1_scale, - w2_scale, a1_scale, a2_scale) + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) def rocm_aiter_fused_moe_fake( @@ -217,21 +257,16 @@ def rocm_aiter_fused_moe_fake( if current_platform.is_rocm(): - direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=rocm_aiter_asm_moe_tkw1_impl, - mutates_args=[], fake_impl=rocm_aiter_asm_moe_tkw1_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_fused_moe", op_func=rocm_aiter_fused_moe_impl, - mutates_args=[], fake_impl=rocm_aiter_fused_moe_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -239,7 +274,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_topk_softmax_impl, mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], fake_impl=rocm_aiter_topk_softmax_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -247,7 +281,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_biased_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=rocm_aiter_biased_grouped_topk_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( @@ -255,7 +288,6 @@ if current_platform.is_rocm(): op_func=rocm_aiter_grouped_topk_impl, mutates_args=["topk_weights", "topk_ids"], fake_impl=rocm_aiter_grouped_topk_fake, - dispatch_key=current_platform.dispatch_key, ) @@ -267,19 +299,18 @@ def rocm_aiter_grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( gating_output, - e_score_correction_bias, + e_score_correction_bias.to(gating_output.dtype), topk_weights, topk_ids, num_expert_group, @@ -287,7 +318,7 @@ def rocm_aiter_grouped_topk( renormalize, ) else: - assert (scoring_func == "softmax" or scoring_func == "sigmoid") + assert scoring_func == "softmax" or scoring_func == "sigmoid" torch.ops.vllm.rocm_aiter_grouped_topk( gating_output, topk_weights, @@ -298,47 +329,49 @@ def rocm_aiter_grouped_topk( scoring_func, ) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights, topk_ids def rocm_aiter_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + expert_map: Optional[torch.Tensor] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, +) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - activation_method = (ActivationMethod.SILU - if activation == "silu" else ActivationMethod.GELU) + activation_method = ( + ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU + ) # All AITER Fused MoE kernels are expecting the following datatypes topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) - if expert_map is not None: - expert_mask = (expert_map > -1).to(torch.int32) - else: - expert_mask = None + expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None # w8a8 per-channel quantization - if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + if ( + quant_config.per_act_token_quant + and apply_router_weight_on_input + and quant_config.use_fp8_w8a8 + ): # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer # rather than the second FC. - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.dim() == 2, ( + "`topk_weights` should be in shape (num_tokens, topk)" + ) assert topk_weights.shape[-1] == 1, ( - "Only support topk=1 when" - " `apply_router_weight_on_input` is True") + "Only support topk=1 when `apply_router_weight_on_input` is True" + ) return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( hidden_states, @@ -346,37 +379,40 @@ def rocm_aiter_fused_experts( w2, topk_weights, topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, + fc1_scale=quant_config.w1_scale, + fc2_scale=quant_config.w2_scale, fc1_smooth_scale=None, fc2_smooth_scale=None, a16=False, per_tensor_quant_scale=None, expert_mask=expert_mask, - activation_method=activation_method) + activation_method=activation_method, + ) else: quant_method = QuantMethod.NO.value # w8a8 block-scaled - if block_shape is not None and use_fp8_w8a8: + if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is\ - not supported for block scaled moe") - assert w1_scale is not None - assert w2_scale is not None + not supported for block scaled moe" + ) + assert quant_config.w1_scale is not None + assert quant_config.w2_scale is not None quant_method = QuantMethod.BLOCK_128x128.value - elif use_fp8_w8a8: + elif quant_config.use_fp8_w8a8: # Currently only per tensor quantization method is enabled. quant_method = QuantMethod.PER_TENSOR.value if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.dim() == 2, ( + "`topk_weights` should be in shape (num_tokens, topk)" + ) _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" + assert topk == 1, ( + "Only support topk=1 when `apply_router_weight_on_input` is True" + ) return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, @@ -387,21 +423,24 @@ def rocm_aiter_fused_experts( expert_mask=expert_mask, quant_method=quant_method, activation_method=activation_method, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - doweight_stage1=apply_router_weight_on_input) + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, + doweight_stage1=apply_router_weight_on_input, + ) -def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: - torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, - token_expert_indices, gating_output, - renormalize) +def rocm_aiter_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) return topk_weights, topk_indices @@ -409,17 +448,16 @@ def shuffle_weights( *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> tuple[torch.Tensor, ...]: """ - Applies shuffle_weight function from AITER to each + Applies shuffle_weight function from AITER to each input tensor and returns them. - + Rearranges (shuffles) the input tensor/s into a specified block layout for optimized computation. Args: *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the - block sizes used to divide the tensors during shuffling. - Default is (16, 16). + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). Returns: A Tuple of shuffled tensors. diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py index c8b107f13cd0d..af20f4b7c1d2b 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -10,7 +10,7 @@ like uniform random routing. """ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional import torch @@ -50,7 +50,7 @@ class DistributionBasedRouting(RoutingStrategy): distributions for testing different routing patterns. """ - def __init__(self, distribution: str = "uniform", **distribution_params): + def __init__(self, distribution: str = "uniform", **distribution_params: Any): """ Initialize distribution-based routing. @@ -74,8 +74,10 @@ class DistributionBasedRouting(RoutingStrategy): valid_distributions = ["uniform", "normal"] if self.distribution not in valid_distributions: - raise ValueError(f"Unsupported distribution: {self.distribution}. " - f"Supported distributions: {valid_distributions}") + raise ValueError( + f"Unsupported distribution: {self.distribution}. " + f"Supported distributions: {valid_distributions}" + ) # Set default parameters if not provided if self.distribution == "normal": @@ -110,12 +112,12 @@ class DistributionBasedRouting(RoutingStrategy): indices_type = torch.long # Generate expert IDs based on the specified distribution - topk_ids = self._sample_expert_ids(num_tokens, num_experts, top_k, - hidden_states.device, indices_type) + topk_ids = self._sample_expert_ids( + num_tokens, num_experts, top_k, hidden_states.device, indices_type + ) # Generate weights based on the distribution - topk_weights = self._generate_weights(num_tokens, top_k, - hidden_states.device) + topk_weights = self._generate_weights(num_tokens, top_k, hidden_states.device) return topk_weights, topk_ids @@ -143,7 +145,8 @@ class DistributionBasedRouting(RoutingStrategy): # For normal distribution, sample continuous values and map to # expert IDs continuous_samples = self._sample_continuous_distribution( - num_tokens, top_k, device) + num_tokens, top_k, device + ) # Map continuous samples to expert indices # Normalize to [0, 1] range and scale to [0, num_experts) @@ -156,8 +159,9 @@ class DistributionBasedRouting(RoutingStrategy): else: raise ValueError(f"Unsupported distribution: {self.distribution}") - def _sample_continuous_distribution(self, num_tokens: int, top_k: int, - device: torch.device) -> torch.Tensor: + def _sample_continuous_distribution( + self, num_tokens: int, top_k: int, device: torch.device + ) -> torch.Tensor: """Sample from continuous distributions.""" shape = (num_tokens, top_k) @@ -168,7 +172,8 @@ class DistributionBasedRouting(RoutingStrategy): else: raise ValueError( - f"Unsupported continuous distribution: {self.distribution}") + f"Unsupported continuous distribution: {self.distribution}" + ) def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor: """Normalize samples to [0, 1] range.""" @@ -177,11 +182,13 @@ class DistributionBasedRouting(RoutingStrategy): return torch.sigmoid(samples) else: - raise ValueError(f"Unsupported distribution for normalization: " - f"{self.distribution}") + raise ValueError( + f"Unsupported distribution for normalization: {self.distribution}" + ) - def _generate_weights(self, num_tokens: int, top_k: int, - device: torch.device) -> torch.Tensor: + def _generate_weights( + self, num_tokens: int, top_k: int, device: torch.device + ) -> torch.Tensor: """Generate weights based on the distribution.""" if self.distribution == "uniform": # All-ones weights for uniform distribution @@ -195,7 +202,8 @@ class DistributionBasedRouting(RoutingStrategy): # For normal distribution, generate weights from the same # distribution continuous_weights = self._sample_continuous_distribution( - num_tokens, top_k, device) + num_tokens, top_k, device + ) # Normalize to positive values and sum to 1 weights = torch.abs(continuous_weights) weights = weights / weights.sum(dim=-1, keepdim=True) @@ -203,14 +211,14 @@ class DistributionBasedRouting(RoutingStrategy): else: raise ValueError( - f"Unsupported distribution for weight generation: " - f"{self.distribution}") + f"Unsupported distribution for weight generation: {self.distribution}" + ) def get_distribution_info(self) -> dict: """Get information about the current distribution configuration.""" return { "distribution": self.distribution, - "parameters": self.distribution_params.copy() + "parameters": self.distribution_params.copy(), } @@ -226,10 +234,12 @@ class RoutingSimulator: # Class-level registry of routing strategies _routing_strategies: dict[str, RoutingStrategy] = { # Basic routing strategies - "uniform_random": - DistributionBasedRouting(distribution="uniform", mean=0.0, std=1.0), - "normal_routing": - DistributionBasedRouting(distribution="normal", mean=0.0, std=1.0), + "uniform_random": DistributionBasedRouting( + distribution="uniform", mean=0.0, std=1.0 + ), + "normal_routing": DistributionBasedRouting( + distribution="normal", mean=0.0, std=1.0 + ), } @classmethod @@ -244,7 +254,7 @@ class RoutingSimulator: cls._routing_strategies[name] = strategy @classmethod - def get_available_strategies(cls): + def get_available_strategies(cls) -> list[str]: """ Get list of available routing strategy names. @@ -278,7 +288,8 @@ class RoutingSimulator: raise ValueError( f"Unknown routing strategy: {strategy_name}. " f"Available strategies: " - f"{list(RoutingSimulator._routing_strategies.keys())}") + f"{list(RoutingSimulator._routing_strategies.keys())}" + ) strategy = RoutingSimulator._routing_strategies[strategy_name] return strategy.route_tokens( diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py new file mode 100644 index 0000000000000..a678fdae8833e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +# TODO(bnell): Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: Optional[torch.nn.Module], + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + # Disable shared expert overlap if EP is disabled or we are not using + # flashinfer + DP since there is nothing to be gained in this case. + # Disabling the overlap optimization also prevents the shared experts + # from being hidden from torch.compile. + self.use_overlapped = ( + use_overlapped + and not (self.use_ep or self.use_flashinfer_cutlass_kernels) + and self._shared_experts is not None + ) + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return self._shared_experts if self.use_overlapped else None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + if self._shared_experts is not None: + shared_out = self._shared_experts(hidden_states) + + # Reduce shared expert outputs if necessary, since the MLP + # should have been created with reduce_results=False. + if ( + self.reduce_results + and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs() + ): + shared_out = tensor_model_parallel_all_reduce(shared_out) + else: + shared_out = None + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index fb398eec119fa..e725a0f00363e 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -19,7 +19,7 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize does the weight-application + reduction as part of the pplx combine kernel. But the BatchedPrepareAndFinalize needs an implementation. To facilitate - this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate + this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate so the PrepareAndFinalize implementations could choose how to weight + reduce. """ @@ -27,12 +27,18 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): def __eq__(self, other): return isinstance(other, TopKWeightAndReduceDelegate) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: - raise RuntimeError("The caller is expected to choose an appropriate " - "TopKWeightAndReduce implementation.") + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: + raise RuntimeError( + "The caller is expected to choose an appropriate " + "TopKWeightAndReduce implementation." + ) class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): @@ -44,10 +50,14 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): def __eq__(self, other): return isinstance(other, TopKWeightAndReduceNoOP) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: # Weight application and reduction operations are already done. if output is None: return fused_expert_output @@ -57,7 +67,8 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): assert output.size() == fused_expert_output.size(), ( "output shape is expected to match the fused_expert_output shape. " f"But got output={output.size()}, " - f"used_expert_output={fused_expert_output.size()}") + f"used_expert_output={fused_expert_output.size()}" + ) output.copy_(fused_expert_output, non_blocking=True) return output @@ -71,11 +82,14 @@ class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): def __eq__(self, other): return isinstance(other, TopKWeightAndReduceContiguous) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: - + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: m, num_topk = topk_ids.size() k = fused_expert_output.size(-1) if fused_expert_output.ndim == 2: @@ -83,17 +97,21 @@ class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): assert fused_expert_output.size() == (m, num_topk, k), ( f"Expected fused_expert_output size {(m, num_topk, k)}. But got " - f"{fused_expert_output.size()}") + f"{fused_expert_output.size()}" + ) if not apply_router_weight_on_input: fused_expert_output.mul_(topk_weights.view(m, -1, 1)) if output is None: - output = torch.empty((m, k), - device=fused_expert_output.device, - dtype=fused_expert_output.dtype) + output = torch.empty( + (m, k), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype, + ) assert output.size() == (m, k), ( - f"Expected output size {(m, k)}. But got {output.size()}") + f"Expected output size {(m, k)}. But got {output.size()}" + ) ops.moe_sum(fused_expert_output, output) return output @@ -109,27 +127,35 @@ class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce): self.rank = rank def __eq__(self, other): - return (isinstance(other, TopKWeightAndReduceNaiveBatched) - and (other.rank == self.rank)) + return isinstance(other, TopKWeightAndReduceNaiveBatched) and ( + other.rank == self.rank + ) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: assert fused_expert_output.ndim == 3 num_tokens = topk_ids.size(0) num_local_experts = fused_expert_output.size(0) K = fused_expert_output.size(-1) if output is None: - output = torch.zeros((num_tokens, K), - device=fused_expert_output.device, - dtype=fused_expert_output.dtype) + output = torch.zeros( + (num_tokens, K), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype, + ) else: output.fill_(0) assert output.size() == (num_tokens, K), ( - f"Expected output size {(num_tokens, K)}, but got {output.size()}") + f"Expected output size {(num_tokens, K)}, but got {output.size()}" + ) first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 486ca881df48c..94a3ba74e47fd 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -7,71 +7,59 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, - deep_gemm_block_shape) + DeepGemmExperts, + _valid_deep_gemm, + _valid_deep_gemm_shape, +) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, allow_deep_gemm: bool = False, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - self.triton_expert = TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + super().__init__(quant_config) + + self.triton_expert = TritonExperts(quant_config) + + self.allow_deep_gemm = ( + allow_deep_gemm + and self.quant_config.use_fp8_w8a8 + and self.block_shape == deep_gemm_block_shape() ) - self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and - self.block_shape == deep_gemm_block_shape()) - - self.deep_gemm_expert = DeepGemmExperts( - ) if self.allow_deep_gemm else None + self.deep_gemm_expert = ( + DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None + ) @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - assert (self.deep_gemm_expert is None - or self.triton_expert.activation_formats - == self.deep_gemm_expert.activation_formats) + assert ( + self.deep_gemm_expert is None + or self.triton_expert.activation_formats + == self.deep_gemm_expert.activation_formats + ) return self.triton_expert.activation_formats def supports_chunking(self) -> bool: dge = self.deep_gemm_expert te = self.triton_expert - return ((dge is None or dge.supports_chunking()) - and (te is None or te.supports_chunking())) + return (dge is None or dge.supports_chunking()) and ( + te is None or te.supports_chunking() + ) def supports_expert_map(self) -> bool: dge = self.deep_gemm_expert te = self.triton_expert - return ((dge is None or dge.supports_expert_map()) - and (te is None or te.supports_expert_map())) + return (dge is None or dge.supports_expert_map()) and ( + te is None or te.supports_expert_map() + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: dge = self.deep_gemm_expert @@ -84,7 +72,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): if is_dge_war and is_te_war: assert dge_war == te_war, ( "Both implementations should agree on WeightAndReduce impls. " - f"Got dge_war: {dge_war}, and te_war: {te_war}") + f"Got dge_war: {dge_war}, and te_war: {te_war}" + ) if dge_war is not None: return dge_war @@ -94,8 +83,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -103,21 +90,33 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used() - or _valid_deep_gemm_shape(M, N, K)): + if self.allow_deep_gemm and ( + is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K) + ): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts, - expert_tokens_meta) + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) else: - return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, - global_num_experts, - local_num_experts, - expert_tokens_meta) + return self.triton_expert.workspace_shapes( + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) def apply( self, @@ -130,10 +129,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -141,9 +136,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - use_deep_gemm = (self.allow_deep_gemm - and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_e8m0_used())) + use_deep_gemm = self.allow_deep_gemm and ( + is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2) + ) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None @@ -158,10 +153,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation, global_num_experts, expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, a1q_scale, a2_scale, workspace13, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py new file mode 100644 index 0000000000000..c84d1afeb1f97 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.utils import next_power_of_2 + + +class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + max_capture_size, + ): + super().__init__(quant_config) + self.moe = moe + self.gemm1_alpha = gemm1_alpha + self.gemm1_beta = gemm1_beta + self.gemm1_clamp_limit = gemm1_clamp_limit + self.max_capture_size = max_capture_size + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # The workspaces for this implementation are managed by flashinfer. + workspace1 = (0,) + workspace2 = (0,) + output = (M, K) + return (workspace1, workspace2, output) + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # 1.0 means perfect expert distribution. + # > 1.0 means some experts have more tokens than the perfect + # distribution. + # < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert assuming perfect + # distribution. + num_tokens_per_expert = (num_tokens * top_k) // local_num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the + # kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + topk = topk_ids.size(-1) + local_num_experts = w1.size(0) + intermediate_size = w2.size(1) + local_expert_offset = self.moe.ep_rank * local_num_experts + + x_quant = hidden_states + x_scale = a1q_scale + if x_scale is not None: + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1) + + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + assert self.w1_scale is not None + assert self.w2_scale is not None + kwargs = { + "topk_ids": packed_tensor, + "routing_bias": None, + "hidden_states": x_quant, + "hidden_states_scale": x_scale, + "gemm1_weights": w1, + "gemm1_weights_scale": self.w1_scale, + "gemm1_bias": self.w1_bias, + "gemm1_alpha": self.gemm1_alpha, + "gemm1_beta": self.gemm1_beta, + "gemm1_clamp_limit": self.gemm1_clamp_limit, + "gemm2_weights": w2, + "gemm2_weights_scale": self.w2_scale, + "gemm2_bias": self.w2_bias, + "output1_scale_scalar": None, + "output1_scale_gate_scalar": None, + "output2_scale_scalar": None, + "num_experts": global_num_experts, + "top_k": topk, + "n_group": None, + "topk_group": None, + "intermediate_size": intermediate_size, + "local_expert_offset": local_expert_offset, + "local_num_experts": local_num_experts, + "routed_scaling_factor": None, + "tile_tokens_dim": self._get_tile_tokens_dim( + x_quant, topk, local_num_experts + ), + "routing_method_type": 1, + "do_finalize": True, + "output": output, + "tune_max_num_tokens": self.max_capture_size, + } + + from flashinfer import trtllm_fp4_block_scale_routed_moe + + trtllm_fp4_block_scale_routed_moe(**kwargs) + return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 4c3e700ad3990..bd68d2ec884de 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -7,38 +7,49 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8, per_token_quant_int8) + per_token_group_quant_int8, + per_token_quant_int8, +) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - quant_dequant_mxfp4) -from vllm.platforms import current_platform + quant_dequant_mxfp4, +) +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import ( + quant_dequant_mxfp6, +) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + mxfp8_e4m3_quantize, +) from vllm.triton_utils import tl, triton -from vllm.utils import cdiv -from vllm.utils.flashinfer import fp4_quantize +from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils.flashinfer import flashinfer_fp4_quantize @triton.jit -def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, - topk_numel, expert_map, - HAS_EXPERT_MAP: tl.constexpr, - BLOCK_SIZE: tl.constexpr): - +def _count_expert_num_tokens( + topk_ids_ptr, + expert_num_tokens_ptr, + num_experts, + topk_numel, + expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): curr_expert = tl.program_id(0) offsets = tl.arange(0, BLOCK_SIZE) topk_ids_ptrs = topk_ids_ptr + offsets - acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32) + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.int32) for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)): mask = offsets < (topk_numel - x * BLOCK_SIZE) expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1) if HAS_EXPERT_MAP: expert_map_ptrs = expert_map + expert_ids expert_map_mask = expert_ids >= 0 - expert_ids = tl.load(expert_map_ptrs, - mask=expert_map_mask, - other=-1) + expert_ids = tl.load(expert_map_ptrs, mask=expert_map_mask, other=-1) has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0) acc = acc + has_curr_expert @@ -49,8 +60,8 @@ def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, def count_expert_num_tokens( - topk_ids: torch.Tensor, num_local_experts: int, - expert_map: Optional[torch.Tensor]) -> torch.Tensor: + topk_ids: torch.Tensor, num_local_experts: int, expert_map: Optional[torch.Tensor] +) -> torch.Tensor: """ Count the number to tokens assigned to each expert. @@ -66,17 +77,16 @@ def count_expert_num_tokens( A tensor of size num_local_experts, where tensor[i] holds the number of tokens assigned to the ith expert. """ - assert topk_ids.dtype.is_signed, ( - "The kernel uses -1 to represent invalid topk_ids") - expert_num_tokens = torch.empty((num_local_experts), - device=topk_ids.device, - dtype=torch.int32) + assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids" + expert_num_tokens = torch.empty( + (num_local_experts), device=topk_ids.device, dtype=torch.int32 + ) grid = num_local_experts BLOCK_SIZE = min(topk_ids.numel(), 1024) BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE) - _count_expert_num_tokens[(grid, )]( + _count_expert_num_tokens[(grid,)]( topk_ids, expert_num_tokens, num_local_experts, @@ -94,19 +104,20 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel( - ), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly? - return x.flatten()[:prod(v)].view(*v) + assert prod(v) <= x.numel(), ( + f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" + ) # CUDAGRAPH unfriendly? + return x.flatten()[: prod(v)].view(*v) -def _fp4_quantize( +def _nvfp4_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], is_sf_swizzled_layout: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - return fp4_quantize(A, - A_scale, - is_sf_swizzled_layout=is_sf_swizzled_layout) + return flashinfer_fp4_quantize( + A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + ) def _fp8_quantize( @@ -123,7 +134,8 @@ def _fp8_quantize( # TODO(luka): use QuantFP8 custom op # https://github.com/vllm-project/vllm/issues/20711 A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_act_token) + A, A_scale, use_per_token_if_dynamic=per_act_token + ) else: assert not per_act_token assert len(block_shape) == 2 @@ -149,8 +161,7 @@ def _int8_quantize( # activations apply per-token quantization. Otherwise, assume # activation tensor-wise fp8/int8 quantization, dynamic or static if block_shape is None: - assert per_act_token, \ - "int8 quantization only supports block or channel-wise" + assert per_act_token, "int8 quantization only supports block or channel-wise" A, A_scale = per_token_quant_int8(A) else: assert not per_act_token @@ -169,10 +180,57 @@ def _mxfp4_quantize( block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, None]: assert block_shape is None - if not current_platform.supports_mx(): - A = quant_dequant_mxfp4(A) - else: - raise NotImplementedError() + # TODO: native mxfp4 is currently not integrated in vllm, + # so simulating even on devices supporting this data type natively. + # Once integrated, `current_platform.supports_mx()` should be used to + # control quantize+dequantize, or simply quantize here down to mxfp4. + A = quant_dequant_mxfp4(A) + + return A, None + + +def _mxfp8_e4m3_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert A_scale is None + assert not per_act_token_quant + assert block_shape is None + return mxfp8_e4m3_quantize(A) + + +def _mxfp6_e3m2_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, None]: + assert block_shape is None + + # TODO: native mxfp6 is currently not integrated in vllm, + # so simulating even on devices supporting this data type natively. + # Eventually, there should be a check based on + # `current_platform.supports_mx()` here. + A = quant_dequant_mxfp6(A, quant_dtype="fp6_e3m2") + + return A, None + + +def _mxfp6_e2m3_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, None]: + assert block_shape is None + + # TODO: native mxfp6 is currently not integrated in vllm, + # so simulating even on devices supporting this data type natively. + # Eventually, there should be a check based on + # `current_platform.supports_mx()` here. + A = quant_dequant_mxfp6(A, quant_dtype="fp6_e2m3") return A, None @@ -190,11 +248,17 @@ def moe_kernel_quantize_input( elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == "nvfp4": - return _fp4_quantize(A, - A_scale, - is_sf_swizzled_layout=is_fp4_scale_swizzled) + return _nvfp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp8": + # TODO: `quant_dtype == "mxfp8"` is ambiguous, + # should be fp8_e4m3. OCP MX also defines `fp8_e5m2`. + return _mxfp8_e4m3_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp6_e3m2": + return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp6_e2m3": + return _mxfp6_e2m3_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale @@ -209,8 +273,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m[idx, ...] -def normalize_scales_shape( - scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: scales = scales.view(1, 1) @@ -226,8 +289,9 @@ def normalize_batched_scales_shape( if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) - scales = torch.repeat_interleave(scales, num_experts, - dim=0).view(num_experts, 1, 1) + scales = torch.repeat_interleave(scales, num_experts, dim=0).view( + num_experts, 1, 1 + ) else: scales = scales.view(num_experts, -1, scales.size(-1)) @@ -247,8 +311,20 @@ def _validate_scale_shape( assert a_scale.numel() == 1, f"{a_scale.shape}" elif per_act_token_quant: assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( - f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") + f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1" + ) else: assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def activation_without_mul(activation: str) -> str: + return activation + "_no_mul" + + +# Torch custom ops can't deal with outputs aliasing inputs so we need to +# disable inplace for torch >= 2.9. +# See https://github.com/vllm-project/vllm/issues/26378 +def disable_inplace() -> bool: + return is_torch_equal_or_newer("2.9") diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a5fc1db2dc10f..6a49ae42ca895 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -1,25 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom normalization layers.""" + from typing import Optional, Union import torch import torch.nn as nn +import torch.nn.functional as F import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_RMSNORM \ - and envs.VLLM_ROCM_USE_AITER + return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER -def rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: from vllm import _custom_ops as ops + out = torch.empty_like(x) ops.rms_norm( out, @@ -31,9 +34,13 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, def fused_add_rms_norm( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops + ops.fused_add_rms_norm( x, residual, @@ -43,9 +50,27 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def poly_norm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + from vllm import _custom_ops as ops + + out = torch.empty_like(x) + ops.poly_norm( + out, + x, + weight, + bias, + variance_epsilon, + ) + return out + + +def rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: import aiter as rocm_aiter + if x.dim() > 2: x_original_shape = x.shape x = x.reshape(-1, x_original_shape[-1]) @@ -55,10 +80,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, return rocm_aiter.rms_norm(x, weight, variance_epsilon) -def rocm_aiter_fused_add_rms_norm( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: - +def rocm_aiter_rmsnorm2d_fwd_with_add_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: import aiter as rocm_aiter residual_out = torch.empty_like(residual) @@ -74,14 +101,49 @@ def rocm_aiter_fused_add_rms_norm( return output, residual_out -def dispatch_cuda_rmsnorm_func(add_residual: bool): - if add_residual: - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_fused_add_rms_norm - return fused_add_rms_norm +def rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + return torch.empty_like(x) - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_rms_norm + +def rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=rocm_aiter_rms_norm_impl, + fake_impl=rocm_aiter_rms_norm_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, + fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, + ) + + +def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): + use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ + torch.float16, + torch.bfloat16, + ] + + if use_aiter and with_fused_add: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + if use_aiter: + return torch.ops.vllm.rocm_aiter_rms_norm + + # fall back to CUDA implementation + if with_fused_add: + return fused_add_rms_norm return rms_norm @@ -105,8 +167,9 @@ class RMSNorm(CustomOp): self.hidden_size = hidden_size self.variance_epsilon = eps - self.variance_size_override = (None if var_hidden_size == hidden_size - else var_hidden_size) + self.variance_size_override = ( + None if var_hidden_size == hidden_size else var_hidden_size + ) self.has_weight = has_weight if dtype is not None: self.weight = torch.ones(hidden_size, dtype=dtype) @@ -114,6 +177,15 @@ class RMSNorm(CustomOp): self.weight = torch.ones(hidden_size) if self.has_weight: self.weight = nn.Parameter(self.weight) + weight_dtype = self.weight.data.dtype + + if current_platform.is_rocm(): + self.rocm_norm_func = dispatch_rocm_rmsnorm_func( + with_fused_add=False, dtype=weight_dtype + ) + self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( + with_fused_add=True, dtype=weight_dtype + ) def forward_native( self, @@ -129,8 +201,10 @@ class RMSNorm(CustomOp): hidden_size = x.shape[-1] if hidden_size != self.hidden_size: - raise ValueError("Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}") + raise ValueError( + "Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}" + ) if self.variance_size_override is None: x_var = x @@ -138,9 +212,10 @@ class RMSNorm(CustomOp): if hidden_size < self.variance_size_override: raise ValueError( "Expected hidden_size to be at least " - f"{self.variance_size_override}, but found: {hidden_size}") + f"{self.variance_size_override}, but found: {hidden_size}" + ) - x_var = x[:, :, :self.variance_size_override] + x_var = x[:, :, : self.variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) @@ -162,13 +237,28 @@ class RMSNorm(CustomOp): return self.forward_native(x, residual) add_residual = residual is not None - norm_func = dispatch_cuda_rmsnorm_func(add_residual) - if add_residual: - return norm_func(x, residual, self.weight.data, - self.variance_epsilon) + return fused_add_rms_norm( + x, residual, self.weight.data, self.variance_epsilon + ) else: - return norm_func(x, self.weight.data, self.variance_epsilon) + return rms_norm(x, self.weight.data, self.variance_epsilon) + + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + add_residual = residual is not None + if add_residual: + return self.rocm_norm_func_with_add( + x, residual, self.weight.data, self.variance_epsilon + ) + else: + return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon) def forward_xpu( self, @@ -228,10 +318,7 @@ class GemmaRMSNorm(CustomOp): """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype if residual is not None: - if orig_dtype == torch.float16: - x = x + residual.float() - else: - x = x + residual + x = x + residual.float() if orig_dtype == torch.float16 else x + residual residual = x x = x.float() @@ -249,8 +336,7 @@ class GemmaRMSNorm(CustomOp): residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" - return self.forward_static(self.weight.data, self.variance_epsilon, x, - residual) + return self.forward_static(self.weight.data, self.variance_epsilon, x, residual) def forward_cuda( self, @@ -262,6 +348,72 @@ class GemmaRMSNorm(CustomOp): if not getattr(self, "_is_compiled", False): self.forward_static = torch.compile( # type: ignore - self.forward_static) + self.forward_static + ) self._is_compiled = True return self.forward_native(x, residual) + + +@CustomOp.register("poly_norm") +class PolyNorm(CustomOp): + """Polynomial normalization. + + Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b + where w_n is the learned weight and b is the bias. + Refer to https://arxiv.org/html/2411.03884v1 + """ + + def __init__( + self, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1)) + self.variance_epsilon = eps + + def _norm(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward(). + + Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md + """ + + orig_dtype = x.dtype + x_float = x.to(torch.float32) + output = ( + self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + + self.bias + ) + return output.to(orig_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return poly_norm(x, self.weight, self.bias, self.variance_epsilon) + + +class LayerNorm(nn.Module): + """ + Layer Normalization. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + return F.layer_norm( + x.float(), (self.dim,), self.weight, self.bias, self.eps + ).type_as(x) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 8ffc700ca5cde..e874301b02c05 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch from einops import rearrange @@ -7,9 +9,21 @@ from vllm.triton_utils import tl, triton @triton.jit -def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, - d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, - NUM_BLOCK, CBLOCK: tl.constexpr): +def _fwd_diag_kernel( + Q, + K, + V, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + CBLOCK: tl.constexpr, +): # This kernel computes the diagonal blocks of the attention matrix # Each diagonal block represents attention # where queries attend to keys in the same block @@ -37,18 +51,36 @@ def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, o_cblock_offset = cblock_offset * e # Calculate pointers to the query, key, value, and output tensors - Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset + - tl.arange(0, CBLOCK)[:, None] * d + - tl.arange(0, d)[None, :]) - K_trans_block_ptr = (K + qk_offset + qk_block_offset + - tl.arange(0, CBLOCK)[None, :] * d + - tl.arange(0, d)[:, None]) - V_block_ptr = (V + v_offset + v_block_offset + - tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, e)[None, :]) - O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset + - tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, e)[None, :]) + Q_block_ptr = ( + Q + + qk_offset + + qk_block_offset + + q_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :] + ) + K_trans_block_ptr = ( + K + + qk_offset + + qk_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, d)[:, None] + ) + V_block_ptr = ( + V + + v_offset + + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :] + ) + O_block_ptr = ( + Out + + o_offset + + o_block_offset + + o_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :] + ) # Load the decay rate for the current head S_block_ptr = S + off_h @@ -58,9 +90,9 @@ def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, q_index = tl.arange(0, CBLOCK) + i * CBLOCK # Load query values - q = tl.load(Q_block_ptr, - mask=block_offset + q_index[:, None] < n, - other=0.0).to(tl.float32) + q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to( + tl.float32 + ) # Initialize output accumulator qkv = tl.zeros([CBLOCK, e], dtype=tl.float32) @@ -144,18 +176,30 @@ def _fwd_kv_parallel( kv_offset = off_bh * NUM_BLOCK * d * e # Calculate pointers to the key, value, and key-value tensors - K_trans_block_ptr = (K + k_offset + k_block_offset + - tl.arange(0, CBLOCK)[None, :] * d + - tl.arange(0, D_FBLOCK)[:, None]) - V_block_ptr = (V + v_offset + v_block_offset + - tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) - KV_block_ptr = (KV + kv_offset + kv_block_offset + - tl.arange(0, D_FBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + K_trans_block_ptr = ( + K + + k_offset + + k_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, D_FBLOCK)[:, None] + ) + V_block_ptr = ( + V + + v_offset + + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) + KV_block_ptr = ( + KV + + kv_offset + + kv_block_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the decay factors for the current head and block - k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]) + k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :] kv_index = tl.arange(0, CBLOCK) @@ -163,10 +207,7 @@ def _fwd_kv_parallel( kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) # Handle the last block which might be smaller than BLOCK - if off_block == NUM_BLOCK - 1: - split_n = n - (NUM_BLOCK - 1) * BLOCK - else: - split_n = BLOCK + split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK @@ -175,12 +216,16 @@ def _fwd_kv_parallel( for j in range(num_blocks): left_bound = (1 - j) * left_shift # Load key and value, handling boundary conditions - k_trans = tl.load(K_trans_block_ptr - left_shift * d, - mask=kv_index[None, :] >= left_bound, - other=0.0) - v = tl.load(V_block_ptr - left_shift * e, - mask=kv_index[:, None] >= left_bound, - other=0.0) + k_trans = tl.load( + K_trans_block_ptr - left_shift * d, + mask=kv_index[None, :] >= left_bound, + other=0.0, + ) + v = tl.load( + V_block_ptr - left_shift * e, + mask=kv_index[:, None] >= left_bound, + other=0.0, + ) # Load decay factor and compute weighted key-value outer product k_decay = tl.load(k_decay_ptr) @@ -196,9 +241,20 @@ def _fwd_kv_parallel( @triton.jit -def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, - d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, - NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr): +def _fwd_kv_reduce( + S, + KV, + KV_HISTORY, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, +): # This kernel reduces the key-value outer products # across blocks and updates the KV history off_bh = tl.program_id(0) # batch-head index @@ -207,8 +263,12 @@ def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, kv_offset = off_bh * NUM_BLOCK * d * e # Calculate pointer to the key-value tensor - KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = ( + KV + + kv_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the decay rate for the current head s_ptrs = S + off_h @@ -216,9 +276,12 @@ def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, # Calculate pointer to the key-value history tensor kv_history_offset = off_bh * d * e - KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset + - tl.arange(0, D_FBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + KV_HISTORY_block_ptr = ( + KV_HISTORY + + kv_history_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the previous key-value history kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32) @@ -281,12 +344,18 @@ def _fwd_none_diag_kernel( kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset # Calculate pointers to the query, output, and key-value tensors - Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + - tl.arange(0, d)[None, :]) - O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) - KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + Q_block_ptr = ( + Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :] + ) + O_block_ptr = ( + Out + + o_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) + KV_block_ptr = ( + KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the decay rate for the current head S_block_ptr = S + off_h @@ -299,8 +368,7 @@ def _fwd_none_diag_kernel( q_index = block_offset + tl.arange(0, CBLOCK) # Load query values - q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, - other=0.).to(tl.float32) + q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32) # Compute decay factors for the current sub-block q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None])) @@ -309,20 +377,18 @@ def _fwd_none_diag_kernel( qkv_none_diag = tl.dot(q, kv) * q_decay # Load diagonal attention output (computed by _fwd_diag_kernel) - qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, - other=0.).to(tl.float32) + qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32) # Combine diagonal and non-diagonal attention outputs qkv = qkv_diag + qkv_none_diag # Store the result - tl.store(O_block_ptr, - qkv.to(O_block_ptr.dtype.element_ty), - mask=q_index[:, None] < n) + tl.store( + O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n + ) class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, s, kv_history): # Forward pass of the lightning attention algorithm @@ -334,8 +400,10 @@ class _attention(torch.autograd.Function): # Check CUDA compute capability capability = torch.cuda.get_device_capability() if capability[0] < 8: - raise RuntimeError("Flash attention currently only supported", - "for compute capability >= 80") + raise RuntimeError( + "Flash attention currently only supported", + "for compute capability >= 80", + ) # Get input dimensions b, h, n, d = q.shape @@ -358,19 +426,21 @@ class _attention(torch.autograd.Function): # Step 1: Compute diagonal blocks of attention grid = (b * h * NUM_BLOCK, NUM_CBLOCK) - _fwd_diag_kernel[grid](q, - k, - v, - o, - s, - b, - h, - n, - d, - e, - BLOCK=BLOCK, - NUM_BLOCK=NUM_BLOCK, - CBLOCK=CBLOCK) + _fwd_diag_kernel[grid]( + q, + k, + v, + o, + s, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + CBLOCK=CBLOCK, + ) # Set feature block sizes NUM_FBLOCK = 1 @@ -384,9 +454,7 @@ class _attention(torch.autograd.Function): assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" # Step 2: Compute key-value outer products for each block in parallel - kv = torch.empty((b, h, NUM_BLOCK, d, e), - dtype=torch.float32, - device=q.device) + kv = torch.empty((b, h, NUM_BLOCK, d, e), dtype=torch.float32, device=q.device) grid = (b * h, NUM_BLOCK) _fwd_kv_parallel[grid]( k, @@ -410,18 +478,20 @@ class _attention(torch.autograd.Function): # Step 3: Reduce key-value outer products # across blocks and update KV history grid = (b * h, NUM_FBLOCK) - _fwd_kv_reduce[grid](s, - kv, - kv_history, - b, - h, - n, - d, - e, - BLOCK=BLOCK, - NUM_BLOCK=NUM_BLOCK, - D_FBLOCK=D_FBLOCK, - E_FBLOCK=E_FBLOCK) + _fwd_kv_reduce[grid]( + s, + kv, + kv_history, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + ) # Step 4: Compute non-diagonal blocks of attention grid = (b * h, NUM_BLOCK * NUM_CBLOCK) @@ -453,11 +523,18 @@ class _attention(torch.autograd.Function): lightning_attention_ = _attention.apply -def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): +def lightning_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ed: torch.Tensor, + block_size: int = 256, + kv_history: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ - Apply lightning attention algorithm + Apply lightning attention algorithm to compute attention efficiently. - + Args: q: Query tensor of shape [batch, heads, seq_len, dim] k: Key tensor of shape [batch, heads, seq_len, dim] @@ -465,7 +542,7 @@ def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): ed: Decay rate tensor of shape [heads] block_size: Size of blocks for block-sparse attention kv_history: Optional key-value history from previous computations - + Returns: output: Attention output kv: Updated key-value history @@ -487,9 +564,9 @@ def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): # Initialize or clone key-value history if kv_history is None: - kv_history = torch.zeros((q.shape[0], q.shape[1], d, e), - dtype=torch.float32, - device=q.device) + kv_history = torch.zeros( + (q.shape[0], q.shape[1], d, e), dtype=torch.float32, device=q.device + ) else: kv_history = kv_history.clone().contiguous() @@ -524,7 +601,7 @@ def _linear_attn_decode_kernel( ): """ Kernel for linear attention decoding with KV cache. - + This kernel computes attention for a single token using the KV cache. """ pid_b = tl.program_id(0) # batch index @@ -547,8 +624,9 @@ def _linear_attn_decode_kernel( # Calculate offsets for dimensions qk_d_offsets = tl.arange(0, D) v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE - cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[ - None, :] * cache_d1_stride + cache_d_offsets = ( + qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride + ) # Calculate offsets for the current batch and head q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride @@ -596,7 +674,7 @@ def linear_decode_forward_triton( ) -> torch.Tensor: """ Perform linear attention decoding using Triton kernels. - + Args: q: Query tensor of shape [B, H, 1, D] k: Key tensor of shape [B, H, 1, D] @@ -605,7 +683,7 @@ def linear_decode_forward_triton( slope_rate: Decay rate tensor slot_idx: Slot indices for batches BLOCK_SIZE: Size of blocks for processing - + Returns: output: Attention output tensor """ diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5725c841e5292..63358a0c07d89 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,38 +3,45 @@ import itertools from abc import abstractmethod -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, Union import torch -import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter -from vllm import envs -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm -# yapf: disable -from vllm.model_executor.parameter import (BasevLLMParameter, - BlockQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - PerTensorScaleParameter, - RowvLLMParameter) -# yapf: enable +from vllm.model_executor.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, + ModelWeightParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import GiB_bytes logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ + "UnquantizedLinearMethod", "CompressedTensorsLinearMethod", + "CompressedTensorsLinearTransformMethod", "BitBLASLinearMethod", "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", @@ -59,8 +66,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ def adjust_bitblas_shard(param, shard_size, shard_offset): bitblas_tile_size = getattr(param, "bitblas_tile_size", None) if bitblas_tile_size is not None: - return (shard_size // bitblas_tile_size, - shard_offset // bitblas_tile_size) + return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size) return shard_size, shard_offset @@ -73,9 +79,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def adjust_bitsandbytes_4bit_shard(param: Parameter, - shard_offsets: dict[str, tuple[int, int]], - loaded_shard_id: str) -> tuple[int, int]: +def adjust_bitsandbytes_4bit_shard( + param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str +) -> tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" total, _ = shard_offsets["total"] @@ -91,8 +97,8 @@ def adjust_bitsandbytes_4bit_shard(param: Parameter, def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): """For fused modules (QKV and MLP) we have an array of length N that holds 1 scale for each "logical" matrix. So the param - is an array of length N. The loaded_weight corresponds to - one of the shards on disk. Here, we slice the param based on + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on the shard_id for loading. """ qkv_idxs = {"q": 0, "k": 1, "v": 2} @@ -119,13 +125,13 @@ def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): For example, given bnb weight attributes as below: { - 'bnb_shard_offsets': array([0, 4, 8, 16]), + 'bnb_shard_offsets': array([0, 4, 8, 16]), 'bnb_quant_state': {0: ..., 1: ..., 2: ...}, } The function will return: { - 'bnb_shard_offsets': array([0, 4]), + 'bnb_shard_offsets': array([0, 4]), 'bnb_quant_state': {0: ...}, } and @@ -140,8 +146,7 @@ def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]} quant_state_r = { i - 1: bnb_weight_attrs["bnb_quant_state"][i] - for i in range(1, - len(shard_offsets) - 1) + for i in range(1, len(shard_offsets) - 1) } left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l) right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r) @@ -152,18 +157,23 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - """Create weights for a linear layer. + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. The weights will be set as attributes of the layer. Args: layer: The layer that is using the LinearMethodBase factory. input_size_per_partition: Size of the weight input dim on rank X. - output_partition_sizes: Sizes of the output dim of each logical + output_partition_sizes: Sizes of the output dim of each logical weight on rank X. E.g., output_partition_sizes for QKVLinear is a list contains the width of Wq, Wk, Wv on rank X. input_size: Size of the input dim of the weight across all ranks. @@ -173,10 +183,12 @@ class LinearMethodBase(QuantizeMethodBase): raise NotImplementedError @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -185,45 +197,63 @@ class LinearMethodBase(QuantizeMethodBase): class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + try: + weight_loader = extra_weight_attrs.pop("weight_loader") + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + except torch.cuda.OutOfMemoryError as e: + logger.error("Failed to create unquantized linear weights: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug( + "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes + ) + logger.debug( + "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes + ) + raise RuntimeError( + "Failed to create unquantized linear weights. " + "This may be caused by insufficient memory to allocate " + "the weight." + ) from e + layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: - from vllm.model_executor.layers.utils import check_cpu_sgl_kernel - N, K = layer.weight.size() - dtype = layer.weight.dtype - if check_cpu_sgl_kernel(N, K, dtype): - packed_weight = torch.ops._C.convert_weight_packed( - layer.weight) - assert packed_weight.size() == layer.weight.size() - layer.weight.copy_(packed_weight) - if layer.bias is not None: - layer.bias = Parameter(layer.bias.to(torch.float32), - requires_grad=False) - layer.use_cpu_sgl = True - else: - logger.warning( - "CPU SGL kernels require Intel AMX support," - " bf16/fp16/int8 weight, IC and OC are divisible by " - "32 and 16.") - layer.use_cpu_sgl = False + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + dispatch_cpu_unquantized_gemm(layer, remove_weight=True) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -233,11 +263,12 @@ class LinearBase(CustomOp): Args: input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. - bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + prefix: Prefix for parameter names. return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, tensor parallelism will be disabled for this layer. """ def __init__( @@ -250,6 +281,7 @@ class LinearBase(CustomOp): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): super().__init__() @@ -263,12 +295,19 @@ class LinearBase(CustomOp): self.quant_config = quant_config self.prefix = prefix if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() else: - self.quant_method = quant_config.get_quant_method(self, - prefix=prefix) + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.disable_tp = disable_tp + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 + + def update_param_tp_status(self): + for param in self.parameters(): + if isinstance(param, BasevLLMParameter): + param.tp_rank = self.tp_rank + param.tp_size = self.tp_size @CustomOp.register("replicated_linear") @@ -285,6 +324,7 @@ class ReplicatedLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: Take no effect for replicated linear layers. """ def __init__( @@ -298,6 +338,7 @@ class ReplicatedLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # If MergedReplicatedLinear, use output size of each partition. if hasattr(self, "output_sizes"): @@ -305,31 +346,40 @@ class ReplicatedLinear(LinearBase): else: self.output_partition_sizes = [output_size] - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - return_bias=return_bias) + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - self.output_partition_sizes, - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + self.input_size, + self.output_partition_sizes, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader, + ) if bias: self.bias = Parameter( - torch.empty(self.output_size, dtype=self.params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + torch.empty(self.output_size, dtype=self.params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) @@ -352,16 +402,20 @@ class ReplicatedLinear(LinearBase): assert param.size() == loaded_weight.size(), ( f"Tried to load weights of size {loaded_weight.size()}" - f"to a parameter of size {param.size()}") + f"to a parameter of size {param.size()}" + ) param.data.copy_(loaded_weight) def forward( - self, x: torch.Tensor + self, + x: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: return output return output, output_bias @@ -373,73 +427,6 @@ class ReplicatedLinear(LinearBase): return s -class MergedReplicatedLinear(ReplicatedLinear): - """Replicated linear layer. - - Args: - input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - """ - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - ): - self.output_sizes = output_sizes - super().__init__(input_size, - sum(output_sizes), - bias, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - return_bias=return_bias) - - def weight_loader(self, - param: Union[Parameter, BasevLLMParameter], - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): - assert loaded_shard_id is not None - assert loaded_shard_id < len(self.output_sizes) - - if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) - assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size - assert weight_block_size is not None - block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n) - elif isinstance(param, PerTensorScaleParameter): - shard_offset = loaded_shard_id - shard_size = 1 - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) - shard_size = self.output_sizes[loaded_shard_id] - - param.data[shard_offset:shard_offset + shard_size] = loaded_weight - - @CustomOp.register("column_parallel_linear") class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -462,7 +449,9 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -478,26 +467,30 @@ class ColumnParallelLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - divide(output_size, self.tp_size) - for output_size in self.output_sizes + divide(output_size, self.tp_size) for output_size in self.output_sizes ] - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) self.gather_output = gather_output @@ -513,23 +506,27 @@ class ColumnParallelLinear(LinearBase): output_size=self.output_size, params_dtype=self.params_dtype, weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) if bias: self.bias = Parameter( - torch.empty(self.output_size_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + torch.empty(self.output_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) - - self.tp_rank = get_tensor_model_parallel_rank() + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - output_dim = getattr(param, "output_dim", None) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -549,16 +546,14 @@ class ColumnParallelLinear(LinearBase): final_shape = list(loaded_weight.shape) if output_dim is not None: assert final_shape[output_dim] % self.tp_size == 0 - final_shape[output_dim] = (final_shape[output_dim] // - self.tp_size) + final_shape[output_dim] = final_shape[output_dim] // self.tp_size param.materialize(final_shape, dtype=loaded_weight.dtype) param_data = param.data if output_dim is not None and not is_sharded_weight: shard_size = param_data.shape[output_dim] start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -568,7 +563,7 @@ class ColumnParallelLinear(LinearBase): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -577,14 +572,16 @@ class ColumnParallelLinear(LinearBase): param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward( - self, input_ + self, + input_, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output: + + if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -598,7 +595,7 @@ class ColumnParallelLinear(LinearBase): s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -625,6 +622,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, all weights matrix won't be sharded, this layer + will be treated as a "Replicated" MergedLinear. """ def __init__( @@ -639,28 +638,32 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.output_sizes = output_sizes - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 - assert all(output_size % self.tp_size == 0 - for output_size in output_sizes) - super().__init__(input_size=input_size, - output_size=sum(output_sizes), - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - return_bias=return_bias) - - def weight_loader(self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): + assert all(output_size % self.tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -671,20 +674,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: param.shard_weight_type = { - i: loaded_weight.item() - for i, _ in enumerate(self.output_sizes) + i: loaded_weight.item() for i, _ in enumerate(self.output_sizes) } return if is_gguf_weight: - output_dim = getattr(param, "output_dim", None) shard_size = loaded_weight.size(output_dim) // self.tp_size start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) @@ -701,14 +701,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, 0) + param_data, loaded_weight, 0 + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return current_shard_offset = 0 - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) shard_offsets: list[tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) @@ -719,14 +719,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) shard_size, shard_offset = adjust_bitblas_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) if use_bitsandbytes_4bit: index = list(itertools.accumulate([0] + self.output_sizes)) @@ -736,33 +738,35 @@ class MergedColumnParallelLinear(ColumnParallelLinear): } orig_offsets["total"] = (self.output_size, 0) shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_offsets, str(shard_id)) + param, orig_offsets, str(shard_id) + ) loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size) + output_dim, shard_offset, shard_size + ) self.weight_loader(param, loaded_weight_shard, shard_id) return assert loaded_shard_id < len(self.output_sizes) if output_dim is not None: - shard_offset = (sum(self.output_sizes[:loaded_shard_id]) // - self.tp_size) + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) shard_size, shard_offset = adjust_bitblas_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) # bitsandbytes loads the weights of the specific portion # no need to narrow @@ -770,19 +774,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): if use_bitsandbytes_4bit: shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = self.tp_rank * shard_size if not is_sharded_weight: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, loaded_shard_id) + param_data, loaded_weight, loaded_shard_id + ) else: ignore_warning = getattr(param, "ignore_warning", False) @@ -790,17 +792,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): logger.warning( "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") + "the same for all partitions." + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): """ Handle special case for models where MLP layers are already fused on disk. In this case, we have no shard id. This function - determmines the shard id by splitting these layers and then calls + determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: @@ -817,25 +821,28 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, (PackedColumnParameter, PackedvLLMParameter - )) and param.packed_dim == param.output_dim: - shard_size, shard_offset = \ - param.adjust_shard_indexes_for_packing( - shard_size=shard_size, shard_offset=shard_offset) + if ( + isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) - loaded_weight_shard = loaded_weight.narrow(param.output_dim, - shard_offset, - shard_size) + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) - def weight_loader_v2(self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=0) + param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) @@ -846,30 +853,32 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // tp_size - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // tp_size) + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n + ) // self.tp_size + shard_size = ( + (self.output_sizes[loaded_shard_id] + block_n - 1) + // block_n + // self.tp_size + ) else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size) + param.load_merged_column_weight( + loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) class QKVParallelLinear(ColumnParallelLinear): @@ -897,6 +906,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -912,6 +922,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size @@ -920,40 +931,43 @@ class QKVParallelLinear(ColumnParallelLinear): total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 - self.num_kv_head_replicas = divide(tp_size, - self.total_num_kv_heads) + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 input_size = self.hidden_size - output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - return_bias=return_bias) + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { "q": 0, "k": self.num_heads * self.head_size, "v": (self.num_heads + self.num_kv_heads) * self.head_size, - "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, } return shard_offset_mapping.get(loaded_shard_id) @@ -965,12 +979,13 @@ class QKVParallelLinear(ColumnParallelLinear): } return shard_size_mapping.get(loaded_shard_id) - def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): """ - Handle special case for models where QKV layers are already + Handle special case for models where QKV layers are already fused on disk. In this case, we have no shard id. This function - determmines the shard id by splitting these layers and then calls + determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: @@ -979,38 +994,49 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), - ("k", self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size), - ("v", - (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), ] for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, (PackedColumnParameter, PackedvLLMParameter - )) and param.packed_dim == param.output_dim: - shard_size, shard_offset = \ - param.adjust_shard_indexes_for_packing( - shard_size=shard_size, shard_offset=shard_offset) + if ( + isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) - loaded_weight_shard = loaded_weight.narrow(param.output_dim, - shard_offset, - shard_size) + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) - def weight_loader_v2(self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + param.load_qkv_weight( + loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank + ) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): - param.load_qkv_weight(loaded_weight=loaded_weight) + param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) @@ -1024,23 +1050,29 @@ class QKVParallelLinear(ColumnParallelLinear): # Note(simon): This is needed for Qwen3's fp8 quantization. if isinstance(param, BlockQuantScaleParameter): assert self.quant_method is not None - assert hasattr(self.quant_method, "quant_config") - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size + assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n - param.load_qkv_weight(loaded_weight=loaded_weight, - num_heads=self.num_kv_head_replicas, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size) - - def weight_loader(self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1051,10 +1083,7 @@ class QKVParallelLinear(ColumnParallelLinear): param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: - param.shard_weight_type = { - k: loaded_weight.item() - for k in idx_map - } + param.shard_weight_type = {k: loaded_weight.item() for k in idx_map} return if is_gguf_weight: @@ -1063,8 +1092,7 @@ class QKVParallelLinear(ColumnParallelLinear): start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) @@ -1082,7 +1110,8 @@ class QKVParallelLinear(ColumnParallelLinear): if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, 0) + param_data, loaded_weight, 0 + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1090,13 +1119,18 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), - ("k", self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size), - ("v", (self.total_num_heads + self.total_num_kv_heads) * - self.head_size, self.total_num_kv_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), ] - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: @@ -1104,32 +1138,40 @@ class QKVParallelLinear(ColumnParallelLinear): # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.total_num_heads * self.head_size), - "k": (self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size), - "v": - ((self.total_num_heads + self.total_num_kv_heads) * - self.head_size, - self.total_num_kv_heads * self.head_size), - "total": - ((self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_size, 0) + "k": ( + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + "v": ( + (self.total_num_heads + self.total_num_kv_heads) + * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + "total": ( + (self.total_num_heads + 2 * self.total_num_kv_heads) + * self.head_size, + 0, + ), } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_qkv_offsets, shard_id) + param, orig_qkv_offsets, shard_id + ) loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size) + output_dim, shard_offset, shard_size + ) self.weight_loader(param, loaded_weight_shard, shard_id) return @@ -1144,23 +1186,22 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": - shard_offset = (self.num_heads + - self.num_kv_heads) * self.head_size + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) # bitsandbytes loads the weights of the specific portion # no need to narrow @@ -1169,41 +1210,46 @@ class QKVParallelLinear(ColumnParallelLinear): if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), - "k": (self.num_heads * self.head_size, - self.num_kv_heads * self.head_size), - "v": - ((self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size), - "total": - ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, - 0) + "k": ( + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "v": ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size, + ), + "total": ( + (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0, + ), } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_qkv_offsets, loaded_shard_id) + param, orig_qkv_offsets, loaded_shard_id + ) - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = self.tp_rank + shard_rank = self.tp_rank else: - shard_id = self.tp_rank // self.num_kv_head_replicas - start_idx = shard_id * shard_size + shard_rank = self.tp_rank // self.num_kv_head_replicas + start_idx = shard_rank * shard_size if not is_sharded_weight: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, loaded_shard_id) + param_data, loaded_weight, loaded_shard_id + ) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: logger.warning( "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " - "for all partitions.") + "for all partitions." + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1240,6 +1286,7 @@ class RowParallelLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.down_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -1255,21 +1302,25 @@ class RowParallelLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the first dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1283,21 +1334,29 @@ class RowParallelLinear(LinearBase): output_size=self.output_size, params_dtype=self.params_dtype, weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) if bias: - self.bias = Parameter( - torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): input_dim = getattr(param, "input_dim", None) @@ -1317,16 +1376,14 @@ class RowParallelLinear(LinearBase): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = (weight_shape[input_dim] // - self.tp_size) + weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -1336,9 +1393,7 @@ class RowParallelLinear(LinearBase): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): - + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -1348,24 +1403,24 @@ class RowParallelLinear(LinearBase): param.load_row_parallel_weight(loaded_weight=loaded_weight) def forward( - self, input_ + self, + input_, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) + output_parallel = self.quant_method.apply(self, input_parallel, bias_) + if self.reduce_results and self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) else: @@ -1378,225 +1433,9 @@ class RowParallelLinear(LinearBase): return output, output_bias def extra_repr(self) -> str: - s = f"input_features={self.input_size_per_partition}" + s = f"in_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" s += f", bias={self.bias is not None}" s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s - - -@CustomOp.register("qkv_cross_parallel_linear") -class QKVCrossParallelLinear(LinearBase): - """Linear layers for efficient cross-attention's QKV transformation. - - Args: - hidden_size: input hidden state size of the transformer. - head_size: size of each attention head. - total_num_heads: total number of attention query heads. - total_num_kv_heads: total number of attention key/value heads. If - None, assume total_num_kv_heads = total_num_heads. - bias: If true, add bias. - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - """ - - def __init__(self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - # input_size and output_size are not used, just for alignment - input_size = hidden_size - output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size - super().__init__(input_size=input_size, - output_size=output_size, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix) - - self.quant_config = quant_config - - # Empty placeholders for loading as a single module. - placeholder_size = 0 - assert self.quant_method is not None - self.quant_method.create_weights(self, - placeholder_size, [placeholder_size], - placeholder_size, - placeholder_size, - self.params_dtype, - weight_loader=self.weight_loader) - - # Use a dictionary to avoid submodules parameters auto-registration: - # drop-in replacement for a `QKVParallelLinear` module. - self.proj = dict() - self.proj["q_proj_decoder"] = ColumnParallelLinear( - input_size=hidden_size, - output_size=total_num_heads * head_size, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.q_proj_decoder") - - self.proj["kv_proj_encoder"] = QKVParallelLinear( - hidden_size=hidden_size, - head_size=head_size, - total_num_heads=0, - total_num_kv_heads=total_num_kv_heads, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.kv_proj_encoder") - - # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. - self.q_size = self.q_proj_decoder.output_size_per_partition - self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size - - if bias: - self.bias = torch.nn.Parameter() - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.bias = None - - def process_weights_after_loading(self): - for layer in self.proj.values(): - if self.quant_method is not None: - self.quant_method.process_weights_after_loading(layer) - - @property - def q_proj_decoder(self) -> ColumnParallelLinear: - layer = self.proj["q_proj_decoder"] - for name, param in self.named_parameters(): - target_param = getattr(layer, name, None) - if target_param is not None: - self.sync_weight_attrs(param, - target_param, - mode="q_proj_decoder") - return layer - - @property - def kv_proj_encoder(self) -> QKVParallelLinear: - layer = self.proj["kv_proj_encoder"] - for name, param in self.named_parameters(): - target_param = getattr(layer, name, None) - if target_param is not None: - self.sync_weight_attrs(param, - target_param, - mode="kv_proj_encoder") - return layer - - def sync_weight_attrs( - self, - src_param: nn.Parameter, - tgt_param: nn.Parameter, - mode: Literal["q_proj_decoder", "kv_proj_encoder"], - ): - missing_attrs_dict = { - k: getattr(src_param, k) - for k in (set(vars(src_param).keys()) - - set(vars(tgt_param).keys())) - } - # TODO(Isotr0py): handle bitsandbytes 8bit - use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", - False) - if (missing_attrs_dict and use_bitsandbytes_4bit): - q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( - missing_attrs_dict) - if mode == "q_proj_decoder": - set_weight_attrs(tgt_param, q_proj_attrs) - elif mode == "kv_proj_encoder": - set_weight_attrs(tgt_param, kv_proj_attrs) - else: - set_weight_attrs(tgt_param, missing_attrs_dict) - - def _is_same_param( - self, - src_param: torch.nn.Parameter, - map_param: torch.nn.Parameter, - ) -> bool: - """Check if two parameters are exactly pointing to same things.""" - # ignore weight_loader because it's always different - key_to_ignore = ["weight_loader", "_weight_loader"] - has_same_type_name = type(src_param) is type(map_param) - src_param_attrs = { - k: v - for k, v in src_param.__dict__.items() if k not in key_to_ignore - } - map_param_attrs = { - k: v - for k, v in map_param.__dict__.items() if k not in key_to_ignore - } - has_same_attrs = src_param_attrs == map_param_attrs - return has_same_type_name and has_same_attrs - - def select_proj_params( - self, - layer: nn.Module, - param: nn.Parameter, - ) -> nn.Parameter: - """ - Given the placeholder param, - return the corresponding param in the proj layers. - """ - target_param_list = [ - v for _, v in layer.named_parameters() - if self._is_same_param(param, v) - ] - assert len(target_param_list) == 1 - target_param = target_param_list[0] - return target_param - - def forward( # type: ignore[override] - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - q, _ = self.q_proj_decoder(decoder_hidden_states) - if encoder_hidden_states is None: - # Encoder KV already cached. - k = None - v = None - else: - # Prefill phase, encoder KV cached here. - kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) - # Split kv in half - k, v = kv_enc.split(self.kv_size, dim=-1) - return q, k, v - - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - layer = (self.q_proj_decoder - if loaded_shard_id == "q" else self.kv_proj_encoder) - target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () - if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: - layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) - else: - layer.weight_loader(target_param, loaded_weight, *shard_id_args) - - def extra_repr(self) -> str: - s = f"in_features={self.input_size}" - s += f", q_size={self.q_size}" - s += f", kv_size={self.kv_size}" - s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" - s += ", gather_output=False" - return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e93be9bfb1657..3db5e0b325538 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,28 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that compute logits from hidden_stats.""" -import inspect -from concurrent.futures import ThreadPoolExecutor + from typing import Optional import torch -import torch.nn as nn -import vllm.envs as envs -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_gather) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_gather, +) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform -_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None -if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: - _logits_processor_threadpool = ThreadPoolExecutor( - envs.VLLM_LOGITS_PROCESSOR_THREADS) - -class LogitsProcessor(nn.Module): +@CustomOp.register("logits_processor") +class LogitsProcessor(CustomOp): """Process logits and apply logits processors from sampling metadata. This layer does the following: @@ -31,12 +25,14 @@ class LogitsProcessor(nn.Module): 3. Apply logits processors (if any). """ - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None) -> None: + def __init__( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, + ) -> None: """ Args: scale: A scaling factor to apply to the logits. @@ -57,17 +53,11 @@ class LogitsProcessor(nn.Module): self, lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, - prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None and prune_hidden_states: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: @@ -78,12 +68,6 @@ class LogitsProcessor(nn.Module): if self.scale != 1.0: logits *= self.scale - - # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - return logits def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: @@ -107,16 +91,14 @@ class LogitsProcessor(nn.Module): embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, - hidden_states, - bias=embedding_bias) + logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) # Gather logits for TP logits = self._gather_logits(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[..., :self.org_vocab_size] + logits = logits[..., : self.org_vocab_size] return logits def extra_repr(self) -> str: @@ -124,75 +106,3 @@ class LogitsProcessor(nn.Module): s += f", org_vocab_size={self.org_vocab_size}" s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" return s - - -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios - # (warmup, profile_run) we might not have selected_token_indices, - # so we skip pruning. - if sampling_metadata.selected_token_indices is not None: - return hidden_states.index_select( - 0, sampling_metadata.selected_token_indices) - else: - return hidden_states - - -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - found_logits_processors = False - logits_processed = 0 - logits_row_ids_and_logits_row_futures = [] - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): - logits_row = logits[logits_row_idx] - past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids - prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - - if _logits_processor_threadpool is not None: - logits_row_ids_and_logits_row_futures.append( - (logits_row_idx, - _logits_processor_threadpool.submit( - _apply_logits_processors_single_seq, logits_row, - logits_processors, past_tokens_ids, - prompt_tokens_ids))) - else: - logits[logits_row_idx] = \ - _apply_logits_processors_single_seq( - logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids) - - logits_processed += len(seq_group.sample_indices) + len( - seq_group.prompt_logprob_indices) - - for logits_row_idx, future in logits_row_ids_and_logits_row_futures: - logits[logits_row_idx] = future.result() - - if found_logits_processors: - # verifies that no rows in logits were missed unexpectedly - assert logits_processed == logits.shape[0] - return logits - - -def _apply_logits_processors_single_seq(logits_row, logits_processors, - past_tokens_ids, - prompt_tokens_ids) -> torch.Tensor: - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, logits_row) - return logits_row diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f771..6da62b5426bb6 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -class MambaBase(ABC): +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. @@ -14,10 +20,7 @@ class MambaBase(ABC): # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - kv_cache: list[Iterable[torch.Tensor]] + kv_cache: tuple[torch.Tensor, ...] @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: @@ -32,3 +35,8 @@ class MambaBase(ABC): @abstractmethod def mamba_type(self) -> str: pass + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py new file mode 100644 index 0000000000000..99f05e2eca0e8 --- /dev/null +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +from typing import TYPE_CHECKING + +import torch +import torch.distributed +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.lightning_attn import ( + lightning_attention, + linear_decode_forward_triton, +) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch +import torch.distributed + + +class MiniMaxText01RMSNormTP(CustomOp): + name = "MiniMaxText01RMSNormTP" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world))) + + self.weight.weight_loader = self.weight_loader + self.variance_epsilon = eps + return + + @staticmethod + def weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + tp_world = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + shard_size = loaded_weight.shape[0] // tp_world + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + return + + def _forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) + if self.tp_world > 1: + variance = tensor_model_parallel_all_reduce(variance) / self.tp_world + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + return x + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert residual is None, "RMSNorm does not support residual connection." + return self._forward(x) + + +class MiniMaxText01LinearKernel: + @staticmethod + def jit_linear_forward_prefix( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + slope_rate = slope_rate.to(torch.float32) + should_pad_dim = q.dim() == 3 + if should_pad_dim: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + b, h, n, d = q.shape + e = d + kv_history = kv_caches.reshape(1, h, d, e).contiguous() + output, kv_history = lightning_attention( + q, k, v, slope_rate, block_size=block_size, kv_history=kv_history + ) + kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) + assert output.shape[0] == 1, "batch size must be 1" + return rearrange(output.squeeze(0), "h n d -> n (h d)") + + +class MiniMaxText01LinearAttention(nn.Module, MambaBase): + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend + + return LinearAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim + ) + + def __init__( + self, + hidden_size: int, + hidden_inner_size: int, + num_heads: int, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size * 3, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.output_gate = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_gate", + ) + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, + ) + + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * ( + 1 - layer_idx / (num_hidden_layer - 1) + 1e-5 + ) + self.tp_slope = self.slope_rate[ + self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads + ].contiguous() + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + @staticmethod + def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = torch.tensor( + get_slopes(n_attention_heads), dtype=torch.float32 + ).reshape(n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer( + self, q, k, v, kv_cache, state_indices_tensor, attn_metadata + ): + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break + offset = attn_metadata.num_decode_tokens + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx, + ) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden_decode = self._decode_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + hidden.insert(0, hidden_decode) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): + q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[: attn_metadata.num_decodes] + hidden = linear_decode_forward_triton( + q, k, v, kv_cache, self.tp_slope, slot_id, 32 + ) + return hidden + + def forward( + self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor + ) -> None: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward( + self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor + ) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx + ] + q_end = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx + 1 + ] + query_len = q_end - q_start + context_len = ( + attn_metadata.seq_lens[num_decode_tokens + prefill_idx] + - query_len + ) + if context_len == 0: + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx + ] + kv_cache[block_to_clear, ...] = 0 + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if attn_metadata is None: + hidden = torch.empty( + (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype + ) + else: + if not decode_only: + hidden = self._prefill_and_mix_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + else: + hidden = self._decode_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + + output[:num_actual_tokens], _ = self.out_proj(hidden) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output, positions=positions) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, +) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py deleted file mode 100644 index 3256ac034aa11..0000000000000 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ /dev/null @@ -1,186 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Optional, Union - -import numpy as np -import torch - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionMetadata) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba2_attn import ( - Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) - - -@dataclass -class Mamba2Metadata: - - has_initial_states: torch.Tensor - prep_initial_states: bool - - chunk_size: int - seq_idx: torch.Tensor - chunk_indices: torch.Tensor - chunk_offsets: torch.Tensor - """ - With continuous batching layout of `x` in vLLM, to enable a Triton program - to handle a request in parallel, two supporting tensors are used - (batch_ptr, token_chunk_offset_ptr) - BLOCK_M = the # tokens to be handled by a Triton program - (can be customized for different hardware) - - nums_dict: - tracks the data associated with a given value of BLOCK_M - BLOCK_M = #tokens handled by a Triton program - cu_seqlen: total tokens per batch - (used as flag to update other data at each new input) - batch_ptr: tracks batch-id handled by the Triton program - token_chunk_offset_ptr: tracks token group_idx handled by the Triton program - (Triton implementation of causal_conv1d handles parallelism in 3-axes - - feature-axis - - batch-axis - - sequence-axis) - """ - nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None - - -def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: - """Returns the appropriate metadata classes for the current platform.""" - if current_platform.is_rocm(): - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata) - return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) - elif current_platform.is_cuda(): - from vllm.attention.backends.flash_attn import FlashAttentionMetadata - from vllm.attention.backends.xformers import XFormersMetadata - return (FlashAttentionMetadata, XFormersMetadata, - PlaceholderAttentionMetadata) - raise ValueError( - f"Unsupported platform for Mamba2: {current_platform.device_type}") - - -def prepare_mamba2_metadata( - chunk_size: int, - attn_metadata: AttentionMetadata, - mamba2_metadata=None, -) -> Mamba2Metadata: - - # compute number of prefill and decode requests - # NOTE: in V0 we assume prefills are before decodes - num_prefills = attn_metadata.num_prefills - num_prefill_tokens = attn_metadata.num_prefill_tokens - - seq_idx = None - chunk_indices, chunk_offsets = None, None - # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend - has_initial_states = None - prep_initial_states = False - - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only - if num_prefills > 0: - attn_metadata_instances = get_platform_metadata_classes() - if (isinstance(attn_metadata, attn_metadata_instances) - and attn_metadata.context_lens_tensor is not None): - # precompute flag to avoid device syncs later in mamba2 layer - # forwards - # prep is only needed for mamba2 ssd prefill processing - has_initial_states = attn_metadata.context_lens_tensor > 0 - prep_initial_states = torch.any( - has_initial_states[:num_prefills]).item() - query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] - seq_idx = torch.repeat_interleave(torch.arange( - num_prefills, dtype=torch.int32, device=query_start_loc.device), - query_start_loc.diff(), - output_size=num_prefill_tokens) - seq_idx.unsqueeze_(0) - - # We compute metadata for chunked prefill once at the top level model - # forward and reuse them in mamba layers. If not needed, they will be - # ignored inside mamba kernels. - if prep_initial_states: - chunk_indices, chunk_offsets = \ - _query_start_loc_to_chunk_indices_offsets( - query_start_loc, chunk_size, num_prefill_tokens) - - if mamba2_metadata is not None: - mamba2_metadata.has_initial_states = has_initial_states - mamba2_metadata.prep_initial_states = prep_initial_states - mamba2_metadata.chunk_size = chunk_size - mamba2_metadata.seq_idx = seq_idx - mamba2_metadata.chunk_indices = chunk_indices - mamba2_metadata.chunk_offsets = chunk_offsets - # We use 1 reset flag: - # * mamba2_metadata.cu_seqlen is None - # update config specific to (each input) - # (become available at first layer, e.g. conv_weights) - mamba2_metadata.cu_seqlen = None # suppose to be updated at each input - - return mamba2_metadata - return Mamba2Metadata(has_initial_states=has_initial_states, - prep_initial_states=prep_initial_states, - chunk_size=chunk_size, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets) - - -def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, - mamba2_metadata: Union[Mamba2Metadata, - Mamba2AttentionMetadata]): - """ - this is triggered upon handling a new input at the first layer - """ - dim, cu_seqlen = x.shape - mamba2_metadata.cu_seqlen = cu_seqlen - seqlens = np.diff(query_start_loc.to('cpu')) - nums_dict = {} # type: ignore - for BLOCK_M in [8]: # cover all BLOCK_M values - nums = -(-seqlens // BLOCK_M) - nums_dict[BLOCK_M] = {} - nums_dict[BLOCK_M]['nums'] = nums - nums_dict[BLOCK_M]['tot'] = nums.sum().item() - mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) - nums_dict[BLOCK_M]['mlist'] = mlist - mlist_len = len(nums_dict[BLOCK_M]['mlist']) - nums_dict[BLOCK_M]['mlist_len'] = mlist_len - MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 - offsetlist = [] # type: ignore - for idx, num in enumerate(nums): - offsetlist.extend(range(num)) - offsetlist = torch.tensor(offsetlist, dtype=torch.int32) - nums_dict[BLOCK_M]['offsetlist'] = offsetlist - - if mamba2_metadata.batch_ptr is None: - # Update default value after class definition - #mamba2_metadata.MAX_NUM_PROGRAMS *= 2 - mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device='cuda') - mamba2_metadata.token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device='cuda') - else: - if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS: - mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_( - PAD_SLOT_ID) - mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore - MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) - - mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist) - mamba2_metadata.token_chunk_offset_ptr[ # type: ignore - 0:mlist_len].copy_(offsetlist) - nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr - nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = ( - mamba2_metadata.token_chunk_offset_ptr) # type: ignore - mamba2_metadata.nums_dict = nums_dict - return mamba2_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a24e72778b34b..8ab77965ae80a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,33 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn from torch.nn.parameter import Parameter -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) -from vllm.model_executor.models.mamba_cache import MambaCacheParams + selective_scan_fn, + selective_state_update, +) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -45,22 +54,24 @@ class MambaMixer(MambaBase, CustomOp): **selective** state spaces) """ - def __init__(self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - time_step_rank: int, - use_conv_bias: bool, - use_bias: bool, - use_rms_norm: bool, - rms_norm_has_weight: bool = True, - rms_norm_eps: float = 1e-5, - activation="silu", - is_lora_enabled: bool = False, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + time_step_rank: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + rms_norm_has_weight: bool = True, + rms_norm_eps: float = 1e-5, + activation="silu", + is_lora_enabled: bool = False, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size @@ -81,9 +92,9 @@ class MambaMixer(MambaBase, CustomOp): # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear(hidden_size, - [intermediate_size] * 2, - bias=use_bias) + self.in_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=use_bias + ) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( @@ -94,17 +105,18 @@ class MambaMixer(MambaBase, CustomOp): # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(time_step_rank, - intermediate_size, - bias=True, - skip_bias_add=True) + self.dt_proj = ColumnParallelLinear( + time_step_rank, intermediate_size, bias=True, skip_bias_add=True + ) def weight_loader(param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[ + tp_rank + ] + ) def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): weight_loader(param, -torch.exp(loaded_weight.float())) @@ -115,7 +127,8 @@ class MambaMixer(MambaBase, CustomOp): intermediate_size // tp_size, ssm_state_size, dtype=torch.float32, - )) + ) + ) self.D = nn.Parameter(torch.ones(intermediate_size // tp_size)) set_weight_attrs(self.D, {"weight_loader": weight_loader}) @@ -128,41 +141,49 @@ class MambaMixer(MambaBase, CustomOp): input_is_parallel=True, ) - self.dt_layernorm = RMSNorm( - time_step_rank, - eps=rms_norm_eps, - has_weight=rms_norm_has_weight, - ) if use_rms_norm else None + self.dt_layernorm = ( + RMSNorm( + time_step_rank, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) + if use_rms_norm + else None + ) - self.b_layernorm = RMSNorm( - ssm_state_size, - eps=rms_norm_eps, - has_weight=rms_norm_has_weight, - ) if use_rms_norm else None + self.b_layernorm = ( + RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) + if use_rms_norm + else None + ) - self.c_layernorm = RMSNorm( - ssm_state_size, - eps=rms_norm_eps, - has_weight=rms_norm_has_weight, - ) if use_rms_norm else None + self.c_layernorm = ( + RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) + if use_rms_norm + else None + ) - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) self.model_config = model_config self.cache_config = cache_config self.prefix = prefix def _ssm_transform( - self, x: torch.Tensor + self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.is_lora_enabled: # Lora kernel requires contiguous tensor. @@ -172,7 +193,8 @@ class MambaMixer(MambaBase, CustomOp): time_step, B, C = torch.split( ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1) + dim=-1, + ) if self.use_rms_norm: assert self.dt_layernorm is not None assert self.b_layernorm is not None @@ -183,29 +205,17 @@ class MambaMixer(MambaBase, CustomOp): discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C - def forward(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params) - else: - torch.ops.vllm.mamba_mixer( - hidden_states, - output, - self.prefix, - ) + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor): + torch.ops.vllm.mamba_mixer( + hidden_states, + output, + self.prefix, + ) - def forward_native(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor): pass - def forward_cuda(self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): """ Run the Mamba-1 SSM pipeline. @@ -231,40 +241,28 @@ class MambaMixer(MambaBase, CustomOp): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba1_metadata = attn_metadata - assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) - query_start_loc = mamba1_metadata.query_start_loc - state_indices_tensor = mamba1_metadata.state_indices_tensor - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - has_initial_states = mamba1_metadata.has_initial_states - num_padded_decodes = mamba1_metadata.num_padded_decodes - else: - assert isinstance(attn_metadata, AttentionMetadata) - assert mamba_cache_params is not None - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - query_start_loc = attn_metadata.query_start_loc - context_lens_tensor = attn_metadata.context_lens_tensor - has_initial_states = None - if context_lens_tensor is not None: - has_initial_states = context_lens_tensor > 0 - num_padded_decodes = attn_metadata.num_decode_tokens + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba1_metadata = attn_metadata + assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) + query_start_loc = mamba1_metadata.query_start_loc + state_indices_tensor = mamba1_metadata.state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) hidden_states_BC, gate = projected_states.chunk(2, dim=-2) - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) - if envs.VLLM_USE_V1 and attn_metadata is None: + if attn_metadata is None: # V1 profile run hidden_states_BC = hidden_states_BC.contiguous() return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] @@ -310,10 +308,12 @@ class MambaMixer(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - query_start_loc=query_start_loc_p) + query_start_loc=query_start_loc_p, + ) # 3. State Space Model sequence transformations. discrete_time_step_p, B_p, C_p = self._ssm_transform( - conv_out_p.transpose(-2, -1)) + conv_out_p.transpose(-2, -1) + ) time_proj_bias = self._time_proj_bias() # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) @@ -330,7 +330,8 @@ class MambaMixer(MambaBase, CustomOp): delta_softplus=True, cache_indices=state_indices_tensor_p, has_initial_state=has_initial_states_p, - query_start_loc=query_start_loc_p) + query_start_loc=query_start_loc_p, + ) ssm_outputs.append(scan_out_p) if has_decode: @@ -341,42 +342,42 @@ class MambaMixer(MambaBase, CustomOp): conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d).transpose(0, 1) + conv_state_indices=state_indices_tensor_d, + ).transpose(0, 1) # 3. State Space Model sequence transformation. discrete_time_step_d, B_d, C_d = self._ssm_transform( - conv_out_d.transpose(-2, -1)) + conv_out_d.transpose(-2, -1) + ) time_proj_bias = self._time_proj_bias() # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) - scan_outputs_d = torch.empty_like( - hidden_states_BC_d.transpose(0, 1)) - selective_state_update(ssm_state, - conv_out_d.transpose(0, 1), - discrete_time_step_d.transpose(0, 1), - self.A, - B_d, - C_d, - self.D, - gate_d.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=state_indices_tensor_d, - out=scan_outputs_d) + scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1)) + selective_state_update( + ssm_state, + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), + self.A, + B_d, + C_d, + self.D, + gate_d.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d, + ) scan_outputs_d = scan_outputs_d.transpose(0, 1) - if envs.VLLM_USE_V1: - ssm_outputs.insert(0, scan_outputs_d) - else: - ssm_outputs.append(scan_outputs_d) + ssm_outputs.insert(0, scan_outputs_d) - scan_outputs_combined = ssm_outputs[0] if len( - ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + scan_outputs_combined = ( + ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + ) # 5. Final output projection if self.is_lora_enabled: # Lora kernel requires contiguous tensor. - scan_outputs_combined = scan_outputs_combined.transpose( - -2, -1).contiguous() + scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous() out = self.out_proj(scan_outputs_combined)[0] else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] @@ -404,6 +405,11 @@ class MambaMixer(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba1" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend + + return Mamba1AttentionBackend + def _time_proj_bias(self) -> Optional[torch.Tensor]: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() @@ -435,38 +441,32 @@ def split_batch_to_prefill_and_decode( ) -> PrefillDecodeSplit: num_actual_tokens = num_prefill_tokens + num_padded_decodes - if envs.VLLM_USE_V1: - # In v1, decode tokens come first, then prefill tokens. - hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) - gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1, + ) + gate_d, gate_p = torch.split( + gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1 + ) - # num_padded_decodes accounts for CUDA graph padding when applicable - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_padded_decodes + num_prefills], - [num_padded_decodes, num_prefills], - dim=0) - query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_padded_decodes if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[-num_prefills:] if ( - has_initial_states is not None and num_prefills > 0) else None - else: - # In v0, prefill tokens come first, then decode tokens. - hidden_states_BC_p, hidden_states_BC_d = torch.split( - hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) - gate_p, gate_d = torch.split(gate, - [num_prefill_tokens, num_decode_tokens], - dim=-1) - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, [num_prefills, num_decodes], dim=0) - query_start_loc_p = (query_start_loc[:num_prefills + - 1] if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[:num_prefills] if ( - has_initial_states is not None and num_prefills > 0) else None + # num_padded_decodes accounts for CUDA graph padding when applicable + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[: num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + query_start_loc[-num_prefills - 1 :] - num_padded_decodes + if num_prefills > 0 + else None + ) + has_initial_states_p = ( + has_initial_states[-num_prefills:] + if (has_initial_states is not None and num_prefills > 0) + else None + ) return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, @@ -487,9 +487,7 @@ def mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mamba_cache_params=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def mamba_mixer_fake( @@ -505,5 +503,4 @@ direct_register_custom_op( op_func=mamba_mixer, mutates_args=["output"], fake_impl=mamba_mixer_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 743e520ec8ee1..7589905ac9277 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,40 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, - update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_state_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) + mamba_chunk_scan_combined_varlen, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( - LoaderFunction, composed_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheParams + LoaderFunction, + composed_weight_loader, + sharded_weight_loader, +) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -44,12 +55,13 @@ from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): - - def __init__(self, - full_hidden_size: int, - full_n_groups: int, - use_rms_norm: bool = True, - eps: float = 1e-6): + def __init__( + self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6, + ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -63,13 +75,13 @@ class Mixer2RMSNormGated(CustomOp): if self.use_rms_norm: # Register norm weight only if we're actually applying RMSNorm self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) - set_weight_attrs(self.weight, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) else: # Avoid checkpoint mismatch by skipping unused parameter self.register_parameter("weight", None) - assert (self.full_hidden_size % self.tp_size == 0 - ), "Tensor parallel world size must divide hidden size." + assert self.full_hidden_size % self.tp_size == 0, ( + "Tensor parallel world size must divide hidden size." + ) def forward_native( self, @@ -112,8 +124,7 @@ class Mixer2RMSNormGated(CustomOp): group_count = hidden_dim // self.group_size x_grouped = x.view(*prefix_dims, group_count, self.group_size) variance = x_grouped.pow(2).mean(-1, keepdim=True) - x_grouped = x_grouped * torch.rsqrt(variance + - self.variance_epsilon) + x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) x = x_grouped.view(*prefix_dims, hidden_dim) if redundant_tp: @@ -131,18 +142,19 @@ class Mixer2RMSNormGated(CustomOp): input_dtype = x.dtype if not self.use_rms_norm: # Keep gate in float32 for numerical stability during silu - return x * nn.functional.silu(gate.to( - torch.float32)).to(input_dtype) + return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype) - if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1): + if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1: return self.forward_native(x, gate) - return rms_norm_gated(x, - self.weight.data, - bias=None, - z=gate, - eps=self.variance_epsilon, - norm_before_gate=False) + return rms_norm_gated( + x, + self.weight.data, + bias=None, + z=gate, + eps=self.variance_epsilon, + norm_before_gate=False, + ) def mamba_v2_sharded_weight_loader( @@ -157,7 +169,6 @@ def mamba_v2_sharded_weight_loader( """ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - # - track boundary of (sharded) param, and loaded_weight, respectively boundary, loaded_boundary = 0, 0 @@ -192,11 +203,12 @@ def mamba_v2_sharded_weight_loader( # seem to handle slices well. # https://github.com/python/mypy/issues/2410 param.data[ - boundary:(boundary + take), - ... # type: ignore[misc] - ] = loaded_weight[loaded_start_idx:(loaded_start_idx + - take) # type: ignore[misc] - ] # type: ignore[misc] + boundary : (boundary + take), ... # type: ignore[misc] + ] = loaded_weight[ + loaded_start_idx : ( + loaded_start_idx + take + ) # type: ignore[misc] + ] # type: ignore[misc] # move indexing boundaries boundary += shard_size @@ -218,23 +230,25 @@ class MambaMixer2(MambaBase, CustomOp): **selective** state spaces) """ - def __init__(self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() # For TP, the sharding plan is as follows: @@ -254,16 +268,21 @@ class MambaMixer2(MambaBase, CustomOp): self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() - assert (num_heads % self.tp_size == 0 - ), "Tensor parallel world size must divide num heads." + assert num_heads % self.tp_size == 0, ( + "Tensor parallel world size must divide num heads." + ) assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( - "If tensor parallel world size does not divide num_heads, " - "then num_groups must equal 1.") + "If tensor parallel world size does not divide num_groups, " + "then num_groups must equal 1." + ) assert ( - self.tp_size == 1 or quant_config is None - ), "Tensor parallel currently not supported for quantized models." + (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None + ), ( + "Tensor parallel currently supported for quantized models only " + "if tensor parallel world size divides num groups." + ) self.ssm_state_size = ssm_state_size self.conv_kernel_size = conv_kernel_size @@ -279,95 +298,86 @@ class MambaMixer2(MambaBase, CustomOp): # - but if n_groups cannot divide tp_size, we need to # extend some extra groups groups = MambaStateShapeCalculator.extra_groups_for_head_shards( - n_groups, self.tp_size) + n_groups, self.tp_size + ) self.n_groups = n_groups + groups - self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size - self.conv1d = ColumnParallelLinear( - input_size=conv_kernel_size, - output_size=self.conv_dim, - bias=use_conv_bias, - quant_config=None, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + self.groups_ssm_state_size = self.n_groups * self.ssm_state_size + self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config, - ) + if n_groups % self.tp_size == 0: + self.conv1d = MergedColumnParallelLinear( + input_size=conv_kernel_size, + output_sizes=[ + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + ], + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - # - because in_proj is a concatenation of 3 weights, we - # need to interleave them before sharding - # - use the custom weight loader mamba_v2_sharded_weight_loader - # for conv1d.bias, covn1d.weight and in_proj.weight - # - need to set these settings, to assign the groups to the head shards - group_shard_settings = ( - self.n_groups * self.ssm_state_size, # expected model size - (self.n_groups - n_groups) * - self.ssm_state_size, # extra dims assigned - n_groups == 1, # if there was only one group - ) - intermediate_settings = (intermediate_size, 0, False) - head_settings = (self.num_heads, 0, False) + self.in_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + intermediate_size, + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + self.num_heads, + ], + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + else: + # This is the n_groups == 1 case, + # where we need to duplicate groups if TP>1. - # - the weight already has a "weight_loader" attribute - # which set_weight_attrs will raise if we do not - # delete before trying to override it - # - ditto for the otther two weights below - delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs( - self.conv1d.bias, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) - if quant_config is None: - # - quant layers do not have a weight loader - delattr(self.in_proj.weight, "weight_loader") + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups + # to the head shards + group_shard_settings = ( + self.groups_ssm_state_size, # expected model size + (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the other two weights below + delattr(self.conv1d.bias, "weight_loader") set_weight_attrs( - self.in_proj.weight, + self.conv1d.bias, { - "weight_loader": - mamba_v2_sharded_weight_loader( + "weight_loader": mamba_v2_sharded_weight_loader( [ - intermediate_settings, # for gate intermediate_settings, group_shard_settings, group_shard_settings, - head_settings, # for dt ], self.tp_size, tp_rank, @@ -375,23 +385,66 @@ class MambaMixer2(MambaBase, CustomOp): }, ) + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) + + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + tp_rank, + ) + }, + ) + + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `MergedColumnParallelLinear`, + # and `set_weight_attrs` doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + # - these are TPed by heads to reduce the size of the # temporal shape self.A = nn.Parameter( torch.empty( divide(num_heads, self.tp_size), dtype=torch.float32, - )) + ) + ) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.use_rms_norm = use_rms_norm set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + sharded_weight_loader(0), lambda x: -torch.exp(x.float()) + ) set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( intermediate_size, @@ -399,23 +452,19 @@ class MambaMixer2(MambaBase, CustomOp): bias=use_bias, input_is_parallel=True, quant_config=quant_config, + prefix=f"{prefix}.out_proj", ) - self.norm = Mixer2RMSNormGated(intermediate_size, - n_groups, - self.use_rms_norm, - eps=rms_norm_eps) + self.norm = Mixer2RMSNormGated( + intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps + ) - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - # The inner tuple is (conv_state, ssm_state) - self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) self.model_config = model_config self.cache_config = cache_config @@ -425,8 +474,6 @@ class MambaMixer2(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): pass @@ -435,64 +482,47 @@ class MambaMixer2(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): - if not envs.VLLM_USE_V1: - CustomOp.forward(self, hidden_states, output, mamba_cache_params, - mamba2_metadata, mup_vector) - else: - torch.ops.vllm.mamba_mixer2( - hidden_states, - output, - self.prefix, - mup_vector, - ) + torch.ops.vllm.mamba_mixer2( + hidden_states, + output, + self.prefix, + mup_vector, + ) def forward_cuda( self, hidden_states: torch.Tensor, output: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): forward_context = get_forward_context() - # mamba2_metadata contains metadata necessary for the mamba2 triton + # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - mamba2_metadata = attn_metadata - assert isinstance(attn_metadata, Mamba2AttentionMetadata) - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # conv_state = (..., dim, width-1) yet contiguous along 'dim' - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states_p - prep_initial_states = attn_metadata.prep_initial_states - chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p - else: - conv_state = mamba_cache_params.conv_state - ssm_state = mamba_cache_params.ssm_state - state_indices_tensor = mamba_cache_params.state_indices_tensor - has_initial_states_p = mamba2_metadata.has_initial_states - prep_initial_states = mamba2_metadata.prep_initial_states - chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx - chunk_indices_p = mamba2_metadata.chunk_indices - chunk_offsets_p = mamba2_metadata.chunk_offsets - groups_time_state_size = self.n_groups * self.ssm_state_size + assert self.cache_config is not None + mamba_block_size = self.cache_config.mamba_block_size + prefix_caching_enabled = self.cache_config.enable_prefix_caching + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + query_start_loc_p = attn_metadata.query_start_loc_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -510,30 +540,32 @@ class MambaMixer2(MambaBase, CustomOp): dim=-1, ) - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) # - get hidden_states, B and C after depthwise convolution. split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( hidden_states_B_C, [ self.intermediate_size // self.tp_size, - groups_time_state_size // self.tp_size, - groups_time_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, ], dim=-1, ) - if envs.VLLM_USE_V1 and attn_metadata is None: - # V1 profile run - hidden_states_B_C = (hidden_states_B_C.transpose( - 0, 1).clone().transpose(0, 1)).contiguous() - hidden_states, _B, _C = split_hidden_states_B_C_fn( - hidden_states_B_C) + if attn_metadata is None: + # profile run + hidden_states_B_C = ( + hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1) + ).contiguous() + hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C) hidden_states = self.norm(hidden_states, gate) out, _ = self.out_proj(hidden_states) return out + # NOTE: V0 put prefill before decode, v1 puts decode before prefill num_prefills = attn_metadata.num_prefills # request count num_decodes = attn_metadata.num_decode_tokens # token count (=request) num_prefill_tokens = attn_metadata.num_prefill_tokens # token count @@ -541,83 +573,89 @@ class MambaMixer2(MambaBase, CustomOp): has_decode = num_decodes > 0 num_actual_tokens = num_prefill_tokens + num_decodes - # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - if envs.VLLM_USE_V1: - hidden_states_B_C_d, hidden_states_B_C_p = torch.split( - hidden_states_B_C[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, + hidden_states_B_C_d, hidden_states_B_C_p = torch.split( + hidden_states_B_C[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + dt_d, dt_p = torch.split( + dt[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor[:num_actual_tokens], + [num_decodes, num_prefills], + dim=0, + ) + + if prefix_caching_enabled: + # If prefix caching is enabled, retrieve the relevant variables + # for prefill and decode + block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( + torch.split( + attn_metadata.block_idx_last_computed_token, + [num_decodes, num_prefills], + dim=0, + ) ) - dt_d, dt_p = torch.split( - dt[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0, + block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = ( + torch.split( + attn_metadata.block_idx_last_scheduled_token, + [num_decodes, num_prefills], + dim=0, + ) ) - # Split along batch dimension - state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_actual_tokens], - [num_decodes, num_prefills], - dim=0, + # Prefill-only variables: + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p ) - query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p else: - hidden_states_B_C_p, hidden_states_B_C_d = torch.split( - hidden_states_B_C, - [num_prefill_tokens, num_decodes], - dim=0, - ) - dt_p, dt_d = torch.split( - dt, - [num_prefill_tokens, num_decodes], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + - 1] - if has_prefill else None) + block_idx_last_computed_token_d = None + block_idx_last_computed_token_p = None + block_idx_last_scheduled_token_d = None + block_idx_last_scheduled_token_p = None + block_idx_first_scheduled_token_p = None + num_computed_tokens_p = None # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( [ num_prefill_tokens + num_decodes, - (self.num_heads // self.tp_size) * self.head_dim + (self.num_heads // self.tp_size) * self.head_dim, ], dtype=hidden_states.dtype, device=hidden_states.device, ) - if envs.VLLM_USE_V1: - preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( - preallocated_ssm_out, - [num_decodes, num_prefill_tokens], - dim=0, - ) - else: - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) # Process prefill requests if has_prefill: # 2. Convolution sequence transformation - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" + # - It will read the initial states for every sequence, + # that has "has_initial_states_p" == True, + # from "cache_indices", using "state_indices_tensor_p". + # - It updates the "conv_state" cache in positions pointed + # to by "state_indices_tensor_p". + # In particular, it will always write the state at the + # sequence end. + # In addition, "block_idx_first_scheduled_token_p" and + # "block_idx_last_scheduled_token_p" + # are provided (which are pointers into + # "state_indices_tensor_p"), it will write additional cache + # states aligned at "block_size_to_align". x = hidden_states_B_C_p.transpose( - 0, 1) # this is the form that causal-conv see - if mamba2_metadata.cu_seqlen is None: - mamba2_metadata = update_metadata(x, query_start_loc_p, - mamba2_metadata) + 0, 1 + ) # this is the form that causal-conv see hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -626,60 +664,153 @@ class MambaMixer2(MambaBase, CustomOp): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - metadata=mamba2_metadata, - query_start_loc=query_start_loc_p).transpose( - 0, 1)[:num_prefill_tokens] + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, + num_computed_tokens=num_computed_tokens_p, + block_size_to_align=mamba_block_size, + metadata=attn_metadata, + query_start_loc=query_start_loc_p, + ).transpose(0, 1)[:num_prefill_tokens] - hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( - hidden_states_B_C_p) + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p) # 3. State Space Model sequence transformation initial_states = None - if (has_initial_states_p is not None and prep_initial_states): - # making a copy of the states - if envs.VLLM_USE_V1: - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) - else: - initial_states = torch.where( - has_initial_states_p[:num_prefills, None, None, None], - ssm_state[state_indices_tensor_p], 0) + if has_initial_states_p is not None and prep_initial_states: + kernel_ssm_indices = state_indices_tensor_p + if prefix_caching_enabled: + kernel_ssm_indices = state_indices_tensor_p.gather( + 1, block_idx_last_computed_token_p.unsqueeze(1) + ).squeeze(1) + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[kernel_ssm_indices], + 0, + ) # NOTE: final output is an in-place update of out tensor - varlen_state = mamba_chunk_scan_combined( - hidden_states_p.view(1, num_prefill_tokens, - self.num_heads // self.tp_size, - self.head_dim), - dt_p.unsqueeze(0), + varlen_states = mamba_chunk_scan_combined_varlen( + hidden_states_p.view( + num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), + dt_p, self.A, - B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, - -1), - C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, - -1), + B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1), + C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1), chunk_size=chunk_size, D=self.D, z=None, dt_bias=self.dt_bias, seq_idx=seq_idx_p, - chunk_indices=chunk_indices_p, - chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, - return_varlen_states=True, - return_final_states=False, + return_intermediate_states=prefix_caching_enabled, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, - self.head_dim), - state_dtype=ssm_state.dtype) + out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, + ) - # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - ssm_state[state_indices_tensor_p] = varlen_state + if prefix_caching_enabled: + # The chunk_stride is the number of chunks per mamba block + # e.g., if mamba_block_size = 512 and chunk_size = 256, + # then chunk_stride = 2 + chunk_stride = mamba_block_size // chunk_size + + # Save state for sequences with more than just final state + for seq_idx in range(num_prefills): + # Block index for the first scheduled token + block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[ + seq_idx + ] + + # Block index for the last scheduled token + block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[ + seq_idx + ] + + # Number of blocks that need to be written + n_blocks_to_fill = ( + block_idx_last_scheduled_token - block_idx_first_scheduled_token + ) + + # Skip sequences that don't have any blocks to fill + if n_blocks_to_fill == 0: + continue + + # Look up the state indices + cache_blocks_to_fill = state_indices_tensor_p[ + seq_idx, + block_idx_first_scheduled_token:block_idx_last_scheduled_token, + ] + + # First chunk index for this sequence + if seq_idx == 0: + first_chunk = 0 + else: + first_chunk = 1 + last_chunk_indices_p[seq_idx - 1] + + # First chunk that is aligned on the mamba block boundary + first_aligned_chunk = first_chunk + chunk_stride - 1 + + # Calculate the number of computed tokens that were not + # already cached + num_unaligned_computed_tokens = ( + num_computed_tokens_p[seq_idx] % mamba_block_size + ) + + if num_unaligned_computed_tokens > 0: + # If the number of computed tokens is not block aligned, + # then we need to shift the index accordingly + first_aligned_chunk -= ( + num_unaligned_computed_tokens // chunk_size + ) + + # Get states to write + from_where = varlen_states[ + first_aligned_chunk : first_aligned_chunk + + n_blocks_to_fill * chunk_stride : chunk_stride + ] + + # Write the states + ssm_state[cache_blocks_to_fill] = from_where + + # For all seqs, store the last state (note: might be partial): + ssm_state[ + state_indices_tensor_p.gather( + 1, block_idx_last_scheduled_token_p.unsqueeze(1) + ).squeeze(1) + ] = varlen_states[last_chunk_indices_p] + + else: + # update ssm states + # - varlen state is a (num_prefills, nheads, headdim, dstate) + # tensor + ssm_state[state_indices_tensor_p] = varlen_states # Process decode requests if has_decode: + if prefix_caching_enabled: + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) + # for decode: + # block_idx_first_scheduled_token_d == + # block_idx_last_scheduled_token_d + # at block boundaries: + # block_idx_first_scheduled_token_d > + # block_idx_last_computed_token_d + else: + # Without caching, read and write in-place to the same blocks: + state_indices_tensor_d_input = state_indices_tensor_d + state_indices_tensor_d_output = state_indices_tensor_d + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, @@ -687,22 +818,28 @@ class MambaMixer2(MambaBase, CustomOp): conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, + initial_state_idx=block_idx_last_computed_token_d, + ) - hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( - hidden_states_B_C_d) + hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d) # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size - A_d = self.A[:, None, ...][:, :, None].expand( - -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + A_d = ( + self.A[:, None, ...][:, :, None] + .expand(-1, self.head_dim, self.ssm_state_size) + .to(dtype=torch.float32) + ) dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D_d = self.D[:, None, ...].expand(-1, self.head_dim) B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups) C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups) hidden_states_d = hidden_states_d.view( - -1, self.num_heads // self.tp_size, self.head_dim) + -1, self.num_heads // self.tp_size, self.head_dim + ) # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected @@ -719,17 +856,16 @@ class MambaMixer2(MambaBase, CustomOp): z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor_d, - out=preallocated_ssm_out_d.view(num_decodes, -1, - self.head_dim), + state_batch_indices=state_indices_tensor_d_input, + dst_state_batch_indices=state_indices_tensor_d_output, + out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) # 4. gated MLP # GatedRMSNorm internally applying SiLU to the gate # SiLU is applied internally before normalization, unlike standard # norm usage - hidden_states = self.norm(preallocated_ssm_out, - gate[:num_actual_tokens]) + hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens]) # 5. Final linear projection output[:num_actual_tokens], _ = self.out_proj(hidden_states) @@ -758,6 +894,11 @@ class MambaMixer2(MambaBase, CustomOp): def mamba_type(self) -> str: return "mamba2" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend + + return Mamba2AttentionBackend + def mamba_mixer2( hidden_states: torch.Tensor, @@ -767,11 +908,7 @@ def mamba_mixer2( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mamba_cache_params=None, - mamba2_metadata=None, - mup_vector=mup_vector) + self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector) def mamba_mixer2_fake( @@ -788,5 +925,4 @@ direct_register_custom_op( op_func=mamba_mixer2, mutates_args=["output"], fake_impl=mamba_mixer2_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 280a9e45e662e..0f160b2c924fb 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -4,13 +4,13 @@ from typing import Union import torch -from vllm.config import MambaDType, ModelDType +from vllm.config.cache import MambaDType +from vllm.config.model import ModelDType from vllm.distributed import divide from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype class MambaStateDtypeCalculator: - @classmethod def linear_attention_state_dtype( cls, @@ -21,7 +21,7 @@ class MambaStateDtypeCalculator: if mamba_cache_dtype == "float32": raise ValueError("fp32 state for minimax is not yet supported") state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) - return (state_dtype, ) + return (state_dtype,) @classmethod def mamba1_state_dtype( @@ -30,12 +30,9 @@ class MambaStateDtypeCalculator: mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - # TODO (tdoublep) requires kernel changes - if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32": - raise ValueError("fp32 state for mamba1 is not yet supported") - else: - return MambaStateDtypeCalculator.mamba2_state_dtype( - model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype) + return cls._mamba_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype + ) @classmethod def mamba2_state_dtype( @@ -44,13 +41,22 @@ class MambaStateDtypeCalculator: mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, - model_dtype) + return cls._mamba_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype + ) + + @classmethod + def _mamba_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) if mamba_ssm_cache_dtype == "auto": temporal_state_dtype = conv_state_dtype else: - temporal_state_dtype = ( - STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]) + temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype] return (conv_state_dtype, temporal_state_dtype) @@ -60,13 +66,20 @@ class MambaStateDtypeCalculator: model_dtype: Union[ModelDType, torch.dtype], mamba_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, - model_dtype) - return (conv_state_dtype, ) + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (conv_state_dtype,) + + @classmethod + def gated_delta_net_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, torch.dtype]: + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, state_dtype) class MambaStateShapeCalculator: - @classmethod def linear_attention_state_shape( cls, @@ -74,9 +87,8 @@ class MambaStateShapeCalculator: tp_size: int, head_dim: int, ) -> tuple[tuple[int, int, int], ...]: - state_shape = (num_heads // tp_size, head_dim, head_dim) - return (state_shape, ) + return (state_shape,) @classmethod def mamba1_state_shape( @@ -85,19 +97,12 @@ class MambaStateShapeCalculator: intermediate_size: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: - conv_state_shape = (divide(intermediate_size, - tp_world_size), conv_kernel - 1) + conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) - temporal_state_shape = (divide(intermediate_size, - tp_world_size), state_size) + temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) - # In V0, the conv_state shape was swapped during allocation in - # MambaCacheManager, but in V1 it needs to be determined here at the - # calculation level - if use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] + conv_state_shape = conv_state_shape[1], conv_state_shape[0] return conv_state_shape, temporal_state_shape @@ -111,25 +116,20 @@ class MambaStateShapeCalculator: head_dim: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it - n_groups = n_groups + cls.extra_groups_for_head_shards( - n_groups, tp_world_size) + n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size) # heads and n_groups are TP-ed conv_dim = intermediate_size + 2 * n_groups * state_size # contiguous along 'dim' axis conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) - if not use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = (divide(num_heads, - tp_world_size), head_dim, state_size) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) return conv_state_shape, temporal_state_shape @classmethod @@ -138,13 +138,10 @@ class MambaStateShapeCalculator: tp_world_size: int, intermediate_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int]]: conv_dim = divide(intermediate_size, tp_world_size) conv_state_shape = (conv_kernel - 1, conv_dim) - if not use_v1: - conv_state_shape = conv_state_shape[1], conv_state_shape[0] - return (conv_state_shape, ) + return (conv_state_shape,) @classmethod def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): @@ -157,3 +154,29 @@ class MambaStateShapeCalculator: # for n_groups == 1, this is exactly tp_size - n_groups return tp_size - ngroups + + @classmethod + def gated_delta_net_state_shape( + cls, + tp_world_size: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_kernel_size: int, + num_spec: int = 0, + ): + conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads + conv_state_shape = ( + divide(conv_dim, tp_world_size), + conv_kernel_size - 1 + num_spec, + ) + + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + temporal_state_shape = ( + divide(num_v_heads, tp_world_size), + head_k_dim, + head_v_dim, + ) + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b8d4bbc37105d..ec486d3b92678 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -20,39 +20,41 @@ def _causal_conv1d_fwd_kernel( # continuous batching w_ptr, # (dim, width) bias_ptr, initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # conv_state_indices_ptr + cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains + # the block indices relevant for each sequence + # plus potential 0-padding at the beginning and at the end has_initial_states_ptr, query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, + block_idx_first_scheduled_token, # (batch,) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + num_computed_tokens, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions - batch: tl.int32, # actually padded_batch dim: tl.constexpr, seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides - stride_x_seq: tl.constexpr, # stride to get to next sequence, stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl. - constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value stride_w_width: tl.constexpr, # stride to get to next width-axis value stride_istate_seq: tl.constexpr, stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, - stride_o_seq: tl.constexpr, + stride_cache_indices: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, + stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - HAS_CACHE: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, @@ -63,13 +65,15 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_conv_state_seq = stride_istate_seq stride_conv_state_dim = stride_istate_dim stride_conv_state_tok = stride_istate_token - state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + state_len = ( + KERNEL_WIDTH - 1 + ) # can be passed via argument if it's not the same as this value # one program handles one chunk in a single sequence # rather than mixing sequences - to make updating initial_states across sequences efficiently # single-sequence id - idx_seq = tl.load(batch_ptr + tl.program_id(0)) + idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) # BLOCK_N elements along the feature-dimension (channel) @@ -83,26 +87,62 @@ def _causal_conv1d_fwd_kernel( # continuous batching # find the actual sequence length seqlen = sequence_end_index - sequence_start_index + B_size: tl.constexpr = stride_block_m * BLOCK_M + + if IS_APC_ENABLED: + # Handle the case if prefix caching is enabled. + # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" + + # Get the length of the completed sequence so far and compute the offset. + current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) + sequence_completed_index = tl.load(num_computed_tokens + idx_seq) + + # Compute the offset where the first stride_block_m-aligned first full block is + # Value in "token-space" + sequence_completed_offset_token = sequence_completed_index % B_size + seq_completed_offset = B_size - sequence_completed_offset_token + seq_end_offset = (seqlen - seq_completed_offset) % B_size + last_full_block_token_index = sequence_end_index - seq_end_offset + # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one + if seq_end_offset == 0: + last_full_block_token_index = last_full_block_token_index - B_size + + # Get the number of blocks to be filled for the current sequence + # If n_block_to_fill = 0, then only the state at the sequence end is stored + n_block_to_fill = current_last_index - current_first_index + + # Get the index of the init block + conv_state_init_index = tl.load(initial_state_idx + idx_seq) + else: + n_block_to_fill = 0 + current_last_index = 0 + conv_state_init_index = 0 + current_first_index = 0 + last_full_block_token_index = 0 + token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) # base of the sequence - x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + x_base = ( + x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim + ) # [BLOCK_N,] + + # cache_idx + conv_states_input_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index + ).to(tl.int64) - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) - else: - # cache_idx - conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return - conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_states_base = ( + conv_states_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -111,14 +151,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] if chunk_offset == 0: # read from conv_states - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( - tl.int1) + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states - prior_tokens = conv_states_base + (state_len - - 1) * stride_conv_state_tok + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] @@ -148,40 +184,56 @@ def _causal_conv1d_fwd_kernel( # continuous batching # prior-tokens are zeros if KERNEL_WIDTH >= 2: # STRATEGY1 # first chunk and does not have prior-token, so just set to 0 - col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 3: # STRATEGY1 - col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 4: # STRATEGY1 - col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 5: # STRATEGY1 - col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) # STEP 2: # here prepare data for updating conv_state - if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + if ( + state_len <= seqlen + ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) # just read from 'x' # copy 'x' data to conv_state # load only 'x' data (and set 0 before 'x' if seqlen < state_len) idx_tokens_last = (seqlen - state_len) + tl.arange( - 0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = x_ptr + ( - (sequence_start_index + idx_tokens_last) * - stride_x_token)[:, None] + ( - idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] - mask_x = ((idx_tokens_last >= 0)[:, None] & - (idx_tokens_last < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + 0, NP2_STATELEN + ) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] + & (idx_tokens_last < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + ( - idx_tokens_conv * stride_conv_state_tok)[:, None] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] + # Compute the offset where the last block should be written in the conv_states + conv_states_output_coord = tl.load( + conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_last_index + ).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + + (conv_states_output_coord * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok + )[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) + tl.store(conv_states_ptrs_target, loaded_x, mask) else: if load_init_state: @@ -189,39 +241,43 @@ def _causal_conv1d_fwd_kernel( # continuous batching idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] conv_states_ptrs_source = ( - conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, - None] + conv_states_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens_conv + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier( - ) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load new_conv_state = tl.where( mask, conv_state, loaded_x ) # BUG in 'tl.where' which requires a barrier before this - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv - < state_len)[:, None] & (idx_feats < dim)[None, :] + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # load_init_state == False # update conv_state by shifting left, BUT @@ -230,21 +286,25 @@ def _causal_conv1d_fwd_kernel( # continuous batching VAL = state_len - seqlen - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv - < state_len)[:, None] & (idx_feats < dim)[None, :] + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # chunk_offset > 0 @@ -254,37 +314,84 @@ def _causal_conv1d_fwd_kernel( # continuous batching mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 3: conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 4: conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 5: # ruff: noqa: F841 conv_states_ptrs = prior_tokens # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + # Store intermediate states aligned with stride_block_m + # The additional states are cached starting from the last stride_block_m. + # For example: + # If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved. + # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last + # stride_block_m are cached. + # For example chunk_offset = n_block_to_fill stores the state at last_full_block + if (chunk_offset - 1) < n_block_to_fill: + # Store the states at the chunk boundaries from the start of the sequence + idx_tokens_last = ( + last_full_block_token_index + - (n_block_to_fill - chunk_offset) * B_size + - state_len + ) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = ( + x_ptr + + (idx_tokens_last * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + + mask_x = (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[ + None, : + ] # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # cache_idx + conv_states_output_coord = tl.load( + conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_first_index + + (chunk_offset - 1) + ).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + + (conv_states_output_coord * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok + )[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, loaded_x, mask) if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) x_base_1d = x_base + token_offset * stride_x_token # starting of chunk @@ -308,7 +415,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching matrix_w = w_col0 matrix_x = col0 for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: if j == 1: # KERNEL_WIDTH-1: matrix_w = w_col1 @@ -349,9 +455,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < segment_len) & ( - idx_feats < dim) # token-index # feature-index - o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token - ) * stride_o_token + (idx_feats * stride_o_dim) + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (sequence_start_index + token_offset + idx_token) * stride_o_token + + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) @@ -366,6 +476,11 @@ def causal_conv1d_fn( has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, + block_idx_first_scheduled_token: Optional[torch.Tensor] = None, + block_idx_last_scheduled_token: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, + num_computed_tokens: Optional[torch.Tensor] = None, + block_size_to_align=0, metadata=None, validate_data=False, ): @@ -376,7 +491,7 @@ def causal_conv1d_fn( sequences are concatenated from left to right for varlen weight: (dim, width) conv_states: (...,dim,width - 1) itype - updated inplace if provided + updated inplace if cache_indices are not provided [it use `cache_indices` to get the index to the cache of conv_state for that sequence conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True @@ -408,37 +523,41 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - + block_idx_first_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the first cache block to be filled is located. + block_idx_last_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the last cache block to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into cache_indices, where the cache block containing the initial state is located. + num_computed_tokens: (batch,), dtype int32 + The number of tokens already completed for each sequence + block_size_to_align: int + The block size to align the cached states to out: same shape as `x` """ if isinstance(activation, bool) and activation: activation = "silu" args = None + # Store original dtype to cast back at the end + original_x_dtype = x.dtype + x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: - cu_seqlen = metadata.cu_seqlen nums_dict = metadata.nums_dict - #x = metadata.x args = nums_dict batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr else: - seqlens = np.diff(query_start_loc.to('cpu')) + seqlens = query_start_loc.diff().to("cpu") args = seqlens MAX_NUM_PROGRAMS = 1024 batch_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=x.device + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking which seq-idx the Triton program is handling token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=x.device + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking BLOCK_M-based index in the sequence the Triton program is handling is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) @@ -448,7 +567,6 @@ def causal_conv1d_fn( np2_statelen = triton.next_power_of_2(state_len) padded_batch = query_start_loc.size(0) - 1 - stride_x_seq = 0 stride_x_dim = x.stride(0) stride_x_token = x.stride(1) stride_w_dim = weight.stride(0) @@ -457,6 +575,7 @@ def causal_conv1d_fn( stride_istate_dim = 0 stride_istate_token = 0 num_cache_lines = 0 + BLOCK_M = 8 if conv_states is not None: # extensions to support vLLM: # 1. conv_states is used to replaced initial_states @@ -464,19 +583,22 @@ def causal_conv1d_fn( # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) - assert (num_cache_lines, dim, width - 1) == conv_states.shape + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) assert stride_istate_dim == 1 if out.dim() == 2: - stride_o_seq = 0 stride_o_dim = out.stride(0) stride_o_token = out.stride(1) else: - stride_o_seq = out.stride(0) stride_o_dim = out.stride(1) stride_o_token = out.stride(2) + stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 @@ -490,11 +612,19 @@ def causal_conv1d_fn( assert cache_indices.dim() == 1 assert padded_batch == cache_indices.size(0) if has_initial_state is not None: - assert has_initial_state.size() == (padded_batch, ) - assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert has_initial_state.size() == (padded_batch,) + assert conv_states is not None, ( + "ERROR: `has_initial_state` is used, which needs also `conv_states`" + ) assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" + if block_size_to_align is not None and block_size_to_align > 0: + assert (block_size_to_align % BLOCK_M) == 0, ( + "The mamba block size needs to be divisible by the BLOCK_M" + ) + else: + block_size_to_align = BLOCK_M if metadata is None: @@ -516,44 +646,45 @@ def causal_conv1d_fn( if META["batch_ptr"].nelement() < len(mlist): newlen = len(mlist) + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_( - PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= len(mlist): - META["batch_ptr"][0:len(mlist)].copy_( - torch.from_numpy(np.array(mlist))) - META["token_chunk_offset_ptr"][0:len(mlist)].copy_( - torch.from_numpy(np.array(offsetlist))) + META["batch_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(mlist)) + ) + META["token_chunk_offset_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(offsetlist)) + ) META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( - META["x_ptr"].device) + META["x_ptr"].device + ) return tot else: def num_program(META, nums_dict): - tot = nums_dict[META["BLOCK_M"]]['tot'] + tot = nums_dict[META["BLOCK_M"]]["tot"] - mlist = nums_dict[META["BLOCK_M"]]['mlist'] - mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len'] + mlist = nums_dict[META["BLOCK_M"]]["mlist"] + mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] - offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist'] + offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] - META["token_chunk_offset_ptr"] = nums_dict[ - META["BLOCK_M"]]["token_chunk_offset_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][ + "token_chunk_offset_ptr" + ] else: if META["batch_ptr"].nelement() < mlist_len: newlen = mlist_len + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_( - PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= mlist_len: META["batch_ptr"][0:mlist_len].copy_(mlist) - META["token_chunk_offset_ptr"][0:mlist_len].copy_( - offsetlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) return tot def grid(META): @@ -577,14 +708,16 @@ def causal_conv1d_fn( query_start_loc, batch_ptr, token_chunk_offset_ptr, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, + num_computed_tokens, out, # Matrix dimensions - padded_batch, dim, cu_seqlen, num_cache_lines, # stride - stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, @@ -592,26 +725,25 @@ def causal_conv1d_fn( stride_istate_seq, stride_istate_dim, stride_istate_token, - stride_o_seq, + stride_cache_indices, stride_o_dim, stride_o_token, + block_size_to_align // BLOCK_M, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - HAS_CACHE=conv_states is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, - #launch_cooperative_grid=True - BLOCK_M=8, + # launch_cooperative_grid=True + BLOCK_M=BLOCK_M, BLOCK_N=256, num_stages=2, ) - return out + return out.to(original_x_dtype) @triton.jit() @@ -621,8 +753,11 @@ def _causal_conv1d_update_kernel( w_ptr, # (dim, width) bias_ptr, conv_state_ptr, - cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -639,6 +774,7 @@ def _causal_conv1d_update_kernel( stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -648,7 +784,9 @@ def _causal_conv1d_update_kernel( HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, @@ -661,24 +799,70 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + if IS_APC_ENABLED: + # Get the state from the initial_state_idx + conv_state_init = tl.load(initial_state_idx + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) else: - conv_state_batch_coord = idx_seq + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init + ).to(tl.int64) + if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + ) + else: + conv_state_token_offset = 0 + # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) + conv_states_base = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) mask_w = idx_feats < dim - prior_tokens = conv_states_base + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok if KERNEL_WIDTH >= 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) @@ -688,43 +872,64 @@ def _causal_conv1d_update_kernel( if KERNEL_WIDTH >= 4: conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: + if KERNEL_WIDTH >= 5: conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ + :, None + ] ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] - x_ptrs = x_base[None, :] + ( - (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens - VAL >= 0)[:, None] & - (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) tl.debug_barrier() new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + # Get the state from the initial_state_idx + # cache_idx + conv_states_offset = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index + ).to(tl.int64) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[None, :] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[:, None] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -732,10 +937,11 @@ def _causal_conv1d_update_kernel( if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) # STEP 4: # PRE-LOAD WEIGHTS @@ -753,12 +959,18 @@ def _causal_conv1d_update_kernel( if KERNEL_WIDTH >= 4: w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) x_base_1d = x_base # starting of chunk [BLOCK_N] mask_x_1d = idx_feats < dim # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): + for idx_token in tl.range(seqlen): acc = acc_preload matrix_w = w_col0 @@ -788,6 +1000,37 @@ def _causal_conv1d_update_kernel( matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) acc += matrix_x * matrix_w # [BLOCK_N] @@ -800,14 +1043,26 @@ def _causal_conv1d_update_kernel( col0 = col1 col1 = col2 col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim - ) # token-index # feature-index - o_ptrs = o_ptr + ( - idx_seq) * stride_o_seq + idx_token * stride_o_token + ( - idx_feats * stride_o_dim) + mask_1d = (idx_token < seqlen) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) @@ -818,82 +1073,117 @@ def causal_conv1d_update( weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, - metadata=None, + block_idx_last_scheduled_token: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, validate_data=False, ): """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] + x: Input tensor which can take the following shapes: + + - `[batch, dim]` - single token prediction + - `[batch, dim, seqlen]` - single or multiple tokens prediction + - `[num_tokens, dim]` - continuous batching, where num_tokens is + the total tokens of all sequences in that batch + conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + block_idx_last_scheduled_token: (batch,), dtype int32 + The pointer into conv_state_indices, where the last cache block to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into conv_state_indices, where the cache block containing the initial state is located. + num_accepted_tokens: (batch,), dtype int32 + If not None, it indicates the number of accepted tokens for each + sequence in the batch. + This is used in speculative decoding, where the conv_state is updated + in a sliding window manner. + query_start_loc: (batch + 1,) int32 + If not None, the inputs is given in a varlen fashion and this indicates + the starting index of each sequence in the batch. + max_query_len: int + If query_start_loc is not None, this indicates the maximum query + length in the batch. pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded + if conv_state_indices is passed, lets the kernel identify padded entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) + out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + seqlen = max_query_len _, width = weight.shape # conv_state: (..., dim, state_len), where state_len >= width - 1 num_cache_lines, _, state_len = conv_state.size() if validate_data: assert dim == weight.size(0) - assert conv_state.stride( - -2 - ) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert conv_state.stride(-2) == 1, ( + f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + ) assert state_len >= width - 1 # when above happens, we don't shift-left to keep any records in conv_state assert dim == conv_state.size(1) if conv_state_indices is None: assert conv_state.size(0) >= batch else: - assert (batch, ) == conv_state_indices.shape + assert (batch,) == conv_state_indices.shape assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x stride_w_dim, stride_w_width = weight.stride() - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (dim, cu_seqlen) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = ( + conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) - state_len = width - 1 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) def grid(META): @@ -908,8 +1198,11 @@ def causal_conv1d_update( weight, bias, conv_state, - cache_seqlens, conv_state_indices, + num_accepted_tokens, + query_start_loc, + block_idx_last_scheduled_token, + initial_state_idx, out, # Matrix dimensions batch, @@ -926,6 +1219,7 @@ def causal_conv1d_update( stride_istate_seq, stride_istate_dim, stride_istate_token, + stride_state_indices, stride_o_seq, stride_o_dim, stride_o_token, @@ -935,11 +1229,13 @@ def causal_conv1d_update( HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_VARLEN=query_start_loc is not None, + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, ) if unsqueeze: out = out.squeeze(-1) - return out + return out.to(original_x_dtype) diff --git a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py index f3a45ab097c34..b592906c6f130 100644 --- a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py +++ b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py @@ -46,17 +46,17 @@ def _layer_norm_fwd_1pass_kernel( B += group * N # Compute mean and variance cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=cols < N).to(tl.float32) x *= z * tl.sigmoid(z) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.) + xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: - xbar = tl.where(cols < N, x, 0.) + xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) @@ -74,15 +74,17 @@ def _layer_norm_fwd_1pass_kernel( tl.store(Y + cols, y, mask=mask) -def _layer_norm_fwd(x, - weight, - bias, - eps, - z=None, - out=None, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): M, N = x.shape if group_size is None: group_size = N @@ -92,57 +94,57 @@ def _layer_norm_fwd(x, if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) - assert weight.shape == (N, ) + assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 - assert bias.shape == (N, ) + assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 - mean = torch.empty((ngroups * M, ), dtype=torch.float32, - device=x.device) if not is_rms_norm else None - rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[grid](x, - out, - weight, - bias, - z, - mean, - rstd, - x.stride(0), - out.stride(0), - z.stride(0) if z is not None else 0, - M, - group_size, - eps, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - num_warps=num_warps) + _layer_norm_fwd_1pass_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) return out, mean, rstd -def rms_norm_gated(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True): +def rms_norm_gated( + x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True +): x_shape_og = x.shape # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) @@ -156,13 +158,15 @@ def rms_norm_gated(x, weight = weight.contiguous() if bias is not None: bias = bias.contiguous() - y, _, _ = _layer_norm_fwd(x, - weight, - bias, - eps, - z=z, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=True) + y, _, _ = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=True, + ) return y.reshape(x_shape_og) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 838290a9f5fb2..8722eb9a7b22f 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -11,8 +11,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import HAS_TRITON, tl, triton -TRITON3 = HAS_TRITON and (version.parse(triton.__version__) - >= version.parse("3.0.0")) +TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) if TRITON3: @@ -28,16 +27,18 @@ else: return dt -@triton.heuristics( - {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) -@triton.heuristics({ - "HAS_STATE_BATCH_INDICES": - lambda args: args["state_batch_indices_ptr"] is not None -}) @triton.heuristics( - {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) + { + "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] + is not None + } +) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} +) @triton.jit def _selective_scan_update_kernel( # Pointers to matrices @@ -52,6 +53,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + dst_state_batch_indices_ptr, pad_slot_id, # Matrix dimensions batch, @@ -107,11 +109,18 @@ def _selective_scan_update_kernel( # is taken from the state_batch_indices_ptr Otherwise, the state coordinate # is the same as the batch id. if HAS_STATE_BATCH_INDICES: + dst_state_batch_indices_ptr += pid_b + dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) + dst_state_ptr = state_ptr + ( + dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) - state_ptr += (state_batch_idx * stride_state_batch + - pid_h * stride_state_head) + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: + dst_state_ptr = ( + state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head + ) state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head @@ -119,26 +128,29 @@ def _selective_scan_update_kernel( if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // - nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // - nheads_ngroups_ratio) * stride_C_group + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) - state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + - offs_n[None, :] * stride_state_dstate) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + dst_state_ptrs = dst_state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + - offs_n[None, :] * stride_A_dstate) + A_ptrs = A_ptr + ( + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate + ) B_ptrs = B_ptr + offs_n * stride_B_dstate C_ptrs = C_ptr + offs_n * stride_C_dstate if HAS_D: @@ -148,20 +160,19 @@ def _selective_scan_update_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - mask &= (state_batch_idx != pad_slot_id) + mask &= state_batch_idx != pad_slot_id state = tl.load(state_ptrs, mask=mask, other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if DT_SOFTPLUS: dt = softplus(dt) - A = tl.load(A_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) + A = tl.load( + A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) dA = tl.exp(A * dt[:, None]) else: dt = tl.load(dt_ptr).to(tl.float32) @@ -184,8 +195,8 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - mask &= (state_batch_idx != pad_slot_id) - tl.store(state_ptrs, state, mask=mask) + mask &= state_batch_idx != pad_slot_id + tl.store(dst_state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -194,19 +205,22 @@ def _selective_scan_update_kernel( tl.store(out_ptrs, out, mask=offs_m < dim) -def selective_state_update(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False, - state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID, - out=None): +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + dst_state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID, + out=None, +): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -219,12 +233,12 @@ def selective_state_update(state, z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 - out: Preallocated ssm output tensor. Assume same shape as x. + out: Preallocated ssm output tensor. Assume same shape as x. In-place updated. """ if state.dim() == 3: @@ -265,20 +279,33 @@ def selective_state_update(state, if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: - assert state_batch_indices.shape == (batch, ) + assert state_batch_indices.shape == (batch,) + if dst_state_batch_indices is not None: + assert dst_state_batch_indices.shape == (batch,) + else: + # revert to the default behavior of in-place state updates + dst_state_batch_indices = state_batch_indices assert out.shape == x.shape - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else - (0, 0, 0)) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) # We don't want autotune since it will overwrite the state # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else - ((16, 4) if dstate <= 32 else - ((8, 4) if dstate <= 64 else - ((4, 4) if dstate <= 128 else ((4, 8)))))) - tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( - -1) == 0 and dt_bias.stride(-1) == 0 + BLOCK_SIZE_M, num_warps = ( + (32, 4) + if dstate <= 16 + else ( + (16, 4) + if dstate <= 32 + else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) + ) + ) + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and dt_bias.stride(-1) == 0 + ) with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, @@ -292,6 +319,7 @@ def selective_state_update(state, z, out, state_batch_indices, + dst_state_batch_indices, pad_slot_id, batch, nheads, @@ -308,8 +336,7 @@ def selective_state_update(state, dt.stride(0), dt.stride(1), dt.stride(2), - *(dt_bias.stride(0), - dt_bias.stride(1)) if dt_bias is not None else 0, + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, A.stride(0), A.stride(1), A.stride(2), @@ -333,54 +360,56 @@ def selective_state_update(state, ) -def selective_scan_fn(u, - ssm_states, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - query_start_loc=None, - cache_indices=None, - has_initial_state=None, - pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: +def selective_scan_fn( + u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID, +) -> torch.Tensor: """ - u: (dim, total_length) for varlen or (batch, dim, seqlen) + u: (dim, total_length) for varlen or (batch, dim, seqlen) applies changes in place. ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) applies changes in place. delta: (dim, total_length) for varlen or (batch, dim, seqlen) - A: (dim, dstate) - B: (ngroups, dstate, total_length) for varlen or + A: (dim, dstate) + B: (ngroups, dstate, total_length) for varlen or (batch,ngroups,dstate,seqlen) - C: (ngroups, dstate, total_length) for varlen or + C: (ngroups, dstate, total_length) for varlen or (batch,ngroups,dstate,seqlen) - D: (dim,) - z: (dim, total_length) for varlen or (batch, dim, seqlen) + D: (dim,) + z: (dim, total_length) for varlen or (batch, dim, seqlen) dt_bias: (dim,) or (dim) query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended with 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), + for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 - A tensor with each cell is a correspondent + A tensor with each cell is a correspondent input and output ssm_state index has_initial_state: (batch) bool - A tensor populated with ones and zeros, - indicate if the ssm_state at the corresponding index should be - used as initial state. Not providing argument assumes + A tensor populated with ones and zeros, + indicate if the ssm_state at the corresponding index should be + used as initial state. Not providing argument assumes there's no initial state pad_slot_id: int - if cache_indices is passed, lets the kernel identify padding entries - that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 returns - output: (dim, total_length) for varlen or (batch, dim, seqlen) + output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement """ if u.stride(-1) != 1: @@ -404,9 +433,22 @@ def selective_scan_fn(u, if C.dim() == 2 and query_start_loc is not None: C = C.unsqueeze(0) - ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, - query_start_loc, cache_indices, has_initial_state, - ssm_states, pad_slot_id) + ops.selective_scan_fwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ) if z is None: return delta # output written inplace to delta diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 11ca1255ebfb6..ac5ffc10f2950 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -6,8 +6,6 @@ # ruff: noqa: E501,SIM102 -import math - import torch from vllm.triton_utils import tl, triton @@ -16,79 +14,52 @@ from vllm.triton_utils import tl, triton @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['chunk_size', 'K', 'IS_CAUSAL'], + key=["chunk_size", "K", "IS_CAUSAL"], ) @triton.jit def _bmm_chunk_fwd_kernel( @@ -96,37 +67,30 @@ def _bmm_chunk_fwd_kernel( a_ptr, b_ptr, out_ptr, - seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions seqlen, - chunk_size, - K, - ngroups, - stride_a_batch, - stride_a_seqlen, - stride_a_head, - stride_ak, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_bk, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_outm, - stride_outn, - stride_seq_idx_batch, - stride_seq_idx_seqlen, + chunk_size: tl.constexpr, + K: tl.constexpr, + ngroups: tl.constexpr, + stride_a_seqlen: tl.int64, + stride_a_head: tl.int64, + stride_ak: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_bk: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_outm: tl.int64, + stride_outn: tl.constexpr, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_ch = tl.program_id(axis=1).to(tl.int64) pid_c = pid_ch // ngroups pid_h = pid_ch - pid_c * ngroups num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) @@ -135,128 +99,113 @@ def _bmm_chunk_fwd_kernel( if IS_CAUSAL: if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: return - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + - offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + - offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # compute a * b.T for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0).to(dot_dtype) - b = tl.load(b_ptrs, - mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & - (offs_n[None, :] < chunk_size_limit), - other=0.0).to(dot_dtype) + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) + & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) acc += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, - mask=offs_n < chunk_size_limit, - other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) - - out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + - offs_n[None, :] * stride_outn) - tl.store(out_ptrs, - out, - mask=(offs_m[:, None] < chunk_size) & - (offs_n[None, :] < chunk_size)) + out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), + ) -def _bmm_chunk_fwd(a, - b, - chunk_size, - seq_idx=None, - causal=False, - output_dtype=None): +def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None): """ Argument: - a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + a: (seqlen, ngroups, k) + b: (seqlen, ngroups, k) + chunk_size: int + cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. Return: - out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + out: (nchunks, ngroups, chunk_size, chunk_size) """ - # Check constraints. - has_groups = a.dim() == 4 - if not has_groups: - batch, seqlen, k = a.shape - else: - batch, seqlen, ngroups, k = a.shape + seqlen, ngroups, k = a.shape assert b.shape == a.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if a.stride(-1) != 1 and a.stride(1) != 1: + if a.stride(-1) != 1 and a.stride(0) != 1: a = a.contiguous() - if b.stride(-1) != 1 and b.stride(1) != 1: + if b.stride(-1) != 1 and b.stride(0) != 1: b = b.contiguous() - nchunks = math.ceil(seqlen / chunk_size) + + nchunks = len(cu_chunk_seqlens) - 1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty( - (batch, nchunks, chunk_size, chunk_size) if not has_groups else - (batch, nchunks, ngroups, chunk_size, chunk_size), - device=a.device, - dtype=out_dtype) - dot_dtype = (tl.bfloat16 - if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else - (tl.float16 if a.dtype == torch.float16 - or b.dtype == torch.float16 else tl.float32)) - grid = lambda META: (triton.cdiv( - chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - chunk_size, META['BLOCK_SIZE_N']), batch, nchunks - if not has_groups else nchunks * ngroups) + (nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype + ) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else ( + tl.float16 + if a.dtype == torch.float16 or b.dtype == torch.float16 + else tl.float32 + ) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + nchunks * ngroups, + ) with torch.cuda.device(a.device.index): _bmm_chunk_fwd_kernel[grid]( - a, - b, - out, - seq_idx, - seqlen, - chunk_size, - k, - ngroups if has_groups else 1, - a.stride(0), - a.stride(1), - 0 if not has_groups else a.stride(2), - a.stride(-1), - b.stride(0), - b.stride(1), - 0 if not has_groups else b.stride(2), - b.stride(-1), - out.stride(0), - out.stride(1), - 0 if not has_groups else out.stride(2), - out.stride(-2), - out.stride(-1), - *((seq_idx.stride(0), - seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - causal, - dot_dtype, - HAS_SEQ_IDX=seq_idx is not None, + a_ptr=a, + b_ptr=b, + out_ptr=out, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + chunk_size=chunk_size, + K=k, + ngroups=ngroups, + stride_a_seqlen=a.stride(0), + stride_a_head=a.stride(1), + stride_ak=a.stride(2), + stride_b_seqlen=b.stride(0), + stride_b_head=b.stride(1), + stride_bk=b.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_outm=out.stride(-2), + stride_outn=out.stride(-1), + IS_CAUSAL=causal, + dot_dtype=dot_dtype, ) return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 365139e237c66..e5a5c9dd6f712 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -6,106 +6,72 @@ # ruff: noqa: E501,SIM102 -import torch from packaging import version from vllm.triton_utils import tl, triton -TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], + key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], ) @triton.jit def _chunk_scan_fwd_kernel( @@ -114,7 +80,6 @@ def _chunk_scan_fwd_kernel( x_ptr, z_ptr, out_ptr, - out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, @@ -122,64 +87,51 @@ def _chunk_scan_fwd_kernel( states_ptr, D_ptr, initstates_ptr, - chunk_indices_ptr, - chunk_offsets_ptr, - chunk_meta_num, + cu_chunk_seqlens_ptr, # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, + chunk_size: tl.constexpr, + hdim: tl.constexpr, + dstate: tl.constexpr, seqlen, - nheads_ngroups_ratio, + nheads_ngroups_ratio: tl.constexpr, # Strides - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_z_batch, - stride_z_seqlen, - stride_z_head, - stride_z_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_init_states_batch, - stride_init_states_head, - stride_init_states_hdim, - stride_init_states_dstate, - stride_D_head, + stride_cb_chunk: tl.int64, + stride_cb_head: tl.int64, + stride_cb_csize_m: tl.int64, + stride_cb_csize_k: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_z_seqlen: tl.int64, + stride_z_head: tl.int64, + stride_z_hdim: tl.constexpr, + stride_out_seqlen: tl.int64, + stride_out_head: tl.int64, + stride_out_hdim: tl.constexpr, + stride_dt_chunk: tl.int64, + stride_dt_head: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_head: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + stride_C_seqlen: tl.int64, + stride_C_head: tl.int64, + stride_C_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + stride_D_head: tl.constexpr, # Meta-parameters IS_CAUSAL: tl.constexpr, HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -187,256 +139,210 @@ def _chunk_scan_fwd_kernel( IS_TRITON_22: tl.constexpr, HAS_INITSTATES: tl.constexpr, ): - pid_bc = tl.program_id(axis=1).to(tl.int64) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - if not HAS_INITSTATES: - c_idx = pid_c - c_off = 0 - else: - c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) - c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) - + pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( - pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( - pid_h // nheads_ngroups_ratio) * stride_C_head + cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += ( + chunk_seqlen_start * stride_C_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_C_head + ) # M-block offsets and prev states # - logic in next block may override these if there is an active offset - offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head - prev_states_hdim = stride_states_hdim - prev_states_dstate = stride_states_dstate + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + seq_idx_ptr += pid_c * stride_seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load( + seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1 + ) - # - we only need seq_idx_prev to be aligned to chunk boundary - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=c_idx >= 1, - other=0) + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states_ptr = ( + initstates_ptr + + seq_idx * stride_init_states_batch + + pid_h * stride_init_states_head + ) + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = ( + states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + ) + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate - if HAS_INITSTATES: - # if there are init states, we only need seq_idx_m to point - # what is the current seq_idx - - # get current seq idx - if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: - seq_idx_m = tl.load( - seq_idx_ptr + - (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) - - # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its initial state - # so this edge case is taken care of - if ((c_off == 0) and - (seq_idx_prev != seq_idx_m - ) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): - - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, - mask=offs_m < chunk_size, - other=0.0).to(tl.float32) - - # - handle chunk state limit - if HAS_INITSTATES: - - # have to split this if otherwise compilation will have problems - dA_cs_m_boundary = 0.0 - - # get the c_idx for the next (logica) chunk - c_idx_n = tl.load( - chunk_indices_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=-1 # to trigger different chunk - ) - - # - there are things to consider - # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct - # contribution of past states - # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to - # encroach into the next sequence, where c_off_n is the offset of the next - # (logical) chunk. - # An equivalent check for B is c_idx == c_idx_n, where there is repetition in - # (logical) chunk indices. - - if (c_idx == c_idx_n) or c_off > 0: - - # get the next offset - c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=chunk_size) - - # in this case, adjust down the chunk_size_limit - if c_idx == c_idx_n: - chunk_size_limit = min(c_off_n, chunk_size_limit) - - # get the cs at the offset boundary - # - c_off == 0 is a passthrough - dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, - mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), - other=0.0).to(tl.float32) - - if HAS_SEQ_IDX: - # - handle seq idx when HAS_INITSTATES==False - if not HAS_INITSTATES: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Without the if (pid_c > -1), with Triton 2.1.0, I get - # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. - # With Triton 2.2.0, this works - if IS_TRITON_22 or c_idx > -1: - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange( - 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + - offs_k_dstate[None, :] * stride_C_dstate) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * prev_states_hdim + - offs_k_dstate[:, None] * prev_states_dstate) - if HAS_SEQ_IDX: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K + ) + C_ptrs = C_ptr + ( + offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate + ) - if not HAS_INITSTATES: - # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), - 0.0) - else: - # - if there is initstates, we will rely on prev_states, no zeroing - # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) + + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros( + (BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) else: - scale_m = tl.exp(dA_cs_m) - if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate), - other=0.0) - - prev_states = tl.load(prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) + # otherwise read the previous state + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc = tl.dot(C, prev_states) * scale_m[:, None] - else: - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate - k), - other=0.0) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + + acc = tl.dot(C, prev_states) * scale_m[:, None] + + else: + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros( + (BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty + ) + else: prev_states = tl.load( prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) & - (offs_n[None, :] < hdim), - other=0.0) + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K - acc *= scale_m[:, None] + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] - offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + - offs_k[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + - offs_n[None, :] * stride_x_hdim) + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + ( + offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k + ) + x_ptrs = x_ptr + ( + offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim + ) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit if not IS_CAUSAL else min( - (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = ( + chunk_size_limit + if not IS_CAUSAL + else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + ) for k in range(0, K_MAX, BLOCK_SIZE_K): - cb = tl.load(cb_ptrs, - mask=(offs_m[:, None] < chunk_size) & - (offs_k[None, :] < chunk_size - k), - other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size - k, - other=0.0).to(tl.float32) + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( + tl.float32 + ) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, - other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: mask = offs_m[:, None] >= k + offs_k[None, :] cb = tl.where(mask, cb, 0.0) cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & - (offs_n[None, :] < hdim), - other=0.0) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) acc += tl.dot(cb, x) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k x_ptrs += BLOCK_SIZE_K * stride_x_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, - mask=offs_n < hdim, - other=0.0).to(tl.float32) + D = tl.load( + D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 + ).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + - offs_n[None, :] * stride_x_hdim), - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_n[None, :] < hdim), - other=0.0).to(tl.float32) + x_residual = tl.load( + x_ptr + + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) acc += x_residual * D if HAS_Z: - out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + - offs_out_n[None, :]) - tl.store(out_x_ptrs, - acc, - mask=(offs_out_m[:, None] < chunk_size_limit) & - (offs_out_n[None, :] < hdim)) - - z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head - z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + - stride_z_hdim * offs_out_n[None, :]) - z = tl.load(z_ptrs, - mask=(offs_out_m[:, None] < chunk_size_limit) & - (offs_out_n[None, :] < hdim), - other=0.0).to(tl.float32) + z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + - offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, - acc, - mask=(offs_out_m[:, None] < chunk_size_limit) & - (offs_out_n[None, :] < hdim)) + out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + ) def _chunk_scan_fwd( @@ -446,126 +352,105 @@ def _chunk_scan_fwd( dA_cumsum, C, states, + cu_chunk_seqlens, + out, + seq_idx, D=None, z=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, initial_states=None, - out=None, ): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = C.shape + assert seq_idx is not None, "this implementation requires seq_idx" + + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = C.shape assert nheads % ngroups == 0 - assert C.shape == (batch, seqlen, ngroups, dstate) - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert C.shape == (seqlen, ngroups, dstate) + assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) if z is not None: assert z.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads, ) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert states.shape == (batch, nchunks, nheads, headdim, dstate) - - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - - if initial_states is not None: - # with initial states, we need to take care of how - # seq_idx crosses the boundaries - assert batch == 1, "chunk scan only supports initial states with batch 1" - assert chunk_indices is not None and chunk_offsets is not None, \ - "chunk_indices and chunk_offsets should have been set" - else: - chunk_indices, chunk_offsets = None, None - else: - chunk_indices, chunk_offsets = None, None - - assert out.shape == x.shape - - if z is not None: - out_x = torch.empty_like(x) - assert out_x.stride() == out.stride() - else: - out_x = None + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + assert states.shape == (nchunks, nheads, headdim, dstate) + assert seq_idx.shape == (nchunks,) grid = lambda META: ( - triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), batch * nchunks - if chunk_offsets is None else len(chunk_offsets), nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2), - z.stride(3)) if z is not None else (0, 0, 0, 0)) + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + _chunk_scan_fwd_kernel[grid]( - cb, - x, - z, - out, - out_x, - dt, - dA_cumsum, - seq_idx, - C, - states, - D, - initial_states, - chunk_indices, - chunk_offsets, - len(chunk_indices) if chunk_indices is not None else 0, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - cb.stride(0), - cb.stride(1), - cb.stride(2), - cb.stride(3), - cb.stride(4), - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - z_strides[0], - z_strides[1], - z_strides[2], - z_strides[3], - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - dt.stride(0), - dt.stride(2), - dt.stride(1), - dt.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else - (0, 0)), - C.stride(0), - C.stride(1), - C.stride(2), - C.stride(3), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - states.stride(4), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) if initial_states is not None else - (0, 0, 0, 0)), - D.stride(0) if D is not None else 0, - True, - D is not None, - D.dim() == 2 if D is not None else True, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + cb_ptr=cb, + x_ptr=x, + z_ptr=z, + out_ptr=out, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + seq_idx_ptr=seq_idx, + C_ptr=C, + states_ptr=states, + D_ptr=D, + initstates_ptr=initial_states, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + chunk_size=chunk_size, + hdim=headdim, + dstate=dstate, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_cb_chunk=cb.stride(0), + stride_cb_head=cb.stride(1), + stride_cb_csize_m=cb.stride(2), + stride_cb_csize_k=cb.stride(3), + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_z_seqlen=z_strides[0], + stride_z_head=z_strides[1], + stride_z_hdim=z_strides[2], + stride_out_seqlen=out.stride(0), + stride_out_head=out.stride(1), + stride_out_hdim=out.stride(2), + stride_dt_chunk=dt.stride(1), + stride_dt_head=dt.stride(0), + stride_dt_csize=dt.stride(2), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_seq_idx_chunk=seq_idx.stride(0), + stride_C_seqlen=C.stride(0), + stride_C_head=C.stride(1), + stride_C_dstate=C.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + stride_D_head=D.stride(0) if D is not None else 0, + IS_CAUSAL=True, + HAS_D=D is not None, + D_HAS_HDIM=D.dim() == 2 if D is not None else True, HAS_Z=z is not None, - HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, ) - return out_x + return diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index ad58a9918f03c..11cc125bf219c 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -6,8 +6,6 @@ # ruff: noqa: E501 -import math - import torch from vllm.triton_utils import tl, triton @@ -17,15 +15,14 @@ from .mamba_ssm import softplus @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), - triton.Config({'BLOCK_SIZE_H': 2}), - triton.Config({'BLOCK_SIZE_H': 4}), - triton.Config({'BLOCK_SIZE_H': 8}), - triton.Config({'BLOCK_SIZE_H': 16}), - triton.Config({'BLOCK_SIZE_H': 32}), - triton.Config({'BLOCK_SIZE_H': 64}), + triton.Config({"BLOCK_SIZE_H": 2}), + triton.Config({"BLOCK_SIZE_H": 4}), + triton.Config({"BLOCK_SIZE_H": 8}), + triton.Config({"BLOCK_SIZE_H": 16}), + triton.Config({"BLOCK_SIZE_H": 32}), + triton.Config({"BLOCK_SIZE_H": 64}), ], - key=['chunk_size', 'nheads'], + key=["chunk_size", "nheads"], ) @triton.jit def _chunk_cumsum_fwd_kernel( @@ -35,158 +32,137 @@ def _chunk_cumsum_fwd_kernel( dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + cu_chunk_seqlens_ptr, # Matrix dimension - batch, seqlen, - nheads, - chunk_size, - dt_min, - dt_max, + nheads: tl.constexpr, + chunk_size: tl.constexpr, + dt_min: tl.constexpr, + dt_max: tl.constexpr, # Strides - stride_dt_batch, - stride_dt_seqlen, - stride_dt_head, - stride_A_head, - stride_dt_bias_head, - stride_dt_out_batch, - stride_dt_out_chunk, - stride_dt_out_head, - stride_dt_out_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, + stride_dt_seqlen: tl.int64, + stride_dt_head: tl.constexpr, + stride_A_head: tl.constexpr, + stride_dt_bias_head: tl.constexpr, + stride_dt_out_head: tl.int64, + stride_dt_out_chunk: tl.int64, + stride_dt_out_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, ): - pid_b = tl.program_id(axis=0) - # if dt is long, may cause problems, so use 64 bit # https://github.com/triton-lang/triton/issues/1058 - pid_c = tl.program_id(axis=1).to(tl.int64) - pid_h = tl.program_id(axis=2) - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen - dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_c = tl.program_id(axis=0).to(tl.int64) + pid_h = tl.program_id(axis=1) + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + dt_ptr += chunk_seqlen_start * stride_dt_seqlen + dt_out_ptr += pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + - offs_c[None, :] * stride_dt_seqlen) + dt_ptrs = dt_ptr + ( + offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen + ) A_ptrs = A_ptr + offs_h * stride_A_head - dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + - offs_c[None, :] * stride_dt_out_csize) - dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + - offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - dt = tl.load(dt_ptrs, - mask=(offs_h[:, None] < nheads) & - (offs_c[None, :] < chunk_size_limit), - other=0.0).to(tl.float32) + dt = tl.load( + dt_ptrs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), + other=0.0, + ).to(tl.float32) if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, - mask=offs_h < nheads, - other=0.0).to(tl.float32) + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) - # As of Triton 2.2.0, tl.clamp is not available yet - # dt = tl.clamp(dt, dt_min, dt_max) - dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + + dt = tl.clamp(dt, dt_min, dt_max) dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, - 0.0) - tl.store(dt_out_ptrs, - dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 + ) + tl.store( + dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) - tl.store(dA_cs_ptrs, - dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + tl.store( + dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['hdim', 'dstate', 'chunk_size'], + key=["hdim", "dstate", "chunk_size"], ) @triton.jit def _chunk_state_fwd_kernel( @@ -196,118 +172,103 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, - seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions - hdim, - dstate, - chunk_size, - batch, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, seqlen, - nheads_ngroups_ratio, + nheads_ngroups_ratio: tl.constexpr, # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - pid_bc = tl.program_id(axis=1).to(tl.int64) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch + pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( - pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + b_ptr += ( + chunk_seqlen_start * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + - offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + - offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( + tl.float32 + ) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + - (chunk_size_limit - 1) * stride_seq_idx_seqlen) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_k[None, :] < chunk_size_limit - k), - other=0.0) - b = tl.load(b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & - (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, - mask=offs_k < chunk_size_limit - k, - other=-1) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k - else: - scale = tl.where(seq_idx_k == seq_idx_last, - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) - states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + - offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) @@ -315,79 +276,52 @@ def _chunk_state_fwd_kernel( @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['hdim', 'dstate', 'chunk_size'], + key=["hdim", "dstate", "chunk_size"], ) @triton.jit def _chunk_state_varlen_kernel( @@ -401,36 +335,35 @@ def _chunk_state_varlen_kernel( states_ptr, initstates_ptr, # Matrix dimensions - hdim, - dstate, - chunk_size, - seqlen, - nheads_ngroups_ratio, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + nheads_ngroups_ratio: tl.constexpr, # Strides - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_chunk_states_chunk, - stride_chunk_states_head, - stride_chunk_states_hdim, - stride_chunk_states_dstate, - stride_states_batch, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_init_states_batch, - stride_init_states_head, - stride_init_states_hdim, - stride_init_states_dstate, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_chunk_states_chunk: tl.int64, + stride_chunk_states_head: tl.int64, + stride_chunk_states_hdim: tl.int64, + stride_chunk_states_dstate: tl.constexpr, + stride_states_batch: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -444,12 +377,16 @@ def _chunk_state_varlen_kernel( pid_n = tl.program_id(axis=0) % num_pid_n end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) pid_c = (end_idx - 1) // chunk_size - b_ptr += pid_c * chunk_size * stride_b_seqlen + ( - pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += ( + pid_c * chunk_size * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + chunk_states_ptr += ( + pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + ) if HAS_INITSTATES: # if there are init states provided, we differentiate between states (which @@ -460,13 +397,16 @@ def _chunk_state_varlen_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + - offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + - offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * - stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load( + dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = end_idx - pid_c * chunk_size @@ -475,24 +415,31 @@ def _chunk_state_varlen_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_k[None, :] < chunk_size_limit - k) & - (offs_k[None, :] >= start_idx_cur - k), - other=0.0) - b = tl.load(b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & - (offs_n[None, :] < dstate) & - (offs_k[:, None] >= start_idx_cur - k), - other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) scale = tl.where( (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + tl.exp(dA_cs_last - dA_cs_k) * dt_k, + 0.0, + ) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -502,42 +449,46 @@ def _chunk_state_varlen_kernel( dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - # If HAS_INITSTATES==True need to consider two possiblties + # If HAS_INITSTATES==True need to consider two possibilities # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs # - if state_idx >= pid * chunk_size, then we need to insert initstates - if ((start_idx < pid_c * chunk_size) # first chunk - or (HAS_INITSTATES)): - + if ( + (start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES) + ): dA_cs_boundary = 0.0 # default if not HAS_INITSTATES: past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim + - offs_n[None, :] * stride_chunk_states_dstate) + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) else: - # - this seems repetitive, buts its to help the compiler if start_idx < pid_c * chunk_size: past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim + - offs_n[None, :] * stride_chunk_states_dstate) + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) else: past_states_ptrs = initstates_ptr + ( - pid_b * stride_init_states_batch + - offs_m[:, None] * stride_init_states_hdim + - offs_n[None, :] * stride_init_states_dstate) + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) # need to adjust the boundary if start_idx > pid_c * chunk_size: - dA_cs_boundary = tl.load(dA_cumsum_ptr + - (start_idx - pid_c * chunk_size - - 1) * stride_dA_cs_csize).to( - tl.float32) + dA_cs_boundary = tl.load( + dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) - past_states = tl.load(past_states_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) + past_states = tl.load( + past_states_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) scale = tl.exp(dA_cs_last - dA_cs_boundary) acc += past_states * scale @@ -547,145 +498,125 @@ def _chunk_state_varlen_kernel( states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + - offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) -def _chunk_cumsum_fwd(dt, - A, - chunk_size, - dt_bias=None, - dt_softplus=False, - dt_limit=(0.0, float("inf"))): - batch, seqlen, nheads = dt.shape - assert A.shape == (nheads, ) +def _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), +): + seqlen, nheads = dt.shape + assert A.shape == (nheads,) if dt_bias is not None: - assert dt_bias.shape == (nheads, ) - nchunks = math.ceil(seqlen / chunk_size) - dt_out = torch.empty(batch, - nheads, - nchunks, - chunk_size, - device=dt.device, - dtype=torch.float32) - dA_cumsum = torch.empty(batch, - nheads, - nchunks, - chunk_size, - device=dt.device, - dtype=torch.float32) - grid_chunk_cs = lambda META: (batch, nchunks, - triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + assert dt_bias.shape == (nheads,) + nchunks = cu_chunk_seqlens.shape[0] - 1 + dt_out = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + dA_cumsum = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( - dt, - A, - dt_bias, - dt_out, - dA_cumsum, - batch, - seqlen, - nheads, - chunk_size, - dt_limit[0], - dt_limit[1], - dt.stride(0), - dt.stride(1), - dt.stride(2), - A.stride(0), - dt_bias.stride(0) if dt_bias is not None else 0, - dt_out.stride(0), - dt_out.stride(2), - dt_out.stride(1), - dt_out.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - dt_softplus, + dt_ptr=dt, + A_ptr=A, + dt_bias_ptr=dt_bias, + dt_out_ptr=dt_out, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + nheads=nheads, + chunk_size=chunk_size, + dt_min=dt_limit[0], + dt_max=dt_limit[1], + stride_dt_seqlen=dt.stride(0), + stride_dt_head=dt.stride(1), + stride_A_head=A.stride(0), + stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0, + stride_dt_out_head=dt_out.stride(0), + stride_dt_out_chunk=dt_out.stride(1), + stride_dt_out_csize=dt_out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + DT_SOFTPLUS=dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out -def _chunk_state_fwd(B, - x, - dt, - dA_cumsum, - seq_idx=None, - states=None, - states_in_fp32=True): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape +def _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True +): + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) + if states is not None: - assert states.shape == (batch, nchunks, nheads, headdim, dstate) + assert states.shape == (nchunks, nheads, headdim, dstate) else: states_dtype = torch.float32 if states_in_fp32 else B.dtype - states = torch.empty((batch, nchunks, nheads, headdim, dstate), - device=x.device, - dtype=states_dtype) + states = torch.empty( + (nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype + ) + grid = lambda META: ( - triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( - dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) with torch.cuda.device(x.device.index): _chunk_state_fwd_kernel[grid]( - x, - B, - states, - dt, - dA_cumsum, - seq_idx, - headdim, - dstate, - chunk_size, - batch, - seqlen, - nheads // ngroups, - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - B.stride(0), - B.stride(1), - B.stride(2), - B.stride(-1), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - states.stride(4), - dt.stride(0), - dt.stride(2), - dt.stride(1), - dt.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *((seq_idx.stride(0), - seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - HAS_SEQ_IDX=seq_idx is not None, + x_ptr=x, + b_ptr=B, + states_ptr=states, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), ) return states -def chunk_state_varlen(B, - x, - dt, - dA_cumsum, - cu_seqlens, - chunk_states, - initial_states=None): +def chunk_state_varlen( + B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None +): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape _, ngroups, dstate = B.shape @@ -700,52 +631,70 @@ def chunk_state_varlen(B, if initial_states is not None: assert initial_states.shape == (batch, nheads, headdim, dstate) - states = torch.empty(batch, - nheads, - headdim, - dstate, - dtype=chunk_states.dtype, - device=chunk_states.device) - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. - cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + states = torch.empty( + batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device, + ) + + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) with torch.cuda.device(x.device.index): _chunk_state_varlen_kernel[grid]( - x, - B, - dt, - dA_cumsum, - chunk_states, - cu_seqlens, - states, - initial_states, - headdim, - dstate, - chunk_size, - total_seqlen, - nheads // ngroups, - x.stride(0), - x.stride(1), - x.stride(2), - B.stride(0), - B.stride(1), - B.stride(2), - dt.stride(1), - dt.stride(0), - dt.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - chunk_states.stride(0), - chunk_states.stride(1), - chunk_states.stride(2), - chunk_states.stride(3), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) if initial_states is not None else - (0, 0, 0, 0)), - HAS_INITSTATES=initial_states is not None) + x_ptr=x, + b_ptr=B, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + chunk_states_ptr=chunk_states, + cu_seqlens_ptr=cu_seqlens, + states_ptr=states, + initstates_ptr=initial_states, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_chunk_states_chunk=chunk_states.stride(0), + stride_chunk_states_head=chunk_states.stride(1), + stride_chunk_states_hdim=chunk_states.stride(2), + stride_chunk_states_dstate=chunk_states.stride(3), + stride_states_batch=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + HAS_INITSTATES=initial_states is not None, + ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index d0b3e9e5235bf..ac905ada7229b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -14,67 +14,69 @@ from vllm.triton_utils import triton from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd -from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, - chunk_state_varlen) +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd -TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") def is_int_pow_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 -def _mamba_chunk_scan_combined_fwd(x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - state_dtype=None, - out=None): +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + return_intermediate_states=False, + seq_idx=None, + cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, +): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape + seqlen, nheads, headdim = x.shape + _, ngroups, dstate = B.shape assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, seqlen, nheads) - assert A.shape == (nheads, ) + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (seqlen, nheads) + assert A.shape == (nheads,) assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert D.shape == (nheads, headdim) or D.shape == (nheads,) if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() - if x.stride(-1) != 1 and x.stride( - 1) != 1: # Either M or K dimension should be contiguous + if ( + x.stride(-1) != 1 and x.stride(0) != 1 + ): # Either M or K dimension should be contiguous x = x.contiguous() - if z is not None and z.stride(-1) != 1 and z.stride( - 1) != 1: # Either M or K dimension should be contiguous + if ( + z is not None and z.stride(-1) != 1 and z.stride(0) != 1 + ): # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() + assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" + if initial_states is not None: - if cu_seqlens is None: - assert initial_states.shape == (batch, nheads, headdim, dstate) - else: - assert initial_states.shape == (len(cu_seqlens) - 1, nheads, - headdim, dstate) + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ @@ -87,49 +89,42 @@ def _mamba_chunk_scan_combined_fwd(x, # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation - dA_cumsum, dt = _chunk_cumsum_fwd(dt, - A, - chunk_size, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - dt_limit=dt_limit) + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + ) # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) - states = _chunk_state_fwd(B, - x, - dt, - dA_cumsum, - seq_idx=seq_idx, - states_in_fp32=True) + states = _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True + ) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # - for handling chunked prefill, this requires i) initial_states and + # ii) seq_idx to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. - # - this will ensure that states will be updated with the rightmost flushed seq_idx - # of the previous chunk. This implies that the first chunk of states is either 0 - # or equal to init_states of the first example. - states, final_states = _state_passing_fwd( + states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, # (nheads, nchunks, chunk_size) + cu_chunk_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") - if initial_states is not None else None, + if initial_states is not None + else None, # (batch, nheads, headdim*dstate) seq_idx=seq_idx, - chunk_size=chunk_size, out_dtype=state_dtype if state_dtype is not None else C.dtype, - is_cont_batched=cu_seqlens is not None) - states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) - for t in [states, final_states]) + ) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms - CB = _bmm_chunk_fwd(C, - B, - chunk_size, - seq_idx=seq_idx, - output_dtype=torch.float32) + CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32) # 5. Scan and compute the diagonal blocks, taking into # account past causal states. @@ -141,105 +136,95 @@ def _mamba_chunk_scan_combined_fwd(x, # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had # a seq_idx change, in which case we take states information from # init_states. - out_x = _chunk_scan_fwd( + _chunk_scan_fwd( CB, x, dt, dA_cumsum, C, states, + cu_chunk_seqlens, + out, # in-place update + seq_idx, D=D, z=z, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, initial_states=initial_states, - out=out, ) - if cu_seqlens is None: - return out_x, dt, dA_cumsum, states, final_states + + if return_intermediate_states: + return states else: - assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen( - B.squeeze(0), - x.squeeze(0), - dt.squeeze(0), - dA_cumsum.squeeze(0), - cu_seqlens, - states.squeeze(0), - initial_states=initial_states, - ) - return out_x, dt, dA_cumsum, states, final_states, varlen_states + return states[last_chunk_indices] -def mamba_chunk_scan_combined(x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - out=None, - return_final_states=False, - return_varlen_states=False, - state_dtype=None): +def mamba_chunk_scan_combined_varlen( + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, +): """ Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) + x: (seqlen, nheads, headdim) + dt: (seqlen, nheads) A: (nheads) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) + B: (seqlen, ngroups, dstate) + C: (seqlen, ngroups, dstate) chunk_size: int + cu_seqlens: (batch + 1,) + cu_chunk_seqlens: (nchunks + 1,) + last_chunk_indices: (batch,) + seq_idx: (nchunks,) + out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) + z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) - seq_idx: (batch, seqlen) - cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt - out: Preallocated output tensor + out: (seqlen, nheads, headdim) preallocated output tensor state_dtype: The data type of the ssm state + Return: + varlen_states: (batch, nheads, headdim, dstate) """ - if not return_varlen_states: - cu_seqlens = None - else: - assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input" + assert seq_idx is not None + + varlen_states = _mamba_chunk_scan_combined_fwd( x, dt, A, B, C, chunk_size, + out, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, + return_intermediate_states=return_intermediate_states, seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, dt_softplus=dt_softplus, dt_limit=dt_limit, - out=out, - state_dtype=state_dtype) - if not return_varlen_states: - if not return_final_states: - return - else: - return final_states - else: - varlen_states = rest[0] - return (varlen_states) if not return_final_states else (final_states, - varlen_states) + state_dtype=state_dtype, + ) + + return varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a28fc9ffad71b..5481bab17e5a7 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -13,122 +13,93 @@ from vllm.triton_utils import tl, triton @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE': 64}), - triton.Config({'BLOCK_SIZE': 128}), - triton.Config({'BLOCK_SIZE': 256}), - triton.Config({'BLOCK_SIZE': 512}), - triton.Config({'BLOCK_SIZE': 1024}), - triton.Config({'BLOCK_SIZE': 2048}), + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), ], - key=['dim'], + key=["dim"], ) @triton.jit def _state_passing_fwd_kernel( # Pointers to matrices states_ptr, out_ptr, - final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions - dim, + dim: tl.constexpr, nchunks, seqlen, - chunk_size, + chunk_size: tl.constexpr, # Strides - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_dim, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_out_dim, - stride_final_states_batch, - stride_final_states_head, - stride_final_states_dim, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_initstates_batch, - stride_initstates_head, - stride_initstates_dim, - stride_seq_idx_batch, - stride_seq_idx_seqlen, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_dim: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_out_dim: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_initstates_batch: tl.int64, + stride_initstates_head: tl.int64, + stride_initstates_dim: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, # Meta-parameters HAS_INITSTATES: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - IS_CONT_BATCHED: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) + pid_h = tl.program_id(axis=1) pid_m = tl.program_id(axis=0) - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head - if HAS_INITSTATES: - initstates_ptr += pid_h * stride_initstates_head - if not IS_CONT_BATCHED: - initstates_ptr += pid_b * stride_initstates_batch - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + states_ptr += pid_h * stride_states_head + dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize + out_ptr += pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim - final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim - # - states will be the past state of the sequence that continues on the current check - if not HAS_INITSTATES: - states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: - initstates_ptr += offs_m * stride_initstates_dim - initstates_ptrs = initstates_ptr - # - for cont batches, for the first chunk mean it will be the first batch's - # init state - states = tl.load(initstates_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - tl.store(out_ptrs, states, mask=offs_m < dim) - out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx = 0 for c in range(nchunks): - new_states = tl.load(states_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) - if HAS_SEQ_IDX: - # - the seq to pass forward is the one that is flushed to the right - # boundary. - # - that is given by seq_idx_new below. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) + # we have started a new sequence + if prev_seq_idx != seq_idx: if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: - # this means in the current chunk the rightmost flushed seq - # has changed. - # - so we do not propagate the state from previous chunk - # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch - - # - update state with seq_idx_new's init state - states = tl.load(initstates_ptrs, - mask=offs_m < dim, - other=0.0).to(tl.float32) + initstates_ptrs = ( + initstates_ptr + + seq_idx * stride_initstates_batch + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( + tl.float32 + ) else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = seq_idx + states = tl.exp(dA_cs) * states + new_states + tl.store(out_ptrs, states, mask=offs_m < dim) - seq_idx = seq_idx_new - states = scale * states + new_states - if c < nchunks - 1: - tl.store(out_ptrs, states, mask=offs_m < dim) - else: - tl.store(final_states_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk dA_cs_ptr += stride_dA_cs_chunk out_ptrs += stride_out_chunk @@ -136,71 +107,51 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, + cu_chunk_seqlens, + seq_idx, initial_states=None, - seq_idx=None, - chunk_size=None, out_dtype=None, - is_cont_batched=False, ): - batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) - if initial_states is not None: - if is_cont_batched: - # - if cu_seqlens is provided, then the initial states - # are used for continuous batching. In which case we - # require seq_idx to be provided - assert seq_idx is not None, "" - else: - # - this is the regular batching case, where initial - # states are used are for each example of the batch. - assert initial_states.shape == (batch, nheads, dim) - - if seq_idx is not None: - assert chunk_size is not None - seqlen = seq_idx.shape[-1] - assert seq_idx.shape == (batch, seqlen) + nchunks, nheads, dim = states.shape + chunk_size = dA_cumsum.shape[-1] + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + seqlen = seq_idx.shape[-1] out_dtype = states.dtype if out_dtype is None else out_dtype - out = torch.empty((batch, nchunks, nheads, dim), - device=states.device, - dtype=out_dtype) - final_states = torch.empty((batch, nheads, dim), - device=states.device, - dtype=torch.float32) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) + + initial_states_strides = ( + (initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None + else (0, 0, 0) + ) + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( - states, - out, - final_states, - dA_chunk_cumsum, - initial_states, - seq_idx, - dim, - nchunks, - seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - final_states.stride(0), - final_states.stride(1), - final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2)) if initial_states is not None else - (0, 0, 0)), - *((seq_idx.stride(0), - seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + states_ptr=states, + out_ptr=out, + dA_cs_ptr=dA_cumsum, + initstates_ptr=initial_states, + seq_idx_ptr=seq_idx, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + dim=dim, + nchunks=nchunks, + seqlen=seqlen if seq_idx is not None else 0, + chunk_size=chunk_size if seq_idx is not None else 0, + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_dim=states.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_out_dim=out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_initstates_batch=initial_states_strides[0], + stride_initstates_head=initial_states_strides[1], + stride_initstates_dim=initial_states_strides[2], + stride_seq_idx_chunk=seq_idx.stride(0), HAS_INITSTATES=initial_states is not None, - HAS_SEQ_IDX=seq_idx is not None, - IS_CONT_BATCHED=is_cont_batched, ) - return out, final_states + return out diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index fead1e73e3450..32273d137eca2 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -1,41 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch -from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.platforms import current_platform + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionMetadata) +from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata @CustomOp.register("short_conv") class ShortConv(MambaBase, CustomOp): - - def __init__(self, - config, - dim: int, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): + def __init__( + self, + config, + dim: int, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config self.layer_idx = layer_idx @@ -68,15 +74,11 @@ class ShortConv(MambaBase, CustomOp): prefix=f"{prefix}.out_proj", ) - assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1") compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - self.kv_cache = [(torch.tensor([]), )] + self.kv_cache = (torch.tensor([]),) self.model_config = model_config self.cache_config = cache_config @@ -86,7 +88,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): return @@ -94,7 +95,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): torch.ops.vllm.short_conv( hidden_states, @@ -106,7 +106,6 @@ class ShortConv(MambaBase, CustomOp): self, hidden_states: torch.Tensor, output: torch.Tensor, - conv_metadata: ShortConvAttentionMetadata, ): forward_context = get_forward_context() # ShortConvAttentionMetadata contains metadata necessary for the @@ -118,19 +117,19 @@ class ShortConv(MambaBase, CustomOp): if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - conv_metadata = attn_metadata assert isinstance(attn_metadata, ShortConvAttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states + has_initial_states_p = attn_metadata.has_initial_states_p BCx, _ = self.in_proj(hidden_states) B, C, x = BCx.chunk(3, dim=-1) - conv_weights = self.conv.weight.view(self.conv.weight.size(0), - self.conv.weight.size(2)) + conv_weights = self.conv.weight.view( + self.conv.weight.size(0), self.conv.weight.size(2) + ) if attn_metadata is None: # V1 profile run @@ -171,26 +170,26 @@ class ShortConv(MambaBase, CustomOp): dim=0, ) query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) + attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes + if has_prefill + else None + ) conv_output_list = [] if has_prefill: Bx_p = (B_p * x_p).transpose(0, 1) - if conv_metadata.cu_seqlen is None: - conv_metadata = update_metadata(Bx_p, query_start_loc_p, - conv_metadata) - Bx = causal_conv1d_fn(Bx_p, - conv_weights, - self.conv.bias, - activation=None, - conv_states=conv_state, - has_initial_state=has_initial_states_p, - cache_indices=state_indices_tensor_p, - metadata=conv_metadata, - query_start_loc=query_start_loc_p).transpose( - 0, 1)[:num_prefill_tokens] + Bx = causal_conv1d_fn( + Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=attn_metadata, + query_start_loc=query_start_loc_p, + ).transpose(0, 1)[:num_prefill_tokens] y = C_p * Bx conv_output_list.append(y) @@ -203,7 +202,8 @@ class ShortConv(MambaBase, CustomOp): conv_weights, self.conv.bias, activation=None, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + ) y = C_d * Bx conv_output_list.insert(0, y) @@ -232,6 +232,11 @@ class ShortConv(MambaBase, CustomOp): def mamba_type(self) -> str: return "short_conv" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend + + return ShortConvAttentionBackend + def short_conv( hidden_states: torch.Tensor, @@ -240,9 +245,7 @@ def short_conv( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - conv_metadata=None) + self.forward_cuda(hidden_states=hidden_states, output=output) def short_conv_fake( @@ -258,5 +261,4 @@ direct_register_custom_op( op_func=short_conv, mutates_args=["output"], fake_impl=short_conv_fake, - dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py new file mode 100644 index 0000000000000..4b397a058dcd8 --- /dev/null +++ b/vllm/model_executor/layers/mla.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention.layer import MLAAttention +from vllm.config import CacheConfig +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig + + +@dataclass +class MLAModules: + """Modules used in MLA.""" + + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + rotary_emb: torch.nn.Module + o_proj: torch.nn.Module + fused_qkv_a_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_b_proj: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + indexer: Optional[torch.nn.Module] + is_sparse: bool + topk_indices_buffer: Optional[torch.Tensor] + + +@CustomOp.register("multi_head_latent_attention") +class MultiHeadLatentAttentionWrapper(CustomOp): + """MLA layer registered as CustomOp to allow OOT backends to add + custom implementations of the outer MLA layer (including rope & o_proj). + Note that currently MLA ignores the enable/disable mechanism of CustomOp + because there is only one in-tree implementation in forward_native. + TODO: implement this with a new PluggableLayer mechanism. + + This class takes positions and hidden_states as input. + The input tensors can either contain prefill tokens or decode tokens. + The class does the following: + + 1. MLA Preprocess. + 2. Perform multi-head attention to prefill tokens and + multi-query attention to decode tokens separately. + 3. Return the output tensor. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + mla_modules: MLAModules, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj + self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa + self.q_a_layernorm = mla_modules.q_a_layernorm + self.q_b_proj = mla_modules.q_b_proj + self.q_proj = mla_modules.q_proj + self.kv_a_layernorm = mla_modules.kv_a_layernorm + self.kv_b_proj = mla_modules.kv_b_proj + self.rotary_emb = mla_modules.rotary_emb + self.o_proj = mla_modules.o_proj + self.indexer = mla_modules.indexer + self.is_sparse = mla_modules.is_sparse + + if self.indexer is not None: + assert hasattr(self.indexer, "topk_tokens") + self.topk_tokens = self.indexer.topk_tokens + self.topk_indices_buffer = mla_modules.topk_indices_buffer + + self.mla_attn = MLAAttention( + num_heads=self.num_heads, + scale=scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + kv_b_proj=self.kv_b_proj, + use_sparse=self.is_sparse, + indexer=self.indexer, + ) + + self.prefix = prefix + + def forward_native( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + q_c = None + kv_lora = None + + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, ( + "fused_qkv_a_proj is required when q_lora_rank is not None" + ) + assert self.q_a_layernorm is not None, ( + "q_a_layernorm is required when q_lora_rank is not None" + ) + assert self.q_b_proj is not None, ( + "q_b_proj is required when q_lora_rank is not None" + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, ( + "kv_a_proj_with_mqa is required when q_lora_rank is None" + ) + assert self.q_proj is not None, ( + "q_proj is required when q_lora_rank is None" + ) + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) + + q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) + + if self.indexer and self.is_sparse: + _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), + ) + return self.o_proj(attn_out)[0] + + def forward_cuda(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d34fb58cb5cb2..979939ebc4686 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -12,26 +12,27 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.config import ModelConfig, PoolerConfig -from vllm.model_executor.pooling_metadata import ( # noqa: E501 - PoolingMetadata as V0PoolingMetadata) -from vllm.model_executor.pooling_metadata import PoolingTensors +from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.models.adapters import _load_st_projector from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask -from vllm.utils import current_stream, resolve_obj_by_qualname -from vllm.v1.pool.metadata import PoolingCursor -from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata +from vllm.utils import resolve_obj_by_qualname +from vllm.v1.outputs import PoolerOutput +from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata + +logger = init_logger(__name__) -PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] PoolingFn = Callable[ [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], - Union[torch.Tensor, list[torch.Tensor]]] + Union[torch.Tensor, list[torch.Tensor]], +] ClassifierFn = Callable[[torch.Tensor], torch.Tensor] class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" + LAST = 0 ALL = 1 CLS = 2 @@ -51,8 +52,7 @@ class ResolvedPoolingConfig: pooler_config: PoolerConfig, ) -> "ResolvedPoolingConfig": assert pooler_config.pooling_type is not None - return cls(task=task, - pooling_type=PoolingType[pooler_config.pooling_type]) + return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type]) @dataclass(frozen=True) @@ -72,8 +72,9 @@ class Pooler(nn.Module, ABC): if pooler_config.pooling_type == "STEP": return StepPooler() - resolved_config = ResolvedPoolingConfig(task="encode", - pooling_type=PoolingType.ALL) + resolved_config = ResolvedPoolingConfig( + task="encode", pooling_type=PoolingType.ALL + ) return SimplePooler.from_config(resolved_config) @@ -127,36 +128,22 @@ def get_prompt_lens( hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states[0].device).prompt_lens + return pooling_metadata.prompt_lens -def get_prompt_token_ids( - pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`") - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] +def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`" + ) return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) ] -def get_pooling_params( - pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [p for _, p in pooling_metadata.seq_groups] - else: - pooling_params = pooling_metadata.pooling_params +def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]: + pooling_params = pooling_metadata.pooling_params return pooling_params @@ -164,7 +151,8 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: pooling_params = get_pooling_params(pooling_metadata) tasks: list[PoolingTask] = [ - task for pooling_param in pooling_params + task + for pooling_param in pooling_params if (task := pooling_param.task) is not None ] assert len(pooling_params) == len(tasks) @@ -187,38 +175,29 @@ def get_classification_activation_function(config: PretrainedConfig): def get_cross_encoder_activation_function(config: PretrainedConfig): function_name: Optional[str] = None - if (hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers): + if ( + hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers + ): function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): + elif ( + hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None + ): function_name = config.sbert_ce_default_activation_function if function_name is not None: assert function_name.startswith("torch.nn.modules."), ( "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons") + "torch.nn.modules for security reasons" + ) fn = resolve_obj_by_qualname(function_name)() return PoolerActivation.wraps(fn) - return PoolerScore() - - -def build_output( - all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: - # Pooling models D2H & synchronize occurs here - if isinstance(all_data, list): - all_data = [d.to("cpu", non_blocking=True) for d in all_data] - else: - all_data = all_data.to("cpu", non_blocking=True) - current_stream().synchronize() - - all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] - return PoolerOutput(outputs=all_outputs) + return PoolerClassify() class PoolingMethod(nn.Module, ABC): - @staticmethod def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": if pooling_type == PoolingType.LAST: @@ -257,7 +236,6 @@ class PoolingMethod(nn.Module, ABC): class CLSPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} @@ -266,14 +244,14 @@ class CLSPool(PoolingMethod): hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - assert not pooling_cursor.is_partial_prefill(), \ + assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" + ) return hidden_states[pooling_cursor.first_token_indices_gpu] class LastPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} @@ -286,7 +264,6 @@ class LastPool(PoolingMethod): class AllPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode"} @@ -295,18 +272,17 @@ class AllPool(PoolingMethod): hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - - assert not pooling_cursor.is_partial_prefill(), \ + assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with ALL pooling" + ) hidden_states_lst = list( - hidden_states.split( - pooling_cursor.num_scheduled_tokens_cpu.tolist())) + hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) + ) return [hidden_states_lst[i] for i in pooling_cursor.index] class MeanPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} @@ -315,12 +291,13 @@ class MeanPool(PoolingMethod): hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - - assert not pooling_cursor.is_partial_prefill(), \ + assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" + ) - prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device, - non_blocking=True) + prompt_lens = pooling_cursor.prompt_lens_cpu.to( + hidden_states.device, non_blocking=True + ) # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. @@ -328,15 +305,15 @@ class MeanPool(PoolingMethod): start_indices = pooling_cursor.first_token_indices_gpu end_indices = pooling_cursor.last_token_indices_gpu - return (cumsum[end_indices] - cumsum[start_indices] + - hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + return ( + cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices] + ) / prompt_lens.unsqueeze(1) _T = TypeVar("_T", torch.Tensor, list[torch.Tensor]) class BasePoolerActivation(nn.Module, ABC): - @abstractmethod def forward(self, pooled_data: _T) -> _T: # shape: @@ -347,7 +324,6 @@ class BasePoolerActivation(nn.Module, ABC): class PoolerActivation(BasePoolerActivation): - @staticmethod def wraps(module: nn.Module): if isinstance(module, nn.Identity): @@ -369,46 +345,50 @@ class PoolerActivation(BasePoolerActivation): class PoolerIdentity(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return pooled_data class PoolerNormalize(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - x = F.normalize(pooled_data.float(), p=2, dim=-1) - return x.to(pooled_data.dtype) + return F.normalize(pooled_data, p=2, dim=-1) class PoolerMultiLabelClassify(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) class PoolerClassify(PoolerActivation): + def __init__(self, *, static_num_labels: bool = True) -> None: + super().__init__() + + if static_num_labels: + vllm_config = get_current_vllm_config() + self.num_labels = getattr( + vllm_config.model_config.hf_config, "num_labels", 0 + ) + if self.num_labels == 0: + logger.warning( + "num_labels should be > 0 for classification" + "models, falling back to softmax. " + "Please check if the configuration is correct." + ) + else: + self.num_labels = None def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] + num_labels = ( + self.num_labels if self.num_labels is not None else pooled_data.shape[-1] + ) + if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) - return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) - - -class PoolerScore(PoolerActivation): - - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = pooled_data.shape[-1] - if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) - - return pooled_data + return F.softmax(pooled_data, dim=-1) class LambdaPoolerActivation(PoolerActivation): - def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): super().__init__() @@ -419,40 +399,54 @@ class LambdaPoolerActivation(PoolerActivation): class PoolerHead(nn.Module): - def __init__(self, activation: PoolerActivation) -> None: super().__init__() self.activation = activation - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata): - + def forward( + self, + pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ): return self.activation(pooled_data) class EmbeddingPoolerHead(PoolerHead): - def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata): + # Load ST projector if available + + vllm_config = get_current_vllm_config() + self.projector: Optional[nn.Module] = ( + _load_st_projector(vllm_config.model_config) if vllm_config else None + ) + self.head_dtype = vllm_config.model_config.head_dtype + + def forward( + self, + pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ): + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_dimension] + + pooled_data = pooled_data.to(self.head_dtype) + + # Apply ST projector + if self.projector is not None: + pooled_data = self.projector(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] pooling_params = get_pooling_params(pooling_metadata) - if isinstance(pooled_data, list): - pooled_data = torch.stack(pooled_data) - # pooled_data shape: [batchsize, embedding_dimension] - # for matryoshka representation - dimensions_list = [ - pooling_param.dimensions for pooling_param in pooling_params - ] + dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) - if len(set(dimensions_list)) == 1 and not isinstance( - pooled_data, list): + if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list): # if all dimensions are the same d = dimensions_list[0] pooled_data = pooled_data[..., :d] @@ -473,16 +467,27 @@ class EmbeddingPoolerHead(PoolerHead): for vecs, f in zip(pooled_data, flags) ] + # pooled_data shape: [batchsize, embedding_dimension] return pooled_data class RewardPoolerHead(PoolerHead): - def __init__(self) -> None: - super().__init__(activation=PoolerClassify()) + super().__init__(activation=PoolerClassify(static_num_labels=False)) + + vllm_config = get_current_vllm_config() + self.head_dtype = vllm_config.model_config.head_dtype + + def forward( + self, + pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ): + if isinstance(pooled_data, list): + pooled_data = [p.to(self.head_dtype) for p in pooled_data] + else: + pooled_data = pooled_data.to(self.head_dtype) - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata): pooling_params = get_pooling_params(pooling_metadata) # for softmax @@ -541,12 +546,13 @@ class SimplePooler(Pooler): ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) - return build_output(pooled_data) + return pooled_data class StepPooler(Pooler): - - def __init__(self, ) -> None: + def __init__( + self, + ) -> None: super().__init__() self.pooling = AllPool() @@ -564,9 +570,9 @@ class StepPooler(Pooler): pooling_params = get_pooling_params(pooling_metadata) - for data, token_id, pooling_param in zip(pooled_data_lst, - prompt_token_ids, - pooling_params): + for data, token_id, pooling_param in zip( + pooled_data_lst, prompt_token_ids, pooling_params + ): step_tag_id = pooling_param.step_tag_id returned_token_ids = pooling_param.returned_token_ids @@ -592,7 +598,7 @@ class StepPooler(Pooler): ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) - return build_output(pooled_data) + return pooled_data class ClassifierPooler(Pooler): @@ -620,9 +626,15 @@ class ClassifierPooler(Pooler): ) -> None: super().__init__() + vllm_config = get_current_vllm_config() + self.pooling = pooling self.classifier = classifier self.act_fn = act_fn or PoolerClassify() + self.logit_bias: Optional[float] = ( + vllm_config.model_config.pooler_config.logit_bias + ) + self.head_dtype = vllm_config.model_config.head_dtype def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -633,19 +645,18 @@ class ClassifierPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_size] + pooled_data = pooled_data.to(self.head_dtype) + if self.classifier is not None: - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_data = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_data = self.classifier(torch.stack(pooled_data)) - else: - pooled_data = [self.classifier(data) for data in pooled_data] + pooled_data = self.classifier(pooled_data) + # pooled_data shape: [batchsize, num_labels] + + if self.logit_bias is not None: + pooled_data -= self.logit_bias pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] @@ -654,11 +665,11 @@ class ClassifierPooler(Pooler): scores = self.act_fn(pooled_data) if flags[0] else pooled_data else: scores = [ - self.act_fn(vecs) if f else vecs - for vecs, f in zip(pooled_data, flags) + self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) ] - return build_output(scores) + # scores shape: [batchsize, num_labels] + return scores class DispatchPooler(Pooler): @@ -671,7 +682,8 @@ class DispatchPooler(Pooler): if task not in pooler.get_supported_tasks(): raise ValueError( f"{pooler=} does not support {task=}. " - f"Supported tasks: {pooler.get_supported_tasks()}") + f"Supported tasks: {pooler.get_supported_tasks()}" + ) self.poolers_by_task = poolers_by_task @@ -688,21 +700,26 @@ class DispatchPooler(Pooler): ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - outputs = list[PoolingSequenceGroupOutput]() + outputs = list[torch.Tensor]() offset = 0 for task, group in groupby(get_tasks(pooling_metadata)): if not (pooler := poolers_by_task.get(task)): raise ValueError( f"Unsupported task: {task} " - f"Supported tasks: {self.get_supported_tasks()}") + f"Supported tasks: {self.get_supported_tasks()}" + ) num_items = len(list(group)) group_output: PoolerOutput = pooler( hidden_states, - pooling_metadata[offset:offset + num_items], + pooling_metadata[offset : offset + num_items], ) - outputs.extend(group_output.outputs) + outputs.extend(group_output) offset += num_items - return PoolerOutput(outputs) + return outputs + + def extra_repr(self) -> str: + s = f"supported_task={self.get_supported_tasks()}" + return s diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index d73fcf368f261..b92fb8d266b73 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -3,8 +3,7 @@ from typing import Literal, get_args -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig QuantizationMethods = Literal[ "awq", @@ -13,6 +12,7 @@ QuantizationMethods = Literal[ "fp8", "ptpc_fp8", "fbgemm_fp8", + "fp_quant", "modelopt", "modelopt_fp4", "bitblas", @@ -26,7 +26,6 @@ QuantizationMethods = Literal[ "bitsandbytes", "hqq", "experts_int8", - "neuron_quant", "ipex", "quark", "moe_wna16", @@ -53,9 +52,13 @@ def register_quantization_config(quantization: str): quantization (str): The quantization method name. Examples: - >>> from vllm.model_executor.layers.quantization import register_quantization_config + >>> from vllm.model_executor.layers.quantization import ( + ... register_quantization_config, + ... ) >>> from vllm.model_executor.layers.quantization import get_quantization_config - >>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + >>> from vllm.model_executor.layers.quantization.base_config import ( + ... QuantizationConfig, + ... ) >>> >>> @register_quantization_config("my_quant") ... class MyQuantConfig(QuantizationConfig): @@ -68,10 +71,12 @@ def register_quantization_config(quantization: str): def _wrapper(quant_config_cls): if quantization in QUANTIZATION_METHODS: raise ValueError( - f"The quantization method `{quantization}` is already exists.") + f"The quantization method `{quantization}` is already exists." + ) if not issubclass(quant_config_cls, QuantizationConfig): - raise ValueError("The quantization config must be a subclass of " - "`QuantizationConfig`.") + raise ValueError( + "The quantization config must be a subclass of `QuantizationConfig`." + ) _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls QUANTIZATION_METHODS.append(quantization) return quant_config_cls @@ -91,12 +96,14 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .awq_marlin import AWQMarlinConfig from .bitblas import BitBLASConfig from .bitsandbytes import BitsAndBytesConfig - from .compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig) + from .compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) from .deepspeedfp import DeepSpeedFPConfig from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config from .fp8 import Fp8Config + from .fp_quant import FPQuantConfig from .gguf import GGUFConfig from .gptq import GPTQConfig from .gptq_bitblas import GPTQBitBLASConfig @@ -108,7 +115,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config - from .neuron_quant import NeuronQuantConfig from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .rtn import RTNConfig @@ -121,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, + "fp_quant": FPQuantConfig, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, "bitblas": BitBLASConfig, @@ -135,7 +142,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "ptpc_fp8": PTPCFp8Config, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, - "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index fb285413ba9ef..b7ebc6f272db5 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Any, Optional, Union import torch from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -53,36 +53,45 @@ class AutoRoundConfig(QuantizationConfig): ) -> None: super().__init__() if weight_bits not in self.SUPPORTED_BITS: - raise ValueError(f"Unsupported weight_bits: {weight_bits}, " - f"currently only support {self.SUPPORTED_BITS}") + raise ValueError( + f"Unsupported weight_bits: {weight_bits}, " + f"currently only support {self.SUPPORTED_BITS}" + ) if data_type not in self.SUPPORTED_DTYPES: raise ValueError( f"Unsupported data_type: {data_type}," - f" currently only support {self.SUPPORTED_DTYPES}") + f" currently only support {self.SUPPORTED_DTYPES}" + ) if packing_format not in self.SUPPORTED_FORMATS: raise ValueError( f"Unsupported packing_format: {packing_format}, " - f"currently only support {self.SUPPORTED_FORMATS}") + f"currently only support {self.SUPPORTED_FORMATS}" + ) if backend not in self.SUPPORTED_BACKENDS: raise ValueError( f"Unsupported backend: {backend}, " - f"currently only support {self.SUPPORTED_BACKENDS}") + f"currently only support {self.SUPPORTED_BACKENDS}" + ) self.weight_bits = weight_bits self.group_size = group_size self.sym = sym self.packing_format = packing_format - self.block_name_to_quantize = (block_name_to_quantize.split(",") if - isinstance(block_name_to_quantize, str) - else block_name_to_quantize) + self.block_name_to_quantize = ( + block_name_to_quantize.split(",") + if isinstance(block_name_to_quantize, str) + else block_name_to_quantize + ) self.extra_config = extra_config self.data_type = data_type self.backend = backend self.pack_factor = Fraction(32, weight_bits) def __repr__(self) -> str: - return (f"AutoRoundConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, sym={self.sym})") + return ( + f"AutoRoundConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, sym={self.sym})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -106,19 +115,18 @@ class AutoRoundConfig(QuantizationConfig): weight_bits=cls.get_from_keys(config, ["bits"]), group_size=cls.get_from_keys(config, ["group_size"]), sym=cls.get_from_keys(config, ["sym"]), - packing_format=cls.get_from_keys_or(config, ["packing_format"], - "auto_round:auto_gptq"), + packing_format=cls.get_from_keys_or( + config, ["packing_format"], "auto_round:auto_gptq" + ), block_name_to_quantize=cls.get_from_keys_or( - config, ["block_name_to_quantize", "to_quant_block_names"], - None), + config, ["block_name_to_quantize", "to_quant_block_names"], None + ), extra_config=cls.get_from_keys_or(config, ["extra_config"], None), data_type=cls.get_from_keys_or(config, ["data_type"], "int"), - backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], - "auto"), + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"), ) def get_layer_config(self, layer, layer_name: str): - def get_config(name: str, quantized: bool = True): cfg = self.extra_config.get(name, {}) if self.extra_config else {} return ( @@ -135,39 +143,38 @@ class AutoRoundConfig(QuantizationConfig): quantized = not isinstance(layer, ParallelLMHead) if self.block_name_to_quantize: quantized = any( - layer_name.startswith(name) - for name in self.block_name_to_quantize) + layer_name.startswith(name) for name in self.block_name_to_quantize + ) # 3. Handle fused MoE - if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower( - ): + if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(): moe_configs = [ - get_config(name, quantized) for name in self.extra_config + get_config(name, quantized) + for name in self.extra_config if name.startswith(layer_name) ] if moe_configs: if len(set(moe_configs)) == 1: return moe_configs[0] - raise ValueError(f"Fused MoE layer '{layer_name}' requires " - f"consistent quant config for all sub-layers") + raise ValueError( + f"Fused MoE layer '{layer_name}' requires " + f"consistent quant config for all sub-layers" + ) # 4. Handle fused QKV or other patterns if self.extra_config: for fusion_key, sub_keys in self.packed_modules_mapping.items(): - if fusion_key in layer_name and layer_name.count( - fusion_key) == 1: + if fusion_key in layer_name and layer_name.count(fusion_key) == 1: sub_names = [ - layer_name.replace(fusion_key, sub_key) - for sub_key in sub_keys - ] - sub_configs = [ - get_config(name, quantized) for name in sub_names + layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys ] + sub_configs = [get_config(name, quantized) for name in sub_names] if len(set(sub_configs)) == 1: return sub_configs[0] raise ValueError( f"Fused module '{layer_name}' requires " - f"consistent quant config for {sub_names}") + f"consistent quant config for {sub_names}" + ) # 5. Fallback return get_config(layer_name, quantized) @@ -178,14 +185,17 @@ class AutoRoundConfig(QuantizationConfig): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.block_name_to_quantize is not None: self.block_name_to_quantize = hf_to_vllm_mapper.apply_list( - self.block_name_to_quantize) + self.block_name_to_quantize + ) if self.extra_config is not None: self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config) def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, check_moe_marlin_supports_layer) + check_marlin_supported, + check_moe_marlin_supports_layer, + ) weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): @@ -207,19 +217,23 @@ class AutoRoundConfig(QuantizationConfig): 4: scalar_types.uint4, 8: scalar_types.uint8, } - use_marlin = (weight_bits - in AWQ_TYPE_MAP) and check_marlin_supported( - AWQ_TYPE_MAP[weight_bits], group_size, not sym) + use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym + ) if isinstance(layer, FusedMoE): use_marlin = use_marlin and check_moe_marlin_supports_layer( - layer, group_size) + layer, group_size + ) else: use_marlin = False if use_marlin: from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod) + AWQMarlinConfig, + AWQMarlinLinearMethod, + AWQMoEMethod, + ) quant_args_marlin = AWQMarlinConfig( weight_bits=weight_bits, @@ -231,7 +245,9 @@ class AutoRoundConfig(QuantizationConfig): ) else: from vllm.model_executor.layers.quantization.awq import ( - AWQConfig, AWQLinearMethod) + AWQConfig, + AWQLinearMethod, + ) quant_args = AWQConfig( weight_bits=weight_bits, @@ -241,9 +257,8 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: - return AWQMoEMethod(quant_args_marlin, layer.moe) - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + return AWQMoEMethod(quant_args_marlin, layer.moe_config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config config = { "quant_method": "awq", @@ -252,8 +267,7 @@ class AutoRoundConfig(QuantizationConfig): "zero_point": not sym, "lm_head": False, } - return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: @@ -262,13 +276,12 @@ class AutoRoundConfig(QuantizationConfig): return AWQLinearMethod(quant_args) return None - def apply_gptq_quant_layer(self, - layer, - prefix: str, - backend: str = "auto"): + def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, check_moe_marlin_supports_layer) + check_marlin_supported, + check_moe_marlin_supports_layer, + ) weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): @@ -290,19 +303,21 @@ class AutoRoundConfig(QuantizationConfig): (4, True): scalar_types.uint4b8, (8, True): scalar_types.uint8b128, } - use_marlin = (weight_bits, - sym) in GPTQ_TYPE_MAP and check_marlin_supported( - GPTQ_TYPE_MAP[(weight_bits, sym)], - group_size, - has_zp=not sym) + use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym + ) if isinstance(layer, FusedMoE): use_marlin = use_marlin and check_moe_marlin_supports_layer( - layer, group_size) + layer, group_size + ) else: use_marlin = False if use_marlin: from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod) + GPTQMarlinConfig, + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) quant_args_marlin = GPTQMarlinConfig( weight_bits=weight_bits, @@ -315,7 +330,9 @@ class AutoRoundConfig(QuantizationConfig): ) else: from vllm.model_executor.layers.quantization.gptq import ( - GPTQConfig, GPTQLinearMethod) + GPTQConfig, + GPTQLinearMethod, + ) quant_args = GPTQConfig( weight_bits=weight_bits, @@ -327,8 +344,11 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config) + else: from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + MoeWNA16Config, + ) config = { "quant_method": "gptq", @@ -338,8 +358,8 @@ class AutoRoundConfig(QuantizationConfig): "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) - return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) + layer, prefix + ) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: @@ -357,29 +377,36 @@ class AutoRoundConfig(QuantizationConfig): else: return None from vllm.model_executor.layers.quantization.ipex_quant import ( - IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod) + IPEXAWQLinearMethod, + IPEXConfig, + IPEXGPTQLinearMethod, + ) if isinstance(layer, (LinearBase, ParallelLMHead)): if "awq" in self.packing_format: - config = IPEXConfig(method="awq", - weight_bits=weight_bits, - group_size=group_size) + config = IPEXConfig( + method="awq", weight_bits=weight_bits, group_size=group_size + ) return IPEXAWQLinearMethod(config) elif "gptq" in self.packing_format: - config = IPEXConfig(method="gptq", - weight_bits=weight_bits, - group_size=group_size) + config = IPEXConfig( + method="gptq", weight_bits=weight_bits, group_size=group_size + ) return IPEXGPTQLinearMethod(config) else: raise ValueError( f"ipex backend only supports awq " - f"and gtpq format,but got {self.packing_format}") + f"and gtpq format,but got {self.packing_format}" + ) else: return None def get_quant_method(self, layer: torch.nn.Module, prefix: str): - if (current_platform.is_cpu() or current_platform.is_xpu() - or self.backend == "ipex"): + if ( + current_platform.is_cpu() + or current_platform.is_xpu() + or self.backend == "ipex" + ): return self.apply_ipex_quant_layer(layer, prefix) if "gptq" in self.packing_format or "gptq" in self.backend: return self.apply_gptq_quant_layer(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index af602eb9aca38..d4f667564848c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -8,13 +8,17 @@ import torch from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter logger = init_logger(__name__) @@ -41,14 +45,17 @@ class AWQConfig(QuantizationConfig): if self.weight_bits != 4: raise ValueError( "Currently, only 4-bit weight quantization is supported for " - f"AWQ, but got {self.weight_bits} bits.") + f"AWQ, but got {self.weight_bits} bits." + ) self.pack_factor = 32 // self.weight_bits def __repr__(self) -> str: - return (f"AWQConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point}, " - f"modules_to_not_convert={self.modules_to_not_convert})") + return ( + f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) def get_name(self) -> QuantizationMethods: return "awq" @@ -75,7 +82,8 @@ class AWQConfig(QuantizationConfig): group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) + config, ["modules_to_not_convert"], None + ) return cls(weight_bits, group_size, zero_point, modules_to_not_convert) def get_quant_method( @@ -90,10 +98,12 @@ class AWQConfig(QuantizationConfig): from .awq_marlin import AWQMarlinConfig, AWQMoEMethod from .moe_wna16 import MoeWNA16Config from .utils.marlin_utils import check_moe_marlin_supports_layer + if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") + "Falling back to Moe WNA16 kernels." + ) config = { "quant_method": "awq", "bits": self.weight_bits, @@ -102,7 +112,8 @@ class AWQConfig(QuantizationConfig): "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + layer, prefix + ) marlin_compatible_config_dict = { "quant_method": "awq", "bits": self.weight_bits, @@ -112,7 +123,8 @@ class AWQConfig(QuantizationConfig): "modules_to_not_convert": self.modules_to_not_convert, } awq_marlin_config = AWQMarlinConfig.from_config( - marlin_compatible_config_dict) + marlin_compatible_config_dict + ) return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None @@ -131,11 +143,16 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size @@ -146,14 +163,16 @@ class AWQLinearMethod(LinearMethodBase): raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) weight_loader = extra_weight_attrs.get("weight_loader") qweight = PackedvLLMParameter( @@ -166,7 +185,8 @@ class AWQLinearMethod(LinearMethodBase): output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) num_groups = input_size_per_partition // group_size @@ -180,38 +200,40 @@ class AWQLinearMethod(LinearMethodBase): output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - scales = GroupQuantScaleParameter(data=torch.empty( - num_groups, - output_size_per_partition, - dtype=params_dtype, - ), - input_dim=0, - output_dim=1, - weight_loader=weight_loader) + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.qweight = torch.nn.Parameter(layer.qweight.data, - requires_grad=False) - layer.qzeros = torch.nn.Parameter(layer.qzeros.data, - requires_grad=False) - layer.scales = torch.nn.Parameter(layer.scales.data, - requires_grad=False) + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) # num_tokens >= threshold @@ -221,8 +243,7 @@ class AWQLinearMethod(LinearMethodBase): out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) out = torch.matmul(reshaped_x, out) else: - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, - pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 287d66b06d6e9..5d142387d4d9e 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from torch.nn import Parameter @@ -9,28 +9,46 @@ from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, - UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod, - set_weight_attrs) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs, +) from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.awq import (AWQConfig, - is_layer_skipped_awq) +from vllm.model_executor.layers.quantization.awq import AWQConfig, is_layer_skipped_awq from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - check_marlin_supports_layer, check_moe_marlin_supports_layer, - marlin_make_empty_g_idx, marlin_make_workspace_new, - marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales, - moe_awq_to_marlin_zero_points, verify_marlin_supported, - verify_marlin_supports_shape) + apply_awq_marlin_linear, + awq_to_marlin_zero_points, + check_marlin_supported, + check_marlin_supports_layer, + check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_moe_permute_scales, + marlin_permute_bias, + marlin_permute_scales, + moe_awq_to_marlin_zero_points, + verify_marlin_supported, + verify_marlin_supports_shape, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -46,10 +64,15 @@ class AWQMarlinConfig(QuantizationConfig): 8: scalar_types.uint8, } - def __init__(self, weight_bits: int, group_size: int, zero_point: bool, - lm_head_quantized: bool, - modules_to_not_convert: Optional[list[str]], - full_config: dict[str, Any]) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any], + ) -> None: super().__init__() self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size @@ -60,21 +83,25 @@ class AWQMarlinConfig(QuantizationConfig): self.full_config = full_config if self.weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " - f"Supported num_bits = {self.TYPE_MAP.keys()}") + raise ValueError( + f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}" + ) self.quant_type = self.TYPE_MAP[self.weight_bits] - verify_marlin_supported(self.quant_type, - group_size=self.group_size, - has_zp=self.zero_point) + verify_marlin_supported( + self.quant_type, group_size=self.group_size, has_zp=self.zero_point + ) def __repr__(self) -> str: - return (f"AWQMarlinConfig(quant_type={self.quant_type}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point}, " - f"lm_head_quantized={self.lm_head_quantized}, " - f"modules_to_not_convert={self.modules_to_not_convert})") + return ( + f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -97,37 +124,51 @@ class AWQMarlinConfig(QuantizationConfig): weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) - return cls(weight_bits, group_size, zero_point, lm_head_quantized, - modules_to_not_convert, config) + config, ["modules_to_not_convert"], None + ) + return cls( + weight_bits, + group_size, + zero_point, + lm_head_quantized, + modules_to_not_convert, + config, + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "marlin" - or user_quant == "awq_marlin") + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "awq": - logger.info("Detected that the model can run with awq_marlin" - ", however you specified quantization=awq explicitly," - " so forcing awq. Use quantization=awq_marlin for" - " faster inference") + logger.info( + "Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference" + ) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() # Check if the layer is supported by AWQMarlin. @@ -136,21 +177,25 @@ class AWQMarlinConfig(QuantizationConfig): "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 prefix, ) - return AWQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + if is_layer_skipped_awq( - prefix, getattr(self, "modules_to_not_convert", [])): + prefix, getattr(self, "modules_to_not_convert", []) + ): return UnquantizedFusedMoEMethod(layer.moe_config) if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") - return MoeWNA16Config.from_config( - self.full_config).get_quant_method(layer, prefix) + "Falling back to Moe WNA16 kernels." + ) + return MoeWNA16Config.from_config(self.full_config).get_quant_method( + layer, prefix + ) return AWQMoEMethod(self, layer.moe_config) return None @@ -169,15 +214,15 @@ class AWQMarlinConfig(QuantizationConfig): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or zero_point is None): + if num_bits is None or group_size is None or zero_point is None: return False if num_bits not in cls.TYPE_MAP: return False - return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], - group_size=group_size, - has_zp=zero_point) + return check_marlin_supported( + quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point + ) class AWQMarlinLinearMethod(LinearMethodBase): @@ -214,7 +259,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) qweight = PackedvLLMParameter( data=torch.empty( @@ -226,7 +272,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) num_groups = input_size_per_partition // group_size @@ -240,16 +287,19 @@ class AWQMarlinLinearMethod(LinearMethodBase): output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - scales = GroupQuantScaleParameter(data=torch.empty( - num_groups, - output_size_per_partition, - dtype=params_dtype, - ), - input_dim=0, - output_dim=1, - weight_loader=weight_loader) + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) @@ -265,12 +315,9 @@ class AWQMarlinLinearMethod(LinearMethodBase): # Here, we handle the repacking def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device - layer.qweight = torch.nn.Parameter(layer.qweight.data, - requires_grad=False) - layer.qzeros = torch.nn.Parameter(layer.qzeros.data, - requires_grad=False) - layer.scales = torch.nn.Parameter(layer.scales.data, - requires_grad=False) + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) # Allocate marlin workspace layer.workspace = marlin_make_workspace_new(device) @@ -280,7 +327,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. @@ -288,7 +336,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.scales, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. @@ -296,7 +345,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.qzeros, size_k=layer.num_groups, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_parameter(layer, "qzeros", marlin_zp) # Not-used @@ -323,11 +373,11 @@ class AWQMarlinLinearMethod(LinearMethodBase): quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) class AWQMoEMethod(FusedMoEMethodBase): - def __init__( self, quant_config: AWQMarlinConfig, @@ -339,75 +389,93 @@ class AWQMoEMethod(FusedMoEMethodBase): raise ValueError("AWQMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - extra_weight_attrs.update({ - "is_transposed": - True, - "quant_method": - FusedMoeWeightScaleSupported.GROUP.value, - }) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + extra_weight_attrs.update( + { + "is_transposed": True, + "quant_method": FusedMoeWeightScaleSupported.GROUP.value, + } + ) w13_qweight = Parameter( - torch.empty(num_experts, - hidden_size, - 2 * intermediate_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) - w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size_per_partition, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w2_qweight = Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) num_groups_w13 = hidden_size // self.quant_config.group_size - num_groups_w2 = (intermediate_size_per_partition // - self.quant_config.group_size) + num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. - w13_scales = Parameter(torch.empty(num_experts, - num_groups_w13, - intermediate_size_per_partition * 2, - dtype=params_dtype), - requires_grad=False) + w13_scales = Parameter( + torch.empty( + num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) - w2_scales = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scales = Parameter( + torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. w13_qzeros = Parameter( - torch.empty(num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) - w2_qzeros = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w2_qzeros = Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) @@ -467,14 +535,16 @@ class AWQMoEMethod(FusedMoEMethodBase): layer.w13_qzeros, size_k=layer.w13_qzeros.shape[1], size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.weight_bits, + ) replace_parameter(layer, "w13_qzeros", marlin_w13_zp) marlin_w2_zp = moe_awq_to_marlin_zero_points( layer.w2_qzeros, size_k=layer.w2_qzeros.shape[1], size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.weight_bits, + ) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) if hasattr(layer, "w13_bias") and layer.w13_bias is not None: @@ -483,6 +553,11 @@ class AWQMoEMethod(FusedMoEMethodBase): if hasattr(layer, "w2_bias") and layer.w2_bias is not None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -497,6 +572,7 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -504,16 +580,15 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `AWQMoEMethod` yet.") + raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.") assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -523,8 +598,10 @@ class AWQMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -543,4 +620,5 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, - workspace=layer.workspace) + workspace=layer.workspace, + ) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index ebc526d6db2f9..67b4dbbfd4d82 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -10,16 +10,17 @@ AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] @triton.jit def awq_dequantize_kernel( - qweight_ptr, # quantized matrix - scales_ptr, # scales, per group - zeros_ptr, # zeros, per group - group_size, # Should always be one of the supported group sizes - result_ptr, # Output matrix - num_cols, # input num cols in qweight - num_rows, # input num rows in qweight - BLOCK_SIZE_X: tl.constexpr, - BLOCK_SIZE_Y: tl.constexpr): - # Setup the pids. + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr, +): + # Set up the pids. pid_x = tl.program_id(axis=0) pid_y = tl.program_id(axis=1) @@ -35,10 +36,10 @@ def awq_dequantize_kernel( # Compute offsets and masks for result output ptr. result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) - result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( - 0, BLOCK_SIZE_X * 8) - result_offsets = (8 * num_cols * result_offsets_y[:, None] + - result_offsets_x[None, :]) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + result_offsets = ( + 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] + ) result_masks_y = result_offsets_y < num_rows result_masks_x = result_offsets_x < num_cols * 8 @@ -52,8 +53,9 @@ def awq_dequantize_kernel( # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. - reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + - tl.arange(0, 4)[:, None]).reshape(8) + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) # Use this to compute a set of shifts that can be used to unpack and # reorder the values in iweights and zeros. @@ -85,10 +87,8 @@ def awq_dequantize_kernel( # Compute scale offsets and masks. scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) - scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + - tl.arange(0, BLOCK_SIZE_X * 8)) - scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + - scale_offsets_x[None, :]) + scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] scale_masks_y = scale_offsets_y < num_rows // group_size scale_masks_x = scale_offsets_x < num_cols * 8 scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] @@ -106,10 +106,21 @@ def awq_dequantize_kernel( @triton.jit -def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, - group_size, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - SPLIT_K: tl.constexpr): +def awq_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + zeros_ptr, + scales_ptr, + M, + N, + K, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) @@ -128,18 +139,17 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, # (BLOCK_SIZE_M, BLOCK_SIZE_N)) # accumulator = accumulator & 0x0 # accumulator = accumulator.to(accumulator_dtype) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), - dtype=accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. - reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + - tl.arange(0, 4)[:, None]).reshape(8) + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) # Create the necessary shifts to use to unpack. shifts = reverse_awq_order_tensor * 4 - shifts = tl.broadcast_to(shifts[None, :], - (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) # Offsets and masks. @@ -178,8 +188,8 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, # Dequantize b. offsets_szk = ( - (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + - tl.arange(0, 1)) + BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K + ) // group_size + tl.arange(0, 1) offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] @@ -220,11 +230,13 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, # qweights - [K , M // 8], int32 # scales - [K // G, M ], float16 # zeros - [K // G, M // 8], int32 -def awq_dequantize_triton(qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - block_size_x: int = 32, - block_size_y: int = 32) -> torch.Tensor: +def awq_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32, +) -> torch.Tensor: K = qweight.shape[0] M = scales.shape[1] group_size = qweight.shape[0] // scales.shape[0] @@ -238,27 +250,31 @@ def awq_dequantize_triton(qweight: torch.Tensor, # Result tensor: # number of rows = same as input tensor # number of cols = 8 x input tensor num cols - result = torch.empty(qweight.shape[0], - qweight.shape[1] * 8, - device=qweight.device, - dtype=scales.dtype) + result = torch.empty( + qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype, + ) Y = qweight.shape[0] # num rows X = qweight.shape[1] # num cols grid = lambda META: ( - triton.cdiv(X, META['BLOCK_SIZE_X']), - triton.cdiv(Y, META['BLOCK_SIZE_Y']), + triton.cdiv(X, META["BLOCK_SIZE_X"]), + triton.cdiv(Y, META["BLOCK_SIZE_Y"]), + ) + awq_dequantize_kernel[grid]( + qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y, ) - awq_dequantize_kernel[grid](qweight, - scales, - zeros, - group_size, - result, - X, - Y, - BLOCK_SIZE_X=block_size_x, - BLOCK_SIZE_Y=block_size_y) return result @@ -268,14 +284,16 @@ def awq_dequantize_triton(qweight: torch.Tensor, # qzeros - [K // G, N // 8] # scales - [K // G, N] # split_k_iters - parallelism along K-dimension, int, power of 2. -def awq_gemm_triton(input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - split_k_iters: int, - block_size_m: int = 32, - block_size_n: int = 32, - block_size_k: int = 32) -> torch.Tensor: +def awq_gemm_triton( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, +) -> torch.Tensor: M, K = input.shape N = qweight.shape[1] * 8 group_size = qweight.shape[0] // qzeros.shape[0] @@ -290,30 +308,29 @@ def awq_gemm_triton(input: torch.Tensor, assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - N, META['BLOCK_SIZE_N']), + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), split_k_iters, ) - result = torch.zeros((split_k_iters, M, N), - dtype=scales.dtype, - device=input.device) + result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device) # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N - awq_gemm_kernel[grid](input, - qweight, - result, - qzeros, - scales, - M, - N, - K, - group_size, - BLOCK_SIZE_M=block_size_m, - BLOCK_SIZE_N=block_size_n, - BLOCK_SIZE_K=block_size_k, - SPLIT_K=split_k_iters) + awq_gemm_kernel[grid]( + input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters, + ) result = result.sum(0) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 4a43351260e9f..26f5e8bb6c7df 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -19,8 +19,9 @@ class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, *weight_args, - **extra_weight_attrs): + def create_weights( + self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs + ): """Create weights for a layer. The weights will be set as attributes of the layer.""" @@ -34,8 +35,7 @@ class QuantizeMethodBase(ABC): raise NotImplementedError # Not required functions - def embedding(self, layer: torch.nn.Module, *args, - **kwargs) -> torch.Tensor: + def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: """Gather embeddings in the layer based on indices in the input tensor. Expects create_weights to have been called before on the layer.""" @@ -49,19 +49,16 @@ class QuantizeMethodBase(ABC): return -def method_has_implemented_embedding( - method_class: type[QuantizeMethodBase]) -> bool: +def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool: """ Not all quant methods have embedding implemented, so we need to check that it exists for our given method. We check this by making sure the function has been changed from the base implementation. """ - base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", - None) + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None) class_embedding = inspect.getattr_static(method_class, "embedding", None) - return (class_embedding is not None - and class_embedding is not base_embedding) + return class_embedding is not None and class_embedding is not base_embedding class QuantizationConfig(ABC): @@ -107,12 +104,13 @@ class QuantizationConfig(ABC): @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: """ - Detects if this quantization method can support a given checkpoint - format by overriding the user specified quantization method -- - this method should only be overwritten by subclasses in exceptional - circumstances + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances """ return None @@ -122,23 +120,24 @@ class QuantizationConfig(ABC): for key in keys: if key in config: return config[key] - raise ValueError(f"Cannot find any of {keys} in the model's " - "quantization config.") + raise ValueError( + f"Cannot find any of {keys} in the model's quantization config." + ) @staticmethod - def get_from_keys_or(config: dict[str, Any], keys: list[str], - default: Any) -> Any: - """Get a optional value from the model's quantization config.""" + def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: + """Get an optional value from the model's quantization config.""" try: return QuantizationConfig.get_from_keys(config, keys) except ValueError: return default @abstractmethod - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional[QuantizeMethodBase]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: """Get the quantize method to use for the quantized layer. - + Args: layer: The layer for the quant method. prefix: The full name of the layer in the state dict @@ -152,7 +151,8 @@ class QuantizationConfig(ABC): return None def apply_vllm_mapper( # noqa: B027 - self, hf_to_vllm_mapper: "WeightsMapper"): + self, hf_to_vllm_mapper: "WeightsMapper" + ): """ Interface for models to update module names referenced in quantization configs in order to reflect the vllm model structure @@ -162,3 +162,9 @@ class QuantizationConfig(ABC): """ # TODO (@kylesayrs): add implementations for all subclasses pass + + def maybe_update_config(self, model_name: str): # noqa: B027 + """ + Interface to update values after config initialization. + """ + pass diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 39bd34d351f61..d2e0582be197e 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -7,17 +7,23 @@ from packaging import version from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, - BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) + BITBLAS_OPTIMIZE_FEATURES, + BITBLAS_SUPPORTED_NUM_BITS, + BITBLAS_SUPPORTED_SYM, + MINIMUM_BITBLAS_VERSION, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -28,6 +34,7 @@ class BitBLASConfig(QuantizationConfig): Reference: https://github.com/Microsoft/BitBLAS """ + TORCH_DTYPE = torch.float16 STORAGE_DTYPE = "int8" # assume int8 storage TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) @@ -46,11 +53,14 @@ class BitBLASConfig(QuantizationConfig): ) -> None: try: import bitblas + if version.parse(bitblas.__version__) < version.parse( - MINIMUM_BITBLAS_VERSION): + MINIMUM_BITBLAS_VERSION + ): raise ImportError( "bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e raise ValueError( @@ -78,12 +88,14 @@ class BitBLASConfig(QuantizationConfig): raise ValueError( f"BitBLAS does not support weight_bits = {self.weight_bits}. " f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " - "are supported.") + "are supported." + ) if self.is_sym not in BITBLAS_SUPPORTED_SYM: raise ValueError( f"BitBLAS does not support is_sym = {self.is_sym}. " - f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported." + ) storage_dtype = self.STORAGE_DTYPE storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) @@ -98,11 +110,13 @@ class BitBLASConfig(QuantizationConfig): self.zeros_mode = self.ZEROS_MODE def __repr__(self) -> str: - return (f"BitBLASConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}, " - f"is_sym={self.is_sym}, " - f"quant_method={self.quant_method})") + return ( + f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -122,9 +136,9 @@ class BitBLASConfig(QuantizationConfig): return ["quantize_config.json"] @staticmethod - def get_from_keys(config: dict[str, Any], - keys: list[str], - default: Any = None) -> Any: + def get_from_keys( + config: dict[str, Any], keys: list[str], default: Any = None + ) -> Any: """Get a value from the model's quantization config.""" for key in keys: if key in config: @@ -138,34 +152,40 @@ class BitBLASConfig(QuantizationConfig): desc_act = cls.get_from_keys(config, ["desc_act"], False) is_sym = cls.get_from_keys(config, ["sym"], False) quant_method = cls.get_from_keys(config, ["quant_method"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, is_sym, quant_method, - lm_head_quantized) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls( + weight_bits, group_size, desc_act, is_sym, quant_method, lm_head_quantized + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_bitblas_format: bool - is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" - or hf_quant_cfg.get("is_bitblas_format", False)) + is_bitblas_format = hf_quant_cfg.get( + "checkpoint_format" + ) == "bitblas" or hf_quant_cfg.get("is_bitblas_format", False) - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "bitblas") + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "bitblas" + ) if is_bitblas_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. Using {} kernel.". - format(cls.get_name(), cls.get_name())) + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) logger.info(msg) return cls.get_name() return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["BitBLASLinearMethod"]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): return BitBLASLinearMethod(self) return None @@ -176,6 +196,7 @@ class BitBLASLinearMethod(LinearMethodBase): Args: quant_config: The BitBLAS quantization config. """ + # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS # Instead of BITBLAS_OPTIMIZE_FEATURES # If you want to high contiguous batching @@ -202,45 +223,47 @@ class BitBLASLinearMethod(LinearMethodBase): output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, - ): + ) -> None: """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing quantized + The function initializes and returns a dictionary containing quantized weights, scales, and zeros for performing quantized matrix multiplication operations. Args: input_size_per_partition: The size of the input partition. - output_size_per_partition: The size of the output partition. + output_partition_sizes: List of output partition sizes. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: + params_dtype: The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or if the - input size per partition is not divisible by the group size in - `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the input + size per partition is not divisible by the group size + in `quant_config`. """ del input_size, output_size # Unused arguments. weight_loader = extra_weight_attrs["weight_loader"] if params_dtype not in self.quant_config.get_supported_act_dtypes(): - raise ValueError("Parameter data type must be torch.float16, " - f"but got {params_dtype}") + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) group_size = self.quant_config.group_size if group_size is None: group_size = -1 # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) - if (group_size != -1 and input_size_per_partition % group_size != 0): + if group_size != -1 and input_size_per_partition % group_size != 0: raise ValueError( f"Input size per partition ({input_size_per_partition}) must " - f"be divisible by group size ({group_size}).") + f"be divisible by group size ({group_size})." + ) # Initialize or retrieve the BitBLAS matrix multiplication operator. self._configure_bitblas_matmul( @@ -266,34 +289,33 @@ class BitBLASLinearMethod(LinearMethodBase): output_dim=0, packed_dim=1, packed_factor=self.quant_config.pack_factor, - bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2] - if self.bitblas_matmul.propagate_b else None), + bitblas_tile_size=( + self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b + else None + ), weight_loader=weight_loader, ) # Compute the number of input groups for channel-wise quantization. - input_groups = (1 if group_size == -1 else input_size_per_partition // - group_size) + input_groups = 1 if group_size == -1 else input_size_per_partition // group_size # Initialize scales and zeros for the quantized weights. weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( output_size_per_partition, input_groups, device="cuda", dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) else: - scales = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) if self.quant_config.zeros_mode == "quantized": zeros = PackedvLLMParameter( @@ -313,17 +335,22 @@ class BitBLASLinearMethod(LinearMethodBase): else: zeros = BasevLLMParameter( - torch.empty(output_size_per_partition, - input_groups, - device="cuda", - dtype=params_dtype), + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), weight_loader=weight_loader, ) # Set attributes to indicate how scales and zeros are applied. - set_weight_attrs(zeros, { - "input_dim": None if input_groups == 1 else 1, - "output_dim": 0, - }) + set_weight_attrs( + zeros, + { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0, + }, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("scales", scales) @@ -340,13 +367,19 @@ class BitBLASLinearMethod(LinearMethodBase): **extra_weight_attrs, ): if self.quant_config.quant_method == "gptq": - return self.create_weights_gptq(layer, input_size_per_partition, - output_partition_sizes, input_size, - output_size, params_dtype, - **extra_weight_attrs) + return self.create_weights_gptq( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) else: raise ValueError( - f"Unsupported quant_method {self.quant_config.quant_method}") + f"Unsupported quant_method {self.quant_config.quant_method}" + ) def _configure_bitblas_matmul( self, @@ -360,6 +393,7 @@ class BitBLASLinearMethod(LinearMethodBase): out_dtype="float16", ): from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] with_scaling = False @@ -375,7 +409,8 @@ class BitBLASLinearMethod(LinearMethodBase): W_dtype = f"int{bits}" else: raise ValueError( - f"Unsupported quant_method {self.quant_config.quant_method}") + f"Unsupported quant_method {self.quant_config.quant_method}" + ) matmul_config = MatmulConfig( N=outfeatures, @@ -393,38 +428,40 @@ class BitBLASLinearMethod(LinearMethodBase): zeros_mode=zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( - matmul_config, enable_tuning) + matmul_config, enable_tuning + ) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul, auto_detect_nvidia_target from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: - global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, - BITBLAS_TARGET) + global_operator_cache.load_from_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, - target=BITBLAS_TARGET, - enable_tuning=False) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) if enable_tuning: - TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...") + TUNING_MESSAGE = f"BitBLAS Operator {config} is tuning ..." logger.info(TUNING_MESSAGE) bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) TUNED_MESSAGE = ( - f"BitBLAS Operator {config} tuned and saved to database.") + f"BitBLAS Operator {config} tuned and saved to database." + ) logger.info(TUNED_MESSAGE) else: _message = f"BitBLAS Operator {config} created." logger.info(_message) else: - _message = ( - f"BitBLAS Operator {config} found in global_operator_cache.") + _message = f"BitBLAS Operator {config} found in global_operator_cache." logger.info(_message) return bitblas_matmul @@ -445,7 +482,7 @@ class BitBLASLinearMethod(LinearMethodBase): else: output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add @@ -461,4 +498,5 @@ class BitBLASLinearMethod(LinearMethodBase): return self.apply_gptq(*args, **kwargs) else: raise ValueError( - f"Unsupported quant_method {self.quant_config.quant_method}") + f"Unsupported quant_method {self.quant_config.quant_method}" + ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index b7897a43793c7..80ed121bd85b8 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -6,15 +6,21 @@ from typing import Any, Callable, Optional, Union import torch from packaging import version -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEConfig, - FusedMoEMethodBase) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod, - set_weight_attrs) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs, +) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -51,16 +57,19 @@ class BitsAndBytesConfig(QuantizationConfig): self.llm_int8_threshold = llm_int8_threshold if self.bnb_4bit_quant_storage not in ["uint8"]: - raise ValueError("Unsupported bnb_4bit_quant_storage: " - f"{self.bnb_4bit_quant_storage}") + raise ValueError( + f"Unsupported bnb_4bit_quant_storage: {self.bnb_4bit_quant_storage}" + ) def __repr__(self) -> str: - return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " - f"load_in_4bit={self.load_in_4bit}, " - f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " - f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " - f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " - f"llm_int8_skip_modules={self.llm_int8_skip_modules})") + return ( + f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " + f"load_in_4bit={self.load_in_4bit}, " + f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " + f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " + f"llm_int8_skip_modules={self.llm_int8_skip_modules})" + ) @classmethod def get_name(self) -> QuantizationMethods: @@ -80,7 +89,6 @@ class BitsAndBytesConfig(QuantizationConfig): @classmethod def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig": - def get_safe_value(config, keys, default_value=None): try: value = cls.get_from_keys(config, keys) @@ -88,30 +96,32 @@ class BitsAndBytesConfig(QuantizationConfig): except ValueError: return default_value - load_in_8bit = get_safe_value(config, ["load_in_8bit"], - default_value=False) - load_in_4bit = get_safe_value(config, ["load_in_4bit"], - default_value=True) - bnb_4bit_compute_dtype = get_safe_value(config, - ["bnb_4bit_compute_dtype"], - default_value="float32") - bnb_4bit_quant_storage = get_safe_value(config, - ["bnb_4bit_quant_storage"], - default_value="uint8") - bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], - default_value="fp4") + load_in_8bit = get_safe_value(config, ["load_in_8bit"], default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], default_value=True) + bnb_4bit_compute_dtype = get_safe_value( + config, ["bnb_4bit_compute_dtype"], default_value="float32" + ) + bnb_4bit_quant_storage = get_safe_value( + config, ["bnb_4bit_quant_storage"], default_value="uint8" + ) + bnb_4bit_quant_type = get_safe_value( + config, ["bnb_4bit_quant_type"], default_value="fp4" + ) bnb_4bit_use_double_quant = get_safe_value( - config, ["bnb_4bit_use_double_quant"], default_value=False) + config, ["bnb_4bit_use_double_quant"], default_value=False + ) llm_int8_enable_fp32_cpu_offload = get_safe_value( - config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) - llm_int8_has_fp16_weight = get_safe_value(config, - ["llm_int8_has_fp16_weight"], - default_value=False) - llm_int8_skip_modules = get_safe_value(config, - ["llm_int8_skip_modules"], - default_value=[]) - llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], - default_value=6.0) + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False + ) + llm_int8_has_fp16_weight = get_safe_value( + config, ["llm_int8_has_fp16_weight"], default_value=False + ) + llm_int8_skip_modules = get_safe_value( + config, ["llm_int8_skip_modules"], default_value=[] + ) + llm_int8_threshold = get_safe_value( + config, ["llm_int8_threshold"], default_value=6.0 + ) return cls( load_in_8bit=load_in_8bit, @@ -123,7 +133,8 @@ class BitsAndBytesConfig(QuantizationConfig): llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, llm_int8_skip_modules=llm_int8_skip_modules, - llm_int8_threshold=llm_int8_threshold) + llm_int8_threshold=llm_int8_threshold, + ) def get_quant_method( self, layer: torch.nn.Module, prefix: str @@ -139,15 +150,15 @@ class BitsAndBytesConfig(QuantizationConfig): def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): # Split the prefix into its dot-separated components - components = prefix.split('.') + components = prefix.split(".") # Check if any of the skip modules exactly matches any component - substr_check = any(module_name in components - for module_name in llm_int8_skip_modules) + substr_check = any( + module_name in components for module_name in llm_int8_skip_modules + ) # Allow certain layers to not be quantized - set_components = set(".".join(components[:i + 1]) - for i in range(len(components))) + set_components = set(".".join(components[: i + 1]) for i in range(len(components))) set_llm_int8_skip_modules = set(llm_int8_skip_modules) prefix_check = len(set_llm_int8_skip_modules & set_components) != 0 @@ -171,39 +182,53 @@ class BitsAndBytesLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if version.parse( - bitsandbytes.__version__) < version.parse("0.46.1"): - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.46.1.") + + if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"): + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1." + ) except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.46.1 via " - "`pip install bitsandbytes>=0.46.1` to use " - "bitsandbytes quantizer.") from err + raise ImportError( + "Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer." + ) from err self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): from bitsandbytes.nn import Int8Params def create_qweight_for_8bit(): qweight = Int8Params( - data=torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8, + ), has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, - requires_grad=False) + requires_grad=False, + ) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 0, "output_dim": 0, "pack_factor": 1, "use_bitsandbytes_8bit": True, - "generation": 0 - }) + "generation": 0, + }, + ) return qweight def create_qweight_for_4bit(): @@ -212,20 +237,22 @@ class BitsAndBytesLinearMethod(LinearMethodBase): total_size = input_size_per_partition * sum(output_partition_sizes) if total_size % quant_ratio != 0: raise ValueError( - "The input size is not aligned with the quantized " - "weight shape.") + "The input size is not aligned with the quantized weight shape." + ) - qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, - 1, - dtype=torch.uint8), - requires_grad=False) + qweight = torch.nn.Parameter( + torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8), + requires_grad=False, + ) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 0, "output_dim": 0, "pack_factor": quant_ratio, - "use_bitsandbytes_4bit": True - }) + "use_bitsandbytes_4bit": True, + }, + ) return qweight if self.quant_config.load_in_8bit: @@ -237,22 +264,23 @@ class BitsAndBytesLinearMethod(LinearMethodBase): layer.register_parameter("weight", qweight) set_weight_attrs(qweight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.quant_config.load_in_8bit: return self._apply_8bit_weight(layer, x, bias) else: return self._apply_4bit_weight(layer, x, bias) def _apply_8bit_weight( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: # only load the bitsandbytes module when needed from bitsandbytes import MatmulLtState, matmul @@ -272,11 +300,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase): out_dim_0 = x.shape[0] out_dim_1 = sum( - [quant_state[1].shape[0] for quant_state in quant_states.items()]) - out = torch.empty(out_dim_0, - out_dim_1, - dtype=torch.float16, - device=x.device) + [quant_state[1].shape[0] for quant_state in quant_states.items()] + ) + out = torch.empty(out_dim_0, out_dim_1, dtype=torch.float16, device=x.device) current_index = 0 for i in range(len(quant_states)): @@ -286,33 +312,36 @@ class BitsAndBytesLinearMethod(LinearMethodBase): # create new matmul_states if generation == 0 or generation == 1: matmul_states[i] = MatmulLtState() - matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].CB = qweight[offsets[i] : offsets[i + 1]] matmul_states[i].SCB = quant_states[i].to(x.device) - matmul_states[i].threshold = ( - self.quant_config.llm_int8_threshold) - matmul_states[i].has_fp16_weights = ( - self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].threshold = self.quant_config.llm_int8_threshold + matmul_states[ + i + ].has_fp16_weights = self.quant_config.llm_int8_has_fp16_weight matmul_states[i].is_training = False - if matmul_states[i].threshold > 0.0 and not matmul_states[ - i].has_fp16_weights: + if ( + matmul_states[i].threshold > 0.0 + and not matmul_states[i].has_fp16_weights + ): matmul_states[i].use_pool = True new_x = bf_x.unsqueeze(0) - out[:, current_index:current_index + output_size] = matmul( - new_x, - qweight[offsets[i]:offsets[i + 1]], - state=matmul_states[i]) + out[:, current_index : current_index + output_size] = matmul( + new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i] + ) current_index += output_size # only update the matmul_states if it is not profile_run - if (generation > 0 - and not self.quant_config.llm_int8_has_fp16_weight - and matmul_states[i].CB is not None - and matmul_states[i].CxB is not None): + if ( + generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None + ): del matmul_states[i].CB - qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + qweight[offsets[i] : offsets[i + 1]] = matmul_states[i].CxB out = out.to(original_type) @@ -327,11 +356,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase): return out def _apply_4bit_weight( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: original_type = x.dtype original_shape = x.shape reshape_after_matmul = False @@ -346,11 +375,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase): out_dim_0 = x.shape[0] out_dim_1 = sum( - [quant_state[1].shape[0] for quant_state in quant_states.items()]) - out = torch.empty(out_dim_0, - out_dim_1, - dtype=torch.bfloat16, - device=x.device) + [quant_state[1].shape[0] for quant_state in quant_states.items()] + ) + out = torch.empty(out_dim_0, out_dim_1, dtype=torch.bfloat16, device=x.device) apply_bnb_4bit(bf_x, qweight, offsets, out) out = out.to(original_type) @@ -371,6 +398,7 @@ def _apply_bnb_4bit( ) -> None: # only load the bitsandbytes module when needed from bitsandbytes import matmul_4bit + quant_states = weight.bnb_quant_state current_index = 0 for i in range(len(quant_states)): @@ -379,8 +407,9 @@ def _apply_bnb_4bit( # matmul_4bit(..., out = ...). Infeasible now due to the bug # https://github.com/TimDettmers/bitsandbytes/issues/1235. # Need to change after the bug is fixed. - out[:, current_index:current_index + output_size] = matmul_4bit( - x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + out[:, current_index : current_index + output_size] = matmul_4bit( + x, weight[offsets[i] : offsets[i + 1]].t(), quant_states[i] + ) current_index += output_size @@ -394,11 +423,13 @@ def _apply_bnb_4bit_fake( try: - direct_register_custom_op(op_name="apply_bnb_4bit", - op_func=_apply_bnb_4bit, - mutates_args=["out"], - fake_impl=_apply_bnb_4bit_fake, - dispatch_key=current_platform.dispatch_key) + direct_register_custom_op( + op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + dispatch_key=current_platform.dispatch_key, + ) apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit except AttributeError as error: @@ -420,14 +451,18 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): super().__init__(moe) try: import bitsandbytes - if version.parse( - bitsandbytes.__version__) < version.parse("0.46.1"): - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.46.1.") + + if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"): + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1." + ) except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.46.1 via " - "`pip install bitsandbytes>=0.46.1` to use " - "bitsandbytes quantizer.") from err + raise ImportError( + "Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer." + ) from err self.quant_config = quant_config def create_weights( @@ -452,6 +487,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): **extra_weight_attrs, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -466,6 +506,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -473,14 +514,16 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from vllm.model_executor.layers.fused_moe import fused_experts + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for `BitsAndBytesMoEMethod` yet.") - topk_weights, topk_ids = FusedMoE.select_experts( + "EPLB not supported for `BitsAndBytesMoEMethod` yet." + ) + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -490,8 +533,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) if self.quant_config.load_in_8bit: w13, w2 = self._apply_8bit_dequant(layer) else: @@ -507,6 +552,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + quant_config=self.moe_quant_config, ) def _create_weights_4bit( @@ -520,8 +566,9 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ): quant_ratio = calculate_quant_ratio(params_dtype) # Fused gate_up_proj (column parallel) - w13_total_size = (hidden_size * 2 * - intermediate_size_per_partition) // quant_ratio + w13_total_size = ( + hidden_size * 2 * intermediate_size_per_partition + ) // quant_ratio w13_qweight = torch.nn.Parameter( torch.empty( num_experts, @@ -536,26 +583,20 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): set_weight_attrs( w13_qweight, { - "num_experts": - num_experts, - "input_dim": - hidden_size, - "output_dim": - 2 * intermediate_size_per_partition, + "num_experts": num_experts, + "input_dim": hidden_size, + "output_dim": 2 * intermediate_size_per_partition, "experts_shape": ( num_experts, intermediate_size_per_partition * 2, hidden_size, ), - "pack_factor": - quant_ratio, - "use_bitsandbytes_4bit": - True, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True, }, ) # down_proj (row parallel) - w2_total_size = (hidden_size * - intermediate_size_per_partition) // quant_ratio + w2_total_size = (hidden_size * intermediate_size_per_partition) // quant_ratio w2_qweight = torch.nn.Parameter( torch.empty( num_experts, @@ -568,21 +609,16 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): set_weight_attrs( w2_qweight, { - "num_experts": - num_experts, - "input_dim": - intermediate_size_per_partition, - "output_dim": - hidden_size, + "num_experts": num_experts, + "input_dim": intermediate_size_per_partition, + "output_dim": hidden_size, "experts_shape": ( num_experts, hidden_size, intermediate_size_per_partition, ), - "pack_factor": - quant_ratio, - "use_bitsandbytes_4bit": - True, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True, }, ) layer.register_parameter("w2_weight", w2_qweight) @@ -600,8 +636,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): raise NotImplementedError def _apply_4bit_dequnt( - self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + self, layer: torch.nn.Module + ) -> tuple[torch.Tensor, torch.Tensor]: from bitsandbytes.functional import dequantize_4bit + w13 = dequantize_4bit( layer.w13_weight.reshape(-1, 1), layer.w13_weight.bnb_quant_state, @@ -615,5 +653,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): return w13, w2 def _apply_8bit_dequant( - self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + self, layer: torch.nn.Module + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 637a84372990a..e89d002078ac1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,37 +5,62 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch -from compressed_tensors.config import (CompressionFormat, - SparsityCompressionConfig, - SparsityStructure) -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) -from pydantic import BaseModel +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from compressed_tensors.transform import TransformConfig import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 - CompressedTensorsMoEMethod) + CompressedTensorsMoEMethod, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, - CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + W4A16SPARSE24_SUPPORTED_BITS, + WNA16_SUPPORTED_BITS, + CompressedTensors24, + CompressedTensorsScheme, + CompressedTensorsW4A4Fp4, + CompressedTensorsW4A8Fp8, + CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16, +) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod, + get_linear_transform_schemes, +) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, is_activation_quantization_format, - should_ignore_layer) + find_matched_target, + is_activation_quantization_format, + should_ignore_layer, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) from vllm.platforms import current_platform if TYPE_CHECKING: @@ -50,7 +75,6 @@ QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] class CompressedTensorsConfig(QuantizationConfig): - def __init__( self, target_scheme_map: dict[str, Any], @@ -60,6 +84,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_ignore_list: list[str], kv_cache_scheme: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None, + transform_config: Optional[dict[str, Any]] = None, ): super().__init__() self.ignore = ignore @@ -71,6 +96,11 @@ class CompressedTensorsConfig(QuantizationConfig): self.sparsity_ignore_list = sparsity_ignore_list self.config = config + if transform_config: + self.transform_config = TransformConfig.model_validate(transform_config) + else: + self.transform_config = None + def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -85,16 +115,16 @@ class CompressedTensorsConfig(QuantizationConfig): return "compressed-tensors" def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - self.target_scheme_map = hf_to_vllm_mapper.apply_dict( - self.target_scheme_map) + self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map) self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( - self.sparsity_scheme_map) + self.sparsity_scheme_map + ) self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( - self.sparsity_ignore_list) + self.sparsity_ignore_list + ) if self.kv_cache_scheme is not None: - self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( - self.kv_cache_scheme) + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(self.kv_cache_scheme) def get_quant_method( self, @@ -103,18 +133,28 @@ class CompressedTensorsConfig(QuantizationConfig): ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import - # Check if the layer is skipped for quantization. - # TODO (@robertgshaw2): support module names - if should_ignore_layer(prefix, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): - return UnquantizedLinearMethod() if isinstance(layer, LinearBase): - scheme = self.get_scheme(layer=layer, layer_name=prefix) - if scheme is None: - return UnquantizedLinearMethod() - layer.scheme = scheme - return CompressedTensorsLinearMethod(self) + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + input_tfms, output_tfms = get_linear_transform_schemes( + layer, prefix, self.transform_config, self.packed_modules_mapping + ) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + quant_method = CompressedTensorsLinearMethod(self) + + # choose transform method + if any((input_tfms, output_tfms)): + return CompressedTensorsLinearTransformMethod.from_schemes( + quant_method, quant_scheme, input_tfms, output_tfms + ) + + else: + return quant_method + if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): @@ -125,10 +165,11 @@ class CompressedTensorsConfig(QuantizationConfig): def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": ignore: list[str] = cast(list[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) - target_scheme_map = cls._quantization_scheme_map_from_config( - config=config) + target_scheme_map = cls._quantization_scheme_map_from_config(config=config) sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( - config=config) + config=config + ) + transform_config = config.get("transform_config") return cls( target_scheme_map=target_scheme_map, @@ -137,6 +178,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_scheme_map=sparsity_scheme_map, sparsity_ignore_list=sparsity_ignore_list, config=config, + transform_config=transform_config, ) @classmethod @@ -153,18 +195,17 @@ class CompressedTensorsConfig(QuantizationConfig): if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)): return dict(), [] - sparsity_config = SparsityCompressionConfig.model_validate( - sparsity_config) + sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config) sparse_scheme_map: dict[str, SparsityCompressionConfig] = { - target: sparsity_config - for target in sparsity_config.targets or list() + target: sparsity_config for target in sparsity_config.targets or list() } sparsity_ignore_list = sparsity_config.ignore or list() return sparse_scheme_map, sparsity_ignore_list @classmethod def _quantization_scheme_map_from_config( - cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + cls, config: dict[str, Any] + ) -> QUANTIZATION_SCHEME_MAP_TYPE: """ :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding @@ -187,42 +228,47 @@ class CompressedTensorsConfig(QuantizationConfig): targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} - target_scheme_map[target][ - "weights"] = QuantizationArgs.model_validate( - quant_config.get("weights")) + target_scheme_map[target]["weights"] = QuantizationArgs.model_validate( + quant_config.get("weights") + ) target_scheme_map[target]["input_activations"] = None - target_scheme_map[target]["format"] = quant_config.get( - "format") + target_scheme_map[target]["format"] = quant_config.get("format") format = target_scheme_map[target].get("format") # If no per-config format defined, use global format in config - act_quant_format = is_activation_quantization_format( - format - ) if format is not None else is_activation_quantization_format( - quant_format) - if act_quant_format: - input_activations = quant_config.get("input_activations") + act_quant_format = ( + is_activation_quantization_format(format) + if format is not None + else is_activation_quantization_format(quant_format) + ) + # TODO(czhu): w4a8fp8 is in packed-quantized format + # but needs input activation quantization + input_activations = quant_config.get("input_activations") + if act_quant_format or input_activations: # The only case where we have activation quant supported # but no input_activations provided in the config # should be w8a16fp8 w8a16fp8 can also run for cases where # there is an input_quant but it is ignored if not input_activations: - assert target_scheme_map[target][ - "weights"].type == QuantizationType.FLOAT + assert ( + target_scheme_map[target]["weights"].type + == QuantizationType.FLOAT + ) else: - target_scheme_map[target][ - "input_activations"] = QuantizationArgs.model_validate( # noqa: E501 - quant_config.get("input_activations")) + target_scheme_map[target]["input_activations"] = ( + QuantizationArgs.model_validate( + quant_config.get("input_activations") + ) + ) return target_scheme_map @classmethod def get_config_filenames(cls) -> list[str]: return [] - def _check_scheme_supported(self, - min_capability: int, - error: bool = True, - match_exact: bool = False) -> bool: + def _check_scheme_supported( + self, min_capability: int, error: bool = True, match_exact: bool = False + ) -> bool: capability_tuple = current_platform.get_device_capability() if capability_tuple is not None: @@ -233,113 +279,155 @@ class CompressedTensorsConfig(QuantizationConfig): raise RuntimeError( "Quantization scheme is not supported for ", "the current GPU. Required capability: ", - f"{min_capability}. Current capability: {capability}.") + f"{min_capability}. Current capability: {capability}.", + ) else: supported = capability >= min_capability if error and not supported: raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + f"Current capability: {capability}.", + ) return supported else: return False - def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): - + def _is_fp4a4_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): if weight_quant is None or input_quant is None: return False - is_tensor_group_quant = (weight_quant.strategy - == QuantizationStrategy.TENSOR_GROUP.value - and input_quant.strategy - == QuantizationStrategy.TENSOR_GROUP.value) + is_tensor_group_quant = ( + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) is_symmetric = weight_quant.symmetric and input_quant.symmetric - is_group_size_16 = (weight_quant.group_size == 16 - and input_quant.group_size == 16) - is_float_type = (weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT.value) + is_group_size_16 = ( + weight_quant.group_size == 16 and input_quant.group_size == 16 + ) + is_float_type = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT.value + ) is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 - return (is_tensor_group_quant and is_float_type and is_4_bits - and is_group_size_16 and is_symmetric) - - def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, - input_quant: BaseModel): + return ( + is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) + def _is_fp4a16_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): is_weight_only = weight_quant is not None and input_quant is None is_tensor_group_quant = ( - weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value) + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) is_symmetric = weight_quant.symmetric is_group_size_16 = weight_quant.group_size == 16 is_float_type = weight_quant.type == QuantizationType.FLOAT is_4_bits = weight_quant.num_bits == 4 - return (is_weight_only and is_tensor_group_quant and is_float_type - and is_4_bits and is_group_size_16 and is_symmetric) + return ( + is_weight_only + and is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) - def _is_static_tensor_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_static_tensor_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_tensor = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TENSOR.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_tensor = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TENSOR.value + ) is_static = not weight_quant.dynamic and not input_quant.dynamic # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. return is_8_bits and is_tensor and weight_quant.symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) is_dynamic = not weight_quant.dynamic and input_quant.dynamic # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w4a8_int( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: is_weight_4_bits = weight_quant.num_bits == 4 is_activation_8_bits = input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.GROUP.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) is_dynamic = not weight_quant.dynamic and input_quant.dynamic # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. - return (is_weight_4_bits and is_activation_8_bits and is_token - and weight_quant.symmetric and is_dynamic) + return ( + is_weight_4_bits + and is_activation_8_bits + and is_token + and weight_quant.symmetric + and is_dynamic + ) - def _is_fp8_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False # Confirm weight scheme is supported. - is_floating_point = (weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT) + is_floating_point = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT + ) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL - ]) - if not (is_floating_point and is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight): + is_tensor_or_channel_or_block_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK, + ] + if not ( + is_floating_point + and is_symmetric_weight + and is_static_weight + and is_tensor_or_channel_or_block_weight + ): return False # Dynamic quantization is always supported if weights supported. @@ -348,23 +436,56 @@ class CompressedTensorsConfig(QuantizationConfig): # Confirm activation scheme is supported. is_symmetric_activation = input_quant.symmetric - is_per_tensor_activation = ( - input_quant.strategy == QuantizationStrategy.TENSOR) + is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: - return (self._check_scheme_supported(90, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w4a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + if not weight_quant or not input_quant: + return False + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = weight_quant.strategy == QuantizationStrategy.GROUP.value + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + # Only per-group symmetric weight (4bit) + # + per-tok symmetric activation (8bit) quantization supported. + return ( + is_weight_4_bits + and is_activation_8_bits + and is_token + and is_symmetric + and is_dynamic + ) - def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: - return (self._check_scheme_supported( - 100, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w4a8_sm90( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + return self._check_scheme_supported( + 90, error=False, match_exact=True + ) and self._is_fp8_w4a8(weight_quant, input_quant) - def _is_fp8_w8a16(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8_sm90( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + return self._check_scheme_supported( + 90, error=False, match_exact=True + ) and self._is_fp8_w8a8(weight_quant, input_quant) + + def _is_fp8_w8a8_sm100( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + return self._check_scheme_supported( + 100, error=False, match_exact=True + ) and self._is_fp8_w8a8(weight_quant, input_quant) + + def _is_fp8_w8a16( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -376,120 +497,142 @@ class CompressedTensorsConfig(QuantizationConfig): # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL - ]) - if not (is_symmetric_weight and is_static_weight # noqa: SIM103 - and is_per_tensor_or_channel_weight): - return False + is_tensor_or_channel_or_block_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK, + ] + return ( + is_symmetric_weight + and is_static_weight + and is_tensor_or_channel_or_block_weight + ) - # All conditions satisfied. - return True - - def _is_wNa16_group_channel(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_wNa16_group_channel( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: input_quant_none = input_quant is None is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value - or weight_quant.strategy == QuantizationStrategy.GROUP.value) + or weight_quant.strategy == QuantizationStrategy.GROUP.value + ) is_static = not weight_quant.dynamic - return (is_channel_group and input_quant_none and is_static) + return is_channel_group and input_quant_none and is_static def _get_scheme_from_parts( - self, - weight_quant: BaseModel, - input_quant: BaseModel, - format: Optional[str] = None) -> "CompressedTensorsScheme": + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + format: Optional[str] = None, + ) -> "CompressedTensorsScheme": + # use the per-layer format if defined, otherwise, use global format + format = format if format is not None else self.quant_format + # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): + return CompressedTensorsW4A8Fp8( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder, + ) + if self._is_wNa16_group_channel(weight_quant, input_quant): - if (self.quant_format == CompressionFormat.marlin_24.value - and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + if ( + format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS + ): assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.pack_quantized.value - and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + group_size=weight_quant.group_size, + ) + if ( + format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS + ): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, symmetric=weight_quant.symmetric, group_size=weight_quant.group_size, - actorder=weight_quant.actorder) + actorder=weight_quant.actorder, + ) - act_quant_format = is_activation_quantization_format( - format - ) if format is not None else is_activation_quantization_format( - self.quant_format) + act_quant_format = is_activation_quantization_format(format) if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if cutlass_fp4_supported( - ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: + if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: logger.warning_once( "Current platform does not support cutlass NVFP4." - " Running CompressedTensorsW4A16Fp4.") - return CompressedTensorsW4A16Fp4( - has_input_global_scale=True) + " Running CompressedTensorsW4A16Fp4." + ) + return CompressedTensorsW4A16Fp4(has_input_global_scale=True) if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + CompressedTensorsW8A8Fp8.get_min_capability(), error=False + ) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( - strategy=weight_quant.strategy, - is_static_input_scheme=(input_quant - and not input_quant.dynamic)) + weight_quant=weight_quant, + is_static_input_scheme=( + input_quant and not input_quant.dynamic + ), + ) else: # note: input_quant will be present for converted models; # will be ignored during inference post loading return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=not input_quant.dynamic) + is_static_input_scheme=not input_quant.dynamic, + ) # note: input_quant can be None if self._is_fp8_w8a16(weight_quant, input_quant): - is_static_input_scheme = (input_quant - and not input_quant.dynamic) + is_static_input_scheme = input_quant and not input_quant.dynamic return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=is_static_input_scheme) + is_static_input_scheme=is_static_input_scheme, + ) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( strategy=weight_quant.strategy, is_static_input_scheme=True, - input_symmetric=input_quant.symmetric) + input_symmetric=input_quant.symmetric, + ) if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( strategy=weight_quant.strategy, is_static_input_scheme=False, - input_symmetric=input_quant.symmetric) + input_symmetric=input_quant.symmetric, + ) if self._is_dynamic_token_w4a8_int(weight_quant, input_quant): - is_static_input_scheme = (input_quant - and not input_quant.dynamic) + is_static_input_scheme = input_quant and not input_quant.dynamic return CompressedTensorsW4A8Int( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, group_size=weight_quant.group_size, is_static_input_scheme=is_static_input_scheme, - input_symmetric=input_quant.symmetric) + input_symmetric=input_quant.symmetric, + ) - raise NotImplementedError( - "No compressed-tensors compatible scheme was found.") + raise NotImplementedError("No compressed-tensors compatible scheme was found.") - def get_scheme(self, - layer: torch.nn.Module, - layer_name: Optional[str] = None - ) -> Optional["CompressedTensorsScheme"]: + def get_scheme( + self, layer: torch.nn.Module, layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: """ compressed-tensors supports non uniform in the following way: @@ -505,9 +648,11 @@ class CompressedTensorsConfig(QuantizationConfig): # Find the "target" in the compressed-tensors config # that our layer conforms to. - # TODO (@robertgshaw): add compressed-tensors as dep - # so we do not have to re-write these functions - # need to make accelerate optional in ct to do this + # TODO (@kylesayrs): support ignore module names with ct matching utils + if should_ignore_layer( + layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): + return None # Will be empty for models with only sparsity weight_quant = input_quant = None @@ -516,7 +661,8 @@ class CompressedTensorsConfig(QuantizationConfig): layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + ) scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") @@ -524,26 +670,32 @@ class CompressedTensorsConfig(QuantizationConfig): format = scheme_dict.get("format") # Find the sparsity scheme of the layer - # assume that fused layers inerhit first component's sparsity scheme - sparsity_targets = (self.sparsity_scheme_map.keys() - - set(self.sparsity_ignore_list)) + # assume that fused layers inherit first component's sparsity scheme + sparsity_targets = self.sparsity_scheme_map.keys() - set( + self.sparsity_ignore_list + ) sparsity_scheme: Optional[SparsityCompressionConfig] = None with suppress(ValueError): matched_target = find_matched_target( layer_name=layer_name, module=layer, targets=sparsity_targets, - fused_mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + ) sparsity_scheme = self.sparsity_scheme_map[matched_target] - if self.supports_cutlass_24(weight_quant=weight_quant, - input_quant=input_quant, - sparsity_scheme=sparsity_scheme): + if self.supports_cutlass_24( + weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme, + ): # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - model_compression_config = (None if sparsity_scheme is None - or sparsity_scheme.format == "dense" - else self.config) + model_compression_config = ( + None + if sparsity_scheme is None or sparsity_scheme.format == "dense" + else self.config + ) scheme = CompressedTensors24( quantized=weight_quant is not None or input_quant is not None, @@ -552,23 +704,23 @@ class CompressedTensorsConfig(QuantizationConfig): model_compression_config=model_compression_config, ) elif weight_quant is None: - logger.warning_once("Acceleration for non-quantized schemes is " - "not supported by Compressed Tensors. " - "Falling back to UnquantizedLinearMethod") + logger.warning_once( + "Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod" + ) return None else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore - weight_quant=weight_quant, - input_quant=input_quant, - format=format) + weight_quant=weight_quant, input_quant=input_quant, format=format + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) - logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, - layer_name) + logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) return scheme def get_cache_scale(self, name: str) -> Optional[str]: @@ -587,11 +739,21 @@ class CompressedTensorsConfig(QuantizationConfig): # If no matches, return None return None + def has_blocked_weights(self) -> bool: + for scheme in self.target_scheme_map.values(): + weight_quant = scheme.get("weights") + if ( + weight_quant is not None + and weight_quant.strategy == QuantizationStrategy.BLOCK + ): + return True + return False + @staticmethod def supports_cutlass_24( - weight_quant: Optional[QuantizationArgs], - input_quant: Optional[QuantizationArgs], - sparsity_scheme: Optional[SparsityCompressionConfig] = None + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig] = None, ) -> bool: """ Check if the layer is supported by the Cutlass 2:4 Kernel @@ -601,7 +763,7 @@ class CompressedTensorsConfig(QuantizationConfig): - Weight only quantization is not-supported - Supported weight quantization strategies are TENSOR and CHANNEL - Supported input quantization strategies are TENSOR and TOKEN - - Only 8 bit quantization is supported + - Only 8 bit quantization is supported :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise @@ -610,16 +772,17 @@ class CompressedTensorsConfig(QuantizationConfig): return False is_valid_sparsity_structure: bool = ( - sparsity_scheme.sparsity_structure == - SparsityStructure.TWO_FOUR.value) + sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value + ) valid_compressors = { CompressionFormat.dense.value, - CompressionFormat.sparse_24_bitmask.value + CompressionFormat.sparse_24_bitmask.value, } - is_valid_sparsity = (is_valid_sparsity_structure - and sparsity_scheme.format in valid_compressors) + is_valid_sparsity = ( + is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors + ) if not is_valid_sparsity: return False @@ -634,7 +797,7 @@ class CompressedTensorsConfig(QuantizationConfig): supported_weight_quant_strategies = [ QuantizationStrategy.TENSOR.value, - QuantizationStrategy.CHANNEL.value + QuantizationStrategy.CHANNEL.value, ] assert weight_quant is not None @@ -643,7 +806,8 @@ class CompressedTensorsConfig(QuantizationConfig): return False supported_input_quant_strategies = [ - QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.TOKEN.value, ] if input_quant.strategy not in supported_input_quant_strategies: @@ -653,18 +817,22 @@ class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsLinearMethod(LinearMethodBase): - def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """ Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param @@ -678,19 +846,21 @@ class CompressedTensorsLinearMethod(LinearMethodBase): output_partition_sizes=output_partition_sizes, output_size=output_size, params_dtype=params_dtype, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details """ - scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") @@ -724,18 +894,21 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): raise NotImplementedError( "Currently supported kv cache quantization is " "num_bits=8, type=float, however " - f"received num_bits={num_bits}, type={type_}") + f"received num_bits={num_bits}, type={type_}" + ) strategy = kv_cache_scheme.get("strategy") if strategy != "tensor": raise NotImplementedError( "Only support per-tensor scaling factor " "for compressed-tensors KV cache. " - f"Expected strategy: tensor, found strategy: {strategy}") + f"Expected strategy: tensor, found strategy: {strategy}" + ) is_symmetric = kv_cache_scheme.get("symmetric") if not is_symmetric: raise NotImplementedError( "Only support symmetric scaling factor " "for compressed-tensors KV cache. " - f"However found symmetric: {is_symmetric}") + f"However found symmetric: {is_symmetric}" + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7bc35cd81ac3f..41e7f1c7a4997 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,43 +3,78 @@ import enum from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import (ActivationOrdering, - QuantizationStrategy) +from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, - FusedMoeWeightScaleSupported) + FusedMoE, + FusedMoEActivationFormat, + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + int4_w4a16_moe_quant_config, + int8_w8a8_moe_quant_config, + int8_w8a16_moe_quant_config, + nvfp4_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe) + is_valid_flashinfer_cutlass_fused_moe, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa - WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) + WNA16_SUPPORTED_BITS, + WNA16_SUPPORTED_TYPES_MAP, +) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target, +) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, - select_nvfp4_gemm_impl) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, + reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + expert_weight_is_col_major, + requant_weight_ue8m0_inplace, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_moe_marlin_supports_layer, marlin_make_workspace_new, - marlin_moe_permute_scales) + check_moe_marlin_supports_layer, + marlin_make_workspace_new, + marlin_moe_permute_scales, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin) + prepare_moe_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - prepare_moe_fp8_layer_for_marlin) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - swizzle_blockscale) + prepare_moe_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) + all_close_1d, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform +from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types +from vllm.utils.deep_gemm import ( + get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, +) logger = init_logger(__name__) @@ -50,15 +85,17 @@ class GPTQMarlinState(Enum): __all__ = [ - "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsMoEMethod", + "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4MoeMethod" + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod", + "CompressedTensorsW4A8Int8MoEMethod", ] class CompressedTensorsMoEMethod(FusedMoEMethodBase): - def __init_(self, moe: FusedMoEConfig): super().__init__(moe) @@ -69,60 +106,97 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. - weight_quant = quant_config.target_scheme_map["Linear"].get("weights") - input_quant = quant_config.target_scheme_map["Linear"].get( - "input_activations") + # Check if a using "Linear" to select schemes + if "Linear" in quant_config.target_scheme_map: + matched_target = "Linear" + else: + # May have instead defined the linear layers in the fused model + + fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"] + current_scheme = None + for fused_layer in fused_layers: + # Check if one of the fused layers are defined in quant_config + matched_target = find_matched_target( + layer_name=fused_layer, + module=layer, + targets=quant_config.target_scheme_map.keys(), + fused_mapping=quant_config.packed_modules_mapping, + ) + + # Only valid if down_proj, gate_proj, and up_proj + # are mapped to the same quant scheme in the quant_config + if current_scheme is None: + current_scheme = quant_config.target_scheme_map.get(matched_target) + else: + assert current_scheme == quant_config.target_scheme_map.get( + matched_target + ) + + weight_quant = quant_config.target_scheme_map[matched_target].get("weights") + input_quant = quant_config.target_scheme_map[matched_target].get( + "input_activations" + ) if quant_config._is_wNa16_group_channel(weight_quant, input_quant): # group_size=None means channelwise group_size = weight_quant.group_size or -1 # Prefer to use the MarlinMoE kernel when it is supported. if not check_moe_marlin_supports_layer(layer, group_size): - if (weight_quant.strategy in QuantizationStrategy.GROUP and - weight_quant.actorder in (ActivationOrdering.GROUP, - ActivationOrdering.DYNAMIC)): + if ( + weight_quant.strategy in QuantizationStrategy.GROUP + and weight_quant.actorder + in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC) + ): raise ValueError( "WNA16MoE is not supported with actorder=group/dynamic." ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config, - layer.moe_config) + return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod( - quant_config, layer.moe_config) + quant_config, layer.moe_config + ) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) - elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) - or quant_config._is_fp8_w8a8(weight_quant, input_quant)): - return CompressedTensorsW8A8Fp8MoEMethod(quant_config, - layer.moe_config) + return CompressedTensorsW4A4MoeMethod(layer.moe_config) + elif ( + quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) + or quant_config._is_fp8_w8a8(weight_quant, input_quant) + ): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod(quant_config, - layer.moe_config) + return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) + elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): + return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config) else: raise RuntimeError( - f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - - def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): + def __init__(self, moe: FusedMoEConfig): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 - detect_nvfp4_moe_support) + detect_nvfp4_moe_support, + ) + super().__init__(moe) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.layer = layer - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.num_experts = num_experts layer.params_dtype = params_dtype @@ -133,8 +207,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): # 2 fp4 items are packed in the input dimension hidden_size // 2, requires_grad=False, - dtype=torch.uint8), - requires_grad=False) + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) @@ -144,8 +220,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, - dtype=torch.uint8), - requires_grad=False) + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -156,11 +234,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.group_size, - dtype=torch.float8_e4m3fn), - requires_grad=False) + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter( @@ -169,142 +250,168 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn), - requires_grad=False) + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # Weight Global Scales - w13_weight_scale_2 = torch.nn.Parameter(torch.empty( - num_experts, 2, dtype=torch.float32), - requires_grad=False) + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) - w2_weight_scale_2 = torch.nn.Parameter(torch.empty( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) # Input Global Scales - w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_global_scale", w13_input_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, - dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_global_scale", w2_input_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w2_input_scale, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # From packed to weight - layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) - layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, - requires_grad=False) + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. if self.allow_flashinfer: - w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data, - layer.w13_weight_scale.data, - dim=-2) + w, s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 + ) layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) - if not torch.allclose(layer.w13_weight_global_scale[:, 0], - layer.w13_weight_global_scale[:, 1]): + if not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): logger.warning_once( "w1_weight_global_scale must match w3_weight_global_scale. " - "Accuracy may be affected.") + "Accuracy may be affected." + ) # Take inverse of global scale saved to disk layer.w13_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False) + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False + ) layer.w2_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w2_weight_global_scale.data, requires_grad=False) + 1 / layer.w2_weight_global_scale.data, requires_grad=False + ) if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) return # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( - layer.w13_weight_scale), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) - layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( - layer.w2_weight_scale), - requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) # w13 - w13_input_global_scale = layer.w13_input_global_scale.max( - dim=1).values.to(torch.float32) + w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to( + torch.float32 + ) layer.g1_alphas = torch.nn.Parameter( ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), - requires_grad=False) + requires_grad=False, + ) layer.w13_input_scale_quant = torch.nn.Parameter( - (w13_input_global_scale), requires_grad=False) + (w13_input_global_scale), requires_grad=False + ) # w2 layer.g2_alphas = torch.nn.Parameter( ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( - torch.float32), - requires_grad=False) + torch.float32 + ), + requires_grad=False, + ) layer.w2_input_scale_quant = torch.nn.Parameter( - (layer.w2_input_global_scale), requires_grad=False) - - def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if not self.allow_flashinfer: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, + (layer.w2_input_global_scale), requires_grad=False ) + + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin: + return None + elif not self.allow_flashinfer: + return super().maybe_make_prepare_finalize() + + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None """Return the appropriate GEMM experts implementation.""" experts = select_nvfp4_gemm_impl( - moe, - g1_alphas=self.layer.g1_alphas, - g2_alphas=self.layer.g2_alphas, - a1_gscale=self.layer.w13_input_scale_quant, - a2_gscale=self.layer.w2_input_scale_quant, + self.moe, + self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + return nvfp4_moe_quant_config( + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + def apply( self, layer: torch.nn.Module, @@ -319,6 +426,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -326,15 +434,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert self.fused_experts is None - + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: - raise NotImplementedError("EPLB not supported for " - "`CompressedTensorsW4A4MoeMethod` yet.") + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." + ) assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -344,11 +451,17 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) + # + # Note: the order here is important. self.fused_experts can override + # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -365,13 +478,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): quant_type_id=scalar_types.float4_e2m1f.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + workspace=layer.workspace, + ) - # FlashInfer fused experts path - if self.fused_experts is not None: + elif self.fused_experts is not None: assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight + ), "Flashinfer CUTLASS Fused MoE not applicable!" return self.fused_experts( hidden_states=x, @@ -383,18 +497,20 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, apply_router_weight_on_input=apply_router_weight_on_input, ) + # FlashInfer fused experts path elif self.allow_flashinfer: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4) + flashinfer_cutlass_moe_fp4, + ) assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight + ), "Flashinfer CUTLASS Fused MoE not applicable!" + + assert self.moe_quant_config is not None return flashinfer_cutlass_moe_fp4( hidden_states=x, @@ -402,49 +518,42 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=self.moe_quant_config, inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "CompressedTensorsW4A4MoeMethod.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) + assert expert_map is None, ( + "Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod." + ) + assert self.moe_quant_config is not None - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - apply_router_weight_on_input=apply_router_weight_on_input).to( - x.dtype) + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + # TODO(bnell): derive these from arguments + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + ).to(x.dtype) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -452,52 +561,69 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ): super().__init__(moe) self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") + "input_activations" + ) - per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy - == QuantizationStrategy.TENSOR) + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN) + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) if not (per_tensor or per_channel): - raise ValueError( - "For FP8 Fused MoE layers, we require per tensor " - "or channelwise, dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None self.static_input_scales = not self.input_quant.dynamic if self.static_input_scales and per_channel: raise ValueError( "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization.") + "channelwise, dynamic per token quantization." + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + and not self.block_quant + ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, + ) self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( - self.weight_quant, self.input_quant) - self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( - self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) + self.weight_quant, self.input_quant + ) + self.use_cutlass = not self.block_quant and ( + quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) + or self.is_fp8_w8a8_sm100 + ) self.disable_expert_map = False - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts @@ -506,22 +632,54 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -529,49 +687,83 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): if self.weight_quant.strategy == QuantizationStrategy.TENSOR: # Allocate 2 scales for w1 and w3 respectively. # They are combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSOR quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: @@ -583,46 +775,53 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # We take the max of all the scales in case they differ. if self.static_input_scales: assert self.input_quant.strategy == QuantizationStrategy.TENSOR - if (layer.w13_input_scale is None or layer.w2_input_scale is None): + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer.") + "for each layer." + ) layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) + layer.w13_input_scale.max(), requires_grad=False + ) layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ + w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, - requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) # For Per-TENSOR case, Fp8 moe kernel needs single weight scale # for w13 per expert. Use max then dequant and requant each expert. @@ -634,136 +833,186 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights) + rocm_aiter_fused_experts, + shuffle_weights, + ) # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - self.fused_experts_func = None - else: - from vllm.model_executor.layers.fused_moe import fused_experts - self.fused_experts_func = fused_experts if self.use_cutlass: + assert self.weight_quant.strategy != QuantizationStrategy.BLOCK device = layer.w13_weight.device # ab_strides1 and c_strides2 are the same self.ab_strides1_c_strides2 = torch.full( - (layer.local_num_experts, ), + (layer.local_num_experts,), layer.hidden_size, device=device, - dtype=torch.int64) + dtype=torch.int64, + ) self.ab_strides2 = torch.full( - (layer.local_num_experts, ), + (layer.local_num_experts,), layer.intermediate_size_per_partition, device=device, - dtype=torch.int64) + dtype=torch.int64, + ) self.c_strides1 = torch.full( - (layer.local_num_experts, ), + (layer.local_num_experts,), 2 * layer.intermediate_size_per_partition, device=device, - dtype=torch.int64) + dtype=torch.int64, + ) + + if is_deep_gemm_e8m0_used() and self.block_quant: + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if expert_weight_is_col_major(layer.w13_weight_scale): + layer.w13_weight_scale = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale + ) + if expert_weight_is_col_major(layer.w2_weight_scale): + layer.w2_weight_scale = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale + ) + + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin or self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, - prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # cutlass path + assert self.moe_quant_config is not None if self.use_cutlass: from vllm.model_executor.layers.fused_moe import ( - CutlassBatchedExpertsFp8, CutlassExpertsFp8) + CutlassBatchedExpertsFp8, + CutlassExpertsFp8, + ) experts: FusedMoEPermuteExpertsUnpermute num_dispatchers = prepare_finalize.num_dispatchers() - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): - logger.debug("CutlassBatchedExpertsFp8(%s)", - self.__class__.__name__) + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) experts = CutlassBatchedExpertsFp8( - moe.num_local_experts, + self.moe.num_local_experts, num_dispatchers, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + self.moe.in_dtype, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + quant_config=self.moe_quant_config, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) experts = CutlassExpertsFp8( - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + self.moe.in_dtype, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + quant_config=self.moe_quant_config, ) - self.disable_expert_map = (num_dispatchers > 1 - or not experts.supports_expert_map()) + self.disable_expert_map = ( + num_dispatchers > 1 or not experts.supports_expert_map() + ) return experts # triton path - from vllm.model_executor.layers.fused_moe import TritonExperts - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, + ) assert not self.rocm_aiter_moe_enabled and not self.use_marlin - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( - ) + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - return BatchedTritonExperts( + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) + return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), + quant_config=self.moe_quant_config, ) else: - return TritonExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), - ) + logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) + return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_channel_quant, + block_shape=layer.weight_block_size, + ) def apply( self, @@ -779,6 +1028,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -786,13 +1036,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Fp8MoEMethod` yet.") + "EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet." + ) - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -802,96 +1052,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) - # cutlass path - if self.use_cutlass: - per_act_token = ( - self.input_quant.strategy == QuantizationStrategy.TOKEN) - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL - # small-batch fallback on SM100 - if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) - - if self.fused_experts is None: - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) - return cutlass_moe_fp8( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - per_act_token=per_act_token, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ab_strides1=self.ab_strides1_c_strides2, - ab_strides2=self.ab_strides2, - c_strides1=self.c_strides1, - c_strides2=self.ab_strides1_c_strides2, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - else: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - if self.rocm_aiter_moe_enabled: - return self.rocm_aiter_fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - expert_map=expert_map) + # + # Note: the order here is important. self.fused_experts can override + # cutlass fp8 or fused_experts but not marlin or rocm. + # if self.use_marlin: - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") + assert activation == "silu", f"{activation} not supported for Marlin MoE." + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -906,32 +1081,108 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): quant_type_id=scalar_types.float8_e4m3fn.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + workspace=layer.workspace, + ) - assert self.fused_experts_func is not None + elif self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, + ) - return self.fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + assert self.fused_experts is None + return rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) + + elif self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + ) + + # cutlass path + elif self.use_cutlass: + assert self.moe_quant_config is not None + + # small-batch fallback on SM100 + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + + assert per_act_token == per_channel_quant + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + quant_config=self.moe_quant_config, + ) + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8, + ) + + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_config=self.moe_quant_config, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) + + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -939,69 +1190,83 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ): super().__init__(moe) self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") + "input_activations" + ) per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN) + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) if not per_channel: raise ValueError( "For INT8 Fused MoE layers, we require channelwise, " "dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") + f"{self.weight_quant}, {self.input_quant}" + ) self.static_input_scales = not self.input_quant.dynamic if self.static_input_scales: raise ValueError( "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found static input scales.") - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + "dynamic per token quantization. Found static input scales." + ) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): params_dtype = torch.int8 # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - hidden_size, - 1, - dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) @@ -1013,6 +1278,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return int8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) + def apply( self, layer: torch.nn.Module, @@ -1027,6 +1303,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1034,17 +1311,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Int8MoEMethod` yet.") + "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -1054,8 +1331,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( hidden_states=x, @@ -1066,18 +1345,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_int8_w8a8=True, - per_channel_quant=True, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + quant_config=self.moe_quant_config, + ) class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -1093,58 +1367,71 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): self.strategy = config.strategy self.group_size = config.group_size self.actorder = config.actorder - assert config.symmetric, ( - "Only symmetric quantization is supported for MoE") + assert config.symmetric, "Only symmetric quantization is supported for MoE" - if not (self.quant_config.quant_format - == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): - raise ValueError("For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - - intermediate_size_full = extra_weight_attrs.pop( - "intermediate_size_full") + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": self.strategy - }) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, - dtype=torch.int32), - requires_grad=False) + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # In the case where we have actorder/g_idx, # we do not partition the w2 scales load_full_w2 = self.actorder and self.group_size != -1 - w2_scales_size = (intermediate_size_full - if load_full_w2 else intermediate_size_per_partition) + w2_scales_size = ( + intermediate_size_full if load_full_w2 else intermediate_size_per_partition + ) self.is_k_full = (not self.actorder) or ( - intermediate_size_per_partition == intermediate_size_full) + intermediate_size_per_partition == intermediate_size_full + ) if self.strategy == "channel": num_groups_w2 = num_groups_w13 = 1 @@ -1153,30 +1440,34 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size - w13_scale = torch.nn.Parameter(torch.ones( - num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_scale) set_weight_attrs(w13_scale, extra_weight_attrs) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_scale) set_weight_attrs(w2_scale, extra_weight_attrs) set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) - w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w2_weight_shape", w2_weight_shape) set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) @@ -1211,8 +1502,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( @@ -1223,8 +1513,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.a13_scale = None @@ -1244,41 +1533,37 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_weight_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort( - layer.w2_weight_g_idx[e]).to(torch.int32) + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( + torch.int32 + ) w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][ - w2_g_idx_sort_indices[e]] + w13_g_idx_sort_indices[e] + ] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) else: layer.w13_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) @@ -1308,8 +1593,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_weight_scale, - size_k=layer.w2_weight_scale.shape[1] * - (self.group_size if self.group_size != -1 else self.packed_factor), + size_k=layer.w2_weight_scale.shape[1] + * (self.group_size if self.group_size != -1 else self.packed_factor), size_n=layer.w2_weight_scale.shape[2], group_size=self.group_size, ) @@ -1317,6 +1602,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.workspace = marlin_make_workspace_new(device, 4) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -1331,6 +1621,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1338,18 +1629,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsWNA16MarlinMoEMethod` yet.") + "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet." + ) - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") + assert activation == "silu", f"{activation} not supported for Marlin MoE." - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -1359,8 +1649,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -1382,11 +1674,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, workspace=layer.workspace, - is_k_full=self.is_k_full) + is_k_full=self.is_k_full, + ) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -1405,43 +1697,55 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): self.group_size = config.group_size # grouped actorder isn't supported by this kernel assert config.actorder != "group" - assert config.symmetric, ( - "Only symmetric quantization is supported for MoE") + assert config.symmetric, "Only symmetric quantization is supported for MoE" - if not (self.quant_config.quant_format - == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): - raise ValueError("For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": self.strategy - }) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, - dtype=torch.int32), - requires_grad=False) + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -1454,30 +1758,34 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size - w13_scale = torch.nn.Parameter(torch.ones( - num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_scale) set_weight_attrs(w13_scale, extra_weight_attrs) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_scale) set_weight_attrs(w2_scale, extra_weight_attrs) set_weight_attrs(w2_scale, {"load_full_w2": False}) - w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w2_weight_shape", w2_weight_shape) set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) @@ -1512,8 +1820,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( @@ -1524,8 +1831,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.a13_scale = None @@ -1534,19 +1840,37 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Reconfigure packed weights and scales to match moe_wna16 format layer.w13_weight_packed = torch.nn.Parameter( - layer.w13_weight_packed.transpose(1, 2).contiguous().view( - torch.uint8), - requires_grad=False) + layer.w13_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), + requires_grad=False, + ) layer.w2_weight_packed = torch.nn.Parameter( - layer.w2_weight_packed.transpose(1, - 2).contiguous().view(torch.uint8), - requires_grad=False) + layer.w2_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), + requires_grad=False, + ) layer.w13_weight_scale = torch.nn.Parameter( - layer.w13_weight_scale.transpose(1, 2).contiguous(), - requires_grad=False) + layer.w13_weight_scale.transpose(1, 2).contiguous(), requires_grad=False + ) layer.w2_weight_scale = torch.nn.Parameter( - layer.w2_weight_scale.transpose(1, 2).contiguous(), - requires_grad=False) + layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + assert self.num_bits == 4 or self.num_bits == 8 + config_builder = ( + int4_w4a16_moe_quant_config + if self.num_bits == 4 + else int8_w8a16_moe_quant_config + ) + + return config_builder( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size], + ) def apply( self, @@ -1562,6 +1886,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1569,16 +1894,17 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: - raise NotImplementedError("EPLB not supported for " - "`CompressedTensorsWNA16MoEMethod` yet.") + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -1588,8 +1914,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -1599,13 +1927,341 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=self.num_bits == 4, - use_int8_w8a16=self.num_bits == 8, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, self.group_size]) + quant_config=self.moe_quant_config, + ) + + +class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): + """ + CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform + - Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles) + - Scales: Fp32 for Channelwise , bf16 for groupwise quantization + - Bias: Same data type as original weights + - Activations: FP32/Bf16 dynamic per-token (A8 Int), + quantized inside the kernel + """ + + def __init__( + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.quant_config = quant_config + + # Validate scheme: weights=W4 (channel or group), + # activations=dynamic TOKEN (A8) + wq = self.quant_config.target_scheme_map["Linear"].get("weights") + aq = self.quant_config.target_scheme_map["Linear"].get("input_activations") + + # Must be dynamic per-token activations + if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: + raise ValueError( + "W4A8-int MoE needs dynamic per-token activation quantization." + ) + + # Weight can be channel-wise (group_size=None) or group-wise + self.group_size = wq.group_size if (wq.group_size is not None) else -1 + if wq.num_bits != 4: + raise ValueError("This method only supports 4-bit weights (num_bits=4).") + + # CPU only + if not current_platform.is_cpu(): + raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.") + + # Arm: check _dyn ops availability + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + try: + _ = torch.ops.aten._dyn_quant_matmul_4bit + _ = torch.ops.aten._dyn_quant_pack_4bit_weight + except AttributeError as err: + raise RuntimeError( + f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops; + install a newer build.""" + ) from err + self.static_input_scales = False # always dynamic per token + + # ---- parameter creation ---- + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Shapes per local rank (TP/EP): + # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) + # w2 : [E, H, I_local] int8 + # Scales: + # channel-wise: group_size=-1 -> per-output-row, single scale per row + # group-wise : group_size=g -> + # per-output-row, (in_features/g) scales + + E = num_experts + H = hidden_size + IN = intermediate_size_per_partition + g = self.group_size + + # Per-row scale columns + def _n_scale_cols(in_features: int) -> int: + return 1 if g == -1 else (in_features // g) + + # Register unpacked int4-as-int8 weights the loader will fill. + w13 = torch.nn.Parameter( + torch.empty(E, 2 * IN, H, dtype=torch.int8), requires_grad=False + ) + set_weight_attrs(w13, extra_weight_attrs) + layer.register_parameter("w13_weight", w13) + + w2 = torch.nn.Parameter( + torch.empty(E, H, IN, dtype=torch.int8), requires_grad=False + ) + set_weight_attrs(w2, extra_weight_attrs) + layer.register_parameter("w2_weight", w2) + + # Register scales + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + + w13_s = torch.nn.Parameter( + torch.ones(E, 2 * IN, _n_scale_cols(H), dtype=scale_dtype), + requires_grad=False, + ) + set_weight_attrs( + w13_s, + {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, + ) + layer.register_parameter("w13_weight_scale", w13_s) + + w2_s = torch.nn.Parameter( + torch.ones(E, H, _n_scale_cols(IN), dtype=scale_dtype), requires_grad=False + ) + set_weight_attrs( + w2_s, + {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, + ) + layer.register_parameter("w2_weight_scale", w2_s) + + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros(E, 2 * IN, dtype=params_dtype), requires_grad=False + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + # Placeholders for packed weights (will be replaced after packing) + layer.register_parameter( + "w13_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs) + + layer.register_parameter( + "w2_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs) + + # dims for 4 bit fused matmuls + layer.w13_in_features = H + layer.w13_out_features = 2 * IN + layer.w2_in_features = IN + layer.w2_out_features = H + layer.group_size = g + + # post-load packing to dyn-4bit KleidiAI kernel's format + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E = layer.w13_weight.shape[0] + H = layer.w13_in_features + I2 = layer.w13_out_features + IN = layer.w2_in_features + g = layer.group_size + + def _pack_matrix( + int4_as_int8_2d: torch.Tensor, + scales_2d: torch.Tensor, + bias_1d: Optional[torch.Tensor], + in_features: int, + out_features: int, + ) -> torch.Tensor: + # int4 values are stored as int8 in [-8,7]. + # Shift to unsigned nibble and pack pairs along input-dim. + tmp = int4_as_int8_2d.add(8) # [out, in] + uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to( + torch.uint8 + ) # [out, in//2] + + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + scales = scales_2d.to(scale_dtype) + bias = None if bias_1d is None else bias_1d.to(torch.float32) + return torch.ops.aten._dyn_quant_pack_4bit_weight( + uint8_nibbles, + scales, + bias, + g if g != -1 else in_features, + in_features, + out_features, + ) + + # Pack per expert + w13_packed_list = [] + w2_packed_list = [] + + has_w13_bias = hasattr(layer, "w13_bias") and layer.w13_bias is not None + has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None + + for e in range(E): + w13_packed_list.append( + _pack_matrix( + layer.w13_weight[e], # [2I, H] + layer.w13_weight_scale[e], # [2I, H/g or 1] + layer.w13_bias[e] if has_w13_bias else None, # [2I] + H, + I2, + ) + ) + w2_packed_list.append( + _pack_matrix( + # w2 shape is [H, IN]; we need [out, in] == [H, IN]. + layer.w2_weight[e], # [H, IN] + layer.w2_weight_scale[e], # [H, IN/g or 1] + layer.w2_bias[e] if has_w2_bias else None, # [H] + IN, + layer.w2_out_features, # in_features=IN, out_features=H + ) + ) + + # each packed tensor has identical shape per expert; stack on dim 0 + w13_packed = torch.stack(w13_packed_list, dim=0) + w2_packed = torch.stack(w2_packed_list, dim=0) + + replace_parameter( + layer, + "w13_weight_packed", + torch.nn.Parameter(w13_packed, requires_grad=False), + ) + replace_parameter( + layer, + "w2_weight_packed", + torch.nn.Parameter(w2_packed, requires_grad=False), + ) + + # free raw tensors/scales/bias now that they're packed into the payload. + replace_parameter( + layer, "w13_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + replace_parameter( + layer, "w2_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + replace_parameter( + layer, + "w13_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + replace_parameter( + layer, + "w2_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + if has_w13_bias: + replace_parameter( + layer, + "w13_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + if has_w2_bias: + replace_parameter( + layer, + "w2_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + # CPU dynamic 4-bit MoE path does not use modular kernels or + # fused_experts; quant config is not needed. + return None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." + assert activation in ("silu", "swigluoai", "swiglu"), ( + "Only SiLU/SwiGLUGU/SwiGLUUG are supported." + ) + assert expert_map is None, """expert_map/EP not implemented + for CPU dyn-4bit MoE.""" + + def _act_kind(s: str) -> int: + # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU + if s == "swiglu": + return 0 + if s == "swigluoai": + return 1 + if s == "silu": + return 2 + raise ValueError(f"Unknown activation '{s}'") + + # Apply topk softmax on router output + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops._C.dynamic_4bit_int_moe( + x, + topk_ids.to(torch.long), + topk_weights, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w2_out_features, + layer.w2_in_features, + layer.w13_out_features, + layer.group_size, + apply_router_weight_on_input, + int(_act_kind(activation)), + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 734fa603ba7b9..fc0634394ece3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,23 +3,32 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int -from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, - CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w4a16_24 import ( + W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24, +) from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 -from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, - CompressedTensorsWNA16) +from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 from .compressed_tensors_24 import CompressedTensors24 # isort: skip __all__ = [ - "CompressedTensorsScheme", "CompressedTensorsWNA16", - "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", - "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", - "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" + "CompressedTensorsScheme", + "CompressedTensorsWNA16", + "CompressedTensorsW8A16Fp8", + "CompressedTensorsW4A16Sparse24", + "CompressedTensorsW8A8Int8", + "CompressedTensorsW8A8Fp8", + "WNA16_SUPPORTED_BITS", + "W4A16SPARSE24_SUPPORTED_BITS", + "CompressedTensors24", + "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A4Fp4", + "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8", ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 168b221a9cfe9..93a50a377ee56 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -5,25 +5,33 @@ from typing import Any, Callable, Optional import torch from compressed_tensors import CompressionFormat, ModelCompressor -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise, sparse_cutlass_supported) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + convert_to_channelwise, + sparse_cutlass_supported, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensors24"] @@ -31,7 +39,6 @@ from vllm.platforms import current_platform class CompressedTensors24(CompressedTensorsScheme): - def __init__( self, quantized: bool = False, @@ -42,16 +49,22 @@ class CompressedTensors24(CompressedTensorsScheme): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant - self.model_compressor = ( - ModelCompressor.from_compression_config(model_compression_config) - if model_compression_config is not None else None) + model_compressor = ModelCompressor.from_compression_config( + model_compression_config + ) self.do_sparse_decompress = ( - self.model_compressor is not None - and self.model_compressor.sparsity_config.format - == CompressionFormat.sparse_24_bitmask.value) + model_compressor is not None + and model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value + ) + if self.do_sparse_decompress: + self.model_compressor = model_compressor - if quantized and input_quant is not None and \ - self._get_quant_dtype() == current_platform.fp8_dtype(): + if ( + quantized + and input_quant is not None + and self._get_quant_dtype() == current_platform.fp8_dtype() + ): static = not input_quant.dynamic g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN self.quant_fp8 = QuantFP8(static, g_shape) @@ -74,7 +87,8 @@ class CompressedTensors24(CompressedTensorsScheme): if not sparse_cutlass_supported(): raise ValueError( "Sparse CUTLASS not supported. vLLM must be built with " - "CUDA 12.2 or later to use this feature") + "CUDA 12.2 or later to use this feature" + ) layer.logical_widths = output_partition_sizes layer.input_size = input_size @@ -93,9 +107,9 @@ class CompressedTensors24(CompressedTensorsScheme): weight_loader=weight_loader, ) if self.do_sparse_decompress: - assert all(partition_size % 8 == 0 - for partition_size in output_partition_sizes - ), "All partitions must be divisible by 8 for " + assert all( + partition_size % 8 == 0 for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " "2:4 sparse compressed models" shape = BasevLLMParameter( @@ -130,20 +144,24 @@ class CompressedTensors24(CompressedTensorsScheme): # Check if quantized, not just 2:4 Sparse if self.quantized: - if (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.CHANNEL.value): + if ( + self.weight_quant + and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ): weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), + data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32 + ), output_dim=0, weight_loader=weight_loader, ) else: - assert (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.TENSOR.value) + assert ( + self.weight_quant + and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value + ) weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) @@ -152,8 +170,7 @@ class CompressedTensors24(CompressedTensorsScheme): # input quant will be non-none if self.input_quant and not self.input_quant.dynamic: # register input quant scale - assert (self.input_quant.strategy == - QuantizationStrategy.TENSOR.value) + assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader, @@ -163,12 +180,12 @@ class CompressedTensors24(CompressedTensorsScheme): else: # for sparse-only, pass in 1 for weight/input scales - weight_scale = torch.nn.Parameter(data=torch.ones( - 1, dtype=torch.float32), - requires_grad=False) - input_scale = torch.nn.Parameter(data=torch.ones( - 1, dtype=torch.float32), - requires_grad=False) + weight_scale = torch.nn.Parameter( + data=torch.ones(1, dtype=torch.float32), requires_grad=False + ) + input_scale = torch.nn.Parameter( + data=torch.ones(1, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("input_scale", input_scale) layer.register_parameter("weight_scale", weight_scale) @@ -199,8 +216,9 @@ class CompressedTensors24(CompressedTensorsScheme): # torch.compile workaround if hasattr(layer, "input_scale"): - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) if self.weight_quant: if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: @@ -214,11 +232,11 @@ class CompressedTensors24(CompressedTensorsScheme): else: # torch.compile workaround layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data, requires_grad=False) + layer.weight_scale.data, requires_grad=False + ) # Set all negative zero values to 0 prior to compression - if (layer.weight.dtype.is_floating_point - and layer.weight.dtype.itemsize >= 2): + if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2: layer.weight.data[layer.weight.data == -0.0] = 0.0 w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data) @@ -243,7 +261,7 @@ class CompressedTensors24(CompressedTensorsScheme): :return: The output tensor of the layer """ if self.quantized: - scale = getattr(layer, 'input_scale', None) + scale = getattr(layer, "input_scale", None) if self.weights_dtype == torch.int8: ops_output = ops.scaled_int8_quant(x, scale=scale) @@ -286,12 +304,16 @@ class CompressedTensors24(CompressedTensorsScheme): if not is_8_bits: raise ValueError("Cutlass only supports 8-bit quantization") - if (self.weight_quant.type == QuantizationType.FLOAT - and self.input_quant.type == QuantizationType.FLOAT): + if ( + self.weight_quant.type == QuantizationType.FLOAT + and self.input_quant.type == QuantizationType.FLOAT + ): return torch.float8_e4m3fn - if (self.weight_quant.type == QuantizationType.INT - and self.input_quant.type == QuantizationType.INT): + if ( + self.weight_quant.type == QuantizationType.INT + and self.input_quant.type == QuantizationType.INT + ): return torch.int8 raise ValueError("Quantization type not supported by Cutlass") @@ -317,7 +339,7 @@ class CompressedTensors24(CompressedTensorsScheme): :param bitmask: The 2:4 bitmask associated with the compressed weights, representing the positions of non-zero elements in the compressed tensor. - :param layer: The layer whose weights need to be processed after + :param layer: The layer whose weights need to be processed after loading. :return: The decompressed 2:4 sparse weight tensor. """ @@ -343,14 +365,16 @@ class CompressedTensors24(CompressedTensorsScheme): if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, layer.logical_widths) split_bitmask = torch.split(bitmask, layer.logical_widths) - split_shape = [(out, layer.input_size_per_partition) - for out in layer.logical_widths] + split_shape = [ + (out, layer.input_size_per_partition) for out in layer.logical_widths + ] if split_weights: decompressed_shards = [ _process_split(compressed_weight, shape, bitmask) for compressed_weight, shape, bitmask in zip( - split_weights, split_shape, split_bitmask) + split_weights, split_shape, split_bitmask + ) ] decompressed = combine_shards(decompressed_shards) else: @@ -362,5 +386,6 @@ class CompressedTensors24(CompressedTensorsScheme): layer.input_size_per_partition, ), bitmask=bitmask, - )) + ) + ) return decompressed diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index a5d48f2356744..688621cbf79af 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -11,7 +11,7 @@ __all__ = ["CompressedTensorsScheme"] class CompressedTensorsScheme(ABC): """ - Abstract class used to describe the weight creation and forward pass + Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by CompressedTensors. """ @@ -26,20 +26,21 @@ class CompressedTensorsScheme(ABC): @abstractmethod def create_weights(self, *args, **kwargs): """ - Weight creation for the particular scheme. Inputs to this function + Weight creation for the particular scheme. Inputs to this function """ raise NotImplementedError @abstractmethod - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]): + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): """ - Run the forward pass for the particular scheme. This is where + Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: torch.nn.Module with the registered weights and - other parameters relevant to the particular scheme. + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 3f3e7668fcf74..af06418c959da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -8,13 +8,18 @@ from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsW4A16Sparse24"] @@ -25,11 +30,7 @@ W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): - - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None): + def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None): self.strategy = strategy self.group_size = group_size self.tile_size = 16 @@ -37,13 +38,13 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}. " - f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}" + ) self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] if self.strategy == "group" and self.group_size is None: - raise ValueError( - "group_size must be given when using strategy group") + raise ValueError("group_size must be given when using strategy group") @classmethod def get_min_capability(cls) -> int: @@ -52,18 +53,20 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # required by torch.compile to be torch.nn.Parameter - layer.weight_packed = Parameter(layer.weight_packed.data, - requires_grad=False) - layer.scale_packed = Parameter(layer.scale_packed.data, - requires_grad=False) + layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False) + layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False) layer.meta = Parameter(layer.meta.data, requires_grad=False) - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): assert params_dtype == torch.float16, ( "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 ) @@ -71,55 +74,59 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes) - qweight = PackedvLLMParameter(data=torch.empty( - input_size_per_partition // self.tile_size // 2, - output_size_per_partition * self.tile_size // pack_factor, - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=pack_factor, - marlin_tile_size=self.tile_size, - weight_loader=weight_loader) + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=pack_factor, + marlin_tile_size=self.tile_size, + weight_loader=weight_loader, + ) - input_groups = (1 if self.group_size is None else - input_size_per_partition // self.group_size) + input_groups = ( + 1 + if self.group_size is None + else input_size_per_partition // self.group_size + ) weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( input_groups, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if self.group_size is not None: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) else: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) - weight_shape = BasevLLMParameter(data=torch.empty(2, - dtype=torch.int64), - weight_loader=weight_loader) + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) - meta = PackedvLLMParameter(data=torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - dtype=torch.int16, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=1, - marlin_tile_size=2, - weight_loader=weight_loader) + meta = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", qweight) layer.register_parameter("weight_shape", weight_shape) @@ -127,16 +134,17 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): layer.register_parameter("meta", meta) max_workspace_size = ( - output_size_per_partition // - GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N + ) * GPTQ_MARLIN_24_MAX_PARALLEL - workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), - requires_grad=False) + workspace = Parameter( + torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False + ) layer.workspace = workspace - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: qweight = layer.weight_packed meta = layer.meta scales = layer.scale_packed @@ -148,11 +156,19 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): size_k = x_2d.shape[1] size_n = scales.shape[1] - output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, self.quant_type, size_m, - size_n, size_k) + output_2d = ops.gptq_marlin_24_gemm( + x_2d, + qweight, + meta, + scales, + workspace, + self.quant_type, + size_m, + size_n, + size_k, + ) - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index 96dccf04d490f..a96f51538b38c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -6,18 +6,22 @@ import torch from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensorsW4A16Fp4"] class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): - def __init__(self, has_input_global_scale: bool = False): self.has_input_global_scale = has_input_global_scale self.group_size = 16 @@ -27,49 +31,59 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): # dont restrict as emulations return 80 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition # Weight - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=torch.uint8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", weight) # Global Weight Scale weight_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_global_scale", weight_global_scale) # Per Group Weight Scale - weight_scale = GroupQuantScaleParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) if self.has_input_global_scale: input_global_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader) + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: @@ -81,25 +95,30 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): # Rename weight_global_scale to weight_scale_2 that marlin expects # Note: ct stores the inverse of what is expected by the marlin kernel layer.weight_scale_2 = Parameter( - 1 / layer.weight_global_scale.max().to(torch.float32), - requires_grad=False) + 1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) del layer.weight_global_scale if self.has_input_global_scale: layer.input_global_scale = torch.nn.Parameter( - layer.input_global_scale.data, requires_grad=False) + layer.input_global_scale.data, requires_grad=False + ) prepare_fp4_layer_for_marlin(layer) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return apply_fp4_marlin_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_scale_2=layer.weight_scale_2, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 49d76bbeaa3a1..676f4de6ee7b1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -9,14 +9,17 @@ import vllm.envs as envs from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - run_nvfp4_emulations) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - swizzle_blockscale) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + run_nvfp4_emulations, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer logger = init_logger(__name__) @@ -25,13 +28,24 @@ __all__ = ["CompressedTensorsW4A4Fp4"] class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): - def __init__(self): if envs.VLLM_USE_TRTLLM_FP4_GEMM: assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" self.backend = "flashinfer-trtllm" + logger.info_once("Using flashinfer-trtllm for FP4") + elif envs.VLLM_USE_FBGEMM: + self.backend = "fbgemm" + try: + import fbgemm_gpu # noqa: F401 + except ImportError as exc: + raise ImportError( + "Backend fbgemm requires fbgemm.f4f4bf16 operator, " + "Please install with: pip install fbgemm-gpu-genai" + ) from exc + logger.info_once("Using FGBEMM-GPU-GENAI for FP4") elif has_flashinfer(): self.backend = "flashinfer-cutlass" + logger.info_once("Using flashinfer-cutlass for FP4") else: self.backend = "cutlass" self.group_size = 16 @@ -42,58 +56,67 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): return 80 return 100 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition # Weight - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=torch.uint8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", weight) # Global Weight Scale weight_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_global_scale", weight_global_scale) # Per Group Weight Scale - weight_scale = GroupQuantScaleParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) input_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: - global_input_scale = layer.input_global_scale.max().to(torch.float32) - layer.input_global_scale = Parameter(global_input_scale, - requires_grad=False) + layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) layer.weight_global_scale = Parameter( - layer.weight_global_scale.max().to(torch.float32), - requires_grad=False) + layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) if self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. @@ -106,38 +129,43 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): weight_scale = layer.weight_scale.data epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), - epilogue_tile_m) - weight_scale = (shuffle_matrix_sf_a(weight_scale.view( - torch.uint8), epilogue_tile_m).reshape( - weight_scale.shape).view(torch.float8_e4m3fn)) + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) - layer.weight_packed = Parameter(layer.weight_packed.data, - requires_grad=False) + if self.backend == "fbgemm": + swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) + layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) + layer.weight_packed = Parameter( + layer.weight_packed.data, requires_grad=False + ) layer.alpha = Parameter( 1 / (layer.input_global_scale * layer.weight_global_scale), - requires_grad=False) - - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + requires_grad=False, + ) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if envs.VLLM_USE_NVFP4_CT_EMULATIONS: out = run_nvfp4_emulations( x=x, input_global_scale=layer.input_global_scale, weight=layer.weight_packed, - weight_scale_swizzled=layer.weight_scale_swizzled, - weight_global_scale=layer.weight_global_scale) + weight_scale_swizzled=layer.weight_scale, + weight_global_scale=layer.weight_global_scale, + ) if bias is not None: out = out + bias return out @@ -148,12 +176,27 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) - mm_args = (x_fp4, layer.weight_packed, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, output_dtype) + mm_args = ( + x_fp4, + layer.weight_packed, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + ) if self.backend == "flashinfer-trtllm": out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") elif self.backend == "flashinfer-cutlass": out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + elif self.backend == "fbgemm": + out = torch.ops.fbgemm.f4f4bf16( + x_fp4, + layer.weight_packed, + x_blockscale.view(-1).view(torch.uint8), + layer.weight_scale, + layer.alpha, + use_mx=False, + ).to(output_dtype) else: out = cutlass_scaled_fp4_mm(*mm_args) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py new file mode 100644 index 0000000000000..59d99e1e1c907 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, + choose_mp_linear_kernel, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Fp8"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__( + self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None, + ): + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size != 128 or self.strategy != "group": + raise ValueError( + "W4A8 kernels require group quantization with group size 128" + ) + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}" + ) + + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # hopper + return 90 + + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), + weight_type=self.quant_type, + act_type=torch.float8_e4m3fn, # always use fp8(e4m3) + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx, + out_type=params_dtype, + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Fp8", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = input_size != input_size_per_partition + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel + ) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter( + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + ) + + # TODO(czhu): allocate the packed fp8 scales memory here? + # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + weight_scale_args = { + "weight_loader": weight_loader, + "data": torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=torch.float8_e4m3fn, + ), + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) + + # per-channel scales + weight_chan_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("weight_chan_scale", weight_chan_scale) + + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx", + ) + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py index f1fca85508a6b..61a9f6b75cb13 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py @@ -7,12 +7,17 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - ModelWeightParameter) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -27,12 +32,14 @@ W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsW4A8Int(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None, - is_static_input_scheme: bool = False, - input_symmetric: bool = True): + def __init__( + self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + is_static_input_scheme: bool = False, + input_symmetric: bool = True, + ): self.strategy = strategy self.group_size = -1 if group_size is None else group_size self.is_static_input_scheme = is_static_input_scheme @@ -41,42 +48,53 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme): if num_bits not in W4A8_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}." - f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}" + ) self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] @classmethod def get_min_capability(cls) -> int: return 1 - def create_weights(self, layer: torch.nn.Module, output_size: int, - input_size: int, output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) - row_parallel = (input_size != input_size_per_partition) + row_parallel = input_size != input_size_per_partition # Compute effective group_size if self.group_size == -1: - effective_group_size = (input_size_per_partition - if row_parallel else input_size) + effective_group_size = ( + input_size_per_partition if row_parallel else input_size + ) else: effective_group_size = self.group_size # Ensure group_size divides input_size_per_partition assert input_size_per_partition % effective_group_size == 0, ( f"input_size_per_partition {input_size_per_partition}" - f" not divisible by group_size {effective_group_size}") + f" not divisible by group_size {effective_group_size}" + ) # Determine scale partitioning - is_channelwise = (self.group_size == -1) - repeat_scales = (is_channelwise and row_parallel) + is_channelwise = self.group_size == -1 + repeat_scales = is_channelwise and row_parallel partition_scales = not repeat_scales mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=(input_size_per_partition, - output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_type, act_type=params_dtype, group_size=effective_group_size, @@ -86,50 +104,50 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme): kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW4A8Int", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsW4A8Int", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) scales_and_zp_size = input_size_per_partition // effective_group_size - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.int8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) weight_scale_args = { - "weight_loader": - weight_loader, - "data": - torch.empty(output_size_per_partition, - scales_and_zp_size, - dtype=params_dtype) + "weight_loader": weight_loader, + "data": torch.empty( + output_size_per_partition, scales_and_zp_size, dtype=params_dtype + ), } if partition_scales: - weight_scale = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + weight_scale = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) else: - weight_scale = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="weight_packed", - w_s_param_name="weight_scale", - w_zp_param_name=None, - w_gidx_param_name=None) + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name=None, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 01a87a0888996..709d2538e6ad0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -7,24 +7,27 @@ import torch from compressed_tensors.quantization import QuantizationStrategy from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + convert_to_channelwise, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensorsW8A16Fp8"] -SUPPORTED_STRATEGIES = [ - QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR -] +SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): - def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme @@ -39,31 +42,36 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): # we expand each scale to its shard's channels. def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: - ws_channelwise = convert_to_channelwise(layer.weight_scale, - layer.logical_widths) - layer.weight_scale = torch.nn.Parameter(ws_channelwise, - requires_grad=False) + ws_channelwise = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False) else: # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) # Weights must be transposed for marlin - layer.weight = torch.nn.Parameter(layer.weight.t(), - requires_grad=False) + layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False) if self.is_static_input_scheme: # required by torch.compile to be torch.nn.Parameter - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) prepare_fp8_layer_for_marlin(layer) - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -72,50 +80,59 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): layer.weight_block_size = None # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) elif self.strategy == QuantizationStrategy.TENSOR: - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) else: raise ValueError( f"Unsupported weight strategy={self.strategy}, " - f"supported strategies are {SUPPORTED_STRATEGIES}") + f"supported strategies are {SUPPORTED_STRATEGIES}" + ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE (to deal with converted checkpoints) if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_scale", input_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return apply_fp8_marlin_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index d984e89d9e02a..902c9c7bde97b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,153 +4,197 @@ from typing import Callable, Optional import torch -from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + CompressedTensorsScheme, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, + check_aiter_fp8_linear_support, + create_fp8_input_scale, + create_fp8_scale_parameter, + create_fp8_weight_parameter, + maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, + process_fp8_weight_channel_strategy, + process_fp8_weight_tensor_strategy, + validate_fp8_block_shape, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) -from vllm.platforms import current_platform + Fp8LinearOp, + cutlass_block_fp8_supported, + maybe_create_device_identity, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensorsW8A8Fp8"] +strategy_to_parameter_type = { + QuantizationStrategy.BLOCK: BlockQuantScaleParameter, + QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, + QuantizationStrategy.TENSOR: PerTensorScaleParameter, +} + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy + def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): + self.weight_quant = weight_quant + self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape) + + self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + else: + self.act_q_group_shape = ( + GroupShape.PER_TENSOR + if is_static_input_scheme + else GroupShape.PER_TOKEN + ) + + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + + if self.weight_block_size is not None: + assert not self.is_static_input_scheme + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape, + ) @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 - def process_weights_after_loading(self, layer) -> None: - # If per tensor, when we have a fused module (e.g. QKV) with per - # tensor scales (thus N scales being passed to the kernel), - # requantize so we can always run per tensor - if self.strategy == QuantizationStrategy.TENSOR: - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) - - if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - - weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=max_w_scale, - input_scale=input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - - # If channelwise, scales are already lined up, so just transpose. - elif self.strategy == QuantizationStrategy.CHANNEL: - weight = layer.weight - - if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - else: - weight_scale = layer.weight_scale.data - - layer.weight = Parameter(weight.t(), requires_grad=False) - # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - - else: - raise ValueError(f"Unknown quantization strategy {self.strategy}") - - # INPUT SCALE - if self.is_static_input_scheme and hasattr(layer, 'input_scale'): - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) - else: - layer.input_scale = None - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): maybe_create_device_identity() output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes + layer.weight_block_size = None + layer.orig_dtype = params_dtype + + if self.strategy == QuantizationStrategy.BLOCK: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + # Validate block quantization shapes + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) layer.register_parameter("weight", weight) # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min + weight_scale = create_fp8_scale_parameter( + strategy_to_parameter_type[self.strategy], + output_partition_sizes, + input_size_per_partition, + layer.weight_block_size, + weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - input_scale[:] = torch.finfo(torch.float32).min + input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def process_weights_after_loading(self, layer) -> None: + if self.strategy == QuantizationStrategy.TENSOR: + weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( + layer.weight, + layer.weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), + ) + weight = weight.t() - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias) + elif self.strategy == QuantizationStrategy.CHANNEL: + weight, weight_scale, input_scale = process_fp8_weight_channel_strategy( + layer.weight, layer.weight_scale, getattr(layer, "input_scale", None) + ) + weight = weight.t() + + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + weight, weight_scale = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale + ) + input_scale = None + + else: + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # required by torch.compile to be torch.nn.Parameter + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + if input_scale is not None: + layer.input_scale = Parameter(input_scale.data, requires_grad=False) + + # INPUT SCALE + if self.is_static_input_scheme and hasattr(layer, "input_scale"): + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) + else: + layer.input_scale = None + + if self.strategy == QuantizationStrategy.BLOCK: + maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.weight_block_size is not None: + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) + + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6189f0609d85d..70316a7553ca3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -8,13 +8,18 @@ from compressed_tensors.quantization import QuantizationStrategy from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + ScaledMMLinearLayerConfig, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) logger = init_logger(__name__) @@ -22,8 +27,9 @@ logger = init_logger(__name__) class CompressedTensorsW8A8Int8(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, strategy: str, is_static_input_scheme: bool, - input_symmetric: bool): + def __init__( + self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool + ): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme self.input_symmetric = input_symmetric @@ -33,56 +39,61 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): # turing and up return 75 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): layer.logical_widths = output_partition_sizes scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, - input_symmetric=self.input_symmetric) + input_symmetric=self.input_symmetric, + ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8Int8", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader + ) layer.register_parameter("input_scale", input_scale) if not self.input_symmetric: @@ -90,22 +101,25 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): # as the weights # AZP loaded as int8 but used as int32 input_zero_point = BasevLLMParameter( - data=torch.empty(1, dtype=torch.int8), - weight_loader=weight_loader) + data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader + ) layer.register_parameter("input_zero_point", input_zero_point) - self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj") + self.kernel = kernel_type( + c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj", + ) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 74787603e0029..188fc15fd9485 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -8,29 +8,29 @@ from compressed_tensors.quantization import ActivationOrdering from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_repeat_scales_on_all_ranks) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) -# yapf: enable + marlin_repeat_scales_on_all_ranks, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) __all__ = ["CompressedTensorsWNA16"] -WNA16_SUPPORTED_TYPES_MAP = { - 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128 -} +WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128} WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) @@ -38,13 +38,14 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsWNA16(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None, - symmetric: Optional[bool] = True, - actorder: Optional[ActivationOrdering] = None): - + def __init__( + self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None, + ): self.pack_factor = 32 // num_bits self.strategy = strategy self.symmetric = symmetric @@ -52,55 +53,67 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size == -1 and self.strategy != "channel": - raise ValueError("Marlin kernels require group quantization or " - "channelwise quantization, but found no group " - "size and strategy is not channelwise.") + raise ValueError( + "Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise." + ) if num_bits not in WNA16_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}. " - f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}" + ) - self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] - if not self.symmetric else - WNA16_SUPPORTED_TYPES_MAP[num_bits]) + self.quant_type = ( + WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric + else WNA16_SUPPORTED_TYPES_MAP[num_bits] + ) @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, output_size: int, - input_size: int, output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_type, act_type=params_dtype, group_size=self.group_size, zero_points=not self.symmetric, - has_g_idx=self.has_g_idx + has_g_idx=self.has_g_idx, ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsWNA16", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size - row_parallel = (input_size != input_size_per_partition) + row_parallel = input_size != input_size_per_partition partition_scales = not marlin_repeat_scales_on_all_ranks( - self.has_g_idx, self.group_size, row_parallel) + self.has_g_idx, self.group_size, row_parallel + ) scales_and_zp_size = input_size // group_size @@ -108,65 +121,65 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): assert input_size_per_partition % group_size == 0 scales_and_zp_size = input_size_per_partition // group_size - weight = PackedvLLMParameter(input_dim=1, - output_dim=0, - weight_loader=weight_loader, - packed_factor=self.pack_factor, - packed_dim=1, - data=torch.empty( - output_size_per_partition, - input_size_per_partition // - self.pack_factor, - dtype=torch.int32, - )) + weight = PackedvLLMParameter( + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + ) weight_scale_args = { - "weight_loader": - weight_loader, - "data": - torch.empty( + "weight_loader": weight_loader, + "data": torch.empty( output_size_per_partition, scales_and_zp_size, dtype=params_dtype, - ) + ), } zeros_args = { - "weight_loader": - weight_loader, - "data": - torch.zeros( + "weight_loader": weight_loader, + "data": torch.zeros( output_size_per_partition // self.pack_factor, scales_and_zp_size, dtype=torch.int32, - ) + ), } if not partition_scales: - weight_scale = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) if not self.symmetric: - qzeros = PackedColumnParameter(output_dim=0, - packed_dim=0, - packed_factor=self.pack_factor, - **zeros_args) + qzeros = PackedColumnParameter( + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args, + ) else: - weight_scale = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + weight_scale = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) if not self.symmetric: - qzeros = PackedvLLMParameter(input_dim=1, - output_dim=0, - packed_dim=0, - packed_factor=self.pack_factor, - **zeros_args) + qzeros = PackedvLLMParameter( + input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args, + ) # A 2D array defining the original shape of the weights # before packing - weight_shape = BasevLLMParameter(data=torch.empty(2, - dtype=torch.int64), - weight_loader=weight_loader) + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) @@ -177,25 +190,30 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): # group index (for activation reordering) if self.has_g_idx: - weight_g_idx = RowvLLMParameter(data=torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + weight_g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_g_idx", weight_g_idx) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="weight_packed", - w_s_param_name="weight_scale", - w_zp_param_name="weight_zero_point", - w_gidx_param_name="weight_g_idx") + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx", + ) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/adapter_commons/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/__init__.py similarity index 100% rename from vllm/adapter_commons/__init__.py rename to vllm/model_executor/layers/quantization/compressed_tensors/transform/__init__.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py new file mode 100644 index 0000000000000..edd2706b470fd --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Generator +from itertools import accumulate +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformLocation, + TransformScheme, +) +from compressed_tensors.utils import is_match + +from vllm.model_executor.layers.linear import ( + WEIGHT_LOADER_V2_SUPPORTED, + LinearMethodBase, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsScheme, +) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 + HadamardTransform, +) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple, +) + + +class CompressedTensorsLinearTransformMethod(LinearMethodBase): + """ + Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds + input and output transforms to either side of the original apply method + """ + + @classmethod + def from_schemes( + cls, + quant_method: LinearMethodBase, + quant_scheme: Optional[CompressedTensorsScheme], + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple], + ) -> "CompressedTensorsLinearTransformMethod": + from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501 + QutlassNvFP4LinearMethod, + is_qutlass_fp4_scheme, + ) + + assert input_tfms or output_tfms + + if is_qutlass_fp4_scheme(quant_scheme, input_tfms): + return QutlassNvFP4LinearMethod(quant_method, input_tfms, output_tfms) + + # hadacore or dense gemm is selected by Transform module + + return cls(quant_method, input_tfms, output_tfms) + + def __init__( + self, + quant_method: LinearMethodBase, + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple], + ): + self.quant_method = quant_method + self.input_tfms = input_tfms + self.output_tfms = output_tfms + + self.input_transform: Optional[HadamardTransform] = None + self.output_transform: Optional[HadamardTransform] = None + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # get weight loader for transforms + weight_loader: Callable = extra_weight_attrs.get("weight_loader") # type: ignore[assignment] + + # HACK: UnquantizedLinearMethod does not support weight loader v2, but + # transforms (specifically SharedWeightParameter) requires + # weight loader v2. Until UnquantizedLinearMethod supports v2, we must + # hack around this by getting weight loader v1 so ULM can load correctly + quant_method_name = self.quant_method.__class__.__name__ + if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED: + weight_loader_v1 = layer.weight_loader + extra_weight_attrs["weight_loader"] = weight_loader_v1 + + self.quant_method.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + input_size=input_size, + output_size=output_size, + params_dtype=params_dtype, + **extra_weight_attrs, + ) + + # validate schemes + num_partitions = len(output_partition_sizes) + self._validate_tfm_schemes(num_partitions) + + # create submodules for weight loading + if len(self.input_tfms) > 0: + scheme_name = list(self.input_tfms.values())[0].scheme_name + location = list(self.input_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform( + self.input_tfms, + layer, + weight_loader, + input_size_per_partition, + output_partition_sizes, + ) + layer.register_module(transform_name, transform) + self.input_transform = transform + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform( + self.output_tfms, + layer, + weight_loader, + input_size_per_partition, + output_partition_sizes, + ) + layer.register_module(transform_name, transform) + self.output_transform = transform + + # compute partition ranges for slicing activations + starts = [0] + list(accumulate(output_partition_sizes))[:-1] + self.partition_ranges = list(zip(starts, output_partition_sizes)) + + def process_weights_after_loading(self, layer): + self.quant_method.process_weights_after_loading(layer) + + for submodule in layer.children(): + if isinstance(submodule, HadamardTransform): + submodule.process_weights_after_loading() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.input_transform is not None: + x = self.input_transform(x) + + assert bias is None + x = self.quant_method.apply(layer, x, bias) + + # In most cases, input transforms are preferred over output transforms + # (@ksayers): confirm that this is done concurrently + if self.output_transform is not None: + for part_id, (start, length) in enumerate(self.partition_ranges): + x[:, start : start + length] = self.output_transform( + x[:, start : start + length].contiguous(), part_id=part_id + ) + + return x + + def _validate_tfm_schemes(self, num_partitions: int): + if len(self.input_tfms) > 0: + if 0 not in self.input_tfms: + raise ValueError("Must have same input") + + for part_index in range(num_partitions): + if self.input_tfms[part_index] != self.input_tfms[0]: + raise ValueError("Must have same input") + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + + for tfm in self.output_tfms.values(): + if tfm.scheme_name != scheme_name: + raise ValueError("Must have same scheme name") + if tfm.args.location != location: + raise ValueError("Must have same location") + + return self.input_tfms, self.output_tfms + + +def get_linear_transform_schemes( + layer: torch.nn.Module, + layer_name: str, + transform_config: Optional[TransformConfig], + packed_modules_mapping: dict[str, list[str]], +) -> tuple[ + dict[int, TransformTuple], dict[int, TransformTuple] +]: # [input_transform, [output_transform, ...]] + # there can only be one transform input scheme per (fused) module + input_tfms = {} + output_tfms = {} + + partition_names = get_layer_partition_names(layer_name, packed_modules_mapping) + + for scheme_name, scheme, args in get_schemes_args(transform_config): + for part_index, part_name in enumerate(partition_names): + if ( + is_match(part_name, layer, args.targets, args.ignore) + and args.is_online() + ): + if args.location == TransformLocation.INPUT: + input_tfms[part_index] = TransformTuple(scheme_name, scheme, args) + + elif args.location == TransformLocation.OUTPUT: + output_tfms[part_index] = TransformTuple(scheme_name, scheme, args) + + else: + raise ValueError( + f"Cannot apply `{args.location}` transform to `{layer_name}`" + ) + + return (input_tfms, output_tfms) + + +def get_schemes_args( + transform_config: Optional[TransformConfig], +) -> Generator[tuple[str, TransformScheme, TransformArgs]]: + if transform_config is None: + return + + for scheme_name, scheme in transform_config.config_groups.items(): + for args in scheme.apply: + yield (scheme_name, scheme, args) + + +def get_layer_partition_names( + layer_name: str, packed_modules_mapping: dict[str, list[str]] +) -> list[str]: + """ + Get all partition names associated with this layer. + Names are returned in order of their partition indices. + + ```python + mapping = {"gate_up_proj", "gate_proj", "up_proj"} + + assert get_layer_partition_names("mlp.gate_up_proj", mapping) == [ + "gate_proj", + "up_proj", + ] + assert get_layer_partition_names("mlp.down_proj", mapping) == ["down_proj"]""" + for fused_suffix, part_suffixes in packed_modules_mapping.items(): + if layer_name.endswith(fused_suffix): + return [ + layer_name.removesuffix(fused_suffix) + part_suffix + for part_suffix in part_suffixes + ] + + return [layer_name] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py new file mode 100644 index 0000000000000..ecd798257fce2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Hashable +from typing import Callable + +import torch +from compressed_tensors.transform import ( + TransformArgs, + TransformLocation, + TransformScheme, +) +from torch import Tensor + +import vllm._custom_ops as ops +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple, +) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.parameter import SharedWeightParameter + + +class HadamardTransform(torch.nn.Module): + """ + Class which handles weight loading, postprocessing, and application of + transforms. Meant to be used with `CompressedTensorsLinearTransformMethod` + and attention transforms method (not implemented yet) + """ + + transforms: dict[int, TransformTuple] # info parsed from transforms config + weight: SharedWeightParameter # container for shared tensors + + scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) + + def __init__( + self, + transforms: dict[int, TransformTuple], + layer: torch.nn.Module, + weight_loader: Callable, + input_size_per_partition: int, + output_partition_sizes: list[int], + ): + super().__init__() + self.transforms = transforms + self.scales = {} + + if get_tensor_model_parallel_world_size() > 1: + raise NotImplementedError( + "Online transforms with tensor parallelism is not supported" + ) + + # Similar to row/col parallel params, but tensors are separate + # to allow for loading with shared memory + self.weight = SharedWeightParameter(weight_loader=weight_loader) + + # create shared partition data for each partition of the original weight + input_size = input_size_per_partition + for part_index, (_scheme_name, scheme, args) in self.transforms.items(): + output_size = output_partition_sizes[part_index] + weight_size = self._get_weight_size( + layer, scheme, args, input_size, output_size + ) + + data_key = self._get_data_key(scheme, weight_size) + self.weight.add_partition( + part_index, + data_key, + size=(weight_size, weight_size), + dtype=scheme.precision, + ) + + # validate that shared tensors and schemes are correct + self._validate_input_transforms() + + def process_weights_after_loading(self): + for part_id in self.weight.partitions: + data = self.weight.partitions[part_id].data + + # required by torch.compile + self.weight.process_weights_after_loading() + + # precompute scale as a runtime multiply, not division + # do not fold into weight in order to utilize FWHT + self.scales[part_id] = 1 / math.sqrt(data.size(0)) + + # FUTURE: avoid runtime transpose by processing weights + # prior to apply + + def forward(self, value: Tensor, part_id: int = 0) -> Tensor: + if part_id not in self.weight.partitions: + return value + + # use hadacore if possible + if self.transforms[part_id].scheme.type == "hadamard": + if self.transforms[part_id].scheme.head_dim is not None: + weight_size = self.transforms[part_id].scheme.head_dim + value = value.unflatten(-1, (-1, weight_size)) + value = ops.hadacore_transform(value) + value = value.flatten(-2, -1) + + return value + + # sylvester transforms are symmetric, inv => transpose => original + return ops.hadacore_transform(value) + + # fall back to dense + else: + weight = self.weight.partitions[part_id] + weight = ( + weight if self.transforms[part_id].args.inverse else weight.T + ) # linear := x(W.T) + scale = self.scales[part_id] + + if self.transforms[part_id].scheme.head_dim is not None: + value = value.unflatten(-1, (-1, weight.size(0))) + value = ( + dispatch_unquantized_gemm()( + self, value.to(weight.dtype), weight, None + ).to(value.dtype) + * scale + ) + value = value.flatten(-2, -1) + + return value + + return ( + dispatch_unquantized_gemm()( + self, value.to(weight.dtype), weight, None + ).to(value.dtype) + * scale + ) + + def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable: + return (id(scheme), weight_size) + + def _get_weight_size( + self, + layer: torch.nn.Module, + scheme: TransformScheme, + args: TransformArgs, + input_size: int, + output_size: int, + ) -> int: + if scheme.head_dim is not None: + return scheme.head_dim + + if isinstance(layer, LinearBase): + if args.location == TransformLocation.INPUT: + return input_size + + elif args.location == TransformLocation.OUTPUT: + return output_size + + elif isinstance(layer, VocabParallelEmbedding): + if args.location == TransformLocation.INPUT: + return output_size + + elif args.location == TransformLocation.OUTPUT: + return input_size + + raise ValueError() + + def _validate_input_transforms(self): + assert len(self.transforms) > 0 + location = list(self.transforms.values())[0].args.location + + if location == TransformLocation.INPUT: + first_data = self.weight.partitions[0].data + for partition in self.weight.partitions.values(): + if partition.data.data_ptr() != first_data.data_ptr(): + raise ValueError("") diff --git a/vllm/attention/backends/mla/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/__init__.py similarity index 100% rename from vllm/attention/backends/mla/__init__.py rename to vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/__init__.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py new file mode 100644 index 0000000000000..b800c5f5d436a --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsScheme, + CompressedTensorsW4A4Fp4, +) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod, + TransformTuple, +) + +__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"] + + +def is_qutlass_fp4_scheme( + quant_scheme: Optional[CompressedTensorsScheme], + input_tfms: dict[int, TransformTuple], +) -> bool: + return ( + isinstance(quant_scheme, (CompressedTensorsW4A4Fp4,)) + and len(input_tfms) == 1 + and input_tfms[0].scheme.head_dim == quant_scheme.group_size + ) + + +class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod): + def create_weights( + self, + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ): + # initializes fp4 qparams + assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4,)) + ret = super().create_weights( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) + + assert self.input_transform is not None + assert len(self.input_transform.weight) == 1 + assert self.input_transform.weight[0].size(0) == layer.scheme.group_size + + return ret + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py new file mode 100644 index 0000000000000..2f353de1e6a74 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple + +from compressed_tensors.transform import TransformArgs, TransformScheme + +__all__ = ["TransformTuple"] + + +class TransformTuple(NamedTuple): + scheme_name: str + scheme: TransformScheme + args: TransformArgs diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index d926b4c12db14..ed326197295dd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -17,13 +17,29 @@ def is_weak_contiguous(x: torch.Tensor): @triton.jit -def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, - M, N, K, stride_am, stride_ak, stride_bk, stride_bn, - stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_SCALE_A: tl.constexpr, - BLOCK_SIZE_SCALE_B: tl.constexpr): +def scaled_mm_kernel( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + bias_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr, +): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -32,8 +48,7 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, pid_n = pid % num_pid_n accumulator_dtype = ACCUMULATOR_DTYPE - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), - dtype=accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) # NOTE: Some tensor inputs are so large, they will cause int32 overflow # so it is necessary to use tl.int64 for all the offsets, else SEGV will @@ -47,20 +62,22 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, masks_bn = offsets_bn < N offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) - offsets_a = (stride_am * offsets_am[:, None] + - stride_ak * offsets_k[None, :]) - offsets_b = (stride_bk * offsets_k[:, None] + - stride_bn * offsets_bn[None, :]) + offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] + offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create # appropriate offsets and masks for each case. Same goes for # BLOCK_SIZE_SCALE_B. - offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + - (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) + offsets_scale_am = ( + tl.arange(0, BLOCK_SIZE_SCALE_A) + + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M + ) masks_scale_am = offsets_scale_am < M - offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + - (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) + offsets_scale_bn = ( + tl.arange(0, BLOCK_SIZE_SCALE_B) + + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N + ) masks_scale_bn = offsets_scale_bn < N a_ptrs = a_ptr + offsets_a @@ -114,8 +131,7 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) offs_cm = offs_cm.to(tl.int64) offs_cn = offs_cn.to(tl.int64) - c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + - stride_cn * offs_cn[None, :]) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) @@ -123,16 +139,18 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, # input - [M, K] # weight - [K, N] -def triton_scaled_mm(input: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None, - block_size_m: int = 32, - block_size_n: int = 32, - block_size_k: int = 32, - use_heuristic=True) -> torch.Tensor: +def triton_scaled_mm( + input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True, +) -> torch.Tensor: M, K = input.shape N = weight.shape[1] @@ -144,17 +162,16 @@ def triton_scaled_mm(input: torch.Tensor, scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() - assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 - or scale_a.shape[0] == M) - assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 - or scale_b.shape[0] == N) + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N) assert out_dtype.is_floating_point assert bias is None or bias.is_floating_point() assert is_weak_contiguous(input) assert is_weak_contiguous(weight) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - N, META['BLOCK_SIZE_N']), ) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) result = torch.empty((M, N), dtype=out_dtype, device=input.device) @@ -181,26 +198,28 @@ def triton_scaled_mm(input: torch.Tensor, # A = input, B = weight, C = result # A = M x K, B = K x N, C = M x N - scaled_mm_kernel[grid](input, - weight, - scale_a, - scale_b, - result, - bias, - M, - N, - K, - input.stride(0), - input.stride(1), - weight.stride(0), - weight.stride(1), - result.stride(0), - result.stride(1), - accumulator_dtype, - BLOCK_SIZE_M=block_size_m, - BLOCK_SIZE_N=block_size_n, - BLOCK_SIZE_K=block_size_k, - BLOCK_SIZE_SCALE_A=block_size_sa, - BLOCK_SIZE_SCALE_B=block_size_sb) + scaled_mm_kernel[grid]( + input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb, + ) return result.to(out_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 099d8613fc1a7..d8beaafff2ef1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -15,7 +15,7 @@ def is_activation_quantization_format(format: str) -> bool: CompressionFormat.naive_quantized.value, CompressionFormat.int_quantized.value, CompressionFormat.float_quantized.value, - CompressionFormat.nvfp4_pack_quantized.value + CompressionFormat.nvfp4_pack_quantized.value, ] return format in _ACTIVATION_QUANTIZATION_FORMATS @@ -23,7 +23,7 @@ def is_activation_quantization_format(format: str) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str] = tuple(), - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: if layer_name is None: return False @@ -49,7 +49,8 @@ def should_ignore_layer( should_ignore_layer = None for shard_name in shard_names: should_ignore_shard = check_equal_or_regex_match( - layer_name=shard_name, targets=ignore) + layer_name=shard_name, targets=ignore + ) # If shard_idx=0, set layer ignore to match shard. if should_ignore_layer is None: @@ -57,44 +58,43 @@ def should_ignore_layer( # If shard_idx=1+ confirm scheme matches prior shards. elif should_ignore_shard != should_ignore_layer: - raise ValueError(f"Found a different quantization schemes for " - f"{shard_proj_names} in {layer_name}. vLLM " - "requires all to use the same scheme.") + raise ValueError( + f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) # Unfused layers like down_proj and o_proj will match # the safetensors checkpoint already. else: - should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, - targets=ignore) + should_ignore_layer = check_equal_or_regex_match( + layer_name=layer_name, targets=ignore + ) assert should_ignore_layer is not None return should_ignore_layer -def check_equal_or_regex_match(layer_name: str, - targets: Iterable[str]) -> bool: +def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: """ Checks whether a layer_name is exactly equal or a regex match for if target starts with 're:' to any target in list. """ - for target in targets: - if _is_equal_or_regex_match(layer_name, target): - return True - return False + return any(_is_equal_or_regex_match(layer_name, target) for target in targets) def find_matched_target( layer_name: Optional[str], module: Module, targets: Iterable[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> str: """ Helper function to look up which "target" in the compressed-tensors config that a layer corresponds to. Recall that a compressed-tensors configs has a concept of - config_groups, where each layer can be quantized with with a different + config_groups, where each layer can be quantized with a different scheme. targets in each config_group will be a list of either layer names @@ -120,19 +120,21 @@ def find_matched_target( matched_target = ( _find_first_match(layer_name, targets) or _find_first_match(module.__class__.__name__, targets, True) - or _match_fused_layer(layer_name, targets, fused_mapping)) + or _match_fused_layer(layer_name, targets, fused_mapping) + ) if matched_target is None: raise ValueError( f"Unable to find matching target for {layer_name} in the " - "compressed-tensors config.") + "compressed-tensors config." + ) return matched_target -def _find_first_match(value: str, - targets: Iterable[str], - check_contains: bool = False) -> Optional[str]: +def _find_first_match( + value: str, targets: Iterable[str], check_contains: bool = False +) -> Optional[str]: """ Returns first element of target that matches value either exactly or as a regex after 're:'. If check_contains is set to True, @@ -144,16 +146,14 @@ def _find_first_match(value: str, """ for target in targets: - if _is_equal_or_regex_match(value, - target, - check_contains=check_contains): + if _is_equal_or_regex_match(value, target, check_contains=check_contains): return target return None -def _is_equal_or_regex_match(value: str, - target: str, - check_contains: bool = False) -> bool: +def _is_equal_or_regex_match( + value: str, target: str, check_contains: bool = False +) -> bool: """ Checks whether a value is exactly equal or a regex match for target if target starts with 're:'. If check_contains is set to True, @@ -173,10 +173,12 @@ def _is_equal_or_regex_match(value: str, def _match_fused_layer( - layer_name: str, target_layers: Iterable[str], - fused_mapping: Mapping[str, list[str]]) -> Optional[str]: + layer_name: str, + target_layers: Iterable[str], + fused_mapping: Mapping[str, list[str]], +) -> Optional[str]: """ - Match a fused layer name to its corresponding individual layer in + Match a fused layer name to its corresponding individual layer in target_layers. Returns first value in fused_mapping which matches targets Implements an "all" matching strategy where a fused layer matches iff @@ -193,8 +195,7 @@ def _match_fused_layer( "model.layers.0.self_attn.v_proj"] """ # find layer_name in mapping - fused = next((key for key in fused_mapping if layer_name.endswith(key)), - None) + fused = next((key for key in fused_mapping if layer_name.endswith(key)), None) if fused is None: return None diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py deleted file mode 100644 index d26a932eddb2c..0000000000000 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging - -import torch - -from vllm.platforms import current_platform -from vllm.triton_utils import triton -from vllm.utils import direct_register_custom_op -from vllm.utils.deep_gemm import fp8_gemm_nt - -logger = logging.getLogger(__name__) - - -def prepare_block_fp8_matmul_inputs( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype = torch.float16, -) -> tuple[int, int, int, torch.Tensor]: - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - - assert A.shape[-1] == B.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - assert A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] - - M = A.numel() // A.shape[-1] - - assert B.ndim == 2 - assert B.is_contiguous() - assert Bs.ndim == 2 - N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0] - assert triton.cdiv(K, block_k) == Bs.shape[1] - - C_shape = A.shape[:-1] + (N, ) - C = A.new_empty(C_shape, dtype=output_dtype) - - return M, N, K, C - - -def w8a8_block_fp8_matmul_deepgemm( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, - output_dtype) - # Deepgemm only supports output tensor type as bfloat16 - assert C.dtype == torch.bfloat16 - fp8_gemm_nt((A, As), (B, Bs), C) - return C - - -def w8a8_block_fp8_matmul_deepgemm_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, - output_dtype) - return C - - -direct_register_custom_op( - op_name="w8a8_block_fp8_matmul_deepgemm", - op_func=w8a8_block_fp8_matmul_deepgemm, - mutates_args=[], - fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, - dispatch_key=current_platform.dispatch_key, -) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 2922aef32939a..82a2103a19f33 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -9,16 +9,17 @@ import torch.nn.functional as F from packaging import version from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.utils import set_weight_attrs class DeepSpeedFPConfig(QuantizationConfig): """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. - - Args: + + Args: weight_bits: the target quantization bits, 6 or 8. group_size: group size for quantizaiton, default to 128. """ @@ -37,11 +38,14 @@ class DeepSpeedFPConfig(QuantizationConfig): raise ValueError( "Currently, only 6-bit or 8-bit weight quantization are " f"supported for DeepSpeed FP quantizaiton, but got " - f"{self.weight_bits} bits.") + f"{self.weight_bits} bits." + ) def __repr__(self) -> str: - return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " - f"group_size={self.group_size}") + return ( + f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -72,8 +76,9 @@ class DeepSpeedFPConfig(QuantizationConfig): "quantize_config.json", ] - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["DeepSpeedFPLinearMethod"]: if isinstance(layer, LinearBase): return DeepSpeedFPLinearMethod(self) return None @@ -90,15 +95,17 @@ class DeepSpeedFPLinearMethod(LinearMethodBase): self.quant_config = quant_config self.weight = None - def create_weights(self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - weight_loader=None, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs, + ): del output_size del input_size output_size_per_partition = sum(output_partition_sizes) @@ -107,10 +114,13 @@ class DeepSpeedFPLinearMethod(LinearMethodBase): params_dtype=params_dtype, quant_config=self.quant_config, ) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - }) + set_weight_attrs( + weight, + { + "input_dim": 1, + "output_dim": 0, + }, + ) layer.register_parameter("weight", weight) def quant_weight_loader(param, loaded_weight, *args, **kwargs): @@ -126,10 +136,12 @@ class DeepSpeedFPLinearMethod(LinearMethodBase): extra_weight_attrs["weight_loader"] = quant_weight_loader set_weight_attrs(weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: weight = layer.weight y = weight.ds_dequantize() return F.linear(x, y, bias) @@ -142,23 +154,33 @@ class DeepSpeedFPParameter(nn.Parameter): GPUs, and can be dequantized on-the-fly when needed by the model. """ - def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, - quant_config: DeepSpeedFPConfig): + def __new__( + cls, + orig_shape: torch.Size, + params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig, + ): try: import deepspeed + if version.parse(deepspeed.__version__) < version.parse("0.14.2"): - raise ImportError("deepspeed version is wrong. Please " - "install deepspeed>=0.14.2.") + raise ImportError( + "deepspeed version is wrong. Please install deepspeed>=0.14.2." + ) from deepspeed.ops.fp_quantizer import FP_Quantize except ImportError as err: - raise ImportError("Please install deepspeed>=0.14.2 via " - "`pip install deepspeed>=0.14.2` to use " - "deepspeedfp quantizer.") from err - data = torch.empty(( - orig_shape.numel() // quant_config.group_size, - quant_config.group_size * quant_config.weight_bits // 8 + 4, - ), - dtype=torch.int8) + raise ImportError( + "Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer." + ) from err + data = torch.empty( + ( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8, + ) self = torch.Tensor._make_subclass(cls, data, data.requires_grad) self.orig_shape = orig_shape self.quant_config = quant_config @@ -173,7 +195,8 @@ class DeepSpeedFPParameter(nn.Parameter): self.fp_quantizer.quantize( tensor.data, q_bits=self.quant_config.weight_bits, - )) + ) + ) def ds_dequantize(self, fp_out=None) -> torch.Tensor: """ @@ -181,7 +204,8 @@ class DeepSpeedFPParameter(nn.Parameter): """ assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 return self.fp_quantizer.dequantize( - self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits + ) def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: """ @@ -190,7 +214,5 @@ class DeepSpeedFPParameter(nn.Parameter): """ assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 return self.fp_quantizer.selective_dequantize( - self.data, - indices, - fp_out=fp_out, - q_bits=self.quant_config.weight_bits) + self.data, indices, fp_out=fp_out, q_bits=self.quant_config.weight_bits + ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 3e43caa4cbf72..909b04c79f238 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,18 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + int8_w8a16_moe_quant_config, +) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs @@ -42,8 +50,9 @@ class ExpertsInt8Config(QuantizationConfig): def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config": return cls() - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): @@ -52,7 +61,6 @@ class ExpertsInt8Config(QuantizationConfig): class ExpertsInt8MoEMethod(FusedMoEMethodBase): - def __init__( self, quant_config: ExpertsInt8Config, @@ -61,51 +69,71 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): int8_dtype = torch.int8 - assert 'weight_loader' in extra_weight_attrs - weight_loader = extra_weight_attrs['weight_loader'] + assert "weight_loader" in extra_weight_attrs + weight_loader = extra_weight_attrs["weight_loader"] wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader( - layer, weight_loader) - extra_weight_attrs['weight_loader'] = wrapped_weight_loader + layer, weight_loader + ) + extra_weight_attrs["weight_loader"] = wrapped_weight_loader # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=int8_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=int8_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=int8_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=int8_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - w13_scale = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - dtype=torch.float32), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.zeros( + num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32 + ), + requires_grad=False, + ) layer.register_parameter("w13_scale", w13_scale) - w2_scale = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - dtype=torch.float32), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_scale", w2_scale) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return int8_w8a16_moe_quant_config( + w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None + ) + def apply( self, layer: torch.nn.Module, @@ -120,6 +148,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -127,16 +156,17 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for `ExpertsInt8MoEMethod` yet.") + "EPLB not supported for `ExpertsInt8MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -146,8 +176,10 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -157,20 +189,21 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_int8_w8a16=True, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale) + quant_config=self.moe_quant_config, + ) @staticmethod def quantizing_weight_loader(layer, weight_loader): - - def quantize_and_call_weight_loader(param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, shard_id: int, - expert_id: int): + def quantize_and_call_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() shard_size = layer.intermediate_size_per_partition shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) @@ -178,33 +211,28 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): loaded_weight = loaded_weight.to(device) # w1, gate_proj case: Load into first shard of w13. if shard_id == "w1": - scales = quantize_in_place_and_get_scales( - loaded_weight[shard, :]) - layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, - 0]) + scales = quantize_in_place_and_get_scales(loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, 0]) # w3, up_proj case: Load into second shard of w13. elif shard_id == "w3": - scales = quantize_in_place_and_get_scales( - loaded_weight[shard, :]) - layer.w13_scale.data[expert_id, shard_size:2 * - shard_size].copy_(scales[:, 0]) + scales = quantize_in_place_and_get_scales(loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, shard_size : 2 * shard_size].copy_( + scales[:, 0] + ) # w2, down_proj case: Load into only shard of w2. elif shard_id == "w2": - scales = quantize_in_place_and_get_scales(loaded_weight[:, - shard]) + scales = quantize_in_place_and_get_scales(loaded_weight[:, shard]) layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) else: - raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") - weight_loader(param, loaded_weight, weight_name, shard_id, - expert_id) + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) return quantize_and_call_weight_loader def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor: vmax = torch.iinfo(torch.int8).max - scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax) + scales = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax weight.div_(scales) weight.round_() diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index b2cab7d4614ad..5d390cbd7b1ef 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -8,19 +8,33 @@ from torch.nn import Module from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) + GroupShape, + is_layer_skipped, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter) + Fp8LinearOp, + maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -60,23 +74,26 @@ class FBGEMMFp8Config(QuantizationConfig): input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped(prefix=prefix, - ignored_layers=self.ignore_list, - fused_mapping=self.packed_modules_mapping): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignore_list, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() return FBGEMMFp8LinearMethod(self) return None class FBGEMMFp8LinearMethod(LinearMethodBase): - def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + ) self.out_dtype = torch.get_default_dtype() def create_weights( @@ -101,43 +118,45 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): layer.orig_dtype = params_dtype # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE - weight_scale = ChannelQuantScaleParameter(data=torch.empty( - (sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND - input_scale_ub = torch.nn.Parameter(torch.tensor( - (self.quant_config.input_scale_ub), dtype=torch.float32), - requires_grad=False) + input_scale_ub = torch.nn.Parameter( + torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False, + ) layer.input_scale_ub = input_scale_ub def process_weights_after_loading(self, layer: Module) -> None: # required by torch.compile - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) weight = layer.weight if current_platform.is_fp8_fnuz(): - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=None) + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=layer.weight_scale, input_scale=None + ) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) @@ -148,11 +167,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): # Activations not quantized for marlin. del layer.input_scale_ub - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.quant_config.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -161,12 +181,15 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias) + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a4de4d7094c30..73e0044803984 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Callable, Optional +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch -import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -14,42 +14,85 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, - FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) + FusedMoE, + FusedMoEActivationFormat, + FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + FlashinferMoeBackend, + apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, - select_cutlass_fp8_gemm_impl, swap_w13_to_w31) + flashinfer_cutlass_moe_fp8, + get_flashinfer_moe_backend, + register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, + swap_w13_to_w31, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + W8A8BlockFp8LinearOp, + check_aiter_fp8_linear_support, + create_fp8_input_scale, + create_fp8_scale_parameter, + create_fp8_weight_parameter, + expert_weight_is_col_major, + maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, + process_fp8_weight_tensor_strategy, + requant_weight_ue8m0_inplace, + validate_fp8_block_shape, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, - prepare_moe_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, + prepare_moe_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) + GroupShape, + is_layer_skipped, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, - cutlass_fp8_supported, maybe_create_device_identity, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, - requantize_with_max_scale) -from vllm.model_executor.parameter import (BlockQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + Fp8LinearOp, + all_close_1d, + cutlass_block_fp8_supported, + cutlass_fp8_supported, + maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import ( + get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, +) from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -60,10 +103,67 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) -def _is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m +class Fp8MoeBackend(Enum): + NONE = 0 + FLASHINFER_TRTLLM = 1 + FLASHINFER_CUTLASS = 2 + DEEPGEMM = 3 + CUTLASS_BLOCK_SCALED_GROUPED_GEMM = 4 + MARLIN = 5 + TRITON = 6 + + +def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: + """ + Select the primary FP8 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + # prefer FlashInfer backends when available and enabled on supported GPUs + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and envs.VLLM_USE_FLASHINFER_MOE_FP8 + and has_flashinfer_moe() + ): + backend = get_flashinfer_moe_backend() + if backend == FlashinferMoeBackend.TENSORRT_LLM: + logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") + return Fp8MoeBackend.FLASHINFER_TRTLLM + else: + logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100") + return Fp8MoeBackend.FLASHINFER_CUTLASS + + # weight-only path for older GPUs without native FP8 + use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) + if current_platform.is_rocm(): + use_marlin = False + if use_marlin: + logger.info_once("Using Marlin backend for FP8 MoE") + return Fp8MoeBackend.MARLIN + + # deepGEMM on supported platforms with block-quantized weights + if envs.VLLM_USE_DEEP_GEMM and block_quant: + if not has_deep_gemm(): + logger.warning_once("DeepGEMM backend requested but not available.") + elif is_deep_gemm_supported(): + logger.info_once("Using DeepGEMM backend for FP8 MoE") + return Fp8MoeBackend.DEEPGEMM + + # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and block_quant + ): + logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") + return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + + # default to Triton + logger.info_once("Using Triton backend for FP8 MoE") + return Fp8MoeBackend.TRITON class Fp8Config(QuantizationConfig): @@ -81,23 +181,26 @@ class Fp8Config(QuantizationConfig): self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError( - f"Unsupported activation scheme {activation_scheme}") + raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] if weight_block_size is not None: if not is_checkpoint_fp8_serialized: raise ValueError( "The block-wise quantization only supports fp8-serialized " - "checkpoint for now.") + "checkpoint for now." + ) if len(weight_block_size) != 2: raise ValueError( "The quantization block size of weight must have 2 " - f"dimensions, but got {len(weight_block_size)} dimensions") + f"dimensions, but got {len(weight_block_size)} dimensions" + ) if activation_scheme != "dynamic": - raise ValueError("The block-wise quantization only supports " - "dynamic activation scheme for now, but got " - f"{activation_scheme} activation scheme.") + raise ValueError( + "The block-wise quantization only supports " + "dynamic activation scheme for now, but got " + f"{activation_scheme} activation scheme." + ) self.weight_block_size = weight_block_size @classmethod @@ -118,37 +221,78 @@ class Fp8Config(QuantizationConfig): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.ignored_layers is not None: - self.ignored_layers = hf_to_vllm_mapper.apply_list( - self.ignored_layers) + self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) @classmethod def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp8_serialized = ("fp8" in quant_method) + is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], - None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if not ignored_layers: - ignored_layers = cls.get_from_keys_or(config, - ["modules_to_not_convert"], - None) - return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, - activation_scheme=activation_scheme, - ignored_layers=ignored_layers, - weight_block_size=weight_block_size) + ignored_layers = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + ) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import + def get_xpu_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention + from vllm.model_executor.layers.quantization.ipex_quant import ( + XPUFp8LinearMethod, + XPUFp8MoEMethod, + ) + + fp8_config = Fp8Config( + is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized, + activation_scheme=self.activation_scheme, + ignored_layers=self.ignored_layers, + weight_block_size=self.weight_block_size, + ) if isinstance(layer, LinearBase): - if is_layer_skipped(prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + return XPUFp8LinearMethod(fp8_config) + elif isinstance(layer, FusedMoE): + return XPUFp8MoEMethod(fp8_config, layer) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if current_platform.is_xpu(): + return self.get_xpu_quant_method(layer, prefix) + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedFusedMoEMethod(layer.moe_config) return Fp8MoEMethod(self, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) @@ -200,31 +344,42 @@ class Fp8LinearMethod(LinearMethodBase): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False - # AITER is only supported on ROCm and only for FP8_FNUZ - # and at the moment are MI300 series - self.use_aiter_and_is_supported = (current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()) + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN + if self.weight_block_size: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: - self.act_q_group_shape = GroupShape.PER_TENSOR + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, - cutlass_fp8_supported=cutlass_fp8_supported()) + if self.block_quant: + assert not self.act_q_static + assert self.weight_block_size is not None + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape, + ) def create_weights( self, @@ -247,50 +402,34 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight_block_size = None if self.block_quant: - tp_size = get_tensor_model_parallel_world_size() - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size - block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, ) - # Required by row parallel - if (tp_size > 1 - and input_size // input_size_per_partition == tp_size - and input_size_per_partition % block_k != 0): - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}.") - # Required by column parallel or enabling merged weights - is_tp_split = (tp_size > 1 and - output_size // output_size_per_partition == tp_size) - is_merged_gemm = len(output_partition_sizes) > 1 - if is_tp_split or is_merged_gemm: - sizes_to_check = output_partition_sizes - if not is_tp_split and is_merged_gemm: - # In case of merged matrices, we allow the last - # matrix to not be a multiple of block size - sizes_to_check = output_partition_sizes[:-1] - for output_partition_size in sizes_to_check: - if output_partition_size % block_n != 0: - raise ValueError( - f"Weight output_partition_size = " - f"{output_partition_size} is not divisible by " - f"weight quantization block_n = {block_n}.") # WEIGHT - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) - - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + if self.quant_config.is_checkpoint_fp8_serialized: + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) + else: + # For non-serialized checkpoints, use original dtype + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # If checkpoint is serialized fp8, load them. @@ -298,149 +437,101 @@ class Fp8LinearMethod(LinearMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if not self.block_quant: - scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader, + scale = create_fp8_scale_parameter( + PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, + weight_loader, ) - scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, {"scale_type": "weight_scale"}) layer.register_parameter("weight_scale", scale) else: - assert self.quant_config.activation_scheme == "dynamic" - scale = BlockQuantScaleParameter( - data=torch.empty( - (output_size_per_partition + block_n - 1) // block_n, - (input_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, + assert not self.act_q_static + assert self.weight_block_size is not None + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader, ) - scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, {"scale_type": "weight_scale"}) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE - if self.quant_config.activation_scheme == "static": - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - scale[:] = torch.finfo(torch.float32).min + if self.act_q_static: + scale = create_fp8_input_scale(output_partition_sizes, weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: - # Pad the weight tensor. This is an optimization on ROCm platform, which - # can benefit from tensors located far enough from one another in memory - if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): - num_pad = 256 // weight.element_size() - weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() - return weight - def process_weights_after_loading(self, layer: Module) -> None: size_k_first = True + input_scale = None # TODO(rob): refactor block quant into separate class. if self.block_quant: - assert self.quant_config.activation_scheme == "dynamic" + assert not self.act_q_static size_k_first = False - if current_platform.is_fp8_fnuz(): - weight, weight_scale_inv, _ = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=layer.weight, - weight_scale=layer.weight_scale_inv) - else: - weight = layer.weight.data - weight_scale_inv = layer.weight_scale_inv.data - weight = self._maybe_pad_weight(weight) - - # Torch.compile cannot use Parameter subclasses. - layer.weight = Parameter(weight, requires_grad=False) - layer.weight_scale_inv = Parameter(weight_scale_inv, - requires_grad=False) + weight, weight_scale = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale_inv + ) + # Delete the weight_scale_inv parameter to avoid confusion + # with the weight_scale parameter + del layer.weight_scale_inv # If checkpoint not serialized fp8, quantize the weights. elif not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, - scale=None) + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + weight = qweight.t() - # Update the layer with the new values. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - layer.input_scale = None - - # If checkpoint is fp8, handle that there are N scales for N + # If checkpoint is fp8 per-tensor, handle that there are N scales for N # shards in a fused module else: - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) - if self.quant_config.activation_scheme == "static": - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) - weight = layer.weight weight_scale = layer.weight_scale # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. if not self.use_marlin: - # Dequant -> Quant with max scale so we can run per tensor. - if current_platform.is_fp8_fnuz(): - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=weight_scale, - input_scale=layer.input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - - weight_scale, weight = requantize_with_max_scale( - weight=weight, - weight_scale=weight_scale, - logical_widths=layer.logical_widths, + weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( + weight, + weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), ) + if self.act_q_static: + assert input_scale is not None + input_scale = input_scale.max() + weight = weight.t() - weight = self._maybe_pad_weight(weight) - # Update layer with new values. - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - if self.quant_config.activation_scheme == "static": - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + # Update layer with new values. + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + layer.input_scale = ( + Parameter(input_scale, requires_grad=False) + if input_scale is not None + else None + ) if self.use_marlin: prepare_fp8_layer_for_marlin(layer, size_k_first) # Activations not quantized for marlin. del layer.input_scale + return - # On B200, if E8M0 for DeepGemm is used, we need to - # requantize the weight and input to the specific scale - # at the same time. - if is_blackwell_deep_gemm_e8m0_used(): - assert layer.weight_block_size is not None - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.weight.data, - layer.weight_scale_inv.data if hasattr( - layer, "weight_scale_inv") else layer.weight_scale.data, - block_sz, - ) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.block_quant: + maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -449,28 +540,28 @@ class Fp8LinearMethod(LinearMethodBase): workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) - - if self.block_quant: - assert self.quant_config.weight_block_size is not None - - return torch.ops.vllm.apply_w8a8_block_fp8_linear( - input=x, - weight=layer.weight, - block_size=self.quant_config.weight_block_size, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias) + if self.block_quant: + assert self.weight_block_size is not None + + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) + + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) class Fp8MoEMethod(FusedMoEMethodBase): @@ -490,73 +581,34 @@ class Fp8MoEMethod(FusedMoEMethodBase): super().__init__(layer.moe_config) self.layer = layer self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant: bool = self.weight_block_size is not None + self.fused_experts: Optional[mk.FusedMoEModularKernel] = None # type: ignore + + self.fp8_backend = get_fp8_moe_backend(self.block_quant) + + self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore - if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): - self.flashinfer_moe_backend = get_flashinfer_moe_backend() - logger.info_once( - f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - ) - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) - # Disable marlin for rocm - if current_platform.is_rocm(): - self.use_marlin = False + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - # Check for DeepGemm support. - self.allow_deep_gemm = False - if envs.VLLM_USE_DEEP_GEMM: - if not has_deep_gemm(): - logger.warning_once("Failed to import DeepGemm kernels.") - elif not self.block_quant: - logger.warning_once("Model is not block quantized. Not using " - "DeepGemm kernels") - elif (is_deep_gemm_supported()): - logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") - self.allow_deep_gemm = True - else: - logger.warning_once( - "DeepGemm not supported on the current platform.") - - # Check for CutlassBlockScaledGroupedGemm support. - self.allow_cutlass_block_scaled_grouped_gemm = False - if not self.block_quant: - logger.debug_once("Model is not block quantized. Not using " - "CutlassBlockScaledGroupedGemm kernels") - elif (current_platform.is_cuda() - and current_platform.is_device_capability(100)): - logger.info_once( - "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." - ) - self.allow_cutlass_block_scaled_grouped_gemm = True - else: - logger.warning_once( - "CutlassBlockScaledGroupedGemm not supported on the current " - "platform.") - - def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe, - layer=self.layer, + self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM + self.allow_cutlass_block_scaled_grouped_gemm = ( + self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts @@ -566,12 +618,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], + self.weight_block_size[0], + self.weight_block_size[1], ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up @@ -581,31 +633,38 @@ class Fp8MoEMethod(FusedMoEMethodBase): raise ValueError( f"The output_size of gate's and up's weight = " f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_n = {block_n}.") - if (tp_size > 1 - and intermediate_size_per_partition % block_k != 0): + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: # Required by row parallel raise ValueError( f"The input_size of down's weight = " f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}.") + f"weight quantization block_k = {block_k}." + ) # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -613,20 +672,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): if not self.block_quant: # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) else: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * ((intermediate_size_per_partition + block_n - 1) // - block_n), + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), @@ -648,9 +706,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK. - value} if self.block_quant else - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() @@ -663,17 +722,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " - "was not serialized fp8.") + "was not serialized fp8." + ) - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) @@ -681,10 +741,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale = None layer.w2_input_scale = None + self.rocm_aiter_moe_enabled = False + def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) + is_rocm_aiter_moe_enabled, + shuffle_weights, + ) self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() @@ -692,20 +756,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" if current_platform.is_fp8_fnuz(): - w13_weight, w13_weight_scale_inv, w13_input_scale = \ + w13_weight, w13_weight_scale_inv, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale_inv, - layer.w13_input_scale) - w2_weight, w2_weight_scale_inv, w2_input_scale = \ + layer.w13_weight, + layer.w13_weight_scale_inv, + layer.w13_input_scale, + ) + ) + w2_weight, w2_weight_scale_inv, w2_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale_inv, - layer.w2_input_scale) + layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale + ) + ) elif self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm w13_weight = swap_w13_to_w31(layer.w13_weight.data) - w13_weight_scale_inv = swap_w13_to_w31( - layer.w13_weight_scale_inv.data) + w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data else: @@ -716,66 +783,67 @@ class Fp8MoEMethod(FusedMoEMethodBase): # torch.compile() cannot use Parameter subclasses. layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, - requires_grad=False) + layer.w13_weight_scale_inv = Parameter( + w13_weight_scale_inv, requires_grad=False + ) layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, - requires_grad=False) + layer.w2_weight_scale_inv = Parameter( + w2_weight_scale_inv, requires_grad=False + ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - # DeepGemm scales need to be transposed and aligned. We try to do + # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used(): - # Lazy import to avoid CUDA initialization problems. - if _is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() - if _is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv + ) + if expert_weight_is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv + ) # If checkpoint is fp16, quantize in place. elif not self.quant_config.is_checkpoint_fp8_serialized: fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=fp8_dtype) + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_weight_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + layer.w13_weight, layer.w2_weight + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) # If checkpoint is fp8, we need to handle that the # MoE kernels require single activation scale and single weight # scale for w13 per expert. @@ -783,46 +851,54 @@ class Fp8MoEMethod(FusedMoEMethodBase): # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": - if (layer.w13_input_scale is None - or layer.w2_input_scale is None): + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer.") + "for each layer." + ) layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) + layer.w13_input_scale.max(), requires_grad=False + ) layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ + w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False) + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False) + w2_input_scale, requires_grad=False + ) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -833,25 +909,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) start += shard_size if self.rocm_aiter_moe_enabled: shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + layer.w13_weight, layer.w2_weight + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) if self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is @@ -859,8 +935,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert not self.block_quant register_moe_scaling_factors(layer) w13_weight = swap_w13_to_w31(layer.w13_weight.data) - if self.flashinfer_moe_backend == \ - FlashinferMoeBackend.TENSORRT_LLM: + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) layer.w13_weight.data = w13_weight.data @@ -870,7 +945,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) @@ -886,60 +961,106 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale_inv): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv).contiguous() - if _is_col_major(layer.w2_weight_scale_inv): + layer.w13_weight_scale_inv + ) + if expert_weight_is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv).contiguous() + layer.w2_weight_scale_inv + ) + + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if ( + self.rocm_aiter_moe_enabled + or self.use_marlin + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + self.moe + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, + TritonOrDeepGemmExperts, + ) assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( - "Marlin and ROCm AITER are not supported with all2all yet.") + "Marlin and ROCm AITER are not supported with all2all yet." + ) - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): - max_num_tokens_per_rank = ( - prepare_finalize.max_num_tokens_per_rank()) + assert self.moe_quant_config is not None + + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None logger.debug( "BatchedTritonOrDeepGemmExperts(%s): " "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - self.__class__.__name__, max_num_tokens_per_rank, - self.quant_config.weight_block_size, False) + self.__class__.__name__, + max_num_tokens_per_rank, + self.weight_block_size, + False, + ) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=False, + quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: experts = select_cutlass_fp8_gemm_impl( - moe, - self.layer, + self.moe, + self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, self.quant_config.weight_block_size, - False) + self.__class__.__name__, + self.weight_block_size, + False, + ) return TritonOrDeepGemmExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + return fp8_w8a8_moe_quant_config( + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) + def apply( self, layer: torch.nn.Module, @@ -954,6 +1075,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -961,22 +1083,34 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - assert scoring_func == 'sigmoid', ( - f"Expected 'sigmoid' scoring func but got {scoring_func}") + if ( + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and self.fused_experts is None + ): + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) + assert scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {scoring_func}" + ) if self.block_quant: - assert (renormalize and use_grouped_topk - and custom_routing_function is None) + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + assert ( + renormalize and use_grouped_topk and custom_routing_function is None + ) + e_score_correction_bias = ( + e_score_correction_bias.to(x.dtype) + if e_score_correction_bias is not None + else None + ) return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( routing_logits=router_logits.to(torch.float32), routing_bias=e_score_correction_bias, @@ -992,13 +1126,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, + block_shape=self.weight_block_size, + routed_scaling=routed_scaling_factor, ) else: - assert (not renormalize - and custom_routing_function is not None) - return apply_flashinfer_per_tensor_scale_fp8( + assert not renormalize and custom_routing_function is not None + result = apply_flashinfer_per_tensor_scale_fp8( layer=layer, hidden_states=x, router_logits=router_logits, @@ -1007,9 +1140,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) - topk_weights, topk_ids = FusedMoE.select_experts( + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + select_result = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -1019,6 +1156,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, @@ -1026,32 +1164,38 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, ) + # + # Note: the order of checks is important since self.fused_experts + # can override fused_experts or cutlass but not rocm or marlin. + # + topk_weights, topk_ids, zero_expert_result = select_result + if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_fused_experts) - return rocm_aiter_fused_experts( + rocm_aiter_fused_experts, + ) + + assert self.fused_experts is None + result = rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - use_fp8_w8a8=True, apply_router_weight_on_input=apply_router_weight_on_input, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - expert_map=expert_map) + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) elif self.use_marlin: - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") - return torch.ops.vllm.fused_marlin_moe( + assert activation == "silu", f"{activation} not supported for Marlin MoE." + assert self.fused_experts is None + result = torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -1065,41 +1209,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_type_id=scalar_types.float8_e4m3fn.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert self.block_quant is None - assert (not renormalize and custom_routing_function is not None) - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - assert scoring_func == 'sigmoid', ( - f"Expected 'sigmoid' scoring func but got {scoring_func}") - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - common_kwargs = dict( + expert_map=expert_map, + workspace=layer.workspace, + ) + elif self.fused_experts: + result = self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1110,26 +1224,55 @@ class Fp8MoEMethod(FusedMoEMethodBase): global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert not self.block_quant + assert not renormalize and custom_routing_function is not None + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) + assert scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {scoring_func}" ) - if self.fused_experts is not None: - return self.fused_experts(**common_kwargs) - else: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - **common_kwargs, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm), - ) + result = flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + result = fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm + ), + ) + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py new file mode 100644 index 0000000000000..929e603149905 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -0,0 +1,420 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202 + +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._custom_ops import ( + cutlass_scaled_fp4_mm, + fusedQuantizeMx, + fusedQuantizeNv, + matmul_mxf4_bf16_tn, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +class FPQuantConfig(QuantizationConfig): + """Config class for FPQuant.""" + + def __init__( + self, + hadamard_group_size: int = 32, + forward_dtype: str = "mxfp4", + forward_method: str = "abs_max", + pseudoquantization: bool = False, + modules_to_not_convert: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.hadamard_group_size = hadamard_group_size + self.forward_dtype = forward_dtype + self.forward_method = forward_method + self.pseudoquantization = pseudoquantization + self.modules_to_not_convert = modules_to_not_convert + + if pseudoquantization: + raise ValueError("Pseudoquantization is not supported for vLLM") + + def __repr__(self) -> str: + return ( + f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, " + f"forward_dtype={self.forward_dtype}, " + f"forward_method={self.forward_method}, " + f"pseudoquantization={self.pseudoquantization}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "fp_quant" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig": + hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"]) + forward_dtype = cls.get_from_keys(config, ["forward_dtype"]) + forward_method = cls.get_from_keys(config, ["forward_method"]) + pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"]) + modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"]) + return cls( + hadamard_group_size, + forward_dtype, + forward_method, + pseudoquantization, + modules_to_not_convert, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[LinearMethodBase]: + if self.modules_to_not_convert is not None and any( + prefix.endswith(module) for module in self.modules_to_not_convert + ): + return UnquantizedLinearMethod() + + if isinstance(layer, LinearBase): + return FPQuantLinearMethod(self) + return None + + +class FPQuantLinearMethod(LinearMethodBase): + """Linear method for FPQuant. + + Args: + quant_config: The FPQuant quantization config. + """ + + def __init__(self, quant_config: FPQuantConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.bfloat16: + raise ValueError("Only bfloat16 is currently supported by FPQuant") + if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501 + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size. Or other skill issues." + ) + + assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], ( + "Only mxfp4 and nvfp4 are supported for now" + ) + if self.quant_config.forward_dtype == "mxfp4": + group_size = 32 + elif self.quant_config.forward_dtype == "nvfp4": + group_size = 16 + else: + raise ValueError( + f"Unsupported forward_dtype: {self.quant_config.forward_dtype}" + ) + + qweight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": 2, + } + | extra_weight_attrs, + ) + layer.register_parameter("qweight", qweight) + + scales = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": group_size, + } + | extra_weight_attrs, + ) + layer.register_parameter("scales", scales) + + weight_global_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs( + weight_global_scale, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("weight_global_scale", weight_global_scale) + + act_global_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs( + act_global_scale, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("act_global_scale", act_global_scale) + + forward_hadamard_matrix = Parameter( + torch.empty( + self.quant_config.hadamard_group_size, + self.quant_config.hadamard_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix) + + backward_hadamard_matrix = Parameter( + torch.empty( + self.quant_config.hadamard_group_size, + self.quant_config.hadamard_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return quantized_forward( + x, + layer.qweight, + layer.scales, + layer.weight_global_scale, + layer.act_global_scale, + bias, + layer.forward_hadamard_matrix, + self.quant_config.forward_method, + self.quant_config.forward_dtype, + ) + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def fused_quantize_mx( + x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str +) -> tuple[torch.Tensor, torch.Tensor]: + return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method) + + +def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method): + rows, cols = x_flat.size(0), x_flat.size(1) // 32 + padded_rows = ((rows + 128 - 1) // 128) * 128 + padded_cols = ((cols + 4 - 1) // 4) * 4 + + xh_e2m1 = torch.empty( + x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device + ) + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device + ) + + return xh_e2m1, xh_e8m0 + + +direct_register_custom_op( + op_name="fused_quantize_mx", + op_func=fused_quantize_mx, + mutates_args=[], + fake_impl=fused_quantize_mx_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def matmul_mxf4_bf16( + x: torch.Tensor, + w: torch.Tensor, + xs: torch.Tensor, + ws: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return matmul_mxf4_bf16_tn( + x, + w, + to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu), + to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu), + alpha, + ) + + +def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha): + return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device) + + +direct_register_custom_op( + op_name="matmul_mxf4_bf16", + op_func=matmul_mxf4_bf16, + mutates_args=[], + fake_impl=matmul_mxf4_bf16_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def fused_quantize_nv( + x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale) + + +def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale): + rows, cols = x_flat.size(0), x_flat.size(1) // 16 + padded_rows = ((rows + 128 - 1) // 128) * 128 + padded_cols = ((cols + 4 - 1) // 4) * 4 + + xh_e2m1 = torch.empty( + x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device + ) + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device + ) + + return xh_e2m1, xh_e8m0 + + +direct_register_custom_op( + op_name="fused_quantize_nv", + op_func=fused_quantize_nv, + mutates_args=[], + fake_impl=fused_quantize_nv_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def matmul_nvf4_bf16( + x: torch.Tensor, + w: torch.Tensor, + xs: torch.Tensor, + ws: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return cutlass_scaled_fp4_mm( + x, + w, + to_blocked(xs, backend="triton") + .view(torch.float8_e4m3fn) + .view(-1, x.shape[1] // 8), # *2//16 + to_blocked(ws, backend="triton") + .view(torch.float8_e4m3fn) + .view(-1, x.shape[1] // 8), + alpha, + torch.bfloat16, + ) + + +def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha): + return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device) + + +direct_register_custom_op( + op_name="matmul_nvf4_bf16", + op_func=matmul_nvf4_bf16, + mutates_args=[], + fake_impl=matmul_nvf4_bf16_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def quantized_forward( + x: torch.Tensor, + qweight: torch.Tensor, + weight_scales: torch.Tensor, + weight_global_scale: torch.Tensor, + act_global_scale: torch.Tensor, + bias: Optional[torch.Tensor], + forward_hadamard_matrix: torch.Tensor, + forward_method: str, + forward_dtype: str, +) -> torch.Tensor: + x_flat = x.contiguous().flatten(end_dim=-2) + + if forward_dtype == "mxfp4": + x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx( + x_flat, forward_hadamard_matrix, forward_method + ) + y = torch.ops.vllm.matmul_mxf4_bf16( + x_flat_q, + qweight, + x_flat_scales, + weight_scales, + 1 / (weight_global_scale * act_global_scale), + ) + elif forward_dtype == "nvfp4": + x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv( + x_flat, forward_hadamard_matrix, act_global_scale + ) + y = torch.ops.vllm.matmul_nvf4_bf16( + x_flat_q, + qweight, + x_flat_scales, + weight_scales, + 1 / (weight_global_scale * act_global_scale), + ) + else: + raise ValueError(f"Unsupported forward_dtype: {forward_dtype}") + + y = y.view(*x.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + + return y diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 90222f2e3b0e5..8296bc2ea3b48 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import gguf import torch @@ -10,16 +10,22 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEConfig, - FusedMoEMethodBase) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs from vllm.utils import direct_register_custom_op @@ -29,13 +35,12 @@ logger = init_logger(__name__) class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, - unquantized_modules: Optional[list[str]] = None) -> None: + def __init__(self, unquantized_modules: Optional[list[str]] = None) -> None: super().__init__() self.unquantized_modules = unquantized_modules or [] def __repr__(self) -> str: - return ("GGUFConfig()") + return "GGUFConfig()" def get_name(self) -> QuantizationMethods: return "gguf" @@ -55,8 +60,9 @@ class GGUFConfig(QuantizationConfig): def from_config(cls, config: dict[str, Any]) -> "GGUFConfig": return cls() - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): if is_layer_skipped_gguf(prefix, self.unquantized_modules): return UnquantizedLinearMethod() @@ -107,8 +113,9 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES -def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, - qweight_type: int) -> torch.Tensor: +def _fused_mul_mat_gguf( + x: torch.Tensor, qweight: torch.Tensor, qweight_type: int +) -> torch.Tensor: if qweight_type in IMATRIX_QUANT_TYPES: mmvq_safe = 8 if qweight.shape[0] > 5120 else 16 else: @@ -116,10 +123,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # HACK: when doing chunked prefill we don't generate output tokens # so input to logits generator is empty which causes invalid parameter if x.shape[0] == 0: - return torch.empty(x.shape[0], - qweight.shape[0], - dtype=x.dtype, - device=x.device) + return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device) # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T @@ -140,8 +144,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # Might be useful if llama.cpp adds a new quantization type. # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. qweight_type = WeightType(qweight_type) - raise NotImplementedError( - f"Unsupported GGUF quantization type: {qweight_type}") + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") return y @@ -150,17 +153,13 @@ def _fused_mul_mat_gguf_fake( qweight: torch.Tensor, qweight_type: int, ) -> torch.Tensor: - return torch.empty(x.shape[0], - qweight.shape[0], - dtype=x.dtype, - device=x.device) + return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device) try: direct_register_custom_op( op_name="_fused_mul_mat_gguf", op_func=_fused_mul_mat_gguf, - mutates_args=[], fake_impl=_fused_mul_mat_gguf_fake, ) fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf @@ -179,10 +178,9 @@ def _fused_moe_gguf( qweight_type2: int, activation: str, ) -> torch.Tensor: - def act(x: torch.Tensor): d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "silu": torch.ops._C.silu_and_mul(out, x) @@ -193,50 +191,73 @@ def _fused_moe_gguf( return out # lazy import to avoid triggering triton import in CPU backend - from vllm.model_executor.layers.fused_moe.fused_moe import ( - moe_align_block_size) + from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size out_hidden_states = torch.empty_like(x) # unless we decent expert reuse we are better off running moe_vec kernel - if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES - and x.shape[0] > 64): + if ( + qweight_type2 in MMQ_QUANT_TYPES + and qweight_type in MMQ_QUANT_TYPES + and x.shape[0] > 64 + ): num_tokens, _ = x.shape E, N, _ = w1.shape top_k = topk_ids.shape[1] BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type) - sorted_token_ids, expert_ids, num_tokens_post_padded = \ - moe_align_block_size(topk_ids, BLOCK_SIZE, E) - out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids, - num_tokens_post_padded, qweight_type, N, top_k, - num_tokens) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, BLOCK_SIZE, E + ) + out = ops.ggml_moe_a8( + x, + w1, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + qweight_type, + N, + top_k, + num_tokens, + ) out = act(out) - out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids, - num_tokens_post_padded, qweight_type2, - w2.shape[1], 1, num_tokens * top_k) + out = ops.ggml_moe_a8( + out, + w2, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + qweight_type2, + w2.shape[1], + 1, + num_tokens * top_k, + ) out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( - topk_weights.view(num_tokens, top_k, 1)) + topk_weights.view(num_tokens, top_k, 1) + ) ops.moe_sum(out, out_hidden_states) elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES: num_tokens, _ = x.shape E, N, _ = w1.shape top_k = topk_ids.shape[1] - out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, - num_tokens) + out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens) out = act(out) - out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2, - w2.shape[1], num_tokens * top_k) + out = ops.ggml_moe_a8_vec( + out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k + ) out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( - topk_weights.view(num_tokens, top_k, 1)) + topk_weights.view(num_tokens, top_k, 1) + ) ops.moe_sum(out, out_hidden_states) else: - logger.warning_once("There is no support for fast MoE kernel " - "for current quantization method. " - "Falling back to slow implementation. ") + logger.warning_once( + "There is no support for fast MoE kernel " + "for current quantization method. " + "Falling back to slow implementation. " + ) for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): - inp = x[tok].reshape((1, ) + x.shape[1:]) + inp = x[tok].reshape((1,) + x.shape[1:]) current_hidden_state = None for ww, ii in zip(w, idx): expert_up = w1[ii] @@ -245,8 +266,9 @@ def _fused_moe_gguf( out = act(out) expert_down = w2[ii] - current_state = fused_mul_mat_gguf(out, expert_down, - qweight_type2).mul_(ww) + current_state = fused_mul_mat_gguf( + out, expert_down, qweight_type2 + ).mul_(ww) if current_hidden_state is None: current_hidden_state = current_state else: @@ -272,7 +294,6 @@ try: direct_register_custom_op( op_name="_fused_moe_gguf", op_func=_fused_moe_gguf, - mutates_args=[], fake_impl=_fused_moe_gguf_fake, ) fused_moe_gguf = torch.ops.vllm._fused_moe_gguf @@ -293,15 +314,15 @@ def _apply_gguf_embedding( elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] x_flat = x.flatten() - assert (hidden_size == qweight.shape[1] // type_size * block_size) + assert hidden_size == qweight.shape[1] // type_size * block_size quant = torch.index_select(qweight, dim=0, index=x_flat) - dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, - x_flat.shape[0], dtype) + dequant = ops.ggml_dequantize( + quant, qweight_type, hidden_size, x_flat.shape[0], dtype + ) return dequant.view(*x.shape, hidden_size) else: qweight_type = WeightType(qweight_type) - raise NotImplementedError( - f"Unsupported GGUF quantization type: {qweight_type}") + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") def _apply_gguf_embedding_fake( @@ -318,7 +339,6 @@ try: direct_register_custom_op( op_name="_apply_gguf_embedding", op_func=_apply_gguf_embedding, - mutates_args=[], fake_impl=_apply_gguf_embedding_fake, ) apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding @@ -337,18 +357,24 @@ class GGUFLinearMethod(LinearMethodBase): def __init__(self, quant_config: GGUFConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): self.params_dtype = params_dtype output_size_per_partition = sum(output_partition_sizes) tensor_shape = (output_size_per_partition, input_size_per_partition) qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, @@ -356,31 +382,34 @@ class GGUFLinearMethod(LinearMethodBase): "data_container": [], "shard_id": [], "shard_id_map": {}, - }) + }, + ) set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("qweight", qweight) - qweight_type = Parameter(torch.empty(len(output_partition_sizes), - dtype=torch.uint8), - requires_grad=False) + qweight_type = Parameter( + torch.empty(len(output_partition_sizes), dtype=torch.uint8), + requires_grad=False, + ) set_weight_attrs( - qweight_type, { + qweight_type, + { "is_gguf_weight_type": True, "weight_type": 0, "shard_weight_type": {}, - "ignore_warning": True - }) + "ignore_warning": True, + }, + ) set_weight_attrs(qweight_type, extra_weight_attrs) layer.register_parameter("qweight_type", qweight_type) def process_weights_after_loading(self, layer: torch.nn.Module): qweight_type = layer.qweight_type.weight_type - if not (qweight_type in UNQUANTIZED_TYPES - or qweight_type in DEQUANT_TYPES): + if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES): qweight_type = WeightType(qweight_type) raise ValueError( - f"Unsupported GGUF quantization type {qweight_type} in " - f"layer {layer}.") + f"Unsupported GGUF quantization type {qweight_type} in layer {layer}." + ) # For MergedColumnParallelLinear and QKVParallelLinear, we need to # materialize the padded weight parameter for CUDA Graph compatibility. self._create_padded_weight_param(layer) @@ -393,22 +422,22 @@ class GGUFLinearMethod(LinearMethodBase): if len(data_container := qweight.data_container) > 1: dtype = {data.dtype for data in data_container} assert len(dtype) == 1, ValueError( - f"Data container has mixed dtypes: {dtype}") + f"Data container has mixed dtypes: {dtype}" + ) dtype = next(iter(dtype)) # concat dim0 and pad dim1 padded_side = max(x.size(1) for x in data_container) concat_side = sum(x.size(0) for x in data_container) # Pad the quantized weights to dense tensor, and create a map # with the location of each shard in the padded tensor. - padded_data = torch.zeros((concat_side, padded_side), - dtype=dtype, - device=qweight.device) + padded_data = torch.zeros( + (concat_side, padded_side), dtype=dtype, device=qweight.device + ) # (dim0_start, dim0_end, dim1_size) shard_offset_map = dict[str, tuple[int, int, int]]() for idx in shard_id: id_in_container = shard_id_map[idx] - start = sum( - x.size(0) for x in data_container[:id_in_container]) + start = sum(x.size(0) for x in data_container[:id_in_container]) end = start + data_container[id_in_container].size(0) size = data_container[id_in_container].size(1) padded_data[start:end, :size] = data_container[id_in_container] @@ -416,14 +445,15 @@ class GGUFLinearMethod(LinearMethodBase): qweight.data_container.clear() padded_param = Parameter(padded_data, requires_grad=False) set_weight_attrs(padded_param, vars(qweight)) - set_weight_attrs(padded_param, - {"shard_offset_map": shard_offset_map}) + set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map}) layer.register_parameter("qweight", padded_param) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: shard_id = layer.qweight.shard_id if shard_id: @@ -436,8 +466,9 @@ class GGUFLinearMethod(LinearMethodBase): qweight_type = layer.qweight_type.shard_weight_type[idx] result.append( fused_mul_mat_gguf( - x, qweight[start:end, :offset].contiguous(), - qweight_type)) + x, qweight[start:end, :offset].contiguous(), qweight_type + ) + ) out = torch.cat(result, axis=1) else: qweight = layer.qweight @@ -463,61 +494,73 @@ class GGUFMoEMethod(FusedMoEMethodBase): super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - - tensor_shape = (num_experts, 2 * intermediate_size_per_partition, - hidden_size) - #gate up proj + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size) + # gate up proj w13_qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( - w13_qweight, { + w13_qweight, + { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, "is_gguf_weight": True, "data_container": [], - }) + }, + ) set_weight_attrs(w13_qweight, extra_weight_attrs) layer.register_parameter("w13_qweight", w13_qweight) - w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), - requires_grad=False) - set_weight_attrs(w13_qweight_type, { - "is_gguf_weight_type": True, - "weight_type": 0, - "ignore_warning": True - }) + w13_qweight_type = Parameter( + torch.empty(1, dtype=torch.uint8), requires_grad=False + ) + set_weight_attrs( + w13_qweight_type, + {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True}, + ) set_weight_attrs(w13_qweight_type, extra_weight_attrs) layer.register_parameter("w13_qweight_type", w13_qweight_type) - tensor_shape = (num_experts, intermediate_size_per_partition, - hidden_size) - #gate down proj + tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size) + # gate down proj w2_qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( - w2_qweight, { + w2_qweight, + { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, "is_gguf_weight": True, "data_container": [], - }) + }, + ) set_weight_attrs(w2_qweight, extra_weight_attrs) layer.register_parameter("w2_qweight", w2_qweight) - w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), - requires_grad=False) - set_weight_attrs(w2_qweight_type, { - "is_gguf_weight_type": True, - "weight_type": 0, - "ignore_warning": True - }) + w2_qweight_type = Parameter( + torch.empty(1, dtype=torch.uint8), requires_grad=False + ) + set_weight_attrs( + w2_qweight_type, + {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True}, + ) set_weight_attrs(w2_qweight_type, extra_weight_attrs) layer.register_parameter("w2_qweight_type", w2_qweight_type) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -532,6 +575,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -539,20 +583,20 @@ class GGUFMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `GGUFMoEMethod` yet.") + raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.") assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" - "fused GGUF MoE method.") + "fused GGUF MoE method." + ) - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -562,12 +606,20 @@ class GGUFMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) - return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, - topk_weights, topk_ids, - layer.w13_qweight_type.weight_type, - layer.w2_qweight_type.weight_type, activation) + indices_type=self.topk_indices_dtype, + ) + return fused_moe_gguf( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights, + topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, + activation, + ) class GGUFEmbeddingMethod(GGUFLinearMethod): @@ -577,17 +629,14 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): quant_config: The GGUF quantization config. """ - def embedding(self, layer: torch.nn.Module, - x: torch.Tensor) -> torch.Tensor: + def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type hidden_size = qweight.tensor_shape[1] - return apply_gguf_embedding(x, - qweight, - qweight_type, - hidden_size, - dtype=self.params_dtype) + return apply_gguf_embedding( + x, qweight, qweight_type, hidden_size, dtype=self.params_dtype + ) class GGUFUninitializedParameter(UninitializedParameter): diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f18c936bac605..8f36fc70c4447 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -4,24 +4,36 @@ import enum from enum import Enum from fractions import Fraction -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_linear_quant_method) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) + get_linear_quant_method, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) +from vllm.transformers_utils.config import get_safetensors_params_metadata +from vllm.utils import is_list_of + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods +else: + QuantizationMethods = str class GPTQConfig(QuantizationConfig): @@ -37,6 +49,8 @@ class GPTQConfig(QuantizationConfig): desc_act: bool, lm_head_quantized: bool, dynamic: dict[str, dict[str, Union[int, bool]]], + autoround_version: str = "", + modules_in_block_to_quantize: Optional[list[str]] = None, ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -72,14 +86,23 @@ class GPTQConfig(QuantizationConfig): if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {self.weight_bits} bits.") + f"supported for GPTQ, but got {self.weight_bits} bits." + ) + + self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] + + # used to identify GPTQ model quantized by autoround + self.autoround_version = autoround_version def __repr__(self) -> str: - return (f"GPTQConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}), " - f"lm_head_quantized={self.lm_head_quantized}), " - f"dynamic={self.dynamic}") + return ( + f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}), " + f"lm_head_quantized={self.lm_head_quantized}, " + f"dynamic={self.dynamic}, " + f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -106,10 +129,22 @@ class GPTQConfig(QuantizationConfig): weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or( + config, ["autoround_version"], default="" + ) + modules_in_block_to_quantize = cls.get_from_keys_or( + config, ["modules_in_block_to_quantize"], default=None + ) + return cls( + weight_bits, + group_size, + desc_act, + lm_head_quantized, + dynamic, + autoround_version, + modules_in_block_to_quantize, + ) def get_quant_method( self, layer: torch.nn.Module, prefix: str @@ -125,14 +160,40 @@ class GPTQConfig(QuantizationConfig): "sym": True, # GPTQ typically uses symmetric quantization "lm_head": False, } - return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + def apply_vllm_mapper(self, hf_to_vllm_mapper): + if self.modules_in_block_to_quantize is not None: + self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( + self.modules_in_block_to_quantize + ) + + def maybe_update_config(self, model_name: str, revision: Optional[str] = None): + if self.modules_in_block_to_quantize: + if is_list_of(self.modules_in_block_to_quantize, list): + # original modules_in_block_to_quantize: list[list[str]] + # flatten original modules_in_block_to_quantize + self.modules_in_block_to_quantize = [ + item + for sublist in self.modules_in_block_to_quantize + for item in sublist + ] + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_in_block_to_quantize = list(quant_layers) + class ExllamaState(Enum): - UNUSED = enum.auto() UNINITIALIZED = enum.auto() READY = enum.auto() @@ -164,14 +225,15 @@ class GPTQLinearMethod(LinearMethodBase): raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) output_size_per_partition = sum(output_partition_sizes) - if (output_size_per_partition % self.quant_config.pack_factor.numerator - != 0): + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) if self.quant_config.group_size != -1: group_size = self.quant_config.group_size @@ -180,8 +242,10 @@ class GPTQLinearMethod(LinearMethodBase): exllama_state = ExllamaState.UNINITIALIZED scale_and_zero_size = input_size // group_size scale_and_zero_input_dim = None - if (input_size != input_size_per_partition - and self.quant_config.group_size != -1): + if ( + input_size != input_size_per_partition + and self.quant_config.group_size != -1 + ): # For act-order models, we cannot use Exllama for row parallel layer if self.quant_config.desc_act: exllama_state = ExllamaState.UNUSED @@ -200,56 +264,56 @@ class GPTQLinearMethod(LinearMethodBase): output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - g_idx = RowvLLMParameter(data=torch.tensor( - [ - i // self.quant_config.group_size - for i in range(input_size_per_partition) - ], - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + g_idx = RowvLLMParameter( + data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scale_and_zero_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scale_and_zero_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) @@ -271,24 +335,30 @@ class GPTQLinearMethod(LinearMethodBase): if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) layer.exllama_state = ExllamaState.READY - ops.gptq_shuffle(layer.qweight, layer.g_idx, - self.quant_config.weight_bits) + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) reshaped_x = x.reshape(-1, x.shape[-1]) - output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, - layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits) + output = ops.gptq_gemm( + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + layer.exllama_state == ExllamaState.READY, + self.quant_config.weight_bits, + ) if bias is not None: output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index d03074f861848..85cf4ed4ac58c 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -7,26 +7,39 @@ from packaging import version from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - BitBLASLinearKernel, MPLinearLayerConfig) + BitBLASLinearKernel, + MPLinearLayerConfig, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS) + BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM) + BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks, - check_bitblas_supported, verify_bitblas_supported) + MINIMUM_BITBLAS_VERSION, + bitblas_repeat_scales_on_all_ranks, + check_bitblas_supported, + verify_bitblas_supported, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -61,14 +74,16 @@ class GPTQBitBLASConfig(QuantizationConfig): quant_method: Optional[str], lm_head_quantized: bool, ) -> None: - try: import bitblas + if version.parse(bitblas.__version__) < version.parse( - MINIMUM_BITBLAS_VERSION): + MINIMUM_BITBLAS_VERSION + ): raise ImportError( "bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e raise ValueError( @@ -96,17 +111,20 @@ class GPTQBitBLASConfig(QuantizationConfig): raise ValueError( f"BitBLAS does not support weight_bits = {self.weight_bits}. " f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " - "are supported.") + "are supported." + ) if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: raise ValueError( f"BitBLAS does not support is_sym = {self.is_sym}. " - f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported." + ) self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE - storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE - if c.isdigit())) + storage_nbit = int( + "".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE if c.isdigit()) + ) # 4 Bits packed into 32 bit datatype. self.pack_factor = storage_nbit // weight_bits @@ -116,17 +134,20 @@ class GPTQBitBLASConfig(QuantizationConfig): self.zeros_mode = self.ZEROS_MODE if (weight_bits, is_sym) not in self.TYPE_MAP: - raise ValueError("Unsupported quantization config: " - f"bits={weight_bits}, sym={is_sym}") + raise ValueError( + f"Unsupported quantization config: bits={weight_bits}, sym={is_sym}" + ) self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] def __repr__(self) -> str: - return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act})" - f"is_sym={self.is_sym}, " - f"quant_method={self.quant_method})") + return ( + f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -151,36 +172,46 @@ class GPTQBitBLASConfig(QuantizationConfig): desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) quant_method = cls.get_from_keys(config, ["quant_method"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, is_sym, quant_method, - lm_head_quantized) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls( + weight_bits, group_size, desc_act, is_sym, quant_method, lm_head_quantized + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "bitblas" - or user_quant == "gptq_bitblas") + is_valid_user_quant = ( + user_quant is None + or user_quant == "bitblas" + or user_quant == "gptq_bitblas" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info("Detected that the model can run with gptq_bitblas" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_bitblas for" - " faster inference") + logger.info( + "Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference" + ) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): return GPTQBitBLASLinearMethod(self) return None @@ -201,8 +232,7 @@ class GPTQBitBLASConfig(QuantizationConfig): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or sym is None - or desc_act is None): + if num_bits is None or group_size is None or sym is None or desc_act is None: return False if (num_bits, sym) not in cls.TYPE_MAP: @@ -215,9 +245,9 @@ class GPTQBitBLASConfig(QuantizationConfig): return False # Otherwise, can convert if model satisfies bitblas constraints. - return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits, - sym)], - group_size=group_size) + return check_bitblas_supported( + quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size + ) class GPTQBitBLASLinearMethod(LinearMethodBase): @@ -233,8 +263,10 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQBitBLASConfig) -> None: self.quant_config = quant_config # Verify supported on platform. - verify_bitblas_supported(quant_type=self.quant_config.quant_type, - group_size=self.quant_config.group_size) + verify_bitblas_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, + ) def create_weights( self, @@ -248,7 +280,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): ) -> None: """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing + The function initializes and returns a dictionary containing quantized weights, scales, and zeros for performing quantized matrix multiplication operations. @@ -257,21 +289,22 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): output_partition_sizes: The size of the output partition. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: + params_dtype: The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or - if the input size per partition is not divisible by the - group size in `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the input + size per partition is not divisible by the group size + in `quant_config`. """ if params_dtype != torch.float16: - raise ValueError("Parameter data type must be torch.float16, " - f"but got {params_dtype}") + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) # Normalize group_size if self.quant_config.group_size != -1: @@ -294,18 +327,19 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_config.quant_type, act_type=params_dtype, group_size=self.quant_config.group_size, zero_points=False, - has_g_idx=self.quant_config.desc_act + has_g_idx=self.quant_config.desc_act, ) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for GPTQBitBLASLinearMethod", - kernel_type.__name__) + logger.info("Using %s for GPTQBitBLASLinearMethod", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # Normalize group_size @@ -315,9 +349,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): group_size = input_size # Determine sharding - if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act, - self.quant_config.group_size, - is_row_parallel): + if bitblas_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None @@ -340,16 +374,19 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) # Activation order # Ignore warning from fused linear layers such as QKVParallelLinear. - g_idx = RowvLLMParameter(data=torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) # Scales scales = Parameter( @@ -371,45 +408,42 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): # Quantized zero-points qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scales_and_zp_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c5d1e017014f3..8fa70a240f9ff 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -5,40 +5,63 @@ from copy import deepcopy from typing import Any, Callable, Optional, Union import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, - UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_dynamic_override, get_linear_quant_method, override_config) + get_dynamic_override, + get_linear_quant_method, + override_config, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, check_moe_marlin_supports_layer, - marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, - marlin_repeat_scales_on_all_ranks, verify_marlin_supported) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) + check_marlin_supported, + check_moe_marlin_supports_layer, + marlin_make_workspace_new, + marlin_moe_permute_scales, + marlin_permute_bias, + marlin_repeat_scales_on_all_ranks, + verify_marlin_supported, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.transformers_utils.config import get_safetensors_params_metadata +from vllm.utils import is_list_of logger = init_logger(__name__) def get_moe_quant_method( - config: QuantizationConfig, + config: "GPTQMarlinConfig", layer: torch.nn.Module, prefix: str, moe_method_cls: type, @@ -47,9 +70,13 @@ def get_moe_quant_method( if isinstance(layer, FusedMoE): # False = skip module, None = no override, else = Positive match - if get_dynamic_override( # noqa: E712 + if ( + get_dynamic_override( # noqa: E712 cloned_config, # noqa: E712 - layer_name=prefix) == False: # noqa: E712 + layer_name=prefix, + ) + == False + ): # noqa: E712 return UnquantizedFusedMoEMethod(layer.moe_config) if prefix: @@ -69,10 +96,17 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool, - dynamic: dict[str, dict[str, Union[int, bool]]], - full_config: dict[str, Any]) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + dynamic: dict[str, dict[str, Union[int, bool]]], + full_config: dict[str, Any], + modules_in_block_to_quantize: Optional[list[str]] = None, + ) -> None: super().__init__() if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False @@ -114,17 +148,25 @@ class GPTQMarlinConfig(QuantizationConfig): self.full_config = full_config if (weight_bits, is_sym) not in self.TYPE_MAP: - raise ValueError("Unsupported quantization config: " - f"bits={weight_bits}, sym={is_sym}") + raise ValueError( + f"Unsupported quantization config: bits={weight_bits}, sym={is_sym}" + ) self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] + # used to identify GPTQ model quantized by autoround + self.autoround_version = full_config.get("autoround_version", "") + def __repr__(self) -> str: - return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}, " - f"lm_head_quantized={self.lm_head_quantized}), " - f"dynamic={self.dynamic}") + return ( + f"GPTQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"dynamic={self.dynamic}, " + f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -151,47 +193,64 @@ class GPTQMarlinConfig(QuantizationConfig): group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, is_sym, - lm_head_quantized, dynamic, config) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + modules_in_block_to_quantize = cls.get_from_keys_or( + config, ["modules_in_block_to_quantize"], default=None + ) + return cls( + weight_bits, + group_size, + desc_act, + is_sym, + lm_head_quantized, + dynamic, + config, + modules_in_block_to_quantize, + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "marlin" - or user_quant == "gptq_marlin") + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info("Detected that the model can run with gptq_marlin" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_marlin for" - " faster inference") + logger.info( + "Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference" + ) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") - return MoeWNA16Config.from_config( - self.full_config).get_quant_method(layer, prefix) - return get_moe_quant_method(self, layer, prefix, - GPTQMarlinMoEMethod) - return get_linear_quant_method(self, layer, prefix, - GPTQMarlinLinearMethod) + "Falling back to Moe WNA16 kernels." + ) + return MoeWNA16Config.from_config(self.full_config).get_quant_method( + layer, prefix + ) + return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod) + return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @classmethod def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): @@ -208,15 +267,43 @@ class GPTQMarlinConfig(QuantizationConfig): return False # Marlin conversion is only valid if required properties are found - if (num_bits is None or group_size is None or sym is None - or desc_act is None): + if num_bits is None or group_size is None or sym is None or desc_act is None: return False if (num_bits, sym) not in cls.TYPE_MAP: return False - return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], - group_size=group_size) + return check_marlin_supported( + quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size + ) + + def apply_vllm_mapper(self, hf_to_vllm_mapper): + if self.modules_in_block_to_quantize is not None: + self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( + self.modules_in_block_to_quantize + ) + + def maybe_update_config(self, model_name: str, revision: Optional[str] = None): + if self.modules_in_block_to_quantize: + if is_list_of(self.modules_in_block_to_quantize, list): + # original modules_in_block_to_quantize: list[list[str]] + # flatten original modules_in_block_to_quantize + self.modules_in_block_to_quantize = [ + item + for sublist in self.modules_in_block_to_quantize + for item in sublist + ] + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_in_block_to_quantize = list(quant_layers) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -232,8 +319,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): self.quant_config = quant_config # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_config.quant_type, - group_size=self.quant_config.group_size) + verify_marlin_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, + ) def create_weights( self, @@ -251,20 +340,21 @@ class GPTQMarlinLinearMethod(LinearMethodBase): mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_config.quant_type, act_type=params_dtype, group_size=self.quant_config.group_size, zero_points=False, - has_g_idx=self.quant_config.desc_act + has_g_idx=self.quant_config.desc_act, ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for GPTQMarlinLinearMethod", - kernel_type.__name__) + logger.info("Using %s for GPTQMarlinLinearMethod", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # Normalize group_size @@ -274,9 +364,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase): group_size = input_size # Determine sharding - if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, - self.quant_config.group_size, - is_row_parallel): + if marlin_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None @@ -298,67 +388,69 @@ class GPTQMarlinLinearMethod(LinearMethodBase): output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) # Activation order - g_idx = RowvLLMParameter(data=torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scales_and_zp_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="qweight", - w_s_param_name="scales", - w_zp_param_name="qzeros", - w_gidx_param_name="g_idx") + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx", + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) @@ -387,8 +479,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): elif self.quant_config.quant_type.size_bits == 8: self.quant_type = scalar_types.uint8b128 else: - raise ValueError( - "GPTQMarlinMoEMethod only supports int4 and int8 now.") + raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") def create_weights( self, @@ -399,28 +490,27 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - intermediate_size_full = extra_weight_attrs.pop( - "intermediate_size_full") + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") self.is_k_full = (not self.quant_config.desc_act) or ( - intermediate_size_per_partition == intermediate_size_full) + intermediate_size_per_partition == intermediate_size_full + ) if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - w2_scales_size = (intermediate_size_full - if self.quant_config.desc_act else - intermediate_size_per_partition) - scales_size2 = (w2_scales_size // self.quant_config.group_size) + w2_scales_size = ( + intermediate_size_full + if self.quant_config.desc_act + else intermediate_size_per_partition + ) + scales_size2 = w2_scales_size // self.quant_config.group_size strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 scales_size2 = 1 strategy = FusedMoeWeightScaleSupported.CHANNEL.value - extra_weight_attrs.update({ - "quant_method": strategy, - "is_transposed": True - }) + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter( torch.empty( @@ -437,8 +527,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): w2_qweight = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size_per_partition // - self.quant_config.pack_factor, + intermediate_size_per_partition // self.quant_config.pack_factor, hidden_size, dtype=torch.int32, ), @@ -448,51 +537,51 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_qweight, extra_weight_attrs) # up_proj scales w13_scales = torch.nn.Parameter( - torch.empty(num_experts, - scales_size13, - 2 * intermediate_size_per_partition, - dtype=params_dtype), + torch.empty( + num_experts, + scales_size13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) # down_proj scales w2_scales = torch.nn.Parameter( - torch.empty(num_experts, - scales_size2, - hidden_size, - dtype=params_dtype), + torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) - # dont shard the w2 scales when running act order - set_weight_attrs(w2_scales, - {"load_full_w2": self.quant_config.desc_act}) + # don't shard the w2 scales when running act order + set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act}) # up_proj scales w13_qzeros = torch.nn.Parameter( - torch.empty(num_experts, - scales_size13, - 2 * intermediate_size_per_partition // - self.quant_config.pack_factor, - dtype=params_dtype), + torch.empty( + num_experts, + scales_size13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) # down_proj scales w2_qzeros = torch.nn.Parameter( - torch.empty(num_experts, - scales_size2, - hidden_size // self.quant_config.pack_factor, - dtype=params_dtype), + torch.empty( + num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - # dont shard the w2 scales when running act order - set_weight_attrs(w2_qzeros, - {"load_full_w2": self.quant_config.desc_act}) + # don't shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, @@ -521,8 +610,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty( @@ -532,15 +620,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) device = layer.w13_qweight.device layer.workspace = marlin_make_workspace_new(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Process act_order if self.quant_config.desc_act: # Get sorting based on g_idx @@ -550,42 +636,36 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( + torch.int32 + ) w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( - torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] device = layer.w13_g_idx.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) # Repack weights @@ -615,9 +695,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] * - (self.quant_config.group_size if self.quant_config.group_size != -1 - else self.quant_config.pack_factor), + size_k=layer.w2_scales.shape[1] + * ( + self.quant_config.group_size + if self.quant_config.group_size != -1 + else self.quant_config.pack_factor + ), size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) @@ -629,6 +712,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): if hasattr(layer, "w2_bias") and layer.w2_bias is not None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -643,6 +731,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -650,16 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for `GPTQMarlinMoEMethod` yet.") + "EPLB not supported for `GPTQMarlinMoEMethod` yet." + ) assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -669,8 +759,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -692,4 +784,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, workspace=layer.workspace, - is_k_full=self.is_k_full) + is_k_full=self.is_k_full, + ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index eba917d854118..8f0df55b0a5cf 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -9,13 +9,16 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -25,15 +28,12 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ - scalar_types.uint4b8, scalar_types.uint8b128 -] +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] class GPTQMarlin24Config(QuantizationConfig): - """Config class for Marlin24. - """ + """Config class for Marlin24.""" def __init__( self, @@ -49,17 +49,18 @@ class GPTQMarlin24Config(QuantizationConfig): self.group_size = group_size # Verify - if quant_type is None or \ - quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: + if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: raise ValueError( f"Marlin_24 does not support quant_type = {quant_type}. " f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " - "are supported.") + "are supported." + ) if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( f"Marlin_24 does not support group_size = {self.group_size}. " f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " - "are supported.") + "are supported." + ) self.quant_type = quant_type @@ -84,7 +85,8 @@ class GPTQMarlin24Config(QuantizationConfig): def __repr__(self) -> str: return "Marlin24Config(quant_type={}, group_size={})".format( - self.quant_type, self.group_size) + self.quant_type, self.group_size + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -111,23 +113,26 @@ class GPTQMarlin24Config(QuantizationConfig): @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: - is_marlin_24_format = ( - hf_quant_cfg.get("checkpoint_format") == "marlin_24") + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: + is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24" - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "gptq_marlin_24") + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24" + ) if is_marlin_24_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. " - "Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) logger.info(msg) return cls.get_name() return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["GPTQMarlin24LinearMethod"]: if isinstance(layer, LinearBase): return GPTQMarlin24LinearMethod(self) return None @@ -157,7 +162,8 @@ class GPTQMarlin24LinearMethod(LinearMethodBase): weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") + f"The params dtype must be float16, but got {params_dtype}" + ) # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) @@ -165,38 +171,46 @@ class GPTQMarlin24LinearMethod(LinearMethodBase): raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") + f"min_n_threads = {self.quant_config.min_n_threads}." + ) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") + f"pack_factor = {self.quant_config.pack_factor}." + ) # Validate input_size_per_partition if input_size_per_partition % self.quant_config.min_k_threads != 0: raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") + f"min_k_threads = {self.quant_config.min_k_threads}." + ) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}." + ) # Check that we have at least 4 tiles horizontally in the shard num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) + self.quant_config.tile_size**2 + ) if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") + raise ValueError("Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition // self.quant_config.tile_size // 2, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, + output_size_per_partition + * self.quant_config.tile_size + // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -205,55 +219,57 @@ class GPTQMarlin24LinearMethod(LinearMethodBase): packed_dim=1, packed_factor=self.quant_config.pack_factor, marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) # Meta - meta = PackedvLLMParameter(data=torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - device="cuda", - dtype=torch.int16, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=1, - marlin_tile_size=2, - weight_loader=weight_loader) + meta = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader, + ) # Determine if channelwise or not - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) + input_groups = ( + 1 + if self.quant_config.group_size == -1 + else input_size_per_partition // self.quant_config.group_size + ) weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( input_groups, output_size_per_partition, device="cuda", dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel + output_size_per_partition // self.quant_config.min_n_threads + ) * self.quant_config.max_parallel - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) + workspace = BasevLLMParameter( + data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), + weight_loader=weight_loader, + ) layer.register_parameter("B_24", qweight) layer.register_parameter("B_meta", meta) @@ -284,12 +300,19 @@ class GPTQMarlin24LinearMethod(LinearMethodBase): size_k = x_2d.shape[1] size_n = scales.shape[1] - output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, - self.quant_config.quant_type, - size_m, size_n, size_k) + output_2d = ops.gptq_marlin_24_gemm( + x_2d, + qweight, + meta, + scales, + workspace, + self.quant_config.quant_type, + size_m, + size_n, + size_k, + ) - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 8385ccac32a28..e61caf6b459b0 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -7,20 +7,32 @@ import torch from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + marlin_make_empty_g_idx, + marlin_permute_bias, + marlin_permute_scales, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace) + MarlinWorkspace, +) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack -from vllm.model_executor.parameter import (BasevLLMParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -36,10 +48,10 @@ class HQQMarlinConfig(QuantizationConfig): skip_modules: Optional[list[str]] = None, ) -> None: super().__init__() - assert group_size == 64, ("The only supported HQQ group size is " - "currently 64.") - assert weight_bits == 4, ("The only supported HQQ quantization " - "bitsize is currently 4.") + assert group_size == 64, "The only supported HQQ group size is currently 64." + assert weight_bits == 4, ( + "The only supported HQQ quantization bitsize is currently 4." + ) self.weight_bits = weight_bits self.group_size = group_size @@ -48,8 +60,10 @@ class HQQMarlinConfig(QuantizationConfig): self.skip_modules = skip_modules def __repr__(self) -> str: - return (f"HQQMarlinConfig(quant_type={self.quant_type}, " - f"group_size={self.group_size})") + return ( + f"HQQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -69,7 +83,7 @@ class HQQMarlinConfig(QuantizationConfig): @classmethod def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig": - wq_params = (config["quant_config"]["weight_quant_params"]) + wq_params = config["quant_config"]["weight_quant_params"] weight_bits = cls.get_from_keys(wq_params, ["nbits"]) group_size = cls.get_from_keys(wq_params, ["group_size"]) skip_modules = config["skip_modules"] @@ -77,14 +91,16 @@ class HQQMarlinConfig(QuantizationConfig): def is_layer_skipped(self, prefix: str) -> bool: # Split the prefix into its dot-separated components - components = prefix.split('.') + components = prefix.split(".") # Check if any of the skip modules exactly matches any component return self.skip_modules is not None and any( - module_name in components for module_name in self.skip_modules) + module_name in components for module_name in self.skip_modules + ) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): if self.is_layer_skipped(prefix): return UnquantizedLinearMethod() @@ -94,7 +110,6 @@ class HQQMarlinConfig(QuantizationConfig): # Empty HQQ parameter, will be ignored during loading class HQQEmptyParameter(BasevLLMParameter): - def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): pass @@ -112,23 +127,18 @@ def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # HQQ packing creates issues with sharding - therefore, prior to loading, we # repack to GPTQ. We also reshape the weights to their proper GPTQ shape. class HQQweightParameter(PackedvLLMParameter): - # unpack function from https://github.com/mobiusml/hqq - def unpack_4bit_u8(self, - W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 + def unpack_4bit_u8(self, W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)" dtype = torch.uint8 step = W_q.shape[0] - tmp = torch.empty([2 * step, W_q.shape[1]], - dtype=dtype, - device=W_q.device) + tmp = torch.empty([2 * step, W_q.shape[1]], dtype=dtype, device=W_q.device) tmp[:step] = (W_q & 0b11110000) >> 4 tmp[step:] = W_q & 0b00001111 return tmp - def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, - **kwargs): + def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, **kwargs): super().__init__(packed_factor, packed_dim, None, **kwargs) self.weight_bits = weight_bits self.input_shape = self.shape[self.input_dim] * self.packed_factor @@ -136,36 +146,41 @@ class HQQweightParameter(PackedvLLMParameter): def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): loaded_weight = self.unpack_4bit_u8(loaded_weight) - loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( - 1, 0) - loaded_weight = gptq_pack(loaded_weight, self.weight_bits, - loaded_weight.shape[0], - loaded_weight.shape[1]) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(1, 0) + loaded_weight = gptq_pack( + loaded_weight, + self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1], + ) super().load_merged_column_weight(loaded_weight, **kwargs) def load_row_parallel_weight(self, loaded_weight: torch.Tensor): loaded_weight = self.unpack_4bit_u8(loaded_weight) - loaded_weight = loaded_weight.reshape(self.output_shape, - -1).transpose(1, 0) - loaded_weight = gptq_pack(loaded_weight, self.weight_bits, - loaded_weight.shape[0], - loaded_weight.shape[1]) + loaded_weight = loaded_weight.reshape(self.output_shape, -1).transpose(1, 0) + loaded_weight = gptq_pack( + loaded_weight, + self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1], + ) super().load_row_parallel_weight(loaded_weight) def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): loaded_weight = self.unpack_4bit_u8(loaded_weight) - loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( - 1, 0) - loaded_weight = gptq_pack(loaded_weight, self.weight_bits, - loaded_weight.shape[0], - loaded_weight.shape[1]) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(1, 0) + loaded_weight = gptq_pack( + loaded_weight, + self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1], + ) super().load_qkv_weight(loaded_weight, **kwargs) # Zero points and scales in HQQ must also be reshaped to correspond to W_q's # GPTQ shape (transposed - we transpose them too when processing weights). class HQQZeroScaleParameter(GroupQuantScaleParameter): - def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): loaded_weight = loaded_weight.reshape(-1, self.shape[1]) super().load_merged_column_weight(loaded_weight, **kwargs) @@ -180,8 +195,7 @@ class HQQZeroScaleParameter(GroupQuantScaleParameter): class HQQMarlinMethod(LinearMethodBase): - """Linear method for HQQ Marlin. - """ + """Linear method for HQQ Marlin.""" def __init__( self, @@ -204,8 +218,9 @@ class HQQMarlinMethod(LinearMethodBase): weight_loader = extra_weight_attrs.get("weight_loader", error_loader) - self.scales_and_zp_size = (input_size_per_partition // - self.quant_config.group_size) + self.scales_and_zp_size = ( + input_size_per_partition // self.quant_config.group_size + ) qweight = HQQweightParameter( data=torch.empty( @@ -218,25 +233,30 @@ class HQQMarlinMethod(LinearMethodBase): packed_dim=0, packed_factor=self.quant_config.pack_factor, weight_bits=self.quant_config.weight_bits, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - zeros = HQQZeroScaleParameter(data=torch.empty( - self.output_size_per_partition, - self.scales_and_zp_size, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + zeros = HQQZeroScaleParameter( + data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) - scales = HQQZeroScaleParameter(data=torch.empty( - self.output_size_per_partition, - self.scales_and_zp_size, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + scales = HQQZeroScaleParameter( + data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("W_q", qweight) layer.register_parameter("zero", zeros) @@ -244,17 +264,29 @@ class HQQMarlinMethod(LinearMethodBase): # Ignore extra parameters in the HQQ model. # To be added as needed. - ignore_parameters = ("axis", "channel_wise", "compute_dtype", - "encoded_state_dict", "group_size", "nbits", - "offload_meta", "optimize", "packing", - "quant_scale", "quant_zero", "round_zero", - "shape", "stores_quant_config", - "unpack_view_dtype", "view_as_float") + ignore_parameters = ( + "axis", + "channel_wise", + "compute_dtype", + "encoded_state_dict", + "group_size", + "nbits", + "offload_meta", + "optimize", + "packing", + "quant_scale", + "quant_zero", + "round_zero", + "shape", + "stores_quant_config", + "unpack_view_dtype", + "view_as_float", + ) for name in ignore_parameters: layer.register_parameter( name, - HQQEmptyParameter(data=torch.empty(0), - weight_loader=weight_loader)) + HQQEmptyParameter(data=torch.empty(0), weight_loader=weight_loader), + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: dev = layer.W_q.device @@ -268,14 +300,18 @@ class HQQMarlinMethod(LinearMethodBase): self.output_size_per_partition, self.quant_config.weight_bits, ).to(dev) - marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0), - self.input_size_per_partition, - self.output_size_per_partition, - self.quant_config.group_size).to(dev) - marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0), - self.input_size_per_partition, - self.output_size_per_partition, - self.quant_config.group_size).to(dev) + marlin_s = marlin_permute_scales( + layer.scale.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size, + ).to(dev) + marlin_zp = marlin_permute_scales( + layer.zero.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size, + ).to(dev) layer.g_idx = marlin_make_empty_g_idx(dev) layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) @@ -293,9 +329,11 @@ class HQQMarlinMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - workspace = MarlinWorkspace(self.output_size_per_partition, - GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = MarlinWorkspace( + self.output_size_per_partition, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL, + ) scales = layer.marlin_scales zeros = layer.marlin_zeros diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index 8aa1f1a14bfc9..4e736378e9dac 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -21,12 +21,15 @@ from typing import Any, Optional import torch from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) + FusedMoE, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) class INCConfig(QuantizationConfig): @@ -44,8 +47,9 @@ class INCConfig(QuantizationConfig): def from_config(cls, config: dict[str, Any]) -> "INCConfig": raise AssertionError - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index e1a9bdde9334d..8786638869a4e 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -7,8 +7,7 @@ import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy @@ -23,28 +22,44 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): """ - Quantize input tensor to per-tensor or per-token FP8. + Quantize input tensor to FP8 (per-tensor, per-token, or per-group). This CustomOp supports both static and dynamic quantization. """ - def __init__(self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None): + def __init__( + self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False, + use_ue8m0: Optional[bool] = None, # for Torch compile + ): """ - :param static: static or dynamic quantization - :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) - :param num_token_padding: Pad the token dimension of output to this size + :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, + or arbitrary block size) + :param num_token_padding: Pad the token dimension of output to this + size + :param column_major_scales: For group quantization, output scales in + column major format """ super().__init__() - self.num_token_padding = num_token_padding - assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} - assert not static or group_shape == GroupShape.PER_TENSOR, \ - "Only per-tensor scales supported for static quantization." self.static = static self.group_shape = group_shape - self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + self.num_token_padding = num_token_padding + self.column_major_scales = column_major_scales + self.use_ue8m0 = use_ue8m0 + + self.is_group_quant = group_shape.is_per_group() + if self.is_group_quant: + assert not static, "Group quantization only supports dynamic mode" + self.group_size = group_shape.col + else: + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, ( + "Only per-tensor scales supported for static quantization." + ) + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN def forward_cuda( self, @@ -52,17 +67,31 @@ class QuantFP8(CustomOp): scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - assert (scale is not None) == self.static - assert scale_ub is None or (not self.static and self.group_shape - == GroupShape.PER_TOKEN - and scale_ub.numel() == 1) + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + from vllm.model_executor.layers.quantization.utils import fp8_utils + return fp8_utils.per_token_group_quant_fp8( + x, + group_size=self.group_size, + column_major_scales=self.column_major_scales, + dtype=_FP8_DTYPE, + use_ue8m0=self.use_ue8m0, + ) + + assert (scale is not None) == self.static + assert scale_ub is None or ( + not self.static + and self.group_shape == GroupShape.PER_TOKEN + and scale_ub.numel() == 1 + ) return ops.scaled_fp8_quant( x, scale, num_token_padding=self.num_token_padding, scale_ub=scale_ub, - use_per_token_if_dynamic=self.use_per_token_if_dynamic) + use_per_token_if_dynamic=self.use_per_token_if_dynamic, + ) def forward_native( self, @@ -70,10 +99,16 @@ class QuantFP8(CustomOp): scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ): + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + return self._quantize_group_native(x) + assert (scale is not None) == self.static - assert scale_ub is None or (not self.static and self.group_shape - == GroupShape.PER_TOKEN - and scale_ub.numel() == 1) + assert scale_ub is None or ( + not self.static + and self.group_shape == GroupShape.PER_TOKEN + and scale_ub.numel() == 1 + ) if scale is None: if self.group_shape == GroupShape.PER_TOKEN: @@ -84,8 +119,7 @@ class QuantFP8(CustomOp): else: x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - scale = x_max / _FP8_MAX - scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) # Even for dynamic per-token scales, # reciprocal performs slightly better than division @@ -101,3 +135,38 @@ class QuantFP8(CustomOp): out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) return out, scale + + def _quantize_group_native( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = x.shape + hidden_dim = x.shape[-1] + num_groups = (hidden_dim + self.group_size - 1) // self.group_size + padded_dim = num_groups * self.group_size + + if padded_dim != hidden_dim: + padding = padded_dim - hidden_dim + x = F.pad(x, (0, padding), mode="constant", value=0.0) + + x_grouped = x.view(-1, num_groups, self.group_size) + absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() + scales_raw = absmax / _FP8_MAX + if self.use_ue8m0: + scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw))) + scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) + + x_scaled = x_grouped / scales + x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + x_quant = x_quant.view(-1, padded_dim) + if padded_dim != hidden_dim: + x_quant = x_quant[..., :hidden_dim] + x_quant = x_quant.view(orig_shape) + + scales = scales.squeeze(-1) + scales = scales.reshape(orig_shape[:-1] + (num_groups,)) + + if self.column_major_scales: + scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) + + return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 9c458954f960f..4aa0e464e0f53 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -1,19 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from packaging import version +from torch.nn import Module +from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, - is_layer_skipped_awq) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm._ipex_ops import ipex_ops as ops +from vllm.model_executor.layers.fused_moe import ( + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) +from vllm.model_executor.layers.quantization.awq import ( + AWQLinearMethod, + is_layer_skipped_awq, +) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform MIN_IPEX_VERSION = "2.6.0" @@ -48,17 +64,22 @@ class IPEXConfig(QuantizationConfig): self.pack_factor = 32 // self.weight_bits if self.weight_bits not in [4]: - raise ValueError(f"IPEX quantization supports weight bits [4], " - f"but got {self.weight_bits}.") + raise ValueError( + f"IPEX quantization supports weight bits [4], " + f"but got {self.weight_bits}." + ) if self.method not in ["awq", "gptq"]: - raise ValueError(f"IPEX quantization supports [awq, gptq], " - f"but got {self.method}.") + raise ValueError( + f"IPEX quantization supports [awq, gptq], but got {self.method}." + ) def __repr__(self) -> str: - return (f"IPEXConfig(method={self.method}," - f"weight_bits={self.weight_bits}, " - f"group_size={self.group_size})") + return ( + f"IPEXConfig(method={self.method}," + f"weight_bits={self.weight_bits}, " + f"group_size={self.group_size})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -84,24 +105,24 @@ class IPEXConfig(QuantizationConfig): method = cls.get_from_keys(config, ["quant_method"]).lower() if method == "awq": weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, - ["q_group_size", "group_size"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) - return cls(method, weight_bits, group_size, modules_to_not_convert, - False, False) + config, ["modules_to_not_convert"], None + ) + return cls( + method, weight_bits, group_size, modules_to_not_convert, False, False + ) # otherwise for gptq weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) - return cls(method, weight_bits, group_size, [], desc_act, - lm_head_quantized) + return cls(method, weight_bits, group_size, [], desc_act, lm_head_quantized) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: if not current_platform.is_cpu() and not current_platform.is_xpu(): return None @@ -112,8 +133,9 @@ class IPEXConfig(QuantizationConfig): return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["LinearMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): if self.method == "awq": if is_layer_skipped_awq(prefix, self.modules_to_not_convert): @@ -125,8 +147,7 @@ class IPEXConfig(QuantizationConfig): class IPEXGPTQLinearMethod(GPTQLinearMethod): - """GPTQ linear method using IPEX for the CPU/XPU backend. - """ + """GPTQ linear method using IPEX for the CPU/XPU backend.""" def __init__(self, quant_config: IPEXConfig): self.quant_config = quant_config # type: ignore @@ -136,18 +157,20 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): try: import intel_extension_for_pytorch as ipex - if version.parse( - ipex.__version__) < version.parse(MIN_IPEX_VERSION): + + if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): raise ImportError( "intel_extension_for_pytorch version is " "wrong. Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." + ) except ImportError as err: raise ImportError( "Please install " f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" - " to use IPEX-AWQ linear method.") from err + " to use IPEX-AWQ linear method." + ) from err # Using the compute dtype (lowp_mode) as INT8 to leverage instructions # with better performance. lowp_mode = ipex.quantization.WoqLowpMode.INT8 @@ -164,32 +187,34 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): ) layer.ipex_output_size = layer.qweight.shape[-1] g_idx = layer.g_idx if self.quant_config.desc_act else None - layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ - IPEXWeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - g_idx=g_idx, - bias=bias, - group_size=self.quant_config.group_size, - quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] + layer.ipex_qlinear = ( + ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"], + ) ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) class IPEXAWQLinearMethod(AWQLinearMethod): - """AWQ linear method using IPEX for the CPU/XPU backend. - """ + """AWQ linear method using IPEX for the CPU/XPU backend.""" def __init__(self, quant_config: IPEXConfig): self.quant_config = quant_config # type: ignore @@ -201,18 +226,20 @@ class IPEXAWQLinearMethod(AWQLinearMethod): try: import intel_extension_for_pytorch as ipex - if version.parse( - ipex.__version__) < version.parse(MIN_IPEX_VERSION): + + if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): raise ImportError( "intel_extension_for_pytorch version is " "wrong. Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." + ) except ImportError as err: raise ImportError( "Please install " f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" - " to use IPEX-AWQ linear method.") from err + " to use IPEX-AWQ linear method." + ) from err # Using the compute dtype (lowp_mode) as INT8 to leverage instructions # with better performance. @@ -229,25 +256,193 @@ class IPEXAWQLinearMethod(AWQLinearMethod): group_size=self.quant_config.group_size, ) - layer.ipex_output_size = layer.qweight.size( - 1) * self.quant_config.pack_factor - layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ - IPEXWeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - bias=bias, - group_size=self.quant_config.group_size, - quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore + layer.ipex_output_size = layer.qweight.size(1) * self.quant_config.pack_factor + layer.ipex_qlinear = ( + ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"], # type: ignore + ) ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) + + +class XPUFp8LinearMethod(Fp8LinearMethod): + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) + + def process_weights_after_loading(self, layer: Module) -> None: + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + # Update the layer with the new values. + layer.weight = Parameter(qweight, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + weight = layer.weight.data + weight_scale = layer.weight_scale.data + output = torch.ops.torch_ipex.fp8_gemm_w8a16( + x, weight, True, weight_scale, bias + ) + return output + + +class XPUFp8MoEMethod(FusedMoEMethodBase): + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(layer.moe_config) + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + # INPUT_SCALES + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + if not self.quant_config.is_checkpoint_fp8_serialized: + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + import intel_extension_for_pytorch as ipex + + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + w1_scale_inv=layer.w13_weight_scale, + w2_scale_inv=layer.w2_weight_scale, + a1_scale_inv=layer.w13_input_scale, + a2_scale_inv=layer.w2_input_scale, + use_prepack=True, + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function=custom_routing_function, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 07ecc096231a4..055a3ebbced61 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -20,10 +20,10 @@ class MPLinearLayerConfig: group_size: int zero_points: bool has_g_idx: bool + out_type: Optional[torch.dtype] = None class MPLinearKernel(ABC): - @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -31,16 +31,17 @@ class MPLinearKernel(ABC): @classmethod @abstractmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: raise NotImplementedError - def __init__(self, - c: MPLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - w_zp_param_name: Optional[str] = None, - w_gidx_param_name: Optional[str] = None) -> None: + def __init__( + self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None, + ) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -57,31 +58,34 @@ class MPLinearKernel(ABC): raise NotImplementedError @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: raise NotImplementedError - def _transform_param(self, layer: torch.nn.Module, name: Optional[str], - fn: Callable) -> None: + def _transform_param( + self, layer: torch.nn.Module, name: Optional[str], fn: Callable + ) -> None: if name is not None and getattr(layer, name, None) is not None: - old_param = getattr(layer, name) new_param = fn(old_param) # replace the parameter with torch.nn.Parameter for TorchDynamo # compatibility replace_parameter( - layer, name, - torch.nn.Parameter(new_param.data, requires_grad=False)) + layer, name, torch.nn.Parameter(new_param.data, requires_grad=False) + ) def _get_weight_params( - self, layer: torch.nn.Module) -> tuple[ - torch.Tensor, # w_q - torch.Tensor, # w_s - Optional[torch.Tensor], # w_zp, - Optional[torch.Tensor] # w_gidx - ]: + self, layer: torch.nn.Module + ) -> tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor], # w_gidx + ]: return ( getattr(layer, self.w_q_name), getattr(layer, self.w_s_name), diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index a5084f6ee92cd..1759d142e6cc1 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -5,25 +5,38 @@ from typing import Optional import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 - AllSparkLinearKernel) + AllSparkLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 - BitBLASLinearKernel) + BitBLASLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 - ConchLinearKernel) + ConchLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 + CutlassW4A8LinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 - Dynamic4bitLinearKernel) + Dynamic4bitLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 - ExllamaLinearKernel) + ExllamaLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 - MacheteLinearKernel) + MacheteLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 - MarlinLinearKernel) + MarlinLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 - MPLinearKernel, MPLinearLayerConfig) + MPLinearKernel, + MPLinearLayerConfig, +) from vllm.platforms import current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, @@ -35,19 +48,19 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ def choose_mp_linear_kernel( - config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> type[MPLinearKernel]: + config: MPLinearLayerConfig, compute_capability: Optional[int] = None +) -> type[MPLinearKernel]: """ Choose an MPLinearKernel that can implement the given config for the given - compute capability. Attempts to choose the best kernel in terms of + compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. + the target device, if None uses `current_platform` to get + the compute capability. Defaults to None. Raises: ValueError: If no kernel can implement the given config. @@ -66,14 +79,18 @@ def choose_mp_linear_kernel( for kernel in _POSSIBLE_KERNELS: if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') + f" {kernel.__name__} disabled by environment variable" + ) continue - if (compute_capability is not None - and kernel.get_min_capability() > compute_capability): + if ( + compute_capability is not None + and kernel.get_min_capability() > compute_capability + ): failure_reasons.append( f"{kernel.__name__} requires capability " f"{kernel.get_min_capability()}, current compute " - f" capability is {compute_capability}") + f" capability is {compute_capability}" + ) continue can_implement, failure_reason = kernel.can_implement(config) @@ -81,10 +98,10 @@ def choose_mp_linear_kernel( return kernel else: failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' + f" {kernel.__name__} cannot implement due to: {failure_reason}" ) raise ValueError( - "Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + "Failed to find a kernel that can implement the " + "WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons) + ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py index 785e559df8f75..c353372b05ec1 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -8,22 +8,21 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + check_allspark_supported_dtype_shape, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class AllSparkLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.has_g_idx: return False, "Act reordering currently not supported by AllSpark" @@ -35,7 +34,8 @@ class AllSparkLinearKernel(MPLinearKernel): c.partition_weight_shape[1], # out_features c.group_size, c.weight_type, - c.act_type) + c.act_type, + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -49,8 +49,8 @@ class AllSparkLinearKernel(MPLinearKernel): sm_count = properties.multi_processor_count sm_version = properties.major * 10 + properties.minor gemm_args = {} - gemm_args['sm_count'] = sm_count - gemm_args['sm_version'] = sm_version + gemm_args["sm_count"] = sm_count + gemm_args["sm_version"] = sm_version self.gemm_args = gemm_args @@ -59,43 +59,42 @@ class AllSparkLinearKernel(MPLinearKernel): old_scale_param = getattr(layer, self.w_s_name) assert isinstance(old_weight_param, BasevLLMParameter) - permute_param_layout_(old_weight_param, - input_dim=0, - output_dim=1, - packed_dim=0) + permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0) assert isinstance(old_scale_param, BasevLLMParameter) permute_param_layout_(old_scale_param, input_dim=0, output_dim=1) # unpack weight from K / 4 x N int32 to K x N uint8 - new_weight_param = torch.nn.Parameter(old_weight_param.data, - requires_grad=False) - new_weight_param.data = new_weight_param.data.t().contiguous().view( - dtype=torch.uint8) + new_weight_param = torch.nn.Parameter( + old_weight_param.data, requires_grad=False + ) + new_weight_param.data = ( + new_weight_param.data.t().contiguous().view(dtype=torch.uint8) + ) new_weight_param.data = new_weight_param.data.t().contiguous() - new_scale_param = torch.nn.Parameter(old_scale_param.data, - requires_grad=False) + new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False) # reorder K x N weight as N32K16 format for Ampere W8A16 - new_weight_param.data, new_scale_param.data, _ = \ - ops.allspark_repack_weight( - new_weight_param.data, new_scale_param.data, None, - c.zero_points) + new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight( + new_weight_param.data, new_scale_param.data, None, c.zero_points + ) replace_parameter(layer, self.w_q_name, new_weight_param.data) replace_parameter(layer, self.w_s_name, new_scale_param.data) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config gemm_args = self.gemm_args w_q, w_s, _, _ = self._get_weight_params(layer) reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) output = ops.allspark_w8a16_gemm( a=reshaped_x, @@ -104,11 +103,12 @@ class AllSparkLinearKernel(MPLinearKernel): b_qzeros=None, n=c.partition_weight_shape[1], group_size=c.group_size, - sm_count=gemm_args['sm_count'], - sm_version=gemm_args['sm_version'], + sm_count=gemm_args["sm_count"], + sm_version=gemm_args["sm_version"], CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp=c.zero_points, - n32k16_reorder=True) + n32k16_reorder=True, + ) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py index 0eca3b4c024e7..d1ff582c4e216 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -7,14 +7,19 @@ import torch from packaging import version from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, - MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx, - check_bitblas_supports_shape, query_bitblas_supported_quant_types, - unpack_gptq_qweight, unpack_gptq_qzeros) + BITBLAS_OPTIMIZE_FEATURES, + BITBLAS_SUPPORTED_GROUP_SIZES, + MINIMUM_BITBLAS_VERSION, + bitblas_make_empty_g_idx, + bitblas_sort_g_idx, + check_bitblas_supports_shape, + query_bitblas_supported_quant_types, + unpack_gptq_qweight, + unpack_gptq_qzeros, +) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -22,7 +27,6 @@ logger = init_logger(__name__) class BitBLASLinearKernel(MPLinearKernel): - OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES ENABLE_TUNING: bool = True MATMUL_LAYOUT: str = "nt" @@ -45,8 +49,9 @@ class BitBLASLinearKernel(MPLinearKernel): bitblas_quant_config: Optional[QuantizationConfig] = None, ): self.quant_config = bitblas_quant_config - super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name, - w_gidx_param_name) + super().__init__( + c, w_q_param_name, w_s_param_name, w_zp_param_name, w_gidx_param_name + ) def repack_bitblas_from_gptq( self, @@ -55,19 +60,18 @@ class BitBLASLinearKernel(MPLinearKernel): qzeros: Optional[torch.Tensor] = None, ): from bitblas.quantization.utils import general_compress + assert self.bitblas_matmul is not None, "bitblas_matmul is None" quant_config = self.quant_config # qweight in gptq old quant linear stored with # (outfeatures, infeatures), should be transposed. - qweight = b_q_weight.T.contiguous().view( - quant_config.torch_storage_dtype) # type: ignore[union-attr] - intweight = unpack_gptq_qweight( - qweight, - quant_config.weight_bits).contiguous() # type: ignore[union-attr] + qweight = b_q_weight.T.contiguous().view(quant_config.torch_storage_dtype) # type: ignore[union-attr] + intweight = unpack_gptq_qweight(qweight, quant_config.weight_bits).contiguous() # type: ignore[union-attr] if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined] qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined] - intweight.cpu()).cuda() + intweight.cpu() + ).cuda() # scales in gptq old quant linear stored with # (infeatures // group_size, outfeatures), should be transposed. scales = scales.T.contiguous() @@ -91,9 +95,14 @@ class BitBLASLinearKernel(MPLinearKernel): general_compress( intzeros.T.contiguous().cpu().numpy(), weight_bits, - )).to(qweight.device). - to(quant_config.torch_storage_dtype # type: ignore[union-attr] - ).contiguous()) + ) + ) + .to(qweight.device) + .to( + quant_config.torch_storage_dtype # type: ignore[union-attr] + ) + .contiguous() + ) else: raise ValueError("Unsupported zeros type: {}".format(zeros_mode)) @@ -104,41 +113,50 @@ class BitBLASLinearKernel(MPLinearKernel): return 70 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: - + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: is_bitblas_installed = True try: import bitblas + if version.parse(bitblas.__version__) < version.parse( - MINIMUM_BITBLAS_VERSION): + MINIMUM_BITBLAS_VERSION + ): raise ImportError( "bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError: is_bitblas_installed = False if not is_bitblas_installed: - return False, "bitblas is not installed. Please install bitblas "\ - "by running `pip install bitblas>="\ - f"{MINIMUM_BITBLAS_VERSION}`" + return ( + False, + "bitblas is not installed. Please install bitblas " + "by running `pip install bitblas>=" + f"{MINIMUM_BITBLAS_VERSION}`", + ) quant_types = query_bitblas_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: - return False, (f"Quant type ({c.weight_type}) not supported by" - f" BitBLAS, supported types are: {quant_types}") + return False, ( + f"Quant type ({c.weight_type}) not supported by" + f" BitBLAS, supported types are: {quant_types}" + ) if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: - return False, (f"Group size ({c.group_size}) not supported by " - "BitBLAS, supported group sizes are: " - f"{BITBLAS_SUPPORTED_GROUP_SIZES}") + return False, ( + f"Group size ({c.group_size}) not supported by " + "BitBLAS, supported group sizes are: " + f"{BITBLAS_SUPPORTED_GROUP_SIZES}" + ) return check_bitblas_supports_shape( c.partition_weight_shape[1], # out_features c.partition_weight_shape[0], # in_features c.full_weight_shape[0], # in_features - c.group_size) + c.group_size, + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -150,14 +168,15 @@ class BitBLASLinearKernel(MPLinearKernel): # Default names since bitblas requires empty parameters for these, # TODO: remove this requirement from bitblas (allow optional tensors) - if self.w_gidx_name is None: - self.w_gidx_name = "g_idx" - if self.w_zp_name is None: - self.w_zp_name = "qzeros" + if getattr(self, "w_gidx_name", None) is None: + self.w_gidx_name: str = "g_idx" + if getattr(self, "w_zp_name", None) is None: + self.w_zp_name: str = "qzeros" if c.has_g_idx: g_idx, g_idx_sort_indices = bitblas_sort_g_idx( - getattr(layer, self.w_gidx_name)) + getattr(layer, self.w_gidx_name) + ) self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) layer.g_idx_sort_indices = g_idx_sort_indices else: @@ -170,13 +189,11 @@ class BitBLASLinearKernel(MPLinearKernel): setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device)) # Repack weights - bitblas_qweight, bitblas_scales, bitblas_qzeros = ( - self.repack_bitblas_from_gptq( - layer.qweight, - layer.scales, - None if quant_config.is_sym else # type: ignore[union-attr] - layer.qzeros, # type: ignore[union-attr] - )) + bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + None if quant_config.is_sym else layer.qzeros, # type: ignore[union-attr] + ) replace_parameter(layer, self.w_q_name, bitblas_qweight) replace_parameter(layer, self.w_s_name, bitblas_scales) if bitblas_qzeros is not None: @@ -213,6 +230,7 @@ class BitBLASLinearKernel(MPLinearKernel): bits, ): from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] quant_config = self.quant_config with_scaling = False @@ -249,30 +267,33 @@ class BitBLASLinearKernel(MPLinearKernel): zeros_mode=zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( - matmul_config, enable_tuning) + matmul_config, enable_tuning + ) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul, auto_detect_nvidia_target from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: - global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, - BITBLAS_TARGET) + global_operator_cache.load_from_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, - target=BITBLAS_TARGET, - enable_tuning=False) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) TUNING_MESSAGE = ( - f"BitBLAS Operator {config} tuned and saved to database.") + f"BitBLAS Operator {config} tuned and saved to database." + ) logger.info(TUNING_MESSAGE) else: _message = f"BitBLAS Operator {config} created without tuning. " @@ -288,7 +309,7 @@ class BitBLASLinearKernel(MPLinearKernel): x: torch.Tensor, ) -> torch.Tensor: output_size_per_partition = self.config.partition_weight_shape[1] - out_shape = x.shape[:-1] + (output_size_per_partition, ) + out_shape = x.shape[:-1] + (output_size_per_partition,) args = [x, layer.qweight, layer.scales] if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined] args.append(layer.qzeros) @@ -298,5 +319,6 @@ class BitBLASLinearKernel(MPLinearKernel): def apply_weights(self, layer, x, bias=None): NOT_IMPLEMENT_MESSAGE = ( f"{self.__class__.__name__}.apply_weights is not implemented. " - "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead") + "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead" + ) raise NotImplementedError(NOT_IMPLEMENT_MESSAGE) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py index f80af548f0199..281fca7888ab3 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py @@ -6,44 +6,49 @@ from typing import Final, Optional import torch -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.scalar_type import scalar_types from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig _CONCH_SUPPORTED_WEIGHT_TYPES: Final = [ - scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8, - scalar_types.uint8b128 + scalar_types.uint4, + scalar_types.uint8, + scalar_types.uint4b8, + scalar_types.uint8b128, ] _CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128] class ConchLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES: - error_msg = f"Weight type ({c.weight_type}) not supported by "\ - "ConchLinearKernel, supported types are: " \ - f"{_CONCH_SUPPORTED_WEIGHT_TYPES}" + error_msg = ( + f"Weight type ({c.weight_type}) not supported by " + "ConchLinearKernel, supported types are: " + f"{_CONCH_SUPPORTED_WEIGHT_TYPES}" + ) return False, error_msg if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES: - error_msg = f"Group size ({c.group_size}) not supported by "\ - "ConchLinearKernel, supported group sizes are: " \ - f"{_CONCH_SUPPORTED_GROUP_SIZES}" + error_msg = ( + f"Group size ({c.group_size}) not supported by " + "ConchLinearKernel, supported group sizes are: " + f"{_CONCH_SUPPORTED_GROUP_SIZES}" + ) return False, error_msg if find_spec("conch") is None: - error_msg = "conch-triton-kernels is not installed, please "\ - "install it via `pip install conch-triton-kernels` "\ - "and try again!" + error_msg = ( + "conch-triton-kernels is not installed, please " + "install it via `pip install conch-triton-kernels` " + "and try again!" + ) return False, error_msg return True, None @@ -52,7 +57,6 @@ class ConchLinearKernel(MPLinearKernel): # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) @@ -68,10 +72,12 @@ class ConchLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: from conch.ops.quantization.gemm import mixed_precision_gemm w_q, w_s, w_zp, _ = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py new file mode 100644 index 0000000000000..f5df7a244b426 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class CutlassW4A8LinearKernel(MPLinearKernel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # dynamic per-tok fp8 activation quantization + self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cuda(): + return False, "CUTLASS only supported on CUDA" + + if not current_platform.is_device_capability(90): + return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)" + + if c.act_type != torch.float8_e4m3fn: + return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" + + if c.has_g_idx: + return False, "Act reordering not supported by CUTLASS W4A8" + + if c.zero_points: + return False, "Zero points not supported by CUTLASS W4A8" + + if c.weight_type != scalar_types.int4: + return ( + False, + f"Quant type ({c.weight_type}) not supported by " + "CUTLASS W4A8, only supported int4", + ) + + # TODO(czhu): support -1 (column-wise) + if c.group_size != 128: + return False, "Only group_size 128 is supported" + + in_features, out_features = c.partition_weight_shape + if in_features % 128 or out_features % 128: + return ( + False, + f"K and N must be divisible by 128, got {c.partition_weight_shape}", + ) + + if c.out_type != torch.bfloat16: + return ( + False, + f"Only bfloat16 output type currently supportedgot {c.out_type=}", + ) + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + # TODO(czhu): optimize speed/mem usage + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous().to(torch.float8_e4m3fn) + x.data = ops.cutlass_pack_scale_fp8(x.data) + return x + + # Encode/reorder weights and pack scales + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + self._transform_param(layer, "weight_chan_scale", lambda x: x) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + w_ch_s = layer.weight_chan_scale + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) + + x_2d, act_scales = self.quant_fp8(x_2d) + output = ops.cutlass_w4a8_mm( + a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=w_ch_s, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py index 7bd326f47f9e4..7631236e6f642 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py @@ -20,37 +20,45 @@ class Dynamic4bitLinearKernel(MPLinearKernel): return 1 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cpu(): return False, "Only CPU is supported" if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: return False, f"Unsupported quant type {c.weight_type}" - if current_platform.get_cpu_architecture( - ) == CpuArchEnum.ARM and c.act_type not in [ + if ( + current_platform.get_cpu_architecture() == CpuArchEnum.ARM + and c.act_type + not in [ torch.float32, - ]: - return False, "Dynamic4bitLinearKernel on Arm requires"\ - " Float32 activations" + ] + ): + return False, "Dynamic4bitLinearKernel on Arm requires Float32 activations" if c.full_weight_shape[0] % c.group_size != 0: - return False, f"Group size ({c.group_size}) does not evenly divide"\ - " the number of input features "\ - f"({c.full_weight_shape[0]})" + return ( + False, + f"Group size ({c.group_size}) does not evenly divide" + " the number of input features " + f"({c.full_weight_shape[0]})", + ) if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: try: # Attempt to retrieve the operation _ = torch.ops.aten._dyn_quant_matmul_4bit except AttributeError: - return False, f"PyTorch {torch.__version__} does not support"\ - " _dyn_quant_matmul_4bit. Install a newer version" + return ( + False, + f"PyTorch {torch.__version__} does not support" + " _dyn_quant_matmul_4bit. Install a newer version", + ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config packed_weight = getattr(layer, self.w_q_name) packed_weight = packed_weight.add(8) - uint8_packed = (packed_weight[::, 1::2] << 4 - | packed_weight[::, ::2]).to(torch.uint8) + uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to( + torch.uint8 + ) scales = getattr(layer, self.w_s_name) block_size = c.group_size @@ -71,22 +79,34 @@ class Dynamic4bitLinearKernel(MPLinearKernel): # Repack weights as per kernel requirement w = torch.ops.aten._dyn_quant_pack_4bit_weight( - uint8_packed, scales, layer.bias, block_size, - c.partition_weight_shape[0], c.partition_weight_shape[1]) - replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(w, requires_grad=False)) + uint8_packed, + scales, + layer.bias, + block_size, + c.partition_weight_shape[0], + c.partition_weight_shape[1], + ) + replace_parameter( + layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False) + ) setattr(layer, self.w_s_name, None) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) w_q = getattr(layer, self.w_q_name) output = torch.ops.aten._dyn_quant_matmul_4bit( - x_2d, w_q, c.group_size, c.partition_weight_shape[0], - c.partition_weight_shape[1]) + x_2d, + w_q, + c.group_size, + c.partition_weight_shape[0], + c.partition_weight_shape[1], + ) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py index fef333e862d5a..a57d3f65267ec 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -7,9 +7,9 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_quantized_values_into_int32) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + pack_quantized_values_into_int32, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.scalar_type import scalar_types from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -25,31 +25,41 @@ class ExllamaLinearKernel(MPLinearKernel): return 60 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: - if c.has_g_idx and\ - c.partition_weight_shape[0] != c.full_weight_shape[0]: - return False, "Act reordering currently not supported by Exllama, "\ - "when the input features are partitioned across "\ - "devices" + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]: + return ( + False, + "Act reordering currently not supported by Exllama, " + "when the input features are partitioned across " + "devices", + ) if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: - return False, "Output features must be a multiple of the pack " \ - "factor (32 / num_bits) so that we can correctly " \ - "pack the zero points" + return ( + False, + "Output features must be a multiple of the pack " + "factor (32 / num_bits) so that we can correctly " + "pack the zero points", + ) if c.act_type != torch.float16: return False, "Exllama only supports float16 activations" if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: - return False, f"Quant type ({c.weight_type}) not supported by "\ - "Exllama, supported types are: "\ - f"{cls.SUPPORTED_QUANT_TYPES}" + return ( + False, + f"Quant type ({c.weight_type}) not supported by " + "Exllama, supported types are: " + f"{cls.SUPPORTED_QUANT_TYPES}", + ) if c.full_weight_shape[0] % c.group_size != 0: - return False, f"Group size ({c.group_size}) does not evenly divide"\ - " the number of input features "\ - f"({c.full_weight_shape[0]})" + return ( + False, + f"Group size ({c.group_size}) does not evenly divide" + " the number of input features " + f"({c.full_weight_shape[0]})", + ) return True, None @@ -70,21 +80,23 @@ class ExllamaLinearKernel(MPLinearKernel): # exllama kernel adding 1 to the zero points during inference) # Documentation of the bug can be found here: # https://garden.danieldk.eu/GPTQ-Checkpoint-Format - zeros = torch.full((groups, out_features), - c.weight_type.bias - 1, - dtype=torch.int32, - device=device) + zeros = torch.full( + (groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device, + ) else: raise NotImplementedError( "A 0 zero-point is not supported by Exllama due to " "a bug in the original GPTQ checkpoint format leading to " "exllama kernel adding 1 to the zero points during " - "inference") - zeros = pack_quantized_values_into_int32(zeros, - c.weight_type, - packed_dim=1) - setattr(layer, self.w_zp_name, - torch.nn.Parameter(zeros, requires_grad=False)) + "inference" + ) + zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1) + setattr( + layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False) + ) if c.has_g_idx: @@ -96,10 +108,9 @@ class ExllamaLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) else: self.w_gidx_name = "g_idx" - empty_g_idx = torch.nn.Parameter(torch.empty((0, ), - dtype=torch.int, - device=device), - requires_grad=False) + empty_g_idx = torch.nn.Parameter( + torch.empty((0,), dtype=torch.int, device=device), requires_grad=False + ) setattr(layer, self.w_gidx_name, empty_g_idx) def transform_w_q(x): @@ -122,21 +133,24 @@ class ExllamaLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) assert w_zp is not None, "Zero points are required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama" - output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, - c.weight_type.size_bits) + output = ops.gptq_gemm( + x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits + ) if bias is not None: output.add_(bias) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index da951ddab2e4e..df2f8fedce7e7 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -8,26 +8,27 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.machete_utils import ( - check_machete_supports_shape, query_machete_supported_group_sizes, - query_machete_supported_quant_types) + check_machete_supports_shape, + query_machete_supported_group_sizes, + query_machete_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_quantized_values_into_int32, unpack_quantized_values_into_int32) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + pack_quantized_values_into_int32, + unpack_quantized_values_into_int32, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class MacheteLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: # Machete uses CUTLASS, so it can only be compatible with Nvidia if not current_platform.is_cuda(): return False, "Machete only supported on CUDA" @@ -35,25 +36,33 @@ class MacheteLinearKernel(MPLinearKernel): if not current_platform.is_device_capability(90): return False, "Machete requires compute capability of 90 (Hopper)" - if c.has_g_idx and\ - c.partition_weight_shape[0] != c.full_weight_shape[0]: - return False, "Act reordering currently not supported by Machete, "\ - "when the input features are partitioned across "\ - "devices" + if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]: + return ( + False, + "Act reordering currently not supported by Machete, " + "when the input features are partitioned across " + "devices", + ) - if c.weight_type not in query_machete_supported_quant_types( - c.zero_points): - return False, f"Quant type ({c.weight_type}) not supported by "\ - "Machete, supported types are: "\ - f"{query_machete_supported_quant_types(c.zero_points)}" + if c.weight_type not in query_machete_supported_quant_types(c.zero_points): + return ( + False, + f"Quant type ({c.weight_type}) not supported by " + "Machete, supported types are: " + f"{query_machete_supported_quant_types(c.zero_points)}", + ) if c.group_size not in query_machete_supported_group_sizes(c.act_type): - return False, f"Group size ({c.group_size}) not supported by "\ - "Machete, supported group sizes are: "\ - f"{query_machete_supported_group_sizes(c.act_type)}" + return ( + False, + f"Group size ({c.group_size}) not supported by " + "Machete, supported group sizes are: " + f"{query_machete_supported_group_sizes(c.act_type)}", + ) - return check_machete_supports_shape(c.partition_weight_shape[0], - c.partition_weight_shape[1]) + return check_machete_supports_shape( + c.partition_weight_shape[0], c.partition_weight_shape[1] + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -64,30 +73,33 @@ class MacheteLinearKernel(MPLinearKernel): if c.has_g_idx: assert self.w_gidx_name is not None - perm = torch.argsort(getattr(layer, self.w_gidx_name))\ - .to(torch.int) + perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int) self.act_perm = lambda x: x[:, perm] # use `ops.permute_cols` if possible - if c.act_type in [torch.float16, torch.bfloat16] \ - and c.partition_weight_shape[0] % 8 == 0: + if ( + c.act_type in [torch.float16, torch.bfloat16] + and c.partition_weight_shape[0] % 8 == 0 + ): self.act_perm = partial(ops.permute_cols, perm=perm) def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) if c.has_g_idx: - x_unpacked = unpack_quantized_values_into_int32(x.data, - c.weight_type, - packed_dim=0) + x_unpacked = unpack_quantized_values_into_int32( + x.data, c.weight_type, packed_dim=0 + ) x_perm = x_unpacked[perm, :] - x.data = pack_quantized_values_into_int32(x_perm, - c.weight_type, - packed_dim=0) - x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), - a_type=c.act_type, - b_type=c.weight_type, - group_scales_type=c.act_type) + x.data = pack_quantized_values_into_int32( + x_perm, c.weight_type, packed_dim=0 + ) + x.data = ops.machete_prepack_B( + x.data.t().contiguous().t(), + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type, + ) return x def transform_w_s(x): @@ -99,9 +111,9 @@ class MacheteLinearKernel(MPLinearKernel): def transform_w_zp(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1) - x_unpacked = unpack_quantized_values_into_int32(x.data, - c.weight_type, - packed_dim=1) + x_unpacked = unpack_quantized_values_into_int32( + x.data, c.weight_type, packed_dim=1 + ) w_s = getattr(layer, self.w_s_name).data # pre-apply scales to zero-points x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous() @@ -113,15 +125,17 @@ class MacheteLinearKernel(MPLinearKernel): if c.zero_points: self._transform_param(layer, self.w_zp_name, transform_w_zp) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config w_q, w_s, w_zp, _ = self._get_weight_params(layer) x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) if c.has_g_idx: x_2d = self.act_perm(x_2d) @@ -131,12 +145,14 @@ class MacheteLinearKernel(MPLinearKernel): else: w_zp = None - output = ops.machete_mm(a=x_2d, - b_q=w_q, - b_type=c.weight_type, - b_group_zeros=w_zp, - b_group_scales=w_s, - b_group_size=c.group_size) + output = ops.machete_mm( + a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=w_zp, + b_group_scales=w_s, + b_group_size=c.group_size, + ) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 5eb99383097b5..0be448e4e3d8a 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -7,46 +7,58 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, - check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, - marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types, - unpack_cols) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + MARLIN_SUPPORTED_GROUP_SIZES, + apply_gptq_marlin_linear, + check_marlin_supports_shape, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + marlin_sort_g_idx, + marlin_zero_points, + query_marlin_supported_quant_types, + unpack_cols, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class MarlinLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: # Marlin uses inline PTX, so it can only be compatible with Nvidia if not current_platform.is_cuda(): return False, "Marlin only supported on CUDA" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: - return False, f"Quant type ({c.weight_type}) not supported by"\ - f" Marlin, supported types are: {quant_types}" + return ( + False, + f"Quant type ({c.weight_type}) not supported by" + f" Marlin, supported types are: {quant_types}", + ) if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return False, f"Group size ({c.group_size}) not supported by "\ - "Marlin, supported group sizes are: "\ - f"{MARLIN_SUPPORTED_GROUP_SIZES}" + return ( + False, + f"Group size ({c.group_size}) not supported by " + "Marlin, supported group sizes are: " + f"{MARLIN_SUPPORTED_GROUP_SIZES}", + ) return check_marlin_supports_shape( c.partition_weight_shape[1], # out_features c.partition_weight_shape[0], # in_features c.full_weight_shape[0], # in_features - c.group_size) + c.group_size, + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -55,7 +67,7 @@ class MarlinLinearKernel(MPLinearKernel): device = getattr(layer, self.w_q_name).device c = self.config - row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) # Allocate marlin workspace. @@ -71,25 +83,30 @@ class MarlinLinearKernel(MPLinearKernel): def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) - x.data = ops.gptq_marlin_repack(x.data.contiguous(), - perm=layer.g_idx_sort_indices, - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits) + x.data = ops.gptq_marlin_repack( + x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ) return x def transform_w_s(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1) - x.data = marlin_permute_scales(x.data.contiguous(), - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - group_size=c.group_size) + x.data = marlin_permute_scales( + x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size, + ) return x if c.has_g_idx: g_idx, g_idx_sort_indices = marlin_sort_g_idx( - getattr(layer, self.w_gidx_name)) + getattr(layer, self.w_gidx_name) + ) self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) layer.g_idx_sort_indices = g_idx_sort_indices else: @@ -97,16 +114,24 @@ class MarlinLinearKernel(MPLinearKernel): layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) if c.zero_points: - grouped_k = (c.partition_weight_shape[0] // - c.group_size if c.group_size != -1 else 1) - self._transform_param(layer, self.w_zp_name, lambda x: \ - marlin_zero_points( - unpack_cols(x.t(), c.weight_type.size_bits, - grouped_k, - c.partition_weight_shape[1]), + grouped_k = ( + c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 + ) + self._transform_param( + layer, + self.w_zp_name, + lambda x: marlin_zero_points( + unpack_cols( + x.t(), + c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1], + ), size_k=grouped_k, size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits)) + num_bits=c.weight_type.size_bits, + ), + ) else: setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) self._transform_param(layer, self.w_q_name, transform_w_q) @@ -115,10 +140,12 @@ class MarlinLinearKernel(MPLinearKernel): if hasattr(layer, "bias") and layer.bias is not None: layer.bias.data = marlin_permute_bias(layer.bias) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) @@ -136,4 +163,5 @@ class MarlinLinearKernel(MPLinearKernel): input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], is_k_full=self.is_k_full, - bias=bias) + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 9ebf5f3037922..d9b999e3d5ddc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -16,7 +16,6 @@ class ScaledMMLinearLayerConfig: class ScaledMMLinearKernel(ABC): - @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -24,13 +23,18 @@ class ScaledMMLinearKernel(ABC): @classmethod @abstractmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: raise NotImplementedError - def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, - w_s_param_name: str, i_s_param_name: str, - i_zp_param_name: str, azp_adj_param_name: str) -> None: + def __init__( + self, + c: ScaledMMLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + i_s_param_name: str, + i_zp_param_name: str, + azp_adj_param_name: str, + ) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -44,20 +48,23 @@ class ScaledMMLinearKernel(ABC): raise NotImplementedError @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: raise NotImplementedError def _get_weight_params( - self, layer: torch.nn.Module) -> tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - Optional[torch.Tensor], # input_scale, - Optional[torch.Tensor], # input_zp - Optional[torch.Tensor], # azp_adj - ]: + self, layer: torch.nn.Module + ) -> tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: return ( getattr(layer, self.w_q_name), getattr(layer, self.w_s_name), diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2bc68ab3ebd18..ee5416bae01c6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -5,17 +5,24 @@ import os from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( - AiterScaledMMLinearKernel) + AiterScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( - CPUScaledMMLinearKernel) + CPUScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( - CutlassScaledMMLinearKernel) + CutlassScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearKernel, ScaledMMLinearLayerConfig) + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( - TritonScaledMMLinearKernel) + TritonScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( - XLAScaledMMLinearKernel) + XLAScaledMMLinearKernel, +) from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) @@ -28,19 +35,18 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, - compute_capability: Optional[int] = None + config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None ) -> type[ScaledMMLinearKernel]: """ - Choose an ScaledMMLinearKernel that can implement the given config for the - given compute capability. Attempts to choose the best kernel in terms of + Choose an ScaledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer + config (ScaledMMLinearLayerConfig): Description of the linear layer to be implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the + the target device, if None uses `current_platform` to get the compute capability. Defaults to None. Raises: @@ -57,22 +63,25 @@ def choose_scaled_mm_linear_kernel( failure_reasons = [] for kernel in _POSSIBLE_KERNELS[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') + f" {kernel.__name__} disabled by environment variable" + ) continue # If the current platform uses compute_capability, # make sure the kernel supports the compute cability. if compute_capability is not None: kernel_min_capability = kernel.get_min_capability() - if (kernel_min_capability is not None - and kernel_min_capability > compute_capability): + if ( + kernel_min_capability is not None + and kernel_min_capability > compute_capability + ): failure_reasons.append( f"{kernel.__name__} requires capability " f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}") + f"is {compute_capability}" + ) continue can_implement, failure_reason = kernel.can_implement(config) @@ -80,10 +89,10 @@ def choose_scaled_mm_linear_kernel( return kernel else: failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' + f" {kernel.__name__} cannot implement due to: {failure_reason}" ) raise ValueError( - "Failed to find a kernel that can implement the "\ - "ScaledMM linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + "Failed to find a kernel that can implement the " + "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 7f808fa92a9a8..e97beefdd9c2c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -22,7 +22,6 @@ def rocm_aiter_gemm_w8a8_impl( bias: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - from aiter import gemm_a8w8_CK # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects @@ -40,7 +39,6 @@ def rocm_aiter_gemm_w8a8_fake( bias: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) @@ -51,57 +49,58 @@ if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8", op_func=rocm_aiter_gemm_w8a8_impl, - mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_fake, - dispatch_key=current_platform.dispatch_key, ) class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_rocm(): return ( False, - "AiterScaledMMLinearKernel requires `aiter` which is not " + - "currently supported on non-ROCm platform.") + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "currently supported on non-ROCm platform.", + ) try: import aiter # noqa: F401 # deliberately attempt to import aiter except Exception: return ( False, - "AiterScaledMMLinearKernel requires `aiter` which is not " + - "installed on ROCm.") + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "installed on ROCm.", + ) # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not ( - envs.VLLM_ROCM_USE_AITER_LINEAR \ - and envs.VLLM_ROCM_USE_AITER - ): - return (False, "AiterScaledMMLinearKernel is disabled. " + - "Enable by setting `VLLM_ROCM_USE_AITER=1` " + - "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + - "`VLLM_ROCM_USE_AITER_LINEAR` default is True.") + if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + return ( + False, + "AiterScaledMMLinearKernel is disabled. " + + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", + ) if not c.input_symmetric: - return (False, - "AiterScaledMMLinearKernel only supports symmetric " + - "quantization.") + return ( + False, + "AiterScaledMMLinearKernel only supports symmetric " + "quantization.", + ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ `AiterScaledMMLinearKernel` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` @@ -118,29 +117,27 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - assert symmetric, ("AiterScaledMMLinearKernel only supports" - " symmetric quantization.") - x_q, x_s, x_zp = ops.scaled_int8_quant(x, - i_s, - i_zp, - symmetric=symmetric) + assert symmetric, ( + "AiterScaledMMLinearKernel only supports symmetric quantization." + ) + x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric) - assert x_zp is None, ("AiterScaledMMLinearKernel only supports" - " symmetric quantization.") + assert x_zp is None, ( + "AiterScaledMMLinearKernel only supports symmetric quantization." + ) out_dtype = x.dtype - assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == w_q.shape[ - 1] and bias.dtype == out_dtype + assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype m = x_q.shape[0] # a n = w_q.shape[1] # b - per_tensor_scale_a = (x_s.numel() == 1) - per_tensor_scale_b = (w_s.numel() == 1) - per_token_scale_a = (x_s.numel() == m) - per_channel_scale_b = (w_s.numel() == n) + per_tensor_scale_a = x_s.numel() == 1 + per_tensor_scale_b = w_s.numel() == 1 + per_token_scale_a = x_s.numel() == m + per_channel_scale_b = w_s.numel() == n # @TODO: # Maybe broadcast the per-tensor-scale into per-channel-scale @@ -148,16 +145,19 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): # For now, it only supports: # - per-tensor-per-tensor a8w8 scaled GEMM, and # - per-token-per-channel a8w8 scaled GEMM - assert ((per_tensor_scale_a and per_tensor_scale_b) - or (per_token_scale_a and per_channel_scale_b)), ( - "Currently only support per-tensor-per-tensor GEMM " + - " and per-token-per-channel GEMM through AITER" - " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + - "does not support AITER block scaled GEMM.") + assert (per_tensor_scale_a and per_tensor_scale_b) or ( + per_token_scale_a and per_channel_scale_b + ), ( + "Currently only support per-tensor-per-tensor GEMM " + + " and per-token-per-channel GEMM through AITER" + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + + "does not support AITER block scaled GEMM." + ) # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s, - bias, out_dtype) + return torch.ops.vllm.rocm_aiter_gemm_w8a8( + x_q, w_q.t(), x_s, w_s, bias, out_dtype + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index 59d2b5bce962e..cb00b0c8af210 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -9,24 +9,22 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, +) from vllm.model_executor.layers.utils import check_cpu_sgl_kernel from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, - ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig class CPUScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cpu(): return False, "CPUScaledMM requires running on CPU." @@ -36,9 +34,12 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): weight = getattr(layer, self.w_q_name) dtype = weight.dtype N, K = weight.size() - if (current_platform.get_cpu_architecture() == CpuArchEnum.X86 - and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric - and check_cpu_sgl_kernel(N, K, dtype)): + if ( + current_platform.get_cpu_architecture() == CpuArchEnum.X86 + and envs.VLLM_CPU_SGL_KERNEL + and self.config.input_symmetric + and check_cpu_sgl_kernel(N, K, dtype) + ): self.linear_method = self._apply_weights_sgl self.process_weights_for_sgl(layer) else: @@ -50,8 +51,10 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # Transpose to [K, N] for convenience weight = getattr(layer, self.w_q_name) replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False)) + layer, + self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) # WEIGHT SCALE # oneDNN kernels support only per-tensor and per-channel. @@ -60,11 +63,12 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): is_fused_module = len(layer.logical_widths) > 1 weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) # INPUT SCALE if self.config.is_static_input_scheme: @@ -72,8 +76,10 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): if self.config.input_symmetric: replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False)) + layer, + self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) setattr(layer, self.i_zp_name, None) else: input_zero_point = getattr(layer, self.i_zp_name) @@ -84,16 +90,17 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): range_max = (input_scale * (int8_traits.max - azps)).max() range_min = (input_scale * (int8_traits.min - azps)).min() - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(scale, requires_grad=False)) + layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) - azp = (int8_traits.min - - range_min / scale).round().to(dtype=torch.int32) - replace_parameter(layer, self.i_zp_name, - torch.nn.Parameter(azp, requires_grad=False)) + azp = ( + (int8_traits.min - range_min / scale).round().to(dtype=torch.int32) + ) + replace_parameter( + layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) else: setattr(layer, self.i_s_name, None) @@ -105,14 +112,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): # s_a * s_b * [(A - zp_a)B] + bias = # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * GEMM_output - s_a * zp_a * adj + bias - if not (self.config.input_symmetric - and self.config.is_static_input_scheme): + if not (self.config.input_symmetric and self.config.is_static_input_scheme): weight = getattr(layer, self.w_q_name) weight_scale = getattr(layer, self.w_s_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = azp_adj * weight_scale.squeeze() - setattr(layer, self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False)) + setattr( + layer, + self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) else: setattr(layer, self.azp_adj_name, None) @@ -135,34 +144,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): weight = getattr(layer, self.w_q_name) packed_weight = torch.ops._C.convert_weight_packed(weight) replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(packed_weight, requires_grad=False)) + layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) + ) if layer.bias is not None: bias = layer.bias layer.register_parameter( - "bias_fp32", - torch.nn.Parameter(bias.float().data, requires_grad=False)) + "bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False) + ) # WEIGHT SCALE # CPU SGL kernels only support per-channel. # For per-tensor quant, convert to the per-channel case. weight_scale = getattr(layer, self.w_s_name) if not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) setattr(layer, self.i_s_name, None) setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return self.linear_method( layer, x, @@ -170,31 +182,33 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ) def _apply_weights_onednn( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. x_q, x_s, x_zp = ops.onednn_scaled_int8_quant( - x, i_s, i_zp, self.config.input_symmetric) + x, i_s, i_zp, self.config.input_symmetric + ) m = x.size(0) n = self.dnnl_handler.n out = torch.empty((m, n), dtype=x.dtype) - ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, - bias) + ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias) return out def _apply_weights_sgl( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( x, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2f982f96b0d04..13dbd55c32df9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -8,23 +8,20 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, +) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, - ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cuda(): return False, "CutlassScaledMM requires running on CUDA." @@ -35,8 +32,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # Cutlass kernels need transposed weight. weight = getattr(layer, self.w_q_name) replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False)) + layer, + self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) # WEIGHT SCALE # Cutlass kernels support only per-tensor and per-channel. @@ -45,11 +44,12 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): is_fused_module = len(layer.logical_widths) > 1 weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) # INPUT SCALE if self.config.is_static_input_scheme: @@ -57,8 +57,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): if self.config.input_symmetric: replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False)) + layer, + self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) setattr(layer, self.i_zp_name, None) else: input_zero_point = getattr(layer, self.i_zp_name) @@ -69,17 +71,16 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): range_max = (input_scale * (int8_traits.max - azps)).max() range_min = (input_scale * (int8_traits.min - azps)).min() - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(scale, requires_grad=False)) + layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - - range_min / scale).to(dtype=torch.int32) - replace_parameter(layer, self.i_zp_name, - torch.nn.Parameter(azp, requires_grad=False)) + azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) else: setattr(layer, self.i_s_name, None) @@ -88,8 +89,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) @@ -97,41 +98,44 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr(layer, self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False)) + setattr( + layer, + self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) else: setattr(layer, self.azp_adj_name, None) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), - i_s, - i_zp, - symmetric=symmetric) + x_q, x_s, x_zp = ops.scaled_int8_quant( + x.contiguous(), i_s, i_zp, symmetric=symmetric + ) if x_zp is not None: # Currently, static is always per-tensor and dynamic is per-token static = i_zp is not None azp = None if static else x_zp - return ops.cutlass_scaled_mm_azp(x_q, - w_q, - scale_a=x_s, - scale_b=w_s, - out_dtype=x.dtype, - azp_adj=azp_adj, - azp=azp, - bias=bias) - return ops.cutlass_scaled_mm(x_q, - w_q, - scale_a=x_s, - scale_b=w_s, - out_dtype=x.dtype, - bias=bias) + return ops.cutlass_scaled_mm_azp( + x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias, + ) + return ops.cutlass_scaled_mm( + x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 817565cf28277..7e21afca5750c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -12,30 +12,32 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if current_platform.is_cpu(): return ( False, - "TritonScaledMMLinearKernel requires Triton which is not " + - "currently supported on CPU.") + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.", + ) if not c.input_symmetric: - return (False, - "TritonScaledMMLinearKernel only supports symmetric " + - "quantization.") + return ( + False, + "TritonScaledMMLinearKernel only supports symmetric " + "quantization.", + ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return super().apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 0b931b2d8b815..63eee1e288618 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -9,25 +9,23 @@ from functorch.experimental.control_flow import cond # noqa: F401 from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, +) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, - ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig class XLAScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: raise NotImplementedError( "TPU platform does have a concept of compute capability, " - "this method should not be called.") + "this method should not be called." + ) @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_tpu(): return False, "ScaledMMXLA requires running on TPU." @@ -46,8 +44,9 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): # WEIGHT # [out, in] (different than cutlass_scaled_mm) weight = getattr(layer, self.w_q_name) - replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(weight.data, requires_grad=False)) + replace_parameter( + layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) + ) # WEIGHT SCALE # XLA kernels support only per-tensor and per-channel. @@ -56,14 +55,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): is_fused_module = len(layer.logical_widths) > 1 weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) # [out_channel,] (different than cutlass_scaled_mm) weight_scale = weight_scale.squeeze(-1) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) # Only support symmetric dynamic activation quantization. setattr(layer, self.i_s_name, None) @@ -74,8 +74,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): # to specialize the graph since bias is not dynamic. warnings.filterwarnings( "ignore", - message= - "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501 ) def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): @@ -84,14 +83,17 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel): def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): return x + bias - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) # Required to register custom ops. import torch_xla.experimental.custom_kernel # noqa: F401 + out = torch.ops.xla.quantized_matmul_int8( x, w_q, diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c1..78456dcf1ca56 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -5,7 +5,9 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -14,12 +16,12 @@ logger = init_logger(__name__) class BaseKVCacheMethod(QuantizeMethodBase): """ Quant method that adds `_k_scale` and `_v_scale` attributes to the - Attention layer to support loading those scaling factors from checkpoints. + Attention layer to support loading those scaling factors from checkpoints. The k/v_scale will be used to: - quantize k/v_cache entries before saving them to the cache - dequantize k/v_cache entries before fetching them from the cache - :param quant_config: the appropriate QuantizationConfig + :param quant_config: the appropriate QuantizationConfig """ def __init__(self, quant_config: QuantizationConfig): @@ -33,19 +35,14 @@ class BaseKVCacheMethod(QuantizeMethodBase): # Initialize the Q and KV cache scales to -1.0, an invalid value. # If the q and k/v_scales appear in the checkpoint, it will be # overwritten when loading weights. - layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) - layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) - layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) # Initialize P = softmax(QK^T) scales - layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: - raise RuntimeError( - f"{self.__class__.__name__}.apply should not be called.") + raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 @@ -77,29 +74,31 @@ class BaseKVCacheMethod(QuantizeMethodBase): k_scale *= 2 v_scale *= 2 - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError( + "Only support per-tensor scaling factor for fp8 KV cache" + ) if layer.q_scale < 0.0: logger.warning_once( "Checkpoint does not provide a q scaling factor. " "Setting it to k_scale. This only matters for " - "the flash-attn backend.") + "FP8 Attention backends (flash-attn or flashinfer)." + ) layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) layer._k_scale_float = k_scale layer._v_scale_float = v_scale - if (k_scale == 1.0 and v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): + if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: logger.warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + "Using KV cache scaling factor 1.0 for fp8_e4m3. " + "If this is unintended, verify that k/v_scale " + "scaling factors are properly set in the checkpoint." + ) if layer.q_scale > 0.0: q_scale = layer.q_scale @@ -115,23 +114,31 @@ class BaseKVCacheMethod(QuantizeMethodBase): else: prob_scale = 1.0 - is_singleton_float = lambda x: isinstance(x, float) or isinstance( - x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() - if not is_singleton_float(q_scale) or not is_singleton_float( - prob_scale): - raise ValueError("Only support per-tensor scaling factor" - "for fp8-quantized Q/prob") + is_singleton_float = ( + lambda x: isinstance(x, float) + or isinstance(x, torch.Tensor) + and x.numel() == 1 + and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError( + "Only support per-tensor scaling factorfor fp8-quantized Q/prob" + ) # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = ( + q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + ) + layer._prob_scale.copy_(prob_scale) - if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 - or prob_scale == 1.0): + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): logger.warning_once( f"Using uncalibrated q_scale {q_scale} and/or prob_scale " f"{prob_scale} with fp8 attention. This may cause accuracy " "issues. Please make sure q/prob scaling factors are " - "available in the fp8 checkpoint.") + "available in the fp8 checkpoint." + ) del layer.k_scale del layer.v_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 046234057f04a..c285b10720d86 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch.nn import Module @@ -11,39 +11,74 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe) + is_valid_flashinfer_cutlass_fused_moe, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, - select_nvfp4_gemm_impl) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, + reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + FlashinferMoeBackend, + apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, - select_cutlass_fp8_gemm_impl, swap_w13_to_w31) + flashinfer_cutlass_moe_fp8, + get_flashinfer_moe_backend, + register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, + swap_w13_to_w31, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, is_fp4_marlin_supported, - prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) + apply_fp4_marlin_linear, + is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, + prepare_moe_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale) + GroupShape, + cutlass_fp4_supported, + is_layer_skipped, + swizzle_blockscale, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, requantize_with_max_scale) -from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParameter) + Fp8LinearOp, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types from vllm.utils import next_power_of_2 -from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, - has_flashinfer_moe) +from vllm.utils.flashinfer import ( + flashinfer_scaled_fp4_mm, + has_flashinfer, + has_flashinfer_moe, +) + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -63,10 +98,12 @@ class ModelOptFp8Config(QuantizationConfig): super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method - self.exclude_modules = exclude_modules + self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: - logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" - " the format is experimental and could change.") + logger.warning( + "Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change." + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -84,9 +121,14 @@ class ModelOptFp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) + @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: """Detect if this ModelOpt config should be used based on quantization config.""" @@ -122,34 +164,36 @@ class ModelOptFp8Config(QuantizationConfig): # ModelOpt format: {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): - raise ValueError( - "Expected 'quantization' to be a dictionary in config") + raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo", "") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") + # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules") else: # Compressed-tensors style format: # {"quant_algo": "...", "quant_method": "modelopt"} quant_method = config.get("quant_algo", "") kv_cache_quant_method = config.get("kv_cache_quant_algo") - exclude_modules = config.get("exclude_modules") + # "ignore" is the key in config.json + exclude_modules = config.get("ignore") if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " - "quant configuration.") - is_checkpoint_fp8_serialized = ("FP8" in quant_method) + "quant configuration." + ) + is_checkpoint_fp8_serialized = "FP8" in quant_method - return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, - exclude_modules) + return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules) def is_layer_excluded(self, prefix: str) -> bool: """ Check if a layer should be excluded from quantization. + Handles both exact matching (for fused layers) and substring matching. This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the @@ -158,20 +202,34 @@ class ModelOptFp8Config(QuantizationConfig): if self.exclude_modules is None: return False - # Check if any excluded module matches the prefix + # First check exact matching with fused layer support + if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): + return True + + # Then check substring matching for patterns not caught by exact match for module in self.exclude_modules: - if (module in prefix - or (prefix.startswith("language_model.") - and module in prefix.removeprefix("language_model."))): + # Skip exact matches already handled above + if module != prefix and ( + module in prefix + or ( + prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model.") + ) + ): return True return False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): if self.is_layer_excluded(prefix): return UnquantizedLinearMethod() + # Check if this is a vision model layer that should not be quantized + if "vision_tower" in prefix or "vision_model" in prefix: + return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) @@ -195,7 +253,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR + ) def create_weights( self, @@ -213,29 +272,34 @@ class ModelOptFp8LinearMethod(LinearMethodBase): layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) @@ -245,11 +309,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase): max_w_scale = layer.weight_scale.max() if not (layer.weight_scale == layer.weight_scale[0]).all(): max_w_scale, weight = requantize_with_max_scale( - layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight, layer.weight_scale, layer.logical_widths + ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) def apply( self, @@ -257,11 +321,13 @@ class ModelOptFp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias) + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) class ModelOptFp8MoEMethod(FusedMoEMethodBase): @@ -281,11 +347,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported) + cutlass_fp8_supported, + ) + self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( @@ -294,27 +360,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def maybe_make_prepare_finalize( self, - moe: FusedMoEConfig, ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if self.fused_experts is not None or \ - self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe, - layer=self.layer, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize + # TRT LLM not supported with all2all yet. + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + self.moe + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None experts = select_cutlass_fp8_gemm_impl( - moe, - self.layer, + self.moe, + self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -328,18 +395,21 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - # Use FP8 dtype if checkpoint is serialized - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) weight_loader = extra_weight_attrs.get("weight_loader") w13_weight = ModelWeightParameter( - data=torch.empty(num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=weight_dtype), + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=weight_dtype, + ), input_dim=2, output_dim=1, weight_loader=weight_loader, @@ -347,10 +417,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.register_parameter("w13_weight", w13_weight) w2_weight = ModelWeightParameter( - data=torch.empty(num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=weight_dtype), + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=weight_dtype, + ), input_dim=2, output_dim=1, weight_loader=weight_loader, @@ -370,7 +442,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): weight_loader=weight_loader, ) w2_weight_scale = PerTensorScaleParameter( - data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) @@ -378,15 +450,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): # Set weight loader attributes for scales extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # INPUT SCALES - Per-tensor scaling for ModelOpt w13_input_scale = PerTensorScaleParameter( - data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) w2_input_scale = PerTensorScaleParameter( - data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) @@ -397,22 +470,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): Only supports pre-quantized checkpoints with FP8 weights and scales. """ - layer.w13_weight = Parameter(layer.w13_weight.data, - requires_grad=False) + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - per_tensor_dequantize) + per_tensor_dequantize, + ) # Handle scale parameters - if hasattr(layer, - "w13_weight_scale") and layer.w13_weight_scale is not None: + if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. if layer.w13_weight_scale.dim() == 2: - # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values @@ -425,48 +496,62 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): for shard_id in range(2): # w1 and w3 # Dequantize using the original scale for this shard dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - intermediate_size, :], + layer.w13_weight[expert_id][ + start : start + intermediate_size, : + ], layer.w13_weight_scale[expert_id][shard_id], ) # Requantize using the combined max scale ( - layer.w13_weight[expert_id][start:start + - intermediate_size, :], + layer.w13_weight[expert_id][ + start : start + intermediate_size, : + ], _, - ) = scaled_fp8_quant(dq_weight, - max_w13_scales[expert_id]) + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) start += intermediate_size # Update the scale parameter to be per-expert - layer.w13_weight_scale = Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) else: - layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, - requires_grad=False) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) - if hasattr(layer, - "w2_weight_scale") and layer.w2_weight_scale is not None: - layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, - requires_grad=False) + if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) # Input scales must be equal for each expert in fp8 MoE layers. - if hasattr(layer, - "w13_input_scale") and layer.w13_input_scale is not None: - layer.w13_input_scale = Parameter(layer.w13_input_scale.max(), - requires_grad=False) - if hasattr(layer, - "w2_input_scale") and layer.w2_input_scale is not None: - layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), - requires_grad=False) + if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: + layer.w13_input_scale = Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: + layer.w2_input_scale = Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) register_moe_scaling_factors(layer) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + return None + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=False, + ) def apply( self, @@ -482,6 +567,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -489,14 +575,17 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( - "EPLB not supported for `ModelOptFp8MoEMethod` yet.") + "EPLB not supported for `ModelOptFp8MoEMethod` yet." + ) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") + assert self.fused_experts is None + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, @@ -507,10 +596,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -520,59 +610,62 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # + # Note: the order here is important. self.fused_experts can override + # cutlass or fused_experts. + # + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - per_channel_quant=False, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts + + assert self.moe_quant_config is not None + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + quant_config=self.moe_quant_config, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) class ModelOptNvFp4Config(QuantizationConfig): @@ -590,7 +683,8 @@ class ModelOptNvFp4Config(QuantizationConfig): if is_checkpoint_nvfp4_serialized: logger.warning( "Detected ModelOpt NVFP4 checkpoint. Please note that" - " the format is experimental and could change in future.") + " the format is experimental and could change in future." + ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo @@ -612,9 +706,14 @@ class ModelOptNvFp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) + @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: """Detect if this ModelOpt FP4 config should be used based on quantization config.""" if hf_quant_cfg is None: @@ -652,8 +751,7 @@ class ModelOptNvFp4Config(QuantizationConfig): # {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): - raise ValueError( - "Expected 'quantization' to be a dictionary in config") + raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo", "") if not quant_method: @@ -667,8 +765,10 @@ class ModelOptNvFp4Config(QuantizationConfig): elif isinstance(kv_cache_quant_algo_raw, str): kv_cache_quant_algo = kv_cache_quant_algo_raw else: - raise ValueError(f"kv_cache_quant_algo must be a string, got " - f"{type(kv_cache_quant_algo_raw)}") + raise ValueError( + f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}" + ) # Handle group_size with proper type validation group_size_raw = quant_config.get("group_size") @@ -680,13 +780,16 @@ class ModelOptNvFp4Config(QuantizationConfig): try: group_size = int(group_size_raw) except (ValueError, TypeError): - raise ValueError(f"group_size must be an integer, got " - f"{type(group_size_raw)}") from None + raise ValueError( + f"group_size must be an integer, got {type(group_size_raw)}" + ) from None + # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules", []) if not isinstance(exclude_modules, list): - raise ValueError(f"exclude_modules must be a list, got " - f"{type(exclude_modules)}") + raise ValueError( + f"exclude_modules must be a list, got {type(exclude_modules)}" + ) else: # Compressed-tensors style format: # {"quant_algo": "...", "quant_method": "modelopt"} @@ -700,8 +803,10 @@ class ModelOptNvFp4Config(QuantizationConfig): elif isinstance(kv_cache_quant_algo_raw, str): kv_cache_quant_algo = kv_cache_quant_algo_raw else: - raise ValueError(f"kv_cache_quant_algo must be a string, got " - f"{type(kv_cache_quant_algo_raw)}") + raise ValueError( + f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}" + ) # Handle group_size with proper type validation group_size_raw = config.get("group_size") @@ -713,60 +818,85 @@ class ModelOptNvFp4Config(QuantizationConfig): try: group_size = int(group_size_raw) except (ValueError, TypeError): - raise ValueError(f"group_size must be an integer, got " - f"{type(group_size_raw)}") from None + raise ValueError( + f"group_size must be an integer, got {type(group_size_raw)}" + ) from None - exclude_modules = config.get("exclude_modules", []) + # "ignore" is the key in config.json + exclude_modules = config.get("ignore", []) if not isinstance(exclude_modules, list): - raise ValueError(f"exclude_modules must be a list, got " - f"{type(exclude_modules)}") + raise ValueError( + f"exclude_modules must be a list, got {type(exclude_modules)}" + ) if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " - "quant configuration.") - is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) + "quant configuration." + ) + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method # For FP4, these fields are required if is_checkpoint_nvfp4_serialized and "quantization" in config: # Check if required fields are present in the quantization config quant_config = config["quantization"] - required_fields = [ - "group_size", "kv_cache_quant_algo", "exclude_modules" - ] + required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config ] if missing_fields: raise ValueError( f"NVFP4 quantization requires the following fields in " - f"hf_quant_config.json: {missing_fields}") + f"hf_quant_config.json: {missing_fields}" + ) - return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, - exclude_modules, group_size) + return cls( + is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo, + exclude_modules, + group_size, + ) - def is_layer_excluded(self, prefix: str, - exclude_modules: list[str]) -> bool: + def is_layer_excluded(self, prefix: str) -> bool: + """ + Check if a layer should be excluded from quantization. + Handles both exact matching (for fused layers) and pattern matching. + """ + # First check exact matching with fused layer support + if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): + return True + + # Check regex pattern matching for patterns not caught by exact match import regex as re - for pattern in exclude_modules: - regex_str = pattern.replace('.', r'\.').replace('*', r'.*') - if re.fullmatch(regex_str, prefix): - return True + + for pattern in self.exclude_modules: + # Skip patterns that would be caught by exact matching + if "*" in pattern or "." in pattern: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True return False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + + skip_layer = self.is_layer_excluded(prefix) if isinstance(layer, LinearBase): - if (is_layer_skipped(prefix, self.exclude_modules) - or self.is_layer_excluded(prefix, self.exclude_modules)): + if skip_layer: + return UnquantizedLinearMethod() + # Check if this is a vision model layer that should not be quantized + if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): + if skip_layer: + return None return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return None @@ -776,8 +906,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): Supports loading kv-cache scaling factors from FP8 checkpoints. """ - def __init__(self, quant_config: Union[ModelOptFp8Config, - ModelOptNvFp4Config]): + def __init__(self, quant_config: Union[ModelOptFp8Config, ModelOptNvFp4Config]): super().__init__(quant_config) @@ -805,9 +934,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): elif is_fp4_marlin_supported(): self.backend = "marlin" else: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above.") + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) def create_weights( self, @@ -821,59 +952,69 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition - if (input_size_per_partition % 16 != 0): - raise ValueError("Unsupported model when in features size is " - "not multiple of 16") + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model when in features size is not multiple of 16" + ) # The nvfp4 weight is still represented as - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_nvfp4_serialized - else params_dtype) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) # Weight weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension layer.output_size_per_partition, layer.input_size_per_partition // 2, - dtype=torch.uint8), + dtype=torch.uint8, + ), input_dim=1, output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # Input Weight Scale - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_scale", input_scale) # Global Weight Scale - weight_scale_2 = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale_2", weight_scale_2) # Per Block Weight Scale - weight_scale = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition // self.quant_config.group_size, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: Module) -> None: - # global scales: input_scale_2 = layer.input_scale.max().to(torch.float32) layer.input_scale = Parameter(input_scale_2, requires_grad=False) @@ -881,16 +1022,27 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) - layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, - requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) + + # Calculate `1 / input_scale` so that we don't need to do so at runtime + layer.input_scale_inv = Parameter( + (1 / layer.input_scale).to(torch.float32), requires_grad=False + ) # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; - assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Block scale must be represented as FP8-E4M3") + assert layer.weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Block scale must be represented as FP8-E4M3" + ) - if self.backend == "flashinfer-trtllm": + if self.backend == "marlin": + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + elif self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -901,27 +1053,20 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): weight_scale = layer.weight_scale.data epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), - epilogue_tile_m) - weight_scale = (shuffle_matrix_sf_a(weight_scale.view( - torch.uint8), epilogue_tile_m).reshape( - weight_scale.shape).view(torch.float8_e4m3fn)) + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) - if self.backend == "marlin": - prepare_fp4_layer_for_marlin(layer) - del layer.alpha - del layer.input_scale - del layer.weight_scale_swizzled - def apply( self, layer: torch.nn.Module, @@ -937,28 +1082,28 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - s_quant = 1 / layer.input_scale - x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale - assert (x_fp4.dtype == torch.uint8) - assert (layer.weight.dtype == torch.uint8) - assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) - assert (layer.alpha.dtype == torch.float32) + assert x_fp4.dtype == torch.uint8 + assert layer.weight.dtype == torch.uint8 + assert x_blockscale.dtype == torch.float8_e4m3fn + assert layer.weight_scale.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 mm_args = ( x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, + layer.weight_scale, layer.alpha, output_dtype, ) @@ -998,7 +1143,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer: torch.nn.Module, ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 - detect_nvfp4_moe_support) + detect_nvfp4_moe_support, + ) + super().__init__(moe) self.quant_config = quant_config self.layer = layer @@ -1007,40 +1154,42 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.flashinfer_moe_backend = None - + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - " for ModelOptNvFp4FusedMoE.") + " for ModelOptNvFp4FusedMoE." + ) - def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if (self.allow_flashinfer and self.flashinfer_moe_backend - == FlashinferMoeBackend.CUTLASS): - prepare_finalize = ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - )) + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin or ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + return None + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): + # For now, fp4 moe only works with the flashinfer dispatcher. + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + self.moe + ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize - - return super().maybe_make_prepare_finalize(moe) + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, + layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None experts = select_nvfp4_gemm_impl( - moe, - g1_alphas=self.layer.g1_alphas, - g2_alphas=self.layer.g2_alphas, - a1_gscale=self.layer.w13_input_scale_quant, - a2_gscale=self.layer.w2_input_scale_quant, + self.moe, + self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) @@ -1052,12 +1201,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ return True - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) layer.num_experts = num_experts layer.params_dtype = params_dtype @@ -1072,10 +1229,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, - dtype=weight_dtype), + dtype=weight_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w13_weight", w13_weight) # GEMM 2 @@ -1085,10 +1244,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, - dtype=weight_dtype), + dtype=weight_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w2_weight", w2_weight) w13_weight_scale = ModelWeightParameter( @@ -1097,10 +1258,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, - dtype=weight_scale_dtype), + dtype=weight_scale_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = ModelWeightParameter( @@ -1108,128 +1271,170 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): num_experts, hidden_size, # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // - self.quant_config.group_size, - dtype=weight_scale_dtype), + intermediate_size_per_partition // self.quant_config.group_size, + dtype=weight_scale_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) w13_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, 2, dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) w2_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) - w13_input_scale = PerTensorScaleParameter(data=torch.empty( - num_experts, 2, dtype=torch.float32), - weight_loader=weight_loader) + w13_input_scale = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("w13_input_scale", w13_input_scale) - w2_input_scale = PerTensorScaleParameter(data=torch.empty( - num_experts, dtype=torch.float32), - weight_loader=weight_loader) + w2_input_scale = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("w2_input_scale", w2_input_scale) - def prepare_static_weight_layouts_for_trtllm_moe( + def prepare_static_weights_for_trtllm_fp4_moe( self, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm1_scales_linear_fp4_bytes: torch.Tensor, - gemm2_scales_linear_fp4_bytes: torch.Tensor, - hidden_size: int, - intermediate_size: int, - num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # args_dequant, + # args, + gemm1_weights, + gemm2_weights, + gemm1_scales_linear_fp4_bytes, + gemm2_scales_linear_fp4_bytes, + hidden_size, + intermediate_size, + num_experts, + ): + from flashinfer import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w2_permute_indices, + _maybe_get_cached_w3_w1_permute_indices, + ) + """Prepare quantized weights for kernel (done offline with weights).""" - from flashinfer import (reorder_rows_for_gated_act_gemm, - shuffle_matrix_a, shuffle_matrix_sf_a) epilogue_tile_m = 128 # FIXME: this depends on the kernel internals # Convert quantized weights to proper formats gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4 + num_experts, 2 * intermediate_size, hidden_size // 2 + ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size, - hidden_size // - 16) # fp8 scaling factors + torch.float8_e4m3fn + ).reshape( + num_experts, 2 * intermediate_size, hidden_size // 16 + ) # fp8 scaling factors gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( - num_experts, hidden_size, intermediate_size // 2) # packed fp4 + num_experts, hidden_size, intermediate_size // 2 + ) # packed fp4 gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn).reshape(num_experts, hidden_size, - intermediate_size // - 16) # fp8 scaling factors + torch.float8_e4m3fn + ).reshape( + num_experts, hidden_size, intermediate_size // 16 + ) # fp8 scaling factors - # Reorder rows of W1 and scales for fused gated activation - gemm1_weights_fp4_interleaved = [] - gemm1_scales_fp4_interleaved = [] - for i in range(num_experts): - gemm1_weights_fp4_interleaved.append( - reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())) - gemm1_scales_fp4_interleaved.append( - reorder_rows_for_gated_act_gemm( - gemm1_scales_linear_fp4[i].clone())) - - # Stack weights and scales for all experts - gemm1_weights_fp4_interleaved = torch.stack( - gemm1_weights_fp4_interleaved).reshape(num_experts, - 2 * intermediate_size, - hidden_size // 2) - gemm1_scales_fp4_interleaved = torch.stack( - gemm1_scales_fp4_interleaved).reshape(num_experts, - 2 * intermediate_size, - hidden_size // 16) - - # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] gemm2_weights_fp4_shuffled = [] gemm2_scales_fp4_shuffled = [] for i in range(num_experts): + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) gemm1_weights_fp4_shuffled.append( - shuffle_matrix_a( - gemm1_weights_fp4_interleaved[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_scales_fp4_shuffled.append( - shuffle_matrix_sf_a( - gemm1_scales_fp4_interleaved[i].view(torch.uint8), - epilogue_tile_m)) + gemm1_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] + .contiguous() + ) + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_fp4_shuffled.append( + nvfp4_block_scale_interleave( + gemm1_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm1_scales_linear_fp4.device) + ] + .contiguous() + ) + ) + + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + gemm2_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) gemm2_weights_fp4_shuffled.append( - shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8), - epilogue_tile_m)) + gemm2_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] + .contiguous() + ) + + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + gemm2_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm2_scales_fp4_shuffled.append( - shuffle_matrix_sf_a( - gemm2_scales_linear_fp4[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave( + gemm2_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm2_scales_linear_fp4.device) + ] + .contiguous() + ) + ) # Stack weights for all experts gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) gemm1_scales_fp4_shuffled = ( - torch.stack(gemm1_scales_fp4_shuffled).view( - torch.float8_e4m3fn).reshape(num_experts, - 2 * intermediate_size, - hidden_size // 16)) + torch.stack(gemm1_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + ) gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) gemm2_scales_fp4_shuffled = ( - torch.stack(gemm2_scales_fp4_shuffled).view( - torch.float8_e4m3fn).reshape(num_experts, hidden_size, - intermediate_size // 16)) - return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled) + torch.stack(gemm2_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, hidden_size, intermediate_size // 16) + ) + return ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 processing @@ -1238,72 +1443,86 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): if self.allow_flashinfer: gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( - gemm1_weight, gemm1_weight_scale, dim=-2) + gemm1_weight, gemm1_weight_scale, dim=-2 + ) layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) - layer.w13_weight_scale = Parameter(gemm1_weight_scale, - requires_grad=False) + layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) # Common processing for w13_weight_scale_2 - if not torch.allclose(layer.w13_weight_scale_2[:, 0], - layer.w13_weight_scale_2[:, 1]): + if not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): logger.warning_once( "w1_weight_scale_2 must match w3_weight_scale_2. " - "Accuracy may be affected.") + "Accuracy may be affected." + ) w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] - layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, - requires_grad=False) + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( - torch.float32) + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), - requires_grad=False) + requires_grad=False, + ) # This is for quantization, so we need to invert it. layer.w13_input_scale_quant = Parameter( - (1 / w13_input_scale).to(torch.float32), requires_grad=False) + (1 / w13_input_scale).to(torch.float32), requires_grad=False + ) # GEMM 2 processing layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False) + requires_grad=False, + ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + ) # TensorRT-LLM specific processing - if self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): # Prepare static weights for TRT-LLM kernel - (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled - ) = self.prepare_static_weight_layouts_for_trtllm_moe( - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - layer.w2_weight.size(-2), # hidden_size - layer.w13_weight.size(-2) // 2, # intermediate_size - layer.w13_weight.size(0), # num_experts - ) + # alternate: prepare_static_weight_layouts_for_trtllm_moe + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = self.prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) + logger.debug_once("Finished shuffling weights for TRT-LLM MOE") layer.gemm1_weights_fp4_shuffled = Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False) + gemm1_weights_fp4_shuffled, requires_grad=False + ) layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False) + gemm2_weights_fp4_shuffled, requires_grad=False + ) layer.gemm1_scales_fp4_shuffled = Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False) + gemm1_scales_fp4_shuffled, requires_grad=False + ) layer.gemm2_scales_fp4_shuffled = Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False) + gemm2_scales_fp4_shuffled, requires_grad=False + ) # Additional parameter needed for TRT-LLM layer.g1_scale_c = Parameter( - (layer.w2_input_scale_quant * layer.g1_alphas).to( - torch.float32), + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) @@ -1312,35 +1531,55 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): del layer.w2_weight_scale del layer.w13_weight del layer.w13_weight_scale - else: - # Non-TRT-LLM processing (Cutlass or non-flashinfer) - assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w13_blockscale_swizzled = swizzle_blockscale( - layer.w13_weight_scale) - layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, - requires_grad=False) - - assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, - requires_grad=False) - layer.w2_weight = Parameter(layer.w2_weight.data, - requires_grad=False) - - if self.use_marlin: + elif self.use_marlin: + # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) del layer.g1_alphas del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant - del layer.w13_blockscale_swizzled - del layer.w2_blockscale_swizzled + else: + # Non-TRT-LLM processing (Cutlass or non-flashinfer) + assert layer.w13_weight_scale.shape[2] % 16 == 0, ( + "Expected weight_scale.dim(1) to be divisible by 16" + ) + assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Blockscale must be represented as FP8-E4M3" + ) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) + layer.w13_weight_scale = Parameter( + w13_blockscale_swizzled, requires_grad=False + ) + + assert layer.w2_weight_scale.shape[2] % 16 == 0, ( + "Expected weight_scale.dim(1) to be divisible by 16" + ) + assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Blockscale must be represented as FP8-E4M3" + ) + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + layer.w2_weight_scale = Parameter( + w2_blockscale_swizzled, requires_grad=False + ) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if ( + self.use_marlin + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + return None + + return nvfp4_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + ) def apply( self, @@ -1356,6 +1595,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1363,68 +1603,82 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( - "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." + ) assert activation == "silu", "Only SiLU activation is supported." - if self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): import flashinfer from vllm.model_executor.models.llama4 import Llama4MoE + assert self.fused_experts is None + a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, - hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) - use_llama4_routing = \ + (hidden_states_fp4, hidden_states_scale_linear_fp4) = ( + flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + ) + use_llama4_routing = ( custom_routing_function is Llama4MoE.custom_routing_function + ) routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 if use_llama4_routing: routing_method_type = flashinfer.RoutingMethodType.Llama4 + routing_bias = e_score_correction_bias + if routing_bias is not None: + routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits - if use_llama4_routing else router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, + if use_llama4_routing + else router_logits.to(torch.float32), + routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn).flatten(), + torch.float8_e4m3fn + ).flatten(), gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn), + torch.float8_e4m3fn + ), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn), + torch.float8_e4m3fn + ), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, output2_scale_scalar=layer.g2_alphas.data, num_experts=global_num_experts, top_k=top_k, - n_group=num_expert_group - if num_expert_group is not None else 0, + n_group=num_expert_group if num_expert_group is not None else 0, topk_group=topk_group if topk_group is not None else 0, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, - tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, - layer.local_num_experts), + tile_tokens_dim=_get_tile_tokens_dim( + x.shape[0], top_k, layer.local_num_experts + ), routing_method_type=routing_method_type, do_finalize=True, )[0] return out - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -1434,10 +1688,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) + # + # Note: the order here is important. self.fused_experts can override + # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or + # trtllm. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -1454,17 +1716,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): quant_type_id=scalar_types.float4_e2m1f.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + workspace=layer.workspace, + ) - if self.fused_experts is not None: - assert self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + elif self.fused_experts is not None: + assert ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ) assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight + ), "Flashinfer CUTLASS Fused MoE not applicable!" - out = self.fused_experts( + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1474,28 +1740,26 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif (self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4) + flashinfer_cutlass_moe_fp4, + ) - out = flashinfer_cutlass_moe_fp4( + assert self.moe_quant_config is not None + + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - inplace=False, # TODO(shuw): fix later, now output is high prec + quant_config=self.moe_quant_config, + inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1504,25 +1768,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) - out = cutlass_moe_fp4( + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 + + assert self.moe_quant_config is not None + return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=self.moe_quant_config, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + # TODO: derive from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - return out + ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 364d1ac314d2d..3719672f6e52f 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,20 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supports_layer) + check_marlin_supports_layer, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -22,10 +33,16 @@ from vllm.platforms import current_platform class MoeWNA16Config(QuantizationConfig): """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" - def __init__(self, linear_quant_method: str, weight_bits: int, - group_size: int, has_zp: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[list[str]], - full_config: dict[str, Any]) -> None: + def __init__( + self, + linear_quant_method: str, + weight_bits: int, + group_size: int, + has_zp: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any], + ) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size @@ -37,26 +54,25 @@ class MoeWNA16Config(QuantizationConfig): self.use_marlin = False # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig - from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig) - from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) + from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig + from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig + if self.linear_quant_method == "gptq": - self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( - full_config) + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) elif self.linear_quant_method == "awq": capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) awq_min_capability = AWQConfig.get_min_capability() if device_capability < awq_min_capability: raise ValueError( "The quantization method moe_wna16 + awq is not supported " "for the current GPU. " f"Minimum capability: {awq_min_capability}. " - f"Current capability: {device_capability}.") - self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( - full_config) + f"Current capability: {device_capability}." + ) + self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(full_config) else: raise ValueError("moe_wna16 only support gptq and awq.") @@ -86,24 +102,32 @@ class MoeWNA16Config(QuantizationConfig): linear_quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) if linear_quant_method == "gptq": has_zp = not cls.get_from_keys(config, ["sym"]) modules_to_not_convert = [] elif linear_quant_method == "awq": has_zp = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) + config, ["modules_to_not_convert"], None + ) else: raise ValueError("moe_wna16 only support gptq and awq.") - return cls(linear_quant_method, weight_bits, group_size, has_zp, - lm_head_quantized, modules_to_not_convert, config) + return cls( + linear_quant_method, + weight_bits, + group_size, + has_zp, + lm_head_quantized, + modules_to_not_convert, + config, + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": return cls.get_name() @@ -117,46 +141,59 @@ class MoeWNA16Config(QuantizationConfig): desc_act = quant_config.get("desc_act") capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig + awq_min_capability = AWQConfig.get_min_capability() - gptq_compatible = quant_method == "gptq" and \ - not desc_act and num_bits in [4, 8] - awq_compatible = quant_method == "awq" and num_bits == 4 and \ - device_capability >= awq_min_capability + gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8] + awq_compatible = ( + quant_method == "awq" + and num_bits == 4 + and device_capability >= awq_min_capability + ) return gptq_compatible or awq_compatible - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig) + AWQMarlinConfig, + ) from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) + GPTQMarlinConfig, + ) + if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + self.full_config + ).get_quant_method(layer, prefix) else: - return GPTQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return GPTQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) elif self.linear_quant_method == "awq": if self.use_marlin and check_marlin_supports_layer( - layer, self.group_size): + layer, self.group_size + ): return AWQMarlinConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + self.full_config + ).get_quant_method(layer, prefix) else: - return AWQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): @@ -175,28 +212,29 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__( - self, - quant_config: MoeWNA16Config, - moe: FusedMoEConfig, - ): + def __init__(self, quant_config: MoeWNA16Config, moe: "FusedMoEConfig") -> None: super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size group_size_div_factor = 1 - # make intermediate_size and hidden_size diviable by group_size + # make intermediate_size and hidden_size divisible by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later - while intermediate_size_per_partition % group_size or \ - hidden_size % group_size: + while intermediate_size_per_partition % group_size or hidden_size % group_size: group_size = group_size // 2 group_size_div_factor *= 2 assert group_size >= 32 @@ -204,71 +242,85 @@ class MoeWNA16Method(FusedMoEMethodBase): layer.group_size_div_factor = group_size_div_factor strategy = FusedMoeWeightScaleSupported.GROUP.value - extra_weight_attrs.update({ - "quant_method": strategy, - "is_transposed": False - }) + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False}) - assert 'weight_loader' in extra_weight_attrs - weight_loader = extra_weight_attrs['weight_loader'] - wrapped_weight_loader = MoeWNA16Method.get_weight_loader( - layer, weight_loader) - extra_weight_attrs['weight_loader'] = wrapped_weight_loader + assert "weight_loader" in extra_weight_attrs + weight_loader = extra_weight_attrs["weight_loader"] + wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader) + extra_weight_attrs["weight_loader"] = wrapped_weight_loader # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // bit8_pack_factor, - dtype=torch.uint8), - requires_grad=False) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // bit8_pack_factor, - dtype=torch.uint8), - requires_grad=False) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) - w13_scales = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // group_size, - dtype=params_dtype), - requires_grad=False) + w13_scales = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) - w2_scales = torch.nn.Parameter(torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition // group_size, - dtype=params_dtype), - requires_grad=False) + w2_scales = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) if self.quant_config.has_zp: - w13_qzeros = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition // bit8_pack_factor, - hidden_size // group_size, - dtype=torch.uint8), - requires_grad=False) + w13_qzeros = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) - w2_qzeros = torch.nn.Parameter(torch.zeros( - num_experts, - hidden_size // bit8_pack_factor, - intermediate_size_per_partition // group_size, - dtype=torch.uint8), - requires_grad=False) + w2_qzeros = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size_per_partition // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) @@ -279,12 +331,32 @@ class MoeWNA16Method(FusedMoEMethodBase): if not self.quant_config.has_zp: invalid_param_keys += ["w13_qzeros", "w2_qzeros"] for key in invalid_param_keys: - param = torch.nn.Parameter(torch.empty((0, ), - dtype=torch.int32), - requires_grad=False) + param = torch.nn.Parameter( + torch.empty((0,), dtype=torch.int32), requires_grad=False + ) layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + assert weight_bits == 4 or weight_bits == 8 + config_builder = ( + int4_w4a16_moe_quant_config + if weight_bits == 4 + else int8_w8a16_moe_quant_config + ) + + return config_builder( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -299,6 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -306,16 +379,15 @@ class MoeWNA16Method(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `MoeWNA16Method` yet.") + raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.") from vllm.model_executor.layers.fused_moe import fused_experts + assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -325,11 +397,10 @@ class MoeWNA16Method(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) - - weight_bits = self.quant_config.weight_bits - has_zp = self.quant_config.has_zp + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -338,20 +409,14 @@ class MoeWNA16Method(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, - w1_zp=layer.w13_qzeros if has_zp else None, - w2_zp=layer.w2_qzeros if has_zp else None, - block_shape=[0, layer.group_size]) + quant_config=self.moe_quant_config, + ) @staticmethod def get_weight_loader(layer, weight_loader): - def convert_awq_tensor(tensor, tensor_type): # convert awq qweight/qzeros to a standard format (assume int4) # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) @@ -367,9 +432,7 @@ class MoeWNA16Method(FusedMoEMethodBase): # 2. unpack to uint4 (only when weight_bits == 4) # shape (a, 4 * b) -> (a, 4 * b, 2) - shifter = torch.tensor([0, 4], - dtype=torch.uint8, - device=tensor.device) + shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF # 3. change order, see @@ -394,22 +457,24 @@ class MoeWNA16Method(FusedMoEMethodBase): def convert_gptq_int4_qzeros(tensor): tensor = tensor.view(torch.uint8) - shifter = torch.tensor([0, 4], - dtype=torch.uint8, - device=tensor.device) + shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF tensor = tensor + 1 tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 return tensor - def moe_wna16_weight_loader(param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, shard_id: str, - expert_id: int): + def moe_wna16_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False, + ): if "g_idx" in weight_name: - return + return False if return_success else None if not layer.quant_config.has_zp and "qzeros" in weight_name: - return + return False if return_success else None device = get_tp_group().device tp_rank = get_tensor_model_parallel_rank() @@ -420,8 +485,7 @@ class MoeWNA16Method(FusedMoEMethodBase): if layer.quant_config.linear_quant_method == "awq": assert layer.quant_config.weight_bits == 4 if "weight" in weight_name: - loaded_weight = convert_awq_tensor(loaded_weight, - "qweight") + loaded_weight = convert_awq_tensor(loaded_weight, "qweight") elif "zeros" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") else: @@ -429,37 +493,50 @@ class MoeWNA16Method(FusedMoEMethodBase): elif layer.quant_config.linear_quant_method == "gptq": assert layer.quant_config.weight_bits in [4, 8] if "weight" in weight_name: - loaded_weight = loaded_weight.T.contiguous().view( - torch.uint8) + loaded_weight = loaded_weight.T.contiguous().view(torch.uint8) elif "zeros" in weight_name: # add 1 to gptq qzeros to align with awq loaded_weight = loaded_weight.view(torch.uint8) if layer.quant_config.weight_bits == 4: - loaded_weight = convert_gptq_int4_qzeros( - loaded_weight).T + loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T else: loaded_weight = loaded_weight.T + 1 else: loaded_weight = loaded_weight.T # repeat the qzeros/scales to fit new group size - if layer.group_size_div_factor > 1 and \ - "qzeros" in weight_name or "scales" in weight_name: + if ( + layer.group_size_div_factor > 1 + and "qzeros" in weight_name + or "scales" in weight_name + ): loaded_weight = loaded_weight.repeat_interleave( - layer.group_size_div_factor, 1) + layer.group_size_div_factor, 1 + ) if "w13_qzeros" in weight_name: - tensor = loaded_weight.view(layer.tp_size, -1, - loaded_weight.size(1))[tp_rank] + tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ + tp_rank + ] if shard_id == "w1": - param.data[expert_id, :shard_size // 2] = tensor + param.data[expert_id, : shard_size // 2] = tensor else: - param.data[expert_id, shard_size // 2:] = tensor + param.data[expert_id, shard_size // 2 :] = tensor + return True if return_success else None elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( - loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + loaded_weight.size(0), layer.tp_size, -1 + )[:, tp_rank] + return True if return_success else None else: - weight_loader(param, loaded_weight, weight_name, shard_id, - expert_id) + # Delegate to the original loader, passing return_success + return weight_loader( + param, + loaded_weight, + weight_name, + shard_id, + expert_id, + return_success=return_success, + ) return moe_wna16_weight_loader diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6a190ebbc063e..dd9532be7585c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,68 +1,142 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from enum import Enum +from typing import Callable, Optional, Union import torch from torch.nn.parameter import Parameter from vllm import envs +from vllm.config import get_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + mxfp4_w4a16_moe_quant_config, + ocp_mx_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - triton_kernel_moe_forward) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) + OAITritonExperts, +) +from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin) + prepare_moe_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - _can_support_mxfp4, _swizzle_mxfp4) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + _can_support_mxfp4, + _swizzle_mxfp4, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, - next_power_of_2, round_up) +from vllm.utils import ( + has_triton_kernels, + is_torch_equal_or_newer, + next_power_of_2, + round_up, +) from vllm.utils.flashinfer import has_flashinfer logger = init_logger(__name__) -def _should_use_flashinfer_mxfp4_bf16(): - """Determine if FlashInfer MXFP4 BF16 should be used.""" - # If explicitly set, respect the setting - if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 +# enum for mxfp4 backend +class Mxfp4Backend(Enum): + NONE = 0 - # Enable by default on SM100 if MXFP8 is not explicitly enabled - if (current_platform.is_device_capability(100) and has_flashinfer() - and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): - logger.info_once( - "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " - "For faster performance, consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " - "though this may impact accuracy.") - return True + # FlashInfer Backend + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 - return False + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 -def _should_use_flashinfer_mxfp4_mxfp8(): - """Determine if FlashInfer MXFP4 MXFP8 should be used.""" - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 +def get_mxfp4_backend(): + # Backend Selection + if current_platform.is_cuda(): + if ( + current_platform.is_device_capability(90) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + ): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + elif ( + current_platform.is_device_capability(100) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS + ): + logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + elif ( + current_platform.is_device_capability(100) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + ): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " + "for high concurrency throughput workloads consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " + "performance" + ) + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + elif current_platform.is_device_capability(100) and has_flashinfer(): + logger.info_once( + "Using FlashInfer MXFP4 BF16 backend for SM100, " + "For faster performance on SM100, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " + "accuracy." + ) + return Mxfp4Backend.SM100_FI_MXFP4_BF16 + elif ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) + ) and not has_flashinfer(): + logger.warning_once( + "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results." + ) + # If FlashInfer is not available, try either Marlin or Triton + if ( + envs.VLLM_MXFP4_USE_MARLIN + or current_platform.get_device_capability()[0] < 9 + or not has_triton_kernels() + or not is_torch_equal_or_newer("2.8.0") + ): + logger.info_once("Using Marlin backend") + return Mxfp4Backend.MARLIN + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + elif current_platform.is_rocm() and has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON -def should_use_flashinfer_mxfp4(): - return (_should_use_flashinfer_mxfp4_mxfp8() - or _should_use_flashinfer_mxfp4_bf16()) + return Mxfp4Backend.NONE class Mxfp4Config(QuantizationConfig): - def __init__(self, ignored_layers: Optional[list[str]] = None): super().__init__() self.ignored_layers = ignored_layers @@ -87,56 +161,51 @@ class Mxfp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return [] - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): if self.ignored_layers and is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping): + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() raise NotImplementedError("Mxfp4 linear layer is not implemented") elif isinstance(layer, FusedMoE): return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): - raise NotImplementedError( - "Mxfp4 attention layer is not implemented") + raise NotImplementedError("Mxfp4 attention layer is not implemented") return None class Mxfp4MoEMethod(FusedMoEMethodBase): - def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.topk_indices_dtype = None self.moe = moe - self.use_marlin = self._should_use_marlin() + self.mxfp4_backend = get_mxfp4_backend() + self.max_capture_size = ( + get_current_vllm_config().compilation_config.max_capture_size + ) - if current_platform.is_device_capability(100) and not has_flashinfer(): - logger.warning_once( - "MXFP4 MoE is enabled on Blackwell but FlashInfer " - "is not available. This may result in degraded performance. " - "Please `pip install vllm[flashinfer]` for best results.") + assert self.mxfp4_backend != Mxfp4Backend.NONE, ( + "No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available." + "Please check your environment and try again." + ) + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} - def _should_use_marlin(self): - if envs.VLLM_MXFP4_USE_MARLIN is not None: - return envs.VLLM_MXFP4_USE_MARLIN - if current_platform.is_cuda() and \ - not current_platform.is_device_capability(100): - if not current_platform.has_device_capability(90): - # marlin kernel has better performance on ampere - return True - if not has_triton_kernels(): - return True - if not is_torch_equal_or_newer("2.8.0"): - return True - return False - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): self.num_experts = num_experts weight_dtype = torch.uint8 scale_dtype = torch.uint8 @@ -151,9 +220,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): mxfp4_block = 32 - intermediate_size_per_partition_after_pad = \ - intermediate_size_per_partition - if self.use_marlin: + intermediate_size_per_partition_after_pad = intermediate_size_per_partition + if self.mxfp4_backend == Mxfp4Backend.MARLIN: # The moe marlin kernel requires that for each linear # n % 256 == 0 and k % 128 == 0. # In gate_up_proj: @@ -163,27 +231,44 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # n = hidden_size # k = intermediate_size_per_partition_after_pad intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128) + intermediate_size_per_partition, 128 + ) hidden_size = round_up(hidden_size, 256) layer.params_dtype = params_dtype layer.num_experts = num_experts layer.hidden_size = hidden_size - layer.intermediate_size_per_partition = \ + layer.intermediate_size_per_partition = ( intermediate_size_per_partition_after_pad - elif should_use_flashinfer_mxfp4(): + ) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 256) + intermediate_size_per_partition, 256 + ) hidden_size = round_up(hidden_size, 256) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128 + ) + hidden_size = round_up(hidden_size, 128) elif current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128) + intermediate_size_per_partition, 256 + ) + hidden_size = round_up(hidden_size, 256) else: intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 64) + intermediate_size_per_partition, 64 + ) self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size @@ -260,45 +345,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: prepare_moe_fp4_layer_for_marlin(layer) - elif should_use_flashinfer_mxfp4(): - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a - layer.gemm1_alpha = Parameter(torch.tensor( - [1.702] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) - layer.gemm1_beta = Parameter(torch.tensor( - [1.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) - layer.gemm1_clamp_limit = Parameter(torch.tensor( - [7.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + + layer.gemm1_alpha = Parameter( + torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_beta = Parameter( + torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_clamp_limit = Parameter( + torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) sf_block_size = 32 # mxfp4 block size - assert (layer.w13_weight.dim() == 3 - and layer.w13_weight.shape[0] == self.num_experts - and layer.w13_weight.shape[1] == self.intermediate_size * 2 - and layer.w13_weight.shape[2] == self.hidden_size // 2) - assert (layer.w13_weight_scale.dim() == 3 - and layer.w13_weight_scale.shape[0] == self.num_experts - and layer.w13_weight_scale.shape[1] - == self.intermediate_size * 2 - and layer.w13_weight_scale.shape[2] - == self.hidden_size // sf_block_size) - assert (layer.w2_weight.dim() == 3 - and layer.w2_weight.shape[0] == self.num_experts - and layer.w2_weight.shape[1] == self.hidden_size and - layer.w2_weight.shape[2] == self.intermediate_size // 2) - assert (layer.w2_weight_scale.dim() == 3 - and layer.w2_weight_scale.shape[1] == self.hidden_size - and layer.w2_weight_scale.shape[2] - == self.intermediate_size // sf_block_size) - assert (layer.w13_bias.dim() == 2 - and layer.w13_bias.shape[0] == self.num_experts - and layer.w13_bias.shape[1] == self.intermediate_size * 2) - assert (layer.w2_bias.dim() == 2 - and layer.w2_bias.shape[0] == self.num_experts - and layer.w2_bias.shape[1] == self.hidden_size) + assert ( + layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2 + ) + assert ( + layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size + ) + assert ( + layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size + and layer.w2_weight.shape[2] == self.intermediate_size // 2 + ) + assert ( + layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size + ) + assert ( + layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2 + ) + assert ( + layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size + ) w13_weight_scale = layer.w13_weight_scale.data w2_weight_scale = layer.w2_weight_scale.data @@ -307,7 +410,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w13_bias = layer.w13_bias.data.to(torch.float32) w2_bias = layer.w2_bias.data.to(torch.float32) - # Swap w1 and w3 as the defenition of + # Swap w1 and w3 as the definition of # swiglu is different in the trtllm-gen def swap_every_two_rows(x, axis=-1): shape = x.shape @@ -340,51 +443,248 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) gemm1_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), - epilogue_tile_m)) + w13_weight[i] + .view(torch.uint8)[permute_indices.to(w13_weight.device)] + .contiguous() + ) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm1_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave( + w13_weight_scale[i] + .view(torch.uint8)[ + permute_sf_indices.to(w13_weight_scale.device) + ] + .contiguous() + ) + ) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) - + w13_bias[i] + .clone() + .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] + .contiguous() + ) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) gemm2_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), - epilogue_tile_m)) + w2_weight[i] + .view(torch.uint8)[permute_indices.to(w2_weight.device)] + .contiguous() + ) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm2_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave( + w2_weight_scale[i] + .view(torch.uint8)[ + permute_sf_indices.to(w2_weight_scale.device) + ] + .contiguous() + ) + ) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) + w2_bias[i] + .clone() + .reshape(-1, 1)[permute_indices.to(w2_bias.device)] + .contiguous() + ) w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) - w13_weight_scale = torch.stack( - gemm1_scales_mxfp4_shuffled).reshape( - self.num_experts, 2 * self.intermediate_size, - self.hidden_size // sf_block_size).view( - torch.float8_e4m3fn) + w13_weight_scale = ( + torch.stack(gemm1_scales_mxfp4_shuffled) + .reshape( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) - w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( - self.num_experts, self.hidden_size, self.intermediate_size // - sf_block_size).view(torch.float8_e4m3fn) + w2_weight_scale = ( + torch.stack(gemm2_scales_mxfp4_shuffled) + .reshape( + self.num_experts, + self.hidden_size, + self.intermediate_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = Parameter(w13_weight_scale, - requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = Parameter(w2_weight_scale, - requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) layer.w13_bias = Parameter( torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), - requires_grad=False) - layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( - self.num_experts, -1), - requires_grad=False) - else: + requires_grad=False, + ) + layer.w2_bias = Parameter( + torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False, + ) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ): + layer.gemm1_alpha = Parameter( + torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_beta = Parameter( + torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_clamp_limit = Parameter( + torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + + sf_block_size = 32 # mxfp4 block size + + # Common shape assertions + assert ( + layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2 + ) + assert ( + layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size + ) + assert ( + layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size + and layer.w2_weight.shape[2] == self.intermediate_size // 2 + ) + assert ( + layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size + ) + assert ( + layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2 + ) + assert ( + layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size + ) + + # De-interleave and swap for w13 weight, bias, and scales + w13_w = layer.w13_weight.data + gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] + deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) + w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + + w13_b = layer.w13_bias.data.to(torch.float32) + gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] + deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) + b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) + w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w13_s = layer.w13_weight_scale.data + gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] + deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) + s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + from flashinfer import block_scale_interleave + + orig_shape = w13_scale_swapped.shape + w13_scale_interleaved = block_scale_interleave( + w13_scale_swapped.view(torch.uint8) + ).reshape(orig_shape) + + w2_s = layer.w2_weight_scale.data + orig_shape = w2_s.shape + w2_scale_interleaved = block_scale_interleave( + w2_s.view(torch.uint8) + ).reshape(orig_shape) + + layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False) + layer.w13_weight_scale = Parameter( + w13_scale_interleaved, requires_grad=False + ) + layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False) + layer.w2_weight_scale = Parameter( + w2_scale_interleaved, requires_grad=False + ) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape( + w_shape[0], w_shape[1], (w_shape[2] // 4), 4 + ) + w_interleaved = w_interleaved.permute(0, 2, 1, 3) + w_interleaved = w_interleaved.reshape( + w_shape[0], w_shape[2] // 4, w_shape[1] * 4 + ) + return w_interleaved + + w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8) + w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales) + + w2_weight_scale = layer.w2_weight_scale.data + w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) + w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales) + + layer.w13_weight = torch.nn.Parameter( + torch.cat([w3_w, w1_w], dim=1), requires_grad=False + ) + layer.w13_bias = torch.nn.Parameter( + w13_bias_swapped, requires_grad=False + ) + layer.w13_weight_scale = torch.nn.Parameter( + w31_scales_interleaved, requires_grad=False + ) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales_interleaved, requires_grad=False + ) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig w13_bias = layer.w13_bias.to(torch.float32) @@ -393,22 +693,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False) - # FIXME warp need to be adjusted based on batch size - # only apply to batched mode - if self.moe.use_ep: + # Ideally we'd use FusedMoEModularKernel.prepare_finalize object + # (stored in self.fused_experts) to determine if the MoE has a + # batched activation format. As self.fused_experts is not + # initialized at this point, we resort to checking the MoE config + # directly. + is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels + if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( - layer.w13_weight, layer.w13_weight_scale, num_warps) + layer.w13_weight, layer.w13_weight_scale, num_warps + ) w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( - layer.w2_weight, layer.w2_weight_scale, num_warps) + layer.w2_weight, layer.w2_weight_scale, num_warps + ) self.w13_precision_config = PrecisionConfig( - weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) + ) self.w2_precision_config = PrecisionConfig( - weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) + ) self.w13_weight_triton_tensor = w13_weight self.w2_weight_triton_tensor = w2_weight @@ -419,6 +727,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_weight = None layer.w2_weight = None torch.cuda.empty_cache() + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): # Number of tokens in the input tensor. @@ -444,7 +754,69 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return tile_tokens_dim - def apply( + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: + w1_scale = self.w13_precision_config + w2_scale = self.w2_precision_config + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + else: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + return ocp_mx_moe_quant_config( + quant_dtype="mxfp4", + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + raise NotImplementedError( + "Mxfp4 does not support batched experts format for EP" + ) + else: + assert self.moe_quant_config is not None + if ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): + # B200 code-path + kwargs = { + "gemm1_alpha": layer.gemm1_alpha, + "gemm1_beta": layer.gemm1_beta, + "gemm1_clamp_limit": layer.gemm1_clamp_limit, + # TODO(bnell): part of quant_config + "max_capture_size": self.max_capture_size, + } + return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) + elif self.mxfp4_backend == Mxfp4Backend.MARLIN: + return MarlinExperts(self.moe_quant_config) + else: + return OAITritonExperts(self.moe_quant_config) + + def _route_and_experts( self, layer: torch.nn.Module, x: torch.Tensor, @@ -466,12 +838,101 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) + topk_weights, topk_ids, _ = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + w13_weight = ( + self.w13_weight_triton_tensor + if layer.w13_weight is None + else layer.w13_weight + ) + w2_weight = ( + self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight + ) + assert all([w is not None for w in [w13_weight, w2_weight]]) + + return self.fused_experts( + hidden_states=x, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") - if self.use_marlin: - topk_weights, topk_ids = FusedMoE.select_experts( + if self.fused_experts is not None: + return self._route_and_experts( + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) + + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -481,7 +942,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -500,27 +963,40 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - expert_map=expert_map) + expert_map=expert_map, + ) assert _can_support_mxfp4( - use_grouped_topk, topk_group, num_expert_group, expert_map, - custom_routing_function, e_score_correction_bias, - apply_router_weight_on_input, scoring_func, activation, - expert_load_view, logical_to_physical_map, - logical_replica_count), ( - "MXFP4 are not supported with this configuration.") + use_grouped_topk, + topk_group, + num_expert_group, + expert_map, + custom_routing_function, + e_score_correction_bias, + apply_router_weight_on_input, + scoring_func, + activation, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ), "MXFP4 are not supported with this configuration." - if should_use_flashinfer_mxfp4(): - from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe - assert not self.moe.use_ep, ( - "EP is not supported for flashinfer mxfp4 moe backend yet.") - if _should_use_flashinfer_mxfp4_bf16(): + if ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): + from flashinfer import trtllm_fp4_block_scale_moe + + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None - else: + elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: + from flashinfer import mxfp8_quantize + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -538,20 +1014,102 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): None, # output1_scale_scalar None, # output1_scale_gate_scalar None, # output2_scale_scalar - self.num_experts, + global_num_experts, top_k, None, # n_group None, # topk_group self.intermediate_size, # padded to multiple of 256 - 0, # local_expert_offset + layer.ep_rank * layer.local_num_experts, # local_expert_offset self.num_experts, # local num experts None, self._get_tile_tokens_dim(x, top_k), 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize + tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output - else: + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ): + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + topk_weights, topk_ids, _ = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # Backend-specific preparation + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + from flashinfer import mxfp8_quantize + + x_quant, x_scale = mxfp8_quantize(x, True, 32) + + fake_input_scale = torch.ones(self.num_experts, device=x.device) + quant_scales = [ + layer.w13_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + layer.w2_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + ] + + fi_input = x_quant + extra_kwargs = dict( + use_mxfp8_act_scaling=True, + input_sf=x_scale, + fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long), + fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long), + ) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + + quant_scales = [ + layer.w13_weight_scale, + layer.w2_weight_scale, + ] + + fi_input = x + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + ) + + output = torch.empty_like(x, dtype=torch.bfloat16) + _ = flashinfer_cutlass_fused_moe( + input=fi_input, + token_selected_experts=topk_ids.to(torch.int).contiguous(), + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.w13_bias, + fc2_expert_biases=layer.w2_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + tp_size=self.moe.tp_size, + tp_rank=self.moe.tp_rank, + ep_size=self.moe.ep_size, + ep_rank=self.moe.ep_rank, + tune_max_num_tokens=self.max_capture_size, + **extra_kwargs, + ) + + return output + elif self.mxfp4_backend == Mxfp4Backend.TRITON: + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward, + ) + return triton_kernel_moe_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, @@ -561,9 +1119,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): renormalize=renormalize, global_num_experts=global_num_experts, expert_map=expert_map, - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_precision=self.w13_precision_config, - w2_precision=self.w2_precision_config, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py deleted file mode 100644 index 8040236663dd1..0000000000000 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from importlib.util import find_spec -from typing import Any, Optional - -from torch.nn import Module - -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) - -SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] - - -class AlwaysSupportedDtypes(list): - - def __contains__(self, item): - return True - - -class NeuronQuantConfig(QuantizationConfig): - """Int8 Quantization Config class for Neuron Backend.""" - - def __init__( - self, - dequant_dtype: str = "f16", - quantize_method: str = "vector_dynamic", - ) -> None: - super().__init__() - self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") - if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: - raise ValueError( - f"Neuron quantization datatype {self.quant_dtype} is not valid," - f" the quantization datatype should match one of the below " - f"types {SUPPORTED_QUANT_DTYPE_LIST}") - self.dequant_dtype = dequant_dtype - self.quantize_method = quantize_method - - def get_name(self) -> QuantizationMethods: - return "neuron_quant" - - def get_supported_act_dtypes(self) -> list[str]: - # Neuron implements custom handling logic for quantization support - return AlwaysSupportedDtypes() - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError( - "This function should not be called with Neuron Backend") - - @staticmethod - def get_config_filenames() -> list[str]: - return [] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig": - quantize_method = cls.get_from_keys(config, ["quantize_method"]) - dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) - return cls(dequant_dtype=dequant_dtype, - quantize_method=quantize_method) - - def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: - if find_spec("transformers_neuronx") is not None: - return self.get_quantization_config() - else: - raise NotImplementedError( - "Neuron Quantization is only supported through" - " transformers_neuronx.") - - def get_quantization_config(self): - from transformers_neuronx.config import QuantizationConfig - return QuantizationConfig(quant_dtype=self.quant_dtype, - dequant_dtype=self.dequant_dtype, - quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py index 5b9fee69bb021..60519bdaea028 100644 --- a/vllm/model_executor/layers/quantization/petit.py +++ b/vllm/model_executor/layers/quantization/petit.py @@ -9,19 +9,24 @@ import torch from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.petit_utils import ( - apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, - verify_petit_nvfp4_supported) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) -from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParameter) + apply_petit_nvfp4_linear, + prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.platforms import current_platform # Initialize logger for the module @@ -43,8 +48,10 @@ class PetitNvFp4Config(QuantizationConfig): self._check_hardware_support() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: - logger.warning("Detected nvfp4 checkpoint. Please note that the " - "format is experimental and subject to change.") + logger.warning( + "Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change." + ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo self.exclude_modules = exclude_modules @@ -61,7 +68,8 @@ class PetitNvFp4Config(QuantizationConfig): "The 'petit' quantization backend is designed for AMD GPUs " "and is not supported on the CUDA platform. For NVIDIA GPUs, " "please use a different quantization method such as FP8, AWQ, " - "or GPTQ.") + "or GPTQ." + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -86,8 +94,7 @@ class PetitNvFp4Config(QuantizationConfig): quant_method_raw = qc.get("quant_algo") if not isinstance(quant_method_raw, str) or not quant_method_raw: - raise ValueError( - "Missing or invalid 'quant_algo' in quantization config.") + raise ValueError("Missing or invalid 'quant_algo' in quantization config.") quant_method = quant_method_raw.upper() group_size_raw = qc.get("group_size") @@ -101,19 +108,18 @@ class PetitNvFp4Config(QuantizationConfig): kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" if not isinstance(kv_cache_quant_algo_raw, str): - raise ValueError( - "'kv_cache_quant_algo' must be a string if provided.") + raise ValueError("'kv_cache_quant_algo' must be a string if provided.") kv_cache_quant_algo = kv_cache_quant_algo_raw exclude_raw = qc.get("exclude_modules", []) if exclude_raw is None: exclude_modules: list[str] = [] elif isinstance(exclude_raw, list) and all( - isinstance(x, str) for x in exclude_raw): + isinstance(x, str) for x in exclude_raw + ): exclude_modules = exclude_raw else: - raise ValueError( - "'exclude_modules' must be a list[str] (or omitted).") + raise ValueError("'exclude_modules' must be a list[str] (or omitted).") is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method @@ -126,7 +132,8 @@ class PetitNvFp4Config(QuantizationConfig): @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: if not current_platform.is_rocm(): return None @@ -142,23 +149,24 @@ class PetitNvFp4Config(QuantizationConfig): algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() return algo == "NVFP4" - def is_layer_excluded(self, prefix: str, - exclude_modules: list[str]) -> bool: + def is_layer_excluded(self, prefix: str, exclude_modules: list[str]) -> bool: for pattern in exclude_modules: regex_str = pattern.replace(".", r"\.").replace("*", r".*") if re.fullmatch(regex_str, prefix): return True return False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import exclude = self.require_exclude_modules() if isinstance(layer, LinearBase): if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( - prefix, exclude): + prefix, exclude + ): return UnquantizedLinearMethod() return PetitNvFp4LinearMethod(self) elif isinstance(layer, Attention): @@ -220,8 +228,10 @@ class PetitNvFp4LinearMethod(LinearMethodBase): ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") @@ -231,12 +241,15 @@ class PetitNvFp4LinearMethod(LinearMethodBase): layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition if input_size_per_partition % 16 != 0: - raise ValueError("Unsupported model when in features size is " - "not multiple of 16") + raise ValueError( + "Unsupported model when in features size is not multiple of 16" + ) - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_nvfp4_serialized - else params_dtype) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) weight = ModelWeightParameter( data=torch.empty( @@ -283,8 +296,9 @@ class PetitNvFp4LinearMethod(LinearMethodBase): weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.input_scale = Parameter(input_scale_2, requires_grad=False) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) - layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, - requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) prepare_nvfp4_layer_for_petit(layer) del layer.input_scale diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index d11cba2caba88..c0156321f65d2 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -8,18 +8,19 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) -from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, - Fp8KVCacheMethod, - Fp8LinearMethod) +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) + GroupShape, + is_layer_skipped, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -36,20 +37,20 @@ class PTPCFp8Config(Fp8Config): ignored_layers: Optional[list[str]] = None, ) -> None: if not current_platform.is_rocm(): - raise ValueError( - "ptpc_fp8 quantization is supported only on ROCm.") + raise ValueError("ptpc_fp8 quantization is supported only on ROCm.") if not current_platform.has_device_capability(94): raise ValueError( "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 ) if activation_scheme == "static": - raise ValueError( - "ptpc_fp8 as of now only support dynamic quantization.") + raise ValueError("ptpc_fp8 as of now only support dynamic quantization.") - super().__init__(is_checkpoint_fp8_serialized=False, - activation_scheme=activation_scheme, - ignored_layers=ignored_layers) + super().__init__( + is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -59,11 +60,11 @@ class PTPCFp8Config(Fp8Config): def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - return cls(activation_scheme=activation_scheme, - ignored_layers=ignored_layers) + return cls(activation_scheme=activation_scheme, ignored_layers=ignored_layers) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): @@ -79,7 +80,7 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """Linear method for Per-Token and Per-Channel FP8 Quantization. Only supports loading quantized BF16 model checkpoints with dynamic activation scaling. To load FP16 model checkpoints, user must specify - to convert the FP16 model weight loading into BF16. + to convert the FP16 model weight loading into BF16. The weight scaling factor will be initialized after the model weights are loaded. @@ -92,38 +93,45 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """ def __init__(self, quant_config: PTPCFp8Config): + assert current_platform.is_rocm(), ( + "PTPCFp8LinearMethod is only supported on ROCm." + ) super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( - act_quant_static=False, - cutlass_fp8_supported=False, - act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) - assert layer.weight.data.dtype == torch.bfloat16, \ - f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + assert layer.weight.data.dtype == torch.bfloat16, ( + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + ) # Quantize the weights. qweight, weight_scale = ops.scaled_fp8_quant( - layer.weight, scale=None, use_per_token_if_dynamic=True) + layer.weight, scale=None, use_per_token_if_dynamic=True + ) # Update the layer with the new values. layer.weight = Parameter( - qweight.t(), requires_grad=False) # Pretranspose the weight + qweight.t(), requires_grad=False + ) # Pretranspose the weight layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.input_scale = None - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index b67ee5cf453d7..51f9d56121bdd 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -8,18 +8,30 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 - QuarkMoEMethod) + QuarkMoEMethod, +) from vllm.model_executor.layers.quantization.quark.schemes import ( - QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkOCP_MX, + QuarkScheme, + QuarkW8A8Fp8, + QuarkW8A8Int8, +) from vllm.model_executor.layers.quantization.quark.utils import ( - deep_compare, should_ignore_layer) + deep_compare, + should_ignore_layer, +) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] @@ -28,12 +40,13 @@ logger = init_logger(__name__) class QuarkConfig(QuantizationConfig): - - def __init__(self, - quant_config: dict[str, Any], - kv_cache_group: Optional[list[str]] = None, - kv_cache_config: Optional[dict[str, Any]] = None, - pack_method: str = "reorder"): + def __init__( + self, + quant_config: dict[str, Any], + kv_cache_group: Optional[list[str]] = None, + kv_cache_config: Optional[dict[str, Any]] = None, + pack_method: str = "reorder", + ): super().__init__() if kv_cache_group is None: kv_cache_group = [] @@ -55,15 +68,16 @@ class QuarkConfig(QuantizationConfig): def get_name(self) -> QuantizationMethods: return "quark" - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import # Check if the layer is skipped for quantization. exclude_layers = cast(list[str], self.quant_config.get("exclude")) - if should_ignore_layer(prefix, - ignore=exclude_layers, - fused_mapping=self.packed_modules_mapping): + if should_ignore_layer( + prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping + ): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -73,17 +87,17 @@ class QuarkConfig(QuantizationConfig): return QuarkKVCacheMethod(self) if isinstance(layer, FusedMoE): - return QuarkMoEMethod.get_moe_method(self, - module=layer, - layer_name=prefix) + return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix) return None @classmethod def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": export_config = config.get("export") if export_config is None: - raise ValueError("The export key should be included in " - "the configurations of Quark quantized model") + raise ValueError( + "The export key should be included in " + "the configurations of Quark quantized model" + ) kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) pack_method = cast(str, export_config.get("pack_method")) @@ -96,33 +110,32 @@ class QuarkConfig(QuantizationConfig): kv_cache_config = None else: kv_cache_set = set(kv_cache_group) - layer_quant_config = cast(dict[str, Any], - config.get("layer_quant_config")) + layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config")) layer_quant_names = list(layer_quant_config.keys()) layer_quant_set = set(layer_quant_names) if not kv_cache_set.issubset(layer_quant_set): - raise ValueError("The Quark quantized model has the " - "kv_cache_group parameter setting, " - "but no kv_cache quantization settings " - "were found in the quantization " - "configuration.") + raise ValueError( + "The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration." + ) q_configs = [ cast(dict[str, Any], layer_quant_config.get(name)) for name in kv_cache_group ] - if not all( - deep_compare(q_config, q_configs[0]) - for q_config in q_configs): + if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): raise ValueError( "The quantization method used for kv_cache should " "be the same, but the quantization method for the " - "kv_cache layer in the config is different.") + "kv_cache layer in the config is different." + ) kv_cache_config = q_configs[0].get("output_tensors") if kv_cache_config is None: - raise ValueError( - "The kv_cache quantization configuration is empty.") + raise ValueError("The kv_cache quantization configuration is empty.") # Since we have already set kv_cache quantization configurations, # we will remove the quantization configuration for the @@ -132,23 +145,22 @@ class QuarkConfig(QuantizationConfig): # In case q_proj output is also quantized, remove the configuration # to keep qkv consistency. - q_proj_q_config = cast(dict[str, Any], - layer_quant_config.get("*q_proj")) + q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj")) if q_proj_q_config is not None: q_proj_q_config["output_tensors"] = None - return cls(quant_config=config, - kv_cache_group=kv_cache_group, - kv_cache_config=kv_cache_config, - pack_method=pack_method) + return cls( + quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method, + ) @classmethod def get_config_filenames(cls) -> list[str]: return [] - def _check_scheme_supported(self, - min_capability: int, - error: bool = True) -> bool: + def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: capability_tuple = current_platform.get_device_capability() if capability_tuple is not None: @@ -158,26 +170,33 @@ class QuarkConfig(QuantizationConfig): raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + f"Current capability: {capability}.", + ) return supported else: return False - def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]]) -> bool: + def _is_fp8_w8a8( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: return False # Confirm weight scheme is supported - is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3" - and input_quant.get("dtype") == "fp8_e4m3") + is_fp8_dtype = ( + weight_quant.get("dtype") == "fp8_e4m3" + and input_quant.get("dtype") == "fp8_e4m3" + ) is_static_weight = not weight_quant.get("is_dynamic") - is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") - in ["per_tensor", "per_channel"]) + is_per_tensor_or_channel_weight = weight_quant.get("qscheme") in [ + "per_tensor", + "per_channel", + ] - if not (is_fp8_dtype and is_static_weight - and is_per_tensor_or_channel_weight): + if not (is_fp8_dtype and is_static_weight and is_per_tensor_or_channel_weight): return False # Dynamic quantization is always supported if weights supported. @@ -185,76 +204,88 @@ class QuarkConfig(QuantizationConfig): return True # Confirm activation scheme is supported. - is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") + is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor" return is_per_tensor_activation - def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]]) -> bool: + def _is_static_tensor_w8a8( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: return False - is_int8_dtype = (weight_quant.get("dtype") == "int8" - and input_quant.get("dtype") == "int8") + is_int8_dtype = ( + weight_quant.get("dtype") == "int8" and input_quant.get("dtype") == "int8" + ) - is_tensor = (weight_quant.get("qscheme") - in ["per_tensor", "per_channel"] - and input_quant.get("qscheme") == "per_tensor") + is_tensor = ( + weight_quant.get("qscheme") in ["per_tensor", "per_channel"] + and input_quant.get("qscheme") == "per_tensor" + ) - is_static = (not weight_quant.get("is_dynamic") - and not input_quant.get("is_dynamic")) + is_static = not weight_quant.get("is_dynamic") and not input_quant.get( + "is_dynamic" + ) - is_weight_symmetric = (weight_quant.get("symmetric") is True) + is_weight_symmetric = weight_quant.get("symmetric") is True # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static - def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]]) -> bool: + def _is_ocp_mx( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: - logger.debug("Quark model is not in MX-FP4 format: " - "weight_quant or input_quant not set") - return False - - # Input and weight dtype needs to be fp4. - if weight_quant.get("dtype") != "fp4" or input_quant.get( - "dtype") != "fp4": - logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + logger.debug( + "Quark model is not in OCP MX format: " + "weight_quant or input_quant not set" + ) return False # Input and weight qscheme needs to be per group. - if weight_quant.get("qscheme") != "per_group" or input_quant.get( - "qscheme") != "per_group": - logger.debug("Quark model is not in MX-FP4 format: not per_group") + if ( + weight_quant.get("qscheme") != "per_group" + or input_quant.get("qscheme") != "per_group" + ): + logger.debug("Quark model is not in OCP MX format: not per_group") return False # Input and weight group size needs to be 32. - if weight_quant.get("group_size") != 32 or input_quant.get( - "group_size") != 32: - logger.debug( - "Quark model is not in MX-FP4 format: not group_size=32") - return False - - # Activations need to use dynamic quantization. - if input_quant.get("is_dynamic") is False: - logger.debug( - "Quark model is not in MX-FP4 format: not activation dynamic") + if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: + logger.debug("Quark model is not in OCP MX format: not group_size=32") return False # Activations and weight scales need to be in e8m0 format. - if weight_quant.get("scale_format") != "e8m0" or input_quant.get( - "scale_format") != "e8m0": + if ( + weight_quant.get("scale_format") != "e8m0" + or input_quant.get("scale_format") != "e8m0" + ): + logger.debug("Quark model is not in OCP MX format: not scale_format e8m0") + return False + + # Input and weight dtypes need to be any of fp4, + # fp6_e3m2 or fp6_e3m2, possibly mixed. + if weight_quant.get("dtype") not in { + "fp4", + "fp6_e3m2", + "fp6_e2m3", + } or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}: logger.debug( - "Quark model is not in MX-FP4 format: not scale_format e8m0") + "Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3" + ) return False return True - def _find_matched_config(self, layer_name: str, - module: torch.nn.Module) -> dict[str, Any]: - + def _find_matched_config( + self, layer_name: str, module: torch.nn.Module + ) -> dict[str, Any]: proj_name = layer_name.split(".")[-1] if proj_name in self.packed_modules_mapping: shard_proj_names = self.packed_modules_mapping[proj_name] @@ -269,59 +300,66 @@ class QuarkConfig(QuantizationConfig): for shard_name in shard_names ] if not all( - deep_compare(q_config, shard_configs[0]) - for q_config in shard_configs): + deep_compare(q_config, shard_configs[0]) for q_config in shard_configs + ): raise ValueError( f"Found a different quantization configuration for " f"{shard_proj_names} in {layer_name}. vLLM " - "requires all to use the same scheme.") + "requires all to use the same scheme." + ) return shard_configs[0] else: layer_quant_config = cast( - dict[str, Any], self.quant_config.get("layer_quant_config")) + dict[str, Any], self.quant_config.get("layer_quant_config") + ) for name_pattern in layer_quant_config: if fnmatch.fnmatch(layer_name, name_pattern): return layer_quant_config[name_pattern] layer_type = cast(str, type(module)) layer_type_quant_config = cast( - dict[str, Any], - self.quant_config.get("layer_type_quant_config")) + dict[str, Any], self.quant_config.get("layer_type_quant_config") + ) if layer_type in layer_type_quant_config: return layer_type_quant_config[layer_type] global_quant_config = cast( - dict[str, Any], self.quant_config.get("global_quant_config")) + dict[str, Any], self.quant_config.get("global_quant_config") + ) return global_quant_config def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " - "and bias quantized are not supported") + "and bias quantized are not supported" + ) weight_config = cast(dict[str, Any], config.get("weight")) input_config = cast(dict[str, Any], config.get("input_tensors")) if self._is_fp8_w8a8(weight_config, input_config): is_fp8_w8a8_supported = self._check_scheme_supported( - QuarkW8A8Fp8.get_min_capability(), error=False) + QuarkW8A8Fp8.get_min_capability(), error=False + ) if is_fp8_w8a8_supported: return QuarkW8A8Fp8(weight_config, input_config) elif self._is_static_tensor_w8a8(weight_config, input_config): weight_qscheme = cast(str, weight_config.get("qscheme")) - return QuarkW8A8Int8(qscheme=weight_qscheme, - is_static_input_scheme=True, - input_symmetric=input_config.get("symmetric")) - elif self._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFP4(weight_config, input_config) + return QuarkW8A8Int8( + qscheme=weight_qscheme, + is_static_input_scheme=True, + input_symmetric=input_config.get("symmetric"), + ) + elif self._is_ocp_mx(weight_config, input_config): + return QuarkOCP_MX(weight_config, input_config) - raise NotImplementedError("No quark compatible scheme was found. " - f"Weight config: {weight_config}, " - f"Input config: {input_config}") - - def get_scheme(self, layer: torch.nn.Module, - layer_name: str) -> "QuarkScheme": + raise NotImplementedError( + "No quark compatible scheme was found. " + f"Weight config: {weight_config}, " + f"Input config: {input_config}" + ) + def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": layer_quant_config = self._find_matched_config(layer_name, layer) # Find the quant_scheme @@ -335,7 +373,7 @@ class QuarkConfig(QuantizationConfig): def get_cache_scale(self, name: str) -> Optional[str]: """ Check whether the param name matches the format for k/v cache scales - in quark. If this is the case, return its equivalent param name + in quark. If this is the case, return its equivalent param name expected by vLLM :param name: param name @@ -355,18 +393,22 @@ class QuarkConfig(QuantizationConfig): class QuarkLinearMethod(LinearMethodBase): - def __init__(self, quantization_config: QuarkConfig): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """ Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param @@ -380,12 +422,15 @@ class QuarkLinearMethod(LinearMethodBase): output_partition_sizes=output_partition_sizes, output_size=output_size, params_dtype=params_dtype, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the @@ -395,6 +440,7 @@ class QuarkLinearMethod(LinearMethodBase): scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) @@ -421,11 +467,13 @@ class QuarkKVCacheMethod(BaseKVCacheMethod): if dtype != "fp8_e4m3": raise NotImplementedError( "Currently supported kv cache quantization is " - f"dtype=fp8_e4m3, however received {dtype}") + f"dtype=fp8_e4m3, however received {dtype}" + ) qscheme = kv_cache_config.get("qscheme") if qscheme != "per_tensor": raise NotImplementedError( "Only support per-tensor scaling factor " "for quark KV cache. " - f"Expected qscheme: per_tensor, found qscheme: {qscheme}") + f"Expected qscheme: per_tensor, found qscheme: {qscheme}" + ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 58f56c6381b31..f00188a6f8c40 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,62 +1,79 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + ocp_mx_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_moe_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( + OCP_MX_BLOCK_SIZE, + OCP_MX_Scheme, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) + all_close_1d, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) -__all__ = [ - "QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod" -] +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"] class QuarkMoEMethod(FusedMoEMethodBase): - def __init__(self, moe: FusedMoEConfig): super().__init__(moe) @staticmethod def get_moe_method( - quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 - module: torch.nn.Module, - layer_name: str) -> "QuarkMoEMethod": - layer_quant_config = quant_config._find_matched_config( - layer_name, module) + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str, + ) -> "QuarkMoEMethod": + layer_quant_config = quant_config._find_matched_config(layer_name, module) - if (layer_quant_config.get("output_tensors") - or layer_quant_config.get("bias")): - raise NotImplementedError("Currently, Quark models with " - "output_tensors and bias " - "quantized are not supported") + if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported" + ) weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): - return QuarkW8A8Fp8MoEMethod(weight_config, input_config, - module.moe_config) - elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, - module.moe_config) + return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) + elif quant_config._is_ocp_mx(weight_config, input_config): + return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): - def __init__( self, weight_config: dict[str, Any], @@ -67,73 +84,136 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): self.weight_quant = weight_config self.input_quant = input_config - weight_qscheme = self.weight_quant.get("qscheme") - input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_tensor" - and input_qscheme == "per_tensor"): + self.weight_qscheme = self.weight_quant.get("qscheme") + self.input_qscheme = self.input_quant.get("qscheme") + per_tensor = ( + self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor" + ) + per_channel = ( + self.weight_qscheme == "per_channel" and self.input_qscheme == "per_channel" + ) + self.act_quant_group_shape = ( + GroupShape.PER_TOKEN if per_channel else GroupShape.PER_TENSOR + ) + if not (per_tensor or per_channel): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}") # noqa E501 + "For FP8 Fused MoE layers, only per-tensor and per-channel " + "scales for weights and activations are supported. Found " + f"{self.weight_qscheme}, {self.input_qscheme}" + ) # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization." + ) - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None params_dtype = torch.float8_e4m3fn # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + if self.weight_qscheme == "per_tensor": + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.weight_qscheme == "per_channel": + # quark's scale is 1 dim. + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: @@ -144,65 +224,123 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: - if (layer.w13_input_scale is None or layer.w2_input_scale is None): + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") + "for each layer. " + ) layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) + layer.w13_input_scale.max(), requires_grad=False + ) layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ + w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, - requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size + # For per-tensor case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_qscheme == "per_tensor": + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + # quark's scale is 1 dim. + elif self.weight_qscheme == "per_channel": + if self.act_quant_group_shape == GroupShape.PER_TOKEN: + w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + # Property to determine if AITER is used + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, + shuffle_weights, + ) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data + ) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + + self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts + elif self.use_marlin: + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + self.fused_experts_func = None + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + self.fused_experts_func = fused_experts + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=self.weight_qscheme == "per_channel", + ) def apply( self, @@ -218,6 +356,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -225,16 +364,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") + "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet." + ) - from vllm.model_executor.layers.fused_moe import fused_experts - - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -244,29 +382,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + quant_config=self.moe_quant_config, + expert_map=expert_map, + ) + if self.use_marlin: + assert activation == "silu", f"{activation} not supported for Marlin MoE." + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + None, + None, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + assert self.fused_experts_func is not None + + return self.fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, + activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - activation=activation) + quant_config=self.moe_quant_config, + ) -class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): - +class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): def __init__( self, weight_config: dict[str, Any], @@ -279,64 +448,95 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): weight_qscheme = self.weight_quant.get("qscheme") input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_group" - and input_qscheme == "per_group"): + if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): raise ValueError( "For MX(FP4) Fused MoE layers, only per-group scales " "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}") # noqa E501 + f"{weight_qscheme}, {input_qscheme}" + ) # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") + self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") + self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") + + self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( + self.input_dtype, self.weight_dtype + ) + if self.static_input_scales: raise NotImplementedError( - "QuarkW4A4MXFp4MoEMethod with static input scales is currently " - "not implemented. Please open an issue.") + "QuarkOCP_MX_MoEMethod with static input scales is currently " + "not implemented. Please open an issue." + ) if not current_platform.supports_mx(): self.emulate = True logger.warning_once( - "The current platform does not support native MXFP4 " + "The current platform does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") + "layers computed in high precision." + ) else: self.emulate = True logger.warning_once( - "The current platform supports native MXFP4 " + "The current platform supports native MXFP4/MXFP6 " "computation, but kernels are not yet integrated in vLLM. " "Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") + "layers computed in high precision." + ) - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def get_packed_dim(self, dim: int, quant_dtype: str): + if quant_dtype == "mxfp4": + assert dim % 2 == 0 + return dim // 2 + else: + # FP6 packs 4 * 6 = 24 bits on 3 bytes. + assert (dim * 3) % 4 == 0 + return (dim * 3) // 4 + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) params_dtype = torch.uint8 # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // 2, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + self.get_packed_dim(hidden_size, self.weight_dtype), + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // 2, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype), + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -366,6 +566,19 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + ) + def apply( self, layer: torch.nn.Module, @@ -380,6 +593,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -387,16 +601,17 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") + "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -406,8 +621,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) out = fused_experts( x, @@ -416,15 +633,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_mxfp4_w4a4=True, + activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - activation=activation, + quant_config=self.moe_quant_config, ) return out diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index ec09d9b2ac26f..7620d6e41b58a 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .quark_ocp_mx import QuarkOCP_MX from .quark_scheme import QuarkScheme -from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"] +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py new file mode 100644 index 0000000000000..0eefa7f7e96c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -0,0 +1,298 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from fractions import Fraction +from functools import cache, partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F + +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + dequant_mxfp4, + quant_dequant_mxfp4, +) +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import ( + dequant_mxfp6, + quant_dequant_mxfp6, +) +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( + OCP_MX_BLOCK_SIZE, + OCP_MX_Scheme, +) +from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from vllm.platforms import current_platform + +from .quark_scheme import QuarkScheme + +logger = init_logger(__name__) + + +@cache +def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + and envs.VLLM_ROCM_USE_AITER + ) + + +try: + from aiter.ops.shuffle import shuffle_weight + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + from vllm.utils import direct_register_custom_op + + if is_rocm_aiter_fp4_asm_gemm_enabled(): + from aiter import gemm_a4w4, per_1x32_f4_quant_hip + + def gemm_with_dynamic_quant( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + x_scales: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + M = x.shape[0] + if rocm_use_aiter_fp4_asm_gemm: + if x_scales is None: + # use hip quant kernel for performance + x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) + else: + x_q = x + x_s = x_scales + + # 32 alignment is enough for dim0 padding of output for + # gemm_a4w4 kernel + y = torch.empty( + (M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_a4w4( + x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True + ) + return y[:M] + else: + if x_scales is None: + x_q, x_s = dynamic_mxfp4_quant(x) + else: + x_q = x + x_s = x_scales + y = torch.empty( + x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) + return y + + def gemm_with_dynamic_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + x_scales: torch.Tensor = None, + rocm_use_aiter_fp4_asm_gemm: bool = False, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + ) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], weight.shape[0]), dtype=out_dtype, device=x.device + ) + + direct_register_custom_op( + op_name="gemm_with_dynamic_quant", + op_func=gemm_with_dynamic_quant, + mutates_args=[], + fake_impl=gemm_with_dynamic_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) +except (ImportError, AttributeError): + dynamic_mxfp4_quant = gemm_afp4wfp4 = None + + +class QuarkOCP_MX(QuarkScheme): + def __init__( + self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] + ): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + + self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp") + self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp") + + self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( + self.input_dtype, self.weight_dtype + ) + + if self.weight_dtype == "mxfp4": + self.packed_factor: Union[int, Fraction] = 2 + self.dequant_func = dequant_mxfp4 + else: + self.packed_factor = Fraction(numerator=8, denominator=6) + self.dequant_func = partial( + dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "") + ) + + if self.input_dtype == "mxfp4": + self.quant_dequant_func = quant_dequant_mxfp4 + else: + self.quant_dequant_func = partial( + quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "") + ) + + self.static_input_scales = not input_quant_spec.get("is_dynamic") + + if self.static_input_scales: + raise NotImplementedError( + "QuarkOCP_MX with static input scales is currently not " + "implemented. Please open an issue." + ) + + # TODO: integrate (or test) mixed-precision kernel. + self.emulate = not current_platform.supports_mx() or ( + self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4" + ) + + self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() + + if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None): + # Currently need these kernels if not emulating + raise NotImplementedError( + f"{self.__class__.__name__} requires AITER to be installed " + "for non-emulation mode! Please refer to " + "https://github.com/ROCm/aiter for installation details." + ) + + if not current_platform.supports_mx(): + logger.warning_once( + "The current platform does not support native MXFP4/MXFP6 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision." + ) + + if current_platform.supports_mx() and ( + self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4" + ): + logger.warning_once( + "The current platform supports native MXFP4/MXFP6 " + f"computation, but kernels for input_dtype={self.input_dtype} " + f"and weight_dtype={self.weight_dtype} are not yet integrated " + "in vLLM. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision." + ) + + def get_packed_dim(self, dim: int, quant_dtype: str): + if quant_dtype == "mxfp4": + assert dim % 2 == 0 + return dim // 2 + elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}: + # FP6 packs 4 * 6 = 24 bits on 3 bytes. + assert (dim * 3) % 4 == 0 + return (dim * 3) // 4 + else: + raise NotImplementedError( + "Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, " + f"got quant_dtype={quant_dtype}. Something is wrong, please " + "open an issue." + ) + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + + if self.emulate: + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) + else: + if self.rocm_use_aiter_fp4_asm_gemm: + # shuffle weight scale + weight_scale_shuffle = layer.weight_scale.data + sm, sn = weight_scale_shuffle.shape + weight_scale_shuffle = weight_scale_shuffle.view( + sm // 32, 2, 16, sn // 8, 2, 4, 1 + ) + weight_scale_shuffle = weight_scale_shuffle.permute( + 0, 3, 5, 2, 4, 1, 6 + ).contiguous() + weight_scale_shuffle = weight_scale_shuffle.view(sm, sn) + layer.weight_scale = torch.nn.Parameter( + weight_scale_shuffle, requires_grad=False + ) + + # shuffle weight + weight_shuffle = layer.weight.data + weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16)) + layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False) + else: + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), requires_grad=False + ) + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + self.get_packed_dim(input_size_per_partition, self.weight_dtype), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.packed_factor, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.emulate: + dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype) + qdq_x = self.quant_dequant_func(x) + return F.linear(qdq_x, dq_w, bias) + else: + return torch.ops.vllm.gemm_with_dynamic_quant( + x, + layer.weight, + layer.weight_scale, + self.rocm_use_aiter_fp4_asm_gemm, + self.out_dtype, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py index c167e949ac262..ddec0f6ea8eb8 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -11,7 +11,7 @@ __all__ = ["QuarkScheme"] class QuarkScheme(ABC): """ - Abstract class used to describe the weight creation and forward pass + Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by Quark. """ @@ -26,20 +26,21 @@ class QuarkScheme(ABC): @abstractmethod def create_weights(self, *args, **kwargs): """ - Weight creation for the particular scheme. Inputs to this function + Weight creation for the particular scheme. Inputs to this function """ raise NotImplementedError @abstractmethod - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]): + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): """ - Run the forward pass for the particular scheme. This is where + Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: torch.nn.Module with the registered weights and - other parameters relevant to the particular scheme. + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py deleted file mode 100644 index 880438a22a695..0000000000000 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Optional - -import torch -import torch.nn.functional as F - -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) -from vllm.platforms import current_platform - -logger = init_logger(__name__) - -__all__ = ["QuarkW4A4MXFP4"] - - -class QuarkW4A4MXFP4(QuarkScheme): - - def __init__(self, weight_quant_spec: dict[str, Any], - input_quant_spec: dict[str, Any]): - self.out_dtype = torch.get_default_dtype() - self.qscheme = "per_group" - self.weight_quant_spec = weight_quant_spec - self.input_quant_spec = input_quant_spec - - self.static_input_scales = not input_quant_spec.get("is_dynamic") - - if self.static_input_scales: - raise NotImplementedError( - "QuarkW4A4MXFP4 with static input scales is currently not " - "implemented. Please open an issue.") - - if not current_platform.supports_mx(): - self.emulate = True - logger.warning_once( - "The current platform does not support native MXFP4 " - "computation. Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") - else: - self.emulate = True - logger.warning_once( - "The current platform supports native MXFP4 " - "computation, but kernels are not yet integrated in vLLM. " - "Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") - - @classmethod - def get_min_capability(cls) -> int: - return 70 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - - # WEIGHT - weight = PackedvLLMParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition // 2, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - packed_dim=1, - packed_factor=2, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - weight_scale = GroupQuantScaleParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition // OCP_MX_BLOCK_SIZE, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) - - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - if self.emulate: - dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype) - - x = quant_dequant_mxfp4(x) - - return F.linear(x, dq_w, bias) - else: - raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 2cb35249f49ef..553698a7dc94a 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,37 +7,43 @@ import torch from torch.nn import Parameter from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + Fp8LinearOp, + normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.platforms import current_platform __all__ = ["QuarkW8A8Fp8"] class QuarkW8A8Fp8(QuarkScheme): - - def __init__(self, weight_config: dict[str, Any], - input_config: Optional[dict[str, Any]]): + def __init__( + self, weight_config: dict[str, Any], input_config: Optional[dict[str, Any]] + ): self.weight_qscheme = cast(str, weight_config.get("qscheme")) self.is_static_input_scheme: bool = False self.input_qscheme: Optional[str] = None if input_config is not None: - self.is_static_input_scheme = not cast( - bool, input_config.get("is_dynamic")) + self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - per_token = (not self.is_static_input_scheme - and self.input_qscheme == "per_channel") - self.act_quant_group_shape = GroupShape.PER_TOKEN \ - if per_token else GroupShape.PER_TENSOR + per_token = ( + not self.is_static_input_scheme and self.input_qscheme == "per_channel" + ) + self.act_quant_group_shape = ( + GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR + ) self.fp8_linear = Fp8LinearOp( act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_quant_group_shape) + act_quant_group_shape=self.act_quant_group_shape, + ) self.out_dtype = torch.get_default_dtype() @classmethod @@ -51,14 +57,14 @@ class QuarkW8A8Fp8(QuarkScheme): # requantize so we can always run per tensor if self.weight_qscheme == "per_tensor": if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) + input_scale = getattr(layer, "input_scale", None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, weight_scale=layer.weight_scale, - input_scale=input_scale) + input_scale=input_scale, + ) if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) + layer.input_scale = Parameter(input_scale, requires_grad=False) else: max_w_scale = layer.weight_scale weight = layer.weight @@ -77,15 +83,14 @@ class QuarkW8A8Fp8(QuarkScheme): weight = layer.weight if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=input_scale) + input_scale = getattr(layer, "input_scale", None) + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=input_scale, + ) if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) + layer.input_scale = Parameter(input_scale, requires_grad=False) else: weight_scale = layer.weight_scale.data if self.act_quant_group_shape == GroupShape.PER_TOKEN: @@ -95,32 +100,37 @@ class QuarkW8A8Fp8(QuarkScheme): layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: - raise ValueError( - f"Unknown quantization scheme {self.weight_qscheme}") + raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}") # INPUT SCALE if self.is_static_input_scheme: - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE @@ -128,15 +138,16 @@ class QuarkW8A8Fp8(QuarkScheme): # the newly added parameters if self.weight_qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes)), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert self.weight_qscheme == "per_tensor" - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) # min requirement for fp8 kernels weight_scale[:] = torch.finfo(torch.float32).min @@ -144,20 +155,24 @@ class QuarkW8A8Fp8(QuarkScheme): # INPUT SCALE if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index ae68d5bbc2680..c41dd05d10629 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,12 +7,16 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) + ScaledMMLinearLayerConfig, + choose_scaled_mm_linear_kernel, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) logger = init_logger(__name__) @@ -20,8 +24,12 @@ logger = init_logger(__name__) class QuarkW8A8Int8(QuarkScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool], - input_symmetric: Optional[bool]): + def __init__( + self, + qscheme: str, + is_static_input_scheme: Optional[bool], + input_symmetric: Optional[bool], + ): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme self.input_symmetric = input_symmetric @@ -31,92 +39,101 @@ class QuarkW8A8Int8(QuarkScheme): # turing and up return 75 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): layer.logical_widths = output_partition_sizes scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), - input_symmetric=(self.input_symmetric is True)) + input_symmetric=(self.input_symmetric is True), + ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE if self.qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes)), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) ChannelQuantZPParameter = ChannelQuantScaleParameter weight_zero_point = ChannelQuantZPParameter( - data=torch.empty((sum(output_partition_sizes)), - dtype=torch.int8), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert self.qscheme == "per_tensor" - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) PerTensorZPParameter = PerTensorScaleParameter weight_zero_point = PerTensorZPParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.int8), - weight_loader=weight_loader) + data=torch.empty(len(output_partition_sizes), dtype=torch.int8), + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE if self.is_static_input_scheme: - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader + ) layer.register_parameter("input_scale", input_scale) - input_zero_point = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.int8), - weight_loader=weight_loader) + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader + ) layer.register_parameter("input_zero_point", input_zero_point) - self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj") + self.kernel = kernel_type( + c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj", + ) # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.register_parameter("weight_zero_point", None) - delattr(layer, 'weight_zero_point') + delattr(layer, "weight_zero_point") if self.input_symmetric: layer.register_parameter("input_zero_point", None) - delattr(layer, 'input_zero_point') + delattr(layer, "input_zero_point") self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 99f5ec15933ab..0eb4b20a6e52c 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -24,7 +24,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: if layer_name is None: return False @@ -50,7 +50,8 @@ def should_ignore_layer( should_ignore_layer = None for shard_name in shard_names: should_ignore_shard = check_equal_or_regex_match( - layer_name=shard_name, targets=ignore) + layer_name=shard_name, targets=ignore + ) # If shard_idx=0, set layer ignore to match shard. if should_ignore_layer is None: @@ -58,35 +59,34 @@ def should_ignore_layer( # If shard_idx=1+ confirm scheme matches prior shards. elif should_ignore_shard != should_ignore_layer: - raise ValueError(f"Found a different quantization schemes for " - f"{shard_proj_names} in {layer_name}. vLLM " - "requires all to use the same scheme.") + raise ValueError( + f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) # Unfused layers like down_proj and o_proj will match # the safetensors checkpoint already. else: - should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, - targets=ignore) + should_ignore_layer = check_equal_or_regex_match( + layer_name=layer_name, targets=ignore + ) assert should_ignore_layer is not None return should_ignore_layer -def check_equal_or_regex_match(layer_name: str, - targets: Iterable[str]) -> bool: +def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: """ - Checks whether a layer_name is exactly equal or a regex match for + Checks whether a layer_name is exactly equal or a regex match for if target starts with 're:' to any target in list. """ - for target in targets: - if _is_equal_or_regex_match(layer_name, target): - return True - return False + return any(_is_equal_or_regex_match(layer_name, target) for target in targets) -def _is_equal_or_regex_match(value: str, - target: str, - check_contains: bool = False) -> bool: +def _is_equal_or_regex_match( + value: str, target: str, check_contains: bool = False +) -> bool: """ Checks whether a value is exactly equal or a regex match for target if target starts with 're:'. If check_contains is set to True, diff --git a/vllm/model_executor/layers/quantization/qutlass_utils.py b/vllm/model_executor/layers/quantization/qutlass_utils.py new file mode 100644 index 0000000000000..395bde76d02ae --- /dev/null +++ b/vllm/model_executor/layers/quantization/qutlass_utils.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# +# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal + +import torch +import triton +import triton.language as tl +from torch.library import wrap_triton + + +@triton.jit +def triton_scale_swizzle( + scale_ptr: torch.Tensor, + scale_rows: int, + scale_cols: int, + output_ptr: torch.Tensor, + input_row_stride: int, + output_block_stride: int, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + """ + Rearranges tensor data from row-major to block-scaled swizzle format. + + Args: + scale_ptr: Pointer to the input scale tensor + scale_rows: Number of rows in the scale tensor + scale_cols: Number of columns in the scale tensor + output_ptr: Pointer to the output tensor + input_row_stride: Stride between rows in the input tensor + output_block_stride: Stride between blocks in the output tensor + BLOCK_ROWS: Number of rows in a tile (compile-time constant) + BLOCK_COLS: Number of columns in a tile (compile-time constant) + """ + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + rows = tl.arange(0, BLOCK_ROWS)[:, None] + cols = tl.arange(0, BLOCK_COLS)[None, :] + + # Calculate starting row and column for this tile + start_row = pid_row * BLOCK_ROWS + start_col = pid_col * BLOCK_COLS + global_rows = start_row + rows + global_cols = start_col + cols + + mask = (global_rows < scale_rows) & (global_cols < scale_cols) + + input_scales = tl.load( + scale_ptr + global_rows * input_row_stride + global_cols, + mask=mask, + other=0.0, + ) + + r_div_32 = rows // 32 + r_mod_32 = rows % 32 + + # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Calculate block offset using provided output block stride + LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS + block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride) + + tl.store( + output_ptr + block_offset + dest_indices_flat, + scales_flat, + ) + + +def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale from row-major format to + block-scaled swizzle format. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scale_tensor: Input tensor in row-major format with 8-bit elements + + Returns: + Rearranged tensor in block-scaled swizzle format + """ + assert scale_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + assert scale_tensor.is_contiguous(), "Input tensor must be contiguous" + + rows, cols = scale_tensor.shape + + # Calculate blocks needed + n_row_blocks = triton.cdiv(rows, 128) + n_col_blocks = triton.cdiv(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + out = scale_tensor.new_empty((padded_rows, padded_cols)) + + # Input stride (for row-major format) + input_row_stride = cols + + # We probably want handle multiple blocks per tile but + # for now keep it simple + BLOCK_ROWS, BLOCK_COLS = 128, 4 + + # Output block stride for the rearranged format + output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) + + grid = lambda META: ( + triton.cdiv(padded_rows, BLOCK_ROWS), + triton.cdiv(padded_cols, BLOCK_COLS), + ) + + wrap_triton(triton_scale_swizzle)[grid]( + scale_tensor.view(torch.uint8), + rows, + cols, + out.view(torch.uint8), + input_row_stride, + output_block_stride, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + return out + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked( + input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton" +) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying + the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + backend: "torch" (PyTorch path) or "triton" (Triton kernel) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + if backend == "triton": + return triton_mx_block_rearrange(input_matrix).flatten() + elif backend != "torch": + raise ValueError(f'backend must be "torch" or "triton", got {backend!r}') + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + assert (rows, cols) == (padded_rows, padded_cols) + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 8bdb50e07b137..e0070e207048f 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,40 +3,52 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) logger = init_logger(__name__) """By default, use 8 bit as target precision, but it can be overridden by setting the RTN_NUM_BITS envvar """ -NUM_BITS = os.getenv('RTN_NUM_BITS', "8") +NUM_BITS = os.getenv("RTN_NUM_BITS", "8") """By default, use group size of 128 parameters, but it can be overridden by setting the RTN_GROUP_SIZE envvar """ -GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") +GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") class RTNConfig(QuantizationConfig): - """Config class for RTN. - """ + """Config class for RTN.""" def __init__( - self, - weight_bits: int = int(NUM_BITS), - group_size: int = int(GROUP_SIZE), + self, + weight_bits: int = int(NUM_BITS), + group_size: int = int(GROUP_SIZE), ) -> None: self.weight_bits = weight_bits self.group_size = group_size @@ -44,11 +56,13 @@ class RTNConfig(QuantizationConfig): if self.weight_bits != 4 and self.weight_bits != 8: raise ValueError( "Currently, only 4-bit or 8-bit weight quantization is " - f"supported for RTN, but got {self.weight_bits} bits.") + f"supported for RTN, but got {self.weight_bits} bits." + ) def __repr__(self) -> str: - return (f"RTNConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size})") + return ( + f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -72,8 +86,9 @@ class RTNConfig(QuantizationConfig): group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return RTNLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -86,8 +101,9 @@ class RTNTensor: overloading the copy_ method. """ - def __init__(self, data: torch.Tensor, scale: torch.Tensor, - quant_config: RTNConfig) -> None: + def __init__( + self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig + ) -> None: self.data = data self.scale = scale self.quant_config = quant_config @@ -96,7 +112,9 @@ class RTNTensor: factor = 1 if self.quant_config.weight_bits == 8 else 2 return RTNTensor( self.data.narrow(dim, start // factor, length // factor), - self.scale.narrow(dim, start, length), self.quant_config) + self.scale.narrow(dim, start, length), + self.quant_config, + ) def __getitem__(self, key): return RTNTensor(self.data[key], self.scale[key], self.quant_config) @@ -112,9 +130,11 @@ class RTNTensor: return torch.Size((shape[0] * factor, shape[1])) def copy_(self, loaded_weight: torch.Tensor) -> None: - qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), - self.quant_config.weight_bits, - self.quant_config.group_size) + qweight, weight_scale = rtn_quantize( + loaded_weight.cuda(), + self.quant_config.weight_bits, + self.quant_config.group_size, + ) self.data.copy_(qweight) self.scale.data.copy_(weight_scale) @@ -130,8 +150,9 @@ class RTNParameter(Parameter): def __new__(cls, data: torch.Tensor, **kwargs): return super().__new__(cls, data=data, requires_grad=False) - def __init__(self, data: torch.Tensor, scale: torch.Tensor, - quant_config: RTNConfig) -> None: + def __init__( + self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig + ) -> None: self.scale = scale self.quant_config = quant_config @@ -161,31 +182,39 @@ class RTNLinearMethod(LinearMethodBase): **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) - num_groups_per_col = (input_size_per_partition // - self.quant_config.group_size - if self.quant_config.group_size != -1 else 1) + num_groups_per_col = ( + input_size_per_partition // self.quant_config.group_size + if self.quant_config.group_size != -1 + else 1 + ) scale = Parameter( - torch.empty(output_size_per_partition, - num_groups_per_col, - dtype=params_dtype), + torch.empty( + output_size_per_partition, num_groups_per_col, dtype=params_dtype + ), requires_grad=False, ) factor = 1 if self.quant_config.weight_bits == 8 else 2 - weight = RTNParameter(data=torch.empty(output_size_per_partition // - factor, - input_size_per_partition, - dtype=torch.uint8), - scale=scale, - quant_config=self.quant_config) + weight = RTNParameter( + data=torch.empty( + output_size_per_partition // factor, + input_size_per_partition, + dtype=torch.uint8, + ), + scale=scale, + quant_config=self.quant_config, + ) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) layer.register_parameter("scale", scale) layer.output_size_per_partition = output_size_per_partition @@ -193,10 +222,12 @@ class RTNLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: fix_weights(layer, "weight") - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: qweight = layer.weight scale = layer.scale @@ -210,57 +241,75 @@ class RTNLinearMethod(LinearMethodBase): class RTNMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): factor = 1 if self.quant_config.weight_bits == 8 else 2 # Fused gate_up_proj (column parallel) - num_groups_per_col = (hidden_size // self.quant_config.group_size - if self.quant_config.group_size != -1 else 1) + num_groups_per_col = ( + hidden_size // self.quant_config.group_size + if self.quant_config.group_size != -1 + else 1 + ) w13_scale = Parameter( - torch.empty(num_experts, - 2 * intermediate_size_per_partition, - num_groups_per_col, - dtype=params_dtype), + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + num_groups_per_col, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w13_scale", w13_scale) - w13_weight = RTNParameter(data=torch.empty( - num_experts, - 2 * intermediate_size_per_partition // factor, - hidden_size, - dtype=torch.uint8), - scale=w13_scale, - quant_config=self.quant_config) + w13_weight = RTNParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition // factor, + hidden_size, + dtype=torch.uint8, + ), + scale=w13_scale, + quant_config=self.quant_config, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - num_groups_per_col = (intermediate_size_per_partition // - self.quant_config.group_size - if self.quant_config.group_size != -1 else 1) - w2_scale = Parameter(torch.zeros(num_experts, - hidden_size, - num_groups_per_col, - dtype=params_dtype), - requires_grad=False) + num_groups_per_col = ( + intermediate_size_per_partition // self.quant_config.group_size + if self.quant_config.group_size != -1 + else 1 + ) + w2_scale = Parameter( + torch.zeros( + num_experts, hidden_size, num_groups_per_col, dtype=params_dtype + ), + requires_grad=False, + ) layer.register_parameter("w2_scale", w2_scale) - w2_weight = RTNParameter(data=torch.empty( - num_experts, - hidden_size // factor, - intermediate_size_per_partition, - dtype=torch.uint8), - scale=w2_scale, - quant_config=self.quant_config) + w2_weight = RTNParameter( + data=torch.empty( + num_experts, + hidden_size // factor, + intermediate_size_per_partition, + dtype=torch.uint8, + ), + scale=w2_scale, + quant_config=self.quant_config, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -269,6 +318,25 @@ class RTNMoEMethod(FusedMoEMethodBase): fix_weights(layer, "w13_weight", weight_bits == 4) fix_weights(layer, "w2_weight", weight_bits == 4) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + weight_bits = self.quant_config.weight_bits + group_size = self.quant_config.group_size + assert weight_bits == 4 or weight_bits == 8 + config_builder = ( + int4_w4a16_moe_quant_config + if weight_bits == 4 + else int8_w8a16_moe_quant_config + ) + return config_builder( + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -283,6 +351,7 @@ class RTNMoEMethod(FusedMoEMethodBase): expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -290,16 +359,15 @@ class RTNMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `RTNMoEMethod` yet.") + raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -309,13 +377,12 @@ class RTNMoEMethod(FusedMoEMethodBase): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) - weight_bits = self.quant_config.weight_bits - group_size = self.quant_config.group_size - - ret = fused_experts( + return fused_experts( x, layer.w13_weight, layer.w2_weight, @@ -323,27 +390,23 @@ class RTNMoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - block_shape=[0, group_size]) - - return ret + quant_config=self.moe_quant_config, + ) -def rtn_quantize(tensor: torch.Tensor, num_bits: int, - group_size: int) -> tuple[torch.Tensor, torch.Tensor]: +def rtn_quantize( + tensor: torch.Tensor, num_bits: int, group_size: int +) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a tensor using per-group static scaling factor. Args: tensor: The input tensor. num_bits: Target precision for the result (supported values are 8 or 4). - group_size: Quantization granularity. + group_size: Quantization granularity. If equal to -1, each row in the input tensor is treated as one group. """ @@ -352,15 +415,18 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, tensor = tensor.unsqueeze(0) q_range = 2**num_bits - num_groups = (tensor.shape[1] * tensor.shape[2] // - group_size if group_size != -1 else tensor.shape[1]) + num_groups = ( + tensor.shape[1] * tensor.shape[2] // group_size + if group_size != -1 + else tensor.shape[1] + ) """Calculate a scaling factor per input group. """ input_flat = tensor.reshape(tensor.shape[0], num_groups, -1) input_min = torch.min(input_flat, dim=2, keepdim=True)[0] input_max = torch.max(input_flat, dim=2, keepdim=True)[0] input_max_abs = torch.max(input_min.abs(), input_max.abs()) - scale = (input_max_abs * 2.0 / (q_range - 1)) + scale = input_max_abs * 2.0 / (q_range - 1) """Scale each input group, round to the nearest integer, shift the range and truncate. """ @@ -376,9 +442,10 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, if num_bits == 4: """Pack two 4-bit values into each byte. """ - inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf) - inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2, - tensor.shape[2]) + inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF) + inputs_q = inputs_q.reshape( + tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2] + ) inputs_q = inputs_q.contiguous() if not batch_present: @@ -408,9 +475,9 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: if num_bits == 4: input_dim *= 2 - data = torch.empty((batch, input_dim, output_dim), - dtype=scale.dtype, - device=tensor.device) + data = torch.empty( + (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device + ) if num_bits == 8: data.copy_(tensor) @@ -420,8 +487,9 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ tensor = tensor.reshape(batch, input_dim, output_dim // 2) for i in range(2): - data[:, :, i::2] = ((tensor << 4 * - (1 - i)) >> 4).to(torch.int8) - q_range // 2 + data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to( + torch.int8 + ) - q_range // 2 """Scale each input group with its scaling factor. """ scale = scale.reshape(batch, num_groups, -1) @@ -435,9 +503,7 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return input_deq -def fix_weights(layer: torch.nn.Module, - param_name: str, - reshape: bool = False): +def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False): """torch.compile does not know how to deal with a Parameter subclass (aka RTNParameter). As we don't really need RTNParameters for the forward pass, we replace them with equivalent instances of Parameters. diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py index a108152929d9a..9396da0ecd1a0 100644 --- a/vllm/model_executor/layers/quantization/schema.py +++ b/vllm/model_executor/layers/quantization/schema.py @@ -30,7 +30,8 @@ class KVCacheQuantSchema(BaseModel): def check_is_fp8(self) -> "KVCacheQuantSchema": assert self.dtype == "float8_e4m3fn", ( "Loaded scaling factors intended for KV cache dtype = " - f"{self.dtype} rather than float8_e4m3fn!") + f"{self.dtype} rather than float8_e4m3fn!" + ) return self @model_validator(mode="after") @@ -41,15 +42,18 @@ class KVCacheQuantSchema(BaseModel): num_hidden_layers = context["num_hidden_layers"] assert len(self.scaling_factor) == tp_size, ( f"Loaded dictionary has TP size {len(self.scaling_factor)} " - f"but LLM engine is currently running with TP size {tp_size}.") + f"but LLM engine is currently running with TP size {tp_size}." + ) for tp_rank, layer_maps in self.scaling_factor.items(): assert len(layer_maps) == num_hidden_layers, ( f"KV cache scales map for TP rank {tp_rank} is malformed. " f"Expected {num_hidden_layers} layers, got " - f"{len(layer_maps)}.") + f"{len(layer_maps)}." + ) for i in range(tp_size): assert i in self.scaling_factor, ( - f"KV cache scales map for TP rank {i} not found.") + f"KV cache scales map for TP rank {i} not found." + ) return self @model_validator(mode="after") @@ -62,7 +66,8 @@ class KVCacheQuantSchema(BaseModel): for i in range(num_hidden_layers): assert i in layer_scales_map, ( f"Could not find KV cache scales for layer {i} in " - f"TP rank {tp_rank}.") + f"TP rank {tp_rank}." + ) return self @@ -82,5 +87,6 @@ class QuantParamSchema(BaseModel): assert model_type == self.model_type, ( f"Model type is {model_type} but loaded " f"scaling factors belonging to different " - f"model type {self.model_type}!") + f"model type {self.model_type}!" + ) return self diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 63b2ab6bab063..6f076401ac32e 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -1,22 +1,44 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import json +from importlib.util import find_spec from typing import Any, Optional +import regex as re import torch import torch.nn.functional as F +from packaging import version from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) +def torchao_version_at_least(torchao_version: str) -> bool: + if find_spec("torchao"): + try: + if version.parse(importlib.metadata.version("torchao")) >= version.parse( + torchao_version + ): + return True + except (ImportError, version.InvalidVersion): + return False + return False + + def should_skip(prefix: str, skip_modules: list[str]) -> bool: """ Robust skipping logic: @@ -38,9 +60,12 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool: class TorchAOConfig(QuantizationConfig): """Config class for torchao.""" - def __init__(self, - torchao_config, - skip_modules: Optional[list[str]] = None) -> None: + def __init__( + self, + torchao_config, + skip_modules: Optional[list[str]] = None, + is_checkpoint_torchao_serialized: bool = False, + ) -> None: """ # TorchAO quantization relies on tensor subclasses. In order, # to enable proper caching this needs standalone compile @@ -58,9 +83,13 @@ class TorchAOConfig(QuantizationConfig): super().__init__() self.torchao_config = torchao_config self.skip_modules = skip_modules or [] + self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized def __repr__(self) -> str: - return f"TorchAOConfig({self.torchao_config})" + return ( + f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " + f"{self.is_checkpoint_torchao_serialized=})" + ) def get_name(self) -> QuantizationMethods: return "torchao" @@ -74,7 +103,10 @@ class TorchAOConfig(QuantizationConfig): @staticmethod def get_config_filenames() -> list[str]: - return ["config.json"] + """torchao doesn't require additional config files, we use + `config.json` from huggingface: `model_config.hf_config` + """ + return [] @classmethod def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": @@ -87,10 +119,16 @@ class TorchAOConfig(QuantizationConfig): "`pip install torchao>=0.10.0` to use torchao quantization." ) from err + quant_method = cls.get_from_keys_or(config, ["quant_method"], None) + is_checkpoint_torchao_serialized = ( + quant_method is not None and "torchao" in quant_method + ) + hf_config = cls.get_from_keys_or(config, ["quant_type"], None) assert hf_config is not None, "quant_type must be specified" assert len(hf_config) == 1 and "default" in hf_config, ( - "Expected only one key 'default' in quant_type dictionary") + "Expected only one key 'default' in quant_type dictionary" + ) quant_type = hf_config["default"] ao_config = config_from_dict(quant_type) @@ -110,10 +148,40 @@ class TorchAOConfig(QuantizationConfig): if layer_cfg is None: skip_modules.append(layer) - return cls(ao_config, skip_modules) + return cls(ao_config, skip_modules, is_checkpoint_torchao_serialized) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + @classmethod + def from_config_file(cls, config_file: str) -> "TorchAOConfig": + """Initialize class from a config file. Example: + ``` + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + fn = "torchao_config.json" + + with open(fn, "w") as f: + f.write(json.dumps(config_to_dict(config))) + ``` + """ + with open(config_file) as f: + f.seek(0) + f_read = f.read() + config_dict = json.loads(f_read) + + hf_config = {"quant_type": {"default": config_dict}} + return cls.from_config(hf_config) + + @classmethod + def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig": + """Iniitalize class from a config_dict json string, got from + torchao_config_object = some AOBaseConfig object + json.dumps(config_to_dict(torchao_config_object)) + """ + config_dict = json.loads(config_dict_json) + hf_config = {"quant_type": {"default": config_dict}} + return cls.from_config(hf_config) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if not isinstance(layer, LinearBase): return None @@ -125,10 +193,30 @@ class TorchAOConfig(QuantizationConfig): module_fqn = prefix if isinstance(self.torchao_config, ModuleFqnToConfig): module_fqn_to_config = self.torchao_config.module_fqn_to_config - c = module_fqn_to_config.get( - module_fqn) or module_fqn_to_config.get("_default", None) + c = None + if module_fqn in module_fqn_to_config: + assert not module_fqn.startswith("re:"), ( + "module fqn should not start with" + "`re:`, which is used for specifying regex" + ) + c = module_fqn_to_config[module_fqn] + else: + for maybe_module_fqn_pattern in module_fqn_to_config: + if not maybe_module_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): + # we'll apply the config for first fully matched pattern + c = module_fqn_to_config[maybe_module_fqn_pattern] + break + else: + # fallback to use default if no module specific + # config is provided + c = module_fqn_to_config.get("_default", None) + if c is not None: - current_torchao_config = TorchAOConfig(c, self.skip_modules) + current_torchao_config = TorchAOConfig( + c, self.skip_modules, self.is_checkpoint_torchao_serialized + ) return TorchAOLinearMethod(current_torchao_config) else: return UnquantizedLinearMethod() @@ -139,39 +227,43 @@ class TorchAOConfig(QuantizationConfig): return [] -def torchao_quantize_param_data(param: torch.Tensor, - torchao_config: Any) -> torch.nn.Parameter: +def torchao_quantize_param_data( + param: torch.Tensor, torchao_config: Any +) -> torch.nn.Parameter: """Quantize a Tensor with torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to - use to quantize the Tensor + param: weight parameter of the linear module + torchao_config: type of quantization and their arguments we want to + use to quantize the Tensor """ from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" - """ - Avoid real weight allocation for faster load, since we will + """ + Avoid real weight allocation for faster load, since we will end up setting it to param. """ with torch.device("meta"): - dummy_linear = torch.nn.Linear(param.shape[1], - param.shape[0], - bias=False) + # linear can't be top level module since quantize_ is inplace + # while some of our configs need to do module swap, and only non-top + # level modules support module swap + dummy_linear = torch.nn.Sequential( + torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + ) - dummy_linear.weight = param + dummy_linear[0].weight = param quantize_(dummy_linear, torchao_config) - return dummy_linear.weight + return dummy_linear[0].weight class TorchAOLinearMethod(LinearMethodBase): """Linear method for torchao. Args: - torchao_config: The torchao quantization config, a string - that encodes the type of quantization and all relevant arguments. + quant_config: The torchao quantization config, a string that encodes + the type of quantization and all relevant arguments. """ def __init__(self, quant_config: TorchAOConfig): @@ -195,8 +287,10 @@ class TorchAOLinearMethod(LinearMethodBase): ), requires_grad=False, ) - weight = torchao_quantize_param_data(weight, - self.quant_config.torchao_config) + if self.quant_config.is_checkpoint_torchao_serialized: + weight = torchao_quantize_param_data( + weight, self.quant_config.torchao_config + ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -210,3 +304,15 @@ class TorchAOLinearMethod(LinearMethodBase): bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return F.linear(x, layer.weight, bias) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.quant_config.is_checkpoint_torchao_serialized: + return + + # quantize the weight on the fly if the checkpoint is not already + # quantized by torchao + weight = torchao_quantize_param_data( + layer.weight, self.quant_config.torchao_config + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 38de4b54fb191..a24cd41659a0e 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -8,9 +8,10 @@ from torch.nn import Module from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none", "dynamic"] @@ -25,8 +26,7 @@ class Int8TpuConfig(QuantizationConfig): ) -> None: super().__init__() if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError( - f"Unsupported activation scheme {activation_scheme}") + raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme def get_name(self) -> QuantizationMethods: @@ -37,8 +37,7 @@ class Int8TpuConfig(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - raise NotImplementedError( - "This function should not be called with TPU Backend") + raise NotImplementedError("This function should not be called with TPU Backend") @staticmethod def get_config_filenames() -> list[str]: @@ -49,50 +48,61 @@ class Int8TpuConfig(QuantizationConfig): activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) return cls(activation_scheme=activation_scheme) - def get_quant_method(self, layer: Module, - prefix: str) -> Optional["TPUInt8LinearMethod"]: + def get_quant_method( + self, layer: Module, prefix: str + ) -> Optional["TPUInt8LinearMethod"]: if isinstance(layer, LinearBase): return TPUInt8LinearMethod(self) return None class TPUInt8LinearMethod(LinearMethodBase): - """Int8 Linear method for TPU Quant. """ + """Int8 Linear method for TPU Quant.""" def __init__(self, quant_config: Int8TpuConfig): self.quant_config = quant_config self.quantize_activation = False - if self.quant_config.activation_scheme == 'dynamic': + if self.quant_config.activation_scheme == "dynamic": self.quantize_activation = True - def create_weights(self, layer: Module, input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - + def create_weights( + self, + layer: Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): weight_loader = extra_weight_attrs.get("weight_loader") - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) def _quantize_weight( - self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: weight_dtype = weight.dtype weight = weight.cpu().to(torch.float32) n_bit = 8 eps = 1e-5 - max_int = 2**(n_bit - 1) - 1 - min_int = -(2**(n_bit - 1)) + max_int = 2 ** (n_bit - 1) - 1 + min_int = -(2 ** (n_bit - 1)) max_val = weight.abs().amax(dim=-1, keepdim=True) max_val = max_val.clamp(min=eps) qscale = max_val / max_int - qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int, - max_int).to(torch.int8) + qweight = torch.clamp( + torch.round(weight * (1.0 / qscale)), min_int, max_int + ).to(torch.int8) qscale = qscale.squeeze().to(weight_dtype) return qweight, qscale @@ -105,21 +115,25 @@ class TPUInt8LinearMethod(LinearMethodBase): layer.weight = Parameter(qweight, requires_grad=False) layer.scale = Parameter(qscale, requires_grad=False) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: try: import torch_xla.experimental.custom_kernel # noqa: F401 except ImportError as err: raise ImportError( "Please install torch_xla by following the instructions at " "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 - "to run vLLM on TPU.") from err + "to run vLLM on TPU." + ) from err weight = layer.weight scale = layer.scale out = torch.ops.xla.quantized_matmul_int8( - x, weight, scale, quantize_activation=self.quantize_activation) + x, weight, scale, quantize_activation=self.quantize_activation + ) if bias is not None: out = out + bias return out diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index 6ad56bae3dca0..07c18029fb4de 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .layer_utils import replace_parameter, update_tensor_inplace - -__all__ = ['update_tensor_inplace', 'replace_parameter'] +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ["update_tensor_inplace", "replace_parameter"] diff --git a/vllm/model_executor/layers/quantization/utils/allspark_utils.py b/vllm/model_executor/layers/quantization/utils/allspark_utils.py index 1992b4d201478..4c324682e5e62 100644 --- a/vllm/model_executor/layers/quantization/utils/allspark_utils.py +++ b/vllm/model_executor/layers/quantization/utils/allspark_utils.py @@ -12,41 +12,56 @@ ALLSPARK_AMPERE_N_ALIGN = 16 ALLSPARK_AMPERE_K_ALIGN = 16 -def check_allspark_supported_dtype_shape(input_size_per_partition: int, - output_size_per_partition: int, - group_size: int, - weight_dtype: ScalarType, - act_dtype: torch.dtype): +def check_allspark_supported_dtype_shape( + input_size_per_partition: int, + output_size_per_partition: int, + group_size: int, + weight_dtype: ScalarType, + act_dtype: torch.dtype, +): capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = -1 if capability_tuple is None else capability_tuple.to_int() # For Ampere GPU if device_capability >= 80 and device_capability < 90: if group_size != -1: - return False, \ - "For Ampere GPU, AllSpark does not support group_size "\ - f"= {group_size}. Only group_size = -1 are supported." + return ( + False, + "For Ampere GPU, AllSpark does not support group_size " + f"= {group_size}. Only group_size = -1 are supported.", + ) if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: - return False, "For Ampere GPU, AllSpark does not support "\ - f"quant type ({weight_dtype}). Only quant type "\ - f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." + return ( + False, + "For Ampere GPU, AllSpark does not support " + f"quant type ({weight_dtype}). Only quant type " + f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported.", + ) - if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ - or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: - return False, \ - "AllSpark needs input_size_per_partition % "\ - f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ - f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ - "for Ampere GPU optimized kernels." + if ( + input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 + or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0 + ): + return ( + False, + "AllSpark needs input_size_per_partition % " + f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and " + f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 " + "for Ampere GPU optimized kernels.", + ) if act_dtype != torch.float16 and act_dtype != torch.bfloat16: - return False, \ - "AllSpark only supports act_dtype = float16 or bfloat16,"\ - f"for Ampere GPU, but got act_dtype = {act_dtype}." + return ( + False, + "AllSpark only supports act_dtype = float16 or bfloat16," + f"for Ampere GPU, but got act_dtype = {act_dtype}.", + ) else: - return False, "AllSpark currently does not support "\ - f"device_capability = {device_capability}." + return ( + False, + "AllSpark currently does not support " + f"device_capability = {device_capability}.", + ) return True, None diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py index 4c2e548735869..4b7a22a266533 100644 --- a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -28,13 +28,14 @@ BITBLAS_SUPPORTED_SYM = [False, True] # Determines the supported quantization types for BitBLAS based on the # device's capability and whether zero-point (zp) is used. -def query_bitblas_supported_quant_types(has_zp: bool, - device_capability: Optional[int] = None - ): +def query_bitblas_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) if device_capability < 70: return [] @@ -50,97 +51,116 @@ def query_bitblas_supported_quant_types(has_zp: bool, def _check_bitblas_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: - + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) - supported_types = query_bitblas_supported_quant_types( - has_zp, device_capability) + supported_types = query_bitblas_supported_quant_types(has_zp, device_capability) if quant_type not in supported_types: - return (False, f"BitBLAS does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).") - if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES): - return (False, f"BitBLAS does not support group_size = {group_size}. " - f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " - "are supported.") + return ( + False, + f"BitBLAS does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: + return ( + False, + f"BitBLAS does not support group_size = {group_size}. " + f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) # Finally, check if bitblas is installed try: import bitblas - if version.parse( - bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): - raise ImportError("bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + + if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError: return False, "BitBLAS is not installed." return True, None -def check_bitblas_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None) -> bool: - cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp, - device_capability) +def check_bitblas_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_bitblas_supported( + quant_type, group_size, has_zp, device_capability + ) return cond -def verify_bitblas_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False) -> None: +def verify_bitblas_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_bitblas_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) -> None: - +def verify_bitblas_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: # Validate output_size_per_partition if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0: - raise ValueError(f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) # Validate input_size_per_partition if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0: - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) - if (group_size < input_size - and input_size_per_partition % group_size != 0): + if group_size < input_size and input_size_per_partition % group_size != 0: raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" f" is not divisible by group_size = {group_size}." "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + "with --quantization gptq." + ) -def check_bitblas_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) \ - -> tuple[bool, Optional[str]]: +def check_bitblas_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> tuple[bool, Optional[str]]: try: - verify_bitblas_supports_shape(output_size_per_partition, - input_size_per_partition, input_size, - group_size) + verify_bitblas_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) except ValueError as e: return False, e.__str__() return True, None @@ -150,8 +170,9 @@ def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, - is_row_parallel: bool) -> bool: +def bitblas_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -159,17 +180,18 @@ def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) -def bitblas_sort_g_idx( - g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def bitblas_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices @@ -186,8 +208,7 @@ def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: for col in range(unpacked_zeros.shape[1]): i = col % elems_per_int32 - unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> - (bits * i)) & 0xF + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF if not is_gptq_v2: return unpacked_zeros + 1 return unpacked_zeros @@ -204,7 +225,6 @@ def unpack_gptq_qweight(qweight, bits): ) for col in range(unpacked_weight.shape[1]): i = col % elems_per_int8 - unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> - (bits * i)) + unpacked_weight[:, col] = qweight[:, col // elems_per_int8] >> (bits * i) return torch.bitwise_and(unpacked_weight, 2**bits - 1) diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..0ea0225c96af1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..be487f2805b85 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..f81e09e198c86 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e073843af64c5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..f74a52fc17c9d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..8cab1b093276a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 1c61451fb34e5..ae244f90bb064 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,73 +1,73 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "24": { - "BLOCK_SIZE_M": 64, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -83,7 +83,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -115,15 +115,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -133,13 +133,13 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 63e661c80de6a..b2931d68f488a 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,83 +1,83 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 64, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -99,9 +99,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } -} +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 56b939e52fac3..ad630f0d787cf 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,30 +1,30 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -32,19 +32,19 @@ "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, - "num_warps": 4, + "num_warps": 8, "num_stages": 3 }, "32": { @@ -59,9 +59,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -83,7 +83,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 63d9a0bf5d79d..10b940c04fad3 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,50 +1,50 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, "32": { @@ -59,15 +59,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -75,7 +75,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 7fa398c15a2a5..94ce6e77f09ce 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,55 +1,55 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 3 }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, - "16": { - "BLOCK_SIZE_M": 64, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, - "24": { - "BLOCK_SIZE_M": 64, + "8": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, @@ -59,31 +59,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,7 +99,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index f15d8f64c7090..9540df407975e 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,57 +1,57 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -59,33 +59,33 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -93,23 +93,23 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..96f6c307b357d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..567675787d4f9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 51e237b91b8e7..0894ff2fa3322 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,6 +1,6 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -8,55 +8,55 @@ "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, - "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, "24": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "48": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, @@ -64,83 +64,83 @@ "num_stages": 4 }, "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "96": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, - "3072": { + "1536": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 6280219c9ee7d..86c68e08a1a6a 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,78 +1,78 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "24": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "48": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "96": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, @@ -80,38 +80,14 @@ "num_stages": 5 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "1536": { + "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, @@ -119,19 +95,43 @@ "num_warps": 4, "num_stages": 5 }, - "2048": { + "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 0a1e14cffbb2a..af1a384cbcbd3 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,14 +1,14 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, @@ -16,26 +16,26 @@ "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, @@ -43,9 +43,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, @@ -59,7 +59,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -67,31 +67,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -101,25 +101,9 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -127,13 +111,29 @@ "num_warps": 4, "num_stages": 3 }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, @@ -141,6 +141,6 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 15b1c93f60fc5..d381764a26414 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,22 +1,22 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, @@ -24,18 +24,18 @@ "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, @@ -45,47 +45,47 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -93,29 +93,29 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index 8ff12e64c172f..821ad0c704573 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,43 +1,43 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 5 }, "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, @@ -45,7 +45,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 64, @@ -59,7 +59,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, @@ -73,19 +73,19 @@ }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, @@ -99,21 +99,21 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, @@ -123,9 +123,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 4532f93681e2b..daaf21c286553 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,67 +1,67 @@ { "1": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 5 }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, "8": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "64": { "BLOCK_SIZE_M": 64, @@ -73,25 +73,25 @@ }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,31 +99,31 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -133,7 +133,7 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, @@ -141,6 +141,6 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json index ca7f32b9552b4..2583b5a3441ca 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,57 +1,57 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 - }, - "2": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 + "num_warps": 8, + "num_stages": 3 }, - "8": { - "BLOCK_SIZE_M": 64, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, - "16": { - "BLOCK_SIZE_M": 64, + "8": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "24": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, @@ -59,43 +59,35 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "512": { + "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -103,19 +95,27 @@ "num_warps": 4, "num_stages": 3 }, - "1024": { + "512": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, "1536": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 }, @@ -139,8 +139,8 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json index 5acea242cc0ad..baa64f8d3d141 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,65 +1,65 @@ { "1": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4 + "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, - "8": { - "BLOCK_SIZE_M": 64, + "16": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -69,21 +69,21 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -91,7 +91,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -99,13 +99,13 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, @@ -123,7 +123,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, @@ -131,15 +131,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/README.md b/vllm/model_executor/layers/quantization/utils/configs/README.md new file mode 100644 index 0000000000000..1110ced4fa063 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/README.md @@ -0,0 +1,3 @@ +# Quantization Kernel Config + +Use scripts under `benchmarks/kernels/` to generate these config files. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index f5d7c57fe2a87..7059a029ba67e 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -1,17 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" + from __future__ import annotations import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts) + FlashInferExperts, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + create_flashinfer_prepare_finalize, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -24,15 +30,17 @@ __all__ = [ def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and current_platform.is_cuda() - and current_platform.is_device_capability(100)) + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() + and current_platform.is_device_capability(100) + ) -def reorder_w1w3_to_w3w1(weight: torch.Tensor, - scale: torch.Tensor, - dim: int = -2) -> tuple[torch.Tensor, torch.Tensor]: +def reorder_w1w3_to_w3w1( + weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 +) -> tuple[torch.Tensor, torch.Tensor]: """Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`""" size = weight.size(dim) assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" @@ -41,38 +49,34 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, w1, w3 = weight.split(half, dim=dim) s1, s3 = scale.split(half, dim=dim) - return (torch.cat([w3, w1], - dim=dim).contiguous(), torch.cat([s3, s1], - dim=dim).contiguous()) + return ( + torch.cat([w3, w1], dim=dim).contiguous(), + torch.cat([s3, s1], dim=dim).contiguous(), + ) def build_flashinfer_fp4_cutlass_moe_prepare_finalize( moe: FusedMoEConfig, - a1_gscale: torch.Tensor, ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 - return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) + enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv" + return create_flashinfer_prepare_finalize( + use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv + ) def select_nvfp4_gemm_impl( moe: FusedMoEConfig, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + moe_quant_config: FusedMoEQuantConfig, allow_flashinfer: bool, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" if allow_flashinfer: return FlashInferExperts( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, out_dtype=moe.in_dtype, - quant_dtype="nvfp4", + quant_config=moe_quant_config, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, @@ -82,4 +86,5 @@ def select_nvfp4_gemm_impl( # native cutlass experts currently don't support DP; TP case won't call this raise ValueError( "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " - "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)") + "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)" + ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9889808f0760f..7f32ef00647ca 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -8,11 +8,16 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts) + FlashInferExperts, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + create_flashinfer_prepare_finalize, +) logger = init_logger(__name__) @@ -23,7 +28,6 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): - # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. @@ -43,13 +47,16 @@ def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: - return x.reshape(-1, 2, x.shape[-2] // 2, - x.shape[-1]).flip(dims=[1]).reshape(x.shape) + return ( + x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape) + ) -def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor): +def rotate_flashinfer_fp8_moe_weights( + gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor +): from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + epilogue_tile_m = 128 num_experts = gemm1_weights.shape[0] hidden_size = gemm1_weights.shape[-1] @@ -59,13 +66,13 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, gemm1_weights_fp8_interleaved = [] for i in range(num_experts): gemm1_weights_fp8_interleaved.append( - reorder_rows_for_gated_act_gemm(gemm1_weights[i])) + reorder_rows_for_gated_act_gemm(gemm1_weights[i]) + ) # Stack weights and scales for all experts - gemm1_weights_fp8_interleaved = torch.stack( - gemm1_weights_fp8_interleaved).reshape(num_experts, - 2 * intermediate_size, - hidden_size) + gemm1_weights_fp8_interleaved = torch.stack(gemm1_weights_fp8_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size + ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] @@ -73,18 +80,21 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, for i in range(num_experts): gemm1_weights_fp8_shuffled.append( shuffle_matrix_a( - gemm1_weights_fp8_interleaved[i].view(torch.uint8), - epilogue_tile_m)) + gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) gemm2_weights_fp8_shuffled.append( - shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), - epilogue_tile_m)) + shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), epilogue_tile_m) + ) # Stack weights for all experts gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view( - torch.float8_e4m3fn) + torch.float8_e4m3fn + ) gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view( - torch.float8_e4m3fn) + torch.float8_e4m3fn + ) def apply_flashinfer_per_tensor_scale_fp8( @@ -99,16 +109,24 @@ def apply_flashinfer_per_tensor_scale_fp8( apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_scalar to be initialized") + "Expected output1_scales_scalar to be initialized" + ) assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_gate_scalar to be initialized") + "Expected output1_scales_gate_scalar to be initialized" + ) assert layer.output1_scales_scalar is not None, ( - "Expected output2_scales_scalar to be initialized") + "Expected output2_scales_scalar to be initialized" + ) from vllm.model_executor.models.llama4 import Llama4MoE - assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( "FusedMoE flashinfer kernels are only supported for Llama4" + ) return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( routing_logits=router_logits, routing_bias=routing_bias, @@ -137,79 +155,65 @@ def get_moe_scaling_factors( activation_scale: torch.Tensor, gemm2_weights_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - output1_scales_scalar = gemm1_weights_scale * input_scale * ( - 1.0 / activation_scale) + output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale) output1_scales_gate_scalar = gemm1_weights_scale * input_scale output2_scales_scalar = activation_scale * gemm2_weights_scale - return output1_scales_scalar, output1_scales_gate_scalar, \ - output2_scales_scalar + return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar def register_moe_scaling_factors(layer: torch.nn.Module) -> None: - output1_scales, output1_gate_scales, output2_scales = \ - get_moe_scaling_factors( - layer.w13_input_scale, layer.w13_weight_scale, - layer.w2_input_scale, layer.w2_weight_scale - ) + output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors( + layer.w13_input_scale, + layer.w13_weight_scale, + layer.w2_input_scale, + layer.w2_weight_scale, + ) layer.register_parameter( - 'output1_scales_scalar', - torch.nn.Parameter(output1_scales, requires_grad=False)) + "output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False) + ) layer.register_parameter( - 'output1_scales_gate_scalar', - torch.nn.Parameter(output1_gate_scales, requires_grad=False)) + "output1_scales_gate_scalar", + torch.nn.Parameter(output1_gate_scales, requires_grad=False), + ) layer.register_parameter( - 'output2_scales_scalar', - torch.nn.Parameter(output2_scales, requires_grad=False)) + "output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False) + ) layer.register_parameter( - 'w2_input_scale_inv', - torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False)) + "w2_input_scale_inv", + torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False), + ) def build_flashinfer_fp8_cutlass_moe_prepare_finalize( moe: Optional[FusedMoEConfig], - layer: torch.nn.Module, ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - return FlashInferCutlassMoEPrepareAndFinalize( - use_dp, a1_gscale=layer.w13_input_scale) + return create_flashinfer_prepare_finalize(use_dp) def select_cutlass_fp8_gemm_impl( moe: Optional[FusedMoEConfig], - layer: torch.nn.Module, + quant_config: FusedMoEQuantConfig, out_dtype: Optional[torch.dtype] = None, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for fused-MoE layers""" - from vllm.model_executor.models.llama4 import Llama4MoE - assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ - "FusedMoE flashinfer kernels are only supported for Llama4" - if moe is not None: return FlashInferExperts( - g1_alphas=layer.output1_scales_gate_scalar, - g2_alphas=layer.output2_scales_scalar, - a1_gscale=layer.w13_input_scale, - a2_gscale=layer.w2_input_scale_inv, out_dtype=moe.in_dtype, - quant_dtype=torch.float8_e4m3fn, + quant_config=quant_config, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, tp_size=moe.moe_parallel_config.tp_size, ) - assert out_dtype is not None, ( - "If moe config is None, out_dtype must be passed") + assert out_dtype is not None, "If moe config is None, out_dtype must be passed" return FlashInferExperts( - g1_alphas=layer.output1_scales_gate_scalar, - g2_alphas=layer.output2_scales_scalar, - a1_gscale=layer.w13_input_scale, - a2_gscale=layer.w2_input_scale_inv, out_dtype=out_dtype, - quant_dtype=torch.float8_e4m3fn, + quant_config=quant_config, ) @@ -224,12 +228,15 @@ def flashinfer_cutlass_moe_fp8( expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: + quant_config = layer.quant_method.get_fused_moe_quant_config(layer) + assert quant_config is not None + fused_experts = mk.FusedMoEModularKernel( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, - layer=layer), - select_cutlass_fp8_gemm_impl(moe=None, - layer=layer, - out_dtype=hidden_states.dtype)) + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None), + select_cutlass_fp8_gemm_impl( + moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype + ), + ) return fused_experts( hidden_states, @@ -255,4 +262,5 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: allowed_backends = ["throughput", "latency"] raise ValueError( f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" - f" expected one of {allowed_backends}") + f" expected one of {allowed_backends}" + ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ab1d5383f4651..fa34dba371e81 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,15 +13,28 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + GroupShape, + group_broadcast, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_BLOCK_FP8_SUPPORTED) + CUTLASS_BLOCK_FP8_SUPPORTED, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - should_use_deepgemm_for_fp8_linear) +from vllm.utils import direct_register_custom_op +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, + should_use_deepgemm_for_fp8_linear, +) logger = init_logger(__name__) @@ -32,6 +45,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +# We need to pass in the is_hopper flag as argument because the function +# current_platform.is_device_capability() is not supported by Torch compiler. def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -39,12 +54,18 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, + is_hopper: Optional[bool] = None, ) -> torch.Tensor: - return ops.cutlass_scaled_mm(A, - B.T, - out_dtype=output_dtype, - scale_a=As, - scale_b=Bs.T) + if is_hopper is None: + is_hopper = current_platform.is_device_capability(90) + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=output_dtype, + scale_a=As, + # SM90 block FP8 requires row-major scale_b, which we do ahead of time + scale_b=Bs if block_size is not None and is_hopper else Bs.T, + ) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -68,7 +89,6 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) @@ -79,143 +99,326 @@ if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8_blockscale", op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, ) - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()): - + if ( + envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz() + ): import aiter as rocm_aiter from aiter import get_hip_quant aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def dispatch_w8a8_blockscale_func( - use_cutlass: bool, use_aiter_and_is_supported: bool -) -> Callable[[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - list[int], - torch.dtype, -], torch.Tensor]: - if use_cutlass: - return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale - return w8a8_block_fp8_matmul +# TODO we should be able to change the type of block_size to GroupShape +# after we resolve GroupShape compilation issue +# https://github.com/vllm-project/vllm/issues/25270 +def _w8a8_triton_block_scaled_mm_func( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return w8a8_triton_block_scaled_mm( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) + + +def _w8a8_triton_block_scaled_mm_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty( + (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device + ) + + +direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + fake_impl=_w8a8_triton_block_scaled_mm_fake, +) + + +def _padded_cutlass( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + pad_multiple = 4 + dim = qx.shape[0] + padded = ( + dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) + ) + + padded_shape = [padded, *qx.shape[1:]] + padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) + padded_qx[0 : qx.shape[0], ...].copy_(qx) + + padded_x_scale_shape = [*x_scale.shape[1:], padded] + padded_x_scale = torch.ones( + padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype + ).permute(-1, -2) + padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) + + output = cutlass_scaled_mm( + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True + ) + return output[0 : qx.shape[0], ...] + + +def _padded_cutlass_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty( + (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device + ) + + +direct_register_custom_op( + "padded_cutlass", + _padded_cutlass, + fake_impl=_padded_cutlass_fake, +) + + +def _fp8_gemm_nt_op( + q_input: torch.Tensor, + input_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool, +) -> None: + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) + + +def _fp8_gemm_nt_op_fake( + q_input: torch.Tensor, + input_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool, +) -> None: + return None + + +direct_register_custom_op( + "fp8_gemm_nt_op", + _fp8_gemm_nt_op, + mutates_args=["output"], + fake_impl=_fp8_gemm_nt_op_fake, +) # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 -def apply_w8a8_block_fp8_linear( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype +class W8A8BlockFp8LinearOp: + """ + This class executes a Blocked FP8 linear layer using cutlass if supported + and torch.scaled_mm otherwise. + """ - if should_use_deepgemm_for_fp8_linear(output_dtype, weight): + def __init__( + self, + weight_group_shape: GroupShape, + act_quant_group_shape: GroupShape, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, + ): + self.weight_group_shape = weight_group_shape + self.act_quant_group_shape = act_quant_group_shape + self.is_deep_gemm_supported = is_deep_gemm_supported() + self.is_hopper = current_platform.is_device_capability(90) + self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, + # Get the correct blockscale mul and input quant operations. + # We can't use _dispatch_w8a8_blockscale_op to figure out if we want + # to use deepgemm because we don't know the shape of weights (and + # whether deepgemm supports it) at the init time. + self.w8a8_blockscale_op, self.input_quant_op = ( + self._dispatch_w8a8_blockscale_op( + cutlass_block_fp8_supported, use_aiter_and_is_supported + ) + ) + self.deepgemm_input_quant_op = ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=self.use_deep_gemm_e8m0, + ) + if self.is_deep_gemm_supported + else None ) - # ensure DeepGEMM-backed custom op is registered before use - import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype - output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + if should_use_deepgemm_for_fp8_linear( + output_dtype, weight, self.is_deep_gemm_supported + ): + output = self._run_deepgemm(input_2d, weight, weight_scale) + else: + output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + def _run_deepgemm( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.deepgemm_input_quant_op is not None + q_input, input_scale = self.deepgemm_input_quant_op(input_2d) + output = torch.empty( + (q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + torch.ops.vllm.fp8_gemm_nt_op( + q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0 + ) + return output + + def _run_cutlass( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + if self.is_hopper: + return torch.ops.vllm.padded_cutlass( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + input_2d.dtype, + ) + else: + return cutlass_scaled_mm( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + input_2d.dtype, + False, + ) + + def _run_aiter( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.act_quant_group_shape == GroupShape(1, 128) + q_input, input_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 + ) + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( q_input, weight, - x_scale, + input_scale, weight_scale, - block_size, - output_dtype=output_dtype) - if bias is not None: - output += bias - return output.to(dtype=output_dtype).view(*output_shape) + list(self.weight_group_shape), + input_2d.dtype, + ) - if current_platform.is_cuda(): - if current_platform.has_device_capability(100): + def _run_triton( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + input_2d.dtype, + ) - use_cutlass = cutlass_block_fp8_supported and ( - cdiv(weight.shape[0], 128) == weight_scale.shape[0] - and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) - else: - # TODO: update this after switching to public sm90 block scale gemm - # as it also supports weight.shape % 128 != 0 - use_cutlass = cutlass_block_fp8_supported and ( - weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - else: - use_cutlass = False - - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported) - if use_cutlass: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - - else: + def _dispatch_w8a8_blockscale_op( + self, + use_cutlass: bool, + use_aiter_and_is_supported: bool, + ) -> tuple[ + Callable[ + [ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + torch.Tensor, + ], + Optional[QuantFP8], + ]: + if use_cutlass: + return self._run_cutlass, ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=False, + ) + ) if use_aiter_and_is_supported: - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) - - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) - - -def apply_w8a8_block_fp8_linear_fake( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - - -if not current_platform.is_cpu(): - direct_register_custom_op( - op_name="apply_w8a8_block_fp8_linear", - op_func=apply_w8a8_block_fp8_linear, - mutates_args=[], - fake_impl=apply_w8a8_block_fp8_linear_fake, - ) + return self._run_aiter, None + return self._run_triton, ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False, + ) + ) def input_to_float8( - x: torch.Tensor, - dtype: Optional[torch.dtype] = None + x: torch.Tensor, dtype: Optional[torch.dtype] = None ) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" @@ -274,8 +477,9 @@ def _per_token_group_quant_fp8( row_g_id = g_id % groups_per_row # Ensure offset calculations use int64 to prevent overflow - y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * - group_size) + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + ( + row_g_id.to(tl.int64) * group_size + ) y_ptr += y_ptr_offset y_q_ptr_offset = g_id.to(tl.int64) * group_size @@ -329,8 +533,9 @@ def _per_token_group_quant_fp8_colmajor( row_g_id = g_id % groups_per_row # Ensure offset calculations use int64 to prevent overflow - y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * - group_size) + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + ( + row_g_id.to(tl.int64) * group_size + ) y_ptr += y_ptr_offset y_q_ptr_offset = g_id.to(tl.int64) * group_size @@ -342,8 +547,7 @@ def _per_token_group_quant_fp8_colmajor( scale_col = g_id % blocks_per_row scale_row = g_id // blocks_per_row # Ensure offset calculation uses int64 for y_s_ptr - y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to( - tl.int64) + y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(tl.int64) y_s_ptr += y_s_ptr_offset cols = tl.arange(0, BLOCK) # group_size <= BLOCK @@ -385,11 +589,12 @@ def per_token_group_quant_fp8( scaling factor. """ if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_e8m0_used() + use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype - assert (x.shape[-1] % group_size == 0), ( + assert x.shape[-1] % group_size == 0, ( f"the last dimension of `x` {x.shape[-1]} must be divisible " - f"by `group_size` {group_size}") + f"by `group_size` {group_size}" + ) assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) @@ -403,17 +608,18 @@ def per_token_group_quant_fp8( # Allocate the scale tensor in either row- or column-major format. if column_major_scales: - shape = (x.shape[-1] // group_size, ) + x.shape[:-1] - x_s = torch.empty(shape, device=x.device, - dtype=torch.float32).permute(-1, -2) + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) else: - shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + shape = x.shape[:-1] + (x.shape[-1] // group_size,) x_s = torch.empty(shape, device=x.device, dtype=torch.float32) # prefer CUDA kernel if available + # TODO(bnell): this causes some fp8 moe test to fail. if current_platform.is_cuda() and x.is_contiguous(): - torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, - fp8_min, fp8_max, use_ue8m0) + torch.ops._C.per_token_group_fp8_quant( + x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0 + ) return x_q, x_s # TRITON FALLBACK @@ -424,7 +630,7 @@ def per_token_group_quant_fp8( num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 if column_major_scales: - _per_token_group_quant_fp8_colmajor[(M, )]( + _per_token_group_quant_fp8_colmajor[(M,)]( x, x_q, x_s, @@ -441,7 +647,7 @@ def per_token_group_quant_fp8( num_stages=num_stages, ) else: - _per_token_group_quant_fp8[(M, )]( + _per_token_group_quant_fp8[(M,)]( x, x_q, x_s, @@ -461,7 +667,7 @@ def per_token_group_quant_fp8( @triton.jit -def _w8a8_block_fp8_matmul( +def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output A, B, @@ -519,12 +725,8 @@ def _w8a8_block_fp8_matmul( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k @@ -550,8 +752,9 @@ def _w8a8_block_fp8_matmul( @functools.lru_cache -def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, - block_k: int) -> Optional[dict[int, Any]]: +def get_w8a8_block_fp8_configs( + N: int, K: int, block_n: int, block_k: int +) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. The return value will be a dictionary that maps an irregular grid of @@ -566,7 +769,8 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( @@ -586,7 +790,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, return None -def w8a8_block_fp8_matmul( +def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -622,7 +826,7 @@ def w8a8_block_fp8_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) @@ -643,10 +847,11 @@ def w8a8_block_fp8_matmul( } def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) - _w8a8_block_fp8_matmul[grid]( + _w8a8_triton_block_scaled_mm[grid]( A, B, C, @@ -673,74 +878,10 @@ def w8a8_block_fp8_matmul( return C -# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 -# TODO(wentao): remove this function when DeepGEMM exposes this function -def get_tma_aligned_size(x: int, element_size: int) -> int: - """ - Global memory address of TMA must be 16-byte aligned. - Since we use column-major layout for the LHS scaling tensor, - the M-axis of the LHS scaling tensor needs to be padded to a multiple of - 16 bytes. - - Arguments: - x: original M-axis shape of the LHS scaling tensor. - element_size: element size of the LHS scaling tensor. - - Returns: - M-axis shape of the LHS scaling tensor after padding. - """ - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return cdiv(x, alignment) * alignment - - -# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 -# TODO(wentao): remove this function when DeepGEMM exposes this function -def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: - """ - Returns TMA-aligned transposed format of the input tensor. `torch.transpose` - will be called if necessary. - If the input tensor is already column-major layout and 16-byte aligned along - the M axis (thus meets the requirement of LHS scaling tensor in - DeepGEMM), this function will do nothing. - - Arguments: - x: usually the LHS scaling tensor in GEMM. - - Returns: - The LHS scaling tensor of TMA-aligned transposed format. - """ - # NOTES: for the extreme performance, you may rewrite/fuse this function in - # CUDA - assert x.dim() in (2, 3) - remove_dim = False - m, n = x.shape[-2], x.shape[-1] - aligned_m = get_tma_aligned_size(m, x.element_size()) - if x.dim() == 2: - if x.stride(0) == 1 and x.stride(1) == aligned_m: - return x - x, remove_dim = x.unsqueeze(0), True - - b = x.shape[0] - - # The last kernel gives a column-major TMA aligned layout - if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride( - 2) == aligned_m: - return x.squeeze(0) if remove_dim else x - - # Normal layout requires transposing - aligned_x = torch.transpose( - torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) - aligned_x[:, :m, :] = x - aligned_x = aligned_x[:, :m, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x - - def requant_weight_ue8m0_inplace( - weight: torch.Tensor, - weight_scale: torch.Tensor, - block_size: Sequence[int] = (128, 128), + weight: torch.Tensor, + weight_scale: torch.Tensor, + block_size: Sequence[int] = (128, 128), ) -> None: """Re-quantise *weight* so that its per-block scaling factors are in the UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace. @@ -757,8 +898,9 @@ def requant_weight_ue8m0_inplace( return if weight.dtype != torch.float8_e4m3fn: - raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got " - f"{weight.dtype} instead.") + raise ValueError( + f"Expected *weight* to be torch.float8_e4m3fn, got {weight.dtype} instead." + ) from vllm.utils.deep_gemm import per_block_cast_to_fp8 @@ -787,9 +929,257 @@ def requant_weight_ue8m0_inplace( s_exp = s_exp[:m_cur, :k_cur] w_dq = w_q.to(torch.float32) * s_exp # Re-quantise using power-of-two scaling (UE8M0). - w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k], - use_ue8m0=True) + w_requant, s_requant = per_block_cast_to_fp8( + w_dq, [block_m, block_k], use_ue8m0=True + ) # Write back the results in-place. w_q.copy_(w_requant) s_old.copy_(s_requant) + + +def check_aiter_fp8_linear_support() -> bool: + """AITER is only supported on ROCm and only for FP8_FNUZ + and at the moment are MI300 series""" + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz() + ) + + +def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: + """Pad the weight tensor. This is an optimization on ROCm platform, which + can benefit from tensors located far enough from one another in memory""" + if ( + envs.VLLM_ROCM_FP8_PADDING + and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0 + ): + num_pad = 256 // weight.element_size() + import torch.nn.functional as F + + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + +def validate_fp8_block_shape( + layer: torch.nn.Module, + input_size: int, + output_size: int, + input_size_per_partition: int, + output_partition_sizes: list[int], + block_size: list[int], +) -> None: + """Validate block quantization shapes for tensor parallelism.""" + from vllm.distributed import get_tensor_model_parallel_world_size + + tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) + block_n, block_k = block_size[0], block_size[1] + + # Required by row parallel + if ( + tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} " + f"is not divisible by weight quantization block_k = {block_k}." + ) + + # Required by column parallel or enabling merged weights + is_tp_split = tp_size > 1 and output_size // sum(output_partition_sizes) == tp_size + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + + +def create_fp8_weight_parameter( + output_size_per_partition: int, + input_size_per_partition: int, + weight_loader: Optional[Callable], +) -> torch.nn.Parameter: + """Create FP8 weight parameter.""" + from vllm.model_executor.parameter import ModelWeightParameter + + return ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + +def create_fp8_scale_parameter( + parameter_type: torch.nn.Parameter, + output_partition_sizes: list[int], + input_size_per_partition: int, + block_size: Optional[list[int]], + weight_loader: Optional[Callable], +) -> torch.nn.Parameter: + """Create scale parameter based on quantization strategy.""" + if parameter_type == ChannelQuantScaleParameter: + scale = parameter_type( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + elif parameter_type == BlockQuantScaleParameter: + assert block_size is not None + block_n, block_k = block_size[0], block_size[1] + output_size_per_partition = sum(output_partition_sizes) + scale = parameter_type( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + elif parameter_type == PerTensorScaleParameter: + scale = parameter_type( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + else: + raise ValueError(f"Unknown parameter type: {parameter_type}") + + scale[:] = torch.finfo(torch.float32).min + return scale + + +def create_fp8_input_scale( + output_partition_sizes: list[int], weight_loader: Optional[Callable] +) -> torch.nn.Parameter: + """Create input scale parameter for static activation quantization.""" + from vllm.model_executor.parameter import PerTensorScaleParameter + + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + return scale + + +def process_fp8_weight_tensor_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: list[int], + input_scale: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Process weights for tensor-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale, + ) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale, input_scale=input_scale + ) + + # Requantize with max scale + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=logical_widths, + ) + + weight = _maybe_pad_fp8_weight(weight) + return weight, weight_scale, input_scale + + +def process_fp8_weight_channel_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Process weights for channel-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, + ) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale, input_scale=input_scale + ) + + return weight, weight_scale, input_scale + + +def process_fp8_weight_block_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Process weights for block-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, + ) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) + + weight = _maybe_pad_fp8_weight(weight) + return weight, weight_scale + + +def maybe_post_process_fp8_weight_block( + layer: torch.nn.Module, cutlass_block_fp8_supported: bool +): + assert layer.weight_block_size is not None + + from vllm.utils.deep_gemm import ( + is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear, + ) + + # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to + # requantize the weight and input to the specific scale + # at the same time. + should_use_deepgemm = should_use_deepgemm_for_fp8_linear( + layer.orig_dtype, layer.weight + ) + if is_deep_gemm_e8m0_used() and should_use_deepgemm: + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.weight.data, layer.weight_scale.data, block_sz + ) + # SM90 Block FP8 CUTLASS requires row-major weight scales + elif ( + current_platform.is_device_capability(90) + and cutlass_block_fp8_supported + and not should_use_deepgemm + ): + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), requires_grad=False + ) + + +def expert_weight_is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index db82b0def1653..6209dda955ce7 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -1,59 +1,70 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping from copy import deepcopy -from typing import Optional, Union +from fractions import Fraction +from types import MappingProxyType +from typing import TYPE_CHECKING, Optional, Union import regex as re import torch -from vllm.config import QuantizationConfig -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, UnquantizedEmbeddingMethod) + ParallelLMHead, + UnquantizedEmbeddingMethod, +) + +if TYPE_CHECKING: + from ..gptq import GPTQConfig + from ..gptq_marlin import GPTQMarlinConfig +else: + GPTQConfig = object + GPTQMarlinConfig = object # Match dynamic rules with module name (prefix) and override quantize # config if module (prefix) matches a rule -def override_config(config: QuantizationConfig, prefix: str): - weight_bits = get_dynamic_override(config, prefix, "bits", - config.weight_bits) +def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) if isinstance(weight_bits, int): config.weight_bits = weight_bits - group_size = get_dynamic_override(config, prefix, "group_size", - config.group_size) + group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) if isinstance(group_size, int): config.group_size = group_size - desc_act = get_dynamic_override(config, prefix, "desc_act", - config.desc_act) + desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) if isinstance(desc_act, bool): config.desc_act = desc_act - config.pack_factor = 32 // config.weight_bits # packed into int32 + config.pack_factor = Fraction(32, config.weight_bits) # packed into int32 if config.get_name() == "gptq_marlin": + assert isinstance(config, GPTQMarlinConfig) is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) if isinstance(is_sym, bool): config.is_sym = is_sym if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: - raise ValueError("Unsupported quantization config: " - f"bits={config.weight_bits}, sym={config.is_sym}") + raise ValueError( + "Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}" + ) - config.quant_type = config.TYPE_MAP[(config.weight_bits, - config.is_sym)] + config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] elif config.get_name() == "gptq": + assert isinstance(config, GPTQConfig) if config.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {config.weight_bits} bits.") + f"supported for GPTQ, but got {config.weight_bits} bits." + ) def get_dynamic_override( - config: QuantizationConfig, + config: Union[GPTQConfig, GPTQMarlinConfig], layer_name: str, key: Optional[str] = None, - default_value: Union[int, bool, - None] = None) -> Union[dict, int, bool, None]: + default_value: Union[int, bool, None] = None, +) -> Union[dict, int, bool, None]: for pattern, pattern_dict in config.dynamic.items(): # Negative match: matched modules are excluded from quantized init if pattern.startswith("-:"): @@ -69,20 +80,72 @@ def get_dynamic_override( return default_value +def is_layer_gptq_quantized( + prefix: str, + quantized_layers: list[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), +) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + + # GPTQ's `modules_in_block_to_quantize`: + # Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"] + # Full prefix ["model.layers.0.self_attn.q_proj"] + + proj_name = prefix.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_quantized = None + for shard_prefix in shard_prefixes: + is_shard_quantized = any( + layer in shard_prefix for layer in quantized_layers + ) + + if is_quantized is None: + is_quantized = is_shard_quantized + elif is_shard_quantized != is_quantized: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_quantized = any(layer in prefix for layer in quantized_layers) + + assert is_quantized is not None + return is_quantized + + def get_linear_quant_method( - config: QuantizationConfig, + config: Union[GPTQConfig, GPTQMarlinConfig], layer: torch.nn.Module, prefix: str, linear_method_cls: type, ): cloned_config = deepcopy(config) - parallel_lm_head_quantized = isinstance( - layer, ParallelLMHead) and cloned_config.lm_head_quantized + parallel_lm_head_quantized = ( + isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized + ) if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + is_layer_quantized = is_layer_gptq_quantized( + prefix=prefix, + quantized_layers=cloned_config.modules_in_block_to_quantize, + fused_mapping=cloned_config.packed_modules_mapping, + ) # False = skip module, None = no override, else = Positive match if get_dynamic_override( # noqa: E712 - cloned_config, # noqa: E712 - layer_name=prefix) == False: # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix, + ) == False or (not is_layer_quantized): # noqa: E712 if parallel_lm_head_quantized: return UnquantizedEmbeddingMethod() return UnquantizedLinearMethod() diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 6840cabbf1ae3..1b8efe4332c54 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -30,12 +30,9 @@ def apply_w8a8_block_int8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1]) - output = w8a8_block_int8_matmul(q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=input.dtype) + output = w8a8_block_int8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) if bias is not None: output = output + bias @@ -43,8 +40,8 @@ def apply_w8a8_block_int8_linear( def input_to_int8( - x: torch.Tensor, - dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, dtype: torch.dtype = torch.int8 +) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to int8 values with tensor-wise quantization.""" iinfo = torch.iinfo(dtype) @@ -78,8 +75,8 @@ def block_dequant( for i in range(k_tiles): for j in range(n_tiles): x_dq_block[ - j * block_n:min((j + 1) * block_n, n), - i * block_k:min((i + 1) * block_k, k), + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), ] *= x_s[j][i] return x_dq_block @@ -91,15 +88,17 @@ if current_platform.is_rocm(): # NOTE: This can be removed when hip.libdevice.round() is available. @core.extern def round_f32(arg0, _builder=None): - return core.extern_elementwise("", - "", [arg0], { - (core.dtype("fp32"), ): - ("llvm.round", core.dtype("fp32")), - (core.dtype("fp64"), ): - ("llvm.round", core.dtype("fp64")), - }, - is_pure=True, - _builder=_builder) + return core.extern_elementwise( + "", + "", + [arg0], + { + (core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")), + (core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")), + }, + is_pure=True, + _builder=_builder, + ) @triton.jit def round_int8(x): @@ -127,8 +126,7 @@ def _per_token_quant_int8( cols = tl.arange(0, BLOCK) mask = cols < N - x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, - other=0.0).to(tl.float32) + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) scale_x = absmax / 127 x_q = x * (127 / absmax) @@ -142,15 +140,13 @@ def per_token_quant_int8(x): M = x.numel() // x.shape[-1] N = x.shape[-1] x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) - scales = torch.empty(x.shape[:-1] + (1, ), - device=x.device, - dtype=torch.float32) + scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) assert x.is_contiguous() - _per_token_quant_int8[(M, )]( + _per_token_quant_int8[(M,)]( x, x_q, scales, @@ -229,8 +225,9 @@ def per_token_group_quant_int8( tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` cannot be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" iinfo = torch.iinfo(dtype) @@ -239,15 +236,15 @@ def per_token_group_quant_int8( x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size, ), + x.shape[:-1] + (x.shape[-1] // group_size,), device=x.device, dtype=torch.float32, ) # prefer CUDA kernel if available if current_platform.is_cuda(): - torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps, - float(int8_min), - float(int8_max)) + torch.ops._C.per_token_group_quant_int8( + x, x_q, x_s, group_size, eps, float(int8_min), float(int8_max) + ) return x_q, x_s M = x.numel() // group_size @@ -257,7 +254,7 @@ def per_token_group_quant_int8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_int8[(M, )]( + _per_token_group_quant_int8[(M,)]( x, x_q, x_s, @@ -333,20 +330,15 @@ def _w8a8_block_int8_matmul( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_s = tl.load(As_ptrs + offs_ks * stride_As_k) b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) - accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, - None] * b_s[None, :] + accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -365,8 +357,9 @@ def _w8a8_block_int8_matmul( @functools.lru_cache -def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, - block_k: int) -> Optional[dict[int, Any]]: +def get_w8a8_block_int8_configs( + N: int, K: int, block_n: int, block_k: int +) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. @@ -382,7 +375,8 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501 config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( @@ -395,8 +389,10 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, # If no optimized configuration is available, we will use the default # configuration logger.warning( - ("Using default W8A8 Block INT8 kernel config. Performance might " - "be sub-optimal! Config file not found at %s"), + ( + "Using default W8A8 Block INT8 kernel config. Performance might " + "be sub-optimal! Config file not found at %s" + ), config_file_path, ) return None @@ -423,7 +419,7 @@ def w8a8_block_int8_matmul( Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. - output_dytpe: The dtype of the returned tensor. + output_dtype: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. @@ -441,7 +437,7 @@ def w8a8_block_int8_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) @@ -462,8 +458,9 @@ def w8a8_block_int8_matmul( } def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) _w8a8_block_int8_matmul[grid]( A, diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index fbc0f23acb59a..4bf31340a2f68 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -20,12 +20,15 @@ def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) -def replace_parameter(mod: torch.nn.Module, name: str, - new: Union[torch.Tensor, torch.nn.Parameter]) -> None: - +def replace_parameter( + mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter] +) -> None: old = getattr(mod, name) - if type(old) is type(new) and old.dtype == new.dtype and \ - old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + if ( + type(old) is type(new) + and old.dtype == new.dtype + and old.untyped_storage().nbytes() == new.untyped_storage().nbytes() + ): # If we can just update in-place to avoid re-registering # can be faster if the underlying storage is the same update_tensor_inplace(old, new) @@ -36,5 +39,4 @@ def replace_parameter(mod: torch.nn.Module, name: str, # parameters for `torch.compile` compatibility if not isinstance(new, torch.nn.Parameter): new = torch.nn.Parameter(new, requires_grad=False) - mod.register_parameter(name, - torch.nn.Parameter(new, requires_grad=False)) + mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py index fbb850d227765..69466bdcb64c2 100644 --- a/vllm/model_executor/layers/quantization/utils/machete_utils.py +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -39,12 +39,19 @@ def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]: return [-1, 128] -def check_machete_supports_shape(in_features: int, out_featrues: int) \ - -> tuple[bool, Optional[str]]: +def check_machete_supports_shape( + in_features: int, out_featrues: int +) -> tuple[bool, Optional[str]]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: - return False, "Input features size must be divisible by "\ - f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + return ( + False, + "Input features size must be divisible by " + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}", + ) if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: - return False, "Output features size must be divisible by "\ - f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return ( + False, + "Output features size must be divisible by " + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}", + ) return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 02057b476c6e2..d2fa5af1b8540 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -40,8 +40,9 @@ def query_marlin_supported_quant_types( ): if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) if device_capability < 80: return [] @@ -50,10 +51,12 @@ def query_marlin_supported_quant_types( # - has_zp is False: return quant_types that has not zero points # - has_zp is None: both if has_zp is None: - types0 = query_marlin_supported_quant_types(False, include_fp_type, - device_capability) - types1 = query_marlin_supported_quant_types(True, include_fp_type, - device_capability) + types0 = query_marlin_supported_quant_types( + False, include_fp_type, device_capability + ) + types1 = query_marlin_supported_quant_types( + True, include_fp_type, device_capability + ) return types0 + types1 if has_zp: @@ -68,108 +71,126 @@ def query_marlin_supported_quant_types( def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: - + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) supported_types = query_marlin_supported_quant_types( - has_zp, True, device_capability) + has_zp, True, device_capability + ) if quant_type not in supported_types: - return (False, f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).") - if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): - return (False, f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.") + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) return True, None -def check_marlin_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, - device_capability) +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) return cond -def verify_marlin_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False) -> None: +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) -> None: - +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError(f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) - if (group_size < input_size - and input_size_per_partition % group_size != 0): + if group_size < input_size and input_size_per_partition % group_size != 0: raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + "with --quantization gptq." + ) -def check_marlin_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) \ - -> tuple[bool, Optional[str]]: +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape(output_size_per_partition, - input_size_per_partition, input_size, - group_size) + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) except ValueError as e: return False, e.__str__() return True, None -def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ - -> bool: - output_size_per_partition = getattr(layer, "output_size_per_partition", - None) or layer.output_size - input_size_per_partition = getattr(layer, "input_size_per_partition", - None) or layer.input_size +def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + output_size_per_partition = ( + getattr(layer, "output_size_per_partition", None) or layer.output_size + ) + input_size_per_partition = ( + getattr(layer, "input_size_per_partition", None) or layer.input_size + ) return check_marlin_supports_shape( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=layer.input_size, - group_size=group_size)[0] + group_size=group_size, + )[0] -def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ - -> bool: +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin @@ -180,41 +201,58 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) # moe marlin requires n % 128 == 0 and k % 64 == 0 - supports_shape = hidden_size % 128 == 0 and \ - intermediate_size_per_partition % max(64, group_size) == 0 + supports_shape = ( + hidden_size % 128 == 0 + and intermediate_size_per_partition % max(64, group_size) == 0 + ) supports_group_size = group_size in [-1, 32, 64, 128] - return supports_shape and supports_group_size and \ - supports_router_weight and supports_activation + return ( + supports_shape + and supports_group_size + and supports_router_weight + and supports_activation + ) -def marlin_make_workspace(output_size_per_partition: int, - device: torch.device) -> torch.Tensor: - max_workspace_size = (output_size_per_partition // - GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) +def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor): + """ + Given Marlin packed weight matrices w1_packed, and w2_packed, + return the MoE intermediate size N + """ + marlin_tile_size = 16 + return w2_packed.size(1) * marlin_tile_size -def marlin_make_workspace_new(device: torch.device, - max_blocks_per_sm: int = 1) -> torch.Tensor: +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_make_workspace_new( + device: torch.device, max_blocks_per_sm: int = 1 +) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace - # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + # size. The num of threadblocks is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count - return torch.zeros(sms * max_blocks_per_sm, - dtype=torch.int, - device=device, - requires_grad=False) + return torch.zeros( + sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False + ) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, - is_row_parallel: bool) -> bool: +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -222,17 +260,18 @@ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) -def marlin_sort_g_idx( - g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices @@ -243,14 +282,13 @@ def get_scale_perms(): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int) -> torch.Tensor: - +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] @@ -286,8 +324,9 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -308,8 +347,9 @@ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, return zp -def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, - size_n: int, num_bits: int) -> torch.Tensor: +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -331,8 +371,9 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, - size_n: int, num_bits: int): +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -340,8 +381,7 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, - num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) return output @@ -353,7 +393,8 @@ def maybe_warn_marlin_atomic_add(device, dtype): logger.info_once( "You are running Marlin kernel with bf16 on GPUs before SM90. " "You can consider change to fp16 to achieve better performance " - "if possible.") + "if possible." + ) def maybe_warn_marlin_atomic_add_env(): @@ -365,12 +406,13 @@ def maybe_warn_marlin_atomic_add_env(): "Marlin kernel can achieve better performance for small size_n " "with experimental use_atomic_add feature. " "You can consider set environment variable " - "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible." + ) -def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, - dtype: torch.dtype) -> bool: - +def should_use_atomic_add_reduce( + m: int, n: int, k: int, device: torch.device, dtype: torch.dtype +) -> bool: # the performance of atomicAdd is better than global reduce # only when m*n is small and k is large if n >= 2048 or k < 2048 or device.type != "cuda": @@ -392,88 +434,98 @@ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition, ) + out_shape = input.shape[:-1] + (output_size_per_partition,) - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=output_size_per_partition, - k=reshaped_x.size(1), - device=input.device, - dtype=input.dtype) + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) - output = ops.gptq_marlin_gemm(reshaped_x, - None, - weight, - bias, - weight_scale, - None, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False) + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) return output.reshape(out_shape) def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition, ) + out_shape = input.shape[:-1] + (output_size_per_partition,) - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=output_size_per_partition, - k=reshaped_x.size(1), - device=input.device, - dtype=input.dtype) + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) - output = ops.gptq_marlin_gemm(reshaped_x, - None, - weight, - bias, - weight_scale, - None, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False) + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 94ffdcd26ecde..c5e34f392fb22 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,8 +8,12 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, - marlin_permute_scales, should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -28,7 +32,8 @@ def nvfp4_marlin_process_scales(marlin_scales): "NVFP4 Marlin assumes the scales to be >=0, but has encountered " "negative scales. Accuracy will likely be degraded. This is " "because it changes the scales from FP8-S1E4M3 to a special " - "FP8-S0E5M3 format to speedup the dequantization.") + "FP8-S0E5M3 format to speedup the dequantization." + ) # convert to half first, we would convert to fp8 later marlin_scales = marlin_scales.to(torch.half) @@ -36,11 +41,13 @@ def nvfp4_marlin_process_scales(marlin_scales): # 8 is the number of scale number using by one thread marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1) + marlin_scales.size(0) * 2, -1 + ) # fit the layout of fp8 dequantization marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1) + marlin_scales.size(0), -1 + ) # We assume that weight_scale (FP8-S1E4M3) is always greater # than or equal to 0. So we can convert @@ -60,11 +67,13 @@ def mxfp4_marlin_process_scales(marlin_scales): # 8 is the number of scale number using by one thread marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1) + marlin_scales.size(0) * 2, -1 + ) # fit the layout of fp8 dequantization marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1) + marlin_scales.size(0), -1 + ) marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) return marlin_scales @@ -78,48 +87,49 @@ def nvfp4_marlin_process_global_scale(global_scale): target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 - exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) - return global_scale * (2.0**(exponent_bias - 7)) + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) + return global_scale * (2.0 ** (exponent_bias - 7)) def apply_fp4_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_scale_2: Optional[torch.Tensor], - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: # For GPUs that lack FP4 hardware support, we can leverage the # Marlin kernel for fast weight-only FP4 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) + out_shape = input.shape[:-1] + (size_n,) - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=size_n, - k=size_k, - device=input.device, - dtype=input.dtype) + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype + ) - output = ops.gptq_marlin_gemm(a=reshaped_x, - c=None, - b_q_weight=weight, - b_bias=bias, - b_scales=weight_scale, - global_scale=weight_scale_2, - b_zeros=None, - g_idx=None, - perm=None, - workspace=workspace, - b_q_type=scalar_types.float4_e2m1f, - size_m=reshaped_x.size(0), - size_n=size_n, - size_k=size_k, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce) + output = ops.gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) return output.reshape(out_shape) @@ -129,7 +139,8 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) is_nvfp4 = hasattr(layer, "weight_scale_2") group_size = 16 if is_nvfp4 else 32 @@ -150,11 +161,13 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: perm = torch.empty(0, dtype=torch.int, device=device) qweight = layer.weight.view(torch.int32).T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=part_size_k, - size_n=part_size_n, - num_bits=4) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) # WEIGHT SCALES @@ -165,27 +178,23 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: weight_scale = weight_scale.view(torch.float8_e8m0fnu) weight_scale = weight_scale.to(param_dtype) - weight_scale = marlin_permute_scales(s=weight_scale, - size_k=part_size_k, - size_n=part_size_n, - group_size=group_size) + weight_scale = marlin_permute_scales( + s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) if is_nvfp4: weight_scale = nvfp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) weight_scale_2 = layer.weight_scale_2.to(param_dtype) weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, - requires_grad=False) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) else: weight_scale = mxfp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) if hasattr(layer, "bias") and layer.bias is not None: - assert layer.bias.shape == (part_size_n, ) + assert layer.bias.shape == (part_size_n,) bias = marlin_permute_bias(layer.bias) layer.bias = torch.nn.Parameter(bias, requires_grad=False) @@ -197,7 +206,8 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) is_nvfp4 = hasattr(layer, "w13_weight_scale_2") group_size = 16 if is_nvfp4 else 32 @@ -227,11 +237,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: for i in range(e): qweight = weight[i].view(torch.int32).T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=size_k, - size_n=size_n, - num_bits=4) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 + ) tensor_list.append(marlin_qweight) weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -247,8 +255,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: scales = scales.view(torch.float8_e8m0fnu) scales = scales.to(param_dtype) if is_nvfp4: - global_scale = getattr(layer, - name + "_weight_scale_2").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -259,10 +266,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: for i in range(e): scale = scales[i].T - marlin_scales = marlin_permute_scales(s=scale, - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scale, size_k=size_k, size_n=size_n, group_size=group_size + ) if is_nvfp4: marlin_scales = nvfp4_marlin_process_scales(marlin_scales) else: @@ -275,8 +281,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: if is_nvfp4: global_scale = nvfp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, - requires_grad=False) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) setattr(layer, name + "_weight_scale_2", global_scale) # BIAS @@ -306,26 +311,26 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): global_scale = scales.max() / 448 scales = (scales / global_scale).to(torch.float8_e4m3fn) - fp4_weight = torch.randint(0, - 256, (size_n, size_k // 2), - dtype=torch.uint8, - device=weight.device) - fp4_weight_part_1 = ((fp4_weight & 0b10000000) | - ((fp4_weight & 0b01110000) >> 2)) + fp4_weight = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device + ) + fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) fp4_weight2 = fp4_weight << 4 - fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | - ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) weight_ref = torch.cat( - [fp4_weight_part_2.unsqueeze(2), - fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) - weight_ref = weight_ref * global_scale.to(weight.dtype) * \ - scales.repeat_interleave(group_size, 1).to(weight.dtype) + [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 + ).view(size_n, size_k) + weight_ref = ( + weight_ref + * global_scale.to(weight.dtype) + * scales.repeat_interleave(group_size, 1).to(weight.dtype) + ) marlin_qweight = ops.gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), @@ -335,10 +340,9 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): num_bits=4, ) - marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + ) marlin_scales = nvfp4_marlin_process_scales(marlin_scales) global_scale = nvfp4_marlin_process_global_scale(global_scale) @@ -351,32 +355,31 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): size_n, size_k = weight.shape device = weight.device - scales = torch.randint(100, - 125, (size_n, size_k // group_size), - dtype=torch.uint8, - device=weight.device) + scales = torch.randint( + 100, + 125, + (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device, + ) scales = scales.view(torch.float8_e8m0fnu) - fp4_weight = torch.randint(0, - 256, (size_n, size_k // 2), - dtype=torch.uint8, - device=weight.device) - fp4_weight_part_1 = ((fp4_weight & 0b10000000) | - ((fp4_weight & 0b01110000) >> 2)) + fp4_weight = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device + ) + fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) fp4_weight2 = fp4_weight << 4 - fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | - ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) weight_ref = torch.cat( - [fp4_weight_part_2.unsqueeze(2), - fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) - weight_ref = weight_ref * \ - scales.repeat_interleave(group_size, 1).to(weight.dtype) + [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 + ).view(size_n, size_k) + weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) marlin_qweight = ops.gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), @@ -386,10 +389,9 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): num_bits=4, ) - marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + ) marlin_scales = mxfp4_marlin_process_scales(marlin_scales) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 511e19545d5ae..9348ac158daa7 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -8,8 +8,12 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, - marlin_permute_scales, should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -28,60 +32,63 @@ def fp8_fused_exponent_bias_into_scales(scales): target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 - exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1) s = torch.ones_like(scales) * 2 s = s**exponent_bias return scales * s def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) + out_shape = input.shape[:-1] + (size_n,) - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=size_n, - k=size_k, - device=input.device, - dtype=input.dtype) + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype + ) - output = ops.gptq_marlin_gemm(a=reshaped_x, - c=None, - b_q_weight=weight, - b_bias=bias, - b_scales=weight_scale, - global_scale=None, - b_zeros=None, - g_idx=None, - perm=None, - workspace=workspace, - b_q_type=scalar_types.float8_e4m3fn, - size_m=reshaped_x.size(0), - size_n=size_n, - size_k=size_k, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce) + output = ops.gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) return output.reshape(out_shape) -def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, - size_k_first: bool = True) -> None: +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, size_k_first: bool = True +) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition @@ -104,11 +111,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, if not size_k_first: qweight = qweight.T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=part_size_k, - size_n=part_size_n, - num_bits=8) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) # WEIGHT SCALES @@ -151,26 +160,27 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, # size_n may not divisible by block_size[0] scales = scales[:, :part_size_n] - marlin_scales = marlin_permute_scales(s=scales, - size_k=part_size_k, - size_n=part_size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) if hasattr(layer, "bias") and layer.bias is not None: - assert layer.bias.shape == (part_size_n, ) + assert layer.bias.shape == (part_size_n,) bias = marlin_permute_bias(layer.bias) layer.bias = torch.nn.Parameter(bias, requires_grad=False) -def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, - size_k_first: bool = True) -> None: +def prepare_moe_fp8_layer_for_marlin( + layer: torch.nn.Module, size_k_first: bool = True +) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) e = layer.num_experts k = layer.hidden_size @@ -202,11 +212,9 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, if not size_k_first: qweight = qweight.T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=size_k, - size_n=size_n, - num_bits=8) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8 + ) tensor_list.append(marlin_qweight) weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -265,10 +273,9 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, scales = scales[..., :size_n].contiguous() for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i], - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size + ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -295,8 +302,9 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, setattr(layer, name, bias) -def pack_fp8_to_int32(fp8_tensor: torch.Tensor, - size_k_first: bool = True) -> torch.Tensor: +def pack_fp8_to_int32( + fp8_tensor: torch.Tensor, size_k_first: bool = True +) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ @@ -335,10 +343,9 @@ def marlin_quant_fp8_torch(weight, group_size): num_bits=8, ) - marlin_scales = marlin_permute_scales(s=scales.T, - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size + ) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index b2c228c242532..1bbd88d5ca710 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -9,24 +9,26 @@ import torch from vllm.scalar_type import ScalarType -from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, - marlin_zero_points) -from .quant_utils import (get_pack_factor, gptq_quantize_weights, - quantize_weights, sort_weights) +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) class MarlinWorkspace: - def __init__(self, out_features, min_thread_n, max_parallel): - assert (out_features % min_thread_n == 0), ( - "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n)) + assert out_features % min_thread_n == 0, ( + "out_features = {} is indivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + ) - max_workspace_size = ((out_features // min_thread_n) * max_parallel) + max_workspace_size = (out_features // min_thread_n) * max_parallel - self.scratch = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda") + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): @@ -54,8 +56,7 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): q_w = q_w.cpu().numpy().astype(np.uint32) - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=np.uint32) + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i @@ -71,10 +72,10 @@ def get_weight_perm(num_bits: int): col = i // 4 for block in [0, 1]: for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col + 8 * block) for j in range(4): @@ -94,11 +95,13 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None): +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -109,7 +112,8 @@ def marlin_quantize(w: torch.Tensor, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm) + w, quant_type, group_size, act_order, test_perm + ) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -130,8 +134,7 @@ def marlin_quantize(w: torch.Tensor, return res_list -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, - group_size: int): +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): size_k, size_n = w.shape # Normalize group_size @@ -144,18 +147,13 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, num_groups = size_k // group_size # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, - quant_type, - group_size, - zero_points=True) + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) # Reformat to marlin weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, - quant_type.size_bits) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) # Create result res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 1c93c364679da..90011f116bb0b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -26,8 +26,7 @@ from .quant_utils import gptq_quantize_weights # matrix elements into reordered metadata matrix elements (or, # equivalently, for gathering reordered metadata matrix element back # into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, - device): +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) @@ -35,9 +34,13 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, group_x = 64 group_y = 32 if meta_dtype.itemsize == 2 else 16 - dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + - (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + - ((dst_rows % group_x) // 8) * 4) + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) @@ -50,8 +53,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, interleave = 2 cols_maj = dst_cols // interleave cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + - cols_min).view(-1) + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) # This function converts dense matrix into sparse semi-structured @@ -75,17 +77,18 @@ def sparse_semi_structured_from_dense_cutlass(dense): raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError( - "Invalid number of elements per meta element calculated") + raise RuntimeError("Invalid number of elements per meta element calculated") if meta_dtype == torch.int32: if m % 16 != 0: raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16") + f"Number of rows of dense matrix {m} must be divisible by 16" + ) else: if m % 32 != 0: raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32") + f"Number of rows of dense matrix {m} must be divisible by 32" + ) if k % (4 * quadbits_per_meta_elem) != 0: raise RuntimeError( f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 @@ -146,40 +149,39 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: - sparse = dense_2.gather(-1, - idxs0.unsqueeze(-1) // 2).view( - m, - k // 2) # type: ignore[possibly-undefined] + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view( - (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: - meta = (meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12)) + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) elif quadbits_per_meta_elem == 8: - meta = (meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28)) + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device) + m, meta_ncols, meta_dtype, device + ) meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) @@ -222,13 +224,14 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: raise RuntimeError( f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix") + "expected according to the number of columns of meta matrix" + ) # Undo meta tensor elements reordering. meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device) - meta = torch.gather(meta_reordered.view(-1), 0, - meta_offsets).view(m, meta_ncols) + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) # Unpack sparse tensor back to original dense tensor, using # information provided by meta tensor. Note that torch.float @@ -270,16 +273,17 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): meta_2[:, :, 15] = (meta >> 30) & 0b11 dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( - -1, 1).repeat(1, 2).view(-1) + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) - dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: - dense.view(torch.half).scatter_(0, dense_offsets, - sparse.view(torch.half).view(-1)) + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) return dense.view(m, 2 * k) @@ -287,8 +291,8 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): def mask_creator(tensor): """ Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask will correspond to the given tensor. :param N: The number of weights in a group to keep @@ -301,14 +305,14 @@ def mask_creator(tensor): # for i, tensor in enumerate(tensors): if tensor.numel() % M != 0: raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " - f"{M} groups") + f"Tensor of size {tensor.shape} can't be evenly divided into {M} groups" + ) num_groups = tensor.numel() // M # N:M sparsity for linear layers tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) @@ -342,7 +346,7 @@ def check_24(w, num_rows_to_sample=50, _verbose=False): for i in sampled_row_idxs: for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): total_segments += 1 - block = w[i, j:j + BLOCK_SIZE] + block = w[i, j : j + BLOCK_SIZE] num_nonzero = torch.count_nonzero(block) if num_nonzero > MAX_NON_ZEROS: print("i = {} j = {} block = {}".format(i, j, block)) @@ -359,8 +363,7 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): # Compress q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( - q_24_no_zp) + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() # Restore bias @@ -390,13 +393,12 @@ def get_weight_perm_24(num_bits: int): col_o = col // 2 for block in [0, 1]: for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + - 4 * block) + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) perm = numpy.array(perm_list) @@ -413,9 +415,9 @@ def get_weight_perm_24(num_bits: int): return perm -def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, - group_size: int) -> torch.Tensor: - +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms_24() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] @@ -443,17 +445,18 @@ def marlin_24_quantize( # Quantize w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False) + w_24, quant_type, group_size, act_order=False + ) # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - quant_type) + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) size_k_comp = size_k // 2 # Reformat to marlin weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - quant_type.size_bits, weight_perm) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) # Create result diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 48f9cc3737e47..ee6c826f8b2c5 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -10,34 +10,50 @@ from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) -OCP_MX_BLOCK_SIZE = 32 - def _swizzle_mxfp4(quant_tensor, scale, num_warps): - """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel - """ + """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" import triton_kernels.matmul_ogs_details.opt_flags as opt_flags from triton_kernels.numerics import InFlexData from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.tensor_details.layout import StridedLayout - if (current_platform.is_cuda() - and current_platform.is_device_capability(90) - and not is_torch_equal_or_newer("2.8.1")): + + value_layout_opts: dict[str, Any] = {} + scale_layout_opts: dict[str, Any] = {} + + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(90) + and not is_torch_equal_or_newer("2.8.1") + ): logger.warning_once( "Mxfp4 on hopper is running on torch < 2.8.1, " "this cause swizling to be disabled, which may " - "cause performance degradation. Please upgrade to torch nightly") - value_layout, value_layout_opts = StridedLayout, dict() - scale_layout, scale_layout_opts = StridedLayout, dict() + "cause performance degradation. Please upgrade to torch nightly" + ) + value_layout = StridedLayout + scale_layout = StridedLayout + elif current_platform.is_rocm(): + from triton_kernels.tensor_details.layout import ( + GFX950MXScaleLayout, + StridedLayout, + ) + + from vllm.platforms.rocm import on_gfx950 + + value_layout = StridedLayout + scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout else: - value_layout, value_layout_opts = \ - layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1 + ) scale_layout, scale_layout_opts = ( layout.make_default_matmul_mxfp4_w_scale_layout( - mx_axis=1, num_warps=num_warps)) - if current_platform.is_cuda() and \ - current_platform.is_device_capability(100): + mx_axis=1, num_warps=num_warps + ) + ) + if current_platform.is_cuda() and current_platform.is_device_capability(100): constraints = { "is_persistent": True, "epilogue_subtile": 1, @@ -46,75 +62,98 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): # transpose the tensor so that the quantization axis is on dim1 quant_tensor = quant_tensor.transpose(-2, -1) scale = scale.transpose(-2, -1) - quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), - value_layout, **value_layout_opts) - scale = convert_layout(wrap_torch_tensor(scale), scale_layout, - **scale_layout_opts) + quant_tensor = convert_layout( + wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts + ) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts) return quant_tensor, InFlexData(), scale -def _can_support_mxfp4(use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - scoring_func: str = "softmax", - activation: str = "swigluoai", - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None): - return not (use_grouped_topk or topk_group or num_expert_group - or expert_map or custom_routing_function - or e_score_correction_bias or apply_router_weight_on_input - or scoring_func != "softmax" or activation != "swigluoai" - or expert_load_view or logical_to_physical_map - or logical_replica_count) +def _can_support_mxfp4( + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + scoring_func: str = "softmax", + activation: str = "swigluoai", + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, +): + return not ( + use_grouped_topk + or topk_group + or num_expert_group + or custom_routing_function + or e_score_correction_bias + or apply_router_weight_on_input + or scoring_func != "softmax" + or activation != "swigluoai" + or expert_load_view + or logical_to_physical_map + or logical_replica_count + ) -def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, - float_dtype: torch.dtype) -> torch.Tensor: +def _dequant_mxfp4( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: try: from quark.torch.kernel import mx except ImportError as err: - raise ImportError("The package `amd-quark` is required to use " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`." + ) from err return mx.dq_mxfp4(x, scale, float_dtype) -def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor, - float_dtype: torch.dtype) -> torch.Tensor: - return torch.empty((*x.shape[:-1], x.shape[-1] * 2), - dtype=float_dtype, - device=x.device) +def _dequant_mxfp4_fake( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device + ) -def _quant_dequant_mxfp4(x: torch.Tensor, - scale_calculation_mode: str = "even") -> torch.Tensor: +def _quant_dequant_mxfp4( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: try: from quark.torch.kernel import mx except ImportError as err: - raise ImportError("The package `amd-quark` is required to use " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`." + ) from err return mx.qdq_mxfp4(x, scale_calculation_mode) -def _quant_dequant_mxfp4_fake(x: torch.Tensor, - scale_calculation_mode: str = "even" - ) -> torch.Tensor: +def _quant_dequant_mxfp4_fake( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: return torch.empty_like(x) +# Protect these operations into a torch custom op to avoid errors as +# torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped +# Explanation: Dynamo does not know how to trace the builtin +# `kernel_ext.PyCapsule.dq_uint8_mxfp4_to_half.` This function is either a +# Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python +# extension (perhaps created with pybind). +# TODO: Make sure there is no way to avoid having these functions +# marked as skipped by dynamo. try: direct_register_custom_op( op_name="dequant_mxfp4", op_func=_dequant_mxfp4, - mutates_args=[], fake_impl=_dequant_mxfp4_fake, ) dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 @@ -125,7 +164,6 @@ try: direct_register_custom_op( op_name="quant_dequant_mxfp4", op_func=_quant_dequant_mxfp4, - mutates_args=[], fake_impl=_quant_dequant_mxfp4_fake, ) quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 diff --git a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py new file mode 100644 index 0000000000000..2249e96589708 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE +from vllm.utils import direct_register_custom_op + + +def _quant_dequant_mxfp6( + x: torch.Tensor, + quant_dtype: str, + scale_calculation_mode: str = "even", +) -> torch.Tensor: + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_fp4_fp6_per_group_with_scale, + ) + from quark.torch.quantization.utils import even_round, reshape_to_blocks + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP6 models. Please install it with `pip install " + "amd-quark`." + ) from err + + axis = -1 + block_x = reshape_to_blocks(x, OCP_MX_BLOCK_SIZE, axis) + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + amax = amax.squeeze(-1) + + # TODO: there are other rounding strategies supported in quark and in the + # config.json that we do not check for here! + if scale_calculation_mode != "even": + raise NotImplementedError( + f"Scale calculation mode {scale_calculation_mode} is not yet " + "supported in MX-FP6 quantization" + ) + scale = even_round(amax, quant_dtype) + + # Apply dequantize(quantize(x)). + x = fake_quantize_fp4_fp6_per_group_with_scale( + x, + scale.to(x.device), + axis=axis, + group_size=OCP_MX_BLOCK_SIZE, + quant_dtype=quant_dtype, + ) + + return x + + +def _quant_dequant_mxfp6_fake( + x: torch.Tensor, + quant_dtype: str, + scale_calculation_mode: str = "even", +) -> torch.Tensor: + return torch.empty_like(x) + + +def _dequant_mxfp6( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str +) -> torch.Tensor: + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + dequantize_fp4_fp6_per_group, + ) + from quark.torch.utils.pack import create_pack_method + except ImportError as e: + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP6 models. Please install it with `pip install " + "amd-quark`." + ) from e + + pack_method = create_pack_method(None, dtype=quant_dtype) + unpacked_x = pack_method.unpack(x, reorder=False) + + scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype) + + # TODO: `dequantize_fp4_fp6_per_group` and `prepare_inputs_per_group` + # always return fp32. + return dequantize_fp4_fp6_per_group( + unpacked_x, + scale, + axis=-1, + group_size=OCP_MX_BLOCK_SIZE, + quant_dtype=quant_dtype, + ).to(float_dtype) + + +def _dequant_mxfp6_fake( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str +) -> torch.Tensor: + assert (x.shape[-1] * 4) % 3 == 0 + return torch.empty( + (*x.shape[:-1], (x.shape[-1] * 4) // 3), dtype=float_dtype, device=x.device + ) + + +# Protect these operations into a torch custom op to avoid errors as +# torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped +# Explanation: Dynamo does not know how to trace the builtin +# `kernel_ext.PyCapsule.dq_uint8_mxfp4_to_half.` This function is either a +# Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python +# extension (perhaps created with pybind). +# TODO: Make sure there is no way to avoid having these functions +# marked as skipped by dynamo. +try: + direct_register_custom_op( + op_name="quant_dequant_mxfp6", + op_func=_quant_dequant_mxfp6, + mutates_args=[], + fake_impl=_quant_dequant_mxfp6_fake, + ) +except AttributeError as error: + raise error + + +# Expose keyword arguments. +def quant_dequant_mxfp6( + x: torch.Tensor, + quant_dtype: str, + scale_calculation_mode: str = "even", +) -> torch.Tensor: + return torch.ops.vllm.quant_dequant_mxfp6(x, quant_dtype, scale_calculation_mode) + + +try: + direct_register_custom_op( + op_name="dequant_mxfp6", + op_func=_dequant_mxfp6, + mutates_args=[], + fake_impl=_dequant_mxfp6_fake, + ) +except AttributeError as error: + raise error + + +def dequant_mxfp6( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str +) -> torch.Tensor: + return torch.ops.vllm.dequant_mxfp6(x, scale, float_dtype, quant_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py new file mode 100644 index 0000000000000..248b2d6c4af2b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + try: + from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize + except ImportError as err: + raise ImportError( + "The package `flashinfer` is required to do " + "MX-FP8 quantization. Please install it with" + "`pip install flashinfer`" + ) from err + + return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 8648771cb0177..62b480210fc06 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -12,8 +12,9 @@ __all__ = [ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) def break_fp4_bytes(a, dtype): @@ -45,12 +46,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): +def dequantize_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 @@ -95,8 +93,7 @@ def ref_nvfp4_quant(x, global_scale, block_size): assert x.ndim == 2 m, n = x.shape x = torch.reshape(x, (m, n // block_size, block_size)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) scale = torch.clamp(scale, max=448, min=-448) scale = scale.to(torch.float8_e4m3fn).to(torch.float32) @@ -108,10 +105,13 @@ def ref_nvfp4_quant(x, global_scale, block_size): return cast_to_fp4(clipped_x), scale.squeeze(-1) -def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor, - weight: torch.Tensor, - weight_scale_swizzled: torch.Tensor, - weight_global_scale: torch.Tensor): +def run_nvfp4_emulations( + x: torch.Tensor, + input_global_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale_swizzled: torch.Tensor, + weight_global_scale: torch.Tensor, +): group_size = 16 x_m, x_k = x.shape output_dtype = x.dtype @@ -127,9 +127,14 @@ def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor, # dequantize weight w_fp4 = weight.data.view(torch.uint8) - w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data, - weight_global_scale, output_dtype, x.device, - group_size) + w_dq = dequantize_to_dtype( + w_fp4, + weight_scale_swizzled.data, + weight_global_scale, + output_dtype, + x.device, + group_size, + ) # matmul out = torch.matmul(x_dq, w_dq.t()) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py index 21af74c6b72b5..c3f26cc774118 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -5,11 +5,14 @@ from dataclasses import dataclass import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_flashinfer_fp4_cutlass_moe_available) + is_flashinfer_fp4_cutlass_moe_available, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported) + is_fp4_marlin_supported, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) __all__ = ["detect_nvfp4_moe_support", "NvFp4Support"] @@ -29,12 +32,12 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: """Detect platform support for NV-FP4 fused-MoE path""" cutlass_supported = cutlass_fp4_supported() - allow_flashinfer = (cutlass_supported - and is_flashinfer_fp4_cutlass_moe_available()) + allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available() if allow_flashinfer: - _logger.info_once("Using FlashInfer kernels for %s.", class_name - or "NVFP4 path") + _logger.info_once( + "Using FlashInfer kernels for %s.", class_name or "NVFP4 path" + ) else: if envs.VLLM_USE_FLASHINFER_MOE_FP4: _logger.warning_once( @@ -50,7 +53,8 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: else: raise ValueError( "Current platform does not support NVFP4 quantization. " - "Please use Blackwell GPUs or enable FlashInfer.") + "Please use Blackwell GPUs or enable FlashInfer." + ) return NvFp4Support( cutlass_supported=cutlass_supported, diff --git a/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py b/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py new file mode 100644 index 0000000000000..3c71441a3df7e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum +from typing import Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +OCP_MX_BLOCK_SIZE = 32 + +OCP_MX_DTYPES = { + "mxfp4", + "mxfp6_e3m2", + "mxfp6_e2m3", + "mxfp8_e4m3", + "mxfp8_e5m2", + "mxint8", +} +SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"} + + +class OCP_MX_Scheme(str, Enum): + w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4" + w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2" + w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3" + w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2" + w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3" + + @classmethod + def from_quant_dtype( + cls, input_dtype: Union[str, None], weight_dtype: Union[str, None] + ): + if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES: + return None + elif input_dtype == "mxfp4" and weight_dtype == "mxfp4": + return cls.w_mxfp4_a_mxfp4 + elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4": + return cls.w_mxfp4_a_mxfp6_e3m2 + elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4": + return cls.w_mxfp4_a_mxfp6_e2m3 + elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2": + return cls.w_mxfp6_e3m2_a_mxfp6_e3m2 + elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3": + return cls.w_mxfp6_e2m3_a_mxfp6_e2m3 + else: + logger.warning( + "input_dtype='%s' and" + " weight_dtype='%s' is not supported " + "in OCP_MX_Scheme at the moment.", + input_dtype, + weight_dtype, + ) + return None diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py index 00d3def1db81e..1f053103fc3c6 100644 --- a/vllm/model_executor/layers/quantization/utils/petit_utils.py +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -11,14 +11,15 @@ if TYPE_CHECKING: # 1. Create a global variable as a placeholder for the module _petit_kernel: Optional["ModuleType"] = None -_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " - "`pip install petit-kernel`.") +_PETIT_INSTALL_MSG = ( + "Petit is not installed. Please install it with `pip install petit-kernel`." +) def _import_petit_kernel() -> "ModuleType": """ A helper function to handle the lazy import. - The first time this function is called, it will import the petit_kernel + The first time this function is called, it will import the petit_kernel library and store it in the global _petit_kernel variable. Subsequent calls will return the already-loaded module directly. """ @@ -28,6 +29,7 @@ def _import_petit_kernel() -> "ModuleType": try: import petit_kernel + _petit_kernel = petit_kernel return _petit_kernel except ImportError: @@ -41,14 +43,16 @@ _require_petit = _import_petit_kernel def _check_petit_nvfp4_supported( - quant_method: str, - group_size: Optional[int]) -> tuple[bool, Optional[str]]: + quant_method: str, group_size: Optional[int] +) -> tuple[bool, Optional[str]]: if quant_method != "NVFP4": return ( False, - ("Petit currently only supports: NVFP4 quantizations in sglang. " - "Please check the `hf_quant_config.json` file for your model's " - "quant configuration."), + ( + "Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration." + ), ) if group_size is not None and group_size != 16: return ( @@ -58,10 +62,8 @@ def _check_petit_nvfp4_supported( return (True, None) -def verify_petit_nvfp4_supported(quant_method: str, - group_size: Optional[int]) -> None: - supported, error_msg = _check_petit_nvfp4_supported( - quant_method, group_size) +def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size) if not supported: assert error_msg is not None raise ValueError(error_msg) @@ -77,15 +79,15 @@ def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: qweight = layer.weight.view(torch.int32).contiguous() # 3. Call functions through the imported module variable. - petit_qweight = petit_kernel.repack_nvfp4(qweight, - size_n=part_size_n, - size_k=part_size_k) + petit_qweight = petit_kernel.repack_nvfp4( + qweight, size_n=part_size_n, size_k=part_size_k + ) layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) # Permute scales - weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, - size_k=part_size_k, - size_n=part_size_n) + weight_scale = petit_kernel.process_nvfp4_scales( + scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n + ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) @@ -102,7 +104,7 @@ def apply_petit_nvfp4_linear( petit_kernel = _import_petit_kernel() reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) + out_shape = input.shape[:-1] + (size_n,) # TODO: Use auto-tuning to find the performant solution_id # Call the function via the module variable. diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 6154fca2e416d..2e9b279465f93 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" + from collections.abc import Mapping from dataclasses import dataclass from types import MappingProxyType @@ -31,8 +32,17 @@ class GroupShape(_GroupShape): """ # Aliases for common quantization group shapes - PER_TENSOR: ClassVar['GroupShape'] - PER_TOKEN: ClassVar['GroupShape'] + PER_TENSOR: ClassVar["GroupShape"] + PER_TOKEN: ClassVar["GroupShape"] + + def is_per_tensor(self) -> bool: + return self.row == -1 and self.col == -1 + + def is_per_token(self) -> bool: + return self.row == 1 and self.col == -1 + + def is_per_group(self) -> bool: + return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) @@ -47,18 +57,26 @@ class ScaleDesc: static: static scale if True, dynamic if False group_shape: group shape of the scale """ + dtype: torch.dtype static: bool group_shape: GroupShape def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) + group_shape = ( + "per_tensor" + if self.group_shape == GroupShape.PER_TENSOR + else ( + "per_token" + if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape) + ) + ) - return (f"{fx.graph.dtype_abbrs[self.dtype]}," - f"{'static' if self.static else 'dynamic'},{group_shape}") + return ( + f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}" + ) @dataclass(frozen=True) @@ -70,6 +88,7 @@ class QuantKey: scale2: second-level scale descriptor symmetric: symmetric if True, asymmetric if False """ + dtype: torch.dtype scale: ScaleDesc scale2: Optional[ScaleDesc] = None @@ -77,9 +96,11 @@ class QuantKey: def __str__(self): scale2_str = f"scale2({self.scale2})," if self.scale2 else "" - return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," - f"scale({self.scale}),{scale2_str}" - f"{'a' if not self.symmetric else ''}symmetric)") + return ( + f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)" + ) kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) @@ -92,16 +113,16 @@ kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) -kNvfp4Quant = QuantKey(FP4_DTYPE, - scale=kNvfp4GroupScale, - scale2=kStaticTensorScale) +kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent - return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], - group_shape[1] if group_shape[1] > 0 else x.shape[-1]) + return ( + group_shape[0] if group_shape[0] > 0 else x.shape[-2], + group_shape[1] if group_shape[1] > 0 else x.shape[-1], + ) # Useful when treating N-dimensional group scaling as extended numpy-style @@ -116,15 +137,17 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # then we would expand a to: # a = [[1, 1, 2, 2], # [3, 3, 4, 4]] -# NOTE this function this function does not explicitly broadcast dimensions +# NOTE this function does not explicitly broadcast dimensions # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: assert s % t.shape[i] == 0 - t = t.unsqueeze(i + 1)\ - .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) .flatten(i, i + 1) + ) return t @@ -142,9 +165,10 @@ def scaled_quantize( quant_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) - assert quant_dtype.is_floating_point, \ - "currently `scaled_quantize` only supports floating point dtypes " \ + assert quant_dtype.is_floating_point, ( + "currently `scaled_quantize` only supports floating point dtypes " "but could be extended to support other dtypes" + ) finfo = torch.finfo(quant_dtype) @@ -166,11 +190,13 @@ def scaled_quantize( # Apply scale and convert form: # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) - x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\ - .clamp(min=finfo.min, max=finfo.max)\ - .reshape(blk_m, blk_n, group_shape[0], group_shape[1])\ - .permute(0, 2, 1, 3)\ + x_scl_sat = ( + (x_blkd_permd * scale.unsqueeze(-1)) + .clamp(min=finfo.min, max=finfo.max) + .reshape(blk_m, blk_n, group_shape[0], group_shape[1]) + .permute(0, 2, 1, 3) .reshape(x.shape) + ) return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal() @@ -191,7 +217,8 @@ def scaled_dequantize( if group_shape is None: raise AssertionError( "if x_s is 1D tensor, group_shape must be provided otherwise " - "its ambiguous which dimension to broadcast x_s to") + "its ambiguous which dimension to broadcast x_s to" + ) # unsqueeze the scales for the dimension where we want to broadcast # across the full extent if group_shape[0] == x_q.shape[-2]: @@ -201,7 +228,8 @@ def scaled_dequantize( else: raise AssertionError( "if x_s is a vector we should be broadcasting it to the full " - "extent of one of the dimensions") + "extent of one of the dimensions" + ) if group_shape is not None: assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] @@ -210,9 +238,9 @@ def scaled_dequantize( return (x_q.to(torch.float32) * x_s).to(out_dtype) -def pack_quantized_values_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) @@ -232,9 +260,9 @@ def pack_quantized_values_into_int32(w_q: torch.Tensor, return res.permute(inv_perm) -def unpack_quantized_values_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) @@ -256,7 +284,7 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, def is_layer_skipped( prefix: str, ignored_layers: list[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj @@ -282,7 +310,16 @@ def is_layer_skipped( raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " - "to have the same precision.") + "to have the same precision." + ) + elif "experts" in prefix: + return any( + [ + prefix in layer_name + for layer_name in ignored_layers + if "experts" in layer_name + ] + ) else: is_skipped = prefix in ignored_layers @@ -295,16 +332,18 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None): +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): assert q_w.shape == w_ref.shape orig_device = q_w.device k_size, _ = q_w.shape - g_idx = torch.zeros((k_size, ), dtype=torch.int32) + g_idx = torch.zeros((k_size,), dtype=torch.int32) for i in range(k_size): g_idx[i] = i // group_size @@ -323,16 +362,20 @@ def permute_rows(q_w: torch.Tensor, ) -def quantize_weights(w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False): - assert quant_type.is_integer(), \ +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert quant_type.is_integer(), ( "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, \ - "to have group zero points, group_size must be provided "\ + ) + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " "(-1 group_size is channelwise)" + ) orig_device = w.device orig_type = w.dtype @@ -362,14 +405,16 @@ def quantize_weights(w: torch.Tensor, if zero_points: assert not quant_type.is_signed() and quant_type.max() > 0 w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ - .clamp(min_q_val, max_q_val).int() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) else: # If the bias is such that there are no possible negative/positive # values, set the max value to inf to avoid divide by 0 w_s = torch.max( abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) @@ -416,19 +461,22 @@ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -def gptq_quantize_weights(w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None): +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" - assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, ( f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" + ) + assert group_size in SUPPORTED_GROUP_SIZES + [size_k], ( + f"Unsupported groupsize = {group_size}" + ) w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) @@ -436,13 +484,13 @@ def gptq_quantize_weights(w: torch.Tensor, g_idx = torch.empty(0, dtype=torch.int, device=w.device) rand_perm = torch.empty(0, dtype=torch.int, device=w.device) if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k) + assert group_size < size_k, ( + "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + ) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, - test_perm) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) return w_ref, w_q, w_s, g_idx, rand_perm @@ -450,8 +498,7 @@ def gptq_quantize_weights(w: torch.Tensor, def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device - sort_indices = torch.argsort(g_idx).to( - dtype=torch.int32) # Sort based on g_idx + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx g_idx = g_idx[sort_indices].contiguous() q_w = q_w[sort_indices, :].contiguous() @@ -521,10 +568,11 @@ def unpack_cols( ): pack_factor = get_pack_factor(num_bits) assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, size_n // pack_factor - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor) + assert packed_q_w.shape == (size_k, size_n // pack_factor), ( + "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + ) orig_device = packed_q_w.device @@ -590,7 +638,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: """ assert scale.dtype == torch.float8_e4m3fn, ( "swizzle_blockscale expects the input tensor to be in " - "torch.float8_e4m3fn format.") + "torch.float8_e4m3fn format." + ) scale_ndim = scale.ndim if scale_ndim == 2: @@ -605,9 +654,9 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: M_padded = _round_up(M, 128) K_padded = _round_up(K, 4) - padded = torch.zeros((B, M_padded, K_padded), - dtype=scale.dtype, - device=scale.device) + padded = torch.zeros( + (B, M_padded, K_padded), dtype=scale.dtype, device=scale.device + ) padded[:B, :M, :K] = scale # Reshape / permute to the layout required by the kernel. diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 36d16960ec57c..c26cd4f28cb69 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -10,10 +10,10 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -22,10 +22,12 @@ TORCH_DEVICE_IDENTITY = None # The condition to determine if it is on a platform that supports # torch._scaled_mm rowwise feature. # The condition is determined once as the operations -# are time consuming. -USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() and version.parse( - torch.__version__) >= version.parse("2.7") - and current_platform.has_device_capability(94)) +# are time-consuming. +USE_ROWWISE_TORCH_SCALED_MM = ( + current_platform.is_rocm() + and version.parse(torch.__version__) >= version.parse("2.7") + and current_platform.has_device_capability(94) +) def sparse_cutlass_supported() -> bool: @@ -73,8 +75,8 @@ CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, - torch.Tensor]) -> torch.Tensor: + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight @@ -86,12 +88,12 @@ def all_close_1d(x: torch.Tensor) -> bool: def convert_to_channelwise( - weight_scale: torch.Tensor, - logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: + weight_scale: torch.Tensor, logical_widths: list[int] +) -> tuple[torch.Tensor, torch.Tensor]: # Create channelwise buffer - weight_scale_channel = torch.empty((sum(logical_widths), 1), - dtype=torch.float32, - device=weight_scale.device) + weight_scale_channel = torch.empty( + (sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device + ) # Expand each scale to match the size of each logical matrix. start = 0 @@ -104,8 +106,8 @@ def convert_to_channelwise( def requantize_with_max_scale( - weight: torch.Tensor, weight_scale: torch.Tensor, - logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: list[int] +) -> tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() @@ -115,8 +117,9 @@ def requantize_with_max_scale( # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = (weight_scale[-1] - > torch.finfo(torch.float8_e4m3fn).min) + unfused_module_in_checkpoint = ( + weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) # If unfused checkpoint, need requanize with the single scale. if unfused_module_in_checkpoint: @@ -126,10 +129,8 @@ def requantize_with_max_scale( if logical_width == 0: continue end = start + logical_width - weight_dq = per_tensor_dequantize(weight[start:end, :], - weight_scale[idx]) - weight[start:end, :], _ = ops.scaled_fp8_quant( - weight_dq, max_w_scale) + weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale) start = end return max_w_scale, weight @@ -142,97 +143,144 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) -def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, - out_dtype: torch.dtype, scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list, **kwargs) -> torch.Tensor: - +def cutlass_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) + output = ops.cutlass_scaled_mm( + qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias + ) return output.view(*output_shape) +def flashinfer_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: + return flashinfer_scaled_fp8_mm( + qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias + ) + + def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx - if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( - ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: - output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, - current_platform.get_cu_count()) + + if ( + envs.VLLM_ROCM_USE_SKINNY_GEMM + and on_mi3xx() + and qinput.shape[0] == 1 + and qinput.shape[1] % 16 == 0 + and ((bias is None) or (bias.dtype == out_dtype)) + ): + output = ops.wvSplitKQ( + weight.t(), + qinput, + out_dtype, + scale_a, + scale_b, + current_platform.get_cu_count(), + bias, + ) else: - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) + output = torch._scaled_mm( + qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias, + ) return output def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: - return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), - dtype=out_dtype) + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + qinput, weight, out_dtype, scale_a, scale_b, bias + ) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) direct_register_custom_op( op_name="rocm_per_tensor_w8a8_scaled_mm_impl", op_func=rocm_per_tensor_w8a8_scaled_mm_impl, - mutates_args=[], fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, - dispatch_key=current_platform.dispatch_key, ) -def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) +def torch_per_tensor_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + output = torch._scaled_mm( + qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias + ) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) -def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def torch_per_token_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -244,26 +292,31 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, # rowwise scaled GEMM before using it # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b.t(), - bias=bias) + output = torch._scaled_mm( + qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b.t(), + bias=bias, + ) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) output = output.view(*output_shape) return output -def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list, - **kwargs) -> torch.Tensor: +def torch_channelwise_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: # Use unfused DQ due to limitations with scaled_mm # Symmetric quantized GEMM by definition computes the following: @@ -281,18 +334,20 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) # DQ # C = sw * sx * (X * W) + bias @@ -303,19 +358,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( - cutlass_fp8_supported: bool, per_tensor_weights: bool, - per_tensor_activations: bool) -> Callable[..., torch.Tensor]: + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool +) -> Callable[..., torch.Tensor]: + if per_tensor_weights and per_tensor_activations: + if preferred_backend == "rocm": + return rocm_per_tensor_w8a8_scaled_mm + if preferred_backend == "flashinfer": + return flashinfer_w8a8_scaled_mm + if preferred_backend == "cutlass": + return cutlass_w8a8_scaled_mm + return torch_per_tensor_w8a8_scaled_mm # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if cutlass_fp8_supported: + if preferred_backend == "cutlass" or preferred_backend == "flashinfer": return cutlass_w8a8_scaled_mm - if per_tensor_weights and per_tensor_activations: - if current_platform.is_rocm(): - return rocm_per_tensor_w8a8_scaled_mm - return torch_per_tensor_w8a8_scaled_mm + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if not per_tensor_weights and not per_tensor_activations \ - and USE_ROWWISE_TORCH_SCALED_MM: + if ( + not per_tensor_weights + and not per_tensor_activations + and USE_ROWWISE_TORCH_SCALED_MM + ): return torch_per_token_w8a8_scaled_mm # Normally, torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -332,12 +395,21 @@ class Fp8LinearOp: in the __init__ method, as reading config is not allowed inside forward. """ - def __init__(self, - act_quant_static: bool, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), - act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None): - self.cutlass_fp8_supported = cutlass_fp8_supported + def __init__( + self, + act_quant_static: bool, + act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, + pad_output: Optional[bool] = None, + ): + if current_platform.is_rocm(): + self.preferred_backend = "rocm" + elif current_platform.is_cuda() and cutlass_fp8_supported(): + if has_flashinfer() and current_platform.has_device_capability(100): + self.preferred_backend = "flashinfer" + else: + self.preferred_backend = "cutlass" + else: + self.preferred_backend = "torch" # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -346,16 +418,19 @@ class Fp8LinearOp: # as it breaks with dynamic shapes. if pad_output is None: config = get_current_vllm_config().compilation_config - pad_output = config.level < CompilationLevel.PIECEWISE and \ - not cutlass_fp8_supported and \ - not current_platform.is_rocm() + pad_output = ( + config.level < CompilationLevel.PIECEWISE + and self.preferred_backend == "torch" + ) self.output_padding = 17 if pad_output else None self.act_quant_static = act_quant_static self.act_quant_group_shape = act_quant_group_shape - self.quant_fp8 = QuantFP8(static=act_quant_static, - group_shape=act_quant_group_shape, - num_token_padding=self.output_padding) + self.quant_fp8 = QuantFP8( + static=act_quant_static, + group_shape=act_quant_group_shape, + num_token_padding=self.output_padding, + ) def apply( self, @@ -389,28 +464,29 @@ class Fp8LinearOp: else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations) + self.preferred_backend, per_tensor_weights, per_tensor_activations + ) - return w8a8_scaled_mm_func(qinput=qinput, - weight=weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - input_2d=input_2d, - output_shape=output_shape) + return w8a8_scaled_mm_func( + qinput=qinput, + weight=weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + output_shape=output_shape, + ) def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: assert weight.dtype == torch.float8_e4m3fn # The bits pattern 10000000(-128) represents zero in e4m3fn diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 3f2d571777c00..8dc237f8232d7 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -32,6 +32,7 @@ related helpers for sincos positional embeddings. Example models: Qwen (Qwen-VL), MiniCPM-V 2.0 """ + import math from functools import partial from typing import Callable, Optional, Union @@ -47,8 +48,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, - int]) -> torch.Tensor: +def get_abs_pos( + abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int] +) -> torch.Tensor: # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -56,21 +58,26 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, dtype = abs_pos.dtype if isinstance(tgt_size, int): tgt_size = (tgt_size, tgt_size) - if (src_size == tgt_size[0] and src_size == tgt_size[1]): + if src_size == tgt_size[0] and src_size == tgt_size[1]: return abs_pos - return (F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size[0], tgt_size[1]), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + return ( + F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .flatten(0, 2) + .to(dtype=dtype) + ) # sin/cos positional embedding helpers are adapted from: # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_1d_sincos_pos_embed_from_grid( - embed_dim: int, pos: np.ndarray, - version: tuple[int, int] = (2, 0)) -> torch.Tensor: + embed_dim: int, pos: np.ndarray, version: tuple[int, int] = (2, 0) +) -> torch.Tensor: """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -96,15 +103,17 @@ def get_1d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed_from_grid( - embed_dim: int, grid: np.ndarray, - version: tuple[int, int] = (2, 0)) -> torch.Tensor: + embed_dim: int, grid: np.ndarray, version: tuple[int, int] = (2, 0) +) -> torch.Tensor: assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + embed_dim // 2, grid[0], version + ) # (H*W, D/2) or (H, W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + embed_dim // 2, grid[1], version + ) # (H*W, D/2) or (H, W, D/2) if version == (2, 0): emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) @@ -114,10 +123,10 @@ def get_2d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed( - embed_dim: int, - grid_size: Union[int, tuple[int, int]], - cls_token: bool = False, - version: tuple[int, int] = (2, 0), + embed_dim: int, + grid_size: Union[int, tuple[int, int]], + cls_token: bool = False, + version: tuple[int, int] = (2, 0), ) -> torch.Tensor: """ grid_size: int of the grid height and width @@ -134,15 +143,13 @@ def get_2d_sincos_pos_embed( grid_w = np.arange(grid_w_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) - assert isinstance(grid, np.ndarray) and \ - grid.shape == (2, grid_h_size, grid_w_size) + assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size) if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], - axis=0) + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) else: pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) return pos_embed @@ -156,15 +163,17 @@ class BaseResampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__(self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.num_queries = num_queries @@ -174,14 +183,16 @@ class BaseResampler(nn.Module): self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim)) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = ReplicatedLinear(kv_dim, - embed_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_proj") + self.kv_proj = ReplicatedLinear( + kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) else: # Maintain the same return value with ReplicatedLinear.forward - self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa nn.Identity()(*args, **kwargs), None, ) @@ -189,10 +200,10 @@ class BaseResampler(nn.Module): self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.do_post_projection = do_post_projection - self.ln_post = norm_layer(embed_dim) if do_post_projection else None - self.proj = nn.Parameter( - (embed_dim**-0.5) * - torch.empty(embed_dim, embed_dim)) if do_post_projection else None + if self.do_post_projection: + self.ln_post = norm_layer(embed_dim) + data = (embed_dim**-0.5) * torch.empty(embed_dim, embed_dim) + self.proj = nn.Parameter(data=data) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) @@ -206,32 +217,35 @@ class Resampler2(BaseResampler): present in minicpmv2.0, but not qwen-vl. """ - def __init__(self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(grid_size**2, - embed_dim, - num_heads, - kv_dim, - norm_layer, - do_post_projection=do_post_projection, - quant_config=quant_config, - prefix=prefix) + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix, + ) self.adaptive = adaptive - pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, - grid_size, - version=(2, 0)) + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, grid_size, version=(2, 0)) self.pos_embed = nn.Parameter( - torch.from_numpy(pos_embed_arr).requires_grad_(False)) + torch.from_numpy(pos_embed_arr).requires_grad_(False) + ) def forward( self, @@ -242,15 +256,16 @@ class Resampler2(BaseResampler): if tgt_sizes is None: tgt_sizes = int(math.sqrt(x.size(1))) if self.adaptive: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes, - version=(2, 0)) - pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, - dtype=x.dtype) + pos_embed_arr = get_2d_sincos_pos_embed( + self.embed_dim, tgt_sizes, version=(2, 0) + ) + pos_embed = torch.from_numpy(pos_embed_arr).to( + device=x.device, dtype=x.dtype + ) else: - pos_embed = get_abs_pos(self.pos_embed, - tgt_sizes).to(device=x.device, - dtype=x.dtype) + pos_embed = get_abs_pos(self.pos_embed, tgt_sizes).to( + device=x.device, dtype=x.dtype + ) x, _ = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 564f9a5c00750..e6956de4bfaaa 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings.""" + from typing import Any, Optional import torch @@ -37,8 +38,7 @@ def get_rope( if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { - k: tuple(v) if isinstance(v, list) else v - for k, v in rope_scaling.items() + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() } rope_scaling_args = tuple(rope_scaling_tuple.items()) else: @@ -56,8 +56,16 @@ def get_rope( if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) - key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args, dual_chunk_attention_args, dtype) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dual_chunk_attention_args, + dtype, + ) if key in _ROPE_DICT: return _ROPE_DICT[key] @@ -67,13 +75,19 @@ def get_rope( for k, v in dual_chunk_attention_config.items() if k in ("chunk_size", "local_size") } - rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - **extra_kwargs) + rotary_emb = DualChunkRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) elif not rope_scaling: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, dtype) + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) else: scaling_type = rope_scaling["rope_type"] @@ -81,18 +95,23 @@ def get_rope( scaling_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - scaling_factor, low_freq_factor, - high_freq_factor, - original_max_position) + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) elif scaling_type == "mllama4": - rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype) + rotary_emb = Llama4VisionRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) elif scaling_type == "default": if "mrope_section" in rope_scaling: rotary_emb = MRotaryEmbedding( @@ -103,6 +122,7 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), ) else: rotary_emb = RotaryEmbedding( @@ -115,75 +135,136 @@ def get_rope( ) elif scaling_type == "linear": scaling_factor = rope_scaling["factor"] - rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, dtype) + rotary_emb = LinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) elif scaling_type == "ntk": scaling_factor = rope_scaling["factor"] - mixed_b = rope_scaling.get('mixed_b', None) - rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, dtype, - mixed_b) + mixed_b = rope_scaling.get("mixed_b", None) + rotary_emb = NTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + mixed_b, + ) elif scaling_type == "dynamic": if "alpha" in rope_scaling: scaling_alpha = rope_scaling["alpha"] rotary_emb = DynamicNTKAlphaRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_alpha, dtype) + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_alpha, + dtype, + ) elif "factor" in rope_scaling: scaling_factor = rope_scaling["factor"] rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor, dtype) + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) else: - raise ValueError("Dynamic rope scaling must contain either " - "'alpha' or 'factor' field") + raise ValueError( + "Dynamic rope scaling must contain either 'alpha' or 'factor' field" + ) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow") + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } - rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, - original_max_position, - base, is_neox_style, - scaling_factor, dtype, - **extra_kwargs) + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + scaling_factor=scaling_factor, + **extra_kwargs, + ) + else: + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] # assert max_position == original_max_position * scaling_factor extra_kwargs = { k: v for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow", "mscale", "mscale_all_dim") + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) } rotary_emb = DeepseekScalingRotaryEmbedding( - head_size, rotary_dim, original_max_position, base, - is_neox_style, scaling_factor, dtype, **extra_kwargs) + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("short_mscale", "long_mscale") } rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( - head_size, rotary_dim, max_position, original_max_position, - base, is_neox_style, dtype, short_factor, long_factor, - **extra_kwargs) + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 10fce857a8ae2..cf50b60118b9b 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings Base Class.""" + from typing import Optional import torch from vllm.model_executor.custom_op import CustomOp -from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch +from .common import apply_rotary_emb_torch +from .rocm_aiter_rope_ops import ( + is_rocm_triton_rotary_embedding_enabled, + rocm_aiter_rotary_emb, +) @CustomOp.register("rotary_embedding") @@ -30,11 +35,24 @@ class RotaryEmbedding(CustomOp): self.base = base self.is_neox_style = is_neox_style self.dtype = dtype + # TODO(mgoin): disabled for now due to failures + # Flashinfer only supports head_size=64, 128, 256, 512. + # https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202 + # self.use_flashinfer = (self.enabled() + # and dtype in (torch.float16, torch.bfloat16) + # and current_platform.is_cuda() + # and has_flashinfer() + # and self.head_size in [64, 128, 256, 512]) + self.use_flashinfer = False cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) + if not self.use_flashinfer: + cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) + self.is_rocm_triton_rotary_embedding_enabled = ( + is_rocm_triton_rotary_embedding_enabled() + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -42,8 +60,12 @@ class RotaryEmbedding(CustomOp): # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -57,16 +79,22 @@ class RotaryEmbedding(CustomOp): cache = torch.cat((cos, sin), dim=-1) return cache + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if ( + self.cos_sin_cache.device != query.device + or self.cos_sin_cache.dtype != query.dtype + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) @@ -74,20 +102,18 @@ class RotaryEmbedding(CustomOp): query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -96,27 +122,55 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.use_flashinfer: + torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + from vllm import _custom_ops as ops - # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) - # is expensive, so avoid calling it if possible - if self.cos_sin_cache.device != query.device or \ - self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) + self._match_cos_sin_cache_dtype(query) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.is_rocm_triton_rotary_embedding_enabled: + self._match_cos_sin_cache_dtype(query) + rocm_aiter_rotary_emb( + positions, + query, + key, + self.cos_sin_cache, + self.head_size, + self.rotary_dim, + self.is_neox_style, + ) else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + self.forward_cuda(positions, query, key) return query, key def forward_xpu( @@ -124,110 +178,26 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, - dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. + self._match_cos_sin_cache_dtype(query) + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. if key is None: # XPU kernel doesn't support key=None so fall back to native impl # TODO(sarckk): add support for optional key in # ipex.llm.functional.rotary_embedding_batched - return self.forward_native(positions, query, key, offsets) + return self.forward_native(positions, query, key) else: - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) - return query, key - - def forward_neuron( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - - def _apply_rotary_emb_neuron( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, - ) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - # x1 = x[..., ::2] - - # x2 = x[..., 1::2] - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) - x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - if offsets is not None: - positions = positions + offsets - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - - if self.rotary_dim == self.head_size: - query = apply_rotary_emb_dispatch(query, cos, sin, - self.is_neox_style) - query = query.reshape(query_shape) - if key is not None: - key = apply_rotary_emb_dispatch(key, cos, sin, - self.is_neox_style) - key = key.reshape(key_shape) - else: - head_size = query.shape[-1] - query_reshaped = query.view(-1, head_size) - query_pass = query_reshaped[:, self.rotary_dim:].view( - *query.shape[:-1], head_size - self.rotary_dim) - query_rot = query_reshaped[:, :self.rotary_dim].view( - *query.shape[:-1], self.rotary_dim) - query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, - self.is_neox_style) - query = torch.cat((query_rot, query_pass), - dim=-1).reshape(query_shape) - - if key is not None: - key_reshaped = key.view(-1, head_size) - key_pass = key_reshaped[:, self.rotary_dim:].view( - *key.shape[:-1], head_size - self.rotary_dim) - key_rot = key_reshaped[:, :self.rotary_dim].view( - *key.shape[:-1], self.rotary_dim) - key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 8d821bea19e3e..124ea0236cbfb 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -2,19 +2,26 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from functools import cache +from importlib.util import find_spec +from typing import Callable, Optional import torch +from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb +logger = init_logger(__name__) + # common functions def rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -46,9 +53,9 @@ def apply_rotary_emb_torch( return torch.stack((o1, o2), dim=-1).flatten(-2) -def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool) -> torch.Tensor: +def apply_rotary_emb_dispatch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool +) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] @@ -58,39 +65,68 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, positional embeddings. """ if current_platform.is_cuda(): - return apply_rotary_emb(x.unsqueeze(0), cos, sin, - not is_neox_style).squeeze(0) + return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0) else: return apply_rotary_emb_torch(x, cos, sin, is_neox_style) +@cache +def dispatch_rotary_emb_function( + default: Optional[Callable[..., torch.Tensor]] = None, +) -> Callable[..., torch.Tensor]: + if current_platform.is_cuda(): + return apply_rotary_emb + + if current_platform.is_rocm(): + if find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + + return apply_rotary + else: + logger.warning( + "flash_attn is not installed. Falling back to PyTorch " + "implementation for rotary embeddings." + ) + + if default is not None: + return default + else: + return apply_rotary_emb_torch + + # yarn functions # Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> float: - return (dim * math.log(max_position_embeddings / - (num_rotations * 2 * math.pi))) / (2 * - math.log(base)) +def yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Find dim range bounds based on rotations def yarn_find_correction_range( - low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> tuple[int, int]: + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> tuple[int, int]: low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) return max(low, 0), min(high, dim - 1) # Clamp values just in case -def yarn_linear_ramp_mask(low: float, high: float, dim: int, - dtype: torch.dtype) -> torch.Tensor: +def yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype +) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity @@ -103,3 +139,47 @@ def yarn_get_mscale(scale: float = 1) -> float: if scale <= 1: return 1.0 return 0.1 * math.log(scale) + 1.0 + + +def _flashinfer_rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + """Custom op wrapper for flashinfer's rotary embedding. + + This is an in-place operation that modifies query and key tensors directly. + """ + from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace + + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + + +def _flashinfer_rotary_embedding_fake( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + return + + +# Register flashinfer rotary embedding custom op +direct_register_custom_op( + op_name="flashinfer_rotary_embedding", + op_func=_flashinfer_rotary_embedding, + mutates_args=["query", "key"], # These tensors are modified in-place + fake_impl=_flashinfer_rotary_embedding_fake, +) diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index cd888b733426b..eaedca9b52192 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -9,8 +9,12 @@ import torch from vllm.platforms import current_platform from .base import RotaryEmbedding -from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range, - yarn_linear_ramp_mask) +from .common import ( + rotate_gptj, + rotate_neox, + yarn_find_correction_range, + yarn_linear_ramp_mask, +) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: @@ -49,46 +53,60 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation. self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) / - yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * - attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**( - torch.arange(0, - self.rotary_dim, - 2, - dtype=torch.float, - device=current_platform.device_type) / - self.rotary_dim) + pos_freqs = self.base ** ( + torch.arange( + 0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type, + ) + / self.rotary_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, - dtype=torch.float)) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + inv_freq_mask = ( + 1 + - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device=current_platform.device_type, - dtype=torch.float32) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32, + ) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -97,17 +115,16 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" assert key is not None - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] + self._match_cos_sin_cache_dtype(query) + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] - if self.cos_sin_cache.device != positions.device: - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) - cos_sin = self.cos_sin_cache[torch.add(positions, offsets) - if offsets is not None else positions] + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the @@ -129,3 +146,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): query = query_rot key = key_rot return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 3d8da0fa9d8f5..0e6eddda772f9 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -35,18 +35,17 @@ class DualChunkRotaryEmbedding(CustomOp): self.local_size = local_size self.dtype = dtype self.device = torch.device(f"cuda:{torch.cuda.current_device()}") - (q_cache, qc_cache, k_cache, qc_no_clamp_cache, - q_inter_cache) = self._compute_cos_sin_cache() + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = ( + self._compute_cos_sin_cache() + ) self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) - self.register_buffer("cos_sin_qc_no_clamp_cache", - qc_no_clamp_cache, - persistent=False) - self.register_buffer("cos_sin_q_inter_cache", - q_inter_cache, - persistent=False) + self.register_buffer( + "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False + ) + self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False) def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -59,8 +58,12 @@ class DualChunkRotaryEmbedding(CustomOp): # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -68,16 +71,15 @@ class DualChunkRotaryEmbedding(CustomOp): inv_freq = self._compute_inv_freq(self.base) chunk_len = self.chunk_size - self.local_size q_t = torch.arange(chunk_len, dtype=torch.float) - qc_t = (torch.arange(chunk_len, dtype=torch.float) + - chunk_len).clamp(max=self.chunk_size) - k_t = torch.arange(self.max_position_embeddings, - dtype=torch.float) % chunk_len + qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp( + max=self.chunk_size + ) + k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len # count from chunk_len, no clamp(self.chunk_size) restriction qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len # count from self.chunk_size for q_inter's rope - q_inter_t = torch.arange(chunk_len, - dtype=torch.float) + self.chunk_size + q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size q_freqs = torch.outer(q_t, inv_freq) qc_freqs = torch.outer(qc_t, inv_freq) @@ -97,21 +99,24 @@ class DualChunkRotaryEmbedding(CustomOp): q_inter_cos = q_inter_freqs.cos() q_inter_sin = q_inter_freqs.sin() - q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), - dim=-1).to(dtype=self.dtype, - device=self.device) - q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), - dim=-1).to(dtype=self.dtype, - device=self.device) + q_cache = torch.cat((q_cos, q_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -120,47 +125,70 @@ class DualChunkRotaryEmbedding(CustomOp): ) -> tuple[torch.Tensor, torch.Tensor]: query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] else: query_pass = None key_pass = None - positions_with_offsets = (torch.add(positions, offsets) - if offsets is not None else positions) + positions_with_offsets = ( + torch.add(positions, offsets) if offsets is not None else positions + ) key = self._apply_rotary_embedding( - self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass + ) chunk_len = self.chunk_size - self.local_size query = self._apply_rotary_embedding( self.cos_sin_q_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) query_succ = self._apply_rotary_embedding( self.cos_sin_qc_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) query_inter = self._apply_rotary_embedding( self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), - query_rot, query_pass) + query_rot, + query_pass, + ) query_succ_critical = self._apply_rotary_embedding( self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) query_inter_critical = self._apply_rotary_embedding( self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) # merge query into one tensor to simplify the interfaces - query = torch.cat(( - query, - query_succ, - query_inter, - query_succ_critical, - query_inter_critical, - ), - dim=-1) + query = torch.cat( + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1, + ) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native(positions, query, key, offsets) + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: diff --git a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py index 1da39bbd303bd..dd9d06d4b288f 100644 --- a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py @@ -23,14 +23,16 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): dtype: torch.dtype, ) -> None: self.scaling_alpha = scaling_alpha - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_cos_sin_cache(self) -> torch.Tensor: # For Hunyuan DynamicNTKAlphaRotaryEmbedding max_len = self.max_position_embeddings - base = self.base * self.scaling_alpha**(self.rotary_dim / - (self.rotary_dim - 2)) + base = self.base * self.scaling_alpha ** ( + self.rotary_dim / (self.rotary_dim - 2) + ) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) diff --git a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py index ec2008b90cfb8..28fd87ecc21fc 100644 --- a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py @@ -44,8 +44,9 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): dtype: torch.dtype, ) -> None: self.scaling_factor = scaling_factor - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -54,9 +55,9 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor base = self.base * ( - (self.scaling_factor * max_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.rotary_dim / - (self.rotary_dim - 2)) + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py new file mode 100644 index 0000000000000..2bc0477c5af28 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .common import apply_rotary_emb_dispatch +from .mrope import MRotaryEmbedding + + +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): + """3D rotary positional embedding. 3D is t:time h:height w:width""" + + def forward_native( # type: ignore[override] + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 + assert section_h == section_w + # Split according to [h w h w h w h w... t t t...] + section_cos_t = cos[..., -section_t:] + section_cos_h = cos[..., : section_h + section_w : 2] + section_cos_w = cos[..., 1 : section_h + section_w : 2] + + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape( + cos_h.shape[:-1] + (cos_h.shape[-1] * 2,) + ) + cos = torch.cat([cos_hw, cos_t], dim=-1) + + section_sin_t = sin[..., -section_t:] + section_sin_h = sin[..., : section_h + section_w : 2] + section_sin_w = sin[..., 1 : section_h + section_w : 2] + + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape( + sin_h.shape[:-1] + (sin_h.shape[-1] * 2,) + ) + sin = torch.cat([sin_hw, sin_t], dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( # type: ignore[override] + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py index 6e920991882d4..cbb3ee4e9974b 100644 --- a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py @@ -71,8 +71,9 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors: list[float] = scaling_factors # noqa - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) # Lazy initialized. self._scaling_factor_to_offset: dict[float, int] diff --git a/vllm/model_executor/layers/rotary_embedding/llama3_rope.py b/vllm/model_executor/layers/rotary_embedding/llama3_rope.py index adcef549bc4c2..ed9a6031eb6f3 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama3_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama3_rope.py @@ -9,7 +9,6 @@ from .base import RotaryEmbedding class Llama3RotaryEmbedding(RotaryEmbedding): - def __init__( self, head_size: int, @@ -27,8 +26,9 @@ class Llama3RotaryEmbedding(RotaryEmbedding): self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) @@ -37,8 +37,9 @@ class Llama3RotaryEmbedding(RotaryEmbedding): wave_len = 2 * math.pi / inv_freqs if self.low_freq_factor != self.high_freq_factor: - smooth = (self.orig_max_position / wave_len - self.low_freq_factor - ) / (self.high_freq_factor - self.low_freq_factor) + smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) else: smooth = 0 new_freqs = torch.where( @@ -47,8 +48,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): torch.where( wave_len > low_freq_wavelen, inv_freqs / self.scaling_factor, - (1 - smooth) * inv_freqs / self.scaling_factor + - smooth * inv_freqs, + (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs, ), ) return new_freqs diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 415a85ab698bc..0b808e31c903e 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -10,7 +10,6 @@ from .base import RotaryEmbedding class Llama4VisionRotaryEmbedding(RotaryEmbedding): - def __init__( self, head_size: int, @@ -20,12 +19,13 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, ): - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) - inv_freqs = inv_freqs[:(self.rotary_dim // 2)] + inv_freqs = inv_freqs[: (self.rotary_dim // 2)] return inv_freqs def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -34,36 +34,36 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): # self.max_position_embeddings here is number of image patches # i.e. (image_size // patch_size) ** 2 num_patches = self.max_position_embeddings - img_idx = torch.arange(num_patches, - dtype=torch.int32) \ - .reshape(num_patches, 1) + img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN num_patches_single_dim = int(math.sqrt(num_patches)) frequencies_x = img_idx % num_patches_single_dim frequencies_y = img_idx // num_patches_single_dim - freqs_x = ((frequencies_x + 1)[..., None] * - inv_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs_y = ((frequencies_y + 1)[..., None] * - inv_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs = torch.cat([freqs_x, freqs_y], - dim=-1).float().contiguous()[..., ::2] + freqs_x = ( + (frequencies_x + 1)[..., None] * inv_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs_y = ( + (frequencies_y + 1)[..., None] * inv_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) cache = torch.view_as_complex( - torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + ) return cache - def forward( + def forward_native( # type: ignore[override] self, query: torch.Tensor, key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert key is not None + # self.cos_sin_cache here is complex tensor so we cannot cast into + # query's dtype directly with self._match_cos_sin_cache_dtype self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) - query_ = torch.view_as_complex(query.float().reshape( - *query.shape[:-1], -1, 2)) - key_ = torch.view_as_complex(key.float().reshape( - *key.shape[:-1], -1, 2)) + query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) broadcast_shape = [ d if i == 1 or i == (query_.ndim - 1) else 1 for i, d in enumerate(query_.shape) @@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) + + def forward_cuda( # type: ignore[override] + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a091cfb743291..fce110e6a5270 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -1,22 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools from typing import Optional, Union import numpy as np import torch from transformers import PretrainedConfig -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .base import RotaryEmbedding from .common import apply_rotary_emb_dispatch +from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @triton.jit -def _triton_qwen2vl_mrope_forward( +def _triton_mrope_forward( q_ptr, k_ptr, cos, @@ -31,12 +30,14 @@ def _triton_qwen2vl_mrope_forward( pad_hd: tl.constexpr, mrope_section_t: tl.constexpr, mrope_section_h: tl.constexpr, + mrope_section_w: tl.constexpr, + is_interleaved: tl.constexpr, ): # Adapted from # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py # This version supports flatten input tensors from vllm # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) - # instead of (3, bsz, seq_len, head_dim) + # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary pid = tl.program_id(0) # locate start address q_ptr = q_ptr + pid * (n_qh * hd) @@ -48,9 +49,6 @@ def _triton_qwen2vl_mrope_forward( # #################################################################### # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) - t_end = mrope_section_t - h_end = t_end + mrope_section_h - # Updated stride calculation for half head_dim half_rd = rd // 2 t_cos = cos + pid * half_rd @@ -62,9 +60,16 @@ def _triton_qwen2vl_mrope_forward( # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) - t_mask = cos_offsets < t_end - h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) - w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) + if is_interleaved: + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) + else: + t_end = mrope_section_t + h_end = t_end + mrope_section_h + t_mask = cos_offsets < mrope_section_t + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) @@ -81,21 +86,25 @@ def _triton_qwen2vl_mrope_forward( # program instance (i.e. for the current token) separately # #################################################################### # left half of the head - first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( - 0, pad_hd // 2)[None, :] - first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( - 0, pad_hd // 2)[None, :] - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( - 0, pad_hd // 2)[None, :] < rd // 2) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( - 0, pad_hd // 2)[None, :] < rd // 2) + first_half_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_half_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) - q_tile_1 = tl.load(q_ptr + first_half_q_offsets, - mask=first_q_mask, - other=0).to(sin_row.dtype) - k_tile_1 = tl.load(k_ptr + first_half_k_offsets, - mask=first_k_mask, - other=0).to(sin_row.dtype) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( + sin_row.dtype + ) # right half of the head second_half_q_offsets = first_half_q_offsets + (rd // 2) @@ -103,12 +112,12 @@ def _triton_qwen2vl_mrope_forward( second_q_mask = first_q_mask second_k_mask = first_k_mask - q_tile_2 = tl.load(q_ptr + second_half_q_offsets, - mask=second_q_mask, - other=0).to(sin_row.dtype) - k_tile_2 = tl.load(k_ptr + second_half_k_offsets, - mask=second_k_mask, - other=0).to(sin_row.dtype) + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( + sin_row.dtype + ) # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] # Since cos and sin are now half-size, @@ -132,12 +141,13 @@ def triton_mrope( mrope_section: list[int], head_size: int, rotary_dim: int, + mrope_interleaved: bool, ) -> tuple[torch.Tensor, torch.Tensor]: """Qwen2VL mrope kernel. Args: - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] + q: [num_tokens, num_heads * head_size] + k: [num_tokens, num_kv_heads * head_size] cos: [3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs) sin: [3, num_tokens, head_size //2 ] @@ -159,7 +169,7 @@ def triton_mrope( cos = cos.contiguous() sin = sin.contiguous() - _triton_qwen2vl_mrope_forward[(n_row, )]( + _triton_mrope_forward[(n_row,)]( q, k, cos, @@ -174,10 +184,23 @@ def triton_mrope( pad_hd, mrope_section[0], mrope_section[1], + mrope_section[2], + mrope_interleaved, ) return q, k +def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] + x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] + return x_t + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -190,39 +213,53 @@ class MRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[list[int]] = None, + mrope_interleaved: bool = False, + # YaRN parameters. + *, + scaling_factor: Optional[float] = None, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + if self.scaling_factor is not None: + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) + else: + self.mscale = 1.0 + # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get # a larger the cos and sin cache. self.cache_max_position_num = max_position_embeddings * 4 - super().__init__(head_size, rotary_dim, self.cache_max_position_num, - base, is_neox_style, dtype) + super().__init__( + head_size, + rotary_dim, + self.cache_max_position_num, + base, + is_neox_style, + dtype, + ) self.mrope_section = mrope_section + self.mrope_interleaved = mrope_interleaved if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 - self.use_triton = current_platform.is_cuda_alike() + def _compute_inv_freq(self, base: float) -> torch.Tensor: + if self.scaling_factor is None: + return super()._compute_inv_freq(base) + return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base) - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """MRope forward. - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ - if self.use_triton: - return self.forward_cuda(positions, query, key) - else: - return self.forward_native(positions, query, key) + def _compute_cos_sin_cache(self) -> torch.Tensor: + if self.scaling_factor is None: + return super()._compute_cos_sin_cache() + return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self) def forward_native( self, @@ -243,37 +280,37 @@ class MRotaryEmbedding(RotaryEmbedding): assert positions.ndim == 1 or positions.ndim == 2 assert key is not None + self._match_cos_sin_cache_dtype(query) num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - cos = torch.cat([ - m[i] - for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) - ], - dim=-1) - sin = torch.cat([ - m[i] - for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) - ], - dim=-1) + if self.mrope_interleaved: + cos = apply_interleaved_rope(cos, self.mrope_section) + sin = apply_interleaved_rope(sin, self.mrope_section) + else: + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -284,10 +321,10 @@ class MRotaryEmbedding(RotaryEmbedding): key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert positions.ndim == 1 or positions.ndim == 2 assert key is not None + self._match_cos_sin_cache_dtype(query) num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -304,25 +341,42 @@ class MRotaryEmbedding(RotaryEmbedding): self.mrope_section, self.head_size, self.rotary_dim, + self.mrope_interleaved, ) return q.reshape(query_shape), k.reshape(key_shape) query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + + def forward_cpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + @classmethod def get_input_positions( cls, @@ -340,496 +394,22 @@ class MRotaryEmbedding(RotaryEmbedding): image_grid_thw = [] if image_grid_thw is None else image_grid_thw video_grid_thw = [] if video_grid_thw is None else video_grid_thw - second_per_grid_ts = [] if second_per_grid_ts is None else \ - second_per_grid_ts + second_per_grid_ts = [] if second_per_grid_ts is None else second_per_grid_ts - llm_positions, mrope_position_delta = \ - cls.get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + llm_positions, mrope_position_delta = cls.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) return llm_positions.tolist(), mrope_position_delta - @classmethod - def get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - from vllm.transformers_utils.config import thinker_uses_mrope - if thinker_uses_mrope(hf_config): - return cls._omni_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - elif hf_config.model_type in ["glm4v", "glm4v_moe"]: - return cls._glm4v_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - else: - return cls._vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - ) - - @classmethod - def _glm4v_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" - - image_token_id = hf_config.image_token_id - video_start_token_id = hf_config.video_start_token_id - video_end_token_id = hf_config.video_end_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - llm_pos_ids_list: list = [] - - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - - input_token_type: list[str] = [] - video_check_flg = False - for token in input_tokens: - if token == video_start_token_id: - video_check_flg = True - elif token == video_end_token_id: - video_check_flg = False - - if (token == image_token_id) and (video_check_flg is False): - input_token_type.append("image") - elif (token == image_token_id) and (video_check_flg is True): - input_token_type.append("video") - else: - input_token_type.append("text") - - input_type_group: list[tuple[str, int, int]] = [] - for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1]): - group_list = list(group_iter) - start_index = group_list[0][0] - end_index = group_list[-1][0] + 1 - input_type_group.append((key, start_index, end_index)) - - video_frame_num = 1 - mm_data_idx = 0 - for modality_type, start_idx, end_idx in input_type_group: - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) - mm_data_idx += 1 - - elif modality_type == "video": - t, h, w = ( - video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - - for t_idx in range(llm_grid_t): - t_index = torch.tensor(t_idx).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view( - 1, -1, 1).expand(1, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view( - 1, 1, -1).expand(1, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) - - mm_data_idx += 1 - video_frame_num += 1 - - else: - text_len = end_idx - start_idx - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + - st_idx) - video_frame_num = 1 - - else: - text_len = len(input_tokens) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1)) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() - return llm_positions, mrope_position_delta - - @classmethod - def _vl_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, - "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * - tokens_per_second).long().flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _omni_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: Optional[list[float]] = None, - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - - Example: - - (V_i are vision position ids, A_i are audio position ids) - - |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... - |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... - """ - - # TODO(fyabc): refactor and share more code with - # _vl_get_input_positions_tensor. - - thinker_config = hf_config.thinker_config - audio_token_id = thinker_config.audio_token_index - image_token_id = thinker_config.image_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - vision_start_token_id = thinker_config.vision_start_token_id - vision_end_token_id = thinker_config.vision_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) - - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - - src_item = input_tokens - audio_seqlens = audio_feature_lengths - if not second_per_grid_ts: - second_per_grid_ts = [1] * video_grid_thw.shape[0] - audio_idx = 0 - video_idx = 0 - image_idx = 0 - new_src_item: list[int] = [] - llm_pos_ids_list: list[torch.Tensor] = [] - - idx = 0 - while idx < len(src_item): - new_src_item_len = len(new_src_item) - start_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - if src_item[idx] not in [ - audio_token_id, video_token_id, image_token_id - ]: - if use_audio_in_video and idx > 0: - if src_item[idx] == vision_end_token_id and \ - src_item[idx - 1] == audio_end_token_id: - # processing the <|audio_eos|> before <|vision_eos|> - start_idx -= 1 - elif src_item[idx] == audio_start_token_id and \ - src_item[idx - 1] == vision_start_token_id: - # processing the <|audio_bos|> after <|vision_eos|> - start_idx -= 1 - new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], - dtype=torch.long).expand(3, -1) - llm_pos_ids_list.append(llm_pos_ids) - elif src_item[idx] == audio_token_id: - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) - new_src_item.extend([audio_token_id] * place_num) - llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx - llm_pos_ids_list.append(llm_pos_ids) - audio_idx += 1 - elif src_item[idx] == image_token_id: - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2) - new_src_item.extend([image_token_id] * vision_seqlen) - image_idx += 1 - elif src_item[idx] == video_token_id and not use_audio_in_video: - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) - new_src_item.extend([video_token_id] * vision_seqlen) - video_idx += 1 - else: - # read audio from video - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) - grid_t = video_grid_thw[video_idx][0] - grid_h = video_grid_thw[video_idx][1] - grid_w = video_grid_thw[video_idx][2] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 - pure_audio_len = place_num - 2 - added_audio_len = 0 - audio_llm_pos_ids_list: list[torch.Tensor] = [] - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len( - t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - new_src_item.extend([video_token_id] * - vision_ntoken_per_chunk) - vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_chunk, - grid_hs, grid_ws).split(1, dim=1) - llm_pos_ids_list.extend(vision_llm_pos_ids_list) - new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len) * [audio_token_id]) - audio_start_idx = start_idx if len( - audio_llm_pos_ids_list - ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 - if min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = (torch.arange( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len)).expand(3, -1) + - audio_start_idx).split(1, - dim=1) - else: - audio_llm_pos_ids_list = [] - added_audio_len += min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - if added_audio_len < pure_audio_len: - new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id]) - audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand( - 3, -1) + llm_pos_ids_list[-1].max() + 1).split( - 1, dim=1) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - audio_idx += 1 - video_idx += 1 - # move to the next token - idx += len(new_src_item) - new_src_item_len - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = torch.cat(llm_pos_ids_list, - dim=1).max() + 1 - len(src_item) - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @staticmethod - def _get_llm_pos_ids_for_vision( - start_idx: int, - vision_idx: int, - spatial_merge_size: int, - t_index: list[int], - grid_hs: torch.Tensor, - grid_ws: torch.Tensor, - ) -> torch.Tensor: - llm_pos_ids_list = [] - llm_grid_h = grid_hs[vision_idx] // spatial_merge_size - llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( - len(t_index), -1, llm_grid_w).flatten()) - w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( - len(t_index), llm_grid_h, -1).flatten()) - t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( - -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() - _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) - llm_pos_ids_list.append(_llm_pos_ids + start_idx) - llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) - return llm_pos_ids - - @staticmethod - def _split_list_into_ranges(lst: torch.Tensor, - interval: int) -> list[list[int]]: - ranges: list[list[int]] = [[] - for _ in range((max(lst) // interval) + 1)] - for num in lst: - index = num // interval - ranges[index].append(num) - return ranges - @staticmethod def get_next_input_positions( mrope_position_delta: int, @@ -838,68 +418,24 @@ class MRotaryEmbedding(RotaryEmbedding): ) -> list[list[int]]: return [ list( - range(context_len + mrope_position_delta, - seq_len + mrope_position_delta)) for _ in range(3) + range( + context_len + mrope_position_delta, seq_len + mrope_position_delta + ) + ) + for _ in range(3) ] @staticmethod - def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, - mrope_position_delta: int, - context_len: int, num_new_tokens: int): - - values = np.arange(mrope_position_delta + context_len, - mrope_position_delta + context_len + num_new_tokens, - dtype=out.dtype) - out[:, out_offset:out_offset + num_new_tokens] = values - - @classmethod - def omni_get_updates_use_audio_in_video( - cls, - thinker_config: PretrainedConfig, - audio_len: int, - video_grid_thw: Union[list[int], torch.Tensor], - video_second_per_grid_t: float, - ) -> list[int]: - """Get video prompt updates when `use_audio_in_video` is True. - - In this case, audio and vision update ids will be split into - chunks and interleaved (details in `_omni_get_input_positions_tensor`). - - <|video_bos|><|VIDEO|><|video_eos|> => - <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> - """ - - audio_token_id = thinker_config.audio_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) - - grid_t = video_grid_thw[0] - grid_h = video_grid_thw[1] - grid_w = video_grid_thw[2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * video_second_per_grid_t * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) - - updates = [audio_start_token_id] - added_audio_len = 0 - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( - spatial_merge_size**2) - updates.extend([video_token_id] * vision_ntoken_per_chunk) - - audio_chunk_size = min(t_ntoken_per_chunk, - audio_len - added_audio_len) - updates.extend(audio_chunk_size * [audio_token_id]) - added_audio_len += audio_chunk_size - if added_audio_len < audio_len: - updates.extend((audio_len - added_audio_len) * [audio_token_id]) - updates.extend([audio_end_token_id]) - - return updates + def get_next_input_positions_tensor( + out: np.ndarray, + out_offset: int, + mrope_position_delta: int, + context_len: int, + num_new_tokens: int, + ): + values = np.arange( + mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype, + ) + out[:, out_offset : out_offset + num_new_tokens] = values diff --git a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py index 42926bad22ef6..560fb100413d1 100644 --- a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py @@ -10,33 +10,39 @@ from .base import RotaryEmbedding class NTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with fixed and mixed NTK scaling. - https://kexue.fm/archives/9706 """ + https://kexue.fm/archives/9706""" - def __init__(self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - mixed_b: Optional[float] = None) -> None: + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + mixed_b: Optional[float] = None, + ) -> None: self.scaling_factor = scaling_factor self.mixed_b = mixed_b - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: base = self.base * (self.scaling_factor if self.mixed_b is None else 1) inv_freq = super()._compute_inv_freq(base) if self.mixed_b is None: - inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) + inv_freq = inv_freq / self.scaling_factor ** (2 / self.rotary_dim) else: - a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / - 2)**self.mixed_b - lambda_1_m = (a * torch.arange( - 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() + a = ( + torch.tensor(self.scaling_factor).log() + / (self.rotary_dim / 2) ** self.mixed_b + ) + lambda_1_m = ( + a * torch.arange(1, self.rotary_dim // 2 + 1).float() ** self.mixed_b + ).exp() inv_freq = inv_freq / lambda_1_m return inv_freq diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index 9c36d633e2a9f..02ad142d676b7 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -44,14 +44,13 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): self.short_factor = short_factor self.long_factor = long_factor - scale = self.max_position_embeddings / \ - self.original_max_position_embeddings + scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: scaling_factor = math.sqrt( - 1 + math.log(scale) / - math.log(self.original_max_position_embeddings)) + 1 + math.log(scale) / math.log(self.original_max_position_embeddings) + ) if short_mscale is None: short_mscale = scaling_factor if long_mscale is None: @@ -61,22 +60,32 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): self.long_mscale = long_mscale short_cache = self._compute_cos_sin_cache( - original_max_position_embeddings, short_factor, short_mscale) + original_max_position_embeddings, short_factor, short_mscale + ) short_cache = short_cache.to(dtype) - long_cache = self._compute_cos_sin_cache(max_position_embeddings, - long_factor, long_mscale) + long_cache = self._compute_cos_sin_cache( + max_position_embeddings, long_factor, long_mscale + ) long_cache = long_cache.to(dtype) long_short_cache = torch.cat([short_cache, long_cache], dim=0) - self.register_buffer("long_short_cos_sin_cache", - long_short_cache, - persistent=False) + self.register_buffer( + "long_short_cos_sin_cache", long_short_cache, persistent=False + ) def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor: rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) - inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) + inv_freq = 1.0 / ( + rescale_factors + * ( + self.base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) + / self.rotary_dim + ) + ) + ) return inv_freq def _compute_cos_sin_cache( @@ -105,10 +114,14 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): key = key.view(*key.shape[:-1], -1, self.head_size) k = self.original_max_position_embeddings - long_prompt_offset = (torch.any(positions > k).float() * - torch.full_like(positions, k)).long() - idx = (torch.add(positions, long_prompt_offset) - if long_prompt_offset is not None else positions) + long_prompt_offset = ( + torch.any(positions > k).float() * torch.full_like(positions, k) + ).long() + idx = ( + torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None + else positions + ) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) @@ -116,13 +129,13 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): cos = cos.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] query_rot = query_rot * cos + rotate_neox(query_rot) * sin query = torch.cat((query_rot, query_pass), dim=-1) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] key_rot = key_rot * cos + rotate_neox(key_rot) * sin key = torch.cat((key_rot, key_pass), dim=-1) diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py new file mode 100644 index 0000000000000..223350d432674 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def is_rocm_triton_rotary_embedding_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_TRITON_ROPE + ) + + +def rocm_aiter_rotary_emb_with_key_forward_triton_impl( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + import aiter.ops.triton.rope as ops + + ops.rope_cached_thd_positions_2c_fwd_inplace( + query, + key, + cos, + sin, + positions, + rotate_style, + reuse_freqs_front_part=True, + nope_first=is_nope_first, + ) + + +def rocm_aiter_rotary_emb_with_key_forward_triton_fake( + positions: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + rotate_style: int = 0, + is_nope_first: bool = False, +) -> None: + pass + + +if is_rocm_triton_rotary_embedding_enabled(): + direct_register_custom_op( + op_name="rocm_aiter_rotary_emb_with_key_forward_triton", + op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, + mutates_args=["key", "query"], + fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def rocm_aiter_rotary_emb( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, + is_neox_style: bool, +): + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox_style else 1 + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton( + positions, + sin, + cos, + query_, + key_, + rotate_style, + False, + ) + query = query.view(query_shape) + key = key.view(key_shape) diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py index 851565c5667a4..93c92e7801e13 100644 --- a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py @@ -4,8 +4,7 @@ import torch from .base import RotaryEmbedding -from .common import (yarn_find_correction_range, yarn_get_mscale, - yarn_linear_ramp_mask) +from .common import yarn_find_correction_range, yarn_get_mscale, yarn_linear_ramp_mask class YaRNScalingRotaryEmbedding(RotaryEmbedding): @@ -36,33 +35,42 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / - self.rotary_dim) + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, - dtype=torch.float)) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + inv_freq_mask = ( + 1 + - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - dtype=torch.float32) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) return cache diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py deleted file mode 100644 index e77eb637c8942..0000000000000 --- a/vllm/model_executor/layers/sampler.py +++ /dev/null @@ -1,1198 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A layer that samples the next tokens from the model's outputs.""" -import itertools -from collections.abc import Iterator -from dataclasses import dataclass -from importlib.util import find_spec -from math import inf -from typing import Optional, Union - -import msgspec -import torch -import torch.nn as nn - -import vllm.envs as envs -from vllm.model_executor.layers.utils import apply_penalties -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors, - SequenceGroupToSample) -from vllm.sampling_params import SamplingType -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SequenceOutput) - -if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - # yapf: disable - from flashinfer.sampling import ( - top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) - - # yapf: enable -else: - flashinfer_top_k_top_p_sampling = None - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -def get_sampler() -> torch.nn.Module: - if envs.VLLM_USE_V1: - # Lazy import: the v1 package isn't distributed - from vllm.v1.sample.sampler import Sampler as V1Sampler - return V1Sampler() - return Sampler() - - -# (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = list[tuple[list[int], list[int]]] - -# Types of temporary data structures used for -# computing sample_result -SampleMetadataType = dict[SamplingType, tuple[list[int], - list[SequenceGroupToSample]]] -MultinomialSamplesType = dict[SamplingType, torch.Tensor] -SampleResultsDictType = dict[int, tuple[list[int], list[int]]] - - -# Encapsulates temporary data structures for computing -# sample_result. -# -# * For multi-step scheduling: must be returned -# by `Sampler.forward()` and used later to compute the pythonized -# sample_result -# -# * For single-step scheduling: consumed immediately -# inside `Sampler.forward()` to compute pythonized sample_result. -@dataclass -class SampleResultArgsType: - sample_metadata: SampleMetadataType - multinomial_samples: MultinomialSamplesType - sample_results_dict: SampleResultsDictType - sampling_metadata: SamplingMetadata - greedy_samples: Optional[torch.Tensor] - - -# Union of non-deferred (single-step scheduling) -# vs deferred (multi-step scheduling) -# sample result types -MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] - -# Abbreviation of the _sample() return type -SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] - - -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: list[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # Holds either (1) the pythonized sampler result (single-step scheduling) - # or (2) what will be arguments for later deferred pythonization of the - # sampler result (muliti-step scheduling) - deferred_sample_results_args: Optional[SampleResultArgsType] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # On-device tensor containing the sampled token embeddings (embeddings - # corresponding to the sampled token ids). Used when prompt embeddings are - # specified in lieu of prompt token ids or text. - sampled_token_embeds: Optional[torch.Tensor] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: - return iter(self.outputs) - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return (f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr})") - - -class Sampler(nn.Module): - """Samples the next tokens from the model's outputs. - - This layer does the following: - 1. Discard the hidden states that are not used for sampling (i.e., all - tokens except the final one in each prompt). - 2. Compute the logits for the next tokens. - 3. Apply presence, frequency and repetition penalties. - 4. Apply temperature scaling. - 5. Apply top-p and top-k truncation. - 6. Sample the next tokens. - Here, each sequence group within the batch can have different sampling - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). - - The structure of the logits tensor is coupled with the seq_groups in - sampling_metadata. Typically, each sequence in each seq_group has one row in - logits for the next token to be sampled; however, for a seq_group with a - prompt request with the prompt_logprobs sampling parameter, there are rows - in logits for each token in the input prompt. - """ - - def __init__(self): - super().__init__() - - # Whether or not the SamplerOutput should have on-device tensors - # containing the sampled token ids and probabilities. This is used by - # speculative decoding and when prompt embeddings are specified. - self.include_gpu_probs_tensor = False - self.should_modify_greedy_probs_inplace = False - - def _init_sampling_tensors( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ): - """The goal here is to reuse sampling tensors between similar decode - runs. This is possible because sampling logic does not change between - decodes of the same sequences. - """ - _, vocab_size = logits.shape - - # First free any existing stored sampling tensors. - # This is necessary because some sampling tensors may - # have pinned memory. - self._sampling_tensors = None - - # Initialize new sampling tensors - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) - - self._sampling_tensors = sampling_tensors - self._do_penalties = do_penalties - self._do_top_p_top_k = do_top_p_top_k - self._do_min_p = do_min_p - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """ - Single-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Pythonize sampling result & logprobs tensor - - Multi-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Defer Pythonization of sampling result & logprobs - tensor - * Encapsulate arguments required for deferred Pythonization - in the - [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput] - structure - - Args: - logits: (num_tokens, vocab_size). - sampling_metadata: Metadata for sampling. - """ - assert logits is not None - _, vocab_size = logits.shape - - # Prepare sampling tensors with pinned memory to avoid blocking. - if not sampling_metadata.reuse_sampling_tensors: - self._init_sampling_tensors(logits, sampling_metadata) - elif self._do_penalties: - # In this case, the sampling tensors logic depends on - # "output_tokens" of a sequence. As a result, we cannot - # reuse sampling tensors, since "output_tokens" changes - # between decode runs. - self._init_sampling_tensors(logits, sampling_metadata) - - assert self._sampling_tensors is not None - sampling_tensors = self._sampling_tensors - do_penalties = self._do_penalties - do_top_p_top_k = self._do_top_p_top_k - do_min_p = self._do_min_p - - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - # Apply presence and frequency penalties. - if do_penalties: - logits = apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - - # Use float32 to apply temperature scaling. - # Use in-place division to avoid creating a new tensor. - logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - - if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=self.include_gpu_probs_tensor, - modify_greedy_probs=self._should_modify_greedy_probs_inplace, - ) - - if self.include_gpu_probs_tensor: - # Since we will defer sampler result Pythonization, - # preserve GPU-side tensors in support of later - # deferred pythonization of logprobs - assert maybe_sampled_tokens_tensor is not None - on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) - else: - # Since Pythonization has already happened, don't preserve - # GPU-side tensors. - on_device_tensors = None - - # Get the logprobs query results. - prompt_logprobs = None - sample_logprobs = None - if not sampling_metadata.skip_sampler_cpu_output: - # Pythonize logprobs now (GPU -> CPU); do not defer. - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - prompt_logprobs, sample_logprobs = get_logprobs( - logprobs, sampling_metadata, maybe_deferred_sample_results) - - return _build_sampler_output( - maybe_deferred_sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors, - skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) - - @property - def _should_modify_greedy_probs_inplace(self) -> bool: - """Whether or not the sampler should modify the probability distribution - of greedily-sampled tokens such that multinomial sampling would sample - the greedily-sampled token. - - In other words, if True then we set the probability of the greedily- - sampled token to 1. - - This is used by speculative decoding, which requires that the sampling - method be encoded into the probability distribution. - """ - return self.should_modify_greedy_probs_inplace - - -def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - have not been generated yet - """ - # list of indices in logits that will be set to -inf - logits_to_penalize: list[tuple[int, int]] = [] - logits_applied = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - - sample_indices = seq_group.sample_indices - logits_applied += len(sample_indices) + len( - seq_group.prompt_logprob_indices) - if not seq_group.do_sample: - continue - - start_idx = sample_indices[0] - min_tokens = sampling_params.min_tokens - token_ids_to_penalize = sampling_params.all_stop_token_ids - if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: list[int] = [] - for j, seq_id in enumerate(seq_ids): - seq_data = seq_group.seq_data[seq_id] - if len(seq_data.output_token_ids_array) < min_tokens: - seqs_to_penalize.append(j) - - if seqs_to_penalize: - # convert to the index into logits - seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] - # itertools.product pairs each seq index with every token id - logits_to_penalize.extend( - itertools.product(seqs_to_penalize, token_ids_to_penalize)) - - if logits_to_penalize: - # use zip and * to group indices along each dimension - # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) - logits[tuple(zip(*logits_to_penalize))] = -float("inf") - - # verifies that no rows in logits were missed unexpectedly - assert logits_applied == logits.shape[0] - return logits - - -def _apply_top_k_top_p( - logits: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = torch.empty_like(logits_sort).scatter_(dim=-1, - index=logits_idx, - src=logits_sort) - return logits - - -def _apply_min_p( - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Adapted from - https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 - """ - probs = torch.softmax(logits, dim=-1) - top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs - tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill_(tokens_to_remove, -float("inf")) - - return logits - - -def _greedy_sample( - selected_seq_groups: list[SequenceGroupToSample], - samples: torch.Tensor, -) -> SampleResultType: - """Run greedy sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - samples: (num_selected_samples,) A tensor of samples. The length of - samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - samples_lst = samples.tolist() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - num_parent_seqs = len(seq_ids) - assert num_parent_seqs == 1, ( - "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples_lst[sample_idx]] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _random_sample( - selected_seq_groups: list[SequenceGroupToSample], - random_samples: torch.Tensor, -) -> SampleResultType: - """Run random sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - random_samples: (num_selected_samples,) A tensor of samples. The - length of samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # Find the maximum n value of the prompt phase requests. - random_samples = random_samples.cpu() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - num_parent_seqs = len(seq_ids) - if is_prompt: - # Prompt phase. - parent_ids = [0] * sampling_params.n - next_token_ids = random_samples[ - sample_idx, :sampling_params.n].tolist() - else: - # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -def _multinomial( - probs: torch.Tensor, - num_samples: int, - seq_groups: Optional[list[SequenceGroupToSample]] = None, -) -> torch.Tensor: - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - q = torch.empty_like(probs) - if seq_groups is None: - q.exponential_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - q[sample_idx:sample_idx + - stride].exponential_(generator=seq_group.generator) - sample_idx += stride - return probs.div_(q).argmax(dim=1).view(-1, num_samples) - - -def _top_k_top_p_multinomial_with_flashinfer( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - top_ks = top_ks.repeat_interleave(num_samples) - top_ps = top_ps.repeat_interleave(num_samples) - batch_next_token_ids = flashinfer_top_k_top_p_sampling( - probs, - top_ks, - top_ps, - ) - return batch_next_token_ids.view(-1, num_samples) - - -def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType) -> SampleResultType: - '''This function consumes GPU-side sampler results and computes - Pythonized CPU-side sampler results (GPU -> CPU sync.) - - Single-step scheduling: this function is invoked at sampling-time - for immediate Pythonization. - - Multi-step scheduling: Pythonization is deferred until after multiple - GPU-side steps have been completed. - - Args: - sample_result_args: GPU-side inputs to the Pythonization process - - Returns: - Pythonized sampler results - ''' - - ( - sample_metadata, - sampling_metadata, - greedy_samples, - multinomial_samples, - sample_results_dict, - ) = ( - sample_result_args.sample_metadata, - sample_result_args.sampling_metadata, - sample_result_args.greedy_samples, - sample_result_args.multinomial_samples, - sample_result_args.sample_results_dict, - ) - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - return [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - - -def _sample_with_torch( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - '''Torch-oriented _sample() implementation. - - Single-step scheduling: - * Perform GPU-side sampling computation - * Immediately Pythonize sampling result - - Multi-step scheduling: - * Perform GPU-side sampling computation - * Defer Pythonization & preserve GPU-side - tensors required for Pythonization - ''' - - categorized_seq_group_ids: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: SampleResultsDictType = {} - sample_metadata: SampleMetadataType = {} - multinomial_samples: MultinomialSamplesType = {} - greedy_samples: Optional[torch.Tensor] = None - - # Create output tensor for sampled token ids. - if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), - VLLM_INVALID_TOKEN_ID, - dtype=torch.long, - device=logprobs.device) - else: - sampled_token_ids_tensor = None - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups) - long_sample_indices = sample_indices.long() - if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], - dim=-1) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = greedy_samples.unsqueeze(-1) - - if modify_greedy_probs: - # If required, modify the probabilities such that sampling from - # the modified distribution would always sample the argmax - # token id. - _modify_greedy_probs_inplace(logprobs, probs, - long_sample_indices, - greedy_samples) - - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_n_in_batch = 1 - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_n_in_batch = max(max_n_in_batch, sampling_params.n) - seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else - seq_groups) - - if flashinfer_top_k_top_p_sampling is not None: - logger.warning("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") - - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_n_in_batch, - seq_groups=seq_groups_arg) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[long_sample_indices] = \ - multinomial_samples[sampling_type].to(torch.long) - - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - # Encapsulate arguments for computing Pythonized sampler - # results, whether deferred or otherwise. - maybe_deferred_args = SampleResultArgsType( - sampling_metadata=sampling_metadata, - sample_metadata=sample_metadata, - multinomial_samples=multinomial_samples, - greedy_samples=greedy_samples, - sample_results_dict=sample_results_dict) - - if not sampling_metadata.skip_sampler_cpu_output: - # GPU<->CPU sync happens here. - # This also converts the sampler output to a Python object. - # Return Pythonized sampler result & sampled token ids - return get_pythonized_sample_results( - maybe_deferred_args), sampled_token_ids_tensor - else: - # Defer sampler result Pythonization; return deferred - # Pythonization args & sampled token ids - return ( - maybe_deferred_args, - sampled_token_ids_tensor, - ) - - -def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - """ - Args: - probs: (num_query_tokens_in_batch, num_vocab) - logprobs: (num_query_tokens_in_batch, num_vocab) - sampling_metadata: The metadata for a batch for sampling. - sampling_tensors: Tensors that include sampling related metadata. - - Returns: - (next_token_ids, parent_seq_ids) for each seq group in a batch. - If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. - """ - return _sample_with_torch( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=include_gpu_probs_tensor, - modify_greedy_probs=modify_greedy_probs, - ) - - -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - """ - This function calculates the ranks of the chosen tokens in a logprob tensor. - - Args: - x (torch.Tensor): 2D logprob tensor of shape (N, M) - where N is the no. of tokens and M is the vocab dim. - indices (torch.Tensor): List of chosen token indices. - - Returns: - torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank - of the chosen token in the input logprob tensor. - """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - result = (x > vals[:, None]) - del vals - return result.sum(1).add_(1) - - -def get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: SampleResultType, -) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: - """Return sample logprobs and prompt logprobs. - - The logic consists of 3 parts. - - Select indices to compute logprob from, ranks of token ids, and - the top k token ids from logprobs. - - Compute prompt logprobs if required. - - Compute sample logprobs if required. - - Args: - logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's - logprob per vocab. Sequence groups' query tokens are batched in a - single flattened tensor. For example, assuming there are N - seq groups, it is sorted by prefill tokens for seq_group_1 (if - prompt logprob is enabled), decode tokens for seq_group_1 (if - sampling is required), prefill tokens for seq_group_2, ... - sampling_metadata: The sampling metadata. - sample_results: (num_seq_groups) The tuple of (next_token_ids, - parent_ids) for each sequence group. When beam search is enabled, - sample_results can contain different number of seq_ids from - sampling_metadata.seq_groups. It is because beam search creates - 2 * BEAM_WIDTH number of samples (whereas there are only up to - BEAM_WIDTH number of seq_ids). - - Returns: - A tuple of prompt and sample logprobs per sequence group in a batch. - """ - # The index of query token to calculate logprobs. It includes both - # prompt and sample logprob indices. - query_indices: list[int] = [] - # The next token ids to get the logprob value from. - next_token_ids: list[int] = [] - # The largest requested number of logprobs. We find logprobs as many as the - # largest num logprobs in this API. If every logprobs is None, it will be - # set to -1. - largest_num_logprobs = -1 - - # Select indices to compute logprob from, ranks of token ids, and the top - # k token ids from logprobs. - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - sample_results): - sampling_params = seq_group.sampling_params - - # Update indices and tokens for prompt logprobs. - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - query_indices.extend(seq_group.prompt_logprob_indices) - next_token_ids.extend(next_prompt_tokens) - - # Update indices and next tokenes for sample logprob. - if seq_group.do_sample: - token_ids, parent_seq_ids = sample_result - # NOTE: We cannot directly use sample_indices because - # sample_indices only contain parent seq_ids of a previous step. - # The current step may have different number of seq_ids, and - # we can obtain it from `sample_result[1]`. - query_idx = seq_group.sample_indices[0] - query_indices.extend( - [query_idx + parent_id for parent_id in parent_seq_ids]) - next_token_ids.extend(token_ids) - - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - - assert len(next_token_ids) == len(query_indices) - - if len(query_indices) == 0: - empty_sampled_logprob: SampleLogprobs = [] - empty_prompt_logprob: Optional[PromptLogprobs] = None - num_seq_groups = len(sampling_metadata.seq_groups) - return [empty_prompt_logprob - ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups - - selected_logprobs, ranks = None, None - top_logprobs, top_token_ids = None, None - - # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can - # skip the whole logprob calculation. - if largest_num_logprobs >= 0: - query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, - device=logprobs.device) - - # (num_selected_query_tokens, num_logprobs). Note that query_indices can - # contain duplicates if beam search is enabled. - selected_logprobs = logprobs[[ - query_indices_gpu, - next_token_ids_gpu, - ]] - ranks = _get_ranks( - logprobs[query_indices_gpu], - next_token_ids_gpu, - ) - assert selected_logprobs.shape[0] == ranks.shape[0] - - # We need to compute top k only if there exists logprobs > 0. - if largest_num_logprobs > 0: - # Logprobs of topk tokens for a batch of sequence groups. - # (num_query_tokens_across_batch). - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - top_logprobs = top_logprobs.to('cpu') - top_token_ids = top_token_ids.to('cpu') - - selected_logprobs = selected_logprobs.to('cpu') - ranks = ranks.to('cpu') - - # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: list[SampleLogprobs] = [] - top_logprob_idx = 0 - selected_logprobs_idx = 0 - - for seq_group, sample_result in zip(sampling_metadata.seq_groups, - sample_results): - (prompt_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_prompt_logprob_if_needed( - seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, - selected_logprobs_idx, top_logprob_idx) - prompt_logprobs_per_seq_group.append(prompt_logprobs) - - (sampled_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_sampled_logprob_if_needed( - seq_group, sample_result, selected_logprobs, ranks, top_token_ids, - top_logprobs, selected_logprobs_idx, top_logprob_idx) - sample_logprobs_per_seq_group.append(sampled_logprobs) - - return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group - - -def _get_prompt_logprob_if_needed( - seq_group: SequenceGroupToSample, - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the prompt logprob from a sequence group if needed.""" - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - - # Find prompt logprobs - prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [] - num_logprobs = sampling_params.prompt_logprobs - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - # Pre-select indexes and create a list. It is faster than calling .item - # repetitively. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - - for idx, token_id in enumerate(next_prompt_tokens): - # Calculate the prompt logprob of the real prompt tokens. - # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: dict[int, tuple[float, int]] = { - token_id: (selected_logprob_items[idx], rank_items[idx]) - } - - # Add top K prompt logprobs along with its rank. - if num_logprobs > 0: - top_ids = top_token_ids[ - top_logprob_idx, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - prompt_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip(top_ids, top_probs, - top_ranks) - }) - prompt_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in prompt_logprobs_dict.items() - }) - # + 1 to go to the next prompt token. - top_logprob_idx += 1 - - # + len(next_prompt_tokens) to go to the next prompt. - selected_logprobs_idx += len(next_prompt_tokens) - return prompt_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _get_sampled_logprob_if_needed( - seq_group: SequenceGroupToSample, - sample_result: tuple[list[int], list[int]], - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the sample logprob if needed.""" - seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - sampled_logprobs: SampleLogprobs = [] - next_token_ids, parent_seq_ids = sample_result - - if seq_group.do_sample: - assert len(next_token_ids) > 0 - if num_logprobs is None: - for next_token_id in next_token_ids: - # Use a dummy logprob - sampled_logprobs.append({next_token_id: Logprob(inf)}) - else: - # Pre-select items from tensor. tolist() is faster than repetitive - # `.item()` calls. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - for idx, (next_token_id, parent_id) in enumerate( - zip(next_token_ids, parent_seq_ids)): - # Get the logprob of a sampled token. - sampled_logprobs_dict = { - next_token_id: - (selected_logprob_items[idx], rank_items[idx]) - } - if num_logprobs is not None and num_logprobs > 0: - # Get top K logprobs. - top_ids = top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx + parent_id, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - sampled_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip( - top_ids, top_probs, top_ranks) - }) - - sampled_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in - sampled_logprobs_dict.items() - }) - - # NOTE: This part of code is not intuitive. `selected_logprobs` include - # logprobs for the current step, which has len(next_token_ids) tokens - # per sequence group. `logprobs` includes logprobs from the previous - # steps, which has len(seq_ids) tokens per sequence group. - - # Iterate to the next sequence group in a batch. - selected_logprobs_idx += len(next_token_ids) - # Iterate to the next sequence group in a batch. - top_logprob_idx += len(seq_ids) - return sampled_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, - sample_indices: torch.Tensor, - greedy_samples: torch.Tensor) -> None: - """Modify the probability distributions of the greedily-sampled tokens such - that each sampled token has a "probability" of 1.0. This is required by - speculative decoding, which depends on the sampling method being encoded - within the probability distribution for correctness. - - # Why do we only need to do this for greedy sampling? - - vLLM's sampler performs the following steps for greedy or multinomial - (random) sampling: - 1. Get logits from model. - 2. Modify logits according to per-sequence sampling parameters. - - Multiply by temperature, top-k and top-p masking, penalize tokens - according to their frequency, etc. - 3. Sample a token. - - Random sampling simply samples from the modified probability - distribution. - - Greedy sampling performs `argmax` to obtain the token with the - highest likelihood. - - Ignoring greedy sampling for a moment, we find that the computed probability - distribution has the following property: we can sample from it independently - and find that the token sampled by the Sampler has a frequency corresponding - to how often we see it in our sampling. In other words, for tokens sampled - with vLLM's random SamplingType, the computed probability distribution - encodes the sampling methodology completely. - - Greedy sampling does not normally have this property. vLLM modifies logits - according to sampling params, then performs `argmax`, then returns the - sampled token and the computed probability distribution. If we sample from - the distribution, we'll find the likelihood of the greedily-sampled token - is not always 1.0. - - Since lossless speculative decoding requires that the sampling methodology - be encoded within the probability distribution, we are motivated to modify - the probability distribution such that the sampled token has probability 1 - when speculative decoding is used. - - NOTE: Alternatively, we could use an extremely low temperature to achieve - greedy sampling using multinomial computation and unite the codepaths. This - has implications on the overall design of the sampler, e.g. how to record - accurate logprobs for the user, so this improvement is deferred to later. - """ - # NOTE: logprobs are not modified so they can be returned to the user. - probs[sample_indices, :] = 0 - probs[sample_indices, greedy_samples] = 1.0 - - -def _build_sampler_output( - maybe_deferred_sample_results: MaybeDeferredSampleResultType, - sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], - sample_logprobs: Optional[list[SampleLogprobs]], - on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]], - skip_sampler_cpu_output: bool = False, -) -> SamplerOutput: - """Construct Python objects with the output of sampling. - - Args: - on_device_tensors: Tuple containing on-device tensors with the - probabilities used in sampling and the sampled token ids. This - allows post-processing without copies to CPU/serialization, e.g. in - speculative decoding rejection sampling. - """ - sampler_output: list[CompletionSequenceGroupOutput] = [] - - if skip_sampler_cpu_output: - assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) - deferred_sample_results_args = maybe_deferred_sample_results - else: - assert prompt_logprobs is not None - assert sample_logprobs is not None - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - assert len(sampling_metadata.seq_groups) \ - == len(maybe_deferred_sample_results) \ - == len(prompt_logprobs) \ - == len(sample_logprobs) - deferred_sample_results_args = None - - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - maybe_deferred_sample_results, - prompt_logprobs, sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: list[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip( - parent_ids, next_token_ids, group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, - group_prompt_logprobs)) - - # If not specified, store None values in SamplerOutput. - if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, - sampled_token_ids) = on_device_tensors - else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, - None) - - return SamplerOutput( - outputs=sampler_output, - sampled_token_probs=sampled_token_probs, - sampled_token_ids=sampled_token_ids, - logprobs=logprobs_tensor, - deferred_sample_results_args=deferred_sample_results_args) - - -def _get_next_prompt_tokens( - seq_group: SequenceGroupToSample) -> tuple[int, ...]: - """Get a list of next prompt tokens to compute logprob from a - given sequence group. - - It is used to compute prompt logprob. Imagine you have logprob for each - query token. Query token needs to know the next prompt token id to compute - prompt logprob. This is a helper to obtain next prompt token ids. - - This API has to be used only when the caller knows seq_group is in prefill - stage. - - Returns: - A list of next prompt tokens to compute logprob. - """ - assert seq_group.is_prompt, ( - "Caller should ensure the sequence group is in a prefill stage.") - seq_ids = seq_group.seq_ids - query_len = seq_group.query_len - assert query_len is not None - # prompt has only 1 seq id. - assert len(seq_ids) == 1 - seq_data = seq_group.seq_data[seq_ids[0]] - computed_len = seq_data.get_num_computed_tokens() - prompt_tokens = seq_data.prompt_token_ids - # +1 because we are looking for a next prompt token. - next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + query_len + 1, - len(prompt_tokens)) - next_prompt_tokens = prompt_tokens[ - next_token_index_start:next_token_index_end] - return next_prompt_tokens diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 2897f75b3129e..e522cc450d6bd 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility methods for model layers.""" + from typing import Callable, Optional import torch from vllm import _custom_ops as ops from vllm import envs -from vllm.platforms import current_platform +from vllm.platforms import CpuArchEnum, current_platform from vllm.utils import direct_register_custom_op @@ -24,8 +25,8 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # This will be used together with triton swiglu kernel shape = w.shape N = shape[-1] - first = w[..., :N // 2] - second = w[..., N // 2:] + first = w[..., : N // 2] + second = w[..., N // 2 :] stacked = torch.stack((first, second), dim=-1) w_shuffled = stacked.reshape(shape) @@ -39,9 +40,9 @@ def get_token_bin_counts_and_mask( ) -> tuple[torch.Tensor, torch.Tensor]: # Compute the bin counts for the tokens. # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) + bin_counts = torch.zeros( + (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device + ) bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 @@ -49,18 +50,21 @@ def get_token_bin_counts_and_mask( return bin_counts, mask -def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: +def apply_penalties( + logits: torch.Tensor, + prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> torch.Tensor: """ Applies penalties in place to the logits tensor logits : The input logits tensor of shape [num_seqs, vocab_size] - prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts - are padded to the maximum prompt length within the batch using - `vocab_size` as the padding value. The value `vocab_size` is used - for padding because it does not correspond to any valid token ID + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID in the vocabulary. output_tokens_tensor: The output tokens tensor. presence_penalties: The presence penalties of shape (num_seqs, ) @@ -68,15 +72,17 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, repetition_penalties: The repetition penalties of shape (num_seqs, ) """ num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) + _, prompt_mask = get_token_bin_counts_and_mask( + prompt_tokens_tensor, vocab_size, num_seqs + ) output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) + output_tokens_tensor, vocab_size, num_seqs + ) # Apply repetition penalties as a custom op from vllm._custom_ops import apply_repetition_penalties - apply_repetition_penalties(logits, prompt_mask, output_mask, - repetition_penalties) + + apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details @@ -85,22 +91,27 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def default_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +def default_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): return torch.nn.functional.linear(x, weight, bias) def rocm_unquantized_gemm_impl( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 + k = weight.shape[1] - use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ - x.dtype in [torch.float16, torch.bfloat16] \ - and k % 8 == 0 and bias is None) + use_skinny = ( + envs.VLLM_ROCM_USE_SKINNY_GEMM + and on_gfx9() + and x.dtype in [torch.float16, torch.bfloat16] + and k % 8 == 0 + ) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) @@ -111,51 +122,84 @@ def rocm_unquantized_gemm_impl( cu_count = current_platform.get_cu_count() if m > 8 and 0 < n <= 4: - out = ops.wvSplitK(weight, x_view, cu_count) + out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.view(*x.shape[:-1], weight.shape[0]) - elif m % 4 == 0 and n == 1 and k <= 8192: + elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: out = ops.LLMM1(weight, x_view, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) def rocm_unquantized_gemm_impl_fake( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: return x.new_empty((*x.shape[:-1], weight.shape[0])) -def rocm_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def rocm_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) direct_register_custom_op( op_name="rocm_unquantized_gemm_impl", op_func=rocm_unquantized_gemm_impl, - mutates_args=[], fake_impl=rocm_unquantized_gemm_impl_fake, - dispatch_key=current_platform.dispatch_key, ) -def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype): - return (torch._C._cpu._is_amx_tile_supported() - and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 - and n % 16 == 0) +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: + return ( + torch._C._cpu._is_amx_tile_supported() + and (dtype in (torch.bfloat16, torch.int8)) + and k % 32 == 0 + and n % 16 == 0 + ) -def cpu_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): - if getattr(layer, "use_cpu_sgl", False): - return torch.ops._C.weight_packed_linear(x, weight, bias, True) +def dispatch_cpu_unquantized_gemm( + layer: torch.nn.Module, + remove_weight: bool, +) -> None: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype): + packed_weight = torch.ops._C.convert_weight_packed(layer.weight) + if getattr(layer, "bias", None) is not None: + bias_f32 = layer.bias.to(torch.float32) + else: + bias_f32 = None + layer.cpu_linear = lambda x, weight, bias: torch.ops._C.weight_packed_linear( + x, packed_weight, bias_f32 if bias is not None else None, True + ) + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + elif ops._supports_onednn and ( + current_platform.get_cpu_architecture() == CpuArchEnum.X86 + or ops.is_onednn_acl_supported() + ): + origin_weight = layer.weight + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + handler = ops.create_onednn_mm(origin_weight.t(), 32) + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) else: - return torch.nn.functional.linear(x, weight, bias) + layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( + x, weight, bias + ) + + +def cpu_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): + return layer.cpu_linear(x, weight, bias) def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9f223998e554f..b7253c7f0e523 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -9,12 +9,18 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding, +) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs @@ -26,59 +32,73 @@ DEFAULT_VOCAB_PADDING_SIZE = 64 class UnquantizedEmbeddingMethod(QuantizeMethodBase): """Unquantized method for embeddings.""" - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """Create weights for embedding layer.""" - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm + + dispatch_cpu_unquantized_gemm(layer, remove_weight=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) - def embedding(self, layer: torch.nn.Module, - input_: torch.Tensor) -> torch.Tensor: + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: return F.embedding(input_, layer.weight) -def pad_vocab_size(vocab_size: int, - pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, - rank: int, - offset: int = 0) -> Sequence[int]: + per_partition_vocab_size: int, rank: int, offset: int = 0 +) -> Sequence[int]: index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f + offset, index_l + offset -def vocab_range_from_global_vocab_size(global_vocab_size: int, - rank: int, - world_size: int, - offset: int = 0) -> Sequence[int]: +def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int, offset: int = 0 +) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) - return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, - offset=offset) + return vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, offset=offset + ) @dataclass class VocabParallelEmbeddingShardIndices: """Indices for a shard of a vocab parallel embedding.""" + padded_org_vocab_start_index: int padded_org_vocab_end_index: int padded_added_vocab_start_index: int @@ -99,13 +119,11 @@ class VocabParallelEmbeddingShardIndices: @property def num_org_elements_padded(self) -> int: - return (self.padded_org_vocab_end_index - - self.padded_org_vocab_start_index) + return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index @property def num_added_elements_padded(self) -> int: - return (self.padded_added_vocab_end_index - - self.padded_added_vocab_start_index) + return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index @property def num_org_vocab_padding(self) -> int: @@ -121,17 +139,14 @@ class VocabParallelEmbeddingShardIndices: def __post_init__(self): # sanity checks - assert (self.padded_org_vocab_start_index - <= self.padded_org_vocab_end_index) - assert (self.padded_added_vocab_start_index - <= self.padded_added_vocab_end_index) + assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index + assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index assert self.org_vocab_start_index <= self.org_vocab_end_index assert self.added_vocab_start_index <= self.added_vocab_end_index assert self.org_vocab_start_index <= self.padded_org_vocab_start_index - assert (self.added_vocab_start_index - <= self.padded_added_vocab_start_index) + assert self.added_vocab_start_index <= self.padded_added_vocab_start_index assert self.org_vocab_end_index <= self.padded_org_vocab_end_index assert self.added_vocab_end_index <= self.padded_added_vocab_end_index @@ -141,20 +156,27 @@ class VocabParallelEmbeddingShardIndices: @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def get_masked_input_and_mask( - input_: torch.Tensor, org_vocab_start_index: int, - org_vocab_end_index: int, num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]: + input_: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, +) -> tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & ( - input_ < org_vocab_end_index) + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) added_vocab_mask = (input_ >= added_vocab_start_index) & ( - input_ < added_vocab_end_index) - added_offset = added_vocab_start_index - ( - org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding - valid_offset = (org_vocab_start_index * - org_vocab_mask) + (added_offset * added_vocab_mask) + input_ < added_vocab_end_index + ) + added_offset = ( + added_vocab_start_index + - (org_vocab_end_index - org_vocab_start_index) + - num_org_vocab_padding + ) + valid_offset = (org_vocab_start_index * org_vocab_mask) + ( + added_offset * added_vocab_mask + ) vocab_mask = org_vocab_mask | added_vocab_mask input_ = vocab_mask * (input_ - valid_offset) return input_, ~vocab_mask @@ -200,14 +222,16 @@ class VocabParallelEmbedding(CustomOp): prefix: full name of the layer in the state dict """ # noqa: E501 - def __init__(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() # Keep the input dimensions. @@ -217,18 +241,22 @@ class VocabParallelEmbedding(CustomOp): self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size - self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, - self.padding_size) + self.org_vocab_size_padded = pad_vocab_size( + self.org_vocab_size, self.padding_size + ) self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, - self.padding_size) + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) assert self.org_vocab_size_padded <= self.num_embeddings_padded - self.shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) self.embedding_dim = embedding_dim quant_method = None @@ -242,70 +270,87 @@ class VocabParallelEmbedding(CustomOp): # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self) is VocabParallelEmbedding quant_method_implements_embedding = method_has_implemented_embedding( - type(quant_method)) + type(quant_method) + ) if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( f"The class {type(quant_method).__name__} must implement " - "the 'embedding' method, see UnquantizedEmbeddingMethod.") + "the 'embedding' method, see UnquantizedEmbeddingMethod." + ) self.quant_method: QuantizeMethodBase = quant_method if params_dtype is None: params_dtype = torch.get_default_dtype() - # Divide the weight matrix along the vocaburaly dimension. + # Divide the weight matrix along the vocabulary dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size - self.num_embeddings_per_partition = divide(self.num_embeddings_padded, - self.tp_size) - assert (self.shard_indices.num_elements_padded == - self.num_embeddings_per_partition) + self.num_embeddings_per_partition = divide( + self.num_embeddings_padded, self.tp_size + ) + assert ( + self.shard_indices.num_elements_padded == self.num_embeddings_per_partition + ) self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - - self.shard_indices.org_vocab_start_index) + self.shard_indices.org_vocab_end_index + - self.shard_indices.org_vocab_start_index + ) self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - - self.shard_indices.added_vocab_start_index) + self.shard_indices.added_vocab_end_index + - self.shard_indices.added_vocab_start_index + ) - self.quant_method.create_weights(self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) @classmethod - def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, - vocab_size: int, org_vocab_size: int, tp_rank: int, - tp_size: int) -> VocabParallelEmbeddingShardIndices: + def _get_indices( + cls, + vocab_size_padded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, + ) -> VocabParallelEmbeddingShardIndices: """Get start and end indices for vocab parallel embedding, following the layout outlined in the class docstring, based on the given tp_rank and tp_size.""" num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded padded_org_vocab_start_index, padded_org_vocab_end_index = ( - vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, - tp_size)) + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size) + ) padded_added_vocab_start_index, padded_added_vocab_end_index = ( - vocab_range_from_global_vocab_size(num_added_embeddings_padded, - tp_rank, - tp_size, - offset=org_vocab_size)) + vocab_range_from_global_vocab_size( + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size + ) + ) # remove padding - org_vocab_start_index = min(padded_org_vocab_start_index, - org_vocab_size) + org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) - added_vocab_start_index = min(padded_added_vocab_start_index, - vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) return VocabParallelEmbeddingShardIndices( - padded_org_vocab_start_index, padded_org_vocab_end_index, - padded_added_vocab_start_index, padded_added_vocab_end_index, - org_vocab_start_index, org_vocab_end_index, - added_vocab_start_index, added_vocab_end_index) + padded_org_vocab_start_index, + padded_org_vocab_end_index, + padded_added_vocab_start_index, + padded_added_vocab_end_index, + org_vocab_start_index, + org_vocab_end_index, + added_vocab_start_index, + added_vocab_end_index, + ) def get_sharded_to_full_mapping(self) -> Optional[list[int]]: """Get a mapping that can be used to reindex the gathered logits for sampling. - + During sampling, we gather logits from all ranks. The relationship of index->token_id will follow the same format as outlined in the class docstring. However, after the gather, we want to reindex the final @@ -320,32 +365,49 @@ class VocabParallelEmbedding(CustomOp): added_embeddings: list[int] = [] padding: list[int] = [] for tp_rank in range(self.tp_size): - shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) + shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) range_start = self.num_embeddings_per_partition * tp_rank range_end = self.num_embeddings_per_partition * (tp_rank + 1) base_embeddings.extend( - range(range_start, - range_start + shard_indices.num_org_elements)) + range(range_start, range_start + shard_indices.num_org_elements) + ) padding.extend( - range(range_start + shard_indices.num_org_elements, - range_start + shard_indices.num_org_elements_padded)) + range( + range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded, + ) + ) added_embeddings.extend( range( range_start + shard_indices.num_org_elements_padded, - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements)) + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + ) + ) padding.extend( range( - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements, - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements_padded)) - assert (range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements_padded == range_end) + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded, + ) + ) + assert ( + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded + == range_end + ) ret = base_embeddings + added_embeddings + padding assert len(ret) == self.num_embeddings_padded return ret @@ -379,10 +441,14 @@ class VocabParallelEmbedding(CustomOp): # If param packed on the same dim we are sharding on, then # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: - packed_factor = param.packed_factor if isinstance( - param, BasevLLMParameter) else param.pack_factor - assert loaded_weight.shape[output_dim] == (self.org_vocab_size // - param.packed_factor) + packed_factor = ( + param.packed_factor + if isinstance(param, BasevLLMParameter) + else param.pack_factor + ) + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size // param.packed_factor + ) start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: @@ -390,23 +456,24 @@ class VocabParallelEmbedding(CustomOp): # Copy the data. Select chunk corresponding to current shard. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) - param[loaded_weight.shape[0]:].data.fill_(0) + param[: loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0] :].data.fill_(0) - def forward(self, input_): + def forward_native(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, + input_, + self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) + self.shard_indices.added_vocab_end_index, + ) else: masked_input = input_ # Get the embeddings. - output_parallel = self.quant_method.embedding(self, - masked_input.long()) + output_parallel = self.quant_method.embedding(self, masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) @@ -414,15 +481,19 @@ class VocabParallelEmbedding(CustomOp): output = tensor_model_parallel_all_reduce(output_parallel) return output + def forward_cuda(self, input_): + return self.forward_native(input_) + def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" s += f", embedding_dim={self.embedding_dim}" s += f", org_vocab_size={self.org_vocab_size}" - s += f', num_embeddings_padded={self.num_embeddings_padded}' - s += f', tp_size={self.tp_size}' + s += f", num_embeddings_padded={self.num_embeddings_padded}" + s += f", tp_size={self.tp_size}" return s +@CustomOp.register("parallel_lm_head") class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. @@ -439,27 +510,38 @@ class ParallelLMHead(VocabParallelEmbedding): padding_size: padding size for the vocabulary. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config, - prefix) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + num_embeddings, + embedding_dim, + params_dtype, + org_num_embeddings, + padding_size, + quant_config, + prefix, + ) self.quant_config = quant_config if bias: self.bias = Parameter( - torch.empty(self.num_embeddings_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + torch.empty(self.num_embeddings_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3e..df0d059594a76 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -5,21 +5,24 @@ from typing import Literal, Optional from torch import nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.bitsandbytes_loader import ( - BitsAndBytesModelLoader) +from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader from vllm.model_executor.model_loader.runai_streamer_loader import ( - RunaiModelStreamerLoader) -from vllm.model_executor.model_loader.sharded_state_loader import ( - ShardedStateLoader) + RunaiModelStreamerLoader, +) +from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.utils import ( - get_architecture_class_name, get_model_architecture, get_model_cls) + get_architecture_class_name, + get_model_architecture, + get_model_cls, +) logger = init_logger(__name__) @@ -67,8 +70,11 @@ def register_model_loader(load_format: str): load_format (str): The model loader format name. Examples: - >>> from vllm.config import LoadConfig - >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader + >>> from vllm.config.load import LoadConfig + >>> from vllm.model_executor.model_loader import ( + ... get_model_loader, + ... register_model_loader, + ... ) >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader >>> >>> @register_model_loader("my_loader") @@ -88,14 +94,20 @@ def register_model_loader(load_format: str): if load_format in _LOAD_FORMAT_TO_MODEL_LOADER: logger.warning( "Load format `%s` is already registered, and will be " - "overwritten by the new loader class `%s`.", load_format, - model_loader_cls) + "overwritten by the new loader class `%s`.", + load_format, + model_loader_cls, + ) if not issubclass(model_loader_cls, BaseModelLoader): - raise ValueError("The model loader must be a subclass of " - "`BaseModelLoader`.") + raise ValueError( + "The model loader must be a subclass of `BaseModelLoader`." + ) _LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls - logger.info("Registered model loader `%s` with load format `%s`", - model_loader_cls, load_format) + logger.info( + "Registered model loader `%s` with load format `%s`", + model_loader_cls, + load_format, + ) return model_loader_cls return _wrapper @@ -109,14 +121,13 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config) -def get_model(*, - vllm_config: VllmConfig, - model_config: Optional[ModelConfig] = None) -> nn.Module: +def get_model( + *, vllm_config: VllmConfig, model_config: Optional[ModelConfig] = None +) -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config - return loader.load_model(vllm_config=vllm_config, - model_config=model_config) + return loader.load_model(vllm_config=vllm_config, model_config=model_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960d..6106a1ab8a85c 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -5,10 +5,14 @@ from abc import ABC, abstractmethod import torch import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -25,24 +29,26 @@ class BaseModelLoader(ABC): raise NotImplementedError @abstractmethod - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: - """Load weights into a model. This standalone API allows + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + """Load weights into a model. This standalone API allows inplace weights loading for an already-initialized model""" raise NotImplementedError - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config load_config = vllm_config.load_config - load_device = device_config.device if load_config.device is None else \ - load_config.device + load_device = ( + device_config.device if load_config.device is None else load_config.device + ) target_device = torch.device(load_device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config, - model_config=model_config) + model = initialize_model( + vllm_config=vllm_config, model_config=model_config + ) logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index b8393956eed3f..d41b8ae55ea5f 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -16,39 +16,45 @@ from packaging import version from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -# yapf: enable +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import (ParamMapping, - set_default_torch_dtype) +from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - pt_weights_iterator, safetensors_weights_iterator) + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + pt_weights_iterator, + safetensors_weights_iterator, +) from vllm.model_executor.models import is_pooling_model -from vllm.model_executor.utils import (get_moe_expert_mapping, - get_packed_modules_mapping, - set_weight_attrs) +from vllm.model_executor.utils import ( + get_moe_expert_mapping, + get_packed_modules_mapping, + set_weight_attrs, +) from vllm.platforms import current_platform -# yapf conflicts with isort for this block - logger = init_logger(__name__) def is_moe_model(model: torch.nn.Module) -> bool: """Checks if the model contains FusedMoE layers.""" - return bool(any( - isinstance(module, FusedMoE) for module in model.modules())) + return bool(any(isinstance(module, FusedMoE) for module in model.modules())) class BitsAndBytesModelLoader(BaseModelLoader): @@ -69,6 +75,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] + self.tp_disabled_modules: list[str] = [] # Store the mapping of expert parameters for MoE models. self.expert_params_mapping: list[tuple[str, str, int, str]] = [] # mapping weight names from transformers to vllm. @@ -90,8 +97,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): if is_local: for pattern in allowed_patterns: - weight_files = glob.glob( - os.path.join(model_name_or_path, pattern)) + weight_files = glob.glob(os.path.join(model_name_or_path, pattern)) if weight_files: return model_name_or_path, weight_files, pattern else: @@ -107,20 +113,24 @@ class BitsAndBytesModelLoader(BaseModelLoader): revision, ignore_patterns=self.load_config.ignore_patterns, ) - return hf_folder, glob.glob( - os.path.join(hf_folder, pattern)), pattern + return ( + hf_folder, + glob.glob(os.path.join(hf_folder, pattern)), + pattern, + ) - raise RuntimeError( - f"No model weights found in: `{model_name_or_path}`") + raise RuntimeError(f"No model weights found in: `{model_name_or_path}`") - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> tuple[list[str], bool]: + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> tuple[list[str], bool]: """Prepare weight files for the model.""" allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( - model_name_or_path, allowed_patterns, revision) + model_name_or_path, allowed_patterns, revision + ) use_safetensors = matched_pattern == "*.safetensors" is_local = os.path.isdir(model_name_or_path) @@ -139,25 +149,27 @@ class BitsAndBytesModelLoader(BaseModelLoader): revision, ) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) + hf_weights_files, hf_folder, index_file + ) else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") + f"Cannot find any model weights with `{model_name_or_path}`" + ) return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - def _maybe_pool_model(module_name: str): # For pool model, we need to add the prefix `model.` # for the weight name if possible. - if self.is_pool_model and self.target_modules[0]. \ - startswith("model.") and not module_name.startswith( - "model."): + if ( + self.is_pool_model + and self.target_modules[0].startswith("model.") + and not module_name.startswith("model.") + ): return "model." + module_name return module_name @@ -185,8 +197,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): self, model_name_or_path: str, revision: Optional[str], - ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, - Any]]: + ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, as well as the quantization state dictionary.""" @@ -194,37 +205,41 @@ class BitsAndBytesModelLoader(BaseModelLoader): try: import bitsandbytes - if version.parse( - bitsandbytes.__version__) < version.parse("0.46.1"): - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.46.1.") + if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"): + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1." + ) except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.46.1 via " - "`pip install bitsandbytes>=0.46.1` to use " - "bitsandbytes quantizer.") from err + raise ImportError( + "Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer." + ) from err hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision) + model_name_or_path, revision + ) quant_state_dict: dict[str, Any] = {} if self.pre_quant: if self.load_8bit: return self._quantized_8bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict + hf_weights_files, use_safetensors, quant_state_dict + ), quant_state_dict else: return self._quantized_4bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict + hf_weights_files, use_safetensors, quant_state_dict + ), quant_state_dict - return self._unquantized_generator(hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict + return self._unquantized_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), quant_state_dict def _is_8bit_weight_name(self, weight_name: str): quantized_suffix = {".scb", ".weight_format"} - return any(weight_name.lower().endswith(suffix) - for suffix in quantized_suffix) + return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix) def _is_4bit_weight_name(self, weight_name: str): quantized_suffix = { @@ -237,12 +252,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): suffix = weight_name.split(".")[-1] return any(q_suffix in suffix for q_suffix in quantized_suffix) - def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: + def _quantized_8bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): if not mapped_weight_name.lower().endswith(".scb"): continue @@ -251,9 +267,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_state_dict[weight_key] = weight_tensor for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): if self._is_8bit_weight_name(mapped_weight_name): continue @@ -264,18 +280,18 @@ class BitsAndBytesModelLoader(BaseModelLoader): else: yield org_weight_name, weight_tensor - def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: + def _quantized_4bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: from bitsandbytes.functional import QuantState # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) + weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) temp_state_dict = {} for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in weight_iterator: if not self._is_4bit_weight_name(mapped_weight_name): continue @@ -287,87 +303,111 @@ class BitsAndBytesModelLoader(BaseModelLoader): temp_state_dict[mapped_weight_name] = weight_tensor # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: dict) -> QuantState: + def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState: quant_state = {} for k in temp_state_dict: if param_name + "." in k: quant_state[k] = temp_state_dict[k] - return QuantState.from_dict(quant_state, - device=current_platform.device_type) + return QuantState.from_dict( + quant_state, device=current_platform.device_type + ) # Second iterate over all prequant and normal weights # pre quantized weights would have a quant_state for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): if self._is_4bit_weight_name(mapped_weight_name): continue - if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" - in temp_state_dict) or ( - f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" - in temp_state_dict): - quant_state = _parse_quant_state(mapped_weight_name, - temp_state_dict) + if ( + f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict + ) or ( + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict + ): + quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict) quant_state_dict[mapped_weight_name] = quant_state yield org_weight_name, weight_tensor else: yield org_weight_name, weight_tensor - def _unquantized_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: + def _unquantized_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: from bitsandbytes.functional import quantize_4bit - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - + global_tp_size = get_tensor_model_parallel_world_size() + global_tp_rank = get_tensor_model_parallel_rank() + check_match = ( + lambda weight_name, module_name: weight_name.removesuffix(".weight") + == module_name + ) for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if any(target_module in mapped_weight_name - for target_module in self.target_modules - ) and mapped_weight_name.endswith(".weight"): + # override tp_size and tp_rank if the module has disabled TP + if any( + tp_disabled_module in mapped_weight_name + for tp_disabled_module in self.tp_disabled_modules + ): + tp_size = 1 + tp_rank = 0 + else: + tp_size = global_tp_size + tp_rank = global_tp_rank + + if any( + target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): # Without sharding if any( - mapped_weight_name.startswith(module) - for module in self.unsharded_weights_modules): + check_match(mapped_weight_name, module) + for module in self.unsharded_weights_modules + ): weight_sub_tensor = weight_tensor # Shard by column elif any( - mapped_weight_name.startswith(module) - for module in self.column_sharded_weights_modules): + check_match(mapped_weight_name, module) + for module in self.column_sharded_weights_modules + ): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[..., - start_index:end_index] + weight_sub_tensor = weight_tensor[..., start_index:end_index] # Weights have fused on disk. In this case, we assume that the # weight and module use same name. elif any( - mapped_weight_name.startswith(module) - for module in self.maybe_fused_weights_modules): + check_match(mapped_weight_name, module) + for module in self.maybe_fused_weights_modules + ): # special case for fused weights # get the size of each shard weight tensor total_shard_sizes = next( - (sizes for module, sizes in - self.maybe_fused_weights_modules.items() - if mapped_weight_name.startswith(module))) + ( + sizes + for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501 + if check_match(mapped_weight_name, module) + ) + ) total_size = weight_tensor.size(0) assert total_size == sum(total_shard_sizes) # get the start/end index of each shard weight tensor total_start_index = list( - itertools.accumulate([0] + total_shard_sizes))[:-1] - shard_weights_index = [( - idx + size // tp_size * tp_rank, - idx + size // tp_size * (tp_rank + 1), - ) for idx, size in zip(total_start_index, - total_shard_sizes)] + itertools.accumulate([0] + total_shard_sizes) + )[:-1] + shard_weights_index = [ + ( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) + for idx, size in zip(total_start_index, total_shard_sizes) + ] # slice and reorder the weight tensor weight_tensor = [ weight_tensor[start_index:end_index, ...] @@ -379,15 +419,15 @@ class BitsAndBytesModelLoader(BaseModelLoader): total_size = weight_tensor.size(0) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[start_index:end_index, - ...] + weight_sub_tensor = weight_tensor[start_index:end_index, ...] # bitsandbytes requires data in GPU if weight_sub_tensor.is_cuda: loaded_weight = weight_sub_tensor else: loaded_weight = weight_sub_tensor.to( - device=current_platform.device_type) + device=current_platform.device_type + ) # remove the following after the issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 @@ -408,62 +448,70 @@ class BitsAndBytesModelLoader(BaseModelLoader): def _get_bnb_target_modules(self, model: nn.Module) -> None: """ - Identify and collect all modules that support BitsAndBytes + Identify and collect all modules that support BitsAndBytes quantization. """ for name, module in model.named_modules(): - if (isinstance(module, LinearBase) - and hasattr(module.quant_method, "quant_config")): + if isinstance(module, LinearBase) and hasattr( + module.quant_method, "quant_config" + ): if modules_info := self.modules_mapping.get_sub_modules(name): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info for sub_name in sub_modules: - self.target_modules.append( - name.replace(rep_name, sub_name)) + new_name = name.replace(rep_name, sub_name) + self.target_modules.append(new_name) + if module.disable_tp: + self.tp_disabled_modules.append(new_name) # Add original module name even if the module has stacked map, # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) + if module.disable_tp: + self.tp_disabled_modules.append(name) elif isinstance(module, FusedMoE) and hasattr( - module.quant_method, "quant_config"): + module.quant_method, "quant_config" + ): # TODO: support FusedMoE with prequant and 8bit. if self.pre_quant and self.load_8bit: raise ValueError( "Prequant BitsAndBytes 8bit models with FusedMoE " - "is not supported yet.") + "is not supported yet." + ) # Get the corresponding weight name using module name and # expert_params_mapping. for exp in self.expert_params_mapping: weight_name = exp[1] - rep_name = name.replace("experts", - "") + weight_name.removesuffix(".") + rep_name = name.replace("experts", "") + weight_name.removesuffix( + "." + ) self.target_modules.append(rep_name) - assert (self.target_modules - ), "vLLM currently does not support BNB quantization for" + assert self.target_modules, ( + "vLLM currently does not support BNB quantization for" + ) f" {type(model).__name__}" def _classify_module_sharding(self, model: nn.Module): """ - Categorize modules based on their weight sharding requirements + Categorize modules based on their weight sharding requirements for tensor parallelism. """ for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new # static variable in the model implementation. - if isinstance(module, (ReplicatedLinear, )): + if isinstance(module, (ReplicatedLinear,)): self.unsharded_weights_modules.append(name) # `QKVParallelLinear` and `MergedColumnParallelLinear` might have # fused weights on disk. We need to use the output sizes of these # modules to shard the weights correctly. - elif isinstance(module, - (QKVParallelLinear, MergedColumnParallelLinear)): + elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)): self.maybe_fused_weights_modules[name] = module.output_sizes # In TP, these weights are partitioned along the column # dimension (dim=-1) - elif isinstance(module, (RowParallelLinear, )): + elif isinstance(module, (RowParallelLinear,)): self.column_sharded_weights_modules.append(name) elif isinstance(module, FusedMoE): expert_mapping = self.expert_params_mapping @@ -471,48 +519,52 @@ class BitsAndBytesModelLoader(BaseModelLoader): if exp[-1] == "w2": weight_name = exp[1] rep_name = name.replace( - "experts", "") + weight_name.removesuffix(".") + "experts", "" + ) + weight_name.removesuffix(".") self.column_sharded_weights_modules.append(rep_name) - def _verify_model_compatibility(self, model: nn.Module, - model_config: ModelConfig) -> None: + def _verify_model_compatibility( + self, model: nn.Module, model_config: ModelConfig + ) -> None: """ Verify that the model is compatible with BitsAndBytes quantization. """ if not hasattr(model, "load_weights"): raise AttributeError( "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") + f" {type(model).__name__}." + ) if not hasattr(model, "packed_modules_mapping"): raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") + "quantization yet. No 'packed_modules_mapping' found." + ) - quant_config = getattr(model_config.hf_config, "quantization_config", - None) - if quant_config is not None: - quant_method = quant_config.get("quant_method") + quant_config = getattr(model_config.hf_config, "quantization_config", None) + if quant_config and (quant_method := quant_config.get("quant_method")): if quant_method == "bitsandbytes": self.pre_quant = True else: raise ValueError( - f"BitsAndBytes loader does not support {quant_method} " - "quantization") + f"BitsAndBytes loader does not support {quant_method} quantization" + ) # The quant_states in pre_quantized models cannot work with a split # weight tensor. So TP does not work with pre_quantized bnb models. if self.pre_quant and get_tensor_model_parallel_world_size() > 1: raise ValueError( "Prequant BitsAndBytes models with tensor parallelism is not " - "supported. Please try with pipeline parallelism.") - if self.pre_quant: + "supported. Please try with pipeline parallelism." + ) + if quant_config and self.pre_quant: self.load_8bit = quant_config.get("load_in_8bit", False) - def _initialize_loader_state(self, model: nn.Module, - model_config: ModelConfig) -> None: + def _initialize_loader_state( + self, model: nn.Module, model_config: ModelConfig + ) -> None: """ - Initialize the loader's internal state based on the model and + Initialize the loader's internal state based on the model and configuration. """ self.is_pool_model = is_pooling_model(model) @@ -524,7 +576,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): raise AttributeError( f"MoE Model {type(model).__name__} does not support " "BitsAndBytes quantization yet. Ensure this model has " - "'get_expert_mapping' method.") + "'get_expert_mapping' method." + ) # For some models like Molmo, we need to use hf_to_vllm_mapper # to ensure correct loading of weights. if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): @@ -535,22 +588,20 @@ class BitsAndBytesModelLoader(BaseModelLoader): def _dequantize_dq(self, quant_states: Any): """ - When BNB employs Double Quantization, we perform the dequantization of - these constants during weight loading rather than at inference time, - thereby avoiding this computational overhead during inference. This + When BNB employs Double Quantization, we perform the dequantization of + these constants during weight loading rather than at inference time, + thereby avoiding this computational overhead during inference. This comes at the cost of increased memory usage. """ from bitsandbytes.functional import QuantState, dequantize_blockwise def _dequantize_single_state(quant_state): """Helper function to dequantize a single QuantState object.""" - if not (isinstance(quant_state, QuantState) - and quant_state.nested): + if not (isinstance(quant_state, QuantState) and quant_state.nested): return # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 - absmax = dequantize_blockwise(quant_state.absmax, - quant_state.state2) + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset # Ensure float32 dtype @@ -569,10 +620,9 @@ class BitsAndBytesModelLoader(BaseModelLoader): _dequantize_single_state(quant_states) return quant_states - def _fuse_moe_quant_states(self, model: nn.Module, - quant_states_dict: dict) -> dict: + def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict: """ - + This function consolidates individual expert quantization states into fused representations for w13 and w2. """ @@ -592,12 +642,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): for exp in expert_mapping: shard_id = exp[-1] if shard_id not in ("w1", "w2", "w3"): - raise ValueError(f"shard_id must be ['w1','w2','w3'] but " - f"got {shard_id}.") + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but got {shard_id}." + ) layer_prefix = name.split("experts")[0] weight_qual_name = layer_prefix + exp[1] + "weight" - quant_state = self._dequantize_dq( - quant_states_dict[weight_qual_name]) + quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name]) if shard_id == "w1": w1_states_lst.append(quant_state) elif shard_id == "w2": @@ -605,14 +655,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): else: w3_states_lst.append(quant_state) del quant_states_dict[weight_qual_name] - assert (len(w1_states_lst) == len(w2_states_lst) == - len(w3_states_lst)) + assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst) w13_absmax_lst = [] w2_absmax_lst = [] w13_total_dim0 = 0 w2_total_dim0 = 0 - for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, - w3_states_lst): + for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst): assert w1_qs.shape == w3_qs.shape assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype @@ -652,12 +700,13 @@ class BitsAndBytesModelLoader(BaseModelLoader): return expert_qs_dict def _stack_quantization_states( - self, model: nn.Module, - quant_state_dict: dict) -> dict[str, dict[int, Any]]: + self, model: nn.Module, quant_state_dict: dict + ) -> dict[str, dict[int, Any]]: stacked_quant_state_dict: dict[str, dict[int, Any]] = {} # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter + param_dict = dict(model.named_parameters()) for quant_param_name in quant_state_dict: if is_pp_missing_parameter(quant_param_name, model): @@ -667,23 +716,23 @@ class BitsAndBytesModelLoader(BaseModelLoader): shard_index = 0 for shard_name, ( - weight_name, - index, + weight_name, + index, ) in self.modules_mapping.inverse_packed_mapping.items(): # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # from being incorrectly identified as being present in # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight shard_pos = quant_param_name.find(shard_name) - can_correct_rename = (shard_pos - > 0) and (quant_param_name[shard_pos - 1] - == ".") + can_correct_rename = (shard_pos > 0) and ( + quant_param_name[shard_pos - 1] == "." + ) # If the quant_param_name is packed, it won't occur in the # param_dict before renaming. - new_quant_param_name = quant_param_name.replace( - shard_name, weight_name) - need_rename = (quant_param_name not in param_dict) \ - and (new_quant_param_name in param_dict) + new_quant_param_name = quant_param_name.replace(shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) and ( + new_quant_param_name in param_dict + ) if can_correct_rename and need_rename: shard_index = index quant_param_name = new_quant_param_name @@ -697,12 +746,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} - stacked_quant_state_dict[quant_param_name][shard_index] = ( - quant_state_dict[non_stacked_param_name]) + stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[ + non_stacked_param_name + ] return stacked_quant_state_dict - def _bind_quant_states_to_params(self, model: nn.Module, - stacked_quant_state_dict: dict) -> None: + def _bind_quant_states_to_params( + self, model: nn.Module, stacked_quant_state_dict: dict + ) -> None: # save quant_states and offsets as the attributes of the parameters param_dict = dict(model.named_parameters()) for param_name, param in param_dict.items(): @@ -716,13 +767,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): pack_ratio = getattr(param, "pack_factor", -1) if pack_ratio == -1: - raise ValueError( - f"pack_factor not set for parameter {param_name}.") + raise ValueError(f"pack_factor not set for parameter {param_name}.") num_elements = [0] * len(quant_states) for seq, quant_state in quant_states.items(): - num_elements[seq] = (math.prod(quant_state.shape) // - pack_ratio) + num_elements[seq] = math.prod(quant_state.shape) // pack_ratio offsets = np.concatenate(([0], np.cumsum(num_elements))) # Make torch infer_schema happy @@ -731,38 +780,39 @@ class BitsAndBytesModelLoader(BaseModelLoader): if self.load_8bit: set_weight_attrs( - param, {"matmul_state": [None] * len(quant_states)}) - - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + param, {"matmul_state": [None] * len(quant_states)} + ) + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: self._verify_model_compatibility(model, model_config) self._initialize_loader_state(model, model_config) - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator( - model_config.model, - model_config.revision, - )) + logger.info( + "Loading weights with BitsAndBytes quantization. May take a while ..." + ) + qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + ) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(qweight_iterator) # Some models may have weights loading tracker unimplemented. if loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") - expert_quant_state_dict = self._fuse_moe_quant_states( - model, quant_state_dict) + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) + expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict) stacked_quant_state_dict = self._stack_quantization_states( - model, quant_state_dict) + model, quant_state_dict + ) stacked_quant_state_dict = { **expert_quant_state_dict, - **stacked_quant_state_dict + **stacked_quant_state_dict, } self._bind_quant_states_to_params(model, stacked_quant_state_dict) torch.cuda.empty_cache() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 34b8d8e4ed622..00944989a002f 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -7,20 +7,28 @@ import time from collections.abc import Generator, Iterable from typing import Optional, cast -import huggingface_hub import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm import envs -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, - pt_weights_iterator, safetensors_weights_iterator) + download_safetensors_index_file_from_hf, + download_weights_from_hf, + fastsafetensors_weights_iterator, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + maybe_download_from_modelscope, + multi_thread_pt_weights_iterator, + multi_thread_safetensors_weights_iterator, + np_cache_weights_iterator, + pt_weights_iterator, + safetensors_weights_iterator, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -29,6 +37,9 @@ logger = init_logger(__name__) class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + # default number of thread when enable multithread weight loading + DEFAULT_NUM_THREADS = 8 + @dataclasses.dataclass class Source: """A source for weights.""" @@ -53,38 +64,17 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + extra_config = load_config.model_loader_extra_config + allowed_keys = {"enable_multithread_load", "num_threads"} + unexpected_keys = set(extra_config.keys()) - allowed_keys - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if envs.VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - if not os.path.exists(model): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None + if unexpected_keys: + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}" + ) def _prepare_weights( self, @@ -96,8 +86,10 @@ class DefaultModelLoader(BaseModelLoader): """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( - model_name_or_path, revision) or model_name_or_path) + model_name_or_path = ( + maybe_download_from_modelscope(model_name_or_path, revision) + or model_name_or_path + ) is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format @@ -106,8 +98,7 @@ class DefaultModelLoader(BaseModelLoader): # Some quantized models use .pt files for storing the weights. if load_format == "auto": allow_patterns = ["*.safetensors", "*.bin"] - elif (load_format == "safetensors" - or load_format == "fastsafetensors"): + elif load_format == "safetensors" or load_format == "fastsafetensors": use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == "mistral": @@ -160,24 +151,29 @@ class DefaultModelLoader(BaseModelLoader): revision, ) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) + hf_weights_files, hf_folder, index_file + ) else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") + f"Cannot find any model weights with `{model_name_or_path}`" + ) return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" + self, source: "Source" ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - source.model_or_path, source.revision, source.fall_back_to_pt, - source.allow_patterns_overrides) + source.model_or_path, + source.revision, + source.fall_back_to_pt, + source.allow_patterns_overrides, + ) if self.load_config.load_format == "npcache": # Currently np_cache only support *.bin checkpoints assert use_safetensors is False @@ -195,39 +191,57 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.use_tqdm_on_load, ) else: - weights_iterator = safetensors_weights_iterator( + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS + ), + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.safetensors_load_strategy, + ) + else: + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS + ), + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) - else: - weights_iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) if current_platform.is_tpu(): - from vllm.platforms.tpu import USE_TPU_COMMONS + from vllm.platforms.tpu import USE_TPU_INFERENCE - if not USE_TPU_COMMONS: - # In PyTorch XLA, we should call `xm.mark_step` - # requently so that not too many ops are accumulated - # in the XLA program. import torch_xla.core.xla_model - # as xm - import torch_xla.core.xla_model as xm + if not USE_TPU_INFERENCE: + # In PyTorch XLA, we should call `torch_xla.sync` + # frequently so that not too many ops are accumulated + # in the XLA program. + import torch_xla def _xla_weights_iterator(iterator: Generator): for weights in iterator: yield weights - xm.mark_step() + torch_xla.sync(wait=False) weights_iterator = _xla_weights_iterator(weights_iterator) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. - return ((source.prefix + name, tensor) - for (name, tensor) in weights_iterator) + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def get_all_weights( self, @@ -238,10 +252,8 @@ class DefaultModelLoader(BaseModelLoader): model_config.model, model_config.revision, prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", - True), - allow_patterns_overrides=getattr(model, "allow_patterns_overrides", - None), + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) yield from self._get_weights_iterator(primary_weights) @@ -253,25 +265,62 @@ class DefaultModelLoader(BaseModelLoader): yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, - model_config.revision, - fall_back_to_pt=True, - allow_patterns_overrides=None) + self._prepare_weights( + model_config.model, + model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None, + ) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + if model_config.quantization == "torchao" and torchao_version_at_least( + "0.14.0" + ): + self.load_config.safetensors_load_strategy = "torchao" weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) + + # if we don't have `model.weight_metadata_and_attr_saved` defined and + # set to True, it means that this is either offline quantization case + # or the first run of online quantization + # see online_quantization.py for detailed notes + offline_quantization_or_first_run_of_online_quantization = not getattr( + model, "weight_metadata_and_attr_saved", False + ) + + if model_config.quantization is None: + # model is not quantized + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model) + ) + elif offline_quantization_or_first_run_of_online_quantization: + # case 1: offline quantized checkpoint + # case 2: Step I1 first run of weight loading with + # online quantization + # see online_quantization.py for detailed notes + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model) + ) + else: + # to avoid circular dependency + from vllm.model_executor.model_loader.online_quantization import ( + load_weights_and_online_quantize, + ) + + # subsequent runs of weight loading with online + # quantization + loaded_weights = load_weights_and_online_quantize(self, model, model_config) + self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) + self.counter_after_loading_weights - self.counter_before_loading_weights, + ) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index f4a7da5744e04..b2a934ce59497 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.weight_utils import ( - initialize_dummy_weights) +from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights class DummyModelLoader(BaseModelLoader): @@ -14,14 +14,15 @@ class DummyModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06e..dbcd864516ec2 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -9,13 +9,19 @@ import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, get_gguf_weight_type_map, - gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, + get_gguf_weight_type_map, + gguf_quant_weights_iterator, +) class GGUFModelLoader(BaseModelLoader): @@ -28,15 +34,18 @@ class GGUFModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) def _prepare_weights(self, model_name_or_path: str): if os.path.isfile(model_name_or_path): return model_name_or_path # for raw HTTPS link if model_name_or_path.startswith( - ("http://", "https://")) and model_name_or_path.endswith(".gguf"): + ("http://", "https://") + ) and model_name_or_path.endswith(".gguf"): return hf_hub_download(url=model_name_or_path) # repo id/filename.gguf if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): @@ -45,7 +54,8 @@ class GGUFModelLoader(BaseModelLoader): else: raise ValueError( f"Unrecognised GGUF reference: {model_name_or_path} " - "(expected local file, raw URL, or <repo_id>/<filename>.gguf)") + "(expected local file, raw URL, or <repo_id>/<filename>.gguf)" + ) def _get_gguf_weights_map(self, model_config: ModelConfig): """ @@ -62,30 +72,41 @@ class GGUFModelLoader(BaseModelLoader): # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" + if model_type == "gemma3_text": + # Gemma3 models use "gemma3_text" in HuggingFace but + # "gemma3" in GGUF architecture naming + model_type = "gemma3" if model_type in ("deepseek_v3", "deepseek_v2"): model_type = "deepseek2" # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ - f"model.layers.{idx}.mlp.gate.e_score_correction_bias" - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = ( + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + ) if model_type in ("qwen2_moe", "qwen3_moe"): model_type = model_type.replace("_", "") # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + ) arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): @@ -98,7 +119,8 @@ class GGUFModelLoader(BaseModelLoader): name_map = gguf.get_tensor_name_map(arch, num_layers) with torch.device("meta"): dummy_model = AutoModelForCausalLM.from_config( - config, trust_remote_code=model_config.trust_remote_code) + config, trust_remote_code=model_config.trust_remote_code + ) state_dict = dummy_model.state_dict() for hf_name in state_dict: @@ -110,31 +132,31 @@ class GGUFModelLoader(BaseModelLoader): def _get_weights_iterator( self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(model_name_or_path, - gguf_to_hf_name_map) + return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map)) + self._get_weights_iterator(local_model_path, gguf_weights_map) + ) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights if "lm_head.weight" in get_gguf_extra_tensor_names( - local_model_path, gguf_weights_map): + local_model_path, gguf_weights_map + ): model_config.hf_config.update({"tie_word_embeddings": True}) - weight_type_map = get_gguf_weight_type_map(model_config.model, - gguf_weights_map) + weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map) # filter out unquantized modules to skip unquant_names = [ diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py deleted file mode 100644 index fad97aba84b6a..0000000000000 --- a/vllm/model_executor/model_loader/neuron.py +++ /dev/null @@ -1,476 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for selecting and loading Neuron models in transformers-neuronx -framework.""" -import ast -import copy -import importlib -import os -from typing import Optional - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import get_quantization_config -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput) - -TORCH_DTYPE_TO_NEURON_AMP = { - "auto": "f32", - "half": "f16", - "float16": "f16", - "bfloat16": "bf16", - "float": "f32", - "float32": "f32", - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float32: "f32", -} - -# Models supported by Neuron. -_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = { - "LlamaForCausalLM": ("transformers_neuronx.llama.model", - "LlamaForSampling", "LlamaForCausalLM"), - "MistralForCausalLM": ("transformers_neuronx.mistral.model", - "MistralForSampling", "MistralForCausalLM") -} - - -class NeuronCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - - self.on_device_sampling_disabled = on_device_sampling_disabled - if self.on_device_sampling_disabled: - # Use default sampler - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - ) -> torch.Tensor: - logits = self.model(input_ids, - cache_ids=positions, - start_ids=input_block_ids) - return logits - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - - if self.on_device_sampling_disabled: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - # On-device sampling outputs the token ids directly. - sampled_token_ids = logits.flatten() - next_tokens = [] - sample_idx = 0 - for seq_group in sampling_metadata.seq_groups: - samples = [] - for seq_id in seq_group.seq_ids: - token_id = sampled_token_ids[sample_idx].item() - samples.append( - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)})) - sample_idx += 1 - next_tokens.append( - CompletionSequenceGroupOutput(samples=samples, - prompt_logprobs=None)) - - return SamplerOutput(outputs=next_tokens) - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - - self.model = neuronx_model_cls.from_pretrained(model_name_or_path, - **kwargs) - self.model.to_neuron() - - -class NeuronSpeculationCausalLM(nn.Module): - """A Neuron-optimized causal language model with speculative decoding.""" - - SPECULATION_TERMINATION_ID = -1 - - def __init__(self, speculation_model) -> None: - super().__init__() - self.model = speculation_model - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - ) -> torch.Tensor: - tokens, counts = self.model.speculative_iteration( - input_ids, positions, input_block_ids) - - # Mark the end of accepted speculative tokens for each sequence with the - # speculation termination id. - batch_size, steps = tokens.shape - mask = torch.arange(steps).expand(batch_size, -1) >= counts - tokens[mask] = self.SPECULATION_TERMINATION_ID - - return tokens - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - batch_size, num_steps = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - sampler_output_list = [] - for step_index in range(num_steps): - if all(token_id == self.SPECULATION_TERMINATION_ID - for token_id in accepted_token_ids_by_step[step_index]): - break - step_output_token_ids = [] - for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][ - sequence_index] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_ids[sequence_index], - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - return sampler_output_list - - -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _NEURON_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on Neuron " - f"for now. Supported architectures: " - f"{list(_NEURON_SUPPORTED_MODELS.keys())}") - - -def _get_buckets(env: str, default_value: list[int]) -> list[int]: - env_value = os.getenv(env) - if env_value is None: - return default_value - buckets_remove_empty = filter( - lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) - buckets_int = map(int, buckets_remove_empty) - buckets_list = list(buckets_int) - return buckets_list - - -def _get_default_neuron_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): - """Generate a neuron config based on vllm config args.""" - from transformers_neuronx.config import ContinuousBatchingConfig - from transformers_neuronx.constants import LAYOUT_BSH - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - quant_config = dict( - dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - quantize_method="vector_dynamic") - neuron_quantization_config_builder = lambda quant: get_quantization_config( - quant).from_config(quant_config).get_quant_method(None, "") - # TODO: Add Paged attention config to the default neuron arguments. - default_neuron_args = dict( - collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - quant=neuron_quantization_config_builder(model_config.quantization) - if model_config.quantization else None, - continuous_batching=continuous_batching_config, - weight_tiling=bool(model_config.quantization), - on_device_generation=_get_neuron_on_device_generation_config( - model_config)) - return default_neuron_args - - -def _get_default_neuron_config_for_speculation( - model_config: ModelConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): - """Generate a neuron config for speculative decoding based on - vllm config args.""" - from transformers_neuronx.config import ContinuousBatchingConfig - from transformers_neuronx.constants import LAYOUT_BSH - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - - default_neuron_args = dict(collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - on_device_embedding=True, - continuous_batching=continuous_batching_config, - on_device_generation=copy.deepcopy( - model_config.neuron_sampling_params)) - return default_neuron_args - - -def _get_neuron_on_device_generation_config(model_config: ModelConfig): - if not _is_neuron_on_device_sampling_disabled(model_config): - return copy.deepcopy(model_config.neuron_sampling_params) - return None - - -def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: - return not getattr(model_config, "neuron_sampling_params", None) - - -def _get_neuron_config_after_override(default_neuron_config, - overridden_neuron_config): - from transformers_neuronx.config import (ContinuousBatchingConfig, - GenerationConfig, - KVCacheQuantizationConfig, - NeuronConfig, QuantizationConfig, - SparseAttnConfig) - - sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) - if sparse_attn: - overridden_neuron_config["sparse_attn"] = SparseAttnConfig( - **sparse_attn) - - kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {}) - if kv_cache_quant: - overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig( - **kv_cache_quant) - - continuous_batching = overridden_neuron_config.pop("continuous_batching", - {}) - if continuous_batching: - overridden_neuron_config[ - "continuous_batching"] = ContinuousBatchingConfig( - **continuous_batching) - - quant = overridden_neuron_config.pop("quant", {}) - if quant: - overridden_neuron_config["quant"] = QuantizationConfig(**quant) - - on_device_generation = overridden_neuron_config.pop( - "on_device_generation", {}) - if on_device_generation: - overridden_neuron_config["on_device_generation"] = GenerationConfig( - **on_device_generation) - default_neuron_config.update(overridden_neuron_config) - return NeuronConfig(**default_neuron_config) - - -def get_neuron_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: - """Initializes a neuron-optimized model for inference.""" - # Create a model instance. - model = NeuronCausalLM( - model_config.hf_config, - _is_neuron_on_device_sampling_disabled(model_config)) - - default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config) - - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - model.load_weights(model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - return model.eval() - - -def get_neuron_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized speculation model for inference. - - This method is only applicable for speculation with a standalone draft model - """ - from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder - - # For Eagle SD, we need to pass in additional parameters in neuron config. - is_eagle = getattr(speculation_config.draft_model_config.hf_config, - "is_eagle", False) - - # Create target model instance. - target_model = NeuronCausalLM(model_config.hf_config) - - default_neuron_config_args = _get_default_neuron_config_for_speculation( - model_config, parallel_config, scheduler_config) - if is_eagle: - default_neuron_config_args['is_eagle_target'] = True - - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - target_model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - target_model.eval() - - # Create draft model instance. - draft_model = NeuronCausalLM( - speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = ( - _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) - if is_eagle: - default_draft_neuron_config_args['is_eagle_draft'] = True - default_draft_neuron_config_args['has_pre_attention_norm'] = False - - draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, - speculation_config.draft_model_config.override_neuron_config) - - draft_model.load_weights(speculation_config.draft_model_config.model, - tp_degree=speculation_config. - draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - draft_model.eval() - - num_speculative_tokens = speculation_config.num_speculative_tokens - # Create speculation model instance. - speculation_model = FusedSpeculativeDecoder(draft_model.model, - target_model.model, - num_speculative_tokens) - speculation_model.to_neuron() - - return NeuronSpeculationCausalLM(speculation_model) - - -def get_neuron_eagle_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized EAGLE speculation model for inference.""" - from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder - - # Create target model instance. - target_model = NeuronCausalLM(model_config.hf_config) - - default_neuron_config_args = _get_default_neuron_config_for_speculation( - model_config, parallel_config, scheduler_config) - default_neuron_config_args['is_eagle_target'] = True - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - target_model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - target_model.eval() - - # Create draft model instance. - draft_model = NeuronCausalLM( - speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = ( - _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) - default_draft_neuron_config_args['is_eagle_draft'] = True - default_draft_neuron_config_args['has_pre_attention_norm'] = False - draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, - speculation_config.draft_model_config.override_neuron_config) - - draft_model.load_weights(speculation_config.draft_model_config.model, - tp_degree=speculation_config. - draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - draft_model.eval() - - token_tree: dict[int, list[int]] = ast.literal_eval( - speculation_config.speculative_token_tree) - - speculation_model = EagleSpeculativeDecoder(draft_model.model, - target_model.model, - token_tree=token_tree) - speculation_model.to_neuron() - - return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py deleted file mode 100644 index f450961c64ff4..0000000000000 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ /dev/null @@ -1,685 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for selecting and loading Neuron models in -neuronx-distributed-inference framework.""" -# Disabling yapf because yapf and isort have conflicts for the below imports -# yapf: disable -import copy -import hashlib -import importlib -import multiprocessing -import os -import shutil -from typing import Optional - -import torch -import torch.nn as nn -from neuronx_distributed_inference.models.config import ( - FusedSpecNeuronConfig, OnDeviceSamplingConfig) -from neuronx_distributed_inference.models.mllama.utils import ( - create_vision_mask) -from neuronx_distributed_inference.modules.lora_serving import ( - LoraServingConfig) -from neuronx_distributed_inference.utils.hf_adapter import ( - load_pretrained_config) -from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig - -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput) - -# yapf: enable -logger = init_logger(__name__) - -TORCH_DTYPE_TO_NEURON_AMP = { - "auto": "float32", - "half": "float16", - "float16": "float16", - "bfloat16": "bfloat16", - "float": "float32", - "float32": "float32", - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.float32: "float32", -} - -# Models supported by Neuronx distributed for inference. -_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = { - "LlamaForCausalLM": - ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "MistralForCausalLM": - ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "DbrxForCausalLM": - ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", - "NeuronDbrxForCausalLM"), - "MixtralForCausalLM": - ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", - "NeuronMixtralForCausalLM"), - "MllamaForConditionalGeneration": - ("neuronx_distributed_inference.models.mllama.modeling_mllama", - "NeuronMllamaForCausalLM"), -} - - -class NeuronCausalLM(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - ) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - prev_hidden: Optional[torch.Tensor] = None, - adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor: - # sort block ids sequentially for perf/neuron support reasons - sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - output = self.model(input_ids, - attention_mask=None, - position_ids=positions, - seq_ids=sorted_input_block_ids, - sampling_params=sampling_params, - prev_hidden=prev_hidden, - adapter_ids=adapter_ids) - # on-device sampling - if self.config.neuron_config.on_device_sampling_config: - output = output.hidden_states - else: - output = output.logits[:, -1, :] - - restored_indices = torch.argsort(sorted_indices) - if input_block_ids.shape[0] != 1: - output = torch.index_select(output, 0, restored_indices) - - return output - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - # on-device sampling - if self.config.neuron_config.on_device_sampling_config: - batch_size = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.flatten() - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - step_output_token_ids = [] - for i, seq_id in enumerate(seq_ids): - token_id = accepted_token_ids_by_step[i] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - return SamplerOutput(outputs=step_output_token_ids) - else: - return self.sampler(logits, sampling_metadata) - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - self.config.neuron_config = neuron_config - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - try: - self.model = neuronx_model_cls(compiled_model_path) - override_neuron_config = kwargs["override_neuron_config"] - for k, v in override_neuron_config.items(): - setattr(self.model.config.neuron_config, k, v) - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: %s", e) - logger.warning("Failed to load the model from %s, Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - self.model = neuronx_model_cls(model_name_or_path, config) - self.model.compile(compiled_model_path) - self.model.load(compiled_model_path) - - -class NeuronMllamaForCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: - super().__init__() - # has_image is the only multimodal input that is used in - # token-generation - # This is a cache (on CPU) that saves has_image data per sequence id - # The number of entries in this cache is <= Batch-Size - self.has_image_cache: dict[int, torch.Tensor] = {} - self.config = config - self.logits_processor = LogitsProcessor( - config.get_text_config().vocab_size, logits_as_input=True) - - self.on_device_sampling_disabled = on_device_sampling_disabled - if self.on_device_sampling_disabled: - # Use default sampler - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - self.is_reorder_needed: bool = True - - def read_from_has_image_cache(self, seq_ids: torch.Tensor): - has_image_list = [] - for index in range(len(seq_ids)): - seq_id = seq_ids[index].item() - if seq_id in self.has_image_cache: - has_image_list.append(self.has_image_cache[seq_id]) - else: - has_image_list.append(torch.tensor([0])) - return torch.tensor(has_image_list) - - def write_to_has_image_cache(self, seq_ids: torch.Tensor, - has_image: torch.Tensor): - for index in range(len(seq_ids)): - seq_id = seq_ids[index].item() - if index < len(has_image): - self.has_image_cache[seq_id] = has_image[index] - else: - self.has_image_cache[seq_id] = torch.zeros(1) - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - seq_ids: torch.Tensor, pixel_values: torch.Tensor, - aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, - has_image: torch.Tensor, sampling_params) -> torch.Tensor: - - # We update the has_image cache during prefill - # and read the has_image cache during decode - if input_ids.shape[-1] > 1: # prefill - self.write_to_has_image_cache(seq_ids, has_image) - else: - has_image = self.read_from_has_image_cache(seq_ids) - bs = input_ids.shape[0] - num_chunks = torch.zeros((bs, 1)) - aspect_ratios = torch.zeros((bs, 1, 2)) - - input_block_ids = seq_ids - origin_input_block_ids = seq_ids - if self.is_reorder_needed: - # sort block ids sequentially for perf/neuron support reasons - input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - pixel_values = torch.index_select(pixel_values, 0, sorted_indices) - aspect_ratios = torch.index_select(aspect_ratios, 0, - sorted_indices) - num_chunks = torch.index_select(num_chunks, 0, sorted_indices) - has_image = torch.index_select(has_image, 0, sorted_indices) - - self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) - output = self.model( - input_ids.to(torch.int32), - attention_mask=None, - position_ids=positions.to(torch.int32), - seq_ids=seq_ids.flatten().to(torch.int32), - pixel_values=pixel_values.to( - self.config.vision_config.torch_dtype), - aspect_ratios=aspect_ratios.to(torch.int32), - vision_mask=self.vision_mask.to(torch.int32), - sampling_params=sampling_params, - num_chunks=num_chunks.to(torch.int32), - has_image=has_image.to(torch.int32), - ) - if self.config.neuron_config.on_device_sampling_config: - output = output.hidden_states - else: - output = output.logits[:, -1, :] - - if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1: - restored_indices = torch.argsort(sorted_indices) - output = torch.index_select(output, 0, restored_indices) - return output - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample(self, hidden_states, sampling_metadata): - if not self.on_device_sampling_disabled: - with torch.profiler.record_function("sample"): - hidden_states = hidden_states.flatten() - res = [] - sample_idx = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - samples = [] - for seq_id in seq_ids: - token_id = hidden_states[sample_idx].item() - samples.append( - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)})) - sample_idx += 1 - res.append( - CompletionSequenceGroupOutput(samples=samples, - prompt_logprobs=None)) - next_tokens = SamplerOutput(outputs=res) - else: - next_tokens = self.sampler(None, hidden_states, sampling_metadata) - return next_tokens - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - self.config.neuron_config = neuron_config - logger.info("neuron_config buckets: %s", - self.config.neuron_config.buckets) - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - try: - self.model = neuronx_model_cls(compiled_model_path) - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.vision_token_id = tokenizer( - "<|image|>", add_special_tokens=False).input_ids[0] - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError): - logger.warning("Failed to load the model from %s, Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - self.model = neuronx_model_cls(model_name_or_path, config) - - logger.info("\nCompiling and saving model to %s", model_name_or_path) - - p = multiprocessing.Process(target=compile_model, - args=(self, compiled_model_path)) - p.start() - p.join() - - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - tokenizer.save_pretrained(compiled_model_path) - logger.info("Successfully compiled and saved the model in %s", - compiled_model_path) - - # Read "<|image|>" token_id from the tokenizer - self.vision_token_id = tokenizer("<|image|>", - add_special_tokens=False).input_ids[0] - logger.info("\nLoading model from compiled checkpoint...") - self.model.load(compiled_model_path) - - -def compile_model(neuron_model, traced_model_path): - neuron_model.model.compile(traced_model_path) - - -class NeuronSpeculationCausalLM(nn.Module): - """A Neuron-optimized causal language model with speculative decoding.""" - - def __init__( - self, - config: PretrainedConfig, - ) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - # Lazy initialized - self.model: nn.Module - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - ) -> torch.Tensor: - # sort block ids sequentially for perf/neuron support reasons - sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - - output = self.model(input_ids, - attention_mask=None, - position_ids=positions, - seq_ids=sorted_input_block_ids, - sampling_params=sampling_params) - restored_indices = torch.argsort(sorted_indices) - - # CTX encoding - if (positions[:, 0]).sum().item() == 0: - output = output.fused_outputs[0][:, 0:1] - if input_block_ids.shape[0] != 1: - output = torch.index_select(output, 0, restored_indices) - return output - - # Fused Spec (Generation) - accepted_tokens_with_padding = output.fused_outputs[0] - next_pos_ids = output.fused_outputs[-1] - generated_token_counts = next_pos_ids - positions - - assert torch.any(generated_token_counts == 0).item() is False, \ - "NxDI model generated no output for one or more sequences." - - batch_size, steps = accepted_tokens_with_padding.shape - mask = torch.arange(steps).expand(batch_size, - -1) >= generated_token_counts - accepted_tokens_with_padding[mask] = -1 - - if input_block_ids.shape[0] != 1: - accepted_tokens_with_padding = torch.index_select( - accepted_tokens_with_padding, 0, restored_indices) - - return accepted_tokens_with_padding - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - batch_size, num_steps = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - sampler_output_list = [] - for step_index in range(num_steps): - if all(token_id == -1 - for token_id in accepted_token_ids_by_step[step_index]): - break - step_output_token_ids = [] - for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][ - sequence_index] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_ids[sequence_index], - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - return sampler_output_list - - def load_weights(self, model_name_or_path: str, - draft_model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - - draft_neuron_config = copy.deepcopy(config.neuron_config) - if not config.neuron_config.enable_eagle_speculation: - draft_neuron_config.speculation_length = 0 - draft_neuron_config.trace_tokengen_model = True - draft_neuron_config.enable_fused_speculation = False - if getattr(config.neuron_config, "draft_model_modules_to_not_convert", - None): - draft_neuron_config.modules_to_not_convert = ( - draft_neuron_config.draft_model_modules_to_not_convert) - if config.neuron_config.enable_eagle_speculation: - draft_neuron_config.is_eagle_draft = True - draft_neuron_config.sequence_parallel_enabled = False - draft_config = neuronx_model_cls.get_config_cls()( - draft_neuron_config, - load_config=load_pretrained_config(draft_model_name_or_path)) - fused_spec_config = (FusedSpecNeuronConfig( - neuronx_model_cls._model_cls, - draft_config=draft_config, - draft_model_path=draft_model_name_or_path)) - config.fused_spec_config = fused_spec_config - self.config.neuron_config = neuron_config - - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - try: - self.model = neuronx_model_cls(compiled_model_path) - override_neuron_config = kwargs["override_neuron_config"] - for k, v in override_neuron_config.items(): - setattr(self.model.config.neuron_config, k, v) - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: %s", e) - logger.warning("Failed to load the model from %s Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - if not os.path.exists(draft_model_name_or_path): - if draft_model_name_or_path != model_name_or_path: - hf_model = AutoModelForCausalLM.from_pretrained( - draft_model_name_or_path) - saved_path = os.path.join("local-models", - draft_model_name_or_path) - hf_model.save_pretrained(saved_path) - draft_model_name_or_path = saved_path - else: - draft_model_name_or_path = model_name_or_path - config.fused_spec_config.draft_model_path = draft_model_name_or_path - self.model = neuronx_model_cls(model_name_or_path, config) - self.model.compile(compiled_model_path) - self.model.load(compiled_model_path) - - -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _NEURON_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on Neuron " - f"for now. Supported architectures: " - f"{list(_NEURON_SUPPORTED_MODELS.keys())}") - - -def _get_default_neuron_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_serving_config: LoraServingConfig): - """Generate a neuron config based on vllm config args.""" - on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, - deterministic=False) - batch_size = scheduler_config.max_num_seqs - - neuron_config = dict( - tp_degree=parallel_config.tensor_parallel_size, - ctx_batch_size=1, - batch_size=batch_size, - max_context_length=scheduler_config.max_model_len, - seq_len=scheduler_config.max_model_len, - enable_bucketing=True, - is_continuous_batching=True, - quantized=False, - torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - padding_side="right", - on_device_sampling_config=on_device_sampling_config, - sequence_parallel_enabled=True, - lora_serving_config=lora_serving_config) - return neuron_config - - -def _get_default_speculation_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Generate a neuron config for speculative decoding based on vllm config - args.""" - neuron_config = dict( - tp_degree=parallel_config.tensor_parallel_size, - ctx_batch_size=1, - batch_size=scheduler_config.max_num_seqs, - max_context_length=scheduler_config.max_model_len, - seq_len=scheduler_config.max_model_len, - speculation_length=speculation_config.num_speculative_tokens, - trace_tokengen_model=False, - enable_fused_speculation=True, - enable_bucketing=True, - is_continuous_batching=True, - quantized=False, - torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - on_device_sampling_config=dict( - top_k=1, - do_sample=False, - )) - return neuron_config - - -def _get_neuron_config_after_override(default_neuron_config, - overridden_neuron_config): - """Update default neuron config values with override args""" - overridden_neuron_config = overridden_neuron_config or {} - default_neuron_config.update(overridden_neuron_config) - return default_neuron_config - - -def get_neuron_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_serving_config: LoraServingConfig) -> nn.Module: - """Initializes a neuron-optimized model for inference.""" - model_arch = _get_model_architecture(model_config.hf_config) - if model_arch == "MllamaForConditionalGeneration": - model = NeuronMllamaForCausalLM(model_config.hf_config) - else: - model = NeuronCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config, lora_serving_config) - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - override_neuron_config = model_config.override_neuron_config - model.load_weights(model_config.model, - neuron_config=neuron_config, - override_neuron_config=override_neuron_config) - return model.eval() - - -def get_neuron_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized speculation model for inference. - - This model handles speculation using both a draft model and an EAGLE draft. - """ - model = NeuronSpeculationCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_speculation_config( - model_config, parallel_config, scheduler_config, speculation_config) - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - override_neuron_config = model_config.override_neuron_config - model.load_weights(model_config.model, - speculation_config.draft_model_config.model, - neuron_config=neuron_config, - override_neuron_config=override_neuron_config) - return model.eval() diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py new file mode 100644 index 0000000000000..890dd7231a0e1 --- /dev/null +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types + +import torch +from torch import nn + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.utils import process_weights_after_loading + +logger = init_logger(__name__) + +# Notes for Online Quantization +# In terms of state of checkpoints, quantization config and their +# correspondance to online quantization: +# | Use Case | Checkpoints | model_config.quantization | +# | no quant | high precision | None | +# | offline quant | quantized | fp8, torchao etc. | +# | online quant | high precision | torchao etc. | +# +# The process for loading non-quantized checkpoint +# 1. load non-quantized weights (load_weights) +# 2. do any additional post processing (process_weights_after_loading) +# +# The process for loading offline quantized checkpoint +# 1. load offline-quantized weights (load_weights) +# 2. do any additional post processing (process_weights_after_loading) + +# The process for unquantized model reloading +# (repeated run in RL training loop) +# first run +# UI1. load_weights: load bfloat16 weights +# UI2. process_weights_after_loading: any additional post processing +# subsequent run +# UC1: load_weights: load bfloat16 weights +# (shouldn't be any issues since we didn't change any attributes +# of the weights) +# UC2: process_weights_after_loading: any additional post processing + +# The process for weight reloading with online quantization +# (repeated run in RL training loop) +# first run +# I1. load_weights: load bfloat16 weights +# I2. process_weights_after_loading: +# record weight metadata and attributes for R1 and R2 +# quantize weights to fp8 +# subsequent run +# (beginning model weight is in fp8) +# load_weights: +# R1. restore bfloat16 model weight metadata +# R2. restore the model weight attributes +# R3. reload bfloat16 weights +# R4. quantize weights (by calling process_weights_after_loading), +# also set `process_weights_after_loading_already_called` to +# True to stop it from running again +# process_weights_after_loading (if called): +# this will be skipped since it's already ran in +# load_weights + + +def maybe_save_metadata_and_attributes_for_weight_reloading( + model: nn.Module, model_config: ModelConfig +): + # following is to support on the fly quantization, currently only supported + # for torchao + if model_config.quantization != "torchao": + return + + if getattr(model, "process_weights_after_loading_already_called", False): + # In case `process_weights_after_loading` is called multiple times + # we'll skip it at later times + logger.warning( + "process_weights_after_loading already called for model %s", model + ) + return + + from vllm.model_executor.model_loader.weight_utils import get_quant_config + + quant_config = get_quant_config(model_config, None) + + # If checkpoint is already torchao serialized, this means it's + # pre-quantized quantization case, we'll skip saving the metadata + # Otherwise, this is Step I2 of initialization steps of + # online quantization + # This step record the weights metadata and weight attributes so we can + # restore the bfloat16 model weights during the relad step (R1 and R2) + # see Notes in online_quantization.py for more details + if not ( + hasattr(quant_config, "is_checkpoint_torchao_serialized") + and not quant_config.is_checkpoint_torchao_serialized + ): + return + + # This is the I2 step of online quantiztion that saves + # metadata and attributes of weights so they can be used in R1 and + # R2 step, note that we only save these during initialization + + # Includes two things + # 1. save floating point metadata (shape, dtype, device) for init + # 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init + + if getattr(model, "weight_metadata_and_attr_saved", False): + return + + # save the dtype, shape and device for model parameter, used for + # restoring the model high precision parameters before + # reloading the weights + assert not hasattr(model, "original_weights_rebuild_keys") + model.original_weights_rebuild_keys = {} + for name, p in model.named_parameters(): + model.original_weights_rebuild_keys[name] = { + "shape": p.shape, + "dtype": p.dtype, + "device": p.device, + } + + # record the weight attributes (loader functions etc.) + # so these can be recovered later when we reload the weights + # structure: {"weight_name": {"weight_attr_key": attr}} + assert not hasattr(model, "recorded_weight_attr") + model.recorded_weight_attr = {} + for name, param in model.named_parameters(): + model.recorded_weight_attr[name] = {} + for key in param.__dict__: + if hasattr(param, key): + attr = getattr(param, key) + if not callable(attr): + model.recorded_weight_attr[name][key] = attr + elif hasattr(attr, "__self__") and param is attr.__self__: + # if attr is a bonded method for an instance, and + # attr.__self__ points to the instance (param) + # we'll record the underlying function object + model.recorded_weight_attr[name][key] = attr.__func__ + else: + model.recorded_weight_attr[name][key] = attr + # mark the metadata and attributes saved so we don't run it again + model.weight_metadata_and_attr_saved = True + + +def _bond_method_to_cls(func, obj): + if hasattr(func, "__self__") or not callable(func): + # If the function is already bound to an instance, return it as is + return func + else: + return types.MethodType(func, obj) + + +def load_weights_and_online_quantize( + model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig +) -> set[str]: + # online quantization, right now only enabled for + # torchao + # R1, R2, R3, R4 in the Notes + + # TODO: Add fp8 support + assert model_config.quantization == "torchao", ( + "online quantization is only enabled for torchao currently" + ) + # TODO: use create_weights to restore the weights to original state + + # Step R1: First restore the quantized weights to original bfloat16 + # weights, with original metadata (shape, dtype, device) + # and attributes, so that bfloat16 weights can be loaded properly + existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() + named_modules = dict(model.named_modules(remove_duplicate=False)) + model_device = None + + # Step R2: recover the parameter to the state before first loading + for name, d in model.original_weights_rebuild_keys.items(): + _shape = d["shape"] + _dtype = d["dtype"] + _device = d["device"] + if model_device is not None: + assert model_device == _device, ( + "Expecting all weights " + "to be in the same device for now, got both: " + f"{model_device} and {_device}" + ) + else: + model_device = _device + + if name in existing_param_names: + module_name, weight_name = name.rsplit(".", 1) + module = named_modules[module_name] + setattr( + module, + weight_name, + torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)), + ) + + # recorded_weight_attr is + # {"weight_name": {"weight_attr_key": attr}} + # e.g. + # { + # { + # "layer.0.weight": { + # "weight_loader": weight_loader_function_object, + # "input_dim": 0, ... + # }, + # "layer.1.weight": ..., + # } + # } + for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): + for attr_name, attr in weight_attr_dict.items(): + module_name, weight_name = full_weight_name.rsplit(".", 1) + module = named_modules[module_name] + weight = getattr(module, weight_name) + if not hasattr(weight, attr_name): + setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) + + # Step I1: reload bfloat16 / high precision weights + loaded_weights = model.load_weights( + model_loader.get_all_weights(model_config, model) + ) + + # Step I2: online quantize the weights + # manually process weights after loading + model.process_weights_after_loading_already_called = False + process_weights_after_loading(model, model_config, model_device) + model.process_weights_after_loading_already_called = True + return loaded_weights diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 83e0f386c1082..50a92edd1162c 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: SIM117 -import glob import os from collections.abc import Generator from typing import Optional @@ -10,19 +9,21 @@ import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - runai_safetensors_weights_iterator) -from vllm.transformers_utils.s3_utils import glob as s3_glob -from vllm.transformers_utils.utils import is_s3 + download_safetensors_index_file_from_hf, + download_weights_from_hf, + runai_safetensors_weights_iterator, +) +from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors class RunaiModelStreamerLoader(BaseModelLoader): """ - Model loader that can load safetensors - files from local FS or S3 bucket. + Model loader that can load safetensors + files from local FS or S3 bucket. """ def __init__(self, load_config: LoadConfig): @@ -30,64 +31,65 @@ class RunaiModelStreamerLoader(BaseModelLoader): if load_config.model_loader_extra_config: extra_config = load_config.model_loader_extra_config - if ("concurrency" in extra_config - and isinstance(extra_config.get("concurrency"), int)): + if "concurrency" in extra_config and isinstance( + extra_config.get("concurrency"), int + ): os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( - extra_config.get("concurrency")) + extra_config.get("concurrency") + ) - if ("memory_limit" in extra_config - and isinstance(extra_config.get("memory_limit"), int)): + if "memory_limit" in extra_config and isinstance( + extra_config.get("memory_limit"), int + ): os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( - extra_config.get("memory_limit")) + extra_config.get("memory_limit") + ) - runai_streamer_s3_endpoint = os.getenv( - 'RUNAI_STREAMER_S3_ENDPOINT') - aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') - if (runai_streamer_s3_endpoint is None - and aws_endpoint_url is not None): + runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT") + aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL") + if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None: os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> list[str]: + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> list[str]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" - is_s3_path = is_s3(model_name_or_path) + is_object_storage_path = is_runai_obj_uri(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME - hf_folder = (model_name_or_path if - (is_local or is_s3_path) else download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - [safetensors_pattern], - revision, - ignore_patterns=self.load_config.ignore_patterns, - )) - if is_s3_path: - hf_weights_files = s3_glob(path=hf_folder, - allow_pattern=[safetensors_pattern]) - else: - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) + hf_folder = ( + model_name_or_path + if (is_local or is_object_storage_path) + else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + ) + hf_weights_files = list_safetensors(path=hf_folder) - if not is_local and not is_s3_path: + if not is_local and not is_object_storage_path: download_safetensors_index_file_from_hf( - model_name_or_path, index_file, self.load_config.download_dir, - revision) + model_name_or_path, index_file, self.load_config.download_dir, revision + ) if not hf_weights_files: raise RuntimeError( - f"Cannot find any safetensors model weights with " - f"`{model_name_or_path}`") + f"Cannot find any safetensors model weights with `{model_name_or_path}`" + ) return hf_weights_files def _get_weights_iterator( - self, model_or_path: str, - revision: str) -> Generator[tuple[str, torch.Tensor], None, None]: + self, model_or_path: str, revision: str + ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_weights_files = self._prepare_weights(model_or_path, revision) return runai_safetensors_weights_iterator( @@ -99,11 +101,11 @@ class RunaiModelStreamerLoader(BaseModelLoader): """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load weights into a model.""" model_weights = model_config.model if hasattr(model_config, "model_weights"): model_weights = model_config.model_weights model.load_weights( - self._get_weights_iterator(model_weights, model_config.revision)) + self._get_weights_iterator(model_weights, model_config.revision) + ) diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 3edd4ec4007e8..e65eb78819e29 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -10,11 +10,14 @@ from typing import Any, Optional import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, runai_safetensors_weights_iterator) + download_weights_from_hf, + runai_safetensors_weights_iterator, +) from vllm.transformers_utils.s3_utils import glob as s3_glob from vllm.transformers_utils.utils import is_s3 @@ -35,23 +38,30 @@ class ShardedStateLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - extra_config = ({} if load_config.model_loader_extra_config is None - else load_config.model_loader_extra_config.copy()) + extra_config = ( + {} + if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy() + ) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) if extra_config: - raise ValueError(f"Unexpected extra config keys for load format " - f"{load_config.load_format}: " - f"{load_config.model_loader_extra_config.keys()}") + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}" + ) @staticmethod def _filter_subtensors( - tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: + tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. """ same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = ( - collections.defaultdict(list)) + collections.defaultdict(list) + ) for key, tensor in tensors.items(): if tensor.numel(): ptr = tensor.untyped_storage().data_ptr() @@ -79,8 +89,7 @@ class ShardedStateLoader(BaseModelLoader): result[k] = t return result - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]): + def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): return model_name_or_path else: @@ -96,8 +105,7 @@ class ShardedStateLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: from vllm.distributed import get_tensor_model_parallel_rank model_weights = model_config.model @@ -113,16 +121,16 @@ class ShardedStateLoader(BaseModelLoader): filepaths = [] if is_s3(local_model_path): - file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" - filepaths = s3_glob(path=local_model_path, - allow_pattern=[file_pattern]) + file_pattern = f"*{self.pattern.format(rank=rank, part='*')}" + filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern]) else: filepaths = glob.glob(pattern) if not filepaths: # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!") + f"pre-sharded checkpoints are currently supported!" + ) state_dict = self._filter_subtensors(model.state_dict()) for key, tensor in self.iterate_over_files(filepaths): # If loading with LoRA enabled, additional padding may @@ -135,8 +143,7 @@ class ShardedStateLoader(BaseModelLoader): param_data = param_data.narrow(dim, 0, size) if tensor.shape != param_shape: logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", + "loading tensor of shape %s into parameter '%s' of shape %s", tensor.shape, key, param_shape, @@ -144,15 +151,16 @@ class ShardedStateLoader(BaseModelLoader): param_data.copy_(tensor) state_dict.pop(key) if state_dict: - raise ValueError( - f"Missing keys {tuple(state_dict)} in loaded state!") + raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") def iterate_over_files( - self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: + self, paths + ) -> Generator[tuple[str, torch.Tensor], None, None]: if self.load_config.load_format == "runai_streamer_sharded": yield from runai_safetensors_weights_iterator(paths, True) else: from safetensors.torch import safe_open + for path in paths: with safe_open(path, framework="pt") as f: for key in f.keys(): # noqa: SIM118 diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3d491be3156b6..9d58278f996b6 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -22,11 +22,9 @@ from torch.utils._python_dispatch import TorchDispatchMode from transformers import PretrainedConfig import vllm.envs as envs -from vllm.config import (ModelConfig, ParallelConfig, VllmConfig, - set_current_vllm_config) +from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, PlaceholderModule @@ -34,11 +32,14 @@ if TYPE_CHECKING: from vllm.engine.arg_utils import EngineArgs try: - from tensorizer import (DecryptionParams, EncryptionParams, - TensorDeserializer, TensorSerializer) + from tensorizer import ( + DecryptionParams, + EncryptionParams, + TensorDeserializer, + TensorSerializer, + ) from tensorizer.stream_io import open_stream - from tensorizer.utils import (convert_bytes, get_mem_usage, - no_init_or_tensor) + from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor except ImportError: tensorizer = PlaceholderModule("tensorizer") @@ -52,9 +53,15 @@ except ImportError: no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") __all__ = [ - 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', - 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', - 'no_init_or_tensor', 'TensorizerConfig' + "EncryptionParams", + "DecryptionParams", + "TensorDeserializer", + "TensorSerializer", + "open_stream", + "convert_bytes", + "get_mem_usage", + "no_init_or_tensor", + "TensorizerConfig", ] logger = init_logger(__name__) @@ -73,12 +80,12 @@ def tensorizer_kwargs_arg(value): raise argparse.ArgumentTypeError( f"Not deserializable to dict: {value}. serialization_kwargs and " f"deserialization_kwargs must be " - f"deserializable from a JSON string to a dictionary. ") + f"deserializable from a JSON string to a dictionary. " + ) return loaded class MetaTensorMode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} @@ -88,8 +95,9 @@ class MetaTensorMode(TorchDispatchMode): return func(*args, **kwargs) -def meta_tensor_mode(loading_code=None, ): - +def meta_tensor_mode( + loading_code=None, +): if loading_code is None: return _NoInitOrTensorImpl.context_manager() elif callable(loading_code): @@ -99,15 +107,15 @@ def meta_tensor_mode(loading_code=None, ): raise TypeError( "expected a callable to evaluate," " or None if being used as a context manager;" - f' got an object of type "{type(loading_code).__name__}" instead.') + f' got an object of type "{type(loading_code).__name__}" instead.' + ) class _NoInitOrTensorImpl: _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) - is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", - default=False) + is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False) _count_active: int = 0 _count_active_lock = threading.Lock() @@ -139,7 +147,6 @@ class _NoInitOrTensorImpl: @staticmethod def _disable(func): - def wrapper(*args, **kwargs): if not _NoInitOrTensorImpl.is_active.get(): return func(*args, **kwargs) @@ -162,76 +169,81 @@ class TensorizerConfig(MutableMapping): stream_kwargs: Optional[dict[str, Any]] = None serialization_kwargs: Optional[dict[str, Any]] = None deserialization_kwargs: Optional[dict[str, Any]] = None - _extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False, - default=None) - model_class: Optional[type[torch.nn.Module]] = field(init=False, - default=None) + _extra_serialization_attrs: Optional[dict[str, Any]] = field( + init=False, default=None + ) + model_class: Optional[type[torch.nn.Module]] = field(init=False, default=None) hf_config: Optional[PretrainedConfig] = field(init=False, default=None) dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None) _is_sharded: bool = field(init=False, default=False) _fields: ClassVar[tuple[str, ...]] _keys: ClassVar[frozenset[str]] - """ - Args for the TensorizerConfig class. These are used to configure the - behavior of model serialization and deserialization using Tensorizer. + """Configuration class for Tensorizer settings. - Args: - tensorizer_uri: Path to serialized model tensors. Can be a local file - path or a S3 URI. This is a required field unless lora_dir is - provided and the config is meant to be used for the - `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or - `lora_dir` is passed to this object's initializer, this is a required - argument. - tensorizer_dir: Path to a directory containing serialized model tensors, - and all other potential model artifacts to load the model, such as - configs and tokenizer files. Can be passed instead of `tensorizer_uri` - where the `model.tensors` file will be assumed to be in this - directory. - vllm_tensorized: If True, indicates that the serialized model is a - vLLM model. This is used to determine the behavior of the - TensorDeserializer when loading tensors from a serialized model. - It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. Note that this is now - deprecated, as serialized vLLM models are now automatically - inferred as vLLM models. - verify_hash: If True, the hashes of each tensor will be verified against - the hashes stored in the metadata. A `HashMismatchError` will be - raised if any of the hashes do not match. - num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is `None`, which will dynamically set - the number of readers based on the number of available - resources and model size. This greatly increases performance. - encryption_keyfile: File path to a binary file containing a - binary key to use for decryption. `None` (the default) means - no decryption. See the example script in - examples/others/tensorize_vllm_model.py. - s3_access_key_id: The access key for the S3 bucket. Can also be set via - the S3_ACCESS_KEY_ID environment variable. - s3_secret_access_key: The secret access key for the S3 bucket. Can also - be set via the S3_SECRET_ACCESS_KEY environment variable. - s3_endpoint: The endpoint for the S3 bucket. Can also be set via the - S3_ENDPOINT_URL environment variable. - lora_dir: Path to a directory containing LoRA adapter artifacts for - serialization or deserialization. When serializing LoRA adapters - this is the only necessary parameter to pass to this object's - initializer. - """ + These settings configure the behavior of model serialization and + deserialization using Tensorizer. + + Attributes: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or + `lora_dir` is passed to this object's initializer, this is + a required argument. + tensorizer_dir: Path to a directory containing serialized model tensors, + and all other potential model artifacts to load the model, such as + configs and tokenizer files. Can be passed instead of + `tensorizer_uri` where the `model.tensors` file will be assumed + to be in this directory. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified + against the hashes stored in the metadata. A `HashMismatchError` + will be raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/others/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + lora_dir: Path to a directory containing LoRA adapter artifacts for + serialization or deserialization. When serializing LoRA adapters + this is the only necessary parameter to pass to this object's + initializer. + """ def __post_init__(self): # check if the configuration is for a sharded vLLM model - self._is_sharded = isinstance(self.tensorizer_uri, str) \ - and re.search(r'%0\dd', self.tensorizer_uri) is not None + self._is_sharded = ( + isinstance(self.tensorizer_uri, str) + and re.search(r"%0\dd", self.tensorizer_uri) is not None + ) if self.tensorizer_dir and self.lora_dir: raise ValueError( "Only one of tensorizer_dir or lora_dir may be specified. " "Use lora_dir exclusively when serializing LoRA adapters, " - "and tensorizer_dir or tensorizer_uri otherwise.") + "and tensorizer_dir or tensorizer_uri otherwise." + ) if self.tensorizer_dir and self.tensorizer_uri: logger.warning_once( "Provided both tensorizer_dir and tensorizer_uri. " "Inferring tensorizer_dir from tensorizer_uri as the " - "latter takes precedence.") + "latter takes precedence." + ) self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) if not self.tensorizer_uri: if self.lora_dir: @@ -239,11 +251,13 @@ class TensorizerConfig(MutableMapping): elif self.tensorizer_dir: self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors" else: - raise ValueError("Unable to resolve tensorizer_uri. " - "A valid tensorizer_uri or tensorizer_dir " - "must be provided for deserialization, and a " - "valid tensorizer_uri, tensorizer_uri, or " - "lora_dir for serialization.") + raise ValueError( + "Unable to resolve tensorizer_uri. " + "A valid tensorizer_uri or tensorizer_dir " + "must be provided for deserialization, and a " + "valid tensorizer_uri, tensorizer_uri, or " + "lora_dir for serialization." + ) else: self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) @@ -279,8 +293,12 @@ class TensorizerConfig(MutableMapping): tc_dict = {} for k, v in raw_tc_dict.items(): - if (k not in blacklisted and k not in tc_dict - and not k.startswith("_") and v is not None): + if ( + k not in blacklisted + and k not in tc_dict + and not k.startswith("_") + and v is not None + ): tc_dict[k] = v return tc_dict @@ -292,26 +310,25 @@ class TensorizerConfig(MutableMapping): self, parallel_config: "ParallelConfig", ) -> None: - if parallel_config.tensor_parallel_size > 1 \ - and not self._is_sharded: + if parallel_config.tensor_parallel_size > 1 and not self._is_sharded: raise ValueError( "For a sharded model, tensorizer_uri should include a" " string format template like '%04d' to be formatted" - " with the rank of the shard") + " with the rank of the shard" + ) def verify_with_model_config(self, model_config: "ModelConfig") -> None: - if (model_config.quantization is not None - and self.tensorizer_uri is not None): + if model_config.quantization is not None and self.tensorizer_uri is not None: logger.warning( "Loading a model using Tensorizer with quantization on vLLM" - " is unstable and may lead to errors.") + " is unstable and may lead to errors." + ) def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): if tensorizer_args is None: tensorizer_args = self._construct_tensorizer_args() - return open_stream(self.tensorizer_uri, - **tensorizer_args.stream_kwargs) + return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs) def keys(self): return self._keys @@ -353,34 +370,36 @@ class TensorizerArgs: for k, v in tensorizer_config.items(): setattr(self, k, v) self.file_obj = tensorizer_config.tensorizer_uri - self.s3_access_key_id = (tensorizer_config.s3_access_key_id - or envs.S3_ACCESS_KEY_ID) - self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key - or envs.S3_SECRET_ACCESS_KEY) + self.s3_access_key_id = ( + tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID + ) + self.s3_secret_access_key = ( + tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY + ) self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL self.stream_kwargs = { "s3_access_key_id": tensorizer_config.s3_access_key_id, "s3_secret_access_key": tensorizer_config.s3_secret_access_key, "s3_endpoint": tensorizer_config.s3_endpoint, - **(tensorizer_config.stream_kwargs or {}) + **(tensorizer_config.stream_kwargs or {}), } self.deserialization_kwargs = { "verify_hash": tensorizer_config.verify_hash, "encryption": tensorizer_config.encryption_keyfile, "num_readers": tensorizer_config.num_readers, - **(tensorizer_config.deserialization_kwargs or {}) + **(tensorizer_config.deserialization_kwargs or {}), } if self.encryption_keyfile: with open_stream( - tensorizer_config.encryption_keyfile, - **self.stream_kwargs, + tensorizer_config.encryption_keyfile, + **self.stream_kwargs, ) as stream: key = stream.read() decryption_params = DecryptionParams.from_key(key) - self.deserialization_kwargs['encryption'] = decryption_params + self.deserialization_kwargs["encryption"] = decryption_params @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -388,17 +407,20 @@ class TensorizerArgs: # Tensorizer options arg group group = parser.add_argument_group( - 'tensorizer options', - description=('Options for configuring the behavior of the' - ' tensorizer deserializer when ' - 'load_format=tensorizer is specified when ' - 'initializing an LLMEngine, either via the CLI ' - 'when running the vLLM OpenAI inference server ' - 'with a JSON string passed to ' - '--model-loader-extra-config or as arguments given ' - 'to TensorizerConfig when passed to ' - 'model_loader_extra_config in the constructor ' - 'for LLMEngine.')) + "tensorizer options", + description=( + "Options for configuring the behavior of the" + " tensorizer deserializer when " + "load_format=tensorizer is specified when " + "initializing an LLMEngine, either via the CLI " + "when running the vLLM OpenAI inference server " + "with a JSON string passed to " + "--model-loader-extra-config or as arguments given " + "to TensorizerConfig when passed to " + "model_loader_extra_config in the constructor " + "for LLMEngine." + ), + ) group.add_argument( "--tensorizer-uri", @@ -418,7 +440,8 @@ class TensorizerArgs: type=str, default=None, help="The file path to a binary file containing a binary key to " - "use for decryption. Can be a file path or S3 network URI.") + "use for decryption. Can be a file path or S3 network URI.", + ) group.add_argument( "--num-readers", default=None, @@ -426,7 +449,8 @@ class TensorizerArgs: help="Controls how many threads are allowed to read concurrently " "from the source file. Default is `None`, which will dynamically " "set the number of readers based on the available resources " - "and model size. This greatly increases performance.") + "and model size. This greatly increases performance.", + ) group.add_argument( "--s3-access-key-id", type=str, @@ -454,72 +478,81 @@ class TensorizerArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": attrs = [attr.name for attr in dataclasses.fields(cls)] - tensorizer_args = cls(**{ - attr: getattr(args, attr) - for attr in attrs if hasattr(args, attr) - }) + tensorizer_args = cls( + **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + ) return tensorizer_args def _check_tensors_on_meta_device(model: nn.Module) -> None: for tensor in model.state_dict().values(): - if tensor.device.type == 'meta': + if tensor.device.type == "meta": raise ValueError( "The serialized model contains tensors on the meta device," " indicating that some tensors were not loaded properly." " Please check that the parameters of the model being" " specified match that of the serialized model, such as" - " its quantization.") + " its quantization." + ) def _resize_lora_embeddings(model: nn.Module): """Modify LoRA embedding layers to use bigger tensors to allow for adapter added tokens.""" for child in model.modules(): - if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0] - < child.num_embeddings_per_partition): - new_weight = torch.empty(child.num_embeddings_per_partition, - child.embedding_dim, - dtype=child.weight.dtype, - device=child.weight.device) - new_weight[:child.weight.shape[0]].copy_(child.weight.data) - new_weight[child.weight.shape[0]:].fill_(0) + if ( + isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < child.num_embeddings_per_partition + ): + new_weight = torch.empty( + child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device, + ) + new_weight[: child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0] :].fill_(0) child.weight.data = new_weight -def init_tensorizer_model(tensorizer_config: TensorizerConfig, - vllm_config: VllmConfig) -> nn.Module: +def init_tensorizer_model( + tensorizer_config: TensorizerConfig, vllm_config: VllmConfig +) -> nn.Module: assert tensorizer_config.hf_config is not None model_args = tensorizer_config.hf_config model_args.torch_dtype = tensorizer_config.dtype assert tensorizer_config.model_class is not None # TODO: Do we need to consider old-style model class? - with meta_tensor_mode(), set_current_vllm_config(vllm_config, - check_compile=True): + with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True): return tensorizer_config.model_class(vllm_config=vllm_config) -def deserialize_tensorizer_model(model: nn.Module, - tensorizer_config: TensorizerConfig) -> None: +def deserialize_tensorizer_model( + model: nn.Module, tensorizer_config: TensorizerConfig +) -> None: tensorizer_args = tensorizer_config._construct_tensorizer_args() if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri): raise ValueError( f"{tensorizer_config.tensorizer_uri} is not a valid " f"tensorizer URI. Please check that the URI is correct. " f"It must either point to a local existing file, or have a " - f"S3, HTTP or HTTPS scheme.") + f"S3, HTTP or HTTPS scheme." + ) before_mem = get_mem_usage() start = time.perf_counter() - with open_stream( - tensorizer_config.tensorizer_uri, - mode="rb", - **tensorizer_args.stream_kwargs) as stream, TensorDeserializer( - stream, - dtype=tensorizer_config.dtype, - device=f'xpu:{torch.xpu.current_device()}' - if current_platform.is_xpu() else - f'cuda:{torch.cuda.current_device()}', - **tensorizer_args.deserialization_kwargs) as deserializer: + with ( + open_stream( + tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs + ) as stream, + TensorDeserializer( + stream, + dtype=tensorizer_config.dtype, + device=f"xpu:{torch.xpu.current_device()}" + if current_platform.is_xpu() + else f"cuda:{torch.cuda.current_device()}", + **tensorizer_args.deserialization_kwargs, + ) as deserializer, + ): deserializer.load_into_module(model) end = time.perf_counter() @@ -528,8 +561,9 @@ def deserialize_tensorizer_model(model: nn.Module, per_second = convert_bytes(deserializer.total_tensor_bytes / duration) after_mem = get_mem_usage() deserializer.close() - logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, - end - start, per_second) + logger.info( + "Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second + ) logger.info("Memory usage before: %s", before_mem) logger.info("Memory usage after: %s", after_mem) @@ -539,20 +573,21 @@ def deserialize_tensorizer_model(model: nn.Module, def tensorizer_weights_iterator( - tensorizer_args: "TensorizerArgs" + tensorizer_args: "TensorizerArgs", ) -> Generator[tuple[str, torch.Tensor], None, None]: - logger.warning("Deserializing HuggingFace models is not optimized for " - "loading on vLLM, as tensorizer is forced to load to CPU. " - "Consider deserializing a vLLM model instead for faster " - "load times. See the " - "examples/others/tensorize_vllm_model.py example script " - "for serializing vLLM models.") + logger.warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the " + "examples/others/tensorize_vllm_model.py example script " + "for serializing vLLM models." + ) deserializer_args = tensorizer_args.deserialization_kwargs stream_kwargs = tensorizer_args.stream_kwargs stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs) - with TensorDeserializer(stream, **deserializer_args, - device="cpu") as state: + with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: yield from state.items() del state @@ -570,41 +605,54 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: bool: True if the model is a vLLM model, False otherwise. """ tensorizer_args = tensorizer_config._construct_tensorizer_args() - deserializer = TensorDeserializer(open_stream( - tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), - **tensorizer_args.deserialization_kwargs, - lazy_load=True) + deserializer = TensorDeserializer( + open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), + **tensorizer_args.deserialization_kwargs, + lazy_load=True, + ) if tensorizer_config.vllm_tensorized: logger.warning( "Please note that newly serialized vLLM models are automatically " "inferred as vLLM models, so setting vllm_tensorized=True is " - "only necessary for models serialized prior to this change.") + "only necessary for models serialized prior to this change." + ) return True return ".vllm_tensorized_marker" in deserializer def serialize_extra_artifacts( - tensorizer_args: TensorizerArgs, - served_model_name: Union[str, list[str], None]) -> None: + tensorizer_args: TensorizerArgs, served_model_name: Union[str, list[str], None] +) -> None: if not isinstance(served_model_name, str): raise ValueError( f"served_model_name must be a str for serialize_extra_artifacts, " - f"not {type(served_model_name)}.") + f"not {type(served_model_name)}." + ) with tempfile.TemporaryDirectory() as tmpdir: - snapshot_download(served_model_name, - local_dir=tmpdir, - ignore_patterns=[ - "*.pt", "*.safetensors", "*.bin", "*.cache", - "*.gitattributes", "*.md" - ]) + snapshot_download( + served_model_name, + local_dir=tmpdir, + ignore_patterns=[ + "*.pt", + "*.safetensors", + "*.bin", + "*.cache", + "*.gitattributes", + "*.md", + ], + ) for artifact in os.scandir(tmpdir): if not artifact.is_file(): continue - with open(artifact.path, "rb") as f, open_stream( + with ( + open(artifact.path, "rb") as f, + open_stream( f"{tensorizer_args.tensorizer_dir}/{artifact.name}", mode="wb+", - **tensorizer_args.stream_kwargs) as stream: + **tensorizer_args.stream_kwargs, + ) as stream, + ): logger.info("Writing artifact %s", artifact.name) stream.write(f.read()) @@ -616,7 +664,8 @@ def serialize_vllm_model( ) -> nn.Module: model.register_parameter( "vllm_tensorized_marker", - nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False), + ) tensorizer_args = tensorizer_config._construct_tensorizer_args() @@ -629,13 +678,17 @@ def serialize_vllm_model( output_file = tensorizer_args.tensorizer_uri if tensorizer_config._is_sharded: from vllm.distributed import get_tensor_model_parallel_rank + output_file = output_file % get_tensor_model_parallel_rank() - with open_stream(output_file, mode="wb+", - **tensorizer_args.stream_kwargs) as stream: - serializer = TensorSerializer(stream, - encryption=encryption_params, - **tensorizer_config.serialization_kwargs) + with open_stream( + output_file, mode="wb+", **tensorizer_args.stream_kwargs + ) as stream: + serializer = TensorSerializer( + stream, + encryption=encryption_params, + **tensorizer_config.serialization_kwargs, + ) serializer.write_module(model) serializer.close() @@ -645,51 +698,47 @@ def serialize_vllm_model( return model -def tensorize_vllm_model(engine_args: "EngineArgs", - tensorizer_config: TensorizerConfig, - generate_keyfile: bool = True): +def tensorize_vllm_model( + engine_args: "EngineArgs", + tensorizer_config: TensorizerConfig, + generate_keyfile: bool = True, +): """Utility to load a model and then serialize it with Tensorizer - Intended to be used separately from running a vLLM server since it - creates its own Engine instance. + Intended to be used separately from running a vLLM server since it + creates its own Engine instance. """ engine_config = engine_args.create_engine_config() tensorizer_config.verify_with_model_config(engine_config.model_config) - tensorizer_config.verify_with_parallel_config( - engine_config.parallel_config) + tensorizer_config.verify_with_parallel_config(engine_config.parallel_config) # generate the encryption key before creating the engine to support sharding - if generate_keyfile and (keyfile := - tensorizer_config.encryption_keyfile) is not None: + if ( + generate_keyfile + and (keyfile := tensorizer_config.encryption_keyfile) is not None + ): encryption_params = EncryptionParams.random() with open_stream( - keyfile, - mode="wb+", - s3_access_key_id=tensorizer_config.s3_access_key_id, - s3_secret_access_key=tensorizer_config.s3_secret_access_key, - s3_endpoint=tensorizer_config.s3_endpoint, + keyfile, + mode="wb+", + s3_access_key_id=tensorizer_config.s3_access_key_id, + s3_secret_access_key=tensorizer_config.s3_secret_access_key, + s3_endpoint=tensorizer_config.s3_endpoint, ) as stream: stream.write(encryption_params.key) - from vllm import LLMEngine - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + assert envs.VLLM_USE_V1 - if not envs.VLLM_USE_V1: - engine = LLMEngine.from_engine_args(engine_args) - engine.model_executor.collective_rpc( - "save_tensorized_model", - kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, - ) - else: - engine = V1LLMEngine.from_vllm_config(engine_config) - engine.collective_rpc( - "save_tensorized_model", - kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, - ) + from vllm.v1.engine.llm_engine import LLMEngine + + engine = LLMEngine.from_vllm_config(engine_config) + engine.collective_rpc( + "save_tensorized_model", + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, + ) -def tensorize_lora_adapter(lora_path: str, - tensorizer_config: TensorizerConfig): +def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig): """ Uses tensorizer to serialize a LoRA adapter. Assumes that the files needed to load a LoRA adapter are a safetensors-format file called @@ -725,19 +774,20 @@ def tensorize_lora_adapter(lora_path: str, tensorizer_args = tensorizer_config._construct_tensorizer_args() - with open_stream(f"{tensorizer_config.tensorizer_dir}/adapter_config.json", - mode="wb+", - **tensorizer_args.stream_kwargs) as f: - + with open_stream( + f"{tensorizer_config.tensorizer_dir}/adapter_config.json", + mode="wb+", + **tensorizer_args.stream_kwargs, + ) as f: f.write(json.dumps(config).encode("utf-8")) - lora_uri = (f"{tensorizer_config.tensorizer_dir}" - f"/adapter_model.tensors") - with open_stream(lora_uri, mode="wb+", - **tensorizer_args.stream_kwargs) as f: + lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors" + with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f: serializer = TensorSerializer(f) serializer.write_state_dict(tensors) serializer.close() - logger.info("Successfully serialized LoRA files to %s", - str(tensorizer_config.tensorizer_dir)) + logger.info( + "Successfully serialized LoRA files to %s", + str(tensorizer_config.tensorizer_dir), + ) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4cee..5585a74f8926e 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -8,15 +8,23 @@ from typing import Union import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, - is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator) -from vllm.model_executor.model_loader.utils import (get_model_architecture, - initialize_model, - set_default_torch_dtype) + TensorizerConfig, + deserialize_tensorizer_model, + init_tensorizer_model, + is_vllm_tensorized, + serialize_vllm_model, + tensorizer_weights_iterator, +) +from vllm.model_executor.model_loader.utils import ( + get_model_architecture, + initialize_model, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -43,15 +51,18 @@ class TensorizerLoader(BaseModelLoader): else: validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( - **load_config.model_loader_extra_config["tensorizer_config"]) + **load_config.model_loader_extra_config["tensorizer_config"] + ) - def _verify_config(self, model_config: ModelConfig, - parallel_config: ParallelConfig): + def _verify_config( + self, model_config: ModelConfig, parallel_config: ParallelConfig + ): self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( - self, ) -> Generator[tuple[str, torch.Tensor], None, None]: + self, + ) -> Generator[tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) @@ -81,8 +92,7 @@ class TensorizerLoader(BaseModelLoader): with self.tensorizer_config.open_stream(): pass - def _patch_tensorizer_config( - self, model_config: ModelConfig) -> TensorizerConfig: + def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig: model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -90,8 +100,7 @@ class TensorizerLoader(BaseModelLoader): tensorizer_config.dtype = model_config.dtype return tensorizer_config - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load serialized model weights with tensorizer. Expects a vLLM-tensorized model. See the @@ -103,8 +112,9 @@ class TensorizerLoader(BaseModelLoader): else: model.load_weights(self._get_weights_iterator()) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -112,8 +122,8 @@ class TensorizerLoader(BaseModelLoader): from vllm.distributed import get_tensor_model_parallel_rank self.tensorizer_config.tensorizer_uri = ( - self.tensorizer_config.tensorizer_uri % - get_tensor_model_parallel_rank()) + self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank() + ) if is_vllm_tensorized(self.tensorizer_config): tensorizer_config = self._patch_tensorizer_config(model_config) @@ -121,8 +131,8 @@ class TensorizerLoader(BaseModelLoader): with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = init_tensorizer_model( - tensorizer_config=tensorizer_config, - vllm_config=vllm_config) + tensorizer_config=tensorizer_config, vllm_config=vllm_config + ) self.load_weights(model, model_config) return model return self._load_model_serialized_cpu(vllm_config=vllm_config) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index a70cdeb483e67..fc97003de8e3c 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -13,7 +13,10 @@ from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model from vllm.logger import init_logger from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -34,33 +37,31 @@ class TPUModelLoader(DefaultModelLoader): self.counter_before_loading_weights = time.perf_counter() model_config = vllm_config.model_config assert model_config.quantization is None, "Quantization not supported" - target_device = torch.device('cpu') + target_device = torch.device("cpu") with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config) load_format = vllm_config.load_config.load_format if load_format != "dummy": - weights_to_load = { - name - for name, _ in model.named_parameters() - } + weights_to_load = {name for name, _ in model.named_parameters()} all_weights = self.get_all_weights(model_config, model) loaded_weights = model.load_weights(all_weights) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) + self.counter_after_loading_weights + - self.counter_before_loading_weights, + ) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. - if model_config.quantization is None and \ - loaded_weights is not None: + if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: raise ValueError( "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + f"checkpoint: {weights_not_loaded}" + ) else: logger.info("Use dummy weight during weight loading.") @@ -68,11 +69,13 @@ class TPUModelLoader(DefaultModelLoader): counter_before_partition = time.perf_counter() model = model.eval() - model = model.to('xla') + model = model.to("xla") shard_model(model, mesh) counter_after_partition = time.perf_counter() - logger.info("Partition model took %.2f seconds", - counter_after_partition - counter_before_partition) + logger.info( + "Partition model took %.2f seconds", + counter_after_partition - counter_before_partition, + ) # Ensure the model is properly loaded. self._check_model_is_loaded(mesh, model) @@ -82,12 +85,12 @@ class TPUModelLoader(DefaultModelLoader): if not model_config.is_multimodal_model: model.model = torch.compile(model.model, backend="openxla") else: - model.language_model.model = \ - torch.compile(model.language_model.model, backend="openxla") + model.language_model.model = torch.compile( + model.language_model.model, backend="openxla" + ) return model - def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], - model: nn.Module) -> None: + def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], model: nn.Module) -> None: """ Ensure the model is properly loaded. 1. All model parameters and buffers are on XLA device. @@ -99,16 +102,18 @@ class TPUModelLoader(DefaultModelLoader): # Check parameters for name, param in model.named_parameters(): assert param.device.type == device_type, ( - f"Parameter {name} is on {param.device.type} " - f"instead of {device_type}") + f"Parameter {name} is on {param.device.type} instead of {device_type}" + ) # Check buffers for name, buffer in model.named_buffers(): assert buffer.device.type == device_type, ( - f"Buffer {name} is on {buffer.device.type} " - f"instead of {device_type}") + f"Buffer {name} is on {buffer.device.type} instead of {device_type}" + ) for module in model.modules(): - if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): - raise AssertionError("QKVParallelLinear should be replaced by \ - XlaQKVParallelLinear under SPMD mode.") + if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"): + raise AssertionError( + "QKVParallelLinear should be replaced by \ + XlaQKVParallelLinear under SPMD mode." + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f57ebdb1abcbc..5ae32f1d120c0 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for selecting and loading models.""" + import contextlib import inspect import warnings @@ -13,16 +14,20 @@ from torch import nn from typing_extensions import assert_never from vllm.attention import Attention -from vllm.config import (ModelConfig, ModelImpl, VllmConfig, - set_current_vllm_config) +from vllm.attention.layer import MLAAttention +from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model, - as_seq_cls_model) -from vllm.model_executor.models.interfaces import SupportsQuant + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.models.adapters import ( + as_embedding_model, + as_reward_model, + as_seq_cls_model, + try_create_mm_pooling_model_cls, +) +from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -57,16 +62,16 @@ def initialize_model( all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - with set_current_vllm_config(vllm_config, - check_compile=True, - prefix=prefix): + with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): return model_class(vllm_config=vllm_config, prefix=prefix) - msg = ("vLLM model class should accept `vllm_config` and `prefix` as " - "input arguments. Possibly you have an old-style model class" - " registered from out of tree and it is used for new vLLM version. " - "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " - "for the design and update the model class accordingly.") + msg = ( + "vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " + "for the design and update the model class accordingly." + ) warnings.warn(msg, DeprecationWarning, stacklevel=2) logger.warning( @@ -87,20 +92,21 @@ def initialize_model( kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - with set_current_vllm_config(vllm_config, - check_compile=True, - prefix=prefix): + with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): return model_class(**kwargs) -def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, - target_device: torch.device) -> None: +def process_weights_after_loading( + model: nn.Module, model_config: ModelConfig, target_device: torch.device +) -> None: + # to avoid circular dependency + from vllm.model_executor.model_loader.online_quantization import ( + maybe_save_metadata_and_attributes_for_weight_reloading, + ) + + maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) + for _, module in model.named_modules(): - if isinstance(module, QKVCrossParallelLinear): - # NOTE(Isotr0py): special case for cross QKV layer because - # q and kv proj aren't registered as submodules intentionally - module.process_weights_after_loading() - continue quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading @@ -111,20 +117,19 @@ def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - # Currently only used by MLA. - # NOTE: This intentionally happens after other modules so we can easily - # decompress the weights for MLA. + # Initialize post-load attention weights for both Attention and MLA. + # NOTE: Happens after other modules so we can easily decompress weights. for _, module in model.named_modules(): - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): + if isinstance(module, (Attention, MLAAttention)) and hasattr( + module, "process_weights_after_loading" + ): # TODO(lucas): see if there is a way to unify the signatures # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) @contextmanager -def device_loading_context(module: torch.nn.Module, - target_device: torch.device): +def device_loading_context(module: torch.nn.Module, target_device: torch.device): if target_device.type == "cpu": # If target is CPU, no need to move anything yield module @@ -165,40 +170,38 @@ def device_loading_context(module: torch.nn.Module, # New parameters or parameters already on target device are untouched -def get_model_architecture( - model_config: ModelConfig) -> tuple[type[nn.Module], str]: +_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() +"""Caches the outputs of `_get_model_architecture`.""" + + +def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - mixtral_supported = [ - "fp8", - "compressed-tensors", - "gptq_marlin", - "awq_marlin", - "quark", - "bitsandbytes", - ] - - if (model_config.quantization is not None - and model_config.quantization not in mixtral_supported - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - model_cls, arch = model_config.registry.resolve_model_cls( architectures, model_config=model_config, ) if arch == model_config._get_transformers_backend_cls(): - assert model_config.model_impl != ModelImpl.VLLM - if model_config.model_impl == ModelImpl.AUTO: + assert model_config.model_impl != "vllm" + if model_config.model_impl == "auto": logger.warning_once( "%s has no vLLM implementation, falling back to Transformers " "implementation. Some features may not be supported and " - "performance may not be optimal.", arch) + "performance may not be optimal.", + arch, + ) convert_type = model_config.convert_type + if convert_type != "none" and supports_multimodal(model_cls): + logger.debug_once("Detected conversion of Multi Modal model.") + converted = try_create_mm_pooling_model_cls(model_cls) + if converted is not None: + logger.debug_once("Creating wrapper class to forward pooler.") + return converted, arch + else: + logger.debug_once("Attempting direct conversion.") + if convert_type == "none": pass elif convert_type == "embed": @@ -216,6 +219,25 @@ def get_model_architecture( return model_cls, arch +def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: + key = hash( + ( + model_config.model, + model_config.convert_type, + model_config.runner_type, + model_config.trust_remote_code, + model_config.model_impl, + tuple(getattr(model_config.hf_config, "architectures", [])), + ) + ) + if key in _MODEL_ARCH_BY_HASH: + return _MODEL_ARCH_BY_HASH[key] + + model_arch = _get_model_architecture(model_config) + _MODEL_ARCH_BY_HASH[key] = model_arch + return model_arch + + def get_model_cls(model_config: ModelConfig) -> type[nn.Module]: return get_model_architecture(model_config)[0] @@ -228,12 +250,12 @@ def get_architecture_class_name(model_config: ModelConfig) -> str: class ParamMapping: """ A class to handle parameter mapping for model weight loading. - It creates a bidirectional mapping between packed parameters and their + It creates a bidirectional mapping between packed parameters and their constituent parts. """ + packed_mapping: dict[str, list[str]] - inverse_packed_mapping: dict[str, tuple[str, - int]] = field(default_factory=dict) + inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict) def __post_init__(self): for packed_name, sub_params in self.packed_mapping.items(): @@ -246,16 +268,16 @@ class ParamMapping: index, ) - def get_sub_modules(self, - module_name: str) -> Optional[tuple[str, list[str]]]: + def get_sub_modules(self, module_name: str) -> Optional[tuple[str, list[str]]]: for key, value in self.packed_mapping.items(): if module_name.endswith(key): return key, value return None -def configure_quant_config(quant_config: QuantizationConfig, - model_class: type[nn.Module]): +def configure_quant_config( + quant_config: QuantizationConfig, model_class: type[nn.Module] +): """ Pass packed_modules_mapping by reference to quant_config so that quant_config can properly match fused modules diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 3bb47f82d2f37..5f83482bec3a0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for downloading and initializing model weights.""" + +import concurrent.futures import fnmatch import glob import hashlib @@ -10,32 +12,35 @@ import tempfile import time from collections import defaultdict from collections.abc import Generator +from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import IO, Any, Callable, Optional, Union import filelock import huggingface_hub.constants import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download -from safetensors.torch import load_file, safe_open, save_file +from safetensors.torch import load, load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.config import LoadConfig, ModelConfig +from vllm import envs +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QuantizationConfig, - get_quantization_config) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + get_quantization_config, +) from vllm.platforms import current_platform from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer except ImportError: - runai_model_streamer = PlaceholderModule( - "runai_model_streamer") # type: ignore[assignment] - SafetensorsStreamer = runai_model_streamer.placeholder_attr( - "SafetensorsStreamer") + runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] + SafetensorsStreamer = runai_model_streamer.placeholder_attr("SafetensorsStreamer") try: import gguf @@ -46,10 +51,11 @@ try: from fastsafetensors import SafeTensorsFileLoader, SingleGroup except ImportError: fastsafetensors = PlaceholderModule("fastsafetensors") - SafeTensorsFileLoader = fastsafetensors.placeholder_attr( - "SafeTensorsFileLoader") + SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader") SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") +from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least + logger = init_logger(__name__) # use system-level temp directory for file locks, so that multiple users @@ -60,12 +66,12 @@ temp_dir = tempfile.gettempdir() def enable_hf_transfer(): - """automatically activates hf_transfer - """ + """automatically activates hf_transfer""" if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True except ImportError: pass @@ -75,13 +81,11 @@ enable_hf_transfer() class DisabledTqdm(tqdm): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) -def get_lock(model_name_or_path: Union[str, Path], - cache_dir: Optional[str] = None): +def get_lock(model_name_or_path: Union[str, Path], cache_dir: Optional[str] = None): lock_dir = cache_dir or temp_dir model_name_or_path = str(model_name_or_path) os.makedirs(os.path.dirname(lock_dir), exist_ok=True) @@ -90,11 +94,88 @@ def get_lock(model_name_or_path: Union[str, Path], # add hash to avoid conflict with old users' lock files lock_file_name = hash_name + model_name + ".lock" # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), - mode=0o666) + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) return lock +@contextmanager +def atomic_writer( + filepath: Union[str, Path], mode: str = "w", encoding: Optional[str] = None +) -> Generator[IO]: + """ + Context manager that provides an atomic file writing routine. + + The context manager writes to a temporary file and, if successful, + atomically replaces the original file. + + Args: + filepath (str or Path): The path to the file to write. + mode (str): The file mode for the temporary file (e.g., 'w', 'wb'). + encoding (str): The encoding for text mode. + + Yields: + file object: A handle to the temporary file. + """ + # Create a temporary file in the same directory as the target file + # to ensure it's on the same filesystem for an atomic replace. + temp_dir = os.path.dirname(filepath) + temp_fd, temp_path = tempfile.mkstemp(dir=temp_dir) + + try: + # Open the temporary file for writing + with os.fdopen(temp_fd, mode=mode, encoding=encoding) as temp_file: + yield temp_file + + # If the 'with' block completes successfully, + # perform the atomic replace. + os.replace(temp_path, filepath) + + except Exception: + logger.exception( + "Error during atomic write. Original file '%s' not modified", filepath + ) + raise + finally: + # Clean up the temporary file if it still exists. + if os.path.exists(temp_path): + os.remove(temp_path) + + +def maybe_download_from_modelscope( + model: str, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, +) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, download_dir): + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=ignore_patterns, + allow_patterns=allow_patterns, + ) + else: + model_path = model + return model_path + return None + + def _shared_pointers(tensors): ptrs = defaultdict(list) for k, v in tensors.items(): @@ -144,9 +225,9 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. -def get_quant_config(model_config: ModelConfig, - load_config: LoadConfig) -> QuantizationConfig: - +def get_quant_config( + model_config: ModelConfig, load_config: LoadConfig +) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) # GGUF doesn't have config file @@ -154,22 +235,55 @@ def get_quant_config(model_config: ModelConfig, return quant_cls() # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", - None) + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) # some vision model may keep quantization_config in their text_config hf_text_config = getattr(model_config.hf_config, "text_config", None) if hf_quant_config is None and hf_text_config is not None: hf_quant_config = getattr(hf_text_config, "quantization_config", None) if hf_quant_config is None: # compressed-tensors uses a compressions_config - hf_quant_config = getattr(model_config.hf_config, "compression_config", - None) + hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) + + # if hf_quant_config is None, we will try to get config from + # hf_overrides + hf_overrides = model_config.hf_overrides + quantization_config_file = hf_overrides.get("quantization_config_file", None) + if quantization_config_file is not None: + if hasattr(quant_cls, "from_config_file"): + return quant_cls.from_config_file(quantization_config_file) + else: + raise NotImplementedError( + "from_config_file is specified in hf_override config, " + "but quant_cls.from_config_file is not implemented in " + f"{quant_cls}" + ) + quantization_config_json = hf_overrides.get("quantization_config_dict_json", None) + if quantization_config_json is not None: + if hasattr(quant_cls, "from_config_dict_json"): + return quant_cls.from_config_dict_json(quantization_config_json) + else: + raise NotImplementedError( + "from_config_dict_json is specified in hf_override config, " + "but quant_cls.from_config_dict_json is not implemented in " + f"{quant_cls}" + ) + # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({}) - is_local = os.path.isdir(model_config.model) + model_name_or_path = ( + maybe_download_from_modelscope( + model_config.model, + revision=model_config.revision, + download_dir=load_config.download_dir, + allow_patterns=["*.json"], + ) + or model_config.model + ) + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. with get_lock(model_config.model, load_config.download_dir): @@ -182,7 +296,7 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm, ) else: - hf_folder = model_config.model + hf_folder = model_name_or_path possible_config_filenames = quant_cls.get_config_filenames() @@ -193,16 +307,15 @@ def get_quant_config(model_config: ModelConfig, config_files = glob.glob(os.path.join(hf_folder, "*.json")) quant_config_files = [ - f for f in config_files if any( - f.endswith(x) for x in possible_config_filenames) + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) ] if len(quant_config_files) == 0: - raise ValueError( - f"Cannot find the config file for {model_config.quantization}") + raise ValueError(f"Cannot find the config file for {model_config.quantization}") if len(quant_config_files) > 1: raise ValueError( f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}") + f"{quant_config_files}" + ) quant_config_file = quant_config_files[0] with open(quant_config_file) as f: @@ -216,7 +329,8 @@ def get_quant_config(model_config: ModelConfig, else: raise ValueError( f"Unsupported quantization config" - f" found for {model_config.quantization} in {f}.") + f" found for {model_config.quantization} in {f}." + ) return quant_cls.from_config(config) @@ -278,37 +392,56 @@ def download_weights_from_hf( Returns: str: The path to the downloaded model weights. """ + assert len(allow_patterns) > 0 local_only = huggingface_hub.constants.HF_HUB_OFFLINE if not local_only: - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + # Attempt to reduce allow_patterns to a single pattern + # so we only have to call snapshot_download once. + try: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] + # Use the first pattern found in the HF repo's files. + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] break + except Exception as e: + logger.warning( + "Failed to get file list for '%s'. Trying each pattern in " + "allow_patterns individually until weights have been " + "downloaded. Error: %s", + model_name_or_path, + e, + ) logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): start_time = time.perf_counter() - hf_folder = snapshot_download( - model_name_or_path, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision, - local_files_only=local_only, - ) + for allow_pattern in allow_patterns: + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_pattern, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=local_only, + ) + # If we have downloaded weights for this allow_pattern, + # we don't need to check the rest. + if any(Path(hf_folder).glob(allow_pattern)): + break time_taken = time.perf_counter() - start_time if time_taken > 0.5: - logger.info("Time spent downloading weights for %s: %.6f seconds", - model_name_or_path, time_taken) + logger.info( + "Time spent downloading weights for %s: %.6f seconds", + model_name_or_path, + time_taken, + ) return hf_folder @@ -352,9 +485,9 @@ def download_safetensors_index_file_from_hf( # Passing both of these to the weight loader functionality breaks. # So, we use the index_file to # look up which safetensors files should be used. -def filter_duplicate_safetensors_files(hf_weights_files: list[str], - hf_folder: str, - index_file: str) -> list[str]: +def filter_duplicate_safetensors_files( + hf_weights_files: list[str], hf_folder: str, index_file: str +) -> list[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. index_file_name = os.path.join(hf_folder, index_file) @@ -367,17 +500,13 @@ def filter_duplicate_safetensors_files(hf_weights_files: list[str], weight_map = json.load(f)["weight_map"] weight_files_in_index = set() for weight_name in weight_map: - weight_files_in_index.add( - os.path.join(hf_folder, weight_map[weight_name])) + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) # Filter out any fields that are not found in the index file. - hf_weights_files = [ - f for f in hf_weights_files if f in weight_files_in_index - ] + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] return hf_weights_files -def filter_files_not_needed_for_inference( - hf_weights_files: list[str]) -> list[str]: +def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]: """ Exclude files that are not needed for inference. @@ -391,8 +520,7 @@ def filter_files_not_needed_for_inference( "scaler.pt", ] hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) ] return hf_weights_files @@ -405,8 +533,9 @@ _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elap def enable_tqdm(use_tqdm_on_load: bool): - return use_tqdm_on_load and (not torch.distributed.is_initialized() - or torch.distributed.get_rank() == 0) + return use_tqdm_on_load and ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) def np_cache_weights_iterator( @@ -431,14 +560,12 @@ def np_cache_weights_iterator( if not os.path.exists(weight_names_file): weight_names: list[str] = [] for bin_file in tqdm( - hf_weights_files, - desc="Loading np_cache checkpoint shards", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, - map_location="cpu", - weights_only=True) + state = torch.load(bin_file, map_location="cpu", weights_only=True) for name, param in state.items(): param_path = os.path.join(np_folder, name) with open(param_path, "wb") as f: @@ -460,18 +587,71 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + safetensors_load_strategy: str = "lazy", ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" + loading_desc = "Loading safetensors checkpoint shards" + if safetensors_load_strategy == "eager": + loading_desc += " (eager)" + for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors checkpoint shards", + hf_weights_files, + desc=loading_desc, + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + if safetensors_load_strategy == "eager": + with open(st_file, "rb") as f: + state_dict = load(f.read()) + yield from state_dict.items() + elif safetensors_load_strategy == "torchao": + if not torchao_version_at_least("0.14.0"): + raise ValueError( + "Please use torchao version >= 0.14.0 \ + to load torchao safetensors checkpoint" + ) + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) + + with safe_open(st_file, framework="pt") as f: + state_dict = {} + for name in f.keys(): # noqa: SIM118 + state_dict[name] = f.get_tensor(name) + metadata = f.metadata() + updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata) + yield from updated_state_dict.items() + else: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def multi_thread_safetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model safetensor files.""" + + def _load_file(st_file: str): + result = load_file(st_file, device="cpu") + return result + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading shards", disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, - ): - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param + ) + + for future in futures_iter: + state_dict = future.result() + yield from state_dict.items() def runai_safetensors_weights_iterator( @@ -483,7 +663,8 @@ def runai_safetensors_weights_iterator( streamer.stream_files(hf_weights_files) total_tensors = sum( len(tensors_meta) - for tensors_meta in streamer.files_to_tensors_metadata.values()) + for tensors_meta in streamer.files_to_tensors_metadata.values() + ) tensor_iter = tqdm( streamer.get_tensors(), @@ -496,6 +677,19 @@ def runai_safetensors_weights_iterator( yield from tensor_iter +def _init_loader( + pg: torch.distributed.ProcessGroup, + device: torch.device, + f_list: list[str], + *, + nogds: bool = False, +): + loader = SafeTensorsFileLoader(pg, device, nogds=nogds) + rank_file_map = {i: [f] for i, f in enumerate(f_list)} + loader.add_filenames(rank_file_map) + return loader + + def fastsafetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -507,23 +701,37 @@ def fastsafetensors_weights_iterator( else: pg = SingleGroup() - device = torch.device(f'cuda:{pg.rank()}') + device = torch.device(f"cuda:{pg.rank()}") weight_files_sub_lists = [ - hf_weights_files[i:i + pg.size()] + hf_weights_files[i : i + pg.size()] for i in range(0, len(hf_weights_files), pg.size()) ] + nogds = False + for f_list in tqdm( - weight_files_sub_lists, - desc="Loading safetensors using Fastsafetensor loader", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + weight_files_sub_lists, + desc="Loading safetensors using Fastsafetensor loader", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): - loader = SafeTensorsFileLoader(pg, device) - rank_file_map = {i: [f] for i, f in enumerate(f_list)} - loader.add_filenames(rank_file_map) + loader = _init_loader(pg, device, f_list, nogds=nogds) try: - fb = loader.copy_files_to_device() + try: + fb = loader.copy_files_to_device() + except RuntimeError as e: + if "gds" not in str(e): + raise + + loader.close() + nogds = True + logger.warning_once( + "GDS not enabled, setting `nogds=True`.\n" + "For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages" + ) + loader = _init_loader(pg, device, f_list, nogds=nogds) + fb = loader.copy_files_to_device() + try: keys = list(fb.key_to_rank_lidx.keys()) for k in keys: @@ -542,20 +750,52 @@ def pt_weights_iterator( ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" for bin_file in tqdm( - hf_weights_files, - desc="Loading pt checkpoint shards", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, - map_location=pt_load_map_location, - weights_only=True) + state = torch.load( + bin_file, map_location=pt_load_map_location, weights_only=True + ) yield from state.items() del state +def multi_thread_pt_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model bin/pt files.""" + + def _load_file(bin_file: str): + return torch.load( + bin_file, map_location=pt_load_map_location, weights_only=True + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, bin_file) for bin_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state = future.result() + yield from state.items() + del state + + def get_gguf_extra_tensor_names( - gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> list[str]: reader = gguf.GGUFReader(gguf_file) expected_gguf_keys = set(gguf_to_hf_name_map.keys()) exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) @@ -564,14 +804,16 @@ def get_gguf_extra_tensor_names( def get_gguf_weight_type_map( - gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]: + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> dict[str, str]: """ Return GGUF mapped weight's name and its quant type """ reader = gguf.GGUFReader(gguf_file) return { gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name - for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map + for tensor in reader.tensors + if tensor.name in gguf_to_hf_name_map } @@ -621,8 +863,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: return x -def default_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" try: if param.numel() == 1 and loaded_weight.numel() == 1: @@ -633,7 +874,8 @@ def default_weight_loader(param: torch.Tensor, else: assert param.size() == loaded_weight.size(), ( f"Attempted to load weight ({loaded_weight.size()}) " - f"into parameter ({param.size()})") + f"into parameter ({param.size()})" + ) param.data.copy_(loaded_weight) except Exception: @@ -642,8 +884,9 @@ def default_weight_loader(param: torch.Tensor, raise -def row_parallel_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: +def row_parallel_weight_loader( + param: torch.Tensor, loaded_weight: torch.Tensor +) -> None: """Load weights that are row-parallelized.""" tp_rank = get_tensor_model_parallel_rank() shard_dim = 0 if param.dim() != 1 else None @@ -675,12 +918,11 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction: def composed_weight_loader( - loader: LoaderFunction, fn: Callable[[torch.Tensor], - torch.Tensor]) -> LoaderFunction: + loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] +) -> LoaderFunction: """Create a weight loader that post-processes the weights after loading""" - def composed_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: + def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: loader(param, loaded_weight) param.data.copy_(fn(param)) return @@ -716,13 +958,18 @@ def initialize_dummy_weights( # from a CPU tensor. # Note: We avoid using torch.rank_like as it doesn't currently # support the generator argument. - param.copy_((high - low) * - torch.rand(param.shape, - generator=generator, - dtype=param.dtype, - layout=param.layout, - requires_grad=param.requires_grad, - device="cpu") + low) + param.copy_( + (high - low) + * torch.rand( + param.shape, + generator=generator, + dtype=param.dtype, + layout=param.layout, + requires_grad=param.requires_grad, + device="cpu", + ) + + low + ) torch._sync(param) continue @@ -732,8 +979,7 @@ def initialize_dummy_weights( # uniform_ doesn't support < 16-bit datatypes (FP8) dtype = param.data.dtype tmp_param = param.data.to(torch.float16) - tmp_param = tmp_param.uniform_(low, high, - generator=generator).to(dtype) + tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) param.data.copy_(tmp_param) else: param.uniform_(low, high, generator=generator) @@ -762,7 +1008,8 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: "This format is deprecated in favor of separate k_scale and " "v_scale tensors and will be removed in a future release. " "Functionally, we will remap kv_scale to k_scale and duplicate " - "k_scale to v_scale") + "k_scale to v_scale" + ) # NOTE: we remap the deprecated kv_scale to k_scale remapped_name = name.replace(".kv_scale", ".attn.k_scale") if remapped_name not in params_dict: @@ -774,19 +1021,28 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return None return remapped_name + if any("mla_attn" in key for key in params_dict): + attn_str = "mla_attn.mla_attn" + logger.debug_once( + f"Found mla_attn with k_scale and v_scale in " + f"the checkpoint, using {attn_str} as attn_str" + ) + else: + attn_str = "attn" # Define scale name mapping patterns in order of precedence scale_mapping_patterns = [ # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale - (r"\.self_attn\.([kv])_proj\.([kv])_scale$", - r".self_attn.attn.\2_scale"), + ( + r"\.self_attn\.([kv])_proj\.([kv])_scale$", + rf".self_attn.{attn_str}.\2_scale", + ), # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale - (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale" - ), + (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), # Default format: .{k,v}_scale -> .attn.{k,v}_scale (r"\.([kv])_scale$", r".attn.\1_scale"), ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d3ee6872dd8bf..b56cb33400480 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,12 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, - SupportsPP, SupportsTranscription, SupportsV0Only, - has_inner_state, supports_lora, supports_multimodal, - supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, - is_pooling_model, is_text_generation_model) +from .interfaces import ( + HasInnerState, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, + SupportsTranscription, + SupportsV0Only, + has_inner_state, + supports_lora, + supports_mrope, + supports_multimodal, + supports_pp, + supports_transcription, + supports_v0_only, +) +from .interfaces_base import ( + VllmModelForPooling, + VllmModelForTextGeneration, + is_pooling_model, + is_text_generation_model, +) from .registry import ModelRegistry __all__ = [ @@ -21,6 +37,8 @@ __all__ = [ "supports_lora", "SupportsMultiModal", "supports_multimodal", + "SupportsMRoPE", + "supports_mrope", "SupportsPP", "supports_pp", "SupportsTranscription", diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1dbe70f84a626..fd8a0b87e43ec 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,21 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import inspect from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig +from vllm.transformers_utils.config import get_hf_file_bytes, get_hf_file_to_dict from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig _T = TypeVar("_T", bound=type[nn.Module]) +logger = init_logger(__name__) + _GENERATE_SUFFIXES = [ "ForCausalLM", "ForConditionalGeneration", @@ -24,6 +32,104 @@ _GENERATE_SUFFIXES = [ ] +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + + try: + modules = get_hf_file_to_dict( + "modules.json", model_config.model, model_config.revision + ) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + layers = [] + for module in dense_modules: + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict( + config_path, model_config.model, model_config.revision + ) + if not layer_config: + continue + + linear = nn.Linear( + layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=model_config.head_dtype, + ) + + if not _load_dense_weights(linear, folder, model_config): + continue + + layers.append(linear) + if act_name := layer_config.get("activation_function"): + layers.append(get_act_fn(act_name)) + return nn.Sequential(*layers).to(dtype=model_config.head_dtype) + except Exception: + logger.exception("ST projector loading failed") + + return None + + +def _load_dense_weights( + linear: nn.Linear, folder: str, model_config: "ModelConfig" +) -> bool: + """Load weights using vLLM's weight_loader pattern.""" + from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + for filename in ["model.safetensors", "pytorch_model.bin"]: + file_path = f"{folder}/{filename}" if folder else filename + + try: + file_bytes = get_hf_file_bytes( + file_path, model_config.model, model_config.revision + ) + if not file_bytes: + continue + + if filename.endswith(".safetensors"): + from safetensors.torch import load as load_safetensors + + state_dict = load_safetensors(file_bytes) + else: + import io + + state_dict = torch.load( + io.BytesIO(file_bytes), map_location="cpu", weights_only=True + ) + + for weight_key in ["weight", "linear.weight", "dense.weight"]: + if weight_key in state_dict: + weight_loader = getattr( + linear.weight, "weight_loader", default_weight_loader + ) + weight_loader(linear.weight, state_dict[weight_key]) + + bias_key = weight_key.replace("weight", "bias") + if linear.bias is not None and bias_key in state_dict: + bias_loader = getattr( + linear.bias, "weight_loader", default_weight_loader + ) + bias_loader(linear.bias, state_dict[bias_key]) + return True + except Exception: + logger.exception("Failed to load %s", filename) + continue + + return False + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name @@ -33,12 +139,43 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix +def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: + class CallVisitor(ast.NodeVisitor): + def __init__(self): + self.calls = [] + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + self.calls.append(node.func.id) + self.generic_visit(node) + + visitor = CallVisitor() + visitor.visit(ast.parse(inspect.getsource(orig_cls))) + if "init_vllm_registered_model" not in visitor.calls: + return None + + class ModelForPooling(orig_cls, VllmModelForPooling): + is_pooling_model = True + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self.pooler = self.get_language_model().pooler + + return ModelForPooling # type: ignore + + def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import from .utils import AutoWeightsLoader, WeightsMapper class ModelForPooling(orig_cls, VllmModelForPooling): - is_pooling_model = True def __init__( @@ -68,8 +205,11 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: # TODO: Support uninitialized params tracking # We have deleted this attribute, so don't load it - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) + weights = ( + (name, data) + for name, data in weights + if not name.startswith("lm_head.") + ) # If `*ForCausalLM` defines `load_weights` on the inner model # and there are no other inner modules with parameters, @@ -78,7 +218,8 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: # Whether only `self.model` contains parameters model_is_only_param = all( name == "model" or next(child.parameters(), None) is None - for name, child in self.named_children()) + for name, child in self.named_children() + ) if model_is_only_param: mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) @@ -118,7 +259,6 @@ def as_embedding_model(cls: _T) -> _T: from vllm.model_executor.layers.pooler import DispatchPooler, Pooler class ModelForEmbedding(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -127,10 +267,10 @@ def as_embedding_model(cls: _T) -> _T: { "encode": Pooler.for_encode(pooler_config), "embed": Pooler.for_embed(pooler_config), - }, ) + }, + ) - ModelForEmbedding.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForEmbedding") + ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding") return ModelForEmbedding # type: ignore @@ -152,26 +292,30 @@ def as_seq_cls_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.linear import RowParallelLinear - from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler, - PoolingMethod, PoolingType) + from vllm.model_executor.layers.linear import ReplicatedLinear + from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + DispatchPooler, + Pooler, + PoolingMethod, + PoolingType, + ) from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors - from .utils import maybe_prefix - - class ModelForSequenceClassification(_create_pooling_model_cls(cls), - SupportsCrossEncoding): + from .utils import get_model_hidden_size, maybe_prefix + class ModelForSequenceClassification( + _create_pooling_model_cls(cls), SupportsCrossEncoding + ): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + hidden_size = get_model_hidden_size(config) - self.score = RowParallelLinear( - config.hidden_size, + self.score = ReplicatedLinear( + hidden_size, config.num_labels, - input_is_parallel=False, bias=False, params_dtype=torch.float32, quant_config=quant_config, @@ -185,24 +329,25 @@ def as_seq_cls_model(cls: _T) -> _T: assert pooling_type_str is not None pooling_type = PoolingType[pooling_type_str] - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def _classifier(self, x: torch.Tensor): x, _ = self.score(x.float()) @@ -215,8 +360,9 @@ def as_seq_cls_model(cls: _T) -> _T: intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return super().forward(input_ids, positions, intermediate_tensors, - inputs_embeds) + return super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(self.config, "classifier_from_token", None) @@ -229,9 +375,9 @@ def as_seq_cls_model(cls: _T) -> _T: # ForSequenceClassification model. return seq_cls_model_loader(self, weights) - - ModelForSequenceClassification.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForSequenceClassification") + ModelForSequenceClassification.__name__ = _get_pooling_model_name( + cls.__name__, "ForSequenceClassification" + ) return ModelForSequenceClassification # type: ignore @@ -254,22 +400,20 @@ def as_reward_model(cls: _T) -> _T: from vllm.model_executor.layers.pooler import DispatchPooler, Pooler class ModelForReward(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) - ModelForReward.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForReward") + ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") return ModelForReward # type: ignore class SequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -294,15 +438,15 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig): def load_weights_using_from_2_way_softmax( - model, weights: Iterable[tuple[str, torch.Tensor]]): + model, weights: Iterable[tuple[str, torch.Tensor]] +): # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 - from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) - from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader) + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) == 2 @@ -310,24 +454,28 @@ def load_weights_using_from_2_way_softmax( if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: - model.lm_head = ParallelLMHead(model.config.vocab_size, - model.config.hidden_size, - quant_config=model.quant_config) + quant_config = model.vllm_config.quant_config + model.lm_head = ParallelLMHead( + model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + ) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) from vllm.transformers_utils.tokenizer import get_tokenizer - tokenizer = get_tokenizer(model_config.tokenizer, - revision=model_config.tokenizer_revision, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code) + + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + ) false_id = tokenizer.convert_tokens_to_ids(tokens[0]) true_id = tokenizer.convert_tokens_to_ids(tokens[1]) score_weight = model.lm_head.weight.data[[true_id]].to( - torch.float32) - model.lm_head.weight.data[[false_id]].to( - torch.float32) + torch.float32 + ) - model.lm_head.weight.data[[false_id]].to(torch.float32) param = model.score.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -339,13 +487,9 @@ def load_weights_using_from_2_way_softmax( return loaded_weights -def load_weights_no_post_processing(model, - weights: Iterable[tuple[str, - torch.Tensor]]): - from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) - from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader) +def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]): + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -356,18 +500,22 @@ def load_weights_no_post_processing(model, if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: - model.lm_head = ParallelLMHead(model.config.vocab_size, - model.config.hidden_size, - quant_config=model.quant_config) + quant_config = model.vllm_config.quant_config + model.lm_head = ParallelLMHead( + model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + ) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) from vllm.transformers_utils.tokenizer import get_tokenizer - tokenizer = get_tokenizer(model_config.tokenizer, - revision=model_config.tokenizer_revision, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code) + + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + ) token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] score_weight = model.lm_head.weight.data[token_ids] diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index b13d863ebb744..2423ad5b0c3ad 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -14,19 +14,20 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.ovis import AIMv2Config class AIMv2SwiGLUFFN(nn.Module): - - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str + ): super().__init__() hidden_features = config.intermediate_size in_features = config.hidden_size @@ -56,7 +57,6 @@ class AIMv2SwiGLUFFN(nn.Module): class AIMv2PatchEmbed(nn.Module): - def __init__(self, config: AIMv2Config): super().__init__() self.proj = nn.Conv2d( @@ -74,14 +74,12 @@ class AIMv2PatchEmbed(nn.Module): class AIMv2ViTPreprocessor(nn.Module): - def __init__(self, config: AIMv2Config): super().__init__() - num_patches = (config.image_size // config.patch_size)**2 + num_patches = (config.image_size // config.patch_size) ** 2 self.patchifier = AIMv2PatchEmbed(config) - self.pos_embed = nn.Parameter( - torch.zeros((1, num_patches, config.hidden_size))) + self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size))) def forward(self, x: torch.Tensor) -> torch.Tensor: tokens = self.patchifier(x) @@ -92,9 +90,9 @@ class AIMv2ViTPreprocessor(nn.Module): class AIMv2Attention(nn.Module): - - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -104,7 +102,8 @@ class AIMv2Attention(nn.Module): raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.qkv = QKVParallelLinear( @@ -127,8 +126,9 @@ class AIMv2Attention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def forward(self, x: torch.Tensor) -> torch.Tensor: qkv, _ = self.qkv(x) @@ -140,17 +140,17 @@ class AIMv2Attention(nn.Module): class AIMv2Block(nn.Module): - - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str + ): super().__init__() - self.attn = AIMv2Attention(config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = AIMv2Attention( + config, quant_config=quant_config, prefix=f"{prefix}.attn" + ) self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = AIMv2SwiGLUFFN(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = AIMv2SwiGLUFFN( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -160,7 +160,6 @@ class AIMv2Block(nn.Module): class AIMv2Transformer(nn.Module): - def __init__( self, config: AIMv2Config, @@ -171,13 +170,14 @@ class AIMv2Transformer(nn.Module): ): super().__init__() - self.blocks = nn.ModuleList([ - AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") - for i in range(config.num_hidden_layers) - ]) + self.blocks = nn.ModuleList( + [ + AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") + for i in range(config.num_hidden_layers) + ] + ) if require_post_norm: - self.post_trunk_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.post_trunk_norm = None @@ -191,29 +191,30 @@ class AIMv2Transformer(nn.Module): class AIMv2Model(torch.nn.Module): - - def __init__(self, - config: AIMv2Config, - quant_config: QuantizationConfig, - *, - require_post_norm: Optional[bool] = None, - prefix: str = ""): + def __init__( + self, + config: AIMv2Config, + quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ): super().__init__() self.preprocessor = AIMv2ViTPreprocessor(config) - self.trunk = AIMv2Transformer(config, - quant_config=quant_config, - require_post_norm=require_post_norm, - prefix=f"{prefix}.trunk") + self.trunk = AIMv2Transformer( + config, + quant_config=quant_config, + require_post_norm=require_post_norm, + prefix=f"{prefix}.trunk", + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - x = self.preprocessor(pixel_values) x = self.trunk(x) return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".fc13", ".fc1", 0), @@ -224,11 +225,13 @@ class AIMv2Model(torch.nn.Module): for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel - if (name.startswith("trunk.post_trunk_norm") - and self.trunk.post_trunk_norm is None): + if ( + name.startswith("trunk.post_trunk_norm") + and self.trunk.post_trunk_norm is None + ): continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -239,8 +242,7 @@ class AIMv2Model(torch.nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py new file mode 100644 index 0000000000000..c5d3d49d67602 --- /dev/null +++ b/vllm/model_executor/models/apertus.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The Swiss AI Initiative. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate the architectural differences made by +# the Swiss AI Initiative that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Apertus model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import ApertusConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import XIELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class ApertusMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "xielu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only xIELU is supported for now." + ) + self.act_fn = XIELU() + + def forward(self, x): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class ApertusAttention(nn.Module): + def __init__( + self, + config: ApertusConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) + + sliding_window = None + if layer_types := getattr(config, "layer_types", None): + is_sliding = layer_types[layer_idx] == "sliding_attention" + if is_sliding: + sliding_window = config.sliding_window + + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) + + self.attn = attn_cls( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q.contiguous().view(-1, self.head_dim)).view_as(q) + k = self.k_norm(k.contiguous().view(-1, self.head_dim)).view_as(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb( + self, + config: ApertusConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig], + ) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "apertus": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class ApertusDecoderLayer(nn.Module): + def __init__( + self, + config: ApertusConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + + # Apertus defaults to causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = ApertusAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = ApertusMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_layernorm(hidden_states) + else: + hidden_states, residual = self.attention_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.feedforward_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ApertusModel(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer, + ): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.aux_hidden_state_layers = tuple[int, ...]() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + + # we need to load the buffers for beta and eps (XIELU) + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer, + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size + ), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer, + ): + return ApertusModel( + vllm_config=vllm_config, prefix=prefix, layer_type=layer_type + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 4cf73e2e0ea56..634e94b168143 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -9,6 +9,7 @@ # activation. from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -19,32 +20,43 @@ from vllm.compilation.decorators import support_torch_compile from vllm.distributed import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, +) class ArceeMLP(nn.Module): """Feed-forward layer for Arcee using ReLU^2 activation (no gating as in LLaMA).""" - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[Any] = None, - bias: bool = False, - prefix: str = "", - reduce_results: bool = True) -> None: + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[Any] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: super().__init__() # Single linear projection up to intermediate size # (no separate gate projection) @@ -65,8 +77,10 @@ class ArceeMLP(nn.Module): prefix=f"{prefix}.down_proj", ) if hidden_act != "relu2": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only 'relu2' is supported for AFM.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only 'relu2' is supported for AFM." + ) # Define ReLU^2 activation: (ReLU(x))^2 elementwise self.act_fn = ReLUSquaredActivation() @@ -81,38 +95,45 @@ class ArceeDecoderLayer(nn.Module): """Transformer decoder block for Arcee, with self-attention and ReLU^2 MLP.""" - def __init__(self, - config: LlamaConfig, - cache_config: Optional[Any] = None, - quant_config: Optional[Any] = None, - prefix: str = "") -> None: + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[Any] = None, + quant_config: Optional[Any] = None, + prefix: str = "", + ) -> None: super().__init__() self.hidden_size = config.hidden_size # Rotary embedding parameters (reuse LLaMA defaults) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Determine if attention bias is needed (some variants use bias terms) attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias # Self-Attention (using LLaMA's attention structure) from vllm.model_executor.models.llama import ( - LlamaAttention) # import here to avoid circular import + LlamaAttention, # import here to avoid circular import + ) + self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -122,8 +143,8 @@ class ArceeDecoderLayer(nn.Module): cache_config=cache_config, prefix=f"{prefix}.self_attn", attn_type=getattr( - config, "attn_type", - "decoder"), # assume decoder (causal) unless specified + config, "attn_type", "decoder" + ), # assume decoder (causal) unless specified ) # MLP with ReLU^2 activation self.mlp = ArceeMLP( @@ -135,14 +156,16 @@ class ArceeDecoderLayer(nn.Module): prefix=f"{prefix}.mlp", ) # Layer normalization layers (RMSNorm as in LLaMA) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( - self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: # Self-Attention block if residual is None: @@ -150,13 +173,10 @@ class ArceeDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) else: # Fused residual add + layernorm if supported - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Feed-forward block - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -166,11 +186,13 @@ class ArceeModel(nn.Module): """The transformer model backbone for Arcee (embedding layer + stacked decoder blocks + final norm).""" - def __init__(self, - *, - vllm_config, - prefix: str = "", - layer_type: type[nn.Module] = ArceeDecoderLayer) -> None: + def __init__( + self, + *, + vllm_config, + prefix: str = "", + layer_type: type[nn.Module] = ArceeDecoderLayer, + ) -> None: super().__init__() config: LlamaConfig = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -181,8 +203,9 @@ class ArceeModel(nn.Module): self.org_vocab_size = config.vocab_size # Word embeddings (parallelized if using pipeline parallel) - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -190,16 +213,17 @@ class ArceeModel(nn.Module): quant_config=quant_config, ) else: - self.embed_tokens = PPMissingLayer( - ) # placeholder on non-embedding ranks + self.embed_tokens = PPMissingLayer() # placeholder on non-embedding ranks # Build decoder layers across pipeline ranks self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) # Final RMSNorm on the last pipeline stage @@ -214,9 +238,9 @@ class ArceeModel(nn.Module): # Prepare factory for empty intermediate tensors # (for pipeline scheduling) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -226,44 +250,47 @@ class ArceeModel(nn.Module): input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: # Embedding lookup (on first pipeline rank) if get_pp_group().is_first_rank: - hidden_states = (inputs_embeds if inputs_embeds is not None else - self.get_input_embeddings(input_ids)) + hidden_states = ( + inputs_embeds + if inputs_embeds is not None + else self.get_input_embeddings(input_ids) + ) residual = None else: assert intermediate_tensors is not None, ( - "IntermediateTensors must be provided for non-first " - "pipeline ranks") + "IntermediateTensors must be provided for non-first pipeline ranks" + ) hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states: list[torch.Tensor] = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append( - hidden_states + - residual) # capture pre-layer hidden state if needed + hidden_states + residual + ) # capture pre-layer hidden state if needed hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: # Send intermediate results to the next pipeline stage - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) # On last rank: apply final layer norm hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights, mapping q/k/v projections to fused qkv_proj.""" stacked_params_mapping = [ (".qkv_proj", ".q_proj", "q"), @@ -277,17 +304,17 @@ class ArceeModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -330,8 +357,7 @@ class ArceeModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -341,7 +367,8 @@ class ArceeModel(nn.Module): class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """Arcee Model for causal language modeling, integrated with vLLM runtime.""" - # Map fused module names to their sub-module components + + # Map fused module names to their submodule components # (for quantization and LoRA) packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -353,8 +380,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.config = config # Initialize the inner Transformer model (ArceeModel) - self.model = ArceeModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") + self.model = ArceeModel(vllm_config=vllm_config, prefix=f"{prefix}.model") # On the last pipeline stage, set up the LM head and logits processor if get_pp_group().is_last_rank: # Determine vocabulary size (including any LoRA extra tokens @@ -372,51 +398,50 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ) if config.tie_word_embeddings: # Tie output weights with input embedding matrix - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: # Placeholder for lm_head on non-last ranks self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # Compute final logits from hidden states (last pipeline rank only) - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights into the model (delegates to inner model and handles tied embeddings).""" loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - skip_substrs=["gate_proj"]) + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + skip_substrs=["gate_proj"], + ) # AutoWeightLoader handles weight name remapping, including fusing # separate q_proj, k_proj, v_proj into qkv_proj return loader.load_weights(weights) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 4693c9487a8bf..760df1cef82b0 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Snowflake Arctic model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -10,67 +12,84 @@ from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.deepspeedfp import ( - DeepSpeedFPConfig, DeepSpeedFPParameter) + DeepSpeedFPConfig, + DeepSpeedFPParameter, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP, SupportsQuant -from .utils import (extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class ArcticMLP(nn.Module): - - def __init__(self, - config: ArcticConfig, - expert_id: int = -1, - is_residual_mlp: bool = False, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = ""): + def __init__( + self, + config: ArcticConfig, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size self.expert_id = expert_id - self.ffn_dim = config.intermediate_size if not is_residual_mlp \ - else self.hidden_size + self.ffn_dim = ( + config.intermediate_size if not is_residual_mlp else self.hidden_size + ) - self.w13 = MergedColumnParallelLinear(self.hidden_size, - [self.ffn_dim] * 2, - bias=False, - quant_config=quant_config) - self.w2 = RowParallelLinear(self.ffn_dim, - self.hidden_size, - bias=False, - reduce_results=reduce_results, - quant_config=quant_config) + self.w13 = MergedColumnParallelLinear( + self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config + ) + self.w2 = RowParallelLinear( + self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config, + ) if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, hidden_states): @@ -85,13 +104,15 @@ class ArcticMoE(nn.Module): Model-parallel implementation of Arctic MoE Layer. """ - def __init__(self, - config: ArcticConfig, - tp_size: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = ""): + def __init__( + self, + config: ArcticConfig, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ): super().__init__() layer_id = extract_layer_index(prefix) @@ -111,52 +132,75 @@ class ArcticMoE(nn.Module): self.params_dtype = params_dtype if not self.is_moe_layer: - self.mlp = ArcticMLP(config, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.mlp") + self.mlp = ArcticMLP( + config, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.mlp", + ) else: - self.gate = ReplicatedLinear(self.hidden_size, - self.num_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) if self.is_quant: self.ws = DeepSpeedFPParameter( - torch.Size((self.num_experts, 2 * self.intermediate_size, - self.hidden_size)), + torch.Size( + (self.num_experts, 2 * self.intermediate_size, self.hidden_size) + ), params_dtype=params_dtype, quant_config=quant_config, ) self.w2s = DeepSpeedFPParameter( - torch.Size((self.num_experts, self.hidden_size, - self.intermediate_size)), + torch.Size( + (self.num_experts, self.hidden_size, self.intermediate_size) + ), params_dtype=params_dtype, quant_config=quant_config, ) else: self.ws = nn.Parameter( - torch.empty(self.num_experts, - 2 * self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.w2s = nn.Parameter( - torch.empty(self.num_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype)) - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) + torch.empty( + self.num_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.ds_dequantize() if self.is_quant else param.data shard_size = self.intermediate_size @@ -164,8 +208,9 @@ class ArcticMoE(nn.Module): if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] if self.is_quant: @@ -178,15 +223,14 @@ class ArcticMoE(nn.Module): router_logits, _ = self.gate(hidden_states) do_normalize = self.top_k > 1 topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, router_logits, self.top_k, renormalize=do_normalize) + hidden_states, router_logits, self.top_k, renormalize=do_normalize + ) # topk_ids: (num_tokens, k) if self.is_quant: if 2 * num_tokens <= self.num_experts: # If much fewer tokens than experts, use selective dequantize. - ws_dequantized = self.ws.ds_selective_dequantize( - topk_ids.flatten()) - w2s_dequantized = self.w2s.ds_selective_dequantize( - topk_ids.flatten()) + ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten()) # We gathered the experts to the tokens so update the mapping. topk_ids = torch.arange( 0, @@ -203,10 +247,10 @@ class ArcticMoE(nn.Module): w2s_dequantized if self.is_quant else self.w2s, topk_weights, topk_ids, - inplace=True) + inplace=True, + ) if self.reduce_results and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) def forward(self, hidden_states: torch.Tensor): @@ -218,7 +262,6 @@ class ArcticMoE(nn.Module): class ArcticAttention(nn.Module): - def __init__( self, config: ArcticConfig, @@ -248,12 +291,14 @@ class ArcticAttention(nn.Module): self.rope_theta = config.rope_theta self.scaling = self.head_dim**-0.5 - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, @@ -270,13 +315,15 @@ class ArcticAttention(nn.Module): is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -292,7 +339,6 @@ class ArcticAttention(nn.Module): class ArcticDecoderLayer(nn.Module): - def __init__( self, config: ArcticConfig, @@ -305,10 +351,12 @@ class ArcticDecoderLayer(nn.Module): layer_idx = extract_layer_index(prefix) is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 self.use_residual = config.use_residual and is_moe_layer - self.self_attn = ArcticAttention(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = ArcticAttention( + config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = ArcticMoE( config, quant_config=quant_config, @@ -316,18 +364,21 @@ class ArcticDecoderLayer(nn.Module): prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) if self.use_residual: - self.residual_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.residual_mlp = ArcticMLP(config, - is_residual_mlp=True, - reduce_results=False, - prefix=f"{prefix}.residual_mlp") + self.residual_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.residual_mlp = ArcticMLP( + config, + is_residual_mlp=True, + reduce_results=False, + prefix=f"{prefix}.residual_mlp", + ) def forward( self, @@ -361,7 +412,6 @@ class ArcticDecoderLayer(nn.Module): @support_torch_compile class ArcticModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -371,19 +421,20 @@ class ArcticModel(nn.Module): self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=self.vocab_size) + self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: ArcticDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -403,7 +454,7 @@ class ArcticModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -419,23 +470,27 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.model = ArcticModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = ArcticModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.unpadded_vocab_size = config.vocab_size - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -447,21 +502,19 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -475,28 +528,47 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): for layer in range(num_layers): mlp_params_mapping.append( - (f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w1.weight", 0)) + ( + f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", + 0, + ) + ) mlp_params_mapping.append( - (f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w3.weight", 1)) + ( + f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", + 1, + ) + ) if layer % 2 == 0: # MLP layers mlp_params_mapping.append( - (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + ( + f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", + 0, + ) + ) mlp_params_mapping.append( - (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + ( + f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", + 1, + ) + ) else: # MoE layers for expert_id in range(self.config.num_local_experts): expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + ("ws", f"experts.{expert_id}.w1.weight", expert_id) + ) expert_params_mapping.append( - ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + ("w2s", f"experts.{expert_id}.w2.weight", expert_id) + ) expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + ("ws", f"experts.{expert_id}.w3.weight", expert_id) + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -504,9 +576,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): logger.info( "It will take ~10 minutes loading from the 16-bit weights. " "Alternatively, use the prequantized 8-bit weights of arctic " - "and set load-format to `sharded_state` will accelerate loading.") + "and set load-format to `sharded_state` will accelerate loading." + ) for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -531,8 +604,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, shard_id \ - in expert_params_mapping: + for param_name, weight_name, shard_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -540,10 +612,9 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=shard_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=shard_id + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -552,8 +623,9 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 1c7960fa3e0a5..734ae8cbd6087 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Optional, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -9,38 +9,48 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.processing_aria import AriaProcessor -from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -# yapf: disable from .idefics2_vision_model import Idefics2VisionConfig from .idefics2_vision_model import ( - Idefics2VisionTransformer as Idefics3VisionTransformer) -# yapf: enable + Idefics2VisionTransformer as Idefics3VisionTransformer, +) from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - is_pp_missing_parameter, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + maybe_prefix, +) class AriaImagePixelInputs(TensorSchema): @@ -53,6 +63,8 @@ class AriaImagePixelInputs(TensorSchema): - w: Width of each image """ + type: Literal["pixel_values"] + pixel_values: Annotated[ torch.Tensor, TensorShape("bn", 3, "h", "w"), @@ -79,8 +91,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): # Identity layer self.post_layernorm = nn.Identity() - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -90,7 +101,6 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - # NOTE: post_layernorm is not used in Aria if "post_layernorm" in name: continue @@ -105,15 +115,13 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class AriaProjectorMLP(nn.Module): - def __init__( self, in_features: int, @@ -122,12 +130,8 @@ class AriaProjectorMLP(nn.Module): ) -> None: super().__init__() - self.linear_in = ColumnParallelLinear(in_features, - hidden_features, - bias=False) - self.linear_out = RowParallelLinear(hidden_features, - output_dim, - bias=False) + self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False) + self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False) self.act = get_act_fn("gelu_new") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -143,16 +147,8 @@ class AriaProjector(nn.Module): projects ViT's outputs into MoE's inputs. Args: - patch_to_query_dict (dict): Maps patch numbers to their corresponding - query numbers, - e.g., {1225: 128, 4900: 256}. This allows for different query sizes - based on image resolution. - embed_dim (int): Embedding dimension. - num_heads (int): Number of attention heads. - kv_dim (int): Dimension of key and value. - ff_dim (int): Hidden dimension of the feed-forward network. - output_dim (int): Output dimension. - norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig) + containing projector configuration parameters. Outputs: A tensor with the shape of (batch_size, query_number, output_dim) @@ -169,15 +165,17 @@ class AriaProjector(nn.Module): self.output_dim = config.text_config.hidden_size self.query = nn.Parameter( - torch.empty(config.max_value_projector_patch_to_query_dict, - self.in_features)) + torch.empty( + config.max_value_projector_patch_to_query_dict, self.in_features + ) + ) self.cross_attn = AriaCrossAttention(config) self.layer_norm = nn.LayerNorm(self.in_features) - self.feed_forward = AriaProjectorMLP(self.in_features, - self.hidden_features, - self.output_dim) + self.feed_forward = AriaProjectorMLP( + self.in_features, self.hidden_features, self.output_dim + ) def forward( self, @@ -187,9 +185,11 @@ class AriaProjector(nn.Module): batch_size, num_patches = x.shape[0], x.shape[1] if num_patches not in self.patch_to_query_dict: - raise KeyError(f"Number of patches {num_patches} not found in " - "patch_to_query_dict amongst possible values " - f"{self.patch_to_query_dict.keys()}.") + raise KeyError( + f"Number of patches {num_patches} not found in " + "patch_to_query_dict amongst possible values " + f"{self.patch_to_query_dict.keys()}." + ) query_num = self.patch_to_query_dict[num_patches] @@ -206,33 +206,33 @@ class AriaProjector(nn.Module): return out -class AriaFusedMoE(FusedMoE): - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - shard_id: str) -> None: +class AriaFusedMoE(SharedFusedMoE): + def weight_loader( + self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str + ) -> None: # Override the weight_loader to handle the expert weights in the Aria # model, which are already packed with experts, and merge the gate and # up weights for each expert. # Note: Loading expert weights with quantization is not supported tp_rank = get_tensor_model_parallel_rank() - if shard_id == 'w13': + if shard_id == "w13": # the shape of loaded_weight is # (num_experts, hidden_size, 2 * moe_intermediate_size) if self.tp_size > 1: up, gate = loaded_weight.chunk(2, dim=-1) up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] - up_and_gate = torch.cat([up_current_rank, gate_current_rank], - dim=-1).transpose(1, 2) + up_and_gate = torch.cat( + [up_current_rank, gate_current_rank], dim=-1 + ).transpose(1, 2) param.data.copy_(up_and_gate) else: param.data.copy_(loaded_weight.transpose(1, 2)) - elif shard_id == 'w2': + elif shard_id == "w2": # the shape of loaded_weight is # (num_experts, moe_intermediate_size, hidden_size) if self.tp_size > 1: - down_current_rank = loaded_weight.chunk(self.tp_size, - dim=1)[tp_rank] + down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[tp_rank] param.data.copy_(down_current_rank.transpose(1, 2)) else: param.data.copy_(loaded_weight.transpose(1, 2)) @@ -257,18 +257,9 @@ class AriaTextMoELayer(nn.Module): self.config = config self.router_weight = nn.Parameter( - torch.empty( - (self.config.moe_num_experts, self.config.hidden_size))) - - self.experts = AriaFusedMoE( - num_experts=config.moe_num_experts, - top_k=config.moe_topk, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - reduce_results=True, - prefix=f"{prefix}.experts", + torch.empty((self.config.moe_num_experts, self.config.hidden_size)) ) + self.shared_experts = LlamaMLP( config.hidden_size, config.intermediate_size * config.moe_num_shared_experts, @@ -277,27 +268,37 @@ class AriaTextMoELayer(nn.Module): bias=config.mlp_bias, ) + self.experts = AriaFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.moe_num_experts, + top_k=config.moe_topk, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.experts", + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Forward pass of the MoE Layer. Args: - hidden_states (torch.Tensor): Input tensor of shape (batch_size, - sequence_length, hidden_size). + hidden_states: Input tensor of shape + (batch_size, sequence_length, hidden_size). Returns: torch.Tensor: Output tensor after passing through the MoE layer. """ - router_output = torch.nn.functional.linear(hidden_states, - self.router_weight) + router_output = torch.nn.functional.linear(hidden_states, self.router_weight) - hidden_states_copy = hidden_states.clone() - # NOTE: hidden_states will be modified inplace by `FusedMoE` sparse_expert_output = self.experts(hidden_states, router_output) - shared_expert_output = self.shared_experts(hidden_states_copy) - return sparse_expert_output + shared_expert_output + if self.shared_experts is not None: + return sparse_expert_output[0] + sparse_expert_output[1] + else: + return sparse_expert_output class AriaTextDecoderLayer(LlamaDecoderLayer): @@ -307,17 +308,15 @@ class AriaTextDecoderLayer(LlamaDecoderLayer): Experts (MoE) Layer. """ - def __init__( - self, - config: AriaTextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, cache_config, quant_config, prefix) - self.mlp = AriaTextMoELayer(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config, prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.mlp = AriaTextMoELayer( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) class AriaTextModel(LlamaModel, SupportsQuant): @@ -325,6 +324,7 @@ class AriaTextModel(LlamaModel, SupportsQuant): Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. """ + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -333,14 +333,13 @@ class AriaTextModel(LlamaModel, SupportsQuant): } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=AriaTextDecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=AriaTextDecoderLayer + ) # Adapted from LlamaModel.load_weights with the modification of adding # the expert weights mapping to `stacked_params_mapping` - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -348,27 +347,27 @@ class AriaTextModel(LlamaModel, SupportsQuant): (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), - ("experts.w13_weight", "experts.fc1.weight", 'w13'), - ("experts.w2_weight", "experts.fc2.weight", 'w2'), + ("experts.w13_weight", "experts.fc1.weight", "w13"), + ("experts.w2_weight", "experts.fc2.weight", "w2"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -400,15 +399,13 @@ class AriaTextModel(LlamaModel, SupportsQuant): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class AriaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(AriaConfig) @@ -427,7 +424,6 @@ class AriaProcessingInfo(BaseProcessingInfo): class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -440,22 +436,26 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: vision_config = self.info.get_vision_config() max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images) + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) } class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -486,9 +486,11 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, - info=AriaProcessingInfo, - dummy_inputs=AriaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + AriaMultiModalProcessor, + info=AriaProcessingInfo, + dummy_inputs=AriaDummyInputsBuilder, +) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. @@ -496,6 +498,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): This model combines a vision tower, a multi-modal projector, and a language model to perform tasks that involve both image and text inputs. """ + + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 @@ -539,21 +544,25 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): vllm_config=vllm_config.with_hf_config(config.text_config), prefix=maybe_prefix(prefix, "language_model.model"), ) - self.pad_token_id = (self.config.pad_token_id - if self.config.pad_token_id is not None else -1) + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, org_num_embeddings=self.language_model.org_vocab_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.vocab_size, logit_scale + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + self, **kwargs: object + ) -> Optional[AriaImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) pixel_mask = kwargs.pop("pixel_mask", None) @@ -561,12 +570,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): return None return AriaImagePixelInputs( - pixel_values=flatten_bn(pixel_values, concat=True), - pixel_mask=flatten_bn(pixel_mask, concat=True), + type="pixel_values", + pixel_values=pixel_values, + pixel_mask=pixel_mask, ) def _create_patch_attention_mask( - self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: + self, + pixel_mask: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: if pixel_mask is None: return None @@ -586,8 +598,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ) -> tuple[torch.Tensor, torch.Tensor]: assert self.vision_tower is not None - pixel_values = image_input['pixel_values'] - pixel_mask = image_input['pixel_mask'] + pixel_values = image_input["pixel_values"] + pixel_mask = image_input["pixel_mask"] patch_attention_mask = self._create_patch_attention_mask(pixel_mask) @@ -605,27 +617,13 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] multimodal_embeddings = self._process_image_input(image_input) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -636,10 +634,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ) -> Union[torch.Tensor, IntermediateTensors]: if inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model( @@ -651,10 +650,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 687c82ded9d0a..6e93de524e482 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -10,32 +10,36 @@ from transformers import BatchFeature, GotOcr2ImageProcessor from transformers.activations import ACT2FN from transformers.image_processing_utils import get_size_dict from transformers.models.aya_vision import AyaVisionConfig -from transformers.models.aya_vision.processing_aya_vision import ( - AyaVisionProcessor) +from transformers.models.aya_vision.processing_aya_vision import AyaVisionProcessor from transformers.models.got_ocr2.image_processing_got_ocr2 import ( - get_optimal_tiled_canvas) + get_optimal_tiled_canvas, +) from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalFieldConfig, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) class AyaVisionImagePixelInputs(TensorSchema): @@ -63,17 +67,17 @@ class AyaVisionImagePixelInputs(TensorSchema): class AyaVisionMultiModalProjector(nn.Module): - def __init__(self, config: AyaVisionConfig): super().__init__() self.config = config self.downsample_factor = config.downsample_factor self.alignment_intermediate_size = getattr( - config, "alignment_intermediate_size", - config.text_config.hidden_size) - self.layernorm = nn.LayerNorm(config.vision_config.hidden_size * - (config.downsample_factor**2), - eps=config.adapter_layer_norm_eps) + config, "alignment_intermediate_size", config.text_config.hidden_size + ) + self.layernorm = nn.LayerNorm( + config.vision_config.hidden_size * (config.downsample_factor**2), + eps=config.adapter_layer_norm_eps, + ) self.linear_1 = nn.Linear( config.vision_config.hidden_size * (config.downsample_factor**2), @@ -83,9 +87,11 @@ class AyaVisionMultiModalProjector(nn.Module): self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation # For SwiGLU, project down to half size since we split intermediate dim - self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, - config.text_config.hidden_size, - bias=True) + self.linear_2 = nn.Linear( + self.alignment_intermediate_size // 2, + config.text_config.hidden_size, + bias=True, + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: image_features = self.pixel_shuffle(image_features) @@ -99,26 +105,31 @@ class AyaVisionMultiModalProjector(nn.Module): hidden_states = self.linear_2(hidden_states) return hidden_states - def pixel_shuffle(self, - image_features: torch.Tensor) -> torch.Tensor: # B, S, D + def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor: # B, S, D batch_size, seq_length, _ = image_features.shape height = width = int(seq_length**0.5) - image_features = image_features.reshape(image_features.shape[0], width, - height, -1) + image_features = image_features.reshape( + image_features.shape[0], width, height, -1 + ) channels = image_features.shape[-1] image_features = image_features.reshape( - batch_size, width, int(height / self.downsample_factor), - int(channels * self.downsample_factor)) + batch_size, + width, + int(height / self.downsample_factor), + int(channels * self.downsample_factor), + ) image_features = image_features.permute(0, 2, 1, 3) image_features = image_features.reshape( - batch_size, int(height / self.downsample_factor), - int(width / self.downsample_factor), -1) + batch_size, + int(height / self.downsample_factor), + int(width / self.downsample_factor), + -1, + ) image_features = image_features.permute(0, 2, 1, 3) return image_features class AyaVisionProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> AyaVisionConfig: return self.ctx.get_hf_config(AyaVisionConfig) @@ -133,14 +144,20 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - height = image_processor.size['height'] - width = image_processor.size['width'] + height = image_processor.size["height"] + width = image_processor.size["width"] max_patches = image_processor.max_patches - return ImageSize(height=height * max_patches, - width=width * max_patches) + return ImageSize(height=height * max_patches, width=width * max_patches) - def get_num_patches(self, *, image_width: int, image_height: int, - size: dict, min_patches: int, max_patches: int) -> int: + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + size: dict, + min_patches: int, + max_patches: int, + ) -> int: """ Calculate the number of patches needed for a given image based on size constraints. This method replicates and adjusts the logic from: @@ -148,15 +165,16 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): """ size = get_size_dict(size, default_to_square=False) num_columns, num_rows = get_optimal_tiled_canvas( - (image_height, image_width), (size["height"], size["width"]), - min_patches, max_patches) + (image_height, image_width), + (size["height"], size["width"]), + min_patches, + max_patches, + ) num_blocks = num_columns * num_rows return num_blocks if num_blocks == 1 else num_blocks + 1 -class AyaVisionDummyInputsBuilder( - BaseDummyInputsBuilder[AyaVisionProcessingInfo]): - +class AyaVisionDummyInputsBuilder(BaseDummyInputsBuilder[AyaVisionProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -169,22 +187,24 @@ class AyaVisionDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - image_size = \ - self.info.get_image_size_with_most_features() + image_size = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=image_size.width, - height=image_size.height, - num_images=num_images) + "image": self._get_dummy_images( + width=image_size.width, + height=image_size.height, + num_images=num_images, + overrides=image_overrides, + ) } -class AyaVisionMultiModalProcessor( - BaseMultiModalProcessor[AyaVisionProcessingInfo]): - +class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -203,13 +223,13 @@ class AyaVisionMultiModalProcessor( # HF processor pops the `num_patches` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) image_sizes = [ - parsed_images.get_image_size(i) - for i in range(len(parsed_images)) + parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] num_patches = [ @@ -218,7 +238,8 @@ class AyaVisionMultiModalProcessor( image_height=image_size.height, size=image_processor.size, min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches) + max_patches=image_processor.max_patches, + ) for image_size in image_sizes ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -232,8 +253,7 @@ class AyaVisionMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -280,10 +300,10 @@ def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int: return _get_layer_index(feature_layers, num_hidden_layers) # If we have multiple feature layers, initialize up to the deepest m elif isinstance(feature_layers, (list, tuple)): - return max( - _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @@ -295,9 +315,10 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @MULTIMODAL_REGISTRY.register_processor( AyaVisionMultiModalProcessor, info=AyaVisionProcessingInfo, - dummy_inputs=AyaVisionDummyInputsBuilder) -class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=AyaVisionDummyInputsBuilder, +) +class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -306,7 +327,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -329,7 +351,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, config.vision_config, quant_config, num_hidden_layers_override=num_hidden_layers, - prefix=maybe_prefix(prefix, "vision_model")) + prefix=maybe_prefix(prefix, "vision_model"), + ) self.vocab_size = config.text_config.vocab_size self.multi_modal_projector = AyaVisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( @@ -337,58 +360,42 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, hf_config=config.text_config, prefix=maybe_prefix(prefix, "model"), # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm - architectures=["Cohere2ForCausalLM"]) + architectures=["Cohere2ForCausalLM"], + ) @property def dtype(self): return next(self.parameters()).dtype - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor, - **kwargs) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype), - **kwargs) - - def select_features(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + return vision_tower( + pixel_values.to(dtype=vision_tower.dtype), + feature_select_strategy=self.config.vision_feature_select_strategy, ) - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - - def _process_image_input(self, image_input: AyaVisionImagePixelInputs, - **kwargs) -> list[torch.Tensor]: + def _process_image_input( + self, image_input: AyaVisionImagePixelInputs, **kwargs + ) -> list[torch.Tensor]: assert self.vision_tower is not None pixel_values = image_input["pixel_values"] num_patches = image_input["num_patches"] image_features = self._image_pixels_to_features( - self.vision_tower, pixel_values=pixel_values) + self.vision_tower, pixel_values=pixel_values + ) image_embeds = self.multi_modal_projector(image_features) - return [ - e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) - ] + return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())] def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: + self, **kwargs: object + ) -> Optional[AyaVisionImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -399,41 +406,24 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, return AyaVisionImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - num_patches=flatten_bn(num_patches, concat=True), + pixel_values=pixel_values, + num_patches=num_patches, resolve_bindings={ "h": self.config.vision_config.image_size, "w": self.config.vision_config.image_size, - }) + }, + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -445,14 +435,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -464,7 +446,5 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 804a2f1785d5c..a8f0e5993e2bc 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -20,8 +20,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -31,32 +33,45 @@ from transformers import PretrainedConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + row_parallel_weight_loader, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -64,22 +79,20 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BaiChuanMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -89,16 +102,15 @@ class BaiChuanMLP(nn.Module): ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -124,12 +136,10 @@ class BaiChuanAttention(nn.Module): ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads self.position_embedding = position_embedding self.rope_theta = rope_theta @@ -159,12 +169,14 @@ class BaiChuanAttention(nn.Module): alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) else: self.rotary_emb = get_rope( self.head_dim, @@ -173,12 +185,14 @@ class BaiChuanAttention(nn.Module): base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -195,18 +209,18 @@ class BaiChuanAttention(nn.Module): class BaiChuanDecoderLayer(nn.Module): - - def __init__(self, - config: PretrainedConfig, - position_embedding: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -223,10 +237,10 @@ class BaiChuanDecoderLayer(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -239,23 +253,20 @@ class BaiChuanDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class BaiChuanModel(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -277,17 +288,15 @@ class BaiChuanModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: BaiChuanDecoderLayer(config, - position_embedding, - cache_config, - quant_config, - prefix=prefix), + lambda prefix: BaiChuanDecoderLayer( + config, position_embedding, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -309,22 +318,23 @@ class BaiChuanModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -336,7 +346,7 @@ class BaiChuanModel(nn.Module): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -356,15 +366,13 @@ class BaiChuanModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - SupportsQuant): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -388,18 +396,24 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config - self.model = BaiChuanModel(vllm_config=vllm_config, - prefix=prefix, - position_embedding=position_embedding) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.model = BaiChuanModel( + vllm_config=vllm_config, + prefix=prefix, + position_embedding=position_embedding, + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.lm_head.weight.weight_loader = self.lm_head_weight_loader if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -411,26 +425,23 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) - def lm_head_weight_loader(self, param: nn.Parameter, - loaded_weight: torch.Tensor): + def lm_head_weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): # Unlike Baichuan, Baichuan2 normalizes the head weights. # Refer to: # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 @@ -454,13 +465,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if config.hidden_size == 4096: # baichuan2 7b - super().__init__(vllm_config=vllm_config, - prefix=prefix, - position_embedding="ROPE") + super().__init__( + vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE" + ) else: # baichuan 13b, baichuan2 13b - super().__init__(vllm_config=vllm_config, - prefix=prefix, - position_embedding="ALIBI") + super().__init__( + vllm_config=vllm_config, prefix=prefix, position_embedding="ALIBI" + ) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -469,6 +480,6 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - position_embedding="ROPE") + super().__init__( + vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE" + ) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 23cab3509ca82..c016d46e194f2 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BailingMoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -34,39 +36,48 @@ from transformers.configuration_utils import PretrainedConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class BailingAttention(nn.Module): - def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, prefix: str = "", ): super().__init__() @@ -80,13 +91,13 @@ class BailingAttention(nn.Module): assert self.total_num_heads >= self.total_kv_heads self.num_heads = self.total_num_heads // tp_size - self.head_dim = config.head_dim or (self.hidden_size // - self.total_num_heads) + self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads - self.num_kv_heads = self.total_kv_heads // tp_size self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.use_rmsnorm = getattr(config, "use_rmsnorm", False) self.query_key_value = QKVParallelLinear( self.hidden_size, @@ -98,28 +109,48 @@ class BailingAttention(nn.Module): prefix=f"{prefix}.query_key_value", ) + if self.use_qk_norm: + self.query_layernorm = ( + RMSNorm(self.head_dim, eps=config.rms_norm_eps) + if self.use_rmsnorm + else nn.LayerNorm(self.head_dim, eps=1e-6) + ) + self.key_layernorm = ( + RMSNorm(self.head_dim, eps=config.rms_norm_eps) + if self.use_rmsnorm + else nn.LayerNorm(self.head_dim, eps=1e-6) + ) + self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=config.use_bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=f"{prefix}.dense", ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - prefix=f"{prefix}.attn") + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + + self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, + rotary_dim=self.rotary_dim, max_position=config.max_position_embeddings, base=config.rope_theta, is_neox_style=True, rope_scaling=config.rope_scaling, + partial_rotary_factor=self.partial_rotary_factor, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", ) def forward( @@ -127,12 +158,18 @@ class BailingAttention(nn.Module): hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split([ - self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank - ], - dim=-1) + q, k, v = qkv.split( + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1 + ) + + if self.use_qk_norm: + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + q = self.query_layernorm(q) + k = self.key_layernorm(k) + q = q.view(-1, self.q_size_per_rank) + k = k.view(-1, self.kv_size_per_rank) q, k = self.rotary_emb(position_ids, q, k) @@ -143,7 +180,6 @@ class BailingAttention(nn.Module): class BailingMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -178,7 +214,6 @@ class BailingMLP(nn.Module): class BailingMoE(nn.Module): - def __init__( self, intermediate_size: int, @@ -197,54 +232,111 @@ class BailingMoE(nn.Module): self.hidden_size = config.hidden_size self.quant_config = quant_config self.num_shared_experts = config.num_shared_experts - # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(self.hidden_size, - self.num_experts, - bias=False, - quant_config=None) + self.score_function = getattr(config, "score_function", None) + self.n_group = getattr(config, "n_group", None) + self.topk_group = getattr(config, "topk_group", None) + self.use_grouped_topk = self.n_group is not None and self.topk_group is not None + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) - self.experts = FusedMoE(num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_expert_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") + router_dtype = getattr(config, "router_dtype", None) + if router_dtype is None: + self.router_dtype = None + elif router_dtype == "fp32": + self.router_dtype = torch.float32 + else: + self.router_dtype = torch.bfloat16 + + self.gate = nn.Linear( + self.hidden_size, + self.num_experts, + bias=False, + dtype=self.router_dtype, + ) + + if getattr(config, "moe_router_enable_expert_bias", False): + self.gate.expert_bias = nn.Parameter( + torch.empty((config.num_experts,), dtype=torch.float32) + ) + else: + self.gate.expert_bias = None + + self.correction_bias = ( + self.gate.expert_bias.data if self.gate.expert_bias is not None else None + ) + + if self.score_function is not None: + assert ( + self.score_function == "softmax" and self.correction_bias is None + ) or ( + self.score_function == "sigmoid" and self.correction_bias is not None + ), ( + "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501 + ) + else: + # default value for scoring_func + self.score_function = "softmax" if self.num_shared_experts > 0: - intermediate_size = (config.moe_intermediate_size * - self.num_shared_experts) + if hasattr(config, "moe_shared_expert_intermediate_size"): + intermediate_size = config.moe_shared_expert_intermediate_size + else: + intermediate_size = config.moe_intermediate_size + intermediate_size *= config.num_shared_experts self.shared_experts = BailingMLP( intermediate_size=intermediate_size, config=config, quant_config=quant_config, reduce_results=False, - prefix=f"{prefix}.shared_experts") + prefix=f"{prefix}.shared_experts", + ) else: self.shared_experts = None + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.score_function, + e_score_correction_bias=self.gate.expert_bias, + num_expert_group=self.n_group, + topk_group=self.topk_group, + use_grouped_topk=self.use_grouped_topk, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_size) - if self.num_shared_experts > 0: - shared_output = self.shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if self.num_shared_experts > 0: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states.to(self.router_dtype)) + router_logits = router_logits.to(hidden_states.dtype) + + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_experts is not None: + shared_output, final_hidden_states = final_hidden_states + else: + shared_output = None + + final_hidden_states *= self.routed_scaling_factor + + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class BailingMoeBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -253,20 +345,26 @@ class BailingMoeBlock(nn.Module): prefix: str = "", ): super().__init__() + layer_idx = int(prefix.split(".")[-1]) + self.config = config hidden_size = config.hidden_size intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) - self.attention = BailingAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attention") - self.post_attention_layernorm = RMSNorm(hidden_size, - eps=config.rms_norm_eps) - self.mlp = BailingMoE(intermediate_size, - config, - quant_config, - True, - prefix=f"{prefix}.mlp") + self.attention = BailingAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attention" + ) + + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + + # Choose MLP class based on the number of experts and layer index + if layer_idx < config.first_k_dense_replace: + mlp_class = BailingMLP + else: + mlp_class = BailingMoE + self.mlp = mlp_class( + intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp" + ) def forward( self, @@ -278,23 +376,20 @@ class BailingMoeBlock(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attention( hidden_states=hidden_states, position_ids=position_ids, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class BailingMoeModel(nn.Module): - def __init__( self, *, @@ -309,11 +404,17 @@ class BailingMoeModel(nn.Module): self.config = config self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + self.tie_word_embeddings and get_pp_group().is_last_rank + ): self.word_embeddings = VocabParallelEmbedding( - self.vocab_size, self.embed_dim) + self.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.word_embeddings", + ) else: self.word_embeddings = PPMissingLayer() @@ -327,11 +428,12 @@ class BailingMoeModel(nn.Module): quant_config=quant_config, prefix=prefix, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) @@ -359,8 +461,7 @@ class BailingMoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( hidden_states, position_ids, @@ -368,24 +469,25 @@ class BailingMoeModel(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + else: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -396,13 +498,14 @@ class BailingMoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.config.norm_head and "lm_head.weight" in name: - loaded_weight = F.normalize(loaded_weight, - dim=0, - p=2, - eps=1e-7) + if ( + hasattr(self.config, "norm_head") + and self.config.norm_head + and "lm_head.weight" in name + ): + loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: @@ -430,13 +533,17 @@ class BailingMoeModel(nn.Module): if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -448,15 +555,15 @@ class BailingMoeModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - packed_modules_mapping = { "query_key_value": ["query_key_value"], "gate_up_proj": [ @@ -473,25 +580,37 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ) -> None: super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() + vllm_config.model_config.hf_config = config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings - self.model = BailingMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = BailingMoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if get_pp_group().is_last_rank: - self.lm_head = (self.word_embeddings if config.tie_word_embeddings - else ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config)) + if self.tie_word_embeddings: + self.lm_head = self.model.word_embeddings + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -503,27 +622,28 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None), ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + +class BailingMoeV2ForCausalLM(BailingMoeForCausalLM): + pass diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index e2cd31af5390a..42c1c7be1a75a 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Bamba model.""" + # Added by the IBM Team, 2024 from collections.abc import Iterable from typing import Optional @@ -9,44 +10,45 @@ import torch from torch import nn from transformers import BambaConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class BambaMLP(nn.Module): - def __init__( self, config: BambaConfig, @@ -67,8 +69,10 @@ class BambaMLP(nn.Module): quant_config=quant_config, ) if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -79,56 +83,53 @@ class BambaMLP(nn.Module): class BambaMixerDecoderLayer(nn.Module): - - def __init__(self, - config: BambaConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: BambaConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config - self.mamba = MambaMixer2(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - n_groups=config.mamba_n_groups, - num_heads=config.mamba_n_heads, - head_dim=config.mamba_d_head, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=config.mamba_expand * config.hidden_size, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) self.feed_forward = BambaMLP(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mamba(hidden_states, output) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -136,7 +137,6 @@ class BambaMixerDecoderLayer(nn.Module): class BambaAttentionDecoderLayer(nn.Module): - def __init__( self, config: BambaConfig, @@ -149,8 +149,7 @@ class BambaAttentionDecoderLayer(nn.Module): super().__init__() rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -198,10 +197,12 @@ class BambaAttentionDecoderLayer(nn.Module): bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) self.attn = Attention( self.num_heads, @@ -213,10 +214,8 @@ class BambaAttentionDecoderLayer(nn.Module): ) self.feed_forward = BambaMLP(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def self_attention( self, @@ -243,29 +242,26 @@ class BambaAttentionDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual ALL_DECODER_LAYER_TYPES = { "attention": BambaAttentionDecoderLayer, - "mamba": BambaMixerDecoderLayer + "mamba": BambaMixerDecoderLayer, } @support_torch_compile class BambaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -276,8 +272,11 @@ class BambaModel(nn.Module): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -289,8 +288,7 @@ class BambaModel(nn.Module): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layers_block_type[layer_idx]] + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[layer_idx]] return layer_class( config, layer_idx, @@ -301,13 +299,13 @@ class BambaModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -316,22 +314,9 @@ class BambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -344,36 +329,21 @@ class BambaModel(nn.Module): residual = intermediate_tensors["residual"] residual = None - num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] - if isinstance(layer, BambaAttentionDecoderLayer): - num_attn += 1 - - layer_mamba_cache_params = None - if isinstance(layer, - BambaMixerDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - + for i, layer in enumerate(self.layers): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -418,22 +388,22 @@ class BambaModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsQuant): +class BambaForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], - "gate_up_proj": ["up_proj", "down_proj"] + "gate_up_proj": ["up_proj", "down_proj"], } # LoRA specific attributes @@ -448,7 +418,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -459,13 +428,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -484,26 +451,22 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Bamba currently does not support prefix caching" - self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = BambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = BambaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -514,70 +477,43 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = \ - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba - ) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py deleted file mode 100644 index 32551d8102f32..0000000000000 --- a/vllm/model_executor/models/bart.py +++ /dev/null @@ -1,1342 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Derived from BART implementation posted on HuggingFace; license below: -# -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BART model.""" -import math -from collections.abc import Iterable -from typing import Optional - -import torch -from torch import nn -from transformers import BartConfig -from transformers.utils import logging - -from vllm.attention import Attention, AttentionType -from vllm.config import CacheConfig, LoRAConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVCrossParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsQuant, SupportsV0Only -from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - maybe_prefix) - -logger = logging.get_logger(__name__) - - -def get_bsz_seq_len(input_ids): - shp = input_ids.shape - ndim = len(shp) - if ndim == 1: - return 1, input_ids.numel() - else: - return shp[:2] - - -class BartLearnedPositionalEmbedding(VocabParallelEmbedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int): - # Bart is set up so that if padding_idx is - # specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. - # Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward( - self, - positions: torch.Tensor, - ) -> torch.Tensor: - """`input_ids' shape is expected to be [bsz x seqlen].""" - return super().forward(positions + self.offset) - - -class BartScaledWordEmbedding(VocabParallelEmbedding): - """ - This module overrides VocabParallelEmbedding's - forward by multiplying with embeddings scale. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - embed_scale: float = 1.0): - super().__init__(num_embeddings, embedding_dim) - self.embed_scale = embed_scale - - def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - return super().forward(input_ids) * self.embed_scale - - -class BartParallelLMHead(ParallelLMHead): - """ - This module overrides ParallelLMHead's - forward by dividing by embeddings scale, - yielding effectively the inverse of - BartScaledWordEmbedding - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - embed_scale: float = 1.0): - super().__init__(num_embeddings, embedding_dim) - self.embed_scale = embed_scale - - def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - return super().forward(input_ids) / self.embed_scale - - -class BartEncoderAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - config: Optional[BartConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.d_model = config.d_model - self.embed_dim = embed_dim - self.total_num_heads = num_heads - self.total_num_kv_heads = self.total_num_heads - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - self.d_model, - self.d_model // self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) - - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - bias=bias, - quant_config=quant_config, - ) - - tp_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tp_world_size == 0 - self.num_heads = self.total_num_heads // tp_world_size - - if self.total_num_kv_heads >= tp_world_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_world_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" - - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - attn_output = self.attn(q, k, v) - - output, _ = self.out_proj(attn_output) - return output - - -class BartDecoderSelfAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - config: Optional[BartConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.d_model = config.d_model - self.embed_dim = embed_dim - self.total_num_heads = num_heads - self.total_num_kv_heads = self.total_num_heads - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - self.d_model, - self.d_model // self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) - - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - bias=bias, - quant_config=quant_config, - ) - - tp_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tp_world_size == 0 - self.num_heads = self.total_num_heads // tp_world_size - - if self.total_num_kv_heads >= tp_world_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_world_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" - - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - attn_output = self.attn(q, k, v) - - output, _ = self.out_proj(attn_output) - return output - - -class BartCrossAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - config: Optional[BartConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.d_model = config.d_model - self.embed_dim = embed_dim - self.total_num_heads = num_heads - self.total_num_kv_heads = self.total_num_heads - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - - # TP sharding sizes is accounted for within "*Parallel" layers. - self.qkv_proj = QKVCrossParallelLinear(self.d_model, - self.d_model // - self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias, - quant_config=quant_config) - - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - bias=bias, - quant_config=quant_config, - ) - - tp_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tp_world_size == 0 - self.num_heads = self.total_num_heads // tp_world_size - - if self.total_num_kv_heads >= tp_world_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_world_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads # No GQA in bart - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_DECODER) - - def forward( - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" - - q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) - - attn_output = self.attn(q, k, v) - - output, _ = self.out_proj(attn_output) - return output - - -class BartEncoderLayer(nn.Module): - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = BartEncoderAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.activation_fn = get_act_fn(config.activation_function) - - ffn_hidden_size = self.embed_dim - ffn_intermediate_size = config.encoder_ffn_dim - ffn_has_bias = True - self.fc1 = ColumnParallelLinear( - ffn_hidden_size, - ffn_intermediate_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - self.act = get_act_fn("gelu") - self.fc2 = RowParallelLinear( - ffn_intermediate_size, - ffn_hidden_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Args: - hidden_states - torch.Tensor of *encoder* input embeddings. - Returns: - Encoder layer output torch.Tensor - """ - residual = hidden_states - hidden_states = self.self_attn(hidden_states=hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - hidden_states = cast_overflow_tensors(hidden_states) - - return hidden_states - - -class BartDecoderLayer(nn.Module): - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = BartDecoderSelfAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.activation_fn = get_act_fn(config.activation_function) - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - ''' - afeldman-nm: personally I would call this "cross-attention", - however I left the name as "encoder_attn" to maintain consistency - with the name of the pretrained weights. - ''' - self.encoder_attn = BartCrossAttention( - self.embed_dim, - config.decoder_attention_heads, - config=config, - prefix=f"{prefix}.encoder_attn", - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - - ffn_hidden_size = self.embed_dim - ffn_intermediate_size = config.encoder_ffn_dim - ffn_has_bias = True - self.fc1 = ColumnParallelLinear( - ffn_hidden_size, - ffn_intermediate_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - self.fc2 = RowParallelLinear( - ffn_intermediate_size, - ffn_hidden_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - decoder_hidden_states - torch.Tensor of *decoder* input embeddings. - encoder_hidden_states - torch.Tensor of *encoder* input embeddings. - Returns: - Decoder layer output torch.Tensor - """ - residual = decoder_hidden_states - - # Self Attention - hidden_states = self.self_attn(hidden_states=decoder_hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - - residual = hidden_states - - hidden_states = self.encoder_attn( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return hidden_states - - -class BartEncoder(nn.Module): - """ - Transformer encoder consisting of *config.encoder_layers* - self attention layers. Each layer is a [`BartEncoderLayer`]. - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = ""): - super().__init__() - - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - embed_dim = config.d_model - self.max_source_positions = config.max_position_embeddings - embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - embed_dim, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - embed_dim, - ) - self.layers = nn.ModuleList([ - BartEncoderLayer(config, - cache_config, - quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.encoder_layers) - ]) - - self.layernorm_embedding = nn.LayerNorm(embed_dim) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *encoder* input sequence tokens. - Returns: - Decoder output torch.Tensor - """ - # retrieve input_ids and inputs_embeds - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - embed_pos = self.embed_positions(positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states=hidden_states) - - return hidden_states - - -class BartDecoder(nn.Module): - """ - Transformer decoder consisting of *config.decoder_layers* layers. - Each layer is a [`BartDecoderLayer`] - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = "", - ): - super().__init__() - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - self.max_target_positions = config.max_position_embeddings - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - ) - - self.layers = nn.ModuleList( - [BartDecoderLayer(config,cache_config,quant_config, - prefix=f"{prefix}.layers.{layer_idx}") \ - for layer_idx in range(config.decoder_layers)]) - - self.layernorm_embedding = nn.LayerNorm(config.d_model) - - def forward( - self, - decoder_input_ids: torch.Tensor, - decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - decoder_input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - decoder_positions - Positions of *decoder* input sequence tokens. - encoder_hidden_states: - Tensor of encoder output embeddings - Returns: - Decoder output torch.Tensor - """ - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(decoder_input_ids) - else: - decoder_positions = inputs_embeds[:, -1] - - # embed positions - embed_pos = self.embed_positions(decoder_positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - # decoder layers - - for decoder_layer in self.layers: - hidden_states = decoder_layer( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - return hidden_states - - -class BartModel(nn.Module, SupportsQuant): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.encoder = BartEncoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = BartDecoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. - Returns: - Model output torch.Tensor - """ - - encoder_hidden_states = None - - if encoder_input_ids.numel() > 0: - # Run encoder attention if a non-zero number of encoder tokens - # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions) - - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states) - - return decoder_outputs - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - other_weights = [] - loaded_stacked_params = [] - model_params_dict = dict(self.named_parameters()) - - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if name not in model_params_dict: - continue - param = model_params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - loaded_stacked_params.append(name) - break - else: - if name in model_params_dict: - other_weights.append((name, loaded_weight)) - - loader = AutoWeightsLoader(self) - loaded_params = loader.load_weights(other_weights) - loaded_params.update(loaded_stacked_params) - return loaded_params - - -class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "decoder.": "model.decoder.", - "encoder.": "model.encoder.", - "shared.": "model.shared." - }, - orig_to_new_substr={ - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - }, - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - - super().__init__() - config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config - # currently all existing BART models have `tie_word_embeddings` enabled - assert config.tie_word_embeddings - self.config = config - self.model = BartModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.lm_head = BartParallelLMHead(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - weights_tuple_list = list(weights) - - shared_embedding_weight = None - for name, loaded_weight in weights_tuple_list: - if ('shared.weight' in name - or 'encoder.embed_tokens.weight' in name - or 'decoder.embed_tokens.weight' in name - or 'lm_head.weight' in name): - assert shared_embedding_weight is None, ( - "Conflicting embedding weights.") - shared_embedding_weight = loaded_weight - - loader = AutoWeightsLoader( - self, - skip_prefixes=(["cls.", "pooler."]), - ) - loaded_params = loader.load_weights(weights_tuple_list, - mapper=self.hf_to_vllm_mapper) - - if shared_embedding_weight is not None: - weight_loader = getattr(self.lm_head.weight, "weight_loader", - default_weight_loader) - weight_loader(self.lm_head.weight, shared_embedding_weight) - - self.model.encoder.embed_tokens.weight = self.lm_head.weight - self.model.decoder.embed_tokens.weight = self.lm_head.weight - loaded_params.update({ - 'model.encoder.embed_tokens.weight', 'lm_head.weight', - 'model.decoder.embed_tokens.weight' - }) - - return loaded_params - - -class MBartEncoderLayer(BartEncoderLayer): - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Args: - hidden_states - torch.Tensor of *encoder* input embeddings. - Returns: - Encoder layer output torch.Tensor - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states) - - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - hidden_states = cast_overflow_tensors(hidden_states) - - return hidden_states - - -class MBartDecoderLayer(BartDecoderLayer): - - def forward( - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - residual = decoder_hidden_states - hidden_states = self.self_attn_layer_norm(decoder_hidden_states) - - # Self Attention - hidden_states = self.self_attn(hidden_states=hidden_states) - - hidden_states = residual + hidden_states - - # Cross-Attention Block - - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - hidden_states = self.encoder_attn( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - - return hidden_states - - -class MBartEncoder(nn.Module): - """ - Transformer encoder consisting of *config.encoder_layers* - self attention layers. Each layer is a [`BartEncoderLayer`]. - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = ""): - super().__init__() - - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - embed_dim = config.d_model - self.max_source_positions = config.max_position_embeddings - embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - embed_dim, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - embed_dim, - ) - self.layers = nn.ModuleList([ - MBartEncoderLayer(config, - cache_config, - quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.encoder_layers) - ]) - - self.layernorm_embedding = nn.LayerNorm(embed_dim) - self.layer_norm = nn.LayerNorm(config.d_model) # 改动 - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *encoder* input sequence tokens. - Returns: - Decoder output torch.Tensor - """ - # retrieve input_ids and inputs_embeds - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - embed_pos = self.embed_positions(positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states=hidden_states) - - hidden_states = self.layer_norm(hidden_states) - return hidden_states - - -class MBartDecoder(nn.Module): - """ - Transformer decoder consisting of *config.decoder_layers* layers. - Each layer is a [`BartDecoderLayer`] - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = "", - ): - super().__init__() - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - self.max_target_positions = config.max_position_embeddings - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - ) - - self.layers = nn.ModuleList( - [MBartDecoderLayer(config, cache_config, quant_config, - prefix=f"{prefix}.layers.{layer_idx}") \ - for layer_idx in range(config.decoder_layers)]) - - self.layernorm_embedding = nn.LayerNorm(config.d_model) - self.layer_norm = nn.LayerNorm(config.d_model) - - def forward( - self, - decoder_input_ids: torch.Tensor, - decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - decoder_input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - decoder_positions - Positions of *decoder* input sequence tokens. - encoder_hidden_states: - Tensor of encoder output embeddings - Returns: - Decoder output torch.Tensor - """ - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(decoder_input_ids) - else: - decoder_positions = inputs_embeds[:, -1] - - # embed positions - embed_pos = self.embed_positions(decoder_positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - # decoder layers - - for decoder_layer in self.layers: - hidden_states = decoder_layer( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states = self.layer_norm(hidden_states) - return hidden_states - - -class MBartModel(nn.Module, SupportsQuant): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" - ] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.encoder = MBartEncoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = MBartDecoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. - Returns: - Model output torch.Tensor - """ - - encoder_hidden_states = None - - if encoder_input_ids.numel() > 0: - # Run encoder attention if a non-zero number of encoder tokens - # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions) - - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states) - - return decoder_outputs - - -class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): - base_model_prefix = "model" - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "decoder.": "model.decoder.", - "encoder.": "model.encoder.", - "shared.": "model.shared." - }, - orig_to_new_substr={ - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - }, - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config - assert config.tie_word_embeddings - self.config = config - self.model = MBartModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.lm_head = BartParallelLMHead(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - model_params_dict = dict(self.named_parameters()) - loaded_params = set() - remaining_weights = [] - shared_embedding_weight = None - - for name, loaded_weight in weights: - if any(skip in name - for skip in ["cls.", "pooler.", "final_logits_bias"]): - continue - if any(embed_name in name for embed_name in [ - 'shared.weight', 'encoder.embed_tokens.weight', - 'decoder.embed_tokens.weight' - ]): - if shared_embedding_weight is None: - shared_embedding_weight = loaded_weight - continue - is_stacked = False - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - vllm_name = name - for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items( - ): - vllm_name = vllm_name.replace(src, dst) - for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items( - ): - if vllm_name.startswith(src): - vllm_name = dst + vllm_name[len(src):] - break - vllm_name = vllm_name.replace(weight_name, param_name) - if vllm_name in model_params_dict: - param = model_params_dict[vllm_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(vllm_name) - is_stacked = True - break - if not is_stacked: - remaining_weights.append((name, loaded_weight)) - loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."]) - auto_loaded_params = loader.load_weights(remaining_weights, - mapper=self.hf_to_vllm_mapper) - loaded_params.update(auto_loaded_params) - if shared_embedding_weight is not None: - lm_head_param = self.lm_head.weight - weight_loader = getattr(lm_head_param, "weight_loader", - default_weight_loader) - weight_loader(lm_head_param, shared_embedding_weight) - self.model.encoder.embed_tokens.weight = self.lm_head.weight - self.model.decoder.embed_tokens.weight = self.lm_head.weight - loaded_params.update({ - 'model.encoder.embed_tokens.weight', 'lm_head.weight', - 'model.decoder.embed_tokens.weight' - }) - return loaded_params diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 22b6c4401213c..d9d4c62639d50 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -13,40 +13,44 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler, - PoolingMethod, - PoolingParamsUpdate, - PoolingType) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + DispatchPooler, + Pooler, + PoolingMethod, + PoolingParamsUpdate, + PoolingType, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import (SupportsCrossEncoding, SupportsQuant, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class BertEmbedding(nn.Module): - def __init__(self, config: BertConfig): - super().__init__() self.size = config.hidden_size - self.word_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.position_embeddings = VocabParallelEmbedding( - config.max_position_embeddings, config.hidden_size) + config.max_position_embeddings, config.hidden_size + ) self.token_type_embeddings = VocabParallelEmbedding( - config.type_vocab_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + config.type_vocab_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.register_buffer( "position_ids", @@ -54,18 +58,21 @@ class BertEmbedding(nn.Module): ) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": - raise ValueError("Only 'absolute' position_embedding_type" + - " is supported") + raise ValueError( + "Only 'absolute' position_embedding_type" + " is supported" + ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - token_type_ids = _decode_token_type_ids(input_ids) - inputs_embeds = self.word_embeddings(input_ids) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -76,7 +83,6 @@ class BertEmbedding(nn.Module): class BertPooler(Pooler): - def __init__(self, config: BertConfig): super().__init__() @@ -111,19 +117,22 @@ class BertPooler(Pooler): class BertEncoder(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.layer = nn.ModuleList([ - BertLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layer = nn.ModuleList( + [ + BertLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -135,12 +144,13 @@ class BertEncoder(nn.Module): class BertLayer(nn.Module): - - def __init__(self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.attention = BertAttention( @@ -149,20 +159,24 @@ class BertLayer(nn.Module): layer_norm_eps=config.layer_norm_eps, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attention") + prefix=f"{prefix}.attention", + ) self.intermediate = BertIntermediate( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.intermediate") + prefix=f"{prefix}.intermediate", + ) - self.output = BertOutput(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - layer_norm_eps=config.layer_norm_eps, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.output = BertOutput( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) def forward(self, hidden_states: torch.Tensor): attn_output = self.attention(hidden_states) @@ -172,7 +186,6 @@ class BertLayer(nn.Module): class BertAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -184,16 +197,20 @@ class BertAttention(nn.Module): ): super().__init__() - self.self = BertSelfAttention(hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.self = BertSelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) - self.output = BertSelfOutput(hidden_size=hidden_size, - layer_norm_eps=layer_norm_eps, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.output = BertSelfOutput( + hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) def forward( self, @@ -204,7 +221,6 @@ class BertAttention(nn.Module): class BertSelfAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -237,15 +253,18 @@ class BertSelfAttention(nn.Module): total_num_kv_heads=self.total_num_kv_heads, bias=True, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + prefix=f"{prefix}.qkv_proj", + ) - self.attn = EncoderOnlyAttention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = EncoderOnlyAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -258,41 +277,48 @@ class BertSelfAttention(nn.Module): class BertSelfOutput(nn.Module): - - def __init__(self, - hidden_size: int, - layer_norm_eps: float, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.dense = RowParallelLinear(input_size=hidden_size, - output_size=hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - def forward(self, hidden_states: torch.Tensor, - input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertIntermediate(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.dense = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.intermediate_act_fn = get_act_fn(hidden_act) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -302,25 +328,29 @@ class BertIntermediate(nn.Module): class BertOutput(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - layer_norm_eps: float, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.dense = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - def forward(self, hidden_states: torch.Tensor, - input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states @@ -329,7 +359,6 @@ class BertOutput(nn.Module): @support_torch_compile @default_pooling_type("CLS") class BertModel(nn.Module, SupportsQuant): - is_pooling_model = True packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} @@ -345,8 +374,10 @@ class BertModel(nn.Module, SupportsQuant): self.config = vllm_config.model_config.hf_config self.embeddings = embedding_class(self.config) - self.encoder = BertEncoder(vllm_config=vllm_config, - prefix=f"{prefix}.encoder") + self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.word_embeddings(input_ids) def forward( self, @@ -355,11 +386,12 @@ class BertModel(nn.Module, SupportsQuant): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=positions) + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + ) + return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -374,7 +406,7 @@ class BertModel(nn.Module, SupportsQuant): other_weights = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -392,8 +424,7 @@ class BertModel(nn.Module, SupportsQuant): return other_weights, loaded_stacked_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) loader = AutoWeightsLoader(self, skip_prefixes=["pooler."]) @@ -404,7 +435,6 @@ class BertModel(nn.Module, SupportsQuant): @default_pooling_type("ALL") class BertPoolingModel(BertModel): - is_pooling_model = True def __init__( @@ -423,8 +453,7 @@ class BertPoolingModel(BertModel): config = vllm_config.model_config.hf_config self.pooler = BertPooler(config) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) loader = AutoWeightsLoader(self) @@ -453,10 +482,14 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.model = self._build_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._build_model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.pooler = self._build_pooler(pooler_config) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -464,34 +497,35 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) - has_model_prefix = any( - name.startswith("model.") for name, _ in weights_list) + has_model_prefix = any(name.startswith("model.") for name, _ in weights_list) if not has_model_prefix: mapper = WeightsMapper(orig_to_new_prefix={"": "model."}) loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."]) return loader.load_weights(weights_list, mapper=mapper) - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> BertModel: - return BertModel(vllm_config=vllm_config, - prefix=prefix, - embedding_class=BertEmbedding) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: + return BertModel( + vllm_config=vllm_config, prefix=prefix, embedding_class=BertEmbedding + ) def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: - return DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }) + return DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) # Here we encode the token type ids together with the input ids. @@ -518,18 +552,18 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): TOKEN_TYPE_SHIFT = 30 -def _encode_token_type_ids(input_ids: torch.Tensor, - token_type_ids: torch.Tensor) -> None: +def _encode_token_type_ids( + input_ids: torch.Tensor, token_type_ids: torch.Tensor +) -> None: # input_ids can be padded to the right - input_ids[:token_type_ids.shape[0]].bitwise_or_( - token_type_ids << TOKEN_TYPE_SHIFT) + input_ids[: token_type_ids.shape[0]].bitwise_or_(token_type_ids << TOKEN_TYPE_SHIFT) def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - - ids_mask = torch.ones_like(input_ids, - dtype=torch.int32, - device=input_ids.device) << TOKEN_TYPE_SHIFT + ids_mask = ( + torch.ones_like(input_ids, dtype=torch.int32, device=input_ids.device) + << TOKEN_TYPE_SHIFT + ) tokens_mask = ids_mask.bitwise_not() token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT @@ -540,17 +574,16 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: @default_pooling_type("CLS") -class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ is_pooling_model = True @@ -559,32 +592,42 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, config = vllm_config.model_config.hf_config self.num_labels = config.num_labels - self.bert = BertPoolingModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=BertEmbedding) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.bert = BertPoolingModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + ) + self.classifier = nn.Linear( + config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype, + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -599,13 +642,73 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if token_type_ids is not None: assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) assert input_ids is not None _encode_token_type_ids(input_ids, token_type_ids) - return self.bert(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.bert( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) + + +@default_pooling_type("ALL") +class BertForTokenClassification(nn.Module): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.head_dtype = vllm_config.model_config.head_dtype + self.num_labels = config.num_labels + self.bert = BertModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + ) + self.classifier = nn.Linear( + config.hidden_size, config.num_labels, dtype=self.head_dtype + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + } + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.bert.get_input_embeddings(input_ids) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if token_type_ids is not None: + assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + + hidden_states = self.bert( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) + + hidden_states = hidden_states.to(self.head_dtype) + return self.classifier(hidden_states) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 129450927e564..05cb0e22a0aad 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -10,50 +10,59 @@ from transformers import PretrainedConfig from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import (get_act_and_mul_fn, - get_act_fn) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, torch_vllm_outplace_fused_experts) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.activation import get_act_and_mul_fn, get_act_fn +from vllm.model_executor.layers.fused_moe import activation_without_mul, fused_topk +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (SupportsQuant, - default_pooling_type) -from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler +from .bert import BertPooler +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type + class BertWithRopeEmbedding(nn.Module): - def __init__(self, config: PretrainedConfig): - super().__init__() if config.position_embedding_type not in ["rope", "rotary"]: - raise ValueError("Only 'rotary'('rope') position_embedding_type" + - " is supported") + raise ValueError( + "Only 'rotary'('rope') position_embedding_type" + " is supported" + ) - self.word_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) if config.type_vocab_size > 0: self.token_type_embeddings = VocabParallelEmbedding( - config.type_vocab_size, config.hidden_size) + config.type_vocab_size, config.hidden_size + ) else: self.token_type_embeddings = None - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, @@ -66,9 +75,9 @@ class BertWithRopeEmbedding(nn.Module): embeddings = inputs_embeds if self.token_type_embeddings is not None: if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=inputs_embeds.device + ) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings += token_type_embeddings @@ -78,7 +87,6 @@ class BertWithRopeEmbedding(nn.Module): class BertWithRopeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -115,23 +123,28 @@ class BertWithRopeAttention(nn.Module): total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + prefix=f"{prefix}.qkv_proj", + ) self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = EncoderOnlyAttention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = EncoderOnlyAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) - self.out_proj = RowParallelLinear(input_size=hidden_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.out_proj = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) def forward( self, @@ -147,14 +160,15 @@ class BertWithRopeAttention(nn.Module): class BertWithRopeGatedMLP(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.act_fn = get_act_and_mul_fn(hidden_act) self.gate_up_proj = MergedColumnParallelLinear( @@ -164,11 +178,13 @@ class BertWithRopeGatedMLP(nn.Module): quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(hidden_states) @@ -178,26 +194,31 @@ class BertWithRopeGatedMLP(nn.Module): class BertWithRopeMLP(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.act_fn = get_act_fn(hidden_act) - self.up_proj = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.up_proj(hidden_states) @@ -207,7 +228,6 @@ class BertWithRopeMLP(nn.Module): class NomicMoE(nn.Module): - def __init__( self, num_experts: int, @@ -226,34 +246,46 @@ class NomicMoE(nn.Module): self.hidden_size = hidden_size self.total_intermediate_size = intermediate_size self.intermediate_size = divide(intermediate_size, self.tp_size) - self.hidden_act = hidden_act + self.hidden_act = activation_without_mul(hidden_act) if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False) + self.router = ReplicatedLinear( + self.hidden_size, self.num_total_experts, bias=False + ) self.w1 = nn.Parameter( - torch.empty(self.num_total_experts, - self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.w2 = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.bias = nn.Parameter(torch.zeros(self.hidden_size)) - set_weight_attrs(self.w1, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.w1, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2, + { + "weight_loader": self.weight_loader, + }, + ) def weight_loader( self, @@ -289,37 +321,36 @@ class NomicMoE(nn.Module): # FIXME(Isotr0py): This implementation is too tricky, # we should use FusedMoE instead in the future # after supporting ungated activation for it. - topk_weights, topk_ids, _ = fused_topk(hidden_states, - router_logits, - self.top_k, - renormalize=False) - final_hidden_states = torch_vllm_outplace_fused_experts( + topk_weights, topk_ids, _ = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=False + ) + + final_hidden_states = torch.ops.vllm.outplace_fused_experts( hidden_states=hidden_states, w1=self.w1, w2=self.w2, topk_weights=topk_weights, topk_ids=topk_ids, activation=self.hidden_act, - is_act_and_mul=False, ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) + self.bias class BertWithRopeBlock(nn.Module): - - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - moe: bool = False, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + moe: bool = False, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = "", + ): super().__init__() self.attn = BertWithRopeAttention( hidden_size=config.hidden_size, @@ -328,14 +359,17 @@ class BertWithRopeBlock(nn.Module): quant_config=quant_config, bias=bias, rotary_kwargs=rotary_kwargs, - prefix=f"{prefix}.attention") + prefix=f"{prefix}.attention", + ) if moe: - self.mlp = NomicMoE(num_experts=config.num_experts, - top_k=config.moe_top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act) + self.mlp = NomicMoE( + num_experts=config.num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) else: if config.hidden_act in ["silu", "geglu"]: self.mlp = BertWithRopeGatedMLP( @@ -344,7 +378,8 @@ class BertWithRopeBlock(nn.Module): hidden_act=config.hidden_act, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) else: self.mlp = BertWithRopeMLP( hidden_size=config.hidden_size, @@ -352,12 +387,11 @@ class BertWithRopeBlock(nn.Module): hidden_act=config.hidden_act, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) - self.attn_ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp_ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): attn_output = self.attn(positions, hidden_states) @@ -368,27 +402,32 @@ class BertWithRopeBlock(nn.Module): class BertWithRopeEncoder(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, - prefix: str = ""): + def __init__( + self, + vllm_config: VllmConfig, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = "", + ): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config every_n = getattr(config, "moe_every_n_layers", 0) - self.layers = nn.ModuleList([ - BertWithRopeBlock(config=config, - cache_config=cache_config, - quant_config=quant_config, - bias=bias, - moe=every_n > 0 and (layer_idx % every_n == 1), - rotary_kwargs=rotary_kwargs, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + BertWithRopeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + bias=bias, + moe=every_n > 0 and (layer_idx % every_n == 1), + rotary_kwargs=rotary_kwargs, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -405,16 +444,28 @@ class BertWithRopeEncoder(nn.Module): class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + add_pooling_layer: bool = False, + ): super().__init__() self.vllm_config = vllm_config + self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( vllm_config=vllm_config, bias=getattr(self.config, "bias", True), rotary_kwargs=self.config.rotary_kwargs, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + ) + self.pooler = BertPooler(self.config) if add_pooling_layer else None + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) def forward( self, @@ -427,12 +478,12 @@ class BertWithRope(nn.Module, SupportsQuant): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings( + input_ids=input_ids, token_type_ids=token_type_ids + ) return self.encoder(positions, hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) if self.config.hidden_act in ["silu", "geglu"]: @@ -447,9 +498,9 @@ class BertWithRope(nn.Module, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if not self.add_pooling_layer and "pooler" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -465,8 +516,7 @@ class BertWithRope(nn.Module, SupportsQuant): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if name.endswith((".w1", ".w2")): # Nomic-MoE has fused experts weights weight_loader(param, loaded_weight, name) @@ -493,7 +543,8 @@ class NomicBertModel(BertWithRope): "experts.mlp.": "", "experts.": "", "router.layer": "router", - }) + } + ) class GteNewModel(BertWithRope): @@ -505,10 +556,11 @@ class GteNewModel(BertWithRope): "layer": "layers", "attention.qkv_proj": "attn.qkv_proj", "attention.o_proj": "attn.out_proj", - }) + } + ) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # GteNewModel only gate_up_proj does not have bias. # Hack method learned from vllm/model_executor/models/glm.py @@ -526,15 +578,13 @@ class GteNewModel(BertWithRope): else: yield name, weight - def ignore_unnecessary_layers(self, - weights: Iterable[tuple[str, torch.Tensor]]): + def ignore_unnecessary_layers(self, weights: Iterable[tuple[str, torch.Tensor]]): for name, weight in weights: if name.startswith("classifier"): continue yield name, weight - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.ignore_unnecessary_layers(weights) weights = self.split_up_gate_proj(weights) return super().load_weights(weights) @@ -548,7 +598,8 @@ class SnowflakeGteNewModel(GteNewModel): "layer": "layers", "attention.qkv_proj": "attn.qkv_proj", "attention.o_proj": "attn.out_proj", - }) + } + ) class JinaRobertaModel(BertWithRope): @@ -563,11 +614,11 @@ class JinaRobertaModel(BertWithRope): "mlp.fc1.": "mlp.up_proj.", "mlp.fc2": "mlp.down_proj", "norm2": "mlp_ln", - }) + } + ) @torch.inference_mode() - def jina_merge_lora_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]): + def jina_merge_lora_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # use for jina-embeddings-v3 # Merge Lora weights into a single weight tensor. # This is a temporary solution until we have a better way to handle @@ -588,7 +639,7 @@ class JinaRobertaModel(BertWithRope): if o in name: dtype = weights[name].dtype shape = weights[name].shape - weight_name = name[:-len(o)] + weight_name = name[: -len(o)] if "embeddings" in weight_name: B = weights[weight_name + a][i].to(device).float() @@ -597,19 +648,90 @@ class JinaRobertaModel(BertWithRope): B = weights[weight_name + b][i].to(device).float() A = weights[weight_name + a][i].to(device).float() - weight = (weights[weight_name + o].to(device) + - torch.matmul(B, A).view(shape) * scaling) + weight = ( + weights[weight_name + o].to(device) + + torch.matmul(B, A).view(shape) * scaling + ) weight = weight.cpu().to(dtype) weights[weight_name.replace(".parametrizations", "")] = weight - del weights[weight_name + o], weights[weight_name + - a], weights[weight_name + - b] + del ( + weights[weight_name + o], + weights[weight_name + a], + weights[weight_name + b], + ) return [(name, weight) for name, weight in weights.items()] - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.jina_merge_lora_weights(weights) return super().load_weights(weights) + + +@default_pooling_type("CLS") +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.new = GteNewModel( + vllm_config=vllm_config, prefix=prefix, add_pooling_layer=True + ) + self.classifier = ReplicatedLinear( + config.hidden_size, + config.num_labels, + bias=True, + quant_config=quant_config, + params_dtype=vllm_config.model_config.head_dtype, + prefix=maybe_prefix(prefix, "classifier"), + return_bias=False, + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.new.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.new( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 2b457fd8a5b25..aa361e0a2a398 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal implementation of BlipVisionModel intended to be only used +"""Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" + from collections.abc import Iterable from typing import Optional, Union @@ -12,9 +13,11 @@ from transformers import Blip2VisionConfig, BlipVisionConfig from vllm.attention.layer import MultiHeadAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -27,14 +30,14 @@ def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: - grid_length = get_blip_patch_grid_length(image_size=image_size, - patch_size=patch_size) + grid_length = get_blip_patch_grid_length( + image_size=image_size, patch_size=patch_size + ) return grid_length * grid_length # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa class BlipVisionEmbeddings(nn.Module): - def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]): super().__init__() @@ -52,25 +55,28 @@ class BlipVisionEmbeddings(nn.Module): stride=self.patch_size, ) - self.num_patches = get_blip_num_patches(image_size=self.image_size, - patch_size=self.patch_size) + self.num_patches = get_blip_num_patches( + image_size=self.image_size, patch_size=self.patch_size + ) self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter( - torch.randn(1, self.num_positions, self.embed_dim)) + torch.randn(1, self.num_positions, self.embed_dim) + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) position_embeds = self.position_embedding.to(target_dtype) - embeddings = embeddings + position_embeds[:, :embeddings.size(1), :] + embeddings = embeddings + position_embeds[:, : embeddings.size(1), :] return embeddings @@ -93,7 +99,8 @@ class BlipAttention(nn.Module): raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -115,12 +122,16 @@ class BlipAttention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -137,7 +148,6 @@ class BlipAttention(nn.Module): class BlipMLP(nn.Module): - def __init__( self, config: BlipVisionConfig, @@ -149,16 +159,20 @@ class BlipMLP(nn.Module): self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -169,7 +183,6 @@ class BlipMLP(nn.Module): class BlipEncoderLayer(nn.Module): - def __init__( self, config: BlipVisionConfig, @@ -184,13 +197,9 @@ class BlipEncoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.layer_norm1 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp = BlipMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.layer_norm2 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = BlipMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states @@ -209,7 +218,7 @@ class BlipEncoderLayer(nn.Module): class BlipEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`BlipEncoderLayer`]. Args: @@ -232,12 +241,16 @@ class BlipEncoder(nn.Module): else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - BlipEncoderLayer(config=config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + BlipEncoderLayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward(self, inputs_embeds: torch.Tensor): hidden_states = inputs_embeds @@ -284,8 +297,9 @@ class BlipVisionModel(nn.Module, SupportsQuant): require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: - self.post_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) else: self.post_layernorm = None @@ -298,8 +312,7 @@ class BlipVisionModel(nn.Module, SupportsQuant): return self.post_layernorm(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -312,8 +325,7 @@ class BlipVisionModel(nn.Module, SupportsQuant): for name, loaded_weight in weights: # post_layernorm is not needed in BlipVisionModel - if (name.startswith("post_layernorm") - and self.post_layernorm is None): + if name.startswith("post_layernorm") and self.post_layernorm is None: continue # omit layers when num_hidden_layers_override is set @@ -322,7 +334,7 @@ class BlipVisionModel(nn.Module, SupportsQuant): if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -332,8 +344,7 @@ class BlipVisionModel(nn.Module, SupportsQuant): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 2f2b880bb0e14..8e94d59350268 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -6,33 +6,43 @@ from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn -from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, - apply_chunking_to_forward) +from transformers import ( + BatchFeature, + Blip2Config, + Blip2QFormerConfig, + apply_chunking_to_forward, +) from vllm.config import CacheConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip import BlipVisionModel -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) - -# We use this internally as placeholders since there is no image token -# defined on the HuggingFace repo -_IMAGE_TOKEN_ID = 50265 +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix class Blip2ImagePixelInputs(TensorSchema): @@ -43,6 +53,7 @@ class Blip2ImagePixelInputs(TensorSchema): - h: Height of each image - w: Width of each image """ + type: Literal["pixel_values"] data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -54,6 +65,7 @@ class Blip2ImageEmbeddingInputs(TensorSchema): - f: Image feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] @@ -62,7 +74,6 @@ Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] class Blip2QFormerMultiHeadAttention(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -83,8 +94,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module): ) self.num_attention_heads = config.num_attention_heads - self.attention_head_size = (config.hidden_size // - config.num_attention_heads) + self.attention_head_size = config.hidden_size // config.num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size self.scaling = self.attention_head_size**-0.5 @@ -96,18 +106,18 @@ class Blip2QFormerMultiHeadAttention(nn.Module): self.key = nn.Linear(kv_hidden_size, self.all_head_size) self.value = nn.Linear(kv_hidden_size, self.all_head_size) - self.position_embedding_type = getattr(config, - "position_embedding_type", - "absolute") + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) if self.position_embedding_type != "absolute": - raise NotImplementedError("Unsupported position_embedding_type: " - f"{self.position_embedding_type}") + raise NotImplementedError( + f"Unsupported position_embedding_type: {self.position_embedding_type}" + ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): - x = x.view(*x.size()[:-1], self.num_attention_heads, - self.attention_head_size) + x = x.view(*x.size()[:-1], self.num_attention_heads, self.attention_head_size) return x.permute(0, 2, 1, 3) def forward( @@ -118,10 +128,8 @@ class Blip2QFormerMultiHeadAttention(nn.Module): is_cross_attention = encoder_hidden_states is not None if is_cross_attention: - key_layer = self.transpose_for_scores( - self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores( - self.value(encoder_hidden_states)) + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) @@ -130,10 +138,8 @@ class Blip2QFormerMultiHeadAttention(nn.Module): query_layer = self.transpose_for_scores(mixed_query_layer) - attention_scores = torch.matmul(query_layer, - key_layer.transpose(-1, -2)) - attention_probs = torch.softmax(attention_scores * self.scaling, - dim=-1) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_probs = torch.softmax(attention_scores * self.scaling, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -142,20 +148,19 @@ class Blip2QFormerMultiHeadAttention(nn.Module): context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - context_layer = context_layer.view(*context_layer.size()[:-2], - self.all_head_size) + context_layer = context_layer.view( + *context_layer.size()[:-2], self.all_head_size + ) return context_layer class Blip2QFormerSelfOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -170,7 +175,6 @@ class Blip2QFormerSelfOutput(nn.Module): class Blip2QFormerAttention(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -207,7 +211,6 @@ class Blip2QFormerAttention(nn.Module): class Blip2QFormerIntermediate(nn.Module): - def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() @@ -221,13 +224,11 @@ class Blip2QFormerIntermediate(nn.Module): class Blip2QFormerOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -242,7 +243,6 @@ class Blip2QFormerOutput(nn.Module): class Blip2QFormerLayer(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -256,10 +256,12 @@ class Blip2QFormerLayer(nn.Module): self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = Blip2QFormerAttention(config, - quant_config=quant_config, - cache_config=cache_config, - prefix=f"{prefix}.attention") + self.attention = Blip2QFormerAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.attention", + ) self.layer_idx = layer_idx @@ -269,15 +271,16 @@ class Blip2QFormerLayer(nn.Module): quant_config=quant_config, cache_config=cache_config, is_cross_attention=True, - prefix=f"{prefix}.crossattention") + prefix=f"{prefix}.crossattention", + ) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate_query = Blip2QFormerIntermediate( - config, prefix=f"{prefix}.intermediate_query") - self.output_query = Blip2QFormerOutput(config, - prefix=f"{prefix}.output_query") + config, prefix=f"{prefix}.intermediate_query" + ) + self.output_query = Blip2QFormerOutput(config, prefix=f"{prefix}.output_query") def forward( self, @@ -310,8 +313,7 @@ class Blip2QFormerLayer(nn.Module): self.seq_len_dim, attention_output[:, query_length:, :], ) - layer_output = torch.cat([layer_output, layer_output_text], - dim=1) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, @@ -322,21 +324,18 @@ class Blip2QFormerLayer(nn.Module): return layer_output - def feed_forward_chunk(self, - attention_output: torch.Tensor) -> torch.Tensor: + def feed_forward_chunk(self, attention_output: torch.Tensor) -> torch.Tensor: intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output - def feed_forward_chunk_query( - self, attention_output: torch.Tensor) -> torch.Tensor: + def feed_forward_chunk_query(self, attention_output: torch.Tensor) -> torch.Tensor: intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class Blip2QFormerEncoder(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -349,14 +348,18 @@ class Blip2QFormerEncoder(nn.Module): self.config = config - self.layer = nn.ModuleList([ - Blip2QFormerLayer(config, - quant_config=quant_config, - cache_config=cache_config, - layer_idx=layer_idx, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layer = nn.ModuleList( + [ + Blip2QFormerLayer( + config, + quant_config=quant_config, + cache_config=cache_config, + layer_idx=layer_idx, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -378,7 +381,6 @@ class Blip2QFormerEncoder(nn.Module): # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025 class Blip2QFormerModel(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -391,14 +393,15 @@ class Blip2QFormerModel(nn.Module): self.config = config - self.layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.encoder = Blip2QFormerEncoder(config, - quant_config=quant_config, - cache_config=cache_config, - prefix=f"{prefix}.encoder") + self.encoder = Blip2QFormerEncoder( + config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.encoder", + ) def forward( self, @@ -420,7 +423,6 @@ class Blip2QFormerModel(nn.Module): class Blip2ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Blip2Config) @@ -433,7 +435,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo): class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -441,6 +442,7 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config @@ -448,16 +450,19 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images) + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) } class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -510,11 +515,15 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, - info=Blip2ProcessingInfo, - dummy_inputs=Blip2DummyInputsBuilder) -class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, - SupportsQuant): +@MULTIMODAL_REGISTRY.register_processor( + Blip2MultiModalProcessor, + info=Blip2ProcessingInfo, + dummy_inputs=Blip2DummyInputsBuilder, +) +class Blip2ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant +): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -524,7 +533,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -537,13 +545,15 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.vision_model = BlipVisionModel(config.vision_config, quant_config) self.query_tokens = nn.Parameter( - torch.zeros(1, config.num_query_tokens, - config.qformer_config.hidden_size)) + torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size) + ) - self.qformer = Blip2QFormerModel(config.qformer_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.qformer") + self.qformer = Blip2QFormerModel( + config.qformer_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.qformer", + ) self.language_projection = nn.Linear( config.qformer_config.hidden_size, @@ -558,10 +568,12 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _create_image_input(self, - **kwargs: object) -> Optional[Blip2ImageInputs]: + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[Blip2ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -570,50 +582,44 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, if pixel_values is not None: expected_h = expected_w = self.config.vision_config.image_size - return Blip2ImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values, - concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }) + return Blip2ImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) if image_embeds is not None: return Blip2ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") - def _image_pixels_to_features(self, vision_model: BlipVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: - + def _image_pixels_to_features( + self, vision_model: BlipVisionModel, pixel_values: torch.Tensor + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_model(pixel_values) return image_features - def _process_image_pixels(self, - inputs: Blip2ImagePixelInputs) -> torch.Tensor: + def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor: assert self.vision_model is not None pixel_values = inputs["data"] return self._image_pixels_to_features(self.vision_model, pixel_values) - def _process_image_input(self, - image_input: Blip2ImageInputs) -> torch.Tensor: - + def _process_image_input(self, image_input: Blip2ImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None image_features = self._process_image_pixels(image_input) - query_tokens = self.query_tokens.expand(image_features.shape[0], -1, - -1) + query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) query_output = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_features, @@ -624,27 +630,13 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - _IMAGE_TOKEN_ID) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -665,7 +657,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends dummy tokens (denoted as `50265`), resulting in: `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`. @@ -678,39 +670,26 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. - + Info: - [Blip2ImageInputs][] + [`Blip2ImageInputs`][vllm.model_executor.models.blip2.Blip2ImageInputs] """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 6e4a399f3cc6e..4a814fc4020d7 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -18,8 +18,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -29,30 +31,40 @@ from transformers import BloomConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -60,22 +72,20 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BloomAttention(nn.Module): - def __init__( self, config: BloomConfig, @@ -115,13 +125,15 @@ class BloomAttention(nn.Module): alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -137,7 +149,6 @@ class BloomAttention(nn.Module): class BloomMLP(nn.Module): - def __init__( self, config: BloomConfig, @@ -165,7 +176,6 @@ class BloomMLP(nn.Module): class BloomBlock(nn.Module): - def __init__( self, config: BloomConfig, @@ -176,17 +186,17 @@ class BloomBlock(nn.Module): super().__init__() hidden_size = config.hidden_size - self.input_layernorm = nn.LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attention" + ) self.post_attention_layernorm = nn.LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + hidden_size, eps=config.layer_norm_epsilon + ) self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm) + config.apply_residual_connection_post_layernorm + ) def forward( self, @@ -223,7 +233,6 @@ class BloomBlock(nn.Module): @support_torch_compile class BloomModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -240,23 +249,26 @@ class BloomModel(nn.Module): self.embed_dim, ) self.word_embeddings_layernorm = nn.LayerNorm( - self.embed_dim, eps=config.layer_norm_epsilon) + self.embed_dim, eps=config.layer_norm_epsilon + ) # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: BloomBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.h", + ) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + return self.word_embeddings(input_ids) def forward( self, @@ -270,18 +282,18 @@ class BloomModel(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -299,40 +311,43 @@ class BloomModel(nn.Module): if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): - +class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = BloomModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = BloomModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: self.lm_head = self.transformer.word_embeddings else: - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -344,30 +359,28 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) weights = _add_transformer_prefix(weights) return loader.load_weights(weights) def _add_transformer_prefix( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: for name, tensor in weights: - if not name.startswith('transformer.'): - name = 'transformer.' + name + if not name.startswith("transformer."): + name = "transformer." + name yield name, tensor diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index e6914ad4c495d..d8756e236f4cc 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -3,48 +3,73 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property +from itertools import islice from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, - ChameleonVQVAEConfig) +from transformers import ( + BatchFeature, + ChameleonConfig, + ChameleonProcessor, + ChameleonVQVAEConfig, +) from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + row_parallel_weight_loader, +) from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsQuant) -from .utils import (flatten_bn, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -57,12 +82,12 @@ class ChameleonImagePixelInputs(TensorSchema): - h: Height of each image - w: Width of each image """ + type: Literal["pixel_values"] data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] class ChameleonProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(ChameleonConfig) @@ -77,9 +102,7 @@ class ChameleonProcessingInfo(BaseProcessingInfo): return processor.image_seq_length -class ChameleonDummyInputsBuilder( - BaseDummyInputsBuilder[ChameleonProcessingInfo]): - +class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -92,23 +115,26 @@ class ChameleonDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: config = self.info.get_hf_config() width = height = config.vq_config.resolution num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=width, - height=height, - num_images=num_images) + "image": self._get_dummy_images( + width=width, + height=height, + num_images=num_images, + overrides=image_overrides, + ) } -class ChameleonMultiModalProcessor( - BaseMultiModalProcessor[ChameleonProcessingInfo]): - +class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -177,29 +203,23 @@ class ChameleonMultiModalProcessor( class ChameleonLayerNorm(nn.LayerNorm): - def __init__(self, hidden_size, *args, **kwargs): super().__init__(hidden_size, *args, **kwargs) - self.normalized_shape = (hidden_size[-1], ) + self.normalized_shape = (hidden_size[-1],) - set_weight_attrs(self.weight, - {"weight_loader": row_parallel_weight_loader}) - set_weight_attrs(self.bias, - {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader}) def forward(self, hidden_states): - hidden_states = F.layer_norm(hidden_states, - self.normalized_shape, - None, - None, - eps=1e-5) + hidden_states = F.layer_norm( + hidden_states, self.normalized_shape, None, None, eps=1e-5 + ) hidden_states = hidden_states * self.weight + self.bias return hidden_states # Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP class ChameleonMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -213,14 +233,18 @@ class ChameleonMLP(nn.Module): input_size=hidden_size, output_sizes=[intermediate_size] * 2, bias=bias, - quant_config=quant_config) - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config) + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -232,7 +256,6 @@ class ChameleonMLP(nn.Module): # Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa class ChameleonAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -293,16 +316,19 @@ class ChameleonAttention(nn.Module): rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # reshape for layernorm q = q.reshape(-1, self.num_heads, self.head_dim) k = k.reshape(-1, self.num_kv_heads, self.head_dim) @@ -328,7 +354,6 @@ class ChameleonAttention(nn.Module): class ChameleonDecoderLayer(nn.Module): - def __init__( self, config: ChameleonConfig, @@ -341,17 +366,19 @@ class ChameleonDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 4096) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = ChameleonAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -367,10 +394,10 @@ class ChameleonDecoderLayer(nn.Module): quant_config=quant_config, bias=getattr(config, "mlp_bias", False), ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -378,28 +405,24 @@ class ChameleonDecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class ChameleonSwinDecoderLayer(nn.Module): - def __init__( self, config: ChameleonConfig, @@ -412,17 +435,19 @@ class ChameleonSwinDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 4096) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = ChameleonAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -438,10 +463,10 @@ class ChameleonSwinDecoderLayer(nn.Module): quant_config=quant_config, bias=getattr(config, "mlp_bias", False), ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -449,7 +474,6 @@ class ChameleonSwinDecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - residual = hidden_states hidden_states = self.self_attn( positions=positions, @@ -470,7 +494,6 @@ class ChameleonSwinDecoderLayer(nn.Module): # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa class ChameleonVQVAEVectorQuantizer(nn.Module): - def __init__(self, config: ChameleonVQVAEConfig): super().__init__() self.num_embeddings = config.num_embeddings @@ -486,55 +509,52 @@ class ChameleonVQVAEVectorQuantizer(nn.Module): # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z distances = ( - torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + - torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, - self.embedding.weight.transpose(0, 1))) + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", + hidden_state_flattened, + self.embedding.weight.transpose(0, 1), + ) + ) min_encoding_indices = torch.argmin(distances, dim=1) hidden_state_quant = self.embedding(min_encoding_indices).view( - hidden_state.shape) + hidden_state.shape + ) # compute loss for embedding - loss = torch.mean((hidden_state_quant.detach() - hidden_state)** - 2) + self.beta * torch.mean( - (hidden_state_quant - hidden_state.detach())**2) + loss = torch.mean( + (hidden_state_quant.detach() - hidden_state) ** 2 + ) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2) # preserve gradients - hidden_state_quant = hidden_state + (hidden_state_quant - - hidden_state).detach() + hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() # reshape back to match original input shape - hidden_state_quant = hidden_state_quant.permute(0, 3, 1, - 2).contiguous() + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() return hidden_state_quant, loss, min_encoding_indices # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa class ChameleonVQVAEEncoderConvDownsample(nn.Module): - def __init__(self, in_channels: int): super().__init__() - self.conv = nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def forward(self, hidden_states: torch.Tensor): # no asymmetric padding in torch conv, must do it ourselves - hidden_states = F.pad(hidden_states, - pad=(0, 1, 0, 1), - mode="constant", - value=0) + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states) return hidden_states # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa class ChameleonVQVAEEncoderResnetBlock(nn.Module): - def __init__( self, config: ChameleonVQVAEConfig, @@ -544,42 +564,31 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module): ): super().__init__() self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None \ - else out_channels + self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut - self.norm1 = torch.nn.GroupNorm(num_groups=32, - num_channels=in_channels, - eps=1e-6, - affine=True) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - self.norm2 = torch.nn.GroupNorm(num_groups=32, - num_channels=out_channels, - eps=1e-6, - affine=True) + self.norm1 = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = torch.nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) self.dropout = torch.nn.Dropout(config.dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, hidden_states: torch.Tensor): residual = hidden_states @@ -603,35 +612,25 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module): # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa class ChameleonVQVAEEncoderAttnBlock(nn.Module): - def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=32, - num_channels=in_channels, - eps=1e-6, - affine=True) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, hidden_states: torch.Tensor): residual = hidden_states @@ -642,20 +641,20 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module): # compute attention batch_size, channels, height, width = query_states.shape - query_states = query_states.reshape(batch_size, channels, - height * width).permute(0, 2, 1) + query_states = query_states.reshape( + batch_size, channels, height * width + ).permute(0, 2, 1) key_states = key_states.reshape(batch_size, channels, height * width) attn_weights = torch.bmm(query_states, key_states) - attn_weights = attn_weights * (int(channels)**(-0.5)) + attn_weights = attn_weights * (int(channels) ** (-0.5)) attn_weights = F.softmax(attn_weights, dim=2) # attend to values - value_states = value_states.reshape(batch_size, channels, - height * width) + value_states = value_states.reshape(batch_size, channels, height * width) attn_weights = attn_weights.permute(0, 2, 1) - attn_output = torch.bmm(value_states, - attn_weights).reshape(batch_size, channels, - height, width) + attn_output = torch.bmm(value_states, attn_weights).reshape( + batch_size, channels, height, width + ) attn_output = self.proj_out(attn_output) return residual + attn_output @@ -663,7 +662,6 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module): # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa class ChameleonVQVAEEncoder(nn.Module): - def __init__(self, config: ChameleonVQVAEConfig): super().__init__() @@ -676,14 +674,12 @@ class ChameleonVQVAEEncoder(nn.Module): latent_channels = config.latent_channels channel_multiplier = config.channel_multiplier - self.conv_in = torch.nn.Conv2d(in_channels, - base_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, base_channels, kernel_size=3, stride=1, padding=1 + ) curr_res = resolution - in_channel_multiplier = (1, ) + tuple(channel_multiplier) + in_channel_multiplier = (1,) + tuple(channel_multiplier) self.in_channel_multiplier = in_channel_multiplier self.down = nn.ModuleList() for i_level in range(self.num_resolutions): @@ -697,11 +693,14 @@ class ChameleonVQVAEEncoder(nn.Module): config=config, in_channels=block_in, out_channels=block_out, - )) + ) + ) block_in = block_out - if (config.attn_resolutions is not None - and curr_res in config.attn_resolutions - and config.attn_type == "vanilla"): + if ( + config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla" + ): attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) down = nn.Module() @@ -718,18 +717,20 @@ class ChameleonVQVAEEncoder(nn.Module): in_channels=block_in, out_channels=block_in, ) - self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock( - block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.attn_1 = ( + ChameleonVQVAEEncoderAttnBlock(block_in) + if config.attn_type == "vanilla" + else nn.Identity() + ) self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( config=config, in_channels=block_in, out_channels=block_in, ) - self.norm_out = torch.nn.GroupNorm(num_groups=32, - num_channels=block_in, - eps=1e-6, - affine=True) + self.norm_out = torch.nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) self.conv_out = torch.nn.Conv2d( block_in, 2 * latent_channels if double_latent else latent_channels, @@ -745,15 +746,12 @@ class ChameleonVQVAEEncoder(nn.Module): hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - hidden_state = self.down[i_level].block[i_block]( - hidden_states[-1]) + hidden_state = self.down[i_level].block[i_block](hidden_states[-1]) if len(self.down[i_level].attn) > 0: - hidden_state = self.down[i_level].attn[i_block]( - hidden_state) + hidden_state = self.down[i_level].attn[i_block](hidden_state) hidden_states.append(hidden_state) if i_level != self.num_resolutions - 1: - hidden_states.append(self.down[i_level].downsample( - hidden_states[-1])) + hidden_states.append(self.down[i_level].downsample(hidden_states[-1])) # middle last_hidden_state = hidden_states[-1] @@ -770,15 +768,14 @@ class ChameleonVQVAEEncoder(nn.Module): # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa class ChameleonVQVAE(nn.Module): - def __init__(self, config: ChameleonVQVAEConfig): super().__init__() self.encoder = ChameleonVQVAEEncoder(config) self.quantize = ChameleonVQVAEVectorQuantizer(config) - self.quant_conv = torch.nn.Conv2d(config.latent_channels, - config.embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, - config.latent_channels, 1) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d( + config.embed_dim, config.latent_channels, 1 + ) self.eval() # Chameleon's VQ model is frozen def encode( @@ -806,10 +803,9 @@ class ChameleonImageVocabularyMapping: @cached_property def image_tokens(self): - return sorted([ - val for name, val in self.vocab_map.items() - if name.startswith("IMGIMG") - ]) + return sorted( + [val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")] + ) @cached_property def bpe2img(self): @@ -817,13 +813,10 @@ class ChameleonImageVocabularyMapping: def remap(old_name: str) -> str: return "".join( - img_tkn_chr_mapping.get(c, c) - for c in old_name[len("IMGIMG"):-1]) + img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1] + ) - return { - tok: int(remap(self.val2name[tok])) - for tok in self.image_tokens - } + return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} @cached_property def img2bpe(self): @@ -832,7 +825,8 @@ class ChameleonImageVocabularyMapping: @cached_property def bpe2img_search_tensors(self): return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor( - sorted(self.bpe2img.values())) + sorted(self.bpe2img.values()) + ) @cached_property def img2bpe_mapping_tensor(self): @@ -848,7 +842,6 @@ class ChameleonImageVocabularyMapping: class ChameleonModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -862,25 +855,29 @@ class ChameleonModel(nn.Module): self.vocab_size, config.hidden_size, ) - self.vocabulary_mapping = ChameleonImageVocabularyMapping( - config.vocabulary_map) - decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ + self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) + decoder_layer = ( + ChameleonDecoderLayer + if not self.config.swin_norm else ChameleonSwinDecoderLayer + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.vqmodel = ChameleonVQVAE(config.vq_config) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -914,17 +911,16 @@ class ChameleonModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -932,12 +928,16 @@ class ChameleonModel(nn.Module): @MULTIMODAL_REGISTRY.register_processor( ChameleonMultiModalProcessor, info=ChameleonProcessingInfo, - dummy_inputs=ChameleonDummyInputsBuilder) -class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsQuant): + dummy_inputs=ChameleonDummyInputsBuilder, +) +class ChameleonForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant +): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -953,24 +953,29 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - self.model = ChameleonModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = ChameleonModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: + self, **kwargs: object + ) -> Optional[ChameleonImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: @@ -979,42 +984,26 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, vq_config: ChameleonVQVAEConfig = self.config.vq_config expected_h = expected_w = vq_config.resolution - return ChameleonImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values, - concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }) + return ChameleonImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] assert self.model.vqmodel is not None - image_tokens = self.model.get_image_tokens(image_input["data"].to( - self.config.torch_dtype)) + image_tokens = self.model.get_image_tokens( + image_input["data"].to(self.config.torch_dtype) + ) vision_embeddings = self.model.get_input_embeddings(image_tokens) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.model.vocabulary_mapping.image_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1023,31 +1012,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) # Disallow image tokens which does not include special # begin-image and end-image tokens @@ -1057,8 +1034,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1073,8 +1049,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -1092,8 +1067,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, # not vqvae for now. use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1113,7 +1087,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -1126,15 +1101,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) if use_default_weight_loading and name in params_dict: if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5470ff3e8b612..ece719df61f7c 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -3,8 +3,10 @@ # Adapted from # https://github.com/zai-org/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" + import json from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -17,27 +19,34 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GLMAttention(nn.Module): - def __init__( self, config: ChatGLMConfig, @@ -52,9 +61,11 @@ class GLMAttention(nn.Module): assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention - self.total_num_kv_heads = (config.multi_query_group_num - if config.multi_query_attention else - config.num_attention_heads) + self.total_num_kv_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. @@ -99,13 +110,15 @@ class GLMAttention(nn.Module): base=10000 * rope_ratio, is_neox_style=is_neox_style, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -183,25 +196,27 @@ class GLMBlock(nn.Module): ): super().__init__() self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm) + config.apply_residual_connection_post_layernorm + ) self.fp32_residual_connection = config.fp32_residual_connection layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = layer_norm_func(config.hidden_size, - eps=config.layernorm_epsilon) + self.input_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon + ) # Self attention. - self.self_attention = GLMAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + self.self_attention = GLMAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attention" + ) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = layer_norm_func( - config.hidden_size, eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) # MLP self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") @@ -261,8 +276,7 @@ class GLMTransformer(nn.Module): # Transformer layers. self.start_layer, self.end_layer, self.layers = make_layers( self.num_layers, - lambda prefix: GLMBlock( - config, cache_config, quant_config, prefix=prefix), + lambda prefix: GLMBlock(config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) @@ -270,20 +284,22 @@ class GLMTransformer(nn.Module): layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = layer_norm_func( - config.hidden_size, eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> Union[torch.Tensor, IntermediateTensors]: - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states = layer(hidden_states=hidden_states, - position_ids=position_ids) + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states = layer( + hidden_states=hidden_states, position_ids=position_ids + ) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -298,8 +314,10 @@ class GLMTransformer(nn.Module): @support_torch_compile class ChatGLMModel(nn.Module, SupportsQuant): packed_modules_mapping = { - "linear_proj.merged_proj": - ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"] + "linear_proj.merged_proj": [ + "linear_proj.gate_proj", + "linear_proj.dense_h_to_4h", + ] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -311,26 +329,30 @@ class ChatGLMModel(nn.Module, SupportsQuant): self.config = config - self.embedding = VocabParallelEmbedding(config.padded_vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embedding") + self.embedding = VocabParallelEmbedding( + config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embedding", + ) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, - cache_config, - quant_config, - prefix=f"{prefix}.encoder") + self.encoder = GLMTransformer( + config, cache_config, quant_config, prefix=f"{prefix}.encoder" + ) - self.output_layer = ParallelLMHead(config.padded_vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.output_layer") + self.output_layer = ParallelLMHead( + config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.output_layer", + ) self.make_empty_intermediate_tensors = ( - self.encoder.make_empty_intermediate_tensors) + self.encoder.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) @@ -360,8 +382,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), @@ -371,7 +392,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -392,8 +413,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -401,7 +421,8 @@ class ChatGLMModel(nn.Module, SupportsQuant): class ChatGLMBaseModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={".word_embeddings": ""}, ) + orig_to_new_substr={".word_embeddings": ""}, + ) def __init__( self, @@ -420,26 +441,26 @@ class ChatGLMBaseModel(nn.Module): self.multimodal_config = multimodal_config self.quant_config = quant_config - self.max_position_embeddings = getattr(config, "max_sequence_length", - 8192) - self.transformer = transformer_type(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) + self.transformer = transformer_type( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: - self.transformer.output_layer.weight = ( - self.transformer.embedding.weight) + self.transformer.output_layer.weight = self.transformer.embedding.weight self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -447,11 +468,10 @@ class ChatGLMBaseModel(nn.Module): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, - SupportsQuant): +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQuant): packed_modules_mapping = { "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] + "dense_h_to_4h": ["dense_h_to_4h"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -462,7 +482,8 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, "The configuration of this model indicates that it supports " "vision inputs, but you instantiated the text-only version " "of this model. Please use the vision model by setting " - f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + f"`--hf-overrides '{json.dumps(hf_overrides)}'`" + ) super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -473,6 +494,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index dcab008228704..f05d5c4cc1d8b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,36 +1,88 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal implementation of CLIPVisionModel intended to be only used -within a vision language model.""" -from collections.abc import Iterable -from typing import Optional, Union +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn -from transformers import CLIPVisionConfig +from transformers import ( + BatchFeature, + CLIPConfig, + CLIPProcessor, + CLIPTextConfig, + CLIPVisionConfig, +) +from vllm.attention import Attention from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces_base import default_pooling_type +from .utils import AutoWeightsLoader, maybe_prefix +from .vision import ( + VisionEncoderInfo, + VisionFeatureSelectStrategy, + VisionFeatureSelectStrategyStr, + get_num_selected_vision_tokens, + resolve_visual_encoder_outputs, +) + + +class CLIPImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): - def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: - return self.get_patch_grid_length()**2 + 1 + return self.get_patch_grid_length() ** 2 + 1 def get_image_size(self) -> int: return self.vision_config.image_size @@ -44,9 +96,215 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): return image_size // patch_size -# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa -class CLIPVisionEmbeddings(nn.Module): +_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = { + "MEAN": "full", + "ALL": "full", + "CLS": "class", + # This lets us use the same pooling type for both text and image + "LAST": "class", +} + +def _get_vision_feature_select_strategy(pooling_type: str): + try: + return _POOLING_TYPE_TO_STRATEGY[pooling_type] + except KeyError: + raise ValueError( + f"No feature selection strategy is defined for " + f"pooling_type: {pooling_type!r}" + ) from None + + +class CLIPProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(CLIPConfig) + + def get_vision_encoder_info(self): + return CLIPEncoderInfo(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(CLIPProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + + pooler_config = self.ctx.model_config.pooler_config + assert pooler_config is not None + + return get_num_selected_vision_tokens( + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + _get_vision_feature_select_strategy(pooler_config.pooling_type), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): + @cached_property + def image_token_id(self) -> int: + tokenizer = self.info.get_tokenizer() + dummy_token_id = 0 + + assert dummy_token_id not in tokenizer.all_special_ids + + return dummy_token_id + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, + ) -> MultiModalInputs: + if prompt and mm_data: + raise ValueError( + "CLIP accepts text-only or image-only inputs, not both! " + "Image-only inputs means passing an image with an empty text " + "prompt." + ) + + if mm_data: + # For multi-modal data, the prompt after processing should + # only contain the dummy image tokens + tokenization_kwargs = { + **(tokenization_kwargs or {}), + "add_special_tokens": False, + } + + return super().apply( + prompt=prompt, + mm_data=mm_data, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + image_token_id = self.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=PromptIndexTargets.start(), + replacement=get_replacement, + ), + ] + + +# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py +class CLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + + embed_dim = config.hidden_size + + self.token_embedding = VocabParallelEmbedding(config.vocab_size, embed_dim) + self.position_embedding = VocabParallelEmbedding( + config.max_position_embeddings, embed_dim + ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is None: + if input_ids is None: + raise ValueError( + "Either `input_ids` or `input_embeds` must be provided" + ) + + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPVisionEmbeddings(nn.Module): def __init__(self, config: CLIPVisionConfig): super().__init__() self.config = config @@ -65,19 +323,21 @@ class CLIPVisionEmbeddings(nn.Module): bias=False, ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) - self.register_buffer("position_ids", - torch.arange(self.num_positions).expand((1, -1)), - persistent=False) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) @@ -88,15 +348,16 @@ class CLIPVisionEmbeddings(nn.Module): class CLIPAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, + *, prefix: str = "", - ): + attn_cls: Union[type[Attention], type[MultiHeadAttention]], + ) -> None: super().__init__() + self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -105,7 +366,8 @@ class CLIPAttention(nn.Module): raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( @@ -126,8 +388,12 @@ class CLIPAttention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -144,26 +410,29 @@ class CLIPAttention(nn.Module): class CLIPMLP(nn.Module): - def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -174,29 +443,26 @@ class CLIPMLP(nn.Module): class CLIPEncoderLayer(nn.Module): - def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, + *, prefix: str = "", + attn_cls: Union[type[Attention], type[MultiHeadAttention]], ) -> None: super().__init__() self.self_attn = CLIPAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_cls=attn_cls, ) - self.layer_norm1 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp = CLIPMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.layer_norm2 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -222,10 +488,12 @@ class CLIPEncoder(nn.Module): def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None, + *, prefix: str = "", + attn_cls: Union[type[Attention], type[MultiHeadAttention]], ) -> None: super().__init__() @@ -235,15 +503,22 @@ class CLIPEncoder(nn.Module): num_hidden_layers = config.num_hidden_layers else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - CLIPEncoderLayer(config=config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + CLIPEncoderLayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + attn_cls=attn_cls, + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( - self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + self, + inputs_embeds: torch.Tensor, + return_all_hidden_states: bool, ) -> Union[torch.Tensor, list[torch.Tensor]]: hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds @@ -259,8 +534,85 @@ class CLIPEncoder(nn.Module): return hidden_states -class CLIPVisionTransformer(nn.Module): +class CLIPTextTransformer(nn.Module): + def __init__( + self, + config: CLIPTextConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPTextEmbeddings(config) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + attn_cls=Attention, + ) + + self.final_layer_norm = nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.token_embedding(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=False, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CLIPVisionTransformer(nn.Module): def __init__( self, config: CLIPVisionConfig, @@ -286,6 +638,7 @@ class CLIPVisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, ) num_hidden_layers = config.num_hidden_layers @@ -300,73 +653,47 @@ class CLIPVisionTransformer(nn.Module): require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: self.post_layernorm = None - def forward( - self, - pixel_values: torch.Tensor, - feature_sample_layers: Optional[list[int]] = None, - ) -> torch.Tensor: - - hidden_states = self.embeddings(pixel_values) - hidden_states = self.pre_layrnorm(hidden_states) - - return_all_hidden_states = feature_sample_layers is not None - - # Produces either the last layer output or all of the hidden states, - # depending on if we have feature_sample_layers or not - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - return_all_hidden_states=return_all_hidden_states) - - # Handle post-norm (if applicable) and stacks feature layers if needed - encoder_outputs = resolve_visual_encoder_outputs( - encoder_outputs, feature_sample_layers, self.post_layernorm, - self.config.num_hidden_layers) - - return encoder_outputs - - -class CLIPVisionModel(nn.Module, SupportsQuant): - config_class = CLIPVisionConfig - main_input_name = "pixel_values" - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - - def __init__( - self, - config: CLIPVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.vision_model = CLIPVisionTransformer( - config=config, - quant_config=quant_config, - num_hidden_layers_override=num_hidden_layers_override, - require_post_norm=require_post_norm, - prefix=f"{prefix}.vision_model") - - def forward( - self, - pixel_values: torch.Tensor, - feature_sample_layers: Optional[list[int]] = None, - ) -> torch.Tensor: - return self.vision_model(pixel_values, feature_sample_layers) + @property + def dtype(self): + return next(self.parameters()).dtype @property def device(self): return next(self.parameters()).device - # (TODO) Add prefix argument for filtering out weights to be loaded - # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def forward( + self, + pixel_values: torch.Tensor, + *, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + # Produces either the last layer output or all of the hidden states, + # depending on if we have select_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=select_layers is not None, + ) + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, + self.post_layernorm, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + feature_select_strategy=feature_select_strategy, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -375,21 +702,20 @@ class CLIPVisionModel(nn.Module, SupportsQuant): ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - layer_count = len(self.vision_model.encoder.layers) + layer_count = len(self.encoder.layers) for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel - if (name.startswith("vision_model.post_layernorm") - and self.vision_model.post_layernorm is None): + if name.startswith("post_layernorm") and self.post_layernorm is None: continue # omit layers when num_hidden_layers_override is set - if name.startswith("vision_model.encoder.layers"): - layer_idx = int(name.split(".")[3]) + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -400,8 +726,239 @@ class CLIPVisionModel(nn.Module, SupportsQuant): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class CLIPVisionModel(nn.Module): + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + require_post_norm=require_post_norm, + prefix=f"{prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + ) -> torch.Tensor: + return self.vision_model( + pixel_values, + select_layers=select_layers, + feature_select_strategy=feature_select_strategy, + ) + + @property + def dtype(self): + return self.vision_model.dtype + + @property + def device(self): + return self.vision_model.device + + +# Assume EOS token corresponds to LAST token in text model +@default_pooling_type("LAST") +@MULTIMODAL_REGISTRY.register_processor( + CLIPMultiModalProcessor, + info=CLIPProcessingInfo, + dummy_inputs=CLIPDummyInputsBuilder, +) +class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): + is_pooling_model = True + + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + merge_by_field_config = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: CLIPConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer( + text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_model"), + ) + self.vision_model = CLIPVisionTransformer( + vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, + self.projection_dim, + bias=False, + ) + self.text_projection = nn.Linear( + self.text_embed_dim, + self.projection_dim, + bias=False, + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler_config = pooler_config + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + # Assumes that self.forward is called after self.get_input_embeddings + self._is_text_input = True + + def get_text_features( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pooled_output = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + text_features = self.text_projection(pooled_output) + + return text_features + + def get_image_features( + self, + pixel_values: torch.Tensor, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + ) -> torch.Tensor: + if feature_select_strategy is None: + feature_select_strategy = _get_vision_feature_select_strategy( + self.pooler_config.pooling_type + ) + + pooled_output = self.vision_model( + pixel_values=pixel_values, + select_layers=None, + feature_select_strategy=feature_select_strategy, + ) + + image_features = self.visual_projection(pooled_output) + + return image_features + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[CLIPImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + expected_h = expected_w = self.config.vision_config.image_size + return CLIPImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) + + def _process_image_inputs(self, inputs: CLIPImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["data"] + + return self.get_image_features(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.text_model + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + self._is_text_input = ( + multimodal_embeddings is None or len(multimodal_embeddings) == 0 + ) + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + vision_embeddings = self._process_image_inputs(image_input) + return vision_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + raise RuntimeError("PP is not supported for this model") + + # Multimodal inputs + if not self._is_text_input: + return inputs_embeds + + # Text inputs + return self.get_text_features( + input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_substrs=[".position_ids"], + ignore_unexpected_prefixes=["logit_scale."], + ) + + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 179cc2af8eb3f..73aafbd011444 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -11,35 +11,44 @@ from torch import nn from transformers import BatchFeature, PretrainedConfig from transformers.models.cohere2_vision import Cohere2VisionConfig from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501 - get_optimal_tiled_canvas) + get_optimal_tiled_canvas, +) from transformers.models.cohere2_vision.processing_cohere2_vision import ( - Cohere2VisionProcessor) + Cohere2VisionProcessor, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import MulAndSilu -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalFieldConfig, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) class Cohere2VisionImagePixelInputs(TensorSchema): @@ -68,7 +77,7 @@ class Cohere2VisionImagePixelInputs(TensorSchema): class Cohere2VisionMultiModalProjector(nn.Module): """Multimodal projector that maps vision features to text embedding space. - + Uses pixel shuffle downsampling followed by SwiGLU activation. """ @@ -77,8 +86,7 @@ class Cohere2VisionMultiModalProjector(nn.Module): self.downsample_factor = config.downsample_factor # Input dimension after pixel shuffle downsampling - input_dim = config.vision_config.hidden_size * ( - config.downsample_factor**2) + input_dim = config.vision_config.hidden_size * (config.downsample_factor**2) # MergedColumnParallelLinear expects the intermediate size to be a list # of sizes, so that it will load the weights as two separate linear # layers before applying any parallelism. @@ -111,28 +119,26 @@ class Cohere2VisionMultiModalProjector(nn.Module): def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor: """Apply pixel shuffle downsampling to reduce spatial dimensions. - + Args: image_features: Input tensor of shape [B, S, D] where S = H*W - + Returns: Downsampled tensor with increased channel dimension """ - height = width = int(image_features.shape[1]**0.5) + height = width = int(image_features.shape[1] ** 0.5) x = image_features.reshape(image_features.shape[0], width, height, -1) n, h, w, c = x.size() - scale_factor = 1. / self.downsample_factor + scale_factor = 1.0 / self.downsample_factor nh = int(h * scale_factor) nw = int(w * scale_factor) - x = x.reshape(n, nh, self.downsample_factor, nw, - self.downsample_factor, c) + x = x.reshape(n, nh, self.downsample_factor, nw, self.downsample_factor, c) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.reshape(n, nh, nw, -1) return x class Cohere2VisionProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> Cohere2VisionConfig: return self.ctx.get_hf_config(Cohere2VisionConfig) @@ -147,8 +153,8 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - height = image_processor.size['height'] - width = image_processor.size['width'] + height = image_processor.size["height"] + width = image_processor.size["width"] max_patches = image_processor.max_patches return ImageSize(height=height * max_patches, width=width) @@ -197,8 +203,8 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo): class Cohere2VisionDummyInputsBuilder( - BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]): - + BaseDummyInputsBuilder[Cohere2VisionProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -211,22 +217,26 @@ class Cohere2VisionDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - image_size = \ - self.info.get_image_size_with_most_features() + image_size = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=image_size.width, - height=image_size.height, - num_images=num_images) + "image": self._get_dummy_images( + width=image_size.width, + height=image_size.height, + num_images=num_images, + overrides=image_overrides, + ) } class Cohere2VisionMultiModalProcessor( - BaseMultiModalProcessor[Cohere2VisionProcessingInfo]): - + BaseMultiModalProcessor[Cohere2VisionProcessingInfo] +): def _call_hf_processor( self, prompt: str, @@ -242,22 +252,26 @@ class Cohere2VisionMultiModalProcessor( ) # Ensure num_patches is available for proper tensor splitting - if "num_patches" not in processed_outputs and ( - images := mm_data.get("images")) is not None: + if ( + "num_patches" not in processed_outputs + and (images := mm_data.get("images")) is not None + ): hf_processor = self.info.get_hf_processor(**mm_kwargs) # Fallback calculation if HF processor didn't provide num_patches - parsed_images = self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) num_patches = [ self.info.get_num_patches( image_width=parsed_images.get_image_size(i).width, image_height=parsed_images.get_image_size(i).height, processor=hf_processor, - ) for i in range(len(parsed_images)) + ) + for i in range(len(parsed_images)) ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -270,8 +284,7 @@ class Cohere2VisionMultiModalProcessor( ) -> Mapping[str, MultiModalFieldConfig]: num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -298,8 +311,7 @@ class Cohere2VisionMultiModalProcessor( image_height=image_size.height, processor=hf_processor, ) - patch_tokens = (image_token * img_tokens_per_tile + - img_line_break_token) + patch_tokens = image_token * img_tokens_per_tile + img_line_break_token repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}" return PromptUpdateDetails.select_text(repl, image_token) @@ -316,9 +328,10 @@ class Cohere2VisionMultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( Cohere2VisionMultiModalProcessor, info=Cohere2VisionProcessingInfo, - dummy_inputs=Cohere2VisionDummyInputsBuilder) -class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=Cohere2VisionDummyInputsBuilder, +) +class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -326,7 +339,8 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.language_model.": "language_model.model.", "lm_head.": "language_model.lm_head.", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -338,37 +352,39 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, self.multimodal_config = multimodal_config self._patch_quant_config(config, quant_config) - self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.vocab_size = config.text_config.vocab_size - self.multi_modal_projector = \ - Cohere2VisionMultiModalProjector( - config, prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.multi_modal_projector = Cohere2VisionMultiModalProjector( + config, prefix=maybe_prefix(prefix, "multi_modal_projector") + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=config.text_config.architectures) + architectures=config.text_config.architectures, + ) @property def dtype(self): return next(self.parameters()).dtype - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs, - **kwargs) -> list[torch.Tensor]: + def _process_image_input( + self, image_input: Cohere2VisionImagePixelInputs, **kwargs + ) -> list[torch.Tensor]: """Process image pixels through vision tower and projector. - + Args: - image_input: Validated image input containing pixel values and + image_input: Validated image input containing pixel values and patch counts - + Returns: List of flattened image embeddings, one per image """ @@ -384,70 +400,52 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = self.multi_modal_projector(image_features) # Split and flatten embeddings per image - return [ - e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) - ] + return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())] def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Cohere2VisionImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Cohere2VisionImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) - assert image_embeds is None, \ - "Cohere2Vision does not support image_embeds." + assert image_embeds is None, "Cohere2Vision does not support image_embeds." if pixel_values is None: return None return Cohere2VisionImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - num_patches=flatten_bn(num_patches, concat=True), + pixel_values=pixel_values, + num_patches=num_patches, resolve_bindings={ "h": self.config.vision_config.image_size, "w": self.config.vision_config.image_size, - }) + }, + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and (llm_quant_config - is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_tower") def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input, **kwargs) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.image_token_id, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -459,14 +457,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -478,7 +468,5 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 4dd84b8f8fdd5..e38c3c0492fbf 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -22,7 +22,9 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -34,27 +36,33 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name, - row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, + row_parallel_weight_loader, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) @torch.compile(backend=current_platform.simple_compile_backend) @@ -63,30 +71,27 @@ def layer_norm_func(hidden_states, weight, variance_epsilon): hidden_states = hidden_states.to(torch.float32) mean = hidden_states.mean(-1, keepdim=True) variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - mean) * torch.rsqrt(variance + - variance_epsilon) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon) hidden_states = weight.to(torch.float32) * hidden_states return hidden_states.to(input_dtype) class LayerNorm(nn.Module): - def __init__(self, param_shape=None, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(param_shape)) self.variance_epsilon = eps - set_weight_attrs(self.weight, - {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader}) def forward(self, hidden_states, residuals=None): - hidden_states = layer_norm_func(hidden_states, self.weight, - self.variance_epsilon) + hidden_states = layer_norm_func( + hidden_states, self.weight, self.variance_epsilon + ) return hidden_states, residuals # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): - def __init__( self, config: Union[CohereConfig, Cohere2Config], @@ -121,7 +126,6 @@ class CohereMLP(nn.Module): class CohereAttention(nn.Module): - def __init__( self, config: Union[CohereConfig, Cohere2Config], @@ -151,8 +155,8 @@ class CohereAttention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = getattr( - config, "model_max_length", None) or getattr( - config, "max_position_embeddings", 8192) + config, "model_max_length", None + ) or getattr(config, "max_position_embeddings", 8192) self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) self.use_qk_norm = getattr(config, "use_qk_norm", False) @@ -190,21 +194,24 @@ class CohereAttention(nn.Module): if config.layer_types[layer_idx] == "sliding_attention": self.sliding_window = config.sliding_window - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - per_layer_sliding_window=self.sliding_window, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn", + ) if self.use_qk_norm: - self.q_norm = LayerNorm(param_shape=(self.num_heads, - self.head_dim), - eps=config.layer_norm_eps) - self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, - self.head_dim), - eps=config.layer_norm_eps) + self.q_norm = LayerNorm( + param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps + ) + self.k_norm = LayerNorm( + param_shape=(self.num_kv_heads, self.head_dim), + eps=config.layer_norm_eps, + ) def _apply_qk_norm(self, q, k): q = q.view(*q.shape[:-1], -1, self.head_dim) @@ -232,25 +239,27 @@ class CohereAttention(nn.Module): class CohereDecoderLayer(nn.Module): - - def __init__(self, - config: Union[CohereConfig, Cohere2Config], - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Union[CohereConfig, Cohere2Config], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = CohereAttention( + config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) - self.mlp = CohereMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), - eps=config.layer_norm_eps) + self.mlp = CohereMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.input_layernorm = LayerNorm( + param_shape=(config.hidden_size), eps=config.layer_norm_eps + ) def forward( self, @@ -274,7 +283,6 @@ class CohereDecoderLayer(nn.Module): @support_torch_compile class CohereModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -285,22 +293,29 @@ class CohereModel(nn.Module): self.quant_config = quant_config self.config = config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: CohereDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = LayerNorm(param_shape=(config.hidden_size), - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = LayerNorm( + param_shape=(config.hidden_size), eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -322,22 +337,20 @@ class CohereModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -349,14 +362,15 @@ class CohereModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -386,8 +400,7 @@ class CohereModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -421,13 +434,15 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.quant_config = quant_config - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=config.logit_scale) - self.model = CohereModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale + ) + self.model = CohereModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -440,27 +455,27 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - is_not_lora = hasattr(self.model.embed_tokens, 'weight') + is_not_lora = hasattr(self.model.embed_tokens, "weight") if is_not_lora: - logits = self.logits_processor(self.model.embed_tokens, - hidden_states, sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) else: - logits = self.logits_processor(self.model.embed_tokens.base_layer, - hidden_states, sampling_metadata) + logits = self.logits_processor( + self.model.embed_tokens.base_layer, hidden_states + ) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( - self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"]) + self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"] + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 882df7e8162c5..ee6a3ba773bb8 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -10,21 +10,25 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: - from vllm.config import VllmConfig logger = init_logger(__name__) class VerifyAndUpdateConfig: - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: raise NotImplementedError -class GteNewModelConfig(VerifyAndUpdateConfig): +class Gemma3TextModelConfig: + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = not hf_config.use_bidirectional_attention + +class GteNewModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -40,12 +44,11 @@ class GteNewModelConfig(VerifyAndUpdateConfig): "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "base": config.rope_theta, - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -54,7 +57,6 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): class JinaRobertaModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -68,29 +70,27 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "base": getattr(config, "rope_theta", config.rotary_emb_base), - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } class NomicBertModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config assert config.__class__.__name__ == "NomicBertConfig" assert config.activation_function in ["swiglu", "gelu"] - config.position_embedding_type = getattr(config, - "position_embedding_type", - "rope") + config.position_embedding_type = getattr( + config, "position_embedding_type", "rope" + ) if config.activation_function == "swiglu": config.hidden_act = "silu" else: config.hidden_act = config.activation_function - assert (config.mlp_fc1_bias == config.mlp_fc2_bias == - config.qkv_proj_bias) + assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias config.bias = config.qkv_proj_bias assert config.rotary_emb_scale_base is None @@ -109,7 +109,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): "rotary_dim": rotary_emb_dim, "max_position": max_trained_positions, "base": getattr(config, "rope_theta", config.rotary_emb_base), - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } # we ignore config.rotary_scaling_factor so that for datasets shorter @@ -117,15 +117,18 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): # with SentenceTransformer. # The context extension uses vllm style rope_theta and rope_scaling. # See #17785 #18755 - if (not vllm_config.model_config.hf_overrides - and vllm_config.model_config.original_max_model_len is None): + if ( + not vllm_config.model_config.hf_overrides + and vllm_config.model_config.original_max_model_len is None + ): # Default # Reset max_model_len to max_trained_positions. # nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. max_model_len_before = vllm_config.model_config.max_model_len - max_model_len = min(vllm_config.model_config.max_model_len, - max_trained_positions) + max_model_len = min( + vllm_config.model_config.max_model_len, max_trained_positions + ) vllm_config.recalculate_max_model_len(max_model_len) logger.warning( @@ -133,7 +136,9 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): "Changing max_model_len from %s to %s. " "To enable context extension, see: " "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html", - max_model_len_before, vllm_config.model_config.max_model_len) + max_model_len_before, + vllm_config.model_config.max_model_len, + ) else: # We need to re-verify max_model_len to avoid lengths # greater than position_embedding. @@ -143,7 +148,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): if isinstance(model_config.hf_overrides, dict): # hf_overrides_kw max_model_len = model_config.hf_overrides.get( - "max_model_len", vllm_config.model_config.max_model_len) + "max_model_len", vllm_config.model_config.max_model_len + ) else: # hf_overrides_fn # This might be overridden by sentence_bert_config.json. @@ -165,7 +171,6 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -175,7 +180,6 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -185,36 +189,36 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config - is_original_qwen3_reranker = getattr(config, - "is_original_qwen3_reranker", - False) + is_original_qwen3_reranker = getattr( + config, "is_original_qwen3_reranker", False + ) if not is_original_qwen3_reranker: return tokens = getattr(config, "classifier_from_token", None) - assert tokens is not None and len(tokens) == 2, \ - ("Try loading the original Qwen3 Reranker?, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") + assert tokens is not None and len(tokens) == 2, ( + "Try loading the original Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py" + ) vllm_config.model_config.hf_config.method = "from_2_way_softmax" class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config - config.num_labels = 1 + pooler_config = vllm_config.model_config.pooler_config + if pooler_config.logit_bias is None: + pooler_config.logit_bias = 2.65 class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -230,53 +234,89 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "base": config.rope_theta, - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } -class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig): - - @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config - config.max_seq_len_to_capture = config.max_model_len - logger.info( - "Setting max_seq_len_to_capture to %d " - "to ensure that CUDA graph capture " - "covers sequences of length up to max_model_len.", - config.max_model_len) - - class GptOssForCausalLMConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - decoding_config = vllm_config.decoding_config - if decoding_config.reasoning_backend == "": - decoding_config.reasoning_backend = "GptOss" + structured_outputs_config = vllm_config.structured_outputs_config + if structured_outputs_config.reasoning_parser == "": + structured_outputs_config.reasoning_parser = "openai_gptoss" - # Increase the max capture size from 512 to 1024 for performance. + # Increase the max capture size from 512 to 992 for performance. # NOTE(woosuk): This will increase the number of CUDA graphs - # from 67 to 83. + # from 67 to 81. scheduler_config = vllm_config.scheduler_config if len(scheduler_config.cuda_graph_sizes) == 1: max_capture_size = scheduler_config.cuda_graph_sizes[0] # FIXME(woosuk): When using full cuda graph with FA3, the max # supported size is 992. - if max_capture_size < 1024: + if max_capture_size < 992: cuda_graph_sizes = [1, 2, 4] # Step size 8 for small batch sizes cuda_graph_sizes += [i for i in range(8, 256, 8)] # Step size 16 for larger batch sizes - cuda_graph_sizes += [i for i in range(256, 1025, 16)] + cuda_graph_sizes += [i for i in range(256, 993, 16)] scheduler_config.cuda_graph_sizes = cuda_graph_sizes logger.info( - "Overriding max cuda graph capture size to " - "%d for performance.", 1024) + "Overriding max cuda graph capture size to %d for performance.", 992 + ) + + +class MambaModelConfig(VerifyAndUpdateConfig): + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Enable FULL_AND_PIECEWISE cuda graph mode by default (required + to get good performance for mamba layers in V1). + + Args: + vllm_config: vLLM Config + """ + + if not envs.VLLM_USE_V1: + return + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + + # Set mamba block size to max_model_len (this may get + # override by prefix caching logic later) + cache_config.mamba_block_size = model_config.max_model_len + + # TODO(@tdoublep) find a better way to do this than whitelist + MAMBA2_MODELS = [ + "BambaForCausalLM", + "FalconH1ForCausalLM", + "GraniteMoeHybridForCausalLM", + "Mamba2ForCausalLM", + "NemotronHForCausalLM", + "Zamba2ForCausalLM", + ] + if cache_config.enable_prefix_caching: + if model_config.architecture in MAMBA2_MODELS: + logger.info( + "Warning: Prefix caching is currently enabled. " + "Its support for Mamba2 layers is experimental. " + "Please report any issues you may observe." + ) + else: + logger.info( + "Hybrid or mamba-based model detected without " + "support for prefix caching: disabling." + ) + cache_config.enable_prefix_caching = False + + # TODO(tdoublep): remove once cascade attention is supported + logger.info( + "Disabling cascade attention since it is not supported for hybrid models." + ) + model_config.disable_cascade_attn = True class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): - @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ @@ -293,6 +333,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): if not envs.VLLM_USE_V1: return + # Enable FULL_AND_PIECEWISE by default + MambaModelConfig.verify_and_update_config(vllm_config) + cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -308,7 +351,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), dtype=kv_cache_dtype, - use_mla=model_config.use_mla).page_size_bytes + ).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, @@ -322,27 +365,75 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): block_size=model_config.max_model_len, ).page_size_bytes - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: 128-byte alignment + # * Other MLA backends: 64-byte alignment + if model_config.use_mla: + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + else: + kernel_block_alignment_size = 16 + + if cache_config.enable_prefix_caching: + # With prefix caching, select attention block size to + # optimize for mamba kernel performance + + # mamba SSD kernel uses a chunk_size, e.g. 256 + # Align the block to the kernel: use lowest multiple of chunk_size + # of attention tokens that would fit mamba_page_size: + # e.g. for mamba page size = 788kB + # attn_1_token = 2kB -> fits ~394 tokens + # then round up to a mulitple of 256 -> 512 tokens + # End result: + # attn_block_size = 512 + # mamba_block_size = 512 (aligned to a multiple of chunk_size) + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + + from math import gcd + + def lcm(a, b): + return a * b // gcd(a, b) + + base_chunk_size = model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, select minimum valid attention block size + # to minimize mamba state padding + + # Calculate minimum attention block size that satisfies both: + # 1. Backend alignment requirements (kernel_block_alignment_size) + # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token + ) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. - if (cache_config.block_size is None - or cache_config.block_size < attn_block_size): + if cache_config.block_size is None or cache_config.block_size < attn_block_size: cache_config.block_size = attn_block_size logger.info( "Setting attention block size to %d tokens " "to ensure that attention page size is >= mamba page size.", - attn_block_size) + attn_block_size, + ) # compute new attention page size - attn_page_size = \ - cache_config.block_size * attn_page_size_1_token + attn_page_size = cache_config.block_size * attn_page_size_1_token assert attn_page_size >= mamba_page_size @@ -351,20 +442,52 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): return # pad mamba page size to exactly match attention - if (cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size): - cache_config.mamba_page_size_padded = (attn_page_size) - mamba_padding_pct = 100 * (attn_page_size - - mamba_page_size) / mamba_page_size + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size + ) logger.info( "Padding mamba page size by %.2f%% to ensure " "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) + "exactly equal.", + mamba_padding_pct, + ) + + +class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 + """ + hf_config = vllm_config.model_config.hf_config + + # Mirror the check in vllm/model_executor/models/deepseek_v2.py + is_v32 = hasattr(hf_config, "index_topk") + assert is_v32 + + # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. + # "auto") + cache_config = vllm_config.cache_config + if cache_config.cache_dtype == "auto" or cache_config.cache_dtype.startswith( + "fp8" + ): + cache_config.cache_dtype = "fp8_ds_mla" + logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") + if cache_config.cache_dtype == "bfloat16": + cache_config.cache_dtype = "auto" + logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, + "GteNewForSequenceClassification": GteNewModelConfig, + "Gemma3TextModel": Gemma3TextModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, @@ -372,6 +495,9 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JambaForSequenceClassification": JambaForSequenceClassificationConfig, - "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, "GptOssForCausalLM": GptOssForCausalLMConfig, + "MambaForCausalLM": MambaModelConfig, + "Mamba2ForCausalLM": MambaModelConfig, + "FalconMambaForCausalLM": MambaModelConfig, + "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, } diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py deleted file mode 100644 index f03c58a12932f..0000000000000 --- a/vllm/model_executor/models/constant_size_cache.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod -from typing import Any - -import torch - -from vllm.attention.backends.utils import PAD_SLOT_ID - - -class ConstantSizeCache(ABC): - """ - Abstract base class for managing constant size caches - like Mamba and Minimax. - """ - - def __init__(self, max_batch_size: int): - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the cache - self.cache_indices_mapping: dict[str, dict[int, int]] = {} - self.free_cache_indices = list(range(max_batch_size)) - - @property - @abstractmethod - def cache(self) -> Any: - """Return the underlying cache tensor(s)""" - pass - - @abstractmethod - def _copy_cache(self, from_index: int, to_index: int): - """Copy cache data from one index to another""" - pass - - def current_run_tensors(self, **kwargs) -> tuple: - """ - Return the tensors for the current run's conv and ssm state. - """ - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - finished_requests_ids = kwargs["finished_requests_ids"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_cache( - request_ids_to_seq_ids, finished_requests_ids) - - state_indices_tensor = torch.as_tensor(state_indices, - dtype=torch.int32, - device="cuda") - cache_tensors = self.cache - else: - # CUDA graph capturing runs - cache_tensors, state_indices_tensor = kwargs[ - "seqlen_agnostic_capture_inputs"] - - return (cache_tensors, state_indices_tensor) - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant state_indices into the CUDA graph input buffer - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - finished_requests_ids = kwargs["finished_requests_ids"] - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - assert "seqlen_agnostic_capture_inputs" in input_buffers - _, input_state_indices_buffer = input_buffers[ - "seqlen_agnostic_capture_inputs"] - - self._release_finished_requests(finished_requests_ids) - state_indices = self._prepare_current_run_cache( - request_ids_to_seq_ids, finished_requests_ids) - cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( - state_indices) - state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) - - input_state_indices_buffer.copy_( - torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Cache during the CUDA graph replay - runs. - """ - state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") - return (self.cache, state_indices_tensor) - - def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, - finished_requests_ids) -> int: - """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. - """ - if cur_rid in finished_requests_ids: - # set as pad, do not allocate destination index - return PAD_SLOT_ID - elif cur_rid not in self.cache_indices_mapping: - destination_index = self.free_cache_indices.pop() - self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} - return destination_index - elif seq_id not in (seq_ids2indices := - self.cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have - # already happened, so we copy the - # existing cache into the siblings seq_ids caches - index_exists = next(iter(seq_ids2indices.values())) - # case of decoding n>1, copy prefill cache to decoding indices - destination_index = self.free_cache_indices.pop() - self._copy_cache(from_index=index_exists, - to_index=destination_index) - self.cache_indices_mapping[cur_rid][seq_id] = destination_index - return destination_index - else: - return self.cache_indices_mapping[cur_rid][seq_id] - - def _prepare_current_run_cache( - self, request_ids_to_seq_ids: dict[str, list[int]], - finished_requests_ids: list[str]) -> list[int]: - return [ - self._assign_seq_id_to_cache_index(req_id, seq_id, - finished_requests_ids) - for req_id, seq_ids in request_ids_to_seq_ids.items() - for seq_id in seq_ids - ] - - def _release_finished_requests(self, - finished_seq_groups_req_ids: list[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.cache_indices_mapping: - for seq_id in self.cache_indices_mapping[req_id]: - self.free_cache_indices.append( - self.cache_indices_mapping[req_id][seq_id]) - self.cache_indices_mapping.pop(req_id) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index e74d90e0b1d7d..8ec7a82a7b2ad 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -10,26 +11,39 @@ from transformers import DbrxConfig from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class DbrxRouter(nn.Module): @@ -60,7 +74,6 @@ class DbrxRouter(nn.Module): class DbrxExperts(FusedMoE): - def __init__( self, config: DbrxConfig, @@ -82,12 +95,16 @@ class DbrxExperts(FusedMoE): ) self.config = config self.d_model = config.d_model - self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // - self.tp_size) + self.intermediate_size = self.config.ffn_config.ffn_hidden_size // self.tp_size # Define custom weight loader for dbrx model - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, param_name: str): + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + param_name: str, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size @@ -111,8 +128,9 @@ class DbrxExperts(FusedMoE): loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ) - param_data[:, shard_size:2 * - shard_size, :] = loaded_weight[:, shard, :] + param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[ + :, shard, : + ] elif param_name.endswith("weight_scale"): param_data[:, 1] = loaded_weight else: @@ -151,10 +169,12 @@ class DbrxMoE(nn.Module): self.router = DbrxRouter(config, self.params_dtype) - self.experts = DbrxExperts(config=config, - quant_config=quant_config, - params_dtype=self.params_dtype, - prefix=f"{prefix}.experts") + self.experts = DbrxExperts( + config=config, + quant_config=quant_config, + params_dtype=self.params_dtype, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -166,7 +186,6 @@ class DbrxMoE(nn.Module): class DbrxAttention(nn.Module): - def __init__( self, config: DbrxConfig, @@ -222,13 +241,15 @@ class DbrxAttention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -246,7 +267,6 @@ class DbrxAttention(nn.Module): class DbrxFusedNormAttention(nn.Module): - def __init__( self, config: DbrxConfig, @@ -256,10 +276,9 @@ class DbrxFusedNormAttention(nn.Module): ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = DbrxAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -281,7 +300,6 @@ class DbrxFusedNormAttention(nn.Module): class DbrxBlock(nn.Module): - def __init__( self, config: DbrxConfig, @@ -291,10 +309,8 @@ class DbrxBlock(nn.Module): ): super().__init__() self.norm_attn_norm = DbrxFusedNormAttention( - config, - cache_config, - quant_config, - prefix=f"{prefix}.norm_attn_norm") + config, cache_config, quant_config, prefix=f"{prefix}.norm_attn_norm" + ) self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn") def forward( @@ -312,7 +328,6 @@ class DbrxBlock(nn.Module): class DbrxModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -327,19 +342,17 @@ class DbrxModel(nn.Module): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: DbrxBlock( - config, cache_config, quant_config, prefix=prefix), + lambda prefix: DbrxBlock(config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.blocks", ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): - if hasattr(module, "bias") and isinstance(module.bias, - nn.Parameter): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.d_model)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.d_model + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -359,31 +372,34 @@ class DbrxModel(nn.Module): else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] - for block in self.blocks[self.start_layer:self.end_layer]: + for block in islice(self.blocks, self.start_layer, self.end_layer): hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - expert_params_mapping = [( - "w13" if weight_name in ["w1", "v1"] else "w2", - f"mlp.{weight_name}", - ) for weight_name in ["w1", "v1", "w2"]] + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + expert_params_mapping = [ + ( + "w13" if weight_name in ["w1", "v1"] else "w2", + f"mlp.{weight_name}", + ) + for weight_name in ["w1", "v1", "w2"] + ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -409,39 +425,39 @@ class DbrxModel(nn.Module): if name is None: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class DbrxForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config if config.tie_word_embeddings: - raise ValueError( - "tie_word_embeddings is not supported for Dbrx models.") + raise ValueError("tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = DbrxModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -453,20 +469,18 @@ class DbrxForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2f0202f1e038d..67258c2f77b83 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -32,34 +34,43 @@ from transformers import PretrainedConfig from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class DeepseekMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -71,17 +82,19 @@ class DeepseekMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results) + quant_config=quant_config, + reduce_results=reduce_results, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -92,7 +105,6 @@ class DeepseekMLP(nn.Module): class DeepseekMoE(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -108,26 +120,29 @@ class DeepseekMoE(nn.Module): if self.tp_size > self.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") + f"the number of experts {self.n_routed_experts}." + ) - self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) + self.experts = nn.ModuleList( + [ + DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + for idx in range(self.n_routed_experts) + ] + ) self.pack_params() - self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, - bias=False, - quant_config=None) + self.gate = ReplicatedLinear( + config.hidden_size, self.n_routed_experts, bias=False, quant_config=None + ) if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -162,24 +177,26 @@ class DeepseekMoE(nn.Module): shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) + + topk_weights, topk_ids, _ = fused_topk( + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + ) + + final_hidden_states = fused_experts( + hidden_states, self.w1, self.w2, topk_weights, topk_ids, inplace=True + ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class DeepseekAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -238,13 +255,15 @@ class DeepseekAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -260,7 +279,6 @@ class DeepseekAttention(nn.Module): class DeepseekDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -273,8 +291,7 @@ class DeepseekDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = DeepseekAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -286,12 +303,14 @@ class DeepseekDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = DeepseekMoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, @@ -300,10 +319,10 @@ class DeepseekDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -316,22 +335,19 @@ class DeepseekDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class DeepseekModel(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -352,11 +368,12 @@ class DeepseekModel(nn.Module): lambda prefix: DeepseekDecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -377,18 +394,16 @@ class DeepseekModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -403,7 +418,7 @@ class DeepseekModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -411,8 +426,9 @@ class DeepseekModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue @@ -425,20 +441,24 @@ class DeepseekModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class DeepseekForCausalLM(nn.Module, SupportsPP): +class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -446,16 +466,21 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = DeepseekModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.model = DeepseekModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -467,20 +492,18 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 0c9c83cf61000..faa7edd4bc3c3 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -14,19 +14,23 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, - DeepseekV3ForCausalLM) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM, +) from .utils import AutoWeightsLoader, maybe_prefix @support_torch_compile class DeepseekV2Model(nn.Module): - def __init__( self, *, @@ -35,10 +39,7 @@ class DeepseekV2Model(nn.Module): start_layer_id: int = 0, ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config quant_config = vllm_config.quant_config self.vocab_size = self.config.vocab_size @@ -49,15 +50,16 @@ class DeepseekV2Model(nn.Module): prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - DeepseekV2DecoderLayer( - self.config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - ) for i in range(self.config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) self.fc = nn.Linear( self.config.model.hidden_size * 2, @@ -65,12 +67,12 @@ class DeepseekV2Model(nn.Module): bias=False, ) - self.enorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - self.hnorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - self.norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.enorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -81,8 +83,8 @@ class DeepseekV2Model(nn.Module): input_embeds = self.embed_tokens(input_ids) inputs = torch.cat( - [self.enorm(input_embeds), - self.hnorm(hidden_states)], dim=-1) + [self.enorm(input_embeds), self.hnorm(hidden_states)], dim=-1 + ) hidden_states = self.fc(inputs) residual = None for layer in self.layers: @@ -94,8 +96,7 @@ class DeepseekV2Model(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -110,7 +111,8 @@ class DeepseekV2Model(nn.Module): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -135,8 +137,9 @@ class DeepseekV2Model(nn.Module): # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -168,8 +171,7 @@ class DeepseekV2Model(nn.Module): break else: # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue # Skip loading extra bias for GPTQ models. @@ -182,33 +184,40 @@ class DeepseekV2Model(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config quant_config = vllm_config.quant_config target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num) + vllm_config.parallel_config + ) + self.model = DeepseekV2Model( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -226,21 +235,19 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def transform(inputs): + name, loaded_weight = inputs + if "lm_head" not in name: + name = "model." + name + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=None, ) - - model_weights = {} - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 0ad001be71c19..bf3ab7bb3079b 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -7,59 +7,77 @@ import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .deepseek_v2 import (DeepseekV2DecoderLayer, - get_spec_layer_idx_from_weight_name) +from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name from .interfaces import SupportsPP from .utils import maybe_prefix class SharedHead(nn.Module): - def __init__( self, config: PretrainedConfig, + prefix: str, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head"), + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) class DeepSeekMultiTokenPredictorLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.speculative_config.draft_model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, - cache_config, quant_config) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda", + ) + else: + topk_indices_buffer = None + + self.shared_head = SharedHead( + config=config, prefix=prefix, quant_config=quant_config + ) + self.mtp_block = DeepseekV2DecoderLayer( + vllm_config, + prefix, + config=self.config, + topk_indices_buffer=topk_indices_buffer, + ) def forward( self, @@ -76,41 +94,43 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return hidden_states class DeepSeekMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - DeepSeekMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): DeepSeekMultiTokenPredictorLayer( + vllm_config, f"{prefix}.layers.{idx}" + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -121,7 +141,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -133,26 +153,27 @@ class DeepSeekMultiTokenPredictor(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers[str(self.mtp_start_layer_idx + - current_step_idx)] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) return logits +@support_torch_compile class DeepSeekMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = DeepSeekMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -163,21 +184,19 @@ class DeepSeekMTP(nn.Module, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, sampling_metadata, - spec_step_idx) + return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), @@ -189,7 +208,8 @@ class DeepSeekMTP(nn.Module, SupportsPP): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -200,7 +220,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): if spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -210,14 +230,15 @@ class DeepSeekMTP(nn.Module, SupportsPP): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -239,11 +260,13 @@ class DeepSeekMTP(nn.Module, SupportsPP): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -252,13 +275,16 @@ class DeepSeekMTP(nn.Module, SupportsPP): # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if (spec_layer != self.model.mtp_start_layer_idx - and ".layers" not in name): + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -270,7 +296,11 @@ class DeepSeekMTP(nn.Module, SupportsPP): and rename shared layer weights to be top level. """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] shared_weight_names = ["embed_tokens"] spec_layer_weight = False @@ -283,8 +313,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) elif shared_weight: # treat shared weights as top level weights name = name.replace(f"model.layers.{spec_layer}.", "model.") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d56224b4b7b30..f8456c5452494 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -23,8 +23,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" + import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -32,37 +34,73 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config from vllm.attention import Attention +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - MergedReplicatedLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerBackend, + DeepseekV32IndexerMetadata, +) +from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from .interfaces import MixtureOfExperts, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +logger = init_logger(__name__) class DeepseekV2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -70,23 +108,36 @@ class DeepseekV2MLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + is_sequence_parallel=False, prefix: str = "", ) -> None: super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -97,16 +148,17 @@ class DeepseekV2MLP(nn.Module): class DeepseekV2MoE(nn.Module): - def __init__( self, config: Union[DeepseekV2Config, DeepseekV3Config], + parallel_config: ParallelConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group @@ -115,38 +167,59 @@ class DeepseekV2MoE(nn.Module): self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) else: self.gate.e_score_correction_bias = None # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) - self.experts = FusedMoE( + if config.n_shared_experts is None: + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -159,68 +232,75 @@ class DeepseekV2MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - self.shared_experts = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), - prefix=f"{prefix}.shared_experts", - ) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) - if self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out + else: + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_dim) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math + if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV2Attention(nn.Module): - def __init__( self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, @@ -234,6 +314,7 @@ class DeepseekV2Attention(nn.Module): max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + topk_indices_buffer: Optional[torch.Tensor] = None, prefix: str = "", ) -> None: super().__init__() @@ -251,58 +332,70 @@ class DeepseekV2Attention(nn.Module): self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + assert topk_indices_buffer is None, ( + "topk_indices_buffer is not \ + supported for DeepseekV2Attention" + ) if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") + prefix=f"{prefix}.kv_b_proj", + ) # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' + rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) @@ -310,13 +403,15 @@ class DeepseekV2Attention(nn.Module): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -326,53 +421,465 @@ class DeepseekV2Attention(nn.Module): if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, - self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output +class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): + def __init__( + self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig + ): + super().__init__() + self.kv_cache = [torch.tensor([])] + self.head_dim = head_dim + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def get_kv_cache_spec(self) -> KVCacheSpec: + return MLAAttentionSpec( # Only has one vector instead of K + V + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + ) + + def forward(self): ... + + def get_attn_backend(self) -> AttentionBackend: + return DeepseekV32IndexerBackend + + +@torch.inference_mode() +def cp_gather_indexer_k_quant_cache( + kv_cache, # [num_blocks, block_size, head_dim + 1] + dst_value, # [cu_seq_lens[-1], head_dim] + dst_scale, # [cu_seq_lens[-1], 4] + block_table, # [batch_size, num_blocks] + cu_seq_lens, # [batch_size + 1, ] + batch_size, +): + num_blocks, block_size, _ = kv_cache.shape + head_dim = dst_value.shape[-1] + kv_cache = kv_cache.view(num_blocks, -1) + + expected_value = [] + expected_scale = [] + for b in range(batch_size): + s = cu_seq_lens[b + 1] - cu_seq_lens[b] + if s == 0: + continue + tot = cdiv(s, block_size) + blocks = block_table[b, :tot] + + value = [] + scale = [] + full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) + non_remaining_value = kv_cache[ + blocks[full_block], : block_size * head_dim + ].view(-1, head_dim) + non_remaining_scale = kv_cache[ + blocks[full_block], block_size * head_dim : + ].view(-1, 4) + + remaining = s - (tot - 1) * block_size + + value = torch.cat( + [ + non_remaining_value, + kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim), + ], + dim=0, + ) + scale = torch.cat( + [ + non_remaining_scale, + kv_cache[ + blocks[-1], + block_size * head_dim : block_size * head_dim + remaining * 4, + ].view(-1, 4), + ], + dim=0, + ) + + expected_value.append(value) + expected_scale.append(scale) + + gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) + gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) + gather_value = gather_value.view(torch.float8_e4m3fn) + gather_scale = gather_scale.view(torch.float32) + dst_value.copy_(gather_value) + dst_scale.copy_(gather_scale) + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=torch.float8_e4m3fn, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32 + ) + cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + chunk.num_reqs, + ) + logits = fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = torch.empty( + num_rows, topk_tokens, dtype=torch.int32, device=logits.device + ) + topk_values = torch.empty( + num_rows, topk_tokens, dtype=logits.dtype, device=logits.device + ) + torch.ops._C.top_k_per_row( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + topk_values, + num_rows, + logits.stride(0), + logits.stride(1), + ) + topk_indices_buffer[ + chunk.token_start : chunk.token_end, : topk_indices.shape[-1] + ] = topk_indices.to(dtype=torch.int32) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + logits = fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + # padded query len + current_device = padded_q_fp8_decode_tokens.device + padded_num_tokens = batch_size * next_n + row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n + next_n_offset = ( + torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) + % next_n + ) + index_end_pos = ( + decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1 + ).unsqueeze(1) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = torch.empty( + num_rows, topk_tokens, dtype=torch.int32, device=logits.device + ) + topk_values = torch.empty( + num_rows, topk_tokens, dtype=logits.dtype, device=logits.device + ) + torch.ops._C.top_k_per_row( + logits, + torch.zeros(num_rows, dtype=torch.int32, device=logits.device), + index_end_pos.to(dtype=torch.int32, device=logits.device), + topk_indices, + topk_values, + num_rows, + logits.stride(0), + logits.stride(1), + ) + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices.to(dtype=torch.int32) + ) + + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +direct_register_custom_op( + op_name="sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) + + +class Indexer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + q_lora_rank: int, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + topk_indices_buffer: Optional[torch.Tensor], + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = config + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 + self.q_lora_rank = q_lora_rank # 1536 + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.wk = ReplicatedLinear( + hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + ) + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.weights_proj = ReplicatedLinear( + hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj" + ) + self.softmax_scale = self.head_dim**-0.5 + + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer + + # NOTE: (zyongye) we use fp8 naive cache, + # where we store value in fp8 and scale in fp32 + # per self.quant_block_size element + self.k_cache = DeepseekV32IndexerCache( + head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4, + dtype=torch.uint8, + prefix=f"{prefix}.k_cache", + cache_config=cache_config, + ) + self.max_model_len = vllm_config.model_config.max_model_len + self.prefix = prefix + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + + def forward( + self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb + ) -> torch.Tensor: + q, _ = self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + k, _ = self.wk(hidden_states) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + + # we only quant q here since k quant is fused with cache insertion + q = q.view(-1, self.head_dim) + q_fp8, q_scale = per_token_group_quant_fp8( + q, + self.quant_block_size, + column_major_scales=False, + use_ue8m0=self.scale_fmt is not None, + ) + q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) + q_scale = q_scale.view(-1, self.n_head, 1) + + weights, _ = self.weights_proj(hidden_states) + weights = ( + weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 + ) + weights = weights.squeeze(-1) + + return torch.ops.vllm.sparse_attn_indexer( + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + + class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - - For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + + For more info see MLACommonImpl in: + vllm/v1/attention/backends/mla/utils.py """ def __init__( self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, @@ -387,6 +894,7 @@ class DeepseekV2MLAAttention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + topk_indices_buffer: Optional[torch.Tensor] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -408,166 +916,168 @@ class DeepseekV2MLAAttention(nn.Module): self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.fused_qkv_a_proj = MergedReplicatedLinear( + self.fused_qkv_a_proj = MergedColumnParallelLinear( self.hidden_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], bias=False, quant_config=quant_config, - prefix=f"{prefix}.fused_qkv_a_proj") + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True, + ) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) if self.q_lora_rank is not None: - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, + self.is_v32 = hasattr(config, "index_topk") + + if self.is_v32: + self.indexer = Indexer( + vllm_config, + config, + hidden_size, + q_lora_rank, + quant_config, + cache_config, + topk_indices_buffer, + f"{prefix}.indexer", + ) + else: + self.indexer = None + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None + else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=self.indexer, + is_sparse=self.is_v32, + topk_indices_buffer=topk_indices_buffer, ) - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - q_c = None - kv_lora = None - - if self.q_lora_rank is not None: - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] - else: - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) - - q = q.view(-1, self.num_local_heads, self.qk_head_dim) - # Add head dim of 1 to k_pe - k_pe = k_pe.unsqueeze(1) - - q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim:], k_pe) - - attn_out = self.mla_attn( - q, - kv_c_normed, - k_pe, - output_shape=(hidden_states.shape[0], - self.num_local_heads * self.v_head_dim)) - return self.o_proj(attn_out)[0] + return self.mla_attn(positions, hidden_states) class DeepseekV2DecoderLayer(nn.Module): - def __init__( self, - config: Union[DeepseekV2Config, DeepseekV3Config], + vllm_config: VllmConfig, prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - enable_eplb: bool = False, + config: Optional[DeepseekV2Config] = None, + topk_indices_buffer: Optional[torch.Tensor] = None, ) -> None: super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx if model_config.use_mla: attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( + vllm_config=vllm_config, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, kv_lora_rank=config.kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, @@ -575,16 +1085,19 @@ class DeepseekV2DecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): self.mlp = DeepseekV2MoE( config=config, + parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -594,10 +1107,10 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -608,11 +1121,10 @@ class DeepseekV2DecoderLayer(nn.Module): ) -> torch.Tensor: # Self Attention if residual is None: - residual = hidden_states + residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -622,74 +1134,76 @@ class DeepseekV2DecoderLayer(nn.Module): # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor if self.layer_idx == 0: # The residual is shared by all layers, we only scale it on # first layer. - residual *= 1. / self.routed_scaling_factor + residual *= 1.0 / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) - if isinstance(self.mlp, - DeepseekV2MLP) and hidden_states.dtype == torch.float16: + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: # Fix FP16 overflow # Scaling the DeepseekV2MLP output, it is the input of # input_layernorm of next decoder layer. # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor return hidden_states, residual @support_torch_compile class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda", + ) + else: + topk_indices_buffer = None if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - enable_eplb=enable_eplb, + vllm_config, prefix, topk_indices_buffer=topk_indices_buffer ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -712,20 +1226,19 @@ class DeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } @@ -741,33 +1254,38 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = hasattr( - config, "q_lora_rank") and config.q_lora_rank is not None + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) if self.fuse_qkv_a_proj: self.packed_modules_mapping["fused_qkv_a_proj"] = [ "q_a_proj", "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = DeepseekV2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) self.expert_weights = [] # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group - self.moe_layers: list[FusedMoE] = [] + self.moe_layers: list[SharedFusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -813,8 +1331,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, DeepseekV2MoE): moe = layer.mlp @@ -833,21 +1350,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -858,12 +1373,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -875,7 +1391,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): if spec_layer is not None: continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -885,15 +1401,16 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -930,14 +1447,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -961,8 +1481,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -975,13 +1496,15 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py -def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, - DeepseekV3Config], - weight_name: str) -> Optional[int]: - if (hasattr(config, "num_nextn_predict_layers") - and config.num_nextn_predict_layers > 0): +def get_spec_layer_idx_from_weight_name( + config: Union[DeepseekV2Config, DeepseekV3Config], weight_name: str +) -> Optional[int]: + if ( + hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index ceb5e1364b68d..8226e88c47a2c 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -3,6 +3,7 @@ # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal, Optional, Union @@ -14,35 +15,50 @@ from einops import rearrange, repeat from transformers import BatchFeature from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalProcessingInfo, - PromptReplacement, PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, - MlpProjectorConfig, - VisionEncoderConfig) -from vllm.transformers_utils.processors.deepseek_vl2 import ( - DeepseekVLV2Processor) +from vllm.transformers_utils.configs.deepseek_vl2 import ( + DeepseekVLV2Config, + MlpProjectorConfig, + VisionEncoderConfig, +) +from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) # The image token id may be various _IMAGE_TOKEN = "<image>" @@ -51,15 +67,15 @@ _IMAGE_TOKEN = "<image>" class DeepseekVL2ImagePixelInputs(TensorSchema): """ Dimensions: - - bn: Batch size * number of images + - bnp: Batch size * number of images * number of patches - p: Number of patches - c: Number of channels (3) - h: Height of each image - w: Width of each image """ + type: Literal["pixel_values"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})] + data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})] images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] @@ -70,51 +86,53 @@ class DeepseekVL2VImageEmbeddingInputs(TensorSchema): - f: Image feature size - h: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "f", "h")] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("bn", "f", "h") + ] -DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs, - DeepseekVL2VImageEmbeddingInputs] +DeepseekVL2ImageInputs = Union[ + DeepseekVL2ImagePixelInputs, DeepseekVL2VImageEmbeddingInputs +] class MlpProjector(nn.Module): - def __init__(self, cfg: MlpProjectorConfig): - super().__init__() self.cfg = cfg - assert not cfg.token_pooling, ( - "Token pooling is not supported currently.") + assert not cfg.token_pooling, "Token pooling is not supported currently." if cfg.projector_type == "downsample_mlp_gelu": mlp_depth = cfg.depth mlp_ratio = cfg.mlp_ratio modules = [ nn.Linear( - cfg.input_dim * cfg.downsample_ratio * - cfg.downsample_ratio, cfg.n_embed * mlp_ratio) + cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, + cfg.n_embed * mlp_ratio, + ) ] for _ in range(1, mlp_depth - 1): modules.append(nn.GELU()) modules.append( - nn.Linear(cfg.n_embed * mlp_ratio, - cfg.n_embed * mlp_ratio)) + nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio) + ) modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise NotImplementedError( - f"Unsupported projector type: {cfg.projector_type}") + f"Unsupported projector type: {cfg.projector_type}" + ) self.layers = modules def forward(self, x): bs, hw, input_dim = x.shape - h = w = int((hw)**0.5) + h = w = int((hw) ** 0.5) """compute padding""" if h % self.cfg.downsample_ratio: pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio @@ -125,17 +143,18 @@ class MlpProjector(nn.Module): x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) """4 to 1 concat""" x = x.permute(0, 3, 1, 2) # B, C, H, W - x = F.unfold(x, - kernel_size=self.cfg.downsample_ratio, - stride=self.cfg.downsample_ratio, - padding=0) # B, C*4, HW // 4 + x = F.unfold( + x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0, + ) # B, C*4, HW // 4 x = x.permute(0, 2, 1) return self.layers(x) class DeepseekVL2ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(DeepseekVLV2Config) @@ -145,11 +164,9 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_num_image_tokens(self, - *, - image_width: int, - image_height: int, - cropping: bool = True) -> int: + def get_num_image_tokens( + self, *, image_width: int, image_height: int, cropping: bool = True + ) -> int: hf_processor = self.get_hf_processor() image_size = hf_processor.image_size patch_size = hf_processor.patch_size @@ -157,9 +174,12 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): if cropping: best_width, best_height = hf_processor.select_best_resolution( - (image_width, image_height)) - num_width_tiles, num_height_tiles = (best_width // image_size, - best_height // image_size) + (image_width, image_height) + ) + num_width_tiles, num_height_tiles = ( + best_width // image_size, + best_height // image_size, + ) else: num_width_tiles = num_height_tiles = 1 @@ -172,15 +192,16 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() candidate_resolutions = hf_config.candidate_resolutions - height, width = max(candidate_resolutions, - key=lambda x: self.get_num_image_tokens( - image_width=x[1], image_height=x[0])) + height, width = max( + candidate_resolutions, + key=lambda x: self.get_num_image_tokens( + image_width=x[1], image_height=x[0] + ), + ) return ImageSize(width=width, height=height) -class DeepseekVL2DummyInputsBuilder( - BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): - +class DeepseekVL2DummyInputsBuilder(BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -193,22 +214,27 @@ class DeepseekVL2DummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) max_image_size = self.info.get_image_size_with_most_features() + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=max_image_size.width, - height=max_image_size.height, - num_images=num_images) + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + overrides=image_overrides, + ) } class DeepseekVL2MultiModalProcessor( - BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): - + BaseMultiModalProcessor[DeepseekVL2ProcessingInfo] +): def _call_hf_processor( self, prompt: str, @@ -218,9 +244,7 @@ class DeepseekVL2MultiModalProcessor( ) -> BatchFeature: if not mm_data: tokenizer = self.info.get_tokenizer() - return tokenizer(prompt, - add_special_tokens=True, - return_tensors="pt") + return tokenizer(prompt, add_special_tokens=True, return_tensors="pt") processed_outputs = super()._call_hf_processor( prompt=prompt, @@ -229,12 +253,9 @@ class DeepseekVL2MultiModalProcessor( tok_kwargs=tok_kwargs, ) - pixel_values = processed_outputs["pixel_values"] - # split pixel values into patches corresponding to each image - images_spatial_crop = processed_outputs["images_spatial_crop"] - patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop] - pixel_values = pixel_values.split(patches_per_image) - processed_outputs["pixel_values"] = pixel_values + processed_outputs["num_patches"] = ( + processed_outputs["images_spatial_crop"].prod(-1) + 1 + ) return processed_outputs @@ -243,8 +264,10 @@ class DeepseekVL2MultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + return dict( - pixel_values=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), images_spatial_crop=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -262,7 +285,8 @@ class DeepseekVL2MultiModalProcessor( def get_replacement_deepseek_vl2(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -290,6 +314,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is @@ -301,6 +326,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) return super()._cached_apply_hf_processor( @@ -308,18 +334,23 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) @MULTIMODAL_REGISTRY.register_processor( DeepseekVL2MultiModalProcessor, info=DeepseekVL2ProcessingInfo, - dummy_inputs=DeepseekVL2DummyInputsBuilder) + dummy_inputs=DeepseekVL2DummyInputsBuilder, +) class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "language.": "language_model.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language.": "language_model.", + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -343,11 +374,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): model_config = vllm_config.model_config tokenizer = cached_tokenizer_from_config(model_config) - self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] - self.vision = self._init_vision_module(self.vision_config, - quant_config, - maybe_prefix(prefix, "vision")) + self.vision = self._init_vision_module( + self.vision_config, quant_config, maybe_prefix(prefix, "vision") + ) self.projector = MlpProjector(self.projector_config) self.tile_tag = config.tile_tag @@ -355,14 +386,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # special token for image token sequence format embed_std = 1 / torch.sqrt( - torch.tensor(self.projector_config.n_embed, dtype=torch.float32)) + torch.tensor(self.projector_config.n_embed, dtype=torch.float32) + ) if self.tile_tag == "2D": # <|view_seperator|>, <|\n|> self.image_newline = nn.Parameter( - torch.randn(self.projector_config.n_embed) * embed_std) + torch.randn(self.projector_config.n_embed) * embed_std + ) # This is a typo in original implementation self.view_seperator = nn.Parameter( - torch.randn(self.projector_config.n_embed) * embed_std) + torch.randn(self.projector_config.n_embed) * embed_std + ) else: raise ValueError( f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" @@ -383,19 +417,19 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str): """Return (parent_module, final_attr_name) for a dotted module path.""" - names = dotted_name.split('.') + names = dotted_name.split(".") parent = root for n in names[:-1]: parent = getattr(parent, n) return parent, names[-1] - #patch for timm ViT instance to support tensor parallel - def patch_vit_for_tp(self, vit: torch.nn.Module, - quant_config: QuantizationConfig): + # patch for timm ViT instance to support tensor parallel + def patch_vit_for_tp(self, vit: torch.nn.Module, quant_config: QuantizationConfig): try: import timm except ImportError as e: @@ -405,13 +439,14 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if isinstance(module, nn.Linear): parent, attr_name = self._get_parent_and_attr(vit, name) if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": - new_linear = replace_linear_class(module, "colwise", - quant_config) + new_linear = replace_linear_class( + module, "colwise", quant_config, prefix=name + ) setattr(parent, attr_name, new_linear) - elif isinstance(parent, - timm.layers.Mlp) and attr_name == "fc2": - new_linear = replace_linear_class(module, "rowwise", - quant_config) + elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2": + new_linear = replace_linear_class( + module, "rowwise", quant_config, prefix=name + ) setattr(parent, attr_name, new_linear) return vit @@ -444,7 +479,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return model def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]: + self, **kwargs: object + ) -> Optional[DeepseekVL2ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None) image_embeds = kwargs.pop("image_embeds", None) @@ -454,37 +490,31 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if pixel_values is not None: expected_h = expected_w = self.vision_config.image_size - return DeepseekVL2ImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values), - images_spatial_crop=flatten_bn( - images_spatial_crop, - concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w, - }) + return DeepseekVL2ImagePixelInputs( + type="pixel_values", + data=pixel_values, + images_spatial_crop=images_spatial_crop, + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }, + ) if image_embeds is not None: return DeepseekVL2VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _pixel_values_to_embedding( self, - pixel_values: NestedTensors, + pixel_values: torch.Tensor, images_spatial_crop: torch.Tensor, - ) -> NestedTensors: - # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width] - total_tiles = [x for x in pixel_values] - - # [batch_all_tiles, 3, height, width] - total_tiles = torch.cat(total_tiles, dim=0) - + ) -> list[torch.Tensor]: # [batch_all_tiles, vit_seq_len, c] - images_feature = self.vision.forward_features(total_tiles) + images_feature = self.vision.forward_features(pixel_values) # [batch_all_tiles, hw, D] images_embeds = self.projector(images_feature) @@ -506,8 +536,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): global_features = images_embeds[tile_index] # [num_height_tiles * num_width_tiles, hw, D] - local_features = images_embeds[tile_index + 1:tile_index + 1 + - num_tiles_in_image] + local_features = images_embeds[ + tile_index + 1 : tile_index + 1 + num_tiles_in_image + ] tile_index += num_tiles_in_image + 1 # format global and local features @@ -519,8 +550,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] - global_features = torch.cat([global_features, new_lines_in_global], - dim=1) + global_features = torch.cat([global_features, new_lines_in_global], dim=1) # [h, w + 1, D] -> [h * (w + 1), D] global_features = global_features.view(-1, n_dim) @@ -528,22 +558,22 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # ----------------- local view add newline ----------------- # [num_height_tiles * num_width_tiles, h * w, D] -> # [num_height_tiles * h, num_width_tiles * w, D] - local_features = rearrange(local_features, - "(th tw) (h w) d -> (th h) (tw w) d", - th=num_height_tiles, - tw=num_width_tiles, - h=h, - w=w) + local_features = rearrange( + local_features, + "(th tw) (h w) d -> (th h) (tw w) d", + th=num_height_tiles, + tw=num_width_tiles, + h=h, + w=w, + ) # [D] -> [num_height_tiles * h, 1, D] - new_lines_in_local = repeat(self.image_newline, - "d -> (th h) 1 d", - th=num_height_tiles, - h=h) + new_lines_in_local = repeat( + self.image_newline, "d -> (th h) 1 d", th=num_height_tiles, h=h + ) # [num_height_tiles * h, num_width_tiles * w + 1, D] - local_features = torch.cat([local_features, new_lines_in_local], - dim=1) + local_features = torch.cat([local_features, new_lines_in_local], dim=1) # [num_height_tiles * h, num_width_tiles * w + 1, D] # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] @@ -551,23 +581,28 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # merge global and local tiles if self.global_view_pos == "head": - global_local_features = torch.cat([ - global_features, - self.view_seperator[None, :], - local_features, - ]) + global_local_features = torch.cat( + [ + global_features, + self.view_seperator[None, :], + local_features, + ] + ) else: - global_local_features = torch.cat([ - local_features, - self.view_seperator[None, :], - global_features, - ]) + global_local_features = torch.cat( + [ + local_features, + self.view_seperator[None, :], + global_features, + ] + ) vision_embeddings.append(global_local_features) return vision_embeddings def _process_image_input( - self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor: + self, image_input: DeepseekVL2ImageInputs + ) -> list[torch.Tensor]: if image_input["type"] == "image_embeds": image_data = image_input["data"] if is_list_of(image_data, torch.Tensor): @@ -585,69 +620,43 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): images_spatial_crop = image_input["images_spatial_crop"] return self._pixel_values_to_embedding( - pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) + pixel_values=pixel_values, images_spatial_crop=images_spatial_crop + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( + def forward( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object): - + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + return self.language_model.compute_logits(hidden_states) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return autoloaded_weights diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 5f410c0ae5fb0..55f8d4b231f78 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -24,7 +24,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only dots1 model.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -34,34 +36,45 @@ from transformers import Dots1Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Dots1MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -73,19 +86,24 @@ class Dots1MLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -96,7 +114,6 @@ class Dots1MLP(nn.Module): class Dots1MoE(nn.Module): - def __init__( self, config: Dots1Config, @@ -109,21 +126,40 @@ class Dots1MoE(nn.Module): self.n_shared_experts = config.n_shared_experts if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = (nn.Parameter( - torch.empty(config.n_routed_experts))) + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts) + ) else: self.gate.e_score_correction_bias = None - self.experts = FusedMoE( + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = Dots1MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -136,39 +172,30 @@ class Dots1MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) - - if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - self.shared_experts = Dots1MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_experts", - ) + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) + + if self.shared_experts is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class Dots1Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -198,8 +225,7 @@ class Dots1Attention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = getattr(config, "head_dim", - hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -242,14 +268,15 @@ class Dots1Attention(nn.Module): self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - def forward(self, positions: torch.Tensor, - hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = self.q_norm(q.reshape(-1, self.num_heads, - self.head_dim)).reshape(q.shape) - k = self.k_norm(k.reshape(-1, self.num_kv_heads, - self.head_dim)).reshape(k.shape) + q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape( + k.shape + ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -257,7 +284,6 @@ class Dots1Attention(nn.Module): class Dots1DecoderLayer(nn.Module): - def __init__( self, config: Dots1Config, @@ -270,9 +296,8 @@ class Dots1DecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - layer_idx = int(prefix.split(sep='.')[-1]) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx self.self_attn = Dots1Attention( @@ -287,12 +312,14 @@ class Dots1DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = Dots1MoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = Dots1MoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = Dots1MLP( hidden_size=config.hidden_size, @@ -301,10 +328,10 @@ class Dots1DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -317,19 +344,15 @@ class Dots1DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class Dots1Model(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -348,7 +371,8 @@ class Dots1Model(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() @@ -361,15 +385,16 @@ class Dots1Model(nn.Module): cache_config=cache_config, quant_config=quant_config, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -391,29 +416,28 @@ class Dots1Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -428,10 +452,10 @@ class Dots1Model(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: @@ -454,11 +478,13 @@ class Dots1Model(nn.Module): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -469,15 +495,15 @@ class Dots1Model(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -496,17 +522,22 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Dots1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Dots1Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -529,14 +560,11 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py new file mode 100644 index 0000000000000..d1a9f4cb3b2e7 --- /dev/null +++ b/vllm/model_executor/models/dots_ocr.py @@ -0,0 +1,879 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping +from typing import Annotated, Literal, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from transformers.models.qwen2_vl import Qwen2VLProcessor + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import utils as dist_utils +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VLDummyInputsBuilder, + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .vision import run_dp_sharded_mrope_vision_model + +IMAGE_TOKEN = "<|imgpad|>" + + +class DotsOCRImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + """ + + type: Literal["pixel_values"] + + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +class DotsOCRImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + """ + + type: Literal["image_embeds"] + + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs] + + +class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return IMAGE_TOKEN * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501 + ) + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self) -> DotsOCRConfig: + config = self.ctx.get_hf_config() + if not config.__class__.__name__ == "DotsOCRConfig": + raise TypeError(f"Expected DotsOCRConfig, got {type(config)}") + + if hasattr(config, "vision_config") and isinstance(config.vision_config, dict): + config.vision_config = DotsVisionConfig(**config.vision_config) + + return config + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + return {"image": max_image_tokens} + + def get_hf_processor( + self, + **kwargs: object, + ) -> Qwen2VLProcessor: + self.get_tokenizer().image_token = IMAGE_TOKEN # Ensure image token is set + processor = self.ctx.get_hf_processor( + Qwen2VLProcessor, + **kwargs, + ) + processor.image_token = IMAGE_TOKEN + processor.video_token = "<|video_pad|>" + return processor + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + + cos = freqs.cos() + sin = freqs.sin() + + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + + output = (tensor * cos) + (rotate_half(tensor) * sin) + + output = output.to(orig_dtype) + + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchMerger(nn.Module): + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + pre_norm="layernorm", + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.pre_norm = pre_norm + if self.pre_norm == "layernorm": + self.ln_q = LayerNorm(context_dim, eps=1e-6) + elif self.pre_norm == "rmsnorm": + self.ln_q = RMSNorm(context_dim, eps=1e-6) + + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + return_bias=False, + prefix=f"{prefix}.0", + disable_tp=use_data_parallel, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + dim, + bias=True, + return_bias=False, + prefix=f"{prefix}.2", + disable_tp=use_data_parallel, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + else: + x = self.mlp(x.view(-1, self.hidden_size)) + return x + + +class DotsVisionAttention(nn.Module): + def __init__( + self, + config, + dim: int, + num_heads: int = 16, + bias: bool = True, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + + self.embed_dim = dim + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size + ) + # qkv/proj follow Qwen2-VL style; bias controlled by arg + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) + # Select attention backend + self.attn_backend = get_vit_attn_backend( + self.hidden_size_per_attention_head, torch.get_default_dtype() + ) + self.use_upstream_fa = False + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) + ) + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Unsupported vision attention backend: {self.attn_backend}" + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + *, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None, + ) -> torch.Tensor: + # [S, C] -> [S, B=1, C] + x = hidden_states.unsqueeze(1) + x, _ = self.qkv(x) + q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x) + bs = q.shape[1] + # [S,B,H,D] -> [B,S,H,D] + q = q.permute(1, 0, 2, 3).contiguous() + k = k.permute(1, 0, 2, 3).contiguous() + v = v.permute(1, 0, 2, 3).contiguous() + + if rotary_pos_emb is not None: + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + + if self.is_flash_attn_backend: + q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) + k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) + v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) + output = self.flash_attn_varlen_func( + q_, + k_, + v_, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + context_layer = output.view( + bs, + -1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + elif self.attn_backend == _Backend.TORCH_SDPA: + outputs = [] + for i in range(1, len(cu_seqlens)): + s = int(cu_seqlens[i - 1]) + e = int(cu_seqlens[i]) + q_i = q[:, s:e].permute(0, 2, 1, 3) + k_i = k[:, s:e].permute(0, 2, 1, 3) + v_i = v[:, s:e].permute(0, 2, 1, 3) + out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + out_i = out_i.permute(0, 2, 1, 3) + outputs.append(out_i) + context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + else: + raise RuntimeError("Unsupported attention backend") + + # [B,S,H,D] -> [S,B,H*D] -> [S, C] + context_layer = context_layer.permute(1, 0, 2, 3).contiguous() + context_layer = context_layer.view(context_layer.shape[0], bs, -1) + out, _ = self.proj(context_layer) + return out.squeeze(1) + + +class DotsSwiGLUFFN(nn.Module): + def __init__( + self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.embed_dim + bias = config.use_bias + + # Referenced aimv2.py AIMv2SwiGLUFFN + self.fc13 = MergedColumnParallelLinear( + in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + disable_tp=use_data_parallel, + ) + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc13(x) + x = self.act_fn(x) + x, _ = self.fc2(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("fc13", "fc1", 0), + ("fc13", "fc3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DotsPatchEmbed(nn.Module): + def __init__(self, config): + super().__init__() + self.num_channels = config.num_channels + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.embed_dim = config.embed_dim + self.config = config + self.proj = nn.Conv2d( + config.num_channels, + config.embed_dim, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + x = x.view( + -1, + self.num_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + )[:, :, 0] + x = self.proj(x).view(-1, self.embed_dim) + x = self.norm(x) + return x + + +class DotsViTPreprocessor(nn.Module): + def __init__(self, config): + super().__init__() + self.patch_h = config.patch_size + self.patch_w = config.patch_size + self.embed_dim = config.embed_dim + self.config = config + self.patchifier = DotsPatchEmbed(config) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + tokens = self.patchifier(x, grid_thw) + return tokens + + +class DotsVisionBlock(nn.Module): + def __init__( + self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + + self.attn = DotsVisionAttention( + config, + config.embed_dim, + num_heads=config.num_attention_heads, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) + self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + self.mlp = DotsSwiGLUFFN( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class DotsVisionTransformer(nn.Module): + def __init__( + self, + config: DotsVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = DotsViTPreprocessor(config) + + head_dim = config.embed_dim // config.num_attention_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN + self.out_hidden_size = config.hidden_size + # Keep blocks for compatibility with other vision towers + num_layers = ( + config.num_hidden_layers + if num_hidden_layers_override is None + else num_hidden_layers_override + ) + self.blocks = nn.ModuleList( + [ + DotsVisionBlock( + config, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + use_data_parallel=use_data_parallel, + ) + for i in range(num_layers) + ] + ) + if require_post_norm is None: + require_post_norm = len(self.blocks) == config.num_hidden_layers + if require_post_norm and self.config.post_norm: + self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + else: + self.post_trunk_norm = None + + self.merger = PatchMerger( + dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + use_data_parallel=use_data_parallel, + ) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.patchifier.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.patchifier.proj.weight.device + + def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + return pos_ids + + def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: + pos_ids = self.get_pos_ids_by_grid(grid_thw) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, hidden_states: torch.Tensor, grid_thw: list[list[int]] + ) -> torch.Tensor: + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # Convert grid_thw to tensor (always expecting list format now) + grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long) + hidden_states = hidden_states.to(self.dtype) + hidden_states = self.patch_embed(hidden_states, grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + if self.post_trunk_norm is not None: + hidden_states = self.post_trunk_norm(hidden_states) + + hidden_states = self.merger(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2VLMultiModalProcessor, + info=DotsOCRProcessingInfo, + dummy_inputs=DotsOCRDummyInputsBuilder, +) +class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".attn.qkv_proj.": ".attn.qkv.", + ".attn.out_proj.": ".attn.proj.", + }, + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }, + ) + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + ".attn.qkv": [".attn.qkv"], + "fc13": ["fc1", "fc3"], + } + supports_encoder_tp_data = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|img|><|imgpad|><|endofimg|>" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config: DotsOCRConfig = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + if isinstance(self.config.vision_config, dict): + vision_config = DotsVisionConfig(**self.config.vision_config) + self.config.vision_config = vision_config + else: + vision_config = self.config.vision_config + self.vision_tower = DotsVisionTransformer( + vision_config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + use_data_parallel=self.use_data_parallel, + ) + self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[DotsOCRImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + return DotsOCRImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + return DotsOCRImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _process_image_input( + self, image_input: DotsOCRImageInputs + ) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.vision_tower, + pixel_values, + grid_thw_list, + rope_type="rope_3d", + ) + else: + image_embeds = self.vision_tower(pixel_values, grid_thw_list)[ + :, : self.config.hidden_size + ] + + # Split concatenated embeddings for each image item. + merge_size = self.vision_tower.spatial_merge_size + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() + + return image_embeds.split(sizes) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="vision_tower.merger", + tower_model="vision_tower.", + ) diff --git a/vllm/model_executor/models/ernie45.py b/vllm/model_executor/models/ernie45.py index e7302dc5ecdd7..b1d26cddcc5eb 100644 --- a/vllm/model_executor/models/ernie45.py +++ b/vllm/model_executor/models/ernie45.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Erine model compatible with HuggingFace weights.""" + from vllm.config import VllmConfig from vllm.model_executor.models.llama import LlamaForCausalLM @@ -29,7 +30,6 @@ from .utils import PPMissingLayer class Ernie4_5ForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # Hack Llama model to fit HF format Ernie4.5 dense implementation diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 4780ea931ea50..7516cb5abaf9a 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -22,7 +22,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only ErineMoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -35,33 +37,42 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Ernie4_5_MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -74,19 +85,24 @@ class Ernie4_5_MoeMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=use_bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=use_bias, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -97,7 +113,6 @@ class Ernie4_5_MoeMLP(nn.Module): class Ernie4_5_MoeMoE(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -109,24 +124,44 @@ class Ernie4_5_MoeMoE(nn.Module): layer_idx = extract_layer_index(prefix) self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() - self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) - > 0) + self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0 if self.tp_size > config.moe_num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.moe_num_experts}.") + f"the number of experts {config.moe_num_experts}." + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.moe_num_experts)) + torch.empty(config.moe_num_experts, dtype=torch.float32) + ) - self.experts = FusedMoE( + if self.has_shared_experts: + intermediate_size = ( + config.moe_intermediate_size * config.moe_num_shared_experts + ) + self.shared_experts = Ernie4_5_MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=False, + ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.moe_num_experts, top_k=config.moe_k, hidden_size=config.hidden_size, @@ -135,47 +170,32 @@ class Ernie4_5_MoeMoE(nn.Module): renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", - e_score_correction_bias=self.gate.e_score_correction_bias) - - if self.has_shared_experts: - intermediate_size = (config.moe_intermediate_size * - config.moe_num_shared_experts) - self.shared_experts = Ernie4_5_MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.shared_experts", - reduce_results=self.experts.must_reduce_shared_expert_outputs( - )) + e_score_correction_bias=self.gate.e_score_correction_bias, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None + + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if self.has_shared_experts: - shared_output = self.shared_experts(hidden_states) - - router_logits, _ = self.gate(hidden_states) - - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - - if self.has_shared_experts and \ - shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(orig_shape) class Ernie4_5_MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -218,19 +238,23 @@ class Ernie4_5_MoeAttention(nn.Module): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, @@ -240,20 +264,21 @@ class Ernie4_5_MoeAttention(nn.Module): is_neox_style=False, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -267,7 +292,6 @@ class Ernie4_5_MoeAttention(nn.Module): class Ernie4_5_MoeDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -279,18 +303,17 @@ class Ernie4_5_MoeDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 500000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 131072) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) self.self_attn = Ernie4_5_MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, - head_dim=getattr(config, 'head_dim', None), + head_dim=getattr(config, "head_dim", None), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'use_bias', False), + qkv_bias=getattr(config, "use_bias", False), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -302,30 +325,35 @@ class Ernie4_5_MoeDecoderLayer(nn.Module): # MoE moe_num_experts = getattr(config, "moe_num_experts", 0) moe_layer_start_index = getattr(config, "moe_layer_start_index", 0) - moe_layer_end_index = getattr(config, "moe_layer_end_index", - config.num_hidden_layers - 1) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", config.num_hidden_layers - 1 + ) moe_layer_interval = getattr(config, "moe_layer_interval", 1) use_moe = getattr(config, "use_moe", moe_num_experts > 0) - if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) - and layer_idx >= moe_layer_start_index - and layer_idx <= moe_layer_end_index): - self.mlp = Ernie4_5_MoeMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + use_moe + and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index + and layer_idx <= moe_layer_end_index + ): + self.mlp = Ernie4_5_MoeMoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = Ernie4_5_MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - use_bias=getattr(config, 'use_bias', False), + use_bias=getattr(config, "use_bias", False), quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -333,14 +361,12 @@ class Ernie4_5_MoeDecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: - # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, @@ -348,8 +374,7 @@ class Ernie4_5_MoeDecoderLayer(nn.Module): ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) @@ -358,7 +383,6 @@ class Ernie4_5_MoeDecoderLayer(nn.Module): @support_torch_compile class Ernie4_5_MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -375,16 +399,19 @@ class Ernie4_5_MoeModel(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Ernie4_5_MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Ernie4_5_MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) @@ -393,9 +420,9 @@ class Ernie4_5_MoeModel(nn.Module): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -407,7 +434,6 @@ class Ernie4_5_MoeModel(nn.Module): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -419,32 +445,29 @@ class Ernie4_5_MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.moe_num_experts) + num_experts=self.config.moe_num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -458,8 +481,7 @@ class Ernie4_5_MoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.config.tie_word_embeddings and name.endswith( - "lm_head.weight"): + if self.config.tie_word_embeddings and name.endswith("lm_head.weight"): continue # MTP will be supported soon. if "mtp" in name: @@ -469,17 +491,18 @@ class Ernie4_5_MoeModel(nn.Module): name = name.replace("moe_statics", "gate") loaded_weight = loaded_weight.squeeze(0) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -502,22 +525,26 @@ class Ernie4_5_MoeModel(nn.Module): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -528,8 +555,9 @@ class Ernie4_5_MoeModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -556,13 +584,17 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Ernie4_5_MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() @@ -570,7 +602,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -582,25 +615,22 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py new file mode 100644 index 0000000000000..d5b2caa2ddfd6 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl.py @@ -0,0 +1,1717 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" + +import itertools +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Annotated, Any, Callable, Literal, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature, PretrainedConfig + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +# === Vision Transformer === # + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class Ernie4_5_VisionAttention(nn.Module): + """VisionAttention using VLLM framework APIs""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size + ) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + # Detect attention implementation. + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype(), + ) + + self.use_upstream_fa = False + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) + ) + + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Ernie45-VL does not support {self.attn_backend} backend now." + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) + if rotary_pos_emb is not None: + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + + if self.is_flash_attn_backend: + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Ernie4_5_VisionMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.act = act_layer() + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class Ernie4_5_VisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Ernie4_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + self.mlp = Ernie4_5_VisionMLP( + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Ernie4_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", + ) -> None: + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = nn.Linear( + in_channels * patch_size * patch_size, embed_dim, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.to(target_dtype) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class Ernie4_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / theta ** ( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim + ) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +class Ernie4_5_VisionTransformer(nn.Module): + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Ernie4_5_VisionPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + prefix=f"{prefix}.patch_embed", + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Ernie4_5_VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(depth) + ] + ) + + assert hidden_size == embed_dim, ( + "vit's config.hidden must be equal to config.embed_dim" + ) + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 + ) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + + zeros = cu_seqlens.new_zeros(1) + if num_pad > 0: + cu_seqlens = torch.cat([zeros, cu_seqlens, zeros]) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = torch.cat([zeros, cu_seqlens]) + + # add batch size + if hidden_states.ndim == 2: + hidden_states = hidden_states.unsqueeze(dim=1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + for i, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + final_output = self.ln(hidden_states) + + if final_output.ndim == 3: + final_output = final_output.squeeze(dim=1) + + return final_output + + def load_weights(self, weights) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# === Vision Inputs === # + + +class Ernie4_5_VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + """ + + type: Literal["pixel_values"] + + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + + +class Ernie4_5_VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * temporal_patch_size * patch_size * + patch_size + """ + + type: Literal["pixel_values_videos"] + pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs + +# === Vision Processor === # + + +def round_by_factor(number: Union[int, float], factor: int) -> int: + return round(number / factor) * factor + + +def ceil_by_factor(number: Union[int, float], factor: int) -> int: + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: Union[int, float], factor: int) -> int: + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +class VariableResolutionResamplerModel(nn.Module): + def __init__( + self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size + # compress 3d conv(video) to 1d + self.temporal_dim = ( + self.in_dim + * self.spatial_conv_size + * self.spatial_conv_size + * self.temporal_conv_size + ) + + self.spatial_linear1 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, "quant_config", None), + prefix=f"{prefix}.spatial_linear1", + ) + + self.spatial_gelu = nn.GELU() + + self.spatial_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, "quant_config", None), + prefix=f"{prefix}.spatial_linear2", + ) + + self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + if self.use_temporal_conv: + self.temporal_linear1 = ColumnParallelLinear( + self.temporal_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, "quant_config", None), + prefix=f"{prefix}.temporal_linear1", + ) + + self.temporal_gelu = nn.GELU() + + self.temporal_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, "quant_config", None), + prefix=f"{prefix}.temporal_linear2", + ) + + self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + self.mlp = ColumnParallelLinear( + self.spatial_dim, + self.out_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, "quant_config", None), + prefix=f"{prefix}.mlp", + ) + + self.after_norm = RMSNorm( + hidden_size=out_dim, eps=getattr(config, "rms_norm_eps", 1e-6) + ) + + def spatial_conv_reshape(self, x, spatial_conv_size): + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size**2)]) + return x + + def forward(self, x, grid_thw): + def fwd_spatial(x): + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x, _ = self.spatial_linear1(x) + x = self.spatial_gelu(x) + x, _ = self.spatial_linear2(x) + x = self.spatial_norm(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2) + batch_offset = np.empty( + tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype + ) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + ) + ) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( + x.device + ) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range( + 1 if temporoal_size > 1 else 0, temporoal_size, 2 + ): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + ) + ) + slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( + x.device + ) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x, _ = self.temporal_linear1(x) + x = self.temporal_gelu(x) + x, _ = self.temporal_linear2(x) + x = self.temporal_norm(x) + return x + + def fwd_mlp(x): + x, _ = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(use_fast=True, **kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts) + return {"image": max_image_tokens, "video": max_video_tokens} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Any], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + patch_size = vision_config.patch_size + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_conv_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, height=image_height) + + grid_t = max(num_frames // temporal_conv_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (spatial_conv_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Any], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Any], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + return num_image_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + # If the number of frames is odd, discard one frame. + if num_frames % 2 != 0: + num_frames -= 1 + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) + max_frames_per_video = max_total_frames // max(max_videos, 1) + + return max(max_frames_per_video, 2) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), + image_processor=None, + ) + + +class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + def _pixel_values_norm( + self, + pixel_values: torch.Tensor, + mm_kwargs: object, + ) -> torch.Tensor: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + image_processor = self.info.get_image_processor(**mm_kwargs) + image_mean_tensor = torch.tensor( + image_processor.image_mean, dtype=torch.float32 + ).reshape([1, 3, 1, 1]) + image_std_tensor = torch.tensor( + image_processor.image_std, dtype=torch.float32 + ).reshape([1, 3, 1, 1]) + rescale_factor = torch.tensor( + image_processor.rescale_factor, dtype=torch.float32 + ) + patch_size_squared = vision_config.patch_size**2 + + image_mean_tensor = image_mean_tensor.squeeze([-2, -1]).repeat_interleave( + patch_size_squared, -1 + ) + image_std_tensor = image_std_tensor.squeeze([-2, -1]).repeat_interleave( + patch_size_squared, -1 + ) + + if not image_mean_tensor.is_contiguous(): + image_mean_tensor = image_mean_tensor.contiguous() + if not image_std_tensor.is_contiguous(): + image_std_tensor = image_std_tensor.contiguous() + + pixel_values = ( + rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor + ) / image_std_tensor + pixel_values = pixel_values.to(hf_config.torch_dtype) + return pixel_values + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # when the prompt is not empty but the multimodal data is empty, + # directly invoke the tokenizer. + if "images" not in mm_data and "videos" not in mm_data and prompt != "": + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + tokenizer_output = BatchFeature( + dict(input_ids=[prompt_ids]), tensor_type="pt" + ) + return tokenizer_output + + if "images" not in mm_data: + mm_data["images"] = [] + if "videos" not in mm_data: + mm_data["videos"] = [] + processor_output = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=[prompt], images=mm_data["images"], videos=mm_data["videos"]), + dict(**mm_kwargs, **tok_kwargs), + ) + + # Divide the processor_output into two modalities: image and video. + if processor_output is not None: + pixel_values = processor_output["images"] + if pixel_values is not None: + processor_output["images"] = self._pixel_values_norm( + pixel_values, mm_kwargs + ) + for key in list(processor_output.keys()): + if processor_output[key] is None: + del processor_output[key] + continue + if key == "grid_thw": + grid_thw = processor_output["grid_thw"] + pixel_values_all = processor_output["images"] + # Identify elements where the first + # dimension is greater than 1 and + # treat them as the video modality + mask = grid_thw[:, 0] > 1 + processor_output["video_grid_thw"] = grid_thw[mask] + processor_output["image_grid_thw"] = grid_thw[~mask] + image_patch_num = ( + processor_output["image_grid_thw"].prod(dim=1).sum() + ) + processor_output["pixel_values"] = pixel_values_all[ + :image_patch_num + ] + processor_output["pixel_values_videos"] = pixel_values_all[ + image_patch_num: + ] + del processor_output["images"] + + return processor_output + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + before_placeholder = { + "image": "<|image@placeholder|>", + "video": "<|video@placeholder|>", + } + + after_placeholder = { + # image and video have same placeholder + "image": "<|IMAGE_PLACEHOLDER|>", + "video": "<|IMAGE_PLACEHOLDER|>", + } + + merge_length = hf_processor.spatial_conv_size**2 + + def get_replacement_ernie45vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + if modality == "video": + num_tokens = ( + int(grid_thw.prod()) + // hf_processor.temporal_conv_size + // merge_length + ) + else: + num_tokens = int(grid_thw.prod()) // merge_length + return after_placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=before_placeholder[modality], + replacement=partial(get_replacement_ernie45vl, modality=modality), + ) + for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes + ), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes + ), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + +class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + prompt = "" + for i in range(num_images): + prompt += ( + f"Picture {i + 1}:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + ) + + for i in range(num_videos): + prompt += f"Video {i + 1}:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + return prompt + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ), + } + + +@MULTIMODAL_REGISTRY.register_processor( + Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder, +) +class Ernie4_5_VLMoeForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): + merge_by_field_config = True + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. + # language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }, + # resampler_weight_mappings + orig_to_new_substr={ + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.3.": "temporal_norm.", + }, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + if modality.startswith("video"): + return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Ernie4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = Ernie4_5_VLMoeForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.resampler_model = VariableResolutionResamplerModel( + self.config.pixel_hidden_size, + self.config.hidden_size, + self.config.spatial_conv_size, + self.config.temporal_conv_size, + config=self.config, + prefix=maybe_prefix(prefix, "resampler_model"), + ) + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + """compute logits""" + return self.language_model.compute_logits(hidden_states) + + def _vision_forward( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0] + if grid_thw.numel() % 3 != 0: + raise ValueError( + f"grid_thw has {grid_thw.numel()} elements after filtering," + "which is not divisible by 3." + ) + grid_thw = grid_thw.reshape(-1, 3) + # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(pixel_values, grid_thw) + return image_features + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if getattr(self.config, "im_patch_id", None) is not None: + self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape( + -1, 1 + ) + else: + self.visual_token_mask = None + + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t // temporal_conv_size, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[Ernie4_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + if pixel_values is not None: + return Ernie4_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> Optional[Ernie4_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + if pixel_values_videos is not None: + return Ernie4_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, image_input: Ernie4_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.vision_model.dtype) + image_features = self._vision_forward( + pixel_values=pixel_values, grid_thw=grid_thw + ) + image_embeds = self.resampler_model(image_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Ernie4_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.vision_model.dtype + ) + video_features = self._vision_forward( + pixel_values=pixel_values_videos, grid_thw=grid_thw + ) + video_embeds = self.resampler_model(video_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = ( + (grid_thw.prod(-1) // self.config.temporal_conv_size) + // merge_size + // merge_size + ) + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor corresponding to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: + self._set_visual_token_mask(input_ids) + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + if self.visual_token_mask is not None: + if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: + padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0] + # right pad False + pad = torch.zeros( + (padding_len, self.visual_token_mask.shape[1]), + dtype=self.visual_token_mask.dtype, + device=self.visual_token_mask.device, + ) + self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0) + + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model( + **forward_kwargs, + **kwargs, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py new file mode 100644 index 0000000000000..2c49895561409 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -0,0 +1,786 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention + +# from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding, +) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors + +from .ernie45_moe import Ernie4_5_MoeMLP +from .interfaces import SupportsPP +from .utils import ( + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): + def __init__(self, shared_experts: Optional[torch.nn.Module] = None, **kwargs): + super().__init__(**kwargs) + self.shared_experts = shared_experts + + def forward(self, x): + if self.shared_experts is not None: + return self.shared_experts(x) + super().forward(x) + else: + return super().forward(x) + + +class Ernie4_5_VLMoeAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + freq_allocation: int = 20, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + t_rope = freq_allocation + h_rope = (self.head_dim // 2 - freq_allocation) // 2 + w_rope = (self.head_dim // 2 - freq_allocation) // 2 + + self.rotary_emb = Ernie4_5_VLRotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + dtype=torch.get_default_dtype(), + mrope_section=[h_rope, w_rope, t_rope], + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Ernie4_5_VLMoeMoE(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0 + self.hidden_size = config.hidden_size + + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + + if self.tp_size > max_moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {moe_num_experts}." + ) + + moe_layer_start_index = config.moe_layer_start_index + text_moe_layer_start_index = moe_layer_start_index[0] + vision_moe_layer_start_index = moe_layer_start_index[1] + moe_layer_end_index = config.moe_layer_end_index + moe_layer_end_index = getattr( + config, + "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1], + ) + text_moe_layer_end_index = moe_layer_end_index[0] + vision_moe_layer_end_index = moe_layer_end_index[1] + + assert config.moe_num_experts[0] == config.moe_num_experts[1] + self.e_score_correction_bias = nn.Parameter( + torch.empty(2, config.moe_num_experts[0], dtype=torch.float32) + ) + + assert text_moe_layer_start_index <= text_moe_layer_end_index + + if self.has_shared_experts: + intermediate_size = ( + config.moe_intermediate_size[0] * config.moe_num_shared_experts + ) + self.shared_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=False, + ) + else: + self.shared_experts = None + + if ( + layer_idx >= text_moe_layer_start_index + and layer_idx <= text_moe_layer_end_index + ): + self.text_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[0], + bias=False, + params_dtype=torch.float32, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate", + ) + + self.text_experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts", + ) + else: + self.text_experts = Ernie4_5_VLMoeMLP( + shared_experts=self.shared_experts, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, "use_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + assert vision_moe_layer_start_index <= vision_moe_layer_end_index + if ( + layer_idx >= vision_moe_layer_start_index + and layer_idx <= vision_moe_layer_end_index + ): + self.vision_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[1], + bias=False, + params_dtype=torch.float32, + quant_config=quant_config, + prefix=f"{prefix}.vision_experts_gate", + ) + + self.vision_experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.vision_experts", + ) + else: + self.vision_experts = Ernie4_5_VLMoeMLP( + shared_experts=self.shared_experts, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, "use_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + visual_token_mask: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if visual_token_mask is not None and visual_token_mask.all(): + # only vision modal input + router_logits, _ = self.vision_experts_gate( + hidden_states.to(dtype=torch.float32) + ) + final_hidden_states = self.vision_experts( + hidden_states=hidden_states, router_logits=router_logits + ) + elif visual_token_mask is not None and visual_token_mask.any(): + # text and vision modals input + visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + final_hidden_states = torch.zeros_like(hidden_states) + + text_hidden_states = hidden_states[text_token_mask].reshape( + -1, self.hidden_size + ) + vision_hidden_states = hidden_states[visual_token_mask].reshape( + -1, self.hidden_size + ) + + text_router_logits, _ = self.text_experts_gate( + text_hidden_states.to(dtype=torch.float32) + ) + final_hidden_states[text_token_mask] = self.text_experts( + hidden_states=text_hidden_states, router_logits=text_router_logits + ).flatten() + + vision_router_logits, _ = self.vision_experts_gate( + vision_hidden_states.to(dtype=torch.float32) + ) + final_hidden_states[visual_token_mask] = self.vision_experts( + hidden_states=vision_hidden_states, router_logits=vision_router_logits + ).flatten() + else: + # only text modal input + text_router_logits, _ = self.text_experts_gate( + hidden_states.to(dtype=torch.float32) + ) + + final_hidden_states = self.text_experts( + hidden_states=hidden_states, router_logits=text_router_logits + ) + + if self.has_shared_experts: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + + if self.tp_size > 1: + final_hidden_states = ( + self.text_experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + ) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_VLMoeDecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + freq_allocation = getattr(config, "freq_allocation", 20) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) + + self.self_attn = Ernie4_5_VLMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, "head_dim", None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + freq_allocation=freq_allocation, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "use_bias", False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_layer_start_index = config.moe_layer_start_index + min_moe_layer_start_index = min(moe_layer_start_index) + moe_layer_end_index = getattr( + config, + "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1], + ) + max_moe_layer_end_index = max(moe_layer_end_index) + assert min_moe_layer_start_index <= max_moe_layer_end_index + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) + + if ( + use_moe + and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index + ): + self.mlp = Ernie4_5_VLMoeMoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + else: + self.mlp = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, "use_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor], + **kwargs: object, + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + if isinstance(self.mlp, Ernie4_5_VLMoeMoE): + hidden_states = self.mlp(hidden_states, visual_token_mask, **kwargs) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# Since Ernie VL distinguishes between text experts and vision experts, +# enabling torch.compile will cause errors. +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# "visual_token_mask": 0, +# }) +class Ernie4_5_VLMoeModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.im_patch_id = config.im_patch_id + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_VLMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, hidden_states, residual, visual_token_mask, **kwargs + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +# only used as text backbone for ernie4.5-vl +class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_VLMoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=max(self.config.moe_num_experts), + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith("lm_head.weight"): + loaded_params.add("lm_head.weight") + continue + # MTP will be supported soon. + if "mtp" in name or "vision_model" in name or "resampler_model" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Distinguish between vision experts and text experts + if "mlp.experts" in name: + moe_offset = int(name.split(".")[-3]) + vision_expert_start_idx = self.config.moe_num_experts[0] + is_text_expert = moe_offset <= vision_expert_start_idx - 1 + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace( + f".experts.{moe_offset}", + f".vision_experts.{moe_offset - vision_expert_start_idx}", + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + # Distinguish between vision experts and text experts + moe_offset = int(name.split(".")[-3]) + is_text_expert = moe_offset <= self.config.moe_num_experts[0] - 1 + + name = name.replace(weight_name, param_name) + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(".experts.", ".vision_experts.") + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Distinguish between vision expert gate + # and text expert gate + if name.endswith("mlp.gate.weight"): + name = name.replace("gate.weight", "text_experts_gate.weight") + loaded_weight = loaded_weight.T + elif name.endswith("mlp.gate.weight_1"): + name = name.replace( + "gate.weight_1", "vision_experts_gate.weight" + ) + loaded_weight = loaded_weight.T + + if "e_score_correction_bias" in name: + name = name.replace(".moe_statics.", ".") + + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 90a1267b28f0a..46a7131f2499a 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Ernie-MTP model.""" + from collections.abc import Iterable from typing import Optional @@ -29,15 +30,14 @@ import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -46,26 +46,20 @@ from .utils import is_pp_missing_parameter, maybe_prefix class ErnieMultiTokenPredictorLayer(nn.Module): - def __init__( self, - config: PretrainedConfig, + vllm_config: VllmConfig, prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + config = vllm_config.model_config.hf_config - self.mtp_emb_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.mtp_hidden_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, - prefix) + self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mtp_hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mtp_linear_proj = nn.Linear( + config.hidden_size * 2, config.hidden_size, bias=False + ) + self.mtp_block = LlamaDecoderLayer(vllm_config, prefix) def forward( self, @@ -82,18 +76,18 @@ class ErnieMultiTokenPredictorLayer(nn.Module): previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) hidden_states = self.mtp_linear_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return hidden_states class ErnieMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -101,23 +95,27 @@ class ErnieMultiTokenPredictor(nn.Module): self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - ErnieMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): ErnieMultiTokenPredictorLayer( + vllm_config, + f"{prefix}.layers.{idx}", + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -139,31 +137,33 @@ class ErnieMultiTokenPredictor(nn.Module): self, hidden_states: torch.Tensor, lm_head: ParallelLMHead, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(lm_head, hidden_states) return logits class ErnieMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) - self.sampler = get_sampler() + self.model = ErnieMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -174,29 +174,19 @@ class ErnieMTP(nn.Module, SupportsPP): spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "ernie_mtp only support predict one token" - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, self.lm_head, - sampling_metadata, spec_step_idx) + return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -208,16 +198,14 @@ class ErnieMTP(nn.Module, SupportsPP): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - - if self.config.tie_word_embeddings and name.endswith( - "lm_head.weight"): + if self.config.tie_word_embeddings and name.endswith("lm_head.weight"): continue if "rotary_emb.inv_freq" in name: continue if "mtp" in name: name = self._rewrite_spec_layer_name(self.config, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -229,12 +217,13 @@ class ErnieMTP(nn.Module, SupportsPP): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -246,8 +235,9 @@ class ErnieMTP(nn.Module, SupportsPP): break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -255,33 +245,36 @@ class ErnieMTP(nn.Module, SupportsPP): # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if "mtp_" not in name and ("embed_tokens" not in name - and "lm_head" not in name): + if "mtp_" not in name and ( + "embed_tokens" not in name and "lm_head" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - def _rewrite_spec_layer_name(self, config: PretrainedConfig, - name: str) -> str: + def _rewrite_spec_layer_name(self, config: PretrainedConfig, name: str) -> str: """ Rewrite the weight name to match the format of the original model. """ spec_layer_weight_names = [ - "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", - "mtp_linear_proj" + "embed_tokens", + "mtp_emb_norm", + "mtp_hidden_norm", + "mtp_linear_proj", ] layer_idx = config.num_hidden_layers for weight_name in spec_layer_weight_names: if weight_name in name: name = name.replace( f"model.{weight_name}.0.", - f"model.layers.{layer_idx}.{weight_name}.") + f"model.layers.{layer_idx}.{weight_name}.", + ) return name - name = name.replace("model.mtp_block.0.", - f"model.layers.{layer_idx}.mtp_block.") + name = name.replace( + "model.mtp_block.0.", f"model.layers.{layer_idx}.mtp_block." + ) return name diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 8052b6bb82348..1f0b5723721c6 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -26,6 +26,7 @@ """Inference-only Exaone model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -38,27 +39,37 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class ExaoneGatedMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -84,8 +95,9 @@ class ExaoneGatedMLP(nn.Module): prefix=f"{prefix}.c_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -96,7 +108,6 @@ class ExaoneGatedMLP(nn.Module): class ExaoneAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -191,7 +202,6 @@ class ExaoneAttention(nn.Module): class ExaoneBlockAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -233,7 +243,6 @@ class ExaoneBlockAttention(nn.Module): class ExaoneDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -246,21 +255,24 @@ class ExaoneDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.attn = ExaoneBlockAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -305,7 +317,6 @@ class ExaoneDecoderLayer(nn.Module): @support_torch_compile class ExaoneModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -316,12 +327,16 @@ class ExaoneModel(nn.Module): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.wte = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.wte = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -341,14 +356,13 @@ class ExaoneModel(nn.Module): prefix=f"{prefix}.h", ) if get_pp_group().is_last_rank: - self.ln_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) else: self.ln_f = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -371,7 +385,7 @@ class ExaoneModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -379,16 +393,14 @@ class ExaoneModel(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -402,19 +414,19 @@ class ExaoneModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -447,8 +459,7 @@ class ExaoneModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -499,21 +510,24 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -525,27 +539,24 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + model_output = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 827e9014184b5..230a2c80104b1 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -22,6 +22,7 @@ """Inference-only Exaone model compatible with HuggingFace weights.""" from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -34,28 +35,38 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Exaone4GatedMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -81,8 +92,9 @@ class Exaone4GatedMLP(nn.Module): prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -93,7 +105,6 @@ class Exaone4GatedMLP(nn.Module): class Exaone4Attention(nn.Module): - def __init__( self, config: Exaone4Config, @@ -163,8 +174,8 @@ class Exaone4Attention(nn.Module): is_sliding = config.layer_types[layer_idx] == "sliding_attention" self.sliding_window = config.sliding_window if is_sliding else None - # apply rotary embeddings to every layer - self.apply_all_layers = not is_sliding + # apply rotary embeddings to every layer in full attention models + self.apply_rope_all_layers = "sliding_attention" not in config.layer_types self.rotary_emb = get_rope( self.head_dim, @@ -200,7 +211,7 @@ class Exaone4Attention(nn.Module): k = self.k_norm(k) k = k.flatten(-2, -1) - if self.sliding_window or self.apply_all_layers: + if self.sliding_window or self.apply_rope_all_layers: q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -208,7 +219,6 @@ class Exaone4Attention(nn.Module): class Exaone4DecoderLayer(nn.Module): - def __init__( self, config: Exaone4Config, @@ -221,22 +231,25 @@ class Exaone4DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = Exaone4Attention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -253,10 +266,12 @@ class Exaone4DecoderLayer(nn.Module): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -290,7 +305,6 @@ class Exaone4DecoderLayer(nn.Module): @support_torch_compile class Exaone4Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -301,11 +315,15 @@ class Exaone4Model(nn.Module): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -329,9 +347,9 @@ class Exaone4Model(nn.Module): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -354,7 +372,7 @@ class Exaone4Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -362,16 +380,14 @@ class Exaone4Model(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -385,19 +401,19 @@ class Exaone4Model(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -430,8 +446,7 @@ class Exaone4Model(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -482,21 +497,24 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -508,27 +526,24 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py index d78ee100b26df..ca0e7e64df53d 100644 --- a/vllm/model_executor/models/fairseq2_llama.py +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -23,8 +23,10 @@ import torch from torch.nn import Parameter from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.linear import set_weight_attrs from vllm.model_executor.models.llama import LlamaForCausalLM @@ -32,7 +34,6 @@ from .utils import AutoWeightsLoader, WeightsMapper class Fairseq2LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) self.tp_rank = get_tensor_model_parallel_rank() @@ -45,14 +46,12 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM): f"model.{self.tp_rank}.pt", ] - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # fairseq2's serialization adds a wrapper to usual .pt state_dict's: # { "model_key": my_model_name, "my_model_name": state_dict } # which we first need to unpack weights_wrapped = dict(weights) - weights = weights_wrapped[ - weights_wrapped["model_key"]].items() # type: ignore + weights = weights_wrapped[weights_wrapped["model_key"]].items() # type: ignore # remap keys fs2_to_vllm_mapper = WeightsMapper( @@ -77,12 +76,14 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM): loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( - (self.reshape_fairseq2_weights(name, loaded_weight, params) - for name, loaded_weight in weights)) + ( + self.reshape_fairseq2_weights(name, loaded_weight, params) + for name, loaded_weight in weights + ) + ) def flag_sharded_weights(self, params: dict[str, Parameter]): """Sets the `is_sharded_weight` flag to True for all sharded weights""" @@ -113,35 +114,34 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM): attn_in //= self.tp_size n_heads //= self.tp_size attn_out = self.config.hidden_size - return (w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, - 2).reshape(attn_in, attn_out)) + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) modules = name.split(".") # rotary embeds should be sliced if "k_proj" in modules: - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) + loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) elif "q_proj" in modules: - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + loaded_weight = permute(loaded_weight, self.config.num_attention_heads) # We make the loaded weights compatible with both # full checkpoints and tp sharded checkpoints. # Embeddings are repeated to fit the vocab size. - # Other weights are flagged for the weight_loader calls. + # Other weights are flagged for the weight_loader calls. if any(emb in modules for emb in ["embed_tokens", "lm_head"]): # Embeddings are sharded on dim 0 dim = 0 # In fairseq2, vocab size has to be divisible by tp_size # so we don't worry about padding - if self.tp_size > 1 and loaded_weight.shape[ - dim] < self.config.vocab_size: - assert loaded_weight.shape[ - dim] * self.tp_size == self.config.vocab_size, \ - "vocab_size should be divisible by tp_size." + if self.tp_size > 1 and loaded_weight.shape[dim] < self.config.vocab_size: + assert ( + loaded_weight.shape[dim] * self.tp_size == self.config.vocab_size + ), "vocab_size should be divisible by tp_size." repeats = [1] * len(loaded_weight.size()) repeats[dim] = self.tp_size # repeat to match vocab size and to be easily 'narrow'able diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 62a93dabd5d7f..211a9120789e2 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -22,6 +22,7 @@ import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -32,56 +33,65 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) FalconConfig = Union[HF_FalconConfig, RWConfig] def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) - base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), - dtype=torch.float32) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32 + ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(1, - 1 + 2 * num_remaining_heads, - 2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32 + ) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class FalconAttention(nn.Module): - def __init__( self, config: FalconConfig, @@ -133,59 +143,68 @@ class FalconAttention(nn.Module): # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=config.bias, skip_bias_add=True, quant_config=quant_config, - reduce_results=self.reduce_row_parallel_results) + reduce_results=self.reduce_row_parallel_results, + ) self.use_rotary = config.rotary self.use_alibi = config.alibi assert not (self.use_rotary and self.use_alibi), ( - "Rotary and alibi are mutually exclusive.") + "Rotary and alibi are mutually exclusive." + ) if self.use_rotary: rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * - self.inv_norm_factor) + alibi_slopes = ( + _get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor + ) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) else: - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -204,7 +223,6 @@ class FalconAttention(nn.Module): class FalconMLP(nn.Module): - def __init__( self, config: FalconConfig, @@ -213,21 +231,25 @@ class FalconMLP(nn.Module): super().__init__() hidden_size = config.hidden_size - self.dense_h_to_4h = ColumnParallelLinear(hidden_size, - 4 * hidden_size, - bias=config.bias, - skip_bias_add=True, - quant_config=quant_config) + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + ) self.act = get_act_fn("gelu") - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, bias=config.bias, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, - quant_config=quant_config) + quant_config=quant_config, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -240,7 +262,6 @@ class FalconMLP(nn.Module): class FalconDecoderLayer(nn.Module): - def __init__( self, config: FalconConfig, @@ -252,39 +273,36 @@ class FalconDecoderLayer(nn.Module): hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.self_attention = FalconAttention( - config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + config, cache_config, quant_config, prefix=f"{prefix}.self_attention" + ) self.mlp = FalconMLP(config, quant_config) self.config = config - if (not hasattr(config, "num_ln_in_parallel_attn")): + if not hasattr(config, "num_ln_in_parallel_attn"): config.num_ln_in_parallel_attn = None - if (config.num_ln_in_parallel_attn is None - and config.new_decoder_architecture): + if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture: config.num_ln_in_parallel_attn = 2 if not config.parallel_attn: self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) - self.input_layernorm = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + hidden_size, eps=config.layer_norm_epsilon + ) + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: if config.num_ln_in_parallel_attn == 2: # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: - self.input_layernorm = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.input_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon + ) - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) def forward( self, @@ -314,8 +332,11 @@ class FalconDecoderLayer(nn.Module): residual += attention_output mlp_layernorm_out = self.post_attention_layernorm(residual) - if (self.config.new_decoder_architecture and self.config.parallel_attn - and self.config.num_ln_in_parallel_attn == 1): + if ( + self.config.new_decoder_architecture + and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1 + ): mlp_layernorm_out = attention_layernorm_out # MLP. @@ -340,7 +361,6 @@ class FalconDecoderLayer(nn.Module): @support_torch_compile class FalconModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -363,14 +383,16 @@ class FalconModel(nn.Module): self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: FalconDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.h", + ) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) @@ -389,15 +411,14 @@ class FalconModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -420,26 +441,34 @@ class FalconModel(nn.Module): loaded_weight_shape = loaded_weight.shape if output_dim is not None: loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + - (total_num_kv_heads, num_query_heads_per_kv_head + 2, - -1) + loaded_weight_shape[output_dim + 1:]) + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + + loaded_weight_shape[output_dim + 1 :] + ) wq = loaded_weight.narrow( - output_dim + 1, 0, - num_query_heads_per_kv_head).reshape( - *loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) + output_dim + 1, 0, num_query_heads_per_kv_head + ).reshape( + *loaded_weight_shape[:output_dim], + -1, + *loaded_weight_shape[output_dim + 1 :], + ) wk = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) + output_dim + 1, num_query_heads_per_kv_head, 1 + ).reshape( + *loaded_weight_shape[:output_dim], + -1, + *loaded_weight_shape[output_dim + 1 :], + ) wv = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head + 1, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) + output_dim + 1, num_query_heads_per_kv_head + 1, 1 + ).reshape( + *loaded_weight_shape[:output_dim], + -1, + *loaded_weight_shape[output_dim + 1 :], + ) loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -456,15 +485,17 @@ class FalconForCausalLM(nn.Module, SupportsPP): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = FalconModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = FalconModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) # only Falcon-11B doesn't share lm_head weight with word embeddings # and previous Falcon model doesn't have tie_word_embeddings config # so we set tie_word_embeddings to True by default - self.tie_word_embeddings = (config.tie_word_embeddings - if config.tie_word_embeddings is not None - else True) + self.tie_word_embeddings = ( + config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True + ) if self.tie_word_embeddings: self.lm_head = self.transformer.word_embeddings else: @@ -472,10 +503,12 @@ class FalconForCausalLM(nn.Module, SupportsPP): config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -487,24 +520,21 @@ class FalconForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 5e2b6d69124c8..db938dda5d637 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -1,49 +1,54 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only FalconH1 model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional import torch from torch import nn from transformers import FalconH1Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class FalconH1MLP(nn.Module): - def __init__( self, config: FalconH1Config, @@ -67,13 +72,15 @@ class FalconH1MLP(nn.Module): self.intermediate_size = config.intermediate_size self.gate_multiplier, self.down_multiplier = config.mlp_multipliers if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): x, _ = self.gate_up_proj(x) - x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier + x[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier x = self.act_fn(x) x, _ = self.down_proj(x) x = x * self.down_multiplier @@ -81,7 +88,6 @@ class FalconH1MLP(nn.Module): class FalconH1SSMDecoderLayer(nn.Module): - def __init__( self, config: FalconH1Config, @@ -94,8 +100,11 @@ class FalconH1SSMDecoderLayer(nn.Module): self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.d_ssm = (int(config.mamba_expand * config.hidden_size) - if config.mamba_d_ssm is None else config.mamba_d_ssm) + self.d_ssm = ( + int(config.mamba_expand * config.hidden_size) + if config.mamba_d_ssm is None + else config.mamba_d_ssm + ) self.mamba = MambaMixer2( hidden_size=config.hidden_size, @@ -122,15 +131,15 @@ class FalconH1SSMDecoderLayer(nn.Module): def _init_mup_vector(self): """ - Non learnable per-block scaling vector composed of element-wise - multipliersapplied to each separate contiguous block of the output + Non learnable per-block scaling vector composed of element-wise + multipliersapplied to each separate contiguous block of the output of the linear projection (in_proj) before further processing (gating, convolution, SSM): - Z block: [0 : d_ssm] → zxbcdt_multipliers[0] - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1] - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2] - - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] + - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] → zxbcdt_multipliers[3] - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4] @@ -140,38 +149,38 @@ class FalconH1SSMDecoderLayer(nn.Module): - S: SSM state size per group - All indices are divided by tp_size to support tensor parallelism """ - vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size + - self.config.mamba_n_heads) // self.tp_size + vector_shape = ( + 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads + ) // self.tp_size mup_vector = torch.ones(1, vector_shape) # Z vector 0 -> d_ssm - mup_vector[:, :self.d_ssm // - self.tp_size] *= self.zxbcdt_multipliers[0] + mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0] # X vector d_ssm -> 2 * d_ssm - mup_vector[:, - (self.d_ssm // - self.tp_size):(2 * self.d_ssm // - self.tp_size)] *= self.zxbcdt_multipliers[1] + mup_vector[ + :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size) + ] *= self.zxbcdt_multipliers[1] # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state) mup_vector[ :, - (2 * self.d_ssm) // - self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) // - self.tp_size, + (2 * self.d_ssm) // self.tp_size : ( + 2 * self.d_ssm + self.groups_time_state_size + ) + // self.tp_size, ] *= self.zxbcdt_multipliers[2] # C vector 2 * d_ssm + (n_group * d_state) # -> 2 * d_ssm + 2 * (n_group * d_state) mup_vector[ :, - (2 * self.d_ssm + self.groups_time_state_size) // - self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) // - self.tp_size, + (2 * self.d_ssm + self.groups_time_state_size) // self.tp_size : ( + 2 * self.d_ssm + 2 * self.groups_time_state_size + ) + // self.tp_size, ] *= self.zxbcdt_multipliers[3] # dt vector 2 * d_ssm + 2 * (n_group * d_state) # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads mup_vector[ :, - (2 * self.d_ssm + 2 * self.groups_time_state_size) // - self.tp_size:, + (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :, ] *= self.zxbcdt_multipliers[4] self.register_buffer("mup_vector", mup_vector, persistent=False) @@ -180,23 +189,18 @@ class FalconH1SSMDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): output = torch.empty_like(hidden_states) self.mamba( hidden_states, output, - mamba_cache_params, - mamba2_metadata=mamba2_metadata, mup_vector=self.mup_vector, ) return output, residual class FalconH1AttentionDecoderLayer(nn.Module): - def __init__( self, config: FalconH1Config, @@ -207,8 +211,7 @@ class FalconH1AttentionDecoderLayer(nn.Module): super().__init__() rope_theta = getattr(config, "rope_theta", 1e11) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -224,8 +227,11 @@ class FalconH1AttentionDecoderLayer(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = (config.hidden_size // self.total_num_heads if getattr( - config, "head_dim", None) is None else config.head_dim) + self.head_dim = ( + config.hidden_size // self.total_num_heads + if getattr(config, "head_dim", None) is None + else config.head_dim + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -356,17 +362,13 @@ class FalconH1ParallelHybrid(nn.Module): self.feed_forward = FalconH1MLP(config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): residual = hidden_states @@ -383,19 +385,18 @@ class FalconH1ParallelHybrid(nn.Module): # Process input through the SSM branch. # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, - # residual, mamba_cache_params, and sequence_idx. + # residual, and sequence_idx. ssm_hidden, _ = self.mamba( hidden_states=hidden_states * self.ssm_in_multiplier, residual=residual, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, **kwargs, ) # Sum the outputs from both branches. # We assume both branches produce outputs of the same # dimensionality (config.hidden_size). hidden_states = (attn_hidden * self.attn_out_multiplier) + ( - ssm_hidden * self.ssm_out_multiplier) + ssm_hidden * self.ssm_out_multiplier + ) hidden_states = hidden_states + residual # feed-forward @@ -409,7 +410,6 @@ class FalconH1ParallelHybrid(nn.Module): @support_torch_compile class FalconH1Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: FalconH1Config = vllm_config.model_config.hf_config @@ -419,12 +419,14 @@ class FalconH1Model(nn.Module): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -448,13 +450,13 @@ class FalconH1Model(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.final_layernorm = PPMissingLayer() @@ -465,56 +467,36 @@ class FalconH1Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds * self.embedding_multiplier else: - hidden_states = (self.get_input_embeddings(input_ids) * - self.embedding_multiplier) + hidden_states = ( + self.get_input_embeddings(input_ids) * self.embedding_multiplier + ) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - layer_mamba_cache_params = None - if mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( positions=positions, hidden_states=hidden_states, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.final_layernorm(hidden_states) return hidden_states -class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): +class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -531,7 +513,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -542,13 +523,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -558,10 +537,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - intermediate_size = (int(hf_config.mamba_expand * - hf_config.hidden_size) - if hf_config.mamba_d_ssm is None else - hf_config.mamba_d_ssm) + intermediate_size = ( + int(hf_config.mamba_expand * hf_config.hidden_size) + if hf_config.mamba_d_ssm is None + else hf_config.mamba_d_ssm + ) return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, @@ -571,29 +551,25 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "FalconH1 currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = FalconH1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = FalconH1Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.tie_word_embeddings = config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size - self.mamba_cache: Optional[MambaCacheManager] = None if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size if get_pp_group().is_last_rank: @@ -605,13 +581,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), + prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_head_multiplier = config.lm_head_multiplier if self.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Used to track and store by the Mamba cache between steps. self.logits_processor = LogitsProcessor( @@ -623,7 +600,8 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -636,53 +614,24 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): - - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager( - self.vllm_config, - self.config.num_hidden_layers, - *mamba_state_shape, - *mamba_state_dtype, - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - hidden_states = self.model( input_ids, positions, - mamba_cache_params, intermediate_tensors, inputs_embeds, ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -729,8 +678,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) diff --git a/vllm/model_executor/models/flex_olmo.py b/vllm/model_executor/models/flex_olmo.py new file mode 100644 index 0000000000000..b1fbbf086896d --- /dev/null +++ b/vllm/model_executor/models/flex_olmo.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only FlexOlmo model compatible with HuggingFace weights.""" + +from typing import Optional + +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM +from vllm.transformers_utils.configs import FlexOlmoConfig + +logger = init_logger(__name__) + + +class FlexOlmoAttention(OlmoeAttention): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps + ) + self.q_norm = RMSNorm( + self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps + ) + + +class FlexOlmoMoE(nn.Module): + """A tensor-parallel MoE implementation for FlexOlmo that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + tp_size = get_tensor_model_parallel_world_size() + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hf_config.hidden_size, + hf_config.num_experts, + bias=False, + return_bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + # Gate always runs at half / full precision for now. + self.experts = FusedMoE( + num_experts=hf_config.num_experts, + top_k=hf_config.num_experts_per_tok, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=None, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) + + self.top_k = hf_config.num_experts_per_tok + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + # Warning: The experts mutate the hidden state input! This messes up + # basic things like the residual stream. + final_hidden_states = self.experts( + hidden_states=hidden_states.detach().clone(), + router_logits=router_logits.float(), + ) + + return final_hidden_states.view(orig_shape) + + +class FlexOlmoDecoderLayer(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + self.self_attn = FlexOlmoAttention( + vllm_config=vllm_config, prefix=f"{prefix}.self_attn" + ) + self.post_attention_layernorm = RMSNorm( + hf_config.hidden_size, eps=hf_config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + hf_config.hidden_size, eps=hf_config.rms_norm_eps + ) + + self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, None + + +class FlexOlmoForCausalLM(OlmoeForCausalLM): + fall_back_to_pt_during_load = False + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = FlexOlmoDecoderLayer, + ): + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py deleted file mode 100644 index d0881231fb1e7..0000000000000 --- a/vllm/model_executor/models/florence2.py +++ /dev/null @@ -1,1107 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from collections import OrderedDict -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers import BartTokenizer, BatchFeature, PretrainedConfig - -from vllm.config import VllmConfig -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, - BartParallelLMHead, - BartScaledWordEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptIndexTargets, PromptInsertion, - PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsV0Only) -from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings - - -class Florence2ImagePixelInputs(TensorSchema): - """ - Dimensions: - - b: Batch size - - c: Number of channels (3) - - h: Height of the image - - w: Width of the image - """ - - type: Literal["pixel_values"] - - data: Annotated[ - torch.Tensor, - TensorShape("b", 3, "h", "w"), - ] - - -# ViT implementation are all copied from -# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py -class LearnedAbsolutePositionEmbedding2D(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, embedding_dim=256, num_pos=50): - super().__init__() - self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) - self.column_embeddings = nn.Embedding( - num_pos, embedding_dim - (embedding_dim // 2)) - - def forward(self, pixel_values): - """ - pixel_values: (batch_size, height, width, num_channels) - returns: (batch_size, height, width, embedding_dim * 2) - """ - if len(pixel_values.shape) != 4: - raise ValueError('pixel_values must be a 4D tensor') - height, width = pixel_values.shape[1:3] - width_values = torch.arange(width, device=pixel_values.device) - height_values = torch.arange(height, device=pixel_values.device) - x_emb = self.column_embeddings(width_values) - y_emb = self.row_embeddings(height_values) - # (height, width, embedding_dim * 2) - pos = torch.cat([ - x_emb.unsqueeze(0).repeat(height, 1, 1), - y_emb.unsqueeze(1).repeat(1, width, 1) - ], - dim=-1) - # (embedding_dim * 2, height, width) - pos = pos.permute(2, 0, 1) - pos = pos.unsqueeze(0) - # (batch_size, embedding_dim * 2, height, width) - pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) - # (batch_size, height, width, embedding_dim * 2) - pos = pos.permute(0, 2, 3, 1) - return pos - - -class PositionalEmbeddingCosine1D(nn.Module): - """ - This class implements a very simple positional encoding. It follows closely - the encoder from the link below: - https://pytorch.org/tutorials/beginner/translation_transformer.html - Args: - embed_dim: The dimension of the embeddings. - dropout_prob: The dropout probability. - max_seq_len: The maximum length to precompute the positional encodings. - """ - - def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: - super().__init__() - self.embed_dim = embed_dim - self.max_seq_len = max_seq_len - # Generate the sinusoidal arrays. - factor = math.log(10000) - denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / - self.embed_dim) - # Matrix where rows correspond to a positional embedding as a function - # of the position index (i.e., the row index). - frequencies = \ - torch.arange(0, self.max_seq_len) \ - .reshape(self.max_seq_len, 1) * denominator - pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) - # Populate uneven entries. - pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) - pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) - # Save the positional embeddings in a constant buffer. - # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) - self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, - requires_grad=False) - - def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: - """ - Args: - seq_embeds: The sequence embeddings in order. Allowed size: - 1. [T, D], where T is the length of the sequence, and D is the - frame embedding dimension. - 2. [B, T, D], where B is the batch size and T and D are the - same as above. - Returns a tensor of with the same dimensions as the input: i.e., - [1, T, D] or [T, D]. - """ - shape_len = len(seq_embeds.shape) - assert 2 <= shape_len <= 3 - len_seq = seq_embeds.size(-2) - assert len_seq <= self.max_seq_len - pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] - # Adapt pre-computed positional embeddings to the input. - if shape_len == 3: - pos_embeds = pos_embeds.view( - (1, pos_embeds.size(0), pos_embeds.size(1))) - return pos_embeds - - -class MySequential(nn.Sequential): - - def forward(self, *inputs): - for module in self._modules.values(): - if isinstance(inputs, tuple): - inputs = module(*inputs) - else: - inputs = module(inputs) - return inputs - - -class PreNorm(nn.Module): - - def __init__(self, norm, fn): - super().__init__() - self.norm = norm - self.fn = fn - - def forward(self, x, *args, **kwargs): - shortcut = x - if self.norm is not None: - x, size = self.fn(self.norm(x), *args, **kwargs) - else: - x, size = self.fn(x, *args, **kwargs) - - x = shortcut + x - - return x, size - - -class Mlp(nn.Module): - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.net = nn.Sequential( - OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), - ("act", act_layer()), - ("fc2", nn.Linear(hidden_features, out_features))])) - - def forward(self, x, size): - return self.net(x), size - - -class DepthWiseConv2d(nn.Module): - - def __init__( - self, - dim_in, - kernel_size, - padding, - stride, - bias=True, - ): - super().__init__() - self.dw = nn.Conv2d(dim_in, - dim_in, - kernel_size=kernel_size, - padding=padding, - groups=dim_in, - stride=stride, - bias=bias) - - def forward(self, x, size): - B, N, C = x.shape - H, W = size - assert N == H * W - - x = self.dw(x.transpose(1, 2).view(B, C, H, W)) - size = (x.size(-2), x.size(-1)) - x = x.flatten(2).transpose(1, 2) - return x, size - - -class ConvEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, - patch_size=7, - in_chans=3, - embed_dim=64, - stride=4, - padding=2, - norm_layer=None, - pre_norm=True): - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d(in_chans, - embed_dim, - kernel_size=patch_size, - stride=stride, - padding=padding) - - dim_norm = in_chans if pre_norm else embed_dim - self.norm = norm_layer(dim_norm) if norm_layer else None - - self.pre_norm = pre_norm - - def forward(self, x, size): - H, W = size - if len(x.size()) == 3: - if self.norm and self.pre_norm: - x = self.norm(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) - - x = self.proj(x) - - _, _, H, W = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - if self.norm and not self.pre_norm: - x = self.norm(x) - - return x, (H, W) - - -class ChannelAttention(nn.Module): - - def __init__(self, dim, groups=8, qkv_bias=True): - super().__init__() - - self.groups = groups - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x, size): - B, N, C = x.shape - - qkv = self.qkv(x).reshape(B, N, 3, self.groups, - C // self.groups).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * (float(N)**-0.5) - attention = q.transpose(-1, -2) @ k - attention = attention.softmax(dim=-1) - x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - return x, size - - -class ChannelBlock(nn.Module): - - def __init__(self, - dim, - groups, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - conv_at_attn=True, - conv_at_ffn=True): - super().__init__() - - self.conv1 = PreNorm(None, DepthWiseConv2d( - dim, 3, 1, 1)) if conv_at_attn else None - self.channel_attn = PreNorm( - norm_layer(dim), - ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), - ) - self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, - 1)) if conv_at_ffn else None - self.ffn = PreNorm( - norm_layer(dim), - Mlp(in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer), - ) - - def forward(self, x, size): - if self.conv1: - x, size = self.conv1(x, size) - x, size = self.channel_attn(x, size) - - if self.conv2: - x, size = self.conv2(x, size) - x, size = self.ffn(x, size) - - return x, size - - -def window_partition(x, window_size: int): - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, - C) - windows = x.permute(0, 1, 3, 2, 4, - 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): - B = batch_size - - x = windows.view(B, H // window_size, W // window_size, window_size, - window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - - def __init__(self, dim, num_heads, window_size, qkv_bias=True): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = float(head_dim)**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, size): - - H, W = size - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C) - - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - x = window_partition(x, self.window_size) - x = x.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # attn_windows = self.attn(x_windows) - - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = self.softmax(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - - # merge windows - x = x.view(-1, self.window_size, self.window_size, C) - x = window_reverse(x, B, self.window_size, Hp, Wp) - - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - return x, size - - -class SpatialBlock(nn.Module): - - def __init__(self, - dim, - num_heads, - window_size, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - conv_at_attn=True, - conv_at_ffn=True): - super().__init__() - - self.conv1 = PreNorm(None, DepthWiseConv2d( - dim, 3, 1, 1)) if conv_at_attn else None - self.window_attn = PreNorm( - norm_layer(dim), - WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), - ) - self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, - 1)) if conv_at_ffn else None - self.ffn = PreNorm( - norm_layer(dim), - Mlp(in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer), - ) - - def forward(self, x, size): - if self.conv1: - x, size = self.conv1(x, size) - x, size = self.window_attn(x, size) - - if self.conv2: - x, size = self.conv2(x, size) - x, size = self.ffn(x, size) - return x, size - - -class DaViT(nn.Module): - - def __init__( - self, - in_chans=3, - num_classes=1000, - depths=(1, 1, 3, 1), - patch_size=(7, 2, 2, 2), - patch_stride=(4, 2, 2, 2), - patch_padding=(3, 0, 0, 0), - patch_prenorm=(False, False, False, False), - embed_dims=(64, 128, 192, 256), - num_heads=(3, 6, 12, 24), - num_groups=(3, 6, 12, 24), - window_size=7, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - enable_checkpoint=False, - conv_at_attn=True, - conv_at_ffn=True, - ): - super().__init__() - - self.num_classes = num_classes - self.embed_dims = embed_dims - self.num_heads = num_heads - self.num_groups = num_groups - self.num_stages = len(self.embed_dims) - self.enable_checkpoint = enable_checkpoint - assert self.num_stages == len(self.num_heads) == len(self.num_groups) - - num_stages = len(embed_dims) - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, - sum(depths) * 2) - ] - - depth_offset = 0 - convs = [] - blocks = [] - for i in range(num_stages): - conv_embed = ConvEmbed( - patch_size=patch_size[i], - stride=patch_stride[i], - padding=patch_padding[i], - in_chans=in_chans if i == 0 else self.embed_dims[i - 1], - embed_dim=self.embed_dims[i], - norm_layer=norm_layer, - pre_norm=patch_prenorm[i]) - convs.append(conv_embed) - - block = MySequential(*[ - MySequential( - OrderedDict([('spatial_block', - SpatialBlock( - embed_dims[i], - num_heads[i], - window_size, - drop_path_rate=dpr[depth_offset + j * 2], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - )), - ('channel_block', - ChannelBlock( - embed_dims[i], - num_groups[i], - drop_path_rate=dpr[depth_offset + j * 2 + - 1], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - ))])) for j in range(depths[i]) - ]) - blocks.append(block) - depth_offset += depths[i] * 2 - - self.convs = nn.ModuleList(convs) - self.blocks = nn.ModuleList(blocks) - - self.avgpool = nn.AdaptiveAvgPool1d(1) - - @property - def dim_out(self): - return self.embed_dims[-1] - - def forward_features_unpool(self, x): - """ - forward until avg pooling - Args: - x (_type_): input image tensor - """ - input_size = (x.size(2), x.size(3)) - for conv, block in zip(self.convs, self.blocks): - x, input_size = conv(x, input_size) - x, input_size = block(x, input_size) - return x - - def forward_features(self, x): - x = self.forward_features_unpool(x) - - # (batch_size, num_tokens, token_dim) - x = self.avgpool(x.transpose(1, 2)) - # (batch_size, 1, num_tokens) - x = torch.flatten(x, 1) - x = self.norms(x) - - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - @classmethod - def from_config(cls, config): - return cls( - depths=config.depths, - embed_dims=config.dim_embed, - num_heads=config.num_heads, - num_groups=config.num_groups, - patch_size=config.patch_size, - patch_stride=config.patch_stride, - patch_padding=config.patch_padding, - patch_prenorm=config.patch_prenorm, - drop_path_rate=config.drop_path_rate, - window_size=config.window_size, - ) - - -# Language backbone and processor implementation -class Florence2LanguageModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.config = config - - self.vocab_size = config.vocab_size - - self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) - self.encoder = BartEncoder(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = BartDecoder(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - if self.config.tie_word_embeddings: - self.encoder.embed_tokens.weight = self.shared.weight - self.decoder.embed_tokens.weight = self.shared.weight - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. - Returns: - Model output torch.Tensor - """ - - encoder_hidden_states = None - - if ((inputs_embeds is not None and inputs_embeds.numel() > 0) - or encoder_input_ids.numel() > 0): - # Run encoder attention if a non-zero number of encoder tokens - # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - inputs_embeds=inputs_embeds) - - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states) - - return decoder_outputs - - -class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - - self.config = config - self.model = Florence2LanguageModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.vocab_size = config.vocab_size - self.lm_head = BartParallelLMHead(self.vocab_size, - config.d_model, - embed_scale=embed_scale) - if self.config.tie_word_embeddings: - self.lm_head.tie_weights(self.model.shared) - - self.logits_processor = LogitsProcessor(self.vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - - return self.model(input_ids, - positions, - encoder_input_ids, - encoder_positions, - inputs_embeds=inputs_embeds) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.encoder.embed_tokens(input_ids) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if "final_logits_bias" in name: - continue - if self.config.tie_word_embeddings and ("embed_tokens" in name - or "lm_head" in name): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Florence2ProcessingInfo(BaseProcessingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} - - def get_num_image_tokens(self) -> int: - processor_config = self.ctx.get_hf_image_processor_config() - return processor_config["image_seq_length"] - - -class Florence2DummyInputsBuilder( - BaseDummyInputsBuilder[Florence2ProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width = target_height = self.info.get_hf_config().projection_dim - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - -class Florence2MultiModalProcessor( - EncDecMultiModalProcessor[Florence2ProcessingInfo]): - - def _hf_processor_applies_updates( - self, - prompt_text: str, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> bool: - return False - - def create_encoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - return prompt - - def create_decoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - return [self.info.get_hf_config().eos_token_id] - - def _apply_hf_processor_tokens_only( - self, - prompt_tokens: list[int], - ) -> list[int]: - hf_processor = self.info.get_hf_processor() - tokenizer: BartTokenizer = hf_processor.tokenizer - prompt_text = tokenizer.decode(prompt_tokens) - # convert task tokens to prompt - prompt_text = hf_processor._construct_prompts([prompt_text])[0] - prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False) - return prompt_tokens - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - if mm_data: - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs) - else: - hf_processor = self.info.get_hf_processor() - tokenizer = hf_processor.tokenizer - prompt = hf_processor._construct_prompts([prompt])[0] - processed_outputs = tokenizer(prompt, - add_special_tokens=True, - return_tensors="pt") - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image")) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - pad_token_id = hf_config.pad_token_id - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [pad_token_id] * num_image_tokens - - return [ - PromptInsertion( - modality="image", - target=PromptIndexTargets.start(), - insertion=image_tokens, - ) - ] - - -@MULTIMODAL_REGISTRY.register_processor( - Florence2MultiModalProcessor, - info=Florence2ProcessingInfo, - dummy_inputs=Florence2DummyInputsBuilder) -class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsV0Only): - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return None - - raise ValueError("Only image modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - processor_config = vllm_config.model_config.hf_image_processor_config - - self.config = config - self.vision_config = config.vision_config - self.processor_config = processor_config - assert config.vision_config.model_type == 'davit', ( - 'only DaViT is supported for now') - self.vision_tower = DaViT.from_config(config=config.vision_config) - self._build_image_projection_layers(config) - self.language_model = Florence2LanguageForConditionalGeneration( - vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=f"{prefix}.language_model", - ) - self.pad_token_id = config.pad_token_id - - def _build_image_projection_layers(self, config: PretrainedConfig): - image_dim_out = config.vision_config.dim_embed[-1] - dim_projection = config.vision_config.projection_dim - self.image_projection = nn.Parameter( - torch.empty(image_dim_out, dim_projection)) - self.image_proj_norm = nn.LayerNorm(dim_projection) - image_pos_embed_config = config.vision_config.image_pos_embed - if image_pos_embed_config['type'] == 'learned_abs_2d': - self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( - embedding_dim=image_dim_out, - num_pos=image_pos_embed_config['max_pos_embeddings']) - else: - raise NotImplementedError("Florence2 only supports learned_abs_2d " - "as image position embedding.") - - self.image_feature_source = config.vision_config.image_feature_source - - # temporal embedding - visual_temporal_embedding_config = ( - self.vision_config.visual_temporal_embedding) - if visual_temporal_embedding_config['type'] == 'COSINE': - self.visual_temporal_embed = PositionalEmbeddingCosine1D( - embed_dim=image_dim_out, - max_seq_len=visual_temporal_embedding_config[ - 'max_temporal_embeddings']) - else: - raise NotImplementedError( - 'Florence2 only supports COSINE as temporal embedding.') - - def _parse_and_validate_image_input(self, **kwargs: object): - pixel_values: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "pixel_values", None) - image_embeds: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "image_embeds", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None and image_embeds is not None: - raise ValueError( - "Both pixel values and image embeds are provided.") - - if pixel_values is not None: - size = self.processor_config["size"] - expected_h, expected_w = size["height"], size["width"] - - return Florence2ImagePixelInputs( - type="pixel_values", - data=flatten_bn(pixel_values, concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }, - ) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") - - def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: - dtype = next(self.vision_tower.parameters()).dtype - pixel_values = pixel_values.to(dtype) - - batch_size, T = pixel_values.size(0), 1 - x = self.vision_tower.forward_features_unpool(pixel_values) - if self.image_pos_embed is not None: - x = x.view(batch_size * T, -1, x.shape[-1]) - num_tokens = x.shape[-2] - h, w = int(num_tokens**0.5), int(num_tokens**0.5) - assert h * w == num_tokens, ( - 'only support square feature maps for now') - x = x.view(batch_size * T, h, w, x.shape[-1]) - pos_embed = self.image_pos_embed(x) - x = x + pos_embed - x = x.view(batch_size, T * h * w, x.shape[-1]) - - if self.visual_temporal_embed is not None: - visual_temporal_embed = self.visual_temporal_embed( - x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) - x = x.view(batch_size, T, -1, - x.shape[-1]) + visual_temporal_embed.view( - 1, T, 1, x.shape[-1]) - - x_feat_dict = {} - - spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) - x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x - - temporal_avg_pool_x = x.view(batch_size, T, -1, - x.shape[-1]).mean(dim=1) - x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x - - x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] - x_feat_dict['last_frame'] = x - - new_x = [] - for _image_feature_source in self.image_feature_source: - if _image_feature_source not in x_feat_dict: - raise ValueError('invalid image feature source: {}'.format( - _image_feature_source)) - new_x.append(x_feat_dict[_image_feature_source]) - - x = torch.cat(new_x, dim=1) - - x = x @ self.image_projection - x = self.image_proj_norm(x) - - return x - - def _process_image_input( - self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: - assert image_input["type"] == "pixel_values" - pixel_values = image_input["data"] - return self._encode_image(pixel_values) - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return [] - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.pad_token_id) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - if encoder_input_ids.numel() > 0 or vision_embeddings is not None: - inputs_embeds = self.get_input_embeddings(encoder_input_ids, - vision_embeddings) - else: - inputs_embeds = None - - hidden_states = self.language_model(input_ids, - positions, - encoder_input_ids, - encoder_positions, - inputs_embeds=inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 90af859ab92ec..83572563c15ef 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,35 +16,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Fuyu model.""" +"""PyTorch Fuyu model.""" + import math from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal, Optional import torch import torch.nn as nn -from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, - FuyuProcessor) +from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.persimmon import PersimmonForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -61,22 +66,18 @@ class FuyuImagePatchInputs(TensorSchema): type: Literal["image_patches"] = "image_patches" - flat_data: Annotated[ - torch.Tensor, - TensorShape("bnp", "fn"), - ] + image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")] patches_per_image: Annotated[list[int], TensorShape("bn")] """ The number of total patches for each image in the batch. This is used to split the embeddings which has the first two dimensions - flattened just like `flat_data`. + flattened just like `image_patches_flat`. """ class FuyuProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(FuyuConfig) @@ -128,12 +129,12 @@ class FuyuProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size["width"], - height=image_processor.size["height"]) + return ImageSize( + width=image_processor.size["width"], height=image_processor.size["height"] + ) class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -141,21 +142,24 @@ class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -176,28 +180,11 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): tok_kwargs=tok_kwargs, ) - image_patches = processed_outputs.get("image_patches") - if image_patches is not None: - images = mm_data["images"] - assert isinstance(images, list) - - # Original output: (1, num_images, Pn, Px * Py * C) - # New output: (num_images, Pn, Px * Py * C) - # image_patches is a list with shape: - # (1, num_images, Pn, Px * Py * C) - # before Transformers 4.53 - if isinstance(image_patches, list): - assert len(image_patches) == 1 - assert (isinstance(image_patches[0], torch.Tensor) - and len(image_patches[0]) == len(images)) - processed_outputs["image_patches"] = image_patches[0] - # image_patches is a tensor with shape: - # (num_images, Pn, Px * Py * C) - # after Transformers 4.53 - elif isinstance(image_patches, torch.Tensor): - assert len(image_patches) == len(images) - else: - raise AssertionError("This line should be unreachable.") + image_patches = processed_outputs["image_patches"] + processed_outputs["image_patches"] = flatten_bn(image_patches) + processed_outputs["patches_per_image"] = torch.tensor( + [len(p) for p in image_patches] + ) return processed_outputs @@ -220,7 +207,14 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(image_patches=MultiModalFieldConfig.batched("image")) + patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) + + return dict( + image_patches=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image + ), + patches_per_image=MultiModalFieldConfig.batched("image"), + ) def _get_prompt_updates( self, @@ -244,8 +238,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): image_width=image_size.width, image_height=image_size.height, ) - image_tokens = ([_IMAGE_TOKEN_ID] * ncols + - [_NEWLINE_TOKEN_ID]) * nrows + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows return PromptUpdateDetails.select_token_id( image_tokens + [bos_token_id], @@ -261,17 +254,21 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, - info=FuyuProcessingInfo, - dummy_inputs=FuyuDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + FuyuMultiModalProcessor, + info=FuyuProcessingInfo, + dummy_inputs=FuyuDummyInputsBuilder, +) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.": "vision_embed_tokens.", "model.language_model.": "language_model.model.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -303,62 +300,46 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: + self, **kwargs: object + ) -> Optional[FuyuImagePatchInputs]: image_patches = kwargs.pop("image_patches", None) - if image_patches is not None: - image_patches_flat = flatten_bn(image_patches) - flat_data = flatten_bn(image_patches_flat, concat=True) + patches_per_image = kwargs.pop("patches_per_image", None) - return FuyuImagePatchInputs( - type="image_patches", - flat_data=flat_data, - patches_per_image=[x.size(0) for x in image_patches_flat], - resolve_bindings={"fn": self.image_feature_size}, - ) + if image_patches is None: + return None - return None + return FuyuImagePatchInputs( + type="image_patches", + image_patches_flat=image_patches, + patches_per_image=patches_per_image, + resolve_bindings={"fn": self.image_feature_size}, + ) def _process_image_input( - self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings: - image_patches_flat = image_input["flat_data"] + self, image_input: FuyuImagePatchInputs + ) -> MultiModalEmbeddings: + image_patches_flat = image_input["image_patches_flat"] patches_per_image = image_input["patches_per_image"] assert self.vision_embed_tokens is not None - vision_embeddings_flat, _ = self.vision_embed_tokens( - image_patches_flat) + vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat) - return vision_embeddings_flat.split(patches_per_image, dim=0) + return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - _IMAGE_TOKEN_ID, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -370,14 +351,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - hidden_states = self.language_model( input_ids=input_ids, positions=positions, @@ -389,13 +362,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.language_model.logits_processor( - self.language_model.lm_head, hidden_states, sampling_metadata) + self.language_model.lm_head, hidden_states + ) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 59c3102add4c7..b152f52223cf6 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -16,8 +16,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" + from collections.abc import Iterable from functools import cache +from itertools import islice from typing import Optional, Union import torch @@ -31,22 +33,26 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -66,19 +72,22 @@ def _get_gemma_act_fn( "`%s`, edit the config JSON to set " "`hidden_activation=%s` instead of `hidden_act`. " "See https://github.com/huggingface/transformers/pull/29402 " - "for more details.", hidden_act, hidden_act) + "for more details.", + hidden_act, + hidden_act, + ) return GeluAndMul(approximate="tanh") elif hidden_activation == "gelu_pytorch_tanh": return GeluAndMul(approximate="tanh") elif hidden_activation == "gelu": return GeluAndMul(approximate="none") else: - raise ValueError(f"Activation function {hidden_act} is not " - "supported for Gemma models.") + raise ValueError( + f"Activation function {hidden_act} is not supported for Gemma models." + ) class GemmaMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -113,7 +122,6 @@ class GemmaMLP(nn.Module): class GemmaAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -172,13 +180,15 @@ class GemmaAttention(nn.Module): base=self.rope_theta, is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -194,7 +204,6 @@ class GemmaAttention(nn.Module): class GemmaDecoderLayer(nn.Module): - def __init__( self, config: GemmaConfig, @@ -223,10 +232,10 @@ class GemmaDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -239,23 +248,20 @@ class GemmaDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class GemmaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -272,8 +278,10 @@ class GemmaModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: GemmaDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -281,12 +289,10 @@ class GemmaModel(nn.Module): # data type such as bfloat16, not float32. # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 - self.register_buffer("normalizer", - torch.tensor(normalizer), - persistent=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -308,22 +314,20 @@ class GemmaModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -335,7 +339,7 @@ class GemmaModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -355,8 +359,7 @@ class GemmaModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -388,11 +391,13 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = GemmaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = GemmaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -404,24 +409,21 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 8cfe92c64540f..2d26edcf6609f 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -30,30 +31,35 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Gemma2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -64,18 +70,17 @@ class Gemma2MLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): raise ValueError( "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) self.act_fn = GeluAndMul(approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -86,19 +91,20 @@ class Gemma2MLP(nn.Module): class Gemma2Attention(nn.Module): - - def __init__(self, - config: Gemma2Config, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - rope_theta: float, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = hidden_size @@ -148,15 +154,17 @@ class Gemma2Attention(nn.Module): is_sliding = config.layer_types[layer_idx] == "sliding_attention" sliding_window = config.sliding_window if is_sliding else None - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -172,7 +180,6 @@ class Gemma2Attention(nn.Module): class Gemma2DecoderLayer(nn.Module): - def __init__( self, config: Gemma2Config, @@ -203,14 +210,16 @@ class Gemma2DecoderLayer(nn.Module): hidden_activation=config.hidden_activation, quant_config=quant_config, ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -222,8 +231,7 @@ class Gemma2DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -231,7 +239,8 @@ class Gemma2DecoderLayer(nn.Module): hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual = self.pre_feedforward_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) return hidden_states, residual @@ -239,7 +248,6 @@ class Gemma2DecoderLayer(nn.Module): @support_torch_compile class Gemma2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -255,8 +263,10 @@ class Gemma2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Gemma2DecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -264,12 +274,10 @@ class Gemma2Model(nn.Module): # data type such as bfloat16, not float32. # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 - self.register_buffer("normalizer", - torch.tensor(normalizer), - persistent=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -292,22 +300,20 @@ class Gemma2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -319,17 +325,17 @@ class Gemma2Model(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -353,8 +359,7 @@ class Gemma2Model(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -384,12 +389,15 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.quant_config = quant_config - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor( - config.vocab_size, soft_cap=config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -401,24 +409,21 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index b762be3c52925..7e6fc401757aa 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -23,37 +24,43 @@ import torch.nn.functional as F from torch import nn from transformers import Gemma3TextConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors +from ...attention.layers.encoder_only_attention import EncoderOnlyAttention from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Gemma3MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -81,7 +88,8 @@ class Gemma3MLP(nn.Module): raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) self.act_fn = GeluAndMul(approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -92,18 +100,19 @@ class Gemma3MLP(nn.Module): class Gemma3Attention(nn.Module): - - def __init__(self, - config: Gemma3TextConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = hidden_size @@ -168,16 +177,29 @@ class Gemma3Attention(nn.Module): rope_scaling=self.rope_scaling, ) - # Initialize the attention. - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn") + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) + + self.attn = attn_cls( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -216,11 +238,7 @@ class Gemma3Attention(nn.Module): # output is discarded and overwritten below. While this duplicates # computation, it maintains compatibility. # TODO(woosuk): Optimize by implementing custom attention kernels. - attn_output = self.naive_attn_with_masks(q, - k, - v, - out=attn_output, - **kwargs) + attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs) output, _ = self.o_proj(attn_output) return output @@ -274,7 +292,6 @@ class Gemma3Attention(nn.Module): class Gemma3DecoderLayer(nn.Module): - def __init__( self, config: Gemma3TextConfig, @@ -304,14 +321,16 @@ class Gemma3DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -324,8 +343,7 @@ class Gemma3DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -334,7 +352,8 @@ class Gemma3DecoderLayer(nn.Module): hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual = self.pre_feedforward_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) return hidden_states, residual @@ -342,7 +361,6 @@ class Gemma3DecoderLayer(nn.Module): @support_torch_compile class Gemma3Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -354,13 +372,16 @@ class Gemma3Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Gemma3DecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -368,12 +389,10 @@ class Gemma3Model(nn.Module): # data type such as bfloat16, not float32. # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 - self.register_buffer("normalizer", - torch.tensor(normalizer), - persistent=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: # NOTE(woosuk): Only apply the normalizer to the output of @@ -398,7 +417,7 @@ class Gemma3Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -406,15 +425,13 @@ class Gemma3Model(nn.Module): **kwargs, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -426,17 +443,42 @@ class Gemma3Model(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + # Revert +1 during llama.cpp conversion + # see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400 + if ( + self.quant_config + and self.quant_config.get_name() == "gguf" + and name.endswith("norm.weight") + ): + loaded_weight -= 1 + + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, shard_name, shard_id) in stacked_params_mapping: + + # Check if this is a scale parameter that needs remapping first + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + # Try to remap the scale name first + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is not None and remapped_name in params_dict: + # Successfully remapped, use the remapped name + param = params_dict[remapped_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(remapped_name) + continue + # If remapping failed, continue with normal processing + + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -460,8 +502,7 @@ class Gemma3Model(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -491,12 +532,15 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.quant_config = quant_config - self.model = Gemma3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor( - config.vocab_size, soft_cap=config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -509,24 +553,21 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index bf5ad633b94a5..95b0b0dab5a1e 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -11,34 +11,45 @@ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs import vllm.envs as envs from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, - PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, - replace_token_matches) -# yapf: enable +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + replace_token_matches, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) @@ -53,6 +64,7 @@ class Gemma3ImagePixelInputs(TensorSchema): - w: Width of each patch - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")] @@ -64,7 +76,6 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs class Gemma3ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Gemma3Config) @@ -107,19 +118,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() images_kwargs = self._resolve_image_kwargs( - processor, { - "do_pan_and_scan", "pan_and_scan_min_crop_size", + processor, + { + "do_pan_and_scan", + "pan_and_scan_min_crop_size", "pan_and_scan_max_num_crops", - "pan_and_scan_min_ratio_to_activate" - }) + "pan_and_scan_min_ratio_to_activate", + }, + ) do_pan_and_scan = images_kwargs["do_pan_and_scan"] - pan_and_scan_min_crop_size = images_kwargs[ - "pan_and_scan_min_crop_size"] - pan_and_scan_max_num_crops = images_kwargs[ - "pan_and_scan_max_num_crops"] + pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"] + pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] pan_and_scan_min_ratio_to_activate = images_kwargs[ - "pan_and_scan_min_ratio_to_activate"] + "pan_and_scan_min_ratio_to_activate" + ] if not do_pan_and_scan: return 0 @@ -127,7 +140,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): if envs.VLLM_USE_V1: logger.warning_once( "`do_pan_and_scan=True` has suboptimal results on V1 " - "because of the simplified attention pattern being used.") + "because of the simplified attention pattern being used." + ) # Based on Gemma3ImageProcessor.pan_and_scan if image_width >= image_height: @@ -187,10 +201,10 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) image_text = ( f"Here is the original image {boi_token} and here are some " - f"crops to help you see better {crops_image_tokens}") + f"crops to help you see better {crops_image_tokens}" + ) - repl_full = image_text.replace(boi_token, - processor.full_image_sequence) + repl_full = image_text.replace(boi_token, processor.full_image_sequence) tokenizer = processor.tokenizer vocab = tokenizer.get_vocab() @@ -221,7 +235,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() images_kwargs = self._resolve_image_kwargs( - processor, {"pan_and_scan_max_num_crops"}) + processor, {"pan_and_scan_max_num_crops"} + ) max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] # Result in the max possible feature size (h:w = max_num_crops:1) @@ -229,7 +244,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -242,22 +256,25 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -274,23 +291,25 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) image_sizes = [ - parsed_images.get_image_size(i) - for i in range(len(parsed_images)) + parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] hf_processor = self.info.get_hf_processor(**mm_kwargs) num_crops = [ - self.info.get_num_crops(image_width=size.width, - image_height=size.height, - processor=hf_processor) + self.info.get_num_crops( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) for size in image_sizes ] - processed_outputs["num_crops"] = torch.tensor(num_crops) + processed_outputs["num_patches"] = torch.tensor(num_crops) + 1 return processed_outputs @@ -299,12 +318,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_crops = hf_inputs.get("num_crops", torch.empty(0)) + num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops + 1), - num_crops=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -337,14 +355,9 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -373,13 +386,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -404,8 +416,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates) return { modality: [ @@ -415,39 +426,43 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): start_idx=repl_orig_idxs[p.start_idx], tokens=p.tokens, is_embed=p.is_embed, - ) for p in placeholders + ) + for p in placeholders ] for modality, placeholders in repls.items() } class Gemma3MultiModalProjector(nn.Module): - def __init__(self, config: Gemma3Config): super().__init__() self.mm_input_projection_weight = nn.Parameter( - torch.zeros(config.vision_config.hidden_size, - config.text_config.hidden_size)) + torch.zeros( + config.vision_config.hidden_size, config.text_config.hidden_size + ) + ) self.mm_soft_emb_norm = GemmaRMSNorm( - config.vision_config.hidden_size, - eps=config.vision_config.layer_norm_eps) + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) - self.patches_per_image = int(config.vision_config.image_size // - config.vision_config.patch_size) + self.patches_per_image = int( + config.vision_config.image_size // config.vision_config.patch_size + ) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side - self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, - stride=self.kernel_size) + self.avg_pool = nn.AvgPool2d( + kernel_size=self.kernel_size, stride=self.kernel_size + ) def forward(self, vision_outputs: torch.Tensor): batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( - batch_size, seq_length, self.patches_per_image, - self.patches_per_image) + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) @@ -457,15 +472,21 @@ class Gemma3MultiModalProjector(nn.Module): normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) projected_vision_outputs = torch.matmul( - normed_vision_outputs, self.mm_input_projection_weight) + normed_vision_outputs, self.mm_input_projection_weight + ) return projected_vision_outputs.type_as(vision_outputs) -@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, - info=Gemma3ProcessingInfo, - dummy_inputs=Gemma3DummyInputsBuilder) -class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): +@MULTIMODAL_REGISTRY.register_processor( + Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder, +) +class Gemma3ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -485,7 +506,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -503,10 +525,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, self.quant_config = quant_config self.multimodal_config = multimodal_config - self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = Gemma3MultiModalProjector(config) self.language_model = init_vllm_registered_model( @@ -516,41 +539,37 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, architectures=["Gemma3ForCausalLM"], ) logit_scale = getattr(config, "logit_scale", 1.0) - self.language_model.logits_processor.scale *= logit_scale + + if hasattr(self.language_model, "logits_processor"): + # The logits processor can be unset if we're using + # automatic conversion to pooling model. + self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) @property def dtype(self): return next(self.parameters()).dtype def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + self, **kwargs: object + ) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - num_crops = kwargs.pop("num_crops", None) + num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: return None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_crops. " - f"Got type: {type(num_crops)}") - image_size = self.config.vision_config.image_size return Gemma3ImagePixelInputs( - pixel_values=flatten_bn(pixel_values, concat=True), - num_patches=flatten_bn(num_crops, concat=True) + 1, - resolve_bindings={ - "h": image_size, - "w": image_size - }) + pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={"h": image_size, "w": image_size}, + ) def _image_pixels_to_features( self, @@ -574,67 +593,36 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ) image_embeds = self.multi_modal_projector(image_features) - return [ - e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) - ] + return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())] def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( + def forward( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> IntermediateTensors: + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - if (vision_embeddings is not None) and len(vision_embeddings) != 0: - kwargs = self.prepare_attn_masks( - input_ids, - positions, - mask_dtype=self.dtype, - **kwargs, - ) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds, - **kwargs) + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) return hidden_states @@ -682,19 +670,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, # Consider the bidirectional attention between image tokens. img_mask = torch.zeros_like(global_attn_mask) - img_pos = (input_token_ids == self.config.image_token_index) + img_pos = input_token_ids == self.config.image_token_index img_mask[:, :, :, img_pos] += 1 img_mask[:, :, img_pos, :] += 1 global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - if (sliding_window := self.config.sliding_window) is not None: + sliding_window = self.config.text_config.sliding_window + if sliding_window is not None: # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, - diagonal=-sliding_window) - local_attn_mask = torch.where(local_attn_mask == 0, - global_attn_mask, float("-inf")) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) + local_attn_mask = torch.where( + local_attn_mask == 0, global_attn_mask, float("-inf") + ) local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks @@ -703,13 +692,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -720,4 +706,5 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index ffec3408702c9..e4ea4256ebc23 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -26,32 +26,45 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, - GeluAndMul, - GeluAndMulSparse) +from vllm.model_executor.layers.activation import ( + _ACTIVATION_REGISTRY, + GeluAndMul, + GeluAndMulSparse, +) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsQuant -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, make_layers, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) +EPS = torch.tensor(torch.finfo().min) + class Gemma3nAltUp(nn.Module): """Alternating updates (Altup) @@ -107,9 +120,11 @@ class Gemma3nAltUp(nn.Module): eps=rms_norm_eps, ) self.router_input_scale = torch.tensor( - hidden_size**-1.0, dtype=self.modality_router.weight.dtype) + hidden_size**-1.0, dtype=self.modality_router.weight.dtype + ) self.correct_output_scale = nn.Parameter( - torch.zeros(hidden_size, dtype=torch.float32)) + torch.zeros(hidden_size, dtype=torch.float32) + ) def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: router_inputs = self.router_norm(x) * self.router_input_scale @@ -117,15 +132,17 @@ class Gemma3nAltUp(nn.Module): return torch.tanh(routed.float()).type_as(x) def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: - return (corrected.type_as(self.correct_output_scale) * - self.correct_output_scale).type_as(corrected) + return ( + corrected.type_as(self.correct_output_scale) * self.correct_output_scale + ).type_as(corrected) def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden: [altup_num_inputs, num_tokens, hidden_size] # modalities: [num_tokens, num_altup_inputs] # all_coefs: [num_tokens, num_altup_inputs ** 2] modalities = self._compute_router_modalities( - hidden_states[self.altup_active_idx]) + hidden_states[self.altup_active_idx] + ) all_coefs = self.prediction_coefs(modalities) # Reshape and transpose the 2D matrix for the matmul. @@ -143,8 +160,9 @@ class Gemma3nAltUp(nn.Module): predictions += hidden_states return predictions.contiguous() - def correct(self, predictions: torch.Tensor, - activated: torch.Tensor) -> torch.Tensor: + def correct( + self, predictions: torch.Tensor, activated: torch.Tensor + ) -> torch.Tensor: # predictions: [altup_num_inputs, num_tokens, hidden_size] # activated: [num_tokens, hidden_size] # modalities: [num_tokens, altup_num_inputs] @@ -212,7 +230,6 @@ class Gemma3nLaurelBlock(nn.Module): class Gemma3nMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -241,12 +258,16 @@ class Gemma3nMLP(nn.Module): raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) - self.act_fn = GeluAndMulSparse( - activation_sparsity=activation_sparsity, - approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul( - approximate="tanh") + self.act_fn = ( + GeluAndMulSparse( + activation_sparsity=activation_sparsity, approximate="tanh" + ) + if activation_sparsity > 0.0 + else GeluAndMul(approximate="tanh") + ) def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) @@ -256,17 +277,18 @@ class Gemma3nMLP(nn.Module): class Gemma3nAttention(nn.Module): - - def __init__(self, - config: Gemma3nTextConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Gemma3nTextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = hidden_size @@ -304,13 +326,11 @@ class Gemma3nAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.o_proj", ) - self.q_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps) - self.k_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps) - self.v_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps, - has_weight=False) + self.q_norm = RMSNorm(hidden_size=self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(hidden_size=self.head_dim, eps=config.rms_norm_eps) + self.v_norm = RMSNorm( + hidden_size=self.head_dim, eps=config.rms_norm_eps, has_weight=False + ) layer_idx = extract_layer_index(prefix) is_sliding = config.layer_types[layer_idx] == "sliding_attention" @@ -326,8 +346,9 @@ class Gemma3nAttention(nn.Module): rope_theta = config.rope_theta rope_scaling = config.rope_scaling - first_kv_shared_layer_idx = (config.num_hidden_layers - - config.num_kv_shared_layers) + first_kv_shared_layer_idx = ( + config.num_hidden_layers - config.num_kv_shared_layers + ) self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx kv_sharing_target_layer_name = None @@ -358,7 +379,8 @@ class Gemma3nAttention(nn.Module): quant_config=quant_config, per_layer_sliding_window=self.sliding_window, kv_sharing_target_layer_name=kv_sharing_target_layer_name, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + ) def forward( self, @@ -387,7 +409,6 @@ class Gemma3nAttention(nn.Module): class Gemma3nDecoderLayer(nn.Module): - def __init__( self, config: Gemma3nTextConfig, @@ -423,12 +444,12 @@ class Gemma3nDecoderLayer(nn.Module): self.mlp = Gemma3nMLP( hidden_size=config.hidden_size, # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501 - intermediate_size=config.intermediate_size[extract_layer_index( - prefix)], + intermediate_size=config.intermediate_size[extract_layer_index(prefix)], hidden_activation=config.hidden_activation, quant_config=quant_config, activation_sparsity=config.activation_sparsity_pattern[ - extract_layer_index(prefix)], + extract_layer_index(prefix) + ], prefix=f"{prefix}.mlp", ) self.laurel = Gemma3nLaurelBlock( @@ -490,7 +511,6 @@ class Gemma3nDecoderLayer(nn.Module): per_layer_input: torch.Tensor, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - # ActUp (predict). predictions = self.altup.predict(hidden_states) active_prediction = predictions[self.altup_active_idx] @@ -505,8 +525,7 @@ class Gemma3nDecoderLayer(nn.Module): ) attn = self.post_attention_layernorm(attn) attn_gated = attn + active_prediction - attn_laurel = (attn_gated + laurel_output) / torch.sqrt( - torch.tensor(2.0)) + attn_laurel = (attn_gated + laurel_output) / torch.sqrt(torch.tensor(2.0)) # MLP. attn_norm = self.pre_feedforward_layernorm(attn_laurel) @@ -515,8 +534,7 @@ class Gemma3nDecoderLayer(nn.Module): attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm # ActUp (connect). - corrected_predictions = self.altup.correct(predictions, - attn_ffw_laurel_gated) + corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) first_prediction = corrected_predictions[self.altup_active_idx] first_prediction = self.altup.scale_corrected_output(first_prediction) @@ -533,16 +551,30 @@ class Gemma3nDecoderLayer(nn.Module): return corrected_predictions -@support_torch_compile -class Gemma3nTextModel(nn.Module, SupportsQuant): +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma3nSelfDecoder(nn.Module): + """ + Includes altup embedding and self decoder layers + """ - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.config = config - self.quant_config = quant_config + quant_config = vllm_config.quant_config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -579,106 +611,147 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): eps=config.rms_norm_eps, ) self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to( - self.embed_tokens.weight.dtype) + self.embed_tokens.weight.dtype + ) self.per_layer_projection_scale = torch.tensor( config.hidden_size**0.5, dtype=self.embed_tokens.weight.dtype, ) - self.altup_projections = nn.ModuleList([ - ColumnParallelLinear( - config.hidden_size, - config.hidden_size, - bias=False, - gather_output=True, - return_bias=False, - quant_config=quant_config, - prefix=f"{prefix}.altup_projections.{idx-1}", - ) for idx in range(1, self.config.altup_num_inputs) - ]) - self.altup_unembed_projections = nn.ModuleList([ - ColumnParallelLinear( - config.hidden_size, - config.hidden_size, - bias=False, - gather_output=True, - return_bias=False, - quant_config=quant_config, - prefix=f"{prefix}.altup_unembed_projections.{idx-1}", - ) for idx in range(1, self.config.altup_num_inputs) - ]) - - # Transformer blocks. - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: Gemma3nDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, + self.altup_projections = nn.ModuleList( + [ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.altup_projections.{idx - 1}", + ) + for idx in range(1, self.config.altup_num_inputs) + ] ) - self.eps = torch.tensor(torch.finfo().min) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.embed_scale - - def get_per_layer_input_embeddings( - self, input_ids: torch.Tensor) -> torch.Tensor: + def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: # Deal with the fact that vocab_size_per_layer_input < vocab_size # which causes us to have some out of vocab tokens by setting # those token ids to 0. This matches the HF implementation. per_layer_inputs_mask = torch.logical_and( - input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, - torch.zeros_like(input_ids)) - return self.embed_tokens_per_layer( - per_layer_inputs_tokens) * self.embed_scale_per_layer + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + return ( + self.embed_tokens_per_layer(per_layer_inputs_tokens) + * self.embed_scale_per_layer + ) - def forward( + def get_per_layer_inputs( self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - per_layer_inputs: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is not None: - hidden_states_0 = inputs_embeds - else: - hidden_states_0 = self.get_input_embeddings(input_ids) - + hidden_states_0: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = per_layer_projection.reshape( *hidden_states_0.shape[:-1], self.config.num_hidden_layers, self.config.hidden_size_per_layer_input, ) - per_layer_projection = self.per_layer_projection_norm( - per_layer_projection) - + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) if per_layer_inputs is not None: # Profiling run does not compute per_layer_inputs per_layer_inputs = per_layer_projection + per_layer_inputs per_layer_inputs *= self.per_layer_input_scale else: per_layer_inputs = per_layer_projection + return per_layer_inputs + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: # Altup embed. hidden_states = [hidden_states_0] * self.config.altup_num_inputs - target_magnitude = torch.mean(hidden_states_0**2, dim=-1, - keepdim=True)**0.5 + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 for i in range(1, self.config.altup_num_inputs): hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - hidden_states = torch.stack(hidden_states, dim=0) + new_magnitude = ( + torch.mean(hidden_states[i] ** 2, dim=-1, keepdim=True) ** 0.5 + ) + hidden_states[i] *= target_magnitude / torch.maximum(new_magnitude, EPS) + hidden_states = torch.stack(hidden_states, dim=-1) + return hidden_states - # Transformer blocks. - for layer_idx, layer in enumerate(self.layers): + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + adjusted_per_layer_inputs = self.get_per_layer_inputs( + hidden_states_0, per_layer_inputs + ) + hidden_states = self.altup_embed(hidden_states_0) + + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + + return hidden_states, adjusted_per_layer_inputs + + +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma3nCrossDecoder(nn.Module): + """ + Cross-decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start # [altup_num_inputs, num_tokens, hidden_size] hidden_states = layer( positions=positions, @@ -686,26 +759,264 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): per_layer_input=per_layer_inputs[:, layer_idx, :], **kwargs, ) + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + return hidden_states + +# This disables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile( + enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma3nTextModel(nn.Module, SupportsQuant): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.altup_unembed_projections = nn.ModuleList( + [ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.altup_unembed_projections.{idx - 1}", + ) + for idx in range(1, self.config.altup_num_inputs) + ] + ) + + # Allocate config.num_kv_shared_layers layers for self-decoder + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3nDecoderLayer( + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + + first_kv_shared_layer_idx = ( + config.num_hidden_layers - config.num_kv_shared_layers + ) + + # NOTE(sarckk): importing this top level seems to cause issues + # during running of tests. + from vllm.compilation.backends import set_model_tag + + # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) + with set_model_tag("self_decoder"): + self.self_decoder = Gemma3nSelfDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + ) + # Layer idx 20-30 are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma3nCrossDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + # TODO(sarckk): Extract this functionality to interface + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros( + max_num_tokens, dtype=torch.int64, device=device + ) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size, self.config.altup_num_inputs), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + self.per_layer_inputs = torch.zeros( + ( + max_num_tokens, + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + + @property + def embed_tokens(self): + return self.self_decoder.embed_tokens + + def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_per_layer_input_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_input_embeddings(input_ids) + + def fast_prefill_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata + + # attn_metadata is None during dummy runs + if self.fast_prefill_enabled and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + # Last layer is a KV sharing layer + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name + ] + if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata): + logits_indices_padded = layer_attn_metadata.logits_indices_padded + num_logits_indices = layer_attn_metadata.num_logits_indices + + # Copy inputs for cudagraph + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + self_decoder_hidden_states, per_layer_inputs_adjusted = self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + positions.size(0), + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE(sarckk): There is currently a bug caused by + # vLLM converting output of last piecewise CUDA graph + # to weakref, causing memory to be prematurely freed + # when there are multiple compilation units + # Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + # Copy inputs for cudagraph + num_padded_logits_indices = logits_indices_padded.size(0) + self.positions[:num_padded_logits_indices].copy_( + positions[logits_indices_padded] + ) + self.hidden_states[:num_padded_logits_indices].copy_( + self_decoder_hidden_states[logits_indices_padded] + ) + self.per_layer_inputs[:num_padded_logits_indices].copy_( + per_layer_inputs_adjusted[logits_indices_padded] + ) + cross_decoder_hidden_states = self.cross_decoder( + positions=self.positions[:num_padded_logits_indices], + hidden_states=self.hidden_states[:num_padded_logits_indices], + per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices], + **kwargs, + ) + + if num_logits_indices is not None: + assert num_logits_indices > 0 + # Merge cross-decoder and self-decoder hidden states + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_decoder_hidden_states[:num_logits_indices] + ) + else: + hidden_states = cross_decoder_hidden_states + + return hidden_states + + def normal_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + hidden_states = self.cross_decoder( + positions=positions, + hidden_states=hidden_states, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + return hidden_states + + def altup_unembed( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: # Altup unembed. - target_magnitude = torch.mean(hidden_states[0]**2, - dim=-1, - keepdim=True)**0.5 + target_magnitude = ( + torch.mean(hidden_states[..., 0] ** 2, dim=-1, keepdim=True) ** 0.5 + ) for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_unembed_projections[i - 1]( - hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] - hidden_states = torch.mean(hidden_states, dim=0) + hidden_states[..., i] = self.altup_unembed_projections[i - 1]( + hidden_states[..., i] + ) + new_magnitude = ( + torch.mean(hidden_states[..., i] ** 2, dim=-1, keepdim=True) ** 0.5 + ) + hidden_states[..., i] *= target_magnitude / torch.maximum( + new_magnitude, EPS + ) + # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=-1) + return hidden_states + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + else: + hidden_states = self.normal_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + hidden_states = self.altup_unembed(hidden_states) return self.norm(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -717,17 +1028,26 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + # decoder layer weights, altup_unembed_projections and rmsnorm + # are initialized in text model, others are in self decoder + if ( + not name.startswith("layers") + and not name.startswith("altup_unembed_projections") + and not name.startswith("norm") + ): + name = f"self_decoder.{name}" + + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue # Avoid spurious match with ".up_proj". @@ -754,8 +1074,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -782,10 +1101,12 @@ class Gemma3nForCausalLM(nn.Module): super().__init__() self.config = config self.cache_config = vllm_config.cache_config - self.model = Gemma3nTextModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma3nTextModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor( - config.vocab_size, soft_cap=config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -800,7 +1121,6 @@ class Gemma3nForCausalLM(nn.Module): inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model( input_ids, positions, @@ -814,17 +1134,15 @@ class Gemma3nForCausalLM(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata], ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_substrs=([ - "embed_audio.", "embed_vision.", - "audio_tower.", "vision_tower." - ])) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_substrs=( + ["embed_audio.", "embed_vision.", "audio_tower.", "vision_tower."] + ), + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 79061fd30c39b..0e69fcfd8febd 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -1,48 +1,67 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union, cast +from typing import Annotated, Any, Literal, Optional, Union, cast +import numpy as np import torch + from torch import nn from transformers import AutoModel, BatchFeature -from transformers.models.gemma3n import (Gemma3nAudioConfig, - Gemma3nAudioFeatureExtractor, - Gemma3nConfig, Gemma3nProcessor, - Gemma3nTextConfig, - Gemma3nVisionConfig) +from transformers.models.gemma3n import ( + Gemma3nAudioConfig, + Gemma3nAudioFeatureExtractor, + Gemma3nConfig, + Gemma3nProcessor, + Gemma3nTextConfig, + Gemma3nVisionConfig, +) from transformers.models.siglip import SiglipImageProcessorFast -from vllm.config import VllmConfig +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import RowParallelLinear -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, - MultiModalDataParser) -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, - PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, - replace_token_matches) -# yapf: enable +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + replace_token_matches, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) @@ -51,23 +70,36 @@ TOKENS_PER_IMAGE = 256 TOKENS_PER_AUDIO = 188 -class Gemma3nImagePixelInputs(TypedDict): - pixel_values: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" +class Gemma3nImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each patch + - w: Width of each patch + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class Gemma3nAudioInputs(TypedDict): - input_features: torch.Tensor - """Shape: `(batch_size * num_audio, seq_length, num_features)`""" - input_features_mask: torch.Tensor - """Shape: `(batch_size * num_audio, seq_length)`""" +class Gemma3nAudioInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - s: seq_length + - f: num_features + """ + + type: Literal["audio"] = "audio" + input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")] + input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")] Gemma3nImageInputs = Gemma3nImagePixelInputs class Gemma3nProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Gemma3nConfig) @@ -78,9 +110,8 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): return {"image": None, "audio": None} def get_max_tokens_per_item( - self, seq_len: int, - mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]: - + self, seq_len: int, mm_counts: Mapping[str, int] + ) -> Optional[Mapping[str, int]]: return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO} def get_image_repl( @@ -92,7 +123,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ) -> str: """ Get the replacement text for image tokens. - + For Gemma3n, this should return the full_image_sequence which includes BOI token, repeated image tokens, and EOI token. """ @@ -100,7 +131,8 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() return PromptUpdateDetails.select_token_id( - processor.full_image_sequence, processor.image_token_id) + processor.full_image_sequence, processor.image_token_id + ) def get_audio_repl( self, @@ -109,7 +141,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ) -> str: """ Get the replacement text for audio tokens. - + For Gemma3n, this should return the full_audio_sequence which includes BOA token, repeated audio tokens, and EOA token. """ @@ -118,11 +150,11 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): # Return the full audio sequence as defined by the processor return PromptUpdateDetails.select_token_id( - processor.full_audio_sequence, processor.audio_token_id) + processor.full_audio_sequence, processor.audio_token_id + ) class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_audios = mm_counts.get("audio", 0) @@ -137,29 +169,36 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_audios = mm_counts.get("audio", 0) processor = self.info.get_hf_processor() - audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501 + audio_feature_extractor: Gemma3nAudioFeatureExtractor = ( + processor.feature_extractor + ) audio_len = audio_feature_extractor.fft_length image_processor: SiglipImageProcessorFast = processor.image_processor img_width = image_processor.size.get("width", 224) img_height = image_processor.size.get("height", 224) + image_overrides = mm_options.get("image") if mm_options else None + audio_overrides = mm_options.get("audio") if mm_options else None + return { - "image": - self._get_dummy_images(width=img_width, - height=img_height, - num_images=num_images), - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) + "image": self._get_dummy_images( + width=img_width, + height=img_height, + num_images=num_images, + overrides=image_overrides, + ), + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ), } -class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] - ): - +class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_hf_processor().feature_extractor return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -171,23 +210,29 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - # HF Transformers audio processor no longer accepts `audios` key. - # We pop `audios` and replace it with `audio` key to surpress + # We pop `audios` and replace it with `audio` key to suppress # the warning. - if 'audios' in mm_data: - mm_data['audio'] = mm_data.pop('audios') + if "audios" in mm_data: + mm_data["audio"] = mm_data.pop("audios") processed_outputs = super()._call_hf_processor( prompt, mm_data, mm_kwargs, tok_kwargs, ) - if 'input_features' in processed_outputs: - # Avoid padding since we need the output of each item to be + + if "input_features" in processed_outputs: + # Padding enables audio_tower to run in batched mode + processed_outputs["input_features_padded"] = processed_outputs[ + "input_features" + ] + + # Unpad features here since we need the output of each item to be # independent of other items for the cache to work correctly unpadded_features = [ - f[mask] for f, mask in zip( + f[mask] + for f, mask in zip( processed_outputs["input_features"], processed_outputs["input_features_mask"], ) @@ -200,10 +245,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - input_features=MultiModalFieldConfig.batched("audio"), - input_features_mask=MultiModalFieldConfig.batched("audio")) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + input_features_padded=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio"), + ) def _get_prompt_updates( self, @@ -233,35 +279,34 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] modality="image", target=image_token, replacement=get_replacement_image, - )) + ) + ) # Handle audio tokens if "audio" in mm_items: audio_token = hf_processor.audio_token def get_replacement_audio(item_idx: int): - return self.info.get_audio_repl(processor=hf_processor, ) + return self.info.get_audio_repl( + processor=hf_processor, + ) prompt_updates.append( PromptReplacement( modality="audio", target=audio_token, replacement=get_replacement_audio, - )) + ) + ) return prompt_updates def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -290,13 +335,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -321,8 +365,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates) return { modality: [ @@ -332,14 +375,15 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] start_idx=repl_orig_idxs[p.start_idx], tokens=p.tokens, is_embed=p.is_embed, - ) for p in placeholders + ) + for p in placeholders ] for modality, placeholders in repls.items() } class Gemma3nMultimodalEmbedder(nn.Module): - """Embeds token ids or soft tokens for multimodal content into language + """Embeds token ids or soft tokens for multimodal content into language model space.""" def __init__( @@ -399,7 +443,8 @@ class Gemma3nMultimodalEmbedder(nn.Module): """ # noqa: E501 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds") + "You must specify exactly one of input_ids or inputs_embeds" + ) if inputs_embeds is not None: emb_norm = self.soft_embedding_norm(inputs_embeds) @@ -411,10 +456,17 @@ class Gemma3nMultimodalEmbedder(nn.Module): return self.embedding_post_projection_norm(emb_norm_proj) -@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, - info=Gemma3nProcessingInfo, - dummy_inputs=Gemma3nDummyInputsBuilder) -class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor( + Gemma3nMultiModalProcessor, + info=Gemma3nProcessingInfo, + dummy_inputs=Gemma3nDummyInputsBuilder, +) +class Gemma3nForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsTranscription +): + merge_by_field_config = True + supported_languages = ISO639_1_SUPPORTED_LANGS + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -438,7 +490,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", "model": "language_model.model", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -450,15 +503,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): self.multimodal_config = multimodal_config self.vocab_size = config.text_config.vocab_size - self.sliding_window = getattr(config.text_config, - "interleaved_sliding_window", None) - self.vision_tower = AutoModel.from_config(config=config.vision_config) self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, - config.text_config) - self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, - config.text_config) + self.embed_vision = Gemma3nMultimodalEmbedder( + config.vision_config, config.text_config + ) + self.embed_audio = Gemma3nMultimodalEmbedder( + config.audio_config, config.text_config + ) self.language_model: nn.Module = init_vllm_registered_model( vllm_config=vllm_config, @@ -474,18 +526,12 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): self.config.text_config.num_hidden_layers, self.config.text_config.hidden_size_per_layer_input, device=self.language_model.model.embed_tokens.weight.device, - dtype=self.language_model.model.embed_tokens.weight.dtype) - - @property - def dtype(self): - return next(self.parameters()).dtype - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - # TODO check if there are any - return data + dtype=self.language_model.model.embed_tokens.weight.dtype, + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Gemma3nImageInputs]: + self, **kwargs: object + ) -> Optional[Gemma3nImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) # TODO is this the case? @@ -493,20 +539,13 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): if pixel_values is None: return None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - pixel_values = pixel_values.contiguous() - - return Gemma3nImagePixelInputs( - pixel_values=self._validate_pixel_values(pixel_values), ) + return Gemma3nImagePixelInputs(pixel_values=pixel_values) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Gemma3nAudioInputs]: - input_features = kwargs.pop("input_features", None) - if input_features is None: + self, **kwargs: object + ) -> Optional[Gemma3nAudioInputs]: + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: return None input_features_mask = kwargs.pop("input_features_mask", None) @@ -514,7 +553,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): return None return Gemma3nAudioInputs( - input_features=input_features, + input_features_padded=input_features_padded, input_features_mask=input_features_mask, ) @@ -524,14 +563,20 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key == "input_features" \ - and "audio" not in mm_input_by_modality: - mm_input_by_modality[ - "audio"] = self._parse_and_validate_audio_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key == "input_features_padded" + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) return mm_input_by_modality def _process_image_input( @@ -541,16 +586,20 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): assert self.vision_tower is not None pixel_values = image_input["pixel_values"] - vision_outputs = self.vision_tower(pixel_values=pixel_values, - do_pooling=False, - return_dict=True).last_hidden_state + vision_outputs = self.vision_tower( + pixel_values=pixel_values, do_pooling=False, return_dict=True + ).last_hidden_state # TODO try to avoid copy here # (batch, channels, height, width) to (batch, height * width, channels) - vision_outputs = vision_outputs.reshape( - vision_outputs.shape[0], - self.config.vision_config.hidden_size, - self.config.vision_soft_tokens_per_image, - ).permute(0, 2, 1).contiguous() + vision_outputs = ( + vision_outputs.reshape( + vision_outputs.shape[0], + self.config.vision_config.hidden_size, + self.config.vision_soft_tokens_per_image, + ) + .permute(0, 2, 1) + .contiguous() + ) # Normalize and embed the soft tokens into language model space. vision_outputs *= self.config.vision_config.hidden_size**0.5 # Return a list of embeddings instead of a batched tensor @@ -561,43 +610,44 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): audio_input: Gemma3nAudioInputs, ) -> list[torch.Tensor]: assert self.audio_tower is not None - input_features = audio_input["input_features"].squeeze(1) + # Run on padded features to enable batching + input_features = audio_input["input_features_padded"].squeeze(1) input_features_mask = audio_input["input_features_mask"].squeeze(1) - audio_outputs, audio_mask = self.audio_tower(input_features, - ~input_features_mask) + audio_outputs, audio_mask = self.audio_tower( + input_features, ~input_features_mask + ) audio_features = self.embed_audio(inputs_embeds=audio_outputs) # ruff: noqa # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the - # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # text to account for this. However, the audio preprocessing and encoder do not guarantee they will # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad - # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + # the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab. # TODO precompute and cache padding - audio_padding_toks = torch.tensor([[self.vocab_size - 1]], - dtype=torch.long, - device=audio_features.device) + audio_padding_toks = torch.tensor( + [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device + ) audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) - audio_features = torch.where(audio_mask.unsqueeze(-1), - audio_padding_embs, audio_features) + audio_features = torch.where( + audio_mask.unsqueeze(-1), audio_padding_embs, audio_features + ) audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501 extra_padding_features = audio_padding_embs.expand( - audio_batch_size, extra_padding_tokens, audio_embed_dim) + audio_batch_size, extra_padding_tokens, audio_embed_dim + ) - audio_features = torch.cat((audio_features, extra_padding_features), - dim=1) + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) # Return a list of embeddings instead of a batched tensor return audio_features.unbind(0) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if mm_input_by_modality is None: return [] @@ -619,35 +669,44 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings( - input_ids) + input_ids + ) per_layer_inputs = per_layer_inputs.reshape( - -1, self.config.text_config.num_hidden_layers, - self.config.text_config.hidden_size_per_layer_input) - self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( - per_layer_inputs) + -1, + self.config.text_config.num_hidden_layers, + self.config.text_config.hidden_size_per_layer_input, + ) + self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_( + per_layer_inputs + ) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - # NOTE: this order of processing mm items is important - [self.config.image_token_id, self.config.audio_token_id]) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> IntermediateTensors: + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -656,7 +715,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): # select a chunk of pre-allocated PLEs. During normal execution, # `get_input_embeddings` is called before forward, hence this slice # will contain PLEs computed from the actual input_ids. - per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]] + per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]] hidden_states = self.language_model.model( input_ids, @@ -664,20 +723,18 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): per_layer_inputs=per_layer_inputs, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **kwargs) + **kwargs, + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -688,7 +745,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -698,3 +756,57 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): return "<audio_soft_token>" else: raise ValueError(f"Unsupported modality: {modality}") + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + """ + Gemma3n supports "free-form" transcription. + We fix its prompt here to standardize transcriptions/translations + requests. + """ + # Transcribe this audio [into <>] | for transcription + # Translate this audio [from <> into <>] | for translation + prompt = "<start_of_turn>user\n" + prompt += "Transcribe" if task_type == "transcribe" else "Translate" + prompt += " this audio" + + # We assume the language is a valid ISO 639-1 code. + full_lang_name = cls.supported_languages.get(language, "") + # Translation only for now + full_lang_name_to = cls.supported_languages.get(to_language, "") + + if task_type == "transcribe" and full_lang_name: + prompt += f" into {full_lang_name}" + elif task_type == "translate": + if full_lang_name: + prompt += f" from {full_lang_name}" + if full_lang_name_to: + prompt += f" into {full_lang_name_to}" + + prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n" + + audio = (audio, stt_config.sample_rate) + prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} + return cast(PromptType, prompts_dict) + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + return SpeechToTextConfig( + # Let's set this to 30 as suggested in the docs for now, although + # the model is only limited by its context length. + max_audio_clip_s=30, + sample_rate=16000, + # TODO enable chunking after more thorough testing. + min_energy_split_window_size=None, + ) diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py index defa77b84e441..a6991f8e43fef 100644 --- a/vllm/model_executor/models/glm.py +++ b/vllm/model_executor/models/glm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only HF format GLM-4 model compatible with THUDM weights.""" + from vllm.config import VllmConfig from vllm.model_executor.models.llama import LlamaForCausalLM @@ -8,7 +9,6 @@ from .utils import PPMissingLayer class GlmForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.partial_rotary_factor = 0.5 super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 5e2908a82c418..f25f50602e6c2 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Optional, Union @@ -34,13 +35,11 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -50,21 +49,22 @@ from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix class Glm4Attention(nn.Module): - - def __init__(self, - config: Glm4Config, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - head_dim: Optional[int] = None, - qkv_bias: bool = False, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + config: Glm4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -113,14 +113,16 @@ class Glm4Attention(nn.Module): partial_rotary_factor=partial_rotary_factor, is_neox_style=False, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + ) def forward( self, @@ -136,15 +138,18 @@ class Glm4Attention(nn.Module): class Glm4DecoderLayer(nn.Module): - def __init__( self, - config: Glm4Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", + config: Optional[Glm4Config] = None, ) -> None: super().__init__() + + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) @@ -156,8 +161,8 @@ class Glm4DecoderLayer(nn.Module): max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, @@ -171,14 +176,14 @@ class Glm4DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_self_attn_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mlp_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_self_attn_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -191,8 +196,7 @@ class Glm4DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -201,8 +205,7 @@ class Glm4DecoderLayer(nn.Module): hidden_states = self.post_self_attn_layernorm(hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) @@ -220,13 +223,13 @@ ALL_DECODER_LAYER_TYPES = { "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Glm4Model(LlamaModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=Glm4DecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer + ) class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -252,25 +255,28 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = Glm4Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Glm4Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -282,24 +288,21 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 662728e6b1393..304e721fade5b 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -36,50 +36,71 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from packaging.version import Version from transformers import BatchFeature +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( - Glm4vImageProcessor, smart_resize) -from transformers.models.glm4v.video_processing_glm4v import ( - Glm4vVideoProcessor) + Glm4vImageProcessor, + smart_resize, +) +from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig -from vllm.distributed import parallel_state +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions +from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .qwen2_vl import (_create_qwen2vl_field_factory, - apply_rotary_pos_emb_vision) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -97,6 +118,7 @@ class Glm4vImagePixelInputs(TensorSchema): - ni: Number of images - g: Grid dimensions (3 for grid_t, grid_h, grid_w) """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")] @@ -111,6 +133,7 @@ class Glm4vImageEmbeddingInputs(TensorSchema): - n: Number of images - g: Grid dimensions (3 for grid_t, grid_h, grid_w) """ + type: Literal["image_embeds"] = "image_embeds" image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")] @@ -130,6 +153,7 @@ class Glm4vVideoPixelInputs(TensorSchema): - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")] @@ -145,6 +169,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ + type: Literal["video_embeds"] = "video_embeds" video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")] @@ -153,11 +178,10 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] -# === Vision Encoder === # +# ==== Vision Encoder ==== # class Glm4vVisionMLP(nn.Module): - def __init__( self, in_features: int, @@ -165,6 +189,7 @@ class Glm4vVisionMLP(nn.Module): bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -172,12 +197,17 @@ class Glm4vVisionMLP(nn.Module): output_sizes=[hidden_features] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor): @@ -199,8 +229,7 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): ) gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) - for tensor in gathered_tensors + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split) for tensor in pair @@ -210,7 +239,6 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): class Glm4vVisionAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -218,15 +246,22 @@ class Glm4vVisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = ( + 0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank() + ) self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) self.qkv = QKVParallelLinear( hidden_size=embed_dim, @@ -235,8 +270,9 @@ class Glm4vVisionAttention(nn.Module): total_num_kv_heads=num_heads, bias=False, quant_config=quant_config, - # Change qkv prefix to align with GLM-4.5V-FP8 quantization config + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", + disable_tp=use_data_parallel, ) self.proj = RowParallelLinear( input_size=projection_size, @@ -244,38 +280,45 @@ class Glm4vVisionAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.proj", bias=False, + disable_tp=use_data_parallel, ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype(), + ) + self.use_upstream_fa = False + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) + ) + if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"GLM-4V does not support {self.attn_backend} backend now.") + f"GLM-4V does not support {self.attn_backend} backend now." + ) + + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) - # 3 * [s, b, head * head_dim] - if self.tp_size > 1: - splitter = partial( - dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size, - ) - q = splitter(q)[self.tp_rank] - k = splitter(k)[self.tp_rank] - v = splitter(v)[self.tp_rank] - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len, @@ -287,12 +330,12 @@ class Glm4vVisionAttention(nn.Module): return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -301,20 +344,17 @@ class Glm4vVisionAttention(nn.Module): q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - - if self.attn_backend == _Backend.FLASH_ATTN: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func( + output = self.flash_attn_varlen_func( q, k, v, @@ -326,9 +366,9 @@ class Glm4vVisionAttention(nn.Module): causal=False, ) - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -338,35 +378,36 @@ class Glm4vVisionAttention(nn.Module): q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Glm4vVisionBlock(nn.Module): - def __init__( self, dim: int, @@ -375,6 +416,7 @@ class Glm4vVisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -387,6 +429,7 @@ class Glm4vVisionBlock(nn.Module): projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, ) self.mlp = Glm4vVisionMLP( dim, @@ -394,30 +437,31 @@ class Glm4vVisionBlock(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x = x + self.attn( + x_attn = self.attn( self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, seqlens=seqlens, ) + x_fused_norm, residual = self.norm2(x, residual=x_attn) + x = residual + self.mlp(x_fused_norm) - x = x + self.mlp(self.norm2(x)) return x class Glm4vVisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -441,14 +485,12 @@ class Glm4vVisionPatchEmbed(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.hidden_size) return x class Glm4vPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -456,15 +498,19 @@ class Glm4vPatchMerger(nn.Module): quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = d_model - self.proj = ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=bias, - gather_output=True, - quant_config=quant_config, - prefix=f"{prefix}.proj") + self.proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.post_projection_norm = nn.LayerNorm(self.hidden_size) self.gate_up_proj = MergedColumnParallelLinear( input_size=self.hidden_size, @@ -472,6 +518,7 @@ class Glm4vPatchMerger(nn.Module): bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, ) self.down_proj = RowParallelLinear( context_dim, @@ -479,6 +526,7 @@ class Glm4vPatchMerger(nn.Module): bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, ) self.act_fn = SiluAndMul() self.extra_activation_func = nn.GELU() @@ -493,7 +541,6 @@ class Glm4vPatchMerger(nn.Module): class Glm4vVisionEmbeddings(nn.Module): - def __init__(self, config: Glm4vVisionConfig): super().__init__() self.config = config @@ -501,18 +548,18 @@ class Glm4vVisionEmbeddings(nn.Module): self.image_size = config.image_size self.patch_size = config.patch_size - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) - def forward(self, embeddings, lengths, image_shapes, h_coords, - w_coords) -> torch.Tensor: + def forward( + self, embeddings, lengths, image_shapes, h_coords, w_coords + ) -> torch.Tensor: pos_embed_weight = self.position_embedding.weight hidden_size = pos_embed_weight.shape[1] total_seq = h_coords.shape[0] @@ -523,39 +570,54 @@ class Glm4vVisionEmbeddings(nn.Module): # Handle empty sequence case if total_seq == 0: - adapted_pos_embed = torch.empty(0, - hidden_size, - device=device, - dtype=pos_embed_weight.dtype) + adapted_pos_embed = torch.empty( + 0, hidden_size, device=device, dtype=pos_embed_weight.dtype + ) else: # Convert inputs to tensors if needed if isinstance(lengths, list): - lengths = torch.tensor(lengths, - device=device, - dtype=torch.long) + lengths = torch.tensor(lengths, device=device, dtype=torch.long) if not isinstance(image_shapes, torch.Tensor): - image_shapes = torch.tensor(image_shapes, - device=device, - dtype=torch.long) + image_shapes = torch.tensor( + image_shapes, device=device, dtype=torch.long + ) # Prepare 2D position embedding orig_size_sq = pos_embed_weight.shape[0] orig_size = int(orig_size_sq**0.5) - pos_embed_2d = (pos_embed_weight.view( - orig_size, orig_size, - hidden_size).permute(2, 0, - 1).unsqueeze(0).to(device=device, - dtype=torch.float32)) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) # Calculate target dimensions for each patch - target_h = torch.cat([ - image_shapes[i, 1].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) - target_w = torch.cat([ - image_shapes[i, 2].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) + # Add bounds checking for data parallel mode + if len(lengths) > image_shapes.shape[0]: + # In data parallel mode, some GPUs might not have all + # image shapes + # Use available image shapes, cycling if necessary + target_h_list = [] + target_w_list = [] + for i in range(len(lengths)): + # Cycle through available shapes + shape_idx = i % image_shapes.shape[0] + target_h_list.append(image_shapes[shape_idx, 1].repeat(lengths[i])) + target_w_list.append(image_shapes[shape_idx, 2].repeat(lengths[i])) + target_h = torch.cat(target_h_list).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat(target_w_list).to( + device=device, dtype=torch.float32 + ) + else: + target_h = torch.cat( + [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))] + ).to(device=device, dtype=torch.float32) + target_w = torch.cat( + [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))] + ).to(device=device, dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample h_coords = h_coords.to(device=device, dtype=torch.float32) @@ -564,8 +626,7 @@ class Glm4vVisionEmbeddings(nn.Module): norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 # Create sampling grid - grid = (torch.stack((norm_w, norm_h), - dim=-1).unsqueeze(0).unsqueeze(2)) + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) # Perform bicubic interpolation interpolated_embed_fp32 = F.grid_sample( @@ -578,9 +639,11 @@ class Glm4vVisionEmbeddings(nn.Module): # Reshape and convert back to original dtype adapted_pos_embed_fp32 = ( - interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)) - adapted_pos_embed = adapted_pos_embed_fp32.to( - pos_embed_weight.dtype).to(embeddings.device) + interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + ) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to( + embeddings.device + ) # Add adapted position encoding to embeddings embeddings = embeddings + adapted_pos_embed @@ -588,13 +651,11 @@ class Glm4vVisionEmbeddings(nn.Module): class Glm4vVisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -603,16 +664,22 @@ class Glm4vVisionRotaryEmbedding(nn.Module): if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, - self.dim, - 2, - dtype=torch.float, - device=self.inv_freq.device, - ) / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, + self.dim, + 2, + dtype=torch.float, + device=self.inv_freq.device, + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -622,13 +689,13 @@ class Glm4vVisionRotaryEmbedding(nn.Module): class Glm4vVisionTransformer(nn.Module): - def __init__( self, vision_config: Glm4vVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -638,6 +705,7 @@ class Glm4vVisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads + self.use_data_parallel = use_data_parallel self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -653,37 +721,50 @@ class Glm4vVisionTransformer(nn.Module): norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Glm4vVisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.out_hidden_size, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - ) for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Glm4vVisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.out_hidden_size, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=self.use_data_parallel, + ) + for layer_idx in range(depth) + ] + ) self.merger = Glm4vPatchMerger( d_model=vision_config.out_hidden_size, context_dim=vision_config.intermediate_size, quant_config=quant_config, bias=False, prefix=f"{prefix}.merger", + use_data_parallel=self.use_data_parallel, ) self.embeddings = Glm4vVisionEmbeddings(vision_config) - self.post_conv_layernorm = RMSNorm(vision_config.hidden_size, - eps=vision_config.rms_norm_eps) + self.post_conv_layernorm = RMSNorm( + vision_config.hidden_size, eps=vision_config.rms_norm_eps + ) self.downsample = nn.Conv2d( in_channels=vision_config.hidden_size, out_channels=vision_config.out_hidden_size, kernel_size=vision_config.spatial_merge_size, stride=vision_config.spatial_merge_size, ) - self.post_layernorm = RMSNorm(vision_config.hidden_size, - eps=vision_config.rms_norm_eps) + self.post_layernorm = RMSNorm( + vision_config.hidden_size, eps=vision_config.rms_norm_eps + ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -698,20 +779,27 @@ class Glm4vVisionTransformer(nn.Module): for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = (hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten()) - wpos_ids = (wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten()) - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -724,15 +812,21 @@ class Glm4vVisionTransformer(nn.Module): ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - if self.attn_backend == _Backend.FLASH_ATTN: + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: + # Convert grid_thw to tensor (always expecting list format now) + grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long) + # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) @@ -741,15 +835,16 @@ class Glm4vVisionTransformer(nn.Module): # compute position embedding rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # pre-compute seqlens for attn mask to reduce cuMemcpy operations max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) - x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0], - image_type_ids[:, 1]) + x = self.embeddings( + x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1] + ) # transformers x = x.unsqueeze(1) @@ -765,16 +860,14 @@ class Glm4vVisionTransformer(nn.Module): # adapter x = self.post_layernorm(x) - x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, - x.shape[-1]) + x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1]) x = x.permute(0, 3, 1, 2) x = self.downsample(x).view(-1, self.out_hidden_size) x = self.merger(x) return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -798,15 +891,13 @@ class Glm4vVisionTransformer(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Glm4vProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -839,17 +930,16 @@ class Glm4vProcessingInfo(BaseProcessingInfo): if do_resize: resized_height, resized_width = smart_resize( num_frames=num_frames - if num_frames > temporal_patch_size else temporal_patch_size, + if num_frames > temporal_patch_size + else temporal_patch_size, height=image_height, width=image_width, factor=patch_size * merge_size, max_pixels=max_image_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) # NOTE: Frames are padded to be divisible by `temporal_patch_size` # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 @@ -865,8 +955,9 @@ class Glm4vProcessingInfo(BaseProcessingInfo): return preprocessed_size, num_vision_tokens def get_image_size_with_most_features(self) -> ImageSize: - max_image_size, _ = self._get_vision_info(image_width=9999999, - image_height=9999999) + max_image_size, _ = self._get_vision_info( + image_width=9999999, image_height=9999999 + ) return max_image_size def get_num_image_tokens( @@ -933,44 +1024,47 @@ class Glm4vProcessingInfo(BaseProcessingInfo): max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) - def _get_video_second_idx(self, metadata: dict[str, Any], - total_frames: int) -> list[int]: + def _get_video_second_idx( + self, metadata: dict[str, Any], total_frames: int + ) -> list[int]: video_processor = self.get_video_processor() video_fps = metadata.get("fps", video_processor.fps) meta_frames = metadata.get("total_num_frames", total_frames) max_frame_idx = meta_frames - 1 - duration = metadata.get("duration", - round(max_frame_idx / video_fps) + 1) - if duration <= video_processor.max_duration: - n = int(math.floor(duration * video_processor.fps)) - frame_indices = [ - min( - max_frame_idx, - int(math.ceil(i * video_fps / video_processor.fps)), - ) for i in range(n) - ] + duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1) + do_sample_frames = metadata["do_sample_frames"] + if not do_sample_frames: + frame_indices = metadata["frames_indices"] else: - num_samples = int(video_processor.max_duration * - video_processor.fps) - if num_samples >= meta_frames: - frame_indices = list(range(meta_frames)) - else: - target_seconds = np.linspace(0, - duration, - num_samples, - endpoint=True) + if duration <= video_processor.max_duration: + n = int(math.floor(duration * video_processor.fps)) frame_indices = [ - min(max_frame_idx, int(math.ceil(t * video_fps))) - for t in target_seconds + min( + max_frame_idx, + int(math.ceil(i * video_fps / video_processor.fps)), + ) + for i in range(n) ] + else: + num_samples = int(video_processor.max_duration * video_processor.fps) + if num_samples >= meta_frames: + frame_indices = list(range(meta_frames)) + else: + target_seconds = np.linspace( + 0, duration, num_samples, endpoint=True + ) + frame_indices = [ + min(max_frame_idx, int(math.ceil(t * video_fps))) + for t in target_seconds + ] seen, uniq = set(), [] for idx in frame_indices: @@ -988,9 +1082,43 @@ class Glm4vProcessingInfo(BaseProcessingInfo): selected_timestamps.append(timestamps_list[idx]) return selected_timestamps + def _construct_video_placeholder( + self, + video_array: np.ndarray, + metadata: dict[str, Any], + grid_thw: torch.Tensor, + ) -> str: + hf_processor = self.get_hf_processor() + tokenizer = self.get_tokenizer() + image_processor = hf_processor.image_processor + + hf_config = self.get_hf_config() + boi_token_id = hf_config.image_start_token_id + eoi_token_id = hf_config.image_end_token_id + bov_token_id = hf_config.video_start_token_id + eov_token_id = hf_config.video_end_token_id + merge_length = image_processor.merge_size**2 + + assert isinstance(grid_thw, torch.Tensor) + timestamps = self._get_video_second_idx(metadata, len(video_array)) + frames_idx_token = [ + tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps + ] + T, H, W = grid_thw + num_tokens_per_frame = int(H * W) // merge_length + placeholder = [] + placeholder.append(bov_token_id) + for frame_idx in frames_idx_token: + placeholder.append(boi_token_id) + placeholder.extend([hf_processor.video_token_id] * num_tokens_per_frame) + placeholder.append(eoi_token_id) + placeholder.extend(frame_idx) + placeholder.append(eov_token_id) + + return placeholder + class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1013,25 +1141,32 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = ( - self.info.get_image_size_with_most_features()) + target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( - seq_len, mm_counts) + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, + overrides=video_overrides, ), } @@ -1042,7 +1177,37 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): height: int, num_frames: int, num_videos: int, + overrides: Optional[VideoDummyOptions] = None, ) -> list[VideoItem]: + if overrides: + if overrides.num_frames: + if overrides.num_frames > num_frames: + logger.warning( + "video.num_frames override (%d) exceeds model's " + "maximum number of frames (%d), will be ignored", + overrides.num_frames, + num_frames, + ) + num_frames = min(num_frames, overrides.num_frames) + if overrides.width: + if overrides.width > width: + logger.warning( + "video.width override (%d) exceeds model's " + "maximum width (%d), will be ignored", + overrides.width, + width, + ) + width = min(width, overrides.width) + if overrides.height: + if overrides.height > height: + logger.warning( + "video.height override (%d) exceeds model's " + "maximum height (%d), will be ignored", + overrides.height, + height, + ) + height = min(height, overrides.height) + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] for i in range(num_videos): @@ -1050,7 +1215,9 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): "fps": 2.0, "duration": num_frames / 2.0, "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], "video_backend": "opencv", + "do_sample_frames": False, } video_item = (video.copy(), video_metadata) video_items.append(video_item) @@ -1059,7 +1226,6 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: return MultiModalDataParser(video_needs_metadata=True) @@ -1076,60 +1242,77 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): # GLM-4.1V use `image_token_id` as video placeholder, we need to # replace it with `video_token_id` for video processing. So we # separate video processing from image processing. - if ("videos" in mm_data and isinstance(mm_data["videos"], list) - and len(mm_data["videos"]) > 0): + if ( + "videos" in mm_data + and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0 + ): video_grid_thw_lst = [] pixel_values_videos_lst = [] for item in mm_data.pop("videos", []): video_array, metadata = item - # FIXME(Isotr0py): Activate the below logic after we can disable - # resampling from video loader backend. - # assert metadata["total_num_frames"] == len(video_array), ( - # f"Total frames {metadata['total_num_frames']} does not " - # f"match the length of video array {len(video_array)}.") - - # NOTE: Temporary workaround for resampled videos. - # this can cause a divergence with HF implementation if - # the input video is resampled in advance. - - if metadata["total_num_frames"] != len(video_array): - logger.warning( - "Total frames in metadata " - "(%s) does not match the length of " - "video array %s. This can " - "be because the video is resampled " - "in advance. This may cause " - "a divergence with HF implementation.", - metadata["total_num_frames"], - len(video_array), - ) - metadata["total_num_frames"] = len(video_array) - metadata = VideoMetadata(**metadata) + # don't update mm_kwargs inplace + video_mm_kwargs = dict(**mm_kwargs) + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", True + ) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] - video_mm_data["video_metadata"] = [[metadata]] + + # backward compatibility for Transformers 4.55 + unuse_metadata = ["do_sample_frames"] + if ( + not hasattr(VideoMetadata, "frames_indices") + and "frames_indices" in metadata + ): + unuse_metadata.append("frames_indices") + + video_mm_data["video_metadata"] = [ + [ + VideoMetadata( + **{ + k: metadata[k] + for k in metadata + if k not in unuse_metadata + } + ) + ] + ] video_outputs = super()._call_hf_processor( prompt="<|begin_of_video|><|video|><|end_of_video|>", mm_data=video_mm_data, - mm_kwargs=mm_kwargs, + mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) - input_ids = video_outputs.pop("input_ids") - input_ids[input_ids == processor.image_token_id] = ( - processor.video_token_id) - video_placeholder = processor.tokenizer.batch_decode( - input_ids)[0] + if not video_mm_kwargs["do_sample_frames"] and Version( + TRANSFORMERS_VERSION + ) < Version("4.56.0"): + # Transformers v4.55 has incorrect timestamps issue for + # skip sampling. We construct the placeholder manually to + # get placeholders with correct timestamps. + placeholder = self.info._construct_video_placeholder( + video_array, + metadata, + video_outputs["video_grid_thw"].squeeze(0), + ) + video_placeholder = processor.tokenizer.decode(placeholder) + else: + input_ids = video_outputs.pop("input_ids") + input_ids[input_ids == processor.image_token_id] = ( + processor.video_token_id + ) + video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( "<|begin_of_video|><|video|><|end_of_video|>", video_placeholder, + 1, ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) - pixel_values_videos_lst.append( - video_outputs["pixel_values_videos"]) + pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), @@ -1155,8 +1338,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _create_qwen2vl_field_factory( - self.info.get_hf_config().vision_config.spatial_merge_size)( - hf_inputs) + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) def _get_prompt_updates( self, @@ -1165,16 +1348,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - hf_config = self.info.get_hf_config() - - boi_token_id = hf_config.image_start_token_id - eoi_token_id = hf_config.image_end_token_id - - bov_token_id = hf_config.video_start_token_id - eov_token_id = hf_config.video_end_token_id + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) merge_length = image_processor.merge_size**2 @@ -1192,21 +1366,9 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): assert isinstance(grid_thw, torch.Tensor) video, metadata = mm_items["video"][item_idx] - timestamps = self.info._get_video_second_idx(metadata, len(video)) - frames_idx_token = [ - tokenizer.encode(str(i), add_special_tokens=False) - for i in timestamps - ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length - placeholder = [] - placeholder.append(bov_token_id) - for frame_idx in frames_idx_token: - placeholder.append(boi_token_id) - placeholder.extend([hf_processor.video_token_id] * - num_tokens_per_frame) - placeholder.append(eoi_token_id) - placeholder.extend(frame_idx) - placeholder.append(eov_token_id) + placeholder = self.info._construct_video_placeholder( + video, metadata, grid_thw + ) return PromptUpdateDetails.select_token_id( placeholder, embed_token_id=hf_processor.video_token_id, @@ -1231,15 +1393,18 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): info=Glm4vProcessingInfo, dummy_inputs=Glm4vDummyInputsBuilder, ) -class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): +class Glm4vForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], - "gate_up_proj": ["gate_up_proj"] + "gate_up_proj": ["gate_up_proj"], } # To ensure correct weight loading and mapping. @@ -1248,7 +1413,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", "model.language_model.": "language_model.model.", "model.visual.": "visual.", - }) + } + ) + + supports_encoder_tp_data = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1267,12 +1435,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) if config.model_type == "glm4v": @@ -1286,29 +1456,16 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=architectures) + architectures=architectures, + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) - else: - return torch.concat(mm_input) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Glm4vImageInputs]: + self, **kwargs: object + ) -> Optional[Glm4vImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1317,11 +1474,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return Glm4vImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1329,11 +1481,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return Glm4vImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1341,7 +1488,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Glm4vVideoInputs]: + self, **kwargs: object + ) -> Optional[Glm4vVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1350,11 +1498,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Glm4vVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1362,11 +1505,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Glm4vVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, @@ -1374,43 +1512,60 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, ) def _process_image_input( - self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]: + self, image_input: Glm4vImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist()) merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() + return image_embeds.split(sizes) def _process_video_input( - self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]: + self, video_input: Glm4vVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() - device = self.visual.device - flat_grid_thw = torch.cat([ - torch.tensor([[1, h, w]] * t, device=device) - for t, h, w in grid_thw - ]) if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=flat_grid_thw) - + self.visual.dtype + ) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw.tolist(), + rope_type="rope_3d", + ) + else: + video_embeds = self.visual( + pixel_values_videos, grid_thw=grid_thw.tolist() + ) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - - return video_embeds.split(sizes.tolist()) + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1418,28 +1573,34 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if (input_key in ("pixel_values", "image_embeds") - and "image" not in mm_input_by_modality): - mm_input_by_modality["image"] = ( - self._parse_and_validate_image_input(**kwargs)) - if (input_key in ("pixel_values_videos", "video_embeds") - and "video" not in mm_input_by_modality): - mm_input_by_modality["video"] = ( - self._parse_and_validate_video_input(**kwargs)) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -1454,49 +1615,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0 - and all(embed.numel() > 0 for embed in multimodal_embeddings)): - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id], - ) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Glm4vImageInputs] = None, - video_input: Optional[Glm4vVideoInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1515,41 +1633,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, **NOTE**: If mrope is enabled (default setting for GLM-4V opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - second_per_grid_ts: Tensor `(num_videos)` of video time interval ( - in seconds) for each grid along the temporal dimension in the - 3D position IDs. `None` if no videos are passed. + intermediate_tensors: Optional intermediate tensors for pipeline + parallelism. + inputs_embeds: Optional pre-computed input embeddings. + **kwargs: Additional keyword arguments. """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -1561,13 +1652,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index fe5e46a99826f..b9cdee29417a6 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -21,9 +21,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only GLM-4.5 model compatible with HuggingFace weights.""" +"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights.""" + import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -33,35 +35,47 @@ from transformers.models.glm4_moe import Glm4MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Glm4MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -73,19 +87,24 @@ class Glm4MoeMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -96,7 +115,6 @@ class Glm4MoeMLP(nn.Module): class Glm4MoE(nn.Module): - def __init__( self, config: Glm4MoeConfig, @@ -115,8 +133,10 @@ class Glm4MoE(nn.Module): self.n_shared_experts: int = config.n_shared_experts if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) # NOTE In the transformers implementation, the gate isn't an nn.Linear, # so we cannot use ReplicatedLinear here. # See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260 @@ -127,7 +147,8 @@ class Glm4MoE(nn.Module): dtype=torch.float32, ) self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32)) + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -136,16 +157,29 @@ class Glm4MoE(nn.Module): self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) - self.experts = FusedMoE( + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -158,44 +192,41 @@ class Glm4MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func="sigmoid", + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - self.shared_experts = Glm4MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), - prefix=f"{prefix}.shared_experts", - ) + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states.to(dtype=torch.float32)) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out + assert shared_output is not None final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states * self.routed_scaling_factor + shared_output + ) + else: + final_hidden_states = fused_moe_out * self.routed_scaling_factor + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_dim) class Glm4MoeAttention(nn.Module): - def __init__( self, config: Glm4MoeConfig, @@ -237,19 +268,23 @@ class Glm4MoeAttention(nn.Module): self.max_position_embeddings = max_position_embeddings self.use_qk_norm = use_qk_norm - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) self.rotary_emb = get_rope( @@ -282,10 +317,12 @@ class Glm4MoeAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_qk_norm: - q = self.q_norm(q.reshape(-1, self.num_heads, - self.head_dim)).reshape(q.shape) - k = self.k_norm(k.reshape(-1, self.num_kv_heads, - self.head_dim)).reshape(k.shape) + q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape( + q.shape + ) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape( + k.shape + ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) @@ -294,7 +331,6 @@ class Glm4MoeAttention(nn.Module): class Glm4MoeDecoderLayer(nn.Module): - def __init__( self, config: Glm4MoeConfig, @@ -307,11 +343,10 @@ class Glm4MoeDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 131072) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx self.self_attn = Glm4MoeAttention( @@ -331,8 +366,10 @@ class Glm4MoeDecoderLayer(nn.Module): use_qk_norm=config.use_qk_norm, ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace): + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + ): self.mlp = Glm4MoE( config=config, quant_config=quant_config, @@ -340,16 +377,18 @@ class Glm4MoeDecoderLayer(nn.Module): enable_eplb=enable_eplb, ) else: - self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -362,12 +401,9 @@ class Glm4MoeDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -378,9 +414,9 @@ class Glm4MoeDecoderLayer(nn.Module): "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Glm4MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -394,9 +430,8 @@ class Glm4MoeModel(nn.Module): if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=f"{prefix}.embed_tokens") + config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens" + ) else: self.embed_tokens = PPMissingLayer() @@ -409,15 +444,16 @@ class Glm4MoeModel(nn.Module): prefix=prefix, enable_eplb=enable_eplb, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -440,44 +476,42 @@ class Glm4MoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -494,7 +528,7 @@ class Glm4MoeModel(nn.Module): spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -504,7 +538,7 @@ class Glm4MoeModel(nn.Module): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -539,14 +573,17 @@ class Glm4MoeModel(nn.Module): # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -570,8 +607,9 @@ class Glm4MoeModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -599,25 +637,29 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Glm4MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Glm4MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) self.expert_weights = [] # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group - self.moe_layers: list[FusedMoE] = [] + self.moe_layers: list[SharedFusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -665,21 +707,19 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -687,13 +727,14 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): return self.model.get_expert_mapping() -def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): +def get_spec_layer_idx_from_weight_name( + config: Glm4MoeConfig, weight_name: str +) -> Optional[int]: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if f"layers.{layer_idx+i}." in weight_name: + if f"layers.{layer_idx + i}." in weight_name: return layer_idx + i return None diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 322c5619c1783..beb40632246c0 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -36,9 +36,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name @@ -47,24 +48,26 @@ from .utils import maybe_prefix class SharedHead(nn.Module): - def __init__( self, config: PretrainedConfig, + prefix: str, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head"), + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) class Glm4MoeMultiTokenPredictorLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -75,14 +78,16 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): super().__init__() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = Glm4MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.shared_head = SharedHead( + config=config, prefix=prefix, quant_config=quant_config + ) + self.mtp_block = Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) def forward( self, @@ -99,40 +104,46 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return hidden_states class Glm4MoeMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - Glm4MoeMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): Glm4MoeMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -143,7 +154,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -155,26 +166,26 @@ class Glm4MoeMultiTokenPredictor(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers[str(self.mtp_start_layer_idx + - current_step_idx)] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) return logits class Glm4MoeMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = Glm4MoeMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -185,21 +196,19 @@ class Glm4MoeMTP(nn.Module, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, sampling_metadata, - spec_step_idx) + return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -215,7 +224,8 @@ class Glm4MoeMTP(nn.Module, SupportsPP): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -224,7 +234,7 @@ class Glm4MoeMTP(nn.Module, SupportsPP): if spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -234,7 +244,7 @@ class Glm4MoeMTP(nn.Module, SupportsPP): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -254,11 +264,13 @@ class Glm4MoeMTP(nn.Module, SupportsPP): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -267,13 +279,16 @@ class Glm4MoeMTP(nn.Module, SupportsPP): # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if (spec_layer != self.model.mtp_start_layer_idx - and ".layers" not in name): + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -285,7 +300,11 @@ class Glm4MoeMTP(nn.Module, SupportsPP): and rename shared layer weights to be top level. """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] shared_weight_names = ["embed_tokens"] spec_layer_weight = False @@ -298,8 +317,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP): break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) elif shared_weight: # treat shared weights as top level weights name = name.replace(f"model.layers.{spec_layer}.", "model.") diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index bf33575859aea..63731b2947d2d 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -4,6 +4,8 @@ # Adapted from # https://github.com/zai-org/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" + +import itertools from argparse import Namespace from collections.abc import Mapping, Sequence from typing import Annotated, Literal, Optional, Union @@ -13,37 +15,50 @@ from torch import nn from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import flatten_bn, merge_multimodal_embeddings +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) class GLMVImagePixelInputs(TensorSchema): @@ -54,21 +69,22 @@ class GLMVImagePixelInputs(TensorSchema): - h: Height of image - w: Width of image """ + type: Literal["pixel_values"] = "pixel_values" data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] class EVA2CLIPPatchEmbedding(nn.Module): - def __init__(self, config): super().__init__() - self.proj = nn.Conv2d(config.in_channels, - config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size) + self.proj = nn.Conv2d( + config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + ) self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) - self.position_embedding = nn.Embedding(config.num_positions, - config.hidden_size) + self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size) def forward(self, images: torch.Tensor) -> torch.Tensor: """ @@ -80,8 +96,7 @@ class EVA2CLIPPatchEmbedding(nn.Module): torch.Tensor Transformed tensor with shape (B, L, D) """ - images = images.to(device=self.proj.weight.device, - dtype=self.proj.weight.dtype) + images = images.to(device=self.proj.weight.device, dtype=self.proj.weight.dtype) x = self.proj(images) x = x.flatten(2).transpose(1, 2) cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) @@ -91,12 +106,11 @@ class EVA2CLIPPatchEmbedding(nn.Module): class EVA2CLIPAttention(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -119,8 +133,9 @@ class EVA2CLIPAttention(nn.Module): prefix=f"{prefix}.dense", ) - self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, - self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_rank, self.head_dim, self.scale + ) self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -134,12 +149,11 @@ class EVA2CLIPAttention(nn.Module): class EVA2CLIPMLP(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() self.config = config @@ -165,29 +179,27 @@ class EVA2CLIPMLP(nn.Module): class EVA2CLIPTransformerLayer(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() - self.input_layernorm = LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.attention = EVA2CLIPAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.attention") - self.mlp = EVA2CLIPMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.post_attention_layernorm = LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = EVA2CLIPAttention( + config, quant_config=quant_config, prefix=f"{prefix}.attention" + ) + self.mlp = EVA2CLIPMLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + self.post_attention_layernorm = LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def forward(self, hidden_states): attention_input = hidden_states - attention_output = self.input_layernorm( - self.attention(attention_input)) + attention_output = self.input_layernorm(self.attention(attention_input)) hidden_states = attention_input + attention_output mlp_input = hidden_states mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) @@ -196,20 +208,23 @@ class EVA2CLIPTransformerLayer(nn.Module): class EVA2CLIPTransformer(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() - self.layers = nn.ModuleList([ - EVA2CLIPTransformerLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + EVA2CLIPTransformerLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward(self, hidden_states): for layer_module in self.layers: @@ -218,13 +233,12 @@ class EVA2CLIPTransformer(nn.Module): class EVA2CLIPGLU(nn.Module): - def __init__( self, config, in_features, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): """ The original implementation is the same as: @@ -233,14 +247,14 @@ class EVA2CLIPGLU(nn.Module): config.hidden_size, config.ffn_hidden_size, bias=False, - quant_config=quant_config + quant_config=quant_config, ) self.gate_proj = ColumnParallelLinear( config.hidden_size, config.ffn_hidden_size, bias=False, - quant_config=quant_config + quant_config=quant_config, ) ``` ``` @@ -255,7 +269,7 @@ class EVA2CLIPGLU(nn.Module): config.hidden_size, [config.ffn_hidden_size] * 2, bias=False, - quant_config=quant_config + quant_config=quant_config, ) ``` ``` @@ -263,27 +277,32 @@ class EVA2CLIPGLU(nn.Module): ``` """ super().__init__() - self.linear_proj = ReplicatedLinear(in_features, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.linear_proj") + self.linear_proj = ReplicatedLinear( + in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj", + ) self.norm1 = nn.LayerNorm(config.hidden_size) self.act1 = nn.GELU() self.act2 = SiluAndMul() self.merged_proj = MergedColumnParallelLinear( - config.hidden_size, [config.ffn_hidden_size] * 2, + config.hidden_size, + [config.ffn_hidden_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.merged_proj") + prefix=f"{prefix}.merged_proj", + ) self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, config.hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.dense_4h_to_h") + prefix=f"{prefix}.dense_4h_to_h", + ) def forward(self, x): x, _ = self.linear_proj(x) @@ -295,27 +314,30 @@ class EVA2CLIPGLU(nn.Module): class EVA2CLIPModel(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() vision_config = Namespace(**config.vision_config) self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config) - self.transformer = EVA2CLIPTransformer(vision_config, - quant_config=quant_config, - prefix=f"{prefix}.transformer") - self.linear_proj = EVA2CLIPGLU(config, - in_features=config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.linear_proj") - self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, - out_channels=config.hidden_size, - kernel_size=2, - stride=2) + self.transformer = EVA2CLIPTransformer( + vision_config, quant_config=quant_config, prefix=f"{prefix}.transformer" + ) + self.linear_proj = EVA2CLIPGLU( + config, + in_features=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj", + ) + self.conv = nn.Conv2d( + in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2, + ) self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.scaling_factor = vision_config.scaling_factor @@ -349,15 +371,14 @@ class EVA2CLIPModel(nn.Module): class GLM4VModel(ChatGLMModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) quant_config = vllm_config.quant_config - self.vision = EVA2CLIPModel(self.config, - quant_config, - prefix=f"{prefix}.vision") + self.vision = EVA2CLIPModel( + self.config, quant_config, prefix=f"{prefix}.vision" + ) class GLM4VProcessor: @@ -379,17 +400,19 @@ class GLM4VProcessor: vision_config = config.vision_config image_size = vision_config["image_size"] - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ]) + self.image_transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) def __call__( self, @@ -424,7 +447,6 @@ class GLM4VProcessor: class GLM4VProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(ChatGLMConfig) @@ -454,7 +476,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo): class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -466,6 +487,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config @@ -473,16 +495,19 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): - def _hf_processor_applies_updates( self, prompt_text: str, @@ -526,16 +551,20 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, - info=GLM4VProcessingInfo, - dummy_inputs=GLM4VDummyInputsBuilder) -class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, - SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor( + GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder, +) +class GLM4VForCausalLM( + ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): + merge_by_field_config = True packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], - "merged_proj": ["gate_proj", "dense_h_to_4h"] + "merged_proj": ["gate_proj", "dense_h_to_4h"], } def get_mm_mapping(self) -> MultiModelKeys: @@ -545,7 +574,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, return MultiModelKeys.from_string_field( language_model="transformer.encoder", connector="transformer.vision.linear_proj", - tower_model="transformer.vision.transformer") + tower_model="transformer.vision.transformer", + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -570,36 +600,175 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, self.transformer: GLM4VModel def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: + self, **kwargs: object + ) -> Optional[GLMVImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - expected_h = expected_w = self.config.vision_config["image_size"] - return GLMVImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values, - concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }) + return GLMVImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) return None - def _process_image_input( - self, image_input: GLMVImagePixelInputs) -> torch.Tensor: + def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor: pixel_values = image_input["data"].to(dtype=self.config.torch_dtype) return self.transformer.vision(pixel_values) + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.transformer - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + get_input_embeddings = SupportsMultiModal.get_input_embeddings + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -607,28 +776,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=[ - self.config.boi_token_id, - self.config.pad_token_id, - self.config.eoi_token_id, - ], - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -640,15 +787,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 98d76337395b9..53d6026c5938e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -19,7 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -30,28 +32,36 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import ( - get_pp_group, get_tensor_model_parallel_world_size) + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from ..layers.pooler import DispatchPooler, Pooler from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPT2Attention(nn.Module): - def __init__( self, config: GPT2Config, @@ -62,8 +72,7 @@ class GPT2Attention(nn.Module): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -84,12 +93,14 @@ class GPT2Attention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.c_proj", ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -103,7 +114,6 @@ class GPT2Attention(nn.Module): class GPT2MLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -137,7 +147,6 @@ class GPT2MLP(nn.Module): class GPT2Block(nn.Module): - def __init__( self, config: GPT2Config, @@ -147,19 +156,14 @@ class GPT2Block(nn.Module): ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = GPT2Attention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, - config, - quant_config, - prefix=f"{prefix}.mlp") + self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") def forward( self, @@ -181,7 +185,6 @@ class GPT2Block(nn.Module): @support_torch_compile class GPT2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -194,20 +197,22 @@ class GPT2Model(nn.Module): assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size - self.wte = VocabParallelEmbedding(config.vocab_size, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.wte") + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.wte", + ) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: GPT2Block( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix: GPT2Block(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -228,7 +233,7 @@ class GPT2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: @@ -237,8 +242,7 @@ class GPT2Model(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -260,34 +264,35 @@ class GPT2Model(nn.Module): if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class GPT2LMHeadModel(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.lm_head") + self.transformer = GPT2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.transformer.wte) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -299,21 +304,19 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) weights = _add_transformer_prefix(weights) return loader.load_weights(weights) @@ -336,19 +339,25 @@ class GPT2ForSequenceClassification(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.transformer = GPT2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "gpt2")) - self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + self.transformer = GPT2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2") + ) + self.score = nn.Linear( + config.n_embd, + config.num_labels, + bias=False, + dtype=vllm_config.model_config.head_dtype, + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - Pooler.for_classify(pooler_config, classifier=None), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": Pooler.for_classify(pooler_config, classifier=self.score), + } + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -365,16 +374,15 @@ class GPT2ForSequenceClassification(nn.Module): input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) - logits = self.score(hidden_states) - return logits + intermediate_tensors=intermediate_tensors, + ) + return hidden_states def _add_transformer_prefix( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: for name, tensor in weights: - if not name.startswith('transformer.') and not name.startswith( - "lm_head"): - name = 'transformer.' + name + if not name.startswith("transformer.") and not name.startswith("lm_head"): + name = "transformer." + name yield name, tensor diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 036ded530f97d..b6d3d8f3f2e60 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -20,7 +20,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -32,25 +34,31 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPTBigCodeAttention(nn.Module): - def __init__( self, config: GPTBigCodeConfig, @@ -61,11 +69,9 @@ class GPTBigCodeAttention(nn.Module): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - self.tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert total_num_heads % self.tensor_model_parallel_world_size == 0 - self.num_heads = (total_num_heads // - self.tensor_model_parallel_world_size) + self.num_heads = total_num_heads // self.tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 @@ -94,13 +100,15 @@ class GPTBigCodeAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.c_proj", ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -110,7 +118,8 @@ class GPTBigCodeAttention(nn.Module): q, k, v = qkv.split( [ self.hidden_size // self.tensor_model_parallel_world_size, - self.kv_dim, self.kv_dim + self.kv_dim, + self.kv_dim, ], dim=-1, ) @@ -120,7 +129,6 @@ class GPTBigCodeAttention(nn.Module): class GPTBigMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -154,7 +162,6 @@ class GPTBigMLP(nn.Module): class GPTBigCodeBlock(nn.Module): - def __init__( self, config: GPTBigCodeConfig, @@ -164,19 +171,14 @@ class GPTBigCodeBlock(nn.Module): ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = GPTBigCodeAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, - config, - quant_config, - prefix=f"{prefix}.mlp") + self.mlp = GPTBigMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") def forward( self, @@ -184,7 +186,9 @@ class GPTBigCodeBlock(nn.Module): ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn(hidden_states=hidden_states, ) + attn_output = self.attn( + hidden_states=hidden_states, + ) # residual connection hidden_states = attn_output + residual @@ -198,7 +202,6 @@ class GPTBigCodeBlock(nn.Module): @support_torch_compile class GPTBigCodeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -211,23 +214,27 @@ class GPTBigCodeModel(nn.Module): assert not config.add_cross_attention self.embed_dim = config.hidden_size - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab - self.wte = VocabParallelEmbedding(self.vocab_size, - self.embed_dim, - org_num_embeddings=config.vocab_size) + self.wte = VocabParallelEmbedding( + self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size + ) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: GPTBigCodeBlock( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -246,7 +253,7 @@ class GPTBigCodeModel(nn.Module): else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: @@ -254,8 +261,7 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -266,13 +272,12 @@ class GPTBigCodeModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method if "c_attn.input_scale" in name: - weight_loader(param, loaded_weight, 'q') - weight_loader(param, loaded_weight, 'k') - weight_loader(param, loaded_weight, 'v') + weight_loader(param, loaded_weight, "q") + weight_loader(param, loaded_weight, "k") + weight_loader(param, loaded_weight, "v") else: weight_loader(param, loaded_weight) loaded_params.add(name) @@ -292,23 +297,27 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = GPTBigCodeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: self.lm_head = ParallelLMHead( self.transformer.vocab_size, self.transformer.embed_dim, - org_num_embeddings=self.config.vocab_size) + org_num_embeddings=self.config.vocab_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -320,21 +329,19 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = None if self.config.tie_word_embeddings: skip_prefixes = ["lm_head."] diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index bd162a5e57bc1..5428512dec195 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -18,7 +18,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -30,27 +32,35 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPTJAttention(nn.Module): - def __init__( self, config: GPTJConfig, @@ -85,8 +95,7 @@ class GPTJAttention(nn.Module): assert getattr(config, "rotary", True) assert config.rotary_dim % 2 == 0 rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, rotary_dim=config.rotary_dim, @@ -94,12 +103,14 @@ class GPTJAttention(nn.Module): base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, - self.head_size, - scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -115,7 +126,6 @@ class GPTJAttention(nn.Module): class GPTJMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -144,7 +154,6 @@ class GPTJMLP(nn.Module): class GPTJBlock(nn.Module): - def __init__( self, config: GPTJConfig, @@ -153,13 +162,11 @@ class GPTJBlock(nn.Module): prefix: str = "", ): super().__init__() - inner_dim = (4 * config.n_embd - if config.n_inner is None else config.n_inner) + inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = GPTJAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -180,7 +187,6 @@ class GPTJBlock(nn.Module): @support_torch_compile class GPTJModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -197,14 +203,13 @@ class GPTJModel(nn.Module): ) self.start_layer, self.end_layer, self.h = make_layers( config.n_layer, - lambda prefix: GPTJBlock( - config, cache_config, quant_config, prefix=prefix), + lambda prefix: GPTJBlock(config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -223,15 +228,14 @@ class GPTJModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -246,19 +250,20 @@ class GPTJModel(nn.Module): if "attn.bias" in name or "attn.masked_bias" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -281,15 +286,13 @@ class GPTJModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class GPTJForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -297,18 +300,20 @@ class GPTJForCausalLM(nn.Module, SupportsPP): self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = GPTJModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, bias=True, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -320,20 +325,18 @@ class GPTJForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) + logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index d418d8bb86cee..8278ae03d88a5 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -18,7 +18,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -30,26 +32,32 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPTNeoXAttention(nn.Module): - def __init__( self, config: GPTNeoXConfig, @@ -63,11 +71,9 @@ class GPTNeoXAttention(nn.Module): self.head_size = self.hidden_size // self.total_num_heads self.bias = getattr(config, "attention_bias", True) - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.query_key_value = QKVParallelLinear( config.hidden_size, @@ -86,20 +92,21 @@ class GPTNeoXAttention(nn.Module): rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, rotary_dim=rotary_dim, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_size, - scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -115,7 +122,6 @@ class GPTNeoXAttention(nn.Module): class GPTNeoXMLP(nn.Module): - def __init__( self, config: GPTNeoXConfig, @@ -142,7 +148,6 @@ class GPTNeoXMLP(nn.Module): class GPTNeoXLayer(nn.Module): - def __init__( self, config: GPTNeoXConfig, @@ -152,14 +157,15 @@ class GPTNeoXLayer(nn.Module): ): super().__init__() self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attention") + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.attention = GPTNeoXAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attention" + ) self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -192,7 +198,6 @@ class GPTNeoXLayer(nn.Module): @support_torch_compile class GPTNeoXModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -209,14 +214,16 @@ class GPTNeoXModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: GPTNeoXLayer( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_in(input_ids) @@ -235,23 +242,24 @@ class GPTNeoXModel(nn.Module): hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layer_norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if ("attention.bias" in name or "attention.masked_bias" in name - or "rotary_emb.inv_freq" in name): + if ( + "attention.bias" in name + or "attention.masked_bias" in name + or "rotary_emb.inv_freq" in name + ): continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using OpenRLHF may include # these tensors in the checkpoint. Skip them. continue @@ -269,39 +277,41 @@ class GPTNeoXModel(nn.Module): if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class GPTNeoXForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "gpt_neox")) + self.gpt_neox = GPTNeoXModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt_neox") + ) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_out"), ) if self.config.tie_word_embeddings: self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.gpt_neox.make_empty_intermediate_tensors) + self.gpt_neox.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.gpt_neox.get_input_embeddings(input_ids) @@ -313,20 +323,18 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.gpt_neox(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.gpt_neox( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.embed_out, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.embed_out, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index cd93f0ef1e310..17f9114350798 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -11,28 +11,41 @@ from transformers import GptOssConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, - maybe_prefix) +from .interfaces import SupportsEagle3, SupportsPP +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OAIAttention(nn.Module): - def __init__( self, config: GptOssConfig, @@ -54,16 +67,13 @@ class OAIAttention(nn.Module): base=config.rope_theta, dtype=torch.float32, rope_scaling={ - "rope_type": - "yarn", - "factor": - config.rope_scaling["factor"], - "original_max_position_embeddings": - config.rope_scaling["original_max_position_embeddings"], - "beta_fast": - config.rope_scaling["beta_fast"], - "beta_slow": - config.rope_scaling["beta_slow"], + "rope_type": "yarn", + "factor": config.rope_scaling["factor"], + "original_max_position_embeddings": config.rope_scaling[ + "original_max_position_embeddings" + ], + "beta_fast": config.rope_scaling["beta_fast"], + "beta_slow": config.rope_scaling["beta_slow"], }, is_neox_style=True, ) @@ -71,11 +81,8 @@ class OAIAttention(nn.Module): tp_size = get_tensor_model_parallel_world_size() self.sinks = torch.nn.Parameter( - torch.empty(config.num_attention_heads // tp_size, - dtype=torch.bfloat16, - requires_grad=False)) - - self.norm = RMSNorm(config.hidden_size, eps=1e-5) + torch.empty(config.num_attention_heads // tp_size, requires_grad=False) + ) self.q_size = self.num_attention_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size @@ -102,8 +109,7 @@ class OAIAttention(nn.Module): self.num_local_key_value_heads = config.num_key_value_heads // tp_size # Only apply sliding window to every other layer - sliding_window = (config.sliding_window if self.layer_idx % - 2 == 0 else None) + sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None self.attn = Attention( self.num_local_attention_heads, self.head_dim, @@ -117,84 +123,108 @@ class OAIAttention(nn.Module): sinks=self.sinks, ) - def forward(self, hidden_states: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - t = self.norm(hidden_states) - - qkv, _ = self.qkv(t) + def forward( + self, hidden_states: torch.Tensor, positions: torch.Tensor + ) -> torch.Tensor: + qkv, _ = self.qkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) v = v.contiguous() attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) - - return output + hidden_states + return output class MLPBlock(torch.nn.Module): - def __init__( self, - config: GptOssConfig, + vllm_config: VllmConfig, layer_idx: int, - quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + self.layer_idx = layer_idx self.num_experts = config.num_local_experts self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.norm = RMSNorm(config.hidden_size, eps=1e-5) - self.router = torch.nn.Linear(config.hidden_size, - config.num_local_experts, - dtype=torch.bfloat16) + self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) assert config.intermediate_size % self.world_size == 0 - self.experts = FusedMoE(num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - prefix=f"{prefix}.experts", - apply_router_weight_on_input=False, - has_bias=True, - activation="swigluoai") + self.experts = FusedMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False, + has_bias=True, + activation="swigluoai", + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - t = self.norm(x) - g = self.router(t) - t = self.experts(hidden_states=t, router_logits=g) - return x + t + num_tokens = x.shape[0] + if self.is_sequence_parallel: + x = sequence_parallel_chunk(x) + + g = self.router(x) + x = self.experts(hidden_states=x, router_logits=g) + + if self.is_sequence_parallel: + x = tensor_model_parallel_all_gather(x.contiguous(), 0) + x = x[:num_tokens] + return x class TransformerBlock(torch.nn.Module): - def __init__( self, - config: GptOssConfig, - quant_config: QuantizationConfig, + vllm_config: VllmConfig, prefix: str = "", ): super().__init__() - self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, prefix=f"{prefix}.attn") - self.mlp = MLPBlock(config, - self.layer_idx, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - def forward(self, hidden_states: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - attn_output = self.attn(hidden_states, positions) - output = self.mlp(attn_output) - return output + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention( + config, prefix=f"{prefix}.attn", cache_config=cache_config + ) + self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.attn(hidden_states, positions) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + output = self.mlp(hidden_states) + return output, residual @support_torch_compile class GptOssModel(nn.Module): - def __init__( self, *, @@ -203,28 +233,60 @@ class GptOssModel(nn.Module): ): super().__init__() self.config = vllm_config.model_config.hf_config - self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, ) - self.layers = torch.nn.ModuleList([ - TransformerBlock( - self.config, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, f"block.{layer_idx}"), - ) for layer_idx in range(self.config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: TransformerBlock( + vllm_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) + self.aux_hidden_state_layers = tuple[int, ...]() - def forward(self, input_ids: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: - x = self.embedding(input_ids) - for layer in self.layers: - x = layer(x, positions) - x = self.norm(x) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + x = inputs_embeds + else: + x = self.get_input_embeddings(input_ids) + + residual = None + else: + assert intermediate_tensors is not None + x = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if i in self.aux_hidden_state_layers: + aux_hidden_states.append(x if residual is None else x + residual) + x, residual = layer(x, positions, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": x, "residual": residual}) + x, _ = self.norm(x, residual) + + if len(aux_hidden_states) > 0: + return x, aux_hidden_states return x def _load_weights_mxfp4( @@ -248,17 +310,18 @@ class GptOssModel(nn.Module): intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = cdiv(intermediate_size_block, - tp_size) - per_rank_intermediate_size = (per_rank_intermediate_size_block * - mxfp4_block) + per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) + per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # FIXME(woosuk): Remove this after testing. weight = weight.cuda() @@ -267,18 +330,17 @@ class GptOssModel(nn.Module): if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w2_weight_scale" in name: @@ -286,66 +348,68 @@ class GptOssModel(nn.Module): if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] + narrow_weight = weight[ + ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block + ] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w13_weight" in name: # Handle MLP gate and up projection weights # flat weight from (E, 2 * N, block_size, entry_per_block) # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(num_experts, 2 * intermediate_size, - -1).contiguous() + weight = weight.view( + num_experts, 2 * intermediate_size, -1 + ).contiguous() # Extract gate and up projection parts # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w2_weight" in name: # Handle MLP down projection weights # same flatten here, but since 2 mx4 value are packed in 1 # uint8, divide by 2 - weight = weight.view(num_experts, -1, - intermediate_size // 2).contiguous() + weight = weight.view( + num_experts, -1, intermediate_size // 2 + ).contiguous() if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[..., - tp_rank_start // 2:tp_rank_end // 2] + narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w13_bias" in name: @@ -354,35 +418,32 @@ class GptOssModel(nn.Module): if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w2_bias" in name: # Handle MLP down projection bias param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() - weight_loader(param, - weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader( + param, weight, weight_name=name, shard_id=None, expert_id=None + ) loaded_params.add(name) continue elif "sinks" in name: @@ -397,8 +458,7 @@ class GptOssModel(nn.Module): continue name = name.replace(weight_name, param_name) param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: @@ -409,8 +469,7 @@ class GptOssModel(nn.Module): if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params @@ -436,18 +495,20 @@ class GptOssModel(nn.Module): per_rank_intermediate_size = cdiv(intermediate_size, tp_size) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ".w13_weight" in name: # Handle MLP gate and up projection weights # Extract gate and up projection parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, :, - 2 * tp_rank_start:2 * tp_rank_end] + narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() param = params_dict[name] @@ -473,8 +534,7 @@ class GptOssModel(nn.Module): if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] param.copy_(narrow_weight) @@ -504,8 +564,7 @@ class GptOssModel(nn.Module): continue name = name.replace(weight_name, param_name) param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: @@ -516,14 +575,12 @@ class GptOssModel(nn.Module): if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv", ".q_proj", "q"), @@ -545,41 +602,48 @@ class GptOssModel(nn.Module): ep_rank_start = ep_rank * experts_per_rank ep_rank_end = (ep_rank + 1) * experts_per_rank - quant_method = (self.config.quantization_config['quant_method'] if - hasattr(self.config, "quantization_config") else None) + quant_method = ( + self.config.quantization_config["quant_method"] + if hasattr(self.config, "quantization_config") + else None + ) if quant_method == "mxfp4": - return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, - heads_per_rank, head_start, - weights, stacked_params_mapping) + return self._load_weights_mxfp4( + ep_rank_end, + ep_rank_start, + heads_per_rank, + head_start, + weights, + stacked_params_mapping, + ) else: - return self._load_weights_other(ep_rank_end, ep_rank_start, - heads_per_rank, head_start, - weights, stacked_params_mapping) + return self._load_weights_other( + ep_rank_end, + ep_rank_start, + heads_per_rank, + head_start, + weights, + stacked_params_mapping, + ) -class GptOssForCausalLM(nn.Module): +class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".self_attn.": ".attn.", - ".post_attention_layernorm.": ".mlp.norm.", }, orig_to_new_suffix={ ".embed_tokens.weight": ".embedding.weight", - ".input_layernorm.weight": ".attn.norm.weight", - ".post_attention_layernorm.weight": ".mlp.norm.weight", - # MoE MXFP4 weights ".gate_up_proj_blocks": ".w13_weight", ".down_proj_blocks": ".w2_weight", ".gate_up_proj_scales": ".w13_weight_scale", ".down_proj_scales": ".w2_weight_scale", - # MoE other weights ".gate_up_proj": ".w13_weight", ".down_proj": ".w2_weight", - # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", @@ -602,29 +666,39 @@ class GptOssForCausalLM(nn.Module): self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: - assert intermediate_tensors is None - assert inputs_embeds is None - return self.model(input_ids, positions) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 507a9206c4281..e9bc592c0797b 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -36,27 +38,36 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_layers, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) class GraniteMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -72,15 +83,19 @@ class GraniteMLP(nn.Module): output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -91,7 +106,6 @@ class GraniteMLP(nn.Module): class GraniteAttention(nn.Module): - def __init__( self, config: GraniteConfig, @@ -156,13 +170,15 @@ class GraniteAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -178,7 +194,6 @@ class GraniteAttention(nn.Module): class GraniteDecoderLayer(nn.Module): - def __init__( self, config: GraniteConfig, @@ -192,21 +207,24 @@ class GraniteDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = GraniteAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -224,10 +242,10 @@ class GraniteDecoderLayer(nn.Module): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -252,7 +270,6 @@ class GraniteDecoderLayer(nn.Module): @support_torch_compile class GraniteModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -263,12 +280,16 @@ class GraniteModel(nn.Module): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -276,18 +297,22 @@ class GraniteModel(nn.Module): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GraniteDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: GraniteDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -308,28 +333,26 @@ class GraniteModel(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) - residual = None hidden_states *= self.config.embedding_multiplier else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -341,18 +364,19 @@ class GraniteModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -381,8 +405,7 @@ class GraniteModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -418,8 +441,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = GraniteModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = GraniteModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -431,8 +455,10 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -441,9 +467,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if hasattr(config, "logits_scaling"): logit_scale /= config.logits_scaling - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, scale=logit_scale + ) else: self.lm_head = PPMissingLayer() @@ -457,38 +483,31 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - skip_prefixes = (["lm_head."] - if self.config.tie_word_embeddings else None) + skip_prefixes = ["lm_head."] if self.config.tie_word_embeddings else None loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index c3ac3bb78c83d..82bceaf3ed019 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite speech model.""" + import math from collections.abc import Iterable, Mapping from typing import Annotated, Optional, Union @@ -33,35 +34,46 @@ from torch import nn from transformers import BatchFeature, PretrainedConfig from vllm.config import CacheConfig, VllmConfig -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, embed_multimodal, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix ### Audio Input class GraniteSpeechAudioInputs(TensorSchema): """ Audio input features for Granite Speech model. - + Dimensions: - b: Batch size - fi: Number of input features from the Mel spectrogram. @@ -80,7 +92,6 @@ class GraniteSpeechAudioInputs(TensorSchema): class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1} @@ -97,8 +108,8 @@ class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): ### Input Processing & Multimodal utils class GraniteSpeechMultiModalProcessor( - BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): - + BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo] +): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_hf_processor().audio_processor sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] @@ -134,7 +145,8 @@ class GraniteSpeechMultiModalProcessor( audio = audios.get(item_idx) audio_length = audio.shape[-1] num_projector_features = feature_extractor._get_num_audio_features( - [audio_length])[0] + [audio_length] + )[0] return [audio_token_id] * num_projector_features return [ @@ -170,28 +182,30 @@ class GraniteSpeechMultiModalProcessor( # Calculate the number of audio tokens per entry in the batch; # This is used to split the batch back out after padding. audio_token_index = self.info.get_hf_config().audio_token_index - processed_outputs["audio_embed_sizes"] = [ - torch.sum(indices == audio_token_index).item() - for indices in processed_outputs["input_ids"] - ] + processed_outputs["audio_embed_sizes"] = ( + processed_outputs["input_ids"] == audio_token_index + ).sum(-1) return processed_outputs class GraniteSpeechDummyInputsBuilder( - BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): - + BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo] +): def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + return { - "audio": - self._get_dummy_audios( + "audio": self._get_dummy_audios( length=self.info.get_max_audio_len(), num_audios=num_audios, + overrides=audio_overrides, ) } @@ -204,7 +218,6 @@ class GraniteSpeechDummyInputsBuilder( ### QFormer Projector class GraniteSpeechEncoderProjector(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -219,8 +232,8 @@ class GraniteSpeechEncoderProjector(nn.Module): self.num_queries = config.window_size // config.downsample_rate self.query = nn.Parameter( - torch.zeros(1, self.num_queries, - config.projector_config.hidden_size)) + torch.zeros(1, self.num_queries, config.projector_config.hidden_size) + ) # NOTE - this is implemented generically in transformers, # but for now we create the QFormer model directly since @@ -231,17 +244,16 @@ class GraniteSpeechEncoderProjector(nn.Module): cache_config=cache_config, prefix=f"{prefix}.qformer", ) - self.linear = nn.Linear(config.projector_config.hidden_size, - config.text_config.hidden_size) + self.linear = nn.Linear( + config.projector_config.hidden_size, config.text_config.hidden_size + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = hidden_states.size() nblocks = math.ceil(seq_len / self.window_size) pad = nblocks * self.window_size - seq_len - hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), - "constant", 0) - hidden_states = hidden_states.view(batch_size * nblocks, - self.window_size, dim) + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0) + hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim) last_hidden_state = self.qformer( query_embeds=self.query.data, @@ -253,7 +265,8 @@ class GraniteSpeechEncoderProjector(nn.Module): batch_size, nblocks * self.window_size // self.downsample_rate, -1, - )) + ) + ) return query_proj @@ -263,10 +276,12 @@ class GraniteSpeechEncoderProjector(nn.Module): class GraniteSpeechConformerFeedForward(nn.Module): """Feedforward module for conformer encoder blocks.""" - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.pre_norm = nn.LayerNorm(config.hidden_dim) @@ -312,16 +327,16 @@ class GraniteSpeechConformerAttention(nn.Module): self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, config.hidden_dim) - self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, - self.dim_head) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) if self.context_size <= 0 or self.context_size > self.max_pos_emb: raise ValueError( "Context size is either less than 0 or exceeds the max_pos_emb" ) - def forward(self, hidden_states: torch.Tensor, - attention_dists: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, attention_dists: torch.Tensor + ) -> torch.Tensor: hidden_states = self.pre_norm(hidden_states) bsz, num_features, _ = hidden_states.shape @@ -330,47 +345,53 @@ class GraniteSpeechConformerAttention(nn.Module): if remainder > 0: # right padding to reach block size hidden_states = torch.nn.functional.pad( - hidden_states, (0, 0, 0, self.context_size - remainder)) + hidden_states, (0, 0, 0, self.context_size - remainder) + ) # NOTE: would be nice to try to use qkvparallellinear # here for this block attention implementation if possible query_states = self.to_q(hidden_states) key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) - query_states = query_states.reshape(bsz, num_blocks, self.context_size, - self.num_heads, - -1).transpose(2, 3) - key_states = key_states.reshape(bsz, num_blocks, self.context_size, - self.num_heads, -1).transpose(2, 3) - value_states = value_states.reshape(bsz, num_blocks, self.context_size, - self.num_heads, - -1).transpose(2, 3) + query_states = query_states.reshape( + bsz, num_blocks, self.context_size, self.num_heads, -1 + ).transpose(2, 3) + key_states = key_states.reshape( + bsz, num_blocks, self.context_size, self.num_heads, -1 + ).transpose(2, 3) + value_states = value_states.reshape( + bsz, num_blocks, self.context_size, self.num_heads, -1 + ).transpose(2, 3) # shaw's relative positional embedding dist = attention_dists.to(hidden_states.device) rel_pos_emb = self.rel_pos_emb(dist) - rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + - list(rel_pos_emb.shape)) - pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, - dim=-1) * self.scale + rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) + pos_attn = ( + torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) + * self.scale + ) if remainder > 0: # masked attention in the extended block - mask = torch.ones(self.context_size, - self.context_size, - dtype=bool, - device=hidden_states.device) + mask = torch.ones( + self.context_size, + self.context_size, + dtype=bool, + device=hidden_states.device, + ) mask[:remainder, :remainder] = 0 mask_value = -torch.finfo(pos_attn.dtype).max pos_attn[:, -1, :].masked_fill_(mask, mask_value) - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH): - out = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - attn_mask=pos_attn, - scale=self.scale) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=pos_attn, + scale=self.scale, + ) out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) return self.to_out(out[:, :num_features, :]) @@ -378,22 +399,16 @@ class GraniteSpeechConformerAttention(nn.Module): class GraniteSpeechConformerDepthWiseConv1d(nn.Module): """Wrapper for padded 1D pointwise convolution.""" - def __init__(self, - chan_in: int, - chan_out: int, - kernel_size: int, - prefix: str = ""): + def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""): super().__init__() # Padding for the 1D conv is symmetric or close (i.e., offset by one). pad = kernel_size // 2 pad_offset = (kernel_size + 1) % 2 self.padding = (pad, pad - pad_offset) - self.conv = nn.Conv1d(chan_in, - chan_out, - kernel_size, - groups=chan_in, - bias=False) + self.conv = nn.Conv1d( + chan_in, chan_out, kernel_size, groups=chan_in, bias=False + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = F.pad(hidden_states, self.padding) @@ -438,21 +453,19 @@ class GraniteSpeechConformerBlock(nn.Module): def __init__(self, config: PretrainedConfig, prefix: str = ""): super().__init__() - self.ff1 = GraniteSpeechConformerFeedForward(config, - prefix=f"{prefix}.ff1") - self.attn = GraniteSpeechConformerAttention(config, - prefix=f"{prefix}.attn") - self.conv = GraniteSpeechConformerConvModule(config, - prefix=f"{prefix}.conv") - self.ff2 = GraniteSpeechConformerFeedForward(config, - prefix=f"{prefix}.ff2") + self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1") + self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn") + self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv") + self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2") self.post_norm = nn.LayerNorm(config.hidden_dim) - def forward(self, hidden_states: torch.Tensor, - attention_dists: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, attention_dists: torch.Tensor + ) -> torch.Tensor: hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states - hidden_states = self.attn( - hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = ( + self.attn(hidden_states, attention_dists=attention_dists) + hidden_states + ) hidden_states = self.conv(hidden_states) + hidden_states hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states hidden_states = self.post_norm(hidden_states) @@ -462,29 +475,33 @@ class GraniteSpeechConformerBlock(nn.Module): class GraniteSpeechCTCEncoder(nn.Module): """CTC Encoder comprising conformer blocks and additional linear layers.""" - def __init__(self, - config: PretrainedConfig, - prefix: str, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config: PretrainedConfig, + prefix: str, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config # Precompute clamped relative positional encoding distances seq = torch.arange(config.context_size) relpos_dist = seq.view(-1, 1) - seq.view(1, -1) - self.attention_dists = torch.clamp( - relpos_dist, -config.context_size, - config.context_size) + config.max_pos_emb + self.attention_dists = ( + torch.clamp(relpos_dist, -config.context_size, config.context_size) + + config.max_pos_emb + ) - self.input_linear = nn.Linear(config.input_dim, - config.hidden_dim, - bias=True) - self.layers = nn.ModuleList([ - GraniteSpeechConformerBlock( - config, - prefix=f"{prefix}.layers.{idx}", - ) for idx in range(config.num_layers) - ]) + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList( + [ + GraniteSpeechConformerBlock( + config, + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(config.num_layers) + ] + ) self.out = ColumnParallelLinear( input_size=config.hidden_dim, @@ -507,8 +524,7 @@ class GraniteSpeechCTCEncoder(nn.Module): def forward(self, hidden_states: torch.Tensor): hidden_states = self.input_linear(hidden_states) for idx, layer in enumerate(self.layers, start=1): - hidden_states = layer(hidden_states, - attention_dists=self.attention_dists) + hidden_states = layer(hidden_states, attention_dists=self.attention_dists) if idx == self.num_layers // 2: hidden_states_mid = hidden_states.clone() @@ -522,13 +538,15 @@ class GraniteSpeechCTCEncoder(nn.Module): @MULTIMODAL_REGISTRY.register_processor( GraniteSpeechMultiModalProcessor, info=GraniteSpeechMultiModalProcessingInfo, - dummy_inputs=GraniteSpeechDummyInputsBuilder) + dummy_inputs=GraniteSpeechDummyInputsBuilder, +) class GraniteSpeechForConditionalGeneration( - nn.Module, - SupportsMultiModal, - SupportsPP, - SupportsLoRA, + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, ): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": [ @@ -549,7 +567,7 @@ class GraniteSpeechForConditionalGeneration( raise ValueError("Only audio modality is supported") - def __init__(self, *, vllm_config: VllmConfig, prefix: str): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -582,7 +600,8 @@ class GraniteSpeechForConditionalGeneration( ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_audio_input( self, @@ -600,17 +619,21 @@ class GraniteSpeechForConditionalGeneration( # from the processor, but we handle rebuilding it here since # vLLM generally processes everything independently + batches. if input_features_mask is None: - input_features_mask = self._build_input_features_mask( - audio_embed_sizes) + input_features_mask = self._build_input_features_mask(audio_embed_sizes) if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") + raise ValueError( + "Incorrect type of audio input features. " + f"Got type: {type(input_features)}" + ) if input_features_mask is not None and not isinstance( - input_features_mask, torch.Tensor): - raise ValueError("Incorrect type of audio input features mask. " - f"Got type: {type(input_features_mask)}") + input_features_mask, torch.Tensor + ): + raise ValueError( + "Incorrect type of audio input features mask. " + f"Got type: {type(input_features_mask)}" + ) if isinstance(input_features, torch.Tensor): # Granite speech currently only allows one audio token per instance @@ -623,16 +646,17 @@ class GraniteSpeechForConditionalGeneration( if len(input_features.shape) != 3: raise ValueError( "Squeezed input features should be 3D but are of shape " - f"{input_features.shape}") - input_features = input_features.to( - self.encoder.input_linear.weight.dtype) + f"{input_features.shape}" + ) + input_features = input_features.to(self.encoder.input_linear.weight.dtype) else: # Otherwise we have a list of tensors, which are almost certainly # differing in their respective numbers of audio features; # stack them into a 3D tensor of size [bsz, most_num_features, 160]. input_features = self._pad_and_stack_input_features( - input_features, ).to(self.encoder.input_linear.weight.dtype) + input_features, + ).to(self.encoder.input_linear.weight.dtype) return GraniteSpeechAudioInputs( input_features=input_features, @@ -704,7 +728,7 @@ class GraniteSpeechForConditionalGeneration( audio_input: GraniteSpeechAudioInputs, ) -> tuple[torch.Tensor]: """Compute the audio features to be merged into the LLM embeddings. - + Args: audio_input: GraniteSpeechAudioInputs Audio inputs object containing Mel features, an input features @@ -721,6 +745,9 @@ class GraniteSpeechForConditionalGeneration( # Split variable length features into a tuple return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object, @@ -729,7 +756,7 @@ class GraniteSpeechForConditionalGeneration( audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] - return None + audio_features = self._process_audio_input(audio_input) return audio_features @@ -737,19 +764,21 @@ class GraniteSpeechForConditionalGeneration( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - """Compute the merged LLM / audio embeddings.""" - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.audio_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -762,26 +791,16 @@ class GraniteSpeechForConditionalGeneration( if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - audio_embeds = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) - input_ids = None - - model_output = self.language_model(input_ids, positions, - intermediate_tensors, inputs_embeds) + model_output = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits( - hidden_states, - sampling_metadata, - ) + return self.language_model.compute_logits(hidden_states) def load_weights( self, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 7d31854dce8d8..4711ed05c5879 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -23,36 +23,46 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch from torch import nn -from transformers.models.granitemoe import GraniteMoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers, - maybe_prefix) +from .utils import AutoWeightsLoader, is_pp_missing_parameter, make_layers, maybe_prefix class GraniteMoeMoE(nn.Module): @@ -63,49 +73,69 @@ class GraniteMoeMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + is_sequence_parallel=False, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size + self.is_sequence_parallel = is_sequence_parallel # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + num_tokens = orig_shape[0] + final_hidden_states = final_hidden_states[:num_tokens] + return final_hidden_states.view(orig_shape) class GraniteMoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -138,8 +168,11 @@ class GraniteMoeAttention(nn.Module): self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = (attention_multiplier if attention_multiplier - is not None else self.head_dim**-1) + self.scaling = ( + attention_multiplier + if attention_multiplier is not None + else self.head_dim**-1 + ) self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( @@ -166,13 +199,15 @@ class GraniteMoeAttention(nn.Module): is_neox_style=True, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -188,15 +223,18 @@ class GraniteMoeAttention(nn.Module): class GraniteMoeDecoderLayer(nn.Module): - def __init__( self, - config: GraniteMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) @@ -211,19 +249,22 @@ class GraniteMoeDecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - attention_multiplier=config.attention_multiplier) + attention_multiplier=config.attention_multiplier, + ) self.block_sparse_moe = GraniteMoeMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + is_sequence_parallel=parallel_config.use_sequence_parallel_moe, + prefix=f"{prefix}.block_sparse_moe", + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.residual_multiplier = config.residual_multiplier @@ -250,19 +291,20 @@ class GraniteMoeDecoderLayer(nn.Module): @support_torch_compile class GraniteMoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -275,10 +317,9 @@ class GraniteMoeModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GraniteMoeDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), - prefix=f"{prefix}.layers") + lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -298,26 +339,24 @@ class GraniteMoeModel(nn.Module): else: hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier - residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def _load_weights(self, - weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """ - This function is copied from `MixtralModel.load_weights`, mainly to - decouple from mixtral, avoiding impact on support like BNB + This function is copied from `MixtralModel.load_weights`, mainly to + decouple from mixtral, avoiding impact on support like BNB quantization. """ stacked_params_mapping = [ @@ -333,30 +372,33 @@ class GraniteMoeModel(nn.Module): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_local_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -379,21 +421,25 @@ class GraniteMoeModel(nn.Module): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -404,40 +450,45 @@ class GraniteMoeModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: new_weights = {} for n, p in weights: - if n.endswith('.block_sparse_moe.input_linear.weight'): + if n.endswith(".block_sparse_moe.input_linear.weight"): for e in range(p.size(0)): w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) w1_param, w3_param = p[e].chunk(2, dim=0) assert w1_name not in new_weights assert w3_name not in new_weights new_weights[w1_name] = w1_param new_weights[w3_name] = w3_param - elif n.endswith('.block_sparse_moe.output_linear.weight'): + elif n.endswith(".block_sparse_moe.output_linear.weight"): for e in range(p.size(0)): w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) w2_param = p[e] assert w2_name not in new_weights new_weights[w2_name] = w2_param - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) assert gate_name not in new_weights new_weights[gate_name] = p else: @@ -472,8 +523,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.config = config self.lora_config = lora_config - self.model = GraniteMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = GraniteMoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -484,16 +536,19 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=1 / - self.config.logits_scaling) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=1 / self.config.logits_scaling, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -505,36 +560,29 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index f451e65338b78..f877dc5764275 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only GraniteMoeHybrid model.""" + # Added by the IBM Team, 2025 from collections.abc import Iterable from typing import Optional @@ -9,72 +10,73 @@ import torch from torch import nn from transformers import GraniteMoeHybridConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GraniteMoeHybridMambaDecoderLayer(nn.Module): - - def __init__(self, - config: GraniteMoeHybridConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: GraniteMoeHybridConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size self.residual_multiplier = config.residual_multiplier - self.mamba = MambaMixer2(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - n_groups=config.mamba_n_groups, - num_heads=config.mamba_n_heads, - head_dim=config.mamba_d_head, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=config.mamba_expand * config.hidden_size, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -84,33 +86,32 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") - - self.shared_mlp = None if \ - getattr(config, 'shared_intermediate_size', 0) == 0 \ - else GraniteMoeSharedMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.shared_mlp" + prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.shared_mlp = ( + None + if getattr(config, "shared_intermediate_size", 0) == 0 + else GraniteMoeSharedMLP( + config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp" + ) + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mamba(hidden_states, output) hidden_states = residual + output * self.residual_multiplier residual = hidden_states @@ -124,8 +125,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): if self.block_sparse_moe is not None: moe_hidden_states = hidden_states.clone() moe_hidden_states = self.block_sparse_moe(moe_hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp( - hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) del moe_hidden_states else: hidden_states = self.shared_mlp(hidden_states) @@ -135,7 +135,6 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): class GraniteMoeHybridAttentionDecoderLayer(nn.Module): - def __init__( self, config: GraniteMoeHybridConfig, @@ -153,7 +152,8 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): config, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -163,28 +163,27 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") - - self.shared_mlp = None if \ - getattr(config, 'shared_intermediate_size', 0) == 0 \ - else GraniteMoeSharedMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.shared_mlp" + prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.shared_mlp = ( + None + if getattr(config, "shared_intermediate_size", 0) == 0 + else GraniteMoeSharedMLP( + config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp" + ) + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -206,8 +205,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): if self.block_sparse_moe is not None: moe_hidden_states = hidden_states.clone() moe_hidden_states = self.block_sparse_moe(moe_hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp( - hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) del moe_hidden_states else: hidden_states = self.shared_mlp(hidden_states) @@ -217,7 +215,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): class GraniteMoeHybridAttention(nn.Module): - def __init__( self, config: GraniteMoeHybridConfig, @@ -249,19 +246,23 @@ class GraniteMoeHybridAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size) - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.hidden_size, - self.hidden_size, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if config.position_embedding_type == "rope": self.rotary_emb = get_rope( @@ -269,34 +270,38 @@ class GraniteMoeHybridAttention(nn.Module): rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=int(config.rope_theta), - rope_scaling=config.rope_scaling \ - if hasattr(config, "rope_scaling") \ - and config.rope_scaling is not None else None, + rope_scaling=config.rope_scaling + if hasattr(config, "rope_scaling") and config.rope_scaling is not None + else None, is_neox_style=True, ) else: self.rotary_emb = None - self.attn = Attention(self.num_heads, - self.head_dim, - self.attention_multiplier, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.attention_multiplier, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - query, key, value = qkv.split([ - self.num_heads * self.head_dim, self.num_key_value_heads * - self.head_dim, self.num_key_value_heads * self.head_dim - ], - dim=-1) + query, key, value = qkv.split( + [ + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=-1, + ) if self.rotary_emb is not None: query, key = self.rotary_emb(positions, query, key) @@ -316,7 +321,6 @@ ALL_DECODER_LAYER_TYPES = { @support_torch_compile class GraniteMoeHybridModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -327,8 +331,11 @@ class GraniteMoeHybridModel(nn.Module): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -341,8 +348,7 @@ class GraniteMoeHybridModel(nn.Module): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layer_types[layer_idx]] + layer_class = ALL_DECODER_LAYER_TYPES[config.layer_types[layer_idx]] return layer_class( config, layer_idx, @@ -353,10 +359,11 @@ class GraniteMoeHybridModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -367,22 +374,9 @@ class GraniteMoeHybridModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -392,41 +386,27 @@ class GraniteMoeHybridModel(nn.Module): residual = None else: if intermediate_tensors is None: - raise RuntimeError('Intermediate tensors may not be None!') + raise RuntimeError("Intermediate tensors may not be None!") hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] num_attn = 0 - for i in range(len(self.layers)): - layer = self.layers[i] + for i, layer in enumerate(self.layers): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): num_attn += 1 - - layer_mamba_cache_params = None - if isinstance( - layer, - GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_attn) - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata) + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -438,8 +418,7 @@ class GraniteMoeHybridModel(nn.Module): def _load(n, p): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, p) loaded_params.add(n) @@ -447,20 +426,14 @@ class GraniteMoeHybridModel(nn.Module): # Skip layers on other devices. if not is_pp_missing_parameter(n, self): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, p, shard_id) loaded_params.add(n) def _load_expert(n, p, name, shard_id, expert_id): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - p, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id) loaded_params.add(n) for n, p in weights: @@ -473,49 +446,62 @@ class GraniteMoeHybridModel(nn.Module): # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) # The renaming and parameter loading logic is the same for weight # and weight_scale tensors so we can reuse them without issues. - if (n.endswith('.block_sparse_moe.input_linear.weight') or - n.endswith('.block_sparse_moe.input_linear.weight_scale')): + if n.endswith(".block_sparse_moe.input_linear.weight") or n.endswith( + ".block_sparse_moe.input_linear.weight_scale" + ): for e in range(p.size(0)): w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) w1_param, w3_param = p[e].chunk(2, dim=0) - _load_expert(n.replace('.input_linear.', '.experts.w13_'), - w1_param, - w1_name, - shard_id='w1', - expert_id=e) - _load_expert(n.replace('.input_linear.', '.experts.w13_'), - w3_param, - w3_name, - shard_id='w3', - expert_id=e) - elif (n.endswith('.block_sparse_moe.output_linear.weight') or - n.endswith('.block_sparse_moe.output_linear.weight_scale')): + _load_expert( + n.replace(".input_linear.", ".experts.w13_"), + w1_param, + w1_name, + shard_id="w1", + expert_id=e, + ) + _load_expert( + n.replace(".input_linear.", ".experts.w13_"), + w3_param, + w3_name, + shard_id="w3", + expert_id=e, + ) + elif n.endswith(".block_sparse_moe.output_linear.weight") or n.endswith( + ".block_sparse_moe.output_linear.weight_scale" + ): for e in range(p.size(0)): w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) w2_param = p[e] - _load_expert(n.replace('.output_linear.', '.experts.w2_'), - w2_param, - w2_name, - shard_id='w2', - expert_id=e) - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") + _load_expert( + n.replace(".output_linear.", ".experts.w2_"), + w2_param, + w2_name, + shard_id="w2", + expert_id=e, + ) + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) _load(gate_name, p) else: loaded = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name in n: - _load_shard(n.replace(weight_name, param_name), - p, - shard_id=shard_id) + _load_shard( + n.replace(weight_name, param_name), p, shard_id=shard_id + ) loaded = True if not loaded: _load(n, p) @@ -523,8 +509,9 @@ class GraniteMoeHybridModel(nn.Module): return loaded_params -class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, - SupportsPP, IsHybrid, SupportsQuant): +class GraniteMoeHybridForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -543,7 +530,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -554,13 +540,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -579,7 +563,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, head_dim=hf_config.mamba_d_head, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -588,19 +571,14 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - if cache_config.enable_prefix_caching: - raise RuntimeError( - "GraniteMoeHybrid currently does not support prefix caching") - self.quant_config = vllm_config.quant_config self.config = config self.scheduler_config = scheduler_config - self.model = GraniteMoeHybridModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = GraniteMoeHybridModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -612,74 +590,47 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + prefix=maybe_prefix(prefix, "lm_head"), + ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=1 / - self.config.logits_scaling) - - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=1 / self.config.logits_scaling, + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 1e2e8544179c7..93302821ca68d 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -5,7 +5,9 @@ The architecture is the same as granitemoe but with the addition of shared experts. """ + from collections.abc import Iterable +from itertools import islice from typing import Optional import torch @@ -17,14 +19,17 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE @@ -33,7 +38,6 @@ from .utils import AutoWeightsLoader, make_layers, maybe_prefix class GraniteMoeSharedMLP(nn.Module): - def __init__( self, config: GraniteMoeSharedConfig, @@ -49,16 +53,20 @@ class GraniteMoeSharedMLP(nn.Module): output_sizes=[self.hidden_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.input_linear") + prefix=f"{prefix}.input_linear", + ) self.output_linear = RowParallelLinear( self.hidden_size, self.input_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.output_linear") + prefix=f"{prefix}.output_linear", + ) if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -69,7 +77,6 @@ class GraniteMoeSharedMLP(nn.Module): class GraniteMoeSharedDecoderLayer(nn.Module): - def __init__( self, config: GraniteMoeSharedConfig, @@ -92,26 +99,28 @@ class GraniteMoeSharedDecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - attention_multiplier=config.attention_multiplier) + attention_multiplier=config.attention_multiplier, + ) self.block_sparse_moe = GraniteMoeMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") - self.shared_mlp = None if \ - getattr(config, 'shared_intermediate_size', 0) == 0 \ + prefix=f"{prefix}.block_sparse_moe", + ) + self.shared_mlp = ( + None + if getattr(config, "shared_intermediate_size", 0) == 0 else GraniteMoeSharedMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.shared_mlp" + config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp" ) + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.residual_multiplier = config.residual_multiplier @@ -145,7 +154,6 @@ class GraniteMoeSharedDecoderLayer(nn.Module): @support_torch_compile class GraniteMoeSharedModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -157,8 +165,11 @@ class GraniteMoeSharedModel(nn.Module): self.config = config self.quant_config = quant_config # Required by MixtralModel self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -175,7 +186,8 @@ class GraniteMoeSharedModel(nn.Module): lambda prefix: GraniteMoeSharedDecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -195,50 +207,52 @@ class GraniteMoeSharedModel(nn.Module): else: hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier - residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: new_weights = {} for n, p in weights: - if n.endswith('.block_sparse_moe.input_linear.weight'): + if n.endswith(".block_sparse_moe.input_linear.weight"): for e in range(p.size(0)): w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) w1_param, w3_param = p[e].chunk(2, dim=0) assert w1_name not in new_weights assert w3_name not in new_weights new_weights[w1_name] = w1_param new_weights[w3_name] = w3_param - elif n.endswith('.block_sparse_moe.output_linear.weight'): + elif n.endswith(".block_sparse_moe.output_linear.weight"): for e in range(p.size(0)): w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) w2_param = p[e] assert w2_name not in new_weights new_weights[w2_name] = w2_param - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) assert gate_name not in new_weights new_weights[gate_name] = p else: @@ -273,9 +287,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.config = config self.lora_config = lora_config - self.model = GraniteMoeSharedModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = GraniteMoeSharedModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -286,16 +300,19 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + prefix=maybe_prefix(prefix, "lm_head"), + ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=1 / - self.config.logits_scaling) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=1 / self.config.logits_scaling, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -307,36 +324,29 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 3f6790269ae62..ac78dd9e753aa 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -9,18 +9,22 @@ import torch.nn as nn from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, - PoolerHead, PoolerNormalize, - PoolingParamsUpdate, - build_output, get_prompt_lens, - get_prompt_token_ids) +from vllm.model_executor.layers.pooler import ( + DispatchPooler, + Pooler, + PoolerHead, + PoolerNormalize, + PoolingParamsUpdate, + get_prompt_lens, + get_prompt_token_ids, +) from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import PoolerOutput from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.v1.outputs import PoolerOutput +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import default_pooling_type +from .interfaces_base import default_pooling_type logger = init_logger(__name__) @@ -47,12 +51,11 @@ class GritLMMeanPool(nn.Module): def tokens_to_ids(tokens: list[str]) -> np.ndarray: return np.array([self.token_ids[token] for token in tokens]) - self.user_pattern_ids = tokens_to_ids( - ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) self.embed_newline_pattern_ids = tokens_to_ids( - ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) - self.embed_pattern_ids = tokens_to_ids( - ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"] + ) + self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"]) def _find_array( self, @@ -86,7 +89,7 @@ class GritLMMeanPool(nn.Module): end_idx = arr_len for i in range(start_idx, min(end_idx, arr_len - target_len + 1)): - if (arr[i:i + target_len] == target).all(): + if (arr[i : i + target_len] == target).all(): return i return -1 @@ -105,31 +108,37 @@ class GritLMMeanPool(nn.Module): # Return no instruction in case of missing BOS token. if prompt_token_ids[0] != self.token_ids["<s>"]: - logger.warning("BOS token not found in prompt, " - "thus using empty string for instruction. " - "GritLM requires BOS token in prompt.") + logger.warning( + "BOS token not found in prompt, " + "thus using empty string for instruction. " + "GritLM requires BOS token in prompt." + ) return instruction_len # If user pattern is found in the prompt, that means there should be # a newline token before the embed pattern. embed_pattern_ids = self.embed_pattern_ids - if self._find_array(prompt_token_ids, - self.user_pattern_ids, - start_idx=1, - end_idx=2) == 1: + if ( + self._find_array( + prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2 + ) + == 1 + ): embed_pattern_ids = self.embed_newline_pattern_ids # Find the embed pattern in the prompt. - found_embed_pattern_idx = self._find_array(prompt_token_ids, - embed_pattern_ids, - start_idx=1) + found_embed_pattern_idx = self._find_array( + prompt_token_ids, embed_pattern_ids, start_idx=1 + ) if found_embed_pattern_idx != -1: instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) else: - logger.warning("Query instruction not found in prompt, " - "thus using BOS token as instruction instead. " - "GritLM requires query instruction in prompt.") + logger.warning( + "Query instruction not found in prompt, " + "thus using BOS token as instruction instead. " + "GritLM requires query instruction in prompt." + ) instruction_len = 1 return instruction_len @@ -146,8 +155,9 @@ class GritLMMeanPool(nn.Module): prompt_len: Optional[torch.Tensor] = None, instr_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + assert prompt_len is None or prompt_len == hidden_states.shape[0], ( "partial prefill not supported with MEAN pooling" + ) return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32) @@ -161,9 +171,11 @@ class GritLMMeanPool(nn.Module): pooled_data = list[torch.Tensor]() for prompt_len, instr_len in zip(prompt_lens, instr_lens): - pooled_data.append(hidden_states[offset + instr_len:offset + - prompt_len].mean( - dim=0, dtype=torch.float32)) + pooled_data.append( + hidden_states[offset + instr_len : offset + prompt_len].mean( + dim=0, dtype=torch.float32 + ) + ) offset += prompt_len return pooled_data @@ -184,15 +196,16 @@ class GritLMMeanPool(nn.Module): if isinstance(hidden_states, list): return [ - self.forward_one(h, prompt_len, instr_len) for h, prompt_len, - instr_len in zip(hidden_states, prompt_lens, instr_lens) + self.forward_one(h, prompt_len, instr_len) + for h, prompt_len, instr_len in zip( + hidden_states, prompt_lens, instr_lens + ) ] return self.forward_all(hidden_states, prompt_lens, instr_lens) class GritLMPooler(Pooler): - def __init__(self, model_config: ModelConfig): super().__init__() @@ -212,7 +225,7 @@ class GritLMPooler(Pooler): ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) - return build_output(pooled_data) + return pooled_data @default_pooling_type("MEAN") @@ -254,9 +267,9 @@ class GritLM(LlamaForCausalLM): pooler_config = vllm_config.model_config.pooler_config if pooler_config is not None: - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "embed": - GritLMPooler(vllm_config.model_config), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": GritLMPooler(vllm_config.model_config), + } + ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 3659249cd8bd6..f4139685b79f6 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -22,7 +22,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Grok1 model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -35,23 +37,33 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) # Default Grok1-specific constants, overridden by config values if present DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 @@ -68,37 +80,43 @@ class Grok1MoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - activation="gelu", - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -112,18 +130,17 @@ class Grok1MoE(nn.Module): class Grok1Attention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - config=None, # Added config parameter + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter ) -> None: super().__init__() self.hidden_size = hidden_size @@ -172,19 +189,21 @@ class Grok1Attention(nn.Module): is_neox_style=True, ) - attn_logits_soft_cap = max( - getattr(config, "attn_logit_softcapping", 30.0), 0.0) + attn_logits_soft_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - prefix=f"{prefix}.attn") - self.attn_multiplier = getattr(self.config, "attn_output_multiplier", - 1.0) if self.config else 1.0 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + prefix=f"{prefix}.attn", + ) + self.attn_multiplier = ( + getattr(self.config, "attn_output_multiplier", 1.0) if self.config else 1.0 + ) def forward( self, @@ -201,7 +220,6 @@ class Grok1Attention(nn.Module): class Grok1DecoderLayer(nn.Module): - def __init__( self, config, @@ -214,8 +232,7 @@ class Grok1DecoderLayer(nn.Module): # Check for fp8 quantization self.use_fp8 = False if quant_config is not None: - self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", - lambda: False)() + self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", lambda: False)() if not self.use_fp8 and hasattr(quant_config, "is_fp8"): self.use_fp8 = quant_config.is_fp8 @@ -231,27 +248,26 @@ class Grok1DecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", - config=config) # Pass config to Grok1Attention + config=config, + ) # Pass config to Grok1Attention # Grok1 uses "num_experts" in its config num_experts = getattr(config, "num_experts", 8) num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) - self.moe_block = Grok1MoE(num_experts=num_experts, - top_k=num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.moe_block") + self.moe_block = Grok1MoE( + num_experts=num_experts, + top_k=num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block", + ) - self.pre_attn_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attn_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_moe_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -264,8 +280,7 @@ class Grok1DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.pre_attn_norm(hidden_states) else: - hidden_states, residual = self.pre_attn_norm( - hidden_states, residual) + hidden_states, residual = self.pre_attn_norm(hidden_states, residual) hidden_states = self.attn( positions=positions, @@ -285,7 +300,6 @@ class Grok1DecoderLayer(nn.Module): @support_torch_compile class Grok1Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -297,13 +311,16 @@ class Grok1Model(nn.Module): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embedding_multiplier_scale = getattr( - config, "embedding_multiplier_scale", - DEFAULT_EMBEDDING_MULTIPLIER_SCALE) + config, "embedding_multiplier_scale", DEFAULT_EMBEDDING_MULTIPLIER_SCALE + ) self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -317,12 +334,13 @@ class Grok1Model(nn.Module): lambda prefix: Grok1DecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -347,15 +365,13 @@ class Grok1Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -368,10 +384,10 @@ class Grok1Model(nn.Module): ckpt_gate_proj_name="linear", # Grok1 specific ckpt_down_proj_name="linear_1", # Grok1 specific ckpt_up_proj_name="linear_v", # Grok1 specific - num_experts=num_experts) + num_experts=num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -383,25 +399,27 @@ class Grok1Model(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -424,21 +442,25 @@ class Grok1Model(nn.Module): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -454,8 +476,9 @@ class Grok1Model(nn.Module): name = name.replace("scale", "weight") param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -483,8 +506,9 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = Grok1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Grok1Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -503,13 +527,15 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head.weight = self.model.embed_tokens.weight self.output_multiplier_scale = getattr( - config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - self.output_multiplier_scale) + config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, self.output_multiplier_scale + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -521,24 +547,21 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Skip lm_head when tie_word_embeddings is True - skip_prefixes = (["lm_head"] - if self.config.tie_word_embeddings else None) + skip_prefixes = ["lm_head"] if self.config.tie_word_embeddings else None loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 87e451a2769ea..d7ee0fd8fd37c 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -17,21 +17,34 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (MultiModalProcessingInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel -from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, - BaseInternVLDummyInputsBuilder, - BaseInternVLMultiModalProcessor, - BaseInternVLProcessingInfo, BaseInternVLProcessor, - InternVLChatModel, build_transform, - find_closest_aspect_ratio, get_internvl_target_ratios) +from .internvl import ( + IMG_CONTEXT, + IMG_END, + IMG_START, + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, + BaseInternVLProcessor, + InternVLChatModel, + build_transform, + find_closest_aspect_ratio, + get_internvl_target_ratios, +) def resolve_h2ovl_min_max_num( @@ -61,8 +74,10 @@ def get_h2ovl_target_ratios( # if prior_aspect_ratio is provided, filter the target ratios if prior_aspect_ratio is not None: target_ratios = [ - ratio for ratio in target_ratios if prior_aspect_ratio[0] % - ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ratio + for ratio in target_ratios + if prior_aspect_ratio[0] % ratio[0] != 0 + and prior_aspect_ratio[1] % ratio[1] != 0 ] return target_ratios @@ -207,7 +222,8 @@ def image_to_pixel_values_h2ovl( ) # combine pixel values pixel_values = torch.cat( - [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0) + [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0 + ) else: pixel_values, _ = _preprocess_image( @@ -223,7 +239,6 @@ def image_to_pixel_values_h2ovl( class H2OVLProcessor(BaseInternVLProcessor): - def __init__( self, config: PretrainedConfig, @@ -270,14 +285,18 @@ class H2OVLProcessor(BaseInternVLProcessor): dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch - is None else min_dynamic_patch) - max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch - is None else max_dynamic_patch) - dynamic_image_size = (self.dynamic_image_size if dynamic_image_size - is None else dynamic_image_size) - use_thumbnail = (self.use_thumbnail - if use_thumbnail is None else use_thumbnail) + min_dynamic_patch = ( + self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch + ) + max_dynamic_patch = ( + self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch + ) + dynamic_image_size = ( + self.dynamic_image_size + if dynamic_image_size is None + else dynamic_image_size + ) + use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail return resolve_h2ovl_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -318,7 +337,7 @@ class H2OVLProcessor(BaseInternVLProcessor): image_height: int, use_msac: Optional[bool] = None, ) -> int: - use_msac = (self.use_msac if use_msac is None else use_msac) + use_msac = self.use_msac if use_msac is None else use_msac use_thumbnail = self.use_thumbnail @@ -387,12 +406,12 @@ class H2OVLProcessor(BaseInternVLProcessor): max_num=max_num, use_thumbnail=self.use_thumbnail, use_msac=use_msac, - ) for image in images + ) + for image in images ] class H2OVLProcessingInfo(BaseInternVLProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor: return self.ctx.init_processor( H2OVLProcessor, @@ -419,9 +438,7 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ) -class H2OVLMultiModalProcessor( - BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): - +class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -446,7 +463,8 @@ class H2OVLMultiModalProcessor( def get_replacement_internvl(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -479,6 +497,7 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is @@ -490,6 +509,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) return super()._cached_apply_hf_processor( @@ -497,15 +517,16 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) @MULTIMODAL_REGISTRY.register_processor( H2OVLMultiModalProcessor, info=H2OVLProcessingInfo, - dummy_inputs=BaseInternVLDummyInputsBuilder) + dummy_inputs=BaseInternVLDummyInputsBuilder, +) class H2OVLChatModel(InternVLChatModel): - def _init_vision_model( self, config: PretrainedConfig, @@ -517,8 +538,9 @@ class H2OVLChatModel(InternVLChatModel): if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = (config.vision_config.num_hidden_layers + - vision_feature_layer + 1) + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index fbba849a76f23..cf2e5d0d0bd6e 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -23,7 +23,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only HunYuan model compatible with HuggingFace weights.""" -from collections.abc import Iterable + +import typing +from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import regex as re @@ -33,32 +36,45 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_layers) +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) def _is_moe(config: PretrainedConfig) -> bool: @@ -81,7 +97,6 @@ def _get_cla_factor(config: PretrainedConfig) -> int: class HunYuanMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -109,8 +124,9 @@ class HunYuanMLP(nn.Module): reduce_results=reduce_results, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -121,7 +137,6 @@ class HunYuanMLP(nn.Module): class HunYuanAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -205,10 +220,8 @@ class HunYuanAttention(nn.Module): ) if self.use_qk_norm: - self.query_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) - self.key_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -222,9 +235,11 @@ class HunYuanAttention(nn.Module): ori_k = k if self.use_qk_norm: q = self.query_layernorm( - q.view(-1, self.num_heads, self.head_dim).contiguous()) + q.view(-1, self.num_heads, self.head_dim).contiguous() + ) k = self.key_layernorm( - k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + k.view(-1, self.num_kv_heads, self.head_dim).contiguous() + ) attn_output = self.attn(q, k, v) # For o_proj @@ -234,7 +249,6 @@ class HunYuanAttention(nn.Module): class HunYuanCrossAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -317,10 +331,8 @@ class HunYuanCrossAttention(nn.Module): ) if self.use_qk_norm: - self.query_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) - self.key_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -336,9 +348,11 @@ class HunYuanCrossAttention(nn.Module): q, _ = self.rotary_emb(positions, q, k_tmp) if self.use_qk_norm: q = self.query_layernorm( - q.view(-1, self.num_heads, self.head_dim).contiguous()) + q.view(-1, self.num_heads, self.head_dim).contiguous() + ) k = self.key_layernorm( - k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + k.view(-1, self.num_kv_heads, self.head_dim).contiguous() + ) attn_output = self.attn(q, k, v) # For o_proj @@ -348,21 +362,27 @@ class HunYuanCrossAttention(nn.Module): class HunYuanSparseMoeBlock(nn.Module): - def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, layer_id: int = -1, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) # Get layer_id topk if config.moe_topk is a list if isinstance(config.moe_topk, list): @@ -375,26 +395,33 @@ class HunYuanSparseMoeBlock(nn.Module): # If it is moe, moe_intermediate_size is preferred intermediate_size = config.intermediate_size if config.moe_intermediate_size is not None: - intermediate_size = (config.moe_intermediate_size if isinstance( - config.moe_intermediate_size, int) else - config.moe_intermediate_size[layer_id]) + intermediate_size = ( + config.moe_intermediate_size + if isinstance(config.moe_intermediate_size, int) + else config.moe_intermediate_size[layer_id] + ) - self.experts = FusedMoE( - num_experts=config.num_experts, - top_k=top_k, - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - reduce_results=False, - renormalize=top_k > 1, - quant_config=quant_config, - prefix=f"{prefix}.experts", + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts ) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.use_mixed_mlp_moe > 0: # Get layer_id num_shared_expert if config.num_shared_expert is # a list. @@ -415,30 +442,41 @@ class HunYuanSparseMoeBlock(nn.Module): else: self.shared_mlp = None + self.experts = SharedFusedMoE( + shared_experts=self.shared_mlp, + num_experts=self.n_routed_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=top_k > 1, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_mlp is not None: - shared_output = self.shared_mlp(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if self.shared_mlp is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) class HunYuanDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -446,35 +484,43 @@ class HunYuanDecoderLayer(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", layer_id: int = -1, + enable_eplb: bool = False, ) -> None: super().__init__() assert layer_id >= 0 self.layer_id = layer_id self.hidden_size = config.hidden_size - self.intermediate_size = (config.intermediate_size if isinstance( - config.intermediate_size, int) else - config.intermediate_size[layer_id]) + self.intermediate_size = ( + config.intermediate_size + if isinstance(config.intermediate_size, int) + else config.intermediate_size[layer_id] + ) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) cla_factor = _get_cla_factor(config) - attention_type = (AttentionType.ENCODER_DECODER - if layer_id >= 0 and layer_id % cla_factor != 0 else - AttentionType.DECODER) + attention_type = ( + AttentionType.ENCODER_DECODER + if layer_id >= 0 and layer_id % cla_factor != 0 + else AttentionType.DECODER + ) if attention_type == AttentionType.DECODER: self.self_attn = HunYuanAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -489,8 +535,9 @@ class HunYuanDecoderLayer(nn.Module): config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -509,6 +556,7 @@ class HunYuanDecoderLayer(nn.Module): quant_config=quant_config, layer_id=layer_id, prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = HunYuanMLP( @@ -520,10 +568,10 @@ class HunYuanDecoderLayer(nn.Module): prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -537,8 +585,7 @@ class HunYuanDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, ori_kv_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -546,15 +593,13 @@ class HunYuanDecoderLayer(nn.Module): ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual, ori_kv_states @support_torch_compile class HunYuanModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -562,16 +607,23 @@ class HunYuanModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + eplb_config = vllm_config.parallel_config.eplb_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -588,6 +640,7 @@ class HunYuanModel(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=prefix, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers", ) @@ -619,8 +672,9 @@ class HunYuanModel(nn.Module): cla_factor = _get_cla_factor(self.config) prev_kv_states = None - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for i, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): hidden_states, residual, kv_states = layer( positions, hidden_states, @@ -628,25 +682,24 @@ class HunYuanModel(nn.Module): prev_kv_states, ) - if (getattr(self.config, "use_cla", False) - and (i - self.start_layer) % cla_factor == 0): + if getattr(self.config, "use_cla", False) and i % cla_factor == 0: prev_kv_states = kv_states else: prev_kv_states = None if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads) + num_kv_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) num_key_value_groups = num_attention_heads // num_kv_heads hidden_size = self.config.hidden_size @@ -657,8 +710,9 @@ class HunYuanModel(nn.Module): else: attention_head_dim = self.config.hidden_size // num_attention_heads - qkv = qkv.reshape(num_kv_heads, num_key_value_groups + 2, - attention_head_dim, hidden_size) + qkv = qkv.reshape( + num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size + ) q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) q = q.reshape(-1, hidden_size) k = k.reshape(-1, hidden_size) @@ -669,11 +723,12 @@ class HunYuanModel(nn.Module): if _is_moe(self.config): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, ) else: return [] @@ -690,16 +745,16 @@ class HunYuanModel(nn.Module): ] num_attention_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads) + num_kv_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) split_params_mapping = [ (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None), ( ".qkv_proj", ".qkv_proj", num_attention_heads + num_kv_heads * 2, - [("q", num_attention_heads), ("k", num_kv_heads), - ("v", num_kv_heads)], + [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)], self._split_qkv_weight, ), ] @@ -714,8 +769,7 @@ class HunYuanModel(nn.Module): name = name.replace("gate_proj_bias", "gate_proj.bias") if "up_proj_bias" in name: name = name.replace("up_proj_bias", "up_proj.bias") - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -725,11 +779,11 @@ class HunYuanModel(nn.Module): if self.config.tie_word_embeddings and "lm_head.weight" in name: continue if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name)): + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue @@ -765,11 +819,11 @@ class HunYuanModel(nn.Module): continue for ( - param_name, - weight_name, - den, - split_param, - func, + param_name, + weight_name, + den, + split_param, + func, ) in split_params_mapping: if weight_name not in name: continue @@ -790,12 +844,11 @@ class HunYuanModel(nn.Module): for shard_id, num in split_param: new_offset = offset + num * units if func: - weight_loader(param, - func(loaded_weight)[offset:new_offset], - shard_id) + weight_loader( + param, func(loaded_weight)[offset:new_offset], shard_id + ) else: - weight_loader(param, loaded_weight[offset:new_offset], - shard_id) + weight_loader(param, loaded_weight[offset:new_offset], shard_id) offset = new_offset break @@ -803,25 +856,44 @@ class HunYuanModel(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + # this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( param, loaded_weight, - name, + name_mapped, shard_id=shard_id, expert_id=expert_id, + return_success=True, ) - break + if success: + name = name_mapped + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: @@ -834,14 +906,15 @@ class HunYuanModel(nn.Module): name = name.replace("wg.", "") param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class HunYuanV1Base(nn.Module, SupportsLoRA): +class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -871,14 +944,15 @@ class HunYuanV1Base(nn.Module, SupportsLoRA): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() @@ -889,49 +963,116 @@ class HunYuanV1Base(nn.Module, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + +class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Set MoE hyperparameters + self.expert_weights = [] + self.num_expert_groups = 1 + self.moe_layers: list[SharedFusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, HunYuanDecoderLayer) + if isinstance(layer.mlp, HunYuanSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No HunYuanMoE layer found in model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + self.expert_weights.append(layer.get_expert_weights()) + # Register the expert weights. + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, HunYuanSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() -class HunYuanDenseV1ForCausalLM(HunYuanV1Base): +class HunYuanDenseV1Base(HunyuanV1ModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + +class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base): pass -class HunYuanMoEV1ForCausalLM(HunYuanV1Base): +class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base): pass diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index eeb8291c77847..611c14733c71f 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -2,50 +2,52 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # copied from : https://github.com/huggingface/transformers import ast -import sys from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial -from itertools import chain -from typing import Any, Literal, Optional, TypedDict, Union +from itertools import accumulate +from typing import Annotated, Any, Literal, Optional, Union import numpy as np -import PIL -from einops import rearrange -from PIL import Image - -if sys.version_info >= (3, 11): - import typing - Unpack = typing.Unpack -else: - import typing_extensions - Unpack = typing_extensions.Unpack - import torch import torch.nn as nn +from einops import rearrange from timm.layers import LayerNorm, LayerNorm2d from timm.models.regnet import RegStage from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers.modeling_utils import no_init_weights from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix +from .utils import ( + AutoWeightsLoader, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info EOT = "<|endofturn|>" @@ -56,8 +58,8 @@ VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" # Based on combine_frames_into_images in # https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py def get_num_combined_frames( - num_frames: int, - max_grid_shape: tuple[int, int] = (3, 3), + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), ) -> int: max_num_grids = max_grid_shape[0] * max_grid_shape[1] @@ -68,32 +70,48 @@ def get_num_combined_frames( return num_canvases + (leftover_frames > 0) -class HCXVisionMultimodalPixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_images: list[torch.Tensor] +class HCXVisionImagePixelInputs(TensorSchema): """ - Shape: `[(num_grids, num_channels, height, width), ...]` if anyres - - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - n: Number of images + - g: Number of grids + - c: Number of channels (3) + - h: Height + - w: Width """ - image_sizes_images: list[tuple[Union[int, float]]] - """ - Shape: `[(height, width), ...]` - """ - vision_query_lengths_images: list[Union[int, float]] - pixel_values_videos: list[tuple[Union[int, float]]] - """ - Shape: `[(num_grids, num_channels, height, width), ...]` if anyres - """ - vision_query_lengths_videos: list[Union[int, float]] + + type: Literal["pixel_values"] = "pixel_values" + pixel_values_images: Annotated[ + list[torch.Tensor], TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"}) + ] + image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)] -HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs] +HCXVisionImageInputs = HCXVisionImagePixelInputs + + +class HCXVisionVideoPixelInputs(TensorSchema): + """ + Dimensions: + - n: Number of videos + - f: Number of frames + - g: Number of grids + - c: Number of channels (3) + - h: Height + - w: Width + """ + + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[ + list[list[torch.Tensor]], + TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"}), + ] + + +HCXVisionVideoInputs = HCXVisionVideoPixelInputs class HCXVisionProcessingInfo(BaseProcessingInfo): - def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) @@ -134,48 +152,49 @@ class HCXVisionProcessingInfo(BaseProcessingInfo): ) -class HCXVisionDummyInputsBuilder( - BaseDummyInputsBuilder[HCXVisionProcessingInfo]): - +class HCXVisionDummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionProcessingInfo]): def get_dummy_text( self, mm_counts: Mapping[str, int], ) -> str: dummy_text = IMAGE_TOKEN * mm_counts.get( - "image", 0) + VIDEO_TOKEN * mm_counts.get("video", 0) + "image", 0 + ) + VIDEO_TOKEN * mm_counts.get("video", 0) return dummy_text def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = 32 + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + return { - "image": - self._get_dummy_images( + "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, + overrides=image_overrides, ), - "video": - self._get_dummy_videos( + "video": self._get_dummy_videos( width=target_width - 1, height=target_height - 1, num_frames=target_num_frames, num_videos=num_videos, - ) + overrides=video_overrides, + ), } -class HCXVisionMultiModalProcessor( - BaseMultiModalProcessor[HCXVisionProcessingInfo]): - +class HCXVisionMultiModalProcessor(BaseMultiModalProcessor[HCXVisionProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -183,27 +202,9 @@ class HCXVisionMultiModalProcessor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - - def replace_multimodal_token( - token_ids: torch.Tensor, - target_token: int, - repeats: list[int], - ): - output = list[int]() - _repeats_idx = 0 - for token_id in token_ids: - if token_id == target_token: - output += [token_id.item()] * repeats[_repeats_idx] - _repeats_idx += 1 - else: - output += [token_id.item()] - - return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", [])): - if video_arr.dtype == np.uint8: - continue - mm_data["videos"][video_idx] = video_arr.astype(np.uint8) + if video_arr.dtype != np.uint8: + mm_data["videos"][video_idx] = video_arr.astype(np.uint8) processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), @@ -215,20 +216,16 @@ class HCXVisionMultiModalProcessor( ) # text-only if len(mm_data) > 0: + images = mm_data.get("images") + videos = mm_data.get("videos") + # batchify input as a single item - images = mm_data.get("images", None) - batched_images = None if images is None else [images] - - # list of video in single conversation - videos = mm_data.get("videos", None) - batched_videos = None if videos is None else [videos] - _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=batched_images, - videos=batched_videos, + images=None if images is None else [images], + videos=None if videos is None else [videos], ), ) # mm-only @@ -238,51 +235,48 @@ class HCXVisionMultiModalProcessor( _processed_outputs[k] = v[0] if images: - tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=image_token_id, - repeats=_processed_outputs[ - "vision_query_lengths_images"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) + _processed_outputs["image_sizes_images"] = torch.tensor( + _processed_outputs["image_sizes_images"] + ) + _processed_outputs["vision_query_lengths_images"] = torch.tensor( + _processed_outputs["vision_query_lengths_images"] + ) if videos: - _num_per_videos = [ - get_num_combined_frames(len(video)) for video in videos + _idx_per_video = [ + 0, + *accumulate( + get_num_combined_frames(len(video)) for video in videos + ), ] _processed_outputs["pixel_values_videos"] = [ - _processed_outputs["pixel_values_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(len(videos)) + _processed_outputs["pixel_values_videos"][ + _idx_per_video[i] : _idx_per_video[i + 1] + ] + for i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ - _processed_outputs["vision_query_lengths_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(len(videos)) + torch.tensor( + _processed_outputs["vision_query_lengths_videos"][ + _idx_per_video[i] : _idx_per_video[i + 1] + ] + ) + for i in range(len(videos)) ] - tokenizer = self.info.get_tokenizer() - video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=video_token_id, - repeats=[ - sum(lens) for lens in - _processed_outputs["vision_query_lengths_videos"] - ], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - processed_outputs.update(_processed_outputs) return processed_outputs + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -303,13 +297,11 @@ class HCXVisionMultiModalProcessor( out_item = out_mm_kwargs[modality][item_idx] if modality == "image": - lens = out_item["vision_query_lengths_images"].data - num_tokens = self.info.get_num_image_tokens( - vision_query_length=lens) + lens = out_item["vision_query_lengths_images"].data.tolist() + num_tokens = self.info.get_num_image_tokens(vision_query_length=lens) elif modality == "video": - lens = out_item["vision_query_lengths_videos"].data - num_tokens = self.info.get_num_video_tokens( - vision_query_length=lens) + lens = out_item["vision_query_lengths_videos"].data.tolist() + num_tokens = self.info.get_num_video_tokens(vision_query_length=lens) else: raise NotImplementedError(modality) @@ -326,7 +318,8 @@ class HCXVisionMultiModalProcessor( modality=modality, out_mm_kwargs=out_mm_kwargs, ), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -335,31 +328,17 @@ class HCXVisionMultiModalProcessor( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - # image pixel_values_images=MultiModalFieldConfig.batched("image"), image_sizes_images=MultiModalFieldConfig.batched("image"), vision_query_lengths_images=MultiModalFieldConfig.batched("image"), - num_queries_vis_abstractors_images=MultiModalFieldConfig.batched( - "image"), - num_queries_vis_abstractors_slow_images=MultiModalFieldConfig. - batched("image"), - first_last_frames_slows_images=MultiModalFieldConfig.batched( - "image"), - # video pixel_values_videos=MultiModalFieldConfig.batched("video"), - image_sizes_videos=MultiModalFieldConfig.batched("video"), vision_query_lengths_videos=MultiModalFieldConfig.batched("video"), - num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched( - "video"), - num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig. - batched("video"), - first_last_frames_slows_videos=MultiModalFieldConfig.batched( - "video"), ) def _build_hcxvision_hf_info( - ctx: InputProcessingContext, ) -> HCXVisionProcessingInfo: + ctx: InputProcessingContext, +) -> HCXVisionProcessingInfo: return HCXVisionProcessingInfo(ctx) @@ -367,7 +346,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( @@ -417,7 +396,6 @@ def init_vision_tower_for_hcxvision( class HCXVisionMlp(nn.Module): - def __init__( self, mm_projector_type, @@ -439,8 +417,9 @@ class HCXVisionMlp(nn.Module): self.act = act_layer() self.fc2 = nn.Linear(2 * hidden_features, out_features) else: - raise NotImplementedError("{} is not implemented".format( - self.mm_projector_type)) + raise NotImplementedError( + "{} is not implemented".format(self.mm_projector_type) + ) def forward(self, x): x = self.fc1(x) @@ -452,7 +431,7 @@ class HCXVisionMlp(nn.Module): class HCXVisionCAbstractor(nn.Module): """ This module is based on C-Abstractor, whose license is under apache-2.0. - You can check the original code at + You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py and we made necessary modifications. """ @@ -474,7 +453,8 @@ class HCXVisionCAbstractor(nn.Module): # Positional embedding if pos_emb: self.pos_emb = torch.nn.Parameter( - torch.zeros(1, num_input_tokens, encoder_hidden_size)) + torch.zeros(1, num_input_tokens, encoder_hidden_size) + ) self.pos_emb.data.normal_(mean=0.0, std=0.02) else: self.pos_emb = None @@ -485,8 +465,9 @@ class HCXVisionCAbstractor(nn.Module): else: self.prenorm = None - self.build_net(num_queries, encoder_hidden_size, hidden_size, - output_hidden_size) + self.build_net( + num_queries, encoder_hidden_size, hidden_size, output_hidden_size + ) self.dtype = next(self.parameters()).dtype def forward( @@ -523,7 +504,8 @@ class HCXVisionCAbstractor(nn.Module): if num_queries_vis_abstractors is not None: assert num_grids is not None return self._forward_adaptive_num_query( - x, num_queries_vis_abstractors, num_grids) + x, num_queries_vis_abstractors, num_grids + ) x = self.net(x) x = rearrange(x, "b d h w -> b (h w) d") @@ -544,7 +526,7 @@ class HCXVisionCAbstractor(nn.Module): for i, num_queries in enumerate(num_queries_vis_abstractors): hw = int(num_queries**0.5) sampler = nn.AdaptiveAvgPool2d((hw, hw)) - out = sampler(x[num_grids[i]:num_grids[i + 1], :]) + out = sampler(x[num_grids[i] : num_grids[i + 1], :]) out = self.net[2](out) # s2 out = rearrange(out, "b d h w -> b (h w) d") @@ -562,8 +544,9 @@ class HCXVisionCAbstractor(nn.Module): depth: int = 3, mlp_depth: int = 2, ): - assert (n_queries**0.5).is_integer( - ), f"n_queries must be square number. n_queries: {n_queries}" + assert (n_queries**0.5).is_integer(), ( + f"n_queries must be square number. n_queries: {n_queries}" + ) hw = int(n_queries**0.5) # RegBlock = ResBlock + SE @@ -588,8 +571,7 @@ class HCXVisionCAbstractor(nn.Module): ) self.net = nn.Sequential(s1, sampler, s2) - self.readout = self.build_mlp(mlp_depth, hidden_size, - output_hidden_size) + self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) def build_mlp( self, @@ -607,12 +589,14 @@ class HCXVisionCAbstractor(nn.Module): @MULTIMODAL_REGISTRY.register_processor( _build_hcxvision_hf_processor, info=_build_hcxvision_hf_info, - dummy_inputs=HCXVisionDummyInputsBuilder) + dummy_inputs=HCXVisionDummyInputsBuilder, +) class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__( @@ -642,7 +626,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ## possible_resolution should be matched with preprocessor_config.json config.possible_resolutions = self._init_possible_resolutions( - config, vision_config) + config, vision_config + ) # init models & parameters with no_init_weights(): # weight will be loaded in from_pretrained @@ -653,11 +638,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): require_post_norm=False, prefix=maybe_prefix(prefix, "vision_model"), ) - self.mm_projector = self._init_mm_projector(config, text_config, - vision_config) + self.mm_projector = self._init_mm_projector(config, text_config, vision_config) - self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size", - text_config.vocab_size) + self.lm_head_vocab_size = getattr( + text_config, "padded_vocab_size", text_config.vocab_size + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=text_config, @@ -666,7 +651,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if config.anyres: self.image_newline = nn.Parameter( - torch.empty(text_config.hidden_size, dtype=self.dtype)) + torch.empty(text_config.hidden_size, dtype=self.dtype) + ) self.config = config self.vision_config = vision_config @@ -684,90 +670,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Only image or video modality is supported") + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[HCXVisionImageInputs]: + pixel_values_images = kwargs.pop("pixel_values_images", None) + + if pixel_values_images is None: + return None + + image_sizes_images = kwargs.pop("image_sizes_images") + + return HCXVisionImagePixelInputs( + pixel_values_images=pixel_values_images, + image_sizes_images=image_sizes_images, + ) + + def _parse_and_validate_video_input( + self, + **kwargs: object, + ) -> Optional[HCXVisionVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + + if pixel_values_videos is None: + return None + + return HCXVisionVideoPixelInputs( + pixel_values_videos=pixel_values_videos, + ) + + def _process_image_input( + self, + image_input: HCXVisionImageInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_images( + pixel_values_images=image_input["pixel_values_images"], + image_sizes_images=image_input["image_sizes_images"], + ) + + def _process_video_input( + self, + video_input: HCXVisionVideoInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_videos( + pixel_values_videos=video_input["pixel_values_videos"], + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key == "pixel_values_images" and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key == "pixel_values_videos" and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + + return modalities + def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, - **kwargs: Unpack[HCXVisionMultimodalInputs], - ) -> Optional[MultiModalEmbeddings]: + **kwargs: object, + ) -> MultiModalEmbeddings: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] - multimodal_embeddings = list() - if kwargs.get("pixel_values_images") is not None: - for _pixel_values_images, _image_sizes_images in zip( - kwargs["pixel_values_images"], - kwargs["image_sizes_images"]): - _pixel_values_images = _pixel_values_images.unsqueeze(dim=0) - _image_sizes_images = _image_sizes_images.unsqueeze(dim=0) - _len_pixel_values_images = [ - len(pixel_value) for pixel_value in _pixel_values_images - ] - if isinstance(_image_sizes_images, torch.Tensor): - _image_sizes_images = _image_sizes_images.detach().cpu( - ).tolist() - _multimodal_embeddings_images = self.forward_images( - pixel_values_images=_pixel_values_images, - image_sizes_images=_image_sizes_images, - len_pixel_values_images=_len_pixel_values_images, - ) - _multimodal_embeddings_images = torch.cat( - _multimodal_embeddings_images, dim=0) - multimodal_embeddings.append(_multimodal_embeddings_images) + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings - if kwargs.get("pixel_values_videos") is not None: - for _pixel_values_videos, _vision_query_lengths_videos in zip( - kwargs["pixel_values_videos"], - kwargs["vision_query_lengths_videos"]): - _len_pixel_values_videos = [ - len(_vision_query_lengths) - for _vision_query_lengths in _vision_query_lengths_videos - ] - _c, _w, _h = _pixel_values_videos.shape[-3:] - _pixel_values_videos = _pixel_values_videos.reshape( - sum(_len_pixel_values_videos), -1, _c, _w, - _h).unsqueeze(dim=0) - _multimodal_embeddings_videos = self.forward_videos( - pixel_values_videos=_pixel_values_videos, - len_pixel_values_videos=_len_pixel_values_videos, - ) - _multimodal_embeddings_videos = torch.cat( - _multimodal_embeddings_videos, dim=0) - multimodal_embeddings.append(_multimodal_embeddings_videos) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - **kwargs, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if (kwargs.get("pixel_values_images") is not None - or kwargs.get("pixel_values_videos") - is not None): # v0 compatibility - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - if multimodal_embeddings is not None: - multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0) - _mask_image = input_ids == self.config.image_token_id - _mask_video = input_ids == self.config.video_token_id - assert _mask_image.sum() + _mask_video.sum() == len( - multimodal_embeddings) - - if multimodal_embeddings.dtype != inputs_embeds.dtype: - multimodal_embeddings = multimodal_embeddings.to( - dtype=inputs_embeds.dtype) - if multimodal_embeddings.device != inputs_embeds.device: - multimodal_embeddings = multimodal_embeddings.to( - device=inputs_embeds.device) - - if _mask_image.sum() > 0: - inputs_embeds[ - _mask_image] = multimodal_embeddings[:sum(_mask_image)] - if _mask_video.sum() > 0: - inputs_embeds[_mask_video] = multimodal_embeddings[ - -sum(_mask_video):] - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -779,93 +769,66 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids=input_ids, - **kwargs) - input_ids = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def forward_images( self, - pixel_values_images: list[list[torch.FloatTensor]], - image_sizes_images: list[list[tuple[int, int]]], - len_pixel_values_images: list[int], - ) -> list[list[torch.Tensor]]: - if sum(len_pixel_values_images) == 0: - return None - - concat_pixel_values_images = torch.cat(list( - chain(*pixel_values_images)), - dim=0) + pixel_values_images: list[torch.Tensor], + image_sizes_images: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True) visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 - image_forward_outs = self.vision_model( - concat_pixel_values_images)[:, visual_token_idx:] + image_forward_outs = self.vision_model(pixel_values_image_flat)[ + :, visual_token_idx: + ] - image_forward_outs = image_forward_outs.to( - dtype=self.mm_projector.dtype) + image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d - split_sizes = [ - pixel_value.shape[0] for pixel_value in chain(*pixel_values_images) - ] - image_forward_outs = torch.split(image_forward_outs, - split_sizes, - dim=0) + split_sizes = [len(item) for item in pixel_values_images] + image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0) # newline for anyres postprocessing image_features = anyres_postprocessing( image_forward_outs=image_forward_outs, - image_sizes=[ - image_size for image_sizes in image_sizes_images - for image_size in image_sizes - ], - num_queries_vis_abstractor=self.config. - num_queries_vis_abstractor_image, + image_sizes=image_sizes_images.tolist(), + num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image, unpad=self.config.unpad, patch_size=self.vision_config.patch_size, grid_size=self.vision_config.image_size, image_newline=self.image_newline, possible_resolutions=self.config.possible_resolutions, ) - return image_features + + return tuple(image_features) def forward_videos( self, - pixel_values_videos: list[list[torch.FloatTensor]], - len_pixel_values_videos: list[int], - ) -> list[torch.Tensor]: - - len_video_grids = sum(len_pixel_values_videos) - if len_video_grids == 0: - return None - - # Run Vision Model - concat_pixel_values_videos = torch.cat(list( - chain(*pixel_values_videos)), - dim=0) + pixel_values_videos: list[list[torch.Tensor]], + ) -> tuple[torch.Tensor, ...]: + pixel_values_videos_flat = flatten_bn( + [frame for frames in pixel_values_videos for frame in frames], + concat=True, + ) visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 - video_forward_outs = self.vision_model( - concat_pixel_values_videos)[:, visual_token_idx:] + video_forward_outs = self.vision_model(pixel_values_videos_flat)[ + :, visual_token_idx: + ] - video_forward_outs = video_forward_outs.to( - dtype=self.mm_projector.dtype) + video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype) # Run MM-Projector # len(num_grids) == len(num_queries_vis_abstractors) + 1 grid_idx = 0 - num_grids = [ - grid_idx - ] # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56] - num_queries_vis_abstractors = [ - ] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9] + # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56] + num_grids = [grid_idx] + # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9] + num_queries_vis_abstractors = [] len_total_frames = video_forward_outs.shape[0] if self.config.first_last_frames_slow: @@ -873,22 +836,26 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): assert len_total_frames != 0 if len_total_frames <= 2: num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += len_total_frames num_grids.append(grid_idx) else: num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += 1 num_grids.append(grid_idx) num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_fast) + self.config.num_queries_vis_abstractor_video_fast + ) grid_idx += len_total_frames - 2 num_grids.append(grid_idx) num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += 1 num_grids.append(grid_idx) else: @@ -897,17 +864,19 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): for pixel_values_frame in pixel_values_frames: if len(pixel_values_frame) > 0: num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += 1 num_grids.append(grid_idx) num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_fast) + self.config.num_queries_vis_abstractor_video_fast + ) grid_idx = grid_idx + len(pixel_values_frame) - 1 num_grids.append(grid_idx) - video_forward_outs = self.mm_projector(video_forward_outs, - num_queries_vis_abstractors, - num_grids) + video_forward_outs = self.mm_projector( + video_forward_outs, num_queries_vis_abstractors, num_grids + ) video_features = [] # what we want to return target_features = [] @@ -929,14 +898,19 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): target_group_size = 0 elif video_group_size < target_group_size: - raise RuntimeError( - f"{video_group_size=} < {target_group_size=}") + raise RuntimeError(f"{video_group_size=} < {target_group_size=}") - assert len(target_features - ) == 0, f"target_features is not empty!! {target_features}" + assert len(target_features) == 0, ( + f"target_features is not empty!! {target_features}" + ) assert len(video_groups) == len(video_features) - return video_features + feats_per_video = [len(video) for video in pixel_values_videos] + idxs_per_video = [0, *accumulate(feats_per_video)] + return tuple( + torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]]) + for i in range(len(feats_per_video)) + ) def _prepare_multimodal_kwargs(self, **kwargs: object): output = defaultdict(list) @@ -945,7 +919,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): continue # if empty batch of empty sample new_k, is_video = k, False - if (not k.endswith("_images") and not k.endswith("_videos")): + if not k.endswith("_images") and not k.endswith("_videos"): pass else: new_k, is_video = k.split("_")[:-1], k.split("_")[-1] @@ -972,10 +946,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights( self, @@ -1000,10 +972,10 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if i * j <= config.max_num_grids: possible_resolutions.append([i, j]) - possible_resolutions = [[ - ys * vision_config.image_size, - xs * vision_config.image_size - ] for ys, xs in possible_resolutions] + possible_resolutions = [ + [ys * vision_config.image_size, xs * vision_config.image_size] + for ys, xs in possible_resolutions + ] return possible_resolutions else: return config.possible_resolutions @@ -1016,14 +988,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ): input_hidden_size = vision_config.hidden_size if config.mm_projector_type == "linear": - mm_projector = nn.Linear(input_hidden_size, - text_config.hidden_size) + mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size) mm_projector.dtype = next(mm_projector.parameters()).dtype elif config.mm_projector_type == "cabstractor": mm_projector = HCXVisionCAbstractor( num_queries=config.num_queries_vis_abstractor_image, - num_input_tokens=(vision_config.image_size // - vision_config.patch_size)**2, + num_input_tokens=(vision_config.image_size // vision_config.patch_size) + ** 2, encoder_hidden_size=input_hidden_size, hidden_size=input_hidden_size, output_hidden_size=text_config.hidden_size, @@ -1040,8 +1011,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return mm_projector -def unpad_image(tensor: torch.Tensor, - original_size: tuple[int, int]) -> torch.Tensor: +def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor: original_width, original_height = original_size current_height, current_width = tensor.shape[1:] @@ -1052,18 +1022,17 @@ def unpad_image(tensor: torch.Tensor, scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding:current_height - padding, :] + unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding:current_width - padding] + unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor -def select_best_resolution(original_size: tuple, - possible_resolutions: list) -> tuple: +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: original_height, original_width = original_size best_fit = None max_effective_resolution = 0 @@ -1071,15 +1040,19 @@ def select_best_resolution(original_size: tuple, for height, width in possible_resolutions: scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int(original_width * scale), int( - original_height * scale) - effective_resolution = min(downscaled_width * downscaled_height, - original_width * original_height) + downscaled_width, downscaled_height = ( + int(original_width * scale), + int(original_height * scale), + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( - effective_resolution == max_effective_resolution - and wasted_resolution < min_wasted_resolution): + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (height, width) @@ -1092,12 +1065,16 @@ def get_anyres_image_grid_shape( grid_pinpoints: Union[str, list[tuple[int, int]]], patch_size: int, ) -> tuple[int, int]: - possible_resolutions = grid_pinpoints if isinstance( - grid_pinpoints, list) else ast.literal_eval(grid_pinpoints) + possible_resolutions = ( + grid_pinpoints + if isinstance(grid_pinpoints, list) + else ast.literal_eval(grid_pinpoints) + ) original_width, original_height = image_size - height, width = select_best_resolution((original_height, original_width), - possible_resolutions) + height, width = select_best_resolution( + (original_height, original_width), possible_resolutions + ) return width // patch_size, height // patch_size @@ -1115,12 +1092,15 @@ def reshape_and_unpad_image_features( image_feature = image_feature[1:] assert height * width == base_image_feature.shape[0], ( - f"{height=} * {width=} != {base_image_feature.shape[0]=}") + f"{height=} * {width=} != {base_image_feature.shape[0]=}" + ) num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, possible_resolutions, grid_size) - image_feature = image_feature.view(num_patch_height, num_patch_width, - height, width, -1) + image_size, possible_resolutions, grid_size + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) if unpad: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() @@ -1129,8 +1109,9 @@ def reshape_and_unpad_image_features( image_feature = torch.cat( ( image_feature, - image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to(image_feature.device), + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), ), dim=-1, ) @@ -1144,20 +1125,21 @@ def reshape_and_unpad_image_features( def anyres_postprocessing( - image_forward_outs: list[torch.FloatTensor], + image_forward_outs: list[torch.Tensor], image_sizes: list[list[int]], possible_resolutions: list[tuple[int, int]], patch_size: int, grid_size: int, - image_newline: torch.FloatTensor, + image_newline: torch.Tensor, num_queries_vis_abstractor: int = -1, unpad: bool = False, -) -> list[torch.FloatTensor]: +) -> list[torch.Tensor]: height = width = grid_size // patch_size if num_queries_vis_abstractor > 0: - assert (num_queries_vis_abstractor**0.5 - ).is_integer(), "n_queries must be square number" + assert (num_queries_vis_abstractor**0.5).is_integer(), ( + "n_queries must be square number" + ) height = width = int(num_queries_vis_abstractor**0.5) # post-processing (unpad, add newline) @@ -1177,29 +1159,8 @@ def anyres_postprocessing( else: image_feature = image_feature[0] image_feature = torch.cat( - (image_feature, image_newline[None].to(image_feature.device)), - dim=0) + (image_feature, image_newline[None].to(image_feature.device)), dim=0 + ) new_image_features.append(image_feature) - image_features = new_image_features - return image_features - -def resize_image( - image: Union[np.ndarray, PIL.Image.Image], - max_side: int = 378, -) -> np.ndarray: - image_arr = image - if isinstance(image, np.ndarray): - image = Image.fromarray(image) - - width, height = image.size - cur_max_size = max(width, height) - if cur_max_size <= max_side: - return image_arr - - scale = max_side / cur_max_size - width = int(width * scale) - height = int(height * scale) - image = image.resize((width, height), Image.LANCZOS) - image_arr = np.array(image) - return image_arr + return new_image_features diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 88b2a295905b7..02c46a11a1798 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -24,18 +24,22 @@ from typing import Optional import torch from torch import nn from transformers.models.idefics2.configuration_idefics2 import ( - Idefics2Config, Idefics2VisionConfig) + Idefics2Config, + Idefics2VisionConfig, +) from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import run_dp_sharded_vision_model + +from .vision import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): @@ -67,13 +71,14 @@ class Idefics2VisionEmbeddings(nn.Module): self.num_patches_per_side = self.image_size // self.patch_size self.num_patches = self.num_patches_per_side**2 self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def forward(self, - pixel_values: torch.FloatTensor, - patch_attention_mask: torch.BoolTensor, - tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) @@ -82,14 +87,14 @@ class Idefics2VisionEmbeddings(nn.Module): max_im_h // self.patch_size, max_im_w // self.patch_size, ) - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, - 1 / self.num_patches_per_side) - position_ids = torch.full(size=(batch_size, - max_nb_patches_h * max_nb_patches_w), - fill_value=0) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - if tgt_sizes is not None: nb_patches_h = tgt_sizes[batch_idx][0] nb_patches_w = tgt_sizes[batch_idx][1] @@ -98,17 +103,18 @@ class Idefics2VisionEmbeddings(nn.Module): nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - bucket_coords_h = torch.bucketize(fractional_coords_h, - boundaries, - right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, - boundaries, - right=True) - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + - bucket_coords_w).flatten() + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) - embeddings = embeddings + self.position_embedding(position_ids) + embeddings += self.position_embedding(position_ids) return embeddings @@ -130,48 +136,35 @@ class Idefics2VisionAttention(nn.Module): if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() assert self.num_heads % tp_size == 0 self.num_heads_per_partition = self.num_heads // tp_size - if use_data_parallel: - self.q_size = self.num_heads * self.head_dim - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = ReplicatedLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) + # Use unified MultiHeadAttention with Flash Attention support + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def forward( self, @@ -181,13 +174,14 @@ class Idefics2VisionAttention(nn.Module): hidden_states ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim query_states, key_states, value_states = qkv.chunk(3, dim=-1) + + # Use unified MultiHeadAttention implementation out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output class Idefics2VisionMLP(nn.Module): - def __init__( self, config: Idefics2VisionConfig, @@ -198,23 +192,21 @@ class Idefics2VisionMLP(nn.Module): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -225,7 +217,6 @@ class Idefics2VisionMLP(nn.Module): class Idefics2EncoderLayer(nn.Module): - def __init__( self, config: Idefics2Config, @@ -239,15 +230,16 @@ class Idefics2EncoderLayer(nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = Idefics2VisionMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + use_data_parallel=use_data_parallel, + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -262,11 +254,11 @@ class Idefics2EncoderLayer(nn.Module): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + hidden_states += residual return hidden_states @@ -298,13 +290,17 @@ class Idefics2Encoder(nn.Module): else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - Idefics2EncoderLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( self, @@ -327,7 +323,6 @@ class Idefics2Encoder(nn.Module): class Idefics2VisionTransformer(nn.Module): - def __init__( self, config: Idefics2VisionConfig, @@ -349,7 +344,8 @@ class Idefics2VisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + ) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -359,10 +355,14 @@ class Idefics2VisionTransformer(nn.Module): ) self.require_post_norm = require_post_norm - self.post_layernorm = nn.LayerNorm( - embed_dim, - eps=config.layer_norm_eps, - ) if require_post_norm else nn.Identity() + self.post_layernorm = ( + nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) + if require_post_norm + else nn.Identity() + ) def get_input_embeddings(self): return self.embeddings @@ -379,39 +379,13 @@ class Idefics2VisionTransformer(nn.Module): tgt_sizes=tgt_sizes, ) if self.use_data_parallel: - encoder_outputs = run_dp_sharded_vision_model( - hidden_states, self.encoder) + encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder) else: encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def _consolidate_qkv_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> Iterable[tuple[str, torch.Tensor]]: - qkv_idx_mappings = { - ".self_attn.q_proj": 0, - ".self_attn.k_proj": 1, - ".self_attn.v_proj": 2, - } - qkv_weights = {} - for name, loaded_weight in weights: - for weight_name, idx in qkv_idx_mappings.items(): - if weight_name not in name: - continue - new_name = name.replace(weight_name, ".self_attn.qkv_proj") - if new_name not in qkv_weights: - qkv_weights[new_name] = [None] * 3 - qkv_weights[new_name][idx] = loaded_weight - break - else: - yield name, loaded_weight - for key, weight in qkv_weights.items(): - qkv_weight = torch.cat(weight, dim=0) - yield key, qkv_weight - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -422,17 +396,13 @@ class Idefics2VisionTransformer(nn.Module): loaded_params: set[str] = set() layer_count = len(self.encoder.layers) - if self.use_data_parallel: - weights = self._consolidate_qkv_weights(weights) - for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): continue # post_layernorm is optional - if (name.startswith("post_layernorm.") - and not self.require_post_norm): + if name.startswith("post_layernorm.") and not self.require_post_norm: continue # omit layers when num_hidden_layers_override is set @@ -451,8 +421,7 @@ class Idefics2VisionTransformer(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 63307470d959b..effdbdc1ac384 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -22,39 +22,45 @@ from typing import Annotated, Literal, Optional, Union import torch from torch import nn -from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, - Idefics3Processor) +from transformers import ( + BatchFeature, + Idefics3Config, + Idefics3ImageProcessor, + Idefics3Processor, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageProcessorItems, ImageSize -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalDataItems, PromptReplacement, - PromptUpdate, PromptUpdateDetails) -# yapf: enable +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalDataItems, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -# yapf: disable from .idefics2_vision_model import ( - Idefics2VisionTransformer as Idefics3VisionTransformer) -# yapf: enable + Idefics2VisionTransformer as Idefics3VisionTransformer, +) from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .llama import LlamaModel -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, maybe_prefix class Idefics3ImagePixelInputs(TensorSchema): @@ -66,9 +72,10 @@ class Idefics3ImagePixelInputs(TensorSchema): - h: Height - w: Width """ + type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] - pixel_attention_mask: torch.Tensor + pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -79,6 +86,7 @@ class Idefics3ImageEmbeddingInputs(TensorSchema): - f: Image feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] @@ -87,20 +95,21 @@ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] class Idefics3ProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> Idefics3Processor: return self.ctx.get_hf_processor(Idefics3Processor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def _resize_output_size(self, - *, - height: int, - width: int, - max_len: Optional[int] = None, - min_len: int = 1, - max_size: Optional[int] = None) -> tuple[int, int]: + def _resize_output_size( + self, + *, + height: int, + width: int, + max_len: Optional[int] = None, + min_len: int = 1, + max_size: Optional[int] = None, + ) -> tuple[int, int]: # Set default value for max_len if not provided max_len = max(height, width) if max_len is None else max_len aspect_ratio = width / height @@ -136,18 +145,19 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ) -> tuple[int, int]: hf_processor = self.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor - max_image_size = image_processor.size['longest_edge'] + max_image_size = image_processor.size["longest_edge"] if resolution_max_side > max_image_size: raise ValueError( - "`resolution_max_side` cannot be larger than `max_image_size`") + "`resolution_max_side` cannot be larger than `max_image_size`" + ) height, width = image_height, image_width # Find the output size, when rescaling the longest edge to max_len and # preserving the aspect ratio - height, width = self._resize_output_size(height=height, - width=width, - max_len=resolution_max_side) + height, width = self._resize_output_size( + height=height, width=width, max_len=resolution_max_side + ) return height, width def _get_image_feature_grid_size( @@ -162,12 +172,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): image_processor: Idefics3ImageProcessor = processor.image_processor - max_image_size = image_processor.max_image_size['longest_edge'] - size = image_processor.size['longest_edge'] + max_image_size = image_processor.max_image_size["longest_edge"] + size = image_processor.size["longest_edge"] assert size % max_image_size == 0, ( "`longest_edge` in image_processor's `size` must be divisible by " "`longest_edge` in `max_image_size`, this may be caused by " - "incorrect mm_kwargs override.") + "incorrect mm_kwargs override." + ) resized_height, resized_width = self._get_resize_output_image_size( image_width=image_width, @@ -197,8 +208,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): return grid_w * grid_h + 1 def _get_image_token( - self, - processor: Optional[Idefics3Processor]) -> tuple[str, str, str]: + self, processor: Optional[Idefics3Processor] + ) -> tuple[str, str, str]: if processor is None: processor = self.get_hf_processor() @@ -218,7 +229,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() image_token, fake_image_token, global_img_token = self._get_image_token( - processor) + processor + ) image_seq_len = processor.image_seq_len grid_placeholder = "<row_{n_h}_col_{n_w}>" @@ -237,19 +249,20 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): tiles_placeholder = list[str]() for i in range(grid_h): for j in range(grid_w): - placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, - n_w=j + 1) + placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, n_w=j + 1) tiles_placeholder.append(placeholder_per_tile) # Add line break if it is the last tile in the row if j == grid_w - 1: tiles_placeholder.append("\n") - return "".join([ - *tiles_placeholder, - "\n", - global_img_placeholder, - fake_image_token, - ]) + return "".join( + [ + *tiles_placeholder, + "\n", + global_img_placeholder, + fake_image_token, + ] + ) def get_num_image_tokens( self, @@ -279,9 +292,7 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ) -class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] - ): - +class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -294,23 +305,26 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) hf_processor = self.info.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor - longest_edge = image_processor.max_image_size['longest_edge'] + longest_edge = image_processor.max_image_size["longest_edge"] + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=longest_edge, - height=longest_edge, - num_images=num_images) + "image": self._get_dummy_images( + width=longest_edge, + height=longest_edge, + num_images=num_images, + overrides=image_overrides, + ) } -class Idefics3MultiModalProcessor( - BaseMultiModalProcessor[Idefics3ProcessingInfo]): - +class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -331,9 +345,11 @@ class Idefics3MultiModalProcessor( tok_kwargs, ) - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] @@ -344,7 +360,8 @@ class Idefics3MultiModalProcessor( image_width=size.width, image_height=size.height, processor=hf_processor, - ) for size in image_sizes + ) + for size in image_sizes ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -362,10 +379,10 @@ class Idefics3MultiModalProcessor( num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + "image", num_patches + ), image_embeds=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"), ) @@ -405,7 +422,6 @@ class Idefics3MultiModalProcessor( class Idefics3SimpleMLP(nn.Module): - def __init__( self, config: Idefics3Config, @@ -413,8 +429,7 @@ class Idefics3SimpleMLP(nn.Module): prefix: str = "", ): super().__init__() - input_size = config.vision_config.hidden_size * (config.scale_factor** - 2) + input_size = config.vision_config.hidden_size * (config.scale_factor**2) output_size = config.text_config.hidden_size self.proj = ReplicatedLinear( input_size, @@ -430,7 +445,6 @@ class Idefics3SimpleMLP(nn.Module): class Idefics3Connector(nn.Module): - def __init__( self, config: Idefics3Config, @@ -445,14 +459,11 @@ class Idefics3Connector(nn.Module): prefix=maybe_prefix(prefix, "modality_projection"), ) - def pixel_shuffle(self, - x: torch.Tensor, - scale_factor: int = 2) -> torch.Tensor: + def pixel_shuffle(self, x: torch.Tensor, scale_factor: int = 2) -> torch.Tensor: bsz, seq, embed_dim = x.size() height = width = int(seq**0.5) x = x.view(bsz, height, width, embed_dim) - x = x.view(bsz, height, int(width / scale_factor), - embed_dim * scale_factor) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) x = x.permute(0, 2, 1, 3) x = x.reshape( bsz, @@ -461,19 +472,16 @@ class Idefics3Connector(nn.Module): embed_dim * (scale_factor**2), ) x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), - embed_dim * (scale_factor**2)) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) return x def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: - image_hidden_states = self.pixel_shuffle(image_hidden_states, - self.scale_factor) + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states class Idefics3Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -485,7 +493,8 @@ class Idefics3Model(nn.Module): self.vision_model = Idefics3VisionTransformer( config.vision_config, quant_config=quant_config, - prefix=maybe_prefix(prefix, "vision_model")) + prefix=maybe_prefix(prefix, "vision_model"), + ) self.connector = Idefics3Connector( config, quant_config, @@ -497,8 +506,9 @@ class Idefics3Model(nn.Module): ) self.image_seq_len = int( - ((config.vision_config.image_size // - config.vision_config.patch_size)**2) / (config.scale_factor**2)) + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) self.image_token_id = self.config.image_token_id def image_pixels_to_features( @@ -515,21 +525,21 @@ class Idefics3Model(nn.Module): # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3)) != nb_values_per_image + dim=(-1, -2, -3) + ) != nb_values_per_image pixel_values = pixel_values[real_images_inds].contiguous() # Handle the vision attention mask # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask[ - real_images_inds].contiguous() + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold(dimension=1, - size=patch_size, - step=patch_size) - patches_subgrid = patches_subgrid.unfold(dimension=2, - size=patch_size, - step=patch_size) + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() # Get sequence from the vision encoder @@ -540,10 +550,7 @@ class Idefics3Model(nn.Module): return image_hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.text_model.get_input_embeddings(input_ids) def forward( @@ -553,7 +560,6 @@ class Idefics3Model(nn.Module): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.text_model( input_ids, positions, @@ -566,9 +572,11 @@ class Idefics3Model(nn.Module): @MULTIMODAL_REGISTRY.register_processor( Idefics3MultiModalProcessor, info=Idefics3ProcessingInfo, - dummy_inputs=Idefics3DummyInputsBuilder) -class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA): + dummy_inputs=Idefics3DummyInputsBuilder, +) +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -598,21 +606,24 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.model = Idefics3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Idefics3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.image_token_id = self.config.image_token_id self.lm_head = ParallelLMHead( config.text_config.vocab_size, config.text_config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.text_config.tie_word_embeddings: - self.lm_head.weight = self.model.text_model.wte.weight + self.lm_head.weight = self.model.text_model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[ImageInputs]: + self, **kwargs: object + ) -> Optional[ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -620,47 +631,27 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Idefics3ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_attention_mask = kwargs.pop("pixel_attention_mask") - if not isinstance(pixel_attention_mask, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_attention_mask. " - f"Got type: {type(pixel_attention_mask)}") - num_patches = kwargs.pop("num_patches") - if not isinstance(num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_patches. " - f"Got type: {type(num_patches)}") - expected_h = expected_w = self.config.vision_config.image_size + return Idefics3ImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - pixel_attention_mask=flatten_bn(pixel_attention_mask, - concat=True), - num_patches=flatten_bn(num_patches, concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, + resolve_bindings={"h": expected_h, "w": expected_w}, ) raise AssertionError("This line should be unreachable.") - def _process_image_pixels( - self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: pixel_values = inputs["pixel_values"] pixel_attention_mask = inputs["pixel_attention_mask"] @@ -680,37 +671,18 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, image_features = self.model.connector(image_features) num_patches = image_input["num_patches"] - return [ - e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) - ] + return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())] def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -722,29 +694,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.model.text_model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.model.text_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -755,4 +715,5 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, return MultiModelKeys.from_string_field( language_model="model.text_model", connector="model.connector", - tower_model="model.vision_model") + tower_model="model.vision_model", + ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9415e67924e74..68915d60ef480 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,12 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, MutableSequence -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, - TypeVar, Union, overload, runtime_checkable) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + Literal, + Optional, + Protocol, + Union, + overload, + runtime_checkable, +) import numpy as np import torch from torch import Tensor +from transformers import PretrainedConfig from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs @@ -15,14 +25,12 @@ from vllm.config import ModelConfig, SpeechToTextConfig from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils import supports_kw -from .interfaces_base import is_pooling_model +from .interfaces_base import VllmModel, is_pooling_model if TYPE_CHECKING: - from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors @@ -52,6 +60,24 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_multimodal_raw_input_only: ClassVar[bool] = False + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + """ + + supports_encoder_tp_data: ClassVar[bool] = False + """ + A flag that indicates whether this model supports + `multimodal_config.mm_encoder_tp_mode="data"`. + """ + + merge_by_field_config: ClassVar[bool] = False + """ + A flag that indicates which implementation of + `vllm.multimodal.utils.group_mm_kwargs_by_modality` to use. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ @@ -59,10 +85,9 @@ class SupportsMultiModal(Protocol): """ ... - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: """ - Returns multimodal embeddings generated from multimodal kwargs + Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. Note: @@ -72,11 +97,11 @@ class SupportsMultiModal(Protocol): """ ... - def get_language_model(self) -> torch.nn.Module: + def get_language_model(self) -> VllmModel: """ Returns the underlying language model used for text generation. - This is typically the `torch.nn.Module` instance responsible for + This is typically the `torch.nn.Module` instance responsible for processing the merged multimodal embeddings and producing hidden states Returns: @@ -84,51 +109,130 @@ class SupportsMultiModal(Protocol): """ ... - # Only for models that support v0 chunked prefill - # TODO(ywang96): Remove this overload once v0 is deprecated @overload - def get_input_embeddings( - self, - input_ids: Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - ) -> Tensor: - ... + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ... - # TODO: Remove this overload once v0 is deprecated @overload def get_input_embeddings( self, input_ids: Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings, + *, + is_multimodal: torch.Tensor, + handle_oov_mm_token: bool = False, + ) -> Tensor: ... + + def _get_text_embeddings( + self, + input_ids: Tensor, + get_input_embeddings: Callable[[Tensor], Tensor], + *, + is_multimodal: Optional[Tensor], + handle_oov_mm_token: bool, ) -> Tensor: - ... + if handle_oov_mm_token and is_multimodal is not None: + is_text = ~is_multimodal + text_embeds = get_input_embeddings(input_ids[is_text]) + + return torch.empty( + (input_ids.shape[0], text_embeds.shape[1]), + dtype=text_embeds.dtype, + device=text_embeds.device, + ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) + + return get_input_embeddings(input_ids) def get_input_embeddings( self, input_ids: Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - # Only necessary so that the v0 overload is valid - # TODO: Remove attn_metadata once v0 is deprecated - attn_metadata: Optional["AttentionMetadata"] = None, + *, + is_multimodal: Optional[Tensor] = None, + handle_oov_mm_token: bool = False, ) -> Tensor: """ - Returns the input embeddings merged from the text embeddings from - input_ids and the multimodal embeddings generated from multimodal - kwargs. + Apply token embeddings to `input_ids`. + + If `multimodal_embeddings` is passed, scatter them into + `input_ids` according to the mask `is_multimodal`. + + In case the multi-modal token IDs exceed the vocabulary size of + the language model, you can set `handle_oov_mm_token=False` + to avoid calling the language model's `get_input_embeddings` method + on those tokens. Note however that doing so increases memory usage + as an additional buffer is needed to hold the input embeddings. + """ + from .utils import _merge_multimodal_embeddings + + inputs_embeds = self._get_text_embeddings( + input_ids, + self.get_language_model().get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + +@runtime_checkable +class SupportsMultiModalPruning(Protocol): + """The interface required for models that support returning both input + embeddings and positions. Model may require custom positions for dynamic + pruning of multimodal embeddings. + """ + + supports_multimodal_pruning: ClassVar[Literal[True]] = True + + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: MultiModalEmbeddings, + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[MultiModalEmbeddings, Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt containing + entire sequence. + multimodal_embeddings: Tuple of multimodal embeddings that + fits into the prefill chunk that is being processed. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). """ ... @overload -def supports_multimodal( - model: type[object]) -> TypeIs[type[SupportsMultiModal]]: - ... +def supports_multimodal(model: type[object]) -> TypeIs[type[SupportsMultiModal]]: ... @overload -def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: - ... +def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ... def supports_multimodal( @@ -137,38 +241,28 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) -@runtime_checkable -class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): - """The interface required for all multi-modal models.""" +def supports_multimodal_raw_input_only(model: Union[type[object], object]) -> bool: + return getattr(model, "supports_multimodal_raw_input_only", False) - supports_multimodal_raw_input: ClassVar[Literal[True]] = True - """ - A flag that indicates this model supports multi-modal inputs and processes - them in their raw form and not embeddings. - Note: - There is no need to redefine this flag if this class is in the - MRO of your model class. - """ +def supports_multimodal_encoder_tp_data(model: Union[type[object], object]) -> bool: + return getattr(model, "supports_encoder_tp_data", False) @overload -def supports_multimodal_raw_input( - model: object) -> TypeIs[SupportsMultiModalWithRawInput]: - ... +def supports_multimodal_pruning( + model: type[object], +) -> TypeIs[type[SupportsMultiModalPruning]]: ... @overload -def supports_multimodal_raw_input( - model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: - ... +def supports_multimodal_pruning(model: object) -> TypeIs[SupportsMultiModalPruning]: ... -def supports_multimodal_raw_input( - model: Union[type[object], object] -) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], - TypeIs[SupportsMultiModalWithRawInput]]: - return getattr(model, "supports_multimodal_raw_input", False) +def supports_multimodal_pruning( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsMultiModalPruning]], TypeIs[SupportsMultiModalPruning]]: + return getattr(model, "supports_multimodal_pruning", False) @runtime_checkable @@ -188,7 +282,7 @@ class SupportsScoreTemplate(Protocol): def get_score_template(cls, query: str, document: str) -> Optional[str]: """ Generate a full prompt by populating the score template with query and document content. - """ # noqa: E501 + """ # noqa: E501 ... @classmethod @@ -201,13 +295,12 @@ class SupportsScoreTemplate(Protocol): @overload def supports_score_template( - model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]: - ... + model: type[object], +) -> TypeIs[type[SupportsScoreTemplate]]: ... @overload -def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: - ... +def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: ... def supports_score_template( @@ -232,7 +325,7 @@ class SupportsLoRA(Protocol): # are empty by default. embedding_modules: ClassVar[dict[str, str]] = {} embedding_padding_modules: ClassVar[list[str]] = [] - packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + packed_modules_mapping: dict[str, list[str]] = {} # We can't use runtime_checkable with ClassVar for issubclass checks @@ -247,13 +340,11 @@ class _SupportsLoRAType(Protocol): @overload -def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: - ... +def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: ... @overload -def supports_lora(model: object) -> TypeIs[SupportsLoRA]: - ... +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ... def supports_lora( @@ -267,8 +358,7 @@ def supports_lora( "embedding_modules", "embedding_padding_modules", ) - missing_attrs = tuple(attr for attr in lora_attrs - if not hasattr(model, attr)) + missing_attrs = tuple(attr for attr in lora_attrs if not hasattr(model, attr)) if getattr(model, "supports_lora", False): if missing_attrs: @@ -282,7 +372,9 @@ def supports_lora( if not missing_attrs: logger.warning( "The model (%s) contains all LoRA-specific attributes, " - "but does not set `supports_lora=True`.", model) + "but does not set `supports_lora=True`.", + model, + ) return result @@ -342,25 +434,21 @@ class _SupportsPPType(Protocol): batch_size: int, dtype: torch.dtype, device: torch.device, - ) -> "IntermediateTensors": - ... + ) -> "IntermediateTensors": ... def forward( self, *, intermediate_tensors: Optional["IntermediateTensors"], - ) -> Union[Tensor, "IntermediateTensors"]: - ... + ) -> Union[Tensor, "IntermediateTensors"]: ... @overload -def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: - ... +def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: ... @overload -def supports_pp(model: object) -> TypeIs[SupportsPP]: - ... +def supports_pp(model: object) -> TypeIs[SupportsPP]: ... def supports_pp( @@ -372,12 +460,13 @@ def supports_pp( if supports_attributes and not supports_inspect: logger.warning( "The model (%s) sets `supports_pp=True`, but does not accept " - "`intermediate_tensors` in its `forward` method", model) + "`intermediate_tensors` in its `forward` method", + model, + ) if not supports_attributes: - pp_attrs = ("make_empty_intermediate_tensors", ) - missing_attrs = tuple(attr for attr in pp_attrs - if not hasattr(model, attr)) + pp_attrs = ("make_empty_intermediate_tensors",) + missing_attrs = tuple(attr for attr in pp_attrs if not hasattr(model, attr)) if getattr(model, "supports_pp", False): if missing_attrs: @@ -391,7 +480,9 @@ def supports_pp( if not missing_attrs: logger.warning( "The model (%s) contains all PP-specific attributes, " - "but does not set `supports_pp=True`.", model) + "but does not set `supports_pp=True`.", + model, + ) return supports_attributes and supports_inspect @@ -424,17 +515,15 @@ class HasInnerState(Protocol): @overload -def has_inner_state(model: object) -> TypeIs[HasInnerState]: - ... +def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @overload -def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: - ... +def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: ... def has_inner_state( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: return getattr(model, "has_inner_state", False) @@ -453,17 +542,15 @@ class IsAttentionFree(Protocol): @overload -def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: - ... +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: ... @overload -def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: - ... +def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: ... def is_attention_free( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: return getattr(model, "is_attention_free", False) @@ -471,7 +558,7 @@ def is_attention_free( @runtime_checkable class IsHybrid(Protocol): """The interface required for all models like Jamba that have both - attention and mamba blocks, indicates that + attention and mamba blocks, indicates that hf_config has 'layers_block_type'""" is_hybrid: ClassVar[Literal[True]] = True @@ -501,17 +588,15 @@ class IsHybrid(Protocol): @overload -def is_hybrid(model: object) -> TypeIs[IsHybrid]: - ... +def is_hybrid(model: object) -> TypeIs[IsHybrid]: ... @overload -def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: - ... +def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: ... def is_hybrid( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: return getattr(model, "is_hybrid", False) @@ -562,7 +647,7 @@ class MixtureOfExperts(Protocol): ) -> None: """ Register the EPLB state in the MoE model. - + Since these are views of the actual EPLB state, any changes made by the EPLB algorithm are automatically reflected in the model's behavior without requiring additional method calls to set new states. @@ -582,8 +667,7 @@ class MixtureOfExperts(Protocol): self, num_physical_experts: int, num_local_physical_experts: int, - ) -> None: - ... + ) -> None: ... def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: @@ -596,17 +680,15 @@ class HasNoOps(Protocol): @overload -def has_noops(model: object) -> TypeIs[HasNoOps]: - ... +def has_noops(model: object) -> TypeIs[HasNoOps]: ... @overload -def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: - ... +def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: ... def has_noops( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: return getattr(model, "has_noops", False) @@ -620,13 +702,12 @@ class SupportsCrossEncoding(Protocol): @overload def supports_cross_encoding( - model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]: - ... + model: type[object], +) -> TypeIs[type[SupportsCrossEncoding]]: ... @overload -def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: - ... +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ... def _supports_cross_encoding( @@ -641,23 +722,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -_T = TypeVar("_T", bound=type[torch.nn.Module]) - - -def default_pooling_type(pooling_type: str): - """Set default_pooling_type decorator. """ - - def func(model: _T) -> _T: - model.default_pooling_type = pooling_type # type: ignore - return model - - return func - - -def get_default_pooling_type(model: Union[type[object], object]) -> str: - return getattr(model, "default_pooling_type", "LAST") - - class SupportsQuant: """The interface required for all models that support quantization.""" @@ -671,7 +735,6 @@ class SupportsQuant: # find config passed in arguments quant_config = cls._find_quant_config(*args, **kwargs) if quant_config is not None: - # attach config to model for general use instance.quant_config = quant_config @@ -680,7 +743,8 @@ class SupportsQuant: instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper) if instance.packed_modules_mapping is not None: instance.quant_config.packed_modules_mapping.update( - instance.packed_modules_mapping) + instance.packed_modules_mapping + ) return instance @@ -703,6 +767,7 @@ class SupportsQuant: @runtime_checkable class SupportsTranscription(Protocol): """The interface required for all models that support transcription.""" + # Mapping from ISO639_1 language codes: language names supported_languages: ClassVar[Mapping[str, str]] @@ -723,14 +788,20 @@ class SupportsTranscription(Protocol): raise ValueError( f"{cls.__name__}.supported_languages contains invalid " f"language codes: {sorted(invalid)}\n. " - f"Valid choices are: {sorted(LANGUAGES.keys())}") + f"Valid choices are: {sorted(LANGUAGES.keys())}" + ) @classmethod - def get_generation_prompt(cls, audio: np.ndarray, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - language: Optional[str], task_type: str, - request_prompt: str) -> PromptType: + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it returns a valid PromptType.""" @@ -739,17 +810,14 @@ class SupportsTranscription(Protocol): @classmethod def get_other_languages(cls) -> Mapping[str, str]: # other possible language codes from the whisper map - return { - k: v - for k, v in LANGUAGES.items() if k not in cls.supported_languages - } + return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages} @classmethod def validate_language(cls, language: Optional[str]) -> Optional[str]: """ - Ensure the language specified in the transcription request - is a valid ISO 639-1 language code. If the request language is - valid, but not natively supported by the model, trigger a + Ensure the language specified in the transcription request + is a valid ISO 639-1 language code. If the request language is + valid, but not natively supported by the model, trigger a warning (but not an exception). """ if language is None or language in cls.supported_languages: @@ -766,22 +834,25 @@ class SupportsTranscription(Protocol): else: raise ValueError( f"Unsupported language: {language!r}. Must be one of " - f"{list(cls.supported_languages.keys())}.") + f"{list(cls.supported_languages.keys())}." + ) @classmethod def get_speech_to_text_config( - cls, model_config: ModelConfig, - task_type: Literal["transcribe", - "translate"]) -> SpeechToTextConfig: + cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"] + ) -> SpeechToTextConfig: """Get the speech to text config for the ASR model.""" ... @classmethod - def get_num_audio_tokens(cls, audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig) -> Optional[int]: + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: """ - Map from audio duration to number of audio tokens produced by the ASR + Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. This is used for estimating the amount of processing for this audio. """ @@ -790,13 +861,12 @@ class SupportsTranscription(Protocol): @overload def supports_transcription( - model: type[object]) -> TypeIs[type[SupportsTranscription]]: - ... + model: type[object], +) -> TypeIs[type[SupportsTranscription]]: ... @overload -def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: - ... +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: ... def supports_transcription( @@ -813,13 +883,11 @@ class SupportsV0Only(Protocol): @overload -def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: - ... +def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: ... @overload -def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: - ... +def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: ... def supports_v0_only( @@ -830,7 +898,7 @@ def supports_v0_only( @runtime_checkable class SupportsEagle3(Protocol): - """The interface required for models that support + """The interface required for models that support EAGLE3 speculative decoding.""" supports_eagle3: ClassVar[Literal[True]] = True @@ -847,10 +915,10 @@ class SupportsEagle3(Protocol): """ Set which layers should output auxiliary hidden states for EAGLE3. - + Args: layers: Tuple of layer indices that should output auxiliary - hidden states. + hidden states. """ ... @@ -858,7 +926,7 @@ class SupportsEagle3(Protocol): """ Get the layer indices that should output auxiliary hidden states for EAGLE3. - + Returns: Tuple of layer indices for auxiliary hidden state outputs. """ @@ -866,16 +934,79 @@ class SupportsEagle3(Protocol): @overload -def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: - ... +def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: ... @overload -def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: - ... +def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: ... def supports_eagle3( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]: return isinstance(model, SupportsEagle3) + + +@runtime_checkable +class SupportsMRoPE(Protocol): + """The interface required for all models that support M-RoPE.""" + + supports_mrope: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports M-RoPE. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """ + Get M-RoPE input positions and delta value for this specific model. + + This method should be implemented by each model that supports M-RoPE + to provide model-specific logic for computing input positions. + + Args: + input_tokens: List of input token IDs + hf_config: HuggingFace model configuration + image_grid_thw: Image grid dimensions (t, h, w) + video_grid_thw: Video grid dimensions (t, h, w) + second_per_grid_ts: Seconds per grid timestep for videos + context_len: Context length + seq_len: Sequence length + audio_feature_lengths: Audio feature lengths for multimodal models + use_audio_in_video: Whether to use audio in video for interleaving + + Returns: + Tuple of (llm_positions, mrope_position_delta) + - llm_positions: Tensor of shape [3, num_tokens] + with T/H/W positions + - mrope_position_delta: Delta for position calculations + """ + ... + + +@overload +def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]: ... + + +@overload +def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]: ... + + +def supports_mrope( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]: + return isinstance(model, SupportsMRoPE) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 697fa020deb46..b697eb25b5cc2 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,7 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol, - Union, overload, runtime_checkable) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Optional, + Protocol, + Union, + overload, + runtime_checkable, +) import torch import torch.nn as nn @@ -13,11 +22,9 @@ from vllm.utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler - from vllm.model_executor.sampling_metadata import SamplingMetadata else: VllmConfig = Any Pooler = Any - SamplingMetadata = Any logger = init_logger(__name__) @@ -40,15 +47,20 @@ class VllmModel(Protocol[T_co]): self, vllm_config: VllmConfig, prefix: str = "", - ) -> None: + ) -> None: ... + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + """Apply token embeddings to `input_ids`.""" ... def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - ) -> T_co: - ... + ) -> T_co: ... def _check_vllm_model_init(model: Union[type[object], object]) -> bool: @@ -56,17 +68,27 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool: return supports_kw(model_init, "vllm_config") +def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool: + model_get_input_embeddings = getattr(model, "get_input_embeddings", None) + if not callable(model_get_input_embeddings): + logger.warning( + "The model (%s) is missing the `get_input_embeddings` method.", + model, + ) + return False + + return True + + def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False vllm_kws = ("input_ids", "positions") - missing_kws = tuple(kw for kw in vllm_kws - if not supports_kw(model_forward, kw)) + missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw)) - if missing_kws and (isinstance(model, type) - and issubclass(model, nn.Module)): + if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)): logger.warning( "The model (%s) is missing " "vLLM-specific keywords from its `forward` method: %s", @@ -78,19 +100,21 @@ def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: @overload -def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: - ... +def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ... @overload -def is_vllm_model(model: object) -> TypeIs[VllmModel]: - ... +def is_vllm_model(model: object) -> TypeIs[VllmModel]: ... def is_vllm_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: - return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + return ( + _check_vllm_model_init(model) + and _check_vllm_model_get_input_embeddings(model) + and _check_vllm_model_forward(model) + ) @runtime_checkable @@ -100,7 +124,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - sampling_metadata: SamplingMetadata, ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... @@ -108,20 +131,19 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): @overload def is_text_generation_model( - model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]: - ... + model: type[object], +) -> TypeIs[type[VllmModelForTextGeneration]]: ... @overload -def is_text_generation_model( - model: object) -> TypeIs[VllmModelForTextGeneration]: - ... +def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ... def is_text_generation_model( model: Union[type[object], object], -) -> Union[TypeIs[type[VllmModelForTextGeneration]], - TypeIs[VllmModelForTextGeneration]]: +) -> Union[ + TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration] +]: if not is_vllm_model(model): return False @@ -144,18 +166,27 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ + default_pooling_type: ClassVar[str] = "LAST" + """ + Indicates the + [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @overload -def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: - ... +def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ... @overload -def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: - ... +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... def is_pooling_model( @@ -165,3 +196,20 @@ def is_pooling_model( return False return getattr(model, "is_pooling_model", False) + + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def default_pooling_type(pooling_type: str): + """Decorator to set `VllmModelForPooling.default_pooling_type`.""" + + def func(model: _T) -> _T: + model.default_pooling_type = pooling_type # type: ignore + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 58e8163e0b26e..9435ff0d26cff 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -17,26 +17,32 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from .vision import run_dp_sharded_vision_model + NORM2FN = { - 'rms_norm': RMSNorm, - 'layer_norm': nn.LayerNorm, + "rms_norm": RMSNorm, + "layer_norm": nn.LayerNorm, } class InternVisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -46,28 +52,36 @@ class InternVisionEmbeddings(nn.Module): self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) - self.patch_embedding = nn.Conv2d(in_channels=3, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size) + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter( - torch.randn(1, self.num_positions, self.embed_dim)) + torch.randn(1, self.num_positions, self.embed_dim) + ) def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int): target_dtype = pos_embed.dtype - pos_embed = pos_embed.float().reshape( - 1, self.image_size // self.patch_size, - self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) - pos_embed = F.interpolate(pos_embed, - size=(H, W), - mode='bicubic', - align_corners=False) - return pos_embed.reshape(1, -1, H * W).permute(0, 2, - 1).to(target_dtype) + pos_embed = ( + pos_embed.float() + .reshape( + 1, + self.image_size // self.patch_size, + self.image_size // self.patch_size, + -1, + ) + .permute(0, 3, 1, 2) + ) + pos_embed = F.interpolate( + pos_embed, size=(H, W), mode="bicubic", align_corners=False + ) + return pos_embed.reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) def _get_position_embedding(self, H: int, W: int) -> torch.Tensor: position_embedding = self.position_embedding @@ -84,12 +98,12 @@ class InternVisionEmbeddings(nn.Module): def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - target_dtype)) # shape = [*, channel, width, height] + patch_embeds = self.patch_embedding( + pixel_values.to(target_dtype) + ) # shape = [*, channel, width, height] batch_size, _, height, width = patch_embeds.shape patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, - -1).to(target_dtype) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) position_embedding = self._get_position_embedding(height, width) embeddings = embeddings + position_embedding.to(target_dtype) @@ -97,7 +111,6 @@ class InternVisionEmbeddings(nn.Module): class InternVisionPatchModel(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -112,8 +125,7 @@ class InternVisionPatchModel(nn.Module): pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: - raise ValueError( - 'You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds @@ -121,8 +133,7 @@ class InternVisionPatchModel(nn.Module): if pixel_values.ndim == 4: hidden_states = self.embeddings(pixel_values) else: - raise ValueError( - f'wrong pixel_values size: {pixel_values.shape}') + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") return hidden_states @@ -137,6 +148,7 @@ class InternParallelAttention(nn.Module): *, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -146,17 +158,21 @@ class InternParallelAttention(nn.Module): self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() # Additional dummy heads are used to enable TP for common GPU counts. self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim - self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads, - self.tp_size) + self.num_heads_per_partition = divide( + num_dummy_heads + self.num_heads, self.tp_size + ) self.scale = self.head_dim**-0.5 self.qkv = QKVParallelLinear( @@ -166,27 +182,34 @@ class InternParallelAttention(nn.Module): bias=config.qkv_bias, quant_config=quant_config, prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, ) self.qk_normalization = config.qk_normalization if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) + self.q_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) + self.k_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) self.proj = RowParallelLinear( self.dummy_dim, self.embed_dim, quant_config=quant_config, prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, ) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: @@ -195,8 +218,7 @@ class InternParallelAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -214,93 +236,34 @@ class InternParallelAttention(nn.Module): return out -class InternSdpaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: PretrainedConfig, - *, - num_dummy_heads: int = 0, - ) -> None: - super().__init__() - - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') - - # Additional dummy heads are used to enable TP for common GPU counts. - self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim - - self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, - 3 * self.dummy_dim, - bias=config.qkv_bias) - - self.qk_normalization = config.qk_normalization - - if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - - self.proj = nn.Linear(self.dummy_dim, self.embed_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.chunk(3, dim=-1) - - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - - if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) - k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) - - x = self.proj(x) - return x - - class InternMLP(nn.Module): - def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -311,7 +274,6 @@ class InternMLP(nn.Module): class InternVisionEncoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -319,6 +281,7 @@ class InternVisionEncoderLayer(nn.Module): *, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -326,23 +289,25 @@ class InternVisionEncoderLayer(nn.Module): self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = self._init_attn(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.attn") + self.attn = self._init_attn( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) - self.mlp = InternMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.norm1 = NORM2FN[self.norm_type](self.embed_dim, - eps=config.layer_norm_eps) - self.norm2 = NORM2FN[self.norm_type](self.embed_dim, - eps=config.layer_norm_eps) + self.mlp = InternMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) - self.ls1 = nn.Parameter(config.initializer_factor * - torch.ones(self.embed_dim)) - self.ls2 = nn.Parameter(config.initializer_factor * - torch.ones(self.embed_dim)) + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) def _init_attn( self, @@ -351,34 +316,37 @@ class InternVisionEncoderLayer(nn.Module): *, num_dummy_heads: int, prefix: str = "", + use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable - tp_size = get_tensor_model_parallel_world_size() + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads - if (num_heads + num_dummy_heads) % tp_size == 0: - return InternParallelAttention(config, - quant_config=quant_config, - num_dummy_heads=num_dummy_heads, - prefix=prefix) - - return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) + # if the number of heads is not divisible by tp_size, + # we also disable Attention's TP + use_data_parallel = ( + use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0 + ) + return InternParallelAttention( + config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads, + prefix=prefix, + use_data_parallel=use_data_parallel, + ) def forward( self, hidden_states: torch.Tensor, ): - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states)) * self.ls1 + hidden_states = hidden_states + self.attn(self.norm1(hidden_states)) * self.ls1 - hidden_states = hidden_states + self.mlp( - self.norm2(hidden_states)) * self.ls2 + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2 return hidden_states class InternVisionEncoder(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -387,6 +355,7 @@ class InternVisionEncoder(nn.Module): num_hidden_layers_override: Optional[int] = None, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() @@ -397,16 +366,20 @@ class InternVisionEncoder(nn.Module): else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - InternVisionEncoderLayer(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + InternVisionEncoderLayer( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward(self, inputs_embeds: torch.Tensor): - hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) @@ -415,7 +388,6 @@ class InternVisionEncoder(nn.Module): class InternVisionModel(nn.Module): - packed_modules_mapping = { "qkv": ["qkv"], } @@ -428,10 +400,12 @@ class InternVisionModel(nn.Module): num_hidden_layers_override: Optional[int] = None, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = InternVisionEmbeddings(config) self.encoder = InternVisionEncoder( @@ -440,6 +414,7 @@ class InternVisionModel(nn.Module): num_hidden_layers_override=num_hidden_layers_override, num_dummy_heads=num_dummy_heads, prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel, ) def get_input_embeddings(self): @@ -451,8 +426,7 @@ class InternVisionModel(nn.Module): pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: - raise ValueError( - 'You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds @@ -460,21 +434,21 @@ class InternVisionModel(nn.Module): if pixel_values.ndim == 4: hidden_states = self.embeddings(pixel_values) else: - raise ValueError( - f'wrong pixel_values size: {pixel_values.shape}') + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") - encoder_outputs = self.encoder(inputs_embeds=hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d0c4bf5450d6d..128791541b3db 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Any, Optional, Union import torch @@ -12,33 +13,42 @@ from transformers import PretrainedConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class InternLM2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -63,8 +73,9 @@ class InternLM2MLP(nn.Module): prefix=f"{prefix}.w2", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -75,7 +86,6 @@ class InternLM2MLP(nn.Module): class InternLM2Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -156,16 +166,16 @@ class InternLM2Attention(nn.Module): qkv = qkv[::3] + qkv[1::3] + qkv[2::3] qkv = torch.cat(qkv, dim=-1) - qkv = qkv.view(seq_len, self.total_num_kv_heads, - self.key_value_groups + 2, self.head_dim) + qkv = qkv.view( + seq_len, self.total_num_kv_heads, self.key_value_groups + 2, self.head_dim + ) q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) q = q.reshape(seq_len, self.q_size * self.tp_size) k = k.reshape(seq_len, self.kv_size * self.tp_size) v = v.reshape(seq_len, self.kv_size * self.tp_size) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] @@ -185,7 +195,6 @@ class InternLM2Attention(nn.Module): class InternLMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -197,8 +206,7 @@ class InternLMDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -217,8 +225,7 @@ class InternLMDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) - self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -232,8 +239,7 @@ class InternLMDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: - hidden_states, residual = self.attention_norm( - hidden_states, residual) + hidden_states, residual = self.attention_norm(hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, @@ -247,13 +253,13 @@ class InternLMDecoderLayer(nn.Module): @support_torch_compile class InternLM2Model(nn.Module): - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer): + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -269,12 +275,14 @@ class InternLM2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: layer_type( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -296,13 +304,12 @@ class InternLM2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -313,11 +320,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "gate_up_proj": ["w1", "w3"], } - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - model_type: type[InternLM2Model] = InternLM2Model): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: type[InternLM2Model] = InternLM2Model, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -327,17 +336,21 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): self.quant_config = quant_config self.lora_config = lora_config - self.model = model_type(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.output = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "output")) + self.model = model_type( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.output = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "output"), + ) if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -349,21 +362,19 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.output, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.output, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), @@ -374,7 +385,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -394,8 +405,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -403,7 +413,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): @default_pooling_type("ALL") class InternLM2ForRewardModel(InternLM2ForCausalLM): - is_pooling_model = True def __init__( @@ -413,27 +422,30 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): prefix: str = "", model_type: type[InternLM2Model] = InternLM2Model, ): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - model_type=model_type) + super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type) for attr in ("output", "logits_processor"): delattr(self, attr) config = vllm_config.model_config.hf_config + self.head_dtype = vllm_config.model_config.head_dtype + self.v_head = RowParallelLinear( config.hidden_size, 1, bias=False, input_is_parallel=False, + params_dtype=self.head_dtype, prefix=maybe_prefix(prefix, "v_head"), + return_bias=False, ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) def forward( self, @@ -442,7 +454,9 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - logits, _ = self.v_head(hidden_states) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + hidden_states = hidden_states.to(self.head_dtype) + logits = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 4bbb49da0e96f..5344ded280b2a 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from itertools import islice from typing import Optional, Union import torch @@ -11,14 +12,16 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.internlm2 import (InternLM2Attention, - InternLM2ForCausalLM, - InternLM2MLP, InternLM2Model) +from vllm.model_executor.models.internlm2 import ( + InternLM2Attention, + InternLM2ForCausalLM, + InternLM2MLP, + InternLM2Model, +) from vllm.sequence import IntermediateTensors class InternLM2VEDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -30,8 +33,7 @@ class InternLM2VEDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -57,8 +59,7 @@ class InternLM2VEDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.feed_forward_ve", ) - self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -73,8 +74,7 @@ class InternLM2VEDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: - hidden_states, residual = self.attention_norm( - hidden_states, residual) + hidden_states, residual = self.attention_norm(hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, @@ -83,27 +83,25 @@ class InternLM2VEDecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) if visual_token_mask is not None and visual_token_mask.any(): - visual_token_mask = visual_token_mask.repeat( - 1, self.hidden_size).bool() + visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() text_token_mask = ~visual_token_mask hidden_states[visual_token_mask] = self.feed_forward_ve( - hidden_states[visual_token_mask].reshape( - -1, self.hidden_size)).flatten() + hidden_states[visual_token_mask].reshape(-1, self.hidden_size) + ).flatten() if text_token_mask.any(): hidden_states[text_token_mask] = self.feed_forward( - hidden_states[text_token_mask].reshape( - -1, self.hidden_size)).flatten() + hidden_states[text_token_mask].reshape(-1, self.hidden_size) + ).flatten() else: hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class InternLM2VEModel(InternLM2Model): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=InternLM2VEDecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=InternLM2VEDecoderLayer + ) def forward( self, @@ -123,7 +121,7 @@ class InternLM2VEModel(InternLM2Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -131,17 +129,15 @@ class InternLM2VEModel(InternLM2Model): visual_token_mask=visual_token_mask, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class InternLM2VEForCausalLM(InternLM2ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - model_type=InternLM2VEModel) + super().__init__( + vllm_config=vllm_config, prefix=prefix, model_type=InternLM2VEModel + ) diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index c739e74b058fa..06c7c8ccd0b5e 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -7,7 +7,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import regex as re import torch @@ -15,44 +15,69 @@ import torch.nn as nn from transformers import BatchFeature, InternVLProcessor, PretrainedConfig from transformers.activations import ACT2FN from transformers.models.got_ocr2.image_processing_got_ocr2_fast import ( - GotOcr2ImageProcessorFast) + GotOcr2ImageProcessorFast, +) +from transformers.models.internvl.video_processing_internvl import ( + InternVLVideoProcessor, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import cached_video_processor_from_config +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) class InternS1MultiModalProjector(nn.Module): - def __init__(self, config): super().__init__() - self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * - int(1 / config.downsample_ratio)**2) + self.layer_norm = nn.LayerNorm( + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2 + ) self.linear_1 = nn.Linear( - config.vision_config.hidden_size * - int(1 / config.downsample_ratio)**2, - config.text_config.hidden_size) + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, + config.text_config.hidden_size, + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, - config.text_config.hidden_size) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size + ) def forward(self, image_features): hidden_states = self.layer_norm(image_features) @@ -62,55 +87,68 @@ class InternS1MultiModalProjector(nn.Module): return hidden_states -class InternS1ImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class InternS1ImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class InternS1ImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] + +class InternS1ImageEmbeddingInputs(TensorSchema): """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - ni: Number of images + - tifs: Total image feature size + - hs: Hidden size (must match language model backbone) """ - -InternS1ImageInputs = Union[InternS1ImagePixelInputs, - InternS1ImageEmbeddingInputs] + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("ni", "tifs", "hs") + ] -class InternS1VideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values: torch.Tensor +InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageEmbeddingInputs] + + +class InternS1VideoPixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_video * num_frames, num_channels, height, width)` + Dimensions: + - bnv: Batch size * number of videos * number of frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width """ - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class InternS1VideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] +class InternS1VideoEmbeddingInputs(TensorSchema): """ - A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` - or a list of tensors of shape `(total_video_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. + Dimensions: + - nv: Number of videos + - tvfs: Total video feature size + - hs: Hidden size (must match language model backbone) """ + type: Literal["video_embeds"] = "video_embeds" + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("nv", "tvfs", "hs") + ] -InternS1VideoInputs = Union[InternS1VideoPixelInputs, - InternS1VideoEmbeddingInputs] + +InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoEmbeddingInputs] def resolve_interns1_min_max_num( @@ -132,10 +170,13 @@ def get_interns1_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -143,7 +184,11 @@ class InternS1ProcessingInfo(BaseProcessingInfo): """ProcessingInfo for InternS1-style models.""" def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: - return self.ctx.get_hf_processor(InternVLProcessor, **kwargs) + hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs) + hf_processor.video_processor = cached_video_processor_from_config( + self.ctx.model_config, processor_cls=InternVLVideoProcessor, **kwargs + ) + return hf_processor def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} @@ -153,18 +198,19 @@ class InternS1ProcessingInfo(BaseProcessingInfo): *, image_width: int, image_height: int, - processor: Optional['GotOcr2ImageProcessorFast'] = None, + processor: Optional["GotOcr2ImageProcessorFast"] = None, ) -> int: if processor is None: processor = self.get_hf_processor().image_processor if not isinstance(processor, GotOcr2ImageProcessorFast): - raise ValueError(f'GotOcr2ImageProcessorFast is expected but got ' - f'{type(processor)}') + raise ValueError( + f"GotOcr2ImageProcessorFast is expected but got {type(processor)}" + ) num_image_patches = processor.get_number_of_image_patches( - image_height, image_width, images_kwargs=dict()) - num_image_tokens = self.get_hf_processor( - ).image_seq_length * num_image_patches + image_height, image_width, images_kwargs=dict() + ) + num_image_tokens = self.get_hf_processor().image_seq_length * num_image_patches return num_image_tokens def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None): @@ -179,7 +225,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo): min_dynamic_patch, max_dynamic_patch, dynamic_image_size, - use_thumbnail=use_thumbnail) + use_thumbnail=use_thumbnail, + ) return get_interns1_target_ratios(min_num, max_num) @@ -201,11 +248,11 @@ class InternS1ProcessingInfo(BaseProcessingInfo): ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) - assert not (largest_feature_size == 0 or largest_feature_pinpoint - is None), ("Cannot have a largest feature size of 0!") + assert not (largest_feature_size == 0 or largest_feature_pinpoint is None), ( + "Cannot have a largest feature size of 0!" + ) return largest_feature_pinpoint @@ -230,15 +277,13 @@ class InternS1ProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - - max_image_tokens) // processor.image_seq_length + max_total_frames = (seq_len - max_image_tokens) // processor.image_seq_length max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) -class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo] - ): +class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]): """DummyInputsBuilder for InternS1-style models.""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -253,33 +298,40 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo] self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) config = self.info.get_hf_config() image_size_h, image_size_w = config.vision_config.image_size + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos(width=image_size_w, - height=image_size_h, - num_frames=target_num_frames, - num_videos=num_videos), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=image_size_w, + height=image_size_h, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ), } -class InternS1MultiModalProcessor( - BaseMultiModalProcessor[InternS1ProcessingInfo]): - """ Basic image-only MultiModalProcessor for InternS1-style models.""" +class InternS1MultiModalProcessor(BaseMultiModalProcessor[InternS1ProcessingInfo]): + """Basic image-only MultiModalProcessor for InternS1-style models.""" def _call_hf_processor( self, @@ -287,7 +339,7 @@ class InternS1MultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: mm_data = dict(mm_data) videos = mm_data.pop("videos", []) images = mm_data.pop("images", []) @@ -296,15 +348,14 @@ class InternS1MultiModalProcessor( hf_processor = self.info.get_hf_processor(**mm_kwargs) tokenizer = hf_processor.tokenizer - video_token_id = tokenizer.encode(hf_processor.video_token, - add_special_tokens=False) + video_token_id = tokenizer.encode( + hf_processor.video_token, add_special_tokens=False + ) assert len(video_token_id) == 1 video_token_id = video_token_id[0] - prompt = re.sub(hf_processor.image_token, "<image_placeholder>", - prompt) - prompt = re.sub(hf_processor.video_token, "<video_placeholder>", - prompt) + prompt = re.sub(hf_processor.image_token, "<image_placeholder>", prompt) + prompt = re.sub(hf_processor.video_token, "<video_placeholder>", prompt) image_outputs = {} if images: @@ -316,16 +367,14 @@ class InternS1MultiModalProcessor( mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - image_pixel_values.append( - processed_outputs.pop("pixel_values")) + image_pixel_values.append(processed_outputs.pop("pixel_values")) input_ids = processed_outputs.pop("input_ids") image_placeholder = tokenizer.batch_decode(input_ids)[0] - prompt = prompt.replace("<image_placeholder>", - image_placeholder, 1) + prompt = prompt.replace("<image_placeholder>", image_placeholder, 1) num_patches = [len(item) for item in image_pixel_values] - image_outputs: dict[str, NestedTensors] = { + image_outputs = { "pixel_values": torch.concat(image_pixel_values), "image_num_patches": torch.tensor(num_patches), "image_token_id": torch.tensor(hf_processor.image_token_id), @@ -341,43 +390,32 @@ class InternS1MultiModalProcessor( mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - video_pixel_values.append( - processed_outputs.pop("pixel_values")) + video_pixel_values.append(processed_outputs.pop("pixel_values")) input_ids = processed_outputs.pop("input_ids") - input_ids[input_ids == - hf_processor.image_token_id] = video_token_id + input_ids[input_ids == hf_processor.image_token_id] = video_token_id video_placeholder = tokenizer.batch_decode(input_ids)[0] - prompt = prompt.replace("<video_placeholder>", - video_placeholder, 1) + prompt = prompt.replace("<video_placeholder>", video_placeholder, 1) num_frames = [len(item) for item in video_pixel_values] - video_outputs: dict[str, NestedTensors] = { + video_outputs = { "pixel_values_videos": torch.concat(video_pixel_values), "video_num_patches": torch.tensor(num_frames), "video_token_id": torch.tensor(video_token_id), } - prompt = re.sub("<image_placeholder>", hf_processor.image_token, - prompt) - prompt = re.sub("<video_placeholder>", hf_processor.video_token, - prompt) + prompt = re.sub("<image_placeholder>", hf_processor.image_token, prompt) + prompt = re.sub("<video_placeholder>", hf_processor.video_token, prompt) text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt") - combined_outputs = dict( - **text_outputs, - **image_outputs, - **video_outputs, - ) - return BatchFeature(combined_outputs) + return BatchFeature({**text_outputs, **image_outputs, **video_outputs}) def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) num_images = len(image_num_patches) @@ -385,12 +423,14 @@ class InternS1MultiModalProcessor( return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), + "video", video_num_patches + ), video_num_patches=MultiModalFieldConfig.batched("video"), video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) @@ -424,7 +464,8 @@ class InternS1MultiModalProcessor( def get_replacement_interns1_image(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -434,19 +475,16 @@ class InternS1MultiModalProcessor( repl_features = img_context_token * feature_size repl_full = start_image_token + repl_features + end_image_token - return PromptUpdateDetails.select_text(repl_full, - img_context_token) + return PromptUpdateDetails.select_text(repl_full, img_context_token) def get_replacement_interns1_video(item_idx: int): num_patches = video_num_patches[item_idx] repl_features = video_token * hf_processor.image_seq_length - repl_features_with_sep = (start_image_token + repl_features + - end_image_token) + repl_features_with_sep = start_image_token + repl_features + end_image_token # num_patches is equal to num_frames - repl_full = '\n'.join([ - f'Frame{i+1}: {repl_features_with_sep}' - for i in range(num_patches) - ]) + repl_full = "\n".join( + [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)] + ) return PromptUpdateDetails.select_text(repl_full, video_token) @@ -467,9 +505,12 @@ class InternS1MultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( InternS1MultiModalProcessor, info=InternS1ProcessingInfo, - dummy_inputs=InternS1DummyInputsBuilder) -class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsLoRA): + dummy_inputs=InternS1DummyInputsBuilder, +) +class InternS1ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): + merge_by_field_config = True # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -478,14 +519,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - # transformers InternVLProcessor uses <IMG_CONTEXT> as the seperator + # transformers InternVLProcessor uses <IMG_CONTEXT> as the separator # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116 if modality.startswith("image"): - return '<IMG_CONTEXT>' + return "<IMG_CONTEXT>" if modality.startswith("video"): return "<video>" @@ -504,7 +546,8 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, patch_size = config.vision_config.patch_size[0] self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.llm_arch_name = config.text_config.architectures[0] @@ -527,7 +570,8 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _init_vision_model( self, @@ -544,7 +588,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, prefix=prefix, ) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: return InternS1MultiModalProjector(config) def pixel_shuffle(self, x, scale_factor=0.5): @@ -553,8 +597,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) x = x.permute(0, 2, 1, 3).contiguous() return x @@ -562,38 +610,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, vit_embeds = self.vision_tower(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.multi_modal_projector(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h, w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[InternS1ImageInputs]: + self, **kwargs: object + ) -> Optional[InternS1ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -602,13 +629,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternS1ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -616,27 +639,22 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values = flatten_bn(pixel_values, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - + h, w = self.config.vision_config.image_size return InternS1ImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values(pixel_values), + pixel_values=pixel_values, num_patches=image_num_patches, + resolve_bindings={ + "h": h, + "w": w, + }, ) raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]: + self, **kwargs: object + ) -> Optional[InternS1VideoInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_videos", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) @@ -645,13 +663,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if video_embeds is not None: - if not isinstance(video_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") - - return InternS1ImageEmbeddingInputs( + return InternS1VideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) video_token_id = kwargs["video_token_id"] @@ -659,32 +673,27 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self.video_context_token_id = video_token_id.flatten().unique().item() if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) - + h, w = self.config.vision_config.image_size return InternS1VideoPixelInputs( type="pixel_values_videos", - pixel_values=self._validate_pixel_values( - pixel_values_flat_video), num_patches=video_num_patches, + pixel_values=pixel_values_flat_video, + resolve_bindings={ + "h": h, + "w": w, + }, ) raise AssertionError("This line should be unreachable.") - def _process_image_input( + def _process_vision_input( self, - image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs], + image_input: Union[InternS1ImageInputs, InternS1VideoInputs], ) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": + if ( + image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds" + ): return image_input["data"] assert self.vision_tower is not None @@ -695,14 +704,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -714,14 +721,13 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ( - "pixel_values_videos", ) and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities @@ -731,15 +737,13 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -747,11 +751,11 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + vision_embeddings = self._process_vision_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] - video_embeddings = self._process_image_input(video_input) + video_embeddings = self._process_vision_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -760,24 +764,23 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -787,19 +790,10 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - forward_kwargs = { "input_ids": input_ids, "positions": positions, @@ -813,13 +807,10 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -830,4 +821,5 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index 300ed17ecaabc..f5965bdf7c9c7 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -12,54 +12,51 @@ from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F from transformers import PretrainedConfig from transformers.utils import torch_int +from vllm.attention.layer import MultiHeadAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader NORM2FN = { - 'rms_norm': RMSNorm, - 'layer_norm': nn.LayerNorm, + "rms_norm": RMSNorm, + "layer_norm": nn.LayerNorm, } class InternS1VisionPatchEmbeddings(nn.Module): - def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // - patch_size[0]) - patch_shape = (image_size[0] // patch_size[0], - image_size[1] // patch_size[1]) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.patch_shape = patch_shape - self.projection = nn.Conv2d(num_channels, - hidden_size, - kernel_size=patch_size, - stride=patch_size) + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values " - "match with the one set in the configuration.") + "match with the one set in the configuration." + ) - embeddings = self.projection( - pixel_values.to(self.projection.weight.dtype)) + embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)) patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] embeddings = embeddings.flatten(2).transpose(1, 2) @@ -67,30 +64,32 @@ class InternS1VisionPatchEmbeddings(nn.Module): class InternS1VisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if config.use_mask_token: - self.mask_token = nn.Parameter( - torch.zeros(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) else: self.mask_token = None self.patch_embeddings = InternS1VisionPatchEmbeddings(config) self.patch_size = config.patch_size - self.image_size = (config.image_size if isinstance( - config.image_size, Iterable) else - (config.image_size, config.image_size)) + self.image_size = ( + config.image_size + if isinstance(config.image_size, Iterable) + else (config.image_size, config.image_size) + ) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter( - torch.zeros(1, num_patches + 1, config.hidden_size)) + torch.zeros(1, num_patches + 1, config.hidden_size) + ) else: self.position_embeddings = None - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, - width: int) -> torch.Tensor: + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. @@ -105,8 +104,11 @@ class InternS1VisionEmbeddings(nn.Module): # always interpolate when tracing to ensure the exported model # works for dynamic input shapes - if not torch.jit.is_tracing( - ) and num_patches == num_positions and height == width: + if ( + not torch.jit.is_tracing() + and num_patches == num_positions + and height == width + ): return self.position_embeddings class_pos_embed = self.position_embeddings[:, :1] @@ -118,8 +120,9 @@ class InternS1VisionEmbeddings(nn.Module): new_width = width // self.patch_size[1] sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, - sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( @@ -139,8 +142,7 @@ class InternS1VisionEmbeddings(nn.Module): bool_masked_pos: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: _, _, height, width = pixel_values.shape - embeddings, (patch_height, - patch_width) = self.patch_embeddings(pixel_values) + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: @@ -154,7 +156,8 @@ class InternS1VisionEmbeddings(nn.Module): if self.position_embeddings is not None: embeddings = embeddings + self.interpolate_pos_encoding( - embeddings, height, width) + embeddings, height, width + ) return embeddings, (patch_height, patch_width) @@ -176,36 +179,44 @@ class InternSdpaAttention(nn.Module): self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) # Additional dummy heads are used to enable TP for common GPU counts. self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim self.scale = self.head_dim**-0.5 - self.q_proj = nn.Linear(self.embed_dim, - self.num_heads * self.head_dim, - bias=config.attention_bias) - self.k_proj = nn.Linear(self.embed_dim, - self.num_heads * self.head_dim, - bias=config.attention_bias) - self.v_proj = nn.Linear(self.embed_dim, - self.num_heads * self.head_dim, - bias=config.attention_bias) + self.q_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias + ) self.qk_normalization = config.use_qk_norm if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) + self.q_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) + self.k_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape @@ -213,27 +224,19 @@ class InternSdpaAttention(nn.Module): k = self.k_proj(x) v = self.v_proj(x) - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - if self.qk_normalization: B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.projection_layer(x) return x class InternS1VisionMLP(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -244,16 +247,20 @@ class InternS1VisionMLP(nn.Module): self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -264,7 +271,6 @@ class InternS1VisionMLP(nn.Module): class InternS1VisionLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -275,26 +281,30 @@ class InternS1VisionLayer(nn.Module): ) -> None: super().__init__() - self.attention = self._init_attn(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.attention") + self.attention = self._init_attn( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.attention", + ) - self.mlp = InternS1VisionMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = InternS1VisionMLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) self.layernorm_before = NORM2FN[config.norm_type]( - config.hidden_size, eps=config.layer_norm_eps) + config.hidden_size, eps=config.layer_norm_eps + ) self.layernorm_after = NORM2FN[config.norm_type]( - config.hidden_size, eps=config.layer_norm_eps) + config.hidden_size, eps=config.layer_norm_eps + ) init_values = config.layer_scale_init_value - self.lambda_1 = nn.Parameter(init_values * - torch.ones(config.hidden_size), - requires_grad=True) - self.lambda_2 = nn.Parameter(init_values * - torch.ones(config.hidden_size), - requires_grad=True) + self.lambda_1 = nn.Parameter( + init_values * torch.ones(config.hidden_size), requires_grad=True + ) + self.lambda_2 = nn.Parameter( + init_values * torch.ones(config.hidden_size), requires_grad=True + ) def _init_attn( self, @@ -310,17 +320,20 @@ class InternS1VisionLayer(nn.Module): self, hidden_states: torch.Tensor, ): - hidden_states = hidden_states + self.attention( - self.layernorm_before(hidden_states)) * self.lambda_1 + hidden_states = ( + hidden_states + + self.attention(self.layernorm_before(hidden_states)) * self.lambda_1 + ) - hidden_states = hidden_states + self.mlp( - self.layernorm_after(hidden_states)) * self.lambda_2 + hidden_states = ( + hidden_states + + self.mlp(self.layernorm_after(hidden_states)) * self.lambda_2 + ) return hidden_states class InternS1VisionEncoder(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -339,16 +352,19 @@ class InternS1VisionEncoder(nn.Module): else: num_hidden_layers = num_hidden_layers_override - self.layer = nn.ModuleList([ - InternS1VisionLayer(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layer = nn.ModuleList( + [ + InternS1VisionLayer( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward(self, inputs_embeds: torch.Tensor): - hidden_states = inputs_embeds for encoder_layer in self.layer: hidden_states = encoder_layer(hidden_states) @@ -357,7 +373,6 @@ class InternS1VisionEncoder(nn.Module): class InternS1VisionModel(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -378,9 +393,11 @@ class InternS1VisionModel(nn.Module): num_dummy_heads=num_dummy_heads, prefix=f"{prefix}.encoder", ) - self.layernorm = (nn.Identity() if config.use_mean_pooling else - nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps)) + self.layernorm = ( + nn.Identity() + if config.use_mean_pooling + else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) def get_input_embeddings(self): return self.embeddings.patch_embeddings @@ -391,8 +408,7 @@ class InternS1VisionModel(nn.Module): pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: - raise ValueError( - 'You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds @@ -400,22 +416,19 @@ class InternS1VisionModel(nn.Module): if pixel_values.ndim == 4: hidden_states, _ = self.embeddings(pixel_values) else: - raise ValueError( - f'wrong pixel_values size: {pixel_values.shape}') + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") encoder_outputs = self.encoder(inputs_embeds=hidden_states) encoder_outputs = self.layernorm(encoder_outputs) return encoder_outputs - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index da8ad8396725d..3cd3807dd8884 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -7,6 +7,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import os from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, Optional, TypeVar, Union @@ -16,37 +17,54 @@ import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchEncoding, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.models.intern_vit import (InternVisionModel, - InternVisionPatchModel) +from vllm.model_executor.models.intern_vit import ( + InternVisionModel, + InternVisionPatchModel, +) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = '<img>' -IMG_END = '</img>' -IMG_CONTEXT = '<IMG_CONTEXT>' +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<IMG_CONTEXT>" IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) @@ -61,6 +79,7 @@ class InternVLImagePixelInputs(TensorSchema): - h: Height of each image patch - w: Width of each image patch """ + type: Literal["pixel_values"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -73,13 +92,12 @@ class InternVLImageEmbeddingInputs(TensorSchema): - f: Total image feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("n", "f", "h")] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] -InternVLImageInputs = Union[InternVLImagePixelInputs, - InternVLImageEmbeddingInputs] +InternVLImageInputs = Union[InternVLImagePixelInputs, InternVLImageEmbeddingInputs] class InternVLVideoPixelInputs(TensorSchema): @@ -91,6 +109,7 @@ class InternVLVideoPixelInputs(TensorSchema): - h: Height of each video frame - w: Width of each video frame """ + type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -103,25 +122,40 @@ class InternVLVideoEmbeddingInputs(TensorSchema): - f: Total video feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["video_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("n", "f", "h")] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] -InternVLVideoInputs = Union[InternVLVideoPixelInputs, - InternVLVideoEmbeddingInputs] +InternVLVideoInputs = Union[InternVLVideoPixelInputs, InternVLVideoEmbeddingInputs] # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - return T.Compose([ - T.Lambda(lambda img: convert_image_mode(img, 'RGB')), - T.Resize((input_size, input_size), - interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=MEAN, std=STD) - ]) + transform = T.Compose( + [ + T.Lambda(lambda img: convert_image_mode(img, "RGB")), + T.Resize( + (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC + ), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + # Image transformation operations (which include tensor computations + # on the CPU) can occupy a substantial number of CPU cores, introducing + # overhead due to CPU contention. This issue becomes particularly + # noticeable when deploying multiple vLLM instances on a single machine. + # Therefore, it is necessary to limit the number of threads allocated to + # image transformation tasks. + num_threads = int(os.environ.get("OMP_NUM_THREADS", "1")) + + def apply(img): + with set_default_torch_num_threads(num_threads): + return transform(img) + + return apply # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B @@ -133,7 +167,7 @@ def find_closest_aspect_ratio( height: int, image_size: int, ) -> tuple[int, int]: - best_ratio_diff = float('inf') + best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: @@ -168,10 +202,13 @@ def get_internvl_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -229,10 +266,12 @@ def dynamic_preprocess_internvl( resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -335,7 +374,8 @@ class BaseInternVLProcessor(ABC): assert isinstance(dynamic_image_size, bool) self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch @@ -363,14 +403,18 @@ class BaseInternVLProcessor(ABC): dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch - is None else min_dynamic_patch) - max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch - is None else max_dynamic_patch) - dynamic_image_size = (self.dynamic_image_size if dynamic_image_size - is None else dynamic_image_size) - use_thumbnail = (self.use_thumbnail - if use_thumbnail is None else use_thumbnail) + min_dynamic_patch = ( + self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch + ) + max_dynamic_patch = ( + self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch + ) + dynamic_image_size = ( + self.dynamic_image_size + if dynamic_image_size is None + else dynamic_image_size + ) + use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail return resolve_internvl_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -437,7 +481,8 @@ class BaseInternVLProcessor(ABC): min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] def _preprocess_image( @@ -457,11 +502,11 @@ class BaseInternVLProcessor(ABC): max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + image_inputs = { + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: @@ -469,11 +514,10 @@ class BaseInternVLProcessor(ABC): feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace('<image>', image_repl.full, 1) for t in text] + text = [t.replace("<image>", image_repl.full, 1) for t in text] return text, image_inputs - def _make_batch_input(self, - input_item: Optional[Union[Any, list[Any]]] = None): + def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): if input_item is None: input_item = [] if not isinstance(input_item, list): @@ -488,7 +532,7 @@ class BaseInternVLProcessor(ABC): max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: text, images = [self._make_batch_input(x) for x in (text, images)] text, image_inputs = self._preprocess_image( @@ -501,10 +545,9 @@ class BaseInternVLProcessor(ABC): text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class InternVLProcessor(BaseInternVLProcessor): @@ -568,7 +611,8 @@ class InternVLProcessor(BaseInternVLProcessor): min_num=min_num, max_num=max_num, use_thumbnail=False, - ) for video in videos + ) + for video in videos ] def _preprocess_video( @@ -584,19 +628,20 @@ class InternVLProcessor(BaseInternVLProcessor): videos, dynamic_image_size=dynamic_image_size, ) - video_inputs: dict[str, NestedTensors] = { - "pixel_values_flat_video": - torch.cat(pixel_values_lst_video), - "video_num_patches": - torch.tensor([len(item) for item in pixel_values_lst_video]), + video_inputs = { + "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "video_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst_video] + ), } for pixel_values in pixel_values_lst_video: num_patches = pixel_values.shape[0] - video_repl = self.get_video_repl(self.num_image_token, - num_patches, self.video_token) - text = [t.replace('<video>', video_repl.full, 1) for t in text] + video_repl = self.get_video_repl( + self.num_image_token, num_patches, self.video_token + ) + text = [t.replace("<video>", video_repl.full, 1) for t in text] return text, video_inputs def __call__( @@ -608,7 +653,7 @@ class InternVLProcessor(BaseInternVLProcessor): max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: text, images, videos = [ self._make_batch_input(x) for x in (text, images, videos) ] @@ -629,11 +674,9 @@ class InternVLProcessor(BaseInternVLProcessor): text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - **video_inputs, - } + combined_outputs = {**text_inputs, **image_inputs, **video_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) def get_image_repl( self, @@ -654,9 +697,9 @@ class InternVLProcessor(BaseInternVLProcessor): repl_features = video_context_token * self.num_image_token repl_features_with_sep = IMG_START + repl_features + IMG_END # num_patches is equal to num_frames - repl_full = ''.join([ - f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches) - ]) + repl_full = "".join( + [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)] + ) return PromptUpdateDetails.select_text(repl_full, video_context_token) @@ -703,8 +746,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -737,21 +779,25 @@ class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): - """ Basic image-only MultiModalProcessor for InternVL-style models.""" + """Basic image-only MultiModalProcessor for InternVL-style models.""" def _call_hf_processor( self, @@ -759,7 +805,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -779,7 +825,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -787,7 +833,8 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), @@ -815,7 +862,8 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): def get_replacement_internvl(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -855,9 +903,13 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): def get_video_token(self) -> Optional[str]: text_model_type = self.get_hf_config().get_text_config().model_type - if text_model_type == "qwen2": - return "<|video_pad|>" - return None + video_token_map = { + "qwen2": "<|video_pad|>", + "qwen3": "<|video_pad|>", + "qwen3_moe": "<|video_pad|>", + "gpt_oss": "<|reserved_200000|>", + } + return video_token_map.get(text_model_type) def get_num_frames_with_most_features( self, @@ -870,8 +922,7 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): processor = self.get_hf_processor() max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - - max_image_tokens) // processor.num_image_token + max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -887,7 +938,8 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo): class InternVLDummyInputsBuilder( - BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]): + BaseInternVLDummyInputsBuilder[InternVLProcessingInfo] +): """InternVL DummyInputsBuilder extended for video support""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -899,21 +951,27 @@ class InternVLDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - dummy_image = super().get_dummy_mm_data(seq_len=seq_len, - mm_counts=mm_counts) + dummy_image = super().get_dummy_mm_data( + seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options + ) if self.info.supports_video: config = self.info.get_hf_config() image_size: int = config.vision_config.image_size - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) num_videos = mm_counts.get("video", 0) + video_overrides = mm_options.get("video") if mm_options else None dummy_video = { - "video": - self._get_dummy_videos(width=image_size, - height=image_size, - num_frames=target_num_frames, - num_videos=num_videos) + "video": self._get_dummy_videos( + width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ) } else: dummy_video = {} @@ -921,7 +979,8 @@ class InternVLDummyInputsBuilder( class InternVLMultiModalProcessor( - BaseInternVLMultiModalProcessor[InternVLProcessingInfo]): + BaseInternVLMultiModalProcessor[InternVLProcessingInfo] +): """InternVL MultiModalProcessor extended for video support""" def _call_hf_processor( @@ -930,33 +989,34 @@ class InternVLMultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) hf_processor = self.info.get_hf_processor(**mm_kwargs) - if self.info.supports_video and ( - video_token_id := hf_processor.video_token_id) is not None: + if ( + self.info.supports_video + and (video_token_id := hf_processor.video_token_id) is not None + ): processed_outputs["video_token_id"] = torch.tensor(video_token_id) return processed_outputs def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_fields = super()._get_mm_fields_config(hf_inputs, - hf_processor_mm_kwargs) + image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) if self.info.supports_video: - video_num_patches = hf_inputs.get("video_num_patches", - torch.empty(0)) + video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) num_videos = len(video_num_patches) video_fields = dict( pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), + "video", video_num_patches + ), video_num_patches=MultiModalFieldConfig.batched("video"), - video_token_id=MultiModalFieldConfig.shared( - "video", num_videos), + video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) else: video_fields = {} @@ -992,9 +1052,8 @@ class InternVLMultiModalProcessor( assert isinstance(num_patches, int) return hf_processor.get_video_repl( - feature_size, - num_patches, - video_context_token=hf_processor.video_token) + feature_size, num_patches, video_context_token=hf_processor.video_token + ) if self.info.supports_video: prompt_repl = [ @@ -1003,7 +1062,7 @@ class InternVLMultiModalProcessor( modality="video", target="<video>", replacement=get_video_replacement_internvl, - ) + ), ] return prompt_repl @@ -1012,9 +1071,12 @@ class InternVLMultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( InternVLMultiModalProcessor, info=InternVLProcessingInfo, - dummy_inputs=InternVLDummyInputsBuilder) -class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): + dummy_inputs=InternVLDummyInputsBuilder, +) +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True + + supports_encoder_tp_data = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1034,18 +1096,20 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.llm_arch_name = config.text_config.architectures[0] - self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + self.is_mono = self.llm_arch_name == "InternLM2VEForCausalLM" self.vision_model = self._init_vision_model( config, quant_config=quant_config, @@ -1066,18 +1130,20 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and \ - (llm_quant_config is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( @@ -1091,8 +1157,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 @@ -1101,18 +1168,20 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, + use_data_parallel=self.use_data_parallel, ) else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, - llm_hidden_size), + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size + ), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size), ) @@ -1123,9 +1192,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) - if self.ps_version == 'v1': + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": pass else: x = x.permute(0, 2, 1, 3).contiguous() @@ -1135,17 +1208,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[InternVLImageInputs]: + self, **kwargs: object + ) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -1154,13 +1226,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -1168,16 +1236,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) expected_h = expected_w = self.config.vision_config.image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1191,7 +1249,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[InternVLVideoPixelInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("image_embeds", None) @@ -1202,7 +1261,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, if video_embeds is not None: return InternVLVideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) video_token_id = kwargs["video_token_id"] @@ -1210,17 +1269,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.video_context_token_id = video_token_id.flatten().unique().item() if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") - - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) expected_h = expected_w = self.config.vision_config.image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1233,11 +1281,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, raise AssertionError("This line should be unreachable.") - def _process_image_input( + def _process_vision_input( self, - image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs], + image_input: Union[InternVLImageInputs, InternVLVideoInputs], ) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": + if ( + image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds" + ): return image_input["data"] assert self.vision_model is not None @@ -1248,14 +1299,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -1267,31 +1316,29 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values_flat", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_flat_video", - ) and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values_flat", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_flat_video",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: assert self.img_context_token_id is not None - self.visual_token_mask = ( - input_ids == self.img_context_token_id).reshape(-1, 1) + self.visual_token_mask = (input_ids == self.img_context_token_id).reshape( + -1, 1 + ) else: self.visual_token_mask = None def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1305,11 +1352,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + vision_embeddings = self._process_vision_input(image_input) multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] - video_embeddings = self._process_image_input(video_input) + video_embeddings = self._process_vision_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -1318,24 +1365,23 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [ - token_id for token_id in (self.img_context_token_id, - self.video_context_token_id) - if token_id is not None - ] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1345,19 +1391,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - forward_kwargs = { "input_ids": input_ids, "positions": positions, @@ -1367,8 +1404,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # Only required if the model is mono-architecture if self.visual_token_mask is not None: - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) @@ -1377,19 +1413,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B skip_prefixes = [ - "action_embed", "temporal_embed", "track_embed", - "track_embed_decoder", "box_token", "cg_criterion", "cg_model", - "loc_encoder", "loc_decoder", "sam", "temporal_token", - "track_token" + "action_embed", + "temporal_embed", + "track_embed", + "track_embed_decoder", + "box_token", + "cg_criterion", + "cg_model", + "loc_encoder", + "loc_decoder", + "sam", + "temporal_token", + "track_token", ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) @@ -1401,4 +1442,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return MultiModelKeys.from_string_field( language_model="language_model", connector="mlp1", - tower_model="vision_model") + tower_model="vision_model", + ) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bed4a5dff2efa..d788ed7ec2af7 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -23,6 +23,7 @@ import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -31,49 +32,57 @@ from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class SwiGLUActivation(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) class JAISAttention(nn.Module): - def __init__( self, config: JAISConfig, @@ -84,8 +93,7 @@ class JAISAttention(nn.Module): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -113,13 +121,15 @@ class JAISAttention(nn.Module): head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -133,7 +143,6 @@ class JAISAttention(nn.Module): class JAISMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -149,12 +158,16 @@ class JAISMLP(nn.Module): bias=True, quant_config=quant_config, ) - self.c_fc2 = (ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config, - ) if self.swiglu else None) + self.c_fc2 = ( + ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + if self.swiglu + else None + ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -168,14 +181,16 @@ class JAISMLP(nn.Module): if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = (self.act(hidden_states, hidden_states2) - if self.swiglu else self.act(hidden_states)) + hidden_states = ( + self.act(hidden_states, hidden_states2) + if self.swiglu + else self.act(hidden_states) + ) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class JAISBlock(nn.Module): - def __init__( self, config: JAISConfig, @@ -185,14 +200,12 @@ class JAISBlock(nn.Module): ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = JAISAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -202,7 +215,9 @@ class JAISBlock(nn.Module): ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn(hidden_states=hidden_states, ) + attn_output = self.attn( + hidden_states=hidden_states, + ) # residual connection hidden_states = attn_output + residual @@ -216,7 +231,6 @@ class JAISBlock(nn.Module): @support_torch_compile class JAISModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -230,9 +244,11 @@ class JAISModel(nn.Module): assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = (nn.Embedding(config.max_position_embeddings, - self.embed_dim) - if config.position_embedding_type != "alibi" else None) + self.wpe = ( + nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.position_embedding_type != "alibi" + else None + ) if hasattr(config, "embeddings_scale"): self.embeddings_scale = config.embeddings_scale else: @@ -240,17 +256,19 @@ class JAISModel(nn.Module): self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: JAISBlock(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: JAISBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -270,13 +288,14 @@ class JAISModel(nn.Module): hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor(float(self.embeddings_scale), - dtype=hidden_states.dtype) + hidden_states *= torch.tensor( + float(self.embeddings_scale), dtype=hidden_states.dtype + ) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: @@ -287,30 +306,33 @@ class JAISModel(nn.Module): class JAISLMHeadModel(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = JAISModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = JAISModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: - self.output_logits_scale = (config.mup_output_alpha * - config.mup_width_scale) - self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, - scale=self.output_logits_scale) + self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale + self.logits_processor = LogitsProcessor( + vocab_size=config.vocab_size, scale=self.output_logits_scale + ) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -322,21 +344,19 @@ class JAISLMHeadModel(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -366,8 +386,7 @@ class JAISLMHeadModel(nn.Module, SupportsPP): if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 3c1a0b68df56e..0371458f55784 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jamba model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional import torch from torch import nn from transformers import JambaConfig -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -16,41 +17,50 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class JambaMoE(nn.Module): - - def __init__(self, - config: JambaConfig, - num_experts: Optional[int] = None, - top_k: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: JambaConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.num_total_experts = num_experts or config.num_experts self.top_k = top_k or config.num_experts_per_tok @@ -58,23 +68,27 @@ class JambaMoE(nn.Module): self.intermediate_size = config.intermediate_size if self.num_total_experts > 1: - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None, - params_dtype=params_dtype) + self.router = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype, + ) - self.experts = FusedMoE(self.num_total_experts, - self.top_k, - self.hidden_size, - self.intermediate_size, - tp_size=tp_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=False, - use_grouped_topk=False, - quant_config=quant_config, - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + self.num_total_experts, + self.top_k, + self.hidden_size, + self.intermediate_size, + tp_size=tp_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + use_grouped_topk=False, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -83,43 +97,46 @@ class JambaMoE(nn.Module): if self.num_total_experts > 1: router_logits, _ = self.router(hidden_states) else: - router_logits = torch.ones((hidden_states.shape[0], 1), - device=hidden_states.device, - dtype=hidden_states.dtype) + router_logits = torch.ones( + (hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) hidden_states = self.experts(hidden_states, router_logits) return hidden_states.view(orig_shape) class JambaMambaDecoderLayer(nn.Module): - - def __init__(self, - config: JambaConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False, - prefix: str = "", - **kwargs) -> None: + def __init__( + self, + config: JambaConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.config = config self.is_lora_enabled = is_lora_enabled - self.mamba = MambaMixer(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - time_step_rank = config.mamba_dt_rank, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - use_rms_norm=True, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - is_lora_enabled = self.is_lora_enabled, - model_config=model_config, - cache_config=cache_config, - prefix=f"{prefix}.mixer", - ) + self.mamba = MambaMixer( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=config.mamba_expand * config.hidden_size, + time_step_rank=config.mamba_dt_rank, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + use_rms_norm=True, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.mixer", + ) num_experts = config.layers_num_experts[layer_idx] if num_experts > 1: @@ -136,27 +153,23 @@ class JambaMambaDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mamba(hidden_states, output, mamba_cache_params) + self.mamba(hidden_states, output) # Fully Connected hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -164,15 +177,16 @@ class JambaMambaDecoderLayer(nn.Module): class JambaAttentionDecoderLayer(nn.Module): - - def __init__(self, - config: JambaConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - **kwargs) -> None: + def __init__( + self, + config: JambaConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -202,10 +216,12 @@ class JambaAttentionDecoderLayer(nn.Module): bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) self.attn = Attention( self.num_heads, @@ -231,10 +247,8 @@ class JambaAttentionDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def self_attention( self, @@ -259,29 +273,26 @@ class JambaAttentionDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual ALL_DECODER_LAYER_TYPES = { "attention": JambaAttentionDecoderLayer, - "mamba": JambaMambaDecoderLayer + "mamba": JambaMambaDecoderLayer, } @support_torch_compile class JambaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -292,8 +303,11 @@ class JambaModel(nn.Module): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -307,24 +321,25 @@ class JambaModel(nn.Module): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layers_block_type[layer_idx]] - return layer_class(config, - layer_idx, - model_config, - cache_config, - quant_config=quant_config, - prefix=prefix, - **extra_kwargs) + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + **extra_kwargs, + ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -333,7 +348,6 @@ class JambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -348,29 +362,15 @@ class JambaModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - kv_cache_index = 0 - mamba_cache_index = 0 - for layer in self.layers[self.start_layer:self.end_layer]: - layer_mamba_cache_params = None - if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache_index += 1 - if isinstance(layer, - JambaMambaDecoderLayer) and mamba_cache_params: - current_state_layer = mamba_cache_index - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - current_state_layer) - mamba_cache_index += 1 - + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_mamba_cache_params) + positions=positions, hidden_states=hidden_states, residual=residual + ) + if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -381,10 +381,10 @@ class JambaModel(nn.Module): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -403,7 +403,7 @@ class JambaModel(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if 'experts' in name: + if "experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -418,10 +418,10 @@ class JambaModel(nn.Module): break else: for ( - param_name, - weight_name, - expert_id, - shard_id, + param_name, + weight_name, + expert_id, + shard_id, ) in expert_params_mapping: if weight_name not in name: continue @@ -431,11 +431,13 @@ class JambaModel(nn.Module): name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -445,19 +447,18 @@ class JambaModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ - ".self_attn.": ".", - ".A_log": ".A" - }, ) +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".self_attn.": ".", ".A_log": ".A"}, + ) packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -480,16 +481,18 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ + assert not cache_config.enable_prefix_caching, ( "Jamba currently does not support prefix caching" + ) super().__init__() self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config - self.model = JambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = JambaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -500,49 +503,37 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config) - state_dtype = self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_layers, *state_shape, - *state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) @@ -552,7 +543,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba1_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -573,20 +563,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, intermediate_size=hf_config.mamba_expand * hidden_size, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=envs.VLLM_USE_V1, ) def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -595,7 +581,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, class JambaForSequenceClassification(JambaForCausalLM): - is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -603,7 +588,7 @@ class JambaForSequenceClassification(JambaForCausalLM): config = vllm_config.model_config.hf_config num_labels: int = config.num_labels - score_bias: bool = getattr(config, 'score_bias', False) + score_bias: bool = getattr(config, "score_bias", False) # TODO: The original reward weights have float32 accuracy data, we # would like to load them in fp32 to get that extra precision. @@ -612,18 +597,18 @@ class JambaForSequenceClassification(JambaForCausalLM): config.hidden_size, num_labels, bias=score_bias, - dtype=torch.float32, + dtype=vllm_config.model_config.head_dtype, ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - Pooler.for_classify( - pooler_config, - classifier=self.score, - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": Pooler.for_classify( + pooler_config, + classifier=self.score, + ), + } + ) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 8c64f636c6a0f..9711eeeeec33e 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -5,37 +5,39 @@ from typing import Optional import torch import torch.nn as nn -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.inputs import TokensPrompt from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors -from .interfaces import (SupportsCrossEncoding, SupportsMultiModal, - SupportsScoreTemplate) -from .qwen2_vl import (Qwen2VLDummyInputsBuilder, - Qwen2VLForConditionalGeneration, - Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) +from .interfaces import SupportsCrossEncoding, SupportsMultiModal, SupportsScoreTemplate +from .qwen2_vl import ( + Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, +) from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) class JinaVLScorer(nn.Module): - - def __init__(self, config: PretrainedConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() - self.dense = ColumnParallelLinear(config.hidden_size, - config.hidden_size, - bias=True) - self.out_proj = RowParallelLinear(config.hidden_size, - config.num_labels, - bias=True) + config = model_config.hf_config + head_dtype = model_config.head_dtype + self.dense = ColumnParallelLinear( + config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True + ) + self.out_proj = RowParallelLinear( + config.hidden_size, config.num_labels, params_dtype=head_dtype, bias=True + ) def forward(self, x, **kwargs): x, _ = self.dense(x) @@ -45,7 +47,6 @@ class JinaVLScorer(nn.Module): class JinaVLMultiModalProcessor(Qwen2VLMultiModalProcessor): - def _call_hf_processor( self, prompt: str, @@ -53,25 +54,26 @@ class JinaVLMultiModalProcessor(Qwen2VLMultiModalProcessor): mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - # NOTE: We should reverse the order of the mm_data because the # query prompt is placed after the document prompt in the score # template for JinaVLForRanking model, but in mm_data they are # stored in the opposite order (query first, then document). for _, value in mm_data.items(): value.reverse() - return super()._call_hf_processor(prompt, mm_data, mm_kwargs, - tok_kwargs) + return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs) -@MULTIMODAL_REGISTRY.register_processor(JinaVLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, - SupportsCrossEncoding, - SupportsMultiModal, - SupportsScoreTemplate): - +@MULTIMODAL_REGISTRY.register_processor( + JinaVLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) +class JinaVLForSequenceClassification( + Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, + SupportsMultiModal, + SupportsScoreTemplate, +): is_pooling_model = True weight_mapper = WeightsMapper( orig_to_new_prefix={ @@ -83,27 +85,24 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "qwen2_vl")) - config = vllm_config.model_config.hf_config + super().__init__( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl") + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - # logit bias for sigmoid normalization - self.LOGIT_BIAS = 2.65 - - self.score = JinaVLScorer(config) - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - Pooler.for_classify(pooler_config, classifier=None), - "score": - Pooler.for_classify(pooler_config, classifier=None), - }) + self.score = JinaVLScorer(vllm_config.model_config) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": Pooler.for_classify(pooler_config, classifier=self.score), + "score": Pooler.for_classify(pooler_config, classifier=self.score), + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -118,9 +117,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, @classmethod def post_process_tokens(cls, prompt: TokensPrompt) -> None: - # add score target token at the end of prompt tokens - prompt['prompt_token_ids'].append(100) + prompt["prompt_token_ids"].append(100) def forward( self, @@ -137,9 +135,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, inputs_embeds=inputs_embeds, **kwargs, ) - - logits = self.score(hidden_states) - self.LOGIT_BIAS - return logits + return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index c6dbd62b905e1..7ccbc81431f62 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, TypeVar, Union import numpy as np import torch @@ -12,61 +13,77 @@ from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation from transformers.feature_extraction_utils import BatchFeature -from transformers.modeling_outputs import (BaseModelOutput, - BaseModelOutputWithPooling) +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageSize, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .siglip import SiglipMLP -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, is_pp_missing_parameter, - maybe_prefix, merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + is_pp_missing_parameter, + maybe_prefix, +) from .vision import get_vit_attn_backend logger = init_logger(__name__) -_MAX_FRAMES_PER_VIDEO = 16 -_MAX_IMAGE_SIZE = 9999999 - def smart_resize( height: int, width: int, - factor: int = 28, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + factor: int, + min_pixels: int, + max_pixels: int, ): if height < factor: logger.warning( @@ -87,8 +104,10 @@ def smart_resize( width = factor if max(height, width) / min(height, width) > 200: - raise ValueError("absolute aspect ratio must be smaller than 200, got " - "{max(height, width) / min(height, width)}") + raise ValueError( + "absolute aspect ratio must be smaller than 200, got " + "{max(height, width) / min(height, width)}" + ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: @@ -105,17 +124,17 @@ def smart_resize( class KeyeImagePixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["pixel_values"] pixel_values: Annotated[ - torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -123,11 +142,12 @@ class KeyeImageEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of image features - - hs: Hidden size (must match the hidden size of language model + - hs: Hidden size (must match the hidden size of language model backbone) - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["image_embeds"] image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -139,17 +159,17 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs] class KeyeVideoPixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ - torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -157,11 +177,12 @@ class KeyeVideoEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of video features - - hs: Hidden size (must match the hidden size of language model + - hs: Hidden size (must match the hidden size of language model backbone) - nv: Number of videos - g: Grid dimensions (3 for t, h, w) """ + type: Literal["video_embeds"] video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -171,7 +192,6 @@ KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs] class KeyeVisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -187,12 +207,11 @@ class KeyeVisionEmbeddings(nn.Module): padding="valid", ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.cache_position_embedding = dict() self.cache_position_count = dict() - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) self.register_buffer( @@ -208,7 +227,6 @@ class KeyeVisionEmbeddings(nn.Module): width: int, is_after_patchify: bool = False, ) -> torch.Tensor: - num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) @@ -223,8 +241,9 @@ class KeyeVisionEmbeddings(nn.Module): new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, - sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( @@ -237,11 +256,7 @@ class KeyeVisionEmbeddings(nn.Module): patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed - def fetch_position_embedding_lfu_cache(self, - embeddings, - h, - w, - max_cache: int = 20): + def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20): grid = (h, w) if grid in self.cache_position_embedding: self.cache_position_count[grid] += 1 @@ -255,8 +270,7 @@ class KeyeVisionEmbeddings(nn.Module): self.cache_position_count.pop(min_hit_grid) self.cache_position_embedding.pop(min_hit_grid) - position_embedding = self.interpolate_pos_encoding( - embeddings, h, w, True) + position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True) self.cache_position_count[grid] = 1 self.cache_position_embedding[grid] = position_embedding return position_embedding @@ -265,10 +279,14 @@ class KeyeVisionEmbeddings(nn.Module): self, pixel_values: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 4: @@ -287,8 +305,7 @@ class KeyeVisionEmbeddings(nn.Module): ) = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype)) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) embeddings = patch_embeds.flatten(-2).squeeze(-1) if interpolate_pos_encoding and image_grid_thw is not None: @@ -298,19 +315,23 @@ class KeyeVisionEmbeddings(nn.Module): t, h, w = image_grid end = start + t * h * w image_embeddings = embeddings[start:end, :] - position_embedding = (self.interpolate_pos_encoding( - image_embeddings, h, w, True).squeeze(0).repeat(t, 1)) + position_embedding = ( + self.interpolate_pos_encoding(image_embeddings, h, w, True) + .squeeze(0) + .repeat(t, 1) + ) image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings) start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) else: - embeddings = embeddings + self.packing_position_embedding( - position_ids) + embeddings = embeddings + self.packing_position_embedding(position_ids) return embeddings else: - raise ValueError("Unsupported pixel_values dimension:" - f" {pixel_values.dim()}. Expected 4 or 5.") + raise ValueError( + "Unsupported pixel_values dimension:" + f" {pixel_values.dim()}. Expected 4 or 5." + ) def apply_rotary_pos_emb_flashatt( @@ -376,10 +397,21 @@ class KeyeSiglipAttention(nn.Module): ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype() + ) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( - f"Keye-VL does not support {self.attn_backend} backend now.") + f"Keye-VL does not support {self.attn_backend} backend now." + ) def forward( self, @@ -413,8 +445,7 @@ class KeyeSiglipAttention(nn.Module): ) else: if cu_seqlens is None: - raise ValueError( - "cu_seqlens cannot be None when rope_emb is not None.") + raise ValueError("cu_seqlens cannot be None when rope_emb is not None.") cos, sin = rope_emb q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( @@ -430,7 +461,10 @@ class KeyeSiglipAttention(nn.Module): ) if self.attn_backend == _Backend.FLASH_ATTN: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -445,29 +479,26 @@ class KeyeSiglipAttention(nn.Module): causal=False, softmax_scale=self.scale, ) - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) - context_layer = rearrange(context_layer, - "b s h d -> b s (h d)").contiguous() + context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() output, _ = self.out_proj(context_layer) return output class SigLIPRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim @@ -475,8 +506,9 @@ class SigLIPRotaryEmbedding(nn.Module): self.rope_init() def rope_init(self): - inv_freq = 1.0 / (self.theta**( - torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)) + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: @@ -490,7 +522,6 @@ class SigLIPRotaryEmbedding(nn.Module): class KeyeSiglipEncoderLayer(nn.Module): - def __init__( self, config: Union[PretrainedConfig], @@ -499,15 +530,13 @@ class KeyeSiglipEncoderLayer(nn.Module): ): super().__init__() self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, @@ -522,7 +551,6 @@ class KeyeSiglipEncoderLayer(nn.Module): cu_seqlens: Optional[list[torch.Tensor]] = None, rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.FloatTensor]: - residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -546,7 +574,6 @@ class KeyeSiglipEncoderLayer(nn.Module): class KeyeSiglipEncoder(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -558,13 +585,16 @@ class KeyeSiglipEncoder(nn.Module): embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads - self.layers = nn.ModuleList([ - KeyeSiglipEncoderLayer( - config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - ) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + KeyeSiglipEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) @staticmethod @@ -584,10 +614,14 @@ class KeyeSiglipEncoder(nn.Module): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, use_rope: Optional[bool] = False, @@ -603,8 +637,7 @@ class KeyeSiglipEncoder(nn.Module): split_hids = list() split_wids = list() for t, h, w in flatten_image_grid_thw: - image_pids = torch.arange(t * h * w, - device=device) % (h * w) + image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w split_hids.append(sample_hids) @@ -640,7 +673,6 @@ class KeyeSiglipEncoder(nn.Module): class KeyeSiglipVisionTransformer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -657,8 +689,7 @@ class KeyeSiglipVisionTransformer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.encoder", ) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -675,15 +706,18 @@ class KeyeSiglipVisionTransformer(nn.Module): cu_seqlens: Optional[list[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: - hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, @@ -709,8 +743,10 @@ class KeyeSiglipVisionTransformer(nn.Module): sample_hidden_state = list() if cu_seqlens is None: - raise ValueError("cu_seqlens cannot be None for " - "SiglipVisionTransformer output processing.") + raise ValueError( + "cu_seqlens cannot be None for " + "SiglipVisionTransformer output processing." + ) for i in range(cu_seqlens.shape[0] - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] @@ -759,16 +795,19 @@ class KeyeSiglipVisionModel(nn.Module): interpolate_pos_encoding: bool = False, position_ids: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: - return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, @@ -784,8 +823,7 @@ class KeyeSiglipVisionModel(nn.Module): window_size=window_size, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -801,22 +839,24 @@ class KeyeSiglipVisionModel(nn.Module): if "head.mlp" in name or "head.probe" in name: continue if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name)): + scale_name := self.quant_config.get_cache_scale(name) + ): param = params_dict[scale_name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for ( - param_name, - weight_name, - shard_id, + param_name, + weight_name, + shard_id, ) in stacked_params_mapping: if weight_name not in name: continue @@ -849,7 +889,6 @@ class KeyeSiglipVisionModel(nn.Module): class Projector(nn.Module): - def __init__( self, text_config: PretrainedConfig, @@ -862,12 +901,13 @@ class Projector(nn.Module): self.vision_config = vision_config self.merge_kernel_size = (2, 2) - self.hidden_size = (self.vision_config.hidden_size * - self.merge_kernel_size[0] * - self.merge_kernel_size[1]) + self.hidden_size = ( + self.vision_config.hidden_size + * self.merge_kernel_size[0] + * self.merge_kernel_size[1] + ) - self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, - eps=1e-05) + self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) self.act = GELUActivation() self.linear_1 = ColumnParallelLinear( @@ -887,14 +927,13 @@ class Projector(nn.Module): def forward( self, - image_features: torch.Tensor, + image_features: Union[torch.Tensor, list[torch.Tensor]], image_grid_thw: list[tuple[int, int, int]], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, list[torch.Tensor]]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() - for image_feature, image_grid in zip(image_features, - image_grid_thw): + for image_feature, image_grid in zip(image_features, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid @@ -917,8 +956,7 @@ class Projector(nn.Module): dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) - hidden_states = self.pre_norm(image_features).view( - -1, self.hidden_size) + hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -926,7 +964,9 @@ class Projector(nn.Module): return hidden_states.view(*dims, -1) -def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): +def _keye_field_config( + hf_inputs: Mapping[str, torch.Tensor], +): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) @@ -934,21 +974,18 @@ def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): video_grid_sizes = video_grid_thw.prod(-1) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), ) class KeyeMultiModalDataParser(MultiModalDataParser): - def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -985,11 +1022,18 @@ class KeyeMultiModalDataParser(MultiModalDataParser): class KeyeProcessingInfo(BaseProcessingInfo): + def get_max_image_size(self) -> int: + return 9999999 # _MAX_IMAGE_SIZE + + def get_max_frame_per_video(self) -> int: + return 16 # _MAX_FRAMES_PER_VIDEO def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits( + self, + ) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( @@ -1028,11 +1072,9 @@ class KeyeProcessingInfo(BaseProcessingInfo): min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size @@ -1075,10 +1117,12 @@ class KeyeProcessingInfo(BaseProcessingInfo): ) return num_video_tokens - def get_image_size_with_most_features(self, ) -> ImageSize: + def get_image_size_with_most_features( + self, + ) -> ImageSize: max_image_size, _ = self._get_vision_info( - image_width=_MAX_IMAGE_SIZE, - image_height=_MAX_IMAGE_SIZE, + image_width=self.get_max_image_size(), + image_height=self.get_max_image_size(), image_processor=None, ) return max_image_size @@ -1119,11 +1163,10 @@ class KeyeProcessingInfo(BaseProcessingInfo): max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO, + self.get_max_frame_per_video(), ) return max(max_frames_per_video, 1) @@ -1139,8 +1182,10 @@ class KeyeProcessingInfo(BaseProcessingInfo): ) -class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): +_I = TypeVar("_I", bound=KeyeProcessingInfo) + +class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1155,36 +1200,40 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = ( - self.info.get_image_size_with_most_features()) - target_num_frames = self.info.get_num_frames_with_most_features( - seq_len) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features(seq_len) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None mm_data = { - "image": - self._get_dummy_images( + "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, + overrides=image_overrides, ), - "video": - self._get_dummy_videos( + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, + overrides=video_overrides, ), } return mm_data -class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): +class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ... + +class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return KeyeMultiModalDataParser() @@ -1195,8 +1244,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -1220,7 +1268,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): modality=modality, target=[placeholder[modality]], replacement=partial(get_replacement_keye, modality=modality), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -1231,13 +1280,9 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): return _keye_field_config(hf_inputs) -@MULTIMODAL_REGISTRY.register_processor( - KeyeMultiModalProcessor, - info=KeyeProcessingInfo, - dummy_inputs=KeyeDummyInputsBuilder, -) -class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, - SupportsPP): +class BaseKeyeModule(nn.Module): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1250,10 +1295,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ], } - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1275,13 +1322,14 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, self.visual = KeyeSiglipVisionModel( config.vision_config, - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), ) - self.mlp_AR = Projector( + + self.mlp_AR = self._build_projector( config, config.vision_config, - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=quant_config, prefix=maybe_prefix(prefix, "mlp_AR"), ) @@ -1292,104 +1340,20 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config + @abstractmethod + def _build_projector( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + raise ValueError("Need projector") - def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim == 5: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) - elif is_list_of(mm_input, torch.Tensor): - if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 - for p in mm_input): - return mm_input - return torch.concat(list(mm_input)) - - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[KeyeImageInputs]: - pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) - image_grid_thw = kwargs.pop("image_grid_thw", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - return KeyeImagePixelInputs( - type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) - - if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - return KeyeImageEmbeddingInputs( - type="image_embeds", - image_embeds=image_embeds, - image_grid_thw=image_grid_thw, - ) - - def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[KeyeVideoInputs]: - pixel_values_videos = kwargs.pop("pixel_values_videos", None) - video_embeds = kwargs.pop("video_embeds", None) - video_grid_thw = kwargs.pop("video_grid_thw", None) - - if pixel_values_videos is None and video_embeds is None: - return None - - if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, - "video pixel values", - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - - return KeyeVideoPixelInputs( - type="pixel_values_videos", - pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thw, - ) - - if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - - return KeyeVideoEmbeddingInputs( - type="video_embeds", - video_embeds=video_embeds, - video_grid_thw=video_grid_thw, - ) - - def _process_image_input( - self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]: + def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]: siglip_position_ids = list() image_grid_hws = list() sample_indices = list() @@ -1404,21 +1368,22 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, image_grid_hws.append(thw_tuple) image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(image_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) + sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if image_input["type"] == "image_embeds": raise ValueError( - "Image embeddings are not supported for this processing path.") + "Image embeddings are not supported for this processing path." + ) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - siglip_position_ids = torch.concat(siglip_position_ids, - dim=0).to(pixel_values.device) + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( + pixel_values.device + ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values.device) + pixel_values.device + ) + sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device) image_embeds = self.visual( pixel_values=pixel_values, @@ -1434,39 +1399,43 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) return image_embeds - def _process_video_input( - self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: + def _process_video_embeds( + self, + video_type: Literal["video_embeds", "pixel_values_videos"], + video_grid_thw: list[torch.Tensor], + pixel_values_videos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, list[torch.Tensor]]: siglip_position_ids = list() video_grid_hws = list() sample_indices = list() cu_seqlens = [0] - video_grid_thw = video_input["video_grid_thw"] assert video_grid_thw.ndim == 2 - - for idx, thaw in enumerate(video_grid_thw): - thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) + for idx, sub_thw in enumerate(video_grid_thw): + thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) video_grid_hws.append(thw_tuple) video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(video_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) + sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) - if video_input["type"] == "video_embeds": + if video_type == "video_embeds": raise ValueError( - "Video embeddings are not supported for this processing path.") + "Video embeddings are not supported for this processing path." + ) else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( - pixel_values_videos.device) + pixel_values_videos.device + ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values_videos.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values_videos.device) + pixel_values_videos.device + ) + sample_indices = torch.concat(sample_indices, dim=0).to( + pixel_values_videos.device + ) video_embeds = self.visual( pixel_values=pixel_values_videos, @@ -1479,21 +1448,23 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, use_rope=True, window_size=-1, ) - video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw)) + video_embeds = self.mlp_AR(video_embeds, video_grid_thw) return video_embeds def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} for input_key in kwargs: - if (input_key in ("pixel_values", "image_embeds") - and "images" not in modalities): - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if (input_key in ("pixel_values_videos", "video_embeds") - and "videos" not in modalities): - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities @@ -1501,8 +1472,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, return self.language_model def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -1520,50 +1491,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[KeyeImagePixelInputs] = None, - video_input: Optional[KeyeVideoPixelInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1572,7 +1499,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - """Run forward pass for Qwen2-VL. + """Run forward pass for Keye-VL. Args: input_ids: Flattened (concatenated) input_ids corresponding to a @@ -1581,57 +1508,29 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ - if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input, - ) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + return self.language_model.compute_logits(hidden_states) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1639,6 +1538,83 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, """Get the module prefix in multimodal models.""" return MultiModelKeys.from_string_field( language_model="language_model", - connector="visual.", - tower_model="mlp_AR.", + connector="mlp_AR.", + tower_model="visual.", + ) + + +@MULTIMODAL_REGISTRY.register_processor( + KeyeMultiModalProcessor, + info=KeyeProcessingInfo, + dummy_inputs=KeyeDummyInputsBuilder, +) +class KeyeForConditionalGeneration( + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP +): + def _build_projector( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + return Projector(text_config, vision_config, quant_config, prefix) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[KeyeImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + return KeyeImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + return KeyeImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> Optional[KeyeVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + return KeyeVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + return KeyeVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + def _process_video_input( + self, video_input: KeyeVideoInputs + ) -> tuple[torch.Tensor, ...]: + video_type = video_input["type"] + video_grid_thw = video_input["video_grid_thw"] + pixel_values_videos = video_input.get("pixel_values_videos", None) + + return tuple( + self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) ) diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py new file mode 100644 index 0000000000000..21d8099b43d16 --- /dev/null +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -0,0 +1,731 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from collections.abc import Mapping, Sequence +from functools import partial +from typing import Annotated, Any, Literal, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from transformers import PretrainedConfig +from transformers.activations import GELUActivation +from transformers.feature_extraction_utils import BatchFeature + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP +from .keye import ( + BaseKeyeModule, + BaseMultiModalProcessor, + KeyeBaseDummyInputsBuilder, + KeyeProcessingInfo, +) + +logger = init_logger(__name__) + + +def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: + """ + Split grid_thw in t dimension. + + Args: + grid_thw: [N, 3] tensor of [t, h, w] + + Returns: + [Σt, 3] tensor where each row is [1, h, w] + + Example: + >>> grid_thw = torch.tensor([[2, 3, 4], [1, 5, 6]]) + >>> split_thw(grid_thw) + tensor([[1, 3, 4], + [1, 3, 4], + [1, 5, 6]]) + """ + t = grid_thw[:, 0] + h_w = grid_thw[:, 1:] + ones = torch.ones_like(h_w[:, :1]) + return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0) + + +def get_num_patches( + grid_thw: torch.Tensor, num_frames: Union[list[int], torch.Tensor] +) -> list[int]: + """ + Return num_patches per video. + + Args: + grid_thw: Tensor with shape [N, 3] containing temporal, height, width + dimensions + num_frames: List or tensor indicating the number of frames per video + + Returns: + List of ints representing the number of patches for each video + + Examples: + >>> # Suppose there are 2 videos with a total of 3 grids + >>> grid_thw = torch.tensor( + ... [ + ... [2, 2, 2], # grid 0: 2*2*2=8 patches + ... [2, 2, 2], # grid 1: 2*2*2=8 patches + ... [1, 1, 1], + ... ] + ... ) # grid 2: 1*1*1=1 patches + >>> num_frames = [2, 1] # The first video contains 2 grids, + the second contains 1 grid. + >>> get_num_patches(grid_thw, num_frames) + tensor([16, 1]) # Total patches for first video: 8+8=16, + second video: 1. + """ + + assert len(grid_thw.shape) == 2 + if isinstance(num_frames, torch.Tensor): + num_frames = num_frames.clone().tolist() + + num_grids_per_frame = grid_thw.prod(dim=1) + start_idx_per_video = [0, *itertools.accumulate(num_frames)] + num_patches = [ + num_grids_per_frame[start_idx_per_video[i] : start_idx_per_video[i + 1]].sum() + for i in range(len(num_frames)) + ] + return ( + torch.stack(num_patches) + if num_patches + else torch.zeros(0, dtype=grid_thw.dtype, device=grid_thw.device) + ) + + +class KeyeVL1_5ImagePixelInputs(TensorSchema): + """ + Dimensions: + - bnp: Batch size * Number of patches + - c: Number of channels + - ps: Patch size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + + type: Literal["pixel_values"] + + pixel_values: Annotated[ + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] + + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +class KeyeVL1_5ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size (must match the hidden size of language model + backbone) + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + + type: Literal["image_embeds"] + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, KeyeVL1_5ImageEmbeddingInputs] + + +class KeyeVL1_5VideoPixelInputs(TensorSchema): + """ + Dimensions: + - bnp: Batch size * Number of patches + - c: Number of channels + - ps: Patch size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + + type: Literal["pixel_values_videos"] + pixel_values_videos: Annotated[ + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + num_frames: torch.Tensor + + +class KeyeVL1_5VideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size (must match the hidden size of language model + backbone) + - nv: Number of videos + - g: Grid dimensions (3 for t, h, w) + """ + + type: Literal["video_embeds"] + video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + num_frames: torch.Tensor + + +KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, KeyeVL1_5VideoEmbeddingInputs] + + +class KeyeVL1_5Projector(nn.Module): + def __init__( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = ( + self.vision_config.hidden_size + * self.merge_kernel_size[0] + * self.merge_kernel_size[1] + ) + + self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05) + self.act = GELUActivation() + + self.linear_1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) + self.linear_2 = RowParallelLinear( + self.hidden_size, + self.text_config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) + + def forward( + self, + image_features: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]], + image_grid_thw: list[tuple[int, int, int]], + ) -> Union[torch.Tensor, list[torch.Tensor]]: + m1, m2 = self.merge_kernel_size + if isinstance(image_features, (list, tuple)): + processed_features = list() + for image_feature, image_grid in zip(image_features, image_grid_thw): + t, h, w = image_grid + image_feature = rearrange( + image_feature, + "(t h p1 w p2) d -> (t h w) (p1 p2 d)", + t=t, + h=h // m1, + p1=m1, + w=w // m2, + p2=m2, + ) + image_feature = self.pre_norm(image_feature) + hidden_states, _ = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + dims = image_features.shape[:-1] + dim = image_features.shape[-1] + image_features = image_features.view(np.prod(dims), dim) + hidden_states = self.pre_norm(image_features.view(-1, self.hidden_size)) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states.view(*dims, -1) + + +class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo): + def get_max_frame_per_video(self) -> int: + return 2048 + + def get_supported_mm_limits( + self, + ) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + +def _keye_field_config( + hf_inputs: Mapping[str, torch.Tensor], +): + image_grid_thw = hf_inputs.get( + "image_grid_thw", torch.empty((0, 3), dtype=torch.int64) + ) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get( + "video_grid_thw", torch.empty((0, 3), dtype=torch.int64) + ) + video_grid_thw = split_thw(video_grid_thw) + num_frames = hf_inputs.get("num_frames", video_grid_thw[:, 0]).clone().tolist() + + video_num_patches = get_num_patches(video_grid_thw, num_frames) + + video_num_grids = [] + if len(num_frames) > 0: + i = 0 + j = 1 + cur_frames = num_frames[i] + for t, _, _ in video_grid_thw.tolist(): + cur_frames -= t + if cur_frames == 0: + video_num_grids.append(j) + i += 1 + if i < len(num_frames): + cur_frames = num_frames[i] + j = 1 + else: + j += 1 + video_num_grids = torch.tensor(video_num_grids) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_num_patches), + video_grid_thw=MultiModalFieldConfig.flat_from_sizes("video", video_num_grids), + num_frames=MultiModalFieldConfig.batched("video"), + ) + + +class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={ + "image_embeds", + "image_grid_thw", + }, + fields_factory=_keye_field_config, + ) + + return super()._parse_image_data(data) + + def _parse_video_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="video", + required_fields={ + "video_embeds", + "video_grid_thw", + }, + fields_factory=_keye_field_config, + ) + + return super()._parse_video_data(data) + + +class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): + def _get_data_parser(self) -> MultiModalDataParser: + return KeyeVL1_5MultiModalDataParser() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + image_token_id = vocab[hf_processor.image_token] + video_token_id = vocab[hf_processor.video_token] + placeholder = {"image": image_token_id, "video": video_token_id} + merge_length = image_processor.merge_size**2 + + out_mm_kwargs_data = out_mm_kwargs.get_data() + frame_types: list[torch.Tensor] = hf_processor_mm_kwargs.get( + "frame_types", None + ) + timestamps: list[torch.Tensor] = hf_processor_mm_kwargs.get("timestamps", None) + num_videos = mm_items.get_count("video", strict=False) + + if frame_types is None: + frame_types = [None] * num_videos + assert len(frame_types) == num_videos, ( + f"Number of frame_types={len(frame_types)} " + f"doesn't equal to number of videos={num_videos}" + ) + if timestamps is None: + timestamps = [None] * num_videos + assert len(timestamps) == num_videos, ( + f"Number of timestamps={len(timestamps)} " + f"doesn't equal to number of videos={num_videos}" + ) + + video_grid_thw = out_mm_kwargs_data.get( + "video_grid_thw", torch.empty((0, 3), dtype=torch.int64) + ) + num_frames = out_mm_kwargs_data.get( + "num_frames", torch.tensor([], dtype=torch.int64) + ) + + assert len(num_frames) == num_videos, ( + f"Size of num_frames={len(num_frames)} " + f"doesn't equal to number of videos={num_videos}" + ) + + video_grid_hws = split_thw(video_grid_thw) + assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], ( + f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}" + f"doesn't equal to num of frames." + ) + + cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()), dim=-1) + + def get_replacement_keye(item_idx: int, modality: str): + """ + Args: + item_idx(int): The item index of modality to replace + modality(str): The modality + """ + if modality == "image": + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [image_token_id] * num_tokens + elif modality == "video": + placeholders = [] + video_timestamps = timestamps[item_idx] + video_frame_types = frame_types[item_idx] + grid_thw = video_grid_hws[ + cu_seqlens[item_idx] : cu_seqlens[item_idx + 1] + ] + + nframes = grid_thw.shape[0] + + if video_timestamps is None: + video_timestamps = [""] * nframes + else: + video_timestamps = [format(ts, ".1f") for ts in video_timestamps] + + if video_frame_types is None: + video_frame_types = [0] * nframes + for i, sub_thw in enumerate(grid_thw): + s = f"{hf_processor.frame_token}{video_timestamps[i]}" + if video_frame_types[i] == 1: + s += hf_processor.fast_start + placeholders.extend(tokenizer.encode(s)) + num_frame_tokens = int(sub_thw.prod()) // merge_length + placeholders.extend([video_token_id] * num_frame_tokens) + if video_frame_types[i] == 1: + placeholders.append(vocab[hf_processor.fast_end]) + + return PromptUpdateDetails.select_token_id( + placeholders, embed_token_id=video_token_id + ) + else: + raise ValueError(f"Unsupported modality {modality}") + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_keye, modality=modality), + ) + for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _keye_field_config(hf_inputs) + + +class KeyeVL1_5DummyInputsBuilder( + KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo] +): ... + + +@MULTIMODAL_REGISTRY.register_processor( + KeyeVL1_5MultiModalProcessor, + info=KeyeVL1_5ProcessingInfo, + dummy_inputs=KeyeVL1_5DummyInputsBuilder, +) +class KeyeVL1_5ForConditionalGeneration( + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): + def _build_projector( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + return KeyeVL1_5Projector(text_config, vision_config, quant_config, prefix) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config: PretrainedConfig = vllm_config.model_config.hf_config + self.merge_size = config.vision_config.spatial_merge_size + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[KeyeVL1_5ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + return KeyeVL1_5ImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + return KeyeVL1_5ImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> Optional[KeyeVL1_5VideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + num_frames = kwargs.pop("num_frames", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + return KeyeVL1_5VideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + num_frames=num_frames, + ) + + if video_embeds is not None: + return KeyeVL1_5VideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + num_frames=num_frames, + ) + + def _process_video_input( + self, video_input: KeyeVL1_5VideoInputs + ) -> tuple[torch.Tensor, ...]: + video_type = video_input["type"] + video_grid_thw = split_thw(video_input["video_grid_thw"]) + pixel_values_videos = video_input.get("pixel_values_videos", None) + + video_embeds = self._process_video_embeds( + video_type, video_grid_thw, pixel_values_videos + ) + video_embeds = torch.concat(video_embeds, dim=0) + + num_frames = video_input["num_frames"].clone().tolist() + + num_patches = get_num_patches(video_grid_thw, num_frames).tolist() + + patch_cu_seqlens = torch.cumsum( + torch.tensor([0] + num_patches).detach().clone(), dim=-1 + ) + patch_cu_seqlens = torch.div( + patch_cu_seqlens, self.merge_size**2, rounding_mode="floor" + ) + + new_video_embeds = [] + for idx in range(patch_cu_seqlens.shape[0] - 1): + start = patch_cu_seqlens[idx] + end = patch_cu_seqlens[idx + 1] + new_video_embeds.append(video_embeds[start:end]) + return tuple(new_video_embeds) + + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: + video_grid_thw = video_grid_thw[0] + """Get mrope input positions and delta value (Keye series).""" + + def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: + """ + Split grid_thw along the t dimension. + + Args: + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. + + Returns: + List of [1, h, w] rows, repeated t times for each original row. + """ + + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.long) + + if grid_thw.numel() == 0: + return [] + + t, hw = grid_thw[:, 0], grid_thw[:, 1:] + ones = torch.ones_like(hw[:, :1]) # [N,1] + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) + return out.tolist() + + video_grid_thw = split_thw(video_grid_thw) + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_nums = len(image_grid_thw) + frame_nums = len(video_grid_thw) + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_frames = image_nums, frame_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + frame_nums): + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_frames > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_frames -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index a08a9a62a57c5..f7381e6b6b93e 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -54,27 +54,40 @@ from transformers import BatchFeature from transformers.activations import GELUActivation from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - SupportsPP) +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP from vllm.model_executor.models.moonvit import MoonVitPretrainedModel -from vllm.model_executor.models.utils import merge_multimodal_embeddings -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig @@ -82,6 +95,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix +from .vision import run_dp_sharded_mrope_vision_model # For dummy input only @@ -92,30 +106,38 @@ class MaxImageTokenMeta: class KimiVLMultiModalProjector(nn.Module): - - def __init__(self, config: KimiVLConfig): + def __init__( + self, config: KimiVLConfig, use_data_parallel: bool = False, prefix: str = "" + ): super().__init__() + self.use_data_parallel = use_data_parallel - self.hidden_size = (config.vision_config.hidden_size * - config.vision_config.merge_kernel_size[0] * - config.vision_config.merge_kernel_size[1]) + self.hidden_size = ( + config.vision_config.hidden_size + * config.vision_config.merge_kernel_size[0] + * config.vision_config.merge_kernel_size[1] + ) - self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, - eps=1e-5) - self.linear_1 = nn.Linear(self.hidden_size, - self.hidden_size, - bias=True) + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5) + self.linear_1 = ReplicatedLinear( + self.hidden_size, + self.hidden_size, + bias=True, + prefix=maybe_prefix(prefix, "linear_1"), + ) + self.linear_2 = ReplicatedLinear( + self.hidden_size, + config.text_config.hidden_size, + bias=True, + prefix=maybe_prefix(prefix, "linear_2"), + ) self.act = GELUActivation() - self.linear_2 = nn.Linear(self.hidden_size, - config.text_config.hidden_size, - bias=True) def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.pre_norm(image_features).view( - -1, self.hidden_size) - hidden_states = self.linear_1(hidden_states) + hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) + hidden_states, _ = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) return hidden_states @@ -127,6 +149,7 @@ class KimiVLImagePixelInputs(TensorSchema): - ps: Patch size - ni: Number of images """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ @@ -143,7 +166,6 @@ KimiVLImageInputs = KimiVLImagePixelInputs class KimiVLProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(KimiVLConfig) @@ -162,25 +184,25 @@ class KimiVLProcessingInfo(BaseProcessingInfo): in_token_limit = hf_processor.image_processor.in_token_limit height = image_height width = image_width - assert isinstance(height, - int), f"height must be int, current height {height}" - assert isinstance(width, - int), f"width must be int, current width {width}" + assert isinstance(height, int), f"height must be int, current height {height}" + assert isinstance(width, int), f"width must be int, current width {width}" assert kernel_size is not None, "kernel_size must be specified" if (width // patch_size) * (height // patch_size) > in_token_limit: - scale = math.sqrt(in_token_limit / ((width // patch_size) * - (height // patch_size))) + scale = math.sqrt( + in_token_limit / ((width // patch_size) * (height // patch_size)) + ) new_w, new_h = int(width * scale), int(height * scale) width, height = new_w, new_h kernel_height, kernel_width = kernel_size - pad_height = (kernel_height * patch_size - height % - (kernel_height * patch_size)) % (kernel_height * - patch_size) - pad_width = (kernel_width * patch_size - width % - (kernel_width * patch_size)) % (kernel_width * patch_size) + pad_height = ( + kernel_height * patch_size - height % (kernel_height * patch_size) + ) % (kernel_height * patch_size) + pad_width = ( + kernel_width * patch_size - width % (kernel_width * patch_size) + ) % (kernel_width * patch_size) # Calculate new dimensions after padding and patching token_height = (height + pad_height) // (kernel_size[0] * patch_size) @@ -193,7 +215,6 @@ class KimiVLProcessingInfo(BaseProcessingInfo): class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -206,19 +227,23 @@ class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=MaxImageTokenMeta.width, - height=MaxImageTokenMeta.height, - num_images=num_images) + "image": self._get_dummy_images( + width=MaxImageTokenMeta.width, + height=MaxImageTokenMeta.height, + num_images=num_images, + overrides=image_overrides, + ) } class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -231,7 +256,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): # image_grid_hws is shapes for each subtensor in pixel_values return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_grid_hws=MultiModalFieldConfig.batched("image"), ) @@ -245,7 +271,8 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): def get_replacement(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -267,11 +294,15 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, - info=KimiVLProcessingInfo, - dummy_inputs=KimiVLDummyInputsBuilder) -class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + KimiVLMultiModalProcessor, + info=KimiVLProcessingInfo, + dummy_inputs=KimiVLDummyInputsBuilder, +) +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + + supports_encoder_tp_data = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -292,14 +323,27 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config = vllm_config.quant_config assert isinstance(config.vision_config, MoonViTConfig) + self.use_data_parallel = ( + model_config.multimodal_config.mm_encoder_tp_mode == "data" + ) + self.hidden_size = config.text_config.hidden_size + self.vision_tower = MoonVitPretrainedModel( + config.vision_config, + self.use_data_parallel, + prefix=maybe_prefix(prefix, "vision_tower"), + ) - self.vision_tower = MoonVitPretrainedModel(config.vision_config) - - self.multi_modal_projector = KimiVLMultiModalProjector(config=config) + self.multi_modal_projector = KimiVLMultiModalProjector( + config=config, + use_data_parallel=self.use_data_parallel, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) self.quant_config = quant_config sub_vllm_config = copy.deepcopy(vllm_config) - sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config + sub_vllm_config.model_config.hf_config = ( + sub_vllm_config.model_config.hf_config.text_config + ) self.language_model = DeepseekV2Model( vllm_config=sub_vllm_config, prefix=maybe_prefix(prefix, "language_model"), @@ -311,35 +355,22 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, config.text_config.hidden_size, org_num_embeddings=self.config.text_config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) self.media_placeholder: int = self.config.media_placeholder_token_id - # ref: qwen2_vl.py - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[KimiVLImageInputs]: + self, **kwargs: object + ) -> Optional[KimiVLImageInputs]: # image input type must be pixel values now pixel_values = kwargs.pop("pixel_values", None) image_grid_hws = kwargs.pop("image_grid_hws", None) @@ -347,21 +378,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is None: return None - image_grid_hws = self._validate_and_reshape_mm_tensor( - image_grid_hws, "image grid hws") - # pixel_values may have complex shapes - num_channels = 3 - patch_size = self.config.vision_config.patch_size - if isinstance(pixel_values, list): - pixel_values = torch.cat([ - x.reshape(-1, num_channels, patch_size, patch_size) - for x in pixel_values - ]) - else: - pixel_values = pixel_values.reshape(-1, num_channels, patch_size, - patch_size) - pixel_values = pixel_values.to(self.vision_tower.dtype) - return KimiVLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -370,28 +386,32 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, # perform vt on processored pixel_values @torch.inference_mode() - def _process_image_pixels(self, - inputs: KimiVLImagePixelInputs) -> torch.Tensor: + def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] image_grid_hws = inputs["image_grid_hws"] - return self.vision_tower(pixel_values, image_grid_hws) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.vision_tower, + pixel_values, + image_grid_hws.tolist(), + rope_type="rope_2d", + ) + else: + return self.vision_tower(pixel_values, image_grid_hws) - def _process_image_input(self, - image_input: KimiVLImageInputs) -> torch.Tensor: + def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor: assert image_input["type"] == "pixel_values" image_features = self._process_image_pixels(image_input) - assert isinstance(image_features, list) + assert isinstance(image_features, (list, tuple)) lengths = [x.shape[0] for x in image_features] - return self.multi_modal_projector( - torch.cat(image_features)).split(lengths) + return self.multi_modal_projector(torch.cat(image_features)).split(lengths) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]: # Validate the multimodal input keyword arguments image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: @@ -401,26 +421,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - - # `get_input_embeddings` should already be implemented for the language - # model as one of the requirements of basic vLLM model implementation. - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=self.config.media_placeholder_token_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -431,24 +431,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - inputs_embeds = None - else: - inputs_embeds = self.get_input_embeddings(input_ids) - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config. - media_placeholder_token_id, - ) - input_ids = None hidden_states = self.language_model( input_ids=input_ids, @@ -459,11 +441,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - **kwargs) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, **kwargs) + def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -491,11 +470,13 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=config.n_routed_experts) + num_experts=config.n_routed_experts, + ) else: expert_params_mapping = [] params_dict = dict(self.named_parameters()) + for args in weights: name, loaded_weight = args[:2] kwargs = args[2] if len(args) > 2 else {} @@ -506,8 +487,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, if spec_layer is not None: continue # skip spec decode layers for main model - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -521,8 +501,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, # not vision model for now. use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. @@ -531,7 +510,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -546,8 +525,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, weight_loader(param, loaded_weight, shard_id, **kwargs) break else: - for idx, (param_name, weight_name, expert_id, - shard_id) in enumerate(expert_params_mapping): + for idx, ( + param_name, + weight_name, + expert_id, + shard_id, + ) in enumerate(expert_params_mapping): if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -557,12 +540,14 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - expert_id=expert_id, - shard_id=shard_id, - **kwargs) + weight_loader( + param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + **kwargs, + ) break else: use_default_weight_loading = True @@ -579,18 +564,18 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, **kwargs) -def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): +def get_spec_layer_idx_from_weight_name( + config: DeepseekV2Config, weight_name: str +) -> Optional[int]: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 5f3148b47eadc..425c936877602 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -1,44 +1,53 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch import torch.nn as nn from transformers import Lfm2Config -from vllm import envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.short_conv import ShortConv from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Lfm2MLP(nn.Module): - def __init__( self, dim: int, @@ -62,14 +71,14 @@ class Lfm2MLP(nn.Module): output_sizes=[ff_dim] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=f"{prefix}.w1", ) self.w2 = RowParallelLinear( input_size=ff_dim, output_size=dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.down_proj", + prefix=f"{prefix}.w2", ) self.act_fn = SiluAndMul() @@ -81,7 +90,6 @@ class Lfm2MLP(nn.Module): class Lfm2Attention(nn.Module): - def __init__( self, config: Lfm2Config, @@ -178,7 +186,6 @@ class Lfm2Attention(nn.Module): class Lfm2AttentionDecoderLayer(nn.Module): - def __init__( self, config: Lfm2Config, @@ -196,11 +203,12 @@ class Lfm2AttentionDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Lfm2Attention( config=config, @@ -239,16 +247,13 @@ class Lfm2AttentionDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.operator_norm(hidden_states) else: - hidden_states, residual = self.operator_norm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.operator_norm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states, residual = self.ffn_norm(hidden_states, residual) return self.feed_forward(hidden_states), residual class Lfm2ShortConvDecoderLayer(nn.Module): - def __init__( self, config: Lfm2Config, @@ -291,13 +296,11 @@ class Lfm2ShortConvDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.operator_norm(hidden_states) else: - hidden_states, residual = self.operator_norm( - hidden_states, residual) + hidden_states, residual = self.operator_norm(hidden_states, residual) output = torch.empty_like(hidden_states) self.conv( hidden_states, output, - conv_metadata=None, ) hidden_states, residual = self.ffn_norm(output, residual) hidden_states = self.feed_forward(hidden_states) @@ -306,7 +309,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module): @support_torch_compile class Lfm2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -317,21 +319,24 @@ class Lfm2Model(nn.Module): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size) + self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size + ) def get_layer(prefix: str): layer_idx = extract_layer_index(prefix) is_attn = self.config.layer_types[layer_idx] == "full_attention" - layer_class = (Lfm2AttentionDecoderLayer - if is_attn else Lfm2ShortConvDecoderLayer) + layer_class = ( + Lfm2AttentionDecoderLayer if is_attn else Lfm2ShortConvDecoderLayer + ) return layer_class( config, layer_idx, @@ -342,14 +347,14 @@ class Lfm2Model(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: - self.embedding_norm = RMSNorm(config.hidden_size, - eps=config.norm_eps) + self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) else: self.embedding_norm = PPMissingLayer() @@ -374,22 +379,20 @@ class Lfm2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.embedding_norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), @@ -400,7 +403,6 @@ class Lfm2Model(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -416,15 +418,15 @@ class Lfm2Model(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsQuant): +class Lfm2ForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -449,7 +451,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, ...]: - return MambaStateDtypeCalculator.short_conv_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -459,13 +460,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int]]: - """ Calculate shapes for LFM2's convolutional cache. + """Calculate shapes for LFM2's convolutional cache. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -478,7 +477,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.conv_dim, conv_kernel=hf_config.conv_L_cache, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -486,20 +484,15 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config - scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "Lfm2 currently does not support prefix caching" - assert envs.VLLM_USE_V1, ( - "Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1") + assert not cache_config.enable_prefix_caching, ( + "Lfm2 currently does not support prefix caching" + ) super().__init__() self.config = config - self.vllm_config = vllm_config - self.scheduler_config = scheduler_config - self.model_config = vllm_config.model_config - - self.model = Lfm2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Lfm2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.config.vocab_size @@ -514,8 +507,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -523,11 +517,16 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -537,21 +536,18 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py new file mode 100644 index 0000000000000..f7903a7af53fe --- /dev/null +++ b/vllm/model_executor/models/lfm2_moe.py @@ -0,0 +1,798 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from itertools import islice +from typing import Any, Optional + +import torch +import torch.nn as nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Lfm2MoeConfig + +from .interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class Lfm2MoeMlp(nn.Module): + def __init__( + self, + dim: int, + ff_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.w1 = MergedColumnParallelLinear( + input_size=dim, + output_sizes=[ff_dim] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.w1", + ) + self.w2 = RowParallelLinear( + input_size=ff_dim, + output_size=dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.w2", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.w1(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class Lfm2MoeSparseMoeBlock(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}." + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + if config.use_expert_bias: + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(self.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, # needed for softmax score func + num_expert_group=1, + topk_group=1, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +class Lfm2MoeAttention(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + n_tokens, _ = hidden_states.shape + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(n_tokens, self.num_heads, self.head_dim).contiguous() + k = k.view(n_tokens, self.num_kv_heads, self.head_dim).contiguous() + q = self.q_layernorm(q) + k = self.k_layernorm(k) + q, k = self.rotary_emb(positions, q, k) + q = q.view(n_tokens, self.num_heads * self.head_dim) + k = k.view(n_tokens, self.num_kv_heads * self.head_dim) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class Lfm2MoeAttentionDecoderLayer(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.prefix = prefix + self.config = config + self.layer_idx = layer_idx + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + self.self_attn = Lfm2MoeAttention( + config=config, + layer_idx=layer_idx, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if layer_idx < config.num_dense_layers: + self.feed_forward = Lfm2MoeMlp( + dim=config.hidden_size, + ff_dim=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = Lfm2MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + enable_eplb=enable_eplb, + ) + + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states, residual = self.ffn_norm(hidden_states, residual) + return self.feed_forward(hidden_states), residual + + +class Lfm2MoeShortConvDecoderLayer(nn.Module): + def __init__( + self, + config: Lfm2MoeConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.conv = ShortConv( + config=config, + dim=config.hidden_size, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.conv", + ) + + if layer_idx < config.num_dense_layers: + self.feed_forward = Lfm2MoeMlp( + dim=config.hidden_size, + ff_dim=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + else: + self.feed_forward = Lfm2MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + enable_eplb=enable_eplb, + ) + + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.operator_norm(hidden_states) + else: + hidden_states, residual = self.operator_norm(hidden_states, residual) + output = torch.empty_like(hidden_states) + self.conv( + hidden_states, + output, + ) + hidden_states, residual = self.ffn_norm(output, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Lfm2MoeModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config + enable_eplb = parallel_config.enable_eplb + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size + ) + + def get_layer(prefix: str): + layer_idx = extract_layer_index(prefix) + is_attn = self.config.layer_types[layer_idx] == "full_attention" + layer_class = ( + Lfm2MoeAttentionDecoderLayer + if is_attn + else Lfm2MoeShortConvDecoderLayer + ) + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + if get_pp_group().is_last_rank: + self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + else: + self.embedding_norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.embedding_norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".w1", ".w1", 0), + (".w1", ".w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "expert_bias" in name: + name = name.replace("expert_bias", "gate.e_score_correction_bias") + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if ("feed_forward.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Lfm2MoeForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsQuant, + MixtureOfExperts, +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "w1": [ + "w1", + "w3", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int]]: + """Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.hidden_size, + conv_kernel=hf_config.conv_L_cache, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + assert not cache_config.enable_prefix_caching, ( + "Lfm2Moe currently does not support prefix caching" + ) + + super().__init__() + self.config = config + self.model = Lfm2MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = self.config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size + ), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance( + layer, (Lfm2MoeAttentionDecoderLayer, Lfm2MoeShortConvDecoderLayer) + ) + if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock): + example_layer = layer.feed_forward + self.moe_layers.append(layer.feed_forward.experts) + + if example_layer is None: + raise RuntimeError( + "No Lfm2MoeSparseMoeBlock layer found in the model.layers." + ) + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.feed_forward, Lfm2MoeSparseMoeBlock): + moe = layer.feed_forward + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f99f1c3643fd4..948c9280f953a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -37,28 +39,38 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class LlamaMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -68,6 +80,7 @@ class LlamaMLP(nn.Module): bias: bool = False, prefix: str = "", reduce_results: bool = True, + disable_tp: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -75,6 +88,7 @@ class LlamaMLP(nn.Module): output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, + disable_tp=disable_tp, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( @@ -83,11 +97,13 @@ class LlamaMLP(nn.Module): bias=bias, quant_config=quant_config, reduce_results=reduce_results, + disable_tp=disable_tp, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -98,7 +114,6 @@ class LlamaMLP(nn.Module): class LlamaAttention(nn.Module): - def __init__( self, config: LlamaConfig, @@ -138,8 +153,7 @@ class LlamaAttention(nn.Module): head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", - 1) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -164,18 +178,36 @@ class LlamaAttention(nn.Module): prefix=f"{prefix}.o_proj", ) - self._init_rotary_emb(config, - rope_scaling=rope_scaling, - quant_config=quant_config) + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) sliding_window = None if layer_types := getattr(config, "layer_types", None): - is_sliding = layer_types[layer_idx] == "sliding_attention" + # Fix for Eagle3 compatibility: + # for draft models, subtract target layer count + # to get draft-relative layer index starting from 0 + if hasattr(config, "target_layer_count"): + # This is a draft model, + # adjust layer_idx to be relative to draft layers + effective_layer_idx = layer_idx - config.target_layer_count + else: + # This is a target model, use layer_idx directly + effective_layer_idx = layer_idx + assert effective_layer_idx < len(layer_types), ( + f"effective_layer_idx: {effective_layer_idx} \ + is out of bounds for layer_types: {layer_types}" + ) + + is_sliding = layer_types[effective_layer_idx] == "sliding_attention" if is_sliding: sliding_window = config.sliding_window - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) self.attn = attn_cls( self.num_heads, @@ -201,9 +233,12 @@ class LlamaAttention(nn.Module): output, _ = self.o_proj(attn_output) return output - def _init_rotary_emb(self, config: LlamaConfig, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig]) -> None: + def _init_rotary_emb( + self, + config: LlamaConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig], + ) -> None: is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" if is_gguf and config.model_type == "llama": @@ -221,31 +256,36 @@ class LlamaAttention(nn.Module): class LlamaDecoderLayer(nn.Module): - def __init__( self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", + config: Optional[LlamaConfig] = None, ) -> None: super().__init__() + + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = self.get_quant_config(vllm_config) + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias - if hasattr(config, 'qkv_bias'): + if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias # By default, Llama uses causal attention as it is a decoder-only model. @@ -261,8 +301,9 @@ class LlamaDecoderLayer(nn.Module): config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -281,10 +322,10 @@ class LlamaDecoderLayer(nn.Module): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -297,41 +338,46 @@ class LlamaDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual + def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + """Get quantization config for this layer. Override in subclasses.""" + return vllm_config.quant_config + @support_torch_compile class LlamaModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -342,10 +388,7 @@ class LlamaModel(nn.Module): self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -353,11 +396,11 @@ class LlamaModel(nn.Module): else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -368,8 +411,9 @@ class LlamaModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -383,16 +427,16 @@ class LlamaModel(nn.Module): aux_hidden_states = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) @@ -400,8 +444,7 @@ class LlamaModel(nn.Module): return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -415,19 +458,19 @@ class LlamaModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -460,8 +503,7 @@ class LlamaModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -470,13 +512,13 @@ class LlamaModel(nn.Module): class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings" + "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] @@ -506,11 +548,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): "norm": "model.norm", } - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -518,9 +562,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type) + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -534,39 +580,45 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Override to return default layers for Llama + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): - return LlamaModel(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): + return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -578,29 +630,27 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) - for name, loaded_weight in weights) + for name, loaded_weight in weights + ) # This function is used to remap the mistral format as # used by Mistral and Llama <=2 @@ -609,32 +659,48 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - - def permute(w: torch.Tensor, n_heads: int): + def permute(w: torch.Tensor, n_heads: int, attn_out: int): attn_in = self.config.head_dim * n_heads - attn_out = self.config.hidden_size - return w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, 2).reshape(attn_in, attn_out) + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) mapping = self.mistral_mapping modules = name.split(".") # rotary embeds should be sliced + # If using quantized model in mistral format, + # quantization scales (qscale_weight) also need to be sliced if "wk" in modules and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) + loaded_weight = permute( + loaded_weight, self.config.num_key_value_heads, self.config.hidden_size + ) + elif ( + "wk" in modules + and modules[-1] == "qscale_weight" + and loaded_weight.numel() > 1 + ): + loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1) elif "wq" in modules and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + loaded_weight = permute( + loaded_weight, self.config.num_attention_heads, self.config.hidden_size + ) + elif ( + "wq" in modules + and modules[-1] == "qscale_weight" + and loaded_weight.numel() > 1 + ): + loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1) num_modules = len(modules) for i in range(num_modules): item = modules[i] next_item = modules[i + 1] if i < num_modules - 1 else None - combined_item = (f"{item}.{next_item}" - if next_item is not None else None) + combined_item = f"{item}.{next_item}" if next_item is not None else None if combined_item in mapping: name = name.replace(combined_item, mapping[combined_item]) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index ba08e6f81f7fe..df7bd9b7f6d1b 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Any, Optional @@ -28,24 +29,35 @@ from vllm.attention import Attention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel -from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, - is_pp_missing_parameter) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + fast_topk, + is_pp_missing_parameter, +) class Llama4MoE(nn.Module): - @staticmethod def custom_routing_function( hidden_states: torch.Tensor, @@ -58,22 +70,39 @@ class Llama4MoE(nn.Module): router_scores = torch.sigmoid(router_scores.float()) return (router_scores, router_indices.to(torch.int32)) - def __init__(self, - config: Llama4TextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.top_k = config.num_experts_per_tok + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe intermediate_size_moe = config.intermediate_size - self.router = ReplicatedLinear(config.hidden_size, - config.num_local_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.router") + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.router", + ) - self.experts = FusedMoE( + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, + disable_tp=self.is_sequence_parallel, + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -83,49 +112,50 @@ class Llama4MoE(nn.Module): reduce_results=False, renormalize=False, quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.shared_expert = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size_moe, - hidden_act="silu", - quant_config=quant_config, - bias=False, - prefix=f"{prefix}.shared_expert", - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, hidden_states): + num_tokens = hidden_states.shape[0] + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + router_logits, _ = self.router(hidden_states) - shared_out = self.shared_expert(hidden_states) - routed_out = self.experts( + + shared_out, routed_out = self.experts( hidden_states=hidden_states, router_logits=router_logits, ) experts_out = routed_out + shared_out - if self.tp_size > 1: + if self.is_sequence_parallel: + experts_out = tensor_model_parallel_all_gather(experts_out, 0) + experts_out = experts_out[:num_tokens] + elif self.tp_size > 1: experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( - experts_out) + experts_out + ) return experts_out class Llama4Attention(nn.Module): - - def __init__(self, - config: Llama4TextConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Llama4TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size @@ -150,20 +180,23 @@ class Llama4Attention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn_temperature_tuning = self.nope and \ - config.attn_temperature_tuning + self.attn_temperature_tuning = self.nope and config.attn_temperature_tuning self.floor_scale = getattr(config, "floor_scale", 8192.0) self.attn_scale = getattr(config, "attn_scale", 0.1) self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.n_rep = self.num_heads // self.num_kv_heads - self.qk_norm = RMSNorm( - hidden_size=self.head_dim, - eps=config.rms_norm_eps, - has_weight=False, - dtype=torch.float32, - ) if self.use_qk_norm else None + self.qk_norm = ( + RMSNorm( + hidden_size=self.head_dim, + eps=config.rms_norm_eps, + has_weight=False, + dtype=torch.float32, + ) + if self.use_qk_norm + else None + ) self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, @@ -186,18 +219,21 @@ class Llama4Attention(nn.Module): if is_gguf and config.model_type == "llama": is_neox_style = False - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - rope_scaling=rope_scaling if rope_scaling != "default" else None, - is_neox_style=is_neox_style, - ) if not self.nope else None + self.rotary_emb = ( + get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling if rope_scaling != "default" else None, + is_neox_style=is_neox_style, + ) + if not self.nope + else None + ) use_chunked_local_attn = not self.nope and config.attention_chunk_size - attn_cls = (ChunkedLocalAttention - if use_chunked_local_attn else Attention) + attn_cls = ChunkedLocalAttention if use_chunked_local_attn else Attention self.attn = attn_cls( self.num_heads, self.head_dim, @@ -206,9 +242,12 @@ class Llama4Attention(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", - **({ - "attention_chunk_size": config.attention_chunk_size - } if use_chunked_local_attn else {})) + **( + {"attention_chunk_size": config.attention_chunk_size} + if use_chunked_local_attn + else {} + ), + ) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) @@ -253,16 +292,18 @@ class Llama4Attention(nn.Module): class Llama4DecoderLayer(nn.Module): - def __init__( self, - config: Llama4TextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", + config: Optional[Llama4TextConfig] = None, ) -> None: super().__init__() + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.layer_idx = extract_layer_index(prefix) self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size @@ -284,12 +325,13 @@ class Llama4DecoderLayer(nn.Module): cache_config=cache_config, prefix=f"{prefix}.self_attn", ) - is_moe_layer = config.interleave_moe_layer_step > 0 and ( - self.layer_idx + 1) % config.interleave_moe_layer_step == 0 + is_moe_layer = ( + config.interleave_moe_layer_step > 0 + and (self.layer_idx + 1) % config.interleave_moe_layer_step == 0 + ) if is_moe_layer: self.feed_forward = Llama4MoE( - config=config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.feed_forward", ) else: @@ -301,10 +343,10 @@ class Llama4DecoderLayer(nn.Module): bias=False, prefix=f"{prefix}.feed_forward", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -317,30 +359,26 @@ class Llama4DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @support_torch_compile class Llama4Model(LlamaModel): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, + ): self.num_experts = vllm_config.model_config.hf_config.num_local_experts - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) def load_moe_expert_weights( self, @@ -360,7 +398,7 @@ class Llama4Model(LlamaModel): params_dict: The dictionary of module parameters. loaded_params: The set of already loaded parameters. expert_params_mapping: The mapping of expert parameters. Must be - generated by FusedMoE.make_expert_params_mapping(). + generated by SharedFusedMoE.make_expert_params_mapping(). fused: Whether the expert weights are fused into a single weight tensor or are separate weight tensors for each expert. When fused is True, loaded_weight should have shape of: @@ -391,9 +429,7 @@ class Llama4Model(LlamaModel): # Iterate over all the expert parameters and load the weights if we find # a match in weight name. - for (param_name, weight_name, expert_id, - shard_id) in expert_params_mapping: - + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: # Get a view of the loaded_weight to avoid modifying the original # one across iterations. new_loaded_weight = loaded_weight @@ -402,7 +438,7 @@ class Llama4Model(LlamaModel): # the expert index from the expected weight name. if fused: # The string between e_str and proj_str is the expert index. - e_str, _, proj_str, _ = weight_name.split('.') + e_str, _, proj_str, _ = weight_name.split(".") weight_name = f"{e_str}.{proj_str}" param_name = f"{param_name}weight" @@ -419,8 +455,9 @@ class Llama4Model(LlamaModel): continue # Skip if the current weight is for the bias. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[full_param_name] @@ -439,13 +476,14 @@ class Llama4Model(LlamaModel): # starting expert index for the current EP rank and extract the # corresponding expert weights. layer_idx = extract_layer_index(name) - expert_map = self.layers[ - layer_idx].feed_forward.experts.expert_map + expert_map = self.layers[layer_idx].feed_forward.experts.expert_map if expert_map is not None: - local_expert_indices = (expert_map != -1) \ - .nonzero() \ - .flatten() \ - .to(new_loaded_weight.device) + local_expert_indices = ( + (expert_map != -1) + .nonzero() + .flatten() + .to(new_loaded_weight.device) + ) new_loaded_weight = new_loaded_weight[local_expert_indices] expert_id = local_expert_indices[0].item() else: @@ -454,19 +492,20 @@ class Llama4Model(LlamaModel): # Load the weight into the module parameter with corresponding # shard id and expert id. - weight_loader(param, - new_loaded_weight, - full_param_name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + new_loaded_weight, + full_param_name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(full_param_name) expert_param_loaded = True return expert_param_loaded - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Name mapping from the parameter name to the shard name and # corresponding shard id. stacked_params_mapping = [ @@ -482,18 +521,20 @@ class Llama4Model(LlamaModel): fused_experts_params = False # Expert parameter mapping for the case where the expert weights are # not fused into a single weight tensor. - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.num_experts) + num_experts=self.num_experts, + ) # Expert parameter mapping for the case where the expert weights are # fused into a single weight tensor. - expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( + expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="gate_up_proj", - num_experts=1) + num_experts=1, + ) # All the module parameters. params_dict = dict(self.named_parameters()) # The module parameters that have been loaded. @@ -501,7 +542,6 @@ class Llama4Model(LlamaModel): # Iterate over all the weights and load them into module parameters. for name, loaded_weight in weights: - # If the name contains "experts.gate_up_proj" or "experts.down_proj" # without the expert indices, it means the expert weights are fused # into a single weight tensor across all experts. @@ -512,13 +552,14 @@ class Llama4Model(LlamaModel): # If kv cache quantization scales exist and the weight name # corresponds to one of the kv cache quantization scales, load # them. - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -535,8 +576,9 @@ class Llama4Model(LlamaModel): # For ModelOpt checkpoints, we need to rename the self_attn # weight/weight_scale names except for kv cache scales. - if not (name.endswith( - (".k_scale", ".v_scale")) and "self_attn" in name): + if not ( + name.endswith((".k_scale", ".v_scale")) and "self_attn" in name + ): name = name.replace(weight_name, param_name) # Skip if the current weight corresponds to a parameter that @@ -555,8 +597,7 @@ class Llama4Model(LlamaModel): # Load the weight into the module parameter with corresponding # shard id and exit the for loop and the else block. param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) @@ -570,12 +611,14 @@ class Llama4Model(LlamaModel): else: # First, try to load MoE weights using load_moe_expert_weights. # If successful, move on to next loaded weight. - if self.load_moe_expert_weights(name, - loaded_weight, - params_dict, - loaded_params, - expert_params_mapping, - fused=fused_experts_params): + if self.load_moe_expert_weights( + name, + loaded_weight, + params_dict, + loaded_params, + expert_params_mapping, + fused=fused_experts_params, + ): continue # Skip if the current weight corresponds to a parameter that @@ -587,37 +630,40 @@ class Llama4Model(LlamaModel): # per-expert patterns, i.e. one weight scale tensor for all # experts. scale_names = [ - "w13_input_scale", "w13_weight_scale", "w2_input_scale", - "w2_weight_scale" + "w13_input_scale", + "w13_weight_scale", + "w2_input_scale", + "w2_weight_scale", ] - if ("experts." in name and any(scale_name in name - for scale_name in scale_names)): - + if "experts." in name and any( + scale_name in name for scale_name in scale_names + ): param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) # If weight loader supports special moe loading, use it to # avoid expensive runtime reflection - if getattr(weight_loader, 'supports_moe_loading', False): + if getattr(weight_loader, "supports_moe_loading", False): # Map the weight name to the corresponding shard id. shard_id = "w2" if "w2_" in name else "w1" # Transpose if weight scales are FP8 block scales with # three dimensions: # [num_experts, hidden_in, hidden_out]. - if name.endswith("weight_scale") \ - and loaded_weight.dtype == torch.float8_e4m3fn \ - and loaded_weight.ndim == 3: + if ( + name.endswith("weight_scale") + and loaded_weight.dtype == torch.float8_e4m3fn + and loaded_weight.ndim == 3 + ): loaded_weight = loaded_weight.transpose(-1, -2) # Load the weight into the module parameter with # corresponding shard id and expert id. - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=0) + weight_loader( + param, loaded_weight, name, shard_id=shard_id, expert_id=0 + ) else: # Regular weight loader (handles both @@ -629,8 +675,7 @@ class Llama4Model(LlamaModel): # Handle normal (non-stacked, non-MoE) weights. param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -639,7 +684,6 @@ class Llama4Model(LlamaModel): class Llama4ForCausalLM(LlamaForCausalLM): - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -650,30 +694,29 @@ class Llama4ForCausalLM(LlamaForCausalLM): gen_config = vllm_config.model_config.try_get_generation_config() gen_config.update(vllm_config.model_config.override_generation_config) # enable temperature tuning by default when max_model_len > 32K - default_attn_temperature_tuning = \ - vllm_config.model_config.max_model_len > 32768 - vllm_config.model_config.hf_config.attn_temperature_tuning \ - = gen_config.get( - "attn_temperature_tuning", default_attn_temperature_tuning) + default_attn_temperature_tuning = vllm_config.model_config.max_model_len > 32768 + vllm_config.model_config.hf_config.attn_temperature_tuning = gen_config.get( + "attn_temperature_tuning", default_attn_temperature_tuning + ) - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=Llama4DecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer + ) - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): - return Llama4Model(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, + ): + return Llama4Model( + vllm_config=vllm_config, prefix=prefix, layer_type=layer_type + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) weights = [ self.permute_qk_weight_for_rotary(name, loaded_weight) @@ -686,10 +729,8 @@ class Llama4ForCausalLM(LlamaForCausalLM): name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - # Helper function to permute the weight's channels def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): - # Calculate the expected shape of the weight. # Do not rely on w's shape, as it may be in another layout. attn_in = self.config.head_dim * n_heads @@ -702,28 +743,39 @@ class Llama4ForCausalLM(LlamaForCausalLM): # If the weight is a weight scale, we need to divide attn_out by # block size, which is currently 16. - elif w.dtype == torch.float8_e4m3fn and is_weight_scale \ - and w.shape[1] * 16 == attn_out: + elif ( + w.dtype == torch.float8_e4m3fn + and is_weight_scale + and w.shape[1] * 16 == attn_out + ): attn_out = attn_out // 16 - return w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, 2).reshape(attn_in, attn_out) + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) modules = name.split(".") # Permute Q/K weights and weight block scales for rotary embedding is_weight = modules[-1] == "weight" - is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and - loaded_weight.dtype == torch.float8_e4m3fn) + is_nvfp4_weight_scale = ( + modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn + ) if is_weight or is_nvfp4_weight_scale: - if ("wk" in modules or "k_proj" in modules): - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads, - is_nvfp4_weight_scale) - elif ("wq" in modules or "q_proj" in modules): - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads, - is_nvfp4_weight_scale) + if "wk" in modules or "k_proj" in modules: + loaded_weight = permute( + loaded_weight, + self.config.num_key_value_heads, + is_nvfp4_weight_scale, + ) + elif "wq" in modules or "q_proj" in modules: + loaded_weight = permute( + loaded_weight, + self.config.num_attention_heads, + is_nvfp4_weight_scale, + ) return name, loaded_weight diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index ece490ff2f2a8..039022ef4527f 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -28,25 +28,21 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.torchao import TorchAOConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, - Llama4ForCausalLM) +from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM from vllm.model_executor.models.utils import extract_layer_index -from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, maybe_prefix logger = init_logger(__name__) @support_torch_compile class LlamaModel(nn.Module): - def __init__( self, *, @@ -56,8 +52,7 @@ class LlamaModel(nn.Module): quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.config = ( - vllm_config.speculative_config.draft_model_config.hf_config) + self.config = vllm_config.speculative_config.draft_model_config.hf_config self.validate_and_update_config(start_layer_id, quant_config) self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -66,23 +61,22 @@ class LlamaModel(nn.Module): prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - Llama4DecoderLayer( - self.config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - ) for i in range(self.config.num_hidden_layers) - ]) - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) - self.norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.layers = nn.ModuleList( + [ + Llama4DecoderLayer( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear( + self.config.hidden_size * 2, self.config.hidden_size, bias=False + ) + self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -94,8 +88,7 @@ class LlamaModel(nn.Module): ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) - hidden_states = self.fc( - torch.cat((inputs_embeds, hidden_states), dim=-1)) + hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -106,8 +99,7 @@ class LlamaModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -130,69 +122,71 @@ class LlamaModel(nn.Module): break else: # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) for name in params_dict: # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue assert name in loaded_params, f"{name} is not loaded!" return loaded_params def validate_and_update_config( - self, - start_layer_id: int, - quant_config: Optional[QuantizationConfig] = None) -> None: + self, start_layer_id: int, quant_config: Optional[QuantizationConfig] = None + ) -> None: # yoco and moe is not supported by draft model yet assert self.config.yoco_global_kv_layer is None assert self.config.yoco_local_kv_layer is None assert len(self.config.moe_layers) == 0 # draft model layer index is increased by start_layer_id, # so we need to pad relevant configs accordingly - self.config.no_rope_layers = [ - 0 - ] * start_layer_id + self.config.no_rope_layers + self.config.no_rope_layers = [0] * start_layer_id + self.config.no_rope_layers # currently only TorchAO quantization is supported if isinstance(quant_config, TorchAOConfig): def pad_layer_name(layer: str) -> str: layer_index = extract_layer_index(layer) - return layer.replace(str(layer_index), - str(layer_index + start_layer_id)) + return layer.replace( + str(layer_index), str(layer_index + start_layer_id) + ) - quant_config.torchao_config.module_fqn_to_config = { + torchao_config = quant_config.torchao_config + torchao_config.module_fqn_to_config = { pad_layer_name(layer): quantization - for layer, quantization in - quant_config.torchao_config.module_fqn_to_config.items() + for layer, quantization in torchao_config.module_fqn_to_config.items() } class EagleLlama4ForCausalLM(Llama4ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = ( - vllm_config.speculative_config.draft_model_config.hf_config) + self.config = vllm_config.speculative_config.draft_model_config.hf_config target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + vllm_config.parallel_config + ) # draft model quantization config may differ from target model quant_config = VllmConfig.get_quantization_config( - vllm_config.speculative_config.draft_model_config, - vllm_config.load_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num, - quant_config=quant_config) + vllm_config.speculative_config.draft_model_config, vllm_config.load_config + ) + self.model = LlamaModel( + vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + quant_config=quant_config, + ) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) + + def get_language_model(self) -> torch.nn.Module: + return self.model + + get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore def forward( self, @@ -203,39 +197,17 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states, inputs_embeds) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> None: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: + def transform(inputs): + name, loaded_weight = inputs + name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight) + if "lm_head" not in name: + name = "model." + name + return name, weight + loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) skip_prefixes=(["lm_head."]), ) - - model_weights = {} - weights = [ - self.permute_qk_weight_for_rotary(name, loaded_weight) - for name, loaded_weight in weights - ] - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - - loader.load_weights(model_weights.items()) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds + loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index a4933b77e3a53..5df158818c9fb 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -13,11 +13,9 @@ from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import (LlamaDecoderLayer, - LlamaForCausalLM) +from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM from .utils import AutoWeightsLoader, maybe_prefix @@ -25,14 +23,14 @@ logger = init_logger(__name__) class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( self, - config: LlamaConfig, + vllm_config: VllmConfig, disable_input_layernorm: bool, prefix: str = "", + config: Optional[LlamaConfig] = None, ) -> None: - super().__init__(config, prefix=prefix) + super().__init__(vllm_config, prefix=prefix, config=config) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 @@ -43,7 +41,6 @@ class LlamaDecoderLayer(LlamaDecoderLayer): @support_torch_compile class LlamaModel(nn.Module): - def __init__( self, *, @@ -52,8 +49,7 @@ class LlamaModel(nn.Module): start_layer_id: int = 0, ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -62,16 +58,23 @@ class LlamaModel(nn.Module): prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer( - self.config, - i == 0, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - ) for i in range(self.config.num_hidden_layers) - ]) - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + vllm_config, + i == 0, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear( + self.config.hidden_size * 2, self.config.hidden_size, bias=False + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -80,8 +83,7 @@ class LlamaModel(nn.Module): hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) - hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -92,8 +94,7 @@ class LlamaModel(nn.Module): hidden_states = hidden_states + residual return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -114,35 +115,40 @@ class LlamaModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num) + vllm_config.parallel_config + ) + self.model = LlamaModel( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -158,14 +164,14 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def transform(inputs): + name, loaded_weight = inputs + if "lm_head" not in name: + name = "model." + name + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=None, ) - - model_weights = {} - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 572930c39a846..67d4669899193 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -9,19 +9,21 @@ import torch.nn as nn from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import (LlamaDecoderLayer, - LlamaForCausalLM) -from vllm.v1.sample.metadata import SamplingMetadata +from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from .utils import AutoWeightsLoader, maybe_prefix @@ -29,18 +31,25 @@ logger = init_logger(__name__) class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( self, - config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", + config: Optional[LlamaConfig] = None, + layer_idx: int = 0, ) -> None: - super().__init__(config, quant_config=quant_config, prefix=prefix) + super().__init__(vllm_config, prefix=prefix, config=config) + + config = config or vllm_config.model_config.hf_config + quant_config = self.get_quant_config(vllm_config) + + # First layer uses 2*hidden_size (embeds + hidden_states concatenated) + # Subsequent layers use hidden_size (only hidden_states, no embeds) + qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size # override qkv self.self_attn.qkv_proj = QKVParallelLinear( - 2 * self.hidden_size, + qkv_input_size, self.self_attn.head_dim, self.self_attn.total_num_heads, self.self_attn.total_num_kv_heads, @@ -50,22 +59,34 @@ class LlamaDecoderLayer(LlamaDecoderLayer): ) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx if getattr(config, "norm_before_residual", False): self._residual_norm = self._norm_before_residual else: self._residual_norm = self._norm_after_residual + def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + """Use drafter's quantization config instead of verifier's.""" + draft_model_config = vllm_config.speculative_config.draft_model_config + draft_load_config = vllm_config.load_config + + return ( + VllmConfig.get_quantization_config(draft_model_config, draft_load_config) + if draft_model_config + else None + ) + def _norm_before_residual( - self, - hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: hidden_states = self.hidden_norm(hidden_states) residual = hidden_states return hidden_states, residual def _norm_after_residual( - self, - hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.hidden_norm(hidden_states) return hidden_states, residual @@ -77,21 +98,22 @@ class LlamaDecoderLayer(LlamaDecoderLayer): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: + if self.layer_idx == 0: + # First layer: concatenate embeds with hidden_states + embeds = self.input_layernorm(embeds) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + else: + # Subsequent layers: process hidden_states and residuals only + hidden_states, residual = self.input_layernorm(hidden_states, residual) - embeds = self.input_layernorm(embeds) - - hidden_states, residual = self._residual_norm( - hidden_states=hidden_states) - - hidden_states = torch.cat([embeds, hidden_states], dim=-1) # Self Attention hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) # Fully Connected hidden_states = self.mlp(hidden_states) @@ -99,9 +121,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer): return hidden_states, residual -@support_torch_compile +@support_torch_compile( + # torch.compile is disabled for multimodal EAGLE3 models due to constraint + # violations with dynamic shapes during tensor concatenation operations. + # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132 + # Non-multimodal EAGLE3 models can still use torch.compile safely. + enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs( + vllm_config.model_config + ), +) class LlamaModel(nn.Module): - def __init__( self, *, @@ -110,57 +139,67 @@ class LlamaModel(nn.Module): prefix: str = "", ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + current_vllm_config = get_current_vllm_config() + self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer( - config=self.config, - prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), - ) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), + config=self.config, + layer_idx=layer_idx, + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + ) if hasattr(self.config, "target_hidden_size"): - self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, - self.config.hidden_size, - bias=False) + self.fc = torch.nn.Linear( + self.config.target_hidden_size * 3, self.config.hidden_size, bias=False + ) else: - self.fc = torch.nn.Linear(self.config.hidden_size * 3, - self.config.hidden_size, - bias=False) + self.fc = torch.nn.Linear( + self.config.hidden_size * 3, self.config.hidden_size, bias=False + ) self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if input_embeds is None: + input_embeds = self.get_input_embeddings(input_ids) assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None - hidden_states, residual = self.layers[0]( - positions, - input_embeds, - hidden_states, - residual, - ) - + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + embeds=input_embeds, + hidden_states=hidden_states, + residual=residual, + ) hidden_states, hidden_prenorm = self.norm(hidden_states, residual) return hidden_states, hidden_prenorm - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -172,8 +211,8 @@ class LlamaModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if 'midlayer.' in name: - name = name.replace('midlayer.', 'layers.0.') + if "midlayer." in name: + name = name.replace("midlayer.", "layers.0.") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -184,24 +223,31 @@ class LlamaModel(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Eagle3LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num) + vllm_config.parallel_config + ) + + # Store target layer count in draft config for + # proper layer_types indexing in draft models + self.config.target_layer_count = target_layer_num + self.model = LlamaModel( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) logit_scale = getattr(self.config, "logit_scale", 1.0) self.lm_head = ParallelLMHead( @@ -209,14 +255,24 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): self.config.hidden_size, org_num_embeddings=self.config.draft_vocab_size, padding_size=(DEFAULT_VOCAB_PADDING_SIZE), - prefix="") - self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, - scale=logit_scale) + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, scale=logit_scale + ) self.draft_id_to_target_id = nn.Parameter( torch.zeros(self.config.draft_vocab_size, dtype=torch.long), requires_grad=False, ) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + is_multimodal: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -224,31 +280,29 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is not None: - raise NotImplementedError( - f"{type(self).__name__} does not support multimodal inputs yet." - ) - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) if self.draft_id_to_target_id is None: - assert logits.shape[1] == self.config.vocab_size, \ - "Expected logits to have shape " \ + assert logits.shape[1] == self.config.vocab_size, ( + "Expected logits to have shape " f"(*, {self.config.vocab_size}), but got {logits.shape}" + ) return logits base = torch.arange(self.config.draft_vocab_size, device=logits.device) targets = base + self.draft_id_to_target_id - logits_new = logits.new_full(( - logits.shape[0], - self.config.vocab_size, - ), float('-inf')) + logits_new = logits.new_full( + ( + logits.shape[0], + self.config.vocab_size, + ), + float("-inf"), + ) logits_new[:, targets] = logits return logits_new diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index cd41d4fb43885..3d46e22a0d217 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -3,46 +3,64 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union, cast) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn -from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, - PixtralVisionConfig, PretrainedConfig, - SiglipVisionConfig) +from transformers import ( + BatchFeature, + CLIPVisionConfig, + LlavaConfig, + PixtralVisionConfig, + PretrainedConfig, + SiglipVisionConfig, +) from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) -from .vision import get_vision_encoder_info +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import get_num_selected_vision_tokens, get_vision_encoder_info class LlavaImagePixelInputs(TensorSchema): @@ -52,10 +70,11 @@ class LlavaImagePixelInputs(TensorSchema): - c: Number of channels (3) - h: Height - w: Width - + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -67,14 +86,16 @@ class PixtralHFImagePixelInputs(TensorSchema): - c: Number of channels - h: Height - w: Width - + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})] + TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}), + ] class LlavaImageEmbeddingInputs(TensorSchema): @@ -84,36 +105,43 @@ class LlavaImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs, - LlavaImageEmbeddingInputs] +LlavaImageInputs = Union[ + LlavaImagePixelInputs, PixtralHFImagePixelInputs, LlavaImageEmbeddingInputs +] class LlavaMultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(image_features) @@ -134,7 +162,6 @@ class LlavaLikeProcessor(Protocol): class BaseLlavaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> LlavaLikeConfig: return self.ctx.get_hf_config(LlavaConfig) @@ -148,19 +175,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def _apply_feature_select_strategy( - self, - strategy: str, - encoder_num_image_tokens: int, - ) -> int: - if strategy == "default": - return encoder_num_image_tokens - 1 - if strategy == "full": - return encoder_num_image_tokens - - msg = f"Unexpected feature select strategy: {strategy!r}" - raise NotImplementedError(msg) - def get_num_image_tokens( self, *, @@ -170,12 +184,12 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - return self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + return get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), + hf_config.vision_feature_select_strategy, ) def get_image_size_with_most_features(self) -> ImageSize: @@ -196,7 +210,6 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -209,22 +222,25 @@ class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class LlavaProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs) # In case patch_size is omitted from `processor_config.json` @@ -236,7 +252,6 @@ class LlavaProcessingInfo(BaseLlavaProcessingInfo): class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): - # Copied from BaseMultiModalProcessor @abstractmethod def _get_mm_fields_config( @@ -257,7 +272,8 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): def get_replacement(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -279,9 +295,7 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): ] -class LlavaMultiModalProcessor( - BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): - +class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -294,14 +308,11 @@ class LlavaMultiModalProcessor( class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) -class PixtralHFMultiModalProcessor( - BaseMultiModalProcessor[PixtralHFProcessingInfo]): - +class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -381,7 +392,8 @@ class PixtralHFMultiModalProcessor( def _build_llava_or_pixtral_hf_info( - ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + ctx: InputProcessingContext, +) -> BaseLlavaProcessingInfo: hf_config = ctx.get_hf_config(LlavaConfig) if isinstance(hf_config.vision_config, PixtralVisionConfig): @@ -394,7 +406,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( @@ -416,7 +428,7 @@ def _build_llava_or_pixtral_hf_processor( def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: """Determine the number of hidden layers to initialize up to in the visual encoder. - + Args: hf_config: Model config with vision feature layer(s). """ @@ -427,10 +439,10 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: return _get_layer_index(feature_layers, num_hidden_layers) # If we have multiple feature layers, initialize up to the deepest one elif isinstance(feature_layers, (list, tuple)): - return max( - _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @@ -488,14 +500,17 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, - info=_build_llava_or_pixtral_hf_info, - dummy_inputs=LlavaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + _build_llava_or_pixtral_hf_processor, + info=_build_llava_or_pixtral_hf_info, + dummy_inputs=LlavaDummyInputsBuilder, +) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } hf_to_vllm_mapper = WeightsMapper( @@ -505,7 +520,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -526,11 +542,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # NOTE: These are special cases for Pixtral-12B in the HF-format # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa - if (config.text_config.architectures is None - and config.text_config.model_type == "mistral"): + if ( + config.text_config.architectures is None + and config.text_config.model_type == "mistral" + ): config.text_config.architectures = ["MistralForCausalLM"] - if (config.projector_hidden_act is None - and config.vision_config.hidden_act == "gelu"): + if ( + config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu" + ): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. @@ -539,14 +559,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=config.multimodal_projector_bias, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) else: self.vision_tower = None self.multi_modal_projector = None @@ -558,10 +580,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaImageInputs]: + self, **kwargs: object + ) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -569,70 +593,40 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - if self.config.vision_config.model_type == "pixtral": return PixtralHFImagePixelInputs( type="pixel_values_pixtral", - pixel_values=flatten_bn(pixel_values), + pixel_values=pixel_values, ) expected_h = expected_w = self.config.vision_config.image_size return LlavaImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }, + pixel_values=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - if self.config.vision_config.model_type == "pixtral": raise ValueError("Pixtral-HF does not support image_embeds.") return LlavaImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, - PixtralHFVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) - - def select_features(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), + return vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) def _process_image_pixels( @@ -658,9 +652,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): if isinstance(image_features, torch.Tensor): return self.multi_modal_projector(image_features) - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features - ] + feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) @@ -669,30 +661,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -731,39 +706,29 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: - [LlavaImageInputs][] + [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs] """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.vision_tower is None and self.multi_modal_projector is None: skip_prefixes.extend(["vision_tower.", "multi_modal_projector."]) @@ -773,7 +738,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class MantisProcessingInfo(LlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): hf_config = self.get_hf_config() vision_info = self.get_vision_encoder_info() @@ -788,13 +752,13 @@ class MantisProcessingInfo(LlavaProcessingInfo): class MantisMultiModalProcessor(LlavaMultiModalProcessor): - def apply( self, prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -805,8 +769,13 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): image_height=-1, ) - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs) + result = super().apply( + prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() @@ -816,38 +785,36 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): # We reimplement the functionality of MLlavaProcessor from # https://github.com/TIGER-AI-Lab/Mantis.git def get_replacement_mantis(item_idx: int): - return "".join([ - f"(image {item_idx+1}: <Image>", # 7 tokens - "<image>" * num_image_tokens, - "</Image>)", # 3 tokens - ]) - - mantis_mm_repls = self._bind_and_group_updates([ - PromptReplacement( - modality="image", - target=[image_token_id] * num_image_tokens, - replacement=get_replacement_mantis, + return "".join( + [ + f"(image {item_idx + 1}: <Image>", # 7 tokens + "<image>" * num_image_tokens, + "</Image>)", # 3 tokens + ] ) - ]) - prompt_ids, prompt, _ = self._apply_prompt_updates( - result["prompt_token_ids"], - mantis_mm_repls, + mantis_mm_repls = self._bind_and_group_updates( + [ + PromptReplacement( + modality="image", + target=[image_token_id] * num_image_tokens, + replacement=get_replacement_mantis, + ) + ], mm_item_counts, ) - unbound_orig_repls = self._get_prompt_updates( + prompt_ids, _ = self._apply_prompt_updates( + result["prompt_token_ids"], + mantis_mm_repls, + ) + + orig_repls = self._get_mm_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_updates(unbound_orig_repls) - - mm_placeholders = self._find_mm_placeholders( - orig_repls, - prompt_ids, - mm_item_counts, - ) + mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { @@ -857,7 +824,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, @@ -867,8 +833,10 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` -@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, - info=MantisProcessingInfo, - dummy_inputs=LlavaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + MantisMultiModalProcessor, + info=MantisProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder, +) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index a63c18493df5e..caedace7cab1e 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -3,17 +3,17 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, unpad_image) + get_anyres_image_grid_shape, + unpad_image, +) from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import ImageSize @@ -22,12 +22,22 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, - LlavaDummyInputsBuilder, LlavaLikeConfig, - LlavaMultiModalProjector, init_vision_tower_for_llava) +from .llava import ( + BaseLlavaMultiModalProcessor, + BaseLlavaProcessingInfo, + LlavaDummyInputsBuilder, + LlavaLikeConfig, + LlavaMultiModalProjector, + init_vision_tower_for_llava, +) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, - flatten_bn, init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import get_num_selected_vision_tokens class LlavaNextImagePixelInputs(TensorSchema): @@ -38,14 +48,16 @@ class LlavaNextImagePixelInputs(TensorSchema): - c: Number of channels (3) - h: Height - w: Width - + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})] + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), + ] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] # This should be in `(height, width)` format. @@ -58,12 +70,12 @@ class LlavaNextImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, - LlavaNextImageEmbeddingInputs] +LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs] class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): @@ -71,7 +83,6 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_config(self) -> LlavaNextLikeConfig: return self.ctx.get_hf_config(LlavaNextConfig) @@ -96,12 +107,12 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - base_feature_size = self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + base_feature_size = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), + hf_config.vision_feature_select_strategy, ) num_patch_height, num_patch_width = get_anyres_image_grid_shape( @@ -141,12 +152,14 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): if aspect_ratio > current_aspect_ratio: new_height = int( - round(original_height * (current_width / original_width), 7)) + round(original_height * (current_width / original_width), 7) + ) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: new_width = int( - round(original_width * (current_height / original_height), 7)) + round(original_width * (current_height / original_height), 7) + ) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) @@ -159,13 +172,13 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): hf_config = self.get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self.get_num_image_tokens(image_width=width, - image_height=height) + for height, width in hf_config.image_grid_pinpoints: + feat_size = self.get_num_image_tokens( + image_width=width, image_height=height + ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -177,7 +190,6 @@ _I = TypeVar("_I", bound=LlavaNextProcessingInfo) class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): - # Copied from BaseMultiModalProcessor @abstractmethod def _get_mm_fields_config( @@ -189,8 +201,8 @@ class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): class LlavaNextMultiModalProcessor( - BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): - + BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo] +): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -203,11 +215,13 @@ class LlavaNextMultiModalProcessor( ) -@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, - info=LlavaNextProcessingInfo, - dummy_inputs=LlavaDummyInputsBuilder) -class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=LlavaNextProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder, +) +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -217,7 +231,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -236,16 +251,18 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, # Determine the layer up to which we will initialize the vision tower if isinstance(vision_feature_layer, int): vision_hidden_size = config.vision_config.hidden_size - self.feature_sample_layers = None + self.select_layers = None # Used for multimodal granite models to control encoder outputs elif isinstance(vision_feature_layer, (list, tuple)): vision_hidden_size = config.vision_config.hidden_size * len( - vision_feature_layer) - self.feature_sample_layers = vision_feature_layer + vision_feature_layer + ) + self.select_layers = vision_feature_layer else: raise TypeError( f"vision_layer_feature type: {type(vision_feature_layer)}" - " is not supported") + " is not supported" + ) self.config = config self.multimodal_config = multimodal_config @@ -255,14 +272,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=vision_hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias) + multimodal_projector_bias=config.multimodal_projector_bias, + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -271,10 +289,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + self, **kwargs: object + ) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -283,78 +303,56 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - expected_h = expected_w = self.config.vision_config.image_size return LlavaNextImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, resolve_bindings={ "h": expected_h, "w": expected_w, - }) + }, + ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeds. " - f"Got type: {type(image_embeds)}") - return LlavaNextImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower( - pixel_values, feature_sample_layers=self.feature_sample_layers) - - return self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + return vision_tower( + pixel_values, + select_layers=self.select_layers, + feature_select_strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py - def _merge_image_patch_embeddings(self, image_size: torch.Tensor, - patch_embeddings: torch.Tensor, *, - strategy: str) -> torch.Tensor: + def _merge_image_patch_embeddings( + self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str + ) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): - height = width = self.config.vision_config.image_size \ + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( - "The number of patches is not consistent with the " - "image size.") + "The number of patches is not consistent with the image size." + ) if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] @@ -371,37 +369,51 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing - other_patch_embeds = other_patch_embeds[:num_patches] \ - .view(num_patch_height, num_patch_width, height, width, -1) + other_patch_embeds = other_patch_embeds[:num_patches].view( + num_patch_height, num_patch_width, height, width, -1 + ) if "unpad" in strategy: - other_patch_embeds = other_patch_embeds \ - .permute(4, 0, 2, 1, 3).contiguous() \ - .flatten(1, 2).flatten(2, 3) - other_patch_embeds = unpad_image(other_patch_embeds, - (orig_height, orig_width)) - other_patch_embeds = torch.cat(( - other_patch_embeds, - self.image_newline[:, None, None] \ - .expand(*other_patch_embeds.shape[:-1], 1) \ + other_patch_embeds = ( + other_patch_embeds.permute(4, 0, 2, 1, 3) + .contiguous() + .flatten(1, 2) + .flatten(2, 3) + ) + other_patch_embeds = unpad_image( + other_patch_embeds, (orig_height, orig_width) + ) + other_patch_embeds = torch.cat( + ( + other_patch_embeds, + self.image_newline[:, None, None] + .expand(*other_patch_embeds.shape[:-1], 1) .to(other_patch_embeds.device), - ), dim=-1) - other_patch_embeds = other_patch_embeds \ - .flatten(1, 2).transpose(0, 1) + ), + dim=-1, + ) + other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose( + 0, 1 + ) else: - other_patch_embeds = other_patch_embeds \ - .permute(0, 2, 1, 3, 4).contiguous() \ + other_patch_embeds = ( + other_patch_embeds.permute(0, 2, 1, 3, 4) + .contiguous() .flatten(0, 3) + ) merged_patch_embeddings = torch.cat( - (base_patch_embeds, other_patch_embeds), dim=0) + (base_patch_embeds, other_patch_embeds), dim=0 + ) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( - (base_patch_embeds, - self.image_newline[None] \ - .to(base_patch_embeds.device) - ), dim=0) + ( + base_patch_embeds, + self.image_newline[None].to(base_patch_embeds.device), + ), + dim=0, + ) else: merged_patch_embeddings = base_patch_embeds @@ -421,20 +433,25 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) stacked_patch_embeddings = self.multi_modal_projector( - stacked_image_features) + stacked_image_features + ) return stacked_patch_embeddings.view( - b, num_patches, *stacked_patch_embeddings.shape[1:]) + b, num_patches, *stacked_patch_embeddings.shape[1:] + ) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) - return torch.split(self.multi_modal_projector(stacked_image_features), - num_patches_per_batch) + return torch.split( + self.multi_modal_projector(stacked_image_features), num_patches_per_batch + ) def _process_image_input( self, @@ -450,21 +467,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, batch_size = len(image_input["data"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size - image_sizes = torch.as_tensor([[default_height, default_width] - for _ in range(batch_size)]) + image_sizes = torch.as_tensor( + [[default_height, default_width] for _ in range(batch_size)] + ) return [ - self._merge_image_patch_embeddings(image_sizes[i], - patch_features_batch, - strategy="spatial_unpad") + self._merge_image_patch_embeddings( + image_sizes[i], patch_features_batch, strategy="spatial_unpad" + ) for i, patch_features_batch in enumerate(patch_embeddings) ] def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -475,19 +492,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - if multimodal_embeddings is None \ - or len(multimodal_embeddings) == 0: - return self.language_model.get_input_embeddings(input_ids) - - inputs_embeds = embed_multimodal( + return super().get_input_embeddings( input_ids, - self.config.image_token_index, - self.language_model.model.get_input_embeddings, - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, ) - return inputs_embeds def forward( self, @@ -527,7 +546,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, Unlike in LLaVA-1.5, the number of image tokens inputted to the language model depends on the original size of the input image. Including the original image token in the input, the required number of image tokens - is given by [get_llava_next_image_feature_size][]. + is given by [`LlavaNextProcessingInfo.get_num_image_tokens`][vllm.\ +model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens]. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. @@ -535,38 +555,27 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each grid patch for each input image. - image_sizes: The original `(height, width)` for each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: - [LlavaNextImageInputs][] + [`LlavaNextImageInputs`][vllm.model_executor.models.llava_next.LlavaNextImageInputs] """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index cf9852de633f3..074acc7943a43 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -7,21 +7,30 @@ from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn -from transformers import (BatchFeature, LlavaNextVideoConfig, - LlavaNextVideoProcessor) +from transformers import BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageSize, + MultiModalDataItems, + VideoEmbeddingItems, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -30,35 +39,39 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info class LlavaNextVideoPixelInputs(TensorSchema): - """ + """ Dimensions: - - bs: Batch size - - nv: Number of videos - - nf: Number of frames - - nc: Number of channels (3) + - bn: Batch size * number of videos + - f: Number of frames + - c: Number of channels (3) - h: Height of each frame - w: Width of each frame - Note that `num_frames` may be different for each batch, in which case + Note that `f` may be different for each batch, in which case the data is passed as a list instead of a batched tensor. Note that it only supports one video input for one batch. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bs", "nv", "nf", 3, "h", "w")] + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), + ] class LlavaNextVideoProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(LlavaNextVideoConfig) @@ -138,8 +151,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): class LlavaNextVideoDummyInputsBuilder( - BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): - + BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_videos = mm_counts.get("video", 0) @@ -152,28 +165,31 @@ class LlavaNextVideoDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + video_overrides = mm_options.get("video") if mm_options else None return { - "video": - self._get_dummy_videos( + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, + overrides=video_overrides, ) } class LlavaNextVideoMultiModalProcessor( - BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): - + BaseMultiModalProcessor[LlavaNextVideoProcessingInfo] +): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -192,7 +208,8 @@ class LlavaNextVideoMultiModalProcessor( def get_replacement(item_idx: int): videos = mm_items.get_items( - "video", (VideoEmbeddingItems, VideoProcessorItems)) + "video", (VideoEmbeddingItems, VideoProcessorItems) + ) if isinstance(videos, VideoEmbeddingItems): num_video_tokens = videos.get_feature_size(item_idx) @@ -217,7 +234,6 @@ class LlavaNextVideoMultiModalProcessor( # adopted from transformers modeling_llava_next_video.py class LlavaNextVideoPooler(nn.Module): - def __init__(self, config: LlavaNextVideoConfig): super().__init__() @@ -234,36 +250,41 @@ class LlavaNextVideoPooler(nn.Module): else: # TODO: Support Conv2d pooling layer, need to load weights raise ValueError( - f"Unknown pooling mode: {mode}. Expected [`average`, `max`]") + f"Unknown pooling mode: {mode}. Expected [`average`, `max`]" + ) def forward(self, image_features: torch.Tensor): ori_width = int( - math.sqrt(image_features.shape[1] * self.image_size // - self.image_size)) + math.sqrt(image_features.shape[1] * self.image_size // self.image_size) + ) ori_height = int(ori_width * self.image_size // self.image_size) batch_size, _, dim = image_features.shape - image_features_spatial = image_features \ - .view(batch_size, ori_height, ori_height, dim) \ - .permute(0, 3, 1, 2) + image_features_spatial = image_features.view( + batch_size, ori_height, ori_height, dim + ).permute(0, 3, 1, 2) image_features_spatial = self.pool(image_features_spatial) return image_features_spatial.flatten(2).transpose(1, 2).contiguous() class LlavaNextMultiModalProjector(nn.Module): - - def __init__(self, vision_hidden_size: int, text_hidden_size: int, - projector_hidden_act: str, multimodal_projector_bias: bool): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + ): super().__init__() - self.linear_1 = nn.Linear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias) + self.linear_1 = nn.Linear( + vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = nn.Linear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias) + self.linear_2 = nn.Linear( + text_hidden_size, text_hidden_size, bias=multimodal_projector_bias + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) @@ -277,8 +298,8 @@ class LlavaNextMultiModalProjector(nn.Module): info=LlavaNextVideoProcessingInfo, dummy_inputs=LlavaNextVideoDummyInputsBuilder, ) -class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -288,7 +309,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -313,13 +335,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.vision_resampler = LlavaNextVideoPooler(config) self.multi_modal_projector = LlavaNextMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias) + multimodal_projector_bias=config.multimodal_projector_bias, + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -327,14 +351,16 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.make_empty_intermediate_tensors = ( - self.language_model.model.make_empty_intermediate_tensors) + self.language_model.model.make_empty_intermediate_tensors + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[LlavaNextVideoPixelInputs]: """ A legal video input should have the following dimensions: { - "pixel_values_videos" : + "pixel_values_videos" : list[b, Tensor(nb_frames, nb_channels, height, width)] } """ @@ -344,34 +370,25 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, return None expected_h = expected_w = self.config.vision_config.image_size - return LlavaNextVideoPixelInputs(type="pixel_values_videos", - data=pixel_values_videos, - resolve_bindings={ - "h": expected_h, - "w": expected_w, - }) - - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") + return LlavaNextVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }, + ) def _video_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) - image_features = self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + image_features = vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) image_features = self.vision_resampler(image_features) image_features = self.multi_modal_projector(image_features) @@ -380,55 +397,38 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): assert self.vision_tower is not None - video_pixels = inputs["data"] + video_pixels = inputs["pixel_values_videos"] if isinstance(video_pixels, torch.Tensor): - # TODO: support multiple videos per input - b, num_videos, num_frames, c, h, w = video_pixels.shape - assert (num_videos == 1) - stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, - h, w) + bn, f, c, h, w = video_pixels.shape + stacked_pixels = video_pixels.view(bn * f, c, h, w) stacked_embeddings = self._video_pixels_to_features( - self.vision_tower, stacked_pixels) - embeds = stacked_embeddings.view(b, num_frames, - *stacked_embeddings.shape[1:]) + self.vision_tower, stacked_pixels + ) + embeds = stacked_embeddings.view(bn, f, *stacked_embeddings.shape[1:]) elif is_list_of(video_pixels, torch.Tensor): frames_per_videos = [v.shape[0] for v in video_pixels] stacked_pixels = torch.cat(video_pixels, dim=0) stacked_embeddings = self._video_pixels_to_features( - self.vision_tower, stacked_pixels) + self.vision_tower, stacked_pixels + ) embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0) else: - raise ValueError( - f"Unsupported type of video input {type(video_pixels)}") + raise ValueError(f"Unsupported type of video input {type(video_pixels)}") return [e.flatten(0, 1) for e in embeds] def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: video_input = self._parse_and_validate_video_input(**kwargs) if video_input is None: return [] vision_embeddings = self._process_video_pixels(video_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.video_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -446,31 +446,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # This model doesn't support images for now diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 42ab5e7c74d37..05f1621694c36 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -7,19 +7,27 @@ from typing import Annotated, Final, Literal, Optional, Protocol, Union import torch import torch.nn as nn -from transformers import (BatchFeature, LlavaOnevisionConfig, - LlavaOnevisionProcessor) +from transformers import BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor from transformers.models.llava_onevision.modeling_llava_onevision import ( - get_anyres_image_grid_shape, unpad_image) + get_anyres_image_grid_shape, + unpad_image, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageSize, + MultiModalDataItems, + VideoEmbeddingItems, + VideoProcessorItems, +) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -27,12 +35,18 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava -from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, - LlavaNextProcessingInfo) +from .llava_next import ( + BaseLlavaNextMultiModalProcessor, + LlavaNextLikeConfig, + LlavaNextProcessingInfo, +) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -47,10 +61,11 @@ class LlavaOnevisionVideoPixelInputs(TensorSchema): - h: Height - w: Width - Note that `num_videos` may be different for each batch, and 'num_frames' + Note that `f` may be different for each batch, and 'num_frames' may be different for each video, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[ @@ -71,11 +86,12 @@ class LlavaOnevisionImagePixelInputs(TensorSchema): Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "np", 3, "h", "w"), + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), ] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] @@ -88,6 +104,7 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[ @@ -96,11 +113,13 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema): ] -LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, - LlavaOnevisionImageEmbeddingInputs] +LlavaOnevisionImageInputs = Union[ + LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs +] -LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs, - LlavaOnevisionVideoPixelInputs] +LlavaOnevisionMultiInputs = Union[ + LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs +] class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): @@ -108,7 +127,6 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): - def get_hf_config(self) -> LlavaOnevisionLikeConfig: return self.ctx.get_hf_config(LlavaOnevisionConfig) @@ -137,12 +155,14 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): if aspect_ratio > current_aspect_ratio: new_height = int( - round(original_height * (current_width / original_width), 7)) + round(original_height * (current_width / original_width), 7) + ) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: new_width = int( - round(original_width * (current_height / original_height), 7)) + round(original_width * (current_height / original_height), 7) + ) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) @@ -216,14 +236,12 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): seq_len: int, mm_counts: Mapping[str, int], ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self._get_max_video_frames(seq_len) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) @@ -237,14 +255,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), ) class LlavaOnevisionDummyInputsBuilder( - LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): - + LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -259,34 +276,39 @@ class LlavaOnevisionDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, - mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, - ) + overrides=video_overrides, + ), } class LlavaOnevisionMultiModalProcessor( - BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): - + BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo] +): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -403,7 +425,8 @@ class LlavaOnevisionMultiModalProcessor( def get_video_replacement(item_idx: int): videos = mm_items.get_items( - "video", (VideoEmbeddingItems, VideoProcessorItems)) + "video", (VideoEmbeddingItems, VideoProcessorItems) + ) if isinstance(videos, VideoEmbeddingItems): num_video_tokens = videos.get_feature_size(item_idx) @@ -428,17 +451,20 @@ class LlavaOnevisionMultiModalProcessor( class LlavaOnevisionMultiModalProjector(nn.Module): - def __init__(self, config: LlavaOnevisionConfig): super().__init__() - self.linear_1 = nn.Linear(config.vision_config.hidden_size, - config.text_config.hidden_size, - bias=config.multimodal_projector_bias) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) self.act = get_act_fn(config.projector_hidden_act) - self.linear_2 = nn.Linear(config.text_config.hidden_size, - config.text_config.hidden_size, - bias=config.multimodal_projector_bias) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) @@ -450,9 +476,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module): @MULTIMODAL_REGISTRY.register_processor( LlavaOnevisionMultiModalProcessor, info=LlavaOnevisionProcessingInfo, - dummy_inputs=LlavaOnevisionDummyInputsBuilder) -class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=LlavaOnevisionDummyInputsBuilder, +) +class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -462,7 +489,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -487,21 +515,23 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.make_empty_intermediate_tensors = ( - self.language_model.model.make_empty_intermediate_tensors) + self.language_model.model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: + self, **kwargs: object + ) -> Optional[LlavaOnevisionImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -510,42 +540,31 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - return LlavaOnevisionImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, resolve_bindings={ "h": self.config.vision_config.image_size, - "w": self.config.vision_config.image_size - }) + "w": self.config.vision_config.image_size, + }, + ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeds. " - f"Got type: {type(image_embeds)}") - return LlavaOnevisionImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, - **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[LlavaOnevisionVideoPixelInputs]: """ A legal video input should have the following dimensions: { - "pixel_values_videos" : + "pixel_values_videos" : list[b, Tensor(nb_frames, nb_channels, height, width)] } """ @@ -553,17 +572,14 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values_videos is None: return None - if not isinstance(pixel_values_videos, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_values_videos. " - f"Got type: {type(pixel_values_videos)}") - return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", - pixel_values_videos=flatten_bn(pixel_values_videos), + pixel_values_videos=pixel_values_videos, resolve_bindings={ "h": self.config.vision_config.image_size, - "w": self.config.vision_config.image_size - }) + "w": self.config.vision_config.image_size, + }, + ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -571,60 +587,59 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) - return self._select_image_features( - image_features, - strategy=self.config.vision_feature_select_strategy, + return vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py - def _merge_image_patch_embeddings(self, - image_size: torch.Tensor, - patch_embeddings: torch.Tensor, - *, - image_newline=None, - vision_aspect_ratio="anyres_max_9", - strategy: str) -> torch.Tensor: + def _merge_image_patch_embeddings( + self, + image_size: torch.Tensor, + patch_embeddings: torch.Tensor, + *, + image_newline=None, + vision_aspect_ratio="anyres_max_9", + strategy: str, + ) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): - height = width = self.config.vision_config.image_size \ + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( - "The number of patches is not consistent with the " - "image size.") + "The number of patches is not consistent with the image size." + ) if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] @@ -641,53 +656,66 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing - other_patch_embeds = other_patch_embeds[:num_patches] \ - .view(num_patch_height, num_patch_width, height, width, -1) + other_patch_embeds = other_patch_embeds[:num_patches].view( + num_patch_height, num_patch_width, height, width, -1 + ) if "unpad" in strategy: - other_patch_embeds = other_patch_embeds \ - .permute(4, 0, 2, 1, 3).contiguous() \ - .flatten(1, 2).flatten(2, 3) - other_patch_embeds = unpad_image(other_patch_embeds, - (orig_height, orig_width)) + other_patch_embeds = ( + other_patch_embeds.permute(4, 0, 2, 1, 3) + .contiguous() + .flatten(1, 2) + .flatten(2, 3) + ) + other_patch_embeds = unpad_image( + other_patch_embeds, (orig_height, orig_width) + ) max_num_patches = int( - vision_aspect_ratio.removeprefix("anyres_max_")) + vision_aspect_ratio.removeprefix("anyres_max_") + ) channels, curr_height, curr_width = other_patch_embeds.shape - ratio = math.sqrt(curr_height * curr_width / - (max_num_patches * height**2)) + ratio = math.sqrt( + curr_height * curr_width / (max_num_patches * height**2) + ) if ratio > 1.1: other_patch_embeds = other_patch_embeds[None] other_patch_embeds = nn.functional.interpolate( - other_patch_embeds, [ - int(curr_height // ratio), - int(curr_width // ratio) - ], - mode="bilinear")[0] + other_patch_embeds, + [int(curr_height // ratio), int(curr_width // ratio)], + mode="bilinear", + )[0] if image_newline is not None: other_patch_embeds = torch.cat( ( other_patch_embeds, - image_newline[:, None, None] \ - .expand(*other_patch_embeds.shape[:-1], 1) \ + image_newline[:, None, None] + .expand(*other_patch_embeds.shape[:-1], 1) .to(other_patch_embeds.device), ), - dim=-1) - other_patch_embeds = other_patch_embeds \ - .flatten(1, 2).transpose(0, 1) + dim=-1, + ) + other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose( + 0, 1 + ) else: - other_patch_embeds = other_patch_embeds \ - .permute(0, 2, 1, 3, 4).contiguous() \ + other_patch_embeds = ( + other_patch_embeds.permute(0, 2, 1, 3, 4) + .contiguous() .flatten(0, 3) + ) merged_patch_embeddings = torch.cat( - (base_patch_embeds, other_patch_embeds), dim=0) + (base_patch_embeds, other_patch_embeds), dim=0 + ) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( - (base_patch_embeds, - self.image_newline[None] \ - .to(base_patch_embeds.device) - ), dim=0) + ( + base_patch_embeds, + self.image_newline[None].to(base_patch_embeds.device), + ), + dim=0, + ) else: merged_patch_embeddings = base_patch_embeds @@ -707,21 +735,27 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) stacked_patch_embeddings = self.multi_modal_projector( - stacked_image_features) + stacked_image_features + ) return stacked_patch_embeddings.view( - b, num_patches, *stacked_patch_embeddings.shape[1:]) + b, num_patches, *stacked_patch_embeddings.shape[1:] + ) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) return [ - self.multi_modal_projector(image_features) for image_features in - torch.split(stacked_image_features, num_patches_per_batch) + self.multi_modal_projector(image_features) + for image_features in torch.split( + stacked_image_features, num_patches_per_batch + ) ] def _process_image_input( @@ -738,15 +772,17 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, batch_size = len(image_input["pixel_values"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size - image_sizes = torch.as_tensor([[default_height, default_width] - for _ in range(batch_size)]) + image_sizes = torch.as_tensor( + [[default_height, default_width] for _ in range(batch_size)] + ) return [ self._merge_image_patch_embeddings( image_sizes[i], patch_features_batch, image_newline=self.image_newline, - strategy="spatial_unpad") + strategy="spatial_unpad", + ) for i, patch_features_batch in enumerate(patch_embeddings) ] @@ -755,13 +791,11 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: - # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - video_features = vision_tower(pixel_values) - video_features = self._select_image_features( - video_features, - strategy=self.config.vision_feature_select_strategy, + video_features = vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) video_features = self.multi_modal_projector(video_features) video_features = self.apply_pooling(video_features) @@ -774,36 +808,39 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, if isinstance(video_pixels, torch.Tensor): total_videos, frames, c, h, w = video_pixels.shape - video_pixels_flat = video_pixels.view(total_videos * frames, c, h, - w) + video_pixels_flat = video_pixels.view(total_videos * frames, c, h, w) embeddings_flat = self._video_pixels_to_features( - self.vision_tower, video_pixels_flat) + self.vision_tower, video_pixels_flat + ) embeddings_flat = embeddings_flat.reshape( - total_videos, frames * embeddings_flat.shape[1], -1) + total_videos, frames * embeddings_flat.shape[1], -1 + ) image_newline = self.image_newline[None, None, :].expand( - total_videos, -1, -1) + total_videos, -1, -1 + ) return torch.cat((embeddings_flat, image_newline), dim=1) frames_per_video = [len(video) for video in video_pixels] video_pixels_flat = torch.cat(video_pixels) embeddings_flat = self._video_pixels_to_features( - self.vision_tower, video_pixels_flat) + self.vision_tower, video_pixels_flat + ) image_newline = self.image_newline[None, None, :] return [ torch.cat( ( - embeds.reshape(1, num_frame * embeddings_flat.shape[1], - -1), + embeds.reshape(1, num_frame * embeddings_flat.shape[1], -1), image_newline, ), dim=1, - ) for num_frame, embeds in zip( + ) + for num_frame, embeds in zip( frames_per_video, torch.split(embeddings_flat, frames_per_video), ) @@ -819,9 +856,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, # TODO support other pooling types config height, width = image_features.shape[2:] scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] - image_feature = nn.functional.interpolate(image_features, - size=scaled_shape, - mode='bilinear') + image_feature = nn.functional.interpolate( + image_features, size=scaled_shape, mode="bilinear" + ) image_feature = image_feature.permute(0, 2, 3, 1) image_feature = image_feature.view(batch_frames, -1, dim) return image_feature @@ -829,16 +866,14 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -854,46 +889,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_index, self.config.video_token_index]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[LlavaOnevisionImagePixelInputs] = None, - video_input: Optional[LlavaOnevisionVideoPixelInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_index, - ) - - if video_input is not None: - video_embeds = self._process_video_pixels(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -911,38 +906,18 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py new file mode 100644 index 0000000000000..17ec6b7d2b06a --- /dev/null +++ b/vllm/model_executor/models/longcat_flash.py @@ -0,0 +1,751 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Inference-only Flash model compatible with HuggingFace weights.""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.int8_utils import block_dequant +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class FlashConfig(PretrainedConfig): + """Flash model configuration.""" + + model_type = "longcat_flash" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + hidden_size=4096, + intermediate_size=8192, + num_layers=28, + num_hidden_layers=None, + num_attention_heads=96, + num_key_value_heads=128, + ep_size=1, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + num_experts_per_tok=None, + norm_topk_prob=False, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=1000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mla_scale_q_lora=False, + mla_scale_kv_lora=False, + torch_dtype="bfloat16", + params_dtype="bfloat16", + router_dtype="float32", + router_bias=False, + topk_method=None, + routed_scaling_factor=None, + zero_expert_num=0, + zero_expert_type=None, + nextn_use_scmoe=False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + torch_dtype=torch_dtype, + params_dtype=params_dtype, + router_dtype=router_dtype, + topk_method=topk_method, + router_bias=router_bias, + nextn_use_scmoe=nextn_use_scmoe, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = ( + num_hidden_layers if num_hidden_layers is not None else num_layers + ) + self.num_attention_heads = num_attention_heads + self.ep_size = ep_size + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mla_scale_q_lora = mla_scale_q_lora + self.mla_scale_kv_lora = mla_scale_kv_lora + self.zero_expert_num = zero_expert_num + self.zero_expert_type = zero_expert_type + self.routed_scaling_factor = routed_scaling_factor + self.hidden_act = "silu" + self.intermediate_size = ( + self.ffn_hidden_size + if hasattr(self, "ffn_hidden_size") + else self.intermediate_size + ) + if hasattr(self, "moe_intermediate_size"): + self.moe_intermediate_size = self.moe_intermediate_size + elif hasattr(self, "expert_ffn_hidden_size"): + self.moe_intermediate_size = self.expert_ffn_hidden_size + else: + self.moe_intermediate_size = self.intermediate_size + + +class FlashMLP(nn.Module): + """Flash MLP layer.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LongcatRouter(nn.Module): + def __init__( + self, + config, + zero_expert_num=0, + rounter_params_dtype=torch.bfloat16, + prefix: str = "", + ): + super().__init__() + self.n_routed_experts = ( + config.n_routed_experts + if hasattr(config, "n_routed_experts") + else config.num_experts[0] + ) + self.n_routed_experts = self.n_routed_experts + zero_expert_num + self.classifier = ReplicatedLinear( + config.hidden_size, + self.n_routed_experts, + bias=config.router_bias, + params_dtype=rounter_params_dtype, + quant_config=None, + prefix=f"{prefix}.classifier", + ) + self.e_score_correction_bias = nn.Parameter( + torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype) + ) + + def forward(self, hidden_states): + logits, _ = self.classifier(hidden_states) + return logits + + +class LongcatMoe(nn.Module): + def __init__( + self, + config: FlashConfig, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.hidden_size = hidden_size + self.zero_expert_num = config.zero_expert_num + self.zero_expert_type = config.zero_expert_type + self.routed_scaling_factor = config.routed_scaling_factor + self.enable_eplb = enable_eplb + # Gate always runs at half / full precision for now. + self.rounter_params_dtype = params_dtype + if config.router_dtype == "float32": + self.rounter_params_dtype = torch.float32 + + self.router = LongcatRouter( + config=config, + zero_expert_num=self.zero_expert_num, + rounter_params_dtype=self.rounter_params_dtype, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + params_dtype=params_dtype, + e_score_correction_bias=self.router.e_score_correction_bias, + renormalize=False, + quant_config=quant_config, + prefix=f"{prefix}.experts", + zero_expert_num=self.zero_expert_num, + zero_expert_type=self.zero_expert_type, + enable_eplb=self.enable_eplb, + routed_scaling_factor=config.routed_scaling_factor, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits = self.router(hidden_states.to(self.rounter_params_dtype)) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class FlashDecoderLayer(nn.Module): + """Flash decoder layer with dual attention and MLP structure.""" + + def __init__( + self, + vllm_config: VllmConfig, + config: FlashConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.layer_idx = int(prefix.split(sep=".")[-1]) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + + # Dual attention structure + self.self_attn = nn.ModuleList( + [ + DeepseekV2MLAAttention( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=None + if "self_attn" in getattr(config, "disable_quant_module", []) + else quant_config, + prefix=f"{prefix}.self_attn.{i}", + ) + for i in range(2) + ] + ) + self.input_layernorm = nn.ModuleList( + [RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)] + ) + self.post_attention_layernorm = nn.ModuleList( + [RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)] + ) + + # Dual MLP structure + self.mlps = nn.ModuleList( + [ + FlashMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=None + if "mlps" in getattr(config, "disable_quant_module", []) + else quant_config, + prefix=f"{prefix}.mlps.{i}", + ) + for i in range(2) + ] + ) + + self.mlp = LongcatMoe( + config=config, + num_experts=config.n_routed_experts + if hasattr(config, "n_routed_experts") + else config.num_experts[self.layer_idx], + top_k=config.moe_topk + if hasattr(config, "moe_topk") + else config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + quant_config=quant_config, + prefix=(f"{prefix}.mlp"), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm[0](hidden_states) + else: + hidden_states, residual = self.input_layernorm[0](hidden_states, residual) + + hidden_states = self.self_attn[0]( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm[0]( + hidden_states, residual + ) + + # moe + hidden_states_copy = hidden_states.clone() + moe_hidden_states = self.mlp(hidden_states_copy) + + # first mlp + hidden_states = self.mlps[0](hidden_states) + + hidden_states, residual = self.input_layernorm[1](hidden_states, residual) + + # second_attn + hidden_states = self.self_attn[1]( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm[1]( + hidden_states, residual + ) + + # second_mlp + hidden_states = self.mlps[1](hidden_states) + + hidden_states = hidden_states + moe_hidden_states + + return hidden_states, residual + + +@support_torch_compile +class FlashModel(nn.Module): + """Flash model.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: FlashDecoderLayer( + vllm_config, + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + """Flash model for causal language modeling.""" + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + config.intermediate_size = ( + config.ffn_hidden_size + if hasattr(config, "ffn_hidden_size") + else config.intermediate_size + ) + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = FlashModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + if hasattr(self.config, "n_routed_experts") + else self.config.num_experts[0], + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + expert_params_mapping = self.get_expert_mapping() + loaded_params: set[str] = set() + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp" in name and "mlps" not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + # Skip mtp + if ".mtp." in name: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip mtp + if ".mtp." in name_mapped: + continue + if ( + name_mapped.endswith(".bias") or name_mapped.endswith("_bias") + ) and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name_mapped] + weight_loader = param.weight_loader + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading kv_scale from ckpts towards new design. + if name.endswith(".kv_scale") and name not in params_dict: + continue + # Skip mtp + if ".mtp." in name: + continue + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + for layer_id in range(self.config.num_hidden_layers): + for i in range(2): + if isinstance(self.model.layers[layer_id], PPMissingLayer): + continue + self_attn = self.model.layers[layer_id].self_attn[i] + if hasattr( + self.quant_config, "weight_block_size" + ) and self_attn.kv_b_proj.weight.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + dtype = torch.get_default_dtype() + w = block_dequant( + self_attn.kv_b_proj.weight, + self_attn.kv_b_proj.weight_scale_inv, + weight_block_size, + ).to(dtype) + else: + w = self_attn.kv_b_proj.weight + + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if self.config.mla_scale_q_lora: + self_attn.q_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.q_lora_rank + ) ** 0.5 + if self.config.mla_scale_kv_lora: + self_attn.kv_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.kv_lora_rank + ) ** 0.5 + return loaded_params diff --git a/vllm/model_executor/models/longcat_flash_mtp.py b/vllm/model_executor/models/longcat_flash_mtp.py new file mode 100644 index 0000000000000..55468f354c3a2 --- /dev/null +++ b/vllm/model_executor/models/longcat_flash_mtp.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.int8_utils import block_dequant +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.longcat_flash import FlashConfig +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import DeepseekV2DecoderLayer +from .interfaces import SupportsPP +from .utils import maybe_prefix + + +class LongCatMultiTokenPredictorLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + prefix: str, + vllm_config: VllmConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = ReplicatedLinear( + 2 * config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix="eh_proj", + ) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states, _ = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) + + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class LongCatMultiTokenPredictor(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) + vllm_config.model_config.hf_config.intermediate_size = config.intermediate_size + self.mtp_start_layer_idx = config.num_hidden_layers * 2 + self.num_mtp_layers = 1 + self.layers = torch.nn.ModuleDict( + { + str(idx): LongCatMultiTokenPredictorLayer( + config, + prefix=f"{prefix}.layers.{idx}", + vllm_config=vllm_config, + quant_config=quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = spec_step_idx % self.num_mtp_layers + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + +class LongCatFlashMTP(nn.Module, SupportsPP): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + # LongCat MTP without MoE layers + vllm_config.model_config.hf_config.n_routed_experts = None + self.config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) + self.quant_config = ( + None + if "mtp" in getattr(self.config, "disable_quant_module", []) + else vllm_config.quant_config + ) + + self.model = LongCatMultiTokenPredictor( + vllm_config=vllm_config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(self.config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + new_to_old_names_mapping = { + "model.mtp.embed_tokens.weight": "model.layers.0.embed_tokens.weight", + "model.mtp.layers.0.eh_proj.weight": "eh_proj.weight", + "model.mtp.layers.0.eh_proj.weight_scale_inv": "eh_proj.weight_scale_inv", + "model.mtp.layers.0.enorm.m.weight": "enorm.weight", + "model.mtp.layers.0.hnorm.m.weight": "hnorm.weight", + "model.mtp.layers.0.input_layernorm.weight": "model.layers.0.input_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.post_attention_layernorm.weight": "model.layers.0.post_attention_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_a_layernorm.weight": "model.layers.0.self_attn.kv_a_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight": "model.layers.0.self_attn.kv_a_proj_with_mqa.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv": "model.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_b_proj.weight": "model.layers.0.self_attn.kv_b_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv": "model.layers.0.self_attn.kv_b_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.o_proj.weight": "model.layers.0.self_attn.o_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.o_proj.weight_scale_inv": "model.layers.0.self_attn.o_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.q_a_layernorm.weight": "model.layers.0.self_attn.q_a_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.q_a_proj.weight": "model.layers.0.self_attn.q_a_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv": "model.layers.0.self_attn.q_a_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.q_b_proj.weight": "model.layers.0.self_attn.q_b_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv": "model.layers.0.self_attn.q_b_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight": "model.layers.0.mlp.down_proj.weight", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv": "model.layers.0.mlp.down_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight": "model.layers.0.mlp.gate_proj.weight", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv": "model.layers.0.mlp.gate_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight": "model.layers.0.mlp.up_proj.weight", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv": "model.layers.0.mlp.up_proj.weight_scale_inv", # noqa: E501 + "model.mtp.norm.weight": "final_layernorm.weight", + } + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = self.get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name( + spec_layer, name, new_to_old_names_mapping + ) + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + if (param_name == "fused_qkv_a_proj") and name not in params_dict: + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + spec_layer_id = self.config.num_hidden_layers * 2 + self_attn = self.model.layers[str(spec_layer_id)].mtp_block.self_attn + if hasattr( + self.quant_config, "weight_block_size" + ) and self_attn.kv_b_proj.weight.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + dtype = torch.get_default_dtype() + w = block_dequant( + self_attn.kv_b_proj.weight, + self_attn.kv_b_proj.weight_scale_inv, + weight_block_size, + ).to(dtype) + else: + w = self_attn.kv_b_proj.weight + else: + w = self_attn.kv_b_proj.weight + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if self.config.mla_scale_q_lora: + self_attn.q_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.q_lora_rank + ) ** 0.5 + if self.config.mla_scale_kv_lora: + self_attn.kv_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.kv_lora_rank + ) ** 0.5 + return loaded_params + + def _rewrite_spec_layer_name( + self, spec_layer: int, name: str, new_to_old_names_mapping: dict + ) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + if name in new_to_old_names_mapping: + name = new_to_old_names_mapping[name] + spec_layer_weight_names = [ + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", + ] + if ( + name.startswith("enorm") + or name.startswith("hnorm") + or name.startswith("eh_proj") + or name.startswith("final_layernorm") + ): + name = "model.layers." + str(spec_layer) + "." + name + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace( + "model.layers.0.", f"model.layers.{spec_layer}.mtp_block." + ) + elif shared_weight: + # treat shared weights as top level weights + name = name.replace("model.layers.0.", "model.") + return name + + def get_spec_layer_idx_from_weight_name( + self, config: PretrainedConfig, weight_name: str + ) -> Optional[int]: + if "model.mtp" in weight_name: + return config.num_hidden_layers * 2 + return None diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f02499a4f96b5..1638aab137aaf 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """PyTorch MAMBA model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional import torch from torch import nn from transformers import MambaConfig -from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -16,56 +17,66 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsAttentionFree, + SupportsPP, +) from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) KVCache = tuple[torch.Tensor, torch.Tensor] class MambaDecoderLayer(nn.Module): - - def __init__(self, - config: MambaConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False, - prefix: str = "") -> None: + def __init__( + self, + config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" self.is_lora_enabled = is_lora_enabled mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None - self.mixer = MambaMixer(hidden_size=config.hidden_size, - ssm_state_size=config.state_size, - conv_kernel_size=config.conv_kernel, - intermediate_size=config.intermediate_size, - time_step_rank=config.time_step_rank, - use_conv_bias=config.use_conv_bias, - use_bias=config.use_bias, - use_rms_norm=self.is_falcon_mamba, - rms_norm_has_weight=not self.is_falcon_mamba, - rms_norm_eps=mixer_rms_eps, - activation=config.hidden_act, - is_lora_enabled=self.is_lora_enabled, - model_config=model_config, - cache_config=cache_config, - prefix=f"{prefix}.mixer") + self.mixer = MambaMixer( + hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=config.intermediate_size, + time_step_rank=config.time_step_rank, + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + use_rms_norm=self.is_falcon_mamba, + rms_norm_has_weight=not self.is_falcon_mamba, + rms_norm_eps=mixer_rms_eps, + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.mixer", + ) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -73,7 +84,6 @@ class MambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -83,13 +93,12 @@ class MambaDecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params) + self.mixer(hidden_states, output) return output, residual @support_torch_compile class MambaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -101,8 +110,11 @@ class MambaModel(nn.Module): is_lora_enabled = bool(lora_config) self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -114,19 +126,21 @@ class MambaModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MambaDecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - is_lora_enabled=is_lora_enabled, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: MambaDecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + is_lora_enabled=is_lora_enabled, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) @@ -135,7 +149,6 @@ class MambaModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -150,30 +163,19 @@ class MambaModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - - layer_cache_params = None - if mamba_cache_params is not None: - layer_cache_params = mamba_cache_params.at_layer_idx( - i - self.start_layer) - + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=layer_cache_params) + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -186,29 +188,29 @@ class MambaModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ + assert not cache_config.enable_prefix_caching, ( "Mamba does not support prefix caching" + ) super().__init__() self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - self.backbone = MambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "backbone")) + self.backbone = MambaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -222,45 +224,33 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.backbone.make_empty_intermediate_tensors) + self.backbone.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config) - state_dtype = self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_layers, *state_shape, - *state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.backbone(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.backbone( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -269,7 +259,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba1_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -289,22 +278,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): intermediate_size=hf_config.intermediate_size, state_size=hf_config.state_size, conv_kernel=hf_config.conv_kernel, - use_v1=envs.VLLM_USE_V1) + ) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 3432cf29feac6..4491648f3a0ad 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """PyTorch MAMBA2 model.""" + from collections.abc import Iterable from typing import Optional @@ -8,66 +9,67 @@ import torch from torch import nn from transformers import MambaConfig -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) KVCache = tuple[torch.Tensor, torch.Tensor] class Mamba2DecoderLayer(nn.Module): - - def __init__(self, - config: MambaConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config - self.mixer = MambaMixer2(hidden_size=config.hidden_size, - ssm_state_size=config.state_size, - conv_kernel_size=config.conv_kernel, - intermediate_size=getattr( - config, "intermediate_size", - config.expand * config.hidden_size), - use_conv_bias=config.use_conv_bias, - use_bias=config.use_bias, - n_groups=config.n_groups, - num_heads=config.num_heads, - head_dim=config.head_dim, - rms_norm_eps=config.layer_norm_epsilon, - activation=config.hidden_act, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mixer = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=getattr( + config, "intermediate_size", config.expand * config.hidden_size + ), + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.n_groups, + num_heads=config.num_heads, + head_dim=config.head_dim, + rms_norm_eps=config.layer_norm_epsilon, + activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -75,8 +77,6 @@ class Mamba2DecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -86,13 +86,12 @@ class Mamba2DecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mixer(hidden_states, output) return output, residual @support_torch_compile class Mamba2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -105,8 +104,11 @@ class Mamba2Model(nn.Module): assert not is_lora_enabled self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -118,18 +120,20 @@ class Mamba2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Mamba2DecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: Mamba2DecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) @@ -138,7 +142,6 @@ class Mamba2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -153,40 +156,21 @@ class Mamba2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - - for i in range(len(self.layers)): - layer = self.layers[i] - + for i, layer in enumerate(self.layers): hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer) if mamba_cache_params else None, - mamba2_metadata=mamba2_metadata) + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -200,21 +184,18 @@ class Mamba2Model(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -225,13 +206,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -250,24 +229,21 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): head_dim=hf_config.head_dim, state_size=hf_config.state_size, conv_kernel=hf_config.conv_kernel, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config self.vllm_config = vllm_config self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config - self.backbone = Mamba2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "backbone")) + self.backbone = Mamba2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -279,70 +255,48 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.backbone.make_empty_intermediate_tensors) + self.backbone.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = ( - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba)) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - - hidden_states = self.backbone(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.backbone( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py deleted file mode 100644 index 6b16e3ce7d984..0000000000000 --- a/vllm/model_executor/models/mamba_cache.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import torch - -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig -from vllm.model_executor.models.constant_size_cache import ConstantSizeCache - - -@dataclass -class MambaCacheParams: - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self, layer_idx): - return MambaCacheParams(self.conv_state[layer_idx], - self.ssm_state[layer_idx], - self.state_indices_tensor) - - -class MambaCacheManager(ConstantSizeCache): - - def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int, - conv_state_shape: tuple[int, int], - temporal_state_shape: tuple[int, int], - conv_state_dtype: torch.dtype, - temporal_state_dtype: torch.dtype): - - self.conv_state_dtype = conv_state_dtype - self.temporal_state_dtype = temporal_state_dtype - - # Determine max batch size to set size of MambaCache - max_batch_size = vllm_config.scheduler_config.max_num_seqs - if not vllm_config.model_config.enforce_eager: - max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) - - # Initialize parent class - super().__init__(max_batch_size) - - # assume conv_state = (dim, state_len) - assert conv_state_shape[0] > conv_state_shape[1] - conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - (conv_state_shape[1], conv_state_shape[0]), - dtype=self.conv_state_dtype, - device="cuda").transpose(-1, -2) - temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=self.temporal_state_dtype, - device="cuda") - - self._mamba_cache = (conv_state, temporal_state) - - @property - def cache(self): - return self._mamba_cache - - def _copy_cache(self, from_index: int, to_index: int): - for cache_t in self.cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def current_run_tensors(self, **kwargs) -> MambaCacheParams: - """ - Return the tensors for the current run's conv and ssm state. - """ - cache_tensors, state_indices_tensor = super().current_run_tensors( - **kwargs) - return MambaCacheParams(cache_tensors[0], cache_tensors[1], - state_indices_tensor) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 709a5a993c6f7..7e1d2bf14bb5c 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -2,32 +2,35 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .utils import maybe_prefix class ResidualBlock(nn.Module): - - def __init__(self, config: VllmConfig, hidden_size: int, - num_layers: int) -> None: + def __init__(self, config: VllmConfig, hidden_size: int, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList([ - nn.Linear(hidden_size, - hidden_size, - bias=getattr(config, "medusa_fc_bias", False)) - for _ in range(num_layers) - ]) + self.layers = nn.ModuleList( + [ + nn.Linear( + hidden_size, + hidden_size, + bias=getattr(config, "medusa_fc_bias", False), + ) + for _ in range(num_layers) + ] + ) self.act = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -39,13 +42,13 @@ class ResidualBlock(nn.Module): class Medusa(nn.Module): """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 Reference implementation: https://github.com/FasterDecoding/Medusa - + Differences from reference implementation: 1. Currently this only supports generating proposals from top-1 tokens. - 2. We have an optional token_map which reduces draft vocab to most - frequently used tokens to give some additional speed-up by reducing - sampling overhead. This is disabled unless the checkpoint file has - explicit token_map tensor and config has an optional attribute + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute truncated_vocab_size < vocab_size. To use this technique, one has to find the top-k most frequent tokens in target dataset and add that as a tensor in the draft checkpoint (using key token_map). Also, the draft config @@ -55,12 +58,16 @@ class Medusa(nn.Module): config = vllm_config.speculative_config.draft_model_config.hf_config super().__init__() self.config = config - self.blocks = nn.ModuleList([ - ResidualBlock(config=config, - hidden_size=self.config.hidden_size, - num_layers=self.config.num_hidden_layers) - for _ in range(self.config.num_heads) - ]) + self.blocks = nn.ModuleList( + [ + ResidualBlock( + config=config, + hidden_size=self.config.hidden_size, + num_layers=self.config.num_hidden_layers, + ) + for _ in range(self.config.num_heads) + ] + ) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size @@ -71,24 +78,27 @@ class Medusa(nn.Module): config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), ) - self.lm_heads = [ - self.lm_head for _ in range(self.config.num_heads) - ] + self.lm_heads = [self.lm_head for _ in range(self.config.num_heads)] else: - self.lm_heads = nn.ModuleList([ - ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) for _ in range(self.config.num_heads) - ]) + self.lm_heads = nn.ModuleList( + [ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, f"lm_heads.{i}"), + ) + for i in range(self.config.num_heads) + ] + ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.truncated_vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale + ) # Token map is a idx to token mapping to reduce the vocab size for # the draft model. Using smaller vocab size for draft, containing @@ -102,12 +112,13 @@ class Medusa(nn.Module): return [block(hidden_states) for block in self.blocks] def compute_logits( - self, hidden_states: list[torch.Tensor], - sampling_metadata: SamplingMetadata) -> list[torch.Tensor]: + self, + hidden_states: list[torch.Tensor], + ) -> list[torch.Tensor]: logits_lst: list[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): - _logits = self.logits_processor(lm_head, hs, sampling_metadata) + _logits = self.logits_processor(lm_head, hs) if _logits is None: # _logits should only be None on rank > 0, in which case @@ -118,68 +129,20 @@ class Medusa(nn.Module): if self.token_map is None: logits_lst.append(_logits) else: - logits_lst.append(-torch.inf * torch.ones( - size=(*_logits.shape[:-1], self.orig_vocab_size), - device=_logits.device, - dtype=_logits.dtype)) + logits_lst.append( + -torch.inf + * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype, + ) + ) logits_lst[-1][..., self.token_map] = _logits return logits_lst - def sample( - self, - logits: list[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - logits = torch.stack(logits, dim=0).float() - logprobs = torch.log_softmax(logits, dim=-1) - token_ids = logits.argmax(-1) # support only top-1 for now - probs = torch.softmax(logits, dim=-1) - - token_id_list = [] - token_prob_list = [] - token_logprob_list = [] - - for idx, seq_group in enumerate(sampling_metadata.seq_groups): - token_id_list.append(token_ids[:, seq_group.sample_indices]) - token_prob_list.append(probs[:, seq_group.sample_indices]) - token_logprob_list.append(logprobs[:, seq_group.sample_indices]) - - outputs: list[Optional[SamplerOutput]] = [] - for idx in range(len(sampling_metadata.seq_groups)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_prob_list[idx].squeeze(1), - logprobs=token_logprob_list[idx].squeeze(1), - sampled_token_ids=token_id_list[idx].squeeze(1), - )) - - return outputs - - def generate_proposals( - self, - previous_hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - # During preemption, we may receive an empty tensor (batch_size=0) - if previous_hidden_states.size(0) == 0: - # Return None to signal the Top1Proposer that no proposals - # were generated for this batch, allowing it to handle this - # special case appropriately - return None - - return self.sample( - logits=self.compute_logits( - hidden_states=self.forward(previous_hidden_states), - sampling_metadata=sampling_metadata, - ), - sampling_metadata=sampling_metadata, - ) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -190,30 +153,33 @@ class Medusa(nn.Module): if name == "token_map": if self.truncated_vocab_size < self.orig_vocab_size: - self.token_map = nn.Parameter(loaded_weight, - requires_grad=False) + self.token_map = nn.Parameter(loaded_weight, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight - elif (getattr(self.config, "original_lm_head", False) - and name == "lm_heads.0.weight"): + elif ( + getattr(self.config, "original_lm_head", False) + and name == "lm_heads.0.weight" + ): weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): - if "lm_head" in name and self.token_map is not None and\ - loaded_weight.shape[0] > self.token_map.shape[0]: - + if ( + "lm_head" in name + and self.token_map is not None + and loaded_weight.shape[0] > self.token_map.shape[0] + ): loaded_weight = loaded_weight[self.token_map] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) - assert (self.truncated_vocab_size - == self.orig_vocab_size) or (self.token_map is not None) + assert (self.truncated_vocab_size == self.orig_vocab_size) or ( + self.token_map is not None + ) return loaded_params diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py new file mode 100644 index 0000000000000..47839a2c6b03f --- /dev/null +++ b/vllm/model_executor/models/midashenglm.py @@ -0,0 +1,853 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Horizon team, Xiaomi MiLM Plus. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiDashengLM model compatible with HuggingFace weights.""" + +import collections +import collections.abc +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Callable, Optional, TypedDict, Union, cast + +import numpy as np +import torch +import torch.nn as nn +import torchaudio.functional as F +from torch.nn.functional import scaled_dot_product_attention +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.midashenglm import DashengConfig + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix + +_Tuple2 = Union[int, tuple[int, int], Sequence[int]] + + +def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]: + if isinstance(x, collections.abc.Sequence): + assert len(x) == 2, ( + f"Expected a sequence of length 2, got {x} with length {len(x)}" + ) + return cast(tuple[int, int], tuple(x)) + return (x, x) + + +def calculate_mel_frames_dasheng( + audio_length_samples: int, + n_fft: int = 512, + hop_size: int = 160, + dasheng_subsampling: int = 4, + center=True, + model_subsampling: int = 5, +) -> int: + """Calculate the number of Mel-spectrogram frames.""" + if center: + audio_length_samples = audio_length_samples + n_fft + + return ( + int(1 + ((audio_length_samples - n_fft) / hop_size)) + // dasheng_subsampling + // model_subsampling + ) + + +class AudioPatchEmbed(nn.Module): + def __init__( + self, + input_size: _Tuple2 = 64, + patch_size: _Tuple2 = 16, + patch_stride: _Tuple2 = 16, + in_chans: int = 1, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = False, + ): + super().__init__() + self.input_size = _resolve_tuple2(input_size) + self.patch_size = _resolve_tuple2(patch_size) + self.patch_stride = _resolve_tuple2(patch_stride) + self.grid_size = ( + self.input_size[0] // self.patch_stride[0], + self.input_size[1] // self.patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_stride, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + if self.flatten: + x = torch.permute( + torch.flatten(x, 2, 3), (0, 2, 1) + ) # rearrange(x, "b c f t -> b (f t) c") + x = self.norm(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class DashengMlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ColumnParallelLinear( + input_size=in_features, + output_size=hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.act = get_act_fn("gelu") + self.fc2 = RowParallelLinear( + input_size=hidden_features, + output_size=out_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.act(x) + x, _ = self.fc2(x) + return x + + +class DashengAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.embed_dim = dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.proj = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): + B, N, C = x.shape + + qkv, _ = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + x = scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask[:, None, None, :] if mask is not None else None, + ) + + x = x.transpose(1, 2).reshape(B, N, C) + x, _ = self.proj(x) + return x + + +class DashengBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + init_values: Optional[float] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = DashengAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + self.mlp = DashengMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + + # Kwargs usually has a mask parameter that is passed to Attention + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.ls1(self.attn(self.norm1(x), mask)) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class DashengFrontend(nn.Module): + def __init__(self, config: DashengConfig): + super().__init__() + self.config = config + + spectrogram_window = torch.hann_window(self.config.win_length) + self.register_buffer( + "spectrogram_window", + spectrogram_window, + persistent=False, + ) + self.spectrogram_window: torch.Tensor + + melscale_fbanks = F.melscale_fbanks( + n_freqs=self.config.n_fft // 2 + 1, + f_min=self.config.f_min, + f_max=self.config.f_max, + n_mels=self.config.n_mels, + sample_rate=self.config.sample_rate, + ) + self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False) + self.melscale_fbanks: torch.Tensor + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + spectrogram = F.spectrogram( + waveform=waveform.to(torch.float32), + pad=0, + window=self.spectrogram_window, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + win_length=self.config.win_length, + power=2, + normalized=False, + center=self.config.center, + ) + mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT + # x has shape [batch, freq, time]. + # F.amplitude_to_DB accepts inputs shaped as: + # - [freq, time] + # - [channel, freq, time] + # - [..., channel, freq, time] + # Here we insert a channel dimension of size 1 before calling it, + # then remove that extra dimension afterward. + log_mel_spectrogram = F.amplitude_to_DB( + mel_spectrogram.unsqueeze(1), + multiplier=10, + amin=1e-10, + db_multiplier=0, + top_db=120, + ).squeeze(1) + return log_mel_spectrogram.to(waveform.dtype) + + +class DashengAudioTransformer(nn.Module): + def __init__( + self, + config: DashengConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.target_length = config.target_length + self.hop_length = config.hop_length + + self.front_end = DashengFrontend(config) + + self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) + + self.patch_embed = AudioPatchEmbed( + input_size=(config.n_mels, config.target_length), + embed_dim=config.embed_dim, + in_chans=config.input_channels, + patch_size=config.patch_size, + flatten=False, + patch_stride=config.patch_stride, + ) + + self.time_pos_embed = nn.Parameter( + torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) + ) + self.freq_pos_embed = nn.Parameter( + torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1) + ) + self.blocks = nn.ModuleList( + DashengBlock( + dim=config.embed_dim, + num_heads=config.num_heads, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + init_values=config.init_values, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + ) + for i in range(config.depth) + ) + self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) + + def forward_features( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + t = x.shape[-1] + x = x + self.time_pos_embed[:, :, :, :t] + x = ( + x + self.freq_pos_embed[:, :, :, :] + ) # Just to support __getitem__ in posembed + x = torch.permute( + torch.flatten(x, 2, 3), (0, 2, 1) + ) # rearrange(x, "b c f t -> b (f t) c") + for block in self.blocks: + x = block(x, mask) + x = self.norm(x) + return x + + def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: + batch_size = len(lengths) + idx = torch.arange(max_length, device=lengths.device) + idx = idx.repeat(batch_size).view(batch_size, max_length) + mask = (idx < lengths.unsqueeze(-1)).bool() + return mask + + def forward( + self, + x: torch.Tensor, + x_length: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + x = self.front_end(x) + x = x.to(self.time_pos_embed.dtype) + target_length_in_patches = self.target_length // 4 + x = x.unsqueeze(1) + x = torch.permute(x, (0, 2, 1, 3)) + x = self.init_bn(x) + x = torch.permute(x, (0, 2, 1, 3)) + + x = self.patch_embed(x) + t = x.shape[-1] + + input_splits = x.split(target_length_in_patches, dim=-1) + + if x_length is not None: + assert len(x_length) == len(x), ( + "batchsizes of input x and x_length need to be same" + ) + assert x_length.ndim == 1, "Lengths are of size (B,)" + scaled_lengths = (x_length / (self.hop_length * 4)).long() + mask = self._to_mask(max_length=t, lengths=scaled_lengths) + split_masks = mask.split(target_length_in_patches, dim=-1) + else: + mask = None + split_masks = [None] * len(input_splits) + + outputs = [] + + for split_x, split_mask in zip(input_splits, split_masks): + forward_kwargs = {} + forward_kwargs["mask"] = split_mask + split_x = self.forward_features(split_x, **forward_kwargs) + outputs.append(split_x) + x = torch.cat(outputs, dim=1) + return x, mask + + +class AudioProjectorSubsample(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + downsample_rate=5, + dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.k = downsample_rate + self.net = nn.Sequential( + ColumnParallelLinear( + input_size=in_dim * self.k, + output_size=out_dim, + quant_config=quant_config, + prefix=f"{prefix}.net.0", + return_bias=False, + ), + get_act_fn("gelu"), + RowParallelLinear( + input_size=out_dim, + output_size=out_dim, + quant_config=quant_config, + prefix=f"{prefix}.net.2", + return_bias=False, + ), + ) + + def forward(self, x, mask=None): + batch_size, seq_len, dim = x.shape + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + if mask is not None: + mask = mask[:, :-num_frames_to_discard] + if mask is None: + mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) + x = x.reshape( + batch_size, -1, self.k * dim + ) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) + for layer in self.net: + x = layer(x) + mask = mask.reshape( + batch_size, -1, self.k + ) # rearrange(mask, "b (s k) -> b s k", k=self.k) + mask = mask.any(dim=-1).long() + return x, mask + + +# === Audio Inputs === # +class MiDashengLMAudioInputs(TypedDict): + input_values: torch.Tensor + """Shape: `(num_audios, num_sampling_points)`""" + audio_length: torch.Tensor + """Shape: `(num_audios, 1)`""" + + +class MiDashengLMProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_feature_extractor(self): + hf_processor = self.get_hf_processor() + feature_extractor = hf_processor.feature_extractor + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} + + def get_min_audio_len(self): + return 3200 + + def get_max_audio_len(self): + return 160000 + + +class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + hf_processor = self.info.get_hf_processor() + audio_token = hf_processor.audio_token + audio_bos_token = hf_processor.audio_bos_token + audio_eos_token = hf_processor.audio_eos_token + + single_audio_text = f"{audio_bos_token}{audio_token}{audio_eos_token}" + return single_audio_text * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + + audio_overrides = mm_options.get("audio") if mm_options else None + + return { + "audio": self._get_dummy_audios( + length=self.info.get_max_audio_len(), + num_audios=num_audios, + overrides=audio_overrides, + ) + } + + +class MiDashengLMMultiModalProcessor( + BaseMultiModalProcessor[MiDashengLMProcessingInfo] +): + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + audios = mm_data.pop("audios", []) + + # + Padding + min_audio_len = self.info.get_min_audio_len() + processed_audios = [ + np.pad( + audio, + (0, min_audio_len - audio.shape[-1]), + mode="constant", + constant_values=0, + ) + if isinstance(audio, np.ndarray) and audio.shape[-1] < min_audio_len + else audio + for audio in audios + ] + + if processed_audios: + mm_data["audio"] = processed_audios + + if not mm_data.get("audio", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + mm_kwargs = dict( + **mm_kwargs, + ) + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_values=MultiModalFieldConfig.batched("audio"), + audio_length=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + audio_length = out_mm_data.get("audio_length") + if audio_length is None: + audio_output_lengths = [] + else: + audio_length_np = ( + audio_length.cpu().numpy() + if isinstance(audio_length, torch.Tensor) + else audio_length + ) + audio_output_lengths = [ + max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame + for length in audio_length_np + ] + + def get_replacement_midashenglm(item_idx: int): + num_features = audio_output_lengths[item_idx] + audio_tokens = [audio_token_id] * num_features + + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_midashenglm, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + MiDashengLMMultiModalProcessor, + info=MiDashengLMProcessingInfo, + dummy_inputs=MiDashengLMDummyInputsBuilder, +) +class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("audio"): + return "<|audio_bos|><|AUDIO|><|audio_eos|>" + + raise ValueError("Only audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + # Initialize audio components + self.audio_encoder = DashengAudioTransformer( + config.audio_encoder_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_encoder"), + ) + self.audio_projector = AudioProjectorSubsample( + in_dim=config.audio_encoder_config.embed_dim, + out_dim=config.text_config.hidden_size, + downsample_rate=config.subsample_factor, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_projector"), + ) + + # Initialize language model (decoder) + self.decoder = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "decoder"), + architectures=["Qwen2ForCausalLM"], + ) + + self.quant_config = quant_config + self.make_empty_intermediate_tensors = ( + self.decoder.make_empty_intermediate_tensors + ) + + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return mm_input.reshape(-1, *mm_input.shape[2:]) + + if name == "input_values": + max_length = max(tensor.shape[1] for tensor in mm_input) + padded_mm_input = [ + torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1])) + if tensor.shape[1] < max_length + else tensor + for tensor in mm_input + ] + return torch.concat(padded_mm_input) + + return torch.concat(mm_input) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> Optional[MiDashengLMAudioInputs]: + input_values = kwargs.pop("input_values", None) + audio_length = kwargs.pop("audio_length", None) + + if input_values is None: + return None + input_values = self._validate_and_reshape_mm_tensor( + input_values, "input_values" + ) + audio_length = self._validate_and_reshape_mm_tensor( + audio_length, "audio_length" + ) + if not isinstance(input_values, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of audio input features. " + f"Got type: {type(input_values)}" + ) + + return MiDashengLMAudioInputs( + input_values=input_values, + audio_length=audio_length, + ) + + def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: + # Process audio through encoder and projector + input_values = audio_input["input_values"] + audio_length = audio_input["audio_length"] + + encoder_out, encoder_atts = self.audio_encoder(input_values, audio_length) + audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts) + audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype) + batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape + + audio_length_np = ( + audio_length.cpu().numpy() + if isinstance(audio_length, torch.Tensor) + else audio_length + ) + audio_output_lengths = [ + max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame + for length in audio_length_np + ] + audio_output_lengths = torch.tensor(audio_output_lengths).to( + audio_embeddings.device + ) + + audio_feature_mask = torch.arange( + max_audio_tokens, device=audio_embeddings.device + ).unsqueeze(0).expand( + batch_size, max_audio_tokens + ) < audio_output_lengths.unsqueeze(1) + + masked_audio_features = audio_embeddings[audio_feature_mask].view(-1, embed_dim) + + return torch.split(masked_audio_features, audio_output_lengths.tolist()) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + if audio_input is None: + return [] + return self._process_audio_input(audio_input) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings( + input_ids, + multimodal_embeddings, + is_multimodal=input_ids == self.config.audio_token_id, + ) + input_ids = None + + return self.decoder.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 5b497dd9d89f5..e01e064218420 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -25,7 +25,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiMo model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -38,9 +40,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix @@ -54,9 +57,9 @@ logger = init_logger(__name__) "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class MiMoModel(Qwen2Model): - def forward( self, input_ids: torch.Tensor, @@ -74,22 +77,20 @@ class MiMoModel(Qwen2Model): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = hidden_states + residual return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -104,18 +105,19 @@ class MiMoModel(Qwen2Model): continue if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -139,15 +141,13 @@ class MiMoModel(Qwen2Model): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config @@ -159,32 +159,33 @@ class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): self.quant_config = quant_config - self.model = MiMoModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiMoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: hidden_states = self.model.norm(hidden_states) - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 5a2079bf5121a..b678a06b7f20f 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiMo-MTP model.""" + from collections.abc import Iterable from typing import Optional @@ -31,17 +32,17 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix class MiMoMultiTokenPredictorLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -52,19 +53,18 @@ class MiMoMultiTokenPredictorLayer(nn.Module): ) -> None: super().__init__() - self.token_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.hidden_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.input_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.mtp_block = Qwen2DecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_proj = nn.Linear( + config.hidden_size * 2, config.hidden_size, bias=False + ) + self.mtp_block = Qwen2DecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -80,17 +80,17 @@ class MiMoMultiTokenPredictorLayer(nn.Module): previous_hidden_states = self.hidden_layernorm(previous_hidden_states) hidden_states = self.input_proj( - torch.cat([previous_hidden_states, inputs_embeds], dim=-1)) + torch.cat([previous_hidden_states, inputs_embeds], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return self.final_layernorm(hidden_states) class MiMoMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -103,21 +103,27 @@ class MiMoMultiTokenPredictor(nn.Module): config.hidden_size, ) - self.mtp_layers = torch.nn.ModuleDict({ - str(idx): - MiMoMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.mtp_layers = torch.nn.ModuleDict( + { + str(idx): MiMoMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.logits_processor = LogitsProcessor(config.vocab_size) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -126,7 +132,6 @@ class MiMoMultiTokenPredictor(nn.Module): inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)]( @@ -140,25 +145,28 @@ class MiMoMultiTokenPredictor(nn.Module): self, hidden_states: torch.Tensor, lm_head: ParallelLMHead, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)] - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(lm_head, hidden_states) return logits class MiMoMTP(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = MiMoMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.model = MiMoMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -170,21 +178,19 @@ class MiMoMTP(nn.Module): spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "mimo_mtp only support predict one token now" - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, self.lm_head, - sampling_metadata, spec_step_idx) + return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -196,12 +202,11 @@ class MiMoMTP(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: continue name = self.map_model_name_to_mtp_param_name(name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -213,7 +218,7 @@ class MiMoMTP(nn.Module): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -228,29 +233,41 @@ class MiMoMTP(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if "mtp_layers" not in name and ("embed_tokens" not in name - and "lm_head" not in name): + if "mtp_layers" not in name and ( + "embed_tokens" not in name and "lm_head" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def map_model_name_to_mtp_param_name(self, name: str) -> str: import regex as re + + # append mtp_start_layer_idx + pattern = r"(model\.mtp_layers\.)(\d+)(\.)" + match = re.match(pattern, name) + if match: + original_num = int(match.group(2)) + new_num = original_num + self.config.num_hidden_layers + name = name.replace(match.group(), f"{match.group(1)}{new_num}.") + # check for early turn name_without_prefix = [ - "token_layernorm", "hidden_layernorm", "input_proj", - "final_layernorm" + "token_layernorm", + "hidden_layernorm", + "input_proj", + "final_layernorm", ] for sub_name in name_without_prefix: if sub_name in name: return name - pattern = r"model.mtp_layers.(\d+)." - group = re.match(pattern, name) - if group is not None: - name = name.replace(group.group(), group.group() + "mtp_block.") + # add mtp_block + pattern = r"(model\.mtp_layers\.\d+\.)" + match = re.match(pattern, name) + if match: + name = name.replace(match.group(), match.group() + "mtp_block.") return name def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: @@ -259,7 +276,11 @@ class MiMoMTP(nn.Module): Add .mtp_block for modules in transformer layer block for spec layer """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] spec_layer_weight = False for weight_name in spec_layer_weight_names: @@ -268,6 +289,7 @@ class MiMoMTP(nn.Module): break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) return name diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index d398a5d12bbcd..06cb6bc615767 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -23,8 +23,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -34,31 +36,42 @@ from transformers import PretrainedConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class MiniCPMMoE(nn.Module): @@ -90,34 +103,53 @@ class MiniCPMMoE(nn.Module): params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=None) + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None, + ) self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size @@ -125,8 +157,9 @@ class MiniCPMMoE(nn.Module): if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] @@ -135,23 +168,22 @@ class MiniCPMMoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True) + + topk_weights, topk_ids, _ = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=True + ) + + final_hidden_states = fused_experts( + hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=True + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class MiniCPMMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -162,20 +194,20 @@ class MiniCPMMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act == "silu": self.act_fn = SiluAndMul() elif hidden_act == "fatrelu": self.act_fn = FatreluAndMul(threshold=hidden_act_param) else: - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu and fatrelu are supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu and fatrelu are supported for now." + ) def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -185,7 +217,6 @@ class MiniCPMMLP(nn.Module): class MiniCPMAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -244,13 +275,15 @@ class MiniCPMAttention(nn.Module): rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -269,7 +302,6 @@ class MiniCPMAttention(nn.Module): class MiniCPMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -284,15 +316,15 @@ class MiniCPMDecoderLayer(nn.Module): self.hidden_size = config.hidden_size self.rope_theta = getattr(config, "rope_theta", 10000) self.rope_scaling = getattr(config, "rope_scaling", None) - self.max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.prefix = prefix self._init_attn_block() self._init_ffn_block() def _init_attn_block(self): - self.input_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.input_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, num_heads=self.config.num_attention_heads, @@ -306,15 +338,16 @@ class MiniCPMDecoderLayer(nn.Module): ) def _init_ffn_block(self): - self.post_attention_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, intermediate_size=self.config.intermediate_size, hidden_act=self.config.hidden_act, - hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + hidden_act_param=getattr(self.config, "hidden_act_param", 0.0), quant_config=self.quant_config, ) else: @@ -322,7 +355,8 @@ class MiniCPMDecoderLayer(nn.Module): num_experts=self.config.num_experts, top_k=self.config.num_experts_per_tok, hidden_size=self.config.hidden_size, - intermediate_size=self.config.intermediate_size) + intermediate_size=self.config.intermediate_size, + ) def forward( self, @@ -337,22 +371,23 @@ class MiniCPMDecoderLayer(nn.Module): positions=positions, hidden_states=hidden_states, ) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + ) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + ) return hidden_states, None @support_torch_compile class MiniCPMModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -364,8 +399,11 @@ class MiniCPMModel(nn.Module): self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -376,9 +414,12 @@ class MiniCPMModel(nn.Module): self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], self.config.hidden_size)) + + self.aux_hidden_state_layers = tuple[int, ...]() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) def _init_layers( self, @@ -390,8 +431,10 @@ class MiniCPMModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MiniCPMDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -403,7 +446,9 @@ class MiniCPMModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -414,22 +459,32 @@ class MiniCPMModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append( + hidden_states + residual if residual is not None else hidden_states + ) hidden_states, residual = layer( positions, hidden_states, residual, ) + if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states = self.norm(hidden_states) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -440,8 +495,11 @@ class MiniCPMModel(nn.Module): ] expert_params_mapping = [ # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(self.num_experts) for weight_name in ["w1", "w2", "w3"] ] @@ -451,12 +509,11 @@ class MiniCPMModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -478,10 +535,9 @@ class MiniCPMModel(nn.Module): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) break else: # Skip loading extra bias for GPTQ models. @@ -490,14 +546,15 @@ class MiniCPMModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -531,8 +588,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.cache_config = cache_config self.quant_config = quant_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._init_model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) unpadded_vocab_size = config.vocab_size if lora_config: @@ -544,17 +602,19 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) @@ -562,31 +622,49 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) / self.scale_width - return hidden_states + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + if isinstance(model_output, tuple) and len(model_output) == 2: + # Aux hidden states are present. + hidden_states, aux_hidden_states = model_output + hidden_states = hidden_states / self.scale_width + return hidden_states, aux_hidden_states + else: + # Only hidden states or IntermediateTensors + if isinstance(model_output, IntermediateTensors): + return model_output + else: + hidden_states = model_output / self.scale_width + return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 92c13e81bf3e4..35f02a1538e87 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" + from typing import Any, Optional import torch @@ -34,20 +35,23 @@ from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, - MiniCPMForCausalLM, - MiniCPMModel) +from vllm.model_executor.models.minicpm import ( + MiniCPMDecoderLayer, + MiniCPMForCausalLM, + MiniCPMModel, +) from .utils import make_layers class MiniCPM3Attention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -83,33 +87,37 @@ class MiniCPM3Attention(nn.Module): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config) + self.q_a_proj = ReplicatedLinear( + self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config + ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) - self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, - self.kv_lora_rank + - self.qk_rope_head_dim, - bias=False, - quant_config=quant_config) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, - quant_config=quant_config) + quant_config=quant_config, + ) # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) self.rotary_emb = get_rope( self.qk_rope_head_dim, @@ -118,13 +126,15 @@ class MiniCPM3Attention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -135,55 +145,52 @@ class MiniCPM3Attention(nn.Module): q = self.q_a_layernorm(q) q, _ = self.q_b_proj(q) q = q.view(-1, self.num_local_heads, self.qk_head_dim) - _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv, _ = self.kv_b_proj(kv_a) - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb( positions, q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim), - k_pe.reshape(-1, self.qk_rope_head_dim)) + k_pe.reshape(-1, self.qk_rope_head_dim), + ) q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim) k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe q = q.reshape(-1, self.num_local_heads * self.qk_head_dim) k = k.view(-1, self.num_local_heads * self.qk_head_dim) v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): - def _init_attn_block(self): - self.input_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.input_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.self_attn = MiniCPM3Attention( config=self.config, hidden_size=self.hidden_size, @@ -203,7 +210,6 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): class MiniCPM3Model(MiniCPMModel): - def _init_layers( self, prefix: str, @@ -214,8 +220,10 @@ class MiniCPM3Model(MiniCPMModel): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MiniCPM3DecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) class MiniCPM3ForCausalLM(MiniCPMForCausalLM): diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 06c2eb4e80afb..6c635b2481093 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only EagleMiniCPM model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable from typing import Optional, Union @@ -37,21 +38,26 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .minicpm import MiniCPMAttention as EagleMiniCPMAttention from .minicpm import MiniCPMMLP as EagleMiniCPMMLP from .minicpm import MiniCPMMoE as EagleMiniCPMMoE -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) class EagleMiniCPMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -66,15 +72,15 @@ class EagleMiniCPMDecoderLayer(nn.Module): self.hidden_size = config.hidden_size self.rope_theta = getattr(config, "rope_theta", 10000) self.rope_scaling = getattr(config, "rope_scaling", None) - self.max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.prefix = prefix self._init_attn_block() self._init_ffn_block() def _init_attn_block(self): - self.input_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.input_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.self_attn = EagleMiniCPMAttention( hidden_size=self.hidden_size, num_heads=self.config.num_attention_heads, @@ -88,15 +94,16 @@ class EagleMiniCPMDecoderLayer(nn.Module): ) def _init_ffn_block(self): - self.post_attention_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = EagleMiniCPMMLP( hidden_size=self.hidden_size, intermediate_size=self.config.intermediate_size, hidden_act=self.config.hidden_act, - hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + hidden_act_param=getattr(self.config, "hidden_act_param", 0.0), quant_config=self.quant_config, ) else: @@ -104,7 +111,8 @@ class EagleMiniCPMDecoderLayer(nn.Module): num_experts=self.config.num_experts, top_k=self.config.num_experts_per_tok, hidden_size=self.config.hidden_size, - intermediate_size=self.config.intermediate_size) + intermediate_size=self.config.intermediate_size, + ) def forward( self, @@ -119,27 +127,26 @@ class EagleMiniCPMDecoderLayer(nn.Module): positions=positions, hidden_states=hidden_states, ) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.mup_denominator)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.mup_denominator) + ) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.mup_denominator)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.mup_denominator) + ) return hidden_states, None @support_torch_compile class EagleMiniCPMModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - start_layer: int = 0): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", start_layer: int = 0 + ): super().__init__() config = vllm_config.speculative_config.draft_model_config.hf_config @@ -150,13 +157,16 @@ class EagleMiniCPMModel(nn.Module): self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) + self.fc = torch.nn.Linear( + self.config.hidden_size * 2, self.config.hidden_size, bias=False + ) self.input_norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.embed_tokens = VocabParallelEmbedding( @@ -165,12 +175,11 @@ class EagleMiniCPMModel(nn.Module): org_num_embeddings=config.vocab_size, ) self.num_experts = getattr(self.config, "num_experts", 0) - self._init_layers(prefix, config, cache_config, quant_config, - start_layer) + self._init_layers(prefix, config, cache_config, quant_config, start_layer) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], self.config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) def _init_layers( self, @@ -180,14 +189,17 @@ class EagleMiniCPMModel(nn.Module): quant_config: Optional[QuantizationConfig], start_layer: int, ): - self.eagle_layers = nn.ModuleList([ - EagleMiniCPMDecoderLayer( - config, - cache_config, - quant_config, - f"{prefix}.eagle_layers.{i + start_layer}", - ) for i in range(self.config.num_hidden_layers) - ]) + self.eagle_layers = nn.ModuleList( + [ + EagleMiniCPMDecoderLayer( + config, + cache_config, + quant_config, + f"{prefix}.eagle_layers.{i + start_layer}", + ) + for i in range(self.config.num_hidden_layers) + ] + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -203,8 +215,7 @@ class EagleMiniCPMModel(nn.Module): input_embeds = self.input_norm1(input_embeds) hidden_states = self.input_norm2(hidden_states) - hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) residual = None for layer in self.eagle_layers: hidden_states, residual = layer( @@ -215,8 +226,7 @@ class EagleMiniCPMModel(nn.Module): return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -227,8 +237,11 @@ class EagleMiniCPMModel(nn.Module): ] expert_params_mapping = [ # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(self.num_experts) for weight_name in ["w1", "w2", "w3"] ] @@ -238,12 +251,11 @@ class EagleMiniCPMModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -265,10 +277,9 @@ class EagleMiniCPMModel(nn.Module): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) break else: # Skip loading extra bias for GPTQ models. @@ -277,8 +288,9 @@ class EagleMiniCPMModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -320,11 +332,14 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + vllm_config.parallel_config + ) - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - start_layer=target_layer_num) + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + start_layer=target_layer_num, + ) unpadded_vocab_size = config.vocab_size if lora_config: @@ -336,26 +351,26 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) - def _init_model(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - start_layer: int = 0): - return EagleMiniCPMModel(vllm_config=vllm_config, - prefix=prefix, - start_layer=start_layer) + def _init_model( + self, *, vllm_config: VllmConfig, prefix: str = "", start_layer: int = 0 + ): + return EagleMiniCPMModel( + vllm_config=vllm_config, prefix=prefix, start_layer=start_layer + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -366,8 +381,7 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): positions: torch.Tensor, hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - hidden_states, hidden_states2 = self.model(input_ids, positions, - hidden_states) + hidden_states, hidden_states2 = self.model(input_ids, positions, hidden_states) hidden_states = hidden_states / self.scale_width hidden_states2 = hidden_states2 / self.scale_width return hidden_states, hidden_states2 @@ -375,17 +389,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 225668d87facb..34f05122abe3a 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -23,41 +23,55 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Callable, Literal, Optional, Union import torch from torch import nn -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.whisper.modeling_whisper import (ACT2FN, - WhisperAttention, - WhisperConfig, - WhisperEncoder) +from transformers.models.whisper.modeling_whisper import ( + ACT2FN, + WhisperAttention, + WhisperConfig, + WhisperEncoder, +) from vllm.config import VllmConfig -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) +from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) -from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, - DictEmbeddingItems, ModalityData, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioItem, + AudioProcessorItems, + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, - MiniCPMVDummyInputsBuilder, - MiniCPMVMultiModalDataParser, - MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, - _minicpmv_field_config) -from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, - maybe_prefix) +from .minicpmv import ( + _MAX_FRAMES_PER_VIDEO, + MiniCPMV2_6, + MiniCPMVDummyInputsBuilder, + MiniCPMVMultiModalDataParser, + MiniCPMVMultiModalProcessor, + MiniCPMVProcessingInfo, + _minicpmv_field_config, +) +from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -71,6 +85,7 @@ class MiniCPMOAudioFeatureInputs(TensorSchema): - l: Length - s: Number of slices """ + type: Literal["audio_features"] = "audio_features" audio_features: Annotated[ @@ -99,9 +114,10 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema): - bn: Batch size * number of audios - s: Number of slices - h: Hidden size (must match language model backbone) - + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" audio_embeds: Annotated[ @@ -110,8 +126,7 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema): ] -MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, - MiniCPMOAudioEmbeddingInputs] +MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioEmbeddingInputs] def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): @@ -128,7 +143,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): - def __init__( self, data: Mapping[str, torch.Tensor], @@ -146,7 +160,6 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): - def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], @@ -218,18 +231,17 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): max_image_tokens = self.get_max_image_tokens() * max_images max_audio_tokens = self.get_max_audio_tokens() * max_audios - max_total_frames = self.get_max_video_frames(seq_len - - max_image_tokens - - max_audio_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self.get_max_video_frames( + seq_len - max_image_tokens - max_audio_tokens + ) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) -class MiniCPMODummyInputsBuilder( - MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): - +class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -241,28 +253,33 @@ class MiniCPMODummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) - audio_len = self.info.get_max_audio_chunks_with_most_features() * \ - self.info.get_default_audio_sampling_rate() + audio_len = ( + self.info.get_max_audio_chunks_with_most_features() + * self.info.get_default_audio_sampling_rate() + ) + + audio_overrides = mm_options.get("audio") if mm_options else None audio_mm_data = { - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } return { - **super().get_dummy_mm_data(seq_len, mm_counts), + **super().get_dummy_mm_data(seq_len, mm_counts, mm_options), **audio_mm_data, } -class MiniCPMOMultiModalProcessor( - MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): - +class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMOMultiModalDataParser( - target_sr=self.info.get_default_audio_sampling_rate()) + target_sr=self.info.get_default_audio_sampling_rate() + ) def get_audio_prompt_texts( self, @@ -285,10 +302,11 @@ class MiniCPMOMultiModalProcessor( if (audios := mm_data.get("audios")) is None: return {} - parsed_audios = (self._get_data_parser().parse_mm_data({ - "audio": audios - }).get_items("audio", - (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))) + parsed_audios = ( + self._get_data_parser() + .parse_mm_data({"audio": audios}) + .get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) + ) if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): audio_inputs = {} @@ -296,9 +314,7 @@ class MiniCPMOMultiModalProcessor( audio_inputs = self._base_call_hf_processor( prompts=[self.info.audio_pattern] * len(parsed_audios), mm_data={"audios": [[audio] for audio in parsed_audios]}, - mm_kwargs={ - **mm_kwargs, "chunk_input": True - }, + mm_kwargs={**mm_kwargs, "chunk_input": True}, tok_kwargs=tok_kwargs, out_keys={"audio_features", "audio_feature_lens"}, ) @@ -306,7 +322,8 @@ class MiniCPMOMultiModalProcessor( # Avoid padding since we need the output for each audio to be # independent of other audios for the cache to work correctly unpadded_audio_features = [ - feat[:, :feature_len] for feat, feature_len in zip( + feat[:, :feature_len] + for feat, feature_len in zip( audio_inputs["audio_features"], audio_inputs["audio_feature_lens"], ) @@ -346,12 +363,14 @@ class MiniCPMOMultiModalProcessor( def get_audio_replacement(item_idx: int): audios = mm_items.get_items( - "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) + "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems) + ) if isinstance(audios, MiniCPMOAudioEmbeddingItems): single_audio_embeds = audios.get(item_idx)["audio_embeds"] audio_len = self.info.get_audio_len_by_num_chunks( - sum(map(len, single_audio_embeds))) + sum(map(len, single_audio_embeds)) + ) else: audio_len = audios.get_audio_length(item_idx) @@ -362,9 +381,11 @@ class MiniCPMOMultiModalProcessor( return [ *base_updates, - PromptReplacement(modality="audio", - target=audio_placeholder, - replacement=get_audio_replacement), + PromptReplacement( + modality="audio", + target=audio_placeholder, + replacement=get_audio_replacement, + ), ] def _get_mm_fields_config( @@ -376,16 +397,11 @@ class MiniCPMOMultiModalProcessor( class MultiModalProjector(nn.Module): - def __init__(self, in_dim: int, out_dim: int): super().__init__() - self.linear1 = nn.Linear(in_features=in_dim, - out_features=out_dim, - bias=True) + self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) self.relu = nn.ReLU() - self.linear2 = nn.Linear(in_features=out_dim, - out_features=out_dim, - bias=True) + self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: hidden_states = self.relu(self.linear1(audio_features)) @@ -394,7 +410,6 @@ class MultiModalProjector(nn.Module): class MiniCPMWhisperEncoderLayer(nn.Module): - def __init__(self, config: WhisperConfig, layer_idx: int): super().__init__() self.embed_dim = config.d_model @@ -426,39 +441,40 @@ class MiniCPMWhisperEncoderLayer(nn.Module): attention_mask=attention_mask, past_key_value=past_key_values, ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = cast_overflow_tensors(hidden_states) - outputs = (hidden_states, ) + outputs = (hidden_states,) return outputs class MiniCPMWhisperEncoder(WhisperEncoder): - def __init__(self, config: WhisperConfig): super().__init__(config) - self.layers = nn.ModuleList([ - MiniCPMWhisperEncoderLayer(config, layer_idx=i) - for i in range(config.encoder_layers) - ]) + self.layers = nn.ModuleList( + [ + MiniCPMWhisperEncoderLayer(config, layer_idx=i) + for i in range(config.encoder_layers) + ] + ) def forward( self, @@ -466,8 +482,9 @@ class MiniCPMWhisperEncoder(WhisperEncoder): attention_mask: Optional[torch.Tensor] = None, ) -> BaseModelOutputWithPast: # Ignore copy - input_features = input_features.to(dtype=self.conv1.weight.dtype, - device=self.conv1.weight.device) + input_features = input_features.to( + dtype=self.conv1.weight.dtype, device=self.conv1.weight.device + ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) @@ -476,17 +493,17 @@ class MiniCPMWhisperEncoder(WhisperEncoder): embed_pos = self.embed_positions.weight - embed_pos = embed_pos[:inputs_embeds.shape[1], :] + embed_pos = embed_pos[: inputs_embeds.shape[1], :] hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) encoder_states = () for idx, encoder_layer in enumerate(self.layers): - encoder_states = encoder_states + (hidden_states, ) + encoder_states = encoder_states + (hidden_states,) to_drop = False if self.training: dropout_probability = torch.rand([]) @@ -505,7 +522,7 @@ class MiniCPMWhisperEncoder(WhisperEncoder): hidden_states = layer_outputs[0] hidden_states = self.layer_norm(hidden_states) - encoder_states = encoder_states + (hidden_states, ) + encoder_states = encoder_states + (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -516,7 +533,8 @@ class MiniCPMWhisperEncoder(WhisperEncoder): @MULTIMODAL_REGISTRY.register_processor( MiniCPMOMultiModalProcessor, info=MiniCPMOProcessingInfo, - dummy_inputs=MiniCPMODummyInputsBuilder) + dummy_inputs=MiniCPMODummyInputsBuilder, +) class MiniCPMO(MiniCPMV2_6): packed_modules_mapping = { "qkv_proj": [ @@ -543,56 +561,27 @@ class MiniCPMO(MiniCPMV2_6): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.apm = self.init_audio_module(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "apm")) + self.apm = self.init_audio_module( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") + ) self.audio_token_id = None - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid vision encoder sections for some models. - # See: https://huggingface.co/openbmb/MiniCPM-o-2_6-int4 - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config - - def init_vision_module( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> nn.Module: - # MiniCPMO GPTQ model leave vpm unquantized. - quant_config = self._maybe_ignore_quant_config(quant_config) - return super().init_vision_module(config, quant_config, prefix) - - def init_resampler( - self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> nn.Module: - # MiniCPMO GPTQ model leave resampler unquantized. - quant_config = self._maybe_ignore_quant_config(quant_config) - return super().init_resampler(embed_dim, vision_dim, quant_config, - prefix) - def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): # Do not use parameters temporarily audio_config = self.config.audio_config model = MiniCPMWhisperEncoder(audio_config) audio_output_dim = int(audio_config.encoder_ffn_dim // 4) - self.audio_avg_pooler = \ - nn.AvgPool1d(self.config.audio_pool_step, - stride=self.config.audio_pool_step) - self.audio_projection_layer = \ - MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim) + self.audio_avg_pooler = nn.AvgPool1d( + self.config.audio_pool_step, stride=self.config.audio_pool_step + ) + self.audio_projection_layer = MultiModalProjector( + in_dim=audio_output_dim, out_dim=self.embed_dim + ) self.audio_encoder_layer = -1 return model - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) return loader.load_weights(weights) @@ -613,14 +602,13 @@ class MiniCPMO(MiniCPMV2_6): start_indices = torch.zeros_like(row_indices) else: # Compute start indices vectorially - start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, - min=0) + start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0) start_indices = start_chunk_indices * chunk_size # Compute ending indices vectorially end_chunk_indices = chunk_indices + 1 - end_indices = torch.clamp(end_chunk_indices * chunk_size + - num_lookhead, - max=size) + end_indices = torch.clamp( + end_chunk_indices * chunk_size + num_lookhead, max=size + ) # Create column indices for broadcasting col_indices = torch.arange(size, device=device).unsqueeze(0) start_indices = start_indices.unsqueeze(1) @@ -629,19 +617,18 @@ class MiniCPMO(MiniCPMV2_6): ret = (col_indices >= start_indices) & (col_indices < end_indices) return ret - def _get_feat_extract_output_lengths(self, - input_lengths: torch.LongTensor): + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 input_lengths_after_pooling = ( - input_lengths_after_cnn - - self.config.audio_pool_step) // self.config.audio_pool_step + 1 - input_lengths_after_pooling = input_lengths_after_pooling.to( - dtype=torch.int32) + input_lengths_after_cnn - self.config.audio_pool_step + ) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling def get_audio_hidden_states( - self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]: + self, data: MiniCPMOAudioFeatureInputs + ) -> list[torch.Tensor]: chunk_length = self.config.audio_chunk_length # (bs, 80, frames) or [], multi audios need filled in advance @@ -670,23 +657,26 @@ class MiniCPMO(MiniCPMV2_6): max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = (torch.arange( - 0, - max_seq_len, - dtype=audio_feature_lens.dtype, - device=audio_feature_lens.device).unsqueeze(0).expand( - batch_size, max_seq_len)) - lengths_expand = audio_feature_lens.unsqueeze(1).expand( - batch_size, max_seq_len) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=audio_feature_lens.dtype, + device=audio_feature_lens.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) # Create mask padding_mask = seq_range >= lengths_expand # 1 for padded values - audio_attention_mask_ = padding_mask.view( - batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, - max_seq_len) + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, - device=self.apm.conv1.weight.device) + dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device + ) if chunk_length > 0: chunk_num_frame = int(chunk_length * 50) @@ -697,20 +687,22 @@ class MiniCPMO(MiniCPMV2_6): device=audio_attention_mask_.device, ) audio_attention_mask_ = torch.logical_or( - audio_attention_mask_, torch.logical_not(chunk_mask)) + audio_attention_mask_, torch.logical_not(chunk_mask) + ) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( - wavforms, attention_mask=audio_attention_mask).hidden_states[ - self.audio_encoder_layer] + wavforms, attention_mask=audio_attention_mask + ).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) - _, feature_lens_after_pooling = \ - self._get_feat_extract_output_lengths(audio_feature_lens) + _, feature_lens_after_pooling = self._get_feat_extract_output_lengths( + audio_feature_lens + ) num_audio_tokens = feature_lens_after_pooling @@ -720,7 +712,8 @@ class MiniCPMO(MiniCPMV2_6): target_audio_embeds_lst = list[torch.Tensor]() for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds_lst.append( - audio_embeds[idx, :num_audio_tokens[idx], :]) + audio_embeds[idx, : num_audio_tokens[idx], :] + ) idx += 1 final_audio_embeds.append(torch.cat(target_audio_embeds_lst)) @@ -728,7 +721,8 @@ class MiniCPMO(MiniCPMV2_6): return final_audio_embeds def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]: + self, **kwargs: object + ) -> Optional[MiniCPMOAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) @@ -742,8 +736,9 @@ class MiniCPMO(MiniCPMV2_6): if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_embeds. " - f"Got type: {type(audio_embeds)}") + raise ValueError( + f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}" + ) audio_embeds_flat = flatten_bn(audio_embeds) @@ -753,13 +748,16 @@ class MiniCPMO(MiniCPMV2_6): ) if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_features. " - f"Got type: {type(audio_features)}") + raise ValueError( + f"Incorrect type of audio_features. Got type: {type(audio_features)}" + ) audio_feature_lens = kwargs.pop("audio_feature_lens") if not isinstance(audio_feature_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_feature_lens. " - f"Got type: {type(audio_feature_lens)}") + raise ValueError( + "Incorrect type of audio_feature_lens. " + f"Got type: {type(audio_feature_lens)}" + ) audio_features_flat = flatten_bn(audio_features) audio_feature_lens_flat = flatten_bn(audio_feature_lens) @@ -776,10 +774,11 @@ class MiniCPMO(MiniCPMV2_6): # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("audio_features", - "audio_embeds") and "audios" not in modalities: - modalities["audios"] = self._parse_and_validate_audio_input( - **kwargs) + if ( + input_key in ("audio_features", "audio_embeds") + and "audios" not in modalities + ): + modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) return modalities diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a2a71bdd12b36..09f973e98db99 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,42 +23,66 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" + import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial +from itertools import chain from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch import torch.types from torch import nn +from torch.nn.init import trunc_normal_ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig -from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, - get_2d_sincos_pos_embed) +from vllm.model_executor.layers.resampler import ( + BaseResampler, + Resampler2, + get_2d_sincos_pos_embed, +) from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, - ImageProcessorItems, ImageSize, - ModalityData, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser, - VideoItem, VideoProcessorItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageItem, + ImageProcessorItems, + ImageSize, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, + VideoItem, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + ResolvedPromptUpdate, + _seq2text, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -66,10 +90,13 @@ from vllm.utils import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -118,45 +145,48 @@ class MiniCPMVImageEmbeddingInputs(TensorSchema): ] -MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, - MiniCPMVImageEmbeddingInputs] +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) class Resampler2_5(BaseResampler): - - def __init__(self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: tuple[int, int] = (70, 70), - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(num_queries, - embed_dim, - num_heads, - kv_dim, - norm_layer, - quant_config=quant_config, - prefix=prefix) + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix, + ) self.max_size = max_size self._set_2d_pos_cache(self.max_size) - def _set_2d_pos_cache(self, - max_size: tuple[int, int], - device: torch.types.Device = "cpu") -> None: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - max_size, - version=(2, 5)) + def _set_2d_pos_cache( + self, max_size: tuple[int, int], device: torch.types.Device = "cpu" + ) -> None: + pos_embed_arr = get_2d_sincos_pos_embed( + self.embed_dim, max_size, version=(2, 5) + ) pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, - device: torch.types.Device) -> None: + def _adjust_pos_cache( + self, tgt_sizes: torch.Tensor, device: torch.types.Device + ) -> None: max_h = tgt_sizes[:, 0].max().item() max_w = tgt_sizes[:, 1].max().item() assert isinstance(max_h, int) and isinstance(max_w, int) @@ -168,8 +198,7 @@ class Resampler2_5(BaseResampler): ) self._set_2d_pos_cache(self.max_size, device) - def forward(self, x: torch.Tensor, - tgt_sizes: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor: assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -183,21 +212,20 @@ class Resampler2_5(BaseResampler): max_patch_len = patch_len.max().item() assert isinstance(max_patch_len, int) - key_padding_mask = torch.zeros((bs, max_patch_len), - dtype=torch.bool, - device=device) + key_padding_mask = torch.zeros( + (bs, max_patch_len), dtype=torch.bool, device=device + ) pos_embed = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i].tolist() - pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( - (tgt_h * tgt_w, -1)).to(dtype)) # patches * D - key_padding_mask[i, patch_len[i]:] = True - pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, - batch_first=True, - padding_value=0.0).permute( - 1, 0, - 2) # BLD => L * B * D + pos_embed.append( + self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) + ) # patches * D + key_padding_mask[i, patch_len[i] :] = True + pos_embed = torch.nn.utils.rnn.pad_sequence( + pos_embed, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # BLD => L * B * D x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D @@ -217,6 +245,200 @@ class Resampler2_5(BaseResampler): return x +class Resampler4_5(Resampler2_5): + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix, + ) + + trunc_normal_(self.query, std=0.02) + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size) + self.apply(self._init_weights) + + def get_1d_sincos_pos_embed_from_temporal_size( + self, embed_dim: int, pos: np.ndarray + ): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + def _set_temporal_pos_cache( + self, max_temporal_size: int, device: torch.types.Device = "cpu" + ) -> None: + temporal_size = np.arange(max_temporal_size, dtype=np.float32) + pos_embed = ( + torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size + ) + ) + .float() + .to(device) + ) + self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) + + def _adjust_temporal_pos_cache( + self, max_temporal_size: int, device: torch.types.Device = "cpu" + ): + if max_temporal_size > self.max_temporal_size: + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size, device) + + def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + # temporal_ids for high refresh rate videos + temporal_ids=None, + ) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros( + (bs, max_patch_len), dtype=torch.bool, device=device + ) + + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + q = self.ln_q(self.query) # Q * D + + pos_embed_2d = [] + pos_embed_temporal = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i] + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append( + torch.zeros(self.embed_dim, dtype=dtype, device=device) + ) + else: + pos_embed_temporal.append( + self.temporal_pos_embed[temporal_ids_flatten[i]].to(dtype) + ) # D + + pos_embed_2d.append( + self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) + ) # patches * D + key_padding_mask[i, patch_len[i] :] = True + + pos_embed_2d = torch.nn.utils.rnn.pad_sequence( + pos_embed_2d, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # BLD => L * B * D + + k = x + v = x + pos_embed_2d + if pos_embed_temporal: + k += torch.stack(pos_embed_temporal, dim=0) + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append( + k[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) + ) + merge_v.append( + v[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) + ) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1) + ) + + start = end + + k = torch.nn.utils.rnn.pad_sequence( + merge_k, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence( + merge_v, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, padding_value=True + ).squeeze(-1) + + out = self.attn( + self._repeat(q, bs), # Q * B * D + k, # L * B * D + L * B * D + v, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) @@ -252,7 +474,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): class MiniCPMVImageEmbeddingItems(DictEmbeddingItems): - def __init__( self, data: Mapping[str, torch.Tensor], @@ -274,7 +495,6 @@ class MiniCPMVImageEmbeddingItems(DictEmbeddingItems): class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems): - def __init__( self, data: Mapping[str, torch.Tensor], @@ -299,7 +519,6 @@ class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems): class MiniCPMVMultiModalDataParser(MultiModalDataParser): - def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -353,9 +572,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, - 6) or self.get_model_version() == (4, - 0): + if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None return mm_limits @@ -470,21 +687,18 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self.get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self.get_max_video_frames(seq_len - max_image_tokens) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) -_I = TypeVar("_I", - bound=MiniCPMVProcessingInfo, - default=MiniCPMVProcessingInfo) +_I = TypeVar("_I", bound=MiniCPMVProcessingInfo, default=MiniCPMVProcessingInfo) class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -498,51 +712,59 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - image_width, image_height = \ - self.info.get_image_size_with_most_features() - video_width, video_height = \ - self.info.get_video_frame_size_with_most_features() - num_video_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + image_width, image_height = self.info.get_image_size_with_most_features() + video_width, video_height = self.info.get_video_frame_size_with_most_features() + num_video_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=image_width, - height=image_height, - num_images=num_images), + "image": self._get_dummy_images( + width=image_width, + height=image_height, + num_images=num_images, + overrides=image_overrides, + ), "video": [ - self._get_dummy_images(width=video_width, - height=video_height, - num_images=num_video_frames) - ] * num_videos, + self._get_dummy_images( + width=video_width, + height=video_height, + num_images=num_video_frames, + overrides=video_overrides, + ) + ] + * num_videos, } class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): - def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMVMultiModalDataParser() - def get_image_prompt_texts(self, - image_size: ImageSize, - image_idx: int = 0) -> str: + def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str: return self.info.get_slice_image_placeholder( image_size, image_idx=image_idx, ) - def get_video_prompt_texts(self, image_size: ImageSize, - num_frames: int) -> str: - return self.info.get_slice_image_placeholder( - image_size=image_size, - image_idx=0, - max_slice_nums=self.info.get_video_max_slice_num(), - use_image_id=False, - ) * num_frames + def get_video_prompt_texts(self, image_size: ImageSize, num_frames: int) -> str: + return ( + self.info.get_slice_image_placeholder( + image_size=image_size, + image_idx=0, + max_slice_nums=self.info.get_video_max_slice_num(), + use_image_id=False, + ) + * num_frames + ) def process_images( self, @@ -553,10 +775,11 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): if (images := mm_data.get("images")) is None: return {} - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": images - }).get_items("image", - (MiniCPMVImageEmbeddingItems, ImageProcessorItems))) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)) + ) if isinstance(parsed_images, MiniCPMVImageEmbeddingItems): image_inputs = {} @@ -584,24 +807,23 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): if (videos := mm_data.get("videos")) is None: return {} - parsed_videos = (self._get_data_parser().parse_mm_data({ - "video": videos - }).get_items("video", - (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))) + parsed_videos = ( + self._get_data_parser() + .parse_mm_data({"video": videos}) + .get_items("video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)) + ) if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems): video_inputs = {} else: video_inputs = self._base_call_hf_processor( prompts=[ - self.info.image_pattern * len(video) - for video in parsed_videos + self.info.image_pattern * len(video) for video in parsed_videos ], mm_data={"images": list(parsed_videos)}, mm_kwargs={ **mm_kwargs, - "max_slice_nums": - self.info.get_video_max_slice_num(), + "max_slice_nums": self.info.get_video_max_slice_num(), }, tok_kwargs=tok_kwargs, out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, @@ -636,8 +858,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == ( - 2, 6) or self.info.get_model_version() == (4, 0): + if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}: inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -650,10 +871,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): for i, prompt in enumerate(prompts): inputs_one = super()._call_hf_processor( prompt=prompt, - mm_data={ - k: v[i] - for k, v in mm_data.items() - }, + mm_data={k: v[i] for k, v in mm_data.items()}, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) @@ -676,10 +894,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)]) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs) - return BatchFeature({ - "input_ids": input_ids, - **mm_inputs, - }) + return BatchFeature( + { + "input_ids": input_ids, + **mm_inputs, + } + ) def _hf_processor_applies_updates( self, @@ -696,22 +916,26 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - placeholders = [("image", self.info.image_pattern), - ("video", self.info.video_pattern)] + placeholders = [ + ("image", self.info.image_pattern), + ("video", self.info.video_pattern), + ] # hard code for inconsistency of encode-decode image_pattern additional_placeholders = [] tokenizer = self.info.get_tokenizer() for modality, pattern in placeholders: sub_pattern = tokenizer.decode( - tokenizer.encode(pattern, add_special_tokens=False)) + tokenizer.encode(pattern, add_special_tokens=False) + ) if sub_pattern != pattern: additional_placeholders.append((modality, sub_pattern)) placeholders += additional_placeholders def get_image_replacement(item_idx: int): images = mm_items.get_items( - "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)) + "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems) + ) image_size = images.get_image_size(item_idx) @@ -722,7 +946,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): def get_video_replacement(item_idx: int): videos = mm_items.get_items( - "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)) + "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems) + ) frame_size = videos.get_frame_size(item_idx) num_frames = videos.get_num_frames(item_idx) @@ -738,12 +963,50 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): } return [ - PromptReplacement(modality=modality, - target=pattern, - replacement=get_replacement[modality]) + PromptReplacement( + modality=modality, target=pattern, replacement=get_replacement[modality] + ) for modality, pattern in placeholders ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() + + text = _seq2text(tokenizer, cached_update.content.full) + prev_item_idx = cached_update.item_idx + + if version == (2, 0) or version == (2, 5): + im_start = image_processor.im_start_token + im_end = image_processor.im_end_token + else: + im_start = image_processor.im_id_start + im_end = image_processor.im_id_end + + new_update = new_update.with_content( + PromptUpdateDetails.select_text( + text.replace( + f"{im_start}{prev_item_idx}{im_end}", + f"{im_start}{new_item_idx}{im_end}", + 1, + ), + "<unk>", + ) + ) + + return new_update + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -758,6 +1021,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -771,6 +1036,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config quant_config = vllm_config.quant_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot @@ -778,27 +1044,30 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): # and config class self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.version = get_version_by_config(self.config) - self.llm = self.init_llm(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "llm")) - self.vpm = self.init_vision_module(config, - quant_config, - prefix=maybe_prefix(prefix, "vpm")) - self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else - self.vpm.embeddings.embed_dim) + self.llm = self.init_llm( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm") + ) + self.vpm = self.init_vision_module( + config, quant_config, prefix=maybe_prefix(prefix, "vpm") + ) + self.vision_dim = ( + self.vpm.embed_dim + if self.version == (2, 0) + else self.vpm.embeddings.embed_dim + ) self.embed_dim = self.config.hidden_size - self.resampler = self.init_resampler(self.embed_dim, - self.vision_dim, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "resampler")) + self.resampler = self.init_resampler( + self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "resampler"), + ) self.mm_token_ids = set[int]() - self.make_empty_intermediate_tensors = ( - self.llm.make_empty_intermediate_tensors) + self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors def _parse_and_validate_vision_input( self, @@ -820,7 +1089,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError( f"Incorrect type of image_embeds for {modality=}. " - f"Got type: {type(image_embeds)}") + f"Got type: {type(image_embeds)}" + ) image_embeds_flat = flatten_bn(image_embeds) @@ -832,12 +1102,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError( f"Incorrect type of pixel_values for {modality=}. " - f"Got type: {type(pixel_values)}") + f"Got type: {type(pixel_values)}" + ) tgt_sizes = kwargs.pop("tgt_sizes") if not isinstance(tgt_sizes, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. " - f"Got type: {type(tgt_sizes)}") + raise ValueError( + f"Incorrect type of tgt_sizes for {modality=}. " + f"Got type: {type(tgt_sizes)}" + ) num_slices = [[len(p) for p in ps] for ps in pixel_values] num_slices_flat = flatten_bn(torch.tensor(num_slices)) @@ -858,12 +1131,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): modalities["images"] = self._parse_and_validate_vision_input( - "images", **kwargs) - if input_key in ("video_pixel_values", - "video_embeds") and "videos" not in modalities: + "images", **kwargs + ) + if ( + input_key in ("video_pixel_values", "video_embeds") + and "videos" not in modalities + ): def _image_key(video_key: str): if video_key == "video_token_id": @@ -872,10 +1150,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): return video_key.removeprefix("video_") modalities["videos"] = self._parse_and_validate_vision_input( - "videos", **{ - _image_key(k): v - for k, v in kwargs.items() - }) + "videos", **{_image_key(k): v for k, v in kwargs.items()} + ) return modalities @@ -889,14 +1165,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): image_features_flat = self.get_vision_hidden_states(image_input) num_slices = image_input["num_slices"] - return [ - e.flatten(0, 1) - for e in image_features_flat.split(num_slices.tolist()) - ] + return [e.flatten(0, 1) for e in image_features_flat.split(num_slices.tolist())] def _process_multimodal_inputs(self, modalities: dict): # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -916,31 +1189,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def get_language_model(self) -> torch.nn.Module: return self.llm - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] return self._process_multimodal_inputs(modalities) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert len(self.mm_token_ids) > 0 - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - list(self.mm_token_ids), - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -952,16 +1207,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - hidden_states = self.llm.model( input_ids=input_ids, positions=positions, @@ -973,12 +1218,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.llm.compute_logits(hidden_states, sampling_metadata) + return self.llm.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -986,9 +1229,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): """ Get the module prefix in multimodal models """ - return MultiModelKeys.from_string_field(language_model="llm", - connector="resampler", - tower_model="vpm") + return MultiModelKeys.from_string_field( + language_model="llm", connector="resampler", tower_model="vpm" + ) def init_llm( self, @@ -1005,19 +1248,21 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ) -> nn.Module: raise NotImplementedError - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: raise NotImplementedError - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: raise NotImplementedError class MiniCPMV2_0(MiniCPMVBaseModel): + supports_encoder_tp_data = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1053,8 +1298,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel): model = model.to(dtype=torch.get_default_dtype()) - if (isinstance(model, timm.models.VisionTransformer) - and model.attn_pool is not None): + if ( + isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None + ): model.attn_pool = torch.nn.Identity() if self.config.drop_vision_last_layer: @@ -1062,27 +1309,30 @@ class MiniCPMV2_0(MiniCPMVBaseModel): return model - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2(embed_dim=embed_dim, - num_heads=embed_dim // 128, - grid_size=int( - math.sqrt(self.config.query_num)), - kv_dim=vision_dim, - adaptive=False, - do_post_projection=True, - quant_config=quant_config, - prefix=prefix) + resampler = Resampler2( + embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int(math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + quant_config=quant_config, + prefix=prefix, + ) - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] P_h, P_w = self.vpm.patch_embed.patch_size @@ -1094,7 +1344,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel): H, W = pixel_value[0].shape[-2:] tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w)) vision_embedding = self.vpm.forward_features( - pixel_value.unsqueeze(0).type(dtype)) + pixel_value.unsqueeze(0).type(dtype) + ) if num_prefix_tokens > 0: vision_embedding = vision_embedding[:, num_prefix_tokens:] @@ -1133,31 +1384,38 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -1167,9 +1425,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1178,9 +1434,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1223,32 +1477,39 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> nn.Module: - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. - resampler = Resampler2_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -1258,9 +1519,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1269,9 +1528,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1283,10 +1540,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): return self.resampler(vision_embedding, tgt_sizes) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["apm.", "audio", "tts"]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) return loader.load_weights(weights) @@ -1330,7 +1585,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): config.vision_config, quant_config=quant_config, prefix=prefix, - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1345,18 +1601,20 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. - resampler = Resampler2_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -1366,9 +1624,7 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1377,9 +1633,7 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1391,10 +1645,117 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): return self.resampler(vision_embedding, tgt_sizes) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["apm.", "audio", "tts"]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 5) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel, + ) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler4_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) + + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) + + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + temporal_ids = data.get("temporal_ids", None) + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) + all_temporal_ids = ( + None if temporal_ids is None else flatten_2d_lists(temporal_ids) + ) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) return loader.load_weights(weights) @@ -1403,13 +1764,15 @@ _SUPPORT_VERSION = { (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, (4, 0): MiniCPMV4_0, + (4, 5): MiniCPMV4_5, } @MULTIMODAL_REGISTRY.register_processor( MiniCPMVMultiModalProcessor, info=MiniCPMVProcessingInfo, - dummy_inputs=MiniCPMVDummyInputsBuilder) + dummy_inputs=MiniCPMVDummyInputsBuilder, +) class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ Different versions of MiniCPMV use different visual encoders and LLMs, @@ -1431,9 +1794,12 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): instance_cls = _SUPPORT_VERSION.get(version) if instance_cls is None: supported_versions = ", ".join( - [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())]) - raise ValueError(f"Currently, MiniCPMV only supports versions " - f"{supported_versions}. Got version: {version}") + [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())] + ) + raise ValueError( + f"Currently, MiniCPMV only supports versions " + f"{supported_versions}. Got version: {version}" + ) # quant_config references base class members, # so update values before init is called diff --git a/vllm/model_executor/models/minimax_cache.py b/vllm/model_executor/models/minimax_cache.py deleted file mode 100644 index 9164ac06a3b0a..0000000000000 --- a/vllm/model_executor/models/minimax_cache.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass - -import torch - -from vllm.model_executor.models.constant_size_cache import ConstantSizeCache - - -@dataclass -class MinimaxCacheParams: - minimax_cache: torch.Tensor = torch.Tensor() - state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self, layer_idx): - return MinimaxCacheParams(self.minimax_cache[layer_idx, ...], - self.state_indices_tensor) - - -class MinimaxCacheManager(ConstantSizeCache): - - def __init__(self, dtype, cache_shape): - super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1] - self._minimax_cache = torch.empty(size=cache_shape, - dtype=dtype, - device="cuda") - - @property - def cache(self): - return self._minimax_cache - - def _copy_cache(self, from_index: int, to_index: int): - assert len(self.cache) > 0 - for cache_t in self.cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 82e96844cd5f6..e6e0952f71dd6 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,77 +1,75 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy -import math + from collections.abc import Iterable -from typing import Optional, Union +from itertools import islice +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + pass import regex as re import torch import torch.distributed -import torch.nn.functional as F -from einops import rearrange from torch import nn from transformers import MiniMaxConfig -from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import ( - get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import get_forward_context -from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.lightning_attn import ( - lightning_attention, linear_decode_forward_triton) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid -from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -def replace_weight_name(name: str, - key: str = None, - to: str = None, - count: int = None, - prefix: str = None) -> str: - name = name.replace(key, to) if count is None else \ - name.replace(key, to, count) +def replace_weight_name( + name: str, key: str = None, to: str = None, count: int = None, prefix: str = None +) -> str: + name = name.replace(key, to) if count is None else name.replace(key, to, count) return name def weight_loader_with_alias(alias: str): - def wrapper(func: callable): - - def inner_func(param: torch.Tensor, - loaded_weight: torch.Tensor, - *args, - prefix: str = None, - **kwargs): + def inner_func( + param: torch.Tensor, + loaded_weight: torch.Tensor, + *args, + prefix: str = None, + **kwargs, + ): value = func(param, loaded_weight, *args, **kwargs) return value @@ -80,123 +78,7 @@ def weight_loader_with_alias(alias: str): return wrapper -class MiniMaxText01RMSNormTP(CustomOp): - name = "MiniMaxText01RMSNormTP" - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.tp_world = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.weight = nn.Parameter(torch.ones(int(hidden_size / - self.tp_world))) - - self.weight.weight_loader = self.weight_loader - self.variance_epsilon = eps - return - - @staticmethod - def weight_loader( - param: nn.Parameter, - loaded_weight: torch.Tensor, - ) -> None: - tp_world = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - shard_size = loaded_weight.shape[0] // tp_world - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - param.data.copy_(loaded_weight[shard]) - return - - def _forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - orig_dtype = x.dtype - x = x.to(torch.float32) - variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) - if self.tp_world > 1: - variance = tensor_model_parallel_all_reduce( - variance) / self.tp_world - x = x * torch.rsqrt(variance + self.variance_epsilon) - - weight = self.weight - if x.size(-1) != self.weight.size(0): - if self.weight.size(0) < x.size(-1): - repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) - full_weight = self.weight.repeat(repeat_count) - weight = full_weight[:x.size(-1)] - else: - weight = self.weight[:x.size(-1)] - - x = x.to(orig_dtype) * weight - return x - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - assert residual is None, "RMSNorm does not support residual connection." - return self._forward(x) - - -class MiniMaxText01RotaryEmbedding(CustomOp): - name = "MiniMaxText01RotaryEmbedding" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position: int, - base: float, - is_neox_style: bool, - cache_dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position - self.base = base - self.is_neox_style = is_neox_style - self.cache_dtype = cache_dtype - cache = self._compute_cos_sin_cache().to(cache_dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: float) -> torch.Tensor: - """Compute the inverse frequency.""" - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) - query_cast = query.to(self.cache_dtype) - key_cast = key.to(self.cache_dtype) - ops.rotary_embedding(positions, query_cast, key_cast, self.head_size, - self.cos_sin_cache, self.is_neox_style) - query = query_cast.to(query.dtype) - key = key_cast.to(key.dtype) - return query, key - - class MiniMaxText01MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -226,7 +108,6 @@ class MiniMaxText01MLP(nn.Module): return def forward(self, x: torch.Tensor) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -234,7 +115,6 @@ class MiniMaxText01MLP(nn.Module): class MiniMaxText01MoE(nn.Module): - def __init__( self, num_experts: int, @@ -285,8 +165,7 @@ class MiniMaxText01MoE(nn.Module): return @staticmethod - def gate_weight_loader(param: nn.Parameter, - loaded_weight: torch.Tensor) -> None: + def gate_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight.to(torch.float32)) return @@ -296,291 +175,13 @@ class MiniMaxText01MoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32)) final_hidden_states = self.experts( - hidden_states, router_logits_fp32.to(hidden_states.dtype)) + hidden_states, router_logits_fp32.to(hidden_states.dtype) + ) final_hidden = final_hidden_states.view(num_tokens, hidden_size) return final_hidden -class MiniMaxText01LinearKernel: - - @staticmethod - def jit_linear_forward_prefix(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_caches: torch.Tensor, - slope_rate: torch.Tensor, - block_size: int, - layer_idx: int = None, - **kwargs) -> torch.Tensor: - - slope_rate = slope_rate.to(torch.float32) - should_pad_dim = q.dim() == 3 - if should_pad_dim: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - b, h, n, d = q.shape - e = d - kv_history = kv_caches.reshape(1, h, d, e).contiguous() - output, kv_history = lightning_attention(q, - k, - v, - slope_rate, - block_size=block_size, - kv_history=kv_history) - kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) - assert output.shape[0] == 1, "batch size must be 1" - return rearrange(output.squeeze(0), "h n d -> n (h d)") - - -class MiniMaxText01LinearAttention(nn.Module, MambaBase): - - @property - def mamba_type(self) -> str: - return "linear_attention" - - def get_state_dtype(self) -> tuple[torch.dtype]: - return MambaStateDtypeCalculator.linear_attention_state_dtype( - self.model_config.dtype, - self.cache_config.mamba_cache_dtype, - ) - - def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return MambaStateShapeCalculator.linear_attention_state_shape( - num_heads=self.num_heads, - tp_size=self.tp_size, - head_dim=self.head_dim) - - def __init__( - self, - hidden_size: int, - hidden_inner_size: int, - num_heads: int, - head_dim: int, - max_position: int, - block_size: int, - num_hidden_layer: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_idx: int = 0, - linear_layer_idx: int = 0, - prefix: str = "linear_attn", - ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.BLOCK = block_size - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = head_dim - self.total_num_heads = num_heads - self.hidden_inner_size = hidden_inner_size - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size - self.qkv_size = self.num_heads * self.head_dim - self.tp_hidden = self.head_dim * self.tp_heads - self.model_config = model_config - self.cache_config = cache_config - self.prefix = prefix - - self.qkv_proj = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size * 3, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.output_gate = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.output_gate", - ) - self.out_proj = RowParallelLinear( - self.hidden_inner_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.norm = MiniMaxText01RMSNormTP( - self.hidden_inner_size, - eps=1e-5, - ) - - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.num_heads) - if num_hidden_layer <= 1: - self.slope_rate = slope_rate * (1 + 1e-5) - else: - self.slope_rate = slope_rate * (1 - layer_idx / - (num_hidden_layer - 1) + 1e-5) - self.tp_slope = self.slope_rate[self.tp_rank * - self.tp_heads:(self.tp_rank + 1) * - self.tp_heads].contiguous() - - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - @staticmethod - def weight_direct_load(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - return - - @staticmethod - def _build_slope_tensor(n_attention_heads: int): - - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.tensor(get_slopes(n_attention_heads), - dtype=torch.float32).reshape( - n_attention_heads, 1, 1) - return slopes - - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - hidden = [] - for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx >= len(attn_metadata.query_start_loc): - break - if _prefill_idx >= len(state_indices_tensor): - break - # prefills are packed at end of batch in V1 - offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 - _start = attn_metadata.query_start_loc[offset + _prefill_idx] - _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] - slot_id = state_indices_tensor[offset + _prefill_idx] - qs = q[_start:_end].transpose(0, 1).contiguous() - ks = k[_start:_end].transpose(0, 1).contiguous() - vs = v[_start:_end].transpose(0, 1).contiguous() - slice_layer_cache = kv_cache[slot_id, ...] - - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( - qs, - ks, - vs, - slice_layer_cache, - self.tp_slope, - self.BLOCK, - layer_idx=self.layer_idx) - hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden_decode = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - if envs.VLLM_USE_V1: - hidden.insert(0, hidden_decode) - else: - hidden.append(hidden_decode) - - if not hidden: - return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) - - hidden = torch.concat(hidden, dim=0).contiguous() - return hidden - - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - if not envs.VLLM_USE_V1: - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - num_prefills = getattr(attn_metadata, "num_prefills", 0) - slot_id = state_indices_tensor[num_prefills:] - else: - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decodes] - hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, - slot_id, 32) - return hidden - - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - qkv32 = qkv.to(torch.float32) - qkvact = torch.nn.functional.silu(qkv32) - qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) - q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if envs.VLLM_USE_V1: - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) - kv_cache = self.kv_cache[forward_context.virtual_engine][0] - state_indices_tensor = attn_metadata.state_indices_tensor - - num_prefills = getattr(attn_metadata, "num_prefills", 0) - if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, - "num_decode_tokens", 0) - for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx] - q_end = attn_metadata.query_start_loc[num_decode_tokens - + prefill_idx + - 1] - query_len = q_end - q_start - context_len = attn_metadata.seq_lens[ - num_decode_tokens + prefill_idx] - query_len - if context_len == 0: - block_to_clear = state_indices_tensor[ - num_decode_tokens + prefill_idx] - kv_cache[block_to_clear, ...] = 0 - else: - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor - - decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if attn_metadata is None: - hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), - device=q.device, - dtype=q.dtype) - else: - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) - hidden = F.sigmoid(gate) * hidden - hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden - - class MiniMaxText01Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -644,27 +245,31 @@ class MiniMaxText01Attention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.attn", ) + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=int(rope_theta), + is_neox_style=True, + dtype=torch.float32, + ) return - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> torch.Tensor: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + **kwargs, + ) -> None: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb( - positions, q, k) - else: - q, k = attn_metadata.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + output[:], _ = self.o_proj(attn_output) class MiniMaxText01DecoderLayer(nn.Module): - def __init__( self, config: MiniMaxConfig, @@ -689,14 +294,17 @@ class MiniMaxText01DecoderLayer(nn.Module): head_dim = getattr(config, "head_dim", None) if head_dim is None: head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) + if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): + max_position_embeddings = min( + config.max_position_embeddings, config.max_model_len + ) if config.attention_type == 0: use_headxdim = True - hidden_inner = (head_dim * config.num_attention_heads - if use_headxdim else config.hidden_size) + hidden_inner = ( + head_dim * config.num_attention_heads + if use_headxdim + else config.hidden_size + ) self.self_attn = MiniMaxText01LinearAttention( hidden_size=self.hidden_size, hidden_inner_size=hidden_inner, @@ -710,14 +318,16 @@ class MiniMaxText01DecoderLayer(nn.Module): quant_config=quant_config, layer_idx=self._ilayer, linear_layer_idx=linear_layer_id, - prefix=prefix) + prefix=prefix, + ) elif config.attention_type == 1: self.self_attn = MiniMaxText01Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, head_dim=head_dim, rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, + if hasattr(config, "rotary_dim") + else head_dim, num_kv_heads=config.num_key_value_heads, max_position=max_position_embeddings, rope_theta=rope_theta, @@ -725,10 +335,12 @@ class MiniMaxText01DecoderLayer(nn.Module): quant_config=quant_config, layer_idx=self._ilayer, cache_config=cache_config, - prefix=prefix) + prefix=prefix, + ) else: raise ValueError( - f"Unsupported attention type: {self.config.attention_type}") + f"Unsupported attention type: {self.config.attention_type}" + ) if expert_num == 1: self.mlp = MiniMaxText01MLP( @@ -736,7 +348,8 @@ class MiniMaxText01DecoderLayer(nn.Module): intermediate_size=config.intermediate_size, quant_config=quant_config, layer_idx=self._ilayer, - prefix=prefix) + prefix=prefix, + ) else: self.block_sparse_moe = MiniMaxText01MoE( num_experts=expert_num, @@ -745,39 +358,51 @@ class MiniMaxText01DecoderLayer(nn.Module): intermediate_size=config.intermediate_size, layer_idx=self._ilayer, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) if config.attention_type == 0: self.layernorm_attention_alpha = getattr( - config, 'layernorm_linear_attention_alpha', - getattr(config, 'linear_attn_alpha_factor', 1)) + config, + "layernorm_linear_attention_alpha", + getattr(config, "linear_attn_alpha_factor", 1), + ) self.layernorm_attention_beta = getattr( - config, 'layernorm_linear_attention_beta', - getattr(config, 'linear_attn_beta_factor', 1)) + config, + "layernorm_linear_attention_beta", + getattr(config, "linear_attn_beta_factor", 1), + ) else: self.layernorm_attention_alpha = getattr( - config, 'layernorm_full_attention_alpha', - getattr(config, 'full_attn_alpha_factor', 1)) + config, + "layernorm_full_attention_alpha", + getattr(config, "full_attn_alpha_factor", 1), + ) self.layernorm_attention_beta = getattr( - config, 'layernorm_full_attention_beta', - getattr(config, 'full_attn_beta_factor', 1)) + config, + "layernorm_full_attention_beta", + getattr(config, "full_attn_beta_factor", 1), + ) self.layernorm_mlp_alpha = getattr( - config, 'layernorm_mlp_alpha', - getattr(config, 'mlp_alpha_factor', 1)) + config, "layernorm_mlp_alpha", getattr(config, "mlp_alpha_factor", 1) + ) self.layernorm_mlp_beta = getattr( - config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor', - 1)) - self.postnorm = getattr(config, 'postnorm', False) + config, "layernorm_mlp_beta", getattr(config, "mlp_beta_factor", 1) + ) + self.postnorm = getattr(config, "postnorm", False) self.shared_moe = False - shared_intermediate = getattr(config, 'shared_intermediate_size', 0) + shared_intermediate = getattr(config, "shared_intermediate_size", 0) if isinstance(shared_intermediate, list): - shared_intermediate = shared_intermediate[ - layer_id] if layer_id < len(shared_intermediate) else 0 + shared_intermediate = ( + shared_intermediate[layer_id] + if layer_id < len(shared_intermediate) + else 0 + ) if shared_intermediate > 0: self.shared_moe = True self.shared_mlp = MiniMaxText01MLP( @@ -785,7 +410,8 @@ class MiniMaxText01DecoderLayer(nn.Module): intermediate_size=shared_intermediate, quant_config=quant_config, layer_idx=self._ilayer, - prefix=prefix) + prefix=prefix, + ) self.coefficient = ReplicatedLinear( self.hidden_size, 1, @@ -793,36 +419,31 @@ class MiniMaxText01DecoderLayer(nn.Module): quant_config=quant_config, params_dtype=torch.float32, ) - self.coefficient.weight.weight_loader = ( - self.shared_moe_coefficient_loader) - self.shared_moe_mode = getattr(config, 'shared_moe_mode', - 'softmax') + self.coefficient.weight.weight_loader = self.shared_moe_coefficient_loader + self.shared_moe_mode = getattr(config, "shared_moe_mode", "softmax") return - def forward(self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - kv_caches: Union[list[dict], Optional[torch.Tensor]], - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - is_warmup: bool = False, - **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + is_warmup: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = self.self_attn( + self_attention_output = torch.empty_like(layernorm_output) + self.self_attn( hidden_states=layernorm_output, + output=self_attention_output, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha - self_attention_output = (self_attention_output * - self.layernorm_attention_beta) + self_attention_output = self_attention_output * self.layernorm_attention_beta layernorm_input = residual + self_attention_output layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -831,24 +452,21 @@ class MiniMaxText01DecoderLayer(nn.Module): if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) - output_mlp = self.shared_mlp(layernorm_output).to( - torch.float32) + output_mlp = self.shared_mlp(layernorm_output).to(torch.float32) coef, _ = self.coefficient(layernorm_output.to(torch.float32)) - if self.shared_moe_mode == 'softmax': + if self.shared_moe_mode == "softmax": coef = torch.nn.functional.softmax(coef, dim=-1) - hidden_states = moe_hidden_fp32 * ( - 1 - coef) + output_mlp * coef - elif self.shared_moe_mode == 'sigmoid': + hidden_states = moe_hidden_fp32 * (1 - coef) + output_mlp * coef + elif self.shared_moe_mode == "sigmoid": coef = torch.nn.functional.sigmoid(coef) - hidden_states = moe_hidden_fp32 * ( - 1 - coef) + output_mlp * coef + hidden_states = moe_hidden_fp32 * (1 - coef) + output_mlp * coef hidden_states = hidden_states.to(before_moe_dtype) else: @@ -862,33 +480,31 @@ class MiniMaxText01DecoderLayer(nn.Module): return hidden_states, None @staticmethod - def shared_moe_coefficient_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: + def shared_moe_coefficient_loader( + param: torch.Tensor, loaded_weight: torch.Tensor + ) -> None: assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight.to(torch.float32)) return +@support_torch_compile class MiniMaxText01Model(nn.Module): - - def __init__( - self, - config: MiniMaxConfig, - model_config: Optional[ModelConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.decoder_attention_types = getattr( - config, "attn_type_list", False) or getattr( - config, "decoder_attention_types", False) + config, "attn_type_list", False + ) or getattr(config, "decoder_attention_types", False) # The HF format uses "layer_types" instead of "attn_type_list" # where "linear_attention" is 0 and "full_attention" is 1 if not self.decoder_attention_types and hasattr(config, "layer_types"): @@ -916,76 +532,61 @@ class MiniMaxText01Model(nn.Module): self.embed_tokens = PPMissingLayer() def layer_fn(prefix): - layer_idx = int(prefix.split('.')[-1]) + layer_idx = int(prefix.split(".")[-1]) layer_config = config - layer_config.attention_type = self.decoder_attention_types[ - layer_idx] + layer_config.attention_type = self.decoder_attention_types[layer_idx] layer_config.layer_idx = layer_idx decoder_kwargs = { "quant_config": quant_config, "layer_id": layer_idx, "model_config": model_config, - "cache_config": cache_config + "cache_config": cache_config, } if layer_config.attention_type == 0: decoder_kwargs["linear_layer_id"] = sum( - 1 for i in range(layer_idx) - if self.decoder_attention_types[i] == 0) + 1 for i in range(layer_idx) if self.decoder_attention_types[i] == 0 + ) else: decoder_kwargs["linear_layer_id"] = None if hasattr(config, "num_local_experts") and isinstance( - config.num_local_experts, list): - decoder_kwargs["expert_num"] = config.num_local_experts[ - layer_idx] + config.num_local_experts, list + ): + decoder_kwargs["expert_num"] = config.num_local_experts[layer_idx] elif hasattr(config, "num_local_experts") and isinstance( - config.num_local_experts, int): + config.num_local_experts, int + ): decoder_kwargs["expert_num"] = config.num_local_experts else: decoder_kwargs["expert_num"] = 1 - return MiniMaxText01DecoderLayer(layer_config, - **decoder_kwargs, - prefix=prefix) + return MiniMaxText01DecoderLayer( + layer_config, **decoder_kwargs, prefix=prefix + ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") + config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers" + ) - linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) - if self.decoder_attention_types[i] == 0) + linear_layer_nums = sum( + 1 + for i in range(config.num_hidden_layers) + if self.decoder_attention_types[i] == 0 + ) max_slots_number = scheduler_config.max_num_seqs - self.cache_shape = (linear_layer_nums, max_slots_number, - config.num_attention_heads // - get_tensor_model_parallel_world_size(), - config.head_dim, config.head_dim) + self.cache_shape = ( + linear_layer_nums, + max_slots_number, + config.num_attention_heads // get_tensor_model_parallel_world_size(), + config.head_dim, + config.head_dim, + ) _dummy = torch.zeros(1) self._dtype = _dummy.dtype del _dummy - if not envs.VLLM_USE_V1: - self.minimax_cache = MinimaxCacheManager( - dtype=torch.float32, cache_shape=self.cache_shape) - - rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) - self.rotary_emb = MiniMaxText01RotaryEmbedding( - head_dim, - rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - is_neox_style=True, - cache_dtype=torch.float32, - ) - norm_kwargs = {} if hasattr(config, "rms_norm_eps"): norm_kwargs["eps"] = config.rms_norm_eps @@ -996,12 +597,12 @@ class MiniMaxText01Model(nn.Module): self.embed_scale = 1.0 return - def _clear_prefill_cache(self, attn_metadata, - minimax_cache_tensors: torch.Tensor, **kwargs): + def _clear_prefill_cache( + self, attn_metadata, minimax_cache_tensors: torch.Tensor, **kwargs + ): seq_to_slot_maps = {} seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), []) - for _, seq_to_slot_map in ( - self.minimax_cache.cache_indices_mapping.items()): + for _, seq_to_slot_map in self.minimax_cache.cache_indices_mapping.items(): seq_to_slot_maps.update(seq_to_slot_map) slots_to_clear = [] @@ -1009,50 +610,31 @@ class MiniMaxText01Model(nn.Module): if _prefill_id >= len(seq_id_map): break seq_id = seq_id_map[_prefill_id] - if attn_metadata.context_lens_tensor[ - _prefill_id] == 0 and seq_id in seq_to_slot_maps: + if ( + attn_metadata.context_lens_tensor[_prefill_id] == 0 + and seq_id in seq_to_slot_maps + ): slots_to_clear.append(seq_to_slot_maps[seq_id]) if slots_to_clear: - slots_tensor = torch.tensor(slots_to_clear, - device=minimax_cache_tensors.device, - dtype=torch.long) + slots_tensor = torch.tensor( + slots_to_clear, device=minimax_cache_tensors.device, dtype=torch.long + ) minimax_cache_tensors[:, slots_tensor, ...] = 0 - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) - def forward(self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> Union[torch.Tensor, IntermediateTensors]: + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - if not envs.VLLM_USE_V1 and attn_metadata is None: - return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - - if not envs.VLLM_USE_V1: - ( - minimax_cache_tensors, - state_indices_tensor, - ) = self.minimax_cache.current_run_tensors(**kwargs) - if getattr(attn_metadata, "num_prefills", 0) > 0: - self._clear_prefill_cache(attn_metadata, minimax_cache_tensors, - **kwargs) - - minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors, - state_indices_tensor) - else: - minimax_cache_params = None if get_pp_group().is_first_rank: if inputs_embeds is None: @@ -1065,39 +647,17 @@ class MiniMaxText01Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - minimax_cache_index = 0 - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - if attn_metadata is not None: - # TODO (tdoublep): this whole thing with the rotary_emb is - # weird. we shouldn't be passing it via attn_metadata imo. - if envs.VLLM_USE_V1: - if isinstance(layer.self_attn, MiniMaxText01Attention): - attn_metadata[layer.prefix + - ".attn"].rotary_emb = self.rotary_emb - else: - attn_metadata.rotary_emb = self.rotary_emb - - _caches = None - if not envs.VLLM_USE_V1 and isinstance( - layer.self_attn, MiniMaxText01LinearAttention): - current_state_layer = minimax_cache_index - _caches = minimax_cache_params.at_layer_idx( - current_state_layer) - minimax_cache_index += 1 + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( hidden_states=hidden_states, positions=positions, - kv_caches=_caches, attn_metadata=attn_metadata, residual=residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: @@ -1107,12 +667,9 @@ class MiniMaxText01Model(nn.Module): class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1126,80 +683,75 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len self.model = MiniMaxText01Model( - self.config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=quant_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, self.config.hidden_size, org_num_embeddings=self.config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.config.vocab_size + ) else: self.lm_head = PPMissingLayer() self.lm_head.float() flash_layer_count = sum( - 1 for attn_type in self.model.decoder_attention_types - if attn_type == 1) + 1 for attn_type in self.model.decoder_attention_types if attn_type == 1 + ) self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] return def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.model.minimax_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + input_buffers, **kwargs + ) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( - batch_size) + return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states.float(), - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states.float()) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -1211,7 +763,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def is_linear_attn_layer(layer_idx: int) -> bool: if layer_idx is None or layer_idx >= len( - self.model.decoder_attention_types): + self.model.decoder_attention_types + ): return False return self.model.decoder_attention_types[layer_idx] == 0 @@ -1219,39 +772,48 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): return "block_sparse_moe" in name and not name.endswith(".bias") def get_expert_id(param_name): - pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.' + pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\." match = re.search(pattern, param_name) if match: return match.group(1) return None - def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_sparse_moe_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if isinstance(self.config.num_local_experts, list): expert_params_mapping = [ - ("w13_weight" - if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(max(self.config.num_local_experts)) for weight_name in ["w1", "w2", "w3"] ] else: expert_params_mapping = [ - ("w13_scale" if weight_name in ["w1", "w3"] else - "w2_scale", f"{expert_id}.{weight_name}.weight_scale", - expert_id, weight_name) + ( + "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"{expert_id}.{weight_name}.weight_scale", + expert_id, + weight_name, + ) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] - ] + [("w13_weight" if weight_name in ["w1", "w3"] else - "w2_weight", f"{expert_id}.{weight_name}.weight", - expert_id, weight_name) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"]] - for (param_name, weight_name, expert_id, - shard_id) in expert_params_mapping: + ] + [ + ( + "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"{expert_id}.{weight_name}.weight", + expert_id, + weight_name, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: name_expert_id = get_expert_id(name) - if name_expert_id is not None and int(name_expert_id) != int( - expert_id): + if name_expert_id is not None and int(name_expert_id) != int(expert_id): continue if weight_name not in name: continue @@ -1261,19 +823,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): param = params_dict[name] weight_loader = param.weight_loader weight_loader = weight_loader_with_alias(name)(weight_loader) - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id, - shard_id=shard_id) + weight_loader( + param, + loaded_weight, + weight_name, + expert_id=expert_id, + shard_id=shard_id, + ) loaded_params.add(name) break else: if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -1282,8 +845,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def is_shared_mlp_weight(name: str) -> bool: return "shared_mlp" in name and not name.endswith(".bias") - def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_shared_mlp_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if not self.CONCAT_FFN: if "gate_proj" in name: name = name.replace("gate_proj", "w1", 1) @@ -1301,8 +865,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) if not self.CONCAT_FFN: weight_loader(param, loaded_weight) @@ -1312,31 +875,31 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): elif "down_proj" in name: weight_loader(param, loaded_weight) else: - raise AssertionError( - "MLP weight not in [gate_up_proj, down_proj]") + raise AssertionError("MLP weight not in [gate_up_proj, down_proj]") loaded_params.add(name) return def is_mha_weight(name: str) -> bool: return "self_attn" in name and not name.endswith(".bias") - def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_linear_attn_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if is_pp_missing_parameter(name, self): return param = params_dict[name] weight_loader = getattr( - param, "weight_loader", - MiniMaxText01LinearAttention.weight_direct_load) + param, "weight_loader", MiniMaxText01LinearAttention.weight_direct_load + ) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return - def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: - + def load_flash_attn_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: flash_mha_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -1344,16 +907,14 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - for (param_name, weight_name, - shard_id) in flash_mha_params_mapping: + for param_name, weight_name, shard_id in flash_mha_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) @@ -1363,36 +924,32 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return def is_layer_norm_weight(name: str) -> bool: - return "norm" in name and not name.endswith( - ".bias") and name in params_dict + return "norm" in name and not name.endswith(".bias") and name in params_dict - def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_layer_norm_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return - def load_basic_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_basic_weight(name: str, loaded_weight: torch.Tensor, self) -> None: if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -1401,7 +958,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): for name, loaded_weight in weights: weight_at_layer = which_layer(name) if weight_at_layer and weight_at_layer >= len( - self.model.decoder_attention_types): + self.model.decoder_attention_types + ): continue if is_layer_norm_weight(name): @@ -1431,7 +989,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.linear_attention_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -1441,13 +998,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, ...], ...]: """Calculate shape for MiniMaxText01LinearAttention cache. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index cc7db849a28bf..a25a7097a6ece 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,35 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn from transformers import BatchFeature, PretrainedConfig from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, unpad_image) + get_anyres_image_grid_shape, + unpad_image, +) from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder, - init_vision_tower_for_llava) +from .llava import ( + BaseLlavaMultiModalProcessor, + LlavaDummyInputsBuilder, + init_vision_tower_for_llava, +) from .llava_next import LlavaNextProcessingInfo from .pixtral import PixtralHFVisionModel from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, +) class MiniMaxVL01ImagePixelInputs(TensorSchema): @@ -44,10 +49,12 @@ class MiniMaxVL01ImagePixelInputs(TensorSchema): Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})] + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"}), + ] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] # This should be in `(height, width)` format. @@ -60,36 +67,43 @@ class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, - MiniMaxVL01ImageEmbeddingInputs] +MiniMaxVL01ImageInputs = Union[ + MiniMaxVL01ImagePixelInputs, MiniMaxVL01ImageEmbeddingInputs +] class MiniMaxVL01MultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(image_features) @@ -103,15 +117,13 @@ class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder): class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo): - def get_hf_config(self): # Need to override the config type return self.ctx.get_hf_config(PretrainedConfig) def get_hf_processor(self, **kwargs: object): hf_processor = self.ctx.get_hf_processor(**kwargs) image_processor = hf_processor.image_processor - image_processor.anyres_preprocess = ( - image_processor.anyres_for_vllm_preprocess) + image_processor.anyres_preprocess = image_processor.anyres_for_vllm_preprocess return hf_processor @@ -120,8 +132,8 @@ class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo): class MiniMaxVL01MultiModalProcessor( - BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]): - + BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo] +): def _call_hf_processor( self, prompt: str, @@ -164,13 +176,14 @@ class MiniMaxVL01MultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( MiniMaxVL01MultiModalProcessor, info=MiniMaxVL01ProcessingInfo, - dummy_inputs=MiniMaxVL01DummyInputsBuilder) -class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=MiniMaxVL01DummyInputsBuilder, +) +class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -195,16 +208,17 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = MiniMaxVL01MultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=True, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -217,104 +231,71 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, self.pad_token_id = self.config.pad_token_id self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds + self.language_model.make_empty_intermediate_tensors + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, - PixtralHFVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = tuple(vision_tower(p) for p in pixel_values) - - def select_features(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - return cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features, image_features), + feature_select_strategy = self.config.vision_feature_select_strategy + return tuple( + vision_tower(p, feature_select_strategy=feature_select_strategy) + for p in pixel_values ) # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 - def pack_image_features(self, image_features: list[torch.Tensor], - image_sizes: torch.Tensor): + def pack_image_features( + self, image_features: list[torch.Tensor], image_sizes: torch.Tensor + ): new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] - height = width = (self.config.vision_config.image_size // - self.config.vision_config.patch_size) + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) if height * width != base_image_feature.shape[0]: raise ValueError( - "The number of patches is not consistent with " - "the image size.") + "The number of patches is not consistent with the image size." + ) num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) - image_feature = image_feature.view(num_patch_height, - num_patch_width, height, - width, -1) - image_feature = image_feature.permute(4, 0, 2, 1, - 3).contiguous() + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, - image_sizes[image_idx]) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat( ( image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to( - image_feature.dtype), + self.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), - dim=0) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] image_feature = torch.cat( - (image_feature, - self.image_newline[None].to(image_feature)), - dim=0) + (image_feature, self.image_newline[None].to(image_feature)), dim=0 + ) new_image_features.append(image_feature) return new_image_features @@ -340,9 +321,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, if isinstance(image_features, torch.Tensor): return self.multi_modal_projector(image_features) - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features - ] + feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) @@ -350,7 +329,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, return self.pack_image_features(image_embeds, image_sizes) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: + self, **kwargs: object + ) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -359,34 +339,21 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None and image_sizes is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - return MiniMaxVL01ImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return MiniMaxVL01ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -401,31 +368,29 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 438513433d3b2..8e74425c5dbdd 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,43 +3,58 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn -from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig, - PretrainedConfig) +from transformers import ( + BatchFeature, + Mistral3Config, + PixtralVisionConfig, + PretrainedConfig, +) from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info @@ -67,38 +82,43 @@ class Mistral3PatchMerger(nn.Module): Learned merging of spatial_merge_size ** 2 patches """ - def __init__(self, vision_hidden_size: int, spatial_merge_size: int, - patch_size: int): + def __init__( + self, vision_hidden_size: int, spatial_merge_size: int, patch_size: int + ): super().__init__() self.vision_hidden_size = vision_hidden_size self.spatial_merge_size = spatial_merge_size self.patch_size = patch_size - self.merging_layer = nn.Linear(vision_hidden_size * - self.spatial_merge_size**2, - vision_hidden_size, - bias=False) + self.merging_layer = nn.Linear( + vision_hidden_size * self.spatial_merge_size**2, + vision_hidden_size, + bias=False, + ) - def forward(self, image_features: torch.Tensor, - image_sizes: torch.Tensor) -> torch.Tensor: - image_sizes = [(image_size[0] // self.patch_size, - image_size[1] // self.patch_size) - for image_size in image_sizes] + def forward( + self, image_features: torch.Tensor, image_sizes: torch.Tensor + ) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) + for image_size in image_sizes + ] tokens_per_image = [h * w for h, w in image_sizes] d = image_features.shape[-1] permuted_tensor = [] for image_index, image_tokens in enumerate( - image_features.split(tokens_per_image)): + image_features.split(tokens_per_image) + ): # Reshape image_tokens into a 2D grid h, w = image_sizes[image_index] - image_grid = image_tokens.view(h, w, d).permute(2, 0, - 1).unsqueeze(0) + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) grid = torch.nn.functional.unfold( image_grid, kernel_size=self.spatial_merge_size, - stride=self.spatial_merge_size) + stride=self.spatial_merge_size, + ) grid = grid.view(d * self.spatial_merge_size**2, -1).t() permuted_tensor.append(grid) @@ -108,38 +128,45 @@ class Mistral3PatchMerger(nn.Module): class Mistral3MultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - spatial_merge_size: int, - patch_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + spatial_merge_size: int, + patch_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.norm = RMSNorm(vision_hidden_size, eps=1e-5) self.patch_merger = Mistral3PatchMerger( vision_hidden_size=vision_hidden_size, spatial_merge_size=spatial_merge_size, - patch_size=patch_size) + patch_size=patch_size, + ) - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) - def forward(self, image_features: torch.Tensor, - image_sizes: torch.Tensor) -> torch.Tensor: + def forward( + self, image_features: torch.Tensor, image_sizes: torch.Tensor + ) -> torch.Tensor: image_features = self.norm(image_features) image_features = self.patch_merger(image_features, image_sizes) hidden_states, _ = self.linear_1(image_features) @@ -160,7 +187,6 @@ class LlavaLikeProcessor(Protocol): class BaseLlavaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> LlavaLikeConfig: return self.ctx.get_hf_config(Mistral3Config) @@ -196,7 +222,6 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -209,29 +234,30 @@ class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class Mistral3ProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) -class Mistral3MultiModalProcessor( - BaseMultiModalProcessor[Mistral3ProcessingInfo]): - +class Mistral3MultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -248,7 +274,6 @@ class Mistral3MultiModalProcessor( pixel_values = processed_outputs.get("pixel_values") if pixel_values is not None: - # Avoid padding since we need the output for each image to be # independent of other images for the cache to work correctly image_sizes = processed_outputs["image_sizes"] @@ -312,7 +337,8 @@ class Mistral3MultiModalProcessor( def _build_mistral3_info( - ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + ctx: InputProcessingContext, +) -> BaseLlavaProcessingInfo: hf_config = ctx.get_hf_config(Mistral3Config) assert isinstance(hf_config.vision_config, PixtralVisionConfig) return Mistral3ProcessingInfo(ctx) @@ -322,7 +348,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( @@ -335,7 +361,7 @@ def _build_mistral3_processor( def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: """Determine the number of hidden layers to initialize up to in the visual encoder. - + Args: hf_config: Model config with vision feature layer(s). """ @@ -346,10 +372,10 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: return _get_layer_index(feature_layers, num_hidden_layers) # If we have multiple feature layers, initialize up to the deepest one elif isinstance(feature_layers, (list, tuple)): - return max( - _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @@ -392,13 +418,16 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_processor( _build_mistral3_processor, info=_build_mistral3_info, - dummy_inputs=Mistral3DummyInputsBuilder) -class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, - SupportsMultiModal, SupportsPP): + dummy_inputs=Mistral3DummyInputsBuilder, +) +class Mistral3ForConditionalGeneration( + nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP +): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } hf_to_vllm_mapper = WeightsMapper( @@ -408,7 +437,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -429,11 +459,15 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, # NOTE: These are special cases for Pixtral-12B in the HF-format # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa - if (config.text_config.architectures is None - and config.text_config.model_type == "mistral"): + if ( + config.text_config.architectures is None + and config.text_config.model_type == "mistral" + ): config.text_config.architectures = ["MistralForCausalLM"] - if (config.projector_hidden_act is None - and config.vision_config.hidden_act == "gelu"): + if ( + config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu" + ): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. @@ -442,7 +476,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = Mistral3MultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, @@ -451,7 +486,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, patch_size=config.vision_config.patch_size, multimodal_projector_bias=config.multimodal_projector_bias, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) else: self.vision_tower = None self.multi_modal_projector = None @@ -463,24 +499,21 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Mistral3ImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - assert pixel_values is not None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - return Mistral3ImagePixelInputs( type="pixel_values_pixtral", - pixel_values=flatten_bn(pixel_values), + pixel_values=pixel_values, ) def _process_image_input( @@ -490,8 +523,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, if image_input["type"] == "image_embeds": return image_input["data"] - image_sizes = [(img.shape[-2], img.shape[-1]) - for img in image_input["pixel_values"]] + image_sizes = [ + (img.shape[-2], img.shape[-1]) for img in image_input["pixel_values"] + ] image_features = self.vision_tower(image_input["pixel_values"]) @@ -503,19 +537,19 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, for image_feature in image_features ] - image_embeds = self.multi_modal_projector(torch.cat(image_features), - image_sizes) + image_embeds = self.multi_modal_projector( + torch.cat(image_features), image_sizes + ) if len(feature_sizes) > 1: image_embeds = torch.split(image_embeds, feature_sizes) else: - image_embeds = (image_embeds, ) + image_embeds = (image_embeds,) return image_embeds def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -524,22 +558,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -578,39 +596,29 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: - [Mistral3ImagePixelInputs][] + [`Mistral3ImagePixelInputs`][vllm.model_executor.models.mistral3.Mistral3ImagePixelInputs] """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.vision_tower is None and self.multi_modal_projector is None: skip_prefixes = ["vision_tower.", "multi_modal_projector."] @@ -625,4 +633,5 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 30de83da49e0e..37b49349ec12c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,7 +23,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from collections.abc import Iterable + +import typing +from collections.abc import Callable, Iterable +from itertools import islice from typing import Optional, Union import torch @@ -32,27 +35,42 @@ from transformers import MixtralConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class MixtralMoE(nn.Module): @@ -64,39 +82,67 @@ class MixtralMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - dp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + enable_eplb: bool = False, + ): super().__init__() self.hidden_size = hidden_size + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + + # Expert Parallelism Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_routed_experts = num_experts + self.n_logical_experts = num_experts + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - dp_size=dp_size, - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + dp_size=dp_size, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -109,7 +155,6 @@ class MixtralMoE(nn.Module): class MixtralAttention(nn.Module): - def __init__( self, config: MixtralConfig, @@ -170,13 +215,15 @@ class MixtralAttention(nn.Module): base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -192,13 +239,13 @@ class MixtralAttention(nn.Module): class MixtralDecoderLayer(nn.Module): - def __init__( self, config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -213,18 +260,21 @@ class MixtralDecoderLayer(nn.Module): rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + prefix=f"{prefix}.block_sparse_moe", + enable_eplb=enable_eplb, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -237,23 +287,20 @@ class MixtralDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.block_sparse_moe(hidden_states) return hidden_states, residual @support_torch_compile class MixtralModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -261,11 +308,15 @@ class MixtralModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -275,17 +326,25 @@ class MixtralModel(nn.Module): org_num_embeddings=config.vocab_size, ) + self.enable_eplb = parallel_config.enable_eplb + self.num_redundant_experts = parallel_config.eplb_config.num_redundant_experts + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix + config, + cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=self.enable_eplb, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -307,13 +366,12 @@ class MixtralModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -324,10 +382,11 @@ class MixtralModel(nn.Module): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_local_experts, + num_redundant_experts=self.num_redundant_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -339,25 +398,27 @@ class MixtralModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -372,29 +433,47 @@ class MixtralModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: continue - name = name.replace(weight_name, param_name) + + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + if is_pp_missing_parameter(name_mapped, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + + if ( + name_mapped.endswith(".bias") or name_mapped.endswith("_bias") + ) and name_mapped not in params_dict: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + + param = params_dict[name_mapped] + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break else: + if is_expert_weight: + continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -405,14 +484,15 @@ class MixtralModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -439,8 +519,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = MixtralModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MixtralModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -451,15 +532,81 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) + + self.expert_weights = [] + self.moe_layers: list[FusedMoE] = [] + example_moe = None + + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + assert isinstance(layer, MixtralDecoderLayer) + if hasattr(layer, "block_sparse_moe") and isinstance( + layer.block_sparse_moe, MixtralMoE + ): + example_moe = layer.block_sparse_moe + self.moe_layers.append(layer.block_sparse_moe.experts) + + self.num_moe_layers = len(self.moe_layers) + + if example_moe is None: + raise RuntimeError("No MixtralMoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.num_expert_groups = 1 + self.num_shared_experts = 0 + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if hasattr(layer, "block_sparse_moe") and isinstance( + layer.block_sparse_moe, MixtralMoE + ): + moe = layer.block_sparse_moe + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -471,21 +618,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py deleted file mode 100644 index c8ad358c622d2..0000000000000 --- a/vllm/model_executor/models/mixtral_quant.py +++ /dev/null @@ -1,453 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Mixtral model.""" -from collections.abc import Iterable -from typing import Optional, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from transformers import MixtralConfig - -from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - quant_config=quant_config) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indices = np.array_split(range(self.num_total_experts), - self.tp_size)[self.rank].tolist() - if not self.expert_indices: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indices else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indices: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) - - -class MixtralAttention(nn.Module): - - def __init__( - self, - config: MixtralConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", None) - if self.head_dim is None: - self.head_dim = self.hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class MixtralDecoderLayer(nn.Module): - - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MixtralAttention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.block_sparse_moe = MixtralMoE(config=config, - quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) - return hidden_states, residual - - -class MixtralModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), - prefix=f"{prefix}.layers") - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class MixtralForCausalLM(nn.Module, SupportsPP): - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = MixtralModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py deleted file mode 100644 index 2a60450de4141..0000000000000 --- a/vllm/model_executor/models/mllama.py +++ /dev/null @@ -1,1698 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mllama model.""" -import math -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union - -import numpy as np -import torch -import torch.nn.functional as F -import transformers.models.mllama.configuration_mllama as config_mllama -from PIL.Image import Image -from torch import nn -from transformers import BatchFeature, MllamaConfig -from transformers.modeling_outputs import (BaseModelOutput, - CausalLMOutputWithPast) -from transformers.models.mllama.image_processing_mllama import ( - get_optimal_tiled_canvas) -from transformers.models.mllama.processing_mllama import ( - MllamaProcessor, get_cross_attention_token_mask) - -import vllm.distributed.parallel_state as ps -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.selector import _Backend -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group, get_tp_group -from vllm.forward_context import get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVCrossParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .clip import CLIPMLP -from .interfaces import SupportsMultiModal, SupportsV0Only -from .llama import LlamaDecoderLayer, LlamaMLP -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix - -logger = init_logger(__name__) - - -class MllamaImagePixelInputs(TensorSchema): - """ - Dimensions: - - batch_size: Batch size - - max_num_image: Max number of images - - max_num_chunk: Max number of chunks - - max_num_tiles: Max number of tiles per image - - num_channel: Number of channels - - height: Height - - width: Width - """ - - type: Literal["pixel_values"] = "pixel_values" - - data: Annotated[torch.Tensor, - TensorShape("batch_size", "max_num_image", "max_num_chunk", - "num_channel", "height", "width")] - - aspect_ratio_ids: Annotated[torch.Tensor, - TensorShape("batch_size", "max_num_image")] - - aspect_ratio_mask: Annotated[ - torch.Tensor, - TensorShape("batch_size", "max_num_image", "max_num_tiles")] - - -# TODO: support LlamaImageEmbeddingInputs - - -def calc_token_per_chunk(image_size: int) -> int: - assert image_size % 14 == 0, "chunk size should be multiple of 14" - token_per_chunk = (image_size // 14)**2 + 1 - return token_per_chunk - - -class MllamaProcessingInfo(BaseProcessingInfo): - - def get_hf_config(self) -> MllamaConfig: - return self.ctx.get_hf_config(MllamaConfig) - - def get_hf_processor(self, **kwargs: object) -> MllamaProcessor: - return self.ctx.get_hf_processor(MllamaProcessor, **kwargs) - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - def get_token_per_chunk_from_config(self) -> int: - image_size = self.get_hf_config().vision_config.image_size - return calc_token_per_chunk(image_size) - - def get_num_tiles_per_image(self, image_height: int, - image_width: int) -> int: - vision_config = self.get_hf_config().vision_config - max_num_tiles = vision_config.max_num_tiles - image_size = vision_config.image_size - tiled_height, tiled_width = get_optimal_tiled_canvas( - image_height, - image_width, - max_num_tiles, - tile_size=image_size, - ) - num_tiles_height = tiled_height // image_size - num_tiles_width = tiled_width // image_size - return num_tiles_height * num_tiles_width - - def get_image_size_with_most_features(self) -> ImageSize: - vision_config = self.get_hf_config().vision_config - image_size = vision_config.image_size - max_num_tiles = vision_config.max_num_tiles - # Result in the max possible feature size (h:w = 16:1) - return ImageSize(height=max_num_tiles * image_size, width=image_size) - - -class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - image_token = processor.image_token - - return image_token * num_images - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = \ - self.info.get_image_size_with_most_features() - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - -class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] - ): - - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - ) -> MultiModalEncDecInputs: - mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs) - - image_token_id = self.info.get_hf_config().image_token_index - # Check that the number of image tokens in the decoder prompt matches - # the number of images provided in mm_data - num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id) - image_data = mm_data.get("image", []) - num_images = 1 if isinstance(image_data, Image) else len(image_data) - if num_image_tokens != num_images: - raise ValueError( - f"The number of image tokens ({num_image_tokens}) must be" - f" the same as the number of images ({num_images})") - - # Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501 - # P0 & P1 do cross attention with placeholder of <IMG0> - # P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2> - # Example input to encoder and decoder: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128256, 128256, ..., 128256], - # 'prompt': '<|image|><|image|>...<|image|>', - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 - # }, - # } - - if mm_data: - hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.image_token - - # Since only the last group of consecutive images - # are attended by the decoded tokens, we only need to - # get the number of tokens for those images. - token_per_chunk = self.info.get_token_per_chunk_from_config() - num_decode_images = self._get_num_image_in_last_group( - mm_inputs["prompt_token_ids"]) - num_encode_images = num_images - num_decode_images - - # Set encoder prompt length based on the number of tiles. - # This tells the block manager to allocate correct number - # of slots for encoder tokens. - num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"] - decode_tiles = num_tiles[num_encode_images:num_images].sum().item() - num_tokens = decode_tiles * token_per_chunk - mm_inputs["encoder_prompt_token_ids"] = [image_token_id - ] * num_tokens - mm_inputs["encoder_prompt"] = image_token * num_tokens - - return mm_inputs - - def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int: - num_images = 0 - for token_id in prompt_token_ids[::-1]: - if token_id == self.info.get_hf_config().image_token_index: - num_images += 1 - elif num_images > 0: - break - return num_images - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - tokenizer = self.info.get_tokenizer() - if mm_data: - num_tiles = [ - self.info.get_num_tiles_per_image(img.height, img.width) - for img in mm_data["images"] - ] - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs) - processed_outputs["num_tiles"] = torch.tensor(num_tiles) - for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): - processed_outputs[k] = processed_outputs[k].squeeze(0) - - processed_token_ids = processed_outputs.pop("input_ids") - start_idx, end_idx = 0, processed_token_ids.size(1) - processed_prompt_text = tokenizer.decode(processed_token_ids[0]) - - hf_processor = self.info.get_hf_processor() - bos_token = hf_processor.bos_token - # Remove the bos_token from the start of prompt, - # because we all know there would be image_token. - if processed_prompt_text.startswith(bos_token): - start_idx += 1 - # Remove the bos_token from the end of prompt, - # because text is empty in this case. - if processed_prompt_text.endswith(bos_token): - end_idx -= 1 - processed_outputs[ - "input_ids"] = processed_token_ids[:, start_idx:end_idx] - else: - processed_outputs = tokenizer(prompt, - add_special_tokens=False, - return_tensors="pt") - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - aspect_ratio_ids=MultiModalFieldConfig.batched("image"), - aspect_ratio_mask=MultiModalFieldConfig.batched("image"), - num_tiles=MultiModalFieldConfig.batched("image"), - ) - - def create_encoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - data = mm_data.get("image", []) - num_images = 1 if isinstance(data, Image) else len(data) - image_token_id = self.info.get_hf_config().image_token_index - return [image_token_id] * num_images - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - token_per_chunk = self.info.get_token_per_chunk_from_config() - image_token_id = self.info.get_hf_config().image_token_index - - def get_replacement_mllama(item_idx): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - num_tile = self.info.get_num_tiles_per_image( - image_height=image_size.height, - image_width=image_size.width, - ) - num_tokens = num_tile * token_per_chunk - return [image_token_id] * num_tokens - - return [ - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_replacement_mllama, - ) - ] - - -def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, -) -> torch.Tensor: - # Expand aspect ratio mask to target_length - batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, - 1).to(dtype) - attention_mask = attention_mask.repeat(1, 1, target_length, 1) - - # Mask padding patches - pad_patches = target_length - num_patches - attention_mask[:, :, -pad_patches:] = 0 - - # Invert the mask (0 -> 1, 1 -> 0) - attention_mask = 1 - attention_mask - - # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) - attention_mask = attention_mask.reshape(batch_size, - max_num_tiles * target_length, 1) - attention_mask = attention_mask @ attention_mask.transpose( - -1, -2) * torch.finfo(dtype).min - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - - -class ColumnParallelConv2dPatch(torch.nn.Module): - """Conv2D Patching layer with model parallelism. - Column parallel over unfolded input. - Arguments: - in_channels: Input channels. - out_channels: Output channels. - kernel_size: Size of convolution kernel. - stride (default 1): Stride for convolution. - bias (default False): Use bias in Conv2d. - Input: (bsz, in_channels, width, height) - Output: (bsz, num_tokens, out_channels) - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int, int]], - stride: Union[int, tuple[int, int]], - bias: bool = False, - ) -> None: - super().__init__() - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) - self._linear = ColumnParallelLinear( - in_channels * kernel_size[0] * kernel_size[1], - out_channels, - bias=bias, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self._unfold(x) - x = x.permute(0, 2, 1) - x, _ = self._linear(x) - return x - - -class MllamaPrecomputedAspectRatioEmbedding(nn.Module): - - def __init__(self, - config: config_mllama.MllamaVisionConfig, - is_gated: bool = True): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.is_gated = is_gated - - self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, - self.max_num_tiles * self.hidden_size) - if is_gated: - self.gate = nn.Parameter(torch.zeros(1)) - - def forward(self, hidden_state: torch.Tensor, - aspect_ratio_ids: torch.Tensor) -> torch.Tensor: - embeddings = self.embedding(aspect_ratio_ids) - embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, - self.hidden_size) - - if self.is_gated: - embeddings = embeddings * self.gate.tanh() - - hidden_state = hidden_state + embeddings - return hidden_state - - -class MllamaPrecomputedPositionEmbedding(nn.Module): - - def __init__(self, config: config_mllama.MllamaVisionConfig): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.num_patches = (config.image_size // config.patch_size)**2 + 1 - self.hidden_size = config.hidden_size - self.scale = config.hidden_size**-0.5 - - self.gate = nn.Parameter(torch.zeros(1)) - - # position embedding - position_embedding = torch.randn(self.num_patches, self.hidden_size) - self.embedding = nn.Parameter(self.scale * position_embedding) - - # tile position embedding - self.tile_embedding = nn.Embedding( - self.max_aspect_ratio_id + 1, - self.max_num_tiles * self.num_patches * self.hidden_size) - - def forward(self, hidden_state: torch.Tensor, - aspect_ratio_ids: torch.Tensor) -> torch.Tensor: - # position embeddings - gated_position_embedding = (1 - self.gate.tanh()) * self.embedding - hidden_state = hidden_state + gated_position_embedding.view( - 1, 1, self.num_patches, self.hidden_size) - - # precomputed tile position embeddings - tile_position_embedding = self.tile_embedding(aspect_ratio_ids) - batch_size = hidden_state.shape[0] - tile_position_embedding = tile_position_embedding.reshape( - batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) - gated_tile_position_embedding = self.gate.tanh( - ) * tile_position_embedding - hidden_state = hidden_state + gated_tile_position_embedding - - return hidden_state - - -# TODO: support other attention backends for attention in vision model -class MllamaVisionSdpaAttention(nn.Module): - - def __init__(self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__() - - tensor_parallel_size = get_tp_group().world_size - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // tensor_parallel_size - self.q_size = self.num_local_heads * self.head_dim - self.kv_size = self.num_local_heads * self.head_dim - - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_state) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - dropout_p=0.0) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(attn_output.shape[0], - attn_output.shape[1], -1) - output, _ = self.o_proj(attn_output) - return output - - -class MllamaVisionEncoderLayer(nn.Module): - - def __init__( - self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - is_gated: bool = False, - ) -> None: - super().__init__() - - self.hidden_size = config.hidden_size - self.num_attention_heads = config.attention_heads - self.is_gated = is_gated - self.intermediate_size = config.intermediate_size - - self.self_attn = MllamaVisionSdpaAttention( - config, quant_config=quant_config, prefix=f"{prefix}.self_attn") - self.mlp = CLIPMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - - self.input_layernorm = nn.LayerNorm(self.hidden_size, - eps=config.norm_eps) - self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, - eps=config.norm_eps) - - # there used to be an if else here, no code path - if is_gated: - self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) - self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - # Self Attention - residual = hidden_state - hidden_state = self.input_layernorm(hidden_state) - hidden_state = self.self_attn(hidden_state, - attention_mask=attention_mask) - gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() - hidden_state = residual + gate_attn * hidden_state - - # Feed forward - residual = hidden_state - hidden_state = self.post_attention_layernorm(hidden_state) - hidden_state = self.mlp(hidden_state) - gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() - hidden_state = residual + gate_ffn * hidden_state - - return hidden_state - - -class MllamaVisionEncoder(nn.Module): - - def __init__( - self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig], - num_layers: int = 32, - is_gated: bool = False, - output_hidden_states=None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.layers = nn.ModuleList([ - MllamaVisionEncoderLayer(config, - quant_config=quant_config, - is_gated=is_gated, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_layers) - ]) - self.output_hidden_states = output_hidden_states or [] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> Union[BaseModelOutput]: - encoder_states = () - - for i, encoder_layer in enumerate(self.layers): - if i in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - hidden_states = encoder_layer( - hidden_states, - attention_mask, - ) - - if len(self.layers) - 1 in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - return hidden_states, encoder_states - - -class MllamaVisionModel(nn.Module): - - def __init__( - self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - ) -> None: - super().__init__() - - self.image_size = config.image_size - self.patch_size = config.patch_size - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.in_channels = config.num_channels - self.intermediate_layers_indices = config.intermediate_layers_indices - - self.num_patches = (self.image_size // self.patch_size)**2 + 1 - self.scale = config.hidden_size**-0.5 - - self.patch_embedding = ColumnParallelConv2dPatch( - in_channels=config.num_channels, - out_channels=self.hidden_size, - kernel_size=self.patch_size, - stride=self.patch_size, - bias=False, - ) - - self.class_embedding = nn.Parameter(self.scale * - torch.randn(self.hidden_size)) - self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( - config) - - self.pre_tile_positional_embedding = \ - MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) - self.post_tile_positional_embedding = \ - MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) - - # layer norms - self.layernorm_pre = nn.LayerNorm(self.hidden_size) - self.layernorm_post = nn.LayerNorm(self.hidden_size) - - # encoders - self.transformer = MllamaVisionEncoder( - config, - quant_config, - config.num_hidden_layers, - is_gated=False, - output_hidden_states=config.intermediate_layers_indices, - prefix=f"{prefix}.transformer", - ) - self.global_transformer = MllamaVisionEncoder( - config, - quant_config, - config.num_global_layers, - is_gated=True, - prefix=f"{prefix}.global_transformer", - ) - - def apply_class_embedding(self, - hidden_state: torch.Tensor) -> torch.Tensor: - batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, - hidden_size) - hidden_state = torch.cat([class_embedding, hidden_state], dim=1) - return hidden_state - - def forward(self, pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, \ - height, width = pixel_values.shape - - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, - height, width) - aspect_ratio_ids = aspect_ratio_ids.reshape( - batch_size * num_concurrent_media, -1) - - # patch embedding - patch_embeds = self.patch_embedding( - pixel_values.to(self.layernorm_pre.weight.dtype)) - hidden_state = patch_embeds - hidden_state = ps.get_tp_group().all_gather(hidden_state) - - # tile embeddings - _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, -1, dim) - hidden_state = self.pre_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - - # apply cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim) - hidden_state = self.apply_class_embedding(hidden_state) - num_patches += 1 - - # apply position embeddings - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, num_patches, dim) - hidden_state = self.gated_positional_embedding(hidden_state, - aspect_ratio_ids) - - # apply encoder - hidden_state = self.layernorm_pre(hidden_state) - - # Compute the number of tokens to pad - num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 - # Compute padding tuple for pad function - padding = ( - 0, 0, 0, num_padding_patches - ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) - # Pad the tensor - hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) - slice_index = -num_padding_patches if num_padding_patches > 0 else None - - attention_mask = aspect_ratio_mask.reshape( - batch_size * num_concurrent_media, -1) - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.layernorm_pre.weight.dtype, - ) - - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, - dim) - output = self.transformer( - hidden_state, - attention_mask=attention_mask, - ) - hidden_state, intermediate_hidden_states = output[0], output[1] - intermediate_hidden_states = torch.stack(intermediate_hidden_states, - dim=-1) - - # apply global encoder - hidden_state = self.layernorm_post(hidden_state) - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = self.post_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles * (num_patches + num_padding_patches), dim) - hidden_state = self.global_transformer( - hidden_state, attention_mask=attention_mask)[0] - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = hidden_state[:, :, :slice_index] - - # adding intermediate layer outputs - hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, - num_tiles, num_patches, dim) - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, num_tiles, - num_patches + num_padding_patches, -1) - intermediate_hidden_states = intermediate_hidden_states[:, :, : - slice_index] - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1) - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], - dim=-1) - return hidden_state - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() - for name, loaded_weight in weights: - if 'patch_embedding._linear.weight' in name: - loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - updated_params.add(name) - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict.pop(name) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - updated_params.add(name) - return updated_params - - -class MllamaTextRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - MllamaTextRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MllamaTextCrossAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: Optional[config_mllama.MllamaTextConfig] = None, - layer_idx: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.pipeline_parallel_rank = get_pp_group().rank_in_group - self.tensor_parallel_size = get_tp_group().world_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - - self.num_local_heads = self.num_heads // self.tensor_parallel_size - self.num_local_key_value_heads = \ - self.num_key_value_heads // self.tensor_parallel_size - self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - - self.layer_idx = layer_idx - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.q_local_size = self.num_local_heads * self.head_dim - self.kv_local_size = self.num_local_key_value_heads * self.head_dim - - self.qkv_proj = QKVCrossParallelLinear( - self.hidden_size, - self.head_dim, - self.num_heads, - self.num_key_value_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, - # use huggingface's instead - self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.scaling = self.head_dim**-0.5 - - self.attn = Attention( - self.num_local_heads, - self.head_dim, - self.scaling, - self.num_local_key_value_heads, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_DECODER, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor], - kv_range_for_decode: Optional[list[tuple[int, int]]], - cross_attention_states: Optional[torch.Tensor], - ) -> torch.Tensor: - q, k, v = self.qkv_proj(hidden_states, cross_attention_states) - if cross_attention_states is not None: - k = k.view(-1, self.num_local_key_value_heads, self.head_dim) - v = v.view(-1, self.num_local_key_value_heads, self.head_dim) - k = self.k_norm(k) - - q = q.view(-1, self.num_local_heads, self.head_dim) - q = self.q_norm(q) - - if attention_mask is not None: - output = self._attention_with_mask(q, k, v, attention_mask, - kv_range_for_decode) - else: - output = self.attn( - q.view(-1, self.num_local_heads * self.head_dim), k, v) - out, _ = self.o_proj(output) - return out - - def _attention_with_mask( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_mask: torch.Tensor, - kv_range_for_decode: list[tuple[int, int]], - ) -> torch.Tensor: - kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - # Skip writing kv-cache for the initial profiling run. - # TODO (NickLucche) replace with custom attn bias and use standard attn - if len(kv_cache.shape) > 1: - i = torch.ones(1, dtype=torch.float32) - if self.attn.backend in (_Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1): - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) - torch.ops._C_cache_ops.reshape_and_cache_flash( - cached_k, - cached_v, - kv_cache[0], - kv_cache[1], - attn_metadata. - cross_slot_mapping, # type: ignore[union-attr] - "auto", - i, - i, - ) - elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH, - _Backend.TORCH_SDPA): - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_local_key_value_heads, self.head_dim) - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) - PagedAttention.write_to_paged_cache( - cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", i, i) - else: - raise ValueError( - f"Unsupported Attention backend {self.attn.backend} " - "enum found. Expected the Attention backend to be " - "FLASH_ATTN, FLASH_ATTN_VLLM_V1, " - "XFORMERS or TORCH_SDPA.") - - # We have to call torch.sdpa for prefill when using a - # custom cross-attention mask. Because the mask is not a - # standard causal mask, neither a block diagonal mask which - # can be optimized by xformers.BlockDiagonalMask. - # The mask is specially calculated for supporting multi - # images and interleaved images. - q_len = q.shape[0] - kv_len = k.shape[0] - q = q.transpose(0, 1).view(self.num_local_key_value_heads, - self.num_key_value_groups, q_len, - self.head_dim).contiguous() - k = k.transpose(0, - 1)[:, - None, :, :].expand(self.num_local_key_value_heads, - self.num_key_value_groups, - kv_len, - self.head_dim).contiguous() - v = v.transpose(0, - 1)[:, - None, :, :].expand(self.num_local_key_value_heads, - self.num_key_value_groups, - kv_len, - self.head_dim).contiguous() - attention_mask = attention_mask.view(1, 1, q_len, kv_len) - output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - is_causal=False) - output = output.permute(2, 0, 1, 3).reshape( - q_len, self.num_local_heads * self.head_dim) - return output - - -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): - """Cross-attention transformer block with tanh-gated attention - and feedforward.""" - - def __init__( - self, - config: config_mllama.MllamaTextConfig, - layer_idx: int, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.cross_attn = MllamaTextCrossAttention( - config=config, - layer_idx=layer_idx, - quant_config=quant_config, - prefix=f"{prefix}.cross_attn", - ) - - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) - - self.mlp = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: torch.Tensor, - cross_attention_mask: torch.Tensor, - kv_range_for_decode: Optional[list[tuple[int, int]]], - full_text_row_masked_out_mask: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states = self.cross_attn( - hidden_states=hidden_states, - attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - cross_attention_states=cross_attention_states, - ) - hidden_states = full_text_row_masked_out_mask * hidden_states - hidden_states = residual + self.cross_attn_attn_gate.tanh( - ) * hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = full_text_row_masked_out_mask * hidden_states - hidden_states = residual + self.cross_attn_mlp_gate.tanh( - ) * hidden_states - return hidden_states - - -class MllamaTextModel(nn.Module): - config_class = config_mllama.MllamaTextConfig - base_model_prefix = "model" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config.text_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, - config.hidden_size) - self.cross_attention_layers = config.cross_attention_layers - - layers = [] - for layer_idx in range(config.num_hidden_layers): - if layer_idx in self.cross_attention_layers: - layers.append( - MllamaCrossAttentionDecoderLayer( - config, - layer_idx, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - )) - else: - # TODO: force LlamaDecoderLayer to config.attention_bias=False - layers.append( - LlamaDecoderLayer( - config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - )) - - self.layers = nn.ModuleList(layers) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - cross_attention_states: Optional[torch.LongTensor], - cross_attention_mask: Optional[torch.LongTensor], - kv_range_for_decode: Optional[list[tuple[int, int]]], - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, - torch.Tensor]], - skip_cross_attention: bool, - ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - if idx in self.cross_attention_layers: - if not skip_cross_attention: - hidden_states = decoder_layer( - hidden_states=hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - full_text_row_masked_out_mask= - full_text_row_masked_out_mask, - ) - else: - hidden_states, residual = decoder_layer( - positions=positions, - hidden_states=hidden_states, - residual=None, - ) - hidden_states = hidden_states + residual - hidden_states = self.norm(hidden_states) - return hidden_states - - -class MllamaForCausalLM(nn.Module): - config_class = config_mllama.MllamaTextConfig - base_model_prefix = "language_model" - _no_split_modules = [ - "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" - ] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config.text_config - quant_config = vllm_config.quant_config - self.quant_config = quant_config - - self.vocab_size = config.vocab_size - self.model = MllamaTextModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - prefix=f"{prefix}.lm_head", - ) - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - cross_attention_states: Optional[torch.LongTensor], - cross_attention_mask: Optional[torch.LongTensor], - kv_range_for_decode: Optional[list[tuple[int, int]]], - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, - torch.Tensor]], - skip_cross_attention: bool, - ) -> torch.Tensor: - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - skip_cross_attention=skip_cross_attention, - ) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() - for name, loaded_weight in weights: - if 'patch_embedding.weight' in name: - name = name.replace('patch_embedding.weight', - 'patch_embedding._linear.weight') - loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - updated_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - updated_params.add(name) - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - orig_name = name - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - logger.debug("Missing name %s, orig name %s", name, - orig_name) - continue - - param = params_dict.pop(name) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - updated_params.add(name) - return updated_params - - -@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, - info=MllamaProcessingInfo, - dummy_inputs=MllamaDummyInputsBuilder) -class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsV0Only): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] - } - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # mapping for new names in checkpoint saved after transformers v4.52 - "model.vision_model.": "vision_model.", - "model.multi_modal_projector.": "multi_modal_projector.", - "model.language_model.": "language_model.model.", - "lm_head.": "language_model.lm_head.", - }, - orig_to_new_suffix={ - "patch_embedding.weight": "patch_embedding._linear.weight", - }, - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return "<|image|>" - - raise ValueError("Only image modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config: MllamaConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.vocab_size = config.text_config.vocab_size - self.hidden_size = config.text_config.hidden_size - self.max_num_tiles = config.vision_config.max_num_tiles - self.vision_output_dim = config.vision_config.vision_output_dim - self.pad_token_id = \ - config.pad_token_id if config.pad_token_id is not None else -1 - self.image_size = config.vision_config.image_size - self.image_token_id = config.image_token_index - - self.vision_model = MllamaVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_model")) - self.language_model = MllamaForCausalLM( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - self.multi_modal_projector = ColumnParallelLinear( - config.vision_config.vision_output_dim, - config.text_config.hidden_size, - bias=True, - quant_config=quant_config, - gather_output=True, - prefix=maybe_prefix(prefix, "multi_modal_projector"), - ) - self.logits_processor = LogitsProcessor(config.output_hidden_states, - config.text_config.vocab_size) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.language_model.lm_head, - hidden_states, sampling_metadata) - return logits - - def unpack_data(self, - image_data: Union[list[torch.Tensor], torch.Tensor], - padding_value=0) -> torch.Tensor: - if isinstance(image_data, torch.Tensor): - # torch.Tensor - return image_data - else: - assert isinstance( - image_data[0], - torch.Tensor), "Image data is not properly batched." - # list[torch.Tensor] - bsz = len(image_data) - max_length = max(t.size(0) for t in image_data) - trailing_dims = image_data[0].shape[1:] - for data in image_data: - cur_trailing_dims = data.shape[1:] - assert cur_trailing_dims == trailing_dims - output_tensor = torch.full((bsz, max_length, *trailing_dims), - padding_value, - dtype=image_data[0].dtype, - device=image_data[0].device) - for i, t in enumerate(image_data): - output_tensor[i, :t.size(0)] = t - return output_tensor - - def _parse_and_validate_image_input(self, **kwargs: object): - # tensor with the same shape will be batched together by - # MultiModalKwargs.batch, so pixel_values here can be: - # - list[torch.Tensor]: - # with shape (num_image, num_tiles, 3, image_res, image_res) - # - torch.Tensor: - # with shape (bs, num_image, num_tiles, 3, image_res, image_res) - pixel_values: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "pixel_values", None) - image_embeds: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "image_embeds", None) - aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "aspect_ratio_ids", None) - aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "aspect_ratio_mask", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None and image_embeds is not None: - raise ValueError( - "Both pixel values and image embeds are provided.") - - if pixel_values is not None: - assert aspect_ratio_ids is not None - assert aspect_ratio_mask is not None - - return MllamaImagePixelInputs( - type="pixel_values", - data=self.unpack_data(pixel_values), - aspect_ratio_ids=self.unpack_data(aspect_ratio_ids), - aspect_ratio_mask=self.unpack_data(aspect_ratio_mask)) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") - - def _get_and_validate_encoder_lens( - self, - encoder_seq_lens: list[int], - num_tiles: list[list[int]], - num_tokens_per_tile: int, - ) -> list[int]: - # Get the actual number of encoder tokens for each sample. - # Because attn_metadata.encoder_seq_lens only counts the last - # group of images for each sample, which is used to cheat the - # block manager to allocate blocks for those images only. - # See MllamaMultiModalProcessor for more details. - actual_encoder_seq_lens = [ - sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles - ] - - # remove 0 encoder len entries for text-only requests for these - # assertions - attn_metadata_lens = [x for x in encoder_seq_lens if x > 0] - assert len(actual_encoder_seq_lens) == len(attn_metadata_lens) - for actual_len, last_group_len in zip(actual_encoder_seq_lens, - attn_metadata_lens): - assert actual_len >= last_group_len - - return actual_encoder_seq_lens - - def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: AttentionMetadata, - actual_encoder_seq_lens: list[int]): - - cross_attention_states_flat = torch.zeros( - sum(actual_encoder_seq_lens), - cross_attention_states.shape[-1], - device=cross_attention_states.device, - dtype=cross_attention_states.dtype) - start_pos = 0 - for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens, - cross_attention_states): - end_pos = start_pos + seq_len - cross_attention_states_flat[ - start_pos:end_pos] = vision_token_in_batch[:seq_len] - start_pos = end_pos - cross_attention_states = cross_attention_states_flat - return cross_attention_states - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_cross_attention_states( - self, - image_inputs: MllamaImagePixelInputs, - attn_metadata: AttentionMetadata, - actual_encoder_seq_lens: list[int], - ) -> tuple[torch.Tensor]: - # NOTE: llama's reference implementation runs vision model on CPU - pixel_values = image_inputs['data'] - aspect_ratio_ids = image_inputs['aspect_ratio_ids'] - aspect_ratio_mask = image_inputs['aspect_ratio_mask'] - cross_attention_states = self.vision_model(pixel_values, - aspect_ratio_ids, - aspect_ratio_mask) - cross_attention_states, _ = self.multi_modal_projector( - cross_attention_states) - - bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view( - bsz, -1, image_token_dim) - - cross_attention_states = self.flat_encoder_result( - cross_attention_states, attn_metadata, actual_encoder_seq_lens) - - return cross_attention_states - - def get_cross_attention_mask( - self, - input_ids: torch.Tensor, - attn_metadata: AttentionMetadata, - num_tiles: list[list[int]], - num_tokens_per_tile: int, - dtype: torch.dtype, - ) -> tuple[torch.Tensor, torch.Tensor]: - token_ids = input_ids.tolist() - start = 0 - batch_token_ids = [] - for seq_len in attn_metadata.seq_lens: - batch_token_ids.append(token_ids[start:start + seq_len]) - start += seq_len - sparse_mask = [ - get_cross_attention_token_mask(t, self.image_token_id) - for t in batch_token_ids - ] - - # Skip generating cross-attention mask if all samples - # are text-only or have only 1 leading image. - if skip_attention_mask(sparse_mask): - return None, None - - dense_mask, tile_range_for_decode = \ - convert_sparse_cross_attention_mask_to_dense( - sparse_mask, num_tiles, attn_metadata.seq_lens) - cross_attention_mask = \ - convert_dense_cross_attention_mask_to_tensor( - dense_mask, num_tokens_per_tile, input_ids.device, dtype) - kv_range_for_decode = [[ - t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile - ] for t in tile_range_for_decode] - - return cross_attention_mask, kv_range_for_decode - - def get_full_text_row_masked_out_mask( - self, - attn_metadata: AttentionMetadata, - device: torch.device, - ) -> torch.Tensor: - full_text_row_masked_out_mask = torch.ones( - (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) - start_pos = 0 - for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens, - attn_metadata.encoder_seq_lens): - if encoder_seq_len == 0: - full_text_row_masked_out_mask[start_pos:start_pos + - seq_len] = False - start_pos += seq_len - full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - device) - return full_text_row_masked_out_mask - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - **kwargs: object, - ) -> Union[CausalLMOutputWithPast]: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.num_prefill_tokens > 0 and \ - attn_metadata.num_decode_tokens > 0: - raise ValueError("Chunk prefill not supported") - image_inputs = self._parse_and_validate_image_input(**kwargs) - cross_attention_states = None - cross_attention_mask = None - kv_range_for_decode = None - - # For 1) text-only prefill and decode, 2) image-present decode. - if image_inputs is None: - full_text_row_masked_out_mask = ( - attn_metadata.encoder_seq_lens_tensor - != 0).reshape(-1, 1).to(input_ids.device) - skip_cross_attention = attn_metadata.max_encoder_seq_len == 0 - - # For image-present prefill. - else: - skip_cross_attention = False - - num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")] - num_tokens_per_tile = calc_token_per_chunk(self.image_size) - - actual_encoder_seq_lens = self._get_and_validate_encoder_lens( - attn_metadata.encoder_seq_lens, - num_tiles, - num_tokens_per_tile, - ) - - cross_attention_states = self.get_cross_attention_states( - image_inputs, attn_metadata, actual_encoder_seq_lens) - - full_text_row_masked_out_mask = \ - self.get_full_text_row_masked_out_mask( - attn_metadata, input_ids.device) - - cross_attention_mask, kv_range_for_decode = \ - self.get_cross_attention_mask( - input_ids, attn_metadata, num_tiles, - num_tokens_per_tile, cross_attention_states.dtype) - - outputs = self.language_model( - input_ids=input_ids, - positions=positions, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - skip_cross_attention=skip_cross_attention, - ) - - return outputs - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="language_model", - connector="multi_modal_projector", - tower_model="vision_model") - - -def skip_attention_mask(sparse_mask: list[list[int]]) -> bool: - for mask in sparse_mask: - # Skip text-only samples. - if len(mask) == 0: - continue - # If the sample contains more than 1 images, - # we can't skip mask. - if len(mask) != 1: - return False - # If the sample contains only 1 image, - # but the image is not the leading one, - # we can't skip mask. - if mask[0][0] != 0 or mask[0][1] != -1: - return False - return True - - -def convert_sparse_cross_attention_mask_to_dense( - sparse_mask: list[list[list[int]]], - num_tiles: list[list[int]], - lengths: list[int], -) -> tuple[np.ndarray, list[tuple[int, int]]]: - total_length = sum(lengths) - total_tiles = sum([sum(tiles) for tiles in num_tiles]) - dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) - # A list of ranges, range[i] = [start, end] means that the i-th image will - # use tiles[start, end] for cross-attention decoding. - tile_range_for_decode = [] - - seq_start = 0 - tile_start = 0 - - # sparse_mask has an [] entry for each sequence that does not have images, - # but num_tiles does not have these entries... - num_tiles_idx = 0 - for masks, length in zip(sparse_mask, lengths): - if len(masks) == 0: - # Text only - continue - - tiles = num_tiles[num_tiles_idx] - num_tiles_idx += 1 - ts, td = -1, 0 - for mask, tile in zip(masks, tiles): - if len(mask) != 2: - continue - start, end = mask - end = min(end, length) - if end == -1: - end = length - if end == length: - if ts == -1: - ts = tile_start - td += tile - dense_mask[seq_start + start:seq_start + end, - tile_start:tile_start + tile] = 1 - tile_start += tile - assert ts != -1 - assert td != 0 - tile_range_for_decode.append((ts, ts + td)) - seq_start += length - assert num_tiles_idx == len(num_tiles) - - return dense_mask, tile_range_for_decode - - -def convert_dense_cross_attention_mask_to_tensor( - cross_attention_token_mask: np.ndarray, - num_tokens_per_tile: int, - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device) - mask = mask.repeat_interleave(num_tokens_per_tile, dim=1) - - mask = 1.0 - mask - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min) - - ninf = torch.finfo(dtype).min - full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) - mask *= full_text_mask - # (num_prompt_tokens, num_encoder_tokens) - return mask diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 595bdd17cf2c2..b624a6200ab3d 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -19,7 +19,7 @@ import math from collections.abc import Iterable, Mapping from itertools import tee -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -27,66 +27,87 @@ from transformers import BatchFeature, Llama4Config, Llama4VisionConfig from transformers.image_utils import SizeDict from transformers.models.llama4 import Llama4Processor from transformers.models.llama4.image_processing_llama4_fast import ( - find_supported_resolutions, get_best_fit) + find_supported_resolutions, + get_best_fit, +) from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import InputProcessingContext -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsMultiModal, + SupportsPP, +) from .llama4 import Llama4ForCausalLM -from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix +from .vision import run_dp_sharded_vision_model -class Llama4ImagePatchInputs(TypedDict): - type: Literal["pixel_values"] - flat_data: torch.Tensor +class Llama4ImagePatchInputs(TensorSchema): """ - Shape: - `(batch_size * num_chunks, num_channels, image size, image size)` + Dimensions: + - batch_size: Batch size + - total_num_chunks: Batch size * number of chunks + - num_channels: Number of channels + - image_size: Size of each image """ - patches_per_image: torch.Tensor + + type: Literal["pixel_values"] = "pixel_values" + + flat_data: Annotated[ + torch.Tensor, + TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"), + ] + + patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")] """ The number of total patches for each image in the batch. - + This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ - aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] + aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] """ A list of aspect ratios corresponding to the number of tiles in each dimension that each image in the batch corresponds to. - - Shape: - `(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)` + Each aspect ratio is a pair (ratio_h, ratio_w). """ class Llama4VisionMLP(nn.Module): - def __init__( self, input_size: int, @@ -99,22 +120,21 @@ class Llama4VisionMLP(nn.Module): use_data_parallel: bool = False, ): super().__init__() - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( input_size=input_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( input_size=intermediate_size, output_size=output_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) self.activation_fn = nn.GELU() self.output_activation = output_activation @@ -129,7 +149,6 @@ class Llama4VisionMLP(nn.Module): class Llama4MultiModalProjector(nn.Module): - def __init__( self, config, @@ -159,9 +178,9 @@ def pixel_shuffle(input_tensor, shuffle_ratio): input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() - reshaped_tensor = input_tensor.view(batch_size, height, - int(width * shuffle_ratio), - int(channels / shuffle_ratio)) + reshaped_tensor = input_tensor.view( + batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) + ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.view( @@ -172,13 +191,11 @@ def pixel_shuffle(input_tensor, shuffle_ratio): ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() - output_tensor = reshaped_tensor.view(batch_size, -1, - reshaped_tensor.shape[-1]) + output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) return output_tensor class Llama4VisionPixelShuffleMLP(nn.Module): - def __init__( self, config, @@ -188,8 +205,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module): ): super().__init__() self.pixel_shuffle_ratio = config.pixel_shuffle_ratio - self.inner_dim = int(config.projector_input_dim // - (self.pixel_shuffle_ratio**2)) + self.inner_dim = int( + config.projector_input_dim // (self.pixel_shuffle_ratio**2) + ) self.output_dim = config.projector_output_dim self.mlp = Llama4VisionMLP( input_size=config.intermediate_size, @@ -203,13 +221,11 @@ class Llama4VisionPixelShuffleMLP(nn.Module): ) def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: - encoded_patches = pixel_shuffle(encoded_patches, - self.pixel_shuffle_ratio) + encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) return self.mlp(encoded_patches) class Llama4VisionAttention(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -219,8 +235,9 @@ class Llama4VisionAttention(nn.Module): ): super().__init__() self.config = config - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -231,8 +248,9 @@ class Llama4VisionAttention(nn.Module): self.attention_dropout = config.attention_dropout self.scaling = self.head_dim**-0.5 - self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, - self.scaling) + self.attn = MultiHeadAttention( + self.num_local_heads, self.head_dim, self.scaling + ) if use_data_parallel: self.qkv_proj = ReplicatedLinear( @@ -271,7 +289,7 @@ class Llama4VisionAttention(nn.Module): head_size=self.head_dim, rotary_dim=config.hidden_size // config.num_attention_heads // 2, # number of image patches - max_position=(config.image_size // config.patch_size)**2, + max_position=(config.image_size // config.patch_size) ** 2, base=config.rope_theta, rope_scaling={"rope_type": "mllama4"}, is_neox_style=False, @@ -302,7 +320,6 @@ class Llama4VisionAttention(nn.Module): class Llama4VisionEncoderLayer(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -351,12 +368,11 @@ class Llama4VisionEncoderLayer(nn.Module): hidden_state = self.mlp(hidden_state) hidden_state = residual + hidden_state - outputs = (hidden_state, ) + outputs = (hidden_state,) return outputs class Llama4VisionEncoder(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -366,14 +382,17 @@ class Llama4VisionEncoder(nn.Module): ): super().__init__() self.config = config - self.layers = nn.ModuleList([ - Llama4VisionEncoderLayer( - config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - use_data_parallel=use_data_parallel, - ) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Llama4VisionEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -381,11 +400,10 @@ class Llama4VisionEncoder(nn.Module): ) -> torch.Tensor: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if you - want more control over how to convert `input_ids` indices into + hidden_states: Input tensor of shape + (batch_size, sequence_length, hidden_size). + Hidden states from the model embeddings, representing + the input tokens. associated vectors than the model's internal embedding lookup matrix. """ @@ -398,7 +416,6 @@ class Llama4VisionEncoder(nn.Module): class Llama4UnfoldConvolution(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -410,22 +427,16 @@ class Llama4UnfoldConvolution(nn.Module): kernel_size = config.patch_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) - self.unfold = torch.nn.Unfold(kernel_size=kernel_size, - stride=config.patch_size) - params = { - "input_size": - config.num_channels * kernel_size[0] * kernel_size[1], - "output_size": config.hidden_size, - "bias": False, - "quant_config": quant_config, - "prefix": f"{prefix}.linear", - } - if use_data_parallel: - cls = ReplicatedLinear - else: - cls = ColumnParallelLinear - params["gather_output"] = True - self.linear = cls(**params) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) + self.linear = ColumnParallelLinear( + input_size=config.num_channels * kernel_size[0] * kernel_size[1], + output_size=config.hidden_size, + bias=False, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.linear", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) @@ -435,7 +446,6 @@ class Llama4UnfoldConvolution(nn.Module): class Llama4VisionModel(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -450,7 +460,7 @@ class Llama4VisionModel(nn.Module): self.hidden_size = config.hidden_size self.num_channels = config.num_channels - self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = Llama4UnfoldConvolution( @@ -460,10 +470,10 @@ class Llama4VisionModel(nn.Module): use_data_parallel=use_data_parallel, ) - self.class_embedding = nn.Parameter(self.scale * - torch.randn(self.hidden_size)) + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.positional_embedding_vlm = nn.Parameter( - self.scale * torch.randn(self.num_patches, self.hidden_size)) + self.scale * torch.randn(self.num_patches, self.hidden_size) + ) # layer norms self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5) @@ -492,8 +502,9 @@ class Llama4VisionModel(nn.Module): num_tiles, num_patches, hidden_dim = hidden_state.shape # Add cls token - class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, - hidden_state.shape[-1]) + class_embedding = self.class_embedding.expand( + hidden_state.shape[0], 1, hidden_state.shape[-1] + ) hidden_state = torch.cat([hidden_state, class_embedding], dim=1) num_patches += 1 @@ -505,7 +516,8 @@ class Llama4VisionModel(nn.Module): hidden_dim, ) positional_embedding = self.positional_embedding_vlm.to( - dtype=hidden_state.dtype, device=hidden_state.device) + dtype=hidden_state.dtype, device=hidden_state.device + ) hidden_state = hidden_state + positional_embedding hidden_state = self.layernorm_pre(hidden_state) hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) @@ -524,7 +536,6 @@ class Llama4VisionModel(nn.Module): class Mllama4ProcessingInfo(BaseProcessingInfo): - def __init__(self, ctx: InputProcessingContext) -> None: super().__init__(ctx) @@ -532,9 +543,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): return self.ctx.get_hf_config(Llama4Config) def get_hf_processor(self, **kwargs: object) -> Llama4Processor: - return self.ctx.get_hf_processor(Llama4Processor, - use_fast=kwargs.pop("use_fast", True), - **kwargs) + return self.ctx.get_hf_processor( + Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs + ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: # Although vLLM can support more images from an infra capability @@ -546,13 +557,13 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): image_size = vision_config.image_size patch_size = vision_config.patch_size - assert ( - image_size % - patch_size == 0), f"chunk size {image_size} should be multiple of " + assert image_size % patch_size == 0, ( + f"chunk size {image_size} should be multiple of " + ) f"patch_size {patch_size}" ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2))) - return (image_size // patch_size)**2 // ds_ratio + return (image_size // patch_size) ** 2 // ds_ratio def get_max_num_tiles(self) -> int: image_processor = self.get_hf_processor().image_processor @@ -562,13 +573,10 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): vision_config = self.get_hf_config().vision_config image_size = vision_config.image_size # Result in the max possible feature size (h:w = 16:1) - return ImageSize(height=self.get_max_num_tiles() * image_size, - width=image_size) + return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size) -class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] - ): - +class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -592,15 +600,16 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] vision_config = self.info.get_hf_config().vision_config if processed_outputs.get("pixel_values") is not None: - assert ( - "images" in mm_data - ), "images expected to be in mm_data when pixel_values is present" + assert "images" in mm_data, ( + "images expected to be in mm_data when pixel_values is present" + ) images = mm_data["images"] - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) tile_size = vision_config.image_size possible_resolutions = find_supported_resolutions( @@ -612,20 +621,20 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] (image.size[1], image.size[0]), torch.tensor(possible_resolutions), resize_to_max_canvas=image_processor.resize_to_max_canvas, - ) for image in parsed_images + ) + for image in parsed_images ] # TODO tile height/width do not necessarily need to match - aspect_ratios = [(image_size[0] // tile_size, - image_size[1] // tile_size) - for image_size in best_fit_sizes] + aspect_ratios = [ + (image_size[0] // tile_size, image_size[1] // tile_size) + for image_size in best_fit_sizes + ] patches_per_image = [ - 1 if r_h * r_w == 1 else 1 + r_h * r_w - for (r_h, r_w) in aspect_ratios + 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios ] - processed_outputs["aspect_ratios"] = aspect_ratios - processed_outputs["patches_per_image"] = torch.tensor( - patches_per_image) + processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios) + processed_outputs["patches_per_image"] = torch.tensor(patches_per_image) return processed_outputs @@ -637,7 +646,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", patches_per_image), + "image", patches_per_image + ), patches_per_image=MultiModalFieldConfig.batched("image"), aspect_ratios=MultiModalFieldConfig.batched("image"), ) @@ -677,7 +687,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -690,17 +699,21 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - (target_width, - target_height) = self.info.get_image_size_with_most_features() + (target_width, target_height) = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } @@ -709,13 +722,16 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): info=Mllama4ProcessingInfo, dummy_inputs=Mllama4DummyInputsBuilder, ) -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class Llama4ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -741,24 +757,42 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, use_data_parallel=self.use_data_parallel, ) self.multi_modal_projector = Llama4MultiModalProjector( - self.config, - None, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector") + ) else: self.vision_model = None self.multi_modal_projector = None self.language_model = initialize_model( - vllm_config=vllm_config.with_hf_config(config.text_config, - ["LlamaForCausalLM"]), + vllm_config=vllm_config.with_hf_config( + config.text_config, ["LlamaForCausalLM"] + ), prefix=maybe_prefix(prefix, "language_model"), model_class=Llama4ForCausalLM, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """Set which layers should output auxiliary hidden states for EAGLE3.""" + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, "set_aux_hidden_state_layers") + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Get the layer indices for auxiliary hidden state outputs. + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") + return self.language_model.get_eagle3_aux_hidden_state_layers() def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: + self, **kwargs: object + ) -> Optional[Llama4ImagePatchInputs]: # num_images, 1, num_chunks, channel, image_size, image_size pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: @@ -768,11 +802,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # TODO: confirm handling for variable lengths flat_pixel_values = flatten_bn(pixel_values, concat=True) patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) - - aspect_ratios = kwargs.pop("aspect_ratios", None) - if not isinstance(aspect_ratios, (torch.Tensor, list)): - raise ValueError("Incorrect type of aspect_ratios. " - f"Got type: {type(aspect_ratios)}") + aspect_ratios = kwargs.pop("aspect_ratios") + if aspect_ratios.ndim == 3: + aspect_ratios = aspect_ratios.squeeze(1) return Llama4ImagePatchInputs( type="pixel_values", @@ -782,8 +814,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ) def _process_image_input( - self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: - + self, image_input: Llama4ImagePatchInputs + ) -> MultiModalEmbeddings: assert self.vision_model and self.multi_modal_projector flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() @@ -791,12 +823,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # shard image input if self.use_data_parallel: vision_embeddings_flat = run_dp_sharded_vision_model( - flat_data, self.vision_model) + flat_data, self.vision_model + ) else: vision_embeddings_flat = self.vision_model(flat_data) - vision_embeddings_flat = self.multi_modal_projector( - vision_embeddings_flat) + vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat) return [ img.flatten(0, 1) @@ -813,24 +845,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -842,31 +856,21 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, - # this condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - return self.language_model(input_ids, positions, intermediate_tensors, - inputs_embeds) + return self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def separate_weights( self, weights: Iterable[tuple[str, torch.Tensor]], prefix: str, - ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[ - str, torch.Tensor]]]: + ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: weights1, weights2 = tee(weights, 2) def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]: @@ -908,31 +912,33 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str: """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM format.""" - if name.startswith("model.") or name.startswith( - "language_model.model."): - renamed = name.replace("model.", "language_model.model.", - 1) if name.startswith("model.") else name + if name.startswith("model.") or name.startswith("language_model.model."): + renamed = ( + name.replace("model.", "language_model.model.", 1) + if name.startswith("model.") + else name + ) # Handle expert scale parameters with flat naming - if "feed_forward.experts." in name and ("_input_scale" in name or - "_weight_scale" in name): + if "feed_forward.experts." in name and ( + "_input_scale" in name or "_weight_scale" in name + ): # Map checkpoint naming to vLLM's expected naming if "down_proj_input_scale" in renamed: - return renamed.replace("down_proj_input_scale", - "w2_input_scale") + return renamed.replace("down_proj_input_scale", "w2_input_scale") elif "down_proj_weight_scale" in renamed: - return renamed.replace("down_proj_weight_scale", - "w2_weight_scale") + return renamed.replace("down_proj_weight_scale", "w2_weight_scale") elif "gate_up_proj_input_scale" in renamed: - return renamed.replace("gate_up_proj_input_scale", - "w13_input_scale") + return renamed.replace( + "gate_up_proj_input_scale", "w13_input_scale" + ) elif "gate_up_proj_weight_scale" in renamed: - return renamed.replace("gate_up_proj_weight_scale", - "w13_weight_scale") + return renamed.replace( + "gate_up_proj_weight_scale", "w13_weight_scale" + ) return renamed # Handle attention scale parameters - elif "self_attn." in name and (".k_scale" in name - or ".v_scale" in name): + elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name): if ".k_proj.k_scale" in renamed: return renamed.replace(".k_proj.k_scale", ".attn.k_scale") elif ".v_proj.v_scale" in renamed: @@ -943,8 +949,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return renamed elif name.startswith("lm_head.weight"): - return name.replace("lm_head.weight", - "language_model.lm_head.weight") + return name.replace("lm_head.weight", "language_model.lm_head.weight") return name @@ -967,7 +972,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return language_model_weights, other_weights def _handle_expert_scale_broadcasting( - self, weights: list[tuple[str, torch.Tensor]], params_dict: dict + self, weights: list[tuple[str, torch.Tensor]], params_dict: dict ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]: """Handle expert scale parameters that need broadcasting. @@ -980,12 +985,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, for name, weight in weights: # Check if this is an expert scale parameter that needs broadcasting - if ("feed_forward.experts." in name and "scale" in name - and ".shared_expert" not in name): + if ( + "feed_forward.experts." in name + and "scale" in name + and ".shared_expert" not in name + ): if name in params_dict: param = params_dict[name] - if (hasattr(param, 'data') and param.data.numel() > 1 - and weight.numel() == 1): + if ( + hasattr(param, "data") + and param.data.numel() > 1 + and weight.numel() == 1 + ): # Broadcast single value to all experts param.data.fill_(weight.item()) updated_params.add(name) @@ -997,10 +1008,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, return regular_weights, expert_scale_weights, updated_params - def _load_other_weights(self, other_weights: Iterable[tuple[str, - torch.Tensor]], - params_dict: dict, - stacked_params_mapping: list) -> set[str]: + def _load_other_weights( + self, + other_weights: Iterable[tuple[str, torch.Tensor]], + params_dict: dict, + stacked_params_mapping: list, + ) -> set[str]: """Load non-language-model weights with stacking support.""" updated_params = set() @@ -1021,16 +1034,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, else: # Use regular weight loading param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) updated_params.add(name) return updated_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), @@ -1047,8 +1057,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, updated_params: set[str] = set() # Separate and rename weights - language_model_weights, other_weights = ( - self._separate_and_rename_weights(weights)) + language_model_weights, other_weights = self._separate_and_rename_weights( + weights + ) # Skip loading vision model and projector if they're not initialized. if self.vision_model is None and self.multi_modal_projector is None: @@ -1056,8 +1067,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, # Handle expert scale parameters regular_weights, expert_scale_weights, updated_params_from_experts = ( - self._handle_expert_scale_broadcasting(language_model_weights, - params_dict)) + self._handle_expert_scale_broadcasting(language_model_weights, params_dict) + ) updated_params.update(updated_params_from_experts) loader = AutoWeightsLoader(self) @@ -1066,13 +1077,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, updated_params.update(loaded_language_model_params) if expert_scale_weights: - loaded_expert_scale_params = loader.load_weights( - expert_scale_weights) + loaded_expert_scale_params = loader.load_weights(expert_scale_weights) if loaded_expert_scale_params: updated_params.update(loaded_expert_scale_params) updated_params.update( - self._load_other_weights(other_weights, params_dict, - stacked_params_mapping)) + self._load_other_weights(other_weights, params_dict, stacked_params_mapping) + ) return updated_params diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc188..4901ac74fb28b 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -8,13 +8,15 @@ import torch import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from .utils import maybe_prefix + SQRT2 = 2**0.5 @@ -74,8 +76,7 @@ class MLPSpeculator(nn.Module): self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim - self.inner_dim = config.inner_dim if config.inner_dim != 0 \ - else config.emb_dim + self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim self.max_speculative_tokens = config.num_lookahead_tokens @@ -83,124 +84,153 @@ class MLPSpeculator(nn.Module): self.scale_input = config.scale_input if self.tie_weights: - assert ( - self.n_predict > 1 - ), "You cannot tie weights between stages when only 1 exists" + assert self.n_predict > 1, ( + "You cannot tie weights between stages when only 1 exists" + ) embedding = VocabParallelEmbedding( - config.vocab_size, - self.inner_dim, - org_num_embeddings=config.vocab_size) + config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size + ) self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) # the initial projection from the base model may # have a different size, so that stays separate. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) - self.proj = nn.ModuleList([proj_first] + [proj_tied] * - (self.max_speculative_tokens - 1)) + self.proj = nn.ModuleList( + [proj_first] + [proj_tied] * (self.max_speculative_tokens - 1) + ) - head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) - self.head = nn.ModuleList([head] * self.max_speculative_tokens) + self.head = nn.ModuleList( + [ + ParallelLMHead( + self.vocab_size, + self.inner_dim, + bias=False, + prefix=maybe_prefix(prefix, f"head.{i}"), + ) + for i in range(self.max_speculative_tokens) + ] + ) - ln = MLPSpeculatorLayerNorm(self.inner_dim, - elementwise_scale_and_shift=True) + ln = MLPSpeculatorLayerNorm( + self.inner_dim, elementwise_scale_and_shift=True + ) self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) else: - self.emb = nn.ModuleList([ - VocabParallelEmbedding(config.vocab_size, - self.inner_dim, - org_num_embeddings=config.vocab_size) - for _ in range(self.max_speculative_tokens) - ]) + self.emb = nn.ModuleList( + [ + VocabParallelEmbedding( + config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size, + ) + for _ in range(self.max_speculative_tokens) + ] + ) - self.proj = nn.ModuleList([ - nn.Linear((self.emb_dim if i == 0 else self.inner_dim), - self.inner_dim, - bias=False) - for i in range(self.max_speculative_tokens) - ]) + self.proj = nn.ModuleList( + [ + nn.Linear( + (self.emb_dim if i == 0 else self.inner_dim), + self.inner_dim, + bias=False, + ) + for i in range(self.max_speculative_tokens) + ] + ) - self.head = nn.ModuleList([ - ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) - for _ in range(self.max_speculative_tokens) - ]) - self.ln = nn.ModuleList([ - MLPSpeculatorLayerNorm(self.inner_dim, - elementwise_scale_and_shift=True) - for _ in range(self.max_speculative_tokens) - ]) + self.head = nn.ModuleList( + [ + ParallelLMHead( + self.vocab_size, + self.inner_dim, + bias=False, + prefix=maybe_prefix(prefix, f"head.{i}"), + ) + for i in range(self.max_speculative_tokens) + ] + ) + self.ln = nn.ModuleList( + [ + MLPSpeculatorLayerNorm( + self.inner_dim, elementwise_scale_and_shift=True + ) + for _ in range(self.max_speculative_tokens) + ] + ) if self.scale_input: self.ln0 = MLPSpeculatorLayerNorm( - self.emb_dim, elementwise_scale_and_shift=False) + self.emb_dim, elementwise_scale_and_shift=False + ) - self.state_weight = 0.5**(0.5 / config.n_predict) - self.emb_weight = math.sqrt( - (1 - self.state_weight**2) * (self.inner_dim / 2)) + self.state_weight = 0.5 ** (0.5 / config.n_predict) + self.emb_weight = math.sqrt((1 - self.state_weight**2) * (self.inner_dim / 2)) self.activation = nn.GELU() self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - config.vocab_size, 1.0) - self.sampler = get_sampler() + self.logits_processor = LogitsProcessor( + config.vocab_size, config.vocab_size, 1.0 + ) - def generate_proposals( - self, - input_ids: torch.Tensor, - previous_hidden_states: torch.Tensor, - num_predict_tokens: int, - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - if num_predict_tokens > self.max_speculative_tokens: - raise ValueError(f"Max speculative tokens for model is " - f"{self.max_speculative_tokens}, but " - f"{num_predict_tokens} were requested") + # NOTE(woosuk): This method is commented out because it is old code + # using V0. We should either port it to V1 or remove it. - # b x 1 x d - previous_hidden_states = previous_hidden_states.unsqueeze(1) + # def generate_proposals( + # self, + # input_ids: torch.Tensor, + # previous_hidden_states: torch.Tensor, + # num_predict_tokens: int, + # sampling_metadata: SamplingMetadata, + # ) -> list[SamplerOutput]: + # if num_predict_tokens > self.max_speculative_tokens: + # raise ValueError(f"Max speculative tokens for model is " + # f"{self.max_speculative_tokens}, but " + # f"{num_predict_tokens} were requested") - if self.scale_input: - previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + # # b x 1 x d + # previous_hidden_states = previous_hidden_states.unsqueeze(1) - # b x 1 - last_tokens = input_ids.unsqueeze(1) + # if self.scale_input: + # previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 - next_tokens = [] + # # b x 1 + # last_tokens = input_ids.unsqueeze(1) - for head_index in range(num_predict_tokens): + # next_tokens = [] - # Project and predict - z = self.emb[head_index](last_tokens) # b k d - states = self.proj[head_index](previous_hidden_states) + # for head_index in range(num_predict_tokens): - # Weighted add of state_weight*state and emb_weight*z - # Let subsequent LN take care of denominator - # state_weight is close to 1, so shouldn't be any precision issues - states.add_(z, alpha=self.emb_weight / self.state_weight) + # # Project and predict + # z = self.emb[head_index](last_tokens) # b k d + # states = self.proj[head_index](previous_hidden_states) - states = self.activation(self.ln[head_index](states)) # b k d - previous_hidden_states = states - # TODO: not yet supporting top_k_tokens_per_head - states = states.flatten(0, 1) + # # Weighted add of state_weight*state and emb_weight*z + # # Let subsequent LN take care of denominator + # # state_weight is close to 1, so shouldn't be any precision issues + # states.add_(z, alpha=self.emb_weight / self.state_weight) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + # states = self.activation(self.ln[head_index](states)) # b k d + # previous_hidden_states = states + # # TODO: not yet supporting top_k_tokens_per_head + # states = states.flatten(0, 1) - output = self.sampler(logits, sampling_metadata) - last_tokens = output.sampled_token_ids - next_tokens.append(output) + # logits = self.logits_processor(self.head[head_index], states, + # sampling_metadata) - return next_tokens + # output = self.sampler(logits, sampling_metadata) + # last_tokens = output.sampled_token_ids + # next_tokens.append(output) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + # return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: name = name.replace("speculator.", "") param = params_dict.get(name) if param is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 72290bf2ee29f..58e2acb8ce922 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -6,41 +6,46 @@ from typing import Optional, Union import torch from torch import nn from transformers import ModernBertConfig +from transformers.activations import ACT2FN from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler, - PoolingMethod, - PoolingParamsUpdate, - PoolingType) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + DispatchPooler, + Pooler, + PoolingMethod, + PoolingParamsUpdate, + PoolingType, +) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata -from .interfaces import SupportsCrossEncoding, default_pooling_type -from .utils import WeightsMapper, maybe_prefix +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class ModernBertEmbeddings(nn.Module): - def __init__(self, config: ModernBertConfig): - super().__init__() self.config = config - self.tok_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps, - bias=config.norm_bias) + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, bias=config.norm_bias + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) def forward( self, @@ -56,24 +61,20 @@ class ModernBertEmbeddings(nn.Module): class ModernBertRotaryEmbedding(RotaryEmbedding): - - def __init__(self, config: ModernBertConfig, head_size: int, dim: int, - base: float): + def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float): super().__init__( head_size=head_size, rotary_dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, is_neox_style=True, - dtype=torch.float16) + dtype=torch.float16, + ) self.config = config class ModernBertAttention(nn.Module): - - def __init__(self, - config: ModernBertConfig, - layer_id: Optional[int] = None): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -95,24 +96,27 @@ class ModernBertAttention(nn.Module): sliding_window = None if layer_id % config.global_attn_every_n_layers != 0: sliding_window = config.local_attention // 2 - rope_theta = config.local_rope_theta if config.local_rope_theta \ - is not None else config.global_rope_theta + rope_theta = ( + config.local_rope_theta + if config.local_rope_theta is not None + else config.global_rope_theta + ) else: rope_theta = config.global_rope_theta - self.rotary_emb = ModernBertRotaryEmbedding(config=config, - head_size=self.head_dim, - dim=self.head_dim, - base=rope_theta) + self.rotary_emb = ModernBertRotaryEmbedding( + config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta + ) self.attn = EncoderOnlyAttention( self.num_heads, self.head_dim, self.scaling, prefix=f"{layer_id}.attn", - per_layer_sliding_window=sliding_window) - self.Wo = RowParallelLinear(config.hidden_size, - config.hidden_size, - bias=config.attention_bias) + per_layer_sliding_window=sliding_window, + ) + self.Wo = RowParallelLinear( + config.hidden_size, config.hidden_size, bias=config.attention_bias + ) def forward( self, @@ -129,17 +133,16 @@ class ModernBertAttention(nn.Module): class ModernBertMLP(nn.Module): - def __init__(self, config: ModernBertConfig): super().__init__() self.config = config - self.Wi = nn.Linear(config.hidden_size, - int(config.intermediate_size) * 2, - bias=config.mlp_bias) + self.Wi = nn.Linear( + config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias + ) self.act = nn.GELU() - self.Wo = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=config.mlp_bias) + self.Wo = RowParallelLinear( + config.intermediate_size, config.hidden_size, bias=config.mlp_bias + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) @@ -147,23 +150,21 @@ class ModernBertMLP(nn.Module): class ModernBertLayer(nn.Module): - - def __init__(self, - config: ModernBertConfig, - prefix: str = "", - layer_id: Optional[int] = None): + def __init__( + self, config: ModernBertConfig, prefix: str = "", layer_id: Optional[int] = None + ): super().__init__() self.config = config if layer_id == 0: self.attn_norm = nn.Identity() else: - self.attn_norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.attn_norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) self.attn = ModernBertAttention(config=config, layer_id=layer_id) - self.mlp_norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.mlp_norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) self.mlp = ModernBertMLP(config) def forward( @@ -171,8 +172,9 @@ class ModernBertLayer(nn.Module): hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: - attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states), - position_ids=position_ids) + attn_outputs = self.attn( + hidden_states=self.attn_norm(hidden_states), position_ids=position_ids + ) hidden_states = hidden_states + attn_outputs mlp_output = self.mlp(self.mlp_norm(hidden_states)) hidden_states = hidden_states + mlp_output @@ -180,14 +182,15 @@ class ModernBertLayer(nn.Module): class ModernBertEncoderLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.layers = nn.ModuleList([ - ModernBertLayer(config=config, layer_id=layer_id) - for layer_id in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + ModernBertLayer(config=config, layer_id=layer_id) + for layer_id in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -203,7 +206,8 @@ class ModernBertEncoderLayer(nn.Module): @default_pooling_type("CLS") class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={"layers.": "encoder_layer.layers."}) + orig_to_new_prefix={"layers.": "encoder_layer.layers."} + ) def __init__( self, @@ -215,12 +219,14 @@ class ModernBertModel(nn.Module): self.config = config self.embeddings = ModernBertEmbeddings(config) self.encoder_layer = ModernBertEncoderLayer(vllm_config) - self.final_norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.final_norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.get_input_embeddings(input_ids) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -228,8 +234,7 @@ class ModernBertModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -244,8 +249,9 @@ class ModernBertModel(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - inputs_embeds=inputs_embeds) + hidden_states = self.embeddings( + input_ids=input_ids, inputs_embeds=inputs_embeds + ) outputs = self.encoder_layer( hidden_states=hidden_states, @@ -256,18 +262,18 @@ class ModernBertModel(nn.Module): class ModernBertPooler(Pooler): - def __init__(self, config: ModernBertConfig): super().__init__() pooling_type = PoolingType[config.classifier_pooling.upper()] self.pooling = PoolingMethod.from_pooling_type(pooling_type) - self.dense = nn.Linear(config.hidden_size, config.hidden_size, - config.classifier_bias) + self.dense = nn.Linear( + config.hidden_size, config.hidden_size, config.classifier_bias + ) self.act = nn.GELU() - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) def get_supported_tasks(self) -> Set[PoolingTask]: return self.pooling.get_supported_tasks() @@ -296,48 +302,55 @@ class ModernBertPooler(Pooler): @default_pooling_type("CLS") class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): - is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config - self.model = ModernBertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "modernbert")) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.model = ModernBertModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") + ) + self.classifier = nn.Linear( + config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype, + ) self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=self.pooling, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=self.pooling, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - self_weights = [] def weight_filter(): for name, weight in weights: if name.startswith("model."): - yield name[len("model."):], weight + yield name[len("model.") :], weight else: self_weights.append((name, weight)) @@ -348,13 +361,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): for name, loaded_weight in self_weights: if name.startswith("classifier"): param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if name.startswith("head"): - param = params_dict["pooling." + name[len("head") + 1:]] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + param = params_dict["pooling." + name[len("head") + 1 :]] + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) def forward( @@ -369,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): inputs_embeds=inputs_embeds, positions=positions, ) + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.dense = nn.Linear( + config.hidden_size, config.hidden_size, bias=config.classifier_bias + ) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm( + config.hidden_size, + eps=getattr(config, "norm_eps", 1e-5), + bias=getattr(config, "norm_bias", True), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@default_pooling_type("ALL") +class ModernBertForTokenClassification(nn.Module): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.head_dtype = vllm_config.model_config.head_dtype + self.num_labels = config.num_labels + self.model = ModernBertModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") + ) + self.head = ModernBertPredictionHead(config) + self.classifier = nn.Linear( + config.hidden_size, config.num_labels, dtype=self.head_dtype + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + } + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self, skip_prefixes=["drop"]) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) + hidden_states = self.head(hidden_states) + hidden_states = hidden_states.to(self.head_dtype) + return self.classifier(hidden_states) diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 11a2a384c165e..666796d835a36 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -54,19 +54,22 @@ class MultiModelKeys(ModelKeys): generator: list[str] = field(default_factory=list) @staticmethod - def from_string_field(language_model: Union[str, list[str]] = None, - connector: Union[str, list[str]] = None, - tower_model: Union[str, list[str]] = None, - generator: Union[str, list[str]] = None, - **kwargs) -> 'MultiModelKeys': - + def from_string_field( + language_model: Union[str, list[str]] = None, + connector: Union[str, list[str]] = None, + tower_model: Union[str, list[str]] = None, + generator: Union[str, list[str]] = None, + **kwargs, + ) -> "MultiModelKeys": def to_list(value): if value is None: return [] return [value] if isinstance(value, str) else list(value) - return MultiModelKeys(language_model=to_list(language_model), - connector=to_list(connector), - tower_model=to_list(tower_model), - generator=to_list(generator), - **kwargs) + return MultiModelKeys( + language_model=to_list(language_model), + connector=to_list(connector), + tower_model=to_list(tower_model), + generator=to_list(generator), + **kwargs, + ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 5fc28ed0e493e..f1dd06f3a0650 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -5,6 +5,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass from functools import cached_property, partial +from itertools import islice from typing import Annotated, Optional, Union import numpy as np @@ -12,8 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, - TensorType) +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -21,44 +21,66 @@ from vllm.attention import Attention from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, - SiluAndMul) +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -75,20 +97,28 @@ class MolmoImageInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - - nc: Number of crops + - nc: Number of crops (dynamic) - np: Number of patches + - tp: Token sequence positions - pd: Patch dimension """ - images: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np", "pd")] - image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], - TensorShape("bn", "nc", "np")] + images: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}), + ] + # Number of crops may vary per batch and image, so pass it as a list. - feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np")] - # A boolean mask indicating which image features correspond to patch tokens. + image_masks: Annotated[ + Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), + ] + image_input_idx: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), + ] + # An index tensor that maps image features to their corresponding patch tokens. num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -107,8 +137,7 @@ class VisionBackboneConfig: image_norm_eps: float = 1e-5 def __post_init__(self): - self.image_default_input_size = tuple( - self.image_default_input_size) # type: ignore[assignment] + self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] @property def image_num_patch(self): @@ -204,15 +233,13 @@ class MultiHeadDotProductAttention(nn.Module): ) self.scale = self.head_dim**-0.5 - self.attn = MultiHeadAttention(self.num_heads, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads) - - def forward(self, - inputs_q: torch.Tensor, - inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + self.attn = MultiHeadAttention( + self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads + ) + def forward( + self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None + ) -> torch.Tensor: if inputs_kv is not None: inputs_k = inputs_kv inputs_v = inputs_kv @@ -239,8 +266,7 @@ class ResidualAttentionBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.attention = MultiHeadDotProductAttention( - config, quant_config=quant_config) + self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config) self.feed_forward = ViTMLP(config, quant_config) self.attention_norm = nn.LayerNorm( config.image_emb_dim, @@ -266,10 +292,12 @@ class BlockCollection(nn.Module): quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.resblocks = nn.ModuleList([ - ResidualAttentionBlock(config, quant_config) - for _ in range(config.image_num_layers) - ]) + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(config, quant_config) + for _ in range(config.image_num_layers) + ] + ) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: hidden_states = [] @@ -294,19 +322,18 @@ class VisionTransformer(nn.Module): super().__init__() scale = config.image_emb_dim**-0.5 self.patch_num = config.image_num_patch - self.class_embedding = nn.Parameter( - torch.randn(config.image_emb_dim) * scale) + self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale) self.num_prefix_tokens: int = NUM_PREFIX_TOKENS self.positional_embedding = nn.Parameter( - torch.randn(config.image_num_pos, config.image_emb_dim) * scale) + torch.randn(config.image_num_pos, config.image_emb_dim) * scale + ) image_patch_size = config.image_patch_size self.patch_embedding = nn.Linear( image_patch_size * image_patch_size * 3, config.image_emb_dim, bias=False, ) - self.pre_ln = nn.LayerNorm(config.image_emb_dim, - eps=config.image_norm_eps) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) self.transformer = BlockCollection(config, quant_config) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: @@ -314,8 +341,12 @@ class VisionTransformer(nn.Module): pos_emb = self.positional_embedding[1:] pos_emb = pos_emb.reshape( - (int(math.sqrt(pos_emb.shape[0])), - int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) + ( + int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), + pos_emb.shape[1], + ) + ) (patch_num_0, patch_num_1) = patch_num @@ -332,13 +363,12 @@ class VisionTransformer(nn.Module): pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) - x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], - dim=1).to(x.dtype) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) return x - def forward(self, - x: torch.Tensor, - patch_num: Optional[int] = None) -> list[torch.Tensor]: + def forward( + self, x: torch.Tensor, patch_num: Optional[int] = None + ) -> list[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ @@ -350,8 +380,8 @@ class VisionTransformer(nn.Module): # class embeddings and positional embeddings x = torch.cat( - [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], - dim=1) + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1 + ) x = self.add_pos_emb(x, patch_num) x = self.pre_ln(x) @@ -379,8 +409,7 @@ class MolmoAttention(nn.Module): assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size - self.total_num_kv_heads = config.num_key_value_heads \ - or self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: @@ -408,10 +437,10 @@ class MolmoAttention(nn.Module): self.q_norm: Optional[nn.Module] = None if config.attention_layer_norm: self.tp_rank = get_tensor_model_parallel_rank() - self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, - eps=config.layer_norm_eps) - self.q_norm = RMSNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps + ) + self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) # Rotary embeddings. self.rotary_emb = get_rope( @@ -421,13 +450,15 @@ class MolmoAttention(nn.Module): base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) # Attention output projection. self.o_proj = RowParallelLinear( @@ -437,16 +468,16 @@ class MolmoAttention(nn.Module): quant_config=quant_config, ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -469,10 +500,12 @@ class MolmoAttention(nn.Module): class LanguageModelMLP(nn.Module): """Molmo's LLM mlp.""" - def __init__(self, - config: PretrainedConfig, - input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // 2 @@ -544,7 +577,6 @@ class ImageProjectorMLP(nn.Module): class MolmoDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -554,20 +586,19 @@ class MolmoDecoderLayer(nn.Module): ) -> None: super().__init__() # Attention block. - self.self_attn = MolmoAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = MolmoAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) # MLP block. self.mlp = LanguageModelMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def forward( self, @@ -580,21 +611,18 @@ class MolmoDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): - def forward( self, positions: torch.Tensor, @@ -635,16 +663,14 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): (self.image_num_patch[0] + 1) // POOLING_SIZE, (self.image_num_patch[1] + 1) // POOLING_SIZE, ) - self.image_vit = VisionTransformer(vision_config, - quant_config=quant_config) + self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) self.num_prefix_tokens = self.image_vit.num_prefix_tokens - assert self.num_prefix_tokens in { - 0, 1 - }, "Only 0 or 1 prefix tokens are supported" + assert self.num_prefix_tokens in {0, 1}, ( + "Only 0 or 1 prefix tokens are supported" + ) self.image_pooling_2d = MultiHeadDotProductAttention( - vision_config, - nlayers=len(self.vit_layers), - quant_config=quant_config) + vision_config, nlayers=len(self.vit_layers), quant_config=quant_config + ) self.image_projector = ImageProjectorMLP( config, input_dim=vision_config.image_emb_dim, @@ -668,8 +694,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): """ B, T, N, D = images.shape - mask = ~torch.all( - images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) images = images.view(B * T, N, D) image_features = self.image_vit(images) @@ -704,21 +729,22 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): assert image_masks is not None pad_embed = self.pad_embed[:, None, None, None, :] all_pad = image_masks == 0 - partial_pad = torch.logical_and( - image_masks < 1, - torch.logical_not(all_pad)).to(dtype=torch.float32) + partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to( + dtype=torch.float32 + ) all_pad = all_pad.to(dtype=torch.float32) - image_features = image_features + pad_embed[0] * torch.unsqueeze( - all_pad, -1) + image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1) image_features = image_features + pad_embed[1] * torch.unsqueeze( - partial_pad, -1) + partial_pad, -1 + ) image_features = image_features.to(og_dtype) image_features = image_features.reshape( - (batch_size, num_image) + self.image_num_patch + (-1, ), ) + (batch_size, num_image) + self.image_num_patch + (-1,), + ) - if (missing_w := self.image_num_patch[0] % POOLING_SIZE): + if missing_w := self.image_num_patch[0] % POOLING_SIZE: # Padding for image pooling (see below) image_features = F.pad( image_features, @@ -728,7 +754,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): # image pooling image_features = rearrange( image_features, - 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', + "b n (h dh) (w dw) c -> (b n h w) (dh dw) c", dh=POOLING_SIZE, dw=POOLING_SIZE, ) @@ -744,8 +770,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): # image_features: (batch_size, num_image, num_patch, d_model) return image_features - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("merged_linear", "gate_proj", 0), @@ -755,7 +780,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -774,8 +799,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -783,7 +807,6 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): @support_torch_compile class MolmoModel(nn.Module, SupportsQuant): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -801,26 +824,25 @@ class MolmoModel(nn.Module, SupportsQuant): quant_config=quant_config, ) - decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ - else MolmoDecoderLayer + decoder_layer = ( + MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: decoder_layer( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) assert config.layer_norm_type == "rms" self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -842,25 +864,23 @@ class MolmoModel(nn.Module, SupportsQuant): residual = intermediate_tensors["residual"] # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -871,8 +891,7 @@ class MolmoModel(nn.Module, SupportsQuant): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -939,8 +958,12 @@ def get_patches_grid_size( def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: - tilings = [(i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) if i * j <= max_num] + tilings = [ + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num + ] return sorted(tilings, key=lambda x: x[0] * x[1]) @@ -1128,7 +1151,8 @@ class MolmoProcessorWrapper: **kwargs, ) -> BatchFeature: outputs = self.processor.process( # type: ignore - text, images, **kwargs) + text, images, **kwargs + ) if images is None: images = [] @@ -1146,13 +1170,14 @@ class MolmoProcessorWrapper: self.select_tiling( image_width=image.size[0], image_height=image.size[1], - ) for image in images + ) + for image in images ] # For each image: tiling_h * tiling_w + extra num_crops = torch.tensor(tilings).prod(-1) + 1 assert num_crops.sum() == len(feat_is_patch) - outputs["feat_is_patch"] = feat_is_patch + outputs["image_input_idx"] = image_input_idx outputs["num_crops"] = num_crops outputs["img_patch_id"] = self.image_patch_id @@ -1160,7 +1185,6 @@ class MolmoProcessorWrapper: class MolmoProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: processor = self.ctx.get_hf_processor(**kwargs) return MolmoProcessorWrapper(processor) @@ -1187,8 +1211,9 @@ class MolmoProcessingInfo(BaseProcessingInfo): image_token_length_w = processor.image_token_length_w image_token_length_h = processor.image_token_length_h - extra = image_token_length_w * image_token_length_h - joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size) + # Calculate total tokens: 2 for start/end + (w+1)*h for column separators + extra = 2 + (image_token_length_w + 1) * image_token_length_h + joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size) return extra + joint @@ -1209,8 +1234,7 @@ class MolmoProcessingInfo(BaseProcessingInfo): ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -1219,7 +1243,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -1227,21 +1250,24 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): - def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], @@ -1259,7 +1285,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): processor, # type: ignore dict(tokens=tokens), ) - prompt_ids, = processed_data.pop("input_ids").tolist() + (prompt_ids,) = processed_data.pop("input_ids").tolist() return prompt_ids @@ -1273,10 +1299,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): return dict( images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), - image_masks=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops), - feat_is_patch=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops), + image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1299,8 +1323,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): img_end_id = processor.im_end_id extra_row = [img_patch_id] * image_token_length_w + [img_col_id] - extra_joint = ([img_start_id] + extra_row * image_token_length_h + - [img_end_id]) + extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id] def get_insertion_molmo(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) @@ -1311,10 +1334,12 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): image_height=image_size.height, ) - joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) + - [img_col_id]) - joint = ([img_start_id] + joint_row * - ((nrows + 1) // pooling_size) + [img_end_id]) + joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id] + joint = ( + [img_start_id] + + joint_row * ((nrows + 1) // pooling_size) + + [img_end_id] + ) return PromptUpdateDetails.select_token_id( extra_joint + joint, @@ -1330,11 +1355,14 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor, - info=MolmoProcessingInfo, - dummy_inputs=MolmoDummyInputsBuilder) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, - SupportsQuant): +@MULTIMODAL_REGISTRY.register_processor( + MolmoMultiModalProcessor, + info=MolmoProcessingInfo, + dummy_inputs=MolmoDummyInputsBuilder, +) +class MolmoForCausalLM( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping @@ -1366,7 +1394,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, packed_modules_mapping = { "qkv_proj": ["qkv_proj"], "gate_up_proj": ["gate_up_proj"], # language model - "merged_linear": ["gate_proj", "up_proj"] # image_projector + "merged_linear": ["gate_proj", "up_proj"], # image_projector } @classmethod @@ -1387,10 +1415,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, self.lora_config = lora_config vision_config = VisionBackboneConfig() - self.vision_backbone = MolmoVisionBackbone(config, vision_config, - quant_config) - self.model = MolmoModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) + self.model = MolmoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.img_patch_id = None if self.config.weight_tying: @@ -1400,13 +1428,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, config.embedding_size or config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(config.embedding_size - or config.vocab_size) + self.logits_processor = LogitsProcessor( + config.embedding_size or config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( self, @@ -1414,27 +1445,29 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) - feat_is_patch = kwargs.pop("feat_is_patch", None) + image_input_idx = kwargs.pop("image_input_idx", None) num_crops = kwargs.pop("num_crops", None) if images is None: return None if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_crops. " - f"Got type: {type(num_crops)}") + raise ValueError( + f"Incorrect type of num_crops. Got type: {type(num_crops)}" + ) num_crops = flatten_bn(num_crops, concat=True) img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): - raise ValueError("Incorrect type of img_patch_id. " - f"Got type: {type(img_patch_id)}") + raise ValueError( + f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}" + ) self.img_patch_id = img_patch_id.flatten().unique().item() return MolmoImageInputs( images=images, image_masks=image_masks, - feat_is_patch=feat_is_patch, + image_input_idx=image_input_idx, num_crops=num_crops, ) @@ -1444,58 +1477,47 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ) -> list[torch.Tensor]: images = image_input["images"] image_masks = image_input["image_masks"] - feat_is_patch = image_input["feat_is_patch"] + image_input_idx = image_input["image_input_idx"] num_crops = image_input["num_crops"] # Call the vision backbone on the whole batch at once images_flat = flatten_bn(images, concat=True) - image_masks_flat = (None if image_masks is None else flatten_bn( - image_masks, concat=True)) - feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) + image_masks_flat = ( + None if image_masks is None else flatten_bn(image_masks, concat=True) + ) + image_input_idx_flat = flatten_bn(image_input_idx, concat=True) image_features_flat = self.vision_backbone( images=images_flat.unsqueeze(0), - image_masks=(None if image_masks_flat is None else - image_masks_flat.unsqueeze(0)), + image_masks=( + None if image_masks_flat is None else image_masks_flat.unsqueeze(0) + ), ).squeeze(0) # Only the features corresponding to patch tokens are relevant - return [ - feats[f_is_patch] for feats, f_is_patch in zip( - image_features_flat.split(num_crops.tolist()), - feat_is_patch_flat.split(num_crops.tolist()), - ) - ] + # Re-order the features using the image_input_idx tensor + results = [] + num_crops_list = num_crops.tolist() + for feats, img_idx in zip( + image_features_flat.split(num_crops_list), + image_input_idx_flat.split(num_crops_list), + ): + is_valid = img_idx >= 0 + valid_img_idx = img_idx[is_valid] + order = torch.argsort(valid_img_idx) + results.append(feats[is_valid][order]) + return results def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_patch_id is not None - - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_patch_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.LongTensor, @@ -1504,33 +1526,20 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: - if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1547,7 +1556,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, def _get_weights_with_merged_embedding( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: embedding_weights = {} for name, weight in weights: diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index d0fdab13ef0c9..3bf8fce0de0d4 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -42,7 +42,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import math from collections.abc import Sequence from copy import deepcopy from functools import cached_property @@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh from transformers.modeling_utils import PreTrainedModel from transformers.utils import is_flash_attn_2_available +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.utils import maybe_prefix from vllm.transformers_utils.configs.moonvit import MoonViTConfig if is_flash_attn_2_available(): @@ -69,11 +70,15 @@ def multihead_attention( v: torch.Tensor, q_cu_seqlens: Optional[torch.Tensor] = None, k_cu_seqlens: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: """Multi-head attention using flash attention 2. Args: - q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. The first element should be 0 and the last element should be q.shape[0]. @@ -86,10 +91,10 @@ def multihead_attention( """ # Unified format legal check assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" - assert q_cu_seqlens[-1] == q.shape[ - 0], "q_cu_seqlens must sum to q.shape[0]" - assert (k_cu_seqlens[-1] == k.shape[0] == - v.shape[0]), "k_cu_seqlens must sum to k.shape[0]" + assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]" + assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], ( + "k_cu_seqlens must sum to k.shape[0]" + ) assert q.dtype in [ torch.bfloat16, torch.float16, @@ -122,27 +127,29 @@ def sdpa_attention( """SDPA attention. Args: - q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. + k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens: Optional cumulative sequence lengths of q. + k_cu_seqlens: Optional cumulative sequence lengths of k. """ seq_length = q.shape[0] - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + attention_mask = torch.zeros( + [1, seq_length, seq_length], device=q.device, dtype=torch.bool + ) for i in range(1, len(q_cu_seqlens)): attention_mask[ ..., - q_cu_seqlens[i - 1]:q_cu_seqlens[i], - q_cu_seqlens[i - 1]:q_cu_seqlens[i], + q_cu_seqlens[i - 1] : q_cu_seqlens[i], + q_cu_seqlens[i - 1] : q_cu_seqlens[i], ] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) return attn_output @@ -161,8 +168,9 @@ def _apply_rope_input_validation(x, freqs_cis): assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype -def apply_rope(xq: torch.Tensor, xk: torch.Tensor, - freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def apply_rope( + xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """ Args: (The leading dimensions of all inputs should be the same) xq: query, tensor of shape (..., num_heads, head_dim) @@ -178,20 +186,15 @@ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, # ..., num_heads, head_dim/2 xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten( - -2) # ..., num_heads, head_dim - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten( - -2) # ..., num_heads, head_dim + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim return xq_out.type_as(xq), xk_out.type_as(xk) class Learnable2DInterpPosEmb(nn.Module): - - def __init__(self, - height: int, - width: int, - dim: int, - interpolation_mode: str = "bicubic") -> None: + def __init__( + self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic" + ) -> None: super().__init__() self.height = height self.width = width @@ -213,13 +216,16 @@ class Learnable2DInterpPosEmb(nn.Module): self.weight.permute((2, 0, 1)).unsqueeze(0), size=shape, mode=self.interpolation_mode, - ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1)) + ) + .squeeze(0) + .permute((1, 2, 0)) + .flatten(end_dim=1) + ) out = x + torch.cat(pos_embs) return out class MoonVisionPatchEmbed(nn.Module): - def __init__( self, out_dim: int, @@ -229,23 +235,23 @@ class MoonVisionPatchEmbed(nn.Module): pos_emb_width: int = 14, ): super().__init__() - assert isinstance( - patch_size, - (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}" + assert isinstance(patch_size, (int, Sequence)), ( + f"Invalid patch_size type: {type(patch_size)}" + ) if isinstance(patch_size, int): patch_size = (patch_size, patch_size) - assert (len(patch_size) == 2 - ), f"Expected patch_size to be a tuple of 2, got {patch_size}" + assert len(patch_size) == 2, ( + f"Expected patch_size to be a tuple of 2, got {patch_size}" + ) self.patch_size = patch_size - self.proj = nn.Conv2d(in_dim, - out_dim, - kernel_size=patch_size, - stride=patch_size) + self.proj = nn.Conv2d( + in_dim, out_dim, kernel_size=patch_size, stride=patch_size + ) - self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height, - width=pos_emb_width, - dim=out_dim) + self.pos_emb = Learnable2DInterpPosEmb( + height=pos_emb_height, width=pos_emb_width, dim=out_dim + ) def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: """ @@ -284,12 +290,9 @@ class Rope2DPosEmb(nn.Module): device (str): the device to store the precomputed cis """ - def __init__(self, - dim: int, - max_height: int, - max_width: int, - theta_base=10000, - device="cuda"): + def __init__( + self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda" + ): super().__init__() self.dim = dim assert self.dim % 4 == 0, "dim must be divisible by 4" @@ -314,18 +317,18 @@ class Rope2DPosEmb(nn.Module): flat_pos = torch.arange(0, N).float().to(self.device) x_pos = flat_pos % self.max_width y_pos = flat_pos // self.max_width - dim_range = (torch.arange(0, self.dim, - 4)[:(self.dim // 4)].float().to(self.device) - ) # C/4 - freqs = 1.0 / (self.theta_base**(dim_range / self.dim)) + dim_range = ( + torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device) + ) # C/4 + freqs = 1.0 / (self.theta_base ** (dim_range / self.dim)) x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 # N, C/4, 2 freqs_cis = torch.cat( - [x_cis.unsqueeze(dim=-1), - y_cis.unsqueeze(dim=-1)], dim=-1) + [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1 + ) # max_height, max_width, C/2 freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) return freqs_cis @@ -338,12 +341,13 @@ class Rope2DPosEmb(nn.Module): freqs_cis: tensor of shape (sum(t * height * width), dim//2) """ shapes = grid_hws.tolist() - assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width - for h, w in shapes), ( - shapes, - self.max_height, - self.max_width, - ) + assert all( + 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes + ), ( + shapes, + self.max_height, + self.max_width, + ) freqs_cis = torch.cat( [ self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2) @@ -353,8 +357,9 @@ class Rope2DPosEmb(nn.Module): ) return freqs_cis - def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor, - pos_idx_mask: torch.Tensor) -> torch.Tensor: + def get_freqs_cis_by_idx( + self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor + ) -> torch.Tensor: """ Args: pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token. @@ -363,16 +368,20 @@ class Rope2DPosEmb(nn.Module): Return: freqs_cis: tensor of shape (..., dim//2) """ - assert (pos_idx.shape[:-1] == pos_idx_mask.shape - and pos_idx.shape[-1] == 2 and pos_idx.ndim - == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape) + assert ( + pos_idx.shape[:-1] == pos_idx_mask.shape + and pos_idx.shape[-1] == 2 + and pos_idx.ndim == pos_idx_mask.ndim + 1 + ), (pos_idx.shape, pos_idx_mask.shape) assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype - shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2 - freqs_cis = torch.ones(shp, dtype=torch.complex64, - device=self.device) # ..., head_dim/2 - freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[ - ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]] + shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2 + freqs_cis = torch.ones( + shp, dtype=torch.complex64, device=self.device + ) # ..., head_dim/2 + freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[ + pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask] + ] return freqs_cis @@ -383,30 +392,40 @@ class MLP2(nn.Module): bias: whether to use bias in linear layer. """ - def __init__(self, dims: list[int], activation, bias=True): + def __init__( + self, + dims: list[int], + activation, + bias: bool = True, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() assert len(dims) == 3 - self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) - self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.use_data_parallel = use_data_parallel + self.fc0 = ReplicatedLinear( + dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0") + ) + self.fc1 = ReplicatedLinear( + dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1") + ) self.activation = activation - for m in [self.fc0, self.fc1]: - nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) - if m.bias is not None: - nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc0(x) + x, _ = self.fc0(x) x = self.activation(x) - return self.fc1(x) + x, _ = self.fc1(x) + return x class MoonVitEncoderLayer(nn.Module): - def __init__( self, num_heads: int, hidden_dim: int, mlp_dim: int, + prefix: str = "", + use_data_parallel: bool = False, *, attn_implementation: str = "sdpa", activation=F.gelu, @@ -423,9 +442,19 @@ class MoonVitEncoderLayer(nn.Module): self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) - self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) - self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) - self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + self.use_data_parallel = use_data_parallel + self.mlp = MLP2( + [hidden_dim, mlp_dim, hidden_dim], + activation, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.wqkv = ReplicatedLinear( + hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv" + ) + self.wo = ReplicatedLinear( + hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo" + ) def attention_qkvpacked( self, @@ -438,7 +467,7 @@ class MoonVitEncoderLayer(nn.Module): x (torch.Tensor): (batch_size, seqlen, hidden_dim) cu_seqlens (torch.Tensor): """ - xqkv = self.wqkv(x) + xqkv, _ = self.wqkv(x) qkv_shape = xqkv.size()[:-1] + ( 3, @@ -452,13 +481,10 @@ class MoonVitEncoderLayer(nn.Module): xq, xk = apply_rope(xq, xk, rope_freqs_cis) attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] - attn_out = attn_func(xq, - xk, - xv, - q_cu_seqlens=cu_seqlens, - k_cu_seqlens=cu_seqlens) - - attn_out = self.wo(attn_out) + attn_out = attn_func( + xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens + ) + attn_out, _ = self.wo(attn_out) return attn_out def forward( @@ -476,9 +502,9 @@ class MoonVitEncoderLayer(nn.Module): """ residual = hidden_states hidden_states = self.norm0(hidden_states) - attn_out = self.attention_qkvpacked(hidden_states, - cu_seqlens, - rope_freqs_cis=rope_freqs_cis) + attn_out = self.attention_qkvpacked( + hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis + ) hidden_states = residual + attn_out residual = hidden_states @@ -488,36 +514,48 @@ class MoonVitEncoderLayer(nn.Module): class MoonVitEncoder(nn.Module): - def __init__( self, hidden_dim: int, num_layers: int, block_cfg: dict, + prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.rope_2d = Rope2DPosEmb( - block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) + block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512 + ) self.blocks = nn.ModuleList( - [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) + [ + MoonVitEncoderLayer( + use_data_parallel=use_data_parallel, + prefix=f"{prefix}.blocks.{layer_idx}", + **block_cfg, + ) + for layer_idx in range(num_layers) + ] + ) self.final_layernorm = nn.LayerNorm(hidden_dim) - def forward(self, hidden_states: torch.Tensor, - grid_hw: torch.Tensor) -> torch.Tensor: - rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( - grid_hws=grid_hw) + def forward( + self, hidden_states: torch.Tensor, grid_hw: torch.Tensor + ) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw) - lengths = torch.cat(( - torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), - grid_hw[:, 0] * grid_hw[:, 1], - )) + lengths = torch.cat( + ( + torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device), + ) + ) cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) for _, block in enumerate(self.blocks): - hidden_states = block(hidden_states, - cu_seqlens, - rope_freqs_cis=rope_freqs_cis) + hidden_states = block( + hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis + ) hidden_states = self.final_layernorm(hidden_states) @@ -525,9 +563,9 @@ class MoonVitEncoder(nn.Module): def patch_merger( - x: torch.Tensor, - grid_hw: torch.Tensor, - merge_kernel_size: list[int, int] = (2, 2), + x: torch.Tensor, + grid_hw: torch.Tensor, + merge_kernel_size: list[int, int] = (2, 2), ) -> list[torch.Tensor]: d_model = x.size(-1) @@ -536,15 +574,17 @@ def patch_merger( for x_shape in grid_hw.tolist(): height, width = x_shape[0], x_shape[1] # Get the current sequence - seq = x[pre_sum:pre_sum + height * width] + seq = x[pre_sum : pre_sum + height * width] # Reshape along self.merge_kernel_size and concat to the last dimension kernel_height, kernel_width = merge_kernel_size new_height, new_width = height // kernel_height, width // kernel_width - reshaped_seq = seq.view(new_height, kernel_height, new_width, - kernel_width, d_model) + reshaped_seq = seq.view( + new_height, kernel_height, new_width, kernel_width, d_model + ) reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous() - padded_seq = reshaped_seq.view(new_height * new_width, - kernel_height * kernel_width, -1) + padded_seq = reshaped_seq.view( + new_height * new_width, kernel_height * kernel_width, -1 + ) outputs.append(padded_seq) pre_sum += height * width @@ -552,7 +592,6 @@ def patch_merger( class MoonVitVLProjector(nn.Module): - def __init__( self, in_channels: int, @@ -562,13 +601,10 @@ class MoonVitVLProjector(nn.Module): out_dim: int = 4096, ): super().__init__() - self.hidden_size = in_channels * merge_kernel_size[ - 0] * merge_kernel_size[1] + self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1] self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) - self.linear_1 = nn.Linear(self.hidden_size, - self.hidden_size, - bias=True) + self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.act = ACT2FN[hidden_act] self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) @@ -587,11 +623,21 @@ class MoonVitPretrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + def __init__( + self, + config: MoonViTConfig, + use_data_parallel: bool = False, + prefix: str = "", + *inputs, + **kwargs, + ): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) + self.use_data_parallel = use_data_parallel self.merge_kernel_size = config.merge_kernel_size + self.hidden_size = config.hidden_size self.patch_size = config.patch_size + self.vit_processing_type = "rope_2d" self.patch_embed = MoonVisionPatchEmbed( out_dim=config.hidden_size, patch_size=config.patch_size, @@ -610,10 +656,12 @@ class MoonVitPretrainedModel(PreTrainedModel): "attn_bias": True, "attn_implementation": config._attn_implementation, }, + prefix=f"{prefix}.encoder", ) - def forward(self, pixel_values: torch.Tensor, - grid_hw: torch.Tensor) -> torch.Tensor: + def forward( + self, pixel_values: torch.Tensor, grid_hw: torch.Tensor + ) -> torch.Tensor: """ Args: pixel_values (torch.Tensor): The input pixel values. @@ -624,7 +672,7 @@ class MoonVitPretrainedModel(PreTrainedModel): """ hidden_states = self.patch_embed(pixel_values, grid_hw) hidden_states = self.encoder(hidden_states, grid_hw) - hidden_states = patch_merger(hidden_states, - grid_hw, - merge_kernel_size=self.merge_kernel_size) + hidden_states = patch_merger( + hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size + ) return hidden_states diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8db52a69924c9..3f1f2bbcb0267 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -4,6 +4,7 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -13,31 +14,38 @@ from transformers import MptConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _get_alibi_slopes( total_num_heads: int, alibi_bias_max: int, ) -> torch.Tensor: - next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) + next_power_of_2 = 2 ** math.ceil(math.log2(total_num_heads)) m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) m = m.mul(alibi_bias_max / next_power_of_2) slopes = 1.0 / torch.pow(2, m) @@ -47,7 +55,6 @@ def _get_alibi_slopes( class MPTAttention(nn.Module): - def __init__( self, config: MptConfig, @@ -107,20 +114,21 @@ class MPTAttention(nn.Module): tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(self.total_num_heads, - self.alibi_bias_max) + alibi_slopes = _get_alibi_slopes(self.total_num_heads, self.alibi_bias_max) alibi_slopes = alibi_slopes[head_start:head_end].tolist() self.head_dim = self.d_model // self.total_num_heads scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -141,7 +149,6 @@ class MPTAttention(nn.Module): class MPTMLP(nn.Module): - def __init__( self, config: MptConfig, @@ -173,7 +180,6 @@ class MPTMLP(nn.Module): class MPTBlock(nn.Module): - def __init__( self, config: MptConfig, @@ -184,10 +190,9 @@ class MPTBlock(nn.Module): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = MPTAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -210,7 +215,6 @@ class MPTBlock(nn.Module): @support_torch_compile class MPTModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -227,19 +231,18 @@ class MPTModel(nn.Module): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: MPTBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.blocks") + lambda prefix: MPTBlock(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.blocks", + ) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): - if hasattr(module, "bias") and isinstance( - module.bias, nn.Parameter): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.d_model)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.d_model + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -260,15 +263,14 @@ class MPTModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for block in self.blocks[self.start_layer:self.end_layer]: + for block in islice(self.blocks, self.start_layer, self.end_layer): hidden_states = block(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -278,15 +280,13 @@ class MPTModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class MPTForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -295,12 +295,14 @@ class MPTForCausalLM(nn.Module, SupportsPP): assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "transformer")) + self.transformer = MPTModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -312,20 +314,18 @@ class MPTForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py new file mode 100644 index 0000000000000..91dfa67355341 --- /dev/null +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -0,0 +1,1609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# -------------------------------------------------------- +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/internvl.py +# under Apache-2.0 License +# LICENSE is in root directory. +# -------------------------------------------------------- + +import copy +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union + +import numpy.typing as npt +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import BatchFeature, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + MultiModalEmbeddings, + SupportsMultiModal, + SupportsMultiModalPruning, +) +from vllm.model_executor.models.internvl import ( + calculate_internvl_targets, + get_internvl_target_ratios, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM +from vllm.model_executor.models.radio import RadioModel +from vllm.model_executor.models.utils import ( + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.evs import ( + compute_retained_tokens_count, + compute_retention_mask, +) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + _seq2tokens, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.radio import RadioConfig +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + cached_tokenizer_from_config, + encode_tokens, +) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .utils import _merge_multimodal_embeddings + +# Configure PIL to handle large images without warnings +# This prevents DecompressionBombWarning for legitimate large images +Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely +# Alternative: Set a specific higher limit +# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels + +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<image>" + +# Profiling +MAX_FRAMES = 16 +DEFAULT_NUM_TILES = 12 + + +class NanoNemotronVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class NanoNemotronVLImageEmbeddinInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +NanoNemotronVLImageInputs = Union[ + NanoNemotronVLImagePixelInputs, NanoNemotronVLImageEmbeddinInputs +] + + +class NanoNemotronVLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bvf: Batch size * number of videos * num_frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each video frame + - w: Width of each video frame + """ + + type: Literal["pixel_values_videos"] + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - n: Number of videos + - f: Total video feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ + + type: Literal["video_embeds"] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] + + +NanoNemotronVLVideoInputs = Union[ + NanoNemotronVLVideoPixelInputs, NanoNemotronVLVideoEmbeddingInputs +] + + +def dynamic_preprocess( + image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0 +): + orig_width, orig_height = image.size + + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + blocks, target_width, target_height = calculate_internvl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + processed_images = [ + img.convert("RGB") if img.mode != "RGB" else img for img in processed_images + ] + processed_images = [ + T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)( + img + ) + for img in processed_images + ] + processed_images = [T.ToTensor()(img) for img in processed_images] + return processed_images + + +def image_to_pixel_values( + image: Image.Image, + *, + input_size: int, + max_num: int, + use_thumbnail: bool, + idx: int, +) -> torch.Tensor: + images = dynamic_preprocess( + image, + image_size=input_size, + max_num_tiles=max_num, + use_thumbnail=use_thumbnail, + idx=idx, + ) + + pixel_values = torch.stack(images) + return pixel_values + + +def video_to_pixel_values( + video: npt.NDArray, + *, + input_size: int, + max_num_tiles: int = 1, + use_thumbnail: bool, +) -> torch.Tensor: + assert max_num_tiles == 1, "Video modality always uses one tile" + + # Convert each frame to a single resized tile tensor consistent + # with image path + frames_tensors: list[torch.Tensor] = [] + for frame in video: + pil_frame = dynamic_preprocess( + Image.fromarray(frame, mode="RGB"), + image_size=input_size, + max_num_tiles=max_num_tiles, + use_thumbnail=use_thumbnail, + idx=0, + ) + # dynamic_preprocess returns tensors already; take the single tile + assert len(pil_frame) >= 1 + frames_tensors.append(pil_frame[-1]) + + return torch.stack(frames_tensors) + + +class BaseNanoNemotronVLProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *args, + max_num_tiles: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES + image_size: int = config.force_image_size + patch_size: int = config.patch_size + + self.num_image_token = int( + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) + self.image_size = image_size + self.use_thumbnail: bool = config.use_thumbnail + self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) + self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + ) -> int: + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + num_patches, _, _ = calculate_internvl_targets( + orig_width=image_width, + orig_height=image_height, + target_ratios=target_ratios, + image_size=self.image_size, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + max_num_tiles: int, + ) -> list[torch.Tensor]: + return [ + image_to_pixel_values( + image, + input_size=self.image_size, + max_num=max_num_tiles, + use_thumbnail=self.use_thumbnail, + idx=idx, + ) + for idx, image in enumerate(images) + ] + + def _preprocess_image( + self, + text: list[str], + images: list[Image.Image], + max_num_tiles: int, + ) -> tuple[list[str], dict[str, torch.Tensor]]: + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) + image_inputs = { + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), + } + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + image_repl = self.get_image_repl(feature_size, num_patches) + text = [t.replace("<image>", image_repl.full, 1) for t in text] + return text, image_inputs + + def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): + if input_item is None: + input_item = [] + if not isinstance(input_item, list): + input_item = [input_item] + return input_item + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + ) -> BatchFeature: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = self.max_num_tiles + + text, images = [self._make_batch_input(x) for x in (text, images)] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) + + +class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): + """ + HF Processor with extended video processing logic. + Code for video processing is adapted from video example: + https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + max_num_tiles: Optional[int] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + video_token: Optional[str] = None, + video_pruning_rate: Optional[float] = None, + ) -> None: + super().__init__( + config=config, + tokenizer=tokenizer, + max_num_tiles=max_num_tiles, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + # add extra video token for video processing + self.video_token = video_token + self.video_pruning_rate = video_pruning_rate + + @property + def supports_video(self) -> bool: + return self.video_token_id is not None + + @property + def video_token_id(self) -> Optional[int]: + if self.video_token is None: + return None + return self.tokenizer.get_vocab().get(self.video_token, None) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT) + + def _videos_to_pixel_values_lst( + self, + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + return [ + video_to_pixel_values( + video, + input_size=self.image_size, + max_num_tiles=max_num_tiles, + use_thumbnail=self.use_thumbnail, + ) + for video in videos + ] + + def _preprocess_video( + self, + text: list[str], + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ): + if len(videos) == 0 or not self.supports_video: + video_inputs = {} + else: + pixel_values_lst_video = self._videos_to_pixel_values_lst( + videos, + max_num_tiles=max_num_tiles, + dynamic_image_size=dynamic_image_size, + ) + + video_inputs = { + "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "video_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst_video] + ), + } + + image_size: int = self.config.force_image_size + patch_size: int = self.config.patch_size + downsample_ratio = self.config.downsample_ratio + tokens_in_single_frame = int( + (image_size * image_size // patch_size**2) * (downsample_ratio**2) + ) + + for pixel_values in pixel_values_lst_video: + num_frames = pixel_values.shape[0] + + if ( + self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ): + # Start of EVS-specific code + num_tokens = compute_retained_tokens_count( + tokens_per_frame=tokens_in_single_frame, + num_frames=num_frames, + q=self.video_pruning_rate, + ) + + # Here we just need placeholders that won't actually be replaced - + # we just need to make sure the total number of tokens is correct + # assign all tokens to the first frame + tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) + + # End of EVS-specific code + else: + tokens_per_frame = [tokens_in_single_frame] * num_frames + + video_repl = self.get_video_repl(tokens_per_frame, self.video_token) + + text = [t.replace("<video>", video_repl.full, 1) for t in text] + return text, video_inputs + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> BatchFeature: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = self.max_num_tiles + + text, images, videos = [ + self._make_batch_input(x) for x in (text, images, videos) + ] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text, video_inputs = self._preprocess_video( + text=text, + videos=videos, + max_num_tiles=1, + dynamic_image_size=dynamic_image_size, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + combined_outputs = {**text_inputs, **image_inputs, **video_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) + + @classmethod + def get_video_repl( + cls, + tokens_per_frame: list[int], + video_context_token: str = IMG_CONTEXT, + ) -> PromptUpdateDetails[str]: + """ + Build prompt replacement for a video. + The replacement returned is not actually used to replace the placeholder + tokens - it's just used to make sure we allocate the correct number + of tokens. + Actual replacement is done in get_multimodal_embeddings of + NemotronH_Nano_VL_V2 + (specifically in _process_video_input -> _create_final_video_embeddings). + There, we create the final embeddings with text embeddings for indicator tokens + and video embeddings for video tokens. + This is a single function that handles all cases - non EVS, EVS dummy, EVS real. + The differentiation is done via tokens_per_frame parameter. + - non EVS case - constant value same value across all frames + - EVS dummy - Doesn't matter how tokens are distributed between frames - just + make sure the total number of tokens is correct. + - EVS real (called from get_real_video_repl_for_evs) - different value per frame + Args: + tokens_per_frame (list[int]): number of tokens per frame + video_context_token (str): the token to use for the video context + """ + repl_full = "".join( + [ + f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}" + for i, num_tokens in enumerate(tokens_per_frame) + ] + ) + + return PromptUpdateDetails.from_seq(repl_full) + + +class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): + """Basic image-only ProcessingInfo for InternVL-style models.""" + + @abstractmethod + def get_hf_processor( + self, + **kwargs: object, + ) -> BaseNanoNemotronVLProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + processor: Optional[BaseNanoNemotronVLProcessor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + max_num_tiles=max_num_tiles, + ) + + def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize: + processor = self.get_hf_processor() + + base_size = processor.image_size + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in target_ratios: + width, height = base_size * wr, base_size * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + max_num_tiles=max_num_tiles, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + def get_max_image_tokens(self) -> int: + processor = self.get_hf_processor() + # Use default max_num_tiles for max tokens calculation + max_num_tiles = processor.max_num_tiles + target_width, target_height = self.get_image_size_with_most_features( + max_num_tiles + ) + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + max_num_tiles=max_num_tiles, + processor=processor, + ) + + +_I = TypeVar("_I", bound=BaseNanoNemotronVLProcessingInfo) + + +class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): + """ProcessingInfo extended for video processing""" + + @property + def supports_video(self): + return self.get_hf_processor().supports_video + + def get_supported_mm_limits(self): + video_limit = {"video": None} if self.supports_video else {} + return {**super().get_supported_mm_limits(), **video_limit} + + def get_video_token(self) -> Optional[str]: + return IMG_CONTEXT + + def get_video_pruning_rate(self) -> Optional[float]: + return self.ctx.get_mm_config().video_pruning_rate + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + processor = self.get_hf_processor() # we get the CustomProcessor here + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token + max_frames_per_video = max_total_frames // max(max_videos, 1) + + max_frames_per_video = min(max_frames_per_video, MAX_FRAMES) + return max(max_frames_per_video, 1) + + def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: + return self.ctx.init_processor( + NanoNemotronVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + video_token=self.get_video_token(), + video_pruning_rate=self.get_video_pruning_rate(), + **kwargs, + ) + + +class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): + """Basic image-only MultiModalProcessor for InternVL-style models.""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id + + # Since there may be extra tokens in the feature placeholders, + # we need to pass the image token ID to the model to select the + # tokens to merge from the vision encoder outputs + processed_outputs["image_token_id"] = torch.tensor(image_token_id) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) + + return dict( + pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_patches + ), + image_num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_custom(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + # Extract max_num_tiles from kwargs, default to 12 + max_num_tiles = hf_processor_mm_kwargs.get( + "max_num_tiles", hf_processor.max_num_tiles + ) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + max_num_tiles=max_num_tiles, + processor=hf_processor, + ) + + num_patches = None + local_image_num_patches = image_num_patches + if isinstance(local_image_num_patches, torch.Tensor): + local_image_num_patches = local_image_num_patches.tolist() + if isinstance(local_image_num_patches, (list, tuple)) and item_idx < len( + local_image_num_patches + ): + num_patches = int(local_image_num_patches[item_idx]) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="<image>", + replacement=get_replacement_custom, + ) + ] + + +class NanoNemotronVLMultiModalProcessor( + NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo] +): + """MultiModalProcessor extended for video support""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + if ( + self.info.supports_video + and (video_token_id := hf_processor.video_token_id) is not None + ): + processed_outputs["video_token_id"] = torch.tensor(video_token_id) + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) + if self.info.supports_video: + video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) + num_videos = len(video_num_patches) + video_fields = dict( + pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches + ), + video_num_patches=MultiModalFieldConfig.batched("video"), + video_token_id=MultiModalFieldConfig.shared("video", num_videos), + ) + else: + video_fields = {} + + return image_fields | video_fields + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + prompt_repl = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] + assert isinstance(video_num_patches, torch.Tensor) + video_num_patches = video_num_patches.tolist() + else: + video_num_patches = [] + + def get_video_replacement_internvl(item_idx: int): + feature_size = hf_processor.num_image_token + num_patches = video_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + # Start of EVS-specific code + num_tokens = compute_retained_tokens_count( + tokens_per_frame=feature_size, + num_frames=num_patches, + q=video_pruning_rate, + ) + # Here we just need placeholders that won't actually be replaced - + # we just need to make sure the total number of tokens is correct + # assign all tokens to the first frame + tokens_per_frame = [num_tokens] + [0] * (num_patches - 1) + + # End of EVS-specific code + else: + tokens_per_frame = [feature_size] * num_patches + + return hf_processor.get_video_repl( + tokens_per_frame, + video_context_token=hf_processor.video_token, + ) + + if self.info.supports_video: + prompt_repl = [ + *prompt_repl, + PromptReplacement( + modality="video", + target="<video>", + replacement=get_video_replacement_internvl, + ), + ] + + return prompt_repl + + +class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + """Basic image-only DummyInputsBuilder for InternVL-style models.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + return "<image>" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + # Use default max_num_tiles for dummy data generation + max_num_tiles = 12 + target_width, target_height = self.info.get_image_size_with_most_features( + max_num_tiles + ) + num_images = mm_counts.get("image", 0) + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class NanoNemotronVLDummyInputsBuilder( + NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo] +): + """DummyInputsBuilder extended for video support""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_videos = mm_counts.get("video", 0) + + return super().get_dummy_text(mm_counts) + "<video>" * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + dummy_image = super().get_dummy_mm_data( + seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options + ) + if self.info.supports_video: + config = self.info.get_hf_config() + image_size: int = config.force_image_size + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + num_videos = mm_counts.get("video", 0) + video_overrides = mm_options.get("video") if mm_options else None + dummy_video = { + "video": self._get_dummy_videos( + width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ) + } + else: + dummy_video = {} + return {**dummy_image, **dummy_video} + + +@MULTIMODAL_REGISTRY.register_processor( + NanoNemotronVLMultiModalProcessor, + info=NanoNemotronVLProcessingInfo, + dummy_inputs=NanoNemotronVLDummyInputsBuilder, +) +class NemotronH_Nano_VL_V2( + nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning +): + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<image>" + if modality.startswith("video"): + return "<video>" + return None + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + image_size = config.force_image_size + patch_size = config.patch_size + self.patch_size = patch_size + self.template = config.template + self.num_image_token = int( + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + self.image_tag_type = config.image_tag_type + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.vision_model = self.get_vit_model_from_radio_config(config).to( + self.language_model.config.torch_dtype + ) + + # Construct the vision projection. + vit_hidden_size = config.vit_hidden_size + vision_projection_hidden_size = config.projector_hidden_size + llm_hidden_size = config.text_config.hidden_size + + self.mlp1 = nn.Sequential( + RMSNorm( + hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + eps=1e-5, + ), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + vision_projection_hidden_size, + bias=False, + ), + ReLUSquaredActivation(), + nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), + ) + self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) + + self.img_context_token_id = None + self.video_context_token_id = None + self.config = config + self.model_config = vllm_config.model_config + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view( + n, + w, + int(h * scale_factor), + int(c / scale_factor), + ) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> + # N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": + warnings.warn( + "In ps_version 'v1', the height and width have not " + "been swapped back, which results in a transposed image.", + stacklevel=2, + ) + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + vit_embeds = self.vision_model(pixel_values) + vit_embeds = vit_embeds.to(dtype=torch.bfloat16) + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[NanoNemotronVLImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) + + return NanoNemotronVLImageEmbeddinInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}" + ) + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}" + ) + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + + return NanoNemotronVLImagePixelInputs( + type="pixel_values", + pixel_values_flat=pixel_values_flat, + num_patches=image_num_patches, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: NanoNemotronVLImageInputs + ) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return (image_embeds.view(-1, self.config.text_config.hidden_size),) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _process_video_input( + self, video_input: NanoNemotronVLVideoPixelInputs + ) -> tuple[torch.Tensor, ...]: + """Process video input and create final embeddings with video content + and indicator tokens.""" + # Get video embeddings using the same processing as images + video_embeddings = self._process_image_input(video_input) + + final_video_embeddings: tuple[torch.Tensor, ...] = () + + image_rows = image_cols = self.config.force_image_size + downsample_ratio = self.config.downsample_ratio + patch_size = self.config.patch_size + rows = int(image_rows * downsample_ratio // patch_size) + cols = int(image_cols * downsample_ratio // patch_size) + video_pruning_rate = self.video_pruning_rate + + # Calculate video feature dimensions (number of frames and + # their feature size (AKA tokens per frame)) + # TODO: Maybe this can be optimized to avoid the loop? + for i, single_video_embeddings in enumerate(video_embeddings): + num_frames = video_input["num_patches"][i].item() + assert single_video_embeddings.shape[0] % num_frames == 0 + + if video_pruning_rate is not None and video_pruning_rate > 0.0: + # Start of EVS-specific code + retention_mask = compute_retention_mask( + single_video_embeddings, + video_size_thw=(num_frames, rows, cols), + spatial_merge_size=1, + q=video_pruning_rate, + ) + + # apply retention mask + single_video_embeddings = single_video_embeddings[retention_mask] + + # calculate the actual number of retained tokens per frame + retention_mask_thw = retention_mask.reshape(num_frames, rows, cols) + num_tokens_per_frame = ( + retention_mask_thw.sum(dim=(1, 2)).long().tolist() + ) + # End of EVS-specific code + else: + feature_size = single_video_embeddings.shape[0] // num_frames + num_tokens_per_frame = [feature_size] * num_frames + + final_video_embeddings += ( + self._create_final_video_embeddings( + single_video_embeddings, + num_tokens_per_frame, + ), + ) + + return final_video_embeddings + + def _create_final_video_embeddings( + self, + video_embeddings: torch.Tensor, + num_tokens_per_frame: list[int], + ) -> torch.Tensor: + """Create final embeddings that combine video embeddings with + text embeddings of indicator tokens. + + These final embeddings contain: + - Actual video embeddings in positions corresponding to video content + - Text embeddings for indicator tokens (<img>, </img>, and + frame separation text) in their respective positions + + These embeddings will replace the placeholder embeddings to create + input_embeds for the LLM. + """ + device = video_embeddings.device + + # Generate video replacement text and convert to token IDs + video_repl_text = NanoNemotronVLProcessor.get_video_repl( + num_tokens_per_frame, + IMG_CONTEXT, + ).full + + tokenizer = cached_tokenizer_from_config(self.model_config) + repl_token_ids = torch.tensor( + _seq2tokens(tokenizer, video_repl_text), device=device + ) + + # Get embedding token IDs for image context + embed_token_ids = torch.tensor( + encode_tokens(tokenizer, IMG_CONTEXT), device=device + ) + + # Create mask for video embedding positions + is_video_embed = torch.isin(repl_token_ids, embed_token_ids) + + # Create final video embeddings, merging text embeddings for indicator + # tokens with video embeddings + text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids) + final_video_embeddings = _merge_multimodal_embeddings( + inputs_embeds=text_embeddings, + multimodal_embeddings=video_embeddings, + is_multimodal=is_video_embed, + ) + + return final_video_embeddings + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> Optional[NanoNemotronVLVideoPixelInputs]: + pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) + video_num_patches = kwargs.pop("video_num_patches", None) + video_embeds = kwargs.pop("video_embeds", None) + + if pixel_values_flat_video is None and video_embeds is None: + return None + + if video_embeds is not None: + return NanoNemotronVLVideoEmbeddingInputs( + type="video_embeds", + data=flatten_bn(video_embeds), + ) + + video_token_id = kwargs["video_token_id"] + assert isinstance(video_token_id, torch.Tensor) + self.video_context_token_id = video_token_id.flatten().unique().item() + + if pixel_values_flat_video is not None: + if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat_video)}" + ) + + if not isinstance(video_num_patches, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image_num_patches. " + f"Got type: {type(video_num_patches)}" + ) + + pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True) + video_num_patches = flatten_bn(video_num_patches, concat=True) + expected_h = expected_w = self.config.force_image_size + resolve_bindings = {"h": expected_h, "w": expected_w} + + return NanoNemotronVLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_flat=pixel_values_flat_video, + num_patches=video_num_patches, + resolve_bindings=resolve_bindings, + ) + + raise AssertionError("This line should be unreachable.") + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if ( + input_key in ("pixel_values_flat", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_flat_video",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + + return modalities + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + # Validate the multimodal input keyword arguments + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if modalities is None: + return [] + + # # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp1", + tower_model="vision_model", + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + adapter_dict = dict(self.mlp1.named_parameters()) + + def is_llm(name: str) -> bool: + return name.startswith("language_model") + + def is_adapter_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("mlp1") + + def is_vision_weights(name: str) -> bool: + return name.startswith("vision_model.radio_model.") + + # Separate weights by component + llm_weights = [] + vision_weights = [] + + for name, w in weights: + if is_llm(name): + # Strip 'language_model.' prefix for LLM weights + llm_weights.append((".".join(name.split(".")[1:]), w)) + elif is_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_vision_weights(name): + # Convert: vision_model.radio_model.* → radio_model.* + hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix + vision_weights.append((hf_key, w)) + + self.language_model.load_weights(llm_weights) + self.vision_model.load_weights(vision_weights) + + def print_architecture(self, detailed: bool = True, save_to_file: str = None): + """ + Print model architecture with parameter names, shapes, and sizes. + + Args: + detailed: If True, show detailed parameter breakdown + save_to_file: If provided, save output to this file path + """ + import sys + from io import StringIO + + # Capture output if saving to file + original_stdout = sys.stdout + if save_to_file: + sys.stdout = StringIO() + + try: + print("=" * 100) + print("NemotronH_Nano_VL_V2 Model Architecture") + print("=" * 100) + + total_params = 0 + param_groups = { + "language_model": [], + "vision_model": [], + "mlp1": [], + "other": [], + } + + for name, param in self.named_parameters(): + param_size = param.numel() + total_params += param_size + + # Group parameters by main component + if name.startswith("language_model"): + param_groups["language_model"].append( + (name, param.shape, param_size, param.dtype) + ) + elif name.startswith("vision_model"): + param_groups["vision_model"].append( + (name, param.shape, param_size, param.dtype) + ) + elif name.startswith("mlp1"): + param_groups["mlp1"].append( + (name, param.shape, param_size, param.dtype) + ) + else: + param_groups["other"].append( + (name, param.shape, param_size, param.dtype) + ) + + if detailed: + print( + f"{name:<70} | Shape: {str(param.shape):<25} | " + f"Size: {param_size:>12,} | Dtype: {param.dtype}" + ) + + print("=" * 100) + print("Summary by Component:") + print("-" * 60) + + for component, params in param_groups.items(): + if params: # Only show components that have parameters + component_total = sum(size for _, _, size, _ in params) + percentage = ( + (component_total / total_params) * 100 + if total_params > 0 + else 0 + ) + print( + f"{component:<20} | Parameters: {len(params):>4} | " + f"Total Size: {component_total:>15,} | " + f"{percentage:>6.2f}%" + ) + + print("-" * 60) + print(f"{'Total Parameters':<20} | {total_params:>15,}") + + # Estimate memory usage (assuming bfloat16 = 2 bytes per parameter) + memory_mb = total_params * 2 / (1024**2) + memory_gb = memory_mb / 1024 + print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}") + print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}") + print("=" * 100) + + # Save to file if requested + if save_to_file: + output = sys.stdout.getvalue() + sys.stdout = original_stdout + with open(save_to_file, "w") as f: + f.write(output) + print(f"Architecture saved to: {save_to_file}") + print(output) # Also print to console + + finally: + if save_to_file and sys.stdout != original_stdout: + sys.stdout = original_stdout + + def get_model_info(self): + """ + Get basic model information as a dictionary. + """ + total_params = sum(p.numel() for p in self.parameters()) + + component_info = {} + for name, param in self.named_parameters(): + component = name.split(".")[0] + if component not in component_info: + component_info[component] = {"params": 0, "size": 0} + component_info[component]["params"] += 1 + component_info[component]["size"] += param.numel() + + return { + "model_name": "NemotronH_Nano_VL_V2", + "total_parameters": total_params, + "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16 + "components": component_info, + "config": { + "image_size": getattr(self.config, "force_image_size", None), + "patch_size": getattr(self.config, "patch_size", None), + "num_image_token": self.num_image_token, + "downsample_ratio": self.downsample_ratio, + }, + } + + def get_vit_model_from_radio_config(self, hf_config): + hf_config_vision = hf_config.vision_config + model_name = hf_config_vision.args.get("model") + if model_name is None: + raise ValueError(f"Unsupported vit model type: {model_name}") + + preferred_resolution = getattr(hf_config_vision, "preferred_resolution", None) + image_size = preferred_resolution[0] if preferred_resolution else 224 + patch_size = getattr(hf_config_vision, "patch_size", 16) + + radio_config = RadioConfig( + model_name=model_name, + image_size=image_size, + patch_size=patch_size, + norm_mean=hf_config.norm_mean, + norm_std=hf_config.norm_std, + reg_tokens=( + hf_config_vision.args.get("register_multiple") + if hasattr(hf_config_vision, "args") + and isinstance(hf_config_vision.args, dict) + else None + ), + ) + + return RadioModel(config=radio_config) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs + ) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.language_model.mamba_cache.get_seqlen_agnostic_capture_inputs( + batch_size + ) + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"): + text_config = vllm_config.model_config.hf_config.text_config + temp_vllm_config = copy.deepcopy(vllm_config) + temp_vllm_config.model_config.hf_config = text_config + return NemotronHForCausalLM.get_mamba_state_shape_from_config(temp_vllm_config) + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"): + text_config = vllm_config.model_config.hf_config.text_config + temp_vllm_config = copy.deepcopy(vllm_config) + temp_vllm_config.model_config.hf_config = text_config + return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index eabf47b1aede4..8f07a2cf12f7a 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -34,24 +36,35 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj @@ -65,20 +78,21 @@ def _cast_if_autocast_enabled(*args): return args else: return torch.amp.autocast_mode._cast( - args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype()) + args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype() + ) class NemotronLayerNorm1P(nn.LayerNorm): - - def __init__(self, - normalized_shape: Union[int, list[int], torch.Size], - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None): - super().__init__(normalized_shape, eps, elementwise_affine, bias, - device, dtype) + def __init__( + self, + normalized_shape: Union[int, list[int], torch.Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) def forward( self, @@ -88,15 +102,15 @@ class NemotronLayerNorm1P(nn.LayerNorm): if residual is not None: x = x + residual residual = x - args = _cast_if_autocast_enabled(x, self.normalized_shape, - self.weight + 1, self.bias, self.eps) + args = _cast_if_autocast_enabled( + x, self.normalized_shape, self.weight + 1, self.bias, self.eps + ) with torch.amp.autocast("cuda", enabled=False): x = torch.nn.functional.layer_norm(*args) return x if residual is None else (x, residual) class NemotronMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -107,16 +121,20 @@ class NemotronMLP(nn.Module): prefix: str = "", ) -> None: super().__init__() - self.up_proj = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_fn = get_act_fn(hidden_act) def forward(self, x): @@ -127,7 +145,6 @@ class NemotronMLP(nn.Module): class NemotronAttention(nn.Module): - def __init__( self, config: NemotronConfig, @@ -194,13 +211,15 @@ class NemotronAttention(nn.Module): rope_scaling=rope_scaling, partial_rotary_factor=self.partial_rotary_factor, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -216,7 +235,6 @@ class NemotronAttention(nn.Module): class NemotronDecoderLayer(nn.Module): - def __init__( self, config: NemotronConfig, @@ -229,21 +247,24 @@ class NemotronDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = NemotronAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -260,10 +281,12 @@ class NemotronDecoderLayer(nn.Module): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, - eps=config.norm_eps) + self.input_layernorm = NemotronLayerNorm1P( + config.hidden_size, eps=config.norm_eps + ) self.post_attention_layernorm = NemotronLayerNorm1P( - config.hidden_size, eps=config.norm_eps) + config.hidden_size, eps=config.norm_eps + ) def forward( self, @@ -276,23 +299,20 @@ class NemotronDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class NemotronModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -303,12 +323,16 @@ class NemotronModel(nn.Module): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -318,19 +342,21 @@ class NemotronModel(nn.Module): self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: NemotronDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: NemotronDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: - self.norm = NemotronLayerNorm1P(config.hidden_size, - eps=config.norm_eps) + self.norm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -353,20 +379,18 @@ class NemotronModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -376,18 +400,19 @@ class NemotronModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -416,8 +441,7 @@ class NemotronModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -450,8 +474,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = NemotronModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = NemotronModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -463,21 +488,24 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -489,20 +517,18 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 07cd5a4c6e24f..0a05c63a31ea2 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -17,50 +17,60 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only NemotronH model.""" + from collections.abc import Iterable from typing import Optional import torch from torch import nn -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsLoRA, SupportsPP, - SupportsQuant) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsPP, + SupportsQuant, +) from vllm.model_executor.models.utils import ( - AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata + AutoWeightsLoader, + WeightsMapper, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig -from vllm.utils import LayerBlockType class NemotronHMLP(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -72,7 +82,7 @@ class NemotronHMLP(nn.Module): super().__init__() hybrid_override_pattern = config.hybrid_override_pattern - mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 + mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 if isinstance(config.intermediate_size, list): if len(config.intermediate_size) == 1: intermediate_size = config.intermediate_size[0] @@ -105,7 +115,6 @@ class NemotronHMLP(nn.Module): class NemotronHMLPDecoderLayer(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -145,7 +154,6 @@ class NemotronHMLPDecoderLayer(nn.Module): class NemotronHMambaDecoderLayer(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -181,8 +189,6 @@ class NemotronHMambaDecoderLayer(nn.Module): self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -192,12 +198,11 @@ class NemotronHMambaDecoderLayer(nn.Module): hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) - self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + self.mixer(hidden_states, output) return output, residual class NemotronHAttention(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -270,7 +275,6 @@ class NemotronHAttention(nn.Module): class NemotronHAttentionDecoderLayer(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -319,7 +323,6 @@ ALL_DECODER_LAYER_TYPES = { @support_torch_compile class NemotronHModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -330,8 +333,11 @@ class NemotronHModel(nn.Module): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -344,7 +350,8 @@ class NemotronHModel(nn.Module): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ - config.hybrid_override_pattern[layer_idx]] + config.hybrid_override_pattern[layer_idx] + ] return layer_class( config, layer_idx, @@ -355,11 +362,11 @@ class NemotronHModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( - len(config.hybrid_override_pattern), - get_layer, - prefix=f"{prefix}.layers") + len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers" + ) self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size) + ["hidden_states", "residual"], config.hidden_size + ) self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -370,22 +377,9 @@ class NemotronHModel(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -398,80 +392,69 @@ class NemotronHModel(nn.Module): residual = intermediate_tensors["residual"] residual = None - num_non_mamba_layers = 0 - for i in range(len(self.layers)): - layer = self.layers[i] - layer_mamba_cache_params = None - if isinstance(layer, - NemotronHMambaDecoderLayer) and mamba_cache_params: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - i - num_non_mamba_layers) - else: - num_non_mamba_layers += 1 - + for i, layer in enumerate(self.layers): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - attb_params_mapping = { - "q_proj": "q", - "k_proj": "k", - "v_proj": "v", - } + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "embeddings" in name: - name = name.replace("embeddings", "embed_tokens") + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue - if "A_log" in name: - name = name.replace("A_log", "A") - loaded_weight = loaded_weight.to(torch.float32) + # load stacked params + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - if "D" in name: - loaded_weight = loaded_weight.to(torch.float32) - - if "dt_bias" in name: - loaded_weight = loaded_weight.to(torch.float32) - - # load attn params - if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]): - weight_name = next(proj - for proj in ["q_proj", "k_proj", "v_proj"] - if proj in name) - name = name.replace(weight_name, "qkv_proj") param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, - attb_params_mapping[weight_name]) + weight_loader(param, loaded_weight, shard_id) + break + # load other params else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsQuant): +class NemotronHForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"backbone": "model"}, + orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"}, + ) + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -492,7 +475,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -503,13 +485,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -528,26 +508,23 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, head_dim=hf_config.mamba_head_dim, state_size=hf_config.ssm_state_size, conv_kernel=hf_config.conv_kernel, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "NemotronH currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = NemotronHModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = NemotronHModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -558,75 +535,41 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) - self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors) + self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - - num_mamba_layers = \ - self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, - LayerBlockType.mamba - ) - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - # update name in weights before passing to loader - updated_weights = [] - for name, loaded_weight in weights: - name = name.replace("backbone", "model") - updated_weights.append((name, loaded_weight)) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(updated_weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index a766ed9476a65..ddd623b5de237 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only deci model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -39,17 +41,26 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasNoOps, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: @@ -66,7 +77,6 @@ def _find_multiple(n: int, k: int) -> int: class DeciLMAttention(LlamaAttention): - def __init__( self, config: LlamaConfig, @@ -83,18 +93,34 @@ class DeciLMAttention(LlamaAttention): prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: - super().__init__(config, hidden_size, num_heads, num_kv_heads, - rope_theta, rope_scaling, max_position_embeddings, - quant_config, bias, bias_o_proj, cache_config, prefix, - attn_type) + super().__init__( + config, + hidden_size, + num_heads, + num_kv_heads, + rope_theta, + rope_scaling, + max_position_embeddings, + quant_config, + bias, + bias_o_proj, + cache_config, + prefix, + attn_type, + ) - def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig]) -> None: + def _init_rotary_emb( + self, + config, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig], + ) -> None: # Enables YARN for Mistral and LLaMA4 derivatives. is_neox_style = True if hasattr(config, "position_embedding_type"): is_neox_style = config.position_embedding_type not in [ - "mistral_yarn", "rope_llama4" + "mistral_yarn", + "rope_llama4", ] self.rotary_emb = get_rope( @@ -104,11 +130,11 @@ class DeciLMAttention(LlamaAttention): base=self.rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor) + partial_rotary_factor=self.partial_rotary_factor, + ) class DeciLMDecoderLayer(nn.Module): - def __init__( self, config: LlamaConfig, @@ -126,23 +152,26 @@ class DeciLMDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias if not self._is_no_op_attention: - num_kv_heads = (config.num_attention_heads // - block_config.attention.n_heads_in_group) + num_kv_heads = ( + config.num_attention_heads // block_config.attention.n_heads_in_group + ) self.self_attn = DeciLMAttention( config=config, hidden_size=self.hidden_size, @@ -157,13 +186,13 @@ class DeciLMDecoderLayer(nn.Module): cache_config=cache_config, prefix=f"{prefix}.self_attn", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if not self._is_no_op_ffn: ffn_mult = block_config.ffn.ffn_mult intermediate_size = _ffn_mult_to_intermediate_size( - ffn_mult, config.hidden_size) + ffn_mult, config.hidden_size + ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -173,8 +202,9 @@ class DeciLMDecoderLayer(nn.Module): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -187,12 +217,11 @@ class DeciLMDecoderLayer(nn.Module): if self._is_no_op_attention: pass else: - if (residual is None): + if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -201,14 +230,14 @@ class DeciLMDecoderLayer(nn.Module): # Fully Connected if not self._is_no_op_ffn: hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class DeciModel(nn.Module): - def __init__( self, *, @@ -226,12 +255,16 @@ class DeciModel(nn.Module): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -261,9 +294,9 @@ class DeciModel(nn.Module): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -287,27 +320,22 @@ class DeciModel(nn.Module): residual = intermediate_tensors["residual"] kv_cache_index = 0 - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): if not layer._is_no_op_attention: - hidden_states, residual = layer(positions, hidden_states, - residual) + hidden_states, residual = layer(positions, hidden_states, residual) kv_cache_index += 1 else: - hidden_states, residual = layer(positions, hidden_states, - residual) + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -321,19 +349,19 @@ class DeciModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name)): + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -366,8 +394,7 @@ class DeciModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -414,8 +441,9 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._init_model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -429,24 +457,25 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return DeciModel(vllm_config=vllm_config, prefix=prefix) @@ -461,24 +490,21 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index a9c7d8044e10c..268644bc92499 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -22,37 +22,45 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.internvl import ( - BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor, - BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs, - InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor) + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, + InternVLImageEmbeddingInputs, + InternVLImageInputs, + InternVLImagePixelInputs, + InternVLProcessor, +) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import ( - cached_image_processor_from_config) +from vllm.transformers_utils.processor import cached_image_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = '<img>' -IMG_END = '</img>' -IMG_CONTEXT = '<image>' +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<image>" def build_transform(input_size: int): - return T.Compose([ - T.Lambda(lambda img: convert_image_mode(img, 'RGB')), - T.Resize((input_size, input_size), - interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - ]) + return T.Compose( + [ + T.Lambda(lambda img: convert_image_mode(img, "RGB")), + T.Resize( + (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC + ), + T.ToTensor(), + ] + ) # adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1 @@ -64,15 +72,16 @@ def find_closest_aspect_ratio( height: int, image_size: int, ) -> tuple[int, int]: - best_factor = float('-inf') + best_factor = float("-inf") best_ratio = (1, 1) area = width * height for rw, rh in target_ratios: target_aspect_ratio = rw / rh size_factor = min((rw * rh * image_size * image_size) / area, 0.6) - ratio_closeness = min(target_aspect_ratio / aspect_ratio, - aspect_ratio / target_aspect_ratio) + ratio_closeness = min( + target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio + ) factor = size_factor * ratio_closeness if factor > best_factor: @@ -135,10 +144,12 @@ def dynamic_preprocess_nemotron_vl( resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -156,10 +167,13 @@ def get_nemotron_vl_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -187,7 +201,6 @@ def image_to_pixel_values_nemotron_vl( class NemotronVLProcessor(InternVLProcessor): - def __init__( self, config: PretrainedConfig, @@ -218,7 +231,8 @@ class NemotronVLProcessor(InternVLProcessor): assert isinstance(dynamic_image_size, bool) self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch @@ -270,7 +284,8 @@ class NemotronVLProcessor(InternVLProcessor): min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] def _preprocess_image( @@ -290,11 +305,11 @@ class NemotronVLProcessor(InternVLProcessor): max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + image_inputs = { + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: @@ -302,10 +317,9 @@ class NemotronVLProcessor(InternVLProcessor): feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) NVL_IMAGE_CONTEXT = image_repl.full.replace( - "<image>", "<NVL_IMG_CONTEXT>") - text = [ - t.replace('<image>', NVL_IMAGE_CONTEXT, 1) for t in text - ] + "<image>", "<NVL_IMG_CONTEXT>" + ) + text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text] text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text] return text, image_inputs @@ -342,9 +356,10 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo): @MULTIMODAL_REGISTRY.register_processor( BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo], info=NemotronVLProcessingInfo, - dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo]) -class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): + dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo], +) +class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -368,7 +383,8 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version @@ -391,18 +407,20 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and \ - (llm_quant_config is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( @@ -412,20 +430,22 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, *, prefix: str, ): - return AutoModel.from_config(config.vision_config, - trust_remote_code=True) + return AutoModel.from_config(config.vision_config, trust_remote_code=True) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vit_hidden_size vision_projection_hidden_size = config.projector_hidden_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, - bias=True), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, - vision_projection_hidden_size, - bias=True), + nn.LayerNorm( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, bias=True + ), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + vision_projection_hidden_size, + bias=True, + ), nn.GELU(), nn.Linear(vision_projection_hidden_size, llm_hidden_size), ) @@ -436,9 +456,13 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) - if self.ps_version == 'v1': + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": pass else: x = x.permute(0, 2, 1, 3).contiguous() @@ -449,17 +473,16 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, vit_embeds = self.vision_model(x=pixel_values).features vit_embeds = vit_embeds.to(dtype=torch.bfloat16) - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[InternVLImageInputs]: + self, **kwargs: object + ) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -468,13 +491,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return InternVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -482,24 +501,13 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return InternVLImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, resolve_bindings={ "h": self.config.force_image_size, - "w": self.config.force_image_size + "w": self.config.force_image_size, }, ) @@ -520,14 +528,12 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -539,10 +545,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values_flat", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) + if ( + input_key in ("pixel_values_flat", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) return modalities @@ -552,15 +559,13 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image). + # tensor corresponding to a multimodal data item (image). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -577,20 +582,23 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - context_token_ids = [self.img_context_token_id] - assert len(context_token_ids) >= 1 + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - context_token_ids, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -600,19 +608,10 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - forward_kwargs = { "input_ids": input_ids, "positions": positions, @@ -622,8 +621,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, # Only required if the model is mono-architecture if self.visual_token_mask is not None: - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) @@ -632,13 +630,10 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ## Ignore registered_buffers ## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501 skip_substrs = ["norm_mean", "norm_std"] @@ -652,4 +647,5 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return MultiModelKeys.from_string_field( language_model="language_model", connector="mlp1", - tower_model="vision_model") + tower_model="vision_model", + ) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 3bbf4c67604c7..f17bf3b09d5be 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -14,25 +14,34 @@ import torch import torch.nn as nn from transformers import PretrainedConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from .intern_vit import InternVisionModel -from .internvl import (BaseInternVLDummyInputsBuilder, - BaseInternVLMultiModalProcessor, - BaseInternVLProcessingInfo, BaseInternVLProcessor, - InternVLChatModel) +from .internvl import ( + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, + BaseInternVLProcessor, + InternVLChatModel, +) IMG_PAD = "<|vision_pad|>" class NVLMProcessor(BaseInternVLProcessor): - @property def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_PAD] @@ -50,8 +59,9 @@ class NVLMProcessor(BaseInternVLProcessor): tile_pos_identifiers += ["<tile_global_thumbnail>"] context_size = feature_size // num_patches - features = "".join(identifier + IMG_PAD * context_size - for identifier in tile_pos_identifiers) + features = "".join( + identifier + IMG_PAD * context_size for identifier in tile_pos_identifiers + ) # We include the start and end as well because "<Image><tile" is # tokenized as ["<Image", "><", "tile"], resulting in assertion error @@ -62,7 +72,6 @@ class NVLMProcessor(BaseInternVLProcessor): class NVLMProcessingInfo(BaseInternVLProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> NVLMProcessor: return self.ctx.init_processor( NVLMProcessor, @@ -72,9 +81,7 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo): ) -class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo] - ): - +class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -86,22 +93,24 @@ class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo] self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } -class NVLMMultiModalProcessor( - BaseInternVLMultiModalProcessor[NVLMProcessingInfo]): - +class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -124,7 +133,8 @@ class NVLMMultiModalProcessor( def get_replacement_nvlm(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -154,21 +164,24 @@ class NVLMMultiModalProcessor( ] -@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor, - info=NVLMProcessingInfo, - dummy_inputs=NVLMDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + NVLMMultiModalProcessor, + info=NVLMProcessingInfo, + dummy_inputs=NVLMDummyInputsBuilder, +) class NVLM_D_Model(InternVLChatModel): - - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_intermediate_size = config.text_config.intermediate_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, - llm_intermediate_size, - bias=False), + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + llm_intermediate_size, + bias=False, + ), nn.GELU(), nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False), ) @@ -184,8 +197,9 @@ class NVLM_D_Model(InternVLChatModel): if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 01639d398126f..f334bbf9feeb5 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -35,22 +37,29 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OlmoAttention(nn.Module): @@ -70,15 +79,13 @@ class OlmoAttention(nn.Module): super().__init__() self.config = config self.hidden_size = config.hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -102,12 +109,14 @@ class OlmoAttention(nn.Module): base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) # Attention output projection. self.o_proj = RowParallelLinear( @@ -189,28 +198,29 @@ class OlmoDecoderLayer(nn.Module): (plus another skip connection). """ - def __init__(self, - config: OlmoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = OlmoAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) # MLP block. self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp") # LayerNorm - self.input_layernorm = nn.LayerNorm(config.hidden_size, - elementwise_affine=False, - bias=False) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - elementwise_affine=False, - bias=False) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, bias=False + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, bias=False + ) def forward( self, @@ -233,7 +243,6 @@ class OlmoDecoderLayer(nn.Module): @support_torch_compile class OlmoModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -243,19 +252,22 @@ class OlmoModel(nn.Module): self.config = config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OlmoDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = nn.LayerNorm(config.hidden_size, - elementwise_affine=False, - bias=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, bias=False + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -280,7 +292,7 @@ class OlmoModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) hidden_states = layer(positions, hidden_states) @@ -291,8 +303,7 @@ class OlmoModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -304,7 +315,7 @@ class OlmoModel(nn.Module): params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -324,8 +335,7 @@ class OlmoModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -335,6 +345,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -352,8 +363,9 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.model = OlmoModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = OlmoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -363,10 +375,12 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -389,17 +403,15 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 66a0f9115585a..79234cc4dd8de 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -26,6 +26,7 @@ from collections.abc import Iterable from functools import partial +from itertools import islice from typing import Optional, Union import torch @@ -41,20 +42,29 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.distributed.utils import split_tensor_along_last_dim from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Olmo3Config class Olmo2Attention(nn.Module): @@ -67,7 +77,7 @@ class Olmo2Attention(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) hidden_size = self.config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() @@ -77,8 +87,9 @@ class Olmo2Attention(nn.Module): assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size - self.total_num_kv_heads = (self.config.num_key_value_heads - or self.total_num_heads) + self.total_num_kv_heads = ( + self.config.num_key_value_heads or self.total_num_heads + ) if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: @@ -107,17 +118,17 @@ class Olmo2Attention(nn.Module): self.total_num_kv_heads * self.head_dim, eps=self.config.rms_norm_eps, ) - self.q_norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - ) self.scaling = self.head_dim**-0.5 + + layer_idx = extract_layer_index(prefix) + sliding_window = None + if ( + layer_types := getattr(self.config, "layer_types", None) + ) is not None and layer_types[layer_idx] == "sliding_attention": + sliding_window = self.config.sliding_window + self.attn = Attention( self.num_heads, self.head_dim, @@ -125,7 +136,19 @@ class Olmo2Attention(nn.Module): num_kv_heads=self.num_kv_heads, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=prefix, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = self.config.rope_scaling if sliding_window is None else None + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + rope_scaling=self.rope_scaling, ) # Attention output projection. @@ -137,16 +160,16 @@ class Olmo2Attention(nn.Module): prefix=f"{prefix}.o_proj", ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -175,7 +198,7 @@ class Olmo2MLP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) hidden_size = config.hidden_size intermediate_size = config.intermediate_size @@ -220,20 +243,23 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) # Attention block. - self.self_attn = Olmo2Attention(vllm_config=vllm_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Olmo2Attention( + vllm_config=vllm_config, prefix=f"{prefix}.self_attn" + ) # MLP block. self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") # LayerNorm - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -256,11 +282,10 @@ class Olmo2DecoderLayer(nn.Module): @support_torch_compile class Olmo2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -269,17 +294,19 @@ class Olmo2Model(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, - lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, - prefix=prefix), + lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - self.config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -305,7 +332,7 @@ class Olmo2Model(nn.Module): assert isinstance(hidden_states, torch.Tensor) # Apply blocks one-by-one. - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) hidden_states = layer(positions, hidden_states) @@ -317,8 +344,7 @@ class Olmo2Model(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -349,8 +375,7 @@ class Olmo2Model(nn.Module): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -360,6 +385,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -375,10 +401,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) self.config = config - self.model = Olmo2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Olmo2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -392,7 +419,11 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -412,16 +443,15 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index a47c3bd416459..0e4b408775f5f 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -13,40 +13,51 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable from functools import partial -from typing import Any, Optional, Union +from itertools import islice +from typing import Optional, Union import torch from torch import nn -from transformers import OlmoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.distributed.utils import split_tensor_along_last_dim from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -60,33 +71,36 @@ class OlmoeMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - quant_config=None) + self.gate = ReplicatedLinear( + hidden_size, num_experts, bias=False, quant_config=None + ) - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - reduce_results=True, - renormalize=False, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -95,27 +109,28 @@ class OlmoeMoE(nn.Module): hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) return final_hidden_states.view(orig_shape) class OlmoeAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 4096, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - self.hidden_size = hidden_size + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) + + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -130,7 +145,7 @@ class OlmoeAttention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -138,7 +153,7 @@ class OlmoeAttention(nn.Module): self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, + self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, @@ -148,11 +163,10 @@ class OlmoeAttention(nn.Module): self.tp_size = tp_size self.tp_rank = get_tensor_model_parallel_rank() self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5) - self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, - eps=1e-5) + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, - hidden_size, + self.hidden_size, bias=False, quant_config=quant_config, ) @@ -165,24 +179,26 @@ class OlmoeAttention(nn.Module): rope_scaling=rope_scaling, is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -202,30 +218,15 @@ class OlmoeAttention(nn.Module): class OlmoeDecoderLayer(nn.Module): - - def __init__( - self, - config: OlmoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 4096) self.self_attn = OlmoeAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.self_attn", ) @@ -251,8 +252,7 @@ class OlmoeDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, @@ -260,21 +260,23 @@ class OlmoeDecoderLayer(nn.Module): ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class OlmoeModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = OlmoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config @@ -284,14 +286,14 @@ class OlmoeModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoeDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -314,7 +316,7 @@ class OlmoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -322,12 +324,14 @@ class OlmoeModel(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) - hidden_states, _ = self.norm(hidden_states, residual) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -337,10 +341,10 @@ class OlmoeModel(nn.Module): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -354,7 +358,7 @@ class OlmoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -391,11 +395,13 @@ class OlmoeModel(nn.Module): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -407,7 +413,8 @@ class OlmoeModel(nn.Module): # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -419,8 +426,9 @@ class OlmoeModel(nn.Module): name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -439,21 +447,34 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ], } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = OlmoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OlmoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.model = OlmoeModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -465,18 +486,16 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 9eaac1e28dcd8..eadfea6084e5e 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -19,7 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -31,26 +33,33 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OPTLearnedPositionalEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int): # OPT is set up so that if padding_idx is specified then offset the # embedding ids by 2 and adjust num_embeddings appropriately. Other @@ -63,7 +72,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -75,8 +83,7 @@ class OPTAttention(nn.Module): ) -> None: super().__init__() self.embed_dim = embed_dim - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() total_num_heads = num_heads assert num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size @@ -98,12 +105,14 @@ class OPTAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -117,7 +126,6 @@ class OPTAttention(nn.Module): class OPTDecoderLayer(nn.Module): - def __init__( self, config: OPTConfig, @@ -139,8 +147,8 @@ class OPTDecoderLayer(nn.Module): self.do_layer_norm_before = config.do_layer_norm_before self.self_attn_layer_norm = nn.LayerNorm( - self.embed_dim, - elementwise_affine=config.layer_norm_elementwise_affine) + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) self.fc1 = ColumnParallelLinear( self.embed_dim, config.ffn_dim, @@ -157,8 +165,8 @@ class OPTDecoderLayer(nn.Module): prefix=f"{prefix}.fc2", ) self.final_layer_norm = nn.LayerNorm( - self.embed_dim, - elementwise_affine=config.layer_norm_elementwise_affine) + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) def forward( self, @@ -191,7 +199,6 @@ class OPTDecoderLayer(nn.Module): class OPTDecoder(nn.Module): - def __init__( self, config: OPTConfig, @@ -210,24 +217,29 @@ class OPTDecoder(nn.Module): ) # Positional embeddings are replicated (not sharded). self.embed_positions = OPTLearnedPositionalEmbedding( - config.max_position_embeddings, config.hidden_size) + config.max_position_embeddings, config.hidden_size + ) # Project out & in will be replicated if they exist. if config.word_embed_proj_dim != config.hidden_size: - self.project_out = ReplicatedLinear(config.hidden_size, - config.word_embed_proj_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.project_out") + self.project_out = ReplicatedLinear( + config.hidden_size, + config.word_embed_proj_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.project_out", + ) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: - self.project_in = ReplicatedLinear(config.word_embed_proj_dim, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.project_in") + self.project_in = ReplicatedLinear( + config.word_embed_proj_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.project_in", + ) else: self.project_in = None @@ -238,15 +250,18 @@ class OPTDecoder(nn.Module): if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm( config.hidden_size, - elementwise_affine=config.layer_norm_elementwise_affine) + elementwise_affine=config.layer_norm_elementwise_affine, + ) else: self.final_layer_norm = None self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OPTDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -269,7 +284,7 @@ class OPTDecoder(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(hidden_states) if not get_pp_group().is_last_rank: @@ -283,7 +298,6 @@ class OPTDecoder(nn.Module): @support_torch_compile class OPTModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -291,13 +305,12 @@ class OPTModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.decoder = OPTDecoder(config, - cache_config, - quant_config, - prefix=f"{prefix}.decoder") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.decoder = OPTDecoder( + config, cache_config, quant_config, prefix=f"{prefix}.decoder" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.decoder.get_input_embeddings(input_ids) @@ -309,13 +322,11 @@ class OPTModel(nn.Module): intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - return self.decoder(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.decoder( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -325,7 +336,7 @@ class OPTModel(nn.Module): params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -345,22 +356,22 @@ class OPTModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class OPTForCausalLM(nn.Module, SupportsPP): +class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] } - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "decoder.": "model.decoder.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -368,16 +379,21 @@ class OPTForCausalLM(nn.Module, SupportsPP): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OPTModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = OPTModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if self.config.tie_word_embeddings: self.lm_head = self.model.decoder.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.word_embed_proj_dim) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.word_embed_proj_dim, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -389,24 +405,23 @@ class OPTForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index d121188ba5d4a..0ce1729389553 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -6,7 +6,9 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -18,26 +20,32 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OrionMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -47,16 +55,15 @@ class OrionMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -67,7 +74,6 @@ class OrionMLP(nn.Module): class OrionAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -125,13 +131,15 @@ class OrionAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -147,7 +155,6 @@ class OrionAttention(nn.Module): class OrionDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -159,8 +166,7 @@ class OrionDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = OrionAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -179,10 +185,10 @@ class OrionDecoderLayer(nn.Module): quant_config=quant_config, ) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -209,7 +215,6 @@ class OrionDecoderLayer(nn.Module): @support_torch_compile class OrionModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -226,13 +231,17 @@ class OrionModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OrionDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory([ + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + [ "hidden_states", - ], config.hidden_size)) + ], + config.hidden_size, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -252,17 +261,18 @@ class OrionModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -274,7 +284,7 @@ class OrionModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -294,31 +304,34 @@ class OrionModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class OrionForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OrionModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.model = OrionModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -330,20 +343,18 @@ class OrionForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 5b3ad7cbd07ad..12ed7b4c2ed03 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -16,10 +16,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Ovis model.""" +"""PyTorch Ovis model.""" + import math from collections.abc import Iterable, Mapping -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -28,30 +29,35 @@ from torch.nn.functional import gumbel_softmax, pad, softmax from transformers import BatchFeature, PretrainedConfig from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.aimv2 import AIMv2Model from vllm.model_executor.models.siglip import SiglipVisionModel -from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, - init_vllm_registered_model, - maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis import OvisProcessor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "<image>" @@ -78,7 +84,6 @@ def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax class VisualTokenizer(torch.nn.Module): - def __init__( self, config: PretrainedConfig, @@ -96,12 +101,15 @@ class VisualTokenizer(torch.nn.Module): head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) self.head = torch.nn.Sequential( ReplicatedLinear( - config.backbone_config.hidden_size * config.hidden_stride * - config.hidden_stride, + config.backbone_config.hidden_size + * config.hidden_stride + * config.hidden_stride, head_dim, bias=False, return_bias=False, - ), torch.nn.LayerNorm(head_dim)) + ), + torch.nn.LayerNorm(head_dim), + ) def _init_backbone( self, @@ -124,8 +132,7 @@ class VisualTokenizer(torch.nn.Module): quant_config=quant_config, prefix=prefix, ) - raise ValueError( - f"Unsupported visual tokenizer model_type: {model_type}") + raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @property def dtype(self) -> torch.dtype: @@ -136,16 +143,17 @@ class VisualTokenizer(torch.nn.Module): return next(self.head.parameters()).device def tokenize(self, logits: torch.Tensor) -> torch.Tensor: - if self.config.tokenize_function == 'softmax': + if self.config.tokenize_function == "softmax": tokens = softmax(logits, dim=-1) - elif self.config.tokenize_function == 'gumbel_argmax': + elif self.config.tokenize_function == "gumbel_argmax": tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) - elif self.config.tokenize_function == 'st_argmax': + elif self.config.tokenize_function == "st_argmax": tokens = st_argmax(logits, dim=-1) else: raise ValueError( - 'Invalid `max_type`, expected softmax or gumbel_argmax ' - f'or st_argmax, but got {self.config.tokenize_function}') + "Invalid `max_type`, expected softmax or gumbel_argmax " + f"or st_argmax, but got {self.config.tokenize_function}" + ) return tokens def encode(self, pixel_values: torch.Tensor) -> torch.Tensor: @@ -162,25 +170,30 @@ class VisualTokenizer(torch.nn.Module): n, L, d = features.shape sqrt_l = int(L**0.5) assert sqrt_l**2 == L, ( - "The token sequence length should be a perfect square.") + "The token sequence length should be a perfect square." + ) features = features.reshape(n, sqrt_l, sqrt_l, d) - pl = (self.config.hidden_stride - - (sqrt_l % - self.config.hidden_stride)) % self.config.hidden_stride + pl = ( + self.config.hidden_stride - (sqrt_l % self.config.hidden_stride) + ) % self.config.hidden_stride features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) sqrt_l += pl - features = features.reshape(n, sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, - sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, d) + features = features.reshape( + n, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + d, + ) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] features = features.flatten(3) # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] features = features.reshape( - n, -1, - self.config.hidden_stride * self.config.hidden_stride * d) + n, -1, self.config.hidden_stride * self.config.hidden_stride * d + ) return features @@ -201,35 +214,34 @@ class VisualTokenizer(torch.nn.Module): return tokens -class OvisImagePatchInputs(TypedDict): +class OvisImagePatchInputs(TensorSchema): + """ + Dimensions: + - bnp: Batch size * number of images * number of patches + - h: Height of each patch + - w: Width of each patch + - patch_indicators: Batch size * (number of patches + 1) + - bn: Batch size * number of images + """ + type: Literal["image_patches"] - flat_data: torch.Tensor - """ - Shape: - `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` - """ - - inducator_tokens: torch.Tensor - """ - Shape: - `(batch_size * (num_patches + 1))` - """ - - patches_per_image: list[int] - """ - List of number of total patches for each image in the batch. - This is used to restore the first two dimensions of `flat_data`. - """ + flat_data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] + indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] + patches_per_image: Annotated[list[int], TensorShape("bn")] + # This is used to restore the first two dimensions of `flat_data`. class VisualEmbedding(torch.nn.Embedding): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, visual_tokens: Tensor) -> Tensor: if visual_tokens.dtype in [ - torch.int8, torch.int16, torch.int32, torch.int64, torch.long + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.long, ]: return super().forward(visual_tokens) return torch.matmul(visual_tokens, self.weight) @@ -244,7 +256,6 @@ class VisualEmbedding(torch.nn.Embedding): class OvisProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor( OvisProcessor, @@ -261,9 +272,10 @@ class OvisProcessingInfo(BaseProcessingInfo): patch_grid_length = math.ceil(image_size / patch_size) assert patch_grid_length % hidden_stride == 0, ( f"patch_grid_length {patch_grid_length} is not divisible by " - f"hidden_stride {hidden_stride}") + f"hidden_stride {hidden_stride}" + ) # minus 1 for presented image token - return (patch_grid_length // hidden_stride)**2 - 1 + return (patch_grid_length // hidden_stride) ** 2 - 1 def get_image_pad_token(self) -> str: hf_text_config = self.get_hf_config().get_text_config() @@ -282,7 +294,6 @@ class OvisProcessingInfo(BaseProcessingInfo): class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return IMAGE_TOKEN * num_images @@ -291,29 +302,32 @@ class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } return mm_data class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): - def image_indicators_to_visual_tokens( self, image_indicators: list[int], ) -> list[int]: """ - Filter image indicators placeholders and convert them to corresponding + Filter image indicators placeholders and convert them to corresponding tokens in visual tokenizer. For example, [-301, -300, -302, -300, -303, -300, -304, -300, -305] should return [vocab_size-1, vocab_size-2, ..., vocab_size-5] @@ -352,14 +366,13 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): self.image_indicators_to_visual_tokens(indicator) for indicator in image_indicators ] - processed_outputs["indicator_tokens"] = indicator_tokens + processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens) return processed_outputs def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], ) -> list[int]: - return prompt_tokens def _get_mm_fields_config( @@ -367,9 +380,11 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - grids=MultiModalFieldConfig.batched("image"), - indicator_tokens=MultiModalFieldConfig.batched("image")) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + ) def _get_prompt_updates( self, @@ -377,7 +392,6 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx: int): out_item = out_mm_kwargs["image"][item_idx] grid = out_item["grids"].data @@ -394,10 +408,13 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, - info=OvisProcessingInfo, - dummy_inputs=OvisDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + OvisMultiModalProcessor, + info=OvisProcessingInfo, + dummy_inputs=OvisDummyInputsBuilder, +) class Ovis(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -419,30 +436,24 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): self.visual_tokenizer = VisualTokenizer( config=config.visual_tokenizer_config, - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", ) self.vte = VisualEmbedding( - self.config.visual_tokenizer_config.vocab_size, - self.config.hidden_size) + self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size + ) text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.make_empty_intermediate_tensors = ( - self.get_language_model().make_empty_intermediate_tensors) - - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid vision encoder sections for some models. - # See: https://huggingface.co/AIDC-AI/Ovis2-2B-GPTQ-Int4 - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config + self.get_language_model().make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + self, **kwargs: object + ) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) @@ -451,59 +462,59 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): if pixel_values is not None and indicator_tokens is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) if not isinstance(indicator_tokens, (torch.Tensor, list)): - raise ValueError("Incorrect type of indicator_tokens. " - f"Got type: {type(pixel_values)}") + raise ValueError( + "Incorrect type of indicator_tokens. " + f"Got type: {type(pixel_values)}" + ) return OvisImagePatchInputs( type="image_patches", - flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), - patches_per_image=[ - x.shape[0] for x in flatten_bn(pixel_values) - ], - indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), - concat=True), + flat_data=flatten_bn(pixel_values, concat=True), + patches_per_image=[x.shape[0] for x in pixel_values], + indicator_tokens=flatten_bn(indicator_tokens, concat=True), ) raise AssertionError("This line should be unreachable.") def _process_image_input( - self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: + self, image_input: OvisImagePatchInputs + ) -> MultiModalEmbeddings: image_patches_flat = image_input["flat_data"] patches_per_image = image_input["patches_per_image"] indicator_tokens = image_input["indicator_tokens"] indicator_per_image = list( - map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image)) + map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image) + ) target_dtype = self.visual_tokenizer.dtype - visual_tokens = self.visual_tokenizer( - image_patches_flat.to(target_dtype)) + visual_tokens = self.visual_tokenizer(image_patches_flat.to(target_dtype)) visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. indicator_embeds = self.vte(indicator_tokens) - indicator_embeds_per_image = indicator_embeds.split( - indicator_per_image) + indicator_embeds_per_image = indicator_embeds.split(indicator_per_image) visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) vision_embeddings = [] - for indicator, visual in zip(indicator_embeds_per_image, - visual_embeds_per_image): + for indicator, visual in zip( + indicator_embeds_per_image, visual_embeds_per_image + ): vision_embeddings_per_image = [] for i in range(visual.shape[0]): vision_embeddings_per_image.append( - torch.cat([indicator[i:i + 1], visual[i]], dim=0)) - vision_embeddings_per_image.append(indicator[i + 1:]) - vision_embeddings.append( - torch.cat(vision_embeddings_per_image, dim=0)) + torch.cat([indicator[i : i + 1], visual[i]], dim=0) + ) + vision_embeddings_per_image.append(indicator[i + 1 :]) + vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0)) return tuple(vision_embeddings) - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -512,19 +523,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): return image_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_pad_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -536,15 +534,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - # up until here we have a inputs_embeds 100% numerical identity + # up until here we have an inputs_embeds 100% numerical identity # between the OG HF Transformers implementation and ours hidden_states = self.llm( input_ids=input_ids, @@ -557,13 +547,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.llm.compute_logits(hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 58a14072443cb..bb4fb1d17c151 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -1,34 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" PyTorch Ovis model.""" +"""PyTorch Ovis model.""" + from collections.abc import Iterable, Mapping from functools import partial -from typing import Optional, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.models.ovis import (OvisImagePatchInputs, - VisualEmbedding) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.ovis import VisualEmbedding from vllm.model_executor.models.siglip2navit import Siglip2NavitModel -from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, - init_vllm_registered_model, - maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -50,13 +59,38 @@ IMAGE_PAD_TOKEN_ID_MAP = { } -def _ovis2_5_field_config(): - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - grids=MultiModalFieldConfig.batched("image"), - indicator_tokens=MultiModalFieldConfig.batched("image"), - video_pixel_values=MultiModalFieldConfig.batched("video"), - video_indicator_tokens=MultiModalFieldConfig.batched("video"), - video_grids=MultiModalFieldConfig.batched("video")) +class Ovis2_5ImagePatchInputs(TensorSchema): + """ + Dimensions: + - bnp: Batch size * number of images * number of patches + - patch_size: patch_size_x * patch_size_y * num_channels + - patch_indicators: Batch size * (number of patches + 1) + - bn: Batch size * number of images + """ + + type: Literal["image_patches"] + flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")] + indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] + patches_per_item: Annotated[list[int], TensorShape("bn")] + grids: Annotated[torch.Tensor, TensorShape("bn", 3)] + # This is used to restore the first two dimensions of `flat_data`. + + +class Ovis2_5VideoPatchInputs(TensorSchema): + """ + Dimensions: + - bnp: Batch size * number of videos * number of patches + - patch_size: patch_size_x * patch_size_y * num_channels + - patch_indicators: Batch size * (number of patches + 1) + - bn: Batch size * number of videos + """ + + type: Literal["video_patches"] + flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")] + indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] + patches_per_item: Annotated[list[int], TensorShape("bn")] + grids: Annotated[torch.Tensor, TensorShape("bn", 3)] + # This is used to restore the first two dimensions of `flat_data`. class VisualTokenizer(torch.nn.Module): @@ -88,7 +122,9 @@ class VisualTokenizer(torch.nn.Module): head_dim, bias=False, return_bias=False, - ), torch.nn.LayerNorm(head_dim)) + ), + torch.nn.LayerNorm(head_dim), + ) def _init_backbone( self, @@ -99,12 +135,13 @@ class VisualTokenizer(torch.nn.Module): ): model_type = config.model_type if model_type == "siglip2_navit": - return Siglip2NavitModel(config=config, - quant_config=quant_config, - prefix=prefix, - use_data_parallel=use_data_parallel) - raise ValueError( - f"Unsupported visual tokenizer model_type: {model_type}") + return Siglip2NavitModel( + config=config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=use_data_parallel, + ) + raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @property def dtype(self) -> torch.dtype: @@ -115,22 +152,22 @@ class VisualTokenizer(torch.nn.Module): return next(self.head.parameters()).device def tokenize(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.softmax(logits, dim=-1, - dtype=torch.float32).to(logits.dtype) + tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype) return tokens - def encode(self, pixel_values: torch.Tensor, - grid_thws: torch.Tensor) -> torch.Tensor: + def encode( + self, pixel_values: torch.Tensor, grid_thws: torch.Tensor + ) -> torch.Tensor: features = self.vit(pixel_values, grid_thws) # refer to qwen2.5-vl patchmerger seq_len, _ = features.shape - features = features.reshape(seq_len // (self.config.hidden_stride**2), - -1) + features = features.reshape(seq_len // (self.config.hidden_stride**2), -1) return features - def forward(self, pixel_values: torch.Tensor, - grid_thws: torch.Tensor) -> torch.Tensor: + def forward( + self, pixel_values: torch.Tensor, grid_thws: torch.Tensor + ) -> torch.Tensor: features = self.encode(pixel_values, grid_thws) logits = self.head(features) tokens = self.tokenize(logits) @@ -147,7 +184,6 @@ class VisualTokenizer(torch.nn.Module): class Ovis2_5ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -200,8 +236,9 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() - return self.get_num_image_tokens(image_width=target_width, - image_height=target_height) + return self.get_num_image_tokens( + image_width=target_width, image_height=target_height + ) def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() @@ -227,8 +264,7 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -240,9 +276,9 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): num_frames: int, image_processor: Optional[BaseImageProcessor], ) -> int: - num_video_tokens = self.get_num_image_tokens(image_width=image_width, - image_height=image_height, - num_frames=num_frames) + num_video_tokens = self.get_num_image_tokens( + image_width=image_width, image_height=image_height, num_frames=num_frames + ) return num_video_tokens def get_max_video_tokens( @@ -254,14 +290,12 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -271,46 +305,52 @@ class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, - ) + overrides=video_overrides, + ), } return mm_data -class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] - ): - +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]): def visual_indicators_to_visual_tokens( self, visual_indicators: list[int], ) -> list[int]: """ - Filter image indicators placeholders and convert them to corresponding + Filter image indicators placeholders and convert them to corresponding tokens in visual tokenizer. """ hf_config = self.info.get_hf_config() vte_vocab_size = hf_config.visual_vocab_size return [ vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 - for x in visual_indicators if x < -300 + for x in visual_indicators + if x < -300 ] def _call_hf_processor( @@ -343,7 +383,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] self.visual_indicators_to_visual_tokens(indicator) for indicator in visual_indicators ] - processed_outputs["video_indicator_tokens"] = indicator_tokens + processed_outputs["video_indicator_tokens"] = torch.tensor(indicator_tokens) if "images" in mm_data: visual_indicators = [ hf_processor.construct_visual_indicators((1, 1, 1), False) @@ -354,14 +394,13 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] for indicator in visual_indicators ] - processed_outputs["indicator_tokens"] = indicator_tokens + processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens) return processed_outputs def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], ) -> list[int]: - return prompt_tokens def _get_mm_fields_config( @@ -369,7 +408,14 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _ovis2_5_field_config() + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video"), + ) def _get_prompt_updates( self, @@ -377,7 +423,6 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx, modality: str): if modality == "image": out_item = out_mm_kwargs["image"][item_idx] @@ -386,21 +431,27 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] out_item = out_mm_kwargs["video"][item_idx] grid = out_item["video_grids"].data hf_processor = self.info.get_hf_processor() - return hf_processor.construct_visual_placeholders(grid[0], ) + return hf_processor.construct_visual_placeholders( + grid[0], + ) return [ PromptReplacement( modality=modality, target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, replacement=partial(get_replacement_ovis, modality=modality), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] -@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, - info=Ovis2_5ProcessingInfo, - dummy_inputs=Ovis2_5DummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder, +) class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -420,112 +471,159 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): prefix=f"{prefix}.visual_tokenizer", ) - self.vte = VisualEmbedding(config.visual_vocab_size, - config.hidden_size) + self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.make_empty_intermediate_tensors = ( - self.get_language_model().make_empty_intermediate_tensors) + self.get_language_model().make_empty_intermediate_tensors + ) - def _parse_and_validate_visual_input( - self, is_video, - **kwargs: object) -> Optional[OvisImagePatchInputs]: - if is_video: - pixel_values = kwargs.pop("video_pixel_values", None) - indicator_tokens = kwargs.pop("video_indicator_tokens", None) - grids = kwargs.pop("video_grids", None) - else: - pixel_values = kwargs.pop("pixel_values", None) - indicator_tokens = kwargs.pop("indicator_tokens", None) - grids = kwargs.pop("grids", None) + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[Ovis2_5ImagePatchInputs]: + pixel_values = kwargs.pop("pixel_values", None) + indicator_tokens = kwargs.pop("indicator_tokens", None) + grids = kwargs.pop("grids", None) if pixel_values is None and indicator_tokens is None: return None if pixel_values is not None and indicator_tokens is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) if not isinstance(indicator_tokens, (torch.Tensor, list)): - raise ValueError("Incorrect type of indicator_tokens. " - f"Got type: {type(indicator_tokens)}") + raise ValueError( + "Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}" + ) - return OvisImagePatchInputs( + return Ovis2_5ImagePatchInputs( type="image_patches", - flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), - patches_per_image=[ + flat_data=flatten_bn(pixel_values, concat=True), + patches_per_item=[ x.shape[0] // (self.config.vit_config.hidden_stride**2) - for x in flatten_bn(pixel_values) + for x in pixel_values ], - indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), - concat=True), - grids=flatten_bn(flatten_bn(grids), concat=True), + indicator_tokens=flatten_bn(indicator_tokens, concat=True), + grids=flatten_bn(grids, concat=True), ) raise AssertionError("This line should be unreachable.") - def _process_image_input( - self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: - image_patches_flat = image_input["flat_data"] - patches_per_image = image_input["patches_per_image"] - indicator_tokens = image_input["indicator_tokens"] - grid_thws = image_input["grids"] + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> Optional[Ovis2_5VideoPatchInputs]: + pixel_values = kwargs.pop("video_pixel_values", None) + indicator_tokens = kwargs.pop("video_indicator_tokens", None) + grids = kwargs.pop("video_grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}" + ) + + return Ovis2_5VideoPatchInputs( + type="video_patches", + flat_data=flatten_bn(pixel_values, concat=True), + patches_per_item=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in pixel_values + ], + indicator_tokens=flatten_bn(indicator_tokens, concat=True), + grids=flatten_bn(grids, concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_visual_input( + self, visual_input: Union[Ovis2_5ImagePatchInputs, Ovis2_5VideoPatchInputs] + ) -> MultiModalEmbeddings: + image_patches_flat = visual_input["flat_data"] + patches_per_image = visual_input["patches_per_item"] + indicator_tokens = visual_input["indicator_tokens"] + grid_thws = visual_input["grids"] indicator_per_image = list( - map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + map(lambda x: 2 if x > 1 else x + 2, patches_per_image) + ) target_dtype = self.visual_tokenizer.dtype visual_tokens = self.visual_tokenizer( - image_patches_flat.to(target_dtype), grid_thws) + image_patches_flat.to(target_dtype), grid_thws + ) visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. indicator_embeds = self.vte(indicator_tokens) visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) - indicator_embeds_per_image = indicator_embeds.split( - indicator_per_image) + indicator_embeds_per_image = indicator_embeds.split(indicator_per_image) vision_embeddings = [] - for indicator, visual in zip(indicator_embeds_per_image, - visual_embeds_per_image): + for indicator, visual in zip( + indicator_embeds_per_image, visual_embeds_per_image + ): vision_embeddings_per_image = [] visual = visual.unsqueeze(0) for i in range(visual.shape[0]): vision_embeddings_per_image.append( - torch.cat([indicator[i:i + 1], visual[i]], dim=0)) - vision_embeddings_per_image.append(indicator[i + 1:]) - vision_embeddings.append( - torch.cat(vision_embeddings_per_image, dim=0)) + torch.cat([indicator[i : i + 1], visual[i]], dim=0) + ) + vision_embeddings_per_image.append(indicator[i + 1 :]) + vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0)) return tuple(vision_embeddings) - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - embeddings = [] + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} - # NOTE: _parse_and_validate_visual_input has side-effects and pops - # keys from kwargs. We process images first, then videos. - image_input = self._parse_and_validate_visual_input(False, **kwargs) - if image_input: - embeddings.extend(self._process_image_input(image_input)) + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if ( + input_key in ("pixel_values", "indicator_tokens", "grids") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key + in ("video_pixel_values", "video_indicator_tokens", "video_grids") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) - video_input = self._parse_and_validate_visual_input(True, **kwargs) - if video_input: - embeddings.extend(self._process_image_input(video_input)) + return modalities - return tuple(embeddings) if embeddings else None + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.llm.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - tmp = torch.concat(multimodal_embeddings, dim=0) - inputs_embeds[input_ids == self.image_pad_token_id] = tmp - return inputs_embeds + multimodal_embeddings: tuple[torch.Tensor, ...] = () + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_visual_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_visual_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings def forward( self, @@ -538,15 +636,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - # up until here we have a inputs_embeds 100% numerical identity # between the OG HF Transformers implementation and ours hidden_states = self.llm( @@ -560,13 +649,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.llm.compute_logits(hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 7d6a6207c7c89..7bddfc5ee855b 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,58 +1,83 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn from transformers import BatchFeature, PaliGemmaConfig from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info logger = init_logger(__name__) -class PaliGemmaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - - -class PaliGemmaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class PaliGemmaImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, - PaliGemmaImageEmbeddingInputs] + +class PaliGemmaImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] + + +PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, PaliGemmaImageEmbeddingInputs] class PaliGemmaMultiModalProjector(nn.Module): - def __init__(self, vision_hidden_size: int, projection_dim: int): super().__init__() @@ -64,7 +89,6 @@ class PaliGemmaMultiModalProjector(nn.Module): class PaliGemmaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(PaliGemmaConfig) @@ -88,9 +112,7 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo): ) -class PaliGemmaDummyInputsBuilder( - BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): - +class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -98,6 +120,7 @@ class PaliGemmaDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config @@ -105,17 +128,19 @@ class PaliGemmaDummyInputsBuilder( num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images) + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) } -class PaliGemmaMultiModalProcessor( - BaseMultiModalProcessor[PaliGemmaProcessingInfo]): - +class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -158,7 +183,8 @@ class PaliGemmaMultiModalProcessor( def get_insertion(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -183,7 +209,8 @@ class PaliGemmaMultiModalProcessor( PromptInsertion( modality="image", target=PromptIndexTargets.prefix( - [bos_token_id] if tokenizer.add_bos_token else []), + [bos_token_id] if tokenizer.add_bos_token else [] + ), insertion=get_insertion, ) ] @@ -194,9 +221,15 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: - mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs) + mm_inputs = super().apply( + prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() @@ -207,7 +240,6 @@ class PaliGemmaMultiModalProcessor( if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: prompt_token_ids.append(newline_token_id) mm_inputs["prompt_token_ids"] = prompt_token_ids - mm_inputs["prompt"] += newline_prompt return mm_inputs @@ -215,9 +247,9 @@ class PaliGemmaMultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( PaliGemmaMultiModalProcessor, info=PaliGemmaProcessingInfo, - dummy_inputs=PaliGemmaDummyInputsBuilder) -class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=PaliGemmaDummyInputsBuilder, +) +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -237,7 +269,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -254,13 +287,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, - projection_dim=config.vision_config.projection_dim) + projection_dim=config.vision_config.projection_dim, + ) self.quant_config = quant_config @@ -277,23 +312,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: + self, **kwargs: object + ) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -301,22 +325,16 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, concat=True) + h = w = self.config.vision_config.image_size return PaliGemmaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=pixel_values, + resolve_bindings={"h": h, "w": w}, ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( @@ -331,7 +349,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype image_features = vision_tower(pixel_values.to(dtype=target_dtype)) @@ -341,7 +358,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, self, image_input: PaliGemmaImageInputs, ) -> torch.Tensor: - if image_input["type"] == "image_embeds": return image_input["data"] @@ -357,8 +373,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -367,52 +382,29 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) return vision_embeddings - def get_input_embeddings( + def forward( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> IntermediateTensors: + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index f8db99eb92ba8..d3df5f9a59b58 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -22,7 +22,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -34,36 +36,42 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class PersimmonMLP(nn.Module): - - def __init__(self, - config: PersimmonConfig, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, config: PersimmonConfig, quant_config: Optional[QuantizationConfig] = None + ): super().__init__() - self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, - config.hidden_size, - quant_config=quant_config) + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, config.intermediate_size, quant_config=quant_config + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, config.hidden_size, quant_config=quant_config + ) self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states) -> torch.Tensor: @@ -74,12 +82,13 @@ class PersimmonMLP(nn.Module): class PersimmonAttention(nn.Module): - - def __init__(self, - config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config tensor_parallel_world_size = get_tensor_model_parallel_world_size() @@ -123,12 +132,14 @@ class PersimmonAttention(nn.Module): partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def _split_heads(self, x: torch.Tensor) -> torch.Tensor: # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] @@ -167,23 +178,28 @@ class PersimmonAttention(nn.Module): class PersimmonDecoderLayer(nn.Module): - - def __init__(self, - config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = PersimmonAttention(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = PersimmonAttention( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.mlp = PersimmonMLP(config, quant_config=quant_config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def forward( self, @@ -214,7 +230,6 @@ class PersimmonDecoderLayer(nn.Module): @support_torch_compile class PersimmonModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -224,18 +239,22 @@ class PersimmonModel(nn.Module): self.vocab_size = config.vocab_size self.config = config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: PersimmonDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -255,15 +274,14 @@ class PersimmonModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -282,34 +300,38 @@ class PersimmonModel(nn.Module): if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class PersimmonForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config self.vocab_size = config.vocab_size - self.model = PersimmonModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=False) + self.model = PersimmonModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -332,13 +354,10 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 21d517b3a490f..779b391008bb5 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -37,7 +37,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -49,41 +51,47 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class PhiAttention(nn.Module): - - def __init__(self, - config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size # pylint: disable=C0103 self.qkv_proj = QKVParallelLinear( @@ -100,28 +108,31 @@ class PhiAttention(nn.Module): ) scaling = self.head_size**-0.5 - rotary_dim = int(config.partial_rotary_factor * - (config.hidden_size // config.num_attention_heads)) + rotary_dim = int( + config.partial_rotary_factor + * (config.hidden_size // config.num_attention_heads) + ) assert rotary_dim % 2 == 0 # pylint: disable=C0301 # Refer to: # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 rope_theta = getattr(config, "rope_theta", 10000.0) - max_position_embeddings = getattr(config, "max_position_embeddings", - 2048) + max_position_embeddings = getattr(config, "max_position_embeddings", 2048) self.rotary_emb = get_rope( self.head_size, rotary_dim=rotary_dim, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_size, - scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -137,10 +148,9 @@ class PhiAttention(nn.Module): class PhiMLP(nn.Module): - - def __init__(self, - config: PhiConfig, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, config: PhiConfig, quant_config: Optional[QuantizationConfig] = None + ): super().__init__() n_inner = getattr(config, "n_inner", None) @@ -166,19 +176,20 @@ class PhiMLP(nn.Module): class PhiLayer(nn.Module): - - def __init__(self, - config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.self_attn = PhiAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) self.mlp = PhiMLP(config, quant_config) def forward( @@ -199,7 +210,6 @@ class PhiLayer(nn.Module): @support_torch_compile class PhiModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -209,18 +219,20 @@ class PhiModel(nn.Module): self.config = config self.quant_config = quant_config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PhiLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + lambda prefix: PhiLayer(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -240,7 +252,7 @@ class PhiModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: @@ -250,13 +262,12 @@ class PhiModel(nn.Module): return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v") + ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -265,7 +276,7 @@ class PhiModel(nn.Module): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -287,8 +298,7 @@ class PhiModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -315,16 +325,21 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config - self.model = PhiModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = PhiModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=True, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -336,21 +351,19 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) + logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py index f4e870c530309..56c8755123d3d 100644 --- a/vllm/model_executor/models/phi3.py +++ b/vllm/model_executor/models/phi3.py @@ -8,7 +8,6 @@ from vllm.model_executor.models.llama import LlamaForCausalLM class Phi3ForCausalLM(LlamaForCausalLM): - packed_modules_mapping = { "qkv_proj": [ "qkv_proj", diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 078251ee2bf4d..d972604db9cd2 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -21,66 +21,90 @@ from typing import Annotated, Any, Literal, Optional, Union import regex as re import torch import torch.nn as nn -from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, - ProcessorMixin) +from transformers import ( + BatchFeature, + CLIPVisionConfig, + PretrainedConfig, + ProcessorMixin, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, - PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) -# yapf: enable +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + ResolvedPromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 32044 -CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, - hidden_act="quick_gelu", - hidden_size=1024, - image_size=336, - intermediate_size=4096, - num_attention_heads=16, - num_channels=3, - num_hidden_layers=24, - patch_size=14, - projection_dim=768) +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( + dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, +) -def _init_img_processor(hf_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "") -> CLIPVisionModel: +def _init_img_processor( + hf_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", +) -> CLIPVisionModel: clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - layer_idx = hf_config.img_processor.get('layer_idx', -2) + layer_idx = hf_config.img_processor.get("layer_idx", -2) # Initialize the CLIP only up to the required feature layer if layer_idx < 0: - num_hidden_layers = clip_config.num_hidden_layers + \ - layer_idx + 1 + num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1 else: num_hidden_layers = layer_idx + 1 @@ -107,10 +131,11 @@ class Phi3VImagePixelInputs(TensorSchema): type: Literal["pixel_values", "image_embeds"] = "pixel_values" # Supports either a stacked tensor or a list of (p, 3, h, w) tensors - data: Annotated[ + pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} - ), # 'p' may vary across items + TensorShape( + "bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # 'p' may vary across items ] # Stacked tensor with height and width for each image @@ -125,6 +150,7 @@ class Phi3VImageEmbeddingInputs(TensorSchema): - f: Image feature size (e.g., number of tokens per image) - h: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[ Union[torch.Tensor, list[torch.Tensor]], @@ -136,15 +162,13 @@ Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs] class Phi3ImageEmbeddingBase(nn.Module): - def __init__(self) -> None: super().__init__() self.layer_idx: int self.type_feature: str self.img_processor: CLIPVisionModel - def get_img_features(self, - img_embeds: torch.FloatTensor) -> torch.FloatTensor: + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: TYPE_FEATURE = self.type_feature # NOTE: we skip the step to select the vision feature layer since @@ -165,52 +189,51 @@ class Phi3ImageEmbeddingBase(nn.Module): class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): """Phi3 Image embedding with HD transform.""" - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "") -> None: + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> None: super().__init__() # n_embed or hidden_size - hidden_size = config.n_embd if hasattr( - config, 'n_embd') else config.hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size self.img_processor = _init_img_processor( - config, quant_config, prefix=f"{prefix}.img_processor") + config, quant_config, prefix=f"{prefix}.img_processor" + ) - image_dim_out = config.img_processor['image_dim_out'] - self.num_img_tokens = config.img_processor['num_img_tokens'] + image_dim_out = config.img_processor["image_dim_out"] + self.num_img_tokens = config.img_processor["num_img_tokens"] self.image_dim_out = image_dim_out # global_gn and sub_gn for hd transform, serves as line separator - self.use_hd_transform = config.embd_layer.get('use_hd_transform', - False) + self.use_hd_transform = config.embd_layer.get("use_hd_transform", False) self.with_learnable_separator = config.embd_layer.get( - 'with_learnable_separator', False) - self.hd_transform_order = config.embd_layer.get( - 'hd_transform_order', 'glb_sub') + "with_learnable_separator", False + ) + self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub") # with_hd_transform and with_learnable_separator should have same value assert self.use_hd_transform and self.with_learnable_separator # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) - self.sub_GN = nn.Parameter( - torch.empty([1, 1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4])) dim_projection = hidden_size depth = 2 layers = [nn.Linear(image_dim_out * 4, dim_projection)] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) - self.type_feature = config.img_processor.get('type_feature', 'patch') + self.type_feature = config.img_processor.get("type_feature", "patch") - def forward(self, pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor) -> torch.FloatTensor: + def forward( + self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor + ) -> torch.FloatTensor: """ process image and return vision embeddings. @@ -220,19 +243,19 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): num_images, num_crops, c, h, w = pixel_values.shape pixel_values = pixel_values.flatten(0, 1) img_features = self.get_img_features(pixel_values) - img_features = img_features.reshape(num_images, num_crops, -1, - self.image_dim_out) - image_features_proj = self.hd_feature_transform( - img_features, image_sizes) + img_features = img_features.reshape( + num_images, num_crops, -1, self.image_dim_out + ) + image_features_proj = self.hd_feature_transform(img_features, image_sizes) return image_features_proj def hd_feature_transform(self, image_features, image_sizes): """ image_features: (num_images, num_crops+1, 24*24, 1024) """ - assert ( - self.hd_transform_order == 'sub_glb' - ), f'hd_transform_order `{self.hd_transform_order}` not implemented' + assert self.hd_transform_order == "sub_glb", ( + f"hd_transform_order `{self.hd_transform_order}` not implemented" + ) if isinstance(self.img_projection, nn.Sequential): target_device = self.img_projection[0].bias.device target_dtype = self.img_projection[0].bias.dtype @@ -240,13 +263,14 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): target_device = self.img_projection.bias.device target_dtype = self.img_projection.bias.dtype - global_image_features = image_features[:, - 0] # (num_images, 24*24, 1024) + global_image_features = image_features[:, 0] # (num_images, 24*24, 1024) # global feature can be viewed as a special HD case with num_crops 1x1 global_image_features_hd = self.reshape_hd_patches_2x2merge( - global_image_features, 1, 1) + global_image_features, 1, 1 + ) global_image_features_hd_newline = self.add_image_newline( - global_image_features_hd) + global_image_features_hd + ) batch_image_features_proj = [] # need a for loop to process each image because of different image sizes @@ -259,21 +283,27 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): # NOTE: real num_crops is padded # (num_crops, 24*24, 1024) - sub_image_features = image_features[i, 1:1 + num_crops] + sub_image_features = image_features[i, 1 : 1 + num_crops] sub_image_features_hd = self.reshape_hd_patches_2x2merge( - sub_image_features, h_crop, w_crop) + sub_image_features, h_crop, w_crop + ) sub_image_features_hd_newline = self.add_image_newline( - sub_image_features_hd) + sub_image_features_hd + ) # [sub features, separator, global features] - image_embeddings = torch.cat([ - sub_image_features_hd_newline.squeeze( - 0), # (h_crop*12*(w_crop*12+1), 4096) - self.glb_GN.squeeze(0), - global_image_features_hd_newline[i], - ]) + image_embeddings = torch.cat( + [ + sub_image_features_hd_newline.squeeze( + 0 + ), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ] + ) img_proj = self.img_projection( - image_embeddings.to(target_device, target_dtype)) + image_embeddings.to(target_device, target_dtype) + ) batch_image_features_proj.append(img_proj) return batch_image_features_proj @@ -293,11 +323,13 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 .reshape(N, -1, 4 * C) # N, 144, 4096 - .reshape(num_images, h_crop, w_crop, H // 2, H // 2, - -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .reshape( + num_images, h_crop, w_crop, H // 2, H // 2, -1 + ) # n_img, h_crop, w_crop, 12, 12, 4096 .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 - .reshape(num_images, h_crop * H // 2, w_crop * H // 2, - 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + .reshape( + num_images, h_crop * H // 2, w_crop * H // 2, 4 * C + ) # n_img, h_crop*12, w_crop*12, 4096 ) return image_features_hd @@ -308,16 +340,16 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): """ num_images, h, w, hid_dim = image_features_hd.shape # add the newline token to the HD image feature patches - newline_embeddings = self.sub_GN.expand(num_images, h, -1, - -1) # (n_img, h, 1, hid_dim) + newline_embeddings = self.sub_GN.expand( + num_images, h, -1, -1 + ) # (n_img, h, 1, hid_dim) image_features_hd_newline = torch.cat( - [image_features_hd, newline_embeddings], - dim=2).reshape(num_images, -1, hid_dim) + [image_features_hd, newline_embeddings], dim=2 + ).reshape(num_images, -1, hid_dim) return image_features_hd_newline class Phi3VProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -342,7 +374,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo): class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -355,22 +386,25 @@ class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -417,7 +451,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): def get_replacement_phi3v(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -431,24 +466,38 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): return [_IMAGE_TOKEN_ID] * num_image_tokens - num_images = mm_items.get_count("image", strict=False) - return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:num_images] + ) ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + + return new_update + def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: # align to hf behavior when there are images - if len(mm_item_counts): + if len(mm_prompt_updates): tokenizer = self.info.get_tokenizer() # to decode token_ids to the original text, we need to # 1. remove the first bos token @@ -467,8 +516,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407 pattern = r"<\|image_\d+\|>" prompt_chunks = [ - tokenizer(chunk).input_ids - for chunk in re.split(pattern, text) + tokenizer(chunk).input_ids for chunk in re.split(pattern, text) ] image_tags = [ tokenizer(chunk, add_special_tokens=False).input_ids @@ -477,19 +525,21 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): if len(prompt_chunks) > len(image_tags): image_tags.append([]) token_ids = [ - e for sublist in zip(prompt_chunks, image_tags) - for ele in sublist for e in ele + e + for sublist in zip(prompt_chunks, image_tags) + for ele in sublist + for e in ele ] - token_ids, text, placeholders = super()._apply_prompt_updates( + token_ids, placeholders = super()._apply_prompt_updates( token_ids=token_ids, mm_prompt_updates=mm_prompt_updates, - mm_item_counts=mm_item_counts, ) # Keep the behavior in line with HF processor - if text.startswith("<s> <|image|>"): - text = text.replace("<s> <|image|>", "<s><|image|>", 1) + if len(mm_prompt_updates) and ( + token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False) + ): token_ids = [token_ids[0], *token_ids[2:]] placeholders = { modality: [ @@ -499,26 +549,29 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): start_idx=p.start_idx - 1, tokens=p.tokens, is_embed=p.is_embed, - ) for p in ps + ) + for p in ps ] for modality, ps in placeholders.items() } - return token_ids, text, placeholders + return token_ids, placeholders -@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, - info=Phi3VProcessingInfo, - dummy_inputs=Phi3VDummyInputsBuilder) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, - SupportsQuant): +@MULTIMODAL_REGISTRY.register_processor( + Phi3VMultiModalProcessor, + info=Phi3VProcessingInfo, + dummy_inputs=Phi3VDummyInputsBuilder, +) +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", "model.vision_embed_tokens.": "vision_embed_tokens.", "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -547,7 +600,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self.vision_embed_tokens = Phi3HDImageEmbedding( config, self.quant_config, - prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) + prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -561,10 +615,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Phi3VImageInputs]: + self, **kwargs: object + ) -> Optional[Phi3VImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -575,12 +631,13 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, if pixel_values is not None: return Phi3VImagePixelInputs( type="pixel_values", - data=flatten_bn(pixel_values), + pixel_values=flatten_bn(pixel_values), image_sizes=flatten_bn(image_sizes, concat=True), resolve_bindings={ "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, - "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size - }) + "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, + }, + ) if image_embeds is not None: return Phi3VImageEmbeddingInputs( @@ -594,7 +651,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self, image_input: Phi3VImageInputs, ) -> torch.Tensor: - if image_input["type"] == "image_embeds": image_data = image_input["data"] if is_list_of(image_data, torch.Tensor): @@ -609,16 +665,16 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) assert self.vision_embed_tokens is not None - image_embeds = self.vision_embed_tokens(image_input["data"], - image_input["image_sizes"]) + image_embeds = self.vision_embed_tokens( + image_input["pixel_values"], image_input["image_sizes"] + ) return image_embeds def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -629,54 +685,59 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.image_token_id) - return inputs_embeds + inputs_embeds = self._get_text_embeddings( + input_ids, + self.embed_tokens, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object): + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + return self.language_model.compute_logits(hidden_states) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) # The HF config doesn't specify whether these are tied, # so we detect it this way diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index ee8b71caf336b..002233d0677b0 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -2,61 +2,87 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from transformers import (BatchFeature, Phi4MultimodalAudioConfig, - Phi4MultimodalConfig, Phi4MultimodalFeatureExtractor, - Phi4MultimodalImageProcessorFast) +from transformers import ( + BatchFeature, + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalFeatureExtractor, + Phi4MultimodalImageProcessorFast, +) from transformers import Phi4MultimodalProcessor as Phi4MMProcessor from transformers.models.phi4_multimodal.modeling_phi4_multimodal import ( - Phi4MultimodalAudioConvModule, Phi4MultimodalAudioNemoConvSubsampling, - Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor) + Phi4MultimodalAudioConvModule, + Phi4MultimodalAudioNemoConvSubsampling, + Phi4MultimodalAudioRelativeAttentionBias, + adaptive_enc_mask, + unfold_tensor, +) from vllm.config import VllmConfig -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, - ImageProcessorItems, ImageSize, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) - -# <|endoftext10|> (see vocab.json in hf model) -_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 -# <|endoftext11|> -_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 -def _get_padding_size(orig_width: int, orig_height: int, target_height: int, - target_width: int): +def _get_padding_size( + orig_width: int, orig_height: int, target_height: int, target_width: int +): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -70,7 +96,6 @@ def _get_padding_size(orig_width: int, orig_height: int, target_height: int, class Phi4MMProjector(nn.Module): - def __init__(self, input_size: int, hidden_size: int): super().__init__() self.up = ColumnParallelLinear(input_size, hidden_size) @@ -94,41 +119,44 @@ class Phi4MMImageEmbedding(nn.Module): self.crop_size = config.vision_config.crop_size self.image_dim_out = config.vision_config.hidden_size - n_patches = (config.vision_config.image_size // - config.vision_config.patch_size) + n_patches = config.vision_config.image_size // config.vision_config.patch_size if n_patches % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) n_patches += 1 - self.num_img_tokens = (n_patches // 2)**2 + self.num_img_tokens = (n_patches // 2) ** 2 - num_hidden_layers = (config.vision_config.num_hidden_layers + - self.layer_idx + - 1 if self.layer_idx < 0 else self.layer_idx + 1) + num_hidden_layers = ( + config.vision_config.num_hidden_layers + self.layer_idx + 1 + if self.layer_idx < 0 + else self.layer_idx + 1 + ) self.img_processor = Idefics2VisionTransformer( config.vision_config, require_post_norm=False, - num_hidden_layers_override=num_hidden_layers) + num_hidden_layers_override=num_hidden_layers, + ) self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) - self.img_projection = Phi4MMProjector(self.image_dim_out, - config.hidden_size) + self.img_projection = Phi4MMProjector(self.image_dim_out, config.hidden_size) self.global_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, self.image_dim_out])) + torch.zeros([1, 1, self.image_dim_out]) + ) self.sub_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, 1, self.image_dim_out])) + torch.zeros([1, 1, 1, self.image_dim_out]) + ) def get_img_features( self, img_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - img_feature = self.img_processor(img_embeds, - patch_attention_mask=attention_mask) + img_feature = self.img_processor( + img_embeds, patch_attention_mask=attention_mask + ) patch_feature = img_feature # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, - patch_feature.size(-1)) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if getattr(self, "img_processor_padding", None) is not None: @@ -137,9 +165,8 @@ class Phi4MMImageEmbedding(nn.Module): # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view( - -1, - patch_feature.size(1) * patch_feature.size(2), - patch_feature.size(-1)) + -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) + ) return patch_feature def forward( @@ -149,7 +176,8 @@ class Phi4MMImageEmbedding(nn.Module): image_attention_mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: image_pixel_values = image_pixel_values.to( - self.img_processor.embeddings.patch_embedding.weight.dtype) + self.img_processor.embeddings.patch_embedding.weight.dtype + ) target_device = self.img_projection.up.bias.device target_dtype = self.img_projection.up.bias.dtype @@ -159,11 +187,13 @@ class Phi4MMImageEmbedding(nn.Module): img_features = self.get_img_features( image_pixel_values.flatten(0, 1), attention_mask=image_attention_mask.flatten(0, 1).to( - dtype=bool, device=target_device), + dtype=bool, device=target_device + ), ) base_feat_size = int(np.sqrt(img_features.shape[1])) - img_features = img_features.view(batch_size, -1, base_feat_size**2, - self.image_dim_out) + img_features = img_features.view( + batch_size, -1, base_feat_size**2, self.image_dim_out + ) image_sizes = image_sizes.view(-1, 2) output_imgs = [] @@ -174,58 +204,70 @@ class Phi4MMImageEmbedding(nn.Module): area_ratio = height_ratio * width_ratio global_img = img_features[idx, :1] - global_img = global_img.reshape(1, base_feat_size, base_feat_size, - self.image_dim_out).contiguous() + global_img = global_img.reshape( + 1, base_feat_size, base_feat_size, self.image_dim_out + ).contiguous() temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, base_feat_size, 1, 1) - global_img = torch.cat([global_img, temporary_extensor], - dim=2).reshape(1, -1, self.image_dim_out) + 1, base_feat_size, 1, 1 + ) + global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape( + 1, -1, self.image_dim_out + ) sub_img = img_features[idx, 1:] sub_img = sub_img[:area_ratio] - sub_img = (sub_img.reshape( - height_ratio, width_ratio, base_feat_size, base_feat_size, - self.image_dim_out).transpose(1, 2).reshape( - 1, height_ratio * base_feat_size, + sub_img = ( + sub_img.reshape( + height_ratio, + width_ratio, + base_feat_size, + base_feat_size, + self.image_dim_out, + ) + .transpose(1, 2) + .reshape( + 1, + height_ratio * base_feat_size, width_ratio * base_feat_size, - self.image_dim_out).contiguous()) + self.image_dim_out, + ) + .contiguous() + ) if image_attention_mask is not None: reshaped_image_attention_mask = ( - image_attention_mask[idx, 1:area_ratio + 1, - 0::2, 0::2].reshape( - height_ratio, width_ratio, - base_feat_size, - base_feat_size).transpose( - 1, 2).reshape( - 1, height_ratio * - base_feat_size, - width_ratio * - base_feat_size)) - useful_height = int( - reshaped_image_attention_mask[0, :, 0].sum().item()) - useful_width = int( - reshaped_image_attention_mask[0, 0, :].sum().item()) + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape( + 1, height_ratio * base_feat_size, width_ratio * base_feat_size + ) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, useful_height, 1, 1) + 1, useful_height, 1, 1 + ) else: temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, height_ratio * base_feat_size, 1, 1) + 1, height_ratio * base_feat_size, 1, 1 + ) - sub_img = torch.cat([sub_img, temporary_extensor], - dim=2).reshape(1, -1, self.image_dim_out) + sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape( + 1, -1, self.image_dim_out + ) # Merge global and sub output_imgs.append( torch.cat( - [sub_img, self.global_img_feature_extensor, global_img], - dim=1)) + [sub_img, self.global_img_feature_extensor, global_img], dim=1 + ) + ) img_set_tensor = [] for output_img in output_imgs: - output_img = output_img.to(device=target_device, - dtype=target_dtype) + output_img = output_img.to(device=target_device, dtype=target_dtype) img_feature_proj = self.img_projection(output_img) img_set_tensor.append(img_feature_proj.flatten(0, 1)) @@ -233,7 +275,6 @@ class Phi4MMImageEmbedding(nn.Module): class Phi4MultimodalAudioMLP(nn.Module): - def __init__( self, config: Phi4MultimodalAudioConfig, @@ -244,15 +285,19 @@ class Phi4MultimodalAudioMLP(nn.Module): self.layer_norm = nn.LayerNorm(config.hidden_size) self.act_fn = MulAndSilu() self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, [config.intermediate_size] * 2, + config.hidden_size, + [config.intermediate_size] * 2, bias=True, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.layer_norm(hidden_states) @@ -263,7 +308,6 @@ class Phi4MultimodalAudioMLP(nn.Module): class Phi4MultimodalAudioAttention(nn.Module): - def __init__( self, config: Phi4MultimodalAudioConfig, @@ -279,7 +323,8 @@ class Phi4MultimodalAudioAttention(nn.Module): raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( @@ -336,7 +381,6 @@ class Phi4MultimodalAudioAttention(nn.Module): class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() @@ -355,11 +399,9 @@ class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) hidden_states = self.layer_norm_att(residual) - hidden_states = residual + self.self_attn(hidden_states, - attention_mask) + hidden_states = residual + self.self_attn(hidden_states, attention_mask) hidden_states = hidden_states + self.conv(hidden_states) - hidden_states = hidden_states + 0.5 * self.feed_forward_out( - hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) out = self.layer_norm(hidden_states) @@ -373,8 +415,8 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module): Typically used as a very first layer in a model. Args: - input_size: int - layer input size. + config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) + object containing model parameters. """ def __init__(self, config: Phi4MultimodalAudioConfig): @@ -393,19 +435,21 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module): class Phi4MultimodalAudioModel(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config) self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) - self.relative_attention_bias_layer = ( - Phi4MultimodalAudioRelativeAttentionBias(config)) - self.encoders = nn.ModuleList([ - Phi4MultimodalAudioConformerEncoderLayer(config) - for _ in range(config.num_blocks) - ]) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias( + config + ) + self.encoders = nn.ModuleList( + [ + Phi4MultimodalAudioConformerEncoderLayer(config) + for _ in range(config.num_blocks) + ] + ) def _streaming_mask( self, @@ -418,9 +462,11 @@ class Phi4MultimodalAudioModel(nn.Module): # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size) - enc_streaming_mask = (adaptive_enc_mask( - seq_len, chunk_start_idx, - left_window=left_chunk).unsqueeze(0).expand([batch_size, -1, -1])) + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) return enc_streaming_mask def forward_embeddings( @@ -429,18 +475,18 @@ class Phi4MultimodalAudioModel(nn.Module): masks: torch.Tensor, ): """Forwarding the inputs through the top embedding layers""" - seq_len = math.ceil(hidden_states.shape[1] / - self.config.time_reduction) + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) if seq_len <= 0: raise ValueError( f"Sequence length after time reduction is invalid: {seq_len}." - "Your input feature is too short.") + "Your input feature is too short." + ) batch_size = hidden_states.shape[0] - enc_streaming_mask = self._streaming_mask(seq_len, batch_size, - self.config.chunk_size, - self.config.left_chunk) + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.config.chunk_size, self.config.left_chunk + ) enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) hidden_states, masks = self.embed(hidden_states, masks) @@ -455,13 +501,14 @@ class Phi4MultimodalAudioModel(nn.Module): return hidden_states, hs_mask, masks - def calculate_hs_mask(self, hidden_states: torch.Tensor, - device: torch.device, mask: torch.Tensor): + def calculate_hs_mask( + self, hidden_states: torch.Tensor, device: torch.device, mask: torch.Tensor + ): max_audio_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] - enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, - self.config.chunk_size, - self.config.left_chunk) + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask @@ -469,17 +516,15 @@ class Phi4MultimodalAudioModel(nn.Module): feature_lens = mask.sum(1) padding_length = feature_lens pad_mask = torch.arange(0, max_audio_length, device=device).expand( - padding_length.size(0), -1) < padding_length.unsqueeze(1) + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask - def forward(self, - hidden_states: torch.Tensor, - mask: Optional[torch.Tensor] = None): + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None): hidden_states = self.encoder_embedding(hidden_states) - hidden_states, hs_mask, mask = self.forward_embeddings( - hidden_states, mask) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) unfolded = False bs, seq_len, _ = hidden_states.shape @@ -495,9 +540,9 @@ class Phi4MultimodalAudioModel(nn.Module): else: chunk_pad_size = 0 if chunk_pad_size > 0: - hidden_states_pad = F.pad(hidden_states, - (0, 0, 0, chunk_pad_size), - "constant", 0) + hidden_states_pad = F.pad( + hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0 + ) hidden_states = hidden_states_pad.to(hidden_states.device) hidden_states = unfold_tensor(hidden_states, max_seq_len) @@ -505,24 +550,24 @@ class Phi4MultimodalAudioModel(nn.Module): if mask is not None: # revise hs_mask here because the previous calculated hs_mask # did not consider extra pad - subsampled_pad_mask = mask.squeeze( - 1) # [bz, subsampled_unmask_seq_len] + subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( - subsampled_pad_mask, (0, chunk_pad_size), "constant", - False) # extra padding to the pad mask + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = ( - extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()) + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + ) masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze( - -1).bool() # unfold op does not support bool tensor + -1 + ).bool() # unfold op does not support bool tensor hs_mask = self.calculate_hs_mask( hidden_states, hidden_states.device, masks_unfold ) # calculate hs_mask based on the unfolded pad mask - relative_attention_bias = self.relative_attention_bias_layer( - hidden_states) + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: @@ -539,7 +584,6 @@ class Phi4MultimodalAudioModel(nn.Module): class Phi4MMAudioEmbedding(nn.Module): - def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config @@ -548,12 +592,11 @@ class Phi4MMAudioEmbedding(nn.Module): self.encoder = Phi4MultimodalAudioModel(config.audio_config) audio_config = config.audio_config - proj_input_size = (audio_config.hidden_size * - audio_config.downsample_rate) + proj_input_size = audio_config.hidden_size * audio_config.downsample_rate self.vision_speech_projection = Phi4MMProjector( - proj_input_size, config.hidden_size) - self.speech_projection = Phi4MMProjector(proj_input_size, - config.hidden_size) + proj_input_size, config.hidden_size + ) + self.speech_projection = Phi4MMProjector(proj_input_size, config.hidden_size) def get_projection( self, @@ -571,23 +614,23 @@ class Phi4MMAudioEmbedding(nn.Module): audio_attention_mask=None, audio_projection_mode="speech", ) -> torch.FloatTensor: - audio_projection = self.get_projection(audio_projection_mode) target_device = audio_projection.up.bias.device target_dtype = audio_projection.up.bias.dtype - audio_input_features = audio_input_features.to(device=target_device, - dtype=target_dtype) + audio_input_features = audio_input_features.to( + device=target_device, dtype=target_dtype + ) - audio_encoder_hidden_states = self.encoder(audio_input_features, - audio_attention_mask) + audio_encoder_hidden_states = self.encoder( + audio_input_features, audio_attention_mask + ) audio_embeds = audio_projection(audio_encoder_hidden_states) return audio_embeds.flatten(0, 1) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -608,57 +651,97 @@ class Phi4MMAudioEmbedding(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Phi4MMImagePixelInputs(TypedDict): +class Phi4MMImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask + """ + type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape( + "bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] + + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] + + +class Phi4MMImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match language model backbone) """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" - - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - - -class Phi4MMImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", "h"), + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - f: Number of Mel filterbank bins (80) + - t: Time frames (M) """ - -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ + type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] @@ -670,9 +753,9 @@ def cat_with_pad(tensors, dim, padding_value=0): cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + assert all(t.dim() == ndim for t in tensors[1:]), ( + "All tensors must have the same number of dimensions" + ) out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) @@ -692,15 +775,13 @@ def cat_with_pad(tensors, dim, padding_value=0): class Phi4MMProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> Phi4MultimodalConfig: return self.ctx.get_hf_config(Phi4MultimodalConfig) def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor: return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs) - def get_feature_extractor( - self, **kwargs: object) -> Phi4MultimodalFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> Phi4MultimodalFeatureExtractor: return self.get_hf_processor(**kwargs).audio_processor def get_image_processor( @@ -734,9 +815,12 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) + target_ratios = set( + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num + ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target @@ -746,6 +830,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): target_ratios, orig_width, orig_height, + image_size, ) # calculate the target width and height @@ -768,49 +853,56 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ): """ compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing + the image encoder architecture and exclude output features containing only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be + for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ assert vit_image_size % vit_patch_size == 0, ( - "vit_image_size must be divisible by vit_patch_size") - assert (vit_image_size // vit_patch_size % - token_compression_factor == 0), ( - "vit_image_size // vit_patch_size must be divisible by " - "token_compression_factor") + "vit_image_size must be divisible by vit_patch_size" + ) + assert vit_image_size // vit_patch_size % token_compression_factor == 0, ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor" + ) target_aspect_ratio, target_height, target_width = ( - self._find_target_aspect_ratio(orig_width, - orig_height, - vit_image_size, - dynamic_hd_size, - min_num=1)) + self._find_target_aspect_ratio( + orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1 + ) + ) assert target_aspect_ratio[0] * vit_image_size == target_width, ( - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + ) assert target_aspect_ratio[1] * vit_image_size == target_height, ( - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + ) + assert ( + target_height % vit_image_size == 0 and target_width % vit_image_size == 0 + ) padding_height, padding_width = _get_padding_size( - orig_width, orig_height, target_height, target_width) - assert padding_width == 0 or padding_height == 0, \ + orig_width, orig_height, target_height, target_width + ) + assert padding_width == 0 or padding_height == 0, ( "padding_width or padding_height must be 0" + ) target_feat_width = target_width // vit_patch_size target_feat_height = target_height // vit_patch_size if padding_width >= vit_patch_size: assert padding_height == 0, "padding_height not 0" non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) + padding_width / vit_patch_size + ) non_pad_feat_height = target_feat_height elif padding_height >= vit_patch_size: assert padding_width == 0, "padding_width not 0" non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) + padding_height / vit_patch_size + ) non_pad_feat_width = target_feat_width else: # small padding shorter than a vit patch @@ -827,15 +919,17 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // - token_compression_factor)**2 + num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor + num_global_image_newline_tokens = vit_feature_size // token_compression_factor - return (num_global_image_tokens + num_sep_tokens + - num_hd_patch_tokens + num_hd_newline_tokens + - num_global_image_newline_tokens) + return ( + num_global_image_tokens + + num_sep_tokens + + num_hd_patch_tokens + + num_hd_newline_tokens + + num_global_image_newline_tokens + ) def get_num_image_tokens( self, @@ -930,7 +1024,6 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -945,28 +1038,34 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + audio_overrides = mm_options.get("audio") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "audio": - self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, - num_audios=num_audios), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "audio": self._get_dummy_audios( + length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios, + overrides=audio_overrides, + ), } return mm_data class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -985,29 +1084,29 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): audio_data = mm_data.pop("audios", []) if audio_data: - mm_data['audio'] = audio_data + mm_data["audio"] = audio_data - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) if "image_pixel_values" in processed_outputs: num_img_tokens = [ - self.info.get_num_image_tokens(image_width=img_size[0], - image_height=img_size[1]) + self.info.get_num_image_tokens( + image_width=img_size[0], image_height=img_size[1] + ) for img_size in processed_outputs["image_sizes"] ] processed_outputs["num_img_tokens"] = num_img_tokens if audio_data: - audio_features = processed_outputs['audio_input_features'] + audio_features = processed_outputs["audio_input_features"] sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate feature_sizes = [ - self.info.get_audio_num_frames(len(audio), sr) - for audio in audio_data + self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data ] - processed_outputs['audio_input_features'] = [ - audio_features[idx, :size] - for idx, size in enumerate(feature_sizes) + processed_outputs["audio_input_features"] = [ + audio_features[idx, :size] for idx, size in enumerate(feature_sizes) ] return processed_outputs @@ -1032,16 +1131,16 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.vocab[tokenizer.image_token] - audio_token_id = tokenizer.vocab[tokenizer.audio_token] + image_token_id: int = tokenizer.vocab[tokenizer.image_token] + audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - audio_processor = self.info.get_feature_extractor( - **hf_processor_mm_kwargs) + audio_processor = self.info.get_feature_extractor(**hf_processor_mm_kwargs) def get_image_replacement_phi4mm(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -1053,22 +1152,18 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [image_token_id] * num_image_tokens - - return image_tokens + return [image_token_id] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) # TODO(Isotr0py): support embedding inputs audio_len = audios.get_audio_length(item_idx) audio_frames = self.info.get_audio_num_frames( - audio_len, audio_processor.sampling_rate) - audio_embed_size = self.info._compute_audio_embed_size( - audio_frames) + audio_len, audio_processor.sampling_rate + ) + audio_embed_size = self.info._compute_audio_embed_size(audio_frames) - audio_tokens = [audio_token_id] * audio_embed_size - - return audio_tokens + return [audio_token_id] * audio_embed_size return [ PromptReplacement( @@ -1093,6 +1188,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1152,12 +1248,14 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMAudioInputs]: """ - Parse and validate the audio input to the model. This handles both + Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. @@ -1174,25 +1272,18 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - - return Phi4MMAudioFeatureInputs(type="audio_features", - data=flatten_bn(audio_features)) + return Phi4MMAudioFeatureInputs( + type="audio_features", data=flatten_bn(audio_features) + ) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", - data=audio_embeds) + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, audio_input: Phi4MMAudioInputs, - audio_projection_mode: str) -> NestedTensors: + def _process_audio_input( + self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str + ) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is @@ -1216,12 +1307,14 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): self.audio_embed( features.unsqueeze(0).to(dtype), audio_projection_mode=audio_projection_mode, - ) for features in audio_features + ) + for features in audio_features ] return audio_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMImagePixelInputs]: image_pixel_values: NestedTensors = kwargs.get("image_pixel_values") if image_pixel_values is None: return None @@ -1229,12 +1322,16 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): image_sizes = kwargs.get("image_sizes") image_attention_mask = kwargs.get("image_attention_mask") num_img_tokens = kwargs.get("num_img_tokens") - assert image_sizes is not None and image_attention_mask is not None\ - and num_img_tokens is not None, "Missing image inputs" + assert ( + image_sizes is not None + and image_attention_mask is not None + and num_img_tokens is not None + ), "Missing image inputs" if is_list_of(image_pixel_values, torch.Tensor): - assert all(p.dim() == 5 - for p in image_pixel_values), "Incorrect image inputs" + assert all(p.dim() == 5 for p in image_pixel_values), ( + "Incorrect image inputs" + ) # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. @@ -1263,17 +1360,16 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ - n for num_tensor in num_img_tokens - for n in num_tensor.tolist() + n for num_tensor in num_img_tokens for n in num_tensor.tolist() ] elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", @@ -1289,44 +1385,46 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("image_pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("audio_input_features", - "audio_embeds") and "audios" not in modalities: - modalities["audios"] = self._parse_and_validate_audio_input( - **kwargs) + if ( + input_key in ("image_pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("audio_input_features", "audio_embeds") + and "audios" not in modalities + ): + modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) return modalities def _process_image_input( - self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: + self, image_input: Phi4MMImagePixelInputs + ) -> list[torch.Tensor]: if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: dtype = next(self.image_embed.parameters()).dtype - pixel_values = image_input['data'].to(dtype) - image_sizes = image_input['image_sizes'] - image_attention_mask = image_input['image_attention_mask'] - image_embeds = self.image_embed(pixel_values, image_sizes, - image_attention_mask) + pixel_values = image_input["data"].to(dtype) + image_sizes = image_input["image_sizes"] + image_attention_mask = image_input["image_attention_mask"] + image_embeds = self.image_embed( + pixel_values, image_sizes, image_attention_mask + ) return image_embeds - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - audio_projection_mode = 'speech' + audio_projection_mode = "speech" for modality in modalities: # make sure process images first if modality == "images": @@ -1337,52 +1435,12 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): if modality == "audios": audio_input = modalities["audios"] audio_embeddings = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode) + audio_input, audio_projection_mode=audio_projection_mode + ) multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Phi4MMImagePixelInputs] = None, - audio_input: Optional[Phi4MMAudioFeatureInputs] = None, - ) -> torch.Tensor: - audio_projection_mode = 'speech' - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID, - ) - audio_projection_mode = 'vision' - - if audio_input is not None: - audio_embeds = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - audio_embeds, - placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1394,22 +1452,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - audio_input = self._parse_and_validate_audio_input(**kwargs) - - if image_input is None and audio_input is None: - inputs_embeds = None - else: - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - audio_input=audio_input) - input_ids = None - hidden_states = self.language_model( input_ids, positions, @@ -1422,13 +1464,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1439,8 +1478,9 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return MultiModelKeys.from_string_field( language_model="language_model.", connector=[ - "img_projection", "vision_speech_projection", - "speech_projection" + "img_projection", + "vision_speech_projection", + "speech_projection", ], tower_model=["image_embed", "audio_embed"], ) diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py deleted file mode 100644 index fcdfcb7bc1603..0000000000000 --- a/vllm/model_executor/models/phi4flash.py +++ /dev/null @@ -1,737 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from collections.abc import Iterable -from typing import Optional, Union - -import torch -import torch.nn as nn -from transformers.activations import ACT2FN - -import vllm.envs as envs -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.selector import _Backend -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsV0Only) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .utils import make_layers, maybe_prefix - -logger = init_logger(__name__) - - -class SwiGLUActivation(nn.Module): - - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - return x1 * nn.functional.silu(x2) - - -class SambaYMLP(nn.Module): - """Gated Linear Unit. - - Reference: - Language Modeling with Gated Convolutional Networks. - https://arxiv.org/pdf/1612.08083v3.pdf. - - """ - - def __init__(self, config): - super().__init__() - - self.config = config - self.fc1 = nn.Linear(config.hidden_size, - 2 * config.intermediate_size, - bias=False) - self.fc2 = nn.Linear(config.intermediate_size, - config.hidden_size, - bias=False) - - self.activation_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - y = self.fc1(hidden_states) - gate, y = y.chunk(2, dim=-1) - y = y * self.activation_fn(gate) - return self.fc2(y) - - -def get_virtual_engine(): - forward_context: ForwardContext = get_forward_context() - return forward_context.virtual_engine - - -class SambaYAttention(nn.Module): - - def __init__(self, - config, - layer_idx: Optional[int] = None, - yoco_cross: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): - super().__init__() - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing " - "a `layer_idx` is not recommended and will lead to errors " - "during the forward call if caching is used. Please make " - "sure to provide a `layer_idx` when creating this class.") - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.yoco_cross = yoco_cross - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError("hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads}).") - - op_size = self.num_heads * self.head_dim + 2 * ( - self.num_key_value_heads * self.head_dim) - self.out_proj = nn.Linear(self.num_heads * self.head_dim, - self.hidden_size, - bias=True) - if yoco_cross: - self.Wqkv = nn.Linear(self.hidden_size, - self.num_heads * self.head_dim, - bias=True) - else: - self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) - - # disable sliding window for the second half of the model - is_sliding = config.layer_types[layer_idx] == "sliding_attention" - sliding_window = config.sliding_window if is_sliding else None - - assert self.num_heads % 2 == 0, 'num_heads should be even' - assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' - - self.lambda_init = self.lambda_init_fn(layer_idx) - self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, - std=0.1)) - self.subln = nn.RMSNorm(2 * self.head_dim, - eps=1e-5, - elementwise_affine=True) - - params = { - 'differential_flash_attention_config': { - 'lambda_init': self.lambda_init, - 'lambda_q1': self.lambda_q1, - 'lambda_k1': self.lambda_k1, - 'lambda_q2': self.lambda_q2, - 'lambda_k2': self.lambda_k2, - "subln": self.subln, - } - } - - if yoco_cross: - kv_shared_layer_index = config.num_hidden_layers // 2 + 1 - kv_sharing_target_layer_name = \ - f"model.layers.{kv_shared_layer_index}.self_attn.attn" - else: - kv_sharing_target_layer_name = None - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.head_dim**-0.5, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER, - kv_sharing_target_layer_name=kv_sharing_target_layer_name, - **params) - assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ - "DIFFERENTIAL_FLASH_ATTN required" - - def lambda_init_fn(self, depth): - return 0.8 - 0.6 * math.exp(-0.3 * depth) - - def forward( - self, - hidden_states: torch.Tensor, - ): - - if not self.yoco_cross: # need to generate kv-cache - qkv = self.Wqkv(hidden_states) - q, k, v = qkv.split([ - self.hidden_size, self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim - ], - dim=-1) - attn_output = self.attn(q, k, v) - else: # reuse the kv cache, full attention - q = self.Wqkv(hidden_states) - attn_output = self.attn(q, None, None) - attn_output = attn_output.view(-1, self.num_heads * self.head_dim) - return self.out_proj(attn_output) - - -class Phi4Mamba(nn.Module): - - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", # difference - dt_scale=1.0, # difference - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - yoco_cross=False, - yoco_kv=False, - ): - factory_kwargs = {"params_dtype": dtype} # difference - super().__init__() - self.yoco_cross = yoco_cross - self.yoco_kv = yoco_kv - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / - 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.swiGluActivation = SwiGLUActivation() - if self.yoco_cross: - self.in_proj = MergedColumnParallelLinear(self.d_model, - [self.d_inner], - bias=bias, - **factory_kwargs) - self.out_proj = RowParallelLinear(self.d_inner, - self.d_model, - bias=bias, - **factory_kwargs) - return - self.conv1d = ColumnParallelLinear( - input_size=d_conv, - output_size=self.d_inner, - bias=conv_bias, - params_dtype=dtype, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear( - self.d_model, - [self.d_inner] * 2, - bias=bias, - params_dtype=dtype, - ) - - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.d_inner, - self.dt_rank + self.d_state * 2, - bias=False, - params_dtype=dtype, - ) - - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear( - self.dt_rank, - self.d_inner, - bias=True, - skip_bias_add=True, - params_dtype=dtype, - ) - - # # D "skip" parameter - # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 - self.A = nn.Parameter( - torch.empty( - self.d_inner, - self.d_state, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) - - self.out_proj = RowParallelLinear( - self.d_inner, - self.d_model, - bias=bias, - input_is_parallel=True, - params_dtype=dtype, - ) - self.activation = "silu" - - def forward(self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - yoco_key_values=None) -> torch.Tensor: - - if self.yoco_cross: - out = self.in_proj(hidden_states)[0] - out = self.swiGluActivation(yoco_key_values, out) - out = self.out_proj(out) - return out[0], yoco_key_values - - # 1. Gated MLP's linear projection - # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - projected_states = self.in_proj( - hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.dt_rank, self.d_state, self.d_state], - dim=-1, - ) - - # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - # z, - None if self.yoco_kv else gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - # z - # gate.transpose(0, 1), - None if self.yoco_kv else gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - if self.yoco_kv: - # gate = gate.transpose(-1,-2).contiguous() - yoco_key_values = scan_outputs.transpose(-2, -1) - scan_outputs = self.swiGluActivation(scan_outputs, gate) - - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - - return contextualized_states, yoco_key_values - - -class SambaYDecoderLayer(nn.Module): - - def __init__( - self, - config, - layer_idx, - cache_config, - prefix: str = "", - ) -> None: - super().__init__() - - self.config = config - self.layer_idx = layer_idx - - self.mlp = SambaYMLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - self.yoco_mb = False - self.yoco_cross = False - if layer_idx >= config.num_hidden_layers // 2: - self.yoco_mb = True - self.yoco_cross = (layer_idx - >= (config.num_hidden_layers // 2 + 2)) - self.use_mamba = config.mb_per_layer > 0 and \ - layer_idx % config.mb_per_layer == 0 - if self.use_mamba: - factory_kwargs = {"dtype": None} - self.attn = Phi4Mamba(config.hidden_size, - layer_idx=layer_idx, - yoco_cross=self.yoco_cross, - yoco_kv=self.yoco_mb, - **factory_kwargs) - else: - self.attn = SambaYAttention(config, - layer_idx=layer_idx, - yoco_cross=self.yoco_cross, - cache_config=cache_config, - prefix=f"{prefix}.self_attn") - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - ssm_output: Optional[torch.LongTensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.use_mamba: - assert mamba_cache_params is not None - else: - assert mamba_cache_params is None - - residual = hidden_states - hidden_states = self.input_layernorm( - hidden_states.to(dtype=self.input_layernorm.weight.dtype)) - - if self.use_mamba: - attn_outputs, ssm_output = self.attn(hidden_states, - attn_metadata, - mamba_cache_params, - yoco_key_values=ssm_output) - residual = residual.to(torch.float32) - else: - attn_outputs = self.attn(hidden_states, ) - hidden_states = residual + attn_outputs - residual = hidden_states - hidden_states = self.post_attention_layernorm( - hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, ssm_output - - -class SambaYModel(nn.Module): - - def __init__(self, - config, - cache_config=None, - quant_config=None, - lora_config=None, - prefix: str = "") -> None: - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - # Pipeline parallel is not supported since the second half of - # the layers share the kv cache. - if get_pp_group().world_size != 1: - raise ValueError("Pipeline Parallel not supported") - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: SambaYDecoderLayer(config, - int(prefix.split('.')[-1]), - cache_config, - prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - - mamba_state_idx = 0 - ssm_output = None - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - if i == self.config.num_hidden_layers // 2 + 2: - # profile run - kv_cache_idx = self.config.num_hidden_layers // 2 + 1 - cache_layer = self.layers[kv_cache_idx] - kv_cache = cache_layer.attn.attn.kv_cache - if kv_cache[0].numel() == 0: - break - - # Starting from this layer, we do not need to calculate - # the kv cache since we reuse the kv cache from last layer. - # If in prefill phase, we can <s>prune></s> truncate - # the hidden state to save computation cost. - if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: - selected_token_indices = torch.cumsum( - attn_metadata.seq_lens_tensor, dim=0) - 1 - hidden_states = hidden_states.index_select( - 0, selected_token_indices) - ssm_output = ssm_output.index_select( - 0, selected_token_indices) - - if layer.use_mamba: - if i < self.config.num_hidden_layers // 2 or \ - not layer.yoco_cross: - mamba_cache = mamba_cache_params.at_layer_idx( - mamba_state_idx) - mamba_state_idx += 1 - else: - mamba_cache = mamba_cache_params.at_layer_idx( - mamba_state_idx - 1) - - hidden_states, ssm_output = layer(hidden_states, - positions, - attn_metadata, - mamba_cache, - ssm_output=ssm_output) - else: - hidden_states, ssm_output = layer( - hidden_states, - positions, - attn_metadata, - None, # mamba_cache_params - ssm_output=ssm_output) - - hidden_states = self.final_layernorm( - hidden_states.to(dtype=self.final_layernorm.weight.dtype)) - return hidden_states - - -class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config - quant_config = vllm_config.quant_config - scheduler_config = vllm_config.scheduler_config - self.compilation_config = vllm_config.compilation_config - self.vllm_config = vllm_config - # Prefix caching and chunked prefill is not supported for this model. - assert not cache_config.enable_prefix_caching, \ - "Phi4flash currently does not support prefix caching" - assert not scheduler_config.chunked_prefill_enabled, \ - "Phi4Flash currently does not support prefix caching" - super().__init__() - self.config = config - self.model_config = vllm_config.model_config - self.scheduler_config = scheduler_config - self.model = SambaYModel(config, - cache_config=cache_config, - prefix=maybe_prefix(prefix, "model")) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size), - quant_config=quant_config, - ) - self.embedding_bias = None - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logits_as_input=False) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers \ - // 2 // self.config.mb_per_layer + 1 - self.mamba_cache = MambaCacheManager( - self.vllm_config, - num_mamba_layers, - *self._get_mamba_cache_shape(), - self.lm_head.weight.dtype, - self.lm_head.weight.dtype, - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - attn_metadata = get_forward_context().attn_metadata - # input_ids and hidden_states isn't a one-to-one mapping in prefill - # stage due to YOCO optimization. - hidden_states = self.model(input_ids, positions, attn_metadata, - mamba_cache_params, intermediate_tensors, - inputs_embeds) - return hidden_states - - def _get_mamba_cache_shape( - self - ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - mamba_expand = self.config.mamba_expand # 2 - mamba_d_conv = self.config.mamba_d_conv # 4 - mamba_d_state = self.config.mamba_d_state # 16 - conv_state_shape = ( - mamba_expand * hidden_size // world_size, - mamba_d_conv - 1, - ) - temporal_state_shape = ( - mamba_expand * hidden_size // world_size, - mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - # If the shape is the same, it means that we have already - # prune hidden states manually. - prune_hidden_states = hidden_states.size( - 0) != sampling_metadata.selected_token_indices.size(0) - processed_logits = self.logits_processor( - self.lm_head, - hidden_states, - sampling_metadata, - self.embedding_bias, - prune_hidden_states=prune_hidden_states) - return processed_logits - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ): - weights = {name: weight for name, weight in weights} - adjusted_weights = {} - for name, weight in weights.items(): - if "A_log" in name: - name = name.replace("A_log", "A") - weight = -torch.exp(weight.float()) - if "inner_cross_attn." in name: - name = name.replace("inner_cross_attn.", "") - adjusted_weights[name] = weight - adjusted_weights["lm_head.weight"] = weights[ - "model.embed_tokens.weight"] - loaded_params: set[str] = set() - for name, param in self.named_parameters(): - weight = adjusted_weights.get(name) - if weight is not None and weight.shape != param.shape: - logger.warning("Shape mismatch: %s %s %s", name, weight.shape, - param.shape) - loaded_params.add(name) - missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, - strict=False) - assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" - assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" - return loaded_params diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b4aed11b86898..981f9b37846fe 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -2,41 +2,61 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch import torch.nn as nn -from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, - SequenceFeatureExtractor, SiglipVisionConfig) +from transformers import ( + BatchFeature, + PretrainedConfig, + ProcessorMixin, + SequenceFeatureExtractor, + SiglipVisionConfig, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_pp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, +) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, - ImageProcessorItems, ImageSize, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + ResolvedPromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 @@ -47,16 +67,17 @@ _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 SIGLIP_NAME = "siglip-so400m-patch14-448" VISION_ENCODER_TO_PROCESSING_CONFIG = { - 'siglip-so400m-patch14-448': { - 'vit_image_size': 448, - 'vit_patch_size': 14, - 'token_compression_factor': 2, + "siglip-so400m-patch14-448": { + "vit_image_size": 448, + "vit_patch_size": 14, + "token_compression_factor": 2, }, } -def _get_padding_size(orig_width: int, orig_height: int, target_height: int, - target_width: int): +def _get_padding_size( + orig_width: int, orig_height: int, target_height: int, target_width: int +): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -82,8 +103,7 @@ def get_navit_vision_model(layer_idx: int = -1, **kwargs): model_config = SiglipVisionConfig(**vision_config, **kwargs) if layer_idx < 0: - num_hidden_layers = model_config.num_hidden_layers \ - + layer_idx + 1 + num_hidden_layers = model_config.num_hidden_layers + layer_idx + 1 else: num_hidden_layers = layer_idx + 1 @@ -99,38 +119,38 @@ def get_navit_vision_model(layer_idx: int = -1, **kwargs): class Phi4MMImageEncoder(nn.Module): """Image embedding.""" - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - model_dir: str = "") -> None: + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + model_dir: str = "", + ) -> None: super().__init__() # n_embed or hidden_size - hidden_size = config.n_embd if hasattr( - config, 'n_embd') else config.hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size # layer_idx to output the img features if isinstance(config.img_processor, dict): - self.layer_idx = config.img_processor.get('layer_idx', -2) - self.type_feature = config.img_processor.get( - 'type_feature', 'patch') + self.layer_idx = config.img_processor.get("layer_idx", -2) + self.type_feature = config.img_processor.get("type_feature", "patch") else: self.layer_idx = -2 - self.type_feature = 'patch' + self.type_feature = "patch" self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx) pe_weight = self.img_processor.embeddings.position_embedding.weight L, D = pe_weight.size() H = int(math.sqrt(L)) - assert H**2 == L, f'position embedding size {L} is not square' + assert H**2 == L, f"position embedding size {L} is not square" if H % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) H += 1 image_dim_out = D # ((448/14)//2)**2 - self.num_img_tokens = (H // 2)**2 + self.num_img_tokens = (H // 2) ** 2 self.base_feat_height_target = H self.image_dim_out = image_dim_out @@ -145,37 +165,35 @@ class Phi4MMImageEncoder(nn.Module): self.crop_size = 448 # image token compression - self.image_token_compression_cls = 'avg_pool_2d' + self.image_token_compression_cls = "avg_pool_2d" self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) self.base_feat_height_reduction = 1 self.base_feat_height_target = self.base_feat_height_target // 2 # with_hd_transform and with_learnable_separator should have same value - assert self.use_hd_transform == self.with_learnable_separator, \ - 'use_hd_transform and with_learnable_separator should have same value' - assert self.use_hd_transform, \ - 'learnable separator is only for hd transform' + assert self.use_hd_transform == self.with_learnable_separator, ( + "use_hd_transform and with_learnable_separator should have same value" + ) + assert self.use_hd_transform, "learnable separator is only for hd transform" # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter( - torch.zeros([ - 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 - ])) + torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2]) + ) self.sub_GN = nn.Parameter( - torch.zeros([ - 1, 1, 1, - self.image_dim_out * self.base_feat_height_reduction**2 - ])) + torch.zeros( + [1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2] + ) + ) dim_projection = hidden_size depth = 2 layers = [ - nn.Linear(image_dim_out * self.base_feat_height_reduction**2, - dim_projection) + nn.Linear( + image_dim_out * self.base_feat_height_reduction**2, dim_projection + ) ] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) self.vocab_size = config.vocab_size @@ -183,24 +201,24 @@ class Phi4MMImageEncoder(nn.Module): self.use_out_place_operations = False - def get_img_features(self, - img_embeds: torch.FloatTensor, - attention_mask=None) -> torch.FloatTensor: - - img_feature = self.img_processor(img_embeds, - patch_attention_mask=attention_mask) + def get_img_features( + self, img_embeds: torch.FloatTensor, attention_mask=None + ) -> torch.FloatTensor: + img_feature = self.img_processor( + img_embeds, patch_attention_mask=attention_mask + ) if self.type_feature == "patch": patch_feature = img_feature use_token_compression = self.image_token_compression is not None - use_padding = getattr(self, 'img_processor_padding', - None) is not None + use_padding = getattr(self, "img_processor_padding", None) is not None if use_token_compression or use_padding: # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, - patch_feature.size(-1)) + patch_feature = patch_feature.view( + -1, width, width, patch_feature.size(-1) + ) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) @@ -214,15 +232,19 @@ class Phi4MMImageEncoder(nn.Module): patch_feature = patch_feature.view( -1, patch_feature.size(1) * patch_feature.size(2), - patch_feature.size(-1)) + patch_feature.size(-1), + ) return patch_feature raise NotImplementedError - def forward(self, pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor, - image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]: + def forward( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + image_attention_mask: torch.Tensor, + ) -> list[torch.FloatTensor]: """ process image and return vision embeddings. @@ -251,25 +273,27 @@ class Phi4MMImageEncoder(nn.Module): img_features = self.get_img_features( pixel_values, - image_attention_mask.type(torch.BoolTensor).flatten( - 0, 1).to(target_device)) + image_attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device), + ) base_feat_height_target = self.base_feat_height_target base_resolution = self.crop_size base_feat_height_reduction = self.base_feat_height_reduction - base_feat_height = base_feat_width = int(np.sqrt( - img_features.shape[1])) - assert base_feat_height == base_feat_height_target \ - and base_feat_width == base_feat_height_target, \ - (f"base_feat_height: {base_feat_height}, " - f"base_feat_width: {base_feat_width}, " - f"expect {base_feat_height_target} features for hd transform") + base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) + assert ( + base_feat_height == base_feat_height_target + and base_feat_width == base_feat_height_target + ), ( + f"base_feat_height: {base_feat_height}, " + f"base_feat_width: {base_feat_width}, " + f"expect {base_feat_height_target} features for hd transform" + ) # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, - base_feat_height * base_feat_width, - self.image_dim_out) + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out + ) C = self.image_dim_out H = base_feat_height @@ -288,22 +312,32 @@ class Phi4MMImageEncoder(nn.Module): global_img_feature = img_features[_bs, :1] # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1, H, H, C).reshape( - 1, H // base_feat_height_reduction, base_feat_height_reduction, - H // base_feat_height_reduction, base_feat_height_reduction, - C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( - 1, H // base_feat_height_reduction, + glb_img = ( + global_img_feature.reshape(1, H, H, C) + .reshape( + 1, H // base_feat_height_reduction, - base_feat_height_reduction * base_feat_height_reduction * - C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, - H // base_feat_height_reduction, - 1, 1) + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + .contiguous() + ) + temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1) # 1 x 156 x 4096 glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( - 1, -1, - base_feat_height_reduction * base_feat_height_reduction * C) + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) # (max_num_crops-1) x (12x12) x C sub_img = img_features[_bs, 1:] @@ -313,119 +347,178 @@ class Phi4MMImageEncoder(nn.Module): # (num_crops, 12, 2, 12, 2, 1024) -> # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_, H, H, C).reshape( - B_, H // base_feat_height_reduction, - base_feat_height_reduction, H // base_feat_height_reduction, - base_feat_height_reduction, - C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( - B_, -1, base_feat_height_reduction * - base_feat_height_reduction * C).contiguous() - sub_img = sub_img.reshape( - 1, h, w, base_feat_height // base_feat_height_reduction, - base_feat_width // base_feat_height_reduction, - -1).permute(0, 1, 3, 2, 4, 5).reshape( - 1, h * base_feat_height // base_feat_height_reduction, + sub_img = ( + sub_img.reshape(B_, H, H, C) + .reshape( + B_, + H // base_feat_height_reduction, + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + B_, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + .contiguous() + ) + sub_img = ( + sub_img.reshape( + 1, + h, + w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1, + ) + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, w * base_feat_width // base_feat_height_reduction, - base_feat_height_reduction * base_feat_height_reduction * - C) + base_feat_height_reduction * base_feat_height_reduction * C, + ) + ) - if image_attention_mask is not None and len( - image_attention_mask) > 0: - reshaped_image_attention_mask = image_attention_mask[ - _bs, 1:B_ + 1, 0::2, 0::2].reshape( - 1, h, w, + if image_attention_mask is not None and len(image_attention_mask) > 0: + reshaped_image_attention_mask = ( + image_attention_mask[_bs, 1 : B_ + 1, 0::2, 0::2] + .reshape( + 1, + h, + w, base_feat_height // base_feat_height_reduction, - base_feat_width // base_feat_height_reduction).permute( - 0, 1, 3, 2, 4).reshape( - 1, h * base_feat_height // - base_feat_height_reduction, w * - base_feat_width // base_feat_height_reduction) - useful_height = int( - reshaped_image_attention_mask[0, :, 0].sum().item()) - useful_width = int( - reshaped_image_attention_mask[0, 0, :].sum().item()) + base_feat_width // base_feat_height_reduction, + ) + .permute(0, 1, 3, 2, 4) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + ) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) - temp_len = int( - image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item( - )) + (useful_height + - 1) + base_feat_height // base_feat_height_reduction + temp_len = ( + int(image_attention_mask[_bs, : B_ + 1, 0::2, 0::2].sum().item()) + + (useful_height + 1) + + base_feat_height // base_feat_height_reduction + ) else: temp_sub_GN = self.sub_GN.repeat( - 1, h * base_feat_height // base_feat_height_reduction, 1, - 1) - temp_len = int((h * w + 1) * self.num_img_tokens + 1 + - (h + 1) * base_feat_height // - base_feat_height_reduction) + 1, h * base_feat_height // base_feat_height_reduction, 1, 1 + ) + temp_len = int( + (h * w + 1) * self.num_img_tokens + + 1 + + (h + 1) * base_feat_height // base_feat_height_reduction + ) sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( - 1, -1, - base_feat_height_reduction * base_feat_height_reduction * C) + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) # (1, num_img_tokens, 1024*4) # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + if self.hd_transform_order == "glb_sub": + output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == "sub_glb": + output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) else: raise NotImplementedError( f'hd_transform_order = {self.hd_transform_order}, "\ - "not implemented') + "not implemented' + ) - #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) - assert temp_len == output_imgs[-1].shape[ - 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ + # temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert temp_len == output_imgs[-1].shape[1], ( + f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ "{output_imgs[-1].shape[1]}' + ) output_len.append(temp_len) img_set_tensor = [] for _output_img in output_imgs: img_feature_proj = self.img_projection( - _output_img.to(target_device).to(target_dtype)) + _output_img.to(target_device).to(target_dtype) + ) img_set_tensor.append(img_feature_proj.squeeze(0)) return img_set_tensor -class Phi4MMImagePixelInputs(TypedDict): +class Phi4MMImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - p: Number of patches (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch + - nc: Number of crops + - H_mask: Height of attention mask + - W_mask: Width of attention mask + """ + type: Literal["pixel_values"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - Shape: - `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape( + "bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image + ] + + image_sizes: Annotated[ + torch.Tensor, + TensorShape("bn", 2), # (height, width) + ] + + num_img_tokens: Annotated[ + list[int], + TensorShape("bn"), + ] + + image_attention_mask: Annotated[ + torch.Tensor, + TensorShape("bn", "nc", 32, 32), # H_mask, W_mask + ] + + +class Phi4MMAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - t: Time frames (M) """ - image_sizes: torch.Tensor - """ - Shape: `(batch_size * num_images, 2)` - - This should be in `(height, width)` format. - """ - - num_img_tokens: list[int] - """Shape: `(batch_size * num_images)`""" - - image_attention_mask: torch.Tensor - """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - - -class Phi4MMAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size * num_audios, 80, M)""" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "t", 80, dynamic_dims={"t"}), + ] -class Phi4MMAudioEmbeddingInputs(TypedDict): +class Phi4MMAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of audios + - f: Audio feature size + - h: Hidden size (must match language model backbone) + """ + type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + data: Annotated[ + NestedTensors, + TensorShape("b", "n", "f", "h"), + ] Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] @@ -436,9 +529,9 @@ def cat_with_pad(tensors, dim, padding_value=0): cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + assert all(t.dim() == ndim for t in tensors[1:]), ( + "All tensors must have the same number of dimensions" + ) out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) @@ -458,14 +551,13 @@ def cat_with_pad(tensors, dim, padding_value=0): class Phi4MMProcessingInfo(BaseProcessingInfo): - @property def image_tokens(self) -> list[str]: - return [f"<|image_{i+1}|>" for i in range(100)] + return [f"<|image_{i + 1}|>" for i in range(100)] @property def audio_tokens(self) -> list[str]: - return [f"<|audio_{i+1}|>" for i in range(100)] + return [f"<|audio_{i + 1}|>" for i in range(100)] def get_dynamic_hd( self, @@ -476,8 +568,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): image_processor = processor.image_processor return image_processor.dynamic_hd - def get_feature_extractor(self, - **kwargs: object) -> SequenceFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> SequenceFeatureExtractor: return self.get_hf_processor(**kwargs).audio_processor def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: @@ -497,9 +588,12 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) + target_ratios = set( + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num + ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target @@ -532,49 +626,56 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ): """ compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing + the image encoder architecture and exclude output features containing only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be + for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ assert vit_image_size % vit_patch_size == 0, ( - "vit_image_size must be divisible by vit_patch_size") - assert (vit_image_size // vit_patch_size % - token_compression_factor == 0), ( - "vit_image_size // vit_patch_size must be divisible by " - "token_compression_factor") + "vit_image_size must be divisible by vit_patch_size" + ) + assert vit_image_size // vit_patch_size % token_compression_factor == 0, ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor" + ) target_aspect_ratio, target_height, target_width = ( - self._find_target_aspect_ratio(orig_width, - orig_height, - vit_image_size, - dynamic_hd_size, - min_num=1)) + self._find_target_aspect_ratio( + orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1 + ) + ) assert target_aspect_ratio[0] * vit_image_size == target_width, ( - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + ) assert target_aspect_ratio[1] * vit_image_size == target_height, ( - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + ) + assert ( + target_height % vit_image_size == 0 and target_width % vit_image_size == 0 + ) padding_height, padding_width = _get_padding_size( - orig_width, orig_height, target_height, target_width) - assert padding_width == 0 or padding_height == 0, \ + orig_width, orig_height, target_height, target_width + ) + assert padding_width == 0 or padding_height == 0, ( "padding_width or padding_height must be 0" + ) target_feat_width = target_width // vit_patch_size target_feat_height = target_height // vit_patch_size if padding_width >= vit_patch_size: assert padding_height == 0, "padding_height not 0" non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) + padding_width / vit_patch_size + ) non_pad_feat_height = target_feat_height elif padding_height >= vit_patch_size: assert padding_width == 0, "padding_width not 0" non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) + padding_height / vit_patch_size + ) non_pad_feat_width = target_feat_width else: # small padding shorter than a vit patch @@ -591,15 +692,17 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // - token_compression_factor)**2 + num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor + num_global_image_newline_tokens = vit_feature_size // token_compression_factor - return (num_global_image_tokens + num_sep_tokens + - num_hd_patch_tokens + num_hd_newline_tokens + - num_global_image_newline_tokens) + return ( + num_global_image_tokens + + num_sep_tokens + + num_hd_patch_tokens + + num_hd_newline_tokens + + num_global_image_newline_tokens + ) def get_num_image_tokens( self, @@ -612,11 +715,10 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ - vision_encoder_name] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + vit_image_size = prepro_config["vit_image_size"] + vit_patch_size = prepro_config["vit_patch_size"] + token_compression_factor = prepro_config["token_compression_factor"] dynamic_hd_size = self.get_dynamic_hd(processor=processor) @@ -639,9 +741,8 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ - vision_encoder_name] - vit_image_size = prepro_config['vit_image_size'] + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + vit_image_size = prepro_config["vit_image_size"] max_side = vit_image_size * self.get_dynamic_hd(processor=processor) return ImageSize(height=max_side, width=vit_image_size) @@ -687,8 +788,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): compression rate. """ hf_config = self.get_hf_config() - compression_rate = hf_config.embd_layer['audio_embd_layer'][ - 'compression_rate'] + compression_rate = hf_config.embd_layer["audio_embd_layer"]["compression_rate"] # NOTE: this is a hard-coded value but might be configurable # in the future qformer_compression_rate = 1 @@ -706,7 +806,6 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -720,32 +819,39 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + audio_overrides = mm_options.get("audio") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "audio": - self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, - num_audios=num_audios), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "audio": self._get_dummy_audios( + length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios, + overrides=audio_overrides, + ), } return mm_data class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate, - audio_resample_method="scipy") + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, audio_resample_method="scipy" + ) def _call_hf_processor( self, @@ -760,27 +866,27 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate - if (audio_data := mm_data.get("audios", [])): - mm_data['audios'] = [(data, sr) for data in audio_data] + if audio_data := mm_data.get("audios", []): + mm_data["audios"] = [(data, sr) for data in audio_data] - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) num_img_tokens = [ - self.info.get_num_image_tokens(image_width=img_size[0], - image_height=img_size[1]) + self.info.get_num_image_tokens( + image_width=img_size[0], image_height=img_size[1] + ) for img_size in processed_outputs["image_sizes"] ] processed_outputs["num_img_tokens"] = num_img_tokens - audio_features = processed_outputs['input_audio_embeds'] + audio_features = processed_outputs["input_audio_embeds"] feature_sizes = [ - self.info.get_audio_num_frames(len(audio), sr) - for audio in audio_data + self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data ] - processed_outputs['input_audio_embeds'] = [ - audio_features[idx, :size] - for idx, size in enumerate(feature_sizes) + processed_outputs["input_audio_embeds"] = [ + audio_features[idx, :size] for idx, size in enumerate(feature_sizes) ] return processed_outputs @@ -806,13 +912,13 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ) -> Sequence[PromptUpdate]: image_tokens: list[str] = self.info.image_tokens # type: ignore audio_tokens: list[str] = self.info.audio_tokens # type: ignore - feature_extractor = self.info.get_feature_extractor( - **hf_processor_mm_kwargs) + feature_extractor = self.info.get_feature_extractor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) def get_image_replacement_phi4mm(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -824,41 +930,50 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): processor=hf_processor, ) - image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens - - return image_tokens + return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) # TODO(Isotr0py): support embedding inputs audio_len = audios.get_audio_length(item_idx) audio_frames = self.info.get_audio_num_frames( - audio_len, feature_extractor.sampling_rate) - audio_embed_size = self.info._compute_audio_embed_size( - audio_frames) + audio_len, feature_extractor.sampling_rate + ) + audio_embed_size = self.info._compute_audio_embed_size(audio_frames) - audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - return audio_tokens - - num_images = mm_items.get_count("image", strict=False) - num_audios = mm_items.get_count("audio", strict=False) - - image_repl = [ + return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_image_replacement_phi4mm, - ) for image_token in image_tokens[:num_images] - ] - audio_repl = [ + ), PromptReplacement( modality="audio", - target=audio_token, + target=audio_tokens.__getitem__, replacement=get_audio_replacement_phi4mm, - ) for audio_token in audio_tokens[:num_audios] + ), ] - return image_repl + audio_repl + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + image_tokens: list[str] = self.info.image_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + elif cached_update.modality == "audio": + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + new_update = new_update.with_target(audio_tokens[new_item_idx]) + + return new_update @MULTIMODAL_REGISTRY.register_processor( @@ -870,6 +985,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -884,10 +1000,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): "base_layer.": "", }, orig_to_new_prefix={ - "model.embed_tokens_extend.audio_embed.audio_projection.vision.": - "embed_tokens_extend.audio_projection_for_vision.", - "model.embed_tokens_extend.audio_embed.audio_projection.speech.": - "embed_tokens_extend.audio_projection.", + "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.", # noqa: E501 + "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.", # noqa: E501 "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", "model.embed_tokens_extend.image_embed.": "vision_encoder.", }, @@ -916,19 +1030,18 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. - assert get_pp_group( - ).world_size == 1, "pipeline parallel is not supported" + assert get_pp_group().world_size == 1, "pipeline parallel is not supported" self.vision_encoder = Phi4MMImageEncoder( config, quant_config, prefix="model.vision_embed_tokens", - model_dir=config._name_or_path) + model_dir=config._name_or_path, + ) if isinstance(config.embd_layer["audio_embd_layer"], dict): embedding_config = { - "embedding_cls": - config.embd_layer["audio_embd_layer"]["embedding_cls"], + "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"], **config.embd_layer["audio_embd_layer"], } else: @@ -937,8 +1050,9 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): } self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = LlamaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -949,17 +1063,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMAudioInputs]: """ - Parse and validate the audio input to the model. This handles both + Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. @@ -976,25 +1093,18 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - - return Phi4MMAudioFeatureInputs(type="audio_features", - data=flatten_bn(audio_features)) + return Phi4MMAudioFeatureInputs( + type="audio_features", data=flatten_bn(audio_features) + ) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", - data=audio_embeds) + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, audio_input: Phi4MMAudioInputs, - audio_projection_mode: str) -> NestedTensors: + def _process_audio_input( + self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str + ) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is @@ -1018,12 +1128,14 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): self.embed_tokens_extend( features.to(dtype), audio_projection_mode=audio_projection_mode, - ) for features in audio_features + ) + for features in audio_features ] return audio_embeds - def _parse_and_validate_image_input(self, - **kwargs: object) -> Optional[dict]: + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[Phi4MMImagePixelInputs]: input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") if input_image_embeds is None: return None @@ -1031,12 +1143,16 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): image_sizes = kwargs.get("image_sizes") image_attention_mask = kwargs.get("image_attention_mask") num_img_tokens = kwargs.get("num_img_tokens") - assert image_sizes is not None and image_attention_mask is not None\ - and num_img_tokens is not None, "Missing image inputs" + assert ( + image_sizes is not None + and image_attention_mask is not None + and num_img_tokens is not None + ), "Missing image inputs" if is_list_of(input_image_embeds, torch.Tensor): - assert all(p.dim() == 5 - for p in input_image_embeds), "Incorrect image inputs" + assert all(p.dim() == 5 for p in input_image_embeds), ( + "Incorrect image inputs" + ) # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. @@ -1065,17 +1181,16 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): elif isinstance(image_sizes, torch.Tensor): image_sizes = image_sizes.flatten(0, 1) else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect image_sizes inputs") if isinstance(num_img_tokens, list): num_img_tokens = [ - n for num_tensor in num_img_tokens - for n in num_tensor.tolist() + n for num_tensor in num_img_tokens for n in num_tensor.tolist() ] elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() else: - raise ValueError("Incorrect image_attention_mask inputs") + raise ValueError("Incorrect num_img_tokens inputs") return Phi4MMImagePixelInputs( type="pixel_values", @@ -1091,43 +1206,43 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("input_image_embeds", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("input_audio_embeds", - "audio_embeds") and "audios" not in modalities: - modalities["audios"] = self._parse_and_validate_audio_input( - **kwargs) + if ( + input_key in ("input_image_embeds", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("input_audio_embeds", "audio_embeds") + and "audios" not in modalities + ): + modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) return modalities def _process_image_input( - self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: - + self, image_input: Phi4MMImagePixelInputs + ) -> list[torch.Tensor]: dtype = next(self.vision_encoder.parameters()).dtype - pixel_values = image_input['data'].to(dtype) - image_sizes = image_input['image_sizes'] - image_attention_mask = image_input['image_attention_mask'] - image_embeds = self.vision_encoder(pixel_values, image_sizes, - image_attention_mask) + pixel_values = image_input["data"].to(dtype) + image_sizes = image_input["image_sizes"] + image_attention_mask = image_input["image_attention_mask"] + image_embeds = self.vision_encoder( + pixel_values, image_sizes, image_attention_mask + ) return image_embeds - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] - return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - audio_projection_mode = 'speech' + audio_projection_mode = "speech" for modality in modalities: # make sure process images first if modality == "images": @@ -1138,53 +1253,12 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): if modality == "audios": audio_input = modalities["audios"] audio_embeddings = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode) + audio_input, audio_projection_mode=audio_projection_mode + ) multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.embed_tokens(input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Phi4MMImagePixelInputs] = None, - audio_input: Optional[Phi4MMAudioFeatureInputs] = None, - ) -> torch.Tensor: - audio_projection_mode = 'speech' - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID, - ) - audio_projection_mode = 'vision' - - if audio_input is not None: - audio_embeds = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - audio_embeds, - placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1196,22 +1270,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - audio_input = self._parse_and_validate_audio_input(**kwargs) - - if image_input is None and audio_input is None: - inputs_embeds = None - else: - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - audio_input=audio_input) - input_ids = None - hidden_states = self.model( input_ids, positions, @@ -1224,14 +1282,11 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> None: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: loader = AutoWeightsLoader(self, skip_substrs=["lora"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index 0b0d66ae771dd..d289e26efa10f 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -7,22 +7,31 @@ #!/usr/bin/env python3 import abc import math -from typing import Literal, Optional +from typing import Any, Literal, Optional, Union import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel) + CheckpointWrapper, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from transformers import PretrainedConfig from vllm.model_executor.models.phi4mm_utils import ( - AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, - MultiHeadedAttention, MultiSequential, NemoConvSubsampling, - T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor) + AbsolutePositionalEncoding, + ConvModule, + FeedForward, + MeanVarianceNormLayer, + MultiHeadedAttention, + MultiSequential, + NemoConvSubsampling, + T5RelativeAttentionLogitBias, + adaptive_enc_mask, + get_offset, + unfold_tensor, +) _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> @@ -40,10 +49,10 @@ class ConformerEncoderLayer(nn.Module): if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int - if set different to 0, the number of + if set different to 0, the number of depthwise_seperable_out_channel will be used as a - channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. + channel_out of the second conv1d layer. + otherwise, it equals to 0, the second conv1d layer is skipped. depthwise_multiplier: int number of input_dim channels duplication. this value will be used to compute the hidden channels of the Conv1D. @@ -100,7 +109,7 @@ class ConformerEncoderLayer(nn.Module): activation function for glu used in the multihead attention, default "swish". activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where + a dictionary of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. @@ -115,14 +124,14 @@ class ConformerEncoderLayer(nn.Module): we recalculate activation in backward. default "". export: bool, optional - if set to True, it remove the padding from convolutional layers + if set to True, it removes the padding from convolutional layers and allow the onnx conversion for inference. default False. use_pt_scaled_dot_product_attention: bool, optional - if set to True, use pytorch's scaled dot product attention + if set to True, use pytorch's scaled dot product attention implementation in training. attn_group_sizes: int, optional - the number of groups to use for attention, default 1 + the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attn_group_sizes < attention_heads = Grouped-Query Attention @@ -131,31 +140,31 @@ class ConformerEncoderLayer(nn.Module): def __init__( self, - d_model=512, - ext_pw_out_channel=0, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - n_head=4, - d_ffn=2048, - ext_pw_kernel_size=1, - kernel_size=3, - dropout_rate=0.1, - causal=False, - batch_norm=False, - activation="relu", - chunk_se=0, - chunk_size=18, - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_inner_dim=-1, - attention_glu_type="swish", - activation_checkpointing="", - export=False, - use_pt_scaled_dot_product_attention=False, + d_model: int = 512, + ext_pw_out_channel: int = 0, + depthwise_seperable_out_channel: int = 256, + depthwise_multiplier: int = 1, + n_head: int = 4, + d_ffn: int = 2048, + ext_pw_kernel_size: int = 1, + kernel_size: int = 3, + dropout_rate: float = 0.1, + causal: bool = False, + batch_norm: bool = False, + activation: str = "relu", + chunk_se: int = 0, + chunk_size: int = 18, + conv_activation: str = "relu", + conv_glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + attention_inner_dim: int = -1, + attention_glu_type: str = "swish", + activation_checkpointing: str = "", + export: bool = False, + use_pt_scaled_dot_product_attention: bool = False, attn_group_sizes: int = 1, - ): + ) -> None: super().__init__() self.feed_forward_in = FeedForward( @@ -173,8 +182,7 @@ class ConformerEncoderLayer(nn.Module): attention_inner_dim, attention_glu_type, bias_in_glu, - use_pt_scaled_dot_product_attention= - use_pt_scaled_dot_product_attention, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, group_size=attn_group_sizes, ) self.conv = ConvModule( @@ -209,24 +217,21 @@ class ConformerEncoderLayer(nn.Module): def forward( self, - x, - pos_k, - pos_v, - mask, + x: torch.Tensor, + pos_k: torch.Tensor, + pos_v: torch.Tensor, + mask: torch.Tensor, relative_attention_bias: Optional[Tensor] = None, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ConformerEncoder forward. Args: - x: torch.Tensor - input feature of shape (batch, max_time_in, size) - pos_k: torch.Tensor - positional key embedding. - mask: torch.Tensor - mask for x (batch, max_time_in) - relative_attention_bias: Optional[torch.Tensor] - bias added to attention logits w.r.t. relative positions - (1, n_head, time1, time2) + x: input feature of shape (batch, max_time_in, size) + pos_k: positional key embedding. + pos_v: positional value embedding. + mask: mask for x (batch, max_time_in) + relative_attention_bias: bias added to attention logits w.r.t. + relative positions (1, n_head, time1, time2) """ x = x + 0.5 * self.feed_forward_in(x) norm_x = self.layer_norm_att(x) @@ -299,7 +304,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): (Q*K^T + B) implemented in cmb.basics.embedding. [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see + additional method-specific arguments can be provided (see transformer_base.py) positional_dropout_rate: float, optional dropout rate after positional encoding. default 0.0 @@ -313,35 +318,34 @@ class TransformerEncoderBase(abc.ABC, nn.Module): supraframe utts in batch. Default: none attention_group_size: int, optional - the number of groups to use for attention, default 1 + the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query + 1 < attention_group_size < attention_heads = Grouped-Query Attention attention_group_size = attention_heads = Multi-Query Attention """ def __init__( self, - input_size, - chunk_size, - left_chunk, - attention_dim=256, - attention_heads=4, - input_layer="nemo_conv", - cnn_out=-1, - cnn_layer_norm=False, - time_reduction=4, - dropout_rate=0.0, - padding_idx=-1, - relative_attention_bias_args=None, - positional_dropout_rate=0.0, - nemo_conv_settings=None, - conv2d_extra_padding: Literal["feat", "feat_time", "none", - True] = "none", - attention_group_size=1, - encoder_embedding_config=None, - ): + input_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + attention_dim: int = 256, + attention_heads: int = 4, + input_layer: str = "nemo_conv", + cnn_out: int = -1, + cnn_layer_norm: bool = False, + time_reduction: int = 4, + dropout_rate: float = 0.0, + padding_idx: int = -1, + relative_attention_bias_args: Optional[dict[str, Any]] = None, + positional_dropout_rate: float = 0.0, + nemo_conv_settings: Optional[dict[str, Any]] = None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + attention_group_size: int = 1, + encoder_embedding_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.input_size = input_size self.input_layer = input_layer @@ -369,74 +373,88 @@ class TransformerEncoderBase(abc.ABC, nn.Module): if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" + assert i not in nemo_conv_settings, ( + "{i} should be specified outside of the NeMo dictionary" + ) - self.embed = NemoConvSubsampling(**default_nemo_conv_settings, ) + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) else: raise ValueError("unknown input_layer: " + input_layer) - self.pos_emb = AbsolutePositionalEncoding(attention_dim, - positional_dropout_rate) + self.pos_emb = AbsolutePositionalEncoding( + attention_dim, positional_dropout_rate + ) self.relative_attention_bias_type = ( relative_attention_bias_args.get("type") - if relative_attention_bias_args else None) + if relative_attention_bias_args + else None + ) if self.relative_attention_bias_type == "t5": - assert (self.num_heads % self.attention_group_size == 0 - ), "attention_group_size must divide n_head" + assert self.num_heads % self.attention_group_size == 0, ( + "attention_group_size must divide n_head" + ) self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( self.num_heads // self.attention_group_size, max_distance=relative_attention_bias_args.get( - "t5_bias_max_distance", 1000), - symmetric=relative_attention_bias_args.get( - "t5_bias_symmetric", False), + "t5_bias_max_distance", 1000 + ), + symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), ) else: raise NotImplementedError self.encoder_embedding = MeanVarianceNormLayer( - self.encoder_embedding_config["input_size"]) + self.encoder_embedding_config["input_size"] + ) - def compute_lens_change(self, feature_lens): + def compute_lens_change( + self, feature_lens: Union[int, torch.Tensor] + ) -> Union[int, torch.Tensor]: """feature_lens: int return updated feature lens. - This used to return a different lambda function for each case that - computed the right thing. That does not work within Torchscript. + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. If you really need this to be faster, create nn.Module()-s for all the cases and return one of them. Torchscript does support that. """ if self.input_layer == "nemo_conv": # Handle the special causal case subsampling_causal_cond = self.nemo_conv_settings.get( - "subsampling", "dw_striding") in [ - "dw_striding", - "striding", - "striding_conv1d", - ] + "subsampling", "dw_striding" + ) in [ + "dw_striding", + "striding", + "striding_conv1d", + ] is_causal = self.nemo_conv_settings.get("is_causal", False) if is_causal and subsampling_causal_cond: - lens_change = (torch.ceil(feature_lens / - self.time_reduction).long() - if isinstance(feature_lens, Tensor) else - math.ceil(feature_lens / self.time_reduction)) + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) feature_lens_remainder = feature_lens % self.time_reduction if isinstance(feature_lens, Tensor): lens_change[feature_lens_remainder != 1] += 1 elif feature_lens_remainder != 1: lens_change += 1 return lens_change - ceil_func = (math.ceil - if isinstance(feature_lens, int) else torch.ceil) + ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod - def forward(self): + def forward(self) -> Any: """Abstract forward method implementation.""" - def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + def _chunk_size_selection( + self, + chunk_size: Optional[Union[int, list[int]]] = None, + left_chunk: Optional[Union[int, list[int]]] = None, + ) -> tuple[int, int]: """If chunk size is a list, we will randomly select a chunk size.""" if chunk_size is None: @@ -446,15 +464,16 @@ class TransformerEncoderBase(abc.ABC, nn.Module): if isinstance(chunk_size, list): # Variable chunk size during training chunk_size_index = int( - torch.randint(low=0, high=len(chunk_size), size=(1, ))) + torch.randint(low=0, high=len(chunk_size), size=(1,)) + ) chunk_size_train_eff = chunk_size[chunk_size_index] if not isinstance(left_chunk, list): raise ValueError( - "Since chunk_size is a list, left_chunk must be a list") + "Since chunk_size is a list, left_chunk must be a list" + ) if len(left_chunk) != len(chunk_size): raise ValueError( - "The length of left_chunk must be the same as length of "\ - "chunk_size." + "The length of left_chunk must be the same as length of chunk_size." ) left_chunk_train_eff = left_chunk[chunk_size_index] else: @@ -463,7 +482,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return chunk_size_train_eff, left_chunk_train_eff - def _get_embed_class(self, embed): + def _get_embed_class(self, embed: nn.Module) -> nn.Module: # pylint: disable=protected-access is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) @@ -474,39 +493,72 @@ class TransformerEncoderBase(abc.ABC, nn.Module): embed_class = embed.module return embed_class - def _forward_embeddings_core(self, input_tensor, masks): + def _forward_embeddings_core( + self, input_tensor: torch.Tensor, masks: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: embed_class = self._get_embed_class(self.embed) assert isinstance(embed_class, NemoConvSubsampling) input_tensor, masks = self.embed(input_tensor, masks) return input_tensor, masks - def _position_embedding(self, input_tensor): + def _position_embedding( + self, input_tensor: torch.Tensor + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: pos_k = None pos_v = None if self.relative_attention_bias_layer is None: input_tensor = self.pos_emb( - input_tensor) # default to add abs sinusoid embedding + input_tensor + ) # default to add abs sinusoid embedding return pos_k, pos_v - def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): - chunk_size_train_eff, left_chunk_train_eff = \ - self._chunk_size_selection(chunk_size, left_chunk) + def _streaming_mask( + self, + seq_len: int, + batch_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + ) -> torch.Tensor: + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) - enc_streaming_mask = (adaptive_enc_mask( - seq_len, chunk_start_idx, - left_window=left_chunk_train_eff).unsqueeze(0).expand( - [batch_size, -1, -1])) + enc_streaming_mask = ( + adaptive_enc_mask( + seq_len, chunk_start_idx, left_window=left_chunk_train_eff + ) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) return enc_streaming_mask - def forward_embeddings(self, - xs_pad, - masks, - chunk_size_nc=None, - left_chunk_nc=None): + def forward_embeddings( + self, + xs_pad: torch.Tensor, + masks: torch.Tensor, + chunk_size_nc: Optional[Union[int, list[int]]] = None, + left_chunk_nc: Optional[Union[int, list[int]]] = None, + ) -> Union[ + tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ], + tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: """Forwarding the inputs through the top embedding layers Args: @@ -514,7 +566,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): input tensor masks: torch.Tensor input mask - chunk_size_nc: (optional, default is None) chunk size for + chunk_size_nc: (optional, default is None) chunk size for non-causal layers left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers @@ -527,21 +579,21 @@ class TransformerEncoderBase(abc.ABC, nn.Module): f"""The sequence length after time reduction is invalid: {seq_len}. Your input feature is too short. Consider filtering out the very short sentence from data - loader""", ) + loader""", + ) batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask(seq_len, batch_size, - self.chunk_size, - self.left_chunk) + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) if xs_pad.is_cuda: enc_streaming_mask = enc_streaming_mask.cuda() xs_pad = xs_pad.cuda() input_tensor = xs_pad - input_tensor, masks = self._forward_embeddings_core( - input_tensor, masks) + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: @@ -553,7 +605,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module): if chunk_size_nc is not None: enc_streaming_mask_nc = self._streaming_mask( - seq_len, batch_size, chunk_size_nc, left_chunk_nc) + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) if xs_pad.is_cuda: enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() if masks is not None: @@ -569,7 +622,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc - def get_offset(self): + def get_offset(self) -> int: """Returns offset used when retaining inputs for decoding. This is essentially, how many additional frames have to be added to @@ -605,11 +658,9 @@ class ConformerEncoder(TransformerEncoderBase): Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] - left_chunk: int - number of chunks used for masking in streaming mode. num_lang: int - This parameter is used to store the number of languages in the - lang_dict, only used for multiseed/multilingual models. + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. default None. attention_dim: int, optional attention dimension. default 256. @@ -686,7 +737,7 @@ class ConformerEncoder(TransformerEncoderBase): only work for glu_in_attention !=0 default "swish". export: bool, optional - if set to True, it remove the padding from convolutional layers + if set to True, it removes the padding from convolutional layers and allow the onnx conversion for inference. default False. activation_checkpointing: str, optional @@ -707,16 +758,16 @@ class ConformerEncoder(TransformerEncoderBase): extra_layer_output_idx: int the layer index to be exposed. relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) implemented in cmb.basics.embedding. [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see + additional method-specific arguments can be provided (see transformer_base.py) time_reduction: int optional time reduction factor default 4 - use_pt_scaled_dot_product_attention: whether to use pytorch scaled + use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention in training. Default: False nemo_conv_settings: dict, optional @@ -734,12 +785,12 @@ class ConformerEncoder(TransformerEncoderBase): Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True) Default: none - replication_pad_for_subsample_embedding: For batched-streaming + replication_pad_for_subsample_embedding: For batched-streaming decoding, use "replication" padding for the cache at start of utterance. Default: False attention_group_size: int, optional - the number of groups to use for attention, default 1 + the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attention_group_size < attention_heads = Grouped-Query @@ -751,46 +802,45 @@ class ConformerEncoder(TransformerEncoderBase): def __init__( # pylint: disable-all self, - input_size, - chunk_size, - left_chunk, - num_lang=None, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - input_layer="nemo_conv", - causal=True, - batch_norm=False, - cnn_out=-1, - cnn_layer_norm=False, - ext_pw_out_channel=0, - ext_pw_kernel_size=1, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - chunk_se=0, - kernel_size=3, - activation="relu", - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_glu_type="swish", - export=False, - extra_layer_output_idx=-1, - extra_multi_layer_output_idxs=[], # noqa - activation_checkpointing="", - relative_attention_bias_args=None, - time_reduction=4, - use_pt_scaled_dot_product_attention=False, - nemo_conv_settings=None, - conv2d_extra_padding: Literal["feat", "feat_time", "none", - True] = "none", - replication_pad_for_subsample_embedding=False, - attention_group_size=1, - encoder_embedding_config=None, - ): + input_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + num_lang: Optional[int] = None, + attention_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + input_layer: str = "nemo_conv", + causal: bool = True, + batch_norm: bool = False, + cnn_out: int = -1, + cnn_layer_norm: bool = False, + ext_pw_out_channel: int = 0, + ext_pw_kernel_size: int = 1, + depthwise_seperable_out_channel: int = 256, + depthwise_multiplier: int = 1, + chunk_se: int = 0, + kernel_size: int = 3, + activation: str = "relu", + conv_activation: str = "relu", + conv_glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + attention_glu_type: str = "swish", + export: bool = False, + extra_layer_output_idx: int = -1, + extra_multi_layer_output_idxs: list[int] = [], # noqa + activation_checkpointing: str = "", + relative_attention_bias_args: Optional[dict[str, Any]] = None, + time_reduction: int = 4, + use_pt_scaled_dot_product_attention: bool = False, + nemo_conv_settings: Optional[dict[str, Any]] = None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + replication_pad_for_subsample_embedding: bool = False, + attention_group_size: int = 1, + encoder_embedding_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__( input_size, chunk_size, @@ -813,71 +863,80 @@ class ConformerEncoder(TransformerEncoderBase): self.num_lang = num_lang self.kernel_size = kernel_size self.replication_pad_for_subsample_embedding: bool = ( - replication_pad_for_subsample_embedding) - assert (self.num_heads % attention_group_size == 0 - ), "attention_group_size must divide n_head" + replication_pad_for_subsample_embedding + ) + assert self.num_heads % attention_group_size == 0, ( + "attention_group_size must divide n_head" + ) self.num_heads_k = self.num_heads // attention_group_size - self.encoders = MultiSequential(*[ - ConformerEncoderLayer( - d_model=attention_dim, - ext_pw_out_channel=ext_pw_out_channel, - depthwise_seperable_out_channel=depthwise_seperable_out_channel, - depthwise_multiplier=depthwise_multiplier, - n_head=attention_heads, - d_ffn=linear_units, - ext_pw_kernel_size=ext_pw_kernel_size, - kernel_size=kernel_size, - dropout_rate=dropout_rate, - causal=causal, - batch_norm=batch_norm, - activation=activation, - chunk_se=chunk_se, - chunk_size=chunk_size, - conv_activation=conv_activation, - conv_glu_type=conv_glu_type, - bias_in_glu=bias_in_glu, - linear_glu_in_convm=linear_glu_in_convm, - attention_glu_type=attention_glu_type, - activation_checkpointing=activation_checkpointing, - export=export, - use_pt_scaled_dot_product_attention= - use_pt_scaled_dot_product_attention, - attn_group_sizes=attention_group_size, - ) for _ in range(num_blocks) - ]) + self.encoders = MultiSequential( + *[ + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=activation_checkpointing, + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + for _ in range(num_blocks) + ] + ) self.extra_layer_output_idx = extra_layer_output_idx self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs # Make a zeros scalar we can use in get_initial_state to determine # the device and the needed dtype: self.register_buffer("dev_type", torch.zeros(()), persistent=False) - def init_relative_attention_bias(self, input_tensor): + def init_relative_attention_bias( + self, input_tensor: torch.Tensor + ) -> Optional[torch.Tensor]: if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) - def calculate_hs_mask(self, xs_pad, device, mask): + def calculate_hs_mask( + self, xs_pad: torch.Tensor, device: torch.device, mask: Optional[torch.Tensor] + ) -> torch.Tensor: max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, - self.chunk_size, - self.left_chunk) + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens - pad_mask = (torch.arange(0, max_audio_length, - device=device).expand(padding_length.size(0), - -1) - < padding_length.unsqueeze(1)) + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask @torch.jit.ignore - def forward(self, xs_pad, masks): + def forward( + self, xs_pad: torch.Tensor, masks: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Conformer Forward function Args: @@ -888,11 +947,12 @@ class ConformerEncoder(TransformerEncoderBase): """ xs_pad = self.encoder_embedding(xs_pad) input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( - xs_pad, masks) + xs_pad, masks + ) unfolded = False ori_bz, seq_len, D = input_tensor.shape - max_seq_len = 500 #maximum position for absolute positional encoding + max_seq_len = 500 # maximum position for absolute positional encoding if seq_len > max_seq_len: # audio sequence is longer than max_seq_len, unfold it into chunks # of max_seq_len @@ -904,26 +964,29 @@ class ConformerEncoder(TransformerEncoderBase): else: chunk_pad_size = 0 if chunk_pad_size > 0: - input_tensor_pad = F.pad(input_tensor, - (0, 0, 0, chunk_pad_size), "constant", - 0) + input_tensor_pad = F.pad( + input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0 + ) input_tensor = input_tensor_pad.to(input_tensor.device) input_tensor = unfold_tensor(input_tensor, max_seq_len) if masks is not None: # revise hs_mask here because the previous calculated hs_mask # did not consider extra pad subsampled_pad_mask = masks.squeeze( - 1) # [bz, subsampled_unmask_seq_len] + 1 + ) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( - subsampled_pad_mask, (0, chunk_pad_size), "constant", - False) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = \ + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = ( extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + ) masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze( - -1).bool() # unfold op does not support bool tensor + -1 + ).bool() # unfold op does not support bool tensor else: masks_unfold = None hs_mask = self.calculate_hs_mask( @@ -932,15 +995,14 @@ class ConformerEncoder(TransformerEncoderBase): # layer_emb = None - relative_attention_bias = self.init_relative_attention_bias( - input_tensor) + relative_attention_bias = self.init_relative_attention_bias(input_tensor) - _simplified_path = (self.extra_layer_output_idx == -1 - and relative_attention_bias is None) + _simplified_path = ( + self.extra_layer_output_idx == -1 and relative_attention_bias is None + ) if _simplified_path: - input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, - hs_mask) + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) else: for i, layer in enumerate(self.encoders): input_tensor, _, _, _ = layer( @@ -980,24 +1042,33 @@ class WindowQformer(nn.Module): ): super().__init__() - self.decoders = nn.ModuleList([ - nn.TransformerDecoderLayer( - d_model=attention_dim, - nhead=attention_heads, - dim_feedforward=linear_units, - dropout=dropout_rate, - activation="relu", - batch_first=True, - norm_first=normalize_before, # TODO need to verify - ) for _ in range(num_blocks) - ]) + self.decoders = nn.ModuleList( + [ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) + for _ in range(num_blocks) + ] + ) self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) - self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) - if normalize_before else None) + self.after_norm = ( + nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None + ) self.window_size = window_size - def forward(self, audio_embed, mask, embed_len=None): + def forward( + self, + audio_embed: torch.Tensor, + mask: Optional[torch.Tensor], + embed_len: Optional[int] = None, + ) -> tuple[torch.Tensor, Optional[int]]: """forward decoder""" # audio_embed: N x T x D => N x D x T @@ -1005,8 +1076,9 @@ class WindowQformer(nn.Module): # audio_embed: N x D x 1 x T => N x DK x T' padding = audio_embed.shape[-1] % self.window_size if padding > 0: - audio_embed = F.pad(audio_embed, (0, self.window_size - padding), - "constant", 0) + audio_embed = F.pad( + audio_embed, (0, self.window_size - padding), "constant", 0 + ) embed_chunk = F.unfold( audio_embed[..., None, :], @@ -1023,10 +1095,7 @@ class WindowQformer(nn.Module): # NT' x 1 x D q = self.queries.expand(bsz * slen, -1, -1) for layer in self.decoders: - q = layer(tgt=q, - memory=embed_chunk, - tgt_mask=None, - memory_mask=mask) + q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask) if self.after_norm is not None: q = self.after_norm(q) @@ -1042,12 +1111,11 @@ class WindowQformer(nn.Module): class AudioEmbedding(nn.Module): """Image embedding.""" - def __init__(self, config: PretrainedConfig, **kwargs) -> None: + def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: super().__init__() self.config = config # n_embed or hidden_size for text LM - hidden_size = (config.n_embd - if hasattr(config, "n_embd") else config.hidden_size) + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size # self.wte = nn.Embedding(config.vocab_size, hidden_size) @@ -1056,8 +1124,10 @@ class AudioEmbedding(nn.Module): ) self.layer_idx = -2 - if (isinstance(config.audio_processor, dict) - and config.audio_processor.get("name", None) == "cascades"): + if ( + isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades" + ): encoder_config = config.audio_processor.get("config", None) assert encoder_config is not None self.encoder = ConformerEncoder(**encoder_config) @@ -1067,13 +1137,11 @@ class AudioEmbedding(nn.Module): else: raise NotImplementedError("") - assert (audio_dim_out - is not None), "Remember to set values for audio_dim_out" + assert audio_dim_out is not None, "Remember to set values for audio_dim_out" self.audio_dim_out = audio_dim_out self.audio_dim_in = n_mels - self.freeze_audio_processor = kwargs.get("freeze_audio_processor", - False) + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False) self.downsample_rate = kwargs.get("downsample_rate", 1) @@ -1085,8 +1153,9 @@ class AudioEmbedding(nn.Module): self.qformer = None if kwargs.get("use_conv_downsample", False): - assert (self.qformer is None - ), "don't support use qformer and conv downsample together" + assert self.qformer is None, ( + "don't support use qformer and conv downsample together" + ) nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) default_nemo_conv_settings = { "subsampling": "dw_striding", @@ -1102,11 +1171,13 @@ class AudioEmbedding(nn.Module): if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" + assert i not in nemo_conv_settings, ( + "{i} should be specified outside of the NeMo dictionary" + ) - self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, ) + self.conv_ds = NemoConvSubsampling( + **default_nemo_conv_settings, + ) else: self.conv_ds = None @@ -1118,60 +1189,53 @@ class AudioEmbedding(nn.Module): # (do not use image_projection and image_proj_norm) dim_projection = hidden_size depth = 2 - self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds) - else self.downsample_rate) + self.linear_downsample_rate = ( + 1 if (self.qformer or self.conv_ds) else self.downsample_rate + ) layers = [ - nn.Linear(audio_dim_out * self.linear_downsample_rate, - dim_projection) + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) ] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.audio_projection = nn.Sequential(*layers) # NOTE vision-speech tasks use a separate projection layer layers = [ - nn.Linear(audio_dim_out * self.linear_downsample_rate, - dim_projection) + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) ] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.audio_projection_for_vision = nn.Sequential(*layers) else: raise NotImplementedError( - f"projection_cls = {projection_cls}, not implemented") + f"projection_cls = {projection_cls}, not implemented" + ) # TODO: audio sequence compression - Qformer self.vocab_size = config.vocab_size self.input_embeds = None self.audio_embed_sizes = None - def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + def set_audio_embeds(self, input_embeds: torch.Tensor) -> None: self.input_embeds = input_embeds - def set_audio_embed_sizes(self, - audio_embed_sizes: torch.LongTensor) -> None: + def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None: self.audio_embed_sizes = audio_embed_sizes def get_audio_features( self, - input_embeds: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, + input_embeds: torch.Tensor, + audio_attention_mask: Optional[torch.Tensor] = None, audio_projection_mode: str = "speech", - ) -> torch.FloatTensor: + ) -> torch.Tensor: """ arguments: input_embeds: audio features (B, T, D) B: num audios in a sequence """ if self.freeze_audio_processor: with torch.no_grad(): - audio_features, masks = self.encoder(input_embeds, - audio_attention_mask) + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) else: - audio_features, masks = self.encoder(input_embeds, - audio_attention_mask) + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) if self.qformer is not None: audio_features, _ = self.qformer(audio_features, mask=None) @@ -1200,28 +1264,27 @@ class AudioEmbedding(nn.Module): feat_dim * self.linear_downsample_rate, ) - if audio_projection_mode == 'speech': + if audio_projection_mode == "speech": audio_set_tensor = self.audio_projection(audio_features) - elif audio_projection_mode == 'vision': + elif audio_projection_mode == "vision": audio_set_tensor = self.audio_projection_for_vision(audio_features) else: raise ValueError( - f"audio_projection_mode = {audio_projection_mode} not "\ - "implemented" + f"audio_projection_mode = {audio_projection_mode} not implemented" ) return audio_set_tensor def forward( self, - audio_features: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, + audio_features: torch.Tensor, + audio_attention_mask: Optional[torch.Tensor] = None, audio_projection_mode: str = "speech", - ) -> torch.FloatTensor: + ) -> torch.Tensor: """ arguments: audio_features: audio features (T, D) - + returns: audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) """ diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index c4890d8427e2a..d50547c199ac5 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -16,13 +16,13 @@ from torch import Tensor, nn class BlockBase(nn.Module): """Block abstract module""" - def __init__(self, input_size, output_size): + def __init__(self, input_size: int, output_size: int) -> None: super().__init__() self.input_size = input_size self.output_size = output_size -def get_activation(name="relu"): +def get_activation(name: str = "relu") -> torch.nn.Module: """Select an activation function by name Args: @@ -43,15 +43,17 @@ def get_activation(name="relu"): return nn.Identity() -def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): +def adaptive_enc_mask( + x_len: int, chunk_start_idx: list[int], left_window: int = 0, right_window: int = 0 +) -> torch.Tensor: """ The function is very important for Transformer Transducer Streaming mode Args: - xs_len (int): sequence length - chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + x_len: sequence length + chunk_start_idx: first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] - left_window (int): how many left chunks can be seen - right_window (int): how many right chunks can be seen. It is used for + left_window: how many left chunks can be seen + right_window: how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model @@ -64,21 +66,23 @@ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): [False., True., True., False.], [False., False., True., True.]]) """ - chunk_start_idx = torch.Tensor(chunk_start_idx).long( - ) # first idx of each chunk, such as [0,18,36,48]. + chunk_start_idx = torch.Tensor( + chunk_start_idx + ).long() # first idx of each chunk, such as [0,18,36,48]. start_pad = torch.nn.functional.pad( - chunk_start_idx, - (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] end_pad = torch.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] - seq_range = torch.arange(0, - x_len).unsqueeze(-1) # seq_range size: [x_len, 1] - idx = ((seq_range < end_pad) & - (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[ + :, 1 + ] # idx size: [x_len] # boundary = end_pad[idx] # boundary size: [x_len] - seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) - ) # seq_range_expand size [x_len, x_len] + seq_range_expand = ( + torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] @@ -172,13 +176,13 @@ class GLUPointWiseConv(nn.Module): def __init__( self, - input_dim, - output_dim, - kernel_size, - glu_type="sigmoid", - bias_in_glu=True, - causal=False, - ): + input_dim: int, + output_dim: int, + kernel_size: int, + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + causal: bool = False, + ) -> None: super().__init__() self.glu_type = glu_type @@ -216,11 +220,10 @@ class GLUPointWiseConv(nn.Module): self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ Args: - x: torch.Tensor - input tensor + x: input tensor """ # to be consistent with GLULinear, we assume the input always has the # #channel (#dim) in the last dimension of the tensor, so need to @@ -229,18 +232,23 @@ class GLUPointWiseConv(nn.Module): x = self.ext_pw_conv_1d(x) if self.glu_type == "bilinear": if self.bias_in_glu: - x = (x[:, 0:self.output_dim, :] + self.b1) * ( - x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) else: - x = (x[:, 0:self.output_dim, :]) * ( - x[:, self.output_dim:self.output_dim * 2, :]) + x = ( + (x[:, 0 : self.output_dim, :]) + * (x[:, self.output_dim : self.output_dim * 2, :]) + ) else: if self.bias_in_glu: - x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act( - x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) else: - x = (x[:, 0:self.output_dim, :]) * self.glu_act( - x[:, self.output_dim:self.output_dim * 2, :]) + x = (x[:, 0 : self.output_dim, :]) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + ) x = x.permute([0, 2, 1]) return x @@ -255,10 +263,10 @@ class DepthWiseSeperableConv1d(nn.Module): input_dim: int input channel size. depthwise_seperable_out_channel: int - if set different to 0, the number of + if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. + otherwise, it equals to 0, the second conv1d layer is skipped. kernel_size: int kernel_size depthwise_multiplier: int @@ -272,12 +280,12 @@ class DepthWiseSeperableConv1d(nn.Module): def __init__( self, - input_dim, - depthwise_seperable_out_channel, - kernel_size, - depthwise_multiplier, - padding=0, - ): + input_dim: int, + depthwise_seperable_out_channel: int, + kernel_size: int, + depthwise_multiplier: int, + padding: int = 0, + ) -> None: super().__init__() self.dw_conv = nn.Conv1d( @@ -301,12 +309,11 @@ class DepthWiseSeperableConv1d(nn.Module): self.pw_conv = nn.Identity() self.depthwise_seperable_out_channel = depthwise_seperable_out_channel - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ Args: - x: torch.Tensor - input tensor + x: input tensor """ x = self.dw_conv(x) if self.depthwise_seperable_out_channel != 0: @@ -326,7 +333,7 @@ class ConvModule(nn.Module): if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int - if set different to 0, the number of + if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. @@ -375,23 +382,23 @@ class ConvModule(nn.Module): def __init__( self, - input_dim, - ext_pw_out_channel, - depthwise_seperable_out_channel, - ext_pw_kernel_size, - kernel_size, - depthwise_multiplier, - dropout_rate, - causal=False, - batch_norm=False, - chunk_se=0, - chunk_size=18, - activation="relu", - glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - export=False, - ): + input_dim: int, + ext_pw_out_channel: int, + depthwise_seperable_out_channel: int, + ext_pw_kernel_size: int, + kernel_size: int, + depthwise_multiplier: int, + dropout_rate: float, + causal: bool = False, + batch_norm: bool = False, + chunk_se: int = 0, + chunk_size: int = 18, + activation: str = "relu", + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + export: bool = False, + ) -> None: super().__init__() self.layer_norm = nn.LayerNorm(input_dim) self.input_dim = input_dim @@ -430,21 +437,20 @@ class ConvModule(nn.Module): if depthwise_seperable_out_channel != 0: if input_dim != depthwise_seperable_out_channel: - self.ln2 = nn.Linear(depthwise_seperable_out_channel, - input_dim) + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) else: if depthwise_multiplier != 1: - self.ln2 = nn.Linear(input_dim * depthwise_multiplier, - input_dim) + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) - def _add_ext_pw_layer(self): + def _add_ext_pw_layer(self) -> None: """ This function is an extension of __init__ function and dedicated to the convolution module creation of the conformer. """ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( - nn.Identity()) # jit hacks. + nn.Identity() + ) # jit hacks. self.squeeze_excitation = nn.Identity() # jit. self.apply_ln1 = self.fix_len1 = False # jit. @@ -497,19 +503,18 @@ class ConvModule(nn.Module): self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ConvModule Forward. Args: - x: torch.Tensor - input tensor. + x: input tensor. """ x = self.layer_norm(x) if self.ext_pw_out_channel != 0: x = self.glu(x) if self.causal and self.ext_pw_kernel_size > 1: - x = x[:, :-(self.ext_pw_kernel_size - 1), :] + x = x[:, : -(self.ext_pw_kernel_size - 1), :] if self.apply_ln1: x = self.ln1(x) else: @@ -521,7 +526,7 @@ class ConvModule(nn.Module): x = self.dw_sep_conv_1d(x) if self.causal and self.kernel_size > 1: - x = x[:, :, :-(self.kernel_size - 1)] + x = x[:, :, : -(self.kernel_size - 1)] if hasattr(self, "ln2"): x = x.permute([0, 2, 1]) x = self.ln2(x) @@ -533,7 +538,7 @@ class ConvModule(nn.Module): if self.ext_pw_out_channel != 0: x = self.ext_pw_conv_1d(x) if self.fix_len1: - x = x[:, :, :-(self.ext_pw_kernel_size - 1)] + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] if self.apply_ln1: x = x.permute([0, 2, 1]) @@ -567,21 +572,20 @@ class GLULinear(nn.Module): def __init__( self, - input_dim, - output_dim, - glu_type="sigmoid", - bias_in_glu=True, - ): + input_dim: int, + output_dim: int, + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + ) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) self.glu_act = GLU(-1, glu_type) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """GLULinear forward Args: - x: torch.Tensor - inpute tensor. + x: input tensor. """ x = self.linear(x) return self.glu_act(x) @@ -609,12 +613,12 @@ class FeedForward(nn.Module): def __init__( self, - d_model, - d_inner, - dropout_rate, - activation="sigmoid", - bias_in_glu=True, - ): + d_model: int, + d_inner: int, + dropout_rate: float, + activation: str = "sigmoid", + bias_in_glu: bool = True, + ) -> None: super().__init__() self.d_model = d_model self.d_inner = d_inner @@ -628,12 +632,11 @@ class FeedForward(nn.Module): nn.Dropout(dropout_rate), ) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """FeedForward forward function. Args: - x: torch.Tensor - input tensor. + x: input tensor. """ out = self.net(self.layer_norm(x)) @@ -642,19 +645,19 @@ class FeedForward(nn.Module): #### positional encoding starts here def _pre_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], +) -> None: """Perform pre-hook in load_state_dict for backward compatibility. Note: We saved self.pe until v.0.5.2 but we have omitted it later. - Therefore, we remove the item "pe" from `state_dict` for backward + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. """ @@ -665,7 +668,7 @@ def _pre_hook( class T5RelativeAttentionLogitBias(nn.Module): """ - This module implements the relative position bias described in Section + This module implements the relative position bias described in Section 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf The Huggingface implementation is used as a reference @@ -673,18 +676,18 @@ class T5RelativeAttentionLogitBias(nn.Module): transformers/models/t5/modeling_t5.py#L435 Modifies attention as Q*K^T + B, where B is a learned scalar bias based - on relative position of the query and key. It is HxNxN, where H is the + on relative position of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. I've made these modifications to the original T5 bias: - - Skipping of the bucketing step. Original T5 bias converted rel - position distances into logarithmically increasing buckets. This is + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is supposed to help with length generalization. - - I just directly use rel position index as bias values, as we don't - need length generalization (40s max is good enough for ASR encoder), + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. - - I've also extended it so that biases can be asymmetric, the default - implementation treats L->R and R->L the same. Asymmetric was found to + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to yield better results in my experiments. Args: @@ -692,26 +695,28 @@ class T5RelativeAttentionLogitBias(nn.Module): Number of attention heads num_buckets: int Number of buckets to use for relative attention bias. This is the - size of the learnable bias parameter. Bucketing is not yet + size of the learnable bias parameter. Bucketing is not yet supported, so this defaults to -1 which means no bucketing is used (max_distance determines size of bias param). max_distance: int - Maximum distance to use for relative attention bias. With - num_buckets=-1, this directly controls the max size of the bias - parameter. When num_buckets > 0 is supported, this will control - the maximum distance for logarithmic bucketing after which all + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all positions are in the same bucket. symmetric: bool Whether to use symmetric or asymmetric biases. symmetric=False uses - 2x number of bias params to distinguish L->R from R->L. This was + 2x number of bias params to distinguish L->R from R->L. This was found to be better for the encoder. """ - def __init__(self, - num_heads, - num_buckets=-1, - max_distance=1000, - symmetric=False): + def __init__( + self, + num_heads: int, + num_buckets: int = -1, + max_distance: int = 1000, + symmetric: bool = False, + ) -> None: super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets @@ -722,27 +727,30 @@ class T5RelativeAttentionLogitBias(nn.Module): self.num_buckets = max_distance else: raise NotImplementedError( - "T5 attention bias with bucketed positions is not yet tested") + "T5 attention bias with bucketed positions is not yet tested" + ) if not self.symmetric: self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: # instantiate bias compatible with shape of x maxpos = x.size(1) - context_position = torch.arange(maxpos, - device=x.device, - dtype=torch.long)[:, None] - memory_position = torch.arange(maxpos, - device=x.device, - dtype=torch.long)[None, :] + context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + :, None + ] + memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + None, : + ] relative_position = memory_position - context_position # clipping to a maximum distance using ops that play well with ONNX # export relative_position = relative_position.masked_fill( - relative_position < -self.max_distance, -self.max_distance) + relative_position < -self.max_distance, -self.max_distance + ) relative_position = relative_position.masked_fill( - relative_position > self.max_distance - 1, self.max_distance - 1) + relative_position > self.max_distance - 1, self.max_distance - 1 + ) # mapping from relative position to index in the bias parameter if self._skip_bucketing: @@ -755,12 +763,11 @@ class T5RelativeAttentionLogitBias(nn.Module): bias_idx += self.num_buckets // 2 t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] - t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze( - 0) # [1, H, L, L] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] return t5_rel_att_bias - def _bucket_relative_position(self, relative_position): + def _bucket_relative_position(self, relative_position: Tensor) -> Tensor: # this is a placeholder (isn't tested, likely buggy) using HuggingFace # implem as a reference this also needs to be extended to support # asymmetric +/- ve positions @@ -768,11 +775,13 @@ class T5RelativeAttentionLogitBias(nn.Module): if not self.causal: self.num_buckets //= 2 relative_buckets += (relative_position > 0).to( - torch.long) * self.num_buckets + torch.long + ) * self.num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min(relative_position, - torch.zeros_like(relative_position)) + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -782,16 +791,18 @@ class T5RelativeAttentionLogitBias(nn.Module): # The other half of the buckets are for logarithmically bigger bins in # positions up to max_distance relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) / - math.log(self.max_distance / max_exact) * - (self.num_buckets - max_exact)).to(torch.long) + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (self.num_buckets - max_exact) + ).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, self.num_buckets - 1), ) - relative_buckets += torch.where(is_small, relative_position, - relative_position_if_large) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) return relative_buckets @@ -810,7 +821,7 @@ class AbsolutePositionalEncoding(nn.Module): """ - def __init__(self, d_model, dropout_rate, max_len=5000): + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super().__init__() self.d_model = d_model @@ -820,11 +831,11 @@ class AbsolutePositionalEncoding(nn.Module): self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self._register_load_state_dict_pre_hook(_pre_hook) - def extend_pe(self, x): + def extend_pe(self, x: torch.Tensor) -> None: """Reset the positional encodings. Args: - x: torch.Tensor + x: input tensor """ if self.pe is not None and self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: @@ -833,26 +844,26 @@ class AbsolutePositionalEncoding(nn.Module): pe = torch.zeros(x.size(1), self.d_model) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) * - -(math.log(10000.0) / self.d_model)) + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Add positional encoding. Args: - x: torch.Tensor - Input tensor. shape is (batch, time, ...) + x: Input tensor. shape is (batch, time, ...) Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + Encoded tensor. Its shape is (batch, time, ...) """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, :x.size(1)] + x = x * self.xscale + self.pe[:, : x.size(1)] return self.dropout(x) @@ -868,7 +879,7 @@ class MeanVarianceNormLayer(nn.Module): layer input size. """ - def __init__(self, input_size): + def __init__(self, input_size: int) -> None: super().__init__() self.input_size = input_size self.global_mean = nn.Parameter(torch.zeros(input_size)) @@ -878,8 +889,7 @@ class MeanVarianceNormLayer(nn.Module): """MeanVarianceNormLayer Forward Args: - input_: torch.Tensor - input tensor. + input_: input tensor. """ return (input_ - self.global_mean) * self.global_invstd @@ -890,14 +900,14 @@ class CausalConv1D(nn.Conv1d): locations on its right or left All arguments are the same as nn.Conv1d except padding. - If padding is set None, then paddings are set automatically to make it a + If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. - If padding is set as a list (size of 2), then padding[0] would be used as + If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. It would make it possible to control the number of steps to be accessible on the right and left. - This mode is not supported when stride > 1. padding[0]+padding[1] should + This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). """ @@ -921,13 +931,15 @@ class CausalConv1D(nn.Conv1d): self._right_padding = stride - 1 else: if stride != 1 and padding != kernel_size - 1: - raise ValueError( - "No striding allowed for non-symmetric convolutions!") + raise ValueError("No striding allowed for non-symmetric convolutions!") if isinstance(padding, int): self._left_padding = padding self._right_padding = padding - elif (isinstance(padding, list) and len(padding) == 2 - and padding[0] + padding[1] == kernel_size - 1): + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): self._left_padding = padding[0] self._right_padding = padding[1] else: @@ -949,7 +961,9 @@ class CausalConv1D(nn.Conv1d): dtype=dtype, ) - def update_cache(self, x, cache=None): + def update_cache( + self, x: Tensor, cache: Optional[Tensor] = None + ) -> tuple[Tensor, Optional[Tensor]]: if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) next_cache = cache @@ -957,13 +971,15 @@ class CausalConv1D(nn.Conv1d): new_x = F.pad(x, pad=(0, self._right_padding)) new_x = torch.cat([cache, new_x], dim=-1) if self.cache_drop_size > 0: - next_cache = new_x[:, :, :-self.cache_drop_size] + next_cache = new_x[:, :, : -self.cache_drop_size] else: next_cache = new_x - next_cache = next_cache[:, :, -cache.size(-1):] + next_cache = next_cache[:, :, -cache.size(-1) :] return new_x, next_cache - def forward(self, x, cache=None): + def forward( + self, x: Tensor, cache: Optional[Tensor] = None + ) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]: x, cache = self.update_cache(x, cache=cache) x = super().forward(x) if cache is None: @@ -976,7 +992,7 @@ class CausalConv2D(nn.Conv2d): """ A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down - All arguments are the same as nn.Conv2d except padding which should be + All arguments are the same as nn.Conv2d except padding which should be set as None """ @@ -995,8 +1011,7 @@ class CausalConv2D(nn.Conv2d): dtype=None, ) -> None: if padding is not None: - raise ValueError( - "Argument padding should be set to None for CausalConv2D.") + raise ValueError("Argument padding should be set to None for CausalConv2D.") self._left_padding = kernel_size - 1 self._right_padding = stride - 1 @@ -1017,8 +1032,8 @@ class CausalConv2D(nn.Conv2d): def forward( self, - x, - ): + x: Tensor, + ) -> Tensor: x = F.pad( x, pad=(self._left_padding, self._right_padding, 0, 0), @@ -1032,17 +1047,17 @@ class NemoConvSubsampling(torch.nn.Module): (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) - Striding Subsampling: "Speech-Transformer: A No-Recurrence - Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) - Compared with the EncoderConv2D (`input_layer: custom`), this is a + Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. - `Striding` and `dw_striding` are the same except that the latter uses + `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions after the first layer, whereas the former does not. Args: @@ -1050,11 +1065,11 @@ class NemoConvSubsampling(torch.nn.Module): feat_in (int): size of the input features feat_out (int): size of the output features subsampling (str): The subsampling technique, choose from - {"striding", "dw-striding", "striding_conv1d", + {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} - conv_channels (int): Number of channels for the convolution layers, + conv_channels (int): Number of channels for the convolution layers, default is 256. - subsampling_conv_chunking_factor (int): Input chunking factor which + subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 activation (Module): activation function, default is nn.ReLU() is_causal (bool): whether to use causal Conv1/2D, where each step will @@ -1062,16 +1077,16 @@ class NemoConvSubsampling(torch.nn.Module): """ def __init__( - self, - feat_in, - feat_out, - subsampling_factor=4, - subsampling="dw_striding", - conv_channels=256, - subsampling_conv_chunking_factor=1, - activation=nn.ReLU(), # noqa: B008 - is_causal=False, - ): + self, + feat_in: int, + feat_out: int, + subsampling_factor: int = 4, + subsampling: str = "dw_striding", + conv_channels: int = 256, + subsampling_conv_chunking_factor: int = 1, + activation: torch.nn.Module = nn.ReLU(), # noqa: B008 + is_causal: bool = False, + ) -> None: super().__init__() self._subsampling = subsampling self._conv_channels = conv_channels @@ -1089,15 +1104,15 @@ class NemoConvSubsampling(torch.nn.Module): "striding_conv1d", ) - if (subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): raise ValueError( - "subsampling_conv_chunking_factor should be -1, 1, or a "\ - "power of 2" + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" ) - self.subsampling_conv_chunking_factor = \ - subsampling_conv_chunking_factor + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor in_channels = 1 layers = [] @@ -1125,7 +1140,8 @@ class NemoConvSubsampling(torch.nn.Module): kernel_size=self._kernel_size, stride=self._stride, padding=None, - )) + ) + ) else: layers.append( torch.nn.Conv2d( @@ -1134,7 +1150,8 @@ class NemoConvSubsampling(torch.nn.Module): kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - )) + ) + ) in_channels = conv_channels layers.append(activation) @@ -1148,7 +1165,8 @@ class NemoConvSubsampling(torch.nn.Module): stride=self._stride, padding=None, groups=in_channels, - )) + ) + ) else: layers.append( torch.nn.Conv2d( @@ -1158,7 +1176,8 @@ class NemoConvSubsampling(torch.nn.Module): stride=self._stride, padding=self._left_padding, groups=in_channels, - )) + ) + ) layers.append( torch.nn.Conv2d( @@ -1168,7 +1187,8 @@ class NemoConvSubsampling(torch.nn.Module): stride=1, padding=0, groups=1, - )) + ) + ) layers.append(activation) in_channels = conv_channels @@ -1195,7 +1215,8 @@ class NemoConvSubsampling(torch.nn.Module): kernel_size=self._kernel_size, stride=self._stride, padding=None, - )) + ) + ) else: layers.append( torch.nn.Conv2d( @@ -1204,7 +1225,8 @@ class NemoConvSubsampling(torch.nn.Module): kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - )) + ) + ) layers.append(activation) in_channels = conv_channels @@ -1229,22 +1251,30 @@ class NemoConvSubsampling(torch.nn.Module): layers.append( CausalConv1D( in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == i + - 1 else conv_channels), + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), kernel_size=self._kernel_size, stride=self._stride, padding=None, - )) + ) + ) else: layers.append( torch.nn.Conv1d( in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == i + - 1 else conv_channels), + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - )) + ) + ) layers.append(activation) in_channels = conv_channels @@ -1259,30 +1289,8 @@ class NemoConvSubsampling(torch.nn.Module): self._right_padding = (self._kernel_size - 1) // 2 # Layer 1 - layers.extend([ - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - groups=in_channels, - ), - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == 1 else - conv_channels), - kernel_size=1, - stride=1, - padding=0, - groups=1, - ), - ]) - in_channels = conv_channels - layers.append(activation) - - for i in range(self._sampling_num - 1): - layers.extend([ + layers.extend( + [ torch.nn.Conv1d( in_channels=in_channels, out_channels=in_channels, @@ -1293,14 +1301,44 @@ class NemoConvSubsampling(torch.nn.Module): ), torch.nn.Conv1d( in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == i + - 2 else conv_channels), + out_channels=( + feat_out if self._sampling_num == 1 else conv_channels + ), kernel_size=1, stride=1, padding=0, groups=1, ), - ]) + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 2 + else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) layers.append(activation) in_channels = conv_channels @@ -1317,8 +1355,7 @@ class NemoConvSubsampling(torch.nn.Module): ceil_mode=self._ceil_mode, repeat_num=self._sampling_num, ) - self.out = torch.nn.Linear(conv_channels * int(out_length), - feat_out) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) self.conv2d_subsampling = True elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: self.out = None @@ -1328,43 +1365,39 @@ class NemoConvSubsampling(torch.nn.Module): self.conv = torch.nn.Sequential(*layers) - def get_sampling_frames(self): + def get_sampling_frames(self) -> list[int]: return [1, self.subsampling_factor] - def get_streaming_cache_size(self): + def get_streaming_cache_size(self) -> list[int]: return [0, self.subsampling_factor + 1] - def forward(self, x, mask): + def forward( + self, x: Tensor, mask: Optional[Tensor] + ) -> tuple[Tensor, Optional[Tensor]]: """ Forward method for NeMo subsampling. Args: - x[Batch, Time, Filters]: torch.Tensor - input tensor - x_mask: torch.Tensor - input mask + x: input tensor + mask: input mask Returns: - x: torch.Tensor - Resulting tensor from subsampling (B, T // + x: Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) - pad_mask: torch.Tensor - tensor of padded hidden state sequences (B, 1, T // + pad_mask: tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) """ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) # split inputs if chunking_factor is set - if (self.subsampling_conv_chunking_factor != -1 - and self.conv2d_subsampling): + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: if self.subsampling_conv_chunking_factor == 1: # if subsampling_conv_chunking_factor is 1, we split only # if needed. # avoiding a bug / feature limiting indexing of tensors # to 2**31. # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = (2**31 / self._conv_channels * self._stride * - self._stride) + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride need_to_split = torch.numel(x) > x_ceil else: # if subsampling_conv_chunking_factor > 1 we always split @@ -1400,40 +1433,36 @@ class NemoConvSubsampling(torch.nn.Module): feature_lens_remainder = feature_lens % self.subsampling_factor padding_length[feature_lens_remainder != 1] += 1 pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( - padding_length.size(0), -1) < padding_length.unsqueeze(1) + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) return x, pad_mask.unsqueeze(1) - def reset_parameters(self): + def reset_parameters(self) -> None: # initialize weights if self._subsampling == "dw_striding": with torch.no_grad(): # init conv scale = 1.0 / self._kernel_size - dw_max = (self._kernel_size**2)**-0.5 + dw_max = (self._kernel_size**2) ** -0.5 pw_max = self._conv_channels**-0.5 torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) for idx in range(2, len(self.conv), 3): - torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, - dw_max) - torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, - dw_max) - torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, - pw_max) - torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, - pw_max) + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ # src/models/conformer_encoder.py#L487 - fc_scale = (self._feat_out * self._feat_in / - self._sampling_num)**-0.5 + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) - def conv_split_by_batch(self, x): + def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]: """Tries to split input by batch, run conv and concat results""" b, _, _, _ = x.size() if b == 1: # can't split if batch size is 1 @@ -1453,15 +1482,14 @@ class NemoConvSubsampling(torch.nn.Module): return x, False return ( - torch.cat([ - self.conv(chunk) - for chunk in torch.split(x, new_batch_size, 0) - ]), + torch.cat( + [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)] + ), True, ) - def conv_split_by_channel(self, x): - """For dw convs, tries to split input by time, run conv and concat + def conv_split_by_channel(self, x: Tensor) -> Tensor: + """For dw convs, tries to split input by time, run conv and concat results""" x = self.conv[0](x) # full conv2D x = self.conv[1](x) # activation @@ -1486,21 +1514,21 @@ class NemoConvSubsampling(torch.nn.Module): if new_t == 0: new_t = 1 - x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, - x) # conv2D, depthwise + x = self.channel_chunked_conv( + self.conv[i * 3 + 2], new_c, x + ) # conv2D, depthwise # splitting pointwise convs by time x = torch.cat( - [ - self.conv[i * 3 + 3](chunk) - for chunk in torch.split(x, new_t, 2) - ], + [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2, ) # conv2D, pointwise x = self.conv[i * 3 + 4](x) # activation return x - def channel_chunked_conv(self, conv, chunk_size, x): + def channel_chunked_conv( + self, conv: torch.nn.Module, chunk_size: int, x: Tensor + ) -> Tensor: """Performs channel chunked convolution""" ind = 0 @@ -1520,8 +1548,8 @@ class NemoConvSubsampling(torch.nn.Module): ) ch_out = nn.functional.conv2d( chunk, - conv.weight[ind:ind + step, :, :, :], - bias=conv.bias[ind:ind + step], + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], stride=self._stride, padding=0, groups=step, @@ -1529,8 +1557,8 @@ class NemoConvSubsampling(torch.nn.Module): else: ch_out = nn.functional.conv2d( chunk, - conv.weight[ind:ind + step, :, :, :], - bias=conv.bias[ind:ind + step], + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], stride=self._stride, padding=self._left_padding, groups=step, @@ -1541,30 +1569,33 @@ class NemoConvSubsampling(torch.nn.Module): return torch.cat(out_chunks, 1) def change_subsampling_conv_chunking_factor( - self, subsampling_conv_chunking_factor: int): - if (subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0): + self, subsampling_conv_chunking_factor: int + ) -> None: + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): raise ValueError( - "subsampling_conv_chunking_factor should be -1, 1, or a "\ - "power of 2" + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" ) self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor -def calc_length(lengths, - all_paddings, - kernel_size, - stride, - ceil_mode, - repeat_num=1): +def calc_length( + lengths: Tensor, + all_paddings: int, + kernel_size: int, + stride: int, + ceil_mode: bool, + repeat_num: int = 1, +) -> Tensor: """Calculates the output length of a Tensor passed through a convolution or - max pooling layer""" + max pooling layer""" add_pad: float = all_paddings - kernel_size one: float = 1.0 for i in range(repeat_num): - lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + - one) + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) return lengths.to(dtype=torch.int) @@ -1573,11 +1604,11 @@ def calc_length(lengths, class AttModule(nn.Module): """Attention abstraction module""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.export_mode = False - def set_export(self, mode=True): + def set_export(self, mode: bool = True) -> None: """set the export mode""" self.export_mode = mode @@ -1591,14 +1622,10 @@ class AttModule(nn.Module): """AttModule forward Args: - x: torch.Tensor - input tensor. - memory: torch.Tensor, optional - memory tensor. - pos_emb: torch.Tensor, optional - positional encoder embedding. - att_mask: torch.Tensor, optional - attention mask tensor. + x: input tensor. + memory: memory tensor. + pos_emb: positional encoder embedding. + att_mask: attention mask tensor. """ return x, memory, pos_emb, att_mask @@ -1606,27 +1633,28 @@ class AttModule(nn.Module): class AttBlock(BlockBase, AttModule): """Attention Block module to support both Attention and Block module.""" - def memory_dims(self, max_len=False): + def memory_dims(self, max_len: bool = False) -> tuple[int, int]: """memory dimensions""" return (1, self.input_size) def masked_softmax( - scores, + scores: Tensor, mask: Optional[Tensor], -): +) -> Tensor: if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) scores = scores.masked_fill(mask, -torch.inf) attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0) # (batch, head, time1, time2) + mask, 0.0 + ) # (batch, head, time1, time2) else: attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) return attn class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer with optional relative position embedding + """Multi-Head Attention layer with optional relative position embedding and GLU. Args: @@ -1636,22 +1664,18 @@ class MultiHeadedAttention(nn.Module): input size features. dropout_rate: float dropout rate. - use_LN: bool - apply layer norm or not - dropout_at_output: bool - whether to apply dropout at output attention_inner_dim: int, optional the attention dimension used in the class, it can be different from the input dimension n_feat. default: -1 (equal to n_feat). use_pt_scaled_dot_product_attention: bool, optional if set True, use pytorch scaled dot product attention in training. - NOTE: this will NOT be used in ONNX decoding due to a lack of - support. In that case, we use the original attention + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention implementation, which shows no regression. default: False. n_value: int, optional - if set to values other than -1, use a different dimension for + if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. group_size: int, optional. must divide `n_head` if group_size > 1: GQA @@ -1666,16 +1690,16 @@ class MultiHeadedAttention(nn.Module): def __init__( self, - n_head, - n_feat, - dropout_rate, - attention_inner_dim=-1, - glu_type="swish", - bias_in_glu=True, - use_pt_scaled_dot_product_attention=False, - n_value=-1, + n_head: int, + n_feat: int, + dropout_rate: float, + attention_inner_dim: int = -1, + glu_type: str = "swish", + bias_in_glu: bool = True, + use_pt_scaled_dot_product_attention: bool = False, + n_value: int = -1, group_size: int = 1, - ): + ) -> None: super().__init__() if n_value == -1: n_value = n_feat @@ -1699,8 +1723,7 @@ class MultiHeadedAttention(nn.Module): self.attn = torch.jit.Attribute(None, Optional[Tensor]) self.dropout = nn.Dropout(p=dropout_rate) self.dropout_rate = dropout_rate - self.use_pt_scaled_dot_product_attention = ( - use_pt_scaled_dot_product_attention) + self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention if use_pt_scaled_dot_product_attention and group_size > 1: raise ValueError("Cannot use PT Scaled Attention with GQA") @@ -1718,45 +1741,38 @@ class MultiHeadedAttention(nn.Module): query: Tensor, key: Tensor, value: Tensor, - pos_k: Tensor, - pos_v: Tensor, + pos_k: Optional[Tensor], + pos_v: Optional[Tensor], mask: Optional[Tensor], relative_attention_bias: Optional[Tensor] = None, - ): + ) -> Tensor: """Compute 'Scaled Dot Product Attention'. Args: - query: torch.Tensor - query tensor (batch, time1, size) - key: torch.Tensor - key tensor (batch, time2, size) - value: torch.Tensor - value tensor (batch, time1, size) - pos_k: torch.Tensor - key tensor used for relative positional embedding. - pos_v: torch.Tensor - value tensor used for relative positional embedding. - mask: torch.Tensor - mask tensor (batch, time1, time2) - relative_attention_bias: torch.Tensor - bias added to attention logits w.r.t. relative positions + query: query tensor (batch, time1, size) + key: key tensor (batch, time2, size) + value: value tensor (batch, time1, size) + pos_k: key tensor used for relative positional embedding. + pos_v: value tensor used for relative positional embedding. + mask: mask tensor (batch, time1, time2) + relative_attention_bias: bias added to attention logits w.r.t. + relative positions (1, n_head, time1, time2) """ n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, - self.d_k) # (b, t, d) - k = self.linear_k(key).view(n_batch, -1, self.h_k, - self.d_k) # (b, t, d) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) - q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention - and not torch.jit.is_scripting() else q.transpose(1, 2) * - self.inv_sqrt_d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) k = k.transpose(1, 2) # (batch, head_k, time2, d_k) v = v.transpose(1, 2) # (batch, head_k, time2, d_k) - if (self.use_pt_scaled_dot_product_attention - and not torch.jit.is_scripting()): + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): attn_mask = None if mask is not None: mask = mask.unsqueeze(1) @@ -1767,12 +1783,14 @@ class MultiHeadedAttention(nn.Module): if mask.dtype != q.dtype: attn_mask = attn_mask.to(q.dtype) - with torch.nn.attention.sdpa_kernel([ + with torch.nn.attention.sdpa_kernel( + [ torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, torch.nn.attention.SDPBackend.MATH, torch.nn.attention.SDPBackend.CUDNN_ATTENTION, - ]): + ] + ): x = torch.nn.functional.scaled_dot_product_attention( q, k, @@ -1790,14 +1808,17 @@ class MultiHeadedAttention(nn.Module): if self.h != self.h_k: B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) else: - reshape_q = (q.contiguous().view(n_batch * self.h, -1, - self.d_k).transpose(0, 1) - ) # (t1,nh,dk) - B = torch.matmul(reshape_q, - pos_k.transpose(-2, - -1)) # pos_k: (t1,dk,t2) - B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), - pos_k.size(1)) + reshape_q = ( + q.contiguous() + .view(n_batch * self.h, -1, self.d_k) + .transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul( + reshape_q, pos_k.transpose(-2, -1) + ) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view( + n_batch, self.h, pos_k.size(0), pos_k.size(1) + ) scores = A + B else: scores = A @@ -1810,20 +1831,24 @@ class MultiHeadedAttention(nn.Module): self.attn = attn p_attn = self.dropout(attn) - x = torch.matmul(p_attn.to(v.dtype), - v) # (batch, head, time1, d_k) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) if pos_v is not None: - reshape_attn = (p_attn.contiguous().view( - n_batch * self.h, pos_v.size(0), - pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2) + reshape_attn = ( + p_attn.contiguous() + .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) + .transpose(0, 1) + ) # (t1, bh, t2) - attn_v = (torch.matmul(reshape_attn, pos_v).transpose( - 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0), - self.d_k)) + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) x = x + attn_v - x = (x.transpose(1, 2).contiguous().view(n_batch, -1, - self.h_k * self.d_k) - ) # (batch, time1, d_model) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) + ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -1832,39 +1857,40 @@ class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential""" @torch.jit.ignore - def forward(self, *args): + def forward(self, *args) -> tuple: """Forward method implementation.""" for m in self: args = m(*args) return args -def get_offset(input_layer: str, time_reduction: int): - """Get an offset. We will use the offset for determining #frames of a +def get_offset(input_layer: str, time_reduction: int) -> int: + """Get an offset. We will use the offset for determining #frames of a subsampled feature. Args: - input_layer (str): Type of an input layer - time_reduction (int): time reduction factor for downsampling a feature + input_layer: Type of an input layer + time_reduction: time reduction factor for downsampling a feature Returns: int: offset """ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: return 3 - if input_layer in ("conv2d", ) and time_reduction == 6: + if input_layer in ("conv2d",) and time_reduction == 6: return 1 if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: return 7 return 0 -def unfold_tensor(xs_pad, max_seq_len): +def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor: """ - For a given tensor with shape of (N, T, D), if sequence length T is - longer than max_seq_len, this function unfold it to a + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: - xs_pad: N, T, D + xs_pad: input tensor with shape (N, T, D) + max_seq_len: maximum sequence length """ _, _, D = xs_pad.shape xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index cfe0982204fa9..fee52edfe26c8 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -23,7 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -35,28 +37,36 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class PhiMoEConfig(PretrainedConfig): - model_type = "phimoe" keys_to_ignore_at_inference = ["past_key_values"] @@ -129,7 +139,6 @@ class PhiMoEConfig(PretrainedConfig): class mp(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -174,8 +183,9 @@ def sparsemixer(scores, jitter_eps=0.01): # compute mask for sparsity mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) - mask_logits_threshold = ((mask_logits_threshold - scores) / - factor) > (2 * jitter_eps) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) # apply mask masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) @@ -196,24 +206,21 @@ def sparsemixer(scores, jitter_eps=0.01): ) with torch.no_grad(): # compute mask for sparsity - mask_logits_threshold, max_ind = masked_scores.max(dim=-1, - keepdim=True) + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) - mask_logits_threshold = ((mask_logits_threshold - scores) / - factor) > (2 * jitter_eps) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) # apply mask - masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, - float("-inf")) + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf")) selected_experts_top2 = max_ind # compute scores for gradients masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) - multiplier_top2 = masked_gates_top2.gather(dim=-1, - index=selected_experts_top2) + multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) - selected_experts = torch.concat((selected_experts, selected_experts_top2), - dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) return ( multiplier, @@ -227,8 +234,7 @@ def phimoe_routing_function( topk: int, renormalize: bool, ): - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert topk == 2, "Only top-2 routing is supported" assert renormalize is False, "Renormalization is not supported" @@ -279,7 +285,8 @@ class PhiMoE(nn.Module): quant_config=quant_config, tp_size=tp_size, custom_routing_function=phimoe_routing_function, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -292,7 +299,6 @@ class PhiMoE(nn.Module): class PhiMoEAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -377,7 +383,6 @@ class PhiMoEAttention(nn.Module): class PhiMoEDecoderLayer(nn.Module): - def __init__( self, config: PhiMoEConfig, @@ -394,8 +399,9 @@ class PhiMoEDecoderLayer(nn.Module): num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, - head_dim=getattr(config, "head_dim", - self.hidden_size // config.num_attention_heads), + head_dim=getattr( + config, "head_dim", self.hidden_size // config.num_attention_heads + ), rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, @@ -410,12 +416,12 @@ class PhiMoEDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) def forward( self, @@ -445,7 +451,6 @@ class PhiMoEDecoderLayer(nn.Module): @support_torch_compile class PhiMoEModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -454,8 +459,11 @@ class PhiMoEModel(nn.Module): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.config = config @@ -469,15 +477,17 @@ class PhiMoEModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: PhiMoEDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -500,7 +510,7 @@ class PhiMoEModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -508,10 +518,9 @@ class PhiMoEModel(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) return hidden_states @@ -524,8 +533,7 @@ class PhiMoEModel(nn.Module): num_experts=self.config.num_local_experts, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -537,14 +545,15 @@ class PhiMoEModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -595,8 +604,9 @@ class PhiMoEModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -628,8 +638,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = vllm_config.quant_config - self.model = PhiMoEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = PhiMoEModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -641,15 +652,20 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=None, bias=True, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -661,18 +677,16 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c01074e2122bb..62f642eae4b52 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,59 +5,77 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, - UserMessage) +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image -from transformers import PixtralVisionConfig, TensorType +from transformers import BatchFeature, PixtralVisionConfig, TensorType from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( - _num_image_tokens as _get_pixtral_hf_num_image_tokens) + _num_image_tokens as _get_pixtral_hf_num_image_tokens, +) from transformers.models.pixtral.modeling_pixtral import ( - PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) + PixtralRotaryEmbedding, + apply_rotary_pos_emb, + position_ids_in_meshgrid, +) from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalProcessingInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalUUIDDict, + NestedTensors, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import (MistralTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import ( + MistralTokenizer, + cached_tokenizer_from_config, +) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .utils import init_vllm_registered_model, maybe_prefix +from .vision import ( + VisionEncoderInfo, + VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs, +) try: from xformers import ops as xops - if (current_platform.is_cuda() - and current_platform.has_device_capability(100)): + + if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: @@ -68,16 +86,24 @@ except ImportError: PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), + ] + class PixtralProcessorAdapter: """ @@ -144,7 +170,8 @@ class PixtralProcessorAdapter: "Make sure to process your input via `mistral_common`'s " "tokenizer or pass a chat completion request. " "For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + "https://github.com/vllm-project/vllm/issues/8411." + ) images_processed = list[torch.Tensor]() images_tokens = list[torch.Tensor]() @@ -157,14 +184,15 @@ class PixtralProcessorAdapter: images_processed.append(image_processed) images_tokens.append(image_tokens) - return { - "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), - "images": images_processed, - } + return BatchFeature( + { + "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), + "images": images_processed, + } + ) class PixtralProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) if not isinstance(tokenizer, MistralTokenizer): @@ -201,7 +229,8 @@ class PixtralProcessingInfo(BaseProcessingInfo): processor = self.get_hf_processor() ncols, nrows = processor.image_processor._image_to_num_tokens( - Image.new("RGB", (image_width, image_height))) + Image.new("RGB", (image_width, image_height)) + ) return ncols * nrows @@ -213,7 +242,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -221,48 +249,57 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() dummy_text = self.get_dummy_text(mm_counts) - dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_images = dummy_mm_data.get("image", []) tokenization_kwargs = {"truncation": False} - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=dummy_text), - *(ImageChunk(image=image) for image in dummy_images), - ]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=dummy_text), + *(ImageChunk(image=image) for image in dummy_images), + ] + ), + ] + ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens - return ProcessorInputs(prompt=dummy_tokens, - mm_data=dummy_mm_data, - tokenization_kwargs=tokenization_kwargs) + return ProcessorInputs( + prompt=dummy_tokens, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs, + ) -class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] - ): - +class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], @@ -287,7 +324,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] image_size = images.get_image_size(item_idx) ncols, nrows = processor.image_processor._image_to_num_tokens( - Image.new("RGB", (image_size.width, image_size.height))) + Image.new("RGB", (image_size.width, image_size.height)) + ) tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id @@ -308,23 +346,27 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template return prompt_ids, mm_info, True -@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, - info=PixtralProcessingInfo, - dummy_inputs=PixtralDummyInputsBuilder) -class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + PixtralMultiModalProcessor, + info=PixtralProcessingInfo, + dummy_inputs=PixtralDummyInputsBuilder, +) +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -359,8 +401,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, self.vision_encoder = VisionTransformer(self.vision_args) if self.vision_args.add_pre_mm_projector_layer_norm: - self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, - eps=1e-5) + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5) if self.vision_args.mm_projector_id == PATCH_MERGE: self.patch_merger = PatchMerger( @@ -370,24 +411,23 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.vision_language_adapter = VisionLanguageAdapter( - self.vision_args, dim=config.text_config.hidden_size) + self.vision_args, dim=config.text_config.hidden_size + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: + self, **kwargs: object + ) -> Optional[PixtralImagePixelInputs]: images = kwargs.pop("images", None) if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", - images=flatten_bn(images), + images=images, ) def _process_image_input( @@ -396,23 +436,24 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ) -> tuple[torch.Tensor, ...]: images = image_input["images"] image_features = self.vision_encoder(images) - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features - ] + feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_features = torch.cat(image_features) if self.vision_args.add_pre_mm_projector_layer_norm: image_features = self.pre_mm_projector_norm(image_features) if self.vision_args.mm_projector_id == PATCH_MERGE: patch_size = self.vision_args.patch_size spatial_merge_size_square = self.vision_args.spatial_merge_size**2 - img_patch_dims = [(img.shape[1] // patch_size, - img.shape[2] // patch_size) for img in images] + img_patch_dims = [ + (img.shape[1] // patch_size, img.shape[2] // patch_size) + for img in images + ] feature_sizes = [ feature_size // spatial_merge_size_square for feature_size in feature_sizes ] - image_features = self.patch_merger(image_features, - image_sizes=img_patch_dims) + image_features = self.patch_merger( + image_features, image_sizes=img_patch_dims + ) image_embeds = self.vision_language_adapter(image_features) image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds @@ -420,30 +461,13 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.vision_args.image_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -456,31 +480,19 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") @@ -495,38 +507,42 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, # Get references to parameters for direct loading vision_encoder_dict = dict(self.vision_encoder.named_parameters()) - patch_merger_dict = dict(self.patch_merger.named_parameters( - )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict() - pre_mm_projector_norm_dict = dict( - self.pre_mm_projector_norm.named_parameters( - )) if self.vision_args.add_pre_mm_projector_layer_norm else dict() - vision_lang_adapter_dict = dict( - self.vision_language_adapter.named_parameters()) + patch_merger_dict = ( + dict(self.patch_merger.named_parameters()) + if self.vision_args.mm_projector_id == PATCH_MERGE + else dict() + ) + pre_mm_projector_norm_dict = ( + dict(self.pre_mm_projector_norm.named_parameters()) + if self.vision_args.add_pre_mm_projector_layer_norm + else dict() + ) + vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters()) def llm_weights_generator(): # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): # Load vision encoder weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): # Load vision patch merger weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): # Load vision pre_mm_projector_norm weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): # Load vision-language adapter weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) @@ -557,8 +573,7 @@ class VisionEncoderArgs: mm_projector_id: str = "" -def _reshape_for_broadcast(freqs_cis: torch.Tensor, - x: torch.Tensor) -> torch.Tensor: +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2) @@ -569,9 +584,7 @@ def _reshape_for_broadcast(freqs_cis: torch.Tensor, freqs_cis.shape, (x.shape[1], x.shape[-1]), ) - shape = [ - d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) - ] + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -586,7 +599,7 @@ def precompute_freqs_cis_2d( to be indexed by (height, width) position tuples """ # (dim / 2) frequency bases - freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height, device=freqs.device) w = torch.arange(width, device=freqs.device) @@ -618,26 +631,18 @@ def apply_rotary_emb_vit( class FeedForward(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() assert args.intermediate_size is not None - self.w1 = nn.Linear(args.hidden_size, - args.intermediate_size, - bias=False) - self.w2 = nn.Linear(args.intermediate_size, - args.hidden_size, - bias=False) - self.w3 = nn.Linear(args.hidden_size, - args.intermediate_size, - bias=False) + self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args @@ -671,10 +676,7 @@ class Attention(nn.Module): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - out = nn.functional.scaled_dot_product_attention(q, - k, - v, - attn_mask=mask) + out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) @@ -682,7 +684,6 @@ class Attention(nn.Module): class TransformerBlock(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.attention = Attention(args) @@ -696,9 +697,9 @@ class TransformerBlock(nn.Module): mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(x), - mask=mask, - freqs_cis=freqs_cis) + r = self.attention.forward( + self.attention_norm(x), mask=mask, freqs_cis=freqs_cis + ) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -706,7 +707,6 @@ class TransformerBlock(nn.Module): class Transformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.layers = torch.nn.ModuleList() @@ -724,22 +724,26 @@ class Transformer(nn.Module): return x -def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: - positions = torch.cat([ - torch.stack( - torch.meshgrid( - torch.arange(p.shape[-2]), - torch.arange(p.shape[-1]), - indexing="ij", - ), - dim=-1, - ).reshape(-1, 2) for p in patch_embeds_list - ]) +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) return positions class VisionTransformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args @@ -801,9 +805,7 @@ class VisionTransformer(nn.Module): self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] - patch_embeds = [ - p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list - ] + patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence @@ -817,13 +819,16 @@ class VisionTransformer(nn.Module): # pass through Transformer with a block diagonal mask delimiting images if USE_XFORMERS_OPS: mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) else: from transformers.models.pixtral.modeling_pixtral import ( - generate_block_attention_mask) + generate_block_attention_mask, + ) + mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], - patch_embeds) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # squeeze dim 0 and split into separate tensors for each image @@ -831,7 +836,6 @@ class VisionTransformer(nn.Module): class VisionLanguageAdapter(nn.Module): - def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) @@ -871,8 +875,9 @@ class PatchMerger(nn.Module): bias=use_mlp_bias, ) - def forward(self, x: torch.Tensor, - image_sizes: list[tuple[int, int]]) -> torch.Tensor: + def forward( + self, x: torch.Tensor, image_sizes: list[tuple[int, int]] + ) -> torch.Tensor: # image_sizes specified in tokens assert sum([h * w for h, w in image_sizes]) == len(x) @@ -904,15 +909,14 @@ class PatchMerger(nn.Module): """ sub_grids = get_sub_grids( - x=x, - image_sizes=image_sizes, - spatial_merge_size=self.spatial_merge_size + x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size ) # list of [d x sub_grid_size x sub_grid_size x n_patches] permuted_tensor: list[torch.Tensor] = [] for grid in sub_grids: n_patches = grid.shape[-1] - permuted_tensor.append(grid.view(-1, n_patches).t( - )) # n_patches x d * sub_grid_size * sub_grid_size + permuted_tensor.append( + grid.view(-1, n_patches).t() + ) # n_patches x d * sub_grid_size * sub_grid_size return torch.cat( permuted_tensor, dim=0 ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) @@ -932,14 +936,15 @@ def get_sub_grids( for image_index, image_tokens in enumerate(x.split(tokens_per_image)): # Reshape image_tokens into a 2D grid h, w = image_sizes[image_index] - image_grid = image_tokens.view(h, w, d).permute( - 2, 0, 1)[None, :, :, :] # 1 x d x h x w - sub_grids = torch.nn.functional.unfold(image_grid, - kernel_size=sub_grid_size, - stride=sub_grid_size) + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[ + None, :, :, : + ] # 1 x d x h x w + sub_grids = torch.nn.functional.unfold( + image_grid, kernel_size=sub_grid_size, stride=sub_grid_size + ) sub_grids = sub_grids.view( - 1, d, sub_grid_size, sub_grid_size, - -1) # 1 x d x sub_grid_size x sub_grid_size x n_patches + 1, d, sub_grid_size, sub_grid_size, -1 + ) # 1 x d x sub_grid_size x sub_grid_size x n_patches all_img_sub_grids.append(sub_grids[0]) @@ -955,7 +960,6 @@ def get_sub_grids( class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): - def get_num_image_tokens( self, *, @@ -1008,7 +1012,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): class PixtralHFMLP(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1024,12 +1027,15 @@ class PixtralHFMLP(nn.Module): output_sizes=[config.intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=config.intermediate_size, - output_size=config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_and_mul = get_act_and_mul_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1040,7 +1046,6 @@ class PixtralHFMLP(nn.Module): class PixtralHFAttention(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1096,14 +1101,12 @@ class PixtralHFAttention(nn.Module): # Transpose q and k back for attention q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() - out = xops.memory_efficient_attention(q, - k, - v, - attn_bias=attention_mask) + out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask) else: v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask) + q, k, v, attn_mask=attention_mask + ) out = out.transpose(1, 2) out = out.view(batch, patches, self.n_heads * self.head_dim) @@ -1113,7 +1116,6 @@ class PixtralHFAttention(nn.Module): class PixtralHFTransformerBlock(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1124,12 +1126,12 @@ class PixtralHFTransformerBlock(nn.Module): super().__init__() self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) - self.attention = PixtralHFAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.attention") - self.feed_forward = PixtralHFMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.feed_forward") + self.attention = PixtralHFAttention( + config, quant_config=quant_config, prefix=f"{prefix}.attention" + ) + self.feed_forward = PixtralHFMLP( + config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" + ) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) def forward( @@ -1138,9 +1140,11 @@ class PixtralHFTransformerBlock(nn.Module): attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: - r, _ = self.attention.forward(self.attention_norm(hidden_states), - attention_mask=attention_mask, - position_embeddings=position_embeddings) + r, _ = self.attention.forward( + self.attention_norm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) h = hidden_states + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -1148,7 +1152,6 @@ class PixtralHFTransformerBlock(nn.Module): class PixtralHFTransformer(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1164,12 +1167,16 @@ class PixtralHFTransformer(nn.Module): else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - PixtralHFTransformerBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + PixtralHFTransformerBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( self, @@ -1192,7 +1199,6 @@ class PixtralHFTransformer(nn.Module): class PixtralHFVisionModel(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1226,7 +1232,8 @@ class PixtralHFVisionModel(nn.Module): raise ValueError( f"The original encoder only has {num_hidden_layers} " f"layers, but you requested {len(self.transformer.layers)} " - "layers.") + "layers." + ) if require_post_norm is True: msg = "PixtralHFVisionModel does not have post-layernorm" @@ -1234,13 +1241,14 @@ class PixtralHFVisionModel(nn.Module): self.dtype = next(self.parameters()).dtype self.device = next(self.parameters()).device - self.patch_positional_embedding = PixtralRotaryEmbedding( - config, self.device) + self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device) def forward( self, pixel_values: list[torch.Tensor], - feature_sample_layers: Optional[list[int]] = None, + *, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> tuple[torch.Tensor, ...]: """ Args: @@ -1248,7 +1256,7 @@ class PixtralHFVisionModel(nn.Module): in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially - feature_sample_layers: Layer indices whose features should be + select_layers: Layer indices whose features should be concatenated and used as the visual encoder output. If none are provided, the last layer is used. @@ -1258,13 +1266,10 @@ class PixtralHFVisionModel(nn.Module): """ # pass images through initial convolution independently patch_embeds_list = [ - self.patch_conv(img.unsqueeze(0).to(self.dtype)) - for img in pixel_values + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values ] - patch_embeds = [ - p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list - ] + patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence @@ -1274,38 +1279,44 @@ class PixtralHFVisionModel(nn.Module): # positional embeddings position_ids = position_ids_in_meshgrid( patch_embeds_list, - max_width=self.config.image_size // self.config.patch_size).to( - self.device) - position_embedding = self.patch_positional_embedding( - patch_embeds, position_ids) + max_width=self.config.image_size // self.config.patch_size, + ).to(self.device) + position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) if USE_XFORMERS_OPS: attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) else: from transformers.models.pixtral.modeling_pixtral import ( - generate_block_attention_mask) - attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], - patch_embeds) + generate_block_attention_mask, + ) + + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) - return_all_hidden_states = feature_sample_layers is not None out = self.transformer( patch_embeds, attention_mask, position_embedding, - return_all_hidden_states=return_all_hidden_states) + return_all_hidden_states=select_layers is not None, + ) - out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, - self.config.num_hidden_layers) + out = resolve_visual_encoder_outputs( + out, + None, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + feature_select_strategy=feature_select_strategy, + ) # squeeze dim 0 and split into separate tensors for each image return torch.split(out.squeeze(0), embed_sizes) # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1325,7 +1336,7 @@ class PixtralHFVisionModel(nn.Module): if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1335,8 +1346,7 @@ class PixtralHFVisionModel(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e5034b536266a..278957e7cf6ce 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,53 +1,71 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only PLaMo2 model.""" + from collections.abc import Iterable -from typing import Optional +from itertools import islice +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_state_update) + causal_conv1d_fn, + causal_conv1d_update, +) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined) + mamba_chunk_scan_combined_varlen, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsPP, SupportsV0Only) -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) + composed_weight_loader, + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.models.interfaces import HasInnerState, IsHybrid, SupportsPP from vllm.model_executor.models.utils import ( - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Only used for type hinting. @@ -72,20 +90,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore vocab_size: int -class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore - - def _init_weights(self, module: torch.nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -98,24 +102,21 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: # Adapted from: # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # transformers.models.mamba.modeling_mamba.MambaMixer -class Plamo2MambaMixer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - *, - prefix: str = "", - **kwargs) -> None: +@CustomOp.register(name="plamo2_mamba_mixer") +class Plamo2MambaMixer(MambaBase, CustomOp): + def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None: super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.quant_config = vllm_config.quant_config self.hidden_size = self.config.hidden_size self.ssm_state_size = self.config.mamba_d_state self.conv_kernel_size = self.config.mamba_d_conv - self.intermediate_size = (self.config.mamba_num_heads * - self.config.hidden_size_per_head) + self.intermediate_size = ( + self.config.mamba_num_heads * self.config.hidden_size_per_head + ) self.tp_size = get_tensor_model_parallel_world_size() - self.intermediate_size_per_tp_worker = \ - self.intermediate_size // self.tp_size self.head_dim = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) @@ -165,17 +166,17 @@ class Plamo2MambaMixer(nn.Module): torch.empty( divide(self.num_heads, self.tp_size), dtype=torch.float32, - )) + ) + ) self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) - self.dt_bias = nn.Parameter( - torch.ones(divide(self.num_heads, self.tp_size))) + self.dt_bias = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + sharded_weight_loader(0), lambda x: -torch.exp(x.float()) + ) set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( self.intermediate_size, @@ -189,12 +190,21 @@ class Plamo2MambaMixer(nn.Module): # The activation function is fixed to SiLU. self.activation = "silu" - self.dt_norm = RMSNorm(self.time_step_rank, - eps=self.config.rms_norm_eps) - self.B_norm = RMSNorm(self.ssm_state_size, - eps=self.config.rms_norm_eps) - self.C_norm = RMSNorm(self.ssm_state_size, - eps=self.config.rms_norm_eps) + self.dt_norm = RMSNorm(self.time_step_rank, eps=self.config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + + self.chunk_size = self.config.mamba_chunk_size + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) + assert self.chunk_size != -1, "chunk_size must be set for v1" + + self.prefix = prefix def _project_ssm_parameters(self, hidden_states): ssm_parameters = self.bcdt_proj(hidden_states) @@ -211,65 +221,111 @@ class Plamo2MambaMixer(nn.Module): dt = self.dt_proj(time_step) return B, C, dt + def forward_native( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + **kwargs, + ): + pass + def forward( self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, + output: torch.Tensor, **kwargs, - ) -> torch.Tensor: + ): + torch.ops.vllm.plamo2_mamba_mixer( + hidden_states, + output, + self.prefix, + ) - # mamba2_metadata contains metadata necessary for the mamba2 triton + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + **kwargs, + ): + forward_context = get_forward_context() + # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + attn_metadata: AttentionMetadata = forward_context.attn_metadata - num_prefills = attn_metadata.num_prefills # request count - num_decodes = attn_metadata.num_decode_tokens # token count (=request) - num_prefill_tokens = attn_metadata.num_prefill_tokens # token count - has_prefill = num_prefills > 0 - has_decode = num_decodes > 0 + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states_p + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx_p + query_start_loc_p = attn_metadata.query_start_loc_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) gate, hidden_states = projected_states.chunk(2, dim=-1) # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + if attn_metadata is None: + # profile run + hidden_states = ( + hidden_states.transpose(0, 1).clone().transpose(0, 1) + ).contiguous() + output[:] = self.out_proj(hidden_states) + return + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - hidden_states_p, hidden_states_d = torch.split( - hidden_states, - [num_prefill_tokens, num_decodes], + hidden_states_d, hidden_states_p = torch.split( + hidden_states[:num_actual_tokens], + [num_decodes, num_prefill_tokens], dim=0, ) - gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], - dim=0) + gate_d, gate_p = torch.split( + gate[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0 + ) # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - mamba_cache_params.state_indices_tensor, - [num_prefills, num_decodes], + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], dim=0, ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] - if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( [ num_prefill_tokens + num_decodes, - (self.num_heads // self.tp_size) * self.head_dim + (self.num_heads // self.tp_size) * self.head_dim, ], dtype=hidden_states.dtype, device=hidden_states.device, ) - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( preallocated_ssm_out, - [num_prefill_tokens, num_decodes], + [num_decodes, num_prefill_tokens], dim=0, ) @@ -277,16 +333,19 @@ class Plamo2MambaMixer(nn.Module): if has_prefill: # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" + x = hidden_states_p.transpose(0, 1) # this is the form that causal-conv see hidden_states_p = causal_conv1d_fn( - hidden_states_p.transpose(0, 1), + x, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=mamba2_metadata.has_initial_states, + conv_states=conv_state, + has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - query_start_loc=query_start_loc_p) + metadata=attn_metadata, + query_start_loc=query_start_loc_p, + ) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] # In some instances, the following `bcdt_proj` op @@ -298,72 +357,77 @@ class Plamo2MambaMixer(nn.Module): # 3. State Space Model sequence transformation initial_states = None - if (mamba2_metadata.has_initial_states is not None - and mamba2_metadata.prep_initial_states): + if has_initial_states_p is not None and prep_initial_states: # making a copy of the states initial_states = torch.where( - mamba2_metadata.has_initial_states[:, None, None, None], - mamba_cache_params.ssm_state[state_indices_tensor_p], 0) - varlen_state = mamba_chunk_scan_combined( - hidden_states_p.view(1, num_prefill_tokens, - self.num_heads // self.tp_size, - self.head_dim), - dt.unsqueeze(0), + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], + 0, + ) + + varlen_state = mamba_chunk_scan_combined_varlen( + hidden_states_p.view( + num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), + dt, self.A, - B.view(1, num_prefill_tokens, 1, -1), - C.view(1, num_prefill_tokens, 1, -1), - chunk_size=mamba2_metadata.chunk_size, + B.view(num_prefill_tokens, 1, -1), + C.view(num_prefill_tokens, 1, -1), + chunk_size=chunk_size, D=self.D, - z=gate_p.view(1, num_prefill_tokens, - self.num_heads // self.tp_size, self.head_dim), + z=gate_p.view( + num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), dt_bias=self.dt_bias, - seq_idx=mamba2_metadata.seq_idx, - chunk_indices=mamba2_metadata.chunk_indices, - chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + seq_idx=seq_idx_p, + cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, - return_varlen_states=True, - return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, - self.head_dim), + out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, ) # update ssm states # - varlen state is a (batch, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + ssm_state[state_indices_tensor_p] = varlen_state # Process decode requests if has_decode: # 2. Convolution sequence transformation hidden_states_d = causal_conv1d_update( hidden_states_d, - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + ) B, C, dt = self._project_ssm_parameters(hidden_states_d) # 3. State Space Model sequence transformation - A = self.A[:, None, ...][:, :, - None].expand(-1, self.head_dim, - self.config.mamba_d_state) + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.config.mamba_d_state + ) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.unsqueeze(1) C = C.unsqueeze(1) hidden_states_d = hidden_states_d.view( - -1, self.num_heads // self.tp_size, self.head_dim) + -1, self.num_heads // self.tp_size, self.head_dim + ) # - the hidden is reshaped into (bs, num_heads, head_dim) - # - mamba_cache_params.ssm_state's slots will be selected + # - ssm_state's slots will be selected # using state_indices_tensor_d + + # NOTE: final output is an in-place update of out tensor selective_state_update( - mamba_cache_params.ssm_state, + ssm_state, hidden_states_d, dt, A, @@ -374,18 +438,69 @@ class Plamo2MambaMixer(nn.Module): dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices_tensor_d, - out=preallocated_ssm_out_d.view(num_decodes, -1, - self.head_dim), + out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) - assert self.num_heads % self.tp_size == 0 # 4. Final linear projection - out = self.out_proj(preallocated_ssm_out) - return out + output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out) + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=self.intermediate_size, + tp_world_size=get_tensor_model_parallel_world_size(), + n_groups=0, + num_heads=self.num_heads, + head_dim=self.head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + ) + + @property + def mamba_type(self) -> str: + return "mamba2" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend + + return Mamba2AttentionBackend + + +def plamo2_mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, output=output) + + +def plamo2_mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="plamo2_mamba_mixer", + op_func=plamo2_mamba_mixer, + mutates_args=["output"], + fake_impl=plamo2_mamba_mixer_fake, +) class DenseMLP(nn.Module): - def __init__( self, config: Plamo2Config, @@ -404,12 +519,14 @@ class DenseMLP(nn.Module): return_bias=False, ) self.act = SiluAndMul() - self.down_proj = RowParallelLinear(self.intermediate_size, - self.hidden_size, - bias=False, - prefix=f"{prefix}.down_proj", - quant_config=quant_config, - return_bias=False) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + prefix=f"{prefix}.down_proj", + quant_config=quant_config, + return_bias=False, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: h = self.gate_up_proj(hidden_states) @@ -417,14 +534,8 @@ class DenseMLP(nn.Module): return self.down_proj(h) -@support_torch_compile class Plamo2AttentionMixer(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -457,20 +568,22 @@ class Plamo2AttentionMixer(nn.Module): bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) - self.rope_theta = config.rope_theta if hasattr(config, - "rope_theta") else 10000 - self.rope_scaling = config.rope_scaling if hasattr( - config, "rope_scaling") else None + self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000 + self.rope_scaling = ( + config.rope_scaling if hasattr(config, "rope_scaling") else None + ) max_position = config.max_position_embeddings if hasattr(vllm_config.model_config, "max_model_len") and isinstance( - vllm_config.model_config.max_model_len, int): - max_position = min(max_position, - vllm_config.model_config.max_model_len) + vllm_config.model_config.max_model_len, int + ): + max_position = min(max_position, vllm_config.model_config.max_model_len) self.rotary_emb = get_rope( self.head_dim, @@ -479,22 +592,24 @@ class Plamo2AttentionMixer(nn.Module): base=self.rope_theta, rope_scaling=self.rope_scaling, ) - self.q_norm = RMSNorm(config.hidden_size_per_head, - eps=config.rms_norm_eps) + self.q_norm = RMSNorm(config.hidden_size_per_head, eps=config.rms_norm_eps) self.q_norm.weight = torch.nn.Parameter( - torch.ones((self.num_heads, config.hidden_size_per_head))) - set_weight_attrs(self.q_norm.weight, - {"weight_loader": sharded_weight_loader(0)}) - self.k_norm = RMSNorm(config.hidden_size_per_head, - eps=config.rms_norm_eps) + torch.ones((self.num_heads, config.hidden_size_per_head)) + ) + set_weight_attrs( + self.q_norm.weight, {"weight_loader": sharded_weight_loader(0)} + ) + self.k_norm = RMSNorm(config.hidden_size_per_head, eps=config.rms_norm_eps) self.k_norm.weight = torch.nn.Parameter( - torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + torch.ones((self.num_kv_heads, config.hidden_size_per_head)) + ) # Tensor-parallelism shards the K norm weights to the tp ranks # in a head-wise manner. This approach does not work if there is only # a single KV head, as is the case for PLaMo 2-1B. if self.total_num_kv_heads != 1: - set_weight_attrs(self.k_norm.weight, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs( + self.k_norm.weight, {"weight_loader": sharded_weight_loader(0)} + ) self.attn = Attention( self.num_heads, @@ -528,58 +643,60 @@ class Plamo2AttentionMixer(nn.Module): class Plamo2DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - layer_idx: int, - prefix: str = "", - **kwargs) -> None: + def __init__( + self, vllm_config: VllmConfig, layer_idx: int, prefix: str = "", **kwargs + ) -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.is_mamba = is_mamba(config, layer_idx) if self.is_mamba: - self.mixer = Plamo2MambaMixer(vllm_config=vllm_config, - prefix=f"{prefix}.mixer") + self.mixer = Plamo2MambaMixer( + vllm_config=vllm_config, prefix=f"{prefix}.mixer" + ) else: - self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config, - prefix=f"{prefix}.mixer") + self.mixer = Plamo2AttentionMixer( + vllm_config=vllm_config, prefix=f"{prefix}.mixer" + ) - self.mlp = DenseMLP(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.pre_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = DenseMLP( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.pre_mixer_norm(hidden_states) else: - hidden_states, residual = self.pre_mixer_norm( - hidden_states, residual) + hidden_states, residual = self.pre_mixer_norm(hidden_states, residual) + if self.is_mamba: + # Plamo2MambaMixer writes output to this tensor + output = torch.empty_like(hidden_states) + mixer_kwargs = { + "output": output, + } + else: + mixer_kwargs = { + "positions": positions, + } hidden_states = self.mixer( - positions=positions, hidden_states=hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, + **mixer_kwargs, ) + if self.is_mamba: + hidden_states = output hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -589,52 +706,43 @@ class Plamo2DecoderLayer(nn.Module): class Plamo2Decoder(torch.nn.Module): - - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - return Plamo2DecoderLayer(vllm_config=vllm_config, - layer_idx=layer_idx, - prefix=prefix, - **extra_kwargs) + return Plamo2DecoderLayer( + vllm_config=vllm_config, + layer_idx=layer_idx, + prefix=prefix, + **extra_kwargs, + ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: - mamba_cache_index = 0 - for layer in self.layers[self.start_layer:self.end_layer]: - layer_mamba_cache_params = None - if layer.is_mamba: - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - mamba_cache_index) - mamba_cache_index += 1 - + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) return hidden_states, residual -class Plamo2Model(Plamo2PreTrainedModel): - +@support_torch_compile +class Plamo2Model(torch.nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config.model_config.hf_config) + super().__init__() config = vllm_config.model_config.hf_config @@ -649,12 +757,11 @@ class Plamo2Model(Plamo2PreTrainedModel): org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + self.layers = Plamo2Decoder(vllm_config=vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -663,7 +770,6 @@ class Plamo2Model(Plamo2PreTrainedModel): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -678,30 +784,20 @@ class Plamo2Model(Plamo2PreTrainedModel): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) - hidden_states, residual = self.layers( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, - IsHybrid, SupportsV0Only): +class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -711,12 +807,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() config = vllm_config.model_config.hf_config scheduler_config = vllm_config.scheduler_config - assert not vllm_config.cache_config.enable_prefix_caching, \ - "PLaMo2 currently does not support prefix caching" - super().__init__(config) self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -727,8 +821,9 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, # the case for PLaMo2, as indicated by the FIXME comment. self.config.head_dim = self.config.hidden_size_per_head - self.model = Plamo2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Plamo2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.vocab_size = self.config.vocab_size self.unpadded_vocab_size = self.config.vocab_size num_embeddings = ((self.vocab_size + 15) // 16) * 16 @@ -742,88 +837,77 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.config.vocab_size) - self.sampler = get_sampler() + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - # Initialize weights and apply final processing - self.post_init() + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - - self.mamba_cache = MambaCacheManager( - self.vllm_config, - num_mamba_layers, - *self._get_mamba_cache_shape(), - self.lm_head.weight.dtype, - self.lm_head.weight.dtype, - ) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - - hidden_states = self.model(input_ids, positions, mamba_cache_params, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = (self.config.mamba_num_heads * - self.config.hidden_size_per_head) - conv_state_shape = ( - hidden_size // world_size, - self.config.mamba_d_conv - 1, + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) - temporal_state_shape = ( - divide(self.config.mamba_num_heads, world_size), - self.config.hidden_size_per_head, - self.config.mamba_d_state, + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + Args: + vllm_config: vLLM config + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_num_heads * hf_config.hidden_size_per_head + + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=0, + num_heads=hf_config.mamba_num_heads, + head_dim=hf_config.hidden_size_per_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, ) - return conv_state_shape, temporal_state_shape def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # Both tie_word_embeddings=True and lm_head.weight in the safetensor # at the same time causes dict key access error. if name == "lm_head.weight" and self.config.tie_word_embeddings: @@ -855,10 +939,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, # Also, in addition to the quantized weights, # the zero points and scales have to be reshaped as well. # Packing should not be affected by this. - if ".mixer.in_proj.weight" in name \ - or "mixer.in_proj.qweight" in name \ - or "mixer.in_proj.scales" in name \ - or "mixer.in_proj.qzeros" in name: + if ( + ".mixer.in_proj.weight" in name + or "mixer.in_proj.qweight" in name + or "mixer.in_proj.scales" in name + or "mixer.in_proj.qzeros" in name + ): if "mixer.in_proj.weight" in name: loaded_weight = loaded_weight.transpose(0, 1) # for weight: @@ -868,14 +954,14 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, # for scales and qzeros: # loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa loaded_weight = loaded_weight.reshape( - loaded_weight.shape[0], self.config.mamba_num_heads, -1) - gate_weight, hidden_states_weight = loaded_weight.chunk(2, - dim=-1) + loaded_weight.shape[0], self.config.mamba_num_heads, -1 + ) + gate_weight, hidden_states_weight = loaded_weight.chunk(2, dim=-1) gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1) hidden_states_weight = hidden_states_weight.reshape( - loaded_weight.shape[0], -1) - loaded_weight = torch.cat([gate_weight, hidden_states_weight], - dim=-1) + loaded_weight.shape[0], -1 + ) + loaded_weight = torch.cat([gate_weight, hidden_states_weight], dim=-1) if "mixer.in_proj.weight" in name: loaded_weight = loaded_weight.transpose(0, 1) @@ -896,6 +982,5 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py deleted file mode 100644 index 59e9f3e8a47b0..0000000000000 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ /dev/null @@ -1,309 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2025 The vLLM team. -# Copyright 2025 IBM. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only IBM/NASA Prithvi Geospatial model.""" - -from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, Union - -import torch -import torch.nn as nn -from transformers import BatchFeature - -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, - default_pooling_type) -from vllm.model_executor.models.utils import AutoWeightsLoader -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, - PlaceholderRange) -from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors - - -def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). - - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) - - -class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): - - def _parse_image_data( - self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: - if isinstance(data, dict): - return DictEmbeddingItems( - data, - modality="image", - required_fields={"pixel_values", "location_coords"}, - fields_factory=_prithvi_field_config, - ) - - return super()._parse_image_data(data) - - -class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - -class PrithviGeoSpatialMAEInputBuilder( - BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - # This model input is fixed and is in the form of a torch Tensor. - # The size of pixel_values might change in the cases where we resize - # the input but never exceeds the dimensions below. - image_data = { - "pixel_values": torch.full((6, 512, 512), 1.0, - dtype=torch.float16), - "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), - } - - return {"image": image_data} - - -class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): - - def _get_data_parser(self) -> MultiModalDataParser: - return PrithviGeoSpatialMAEMultiModalDataParser() - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return _prithvi_field_config(hf_inputs) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - return [] - - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - ) -> MultiModalInputs: - if "image" in mm_data: - image_data = mm_data["image"] - else: - image_data = mm_data - mm_data = {"image": mm_data} - - mm_items = self._to_mm_items(mm_data) - mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs or {}) - mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - - mm_processed_data = BatchFeature(image_data) - - mm_kwargs = MultiModalKwargsItems.from_hf_inputs( - mm_processed_data, - self._get_mm_fields_config(mm_processed_data, - hf_processor_mm_kwargs), - ) - - return MultiModalInputs( - type="multimodal", - prompt=prompt, - prompt_token_ids=[1], - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholders, - ) - - -@default_pooling_type("All") -@MULTIMODAL_REGISTRY.register_processor( - PrithviGeoSpatialMAEMultiModalProcessor, - info=PrithviGeoSpatialMAEProcessingInfo, - dummy_inputs=PrithviGeoSpatialMAEInputBuilder, -) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, - SupportsMultiModalWithRawInput): - """Prithvi Masked Autoencoder""" - - is_pooling_model = True - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return None - - raise ValueError("Only image modality is supported") - - def _instantiate_model(self, config: dict) -> Optional[nn.Module]: - # We might be able/need to support different tasks with this same model - if config["task_args"]["task"] == "SemanticSegmentationTask": - from terratorch.cli_tools import SemanticSegmentationTask - - task = SemanticSegmentationTask( - config["model_args"], - config["task_args"]["model_factory"], - loss=config["task_args"]["loss"], - lr=config["task_args"]["lr"], - ignore_index=config["task_args"]["ignore_index"], - optimizer=config["task_args"]["optimizer"], - optimizer_hparams=config["optimizer_params"], - scheduler=config["task_args"]["scheduler"], - scheduler_hparams=config["scheduler_params"], - plot_on_val=config["task_args"]["plot_on_val"], - freeze_decoder=config["task_args"]["freeze_decoder"], - freeze_backbone=config["task_args"]["freeze_backbone"], - ) - - return task.model - else: - return None - - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - # the actual model is dynamically instantiated using terratorch - # allowing us to perform changes to the model architecture - # at startup time (e.g., change the model decoder class.) - self.model = self._instantiate_model( - vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) - if self.model is None: - raise ValueError( - "Unsupported task. " - "Only SemanticSegmentationTask is supported for now " - "by PrithviGeospatialMAE.") - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) - - def _parse_and_validate_multimodal_data( - self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - pixel_values = kwargs.pop("pixel_values", None) - if not isinstance(pixel_values, torch.Tensor): - raise ValueError(f"Incorrect type of pixel_values. " - f"Got type: {type(pixel_values)}") - - location_coords = kwargs.pop("location_coords", None) - if not isinstance(location_coords, torch.Tensor): - raise ValueError(f"Incorrect type of location_coords. " - f"Got type: {type(location_coords)}") - location_coords = torch.unbind(location_coords, dim=0)[0] - if location_coords.shape == torch.Size([0]): - location_coords = None - - return pixel_values, location_coords - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - # We do not really use any input tokens and therefore no embeddings - # to be calculated. However, due to the mandatory token ids in - # the input prompt we pass one token and the size of the dummy - # embedding tensors must reflect that. - return torch.empty((input_ids.shape[0], 0)) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ): - pixel_values, location_coords = ( - self._parse_and_validate_multimodal_data(**kwargs)) - model_output = self.model(pixel_values, - location_coords=location_coords) - - return model_output.output - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - params_list = [] - model_buffers = dict(self.named_buffers()) - loaded_buffers = [] - for key, value in weights: - if key == "state_dict": - weights_to_parse = value - for name, weight in weights_to_parse.items(): - if "pos_embed" in name: - continue - - if "_timm_module." in name: - name = name.replace("_timm_module.", "") - - # this model requires a couple of buffers to be loaded - # that are not loadable with the AutoWeightsLoader - if name in model_buffers: - if "_timm_module." in name: - name = name.replace("_timm_module.", "") - buffer = model_buffers[name] - weight_loader = getattr(buffer, "weight_loader", - default_weight_loader) - weight_loader(buffer, weight) - loaded_buffers.append(name) - else: - params_list.append((name, weight)) - break - - # Load the remaining model parameters - loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(params_list) - - return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e804f03e014e1..6a12776b7f94b 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,8 +6,10 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" + import json from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -20,22 +22,28 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class QWenMLP(nn.Module): @@ -51,16 +59,15 @@ class QWenMLP(nn.Module): ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.c_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.c_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -71,7 +78,6 @@ class QWenMLP(nn.Module): class QWenAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -85,12 +91,10 @@ class QWenAttention(nn.Module): ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads self.c_attn = QKVParallelLinear( hidden_size, @@ -114,12 +118,14 @@ class QWenAttention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -135,7 +141,6 @@ class QWenAttention(nn.Module): class QWenBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -148,20 +153,22 @@ class QWenBlock(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - self.attn = QWenAttention(config.hidden_size, - config.num_attention_heads, - config.max_position_embeddings, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = QWenAttention( + config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.hidden_size, - config.intermediate_size // 2, - quant_config=quant_config) + self.mlp = QWenMLP( + config.hidden_size, config.intermediate_size // 2, quant_config=quant_config + ) def forward( self, @@ -188,7 +195,6 @@ class QWenBlock(nn.Module): @support_torch_compile class QWenModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -205,13 +211,13 @@ class QWenModel(nn.Module): ) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: QWenBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -234,23 +240,21 @@ class QWenModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.h[self.start_layer:self.end_layer]: + for layer in islice(self.h, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states class QWenBaseModel(nn.Module): - def __init__( self, *, @@ -265,29 +269,30 @@ class QWenBaseModel(nn.Module): self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config - self.transformer = transformer_type(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.transformer = transformer_type( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), @@ -298,7 +303,7 @@ class QWenBaseModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -320,8 +325,7 @@ class QWenBaseModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -339,14 +343,13 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if hasattr(config, "visual"): - hf_overrides = { - "architectures": ["QwenVLForConditionalGeneration"] - } + hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]} raise RuntimeError( "The configuration of this model indicates that it supports " "vision inputs, but you instantiated the text-only version " "of this model. Please use the vision model by setting " - f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + f"`--hf-overrides '{json.dumps(hf_overrides)}'`" + ) super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -357,6 +360,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 801741ecaf3b8..c8bc17dbfa0a1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -24,7 +24,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -38,29 +40,38 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Qwen2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -85,8 +96,9 @@ class Qwen2MLP(nn.Module): prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -97,7 +109,6 @@ class Qwen2MLP(nn.Module): class Qwen2Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -160,8 +171,11 @@ class Qwen2Attention(nn.Module): rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) self.attn = attn_cls( self.num_heads, self.head_dim, @@ -174,7 +188,10 @@ class Qwen2Attention(nn.Module): **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) + } + if dual_chunk_attention_config + else {}, + ) def forward( self, @@ -190,7 +207,6 @@ class Qwen2Attention(nn.Module): class Qwen2DecoderLayer(nn.Module): - def __init__( self, config: Qwen2Config, @@ -203,9 +219,9 @@ class Qwen2DecoderLayer(nn.Module): # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -236,10 +252,10 @@ class Qwen2DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -252,16 +268,14 @@ class Qwen2DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -274,17 +288,19 @@ class Qwen2DecoderLayer(nn.Module): "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Qwen2Model(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, + ): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -297,14 +313,16 @@ class Qwen2Model(nn.Module): "to discuss this feature.".format( config.max_window_layers, config.num_hidden_layers, - )) + ) + ) self.config = config self.quant_config = quant_config self.vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -318,22 +336,24 @@ class Qwen2Model(nn.Module): decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers: tuple[int] = tuple() + self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -358,16 +378,16 @@ class Qwen2Model(nn.Module): aux_hidden_states = [] for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) @@ -376,8 +396,7 @@ class Qwen2Model(nn.Module): return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -391,18 +410,19 @@ class Qwen2Model(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -417,8 +437,7 @@ class Qwen2Model(nn.Module): if name is None: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: @@ -435,8 +454,7 @@ class Qwen2Model(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -465,33 +483,36 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) @@ -502,24 +523,21 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 664e3f2985a59..0df79fc733f3f 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,51 +25,86 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn +from transformers import PretrainedConfig from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( - Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig) + Qwen2_5OmniConfig, + Qwen2_5OmniThinkerConfig, +) from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( - Qwen2_5OmniAudioEncoder) + Qwen2_5OmniAudioEncoder, +) from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( - Qwen2_5OmniProcessor) + Qwen2_5OmniProcessor, +) from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, - Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, - Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) + Qwen2_5_VisionTransformer, + Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLProcessingInfo, + Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioInputs, Qwen2AudioProcessingInfo, - _get_feat_extract_output_lengths) + Qwen2AudioProcessingInfo, + _get_feat_extract_output_lengths, +) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalPromptUpdates, - PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens +from vllm.transformers_utils.tokenizer import encode_tokens +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, + split_list_into_ranges, +) +from .vision import get_llm_pos_ids_for_vision try: import flash_attn @@ -79,53 +114,77 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) -def create_qwen2_5_omni_thinker_field_factory( - spatial_merge_size: int -) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, - MultiModalFieldConfig]]: +class Qwen2_5OmniAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + - msl: Maximum sequence length + - tsl: Total sequence length + """ - def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, - torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nmb", "tsl"), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", "msl"), + ] + + +def create_qwen2_5_omni_thinker_field_factory( + spatial_merge_size: int, +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig]]: + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_feature_lengths = hf_inputs.get( + "audio_feature_lengths", torch.empty((0,)) + ) image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_pixel_grid_sizes = image_grid_thw.prod(-1) - image_embed_grid_sizes = (image_pixel_grid_sizes // - spatial_merge_size // spatial_merge_size) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) - video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // - spatial_merge_size) + video_embed_grid_sizes = ( + video_grid_sizes // spatial_merge_size // spatial_merge_size + ) num_videos = len(video_grid_sizes) return dict( input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), + "audio", audio_feature_lengths, dim=1 + ), feature_attention_mask=MultiModalFieldConfig.batched("audio"), audio_feature_lengths=MultiModalFieldConfig.batched("audio"), pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_pixel_grid_sizes), + "image", image_pixel_grid_sizes + ), image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_embed_grid_sizes), + "image", image_embed_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_embed_grid_sizes), + "video", video_embed_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), second_per_grid_ts=MultiModalFieldConfig.batched("video"), - use_audio_in_video=MultiModalFieldConfig.shared( - "video", num_videos), + use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), ) return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): - def __init__(self, spatial_merge_size: int, *args, **kwargs): self._spatial_merge_size = spatial_merge_size super().__init__(self._spatial_merge_size, *args, **kwargs) @@ -138,19 +197,18 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): return DictEmbeddingItems( data, modality="audio", - required_fields={ - "input_audio_features", "audio_feature_lengths" - }, + required_fields={"input_audio_features", "audio_feature_lengths"}, fields_factory=create_qwen2_5_omni_thinker_field_factory( - self._spatial_merge_size), + self._spatial_merge_size + ), ) return super()._parse_audio_data(data) -class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, - Qwen2_5_VLProcessingInfo): - +class Qwen2_5OmniThinkerProcessingInfo( + Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo +): def get_hf_config(self): return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config @@ -172,8 +230,8 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, class Qwen2_5OmniThinkerDummyInputsBuilder( - BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]): - + BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -185,13 +243,17 @@ class Qwen2_5OmniThinkerDummyInputsBuilder( image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token - return (audio_token * num_audios + image_token * num_images + - video_token * num_videos) + return ( + audio_token * num_audios + + image_token * num_images + + video_token * num_videos + ) def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -199,42 +261,55 @@ class Qwen2_5OmniThinkerDummyInputsBuilder( feature_extractor = self.info.get_feature_extractor() - target_audio_length = min( - feature_extractor.chunk_length, - 30, - ) * feature_extractor.sampling_rate - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_audio_length = ( + min( + feature_extractor.chunk_length, + 30, + ) + * feature_extractor.sampling_rate + ) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + audio_overrides = mm_options.get("audio") if mm_options else None mm_data = { - "audio": - self._get_dummy_audios(length=target_audio_length, - num_audios=num_audios), - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos(width=target_width, - height=target_height, - num_frames=target_num_frames, - num_videos=num_videos), + "audio": self._get_dummy_audios( + length=target_audio_length, + num_audios=num_audios, + overrides=audio_overrides, + ), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ), } return mm_data class Qwen2_5OmniThinkerMultiModalProcessor( - BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): - + BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo] +): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( - spatial_merge_size=self.info.get_hf_config( - ).vision_config.spatial_merge_size, - target_sr=feature_extractor.sampling_rate) + spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size, + target_sr=feature_extractor.sampling_rate, + ) def _call_hf_processor( self, @@ -250,7 +325,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor( if audios: # NOTE: Qwen2.5-Omni processor accept "audio" mm_data["audio"] = audios - mm_kwargs = dict(**mm_kwargs, ) + mm_kwargs = dict( + **mm_kwargs, + ) hf_inputs = super()._call_hf_processor( prompt=prompt, @@ -259,17 +336,19 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tok_kwargs=tok_kwargs, ) - input_features = hf_inputs.pop('input_features', None) - feature_attention_mask = hf_inputs.get('feature_attention_mask', None) - if ('input_audio_features' not in hf_inputs - and input_features is not None): + input_features = hf_inputs.pop("input_features", None) + feature_attention_mask = hf_inputs.get("feature_attention_mask", None) + if "input_audio_features" not in hf_inputs and input_features is not None: if feature_attention_mask is not None: - input_features = input_features.permute( - 0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) - hf_inputs['input_audio_features'] = input_features - if ('audio_feature_lengths' not in hf_inputs - and feature_attention_mask is not None): - hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + hf_inputs["input_audio_features"] = input_features + if ( + "audio_feature_lengths" not in hf_inputs + and feature_attention_mask is not None + ): + hf_inputs["audio_feature_lengths"] = feature_attention_mask.sum(-1) video_second_per_grid = hf_inputs.get("video_second_per_grid", None) if video_second_per_grid is not None: @@ -286,8 +365,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return create_qwen2_5_omni_thinker_field_factory( - self.info.get_hf_config().vision_config.spatial_merge_size)( - hf_inputs) + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) def _maybe_apply_prompt_updates( self, @@ -296,49 +375,98 @@ class Qwen2_5OmniThinkerMultiModalProcessor( mm_kwargs: MultiModalKwargsItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + self._validate_mm_updates(mm_prompt_updates, mm_item_counts) - use_audio_in_video = (all( - item["use_audio_in_video"].data - for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) + use_audio_in_video = False + if "video" in mm_kwargs: + video_items = [item for item in mm_kwargs["video"] if item is not None] + # only check video items (if there are any) + if video_items: + use_audio_in_video = all( + item["use_audio_in_video"].data for item in video_items + ) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, mm_item_counts, - use_audio_in_video=use_audio_in_video) - - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) + use_audio_in_video=use_audio_in_video, + ) else: - ( - prompt_ids, - prompt, - mm_placeholders, - ) = self._apply_prompt_updates( + prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders( mm_placeholders, mm_item_counts, - use_audio_in_video=use_audio_in_video) + use_audio_in_video=use_audio_in_video, + ) - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) + return prompt_ids, mm_placeholders - return prompt_ids, prompt, mm_placeholders + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = ( + torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second + ).long() + t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates def _get_prompt_updates( self, @@ -348,8 +476,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) vocab = tokenizer.get_vocab() audio_token = processor.audio_token @@ -366,12 +493,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( audio_output_lengths = [] elif audio_feature_lengths is not None: _, audio_output_lens = _get_feat_extract_output_lengths( - audio_feature_lengths) + audio_feature_lengths + ) audio_output_lengths = audio_output_lens.tolist() elif feature_attention_mask is not None: assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lens = _get_feat_extract_output_lengths( - feature_attention_mask.sum(-1)) + feature_attention_mask.sum(-1) + ) audio_output_lengths = audio_output_lens.tolist() # number of audios read from video. @@ -386,7 +515,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor( audio = audios.get(item_idx) raise ValueError( f"The audio {audio} (len={len(audio)}) is too short " - "to be represented inside the model") + "to be represented inside the model" + ) return [audio_token_id] * num_features @@ -398,27 +528,26 @@ class Qwen2_5OmniThinkerMultiModalProcessor( token_id = image_token_id if modality == "image" else video_token_id return [token_id] * (int(grid_thw.prod()) // merge_length) - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) thinker_config = self.info.get_hf_config() def get_replacement_qwen2_use_audio_in_video(item_idx: int): nonlocal audio_in_video_item_idx - audio_num_features = audio_output_lengths[audio_in_video_item_idx + - item_idx] + audio_num_features = audio_output_lengths[ + audio_in_video_item_idx + item_idx + ] video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 - second_per_grid_ts = hf_processor_mm_kwargs.get( - "second_per_grid_ts", None) + second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None) if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[item_idx] else: video_second_per_grid_t = 1.0 - return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + return self.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, @@ -426,8 +555,10 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ) video_replacement_fn = ( - get_replacement_qwen2_use_audio_in_video if use_audio_in_video else - partial(get_replacement_qwen2_vision, modality="video")) + get_replacement_qwen2_use_audio_in_video + if use_audio_in_video + else partial(get_replacement_qwen2_vision, modality="video") + ) return [ PromptReplacement( @@ -438,8 +569,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( PromptReplacement( modality="image", target=image_token, - replacement=partial(get_replacement_qwen2_vision, - modality="image"), + replacement=partial(get_replacement_qwen2_vision, modality="image"), ), PromptReplacement( modality="video", @@ -492,8 +622,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( """ mm_counts = mm_items.get_all_counts() - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) if use_audio_in_video and "video" in mm_counts: assert "audio" in mm_counts mm_counts["audio"] -= mm_counts["video"] @@ -522,37 +651,44 @@ class Qwen2_5OmniThinkerMultiModalProcessor( class Qwen2_5OmniConditionalGenerationMixin: - - def _validate_and_reshape_mm_tensor(self, - mm_input: object, - name: str, - dim: int = 0) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str, dim: int = 0 + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): + if dim == 0: + return mm_input.reshape(-1, *mm_input.shape[2:]) return torch.concat(list(mm_input), dim=dim) else: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: - input_audio_features = kwargs.pop('input_audio_features', None) - audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) - feature_attention_mask = kwargs.pop('feature_attention_mask', None) + self, **kwargs: object + ) -> Optional[Qwen2_5OmniAudioFeatureInputs]: + input_audio_features = kwargs.pop("input_audio_features", None) + audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) if input_audio_features is None: return None input_audio_features = self._validate_and_reshape_mm_tensor( - input_audio_features, 'input_audio_features', dim=1) + input_audio_features, "input_audio_features", dim=1 + ) if feature_attention_mask is not None: feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') + feature_attention_mask, "feature_attention_mask" + ) if not isinstance(input_audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_audio_features)}") - return Qwen2AudioInputs(input_features=input_audio_features, - audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask) + raise ValueError( + "Incorrect type of audio input features. " + f"Got type: {type(input_audio_features)}" + ) + return Qwen2_5OmniAudioFeatureInputs( + type="audio_features", + input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask, + ) def _parse_and_validate_image_input( self, @@ -567,31 +703,42 @@ class Qwen2_5OmniConditionalGenerationMixin: if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + "Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}" + ) - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( self, @@ -606,9 +753,11 @@ class Qwen2_5OmniConditionalGenerationMixin: if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", @@ -618,53 +767,58 @@ class Qwen2_5OmniConditionalGenerationMixin: if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") + raise ValueError( + "Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}" + ) return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_audio_input( self, - audio_input: Qwen2AudioInputs, + audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: - input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] if input_features.ndim == 3: assert input_features.shape[0] == 1 input_features = input_features.squeeze(0) if audio_feature_lengths.ndim == 2: - assert audio_feature_lengths.shape[ - 0] == 1 or audio_feature_lengths.shape[1] == 1 + assert ( + audio_feature_lengths.shape[0] == 1 + or audio_feature_lengths.shape[1] == 1 + ) if audio_feature_lengths.shape[0] == 1: audio_feature_lengths = audio_feature_lengths.squeeze(0) else: audio_feature_lengths = audio_feature_lengths.squeeze(1) audio_feat_lengths, audio_output_lengths = ( - self.audio_tower._get_feat_extract_output_lengths( - audio_feature_lengths)) + self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths) + ) audio_outputs = self.audio_tower( input_features.to(self.audio_tower.dtype), feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) - audio_features = audio_outputs.last_hidden_state - return audio_features.split(audio_output_lengths.tolist()) + return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist()) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["image_embeds"].type(self.visual.dtype) @@ -680,18 +834,18 @@ class Qwen2_5OmniConditionalGenerationMixin: return image_embeds.split(sizes.tolist()) def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs, - video_hashes: list[str] = None, - cached_video_embeds: torch.Tensor = None) -> torch.Tensor: + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] = None, + cached_video_embeds: torch.Tensor = None, + ) -> torch.Tensor: if video_input["type"] == "video_embeds": return video_input["video_embeds"].type(self.visual.dtype) grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size @@ -706,14 +860,36 @@ class Qwen2_5OmniConditionalGenerationMixin: dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, - Qwen2_5OmniConditionalGenerationMixin): + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, + SupportsMRoPE, + Qwen2_5OmniConditionalGenerationMixin, +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", "thinker.model.": "language_model.model.", "thinker.": "", - }) + } + ) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "attn.qkv": [ + "attn.q", + "attn.k", + "attn.v", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -729,7 +905,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() thinker_config: Qwen2_5OmniThinkerConfig = ( - vllm_config.model_config.hf_config.thinker_config) + vllm_config.model_config.hf_config.thinker_config + ) quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = thinker_config @@ -745,20 +922,20 @@ class Qwen2_5OmniThinkerForConditionalGeneration( logger.warning( "flash_attn is not available, the model may not yield the " "exactly same result as the transformers implementation " - "in the audio tower part.") + "in the audio tower part." + ) if multimodal_config.get_limit_per_prompt("audio"): - self.audio_tower = Qwen2_5OmniAudioEncoder( - thinker_config.audio_config) + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) else: self.audio_tower = None if multimodal_config.get_limit_per_prompt( - "image") or multimodal_config.get_limit_per_prompt("video"): + "image" + ) or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2_5_VisionTransformer( vision_config=thinker_config.vision_config, - norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", - 1e-6), + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), ) @@ -774,7 +951,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -782,33 +960,249 @@ class Qwen2_5OmniThinkerForConditionalGeneration( # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) - if input_key in ("input_audio_features" - ) and "audio" not in mm_input_by_modality: - mm_input_by_modality[ - "audio"] = self._parse_and_validate_audio_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + if ( + input_key in ("input_audio_features") + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if ( + src_item[idx] == vision_end_token_id + and src_item[idx - 1] == audio_end_token_id + ): + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif ( + src_item[idx] == audio_start_token_id + and src_item[idx - 1] == vision_start_token_id + ): + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2 + ) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() + t_index_split_chunk = split_list_into_ranges( + t_index, t_ntoken_per_chunk + ) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + * [audio_token_id] + ) + audio_start_idx = ( + start_idx + if len(audio_llm_pos_ids_list) == 0 + else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange( + min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + ).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id] + ) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + + llm_pos_ids_list[-1].max() + + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = ( + torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + ) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -826,27 +1220,28 @@ class Qwen2_5OmniThinkerForConditionalGeneration( multimodal_embeddings += audio_embeddings return multimodal_embeddings + # TODO (ywang96): support overlapping modality embeddings so that + # `use_audio_in_video` will work on V1. def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - # TODO (ywang96): support overlapping modalitiy embeddings so that - # `use_audio_in_video` will work on V1. - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - self.config.image_token_index, - self.config.video_token_index, - self.config.audio_token_index - ]) - return inputs_embeds + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) - def get_multimodal_embeddings_v0( - self, **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]: audio_input = self._parse_and_validate_audio_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) @@ -867,26 +1262,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( multimodal_embeddings.append((video_embeds, "video")) return multimodal_embeddings - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is None or len(multimodal_embeddings) == 0: - return inputs_embeds - - for embeddings, modality in multimodal_embeddings: - if modality == "audio": - placeholder_token_id = self.config.audio_token_index - if modality == "image": - placeholder_token_id = self.config.image_token_index - if modality == "video": - placeholder_token_id = self.config.video_token_index - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, embeddings, placeholder_token_id) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -898,30 +1273,18 @@ class Qwen2_5OmniThinkerForConditionalGeneration( if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings_v0(**kwargs) - inputs_embeds = self.get_input_embeddings_v0( - input_ids, multimodal_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = ["talker.", "token2wav."] if self.audio_tower is None: skip_prefixes.extend(["audio_tower."]) @@ -932,7 +1295,16 @@ class Qwen2_5OmniThinkerForConditionalGeneration( self, skip_prefixes=skip_prefixes, ) - loaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="merger.", + tower_model=["visual.", "audio_tower."], + ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 811ecffcc1e49..094fd90aac4e5 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -25,175 +25,248 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping + +from collections.abc import Iterable, Mapping, Sequence from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) + Qwen2_5_VLConfig, + Qwen2_5_VLVisionConfig, +) +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - MergedReplicatedLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalFieldConfig -from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model -from vllm.platforms import _Backend +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_pin_memory_available +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsQuant) +from .interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsMultiModalPruning, + SupportsPP, + SupportsQuant, +) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder -from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision) -from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .qwen2_vl import ( + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, + apply_rotary_pos_emb_vision, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + cast_overflow_tensors, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) # === Vision Inputs === # -class Qwen2_5_VLImagePixelInputs(TypedDict): +class Qwen2_5_VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + formatnum_channels * patch_size * patch_size + """ + type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2_5_VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +Qwen2_5_VLImageInputs = Union[ + Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs +] + + +class Qwen2_5_VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - nv: Number of videos + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format + - second_per_grid_ts: The video time interval (in seconds) for each + grid along the temporal dimension in the 3D position IDs. Returned + when `videos` is not `None`. """ - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLImageEmbeddingInputs] - - -class Qwen2_5_VLVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` + + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + second_per_grid_ts: Annotated[ + Optional[torch.Tensor], + TensorShape("nv"), + ] + + +class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - second_per_grid_ts: torch.Tensor - """ - The video time interval (in seconds) for each grid along the temporal - dimension in the 3D position IDs. Returned when `videos` is not `None`. - """ - - -class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. - """ + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] -Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, - Qwen2_5_VLVideoEmbeddingInputs] +Qwen2_5_VLVideoInputs = Union[ + Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoEmbeddingInputs +] # === Vision Encoder === # class Qwen2_5_VisionMLP(nn.Module): - - def __init__(self, - in_features: int, - hidden_features: int, - bias: bool = False, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() - cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else - MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up_proj( + self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, + ) - cls_down_proj = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down_proj(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -206,14 +279,14 @@ class Qwen2_5_VisionMLP(nn.Module): def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] - dist.all_gather(gathered_tensors, - local_tensor, - group=parallel_state.get_tp_group().device_group) + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) - for tensor in gathered_tensors + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split) for tensor in pair @@ -223,7 +296,6 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): class Qwen2_5_VisionAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -232,86 +304,90 @@ class Qwen2_5_VisionAttention(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = (1 if use_data_parallel else - parallel_state.get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) - if use_data_parallel: - self.qkv = ReplicatedLinear(embed_dim, - self.hidden_size_per_attention_head * - 3 * num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) - else: - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - - cls_proj = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.proj = cls_proj(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") - - # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) - if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA - }: - raise RuntimeError( - f"Qwen2.5-VL does not support {self.attn_backend} backend now." + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) + self.attn_backend = attn_backend + self.use_upstream_fa = use_upstream_fa + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, ) + ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -320,33 +396,31 @@ class Qwen2_5_VisionAttention(nn.Module): q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - from flash_attn import flash_attn_varlen_func - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -356,34 +430,36 @@ class Qwen2_5_VisionAttention(nn.Module): q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Qwen2_5_VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -394,6 +470,8 @@ class Qwen2_5_VisionBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -406,35 +484,41 @@ class Qwen2_5_VisionBlock(nn.Module): projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) - self.mlp = Qwen2_5_VisionMLP(dim, - mlp_hidden_dim, - act_fn=act_fn, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa, + ) + self.mlp = Qwen2_5_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x_attn = self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + x_attn = self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) return x class Qwen2_5_VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -448,22 +532,22 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - hidden_size, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.hidden_size) return x class Qwen2_5_VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -480,43 +564,43 @@ class Qwen2_5_VisionPatchMerger(nn.Module): norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.mlp = nn.ModuleList([ - cls_fc1(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + return_bias=False, + disable_tp=use_data_parallel, + ), nn.GELU(), - cls_fc2(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + return_bias=False, + disable_tp=use_data_parallel, + ), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) - - mlp_fc1, mlp_act, mlp_fc2 = self.mlp - x_parallel, _ = mlp_fc1(x) - x_parallel = mlp_act(x_parallel) - out, _ = mlp_fc2(x_parallel) + out = self.mlp(x) return out class Qwen2_5_VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta**( - torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim)) + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, dtype=torch.float, device="cpu") / dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -525,12 +609,18 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module): if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -540,7 +630,6 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module): class Qwen2_5_VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2_5_VLVisionConfig, @@ -578,18 +667,45 @@ class Qwen2_5_VisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen2_5_VisionBlock(dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn( - vision_config.hidden_act), - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(depth) - ]) + use_upstream_fa = False + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if ( + self.attn_backend != _Backend.FLASH_ATTN + and self.attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): + self.attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." + ) + + self.blocks = nn.ModuleList( + [ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) + for layer_idx in range(depth) + ] + ) self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, @@ -599,7 +715,6 @@ class Qwen2_5_VisionTransformer(nn.Module): prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @property def dtype(self) -> torch.dtype: @@ -612,48 +727,66 @@ class Qwen2_5_VisionTransformer(nn.Module): def rotary_pos_emb_thw(self, t, h, w): hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) max_size = max(h, w) rotary_pos_emb_full = self.rotary_pos_emb(max_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb.reshape( rotary_pos_emb.shape[0] // self.spatial_merge_unit, - self.spatial_merge_unit, -1) + self.spatial_merge_unit, + -1, + ) return rotary_pos_emb def get_window_index_thw(self, grid_t, grid_h, grid_w): - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) llm_grid_h = grid_h // self.spatial_merge_size llm_grid_w = grid_w // self.spatial_merge_size index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] @@ -665,28 +798,41 @@ class Qwen2_5_VisionTransformer(nn.Module): @lru_cache(maxsize=1024) # noqa: B019 def get_rope_by_thw(self, t, h, w): - window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw( - t, h, w) + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w) rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) cu_seqlens_thw = torch.repeat_interleave( - torch.tensor([h * w], dtype=torch.int32), t) - return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, - cu_seqlens_thw) + torch.tensor([h * w], dtype=torch.int32), t + ) + return ( + rotary_pos_emb_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens + @staticmethod + def invert_permutation(perm: torch.Tensor) -> torch.Tensor: + # building the inverse permutation in O(n) time + inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) + inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) + return inv + def forward( self, x: torch.Tensor, @@ -717,10 +863,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ) = self.get_rope_by_thw(t, h, w) window_index.append(window_index_thw + window_index_id) - window_index_id += (t * llm_h * llm_w) + window_index_id += t * llm_h * llm_w - cu_seqlens_window_thw = (cu_seqlens_window_thw + - cu_window_seqlens_last) + cu_seqlens_window_thw = cu_seqlens_window_thw + cu_window_seqlens_last cu_window_seqlens_last = cu_seqlens_window_thw[-1] cu_window_seqlens.append(cu_seqlens_window_thw) @@ -730,6 +875,8 @@ class Qwen2_5_VisionTransformer(nn.Module): rotary_pos_emb = torch.cat(rotary_pos_emb) window_index = torch.cat(window_index) + # compute reverse indices + reverse_indices = self.invert_permutation(window_index) cu_window_seqlens = torch.cat(cu_window_seqlens) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) cu_seqlens = torch.cat(cu_seqlens) @@ -738,21 +885,22 @@ class Qwen2_5_VisionTransformer(nn.Module): # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( - cu_seqlens) + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( - cu_window_seqlens) + cu_window_seqlens + ) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) - cu_window_seqlens = cu_window_seqlens.to(device=self.device, - non_blocking=True) - rotary_pos_emb = rotary_pos_emb.to(device=self.device, - non_blocking=True) - window_index = window_index.to(device=hidden_states.device, - non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) + rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True) + window_index = window_index.to(device=hidden_states.device, non_blocking=True) + reverse_indices = reverse_indices.to( + device=hidden_states.device, non_blocking=True + ) hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) @@ -783,12 +931,10 @@ class Qwen2_5_VisionTransformer(nn.Module): # adapter hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -801,7 +947,7 @@ class Qwen2_5_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -811,15 +957,13 @@ class Qwen2_5_VisionTransformer(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2_5_VLConfig) @@ -832,7 +976,6 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -843,14 +986,80 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): second_per_grid_ts=MultiModalFieldConfig.batched("video"), ) + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + placeholder = { + "image": vocab[hf_processor.image_token], + "video": vocab[hf_processor.video_token], + } + + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + + # EVS-specific code + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if ( + modality == "video" + and video_pruning_rate is not None + and video_pruning_rate > 0.0 + ): + T, H, W = map(int, grid_thw) + tokens_per_frame = (H // image_processor.merge_size) * ( + W // image_processor.merge_size + ) + num_tokens = compute_retained_tokens_count( + tokens_per_frame, + T, + video_pruning_rate, + ) + # End of EVS-specific code + + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_qwen2vl, modality=modality), + ) + for modality in ("image", "video") + ] + @MULTIMODAL_REGISTRY.register_processor( Qwen2_5_VLMultiModalProcessor, info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP, - SupportsQuant): + dummy_inputs=Qwen2_5_VLDummyInputsBuilder, +) +class Qwen2_5_VLForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsLoRA, + SupportsPP, + SupportsQuant, + SupportsEagle3, + SupportsMultiModalPruning, + SupportsMRoPE, +): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -861,7 +1070,136 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) + + supports_encoder_tp_data = True + + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -880,14 +1218,18 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): + if multimodal_config.get_limit_per_prompt( + "image" + ) or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config( - self.quant_config), + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, ) @@ -901,33 +1243,37 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid vision encoder sections for some models. - if isinstance(config, (GPTQConfig, GPTQMarlinConfig)): - return None - return config + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.language_model.model.aux_hidden_state_layers = layers - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.language_model.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})" + ) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + self, **kwargs: object + ) -> Optional[Qwen2_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -937,34 +1283,35 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Qwen2_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -975,10 +1322,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - + video_grid_thw, "video grid_thw" + ) + if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: + second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -988,22 +1338,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1015,23 +1364,55 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list) + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return image_embeds.split(sizes) - def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + def _postprocess_image_embeds_evs( + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Append mrope positions for each for images. + This is necessary to recover correct mrope + positions after video pruning + Args: + image_embeds_split: Tuple of image embeddings for + each image item. + image_input: Image input data. + + Returns: + Tuple of image embeddings for each image item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = image_embeds_out + return tuple(image_embeds_split) + + def _process_video_input( + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1042,43 +1423,157 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values_videos = video_input["pixel_values_videos"] if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values_videos, grid_thw_list) + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return video_embeds.split(sizes) + def _postprocess_video_embeds_evs( + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Prunes video embeddings via Efficient Video Sampling (EVS) + and then appends mrope positions for each retained embeddings + + Args: + video_embeds_split: Tuple of video embeddings for each video item. + video_input: Video input data. + + Returns: + Tuple of video embeddings for each video item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + merge_size = self.visual.spatial_merge_size + + # Cast to long to match the original code + # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa + second_per_grid_ts = video_input["second_per_grid_ts"].long() + tokens_per_second = self.config.vision_config.tokens_per_second + + video_embeds_out = [] + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): + # For each video, we compute retention mask using EVS + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + positions = compute_mrope_for_media( + size, + merge_size, + tokens_per_second=tokens_per_second, + video_second_per_grid=video_second_per_grid_t.item(), + ).to(emb.device) + + emb = emb[retention_mask] + positions = positions[retention_mask] + emb = torch.cat([emb, positions], dim=1) + video_embeds_out.append(emb) + return tuple(video_embeds_out) + + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: tuple[torch.Tensor, ...], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt (Containing + entire sequence). + multimodal_embeddings: Tuple of multimodal embeddings. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Device + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) + + # Tensors + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] + + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) + + return tuple(mm_embeddings_out), positions, mrope_positions_delta + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -1092,51 +1587,20 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, multimodal_input = mm_input_by_modality[modality] if modality == "image": vision_embeddings = self._process_image_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + vision_embeddings = self._postprocess_image_embeds_evs( + vision_embeddings, multimodal_input + ) multimodal_embeddings += vision_embeddings if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + video_embeddings = self._postprocess_video_embeds_evs( + video_embeddings, multimodal_input + ) multimodal_embeddings += video_embeddings return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Qwen2_5_VLImageInputs] = None, - video_input: Optional[Qwen2_5_VLVideoInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1151,46 +1615,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a - batch. - **NOTE**: If mrope is enabled (default setting for Qwen2.5-VL - opensource models), the shape will be `(3, seq_len)`, + batch. **NOTE**: If mrope is enabled (default setting for + Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - second_per_grid_ts: Tensor `(num_videos)` of video time interval ( - in seconds) for each grid along the temporal dimension in the - 3D position IDs. `None` if no videos are passed. """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -1202,14 +1634,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + return self.language_model.compute_logits(hidden_states) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 86c567ca36174..e61a730f97bb6 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,49 +22,95 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn from transformers import BatchFeature -from transformers.models.qwen2_audio import (Qwen2AudioConfig, - Qwen2AudioEncoder, - Qwen2AudioProcessor) +from transformers.models.qwen2_audio import ( + Qwen2AudioConfig, + Qwen2AudioEncoder, + Qwen2AudioProcessor, +) from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + AudioItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix # # === Audio Inputs === # -class Qwen2AudioInputs(TypedDict): - input_features: torch.Tensor - """Shape: `(num_audios, num_mel_bins, 3000)`""" +class Qwen2AudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + """ - feature_attention_mask: torch.Tensor - """Shape: `(num_audios, 3000)`""" + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("na", "nmb", 3000), + ] + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", 3000), + ] + + +class Qwen2AudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs"), + ] + + +Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] # === Audio Encoder === # class Qwen2AudioMultiModalProjector(nn.Module): - def __init__(self, audio_hidden_size: int, text_hidden_size: int): super().__init__() self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True) @@ -82,15 +128,13 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): class Qwen2AudioProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2AudioConfig) def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs) - def get_feature_extractor(self, - **kwargs: object) -> WhisperFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) @@ -100,9 +144,7 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): return {"audio": None} -class Qwen2AudioDummyInputsBuilder( - BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): - +class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -115,6 +157,7 @@ class Qwen2AudioDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() @@ -122,18 +165,43 @@ class Qwen2AudioDummyInputsBuilder( audio_len = feature_extractor.chunk_length * sampling_rate num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + return { - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } -class Qwen2AudioMultiModalProcessor( - BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): +def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + ) + +class Qwen2AudioMultiModalDataParser(MultiModalDataParser): + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_qwen2audio_field_config, + ) + + return super()._parse_audio_data(data) + + +class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + return Qwen2AudioMultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, @@ -173,10 +241,7 @@ class Qwen2AudioMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - input_features=MultiModalFieldConfig.batched("audio"), - feature_attention_mask=MultiModalFieldConfig.batched("audio"), - ) + return _qwen2audio_field_config(hf_inputs) def _get_prompt_updates( self, @@ -190,10 +255,8 @@ class Qwen2AudioMultiModalProcessor( # Use getattr with default to be compatible with transformers<4.48 audio_token = getattr(processor, "audio_token", "<|AUDIO|>") - audio_bos_token = getattr(processor, "audio_bos_token", - "<|audio_bos|>") - audio_eos_token = getattr(processor, "audio_eos_token", - "<|audio_eos|>") + audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>") + audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>") audio_token_id = vocab[audio_token] audio_bos_id = vocab[audio_bos_token] @@ -206,18 +269,27 @@ class Qwen2AudioMultiModalProcessor( else: assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lens = _get_feat_extract_output_lengths( - feature_attention_mask.sum(-1)) + feature_attention_mask.sum(-1) + ) audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): - num_features = audio_output_lengths[item_idx] + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor" + num_features = audio_embeds.shape[0] + if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) - raise ValueError(f"The audio (len={audio_len}) is too short " - "to be represented inside the model") + raise ValueError( + f"The audio (len={audio_len}) is too short " + "to be represented inside the model" + ) audio_tokens = [audio_token_id] * num_features @@ -238,10 +310,9 @@ class Qwen2AudioMultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( Qwen2AudioMultiModalProcessor, info=Qwen2AudioProcessingInfo, - dummy_inputs=Qwen2AudioDummyInputsBuilder) -class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): - + dummy_inputs=Qwen2AudioDummyInputsBuilder, +) +class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("audio"): @@ -259,7 +330,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, self.audio_tower = Qwen2AudioEncoder(config.audio_config) self.multi_modal_projector = Qwen2AudioMultiModalProjector( - config.audio_config.d_model, config.text_config.hidden_size) + config.audio_config.d_model, config.text_config.hidden_size + ) self.quant_config = quant_config @@ -271,106 +343,130 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: - input_features = kwargs.pop('input_features', None) - feature_attention_mask = kwargs.pop('feature_attention_mask', None) - if input_features is None: - return None - input_features = self._validate_and_reshape_mm_tensor( - input_features, 'input_features') - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') - if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") - return Qwen2AudioInputs(input_features=input_features, - feature_attention_mask=feature_attention_mask) + self, **kwargs: object + ) -> Optional[Qwen2AudioInputs]: + input_features = kwargs.pop("input_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) - def _process_audio_input(self, - audio_input: Qwen2AudioInputs) -> torch.Tensor: + if input_features is None and audio_embeds is None: + return None + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}" + ) + audio_embeds = self._validate_and_reshape_mm_tensor( + audio_embeds, "audio_embeds" + ) + return Qwen2AudioEmbeddingInputs( + type="audio_embeds", audio_embeds=audio_embeds + ) + + if input_features is not None: + input_features = self._validate_and_reshape_mm_tensor( + input_features, "input_features" + ) + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, "feature_attention_mask" + ) + return Qwen2AudioFeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: Qwen2AudioInputs + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"] audio_feat_lengths, audio_output_lengths = ( self.audio_tower._get_feat_extract_output_lengths( - feature_attention_mask.sum(-1))) + feature_attention_mask.sum(-1) + ) + ) batch_size, _, max_mel_seq_len = input_features.shape max_seq_len = (max_mel_seq_len - 2) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = (torch.arange( - 0, - max_seq_len, - dtype=audio_feat_lengths.dtype, - device=audio_feat_lengths.device).unsqueeze(0).expand( - batch_size, max_seq_len)) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=audio_feat_lengths.dtype, + device=audio_feat_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) lengths_expand = audio_feat_lengths.unsqueeze(-1).expand( - batch_size, max_seq_len) + batch_size, max_seq_len + ) # Create mask padding_mask = seq_range >= lengths_expand - audio_attention_mask_ = padding_mask.view( - batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, - max_seq_len) + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) audio_attention_mask = audio_attention_mask_.to( dtype=self.audio_tower.conv1.weight.dtype, - device=self.audio_tower.conv1.weight.device) + device=self.audio_tower.conv1.weight.device, + ) audio_attention_mask[audio_attention_mask_] = float("-inf") - audio_outputs = self.audio_tower(input_features, - attention_mask=audio_attention_mask) + audio_outputs = self.audio_tower( + input_features, attention_mask=audio_attention_mask + ) selected_audio_feature = audio_outputs.last_hidden_state audio_features = self.multi_modal_projector(selected_audio_feature) num_audios, max_audio_tokens, embed_dim = audio_features.shape audio_output_lengths = audio_output_lengths.unsqueeze(1) - audio_features_mask = torch.arange(max_audio_tokens).expand( - num_audios, max_audio_tokens).to( - audio_output_lengths.device) < audio_output_lengths - masked_audio_features = audio_features[audio_features_mask].view( - -1, embed_dim) + audio_features_mask = ( + torch.arange(max_audio_tokens) + .expand(num_audios, max_audio_tokens) + .to(audio_output_lengths.device) + < audio_output_lengths + ) + masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) # Split to tuple of embeddings for individual audio input. - return torch.split(masked_audio_features, - audio_output_lengths.flatten().tolist()) + return torch.split( + masked_audio_features, audio_output_lengths.flatten().tolist() + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] masked_audio_features = self._process_audio_input(audio_input) return masked_audio_features - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -379,33 +475,20 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 5c4ad34246d66..7251e7b2eea49 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -24,7 +24,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional, Union import torch @@ -38,32 +40,38 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Qwen2MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -71,31 +79,44 @@ class Qwen2MoeMLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + expert_gate: Optional[torch.nn.Linear] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() + self.expert_gate = expert_gate def forward(self, x): gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + out = self.act_fn(gate_up) + out, _ = self.down_proj(out) + + if self.expert_gate is not None: + out = F.sigmoid(self.expert_gate(x)) * out + + return out class Qwen2MoeSparseMoeBlock(nn.Module): - def __init__( self, config: Qwen2MoeConfig, @@ -108,63 +129,66 @@ class Qwen2MoeSparseMoeBlock(nn.Module): if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) - self.experts = FusedMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=False, + expert_gate=self.shared_expert_gate, + prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if self.shared_expert is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) + final_hidden_states + ) return final_hidden_states.view(orig_shape) class Qwen2MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -209,6 +233,7 @@ class Qwen2MoeAttention(nn.Module): self.total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( @@ -216,6 +241,7 @@ class Qwen2MoeAttention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( @@ -237,7 +263,10 @@ class Qwen2MoeAttention(nn.Module): **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) + } + if dual_chunk_attention_config + else {}, + ) def forward( self, @@ -253,7 +282,6 @@ class Qwen2MoeAttention(nn.Module): class Qwen2MoeDecoderLayer(nn.Module): - def __init__( self, config: Qwen2MoeConfig, @@ -265,11 +293,10 @@ class Qwen2MoeDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -286,25 +313,27 @@ class Qwen2MoeDecoderLayer(nn.Module): # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen2MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen2MoeSparseMoeBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -317,23 +346,20 @@ class Qwen2MoeDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class Qwen2MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -350,16 +376,18 @@ class Qwen2MoeModel(nn.Module): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen2MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Qwen2MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -381,27 +409,26 @@ class Qwen2MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -415,7 +442,7 @@ class Qwen2MoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -429,8 +456,9 @@ class Qwen2MoeModel(nn.Module): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -453,21 +481,25 @@ class Qwen2MoeModel(nn.Module): if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -475,7 +507,8 @@ class Qwen2MoeModel(nn.Module): # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501 @@ -486,26 +519,22 @@ class Qwen2MoeModel(nn.Module): else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - fall_back_to_pt_during_load = False packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -514,16 +543,33 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + # Only perform the following mapping when Qwen2MoeMLP exists + if ( + getattr(config, "mlp_only_layers", []) + or config.shared_expert_intermediate_size > 0 + ): + self.packed_modules_mapping["gate_up_proj"] = ( + [ + "gate_proj", + "up_proj", + ], + ) + + self.model = Qwen2MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -535,21 +581,19 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e0a30e04c602a..75ed95477f78f 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -6,6 +6,7 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Optional, Union @@ -13,18 +14,17 @@ import torch from torch import nn from vllm.config import VllmConfig -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): - is_pooling_model = True pooler: Pooler @@ -50,22 +50,31 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.head_dtype = vllm_config.model_config.head_dtype self.score = nn.Sequential( - ColumnParallelLinear(config.hidden_size, - config.hidden_size, - quant_config=quant_config, - return_bias=False), + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + params_dtype=self.head_dtype, + return_bias=False, + ), nn.ReLU(), - RowParallelLinear(config.hidden_size, - config.num_labels, - quant_config=quant_config, - return_bias=False), + RowParallelLinear( + config.hidden_size, + config.num_labels, + params_dtype=self.head_dtype, + quant_config=quant_config, + return_bias=False, + ), ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -77,21 +86,20 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + hidden_states = hidden_states.to(self.head_dtype) logits = self.score(hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["lm_head."]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) return loader.load_weights(weights) @default_pooling_type("ALL") class Qwen2ForRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -100,12 +108,12 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) @default_pooling_type("STEP") class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -113,5 +121,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}) + self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)}) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ae7a8d8d7a5b9..cb1bf3825c74f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -24,150 +24,216 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import AutoConfig, BatchFeature -from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, - Qwen2VLProcessor) +from transformers import AutoConfig, BatchFeature, PretrainedConfig +from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLConfig, Qwen2VLVisionConfig) + Qwen2VLConfig, + Qwen2VLVisionConfig, +) from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize -from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( - Qwen2VLVideoProcessor) +from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) +from vllm.model_executor.layers.rotary_embedding.common import ( + dispatch_rotary_emb_function, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageSize, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) -from .vision import get_vit_attn_backend +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) # For profile run -_MAX_FRAMES_PER_VIDEO = 16 +_MAX_FRAMES_PER_VIDEO = 14 # === Vision Inputs === # -class Qwen2VLImagePixelInputs(TypedDict): +class Qwen2VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format + """ + type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2VLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. + + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] + + +class Qwen2VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each video over each prompt in + the batch + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + - nv: Number of videos + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, - Qwen2VLImageEmbeddingInputs] - - -class Qwen2VLVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` + + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + +class Qwen2VLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. - """ - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] -Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, - Qwen2VLVideoEmbeddingInputs] +Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, Qwen2VLVideoEmbeddingInputs] # === Vision Encoder === # class Qwen2VisionMLP(nn.Module): - def __init__( self, in_features: int, @@ -175,17 +241,24 @@ class Qwen2VisionMLP(nn.Module): act_layer: type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() - self.fc1 = ColumnParallelLinear(in_features, - hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1") + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) self.act = act_layer() - self.fc2 = RowParallelLinear(hidden_features, - in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -200,15 +273,14 @@ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) -def apply_rotary_emb_torch(x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False) -> torch.Tensor: +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) @@ -216,34 +288,30 @@ def apply_rotary_emb_torch(x: torch.Tensor, ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( - cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) sin = repeat( - sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) -def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) t_ = t.float() cos = freqs.cos() sin = freqs.sin() - apply_rotary_emb = apply_rotary_emb_torch - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - output = apply_rotary_emb(t_, cos, sin).type_as(t) + output = rotary_emb_function(t_, cos, sin).type_as(t) return output class Qwen2VisionAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -251,36 +319,65 @@ class Qwen2VisionAttention(nn.Module): projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_size = world_size + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size + ) - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + self.qkv = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype(), + ) + self.use_upstream_fa = False + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) + ) + if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now.") + f"Qwen2-VL does not support {self.attn_backend} backend now." + ) + self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -294,27 +391,31 @@ class Qwen2VisionAttention(nn.Module): # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -322,33 +423,31 @@ class Qwen2VisionAttention(nn.Module): q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - from flash_attn import flash_attn_varlen_func - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -358,34 +457,36 @@ class Qwen2VisionAttention(nn.Module): q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Qwen2VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -395,6 +496,7 @@ class Qwen2VisionBlock(nn.Module): norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -403,24 +505,30 @@ class Qwen2VisionBlock(nn.Module): self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Qwen2VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - self.mlp = Qwen2VisionMLP(dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.attn = Qwen2VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) + self.mlp = Qwen2VisionMLP( + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -435,7 +543,6 @@ class Qwen2VisionBlock(nn.Module): class Qwen2VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -449,22 +556,22 @@ class Qwen2VisionPatchEmbed(nn.Module): self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - embed_dim, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.embed_dim) return x class Qwen2VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -473,25 +580,34 @@ class Qwen2VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), - nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + disable_tp=use_data_parallel, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + disable_tp=use_data_parallel, + ), + ] + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) @@ -505,13 +621,11 @@ class Qwen2VisionPatchMerger(nn.Module): class Qwen2VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -520,12 +634,18 @@ class Qwen2VisionRotaryEmbedding(nn.Module): if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -535,13 +655,13 @@ class Qwen2VisionRotaryEmbedding(nn.Module): class Qwen2VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -555,6 +675,9 @@ class Qwen2VisionTransformer(nn.Module): num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.hidden_size + self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads self.embed_dim = embed_dim @@ -570,23 +693,35 @@ class Qwen2VisionTransformer(nn.Module): head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen2VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(depth) + ] + ) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -596,37 +731,47 @@ class Qwen2VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: pos_ids = [] + max_grid_size = 0 for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + max_grid_size = max(max_grid_size, h, w) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor + self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -635,7 +780,7 @@ class Qwen2VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: list[list[int]], ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) @@ -645,9 +790,10 @@ class Qwen2VisionTransformer(nn.Module): rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long) + cu_seqlens = torch.repeat_interleave( + grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -669,8 +815,7 @@ class Qwen2VisionTransformer(nn.Module): return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -681,7 +826,7 @@ class Qwen2VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -692,41 +837,45 @@ class Qwen2VisionTransformer(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def _create_qwen2vl_field_factory( - spatial_merge_size: int + spatial_merge_size: int, ) -> Callable[ [Mapping[str, torch.Tensor]], - Mapping[str, MultiModalFieldConfig], + Mapping[str, MultiModalFieldConfig], ]: - def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_pixel_grid_sizes = image_grid_thw.prod(-1) - image_embed_grid_sizes = (image_pixel_grid_sizes // - spatial_merge_size // spatial_merge_size) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) - video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // - spatial_merge_size) + video_embed_grid_sizes = ( + video_grid_sizes // spatial_merge_size // spatial_merge_size + ) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_pixel_grid_sizes), + "image", image_pixel_grid_sizes + ), image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_embed_grid_sizes), + "image", image_embed_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_embed_grid_sizes), + "video", video_embed_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), ) @@ -734,7 +883,6 @@ def _create_qwen2vl_field_factory( class Qwen2VLMultiModalDataParser(MultiModalDataParser): - def __init__(self, spatial_merge_size: int, *args, **kwargs): self._spatial_merge_size = spatial_merge_size super().__init__(*args, **kwargs) @@ -748,8 +896,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_create_qwen2vl_field_factory( - self._spatial_merge_size), + fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -763,15 +910,13 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_create_qwen2vl_field_factory( - self._spatial_merge_size), + fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size), ) return super()._parse_video_data(data) class Qwen2VLProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2VLConfig) @@ -823,11 +968,9 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) # NOTE: Frames are padded to be divisible by `temporal_patch_size` # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 @@ -852,6 +995,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, + num_frames=1, image_processor=image_processor, ) return num_image_tokens @@ -876,6 +1020,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, + num_frames=1, image_processor=None, ) return max_image_size @@ -889,10 +1034,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): image_processor=None, ) - def _get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int: target_width, target_height = self.get_image_size_with_most_features() - num_frames = 0 + num_frames = start_num_frames while True: next_num_frames = num_frames + 1 @@ -914,15 +1059,14 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): self, seq_len: int, mm_counts: Mapping[str, int], + max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO, ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self._get_max_video_frames(seq_len) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), max_frames_per_video + ) return max(max_frames_per_video, 1) @@ -936,14 +1080,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -958,36 +1100,41 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, - ) + overrides=video_overrides, + ), } -class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] - ): - +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return Qwen2VLMultiModalDataParser( - self.info.get_hf_config().vision_config.spatial_merge_size) + self.info.get_hf_config().vision_config.spatial_merge_size + ) def _get_prompt_updates( self, @@ -996,8 +1143,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -1020,9 +1166,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] PromptReplacement( modality=modality, target=[placeholder[modality]], - replacement=partial(get_replacement_qwen2vl, - modality=modality), - ) for modality in ("image", "video") + replacement=partial(get_replacement_qwen2vl, modality=modality), + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -1031,16 +1177,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _create_qwen2vl_field_factory( - self.info.get_hf_config().vision_config.spatial_merge_size)( - hf_inputs) + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): - +@MULTIMODAL_REGISTRY.register_processor( + Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) +class Qwen2VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -1050,7 +1198,141 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) + + supports_encoder_tp_data = True + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get M-RoPE input positions for Qwen2-VL model.""" + if image_grid_thw is None: + image_grid_thw = [] + if video_grid_thw is None: + video_grid_thw = [] + if second_per_grid_ts is None: + second_per_grid_ts = [] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1067,16 +1349,19 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): + if multimodal_config.get_limit_per_prompt( + "image" + ) or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) else: self.visual = None @@ -1088,34 +1373,30 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid vision encoder sections for some models. - # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config - - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})" + ) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + self, **kwargs: object + ) -> Optional[Qwen2VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1125,33 +1406,35 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - - return Qwen2VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Qwen2VLImageEmbeddingInputs(type="image_embeds", - image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + return Qwen2VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Qwen2VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1161,9 +1444,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2VLVideoPixelInputs( type="pixel_values_videos", @@ -1173,52 +1458,72 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") - return Qwen2VLVideoEmbeddingInputs(type="video_embeds", - video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + return Qwen2VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"] else: pixel_values = image_input["pixel_values"] - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() - return image_embeds.split(sizes.tolist()) + return image_embeds.split(sizes) def _process_video_input( - self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"] else: pixel_values_videos = video_input["pixel_values_videos"] - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() - return video_embeds.split(sizes.tolist()) + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} @@ -1226,23 +1531,23 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1265,45 +1570,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Qwen2VLImagePixelInputs] = None, - video_input: Optional[Qwen2VLVideoPixelInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -1321,40 +1587,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -1366,14 +1606,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + return self.language_model.compute_logits(hidden_states) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) @@ -1396,7 +1632,6 @@ class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor): class Tarsier2ImageProcessor(Qwen2VLImageProcessor): - def __init__( self, size: Optional[dict[str, int]] = None, @@ -1406,7 +1641,7 @@ class Tarsier2ImageProcessor(Qwen2VLImageProcessor): # Remap if Tarsier2-specific format is provided remapped_size = { "shortest_edge": size["min_pixels"], - "longest_edge": size["max_pixels"] + "longest_edge": size["max_pixels"], } super().__init__(size=remapped_size, **kwargs) else: @@ -1414,7 +1649,6 @@ class Tarsier2ImageProcessor(Qwen2VLImageProcessor): class Tarsier2Processor(Qwen2VLProcessor): - def __init__( self, vision_config: dict, @@ -1427,11 +1661,11 @@ class Tarsier2Processor(Qwen2VLProcessor): tokenizer=tokenizer, video_processor=Qwen2VLVideoProcessor(**vision_config), chat_template=None, - **kwargs) + **kwargs, + ) class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self) -> Qwen2VLConfig: model_path = self.ctx.model_config.model original_config = AutoConfig.from_pretrained(model_path) @@ -1448,17 +1682,20 @@ class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): ) def get_image_processor(self) -> Tarsier2ImageProcessor: - return Tarsier2ImageProcessor( - **self.ctx.get_hf_image_processor_config()) + return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config()) -@MULTIMODAL_REGISTRY.register_processor(Tarsier2MultiModalProcessor, - info=Tarsier2ProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + Tarsier2MultiModalProcessor, + info=Tarsier2ProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "vision_tower.": "visual.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "vision_tower.": "visual.", + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig @@ -1469,9 +1706,7 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): vllm_config.model_config.hf_config = qwen2vl_config super().__init__(vllm_config=vllm_config, prefix=prefix) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 2060206633702..bcd4968ba5c46 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Any, Optional, Union @@ -35,26 +36,22 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - maybe_prefix) +from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix logger = init_logger(__name__) class Qwen3Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -132,7 +129,9 @@ class Qwen3Attention(nn.Module): **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}, + } + if dual_chunk_attention_config + else {}, ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -145,12 +144,10 @@ class Qwen3Attention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) @@ -160,7 +157,6 @@ class Qwen3Attention(nn.Module): class Qwen3DecoderLayer(nn.Module): - def __init__( self, config: Qwen3Config, @@ -173,9 +169,9 @@ class Qwen3DecoderLayer(nn.Module): # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) # By default, Qwen3 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -193,8 +189,8 @@ class Qwen3DecoderLayer(nn.Module): num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, @@ -209,10 +205,10 @@ class Qwen3DecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -225,16 +221,14 @@ class Qwen3DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -252,13 +246,13 @@ ALL_DECODER_LAYER_TYPES = { "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Qwen3Model(Qwen2Model): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - decoder_layer_type=Qwen3DecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, decoder_layer_type=Qwen3DecoderLayer + ) class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): @@ -284,30 +278,33 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) @@ -321,24 +318,21 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 2812f79a66b70..0769378933d52 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -22,48 +22,63 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" + import typing from collections.abc import Callable, Iterable +from itertools import islice from typing import Any, Optional, Union import torch from torch import nn -from transformers import Qwen3MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors -from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Qwen3MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -75,19 +90,24 @@ class Qwen3MoeMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -98,15 +118,17 @@ class Qwen3MoeMLP(nn.Module): class Qwen3MoeSparseMoeBlock(nn.Module): - def __init__( self, - config: Qwen3MoeConfig, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() + + config = vllm_config.model_config.hf_text_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group @@ -114,60 +136,79 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) - self.experts = FusedMoE(num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=True, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=True, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] + assert hidden_states.dim() <= 2, ( + "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" + ) + is_input_1d = hidden_states.dim() == 1 + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) - return final_hidden_states.view(orig_shape) + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + + # return to 1d if input is 1d + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states class Qwen3MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -208,19 +249,23 @@ class Qwen3MoeAttention(nn.Module): self.max_position_embeddings = max_position_embeddings self.dual_chunk_attention_config = dual_chunk_attention_config - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, @@ -241,7 +286,9 @@ class Qwen3MoeAttention(nn.Module): **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}, + } + if dual_chunk_attention_config + else {}, ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -255,13 +302,11 @@ class Qwen3MoeAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) @@ -271,24 +316,20 @@ class Qwen3MoeAttention(nn.Module): class Qwen3MoeDecoderLayer(nn.Module): - - def __init__( - self, - config: Qwen3MoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - enable_eplb: bool = False, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + + config = vllm_config.model_config.hf_text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -297,8 +338,8 @@ class Qwen3MoeDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -307,25 +348,27 @@ class Qwen3MoeDecoderLayer(nn.Module): # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb) + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock( + vllm_config=vllm_config, prefix=f"{prefix}.mlp" + ) else: - self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -338,31 +381,26 @@ class Qwen3MoeDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class Qwen3MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config + config = vllm_config.model_config.hf_text_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts @@ -373,20 +411,19 @@ class Qwen3MoeModel(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen3MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - enable_eplb=enable_eplb), + lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + # Track layers for auxiliary hidden state outputs (EAGLE3) + self.aux_hidden_state_layers: tuple[int, ...] = () def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -397,7 +434,9 @@ class Qwen3MoeModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -408,15 +447,29 @@ class Qwen3MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + + aux_hidden_states = [] + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, + ): + # Collect auxiliary hidden states if specified + if layer_idx in self.aux_hidden_state_layers: + aux_hidden_state = ( + hidden_states + residual if residual is not None else hidden_states + ) + aux_hidden_states.append(aux_hidden_state) hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) + + # Return auxiliary hidden states if collected + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -427,10 +480,10 @@ class Qwen3MoeModel(nn.Module): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -441,15 +494,24 @@ class Qwen3MoeModel(nn.Module): ] # Skip loading extra parameters for GPTQ/modelopt models. - ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", - ".v_scale", "_v_scale", ".weight_scale", - "_weight_scale", ".input_scale", "_input_scale") + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -479,8 +541,7 @@ class Qwen3MoeModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: @@ -505,23 +566,27 @@ class Qwen3MoeModel(nn.Module): continue # Skip loading extra parameters for GPTQ/modelopt models. - if name_mapped.endswith( - ignore_suffixes - ) and name_mapped not in params_dict: + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -533,8 +598,7 @@ class Qwen3MoeModel(nn.Module): continue # Skip loading extra parameters for GPTQ/modelopt models. - if name.endswith( - ignore_suffixes) and name not in params_dict: + if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -542,7 +606,8 @@ class Qwen3MoeModel(nn.Module): # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -553,45 +618,56 @@ class Qwen3MoeModel(nn.Module): else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, - MixtureOfExperts): +class Qwen3MoeForCausalLM( + nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts +): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_text_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Qwen3MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + # Only perform the following mapping when Qwen3MoeMLP exists + if getattr(config, "mlp_only_layers", []): + self.packed_modules_mapping["gate_up_proj"] = ( + [ + "gate_proj", + "up_proj", + ], + ) + self.model = Qwen3MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) # Set MoE hyperparameters self.expert_weights = [] @@ -643,8 +719,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): moe = layer.mlp @@ -653,6 +728,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -663,23 +745,21 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() \ No newline at end of file + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py new file mode 100644 index 0000000000000..50629bb2e4a26 --- /dev/null +++ b/vllm/model_executor/models/qwen3_next.py @@ -0,0 +1,1337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next model.""" + +from collections.abc import Iterable +from itertools import islice +from typing import Optional + +import torch +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fla.ops import ( + RMSNormGated, + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, +) +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.triton_utils import tl, triton +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from .interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + quant_config = vllm_config.quant_config + + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + expert_gate=self.shared_expert_gate, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.shared_expert is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_qkvz, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvz", + ) + # ba_proj doesn't support blockwise fp8 quantization. + self.in_proj_ba = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_ba, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_ba", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + ) + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=current_platform.current_device(), + dtype=config.torch_dtype, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkvz, + mixed_ba, + ): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + ( + self.head_k_dim + + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) + * self.num_v_heads + // self.num_k_heads + ), + ) + new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads, + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], + # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), + (query, key), + ) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query, key, value + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + return torch.ops.vllm.gdn_attention( + hidden_states, + output, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + + # 1. Set up dimensions for reshapes later + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) + projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + query, key, value = map( + lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) + ) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.1: process the mutli-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][ + : attn_metadata.num_spec_decodes + ], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec_T, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=attn_metadata, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec + ) + + beta = b.sigmoid() + # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) + + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype + ) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class Qwen3NextAttention(nn.Module): + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + self.attn_output_gate = getattr(config, "attn_output_gate", True) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=getattr(config, "qkv_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, + partial_rotary_factor=config.partial_rotary_factor, + dual_chunk_attention_config=self.dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": self.dual_chunk_attention_config, + } + if self.dual_chunk_attention_config + else {}, + ) + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + output: torch.Tensor, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 + ) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim + ) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim + ) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + +class Qwen3NextDecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + layer_type: str, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.linear_attn", + ) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 + and (self.layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3NextSparseMoeBlock( + vllm_config=vllm_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.torch_dtype, + ), + ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.torch_dtype, + ), + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) + else: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype) + 1 + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) + else: + assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( + f"shape must be the same {len(hidden_states.shape)}, " + f"{len(self.ffn_layer_scale.shape)}" + ) + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype) + 1 + ) + + return hidden_states, residual + + +@support_torch_compile +class Qwen3NextModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Qwen3NextConfig = vllm_config.model_config.hf_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return Qwen3NextDecoderLayer( + vllm_config, + layer_type=config.layer_types[extract_layer_index(prefix)], + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + if get_pp_group().is_last_rank: + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3NextForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, ( + "Qwen3Next currently does not support prefix caching" + ) + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3NextModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[SharedFusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def gdn_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output) + + +def gdn_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="gdn_attention", + op_func=gdn_attention, + mutates_args=["output"], + fake_impl=gdn_attention_fake, +) + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid]( + g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 + ) + return g diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py new file mode 100644 index 0000000000000..828931716c8f9 --- /dev/null +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next MTP model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3_next import ( + Qwen3NextDecoderLayer, + Qwen3NextRMSNorm, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig + +from .interfaces import SupportsPP +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class Qwen3NextMultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config + + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fc", + ) + + self.layers = torch.nn.ModuleList( + Qwen3NextDecoderLayer( + vllm_config, + layer_type="full_attention", + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = spec_step_idx % self.num_mtp_layers + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class Qwen3NextMTP(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert not cache_config.enable_prefix_caching, ( + "Qwen3NextMTP currently does not support prefix caching" + ) + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = Qwen3NextMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif not any(key in name for key in shared_weight_names): + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py new file mode 100755 index 0000000000000..6eb9faabd1c7f --- /dev/null +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -0,0 +1,1712 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-Omni-Moe model (thinker part).""" + +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeConfig, + Qwen3OmniMoeThinkerConfig, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, +) +from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( + Qwen3OmniMoeProcessor, +) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2_audio import ( + Qwen2AudioFeatureInputs, + Qwen2AudioProcessingInfo, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.sequence import IntermediateTensors + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) + +# yapf conflicts with isort for this block +# yapf: disable +from .qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, +) + +# yapf: enable +from .qwen2_5_vl import ( + Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLProcessingInfo, +) +from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) + + +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return feat_lengths, output_lengths + + +class Qwen3_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen3_VisionMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear_fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1", + ) + self.linear_fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2", + ) + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.mlp = Qwen3_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = norm_layer( + self.hidden_size if use_postshuffle_norm else context_dim + ) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.ln_q(x.view(-1, self.hidden_size)) + else: + x = self.ln_q(x).view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen3Omni_VisionTransformer(nn.Module): + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.image_size = vision_config.image_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.num_grid_per_side = self.image_size // self.patch_size + self.apply_vit_abs_pos_embed = vision_config.apply_vit_abs_pos_embed + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + # vit pos embeding, TODO: spatial_patch_size vs patch_size + if self.apply_vit_abs_pos_embed: + self.pos_embed = nn.Embedding(self.num_grid_per_side**2, self.hidden_size) + else: + self.pos_embed = nn.Parameter( + torch.empty([1, self.num_grid_per_side**2, self.hidden_size]) + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(vision_config.depth) + ] + ) + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + if self.deepstack_visual_indexes is not None: + self.merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.merger_list.{layer_idx}", + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace( + 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device + ) + w_idxs = torch.linspace( + 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + h_floor_grid_idx = h_floor_grid * num_grid_per_side + h_ceil_grid_idx = h_ceil_grid * num_grid_per_side + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - dw_grid + w11 + + idx00 = h_floor_grid_idx + w_floor_grid + idx01 = h_floor_grid_idx + w_ceil_grid + idx10 = h_ceil_grid_idx + w_floor_grid + idx11 = h_ceil_grid_idx + w_ceil_grid + + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype, device=self.device) + + embeds = self.pos_embed(indices) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = p0 + p1 + p2 + p3 + + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() + repeated = repeated.view( + t, h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + if self.apply_vit_abs_pos_embed: + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + hidden_states_list = [] + deepstack_visual_indexes = self.deepstack_visual_indexes + + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + if ( + deepstack_visual_indexes is not None + and layer_num in deepstack_visual_indexes + ): + hidden_states_list.append(hidden_states) + + hidden_states = self.merger(hidden_states) + + # processing deepstack + if deepstack_visual_indexes is not None: + processed_hidden_states_list = [hidden_states] + for idx, x in enumerate(hidden_states_list): + x = self.merger_list[idx](x) + processed_hidden_states_list.append(x) + # we cat the original visual features and deepstack features + # along the feature dim + hidden_states = torch.cat( + processed_hidden_states_list, dim=1 + ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "deepstack_input_embeds": 0, + } +) +class Qwen3MoeLLMModel(Qwen3MoeModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + self.deepstack_multiscale_layer_start = 1 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in enumerate( + self.layers[self.start_layer : self.end_layer] + ): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeLLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + +class Qwen3OmniMoeThinkerProcessingInfo( + Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo +): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config + + def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor: + processor = self.ctx.get_hf_processor( + Qwen3OmniMoeProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + if not hasattr(processor, "audio_token"): + processor.audio_token = "<|audio_pad|>" + if not hasattr(processor, "image_token"): + processor.image_token = "<|image_pad|>" + if not hasattr(processor, "video_token"): + processor.video_token = "<|video_pad|>" + return processor + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None, "video": None} + + +Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder + + +class Qwen3OmniMoeThinkerMultiModalProcessor( + Qwen2_5OmniThinkerMultiModalProcessor, +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: + length = x.shape[-1] + if length % hop_length != 0: + pad_length = hop_length - (length % hop_length) + x = np.pad(x, (0, pad_length), mode="constant", constant_values=0) + return x + + # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + if audios: + # NOTE: Qwen3-Omni processor accept "audio" + # To make sure the cache works with padding=True, we pre-padded + # the audio to multiple of hop_length. + hop_length = self.info.get_feature_extractor().hop_length + mm_data["audio"] = [ + pad_to_hop_length(audio, hop_length) + if isinstance(audio, np.ndarray) + else (pad_to_hop_length(audio[0], hop_length), audio[1]) + for audio in audios + ] + mm_kwargs = dict( + **mm_kwargs, + ) + + hf_inputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + if ( + "audio_feature_lengths" in hf_inputs + and "feature_attention_mask" in hf_inputs + and (audios := mm_data.get("audio", [])) + ): + hop_length = self.info.get_feature_extractor().hop_length + audio_num_frames = [] + for _, audio in enumerate(audios): + audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio) + num_frame = ( + (audio_length // hop_length) + if audio_length % hop_length == 0 + else (audio_length // hop_length - 1) + ) + audio_num_frames.append(num_frame) + hf_inputs["feature_attention_mask"] = [ + torch.ones(num_frame) for num_frame in audio_num_frames + ] + hf_inputs["audio_feature_lengths"] = torch.tensor(audio_num_frames) + return hf_inputs + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + prompt_ids: list[int], + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen3-Omni reimplements this function to handle `use_audio_in_video`. + """ + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = False + if "video" in mm_kwargs: + for item in mm_kwargs["video"]: + if item and item["use_audio_in_video"].data: + use_audio_in_video = True + else: + use_audio_in_video = False + + if use_audio_in_video and "video" in mm_item_counts: + assert "audio" in mm_item_counts + mm_item_counts["audio"] -= mm_item_counts["video"] + + # Special case with `use_audio_in_video=True` + if use_audio_in_video: + if is_update_applied: + prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video) + ( + prompt_ids, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + # normal case with `use_audio_in_video=False` + elif is_update_applied: + mm_placeholders = self._find_mm_placeholders( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + ) + else: + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + ) + + return prompt_ids, mm_placeholders + + def get_updates_use_audio_in_video( + self, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + shift = 0 + audio_token_id = thinker_config.audio_token_id + video_token_id = thinker_config.video_token_id + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + position_id_per_seconds = thinker_config.position_id_per_seconds + audio_token_indices = np.arange(next(iter([audio_len]))) + curr_video_grid_thw = next(iter([video_grid_thw])) + height = curr_video_grid_thw[1] // spatial_merge_size + width = curr_video_grid_thw[2] // spatial_merge_size + video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1) + video_token_indices = np.broadcast_to( + video_token_indices, (video_token_indices.shape[0], height, width) + ).reshape(-1) + video_token_indices = ( + (video_token_indices + shift) + * next(iter([video_second_per_grid_t])) + * position_id_per_seconds + ) + video_data_index, audio_data_index = 0, 0 + updates = [audio_start_token_id] + while video_data_index < len(video_token_indices) and audio_data_index < len( + audio_token_indices + ): + if ( + video_token_indices[video_data_index] + <= audio_token_indices[audio_data_index] + ): + updates += [video_token_id] + video_data_index += 1 + else: + updates += [audio_token_id] + audio_data_index += 1 + if video_data_index < len(video_token_indices): + updates += [video_token_id] * (len(video_token_indices) - video_data_index) + if audio_data_index < len(audio_token_indices): + updates += [audio_token_id] * (len(audio_token_indices) - audio_data_index) + updates += [audio_end_token_id] + return updates + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + _, audio_output_lens = _get_feat_extract_output_lengths( + audio_feature_lengths + ) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1) + ) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + audio_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + nonlocal audio_item_idx + item_idx += audio_in_video_item_idx + + audio_item_idx += 1 + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model" + ) + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + audio_num_features = audio_output_lengths[audio_item_idx + item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 1.0 + + return self.get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video + if use_audio_in_video + else partial(get_replacement_qwen2_vision, modality="video") + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + ) -> None: + BaseMultiModalProcessor[ + Qwen2_5OmniThinkerProcessingInfo + ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts) + + def _get_raw_input_ids( + self, + token_ids: list[int], + use_audio_in_video: bool = False, + ) -> list[int]: + tokenizer = self.info.get_tokenizer() + vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0] + vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0] + audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0] + audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0] + audio_token = tokenizer.encode("<|audio_pad|>")[0] + image_token = tokenizer.encode("<|image_pad|>")[0] + video_token = tokenizer.encode("<|video_pad|>")[0] + + result = token_ids[:] + if use_audio_in_video: + while True: + start = None + for i in range(len(result) - 1): + if result[i : i + 2] == [vision_bos_token, audio_bos_token]: + start = i + break + if start is not None: + end = None + for i in range(start + 2, len(result) - 1): + if result[i : i + 2] == [audio_eos_token, vision_eos_token]: + end = i + break + if end is not None: + result = ( + result[:start] + + [vision_bos_token, video_token, vision_eos_token] + + result[end + 2 :] + ) + else: + break + + for mm_token in [audio_token, image_token, video_token]: + compressed = [] + for x in result: + if x != mm_token or (not compressed or compressed[-1] != mm_token): + compressed.append(x) + result = compressed + + return result + + +class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin): + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str, dim: int = 0 + ) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") + if name == "feature_attention_mask": + dim = -1 + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input), dim=dim) + else: + if isinstance(mm_input[0], list): + return torch.concat( + [torch.concat(mm_input[i], dim=dim) for i in range(len(mm_input))], + dim=dim, + ) + else: + return torch.concat(mm_input, dim=dim) + + def _process_audio_input( + self, + audio_input: Qwen2AudioFeatureInputs, + audio_hashes: list[str] = None, + cached_audio_features: torch.Tensor = None, + ) -> torch.Tensor: + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + + if input_features.ndim == 3: + assert input_features.shape[0] == 1 + input_features = input_features.squeeze(0) + + if not isinstance(audio_feature_lengths, torch.Tensor): + audio_feature_lengths = torch.cat(audio_feature_lengths) + if audio_feature_lengths.ndim == 2: + audio_feature_lengths = audio_feature_lengths.reshape(-1) + + audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( + audio_feature_lengths + ) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + return audio_features.split(audio_output_lengths.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + Qwen3OmniMoeConditionalGenerationMixin, +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + if modality.startswith("audio"): + return "<|audio_start|><|audio_pad|><|audio_end|>" + + raise ValueError("Only image, video or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + thinker_config: Qwen3OmniMoeThinkerConfig = ( + vllm_config.model_config.hf_config.thinker_config + ) + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part." + ) + + self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) + + self.visual = Qwen3Omni_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.quant_config = quant_config + + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config.with_hf_config( + thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"] + ), + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = hasattr( + thinker_config.vision_config, "deepstack_visual_indexes" + ) + self.deepstack_num_level = ( + len(thinker_config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) + # register buffer for deepstack + self.deepstack_input_embeds = ( + [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + thinker_config.text_config.hidden_size, + ) + for _ in range(self.deepstack_num_level) + ] + if self.use_deepstack + else None + ) + self.visual_dim = thinker_config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + return IntermediateTensors( + { + f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][ + :num_tokens + ] + for idx in range(self.deepstack_num_level) + } + ) + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: + # set deepstack_input_embeds to buffer + num_tokens = deepstack_input_embeds.size(1) + if num_tokens > self.deepstack_input_embeds[0].size(0): + self.deepstack_input_embeds = [ + torch.zeros( + num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype, + ) + for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].copy_( + deepstack_input_embeds[idx] + ) + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + if ( + input_key in ("input_audio_features") + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + deepstack_input_embeds = None + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + # split the feat dim to obtain multi-scale visual feature + has_vision_embeddings = [ + embeddings.shape[-1] != self.config.text_config.hidden_size + for embeddings in multimodal_embeddings + ] + if self.visual.deepstack_visual_indexes is not None and any( + has_vision_embeddings + ): + multiscale_len = len(self.visual.deepstack_visual_indexes) + multimodal_embeddings_multiscale = [] + is_vision = torch.zeros_like(is_multimodal) + mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0] + mm_position_idx = 0 + for index, embeddings in enumerate(multimodal_embeddings): + num_tokens = embeddings.shape[0] + current_positions = mm_positions[ + mm_position_idx : mm_position_idx + num_tokens + ] + + # Vision embeddings + if embeddings.shape[-1] != self.config.text_config.hidden_size: + visual_dim = embeddings.shape[-1] // (multiscale_len + 1) + multi_dim = visual_dim * multiscale_len + embeddings_main, embeddings_multiscale = torch.split( + embeddings, [visual_dim, multi_dim], dim=-1 + ) + multimodal_embeddings[index] = embeddings_main + multimodal_embeddings_multiscale.append(embeddings_multiscale) + is_vision[current_positions] = True + + # Audio embeddings + else: + is_vision[current_positions] = False + + mm_position_idx += num_tokens + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1) + ) + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_vision, + ) + deepstack_input_embeds = ( + deepstack_input_embeds.view( + inputs_embeds.shape[0], multiscale_len, visual_dim + ) + .permute(1, 0, 2) + .contiguous() + ) + self._set_deepstack_input_embeds(deepstack_input_embeds) + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + if ( + self.use_deepstack + and inputs_embeds is not None + and get_pp_group().is_first_rank + ): + deepstack_input_embeds = self._get_deepstack_input_embeds( + inputs_embeds.size(0) + ) + else: + deepstack_input_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + if inputs_embeds is not None and get_pp_group().is_first_rank: + self._clear_deepstack_input_embeds(inputs_embeds.size(0)) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "code2wav."], + ) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + return loaded_weights + + @classmethod + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + config = hf_config.thinker_config + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + input_ids = torch.tensor(input_tokens) + if input_ids is None or input_ids.ndim != 1: + raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") + + seq_len = input_ids.shape[0] + if audio_feature_lengths is not None and not isinstance( + audio_feature_lengths, torch.Tensor + ): + audio_feature_lengths = torch.as_tensor( + audio_feature_lengths, dtype=torch.long + ) + if second_per_grid_ts is None: + if video_grid_thw is not None and video_grid_thw.numel() > 0: + second_per_grids = torch.ones( + video_grid_thw.shape[0], dtype=torch.float32 + ) + else: + second_per_grids = torch.tensor([], dtype=torch.float32) + else: + second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + audio_token_id = config.audio_token_id + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + else: + vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + + llm_pos_ids_list: list[torch.Tensor] = [] + st = 0 + image_idx = 0 + video_idx = 0 + audio_idx = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) # noqa: E501 + + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + if min_ed == ed_audio_start: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + llm_pos_ids = ( + torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == image_token_id + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = torch.arange(grid_t) * position_id_per_seconds + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == video_token_id + and not use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + elif ( + min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + and use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + bos_block = ( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(bos_block) + llm_pos_ids_list.append(bos_block) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + audio_llm_pos_ids = ( + torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if ( + video_llm_pos_ids[0][video_data_index] + <= audio_llm_pos_ids[0][audio_data_index] + ): + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_data_index + 1 + ] + ) + video_data_index += 1 + else: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_data_index + 1 + ] + ) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_llm_pos_ids.shape[-1] + ] + ) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_llm_pos_ids.shape[-1] + ] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + eos_block = ( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(eos_block) + llm_pos_ids_list.append(eos_block) + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + if llm_positions.shape[1] != seq_len: + raise RuntimeError("Position ids length mismatch with input ids length") + + mrope_position_delta = llm_positions.max() + 1 - seq_len + return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py new file mode 100644 index 0000000000000..6a7d2eaeab3b8 --- /dev/null +++ b/vllm/model_executor/models/qwen3_vl.py @@ -0,0 +1,1790 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" + +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from itertools import islice +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import BatchFeature, PretrainedConfig +from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast +from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + smart_resize as image_smart_resize, +) +from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLConfig, + Qwen3VLVisionConfig, +) +from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( + smart_resize as video_smart_resize, +) +from transformers.video_utils import VideoMetadata + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItem, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .qwen2_5_vl import ( + Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) +from .qwen2_vl import Qwen2VLProcessingInfo +from .qwen3 import Qwen3ForCausalLM, Qwen3Model +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model + +logger = init_logger(__name__) + +# Official recommended max pixels is 24576 * 32 * 32 +_MAX_FRAMES_PER_VIDEO = 24576 + + +class Qwen3_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen3_VisionMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.linear_fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel, + ) + self.linear_fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel, + ) + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend: _Backend = _Backend.TORCH_SDPA, + use_upstream_fa: bool = False, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + attn_backend=attn_backend, + use_upstream_fa=use_upstream_fa, + ) + self.mlp = Qwen3_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm = norm_layer(context_dim) + self.linear_fc1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel, + ) + self.act_fn = nn.GELU() + self.linear_fc2 = RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + + x_parallel, _ = self.linear_fc1(x) + x_parallel = self.act_fn(x_parallel) + out, _ = self.linear_fc2(x_parallel) + return out + + +class Qwen3_VisionTransformer(nn.Module): + def __init__( + self, + vision_config: Qwen3VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.use_data_parallel = use_data_parallel + self.num_grid_per_side = int(self.num_position_embeddings**0.5) + + # NOTE: This is used for creating empty tensor for all_gather for + # DP ViT. Here out_hidden_size is enlarged due to deepstack + self.out_hidden_size = vision_config.out_hidden_size * ( + 1 + len(self.deepstack_visual_indexes) + ) + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype() + ) + use_upstream_fa = False + if ( + self.attn_backend != _Backend.FLASH_ATTN + and self.attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): + self.attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Qwen3-VL does not support {self.attn_backend} backend now." + ) + + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) + for layer_idx in range(vision_config.depth) + ] + ) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + # Support both Tensor and list inputs for DP path + if isinstance(grid_thw, list): + grid_list = grid_thw + max_grid_size = max(max(h, w) for _, h, w in grid_list) + else: + grid_list = grid_thw.tolist() + max_grid_size = int(grid_thw[:, 1:].max().item()) + for t, h, w in grid_list: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace( + 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device + ) + w_idxs = torch.linspace( + 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + h_floor_grid_idx = h_floor_grid * num_grid_per_side + h_ceil_grid_idx = h_ceil_grid * num_grid_per_side + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - dw_grid + w11 + + idx00 = h_floor_grid_idx + w_floor_grid + idx01 = h_floor_grid_idx + w_ceil_grid + idx10 = h_ceil_grid_idx + w_floor_grid + idx11 = h_ceil_grid_idx + w_ceil_grid + + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to( + dtype=self.dtype, device=self.device, non_blocking=True + ) + + embeds = self.pos_embed(indices) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = p0 + p1 + p2 + p3 + + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() + repeated = repeated.view( + t, h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) + + grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32) + + cu_seqlens = torch.repeat_interleave( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + + hidden_states = hidden_states.unsqueeze(1) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, dim=1 + ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLConfig) + + def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: + return self.ctx.get_hf_processor( + Qwen3VLProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_tokenizer(self): + return self.ctx.tokenizer + + def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast: + return self.get_hf_processor(**kwargs).image_processor + + def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor: + return self.get_hf_processor(**kwargs).video_processor + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 2, + do_resize: bool = True, + image_processor: Optional[ + Union[Qwen2VLImageProcessorFast, Qwen3VLVideoProcessor] + ], + ) -> tuple[ImageSize, int]: + if image_processor is None and num_frames > 1: + image_processor = self.get_video_processor() + elif image_processor is None: + image_processor = self.get_image_processor() + + is_video = isinstance(image_processor, Qwen3VLVideoProcessor) + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size + + if do_resize: + if is_video: + smart_resize = video_smart_resize + extra_kwargs = { + "num_frames": num_frames, + "temporal_factor": temporal_patch_size, + } + else: + smart_resize = image_smart_resize + extra_kwargs = {} + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.size["shortest_edge"], + max_pixels=image_processor.size["longest_edge"], + **extra_kwargs, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, height=image_height) + + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int: + return super()._get_max_video_frames( + max_tokens, start_num_frames=start_num_frames + ) + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + return super().get_num_frames_with_most_features( + seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO + ) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + video_soft_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), + image_processor=None, + ) + + # NOTE: By default in Qwen3-VL, one video token is converted to + # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501 + formatted_video_soft_tokens = video_soft_tokens * 12.5 + return int(formatted_video_soft_tokens) + + def _calculate_timestamps( + self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int + ): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + # don't update metadata's frames_indices directly + indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size) + timestamps = [idx / video_fps for idx in indices] + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 + for i in range(0, len(timestamps), merge_size) + ] + return timestamps + + def _get_video_second_idx( + self, + metadata: dict[str, Any], + out_item: MultiModalKwargsItem, + do_sample_frames: Optional[bool] = None, + sampled_fps: Optional[float] = None, + ) -> list[int]: + video_processor = self.get_video_processor() + merge_size = video_processor.merge_size + indices = metadata["frames_indices"] + + # metadata["fps"] refers to the true fps of the input video. + video_fps = metadata["fps"] + if do_sample_frames is None: + do_sample_frames = metadata.get("do_sample_frames", False) + + # If video frames are sampled in HF processor (instead of vLLM + # video loader), we need to re-calculate the indices from original + # metadata. + if do_sample_frames: + # here video_fps is the fps of the sampled video, and + # metadata["fps"] refers to the fps of the original video. + video_fps = sampled_fps if sampled_fps else video_processor.fps + total_num_frames = metadata["total_num_frames"] + num_frames = int(total_num_frames / metadata["fps"] * video_fps) + num_frames = min( + min( + max(num_frames, video_processor.min_frames), + video_processor.max_frames, + ), + total_num_frames, + ) + indices = ( + np.linspace(0, total_num_frames - 1, num_frames) + .round() + .astype(int) + .tolist() + ) + timestamps = self._calculate_timestamps(indices, video_fps, merge_size) + return timestamps + + +class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_token = "<|vision_start|><|image_pad|><|vision_end|>" + video_token = "<|vision_start|><|video_pad|><|vision_end|>" + + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + if video_overrides: + assert isinstance(video_overrides, VideoDummyOptions) + num_frames_override = video_overrides.num_frames + if num_frames_override: + if num_frames_override > target_num_frames: + logger.warning( + "video.num_frames override (%d) exceeds model's " + "maximum number of frames (%d), will be ignored", + num_frames_override, + target_num_frames, + ) + if num_frames_override < 2: + logger.warning( + "video.num_frames override (%d) cannot be less " + "than 2, will be ignored", + num_frames_override, + ) + target_num_frames = min(target_num_frames, num_frames_override) + target_num_frames = max(target_num_frames, 2) + + target_video_size, _ = self.info._get_vision_info( + image_width=target_width, + image_height=target_height, + num_frames=target_num_frames, + image_processor=self.info.get_video_processor(), + ) + # NOTE: we need to do this check here since Qwen3-VL resizes video + # frames depending on how many frames there are. + width, height = target_video_size.width, target_video_size.height + if video_overrides: + assert isinstance(video_overrides, VideoDummyOptions) + width_override = video_overrides.width + if width_override: + if width_override > width: + logger.warning( + "video.width override (%d) exceeds model's " + "maximum width (%d), will be ignored", + width_override, + width, + ) + width = min(width, width_override) + height_override = video_overrides.height + if height_override: + if height_override > height: + logger.warning( + "video.height override (%d) exceeds model's " + "maximum height (%d), will be ignored", + height_override, + height, + ) + height = min(height, height_override) + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=width, + height=height, + num_frames=target_num_frames, + num_videos=num_videos, + ), + } + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[VideoItem]: + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) + video_items = [] + for i in range(num_videos): + video_metadata = { + "fps": 2.0, + "duration": num_frames / 2.0, + "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": False, + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + return video_items + + +class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]): + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + processor = self.info.get_hf_processor(**mm_kwargs) + + # Separate video processing from image processing. Because the videos + # are processed into serval image patches + if ( + "videos" in mm_data + and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0 + ): + video_grid_thw_lst = [] + pixel_values_videos_lst = [] + + for item_idx, item in enumerate(mm_data.pop("videos", [])): + video_array, metadata = item + + # NOTE: @JJJYmmm new attr metadata.frames_indices indicates + # the sampled frames indices of pre-sampled videos, which is + # used to calculate the timestamps. Make sure that + # do_sample_frames in mm_kwargs is false for presampled videos. + + # NOTE: a copy of is created to update do_sample_frames, + # otherwise mm_hash for the object will be incorrect. + video_mm_kwargs = dict(**mm_kwargs) + if "do_sample_frames" not in video_mm_kwargs: + # qwen_vl_utils already has "do_sample_frames" in + # mm_kwargs, don't overwrite it. + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", False + ) + + metadata = VideoMetadata( + **{k: metadata[k] for k in metadata if k != "do_sample_frames"} + ) + + video_mm_data = dict() + video_mm_data["videos"] = [[video_array]] + video_mm_data["video_metadata"] = [[metadata]] + + video_outputs = super()._call_hf_processor( + prompt="<|vision_start|><|video_pad|><|vision_end|>", + mm_data=video_mm_data, + mm_kwargs=video_mm_kwargs, + tok_kwargs=tok_kwargs, + ) + input_ids = video_outputs.pop("input_ids") + video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] + prompt = prompt.replace( + "<|vision_start|><|video_pad|><|vision_end|>", + video_placeholder, + 1, + ) + + video_grid_thw_lst.append(video_outputs["video_grid_thw"]) + pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) + video_outputs = dict( + pixel_values_videos=torch.cat(pixel_values_videos_lst), + video_grid_thw=torch.cat(video_grid_thw_lst), + ) + else: + video_outputs = dict() + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + combined_outputs = dict( + processed_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes + ), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes + ), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes + ), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + vision_end_token_id = hf_config.vision_end_token_id + + merge_length = image_processor.merge_size**2 + + def get_image_replacement_qwen3vl(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [hf_processor.image_token_id] * num_tokens + + def get_video_replacement_qwen3vl(item_idx: int): + out_item = out_mm_kwargs["video"][item_idx] + grid_thw = out_item["video_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + video, metadata = mm_items["video"][item_idx] + do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") + sampled_fps = hf_processor_mm_kwargs.get("fps") + if is_list_of(sampled_fps, float): + sampled_fps = sampled_fps[item_idx] + timestamps = self.info._get_video_second_idx( + metadata, out_item, do_sample_frames, sampled_fps + ) + + assert len(timestamps) == grid_thw[0], ( + f"The timestamps length({len(timestamps)}) should be equal " + f"video length ({grid_thw[0]})." + ) + + frames_idx_token = [ + tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) + for curr_time in timestamps + ] + num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + placeholder = [] + for frame_idx in frames_idx_token: + placeholder.extend(frame_idx) + placeholder.extend( + [vision_start_token_id] + + [video_token_id] * num_tokens_per_frame + + [vision_end_token_id] + ) + return PromptUpdateDetails.select_token_id(placeholder, video_token_id) + + return [ + PromptReplacement( + modality="image", + target=hf_processor.image_token, + replacement=get_image_replacement_qwen3vl, + ), + # NOTE: We match string on purpose since searching sequence of + # token ids takes more time. + PromptReplacement( + modality="video", + target="<|vision_start|><|video_pad|><|vision_end|>", + replacement=get_video_replacement_qwen3vl, + ), + ] + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + # the same shape as input_embeds + "deepstack_input_embeds": 0, + } +) +class Qwen3LLMModel(Qwen3Model): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes + ), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)" + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + # args for deepstack + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer + ): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3LLMForCausalLM(Qwen3ForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3ForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config.text_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head", + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super().__init__() + config: Qwen3VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + if not multimodal_config.get_limit_per_prompt( + "image" + ) and not multimodal_config.get_limit_per_prompt("video"): + self.visual = None + else: + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + + self.language_model = Qwen3LLMForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) + # register buffer for deepstack + if self.use_deepstack and self.visual is not None: + self.deepstack_input_embeds = [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size, + ) + for _ in range(self.deepstack_num_level) + ] + else: + self.deepstack_input_embeds = None + self.visual_dim = config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + return IntermediateTensors( + { + f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][ + :num_tokens + ] + for idx in range(self.deepstack_num_level) + } + ) + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: + # set deepstack_input_embeds to buffer + num_tokens = deepstack_input_embeds.size(1) + if num_tokens > self.deepstack_input_embeds[0].size(0): + self.deepstack_input_embeds = [ + torch.zeros( + num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype, + ) + for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].copy_( + deepstack_input_embeds[idx] + ) + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})" + ) + return mm_input.reshape(-1, mm_input.shape[-1]) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values" + ) + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw" + ) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}" + ) + + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds" + ) + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw" + ) + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values" + ) + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw" + ) + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds" + ) + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw" + ) + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError( + "Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}" + ) + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + # Split concatenated embeddings for each image item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + merge_size = self.visual.spatial_merge_size + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() + return image_embeds.split(sizes) + + def _process_video_input( + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype + ) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) + + # Split concatenated embeddings for each video item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + merge_size = self.visual.spatial_merge_size + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() + return video_embeds.split(sizes) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + for input_key in kwargs: + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + return mm_input_by_modality + + @classmethod + def get_mrope_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + second_per_grid_ts: Optional[list[float]] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def _compute_deepstack_embeds( + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] + multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) + + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = torch.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) + + multimodal_embeddings = torch.split( + multimodal_embeddings_main, visual_lens, dim=0 + ) + multimodal_embeddings_multiscale = torch.split( + multimodal_embeddings_multiscale, visual_lens, dim=0 + ) + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1) + ) + + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, + ) + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim + ) + deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + + return deepstack_input_embeds, multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + if self.use_deepstack: + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + if deepstack_input_embeds is not None: + self._set_deepstack_input_embeds(deepstack_input_embeds) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen3VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen3VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + intermediate_tensors: Intermediate tensors from previous pipeline + stages. + inputs_embeds: Pre-computed input embeddings. + **kwargs: Additional keyword arguments including: + - pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in + LLM. `None` if no images are passed. + - pixel_values_videos: Pixel values of videos to be fed to a + model. `None` if no videos are passed. + - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in + LLM. `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + if ( + self.use_deepstack + and inputs_embeds is not None + and get_pp_group().is_first_rank + ): + deepstack_input_embeds = self._get_deepstack_input_embeds( + inputs_embeds.size(0) + ) + else: + deepstack_input_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + if inputs_embeds is not None and get_pp_group().is_first_rank: + self._clear_deepstack_input_embeds(inputs_embeds.size(0)) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="model.visual.merger", + tower_model="model.visual.", + ) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py new file mode 100644 index 0000000000000..db7bcb0436595 --- /dev/null +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights.""" + +import typing +from collections.abc import Iterable +from itertools import islice +from typing import Callable, Optional, Union + +import torch +from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .qwen3_vl import ( + Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, + Qwen3VLProcessingInfo, +) +from .utils import is_pp_missing_parameter, maybe_prefix + +logger = init_logger(__name__) + + +class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLMoeConfig) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + # the same shape as input_embeds + "deepstack_input_embeds": 0, + } +) +class Qwen3MoeLLMModel(Qwen3MoeModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes + ), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)" + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer + ): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_fused_expert_weights( + self, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: + param = params_dict[name] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + loaded_local_expert = False + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + success = weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + + return loaded_local_expert + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + num_experts = self.config.num_experts + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): + continue + if is_fused_expert: + loaded_weight = loaded_weight.transpose(-1, -2) # no bias + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + success_w1 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) + success_w3 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) + success = success_w1 and success_w3 + else: + # down_proj + success = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) + else: + # Skip loading extra parameters for GPTQ/modelopt models + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale" + ) + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + self.config = vllm_config.model_config.hf_config.text_config + self.quant_config = vllm_config.quant_config + self.model = Qwen3MoeLLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3VLForConditionalGeneration, self).__init__() + config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + + if not multimodal_config.get_limit_per_prompt( + "image" + ) and not multimodal_config.get_limit_per_prompt("video"): + self.visual = None + else: + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) + # register buffer for deepstack + if self.use_deepstack and self.visual is not None: + self.deepstack_input_embeds = [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size, + ) + for _ in range(self.deepstack_num_level) + ] + else: + self.deepstack_input_embeds = None + self.visual_dim = config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 2950ca664a98f..1786ea6a6878b 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -11,63 +11,87 @@ import math import unicodedata from collections.abc import Collection, Mapping, Sequence, Set from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import regex as re import torch from torch import nn from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer, - TensorType) +from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .qwen import QWenBaseModel, QWenModel -from .utils import flatten_bn, merge_multimodal_embeddings +from .utils import flatten_bn -class QwenImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor +class QwenImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, 3, image_size, image_size)` + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support can only be leveraged by passing image embeddings directly. """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class QwenImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, 256, hidden_size)` + +class QwenImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size (256) + - hs: Hidden size `hidden_size` must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one. """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")] + QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] @@ -90,8 +114,7 @@ class VisualAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim \ - and self.vdim == embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads @@ -102,8 +125,9 @@ class VisualAttention(nn.Module): self.hidden_size_per_partition = embed_dim # Strided linear layer. - assert self._qkv_same_embed_dim, \ - 'Visual Attention implementation only supports self-attention' + assert self._qkv_same_embed_dim, ( + "Visual Attention implementation only supports self-attention" + ) self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) self.out_proj = ReplicatedLinear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -118,50 +142,63 @@ class VisualAttention(nn.Module): mixed_x_layer, _ = self.in_proj(x) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] query_layer, key_layer, value_layer = mixed_x_layer.split( - self.hidden_size_per_attention_head, dim=-1) + self.hidden_size_per_attention_head, dim=-1 + ) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ).transpose(0, 1) q_scaled = query_layer / self.norm_factor if attn_mask is not None: - attention_probs = torch.baddbmm(attn_mask, q_scaled, - key_layer.transpose(-2, -1)) + attention_probs = torch.baddbmm( + attn_mask, q_scaled, key_layer.transpose(-2, -1) + ) else: attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) attention_probs = attention_probs.softmax(dim=-1) value_layer = value_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ).transpose(0, 1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer) # change view [b, np, sq, hn] context_layer = context_layer.view( - b, self.num_attention_heads_per_partition, sq, - self.hidden_size_per_attention_head) + b, + self.num_attention_heads_per_partition, + sq, + self.hidden_size_per_attention_head, + ) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) context_layer = context_layer.view(*new_context_layer_shape) output, _ = self.out_proj(context_layer) @@ -179,10 +216,9 @@ class QwenVLMLP(nn.Module): quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.c_fc = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config) + self.c_fc = ColumnParallelLinear( + hidden_size, intermediate_size, bias=True, quant_config=quant_config + ) self.act_fn = get_act_fn("gelu") self.c_proj = RowParallelLinear( intermediate_size, @@ -199,7 +235,6 @@ class QwenVLMLP(nn.Module): class VisualAttentionBlock(nn.Module): - def __init__( self, d_model: int, @@ -239,7 +274,6 @@ class VisualAttentionBlock(nn.Module): class TransformerBlock(nn.Module): - def __init__( self, width: int, @@ -253,14 +287,18 @@ class TransformerBlock(nn.Module): self.width = width self.layers = layers - self.resblocks = nn.ModuleList([ - VisualAttentionBlock(width, - heads, - mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config) - for _ in range(layers) - ]) + self.resblocks = nn.ModuleList( + [ + VisualAttentionBlock( + width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) + for _ in range(layers) + ] + ) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -268,54 +306,57 @@ class TransformerBlock(nn.Module): def get_cast_device(self) -> torch.device: return self.resblocks[0].mlp.c_fc.weight.device - def forward(self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: for r in self.resblocks: x = r(x, attn_mask=attn_mask) return x class VisionTransformer(nn.Module): - - def __init__(self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - n_queries: int = 256, - output_dim: int = 512, - image_start_id: int = 151857, - quant_config: Optional[QuantizationConfig] = None, - **kwargs): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs, + ): super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) - self.grid_size = (image_height // patch_height, - image_width // patch_width) + self.grid_size = (image_height // patch_height, image_width // patch_width) self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False) + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) # class embeddings and positional embeddings scale = width**-0.5 - self.positional_embedding = nn.Parameter(scale * - torch.randn(256, width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_pre = norm_layer(width) - self.transformer = TransformerBlock(width, - layers, - heads, - mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config) + self.transformer = TransformerBlock( + width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) self.attn_pool = Resampler2( grid_size=int(math.sqrt(n_queries)), @@ -332,7 +373,8 @@ class VisionTransformer(nn.Module): self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter( - (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + (output_dim**-0.5) * torch.randn(output_dim, output_dim) + ) self.image_start_id = image_start_id self.image_end_id = image_start_id + 1 @@ -346,12 +388,10 @@ class VisionTransformer(nn.Module): # to patches x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], - -1) # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( - x.size(1)))) + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(x.size(1)))) x = self.ln_pre(x) @@ -367,20 +407,19 @@ class VisionTransformer(nn.Module): class QwenVLModel(QWenModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.visual = VisionTransformer(**config.visual, - quant_config=quant_config) + self.visual = VisionTransformer(**config.visual, quant_config=quant_config) @lru_cache(maxsize=1) def _get_tokenizer_without_image_pad( - tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: + tokenizer: PreTrainedTokenizer, +) -> PreTrainedTokenizer: """ The logic of adding image pad tokens should only be applied in [`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor], @@ -392,7 +431,6 @@ def _get_tokenizer_without_image_pad( new_tokenizer = copy.deepcopy(tokenizer) class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore - def tokenize( self, text: str, @@ -403,7 +441,8 @@ def _get_tokenizer_without_image_pad( text = unicodedata.normalize("NFC", text) return [ - self.decoder[t] for t in self.tokenizer.encode( + self.decoder[t] + for t in self.tokenizer.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, @@ -425,8 +464,7 @@ def _get_tokenizer_without_image_pad( errors=errors or self.errors, ) - TokenizerWithoutImagePad.__name__ = \ - f"{tokenizer.__class__.__name__}WithoutImagePad" + TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad" new_tokenizer.__class__ = TokenizerWithoutImagePad return new_tokenizer @@ -457,17 +495,19 @@ class QwenVLProcessor: vision_config = config.visual image_size = vision_config["image_size"] - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ]) + self.image_transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) @property def image_start_tag(self) -> str: @@ -514,7 +554,6 @@ class QwenVLProcessor: class QwenVLProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self) -> PreTrainedTokenizer: tokenizer = self.ctx.tokenizer assert isinstance(tokenizer, PreTrainedTokenizer) @@ -543,7 +582,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo): class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -551,13 +589,15 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): img_start = hf_processor.image_start_tag img_end = hf_processor.image_end_tag - return "".join(f"Picture {i}: {img_start}{img_end}\n" - for i in range(1, num_images + 1)) + return "".join( + f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1) + ) def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.visual @@ -565,16 +605,19 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -630,8 +673,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - special_tokens: dict[str, - int] = tokenizer.special_tokens # type: ignore + special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore processor = self.info.get_hf_processor() img_start_id = special_tokens[processor.image_start_tag] @@ -653,11 +695,14 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(QwenVLMultiModalProcessor, - info=QwenVLProcessingInfo, - dummy_inputs=QwenVLDummyInputsBuilder) -class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, - SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor( + QwenVLMultiModalProcessor, + info=QwenVLProcessingInfo, + dummy_inputs=QwenVLDummyInputsBuilder, +) +class QwenVLForConditionalGeneration( + QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal +): packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ @@ -673,7 +718,8 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, return MultiModelKeys.from_string_field( language_model="transformer.h", connector="transformer.visual.attn_pool", - tower_model="transformer.visual.transformer") + tower_model="transformer.visual.transformer", + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -697,39 +743,33 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, self.transformer: QwenVLModel - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.visual["image_size"] - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[QwenImageInputs]: + self, **kwargs: object + ) -> Optional[QwenImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) + + expected_h = expected_w = self.config.visual["image_size"] + resolve_bindings = {"h": expected_h, "w": expected_w} return QwenImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + data=flatten_bn(pixel_values, concat=True), + resolve_bindings=resolve_bindings, ) if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) return QwenImageEmbeddingInputs( type="image_embeds", @@ -738,8 +778,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, return None - def _process_image_input(self, - image_input: QwenImageInputs) -> torch.Tensor: + def _process_image_input(self, image_input: QwenImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] @@ -748,8 +787,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, def get_language_model(self) -> torch.nn.Module: return self.transformer - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -757,21 +795,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.transformer.visual.image_pad_id) - - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -783,14 +806,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/radio.py b/vllm/model_executor/models/radio.py new file mode 100644 index 0000000000000..2313b98348b77 --- /dev/null +++ b/vllm/model_executor/models/radio.py @@ -0,0 +1,583 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import math +from collections.abc import Iterable +from itertools import repeat +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import PretrainedConfig + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.intern_vit import InternVisionEncoder + +input_dim_t = Union[int, tuple[int, int]] +norm_t = Union[tuple[float, float, float], torch.Tensor] + + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +class InputConditioner(nn.Module): + def __init__( + self, + input_scale: float, + norm_mean: norm_t, + norm_std: norm_t, + dtype: torch.dtype = None, + ): + super().__init__() + + self.dtype = dtype + + self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) + self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) + + def forward(self, x: torch.Tensor): + y = (x - self.norm_mean) / self.norm_std + if self.dtype is not None: + y = y.to(self.dtype) + return y + + +def _to_tensor(v: norm_t): + return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) + + +class ClsToken(nn.Module): + def __init__( + self, + ndim: int, + num_tokens: int = 1, + enabled: bool = True, + register_multiple: Optional[int] = None, + num_registers: Optional[int] = None, + ): + super().__init__() + + self.ndim = ndim + self.enabled = enabled + self.num_registers = 0 + self.num_tokens = num_tokens + if enabled: + if num_registers: + self.num_registers = num_registers + elif register_multiple: + self.num_registers = register_multiple - ( + num_tokens % register_multiple + ) + + scale = ndim**-0.5 + self.token = nn.Parameter( + torch.randn(num_tokens + self.num_registers, ndim) * scale + ) + + else: + self.token = None + + self.num_patches = self.num_tokens + self.num_registers + + def forward(self, x: torch.Tensor): + if self.token is None: + return x + + token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) + x = torch.cat( + [ + token, + x, + ], + dim=1, + ) + + return x + + +class ViTPatchGenerator(nn.Module): + def __init__( + self, + # config: PretrainedConfig, + patch_size: int, + embed_dim: int, + input_dims: input_dim_t, + abs_pos: bool = True, + normalize_patches: bool = False, + cls_token: bool = False, + max_input_dims: Optional[input_dim_t] = None, + pos_dropout: float = 0.0, + return_pos_enc: bool = False, + num_cls_tokens: int = 1, + register_multiple: Optional[int] = None, + num_registers: Optional[int] = None, + patch_bias: bool = False, + device=None, + dtype=None, + ): + super().__init__() + if isinstance(input_dims, int): + input_dims = (input_dims, input_dims) + + if max_input_dims is None: + max_input_dims = input_dims + if isinstance(max_input_dims, int): + max_input_dims = (max_input_dims, max_input_dims) + + max_input_dims = tuple( + int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims + ) + + self.cpe_mode = max_input_dims != input_dims + self.pos_dropout = pos_dropout + self.return_pos_enc = return_pos_enc + + factory = dict(device=device, dtype=dtype) + + self.patch_size = patch_size + self.abs_pos = abs_pos + self.embed_dim = embed_dim + + self.num_rows = max_input_dims[0] // patch_size + self.num_cols = max_input_dims[1] // patch_size + self.input_dims = tuple(d // patch_size for d in input_dims) + self.num_patches = self.num_rows * self.num_cols + self.max_input_dims = max_input_dims + + self.im_to_patches = Im2Patches(patch_size) + self.embedder = ViTPatchLinear( + patch_size, embed_dim, bias=patch_bias, **factory + ) + + if abs_pos: + scale = embed_dim**-0.5 + self.pos_embed = nn.Parameter( + torch.randn(1, self.num_patches, embed_dim, **factory) * scale + ) + + self.cls_token = ClsToken( + embed_dim, + num_tokens=num_cls_tokens, + enabled=cls_token, + register_multiple=register_multiple, + num_registers=num_registers, + ) + + self.patch_normalizer = ( + nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + patches = self.embed_patches(x) + patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:]) + patches = self.cls_token(patches) + patches = self.patch_normalizer(patches) + if self.return_pos_enc: + return patches, pos_enc + return patches + + @property + def apply_cls_token(self): + return self.cls_token.enabled + + @property + def num_cls_tokens(self): + return self.cls_token.num_tokens + + @property + def num_cls_patches(self): + return self.cls_token.num_patches + + @property + def num_registers(self): + return self.cls_token.num_registers + + @property + def num_skip(self): + return self.num_cls_tokens + self.num_registers + + def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter): + if src_embed.shape != targ_embed.shape: + src_size = int(math.sqrt(src_embed.shape[1])) + + assert src_size**2 == src_embed.shape[1], ( + "Unable to interpolate non-square embedding" + ) + + src_embed = rearrange( + src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size + ) + src_embed = F.interpolate( + src_embed, + size=(self.num_rows, self.num_cols), + mode="bicubic", + align_corners=True, + antialias=False, + ) + src_embed = rearrange(src_embed, "b c h w -> b (h w) c") + targ_embed.data.copy_(src_embed) + + def _load_projection( + self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor + ): + if src_proj_weight.shape != targ_proj_weight.shape: + src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3)) + + assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], ( + "Unable to interpolate non-square patch size" + ) + + src_proj_weight = rearrange( + src_proj_weight, + "b (c h w) -> b c h w", + c=3, + h=src_patch_size, + w=src_patch_size, + ) + src_proj_weight = F.interpolate( + src_proj_weight, + size=(self.patch_size, self.patch_size), + mode="bicubic", + align_corners=True, + antialias=False, + ) + src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)") + targ_proj_weight.data.copy_(src_proj_weight) + + def embed_patches(self, x: torch.Tensor) -> torch.Tensor: + patches = self.im_to_patches(x) + patches = self.embedder(patches) + return patches + + def apply_pos_enc( + self, + patches: torch.Tensor, + patch_idxs: Optional[torch.Tensor] = None, + input_size: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + if not self.abs_pos: + return patches + + pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size) + + if self.training and self.pos_dropout > 0: + keeps = ( + torch.rand( + patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device + ) + > self.pos_dropout + ) + pos_enc_drop = torch.where(keeps, pos_enc, 0) + else: + pos_enc_drop = pos_enc + + return patches + pos_enc_drop, pos_enc + + def get_pos_enc( + self, + batch_size: int, + patch_idxs: Optional[torch.Tensor] = None, + input_size: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + if input_size is None: + input_dims = self.input_dims + else: + input_dims = tuple(d // self.patch_size for d in input_size) + + pos_embed = self._get_pos_embeddings(batch_size, input_dims) + + if patch_idxs is None: + return pos_embed + + exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]) + + pos_embed = torch.gather( + pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs + ) + return pos_embed + + def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]): + if (self.num_rows, self.num_cols) == input_dims: + return self.pos_embed + + pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute( + 0, 3, 1, 2 + ) + + def window_select(pos_embed): + if input_dims[0] < pos_embed.shape[-2]: + pos_embed = pos_embed[..., : input_dims[0], :] + if input_dims[1] < pos_embed.shape[-1]: + pos_embed = pos_embed[..., :, : input_dims[1]] + return pos_embed + + if self.cpe_mode: + if self.training: + min_scale = math.sqrt(0.1) + scale = ( + torch.rand(batch_size, 1, 1, device=pos_embed.device) + * (1 - min_scale) + + min_scale + ) + aspect_min = math.log(3 / 4) + aspect_max = -aspect_min + aspect = torch.exp( + torch.rand(batch_size, 1, 1, device=pos_embed.device) + * (aspect_max - aspect_min) + + aspect_min + ) + + scale_x = scale * aspect + scale_y = scale * (1 / aspect) + scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1) + + pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * ( + 1 - scale_xy + ) + + lin_x = torch.linspace( + 0, 1, steps=input_dims[1], device=pos_embed.device + )[None, None].expand(batch_size, input_dims[0], -1) + lin_y = torch.linspace( + 0, 1, steps=input_dims[0], device=pos_embed.device + )[None, :, None].expand(batch_size, -1, input_dims[1]) + + lin_xy = torch.stack([lin_x, lin_y], dim=-1) + + grid_xy = lin_xy * scale_xy + pos_xy + + # Convert to [-1, 1] range + grid_xy.mul_(2).sub_(1) + + pos_embed = F.grid_sample( + pos_embed.float().expand(batch_size, -1, -1, -1), + grid=grid_xy, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ).to(pos_embed.dtype) + else: + max_dim = max(input_dims) + pos_embed = F.interpolate( + pos_embed.float(), + size=(max_dim, max_dim), + align_corners=True, + mode="bilinear", + ).to(pos_embed.dtype) + + pos_embed = window_select(pos_embed) + else: + pos_embed = window_select(pos_embed) + + if pos_embed.shape[-2:] != input_dims: + pos_embed = F.interpolate( + pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear" + ).to(pos_embed.dtype) + + pos_embed = pos_embed.flatten(2).permute(0, 2, 1) + + return pos_embed + + +class Im2Patches(nn.Module): + def __init__(self, patch_size: int): + super().__init__() + self.patch_size = patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.patch_size == 1: + patches = x.flatten(2) + patches = patches.permute(0, 2, 1) + return patches + + py = x.shape[-2] // self.patch_size + px = x.shape[-1] // self.patch_size + patches = rearrange( + x, + "b c (py yy) (px xx) -> b (py px) (c yy xx)", + py=py, + yy=self.patch_size, + px=px, + xx=self.patch_size, + ) + return patches + + +class ViTPatchLinear(nn.Linear): + def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory): + super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory) + self.patch_size = patch_size + + +class RadioInternVisionModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig = None, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.img_size, self.grid_size, self.num_patches = self._init_img_size( + to_2tuple(config.patch_size), config.image_size + ) + max_img_size = int( + round(config.max_img_size / config.patch_size) * config.patch_size + ) + self.patch_generator = ViTPatchGenerator( + config.patch_size, + config.hidden_size, + input_dims=self.img_size, + max_input_dims=max_img_size, + cls_token=True, + register_multiple=config.reg_tokens, + ) + + self.encoder = InternVisionEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.encoder", + ) + + def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, int]]): + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def get_input_embeddings(self): + return self.embeddings + + def forward(self, x: torch.Tensor) -> torch.FloatTensor: + assert self.patch_generator is not None + hidden_states = self.patch_generator(x) + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + return encoder_outputs + + +class RadioModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.input_conditioner = InputConditioner( + input_scale=1.0, + norm_mean=config.norm_mean, + norm_std=config.norm_std, + ) + self.model = RadioInternVisionModel( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + prefix=prefix, + ) + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + x = self.input_conditioner(pixel_values) + y = self.model(x) + return self._extract_final(y) + + def load_weights(self, weights) -> set[str]: + loaded_params: set[str] = set() + params_dict = dict(self.named_parameters()) + + if isinstance(weights, dict): + weights_list = list(weights.items()) + else: + weights_list = list(weights) + + for name, weight in weights_list: + if not name.startswith("radio_model."): + # Skip non-radio weights + continue + + sub = name[len("radio_model.") :] # drop "radio_model." prefix + + # Skip buffers not used in vLLM + if sub in {"summary_idxs"}: + continue + + vllm_key = None + if sub.startswith("model.patch_generator."): + vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}" + elif sub.startswith("input_conditioner."): + vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}" + elif sub.startswith("model.blocks."): + # Encoder blocks: HF 'model.blocks.{i}.' -> + # vLLM 'model.encoder.layers.{i}.' + parts = sub.split(".") + if len(parts) >= 4: + layer_idx = parts[2] + suffix = ".".join(parts[3:]) + # Skip layer-scale entries that vLLM doesn't use + if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")): + continue + vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}" + + if vllm_key and vllm_key in params_dict: + param = params_dict[vllm_key] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, weight) + loaded_params.add(vllm_key) + + return loaded_params + + def _extract_final(self, y: torch.Tensor): + # Remove CLS + REGISTERS tokens + patch_gen = getattr(self.model, "patch_generator", None) + if patch_gen is not None: + all_feat = y[:, patch_gen.num_skip :] + + return all_feat diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 465c25f094806..a52fcb3eeef3c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -4,7 +4,10 @@ Whenever you add an architecture to this page, please also update `tests/models/registry.py` with example HuggingFace models for it. """ + +import hashlib import importlib +import json import os import pickle import subprocess @@ -12,30 +15,48 @@ import sys import tempfile from abc import ABC, abstractmethod from collections.abc import Set -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import lru_cache +from pathlib import Path from typing import Callable, Optional, TypeVar, Union import torch.nn as nn import transformers -from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults, - try_match_architecture_defaults) +from vllm import envs +from vllm.config import ( + ModelConfig, + iter_architecture_defaults, + try_match_architecture_defaults, +) from vllm.logger import init_logger -from vllm.transformers_utils.dynamic_module import ( - try_get_class_from_dynamic_module) +from vllm.logging_utils import logtime +from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module -from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, - is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import is_pooling_model, is_text_generation_model +from .interfaces import ( + has_inner_state, + has_noops, + is_attention_free, + is_hybrid, + supports_cross_encoding, + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input_only, + supports_pp, + supports_transcription, + supports_v0_only, +) +from .interfaces_base import ( + get_default_pooling_type, + is_pooling_model, + is_text_generation_model, +) logger = init_logger(__name__) -# yapf: disable _TEXT_GENERATION_MODELS = { # [Decoder-only] + "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"), @@ -48,17 +69,20 @@ _TEXT_GENERATION_MODELS = { # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), + "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"), + "CwmForCausalLM": ("llama", "LlamaForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), + "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), @@ -66,10 +90,12 @@ _TEXT_GENERATION_MODELS = { "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), + "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), + "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), @@ -80,8 +106,8 @@ _TEXT_GENERATION_MODELS = { "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), - "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501 - "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 + "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501 + "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), @@ -94,19 +120,20 @@ _TEXT_GENERATION_MODELS = { "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), + "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 + "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), - "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), + "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), @@ -115,6 +142,7 @@ _TEXT_GENERATION_MODELS = { "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), @@ -122,7 +150,6 @@ _TEXT_GENERATION_MODELS = { "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), @@ -140,10 +167,6 @@ _TEXT_GENERATION_MODELS = { "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"), "XverseForCausalLM": ("llama", "LlamaForCausalLM"), "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), - # [Encoder-decoder] - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), - "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"), } _EMBEDDING_MODELS = { @@ -151,6 +174,7 @@ _EMBEDDING_MODELS = { "BertModel": ("bert", "BertEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3TextModel": ("gemma3", "Gemma3Model"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), @@ -161,7 +185,8 @@ _EMBEDDING_MODELS = { "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all - k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() + k: (mod, arch) + for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() if arch == "LlamaForCausalLM" }, "MistralModel": ("llama", "LlamaForCausalLM"), @@ -177,85 +202,178 @@ _EMBEDDING_MODELS = { "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), # [Multimodal] - "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 + "CLIPModel": ("clip", "CLIPEmbeddingModel"), + "LlavaNextForConditionalGeneration": ( + "llava_next", + "LlavaNextForConditionalGeneration", + ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - # Technically PrithviGeoSpatialMAE is a model that works on images, both in + # Technically Terratorch models work on images, both in # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. - "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"), + "Terratorch": ("terratorch", "Terratorch"), } _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), - "RobertaForSequenceClassification": ("roberta", - "RobertaForSequenceClassification"), - "XLMRobertaForSequenceClassification": ("roberta", - "RobertaForSequenceClassification"), - "ModernBertForSequenceClassification": ("modernbert", - "ModernBertForSequenceClassification"), + "BertForTokenClassification": ("bert", "BertForTokenClassification"), + "GteNewForSequenceClassification": ( + "bert_with_rope", + "GteNewForSequenceClassification", + ), + "ModernBertForSequenceClassification": ( + "modernbert", + "ModernBertForSequenceClassification", + ), + "ModernBertForTokenClassification": ( + "modernbert", + "ModernBertForTokenClassification", + ), + "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ( + "roberta", + "RobertaForSequenceClassification", + ), # [Auto-converted (see adapters.py)] - "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, + "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } _MULTIMODAL_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), - "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501 + "AyaVisionForConditionalGeneration": ( + "aya_vision", + "AyaVisionForConditionalGeneration", + ), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), - "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 - "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 + "ChameleonForConditionalGeneration": ( + "chameleon", + "ChameleonForConditionalGeneration", + ), + "Cohere2VisionForConditionalGeneration": ( + "cohere2_vision", + "Cohere2VisionForConditionalGeneration", + ), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), + "Ernie4_5_VLMoeForConditionalGeneration": ( + "ernie45_vl", + "Ernie4_5_VLMoeForConditionalGeneration", + ), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 - "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 + "Gemma3nForConditionalGeneration": ( + "gemma3n_mm", + "Gemma3nForConditionalGeneration", + ), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501 - "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 + "GraniteSpeechForConditionalGeneration": ( + "granite_speech", + "GraniteSpeechForConditionalGeneration", + ), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), - "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 - "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), - "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 + "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"), + "InternS1ForConditionalGeneration": ( + "interns1", + "InternS1ForConditionalGeneration", + ), + "InternVLForConditionalGeneration": ( + "interns1", + "InternS1ForConditionalGeneration", + ), + "Idefics3ForConditionalGeneration": ( + "idefics3", + "Idefics3ForConditionalGeneration", + ), + "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "KeyeVL1_5ForConditionalGeneration": ( + "keye_vl1_5", + "KeyeVL1_5ForConditionalGeneration", + ), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), + "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), - "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 - "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 - "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "LlavaNextForConditionalGeneration": ( + "llava_next", + "LlavaNextForConditionalGeneration", + ), + "LlavaNextVideoForConditionalGeneration": ( + "llava_next_video", + "LlavaNextVideoForConditionalGeneration", + ), + "LlavaOnevisionForConditionalGeneration": ( + "llava_onevision", + "LlavaOnevisionForConditionalGeneration", + ), "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 - "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501 + "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"), + "MiniMaxVL01ForConditionalGeneration": ( + "minimax_vl_01", + "MiniMaxVL01ForConditionalGeneration", + ), "MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMV": ("minicpmv", "MiniCPMV"), - "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 + "Mistral3ForConditionalGeneration": ( + "mistral3", + "Mistral3ForConditionalGeneration", + ), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), - "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 + "PaliGemmaForConditionalGeneration": ( + "paligemma", + "PaliGemmaForConditionalGeneration", + ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 - "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 - "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 - "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 - "UltravoxModel": ("ultravox", "UltravoxModel"), + "Qwen2_5_VLForConditionalGeneration": ( + "qwen2_5_vl", + "Qwen2_5_VLForConditionalGeneration", + ), + "Qwen2AudioForConditionalGeneration": ( + "qwen2_audio", + "Qwen2AudioForConditionalGeneration", + ), + "Qwen2_5OmniModel": ( + "qwen2_5_omni_thinker", + "Qwen2_5OmniThinkerForConditionalGeneration", + ), + "Qwen2_5OmniForConditionalGeneration": ( + "qwen2_5_omni_thinker", + "Qwen2_5OmniThinkerForConditionalGeneration", + ), + "Qwen3OmniMoeForConditionalGeneration": ( + "qwen3_omni_moe_thinker", + "Qwen3OmniMoeThinkerForConditionalGeneration", + ), + "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 + "Qwen3VLMoeForConditionalGeneration": ( + "qwen3_vl_moe", + "Qwen3VLMoeForConditionalGeneration", + ), + "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 - "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 + "Tarsier2ForConditionalGeneration": ( + "qwen2_vl", + "Tarsier2ForConditionalGeneration", + ), + "UltravoxModel": ("ultravox", "UltravoxModel"), "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] - "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 - "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 - "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 - "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } @@ -265,13 +383,15 @@ _SPECULATIVE_DECODING_MODELS = { "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), + "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), + "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), @@ -285,11 +405,30 @@ _TRANSFORMERS_SUPPORTED_MODELS = { } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersModel": ("transformers", "TransformersModel"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), - "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 + "TransformersMoEForMultimodalLM": ( + "transformers_moe", + "TransformersMoEForMultimodalLM", + ), + "TransformersEmbeddingModel": ( + "transformers_pooling", + "TransformersEmbeddingModel", + ), + "TransformersForSequenceClassification": ( + "transformers_pooling", + "TransformersForSequenceClassification", + ), + "TransformersMoEForSequenceClassification": ( + "transformers_pooling", + "TransformersMoEForSequenceClassification", + ), + "TransformersMoEEmbeddingModel": ( + "transformers_pooling", + "TransformersMoEEmbeddingModel", + ), } -# yapf: enable _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, @@ -305,11 +444,21 @@ _VLLM_MODELS = { # can modify this variable to alter the args if needed. e.g. # when we use par format to pack things together, sys.executable # might not be the target we want to run. -_SUBPROCESS_COMMAND = [ - sys.executable, "-m", "vllm.model_executor.models.registry" -] +_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"] -_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"} +_PREVIOUSLY_SUPPORTED_MODELS = { + "MotifForCausalLM": "0.10.2", + "Phi3SmallForCausalLM": "0.9.2", + "Phi4FlashForCausalLM": "0.10.2", + # encoder-decoder models except whisper + # have been removed for V0 deprecation. + "BartModel": "0.10.2", + "BartForConditionalGeneration": "0.10.2", + "DonutForConditionalGeneration": "0.10.2", + "Florence2ForConditionalGeneration": "0.10.2", + "MBartForConditionalGeneration": "0.10.2", + "MllamaForConditionalGeneration": "0.10.2", +} @dataclass(frozen=True) @@ -320,7 +469,8 @@ class _ModelInfo: default_pooling_type: str supports_cross_encoding: bool supports_multimodal: bool - supports_multimodal_raw_input: bool + supports_multimodal_raw_input_only: bool + supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -339,21 +489,26 @@ class _ModelInfo: default_pooling_type=get_default_pooling_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), - supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( + model + ), + supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data( + model + ), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), supports_transcription=supports_transcription(model), - supports_transcription_only=(supports_transcription(model) and - model.supports_transcription_only), + supports_transcription_only=( + supports_transcription(model) and model.supports_transcription_only + ), supports_v0_only=supports_v0_only(model), has_noops=has_noops(model), ) class _BaseRegisteredModel(ABC): - @abstractmethod def inspect_model_cls(self) -> _ModelInfo: raise NotImplementedError @@ -391,13 +546,104 @@ class _LazyRegisteredModel(_BaseRegisteredModel): """ Represents a model that has not been imported in the main process. """ + module_name: str class_name: str - # Performed in another process to avoid initializing CUDA + @staticmethod + def _get_cache_dir() -> Path: + return Path(envs.VLLM_CACHE_ROOT) / "modelinfos" + + def _get_cache_filename(self) -> str: + cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-") + return f"{cls_name}.json" + + def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None: + try: + try: + modelinfo_path = self._get_cache_dir() / self._get_cache_filename() + with open(modelinfo_path, encoding="utf-8") as file: + mi_dict = json.load(file) + except FileNotFoundError: + logger.debug( + ("Cached model info file for class %s.%s not found"), + self.module_name, + self.class_name, + ) + return None + + if mi_dict["hash"] != module_hash: + logger.debug( + ("Cached model info file for class %s.%s is stale"), + self.module_name, + self.class_name, + ) + return None + + # file not changed, use cached _ModelInfo properties + return _ModelInfo(**mi_dict["modelinfo"]) + except Exception: + logger.exception( + ("Cached model info for class %s.%s error. "), + self.module_name, + self.class_name, + ) + return None + + def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None: + """save dictionary json file to cache""" + from vllm.model_executor.model_loader.weight_utils import atomic_writer + + try: + modelinfo_dict = { + "hash": module_hash, + "modelinfo": asdict(mi), + } + cache_dir = self._get_cache_dir() + cache_dir.mkdir(parents=True, exist_ok=True) + modelinfo_path = cache_dir / self._get_cache_filename() + with atomic_writer(modelinfo_path, encoding="utf-8") as f: + json.dump(modelinfo_dict, f, indent=2) + except Exception: + logger.exception("Error saving model info cache.") + + @logtime(logger=logger, msg="Registry inspect model class") def inspect_model_cls(self) -> _ModelInfo: - return _run_in_subprocess( - lambda: _ModelInfo.from_model_cls(self.load_model_cls())) + model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py" + module_hash = None + + if model_path.exists(): + with open(model_path, "rb") as f: + module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest() + + mi = self._load_modelinfo_from_cache(module_hash) + if mi is not None: + logger.debug( + ("Loaded model info for class %s.%s from cache"), + self.module_name, + self.class_name, + ) + return mi + else: + logger.debug( + ("Cache model info for class %s.%s miss. Loading model instead."), + self.module_name, + self.class_name, + ) + + # Performed in another process to avoid initializing CUDA + mi = _run_in_subprocess( + lambda: _ModelInfo.from_model_cls(self.load_model_cls()) + ) + logger.debug( + "Loaded model info for class %s.%s", self.module_name, self.class_name + ) + + # save cache file + if module_hash is not None: + self._save_modelinfo_to_cache(mi, module_hash) + + return mi def load_model_cls(self) -> type[nn.Module]: mod = importlib.import_module(self.module_name) @@ -410,12 +656,12 @@ def _try_load_model_cls( model: _BaseRegisteredModel, ) -> Optional[type[nn.Module]]: from vllm.platforms import current_platform + current_platform.verify_model_arch(model_arch) try: return model.load_model_cls() except Exception: - logger.exception("Error in loading model architecture '%s'", - model_arch) + logger.exception("Error in loading model architecture '%s'", model_arch) return None @@ -427,8 +673,7 @@ def _try_inspect_model_cls( try: return model.inspect_model_cls() except Exception: - logger.exception("Error in inspecting model architecture '%s'", - model_arch) + logger.exception("Error in inspecting model architecture '%s'", model_arch) return None @@ -463,8 +708,10 @@ class _ModelRegistry: if model_arch in self.models: logger.warning( "Model architecture %s is already registered, and will be " - "overwritten by the new model class %s.", model_arch, - model_cls) + "overwritten by the new model class %s.", + model_arch, + model_cls, + ) if isinstance(model_cls, str): split_str = model_cls.split(":") @@ -476,8 +723,10 @@ class _ModelRegistry: elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): model = _RegisteredModel.from_model_cls(model_cls) else: - msg = ("`model_cls` should be a string or PyTorch model class, " - f"not a {type(model_arch)}") + msg = ( + "`model_cls` should be a string or PyTorch model class, " + f"not a {type(model_arch)}" + ) raise TypeError(msg) self.models[model_arch] = model @@ -488,7 +737,8 @@ class _ModelRegistry: if any(arch in all_supported_archs for arch in architectures): raise ValueError( f"Model architectures {architectures} failed " - "to be inspected. Please check the logs for more details.") + "to be inspected. Please check the logs for more details." + ) for arch in architectures: if arch in _PREVIOUSLY_SUPPORTED_MODELS: @@ -498,14 +748,15 @@ class _ModelRegistry: f"Model architecture {arch} was supported in vLLM until " f"v{previous_version}, and is not supported anymore. " "Please use an older version of vLLM if you want to " - "use this model architecture.") + "use this model architecture." + ) raise ValueError( f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {all_supported_archs}") + f"Supported architectures: {all_supported_archs}" + ) - def _try_load_model_cls(self, - model_arch: str) -> Optional[type[nn.Module]]: + def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]: if model_arch not in self.models: return None @@ -525,8 +776,9 @@ class _ModelRegistry: if architecture in _TRANSFORMERS_BACKEND_MODELS: return architecture - auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", - None) or dict() + auto_map: dict[str, str] = ( + getattr(model_config.hf_config, "auto_map", None) or dict() + ) # Make sure that config class is always initialized before model class, # otherwise the model class won't be able to access the config class, @@ -560,7 +812,7 @@ class _ModelRegistry: if model_module is not None: break else: - if model_config.model_impl != ModelImpl.TRANSFORMERS: + if model_config.model_impl != "transformers": return None raise ValueError( @@ -568,15 +820,17 @@ class _ModelRegistry: "registered model in the Transformers library (only " "relevant if the model is meant to be in Transformers) " "and 'AutoModel' is not present in the model config's " - "'auto_map' (relevant if the model is custom).") + "'auto_map' (relevant if the model is custom)." + ) if not model_module.is_backend_compatible(): - if model_config.model_impl != ModelImpl.TRANSFORMERS: + if model_config.model_impl != "transformers": return None raise ValueError( f"The Transformers implementation of {architecture!r} " - "is not compatible with vLLM.") + "is not compatible with vLLM." + ) return model_config._get_transformers_backend_cls() @@ -617,20 +871,23 @@ class _ModelRegistry: raise ValueError("No model architectures are specified") # Require transformers impl - if model_config.model_impl == ModelImpl.TRANSFORMERS: - arch = self._try_resolve_transformers(architectures[0], - model_config) + if model_config.model_impl == "transformers": + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_info = self._try_inspect_model_cls(arch) if model_info is not None: return (model_info, arch) + elif model_config.model_impl == "terratorch": + model_info = self._try_inspect_model_cls("Terratorch") + return (model_info, "Terratorch") # Fallback to transformers impl (after resolving convert_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == ModelImpl.AUTO - and getattr(model_config, "convert_type", "none") == "none"): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + and getattr(model_config, "convert_type", "none") == "none" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_info = self._try_inspect_model_cls(arch) if model_info is not None: @@ -643,10 +900,11 @@ class _ModelRegistry: return (model_info, arch) # Fallback to transformers impl (before resolving runner_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == ModelImpl.AUTO): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_info = self._try_inspect_model_cls(arch) if model_info is not None: @@ -665,20 +923,25 @@ class _ModelRegistry: raise ValueError("No model architectures are specified") # Require transformers impl - if model_config.model_impl == ModelImpl.TRANSFORMERS: - arch = self._try_resolve_transformers(architectures[0], - model_config) + if model_config.model_impl == "transformers": + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_cls = self._try_load_model_cls(arch) if model_cls is not None: return (model_cls, arch) + elif model_config.model_impl == "terratorch": + arch = "Terratorch" + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) # Fallback to transformers impl (after resolving convert_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == ModelImpl.AUTO - and getattr(model_config, "convert_type", "none") == "none"): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + and getattr(model_config, "convert_type", "none") == "none" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_cls = self._try_load_model_cls(arch) if model_cls is not None: @@ -691,10 +954,11 @@ class _ModelRegistry: return (model_cls, arch) # Fallback to transformers impl (before resolving runner_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == ModelImpl.AUTO): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_cls = self._try_load_model_cls(arch) if model_cls is not None: @@ -734,13 +998,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures, model_config) return model_cls.supports_multimodal - def supports_multimodal_raw_input( + def is_multimodal_raw_input_only_model( self, architectures: Union[str, list[str]], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) - return model_cls.supports_multimodal_raw_input + return model_cls.supports_multimodal_raw_input_only def is_pp_supported_model( self, @@ -807,14 +1071,15 @@ class _ModelRegistry: return not model_cls.supports_v0_only -ModelRegistry = _ModelRegistry({ - model_arch: - _LazyRegisteredModel( - module_name=f"vllm.model_executor.models.{mod_relname}", - class_name=cls_name, - ) - for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() -}) +ModelRegistry = _ModelRegistry( + { + model_arch: _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() + } +) _T = TypeVar("_T") @@ -827,21 +1092,23 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T: # `cloudpickle` allows pickling lambda functions directly import cloudpickle + input_bytes = cloudpickle.dumps((fn, output_filepath)) # cannot use `sys.executable __file__` here because the script # contains relative imports - returned = subprocess.run(_SUBPROCESS_COMMAND, - input=input_bytes, - capture_output=True) + returned = subprocess.run( + _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True + ) # check if the subprocess is successful try: returned.check_returncode() except Exception as e: # wrap raised exception to provide more information - raise RuntimeError(f"Error raised in subprocess:\n" - f"{returned.stderr.decode()}") from e + raise RuntimeError( + f"Error raised in subprocess:\n{returned.stderr.decode()}" + ) from e with open(output_filepath, "rb") as f: return pickle.load(f) @@ -850,6 +1117,7 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T: def _run() -> None: # Setup plugins from vllm.plugins import load_general_plugins + load_general_plugins() fn, output_file = pickle.loads(sys.stdin.buffer.read()) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 49a37342c67fa..6408cf7937b2f 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -8,39 +8,51 @@ import torch from torch import nn from transformers import RobertaConfig -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, - DispatchPooler, Pooler) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT, - BertEmbeddingModel, BertModel, - _decode_token_type_ids, - _encode_token_type_ids) -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - maybe_prefix) +from vllm.config import ModelConfig, VllmConfig +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + CLSPool, + DispatchPooler, + Pooler, +) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.bert import ( + TOKEN_TYPE_SHIFT, + BertEmbeddingModel, + BertModel, + _decode_token_type_ids, + _encode_token_type_ids, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type class RobertaEmbedding(nn.Module): - def __init__(self, config: RobertaConfig): super().__init__() self.size = config.hidden_size - self.word_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.padding_idx = config.pad_token_id - self.position_embeddings = nn.Embedding(config.max_position_embeddings, - config.hidden_size, - padding_idx=self.padding_idx) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx, + ) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, - config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).unsqueeze(0), @@ -48,18 +60,21 @@ class RobertaEmbedding(nn.Module): self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": - raise ValueError("Only 'absolute' position_embedding_type" + - " is supported") + raise ValueError( + "Only 'absolute' position_embedding_type" + " is supported" + ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - token_type_ids = _decode_token_type_ids(input_ids) - inputs_embeds = self.word_embeddings(input_ids) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -72,10 +87,14 @@ class RobertaEmbedding(nn.Module): class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" - def __init__(self, config: RobertaConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + config = model_config.hf_config + head_dtype = model_config.head_dtype + self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype) + self.out_proj = nn.Linear( + config.hidden_size, config.num_labels, dtype=head_dtype + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # CLSPool has already been applied in `pooling` @@ -89,13 +108,13 @@ class RobertaClassificationHead(nn.Module): class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -108,34 +127,35 @@ class RobertaEmbeddingModel(BertEmbeddingModel): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # Fix Roberta positions here outside of the CUDA graph. # Because we need the to extract the sequences from # input_ids the control flow is data dependent. - replace_roberta_positions(input_ids=input_ids, - position_ids=positions, - padding_idx=self.padding_idx) + replace_roberta_positions( + input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx + ) - return self.model(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> Union[BertModel, BertWithRope]: - if (vllm_config.model_config.hf_config.position_embedding_type == - "rotary"): + def _build_model( + self, vllm_config: VllmConfig, prefix: str = "" + ) -> Union[BertModel, BertWithRope]: + if vllm_config.model_config.hf_config.position_embedding_type == "rotary": return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) else: - return BertModel(vllm_config=vllm_config, - prefix=prefix, - embedding_class=RobertaEmbedding) + return BertModel( + vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) has_roberta_prefix = any( - name.startswith("roberta.") for name, _ in weights_list) + name.startswith("roberta.") for name, _ in weights_list + ) if has_roberta_prefix: # For models with the `roberta.` prefix e.g. # `FacebookAI/roberta-base` @@ -153,26 +173,27 @@ class RobertaEmbeddingModel(BertEmbeddingModel): class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - roberta: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ is_pooling_model = True jina_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ - 'emb_ln': "embeddings.LayerNorm", - 'layers': "layer", - 'mixer.Wqkv': "attention.self.qkv_proj", - 'mixer.out_proj': "attention.output.dense", - 'norm1': "attention.output.LayerNorm", - 'mlp.fc1': "intermediate.dense", - 'mlp.fc2': "output.dense", - 'norm2': "output.LayerNorm", - }) + "emb_ln": "embeddings.LayerNorm", + "layers": "layer", + "mixer.Wqkv": "attention.self.qkv_proj", + "mixer.out_proj": "attention.output.dense", + "norm1": "attention.output.LayerNorm", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm2": "output.LayerNorm", + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -180,37 +201,43 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels - self.roberta = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=RobertaEmbedding) - self.classifier = RobertaClassificationHead(config) + self.roberta = BertModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + ) + self.classifier = RobertaClassificationHead(vllm_config.model_config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.roberta.get_input_embeddings(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], @@ -219,22 +246,24 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - replace_roberta_positions(input_ids=input_ids, - position_ids=positions, - padding_idx=self.padding_idx) + replace_roberta_positions( + input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx + ) if token_type_ids is not None: assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) assert input_ids is not None _encode_token_type_ids(input_ids, token_type_ids) - return self.roberta(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.roberta( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) -def replace_roberta_positions(input_ids: torch.Tensor, - position_ids: torch.Tensor, - padding_idx: int) -> None: +def replace_roberta_positions( + input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int +) -> None: # Replace position ids because in RoBERTa models # they have to start at padding_idx + 1 and ignore # existing padding tokens diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py index efdb010046634..89150677f3ce8 100644 --- a/vllm/model_executor/models/rvl.py +++ b/vllm/model_executor/models/rvl.py @@ -2,23 +2,27 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping +from typing import Optional import torch import torch.nn as nn from transformers.activations import GELUActivation from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict -from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, - LlavaNextProcessingInfo) +from .llava_next import ( + LlavaDummyInputsBuilder, + LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo, +) from .llava_onevision import LlavaOnevisionForConditionalGeneration from .utils import WeightsMapper class RVLProcessingInfo(LlavaNextProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -27,7 +31,6 @@ class RVLProcessingInfo(LlavaNextProcessingInfo): class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) image_token = "<image>" @@ -38,26 +41,28 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = ( - self.info.get_image_size_with_most_features()) + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } class RVLMultiModalProjector(nn.Module): - def __init__(self, config): super().__init__() - self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, - eps=1e-06) + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-06) self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, @@ -85,7 +90,6 @@ class RVLMultiModalProjector(nn.Module): dummy_inputs=RVLDummyInputsBuilder, ) class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers @@ -95,7 +99,8 @@ class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index 34a87a6a69a39..ca33a694a3b61 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -22,7 +22,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only SeedOss model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -36,29 +38,38 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class SeedOssMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -83,8 +94,9 @@ class SeedOssMLP(nn.Module): prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -95,7 +107,6 @@ class SeedOssMLP(nn.Module): class SeedOssAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -181,7 +192,6 @@ class SeedOssAttention(nn.Module): class SeedOssDecoderLayer(nn.Module): - def __init__( self, config: SeedOssConfig, @@ -224,10 +234,10 @@ class SeedOssDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -240,16 +250,14 @@ class SeedOssDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -260,14 +268,16 @@ class SeedOssDecoderLayer(nn.Module): "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class SeedOssModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -275,8 +285,9 @@ class SeedOssModel(nn.Module): quant_config = vllm_config.quant_config # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): + if cache_config.sliding_window is not None and hasattr( + config, "max_window_layers" + ): assert config.max_window_layers == config.num_hidden_layers, ( "Sliding window for some but all layers is not supported. " "This model uses sliding window but `max_window_layers` = {} " @@ -284,14 +295,16 @@ class SeedOssModel(nn.Module): "to discuss this feature.".format( config.max_window_layers, config.num_hidden_layers, - )) + ) + ) self.config = config self.quant_config = quant_config self.vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -305,16 +318,18 @@ class SeedOssModel(nn.Module): decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -340,22 +355,20 @@ class SeedOssModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -369,18 +382,19 @@ class SeedOssModel(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -404,8 +418,7 @@ class SeedOssModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -434,25 +447,28 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lora_config = lora_config self.quant_config = quant_config - self.model = SeedOssModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = SeedOssModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -464,24 +480,21 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 3630f59f53e0a..ee21a03c8525d 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -14,26 +14,33 @@ from transformers import SiglipVisionConfig from vllm.attention.layer import MultiHeadAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) -from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +from .vision import ( + VisionEncoderInfo, + VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs, +) class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): - def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: - return self.get_patch_grid_length()**2 + return self.get_patch_grid_length() ** 2 def get_image_size(self) -> int: return self.vision_config.image_size @@ -48,7 +55,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa class SiglipVisionEmbeddings(nn.Module): - def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config @@ -64,19 +70,20 @@ class SiglipVisionEmbeddings(nn.Module): padding="valid", ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = VocabParallelEmbedding( - self.num_positions, self.embed_dim) + self.num_positions, self.embed_dim + ) self.register_buffer( "position_ids", - torch.arange(self.num_positions, dtype=torch.int64).expand( - (1, -1)), + torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)), persistent=False, ) - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, - width: int) -> torch.Tensor: + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: """ This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs) that allows the model to interpolate @@ -101,8 +108,8 @@ class SiglipVisionEmbeddings(nn.Module): height, width = height + 0.1, width + 0.1 patch_pos_embed = position_embeddings.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), - dim) + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, @@ -113,33 +120,36 @@ class SiglipVisionEmbeddings(nn.Module): mode="bicubic", align_corners=False, ) - if (int(height) != patch_pos_embed.shape[-2] - or int(width) != patch_pos_embed.shape[-1]): - raise ValueError("Width or height does not match with " - "the interpolated position embeddings") + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with " + "the interpolated position embeddings" + ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed - def forward(self, - pixel_values: torch.Tensor, - interpolate_pos_encoding: bool = False) -> torch.Tensor: + def forward( + self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False + ) -> torch.Tensor: _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: - embeddings += self.interpolate_pos_encoding( - embeddings, height, width) + embeddings += self.interpolate_pos_encoding(embeddings, height, width) else: embeddings += self.position_embedding(self.position_ids) return embeddings class SiglipAttention(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -153,9 +163,11 @@ class SiglipAttention(nn.Module): self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads (got " - "`embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + raise ValueError( + f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -177,8 +189,9 @@ class SiglipAttention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def forward( self, @@ -195,7 +208,6 @@ class SiglipAttention(nn.Module): class SiglipMLP(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -207,15 +219,14 @@ class SiglipMLP(nn.Module): self.config = config self.activation_fn = get_act_fn(config.hidden_act) # Special handling for BNB and torchao quantization - if quant_config and quant_config.get_name() in [ - "bitsandbytes", "torchao" - ]: + if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]: quantizable = True else: # For other quantization, we require the hidden size to be a # multiple of 64 - quantizable = (config.hidden_size % 64 == 0 - and config.intermediate_size % 64 == 0) + quantizable = ( + config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0 + ) self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, @@ -237,7 +248,6 @@ class SiglipMLP(nn.Module): class SiglipEncoderLayer(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -253,15 +263,13 @@ class SiglipEncoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -282,7 +290,6 @@ class SiglipEncoderLayer(nn.Module): class SiglipEncoder(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -299,12 +306,16 @@ class SiglipEncoder(nn.Module): else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - SiglipEncoderLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( self, @@ -339,12 +350,12 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention self.attention = torch.nn.MultiheadAttention( - config.hidden_size, config.num_attention_heads, batch_first=True) - self.layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size = hidden_state.shape[0] @@ -361,7 +372,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): class SiglipVisionTransformer(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -397,13 +407,13 @@ class SiglipVisionTransformer(nn.Module): require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: self.post_layernorm = None - self.use_head = (True if not hasattr(config, "vision_use_head") else - config.vision_use_head) + self.use_head = ( + True if not hasattr(config, "vision_use_head") else config.vision_use_head + ) if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead( config=config, @@ -414,28 +424,31 @@ class SiglipVisionTransformer(nn.Module): def forward( self, pixel_values: torch.Tensor, - interpolate_pos_encoding: bool = True, - feature_sample_layers: Optional[list[int]] = None, + *, + interpolate_pos_encoding: bool = False, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: - hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) - return_all_hidden_states = feature_sample_layers is not None - # Produces either the last layer output or all of the hidden states, - # depending on if we have feature_sample_layers or not + # depending on if we have select_layers or not encoder_outputs = self.encoder( inputs_embeds=hidden_states, - return_all_hidden_states=return_all_hidden_states, + return_all_hidden_states=select_layers is not None, ) # Handle post-norm (if applicable) and stacks feature layers if needed encoder_outputs = resolve_visual_encoder_outputs( - encoder_outputs, feature_sample_layers, self.post_layernorm, - self.config.num_hidden_layers) + encoder_outputs, + self.post_layernorm, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + feature_select_strategy=feature_select_strategy, + ) # TODO: add this back when pooled_output is used in inference. # if self.use_head: @@ -470,20 +483,25 @@ class SiglipVisionModel(nn.Module): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + @property + def dtype(self): + return self.get_input_embeddings().weight.dtype + def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, - feature_sample_layers: Optional[list[int]] = None, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: return self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, - feature_sample_layers=feature_sample_layers, + select_layers=select_layers, + feature_select_strategy=feature_select_strategy, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -496,8 +514,10 @@ class SiglipVisionModel(nn.Module): for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel - if (name.startswith("vision_model.post_layernorm") - and self.vision_model.post_layernorm is None): + if ( + name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None + ): continue # omit layers when num_hidden_layers_override is set @@ -506,7 +526,22 @@ class SiglipVisionModel(nn.Module): if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Check if this is a scale parameter that needs remapping first + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + # Try to remap the scale name first + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is not None and remapped_name in params_dict: + # Successfully remapped, use the remapped name + param = params_dict[remapped_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(remapped_name) + continue + # If remapping failed, continue with normal processing + + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -517,8 +552,7 @@ class SiglipVisionModel(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c6244fb3b3e6a..81f7e9887acee 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,37 +13,38 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.config import QuantizationConfig +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.platforms import _Backend from .vision import get_vit_attn_backend class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) return freqs class Siglip2VisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -57,15 +58,13 @@ class Siglip2VisionEmbeddings(nn.Module): # siglip2 naflex if self.num_patches > 0: self.patch_embedding = ReplicatedLinear( - input_size=config.num_channels * self.patch_size * - self.patch_size, + input_size=config.num_channels * self.patch_size * self.patch_size, output_size=self.embed_dim, return_bias=False, ) if self.preserve_original_pe: self.position_embedding_size = int(self.num_patches**0.5) - self.position_embedding = nn.Embedding(self.num_patches, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) else: self.patch_embedding = nn.Conv2d( @@ -76,15 +75,15 @@ class Siglip2VisionEmbeddings(nn.Module): padding="valid", ) if self.preserve_original_pe: - self.num_patches = (self.image_size // self.patch_size)**2 - self.position_embedding_size = (self.image_size // - self.patch_size) - self.position_embedding = nn.Embedding(self.num_patches, - self.embed_dim) + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.position_embedding_size = self.image_size // self.patch_size + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) - def forward(self, - pixel_values: torch.FloatTensor, - grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor`): @@ -99,36 +98,48 @@ class Siglip2VisionEmbeddings(nn.Module): # Apply patch embeddings to already patchified pixel values target_dtype = self.patch_embedding.weight.dtype if isinstance(self.patch_embedding, LinearBase): - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype)) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) elif isinstance(self.patch_embedding, nn.Conv2d): pixel_values = pixel_values.view( - -1, self.config.num_channels * self.config.temporal_patch_size, - self.patch_size, self.patch_size) - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype)) + -1, + self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) patch_embeds = patch_embeds.reshape(-1, self.embed_dim) if self.preserve_original_pe: assert grid_thws is not None pos_embed_new = torch.zeros_like(patch_embeds) - positional_embeddings = self.position_embedding.weight.reshape( - self.position_embedding_size, self.position_embedding_size, - -1).unsqueeze(0).permute(0, 3, 1, 2) + positional_embeddings = ( + self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + .unsqueeze(0) + .permute(0, 3, 1, 2) + ) cnt = 0 for t, h, w in grid_thws: volume = t * h * w - pe = F.interpolate(positional_embeddings, - size=(h, w), - mode='bicubic', - align_corners=False) + pe = F.interpolate( + positional_embeddings, + size=(h, w), + mode="bicubic", + align_corners=False, + ) pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) pe = pe[0].repeat(t, 1) - pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, - w // self.hidden_stride, self.hidden_stride, - -1) + pe = pe.reshape( + t, + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + -1, + ) pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1) - pos_embed_new[cnt:cnt + volume] = pe + pos_embed_new[cnt : cnt + volume] = pe cnt += volume patch_embeds = patch_embeds + pos_embed_new @@ -142,9 +153,9 @@ def rotate_half(x, interleaved=False): return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) def apply_rotary_emb_torch(x, cos, sin, interleaved=False): @@ -155,15 +166,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False): ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( - cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) sin = repeat( - sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) @@ -180,13 +191,12 @@ def apply_rotary_pos_emb( sin = sin.chunk(2, dim=-1)[0].contiguous() if is_flash_attn_backend: from flash_attn.layers.rotary import apply_rotary_emb + apply_rotary_emb_func = apply_rotary_emb else: apply_rotary_emb_func = apply_rotary_emb_torch - q_embed = apply_rotary_emb_func(q.float(), cos.float(), - sin.float()).type_as(q) - k_embed = apply_rotary_emb_func(k.float(), cos.float(), - sin.float()).type_as(k) + q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed @@ -209,7 +219,8 @@ class Siglip2Attention(nn.Module): raise ValueError( f"embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False @@ -230,28 +241,41 @@ class Siglip2Attention(nn.Module): prefix=f"{prefix}.out_proj", ) - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.use_rope = config.use_rope # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype() + ) + self.use_upstream_fa = False + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) + ) + if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.ROCM_AITER_FA, }: self.attn_backend = _Backend.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: Optional[tuple[torch.Tensor, - torch.Tensor]] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -260,30 +284,27 @@ class Siglip2Attention(nn.Module): qkv_states, _ = self.qkv_proj(hidden_states) queries, keys, values = qkv_states.chunk(3, dim=-1) - queries = queries.view(seq_length, self.num_heads_per_partition, - self.head_dim) - keys = keys.view(seq_length, self.num_heads_per_partition, - self.head_dim) - values = values.view(seq_length, self.num_heads_per_partition, - self.head_dim) + queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim) + keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim) + values = values.view(seq_length, self.num_heads_per_partition, self.head_dim) if self.use_rope: cos, sin = position_embeddings - queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0), - keys.unsqueeze(0), cos, sin, - self.is_flash_attn_backend) + queries, keys = apply_rotary_pos_emb( + queries.unsqueeze(0), + keys.unsqueeze(0), + cos, + sin, + self.is_flash_attn_backend, + ) queries = queries.squeeze(0) keys = keys.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - from flash_attn import flash_attn_varlen_func - attn_output = flash_attn_varlen_func( - queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, - max_seqlen).reshape(seq_length, -1) + attn_output = self.flash_attn_varlen_func( + queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + ).reshape(seq_length, -1) elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. batch_size = cu_seqlens.shape[0] - 1 @@ -302,13 +323,9 @@ class Siglip2Attention(nn.Module): # (1, num_heads, seq_len, head_dim) q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) - output_i = output_i.transpose(1, 2).reshape( - end_idx - start_idx, -1) + output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1) outputs.append(output_i) attn_output = torch.cat(outputs, dim=0) @@ -317,7 +334,6 @@ class Siglip2Attention(nn.Module): class Siglip2MLP(nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -351,7 +367,6 @@ class Siglip2MLP(nn.Module): class Siglip2EncoderLayer(nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -361,36 +376,41 @@ class Siglip2EncoderLayer(nn.Module): ): super().__init__() self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.self_attn = Siglip2Attention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = Siglip2MLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) - def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> tuple[torch.FloatTensor]: """ Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all - attention layers. See `attentions` under - returned tensors for more detail. + hidden_states: Input tensor of shape (batch, seq_len, embed_dim). + cu_seqlens: Cumulative sequence lengths tensor. + position_embeddings: Position embeddings tensor. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states, - cu_seqlens=cu_seqlens, - position_embeddings=position_embeddings) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) hidden_states = residual + hidden_states residual = hidden_states @@ -402,7 +422,7 @@ class Siglip2EncoderLayer(nn.Module): class Siglip2Encoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Siglip2EncoderLayer`]. Args: @@ -418,16 +438,21 @@ class Siglip2Encoder(nn.Module): ): super().__init__() self.config = config - self.layers = nn.ModuleList([ - Siglip2EncoderLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{idx}", - use_data_parallel=use_data_parallel) - for idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Siglip2EncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + use_data_parallel=use_data_parallel, + ) + for idx in range(config.num_hidden_layers) + ] + ) self.rotary_pos_emb = VisionRotaryEmbedding( - config.hidden_size // config.num_attention_heads // 2) + config.hidden_size // config.num_attention_heads // 2 + ) self.patch_size = config.patch_size self.hidden_stride = config.hidden_stride self.window_size = config.window_size @@ -436,7 +461,7 @@ class Siglip2Encoder(nn.Module): self.fullatt_block_indexes = None else: self.fullatt_block_indexes = [ - int(i) for i in config.fullatt_block_indexes.split('|') + int(i) for i in config.fullatt_block_indexes.split("|") ] # copied from qwen2.5_vl @@ -462,8 +487,7 @@ class Siglip2Encoder(nn.Module): ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -475,8 +499,9 @@ class Siglip2Encoder(nn.Module): cu_window_seqlens: list = [0] window_index_id = 0 # patch (after merge) number in each window - vit_merger_window_size = (self.window_size // self.hidden_stride // - self.patch_size) + vit_merger_window_size = ( + self.window_size // self.hidden_stride // self.patch_size + ) for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( @@ -484,7 +509,8 @@ class Siglip2Encoder(nn.Module): grid_w // self.hidden_stride, ) index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size @@ -507,8 +533,9 @@ class Siglip2Encoder(nn.Module): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) @@ -522,19 +549,11 @@ class Siglip2Encoder(nn.Module): ) -> torch.Tensor: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if - you want more control over how to convert `input_ids` indices - into associated vectors than the model's internal embedding - lookup matrix. - grid_thws (`torch.LongTensor`): - grid shape (num_patches, 3) - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See - `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): + inputs_embeds: Input tensor of shape + (batch_size, sequence_length, hidden_size). + Embedded representation of the input tokens. + grid_thws: Grid tensor of shape (num_patches, 3) + containing grid dimensions. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -549,11 +568,13 @@ class Siglip2Encoder(nn.Module): seq_len, _ = inputs_embeds.size() inputs_embeds = inputs_embeds.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) inputs_embeds = inputs_embeds[window_index, :, :] inputs_embeds = inputs_embeds.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) @@ -571,29 +592,27 @@ class Siglip2Encoder(nn.Module): # for more information dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) reverse_indices = torch.argsort(window_index) hidden_states = inputs_embeds for index, block in enumerate(self.layers): - if (not self.fullatt_block_indexes - or index in self.fullatt_block_indexes): + if not self.fullatt_block_indexes or index in self.fullatt_block_indexes: cu_seqlens_tmp = cu_seqlens else: cu_seqlens_tmp = cu_window_seqlens - hidden_states = block(hidden_states, cu_seqlens_tmp, - position_embeddings) + hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings) hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) return hidden_states class Siglip2VisionTransformer(nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -606,12 +625,13 @@ class Siglip2VisionTransformer(nn.Module): embed_dim = config.hidden_size self.embeddings = Siglip2VisionEmbeddings(config) - self.encoder = Siglip2Encoder(config, - quant_config=quant_config, - prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.encoder = Siglip2Encoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel, + ) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -632,7 +652,6 @@ class Siglip2VisionTransformer(nn.Module): class Siglip2NavitModel(torch.nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -646,7 +665,8 @@ class Siglip2NavitModel(torch.nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.vision_model", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + ) def forward( self, @@ -658,8 +678,7 @@ class Siglip2NavitModel(torch.nn.Module): grid_thws=grid_thws, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -670,7 +689,7 @@ class Siglip2NavitModel(torch.nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -681,8 +700,7 @@ class Siglip2NavitModel(torch.nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 920f4def69173..f0f6917ddf913 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,83 +8,117 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image -from transformers import BatchEncoding, PretrainedConfig, TensorType +from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.models.intern_vit import (InternVisionModel, - InternVisionPatchModel) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.intern_vit import ( + InternVisionModel, + InternVisionPatchModel, +) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = '<img>' -IMG_END = '</img>' -IMG_CONTEXT = '<IMG_CONTEXT>' +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<IMG_CONTEXT>" IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) -class SkyworkR1VImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor +class SkyworkR1VImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + type: Literal["pixel_values"] = "pixel_values" + + pixel_values_flat: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "h", "w"), + ] + + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] -class SkyworkR1VImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class SkyworkR1VImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - ni: Number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) """ + type: Literal["image_embeds"] = "image_embeds" -SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, - SkyworkR1VImageEmbeddingInputs] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("ni", "ifs", "hs"), + ] + + +SkyworkR1VImageInputs = Union[ + SkyworkR1VImagePixelInputs, SkyworkR1VImageEmbeddingInputs +] # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - return T.Compose([ - T.Lambda(lambda img: convert_image_mode(img, 'RGB')), - T.Resize((input_size, input_size), - interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=MEAN, std=STD) - ]) + return T.Compose( + [ + T.Lambda(lambda img: convert_image_mode(img, "RGB")), + T.Resize( + (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC + ), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ @@ -96,7 +130,7 @@ def find_closest_aspect_ratio( height: int, image_size: int, ) -> tuple[int, int]: - best_ratio_diff = float('inf') + best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: @@ -131,10 +165,13 @@ def get_skyworkr1v_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -191,10 +228,12 @@ def dynamic_preprocess_skyworkr1v( resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -270,7 +309,8 @@ class SkyworkR1VProcessor: assert isinstance(dynamic_image_size, bool) self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch @@ -299,14 +339,18 @@ class SkyworkR1VProcessor: dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch - is None else min_dynamic_patch) - max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch - is None else max_dynamic_patch) - dynamic_image_size = (self.dynamic_image_size if dynamic_image_size - is None else dynamic_image_size) - use_thumbnail = (self.use_thumbnail - if use_thumbnail is None else use_thumbnail) + min_dynamic_patch = ( + self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch + ) + max_dynamic_patch = ( + self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch + ) + dynamic_image_size = ( + self.dynamic_image_size + if dynamic_image_size is None + else dynamic_image_size + ) + use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail return resolve_skyworkr1v_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -373,7 +417,8 @@ class SkyworkR1VProcessor: min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] def __call__( @@ -384,7 +429,7 @@ class SkyworkR1VProcessor: max_dynamic_patch: Optional[int] = None, dynamic_image_size: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: if text is None: text = [] if not isinstance(text, list): @@ -403,11 +448,11 @@ class SkyworkR1VProcessor: max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) - image_inputs: dict[str, NestedTensors] = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + image_inputs = { + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: @@ -416,18 +461,16 @@ class SkyworkR1VProcessor: image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace('<image>', image_repl.full, 1) for t in text] + text = [t.replace("<image>", image_repl.full, 1) for t in text] text_inputs = self.tokenizer(text) - return { - **BatchEncoding(text_inputs, tensor_type=return_tensors), - **image_inputs, - } + combined_outputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_outputs, tensor_type=return_tensors) class SkyworkR1VProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor: return self.ctx.init_processor( SkyworkR1VProcessor, @@ -471,8 +514,7 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo): ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -480,9 +522,7 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo): return largest_feature_pinpoint -class SkyworkR1VDummyInputsBuilder( - BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]): - +class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -492,29 +532,31 @@ class SkyworkR1VDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } -class SkyworkR1VMultiModalProcessor( - BaseMultiModalProcessor[SkyworkR1VProcessingInfo]): - +class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessingInfo]): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: + ) -> BatchFeature: processed_outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, @@ -534,7 +576,7 @@ class SkyworkR1VMultiModalProcessor( def _get_mm_fields_config( self, - hf_inputs: Mapping[str, NestedTensors], + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) @@ -542,7 +584,8 @@ class SkyworkR1VMultiModalProcessor( return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), @@ -570,7 +613,8 @@ class SkyworkR1VMultiModalProcessor( def get_replacement_skyworkr1v(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -600,8 +644,10 @@ class SkyworkR1VMultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor( SkyworkR1VMultiModalProcessor, info=SkyworkR1VProcessingInfo, - dummy_inputs=SkyworkR1VDummyInputsBuilder) + dummy_inputs=SkyworkR1VDummyInputsBuilder, +) class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -625,12 +671,13 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.llm_arch_name = config.text_config.architectures[0] - self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM' + self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM" self.vision_model = self._init_vision_model( config, quant_config=quant_config, @@ -649,18 +696,20 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): self.img_context_token_id = None self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and \ - (llm_quant_config is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( @@ -674,8 +723,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 @@ -688,20 +738,19 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), - ReplicatedLinear(vit_hidden_size * - int(1 / self.downsample_ratio)**2, - llm_hidden_size, - return_bias=False), + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + ReplicatedLinear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + llm_hidden_size, + return_bias=False, + ), nn.GELU(), - ReplicatedLinear(llm_hidden_size, - llm_hidden_size, - return_bias=False), + ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False), ) def pixel_shuffle(self, x, scale_factor=0.5): @@ -710,9 +759,13 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) - if self.ps_version == 'v1': + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": pass else: x = x.permute(0, 2, 1, 3).contiguous() @@ -722,37 +775,16 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: + self, **kwargs: object + ) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -761,13 +793,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return SkyworkR1VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) image_token_id = kwargs["image_token_id"] @@ -775,22 +803,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): self.img_context_token_id = image_token_id.flatten().unique().item() if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return SkyworkR1VImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }, ) raise AssertionError("This line should be unreachable.") @@ -810,14 +830,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): # Only one image in the current batch if len(num_patches) == 1: - return image_embeds.view( - -1, self.config.text_config.hidden_size).unsqueeze(0) + return image_embeds.view(-1, self.config.text_config.hidden_size).unsqueeze( + 0 + ) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -825,16 +845,16 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: - self.visual_token_mask = ( - input_ids == self.img_context_token_id).reshape(-1, 1) + self.visual_token_mask = (input_ids == self.img_context_token_id).reshape( + -1, 1 + ) else: self.visual_token_mask = None def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -845,19 +865,23 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - assert self.img_context_token_id is not None + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.img_context_token_id, - ) - return inputs_embeds + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -867,19 +891,10 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - forward_kwargs = { "input_ids": input_ids, "positions": positions, @@ -889,8 +904,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): # Only required if the model is mono-architecture if self.visual_token_mask is not None: - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) @@ -899,18 +913,23 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [ - "action_embed", "temporal_embed", "track_embed", - "track_embed_decoder", "box_token", "cg_criterion", "cg_model", - "loc_encoder", "loc_decoder", "sam", "temporal_token", - "track_token" + "action_embed", + "temporal_embed", + "track_embed", + "track_embed_decoder", + "box_token", + "cg_criterion", + "cg_model", + "loc_encoder", + "loc_decoder", + "sam", + "temporal_token", + "track_token", ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/smolvlm.py b/vllm/model_executor/models/smolvlm.py index 2adfad67152b3..1800330c8235f 100644 --- a/vllm/model_executor/models/smolvlm.py +++ b/vllm/model_executor/models/smolvlm.py @@ -8,22 +8,18 @@ from transformers import SmolVLMProcessor from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf: disable from .idefics3 import Idefics3DummyInputsBuilder as SmolVLMDummyInputsBuilder -from .idefics3 import Idefics3ForConditionalGeneration +from .idefics3 import Idefics3ForConditionalGeneration, Idefics3ProcessingInfo from .idefics3 import Idefics3MultiModalProcessor as SmolVLMMultiModalProcessor -from .idefics3 import Idefics3ProcessingInfo - -# yapf: enable class SmolVLMProcessingInfo(Idefics3ProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor: return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs) def _get_image_token( - self, processor: Optional[SmolVLMProcessor]) -> tuple[str, str]: + self, processor: Optional[SmolVLMProcessor] + ) -> tuple[str, str]: if processor is None: processor = self.get_hf_processor() image_token = processor.image_token @@ -32,11 +28,12 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo): return image_token, fake_image_token, global_image_token -@MULTIMODAL_REGISTRY.register_processor(SmolVLMMultiModalProcessor, - info=SmolVLMProcessingInfo, - dummy_inputs=SmolVLMDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + SmolVLMMultiModalProcessor, + info=SmolVLMProcessingInfo, + dummy_inputs=SmolVLMDummyInputsBuilder, +) class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__( vllm_config=vllm_config, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 8dd52f1d204a5..5abcb47c6e25f 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -37,27 +37,37 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class SolarMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -83,8 +93,9 @@ class SolarMLP(nn.Module): prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -95,7 +106,6 @@ class SolarMLP(nn.Module): class SolarAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -184,7 +194,6 @@ class SolarAttention(nn.Module): class SolarDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -198,21 +207,24 @@ class SolarDecoderLayer(nn.Module): rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): - rope_scaling["original_max_position_embeddings"] \ - = config.original_max_position_embeddings - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = SolarAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -229,10 +241,10 @@ class SolarDecoderLayer(nn.Module): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -245,23 +257,20 @@ class SolarDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class SolarModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -272,12 +281,16 @@ class SolarModel(nn.Module): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -300,9 +313,9 @@ class SolarModel(nn.Module): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -329,8 +342,7 @@ class SolarModel(nn.Module): bskcn_h_2 = None bskcn_r_1 = None bskcn_r_2 = None - bskcn_tv = (self.config.bskcn_tv[0] - if self.training else self.config.bskcn_tv[1]) + bskcn_tv = self.config.bskcn_tv[0] if self.training else self.config.bskcn_tv[1] for i in range(self.start_layer, self.end_layer): if i in self.config.bskcn_1: @@ -340,12 +352,10 @@ class SolarModel(nn.Module): bskcn_h_2 = hidden_states.clone() bskcn_r_2 = residual.clone() if i in self.config.bskcn_3: - hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( - 1 - bskcn_tv) + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (1 - bskcn_tv) residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) if i in self.config.bskcn_4: - hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( - 1 - bskcn_tv) + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (1 - bskcn_tv) residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) layer = self.layers[i] hidden_states, residual = layer( @@ -355,16 +365,14 @@ class SolarModel(nn.Module): ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -376,14 +384,15 @@ class SolarModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -416,8 +425,7 @@ class SolarModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -467,21 +475,27 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def forward( self, @@ -490,17 +504,15 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index d6ec743ce845e..79ed001833444 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -21,7 +21,9 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -32,44 +34,56 @@ from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class StablelmMLP(nn.Module): - - def __init__(self, - config: StableLmConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: StableLmConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, [config.intermediate_size] * 2, + config.hidden_size, + [config.intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -80,12 +94,13 @@ class StablelmMLP(nn.Module): class StablelmAttention(nn.Module): - - def __init__(self, - config: StableLmConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: StableLmConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -102,33 +117,39 @@ class StablelmAttention(nn.Module): # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_key_value_heads == 0 - self.num_key_value_heads = max( - 1, self.total_num_key_value_heads // tp_size) + self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings self.partial_rotary_factor = getattr( - config, "rope_pct", getattr(config, "partial_rotary_factor", 1)) + config, "rope_pct", getattr(config, "partial_rotary_factor", 1) + ) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim self.qkv_bias = getattr(config, "use_qkv_bias", False) if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") + raise ValueError( + f"hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_key_value_heads, - self.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_key_value_heads, + self.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -136,13 +157,15 @@ class StablelmAttention(nn.Module): base=self.config.rope_theta, partial_rotary_factor=self.partial_rotary_factor, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -158,7 +181,6 @@ class StablelmAttention(nn.Module): class StablelmDecoderLayer(nn.Module): - def __init__( self, config: StableLmConfig, @@ -167,16 +189,13 @@ class StablelmDecoderLayer(nn.Module): prefix: str = "", ) -> None: super().__init__() - self.self_attn = StablelmAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = StablelmAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp") - norm_eps = getattr(config, "norm_eps", - getattr(config, "layer_norm_eps", 1e-05)) + norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) def forward( self, @@ -202,7 +221,6 @@ class StablelmDecoderLayer(nn.Module): class StableLMEpochModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -219,15 +237,15 @@ class StableLMEpochModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: StablelmDecoderLayer( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) - norm_eps = getattr(config, "norm_eps", - getattr(config, "layer_norm_eps", 1e-05)) + norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -247,15 +265,14 @@ class StableLMEpochModel(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -267,7 +284,7 @@ class StableLMEpochModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -287,32 +304,34 @@ class StableLMEpochModel(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class StablelmForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.lm_head") + self.model = StableLMEpochModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -324,20 +343,18 @@ class StablelmForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 9d9a2bff0e43f..ec894140c3bf3 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -19,8 +19,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Starcoder2 model.""" +"""PyTorch Starcoder2 model.""" + from collections.abc import Iterable +from itertools import islice from typing import Optional, Union import torch @@ -32,32 +34,43 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Starcoder2Attention(nn.Module): - - def __init__(self, - config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config @@ -107,13 +120,15 @@ class Starcoder2Attention(nn.Module): base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -129,11 +144,12 @@ class Starcoder2Attention(nn.Module): class Starcoder2MLP(nn.Module): - - def __init__(self, - config: Starcoder2Config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Starcoder2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.c_fc = ColumnParallelLinear( config.hidden_size, @@ -159,25 +175,28 @@ class Starcoder2MLP(nn.Module): class Starcoder2DecoderLayer(nn.Module): - - def __init__(self, - config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") - self.mlp = Starcoder2MLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon) + self.self_attn = Starcoder2Attention( + config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Starcoder2MLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.norm_epsilon + ) def forward( self, @@ -204,7 +223,6 @@ class Starcoder2DecoderLayer(nn.Module): @support_torch_compile class Starcoder2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -219,7 +237,8 @@ class Starcoder2Model(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Starcoder2DecoderLayer( @@ -228,9 +247,9 @@ class Starcoder2Model(nn.Module): prefix=f"{prefix}.layers", ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -250,15 +269,14 @@ class Starcoder2Model(nn.Module): else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -269,7 +287,7 @@ class Starcoder2Model(nn.Module): params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -286,22 +304,21 @@ class Starcoder2Model(nn.Module): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Starcoder2ForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.model = Starcoder2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Starcoder2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: @@ -316,10 +333,12 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): quant_config=quant_config, prefix=f"{prefix}.lm_head", ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -331,26 +350,25 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 47d2af5c2a140..2099055e641c4 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" + from collections.abc import Iterable +from itertools import islice from typing import Any, Optional import torch @@ -10,62 +12,77 @@ from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class FusedMoEBlock(nn.Module): - - def __init__(self, - config: ModelConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: ModelConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() if self.tp_size > config.moe_num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.moe_num_experts}.") + f"the number of experts {config.moe_num_experts}." + ) - self.experts = FusedMoE(num_experts=config.moe_num_experts, - top_k=config.moe_top_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_expert_weight, - quant_config=quant_config, - prefix=f"{prefix}.experts") - self.gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.experts = FusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_expert_weight, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + self.gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -74,17 +91,16 @@ class FusedMoEBlock(nn.Module): router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) class Step3TextMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -95,18 +111,23 @@ class Step3TextMLP(nn.Module): ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() self.hidden_size = hidden_size @@ -118,7 +139,6 @@ class Step3TextMLP(nn.Module): class Step3TextAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -143,8 +163,9 @@ class Step3TextAttention(nn.Module): self.num_heads = self.total_num_heads // tp_size if num_kv_heads != 1: - raise ValueError(f"Step3TextAttention num_kv_heads must be 1, " - f"but got {num_kv_heads}.") + raise ValueError( + f"Step3TextAttention num_kv_heads must be 1, but got {num_kv_heads}." + ) self.num_kv_heads = num_kv_heads self.head_dim = head_dim @@ -174,21 +195,26 @@ class Step3TextAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.wq", ) - self.rotary_emb = get_rope(self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embedding, - base=rope_theta, - rope_scaling=rope_scaling) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embedding, + base=rope_theta, + rope_scaling=rope_scaling, + ) scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - self.num_kv_heads, - cache_config=cache_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) - def forward(self, positions: torch.Tensor, - hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = self.inter_norm(q) @@ -200,12 +226,13 @@ class Step3TextAttention(nn.Module): class Step3TextDecoderLayer(nn.Module): - - def __init__(self, - config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() config = config.hf_config self.hidden_size = config.hidden_size @@ -223,59 +250,61 @@ class Step3TextDecoderLayer(nn.Module): share_q_dim=config.share_q_dim, rope_theta=config.rope_theta, rope_scaling=rope_scaling, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) layer_idx = int(prefix.split("layers.")[1].split(".")[0]) moe_layers_enum = getattr(config, "moe_layers_enum", None) if moe_layers_enum is not None: - moe_layers_idx = [ - int(i) for i in moe_layers_enum.strip().split(',') - ] + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] else: # Default to 1dense. moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.moe") + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) self.share_expert = Step3TextMLP( hidden_size=self.hidden_size, intermediate_size=config.share_expert_dim, hidden_act="silu", quant_config=quant_config, - prefix=f"{prefix}.share_expert") + prefix=f"{prefix}.share_expert", + ) self.use_moe = True else: - self.mlp = Step3TextMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) self.use_moe = False - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( - self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) if self.use_moe: share_output = self.share_expert(hidden_states) @@ -289,7 +318,6 @@ class Step3TextDecoderLayer(nn.Module): @support_torch_compile class Step3TextModel(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -298,8 +326,9 @@ class Step3TextModel(nn.Module): self.vocab_size = config.vocab_size self.config = config - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -309,11 +338,12 @@ class Step3TextModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Step3TextDecoderLayer(config=vllm_config. - model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Step3TextDecoderLayer( + config=vllm_config.model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -321,9 +351,9 @@ class Step3TextModel(nn.Module): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -346,22 +376,22 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Step3TextForCausalLM(nn.Module, SupportsPP): - def __init__( self, *, @@ -385,55 +415,65 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None): - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) - (".qkv_proj", ".q_proj", 0, self.config.share_q_dim / - (self.config.share_q_dim + self.config.head_dim * 2)), - (".qkv_proj", ".k_proj", self.config.share_q_dim / - (self.config.share_q_dim + self.config.head_dim * 2), - (self.config.share_q_dim + self.config.head_dim) / - (self.config.share_q_dim + self.config.head_dim * 2)), - (".qkv_proj", ".v_proj", - (self.config.share_q_dim + self.config.head_dim) / - (self.config.share_q_dim + self.config.head_dim * 2), - (self.config.share_q_dim + self.config.head_dim * 2) / - (self.config.share_q_dim + self.config.head_dim * 2)), + ( + ".qkv_proj", + ".q_proj", + 0, + self.config.share_q_dim + / (self.config.share_q_dim + self.config.head_dim * 2), + ), + ( + ".qkv_proj", + ".k_proj", + self.config.share_q_dim + / (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim) + / (self.config.share_q_dim + self.config.head_dim * 2), + ), + ( + ".qkv_proj", + ".v_proj", + (self.config.share_q_dim + self.config.head_dim) + / (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim * 2) + / (self.config.share_q_dim + self.config.head_dim * 2), + ), ] stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -446,20 +486,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): expert_params_mapping = [ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), ] - disable_moe_stacked_params = [ - data[1] for data in expert_params_mapping - ] + disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if any(disable_moe_stacked_param in name - for disable_moe_stacked_param in - disable_moe_stacked_params): + if any( + disable_moe_stacked_param in name + for disable_moe_stacked_param in disable_moe_stacked_params + ): continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): @@ -479,23 +518,30 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader for expert_id in range(loaded_weight.shape[0]): loaded_weight_expert = loaded_weight[expert_id] - weight_loader(param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(name) break else: - for (param_name, weight_name, start_idx, - end_idx) in qkv_params_mapping: + for ( + param_name, + weight_name, + start_idx, + end_idx, + ) in qkv_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -505,8 +551,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): dim = param.shape[param.output_dim] begin_idx = int(start_idx * dim) end_idx = int(end_idx * dim) - param_slice = param.narrow(param.output_dim, begin_idx, - end_idx - begin_idx) + param_slice = param.narrow( + param.output_dim, begin_idx, end_idx - begin_idx + ) param_slice.copy_(loaded_weight) loaded_params.add(name) break @@ -514,8 +561,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f8877b584b198..5ec7845a122f7 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -2,10 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from itertools import product from math import ceil, sqrt -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -16,49 +15,80 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType +from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import run_dp_sharded_vision_model -class Step3VLImagePixelInputs(TypedDict): +class Step3VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + - bnp: Batch size * number of images * number of patches + - hp: Height of patch + - wp: Width of patch + """ + type: Literal["pixel_values"] - pixel_values: torch.Tensor - patch_pixel_values: Optional[torch.Tensor] - num_patches: list[int] + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + patch_pixel_values: Annotated[ + Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp") + ] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class Step3VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor +class Step3VLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ + + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] -Step3VLImageInputs = Union[Step3VLImagePixelInputs, - Step3VLImageEmbeddingInputs] +Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs] ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] @@ -66,31 +96,42 @@ MAX_IMAGE_SIZE: int = 3024 class Step3VisionProcessor: - def __init__(self, size, interpolation_mode="bicubic", patch_size=None): mean = [0.48145466, 0.4578275, 0.40821073] std = [0.26862954, 0.26130258, 0.27577711] patch_size = patch_size if patch_size is not None else size - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean, std), - transforms.Resize( - (size, size), - interpolation=InterpolationMode.BICUBIC if interpolation_mode - == "bicubic" else InterpolationMode.BILINEAR, - antialias=True), - ]) + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (size, size), + interpolation=InterpolationMode.BICUBIC + if interpolation_mode == "bicubic" + else InterpolationMode.BILINEAR, + antialias=True, + ), + ] + ) - self.patch_transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean, std), - transforms.Resize( - (patch_size, patch_size), - interpolation=InterpolationMode.BICUBIC if interpolation_mode - == "bicubic" else InterpolationMode.BILINEAR, - antialias=True), - ]) if patch_size is not None else None + self.patch_transform = ( + transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (patch_size, patch_size), + interpolation=InterpolationMode.BICUBIC + if interpolation_mode == "bicubic" + else InterpolationMode.BILINEAR, + antialias=True, + ), + ] + ) + if patch_size is not None + else None + ) def __call__(self, image, is_patch=False): if is_patch: @@ -100,7 +141,6 @@ class Step3VisionProcessor: class ImagePatcher: - def determine_window_size(self, long: int, short: int) -> int: if long <= 728: return short if long / short > 1.5 else 0 @@ -121,14 +161,12 @@ class ImagePatcher: size_w, size_h = size step_w, step_h = step - x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + - 1) + x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + 1) x_start = [step_w * i for i in range(x_num)] if len(x_start) > 1 and x_start[-1] + size_w > width: x_start[-1] = width - size_w - y_num = 1 if height <= size_h else ceil((height - size_h) / - step_h + 1) + y_num = 1 if height <= size_h else ceil((height - size_h) / step_h + 1) y_start = [step_h * i for i in range(y_num)] if len(y_start) > 1 and y_start[-1] + size_h > height: y_start[-1] = height - size_h @@ -138,8 +176,10 @@ class ImagePatcher: windows.append(np.concatenate([start, start + size], axis=1)) windows = np.concatenate(windows, axis=0) - return [(int(box[0]), int(box[1]), int(box[2] - box[0]), - int(box[3] - box[1])) for box in windows], (x_num, y_num) + return [ + (int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1])) + for box in windows + ], (x_num, y_num) def square_pad(self, img: Image.Image) -> Image.Image: w, h = img.size @@ -150,25 +190,27 @@ class ImagePatcher: padded.paste(img, (0, 0)) return padded - def get_image_size_for_padding(self, img_width: int, - img_height: int) -> tuple[int, int]: + def get_image_size_for_padding( + self, img_width: int, img_height: int + ) -> tuple[int, int]: ratio = img_width / img_height if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): new_size = max(img_height, img_width) return new_size, new_size return img_width, img_height - def get_image_size_for_preprocess(self, img_width: int, - img_height: int) -> tuple[int, int]: - + def get_image_size_for_preprocess( + self, img_width: int, img_height: int + ) -> tuple[int, int]: if max(img_height, img_width) > MAX_IMAGE_SIZE: scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width) img_width = int(img_width * scale_factor) img_height = int(img_height * scale_factor) return img_width, img_height - def get_image_size_for_crop(self, img_width: int, img_height: int, - window_size: int): + def get_image_size_for_crop( + self, img_width: int, img_height: int, window_size: int + ): w_ratio = img_width / window_size h_ratio = img_height / window_size @@ -190,22 +232,26 @@ class ImagePatcher: target = img.crop((j, i, j + tw, i + th)) return target - def get_num_patches(self, img_width: int, - img_height: int) -> tuple[int, int]: - img_width, img_height = self.get_image_size_for_padding( - img_width, img_height) + def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]: + img_width, img_height = self.get_image_size_for_padding(img_width, img_height) img_width, img_height = self.get_image_size_for_preprocess( - img_width, img_height) - window_size = self.determine_window_size(max(img_height, img_width), - min(img_height, img_width)) + img_width, img_height + ) + window_size = self.determine_window_size( + max(img_height, img_width), min(img_height, img_width) + ) if window_size == 0: return 0, 0 else: img_width, img_height = self.get_image_size_for_crop( - img_width, img_height, window_size) + img_width, img_height, window_size + ) center_list, (x_num, y_num) = self.slide_window( - img_width, img_height, [(window_size, window_size)], - [(window_size, window_size)]) + img_width, + img_height, + [(window_size, window_size)], + [(window_size, window_size)], + ) full_rows = (len(center_list) - 1) // x_num + 1 if len(center_list) > 0 and len(center_list) % x_num == 0: full_rows -= 1 @@ -216,39 +262,44 @@ class ImagePatcher: ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: img_width, img_height = img.size new_img_width, new_img_height = self.get_image_size_for_padding( - img_width, img_height) + img_width, img_height + ) if new_img_width != img_width or new_img_height != img_height: img = self.square_pad(img) img_width, img_height = img.size new_img_width, new_img_height = self.get_image_size_for_preprocess( - img_width, img_height) - img = img.resize((new_img_width, new_img_height), - Image.Resampling.BILINEAR) + img_width, img_height + ) + img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR) window_size = self.determine_window_size( - max(new_img_height, new_img_width), - min(new_img_height, new_img_width)) + max(new_img_height, new_img_width), min(new_img_height, new_img_width) + ) if window_size == 0: return img, [], None else: new_img_width, new_img_height = self.get_image_size_for_crop( - new_img_width, new_img_height, window_size) + new_img_width, new_img_height, window_size + ) if (new_img_width, new_img_height) != (img_width, img_height): - img_for_crop = img.resize((new_img_width, new_img_height), - Image.Resampling.BILINEAR) + img_for_crop = img.resize( + (new_img_width, new_img_height), Image.Resampling.BILINEAR + ) else: img_for_crop = img patches = [] newlines = [] center_list, (x_num, y_num) = self.slide_window( - new_img_width, new_img_height, [(window_size, window_size)], - [(window_size, window_size)]) + new_img_width, + new_img_height, + [(window_size, window_size)], + [(window_size, window_size)], + ) for patch_id, center_lf_point in enumerate(center_list): x, y, patch_w, patch_h = center_lf_point - big_patch = self.patch_crop(img_for_crop, y, x, patch_h, - patch_w) + big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w) patches.append(big_patch) if (patch_id + 1) % x_num == 0: newlines.append(patch_id) @@ -256,12 +307,16 @@ class ImagePatcher: if newlines and newlines[-1] == len(patches) - 1: newlines.pop() - return img, patches, [i in newlines for i in range(len(patches)) - ] if len(patches) > 0 else None + return ( + img, + patches, + [i in newlines for i in range(len(patches))] + if len(patches) > 0 + else None, + ) class Step3VLProcessor: - def __init__( self, config: PretrainedConfig, @@ -274,17 +329,15 @@ class Step3VLProcessor: self.image_size = 728 self.patch_size = 504 - self.image_preprocessor = Step3VisionProcessor(self.image_size, - "bilinear", - self.patch_size) + self.image_preprocessor = Step3VisionProcessor( + self.image_size, "bilinear", self.patch_size + ) self.num_image_feature_size = 169 self.num_patch_feature_size = 81 self.image_token = "<im_patch>" - self.image_feature_placeholder = (self.image_token * - self.num_image_feature_size) - self.patch_feature_placeholder = (self.image_token * - self.num_patch_feature_size) + self.image_feature_placeholder = self.image_token * self.num_image_feature_size + self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size self.patcher = ImagePatcher() @@ -293,15 +346,16 @@ class Step3VLProcessor: return self.tokenizer.get_vocab()[self.image_token] def get_num_image_tokens(self, img_width: int, img_height: int) -> int: - num_patches, num_newlines = self.patcher.get_num_patches( - img_width, img_height) + num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height) - return num_patches * ( - self.num_patch_feature_size + - 2) + self.num_image_feature_size + 2 + num_newlines + return ( + num_patches * (self.num_patch_feature_size + 2) + + self.num_image_feature_size + + 2 + + num_newlines + ) - def _split_images(self, - images: list[Image.Image]) -> list[ImageWithPatches]: + def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]: result = [] for img in images: result.append(self.patcher(img)) @@ -328,13 +382,15 @@ class Step3VLProcessor: assert len(patch_newline_mask) == num_patches text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>" token_ids.extend( - [self.tokenizer.convert_tokens_to_ids("<patch_start>")] + - [self.image_token_id] * self.num_patch_feature_size + - [self.tokenizer.convert_tokens_to_ids("<patch_end>")]) + [self.tokenizer.convert_tokens_to_ids("<patch_start>")] + + [self.image_token_id] * self.num_patch_feature_size + + [self.tokenizer.convert_tokens_to_ids("<patch_end>")] + ) if patch_newline_mask and patch_newline_mask[i]: text += "<patch_newline>" token_ids.append( - self.tokenizer.convert_tokens_to_ids("<patch_newline>")) + self.tokenizer.convert_tokens_to_ids("<patch_newline>") + ) return text, token_ids def _get_image_repl( @@ -342,11 +398,11 @@ class Step3VLProcessor: num_images: int, ) -> tuple[str, list[int]]: text = f"<im_start>{self.image_feature_placeholder}<im_end>" - token_ids = [ - self.tokenizer.convert_tokens_to_ids("<im_start>") - ] + [self.image_token_id] * self.num_image_feature_size + [ - self.tokenizer.convert_tokens_to_ids("<im_end>") - ] + token_ids = ( + [self.tokenizer.convert_tokens_to_ids("<im_start>")] + + [self.image_token_id] * self.num_image_feature_size + + [self.tokenizer.convert_tokens_to_ids("<im_end>")] + ) return text * num_images, token_ids * num_images def _get_image_repl_features( @@ -357,15 +413,15 @@ class Step3VLProcessor: ) -> tuple[str, list[int]]: if num_patches > 0: patch_repl, patch_repl_ids = self._get_patch_repl( - num_patches, patch_new_line_idx) + num_patches, patch_new_line_idx + ) else: patch_repl = "" patch_repl_ids = [] image_repl, image_repl_ids = self._get_image_repl(num_images) return patch_repl + image_repl, patch_repl_ids + image_repl_ids - def replace_placeholder(self, text: str, placeholder: str, - repls: list[str]) -> str: + def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str: parts = text.split(placeholder) if len(parts) - 1 != len(repls): @@ -407,17 +463,17 @@ class Step3VLProcessor: image_repl_ids_lst = [] num_patches = [] for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501 - pixel_values_lst.extend( - self._convert_images_to_pixel_values([raw_img])) + pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img])) if len(img_patches) > 0: patch_pixel_values_lst.extend( - self._convert_images_to_pixel_values(img_patches, - is_patch=True)) + self._convert_images_to_pixel_values(img_patches, is_patch=True) + ) num_patches.append(len(img_patches)) image_repl_str, image_repl_ids = self._get_image_repl_features( - 1, len(img_patches), patch_newline_mask) + 1, len(img_patches), patch_newline_mask + ) image_repl_str_lst.append(image_repl_str) image_repl_ids_lst.extend(image_repl_ids) @@ -429,15 +485,15 @@ class Step3VLProcessor: "num_patches": num_patches, } if patch_pixel_values_lst: - image_inputs["patch_pixel_values"] = torch.cat( - patch_pixel_values_lst) + image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst) if patch_newline_mask_lst: image_inputs["patch_newline_mask"] = torch.tensor( - patch_newline_mask_lst, dtype=torch.bool) + patch_newline_mask_lst, dtype=torch.bool + ) text = [ - self.replace_placeholder(t, self.image_token, - image_repl_str_lst) for t in text + self.replace_placeholder(t, self.image_token, image_repl_str_lst) + for t in text ] text_inputs = self.tokenizer(text) @@ -451,7 +507,6 @@ class Step3VLProcessor: class Step3VLProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self) -> Step3VLProcessor: return Step3VLProcessor( self.get_hf_config(), @@ -465,7 +520,8 @@ class Step3VLProcessingInfo(BaseProcessingInfo): hf_processor = self.get_hf_processor() return hf_processor.get_num_image_tokens( self.get_image_size_with_most_features().width, - self.get_image_size_with_most_features().height) + self.get_image_size_with_most_features().height, + ) def get_mm_max_tokens_per_item( self, @@ -479,19 +535,19 @@ class Step3VLProcessingInfo(BaseProcessingInfo): def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int: if len(mm_data) != 1 or "image" not in mm_data: - raise ValueError( - "mm_data could only contain one key 'image' for steo1o") + raise ValueError("mm_data could only contain one key 'image' for steo1o") image_data = mm_data["image"] if not isinstance(image_data, (list, tuple)): image_data = [image_data] - return sum(self.get_hf_processor().get_num_image_tokens( - img.width, img.height) for img in image_data) + return sum( + self.get_hf_processor().get_num_image_tokens(img.width, img.height) + for img in image_data + ) class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return "<im_patch>" * num_images @@ -500,22 +556,24 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } -class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] - ): - +class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -531,10 +589,10 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] if num_patches > 0: patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor._get_image_repl_features( - 1, num_patches, patch_newline_mask.tolist())[1] + 1, num_patches, patch_newline_mask.tolist() + )[1] else: - image_repl_ids = hf_processor._get_image_repl_features( - 1, 0, None)[1] + image_repl_ids = hf_processor._get_image_repl_features(1, 0, None)[1] return PromptUpdateDetails.select_token_id( seq=image_repl_ids, embed_token_id=image_placeholder_token_id, @@ -558,10 +616,12 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] return dict( pixel_values=MultiModalFieldConfig.batched("image"), patch_pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + "image", num_patches + ), num_patches=MultiModalFieldConfig.batched("image"), patch_newline_mask=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + "image", num_patches + ), ) @@ -575,29 +635,29 @@ def get_abs_pos(abs_pos, tgt_size): dtype = abs_pos.dtype if src_size != tgt_size: - old_pos_embed = old_pos_embed.view(1, src_size, src_size, - dim).permute(0, 3, 1, - 2).contiguous() + old_pos_embed = ( + old_pos_embed.view(1, src_size, src_size, dim) + .permute(0, 3, 1, 2) + .contiguous() + ) old_pos_embed = old_pos_embed.to(torch.float32) new_pos_embed = F.interpolate( old_pos_embed, size=(tgt_size, tgt_size), - mode='bicubic', + mode="bicubic", antialias=True, align_corners=False, ).to(dtype) new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) - vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, - dim) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) return vision_pos_embed else: return abs_pos class Step3VisionEmbeddings(nn.Module): - def __init__(self, config: Step3VisionEncoderConfig): super().__init__() self.config = config @@ -615,43 +675,51 @@ class Step3VisionEmbeddings(nn.Module): bias=True, ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.pad_tp_size = 4 # hard code for padding # To load the pretrained weights, we still use P+1 as the seqlen - self.position_embedding = torch.nn.Embedding(self.num_patches + 1, - self.embed_dim) - self.register_buffer("position_ids", - torch.arange(self.num_patches + 1).expand( - (1, -1)), - persistent=False) + self.position_embedding = torch.nn.Embedding( + self.num_patches + 1, self.embed_dim + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_patches + 1).expand((1, -1)), + persistent=False, + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding( - pixel_values) # shape = [*, width, grid, grid] + pixel_values + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # pad class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = embeddings + get_abs_pos( - self.position_embedding(self.position_ids), patch_embeds.size(1)) - embeddings = torch.cat([ - embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, - 1), embeddings - ], - dim=1) + self.position_embedding(self.position_ids), patch_embeds.size(1) + ) + embeddings = torch.cat( + [ + embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1), + embeddings, + ], + dim=1, + ) return embeddings class Step3VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -660,46 +728,32 @@ class Step3VisionAttention(nn.Module): self.scale = self.head_dim**-0.5 - tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.q_size = self.num_heads * self.head_dim - if use_data_parallel: - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - self.out_proj = ReplicatedLinear( - self.total_num_heads * self.head_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.total_num_heads, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - self.out_proj = RowParallelLinear(self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) def forward( self, @@ -711,19 +765,9 @@ class Step3VisionAttention(nn.Module): # get query proj qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) - k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) - v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - scale=self.scale, - is_causal=False) - attn_output = attn_output.transpose(1, 2).reshape( - bsz, tgt_len, self.num_heads * self.head_dim) + + # Use unified MultiHeadAttention with automatic backend selection + attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) @@ -731,29 +775,32 @@ class Step3VisionAttention(nn.Module): class Step3VisionMLP(nn.Module): - - def __init__(self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=prefix) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=prefix) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -763,12 +810,13 @@ class Step3VisionMLP(nn.Module): class Step3VisionEncoderLayer(nn.Module): - - def __init__(self, - config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.use_data_parallel = use_data_parallel self.embed_dim = config.hidden_size @@ -776,44 +824,48 @@ class Step3VisionEncoderLayer(nn.Module): config, quant_config, prefix=f"{prefix}.self_attn", - use_data_parallel=self.use_data_parallel) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = Step3VisionMLP(config, - quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=self.use_data_parallel) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + use_data_parallel=self.use_data_parallel, + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Step3VisionMLP( + config, + quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=self.use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, ) -> torch.FloatTensor: - hidden_states = hidden_states + self.layer_norm1( - self.self_attn(hidden_states)) - hidden_states = hidden_states + self.layer_norm2( - self.mlp(hidden_states)) + hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states)) + hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states)) return hidden_states class Step3VisionEncoder(nn.Module): - - def __init__(self, - config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.use_data_parallel = use_data_parallel - self.layers = nn.ModuleList([ - Step3VisionEncoderLayer(config, - quant_config, - prefix=f"{prefix}.layers.{i}", - use_data_parallel=self.use_data_parallel) - for i in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Step3VisionEncoderLayer( + config, + quant_config, + prefix=f"{prefix}.layers.{i}", + use_data_parallel=self.use_data_parallel, + ) + for i in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -826,12 +878,13 @@ class Step3VisionEncoder(nn.Module): class Step3VisionTransformer(nn.Module): - - def __init__(self, - config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.use_data_parallel = use_data_parallel @@ -841,7 +894,8 @@ class Step3VisionTransformer(nn.Module): config, quant_config, prefix=f"{prefix}.transformer", - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) def forward( self, @@ -849,23 +903,28 @@ class Step3VisionTransformer(nn.Module): ): hidden_states = self.embeddings(pixel_values) if self.use_data_parallel: - hidden_states = run_dp_sharded_vision_model( - hidden_states, self.transformer) + hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer) else: hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states -@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor, - info=Step3VLProcessingInfo, - dummy_inputs=Step3VLDummyInputsBuilder) -class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + Step3VLMultiModalProcessor, + info=Step3VLProcessingInfo, + dummy_inputs=Step3VLDummyInputsBuilder, +) +class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "model.": "language_model.model.", - "lm_head.": "language_model.lm_head.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + } + ) + + supports_encoder_tp_data = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -889,12 +948,14 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, config.vision_config, None, prefix=maybe_prefix(prefix, "vision_model"), - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) self.vit_downsampler = nn.Conv2d( config.vision_config.hidden_size, config.vision_config.output_hidden_size, kernel_size=2, - stride=config.understand_projector_stride) + stride=config.understand_projector_stride, + ) self.vit_downsampler2 = nn.Conv2d( config.vision_config.output_hidden_size, config.vision_config.output_hidden_size * 2, @@ -916,17 +977,12 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model")) + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() + self.language_model.make_empty_intermediate_tensors + ) @property def device(self): @@ -937,7 +993,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return next(self.parameters()).dtype def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Step3VLImageInputs]: + self, **kwargs: object + ) -> Optional[Step3VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) patch_pixel_values = kwargs.pop("patch_pixel_values", None) num_patches = kwargs.pop("num_patches", None) @@ -947,42 +1004,24 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - pixel_values = flatten_bn(pixel_values, concat=True) - if pixel_values.dim() >= 3: - pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) - if patch_pixel_values is not None: - patch_pixel_values = flatten_bn(patch_pixel_values, - concat=True) - patch_pixel_values = patch_pixel_values.view( - -1, *patch_pixel_values.shape[-3:]) - # Handle empty patch_pixel_values by setting to None - if patch_pixel_values.shape[0] == 0: - patch_pixel_values = None - num_patches = flatten_bn(num_patches, concat=True).tolist() - return Step3VLImagePixelInputs( type="pixel_values", - pixel_values=pixel_values.to(self.dtype).to(self.device), - patch_pixel_values=patch_pixel_values.to(self.dtype).to( - self.device) if patch_pixel_values is not None else None, + pixel_values=pixel_values.to(self.dtype), + patch_pixel_values=patch_pixel_values.to(self.dtype) + if patch_pixel_values is not None + else None, num_patches=num_patches, ) if image_embeds is not None: - if image_embeds.dim() == 2 or image_embeds.dim() >= 3: - image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) - else: - raise ValueError( - f"Unexpected shape for image_embeds: {image_embeds.shape}") - return Step3VLImageEmbeddingInputs( type="image_embeds", - image_embeds=image_embeds.to(self.dtype).to(self.device), + image_embeds=image_embeds.to(self.dtype), ) - return None - def _process_image_features(self, - image_features: torch.Tensor) -> torch.Tensor: + raise AssertionError("This line should be unreachable.") + + def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: B, P = image_features.shape[:2] HW = int(sqrt(P)) image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) @@ -993,26 +1032,29 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_features = self.vit_large_projector(image_features) return image_features - def _get_vision_model_output(self, - input_tensor: torch.Tensor) -> torch.Tensor: + def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor: return self.vision_model(input_tensor)[:, 4:] def _process_image_input( - self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Step3VLImageInputs + ) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": image_features = image_input["image_embeds"] else: - image_features = self._get_vision_model_output( - image_input["pixel_values"]) - patch_image_features = self._get_vision_model_output( - image_input["patch_pixel_values"] - ) if image_input["patch_pixel_values"] is not None else None + image_features = self._get_vision_model_output(image_input["pixel_values"]) + patch_image_features = ( + self._get_vision_model_output(image_input["patch_pixel_values"]) + if image_input["patch_pixel_values"] is not None + else None + ) num_patches = image_input["num_patches"] image_features = self._process_image_features(image_features) - patch_image_features = self._process_image_features( - patch_image_features) if patch_image_features is not None else None + patch_image_features = ( + self._process_image_features(patch_image_features) + if patch_image_features is not None + else None + ) merged_image_features = [] cur_patch_idx = 0 @@ -1020,20 +1062,23 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, cur_feature = [] if num_patch > 0: patch_slice = patch_image_features[ - cur_patch_idx:cur_patch_idx + num_patch] + cur_patch_idx : cur_patch_idx + num_patch + ] cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) - cur_feature.append(image_features[i].view( - -1, image_features.shape[-1])) + cur_feature.append(image_features[i].view(-1, image_features.shape[-1])) cur_patch_idx += num_patch merged_image_features.append( - torch.cat(cur_feature) if len(cur_feature) > - 1 else cur_feature[0]) + torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0] + ) return merged_image_features - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: - return None + return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings @@ -1041,24 +1086,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - if multimodal_embeddings is None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - else: - is_text = input_ids != self.config.image_token_id - text_ids = input_ids[is_text] - text_embeds = self.language_model.model.get_input_embeddings( - text_ids) - inputs_embeds = torch.empty(input_ids.shape[0], - text_embeds.shape[-1], - dtype=text_embeds.dtype, - device=text_embeds.device) - inputs_embeds[is_text] = text_embeds - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_id) - return inputs_embeds + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) def forward( self, @@ -1072,44 +1114,35 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_id, + ) input_ids = None - hidden_states = self.language_model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - skip_prefixes = [] if self.vision_model is None and self.vit_large_projector is None: skip_prefixes = [ - "vision_model.", "vit_downsampler.", "vit_downsampler2.", - "vit_large_projector." + "vision_model.", + "vit_downsampler.", + "vit_downsampler2.", + "vit_large_projector.", ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - loaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 0000000000000..485c008e830a9 --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,515 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size + if isinstance(window_size, Iterable) + else (window_size, window_size) + ) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads + ) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter( + relative_position_index, requires_grad=False + ) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, dim + ).expand( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask_expanded.unsqueeze( + 1 + ).unsqueeze(0) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, dim, dim + ) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0.0, + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + +class SwinSelfOutput(nn.Module): + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self = SwinSelfAttention( + config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self", + ) + self.output = SwinSelfOutput( + config, dim, quant_config=quant_config, prefix=f"{prefix}.output" + ) + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, attention_mask, head_mask, output_attentions + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = ColumnParallelLinear( + dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention( + config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.intermediate = SwinIntermediate( + config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate" + ) + self.output = SwinOutput( + config, dim, quant_config=quant_config, prefix=f"{prefix}.output" + ) + + +class SwinStage(nn.Module): + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + SwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=nn.LayerNorm + ) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample( + hidden_states_before_downsampling, input_dimensions + ) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = ( + hidden_states, + hidden_states_before_downsampling, + output_dimensions, + ) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() + for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu" + ) + ] + self.layers = nn.ModuleList( + [ + SwinStage( + config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=( + grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx), + ), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[ + sum(config.depths[:layer_idx]) : sum( + config.depths[: layer_idx + 1] + ) + ], + downsample=SwinPatchMerging + if (layer_idx < self.num_layers - 1) + else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(self.num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder( + config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 0990be8d02b94..6a224fe9288b2 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,58 +3,85 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, - Union, cast) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn -from transformers import BatchFeature, CLIPVisionConfig +from transformers import ( + BatchFeature, + CLIPVisionConfig, + PretrainedConfig, + SiglipVisionConfig, +) from transformers import LlavaConfig as HfLlavaConfig -from transformers import PretrainedConfig, SiglipVisionConfig from transformers.image_utils import ImageInput, get_image_size, to_numpy_array from transformers.models.llava import LlavaProcessor from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix, merge_multimodal_embeddings) -from .vision import VisionEncoderInfo, get_vision_encoder_info +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix +from .vision import ( + VisionEncoderInfo, + get_num_selected_vision_tokens, + get_vision_encoder_info, +) -class TarsierImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class TarsierImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class TarsierImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor +class TarsierImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -TarsierImageInputs = Union[TarsierImagePixelInputs, - TarsierImageEmbeddingInputs] +TarsierImageInputs = Union[TarsierImagePixelInputs, TarsierImageEmbeddingInputs] class TarsierHfConfig(Protocol): # Based on the Tarsier's LlavaConfig @@ -79,19 +106,18 @@ class TarsierProcessorKwargs(ProcessingKwargs, total=False): class TarsierProcessor(LlavaProcessor): - def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, list[TextInput], - list[PreTokenizedInput]] = None, + text: Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] = None, audio=None, videos=None, **kwargs: Unpack[TarsierProcessorKwargs], ) -> BatchFeature: if images is None and text is None: - raise ValueError( - "You have to specify at least one of `images` or `text`.") + raise ValueError("You have to specify at least one of `images` or `text`.") output_kwargs = self._merge_kwargs( TarsierProcessorKwargs, @@ -100,15 +126,17 @@ class TarsierProcessor(LlavaProcessor): ) if images is not None: image_inputs = self.image_processor( - images, **output_kwargs["images_kwargs"]) + images, **output_kwargs["images_kwargs"] + ) else: image_inputs = {} if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string," - " or a list of strings") + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) # try to expand inputs in processing if we have the necessary parts prompt_strings = text @@ -116,51 +144,55 @@ class TarsierProcessor(LlavaProcessor): # Replace the image token with the expanded image token sequence pixel_values = image_inputs["pixel_values"] height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * ( - width // self.patch_size + - 1) + self.num_additional_image_tokens + 1 + num_image_tokens = ( + (height // self.patch_size) * (width // self.patch_size + 1) + + self.num_additional_image_tokens + + 1 + ) if self.vision_feature_select_strategy == "default": num_image_tokens -= 1 prompt_strings = [] for sample in text: - sample = sample.replace(self.image_token, - self.image_token * num_image_tokens) + sample = sample.replace( + self.image_token, self.image_token * num_image_tokens + ) prompt_strings.append(sample) - return_tensors = output_kwargs["text_kwargs"].pop( - "return_tensors", None) - text_inputs = self.tokenizer(prompt_strings, - **output_kwargs["text_kwargs"]) - return BatchFeature(data={ - **text_inputs, - **image_inputs - }, - tensor_type=return_tensors) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + return BatchFeature( + data={**text_inputs, **image_inputs}, tensor_type=return_tensors + ) class TarsierMultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(image_features) @@ -170,7 +202,6 @@ class TarsierMultiModalProjector(nn.Module): class TarsierProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> TarsierHfConfig: return self.ctx.get_hf_config(HfLlavaConfig) @@ -187,18 +218,6 @@ class TarsierProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def _apply_feature_select_strategy( - self, - strategy: str, - encoder_num_image_tokens: int, - ) -> int: - if strategy == "default": - return encoder_num_image_tokens - 1 - if strategy == "full": - return encoder_num_image_tokens - msg = f"Unexpected feature select strategy: {strategy!r}" - raise NotImplementedError(msg) - def get_num_image_tokens( self, *, @@ -207,29 +226,27 @@ class TarsierProcessingInfo(BaseProcessingInfo): ) -> int: hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - num_projected_patches = self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + num_projected_patches = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), + hf_config.vision_feature_select_strategy, ) if num_projected_patches <= 0: default_size = self.get_image_size_with_most_features() - num_projected_patches_default = self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, + num_projected_patches_default = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=default_size.width, image_height=default_size.height, ), + hf_config.vision_feature_select_strategy, ) if num_projected_patches_default <= 0: - raise ValueError( - "Could not determine a valid number of image patches.") + raise ValueError("Could not determine a valid number of image patches.") num_projected_patches = num_projected_patches_default num_height_patches = int(math.sqrt(num_projected_patches)) - total_image_tokens_for_llm = num_projected_patches \ - + num_height_patches + 1 + total_image_tokens_for_llm = num_projected_patches + num_height_patches + 1 return total_image_tokens_for_llm def get_image_size_with_most_features(self) -> ImageSize: @@ -255,12 +272,10 @@ _I_Tarsier = TypeVar("_I_Tarsier", bound=TarsierProcessingInfo) class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]): - pass class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -282,14 +297,14 @@ class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): def get_replacement(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_projected_patches = images.get_feature_size(item_idx) # This assumes num_projected_patches is a perfect square num_height_patches = int(math.sqrt(num_projected_patches)) - num_final_image_tokens = num_projected_patches \ - + num_height_patches + 1 + num_final_image_tokens = num_projected_patches + num_height_patches + 1 else: image_size = images.get_image_size(item_idx) num_final_image_tokens = self.info.get_num_image_tokens( @@ -308,8 +323,7 @@ class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): ] -def _build_tarsier_hf_info( - ctx: InputProcessingContext) -> TarsierProcessingInfo: +def _build_tarsier_hf_info(ctx: InputProcessingContext) -> TarsierProcessingInfo: return TarsierProcessingInfo(ctx) @@ -317,7 +331,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( @@ -340,22 +354,23 @@ def init_vision_tower_for_tarsier( feature_layers = hf_config.vision_feature_layer base_num_hidden_layers = vision_config.num_hidden_layers - def _get_layer_index(feature_layer_index: int, - num_hidden_layers_total: int) -> int: + def _get_layer_index(feature_layer_index: int, num_hidden_layers_total: int) -> int: if feature_layer_index < 0: return num_hidden_layers_total + feature_layer_index + 1 return feature_layer_index if isinstance(feature_layers, int): - num_hidden_layers_to_init = _get_layer_index(feature_layers, - base_num_hidden_layers) + num_hidden_layers_to_init = _get_layer_index( + feature_layers, base_num_hidden_layers + ) elif isinstance(feature_layers, (list, tuple)): num_hidden_layers_to_init = max( - _get_layer_index(idx, base_num_hidden_layers) - for idx in feature_layers) + _get_layer_index(idx, base_num_hidden_layers) for idx in feature_layers + ) else: - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) if isinstance(vision_config, CLIPVisionConfig): return CLIPVisionModel( @@ -378,14 +393,17 @@ def init_vision_tower_for_tarsier( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(_build_tarsier_hf_processor, - info=_build_tarsier_hf_info, - dummy_inputs=TarsierDummyInputsBuilder) -class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + _build_tarsier_hf_processor, + info=_build_tarsier_hf_info, + dummy_inputs=TarsierDummyInputsBuilder, +) +class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -404,7 +422,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) projector_bias = getattr(config, "multimodal_projector_bias", True) self.multi_modal_projector = TarsierMultiModalProjector( @@ -413,39 +432,31 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=projector_bias, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, - hf_config=config. - text_config, # Use text_config from Tarsier's main config + hf_config=config.text_config, # Use text_config from Tarsier's main config prefix=maybe_prefix(prefix, "language_model"), ) - self.register_buffer('image_newline_idx_tensor', - torch.tensor([config.image_newline_idx], - dtype=torch.long), - persistent=False) - self.register_buffer('image_new_idx_tensor', - torch.tensor([config.image_new_idx], - dtype=torch.long), - persistent=False) + self.register_buffer( + "image_newline_idx_tensor", + torch.tensor([config.image_newline_idx], dtype=torch.long), + persistent=False, + ) + self.register_buffer( + "image_new_idx_tensor", + torch.tensor([config.image_new_idx], dtype=torch.long), + persistent=False, + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) # Assuming 3 channels - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - return data + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[TarsierImageInputs]: + self, **kwargs: object + ) -> Optional[TarsierImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -453,77 +464,49 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - return TarsierImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=pixel_values, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return TarsierImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") - def _select_image_features(self, image_features: torch.Tensor, *, - strategy: str) -> torch.Tensor: - if strategy == "default": - return image_features[:, 1:] - elif strategy == "full": - return image_features - raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # From vLLM LLaVA, vision tower output handling - image_hidden_states = vision_tower(pixel_values) - if not isinstance(image_hidden_states, torch.Tensor): - raise TypeError( - f"image_hidden_states type: {type(image_hidden_states)}" - " is not supported") - - def select_features_fn(leaf: torch.Tensor): - return self._select_image_features( - leaf, - strategy=self.config.vision_feature_select_strategy, - ) - - selected_features = cast( - Union[torch.Tensor, tuple[torch.Tensor, ...]], - json_map_leaves(select_features_fn, image_hidden_states), + return vision_tower( + pixel_values, + feature_select_strategy=self.config.vision_feature_select_strategy, ) - return selected_features def _add_tarsier_split_tokens( - self, projected_image_features: torch.Tensor) -> torch.Tensor: + self, projected_image_features: torch.Tensor + ) -> torch.Tensor: """ Implements Tarsier's `add_split_tokens` logic. """ - num_images, num_projected_patches, embed_dim = \ - projected_image_features.shape + num_images, num_projected_patches, embed_dim = projected_image_features.shape num_height_patches = int(math.sqrt(num_projected_patches)) num_width_patches = num_projected_patches // num_height_patches device = projected_image_features.device embedding_layer = self.language_model.model.embed_tokens image_newline_emb = embedding_layer( - self.image_newline_idx_tensor.to(device)).squeeze(0) - image_new_emb = embedding_layer( - self.image_new_idx_tensor.to(device)).squeeze(0) + self.image_newline_idx_tensor.to(device) + ).squeeze(0) + image_new_emb = embedding_layer(self.image_new_idx_tensor.to(device)).squeeze(0) try: current_image_features_grid = projected_image_features.view( - num_images, num_height_patches, num_width_patches, embed_dim) + num_images, num_height_patches, num_width_patches, embed_dim + ) except RuntimeError as e: raise RuntimeError( "Cannot reshape projected_image_features" @@ -533,22 +516,24 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, "Ensure num_projected_patches is compatible" " with a grid structure. " f"num_projected_patches={num_projected_patches}, " - f"derived num_height_patches={num_height_patches}. ") from e + f"derived num_height_patches={num_height_patches}. " + ) from e image_newline_expanded = image_newline_emb.expand( - (num_images, num_height_patches, 1, embed_dim)) + (num_images, num_height_patches, 1, embed_dim) + ) features_with_newlines = torch.cat( [current_image_features_grid, image_newline_expanded], - dim=2 # Concatenate along width dim + dim=2, # Concatenate along width dim ) - new_num_patches_after_newline = num_projected_patches \ - + num_height_patches + new_num_patches_after_newline = num_projected_patches + num_height_patches features_with_newlines_flat = features_with_newlines.view( - num_images, new_num_patches_after_newline, embed_dim) + num_images, new_num_patches_after_newline, embed_dim + ) image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim)) final_image_features = torch.cat( [features_with_newlines_flat, image_new_expanded], - dim=1 # Concatenate along patch sequence dim + dim=1, # Concatenate along patch sequence dim ) return final_image_features @@ -559,16 +544,17 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, assert self.vision_tower is not None pixel_values = inputs["pixel_values"] image_features_selected = self._image_pixels_to_features( - self.vision_tower, pixel_values) # type: ignore + self.vision_tower, pixel_values + ) # type: ignore if isinstance(image_features_selected, torch.Tensor): - projected_features = self.multi_modal_projector( - image_features_selected) + projected_features = self.multi_modal_projector(image_features_selected) final_features = self._add_tarsier_split_tokens(projected_features) return final_features else: raise TypeError( f"_image_pixels_to_features type:" - f" {type(image_features_selected)} is not supported") + f" {type(image_features_selected)} is not supported" + ) def _process_image_input( self, @@ -579,37 +565,22 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, if isinstance(projected_features, torch.Tensor): return self._add_tarsier_split_tokens(projected_features) else: - raise ValueError("Incorrect type of image_embeds. " - f"Got type: {type(projected_features)}. ") + raise ValueError( + "Incorrect type of image_embeds. " + f"Got type: {type(projected_features)}. " + ) assert self.vision_tower is not None return self._process_image_pixels(image_input) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - self.config.image_token_index, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, @@ -622,25 +593,26 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, inputs_embeds = None elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + inputs_embeds = self.get_input_embeddings( + input_ids, + vision_embeddings, + is_multimodal=input_ids == self.config.image_token_index, + ) input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 49a7677151a94..113581d55ff56 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -30,12 +30,15 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel from .llama import LlamaDecoderLayer -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + is_pp_missing_parameter, +) class TeleChat2Model(LlamaModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): hf_config = vllm_config.model_config.hf_config @@ -43,7 +46,7 @@ class TeleChat2Model(LlamaModel): "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", "intermediate_size": "ffn_hidden_size", - "rms_norm_eps": "layer_norm_epsilon" + "rms_norm_eps": "layer_norm_epsilon", } vllm_config.model_config.hf_config.hidden_act = "silu" @@ -62,11 +65,10 @@ class TeleChat2Model(LlamaModel): layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.skip_bias_add = True - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ - ('gate_up_proj', 'gate_proj', 0), - ('gate_up_proj', 'up_proj', 1), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -78,9 +80,10 @@ class TeleChat2Model(LlamaModel): v_weight = [] for i in range(total_num_heads): start = i * head_dim * 2 - k_weight.append(loaded_weight[start:start + head_dim, :]) - v_weight.append(loaded_weight[start + head_dim:start + - 2 * head_dim:]) + k_weight.append(loaded_weight[start : start + head_dim, :]) + v_weight.append( + loaded_weight[start + head_dim : start + 2 * head_dim :] + ) k_weight = torch.cat(k_weight, dim=0) v_weight = torch.cat(v_weight, dim=0) name = name.replace("key_value", "qkv_proj") @@ -112,15 +115,15 @@ class TeleChat2Model(LlamaModel): if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class TeleChat2ForCausalLM(LlamaForCausalLM): - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "transformer.": "model.", @@ -134,18 +137,17 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): }, ) - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py index 3666f7011a997..4dfeddb0b28e4 100644 --- a/vllm/model_executor/models/teleflm.py +++ b/vllm/model_executor/models/teleflm.py @@ -28,12 +28,14 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.models.llama import (LlamaDecoderLayer, - LlamaForCausalLM, LlamaModel) +from vllm.model_executor.models.llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) class TeleFLMModel(LlamaModel): - def __init__( self, *, @@ -41,9 +43,7 @@ class TeleFLMModel(LlamaModel): prefix: str = "", layer_type: type[nn.Module] = LlamaDecoderLayer, ): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) """ This implementation is based on the µScaling paper presented at the ICLR 2025 Workshop: @@ -65,7 +65,6 @@ class TeleFLMModel(LlamaModel): class TeleFLMForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # mup @@ -74,6 +73,6 @@ class TeleFLMForCausalLM(LlamaForCausalLM): self.mup_scale_factor = self.config.mup_scale_factor self.output_mult = self.config.output_mult / self.mup_scale_factor logit_scale = self.output_mult - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.config.vocab_size, logit_scale + ) diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py new file mode 100644 index 0000000000000..13d2e8eacc013 --- /dev/null +++ b/vllm/model_executor/models/terratorch.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 IBM. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `Terratorch` models""" + +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +from terratorch.vllm import ( + DummyDataGenerator, + InferenceRunner, + InputDefinition, + InputTypeEnum, +) +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import MultiModalProcessorOnlyCache +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, + PlaceholderRange, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal +from .interfaces_base import default_pooling_type + +logger = init_logger(__name__) + + +def _terratorch_field_names(pretrained_cfg: dict): + input_definition = InputDefinition(**pretrained_cfg["input"]) + return set(input_definition.data.keys()) + + +def _terratorch_field_factory( + pretrained_cfg: dict, +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: + def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): + input_definition = InputDefinition(**pretrained_cfg["input"]) + fields = {} + for input_name, input in input_definition.data.items(): + if input.type == InputTypeEnum.tensor: + fields[input_name] = "image" + + return { + field_name: MultiModalFieldConfig.batched(modality=field_modality) + for field_name, field_modality in fields.items() + } + + return _terratorch_field_config + + +class TerratorchProcessingInfo(BaseProcessingInfo): + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + +class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): + def __init__(self, info: TerratorchProcessingInfo): + super().__init__(info) + self.dummy_data_generator = DummyDataGenerator( + self.info.get_hf_config().to_dict()["pretrained_cfg"] + ) + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + # Dummy data is generated based on the 'input' section + # defined in the HF configuration file + + if mm_options: + logger.warning( + "Configurable multimodal profiling " + "options are not supported for Terratorch. " + "They are ignored for now." + ) + + return self.dummy_data_generator.get_dummy_mm_data() + + +class TerratorchMultiModalDataParser(MultiModalDataParser): + def __init__(self, pretrained_cfg: dict, *args, **kwargs): + self._pretrained_cfg = pretrained_cfg + super().__init__(*args, **kwargs) + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + terratorch_fields = _terratorch_field_names(self._pretrained_cfg) + + return DictEmbeddingItems( + data, + modality="image", + required_fields=terratorch_fields, + fields_factory=_terratorch_field_factory(self._pretrained_cfg), + ) + + return super()._parse_image_data(data) + + +class TerratorchMultiModalProcessor(BaseMultiModalProcessor): + def __init__( + self, + info: TerratorchProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", + *, + cache: Optional[MultiModalProcessorOnlyCache] = None, + ) -> None: + self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] + super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) + + def _get_data_parser(self) -> MultiModalDataParser: + return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + return [] + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, + ) -> MultiModalInputs: + if "image" in mm_data: + image_data = mm_data["image"] + image_data = {k: v.unsqueeze(0) for k, v in image_data.items()} + else: + image_data = mm_data + image_data = {k: v.unsqueeze(0) for k, v in image_data.items()} + + mm_data = {"image": image_data} + + mm_items = self._to_mm_items(mm_data) + tokenization_kwargs = tokenization_kwargs or {} + mm_hashes = self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids + ) + mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} + + mm_processed_data = BatchFeature(image_data) + + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), + ) + + return MultiModalInputs( + type="multimodal", + prompt_token_ids=[1], + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + +@default_pooling_type("All") +@MULTIMODAL_REGISTRY.register_processor( + TerratorchMultiModalProcessor, + info=TerratorchProcessingInfo, + dummy_inputs=TerratorchInputBuilder, +) +class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): + merge_by_field_config = True + supports_multimodal_raw_input_only = True + is_pooling_model = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"] + + self.inference_runner = InferenceRunner(config) + self.model = self.inference_runner.model + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, + ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty((input_ids.shape[0], 0)) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + model_output = self.inference_runner.forward(**kwargs) + + return model_output.output + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_list = [] + model_buffers = dict(self.named_buffers()) + loaded_buffers = [] + for key, value in weights: + if isinstance(value, (dict, OrderedDict)): + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + name = f"inference_runner.{name}" + + if "pos_embed" in name: + continue + + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr( + buffer, "weight_loader", default_weight_loader + ) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + elif isinstance(value, torch.Tensor): + params_list.append((f"inference_runner.model.{key}", value)) + + # Load the remaining model parameters + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(params_list) + + return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index fc242d1adafd0..1cfe401b243c7 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,63 +15,99 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" + from collections.abc import Iterable, Mapping from contextlib import contextmanager +from pathlib import Path from typing import Literal, Optional, Union import regex as re import torch +import transformers +from packaging.version import Version from torch import nn -from transformers import (AutoModel, BatchFeature, PretrainedConfig, - PreTrainedModel) +from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, VllmConfig) -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + VllmConfig, +) +from vllm.config.multimodal import BaseDummyOptions +from vllm.config.utils import getattr_iter +from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, PlaceholderRange) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalUUIDDict, + PlaceholderRange, +) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo) +from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of -from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - flatten_bn, make_empty_intermediate_tensors_factory, - maybe_prefix) +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) logger = init_logger(__name__) +def get_feature_request_tip( + model: str, + trust_remote_code: bool, +) -> str: + hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" + gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" + url = hf_url if trust_remote_code else gh_url + prefix = f"Please open {url} to request support for this feature. " + if Path(model).exists(): + prefix = "" + doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" + tip = f"See {doc_url} for instructions on how to add support yourself." + return f"{prefix}{tip}" + + def vllm_flash_attention_forward( - # Transformers args - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - # Transformers kwargs - scaling: Optional[float] = None, - # vLLM kwargs - attention_instances: Optional[dict[Attention]] = None, - **kwargs): + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: Optional[float] = None, + # vLLM kwargs + attention_instances: Optional[dict[Attention]] = None, + **kwargs, +): self_attn = attention_instances[module.layer_idx] if scaling is not None: self_attn.impl.scale = float(scaling) @@ -88,34 +124,52 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): logger.debug("%s: %s -> %s", name, old_module, new_module) +def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: + """ + Callable to be passed to `@support_torch_compile`'s `enable_if` argument. + + Defaults to `True` but is disabled in the following situations: + + - The model uses dynamic rope scaling. + """ + enable = True + text_config = vllm_config.model_config.hf_config.get_text_config() + # Dynamic rope scaling is not compatible with torch.compile + rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} + if rope_scaling.get("rope_type") == "dynamic": + enable = False + return enable + + +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] + + def replace_linear_class( - linear: nn.Linear, style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig + linear: nn.Linear, + style: Style = "replicate", + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. Args: - linear (nn.Linear): `nn.Linear` to be replaced. - style (str): Tensor parallel style of the new linear, e.g. "colwise". - quant_config (QuantConfig): Quantization config for the new linear. + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear, e.g. "colwise". + quant_config: Quantization config for the new linear. Returns: - Union[ColumnParallelLinear, RowParallelLinear]: The new linear. + The new linear. """ if not isinstance(style, str): - raise ValueError( - f"Unsupported parallel style type {type(style)}, expected str") + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), - "colwise_rep": (ColumnParallelLinear, { - "gather_output": True - }), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), "rowwise": (RowParallelLinear, {}), - "rowwise_rep": (RowParallelLinear, { - "input_is_parallel": False - }), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), "replicate": (ReplicatedLinear, {}), }.get(style, (ReplicatedLinear, {})) @@ -124,11 +178,51 @@ def replace_linear_class( output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + prefix=prefix, return_bias=False, **vllm_linear_kwargs, ) +def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: + """Replace a Transformers RMSNorm with vLLM's RMSNorm. + + This method assumes: + - Weight is stored as `weight`. + - Epsilon is stored as `eps` or `variance_epsilon`. + - `with_scale` indicates whether the layer has a weight (Gemma3n only). + - `var_hidden_size` is only ever used for Intern vision encoder in vLLM + and Transformers doesn't appear to have the same concept. + """ + eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) + kwargs = {"hidden_size": hidden_size, "eps": eps} + # Update hidden size if weight is available + weight_meta = getattr(rms_norm, "weight", None) + if weight_meta is not None: + kwargs["hidden_size"] = weight_meta.size(0) + # Check if weight is all zeros, which indicates GemmaRMSNorm + # We must create a new instance because rms_norm is on meta + try: + with torch.device("cpu"): + weight_test = getattr(rms_norm.__class__(1), "weight", None) + except Exception: + logger.warning( + "Failed to determine if RMSNorm weight is centered on zero or one. " + "Defaulting to one." + ) + weight_test = None + if weight_test is not None and torch.all(weight_test == 0): + return GemmaRMSNorm(**kwargs) + # Otherwise assume it's a regular RMSNorm + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) + if weight_meta is not None: + kwargs["dtype"] = weight_meta.dtype + else: + # No weight, fall back to weightless RMSNorm + kwargs["has_weight"] = False + return RMSNorm(**kwargs) + + # Copied from `accelerate` @contextmanager def init_on_device_without_buffers(device: torch.device): @@ -151,12 +245,12 @@ def init_on_device_without_buffers(device: torch.device): kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs) + module._parameters[name].to(device), **kwargs + ) tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): - def wrapper(*args, **kwargs): kwargs["device"] = device return fn(*args, **kwargs) @@ -167,21 +261,21 @@ def init_on_device_without_buffers(device: torch.device): nn.Module.register_parameter = register_empty_parameter for torch_function_name in tensor_constructors_to_patch: setattr( - torch, torch_function_name, - patch_tensor_constructor(getattr(torch, torch_function_name))) + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) yield finally: nn.Module.register_parameter = old_register_parameter - for torch_function_name, old_torch_function in ( - tensor_constructors_to_patch.items()): + for ( + torch_function_name, + old_torch_function, + ) in tensor_constructors_to_patch.items(): setattr(torch, torch_function_name, old_torch_function) class MultiModalProcessingInfo(BaseProcessingInfo): - - def get_hf_config(self): - return self.ctx.model_config.hf_config - def get_supported_mm_limits(self): return {"image": None} @@ -191,9 +285,11 @@ class MultiModalProcessingInfo(BaseProcessingInfo): def get_max_image_tokens(self) -> int: width, height = self.get_max_image_size() processor = self.get_hf_processor() - mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {} + multimodal_config = self.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} mm_tokens = processor._get_num_multimodal_tokens( - image_sizes=([height, width], ), **mm_processor_kwargs) + image_sizes=([height, width],), **mm_processor_kwargs + ) image_tokens = mm_tokens["num_image_tokens"][0] return image_tokens @@ -201,9 +297,7 @@ class MultiModalProcessingInfo(BaseProcessingInfo): return 10_000, 10_000 # hardcode for arbitrary very large size -class MultiModalDummyInputsBuilder( - BaseDummyInputsBuilder[MultiModalProcessingInfo]): - +class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -218,21 +312,25 @@ class MultiModalDummyInputsBuilder( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) target_width, target_height = self.info.get_max_image_size() + image_overrides = mm_options.get("image") if mm_options else None + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -256,53 +354,37 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): def _get_mm_fields_config( self, - hf_inputs, - hf_processor_mm_kwargs, - num_image_patches: torch.Tensor = None, - ): + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: # HF Processors always return a mask but vLLM doesn't need it hf_inputs.pop("attention_mask", None) + num_image_patches = hf_inputs.get("num_image_patches") mm_fields = { - key: MultiModalFieldConfig.flat_from_sizes("image", - num_image_patches) + key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) for key in hf_inputs } mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( - "image", num_image_patches) + "image", num_image_patches + ) + + # Keep these as batched, as they always have batch size as first dim + mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") return mm_fields - def _apply_hf_processor_text_mm( + def _get_hf_mm_data( self, - prompt_text: str, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> tuple[list[int], BatchFeature, bool]: + ) -> tuple[Mapping[str, object], Mapping[str, object]]: """ - Apply the HF processor on the prompt text and multi-modal data - together. - - In addition, return whether prompt replacements have been applied. + In contrast to the base class, this method always adds + `return_mm_token_type_ids` to the processor data """ - processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) processor_data["return_mm_token_type_ids"] = True - - processed_data = self._call_hf_processor( - prompt=prompt_text, - mm_data=processor_data, - mm_kwargs=hf_processor_mm_kwargs, - tok_kwargs=tokenization_kwargs, - ) - processed_data.update(passthrough_data) - - prompt_ids, = processed_data.pop("input_ids").tolist() - mm_token_type_ids = processed_data.pop( - "mm_token_type_ids" - ) if "mm_token_type_ids" in processed_data else processed_data.pop( - "token_type_ids") # for gemma3 only - - return prompt_ids, processed_data, mm_token_type_ids + return processor_data, passthrough_data def apply( self, @@ -310,6 +392,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -328,28 +411,40 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): # into string prompt = hf_processor.decode(prompt) - (prompt_ids, processed_data, - mm_token_type_ids) = self._apply_hf_processor_text_mm( - prompt_text=prompt, - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) + # Bypass cached processor and always apply to the full set of mm inputs + # NOTE: we can't just set caching=False because base class method + # transforms outputs to `MultiModalKwargs` which is not going to + # work for Transformers. We have a lot of logic tied to + # `mm_tokens_per_modality` below + prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) - # HF processor will return `mm_token_type_ids` from which - # we can infer mm_placeholders. Until then hardcode to make code run - # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 + # For gemma3 we check `token_type_ids` as the key + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" in processed_data + else "token_type_ids" + ) + mm_token_type_ids = processed_data.pop(token_type_key) + + # We can infer vLLM style placeholder from token type ids, if we split + # it for each input `mm_data`. mm_positions = torch.where(mm_token_type_ids == 1)[1] images = mm_items.get_items("image", ImageProcessorItems) - mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs - or {}) + multimodal_config = self.info.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} image_sizes = [] for item_idx in range(len(images)): image_size = images.get_image_size(item_idx) image_sizes.append((image_size.height, image_size.width)) mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( - image_sizes=image_sizes, **mm_processor_kwargs) + image_sizes=image_sizes, **mm_processor_kwargs + ) mm_placeholders = {} split_sizes = mm_tokens_per_modality["num_image_tokens"] @@ -361,27 +456,27 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): PlaceholderRange( offset=positions[0].item(), length=positions.shape[0], - is_embed=(mm_tokens == hf_processor.image_token_id).bool()) - for positions, mm_tokens in zip(chunked_mm_positions, - chunked_mm_tokens) + is_embed=(mm_tokens == hf_processor.image_token_id).bool(), + ) + for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) ] mm_placeholders = {"image": ranges} - num_image_patches = torch.tensor( + processed_data["num_image_patches"] = torch.tensor( mm_tokens_per_modality["num_image_patches"] - ) if "num_image_patches" in mm_tokens_per_modality else None - processed_data['num_image_patches'] = num_image_patches + ) mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, - num_image_patches), + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + ) + + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids ) - mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs) return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, @@ -391,8 +486,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens" - ] # TODO transformers will have a util to get it + embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -404,19 +498,34 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self.device_config: DeviceConfig = vllm_config.device_config self.model_config: ModelConfig = vllm_config.model_config self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: QuantizationConfig = vllm_config.quant_config + self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() - self.pp_size = self.pp_group.world_size - self.pp_rank = self.pp_group.rank_in_group - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() - # To be updated in child classes for use in `load_weights` - self.skip_prefixes: Optional[list[str]] = None + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + """Skip loading weights whose qualname starts with these prefixes.""" + self.skip_substrs: list[str] = [] + """Skip loading weights whose qualname contains these substrings.""" + self.ignore_unexpected_prefixes: list[str] = [] + """Ignore unexpected weights whose qualname starts with these prefixes. + """ + self.ignore_unexpected_suffixes: list[str] = [] + """Ignore unexpected weights whose qualname ends with these suffixes.""" + + if self.quant_config: + quant_method_name = self.quant_config.get_name() + # Check for unsupported quantization methods. + if quant_method_name == "mxfp4": + raise NotImplementedError( + "Transformers backend does not support MXFP4 quantization yet." + ) + # Skip loading extra bias for GPTQ models. + if "gptq" in quant_method_name: + self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors - # TODO: @raushan, use the public `model.set_attn_implementation()` - # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( @@ -425,39 +534,52 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): trust_remote_code=self.model_config.trust_remote_code, ) + # Remove layers not on this pipeline parallel rank self.pipeline_parallel() - self.tensor_parallel() + # Substitute remaining layers with vLLM's layers as needed + self.recursive_replace() + # Create attention instances for KV cache allocation + self.attention_instances = self.create_attention_instances() # Input embeddings - if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): + input_embeddings = self.model.get_input_embeddings() + if not isinstance(input_embeddings, PPMissingLayer): + # Some models use embedding scales + self.embed_scale = getattr(input_embeddings, "embed_scale", None) + names = ("embedding_size", "hidden_size") + embedding_dim = getattr_iter(self.text_config, names, None) + assert embedding_dim is not None self.model.set_input_embeddings( VocabParallelEmbedding( self.text_config.vocab_size, - self.text_config.hidden_size, + embedding_dim=embedding_dim, org_num_embeddings=self.text_config.vocab_size, quant_config=self.quant_config, - )) - - # Attention layers - self.attention_instances = self.create_attention_instances() + ) + ) # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states"], self.text_config.hidden_size)) + # Pipeline parallel intermediate tensors + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.text_config.hidden_size + ) def pipeline_parallel(self): """ Apply the model's pipeline parallelization plan. """ - if self.pp_size <= 1: + if self.pp_group.world_size <= 1: return if not self.model.supports_pp_plan: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) raise ValueError( - f"{type(self.model)} does not support pipeline parallel yet!") + f"{type(self.model)} does not support pipeline parallel. {tip}" + ) module_lists = [] module_list_idx = None @@ -470,22 +592,25 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): if len(module_lists) > 1: raise ValueError( "Pipeline parallel of models with multiple `ModuleList`s " - "in the base model are not supported yet!") + "in the base model are not supported yet!" + ) if module_list_idx is None: - raise ValueError( - f"Could not find `ModuleList` in {type(self.model)}") + raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") # Layers before module list for name in pp_plan[:module_list_idx]: if self.pp_group.is_first_rank or ( - self.text_config.tie_word_embeddings - and self.pp_group.is_last_rank): + self.text_config.tie_word_embeddings and self.pp_group.is_last_rank + ): continue setattr(self.model, name, PPMissingLayer()) # Module list start_layer, end_layer = get_pp_indices( - self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) for i in range(len(layers)): @@ -494,77 +619,83 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): layers[i] = PPMissingLayer() # Layers after module list - for name in pp_plan[module_list_idx + 1:]: + for name in pp_plan[module_list_idx + 1 :]: # Modules that should be on last rank if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) - def tensor_parallel(self): - """ - Apply the model's tensor parallelization plan. - Currently only supports linear layers. - """ - # Look for tp plans in all of the PreTrainedModels found in self.model - is_pretrained_model = lambda m: isinstance(m, PreTrainedModel) - supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None - pretrained_models = filter(is_pretrained_model, self.model.modules()) - models_with_tp_plan = filter(supports_tp_plan, pretrained_models) + def recursive_replace(self): + """Recursively replace modules in the model as needed. - if not any(models_with_tp_plan) and self.tp_size > 1: + Currently, this replaces: + + - `nn.Linear` with vLLM's tensor parallel linear classes + - `*RMSNorm` with vLLM's `RMSNorm` + """ + tp_plan = self.model.tp_plan + + if not tp_plan and self.tp_group.world_size > 1: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) raise ValueError( - f"{type(self.model)} does not support tensor parallel yet!") + f"{type(self.model)} does not support tensor parallel. {tip}" + ) - def _tensor_parallel(module: nn.Module, - prefix: str = "", - tp_plan=None): - tp_plan = tp_plan or {} + # Prefix the patterns because we always start from `self.model` + tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} - # If the current module is a PreTrainedModel, set the tp_plan for - # all of its children - if isinstance(module, PreTrainedModel): - tp_plan = module.config.base_model_tp_plan or {} - tp_plan = { - maybe_prefix(prefix, k): v - for k, v in tp_plan.items() - } - - # Some weight loaders expect linear layers to inherit from vLLM's - # LinearBase class, so we set a default style which causes any - # unspecified linear layers to be replaced with ReplicatedLinear + def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): + new_module = child_module qual_name = maybe_prefix(prefix, child_name) if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) + # Some weight loaders expect all linear layers to inherit + # LinearBase, so we set a default style which causes any + # unspecified layers to be replaced with ReplicatedLinear style = tp_plan.get(pattern, "replicate") - new_module = replace_linear_class(child_module, style, - self.quant_config) + new_module = replace_linear_class( + child_module, style, self.quant_config, prefix=qual_name + ) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class( + child_module, self.text_config.hidden_size + ) + else: + _recursive_replace(child_module, prefix=qual_name) + + if new_module is not child_module: setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) - else: - _tensor_parallel(child_module, - prefix=qual_name, - tp_plan=tp_plan) - _tensor_parallel(self.model) + _recursive_replace(self.model, prefix="model") - def create_attention_instances(self) -> dict[int, Attention]: + def create_attention_instances( + self, attn_type: AttentionType = AttentionType.DECODER + ) -> dict[int, Attention]: """ Create `Attention` instances to inform KV cache allocation. """ - num_heads = self.model_config.get_num_attention_heads( - self.parallel_config) + num_heads = self.model_config.get_num_attention_heads(self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - start, end = get_pp_indices(self.text_config.num_hidden_layers, - self.pp_rank, self.pp_size) + logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None) + start, end = get_pp_indices( + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention per_layer_sliding_window = None - if (hasattr(self.config, "layer_types") - and self.config.layer_types[i] == "sliding_attention"): + if ( + hasattr(self.config, "layer_types") + and self.config.layer_types[i] == "sliding_attention" + ): per_layer_sliding_window = self.config.sliding_window attention_instances[i] = Attention( @@ -576,11 +707,14 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, + logits_soft_cap=logits_soft_cap, per_layer_sliding_window=per_layer_sliding_window, - prefix=f"{i}.attn") + prefix=f"{i}.attn", + attn_type=attn_type, + ) return attention_instances - def init_parameters(self, module: nn.Module): + def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None): """ If a `parameter` is on the `meta` device, then its parent `module` is the original module created by: @@ -590,15 +724,22 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self.model: PreTrainedModel = AutoModel.from_config(...) ``` """ - for name, param in module.named_parameters(recurse=False): - if param.device == torch.device("meta"): - new_param = nn.Parameter( - torch.empty_like(param.data, - dtype=self.model_config.dtype, - device=self.device_config.device)) - setattr(module, name, new_param) - for child in module.children(): - self.init_parameters(child) + + def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]): + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype or self.model_config.dtype, + device=self.device_config.device, + ) + ) + setattr(module, name, new_param) + for child in module.children(): + _init_parameters(child, dtype) + + _init_parameters(module, dtype) def forward( self, @@ -606,8 +747,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - if not get_pp_group().is_first_rank: + if not self.pp_group.is_first_rank: assert intermediate_tensors is not None input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] @@ -628,43 +770,49 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): use_cache=False, position_ids=position_ids, attention_instances=self.attention_instances, - return_dict=False)[0][0, ...] # we remove batch dimension for now + return_dict=False, + **kwargs, + )[0][0, ...] # we remove batch dimension for now - if not get_pp_group().is_last_rank: + if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + skip_substrs=self.skip_substrs, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, + ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - -@support_torch_compile -class TransformersModel(TransformersBase): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # Add `model.` prefix for base model checkpoints - "": "model.", - # Remove `model.` from places it should not be - "model.model.": "model.", - "model.score": "score", - }) + def check_version(self, min_version: str, feature: str): + installed = Version(transformers.__version__) + required = Version(min_version) + if installed < required: + raise ImportError( + f"Transformers backend requires transformers>={required} " + f"for {feature}, but got {installed}" + ) -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # Tell `TransformersBase.load_weights` to skip # `lm_head` if the model has tied word embeddings if self.text_config.tie_word_embeddings: - self.skip_prefixes = ["lm_head."] + self.skip_prefixes.append("lm_head.") - if get_pp_group().is_last_rank: + if self.pp_group.is_last_rank: self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( self.text_config.vocab_size, @@ -674,48 +822,48 @@ class TransformersForCausalLM(TransformersBase): ) if self.text_config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( - self.model.get_input_embeddings()) + self.model.get_input_embeddings() + ) logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.text_config.vocab_size, - logit_scale) + self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + if self.embed_scale is not None: + inputs_embeds *= self.embed_scale + return inputs_embeds + def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits -def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: - """Flatten until a list of tensors can be concatenated then do concat""" - - def _can_concat(x: list[torch.Tensor]): - return len(set(map(lambda _x: _x.shape[1:], x))) == 1 - - if _can_concat(x): - return torch.concat(x) - return flatten_and_concat(flatten_bn(x)) - - @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, - dummy_inputs=MultiModalDummyInputsBuilder) + dummy_inputs=MultiModalDummyInputsBuilder, +) @support_torch_compile( + # set `positions` to last dim to support Qwen-mrope dynamic_arg_dims={ "input_ids": 0, "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) # set `positions` to last dim to support Qwen-mrope + }, + enable_if=can_enable_torch_compile, +) class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): + supports_multimodal_raw_input_only = True + merge_by_field_config = True # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is hf_to_vllm_mapper = WeightsMapper( @@ -737,7 +885,8 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): "model.embed_tokens": "model.language_model.embed_tokens", "model.layers": "model.language_model.layers", "model.norm": "model.language_model.norm", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -752,53 +901,45 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - if inputs_embeds is None: - multimodal_embeds = self.get_multimodal_embeddings(**kwargs) - if multimodal_embeds is not None: - inputs_embeds = self.get_input_embeddings( - input_ids, multimodal_embeds) - input_ids = None - - model_output = super().forward(input_ids, positions, - intermediate_tensors, inputs_embeds) + # Gemma3 and PaliGemma needs `token_type_ids` to work correctly + # Other models will not have `token_type_ids` in kwargs + kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} + model_output = super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) return model_output + def get_language_model(self) -> torch.nn.Module: + """`TransformersForMultimodalLM` does not contain a vLLM language model class. + Therefore, in order to return a language model vLLM class, we use a wrapper to + give `self` the same interface as `TransformersForCausalLM`.""" + + class LanguageModelWrapper(TransformersForCausalLM): + def __init__(self, multimodal_model): + # Don't call super().__init__() to avoid re-initialization + self.__dict__.update(multimodal_model.__dict__) + + model = getattr_iter(self.model, ("language_model", "text_model"), None) + + return LanguageModelWrapper(self) + def get_multimodal_embeddings(self, **kwargs): - pixel_values = kwargs.pop("pixel_values", None) - pixel_values = pixel_values if pixel_values is not None else kwargs.pop( - "image_patches", None) - image_embeds = kwargs.pop("image_embeds", None) + pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None) + image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None) + # Model might use `image_patches` instead of `pixel_values` + if pixel_values is None: + pixel_values = kwargs.pop("image_patches", None) if image_embeds is not None: return image_embeds - if pixel_values is None and image_embeds is None: + if pixel_values is None: return None num_image_patches = kwargs.pop("num_image_patches") + kwargs.pop("token_type_ids", None) # used only in `forward` if pixel_values is not None: - if isinstance(pixel_values, torch.Tensor): - pixel_values = flatten_bn(pixel_values).to(self.dtype) - elif is_list_of(pixel_values, torch.Tensor): - pixel_values = flatten_and_concat(pixel_values).to(self.dtype) - else: - raise ValueError( - f"Unsupported pixel_values type {type(pixel_values)}. " - "Expected `torch.Tensor` or list of `torch.Tensor`.") - - if isinstance(num_image_patches, list): - num_image_patches = torch.cat(num_image_patches) - - vision_embeddings = self.model.get_image_features( - pixel_values, - **{ - k: v.flatten(0, 1) - for k, v in kwargs.items() - }, - ) + vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) if isinstance(vision_embeddings, torch.Tensor): if vision_embeddings.ndim == 2: @@ -808,8 +949,8 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): # but transformers returns concat tensors if each patch # is of different size. We split it back to make vLLM happy vision_embeddings = torch.split( - vision_embeddings, - num_image_patches.flatten().tolist()) + vision_embeddings, num_image_patches.flatten().tolist() + ) vision_embeddings = [ embed.flatten(start_dim=0, end_dim=-2) for embed in vision_embeddings @@ -817,18 +958,4 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): return vision_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings=None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings()(input_ids) - if (multimodal_embeddings is not None - and len(multimodal_embeddings) != 0): - mask = (input_ids == self.config.image_token_id) - mask = mask.unsqueeze(-1).expand_as(inputs_embeds) - multimodal_embeddings = torch.cat(multimodal_embeddings) - - inputs_embeds = inputs_embeds.masked_scatter( - mask, multimodal_embeddings) - return inputs_embeds + get_input_embeddings = SupportsMultiModal.get_input_embeddings diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py new file mode 100644 index 0000000000000..5267e447902f0 --- /dev/null +++ b/vllm/model_executor/models/transformers_moe.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` MoE models.""" + +from typing import Any + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config.utils import getattr_iter +from vllm.distributed import get_dp_group, get_ep_group +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + +from .interfaces import MixtureOfExperts, SupportsMultiModal +from .transformers import ( + TransformersBase, + TransformersForCausalLM, + TransformersForMultimodalLM, + can_enable_torch_compile, + log_replacement, +) +from .utils import maybe_prefix + + +@CustomOp.register("transformers_fused_moe") +class TransformersFusedMoE(FusedMoE): + """Custom FusedMoE for the Transformers backend.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._topk_ids: torch.Tensor = None + + def custom_routing_function(hidden_states, gating_output, topk, renormalize): + """Return `topk_weights` from `gating_output` and the + `topk_ids` we stored in the layer earlier.""" + topk_weights = gating_output + topk_ids = self._topk_ids + # Handle all gather in expert parallel + if topk_ids.size(0) != hidden_states.size(0): + dp_metadata = get_forward_context().dp_metadata + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + is_sp = self.is_sequence_parallel + dist_group = get_ep_group() if is_sp else get_dp_group() + assert sizes[dist_group.rank_in_group] == topk_ids.shape[0] + (topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes) + return topk_weights, topk_ids + + self.custom_routing_function = custom_routing_function + + def forward( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + """In Transformers `experts.forward` will have this signature. + + We discard any extra kwargs because we cannot use them here.""" + return torch.ops.vllm.transformers_moe_forward( + hidden_states, + topk_ids.to(torch.int32), + topk_weights.to(torch.float32), + self.layer_name, + ) + + +def transformers_moe_forward( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + """Store the `topk_ids` in the layer and call the actual forward.""" + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._topk_ids = topk_ids + # Clone hidden_states because it will be mutated in-place in FusedMoE + return self.forward_impl(hidden_states.clone(), topk_weights) + + +def transformers_moe_forward_fake( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="transformers_moe_forward", + op_func=transformers_moe_forward, + mutates_args=["hidden_states"], + fake_impl=transformers_moe_forward_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order,), +) + + +class TransformersMoEBase(TransformersBase, MixtureOfExperts): + def __init__(self, *, vllm_config, prefix=""): + self.check_version("4.57.0.dev0", "MoE models support") + self.ep_group = get_ep_group() + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ): + for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers): + mlp_layer.experts.set_eplb_state( + moe_layer_idx=moe_layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ): + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for mlp in self.mlp_layers: + mlp.n_local_physical_experts = num_local_physical_experts + mlp.n_physical_experts = num_physical_experts + mlp.n_redundant_experts = self.num_redundant_experts + mlp.experts.update_expert_map() + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """ + Params for weights, fp8 weight scales, fp8 activation scales + (param_name, weight_name, expert_id, shard_id) + """ + ckpt_names = [ + # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) + ("gate_proj", "down_proj", "up_proj"), # Most common MoE style + ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style + ("linear", "linear_1", "linear_v"), # Grok1 style + ] + num_experts = self.model_config.get_num_experts() + num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts + expert_mapping = [] + for gate_proj, down_proj, up_proj in ckpt_names: + expert_mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name=gate_proj, + ckpt_down_proj_name=down_proj, + ckpt_up_proj_name=up_proj, + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, + ) + ) + return expert_mapping + + def recursive_replace(self): + """Initialize the MoE layers.""" + text_config = self.text_config + + # Positional arguments + num_experts = self.model_config.get_num_experts() + top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"], None) + assert top_k is not None + hidden_size = text_config.hidden_size + intermediate_size = getattr_iter( + text_config, ["moe_intermediate_size", "intermediate_size"], None + ) + assert intermediate_size is not None + + # If there are shared experts, the results are + # reduced after mlp.forward() not inside FusedMoE + num_shared_experts = getattr_iter( + text_config, + [ + "n_shared_experts", # DeepSeek, Docs, GLM + "moe_num_shared_experts", # Aria, Ernie + ], + 0, + ) + reduce_results = num_shared_experts == 0 + + def add_all_reduce(mlp: nn.Module): + """Adds an all-reduce to the output of `mlp.forward()`.""" + + class MLPWithAllReduce(mlp.__class__): + def forward(self, *args, **kwargs): + output = super().forward(*args, **kwargs) + return self.experts.maybe_all_reduce_tensor_model_parallel(output) + + mlp.__class__ = MLPWithAllReduce + + # Unused kwargs since we use custom_routing_function: + # - `scoring_func` and `e_score_correction_bias` only used for grouped + # topk routing inside vLLM and are non-trivial to infer + # and hard code `use_grouped_topk=False` + # - `renormalize` passed anyway because it's easy to infer + # - `num_expert_group` and `topk_group` used for inferring expert + # placement strategy in FusedMoE + # - `apply_router_weight_on_input` is already applied in Transformers + renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) + num_expert_group = getattr(text_config, "n_group", None) + topk_group = getattr(text_config, "topk_group", None) + + # MoE activation function + activation = "silu" + wrapped_arch = self.config.architectures[0].lower() + if "gptoss" in wrapped_arch: + activation = "swigluoai" + elif "grok1" in wrapped_arch: + activation = "gelu" + + # Expert mapping for `AutoWeightsLoader` + expert_mapping = self.get_expert_mapping() + + # Expert parallel load balancing kwargs + enable_eplb = self.parallel_config.enable_eplb + num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts + + # MixtureOfExperts mixin settings + ep_size = self.ep_group.world_size + + self.mlp_layers = [] # Used for MixtureOfExperts methods + self.expert_weights = [] + self.num_moe_layers = 0 + self.num_expert_groups = 1 if num_expert_group is None else num_expert_group + self.num_logical_experts = num_experts + self.num_physical_experts = num_experts + num_redundant_experts + self.num_local_physical_experts = self.num_physical_experts // ep_size + self.num_routed_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_redundant_experts = num_redundant_experts + + # Recursively fuse MoE layers + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + if child_name == "experts" and isinstance(child_module, nn.ModuleList): + # Alias for readability + mlp = module + experts = child_module + # Do the experts have biases + has_bias = False + for experts_param_name, _ in experts.named_parameters(): + if "bias" in experts_param_name: + has_bias = True + break + # Double check there are no shared experts + nonlocal reduce_results + if reduce_results: + for mlp_param_name, _ in mlp.named_parameters(): + if "shared_expert" in mlp_param_name: + reduce_results = False + # If the config does not specify num_shared_experts, but + # the model has shared experts, we assume there is one. + self.num_shared_experts = 1 + break + # Replace experts module with FusedMoE + fused_experts = TransformersFusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=reduce_results, + renormalize=renormalize, + # Hard coded because topk happens in Transformers + use_grouped_topk=False, + num_expert_group=num_expert_group, + topk_group=topk_group, + quant_config=self.quant_config, + prefix=qual_name, + activation=activation, + enable_eplb=enable_eplb, + num_redundant_experts=num_redundant_experts, + has_bias=has_bias, + expert_mapping=expert_mapping, + ) + mlp.experts = fused_experts + log_replacement(qual_name, experts, fused_experts) + # Update MixtureOfExperts mixin state + self.mlp_layers.append(mlp) + self.expert_weights.append(fused_experts.get_expert_weights()) + self.num_moe_layers += 1 + # If results are not all-reduced in FusedMoE, ensure they + # are all-reduced at the end of mlp.forward() if tensor + # parallel or expert parallel is enabled + if not reduce_results and ( + fused_experts.tp_size > 1 or fused_experts.ep_size > 1 + ): + add_all_reduce(mlp) + else: + _recursive_replace(child_module, prefix=qual_name) + + _recursive_replace(self.model, prefix="model") + # Continue with the replacement of layers in TransformersBase + super().recursive_replace() + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): + pass + + +@support_torch_compile( + # set `positions` to last dim to support Qwen-mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }, + enable_if=can_enable_torch_compile, +) +class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM): + get_input_embeddings = SupportsMultiModal.get_input_embeddings diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py new file mode 100644 index 0000000000000..98d2611351c03 --- /dev/null +++ b/vllm/model_executor/models/transformers_pooling.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` models for pooling tasks.""" + +from typing import Optional, Union + +import torch +from transformers import AutoModelForSequenceClassification + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + CLSPool, + DispatchPooler, + Pooler, +) +from vllm.sequence import IntermediateTensors + +from .interfaces_base import VllmModelForPooling +from .transformers import TransformersBase, can_enable_torch_compile +from .transformers_moe import TransformersMoEBase +from .utils import WeightsMapper + + +class TransformersPoolingBase(TransformersBase, VllmModelForPooling): + hf_to_vllm_mapper = WeightsMapper( + # These are applied in order, so the order matters! + orig_to_new_prefix={ + # Handle BERT-like models + "roberta": "model", + "bert": "model", + # Add `model.` prefix for base model checkpoints + "": "model.", + # Remove `model.` prefix if it was already there + "model.model.": "model.", + # Classifier/scoring heads will be adjacent to `model` + "model.score": "classifier", + "model.classifier": "classifier", + }, + orig_to_new_suffix={ + # Replace legacy suffixes used for norms + ".gamma": ".weight", + ".beta": ".bias", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Skip unsupported/unwanted output embeddings layers + self.skip_prefixes.extend( + [ + "model.lm_head.", + "model.predictions.", + "model.qa_outputs.", + "model.embeddings_project.", + "model.discriminator_predictions.", + ] + ) + + # Some encoder models have the position_ids buffer in the checkpoint. + # vLLM will always pass position_ids as an argument, so we skip loading + # the buffer if it exists + self.skip_substrs.append("position_ids") + + # Some encoder models have the bias of the final classifier layer + # in the checkpoint. vLLM does not use this bias, so we skip loading + # it if it exists + self.skip_substrs.append("score.bias") + + # roberta-like models an extra padding in positions. + # FIXME(Isotr0py): This is quite hacky for roberta edge case, + # we should find a better way to handle this. + self.is_roberta = "roberta" in self.text_config.model_type + self.padding_idx = self.text_config.pad_token_id + + def create_attention_instances( + self, attn_type: AttentionType = AttentionType.DECODER + ) -> dict[int, Attention]: + # TODO(hmellor): Better way to detect encoder models + # In encoder models, the attention layers will have `is_causal=False` + is_encoder = lambda m: not getattr(m, "is_causal", True) + # vLLM does not support encoder-decoder models, so if any encoder layer + # is found, we assume the whole model is an encoder model + if any(is_encoder(m) for m in self.model.modules()): + attn_type = AttentionType.ENCODER_ONLY + + # Check minimum transformers version for encoder models support + if attn_type == AttentionType.ENCODER_ONLY: + self.check_version("4.57.0.dev0", "encoder models support") + + return super().create_attention_instances(attn_type) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.is_roberta: + # RoBERTa-specific positions padding + positions += self.padding_idx + 1 + return super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersEmbeddingModel(TransformersPoolingBase): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersForSequenceClassification(TransformersPoolingBase): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # Certain information about the the model and classifier can only be + # inferred from the `ForSequenceClassification` class. Therefore, we + # instantiate it on the "meta" device to avoid allocating GPU memory. + with torch.device("meta"): + seq_cls_model = AutoModelForSequenceClassification.from_config( + self.config, + torch_dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # When used for sequence classification, some models have their + # pooling layers removed. Make sure this is reflected in vLLM. + for module in seq_cls_model.modules(): + if hasattr(module, "pooler") and module.pooler is None: + self.model.pooler = None + break + if self.model.pooler is not None: + raise ValueError( + "Sequence classification models with pooling layers are not " + "supported yet in the Transformers backend." + ) + + # Unlike `lm_head`, `classifier` is not always `nn.Linear`. + self.classifier = seq_cls_model.classifier + self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) + + class ClassifierWithReshape(self.classifier.__class__): + """CLSPool has already been applied in `pooling`. + Add dim to match expected input shape of `classifier.forward`.""" + + def forward(self, *args, **kwargs): + if len(args) > 0: + args = (args[0].unsqueeze(1), *args[1:]) + return super().forward(*args, **kwargs) + + self.classifier.__class__ = ClassifierWithReshape + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel): + pass + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForSequenceClassification( + TransformersMoEBase, TransformersForSequenceClassification +): + pass diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f91c4ddb6e834..1fc34f48401df 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,8 +3,9 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" + from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch from torch import nn @@ -13,64 +14,92 @@ from transformers import BatchFeature, ProcessorMixin from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper.modeling_whisper import WhisperEncoder -from vllm import envs from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context +from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings, - merge_multimodal_embeddings_from_map) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 -class UltravoxAudioFeatureInputs(TypedDict): +class UltravoxAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - n: number of chunks + - t: Time frames (M) + - nmb: Number of mel bins + """ + type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] - """Shape: `(batch_size, num_chunks, 80, M)`""" - lens: Union[torch.Tensor, list[torch.Tensor]] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]], + TensorShape("bn", "nmb", "t"), + ] + lens: Annotated[torch.Tensor, TensorShape("bn")] """ - Length of the audio frames. Used for attention mask in WhisperEncoder. - Shape: `(batch_size, num_chunks)` - """ - token_len: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio tokens. Used for flattening the audio features. - Shape: `(batch_size, num_chunks)` + Length of the audio frames per chunk. Used for attention mask in WhisperEncoder. """ + token_len: Annotated[torch.Tensor, TensorShape("bn")] + """Length of the audio tokens per chunk. Used for flattening the audio features.""" + num_chunks: Annotated[torch.Tensor, TensorShape("n")] + """Number of chunks per audio. Used for flattening the audio features.""" -class UltravoxAudioEmbeddingInputs(TypedDict): +class UltravoxAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - na: number of audios + - afs: audio feature size + - hs: hidden size + """ + type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("b", "na", "afs", "hs") + ] -UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, - UltravoxAudioEmbeddingInputs] +UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioEmbeddingInputs] class UltravoxProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: config = self.ctx.model_config.hf_config hf_processor = self.ctx.get_hf_processor(**kwargs) @@ -83,8 +112,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo): return hf_processor - def get_feature_extractor(self, - **kwargs: object) -> WhisperFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) audio_processor = hf_processor.audio_processor # type: ignore feature_extractor = audio_processor.feature_extractor # type: ignore @@ -95,9 +123,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo): return {"audio": None} -class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] - ): - +class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -107,23 +133,26 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate - audio_len = (feature_extractor.chunk_length * sampling_rate * - _MAX_ENCODER_BATCH_SIZE) + audio_len = ( + feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE + ) num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + return { - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } -class UltravoxMultiModalProcessor( - BaseMultiModalProcessor[UltravoxProcessingInfo]): - +class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -138,7 +167,8 @@ class UltravoxMultiModalProcessor( # Text-only input not supported in composite processor if not mm_data.get("audios", []): prompt_ids = self.info.get_tokenizer().encode( - prompt, add_special_tokens=False) + prompt, add_special_tokens=False + ) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -165,7 +195,7 @@ class UltravoxMultiModalProcessor( mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - output['audio_features'] = output.pop('audio_values') + output["audio_features"] = output.pop("audio_values") return output @@ -174,17 +204,14 @@ class UltravoxMultiModalProcessor( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0)) + num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0)) return dict( # to handle longer than 30s audio, each audio might be split # into multiple chunks as such, their batch dimension can be # higher than the number of audio samples - audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", num_chunks), - audio_token_len=MultiModalFieldConfig.flat_from_sizes( - "audio", num_chunks), - audio_lens=MultiModalFieldConfig.flat_from_sizes( - "audio", num_chunks), + audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks), + audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks), + audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks), # num_chunks can convert audio_chunked to audio batch dimension audio_num_chunks=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), @@ -205,11 +232,12 @@ class UltravoxMultiModalProcessor( # belonging to the i-th audio. out_mm_data = out_mm_kwargs.get_data() num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0)) - chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, - dim=0, - dtype=torch.int32) + chunks_start_idx: torch.Tensor = torch.cumsum( + num_chunks, dim=0, dtype=torch.int32 + ) chunks_start_idx = torch.cat( - [torch.tensor([0], dtype=torch.int32), chunks_start_idx]) + [torch.tensor([0], dtype=torch.int32), chunks_start_idx] + ) def get_replacement_ultravox(item_idx: int): start = chunks_start_idx[item_idx] @@ -238,17 +266,16 @@ class StackAudioFrames(nn.Module): def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: B, T, C = audio_embeds.shape - T_pad = (T + self.stack_factor - - 1) // self.stack_factor * self.stack_factor + T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) B, T, C = audio_embeds.shape - audio_embeds = audio_embeds.view(B, T // self.stack_factor, - C * self.stack_factor) + audio_embeds = audio_embeds.view( + B, T // self.stack_factor, C * self.stack_factor + ) return audio_embeds class UltravoxProjector(nn.Module): - def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size @@ -264,7 +291,7 @@ class UltravoxProjector(nn.Module): else: self.act = get_act_fn(config.projector_act) - dim_out = config.text_hidden_size + dim_out = config.text_config.hidden_size self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) # Ultravox v0.4.1 and below use layer_norm after the second linear layer @@ -312,12 +339,15 @@ class ModifiedWhisperEncoder(WhisperEncoder): @property def max_context_length(self): - return (self.config.max_source_positions * self.conv1.stride[0] * - self.conv2.stride[0]) + return ( + self.config.max_source_positions + * self.conv1.stride[0] + * self.conv2.stride[0] + ) - def get_attention_mask_by_audio_len(self, - audio_lens: Optional[torch.Tensor], - hidden_states: torch.Tensor): + def get_attention_mask_by_audio_len( + self, audio_lens: Optional[torch.Tensor], hidden_states: torch.Tensor + ): """ Create attention mask based on audio lengths to mask out padding tokens For each sample in batch: @@ -333,9 +363,9 @@ class ModifiedWhisperEncoder(WhisperEncoder): audio_feature_len = self._get_feat_extract_output_lengths(audio_lens) max_seq_len = hidden_states.shape[1] - attention_mask = torch.arange(max_seq_len, - device=hidden_states.device)[None, :].lt( - audio_feature_len.view(-1, 1)) + attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[ + None, : + ].lt(audio_feature_len.view(-1, 1)) attention_mask = self.get_extended_attention_mask( attention_mask, None, @@ -354,21 +384,21 @@ class ModifiedWhisperEncoder(WhisperEncoder): f"Whisper expects the mel input features to be of length " f"{expected_seq_length} or less, but found " f"{input_features.shape[-1]}. Make sure to pad the input mel " - f"features to {expected_seq_length}.") + f"features to {expected_seq_length}." + ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) - embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)] hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) - attention_mask = self.get_attention_mask_by_audio_len( - audio_lens, hidden_states) + attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) for encoder_layer in self.layers: layer_outputs = encoder_layer( @@ -386,16 +416,19 @@ class ModifiedWhisperEncoder(WhisperEncoder): @MULTIMODAL_REGISTRY.register_processor( UltravoxMultiModalProcessor, info=UltravoxProcessingInfo, - dummy_inputs=UltravoxDummyInputsBuilder) + dummy_inputs=UltravoxDummyInputsBuilder, +) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) + orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."} + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -406,7 +439,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: UltravoxConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config @@ -422,23 +455,28 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): model_or_path=config.audio_model_id, revision=None, prefix="audio_tower.", - )) + ) + ) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, - hf_config=config.text_config, + hf_config=config.wrapped_model_config, prefix=maybe_prefix(prefix, "language_model"), ) if config.text_model_id is not None: # this prefix is not for initialization, but for loading weights # note the trailing dot self.secondary_weights.append( - DefaultModelLoader.Source(model_or_path=config.text_model_id, - revision=None, - prefix="language_model.")) + DefaultModelLoader.Source( + model_or_path=config.text_model_id, + revision=None, + prefix="language_model.", + ) + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def get_mm_mapping(self) -> MultiModelKeys: """ @@ -451,8 +489,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ) def _audio_features_to_embeddings( - self, input_features: torch.Tensor, - audio_lens: torch.Tensor) -> torch.Tensor: + self, input_features: torch.Tensor, audio_lens: torch.Tensor + ) -> torch.Tensor: audio_features = input_features.to(self.audio_tower.dtype) batch_size = audio_features.size(0) audio_embeddings = [] @@ -461,8 +499,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE): end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size) # Process through audio tower - batch_features = self.audio_tower(audio_features[start:end], - audio_lens[start:end]) + batch_features = self.audio_tower( + audio_features[start:end], audio_lens[start:end] + ) batch_features = batch_features.to(self.audio_tower.dtype) # Process through projector @@ -474,38 +513,28 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): return audio_embeddings def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[UltravoxAudioInputs]: + self, **kwargs: object + ) -> Optional[UltravoxAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) audio_lens = kwargs.pop("audio_lens", None) audio_token_len = kwargs.pop("audio_token_len", None) + audio_num_chunks = kwargs.pop("audio_num_chunks", None) if audio_features is None and audio_embeds is None: return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_lens. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_token_len, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_token_len. " - f"Got type: {type(audio_features)}") - - return UltravoxAudioFeatureInputs(type="audio_features", - data=audio_features, - lens=audio_lens, - token_len=audio_token_len) + return UltravoxAudioFeatureInputs( + type="audio_features", + data=audio_features, + lens=audio_lens, + token_len=audio_token_len, + num_chunks=audio_num_chunks, + ) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - - return UltravoxAudioEmbeddingInputs(type="audio_embeds", - data=audio_embeds) + return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") @@ -520,12 +549,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] audio_features = pad_and_concat_to_dim3(audio_input["data"]) - # [B1, B2] -> [B1+B2] - audio_lens = flatten_bn(audio_input['lens'], concat=True) - audio_token_len = flatten_bn(audio_input['token_len'], concat=True) + audio_lens = audio_input["lens"] + audio_token_len = audio_input["token_len"] - embeddings = self._audio_features_to_embeddings( - audio_features, audio_lens) + embeddings = self._audio_features_to_embeddings(audio_features, audio_lens) # We should flatten and concatenate embeddings based on token lengths # For example, with token_len = [4, 2, 3], flattened_embeddings will be @@ -534,23 +561,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): # Create a mask of valid indices based on token lengths max_len = embeddings.shape[1] indices = torch.arange(max_len, device=embeddings.device).expand( - embeddings.shape[0], -1) + embeddings.shape[0], -1 + ) mask = indices < audio_token_len[:, None] # Apply mask and flatten flattened_embeddings = embeddings[mask] # Return one tensor per input audio embed_lens = [ - token_len_item.sum().item() - for token_len_item in audio_input['token_len'] + chunk_lens.sum().item() + for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist()) ] return flattened_embeddings.split(embed_lens) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] @@ -561,34 +588,30 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + # Multi-modal token ID may exceed vocab size + handle_oov_mm_token: bool = True, ) -> torch.Tensor: - # The audio token index is not included in the embedding table - # We need to remove it before embedding lookup - safe_input_ids = input_ids.clone() - safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0 - inputs_embeds = self.language_model.get_input_embeddings( - safe_input_ids) - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) - # TODO(ywang96): remove this block after v0 is deprecated. - if not envs.VLLM_USE_V1: - attn_metadata = get_forward_context().attn_metadata - merge_multimodal_embeddings_from_map( - inputs_embeds, multimodal_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["audio"]) - else: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index) - return inputs_embeds + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> Union[torch.Tensor, IntermediateTensors]: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Ultravox One key thing to understand is the `input_ids` already accounts for the @@ -599,50 +622,36 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): with the `input_ids`. Args: - audio_features: A batch of audio input chunks [B, N, 80, M]. - audio_lens: Length of audio frames for each audio chunk [B]. - audio_token_len: Length of audio tokens for each audio chunk [B']. - Note: batch dim is different from batch dim in audio chunks. + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) - input_ids = None - language_model = self.language_model if hasattr(language_model, "language_model"): language_model = language_model.language_model - hidden_states = language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["audio_tower."]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def pad_and_concat_to_dim3( - features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] + features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]], ) -> torch.Tensor: """ Pad and concatenate a list of tensors. @@ -656,6 +665,7 @@ def pad_and_concat_to_dim3( if features.ndim > 3: # Flatten [B, N, 80, M] -> [B * N, 80, M] features = flatten_bn(features) + return features features = [pad_and_concat_to_dim3(f) for f in features] @@ -663,7 +673,7 @@ def pad_and_concat_to_dim3( max_len = max(f.shape[-1] for f in features) # Ensure all features have dim=3 features = [f.view(-1, *f.shape[-2:]) for f in features] - # Pad and oncatenate: + # Pad and concatenate: # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features] return torch.cat(features) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6c27fedc61b17..bd530be73c2ad 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,21 +4,31 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn from torch.func import functional_call from transformers import PretrainedConfig +from typing_extensions import deprecated import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors +from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors -from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, - is_uva_available) +from vllm.utils import ( + cdiv, + direct_register_custom_op, + get_cuda_view_from_cpu_tensor, + is_pin_memory_available, + is_uva_available, +) logger = init_logger(__name__) @@ -61,12 +71,16 @@ class WeightsMapper: def apply( self, weights: Iterable[tuple[str, torch.Tensor]] ) -> Iterable[tuple[str, torch.Tensor]]: - return ((out_name, data) for name, data in weights - if (out_name := self._map_name(name)) is not None) + return ( + (out_name, data) + for name, data in weights + if (out_name := self._map_name(name)) is not None + ) def apply_list(self, values: list[str]) -> list[str]: return [ - out_name for name in values + out_name + for name in values if (out_name := self._map_name(name)) is not None ] @@ -109,6 +123,7 @@ class AutoWeightsLoader: skip_prefixes: Optional[list[str]] = None, skip_substrs: Optional[list[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None, + ignore_unexpected_suffixes: Optional[list[str]] = None, ) -> None: super().__init__() @@ -116,6 +131,7 @@ class AutoWeightsLoader: self.skip_prefixes = skip_prefixes or [] self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or [] # update default skip_substrs self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS @@ -123,17 +139,20 @@ class AutoWeightsLoader: self, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]: - weights_by_parts = ((weight_name.split(".", 1), weight_data) - for weight_name, weight_data in weights) + weights_by_parts = ( + (weight_name.split(".", 1), weight_data) + for weight_name, weight_data in weights + ) - for prefix, group in itertools.groupby(weights_by_parts, - key=lambda x: x[0][0]): + for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]): yield ( prefix, # Because maxsplit=1 in weight_name.split(...), # the length of `parts` must either be 1 or 2 - (("" if len(parts) == 1 else parts[1], weights_data) - for parts, weights_data in group), + ( + ("" if len(parts) == 1 else parts[1], weights_data) + for parts, weights_data in group + ), ) def _get_qualname(self, prefix: str, rest: str) -> str: @@ -145,12 +164,14 @@ class AutoWeightsLoader: return ".".join((prefix, rest)) def _can_skip(self, qualname: str) -> bool: - return (any(qualname.startswith(p) for p in self.skip_prefixes) - or any(substr in qualname for substr in self.skip_substrs)) + return any(qualname.startswith(p) for p in self.skip_prefixes) or any( + substr in qualname for substr in self.skip_substrs + ) def _can_ignore_unexpected(self, qualname: str) -> bool: - return any( - qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes) + return any(iup) or any(ius) def _load_param( self, @@ -174,24 +195,26 @@ class AutoWeightsLoader: raise ValueError( f"Attempted to load nested weight '{weight_qualname}' " - f"into a single parameter '{base_prefix}'") + f"into a single parameter '{base_prefix}'" + ) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight_data) - logger.debug("Loaded weight %s with shape %s", weight_qualname, - param.shape) + logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape) yield weight_qualname - def _add_loadable_non_param_tensors(self, module: nn.Module, - child_params: dict[str, torch.Tensor]): + def _add_loadable_non_param_tensors( + self, module: nn.Module, child_params: dict[str, torch.Tensor] + ): """ Add tensor names that are not in the model params that may be in the safetensors, e.g., batch normalization stats. """ - if isinstance(module, ( + if isinstance( + module, + ( nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, @@ -199,10 +222,10 @@ class AutoWeightsLoader: nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, nn.SyncBatchNorm, - )): + ), + ): module_state_dict = module.state_dict() - for stat_name in ("running_mean", "running_var", - "num_batches_tracked"): + for stat_name in ("running_mean", "running_var", "num_batches_tracked"): child_params[stat_name] = module_state_dict[stat_name] def _load_module( @@ -222,8 +245,8 @@ class AutoWeightsLoader: loaded_params = module_load_weights(weights) if loaded_params is None: logger.warning( - "Unable to collect loaded parameters " - "for module %s", module) + "Unable to collect loaded parameters for module %s", module + ) else: yield from map( lambda x: self._get_qualname(base_prefix, x), @@ -246,17 +269,18 @@ class AutoWeightsLoader: continue - yield from self._load_module(prefix, - child_modules[child_prefix], - child_weights) + yield from self._load_module( + prefix, child_modules[child_prefix], child_weights + ) elif child_prefix in child_params: if self._can_skip(prefix): logger.debug("Skipping param %s", prefix) continue - yield from self._load_param(prefix, child_params[child_prefix], - child_weights) + yield from self._load_param( + prefix, child_params[child_prefix], child_weights + ) else: can_skip_module = self._can_skip(prefix + ".") can_skip_param = self._can_skip(prefix) @@ -272,8 +296,10 @@ class AutoWeightsLoader: continue - msg = (f"There is no module or parameter named '{prefix}' " - f"in {type(self.module).__name__}") + msg = ( + f"There is no module or parameter named '{prefix}' " + f"in {type(self.module).__name__}" + ) raise ValueError(msg) def load_weights( @@ -285,8 +311,9 @@ class AutoWeightsLoader: if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name - weights = ((name, weight) for name, weight in weights - if not self._can_skip(name)) + weights = ( + (name, weight) for name, weight in weights if not self._can_skip(name) + ) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights @@ -310,20 +337,17 @@ def init_vllm_registered_model( hf_config = vllm_config.model_config.hf_config if hf_config is not None: - vllm_config = vllm_config.with_hf_config(hf_config, - architectures=architectures) + vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures) return initialize_model(vllm_config=vllm_config, prefix=prefix) @overload -def flatten_bn(x: torch.Tensor) -> torch.Tensor: - ... +def flatten_bn(x: torch.Tensor) -> torch.Tensor: ... @overload -def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: - ... +def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload @@ -331,8 +355,7 @@ def flatten_bn( x: Union[list[torch.Tensor], torch.Tensor], *, concat: Literal[True], -) -> torch.Tensor: - ... +) -> torch.Tensor: ... @overload @@ -340,8 +363,7 @@ def flatten_bn( x: Union[list[torch.Tensor], torch.Tensor], *, concat: bool = False, -) -> Union[list[torch.Tensor], torch.Tensor]: - ... +) -> Union[list[torch.Tensor], torch.Tensor]: ... def flatten_bn( @@ -385,30 +407,21 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: if isinstance(embeddings, torch.Tensor): return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) - return " + ".join( - _embedding_count_expression(inner) for inner in embeddings) + return " + ".join(_embedding_count_expression(inner) for inner in embeddings) -def merge_multimodal_embeddings_from_map( - inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: - """ - Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided - placeholder map . - - Note: - This updates ``inputs_embeds`` in place. - """ - flattened_embeddings = _flatten_embeddings(multimodal_embeddings) - inputs_embeds[placeholder_map.dest] = flattened_embeddings[ - placeholder_map.src].to(dtype=inputs_embeds.dtype) - return inputs_embeds +def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, - is_multimodal: torch.Tensor, multimodal_embeddings: NestedTensors, + is_multimodal: torch.Tensor, ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -418,63 +431,43 @@ def _merge_multimodal_embeddings( Note: This updates ``inputs_embeds`` in place. """ - flattened = _flatten_embeddings(multimodal_embeddings) - try: - # This is equivalent to: inputs_embeds[is_multimodal] = flattened. - inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), - flattened.to(dtype=inputs_embeds.dtype)) - except RuntimeError as e: - num_expected_tokens = is_multimodal.sum().item() - assert isinstance(num_expected_tokens, int) + if len(multimodal_embeddings) == 0: + return inputs_embeds - if flattened.shape[0] != num_expected_tokens: + mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) + input_dtype = inputs_embeds.dtype + + try: + # For debugging + # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype) + + # NOTE: This can avoid D2H sync (#22105), but fails to + # raise an error if is_multimodal.sum() < len(mm_embeds_flat) + inputs_embeds.masked_scatter_( + is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype) + ) + except RuntimeError as e: + num_actual_tokens = len(mm_embeds_flat) + num_expected_tokens = is_multimodal.sum().item() + + if num_actual_tokens != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( - f"Attempted to assign {expr} = {flattened.shape[0]} " + f"Attempted to assign {expr} = {num_actual_tokens} " f"multimodal tokens to {num_expected_tokens} placeholders" ) from e - else: - raise ValueError("Error during masked scatter operation") from e + + raise ValueError("Error during masked scatter operation") from e return inputs_embeds -def embed_multimodal( - input_ids: torch.Tensor, - multimodal_token_id: int, - get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - multimodal_embeds: NestedTensors, -) -> torch.Tensor: - """ - Embed token IDs and multimodal inputs and combine their embeddings. - - ``multimodal_token_id`` is used to determine whether a token ID should - be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``. - - Compared to ``merge_multimodal_embeddings`, this avoids running - ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]`` - which causes issues when the placeholder token ID exceeds the - vocabulary size of the language model. - """ - is_multimodal = input_ids == multimodal_token_id - is_text = ~is_multimodal - - text_embeds = get_text_embeds(input_ids[is_text]) - merged_embeds = torch.empty( - (input_ids.shape[0], text_embeds.shape[1]), - dtype=text_embeds.dtype, - device=text_embeds.device, - ) - - merged_embeds[is_text] = text_embeds - - return _merge_multimodal_embeddings( - merged_embeds, - is_multimodal, - multimodal_embeds, - ) - - +@deprecated( + "`merge_multimodal_embeddings` has been replaced with " + "`SupportsMultiModal.get_input_embeddings` and will be " + "removed in v0.12." +) def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, @@ -507,25 +500,31 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = torch.tensor(placeholder_token_id, - device=input_ids.device) - return _merge_multimodal_embeddings( - inputs_embeds, - torch.isin(input_ids, placeholder_token_id), - multimodal_embeddings, - ) + is_multimodal = isin_list(input_ids, placeholder_token_id) + else: + is_multimodal = input_ids == placeholder_token_id return _merge_multimodal_embeddings( inputs_embeds, - (input_ids == placeholder_token_id), - multimodal_embeddings, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, ) -class LayerFn(Protocol): +def isin_list( + elements: torch.Tensor, + test_elements_list: list[int], +) -> torch.Tensor: + test_elements = torch.tensor( + test_elements_list, + pin_memory=is_pin_memory_available(), + ).to(device=elements.device, non_blocking=True) - def __call__(self, prefix: str) -> torch.nn.Module: - ... + return torch.isin(elements, test_elements) + + +class LayerFn(Protocol): + def __call__(self, prefix: str) -> torch.nn.Module: ... class PPMissingLayer(torch.nn.Identity): @@ -568,8 +567,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: uva_available = is_uva_available() if envs.VLLM_USE_V1: - assert uva_available, ("V1 CPU offloading requires" - " uva (pin memory) support") + assert uva_available, "V1 CPU offloading requires uva (pin memory) support" uva_offloading = True else: uva_offloading = False @@ -584,12 +582,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: break # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided(size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device='cpu', - pin_memory=pin_memory) + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) cpu_data.copy_(p.data) if not uva_offloading: p.data = cpu_data @@ -611,10 +611,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } - output = functional_call(module, - device_state, - args=args, - kwargs=kwargs) + output = functional_call(module, device_state, args=args, kwargs=kwargs) module.forward = forward return output @@ -633,14 +630,18 @@ def make_layers( """ from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.utils import get_pp_indices - start_layer, end_layer = get_pp_indices(num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) + + start_layer, end_layer = get_pp_indices( + num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size + ) modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + [ + [PPMissingLayer() for _ in range(start_layer)] + + [ maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) for idx in range(start_layer, end_layer) - ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + ] + + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)] + ) return start_layer, end_layer, modules @@ -660,7 +661,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: # NOTE: the trailing dot is used to match the prefix of the layer. # without the dot, we could match a layer that is not missing, # e.g., 'encoder.layer.1' would match 'encoder.layer.11' - missing_layer_names.append(name + '.') + missing_layer_names.append(name + ".") _model_to_pp_missing_layer_names[model_id] = missing_layer_names return missing_layer_names @@ -673,21 +674,22 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: return any( name.startswith(missing_layer_name) - for missing_layer_name in get_pp_missing_layer_names(model)) + for missing_layer_name in get_pp_missing_layer_names(model) + ) def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int): - def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, device: torch.device, ) -> IntermediateTensors: - return IntermediateTensors({ - key: - torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) - for key in keys - }) + return IntermediateTensors( + { + key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) + for key in keys + } + ) return make_empty_intermediate_tensors @@ -705,14 +707,14 @@ def maybe_prefix(prefix: str, name: str) -> str: return name if not prefix else f"{prefix}.{name}" -def extract_layer_index(layer_name: str) -> int: +def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: """ Extract the layer index from the module name. Examples: - "encoder.layers.0" -> 0 - "encoder.layers.1.self_attn" -> 1 - "2.self_attn" -> 2 - - "model.encoder.layers.0.sub.1" -> ValueError + - "model.encoder.layers.0.sub.1" -> ValueError if num_attn_module == 1 """ subnames = layer_name.split(".") int_vals: list[int] = [] @@ -721,9 +723,22 @@ def extract_layer_index(layer_name: str) -> int: int_vals.append(int(subname)) except ValueError: continue - assert len(int_vals) == 1, (f"layer name {layer_name} should" - " only contain one integer") - return int_vals[0] + if num_attn_module == 1 or "attn" not in layer_name: + assert len(int_vals) == 1, ( + f"layer name {layer_name} should only contain one integer" + ) + + return int_vals[0] + else: + assert len(int_vals) <= 2, ( + f"layer name {layer_name} should contain most two integers" + ) + layer_index = ( + int_vals[0] * num_attn_module + int_vals[1] + if len(int_vals) == 2 + else int_vals[0] + ) + return layer_index def cast_overflow_tensors( @@ -736,19 +751,20 @@ def cast_overflow_tensors( return tensors -def fast_topk(values: torch.Tensor, topk: int, - dim: int) -> tuple[torch.Tensor, torch.Tensor]: +def fast_topk( + values: torch.Tensor, topk: int, dim: int +) -> tuple[torch.Tensor, torch.Tensor]: """ Optimized topk implementation that uses torch.max for k=1 case. - + This function provides better performance for the common case of k=1 by using torch.max instead of the more general torch.topk. - + Args: values: Input tensor to find top-k values from topk: Number of top values to return (k). Must be > 0. dim: Dimension along which to compute topk - + Returns: Tuple of (values, indices) where values are the top-k values and indices are their corresponding indices in the input tensor @@ -759,3 +775,53 @@ def fast_topk(values: torch.Tensor, topk: int, else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + + +def get_model_hidden_size(hf_config: PretrainedConfig) -> int: + if hasattr(hf_config, "hidden_size"): + return hf_config.hidden_size + text_config = hf_config.get_text_config() + return text_config.hidden_size + + +# Chunk x along the num_tokens axis for sequence parallelism +# NOTE: This is wrapped in a torch custom op to work around the following issue: +# The output tensor can have a sequence length 0 at small input sequence lengths +# even though we explicitly pad to avoid this. +def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.sequence_parallel_chunk_impl(x) + + +def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # all_gather needs the sequence length to be divisible by tp_size + seq_len = x.size(0) + remainder = seq_len % tp_size + if remainder != 0: + pad_len = tp_size - remainder + y = nn.functional.pad(x, (0, 0, 0, pad_len)) + else: + y = x + + chunk = y.shape[0] // tp_size + start = tp_rank * chunk + return torch.narrow(y, 0, start, chunk) + + +def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + seq_len = cdiv(x.size(0), tp_size) + shape = list(x.shape) + shape[0] = seq_len + out = torch.empty(shape, dtype=x.dtype, device=x.device) + return out + + +direct_register_custom_op( + op_name="sequence_parallel_chunk_impl", + op_func=sequence_parallel_chunk_impl, + fake_impl=sequence_parallel_chunk_impl_fake, + tags=(torch.Tag.needs_fixed_stride_order,), +) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index de30509b1ccb4..e517109e94dd6 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,24 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +import math from abc import ABC, abstractmethod -from typing import Final, Generic, Optional, Protocol, TypeVar, Union +from typing import Callable, Final, Generic, Literal, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig -from vllm.attention.selector import get_env_variable_attn_backend +from vllm.attention.backends.registry import _Backend +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform logger = init_logger(__name__) _C = TypeVar("_C", bound=PretrainedConfig) -class VisionEncoderInfo(ABC, Generic[_C]): +class _RootConfig(Protocol[_C]): + vision_config: _C - def __init__(self, hf_config: _C) -> None: + +class VisionEncoderInfo(ABC, Generic[_C]): + def __init__(self, hf_config: _RootConfig[_C]) -> None: super().__init__() self.hf_config = hf_config @@ -50,8 +60,7 @@ class VisionLanguageConfig(Protocol): vision_config: Final[PretrainedConfig] -def get_vision_encoder_info( - hf_config: VisionLanguageConfig) -> VisionEncoderInfo: +def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo: # Avoid circular imports from .clip import CLIPEncoderInfo, CLIPVisionConfig from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig @@ -68,24 +77,76 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(support_fa: bool = False) -> _Backend: +def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: """ Get the available attention backend for Vision Transformer. """ - # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. + # Lazy import to avoid circular dependency + from vllm.attention.selector import get_env_variable_attn_backend selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend - return current_platform.get_vit_attn_backend(support_fa) + return current_platform.get_vit_attn_backend(head_size, dtype) + + +VisionFeatureSelectStrategyStr = Literal["class", "default", "full"] + +VisionFeatureSelectStrategy = Union[ + VisionFeatureSelectStrategyStr, + Callable[[torch.Tensor], torch.Tensor], +] + + +def _get_vision_feature_selector( + strategy: Union[VisionFeatureSelectStrategy, str], +) -> Callable[[torch.Tensor], torch.Tensor]: + if callable(strategy): + return strategy + + # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762 + if strategy == "class": + return lambda feats: feats[:, :1, :] + + # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196 + if strategy == "default": + return lambda feats: feats[:, 1:, :] + + if strategy == "full": + return lambda feats: feats + + raise ValueError(f"Unexpected feature select strategy: {strategy!r}") + + +def get_num_selected_vision_tokens( + num_vision_tokens: int, + strategy: Union[VisionFeatureSelectStrategy, str], +) -> int: + if callable(strategy): + dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D] + dummy_selected_features = strategy(dummy_features) + return dummy_selected_features.shape[1] + + if strategy == "class": + return 1 + + if strategy == "default": + return num_vision_tokens - 1 + + if strategy == "full": + return num_vision_tokens + + raise ValueError(f"Unexpected feature select strategy: {strategy!r}") def resolve_visual_encoder_outputs( encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], - feature_sample_layers: Optional[list[int]], post_layer_norm: Optional[torch.nn.LayerNorm], - max_possible_layers: int, + *, + select_layers: Optional[list[int]] = None, + max_possible_layers: Optional[int] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: """Given the outputs a visual encoder module that may correspond to the output of the last layer, or a list of hidden states to be stacked, @@ -93,17 +154,34 @@ def resolve_visual_encoder_outputs( Args: encoder_outputs: Output of encoder's last layer or all hidden states. - feature_sample_layers: Optional layer indices to grab from the encoder - outputs; if provided, encoder outputs must be a list. post_layer_norm: Post norm to apply to the output of the encoder. + select_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. max_possible_layers: Total layers in the fully loaded visual encoder. - + feature_select_strategy: Defines how to select the hidden states + from each layer. """ - if feature_sample_layers is None: + if select_layers is None: + if not isinstance(encoder_outputs, torch.Tensor): + raise ValueError( + "Expected only a single encoder output when " + "`select_layers` is not provided" + ) + + if feature_select_strategy is not None: + select_features = _get_vision_feature_selector(feature_select_strategy) + encoder_outputs = select_features(encoder_outputs) + if post_layer_norm is not None: return post_layer_norm(encoder_outputs) + return encoder_outputs + if max_possible_layers is None: + raise ValueError( + "`max_possible_layers` must be provided alongside `select_layers`" + ) + # Get the hidden states corresponding to the layer indices. # Negative values are relative to the full visual encoder, # so offset them depending on how many layers were loaded. @@ -114,12 +192,347 @@ def resolve_visual_encoder_outputs( offset = max_possible_layers - num_loaded_layers hs_pool = [ encoder_outputs[layer_idx] - if layer_idx >= 0 else encoder_outputs[layer_idx + offset] - for layer_idx in feature_sample_layers + if layer_idx >= 0 + else encoder_outputs[layer_idx + offset] + for layer_idx in select_layers ] + if feature_select_strategy is not None: + select_features = _get_vision_feature_selector(feature_select_strategy) + hs_pool = [select_features(hs) for hs in hs_pool] + # Apply post-norm on the final hidden state if we are using it - uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1) if post_layer_norm is not None and uses_last_layer: - hs_pool[-1] = post_layer_norm(encoder_outputs) + hs_pool[-1] = post_layer_norm(hs_pool[-1]) + return torch.cat(hs_pool, dim=-1) + + +def run_dp_sharded_vision_model( + image_input: torch.Tensor, vision_model: torch.nn.Module +) -> torch.Tensor: + """Run a vision model with data parallelism (DP) sharding. The function + will shard the input image tensor on the first dimension and run the vision + model + + Args: + image_input (torch.Tensor): Image input tensor. + vision_model (torch.nn.Module): Vision model. + Returns: + torch.Tensor: Output image embeddings + """ + + num_chunks = image_input.shape[0] + mp_world_size = get_tensor_model_parallel_world_size() + num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size + num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks + pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + image_input_padded = torch.nn.functional.pad(image_input, pad) + rank = get_tensor_model_parallel_rank() + image_input_per_rank = image_input_padded[ + rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ... + ] + + vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() + vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) + vision_embeddings = vision_embeddings[:num_chunks, ...] + return vision_embeddings + + +def get_load_balance_assignment( + sizes: list[int], + num_gpus: int = 2, +) -> tuple[list[int], list[int], list[int]]: + """ + Generate load balancing assignment and metadata + for distributing data across GPUs. + The load is determined by the total image sizes, + not the number of images. + + Args: + sizes: The size of each image + num_gpus: Number of GPUs to balance across + + Returns: + shuffle_indices: + Indices to reorder data for balanced loading + gpu_sample_counts: + Number of samples assigned to each GPU + grouped_sizes_per_gpu: + Total size assigned to each GPU + + Example: + ``` + sizes = [1000, 100, 200, 50] + num_gpus = 2 + ``` + + """ + + n_samples = len(sizes) + + # Handle edge cases + if n_samples == 0: + return [], [0] * num_gpus, [0] * num_gpus + + # Use greedy algorithm - balance by total size, not sample count + gpu_assignments = [list[int]() for _ in range(num_gpus)] + gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count + + # Sort indices by size (largest first for better load balancing) + # sizes = [1000, 100, 200, 50] + # large_to_small_indices = [0, 2, 1, 3] + large_to_small_indices = sorted( + range(n_samples), key=lambda i: sizes[i], reverse=True + ) + + for idx in large_to_small_indices: + # Find GPU with minimum current load (by total size) + min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) + gpu_assignments[min_gpu].append(idx) + gpu_loads[min_gpu] += sizes[idx] + + # Create shuffle indices and counts + shuffle_indices = list[int]() + gpu_sample_counts = list[int]() + for gpu_id in range(num_gpus): + # GPU_0 = [1000] = [0] + # GPU_1 = [200, 100, 50] = [2, 1, 3] + # shuffle_indices = [0, 2, 1, 3] + shuffle_indices.extend(gpu_assignments[gpu_id]) + # GPU_0 = [1] + # GPU_1 = [3] + # gpu_sample_counts = [1, 3] + gpu_sample_counts.append(len(gpu_assignments[gpu_id])) + + return (shuffle_indices, gpu_sample_counts, gpu_loads) + + +def run_dp_sharded_mrope_vision_model( + vision_model: torch.nn.Module, + pixel_values: torch.Tensor, + grid_thw_list: list[list[int]], + *, + rope_type: Literal["rope_3d", "rope_2d"], +) -> tuple[torch.Tensor, ...]: + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the + first dimension and run the vision model. + This function is used to run the vision model with mrope. + + Args: + vision_model (torch.nn.Module): Vision model. + pixel_values (torch.Tensor): Image/Video input tensor. + grid_thw_list: List of grid dimensions for each image + rope_type: Type of rope used in the vision model. + Different rope types have different dimension to do ViT. + "rope_3d" for 3D rope (e.g., Qwen2.5-VL) + "rope_2d" for 2D rope (e.g., Kimi-VL) + Returns: + torch.Tensor: Output image embeddings + + Example: + ``` + vision_model.out_hidden_size = 64 + vision_model.spatial_merge_size = 2 + pixel_values.shape = (1350, channel) + grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] + tp_size = 2 + ``` + + """ + tp_size = get_tensor_model_parallel_world_size() + + # GPU_0 tp_rank_local = 0 + # GPU_1 tp_rank_local = 1 + tp_rank_local = get_tensor_model_parallel_rank() + + # patches_per_image = [1000, 100, 200, 50] + patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] + # patches_per_image = [0, 1000, 1100, 1300, 1350] + cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] + + # Get load balancing assignment with all metadata + # image_to_tp_rank = [0, 2, 1, 3] + # gpu_sample_counts = [1, 3] + # grouped_pixel_values_len = [1000, 350] + (image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = ( + get_load_balance_assignment(patches_per_image, tp_size) + ) + + # cu_gpu_sample_counts = [0, 1, 4] + cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] + + # GPU_0 image_idxs_local = [0] + # GPU_1 image_idxs_local = [2, 1, 3] + image_idxs_local = image_to_tp_rank[ + cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1] + ] + + # Get the pixel values for the local images based on the image_idxs_local + if len(image_idxs_local) > 0: + pixel_values_local = torch.cat( + [ + pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]] + for i in image_idxs_local + ] + ) + else: + # Handle case where this rank has no images + pixel_values_local = torch.empty( + (0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + # embed_dim_reduction_factor = 2 * 2 + if rope_type == "rope_2d": + embed_dim_reduction_factor = ( + vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1] + ) + else: + embed_dim_reduction_factor = ( + vision_model.spatial_merge_size * vision_model.spatial_merge_size + ) + + # Find the max length across all ranks + # The output embedding of every DP rank has to be + # padded to this length for tensor_model_parallel_all_gather + # to work + max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor + local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] + + # Run the vision model on the local pixel_values_local + if rope_type == "rope_2d": + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model( + pixel_values_local, torch.tensor(local_grid_thw_list) + ) + if isinstance(image_embeds_local, list): + image_embeds_local = torch.cat(image_embeds_local, dim=0) + else: + out_dim = getattr(vision_model.config, "hidden_size", None) + image_embeds_local = torch.empty( + (0, embed_dim_reduction_factor, out_dim), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + else: + if pixel_values_local.shape[0] > 0: + image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list) + else: + # Handle empty case + image_embeds_local = torch.empty( + (0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) + + # Pad the output based on max_len_per_rank + # for tensor_model_parallel_all_gather to work + current_len = image_embeds_local.shape[0] + if current_len < max_len_per_rank: + padding_size = max_len_per_rank - current_len + if rope_type == "rope_2d": + padding = torch.empty( + ( + padding_size, + image_embeds_local.shape[1], + image_embeds_local.shape[2], + ), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device, + ) + else: + padding = torch.empty( + (padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device, + ) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0) + else: + image_embeds_local_padded = image_embeds_local + + # Do all_gather to collect embeddings from all ranks + gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0) + + # Remove padding and reconstruct per-rank embeddings + rank_embeddings = list[torch.Tensor]() + for rank in range(tp_size): + start_idx = rank * max_len_per_rank + end_idx = start_idx + ( + grouped_pixel_values_len[rank] // embed_dim_reduction_factor + ) + rank_embeddings.append(gathered_embeds[start_idx:end_idx]) + + patches_per_output_image = [ + (patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image + ] + + # Reconstruct embeddings in the original order + original_order_embeddings = [None] * len(grid_thw_list) + current_idx = 0 + for rank in range(tp_size): + count = gpu_sample_counts[rank] + if count > 0: + # Get images assigned to this rank in shuffled order + # GPU_0 = image_idxs_local [0] + # GPU_1 = image_idxs_local [2, 1, 3] + rank_images = image_to_tp_rank[current_idx : current_idx + count] + + rank_embed = rank_embeddings[rank] + # Split rank embeddings back to individual images + embed_start = 0 + for img_idx in rank_images: + img_patches = patches_per_output_image[img_idx] + original_order_embeddings[img_idx] = rank_embed[ + embed_start : embed_start + img_patches + ] + embed_start += img_patches + current_idx += count + out_embeddings = tuple( + embed for embed in original_order_embeddings if embed is not None + ) + assert len(out_embeddings) == len(original_order_embeddings), ( + "Found unassigned embeddings" + ) + return out_embeddings + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, +) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index_tensor = ( + torch.Tensor(t_index) + .to(llm_grid_h.device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .long() + .flatten() + ) + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 77f11a691e080..f4bfbd26756e1 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -5,48 +5,59 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from math import ceil -from typing import Optional, Union, cast +from typing import Literal, Optional, Union, cast import numpy as np import regex as re import torch import torch.nn as nn from mistral_common.audio import mel_filter_bank -from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, - TextChunk, UserMessage) +from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder -from transformers import TensorType, WhisperConfig +from transformers import BatchFeature, TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP -# yapf: disable +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.whisper import WhisperEncoder -# yapf: enable -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalProcessingInfo, - PromptReplacement, PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + MultiModalUUIDDict, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import (MistralTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import ( + MistralTokenizer, + cached_tokenizer_from_config, +) -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription) -from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, - merge_multimodal_embeddings) +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription +from .utils import init_vllm_registered_model, maybe_prefix logger = init_logger(__name__) @@ -108,7 +119,8 @@ class VoxtralProcessorAdapter: audio_length: int, ) -> int: pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames( - audio_length, self.sampling_rate) + audio_length, self.sampling_rate + ) return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate)) def __call__( @@ -138,7 +150,8 @@ class VoxtralProcessorAdapter: "Make sure to process your input via `mistral_common`'s " "tokenizer or pass a chat completion request. " "For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + "https://github.com/vllm-project/vllm/issues/8411." + ) audios_tokens = list[torch.Tensor]() audios_processed = list[torch.Tensor]() @@ -149,21 +162,22 @@ class VoxtralProcessorAdapter: # pad if necessary audio = self._audio_processor.pad(audio, self.sampling_rate) - audio_tokens = [ - self.begin_audio_token_id - ] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio)) + audio_tokens = [self.begin_audio_token_id] + [ + self.audio_token_id + ] * self.get_num_audio_tokens(len(audio)) audios_tokens.append(torch.tensor(audio_tokens)) audios_processed.append(torch.tensor(audio)) - return { - "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), - "audio_arrays": audios_processed, - } + return BatchFeature( + { + "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), + "audio_arrays": audios_processed, + } + ) class VoxtralProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) if not isinstance(tokenizer, MistralTokenizer): @@ -190,11 +204,11 @@ class VoxtralProcessingInfo(BaseProcessingInfo): def get_max_audio_array_len(self) -> int: processor = self.get_hf_processor() return self.get_max_audio_tokens() * int( - processor.sampling_rate // processor.frame_rate) + processor.sampling_rate // processor.frame_rate + ) class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -202,25 +216,30 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) target_length = self.info.get_max_audio_array_len() + audio_overrides = mm_options.get("audio") if mm_options else None + return { - "audio": - self._get_dummy_audios(length=target_length, num_audios=num_audios) + "audio": self._get_dummy_audios( + length=target_length, num_audios=num_audios, overrides=audio_overrides + ) } def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() dummy_text = self.get_dummy_text(mm_counts) - dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_audios = dummy_mm_data.get("audio", []) audio_chunks: list[AudioChunk] = [] @@ -234,9 +253,11 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item)) audio_chunks.append(chunk) - request = ChatCompletionRequest(messages=[ - UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]), + ] + ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens # whixtral tokenizer adds padding to the audio @@ -246,9 +267,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) -class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] - ): - +class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], @@ -288,12 +307,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template @@ -304,17 +325,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] return MultiModalDataParser(target_sr=sampling_rate) -@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor, - info=VoxtralProcessingInfo, - dummy_inputs=VoxtralDummyInputsBuilder) -class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsTranscription): +@MULTIMODAL_REGISTRY.register_processor( + VoxtralMultiModalProcessor, + info=VoxtralProcessingInfo, + dummy_inputs=VoxtralDummyInputsBuilder, +) +class VoxtralForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription +): + merge_by_field_config = True + supported_languages = ISO639_1_SUPPORTED_LANGS + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + # update quant config to so that ignored module and target module names + # match the vLLM model names + if hasattr(vllm_config, "quant_config"): + vllm_config.quant_config = self.maybe_update_quant_config( + vllm_config.quant_config + ) + config = vllm_config.model_config.hf_config self.config = config self.downsample_factor = self.config.audio_config.downsample_factor @@ -336,6 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model + def get_mm_mapping(self) -> MultiModelKeys: + """Get module prefix for multimodal models to filter LoRA modules.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="audio_language_adapter", + tower_model=["whisper_encoder"], + ) + def forward( self, input_ids: torch.Tensor, @@ -347,25 +393,15 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - audio_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - audio_embeddings) - input_ids = None - - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def get_multimodal_embeddings( self, **kwargs - ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], - None]: + ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], None]: audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) if audio_inputs is None: return None @@ -376,50 +412,37 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, seq_len, dim = audio_embedding.shape # Pad such that seq_len is divisible by downsample_factor target_seq_len = self.downsample_factor * math.ceil( - seq_len / self.downsample_factor) + seq_len / self.downsample_factor + ) audio_embedding = torch.nn.functional.pad( audio_embedding, (0, 0, 0, target_seq_len - seq_len), ) audio_embeddings[i] = audio_embedding.reshape( - target_seq_len // self.downsample_factor, - dim * self.downsample_factor) + target_seq_len // self.downsample_factor, dim * self.downsample_factor + ) # Concat, project and resplit audio_embeddings_packed = torch.cat(audio_embeddings, dim=0) - audio_embeddings_packed = self.audio_language_adapter( - audio_embeddings_packed) - audio_embeddings = torch.split(audio_embeddings_packed, - [a.shape[0] for a in audio_embeddings], - dim=0) + audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed) + audio_embeddings = torch.split( + audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0 + ) return audio_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - audio_encoder = self.tokenizer.instruct.audio_encoder - audio_tok_id = audio_encoder.audio_token - - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id) - return inputs_embeds - def _parse_and_validate_audio_arrays( - self, **kwargs: object) -> Union[list[torch.Tensor], None]: + self, **kwargs: object + ) -> Union[list[torch.Tensor], None]: audio_arrays = kwargs.pop("audio_arrays", None) if audio_arrays is None: return None if not isinstance(audio_arrays, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_arrays. " - f"Got type: {type(audio_arrays)}") + raise ValueError( + f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}" + ) - audio_arrays = flatten_bn(audio_arrays) if isinstance(audio_arrays, torch.Tensor): audio_arrays = list(audio_arrays.unbind(0)) return audio_arrays @@ -427,14 +450,13 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) @classmethod - def get_speech_to_text_config(cls, model_config: ModelConfig, - task_type: str) -> SpeechToTextConfig: + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: tokenizer = cached_tokenizer_from_config(model_config) audio_config = tokenizer.instruct.audio_encoder.audio_config max_audio_clip_s = audio_config.chunk_length_s @@ -448,17 +470,23 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, @classmethod # for speech-to-text transcription - def get_generation_prompt(cls, audio: np.ndarray, - model_config: ModelConfig, - stt_config: SpeechToTextConfig, - language: Optional[str], task_type: str, - request_prompt: str) -> PromptType: + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, + stt_config: SpeechToTextConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) - audio = Audio(audio, int(stt_config.sample_rate), - format="wav") # lossless - req = TranscriptionRequest(model=model_config.model, - audio=RawAudio.from_audio(audio), - language=language) + audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless + req = TranscriptionRequest( + model=model_config.model, + audio=RawAudio.from_audio(audio), + language=language, + ) tokenized = tokenizer.instruct.encode_transcription(req) audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) @@ -467,35 +495,44 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return cast(PromptType, prompts_dict) @classmethod - def get_num_audio_tokens(cls, audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig) -> Optional[int]: + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: """ - Map from audio duration to number of audio tokens produced by the ASR + Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. This is used for estimating the amount of processing for this audio. """ tokenizer = cached_tokenizer_from_config(model_config) adapter = VoxtralProcessorAdapter(tokenizer) return adapter.get_num_audio_tokens( - int(audio_duration_s * stt_config.sample_rate)) + int(audio_duration_s * stt_config.sample_rate) + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - # fmt: off + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: remapping_rules = [ (r"mm_whisper_embeddings\.(.*)", r"\1"), (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"), - (r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501 - (r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501 + ( + r"audio_language_adapter\.0\.weight", + r"audio_language_adapter.w_in.weight", + ), + ( + r"audio_language_adapter\.2\.weight", + r"audio_language_adapter.w_out.weight", + ), ] - # fmt: on audio_params = dict( - nn.ModuleDict({ - "audio_language_adapter": - self.audio_language_adapter, - }).named_parameters()) + nn.ModuleDict( + { + "audio_language_adapter": self.audio_language_adapter, + } + ).named_parameters() + ) loaded_weights = set() @@ -503,10 +540,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, nonlocal loaded_weights for name, w in weights: is_encoder = ( - name.startswith("mm_whisper_embeddings") and - not name.startswith("mm_whisper_embeddings.tok_embeddings") + name.startswith("mm_whisper_embeddings") + and not name.startswith("mm_whisper_embeddings.tok_embeddings") and not name.startswith( - "mm_whisper_embeddings.audio_language_projection")) + "mm_whisper_embeddings.audio_language_projection" + ) + ) for pattern, repl in remapping_rules: if re.fullmatch(pattern, name): @@ -536,9 +575,97 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return loaded_weights + def maybe_update_quant_config( + self, quant_config: QuantizationConfig + ) -> QuantizationConfig: + """ + Update quant config to so that ignored module and target module names + match the vLLM model names. + Right now this is specific for compressed-tensors format and + load_format mistral. + """ + remapping_rules = [ + (r"output", r"language_model.lm_head"), + ( + r"layers\.(\d+)\.attention\.wo", + r"language_model.model.layers.\1.self_attn.out_proj", + ), + ( + r"layers\.(\d+)\.attention\.w(.*)", + r"language_model.model.layers.\1.self_attn.\2_proj", + ), + ( + r"layers\.(\d+)\.feed_forward\.w1", + r"language_model.model.layers.\1.mlp.gate_proj", + ), + ( + r"layers\.(\d+)\.feed_forward\.w2", + r"language_model.model.layers.\1.mlp.down_proj", + ), + ( + r"layers\.(\d+)\.feed_forward\.w3", + r"language_model.model.layers.\1.mlp.up_proj", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", + r"whisper_encoder.whisper_encoder.conv1", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", + r"whisper_encoder.whisper_encoder.conv2", + ), + ( + r"mm_whisper_embeddings\.audio_language_projection\.0", + r"audio_language_adapter.w_in", + ), + ( + r"mm_whisper_embeddings\.audio_language_projection\.2", + r"audio_language_adapter.w_out", + ), + ] + + # Update ignore list + if hasattr(quant_config, "ignore"): + mistral_ignore = [] + for name in quant_config.ignore: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + mistral_ignore.append(mistral_name) + quant_config.ignore = mistral_ignore + + # Update target list + if hasattr(quant_config, "config_groups"): + config_groups = quant_config.config_groups + for group_name in config_groups: + if "targets" in config_groups[group_name]: + targets = [] + for name in config_groups[group_name]["targets"]: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + targets.append(mistral_name) + config_groups[group_name]["targets"] = targets + quant_config.config_groups = config_groups + + return quant_config + class AudioLanguageAdapter(nn.Module): - def __init__(self, hidden_size: int, dim: int) -> None: super().__init__() self.w_in = nn.Linear(hidden_size, dim, bias=False) @@ -552,19 +679,44 @@ class AudioLanguageAdapter(nn.Module): class VoxtralEncoderModel(nn.Module): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - # fmt: off mistral_remapping = [ - (r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501 - (r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501 + ( + r"whisper_encoder\.conv_layers\.0\.(weight|bias)", + r"whisper_encoder.conv1.\1", + ), + ( + r"whisper_encoder\.conv_layers\.1\.(weight|bias)", + r"whisper_encoder.conv2.\1", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn.\2_proj.\3", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn.out_proj.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn_layer_norm.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.mlp.fc1.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.mlp.fc2.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", + r"whisper_encoder.layers.\1.final_layer_norm.\2", + ), + ( + r"whisper_encoder\.transformer\.norm\.(weight|bias)", + r"whisper_encoder.layer_norm.\1", + ), ] - # fmt: on def __init__( self, @@ -575,11 +727,11 @@ class VoxtralEncoderModel(nn.Module): super().__init__() self.config = cast(WhisperConfig, vllm_config.model_config.hf_config) self.dtype: torch.dtype = vllm_config.model_config.dtype - self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "whisper_encoder"), - is_standalone_encoder=True, - init_in_fp32=True) + self.whisper_encoder = WhisperEncoder( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "whisper_encoder"), + init_in_fp32=True, + ) mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.window_size // 2, num_mel_bins=self.config.num_mel_bins, @@ -594,8 +746,7 @@ class VoxtralEncoderModel(nn.Module): audio_waveforms: torch.Tensor, ) -> torch.Tensor: input_dtype = audio_waveforms.dtype - window = torch.hann_window(self.config.window_size).to( - audio_waveforms.device) + window = torch.hann_window(self.config.window_size).to(audio_waveforms.device) stft = torch.stft( audio_waveforms, self.config.window_size, @@ -603,7 +754,7 @@ class VoxtralEncoderModel(nn.Module): window=window, return_complex=True, ) - magnitudes = stft[..., :-1].abs()**2 + magnitudes = stft[..., :-1].abs() ** 2 mel_spec = self.mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) @@ -612,8 +763,9 @@ class VoxtralEncoderModel(nn.Module): @property def downsample_factor(self) -> int: - return self.whisper_encoder.conv1.stride[ - 0] * self.whisper_encoder.conv2.stride[0] + return ( + self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0] + ) @property def chunk_size(self) -> int: @@ -647,8 +799,7 @@ class VoxtralEncoderModel(nn.Module): input_features = [input_features] # Split long inputs into chunks - input_embeds, chunks_per_example = ( - self.prepare_inputs_for_conv(input_features)) + input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features) # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size] out = self.whisper_encoder([input_embeds]) @@ -657,7 +808,7 @@ class VoxtralEncoderModel(nn.Module): chunk_idx = 0 results = [] for n_chunks in chunks_per_example: - result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1) + result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1) results.append(result) chunk_idx += n_chunks @@ -677,7 +828,7 @@ class VoxtralEncoderModel(nn.Module): if re.fullmatch(pattern, name): name = re.sub(pattern, repl, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -688,8 +839,7 @@ class VoxtralEncoderModel(nn.Module): break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) return name diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 16bbe2f2010a1..397556cbbcc47 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,47 +4,64 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import numpy as np import torch from torch import nn -from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, - WhisperProcessor) +from transformers import ( + BatchFeature, + WhisperConfig, + WhisperFeatureExtractor, + WhisperProcessor, +) from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention -from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, - VllmConfig) +from vllm.attention.layers.cross_attention import CrossAttention +from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptReplacement, PromptUpdate) +from vllm.multimodal.processing import ( + BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription, SupportsV0Only) -from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - make_layers) +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + cast_overflow_tensors, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -107,17 +124,53 @@ ISO639_1_SUPPORTED_LANGS = { "uk": "Ukrainian", "ur": "Urdu", "vi": "Vietnamese", - "cy": "Welsh" + "cy": "Welsh", } -class WhisperAudioInputs(TypedDict): - input_features: NestedTensors - """Shape: `(batch_size, 128, M)`""" +class WhisperAudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[ + Optional[list[torch.Tensor]], + TensorShape("b", "nmb", "t"), + ] + + +class WhisperEncoderAttention(MultiHeadAttention): + """Multi-headed attention for Whisper encoder with 2D tensor support.""" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """ + Input shape: batch_size x seq_len x hidden_size + or seq_len x hidden_size + """ + is_2d = query.dim() == 2 + if is_2d: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + # Call the parent forward method + out = super().forward(query, key, value) + + if is_2d: + out = out.squeeze(0) + + return out class WhisperPositionalEmbedding(nn.Embedding): - def __init__(self, num_positions: int, embedding_dim: int): super().__init__(num_positions, embedding_dim) @@ -126,7 +179,6 @@ class WhisperPositionalEmbedding(nn.Embedding): class WhisperAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -136,7 +188,6 @@ class WhisperAttention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - standalone_encoder: bool = False, ): super().__init__() self.embed_dim = embed_dim @@ -161,7 +212,8 @@ class WhisperAttention(nn.Module): if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: " - f"{self.embed_dim} and `num_heads`: {num_heads}).") + f"{self.embed_dim} and `num_heads`: {num_heads})." + ) self.scaling = self.head_dim**-0.5 self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) @@ -172,14 +224,25 @@ class WhisperAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - if standalone_encoder: - self.attn = MultiHeadAttention( + if attn_type == AttentionType.ENCODER: + self.attn = WhisperEncoderAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, ) - else: + elif self.attn_type == AttentionType.ENCODER_DECODER: + self.attn = CrossAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + else: # AttentionType.DECODER (regular decoder self-attention) self.attn = Attention( self.num_heads, self.head_dim, @@ -223,7 +286,6 @@ class WhisperAttention(nn.Module): class WhisperCrossAttention(WhisperAttention): - def __init__( self, embed_dim: int, @@ -290,7 +352,6 @@ class WhisperCrossAttention(WhisperAttention): class WhisperMLP(nn.Module): - def __init__( self, embed_dim: int, @@ -323,12 +384,7 @@ class WhisperMLP(nn.Module): class WhisperEncoderLayer(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - is_standalone_encoder: bool = False): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -342,7 +398,6 @@ class WhisperEncoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - standalone_encoder=is_standalone_encoder, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.mlp = WhisperMLP( @@ -373,7 +428,6 @@ class WhisperEncoderLayer(nn.Module): class WhisperDecoderLayer(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -433,52 +487,39 @@ class WhisperDecoderLayer(nn.Module): class WhisperEncoder(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - is_standalone_encoder: bool = False, - init_in_fp32: bool = False): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False + ): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model - self.is_standalone_encoder = is_standalone_encoder self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions - self.embed_scale = (math.sqrt(embed_dim) - if config.scale_embedding else 1.0) + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.conv1 = nn.Conv1d(self.num_mel_bins, - embed_dim, - kernel_size=3, - padding=1) - self.conv2 = nn.Conv1d(embed_dim, - embed_dim, - kernel_size=3, - stride=2, - padding=1) + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, - lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers", - is_standalone_encoder= - is_standalone_encoder), + lambda prefix: WhisperEncoderLayer( + vllm_config=vllm_config, prefix=f"{prefix}.layers" + ), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) - maybe_fp32_init_ctx = set_default_torch_dtype( - torch.float32) if init_in_fp32 else nullcontext() + maybe_fp32_init_ctx = ( + set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext() + ) with ( - torch.no_grad(), - maybe_fp32_init_ctx, + torch.no_grad(), + maybe_fp32_init_ctx, ): - self.embed_positions = nn.Embedding(self.max_source_positions, - embed_dim) + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) self.embed_positions.weight.copy_( - sinusoids(*self.embed_positions.weight.shape)) + sinusoids(*self.embed_positions.weight.shape) + ) def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): hidden_states = [] @@ -486,9 +527,9 @@ class WhisperEncoder(nn.Module): embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) embeds = embeds.transpose(-1, -2) - embeds = (embeds + - self.embed_positions.weight[:embeds.size(-2), :]).to( - embeds.dtype) + embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to( + embeds.dtype + ) hidden_states.append(embeds) hidden_states = torch.cat(hidden_states) @@ -500,7 +541,6 @@ class WhisperEncoder(nn.Module): class WhisperDecoder(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -508,17 +548,19 @@ class WhisperDecoder(nn.Module): self.padding_idx = config.pad_token_id self.max_target_positions = config.max_target_positions self.max_source_positions = config.max_source_positions - self.embed_scale = (math.sqrt(config.d_model) - if config.scale_embedding else 1.0) + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, - self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.d_model, self.padding_idx + ) self.embed_positions = WhisperPositionalEmbedding( - self.max_target_positions, config.d_model) + self.max_target_positions, config.d_model + ) self.start_layer, self.end_layer, self.layers = make_layers( config.decoder_layers, - lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers"), + lambda prefix: WhisperDecoderLayer( + vllm_config=vllm_config, prefix=f"{prefix}.layers" + ), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -542,21 +584,19 @@ class WhisperDecoder(nn.Module): hidden_states = self.layer_norm(hidden_states) return hidden_states - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) class WhisperModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.encoder = WhisperEncoder(vllm_config=vllm_config, - prefix=f"{prefix}.encoder") - self.decoder = WhisperDecoder(vllm_config=vllm_config, - prefix=f"{prefix}.decoder") + self.encoder = WhisperEncoder( + vllm_config=vllm_config, prefix=f"{prefix}.encoder" + ) + self.decoder = WhisperDecoder( + vllm_config=vllm_config, prefix=f"{prefix}.decoder" + ) def forward( self, @@ -580,8 +620,7 @@ class WhisperModel(nn.Module): return None return self.encoder(input_features) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), @@ -611,15 +650,13 @@ class WhisperModel(nn.Module): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class WhisperProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> WhisperConfig: return self.ctx.get_hf_config(WhisperConfig) @@ -636,8 +673,7 @@ class WhisperProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1} - def get_feature_extractor(self, - **kwargs: object) -> WhisperFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) @@ -648,7 +684,6 @@ class WhisperProcessingInfo(BaseProcessingInfo): class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -658,6 +693,7 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() @@ -665,15 +701,16 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): audio_len = feature_extractor.chunk_length * sampling_rate num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + return { - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } -class WhisperMultiModalProcessor( - EncDecMultiModalProcessor[WhisperProcessingInfo]): - +class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -740,11 +777,15 @@ class WhisperMultiModalProcessor( ] -@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor, - info=WhisperProcessingInfo, - dummy_inputs=WhisperDummyInputsBuilder) -class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal, SupportsV0Only): +@MULTIMODAL_REGISTRY.register_processor( + WhisperMultiModalProcessor, + info=WhisperProcessingInfo, + dummy_inputs=WhisperDummyInputsBuilder, +) +class WhisperForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): + merge_by_field_config = True packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -754,10 +795,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], } - hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ - ".fc1.": ".mlp.fc1.", - ".fc2.": ".mlp.fc2." - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."} + ) # Whisper only supports audio-conditioned generation. supports_transcription_only = True @@ -772,22 +812,26 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, logger.warning( "Defaulting to language='en'. If you wish to transcribe " "audio in a different language, pass the `language` field " - "in the TranscriptionRequest.") + "in the TranscriptionRequest." + ) language = "en" return super().validate_language(language) @classmethod def get_generation_prompt( - cls, - audio: np.ndarray, - model_config: ModelConfig, # not needed here - stt_config: SpeechToTextConfig, - language: Optional[str], - task_type: str, - request_prompt: str) -> PromptType: + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: if language is None: raise ValueError( - "Language must be specified when creating the Whisper prompt") + "Language must be specified when creating the Whisper prompt" + ) prompt = { "encoder_prompt": { # Whisper does not support encoder prompt. @@ -796,10 +840,11 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, "audio": (audio, stt_config.sample_rate), }, }, - "decoder_prompt": - ((f"<|prev|>{request_prompt}" if request_prompt else "") + - f"<|startoftranscript|><|{language}|>" + - f"<|{task_type}|><|notimestamps|>") + "decoder_prompt": ( + (f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>" + ), } return cast(PromptType, prompt) @@ -811,8 +856,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, raise ValueError("Only audio modality is supported") @classmethod - def get_speech_to_text_config(cls, model_config: ModelConfig, - task_type: str) -> SpeechToTextConfig: + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: processor = cached_get_processor(model_config.model) return SpeechToTextConfig( @@ -821,9 +867,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ) @classmethod - def get_num_audio_tokens(cls, audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig) -> Optional[int]: + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: processor = cached_get_processor(model_config.model) hop_length = processor.feature_extractor.hop_length assert hop_length is not None @@ -831,8 +880,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, # prompts directly at least not to Whisper. # One indicator of the encoder amount of processing # is the log-mel spectogram length. - return math.ceil(audio_duration_s * stt_config.sample_rate / - hop_length) + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -843,14 +891,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.unpadded_vocab_size = config.vocab_size - self.proj_out = ParallelLMHead(config.vocab_size, - config.d_model, - quant_config=quant_config) - self.proj_out = self.proj_out.tie_weights( - self.model.decoder.embed_tokens) + self.proj_out = ParallelLMHead( + config.vocab_size, + config.d_model, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "proj_out"), + ) + self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) def forward( self, @@ -869,44 +920,36 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def get_language_model(self) -> torch.nn.Module: return self.model.decoder - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - # TODO: This method does not obey the interface for SupportsMultiModal. - # Refactor this once encoder/decoder support is implemented in V1. + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs(audio_input["input_features"]) + return [self.model.get_encoder_outputs(audio_input["input_features"])] def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - # TODO: This method just returns the decoder sequence embeddings since - # Whisper does not have encoder text tokens. Refactor this once - # encoder/decoder support is implemented in V1. + # This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. return self.model.decoder.get_input_embeddings(input_ids) - def _parse_and_validate_audio_input( - self, **kwargs: object) -> WhisperAudioInputs: + def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs: input_features = kwargs.pop("input_features", None) if input_features is not None: - if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(input_features)}") - input_features = torch.cat( - [feat.to(self.dtype) for feat in input_features]) + input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features) return WhisperAudioInputs(input_features=input_features) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.proj_out, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.proj_out, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) # add fake zeros bias for k_proj to state_dict @@ -915,7 +958,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def _create_fake_bias_for_k_proj( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: """ Create full zeros bias for k_proj weight in self-attn and x-attn layers. diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ed65944c109bd..b69204d020962 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -2,46 +2,47 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """PyTorch Zamba2 model implementation for vLLM. -This module implements the Zamba2 architecture from -https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer -architectures in a hybrid model optimized for efficient sequence modeling. The +This module implements the Zamba2 architecture from +https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer +architectures in a hybrid model optimized for efficient sequence modeling. The model alternates between state space model layers and attention-based layers. """ + from collections.abc import Iterable from itertools import cycle -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch import nn from transformers import Zamba2Config -from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -50,7 +51,7 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class Zamba2LoRA(nn.Module): """LoRA layer for the Zamba2 model. - + Implements a LoRA layer that is used in shared attention and gated MLP blocks. """ @@ -64,7 +65,7 @@ class Zamba2LoRA(nn.Module): prefix: str = "", ): """Initialize the attention layer. - + Args: input_dim: input dimension rank: LoRA rank @@ -73,20 +74,15 @@ class Zamba2LoRA(nn.Module): """ super().__init__() - self.A = ColumnParallelLinear(input_dim, - rank, - bias=False, - quant_config=quant_config, - gather_output=True) + self.A = ColumnParallelLinear( + input_dim, rank, bias=False, quant_config=quant_config, gather_output=True + ) if isinstance(output_dim, list): B_class = MergedColumnParallelLinear else: B_class = ColumnParallelLinear - self.B = B_class(rank, - output_dim, - bias=False, - quant_config=quant_config) + self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config) def forward( self, @@ -99,8 +95,8 @@ class Zamba2LoRA(nn.Module): class Zamba2Attention(nn.Module): """Multi-head attention mechanism for the Zamba2 model. - - Implements attention with parallel computation, QKV projections, optional + + Implements attention with parallel computation, QKV projections, optional adapters and rotary position embeddings. The attention is computed across distributed blocks for efficient processing. """ @@ -115,7 +111,7 @@ class Zamba2Attention(nn.Module): prefix: str = "", ) -> None: """Initialize the attention layer. - + Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare attention block @@ -136,15 +132,17 @@ class Zamba2Attention(nn.Module): self.num_attention_heads = config.num_attention_heads // tp_size self.attention_head_dim = config.attention_head_dim self.qkv_size = self.attention_hidden_size // tp_size - self.scale = (self.attention_head_dim / 2)**-0.5 + self.scale = (self.attention_head_dim / 2) ** -0.5 - if (self.attention_head_dim * - self.total_num_attention_heads) != self.attention_hidden_size: + if ( + self.attention_head_dim * self.total_num_attention_heads + ) != self.attention_hidden_size: raise ValueError( f"attention_hidden_size must be divisible by" f" num_attention_heads" f" (got `attention_hidden_size`: {self.attention_hidden_size}" - f" and `num_heads`: {self.num_attention_heads}).") + f" and `num_heads`: {self.num_attention_heads})." + ) self.qkv_proj = QKVParallelLinear( self.attention_hidden_size, @@ -153,10 +151,12 @@ class Zamba2Attention(nn.Module): bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.attention_hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.attention_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) # Even though in Zamba2 weights are shared between attention layers, KV # cache is unique for every attention layer. Hence, we need to define @@ -165,8 +165,11 @@ class Zamba2Attention(nn.Module): # Initialize attention blocks with proper indexing self.dpa_list = nn.ModuleList([]) - j = bare_block_idx * (self.num_hybrid_layers + config.num_mem_blocks - - 1) // config.num_mem_blocks + j = ( + bare_block_idx + * (self.num_hybrid_layers + config.num_mem_blocks - 1) + // config.num_mem_blocks + ) for block_idx in range(self.num_hybrid_layers): if block_idx % config.num_mem_blocks == bare_block_idx: dpa = Attention( @@ -233,18 +236,17 @@ class Zamba2Attention(nn.Module): position_ids: torch.Tensor, ) -> torch.Tensor: """Forward pass through the attention layer. - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] position_ids: Position IDs for positional embeddings block_idx: Current shared transformer block index - + Returns: Output tensor [batch_size, seq_len, hidden_size] """ qkv, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, - dim=-1) + query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, dim=-1) if self.config.use_shared_attention_adapter: # Apply adapter transformations to Q, K, V if enabled @@ -264,9 +266,9 @@ class Zamba2Attention(nn.Module): value_states = value_states + v_lora_output if self.config.use_mem_rope: - query_states, key_states = self.rotary_emb(position_ids, - query_states, - key_states) + query_states, key_states = self.rotary_emb( + position_ids, query_states, key_states + ) y = self.dpa_list[block_idx](query_states, key_states, value_states) y, _ = self.o_proj(y) @@ -275,9 +277,9 @@ class Zamba2Attention(nn.Module): class Zamba2MLP(nn.Module): """Feed-forward MLP layer for the Zamba2 model. - - Implements a gated feed-forward network that projects inputs to a larger - intermediate size, applies GELU activation with gating, then projects back + + Implements a gated feed-forward network that projects inputs to a larger + intermediate size, applies GELU activation with gating, then projects back to the original size. Includes optional adapter layers for model adaptation. """ @@ -290,7 +292,7 @@ class Zamba2MLP(nn.Module): prefix: str = "", ) -> None: """Initialize the MLP layer. - + Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare block in the model @@ -309,17 +311,22 @@ class Zamba2MLP(nn.Module): self.hidden_size, 2 * [self.intermediate_size], # 2x for gate and input projections bias=self.config.add_bias_linear, - quant_config=quant_config) + quant_config=quant_config, + ) - self.down_proj = RowParallelLinear(self.intermediate_size, - self.hidden_size, - bias=self.config.add_bias_linear, - quant_config=quant_config) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.config.add_bias_linear, + quant_config=quant_config, + ) # Only allow GELU activations if config.hidden_act != "gelu": - raise ValueError(f"Only GELU activation is supported " - f"(got `hidden_act`: {config.hidden_act})") + raise ValueError( + f"Only GELU activation is supported " + f"(got `hidden_act`: {config.hidden_act})" + ) self.act_fn = GeluAndMul() # Initialize adapter layers @@ -336,14 +343,13 @@ class Zamba2MLP(nn.Module): gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) - def forward(self, hidden_states: torch.Tensor, - block_idx: int) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, block_idx: int) -> torch.Tensor: """Forward pass through the MLP layer. - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] block_idx: Current shared transformer block index - + Returns: Output tensor [batch_size, seq_len, hidden_size] after applying gated feed-forward transformation @@ -367,7 +373,7 @@ class Zamba2MLP(nn.Module): class Zamba2AttentionDecoderLayer(nn.Module): """Single decoder layer combining attention and feed-forward networks. - + This layer implements a standard transformer block with: - Input layer normalization - Multi-head self-attention @@ -385,7 +391,7 @@ class Zamba2AttentionDecoderLayer(nn.Module): prefix: str = "", ) -> None: """Initialize the decoder layer. - + Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare block @@ -416,11 +422,9 @@ class Zamba2AttentionDecoderLayer(nn.Module): # Initialize layer normalizations # Input normalization operates on concatenated states - self.input_layernorm = RMSNorm(2 * config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(2 * config.hidden_size, eps=config.rms_norm_eps) # Pre-FF normalization operates on attention output - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -430,14 +434,14 @@ class Zamba2AttentionDecoderLayer(nn.Module): positions: torch.Tensor, ) -> torch.Tensor: """Forward pass through the decoder layer. - + Args: hidden_states: Input tensor from previous layer - original_hidden_states: Original input tensor for residual + original_hidden_states: Original input tensor for residual connection block_idx: Current shared transformer block index positions: IDs for positional embeddings - + Returns: Transformed hidden states after attention and feed-forward """ @@ -447,7 +451,8 @@ class Zamba2AttentionDecoderLayer(nn.Module): # The concatenated tensor is then used as input of the pre-attention # RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). hidden_states = torch.concatenate( - [hidden_states, original_hidden_states], dim=-1) + [hidden_states, original_hidden_states], dim=-1 + ) # Layer norm before attention hidden_states = self.input_layernorm(hidden_states) @@ -470,20 +475,22 @@ class Zamba2AttentionDecoderLayer(nn.Module): class Zamba2MambaDecoderLayer(nn.Module): """Single Mamba decoder layer with normalization. - - This implements a Mamba block. It includes input normalization - and can process sequences using either chunked or full + + This implements a Mamba block. It includes input normalization + and can process sequences using either chunked or full computation depending on configuration. """ - def __init__(self, - config: Zamba2Config, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Zamba2Config, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: """Initialize the Mamba decoder layer. - + Args: config: The Zamba2 model configuration quant_config: Configuration for model quantization @@ -492,49 +499,43 @@ class Zamba2MambaDecoderLayer(nn.Module): # Initialize Mamba mixer with expanded intermediate size intermediate_size = config.mamba_expand * config.hidden_size - self.mamba = MambaMixer2(hidden_size=config.hidden_size, - ssm_state_size=config.mamba_d_state, - conv_kernel_size=config.mamba_d_conv, - intermediate_size=intermediate_size, - use_conv_bias=config.use_conv_bias, - use_bias=config.add_bias_linear, - n_groups=config.mamba_ngroups, - num_heads=config.n_mamba_heads, - head_dim=intermediate_size // - config.n_mamba_heads, - rms_norm_eps=config.rms_norm_eps, - activation="silu", - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=intermediate_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.add_bias_linear, + n_groups=config.mamba_ngroups, + num_heads=config.n_mamba_heads, + head_dim=intermediate_size // config.n_mamba_heads, + rms_norm_eps=config.rms_norm_eps, + activation="silu", + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) # Input normalization - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, transformer_hidden_states: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass through the Mamba decoder layer. - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) - sequence_idx: Index tensor for identifying sequences in batch - Required for proper chunked processing in prefill transformer_hidden_states: Optional output from transformer path Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) original_hidden_states: Optional original inputs (unused in Mamba) - + Returns: Transformed hidden states with residual connection applied """ @@ -558,8 +559,6 @@ class Zamba2MambaDecoderLayer(nn.Module): self.mamba( hidden_states, output, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) # residual connection after mamba @@ -570,7 +569,7 @@ class Zamba2MambaDecoderLayer(nn.Module): class Zamba2HybridLayer(nn.Module): """Hybrid layer combining Transformer and Mamba architectures. - + This layer implements the hybrid architecture described in the Zamba paper, where a shared transformer pathway processes input in parallel with a Mamba pathway. The transformer output is projected and added to the Mamba input @@ -588,51 +587,47 @@ class Zamba2HybridLayer(nn.Module): prefix: str = "", ) -> None: """Initialize the hybrid layer. - + Args: shared_transformer: Transformer decoder layer for attention pathway - linear: Linear projection for transformer output before Mamba - mamba: Mamba decoder layer for state space pathway """ super().__init__() self.block_idx = block_idx self.shared_transformer = shared_transformer - self.linear = ReplicatedLinear(config.hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) - self.mamba_decoder = Zamba2MambaDecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + self.linear = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.mamba_decoder = Zamba2MambaDecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) def forward( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, - mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: """Forward pass through the hybrid layer. - + Processes input through parallel transformer and Mamba paths: 1. Transformer path processes input with attention 2. Transformer output is projected to match hidden size 3. Projected output is added to Mamba path input 4. Final output combines both paths' representations - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - original_hidden_states: Original input for transformer residual + original_hidden_states: Original input for transformer residual connection positions: Position IDs for positional embeddings - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) - sequence_idx: Indices for identifying sequences in batch, - required for proper chunked processing in prefill - + Returns: Output tensor combining transformer and Mamba representations """ @@ -651,8 +646,6 @@ class Zamba2HybridLayer(nn.Module): layer_outputs = self.mamba_decoder( hidden_states, transformer_hidden_states=transformer_hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) return layer_outputs @@ -661,16 +654,16 @@ class Zamba2HybridLayer(nn.Module): @support_torch_compile class Zamba2Model(nn.Module): """Core Zamba2 model combining transformer and Mamba architectures. - - The model processes input through a sequence of hybrid and Mamba-only + + The model processes input through a sequence of hybrid and Mamba-only layers, using token embeddings and final layer normalization. """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model. - + Args: - vllm_config: Configuration object containing model, cache, + vllm_config: Configuration object containing model, cache, quantization and LoRA settings prefix: Optional prefix for parameter names in state dict """ @@ -685,8 +678,11 @@ class Zamba2Model(nn.Module): assert not is_lora_enabled self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -704,15 +700,19 @@ class Zamba2Model(nn.Module): } # Create cyclic iterator of transformer blocks - blocks = cycle([ - Zamba2AttentionDecoderLayer(config, - bare_block_idx=idx, - num_hybrid_layers=len(layer2block_map), - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}") - for idx in range(config.num_mem_blocks) - ]) + blocks = cycle( + [ + Zamba2AttentionDecoderLayer( + config, + bare_block_idx=idx, + num_hybrid_layers=len(layer2block_map), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}", + ) + for idx in range(config.num_mem_blocks) + ] + ) # Initialize layers according to block type configuration layers = [] @@ -724,32 +724,37 @@ class Zamba2Model(nn.Module): block = next(blocks) block_idx = layer2block_map[layer_idx] layers.append( - Zamba2HybridLayer(block, - config, - block_idx, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix)) + Zamba2HybridLayer( + block, + config, + block_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + ) else: layers.append( - Zamba2MambaDecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix)) + Zamba2MambaDecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + ) self.layers = nn.ModuleList(layers) # Final layer normalization - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. - + Args: input_ids: Tensor of input token IDs - + Returns: Embedded representation of the input tokens """ @@ -759,20 +764,17 @@ class Zamba2Model(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """Forward pass through the model. - + Args: input_ids: Input token IDs positions: Position IDs for embeddings - mamba_cache_params: Parameters for Mamba's state caches - (one for conv, one for ssm) inputs_embeds: Optional pre-computed input embeddings - + Returns: - Either final hidden states or intermediate tensors for pipeline + Either final hidden states or intermediate tensors for pipeline parallelism """ # Handle pipeline parallelism for first rank @@ -780,41 +782,20 @@ class Zamba2Model(nn.Module): inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = inputs_embeds - attn_metadata = get_forward_context().attn_metadata - - if not envs.VLLM_USE_V1: - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) - else: - # v1 get mamba2_metadata from forward_context - mamba2_metadata = None - # Process through layers original_hidden_states = torch.clone(hidden_states) for layer_idx, layer in enumerate(self.layers): - - layer_mamba_cache_params = None - if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer)) - and mamba_cache_params): - layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - layer_idx) - layer_outputs = layer( hidden_states, original_hidden_states=original_hidden_states, positions=positions, - mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata, ) hidden_states = layer_outputs hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -828,8 +809,7 @@ class Zamba2Model(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in chkpt_weight_name: continue - chkpt_weight_name = chkpt_weight_name.replace( - weight_name, param_name) + chkpt_weight_name = chkpt_weight_name.replace(weight_name, param_name) param = params_dict[chkpt_weight_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -838,8 +818,7 @@ class Zamba2Model(nn.Module): if chkpt_weight_name not in params_dict: continue param = params_dict[chkpt_weight_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(chkpt_weight_name) return loaded_params @@ -847,26 +826,28 @@ class Zamba2Model(nn.Module): class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): """Zamba2 model with causal language modeling head. - + This class wraps the core Zamba2 model and adds: - A language modeling head for next token prediction - Mamba state caching functionality - Support for model parallelism and quantization - Sampling capabilities for text generation """ + # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ - "A_log": "A", - "0.weight": "A.weight", - "1.weight": "B.weight", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "A_log": "A", + "0.weight": "A.weight", + "1.weight": "B.weight", + } + ) @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -877,13 +858,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config - use_v1: Get shapes for V1 (or V0) Returns: Tuple containing: @@ -903,27 +882,23 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): head_dim=hf_config.mamba_headdim, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. - + Args: vllm_config: Configuration containing model, cache, quantization, LoRA and scheduler settings prefix: Optional prefix for parameter names - + Raises: - AssertionError: If prefix caching is enabled (not supported by - Mamba) + AssertionError: If prefix caching is enabled + (not supported by Mamba) """ config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config @@ -935,8 +910,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): self.unpadded_vocab_size += lora_config.lora_extra_vocab_size # Initialize core model - self.model = Zamba2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Zamba2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) # Initialize language modeling head self.lm_head = ParallelLMHead( @@ -946,17 +922,17 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - # Initialize logits processing and sampling - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. @@ -967,96 +943,48 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): """ return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: """Forward pass through the model. - + Args: input_ids: Input token IDs positions: Position IDs for embeddings inputs_embeds: Optional pre-computed input embeddings **kwargs: Additional arguments passed to cache manager - + Returns: Output hidden states """ - # Initialize Mamba cache if needed - mamba_cache_params = None - if not envs.VLLM_USE_V1: - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers - mamba_state_shape = \ - self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - mamba_state_dtype = \ - self.get_mamba_state_dtype_from_config( - self.vllm_config) - self.mamba_cache = MambaCacheManager(self.vllm_config, - num_mamba_layers, - *mamba_state_shape, - *mamba_state_dtype) - - # Get cache parameters for current run - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - # Forward pass through model hidden_states = self.model( input_ids, positions, - mamba_cache_params, inputs_embeds, ) return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str, - torch.Tensor], - **kwargs) -> dict[str, torch.Tensor]: - """Copy inputs before CUDA graph capture. - - Args: - input_buffers: Dictionary of input tensors - **kwargs: Additional arguments passed to cache manager - - Returns: - Updated input buffers - """ - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs( - self, batch_size: int) -> dict[str, torch.Tensor]: - """Get inputs for sequence-length-agnostic graph capture. - - Args: - batch_size: Size of batch to capture - Returns: - Dictionary of capture inputs - """ - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """Compute logits for next token prediction. - + Args: hidden_states: Hidden states from model forward pass - sampling_metadata: Metadata for sampling process - + Returns: Logits for next token prediction """ - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 750ee78502688..9341665f1bca2 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,20 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Hashable from fractions import Fraction from typing import Callable, Optional, Union +from weakref import WeakValueDictionary import torch from torch.nn import Parameter -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger -from vllm.model_executor.utils import _make_synced_weight_loader __all__ = [ - "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", - "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" + "BasevLLMParameter", + "PackedvLLMParameter", + "PerTensorScaleParameter", + "ModelWeightParameter", + "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", + "PackedColumnParameter", + "RowvLLMParameter", ] logger = init_logger(__name__) @@ -27,8 +36,7 @@ class BasevLLMParameter(Parameter): into the parameter when the provided weight loader is called. """ - def __new__(cls, data: torch.Tensor, **kwargs): - + def __new__(cls, data: Optional[torch.Tensor], **kwargs): return super().__new__(cls, data=data, requires_grad=False) def __init__(self, data: torch.Tensor, weight_loader: Callable): @@ -50,23 +58,43 @@ class BasevLLMParameter(Parameter): # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. from vllm.platforms import current_platform - if current_platform.is_tpu(): - weight_loader = _make_synced_weight_loader(weight_loader) + + if current_platform.use_sync_weight_loader(): + weight_loader = current_platform.make_synced_weight_loader(weight_loader) self._weight_loader = weight_loader + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() @property - def weight_loader(self): + def weight_loader(self) -> Callable: + # NOTE(@ksayers) some models such as mamba_mixer2 override the + # weight loader to support custom loading. In the future, model-specific + # weight loading should be implemented via Model.load_weights. In the + # meantime, support deleting and overriding `weight_loader`` attribute + if self._weight_loader is None: + raise AttributeError( + f"{self.__class__.__name__} weight_loader attribute has been deleted" + ) return self._weight_loader + @weight_loader.setter + def weight_loader(self, value: Callable): + self._weight_loader = value + + @weight_loader.deleter + def weight_loader(self): + self._weight_loader = None # type: ignore[assignment] + def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): cond1 = self.data.ndim == 1 and self.data.numel() == 1 cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 - return (cond1 and cond2) + return cond1 and cond2 def _assert_and_load(self, loaded_weight: torch.Tensor): - assert (self.data.shape == loaded_weight.shape - or self._is_1d_and_scalar(loaded_weight)) + assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar( + loaded_weight + ) self.data.copy_(loaded_weight) def load_column_parallel_weight(self, loaded_weight: torch.Tensor): @@ -81,14 +109,31 @@ class BasevLLMParameter(Parameter): def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): self._assert_and_load(loaded_weight) + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert isinstance(shard_id, str) + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + class _ColumnvLLMParameter(BasevLLMParameter): """ - Private class defining weight loading functionality + Private class defining weight loading functionality (load_merged_column_weight, load_qkv_weight) for parameters being loaded into linear layers with column parallelism. This includes QKV and MLP layers which are - not already fused on disk. Requires an output dimension + not already fused on disk. Requires an output dimension to be defined. Called within the weight loader of each of the column parallel linear layers. """ @@ -102,55 +147,56 @@ class _ColumnvLLMParameter(BasevLLMParameter): return self._output_dim def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.output_dim] - loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, self.tp_rank * shard_size, shard_size + ) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): - shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") - if isinstance( - self, - (PackedColumnParameter, - PackedvLLMParameter)) and self.packed_dim == self.output_dim: + + # TODO: move these to PackedColumnParameter and PackedvLLMParameter + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.packed_dim == self.output_dim + ): shard_size, shard_offset = self.adjust_shard_indexes_for_packing( - shard_offset=shard_offset, shard_size=shard_size) + shard_offset=shard_offset, shard_size=shard_size + ) param_data = self.data - tp_rank = get_tensor_model_parallel_rank() - param_data = param_data.narrow(self.output_dim, shard_offset, - shard_size) - loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, self.tp_rank * shard_size, shard_size + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): - shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") shard_id = kwargs.get("shard_id") num_heads = kwargs.get("num_heads") - if isinstance( - self, - (PackedColumnParameter, - PackedvLLMParameter)) and self.output_dim == self.packed_dim: + # TODO: move these to PackedColumnParameter and PackedvLLMParameter + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.output_dim == self.packed_dim + ): shard_size, shard_offset = self.adjust_shard_indexes_for_packing( - shard_offset=shard_offset, shard_size=shard_size) + shard_offset=shard_offset, shard_size=shard_size + ) param_data = self.data - tp_rank = get_tensor_model_parallel_rank() - shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads - param_data = param_data.narrow(self.output_dim, shard_offset, - shard_size) - loaded_weight = loaded_weight.narrow(self.output_dim, - shard_id * shard_size, shard_size) + shard_id = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -173,10 +219,10 @@ class RowvLLMParameter(BasevLLMParameter): return self._input_dim def load_row_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.input_dim] - loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + loaded_weight = loaded_weight.narrow( + self.input_dim, self.tp_rank * shard_size, shard_size + ) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) @@ -190,6 +236,7 @@ class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): Parameter class for linear layer weights. Uses both column and row parallelism. """ + pass @@ -198,6 +245,7 @@ class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): Parameter class for weight scales loaded for weights with grouped quantization. Uses both column and row parallelism. """ + pass @@ -206,6 +254,7 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter): Parameter class for weight scales loaded for weights with channel-wise quantization. Equivalent to _ColumnvLLMParameter. """ + pass @@ -216,27 +265,16 @@ class PerTensorScaleParameter(BasevLLMParameter): layers (e.g. for QKV, there are 3 scales loaded from disk). This is relevant to weights with per-tensor quantization. Adds functionality to map the scalers to a shard during - weight loading. + weight loading. - Note: additional parameter manipulation may be handled - for each quantization config specifically, within - process_weights_after_loading + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading """ def __init__(self, **kwargs): - self.qkv_idxs = {"q": 0, "k": 1, "v": 2} super().__init__(**kwargs) - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - # if not int, assume shard_id for qkv - # map to int and return - assert isinstance(shard_id, str) - assert shard_id in self.qkv_idxs - return self.qkv_idxs[shard_id] - # For row parallel layers, no sharding needed # load weight into parameter as is def load_row_parallel_weight(self, *args, **kwargs): @@ -251,10 +289,11 @@ class PerTensorScaleParameter(BasevLLMParameter): def load_column_parallel_weight(self, *args, **kwargs): super().load_row_parallel_weight(*args, **kwargs) - def _load_into_shard_id(self, loaded_weight: torch.Tensor, - shard_id: Union[str, int], **kwargs): + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs + ): """ - Slice the parameter data based on the shard id for + Slice the parameter data based on the shard id for loading. """ @@ -279,12 +318,14 @@ class PackedColumnParameter(_ColumnvLLMParameter): for more details on the packed properties. """ - def __init__(self, - packed_factor: Union[int, Fraction], - packed_dim: int, - marlin_tile_size: Optional[int] = None, - bitblas_tile_size: Optional[int] = None, - **kwargs): + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, + **kwargs, + ): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size @@ -313,7 +354,8 @@ class PackedColumnParameter(_ColumnvLLMParameter): shard_offset=shard_offset, packed_factor=self.packed_factor, marlin_tile_size=self.marlin_tile_size, - bitblas_tile_size=self.bitblas_tile_size) + bitblas_tile_size=self.bitblas_tile_size, + ) class PackedvLLMParameter(ModelWeightParameter): @@ -322,17 +364,19 @@ class PackedvLLMParameter(ModelWeightParameter): Example: GPTQ Marlin weights are int4 or int8, packed into int32. Extends the ModelWeightParameter to take in the packed factor, the packed dimension, and optionally, marlin - tile size for marlin kernels. Adjusts the shard_size and + tile size for marlin kernels. Adjusts the shard_size and shard_offset for fused linear layers model weight loading by accounting for packing and optionally, marlin tile size. """ - def __init__(self, - packed_factor: Union[int, Fraction], - packed_dim: int, - marlin_tile_size: Optional[int] = None, - bitblas_tile_size: Optional[int] = None, - **kwargs): + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, + **kwargs, + ): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size @@ -361,7 +405,8 @@ class PackedvLLMParameter(ModelWeightParameter): shard_offset=shard_offset, packed_factor=self.packed_factor, marlin_tile_size=self.marlin_tile_size, - bitblas_tile_size=self.bitblas_tile_size) + bitblas_tile_size=self.bitblas_tile_size, + ) class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): @@ -373,16 +418,155 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): pass -def permute_param_layout_(param: BasevLLMParameter, input_dim: int, - output_dim: int, **kwargs) -> BasevLLMParameter: +class SharedWeightParameter(BasevLLMParameter): """ - Permute a parameter's layout to the specified input and output dimensions, + Parameter for weights with many shared tensors across a model + + For example, when applying transforms to the "gate" and "up" partitions of + `MergedColumnParallelLinear`, the transform weights must stay separate + tensors in order to allow for tensor memory sharing between layers. + """ + + # global registry for sharing tensors based on passed `data_key` + # this dict holds weaksrefs to avoid memory leak after model cleanup + tensors_registry: WeakValueDictionary = WeakValueDictionary() + + # local container for strong references to shared tensors + # this set compensates for the fact that torch.nn.Parameter + # and Parameter subclasses do not hold reliable references to tensors + local_tensors: set[torch.Tensor] + + # dictionary mapping partition indices to associated parameters + partitions: dict[int, Union[ModelWeightParameter, Parameter]] + + def __new__(cls, **kwargs): + return super().__new__(cls, data=None, **kwargs) + + def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): + weight_loader: Callable = kwargs.get("weight_loader") # type: ignore[assignment] + super().__init__(data=None, weight_loader=weight_loader) + + self.local_tensors = set() + self.partitions = {} + self.kwargs = { + "input_dim": input_dim, + "output_dim": output_dim, + "weight_loader": self._fake_weight_loader, + } + + if self.tp_size > 1: + raise NotImplementedError( + f"{self.__class__.__name__} does not " + "currently support tensor parallelism" + ) + + def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): + """ + Add a partition to the weight parameter. Partitions whose `data_key` + is the same will share tensor data + + :param index: index of partition to add + :param data_key: hashable key used to key shared tensors + :param *args: arguments for `torch.empty` + :param **kwargs: keyword arguments for `torch.empty` + """ + # load (shared) tensor using `data_key` + if data_key not in self.tensors_registry: + data = torch.empty(*args, **kwargs) + self.tensors_registry[data_key] = data + else: + data = self.tensors_registry[data_key] + + # create associated model parameter + self.partitions[index] = ModelWeightParameter(data=data, **self.kwargs) # type: ignore[arg-type] + + # hold local reference, since ModelWeightParameter does not + # see https://github.com/pytorch/pytorch/issues/75932 + self.local_tensors.add(data) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_column_parallel_weight(partition, loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = kwargs.pop("shard_id") + partition_id = self._shard_id_as_int(partition_id) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + + ModelWeightParameter.load_merged_column_weight( + partition, loaded_weight, shard_offset=shard_offset, shard_size=shard_size + ) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = self._shard_id_as_int(kwargs.pop("shard_id")) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + shard_id = "q" # fake first partition + num_heads = kwargs.get("num_heads") + + ModelWeightParameter.load_qkv_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size, + shard_id=shard_id, + num_heads=num_heads, + ) + + def process_weights_after_loading(self): + for key in self.partitions: + self.partitions[key] = torch.nn.Parameter( + data=self.partitions[key].data, requires_grad=False + ) + + @property + def data(self): + raise ValueError( + "Accessing `data` of a " + "`PartitionedModelWeightParameter` is not allowed. " + "Instead, use `get_partition` to get the weight of " + "the particular partition you want to access" + ) + + def _fake_weight_loader( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_weight_shard_id: Optional[Union[str, int]], + ): + raise ValueError( + "When loading partition weights of " + f"{self.__class__.__name__}, use methods provided by " + f"{self.__class__.__name__}, not partition loader" + ) + + +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, useful for forcing the parameter into a known layout, for example, if I need - a packed (quantized) weight matrix to be in the layout + a packed (quantized) weight matrix to be in the layout {input_dim = 0, output_dim = 1, packed_dim = 0} then I can call: permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) - to ensure x is in the correct layout (permuting it to the correct layout if + to ensure x is in the correct layout (permuting it to the correct layout if required, asserting if it cannot get it to the correct layout) """ @@ -390,35 +574,34 @@ def permute_param_layout_(param: BasevLLMParameter, input_dim: int, curr_output_dim = getattr(param, "output_dim", None) if curr_input_dim is None or curr_output_dim is None: - assert param.data.dim() == 2,\ - "permute_param_layout_ only supports 2D parameters when either "\ + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " "input_dim or output_dim is not set" + ) # if one of the dimensions is not set, set it to the opposite of the other # we can only do this since we asserted the parameter is 2D above if curr_input_dim is None: - assert curr_output_dim is not None,\ - "either input or output dim must be set" + assert curr_output_dim is not None, "either input or output dim must be set" curr_input_dim = (curr_output_dim + 1) % 2 if curr_output_dim is None: - assert curr_input_dim is not None,\ - "either input or output dim must be set" + assert curr_input_dim is not None, "either input or output dim must be set" curr_output_dim = (curr_input_dim + 1) % 2 # create permutation from the current layout to the layout with # self.input_dim at input_dim and self.output_dim at output_dim preserving # other dimensions perm = [ - i for i in range(param.data.dim()) - if i not in [curr_input_dim, curr_output_dim] + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] ] perm.insert(input_dim, curr_input_dim) perm.insert(output_dim, curr_output_dim) if "packed_dim" in kwargs: - assert hasattr(param, "packed_dim") and\ - param.packed_dim == perm[kwargs["packed_dim"]],\ - "permute_param_layout_ currently doesn't support repacking" + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" param.data = param.data.permute(*perm) if hasattr(param, "_input_dim"): @@ -431,29 +614,30 @@ def permute_param_layout_(param: BasevLLMParameter, input_dim: int, return param -def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, - marlin_tile_size): +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, - bitblas_tile_size): +def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, bitblas_tile_size): return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size -def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, - marlin_tile_size, bitblas_tile_size): +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor, marlin_tile_size, bitblas_tile_size +): shard_size = shard_size // packed_factor shard_offset = shard_offset // packed_factor if marlin_tile_size is not None: return _adjust_shard_indexes_for_marlin( shard_size=shard_size, shard_offset=shard_offset, - marlin_tile_size=marlin_tile_size) + marlin_tile_size=marlin_tile_size, + ) elif bitblas_tile_size is not None: return _adjust_shard_indexes_for_bitblas( shard_size=shard_size, shard_offset=shard_offset, - bitblas_tile_size=bitblas_tile_size) + bitblas_tile_size=bitblas_tile_size, + ) - return shard_size, shard_offset \ No newline at end of file + return shard_size, shard_offset diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py deleted file mode 100644 index 3209879193453..0000000000000 --- a/vllm/model_executor/pooling_metadata.py +++ /dev/null @@ -1,90 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any, Optional - -import torch - -from vllm.pooling_params import PoolingParams -from vllm.utils import is_pin_memory_available -from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor - - -class PoolingMetadata: - """Metadata for pooling operations in the Pooler layer. - - This class holds the necessary information for pooling operations, - providing context for how to perform pooling and other related operations. - - Attributes: - seq_groups: List of (seq_ids, pooling_params). - seq_data: A mapping of sequence ID to additional sequence data. - prompt_lens: List of the lengths of each prompt. - """ - - def __init__( - self, - seq_groups: list[tuple[list[int], PoolingParams]], - seq_data: dict[int, Any], # Specific data related to sequences - prompt_lens: list[int], - pooling_cursor: Optional[PoolingCursor] = None) -> None: - self.seq_groups = seq_groups - self.seq_data = seq_data - self.prompt_lens = prompt_lens - self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor - - def __repr__(self) -> str: - return ("PoolingMetadata(" - f"seq_groups={self.seq_groups}, " - f"seq_data={self.seq_data}, " - f"prompt_lens={self.prompt_lens})") - - def __getitem__(self, indices: slice): - return PoolingMetadata( - seq_groups=self.seq_groups[indices], - seq_data=dict(list(self.seq_data.items())[indices]), - prompt_lens=self.prompt_lens[indices], - pooling_cursor=None - if self.pooling_cursor is None else self.pooling_cursor[indices], - ) - - def build_pooling_cursor(self, num_scheduled_tokens: list[int], - device: torch.device): - prompt_lens = torch.tensor(self.prompt_lens, device="cpu") - self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, - prompt_lens, - device=device) - - -@dataclass -class PoolingTensors: - """Tensors for pooling.""" - - prompt_lens: torch.Tensor - - @classmethod - def from_pooling_metadata( - cls, - pooling_metadata: "PoolingMetadata", - device: torch.device, - ) -> "PoolingTensors": - """ - Create PoolingTensors from PoolingMetadata. - - Args: - pooling_metadata: PoolingMetadata instance to convert. - device: Device to store the tensors. - """ - # Convert prompt lengths to tensor - pin_memory = is_pin_memory_available() - - prompt_lens_t = torch.tensor( - pooling_metadata.prompt_lens, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - - return cls(prompt_lens=prompt_lens_t.to(device=device, - non_blocking=True), ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py deleted file mode 100644 index 56f0f0984bfa0..0000000000000 --- a/vllm/model_executor/sampling_metadata.py +++ /dev/null @@ -1,597 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from array import array -from dataclasses import dataclass -from typing import Optional - -import torch - -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) - -_SAMPLING_EPS = 1e-5 - - -@dataclass -class SequenceGroupToSample: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Sequence ids for the sequence group in a previous step. - seq_ids: list[int] - sampling_params: SamplingParams - # seq_id -> sequence data. - seq_data: dict[int, SequenceData] - # The length of the sequence (all tokens seen in the past + new token to - # compute attention) of the sequence group. None if it is in a decode - # stage. - seq_len: Optional[int] - # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seq_len if chunked - # prefill is enabled. - query_len: Optional[int] - # A random number generator for sampling. - generator: Optional[torch.Generator] - # True if the sequence group is in prefill stage. False if it is in a - # decode stage. - is_prompt: bool - # Query token indices from logits. to compute prompt logprob. Empty if - # prompt logprob is not required. - prompt_logprob_indices: list[int] - # Sample token indices from logits. Empty if sampling is not required. - sample_indices: list[int] - - @property - def do_sample(self): - return len(self.sample_indices) > 0 - - def __post_init__(self): - if len(self.prompt_logprob_indices) > 0: - assert self.sampling_params.prompt_logprobs is not None - if self.is_prompt: - assert self.seq_len is not None - assert self.query_len is not None - - -def gen_seq_group_to_sample_builder(num_seqs: int): - return lambda: SequenceGroupToSample( - seq_ids=[0] * num_seqs, - sampling_params=None, - seq_data=None, # type: ignore - seq_len=0, - query_len=0, - generator=None, - is_prompt=True, - prompt_logprob_indices=[], - sample_indices=[], - ) - - -class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations""" - - def __init__(self): - self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {} - - def get_cached_seq_group_to_sample(self, num_seqs): - if num_seqs not in self._seq_group_to_sample_cache: - self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( - gen_seq_group_to_sample_builder(num_seqs)) - - obj = self._seq_group_to_sample_cache[num_seqs].get_object() - return obj - - def reset(self): - for cache in self._seq_group_to_sample_cache.values(): - cache.reset() - - -class SamplingMetadata: - """Metadata for input sequences. Used in sampler. - - The usage is as follow; - ``` - hidden_states = execute_model(...) - logits = hidden_states[sampling_metadata.selected_token_indices] - sample(logits) - - def sample(logits): - # Use categorized_sample_indices for sampling.... - ``` - - Args: - seq_groups: List of batched sequence groups. - selected_token_indices: (num_query_tokens_to_logprob). Indices to find - logits from the initial model output hidden states. - categorized_sample_indices: SamplingType -> token indices to sample. - Each token indices is 2D tensor of (num_indices, num_indices) where - the first item means the sample index within the returned logit - (before pruning padding), and the second item means the sample - index after pruning using selected_token_indices. - For example, if the returned logit is [1, 2, 3], and we select - [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, - The first tuple is [1, 2] (sampled index within original logit), - and the second tuple is [0, 1] (sampled index within pruned logit). - num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU - serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling - tensors that are part of the sampler forward pass. Currently, - it is mainly used for multi-step decode. - - """ - - def __init__( - self, - seq_groups: list[SequenceGroupToSample], - selected_token_indices: torch.Tensor, - categorized_sample_indices: dict[SamplingType, torch.Tensor], - num_prompts: int, - skip_sampler_cpu_output: bool = False, - reuse_sampling_tensors: bool = False, - ) -> None: - self.seq_groups = seq_groups - self.selected_token_indices = selected_token_indices - self.categorized_sample_indices = categorized_sample_indices - self.num_prompts = num_prompts - self.skip_sampler_cpu_output = skip_sampler_cpu_output - self.reuse_sampling_tensors = reuse_sampling_tensors - - @staticmethod - def prepare( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - pin_memory: bool, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, - ) -> "SamplingMetadata": - ( - seq_groups, - selected_token_indices, - categorized_sample_indices, - num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device, generators, cache) - selected_token_indices = async_tensor_h2d( - selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory, - ) - categorized_sample_indices = { - t: - async_tensor_h2d( - seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory, - ) - for t, seq_ids in categorized_sample_indices.items() - } - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - num_prompts=num_prompts, - ) - return sampling_metadata - - def __repr__(self) -> str: - return ( - "SamplingMetadata(" - f"seq_groups={self.seq_groups}, " - f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices})") - - -def _prepare_seq_groups( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, -) -> tuple[ - list[SequenceGroupToSample], - list[int], - dict[SamplingType, list[int]], - int, -]: - """Prepare sequence groups and indices for sampling. - - Args: - seq_group_metadata_list: A list of sequence group to batch. - seq_lens: A list of sequence lens per sequence group. - Index of prompt len should match with seq_group_metadata_list. - query_lens: A list of query lengths. Prompt lens include the length - of entire prompt tokens, and it could be shorter. - device: A device to use for random number generators, - `SequenceGroupToSample.generator`. - generators: A store of per-request random number generators used - for seeded requests. - - Returns: - seq_groups: A list of sequence group to sample. - selected_token_indices: See the definition from `SamplingMetadata`. - categorized_sample_indices: See the definition from `SamplingMetadata`. - num_prompts: Total number of prompts from `seq_group_metadata_list`. - """ - # Batched sequence groups for the current model forward stsep. - seq_groups: list[SequenceGroupToSample] = [] - # A list of token indices to sample/compute logprob. It is used to - # prune the outcome logits from the model for the performance. - selected_token_indices: list[int] = [] - # Used for selected_token_indices. - model_output_idx = 0 - - # Sampling type -> ( - # indices to sample/prompt logprob within pruned output logits, - # indices to sample within pruned logits) - categorized_sample_indices: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - # Index of logits to compute logprob. Logits include both prompt logprob - # and sample logprob indices. - logit_idx = 0 - # Total number of prompts from given sequence groups. - num_prompts = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = seq_group_metadata.seq_data.keys() - - if cache is not None: - sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) - - for j, seq_id in enumerate(seq_ids): - sample_obj.seq_ids[j] = seq_id - - sample_obj.prompt_logprob_indices.clear() - sample_obj.sample_indices.clear() - - sampling_params = seq_group_metadata.sampling_params - is_prompt = seq_group_metadata.is_prompt - generator: Optional[torch.Generator] = None - # If the current seq group is in decode stage, it is None. - seq_len: Optional[int] = None - query_len: Optional[int] = None - prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices - if cache is not None else []) - sample_indices: list[int] = (sample_obj.sample_indices - if cache is not None else []) - do_sample = seq_group_metadata.do_sample - - if seq_group_metadata.is_prompt: - if sampling_params.seed is not None: - generator = torch.Generator(device=device).manual_seed( - sampling_params.seed) - if generators is not None: - generators[seq_group_metadata.request_id] = generator - - num_prompts += 1 - num_prefill_sample = len(seq_ids) - assert num_prefill_sample == 1 - assert query_lens is not None and seq_lens is not None - query_len, seq_len = query_lens[i], seq_lens[i] - # If we need sampling, exclude num_prefill_sample tokens from - # prompt logprob. - prompt_logprob_len = (query_len - num_prefill_sample - if do_sample else query_len) - sample_len = num_prefill_sample if do_sample else 0 - else: - # Decode - prompt_logprob_len = 0 - query_len = query_lens[i] if query_lens is not None and len( - query_lens) > 0 else 1 - sample_len = len(seq_ids) * query_len if do_sample else 0 - - if sampling_params.seed is not None and generators is not None: - generator = generators.get(seq_group_metadata.request_id) - - # Update indices to select from the model output. - """ - This blocks computes selected_token_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - """ - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + prompt_logprob_len)) - model_output_idx += prompt_logprob_len - if do_sample: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + sample_len)) - model_output_idx += sample_len - - # We now find indices for logprob computation and sampling. - """ - This block computes categorized_sample_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - def sample(logits): - # Use categorized_sample_indices for sampling. - # prompt_logprob_indices to find prompt logprob indices. - # sample_indices to find sample indices. - """ - - if sampling_params.prompt_logprobs is not None: - prompt_logprob_indices.extend( - range(logit_idx, logit_idx + prompt_logprob_len)) - logit_idx += prompt_logprob_len - if do_sample: - sample_indices.extend(range(logit_idx, logit_idx + sample_len)) - categorized_sample_indices[sampling_params.sampling_type].extend( - list(range(logit_idx, logit_idx + sample_len))) - logit_idx += sample_len - - if cache is not None: - sample_obj.sampling_params = sampling_params - sample_obj.seq_data = seq_group_metadata.seq_data - sample_obj.seq_len = seq_len - sample_obj.query_len = query_len - sample_obj.generator = generator - sample_obj.is_prompt = is_prompt - else: - sample_obj = SequenceGroupToSample( - seq_ids=list(seq_ids), - sampling_params=sampling_params, - seq_data=seq_group_metadata.seq_data, - seq_len=seq_len, - query_len=query_len, - generator=generator, - is_prompt=is_prompt, - prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices), - ) - - seq_groups.append(sample_obj) - - if cache is not None: - cache.reset() - - return (seq_groups, selected_token_indices, categorized_sample_indices, - num_prompts) - - -@dataclass -class SamplingTensors: - """Tensors for sampling.""" - - temperatures: torch.Tensor - top_ps: torch.Tensor - top_ks: torch.Tensor - min_ps: torch.Tensor - presence_penalties: torch.Tensor - frequency_penalties: torch.Tensor - repetition_penalties: torch.Tensor - prompt_tokens: torch.Tensor - output_tokens: torch.Tensor - - @classmethod - def from_sampling_metadata( - cls, - sampling_metadata: "SamplingMetadata", - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> tuple["SamplingTensors", bool, bool, bool]: - prompt_tokens: list[array] = [] - output_tokens: list[array] = [] - top_ks: list[int] = [] - temperatures: list[float] = [] - top_ps: list[float] = [] - min_ps: list[float] = [] - presence_penalties: list[float] = [] - frequency_penalties: list[float] = [] - repetition_penalties: list[float] = [] - do_penalties = False - do_top_p_top_k = False - do_min_p = False - - assert sampling_metadata.seq_groups is not None - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - temperature = sampling_params.temperature - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - top_p = sampling_params.top_p - min_p = sampling_params.min_p - - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k < 1 else top_k - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS - or top_k != vocab_size): - do_top_p_top_k = True - if not do_min_p and min_p > _SAMPLING_EPS: - do_min_p = True - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True - - is_prompt = seq_group.is_prompt - if is_prompt and sampling_params.prompt_logprobs is not None: - # For tokens in the prompt that we only need to get - # their logprobs - query_len = seq_group.query_len - assert query_len is not None - prefill_len = len(seq_group.prompt_logprob_indices) - temperatures += [temperature] * prefill_len - top_ps += [top_p] * prefill_len - top_ks += [top_k] * prefill_len - min_ps += [min_p] * prefill_len - presence_penalties += [0] * prefill_len - frequency_penalties += [0] * prefill_len - repetition_penalties += [1] * prefill_len - - if seq_group.do_sample: - sample_lens = len(seq_group.sample_indices) - assert sample_lens >= len(seq_ids) - temperatures += [temperature] * sample_lens - top_ps += [top_p] * sample_lens - top_ks += [top_k] * sample_lens - min_ps += [min_p] * sample_lens - presence_penalties += [p] * sample_lens - frequency_penalties += [f] * sample_lens - repetition_penalties += [r] * sample_lens - - if do_penalties: - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - prefill_len = len(seq_group.prompt_logprob_indices) - prompt_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - output_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - if seq_group.do_sample: - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids_array) - output_tokens.append(seq_data.output_token_ids_array) - - sampling_tensors = SamplingTensors.from_lists( - temperatures, - top_ps, - top_ks, - min_ps, - presence_penalties, - frequency_penalties, - repetition_penalties, - prompt_tokens, - output_tokens, - vocab_size, - device, - dtype, - ) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) - - @classmethod - def from_lists( - cls, - temperatures: list[float], - top_ps: list[float], - top_ks: list[int], - min_ps: list[float], - presence_penalties: list[float], - frequency_penalties: list[float], - repetition_penalties: list[float], - prompt_tokens: list[array], - output_tokens: list[array], - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> "SamplingTensors": - # Note that the performance will be very bad without - # pinned memory. - pin_memory = is_pin_memory_available() - - do_penalties = prompt_tokens or output_tokens - - if do_penalties: - prompt_t = make_tensor_with_pad( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - output_t = make_tensor_with_pad( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - else: - empty_tensor = torch.empty(0, device=device, dtype=torch.long) - prompt_t = empty_tensor - output_t = empty_tensor - - temperatures_t = torch.tensor( - temperatures, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ps_t = torch.tensor( - top_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - min_ps_t = torch.tensor( - min_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - presence_penalties_t = torch.tensor( - presence_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - frequency_penalties_t = torch.tensor( - frequency_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - repetition_penalties_t = torch.tensor( - repetition_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ks_t = torch.tensor( - top_ks, - device="cpu", - dtype=torch.int, - pin_memory=pin_memory, - ) - # Because the memory is pinned, we can do non-blocking - # transfer to device. - - return cls( - temperatures=temperatures_t.to(device=device, non_blocking=True), - top_ps=top_ps_t.to(device=device, non_blocking=True), - top_ks=top_ks_t.to(device=device, non_blocking=True), - min_ps=min_ps_t.to(device=device, non_blocking=True), - presence_penalties=presence_penalties_t.to(device=device, - non_blocking=True), - frequency_penalties=frequency_penalties_t.to(device=device, - non_blocking=True), - repetition_penalties=repetition_penalties_t.to(device=device, - non_blocking=True), - prompt_tokens=prompt_t.to(device=device, non_blocking=True), - output_tokens=output_t.to(device=device, non_blocking=True), - ) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 41ed0b09c5a2a..4abd2625f8066 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -30,8 +30,7 @@ def set_weight_attrs( if weight_attrs is None: return for key, value in weight_attrs.items(): - assert not hasattr( - weight, key), f"Overwriting existing tensor attribute: {key}" + assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" # NOTE(woosuk): During weight loading, we often do something like: # narrowed_tensor = param.data.narrow(0, offset, len) @@ -44,22 +43,11 @@ def set_weight_attrs( # TODO(woosuk): Remove this hack once we have a better solution. from vllm.platforms import current_platform - if current_platform.is_tpu() and key == "weight_loader": - value = _make_synced_weight_loader(value) + if current_platform.use_sync_weight_loader() and key == "weight_loader": + value = current_platform.make_synced_weight_loader(value) setattr(weight, key, value) -def _make_synced_weight_loader(original_weight_loader): - - def _synced_weight_loader(param, *args, **kwargs): - original_weight_loader(param, *args, **kwargs) - # torch._sync doesn't support, is not needed for CPU tensors. - if param.device != torch.device("cpu"): - torch._sync(param) - - return _synced_weight_loader - - def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: parent_map = getattr(model, "packed_modules_mapping", None) parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} @@ -73,18 +61,19 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: child_map = getattr(child, "packed_modules_mapping", None) child_map = copy.deepcopy(child_map) if child_map is not None else {} - if any((k in parent_map and parent_map[k] != v) - for k, v in child_map.items()): + if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()): raise ValueError( f"Can't update {type(model).__name__}'s packed_modules_mapping " - f"safely because of conflicts from {type(child).__name__}.") + f"safely because of conflicts from {type(child).__name__}." + ) else: parent_map.update(child_map) return parent_map def get_moe_expert_mapping( - model: torch.nn.Module, ) -> list[tuple[str, str, int, str]]: + model: torch.nn.Module, +) -> list[tuple[str, str, int, str]]: if parent_map := getattr(model, "get_expert_mapping", None): return parent_map() else: diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 74599fa44c88c..1747caf26cef9 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -10,21 +10,25 @@ import torch from tqdm import tqdm import vllm.envs as envs +from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, deep_gemm_block_shape) + compute_aligned_M, + deep_gemm_block_shape, +) from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous def _extract_data_from_linear_base_module( - m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + m: torch.nn.Module, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """ Extract weights, weight scales and quantization block sizes from the given LinearBase module. @@ -35,7 +39,7 @@ def _extract_data_from_linear_base_module( assert m.quant_method.quant_config is not None w = m.weight - ws = m.weight_scale_inv + ws = m.weight_scale quant_block_size = m.quant_method.quant_config.weight_block_size assert isinstance(w, torch.Tensor) @@ -45,16 +49,24 @@ def _extract_data_from_linear_base_module( def _extract_data_from_fused_moe_module( - m: torch.nn.Module + m: torch.nn.Module, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: """ Extract weights, weight scales and num_topk from FusedMoE module. """ assert isinstance(m, FusedMoE) w13 = m.w13_weight - w13_s = m.w13_weight_scale_inv + w13_s = ( + m.w13_weight_scale_inv + if hasattr(m, "w13_weight_scale_inv") + else m.w13_weight_scale + ) w2 = m.w2_weight - w2_s = m.w2_weight_scale_inv + w2_s = ( + m.w2_weight_scale_inv + if hasattr(m, "w2_weight_scale_inv") + else m.w2_weight_scale + ) num_topk = m.top_k assert isinstance(w13, torch.Tensor) @@ -69,38 +81,48 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: Return True if the input module/layer could be processed with DeepGEMM. """ block_size = deep_gemm_block_shape()[0] - if not (isinstance(module, LinearBase) - and isinstance(module.quant_method, Fp8LinearMethod) - and module.quant_method.block_quant): + if not ( + isinstance(module, LinearBase) + and isinstance(module.quant_method, Fp8LinearMethod) + and module.quant_method.block_quant + ): return False w, _, block_sizes = _extract_data_from_linear_base_module(module) - return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 - and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) + return ( + block_sizes == deep_gemm_block_shape() + and w.ndim == 2 + and w.shape[0] % block_size == 0 + and w.shape[1] % block_size == 0 + ) def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: - if not (isinstance(module, FusedMoE) - and module.moe_config.quant_dtype == torch.float8_e4m3fn - and module.moe_config.block_shape == deep_gemm_block_shape()): + if not isinstance(module, FusedMoE): return False - if not isinstance(module.quant_method.fused_experts, - FusedMoEModularKernel): + moe_quant_config = module.quant_method.get_fused_moe_quant_config(module) + + if ( + moe_quant_config is None + or moe_quant_config.quant_dtype != torch.float8_e4m3fn + or moe_quant_config.block_shape != deep_gemm_block_shape() + ): + return False + + if not isinstance(module.quant_method.fused_experts, FusedMoEModularKernel): # fused_experts could invoke deep_gemm_moe_fp8 return True mk: FusedMoEModularKernel = module.quant_method.fused_experts # Further check if the ModularKernel implementation uses the DeepGemmExperts - return isinstance(mk.fused_experts, - (DeepGemmExperts, TritonOrDeepGemmExperts)) + return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts)) FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, - max_tokens: int): +def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int): if w.size() in FP8_GEMM_NT_WARMUP_CACHE: return @@ -108,20 +130,18 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, block_m = deep_gemm_block_shape()[0] device = w.device - a1q = torch.empty((max_tokens, k), - device=device, - dtype=torch.float8_e4m3fn) - a1q_scales = torch.empty((max_tokens, k // block_m), - device=device, - dtype=torch.float32) + a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn) + a1q_scales = torch.empty( + (max_tokens, k // block_m), device=device, dtype=torch.float32 + ) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=max_tokens, - desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") + pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") num_tokens = max_tokens while num_tokens > 0: - fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), - out[:num_tokens]) + fp8_gemm_nt( + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] + ) pbar.update(1) num_tokens -= 1 @@ -131,57 +151,62 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int): - if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE - and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int, + max_tokens: int, +): + if ( + w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + ): return - assert w1.size(0) == w2.size(0), ( - "w1 and w2 must have the same number of experts") + assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" block_m = deep_gemm_block_shape()[0] num_experts = w1.size(0) device = w1.device + # Assumes all ranks have the same max_num_batched_tokens + max_tokens_across_dp = get_dp_group().world_size * max_tokens + max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None) + MAX_M = compute_aligned_M( + max_tokens, num_topk, num_experts, block_m, expert_tokens_meta=None + ) # Distribute expert-ids evenly. MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint(low=0, - high=num_experts, - size=(MAX_BLOCKS, ), - device=device, - dtype=torch.int32) + expert_ids_block = torch.randint( + low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 + ) expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) - a1q_scales = torch.empty((MAX_M, k // block_m), - device=device, - dtype=torch.float32) + a1q_scales = torch.empty( + (MAX_M, k // block_m), device=device, dtype=torch.float32 + ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) pbar = tqdm( total=MAX_BLOCKS, - desc= - f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})" + desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})", ) num_tokens = MAX_M while num_tokens > 0: m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), - out[:num_tokens], expert_ids[:num_tokens]) + (a1q[:num_tokens], a1q_scales[:num_tokens]), + (w, w_scale), + out[:num_tokens], + expert_ids[:num_tokens], + ) pbar.update(1) num_tokens = num_tokens - block_m @@ -192,28 +217,29 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): - dg_modules = [ - m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m) - ] + dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)] for dgm in dg_modules: w, ws, _ = _extract_data_from_linear_base_module(dgm) _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) -def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + model: torch.nn.Module, max_tokens: int +): dg_modules = [ - m for m in model.modules() - if _fused_moe_grouped_gemm_may_use_deep_gemm(m) + m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) ] for dgm in dg_modules: - w13, w13_scale, w2, w2_scale, num_topk = ( - _extract_data_from_fused_moe_module(dgm)) + w13, w13_scale, w2, w2_scale, num_topk = _extract_data_from_fused_moe_module( + dgm + ) _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w13, w2, w13_scale, w2_scale, num_topk) + w13, w2, w13_scale, w2_scale, num_topk, max_tokens + ) def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): deepgemm_fp8_gemm_nt_warmup(model, max_tokens) - deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 761172e4d3616..23227065ee950 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -5,11 +5,13 @@ Warmup kernels used during model execution. This is useful specifically for JIT'ed kernels as we don't want JIT'ing to happen during model execution. """ + from typing import TYPE_CHECKING import torch import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported @@ -19,21 +21,50 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_worker import Worker +logger = init_logger(__name__) + def kernel_warmup(worker: "Worker"): # Deep GEMM warmup - do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM - and is_deep_gemm_supported() - and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP) + do_deep_gemm_warmup = ( + envs.VLLM_USE_DEEP_GEMM + and is_deep_gemm_supported() + and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP + ) if do_deep_gemm_warmup: model = worker.get_model() max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) - # FlashInfer autotune for Blackwell (SM 10.0) GPUs - if has_flashinfer() and current_platform.is_device_capability(100): + # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) + # FlashInfer attention warmup + # Only warmup if the model has FlashInfer attention groups + # and is not a pooling model + def _is_flashinfer_backend(backend): + try: + return backend.get_name() == "FLASHINFER" + except NotImplementedError: + return False + + if not worker.model_runner.is_pooling_model and all( + _is_flashinfer_backend(group.backend) + for groups in worker.model_runner.attn_groups + for group in groups + ): + logger.info("Warming up FlashInfer attention.") + # Warmup with mixed batch containing both prefill and decode tokens + # This is to warm up both prefill and decode attention kernels + worker.model_runner._dummy_run( + num_tokens=16, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ @@ -52,6 +83,8 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: # When autotuning with number of tokens m, flashinfer will autotune # operations for all number of tokens up to m. # So we only need to run with the max number of tokens. - runner._dummy_run(runner.scheduler_config.max_num_batched_tokens, - skip_eplb=True, - is_profile=True) + runner._dummy_run( + runner.scheduler_config.max_num_batched_tokens, + skip_eplb=True, + is_profile=True, + ) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 69eed22741446..b7cbb3bbc67e7 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,11 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .base import MultiModalPlaceholderMap -from .hasher import MultiModalHashDict, MultiModalHasher -from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalKwargs, - MultiModalKwargsItems, MultiModalPlaceholderDict, - NestedTensors) +from .hasher import MultiModalHasher +from .inputs import ( + BatchedTensorInputs, + ModalityData, + MultiModalDataBuiltins, + MultiModalDataDict, + MultiModalKwargs, + MultiModalKwargsItems, + MultiModalPlaceholderDict, + MultiModalUUIDDict, + NestedTensors, +) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -15,7 +21,7 @@ is used by model runners to dispatch data processing according to the target model. Info: - [mm_processing](../../../design/mm_processing.html) + [mm_processing](../../../design/mm_processing.md) """ __all__ = [ @@ -23,12 +29,11 @@ __all__ = [ "ModalityData", "MultiModalDataBuiltins", "MultiModalDataDict", - "MultiModalHashDict", "MultiModalHasher", "MultiModalKwargs", "MultiModalKwargsItems", "MultiModalPlaceholderDict", - "MultiModalPlaceholderMap", + "MultiModalUUIDDict", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f3b273eb41e8f..d81354d9a399e 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -66,23 +66,25 @@ class AudioResampler: orig_sr: float, ) -> npt.NDArray[np.floating]: if self.target_sr is None: - raise RuntimeError("Audio resampling is not supported when " - "`target_sr` is not provided") + raise RuntimeError( + "Audio resampling is not supported when `target_sr` is not provided" + ) if self.method == "librosa": - return resample_audio_librosa(audio, - orig_sr=orig_sr, - target_sr=self.target_sr) + return resample_audio_librosa( + audio, orig_sr=orig_sr, target_sr=self.target_sr + ) elif self.method == "scipy": - return resample_audio_scipy(audio, - orig_sr=orig_sr, - target_sr=self.target_sr) + return resample_audio_scipy( + audio, orig_sr=orig_sr, target_sr=self.target_sr + ) else: - raise ValueError(f"Invalid resampling method: {self.method}. " - "Supported methods are 'librosa' and 'scipy'.") + raise ValueError( + f"Invalid resampling method: {self.method}. " + "Supported methods are 'librosa' and 'scipy'." + ) class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): - def __init__(self, **kwargs) -> None: super().__init__() @@ -106,11 +108,11 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: return librosa.load(filepath, sr=None) - def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: + def encode_base64(self, media: tuple[npt.NDArray, int]) -> str: audio, sr = media with BytesIO() as buffer: soundfile.write(buffer, audio, sr, format="WAV") data = buffer.getvalue() - return base64.b64encode(data).decode('utf-8') + return base64.b64encode(data).decode("utf-8") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ef8f1b2e17b47..fef118a93c6cb 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,206 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar - -if TYPE_CHECKING: - from vllm.sequence import SequenceGroupMetadata - -from .inputs import MultiModalKwargs, PlaceholderRange +from typing import Generic, TypeVar _T = TypeVar("_T") -class MultiModalPlaceholderMap: - """ - Relates multi-modal embeddings to their corresponding placeholders. - - Note: This is only used in V0. - """ - - class IndexMap(NamedTuple): - src: list[int] - dest: list[int] - - src_ranges: list[range] - """ - The indices of the multi-modal embeddings that will replace the - corresponding placeholder embeddings pointed to by ``dest_ranges``. - """ - - src_len: int - """ - The total number of flattened multi-modal embeddings. - """ - - dest_ranges: list[range] - """ - The indices of the placeholder embeddings that will be replaced by the - multimodal embeddings. - """ - - dest_len: int - """ - The total number of embeddings in the destination tensor. - """ - - def __init__(self): - self.src_ranges = [] - self.src_len = 0 - self.dest_ranges = [] - self.dest_len = 0 - - @classmethod - def from_seq_group( - cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: - """ - Returns the multi-modal items that intersect with the portion of a - prompt (``seq_group``) represented by ``positions``, as well as a - ``MultiModalPlaceholderMap`` that relates the multi-modal embedding - vectors to their corresponding placeholders. - - Examples: - - ``` - Prompt: |AAAA BBBB What's in these images?| - Positions: |.................................| - - images = [A, B] - src_ranges = [(0, 4), (4, 8)] - dest_ranges = [(0, 4), (5, 9)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | ..... | - - images = [A, B] - src_ranges = [(2, 4), (4, 6)] - dest_ranges = [(0, 2), (3, 5)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | ......... | - - images = [B] - src_ranges = [(0, 4)] - dest_ranges = [(0, 4)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | .......................| - - images = [] - src_ranges = [] - dest_ranges = [] - ``` - """ - seq_mm_data = seq_group.multi_modal_data - seq_mm_placeholders = seq_group.multi_modal_placeholders - - if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs(), {} - - placeholder_maps = dict[str, MultiModalPlaceholderMap]() - - for modality, placeholders in seq_mm_placeholders.items(): - placeholder_map = MultiModalPlaceholderMap() - - if positions: - placeholder_map.append_items_from_seq_group( - positions, - # Dummy, since we don't care about intersecting items - [None] * len(placeholders), - placeholders, - ) - - placeholder_maps[modality] = placeholder_map - - return seq_mm_data, placeholder_maps - - def append_items_from_seq_group( - self, - positions: range, - multi_modal_items: list[_T], - multi_modal_placeholders: Sequence[PlaceholderRange], - ) -> list[_T]: - """ - Adds the multi-modal items that intersect ```positions`` to this - placeholder map and returns the intersecting items. - """ - intersecting_items = [] - - if len(multi_modal_items) != len(multi_modal_placeholders): - raise ValueError( - "Multi-modal placeholders and items must have the same length." - ) - for placeholder_dict, mm_item in zip(multi_modal_placeholders, - multi_modal_items): - placeholder = range( - placeholder_dict.offset, - placeholder_dict.offset + placeholder_dict.length, - ) - intersection = range( - max(positions.start, placeholder.start), - min(positions.stop, placeholder.stop), - ) - - if not intersection: - # Skip this multi-modal item. - continue - - token_embedding_range = range( - intersection.start - positions.start, - intersection.stop - positions.start, - ) - - multimodal_embedding_range = range( - intersection.start - placeholder.start + self.src_len, - intersection.stop - placeholder.start + self.src_len, - ) - - intersecting_items.append(mm_item) - self.dest_ranges.append(token_embedding_range) - self.src_ranges.append(multimodal_embedding_range) - self.src_len += len(placeholder) - - self.dest_len += len(positions) - return intersecting_items - - def extend(self, other: "MultiModalPlaceholderMap"): - """ - Adds the placeholders from another ``MultiModalPlaceholderMap`` to this - instance based on the source and destination tensors being - concatenated. - """ - - self.src_ranges.extend( - range(self.src_len + r.start, self.src_len + r.stop) - for r in other.src_ranges) - self.src_len += other.src_len - self.dest_ranges.extend( - range(self.dest_len + r.start, self.dest_len + r.stop) - for r in other.dest_ranges) - self.dest_len += other.dest_len - - def index_map(self) -> "IndexMap": - """ - Finalizes the placeholder map into lists of indices that can be used to - index the source and destination tensors. - """ - - src_indices = [i for r in self.src_ranges for i in r] - dest_indices = [i for r in self.dest_ranges for i in r] - - if len(src_indices) != len(dest_indices): - raise ValueError( - f"The number of source ({len(src_indices)}) and destination " - f"indices ({len(dest_indices)}) must be the same.") - - return self.IndexMap(src=src_indices, dest=dest_indices) - - class MediaIO(ABC, Generic[_T]): - @abstractmethod def load_bytes(self, data: bytes) -> _T: raise NotImplementedError diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 5cec8e71fb265..8b72bbe56eafd 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,70 +1,126 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import operator import sys -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TypeVar, Union +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from multiprocessing.synchronize import Lock as LockType +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast import torch +from typing_extensions import TypeAlias, override +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, + SingleWriterShmObjectStorage, + SingleWriterShmRingBuffer, +) +from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, LRUCache -from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves +from vllm.utils import GiB_bytes, MiB_bytes +from vllm.utils.cache import CacheInfo, LRUCache +from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves -from .inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, MultiModalKwargsItems, - NestedTensors) +from .inputs import ( + MultiModalBatchedField, + MultiModalFeatureSpec, + MultiModalFieldElem, + MultiModalKwargs, + MultiModalKwargsItem, + MultiModalKwargsItems, + NestedTensors, +) + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + + from .processing import ResolvedPromptUpdate + from .registry import MultiModalRegistry logger = init_logger(__name__) -@dataclass -class MultiModalCacheItemMetadata: - size: int +class MultiModalProcessorCacheItem: + """ + The data to store inside `MultiModalProcessorOnlyCache`. - @classmethod - def wraps(cls, value: "MultiModalCacheValue"): - return cls(size=MultiModalCache.get_item_size(value)) + Args: + item: The processed tensor data corresponding to a multi-modal item. + prompt_updates: The prompt updates corresponding to `item`. + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item = item + self.prompt_updates = prompt_updates + + +class MultiModalProcessorCacheItemMetadata: + """ + The metadata to store inside `MultiModalProcessorSenderCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + Since P1 already stores the tensor data, we only store its size + metadata in P0 to reduce memory usage. The size metadata is still + needed to keep the same cache eviction policy as P0. + prompt_updates: The prompt updates corresponding to `item`. + This needs to stay on P0 because for some models, they are + dependent on the processed tensor data (cached on P1). + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item_size = MultiModalCache.get_item_size(item) + self.prompt_updates = prompt_updates MultiModalCacheValue = Union[ + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, MultiModalKwargsItems, MultiModalKwargsItem, MultiModalKwargs, Mapping[str, NestedTensors], - MultiModalCacheItemMetadata, ] _V = TypeVar("_V", bound=MultiModalCacheValue) class MultiModalCache: - @classmethod - def get_leaf_size( - cls, - leaf: object, - *, - debug: bool = False, - ) -> int: - if isinstance(leaf, MultiModalFieldElem): - return cls.get_item_size(leaf.data) # type: ignore + def get_leaf_size(cls, leaf: object) -> int: + if isinstance(leaf, MultiModalProcessorCacheItem): + return cls.get_leaf_size(leaf.item) + if isinstance(leaf, MultiModalProcessorCacheItemMetadata): + return leaf.item_size # These are not subclasses of dict - if isinstance(leaf, MultiModalKwargsItems): - return cls.get_item_size(leaf.data) # type: ignore - if isinstance(leaf, MultiModalKwargsItem): - return cls.get_item_size(leaf.data) # type: ignore - if isinstance(leaf, MultiModalKwargs): + if isinstance( + leaf, + ( + MultiModalKwargs, + MultiModalKwargsItems, + MultiModalKwargsItem, + MultiModalFieldElem, + ), + ): return cls.get_item_size(leaf.data) # type: ignore # sys.getsizeof doesn't work for tensors if isinstance(leaf, torch.Tensor): return leaf.nbytes - if isinstance(leaf, MultiModalCacheItemMetadata): - return leaf.size - return sys.getsizeof(leaf) @classmethod @@ -75,17 +131,36 @@ class MultiModalCache: debug: bool = False, ) -> int: size = json_reduce_leaves( - lambda a, b: a + b, - json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug), - value), + operator.add, json_map_leaves(cls.get_leaf_size, value) ) if debug: - logger.debug("Calculated size of %s to be %.2f GiB", type(value), - size / GiB_bytes) + leaf_count = json_count_leaves(value) + logger.debug( + "Calculated size of %s to be %.2f GiB (%d leaves)", + type(value), + size / GiB_bytes, + leaf_count, + ) return size + @classmethod + def get_item_complexity(cls, value: MultiModalCacheValue) -> int: + """ + Get the number of leaf elements in a multi-modal cache value. + + This provides a measure of structural complexity that can be useful + for debugging cache performance and understanding data patterns. + + Args: + value: The multi-modal cache value to analyze. + + Returns: + The number of leaf elements in the nested structure. + """ + return json_count_leaves(value) + @classmethod def get_lru_cache( cls, @@ -98,3 +173,583 @@ class MultiModalCache: GiB_bytes * capacity_gb, getsizeof=lambda x: cls.get_item_size(x, debug=debug), ) + + +_I = TypeVar("_I", contravariant=True) +_O = TypeVar("_O", covariant=True) + + +class BaseMultiModalCache(ABC, Generic[_I, _O]): + """ + Abstract base class to read/write multi-modal items from cache. + + The idea of multi-modal caching is based on having a client and server + where the client executes in the frontend process (=P0) and + the server in the core process (=P1). The data flow is as follows: + + ``` + is_cached() x N get_and_update() + P0: From API -----------------> -----------------> To P1 + + get_and_update() + P1: From P0 -----------------> To model + ``` + + `is_cached()` can be called any number of times in P0. However, + `get_and_update()` must be called in P0 and P1 one after another + so that their cache eviction order remains the same. + + This ensures that the keys in P0 and P1 caches are mirrored, + allowing us to determine whether a key is cached in P1 by looking + up the P0 cache, without having to communicate with P1. + """ + + @abstractmethod + def get_and_update_item( + self, + mm_item: _I, + mm_hash: str, + ) -> _O: + """ + Possibly update a multi-modal item based on whether it is + in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_item: The multi-modal item to update. + mm_hash: The hash of `mm_item`. + + Returns: + The update multi-modal item. + """ + raise NotImplementedError + + def get_and_update( + self, + mm_items: Sequence[_I], + mm_hashes: list[str], + ) -> list[_O]: + """ + Possibly update a sequence of multi-modal items based on whether they + are in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_items: The multi-modal items to update. + mm_hashes: The hash of each item in `mm_items`. + + Returns: + A new list of updated multi-modal items. + """ + assert len(mm_items) == len(mm_hashes) + + return [ + self.get_and_update_item(mm_item, mm_hash) + for mm_item, mm_hash in zip(mm_items, mm_hashes) + ] + + @abstractmethod + def clear_cache(self) -> None: + """Clear the underlying cache.""" + raise NotImplementedError + + +MultiModalProcessorCacheInItem: TypeAlias = Optional[ + tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] +] + + +MultiModalProcessorCacheOutItem: TypeAlias = tuple[ + Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"] +] + + +class BaseMultiModalProcessorCache( + BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem] +): + """The required interface for caches on P0.""" + + @abstractmethod + def is_cached_item(self, mm_hash: str) -> bool: + """ + Check whether a multi-modal item is + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hash: The hash of the item to check. + + Returns: + `True` if the item is cached, otherwise `False`. + """ + raise NotImplementedError + + def is_cached(self, mm_hashes: list[str]) -> list[bool]: + """ + Check whether a sequence of multi-modal items are + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hashes: The hash of each item to check. + + Returns: + For each item, `True` if the item is cached, otherwise `False`. + """ + return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + + @abstractmethod + def make_stats(self, *, delta: bool = False) -> CacheInfo: + """ + Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + raise NotImplementedError + + +class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is disabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes + tensor data and metadata) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItem, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item.item, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._cache.stat(delta=delta) + + +class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the metadata of that item so + that the eviction policy remains the same as the cache on P1, + and return the input. + By only storing the metadata, we avoid keeping the data itself in + memory inside P0. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItemMetadata, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return None, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._cache.stat(delta=delta) + + +class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the data in shared memory. + """ + + def __init__(self, vllm_config: "VllmConfig") -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=True, # sender is the writer + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + ) + # cache (prompt_updates, modality) for P0 only + self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def _stat(self, *, delta: bool = False) -> CacheInfo: + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return self._shm_cache.is_cached(mm_hash) + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if self._shm_cache.is_cached(mm_hash): + self._hits += 1 + self._total += 1 + + address, monotonic_id = self._shm_cache.get_cached(mm_hash) + prompt_updates, modality = self._p0_cache[mm_hash] + return self.address_as_item(address, monotonic_id, modality), prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._total += 1 + + try: + address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) + # Try to remove dangling items if p0 cache is too large. + if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): + self.remove_dangling_items() + self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality + address_item = self.address_as_item( + address, monotonic_id, mm_item[0].modality + ) + return address_item, mm_item[1] + except (ValueError, MemoryError) as e: + # put may fail if the object is too large or + # the cache is full. + # In this case we log the error and keep the original mm_input. + logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e) + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + self._p0_cache.clear() + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._stat(delta=delta) + + def remove_dangling_items(self) -> None: + """Remove items that are no longer in the shared memory cache.""" + cached_hashes = self._shm_cache.key_index.keys() + dangling_hashes = set(self._p0_cache.keys()) - cached_hashes + for mm_hash in dangling_hashes: + del self._p0_cache[mm_hash] + + def address_as_item( + self, address: int, monotonic_id: int, modality: str + ) -> MultiModalKwargsItem: + addr_elem = MultiModalFieldElem( + modality=modality, + key="address", + data=address, + field=MultiModalBatchedField(), + ) + id_elem = MultiModalFieldElem( + modality=modality, + key="monotonic_id", + data=monotonic_id, + field=MultiModalBatchedField(), + ) + mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem]) + return mm_item + + +def _enable_processor_cache( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +) -> bool: + if not mm_registry.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_gb > 0 + + +def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: + parallel_config = vllm_config.parallel_config + supports_ipc_cache = ( + parallel_config._api_process_count == 1 + and parallel_config.data_parallel_size == 1 + ) or parallel_config.data_parallel_external_lb + + return supports_ipc_cache + + +def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool: + """Whether the shared memory based cache should be enabled.""" + + if not _enable_ipc_cache(vllm_config): + return False + + mm_config = vllm_config.model_config.get_multimodal_config() + + return mm_config.mm_processor_cache_type == "shm" + + +def processor_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalProcessorCache]: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return MultiModalProcessorOnlyCache(model_config) + + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalProcessorSenderCache(model_config) + return ShmObjectStoreSenderCache(vllm_config) + + +def processor_only_cache_from_config( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +): + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + if not _enable_processor_cache(model_config, mm_registry): + return None + + return MultiModalProcessorOnlyCache(model_config) + + +class BaseMultiModalReceiverCache( + BaseMultiModalCache[Optional[MultiModalKwargsItem], MultiModalKwargsItem] +): + """The required interface for caches on P1.""" + + def get_and_update_features( + self, + mm_features: list["MultiModalFeatureSpec"], + ) -> list["MultiModalFeatureSpec"]: + """Update multimodal features with cached encoder outputs.""" + for feature in mm_features: + feature.data = self.get_and_update_item(feature.data, feature.identifier) + return mm_features + + +class MultiModalReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 when IPC caching is enabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes tensor + data) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalKwargsItem, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = mm_item + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 Worker Process when IPC caching is enabled. + + How to update each item: + + - If the item has an address, replace the input with the cached item. + - If not, return the input. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + shared_worker_lock: LockType, + ) -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=False, # Server is a reader + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=shared_worker_lock, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + assert mm_item is not None, f"Expected an address item for {mm_hash=}" + if "address" in mm_item: + address = cast(int, mm_item["address"].data) + monotonic_id = cast(int, mm_item["monotonic_id"].data) + return self._shm_cache.get(address, monotonic_id) + + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + + +def engine_receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalReceiverCache]: + """ + This is used in the engine process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="lru". + """ + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalReceiverCache(model_config) + + return None + + +def worker_receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", + shared_worker_lock: LockType, +) -> Optional[BaseMultiModalReceiverCache]: + """ + This is used in the worker process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="shm". + """ + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + if not _enable_mm_input_shm_cache(vllm_config): + return None + + return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py new file mode 100644 index 0000000000000..36518c6bdb55a --- /dev/null +++ b/vllm/multimodal/evs.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import typing +from typing import Union + +import torch + + +def compute_retained_tokens_count( + tokens_per_frame: int, num_frames: int, q: float +) -> int: + """ + Compute the number of retained tokens for a given video. + Method ensures that we retain all the tokens from the first frame + regardless of the pruning rate. + + Args: + tokens_per_frame: The number of tokens per frame. + num_frames: The total number of frames. + q: The pruning rate. + + Returns: + The number of retained tokens. + """ + total_tokens = tokens_per_frame * num_frames + evs_num_tokens = int(total_tokens * (1 - q)) + min_num_tokens = tokens_per_frame + return max(min_num_tokens, evs_num_tokens) + + +def compute_retention_mask( + video_embeds: torch.Tensor, + video_size_thw: Union[torch.LongTensor, tuple[int, int, int]], + spatial_merge_size: int, + q: float, +) -> torch.Tensor: + """ + Computes the retention mask for input video embeddings. + + Args: + video_embeds (`torch.Tensor`): The input video embeddings + of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)` + video_size_thw (`torch.LongTensor` of shape `(3)`): + The temporal, height and width of video. + spatial_merge_size: Size reduction for rows & cols dimensions. + q: (`float`): Pruning rate factor [0,1) + + Returns: + `torch.Tensor`: The retention mask for the video embeddings of + `(T * H * W // spatial_merge_size ^ 2)` shape. + """ + T, H, W = map(int, video_size_thw) + + # Use reshape instead of einops to avoid graph breaks + video_embeds = video_embeds.reshape( + T, + H // spatial_merge_size, + W // spatial_merge_size, + video_embeds.size(-1), + ) + tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size) + # Core EVS + similarity = torch.nn.functional.cosine_similarity( + video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1 + ) + dissimilarity = 1 - similarity + + # Always ensure we include all tokens from the first frame + dissimilarity = torch.cat( + [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0 + ) + + dissimilarity_flat = dissimilarity.view(-1) + order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True) + retain_num_tokens = compute_retained_tokens_count( + tokens_per_frame=tokens_per_frame, num_frames=T, q=q + ) + topk_indices = order[:retain_num_tokens] + + retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool) + retention_mask[topk_indices] = True + retention_mask = retention_mask.reshape(dissimilarity.size()) + + mask = retention_mask.view(-1) # "T H W -> (T H W)" + return mask + + +def compute_mrope_for_media( + video_size_thw: torch.LongTensor, + spatial_merge_size: int, + tokens_per_second: float = 1.0, + video_second_per_grid: float = 1.0, +) -> torch.Tensor: + """ + Computes the mrope for video embeddings based on the grid dimensions. + Computed mrope positions match original qwen 2.5 implementation, + but positions are built for media being the first element in sequence. + + Args: + video_size_thw: Media size (num frames, rows, cols) + spatial_merge_size: Size reduction for rows & cols dimensions. + tokens_per_second: Number of tokens per second. + video_second_per_grid: Number of seconds per video. + + Returns: + Tensor of shape `(T * H * W, 4)` where last dimension + represents mrope positions [0:3), while the last channel + contains value of llm_grid_w repeated for all positions. + """ + llm_grid_t = video_size_thw[0] + llm_grid_h = video_size_thw[1] // spatial_merge_size + llm_grid_w = video_size_thw[2] // spatial_merge_size + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .mul(tokens_per_second * video_second_per_grid) + ) + .long() + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_grid_w = ( + torch.tensor([llm_grid_w]) + .view(1, 1, 1) + .expand(llm_grid_t, llm_grid_h, llm_grid_w) + .flatten() + ) + + positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1) + return positions + + +def recompute_mrope_positions( + input_ids: torch.LongTensor, + multimodal_positions: list[torch.Tensor], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + vision_start_token_id: int, + image_token_id: int, + video_token_id: int, +) -> tuple[torch.LongTensor, int]: + """ + Update part of input mrope positions. + Original mrope_positions are computed incorrectly, so once we prune media + tokens we should reflect this in the mrope positions for the LLM. + + This method supports chunked prefill approach where + multimodal_embeddings are passed to LLM in chunks, so input + multimodal_embeddings may contain zero, some or even some part of all + multimodal_embeddings for a given prompt. + + Each multimodal_positions has 4 extra channels + (First 3 channels corresponds to original 3 mrope positions, last channel + is the maximum width of the media repeated). Provided multimodal_positions + do not reflect location of media position in sequence - they are computed + like the media is in the 0-th position in the sequence. + + Method works as follows: it recomputes mrope_positions starting from the + `num_computed_tokens` for `total_len_of_multimodal_embeddings` and then + shifts all text tokens that goes after total_len_of_multimodal_embeddings. + + It also handles case when multimodal_embeddings is partial + (e.g. one media is split into two prefill stages) + + Args: + input_ids: (N,) All input tokens of the prompt (entire sequence). + multimodal_positions: List of mrope positsions for each media. + mrope_positions: Existing mrope positions (4, N) for entire sequence. + num_computed_tokens: A number of computed tokens so far. + vision_start_token_id: Token indicating start of vision media. + image_token_id: Image token id + video_token_id: Video token id + + Returns: + Tuple of (mrope_positions, mrope_position_delta). + """ + + # Tensors + positions: torch.LongTensor = typing.cast( + torch.LongTensor, mrope_positions.clone() + ) # (3, N) + N = input_ids.numel() + + image_mask = input_ids.eq(image_token_id) + video_mask = input_ids.eq(video_token_id) + media_mask = image_mask | video_mask + text_mask = ~media_mask + + # Early exit: no media in this chunk + if len(multimodal_positions) == 0: + delta = int((positions.max().item() + 1) - N) if positions.numel() else -N + return positions, delta + + total_mm_tokens = torch.count_nonzero(media_mask) + seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens]) + + # Early exit: we've updated positions for all media tokens + # (and consequently - for all remaining text tokens) + if seen_mm_tokens == total_mm_tokens: + delta = int((positions.max().item() + 1) - N) if positions.numel() else -N + return positions, delta + + vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[ + 0 + ] + + for mm_pos in multimodal_positions: + # Each mm_pos can be a complete embedding for single media + # or it can be a part of a single media (due to chunked prefill) + + # Cases to cover + # - Current prefill chunk has no vision start indexes at all + # - Vision start token appeared in previous prefill round + # - Regular case + seen_vision_start_indices = vision_start_indices[ + vision_start_indices < num_computed_tokens + ] + + if len(seen_vision_start_indices): + # If we have encountered some vision start indexes, + # then we should check the condition: + # | --- prefill 1 ------| ---- prefill 2 ----- | + # | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT| + last_vision_start_token = seen_vision_start_indices[-1] + seem_mm_tokens_before_last_vision_start = torch.count_nonzero( + media_mask[:last_vision_start_token] + ) + in_the_middle_of_media = ( + seen_mm_tokens > seem_mm_tokens_before_last_vision_start + ) + + if in_the_middle_of_media: + mm_embeddings_seen = ( + seen_mm_tokens - seem_mm_tokens_before_last_vision_start + ) + global_mm_start = last_vision_start_token + else: + # We have completed previous mm_embedding part and + # ready to start a new one + next_vision_start_token = vision_start_indices[ + vision_start_indices >= num_computed_tokens + ][0] + mm_embeddings_seen = 0 + global_mm_start = next_vision_start_token + + else: + # If there were no vision start indexes so far, + # let's find first vision start index + next_vision_start_token = vision_start_indices[ + vision_start_indices >= num_computed_tokens + ][0] + + mm_embeddings_seen = 0 + global_mm_start = next_vision_start_token + + # Offset right after vision_start_token + base = positions[-1, global_mm_start] + 1 + local_start = global_mm_start + 1 + mm_embeddings_seen + local_end = local_start + mm_pos.shape[1] + positions[:, local_start:local_end] = mm_pos[0:3] + base + + # mm_pos[3, 0] is the max width of the media + offset = mm_pos[3, 0] + base + + text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0) + + positions[:, local_end:N] = text_pos_sum + offset - 1 + + # Include distance to the next vision start token + num_computed_tokens += mm_pos.shape[1] + + mrope_positions_delta = (positions.max() + 1 - N).item() + return positions, mrope_positions_delta diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index 210a4ec762879..91d86cd9a1897 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -3,7 +3,7 @@ import pickle import uuid -from collections.abc import Iterable, Mapping +from collections.abc import Iterable from typing import Union import numpy as np @@ -12,67 +12,79 @@ from blake3 import blake3 from PIL import Image from vllm.logger import init_logger -from vllm.multimodal.image import convert_image_mode logger = init_logger(__name__) -MultiModalHashDict = Mapping[str, list[str]] -""" -A dictionary containing hashes for items in each modality. -""" - class MultiModalHasher: - @classmethod - def serialize_item(cls, obj: object) -> Union[bytes, memoryview]: + def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]: # Simple cases - if isinstance(obj, str): - return obj.encode("utf-8") if isinstance(obj, (bytes, memoryview)): - return obj + return (obj,) + if isinstance(obj, str): + return (obj.encode("utf-8"),) if isinstance(obj, (int, float)): - return np.array(obj).tobytes() + return (np.array(obj).tobytes(),) if isinstance(obj, Image.Image): exif = obj.getexif() if Image.ExifTags.Base.ImageID in exif and isinstance( - exif[Image.ExifTags.Base.ImageID], uuid.UUID): + exif[Image.ExifTags.Base.ImageID], uuid.UUID + ): # If the image has exif ImageID tag, use that - return exif[Image.ExifTags.Base.ImageID].bytes - return cls.item_to_bytes( - "image", np.asarray(convert_image_mode(obj, "RGBA"))) + return (exif[Image.ExifTags.Base.ImageID].bytes,) + data = {"mode": obj.mode, "data": np.asarray(obj)} + if obj.palette is not None: + data["palette"] = obj.palette.palette + if obj.palette.rawmode is not None: + data["palette_rawmode"] = obj.palette.rawmode + return cls.iter_item_to_bytes("image", data) if isinstance(obj, torch.Tensor): - return cls.item_to_bytes("tensor", obj.cpu().numpy()) + tensor_obj: torch.Tensor = obj.cpu() + tensor_dtype = tensor_obj.dtype + tensor_shape = tensor_obj.shape + + # NumPy does not support bfloat16. + # Workaround: View the tensor as a contiguous 1D array of bytes + if tensor_dtype == torch.bfloat16: + tensor_obj = tensor_obj.contiguous() + tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8) + + return cls.iter_item_to_bytes( + "tensor", + { + "original_dtype": str(tensor_dtype), + "original_shape": tuple(tensor_shape), + "data": tensor_obj.numpy(), + }, + ) + return cls.iter_item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first - arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() - return cls.item_to_bytes("ndarray", { - "dtype": obj.dtype.str, - "shape": obj.shape, - "data": arr_data, - }) - + arr_data = ( + obj.view(np.uint8).data if obj.flags.c_contiguous else obj.tobytes() + ) + return cls.iter_item_to_bytes( + "ndarray", + { + "dtype": obj.dtype.str, + "shape": obj.shape, + "data": arr_data, + }, + ) logger.warning( - "No serialization method found for %s. " - "Falling back to pickle.", type(obj)) + "No serialization method found for %s. Falling back to pickle.", type(obj) + ) - return pickle.dumps(obj) - - @classmethod - def item_to_bytes( - cls, - key: str, - obj: object, - ) -> bytes: - return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj)) + return (pickle.dumps(obj),) @classmethod def iter_item_to_bytes( cls, key: str, obj: object, - ) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]: + ) -> Iterable[Union[bytes, memoryview]]: # Recursive cases if isinstance(obj, (list, tuple)): for i, elem in enumerate(obj): @@ -81,17 +93,15 @@ class MultiModalHasher: for k, v in obj.items(): yield from cls.iter_item_to_bytes(f"{key}.{k}", v) else: - key_bytes = key.encode("utf-8") - value_bytes = cls.serialize_item(obj) - yield key_bytes, value_bytes + yield key.encode("utf-8") + yield from cls.serialize_item(obj) @classmethod def hash_kwargs(cls, **kwargs: object) -> str: hasher = blake3() for k, v in kwargs.items(): - for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v): - hasher.update(k_bytes) - hasher.update(v_bytes) + for bytes_ in cls.iter_item_to_bytes(k, v): + hasher.update(bytes_) return hasher.hexdigest() diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 1006c1ce4b241..f50ab1faebbad 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -12,9 +12,9 @@ from PIL import Image from .base import MediaIO -def rescale_image_size(image: Image.Image, - size_factor: float, - transpose: int = -1) -> Image.Image: +def rescale_image_size( + image: Image.Image, size_factor: float, transpose: int = -1 +) -> Image.Image: """Rescale the dimensions of an image by a constant factor.""" new_width = int(image.width * size_factor) new_height = int(image.height * size_factor) @@ -26,7 +26,7 @@ def rescale_image_size(image: Image.Image, def rgba_to_rgb( image: Image.Image, - background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255) + background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255), ) -> Image.Image: """Convert an RGBA image to RGB with filled background color.""" assert image.mode == "RGBA" @@ -45,7 +45,6 @@ def convert_image_mode(image: Image.Image, to_mode: str): class ImageMediaIO(MediaIO[Image.Image]): - def __init__(self, image_mode: str = "RGB", **kwargs) -> None: super().__init__() @@ -59,18 +58,21 @@ class ImageMediaIO(MediaIO[Image.Image]): # Extract RGBA background color from kwargs if provided # Default to white background for backward compatibility - rgba_bg = kwargs.get('rgba_background_color', (255, 255, 255)) + rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255)) # Convert list to tuple for consistency if isinstance(rgba_bg, list): rgba_bg = tuple(rgba_bg) # Validate rgba_background_color format - if not (isinstance(rgba_bg, tuple) and len(rgba_bg) == 3 - and all(isinstance(c, int) and 0 <= c <= 255 - for c in rgba_bg)): + if not ( + isinstance(rgba_bg, tuple) + and len(rgba_bg) == 3 + and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg) + ): raise ValueError( "rgba_background_color must be a list or tuple of 3 integers " - "in the range [0, 255].") + "in the range [0, 255]." + ) self.rgba_background_color = rgba_bg def _convert_image_mode(self, image: Image.Image) -> Image.Image: @@ -108,11 +110,10 @@ class ImageMediaIO(MediaIO[Image.Image]): image.save(buffer, image_format) data = buffer.getvalue() - return pybase64.b64encode(data).decode('utf-8') + return pybase64.b64encode(data).decode("utf-8") class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): - def __init__(self) -> None: super().__init__() @@ -127,4 +128,4 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): return torch.load(filepath, weights_only=True) def encode_base64(self, media: torch.Tensor) -> str: - return pybase64.b64encode(media.numpy()).decode('utf-8') + return pybase64.b64encode(media.numpy()).decode("utf-8") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 581f9a109cce6..bec3099a99bc5 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,14 +7,13 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast, final import numpy as np -from typing_extensions import NotRequired, TypeAlias, deprecated +from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from vllm.utils import LazyLoader, full_groupby, is_list_of -from vllm.utils.jsontree import JSONTree, json_map_leaves +from vllm.utils.jsontree import json_map_leaves if TYPE_CHECKING: import torch @@ -22,7 +21,8 @@ if TYPE_CHECKING: from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature - from .hasher import MultiModalHashDict + from .processing import MultiModalHashes + else: torch = LazyLoader("torch", globals(), "torch") @@ -34,8 +34,9 @@ A `transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace `ImageProcessor`. """ -HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor", - list[np.ndarray], list["torch.Tensor"]] +HfVideoItem: TypeAlias = Union[ + list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"] +] """ A `transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace `VideoProcessor`. @@ -57,8 +58,9 @@ which are treated as image embeddings; these are directly passed to the model without HF processing. """ -VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor", - tuple[HfVideoItem, dict[str, Any]]] +VideoItem: TypeAlias = Union[ + HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]] +] """ A `transformers.video_utils.VideoInput` representing a single video item. This can be passed to a HuggingFace `VideoProcessor` @@ -69,8 +71,7 @@ which are treated as video embeddings; these are directly passed to the model without HF processing. """ -AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], - "torch.Tensor"] +AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"] """ Represents a single audio item, which can be passed to a HuggingFace `AudioProcessor`. @@ -84,9 +85,10 @@ which are treated as audio embeddings; these are directly passed to the model without HF processing. """ -ModalityData: TypeAlias = Union[_T, list[_T]] +ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None] """ -Either a single data item, or a list of data items. +Either a single data item, or a list of data items. Can only be None if UUID +is provided. The number of data items allowed per modality is restricted by `--limit-mm-per-prompt`. @@ -115,6 +117,16 @@ The built-in modalities are defined by [`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. """ +MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]] +""" +A dictionary containing user-provided UUIDs for items in each modality. +If a UUID for an item is not provided, its entry will be `None` and +MultiModalHasher will compute a hash for the item. + +The UUID will be used to identify the item for all caching purposes +(input processing caching, embedding caching, prefix caching, etc). +""" + @dataclass(frozen=True) class PlaceholderRange: @@ -165,8 +177,12 @@ class PlaceholderRange: return nested_tensors_equal(self.is_embed, other.is_embed) -NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"], - "torch.Tensor", tuple["torch.Tensor", ...]] +NestedTensors: TypeAlias = Union[ + list["NestedTensors"], + list["torch.Tensor"], + "torch.Tensor", + tuple["torch.Tensor", ...], +] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ @@ -181,23 +197,48 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: return isinstance(a, torch.Tensor) and torch.equal(b, a) if isinstance(a, list): - return (isinstance(b, list) - and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + return isinstance(b, list) and all( + nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b) + ) if isinstance(b, list): - return (isinstance(a, list) - and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))) + return isinstance(a, list) and all( + nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a) + ) # Both a and b are scalars return a == b -BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] +BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via [`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. """ +@dataclass +class MultiModalFeatureSpec: + """ + Represents a single multimodal input with its processed data and metadata. + + Used by the V1 engine to track multimodal data through processing and + caching. A request containing multiple multimodal items will have one + MultiModalFeatureSpec per item. + """ + + data: Optional["MultiModalKwargsItem"] + """Multimodal data for this feature""" + + modality: str + """Based on the input, e.g., "image", "audio", "video".""" + + identifier: str + """mm_hash or uuid for caching encoder outputs.""" + + mm_position: PlaceholderRange + """e.g., PlaceholderRange(offset=2, length=336)""" + + @dataclass class MultiModalFieldElem: """ @@ -245,9 +286,11 @@ class MultiModalFieldElem: else: data_equal = nested_tensors_equal(self.data, other.data) - return ((self.modality, self.key) == (other.modality, other.key) - and data_equal - and type(self.field) == type(other.field)) # noqa: E721 + return ( + (self.modality, self.key) == (other.modality, other.key) + and data_equal + and type(self.field) is type(other.field) + ) # noqa: E721 @dataclass(frozen=True) @@ -342,6 +385,7 @@ class MultiModalBatchedField(BaseMultiModalField): pin_memory: bool, ) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + batch = cast(list[torch.Tensor], batch) if len(batch) == 1: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.stack(batch)` @@ -349,10 +393,12 @@ class MultiModalBatchedField(BaseMultiModalField): return batch[0].unsqueeze(0).contiguous() first_shape = batch[0].shape if all(elem.shape == first_shape for elem in batch): - out = torch.empty((len(batch), *batch[0].shape), - dtype=batch[0].dtype, - device=batch[0].device, - pin_memory=pin_memory) + out = torch.empty( + (len(batch), *batch[0].shape), + dtype=batch[0].dtype, + device=batch[0].device, + pin_memory=pin_memory, + ) return torch.stack(batch, out=out) return batch @@ -365,6 +411,7 @@ class MultiModalFlatField(BaseMultiModalField): [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat] [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes] """ + slices: Union[Sequence[slice], Sequence[Sequence[slice]]] dim: int = 0 @@ -376,8 +423,9 @@ class MultiModalFlatField(BaseMultiModalField): ) -> Sequence[MultiModalFieldElem]: field_factory = self._field_factory(modality=modality, key=key) if not is_list_of(self.slices, slice, check="all"): - assert isinstance(data, torch.Tensor), \ + assert isinstance(data, torch.Tensor), ( "torch.Tensor is required for multiple slices" + ) return [field_factory(data[cast(slice, s)]) for s in self.slices] def _reduce_data( @@ -387,6 +435,7 @@ class MultiModalFlatField(BaseMultiModalField): pin_memory: bool, ) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + batch = cast(list[torch.Tensor], batch) if len(batch) == 1: # An optimization when `batch` contains only one tensor: # - produce exactly same result as `torch.concat(batch)` @@ -396,17 +445,19 @@ class MultiModalFlatField(BaseMultiModalField): dim = self.dim + (self.dim < 0) * len(batch[0].shape) def _shape_before_after(tensor: torch.Tensor): - return tensor.shape[:dim], tensor.shape[dim + 1:] + return tensor.shape[:dim], tensor.shape[dim + 1 :] first_shape = _shape_before_after(batch[0]) if all(_shape_before_after(elem) == first_shape for elem in batch): shape_before, shape_after = first_shape shape_concat = sum(item.shape[dim] for item in batch) - out = torch.empty((*shape_before, shape_concat, *shape_after), - dtype=batch[0].dtype, - device=batch[0].device, - pin_memory=pin_memory) + out = torch.empty( + (*shape_before, shape_concat, *shape_after), + dtype=batch[0].dtype, + device=batch[0].device, + pin_memory=pin_memory, + ) return torch.concat(batch, dim=self.dim, out=out) assert self.dim == 0, "dim == 0 is required for nested list" @@ -419,6 +470,7 @@ class MultiModalSharedField(BaseMultiModalField): Info: [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared] """ + batch_size: int def build_elems( @@ -440,7 +492,6 @@ class MultiModalSharedField(BaseMultiModalField): class MultiModalFieldConfig: - @staticmethod def batched(modality: str): """ @@ -471,9 +522,11 @@ class MultiModalFieldConfig: ) @staticmethod - def flat(modality: str, - slices: Union[Sequence[slice], Sequence[Sequence[slice]]], - dim: int = 0): + def flat( + modality: str, + slices: Union[Sequence[slice], Sequence[Sequence[slice]]], + dim: int = 0, + ): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -524,9 +577,7 @@ class MultiModalFieldConfig: ) @staticmethod - def flat_from_sizes(modality: str, - size_per_item: "torch.Tensor", - dim: int = 0): + def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -534,8 +585,8 @@ class MultiModalFieldConfig: Args: modality: The modality of the multi-modal item that uses this keyword argument. - slices: For each multi-modal item, the size of the slice that - is used to extract the data corresponding to it. + size_per_item: For each multi-modal item, the size of the slice + that is used to extract the data corresponding to it. dim: The dimension to slice, default to 0. Example: @@ -555,7 +606,7 @@ class MultiModalFieldConfig: ``` Given: - slices: [3, 4, 2] + size_per_item: [3, 4, 2] dim: 1 Input: @@ -572,13 +623,17 @@ class MultiModalFieldConfig: """ if size_per_item.ndim != 1: - raise ValueError("size_per_item should be a 1-D tensor, " - f"but found shape: {size_per_item.shape}") + raise ValueError( + "size_per_item should be a 1-D tensor, " + f"but found shape: {size_per_item.shape}" + ) slice_idxs = [0, *accumulate(size_per_item)] - slices = [(slice(None, None, None), ) * dim + - (slice(slice_idxs[i], slice_idxs[i + 1]), ) - for i in range(len(size_per_item))] + slices = [ + (slice(None, None, None),) * dim + + (slice(slice_idxs[i], slice_idxs[i + 1]),) + for i in range(len(size_per_item)) + ] return MultiModalFieldConfig.flat(modality, slices, dim=dim) @@ -622,6 +677,9 @@ class MultiModalFieldConfig: self.field = field self.modality = modality + def __repr__(self) -> str: + return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})" + def build_elems( self, key: str, @@ -668,7 +726,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): return {key: elem.data for key, elem in self.items()} -class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) + + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): """ A dictionary of [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s @@ -700,7 +766,8 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): if len(set(batch_sizes.values())) > 1: raise ValueError( f"Cannot merge different batch sizes for {modality=}! " - f"Found: {batch_sizes=}") + f"Found: {batch_sizes=}" + ) batch_size = next(iter(batch_sizes.values())) for item_idx in range(batch_size): @@ -714,25 +781,47 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): items_by_modality = full_groupby(items, key=lambda x: x.modality) return MultiModalKwargsItems(items_by_modality) - def __getitem__(self, modality: str): + def __getitem__(self, modality: str) -> Sequence[_I]: if modality not in self: - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {set(self.keys())}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {set(self.keys())}" + ) - return super().__getitem__(modality) + return super().__getitem__(modality) # type: ignore[return-value] + + def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]": + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError(f"Found empty mm_items[{modality}][{i}]") + + return self # type: ignore[return-value] def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for items in self.values(): - for item in items: + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError( + f"Cannot build data from empty mm_items[{modality}][{i}]" + ) + for key, elem in item.items(): elems_by_key[key].append(elem) - return MultiModalKwargs({ - key: - elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() if len(elems) > 0 - }) + return MultiModalKwargs( + { + key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + } + ) + + +MultiModalKwargsOptionalItems: TypeAlias = Union[ + MultiModalKwargsItems[MultiModalKwargsItem], + MultiModalKwargsItems[Optional[MultiModalKwargsItem]], +] class MultiModalKwargs(UserDict[str, NestedTensors]): @@ -742,33 +831,36 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): """ @staticmethod - @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and " - "will be removed in v0.13. " - "Please use `MultiModalKwargsItems.from_hf_inputs` and " - "access the tensor data using `.get_data()`.") + @deprecated( + "`MultiModalKwargs.from_hf_inputs` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_hf_inputs` and " + "access the tensor data using `.get_data()`." + ) def from_hf_inputs( hf_inputs: "BatchFeature", config_by_key: Mapping[str, MultiModalFieldConfig], ): - return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \ - .get_data() + return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data() @staticmethod - @deprecated("`MultiModalKwargs.from_items` is deprecated and " - "will be removed in v0.13. " - "Please use `MultiModalKwargsItems.from_seq` and " - "access the tensor data using `.get_data()`.") + @deprecated( + "`MultiModalKwargs.from_items` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_seq` and " + "access the tensor data using `.get_data()`." + ) def from_items( items: Sequence[MultiModalKwargsItem], *, pin_memory: bool = False, ): - return MultiModalKwargsItems.from_seq(items) \ - .get_data(pin_memory=pin_memory) + return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory) @staticmethod - def _try_stack(nested_tensors: NestedTensors, - pin_memory: bool = False) -> NestedTensors: + def _try_stack( + nested_tensors: NestedTensors, pin_memory: bool = False + ) -> NestedTensors: """ Stack the inner dimensions that have the same shape in a nested list of tensors. @@ -785,9 +877,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): if isinstance(nested_tensors, (int, float)): return torch.tensor(nested_tensors) - stacked = [ - MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors - ] + stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): # Only tensors (not lists) can be stacked. return stacked @@ -803,16 +893,19 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): # The tensors have incompatible shapes and can't be stacked. return tensors_ - outputs = torch.empty(len(tensors_), - *tensors_[0].shape, - dtype=tensors_[0].dtype, - device=tensors_[0].device, - pin_memory=pin_memory) + outputs = torch.empty( + len(tensors_), + *tensors_[0].shape, + dtype=tensors_[0].dtype, + device=tensors_[0].device, + pin_memory=pin_memory, + ) return torch.stack(tensors_, out=outputs) @staticmethod - def batch(inputs_list: list["MultiModalKwargs"], - pin_memory: bool = False) -> BatchedTensorInputs: + def batch( + inputs_list: list["MultiModalKwargs"], pin_memory: bool = False + ) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -844,19 +937,17 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): *, device: torch.types.Device, ) -> BatchedTensorInputs: - json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) - - json_mapped = json_map_leaves( + return json_map_leaves( lambda x: x.to(device=device, non_blocking=True), - json_inputs, + batched_inputs, ) - return cast(BatchedTensorInputs, json_mapped) - def __getitem__(self, key: str): if key not in self: - raise KeyError(f"Keyword argument {key!r} not found. " - f"Available keys: {set(self.keys())}") + raise KeyError( + f"Keyword argument {key!r} not found. " + f"Available keys: {set(self.keys())}" + ) return super().__getitem__(key) @@ -889,19 +980,13 @@ class MultiModalInputs(TypedDict): type: Literal["multimodal"] """The type of inputs.""" - prompt: str - """The processed prompt text.""" - prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[list[int]] - """The token type IDs of the prompt.""" - - mm_kwargs: MultiModalKwargsItems + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: "MultiModalHashDict" + mm_hashes: "MultiModalHashes" """The hashes of the multi-modal data.""" mm_placeholders: "MultiModalPlaceholderDict" @@ -923,11 +1008,5 @@ class MultiModalEncDecInputs(MultiModalInputs): ready to be passed to vLLM internals. """ - encoder_prompt: str - """The processed encoder prompt text.""" - encoder_prompt_token_ids: list[int] """The processed token IDs of the encoder prompt.""" - - encoder_token_type_ids: NotRequired[list[int]] - """The token type IDs of the encoder prompt.""" diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 88bb99529f200..8fdc5cf721d08 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -4,8 +4,16 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional, - TypeVar, Union) +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + NamedTuple, + Optional, + TypeVar, + Union, +) import numpy as np import torch @@ -14,9 +22,18 @@ from typing_extensions import TypeAlias, TypeGuard, assert_never from vllm.utils import LazyLoader, is_list_of from .audio import AudioResampler -from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, - ImageItem, ModalityData, MultiModalDataDict, - MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) +from .inputs import ( + AudioItem, + HfAudioItem, + HfImageItem, + HfVideoItem, + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) _T = TypeVar("_T") _I = TypeVar("_I") @@ -36,12 +53,11 @@ class ModalityDataItems(ABC, Generic[_T, _I]): def __init__(self, data: _T, modality: str) -> None: super().__init__() - self.data = data + self.data: _T = data self.modality = modality def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"len={len(self)})") + return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})" def __len__(self) -> int: return self.get_count() @@ -51,8 +67,7 @@ class ModalityDataItems(ABC, Generic[_T, _I]): if TYPE_CHECKING: # Auto-generated - def __iter__(self) -> Iterator[_I]: - ... + def __iter__(self) -> Iterator[_I]: ... @abstractmethod def get_count(self) -> int: @@ -95,8 +110,9 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): return {} -class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], - torch.Tensor]): +class EmbeddingItems( + ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor] +): """ Base class for data items that are expressed as a batched embedding tensor, or a list of embedding tensors (one per item). @@ -118,8 +134,9 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], return len(self.get(item_idx)) -class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor]]): +class DictEmbeddingItems( + ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]] +): """ Base class for data items that are expressed as a dictionary of tensors. @@ -143,8 +160,10 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], missing_required_data_keys = required_fields - data.keys() if missing_required_data_keys: data_keys = set(data.keys()) - msg = (f"The data should contain the fields: {required_fields}, " - f"but only found the following keys: {data_keys}") + msg = ( + f"The data should contain the fields: {required_fields}, " + f"but only found the following keys: {data_keys}" + ) raise ValueError(msg) fields_config = fields_factory(data) @@ -176,8 +195,9 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): - - def __init__(self, data: Sequence[HfAudioItem]) -> None: + def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None: + if data is None: + data = [None] super().__init__(data, "audio") def get_audio_length(self, item_idx: int) -> int: @@ -186,7 +206,6 @@ class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): class AudioEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "audio") @@ -197,8 +216,9 @@ class ImageSize(NamedTuple): class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): - - def __init__(self, data: Sequence[HfImageItem]) -> None: + def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None: + if data is None: + data = [None] super().__init__(data, "image") def get_image_size(self, item_idx: int) -> ImageSize: @@ -214,19 +234,20 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): class ImageEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "image") class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): - def __init__( self, - data: Sequence[HfVideoItem], - metadata: Optional[Union[dict[str, Any], - list[Optional[dict[str, Any]]]]] = None, + data: Optional[Sequence[HfVideoItem]], + metadata: Optional[ + Union[dict[str, Any], list[Optional[dict[str, Any]]]] + ] = None, ) -> None: + if data is None: + data = [None] super().__init__(data, "video") self.metadata = metadata @@ -246,7 +267,6 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): class VideoEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "video") @@ -270,8 +290,10 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): if modality not in self: if strict: available_modalities = set(self.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) return 0 @@ -292,20 +314,25 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): """ if modality not in self: available_modalities = set(self.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) items = self[modality] if not isinstance(items, typ): - raise TypeError(f"Invalid type of data items for {modality=}. " - f"Expected type: {typ}, but " - f"found type: {type(items)}") + raise TypeError( + f"Invalid type of data items for {modality=}. " + f"Expected type: {typ}, but " + f"found type: {type(items)}" + ) return items # type: ignore[return-value] -ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], - Optional[ModalityDataItems[Any, Any]]] +ModalityDataParser: TypeAlias = Callable[ + [ModalityData[Any]], Optional[ModalityDataItems[Any, Any]] +] class MultiModalDataParser: @@ -334,7 +361,7 @@ class MultiModalDataParser: self.video_needs_metadata = video_needs_metadata def _is_embeddings( - self, data: object + self, data: object ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]: if isinstance(data, torch.Tensor): return data.ndim == 3 @@ -385,18 +412,24 @@ class MultiModalDataParser: self, data: ModalityData[AudioItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return AudioProcessorItems(None) + # also check single audio item with sampling rate - if self._is_empty(data) or (isinstance(data, tuple) - and self._is_empty(data[0])): + if self._is_empty(data) or ( + isinstance(data, tuple) and self._is_empty(data[0]) + ): return None if self._is_embeddings(data): return AudioEmbeddingItems(data) - if (is_list_of(data, float) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 1 - or isinstance(data, tuple)): + if ( + is_list_of(data, float) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 1 + or isinstance(data, tuple) + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -409,8 +442,7 @@ class MultiModalDataParser: if orig_sr is None: new_audio = audio else: - new_audio = self.audio_resampler.resample(audio, - orig_sr=orig_sr) + new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr) new_audios.append(new_audio) @@ -420,15 +452,20 @@ class MultiModalDataParser: self, data: ModalityData[ImageItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return ImageProcessorItems(None) + if self._is_empty(data): return None if self._is_embeddings(data): return ImageEmbeddingItems(data) - if (isinstance(data, PILImage.Image) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 3): + if ( + isinstance(data, PILImage.Image) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 3 + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -441,15 +478,20 @@ class MultiModalDataParser: self, data: ModalityData[VideoItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return VideoProcessorItems(None) + if self._is_empty(data): return None if self._is_embeddings(data): return VideoEmbeddingItems(data) - if (is_list_of(data, PILImage.Image) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 4): + if ( + is_list_of(data, PILImage.Image) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 4 + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -480,8 +522,7 @@ class MultiModalDataParser: "video": self._parse_video_data, } - def parse_mm_data(self, - mm_data: MultiModalDataDict) -> MultiModalDataItems: + def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: subparsers = self._get_subparsers() mm_items = MultiModalDataItems() diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 55fd1479d2de5..5c3739e29d101 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,39 +1,61 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, - Sequence) -from dataclasses import dataclass, field +from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache -from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union, cast) +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NamedTuple, + Optional, + Protocol, + Union, + cast, + overload, +) import regex as re import torch -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never -from vllm.inputs import InputProcessingContext from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, - encode_tokens) -from vllm.utils import flatten_2d_lists, full_groupby +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens +from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides +from vllm.utils.jsontree import JSONTree, json_map_leaves -from .cache import MultiModalCache from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalInputs, - MultiModalKwargsItem, MultiModalKwargsItems, - PlaceholderRange) -from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, - MultiModalDataParser) +from .inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalKwargsOptionalItems, + MultiModalUUIDDict, + PlaceholderRange, +) +from .parse import ( + DictEmbeddingItems, + EmbeddingItems, + MultiModalDataItems, + MultiModalDataParser, +) if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from vllm.config import ModelConfig + + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -44,14 +66,59 @@ PromptSeq = Union[str, list[int]] """A token sequence (list of token IDs) or text.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, text, add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens( + tokenizer, list(token_ids), skip_special_tokens=skip_special_tokens + ) + + +def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: + if isinstance(seq, str): + return seq + + return _cached_decode(tokenizer, tuple(seq)) + + +def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: + if isinstance(seq, str): + return _cached_encode(tokenizer, seq, add_special_tokens=False) + + return seq + + +class _GetMatchIndex(Protocol): + def __call__( + self, + tokenizer: AnyTokenizer, + prompt: PromptSeq, + start_idx: int = 0, + ) -> Optional[int]: ... + + @dataclass class PromptIndex: """Resolves to an index in the prompt.""" - get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + + get_match_index: _GetMatchIndex class PromptIndexTargets: - @staticmethod def start() -> PromptIndex: """ @@ -59,7 +126,7 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: 0) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0) @staticmethod def prefix(seq: PromptSeq) -> PromptIndex: @@ -70,7 +137,11 @@ class PromptIndexTargets: def get_match_index( tokenizer: AnyTokenizer, prompt: PromptSeq, + start_idx: int = 0, ) -> Optional[int]: + if start_idx != 0: + return None + prefix = seq if isinstance(prompt, str): @@ -80,9 +151,7 @@ class PromptIndexTargets: else: if isinstance(prefix, str): # Make both `list[int]` - prefix = encode_tokens(tokenizer, - prefix, - add_special_tokens=False) + prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False) match_idx = len(prefix) return match_idx if prompt[:match_idx] == prefix else None @@ -96,14 +165,24 @@ class PromptIndexTargets: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: len(prompt)) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt)) -PromptTarget = Union[PromptSeq, PromptIndex] +UpdateTarget = Union[PromptSeq, PromptIndex] """ The token sequence or text to update. """ +PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget] +""" +Given the index of the processed item within +[`modality`][vllm.multimodal.processing.PromptUpdate.modality], +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + @dataclass class PromptUpdateDetails(Generic[_S]): @@ -112,7 +191,7 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -133,12 +212,12 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_text: str, ) -> "PromptUpdateDetails[_S]": - - def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_text) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + embed_token_ids = encode_tokens(tokenizer, embed_text) + token_ids = _seq2tokens(tokenizer, full) return torch.isin( - torch.tensor(full.token_ids), + torch.tensor(token_ids), torch.tensor(embed_token_ids), ) @@ -149,10 +228,12 @@ class PromptUpdateDetails(Generic[_S]): seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails( - full=seq, - is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, - ) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.tensor(token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -164,8 +245,7 @@ use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to specify which part. """ -PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], - PromptUpdateInfo] +PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], PromptUpdateInfo] """ Given the index of the processed item within [`modality`][vllm.multimodal.processing.PromptUpdate.modality], @@ -190,7 +270,7 @@ class PromptUpdate(ABC): modality: str """The modality for which the update is made.""" - target: PromptTarget + target: PromptUpdateTarget """The token sequence (or text) to update.""" @property @@ -205,10 +285,35 @@ class PromptUpdate(ABC): """Defines how to update the prompt.""" raise NotImplementedError - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": - return BoundPromptUpdate( - _origin=self, - tokenizer=tokenizer, + def _resolve_target(self, item_idx: int) -> UpdateTarget: + target = self.target + if callable(target): + target = target(item_idx) + + return target + + def _resolve_content(self, item_idx: int) -> PromptUpdateDetails: + content = self.content + if callable(content): + content = content(item_idx) + + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return content + + def resolve(self, item_idx: int) -> "ResolvedPromptUpdate": + """ + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output a copy of this object with its lazy attributes resolved. + """ + return ResolvedPromptUpdate( + modality=self.modality, + item_idx=item_idx, + mode=self.mode, + target=self._resolve_target(item_idx), + content=self._resolve_content(item_idx), ) @@ -309,11 +414,13 @@ class PromptReplacement(PromptUpdate): modality="image", target="<image>", replacement=PromptUpdateDetails( - full="".join([ - "<image_bos>", - "<image>" * image_feature_size, - "<image_eos>", - ]), + full="".join( + [ + "<image_bos>", + "<image>" * image_feature_size, + "<image_eos>", + ] + ), features="<image>" * image_feature_size, ), ) @@ -327,8 +434,9 @@ class PromptReplacement(PromptUpdate): modality="image", target=[image_token_id], replacement=PromptUpdateDetails( - full=([image_bos_id] + [image_token_id] * image_feature_size - + [image_eos_id]), + full=( + [image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id] + ), features=[image_token_id] * image_feature_size, ), ) @@ -355,39 +463,13 @@ class PromptReplacement(PromptUpdate): return UpdateMode.REPLACE -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: Optional[bool] = None, -) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: Optional[bool] = None, -) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - class _HasModalityAttr(Protocol): modality: str class _HasModalityProp(Protocol): - @property - def modality(self) -> str: - ... + def modality(self) -> str: ... _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) @@ -399,126 +481,98 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: return full_groupby(values, key=lambda x: x.modality) -@dataclass -class _BoundPromptSequence: - """ - A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound - to a tokenizer to automatically - convert between token sequence and text representations. - """ - tokenizer: AnyTokenizer = field(repr=False) +class PromptTargetMatch(NamedTuple): + start_idx: int + end_idx: int - _text: Optional[str] - _token_ids: Optional[list[int]] - @staticmethod - def from_seq( +@dataclass(frozen=True) +class ResolvedPromptUpdate: + """ + A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its + lazy attributes resolved, apart from those related to tokenization. + """ + + modality: str + """The modality for which the update is made.""" + + item_idx: int + """The index within `modality` of the item this update pertains to.""" + + mode: UpdateMode + """Defines how to update the prompt.""" + + target: UpdateTarget + """The token sequence (or text) to update.""" + + content: PromptUpdateDetails = field(repr=False) + """The placeholder tokens that are part of the update.""" + + def iter_token_matches( + self, + prompt: list[int], tokenizer: AnyTokenizer, - seq: PromptSeq, - ) -> "_BoundPromptSequence": - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) - - def __post_init__(self) -> None: - if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") - - @property - def text(self) -> str: - if self._text is None: - assert self._token_ids is not None - self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) - - return self._text - - @property - def token_ids(self) -> list[int]: - if self._token_ids is None: - assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) - - return self._token_ids - - -@dataclass -class _BoundPromptContent: - full: _BoundPromptSequence - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] - - -@dataclass -class BoundPromptUpdate: - """ - A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound - to a tokenizer to automatically convert - [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of - [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content] - between token sequence and text representations. - """ - _origin: PromptUpdate - tokenizer: AnyTokenizer = field(repr=False) - - def __post_init__(self) -> None: - self._content_cache = dict[int, _BoundPromptContent]() - - @property - def modality(self) -> str: - return self._origin.modality - - @property - def target(self) -> Union[_BoundPromptSequence, PromptIndex]: - """The token sequence (or text) to update.""" - target = self._origin.target + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target if isinstance(target, PromptIndex): - return target + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) - return _BoundPromptSequence.from_seq(self.tokenizer, target) + return - @property - def content(self) -> PromptUpdateContent: - """The placeholder tokens that are part of the update.""" - return self._origin.content + target_token_ids = _seq2tokens(tokenizer, target) - @property - def mode(self) -> UpdateMode: - """Defines how to update the prompt.""" - return self._origin.mode + for match in iter_token_matches(prompt, target_token_ids, start_idx=start_idx): + yield PromptTargetMatch(match.start_idx, match.end_idx) - def get_content(self, item_idx: int) -> _BoundPromptContent: - """ - Given the index of the processed item within - [`modality`][vllm.multimodal.processing.PromptUpdate.modality], - output the token sequence (or text) to update. - """ - content = self.content - if callable(content): - cache_key = item_idx - if cache_key in self._content_cache: - return self._content_cache[cache_key] + def iter_text_matches( + self, + prompt: str, + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target - content = content(item_idx) - else: - cache_key = None + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) + return + + target_text = _seq2text(tokenizer, target) + + for match in re.finditer(re.escape(target_text), prompt, pos=start_idx): + yield PromptTargetMatch(match.start(), match.end()) + + def iter_matches( + self, + prompt: Union[list[int], str], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + if isinstance(prompt, str): + return self.iter_text_matches(prompt, tokenizer, start_idx=start_idx) + + return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) + + def with_target(self, target: UpdateTarget): + return replace(self, target=target) + + def with_content(self, content: PromptUpdateInfo): if not isinstance(content, PromptUpdateDetails): content = PromptUpdateDetails.from_seq(content) - bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - content.full) - bound_content = _BoundPromptContent(full=bound_full, - is_embed=content.is_embed) - - if cache_key is not None: - self._content_cache[cache_key] = bound_content - - return bound_content + return replace(self, content=content) class _TokenMatch(NamedTuple): @@ -529,6 +583,8 @@ class _TokenMatch(NamedTuple): def iter_token_matches( token_ids: list[int], match_ids: list[int], + *, + start_idx: int = 0, ) -> Generator[_TokenMatch]: """ Yield each occurrence of `match_ids` in `token_ids`. @@ -541,7 +597,6 @@ def iter_token_matches( if match_len == 0: return - start_idx = 0 while start_idx < prompt_len - match_len + 1: end_idx = start_idx + match_len @@ -581,68 +636,6 @@ def replace_token_matches( return flatten_2d_lists(out_seqs) -@dataclass(repr=False) -class PromptTargetMatch(ABC): - _origin: BoundPromptUpdate - - @property - def modality(self) -> str: - return self._origin.modality - - @property - @abstractmethod - def start_idx(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def end_idx(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") - - -@dataclass(repr=False) -class _PromptTargetIndexMatch(PromptTargetMatch): - match_idx: int - - @property - def start_idx(self) -> int: - return self.match_idx - - @property - def end_idx(self) -> int: - return self.match_idx - - -@dataclass(repr=False) -class _PromptTargetTokenMatch(PromptTargetMatch): - match: _TokenMatch - - @property - def start_idx(self) -> int: - return self.match.start_idx - - @property - def end_idx(self) -> int: - return self.match.end_idx - - -@dataclass(repr=False) -class _PromptTargetTextMatch(PromptTargetMatch): - match: re.Match[str] - - @property - def start_idx(self) -> int: - return self.match.start() - - @property - def end_idx(self) -> int: - return self.match.end() - - @dataclass class PlaceholderFeaturesInfo: modality: str @@ -665,163 +658,159 @@ class PlaceholderFeaturesInfo: ) -def find_token_matches( - prompt: list[int], - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTokenMatch(update, match) - for match in iter_token_matches(prompt, target.token_ids) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] +_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] -def find_text_matches( - prompt: str, - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" +def _find_matches( + prompt: _S, + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, + *, + prev_end_idx: int = 0, + current_result: "MultiModalPromptUpdatesApplyResult", +) -> tuple[Optional[UpdateMode], list[_MatchToApply]]: + mode: Optional[UpdateMode] = None + mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]() - def get_matches(update: BoundPromptUpdate): - target = update.target + for modality, modality_updates in mm_prompt_updates.items(): + for item_idx, item_updates in enumerate(modality_updates): + if current_result[modality][item_idx] is not None: + continue # Updates have already been applied for this item - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] + for update_idx, update in enumerate(item_updates): + if (modality, item_idx) in mm_matches: + break # Already found a match for this item - return [_PromptTargetIndexMatch(update, match_idx)] + for match in update.iter_matches( + prompt, + tokenizer, + start_idx=prev_end_idx, + ): + # All matches should share the same mode + if mode is None: + mode = update.mode + elif mode != update.mode: + continue - return [ - _PromptTargetTextMatch(update, match) - for match in re.finditer(re.escape(target.text), prompt) - ] + mm_matches[(modality, item_idx)] = match, update_idx + break # Get only the first valid match per item - return [ - match for update in prompt_updates for match in get_matches(update) - ] + # Prioritize earlier matches + matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0]) + # To avoid conflicts, only replace one non-empty item at a time + if mode == UpdateMode.REPLACE: + matches_to_apply_ = list[_MatchToApply]() + has_non_empty_matches = False -def _resolve_matches( - prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], -) -> list[PromptTargetMatch]: - """ - Resolve `mm_matches` to ensure that there are no overlapping matches, - and sort them such that earlier matches take priority over later ones. - """ - matches = [m for matches in mm_matches.values() for m in matches] + for item in matches_to_apply: + _, (match, _) = item + if match.start_idx == match.end_idx: + matches_to_apply_.append(item) + elif not has_non_empty_matches: + has_non_empty_matches = True + matches_to_apply_.append(item) - seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) + matches_to_apply = matches_to_apply_ - for match in matches: - for idx in range(match.start_idx, match.end_idx): - if seen_matches[idx] is not None: - raise ValueError("Found overlapping matches " - f"({seen_matches[idx]} and {match}) " - f"at index={idx} of prompt={prompt}") - - seen_matches[idx] = match - - return sorted(matches, key=lambda x: x.start_idx) + return mode, matches_to_apply def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[_S]: - """Apply the updates in `mm_matches` to `prompt`.""" + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: + prompt_len = len(prompt) + out_seqs = list[Union[str, list[int]]]() - prev_end_idx = 0 - next_idx_by_modality = defaultdict[str, int](lambda: 0) + out_result: MultiModalPromptUpdatesApplyResult = { + m: [None] * len(items) for m, items in mm_prompt_updates.items() + } - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + start_idx = prev_end_idx = 0 + while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt + found = False - item_start_idx = next_idx_by_modality[modality] - max_item_count = mm_item_counts.get(modality, 0) - if item_start_idx >= max_item_count: - continue + mode, matches_to_apply = _find_matches( + prompt, + mm_prompt_updates, + tokenizer, + prev_end_idx=prev_end_idx, + current_result=out_result, + ) - start_idx = match.start_idx - end_idx = match.end_idx - origin = match._origin - mode = origin.mode + if mode is not None: + for (modality, item_idx), (match, update_idx) in matches_to_apply: + found = True - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = max_item_count - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = max_item_count if start_idx == end_idx else 1 - else: - assert_never(mode) + matched_update = mm_prompt_updates[modality][item_idx][update_idx] + matched_content = matched_update.content.full - item_end_idx = min(item_start_idx + num_inserts, max_item_count) + if mode == UpdateMode.INSERT: + end_idx_to_insert = match.end_idx + elif mode == UpdateMode.REPLACE: + end_idx_to_insert = match.start_idx + else: + assert_never(mode) - for item_idx in range(item_start_idx, item_end_idx): - content = origin.get_content(item_idx) - insert_seq = (content.full.text if isinstance(prompt, str) else - content.full.token_ids) + out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) + out_seqs.append( + _seq2text(tokenizer, matched_content) + if isinstance(prompt, str) + else _seq2tokens(tokenizer, matched_content) + ) + out_result[modality][item_idx] = update_idx - out_seqs.append(insert_seq) + # Exclude overlapping matches + start_idx = prev_end_idx = match.end_idx - prev_end_idx = end_idx - next_idx_by_modality[modality] += item_end_idx - item_start_idx + if not found: + start_idx += 1 out_seqs.append(prompt[prev_end_idx:]) - return cast(list[_S], out_seqs) + return cast(list[_S], out_seqs), out_result def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[int]: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) - return flatten_2d_lists(token_id_seqs) + return flatten_2d_lists(token_id_seqs), result def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> str: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - texts = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) - return "".join(texts) + return "".join(texts), result def _iter_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -833,6 +822,8 @@ def _iter_placeholders( Note that empty matches are ignored. """ prompt_len = len(prompt) + mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} + item_idx_by_modality = defaultdict[str, int](lambda: 0) start_idx = 0 @@ -844,9 +835,9 @@ def _iter_placeholders( if item_idx >= mm_item_counts.get(modality, 0): continue - for update_info in modality_updates: - content = update_info.get_content(item_idx) - content_tokens_full = content.full.token_ids + for update in modality_updates[item_idx]: + content = update.content + content_tokens_full = _seq2tokens(tokenizer, content.full) content_len_full = len(content_tokens_full) end_idx_full = start_idx + content_len_full @@ -856,7 +847,7 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed(content.full) + content_is_embed = content_is_embed(tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -880,27 +871,231 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) + it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) -class ProcessingCache(MultiModalCache): - - def __init__(self, capacity_gb: float) -> None: - super().__init__() - - self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) - - self.get = self._cache.get - self.put = self._cache.put - self.reset = self._cache.clear +_T = TypeVar("_T") +_C = TypeVar("_C", bound="PretrainedConfig", default="PretrainedConfig") +_P = TypeVar("_P", bound="ProcessorMixin", default="ProcessorMixin") -_CacheItemOrHash = Union[MultiModalKwargsItem, str] +@dataclass(frozen=True) +class InputProcessingContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + @overload + def get_hf_config(self, /) -> "PretrainedConfig": ... + + @overload + def get_hf_config( + self, + typ: Union[type[_C], tuple[type[_C], ...]], + /, + ) -> _C: ... + + def get_hf_config( + self, + typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + /, + ) -> Any: + """ + Get the HuggingFace configuration + (`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the configuration is not of the specified type. + """ + if typ is None: + from transformers.configuration_utils import PretrainedConfig + + typ = PretrainedConfig + + hf_config = self.model_config.hf_config + if not isinstance(hf_config, typ): + raise TypeError( + "Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}" + ) + + return hf_config + + def get_hf_image_processor_config(self) -> dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + return self.model_config.hf_image_processor_config + + def get_mm_config(self): + """ + Get the multimodal config of the model. + + Raises: + RuntimeError: If the model is not a multimodal model. + """ + mm_config = self.model_config.multimodal_config + if mm_config is None: + raise RuntimeError("Not a multimodal model") + + return mm_config + + @overload + def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin": ... + + @overload + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]], + /, + **kwargs: object, + ) -> _P: ... + + def get_hf_processor( + self, + typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + /, + **kwargs: object, + ) -> Any: + """ + Get the HuggingFace processor + (`transformers.ProcessorMixin`) of the model, + additionally checking its type. + + Raises: + TypeError: If the processor is not of the specified type. + """ + if typ is None: + from transformers.processing_utils import ProcessorMixin + + typ = ProcessorMixin + + return cached_processor_from_config( + self.model_config, + processor_cls=typ, + tokenizer=self.tokenizer, + **kwargs, + ) + + def init_processor( + self, + typ: type[_T], + /, + **kwargs: object, + ) -> _T: + """ + Initialize a HuggingFace-like processor class, merging the + keyword arguments with those in the model's configuration. + """ + mm_config = self.model_config.get_multimodal_config() + base_kwargs = mm_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + return typ(**merged_kwargs) + + def _postprocess_output( + self, + output: JSONTree, + ) -> JSONTree: + def _postprocess_one(x: object): + if isinstance(x, torch.Tensor): # noqa: SIM102 + # This mimics the behavior of transformers.BatchFeature + if x.is_floating_point(): + x = x.to(dtype=self.model_config.dtype) + + return x + + return json_map_leaves(_postprocess_one, output) + + def call_hf_processor( + self, + hf_processor: "ProcessorMixin", + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, + *, + num_tries: int = 1, + max_tries: int = 5, + ) -> Union["BatchFeature", JSONTree]: + """ + Call `hf_processor` on the prompt `data` + (text, image, audio...) with configurable options `kwargs`. + """ + assert callable(hf_processor) + + mm_config = self.model_config.get_multimodal_config() + merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) + + allowed_kwargs = get_allowed_kwarg_only_overrides( + hf_processor, + merged_kwargs, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + try: + output = hf_processor(**data, **allowed_kwargs, return_tensors="pt") + except Exception as exc: + # See https://github.com/huggingface/tokenizers/issues/537 + if ( + isinstance(exc, RuntimeError) + and exc + and exc.args[0] == "Already borrowed" + and num_tries < max_tries + ): + logger.warning( + "Failed to acquire tokenizer in current thread. " + "Retrying (%d/%d)...", + num_tries, + max_tries, + ) + time.sleep(0.5) + return self.call_hf_processor( + hf_processor, + data, + kwargs, + num_tries=num_tries + 1, + max_tries=max_tries, + ) + + msg = ( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={allowed_kwargs}" + ) + + raise ValueError(msg) from exc + + # this emulates output.to(dtype=self.model_config.dtype) + from transformers.feature_extraction_utils import BatchFeature + + if isinstance(output, BatchFeature): + output_ = self._postprocess_output(output.data) + return BatchFeature(output_) + + logger.warning_once( + "%s did not return `BatchFeature`. " + "Make sure to match the behaviour of `ProcessorMixin` when " + "implementing custom processors.", + type(hf_processor).__name__, + ) + + return self._postprocess_output(output) class BaseProcessingInfo: @@ -949,8 +1144,11 @@ class BaseProcessingInfo: for modality, supported_limit in supported_mm_limits.items(): user_limit = mm_config.get_limit_per_prompt(modality) - allowed_limits[modality] = (user_limit if supported_limit is None - else min(user_limit, supported_limit)) + allowed_limits[modality] = ( + user_limit + if supported_limit is None + else min(user_limit, supported_limit) + ) return allowed_limits @@ -961,7 +1159,7 @@ class BaseProcessingInfo: ) -> Optional[Mapping[str, int]]: """ Return the maximum number of tokens per item of for each modality. - + When `None` (the default) is returned, vLLM will generate dummy inputs (images/videos) at maximum possible sizes and process them to determine the maximum token count per modality. @@ -972,7 +1170,7 @@ class BaseProcessingInfo: counts, avoiding the need for dummy input generation and processing. Note: - The maximum number of tokens per item of each modality returned + The maximum number of tokens per item of each modality returned from this function should respect the model's maximum sequence length and the maximum number of items of each modality allowed, and agree with dummy inputs (images/videos) at maximum possible @@ -989,15 +1187,23 @@ A collection of hashes with a similar structure as [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ -MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]] +MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]] """ A collection of prompt updates with a similar structure as [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ +MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]] +""" +For an item `MultiModalPromptUpdates[k][i]`, +`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the +`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the +`ResolvedPromptUpdate` instances have been applied. +""" + class MultiModalProcessingInfo(NamedTuple): - kwargs: MultiModalKwargsItems + kwargs: MultiModalKwargsOptionalItems hashes: MultiModalHashes prompt_updates: MultiModalPromptUpdates @@ -1009,11 +1215,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with `transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional["BaseMultiModalProcessorCache"] = None, + ) -> None: super().__init__() self.info = info @@ -1039,8 +1247,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt: str, mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: - return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) def _get_data_parser(self) -> MultiModalDataParser: """ @@ -1068,8 +1278,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): limit = min(supported_limit, allowed_limit) if num_items > limit: - msg = (f"At most {limit} {modality}(s) may be provided in " - "one prompt.") + msg = f"At most {limit} {modality}(s) may be provided in one prompt." if num_items <= supported_limit: msg += " Set `--limit-mm-per-prompt` to increase this limit." @@ -1088,7 +1297,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) - for modality, items in mm_items.items(): self.validate_num_items(modality, len(items)) @@ -1126,14 +1334,61 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ raise NotImplementedError + def _bind_and_group_updates( + self, + prompt_updates: Sequence[PromptUpdate], + mm_item_counts: Mapping[str, int], + ) -> MultiModalPromptUpdates: + return { + modality: [ + [update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0)) + ] + for modality, updates in full_groupby_modality(prompt_updates) + } + + def _get_mm_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> MultiModalPromptUpdates: + unbound_prompt_updates = self._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates, + mm_items.get_all_counts(), + ) + + for modality, prompt_updates in mm_prompt_updates.items(): + for item_idx, item_prompt_updates in enumerate(prompt_updates): + if len(item_prompt_updates) > 1: + logger.warning_once( + "Detected %d prompt updates for `mm_items[%r][%s]`. " + "Multiple prompt updates per item is now " + "deprecated and may be removed in v0.13. " + "Instead, please specify dynamic update targets " + "in the same prompt update definition by passing " + "a function to `PromptUpdate.target`.", + len(prompt_updates), + modality, + item_idx, + ) + + return mm_prompt_updates + def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_updates, new_token_ids, - mm_item_counts) + tokenizer = self.info.get_tokenizer() + + return find_mm_placeholders(new_token_ids, mm_prompt_updates, tokenizer) def _get_hf_mm_data( self, @@ -1183,7 +1438,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ return not any( isinstance(items, (EmbeddingItems, DictEmbeddingItems)) - for items in mm_items.values()) + for items in mm_items.values() + ) def _apply_hf_processor_text_mm( self, @@ -1208,7 +1464,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) processed_data.update(passthrough_data) - prompt_ids, = processed_data.pop("input_ids").tolist() + (prompt_ids,) = processed_data.pop("input_ids").tolist() is_update_applied = self._hf_processor_applies_updates( prompt_text=prompt_text, @@ -1311,8 +1567,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs=tokenization_kwargs, ) - prompt_ids = self._apply_hf_processor_text_only( - prompt, tokenization_kwargs) + prompt_ids = self._apply_hf_processor_text_only(prompt, tokenization_kwargs) else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) @@ -1324,74 +1579,167 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return prompt_ids, mm_processed_data, False - def _get_cache_missing_items( - self, - cache: ProcessingCache, - mm_data_items: MultiModalDataItems, - mm_hashes: MultiModalHashes, - ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]: - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = { - modality: [(h if (v := cache.get(h)) is None else v) - for h in hashes] - for modality, hashes in mm_hashes.items() - } - - mm_missing_idxs = { - modality: [ - idx for idx, item_or_hash in enumerate(items_or_hashes) - if isinstance(item_or_hash, str) - ] - for modality, items_or_hashes in mm_cache_items_or_hashes.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - - return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data) - def _hash_mm_items( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalHashes: - """Create MM hashes to be returned (only used in V1).""" + """Create MM hashes to be returned. + + + Note: When overrides are provided via callers of `apply`, + `_hash_mm_items` will be bypassed and the overrides will be used. + """ model_id = self.info.model_id - return { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs, - **tokenization_kwargs) - for item in items - ] - for modality, items in mm_items.items() + hashes: MultiModalHashes = {} + mm_uuids = mm_uuids or {} + + for modality, items in mm_items.items(): + if modality in mm_uuids: + mm_uuids_per_modality = mm_uuids[modality] + if isinstance(mm_uuids_per_modality, str): + mm_uuids_per_modality = [mm_uuids_per_modality] + + # For None entries, compute a hash; otherwise, use provided ID. + computed: list[str] = [] + for i, item in enumerate(items): + item_uuid = mm_uuids_per_modality[i] + + # NOTE: Even if a item_uuid is provided, we still compute a + # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` + # are provided. This is because the processed multimodal + # inputs can be different depending on the processor kwargs. + if ( + item_uuid is None + or hf_processor_mm_kwargs + or tokenization_kwargs + ): + # NOTE: use provided hash string to hash with kwargs + # if available for better performance. + item = item_uuid if item_uuid is not None else item + computed.append( + MultiModalHasher.hash_kwargs( + model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs, + ) + ) + else: + computed.append(item_uuid) + hashes[modality] = computed + else: + hashes[modality] = [ + MultiModalHasher.hash_kwargs( + model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs, + ) + for item in items + ] + + return hashes + + def _get_cache_missing_items( + self, + cache: "BaseMultiModalProcessorCache", + mm_data_items: MultiModalDataItems, + mm_hashes: MultiModalHashes, + ) -> MultiModalDataItems: + mm_is_cached = { + modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() } + mm_missing_idxs = { + modality: [ + idx + for idx, item_is_cached in enumerate(items_is_cached) + if not item_is_cached + ] + for modality, items_is_cached in mm_is_cached.items() + } + mm_missing_data = {} + for modality, idxs in mm_missing_idxs.items(): + missing_modality_data = [] + for idx in idxs: + data = mm_data_items[modality][idx] + if data is None: + raise ValueError( + f"Cache miss for {modality} at index {idx} " + f"but data is not provided." + ) + else: + missing_modality_data.append(data) + mm_missing_data[modality] = missing_modality_data + + return self._to_mm_items(mm_missing_data) + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + """ + Override this if other attributes of `ResolvedPromptUpdate` + also need to be recomputed after retrieving from the cache. + """ + return replace(cached_update, item_idx=new_item_idx) + def _merge_mm_kwargs( self, - cache: ProcessingCache, - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], + cache: "BaseMultiModalProcessorCache", + mm_hashes: MultiModalHashes, mm_missing_kwargs: MultiModalKwargsItems, - ) -> MultiModalKwargsItems: + mm_missing_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: + # Need to calculate this at the beginning to avoid skipping cache logic + # for subsequently repeated items in the same modality + mm_is_cached = { + modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() + } + mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) - for modality, items_or_hashes in mm_cache_items_or_hashes.items(): - for item_or_hash in items_or_hashes: - if isinstance(item_or_hash, str): - kw_item = mm_missing_kwargs[modality][ - mm_missing_next_idx[modality]] - cache.put(item_or_hash, kw_item) + merged_kwargs = defaultdict[str, list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]]( + list + ) + for modality, hashes in mm_hashes.items(): + missing_kwargs = mm_missing_kwargs.get(modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get(modality, []) + + for item_idx, item_hash in enumerate(hashes): + kwargs: Optional[MultiModalKwargsItem] + if not mm_is_cached[modality][item_idx]: + missing_next_idx = mm_missing_next_idx[modality] + kwargs = missing_kwargs[missing_next_idx] + updates = missing_prompt_updates[missing_next_idx] + mm_missing_next_idx[modality] += 1 + + item = kwargs, updates else: - kw_item = item_or_hash + item = None - merged_items[modality].append(kw_item) + kwargs, updates = cache.get_and_update_item(item, item_hash) - return MultiModalKwargsItems(merged_items) + merged_kwargs[modality].append(kwargs) + merged_prompt_updates[modality].append( + [ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ] + ) + + mm_kwargs = MultiModalKwargsItems(merged_kwargs) + mm_prompt_updates = dict(merged_prompt_updates) + + return mm_kwargs, mm_prompt_updates def _apply_hf_processor( self, @@ -1399,6 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, @@ -1414,20 +1764,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_processed_data, - self._get_mm_fields_config(mm_processed_data, - hf_processor_mm_kwargs), + self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), ) - mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = self._hash_mm_items( + mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) - unbound_prompt_updates = self._get_prompt_updates( + mm_prompt_updates = self._get_mm_prompt_updates( mm_data_items, hf_processor_mm_kwargs, mm_kwargs, ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) mm_info = MultiModalProcessingInfo( kwargs=mm_kwargs, @@ -1443,6 +1795,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, @@ -1457,14 +1811,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) - mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) - ( - mm_cache_items_or_hashes, - mm_missing_data_items, - ) = self._get_cache_missing_items( + mm_hashes = self._hash_mm_items( + mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, @@ -1487,23 +1844,23 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_missing_processed_data, - self._get_mm_fields_config(mm_missing_processed_data, - hf_processor_mm_kwargs), + self._get_mm_fields_config( + mm_missing_processed_data, hf_processor_mm_kwargs + ), ) - mm_kwargs = self._merge_mm_kwargs( - cache, - mm_cache_items_or_hashes=mm_cache_items_or_hashes, - mm_missing_kwargs=mm_missing_kwargs, - ) - - unbound_prompt_updates = self._get_prompt_updates( - mm_data_items, + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, hf_processor_mm_kwargs, - mm_kwargs, + mm_missing_kwargs, + ) + + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( + cache, + mm_hashes=mm_hashes, + mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) mm_info = MultiModalProcessingInfo( kwargs=mm_kwargs, @@ -1513,47 +1870,33 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return prompt_ids, mm_info, is_update_applied - def _bind_and_group_updates( - self, - prompt_updates: Sequence[PromptUpdate], - ) -> dict[str, Sequence[BoundPromptUpdate]]: - tokenizer = self.info.get_tokenizer() - - it = (update.bind(tokenizer) for update in prompt_updates) - return dict(full_groupby_modality(it)) - def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - return apply_token_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_token_matches(prompt, mm_prompt_updates, tokenizer) def _apply_text_matches( self, prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> str: - return apply_text_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[str, MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_text_matches(prompt, mm_prompt_updates, tokenizer) def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() - mm_token_matches = { - modality: find_token_matches(token_ids, updates) - for modality, updates in mm_prompt_updates.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } + new_token_ids, match_result = self._apply_token_matches( + token_ids, + mm_prompt_updates, + ) # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -1565,53 +1908,43 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # Since it is inefficient to search for all possible tokenizations # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. - if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = self._apply_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, + if not all( + all(update_idx is not None for update_idx in update_idxs) + for update_idxs in match_result.values() + ): + new_text, match_result = self._apply_text_matches( + decode_tokens(tokenizer, token_ids), + mm_prompt_updates, ) - text = decode_tokens(tokenizer, token_ids) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } - else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, updates) - for modality, updates in mm_prompt_updates.items() - } - text = self._apply_text_matches( - text, - mm_text_matches, - mm_item_counts, + new_token_ids = encode_tokens( + tokenizer, + new_text, + add_special_tokens=False, ) - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } + matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, update_idxs in match_result.items(): + for item_idx, update_idx in enumerate(update_idxs): + assert update_idx is not None, ( + "Failed to apply prompt replacement for " + f"mm_items[{modality!r}][{item_idx}]" + ) + + matched_updates[modality].append( + [mm_prompt_updates[modality][item_idx][update_idx]] + ) placeholders = self._find_mm_placeholders( - matched_updates, - token_ids, - mm_item_counts, + new_token_ids, + dict(matched_updates), ) - return token_ids, text, placeholders + return new_token_ids, placeholders def _validate_mm_kwargs( self, - mm_kwargs: MultiModalKwargsItems, + mm_kwargs: MultiModalKwargsOptionalItems, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): @@ -1625,7 +1958,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): "There is likely a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_mm_fields_config`).") + "`_call_hf_processor` and `_get_mm_fields_config`)." + ) + + def _validate_mm_updates( + self, + mm_updates: MultiModalPromptUpdates, + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + placeholders = mm_updates.get(modality, []) + + if len(placeholders) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} prompt updates " + f"corresponding to {item_count} {modality} items, but " + f"instead found {len(placeholders)} prompt updates! " + "This is likely because you forgot to include input " + "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) " + "in the prompt. If the model has a chat template, make " + "sure you have applied it before calling `LLM.generate`." + ) def _validate_mm_placeholders( self, @@ -1636,52 +1989,40 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): placeholders = mm_placeholders.get(modality, []) if len(placeholders) != item_count: - # NOTE: If you are a model developer, this can also arise from - # an inconsistency between `_call_hf_processor` and - # `_get_mm_fields_config` implementations raise RuntimeError( - f"Expected there to be {item_count} prompt updates " + f"Expected there to be {item_count} prompt placeholders " f"corresponding to {item_count} {modality} items, but " - f"instead found {len(placeholders)} prompt updates! " - "This is likely because you forgot to include input " - "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) " - "in the prompt. If the model has a chat template, make " - "sure you have applied it before calling `LLM.generate`.") + f"instead found {len(placeholders)} prompt placeholders! " + "Make sure the implementation of `_call_hf_processor` and " + "`_get_mm_fields_config` are consistent with each other." + ) def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, prompt_ids: list[int], - mm_kwargs: MultiModalKwargsItems, + mm_kwargs: MultiModalKwargsOptionalItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + self._validate_mm_updates(mm_prompt_updates, mm_item_counts) if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - - tokenizer = self.info.get_tokenizer() - prompt = decode_tokens(tokenizer, prompt_ids) else: - ( - prompt_ids, - prompt, - mm_placeholders, - ) = self._apply_prompt_updates( + prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - return prompt_ids, prompt, mm_placeholders + return prompt_ids, mm_placeholders def apply( self, @@ -1689,6 +2030,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1717,10 +2060,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) # NOTE: tokenization_kwargs are not required to init processor - prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( + prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, prompt_ids=prompt_ids, mm_kwargs=mm_info.kwargs, @@ -1735,7 +2079,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return MultiModalInputs( type="multimodal", - prompt=prompt, prompt_token_ids=prompt_ids, mm_kwargs=mm_info.kwargs, mm_hashes=mm_info.hashes, @@ -1744,7 +2087,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): - @abstractmethod def create_encoder_prompt( self, @@ -1776,23 +2118,19 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): encoder_inputs: MultiModalInputs, ): tokenizer = self.info.get_tokenizer() - decoder_prompt = self.create_decoder_prompt(prompt, mm_data) - if isinstance(decoder_prompt, str): - decoder_prompt_ids = encode_tokens(tokenizer, - decoder_prompt, - add_special_tokens=False) + decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) + if isinstance(decoder_prompt_raw, str): + decoder_prompt_ids = encode_tokens( + tokenizer, decoder_prompt_raw, add_special_tokens=False + ) else: - decoder_prompt_ids = decoder_prompt - decoder_prompt = decode_tokens(tokenizer, decoder_prompt) + decoder_prompt_ids = decoder_prompt_raw mm_inputs = MultiModalEncDecInputs( - encoder_prompt=encoder_inputs["prompt"], encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], - **encoder_inputs) - mm_inputs.update({ - "prompt": decoder_prompt, - "prompt_token_ids": decoder_prompt_ids - }) + **encoder_inputs, + ) + mm_inputs["prompt_token_ids"] = decoder_prompt_ids return mm_inputs def apply( @@ -1801,6 +2139,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1815,6 +2155,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, + mm_uuids=mm_uuids, ) return self._get_enc_dec_inputs( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2da9b4c72189a..05ba5a2abdd41 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -10,13 +10,26 @@ import numpy.typing as npt from PIL import Image import vllm.envs as envs +from vllm.config.multimodal import ( + AudioDummyOptions, + BaseDummyOptions, + ImageDummyOptions, + VideoDummyOptions, +) from vllm.logger import init_logger -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargsItems, - MultiModalPlaceholderDict) -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - EncDecMultiModalProcessor) +from .inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalPlaceholderDict, +) +from .processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + EncDecMultiModalProcessor, +) logger = init_logger(__name__) @@ -27,6 +40,7 @@ class ProcessorInputs: Represents the keyword arguments to [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][]. """ + prompt: Union[str, list[int]] mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) @@ -73,10 +87,19 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: """ Build the multimodal input which, after processing, results in the maximum possible number of placeholder tokens. + + Args: + seq_len: Sequence length + mm_counts: Count of items per modality + mm_options: Configurable options per modality (optional). + If None, use model defaults for backward compatibility. + If provided, models can use these to customize dummy + data generation. """ raise NotImplementedError @@ -84,28 +107,49 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> ProcessorInputs: """ Build the input which, after processing, results in the maximum possible number of placeholder tokens. + + Args: + seq_len: Sequence length + mm_counts: Count of items per modality + mm_options: Configurable options per modality (optional) """ dummy_text = self.get_dummy_text(mm_counts) - dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + + # Use the unified function for both legacy and configurable cases + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) + tokenization_kwargs = {"truncation": False} - return ProcessorInputs(prompt=dummy_text, - mm_data=dummy_mm_data, - tokenization_kwargs=tokenization_kwargs) + return ProcessorInputs( + prompt=dummy_text, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs, + ) def _get_dummy_audios( self, *, length: int, num_audios: int, + overrides: Optional[AudioDummyOptions] = None, ) -> list[npt.NDArray]: if num_audios == 0: return [] - audio = np.zeros((length, )) + if overrides and overrides.length: + if overrides.length > length: + logger.warning( + "audio.length override (%d) exceeds model's " + "maximum length (%d), will be ignored", + overrides.length, + length, + ) + length = min(length, overrides.length) + audio = np.zeros((length,)) return [audio] * num_audios def _get_dummy_images( @@ -114,9 +158,29 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): width: int, height: int, num_images: int, + overrides: Optional[ImageDummyOptions] = None, ) -> list[Image.Image]: if num_images == 0: return [] + if overrides: + if overrides.width: + if overrides.width > width: + logger.warning( + "image.width override (%d) exceeds model's " + "maximum width (%d), will be ignored", + overrides.width, + width, + ) + width = min(width, overrides.width) + if overrides.height: + if overrides.height > height: + logger.warning( + "image.height override (%d) exceeds model's " + "maximum height (%d), will be ignored", + overrides.height, + height, + ) + height = min(height, overrides.height) image = Image.new("RGB", (width, height), color=255) return [image] * num_images @@ -127,9 +191,38 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): height: int, num_frames: int, num_videos: int, + overrides: Optional[VideoDummyOptions] = None, ) -> list[npt.NDArray]: if num_videos == 0: return [] + if overrides: + if overrides.num_frames: + if overrides.num_frames > num_frames: + logger.warning( + "video.num_frames override (%d) exceeds model's " + "maximum number of frames (%d), will be ignored", + overrides.num_frames, + num_frames, + ) + num_frames = min(num_frames, overrides.num_frames) + if overrides.width: + if overrides.width > width: + logger.warning( + "video.width override (%d) exceeds model's " + "maximum width (%d), will be ignored", + overrides.width, + width, + ) + width = min(width, overrides.width) + if overrides.height: + if overrides.height > height: + logger.warning( + "video.height override (%d) exceeds model's " + "maximum height (%d), will be ignored", + overrides.height, + height, + ) + height = min(height, overrides.height) video = np.full((num_frames, width, height, 3), 255) return [video] * num_videos @@ -162,13 +255,15 @@ class MultiModalProfiler(Generic[_I]): self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalInputs: if mm_counts is None: mm_counts = self.get_mm_limits() factory = self.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( - seq_len, mm_counts) + seq_len, mm_counts, mm_options + ) return self.processor.apply( prompt=processor_inputs.prompt, @@ -185,9 +280,10 @@ class MultiModalProfiler(Generic[_I]): placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: - sum(item.get_num_embeds() if mm_embeddings_only else item.length - for item in placeholders) + modality: sum( + item.get_num_embeds() if mm_embeddings_only else item.length + for item in placeholders + ) for modality, placeholders in placeholders_by_modality.items() } @@ -195,8 +291,9 @@ class MultiModalProfiler(Generic[_I]): self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> DummyEncoderData: - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) # For encoder-decoder models, use encoder prompt token ids instead of @@ -228,31 +325,19 @@ class MultiModalProfiler(Generic[_I]): self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> DummyDecoderData: - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options) prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) - # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - # `max_num_batched_tokens` is defined by `SchedulerConfig` - logger.warning_once( - "The sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501 - "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501 - "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501 - "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501 - seq_len, - total_len, - str(self._get_mm_num_tokens(mm_inputs)), - ) - if total_len < seq_len: prompt_token_ids.extend([0] * (seq_len - total_len)) return DummyDecoderData( prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_data=mm_inputs["mm_kwargs"].require_data(), multi_modal_placeholders=mm_inputs["mm_placeholders"], ) @@ -270,27 +355,10 @@ class MultiModalProfiler(Generic[_I]): mm_counts=mm_counts, ) if max_tokens_per_item is not None: - if mm_counts is None: - total_mm_tokens = sum(max_tokens_per_item.values()) - else: - total_mm_tokens = sum(max_tokens_per_item[k] * mm_counts[k] - for k in max_tokens_per_item.keys() - & mm_counts.keys()) - if total_mm_tokens > seq_len: - logger.warning_once( - "The sequence length (%d) is smaller than the pre-defined" - " worst-case total number of multimodal tokens (%d). " - "This may cause certain multi-modal inputs to fail during " - "inference. To avoid this, you should increase " - "`max_model_len` or reduce `mm_counts`.", - seq_len, - total_mm_tokens, - ) return max_tokens_per_item mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs, - mm_embeddings_only=mm_embeddings_only) + return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) def get_mm_max_contiguous_tokens( self, @@ -301,13 +369,11 @@ class MultiModalProfiler(Generic[_I]): Returns the maximum length of the multimodal (image placeholders+text) tokens, including any break/text tokens in-between image embeddings. - <im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end> + `<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>` Returns 9, even when the number of image embeddings is 6. - + This is important to take into account when profiling and initializing the encoder cache size. """ - return self._get_mm_max_tokens(seq_len, - mm_counts, - mm_embeddings_only=False) + return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ded56cca80999..a526eaff715ac 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,21 +2,27 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from functools import lru_cache from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn -from vllm.inputs import InputProcessingContext +from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config from vllm.utils import ClassRegistry -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) -from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, - DummyEncoderData, MultiModalProfiler) +from .cache import BaseMultiModalProcessorCache +from .processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, +) +from .profiling import ( + BaseDummyInputsBuilder, + DummyDecoderData, + DummyEncoderData, + MultiModalProfiler, +) if TYPE_CHECKING: from vllm.config import ModelConfig @@ -38,22 +44,20 @@ class ProcessingInfoFactory(Protocol[_I_co]): def __call__( self, ctx: InputProcessingContext, - ) -> _I_co: - ... + ) -> _I_co: ... -class DummyInputsBuilderFactory(Protocol[_I]): +class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] """ Constructs a [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder] instance from the context. """ - def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: - ... + def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ... -class MultiModalProcessorFactory(Protocol[_I]): +class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc] """ Constructs a [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor] @@ -65,9 +69,8 @@ class MultiModalProcessorFactory(Protocol[_I]): info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, - ) -> BaseMultiModalProcessor[_I]: - ... + cache: Optional[BaseMultiModalProcessorCache] = None, + ) -> BaseMultiModalProcessor[_I]: ... @dataclass(frozen=True) @@ -80,59 +83,47 @@ class _ProcessorFactories(Generic[_I]): self, ctx: InputProcessingContext, *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) return self.processor(info, dummy_inputs_builder, cache=cache) -# Make sure a different cache is used for each model config -# NOTE: ModelConfig is not hashable so it cannot be passed directly -@lru_cache(maxsize=1) -def _get_processor_cache(model_id: str, capacity_gb: int): - return ProcessingCache(capacity_gb) if capacity_gb > 0 else None - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. """ def __init__(self) -> None: - self._processor_factories = ClassRegistry[nn.Module, - _ProcessorFactories]() + self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - def _get_processor_cache(self, model_config: "ModelConfig"): - model_id = model_config.model - capacity_gb = model_config.mm_processor_cache_gb - return _get_processor_cache(model_id, capacity_gb) - - def reset_processor_cache(self, model_config: "ModelConfig") -> bool: - """Reset the multi-modal processing cache.""" - if processor_cache := self._get_processor_cache(model_config): - processor_cache.reset() - - return True # Success - - def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: - """Whether the multi-modal input cache should be enabled. - NOTE: This is put under MultiModalRegistry on purpose to respect - text-only mode for multimodal models. + def _extract_mm_options( + self, + model_config: "ModelConfig", + ) -> Optional[Mapping[str, BaseDummyOptions]]: """ + Extract multimodal dummy options from model config. - if not self.supports_multimodal_inputs(model_config): - return False + Returns None if no configurable options are found, otherwise returns + a mapping of modality names to their dummy options. + """ + if not model_config.multimodal_config: + return None - mm_config = model_config.get_multimodal_config() + mm_options = { + m: opt + for m in model_config.multimodal_config.limit_per_prompt + if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None + } - return mm_config.mm_processor_cache_gb > 0 + return mm_options if len(mm_options) > 0 else None def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: """ Checks if the model supports multimodal inputs. - Returns True if the model is multimodal with any non-zero supported - modalities, otherwise returns False, effectively running in + Returns True if the model is multimodal with any non-zero supported + modalities, otherwise returns False, effectively running in text-only mode. """ if not model_config.is_multimodal_model: @@ -145,11 +136,13 @@ class MultiModalRegistry: # Check if all supported modalities have limit == 0 if all( - mm_config.get_limit_per_prompt(modality) == 0 - for modality in supported_modalities): + mm_config.get_limit_per_prompt(modality) == 0 + for modality in supported_modalities + ): logger.info_once( "All limits of multimodal modalities supported by the model " - "are set to 0, running in text-only mode.") + "are set to 0, running in text-only mode." + ) return False return True @@ -157,6 +150,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -165,23 +160,22 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) - profiler = MultiModalProfiler(processor) + processor = self.create_processor(model_config, cache=cache) + profiler: MultiModalProfiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) return profiler.get_mm_max_contiguous_tokens( seq_len, - { - modality: 1 - for modality, limit in mm_limits.items() if limit > 0 - }, + {modality: 1 for modality, limit in mm_limits.items() if limit > 0}, ) def get_max_tokens_per_item_by_nonzero_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -192,41 +186,23 @@ class MultiModalRegistry: This is currently directly used only in V1 for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() if mm_limits[key] > 0 } - def get_max_tokens_by_modality( - self, - model_config: "ModelConfig", - ) -> Mapping[str, int]: - """ - Get the maximum number of tokens from each modality - for profiling the memory usage of a model. - """ - mm_limits = self.get_mm_limits_per_prompt(model_config) - - return { - key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() - } - - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: - """ - Get the maximum number of multi-modal tokens - for profiling the memory usage of a model. - """ - return sum(self.get_max_tokens_by_modality(model_config).values()) - def get_mm_limits_per_prompt( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -235,8 +211,8 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) - profiler = MultiModalProfiler(processor) + processor = self.create_processor(model_config, cache=cache) + profiler: MultiModalProfiler = MultiModalProfiler(processor) return profiler.get_mm_limits() def register_processor( @@ -259,7 +235,9 @@ class MultiModalRegistry: logger.warning( "Model class %s already has a multi-modal processor " "registered to %s. It is overwritten by the new one.", - model_cls, self) + model_cls, + self, + ) self._processor_factories[model_cls] = _ProcessorFactories( info=info, @@ -303,7 +281,7 @@ class MultiModalRegistry: model_config: "ModelConfig", *, tokenizer: Optional[AnyTokenizer] = None, - disable_cache: Optional[bool] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -311,15 +289,10 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if disable_cache is None: - disable_cache = not model_config.enable_mm_processor_cache - model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] ctx = self._create_processing_ctx(model_config, tokenizer) - cache = None if disable_cache else self._get_processor_cache( - model_config) return factories.build_processor(ctx, cache=cache) @@ -328,22 +301,31 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) - profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) + processor = self.create_processor(model_config, cache=cache) + profiler: MultiModalProfiler = MultiModalProfiler(processor) + + # Extract configurable options from multimodal config. + # Only include modalities that use advanced option types so legacy + # count-only behavior remains unchanged. + mm_options = self._extract_mm_options(model_config) + + dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids if len(token_ids) < seq_len: raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(token_ids)} tokens instead.") + f"but found {len(token_ids)} tokens instead." + ) return dummy_data @@ -352,15 +334,23 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) - profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) + processor = self.create_processor(model_config, cache=cache) + profiler: MultiModalProfiler = MultiModalProfiler(processor) + + # Extract configurable options from multimodal config. + # Only include modalities that use advanced option types so legacy + # count-only behavior remains unchanged. + mm_options = self._extract_mm_options(model_config) + + dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids @@ -372,3 +362,23 @@ class MultiModalRegistry: ) return dummy_data + + def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: + """ + Get the maximum length of the encoder input for encoder-decoder models. + """ + if not model_config.is_encoder_decoder: + return 0 + max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config) + if not max_tokens: + # TODO - this function assumes encoder-decoder models are + # multimodal. This will need to change when adding support for more + # than whisper. + return 0 + assert len(max_tokens) == 1, ( + "Encoder-decoder models are expected \ + to implement the multimodal interface with at most one modality." + ) + + first_modality = next(iter(max_tokens)) + return max_tokens[first_modality] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 834b2189e4bed..c9dc077d0385f 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,8 +3,6 @@ import asyncio import atexit -import itertools -import math from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from itertools import groupby @@ -17,13 +15,10 @@ import numpy as np import numpy.typing as npt import torch from PIL import Image, UnidentifiedImageError -from typing_extensions import deprecated import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.utils.jsontree import json_map_leaves from .audio import AudioMediaIO from .base import MediaIO @@ -33,35 +28,38 @@ from .video import VideoMediaIO _M = TypeVar("_M") if TYPE_CHECKING: - from .inputs import (BatchedTensorInputs, MultiModalKwargs, - MultiModalKwargsItem, MultiModalKwargsItems, - MultiModalPlaceholderDict) + from .inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalPlaceholderDict, + ) else: BatchedTensorInputs = Any - MultiModalKwargs = Any MultiModalKwargsItem = Any MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any global_thread_pool = ThreadPoolExecutor( - max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT) + max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT +) atexit.register(global_thread_pool.shutdown) class MediaConnector: - def __init__( self, media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None, connection: HTTPConnection = global_http_connection, *, allowed_local_media_path: str = "", + allowed_media_domains: Optional[list[str]] = None, ) -> None: """ Args: - media_io_kwargs: Additional args passed to process media - inputs, keyed by modalities. For example, - to set num_frames for video, set + media_io_kwargs: Additional args passed to process media + inputs, keyed by modalities. For example, + to set num_frames for video, set `--media-io-kwargs '{"video":{"num_frames":40}}'` connection: HTTP connection client to download media contents. allowed_local_media_path: A local directory to load media files @@ -69,8 +67,9 @@ class MediaConnector: """ super().__init__() - self.media_io_kwargs: dict[str, dict[ - str, Any]] = media_io_kwargs if media_io_kwargs else {} + self.media_io_kwargs: dict[str, dict[str, Any]] = ( + media_io_kwargs if media_io_kwargs else {} + ) self.connection = connection if allowed_local_media_path: @@ -79,21 +78,26 @@ class MediaConnector: if not allowed_local_media_path_.exists(): raise ValueError( "Invalid `--allowed-local-media-path`: The path " - f"{allowed_local_media_path_} does not exist.") + f"{allowed_local_media_path_} does not exist." + ) if not allowed_local_media_path_.is_dir(): raise ValueError( "Invalid `--allowed-local-media-path`: The path " - f"{allowed_local_media_path_} must be a directory.") + f"{allowed_local_media_path_} must be a directory." + ) else: allowed_local_media_path_ = None self.allowed_local_media_path = allowed_local_media_path_ + if allowed_media_domains is None: + allowed_media_domains = [] + self.allowed_media_domains = allowed_media_domains def _load_data_url( self, url_spec: ParseResult, media_io: MediaIO[_M], - ) -> _M: + ) -> _M: # type: ignore[type-var] data_spec, data = url_spec.path.split(",", 1) media_type, data_type = data_spec.split(";", 1) @@ -107,32 +111,51 @@ class MediaConnector: self, url_spec: ParseResult, media_io: MediaIO[_M], - ) -> _M: + ) -> _M: # type: ignore[type-var] allowed_local_media_path = self.allowed_local_media_path if allowed_local_media_path is None: - raise RuntimeError("Cannot load local files without " - "`--allowed-local-media-path`.") + raise RuntimeError( + "Cannot load local files without `--allowed-local-media-path`." + ) filepath = Path(url2pathname(url_spec.path)) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " - f"of `--allowed-local-media-path` {allowed_local_media_path}.") + f"of `--allowed-local-media-path` {allowed_local_media_path}." + ) return media_io.load_file(filepath) + def _assert_url_in_allowed_media_domains(self, url_spec) -> None: + if ( + self.allowed_media_domains + and url_spec.hostname not in self.allowed_media_domains + ): + raise ValueError( + f"The URL must be from one of the allowed domains: " + f"{self.allowed_media_domains}. Input URL domain: " + f"{url_spec.hostname}" + ) + def load_from_url( self, url: str, media_io: MediaIO[_M], *, fetch_timeout: Optional[int] = None, - ) -> _M: + ) -> _M: # type: ignore[type-var] url_spec = urlparse(url) if url_spec.scheme.startswith("http"): + self._assert_url_in_allowed_media_domains(url_spec) + connection = self.connection - data = connection.get_bytes(url, timeout=fetch_timeout) + data = connection.get_bytes( + url, + timeout=fetch_timeout, + allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, + ) return media_io.load_bytes(data) @@ -156,22 +179,27 @@ class MediaConnector: loop = asyncio.get_running_loop() if url_spec.scheme.startswith("http"): + self._assert_url_in_allowed_media_domains(url_spec) + connection = self.connection - data = await connection.async_get_bytes(url, timeout=fetch_timeout) - future = loop.run_in_executor(global_thread_pool, - media_io.load_bytes, data) + data = await connection.async_get_bytes( + url, + timeout=fetch_timeout, + allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, + ) + future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data) return await future if url_spec.scheme == "data": - future = loop.run_in_executor(global_thread_pool, - self._load_data_url, url_spec, - media_io) + future = loop.run_in_executor( + global_thread_pool, self._load_data_url, url_spec, media_io + ) return await future if url_spec.scheme == "file": - future = loop.run_in_executor(global_thread_pool, - self._load_file_url, url_spec, - media_io) + future = loop.run_in_executor( + global_thread_pool, self._load_file_url, url_spec, media_io + ) return await future msg = "The URL must be either a HTTP, data or file URL." raise ValueError(msg) @@ -213,12 +241,13 @@ class MediaConnector: image_mode: str = "RGB", ) -> Image.Image: """ - Load a PIL image from a HTTP or base64 data URL. + Load a PIL image from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) try: return self.load_from_url( @@ -237,12 +266,13 @@ class MediaConnector: image_mode: str = "RGB", ) -> Image.Image: """ - Asynchronously load a PIL image from a HTTP or base64 data URL. + Asynchronously load a PIL image from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) try: return await self.load_from_url_async( @@ -261,12 +291,12 @@ class MediaConnector: image_mode: str = "RGB", ) -> tuple[npt.NDArray, dict[str, Any]]: """ - Load video from a HTTP or base64 data URL. + Load video from an HTTP or base64 data URL. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) - video_io = VideoMediaIO(image_io, - **self.media_io_kwargs.get("video", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) + video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {})) return self.load_from_url( video_url, @@ -281,14 +311,14 @@ class MediaConnector: image_mode: str = "RGB", ) -> tuple[npt.NDArray, dict[str, Any]]: """ - Asynchronously load video from a HTTP or base64 data URL. + Asynchronously load video from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) - video_io = VideoMediaIO(image_io, - **self.media_io_kwargs.get("video", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) + video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {})) return await self.load_from_url_async( video_url, @@ -310,7 +340,7 @@ class MediaConnector: def encode_audio_base64( audio: np.ndarray, - sampling_rate: float, + sampling_rate: int, ) -> str: """Encode audio as base64.""" audio_io = AudioMediaIO() @@ -339,7 +369,8 @@ def encode_video_base64(frames: npt.NDArray) -> str: def argsort_mm_positions( - mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]: + mm_positions: MultiModalPlaceholderDict, +) -> list[tuple[str, int]]: """ Given a `MultiModalPlaceholderDict`, output a sequence of keys to sort the dictionary by `offset` (starting index in the input sequence) @@ -349,339 +380,76 @@ def argsort_mm_positions( A list of `(modality, idx)`, which can be used to access an item by `mm_positions[modality][idx]`. """ - flat_items = ((modality, idx, item) - for modality, items in mm_positions.items() - for idx, item in enumerate(items)) + flat_items = ( + (modality, idx, item) + for modality, items in mm_positions.items() + for idx, item in enumerate(items) + ) sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset) return [(modality, idx) for modality, idx, _ in sorted_flat_items] -# Temporary back-compatibility for plugins that define model runner -@deprecated("`group_mm_inputs_by_modality` is superseded by " - "`group_mm_kwargs_by_modality` and will be removed in v0.13. " - "Please use `group_mm_kwargs_by_modality` instead.") -def group_mm_inputs_by_modality( - mm_inputs: list[MultiModalKwargsItems] -) -> list[list[MultiModalKwargsItems]]: - if not mm_inputs: - return [] - - def modality_group_func( - mm_input: MultiModalKwargsItems) -> Union[str, int]: - # If the input has multiple modalities, return a id as the unique key - # for the mm_input input. - if len(mm_input) > 1: - return id(mm_input) - - elif len(mm_input) == 1: - return next(iter(mm_input.keys())) - - # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, - # this is used to make InternVL with legacy pipeline still work with v1. - else: - return "" - - return [ - list(group) for _, group in groupby(mm_inputs, key=modality_group_func) - ] - - def group_mm_kwargs_by_modality( mm_kwargs: list[MultiModalKwargsItem], *, device: torch.types.Device = None, pin_memory: bool = False, + merge_by_field_config: Optional[bool] = None, ) -> Iterable[tuple[str, int, BatchedTensorInputs]]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. Args: - mm_inputs: List of `MultiModalKwargsItem`. + mm_kwargs: List of `MultiModalKwargsItem`. + device: The device to place the grouped tensors on. + pin_memory: Whether to pin memory for faster host-to-device transfer. Yields: A tuple `(modality, num_items, grouped_kwargs)`. """ + if merge_by_field_config is None: + raise RuntimeError( + "`group_mm_kwargs_by_modality` now requires " + "`merge_by_field_config` arg, please update your model runner " + "according to https://github.com/vllm-project/vllm/pull/25676." + ) + from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) - # mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \ - # .get_data(pin_memory=pin_memory) + # TODO: Deprecate `merge_by_field_config` once + # we have migrated all in-tree models + if merge_by_field_config: + mm_kwargs_group: BatchedTensorInputs = dict( + MultiModalKwargsItems.from_seq(items_lst).get_data( + pin_memory=pin_memory + ) + ) - # if device is not None: - # mm_kwargs_group = json_map_leaves( - # lambda x: x.to(device=device), - # mm_kwargs_group, - # ) - - # TODO: Once V0 is removed, we can use the merging logic above - # to avoid creating an extra batch dimension (except for fields - # that are meant to be stacked anyway). - # We will also need to update each model to remove `flatten_bn`. - mm_kwargs_group = MultiModalKwargs.as_kwargs( - MultiModalKwargs.batch( - [ - MultiModalKwargsItems.from_seq([item]).get_data() - for item in items_lst - ], - pin_memory=pin_memory, - ), - device=device, - ) + if device is not None: + mm_kwargs_group = json_map_leaves( + lambda x: x.to(device=device), + mm_kwargs_group, + ) + else: + mm_kwargs_group = MultiModalKwargs.as_kwargs( + MultiModalKwargs.batch( + [ + MultiModalKwargsItems.from_seq([item]).get_data() + for item in items_lst + ], + pin_memory=pin_memory, + ), + device=device, + ) yield modality, len(items_lst), mm_kwargs_group -def run_dp_sharded_vision_model(image_input: torch.Tensor, - vision_model: torch.nn.Module) -> torch.Tensor: - """Run a vision model with data parallelism (DP) sharding. The function - will shard the input image tensor on the first dimension and run the vision - model - - Args: - image_input (torch.Tensor): Image input tensor. - vision_model (torch.nn.Module): Vision model. - - Returns: - torch.Tensor: Output image embeddings - """ - - num_chunks = image_input.shape[0] - mp_world_size = get_tensor_model_parallel_world_size() - num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size - num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks - pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) - image_input_padded = torch.nn.functional.pad(image_input, pad) - rank = get_tensor_model_parallel_rank() - image_input_per_rank = image_input_padded[rank * - num_chunks_per_rank:(rank + 1) * - num_chunks_per_rank, ...] - - vision_embeddings = vision_model(image_input_per_rank) - # Ensure tensor is contiguous before all_gather - vision_embeddings = vision_embeddings.contiguous() - vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, - dim=0) - vision_embeddings = vision_embeddings[:num_chunks, ...] - return vision_embeddings - - -def get_load_balance_assignment( - sizes: list[int], - num_gpus: int = 2, -) -> tuple[list[int], list[int], list[int]]: - """ - Generate load balancing assignment and metadata - for distributing data across GPUs. - The load is determined by the total image sizes, - not the number of images. - - Args: - sizes: The size of each image - num_gpus: Number of GPUs to balance across - - Returns: - shuffle_indices: - Indices to reorder data for balanced loading - gpu_sample_counts: - Number of samples assigned to each GPU - grouped_sizes_per_gpu: - Total size assigned to each GPU - - Example: - ``` - sizes = [1000, 100, 200, 50] - num_gpus=2 - ``` - - """ - - n_samples = len(sizes) - - # Handle edge cases - if n_samples == 0: - return [], [0] * num_gpus, [0] * num_gpus - - # Use greedy algorithm - balance by total size, not sample count - gpu_assignments = [list[int]() for _ in range(num_gpus)] - gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count - - # Sort indices by size (largest first for better load balancing) - # sizes = [1000, 100, 200, 50] - # large_to_small_indices = [0, 2, 1, 3] - large_to_small_indices = sorted(range(n_samples), - key=lambda i: sizes[i], - reverse=True) - - for idx in large_to_small_indices: - # Find GPU with minimum current load (by total size) - min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i]) - gpu_assignments[min_gpu].append(idx) - gpu_loads[min_gpu] += sizes[idx] - - # Create shuffle indices and counts - shuffle_indices = list[int]() - gpu_sample_counts = list[int]() - for gpu_id in range(num_gpus): - # GPU_0 = [1000] = [0] - # GPU_1 = [200, 100, 50] = [2, 1, 3] - # shuffle_indices = [0, 2, 1, 3] - shuffle_indices.extend(gpu_assignments[gpu_id]) - # GPU_0 = [1] - # GPU_1 = [3] - # gpu_sample_counts = [1, 3] - gpu_sample_counts.append(len(gpu_assignments[gpu_id])) - - return (shuffle_indices, gpu_sample_counts, gpu_loads) - - -def run_dp_sharded_mrope_vision_model( - vision_model: torch.nn.Module, - pixel_values: torch.Tensor, - grid_thw_list: list[list[int]], -) -> tuple[torch.Tensor, ...]: - """Run a vision model with data parallelism (DP) sharding. - The function will shard the input image tensor on the - first dimension and run the vision model. - This function is used to run the vision model with mrope. - - Args: - vision_model (torch.nn.Module): Vision model. - pixel_values (torch.Tensor): Image/Video input tensor. - grid_thw_list: List of grid dimensions for each image - Returns: - torch.Tensor: Output image embeddings - - Example: - ``` - vision_model.out_hidden_size = 64 - vision_model.spatial_merge_size = 2 - pixel_values.shape = (1350, channel) - grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] - tp_size=2 - ``` - - """ - tp_size = get_tensor_model_parallel_world_size() - - # GPU_0 tp_rank_local = 0 - # GPU_1 tp_rank_local = 1 - tp_rank_local = get_tensor_model_parallel_rank() - - # patches_per_image = [1000, 100, 200, 50] - patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list] - # patches_per_image = [0, 1000, 1100, 1300, 1350] - cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)] - - # Get load balancing assignment with all metadata - # image_to_tp_rank = [0, 2, 1, 3] - # gpu_sample_counts = [1, 3] - # grouped_pixel_values_len = [1000, 350] - (image_to_tp_rank, gpu_sample_counts, - grouped_pixel_values_len) = get_load_balance_assignment( - patches_per_image, tp_size) - - # cu_gpu_sample_counts = [0, 1, 4] - cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] - - # GPU_0 image_idxs_local = [0] - # GPU_1 image_idxs_local = [2, 1, 3] - image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: - cum_gpu_sample_counts[tp_rank_local + - 1]] - - # Get the pixel values for the local images based on the image_idxs_local - if len(image_idxs_local) > 0: - pixel_values_local = torch.cat([ - pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] - for i in image_idxs_local - ]) - else: - # Handle case where this rank has no images - pixel_values_local = torch.empty((0, pixel_values.shape[1]), - device=pixel_values.device, - dtype=pixel_values.dtype) - # embed_dim_reduction_factor = 2 * 2 - embed_dim_reduction_factor = (vision_model.spatial_merge_size * - vision_model.spatial_merge_size) - - # Find the max length across all ranks - # The output embedding of every DP rank has to be - # padded to this length for tensor_model_parallel_all_gather - # to work - max_len_per_rank = max( - grouped_pixel_values_len) // embed_dim_reduction_factor - local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] - - # Run the vision model on the local pixel_values_local - if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model(pixel_values_local, - local_grid_thw_list) - else: - # Handle empty case - image_embeds_local = torch.empty((0, vision_model.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) - - # Pad the output based on max_len_per_rank - # for tensor_model_parallel_all_gather to work - current_len = image_embeds_local.shape[0] - if current_len < max_len_per_rank: - padding_size = max_len_per_rank - current_len - padding = torch.empty((padding_size, image_embeds_local.shape[1]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) - image_embeds_local_padded = torch.cat([image_embeds_local, padding], - dim=0) - else: - image_embeds_local_padded = image_embeds_local - - # Do all_gather to collect embeddings from all ranks - gathered_embeds = tensor_model_parallel_all_gather( - image_embeds_local_padded, dim=0) - - # Remove padding and reconstruct per-rank embeddings - rank_embeddings = list[torch.Tensor]() - for rank in range(tp_size): - start_idx = rank * max_len_per_rank - end_idx = start_idx + (grouped_pixel_values_len[rank] // - embed_dim_reduction_factor) - rank_embeddings.append(gathered_embeds[start_idx:end_idx]) - - patches_per_output_image = [(patch_size // embed_dim_reduction_factor) - for patch_size in patches_per_image] - - # Reconstruct embeddings in the original order - original_order_embeddings = [None] * len(grid_thw_list) - current_idx = 0 - for rank in range(tp_size): - count = gpu_sample_counts[rank] - if count > 0: - # Get images assigned to this rank in shuffled order - # GPU_0 = image_idxs_local [0] - # GPU_1 = image_idxs_local [2, 1, 3] - rank_images = image_to_tp_rank[current_idx:current_idx + count] - - rank_embed = rank_embeddings[rank] - # Split rank embeddings back to individual images - embed_start = 0 - for img_idx in rank_images: - img_patches = patches_per_output_image[img_idx] - original_order_embeddings[img_idx] = rank_embed[ - embed_start:embed_start + img_patches] - embed_start += img_patches - current_idx += count - - out_embeddings = tuple(embed for embed in original_order_embeddings - if embed is not None) - assert len(out_embeddings) == len( - original_order_embeddings), "Found unassigned embeddings" - return out_embeddings - - def fetch_audio( audio_url: str, audio_io_kwargs: Optional[dict[str, Any]] = None, @@ -691,9 +459,7 @@ def fetch_audio( audio_url: URL of the audio file to fetch. audio_io_kwargs: Additional kwargs passed to handle audio IO. """ - media_io_kwargs = None if not audio_io_kwargs else { - "audio": audio_io_kwargs - } + media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs} media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) return media_connector.fetch_audio(audio_url) @@ -707,9 +473,7 @@ def fetch_image( image_url: URL of the image file to fetch. image_io_kwargs: Additional kwargs passed to handle image IO. """ - media_io_kwargs = None if not image_io_kwargs else { - "image": image_io_kwargs - } + media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs} media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) return media_connector.fetch_image(image_url) @@ -723,8 +487,6 @@ def fetch_video( video_url: URL of the video file to fetch. video_io_kwargs: Additional kwargs passed to handle video IO. """ - media_io_kwargs = None if not video_io_kwargs else { - "video": video_io_kwargs - } + media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs} media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) return media_connector.fetch_video(video_url) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index ef1380bdb614c..400d6a6be9bee 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import base64 +import math from abc import abstractmethod from functools import partial from io import BytesIO from pathlib import Path -from typing import Any +from typing import Any, Union import numpy as np import numpy.typing as npt @@ -21,8 +21,9 @@ from .image import ImageMediaIO def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: num_frames, _, _, channels = frames.shape new_height, new_width = size - resized_frames = np.empty((num_frames, new_height, new_width, channels), - dtype=frames.dtype) + resized_frames = np.empty( + (num_frames, new_height, new_width, channels), dtype=frames.dtype + ) # lazy import cv2 to avoid bothering users who only use text models import cv2 @@ -40,8 +41,7 @@ def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray: return resize_video(frames, (new_height, new_width)) -def sample_frames_from_video(frames: npt.NDArray, - num_frames: int) -> npt.NDArray: +def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray: total_frames = frames.shape[0] if num_frames == -1: return frames @@ -52,23 +52,19 @@ def sample_frames_from_video(frames: npt.NDArray, class VideoLoader: - @classmethod @abstractmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - **kwargs) -> tuple[npt.NDArray, dict[str, Any]]: + def load_bytes( + cls, data: bytes, num_frames: int = -1, **kwargs + ) -> tuple[npt.NDArray, dict[str, Any]]: raise NotImplementedError class VideoLoaderRegistry: - def __init__(self) -> None: self.name2class: dict[str, type] = {} def register(self, name: str): - def wrap(cls_to_register): self.name2class[name] = cls_to_register return cls_to_register @@ -87,7 +83,6 @@ VIDEO_LOADER_REGISTRY = VideoLoaderRegistry() @VIDEO_LOADER_REGISTRY.register("opencv") class OpenCVVideoBackend(VideoLoader): - def get_cv2_video_api(self): import cv2.videoio_registry as vr @@ -104,10 +99,12 @@ class OpenCVVideoBackend(VideoLoader): return api_pref @classmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - **kwargs) -> tuple[npt.NDArray, dict[str, Any]]: + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: import cv2 backend = cls().get_cv2_video_api() @@ -119,15 +116,15 @@ class OpenCVVideoBackend(VideoLoader): original_fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames_num / original_fps if original_fps > 0 else 0 + # resample video to target num_frames full_read = num_frames == -1 or total_frames_num < num_frames if full_read: num_frames = total_frames_num frame_idx = list(range(0, num_frames)) else: - uniform_sampled_frames = np.linspace(0, - total_frames_num - 1, - num_frames, - dtype=int) + uniform_sampled_frames = np.linspace( + 0, total_frames_num - 1, num_frames, dtype=int + ) frame_idx = uniform_sampled_frames.tolist() width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) @@ -145,22 +142,112 @@ class OpenCVVideoBackend(VideoLoader): frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) i += 1 - assert i == num_frames, (f"Expected reading {num_frames} frames, " - f"but only loaded {i} frames from video.") + assert i == num_frames, ( + f"Expected reading {num_frames} frames, " + f"but only loaded {i} frames from video." + ) + + # Use transformers transformers.video_utils.VideoMetadata format + # NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata + # can cause incorrect timestamp calculation without num_frames=-1. + metadata = { + "total_num_frames": num_frames, + "fps": num_frames / duration, + "duration": duration, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames_num, + } + + return frames, metadata + + +@VIDEO_LOADER_REGISTRY.register("opencv_dynamic") +class OpenCVDynamicVideoBackend(OpenCVVideoBackend): + @classmethod + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + fps: int = 2, + max_duration: int = 300, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: + import cv2 + + backend = cls().get_cv2_video_api() + cap = cv2.VideoCapture(BytesIO(data), backend, []) + if not cap.isOpened(): + raise ValueError("Could not open video stream") + + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + original_fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames_num / original_fps if original_fps > 0 else 0 + + # resample video to target num_frames + max_frame_idx = total_frames_num - 1 + duration = duration or round(max_frame_idx / original_fps) + 1 + + # Refer to: + # https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140 + frame_indices: Union[range, list[int]] + if duration <= max_duration: + n = int(math.floor(duration * fps)) + frame_indices = sorted( + { + min(max_frame_idx, int(math.ceil(i * original_fps / fps))) + for i in range(n) + } + ) + else: + num_samples = int(max_duration * fps) + if num_samples >= total_frames_num: + frame_indices = range(total_frames_num) + else: + target_seconds = np.linspace(0, duration, num_samples, endpoint=True) + frame_indices = sorted( + { + min(max_frame_idx, int(math.ceil(t * original_fps))) + for t in target_seconds + } + ) + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames = np.empty((len(frame_indices), height, width, 3), dtype=np.uint8) + + i = 0 + for idx in range(total_frames_num): + ok = cap.grab() + if not ok: + break + if idx in frame_indices: + ret, frame = cap.retrieve() + if ret: + frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + i += 1 + + assert i == len(frame_indices), ( + f"Expected reading {len(frame_indices)} frames, " + f"but only loaded {i} frames from video." + ) # Use transformers transformers.video_utils.VideoMetadata format metadata = { "total_num_frames": total_frames_num, "fps": original_fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv_dynamic", + "frames_indices": list(frame_indices), + "do_sample_frames": False, } return frames, metadata class VideoMediaIO(MediaIO[npt.NDArray]): - def __init__( self, image_io: ImageMediaIO, @@ -181,22 +268,22 @@ class VideoMediaIO(MediaIO[npt.NDArray]): self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]: - return self.video_loader.load_bytes(data, - num_frames=self.num_frames, - **self.kwargs) + return self.video_loader.load_bytes( + data, num_frames=self.num_frames, **self.kwargs + ) - def load_base64(self, media_type: str, - data: str) -> tuple[npt.NDArray, dict[str, Any]]: + def load_base64( + self, media_type: str, data: str + ) -> tuple[npt.NDArray, dict[str, Any]]: if media_type.lower() == "video/jpeg": load_frame = partial( self.image_io.load_base64, "image/jpeg", ) - return np.stack([ - np.asarray(load_frame(frame_data)) - for frame_data in data.split(",") - ]), {} + return np.stack( + [np.asarray(load_frame(frame_data)) for frame_data in data.split(",")] + ), {} return self.load_bytes(base64.b64decode(data)) @@ -220,8 +307,7 @@ class VideoMediaIO(MediaIO[npt.NDArray]): image_format=video_format, ) - return ",".join( - encode_frame(Image.fromarray(frame)) for frame in video) + return ",".join(encode_frame(Image.fromarray(frame)) for frame in video) msg = "Only JPEG format is supported for now." raise NotImplementedError(msg) diff --git a/vllm/outputs.py b/vllm/outputs.py index 9784a8894472f..dc183bd8dbe93 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass @@ -11,11 +10,11 @@ import torch from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceGroupBase, SequenceStatus) +from vllm.sequence import RequestMetrics +from vllm.v1.metrics.stats import RequestStateStats logger = init_logger(__name__) @@ -52,13 +51,15 @@ class CompletionOutput: return self.finish_reason is not None def __repr__(self) -> str: - return (f"CompletionOutput(index={self.index}, " - f"text={self.text!r}, " - f"token_ids={self.token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason}, " - f"stop_reason={self.stop_reason})") + return ( + f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"stop_reason={self.stop_reason})" + ) @dataclass @@ -68,14 +69,16 @@ class PoolingOutput: Args: data: The extracted hidden states. """ + data: torch.Tensor def __repr__(self) -> str: - return (f"PoolingOutput(data={self.data})") + return f"PoolingOutput(data={self.data})" def __eq__(self, other: object) -> bool: - return (isinstance(other, self.__class__) and bool( - (self.data == other.data).all())) + return isinstance(other, self.__class__) and bool( + (self.data == other.data).all() + ) class RequestOutput: @@ -110,7 +113,7 @@ class RequestOutput: prompt_logprobs: Optional[PromptLogprobs], outputs: list[CompletionOutput], finished: bool, - metrics: Optional[RequestMetrics] = None, + metrics: Optional[Union[RequestMetrics, RequestStateStats]] = None, lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[list[int]] = None, @@ -123,8 +126,9 @@ class RequestOutput: **kwargs: Any, ) -> None: if kwargs: - logger.warning_once("RequestOutput: Ignoring extra arguments: %s", - str(kwargs)) + logger.warning_once( + "RequestOutput: Ignoring extra arguments: %s", str(kwargs) + ) self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -151,16 +155,15 @@ class RequestOutput: if aggregate: # Merge outputs with same index completion.text += next_completion.text - if not isinstance(completion.token_ids, - MutableSequence): + if not isinstance(completion.token_ids, MutableSequence): completion.token_ids = list(completion.token_ids) completion.token_ids.extend(next_completion.token_ids) if next_completion.logprobs: assert completion.logprobs is not None - completion.logprobs.extend( - next_completion.logprobs) + completion.logprobs.extend(next_completion.logprobs) completion.cumulative_logprob = ( - next_completion.cumulative_logprob) + next_completion.cumulative_logprob + ) completion.finish_reason = next_completion.finish_reason completion.stop_reason = next_completion.stop_reason else: @@ -170,183 +173,21 @@ class RequestOutput: else: self.outputs.append(next_completion) - @classmethod - def from_seq_group( - cls, seq_group: SequenceGroup, use_cache: bool, - seq_id_to_seq_group: dict[str, SequenceGroupBase] - ) -> Optional["RequestOutput"]: - finished = seq_group.is_finished() - - if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] - assembled_seq_group = group.maybe_assemble_group(seq_group) - if finished: - group.finish_seq(seq_group) - if assembled_seq_group is None: - return None - - # clear finished seq in seq_id_to_seq_group - if len(group.to_be_finished) == 0: - for sub_request_id in list(group.seq_id_to_index.keys()): - if sub_request_id in seq_id_to_seq_group: - del seq_id_to_seq_group[sub_request_id] - - return cls.from_seq_group(assembled_seq_group, use_cache, - seq_id_to_seq_group) - - sampling_params = seq_group.sampling_params - if sampling_params is None: - raise ValueError( - "Sampling parameters are missing for a CompletionRequest.") - - if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( - not finished): - return None - - # Init cache (if needed) - if use_cache and seq_group.cached_request_output is None: - seq_group.cached_request_output = RequestOutput( # type: ignore - request_id="", - prompt=None, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[], - finished=False) - - top_n_seqs = seq_group.get_seqs() - - # Create the outputs. - # NOTE: We need omit logprobs here explicitly because the sequence - # always has the logprobs of the sampled tokens even if the - # logprobs are not requested. - include_logprobs = sampling_params.logprobs is not None - text_buffer_length = sampling_params.output_text_buffer_length - delta = sampling_params.output_kind == RequestOutputKind.DELTA - - outputs = [] - include_prompt = True - # num_cached_tokens should be the same for all the sequences - num_cached_tokens = None - for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) - - output_token_ids = seq.get_output_token_ids_to_return(delta) - num_output_tokens = 1 if isinstance(output_token_ids, - int) else len(output_token_ids) - num_cached_tokens = seq.data.get_num_cached_tokens() - - output_logprobs = seq.output_logprobs if include_logprobs else None - - if delta: - # Slice logprobs delta if applicable - if output_logprobs: - # num_output_tokens can be 0 when n > 1 and request finishes - # before the others - if num_output_tokens > 0: - output_logprobs = output_logprobs[-num_output_tokens:] - else: - output_logprobs = None - # Don't include prompt if this is after the first output - # containing decode token ids - if include_prompt and seq.get_output_len() > num_output_tokens: - include_prompt = False - - if use_cache: - # Get cached output object - cached_outputs = seq_group.cached_request_output.outputs # type: ignore - if i >= len(cached_outputs): - cached_outputs.append( - CompletionOutput(index=i, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, - finish_reason=None, - stop_reason=None)) - output = cached_outputs[i] - - # Init cached output object - assert output.index == i - output.text = output_text - - if isinstance(output_token_ids, int): - output.token_ids.clear() - output.token_ids.append(output_token_ids) - else: - output.token_ids = output_token_ids - - output.cumulative_logprob = seq.get_cumulative_logprob() \ - if include_logprobs else None - output.logprobs = output_logprobs - output.finish_reason = SequenceStatus.get_finished_reason( - seq.status) - output.stop_reason = seq.stop_reason - - else: - output = CompletionOutput( - top_n_seqs.index(seq), output_text, [output_token_ids] - if isinstance(output_token_ids, int) else output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - output_logprobs, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) - - outputs.append(output) - - # Every sequence in the sequence group should have the same prompt. - if include_prompt: - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - else: - prompt = None - prompt_token_ids = None - encoder_prompt = None - encoder_prompt_token_ids = None - prompt_logprobs = None - finished_time = time.time() if finished else None - seq_group.set_finished_time(finished_time) - - init_kwargs = { - "request_id": seq_group.request_id, - "prompt": prompt, - "prompt_token_ids": prompt_token_ids, - "prompt_logprobs": prompt_logprobs, - "outputs": outputs, - "finished": finished, - "metrics": seq_group.metrics, - "lora_request": seq_group.lora_request, - "encoder_prompt": encoder_prompt, - "encoder_prompt_token_ids": encoder_prompt_token_ids, - "num_cached_tokens": num_cached_tokens, - "multi_modal_placeholders": seq_group.multi_modal_placeholders - } - - if use_cache: - request_output = seq_group.cached_request_output - request_output.__init__(**init_kwargs) # type: ignore - else: - request_output = cls(**init_kwargs) # type: ignore - - return request_output - def __repr__(self) -> str: - return (f"RequestOutput(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"encoder_prompt={self.encoder_prompt!r}, " - f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"outputs={self.outputs}, " - f"finished={self.finished}, " - f"metrics={self.metrics}, " - f"lora_request={self.lora_request}, " - f"num_cached_tokens={self.num_cached_tokens}, " - f"multi_modal_placeholders={self.multi_modal_placeholders})") + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"encoder_prompt={self.encoder_prompt!r}, " + f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished}, " + f"metrics={self.metrics}, " + f"lora_request={self.lora_request}, " + f"num_cached_tokens={self.num_cached_tokens}, " + f"multi_modal_placeholders={self.multi_modal_placeholders})" + ) _O = TypeVar("_O", default=PoolingOutput) @@ -363,44 +204,21 @@ class PoolingRequestOutput(Generic[_O]): finished (bool): A flag indicating whether the pooling is completed. """ - def __init__(self, request_id: str, outputs: _O, - prompt_token_ids: list[int], finished: bool): + def __init__( + self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool + ): self.request_id = request_id self.prompt_token_ids = prompt_token_ids self.finished = finished self.outputs = outputs - @staticmethod - def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput": - pooled_data = seq_group.pooled_data - assert pooled_data is not None - - data = pooled_data.to(dtype=torch.float32, device="cpu") - output = PoolingOutput(data) - prompt_token_ids = seq_group.prompt_token_ids - finished = seq_group.is_finished() - - return PoolingRequestOutput(seq_group.request_id, output, - prompt_token_ids, finished) - def __repr__(self): - return (f"{type(self).__name__}(request_id={self.request_id!r}, " - f"outputs={self.outputs!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"finished={self.finished})") - - -class RequestOutputFactory: - - @staticmethod - def create(seq_group: SequenceGroup, - seq_id_to_seq_group: dict[str, SequenceGroupBase], - use_cache: bool = False): - if seq_group.pooled_data is not None: - return PoolingRequestOutput.from_seq_group(seq_group) - else: - return RequestOutput.from_seq_group(seq_group, use_cache, - seq_id_to_seq_group) + return ( + f"{type(self).__name__}(request_id={self.request_id!r}, " + f"outputs={self.outputs!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})" + ) @dataclass @@ -409,8 +227,9 @@ class EmbeddingOutput: Args: embedding: The embedding vector, which is a list of floats. - Its length depends on the hidden dimension of the model. + Its length depends on the hidden dimension of the model. """ + embedding: list[float] @staticmethod @@ -430,7 +249,6 @@ class EmbeddingOutput: class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]): - @staticmethod def from_base(request_output: PoolingRequestOutput): return EmbeddingRequestOutput( @@ -447,8 +265,9 @@ class ClassificationOutput: Args: probs: The probability vector, which is a list of floats. - Its length depends on the number of classes. + Its length depends on the number of classes. """ + probs: list[float] @staticmethod @@ -469,7 +288,6 @@ class ClassificationOutput: class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]): - @staticmethod def from_base(request_output: PoolingRequestOutput): return ClassificationRequestOutput( @@ -487,6 +305,7 @@ class ScoringOutput: Args: score: The similarity score, which is a scalar value. """ + score: float @staticmethod @@ -505,7 +324,6 @@ class ScoringOutput: class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]): - @staticmethod def from_base(request_output: PoolingRequestOutput): return ScoringRequestOutput( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 56edb8629e45b..d1708ad5c7517 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -9,7 +9,6 @@ from vllm import envs from vllm.plugins import load_plugins_by_group from vllm.utils import resolve_obj_by_qualname, supports_xccl -from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum logger = logging.getLogger(__name__) @@ -20,12 +19,14 @@ def vllm_version_matches_substr(substr: str) -> bool: Check to see if the vLLM version matches a substring. """ from importlib.metadata import PackageNotFoundError, version + try: vllm_version = version("vllm") except PackageNotFoundError as e: logger.warning( "The vLLM package was not found, so its version could not be " - "inspected. This may cause platform detection to fail.") + "inspected. This may cause platform detection to fail." + ) raise e return substr in vllm_version @@ -36,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]: # Check for Pathways TPU proxy if envs.VLLM_TPU_USING_PATHWAYS: logger.debug("Confirmed TPU platform is available via Pathways proxy.") - return "tpu_commons.platforms.tpu_jax.TpuPlatform" + return "tpu_inference.platforms.tpu_jax.TpuPlatform" # Check for libtpu installation try: @@ -46,6 +47,7 @@ def tpu_platform_plugin() -> Optional[str]: # has TPUs. import libtpu # noqa: F401 + logger.debug("Confirmed TPU platform is available.") return "vllm.platforms.tpu.TpuPlatform" except Exception as e: @@ -58,6 +60,7 @@ def cuda_platform_plugin() -> Optional[str]: logger.debug("Checking if CUDA platform is available.") try: from vllm.utils import import_pynvml + pynvml = import_pynvml() pynvml.nvmlInit() try: @@ -66,21 +69,22 @@ def cuda_platform_plugin() -> Optional[str]: # we need to check if vllm is built with cpu too. # Otherwise, vllm will always activate cuda plugin # on a GPU machine, even if in a cpu build. - is_cuda = (pynvml.nvmlDeviceGetCount() > 0 - and not vllm_version_matches_substr("cpu")) + is_cuda = ( + pynvml.nvmlDeviceGetCount() > 0 + and not vllm_version_matches_substr("cpu") + ) if pynvml.nvmlDeviceGetCount() <= 0: - logger.debug( - "CUDA platform is not available because no GPU is found.") + logger.debug("CUDA platform is not available because no GPU is found.") if vllm_version_matches_substr("cpu"): - logger.debug("CUDA platform is not available because" - " vLLM is built with CPU.") + logger.debug( + "CUDA platform is not available because vLLM is built with CPU." + ) if is_cuda: logger.debug("Confirmed CUDA platform is available.") finally: pynvml.nvmlShutdown() except Exception as e: - logger.debug("Exception happens when checking CUDA platform: %s", - str(e)) + logger.debug("Exception happens when checking CUDA platform: %s", str(e)) if "nvml" not in e.__class__.__name__.lower(): # If the error is not related to NVML, re-raise it. raise e @@ -89,8 +93,9 @@ def cuda_platform_plugin() -> Optional[str]: import os def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") + return os.path.isfile("/etc/nv_tegra_release") or os.path.exists( + "/sys/class/tegra-firmware" + ) if cuda_is_jetson(): logger.debug("Confirmed CUDA platform is available on Jetson.") @@ -106,14 +111,14 @@ def rocm_platform_plugin() -> Optional[str]: logger.debug("Checking if ROCm platform is available.") try: import amdsmi + amdsmi.amdsmi_init() try: if len(amdsmi.amdsmi_get_processor_handles()) > 0: is_rocm = True logger.debug("Confirmed ROCm platform is available.") else: - logger.debug("ROCm platform is not available because" - " no GPU is found.") + logger.debug("ROCm platform is not available because no GPU is found.") finally: amdsmi.amdsmi_shut_down() except Exception as e: @@ -129,18 +134,19 @@ def xpu_platform_plugin() -> Optional[str]: # installed IPEX if the machine has XPUs. import intel_extension_for_pytorch # noqa: F401 import torch + if supports_xccl(): dist_backend = "xccl" else: dist_backend = "ccl" import oneccl_bindings_for_pytorch # noqa: F401 - if hasattr(torch, 'xpu') and torch.xpu.is_available(): + if hasattr(torch, "xpu") and torch.xpu.is_available(): is_xpu = True from vllm.platforms.xpu import XPUPlatform + XPUPlatform.dist_backend = dist_backend - logger.debug("Confirmed %s backend is available.", - XPUPlatform.dist_backend) + logger.debug("Confirmed %s backend is available.", XPUPlatform.dist_backend) logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) @@ -154,14 +160,17 @@ def cpu_platform_plugin() -> Optional[str]: try: is_cpu = vllm_version_matches_substr("cpu") if is_cpu: - logger.debug("Confirmed CPU platform is available because" - " vLLM is built with CPU.") + logger.debug( + "Confirmed CPU platform is available because vLLM is built with CPU." + ) if not is_cpu: import sys + is_cpu = sys.platform.startswith("darwin") if is_cpu: - logger.debug("Confirmed CPU platform is available" - " because the machine is MacOS.") + logger.debug( + "Confirmed CPU platform is available because the machine is MacOS." + ) except Exception as e: logger.debug("CPU platform is not available because: %s", str(e)) @@ -169,47 +178,21 @@ def cpu_platform_plugin() -> Optional[str]: return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None -def neuron_platform_plugin() -> Optional[str]: - tnx_installed = False - nxd_installed = False - logger.debug("Checking if Neuron platform is available.") - try: - import transformers_neuronx # noqa: F401 - tnx_installed = True - logger.debug("Confirmed Neuron platform is available because" - " transformers_neuronx is found.") - except ImportError: - pass - - try: - import neuronx_distributed_inference # noqa: F401 - nxd_installed = True - logger.debug("Confirmed Neuron platform is available because" - " neuronx_distributed_inference is found.") - except ImportError: - pass - - is_neuron = tnx_installed or nxd_installed - return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None - - builtin_platform_plugins = { - 'tpu': tpu_platform_plugin, - 'cuda': cuda_platform_plugin, - 'rocm': rocm_platform_plugin, - 'xpu': xpu_platform_plugin, - 'cpu': cpu_platform_plugin, - 'neuron': neuron_platform_plugin, + "tpu": tpu_platform_plugin, + "cuda": cuda_platform_plugin, + "rocm": rocm_platform_plugin, + "xpu": xpu_platform_plugin, + "cpu": cpu_platform_plugin, } def resolve_current_platform_cls_qualname() -> str: - platform_plugins = load_plugins_by_group('vllm.platform_plugins') + platform_plugins = load_plugins_by_group("vllm.platform_plugins") activated_plugins = [] - for name, func in chain(builtin_platform_plugins.items(), - platform_plugins.items()): + for name, func in chain(builtin_platform_plugins.items(), platform_plugins.items()): try: assert callable(func) platform_cls_qualname = func() @@ -219,43 +202,41 @@ def resolve_current_platform_cls_qualname() -> str: pass activated_builtin_plugins = list( - set(activated_plugins) & set(builtin_platform_plugins.keys())) - activated_oot_plugins = list( - set(activated_plugins) & set(platform_plugins.keys())) + set(activated_plugins) & set(builtin_platform_plugins.keys()) + ) + activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys())) if len(activated_oot_plugins) >= 2: raise RuntimeError( "Only one platform plugin can be activated, but got: " - f"{activated_oot_plugins}") + f"{activated_oot_plugins}" + ) elif len(activated_oot_plugins) == 1: platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]() - logger.info("Platform plugin %s is activated", - activated_oot_plugins[0]) + logger.info("Platform plugin %s is activated", activated_oot_plugins[0]) elif len(activated_builtin_plugins) >= 2: raise RuntimeError( "Only one platform plugin can be activated, but got: " - f"{activated_builtin_plugins}") + f"{activated_builtin_plugins}" + ) elif len(activated_builtin_plugins) == 1: - platform_cls_qualname = builtin_platform_plugins[ - activated_builtin_plugins[0]]() - logger.info("Automatically detected platform %s.", - activated_builtin_plugins[0]) + platform_cls_qualname = builtin_platform_plugins[activated_builtin_plugins[0]]() + logger.info("Automatically detected platform %s.", activated_builtin_plugins[0]) else: platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform" - logger.info( - "No platform detected, vLLM is running on UnspecifiedPlatform") + logger.info("No platform detected, vLLM is running on UnspecifiedPlatform") return platform_cls_qualname _current_platform = None -_init_trace: str = '' +_init_trace: str = "" if TYPE_CHECKING: current_platform: Platform def __getattr__(name: str): - if name == 'current_platform': + if name == "current_platform": # lazy init current_platform. # 1. out-of-tree platform plugins need `from vllm.platforms import # Platform` so that they can inherit `Platform` class. Therefore, @@ -270,19 +251,24 @@ def __getattr__(name: str): global _current_platform if _current_platform is None: platform_cls_qualname = resolve_current_platform_cls_qualname() - _current_platform = resolve_obj_by_qualname( - platform_cls_qualname)() + _current_platform = resolve_obj_by_qualname(platform_cls_qualname)() global _init_trace _init_trace = "".join(traceback.format_stack()) return _current_platform elif name in globals(): return globals()[name] else: - raise AttributeError( - f"No attribute named '{name}' exists in {__name__}.") + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") -__all__ = [ - 'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum', - "_init_trace" -] +def __setattr__(name: str, value): + if name == "current_platform": + global _current_platform + _current_platform = value + elif name in globals(): + globals()[name] = value + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + +__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index c748595a71534..49c953fd36ee0 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -4,6 +4,7 @@ import json import os import platform +import re import subprocess import sys from dataclasses import dataclass @@ -15,20 +16,22 @@ import torch from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend +from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import VllmConfig else: + _Backend = None VllmConfig = None def get_max_threads(pid=0): - if hasattr(os, 'sched_getaffinity'): + if hasattr(os, "sched_getaffinity"): return len(os.sched_getaffinity(pid)) - elif platform.system() == 'Darwin': + elif platform.system() == "Darwin": return os.cpu_count() else: raise NotImplementedError("Unsupported OS") @@ -58,7 +61,8 @@ class LogicalCPUInfo: return LogicalCPUInfo( id=LogicalCPUInfo._int(id), physical_core=LogicalCPUInfo._int(physical_core), - numa_node=LogicalCPUInfo._int(numa_node)) + numa_node=LogicalCPUInfo._int(numa_node), + ) else: return obj_dict @@ -69,18 +73,48 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" dist_backend: str = "gloo" + device_control_env_var = "CPU_VISIBLE_MEMORY_NODES" @property def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] - elif sys.platform.startswith( - "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: - # TODO: change this condition to check if the platform support bf16 - # instead of checking the OS. For instance M2 shall supports bf16 - # already. But we need to modify `cpu_extension.cmake` to activate - # the feature in the build. + elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith( + "darwin" + ): + if ( + subprocess.check_output( + ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True + ).strip() + == b"1" + ): + return [torch.bfloat16, torch.float16, torch.float32] return [torch.float16, torch.float32] + elif self.get_cpu_architecture() == CpuArchEnum.RISCV: + # Workaround for Issue #25655: RISC-V scheduler bug with float16 + # + # Background: + # - RISC-V currently uses scalar code path + # - There is a latent bug in the vLLM scheduler that provides + # invalid + # physical_block_idx values under certain conditions + # - This bug causes segmentation faults when using float16 + # dtype on RISC-V + # - Testing shows that forcing float32 successfully bypasses + # this issue + # + # Technical details: + # - The bug manifests as out-of-bounds physical_block_idx in + # block_tables + # - Only occurs on RISC-V hardware + # tested on Sophgo SG2044 + # - Does not reproduce on x86 or other architectures + # - Root cause is in Python-level scheduling logic, + # not C++ kernels + # + # This is a temporary workaround until the scheduler bug is fixed. + # See: https://github.com/vllm-project/vllm/issues/25655 + return [torch.float32] # x86/aarch64 CPU has supported both bf16 and fp16 natively. return [torch.bfloat16, torch.float16, torch.float32] @@ -89,14 +123,26 @@ class CpuPlatform(Platform): return "cpu" @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: + from vllm.attention.backends.registry import _Backend + if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") + if use_sparse: + raise NotImplementedError("Sparse Attention is not supported on CPU.") logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") @@ -112,7 +158,8 @@ class CpuPlatform(Platform): kv_cache_space = 4 * GiB_bytes # type: ignore logger.warning_once( "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " - "for CPU backend is not set, using 4 by default.") + "for CPU backend is not set, using 4 by default." + ) else: kv_cache_space *= GiB_bytes @@ -125,10 +172,6 @@ class CpuPlatform(Platform): """ torch.cpu.set_device(device) - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - @classmethod def inference_mode(cls): return torch.no_grad() @@ -150,48 +193,66 @@ class CpuPlatform(Platform): if not ipex_available and cache_config.block_size != 16: raise RuntimeError( f"--block-size={cache_config.block_size} requires" - " intel_extension_for_pytorch") + " intel_extension_for_pytorch" + ) scheduler_config = vllm_config.scheduler_config - if ((scheduler_config.chunked_prefill_enabled - or cache_config.enable_prefix_caching) - and cache_config.cache_dtype != "auto"): - raise RuntimeError("Chunked-prefill and prefix-cache on the CPU " - "backend is not compatible with FP8 KV cache.") + if ( + scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching + ) and cache_config.cache_dtype != "auto": + raise RuntimeError( + "Chunked-prefill and prefix-cache on the CPU " + "backend is not compatible with FP8 KV cache." + ) if cache_config.cache_dtype == "fp8_e4m3": cache_config.cache_dtype = "fp8_e5m2" logger.warning( - "CPU backend doesn't support fp8_e4m3 KV cache type, " - "cast to fp8_e5m2.") + "CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2." + ) - if (cache_config.cache_dtype != "auto" and model_config is not None - and model_config.dtype == torch.half): - logger.warning("FP8 KV cache on the CPU backend only does not" - " support fp16 for now, cast to bf16.") + if ( + cache_config.cache_dtype != "auto" + and model_config is not None + and model_config.dtype == torch.half + ): + logger.warning( + "FP8 KV cache on the CPU backend only does not" + " support fp16 for now, cast to bf16." + ) model_config.dtype = torch.bfloat16 - cache_config.cpu_kvcache_space_bytes = \ - CpuPlatform.get_device_total_memory() + cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory() parallel_config = vllm_config.parallel_config - if (parallel_config.world_size > 1 - and parallel_config.distributed_executor_backend is not None - and parallel_config.distributed_executor_backend != "mp"): - logger.warning(("%s is not supported on CPU, fallback to mp " - "distributed executor backend."), - parallel_config.distributed_executor_backend) + if ( + parallel_config.world_size > 1 + and parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "mp" + ): + logger.warning( + ( + "%s is not supported on CPU, fallback to mp " + "distributed executor backend." + ), + parallel_config.distributed_executor_backend, + ) parallel_config.distributed_executor_backend = "mp" if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker" + # Disable DBO + if parallel_config.enable_dbo: + logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.") + parallel_config.enable_dbo = False # Note: workaround for v1 gpu_model_runner from vllm.config import CompilationLevel + vllm_config.compilation_config.cudagraph_capture_sizes = [] compilation_config = vllm_config.compilation_config if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: - # Note: vLLM V1 is using PIECEWISE level compilation, which will # take time to compile kernels just-in-time with the inductor # backend. For CPU CI tests, most of them are executed fast and @@ -206,16 +267,14 @@ class CpuPlatform(Platform): compilation_config.level = CompilationLevel.DYNAMO_ONCE compilation_config.backend = backend - compilation_config.inductor_compile_config.update({ - "dce": - True, - "size_asserts": - False, - "nan_asserts": - False, - "epilogue_fusion": - True, - }) + compilation_config.inductor_compile_config.update( + { + "dce": True, + "size_asserts": False, + "nan_asserts": False, + "epilogue_fusion": True, + } + ) if compilation_config.use_inductor: compilation_config.custom_ops = ["none"] @@ -245,51 +304,57 @@ class CpuPlatform(Platform): if "libiomp5.so" in ld_prealod_str: # The time(milliseconds) that a thread should wait after # completing the execution of a parallel region, before sleeping. - os.environ['KMP_BLOCKTIME'] = "1" + os.environ["KMP_BLOCKTIME"] = "1" # Prevents the CPU to run into low performance state - os.environ['KMP_TPAUSE'] = "0" + os.environ["KMP_TPAUSE"] = "0" # Provides fine granularity parallelism - os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist" + os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist" + os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist" # To hint IPEX uses shared memory based AllReduce os.environ["LOCAL_WORLD_SIZE"] = str( - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.parallel_config.tensor_parallel_size + ) if model_config is not None and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") + "prefill and prefix caching to be disabled." + ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) @classmethod - def get_allowed_cpu_core_node_list( - cls) -> tuple[list[int], list[LogicalCPUInfo]]: + def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]: assert platform.system() == "Linux" # Init LogicalCPUInfo from lscpu - lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE", - shell=True, - text=True) + lscpu_output = subprocess.check_output( + "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True + ) + lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output) logical_cpu_list: list[LogicalCPUInfo] = json.loads( - lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus'] + lscpu_output, object_hook=LogicalCPUInfo.json_decoder + )["cpus"] # Filter CPUs with invalid attributes logical_cpu_list = [ - x for x in logical_cpu_list + x + for x in logical_cpu_list if -1 not in (x.id, x.physical_core, x.numa_node) ] # Filter allowed CPUs - allowed_cpu_id_list = os.sched_getaffinity(0) - logical_cpu_list = [ - x for x in logical_cpu_list if x.id in allowed_cpu_id_list - ] + if hasattr(os, "sched_getaffinity"): + allowed_cpu_id_list = os.sched_getaffinity(0) + else: + raise NotImplementedError("Unsupported OS") + logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list] # Get allowed NUMA nodes allowed_numa_nodes = set() @@ -297,6 +362,13 @@ class CpuPlatform(Platform): allowed_numa_nodes.add(x.numa_node) # type: ignore allowed_numa_nodes_list = sorted(allowed_numa_nodes) + env_key = CpuPlatform.device_control_env_var + if env_key in os.environ and os.environ[env_key] != "": + visible_nodes = [int(s) for s in os.environ[env_key].split(",")] + allowed_numa_nodes_list = [ + x for x in visible_nodes if x in allowed_cpu_id_list + ] + return allowed_numa_nodes_list, logical_cpu_list @classmethod @@ -320,18 +392,9 @@ class CpuPlatform(Platform): return True @classmethod - def supports_v1(cls, model_config) -> bool: - """Returns whether the current platform can support v1 for the supplied - model configuration. - """ + def opaque_attention_op(cls) -> bool: return True @classmethod - def default_v1(cls, model_config) -> bool: - """Returns whether the current platform can use v1 by default for the - supplied model configuration. - """ - arch = cls.get_cpu_architecture() - return (cls.supports_v1(model_config) - and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, - CpuArchEnum.ARM, CpuArchEnum.S390X)) + def support_hybrid_kv_cache(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 134ba36e5e735..e0f832b431147 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -20,10 +20,13 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless, import_pynvml -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig +else: + _Backend = None logger = init_logger(__name__) @@ -38,7 +41,6 @@ torch.backends.cuda.enable_cudnn_sdp(False) def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: - @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: pynvml.nvmlInit() @@ -64,8 +66,7 @@ class CudaPlatformBase(Platform): if self.has_device_capability(80): # Ampere and Hopper or later NVIDIA GPUs. return [torch.bfloat16, torch.float16, torch.float32] - elif (not self.has_device_capability(80) - ) and self.has_device_capability(60): + if self.has_device_capability(60): # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported return [torch.float16, torch.float32] # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported, @@ -84,9 +85,7 @@ class CudaPlatformBase(Platform): _ = torch.zeros(1, device=device) @classmethod - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: raise NotImplementedError @classmethod @@ -97,16 +96,6 @@ class CudaPlatformBase(Platform): def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager and not envs.VLLM_USE_V1: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - return False - return True - @classmethod def is_fully_connected(cls, device_ids: list[int]) -> bool: raise NotImplementedError @@ -121,17 +110,7 @@ class CudaPlatformBase(Platform): model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": - if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Speculative decoding is not supported on vLLM V0.") - parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" - else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: @@ -139,13 +118,23 @@ class CudaPlatformBase(Platform): # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # for models with both attention and mamba, + # and doesn't need to be reinitialized here + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the # required block_size. use_flashmla = False use_cutlass_mla = False + use_flashinfer_mla = False if envs.VLLM_ATTENTION_BACKEND is None: # Default case @@ -161,136 +150,221 @@ class CudaPlatformBase(Platform): use_flashmla = True else: # Forced case - use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") - use_cutlass_mla = ( - envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" - from vllm.attention.ops.flashmla import is_flashmla_supported - if use_flashmla and is_flashmla_supported()[0] \ - and cache_config.block_size != 64: + from vllm.attention.ops.flashmla import is_flashmla_dense_supported + + if ( + use_flashmla + and is_flashmla_dense_supported()[0] + and cache_config.block_size % 64 != 0 + ): + cache_config.block_size = 64 + logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") + + if use_cutlass_mla and cache_config.block_size % 128 != 0: + cache_config.block_size = 128 + logger.info( + "Forcing kv cache block size to 128 for CUTLASS_MLA backend." + ) + + if ( + use_flashinfer_mla + and cache_config.block_size != 32 + and cache_config.block_size % 64 != 0 + ): cache_config.block_size = 64 logger.info( - "Forcing kv cache block size to 64 for FlashMLA backend.") - - if use_cutlass_mla and cache_config.block_size != 128: - cache_config.block_size = 128 - logger.info("Forcing kv cache block size to 128 for " - "CUTLASS_MLA backend.") + "Forcing kv cache block size to 64 for FlashInferMLA backend." + ) + # TODO(Chen): remove this hacky code + if use_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config - if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" - and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + if ( + envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + and parallel_config.data_parallel_size > 1 + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + # TODO: Piecewise Cuda graph might be enabled + # if torch compile cache key issue fixed + # See https://github.com/vllm-project/vllm/pull/25093 logger.info( - "Data Parallel: disabling cudagraphs since DP " - "with DeepEP high-throughput kernels are not CUDA Graph " - "compatible. The DeepEP low-latency kernels are CUDA Graph " - "compatible. Set the all_to_all backend to deepep_low_latency " - "to use those kernels instead.") + "WideEP: Disabling CUDA Graphs since DeepEP high-throughput " + "kernels are optimized for prefill and are incompatible with " + "CUDA Graphs. " + "In order to use CUDA Graphs for decode-optimized workloads, " + "set VLLM_ALL2ALL_BACKEND to another option, such as " + "deepep_low_latency, pplx, or allgather_reducescatter." + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE - if model_config is not None: - model_config.enforce_eager = True @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if cls.has_device_capability(80) and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend + + # For Blackwell GPUs, force TORCH_SDPA for now. + # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 + if cls.has_device_capability(100): + return _Backend.TORCH_SDPA + + if dtype not in (torch.float16, torch.bfloat16): + return _Backend.XFORMERS + + if cls.has_device_capability(80): + FLASH_ATTN_V1 = ( + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + ) + from vllm.attention.selector import is_attn_backend_supported + + is_default_fa_supported = is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ) + if is_default_fa_supported: return _Backend.FLASH_ATTN - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + else: + # Fallback to XFORMERS + return _Backend.XFORMERS + else: + # Fallback for Volta/Turing GPUs or FA not supported + return _Backend.XFORMERS @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink) -> str: + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) -> str: + from vllm.attention.backends.registry import _Backend + if use_mla: - # TODO(lucas): refactor to be more concise - # we should probably consider factoring out V1 here - if selected_backend == _Backend.CUTLASS_MLA or ( - cls.is_device_capability(100) and selected_backend is None - and block_size == 128): - if use_v1: - logger.info_once("Using Cutlass MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "cutlass_mla.CutlassMLABackend") - else: - logger.warning( - "Cutlass MLA backend is only supported on V1 engine") - if selected_backend == _Backend.TRITON_MLA or block_size != 64: - if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" - else: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - if not is_flashmla_supported()[0]: - logger.warning( - "FlashMLA backend is not supported due to %s", - is_flashmla_supported()[1]) - elif block_size != 64: + if not use_v1: + raise RuntimeError( + "MLA attention backends require the V1 engine. " + "Set VLLM_USE_V1=1 to enable them." + ) + + from vllm.attention.ops.flashmla import is_flashmla_dense_supported + from vllm.attention.utils.fa_utils import flash_attn_supports_mla + + if use_sparse: + logger.info_once("Using Sparse MLA backend on V1 engine.") + return ( + "vllm.v1.attention.backends.mla.flashmla_sparse." + "FlashMLASparseBackend" + ) + + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( + selected_backend is None + and cls.is_device_capability(100) + and block_size % 128 == 0 + ) + use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( + selected_backend is None + and cls.is_device_capability(100) + and (block_size == 32 or block_size % 64 == 0) + ) + use_flashmla = selected_backend == _Backend.FLASHMLA or ( + selected_backend is None and is_flashmla_dense_supported()[0] + ) + use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( + selected_backend is None and flash_attn_supports_mla() + ) + use_triton = selected_backend == _Backend.TRITON_MLA or ( + selected_backend is None + ) + + if use_cutlassmla: + logger.info_once("Using Cutlass MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" + if use_flashinfermla: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout("HND") + logger.info_once("Using FlashInfer MLA backend on V1 engine.") + return ( + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + ) + if use_flashmla: + if block_size % 64 != 0: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", - block_size) + block_size, + ) else: - if use_v1: - logger.info_once( - "Using FlashMLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashmla.FlashMLABackend") - else: - logger.info("Using FlashMLA backend.") - return ("vllm.attention.backends." - "flashmla.FlashMLABackend") + logger.info_once("Using FlashMLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" + if use_flashattn: + logger.info_once("Using FlashAttention MLA backend on V1 engine.") + return ( + "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + ) + if use_triton: + logger.info_once("Using Triton MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + FLEX_ATTENTION_V1 = ( + "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + ) + TRITON_ATTN = ( + "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + ) + FLASH_ATTN_V1 = ( + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + ) TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 + use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( + "fp8" + ) + if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + from vllm.v1.attention.backends.utils import set_kv_cache_layout + set_kv_cache_layout("HND") return FLASHINFER_V1 elif selected_backend == _Backend.FLEX_ATTENTION: logger.info_once("Using FlexAttention backend on V1 engine.") return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + elif selected_backend == _Backend.TRITON_ATTN: logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN_VLLM_V1 + return TRITON_ATTN elif selected_backend == _Backend.FLASH_ATTN: logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 elif selected_backend == _Backend.TREE_ATTN: logger.info_once("Using Tree Attention backend on V1 engine.") return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS_VLLM_V1: + elif selected_backend == _Backend.XFORMERS: logger.info_once("Using XFormers backend on V1 engine.") return XFORMERS_V1 @@ -300,13 +374,14 @@ class CudaPlatformBase(Platform): # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + FLASHINFER_V1, head_size, dtype + ): + from vllm.v1.attention.backends.utils import set_kv_cache_layout logger.info_once( "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs.") + "V1 engine by default for Blackwell (SM 10.0) GPUs." + ) set_kv_cache_layout("HND") return FLASHINFER_V1 @@ -315,18 +390,18 @@ class CudaPlatformBase(Platform): logger.warning_once( "FlashInfer failed to import for V1 engine on " "Blackwell (SM 10.0) GPUs; it is recommended to " - "install FlashInfer for better performance.") + "install FlashInfer for better performance." + ) # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): - if has_sink and not cls.is_device_capability(90): + if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): logger.info_once("Using Triton backend on V1 engine.") - return TRITON_ATTN_VLLM_V1 - if is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, - allow_import_error=False): - logger.info_once("Using Flash Attention backend on " - "V1 engine.") + return TRITON_ATTN + elif is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): + logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 # FlexAttention is the default for older GPUs @@ -344,83 +419,14 @@ class CudaPlatformBase(Platform): logger.info_once( "Using FlexAttention backend for %s on V1 engine.", - ", ".join(f"{k}={v}" - for k, v in use_flex_attention_reason.items()), + ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), ) return FLEX_ATTENTION_V1 - # Backends for V0 engine - if selected_backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - return "vllm.attention.backends.xformers.XFormersBackend" - elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: - logger.info("Using DualChunkFlashAttention backend.") - return ("vllm.attention.backends.dual_chunk_flash_attn." - "DualChunkFlashAttentionBackend") - elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: - logger.info("Using DifferentialFlashAttention backend.") - return ("vllm.attention.backends.differential_flash_attn." - "DifferentialFlashAttentionBackend") - elif selected_backend == _Backend.FLASH_ATTN: - pass - elif selected_backend: - raise ValueError( - f"Invalid attention backend for {cls.device_name}, " - f"with use_v1: {use_v1} use_mla: {use_mla}") - - target_backend = _Backend.FLASH_ATTN - if not cls.has_device_capability(80): - # Volta and Turing NVIDIA GPUs. - logger.info( - "Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - target_backend = _Backend.XFORMERS - elif dtype not in (torch.float16, torch.bfloat16): - logger.info( - "Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - target_backend = _Backend.XFORMERS - elif block_size % 16 != 0: - logger.info( - "Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - target_backend = _Backend.XFORMERS - - # FlashAttn is valid for the model, checking if the package is - # installed. - if target_backend == _Backend.FLASH_ATTN: - try: - import vllm.vllm_flash_attn # noqa: F401 - from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend, flash_attn_supports_fp8) - - supported_sizes = \ - FlashAttentionBackend.get_supported_head_sizes() - if head_size not in supported_sizes: - logger.info( - "Cannot use FlashAttention-2 backend for head size %d.", - head_size) - target_backend = _Backend.XFORMERS - fp8_kv_cache = (kv_cache_dtype is not None - and kv_cache_dtype.startswith("fp8")) - if (fp8_kv_cache and not flash_attn_supports_fp8()): - logger.info( - "Cannot use FlashAttention backend for FP8 KV cache.") - target_backend = _Backend.XFORMERS - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the " - "vllm.vllm_flash_attn package is not found. " - "Make sure that vllm_flash_attn was built and installed " - "(on by default).") - target_backend = _Backend.XFORMERS - - if target_backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - return "vllm.attention.backends.xformers.XFormersBackend" - - logger.info("Using Flash Attention backend.") - return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." + ) @classmethod def get_punica_wrapper(cls) -> str: @@ -428,18 +434,20 @@ class CudaPlatformBase(Platform): @classmethod def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + return ( + "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + ) @classmethod def supports_fp8(cls) -> bool: return cls.has_device_capability(89) @classmethod - def supports_v1(cls, model_config: "ModelConfig") -> bool: + def use_custom_allreduce(cls) -> bool: return True @classmethod - def use_custom_allreduce(cls) -> bool: + def opaque_attention_op(cls) -> bool: return True @classmethod @@ -466,8 +474,9 @@ class CudaPlatformBase(Platform): backend_options = ProcessGroupNCCL.Options() backend_options._timeout = timeout - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") pg._set_default_backend(backend_type) @@ -481,8 +490,9 @@ class CudaPlatformBase(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") attention_backend = envs.VLLM_ATTENTION_BACKEND @@ -496,41 +506,95 @@ class CudaPlatformBase(Platform): else: attention_backend = "FLASHMLA" - # Only FlashMLA supports fp8 - if attention_backend == "FLASHMLA": + # Only FlashMLA and CUTLASS_MLA support fp8 + if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]: supported = True else: - supported = (not fp8_attention) + supported = not fp8_attention else: # Default to FlashAttention if attention_backend is None: - attention_backend = "FLASH_ATTN_VLLM_V1" + attention_backend = "FLASH_ATTN" # All Blackwell backends support fp8 if cls.is_device_capability(100): supported = True - elif attention_backend == "FLASH_ATTN_VLLM_V1": + elif attention_backend == "FLASH_ATTN": if fp8_attention: - from vllm.attention.utils.fa_utils import ( - flash_attn_supports_fp8) + from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 + supported = flash_attn_supports_fp8() else: supported = True + elif attention_backend == "FLASHINFER": + supported = True + elif attention_backend == "TRITON_ATTN": + supported = cls.supports_fp8() return supported + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half." + ) + + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on GPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from GPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + + @classmethod + def support_static_graph_mode(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. # the major benefit of using NVML is that it will not initialize CUDA class NvmlCudaPlatform(CudaPlatformBase): - @classmethod @cache @with_nvml_context - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: try: physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) @@ -577,9 +641,7 @@ class NvmlCudaPlatform(CudaPlatformBase): """ query if the set of gpus are fully connected by nvlink (1 hop) """ - handles = [ - pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids - ] + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] for i, handle in enumerate(handles): for j, peer_handle in enumerate(handles): if i < j: @@ -594,7 +656,8 @@ class NvmlCudaPlatform(CudaPlatformBase): except pynvml.NVMLError: logger.exception( "NVLink detection failed. This is normal if" - " your machine has no NVLink equipped.") + " your machine has no NVLink equipped." + ) return False return True @@ -608,11 +671,11 @@ class NvmlCudaPlatform(CudaPlatformBase): def log_warnings(cls): device_ids: int = pynvml.nvmlDeviceGetCount() if device_ids > 1: - device_names = [ - cls._get_physical_device_name(i) for i in range(device_ids) - ] - if (len(set(device_names)) > 1 - and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): + device_names = [cls._get_physical_device_name(i) for i in range(device_ids)] + if ( + len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID" + ): logger.warning( "Detected different devices in the system: %s. Please" " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " @@ -622,7 +685,6 @@ class NvmlCudaPlatform(CudaPlatformBase): class NonNvmlCudaPlatform(CudaPlatformBase): - @classmethod @cache def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @@ -642,7 +704,8 @@ class NonNvmlCudaPlatform(CudaPlatformBase): def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: logger.exception( "NVLink detection not possible, as context support was" - " not found. Assuming no NVLink available.") + " not found. Assuming no NVLink available." + ) return False diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 00bc555288e8e..e372ebf0cb3f7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import enum import os import platform @@ -17,12 +18,14 @@ from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser else: + _Backend = None ModelConfig = None VllmConfig = None LoRARequest = None @@ -38,41 +41,12 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - FLASH_ATTN_VLLM_V1 = enum.auto() - TRITON_ATTN_VLLM_V1 = enum.auto() - XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() # Supported by V1 - ROCM_AITER_MLA_VLLM_V1 = enum.auto() - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_VLLM_V1 = enum.auto() - TRITON_MLA = enum.auto() # Supported by V1 - TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 - CUTLASS_MLA = enum.auto() - PALLAS = enum.auto() - PALLAS_VLLM_V1 = enum.auto() - IPEX = enum.auto() - DUAL_CHUNK_FLASH_ATTN = enum.auto() - DIFFERENTIAL_FLASH_ATTN = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - XFORMERS_VLLM_V1 = enum.auto() - - class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() TPU = enum.auto() XPU = enum.auto() CPU = enum.auto() - NEURON = enum.auto() OOT = enum.auto() UNSPECIFIED = enum.auto() @@ -82,6 +56,7 @@ class CpuArchEnum(enum.Enum): ARM = enum.auto() POWERPC = enum.auto() S390X = enum.auto() + RISCV = enum.auto() OTHER = enum.auto() UNKNOWN = enum.auto() @@ -163,12 +138,12 @@ class Platform: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - def is_neuron(self) -> bool: - return self._enum == PlatformEnum.NEURON - def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT + def is_unspecified(self) -> bool: + return self._enum == PlatformEnum.UNSPECIFIED + def get_max_output_tokens(self, prompt_len: int) -> int: return sys.maxsize @@ -184,8 +159,10 @@ class Platform: # Treat empty device control env var as unset. This is a valid # configuration in Ray setups where the engine is launched in # a CPU-only placement group located on a GPU node. - if cls.device_control_env_var in os.environ and os.environ[ - cls.device_control_env_var] != "": + if ( + cls.device_control_env_var in os.environ + and os.environ[cls.device_control_env_var] != "" + ): device_ids = os.environ[cls.device_control_env_var].split(",") physical_device_id = device_ids[device_id] return int(physical_device_id) @@ -193,14 +170,34 @@ class Platform: return device_id @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + def import_kernels(cls) -> None: + """Import any platform-specific C kernels.""" + try: + import vllm._C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._C: %r", e) + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + + @classmethod + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend + return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: """Get the attention backend class of a device.""" return "" @@ -275,13 +272,6 @@ class Platform: """Get the total memory of a device in bytes.""" raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - """ - Check if the current platform supports async output. - """ - raise NotImplementedError - @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. @@ -313,9 +303,9 @@ class Platform: raise NotImplementedError @classmethod - def pre_register_and_update(cls, - parser: Optional[FlexibleArgumentParser] = None - ) -> None: + def pre_register_and_update( + cls, parser: Optional[FlexibleArgumentParser] = None + ) -> None: """ Do some pre-registration or update action for the current platform. @@ -358,11 +348,10 @@ class Platform: """ Verify whether the quantization is supported by the current platform. """ - if cls.supported_quantization and \ - quant not in cls.supported_quantization: + if cls.supported_quantization and quant not in cls.supported_quantization: raise ValueError( - f"{quant} quantization is currently not supported in " - f"{cls.device_name}.") + f"{quant} quantization is currently not supported in {cls.device_name}." + ) @classmethod def get_cpu_architecture(cls) -> CpuArchEnum: @@ -380,6 +369,8 @@ class Platform: return CpuArchEnum.POWERPC elif machine == "s390x": return CpuArchEnum.S390X + elif machine.startswith("riscv"): + return CpuArchEnum.RISCV return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN @@ -389,15 +380,17 @@ class Platform: if in_wsl(): # Pinning memory in WSL is not supported. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications - logger.warning("Using 'pin_memory=False' as WSL is detected. " - "This may slow down the performance.") + logger.warning( + "Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance." + ) return False return True @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: """ Return the memory usage in bytes. """ @@ -484,23 +477,10 @@ class Platform: from vllm.config import get_current_vllm_config parallel_config = get_current_vllm_config().parallel_config - return (envs.VLLM_USE_V1 - or parallel_config.distributed_executor_backend - == "external_launcher") - - @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: - """Returns whether the current platform can support v1 for the supplied - model configuration. - """ - return False - - @classmethod - def default_v1(cls, model_config: ModelConfig) -> bool: - """ - Returns whether the current platform supports v1 by default. - """ - return cls.supports_v1(model_config) + return ( + envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend == "external_launcher" + ) @classmethod def use_custom_allreduce(cls) -> bool: @@ -509,6 +489,14 @@ class Platform: """ return False + @classmethod + def opaque_attention_op(cls) -> bool: + """ + Returns True if we register attention as one giant opaque custom op + on the current platform + """ + return False + @classmethod def validate_request( cls, @@ -523,13 +511,16 @@ class Platform: if device is not None and hasattr(device, key): return getattr(device, key) else: - logger.warning("Current platform %s does not have '%s'" \ - " attribute.", self.device_type, key) + logger.warning( + "Current platform %s does not have '%s' attribute.", + self.device_type, + key, + ) return None def get_global_graph_pool(self) -> Any: """ - Return the global graph pool for the this platform. + Return the global graph pool for this platform. """ cls = self.__class__ if cls._global_graph_pool is None: @@ -565,13 +556,73 @@ class Platform: raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ return False + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + """ + Check if the dtype is supported by the current platform. + """ + raise NotImplementedError + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + """ + Returns if the hybrid kv cache is supported by the current platform. + """ + return False + + @classmethod + def support_static_graph_mode(cls) -> bool: + """ + Returns if the graph mode is supported by the current platform. + """ + return False + + @classmethod + def use_sync_weight_loader(cls) -> bool: + """ + Returns if the current platform needs to sync weight loader. + """ + return False + + @classmethod + def make_synced_weight_loader(cls, original_weight_loader): + """ + Wrap the original weight loader to make it synced. + """ + if not cls.use_sync_weight_loader(): + return original_weight_loader + + def _synced_weight_loader(param, *args, **kwargs): + out = original_weight_loader(param, *args, **kwargs) + if param.device != torch.device("cpu"): + torch._sync(param) + return out + + return _synced_weight_loader + + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + """ + Returns a mapping from device_type to a tuple of supported + kv_buffer_device for nixl. + """ + return {} + + @classmethod + def get_nixl_memory_type(cls) -> Optional[str]: + """ + Returns the nixl memory type for the current platform. + """ + return None + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py deleted file mode 100644 index cb8ac8db669fe..0000000000000 --- a/vllm/platforms/neuron.py +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum -import os -from functools import lru_cache -from typing import TYPE_CHECKING, Optional - -from vllm import envs -from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS - -from .interface import Platform, PlatformEnum - -if TYPE_CHECKING: - from vllm.config import VllmConfig -else: - VllmConfig = None - -logger = init_logger(__name__) - - -class NeuronFramework(enum.Enum): - TRANSFORMERS_NEURONX = "transformers-neuronx" - NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" - - -class NeuronPlatform(Platform): - _enum = PlatformEnum.NEURON - device_name: str = "neuron" - device_type: str = "neuron" - ray_device_key: str = "neuron_cores" - supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] - dist_backend: str = "gloo" - device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" - - @classmethod - def get_device_name(cls, device_id: int = 0) -> str: - return "neuron" - - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - parallel_config = vllm_config.parallel_config - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = \ - "vllm.worker.neuron_worker.NeuronWorker" - - if parallel_config.world_size > 1: - parallel_config.distributed_executor_backend = "uni" - - if vllm_config.cache_config and vllm_config.model_config: - # neuron needs block_size = max_model_len - vllm_config.cache_config.block_size = \ - vllm_config.model_config.max_model_len # type: ignore - - if vllm_config.model_config and vllm_config.model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) - - @classmethod - def is_pin_memory_available(cls) -> bool: - logger.warning("Pin memory is not supported on Neuron.") - return False - - @classmethod - def get_device_communicator_cls(cls) -> str: - if envs.VLLM_USE_V1: - return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa - else: - return Platform.get_device_communicator_cls() - - @classmethod - def use_all_gather(cls) -> bool: - return True - - @classmethod - @lru_cache - def is_neuronx_distributed_inference(cls) -> bool: - try: - import neuronx_distributed_inference - except ImportError: - neuronx_distributed_inference = None - return neuronx_distributed_inference is not None - - @classmethod - @lru_cache - def is_transformers_neuronx(cls) -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None - - def get_neuron_framework_to_use(self): - """Return the specified framework if corresponding installations are - available. - - If no framework is specified, use neuronx-distributed-inference by - default. - If that's unavailable, check and switch to transformers-neuronx. - """ - if not self.is_neuron(): - raise AssertionError( - f"Neuron Framework unavailable for platform: {self}") - - tnx_installed = self.is_transformers_neuronx() - nxd_installed = self.is_neuronx_distributed_inference() - - specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value - if specified_framework == tnx_framework and tnx_installed: - return self.TRANSFORMERS_NEURONX - - if ((specified_framework == nxd_framework and nxd_installed) - or (specified_framework is None and nxd_installed)): - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - - if specified_framework is None and tnx_installed: - return NeuronFramework.TRANSFORMERS_NEURONX - - return None - - def use_neuronx_distributed(self): - """ - Return True if the framework determined in get_neuron_framework_to_use() - is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This - is used to select the Neuron model framework and framework-specific - configuration to apply during model compilation. - """ - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - return self.get_neuron_framework_to_use() == nxd_framework - - def use_transformers_neuronx(self): - """ - Return True if the framework determined in get_neuron_framework_to_use() - is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used - to select the Neuron model framework and framework-specific - configuration to apply during model compilation. - """ - return self.get_neuron_framework_to_use( - ) == NeuronFramework.TRANSFORMERS_NEURONX diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 323ec591c50a3..95d3fa74e325d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -14,17 +14,25 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig +else: + _Backend = None logger = init_logger(__name__) try: - from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info, - amdsmi_get_processor_handles, amdsmi_init, - amdsmi_shut_down, amdsmi_topo_get_link_type) + from amdsmi import ( + AmdSmiException, + amdsmi_get_gpu_asic_info, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + amdsmi_topo_get_link_type, + ) except ImportError as e: logger.warning("Failed to import from amdsmi with %r", e) @@ -44,24 +52,24 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = [] # Models partially supported by ROCm. # Architecture -> Reason. -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_SWA_REASON = ( + "Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`" +) _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = { - "Qwen2ForCausalLM": - _ROCM_SWA_REASON, - "MistralForCausalLM": - _ROCM_SWA_REASON, - "MixtralForCausalLM": - _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": - ("ROCm flash attention does not yet " - "fully support 32-bit precision on PaliGemma"), - "Phi3VForCausalLM": - ("ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") + "Qwen2ForCausalLM": _ROCM_SWA_REASON, + "MistralForCausalLM": _ROCM_SWA_REASON, + "MixtralForCausalLM": _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": ( + "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma" + ), + "Phi3VForCausalLM": ( + "ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`" + ), } _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { "0x74a0": "AMD_Instinct_MI300A", @@ -88,7 +96,6 @@ if "HIP_VISIBLE_DEVICES" in os.environ: def with_amdsmi_context(fn): - @wraps(fn) def wrapper(*args, **kwargs): amdsmi_init() @@ -119,17 +126,23 @@ def on_gfx9() -> bool: @cache -def use_rocm_custom_paged_attention( - qtype: torch.dtype, - head_size: int, - block_size: int, - gqa_ratio: int, - max_seq_len: int, - sliding_window: int, - kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None, - sinks: Optional[torch.Tensor] = None) -> bool: +def on_gfx950() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx950"]) + +@cache +def use_rocm_custom_paged_attention( + qtype: torch.dtype, + head_size: int, + block_size: int, + gqa_ratio: int, + max_seq_len: int, + sliding_window: int, + kv_cache_dtype: str, + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, +) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -137,26 +150,36 @@ def use_rocm_custom_paged_attention( # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. if ON_GFX9: - return ((not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_seq_len <= 128 * 1024 - and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER) and sinks is None) + return ( + (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024 + and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER) + and sinks is None + ) else: - return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and head_size == 128 and block_size == 16 - and (gqa_ratio >= 3 and gqa_ratio <= 16) - and max_seq_len <= 128 * 1024 and alibi_slopes is None - and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) + return ( + ON_GFX11_GFX12 + and ( + not envs.VLLM_USE_V1 + or sliding_window == 0 + or sliding_window == (-1, -1) + ) + and (qtype == torch.half or qtype == torch.bfloat16) + and head_size == 128 + and block_size == 16 + and (gqa_ratio >= 3 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024 + and alibi_slopes is None + and kv_cache_dtype == "auto" + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN + and sinks is None + ) class RocmPlatform(Platform): @@ -170,89 +193,123 @@ class RocmPlatform(Platform): device_control_env_var: str = "CUDA_VISIBLE_DEVICES" supported_quantization: list[str] = [ - "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" + "awq", + "gptq", + "fp8", + "compressed-tensors", + "fbgemm_fp8", + "gguf", + "quark", + "ptpc_fp8", + "mxfp4", + "petit_nvfp4", + "torchao", ] @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if support_fa: - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA - and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA - if on_gfx9(): - return _Backend.FLASH_ATTN + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend + + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + return _Backend.ROCM_AITER_FA + if on_gfx9(): + return _Backend.FLASH_ATTN return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink) -> str: + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) -> str: + from vllm.attention.backends.registry import _Backend + + if use_sparse: + raise NotImplementedError("Sparse Attention is not supported on ROCm.") if use_mla: - from vllm.attention.backends.rocm_aiter_mla import ( - is_aiter_mla_enabled) + if not use_v1: + raise RuntimeError( + "MLA attention backends require the V1 engine. " + "Set VLLM_USE_V1=1 to enable them." + ) + + from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( + is_aiter_mla_enabled, + ) if selected_backend is None: - selected_backend = (_Backend.ROCM_AITER_MLA if - is_aiter_mla_enabled() or block_size == 1 - else _Backend.TRITON_MLA) + selected_backend = ( + _Backend.ROCM_AITER_MLA + if is_aiter_mla_enabled() or block_size == 1 + else _Backend.TRITON_MLA + ) if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - if use_v1: - logger.info_once( - "Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 - else: - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA \ - or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: - if block_size == 1: - if use_v1: - logger.info("Using AITER MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - else: - logger.info("Using AITER MLA backend") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - else: - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}." - "(currently only supports block size 1)") - else: + logger.info_once("Using Triton MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" raise ValueError( f" The selected backend, {selected_backend.name}," - f"is not MLA type while requested for MLA backend.") - - if selected_backend is None or selected_backend == _Backend.FLASH_ATTN: - selected_backend = _Backend.ROCM_FLASH + f"does not support block size {block_size}." + ) + if selected_backend == _Backend.ROCM_AITER_MLA: + if block_size == 1: + logger.info("Using AITER MLA backend on V1 engine.") + return ( + "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + ) + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}." + "(currently only supports block size 1)" + ) + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"is not MLA type while requested for MLA backend." + ) if envs.VLLM_USE_V1: - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ - and on_gfx9(): - logger.info("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "rocm_aiter_fa.AiterFlashAttentionBackend") - else: - logger.info("Using Triton Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "triton_attn.TritonAttentionBackend") - if selected_backend == _Backend.ROCM_FLASH: - if not cls.has_device_capability(90): - # not Instinct series GPUs. - logger.info("flash_attn is not supported on NAVI GPUs.") - else: - logger.info("%s is not supported in AMD GPUs.", selected_backend) - logger.info("Using ROCmFlashAttention backend.") - return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + if selected_backend == _Backend.FLEX_ATTENTION: + logger.info("Using FlexAttention backend on V1 engine.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + ) or selected_backend == _Backend.ROCM_AITER_FA: + logger.info("Using Aiter Flash Attention backend on V1 engine.") + return ( + "vllm.v1.attention.backends." + "rocm_aiter_fa.AiterFlashAttentionBackend" + ) + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + logger.info("Using Aiter Unified Attention backend on V1 engine.") + return ( + "vllm.v1.attention.backends." + "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" + ) + if ( + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + or selected_backend == _Backend.ROCM_ATTN + ): + # rocm specific backend, with aiter and/or + # triton prefix-prefill + logger.info("Using Rocm Attention backend on V1 engine.") + return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + # default case, using triton unified attention + logger.info("Using Triton Attention backend on V1 engine.") + return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." + ) @classmethod def set_device(cls, device: torch.device) -> None: @@ -263,9 +320,7 @@ class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) @@ -275,21 +330,17 @@ class RocmPlatform(Platform): """ Query if the set of gpus are fully connected by xgmi (1 hop) """ - handles = [ - amdsmi_get_processor_handles()[i] for i in physical_device_ids - ] + handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] for i, handle in enumerate(handles): for j, peer_handle in enumerate(handles): if i < j: try: - link_type = amdsmi_topo_get_link_type( - handle, peer_handle) + link_type = amdsmi_topo_get_link_type(handle, peer_handle) # type is 2 for XGMI if link_type["hops"] != 1 or link_type["type"] != 2: return False except AmdSmiException as error: - logger.error("AMD 1 hop XGMI detection failed.", - exc_info=error) + logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) return False return True @@ -310,47 +361,48 @@ class RocmPlatform(Platform): device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager and not envs.VLLM_USE_V1: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - return False - return True - @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.config.compilation import CUDAGraphMode + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + parallel_config = vllm_config.parallel_config + is_eager_execution = compilation_config == CUDAGraphMode.NONE + + use_v1 = envs.VLLM_USE_V1 + use_aiter_rms_norm = ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM + ) + if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": - if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Speculative decoding is not supported on vLLM V0.") - parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" - else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" + # Aiter rms norm perform best when CUDA Graph capture is enabled. + if ( + use_v1 + and use_aiter_rms_norm + and not is_eager_execution + and "-rms_norm" not in compilation_config.custom_ops + ): + compilation_config.custom_ops.append("+rms_norm") @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError(f"Model architecture '{model_arch}' is not " - "supported by ROCm for now.") + raise ValueError( + f"Model architecture '{model_arch}' is not supported by ROCm for now." + ) if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] logger.warning( - "Model architecture '%s' is partially " - "supported by ROCm: %s", model_arch, msg) + "Model architecture '%s' is partially supported by ROCm: %s", + model_arch, + msg, + ) @classmethod def verify_quantization(cls, quant: str) -> None: @@ -358,7 +410,8 @@ class RocmPlatform(Platform): if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: logger.warning( "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" - " is not set, enabling VLLM_USE_TRITON_AWQ.") + " is not set, enabling VLLM_USE_TRITON_AWQ." + ) envs.VLLM_USE_TRITON_AWQ = True @classmethod @@ -366,16 +419,17 @@ class RocmPlatform(Platform): return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( - device)[0] + return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0] @classmethod def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + return ( + "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + ) @classmethod def supports_mx(cls) -> bool: @@ -385,12 +439,12 @@ class RocmPlatform(Platform): @classmethod def supports_fp8(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12']) + return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"]) @classmethod def is_fp8_fnuz(cls) -> bool: # only device 0 is checked, this assumes MI300 platforms are homogeneous - return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName + return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName @classmethod def fp8_dtype(cls) -> torch.dtype: @@ -399,26 +453,24 @@ class RocmPlatform(Platform): else: return torch.float8_e4m3fn - @classmethod - def supports_v1(cls, model_config: "ModelConfig") -> bool: - # V1 support on AMD gpus is experimental - return True - @classmethod def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - supported_archs = ['gfx94', 'gfx95'] + supported_archs = ["gfx94", "gfx95"] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: - return torch.cuda.get_device_properties( - device_id).multi_processor_count + return torch.cuda.get_device_properties(device_id).multi_processor_count @classmethod def is_navi(cls) -> bool: - return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName + return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName @classmethod def get_static_graph_wrapper_cls(cls) -> str: @@ -444,8 +496,9 @@ class RocmPlatform(Platform): backend_options = ProcessGroupNCCL.Options() backend_options._timeout = timeout - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") pg._set_default_backend(backend_type) @@ -459,6 +512,36 @@ class RocmPlatform(Platform): return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: + return True + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not cls.has_device_capability(80): + capability = cls.get_device_capability() + gpu_name = cls.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs " + "with compute capability of at least 8.0. " + f"Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half." + ) + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + + @classmethod + def support_static_graph_mode(cls) -> bool: return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d7468d74b021f..1c323ba8200a2 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from typing import TYPE_CHECKING, Optional, Union, cast import torch @@ -11,20 +12,23 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import Platform, PlatformEnum, _Backend +from .interface import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.config import BlockSize, ModelConfig, VllmConfig + from vllm.attention.backends.registry import _Backend + from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams else: BlockSize = None ModelConfig = None VllmConfig = None PoolingParams = None + _Backend = None logger = init_logger(__name__) -USE_TPU_COMMONS = False +USE_TPU_INFERENCE = False class TpuPlatform(Platform): @@ -37,21 +41,34 @@ class TpuPlatform(Platform): device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" - supported_quantization: list[str] = [ - "fp8", "tpu_int8", "compressed-tensors" - ] + supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"] - additional_env_vars: list[str] = [ - "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" - ] + additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"] @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink) -> str: - if (selected_backend != _Backend.PALLAS - and selected_backend != _Backend.PALLAS_VLLM_V1): + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink, + use_sparse, + ) -> str: + from vllm.attention.backends.registry import _Backend + + if use_sparse: + raise NotImplementedError("Sparse Attention is not supported on TPU.") + if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) if not use_v1: @@ -75,10 +92,6 @@ class TpuPlatform(Platform): def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" @@ -111,34 +124,43 @@ class TpuPlatform(Platform): # TPU only supports DYNAMO_ONCE compilation level if compilation_config.level != CompilationLevel.DYNAMO_ONCE: - logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and " - "disabling cudagraph.") + logger.info( + "[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph." + ) compilation_config.level = CompilationLevel.DYNAMO_ONCE - if compilation_config.cudagraph_mode is None or \ - compilation_config.cudagraph_mode.max_cudagraph_mode() \ - != CUDAGraphMode.NONE: - logger.info("[TPU] CUDA graph is not supported on TPU, " - "disabling cudagraphs.") + if ( + compilation_config.cudagraph_mode is None + or compilation_config.cudagraph_mode.max_cudagraph_mode() + != CUDAGraphMode.NONE + ): + logger.info( + "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs." + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE if compilation_config.backend == "": compilation_config.backend = "openxla" - assert vllm_config.speculative_config is None, \ + assert vllm_config.speculative_config is None, ( "TPU does not support speculative decoding" + ) model_config = vllm_config.model_config - if model_config is not None and model_config.dtype in (torch.float16, - torch.float32): + if model_config is not None and model_config.dtype in ( + torch.float16, + torch.float32, + ): logger.warning( "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", model_config.dtype) + "Using bfloat16 instead.", + model_config.dtype, + ) model_config.dtype = torch.bfloat16 from vllm.v1.attention.backends.pallas import PallasAttentionBackend - cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) # type: ignore[assignment] + + cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config @@ -146,24 +168,31 @@ class TpuPlatform(Platform): parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend") + "Speculative decoding is not yet supported for TPU backend" + ) - if scheduler_config.is_multimodal_model and not \ - scheduler_config.disable_chunked_mm_input: - logger.warning("TPU does not support running Multimodal models"\ - " without setting `--disable_chunked_mm_input`. " \ - "Forcing --disable_chunked_mm_input.") + if ( + scheduler_config.is_multimodal_model + and not scheduler_config.disable_chunked_mm_input + ): + logger.warning( + "TPU does not support running Multimodal models" + " without setting `--disable_chunked_mm_input`. " + "Forcing --disable_chunked_mm_input." + ) scheduler_config.disable_chunked_mm_input = True if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") + "prefill and prefix caching to be disabled." + ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) @classmethod def is_pin_memory_available(cls): @@ -178,11 +207,6 @@ class TpuPlatform(Platform): def use_all_gather(cls) -> bool: return True - @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: - # V1 support on TPU is experimental - return True - @classmethod def validate_request( cls, @@ -191,20 +215,53 @@ class TpuPlatform(Platform): processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" - if (isinstance(params, SamplingParams) - and params.sampling_type == SamplingType.RANDOM_SEED): + if ( + isinstance(params, SamplingParams) + and params.sampling_type == SamplingType.RANDOM_SEED + ): raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: + return True + + @classmethod + @torch.compile(backend="openxla") + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device) + + @classmethod + @torch.compile(backend="openxla") + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """tpu blocks to cpu blocks""" + torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() + + @classmethod + def use_sync_weight_loader(cls) -> bool: return True try: - from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform - TpuPlatform = TpuCommonsPlatform # type: ignore - USE_TPU_COMMONS = True + from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform + + TpuPlatform = TpuInferencePlatform # type: ignore + USE_TPU_INFERENCE = True except ImportError: - logger.info("tpu_commons not found, using vLLM's TpuPlatform") + logger.info("tpu_inference not found, using vLLM's TpuPlatform") pass diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index af24437f649f4..b75b52938839b 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import os from typing import TYPE_CHECKING, Optional @@ -10,13 +11,15 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig else: ModelConfig = None VllmConfig = None + _Backend = None logger = init_logger(__name__) @@ -33,18 +36,72 @@ class XPUPlatform(Platform): device_control_env_var: str = "ZE_AFFINITY_MASK" @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: - if selected_backend is not None and selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse, + ) -> str: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout("NHD") + logger.info( + "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels." + ) + + from vllm.attention.backends.registry import _Backend + + if use_sparse: + raise NotImplementedError("Sparse Attention is not supported on XPU.") use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + if selected_backend == _Backend.TRITON_ATTN: + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return FLASH_ATTN + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}, " + f"with use_v1: {use_v1} use_mla: {use_mla}" + ) + logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: + """ + Check if the kv_cache_dtype is supported. + XPU only support fp8 kv cache with triton backend. + """ + if ( + envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN" + ): + return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] + + return False + @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -74,10 +131,6 @@ class XPUPlatform(Platform): device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True - @classmethod def inference_mode(cls): return torch.no_grad() @@ -90,29 +143,19 @@ class XPUPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 64 - # FIXME: Temporarily forcing eager mode - # remove after t.compile support stabilizes. - if (envs.VLLM_USE_V1 and model_config is not None - and not vllm_config.model_config.enforce_eager): - from vllm.config import CompilationLevel - vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 - - # Instances created using VllmConfig() typically have model_config as - # None by default. The modification involves adding a check to prevent - # potential null exceptions check and update model config. - if model_config is not None and model_config.dtype == torch.bfloat16 \ - and not cls.device_support_bf16(): - model_config.dtype = torch.float16 - # lazy import to avoid circular import - from vllm.config import CUDAGraphMode + from vllm.config import CompilationLevel, CUDAGraphMode + compilation_config = vllm_config.compilation_config - if compilation_config.cudagraph_mode is None or \ - compilation_config.cudagraph_mode.max_cudagraph_mode() \ - != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, " - "disabling cudagraphs.") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if compilation_config.compile_sizes is None: + compilation_config.compile_sizes = [] + + assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, ( + "CUDA graph mode should be NONE on XPU" + ) + + if vllm_config.lora_config is not None: + compilation_config.level = CompilationLevel.NO_COMPILATION # check and update parallel config parallel_config = vllm_config.parallel_config @@ -130,70 +173,104 @@ class XPUPlatform(Platform): if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" logger.warning( - "Please use spawn as start method if you want to use mp.") - elif (parallel_config.distributed_executor_backend != "ray" - and parallel_config.distributed_executor_backend != "uni" - and parallel_config.distributed_executor_backend - != "external_launcher"): + "Please use spawn as start method if you want to use mp." + ) + elif ( + parallel_config.distributed_executor_backend != "ray" + and parallel_config.distributed_executor_backend != "uni" + and parallel_config.distributed_executor_backend != "external_launcher" + ): logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", - parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend, + ) parallel_config.distributed_executor_backend = "ray" if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") + "prefill and prefix caching to be disabled." + ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + + @classmethod + def support_static_graph_mode(cls) -> bool: + return False @classmethod def is_pin_memory_available(cls): return True @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) @classmethod - def device_support_bf16(cls) -> bool: - device_name = cls.get_device_name().lower() - if cls.is_client_gpu_a770(): - logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," - " fallback to float16") - return False - else: - logger.info( - "Device name %s supports bfloat16. Please file an issue " - "if you encounter any accuracy problems with bfloat16.", - device_name) - return True + def fp8_dtype(cls) -> torch.dtype: + return torch.float8_e5m2 @classmethod def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower() return device_name.count("data center gpu") > 0 - @classmethod - def is_client_gpu_a770(cls) -> bool: - device_name = cls.get_device_name().lower() - return device_name.count("a770") > 0 - @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: + def device_count(cls) -> int: + return torch.xpu.device_count() + + @classmethod + def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + if torch_dtype == torch.bfloat16: # noqa: SIM102 + device_name = cls.get_device_name().lower() + # client gpu a770 + if device_name.count("a770") > 0: + raise ValueError( + "Intel Arc A770 have bfloat16 accuracy known issue. " + "You can use float16 instead by explicitly setting the " + "`dtype` flag in CLI, for example: --dtype=half." + ) + + @classmethod + def opaque_attention_op(cls) -> bool: return True @classmethod - def device_count(cls) -> int: - return torch.xpu.device_count() + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on XPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from XPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 1a1760df82c03..094bda3f9369e 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -8,18 +8,14 @@ import vllm.envs as envs logger = logging.getLogger(__name__) -DEFAULT_PLUGINS_GROUP = 'vllm.general_plugins' +DEFAULT_PLUGINS_GROUP = "vllm.general_plugins" # make sure one process only loads plugins once plugins_loaded = False def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: - import sys - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points allowed_plugins = envs.VLLM_PLUGINS @@ -29,7 +25,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: return {} # Check if the only discovered plugin is the default one - is_default_group = (group == DEFAULT_PLUGINS_GROUP) + is_default_group = group == DEFAULT_PLUGINS_GROUP # Use INFO for non-default groups and DEBUG for the default group log_level = logger.debug if is_default_group else logger.info @@ -38,8 +34,10 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: log_level("- %s -> %s", plugin.name, plugin.value) if allowed_plugins is None: - log_level("All plugins in this group will be loaded. " - "Set `VLLM_PLUGINS` to control which plugins to load.") + log_level( + "All plugins in this group will be loaded. " + "Set `VLLM_PLUGINS` to control which plugins to load." + ) plugins = dict[str, Callable[[], Any]]() for plugin in discovered_plugins: diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py new file mode 100644 index 0000000000000..7a914442c4ab8 --- /dev/null +++ b/vllm/plugins/io_processors/__init__.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging + +from vllm.config import VllmConfig +from vllm.plugins import load_plugins_by_group +from vllm.plugins.io_processors.interface import IOProcessor +from vllm.utils import resolve_obj_by_qualname + +logger = logging.getLogger(__name__) + + +def get_io_processor( + vllm_config: VllmConfig, plugin_from_init: str | None = None +) -> IOProcessor | None: + # Input.Output processors are loaded as plugins under the + # 'vllm.io_processor_plugins' group. Similar to platform + # plugins, these plugins register a function that returns the class + # name for the processor to install. + + if plugin_from_init: + model_plugin = plugin_from_init + else: + # A plugin can be specified via the model config + # Retrieve the model specific plugin if available + # This is using a custom field in the hf_config for the model + hf_config = vllm_config.model_config.hf_config.to_dict() + config_plugin = hf_config.get("io_processor_plugin") + model_plugin = config_plugin + + if model_plugin is None: + logger.debug("No IOProcessor plugins requested by the model") + return None + + logger.debug("IOProcessor plugin to be loaded %s", model_plugin) + + # Load all installed plugin in the group + multimodal_data_processor_plugins = load_plugins_by_group( + "vllm.io_processor_plugins" + ) + + loadable_plugins = {} + for name, func in multimodal_data_processor_plugins.items(): + try: + assert callable(func) + processor_cls_qualname = func() + if processor_cls_qualname is not None: + loadable_plugins[name] = processor_cls_qualname + except Exception: + logger.warning("Failed to load plugin %s.", name, exc_info=True) + + num_available_plugins = len(loadable_plugins.keys()) + if num_available_plugins == 0: + raise ValueError( + f"No IOProcessor plugins installed but one is required ({model_plugin})." + ) + + if model_plugin not in loadable_plugins: + raise ValueError( + f"The model requires the '{model_plugin}' IO Processor plugin " + "but it is not installed. " + f"Available plugins: {list(loadable_plugins.keys())}" + ) + + activated_plugin_cls = loadable_plugins[model_plugin] + + return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config) diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py new file mode 100644 index 0000000000000..84af40d01c439 --- /dev/null +++ b/vllm/plugins/io_processors/interface.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Generic, Optional, TypeVar, Union + +from vllm.config import VllmConfig +from vllm.entrypoints.openai.protocol import IOProcessorResponse +from vllm.inputs.data import PromptType +from vllm.outputs import PoolingRequestOutput + +IOProcessorInput = TypeVar("IOProcessorInput") +IOProcessorOutput = TypeVar("IOProcessorOutput") + + +class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + + @abstractmethod + def pre_process( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + raise NotImplementedError + + async def pre_process_async( + self, + prompt: IOProcessorInput, + request_id: Optional[str] = None, + **kwargs, + ) -> Union[PromptType, Sequence[PromptType]]: + return self.pre_process(prompt, request_id, **kwargs) + + @abstractmethod + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + raise NotImplementedError + + async def post_process_async( + self, + model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: + # We cannot guarantee outputs are returned in the same order they were + # fed to vLLM. + # Let's sort them by id before post_processing + sorted_output = sorted( + [(i, item) async for i, item in model_output], key=lambda output: output[0] + ) + collected_output = [output[1] for output in sorted_output] + return self.post_process(collected_output, request_id, **kwargs) + + @abstractmethod + def parse_request(self, request: Any) -> IOProcessorInput: + raise NotImplementedError + + @abstractmethod + def output_to_response( + self, plugin_output: IOProcessorOutput + ) -> IOProcessorResponse: + raise NotImplementedError diff --git a/vllm/plugins/lora_resolvers/filesystem_resolver.py b/vllm/plugins/lora_resolvers/filesystem_resolver.py index b999d07a6eb74..c3255af457026 100644 --- a/vllm/plugins/lora_resolvers/filesystem_resolver.py +++ b/vllm/plugins/lora_resolvers/filesystem_resolver.py @@ -10,25 +10,29 @@ from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry class FilesystemResolver(LoRAResolver): - def __init__(self, lora_cache_dir: str): self.lora_cache_dir = lora_cache_dir - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: lora_path = os.path.join(self.lora_cache_dir, lora_name) if os.path.exists(lora_path): - adapter_config_path = os.path.join(self.lora_cache_dir, lora_name, - "adapter_config.json") + adapter_config_path = os.path.join( + self.lora_cache_dir, lora_name, "adapter_config.json" + ) if os.path.exists(adapter_config_path): with open(adapter_config_path) as file: adapter_config = json.load(file) - if adapter_config["peft_type"] == "LORA" and adapter_config[ - "base_model_name_or_path"] == base_model_name: - lora_request = LoRARequest(lora_name=lora_name, - lora_int_id=abs( - hash(lora_name)), - lora_path=lora_path) + if ( + adapter_config["peft_type"] == "LORA" + and adapter_config["base_model_name_or_path"] == base_model_name + ): + lora_request = LoRARequest( + lora_name=lora_name, + lora_int_id=abs(hash(lora_name)), + lora_path=lora_path, + ) return lora_request return None @@ -38,13 +42,12 @@ def register_filesystem_resolver(): lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR if lora_cache_dir: - if not os.path.exists(lora_cache_dir) or not os.path.isdir( - lora_cache_dir): + if not os.path.exists(lora_cache_dir) or not os.path.isdir(lora_cache_dir): raise ValueError( "VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \ - for Filesystem Resolver plugin to function") + for Filesystem Resolver plugin to function" + ) fs_resolver = FilesystemResolver(lora_cache_dir) - LoRAResolverRegistry.register_resolver("Filesystem Resolver", - fs_resolver) + LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver) return diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 29f037b4372cd..f7a53503e5841 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Annotated, Any, Optional import msgspec @@ -14,26 +14,39 @@ if TYPE_CHECKING: class PoolingParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, +): # type: ignore[call-arg] """API parameters for pooling models. Attributes: + truncate_prompt_tokens: Controls prompt truncation. + Set to -1 to use the model's default truncation size. + Set to k to keep only the last k tokens (left truncation). + Set to None to disable truncation. normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings - if model support matryoshka representation. + if model support matryoshka representation. activation: Whether to apply activation function to - the classification outputs. + the classification outputs. softmax: Whether to apply softmax to the reward outputs. """ + # --8<-- [start:common-pooling-params] + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None + # --8<-- [end:common-pooling-params] + ## for embeddings models + # --8<-- [start:embedding-pooling-params] dimensions: Optional[int] = None normalize: Optional[bool] = None + # --8<-- [end:embedding-pooling-params] - ## for classification models + ## for classification, scoring and rerank + # --8<-- [start:classification-pooling-params] activation: Optional[bool] = None + # --8<-- [end:classification-pooling-params] ## for reward models softmax: Optional[bool] = None @@ -54,8 +67,12 @@ class PoolingParams( @property def all_parameters(self) -> list[str]: return [ - "dimensions", "normalize", "activation", "softmax", "step_tag_id", - "returned_token_ids" + "dimensions", + "normalize", + "activation", + "softmax", + "step_tag_id", + "returned_token_ids", ] @property @@ -71,10 +88,9 @@ class PoolingParams( """Returns a deep copy of the PoolingParams instance.""" return deepcopy(self) - def verify(self, - task: PoolingTask, - model_config: Optional["ModelConfig"] = None) -> None: - + def verify( + self, task: PoolingTask, model_config: Optional["ModelConfig"] = None + ) -> None: if self.task is None: self.task = task elif self.task != task: @@ -89,10 +105,9 @@ class PoolingParams( self._set_default_parameters(model_config) self._verify_valid_parameters() - def _merge_default_parameters(self, - model_config: Optional["ModelConfig"] = None - ) -> None: - + def _merge_default_parameters( + self, model_config: Optional["ModelConfig"] = None + ) -> None: if model_config is None: return @@ -119,8 +134,8 @@ class PoolingParams( if not model_config.is_matryoshka: raise ValueError( f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.' + f"support matryoshka representation, " + f"changing output dimensions will lead to poor results." ) mds = model_config.matryoshka_dimensions @@ -128,9 +143,10 @@ class PoolingParams( if self.dimensions not in mds: raise ValueError( f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') + f"only supports {str(mds)} matryoshka dimensions, " + f"use other output dimensions will " + f"lead to poor results." + ) elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") @@ -159,20 +175,24 @@ class PoolingParams( raise ValueError( f"Task {self.task} only supports {valid_parameters} " f"parameters, does not support " - f"{invalid_parameters} parameters") + f"{invalid_parameters} parameters" + ) def __repr__(self) -> str: - return (f"PoolingParams(" - f"task={self.task}, " - f"normalize={self.normalize}, " - f"dimensions={self.dimensions}, " - f"activation={self.activation}, " - f"softmax={self.softmax}, " - f"step_tag_id={self.step_tag_id}, " - f"returned_token_ids={self.returned_token_ids}, " - f"requires_token_ids={self.requires_token_ids}, " - f"extra_kwargs={self.extra_kwargs})") + return ( + f"PoolingParams(" + f"task={self.task}, " + f"normalize={self.normalize}, " + f"dimensions={self.dimensions}, " + f"activation={self.activation}, " + f"softmax={self.softmax}, " + f"step_tag_id={self.step_tag_id}, " + f"returned_token_ids={self.returned_token_ids}, " + f"requires_token_ids={self.requires_token_ids}, " + f"extra_kwargs={self.extra_kwargs})" + ) def __post_init__(self) -> None: - assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ + assert self.output_kind == RequestOutputKind.FINAL_ONLY, ( "For pooling output_kind has to be FINAL_ONLY" + ) diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 2f9ebe531cbb1..fea299b287f98 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -12,21 +12,26 @@ from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent from torch.autograd.profiler import FunctionEvent from torch.profiler import ProfilerActivity, profile -from vllm.profiler.utils import (TablePrinter, event_has_module, - event_is_torch_op, event_module_repr, - event_torch_op_stack_trace, indent_string) +from vllm.profiler.utils import ( + TablePrinter, + event_has_module, + event_is_torch_op, + event_module_repr, + event_torch_op_stack_trace, + indent_string, +) @dataclass class _ModuleTreeNode: event: _ProfilerEvent - parent: Optional['_ModuleTreeNode'] = None - children: list['_ModuleTreeNode'] = field(default_factory=list) + parent: Optional["_ModuleTreeNode"] = None + children: list["_ModuleTreeNode"] = field(default_factory=list) trace: str = "" @property def is_leaf(self): - return (self.event.children is None or len(self.event.children) == 0) + return self.event.children is None or len(self.event.children) == 0 @property def is_torch_op(self): @@ -34,8 +39,10 @@ class _ModuleTreeNode: @property def is_cuda(self): - return (self.event.tag == _EventType.Kineto - and self.event.typed[1].device_type == DeviceType.CUDA) + return ( + self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA + ) @dataclass @@ -68,8 +75,7 @@ class _StatsTreeNode: @dataclass class LayerwiseProfileResults(profile): _kineto_results: _ProfilerResult - _kineto_event_correlation_map: dict[int, - list[_KinetoEvent]] = field(init=False) + _kineto_event_correlation_map: dict[int, list[_KinetoEvent]] = field(init=False) _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False) _module_tree: list[_ModuleTreeNode] = field(init=False) _model_stats_tree: list[_StatsTreeNode] = field(init=False) @@ -84,11 +90,9 @@ class LayerwiseProfileResults(profile): self._build_stats_trees() def print_model_table(self, column_widths: dict[str, int] = None): - _column_widths = dict(name=60, - cpu_time_us=12, - cuda_time_us=12, - pct_cuda_time=12, - trace=60) + _column_widths = dict( + name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60 + ) if column_widths: _column_widths.update(**column_widths) filtered_model_table = [ @@ -99,78 +103,76 @@ class LayerwiseProfileResults(profile): TablePrinter(ModelStatsEntry, _column_widths).print_table( self._indent_row_names_based_on_depth( filtered_model_table, - indent_style=lambda indent: "|" + "-" * indent + " ")) + indent_style=lambda indent: "|" + "-" * indent + " ", + ) + ) def print_summary_table(self, column_widths: dict[str, int] = None): - _column_widths = dict(name=80, - cuda_time_us=12, - pct_cuda_time=12, - invocations=15) + _column_widths = dict( + name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15 + ) if column_widths: _column_widths.update(**column_widths) - filtered_summary_table = [(depth, row) - for depth, row in self._flatten_stats_tree( - self._summary_stats_tree) - if row.cuda_time_us > 0] + filtered_summary_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._summary_stats_tree) + if row.cuda_time_us > 0 + ] TablePrinter(SummaryStatsEntry, _column_widths).print_table( self._indent_row_names_based_on_depth( filtered_summary_table, - indent_style=lambda indent: "|" + "-" * indent + " ")) + indent_style=lambda indent: "|" + "-" * indent + " ", + ) + ) def export_model_stats_table_csv(self, filename: str): - df = pd.DataFrame([ - asdict(row) - for _, row in self._flatten_stats_tree(self._model_stats_tree) - ]) + df = pd.DataFrame( + [asdict(row) for _, row in self._flatten_stats_tree(self._model_stats_tree)] + ) df.to_csv(filename) def export_summary_stats_table_csv(self, filename: str): - df = pd.DataFrame([ - asdict(row) - for _, row in self._flatten_stats_tree(self._summary_stats_tree) - ]) + df = pd.DataFrame( + [ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ] + ) df.to_csv(filename) def convert_stats_to_dict(self) -> dict[str, Any]: return { - "metadata": { - "num_running_seqs": self.num_running_seqs - }, - "summary_stats": - self._convert_stats_tree_to_dict(self._summary_stats_tree), - "model_stats": - self._convert_stats_tree_to_dict(self._model_stats_tree) + "metadata": {"num_running_seqs": self.num_running_seqs}, + "summary_stats": self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": self._convert_stats_tree_to_dict(self._model_stats_tree), } @staticmethod - def _indent_row_names_based_on_depth(depths_rows: list[tuple[int, - StatsEntry]], - indent_style: Union[Callable[[int], - str], - str] = " "): + def _indent_row_names_based_on_depth( + depths_rows: list[tuple[int, StatsEntry]], + indent_style: Union[Callable[[int], str], str] = " ", + ): indented_rows = [] for depth, row in depths_rows: if row.cuda_time_us == 0: continue indented_row = copy.deepcopy(row) - indented_row.name = indent_string(indented_row.name, depth, - indent_style) + indented_row.name = indent_string(indented_row.name, depth, indent_style) indented_rows.append(indented_row) return indented_rows def _build_correlation_map(self): self._kineto_event_correlation_map = defaultdict(list) for event in self._kineto_results.events(): - self._kineto_event_correlation_map[event.correlation_id()].append( - event) + self._kineto_event_correlation_map[event.correlation_id()].append(event) def _build_module_tree(self): self._module_tree = [] event_tree = self._kineto_results.experimental_event_tree() - def _df_traversal(event: _ProfilerEvent, - curr_node: Optional[_ModuleTreeNode] = None): - + def _df_traversal( + event: _ProfilerEvent, curr_node: Optional[_ModuleTreeNode] = None + ): # For the tensor parallel case for now only look at task 1 if event.start_tid != 1: return @@ -183,13 +185,15 @@ class LayerwiseProfileResults(profile): self._module_tree.append(node) curr_node = node - is_leaf = (event.children is None or len(event.children) == 0) + is_leaf = event.children is None or len(event.children) == 0 if is_leaf and curr_node: node = _ModuleTreeNode( event=event, parent=curr_node, trace=event_torch_op_stack_trace( - event, until=lambda x: event_has_module(x))) + event, until=lambda x: event_has_module(x) + ), + ) curr_node.children.append(node) curr_node = node @@ -203,31 +207,31 @@ class LayerwiseProfileResults(profile): if node.event.tag != _EventType.Kineto: return None correlated_kineto_events = self._kineto_event_correlation_map.get( - node.event.correlation_id, []) - iterator = (x for x in correlated_kineto_events - if x.device_type() == DeviceType.CUDA - and x.name() == node.event.name) + node.event.correlation_id, [] + ) + iterator = ( + x + for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA and x.name() == node.event.name + ) return next(iterator, None) def _cumulative_cuda_time(self, node: _ModuleTreeNode): - 'Return cuda time in microseconds' + "Return cuda time in microseconds" def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): - if node.is_leaf and (gpu_kineto_event := - self._get_kineto_gpu_event(node)): + if node.is_leaf and (gpu_kineto_event := self._get_kineto_gpu_event(node)): return gpu_kineto_event.duration_ns() / 1000.0 else: cumulative_cuda_time = 0 for child in node.children: - cumulative_cuda_time += _cumulative_cuda_time_recursive( - child) + cumulative_cuda_time += _cumulative_cuda_time_recursive(child) return cumulative_cuda_time return _cumulative_cuda_time_recursive(node) def _total_cuda_time(self): - return sum( - [self._cumulative_cuda_time(root) for root in self._module_tree]) + return sum([self._cumulative_cuda_time(root) for root in self._module_tree]) def _build_stats_trees(self): summary_dict: dict[str, _StatsTreeNode] = {} @@ -239,38 +243,42 @@ class LayerwiseProfileResults(profile): def build_summary_stats_tree_df( node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None, - summary_trace: tuple[str] = ()): - + summary_trace: tuple[str] = (), + ): if event_has_module(node.event): name = event_module_repr(node.event) cuda_time_us = self._cumulative_cuda_time(node) - elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + elif gpu_kineto_event := self._get_kineto_gpu_event(node): name = gpu_kineto_event.name() cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 else: return None - summary_trace = summary_trace + (name, ) + summary_trace = summary_trace + (name,) if summary_trace in summary_dict: entry = summary_dict[summary_trace].entry entry.cuda_time_us += cuda_time_us entry.invocations += 1 entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) else: - new_node = _StatsTreeNode(entry=SummaryStatsEntry( - name=name, - cuda_time_us=cuda_time_us, - pct_cuda_time=pct_cuda_time(cuda_time_us), - invocations=1), - children=[], - parent=parent) + new_node = _StatsTreeNode( + entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1, + ), + children=[], + parent=parent, + ) if parent: parent.children.append(new_node) summary_dict[summary_trace] = new_node for child in node.children: - build_summary_stats_tree_df(child, summary_dict[summary_trace], - summary_trace) + build_summary_stats_tree_df( + child, summary_dict[summary_trace], summary_trace + ) return summary_dict[summary_trace] @@ -278,14 +286,17 @@ class LayerwiseProfileResults(profile): for root in self._module_tree: self._summary_stats_tree.append(build_summary_stats_tree_df(root)) - def build_model_stats_tree_df(node: _ModuleTreeNode, - parent: Optional[_StatsTreeNode] = None): - if event_has_module(node.event, ): + def build_model_stats_tree_df( + node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None + ): + if event_has_module( + node.event, + ): name = event_module_repr(node.event) cuda_time_us = self._cumulative_cuda_time(node) cpu_time_us = node.event.duration_time_ns / 1000 trace = "" - elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + elif gpu_kineto_event := self._get_kineto_gpu_event(node): name = gpu_kineto_event.name() cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 cpu_time_us = 0 @@ -293,14 +304,17 @@ class LayerwiseProfileResults(profile): else: return None - new_node = _StatsTreeNode(entry=ModelStatsEntry( - name=name, - cpu_time_us=cpu_time_us, - cuda_time_us=cuda_time_us, - pct_cuda_time=pct_cuda_time(cuda_time_us), - trace=trace), - parent=parent, - children=[]) + new_node = _StatsTreeNode( + entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace, + ), + parent=parent, + children=[], + ) if parent: parent.children.append(new_node) @@ -314,7 +328,8 @@ class LayerwiseProfileResults(profile): self._model_stats_tree.append(build_model_stats_tree_df(root)) def _flatten_stats_tree( - self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]: + self, tree: list[_StatsTreeNode] + ) -> list[tuple[int, StatsEntry]]: entries: list[tuple[int, StatsEntry]] = [] def df_traversal(node: _StatsTreeNode, depth=0): @@ -327,15 +342,11 @@ class LayerwiseProfileResults(profile): return entries - def _convert_stats_tree_to_dict(self, - tree: list[_StatsTreeNode]) -> list[dict]: + def _convert_stats_tree_to_dict(self, tree: list[_StatsTreeNode]) -> list[dict]: root_dicts: list[dict] = [] def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): - curr_json_list.append({ - "entry": asdict(node.entry), - "children": [] - }) + curr_json_list.append({"entry": asdict(node.entry), "children": []}) for child in node.children: df_traversal(child, curr_json_list[-1]["children"]) @@ -346,22 +357,22 @@ class LayerwiseProfileResults(profile): class layerwise_profile(profile): - def __init__(self, num_running_seqs: Optional[int] = None): """ layerwise profile constructor. Args: num_running_seqs (Optional[int], optional): When given, - num_running_seqs will be passed to LayerProfileResults for metadata - update. Defaults to None. + num_running_seqs will be passed to LayerProfileResults + for metadata update. Defaults to None. """ super().__init__( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True, with_modules=True, - experimental_config=_ExperimentalConfig(verbose=True)) + experimental_config=_ExperimentalConfig(verbose=True), + ) self.num_running_seqs = num_running_seqs @@ -371,5 +382,5 @@ class layerwise_profile(profile): def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) self.results = LayerwiseProfileResults( - self.profiler.kineto_results, - num_running_seqs=self.num_running_seqs) + self.profiler.kineto_results, num_running_seqs=self.num_running_seqs + ) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py index 9f0f56a15fd53..b3607fbecde78 100644 --- a/vllm/profiler/utils.py +++ b/vllm/profiler/utils.py @@ -30,9 +30,9 @@ def trim_string_back(string, width): class TablePrinter: - - def __init__(self, row_cls: type[dataclasses.dataclass], - column_widths: dict[str, int]): + def __init__( + self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int] + ): self.row_cls = row_cls self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.column_widths = column_widths @@ -46,16 +46,18 @@ class TablePrinter: def _print_header(self): for i, f in enumerate(self.fieldnames): - last = (i == len(self.fieldnames) - 1) + last = i == len(self.fieldnames) - 1 col_width = self.column_widths[f] - print(trim_string_back(f, col_width).ljust(col_width), - end=" | " if not last else "\n") + print( + trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n", + ) def _print_row(self, row): assert isinstance(row, self.row_cls) for i, f in enumerate(self.fieldnames): - last = (i == len(self.fieldnames) - 1) + last = i == len(self.fieldnames) - 1 col_width = self.column_widths[f] val = getattr(row, f) @@ -75,9 +77,9 @@ class TablePrinter: print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) -def indent_string(string: str, - indent: int, - indent_style: Union[Callable[[int], str], str] = " ") -> str: +def indent_string( + string: str, indent: int, indent_style: Union[Callable[[int], str], str] = " " +) -> str: if indent: if isinstance(indent_style, str): return indent_style * indent + string @@ -111,15 +113,14 @@ def event_arg_repr(arg) -> str: elif isinstance(arg, tuple): return f"({', '.join([event_arg_repr(x) for x in arg])})" else: - assert isinstance(arg, - _TensorMetadata), f"Unsupported type: {type(arg)}" - sizes_str = ', '.join([str(x) for x in arg.sizes]) + assert isinstance(arg, _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ", ".join([str(x) for x in arg.sizes]) return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" def event_torch_op_repr(event: _ProfilerEvent) -> str: assert event.tag == _EventType.TorchOp - args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + args_str = ", ".join([event_arg_repr(x) for x in event.typed[1].inputs]) return f"{event.name}({args_str})".replace("aten::", "") @@ -127,15 +128,17 @@ def event_module_repr(event: _ProfilerEvent) -> str: assert event_has_module(event) module = event.typed[1].module if module.parameters and len(module.parameters) > 0: - args_str = ', '.join( - [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + args_str = ", ".join( + [f"{x[0]}={event_arg_repr(x[1])}" for x in module.parameters] + ) return f"{module.cls_name}({args_str})" else: return module.cls_name -def event_torch_op_stack_trace(curr_event: _ProfilerEvent, - until: Callable[[_ProfilerEvent], bool]) -> str: +def event_torch_op_stack_trace( + curr_event: _ProfilerEvent, until: Callable[[_ProfilerEvent], bool] +) -> str: trace = "" curr_event = curr_event.parent while curr_event and not until(curr_event): diff --git a/vllm/ray/lazy_utils.py b/vllm/ray/lazy_utils.py index bb3535579cfdf..64b5f51571a35 100644 --- a/vllm/ray/lazy_utils.py +++ b/vllm/ray/lazy_utils.py @@ -6,6 +6,7 @@ def is_ray_initialized(): """Check if Ray is initialized.""" try: import ray + return ray.is_initialized() except ImportError: return False @@ -16,7 +17,10 @@ def is_in_ray_actor(): try: import ray - return (ray.is_initialized() - and ray.get_runtime_context().get_actor_id() is not None) + + return ( + ray.is_initialized() + and ray.get_runtime_context().get_actor_id() is not None + ) except ImportError: return False diff --git a/vllm/ray/ray_env.py b/vllm/ray/ray_env.py index f6a994bb3c226..a89e55bd7e4b6 100644 --- a/vllm/ray/ray_env.py +++ b/vllm/ray/ray_env.py @@ -14,7 +14,8 @@ CONFIG_HOME = envs.VLLM_CONFIG_ROOT # This file contains a list of env vars that should not be copied # from the driver to the Ray workers. RAY_NON_CARRY_OVER_ENV_VARS_FILE = os.path.join( - CONFIG_HOME, "ray_non_carry_over_env_vars.json") + CONFIG_HOME, "ray_non_carry_over_env_vars.json" +) try: if os.path.exists(RAY_NON_CARRY_OVER_ENV_VARS_FILE): @@ -25,13 +26,16 @@ try: except json.JSONDecodeError: logger.warning( "Failed to parse %s. Using an empty set for non-carry-over env vars.", - RAY_NON_CARRY_OVER_ENV_VARS_FILE) + RAY_NON_CARRY_OVER_ENV_VARS_FILE, + ) RAY_NON_CARRY_OVER_ENV_VARS = set() -def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None, - additional_vars: Optional[set[str]] = None, - destination: Optional[str] = None) -> set[str]: +def get_env_vars_to_copy( + exclude_vars: Optional[set[str]] = None, + additional_vars: Optional[set[str]] = None, + destination: Optional[str] = None, +) -> set[str]: """ Get the environment variables to copy to downstream Ray actors. @@ -60,13 +64,17 @@ def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None, to_destination = " to " + destination if destination is not None else "" - logger.info("RAY_NON_CARRY_OVER_ENV_VARS from config: %s", - RAY_NON_CARRY_OVER_ENV_VARS) - logger.info("Copying the following environment variables%s: %s", - to_destination, - [v for v in env_vars_to_copy if v in os.environ]) logger.info( - "If certain env vars should NOT be copied, add them to " - "%s file", RAY_NON_CARRY_OVER_ENV_VARS_FILE) + "RAY_NON_CARRY_OVER_ENV_VARS from config: %s", RAY_NON_CARRY_OVER_ENV_VARS + ) + logger.info( + "Copying the following environment variables%s: %s", + to_destination, + [v for v in env_vars_to_copy if v in os.environ], + ) + logger.info( + "If certain env vars should NOT be copied, add them to %s file", + RAY_NON_CARRY_OVER_ENV_VARS_FILE, + ) return env_vars_to_copy diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index b987adeb6428f..78d3bf35f2a32 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -2,17 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from .basic_parsers import BaseThinkingReasoningParser from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser +from .olmo3_reasoning_parser import Olmo3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser +from .seedoss_reasoning_parser import SeedOSSReasoningParser from .step3_reasoning_parser import Step3ReasoningParser __all__ = [ "ReasoningParser", + "BaseThinkingReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser", "GraniteReasoningParser", @@ -20,6 +24,8 @@ __all__ = [ "Qwen3ReasoningParser", "Glm4MoeModelReasoningParser", "MistralReasoningParser", + "Olmo3ReasoningParser", "Step3ReasoningParser", "GptOssReasoningParser", + "SeedOSSReasoningParser", ] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index df9e84163f16c..2d93f0702f721 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -7,15 +7,17 @@ import os from abc import abstractmethod from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Union from vllm.logger import init_logger from vllm.utils import import_from_path, is_list_of if TYPE_CHECKING: - from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ResponsesRequest) + from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, + ) from vllm.transformers_utils.tokenizer import AnyTokenizer else: ChatCompletionRequest = Any @@ -34,7 +36,7 @@ class ReasoningParser: It is used to extract reasoning content from the model output. """ - def __init__(self, tokenizer: AnyTokenizer): + def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): self.model_tokenizer = tokenizer @cached_property @@ -77,7 +79,7 @@ class ReasoningParser: self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest], - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Extract reasoning content from a complete model-generated string. @@ -128,19 +130,19 @@ class ReasoningParserManager: if name in cls.reasoning_parsers: return cls.reasoning_parsers[name] - raise KeyError( - f"reasoning helper: '{name}' not found in reasoning_parsers") + raise KeyError(f"reasoning helper: '{name}' not found in reasoning_parsers") @classmethod def _register_module( cls, module: type, - module_name: Optional[Union[str, list[str]]] = None, + module_name: Union[str, list[str]] | None = None, force: bool = True, ) -> None: if not issubclass(module, ReasoningParser): - raise TypeError("module must be subclass of ReasoningParser, " - f"but got {type(module)}") + raise TypeError( + f"module must be subclass of ReasoningParser, but got {type(module)}" + ) if module_name is None: module_name = module.__name__ if isinstance(module_name, str): @@ -148,14 +150,15 @@ class ReasoningParserManager: for name in module_name: if not force and name in cls.reasoning_parsers: existed_module = cls.reasoning_parsers[name] - raise KeyError(f"{name} is already registered " - f"at {existed_module.__module__}") + raise KeyError( + f"{name} is already registered at {existed_module.__module__}" + ) cls.reasoning_parsers[name] = module @classmethod def register_module( cls, - name: Optional[Union[str, list[str]]] = None, + name: Union[str, list[str]] | None = None, force: bool = True, module: Union[type, None] = None, ) -> Union[type, Callable]: @@ -168,11 +171,11 @@ class ReasoningParserManager: raise TypeError(f"force must be a boolean, but got {type(force)}") # raise the error ahead of time - if not (name is None or isinstance(name, str) - or is_list_of(name, str)): + if not (name is None or isinstance(name, str) or is_list_of(name, str)): raise TypeError( "name must be None, an instance of str, or a sequence of str, " - f"but got {type(name)}") + f"but got {type(name)}" + ) # use it as a normal method: x.register_module(module=SomeClass) if module is not None: @@ -197,6 +200,7 @@ class ReasoningParserManager: try: import_from_path(module_name, plugin_path) except Exception: - logger.exception("Failed to load module '%s' from %s.", - module_name, plugin_path) + logger.exception( + "Failed to load module '%s' from %s.", module_name, plugin_path + ) return diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py new file mode 100644 index 0000000000000..f47ffe6212caf --- /dev/null +++ b/vllm/reasoning/basic_parsers.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from collections.abc import Sequence +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, +) +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class BaseThinkingReasoningParser(ReasoningParser): + """ + Base class for reasoning parsers that use thinking tokens. + + This class provides common functionality for parsers that use start and end + tokens to delimit reasoning content ( + e.g., <think>...</think>, <seed:think>...</seed:think>). + + Subclasses must implement the start and end tokens via abstract + properties. + """ + + @property + @abstractmethod + def start_token(self) -> str: + """The token that starts reasoning content.""" + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + """The token that ends reasoning content.""" + raise NotImplementedError + + def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction." + ) + + if not self.start_token or not self.end_token: + raise ValueError("start_token and end_token must be defined in subclasses") + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__} reasoning parser could not locate " + "think start/end tokens in the tokenizer!" + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + end_token_id = self.end_token_id + return any(input_id == end_token_id for input_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.start_token_id, self.end_token_id] + ): + return None + + # Check if start token is present in previous or delta. + # Keep compatibility with models that don't generate start tokens. + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + # start token in previous, end token in delta, + # extract reasoning content + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # start token in previous, end token in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # start token in previous, no end token in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + # start token in delta, end token in delta, + # extract reasoning content + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[ + start_index + len(self.start_token) : end_index + ] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + else: + # start token in delta, no end token in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # not find thinking start token + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest] + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + This is the base implementation that works for most models. + Subclasses can override this method for specific behavior. + """ + # Check if the start token is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + + # For models that may not generate start token, + # assume the reasoning content is always at the start. + if self.end_token not in model_output: + return model_output, None + else: + reasoning_content, _, content = model_output.partition(self.end_token) + # If generation stops right after end-of-think, return null content + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f1d..264da54b48793 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -2,20 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union +from typing import Union -from transformers import PreTrainedTokenizerBase - -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from vllm.logger import init_logger -from vllm.reasoning import ReasoningParser, ReasoningParserManager - -logger = init_logger(__name__) +from vllm.entrypoints.openai.protocol import DeltaMessage +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @ReasoningParserManager.register_module("deepseek_r1") -class DeepSeekR1ReasoningParser(ReasoningParser): +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for DeepSeek R1 model. @@ -23,38 +18,15 @@ class DeepSeekR1ReasoningParser(ReasoningParser): text. This parser extracts the reasoning content from the model output. """ - start_token_id: int - end_token_id: int + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "<think>" - start_token: str = "<think>" - end_token: str = "</think>" - - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) - - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") - - self.start_token_id = self.vocab.get(self.start_token) - self.end_token_id = self.vocab.get(self.end_token) - if self.start_token_id is None or self.end_token_id is None: - raise RuntimeError( - "DeepSeek R1 reasoning parser could not locate think start/end " - "tokens in the tokenizer!") - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.end_token_id not in input_ids[:-1]: - return [] - else: - return input_ids[input_ids.index(self.end_token_id) + 1:] + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "</think>" def extract_reasoning_content_streaming( self, @@ -65,109 +37,34 @@ class DeepSeekR1ReasoningParser(ReasoningParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text <think>abc</think>xyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.start_token_id, self.end_token_id - ]): - return None - - # Check if <think> is present in previous or delta. - # Keep compatibility with models that don't generate <think> tokens. - if self.start_token_id in previous_token_ids: + ret = super().extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if ( + ret is not None + and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): if self.end_token_id in delta_token_ids: - # <think> in previous, </think> in delta, - # extract reasoning content - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token):] - return DeltaMessage( - reasoning_content=reasoning_content, - content=content if content else None, - ) - elif self.end_token_id in previous_token_ids: - # <think> in previous, </think> in previous, - # reasoning content continues - return DeltaMessage(content=delta_text) - else: - # <think> in previous, no </think> in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.start_token_id in delta_token_ids: - if self.end_token_id in delta_token_ids: - # <think> in delta, </think> in delta, extract reasoning content - start_index = delta_text.find(self.start_token) - end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[start_index + - len(self.start_token):end_index] - content = delta_text[end_index + len(self.end_token):] - return DeltaMessage( - reasoning_content=reasoning_content, - content=content if content else None, - ) - else: - # <think> in delta, no </think> in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # No <think> in previous or delta, also need to check for </think>. - # Because the model may have generated </think> without <think> - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token_id in delta_token_ids: - # </think> in delta with more tokens, + # end token in delta with more tokens, # extract reasoning content and content end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token):] + content = delta_text[end_index + len(self.end_token) :] return DeltaMessage( reasoning_content=reasoning_content, content=content if content else None, ) elif self.end_token_id in previous_token_ids: - # </think> in previous, thinking content ends + # end token in previous, thinking content ends return DeltaMessage(content=delta_text) else: - # no </think> in previous or delta, reasoning content continues + # no end token in previous or delta, reasoning content continues return DeltaMessage(reasoning_content=delta_text) - def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: - """ - Extract reasoning content from the model output. - - For text <think>abc</think>xyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - - Returns: - tuple[Optional[str], Optional[str]]: reasoning content and content - """ - - # Check if the start token is present in the model output, remove it - # if it is present. - model_output_parts = model_output.partition(self.start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] - - # DeepSeek R1 doesn't generate <think> now. - # Thus we assume the reasoning content is always at the start. - # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token not in model_output: - return model_output, None - else: - reasoning_content, _, content = model_output.partition( - self.end_token) - # If the end token is not found, return the model output as is. - # It should not happen since we already checked for the presence - # of the end token. - # If generation stops right after end-of-think, return null content - final_content = content or None - return reasoning_content, final_content + return ret diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py index 460e38d2d396b..da98515c7e629 100644 --- a/vllm/reasoning/glm4_moe_reasoning_parser.py +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -6,8 +6,7 @@ from typing import Optional, Union from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -26,26 +25,43 @@ class Glm4MoeModelReasoningParser(ReasoningParser): from the model's output. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) self.think_start_token = "<think>" self.think_end_token = "</think>" + self.assistant_token = "<|assistant|>" if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) self.think_start_token_id = self.vocab.get(self.think_start_token) self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None): + self.assistant_token_id = self.vocab.get(self.assistant_token) + if ( + self.think_start_token_id is None + or self.think_end_token_id is None + or self.assistant_token_id is None + ): raise RuntimeError( "Glm4MoeModel reasoning parser could not locate " - "think start/end tokens in the tokenizer!") + "think start/end or assistant tokens in the tokenizer!" + ) def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids + """ + GLM's chat template has <think></think> tokens after every + <|assistant|> token. Thus, we need to check if </think> is + after the most recent <|assistant|> token (if present). + """ + for token_id in input_ids[::-1]: + if token_id == self.think_end_token_id: + return True + elif token_id == self.assistant_token_id: + return False + return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ @@ -54,7 +70,7 @@ class Glm4MoeModelReasoningParser(ReasoningParser): if self.think_end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] def extract_reasoning_content_streaming( self, @@ -74,9 +90,9 @@ class Glm4MoeModelReasoningParser(ReasoningParser): - 'xyz' goes to content """ # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id - ]): + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id] + ): return None if self.think_start_token_id in previous_token_ids: @@ -85,9 +101,11 @@ class Glm4MoeModelReasoningParser(ReasoningParser): # extract reasoning content end_index = delta_text.find(self.think_end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) elif self.think_end_token_id in previous_token_ids: # <think> in previous, </think> in previous, # reasoning content continues @@ -101,12 +119,14 @@ class Glm4MoeModelReasoningParser(ReasoningParser): # <think> in delta, </think> in delta, extract reasoning content start_index = delta_text.find(self.think_start_token) end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + reasoning_content = delta_text[ + start_index + len(self.think_start_token) : end_index + ] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) else: # <think> in delta, no </think> in delta, # reasoning content continues @@ -116,7 +136,7 @@ class Glm4MoeModelReasoningParser(ReasoningParser): return DeltaMessage(content=delta_text) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. @@ -130,22 +150,24 @@ class Glm4MoeModelReasoningParser(ReasoningParser): """ # Check if the model output contains the <think> and </think> tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + if ( + self.think_start_token not in model_output + or self.think_end_token not in model_output + ): return None, model_output # Check if the <think> is present in the model output, remove it # if it is present. model_output_parts = model_output.partition(self.think_start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) # Check if the model output contains the </think> tokens. # If the end token is not found, return the model output as is. if self.think_end_token not in model_output: return None, model_output # Extract reasoning content from the model output. - reasoning_content, _, content = model_output.partition( - self.think_end_token) + reasoning_content, _, content = model_output.partition(self.think_end_token) final_content = content or None return reasoning_content, final_content diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index 05a72ac23bf2e..738c7b51694a0 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -6,15 +6,15 @@ from typing import Optional, Union from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.harmony_utils import parse_chat_output +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) -@ReasoningParserManager.register_module("GptOss") +@ReasoningParserManager.register_module("openai_gptoss") class GptOssReasoningParser(ReasoningParser): """ Reasoning parser for GptOss model. @@ -23,10 +23,11 @@ class GptOssReasoningParser(ReasoningParser): is only used for detecting the end of the reasoning content. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) self.reasoning_end_token_ids = self.model_tokenizer.encode( - "<|start|>assistant<|channel|>final<|message|>") + "<|start|>assistant<|channel|>final<|message|>" + ) def is_reasoning_end(self, input_ids: list[int]) -> bool: end_token_ids = self.reasoning_end_token_ids @@ -34,14 +35,15 @@ class GptOssReasoningParser(ReasoningParser): # Check if the end sequence is present in the input_ids. # We search from the end of input_ids to find the last match. for i in range(len(input_ids) - len(end_token_ids), -1, -1): - if input_ids[i:i + len(end_token_ids)] == end_token_ids: + if input_ids[i : i + len(end_token_ids)] == end_token_ids: return True return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + _, content, _ = parse_chat_output(input_ids) + if content is None: + return [] + return self.model_tokenizer.encode(content) def extract_reasoning_content_streaming( self, @@ -52,13 +54,31 @@ class GptOssReasoningParser(ReasoningParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids)) + cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids)) + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + if cur_reasoning.startswith(prev_r): + reasoning_delta = cur_reasoning[len(prev_r) :] or None + else: + reasoning_delta = cur_reasoning + if cur_content is not None: + prev_c = prev_content or "" + if cur_content.startswith(prev_c): + content_delta = cur_content[len(prev_c) :] or None + else: + content_delta = cur_content + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning_content=reasoning_delta, content=content_delta) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, + model_output: str, + request: ChatCompletionRequest, ) -> tuple[Optional[str], Optional[str]]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + raise NotImplementedError( + "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 + ) diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index 5820001b918f6..543b202989ee9 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -7,8 +7,7 @@ from typing import Optional, Union import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -24,8 +23,8 @@ class GraniteReasoningParser(ReasoningParser): and "Here is my response:" to separate its thinking / response outputs. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) # NOTE: There have been some observed occurrences of quantized # instances of the current models using "Here's" instead of "Here is", @@ -34,15 +33,14 @@ class GraniteReasoningParser(ReasoningParser): self.response_start_expr = r"(?:Here's|Here is) my response:" self.reasoning_regex = re.compile( - rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", - re.DOTALL) + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) self.valid_think_starts = [ - "Here's my thought process:", "Here is my thought process:" - ] - self.valid_response_starts = [ - "Here's my response:", "Here is my response:" + "Here's my thought process:", + "Here is my thought process:", ] + self.valid_response_starts = ["Here's my response:", "Here is my response:"] # Substrings to match for sequence boundaries on raw text self.seq_boundary_end = ":" @@ -50,10 +48,11 @@ class GraniteReasoningParser(ReasoningParser): # The longest any thinking / start of response message can be self.longest_think_start = max( - len(think_start) for think_start in self.valid_think_starts) + len(think_start) for think_start in self.valid_think_starts + ) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates @@ -111,24 +110,27 @@ class GraniteReasoningParser(ReasoningParser): DeltaMessage with either reasoning content or content, or None. """ reasoning_content, resp_seq_len, content = self._get_content_sections( - current_text) + current_text + ) # Either we haven't finished the start of the reasoning sequence, # or the model is generating something unexpected. if not reasoning_content: delta_message = self._get_delta_message_with_no_reasoning_bounds( - current_text, delta_text) + current_text, delta_text + ) # We have a start of reasoning message, but have not yet finished # the start of response sequence. elif not content: delta_message = self._get_delta_message_with_no_response_bounds( - current_text, reasoning_content, delta_text) + current_text, reasoning_content, delta_text + ) # We've finished both the start of reasoning and start of response seq. else: # This should never happen since we matched on the response assert resp_seq_len is not None delta_message = self._get_delta_message_with_both_bounds( - delta_text, reasoning_content, content, current_text, - resp_seq_len) + delta_text, reasoning_content, content, current_text, resp_seq_len + ) if not delta_message.content and not delta_message.reasoning_content: return None return delta_message @@ -139,26 +141,27 @@ class GraniteReasoningParser(ReasoningParser): Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible reasoning start seqs match. """ return any( - think_start.startswith(text) - for think_start in self.valid_think_starts) + think_start.startswith(text) for think_start in self.valid_think_starts + ) def _is_response_start_substr(self, text: str) -> bool: """Check if a text matches one of the possible start response seqs. Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible response start seqs match. """ return any( response_start.startswith(text) - for response_start in self.valid_response_starts) + for response_start in self.valid_response_starts + ) def _get_delta_message_with_no_reasoning_bounds( self, @@ -177,8 +180,7 @@ class GraniteReasoningParser(ReasoningParser): """ prev_longest_length = len(current_text) - len(delta_text) is_substr = self._is_reasoning_start_substr(current_text) - was_substr = self._is_reasoning_start_substr( - current_text[:prev_longest_length]) + was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length]) # Check if we just generated something NOT in the special token seq; # if so, add everything that we previously skipped with this delta @@ -220,12 +222,13 @@ class GraniteReasoningParser(ReasoningParser): # content and fully parse it out; we should not pass the : back. ends_with_start_response_seq = any( current_text.endswith(response_start) - for response_start in self.valid_response_starts) + for response_start in self.valid_response_starts + ) if reasoning_content is None or ends_with_start_response_seq: return DeltaMessage(reasoning_content=None, content=None) # Consider previous / current text only within context of the reasoning - previous_text = reasoning_content[:-len(delta_text)] + previous_text = reasoning_content[: -len(delta_text)] current_text = reasoning_content # We need to be careful about adding unfinished response sequences; @@ -234,12 +237,21 @@ class GraniteReasoningParser(ReasoningParser): delta_idx = delta_text.rfind(self.seq_boundary_start) # Check the state of potential start of response substring matches. - prev_was_substr = self._is_response_start_substr( - previous_text[prev_idx:]) if prev_idx >= 0 else False - delta_continues_substr = self._is_response_start_substr( - current_text[prev_idx:]) if prev_idx >= 0 else False - delta_new_substr = self._is_response_start_substr( - delta_text[delta_idx:]) if delta_idx >= 0 else False + prev_was_substr = ( + self._is_response_start_substr(previous_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_continues_substr = ( + self._is_response_start_substr(current_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_new_substr = ( + self._is_response_start_substr(delta_text[delta_idx:]) + if delta_idx >= 0 + else False + ) # Delta only contains potential continued response sequence text. if delta_continues_substr: @@ -248,18 +260,17 @@ class GraniteReasoningParser(ReasoningParser): if not prev_was_substr: # Delta may be starting a new response seq but has other text too. if delta_new_substr: - return DeltaMessage(reasoning_content=delta_text[:delta_idx], - content=None) + return DeltaMessage( + reasoning_content=delta_text[:delta_idx], content=None + ) # Normal case for most reasoning text (no potential special seqs). return DeltaMessage(reasoning_content=delta_text, content=None) # The substring that previously seemed to be a potential response # seq wasn't one; we need to add the content to the delta message, # and also slice off the potential response sequence elif delta_new_substr: - reasoning_content = previous_text[ - prev_idx:] + delta_text[:delta_idx] - return DeltaMessage(reasoning_content=reasoning_content, - content=None) + reasoning_content = previous_text[prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning_content=reasoning_content, content=None) # No new substring yet, and we broke our old one; take the whole delta return DeltaMessage( reasoning_content=previous_text[prev_idx:] + delta_text, @@ -278,33 +289,31 @@ class GraniteReasoningParser(ReasoningParser): content and normal (response) content. Args: - delta_text (str): Text to consider and parse content from. - reasoning_content (str): reasoning content from current_text. - response_content (str): response content from current_text. - current_text (str): The full previous + delta text. - response_seq_len(str): Len of the complete response sequence used. + delta_text: Text to consider and parse content from. + reasoning_content: reasoning content from current_text. + response_content: response content from current_text. + current_text: The full previous + delta text. + response_seq_len: Len of the complete response sequence used. Returns: DeltaMessage: Message containing the parsed content. """ # Always have content; take length to the end - delta_content = delta_text[-len(response_content):] - reasoning_end_idx = len(delta_text) - (len(response_content) + - response_seq_len) + delta_content = delta_text[-len(response_content) :] + reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len) if reasoning_end_idx < 0: delta_reasoning_content = None else: # Get the starting offset - start_reasoning_content_idx = len( - reasoning_content) + response_seq_len + len( - response_content) - 1 + start_reasoning_content_idx = ( + len(reasoning_content) + response_seq_len + len(response_content) - 1 + ) delta_offset = len(current_text) - len(delta_text) start_offset = start_reasoning_content_idx - delta_offset if start_offset < 0: start_offset = 0 - delta_reasoning_content = delta_text[ - start_offset:reasoning_end_idx] + delta_reasoning_content = delta_text[start_offset:reasoning_end_idx] return DeltaMessage( reasoning_content=delta_reasoning_content, @@ -329,7 +338,8 @@ class GraniteReasoningParser(ReasoningParser): start_reasoning_content = None parsed_content = False delimiter_idxs = [ - idx for idx, char in enumerate(current_text) + idx + for idx, char in enumerate(current_text) if char == self.seq_boundary_end ] @@ -346,17 +356,15 @@ class GraniteReasoningParser(ReasoningParser): # Check to see if the start of response seq if complete elif not parsed_content: for response_start in self.valid_response_starts: - if current_chunk[-len(response_start) + - 1:] == response_start[:-1]: + if current_chunk[-len(response_start) + 1 :] == response_start[:-1]: # Mark end of reasoning and start response content # after the start of response sequence. - end_reasoning_content = current_chunk_end - len( - response_start) + end_reasoning_content = current_chunk_end - len(response_start) reasoning_content = current_text[ - start_reasoning_content:end_reasoning_content] - response_content = current_text[current_chunk_end + 1:] - return reasoning_content, len( - response_start), response_content + start_reasoning_content:end_reasoning_content + ] + response_content = current_text[current_chunk_end + 1 :] + return reasoning_content, len(response_start), response_content if start_reasoning_content and not parsed_content: return current_text[start_reasoning_content:], None, None diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index b2452b95c1c67..381f1b5f34667 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -7,8 +7,7 @@ from typing import Optional, Union import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -22,16 +21,16 @@ class HunyuanA13BReasoningParser(ReasoningParser): HunyuanReasoningParser - This class implements a reasoning parser specifically designed - for the Hunyuan A13B Model. It is responsible for parsing and - extracting structured reasoning and answer segments from model + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model outputs that follow a specific pattern. Key Features: - For non-stream output , Recognizes and extracts reasoning ("think") and answer ("answer") sections from text using regular expressions. - - For stream process, it require a token id sequences to change the - reasoning state and other state so it maintains internal state to + - For stream process, it requires a token id sequences to change the + reasoning state and other state so it maintains internal state to manage parsing across multiple token. @@ -40,8 +39,8 @@ class HunyuanA13BReasoningParser(ReasoningParser): response ends: "\n</answer>": [524, 9399, 29] """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) self.think_start_expr = r"<think>\n" self.think_end_expr = r"\n</think>\n" @@ -50,20 +49,19 @@ class HunyuanA13BReasoningParser(ReasoningParser): self.full_match_reasoning_regex = re.compile( rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", - re.DOTALL) + re.DOTALL, + ) self.half_match_reasoning_regex = re.compile( - rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", - re.DOTALL) + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) self.think_start_ids = [14023, 771, 397] self.think_start_ids_fast = [14023, 771, 1363] self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] self.response_end_ids = [198, 524, 9399, 29] - self.fast_think_ids = [ - 14023, 771, 1363, 524, 27963, 397, 27, 9399, 397 - ] + self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397] # when state change, send out all the buffered text in last state self.buffered_text = [] @@ -91,7 +89,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): return [] def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates @@ -121,8 +119,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): reasoning_content, response_content = fallback_match[0] if response_content.endswith(self.response_end_expr): - response_content = response_content[:-len(self. - response_end_expr)] + response_content = response_content[: -len(self.response_end_expr)] if len(reasoning_content) == 0: reasoning_content = None @@ -133,8 +130,9 @@ class HunyuanA13BReasoningParser(ReasoningParser): return None, model_output - def _is_strict_increasing_subsequence(self, subsequence: Sequence[int], - sequence: Sequence[int]) -> bool: + def _is_strict_increasing_subsequence( + self, subsequence: Sequence[int], sequence: Sequence[int] + ) -> bool: if not subsequence: return False @@ -159,27 +157,27 @@ class HunyuanA13BReasoningParser(ReasoningParser): response_start_sequence = self.response_start_ids response_end_sequence = self.response_end_ids - assert (len(delta_token_ids) == 1) + assert len(delta_token_ids) == 1 # Process each token in the delta token = delta_token_ids[0] def check_token_with_sequence(token): if self.current_state == "idle" or self.current_state == "think": - return (token == self.expected_sequence[self.sequence_index] - or token == \ - self.expected_sequence_side[self.sequence_index]) + return ( + token == self.expected_sequence[self.sequence_index] + or token == self.expected_sequence_side[self.sequence_index] + ) else: return token == self.expected_sequence[self.sequence_index] def check_last_token(token): if self.current_state == "idle" or self.current_state == "think": # only return true if it's judge using a side sequence. - if (self.sequence_index - 1 < len(self.expected_sequence_side) - and token - == self.expected_sequence_side[self.sequence_index - - 1]): - return self.sequence_index == len( - self.expected_sequence_side) + if ( + self.sequence_index - 1 < len(self.expected_sequence_side) + and token == self.expected_sequence_side[self.sequence_index - 1] + ): + return self.sequence_index == len(self.expected_sequence_side) else: return self.sequence_index == len(self.expected_sequence) else: @@ -227,19 +225,19 @@ class HunyuanA13BReasoningParser(ReasoningParser): # Return content based on current state if self.current_state == "think": - return DeltaMessage(reasoning_content=buffered_content, - content=None) + return DeltaMessage( + reasoning_content=buffered_content, content=None + ) else: - return DeltaMessage(reasoning_content=None, - content=buffered_content) + return DeltaMessage( + reasoning_content=None, content=buffered_content + ) else: # No buffered content, send normally if self.current_state == "think": - return DeltaMessage(reasoning_content=delta_text, - content=None) + return DeltaMessage(reasoning_content=delta_text, content=None) else: - return DeltaMessage(reasoning_content=None, - content=delta_text) + return DeltaMessage(reasoning_content=None, content=delta_text) # If no content to send in this delta return None diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index 6c707a4079fa0..5658c372a264c 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cached_property + from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.reasoning.deepseek_r1_reasoning_parser import ( - DeepSeekR1ReasoningParser) +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -19,29 +20,37 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): text. This parser extracts the reasoning content from the model output. """ - def __init__(self, tokenizer: MistralTokenizer): + def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs): if not isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "The tokenizer must be an instance of MistralTokenizer.") + raise ValueError("The tokenizer must be an instance of MistralTokenizer.") - ReasoningParser.__init__(self, tokenizer) + ReasoningParser.__init__(self, tokenizer, *args, **kwargs) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) - from mistral_common.tokens.tokenizers.base import SpecialTokens - - self.start_token = SpecialTokens.begin_think - self.end_token = SpecialTokens.end_think - - self.start_token_id = tokenizer.tokenizer.get_control_token( - self.start_token) - self.end_token_id = tokenizer.tokenizer.get_control_token( - self.end_token) + self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token) + self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token) if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( "Mistral reasoning parser could not locate think start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) + + @cached_property + def start_token(self) -> str: + """The token that starts reasoning content.""" + from mistral_common.tokens.tokenizers.base import SpecialTokens + + return SpecialTokens.begin_think + + @cached_property + def end_token(self) -> str: + """The token that ends reasoning content.""" + from mistral_common.tokens.tokenizers.base import SpecialTokens + + return SpecialTokens.end_think diff --git a/vllm/reasoning/olmo3_reasoning_parser.py b/vllm/reasoning/olmo3_reasoning_parser.py new file mode 100644 index 0000000000000..b330e8b1fdd5b --- /dev/null +++ b/vllm/reasoning/olmo3_reasoning_parser.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses as dt +import enum +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union + +import regex as re + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, +) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +class Olmo3ReasoningState(enum.Enum): + REASONING = 1 + CONTENT = 2 + + +@dt.dataclass(frozen=True) +class Indices: + start: int + end: int + + def __len__(self): + return self.end - self.start + + +def string_overlap(a: str, b: str) -> tuple[Optional[Indices], Optional[Indices]]: + """ + Find the longest overlap where the end of string a matches the start + of string b. + + Args: + a: First string + b: Second string + + Returns: + Tuple of IndicesTuples representing the overlapping portions in each + string, or a tuple of None if no overlap exists + """ + + # swap so a is always the shorter string + a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True) + + # first check: is a fully contained in b? + if a in b: + ind_a = Indices(0, len(a)) + ind_b = Indices(b.index(a), b.index(a) + len(a)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + # second check: does the end of a overlap with the + # beginning of b? + for i in range(len(a) - 1, 0, -1): + if a[-i:] == b[:i]: + ind_a = Indices(len(a) - i, len(a)) + ind_b = Indices(0, i) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + # third check: does the beginning of a overlap with + # the end of b? + for i in range(len(a) - 1, 0, -1): + if b[-i:] == a[:i]: + ind_a = Indices(0, i) + ind_b = Indices(len(b) - i, len(b)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + return None, None + + +@dt.dataclass +class Olmo3ReasoningBuffer: + think_start: str = "<think>" + think_end: str = "</think>" + buffer: str = "" + + # we start in reasoning state to support cases where we hardcode + # <think> as the start of the reasoning block. + # In those cases, the only token we will see is </think>, which + # is when we switch to content state. + state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING + + def process_buffer(self) -> Optional[DeltaMessage]: + start_think_idx = self.buffer.find(self.think_start) + + if start_think_idx >= 0: + self.state = Olmo3ReasoningState.REASONING + pretext, self.buffer = ( + self.buffer[:start_think_idx], + self.buffer[start_think_idx + len(self.think_start) :], + ) + if start_think_idx > 0: + # this covers the case there's content before + # the start of the reasoning block + return DeltaMessage(content=pretext) + + end_think_idx = self.buffer.rfind(self.think_end) + + if end_think_idx >= 0: + self.state = Olmo3ReasoningState.CONTENT + pretext, self.buffer = ( + self.buffer[:end_think_idx], + self.buffer[end_think_idx + len(self.think_end) :], + ) + if end_think_idx > 0: + # this covers the case there's content before + # the end of the reasoning block + return DeltaMessage(reasoning_content=pretext) + + if self.state == Olmo3ReasoningState.REASONING: + # we are inside reasoning block, return and empty + # the text buffer + ( + text_buffer, + self.buffer, + ) = self.buffer, "" + return DeltaMessage(reasoning_content=text_buffer) + + if self.state == Olmo3ReasoningState.CONTENT: + # we are outside reasoning block, return and empty + # the text buffer + ( + text_buffer, + self.buffer, + ) = self.buffer, "" + return DeltaMessage(content=text_buffer) + + # nothing to return unless we are in reasoning or content state + return None + + def __len__(self): + # is the length of the text buffer + return len(self.buffer) + + def add_text(self, delta_text: str) -> Optional[DeltaMessage]: + # we start by adding the delta text to the buffer + self.buffer += delta_text + + # setting this to empty before starting + delta_message: Optional[DeltaMessage] = None + + # we start by computing the overlap between the delta_text + # and start/end of think tokens. + _, overlap_think_start = string_overlap(delta_text, self.think_start) + _, overlap_think_end = string_overlap(delta_text, self.think_end) + + partial_overlap_start = overlap_think_start is not None and len( + overlap_think_start + ) < len(self.think_start) + partial_overlap_end = overlap_think_end is not None and len( + overlap_think_end + ) < len(self.think_end) + + if ( + partial_overlap_start + and self.think_start in self.buffer + and not partial_overlap_end + ): + # we can only process the buffer if partial overlap + # is the last part of think token (thus causing + # text_buffer to contain the start of think token) + # and there are no partial overlaps with end think + delta_message = self.process_buffer() + + elif partial_overlap_end and self.think_end in self.buffer: + # same as before (partial overlap only allowed) + # if the buffer contains the end think token, + # but we don't have to check for partial overlap + # with start think token because they are handled + # by the previous condition + delta_message = self.process_buffer() + + elif partial_overlap_start or partial_overlap_end: + # in general, if there are overlaps, we don't + # process the buffer because we want to wait until + # the think token is fully completed. + return None + else: + # we process the buffer as normal + delta_message = self.process_buffer() + + return delta_message + + +@ReasoningParserManager.register_module("olmo3") +class Olmo3ReasoningParser(ReasoningParser): + """ + Reasoning parser for Olmo 3 model + + Olmo3ReasoningParser + + This class implements a reasoning parser specifically designed for the + Olmo 3 family of models. Olmo 3 models do not use special tokens to + indicate reasoning; rather, reasoning trace is wrapped in `<think>` and + `</think>`, which are tokenized using standard vocabulary entries. + Because of this, the parser operates in string space, accumulating the + characters in a buffer until it sees `<think>` or `</think>`. tokens + to switch modes. + + Key Features: + - For non-stream output, Recognizes and extracts reasoning (text + bracketed by `<think>` and `</think>`) and content (everything + after the first `</think>`). + - For stream process, it uses a buffer to accumulate delta text, + and output progressive delta messages as soon as thinking starts + or ends. + - For reliability, some Olmo 3 models may hardcode the first + `<think>` token is the input text (similar to Deepseek R1, + or reasoning-only Qwen models). To support such variants, the + parser can optionally work in cases where the first `<think>` + token is missing from generation. + """ + + def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + self.think_start = r"<think>" + self.think_end = r"</think>" + + # notice that the first think is optional; this allows template to + # work in cases when we hardcode a <think> at the beginning of the + # reasoning template. + reasoning_expr = ( + rf"^(?:{self.think_start})?(?P<reasoning>.*?)" + + rf"{self.think_end}(?P<content>.*)$" + ) + self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL) + + self.buffer = Olmo3ReasoningBuffer( + think_start=self.think_start, think_end=self.think_end + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return self.think_end in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for Olmo 3 streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning_content( + self, + model_output: str, + request: Union[ChatCompletionRequest, ResponsesRequest], + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest | ResponsesRequest): Request being + processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.reasoning_regex.match(model_output) + if re_match: + reasoning_content = re_match.group("reasoning") or None + content = re_match.group("content") or None + return reasoning_content, content + + # no reasoning content + return None, model_output + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract content using token ID sequence state machine""" + + delta_message = self.buffer.add_text(delta_text) + if delta_message is None and self.buffer.think_end in self.buffer.buffer: + # this is a bit hacky, but, because of how the buffer is + # constructed, if the last delta_text contains characters that + # marks the end of thinking tokens, then messages in the buffer + # would never be processed because we get no other turn. To get + # around that, we check if the text buffer contains the end of + # thinking tokens, and, if so, we reprocess the buffer again. + delta_message = self.buffer.process_buffer() + + return delta_message diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index 61bafc724c17f..160e8633a43fd 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -1,21 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence from typing import Optional, Union -from transformers import PreTrainedTokenizerBase - -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) -from vllm.logger import init_logger -from vllm.reasoning import ReasoningParser, ReasoningParserManager - -logger = init_logger(__name__) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @ReasoningParserManager.register_module("qwen3") -class Qwen3ReasoningParser(ReasoningParser): +class Qwen3ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for the Qwen3 model. @@ -26,101 +20,25 @@ class Qwen3ReasoningParser(ReasoningParser): output. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) - self.think_start_token = "<think>" - self.think_end_token = "</think>" + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "<think>" - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") - - self.think_start_token_id = self.vocab.get(self.think_start_token) - self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None): - raise RuntimeError( - "Qwen3 reasoning parser could not locate think start/end " - "tokens in the tokenizer!") - - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract the content after the end tokens - """ - if self.think_end_token_id not in input_ids[:-1]: - return [] - else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] - - def extract_reasoning_content_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: - """ - Extract reasoning content from a delta message. - Handles streaming output where previous + delta = current. - Uses token IDs for faster processing. - For text <think>abc</think>xyz: - - 'abc' goes to reasoning_content - - 'xyz' goes to content - """ - # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id - ]): - return None - - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: - # <think> in previous, </think> in delta, - # extract reasoning content - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - elif self.think_end_token_id in previous_token_ids: - # <think> in previous, </think> in previous, - # reasoning content continues - return DeltaMessage(content=delta_text) - else: - # <think> in previous, no </think> in previous or delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: - # <think> in delta, </think> in delta, extract reasoning content - start_index = delta_text.find(self.think_start_token) - end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) - else: - # <think> in delta, no </think> in delta, - # reasoning content continues - return DeltaMessage(reasoning_content=delta_text) - else: - # thinking is disabled, just content - return DeltaMessage(content=delta_text) + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "</think>" def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest] ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. + Qwen3 has stricter requirements - it needs both start and end tokens + to be present, unlike other models that work with just the end token. + For text <think>abc</think>xyz: - 'abc' goes to reasoning_content - 'xyz' goes to content @@ -129,23 +47,24 @@ class Qwen3ReasoningParser(ReasoningParser): tuple[Optional[str], Optional[str]]: reasoning content and content """ - # Check if the model output contains the <think> and </think> tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + # Check if the model output contains both <think> and </think> tokens. + if self.start_token not in model_output or self.end_token not in model_output: return None, model_output + # Check if the <think> is present in the model output, remove it # if it is present. - model_output_parts = model_output.partition(self.think_start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + # Check if the model output contains the </think> tokens. # If the end token is not found, return the model output as is. - if self.think_end_token not in model_output: + if self.end_token not in model_output: return None, model_output # Extract reasoning content from the model output. - reasoning_content, _, content = model_output.partition( - self.think_end_token) + reasoning_content, _, content = model_output.partition(self.end_token) final_content = content or None return reasoning_content, final_content diff --git a/vllm/reasoning/seedoss_reasoning_parser.py b/vllm/reasoning/seedoss_reasoning_parser.py new file mode 100644 index 0000000000000..72f8dc54f1b37 --- /dev/null +++ b/vllm/reasoning/seedoss_reasoning_parser.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +@ReasoningParserManager.register_module("seed_oss") +class SeedOSSReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for SeedOSS model. + + The SeedOSS model uses <seed:think>...</seed:think> tokens to + denote reasoning content text. This parser extracts + the reasoning content from the model output. + Similar to DeepSeek R1, it supports cases + where the model doesn't generate the start token. + """ + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "<seed:think>" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "</seed:think>" diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py index f642ea977c580..c9f580077b338 100644 --- a/vllm/reasoning/step3_reasoning_parser.py +++ b/vllm/reasoning/step3_reasoning_parser.py @@ -7,8 +7,7 @@ from typing import Optional, Union import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -20,27 +19,28 @@ class Step3ReasoningParser(ReasoningParser): """ Reasoning parser for Step3 model. - The Step3 model uses </think> token to denote the end of reasoning + The Step3 model uses </think> token to denote the end of reasoning text. This parser extracts all content before </think> as reasoning content. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase): - super().__init__(tokenizer) + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) self.think_end_token = "</think>" - self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", - re.DOTALL) + self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) self.think_end_token_id = self.vocab.get(self.think_end_token) if self.think_end_token_id is None: raise RuntimeError( "Step3 reasoning parser could not locate think end " - "token in the tokenizer!") + "token in the tokenizer!" + ) def extract_reasoning_content_streaming( self, @@ -60,17 +60,18 @@ class Step3ReasoningParser(ReasoningParser): - 'xyz' goes to content """ # Skip single special token - if len(delta_token_ids - ) == 1 and delta_token_ids[0] == self.think_end_token_id: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: return None if self.think_end_token_id in delta_token_ids: # </think> in delta, extract reasoning content and remaining content end_index = delta_text.find(self.think_end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) elif self.think_end_token_id in previous_token_ids: # </think> already seen in previous text, everything is content return DeltaMessage(content=delta_text) @@ -79,9 +80,8 @@ class Step3ReasoningParser(ReasoningParser): return DeltaMessage(reasoning_content=delta_text) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: - # Check if the model output contains the </think> token if self.think_end_token not in model_output: # If no </think> token, everything is reasoning content @@ -92,7 +92,7 @@ class Step3ReasoningParser(ReasoningParser): reasoning_content = model_output[:end_index] # Content after </think> token - content = model_output[end_index + len(self.think_end_token):] + content = model_output[end_index + len(self.think_end_token) :] if len(content) == 0: content = None @@ -106,4 +106,4 @@ class Step3ReasoningParser(ReasoningParser): if self.think_end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index df4cca9ba1147..a1ff4e5ff63b2 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampling parameters for text generation.""" + import copy -from dataclasses import dataclass +import warnings +from dataclasses import field from enum import Enum, IntEnum from functools import cached_property from typing import Annotated, Any, Optional, Union import msgspec -from pydantic import BaseModel +from pydantic.dataclasses import dataclass from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor @@ -28,60 +30,54 @@ class SamplingType(IntEnum): # maybe make msgspec? @dataclass -class GuidedDecodingParams: - """One of these fields will be used to build a logit processor.""" +class StructuredOutputsParams: + # One of these fields will be used to build a logit processor. json: Optional[Union[str, dict]] = None regex: Optional[str] = None choice: Optional[list[str]] = None grammar: Optional[str] = None json_object: Optional[bool] = None - """These are other options that can be set""" - backend: Optional[str] = None - backend_was_auto: bool = False + # These are other options that can be set. disable_fallback: bool = False disable_any_whitespace: bool = False disable_additional_properties: bool = False whitespace_pattern: Optional[str] = None structural_tag: Optional[str] = None - @staticmethod - def from_optional( - json: Optional[Union[dict, BaseModel, str]] = None, - regex: Optional[str] = None, - choice: Optional[list[str]] = None, - grammar: Optional[str] = None, - json_object: Optional[bool] = None, - backend: Optional[str] = None, - whitespace_pattern: Optional[str] = None, - structural_tag: Optional[str] = None, - ) -> Optional["GuidedDecodingParams"]: - if all(arg is None for arg in (json, regex, choice, grammar, - json_object, structural_tag)): - return None - # Extract json schemas from pydantic models - if isinstance(json, (BaseModel, type(BaseModel))): - json = json.model_json_schema() - return GuidedDecodingParams( - json=json, - regex=regex, - choice=choice, - grammar=grammar, - json_object=json_object, - backend=backend, - whitespace_pattern=whitespace_pattern, - structural_tag=structural_tag, - ) + _backend: Optional[str] = field(default=None, init=False) + """CAUTION: Should only be set by Processor._validate_structured_output""" + _backend_was_auto: bool = field(default=False, init=False) + """CAUTION: Should only be set by Processor._validate_structured_output""" def __post_init__(self): """Validate that some fields are mutually exclusive.""" - guide_count = sum([ - self.json is not None, self.regex is not None, self.choice - is not None, self.grammar is not None, self.json_object is not None - ]) - if guide_count > 1: + count = sum( + [ + self.json is not None, + self.regex is not None, + self.choice is not None, + self.grammar is not None, + self.json_object is not None, + ] + ) + if count > 1: raise ValueError( - "You can only use one kind of guided decoding but multiple are " - f"specified: {self.__dict__}") + "You can only use one kind of structured outputs constraint " + f"but multiple are specified: {self.__dict__}" + ) + + +@dataclass +class GuidedDecodingParams(StructuredOutputsParams): + def __post_init__(self): + warnings.warn( + "GuidedDecodingParams is deprecated. This will be removed in " + "v0.12.0 or v1.0.0, which ever is soonest. Please use " + "StructuredOutputsParams instead.", + DeprecationWarning, + stacklevel=2, + ) + return super().__post_init__() class RequestOutputKind(Enum): @@ -94,10 +90,11 @@ class RequestOutputKind(Enum): class SamplingParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): # type: ignore[call-arg] """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion @@ -106,7 +103,13 @@ class SamplingParams( """ n: int = 1 - """Number of output sequences to return for the given prompt.""" + """Number of outputs to return for the given prompt request. + + NOTE: + `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs + are generated and streamed cumulatively per request. To see all `n` + outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY` + in `SamplingParams`.""" best_of: Optional[int] = None """Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` @@ -165,7 +168,8 @@ class SamplingParams( the sampled token, so there may be up to `logprobs+1` elements in the response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: Optional[int] = None - """Number of log probabilities to return per prompt token.""" + """Number of log probabilities to return per prompt token. + When set to -1, return all `vocab_size` log probabilities.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. @@ -182,7 +186,7 @@ class SamplingParams( optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" @@ -194,9 +198,10 @@ class SamplingParams( _all_stop_token_ids: set[int] = msgspec.field(default_factory=set) # Fields used to construct logits processors + structured_outputs: Optional[StructuredOutputsParams] = None + """Parameters for configuring structured outputs.""" guided_decoding: Optional[GuidedDecodingParams] = None - """If provided, the engine will construct a guided decoding logits - processor from these parameters.""" + """Deprecated alias for structured_outputs.""" logit_bias: Optional[dict[int, float]] = None """If provided, the engine will construct a logits processor that applies these logit biases.""" @@ -240,9 +245,9 @@ class SamplingParams( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[list[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, + structured_outputs: Optional[StructuredOutputsParams] = None, guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, @@ -255,16 +260,25 @@ class SamplingParams( int(token): min(100.0, max(-100.0, bias)) for token, bias in logit_bias.items() } + if guided_decoding is not None: + warnings.warn( + "guided_decoding is deprecated. This will be removed in " + "v0.12.0 or v1.0.0, which ever is soonest. Please use " + "structured_outputs instead.", + DeprecationWarning, + stacklevel=2, + ) + structured_outputs = guided_decoding + guided_decoding = None return SamplingParams( n=1 if n is None else n, best_of=best_of, - presence_penalty=0.0 - if presence_penalty is None else presence_penalty, - frequency_penalty=0.0 - if frequency_penalty is None else frequency_penalty, + presence_penalty=0.0 if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty, repetition_penalty=1.0 - if repetition_penalty is None else repetition_penalty, + if repetition_penalty is None + else repetition_penalty, temperature=1.0 if temperature is None else temperature, top_p=1.0 if top_p is None else top_p, top_k=top_k, @@ -285,7 +299,7 @@ class SamplingParams( logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, output_kind=output_kind, - guided_decoding=guided_decoding, + structured_outputs=structured_outputs, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, extra_args=extra_args, @@ -302,7 +316,8 @@ class SamplingParams( if self.best_of < self.n: raise ValueError( f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + f"got n={self.n} and best_of={self.best_of}." + ) if not self._real_n: self._real_n = self.n self.n = self.best_of @@ -311,7 +326,10 @@ class SamplingParams( logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", - self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature, + _MAX_TEMP, + _MAX_TEMP, + ) self.temperature = max(self.temperature, _MAX_TEMP) if self.seed == -1: @@ -351,93 +369,122 @@ class SamplingParams( # eos_token_id is added to this by the engine self._all_stop_token_ids.update(self.stop_token_ids) + if self.guided_decoding is not None: + warnings.warn( + "guided_decoding is deprecated. This will be removed in " + "v0.12.0 or v1.0.0, which ever is soonest. Please use " + "structured_outputs instead.", + DeprecationWarning, + stacklevel=2, + ) + self.structured_outputs = self.guided_decoding + self.guided_decoding = None + def _verify_args(self) -> None: if not isinstance(self.n, int): - raise ValueError(f"n must be an int, but is of " - f"type {type(self.n)}") + raise ValueError(f"n must be an int, but is of type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if self.best_of is not None: if not isinstance(self.best_of, int): raise ValueError( - f"best_of must be an integer, got {type(self.best_of)}") + f"best_of must be an integer, got {type(self.best_of)}" + ) if self.best_of < 1: - raise ValueError( - f"best_of must be at least 1, got {self.best_of}") + raise ValueError(f"best_of must be at least 1, got {self.best_of}") if self.best_of < self.n: raise ValueError( f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + f"got n={self.n} and best_of={self.best_of}." + ) if not -2.0 <= self.presence_penalty <= 2.0: - raise ValueError("presence_penalty must be in [-2, 2], got " - f"{self.presence_penalty}.") + raise ValueError( + f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." + ) if not -2.0 <= self.frequency_penalty <= 2.0: - raise ValueError("frequency_penalty must be in [-2, 2], got " - f"{self.frequency_penalty}.") + raise ValueError( + f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}." + ) if self.repetition_penalty <= 0.0: raise ValueError( "repetition_penalty must be greater than zero, got " - f"{self.repetition_penalty}.") + f"{self.repetition_penalty}." + ) if self.temperature < 0.0: raise ValueError( - f"temperature must be non-negative, got {self.temperature}.") + f"temperature must be non-negative, got {self.temperature}." + ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") # quietly accept -1 as disabled, but prefer 0 if self.top_k < -1: - raise ValueError(f"top_k must be 0 (disable), or at least 1, " - f"got {self.top_k}.") + raise ValueError( + f"top_k must be 0 (disable), or at least 1, got {self.top_k}." + ) if not isinstance(self.top_k, int): raise TypeError( - f"top_k must be an integer, got {type(self.top_k).__name__}") + f"top_k must be an integer, got {type(self.top_k).__name__}" + ) if not 0.0 <= self.min_p <= 1.0: - raise ValueError("min_p must be in [0, 1], got " - f"{self.min_p}.") + raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.max_tokens is not None and self.max_tokens < 1: - raise ValueError( - f"max_tokens must be at least 1, got {self.max_tokens}.") + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") if self.min_tokens < 0: - raise ValueError(f"min_tokens must be greater than or equal to 0, " - f"got {self.min_tokens}.") + raise ValueError( + f"min_tokens must be greater than or equal to 0, got {self.min_tokens}." + ) if self.max_tokens is not None and self.min_tokens > self.max_tokens: raise ValueError( f"min_tokens must be less than or equal to " - f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if (self.logprobs is not None and self.logprobs != -1 - and self.logprobs < 0): + f"max_tokens={self.max_tokens}, got {self.min_tokens}." + ) + if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0: raise ValueError( - f"logprobs must be non-negative or -1, got {self.logprobs}.") - if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") - if (self.truncate_prompt_tokens is not None - and self.truncate_prompt_tokens < 1): - raise ValueError(f"truncate_prompt_tokens must be >= 1, " - f"got {self.truncate_prompt_tokens}") + f"logprobs must be non-negative or -1, got {self.logprobs}." + ) + if ( + self.prompt_logprobs is not None + and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0 + ): + raise ValueError( + f"prompt_logprobs must be non-negative or -1, got " + f"{self.prompt_logprobs}." + ) + if self.truncate_prompt_tokens is not None and ( + self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1 + ): + raise ValueError( + f"truncate_prompt_tokens must be an integer >= 1 or -1, " + f"got {self.truncate_prompt_tokens}" + ) assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): - raise ValueError(f"stop_token_ids must contain only integers, " - f"got {self.stop_token_ids}.") + raise ValueError( + f"stop_token_ids must contain only integers, got {self.stop_token_ids}." + ) assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " - "Set detokenize=True to use stop.") + "Set detokenize=True to use stop." + ) if self.best_of != self._real_n and self.output_kind == ( - RequestOutputKind.DELTA): + RequestOutputKind.DELTA + ): raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_greedy_sampling(self) -> None: if self.n > 1: - raise ValueError("n must be 1 when using greedy sampling, " - f"got {self.n}.") + raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.") def update_from_generation_config( - self, - generation_config: dict[str, Any], - model_eos_token_id: Optional[int] = None) -> None: + self, + generation_config: dict[str, Any], + model_eos_token_id: Optional[int] = None, + ) -> None: """Update if there are non-default values from generation_config""" if model_eos_token_id is not None: @@ -471,30 +518,33 @@ class SamplingParams( for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode( + text=prompt, add_special_tokens=False + ) # If no space at the beginning # or if prefix space produces a new word token if (not add_prefix_space) or ( - add_prefix_space and prompt_token_ids[0] - != self._bad_words_token_ids[-1][0] - and len(prompt_token_ids) == len( - self._bad_words_token_ids[-1])): + add_prefix_space + and prompt_token_ids[0] != self._bad_words_token_ids[-1][0] + and len(prompt_token_ids) == len(self._bad_words_token_ids[-1]) + ): self._bad_words_token_ids.append(prompt_token_ids) invalid_token_ids = [ - token_id for bad_words_token_ids in self._bad_words_token_ids + token_id + for bad_words_token_ids in self._bad_words_token_ids for token_id in bad_words_token_ids if token_id < 0 or token_id > tokenizer.max_token_id ] if len(invalid_token_ids) > 0: raise ValueError( - f"The model vocabulary size is {tokenizer.max_token_id+1}," + f"The model vocabulary size is {tokenizer.max_token_id + 1}," f" but the following tokens" f" were specified as bad: {invalid_token_ids}." f" All token id values should be integers satisfying:" - f" 0 <= token_id <= {tokenizer.max_token_id}.") + f" 0 <= token_id <= {tokenizer.max_token_id}." + ) @cached_property def sampling_type(self) -> SamplingType: @@ -522,10 +572,14 @@ class SamplingParams( See https://github.com/vllm-project/vllm/issues/3087 """ - logit_processor_refs = None if self.logits_processors is None else { - id(lp): lp.clone() if hasattr(lp, 'clone') else lp - for lp in self.logits_processors - } + logit_processor_refs = ( + None + if self.logits_processors is None + else { + id(lp): lp.clone() if hasattr(lp, "clone") else lp + for lp in self.logits_processors + } + ) return copy.deepcopy(self, memo=logit_processor_refs) def __repr__(self) -> str: @@ -552,16 +606,19 @@ class SamplingParams( "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " - f"guided_decoding={self.guided_decoding}, " - f"extra_args={self.extra_args})") + f"structured_outputs={self.structured_outputs}, " + f"extra_args={self.extra_args})" + ) class BeamSearchParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): # type: ignore[call-arg] """Beam search parameters for text generation.""" + beam_width: int max_tokens: int ignore_eos: bool = False diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 6f11ab8e0300a..fd25d198bf1ab 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -70,20 +70,19 @@ class ScalarType: """ def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" + assert self.mantissa <= 52 and self.exponent <= 11, ( + f"Cannot represent max/min as a double for type {self.__str__()}" + ) max_mantissa = (1 << self.mantissa) - 1 if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: max_mantissa = max_mantissa - 1 max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: + assert self.exponent < 11, ( + f"Cannot represent max/min as a double for type {self.__str__()}" + ) max_exponent = max_exponent + 1 # adjust the exponent to match that of a double @@ -96,38 +95,39 @@ class ScalarType: exponent_bias = (1 << (self.exponent - 1)) - 1 exponent_bias_double = (1 << 10) - 1 # double e = 11 - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) + max_exponent_double = max_exponent - exponent_bias + exponent_bias_double # shift the mantissa and exponent into the proper positions for an # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) + return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) def _floating_point_max(self) -> float: double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + return struct.unpack("!d", struct.pack("!Q", double_raw))[0] def _raw_max(self) -> Union[int, float]: if self.is_floating_point(): return self._floating_point_max() else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" + assert self.size_bits < 64 or self.size_bits == 64 and self.is_signed(), ( + "Cannot represent max as an int" + ) return (1 << self.mantissa) - 1 def _raw_min(self) -> Union[int, float]: if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" + assert self.is_signed(), ( + "We currently assume all floating point types are signed" + ) sign_bit_double = 1 << 63 max_raw = self._floating_point_max_int() min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + return struct.unpack("!d", struct.pack("!Q", min_raw))[0] else: - assert (not self.is_signed() or self.size_bits - <= 64), "Cannot represent min as a int64_t" + assert not self.is_signed() or self.size_bits <= 64, ( + "Cannot represent min as a int64_t" + ) if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -158,8 +158,7 @@ class ScalarType: or_and_advance(self._finite_values_only, 1) or_and_advance(self.nan_repr.value, 8) - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" + assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" _SCALAR_TYPES_ID_MAP[val] = self @@ -215,8 +214,7 @@ class ScalarType: If the type is a floating point type that follows IEEE 754 conventions """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only + return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only def __str__(self) -> str: """ @@ -232,8 +230,14 @@ class ScalarType: - if bias is not present it means its zero """ if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) + ret = ( + "float" + + str(self.size_bits) + + "_e" + + str(self.exponent) + + "m" + + str(self.mantissa) + ) if not self.is_ieee_754(): if self._finite_values_only: @@ -261,41 +265,43 @@ class ScalarType: # @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": "Create a signed integer scalar type (size_bits includes sign-bit)." ret = cls(0, size_bits - 1, True, bias if bias else 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" + def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + """Create an unsigned integer scalar type.""" ret = cls(0, size_bits, False, bias if bias else 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": """ Create a standard floating point type (i.e. follows IEEE 754 conventions). """ - assert (mantissa > 0 and exponent > 0) + assert mantissa > 0 and exponent > 0 ret = cls(exponent, mantissa, True, 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': + def float_( + cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr + ) -> "ScalarType": """ Create a non-standard floating point type (i.e. does not follow IEEE 754 conventions). """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( + assert mantissa > 0 and exponent > 0 + assert nan_repr != NanRepr.IEEE_754, ( "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") + "follow IEEE 754 conventions" + ) ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) ret.id # noqa B018: make sure the id is cached return ret @@ -303,8 +309,7 @@ class ScalarType: @classmethod def from_id(cls, scalar_type_id: int): if scalar_type_id not in _SCALAR_TYPES_ID_MAP: - raise ValueError( - f"scalar_type_id {scalar_type_id} doesn't exists.") + raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.") return _SCALAR_TYPES_ID_MAP[scalar_type_id] @@ -327,14 +332,16 @@ class scalar_types: uint8 = ScalarType.uint(8, None) float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float8_e8m0fnu = ScalarType(8, 0, False, 0, True, - NanRepr.EXTD_RANGE_MAX_MIN) + float8_e8m0fnu = ScalarType(8, 0, False, 0, True, NanRepr.EXTD_RANGE_MAX_MIN) float16_e8m7 = ScalarType.float_IEEE754(8, 7) float16_e5m10 = ScalarType.float_IEEE754(5, 10) # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + # and https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + float6_e2m3f = ScalarType.float_(2, 3, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) diff --git a/vllm/scripts.py b/vllm/scripts.py index 7a7fdccf0a32b..f158860726beb 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -10,6 +10,8 @@ logger = init_logger(__name__) # Backwards compatibility for the move from vllm.scripts to # vllm.entrypoints.cli.main def main(): - logger.warning("vllm.scripts.main() is deprecated. Please re-install " - "vllm or use vllm.entrypoints.cli.main.main() instead.") + logger.warning( + "vllm.scripts.main() is deprecated. Please re-install " + "vllm or use vllm.entrypoints.cli.main.main() instead." + ) vllm_main() diff --git a/vllm/sequence.py b/vllm/sequence.py index 43d5c8beef270..7682b7f58305e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,103 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sequence and its related classes.""" -import copy -import enum -from abc import ABC, abstractmethod -from array import array -from collections import defaultdict -from collections.abc import Mapping -from collections.abc import Sequence as GenericSequence -from dataclasses import dataclass, field -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union import msgspec import torch -from vllm.inputs import SingletonInputs -from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import RequestOutputKind, SamplingParams - if TYPE_CHECKING: - from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorOutput) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +else: + KVConnectorOutput = Any VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 -def array_full(token_id: int, count: int): - """[`array`][] equivalent of [numpy.full][].""" - return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - - -# We use dataclass for now because it is used for -# openai server output, and msgspec is not serializable. -# TODO(sang): Fix it. -@dataclass -class Logprob: - """Infos for supporting OpenAI compatible logprobs and token ranks. - - Attributes: - logprob: The logprob of chosen token - rank: The vocab rank of chosen token (>=1) - decoded_token: The decoded chosen token index - """ - logprob: float - rank: Optional[int] = None - decoded_token: Optional[str] = None - - -# {token_id -> logprob} per each sequence group. None if the corresponding -# sequence group doesn't require prompt logprob. -PromptLogprobs = list[Optional[dict[int, Logprob]]] -# {token_id -> logprob} for each sequence group. -SampleLogprobs = list[dict[int, Logprob]] - - -class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered - # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: - return status > SequenceStatus.SWAPPED - - @staticmethod - def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: - if status == SequenceStatus.FINISHED_STOPPED: - finish_reason = "stop" - elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: - finish_reason = "length" - elif status == SequenceStatus.FINISHED_ABORTED: - finish_reason = "abort" - elif status == SequenceStatus.FINISHED_IGNORED: - # The ignored sequences are the sequences whose prompt lengths - # are longer than the model's length cap. Therefore, the stop - # reason should also be "length" as in OpenAI API. - finish_reason = "length" - else: - finish_reason = None - return finish_reason - - -class SequenceStage(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - - @dataclass class RequestMetrics: """Metrics associated with a request. @@ -116,6 +36,7 @@ class RequestMetrics: will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. """ + arrival_time: float last_token_time: float first_scheduled_time: Optional[float] @@ -127,1018 +48,18 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class SequenceDataDelta( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta SequenceData to send to workers per step.""" - # A new token to be appended to existing SequenceData. - new_output_token_ids: list[int] - # Overwriting existing `cumulative_logprob` - new_cumulative_logprob: float - # Overwriting existing `num_computed_tokens`. - new_num_computed_tokens: int - # Overwriting existing `stage`. - new_stage: SequenceStage - - -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence. - - Args: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. Set to an empty list if - None. - - Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. - cumulative_logprob: The cumulative log probability of the output. - """ - # NOTE: we cannot use Union[list, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - - _prompt_embeds: Optional[torch.Tensor] = None - _output_embeds: Optional[torch.Tensor] = None - - ### The below fields should not be passed as an argument ### - _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: tuple[int, - ...] = msgspec.field(default_factory=tuple) - # The number of tokens that are computed (that run against the model). - _num_computed_tokens: int = 0 - # The number of tokens with prefix cache hit. - _num_cached_tokens: int = 0 - _stage: SequenceStage = SequenceStage.PREFILL - _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) - _cached_all_token_embeds: Optional[torch.Tensor] = None - - # It is used to get delta input. It is reset when `get_delta_and_reset` - # is called. - _new_appended_tokens: list[int] = msgspec.field(default_factory=list) - - # It is used to compute mrope_position_ids. - _mrope_position_delta: Optional[int] = None - - @staticmethod - def from_prompt_token_counts( - *token_counts: tuple[int, int]) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - by concatenating prompt token sequences. - - Each tuple represents one token sequence, expressed in the form - `(token_id, count)`. - """ - if len(token_counts) == 0: - return SequenceData.from_seqs([]) - - prompt_token_ids_arr = reduce( - array.__iadd__, - (array_full(token_id, count) for token_id, count in token_counts), - ) - - return SequenceData(prompt_token_ids_arr) - - @staticmethod - def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, - *, - prompt_embeds: Optional[torch.Tensor] = None, - ) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - from prompt and output token sequences. - """ - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) - - if output_token_ids is None: - return SequenceData(prompt_token_ids_arr, - _prompt_embeds=prompt_embeds) - - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) - - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr, - _prompt_embeds=prompt_embeds) - - def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: tuple[int, ...] = tuple( - self._prompt_token_ids) - self._update_cached_all_tokens() - if self._prompt_embeds is not None: - self._update_cached_all_token_embeds() - - def _update_cached_all_tokens(self): - assert isinstance(self._prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + - self._output_token_ids) - - def _update_cached_all_token_embeds(self): - assert isinstance(self._prompt_embeds, torch.Tensor) - self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds - if self._output_embeds is not None: - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, self._output_embeds), dim=0) - - @property - def cumulative_logprob(self) -> float: - return self._cumulative_logprob - - @property - def prompt_token_ids(self) -> tuple[int, ...]: - return self._prompt_token_ids_tuple - - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - - @property - def output_token_ids(self) -> tuple[int, ...]: - return tuple(self._output_token_ids) - - @output_token_ids.setter - def output_token_ids(self, - new_output_token_ids: GenericSequence[int]) -> None: - self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids) - self._update_cached_all_tokens() - - @property - def output_embeds(self) -> Optional[torch.Tensor]: - return self._output_embeds - - @output_embeds.setter - def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: - self._output_token_embeds = new_output_token_embeds - self._update_cached_all_token_embeds() - - @property - def output_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - assert isinstance(self._output_token_ids, array) - return self._output_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: - self._prompt_embeds = prompt_embeds - self._update_cached_all_token_embeds() - - @property - def mrope_position_delta(self) -> Optional[int]: - return self._mrope_position_delta - - @mrope_position_delta.setter - def mrope_position_delta(self, new_mrope_position_delta): - self._mrope_position_delta = new_mrope_position_delta - - def append_token_id(self, - token_id: int, - logprob: float, - token_embed: Optional[torch.Tensor] = None) -> None: - self._output_token_ids.append(token_id) - self._new_appended_tokens.append(token_id) - self._cached_all_token_ids.append(token_id) - self._cumulative_logprob += logprob - if token_embed is not None: - # Do not pass in with batch or sequence dimensions - assert token_embed.ndim == 1 - token_embed = token_embed.detach().cpu().unsqueeze(0) - if self._output_embeds is None: - self._output_embeds = token_embed - else: - self._output_embeds = torch.cat( - (self._output_embeds, token_embed), dim=0) - assert self._cached_all_token_embeds is not None - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, - token_embed.to(device=self._cached_all_token_embeds.device)), - dim=0) - - def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) - - def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) - - def get_output_len(self) -> int: - return len(self._output_token_ids) - - def get_token_ids(self) -> list[int]: - return self._cached_all_token_ids - - def get_token_embeddings(self) -> Optional[torch.Tensor]: - return self._cached_all_token_embeds - - def get_prefix_token_ids( - self, num_tokens: int - ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) - - def get_num_computed_tokens(self) -> int: - """Return the number of prefill tokens that are already computed.""" - return self._num_computed_tokens - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) - # If all tokens are computed, it means it is in decoding phase. - if self.get_num_uncomputed_tokens() == 0: - self._stage = SequenceStage.DECODE - - def get_num_cached_tokens(self) -> int: - """Return the number of tokens with prefix cache hit.""" - return self._num_cached_tokens - - def update_num_cached_tokens(self, num_cached_tokens: int): - """Update the number of tokens with prefix cache hit.""" - self._num_cached_tokens = num_cached_tokens - - def reset_state_for_recompute(self) -> None: - """Reset the number of computed tokens from this sequence. It is - supposed to be called when a sequence needs to be started from - the beginning again (e.g., sequence is preempted). - """ - self._num_computed_tokens = 0 - self._stage = SequenceStage.PREFILL - self._new_appended_tokens = [] - - def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefill tokens that are not computed.""" - # we use `get_len()` which includes prompt_len + output_len instead - # of prompt_len here. This is because during recompute we need to - # prefill for both prompt and output. - return self.get_len() - self.get_num_computed_tokens() - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.prompt_token_ids - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.output_token_ids - - def get_delta_and_reset(self) -> SequenceDataDelta: - delta = SequenceDataDelta(self._new_appended_tokens, - self._cumulative_logprob, - self.get_num_computed_tokens(), self.stage) - # Reset delta state. - self._new_appended_tokens = [] - return delta - - def apply_delta(self, delta: SequenceDataDelta): - self._num_computed_tokens = delta.new_num_computed_tokens - self._cumulative_logprob = delta.new_cumulative_logprob - self._stage = delta.new_stage - self._output_token_ids.extend(delta.new_output_token_ids) - self._cached_all_token_ids.extend(delta.new_output_token_ids) - - @property - def stage(self) -> SequenceStage: - return self._stage - - def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds.shape=" - f"{getattr(self._prompt_embeds, 'shape', None)}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") - - -class Sequence: - """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the - [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only) - or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] - (for encoder-decoder) instance passed in through the `inputs` - constructor argument. - - Args: - seq_id: The ID of the sequence. - inputs: The inputs of the sequence. - block_size: The block size of the sequence. Should be the same as the - block size used by the block manager and cache engine. - eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. - lora_request: LoRA request. - """ - - def __init__( - self, - seq_id: int, - inputs: SingletonInputs, - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.seq_id = seq_id - self.inputs = inputs - self.block_size = block_size - self.eos_token_id = eos_token_id - self.lora_request = lora_request - - self.data = SequenceData.from_seqs( - self.prompt_token_ids, - prompt_embeds=self.inputs["prompt_embeds"] - if self.inputs["type"] == "embeds" else None) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" - - self.status = SequenceStatus.WAITING - self.stop_reason: Union[int, str, None] = None - - # These are used to keep track of delta outputs - self._last_output_token_ids_offset: int = 0 - self._last_output_text_offset: int = 0 - - # Used for incremental detokenization - self.prefix_offset = 0 - self.read_offset = 0 - # Input + output tokens - self.tokens: Optional[list[str]] = None - - @property - def n_blocks(self) -> int: - return (self.get_len() + self.block_size - 1) // self.block_size - - @property - def prompt(self) -> Optional[str]: - if self.inputs["type"] == "embeds": - return None - return self.inputs.get("prompt") - - @property - def prompt_token_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [0] * len(self.inputs["prompt_embeds"]) - return self.inputs["prompt_token_ids"] - - @property - def token_type_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [] - return self.inputs.get("token_type_ids", []) - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"].get_data() - - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_placeholders"] - - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def get_output_text_to_return(self, buffer_length: int, - delta: bool) -> str: - """If delta is True, only new text since the last call to - this method is returned""" - - # We return the full output text if the sequence is finished. - truncate = buffer_length and not self.is_finished() - if not delta: - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) - length = len(self.output_text) - if truncate: - length -= buffer_length - last_offset = self._last_output_text_offset - if last_offset < length: - self._last_output_text_offset = length - return self.output_text[last_offset:length] - return "" - - def get_output_token_ids_to_return( - self, delta: bool) -> Union[GenericSequence[int], int]: - """If delta is True, only new tokens since the last call to - this method are returned""" - if not delta: - return self.get_output_token_ids() - - output_len = self.get_output_len() - - # Get the number of new tokens - num_new_tokens = output_len - self._last_output_token_ids_offset - self._last_output_token_ids_offset = output_len - - # Return new tokens - if num_new_tokens == 1: - # Optimization for single decode token case - # (which is what we have most of the time) - return self.data._cached_all_token_ids[-1] - - if num_new_tokens == 0: - return [] - - return self.data._cached_all_token_ids[-num_new_tokens:] - - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size - - # Compute the number of tokens in the sequence - # TODO: The current hashing function is O(L^2). We should optimize - # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) - - def extra_hash(self) -> Optional[int]: - """ - This function computes an extra hash for a sequence, specifically - designed for prefix caching mode. The final sequence hash is determined - by applying token_ids from the sequence's blocks. - """ - if self.lora_int_id == 0: - return None - - # NOTE: If there are additional factors influencing the block aside from - # token_ids, include them as input parameters to the hash. - return hash(self.lora_int_id) - - def num_hashed_tokens_of_block(self, logical_idx: int): - return logical_idx * self.block_size + self.block_size - - def reset_state_for_recompute(self): - """Reset the sequence states for recomputation.""" - self.data.reset_state_for_recompute() - - def append_token_id(self, - token_id: int, - logprobs: dict[int, Logprob], - token_embed: Optional[torch.Tensor] = None) -> None: - assert token_id in logprobs - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob, - token_embed) - - def get_len(self) -> int: - return self.data.get_len() - - def get_prompt_len(self) -> int: - return self.data.get_prompt_len() - - def get_output_len(self) -> int: - return self.data.get_output_len() - - def get_token_ids(self) -> list[int]: - return self.data.get_token_ids() - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.data.get_prompt_token_ids() - - def get_last_token_id(self) -> int: - return self.data.get_last_token_id() - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.data.get_output_token_ids() - - def get_cumulative_logprob(self) -> float: - return self.data.cumulative_logprob - - def is_finished(self) -> bool: - return SequenceStatus.is_finished(self.status) - - def fork(self, new_seq_id: int) -> "Sequence": - new_seq = copy.deepcopy(self) - new_seq.seq_id = new_seq_id - return new_seq - - def get_num_new_tokens(self) -> int: - """Get the number of new tokens to be computed. - - Returns: - The new number of tokens to be computed. I.e., 1 for decode, or - the remaining prompt size for prefill. - """ - if self.data.stage == SequenceStage.DECODE: - return 1 - return self.data.get_num_uncomputed_tokens() - - def get_num_computed_tokens(self) -> int: - return self.data.get_num_computed_tokens() - - def is_prefill(self) -> bool: - return self.data.stage == SequenceStage.PREFILL - - def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={self.n_blocks})") - - -class SequenceGroupState(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Mutable state tied to a specific sequence group""" - - # for multi-step decoding - num_steps: int = 1 - current_step: int = 0 - - @property - def remaining_steps(self) -> int: - return self.num_steps - self.current_step - - -class SequenceGroup: - """A group of sequences that are generated from the same prompt. - - Args: - request_id: The ID of the request. - seqs: The list of sequences. - sampling_params: The sampling parameters used to generate the outputs. - arrival_time: The arrival time of the request. - lora_request: LoRA request. - pooling_params: The parameters used to generate the pooler - for a pooling model. - pooled_data: The extracted hidden states from a pooling model. - encoder_seq: Optional, the single encoder sequence. Should be None - unless you are working with an encoder/decoder model. - trace_headers: OpenTelemetry trace headers. - priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target - model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than - that for multi-draft SD (currently not supported). - """ - - def __init__(self, - request_id: str, - seqs: list[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - pooling_params: Optional[PoolingParams] = None, - pooled_data: Optional[torch.Tensor] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - draft_size: int = 1) -> None: - self.request_id = request_id - self.seqs = seqs - self.first_seq = seqs[0] - self.arrival_time = arrival_time - self.is_single_seq = len(seqs) == 1 - self.seqs_dict = {seq.seq_id: seq for seq in seqs} - - self.sampling_params = sampling_params - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) - self.last_token_latency = 0.0 - self.lora_request = lora_request - self.prompt_logprobs: Optional[PromptLogprobs] = None - self.state = SequenceGroupState() - self.pooling_params = pooling_params - self.pooled_data = pooled_data - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - self.priority = priority - - self.cached_request_output = None - - @property - def prompt(self) -> Optional[str]: - return self.first_seq.prompt - - @property - def prompt_token_ids(self) -> list[int]: - return self.first_seq.prompt_token_ids - - @property - def encoder_prompt(self) -> Optional[str]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt is distinct - # from the decoder's. - return (self.encoder_seq.prompt - if self.encoder_seq is not None else None) - - @property - def encoder_prompt_token_ids(self) -> Optional[list[int]]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt token ids are - # distinct from the decoder's. - return (self.encoder_seq.prompt_token_ids - if self.encoder_seq is not None else None) - - @property - def token_type_ids(self) -> Optional[list[int]]: - return self.first_seq.token_type_ids - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_data - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_data - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_placeholders - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_placeholders - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def set_last_token_time(self, now: float) -> None: - """Sets the last token time for Request level timings.""" - # If still in prefill phase, assertion fails. - assert not self.is_prefill(), ( - "seq_group.set_last_token_time() should not be called " - "if the seq_group is in prefill phase.") - self.last_token_latency = now - self.metrics.last_token_time - self.metrics.last_token_time = now - - def get_last_token_latency(self) -> float: - """Returns the latency of the last token.""" - assert not self.is_prefill(), ( - "seq_group.get_last_token_latency() should not be called " - "if the seq_group is in prefill phase.") - return self.last_token_latency - - def maybe_set_first_token_time(self, time: float) -> None: - """Sets the first token time for Request level timings.""" - # Note: in a case where a sequence_group is swapped and - # recomputed, the time between iterations is counted - # in TPOT, rather than recalculating TTFT (since from the ) - # POV of the user, there is simply a long generation delay. - if (self.metrics.first_token_time is None - and self.first_seq.get_output_len() == 1): - self.metrics.first_token_time = time - - def maybe_set_first_scheduled_time(self, time: float) -> None: - """Sets the first scheduled time and time in queue for Request - level timings.""" - if self.metrics.first_scheduled_time is None: - self.metrics.first_scheduled_time = time - self.metrics.time_in_queue = time - self.metrics.arrival_time - - def set_finished_time(self, time: Optional[float]) -> None: - """Sets the finished time for Request level timings.""" - self.metrics.finished_time = time - - def get_max_num_running_seqs(self) -> int: - """The maximum number of sequences running in parallel in the remaining - lifetime of the request.""" - if self.is_single_seq: - return 0 if self.first_seq.is_finished() else 1 - return self.num_seqs() - self.num_finished_seqs() - - def get_seqs( - self, - status: Optional[SequenceStatus] = None, - ) -> list[Sequence]: - if status is None: - return self.seqs - - if self.is_single_seq: - return self.seqs if self.first_seq.status == status else [] - - return [seq for seq in self.seqs if seq.status == status] - - def is_encoder_decoder(self) -> bool: - return self.encoder_seq is not None - - def get_encoder_seq(self) -> Optional[Sequence]: - return self.encoder_seq - - def get_finished_seqs(self) -> list[Sequence]: - if self.is_single_seq: - return self.seqs if self.first_seq.is_finished() else [] - - return [seq for seq in self.seqs if seq.is_finished()] - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - for seq in self.seqs: - if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) - - def get_num_uncomputed_tokens(self) -> int: - num_uncomputed_tokens = 0 - for seq in self.seqs: - if not seq.is_finished(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() - return num_uncomputed_tokens - - def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: - # Optimization. We don't need to call get_seqs if we don't need to - # filter by states. - if status is None: - return len(self.seqs) - - if self.is_single_seq: - return 1 if self.seqs[0].status == status else 0 - - return len(self.get_seqs(status)) - - def num_finished_seqs(self) -> int: - if self.is_single_seq: - return 1 if self.seqs[0].is_finished() else 0 - return len(self.get_finished_seqs()) - - def is_finished(self) -> bool: - if self.is_single_seq: - return self.first_seq.is_finished() - return all(seq.is_finished() for seq in self.seqs) - - def is_prefill(self) -> bool: - return self.first_seq.is_prefill() - - def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") - - def uses_prompt_embeds(self) -> bool: - """Returns True if the sequence group uses input embeds.""" - return any(seq.data.prompt_embeds is not None for seq in self.seqs) - - -class SequenceGroupMetadataDelta( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta of SequenceGroupMetadata. - - After sending the first SequenceGroupMetadata, vLLM scheduler - only sends delta to reduce the data payload size. - """ - seq_data_delta: dict[int, SequenceDataDelta] - request_id: str - block_tables: dict[int, list[int]] - is_prompt: bool - do_sample: bool = True - token_chunk_size: Optional[int] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - - -class SequenceGroupMetadata( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Metadata for a sequence group. Used to create `AttentionMetadata`. - - Args: - request_id: The ID of the request. - is_prompt: Whether the request is at prompt stage. - seq_data: The sequence data. (Seq id -> sequence data) - sampling_params: The sampling parameters used to generate the outputs. - block_tables: The block tables. (Seq id -> list of physical block - numbers) - do_sample: True if sampling is required. Sampling is not required when - e.g., prefill is chunked, and the current iteration only computes - query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. - state: Internal state tied to this sequence group. - multi_modal_data: Multi modal data. - mm_processor_kwargs: Multimodal input processor / mapper overrides. - encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - cross_block_table: Optional cross-attention block table associated - with the encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - """ - - request_id: str - is_prompt: bool - seq_data: dict[int, SequenceData] - sampling_params: Optional[SamplingParams] - block_tables: dict[int, list[int]] - do_sample: bool = True - pooling_params: Optional[PoolingParams] = None - lora_request: Optional[LoRARequest] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - token_type_ids: Optional[list[int]] = None - multi_modal_data: Optional[MultiModalKwargs] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - encoder_seq_data: Optional[SequenceData] = None - cross_block_table: Optional[list[int]] = None - token_chunk_size: Optional[int] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - num_speculative_tokens: Optional[int] = None - - def __post_init__(self): - if self.seq_data is not None and self.token_chunk_size is None: - if self.is_prompt: - self.token_chunk_size = next(iter( - self.seq_data.values())).get_len() - else: - self.token_chunk_size = 1 - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - # Multi-Step Chunked-Prefill property - @property - def is_single_step_prompt(self) -> bool: - # do_sample is true, only when the token_chunk_size matches the - # num_uncomputed_tokens of the sequence. This indicates that - # the prompt will finish processing in a single `execute_model` - # step. - return self.is_prompt and self.do_sample - - def get_first_seq_id(self) -> int: - # This is an efficient way of fetching the seq_id when - # we know this SequenceGroup has only one sequence. - return next(iter(self.seq_data)) - - def apply_delta(self, - sequence_group_metadata_delta: SequenceGroupMetadataDelta): - for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): - self.seq_data[id].apply_delta(delta) - assert self.request_id == sequence_group_metadata_delta.request_id - self.block_tables = sequence_group_metadata_delta.block_tables - self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size - self.do_sample = sequence_group_metadata_delta.do_sample - self.is_prompt = sequence_group_metadata_delta.is_prompt - - def finish_step(self) -> None: - assert self.state is not None - assert self.state.current_step < self.state.num_steps, \ - f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa - self.state.current_step += 1 - - -class SequenceOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a sequence. - - Args: - parent_seq_id: The ID of the parent sequence (for forking in beam - search). - output_token: The output token ID. - logprobs: The logprobs of the output token. - (Token id -> logP(x_i+1 | x_0, ..., x_i)) - """ - parent_seq_id: int - output_token: int - logprobs: dict[int, Logprob] - output_embed: Optional[torch.Tensor] = None - - def __repr__(self) -> str: - output_embed_shape = \ - self.output_embed.shape if self.output_embed is not None else None - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}, " - f"logprobs={self.logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutput): - raise NotImplementedError() - equal = (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token) - log_probs_equal = other.logprobs == self.logprobs - return equal and log_probs_equal - - -class SequenceGroupOutput(ABC): - """The base class for model outputs associated with a sequence group.""" - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractmethod - def __eq__(self, other: object) -> bool: - pass - - -class CompletionSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a completion sequence group.""" - __metaclass__ = SequenceGroupOutput - samples: list[SequenceOutput] - # Prompt logprob for each prompt query token. - prompt_logprobs: Optional[PromptLogprobs] - step_index: Optional[int] = 0 - - def __repr__(self) -> str: - return (f"CompletionSequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CompletionSequenceGroupOutput): - raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) - - -class PoolingSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] -): - """The model output associated with a pooling sequence group.""" - __metaclass__ = SequenceGroupOutput - # Annotated as Any to be compatible with msgspec - # The actual type is in SequenceGroup.pooled_data - data: Any - - def get_data_nbytes(self) -> int: - data: torch.Tensor = self.data - return data.nbytes - - def __repr__(self) -> str: - return f"PoolingSequenceGroupOutput(data={self.data}" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, PoolingSequenceGroupOutput): - raise NotImplementedError() - return self.data == other.data - - # cannot use msgspec.Struct here because Dynamo does not support it @dataclass class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. - + Each stage also needs to handle its own kv_connector_output. """ tensors: dict[str, torch.Tensor] - kv_connector_output: Optional["KVConnectorOutput"] + kv_connector_output: Optional[KVConnectorOutput] def __init__(self, tensors): # manually define this function, so that @@ -1167,337 +88,16 @@ class IntermediateTensors: return False if self.tensors.keys() != other.tensors.keys(): return False - return all( - torch.equal(self.tensors[k], other.tensors[k]) - for k in self.tensors) + return all(torch.equal(self.tensors[k], other.tensors[k]) for k in self.tensors) def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" -class PoolerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The output from a pooling operation in the pooling model.""" - outputs: list[PoolingSequenceGroupOutput] - - def get_data_nbytes(self) -> int: - return sum(o.get_data_nbytes() for o in self.outputs) - - def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: - return self.outputs[idx] - - def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - -def get_all_seq_ids( - seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] - - -def get_all_seq_ids_and_request_ids( - seq_group_metadata_list: list[SequenceGroupMetadata] -) -> tuple[list[int], dict[str, set[int]]]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - seq_ids: list[int] = [] - request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set) - for sg in seq_group_metadata_list: - for seq_id in sg.seq_data: - seq_ids.append(seq_id) - request_id_seq_ids_mapping[sg.request_id].add(seq_id) - return seq_ids, request_id_seq_ids_mapping - - -class HiddenStates(msgspec.Struct, array_like=True, - omit_defaults=True): # type: ignore[call-arg] - """Hidden states corresponding to in-progress sequences. - Used in speculative decoding to pass hidden states from - the target model to the proposer model. - - seq_ids are the sequence ids of each entry of the batch - dimension of the hidden_states tensor""" - # Scorer hidden states. For prefill step, it is used for hidden states of - # all tokens, whereas for decode step, it use used for last accepted tokens. - hidden_states: torch.Tensor - # The sequence group metadata list. Only needed for decode step. - seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None - # Scorer hidden states of the 2nd last token proposed by the proposer ( - # irrespective of whether it was accepted or not). Only used for cases when - # last proposed token is accepted (i.e., in case of bonus tokens). For the - # case of no bonus tokens, these are ignored. - second_last_token_hidden_states: Optional[torch.Tensor] = None - - _seq_ids: list[int] = msgspec.field(default_factory=list) - - def __post_init__(self): - if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - - @property - def seq_ids(self) -> list[int]: - return self._seq_ids - - def update(self, - hidden_states: torch.Tensor, - seq_group_metadata_list: list[SequenceGroupMetadata], - second_last_token_hidden_states: Optional[torch.Tensor] = None): - """Update hidden states from target model invocation. Only used for - decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) - self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - - if self.second_last_token_hidden_states is not None: - # Adding dummy hidden_states to this to maintain same shape - self.second_last_token_hidden_states = torch.cat([ - self.second_last_token_hidden_states, - torch.zeros_like(hidden_states) - if second_last_token_hidden_states is None else - second_last_token_hidden_states - ]) - - def prune(self, - seq_group_metadata_list: list[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids. Only used for decode steps. - """ - # Currently this prunes all seq_ids not present in - # seq_group_metadata_list which might cause problems where a sequence - # may be "paused" then "resumed" later. This should only prune sequences - # which are confirmed to be aborted. - seq_ids = get_all_seq_ids(seq_group_metadata_list) - # Only keep sequence IDs that exist in self._seq_ids - seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids] - if seq_ids != self._seq_ids: - # Batch contents changed - prune removed sequences. - index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] - self._seq_ids = seq_ids - - def expand_with_bonus_tokens( - self, seq_with_bonus_token_in_last_step: set) -> None: - """Expand hidden states for sequences with bonus tokens. This is in - alignment with `MultiStepWorker._expand_execute_model_request`.""" - if self.second_last_token_hidden_states is None \ - or not seq_with_bonus_token_in_last_step: - return - - index = [] - for seq_id in self._seq_ids: - i = self._seq_ids.index(seq_id) - if seq_id in seq_with_bonus_token_in_last_step: - index.append(i + len(self._seq_ids)) - index.append(i) - - self.hidden_states = torch.cat( - [self.hidden_states, self.second_last_token_hidden_states])[index] - - class ExecuteModelRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """The model execution request, containing CPU metadata only. The LLM - engine should create an instance of this class for each request batch.""" - # The sequence group metadata list. - seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to copy. Source to dest block. - blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list) - # Virtual engine ID for pipeline parallel. - virtual_engine: int = 0 - # The number of slots for lookahead decoding. - num_lookahead_slots: int = 0 - # The number of requests in the running queue. - running_queue_size: int = 0 - # Optional hidden states from prior step. - previous_hidden_states: Optional[HiddenStates] = None - # The number of forward steps to run. - num_steps: int = 1 - # Finished request ids since last step. - finished_requests_ids: list[str] = msgspec.field(default_factory=list) - # The last sampled token ids for multi step decoding. - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None - - @property - def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.remaining_steps == 1 - - @property - def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - state = self.seq_group_metadata_list[0].state - assert state is not None - return state.current_step - - def clone( - self, seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - ) -> "ExecuteModelRequest": - """Clone the request with a new sequence group metadata list.""" - return ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=self.blocks_to_swap_in.copy(), - blocks_to_swap_out=self.blocks_to_swap_out.copy(), - blocks_to_copy=self.blocks_to_copy.copy(), - virtual_engine=self.virtual_engine, - num_lookahead_slots=self.num_lookahead_slots, - running_queue_size=self.running_queue_size, - previous_hidden_states=self.previous_hidden_states, - num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids, - last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) - - -@dataclass -class SequenceGroupBase: - group_id: str # the original request id before splitting - - assembled_seq_group: Optional[SequenceGroup] = None - - # seq id to a unique index inside this group - seq_id_to_index: dict[str, int] = field(default_factory=dict) - - # seq ids to be finished - to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict) - - # seq id to finished sequences - finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict) - - streaming: bool = False - - output_produced: bool = False - - @staticmethod - def add_request(request_id: str, engine, params, *args, **kwargs): - """When we are ready to add a request with request_id and params - into the engine, we can split the request into multiple requests. - """ - raise NotImplementedError - - def finish_seq(self, seq: SequenceGroup): - """The sequence `seq` finishes, we should record the information. - """ - del self.to_be_finished[seq.request_id] - self.finished_reqs[seq.request_id] = seq - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - """Assemble the sequence group, for producing the final - output, or adding request in the engine again. - """ - raise NotImplementedError - - -class ParallelSampleSequenceGroup(SequenceGroupBase): - - @staticmethod - def add_request(request_id: str, engine, params, **kwargs): - original_params = params - group = ParallelSampleSequenceGroup(request_id) - seqs = [] - for i in range(original_params.n): - request_id_i = f"{request_id}_parallel_sample_{i}" - group.seq_id_to_index[request_id_i] = i - params = original_params.clone() - params.n = 1 - if params.seed is not None: - params.seed += i - seq_group = engine._add_processed_request( - request_id_i, - params=params, - **kwargs, - ) # type: ignore - assert seq_group is not None - engine.seq_id_to_seq_group[request_id_i] = group - group.to_be_finished[request_id_i] = seq_group - seqs.append(seq_group.seqs[0]) - - # for parallel sampling, the `assembled_seq_group` is always - # available, since we have all the sequences ready, and they - # will not change. - group.assembled_seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - arrival_time=seq_group.arrival_time, - sampling_params=original_params, - lora_request=seq_group.lora_request, - pooling_params=seq_group.pooling_params, - pooled_data=seq_group.pooled_data, - encoder_seq=seq_group.encoder_seq, - trace_headers=seq_group.trace_headers, - priority=seq_group.priority, - ) - - group.streaming = params.output_kind == RequestOutputKind.DELTA - group.output_produced = False - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - - # in the streaming mode, we will return the assembled sequence - # for the first remaining sequence, and then return None for the - # rest of sequences - if self.streaming: - first_remaining_id = next(iter(self.to_be_finished)) - if seq_group.request_id == first_remaining_id: - return self.assembled_seq_group - return None - - # in the non-streaming mode, we will return the assembled sequence - # when the last sequences finishes, and then return None for the - # rest of the time - if (len(self.to_be_finished) == 1 - and seq_group.request_id in self.to_be_finished - and seq_group.is_finished()): - assert self.assembled_seq_group is not None - params = self.assembled_seq_group.sampling_params - assert isinstance(params, SamplingParams) - if not self.output_produced: - self.output_produced = True - if params._real_n is not None: - # Get the top-n sequences. - n = params._real_n or params.n - seqs = self.assembled_seq_group.seqs - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - self.assembled_seq_group.seqs = top_n_seqs - return self.assembled_seq_group - if self.output_produced: - return None - return None + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, +): # type: ignore[call-arg] + # Placeholder. Remove. + pass diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 23679b8228d6f..91dcc2fd84e17 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -36,7 +36,6 @@ MODELS_ON_S3 = [ "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", # "meta-llama/Llama-2-7b-hf", - "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Meta-Llama-3-8B", diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py index d215e5d8bf657..6aabbc217dd03 100644 --- a/vllm/third_party/pynvml.py +++ b/vllm/third_party/pynvml.py @@ -1022,7 +1022,7 @@ def _extractNVMLErrorsAsClasses(): Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate exceptions more easily. - NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + NVMLError is a parent class. Each NVML_ERROR_* gets its own subclass. e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized ''' this_module = sys.modules[__name__] @@ -3533,7 +3533,7 @@ def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): return [] elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): # typical case - # oversize the array incase more processes are created + # oversize the array in case more processes are created c_count.value = c_count.value * 2 + 5 proc_array = c_nvmlProcessInfo_v3_t * c_count.value c_procs = proc_array() diff --git a/vllm/tracing.py b/vllm/tracing.py index 6a287d82be5ff..c9b595999fc78 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -17,12 +17,15 @@ otel_import_error_traceback: Optional[str] = None try: from opentelemetry.context.context import Context from opentelemetry.sdk.environment_variables import ( - OTEL_EXPORTER_OTLP_TRACES_PROTOCOL) + OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, + ) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider from opentelemetry.trace.propagation.tracecontext import ( - TraceContextTextMapPropagator) + TraceContextTextMapPropagator, + ) + _is_otel_imported = True except ImportError: # Capture and format traceback to provide detailed context for the import @@ -30,6 +33,7 @@ except ImportError: # memory leaks. # See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458 import traceback + otel_import_error_traceback = traceback.format_exc() class Context: # type: ignore @@ -49,13 +53,15 @@ def is_otel_available() -> bool: return _is_otel_imported -def init_tracer(instrumenting_module_name: str, - otlp_traces_endpoint: str) -> Optional[Tracer]: +def init_tracer( + instrumenting_module_name: str, otlp_traces_endpoint: str +) -> Optional[Tracer]: if not is_otel_available(): raise ValueError( "OpenTelemetry is not available. Unable to initialize " "a tracer. Ensure OpenTelemetry packages are installed. " - f"Original error:\n{otel_import_error_traceback}") + f"Original error:\n{otel_import_error_traceback}" + ) trace_provider = TracerProvider() span_exporter = get_span_exporter(otlp_traces_endpoint) @@ -70,19 +76,19 @@ def get_span_exporter(endpoint): protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") if protocol == "grpc": from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter) + OTLPSpanExporter, + ) elif protocol == "http/protobuf": from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) # type: ignore + OTLPSpanExporter, # type: ignore + ) else: - raise ValueError( - f"Unsupported OTLP protocol '{protocol}' is configured") + raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured") return OTLPSpanExporter(endpoint=endpoint) -def extract_trace_context( - headers: Optional[Mapping[str, str]]) -> Optional[Context]: +def extract_trace_context(headers: Optional[Mapping[str, str]]) -> Optional[Context]: if is_otel_available(): headers = headers or {} return TraceContextTextMapPropagator().extract(headers) @@ -91,7 +97,6 @@ def extract_trace_context( def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]: - return {h: headers[h] for h in TRACE_HEADERS if h in headers} @@ -113,12 +118,13 @@ class SpanAttributes: GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e" GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler" # Time taken in the forward pass for this across all workers - GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = ( - "gen_ai.latency.time_in_model_forward") + GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward" # Time taken in the model execute function. This will include model # forward, block/sync across workers, cpu-gpu sync time and sampling time. - GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( - "gen_ai.latency.time_in_model_execute") + GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute" + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference" def contains_trace_headers(headers: Mapping[str, str]) -> bool: @@ -127,5 +133,4 @@ def contains_trace_headers(headers: Mapping[str, str]) -> bool: @run_once def log_tracing_disabled_warning() -> None: - logger.warning( - "Received a request with trace context but tracing is disabled") + logger.warning("Received a request with trace context but tracing is disabled") diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py index 6d4231baca50b..649df9a4f0222 100644 --- a/vllm/transformers_utils/__init__.py +++ b/vllm/transformers_utils/__init__.py @@ -10,10 +10,11 @@ if envs.VLLM_USE_MODELSCOPE: from packaging import version # patch_hub begins from modelscope>=1.18.1 - if version.parse(modelscope.__version__) <= version.parse('1.18.0'): + if version.parse(modelscope.__version__) <= version.parse("1.18.0"): raise ImportError( - 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' - 'install by `pip install modelscope -U`') + "Using vLLM with ModelScope needs modelscope>=1.18.1, please " + "install by `pip install modelscope -U`" + ) from modelscope.utils.hf_util import patch_hub # Patch hub to download models from modelscope to speed up. @@ -21,4 +22,5 @@ if envs.VLLM_USE_MODELSCOPE: except ImportError as err: raise ImportError( "Please install modelscope>=1.18.1 via " - "`pip install modelscope>=1.18.1` to use ModelScope.") from err + "`pip install modelscope>=1.18.1` to use ModelScope." + ) from err diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index e0ef7f0999d47..b8d0cd8d2f208 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -12,25 +12,32 @@ CHAT_TEMPLATES_DIR = Path(__file__).parent ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]] -def _get_qwen_chat_template_fallback( - tokenizer_name_or_path: str) -> Optional[Path]: +def _get_qwen_chat_template_fallback(tokenizer_name_or_path: str) -> Optional[Path]: if tokenizer_name_or_path.endswith("-Chat"): return CHAT_TEMPLATES_DIR / "template_chatml.jinja" return CHAT_TEMPLATES_DIR / "template_basic.jinja" -# yapf: disable +def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Optional[Path]: + # MiniCPM-V-4.5 version uses a dedicated template + if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: + return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" + + # Other versions use chatml template + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", + "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", - "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } -# yapf: enable def register_chat_template_fallback_path( @@ -40,8 +47,10 @@ def register_chat_template_fallback_path( if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: logger.warning( "Model type %s already has a chat template registered. " - "It will be overwritten by the new chat template %s.", model_type, - chat_template) + "It will be overwritten by the new chat template %s.", + model_type, + chat_template, + ) _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja new file mode 100644 index 0000000000000..661ebd1cf5c17 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja @@ -0,0 +1,93 @@ +{%- set enable_thinking = enable_thinking | default(false) %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} + +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '</think>' in message.content %} + {%- set content = message.content.split('</think>')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '<tool_call>\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n</tool_call>' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n<tool_response>\n' }} + {{- message.content }} + {{- '\n</tool_response>' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '<think>\n\n</think>\n\n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '<think>\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fe345bd8f0a2e..4a8bb8f8b41de 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,46 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum import json import os import time +from dataclasses import asdict from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub -from huggingface_hub import get_safetensors_metadata, hf_hub_download +from huggingface_hub import ( + get_safetensors_metadata, + hf_hub_download, + try_to_load_from_cache, +) from huggingface_hub import list_repo_files as hf_list_repo_files -from huggingface_hub import try_to_load_from_cache -from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - LocalEntryNotFoundError, - RepositoryNotFoundError, - RevisionNotFoundError) +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) from transformers import GenerationConfig, PretrainedConfig -from transformers.models.auto.image_processing_auto import ( - get_image_processor_config) -from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.models.auto.image_processing_auto import get_image_processor_config +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, - EAGLEConfig, JAISConfig, - KimiVLConfig, MedusaConfig, - MLPSpeculatorConfig, - Nemotron_Nano_VL_Config, - NemotronConfig, OvisConfig, - RWConfig, SpeculatorsConfig, - Step3TextConfig, Step3VLConfig, - UltravoxConfig) -# yapf: enable -from vllm.transformers_utils.configs.mistral import adapt_config_dict -from vllm.transformers_utils.utils import check_gguf_file +from vllm.transformers_utils.config_parser_base import ConfigParserBase +from vllm.transformers_utils.utils import ( + check_gguf_file, + parse_safetensors_file_metadata, +) if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -56,59 +51,233 @@ def _get_hf_token() -> Optional[str]: """ Get the HuggingFace token from environment variable. - Returns None if the token is not set, is an empty string, + Returns None if the token is not set, is an empty string, or contains only whitespace. This follows the same pattern as huggingface_hub library which treats empty string tokens as None to avoid authentication errors. """ - token = os.getenv('HF_TOKEN') + token = os.getenv("HF_TOKEN") if token and token.strip(): return token return None -_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { - "chatglm": ChatGLMConfig, - "deepseek_vl_v2": DeepseekVLV2Config, - "kimi_vl": KimiVLConfig, - "Llama_Nemotron_Nano_VL": Nemotron_Nano_VL_Config, - "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) - "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig, - "mlp_speculator": MLPSpeculatorConfig, - "medusa": MedusaConfig, - "eagle": EAGLEConfig, - "speculators": SpeculatorsConfig, - "nemotron": NemotronConfig, - "ovis": OvisConfig, - "ultravox": UltravoxConfig, - "step3_vl": Step3VLConfig, - "step3_text": Step3TextConfig, -} +class LazyConfigDict(dict): + def __getitem__(self, key): + import vllm.transformers_utils.configs as configs + + return getattr(configs, super().__getitem__(key)) + + +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( + chatglm="ChatGLMConfig", + deepseek_vl_v2="DeepseekVLV2Config", + deepseek_v3="DeepseekV3Config", + deepseek_v32="DeepseekV3Config", + flex_olmo="FlexOlmoConfig", + kimi_vl="KimiVLConfig", + Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", + RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) + RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) + jais="JAISConfig", + mlp_speculator="MLPSpeculatorConfig", + medusa="MedusaConfig", + midashenglm="MiDashengLMConfig", + eagle="EAGLEConfig", + speculators="SpeculatorsConfig", + nemotron="NemotronConfig", + olmo3="Olmo3Config", + ovis="OvisConfig", + ultravox="UltravoxConfig", + step3_vl="Step3VLConfig", + step3_text="Step3TextConfig", + qwen3_next="Qwen3NextConfig", + lfm2_moe="Lfm2MoeConfig", +) _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", } _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = { - "internvl_chat": { - "has_no_defaults_at_init": True - }, - # transformers regards mllama as is_encoder_decoder=False - # vllm needs is_encoder_decoder=True to enable cross-attention - "mllama": { - "is_encoder_decoder": True - }, - "NVLM_D": { - "has_no_defaults_at_init": True - }, + "internvl_chat": {"has_no_defaults_at_init": True}, + "NVLM_D": {"has_no_defaults_at_init": True}, } -class ConfigFormat(str, enum.Enum): - AUTO = "auto" - HF = "hf" - MISTRAL = "mistral" +class HFConfigParser(ConfigParserBase): + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE + config_dict, _ = PretrainedConfig.get_config_dict( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type is None: + model_type = ( + "speculators" + if config_dict.get("speculators_config") is not None + else model_type + ) + + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + else: + try: + kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type) + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + except ValueError as e: + if ( + not trust_remote_code + and "requires you to execute the configuration file" in str(e) + ): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + +class MistralConfigParser(ConfigParserBase): + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if ( + max_position_embeddings := config_dict.get("max_position_embeddings") + ) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs + ) + config_dict["max_position_embeddings"] = max_position_embeddings + + from vllm.transformers_utils.configs.mistral import adapt_config_dict + + config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if (sliding_window := getattr(config, "sliding_window", None)) and isinstance( + sliding_window, list + ): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config + + +_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { + "hf": HFConfigParser, + "mistral": MistralConfigParser, +} + +ConfigFormat = Literal[ + "auto", + "hf", + "mistral", +] + + +def get_config_parser(config_format: str) -> ConfigParserBase: + """Get the config parser for a given config format.""" + if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER: + raise ValueError(f"Unknown config format `{config_format}`.") + return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]() + + +def register_config_parser(config_format: str): + """Register a customized vllm config parser. + When a config format is not supported by vllm, you can register a customized + config parser to support it. + Args: + config_format (str): The config parser format name. + Examples: + + >>> from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) + >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase + >>> + >>> @register_config_parser("custom_config_parser") + ... class CustomConfigParser(ConfigParserBase): + ... def parse( + ... self, + ... model: Union[str, Path], + ... trust_remote_code: bool, + ... revision: Optional[str] = None, + ... code_revision: Optional[str] = None, + ... **kwargs, + ... ) -> tuple[dict, PretrainedConfig]: + ... raise NotImplementedError + >>> + >>> type(get_config_parser("custom_config_parser")) + <class 'CustomConfigParser'> + """ # noqa: E501 + + def _wrapper(config_parser_cls): + if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER: + logger.warning( + "Config format `%s` is already registered, and will be " + "overwritten by the new parser class `%s`.", + config_format, + config_parser_cls, + ) + if not issubclass(config_parser_cls, ConfigParserBase): + raise ValueError( + "The config parser must be a subclass of `ConfigParserBase`." + ) + _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls + logger.info( + "Registered config parser `%s` with config format `%s`", + config_parser_cls, + config_format, + ) + return config_parser_cls + + return _wrapper _R = TypeVar("_R") @@ -127,8 +296,9 @@ def with_retry( if attempt == max_retries - 1: logger.error("%s: %s", log_msg, e) raise - logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1, - max_retries) + logger.error( + "%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries + ) time.sleep(retry_delay) retry_delay *= 2 @@ -144,28 +314,27 @@ def list_repo_files( repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> list[str]: - def lookup_files() -> list[str]: # directly list files if model is local if (local_path := Path(repo_id)).exists(): return [ str(file.relative_to(local_path)) - for file in local_path.rglob('*') if file.is_file() + for file in local_path.rglob("*") + if file.is_file() ] # if model is remote, use hf_hub api to list files try: if envs.VLLM_USE_MODELSCOPE: - from vllm.transformers_utils.utils import ( - modelscope_list_repo_files) - return modelscope_list_repo_files(repo_id, - revision=revision, - token=os.getenv( - "MODELSCOPE_API_TOKEN", - None)) - return hf_list_repo_files(repo_id, - revision=revision, - repo_type=repo_type, - token=token) + from vllm.transformers_utils.utils import modelscope_list_repo_files + + return modelscope_list_repo_files( + repo_id, + revision=revision, + token=os.getenv("MODELSCOPE_API_TOKEN", None), + ) + return hf_list_repo_files( + repo_id, revision=revision, repo_type=repo_type, token=token + ) except huggingface_hub.errors.OfflineModeIsEnabled: # Don't raise in offline mode, # all we know is that we don't have this @@ -183,23 +352,23 @@ def file_exists( revision: Optional[str] = None, token: Union[str, bool, None] = None, ) -> bool: - file_list = list_repo_files(repo_id, - repo_type=repo_type, - revision=revision, - token=token) + file_list = list_repo_files( + repo_id, repo_type=repo_type, revision=revision, token=token + ) return file_name in file_list # In offline mode the result can be a false negative -def file_or_path_exists(model: Union[str, Path], config_name: str, - revision: Optional[str]) -> bool: +def file_or_path_exists( + model: Union[str, Path], config_name: str, revision: Optional[str] +) -> bool: if (local_path := Path(model)).exists(): return (local_path / config_name).is_file() # Offline mode support: Check if config file is cached already - cached_filepath = try_to_load_from_cache(repo_id=model, - filename=config_name, - revision=revision) + cached_filepath = try_to_load_from_cache( + repo_id=model, filename=config_name, revision=revision + ) if isinstance(cached_filepath, str): # The config file exists in cache- we can continue trying to load return True @@ -208,10 +377,9 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # hf_hub. This will fail in offline mode. # Call HF to check if the file exists - return file_exists(str(model), - config_name, - revision=revision, - token=_get_hf_token()) + return file_exists( + str(model), config_name, revision=revision, token=_get_hf_token() + ) def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -233,7 +401,8 @@ def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None: raise ValueError( f"Found conflicts between 'rope_type={rope_type}' (modern " f"field) and 'type={rope_type_legacy}' (legacy field). " - "You should only specify one of them.") + "You should only specify one of them." + ) if "rope_type" not in rope_scaling and "type" in rope_scaling: rope_scaling["rope_type"] = rope_scaling["type"] @@ -261,8 +430,11 @@ def _uses_mrope(config: PretrainedConfig) -> bool: def uses_mrope(config: PretrainedConfig) -> bool: """Detect if the model with this config uses M-ROPE.""" - return _uses_mrope(config) or _uses_mrope( - config.get_text_config()) or thinker_uses_mrope(config) + return ( + _uses_mrope(config) + or _uses_mrope(config.get_text_config()) + or thinker_uses_mrope(config) + ) def thinker_uses_mrope(config: PretrainedConfig) -> bool: @@ -284,8 +456,7 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool: def _is_encoder_decoder(config: PretrainedConfig) -> bool: return getattr(config, "is_encoder_decoder", False) - return (_is_encoder_decoder(config) - or _is_encoder_decoder(config.get_text_config())) + return _is_encoder_decoder(config) or _is_encoder_decoder(config.get_text_config()) def is_interleaved(config: PretrainedConfig) -> bool: @@ -314,20 +485,33 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: if hasattr(config, old_attr): if not hasattr(config, new_attr): config.update({new_attr: getattr(config, old_attr)}) - logger.debug("Remapped config attribute '%s' to '%s'", old_attr, - new_attr) + logger.debug("Remapped config attribute '%s' to '%s'", old_attr, new_attr) return config -def maybe_override_with_speculators_target_model( +def maybe_override_with_speculators( model: str, tokenizer: str, trust_remote_code: bool, revision: Optional[str] = None, + vllm_speculative_config: Optional[dict[str, Any]] = None, **kwargs, -) -> tuple[str, str]: +) -> tuple[str, str, Optional[dict[str, Any]]]: """ - If running a speculators config, override running model with target model + Resolve model configuration when speculators are detected. + + Checks if the provided model is a speculators model and if so, extracts + the target model configuration and builds the speculative config. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + trust_remote_code: Whether to trust remote code + revision: Model revision + vllm_speculative_config: Existing vLLM speculative config + + Returns: + Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ is_gguf = check_gguf_file(model) if is_gguf: @@ -343,11 +527,27 @@ def maybe_override_with_speculators_target_model( token=_get_hf_token(), **kwargs, ) - spec_config = config_dict.get("speculators_config", None) - # Return the target model - if spec_config is not None: - model = tokenizer = spec_config["verifier"]["name_or_path"] - return model, tokenizer + speculators_config = config_dict.get("speculators_config") + + if speculators_config is None: + # No speculators config found, return original values + return model, tokenizer, vllm_speculative_config + + # Speculators format detected - process overrides + from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig + + speculative_config = SpeculatorsConfig.extract_vllm_speculative_config( + config_dict=config_dict + ) + + # Set the draft model to the speculators model + speculative_config["model"] = model + + # Override model and tokenizer with the verifier model from config + verifier_model = speculators_config["verifier"]["name_or_path"] + model = tokenizer = verifier_model + + return model, tokenizer, speculative_config def get_config( @@ -355,10 +555,9 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - config_format: ConfigFormat = ConfigFormat.AUTO, + config_format: Union[str, ConfigFormat] = "auto", hf_overrides_kw: Optional[dict[str, Any]] = None, - hf_overrides_fn: Optional[Callable[[PretrainedConfig], - PretrainedConfig]] = None, + hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -368,20 +567,20 @@ def get_config( kwargs["gguf_file"] = Path(model).name model = Path(model).parent - if config_format == ConfigFormat.AUTO: + if config_format == "auto": try: - if is_gguf or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.HF - elif file_or_path_exists(model, - MISTRAL_CONFIG_NAME, - revision=revision): - config_format = ConfigFormat.MISTRAL + if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): + config_format = "hf" + elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): + config_format = "mistral" else: raise ValueError( "Could not detect config format for no config file found. " - "Ensure your model has either config.json (HF format) " - "or params.json (Mistral format).") + "With config_format 'auto', ensure your model has either " + "config.json (HF format) or params.json (Mistral format). " + "Otherwise please specify your_custom_config_format " + "in engine args for customized config parser." + ) except Exception as e: error_message = ( @@ -396,99 +595,23 @@ def get_config( "'params.json'.\n" "3. For GGUF: pass the local path of the GGUF checkpoint.\n" " Loading GGUF from a remote repo directly is not yet " - "supported.\n").format(model=model) + "supported.\n" + ).format(model=model) raise ValueError(error_message) from e - if config_format == ConfigFormat.HF: - kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE - config_dict, _ = PretrainedConfig.get_config_dict( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type is None: - model_type = "speculators" if config_dict.get( - "speculators_config") is not None else model_type - - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - else: - try: - kwargs = _maybe_update_auto_config_kwargs( - kwargs, model_type=model_type) - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model " - "is a custom model not yet available in the " - "HuggingFace transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - config = _maybe_remap_hf_config_attrs(config) - - elif config_format == ConfigFormat.MISTRAL: - # This function loads a params.json config which - # should be used when loading models in mistral format - config_dict = _download_mistral_config_file(model, revision) - if (max_position_embeddings := - config_dict.get("max_position_embeddings")) is None: - max_position_embeddings = _maybe_retrieve_max_pos_from_hf( - model, revision, **kwargs) - config_dict["max_position_embeddings"] = max_position_embeddings - - config = adapt_config_dict(config_dict) - - # Mistral configs may define sliding_window as list[int]. Convert it - # to int and add the layer_types list[str] to make it HF compatible - if ((sliding_window := getattr(config, "sliding_window", None)) - and isinstance(sliding_window, list)): - pattern_repeats = config.num_hidden_layers // len(sliding_window) - layer_types = sliding_window * pattern_repeats - config.layer_types = [ - "full_attention" if layer_type is None else "sliding_attention" - for layer_type in layer_types - ] - config.sliding_window = next(filter(None, sliding_window), None) - else: - supported_formats = [ - fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO - ] - raise ValueError( - f"Unsupported config format: {config_format}. " - f"Supported formats are: {', '.join(supported_formats)}. " - f"Ensure your model uses one of these configuration formats " - f"or specify the correct format explicitly.") - + config_parser = get_config_parser(config_format) + config_dict, config = config_parser.parse( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - raise RuntimeError( - f"Can't get gguf config for {config.model_type}.") + raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) @@ -498,13 +621,37 @@ def get_config( # ModelOpt 0.29.0 and before saves the quantization config in a separate # "hf_quant_config.json" in the same directory as the model config file. - if quantization_config is None \ - and file_or_path_exists(model, "hf_quant_config.json", revision): - quantization_config = get_hf_file_to_dict("hf_quant_config.json", - model, revision) + if quantization_config is None and file_or_path_exists( + model, "hf_quant_config.json", revision + ): + quantization_config = get_hf_file_to_dict( + "hf_quant_config.json", model, revision + ) if quantization_config is not None: config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0",): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0"] = "1" + logger.info_once( + ( + "Detected quantization_config.scale_fmt=%s; " + "enabling UE8M0 for DeepGEMM." + ), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.warning_once( + ( + "Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0=0 is set; " + "UE8M0 for DeepGEMM disabled." + ), + scale_fmt, + ) if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) @@ -521,17 +668,17 @@ def get_config( return config -def try_get_local_file(model: Union[str, Path], - file_name: str, - revision: Optional[str] = 'main') -> Optional[Path]: +def try_get_local_file( + model: Union[str, Path], file_name: str, revision: Optional[str] = "main" +) -> Optional[Path]: file_path = Path(model) / file_name if file_path.is_file(): return file_path else: try: - cached_filepath = try_to_load_from_cache(repo_id=model, - filename=file_name, - revision=revision) + cached_filepath = try_to_load_from_cache( + repo_id=model, filename=file_name, revision=revision + ) if isinstance(cached_filepath, str): return Path(cached_filepath) except ValueError: @@ -539,9 +686,9 @@ def try_get_local_file(model: Union[str, Path], return None -def get_hf_file_to_dict(file_name: str, - model: Union[str, Path], - revision: Optional[str] = 'main'): +def get_hf_file_to_dict( + file_name: str, model: Union[str, Path], revision: Optional[str] = "main" +): """ Downloads a file from the Hugging Face Hub and returns its contents as a dictionary. @@ -556,25 +703,27 @@ def get_hf_file_to_dict(file_name: str, the contents of the downloaded file. """ - file_path = try_get_local_file(model=model, - file_name=file_name, - revision=revision) + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) if file_path is None: try: hf_hub_file = hf_hub_download(model, file_name, revision=revision) except huggingface_hub.errors.OfflineModeIsEnabled: return None - except (RepositoryNotFoundError, RevisionNotFoundError, - EntryNotFoundError, LocalEntryNotFoundError) as e: + except ( + RepositoryNotFoundError, + RevisionNotFoundError, + EntryNotFoundError, + LocalEntryNotFoundError, + ) as e: logger.debug("File or repository not found in hf_hub_download", e) return None except HfHubHTTPError as e: logger.warning( - "Cannot connect to Hugging Face Hub. Skipping file " - "download for '%s':", + "Cannot connect to Hugging Face Hub. Skipping file download for '%s':", file_name, - exc_info=e) + exc_info=e, + ) return None file_path = Path(hf_hub_file) @@ -586,28 +735,28 @@ def get_hf_file_to_dict(file_name: str, @cache -def get_pooling_config(model: str, revision: Optional[str] = 'main'): +def get_pooling_config(model: str, revision: Optional[str] = "main") -> Optional[dict]: """ This function gets the pooling and normalize config from the model - only applies to sentence-transformers models. Args: - model (str): The name of the Hugging Face model. - revision (str, optional): The specific version - of the model to use. Defaults to 'main'. + model: The name of the Hugging Face model. + revision: The specific version of the model to use. + Defaults to 'main'. Returns: - dict: A dictionary containing the pooling - type and whether normalization is used. + A dictionary containing the pooling type and whether + normalization is used, or None if no pooling configuration is found. """ modules_file_name = "modules.json" modules_dict = None - if file_or_path_exists(model=model, - config_name=modules_file_name, - revision=revision): + if file_or_path_exists( + model=model, config_name=modules_file_name, revision=revision + ): modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) if modules_dict is None: @@ -615,20 +764,31 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): logger.info("Found sentence-transformers modules configuration.") - pooling = next((item for item in modules_dict - if item["type"] == "sentence_transformers.models.Pooling"), - None) + pooling = next( + ( + item + for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling" + ), + None, + ) normalize = bool( - next((item for item in modules_dict - if item["type"] == "sentence_transformers.models.Normalize"), - False)) + next( + ( + item + for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize" + ), + False, + ) + ) if pooling: - pooling_file_name = "{}/config.json".format(pooling["path"]) pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) pooling_type_name = next( - (item for item, val in pooling_dict.items() if val is True), None) + (item for item, val in pooling_dict.items() if val is True), None + ) if pooling_type_name is not None: pooling_type_name = get_pooling_config_name(pooling_type_name) @@ -649,20 +809,19 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: if "lasttoken" in pooling_name: pooling_name = "last" - supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] + supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"] pooling_type_name = pooling_name.upper() if pooling_type_name in supported_pooling_types: return pooling_type_name - raise NotImplementedError( - f"Pooling type {pooling_type_name} not supported") + raise NotImplementedError(f"Pooling type {pooling_type_name} not supported") @cache -def get_sentence_transformer_tokenizer_config(model: Union[str, Path], - revision: Optional[str] = 'main' - ): +def get_sentence_transformer_tokenizer_config( + model: Union[str, Path], revision: Optional[str] = "main" +): """ Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. @@ -689,9 +848,10 @@ def get_sentence_transformer_tokenizer_config(model: Union[str, Path], encoder_dict = None for config_file in sentence_transformer_config_files: - if try_get_local_file(model=model, - file_name=config_file, - revision=revision) is not None: + if ( + try_get_local_file(model=model, file_name=config_file, revision=revision) + is not None + ): encoder_dict = get_hf_file_to_dict(config_file, model, revision) if encoder_dict: break @@ -699,16 +859,15 @@ def get_sentence_transformer_tokenizer_config(model: Union[str, Path], if not encoder_dict and not Path(model).is_absolute(): try: # If model is on HuggingfaceHub, get the repo files - repo_files = list_repo_files(model, - revision=revision, - token=_get_hf_token()) + repo_files = list_repo_files( + model, revision=revision, token=_get_hf_token() + ) except Exception: repo_files = [] for config_name in sentence_transformer_config_files: if config_name in repo_files: - encoder_dict = get_hf_file_to_dict(config_name, model, - revision) + encoder_dict = get_hf_file_to_dict(config_name, model, revision) if encoder_dict: break @@ -725,34 +884,39 @@ def get_sentence_transformer_tokenizer_config(model: Union[str, Path], def maybe_register_config_serialize_by_value() -> None: """Try to register HF model configuration class to serialize by value - If trust_remote_code is set, and the model's config file specifies an - `AutoConfig` class, then the config class is typically an instance of - a custom class imported from the HF modules cache. + If trust_remote_code is set, and the model's config file specifies an + `AutoConfig` class, then the config class is typically an instance of + a custom class imported from the HF modules cache. - Examples: + Examples: - >>> from transformers import AutoConfig - >>> klass = AutoConfig.from_pretrained('meta-llama/Meta-Llama-3-8B', trust_remote_code=True) - >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig - >>> import transformers_modules # error, not initialized - >>> klass = AutoConfig.from_pretrained('deepseek-ai/DeepSeek-V2.5', trust_remote_code=True) - >>> import transformers_modules # success, initialized - >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config + >>> from transformers import AutoConfig + >>> klass = AutoConfig.from_pretrained( + ... "meta-llama/Meta-Llama-3-8B", trust_remote_code=True + ... ) + >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig + >>> import transformers_modules # error, not initialized + >>> klass = AutoConfig.from_pretrained( + ... "deepseek-ai/DeepSeek-V2.5", trust_remote_code=True + ... ) + >>> import transformers_modules # success, initialized + >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config - In the DeepSeek example, the config class is an instance of a custom - class that is not serializable by default. This class will not be - importable in spawned workers, and won't exist at all on - other nodes, which breaks serialization of the config. + In the DeepSeek example, the config class is an instance of a custom + class that is not serializable by default. This class will not be + importable in spawned workers, and won't exist at all on + other nodes, which breaks serialization of the config. - In this function we tell the cloudpickle serialization library to pass - instances of these generated classes by value instead of by reference, - i.e. the class definition is serialized along with its data so that the - class module does not need to be importable on the receiving end. + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. - See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs - """ # noqa + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ # noqa try: import transformers_modules + transformers_modules_available = True except ImportError: transformers_modules_available = False @@ -769,7 +933,7 @@ def maybe_register_config_serialize_by_value() -> None: # serialization of VllmConfig objects that may contain custom configs # from transformers_modules def _reduce_config(config: VllmConfig): - return (pickle.loads, (cloudpickle.dumps(config), )) + return (pickle.loads, (cloudpickle.dumps(config),)) multiprocessing.reducer.register(VllmConfig, _reduce_config) @@ -779,6 +943,7 @@ def maybe_register_config_serialize_by_value() -> None: # ray vendors its own version of cloudpickle from vllm.executor.ray_utils import ray + if ray: ray.cloudpickle.register_pickle_by_value(transformers_modules) @@ -788,7 +953,8 @@ def maybe_register_config_serialize_by_value() -> None: " trust_remote_code with by-value serialization. This may" " lead to a later error. If remote code is not needed" " remove `--trust-remote-code`", - exc_info=e) + exc_info=e, + ) def get_hf_image_processor_config( @@ -803,10 +969,9 @@ def get_hf_image_processor_config( # Separate model folder from file path for GGUF models if check_gguf_file(model): model = Path(model).parent - return get_image_processor_config(model, - token=hf_token, - revision=revision, - **kwargs) + return get_image_processor_config( + model, token=hf_token, revision=revision, **kwargs + ) def get_hf_text_config(config: PretrainedConfig): @@ -828,6 +993,7 @@ def try_get_generation_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, + config_format: Union[str, ConfigFormat] = "auto", ) -> Optional[GenerationConfig]: try: return GenerationConfig.from_pretrained( @@ -840,6 +1006,7 @@ def try_get_generation_config( model, trust_remote_code=trust_remote_code, revision=revision, + config_format=config_format, ) return GenerationConfig.from_model_config(config) except OSError: # Not found @@ -859,8 +1026,9 @@ def try_get_safetensors_metadata( ) try: - return with_retry(get_safetensors_metadata_partial, - "Error retrieving safetensors") + return with_retry( + get_safetensors_metadata_partial, "Error retrieving safetensors" + ) except Exception: return None @@ -880,6 +1048,34 @@ def try_get_tokenizer_config( return None +def get_safetensors_params_metadata( + model: str, + *, + revision: Optional[str] = None, +) -> dict[str, Any]: + """ + Get the safetensors metadata for remote model repository. + """ + full_metadata = {} + if (model_path := Path(model)).exists(): + safetensors_to_check = model_path.glob("*.safetensors") + full_metadata = { + param_name: info + for file_path in safetensors_to_check + if file_path.is_file() + for param_name, info in parse_safetensors_file_metadata(file_path).items() + } + else: + repo_mt = try_get_safetensors_metadata(model, revision=revision) + if repo_mt and (files_mt := repo_mt.files_metadata): + full_metadata = { + param_name: asdict(info) + for file_mt in files_mt.values() + for param_name, info in file_mt.tensors.items() + } + return full_metadata + + def _download_mistral_config_file(model, revision) -> dict: config_file_name = "params.json" config_dict = get_hf_file_to_dict(config_file_name, model, revision) @@ -887,7 +1083,8 @@ def _download_mistral_config_file(model, revision) -> dict: raise ValueError( f"Failed to load mistral '{config_file_name}' config for model " f"{model}. Please check if the model is a mistral-format model " - f"and if the config file exists.") + f"and if the config file exists." + ) assert isinstance(config_dict, dict) return config_dict @@ -896,10 +1093,12 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: max_position_embeddings = 128_000 try: trust_remote_code_val = kwargs.get("trust_remote_code", False) - hf_config = get_config(model=model, - trust_remote_code=trust_remote_code_val, - revision=revision, - config_format=ConfigFormat.HF) + hf_config = get_config( + model=model, + trust_remote_code=trust_remote_code_val, + revision=revision, + config_format="hf", + ) if hf_value := hf_config.get_text_config().max_position_embeddings: max_position_embeddings = hf_value except Exception as e: @@ -907,7 +1106,8 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: "The params.json file is missing 'max_position_embeddings'" " and could not get a value from the HF config." " Defaulting to 128000", - exc_info=e) + exc_info=e, + ) return max_position_embeddings @@ -923,7 +1123,28 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None): if envs.VLLM_USE_MODELSCOPE: from modelscope.hub.snapshot_download import snapshot_download + return snapshot_download(model_id=model, **common_kwargs) from huggingface_hub import snapshot_download + return snapshot_download(repo_id=model, **common_kwargs) + + +def get_hf_file_bytes( + file_name: str, model: Union[str, Path], revision: Optional[str] = "main" +) -> Optional[bytes]: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) + + if file_path is None: + hf_hub_file = hf_hub_download( + model, file_name, revision=revision, token=_get_hf_token() + ) + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path, "rb") as file: + return file.read() + + return None diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py new file mode 100644 index 0000000000000..0e1c49b428b07 --- /dev/null +++ b/vllm/transformers_utils/config_parser_base.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class ConfigParserBase(ABC): + @abstractmethod + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8339c55bcf808..befe9cdae76a1 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -8,44 +8,63 @@ Model configs may be defined in this directory for the following reasons: """ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig + # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig +from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.midashenglm import MiDashengLMConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config +from vllm.transformers_utils.configs.olmo3 import Olmo3Config from vllm.transformers_utils.configs.ovis import OvisConfig +from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig +from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig -from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, - Step3VisionEncoderConfig, - Step3VLConfig) +from vllm.transformers_utils.configs.step3_vl import ( + Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig, +) from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ "ChatGLMConfig", "DeepseekVLV2Config", + "DeepseekV3Config", + "DotsOCRConfig", "EAGLEConfig", + "FlexOlmoConfig", "RWConfig", "JAISConfig", + "Lfm2MoeConfig", "MedusaConfig", + "MiDashengLMConfig", "MLPSpeculatorConfig", "MoonViTConfig", "KimiVLConfig", "NemotronConfig", "NemotronHConfig", "Nemotron_Nano_VL_Config", + "Olmo3Config", "OvisConfig", + "RadioConfig", "SpeculatorsConfig", "UltravoxConfig", "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "Qwen3NextConfig", ] diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py index a789b93b5edff..1707e15285c89 100644 --- a/vllm/transformers_utils/configs/arctic.py +++ b/vllm/transformers_utils/configs/arctic.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # coding=utf-8 # Copied from # https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py -""" Arctic model configuration""" +"""Arctic model configuration""" from dataclasses import asdict, dataclass from typing import Any diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index 176d2b8f63fe4..1d795b55c8bc7 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -13,33 +13,35 @@ class ChatGLMConfig(PretrainedConfig): "n_head_kv": "multi_query_group_num", } - def __init__(self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - interleaved_qkv=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs): + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): self.num_layers = num_layers self.vocab_size = padded_vocab_size self.padded_vocab_size = padded_vocab_size @@ -55,7 +57,8 @@ class ChatGLMConfig(PretrainedConfig): self.layernorm_epsilon = layernorm_epsilon self.rmsnorm = rmsnorm self.apply_residual_connection_post_layernorm = ( - apply_residual_connection_post_layernorm) + apply_residual_connection_post_layernorm + ) self.post_layer_norm = post_layer_norm self.add_bias_linear = add_bias_linear self.add_qkv_bias = add_qkv_bias diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py new file mode 100644 index 0000000000000..91fbed79dd021 --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class DeepseekV3Config(PretrainedConfig): + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="noaux_tc", + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func="sigmoid", + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py index 957d638318410..7abfe62298422 100644 --- a/vllm/transformers_utils/configs/deepseek_vl2.py +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -25,20 +25,22 @@ class VisionEncoderConfig(PretrainedConfig): deterministic: bool = False num_recomputing_layers: int = 0 - def __init__(self, - model_name: str = "vit_so400m_patch14_siglip_384.webli", - image_size: int = 384, - patch_size: int = 16, - width: int = 1024, - layers: int = 24, - heads: int = 16, - mlp_ratio: int = 4, - global_pool: str = "map", - ignore_head: bool = True, - class_token: bool = False, - num_classes: int = 0, - use_checkpoint: bool = False, - **kwargs): + def __init__( + self, + model_name: str = "vit_so400m_patch14_siglip_384.webli", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs, + ): self.model_name = model_name self.image_size = image_size self.patch_size = patch_size @@ -65,14 +67,16 @@ class MlpProjectorConfig(PretrainedConfig): downsample_ratio: int = 2 token_pooling: bool = False - def __init__(self, - projector_type: str = "downsample_mlp_gelu", - input_dim: int = 1152, - n_embed: int = 2048, - depth: int = 2, - mlp_ratio: int = 1, - downsample_ratio: int = 2, - **kwargs): + def __init__( + self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs, + ): self.projector_type = projector_type self.input_dim = input_dim self.n_embed = n_embed @@ -84,7 +88,6 @@ class MlpProjectorConfig(PretrainedConfig): class DeepseekV2Config(PretrainedConfig): - model_type = "deepseek_v2" keys_to_ignore_at_inference = ["past_key_values"] @@ -106,14 +109,14 @@ class DeepseekV2Config(PretrainedConfig): qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, - topk_method='gready', + topk_method="gready", n_group=None, topk_group=None, num_experts_per_tok=None, moe_layer_freq=1, first_k_dense_replace=0, norm_topk_prob=False, - scoring_func='softmax', + scoring_func="softmax", aux_loss_alpha=0.001, seq_aux=True, hidden_act="silu", @@ -191,14 +194,15 @@ class DeepseekVLV2Config(PretrainedConfig): tile_tag: str = "2D" global_view_pos: str = "head" - candidate_resolutions: tuple[tuple[int, int]] = ((384, 384), ) + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),) - def __init__(self, - tile_tag: str = "tile_tag", - global_view_pos: str = "head", - candidate_resolutions: tuple[tuple[int, - int]] = ((384, 384), ), - **kwargs): + def __init__( + self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),), + **kwargs, + ): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) diff --git a/vllm/transformers_utils/configs/dotsocr.py b/vllm/transformers_utils/configs/dotsocr.py new file mode 100644 index 0000000000000..446693b9a32eb --- /dev/null +++ b/vllm/transformers_utils/configs/dotsocr.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.qwen2 import Qwen2Config + + +class DotsVisionConfig(PretrainedConfig): + model_type: str = "dots_vit" + + def __init__( + self, + embed_dim: int = 1536, # vision encoder embed size + hidden_size: int = 1536, # after merger hidden size + intermediate_size: int = 4224, + num_hidden_layers: int = 42, + num_attention_heads: int = 12, + num_channels: int = 3, + patch_size: int = 14, + spatial_merge_size: int = 2, + temporal_patch_size: int = 1, + rms_norm_eps: float = 1e-5, + use_bias: bool = False, + attn_implementation="flash_attention_2", + initializer_range=0.02, + init_merger_std=0.02, + is_causal=False, # ve causal forward + post_norm=True, + gradient_checkpointing=False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.rms_norm_eps = rms_norm_eps + self.use_bias = use_bias + self.attn_implementation = attn_implementation + self.initializer_range = initializer_range + self.init_merger_std = init_merger_std + self.is_causal = is_causal + self.post_norm = post_norm + self.gradient_checkpointing = gradient_checkpointing + + +class DotsOCRConfig(Qwen2Config): + model_type = "dots_ocr" + + def __init__( + self, + image_token_id=151665, + video_token_id=151656, + vision_config: Optional[dict] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_config = DotsVisionConfig(**(vision_config or {})) + + def save_pretrained(self, save_directory, **kwargs): + self._auto_class = None + super().save_pretrained(save_directory, **kwargs) diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 6aabf9e5262e6..6e18513d12340 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -12,12 +12,13 @@ from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config class EAGLEConfig(PretrainedConfig): model_type = "eagle" - def __init__(self, - model: Union[PretrainedConfig, dict, None] = None, - truncated_vocab_size: Optional[int] = None, - method: Optional[str] = 'eagle', - **kwargs): - + def __init__( + self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + method: Optional[str] = "eagle", + **kwargs, + ): model_config: Union[PretrainedConfig, DeepseekV2Config, None] if isinstance(model, dict): archs = model.get("architectures", []) @@ -31,8 +32,7 @@ class EAGLEConfig(PretrainedConfig): model_config = model for k, v in kwargs.items(): - if k != "architectures" and k != "model_type" and hasattr( - model_config, k): + if k != "architectures" and k != "model_type" and hasattr(model_config, k): setattr(model_config, k, v) self.model = model_config @@ -40,29 +40,39 @@ class EAGLEConfig(PretrainedConfig): if self.model is None: self.truncated_vocab_size = None else: - self.truncated_vocab_size = self.model.vocab_size if \ - truncated_vocab_size is None else truncated_vocab_size + self.truncated_vocab_size = ( + self.model.vocab_size + if truncated_vocab_size is None + else truncated_vocab_size + ) # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM + # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 if method == "eagle": - assert self.model is not None, \ + assert self.model is not None, ( "model should not be None when method is eagle" + ) kwargs["architectures"] = [ - f"Eagle{arch}" if not arch.startswith("Eagle") \ - else arch for arch in self.model.architectures + f"Eagle{arch}" if not arch.startswith("Eagle") else arch + for arch in self.model.architectures ] + elif method == "eagle3": - assert self.model is not None, \ + assert self.model is not None, ( "model should not be None when method is eagle3" + ) kwargs["architectures"] = [ - arch if arch.startswith("Eagle3") or arch.endswith("Eagle3") - else f"Eagle3{arch}" for arch in self.model.architectures + arch + if arch.startswith("Eagle3") or arch.endswith("Eagle3") + else f"Eagle3{arch}" + for arch in self.model.architectures ] else: - raise ValueError(f"Invalid method {method}. " - "Supported methods are eagle and eagle3.") + raise ValueError( + f"Invalid method {method}. Supported methods are eagle and eagle3." + ) super().__init__(**kwargs) @@ -78,5 +88,6 @@ class EAGLEConfig(PretrainedConfig): **kwargs, ) -> "EAGLEConfig": config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) + pretrained_model_name_or_path, **kwargs + ) return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/falcon.py b/vllm/transformers_utils/configs/falcon.py index 2f5400463d91a..c646d241d4eb0 100644 --- a/vllm/transformers_utils/configs/falcon.py +++ b/vllm/transformers_utils/configs/falcon.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Falcon configuration""" + from transformers.configuration_utils import PretrainedConfig @@ -77,9 +78,7 @@ class RWConfig(PretrainedConfig): # Hack for falcon-40b self.new_decoder_architecture = True - super().__init__(bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @property def head_dim(self): diff --git a/vllm/transformers_utils/configs/flex_olmo.py b/vllm/transformers_utils/configs/flex_olmo.py new file mode 100644 index 0000000000000..1f2f4d446288b --- /dev/null +++ b/vllm/transformers_utils/configs/flex_olmo.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class FlexOlmoConfig(PretrainedConfig): + model_type = "flex_olmo" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=100352, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=100277, + bos_token_id=None, + eos_token_id=100257, + tie_word_embeddings=False, + rope_theta=500000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + num_experts_per_tok=5, + num_experts=7, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=False, + **kwargs, + ): + if "architectures" not in kwargs: + kwargs["architectures"] = ["FlexOlmoForCausalLM"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 767c4ddae870d..6b581bf187755 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -74,10 +74,9 @@ class JAISConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). - scale_attn_by_inverse_layer_idx (`bool`, *optional*, - defaults to `False`): - Whether to additionally scale attention weights by - `1 / layer_idx + 1`. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, default `True`): + Whether to additionally scale attention weights + by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention (dot-product) @@ -210,29 +209,35 @@ class JAISConfig(PretrainedConfig): if self.alibi_scaling is None: return - if (not isinstance(self.alibi_scaling, dict) - or len(self.alibi_scaling) != 2): + if not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2: raise ValueError( "`alibi_scaling` must be a dictionary with two fields, " "`type` and `factor` or `type` and `train_seq_len`, " - f"got {self.alibi_scaling}") + f"got {self.alibi_scaling}" + ) alibi_scaling_type = self.alibi_scaling.get("type", None) alibi_scaling_factor = self.alibi_scaling.get("factor", None) alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) if alibi_scaling_type is None or alibi_scaling_type != "linear": - raise ValueError(f"`alibi_scaling`'s type field must be 'linear', " - f"got {alibi_scaling_type}") - if (alibi_scaling_factor is not None - and not isinstance(alibi_scaling_factor, float) - or (alibi_scaling_factor is not None - and alibi_scaling_factor <= 1.0)): + raise ValueError( + f"`alibi_scaling`'s type field must be 'linear', " + f"got {alibi_scaling_type}" + ) + if ( + alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or (alibi_scaling_factor is not None and alibi_scaling_factor <= 1.0) + ): raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0, " - f"got {alibi_scaling_factor}") - if (alibi_dynamic_scaling is not None - and not isinstance(alibi_dynamic_scaling, int) - or (alibi_dynamic_scaling is not None - and alibi_dynamic_scaling <= 1)): + f"got {alibi_scaling_factor}" + ) + if ( + alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or (alibi_dynamic_scaling is not None and alibi_dynamic_scaling <= 1) + ): raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an " - f"integer > 1, got {alibi_dynamic_scaling}") + f"integer > 1, got {alibi_dynamic_scaling}" + ) diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py index ae8dac0f381d6..89a8878465b6d 100644 --- a/vllm/transformers_utils/configs/kimi_vl.py +++ b/vllm/transformers_utils/configs/kimi_vl.py @@ -12,13 +12,15 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig class KimiVLConfig(PretrainedConfig): model_type = "kimi_vl" - def __init__(self, - vision_config: Optional[Union[dict, MoonViTConfig]] = None, - text_config: Optional[Union[dict, DeepseekV2Config]] = None, - ignore_index: int = -100, - media_placeholder_token_id: int = 163605, - pad_token_id: int = 0, - **kwargs): + def __init__( + self, + vision_config: Optional[Union[dict, MoonViTConfig]] = None, + text_config: Optional[Union[dict, DeepseekV2Config]] = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + **kwargs, + ): if vision_config is None: vision_config = MoonViTConfig() elif isinstance(vision_config, dict): diff --git a/vllm/transformers_utils/configs/lfm2_moe.py b/vllm/transformers_utils/configs/lfm2_moe.py new file mode 100644 index 0000000000000..7d17c2b4f74c5 --- /dev/null +++ b/vllm/transformers_utils/configs/lfm2_moe.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class Lfm2MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Lfm2MoeModel`]. It is used to instantiate a LFM2 Moe + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LFM2-8B-A1B model. + e.g. [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Lfm2Model`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7168): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1792): + Intermediate size of the routed expert. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the conv layers. + conv_L_cache (`int`, *optional*, defaults to 3): + L_cache dim in the conv layers. + num_dense_layers (`int`, *optional*, defaults to 2): + Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 32): + Number of routed experts. + use_expert_bias (`bool`, *optional*, defaults to `True`): + Whether to use the expert bias on the routing weights. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for routed experts in MoE models. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + layer_types (`Optional`, *optional*): + Type of each layers. + + ```python + >>> from transformers import Lfm2MoeModel, Lfm2MoeConfig + + >>> # Initializing a LFM2 Moe model + >>> configuration = Lfm2MoeConfig() + + >>> # Initializing a model from the LFM2-8B-A1B style configuration + >>> model = Lfm2MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" # noqa: E501 + + model_type = "lfm2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 65536, + hidden_size: int = 2048, + intermediate_size: int = 7168, + moe_intermediate_size: int = 1792, + num_hidden_layers: int = 32, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = True, + rope_theta: float = 1000000.0, + max_position_embeddings: int = 128_000, + use_cache: bool = True, + norm_eps: float = 0.00001, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + conv_bias: bool = False, + conv_L_cache: int = 3, + num_dense_layers: int = 2, + num_experts_per_tok: int = 4, + num_experts: int = 32, + use_expert_bias: bool = True, + routed_scaling_factor: float = 1.0, + norm_topk_prob: bool = True, + layer_types: Optional[list[str]] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.norm_eps = norm_eps + + # attn operator config + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + # custom operator config + self.conv_bias = conv_bias + self.conv_L_cache = conv_L_cache + + # moe config + self.num_dense_layers = num_dense_layers + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.use_expert_bias = use_expert_bias + self.routed_scaling_factor = routed_scaling_factor + self.norm_topk_prob = norm_topk_prob + self.layer_types = layer_types + + tie_word_embeddings = kwargs.get( + "tie_embedding", tie_word_embeddings + ) # to fit original config keys + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Lfm2MoeConfig"] diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py index 9ba52956a8e8e..7dcfd0cf26aef 100644 --- a/vllm/transformers_utils/configs/medusa.py +++ b/vllm/transformers_utils/configs/medusa.py @@ -10,16 +10,17 @@ from transformers import PretrainedConfig class MedusaConfig(PretrainedConfig): model_type = "medusa" - def __init__(self, - hidden_size: int = 4096, - vocab_size: int = 32001, - num_heads: int = 5, - num_hidden_layers: int = 1, - max_paths: int = 64, - topk: int = 10, - truncated_vocab_size: Optional[int] = None, - **kwargs): - + def __init__( + self, + hidden_size: int = 4096, + vocab_size: int = 32001, + num_heads: int = 5, + num_hidden_layers: int = 1, + max_paths: int = 64, + topk: int = 10, + truncated_vocab_size: Optional[int] = None, + **kwargs, + ): self.hidden_size = hidden_size self.vocab_size = vocab_size self.num_heads = num_heads @@ -27,8 +28,9 @@ class MedusaConfig(PretrainedConfig): self.max_paths = max_paths self.topk = topk self.max_seq_len = int(2**20) - self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ - else truncated_vocab_size + self.truncated_vocab_size = ( + vocab_size if truncated_vocab_size is None else truncated_vocab_size + ) if "architectures" not in kwargs: kwargs["architectures"] = ["MedusaModel"] @@ -41,12 +43,13 @@ class MedusaConfig(PretrainedConfig): **kwargs, ) -> "MedusaConfig": config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) + pretrained_model_name_or_path, **kwargs + ) for k in list(config_dict.keys()): - if 'num' in k: - if 'heads' in k: + if "num" in k: + if "heads" in k: config_dict["num_heads"] = config_dict.pop(k) - elif 'layers' in k: + elif "layers" in k: config_dict["num_hidden_layers"] = config_dict.pop(k) return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/midashenglm.py b/vllm/transformers_utils/configs/midashenglm.py new file mode 100644 index 0000000000000..5c9e72be8ebff --- /dev/null +++ b/vllm/transformers_utils/configs/midashenglm.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Horizon team, Xiaomi MiLM Plus. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from transformers import PretrainedConfig +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTextConfig, +) + + +class DashengConfig(PretrainedConfig): + model_type = "midashenglm_dasheng_encoder" + + def __init__( + self, + embed_dim: int = 768, + outputdim: int = 527, + patch_size: Union[int, tuple[int, int]] = 16, + patch_stride: Union[int, tuple[int, int]] = 16, + input_channels: int = 1, + target_length: int = 1012, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + init_values: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + f_min: float = 0.0, + f_max: float = 8000.0, + center: bool = True, + win_length: int = 512, + hop_length: int = 160, + sample_rate: int = 16000, + n_fft: int = 512, + n_mels: int = 64, + **kwargs, + ): + self.embed_dim = embed_dim + self.outputdim = outputdim + self.patch_size = patch_size + self.patch_stride = patch_stride + self.input_channels = input_channels + self.target_length = target_length + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.init_values = init_values + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.f_min = f_min + self.f_max = f_max + self.center = center + self.win_length = win_length + self.hop_length = hop_length + self.sample_rate = sample_rate + self.n_fft = n_fft + self.n_mels = n_mels + super().__init__(**kwargs) + + +class MiDashengLMConfig(PretrainedConfig): + model_type = "midashenglm" + + def __init__( + self, + audio_encoder_config: Optional[dict] = None, + subsample_factor: int = 5, + text_config: Optional[dict] = None, + audio_token_id: Optional[int] = None, + **kwargs, + ): + self.audio_encoder_config = DashengConfig(**(audio_encoder_config or {})) + self.subsample_factor = subsample_factor + self.text_config = ( + Qwen2_5OmniTextConfig(**text_config) + if text_config + else Qwen2_5OmniTextConfig() + ) + self.text_config.rope_scaling = None # uses_mrope is false + self.audio_token_id = audio_token_id + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 8a9c660b882fd..d5bf79e01f954 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -9,8 +9,7 @@ from vllm.logger import init_logger logger = init_logger(__name__) -def adapt_config_dict(config_dict: dict[str, Any], - **kwargs) -> PretrainedConfig: +def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig: config_dict.update(kwargs) config_dict = _remap_general_mistral_args(config_dict) @@ -25,15 +24,16 @@ def adapt_config_dict(config_dict: dict[str, Any], if bool(config_dict.get("yarn")): config_dict = _remap_mistral_yarn_args(config_dict) - is_vision = ((config_dict.get("multimodal") - or {}).get("vision_encoder_args") - or config_dict.get("vision_encoder")) + is_vision = (config_dict.get("multimodal") or {}).get( + "vision_encoder_args" + ) or config_dict.get("vision_encoder") is_audio = bool( - ((config_dict.get("multimodal") or {}).get("whisper_model_args") - or {}).get("encoder_args")) + ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get( + "encoder_args" + ) + ) - assert not (is_vision and is_audio), \ - "Vision and audio are mutually exclusive" + assert not (is_vision and is_audio), "Vision and audio are mutually exclusive" if is_vision: config_dict = _remap_mistral_vision_args(config_dict) @@ -77,7 +77,7 @@ def _remap_mistral_yarn_args(config: dict) -> dict: config["rope_scaling"] = { "rope_type": "yarn", "mscale_all_dim": 1, # We hardcoded this to 1 - **renamed_yarn_config + **renamed_yarn_config, } return config @@ -105,8 +105,7 @@ def _remap_general_mistral_args(config: dict) -> dict: if key in config: config[new_key] = config.pop(key) - for new_key, (key, - default_value) in top_level_mapping_with_default.items(): + for new_key, (key, default_value) in top_level_mapping_with_default.items(): config[new_key] = config.pop(key, default_value) return config @@ -116,16 +115,12 @@ def _remap_mistral_quantization_args(config: dict) -> dict: quantization = config.get("quantization", {}) if quantization.get("qformat_weight") == "fp8_e4m3": # This maps to the FP8 static per-tensor quantization scheme - quantization_config = { - "quant_method": "fp8", - "activation_scheme": "static" - } + quantization_config = {"quant_method": "fp8", "activation_scheme": "static"} elif quantization.get("quant_method") == "compressed-tensors": # Pass through the quantization config to compressed-tensors quantization_config = quantization else: - raise ValueError( - f"Found unknown quantization='{quantization}' in config") + raise ValueError(f"Found unknown quantization='{quantization}' in config") config["quantization_config"] = quantization_config @@ -139,13 +134,10 @@ def _remap_mistral_audio_args(config: dict) -> dict: quant_config = config.get("quantization_config") config = { - "model_type": - "whixtral", + "model_type": "whixtral", "architectures": ["VoxtralForConditionalGeneration"], - "text_config": - PretrainedConfig.from_dict(config), - "audio_config": - WhisperConfig( + "text_config": PretrainedConfig.from_dict(config), + "audio_config": WhisperConfig( num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], window_size=encoder_args["audio_encoding_args"]["window_size"], sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], @@ -157,7 +149,8 @@ def _remap_mistral_audio_args(config: dict) -> dict: encoder_attention_heads=encoder_args["n_heads"], vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], - ) + is_encoder_decoder=False, # Override WhisperConfig default + ), } if quant_config: config["quantization_config"] = quant_config diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py index 2fa284e5c9e8f..45d76a8fdf264 100644 --- a/vllm/transformers_utils/configs/mlp_speculator.py +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -13,16 +13,18 @@ class MLPSpeculatorConfig(PretrainedConfig): "hidden_size": "emb_dim", } - def __init__(self, - vocab_size: int = 32000, - emb_dim: int = 4096, - inner_dim: int = 0, - n_predict: int = 3, - top_k_tokens_per_head: Optional[list[int]] = None, - n_candidates: int = 5, - tie_weights: bool = False, - scale_input: bool = False, - **kwargs): + def __init__( + self, + vocab_size: int = 32000, + emb_dim: int = 4096, + inner_dim: int = 0, + n_predict: int = 3, + top_k_tokens_per_head: Optional[list[int]] = None, + n_candidates: int = 5, + tie_weights: bool = False, + scale_input: bool = False, + **kwargs, + ): """ Initialize an MLPSpeculatorConfig diff --git a/vllm/transformers_utils/configs/moonvit.py b/vllm/transformers_utils/configs/moonvit.py index a6f712f3d6005..6e9b2897f4cc7 100644 --- a/vllm/transformers_utils/configs/moonvit.py +++ b/vllm/transformers_utils/configs/moonvit.py @@ -8,16 +8,16 @@ class MoonViTConfig(PretrainedConfig): model_type = "moonvit" def __init__( - self, - patch_size: int = 14, - init_pos_emb_height: int = 64, - init_pos_emb_width: int = 64, - num_attention_heads: int = 16, - num_hidden_layers: int = 27, - hidden_size: int = 1152, - intermediate_size: int = 4304, - merge_kernel_size: tuple[int, int] = (2, 2), - **kwargs, + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + **kwargs, ): super().__init__(**kwargs) self.patch_size = patch_size diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index 9a7243b1262c0..60eed549561fb 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -26,7 +26,7 @@ logger = logging.get_logger(__name__) class NemotronConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a - [`NemotronModel`]. It is used to instantiate an Nemotron model + [`NemotronModel`]. It is used to instantiate a Nemotron model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Nemotron-8B. @@ -62,7 +62,7 @@ class NemotronConfig(PretrainedConfig): (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original - heads within that group. For more details checkout + heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): @@ -147,8 +147,9 @@ class NemotronConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads head_dim = head_dim or kwargs.get("kv_channels") - self.head_dim = head_dim if head_dim is not None else ( - hidden_size // num_attention_heads) + self.head_dim = ( + head_dim if head_dim is not None else (hidden_size // num_attention_heads) + ) # for backward compatibility if num_key_value_heads is None: @@ -162,8 +163,11 @@ class NemotronConfig(PretrainedConfig): self.rope_theta = rope_theta self.rope_scaling = rope_scaling # for backward compatibility - partial_rotary_factor = kwargs.get("rope_percent") or kwargs.get( - "rope_percentage") or partial_rotary_factor + partial_rotary_factor = ( + kwargs.get("rope_percent") + or kwargs.get("rope_percentage") + or partial_rotary_factor + ) self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() self.attention_bias = attention_bias @@ -185,21 +189,24 @@ class NemotronConfig(PretrainedConfig): if self.rope_scaling is None: return - if not isinstance(self.rope_scaling, dict) or len( - self.rope_scaling) != 2: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with two fields, " - f"`type` and `factor`, got {self.rope_scaling}") + f"`type` and `factor`, got {self.rope_scaling}" + ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in [ - "linear", "dynamic" - ]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( "`rope_scaling`'s type field must be one of ['linear', " - f"'dynamic'], got {rope_scaling_type}") - if rope_scaling_factor is None or not isinstance( - rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + f"'dynamic'], got {rope_scaling_type}" + ) + if ( + rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0 + ): raise ValueError( "`rope_scaling`'s factor field must be a float > 1, got " - f"{rope_scaling_factor}") \ No newline at end of file + f"{rope_scaling_factor}" + ) diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 027f2911543f5..c8b6784d6a8ef 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -38,7 +38,7 @@ class NemotronHConfig(PretrainedConfig): passed when calling [`NemotronHModel`] tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be - tied. Note that this is only relevant if the model has a output + tied. Note that this is only relevant if the model has an output word embedding layer. hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. @@ -203,11 +203,11 @@ class NemotronHConfig(PretrainedConfig): # Validate hybrid_override_pattern # M: Mamba2, *: Attention, -: MLP assert len(self.hybrid_override_pattern) == self.num_hidden_layers, ( - "hybrid_override_pattern must have same length as " - "num_hidden_layers") + "hybrid_override_pattern must have same length as num_hidden_layers" + ) assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), ( - "hybrid_override_pattern must only contain characters " - "'M', '*', or '-'") + "hybrid_override_pattern must only contain characters 'M', '*', or '-'" + ) # for backward compatibility if num_key_value_heads is None: @@ -253,7 +253,10 @@ class NemotronHConfig(PretrainedConfig): @property def layers_block_type(self): return [ - "mamba" if self.hybrid_override_pattern[i] == "M" else - "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + "mamba" + if self.hybrid_override_pattern[i] == "M" + else "attention" + if self.hybrid_override_pattern[i] == "*" + else "mlp" for i in range(self.num_hidden_layers) ] diff --git a/vllm/transformers_utils/configs/nemotron_vl.py b/vllm/transformers_utils/configs/nemotron_vl.py index 6a642f26b82a2..6f98fbafbed5f 100644 --- a/vllm/transformers_utils/configs/nemotron_vl.py +++ b/vllm/transformers_utils/configs/nemotron_vl.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # Adapted from # https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/configuration.py @@ -16,7 +15,7 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module class Nemotron_Nano_VL_Config(PretrainedConfig): - model_type = 'Llama_Nemotron_Nano_VL' + model_type = "Llama_Nemotron_Nano_VL" is_composition = True def __init__( @@ -26,17 +25,22 @@ class Nemotron_Nano_VL_Config(PretrainedConfig): force_image_size=None, downsample_ratio=0.5, template=None, - ps_version='v1', + ps_version="v1", image_tag_type="internvl", projector_hidden_size=4096, vit_hidden_size=1280, - **kwargs + **kwargs, ): super().__init__(**kwargs) if vision_config is not None: - assert "auto_map" in vision_config and "AutoConfig" in vision_config["auto_map"] - vision_auto_config = get_class_from_dynamic_module(*vision_config["auto_map"]["AutoConfig"].split("--")[::-1]) + assert ( + "auto_map" in vision_config + and "AutoConfig" in vision_config["auto_map"] + ) + vision_auto_config = get_class_from_dynamic_module( + *vision_config["auto_map"]["AutoConfig"].split("--")[::-1] + ) self.vision_config = vision_auto_config(**vision_config) else: self.vision_config = PretrainedConfig() @@ -51,6 +55,6 @@ class Nemotron_Nano_VL_Config(PretrainedConfig): self.downsample_ratio = downsample_ratio self.template = template # TODO move out of here and into the tokenizer self.ps_version = ps_version # Pixel shuffle version - self.image_tag_type = image_tag_type # TODO: into the tokenizer too? + self.image_tag_type = image_tag_type # TODO: into the tokenizer too? self.projector_hidden_size = projector_hidden_size self.vit_hidden_size = vit_hidden_size diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py new file mode 100644 index 0000000000000..f5a9a7cd36bdb --- /dev/null +++ b/vllm/transformers_utils/configs/olmo3.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class Olmo3Config(PretrainedConfig): + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in vLLM. + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ] diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index 550f5e15dbcc2..404fa700a26c0 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py # and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py @@ -70,34 +69,37 @@ class AIMv2Config(PretrainedConfig): # Visual Tokenizer Configuration # ---------------------------------------------------------------------- class BaseVisualTokenizerConfig(PretrainedConfig): - - def __init__(self, - vocab_size=16384, - tokenize_function="softmax", - tau=1.0, - depths=None, - drop_cls_token=False, - backbone_config: Optional[Union[PretrainedConfig, - dict]] = None, - hidden_stride: int = 1, - **kwargs): + def __init__( + self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, dict]] = None, + hidden_stride: int = 1, + **kwargs, + ): super().__init__(**kwargs) self.vocab_size = vocab_size self.tokenize_function = tokenize_function self.tau = tau if isinstance(depths, str): - depths = [int(x) for x in depths.split('|')] + depths = [int(x) for x in depths.split("|")] self.depths = depths self.backbone_kwargs = dict[str, Any]() self.drop_cls_token = drop_cls_token if backbone_config is not None: - assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + assert isinstance(backbone_config, (PretrainedConfig, dict)), ( f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + ) if not isinstance(backbone_config, PretrainedConfig): - model_type = backbone_config['model_type'] + model_type = backbone_config["model_type"] if model_type != "aimv2": - backbone_config.pop('model_type') - backbone_config = AutoConfig.for_model(model_type, **backbone_config) + backbone_config.pop("model_type") + backbone_config = AutoConfig.for_model( + model_type, **backbone_config + ) else: backbone_config = AIMv2Config(**backbone_config) self.backbone_config = backbone_config @@ -113,7 +115,7 @@ class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): self.drop_cls_token = False if self.depths: assert len(self.depths) == 1 - self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + self.backbone_kwargs["num_hidden_layers"] = self.depths[0] class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): @@ -125,7 +127,7 @@ class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): self.drop_cls_token = False if self.depths: assert len(self.depths) == 1 - self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + self.backbone_kwargs["num_hidden_layers"] = self.depths[0] AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) @@ -138,35 +140,39 @@ AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) class OvisConfig(PretrainedConfig): model_type = "ovis" - def __init__(self, - llm_config: Optional[Union[PretrainedConfig, dict]] = None, - visual_tokenizer_config: Optional[Union[PretrainedConfig, - dict]] = None, - multimodal_max_length=8192, - hidden_size=None, - conversation_formatter_class=None, - llm_attn_implementation=None, - disable_tie_weight=False, - **kwargs): + def __init__( + self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs, + ): super().__init__(**kwargs) if llm_config is not None: - assert isinstance(llm_config, (PretrainedConfig, dict)), \ + assert isinstance(llm_config, (PretrainedConfig, dict)), ( f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + ) if not isinstance(llm_config, PretrainedConfig): - model_type = llm_config['model_type'] - llm_config.pop('model_type') + model_type = llm_config["model_type"] + llm_config.pop("model_type") llm_config = AutoConfig.for_model(model_type, **llm_config) # map llm_config to text_config self.text_config = llm_config if visual_tokenizer_config is not None: - assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), ( f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + ) if not isinstance(visual_tokenizer_config, PretrainedConfig): - model_type = visual_tokenizer_config['model_type'] - visual_tokenizer_config.pop('model_type') + model_type = visual_tokenizer_config["model_type"] + visual_tokenizer_config.pop("model_type") visual_tokenizer_config = AutoConfig.for_model( - model_type, **visual_tokenizer_config) + model_type, **visual_tokenizer_config + ) self.visual_tokenizer_config = visual_tokenizer_config self.multimodal_max_length = multimodal_max_length diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py new file mode 100644 index 0000000000000..21750bde2f878 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3-Next model configuration""" + +from transformers.configuration_utils import PretrainedConfig, layer_type_validation +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ # noqa: E501 + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + layer_types=None, + **kwargs, + ): + if mlp_only_layers is None: + mlp_only_layers = [] + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "linear_attention" if bool((i + 1) % 4) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + +__all__ = ["Qwen3NextConfig"] diff --git a/vllm/transformers_utils/configs/radio.py b/vllm/transformers_utils/configs/radio.py new file mode 100644 index 0000000000000..f13598034bae8 --- /dev/null +++ b/vllm/transformers_utils/configs/radio.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Radio vision model configuration""" + +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VIT_TIMM_DIM_BY_NAME: dict[str, tuple[int, int, int, int]] = { + "vit_small_patch16_224": (384, 12, 6, 1536), + "vit_base_patch16_224": (768, 12, 12, 3072), + "vit_large_patch16_224": (1024, 24, 16, 4096), + "vit_huge_patch16_224": (1280, 32, 16, 5120), +} + +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + + +class RadioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a Radio + vision model. It is used to instantiate a Radio model according to the + specified arguments, defining the model architecture. + + Args: + model_name: Name of the vision transformer model + (e.g., "vit_base_patch16_224"). Used to determine architecture + dimensions from `VIT_TIMM_DIM_BY_NAME`. + image_size: The size (resolution) of each image. + patch_size: The size (resolution) of each patch. + qkv_bias: Whether to add a bias to the queries, keys and values. + qk_normalization: Whether to apply normalization to queries and keys. + norm_type: The normalization type to use. + layer_norm_eps: The epsilon used by the layer normalization layers. + initializer_factor: A factor for initializing all weight matrices. + hidden_act: The non-linear activation function in the encoder. + max_img_size: Maximum image size for position embeddings. + norm_mean: Mean values for image normalization (RGB channels). + Defaults to (0.48145466, 0.4578275, 0.40821073)). + norm_std: Standard deviation values for image normalization + (RGB channels). Defaults to (0.26862954, 0.26130258, 0.27577711)). + reg_tokens: Number of register tokens to use. + """ + + model_type = "radio" + + def __init__( + self, + model_name: str, + image_size: int = 224, + patch_size: int = 16, + qkv_bias: bool = True, + qk_normalization: bool = False, + norm_type: str = "layer_norm", + layer_norm_eps: float = 1e-6, + initializer_factor: float = 1.0, + hidden_act: str = "gelu", + max_img_size: int = 2048, + norm_mean: Union[tuple[float, float, float], list] = OPENAI_CLIP_MEAN, + norm_std: Union[tuple[float, float, float], list] = OPENAI_CLIP_STD, + reg_tokens: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + ( + self.hidden_size, + self.num_hidden_layers, + self.num_attention_heads, + self.intermediate_size, + ) = VIT_TIMM_DIM_BY_NAME[model_name] + self.image_size = image_size + self.patch_size = patch_size + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.norm_type = norm_type + self.layer_norm_eps = layer_norm_eps + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.max_img_size = max_img_size + self.norm_mean = ( + list(norm_mean) if isinstance(norm_mean, (tuple, list)) else norm_mean + ) + self.norm_std = ( + list(norm_std) if isinstance(norm_std, (tuple, list)) else norm_std + ) + self.reg_tokens = reg_tokens + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index efc87b6bcf26f..88bce3d4f79e9 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -5,7 +5,6 @@ SUPPORTED_SPECULATORS_TYPES = {} def register_speculator(name): - def decorator(fn): SUPPORTED_SPECULATORS_TYPES[name] = fn return fn @@ -17,16 +16,23 @@ def register_speculator(name): def update_eagle3(config_dict: dict, vllm_config: dict) -> None: """ Apply Eagle-3 specific configuration transformations. - + Eagle-3 specific fields: - draft_vocab_size: Size of the draft model's vocabulary - target_hidden_size: Hidden size of the target model - norm_before_residual: Whether to apply norm before residual connection + - eagle_aux_hidden_state_layer_ids: List of layer indices from the base + model to use as auxiliary inputs for the Eagle3 drafter. These layers + provide intermediate hidden states that help the drafter make better + predictions. This is the standard field used in Eagle3 checkpoints. """ vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") if config_dict.get("target_hidden_size") is not None: vllm_config["target_hidden_size"] = config_dict["target_hidden_size"] - vllm_config["norm_before_residual"] = config_dict.get( - "norm_before_residual", True) + vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True) vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] + if config_dict.get("eagle_aux_hidden_state_layer_ids"): + vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ + "eagle_aux_hidden_state_layer_ids" + ] diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index d7c16e180c709..1c415a43360ea 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -6,7 +6,8 @@ from typing import Any, Union from transformers import PretrainedConfig from vllm.transformers_utils.configs.speculators.algos import ( - SUPPORTED_SPECULATORS_TYPES) + SUPPORTED_SPECULATORS_TYPES, +) __all__ = ["SpeculatorsConfig"] @@ -21,24 +22,31 @@ class SpeculatorsConfig(PretrainedConfig): **kwargs, ) -> "SpeculatorsConfig": """Load speculators Eagle config and convert to vLLM format.""" - config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, - **kwargs) + config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + vllm_config = cls.extract_vllm_speculative_config(config_dict) + return cls(**vllm_config) + + @classmethod + def extract_vllm_speculative_config( + cls, config_dict: dict[str, Any] + ) -> dict[str, Any]: speculators_model_type = config_dict.get("speculators_model_type") if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: raise ValueError( f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. " - "Please ensure you're loading a speculators-format model.") + "Please ensure you're loading a speculators-format model." + ) # validate fields # TODO: @dsikka - use speculators pydantic model to validate cls.validate_speculators_config(config_dict=config_dict) # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) + vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict) # Apply anything specific to the supported algorithm algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] algo_updater(config_dict=config_dict, vllm_config=vllm_config) - return cls(**vllm_config) + return vllm_config @classmethod def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: @@ -57,35 +65,50 @@ class SpeculatorsConfig(PretrainedConfig): if not isinstance(config_dict["transformer_layer_config"], dict): raise TypeError( - "'transformer_layer_config' must be a dictionary if provided") + "'transformer_layer_config' must be a dictionary if provided" + ) @classmethod - def convert_speculators_to_vllm( - cls, config_dict: dict[str, Any]) -> dict[str, Any]: + def build_vllm_speculative_config( + cls, config_dict: dict[str, Any] + ) -> dict[str, Any]: """ - Convert speculators config format to vLLM format. - - This method handles the translation of field names and structure - between speculators and vLLM formats. - + Build vLLM-compatible speculative configuration from speculators format. + + This method extracts and transforms speculative configuration from the + speculators format into the structure expected by vLLM. + + Args: + config_dict: Configuration dictionary in speculators format + Returns: - Dictionary with vLLM-compatible configuration + Dictionary with vLLM-compatible speculative configuration """ - # Currently we only support one proposal method + # Extract speculators configuration spec_config = config_dict["speculators_config"] - first_method = spec_config.get("proposal_methods")[0] - num_lookahead_tokens = first_method.get("speculative_tokens") - if num_lookahead_tokens is None: + # Currently we only support one proposal method + proposal_methods = spec_config.get("proposal_methods") + if not proposal_methods: + raise ValueError("No proposal methods found in speculators config") + + first_method = proposal_methods[0] + num_speculative_tokens = first_method.get("speculative_tokens") + + if num_speculative_tokens is None: raise ValueError( - "Missing 'speculative_tokens' in proposal method. " - f"Got: {first_method}") + f"Missing 'speculative_tokens' in proposal method. Got: {first_method}" + ) - # Build base vLLM config + # Build base vLLM speculative configuration vllm_config = { "method": config_dict.get("speculators_model_type"), - "num_lookahead_tokens": num_lookahead_tokens, - "target_model": spec_config.get("verifier")["name_or_path"] + "num_speculative_tokens": num_speculative_tokens, + "target_model": spec_config.get("verifier")["name_or_path"], } - vllm_config.update(config_dict["transformer_layer_config"]) + + # Merge transformer layer configuration if present + transformer_config = config_dict.get("transformer_layer_config", {}) + vllm_config.update(transformer_config) + return vllm_config diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py index fe3c72de69d28..36d39e828a93b 100644 --- a/vllm/transformers_utils/configs/step3_vl.py +++ b/vllm/transformers_utils/configs/step3_vl.py @@ -59,13 +59,64 @@ class Step3TextConfig(PretrainedConfig): share_q_dim: int = 2048, head_dim: int = 256, norm_expert_weight: bool = False, - moe_layers_enum: tuple[int, - ...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, - 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59), + moe_layers_enum: tuple[int, ...] = ( + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + ), **kwargs, ) -> None: self.hidden_size = hidden_size diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 87064cc12deda..ac22304e91250 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig): Args: audio_config (`Union[AutoConfig, dict]`, *optional*): - Custom audio config or dict + Custom audio config or dict. text_config (`Union[AutoConfig, dict]`, *optional*): - The config object of the text backbone. Can be any of `LlamaConfig` - or `MistralConfig`. + The config object of the text backbone. + audio_model_id (`str`, *optional*): + The model ID of the audio backbone. + text_model_id (`str`, *optional*): + The model ID of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. audio_token_index (`int`, *optional*, defaults to 32000): @@ -34,16 +37,13 @@ class UltravoxConfig(transformers.PretrainedConfig): The initialization value for the layer normalization. projector_act (`str`, *optional*, defaults to `"swiglu"`): The activation function used by the multimodal projector. - text_model_lora_config (`LoraConfigSimplified`, *optional*): - The LoRA configuration for finetuning the text model. - audio_model_lora_config (`LoraConfigSimplified`, *optional*): - The LoRA configuration for finetuning the audio model. projector_ln_mid (`bool`, *optional*, defaults to `False`): Whether to apply layer normalization at the middle of the projector or at the end. Versions v0.4.1 and below use `False`, but v0.5 and above use `True`. """ + wrapped_model_config: transformers.PretrainedConfig model_type = "ultravox" audio_token = "<|audio|>" is_composition = False @@ -60,15 +60,10 @@ class UltravoxConfig(transformers.PretrainedConfig): stack_factor: int = 8, norm_init: float = 0.4, projector_act: str = "swiglu", - text_model_lora_config: Optional[dict[str, Any]] = None, - audio_model_lora_config: Optional[dict[str, Any]] = None, projector_ln_mid: bool = False, **kwargs, ): self.ignore_index = ignore_index - - self.audio_model_id = audio_model_id - self.text_model_id = text_model_id self.audio_token_index = audio_token_index self.hidden_size = hidden_size @@ -77,36 +72,47 @@ class UltravoxConfig(transformers.PretrainedConfig): self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid - if text_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - text_config_obj = get_config(text_model_id, - trust_remote_code=False) - else: + # N.B. May set the wrapped_model_config below. + self.text_model_id = text_model_id + if text_model_id is None: text_config = text_config or {} - text_config_obj = transformers.CONFIG_MAPPING[text_config.get( - "model_type", "llama")](**text_config) + self.wrapped_model_config = transformers.CONFIG_MAPPING[ + text_config.get("model_type", "llama") + ](**text_config) - inner_text_config = text_config_obj.get_text_config() - - if audio_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - audio_config = get_config(audio_model_id, trust_remote_code=False) - else: + # N.B. May set the audio_config below. + self.audio_model_id = audio_model_id + if audio_model_id is None: + self.audio_model_id = None audio_config = audio_config or {} - audio_config = transformers.CONFIG_MAPPING[audio_config.get( - "model_type", "whisper")](**audio_config) - - self.text_config = text_config_obj - self.audio_config = audio_config - self.text_model_lora_config = text_model_lora_config or {} - self.audio_model_lora_config = audio_model_lora_config or {} - - self.vocab_size = inner_text_config.vocab_size - self.initializer_range = inner_text_config.initializer_range - self.text_hidden_size = inner_text_config.hidden_size + self.audio_config = transformers.CONFIG_MAPPING[ + audio_config.get("model_type", "whisper") + ](**audio_config) super().__init__(**kwargs) + + def __setattr__(self, key, value): + # Since --hf-overrides are applied _after_ the UltravoxConfig is + # instantiated, load the configs implicitly when assigning text_model_id + # or audio_model_id. This allows: + # + # --hf-overrides.text_model_id=<quantized variant> + # + # to behave as intended. + if key == "text_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.wrapped_model_config = get_config(value, trust_remote_code=False) + elif key == "audio_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(value, trust_remote_code=False) + + return super().__setattr__(key, value) + + @property + def text_config(self) -> transformers.PretrainedConfig: + # When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate + # the full model, but the text config is the text config of the inner + # model. + return self.wrapped_model_config.get_text_config() diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py deleted file mode 100644 index 380c62a141f0f..0000000000000 --- a/vllm/transformers_utils/detokenizer.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, - Sequence, SequenceGroup) - -from .detokenizer_utils import (convert_prompt_ids_to_tokens, - detokenize_incrementally) -from .tokenizer import AnyTokenizer -from .tokenizer_group import TokenizerGroup - - -class Detokenizer: - """Provides methods to decode the output of a model into text.""" - - def __init__(self, tokenizer_group: TokenizerGroup): - self.tokenizer_group = tokenizer_group - - def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: - """Returns the HF tokenizer to use for a given sequence.""" - return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) - - def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, - prompt_logprobs: list[Optional[dict[ - int, Logprob]]], - position_offset: int) -> None: - """Decodes the logprobs for the prompt of a sequence group. - - Args: - seq_group: The sequence group to decode. - prompt_logprobs: The logprobs to decode. - position_offset: Offset of the first index of the logprobs - relative to the start of the sequence (for chunked prefill). - - Returns: - The prompt logprobs with the decoded tokens. - """ - prms = seq_group.sampling_params - assert prms is not None - - # We can pick any sequence for the prompt. - seq = seq_group.get_seqs()[0] - # Only prompt, without the generated token. - all_token_ids = seq.get_token_ids() - prompt_token_ids = all_token_ids[:-1] - tokenizer = self.get_tokenizer_for_seq(seq) - prefix_offset = 0 - read_offset = 0 - next_iter_prefix_offset = 0 - next_iter_read_offset = 0 - next_iter_tokens: list[str] = [] - prev_tokens = None - - for token_position_in_logprob, prompt_logprobs_for_token in enumerate( - prompt_logprobs): - - # Absolute token position equals the index in the logprobs - # list plus the offset of the entire logprobs list relative - # to the start of the sequence. - token_position = token_position_in_logprob + position_offset - if not prompt_logprobs_for_token: - continue - for token_id, sample_logprob in prompt_logprobs_for_token.items(): - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - prompt_token_ids_with_token = ( - prompt_token_ids[:token_position] + [token_id]) - (new_tokens, new_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=prompt_token_ids_with_token, - prev_tokens=prev_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - - sample_logprob.decoded_token = new_text - - # Use the offsets & prev tokens corresponding to - # real tokens to ensure detokenization is consistent - # actual with prompt. - if token_id == all_token_ids[token_position]: - next_iter_prefix_offset = new_prefix_offset - next_iter_read_offset = new_read_offset - next_iter_tokens = new_tokens - - # Advance to the next token position. - prefix_offset = next_iter_prefix_offset - read_offset = next_iter_read_offset - if prev_tokens is None: - prev_tokens = next_iter_tokens.copy() - else: - prev_tokens.extend(next_iter_tokens) - - def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> int: - """Decodes the new token for a sequence. In-place operation. - - Args: - seq: The sequence to decode. - prms: The sampling parameters used to generate the sequence. - - Returns: - The number of characters added to the output text. - """ - all_input_ids = seq.get_token_ids() - token_id_generated_this_iteration = all_input_ids[-1] - tokenizer = self.get_tokenizer_for_seq(seq) - - # Convert prompt token IDs to tokens if necessary. - # Do it here so that we don't have to repeat this - # computation for each logprob. - if seq.tokens is None: - (seq.tokens, seq.prefix_offset, - seq.read_offset) = convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=all_input_ids[:-1], - skip_special_tokens=prms.skip_special_tokens, - ) - - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=all_input_ids, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - - # Decode logprobs - logprobs = seq.output_logprobs[-1] - if logprobs: - previous_tokens = all_input_ids[:-1] - for token_id, sample_logprob in logprobs.items(): - # If the token was generated this iteration, - # use the provided text. - if token_id == token_id_generated_this_iteration: - sample_logprob.decoded_token = new_decoded_token_text - continue - - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - all_input_ids_with_logprob = previous_tokens + [token_id] - (_, new_text, _, _) = detokenize_incrementally( - tokenizer=tokenizer, - all_input_ids=all_input_ids_with_logprob, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - sample_logprob.decoded_token = new_text - - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_decoded_token_text - - return len(new_decoded_token_text) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 101f31d39cc1f..60742ae97d5d1 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -30,8 +30,9 @@ def _convert_tokens_to_string_with_added_encoders( current_sub_text: list[str] = [] convert_tokens_to_string = tokenizer.convert_tokens_to_string added_vocab_set = set(tokenizer.get_added_vocab()) - all_special_tokens = set( - tokenizer.all_special_tokens) if skip_special_tokens else () + all_special_tokens = ( + set(tokenizer.all_special_tokens) if skip_special_tokens else () + ) for token in output_tokens: # Use precomputed set for skip-special check @@ -70,11 +71,11 @@ def convert_prompt_ids_to_tokens( # We do not need to convert the whole prompt to tokens. # Offset a little more in case we have special tokens. new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], - skip_special_tokens=skip_special_tokens) + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2 :], + skip_special_tokens=skip_special_tokens, + ) read_offset = len(new_tokens) - prefix_offset = max( - read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + prefix_offset = max(read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) # This is required to guard against out-of-vocab prompt token ids _replace_none_with_empty(new_tokens) # type: ignore[arg-type] return new_tokens, prefix_offset, read_offset @@ -92,7 +93,7 @@ def convert_ids_list_to_tokens( Returns: Python list of token string representations - + """ token_str_lst = [] for token_id in token_ids: @@ -144,18 +145,17 @@ def detokenize_incrementally( # This is the first iteration for this sequence is_first_iter = prev_tokens is None if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) + (prev_tokens, prefix_offset, read_offset) = convert_prompt_ids_to_tokens( + tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens + ) assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) + [new_token_id], skip_special_tokens=skip_special_tokens + ) if isinstance(new_tokens, str): new_tokens = [new_tokens] else: @@ -171,9 +171,9 @@ def detokenize_incrementally( # surrounding ids. if tokenizer.is_fast or not tokenizer.get_added_vocab(): prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) + output_tokens[prefix_offset:read_offset] + ) + new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_offset:]) else: prefix_text = _convert_tokens_to_string_with_added_encoders( tokenizer, @@ -195,5 +195,5 @@ def detokenize_incrementally( # by the model return new_tokens, "", prefix_offset, read_offset - new_text = new_text[len(prefix_text):] + new_text = new_text[len(prefix_text) :] return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/dynamic_module.py b/vllm/transformers_utils/dynamic_module.py index 05191f95216ce..3c273ad41da00 100644 --- a/vllm/transformers_utils/dynamic_module.py +++ b/vllm/transformers_utils/dynamic_module.py @@ -27,7 +27,7 @@ def try_get_class_from_dynamic_module( **kwargs, ) -> Optional[type]: """ - As [transformers.dynamic_module_utils.get_class_from_dynamic_module][], + As `transformers.dynamic_module_utils.get_class_from_dynamic_module`, but ignoring any errors. """ try: diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index a630d940b2578..81f9b76b5ef7a 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -4,11 +4,16 @@ from functools import lru_cache from typing import TYPE_CHECKING, Any, Optional, Union, cast -from transformers import (AutoFeatureExtractor, AutoImageProcessor, - AutoProcessor) +from transformers import ( + AutoFeatureExtractor, + AutoImageProcessor, + AutoProcessor, + AutoVideoProcessor, +) from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.image_processing_utils import BaseImageProcessor from transformers.processing_utils import ProcessorMixin +from transformers.video_processing_utils import BaseVideoProcessor from typing_extensions import TypeVar from vllm.utils import get_allowed_kwarg_only_overrides @@ -17,6 +22,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig _P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) +_V = TypeVar("_V", bound=BaseVideoProcessor, default=BaseVideoProcessor) class HashableDict(dict): @@ -119,15 +125,18 @@ def get_processor( "a custom processor not yet available in the HuggingFace " "transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e if not isinstance(processor, processor_cls): - raise TypeError("Invalid type of HuggingFace processor. " - f"Expected type: {processor_cls}, but " - f"found type: {type(processor)}") + raise TypeError( + "Invalid type of HuggingFace processor. " + f"Expected type: {processor_cls}, but " + f"found type: {type(processor)}" + ) return processor @@ -156,7 +165,7 @@ def get_feature_extractor( trust_remote_code: bool = False, **kwargs: Any, ): - """Load an audio feature extractor for the given model name + """Load an audio feature extractor for the given model name via HuggingFace.""" try: feature_extractor = AutoFeatureExtractor.from_pretrained( @@ -164,7 +173,8 @@ def get_feature_extractor( *args, revision=revision, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. @@ -175,7 +185,8 @@ def get_feature_extractor( "extractor is a custom extractor not yet available in the " "HuggingFace transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -211,7 +222,8 @@ def get_image_processor( *args, revision=revision, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. @@ -222,7 +234,8 @@ def get_image_processor( "a custom processor not yet available in the HuggingFace " "transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -243,3 +256,57 @@ def cached_image_processor_from_config( trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs), ) + + +def get_video_processor( + processor_name: str, + *args: Any, + revision: Optional[str] = None, + trust_remote_code: bool = False, + processor_cls_overrides: Optional[type[_V]] = None, + **kwargs: Any, +): + """Load a video processor for the given model name via HuggingFace.""" + try: + processor_cls = processor_cls_overrides or AutoVideoProcessor + processor = processor_cls.from_pretrained( + processor_name, + *args, + revision=revision, + trust_remote_code=trust_remote_code, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoVideoProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the video processor. If the video processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(BaseVideoProcessor, processor) + + +cached_get_video_processor = lru_cache(get_video_processor) + + +def cached_video_processor_from_config( + model_config: "ModelConfig", + processor_cls: Optional[type[_V]] = None, + **kwargs: Any, +): + return cached_get_video_processor( + model_config.model, + revision=model_config.revision, + trust_remote_code=model_config.trust_remote_code, + processor_cls_overrides=processor_cls, # type: ignore[arg-type] + **_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs), + ) diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 8a1ad226d99f0..76b6d3dc9c99a 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -8,8 +8,7 @@ reasons: - There is a need to override the existing processor to support vLLM. """ -from vllm.transformers_utils.processors.deepseek_vl2 import ( - DeepseekVLV2Processor) +from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index 5896bde312657..5ef258b9be298 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # coding=utf-8 # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/ff23960c5cf9e6874b44be38af930cfb0ccbb620/deepseek_vl2/models/processing_deepseek_vl_v2.py @@ -25,6 +24,7 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import math +from typing import Any import torch import torchvision.transforms as T @@ -34,11 +34,12 @@ from transformers.processing_utils import ProcessorMixin class ImageTransform: - - def __init__(self, - mean: tuple[float, float, float] = (0.5, 0.5, 0.5), - std: tuple[float, float, float] = (0.5, 0.5, 0.5), - normalize: bool = True): + def __init__( + self, + mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): self.mean = mean self.std = std self.normalize = normalize @@ -76,7 +77,6 @@ class DeepseekVLV2Processor(ProcessorMixin): ignore_id: int = -100, **kwargs, ): - self.candidate_resolutions = candidate_resolutions self.image_size = candidate_resolutions[0][0] self.patch_size = patch_size @@ -85,13 +85,15 @@ class DeepseekVLV2Processor(ProcessorMixin): self.normalize = normalize self.downsample_ratio = downsample_ratio - self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize) + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) self.tokenizer = tokenizer - self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference + self.tokenizer.padding_side = "left" # must set this,padding side with make a difference in batch inference # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' if tokenizer.pad_token is None: - self.tokenizer.add_special_tokens({'pad_token': pad_token}) + self.tokenizer.add_special_tokens({"pad_token": pad_token}) # add image token image_token_id = self.tokenizer.vocab.get(image_token) @@ -103,7 +105,7 @@ class DeepseekVLV2Processor(ProcessorMixin): # add five special tokens for grounding-related tasks # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> - special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>'] + special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) @@ -133,15 +135,19 @@ class DeepseekVLV2Processor(ProcessorMixin): for width, height in self.candidate_resolutions: scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int( - original_width * scale), int(original_height * scale) - effective_resolution = min(downscaled_width * downscaled_height, - original_width * original_height) + downscaled_width, downscaled_height = ( + int(original_width * scale), + int(original_height * scale), + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( - effective_resolution == max_effective_resolution - and wasted_resolution < min_wasted_resolution): + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) @@ -178,17 +184,15 @@ class DeepseekVLV2Processor(ProcessorMixin): prompt: str, images: list[Image.Image], inference_mode: bool = True, - **kwargs, + **kwargs: Any, ): """ Args: prompt (str): the formatted prompt; - conversations (list[dict]): conversations with a list of messages; images (list[ImageType]): the list of images; inference_mode (bool): if True, then remove the last eos token; - system_prompt (str): the system prompt; - **kwargs: + **kwargs: Additional keyword arguments. Returns: outputs (BaseProcessorOutput): the output of the processor, @@ -199,12 +203,20 @@ class DeepseekVLV2Processor(ProcessorMixin): - num_image_tokens (list[int]): the number of image tokens """ - assert (prompt is not None and images is not None - ), "prompt and images must be used at the same time." + assert prompt is not None and images is not None, ( + "prompt and images must be used at the same time." + ) sft_format = prompt - tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images( - sft_format, images, bos=True, eos=True, cropping=len(images) <= 2) + ( + tokenized_str, + images_list, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + ) = self.tokenize_with_images( + sft_format, images, bos=True, eos=True, cropping=len(images) <= 2 + ) masked_tokenized_str = [] for token_index in tokenized_str: if token_index != self.image_token_id: @@ -212,17 +224,21 @@ class DeepseekVLV2Processor(ProcessorMixin): else: masked_tokenized_str.append(self.ignore_id) - assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \ - (f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " - f"imags_seq_mask's length {len(images_seq_mask)}, are not equal") + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" + ) input_ids = torch.LongTensor(tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) # set input_ids < 0 | input_ids == self.image_token_id as ignore_id - target_ids[(input_ids < 0) | - (input_ids == self.image_token_id)] = self.ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) input_ids[input_ids < 0] = self.pad_id if inference_mode: @@ -259,7 +275,7 @@ class DeepseekVLV2Processor(ProcessorMixin): text: str, images: list[Image.Image], inference_mode: bool = True, - **kwargs, + **kwargs: Any, ): """ @@ -312,30 +328,50 @@ class DeepseekVLV2Processor(ProcessorMixin): best_width, best_height = self.image_size, self.image_size """process the global view""" - global_view = ImageOps.pad(image, (self.image_size, self.image_size), - color=tuple(int(x * 255) for x in self.image_transform.mean)) + global_view = ImageOps.pad( + image, + (self.image_size, self.image_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) images_list.append(self.image_transform(global_view)) """process the local views""" - local_view = ImageOps.pad(image, (best_width, best_height), - color=tuple(int(x * 255) for x in self.image_transform.mean)) + local_view = ImageOps.pad( + image, + (best_width, best_height), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) for i in range(0, best_height, self.image_size): for j in range(0, best_width, self.image_size): images_list.append( - self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size)))) + self.image_transform( + local_view.crop( + (j, i, j + self.image_size, i + self.image_size) + ) + ) + ) """record height / width crop num""" - num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size + num_width_tiles, num_height_tiles = ( + best_width // self.image_size, + best_height // self.image_size, + ) images_spatial_crop.append([num_width_tiles, num_height_tiles]) """add image tokens""" - h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) + h = w = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) # global views tokens h * (w + 1), 1 is for line separator tokenized_image = [self.image_token_id] * h * (w + 1) # add a separator between global and local views tokenized_image += [self.image_token_id] # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) - tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1) + tokenized_image += ( + [self.image_token_id] + * (num_height_tiles * h) + * (num_width_tiles * w + 1) + ) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) @@ -354,10 +390,17 @@ class DeepseekVLV2Processor(ProcessorMixin): tokenized_str = tokenized_str + [self.eos_id] images_seq_mask = images_seq_mask + [False] - assert len(tokenized_str) == len( - images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + assert len(tokenized_str) == len(images_seq_mask), ( + f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + ) - return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens + return ( + tokenized_str, + images_list, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + ) AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor) diff --git a/vllm/transformers_utils/processors/ovis.py b/vllm/transformers_utils/processors/ovis.py index 557d251c45f3b..58c1b1a91658b 100644 --- a/vllm/transformers_utils/processors/ovis.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # coding=utf-8 # adapted from https://github.com/AIDC-AI/Ovis/blob/35ab51a1a1e3542fa6db260a1084cefbc8f164bb/ovis/vllm/processing_ovis.py @@ -30,32 +29,32 @@ import PIL import torch from transformers import AutoProcessor, BatchFeature from transformers.image_utils import ImageInput -from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, - Unpack) +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.multimodal.image import convert_image_mode -__all__ = ['OvisProcessor'] +__all__ = ["OvisProcessor"] IGNORE_ID = -100 -class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] + +class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": { - 'max_partition':9, - 'covering_threshold':0.9, - 'convert_to_rgb':True, - 'return_tensors':'pt'}, + "max_partition": 9, + "covering_threshold": 0.9, + "convert_to_rgb": True, + "return_tensors": "pt", + }, } - class OvisProcessor(ProcessorMixin): r""" - Constructs a Ovis processor which wraps a Ovis image processor and a Qwen2 tokenizer into a single processor. + Constructs an Ovis processor which wraps an Ovis image processor and a Qwen2 tokenizer into a single processor. [`OvisProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information. Args: @@ -98,14 +97,16 @@ class OvisProcessor(ProcessorMixin): "image_col_sep": -303, "image_row_sep": -304, "image_end": -305, - 'image_pad': image_pad_token_id, + "image_pad": image_pad_token_id, } return extra_special_tokens def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text: Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] = None, **kwargs: Unpack[OvisProcessorKwargs], ) -> BatchFeature: """ @@ -170,7 +171,6 @@ class OvisProcessor(ProcessorMixin): # Process text input if text is not None: - if not isinstance(text, list): text = [text] @@ -179,7 +179,10 @@ class OvisProcessor(ProcessorMixin): replaced_ids_list = [] idx = 0 for ids_tensor in tokenized_batched_text: - if image_token_id in ids_tensor and "image_placeholders" in image_features: + if ( + image_token_id in ids_tensor + and "image_placeholders" in image_features + ): if idx < len(image_features["image_placeholders"]): # Converts in list for ease of use ids_list = ids_tensor.tolist() @@ -189,7 +192,9 @@ class OvisProcessor(ProcessorMixin): # replace placeholders for i, token_id in enumerate(ids_list): if token_id == image_token_id: - placeholder_ids = image_features["image_placeholders"][idx] + placeholder_ids = image_features["image_placeholders"][ + idx + ] new_ids.extend(placeholder_ids) idx += 1 else: @@ -199,7 +204,8 @@ class OvisProcessor(ProcessorMixin): ids_tensor = torch.tensor(new_ids, dtype=torch.long) else: raise RuntimeError( - 'Mismatch between the images you provided and the number of placeholder present in the text') + "Mismatch between the images you provided and the number of placeholder present in the text" + ) replaced_ids_list.append(ids_tensor) @@ -218,7 +224,7 @@ class OvisProcessor(ProcessorMixin): # Add image features if present if image_features: output["pixel_values"] = processed_images - output['grids'] = grids + output["grids"] = grids return output @@ -228,8 +234,10 @@ class OvisProcessor(ProcessorMixin): def _tokenize_with_image_symbol(self, text_list: list[str]) -> torch.LongTensor: batch_token_ids = [] for text in text_list: - text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in - text.split(self.image_token)] + text_chunks = [ + self.tokenizer(chunk, add_special_tokens=False).input_ids + for chunk in text.split(self.image_token) + ] token_ids = [] num_chuck = len(text_chunks) for i, chunk in enumerate(text_chunks): @@ -241,50 +249,60 @@ class OvisProcessor(ProcessorMixin): def get_image_size(self): size = self.image_processor.size - if 'shortest_edge' in size: - width = height = size['shortest_edge'] + if "shortest_edge" in size: + width = height = size["shortest_edge"] elif "height" in size and "width" in size: - width = size['width'] - height = size['height'] + width = size["width"] + height = size["height"] else: - raise ValueError( "Can't parse image size from image_processor config.") + raise ValueError("Can't parse image size from image_processor config.") return height, width def get_token_value(self, tok): return self.extra_special_tokens[tok] def construct_image_indicators(self, grid): - image_placeholders = [self.get_token_value('image_start'), - self.get_token_value('image_atom'), - self.get_token_value('image_prefix')] + image_placeholders = [ + self.get_token_value("image_start"), + self.get_token_value("image_atom"), + self.get_token_value("image_prefix"), + ] if grid[0] * grid[1] > 1: for r in range(grid[0]): for c in range(grid[1]): - image_placeholders.append(self.get_token_value('image_atom') ) + image_placeholders.append(self.get_token_value("image_atom")) if c < grid[1] - 1: - image_placeholders.append(self.get_token_value('image_col_sep')) + image_placeholders.append(self.get_token_value("image_col_sep")) if r < grid[0] - 1: - image_placeholders.append(self.get_token_value('image_row_sep')) - image_placeholders.append(self.get_token_value('image_end')) + image_placeholders.append(self.get_token_value("image_row_sep")) + image_placeholders.append(self.get_token_value("image_end")) return image_placeholders def construct_image_placeholders(self, grid): - image_placeholders = self.construct_image_indicators(grid) - image_atom_token_id = self.get_token_value('image_atom') + image_atom_token_id = self.get_token_value("image_atom") # Extract the padding token ID from tokenizer - image_padding_token_id = self.get_token_value('image_pad') + image_padding_token_id = self.get_token_value("image_pad") # Create a new list with padding tokens inserted padded_placeholder_tokens = [] for token in image_placeholders: padded_placeholder_tokens.append(image_padding_token_id) if token == image_atom_token_id: - padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len) + padded_placeholder_tokens.extend( + [image_padding_token_id] * self.image_segment_len + ) return padded_placeholder_tokens - def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors): + def preprocess_image( + self, + image: PIL.Image.Image, + max_partition, + covering_threshold, + convert_to_rgb, + return_tensors, + ): def _preprocess(img: PIL.Image.Image, side): # first resize and preprocess w, h = img.size @@ -297,19 +315,27 @@ class OvisProcessor(ProcessorMixin): new_height = side new_width = int(w / h * new_height) new_size = dict(height=new_height, width=new_width) - pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors=return_tensors)['pixel_values'] + pixel_values = self.image_processor.preprocess( + img, size=new_size, return_tensors=return_tensors + )["pixel_values"] # then pad to square - square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) + square_values = torch.zeros( + [1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device + ) new_height, new_width = pixel_values.shape[2:] if new_height == new_width: square_values[:, :, :, :] = pixel_values elif new_height > new_width: from_index = (side - new_width) // 2 - square_values[:, :, :, from_index:from_index + new_width] = pixel_values + square_values[:, :, :, from_index : from_index + new_width] = ( + pixel_values + ) else: from_index = (side - new_height) // 2 - square_values[:, :, from_index:from_index + new_height, :] = pixel_values + square_values[:, :, from_index : from_index + new_height, :] = ( + pixel_values + ) return square_values @@ -351,7 +377,9 @@ class OvisProcessor(ProcessorMixin): good_grids = [] for grid in candidate_grids: partition = _partition(img, grid) - covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area + covering_ratio = ( + sum([_covering_area(*p, side) for p in partition]) / img_area + ) assert covering_ratio <= 1.0 all_grids.append((grid, covering_ratio)) if covering_ratio > covering_threshold: @@ -359,18 +387,19 @@ class OvisProcessor(ProcessorMixin): if len(good_grids) > 0: # pick the good partition with minimum #sub_images and break the tie using covering_ratio - return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] + return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][ + 0 + ] else: # pick the partition with maximum covering_ratio and break the tie using #sub_images return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] if convert_to_rgb: - image = convert_image_mode(image, 'RGB') - + image = convert_image_mode(image, "RGB") sides = self.get_image_size() if sides[0] != sides[1]: - raise ValueError('get_image_size() returns non-square size') + raise ValueError("get_image_size() returns non-square size") side = sides[0] grid = _get_best_grid(image, side) partition = _partition(image, grid) @@ -379,7 +408,7 @@ class OvisProcessor(ProcessorMixin): crops.insert(0, image) pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0) image_placeholders = self.construct_image_placeholders(grid) - return pixel_values, image_placeholders, grid + return torch.tensor(pixel_values), image_placeholders, torch.tensor(grid) def batch_decode(self, *args, **kwargs): """ @@ -406,14 +435,18 @@ class OvisProcessor(ProcessorMixin): `list[str]`: The decoded text. """ return self.tokenizer.batch_decode( - generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, ) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names - names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + names_from_processor = list( + dict.fromkeys(tokenizer_input_names + image_processor_input_names) + ) return names_from_processor + ["second_per_grid_ts"] diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py index d3273257ff8c2..bacc58c78b3f6 100644 --- a/vllm/transformers_utils/processors/ovis2_5.py +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -9,42 +9,40 @@ import PIL import torch from transformers import AutoProcessor, BatchFeature from transformers.image_utils import ImageInput -from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, - Unpack) +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -__all__ = ['Ovis2_5Processor'] +__all__ = ["Ovis2_5Processor"] IMAGE_TOKEN = "<image>" VIDEO_TOKEN = "<video>" MIN_PIXELS = 448 * 448 MAX_PIXELS = 1792 * 1792 -class Ovis2_5ProcessorKwargs(ProcessingKwargs, - total=False): # type: ignore[call-arg] +class Ovis2_5ProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": { - 'convert_to_rgb': True, - 'min_pixels': MIN_PIXELS, - 'max_pixels': MAX_PIXELS, + "convert_to_rgb": True, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, }, "videos_kwargs": { - 'convert_to_rgb': True, - 'min_pixels': MIN_PIXELS, - 'max_pixels': MAX_PIXELS, - } + "convert_to_rgb": True, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + }, } class Ovis2_5Processor(ProcessorMixin): r""" - Constructs a Ovis processor which wraps a Ovis image processor + Constructs an Ovis processor which wraps an Ovis image processor and a Qwen2 tokenizer into a single processor. - [`OvisProcessor`] offers all the functionalities of - [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. + [`OvisProcessor`] offers all the functionalities of + [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information. Args: @@ -81,9 +79,7 @@ class Ovis2_5Processor(ProcessorMixin): self.patch_size = patch_size self.hidden_stride = hidden_stride self.temporal_patch_size = temporal_patch_size - super().__init__(image_processor, - tokenizer, - chat_template=chat_template) + super().__init__(image_processor, tokenizer, chat_template=chat_template) @cached_property def extra_special_tokens(self): @@ -96,7 +92,7 @@ class Ovis2_5Processor(ProcessorMixin): "image_end": -302, "video_start": -303, "video_end": -304, - 'image_pad': image_pad_token_id, + "image_pad": image_pad_token_id, } return extra_special_tokens @@ -104,8 +100,9 @@ class Ovis2_5Processor(ProcessorMixin): self, images: ImageInput = None, videos: Union[np.ndarray, list[ImageInput]] = None, - text: Union[TextInput, PreTokenizedInput, list[TextInput], - list[PreTokenizedInput]] = None, + text: Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] = None, **kwargs: Unpack[Ovis2_5ProcessorKwargs], ) -> BatchFeature: """ @@ -148,9 +145,9 @@ class Ovis2_5Processor(ProcessorMixin): [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- list of token ids to be fed to a model. Returned when `text` is not `None`. - - **attention_mask** -- list of indices specifying which tokens + - **attention_mask** -- list of indices specifying which tokens should be attended to by the model (when - `return_attention_mask=True` or if *"attention_mask"* + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. @@ -177,9 +174,9 @@ class Ovis2_5Processor(ProcessorMixin): grids = [] # Process each image for image in images if isinstance(images, list) else [images]: - pixel_values, image_placeholders, grid = ( - self.preprocess_multidata( - images=image, **output_kwargs["images_kwargs"])) + pixel_values, image_placeholders, grid = self.preprocess_multidata( + images=image, **output_kwargs["images_kwargs"] + ) processed_images.append(pixel_values) image_placeholders_list.append(image_placeholders) grids.append(grid) @@ -196,16 +193,15 @@ class Ovis2_5Processor(ProcessorMixin): grids = [] # Process each video for video in videos if isinstance(videos, list) else [videos]: - pixel_values, video_placeholders, grid = ( - self.preprocess_multidata( - video=video, **output_kwargs["videos_kwargs"])) + pixel_values, video_placeholders, grid = self.preprocess_multidata( + video=video, **output_kwargs["videos_kwargs"] + ) processed_videos.append(pixel_values) videos_placeholders_list.append(video_placeholders) grids.append(grid) # assign all processed videos if processed_videos: - visual_features[ - "video_placeholders"] = videos_placeholders_list + visual_features["video_placeholders"] = videos_placeholders_list output["video_pixel_values"] = processed_videos output["video_grids"] = grids @@ -220,14 +216,16 @@ class Ovis2_5Processor(ProcessorMixin): image_idx = 0 video_idx = 0 for ids_tensor in tokenized_batched_text: - has_image_tokens = (image_token_id in ids_tensor - and "image_placeholders" in visual_features - and image_idx < len( - visual_features["image_placeholders"])) - has_video_tokens = (video_token_id in ids_tensor - and "video_placeholders" in visual_features - and video_idx < len( - visual_features["video_placeholders"])) + has_image_tokens = ( + image_token_id in ids_tensor + and "image_placeholders" in visual_features + and image_idx < len(visual_features["image_placeholders"]) + ) + has_video_tokens = ( + video_token_id in ids_tensor + and "video_placeholders" in visual_features + and video_idx < len(visual_features["video_placeholders"]) + ) if has_image_tokens or has_video_tokens: # Convert to list for easier manipulation ids_list = ids_tensor.tolist() @@ -237,13 +235,13 @@ class Ovis2_5Processor(ProcessorMixin): for token_id in ids_list: if token_id == image_token_id: new_ids.extend( - visual_features["image_placeholders"] - [image_idx]) + visual_features["image_placeholders"][image_idx] + ) image_idx += 1 elif token_id == video_token_id: new_ids.extend( - visual_features["video_placeholders"] - [video_idx]) + visual_features["video_placeholders"][video_idx] + ) video_idx += 1 else: new_ids.append(token_id) @@ -260,8 +258,7 @@ class Ovis2_5Processor(ProcessorMixin): # If only images were provided return BatchFeature(data=visual_features) - def _tokenize_with_visual_symbol(self, - text_list: list[str]) -> torch.LongTensor: + def _tokenize_with_visual_symbol(self, text_list: list[str]) -> torch.LongTensor: batch_token_ids = [] for text in text_list: token_ids = [] @@ -288,21 +285,24 @@ class Ovis2_5Processor(ProcessorMixin): return torch.tensor(batch_token_ids, dtype=torch.long) # Copied from qwen2_vl - def smart_resize(self, - height: int, - width: int, - factor: int = 28, - min_pixels: int = MIN_PIXELS, - max_pixels: int = MAX_PIXELS): + def smart_resize( + self, + height: int, + width: int, + factor: int = 28, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ): """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. - 2. The total number of pixels is within the range + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ if height < factor or width < factor: - print(f"height:{height} or width:{width} must be " - f"larger than factor:{factor}") + print( + f"height:{height} or width:{width} must be larger than factor:{factor}" + ) if height < width: width = round(factor / height * width) height = factor @@ -311,8 +311,10 @@ class Ovis2_5Processor(ProcessorMixin): width = factor elif max(height, width) / min(height, width) > 200: - print(f"absolute aspect ratio must be smaller than 200, " - f"got {max(height, width) / min(height, width)}") + print( + f"absolute aspect ratio must be smaller than 200, " + f"got {max(height, width) / min(height, width)}" + ) if height > width: height = 200 * width else: @@ -335,29 +337,27 @@ class Ovis2_5Processor(ProcessorMixin): def construct_visual_indicators(self, grid, is_video: bool = False): if is_video: - start_token = self.get_token_value('video_start') - end_token = self.get_token_value('video_end') + start_token = self.get_token_value("video_start") + end_token = self.get_token_value("video_end") else: - start_token = self.get_token_value('image_start') - end_token = self.get_token_value('image_end') + start_token = self.get_token_value("image_start") + end_token = self.get_token_value("image_end") - image_placeholders = [start_token, self.get_token_value('visual_atom')] + image_placeholders = [start_token, self.get_token_value("visual_atom")] if grid[0] * grid[1] > 1: for r in range(grid[0]): for c in range(grid[1]): - image_placeholders.append( - self.get_token_value('visual_atom')) + image_placeholders.append(self.get_token_value("visual_atom")) image_placeholders.append(end_token) return image_placeholders def construct_visual_placeholders(self, grid, is_video: bool = False): - visual_placeholders = self.construct_visual_indicators((1, 1), - is_video) + visual_placeholders = self.construct_visual_indicators((1, 1), is_video) - image_atom_token_id = self.get_token_value('visual_atom') + image_atom_token_id = self.get_token_value("visual_atom") # Extract the padding token ID from tokenizer - image_padding_token_id = self.get_token_value('image_pad') + image_padding_token_id = self.get_token_value("image_pad") num_image_atoms = grid[0] * grid[1] * grid[2] num_image_atoms //= self.hidden_stride**2 @@ -367,8 +367,9 @@ class Ovis2_5Processor(ProcessorMixin): padded_placeholder_tokens = [] for token in visual_placeholders: if token == image_atom_token_id: - padded_placeholder_tokens.extend([image_padding_token_id] * - num_image_atoms) + padded_placeholder_tokens.extend( + [image_padding_token_id] * num_image_atoms + ) else: padded_placeholder_tokens.append(image_padding_token_id) return padded_placeholder_tokens @@ -380,7 +381,7 @@ class Ovis2_5Processor(ProcessorMixin): convert_to_rgb: Optional[bool] = True, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, - return_tensors: Optional[str] = 'pt', + return_tensors: Optional[str] = "pt", ): is_video = False if images is not None: @@ -396,11 +397,14 @@ class Ovis2_5Processor(ProcessorMixin): images.append(image) elif isinstance(video, list): images = video - min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS, - min_pixels if min_pixels is not None else MIN_PIXELS) + else: + raise ValueError("Either images or video should be provided.") + min_pixels = min( + max_pixels if max_pixels is not None else MAX_PIXELS, + min_pixels if min_pixels is not None else MIN_PIXELS, + ) images = [ - image.convert("RGB") - if convert_to_rgb and image.mode != 'RGB' else image + image.convert("RGB") if convert_to_rgb and image.mode != "RGB" else image for image in images ] @@ -417,14 +421,16 @@ class Ovis2_5Processor(ProcessorMixin): ) new_size = dict(height=resized_height, width=resized_width) image_pt = self.image_processor.preprocess( - image, size=new_size, return_tensors="np")['pixel_values'][0] + image, size=new_size, return_tensors="np" + )["pixel_values"][0] processed_images.append(image_pt) patches = np.array(processed_images) if patches.shape[0] % self.temporal_patch_size != 0: - num_to_pad = self.temporal_patch_size - (patches.shape[0] % - self.temporal_patch_size) + num_to_pad = self.temporal_patch_size - ( + patches.shape[0] % self.temporal_patch_size + ) repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0) patches = np.concatenate([patches, repeats], axis=0) channel = patches.shape[1] @@ -445,14 +451,18 @@ class Ovis2_5Processor(ProcessorMixin): ) patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( - grid_t * grid_h * grid_w, channel * self.temporal_patch_size * - self.patch_size * self.patch_size) + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ) visual_placeholders = self.construct_visual_placeholders( - [grid_t, grid_h, grid_w], is_video) - return torch.tensor( - flatten_patches), visual_placeholders, torch.tensor( - [[grid_t, grid_h, grid_w]]) + [grid_t, grid_h, grid_w], is_video + ) + return ( + torch.tensor(flatten_patches), + visual_placeholders, + torch.tensor([[grid_t, grid_h, grid_w]]), + ) AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor) diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py new file mode 100644 index 0000000000000..ec60d66e5cff2 --- /dev/null +++ b/vllm/transformers_utils/runai_utils.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import os +import shutil +import signal +from typing import Optional + +from vllm import envs +from vllm.assets.base import get_cache_dir +from vllm.logger import init_logger +from vllm.utils import PlaceholderModule + +logger = init_logger(__name__) + +SUPPORTED_SCHEMES = ["s3://", "gs://"] + +try: + from runai_model_streamer import list_safetensors as runai_list_safetensors + from runai_model_streamer import pull_files as runai_pull_files +except (ImportError, OSError): + # see https://github.com/run-ai/runai-model-streamer/issues/26 + # OSError will be raised on arm64 platform + runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] + runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") + runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors") + + +def list_safetensors(path: str = "") -> list[str]: + """ + List full file names from object path and filter by allow pattern. + + Args: + path: The object storage path to list from. + + Returns: + list[str]: List of full object storage paths allowed by the pattern + """ + return runai_list_safetensors(path) + + +def is_runai_obj_uri(model_or_path: str) -> bool: + return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES)) + + +class ObjectStorageModel: + """ + A class representing an ObjectStorage model mirrored into a + temporary directory. + + Attributes: + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from object storage to the temporary directory. + """ + + def __init__(self, url: str) -> None: + if envs.VLLM_ASSETS_CACHE_MODEL_CLEAN: + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + + dir_name = os.path.join( + get_cache_dir(), + "model_streamer", + hashlib.sha256(str(url).encode()).hexdigest()[:8], + ) + if os.path.exists(dir_name): + shutil.rmtree(dir_name) + os.makedirs(dir_name) + self.dir = dir_name + logger.debug("Init object storage, model cache path is: %s", dir_name) + + def _close(self) -> None: + if os.path.exists(self.dir): + shutil.rmtree(self.dir) + + def _close_by_signal(self, existing_handler=None): + def new_handler(signum, frame): + self._close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + def pull_files( + self, + model_path: str = "", + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, + ) -> None: + """ + Pull files from object storage into the temporary directory. + + Args: + model_path: The object storage path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ + if not model_path.endswith("/"): + model_path = model_path + "/" + runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern) diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index f95aae7815e0b..ef30efd80b1f7 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -2,15 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import fnmatch -import os -import shutil -import signal -import tempfile -from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.utils import PlaceholderModule +if TYPE_CHECKING: + from botocore.client import BaseClient + try: import boto3 except ImportError: @@ -19,21 +17,25 @@ except ImportError: def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: return [ - path for path in paths if any( - fnmatch.fnmatch(path, pattern) for pattern in patterns) + path + for path in paths + if any(fnmatch.fnmatch(path, pattern) for pattern in patterns) ] def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: return [ - path for path in paths + path + for path in paths if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) ] -def glob(s3=None, - path: str = "", - allow_pattern: Optional[list[str]] = None) -> list[str]: +def glob( + s3: Optional["BaseClient"] = None, + path: str = "", + allow_pattern: Optional[list[str]] = None, +) -> list[str]: """ List full file names from S3 path and filter by allow pattern. @@ -49,17 +51,15 @@ def glob(s3=None, s3 = boto3.client("s3") if not path.endswith("/"): path = path + "/" - bucket_name, _, paths = list_files(s3, - path=path, - allow_pattern=allow_pattern) + bucket_name, _, paths = list_files(s3, path=path, allow_pattern=allow_pattern) return [f"s3://{bucket_name}/{path}" for path in paths] def list_files( - s3, - path: str, - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None + s3: "BaseClient", + path: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, ) -> tuple[str, str, list[str]]: """ List files from S3 path and filter by pattern. @@ -73,17 +73,17 @@ def list_files( Returns: tuple[str, str, list[str]]: A tuple where: - The first element is the bucket name - - The second element is string represent the bucket + - The second element is string represent the bucket and the prefix as a dir like string - - The third element is a list of files allowed or + - The third element is a list of files allowed or disallowed by pattern """ - parts = path.removeprefix('s3://').split('/') - prefix = '/'.join(parts[1:]) + parts = path.removeprefix("s3://").split("/") + prefix = "/".join(parts[1:]) bucket_name = parts[0] objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) - paths = [obj['Key'] for obj in objects.get('Contents', [])] + paths = [obj["Key"] for obj in objects.get("Contents", [])] paths = _filter_ignore(paths, ["*/"]) if allow_pattern is not None: @@ -93,70 +93,3 @@ def list_files( paths = _filter_ignore(paths, ignore_pattern) return bucket_name, prefix, paths - - -class S3Model: - """ - A class representing a S3 model mirrored into a temporary directory. - - Attributes: - s3: S3 client. - dir: The temporary created directory. - - Methods: - pull_files(): Pull model from S3 to the temporary directory. - """ - - def __init__(self) -> None: - self.s3 = boto3.client('s3') - for sig in (signal.SIGINT, signal.SIGTERM): - existing_handler = signal.getsignal(sig) - signal.signal(sig, self._close_by_signal(existing_handler)) - - self.dir = tempfile.mkdtemp() - - def __del__(self): - self._close() - - def _close(self) -> None: - if os.path.exists(self.dir): - shutil.rmtree(self.dir) - - def _close_by_signal(self, existing_handler=None): - - def new_handler(signum, frame): - self._close() - if existing_handler: - existing_handler(signum, frame) - - return new_handler - - def pull_files(self, - s3_model_path: str = "", - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None) -> None: - """ - Pull files from S3 storage into the temporary directory. - - Args: - s3_model_path: The S3 path of the model. - allow_pattern: A list of patterns of which files to pull. - ignore_pattern: A list of patterns of which files not to pull. - - """ - if not s3_model_path.endswith("/"): - s3_model_path = s3_model_path + "/" - - bucket_name, base_dir, files = list_files(self.s3, s3_model_path, - allow_pattern, - ignore_pattern) - if len(files) == 0: - return - - for file in files: - destination_file = os.path.join( - self.dir, - file.removeprefix(base_dir).lstrip("/")) - local_dir = Path(destination_file).parent - os.makedirs(local_dir, exist_ok=True) - self.s3.download_file(bucket_name, file, destination_file) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index b3f1977f26cf4..9537295c6dcd2 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -10,16 +10,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from typing_extensions import assert_never from vllm import envs from vllm.logger import init_logger -from vllm.transformers_utils.config import ( - get_sentence_transformer_tokenizer_config) +from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import make_async if TYPE_CHECKING: from vllm.config import ModelConfig @@ -32,8 +30,7 @@ else: logger = init_logger(__name__) -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - TokenizerBase] +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, TokenizerBase] def decode_tokens( @@ -50,8 +47,7 @@ def decode_tokens( settings. """ if skip_special_tokens is not None: - return tokenizer.decode(token_ids, - skip_special_tokens=skip_special_tokens) + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) return tokenizer.decode(token_ids) @@ -95,8 +91,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: tokenizer_all_special_ids = tokenizer.all_special_ids tokenizer_all_special_tokens = tokenizer.all_special_tokens - tokenizer_all_special_tokens_extended = ( - tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended tokenizer_vocab = tokenizer.get_vocab() tokenizer_len = len(tokenizer) @@ -110,7 +105,6 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: max_token_id = max(max_token_id, tokenizer.vocab_size) class CachedTokenizer(tokenizer.__class__): # type: ignore - @property def all_special_ids(self) -> list[int]: return tokenizer_all_special_ids @@ -134,7 +128,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: return tokenizer_len def __reduce__(self): - return get_cached_tokenizer, (tokenizer, ) + return get_cached_tokenizer, (tokenizer,) CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" @@ -151,8 +145,7 @@ def get_tokenizer( download_dir: Optional[str] = None, **kwargs, ) -> AnyTokenizer: - """Gets a tokenizer for the given model name via HuggingFace or ModelScope. - """ + """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -173,13 +166,13 @@ def get_tokenizer( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, # Ignore weights - we only need the tokenizer. - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) tokenizer_name = tokenizer_path if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if "truncation_side" not in kwargs: @@ -195,23 +188,28 @@ def get_tokenizer( is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" if is_from_mistral_org and tokenizer_mode != "mistral": warnings.warn( - 'It is strongly recommended to run mistral models with ' + "It is strongly recommended to run mistral models with " '`--tokenizer-mode "mistral"` to ensure correct ' - 'encoding and decoding.', + "encoding and decoding.", FutureWarning, - stacklevel=2) + stacklevel=2, + ) tokenizer: AnyTokenizer if tokenizer_mode == "mistral": - tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), - revision=revision) + tokenizer = MistralTokenizer.from_pretrained( + str(tokenizer_name), revision=revision + ) elif tokenizer_mode == "custom": from vllm.transformers_utils.tokenizer_base import TokenizerRegistry - tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name), - *args, - revision=revision, - download_dir=download_dir, - **kwargs) + + tokenizer = TokenizerRegistry.get_tokenizer( + str(tokenizer_name), + *args, + revision=revision, + download_dir=download_dir, + **kwargs, + ) else: try: tokenizer = AutoTokenizer.from_pretrained( @@ -226,13 +224,16 @@ def get_tokenizer( # currently being imported, # suggest using the --trust-remote-code flag. if not trust_remote_code and ( - "does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e)): - err_msg = ("Failed to load the tokenizer. If the tokenizer " - "is a custom tokenizer not yet available in the " - "HuggingFace transformers library, consider " - "setting `trust_remote_code=True` in LLM or using " - "the `--trust-remote-code` flag in the CLI.") + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -240,19 +241,21 @@ def get_tokenizer( # The special_tokens in tokenizer should also be # controlled by do_lower_case in encoder_config encoder_config = get_sentence_transformer_tokenizer_config( - tokenizer_name, revision) + tokenizer_name, revision + ) if isinstance(encoder_config, dict) and encoder_config.get( - "do_lower_case", False): + "do_lower_case", False + ): special_tokens_map = { - k: v.lower() - for k, v in tokenizer.special_tokens_map.items() + k: v.lower() for k, v in tokenizer.special_tokens_map.items() } tokenizer.add_special_tokens(special_tokens_map) if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead.") + "slowdown. Consider using a fast tokenizer instead." + ) tokenizer = get_cached_tokenizer(tokenizer) return tokenizer @@ -274,20 +277,19 @@ def cached_tokenizer_from_config( ) -def get_lora_tokenizer(lora_request: LoRARequest, *args, - **kwargs) -> Optional[AnyTokenizer]: - if lora_request is None: - return None - try: - tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs) - except Exception as e: - # No tokenizer was found in the LoRA folder, - # use base model tokenizer - logger.warning( - "No tokenizer found in %s, using base model tokenizer instead. " - "(Exception: %s)", lora_request.lora_path, e) - tokenizer = None - return tokenizer +def init_tokenizer_from_configs(model_config: ModelConfig): + runner_type = model_config.runner_type + if runner_type == "generate" or runner_type == "draft": + truncation_side = "left" + elif runner_type == "pooling": + truncation_side = "right" + else: + assert_never(runner_type) - -get_lora_tokenizer_async = make_async(get_lora_tokenizer) + return get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision, + truncation_side=truncation_side, + ) diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index 20e5fea714e70..2d64265abbf21 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: class TokenizerBase(ABC): - @property @abstractmethod def all_special_tokens_extended(self) -> list[str]: @@ -61,6 +60,11 @@ class TokenizerBase(ABC): def max_token_id(self) -> int: raise NotImplementedError() + @property + @abstractmethod + def truncation_side(self) -> str: + raise NotImplementedError() + def __len__(self) -> int: return self.vocab_size @@ -93,18 +97,22 @@ class TokenizerBase(ABC): raise NotImplementedError() @abstractmethod - def encode(self, - text: str, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: + def encode( + self, + text: str, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, + add_special_tokens: Optional[bool] = None, + ) -> list[int]: raise NotImplementedError() @abstractmethod - def apply_chat_template(self, - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, - **kwargs) -> list[int]: + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs, + ) -> list[int]: raise NotImplementedError() @abstractmethod @@ -112,9 +120,9 @@ class TokenizerBase(ABC): raise NotImplementedError() @abstractmethod - def decode(self, - ids: Union[list[int], int], - skip_special_tokens: bool = True) -> str: + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: raise NotImplementedError() @abstractmethod diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py deleted file mode 100644 index a8bb0398dfdb1..0000000000000 --- a/vllm/transformers_utils/tokenizer_group.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -from typing_extensions import assert_never - -from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, - get_lora_tokenizer, - get_lora_tokenizer_async, - get_tokenizer) -from vllm.utils import LRUCache - - -class TokenizerGroup: - """A group of tokenizers that can be used for LoRA adapters.""" - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): - self.tokenizer_id = tokenizer_id - self.tokenizer_config = tokenizer_config - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - max_loras = tokenizer_config.get("max_loras", 0) - self.lora_tokenizers = LRUCache[int, AnyTokenizer]( - capacity=max(max_loras, max_num_seqs) if enable_lora else 0) - - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - return self.max_input_length - - def _raise_if_input_too_long(self, - encoded_tokens: list[int], - lora_request: Optional[LoRARequest] = None): - input_length = len(encoded_tokens) - if lora_request: - max_input_length = (lora_request.long_lora_max_len - or self.max_input_length) - else: - max_input_length = self.max_input_length - if max_input_length is not None and input_length > max_input_length: - raise ValueError("Input too long.", input_length, max_input_length) - - def encode(self, - prompt: str, - max_length: Optional[int] = None, - truncation: Optional[bool] = None, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: - - tokenizer = self.get_lora_tokenizer(lora_request) - ret = encode_tokens(tokenizer, - prompt, - max_length=max_length, - truncation=truncation, - add_special_tokens=add_special_tokens) - self._raise_if_input_too_long(ret, lora_request) - return ret - - async def encode_async( - self, - prompt: str, - max_length: Optional[int] = None, - truncation: Optional[bool] = None, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: - tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = encode_tokens(tokenizer, - prompt, - max_length=max_length, - truncation=truncation, - add_special_tokens=add_special_tokens) - self._raise_if_input_too_long(ret, lora_request) - return ret - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (get_lora_tokenizer( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers[lora_request.lora_int_id] - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (await get_lora_tokenizer_async( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers[lora_request.lora_int_id] - - -def init_tokenizer_from_configs(model_config: ModelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig]): - runner_type = model_config.runner_type - if runner_type == "generate" or runner_type == "draft": - truncation_side = "left" - elif runner_type == "pooling": - truncation_side = "right" - else: - assert_never(runner_type) - - return TokenizerGroup( - tokenizer_id=model_config.tokenizer, - enable_lora=bool(lora_config), - max_num_seqs=scheduler_config.max_num_seqs, - max_loras=lora_config.max_loras if lora_config else 0, - max_input_length=None, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=truncation_side) diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 941156c4bf50e..b63cb26af46dd 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, - truncate_tool_call_ids, validate_request_params) +from .mistral import ( + MistralTokenizer, + maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params, +) __all__ = [ - "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids", - "validate_request_params" + "MistralTokenizer", + "maybe_serialize_tool_calls", + "truncate_tool_call_ids", + "validate_request_params", ] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 4dd8b2439b3f5..eae067fcfa344 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -1,33 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union, cast -import huggingface_hub -import regex as re -from huggingface_hub import HfApi, hf_hub_download -from transformers.tokenization_utils_base import BatchEncoding - from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_base import TokenizerBase -from vllm.utils import is_list_of if TYPE_CHECKING: - # make sure `mistral_common` is lazy imported, - # so that users who only use non-mistral models - # will not be bothered by the dependency. - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer) + from mistral_common.protocol.instruct.request import ( + ChatCompletionRequest as MistralChatCompletionRequest, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, + ) from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + from vllm.entrypoints.openai.protocol import ChatCompletionRequest logger = init_logger(__name__) -def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): +def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"): # SEE: https://github.com/vllm-project/vllm/pull/9951 # Credits go to: @gcalmettes # NOTE: There is currently a bug in pydantic where attributes @@ -51,7 +45,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): # - https://github.com/pydantic/pydantic/issues/9541 # TODO: remove when pydantic v2.11 is released for i, message in enumerate(request.messages): - if message.get("role") == 'assistant': + if message.get("role") == "assistant": tool_calls_validator = message.get("tool_calls", ().__iter__()) validated_tool_calls = [] while True: @@ -64,10 +58,10 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): request.messages[i]["tool_calls"] = validated_tool_calls -def truncate_tool_call_ids(request: "ChatCompletionRequest"): +def truncate_tool_call_ids(request: "MistralChatCompletionRequest"): """Truncates tool call IDs for Mistral's ID requirements.""" for i, message in enumerate(request.messages): - if message.get("role") == 'assistant': + if message.get("role") == "assistant": tool_calls = message.get("tool_calls", []) for tool_call in tool_calls: if len(tool_call["id"]) > 9: @@ -94,74 +88,34 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): request.messages[i]["tool_call_id"] = tool_call_id -def validate_request_params(request: "ChatCompletionRequest"): - if (request.skip_special_tokens is not None - and not request.skip_special_tokens): - raise ValueError("skip_special_tokens=False is not supported " - "for Mistral tokenizers.") +def _prepare_apply_chat_template_tools_and_messages( + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + continue_final_message: bool = False, + add_generation_prompt: bool = False, +) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]: + if add_generation_prompt and continue_final_message: + raise ValueError( + "Cannot set both `add_generation_prompt` and " + "`continue_final_message` to True." + ) - -def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]: - repo_cache = os.path.join( - huggingface_hub.constants.HF_HUB_CACHE, - huggingface_hub.constants.REPO_ID_SEPARATOR.join( - ["models", *repo_id.split("/")])) - - if revision is None: - revision_file = os.path.join(repo_cache, "refs", "main") - if os.path.isfile(revision_file): - with open(revision_file) as file: - revision = file.read() - - if revision: - revision_dir = os.path.join(repo_cache, "snapshots", revision) - if os.path.isdir(revision_dir): - return os.listdir(revision_dir) - - return [] - - -def find_tokenizer_file(files: list[str]): - file_pattern = re.compile( - r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") - - matched_files = [file for file in files if file_pattern.match(file)] - if len(matched_files) > 1: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral " - f"tokenizer is present in {files}.") - elif len(matched_files) == 0: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " - f"tokenizer is present in {files}.") - - return matched_files[0] - - -def _aggregate_content(content: list) -> list[dict[str, Any]]: - aggregated_content: list[dict[str, Any]] = [] - for chunk in content: - if chunk.get("type" - ) == "text" and aggregated_content and aggregated_content[ - -1].get("type") == "text": - aggregated_content[-1]["text"] += "\n\n" + chunk.get("text") - else: - aggregated_content.append(chunk) - if len(aggregated_content) == 1 and aggregated_content[0].get( - "type") == "text": - content = aggregated_content[0]["text"] - return content - - -def make_mistral_chat_completion_request( - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, - Any]]] = None) -> "ChatCompletionRequest": last_message = cast(dict[str, Any], messages[-1]) - if last_message["role"] == "assistant": - last_message["prefix"] = True + # add_generation_prompt is directly handled by the tokenizer but we + # check if the user is trying to use it with a final assistant message + # which is probably not what they want. + # If add_generation_prompt is False, we don't need to check anything. + if add_generation_prompt and last_message["role"] == "assistant": + raise ValueError( + "Cannot set `add_generation_prompt` to True when " + "the last message is from the assistant. Consider " + "using `continue_final_message` instead." + ) + if continue_final_message and last_message["role"] != "assistant": + raise ValueError( + "Cannot set `continue_final_message` to True when " + "the last message is not from the assistant." + ) # mistral-common requires AssistantMessage content to be string [1]. # @@ -170,136 +124,125 @@ def make_mistral_chat_completion_request( # Remove reasoning_content as unsupported by Mistral _ = message.pop("reasoning_content", None) # type: ignore - # Convert list text content to string - if message.get("role") in ("assistant", "tool"): - content: Any = message.get("content") - if isinstance(content, list): - content = _aggregate_content(content) - message["content"] = content - # The Mistral client, in comparison to the OpenAI client, requires the # "parameters" dict and the "description" string to be present # even if they are empty. if tools: for function in [ - tool["function"] for tool in tools - if tool["type"] == "function" + tool["function"] for tool in tools if tool["type"] == "function" ]: if function.get("parameters") is None: function["parameters"] = {} if function.get("description") is None: function["description"] = "" - from mistral_common.protocol.instruct.request import ChatCompletionRequest - return ChatCompletionRequest(messages=messages, - tools=tools) # type: ignore[type-var] + return messages, tools + + +def validate_request_params(request: "ChatCompletionRequest"): + if request.chat_template is not None or request.chat_template_kwargs is not None: + raise ValueError("chat_template is not supported for Mistral tokenizers.") + + +def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int: + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + assert isinstance(tokenizer, Tekkenizer), type(tokenizer) + + t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t + shift = tokenizer.num_special_tokens + try: + return shift + tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + t_str = t_bytes.decode("utf-8") + if t_str in tokenizer._special_tokens_reverse_vocab: + return tokenizer._special_tokens_reverse_vocab[t_str] + logger.warning( + "Failed to convert token %s to id, replacing with <unk>", t_bytes + ) + return tokenizer.unk_id class MistralTokenizer(TokenizerBase): + def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer - def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: - self.mistral = tokenizer - self.instruct = tokenizer.instruct_tokenizer - _mistral_version_str = self.instruct.tokenizer.version.value + self.transformers_tokenizer = tokenizer + self.mistral = tokenizer.tokenizer + self.instruct = self.mistral.instruct_tokenizer + self.tokenizer = self.instruct.tokenizer + + _mistral_version_str = str(self.tokenizer.version.value) self.version: int = int(_mistral_version_str.split("v")[-1]) - tokenizer_ = tokenizer.instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.tekken import ( - SpecialTokenPolicy, Tekkenizer) - self.is_tekken = isinstance(tokenizer_, Tekkenizer) - from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer) - self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - if self.is_tekken: - # Make sure special tokens will not raise - tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - elif self.is_spm: - pass - else: - raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + self.is_tekken = isinstance(self.tokenizer, Tekkenizer) + self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer) + if not (self.is_tekken or self.is_spm): + raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}") - self._vocab = tokenizer_.vocab() - # Convert to a dict[str, int] to match protocol, but this is a lossy - # conversion. There may be multiple token ids that decode to the same - # string due to partial UTF-8 byte sequences being converted to � + # Reverse order to ensure that the lowest token id is kept. self._vocab_dict = { - token: idx - for idx, token in enumerate(self._vocab) + self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i + for i in range(self.vocab_size - 1, -1, -1) } - self.tokenizer = tokenizer_ + # Sort the dict for convenience + self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) + + # Vocab sorted by token id. + self._vocab = self.tokenizer._vocab self._max_token_id = self.vocab_size - 1 @classmethod - def from_pretrained(cls, - path_or_repo_id: str, - *, - revision: Optional[str] = None) -> "MistralTokenizer": - if not Path(path_or_repo_id).exists(): - assert len(path_or_repo_id.split("/")) == 2, ( - "You have either provided a non-existent path: " - "{path_or_repo_id} or an invalid HF Hub repo id.") - tokenizer_file = cls._download_mistral_tokenizer_from_hf( - path_or_repo_id, revision) - elif Path(path_or_repo_id).is_dir(): - tokenizer_file_name = find_tokenizer_file( - os.listdir(path_or_repo_id)) - tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) - else: - assert Path( - path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" - tokenizer_file = str(Path(path_or_repo_id)) + def from_pretrained( + cls, path_or_repo_id: str, *, revision: Optional[str] = None + ) -> "MistralTokenizer": + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, + ) - from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer) - mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) - return cls(mistral_tokenizer) - - @staticmethod - def _download_mistral_tokenizer_from_hf(tokenizer_name: str, - revision: Optional[str]) -> str: - try: - hf_api = HfApi() - files = hf_api.list_repo_files(repo_id=tokenizer_name, - revision=revision) - except ConnectionError as exc: - files = list_local_repo_files(repo_id=tokenizer_name, - revision=revision) - - if len(files) == 0: - raise exc - - filename = find_tokenizer_file(files) - - tokenizer_file = hf_hub_download(tokenizer_name, - filename=filename, - revision=revision) - return tokenizer_file + str_revision = "main" if revision is None else revision + return cls( + TransformersMistralTokenizer.from_pretrained( + path_or_repo_id, revision=str_revision + ) + ) # the following attributes are set to fit vLLM's design and are used - # by the guided structured output backends. + # by the structured output backends. @property def all_special_tokens_extended(self) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - # tekken defines its own extended special tokens list - if hasattr(self.tokenizer, "SPECIAL_TOKENS"): - special_tokens = self.tokenizer.SPECIAL_TOKENS - else: - special_tokens = list(SpecialTokens) - return [ - s.value if isinstance(s, SpecialTokens) else s - for s in special_tokens - ] + return self.all_special_tokens @property def all_special_tokens(self) -> list[str]: - return self.all_special_tokens_extended + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + + return [ + self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) + for i in self.all_special_ids + ] @property def all_special_ids(self) -> list[int]: - return [ - self.all_special_tokens.index(t) for t in self.all_special_tokens - ] + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) + special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens} + elif self.is_spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + special_ids = self.tokenizer._control_tokens + else: + raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + return sorted(special_ids) @property def bos_token_id(self) -> int: @@ -315,7 +258,7 @@ class MistralTokenizer(TokenizerBase): @property def pad_token(self) -> str: - raise NotImplementedError() + return self.transformers_tokenizer.pad_token @property def is_fast(self) -> bool: @@ -323,12 +266,33 @@ class MistralTokenizer(TokenizerBase): @property def vocab_size(self) -> int: - return len(self._vocab) + return self.transformers_tokenizer.vocab_size @property def max_token_id(self) -> int: return self._max_token_id + @property + def truncation_side(self) -> str: + raise NotImplementedError() + + def _is_special_token_id(self, token_id: int) -> bool: + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + if self.is_spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + return token_id in self.tokenizer._control_tokens + if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) + return token_id < self.tokenizer.num_special_tokens + else: + raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + def __len__(self) -> int: return self.vocab_size @@ -340,25 +304,19 @@ class MistralTokenizer(TokenizerBase): truncation: bool = False, max_length: Optional[int] = None, ): - input_ids: Union[list[int], list[list[int]]] - # For list[str], original prompt text - if is_list_of(text, str): - input_ids_: list[list[int]] = [] - for p in text: - each_input_ids = self.encode_one(p, truncation, max_length) - input_ids_.append(each_input_ids) - input_ids = input_ids_ - # For list[int], apply chat template output, already tokens. - elif is_list_of(text, int): - input_ids = text - # For str, single prompt text - else: - input_ids = self.encode_one(text, truncation, max_length) - return BatchEncoding({"input_ids": input_ids}) + return self.transformers_tokenizer( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + truncation=truncation, + max_length=max_length, + ) + + @property + def vocab(self) -> list[str]: + return self._vocab def get_vocab(self) -> dict[str, int]: - # NB: the dictionary form of the vocabulary collapses token ids that map - # to the same string but have different bytes return self._vocab_dict def get_added_vocab(self) -> dict[str, int]: @@ -372,79 +330,113 @@ class MistralTokenizer(TokenizerBase): max_length: Optional[int] = None, ) -> list[int]: # Mistral Tokenizers should not add special tokens - input_ids = self.encode(text) + return self.transformers_tokenizer.encode( + text, add_special_tokens=False, truncation=truncation, max_length=max_length + ) - if truncation: - input_ids = input_ids[:max_length] - return input_ids - - def encode(self, - text: str, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: - # `encode` should only be used for prompt completion - # it should never be used for chat_completion. - # For chat completion use `apply_chat_template` + def encode( + self, + text: str, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, + add_special_tokens: Optional[bool] = None, + ) -> list[int]: if add_special_tokens is not None: - return self.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) + return self.transformers_tokenizer.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) else: - return self.tokenizer.encode(text, bos=True, eos=False) + encoded = self.tokenizer.encode(text, bos=True, eos=False) - def apply_chat_template(self, - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, - **kwargs) -> list[int]: + if truncation is not False and max_length is not None: + return encoded[:max_length] + else: + return encoded - request = make_mistral_chat_completion_request(messages, tools) - encoded = self.mistral.encode_chat_completion(request) + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs, + ) -> list[int]: + add_generation_prompt = kwargs.pop("add_generation_prompt", False) + continue_final_message = kwargs.get("continue_final_message", False) + padding = kwargs.get("padding", False) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") - # encode-decode to get clean prompt - return encoded.tokens + messages, tools = _prepare_apply_chat_template_tools_and_messages( + messages, tools, continue_final_message, add_generation_prompt + ) + + return self.transformers_tokenizer.apply_chat_template( + conversation=messages, + tools=tools, + continue_final_message=continue_final_message, + tokenize=True, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=None, + return_dict=False, + ) + + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: + return self.transformers_tokenizer.decode( + ids, skip_special_tokens=skip_special_tokens + ) def convert_tokens_to_string(self, tokens: list[str]) -> str: - from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.base import ( + SpecialTokenPolicy, + SpecialTokens, + ) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + to_decode_special_tokens = {SpecialTokens.tool_calls} if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) tokens = [ - t for t in tokens - if (t is SpecialTokens.tool_calls - or t not in self.tokenizer._all_special_tokens) + t + for t in tokens + if (t in to_decode_special_tokens or t not in self.all_special_tokens) ] if any(isinstance(t, bytes) for t in tokens): # we need to encode and decode all tokens again - shift = self.tokenizer.num_special_tokens - - def _token_to_id(t: str): - t_bytes = t.encode("utf-8") \ - if not isinstance(t, bytes) else t - try: - return shift + \ - self.tokenizer._tekken_token2id_nospecial[t_bytes] - except KeyError: - logger.warning( - "Failed to convert token %s to id," - " replacing with <unk>", t_bytes) - return self.tokenizer.unk_id - - ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids) + ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens] + # We filtered unwanted special tokens before + # so we can decode the rest. + decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP) else: decoded = "".join(tokens) else: # make sure certain special tokens like Tool calls are # not decoded - special_tokens = {SpecialTokens.tool_calls} + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + regular_tokens: list[str] = [] - decoded_list = [] + decoded_list: list[str] = [] + decoded = "" for token in tokens: - if token in special_tokens: + if token in to_decode_special_tokens: if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) + self.tokenizer.decode( + regular_tokens, SpecialTokenPolicy.IGNORE + ) + ) regular_tokens = [] decoded_list.append(token) else: @@ -452,65 +444,56 @@ class MistralTokenizer(TokenizerBase): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens)) # type: ignore - - decoded = ''.join(decoded_list) + self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE) + ) + decoded = "".join(decoded_list) return decoded - # WARN: Outlines logits processors can overwrite this method. - # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer - # for more. - def decode(self, - ids: Union[list[int], int], - skip_special_tokens: bool = True) -> str: - assert ( - skip_special_tokens - ), "skip_special_tokens=False is not supported for Mistral tokenizers." - - if isinstance(ids, int): - ids = [ids] - return self.tokenizer.decode(ids) - def convert_ids_to_tokens( self, ids: list[int], skip_special_tokens: bool = True, ) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens - from mistral_common.tokens.tokenizers.instruct import ( - InstructTokenizerV13) + from mistral_common.tokens.tokenizers.base import ( + SpecialTokenPolicy, + SpecialTokens, + ) + from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 - # TODO(Patrick) - potentially allow special tokens to not be skipped - assert ( - skip_special_tokens - ), "skip_special_tokens=False is not supported for Mistral tokenizers." + if not skip_special_tokens: + return [self.tokenizer.id_to_piece(token_id) for token_id in ids] - assert self.is_tekken or self.is_spm, type(self.tokenizer) + non_skip_special_tokens_ids = { + self.tokenizer.get_control_token(SpecialTokens.tool_calls), + } + if isinstance(self.instruct, InstructTokenizerV13): + if self.instruct.BEGIN_THINK: + non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK) + if self.instruct.END_THINK: + non_skip_special_tokens_ids.add(self.instruct.END_THINK) - if self.is_tekken: - # skip special tokens except tool call and think tokens - non_skip_special_tokens = { - self.tokenizer.get_control_token(SpecialTokens.tool_calls) - } - if isinstance(self.instruct, InstructTokenizerV13): - if self.instruct.BEGIN_THINK: - non_skip_special_tokens.add(self.instruct.BEGIN_THINK) - if self.instruct.END_THINK: - non_skip_special_tokens.add(self.instruct.END_THINK) - ids = [ - i for i in ids if i > self.tokenizer.num_special_tokens - or i in non_skip_special_tokens - ] + ids_kept = [ + i + for i in ids + if i in non_skip_special_tokens_ids or not self._is_special_token_id(i) + ] - tokens = [self.tokenizer.id_to_piece(id) for id in ids] + # We filtered unwanted special tokens so we can decode the rest. + tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept] if any("�" in t for t in tokens) and self.is_tekken: # if a decoded token contains the replacement character, then the # token has an incomplete UTF-8 character so we must use bytes # See: https://github.com/vllm-project/vllm/pull/8640 # https://github.com/vllm-project/vllm/pull/9625 - # if underlying tokenizeir is sentencepiece, we just add "�" - tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + # if underlying tokenizer is sentencepiece, we just add "�". + # We filtered unwanted special tokens so we can decode the rest. + tokens = [ + self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP) + if token_id not in self.all_special_ids + else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP) + for token_id in ids_kept + ] return tokens diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 66c8fb797adcd..8952a0b197d69 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import struct from functools import cache from os import PathLike from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union from vllm.envs import VLLM_MODEL_REDIRECT_PATH from vllm.logger import init_logger @@ -14,7 +15,7 @@ logger = init_logger(__name__) def is_s3(model_or_path: str) -> bool: - return model_or_path.lower().startswith('s3://') + return model_or_path.lower().startswith("s3://") def check_gguf_file(model: Union[str, PathLike]) -> bool: @@ -42,13 +43,16 @@ def modelscope_list_repo_files( ) -> list[str]: """List files in a modelscope repo.""" from modelscope.hub.api import HubApi + api = HubApi() api.login(token) # same as huggingface_hub.list_repo_files files = [ - file['Path'] for file in api.get_model_files( - model_id=repo_id, revision=revision, recursive=True) - if file['Type'] == 'blob' + file["Path"] + for file in api.get_model_files( + model_id=repo_id, revision=revision, recursive=True + ) + if file["Type"] == "blob" ] return files @@ -90,10 +94,18 @@ def maybe_model_redirect(model: str) -> str: if not Path(model_redirect_path).exists(): return model - redirect_dict = (_maybe_json_dict(model_redirect_path) - or _maybe_space_split_dict(model_redirect_path)) - if (redirect_model := redirect_dict.get(model)): + redirect_dict = _maybe_json_dict(model_redirect_path) or _maybe_space_split_dict( + model_redirect_path + ) + if redirect_model := redirect_dict.get(model): logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model) return redirect_model return model + + +def parse_safetensors_file_metadata(path: Union[str, PathLike]) -> dict[str, Any]: + with open(path, "rb") as f: + length_of_metadata = struct.unpack("<Q", f.read(8))[0] + metadata = json.loads(f.read(length_of_metadata).decode("utf-8")) + return metadata diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 0fcf5d15afd1d..a475d0fa406bf 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,14 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder, - TritonPlaceholder) +from vllm.triton_utils.importing import ( + HAS_TRITON, + TritonLanguagePlaceholder, + TritonPlaceholder, +) if HAS_TRITON: import triton import triton.language as tl + import triton.language.extra.libdevice as tldevice else: triton = TritonPlaceholder() tl = TritonLanguagePlaceholder() + tldevice = TritonLanguagePlaceholder() -__all__ = ["HAS_TRITON", "triton", "tl"] +__all__ = ["HAS_TRITON", "triton", "tl", "tldevice"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 372200027bf95..e1a509a303c53 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -21,15 +21,15 @@ if HAS_TRITON: # an is_active method. # The `x.driver and` check adds a small layer of safety. active_drivers = [ - x.driver for x in backends.values() - if x.driver and x.driver.is_active() + x.driver for x in backends.values() if x.driver and x.driver.is_active() ] # Check if we're in a distributed environment where CUDA_VISIBLE_DEVICES # might be temporarily empty (e.g., Ray sets it to "" during actor init) cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") - is_distributed_env = (cuda_visible_devices is not None - and len(cuda_visible_devices.strip()) == 0) + is_distributed_env = ( + cuda_visible_devices is not None and len(cuda_visible_devices.strip()) == 0 + ) # Apply lenient driver check for distributed environments if is_distributed_env and len(active_drivers) == 0: @@ -37,38 +37,44 @@ if HAS_TRITON: # active later when CUDA context is properly initialized logger.debug( "Triton found 0 active drivers in distributed environment. " - "This is expected during initialization.") + "This is expected during initialization." + ) elif not is_distributed_env and len(active_drivers) != 1: # Strict check for non-distributed environments logger.info( "Triton is installed but %d active driver(s) found " "(expected 1). Disabling Triton to prevent runtime errors.", - len(active_drivers)) + len(active_drivers), + ) HAS_TRITON = False except ImportError: # This can occur if Triton is partially installed or triton.backends # is missing. logger.warning( "Triton is installed, but `triton.backends` could not be imported. " - "Disabling Triton.") + "Disabling Triton." + ) HAS_TRITON = False except Exception as e: # Catch any other unexpected errors during the check. logger.warning( "An unexpected error occurred while checking Triton active drivers:" - " %s. Disabling Triton.", e) + " %s. Disabling Triton.", + e, + ) HAS_TRITON = False if not HAS_TRITON: - logger.info("Triton not installed or not compatible; certain GPU-related" - " functions will not be available.") + logger.info( + "Triton not installed or not compatible; certain GPU-related" + " functions will not be available." + ) class TritonPlaceholder(types.ModuleType): - def __init__(self): super().__init__("triton") - self.__version__ = "3.3.0" + self.__version__ = "3.4.0" self.jit = self._dummy_decorator("jit") self.autotune = self._dummy_decorator("autotune") self.heuristics = self._dummy_decorator("heuristics") @@ -76,7 +82,6 @@ class TritonPlaceholder(types.ModuleType): self.language = TritonLanguagePlaceholder() def _dummy_decorator(self, name): - def decorator(*args, **kwargs): if args and callable(args[0]): return args[0] @@ -86,10 +91,10 @@ class TritonPlaceholder(types.ModuleType): class TritonLanguagePlaceholder(types.ModuleType): - def __init__(self): super().__init__("triton.language") self.constexpr = None self.dtype = None self.int64 = None self.int32 = None + self.tensor = None diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 92245498de657..ed470ebe88929 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -68,8 +68,7 @@ def is_usage_stats_enabled(): no_usage_stats = envs.VLLM_NO_USAGE_STATS do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) - _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats - or do_not_track_file) + _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats or do_not_track_file) return _USAGE_STATS_ENABLED @@ -80,9 +79,11 @@ def _get_current_timestamp_ns() -> int: def _detect_cloud_provider() -> str: # Try detecting through vendor file vendor_files = [ - "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", + "/sys/class/dmi/id/product_version", + "/sys/class/dmi/id/bios_vendor", "/sys/class/dmi/id/product_name", - "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor" + "/sys/class/dmi/id/chassis_asset_tag", + "/sys/class/dmi/id/sys_vendor", ] # Mapping of identifiable strings to cloud providers cloud_identifiers = { @@ -152,39 +153,53 @@ class UsageMessage: self.log_time: Optional[int] = None self.source: Optional[str] = None - def report_usage(self, - model_architecture: str, - usage_context: UsageContext, - extra_kvs: Optional[dict[str, Any]] = None) -> None: - t = Thread(target=self._report_usage_worker, - args=(model_architecture, usage_context, extra_kvs or {}), - daemon=True) + def report_usage( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: Optional[dict[str, Any]] = None, + ) -> None: + t = Thread( + target=self._report_usage_worker, + args=(model_architecture, usage_context, extra_kvs or {}), + daemon=True, + ) t.start() - def _report_usage_worker(self, model_architecture: str, - usage_context: UsageContext, - extra_kvs: dict[str, Any]) -> None: + def _report_usage_worker( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any], + ) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continuous_usage() - def _report_usage_once(self, model_architecture: str, - usage_context: UsageContext, - extra_kvs: dict[str, Any]) -> None: + def _report_usage_once( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any], + ) -> None: # Platform information from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): self.gpu_count = cuda_device_count_stateless() - self.gpu_type, self.gpu_memory_per_device = ( - cuda_get_device_properties(0, ("name", "total_memory"))) + self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties( + 0, ("name", "total_memory") + ) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda if current_platform.is_tpu(): try: import torch_xla + self.gpu_count = torch_xla.runtime.world_size() self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = ( - torch_xla.core.xla_model.get_memory_info()["bytes_limit"]) + self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ + "bytes_limit" + ] except Exception: logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() @@ -195,11 +210,13 @@ class UsageMessage: info = cpuinfo.get_cpu_info() self.num_cpu = info.get("count", None) self.cpu_type = info.get("brand_raw", "") - self.cpu_family_model_stepping = ",".join([ - str(info.get("family", "")), - str(info.get("model", "")), - str(info.get("stepping", "")) - ]) + self.cpu_family_model_stepping = ",".join( + [ + str(info.get("family", "")), + str(info.get("model", "")), + str(info.get("stepping", "")), + ] + ) # vLLM information self.context = usage_context.value @@ -207,10 +224,9 @@ class UsageMessage: self.model_architecture = model_architecture # Environment variables - self.env_var_json = json.dumps({ - env_var: getattr(envs, env_var) - for env_var in _USAGE_ENV_VARS_TO_COLLECT - }) + self.env_var_json = json.dumps( + {env_var: getattr(envs, env_var) for env_var in _USAGE_ENV_VARS_TO_COLLECT} + ) # Metadata self.log_time = _get_current_timestamp_ns() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 7c34a858c0a21..22c2a4b5362c2 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -33,25 +33,45 @@ import types import uuid import warnings import weakref -from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, - ArgumentTypeError, RawDescriptionHelpFormatter, - _ArgumentGroup) +from argparse import ( + Action, + ArgumentDefaultsHelpFormatter, + ArgumentParser, + ArgumentTypeError, + RawDescriptionHelpFormatter, + _ArgumentGroup, +) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, - Hashable, Iterable, Iterator, KeysView, Mapping, - Sequence) +from collections.abc import ( + AsyncGenerator, + Awaitable, + Collection, + Generator, + Hashable, + Iterable, + Iterator, + Mapping, + Sequence, +) from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps -from types import MappingProxyType -from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, TextIO, TypeVar, Union, cast, overload) +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + TextIO, + TypeVar, + Union, +) from urllib.parse import urlparse from uuid import uuid4 -import cachetools import cbor2 import cloudpickle import numpy as np @@ -78,6 +98,7 @@ if TYPE_CHECKING: from argparse import Namespace from vllm.config import ModelConfig, VllmConfig + from vllm.sequence import IntermediateTensors logger = init_logger(__name__) @@ -87,64 +108,6 @@ DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 -# Exception strings for non-implemented encoder/decoder scenarios - -# Reminder: Please update docs/features/compatibility_matrix.md -# If the feature combo become valid - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ - "Chunked prefill for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( - "Models with logits_soft_cap " - "require FlashInfer backend, which is " - "currently not supported for encoder/decoder " - "models.") - -STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is not currently " - "supported with encoder/decoder " - "models.") - -STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not " - "currently supported with " - "encoder/decoder models.") - -STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently " - "supported with encoder/decoder " - "models.") - -STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not " - "currently supported with encoder/" - "decoder models.") - -STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " - "backends currently supported with encoder/" - "decoder models.") - -# Efficiently import all enc/dec error strings -# rather than having to import all of the above -STR_NOT_IMPL_ENC_DEC_ERR_STRS = { - "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA, - "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL": - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP, - "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA, - "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, - "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, - "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, - "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, -} - # Constants related to forcing the attention backend selection # String name of register which may be set in order to @@ -156,12 +119,16 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" # register, corresponding to possible backends STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" -STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" -STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" +MB_bytes = 1_000_000 +"""The number of bytes in one megabyte (MB).""" + +MiB_bytes = 1 << 20 +"""The number of bytes in one mebibyte (MiB).""" + GB_bytes = 1_000_000_000 """The number of bytes in one gigabyte (GB).""" @@ -169,8 +136,8 @@ GiB_bytes = 1 << 30 """The number of bytes in one gibibyte (GiB).""" # ANSI color codes -CYAN = '\033[1;36m' -RESET = '\033[0;0m' +CYAN = "\033[1;36m" +RESET = "\033[0;0m" STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, @@ -182,6 +149,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { "fp8_e5m2": torch.uint8, "int8": torch.int8, "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { @@ -203,20 +171,12 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) -P = ParamSpec('P') +P = ParamSpec("P") T = TypeVar("T") U = TypeVar("U") _K = TypeVar("_K", bound=Hashable) _V = TypeVar("_V") -_T = TypeVar("_T") - - -class _Sentinel: - ... - - -ALL_PINNED_SENTINEL = _Sentinel() class Device(enum.Enum): @@ -230,7 +190,6 @@ class LayerBlockType(enum.Enum): class Counter: - def __init__(self, start: int = 0) -> None: self.counter = start @@ -243,261 +202,12 @@ class Counter: self.counter = 0 -class _MappingOrderCacheView(UserDict[_K, _V]): - - def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): - super().__init__(data) - self.ordered_keys = ordered_keys - - def __iter__(self) -> Iterator[_K]: - return iter(self.ordered_keys) - - def keys(self) -> KeysView[_K]: - return KeysView(self.ordered_keys) - - -class CacheInfo(NamedTuple): - hits: int - total: int - - @property - def hit_ratio(self) -> float: - if self.total == 0: - return 0 - - return self.hits / self.total - - def __sub__(self, other: CacheInfo): - return CacheInfo( - hits=self.hits - other.hits, - total=self.total - other.total, - ) - - -class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): - - def __init__(self, - capacity: float, - getsizeof: Optional[Callable[[_V], float]] = None): - super().__init__(capacity, getsizeof) - - self.pinned_items = set[_K]() - - self._hits = 0 - self._total = 0 - self._last_info = CacheInfo(hits=0, total=0) - - def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: - value = super().__getitem__(key) - - if update_info: - self._hits += 1 - self._total += 1 - - return value - - def __delitem__(self, key: _K) -> None: - run_on_remove = key in self - value = self.__getitem__(key, - update_info=False) # type: ignore[call-arg] - super().__delitem__(key) - if key in self.pinned_items: - # Todo: add warning to inform that del pinned item - self._unpin(key) - if run_on_remove: - self._on_remove(key, value) - - @property - def cache(self) -> Mapping[_K, _V]: - """Return the internal cache dictionary in order (read-only).""" - return _MappingOrderCacheView( - self._Cache__data, # type: ignore - self.order) - - @property - def order(self) -> Mapping[_K, None]: - """Return the internal order dictionary (read-only).""" - return MappingProxyType(self._LRUCache__order) # type: ignore - - @property - def capacity(self) -> float: - return self.maxsize - - @property - def usage(self) -> float: - if self.maxsize == 0: - return 0 - - return self.currsize / self.maxsize - - def stat(self, *, delta: bool = False) -> CacheInfo: - """ - Gets the cumulative number of hits and queries against this cache. - - If `delta=True`, instead gets these statistics - since the last call that also passed `delta=True`. - """ - info = CacheInfo(hits=self._hits, total=self._total) - - if delta: - info_delta = info - self._last_info - self._last_info = info - info = info_delta - - return info - - def touch(self, key: _K) -> None: - try: - self._LRUCache__order.move_to_end(key) # type: ignore - except KeyError: - self._LRUCache__order[key] = None # type: ignore - - @overload - def get(self, key: _K, /) -> Optional[_V]: - ... - - @overload - def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: - ... - - def get(self, - key: _K, - /, - default: Optional[Union[_V, - _T]] = None) -> Optional[Union[_V, _T]]: - value: Optional[Union[_V, _T]] - if key in self: - value = self.__getitem__( - key, update_info=False) # type: ignore[call-arg] - - self._hits += 1 - else: - value = default - - self._total += 1 - return value - - @overload - def pop(self, key: _K) -> _V: - ... - - @overload - def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: - ... - - def pop(self, - key: _K, - default: Optional[Union[_V, - _T]] = None) -> Optional[Union[_V, _T]]: - value: Optional[Union[_V, _T]] - if key not in self: - return default - - value = self.__getitem__(key, - update_info=False) # type: ignore[call-arg] - self.__delitem__(key) - return value - - def put(self, key: _K, value: _V) -> None: - self.__setitem__(key, value) - - def pin(self, key: _K) -> None: - """ - Pins a key in the cache preventing it from being - evicted in the LRU order. - """ - if key not in self: - raise ValueError(f"Cannot pin key: {key} not in cache.") - self.pinned_items.add(key) - - def _unpin(self, key: _K) -> None: - """ - Unpins a key in the cache allowing it to be - evicted in the LRU order. - """ - self.pinned_items.remove(key) - - def _on_remove(self, key: _K, value: Optional[_V]) -> None: - pass - - def remove_oldest(self, *, remove_pinned: bool = False) -> None: - if len(self) == 0: - return - - self.popitem(remove_pinned=remove_pinned) - - def _remove_old_if_needed(self) -> None: - while self.currsize > self.capacity: - self.remove_oldest() - - def popitem(self, remove_pinned: bool = False): - """Remove and return the `(key, value)` pair least recently used.""" - if not remove_pinned: - # pop the oldest item in the cache that is not pinned - lru_key = next( - (key for key in self.order if key not in self.pinned_items), - ALL_PINNED_SENTINEL) - if lru_key is ALL_PINNED_SENTINEL: - raise RuntimeError("All items are pinned, " - "cannot remove oldest from the cache.") - else: - lru_key = next(iter(self.order)) - value = self.pop(cast(_K, lru_key)) - return (lru_key, value) - - def clear(self) -> None: - while len(self) > 0: - self.remove_oldest(remove_pinned=True) - - self._hits = 0 - self._total = 0 - self._last_info = CacheInfo(hits=0, total=0) - - -class PyObjectCache: - """Used to cache python objects to avoid object allocations - across scheduler iterations. - """ - - def __init__(self, obj_builder): - self._obj_builder = obj_builder - self._index = 0 - - self._obj_cache = [] - for _ in range(128): - self._obj_cache.append(self._obj_builder()) - - def _grow_cache(self): - # Double the size of the cache - num_objs = len(self._obj_cache) - for _ in range(num_objs): - self._obj_cache.append(self._obj_builder()) - - def get_object(self): - """Returns a pre-allocated cached object. If there is not enough - objects, then the cache size will double. - """ - if self._index >= len(self._obj_cache): - self._grow_cache() - assert self._index < len(self._obj_cache) - - obj = self._obj_cache[self._index] - self._index += 1 - - return obj - - def reset(self): - """Makes all cached-objects available for the next scheduler iteration. - """ - self._index = 0 - - @cache def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" from vllm import _custom_ops as ops - max_shared_mem = ( - ops.get_max_shared_memory_per_block_device_attribute(gpu)) + + max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail assert max_shared_mem > 0, "max_shared_mem can not be zero" @@ -532,11 +242,14 @@ class AsyncMicrobatchTokenizer: self.batch_wait_timeout_s = batch_wait_timeout_s self._loop = asyncio.get_running_loop() - self._queues: dict[tuple, - asyncio.Queue[Union[tuple[str, dict, - asyncio.Future], - tuple[list[int], - asyncio.Future]]]] = {} + self._queues: dict[ + tuple, + asyncio.Queue[ + Union[ + tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future] + ] + ], + ] = {} self._batcher_tasks: list[asyncio.Task] = [] # Single-thread executor for blocking tokenizer calls. @@ -560,8 +273,9 @@ class AsyncMicrobatchTokenizer: # === Internal helpers === def _get_queue( self, loop: asyncio.AbstractEventLoop, key: tuple - ) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[ - list[int], asyncio.Future]]]: + ) -> asyncio.Queue[ + Union[tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future]] + ]: """Get the request queue for the given operation key, creating a new queue and batcher task if needed.""" queue = self._queues.get(key) @@ -571,8 +285,7 @@ class AsyncMicrobatchTokenizer: can_batch = key[1] != "other" coro = self._batch_encode_loop(queue, can_batch) else: - assert key[0] == "decode", \ - f"Unknown operation type: {key[0]}." + assert key[0] == "decode", f"Unknown operation type: {key[0]}." coro = self._batch_decode_loop(queue) self._batcher_tasks.append(loop.create_task(coro)) return queue @@ -592,7 +305,8 @@ class AsyncMicrobatchTokenizer: break try: prompt, kwargs, result_future = await asyncio.wait_for( - queue.get(), timeout) + queue.get(), timeout + ) prompts.append(prompt) result_futures.append(result_future) if not can_batch: @@ -604,9 +318,10 @@ class AsyncMicrobatchTokenizer: # If every request uses identical kwargs we can run a single # batched tokenizer call for a big speed-up. if can_batch and len(prompts) > 1: - encode_fn = partial(self.tokenizer, prompts, **kwargs) + batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) results = await self._loop.run_in_executor( - self._executor, encode_fn) + self._executor, batch_encode_fn + ) for i, fut in enumerate(result_futures): if not fut.done(): @@ -614,11 +329,11 @@ class AsyncMicrobatchTokenizer: fut.set_result(BatchEncoding(data)) else: encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ - self.tokenizer(p, **kw) - for p, kw in zip(prompts, kwargs) + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) ] results = await self._loop.run_in_executor( - self._executor, encode_fn) + self._executor, encode_fn + ) for fut, res in zip(result_futures, results): if not fut.done(): @@ -642,7 +357,8 @@ class AsyncMicrobatchTokenizer: break try: token_ids, result_future = await asyncio.wait_for( - queue.get(), timeout) + queue.get(), timeout + ) token_ids_list.append(token_ids) result_futures.append(result_future) except asyncio.TimeoutError: @@ -651,8 +367,8 @@ class AsyncMicrobatchTokenizer: try: # Perform a single batched decode call for all requests results = await self._loop.run_in_executor( - self._executor, self.tokenizer.batch_decode, - token_ids_list) + self._executor, self.tokenizer.batch_decode, token_ids_list + ) for fut, res in zip(result_futures, results): if not fut.done(): fut.set_result(res) @@ -681,7 +397,7 @@ class AsyncMicrobatchTokenizer: """ if op == "decode": - return ("decode", ) + return ("decode",) add_special_tokens = kwargs.get("add_special_tokens", True) truncation = kwargs.get("truncation", False) @@ -691,16 +407,17 @@ class AsyncMicrobatchTokenizer: return "encode", add_special_tokens, False, None model_max = getattr(self.tokenizer, "model_max_length", None) - if max_length is None or (model_max is not None - and max_length == model_max): + if max_length is None or (model_max is not None and max_length == model_max): return "encode", add_special_tokens, True, "model_max" return "encode", "other" def __del__(self): - if ((tasks := getattr(self, "_batcher_tasks", None)) - and (loop := getattr(self, "_loop", None)) - and not loop.is_closed()): + if ( + (tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed() + ): def cancel_tasks(): for task in tasks: @@ -735,8 +452,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool: def make_async( - func: Callable[P, T], - executor: Optional[concurrent.futures.Executor] = None + func: Callable[P, T], executor: concurrent.futures.Executor | None = None ) -> Callable[P, Awaitable[T]]: """Take a blocking function, and run it on in an executor thread. @@ -753,15 +469,9 @@ def make_async( return _async_wrapper -def _next_task(iterator: AsyncGenerator[T, None], - loop: AbstractEventLoop) -> Task: - # Can use anext() in python >= 3.10 - return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] - - async def merge_async_iterators( - *iterators: AsyncGenerator[T, - None], ) -> AsyncGenerator[tuple[int, T], None]: + *iterators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. @@ -776,17 +486,16 @@ async def merge_async_iterators( loop = asyncio.get_running_loop() - awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} + awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} try: while awaits: - done, _ = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED) + done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) for d in done: pair = awaits.pop(d) try: item = await d i, it = pair - awaits[_next_task(it, loop)] = pair + awaits[loop.create_task(anext(it))] = pair yield i, item except StopAsyncIteration: pass @@ -798,8 +507,7 @@ async def merge_async_iterators( await it.aclose() -async def collect_from_async_generator( - iterator: AsyncGenerator[T, None]) -> list[T]: +async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: """Collect all items from an async generator into a list.""" items = [] async for item in iterator: @@ -815,7 +523,8 @@ def get_ip() -> str: " it is often used by Docker and other software to" " interact with the container's network stack. Please " "use VLLM_HOST_IP instead to set the IP address for vLLM processes" - " to communicate with each other.") + " to communicate with each other." + ) if host_ip: return host_ip @@ -843,7 +552,8 @@ def get_ip() -> str: "Failed to get the IP address, using 0.0.0.0 by default." "The value can be set by the environment variable" " VLLM_HOST_IP or HOST_IP.", - stacklevel=2) + stacklevel=2, + ) return "0.0.0.0" @@ -871,7 +581,8 @@ def get_loopback_ip() -> str: else: raise RuntimeError( "Neither 127.0.0.1 nor ::1 are bound to a local interface. " - "Set the VLLM_LOOPBACK_IP environment variable explicitly.") + "Set the VLLM_LOOPBACK_IP environment variable explicitly." + ) def is_valid_ipv6_address(address: str) -> bool: @@ -884,13 +595,13 @@ def is_valid_ipv6_address(address: str) -> bool: def split_host_port(host_port: str) -> tuple[str, int]: # ipv6 - if host_port.startswith('['): - host, port = host_port.rsplit(']', 1) + if host_port.startswith("["): + host, port = host_port.rsplit("]", 1) host = host[1:] - port = port.split(':')[1] + port = port.split(":")[1] return host, int(port) else: - host, port = host_port.split(':') + host, port = host_port.split(":") return host, int(port) @@ -942,7 +653,7 @@ def get_open_port() -> int: def get_open_ports_list(count: int = 5) -> list[int]: """Get a list of open ports.""" - ports = set() + ports = set[int]() while len(ports) < count: ports.add(get_open_port()) return list(ports) @@ -958,8 +669,7 @@ def _get_open_port() -> int: return port except OSError: port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", - port - 1, port) + logger.info("Port %d is already in use, trying port %d", port - 1, port) # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -972,7 +682,7 @@ def _get_open_port() -> int: return s.getsockname()[1] -def find_process_using_port(port: int) -> Optional[psutil.Process]: +def find_process_using_port(port: int) -> psutil.Process | None: # TODO: We can not check for running processes with network # port on macOS. Therefore, we can not have a full graceful shutdown # of vLLM. For now, let's not look for processes in this case. @@ -980,8 +690,9 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]: if sys.platform.startswith("darwin"): return None + our_pid = os.getpid() for conn in psutil.net_connections(): - if conn.laddr.port == port: + if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid): try: return psutil.Process(conn.pid) except psutil.NoSuchProcess: @@ -993,15 +704,18 @@ def update_environment_variables(envs: dict[str, str]): for k, v in envs.items(): if k in os.environ and os.environ[k] != v: logger.warning( - "Overwriting environment variable %s " - "from '%s' to '%s'", k, os.environ[k], v) + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) os.environ[k] = v def chunk_list(lst: list[T], chunk_size: int): """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size): - yield lst[i:i + chunk_size] + yield lst[i : i + chunk_size] def cdiv(a: int, b: int) -> int: @@ -1045,6 +759,7 @@ def _generate_random_fp8( # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) ops.convert_fp8(tensor, tensor_tmp) @@ -1052,12 +767,12 @@ def _generate_random_fp8( def get_kv_cache_torch_dtype( - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, +) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": - if isinstance(model_dtype, - str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype @@ -1080,39 +795,37 @@ def create_kv_caches_with_random_flash( num_layers: int, num_heads: int, head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = None, - device: Optional[str] = "cuda", - cache_layout: Optional[str] = "NHD", + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, - 4) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) - kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] - for i in stride_order) + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_value_cache = torch.empty(size=kv_cache_allocation_shape, - dtype=torch_dtype, - device=device).permute(*stride_order) + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=torch_dtype, device=device + ).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(key_value_cache, -scale, scale) else: - raise ValueError( - f"Does not support key cache of type {cache_dtype}") + raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches @@ -1124,16 +837,17 @@ def create_kv_caches_with_random( num_layers: int, num_heads: int, head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = None, - device: Optional[str] = "cuda", + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, + seed: int | None = None, + device: str | None = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" ) from vllm.platforms import current_platform + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -1143,31 +857,27 @@ def create_kv_caches_with_random( key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=torch_dtype, - device=device) + key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(key_cache, -scale, scale) else: - raise ValueError( - f"Does not support key cache of type {cache_dtype}") + raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=torch_dtype, - device=device) + value_cache = torch.empty( + size=value_cache_shape, dtype=torch_dtype, device=device + ) if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(value_cache, -scale, scale) else: - raise ValueError( - f"Does not support value cache of type {cache_dtype}") + raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches @@ -1175,6 +885,7 @@ def create_kv_caches_with_random( @cache def is_pin_memory_available() -> bool: from vllm.platforms import current_platform + return current_platform.is_pin_memory_available() @@ -1187,13 +898,13 @@ def is_uva_available() -> bool: class DeviceMemoryProfiler: - - def __init__(self, device: Optional[torch.types.Device] = None): + def __init__(self, device: torch.types.Device | None = None): self.device = device def current_memory_usage(self) -> float: # Return the memory usage in bytes. from vllm.platforms import current_platform + gc.collect() return current_platform.get_current_memory_usage(self.device) @@ -1215,7 +926,7 @@ def make_ndarray_with_pad( pad: T, dtype: npt.DTypeLike, *, - max_len: Optional[int] = None, + max_len: int | None = None, ) -> npt.NDArray: """ Make a padded array from 2D inputs. @@ -1230,7 +941,7 @@ def make_ndarray_with_pad( padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb + padded_x[ind, : len(blocktb)] = blocktb return padded_x @@ -1240,8 +951,8 @@ def make_tensor_with_pad( pad: T, dtype: torch.dtype, *, - max_len: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, + max_len: int | None = None, + device: Union[str, torch.device] | None = None, pin_memory: bool = False, ) -> torch.Tensor: """ @@ -1279,8 +990,7 @@ def get_dtype_size(dtype: torch.dtype) -> int: # bool = 0, int = 1, float = 2, complex = 3 def _get_precision_level(dtype: torch.dtype) -> int: # NOTE: Complex dtypes return `is_floating_point=False` - return ((dtype != torch.bool) + dtype.is_floating_point + - dtype.is_complex * 2) + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): @@ -1308,8 +1018,11 @@ def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): # Compare floating-point types src_info = torch.finfo(src_dtype) tgt_info = torch.finfo(tgt_dtype) - return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max - and src_info.resolution >= tgt_info.resolution) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): @@ -1328,6 +1041,12 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: return maybe_list if isinstance(maybe_list, list) else list(maybe_list) +def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + return [obj] # type: ignore[list-item] + return obj + + # `collections` helpers def is_list_of( value: object, @@ -1371,6 +1090,7 @@ def init_cached_hf_modules() -> None: Lazy initialization of the Hugging Face modules. """ from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() @@ -1414,8 +1134,8 @@ def find_nccl_library() -> str: # manually load the nccl library if so_file: logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", - so_file) + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) else: if torch.version.cuda is not None: so_file = "libnccl.so.2" @@ -1427,6 +1147,39 @@ def find_nccl_library() -> str: return so_file +def find_nccl_include_paths() -> list[str] | None: + """ + We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` + environment variable, or we find the library file brought by + nvidia-nccl-cuXX. load_inline by default uses + torch.utils.cpp_extension.include_paths + """ + paths: list[str] = [] + inc = envs.VLLM_NCCL_INCLUDE_PATH + if inc and os.path.isdir(inc): + paths.append(inc) + + try: + import importlib.util + + spec = importlib.util.find_spec("nvidia.nccl") + if spec and getattr(spec, "submodule_search_locations", None): + for loc in spec.submodule_search_locations: + inc_dir = os.path.join(loc, "include") + if os.path.exists(os.path.join(inc_dir, "nccl.h")): + paths.append(inc_dir) + except Exception: + pass + + seen = set() + out: list[str] = [] + for p in paths: + if p and p not in seen: + out.append(p) + seen.add(p) + return out or None + + prev_set_stream = torch.cuda.set_stream _current_stream_tls = threading.local() @@ -1441,7 +1194,6 @@ torch.cuda.set_stream = _patched_set_stream class _StreamPlaceholder: - def __init__(self): self.synchronize = lambda: None @@ -1458,15 +1210,16 @@ def current_stream() -> torch.cuda.Stream: from C/C++ code. """ from vllm.platforms import current_platform - if not hasattr(_current_stream_tls, - "value") or _current_stream_tls.value is None: + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process if current_platform.is_rocm(): - _current_stream_tls.value = torch.cuda.Stream() + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu(): _current_stream_tls.value = _StreamPlaceholder() else: @@ -1476,7 +1229,8 @@ def current_stream() -> torch.cuda.Stream: else: raise ValueError( "Fail to set current stream, current platform " - "may not support current_stream with torch API") + "may not support current_stream with torch API" + ) return _current_stream_tls.value @@ -1489,12 +1243,14 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: tmp_dir = tempfile.gettempdir() # add username to tmp_dir to avoid permission issues tmp_dir = os.path.join(tmp_dir, getpass.getuser()) - filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", - f"vllm-instance-{vllm_config.instance_id}", - filename) + filename = ( + f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log" + ).replace(" ", "_") + log_path = os.path.join( + tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename + ) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) @@ -1505,36 +1261,34 @@ def identity(value: T, **kwargs) -> T: return value -F = TypeVar('F', bound=Callable[..., Any]) +F = TypeVar("F", bound=Callable[..., Any]) def deprecate_args( start_index: int, is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None, + additional_message: str | None = None, ) -> Callable[[F], F]: if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: - params = inspect.signature(fn).parameters pos_types = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) - pos_kws = [ - kw for kw, param in params.items() if param.kind in pos_types - ] + pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): - deprecated_args = pos_kws[start_index:len(args)] + deprecated_args = pos_kws[start_index : len(args)] if deprecated_args: msg = ( f"The positional arguments {deprecated_args} are " - "deprecated and will be removed in a future update.") + "deprecated and will be removed in a future update." + ) if additional_message is not None: msg += f" {additional_message}" @@ -1553,7 +1307,7 @@ def deprecate_args( def deprecate_kwargs( *kws: str, is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None, + additional_message: str | None = None, ) -> Callable[[F], F]: deprecated_kws = set(kws) @@ -1561,7 +1315,6 @@ def deprecate_kwargs( is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: - @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): @@ -1569,7 +1322,8 @@ def deprecate_kwargs( if deprecated_kwargs: msg = ( f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update.") + "deprecated and will be removed in a future update." + ) if additional_message is not None: msg += f" {additional_message}" @@ -1586,8 +1340,7 @@ def deprecate_kwargs( @lru_cache(maxsize=8) -def _cuda_device_count_stateless( - cuda_visible_devices: Optional[str] = None) -> int: +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. @@ -1599,13 +1352,17 @@ def _cuda_device_count_stateless( import torch.version from vllm.platforms import current_platform + if not torch.cuda._is_compiled(): return 0 if current_platform.is_rocm(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = torch.cuda._device_count_amdsmi() if (hasattr( - torch.cuda, "_device_count_amdsmi")) else -1 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) else: raw_count = torch.cuda._device_count_nvml() r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count @@ -1639,9 +1396,9 @@ def xpu_is_initialized() -> bool: return torch.xpu.is_initialized() -def cuda_get_device_properties(device, - names: Sequence[str], - init_cuda=False) -> tuple[Any, ...]: +def cuda_get_device_properties( + device, names: Sequence[str], init_cuda=False +) -> tuple[Any, ...]: """Get specified CUDA device property values without initializing CUDA in the current process.""" if init_cuda or cuda_is_initialized(): @@ -1651,11 +1408,12 @@ def cuda_get_device_properties(device, # Run in subprocess to avoid initializing CUDA as a side effect. mp_ctx = multiprocessing.get_context("fork") with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: - return executor.submit(cuda_get_device_properties, device, names, - True).result() + return executor.submit(cuda_get_device_properties, device, names, True).result() -def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: +def weak_bind( + bound_method: Callable[..., Any], +) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that instance is collected.""" @@ -1670,7 +1428,6 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if wrapper.has_run: # type: ignore[attr-defined] return @@ -1686,19 +1443,18 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]: class StoreBoolean(Action): - def __call__(self, parser, namespace, values, option_string=None): if values.lower() == "true": setattr(namespace, self.dest, True) elif values.lower() == "false": setattr(namespace, self.dest, False) else: - raise ValueError(f"Invalid boolean value: {values}. " - "Expected 'true' or 'false'.") + raise ValueError( + f"Invalid boolean value: {values}. Expected 'true' or 'false'." + ) -class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, - RawDescriptionHelpFormatter): +class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): """SortedHelpFormatter that sorts arguments by their option strings.""" def _split_lines(self, text, width): @@ -1710,7 +1466,7 @@ class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, # The patterns also include whitespace after the newline single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*") multiple_newlines = re.compile(r"\n{2,}\s*") - text = single_newline.sub(' ', text) + text = single_newline.sub(" ", text) lines = re.split(multiple_newlines, text) return sum([textwrap.wrap(line, width) for line in lines], []) @@ -1730,7 +1486,9 @@ class FlexibleArgumentParser(ArgumentParser): " --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n" "Additionally, list elements can be passed individually using +:\n" ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n' - " --json-arg.key4+ value3 --json-arg.key4+=\'value4,value5\'\n\n") + " --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n" + ) + _search_keyword: str | None = None def __init__(self, *args, **kwargs): # Set the default "formatter_class" to SortedHelpFormatter @@ -1750,11 +1508,14 @@ class FlexibleArgumentParser(ArgumentParser): logger.warning_once( "argument '--disable-log-requests' is deprecated and " "replaced with '--enable-log-requests'. This will be " - "removed in v0.12.0.") + "removed in v0.12.0." + ) namespace, args = super().parse_known_args(args, namespace) for action in FlexibleArgumentParser._deprecated: - if (hasattr(namespace, dest := action.dest) - and getattr(namespace, dest) != action.default): + if ( + hasattr(namespace, dest := action.dest) + and getattr(namespace, dest) != action.default + ): logger.warning_once("argument '%s' is deprecated", dest) return namespace, args @@ -1766,7 +1527,6 @@ class FlexibleArgumentParser(ArgumentParser): return action class _FlexibleArgumentGroup(_ArgumentGroup): - def add_argument(self, *args, **kwargs): deprecated = kwargs.pop("deprecated", False) action = super().add_argument(*args, **kwargs) @@ -1779,13 +1539,79 @@ class FlexibleArgumentParser(ArgumentParser): self._action_groups.append(group) return group - def format_help(self) -> str: - # Add tip about JSON arguments to the epilog - epilog = self.epilog or "" - if (self.add_json_tip - and not epilog.startswith(FlexibleArgumentParser._json_tip)): - self.epilog = FlexibleArgumentParser._json_tip + epilog - return super().format_help() + def format_help(self): + # Only use custom help formatting for bottom level parsers + if self._subparsers is not None: + return super().format_help() + + formatter = self._get_formatter() + + # Handle keyword search of the args + if (search_keyword := self._search_keyword) is not None: + # Normalise the search keyword + search_keyword = search_keyword.lower().replace("_", "-") + # Return full help if searching for 'all' + if search_keyword == "all": + self.epilog = self._json_tip + return super().format_help() + + # Return group help if searching for a group title + for group in self._action_groups: + if group.title and group.title.lower() == search_keyword: + formatter.start_section(group.title) + formatter.add_text(group.description) + formatter.add_arguments(group._group_actions) + formatter.end_section() + formatter.add_text(self._json_tip) + return formatter.format_help() + + # Return matched args if searching for an arg name + matched_actions = [] + for group in self._action_groups: + for action in group._group_actions: + # search option name + if any( + search_keyword in opt.lower() for opt in action.option_strings + ): + matched_actions.append(action) + if matched_actions: + formatter.start_section(f"Arguments matching '{search_keyword}'") + formatter.add_arguments(matched_actions) + formatter.end_section() + formatter.add_text(self._json_tip) + return formatter.format_help() + + # No match found + formatter.add_text( + f"No group or arguments matching '{search_keyword}'.\n" + "Use '--help' to see available groups or " + "'--help=all' to see all available parameters." + ) + return formatter.format_help() + + # usage + formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups) + + # description + formatter.add_text(self.description) + + # positionals, optionals and user-defined groups + formatter.start_section("Config Groups") + config_groups = "" + for group in self._action_groups: + if not group._group_actions: + continue + title = group.title + description = group.description or "" + config_groups += f"{title: <24}{description}\n" + formatter.add_text(config_groups) + formatter.end_section() + + # epilog + formatter.add_text(self.epilog) + + # determine help from format above + return formatter.format_help() def parse_args( # type: ignore[override] self, @@ -1797,15 +1623,42 @@ class FlexibleArgumentParser(ArgumentParser): # Check for --model in command line arguments first if args and args[0] == "serve": - model_in_cli_args = any(arg == '--model' for arg in args) - - if model_in_cli_args: - raise ValueError( + try: + model_idx = next( + i + for i, arg in enumerate(args) + if arg == "--model" or arg.startswith("--model=") + ) + logger.warning( "With `vllm serve`, you should provide the model as a " "positional argument or in a config file instead of via " - "the `--model` option.") + "the `--model` option. " + "The `--model` option will be removed in v0.13." + ) - if '--config' in args: + if args[model_idx] == "--model": + model_tag = args[model_idx + 1] + rest_start_idx = model_idx + 2 + else: + model_tag = args[model_idx].removeprefix("--model=") + rest_start_idx = model_idx + 1 + + # Move <model> to the front, e,g: + # [Before] + # vllm serve -tp 2 --model <model> --enforce-eager --port 8001 + # [After] + # vllm serve <model> -tp 2 --enforce-eager --port 8001 + args = [ + "serve", + model_tag, + *args[1:model_idx], + *args[rest_start_idx:], + ] + print("args", args) + except StopIteration: + pass + + if "--config" in args: args = self._pull_args_from_config(args) def repl(match: re.Match) -> str: @@ -1818,25 +1671,30 @@ class FlexibleArgumentParser(ArgumentParser): # Convert underscores to dashes and vice versa in argument names processed_args = list[str]() for i, arg in enumerate(args): - if arg.startswith('--'): - if '=' in arg: - key, value = arg.split('=', 1) + if arg.startswith("--help="): + FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower() + processed_args.append("--help") + elif arg.startswith("--"): + if "=" in arg: + key, value = arg.split("=", 1) key = pattern.sub(repl, key, count=1) - processed_args.append(f'{key}={value}') + processed_args.append(f"{key}={value}") else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith('-O') and arg != '-O' and arg[2] != '.': + elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": # allow -O flag to be used without space, e.g. -O3 or -Odecode # -O.<...> handled later # also handle -O=<level> here - level = arg[3:] if arg[2] == '=' else arg[2:] - processed_args.append(f'-O.level={level}') - elif arg == '-O' and i + 1 < len(args) and args[i + 1] in { - "0", "1", "2", "3" - }: + level = arg[3:] if arg[2] == "=" else arg[2:] + processed_args.append(f"-O.level={level}") + elif ( + arg == "-O" + and i + 1 < len(args) + and args[i + 1] in {"0", "1", "2", "3"} + ): # Convert -O <n> to -O.level <n> - processed_args.append('-O.level') + processed_args.append("-O.level") else: processed_args.append(arg) @@ -1900,14 +1758,11 @@ class FlexibleArgumentParser(ArgumentParser): # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) - arg_duplicates = recursive_dict_update(dict_args[key], - arg_dict) - duplicates |= {f'{key}.{d}' for d in arg_duplicates} + arg_duplicates = recursive_dict_update(dict_args[key], arg_dict) + duplicates |= {f"{key}.{d}" for d in arg_duplicates} delete.add(i) # Filter out the dict args we set to None - processed_args = [ - a for i, a in enumerate(processed_args) if i not in delete - ] + processed_args = [a for i, a in enumerate(processed_args) if i not in delete] if duplicates: logger.warning("Found duplicate keys %s", ", ".join(duplicates)) @@ -1964,48 +1819,56 @@ class FlexibleArgumentParser(ArgumentParser): this way the order of priorities is maintained when these are args parsed by super(). """ - assert args.count( - '--config') <= 1, "More than one config file specified!" + assert args.count("--config") <= 1, "More than one config file specified!" - index = args.index('--config') + index = args.index("--config") if index == len(args) - 1: - raise ValueError("No config file specified! \ - Please check your command-line arguments.") + raise ValueError( + "No config file specified! \ + Please check your command-line arguments." + ) file_path = args[index + 1] - config_args = self._load_config_file(file_path) + config_args = self.load_config_file(file_path) - # 0th index is for {serve,chat,complete} + # 0th index might be the sub command {serve,chat,complete,...} # optionally followed by model_tag (only for serve) # followed by config args # followed by rest of cli args. # maintaining this order will enforce the precedence # of cli > config > defaults - if args[0] == "serve": - model_in_cli = len(args) > 1 and not args[1].startswith('-') - model_in_config = any(arg == '--model' for arg in config_args) + if args[0].startswith("-"): + # No sub command (e.g., api_server entry point) + args = config_args + args[0:index] + args[index + 2 :] + elif args[0] == "serve": + model_in_cli = len(args) > 1 and not args[1].startswith("-") + model_in_config = any(arg == "--model" for arg in config_args) if not model_in_cli and not model_in_config: raise ValueError( "No model specified! Please specify model either " - "as a positional argument or in a config file.") + "as a positional argument or in a config file." + ) if model_in_cli: # Model specified as positional arg, keep CLI version - args = [args[0]] + [ - args[1] - ] + config_args + args[2:index] + args[index + 2:] + args = ( + [args[0]] + + [args[1]] + + config_args + + args[2:index] + + args[index + 2 :] + ) else: # No model in CLI, use config if available - args = [args[0] - ] + config_args + args[1:index] + args[index + 2:] + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] else: - args = [args[0]] + config_args + args[1:index] + args[index + 2:] + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] return args - def _load_config_file(self, file_path: str) -> list[str]: + def load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -2018,11 +1881,13 @@ class FlexibleArgumentParser(ArgumentParser): '--tensor-parallel-size': '4' ] """ - extension: str = file_path.split('.')[-1] - if extension not in ('yaml', 'yml'): + extension: str = file_path.split(".")[-1] + if extension not in ("yaml", "yml"): raise ValueError( "Config file must be of a yaml/yml type.\ - %s supplied", extension) + %s supplied", + extension, + ) # only expecting a flat dictionary of atomic types processed_args: list[str] = [] @@ -2034,32 +1899,38 @@ class FlexibleArgumentParser(ArgumentParser): except Exception as ex: logger.error( "Unable to read the config file at %s. \ - Make sure path is correct", file_path) + Make sure path is correct", + file_path, + ) raise ex store_boolean_arguments = [ - action.dest for action in self._actions - if isinstance(action, StoreBoolean) + action.dest for action in self._actions if isinstance(action, StoreBoolean) ] for key, value in config.items(): if isinstance(value, bool) and key not in store_boolean_arguments: if value: - processed_args.append('--' + key) + processed_args.append("--" + key) + elif isinstance(value, list): + if value: + processed_args.append("--" + key) + for item in value: + processed_args.append(str(item)) else: - processed_args.append('--' + key) + processed_args.append("--" + key) processed_args.append(str(value)) return processed_args -async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, - **kwargs): +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) +@lru_cache def supports_kw( callable: Callable[..., object], kw_name: str, @@ -2077,19 +1948,26 @@ def supports_kw( param_val = params.get(kw_name) # Types where the it may be valid, i.e., explicitly defined & nonvariadic - passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY)) + passable_kw_types = set( + ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ) if param_val: is_sig_param = param_val.kind in passable_kw_types # We want kwargs only, but this is passable as a positional arg - if (requires_kw_only and is_sig_param - and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + if ( + requires_kw_only + and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY + ): return False - if ((requires_kw_only - and param_val.kind == inspect.Parameter.KEYWORD_ONLY) - or (not requires_kw_only and is_sig_param)): + if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( + not requires_kw_only and is_sig_param + ): return True # If we're okay with var-kwargs, it's supported as long as @@ -2099,15 +1977,17 @@ def supports_kw( # mapping, but it wraps an ordered dict, and they appear in order. # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters last_param = params[next(reversed(params))] # type: ignore - return (last_param.kind == inspect.Parameter.VAR_KEYWORD - and last_param.name != kw_name) + return ( + last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name + ) return False def get_allowed_kwarg_only_overrides( callable: Callable[..., object], - overrides: Optional[Mapping[str, object]], + overrides: Mapping[str, object] | None, *, requires_kw_only: bool = True, allow_var_kwargs: bool = False, @@ -2139,10 +2019,12 @@ def get_allowed_kwarg_only_overrides( filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if supports_kw(callable, - kwarg_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs) + if supports_kw( + callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) } # If anything is dropped, log a warning @@ -2151,11 +2033,15 @@ def get_allowed_kwarg_only_overrides( if requires_kw_only: logger.warning( "The following intended overrides are not keyword-only args " - "and will be dropped: %s", dropped_keys) + "and will be dropped: %s", + dropped_keys, + ) else: logger.warning( "The following intended overrides are not keyword args " - "and will be dropped: %s", dropped_keys) + "and will be dropped: %s", + dropped_keys, + ) return filtered_overrides @@ -2170,8 +2056,9 @@ def supports_dynamo() -> bool: # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform def supports_xccl() -> bool: - return is_torch_equal_or_newer( - "2.8.0.dev") and torch.distributed.is_xccl_available() + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) # Some backends use pytorch version < 2.4.0 which doesn't @@ -2207,7 +2094,6 @@ class AtomicCounter: # Adapted from: https://stackoverflow.com/a/47212782/5082708 class LazyDict(Mapping[str, T], Generic[T]): - def __init__(self, factory: dict[str, Callable[[], T]]): self._factory = factory self._dict: dict[str, T] = {} @@ -2230,7 +2116,6 @@ class LazyDict(Mapping[str, T], Generic[T]): class ClassRegistry(UserDict[type[T], _V]): - def __getitem__(self, key: type[T]) -> _V: for cls in key.mro(): if cls in self.data: @@ -2264,7 +2149,9 @@ def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] + tensors: Union[ + torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], IntermediateTensors + ], ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, @@ -2276,6 +2163,15 @@ def weak_ref_tensors( return [weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) + return ret raise ValueError("Invalid type for tensors") @@ -2314,7 +2210,8 @@ def get_vllm_optional_dependencies(): return { extra: [ - re.split(r";|>=|<=|==", req)[0] for req in requirements + re.split(r";|>=|<=|==", req)[0] + for req in requirements if req.endswith(f'extra == "{extra}"') ] for extra in extras @@ -2507,12 +2404,13 @@ class PlaceholderModule(_PlaceholderBase): raise exc - raise AssertionError("PlaceholderModule should not be used " - "when the original module can be imported") + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) class _PlaceholderModuleAttr(_PlaceholderBase): - def __init__(self, module: PlaceholderModule, attr_path: str) -> None: super().__init__() @@ -2521,14 +2419,15 @@ class _PlaceholderModuleAttr(_PlaceholderBase): self.__attr_path = attr_path def placeholder_attr(self, attr_path: str): - return _PlaceholderModuleAttr(self.__module, - f"{self.__attr_path}.{attr_path}") + return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}") def __getattr__(self, key: str): getattr(self.__module, f"{self.__attr_path}.{key}") - raise AssertionError("PlaceholderModule should not be used " - "when the original module can be imported") + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) # create a library to hold the custom op @@ -2536,13 +2435,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: list[str], - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: str = "CUDA", - tags: tuple[torch.Tag, ...] = (), + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, + tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it @@ -2561,21 +2460,32 @@ def direct_register_custom_op( """ if not supports_custom_op(): from vllm.platforms import current_platform + assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " - "the required dependencies.") + "the required dependencies." + ) return + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + + dispatch_key = current_platform.dispatch_key + import torch.library + if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, - mutates_args=mutates_args) + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) @@ -2621,6 +2531,7 @@ def kill_process_tree(pid: int): @dataclass class MemorySnapshot: """Memory snapshot.""" + torch_peak: int = 0 free_memory: int = 0 total_memory: int = 0 @@ -2635,15 +2546,34 @@ class MemorySnapshot: self.measure() def measure(self): + from vllm.platforms import current_platform + # we measure the torch peak memory usage via allocated_bytes, # rather than `torch.cuda.memory_reserved()` . # After `torch.cuda.reset_peak_memory_stats()`, # `torch.cuda.memory_reserved()` will keep growing, and only shrink # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats().get( - "allocated_bytes.all.peak", 0) + self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) self.free_memory, self.total_memory = torch.cuda.mem_get_info() + shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark + if ( + current_platform.is_cuda() + and current_platform.get_device_capability() in shared_sysmem_device_mem_sms + ): + # On UMA (Orin, Thor and Spark) platform, + # where both CPU and GPU rely on system memory, + # the cudaMemGetInfo function shows the amount of free system memory + # rather than what’s actually available. + # In the case, + # torch.cuda.mem_get_info() only reports "free" memory, + # which can be lower than what is actually + # available due to not including cache memory. + # There’s also a comprehensive reference page + # that explains how you can compute the proper value yourself. + # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device + self.free_memory = psutil.virtual_memory().available + self.cuda_memory = self.total_memory - self.free_memory # torch.cuda.memory_reserved() is how many bytes @@ -2669,8 +2599,8 @@ class MemorySnapshot: @dataclass class MemoryProfilingResult: - """Memory profiling result. All numbers are in bytes. - """ + """Memory profiling result. All numbers are in bytes.""" + non_kv_cache_memory: int = 0 torch_peak_increase: int = 0 non_torch_increase: int = 0 @@ -2681,20 +2611,22 @@ class MemoryProfilingResult: profile_time: float = 0.0 def __repr__(self) -> str: - return (f"Memory profiling takes {self.profile_time:.2f} seconds. " - f"Total non KV cache memory: " - f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " - f"torch peak memory increase: " - f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " - f"non-torch forward increase memory: " - f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " - f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.") + return ( + f"Memory profiling takes {self.profile_time:.2f} seconds. " + f"Total non KV cache memory: " + f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " + f"torch peak memory increase: " + f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " + f"non-torch forward increase memory: " + f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " + f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB." + ) @contextlib.contextmanager def memory_profiling( - baseline_snapshot: MemorySnapshot, - weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + baseline_snapshot: MemorySnapshot, weights_memory: int +) -> Generator[MemoryProfilingResult, None, None]: """Memory profiling context manager. baseline_snapshot: the memory snapshot before the current vLLM instance. weights_memory: memory used by PyTorch when loading the model weights. @@ -2765,29 +2697,37 @@ def memory_profiling( result.torch_peak_increase = diff_profile.torch_peak result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp - result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = ( + non_torch_memory + peak_activation_memory + result.weights_memory + ) # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 def set_ulimit(target_soft_limit=65535): - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): logger.info("Windows detected, skipping ulimit adjustment.") return import resource + resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) if current_soft < target_soft_limit: try: - resource.setrlimit(resource_type, - (target_soft_limit, current_hard)) + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: logger.warning( "Found ulimit of %s and failed to automatically increase " "with error %s. This can cause fd limit errors like " "`OSError: [Errno 24] Too many open files`. Consider " - "increasing with ulimit -n", current_soft, e) + "increasing with ulimit -n", + current_soft, + e, + ) # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501 @@ -2818,7 +2758,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]: return scheme, host, port -def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: +def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: """Make a ZMQ path from its parts. Args: @@ -2841,9 +2781,9 @@ def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] path: str, socket_type: Any, - bind: Optional[bool] = None, - identity: Optional[bytes] = None, - linger: Optional[int] = None, + bind: bool | None = None, + identity: bytes | None = None, + linger: int | None = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2857,10 +2797,7 @@ def make_zmq_socket( # - Set a large 0.5GB buffer to improve throughput # For systems with less memory: # - Use system default (-1) to avoid excessive memory consumption - if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) # 0.5GB in bytes - else: - buf_size = -1 # Use system default buffer size + buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 if bind is None: bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) @@ -2900,19 +2837,15 @@ def make_zmq_socket( def zmq_socket_ctx( path: str, socket_type: Any, - bind: Optional[bool] = None, + bind: bool | None = None, linger: int = 0, - identity: Optional[bytes] = None, + identity: bytes | None = None, ) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, - path, - socket_type, - bind=bind, - identity=identity) + yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") @@ -2933,6 +2866,7 @@ def _maybe_force_spawn(): # to the subprocess so that it knows how to connect to the ray cluster. # env vars are inherited by subprocesses, even if we use spawn. import ray + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address reasons.append("In a Ray actor and can only be spawned") @@ -2947,7 +2881,9 @@ def _maybe_force_spawn(): "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " "See https://docs.vllm.ai/en/latest/usage/" "troubleshooting.html#python-multiprocessing " - "for more information. Reasons: %s", "; ".join(reasons)) + "for more information. Reasons: %s", + "; ".join(reasons), + ) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -2966,7 +2902,7 @@ def get_mp_context(): def bind_kv_cache( ctx: dict[str, Any], kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: Optional[dict[str, str]] = None + shared_kv_cache_layers: dict[str, str] | None = None, ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2984,33 +2920,40 @@ def bind_kv_cache( shared_kv_cache_layers = {} from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index + layer_need_kv_cache = [ - layer_name for layer_name in ctx - if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \ - and ctx[layer_name].kv_sharing_target_layer_name is None + layer_name + for layer_name in ctx + if ( + hasattr(ctx[layer_name], "attn_type") + and ctx[layer_name].attn_type + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER) + ) + and ctx[layer_name].kv_sharing_target_layer_name is None ] layer_index_sorted = sorted( - set( - extract_layer_index(layer_name) - for layer_name in layer_need_kv_cache)) + set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache) + ) for layer_name in layer_need_kv_cache: - kv_cache_idx = layer_index_sorted.index( - extract_layer_index(layer_name)) + kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name)) forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] if shared_kv_cache_layers is not None: for layer_name, target_layer_name in shared_kv_cache_layers.items(): - assert extract_layer_index(target_layer_name) < \ - extract_layer_index(layer_name), \ - "v0 doesn't support interleaving kv sharing" + assert extract_layer_index(target_layer_name) < extract_layer_index( + layer_name + ), "v0 doesn't support interleaving kv sharing" ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache -def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], - kwargs: dict[str, Any]) -> Any: +def run_method( + obj: Any, + method: Union[str, bytes, Callable], + args: tuple[Any], + kwargs: dict[str, Any], +) -> Any: """ Run a method of an object with the given arguments and keyword arguments. If the method is string, it will be converted to a method using getattr. @@ -3024,8 +2967,9 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], try: func = getattr(obj, method) except AttributeError: - raise NotImplementedError(f"Method {method!r} is not" - " implemented.") from None + raise NotImplementedError( + f"Method {method!r} is not implemented." + ) from None else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) @@ -3059,6 +3003,7 @@ def import_pynvml(): module to our codebase, and use it directly. """ import vllm.third_party.pynvml as pynvml + return pynvml @@ -3078,7 +3023,7 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: unimplemented_methods = [] for attr_name in dir(self): # bypass inner method - if attr_name.startswith('_'): + if attr_name.startswith("_"): continue try: @@ -3092,8 +3037,8 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: - method_names = ','.join(unimplemented_methods) - msg = (f"Methods {method_names} not implemented in {self}") + method_names = ",".join(unimplemented_methods) + msg = f"Methods {method_names} not implemented in {self}" logger.debug(msg) @wraps(original_init) @@ -3101,7 +3046,7 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: original_init(self, *args, **kwargs) find_unimplemented_methods(self) - type.__setattr__(cls, '__init__', wrapped_init) + type.__setattr__(cls, "__init__", wrapped_init) return cls @@ -3173,12 +3118,12 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: @contextlib.contextmanager -def cprofile_context(save_file: Optional[str] = None): +def cprofile_context(save_file: str | None = None): """Run a cprofile Args: save_file: path to save the profile result. "1" or - None will result in printing to stdout. + None will result in printing to stdout. """ import cProfile @@ -3195,7 +3140,7 @@ def cprofile_context(save_file: Optional[str] = None): prof.print_stats(sort="cumtime") -def cprofile(save_file: Optional[str] = None, enabled: bool = True): +def cprofile(save_file: str | None = None, enabled: bool = True): """Decorator to profile a Python method using cProfile. Args: @@ -3205,7 +3150,6 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): """ def decorator(func: Callable): - @wraps(func) def wrapper(*args, **kwargs): if not enabled: @@ -3223,19 +3167,29 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): # Only relevant for models using ALiBi (e.g, MPT) def check_use_alibi(model_config: ModelConfig) -> bool: cfg = model_config.hf_text_config - return (getattr(cfg, "alibi", False) # Falcon - or ("BloomForCausalLM" in getattr(model_config.hf_config, - "architectures", [])) # Bloom - or getattr(cfg, "position_encoding_type", "") == - "alibi" # codellm_1b_alibi - or (hasattr(cfg, "attn_config") # MPT - and ((isinstance(cfg.attn_config, dict) - and cfg.attn_config.get("alibi", False)) or - (not isinstance(cfg.attn_config, dict) - and getattr(cfg.attn_config, "alibi", False))))) + return ( + getattr(cfg, "alibi", False) # Falcon + or ( + "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", []) + ) # Bloom + or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi + or ( + hasattr(cfg, "attn_config") # MPT + and ( + ( + isinstance(cfg.attn_config, dict) + and cfg.attn_config.get("alibi", False) + ) + or ( + not isinstance(cfg.attn_config, dict) + and getattr(cfg.attn_config, "alibi", False) + ) + ) + ) + ) -def sha256(input) -> int: +def sha256(input: Any) -> bytes: """Hash any picklable Python object using SHA-256. The input is serialized using pickle before hashing, which allows @@ -3246,16 +3200,15 @@ def sha256(input) -> int: input: Any picklable Python object. Returns: - An integer representing the SHA-256 hash of the serialized input. + Bytes representing the SHA-256 hash of the serialized input. """ input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - return int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") + return hashlib.sha256(input_bytes).digest() -def sha256_cbor_64bit(input) -> int: +def sha256_cbor(input: Any) -> bytes: """ - Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. + Hash objects using CBOR serialization and SHA-256. This option is useful for non-Python-dependent serialization and hashing. @@ -3266,17 +3219,13 @@ def sha256_cbor_64bit(input) -> int: Custom classes must implement CBOR serialization methods. Returns: - An integer in the range [0, 2^64-1] representing the lower 64 bits - of the SHA-256 hash of the CBOR serialized input. + Bytes representing the SHA-256 hash of the CBOR serialized input. """ input_bytes = cbor2.dumps(input, canonical=True) - full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") - - return full_hash & ((1 << 64) - 1) + return hashlib.sha256(input_bytes).digest() -def get_hash_fn_by_name(hash_fn_name: str) -> Callable: +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: """Get a hash function by name, or raise an error if the function is not found. Args: @@ -3286,10 +3235,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable: """ if hash_fn_name == "sha256": return sha256 - if hash_fn_name == "sha256_cbor_64bit": - return sha256_cbor_64bit - if hash_fn_name == "builtin": - return hash + if hash_fn_name == "sha256_cbor": + return sha256_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") @@ -3307,7 +3254,7 @@ def is_torch_equal_or_newer(target: str) -> bool: return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. - return Version(importlib.metadata.version('torch')) >= Version(target) + return Version(importlib.metadata.version("torch")) >= Version(target) # Helper function used in testing. @@ -3350,9 +3297,15 @@ def has_triton_kernels() -> bool: return _has_module("triton_kernels") -def set_process_title(name: str, - suffix: str = "", - append: bool = False) -> None: +def has_tilelang() -> bool: + """Whether the optional `tilelang` package is available.""" + + return _has_module("tilelang") + + +def set_process_title( + name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX +) -> None: """ Set the current process title to a specific name with an optional suffix. @@ -3360,15 +3313,11 @@ def set_process_title(name: str, Args: name: The title to assign to the current process. suffix: An optional suffix to append to the base name. - append: Whether to append to the existing process title. + prefix: A prefix to prepend to the front separated by `::`. """ if suffix: name = f"{name}_{suffix}" - if append: - name = f"{setproctitle.getproctitle()}_{name}" - else: - name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}" - setproctitle.setproctitle(name) + setproctitle.setproctitle(f"{prefix}::{name}") def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: @@ -3383,7 +3332,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: if file.start_new_line: # type: ignore[attr-defined] file_write(prefix) idx = 0 - while (next_idx := s.find('\n', idx)) != -1: + while (next_idx := s.find("\n", idx)) != -1: next_idx += 1 file_write(s[idx:next_idx]) if next_idx == len(s): @@ -3398,7 +3347,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: file.write = write_with_prefix # type: ignore[method-assign] -def decorate_logs(process_name: Optional[str] = None) -> None: +def decorate_logs(process_name: str | None = None) -> None: """ Adds a process-specific prefix to each line of output written to stdout and stderr. @@ -3418,3 +3367,60 @@ def decorate_logs(process_name: Optional[str] = None) -> None: pid = os.getpid() _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) + + +def length_from_prompt_token_ids_or_embeds( + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, +) -> int: + """Calculate the request length (in number of tokens) give either + prompt_token_ids or prompt_embeds. + """ + prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids) + prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds) + + if prompt_token_len is None: + if prompt_embeds_len is None: + raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.") + return prompt_embeds_len + else: + if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len: + raise ValueError( + "Prompt token ids and prompt embeds had different lengths" + f" prompt_token_ids={prompt_token_len}" + f" prompt_embeds={prompt_embeds_len}" + ) + return prompt_token_len + + +@contextlib.contextmanager +def set_env_var(key, value): + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old + + +def unique_filepath(fn: Callable[[int], Path]) -> Path: + """ + unique_filepath returns a unique path by trying + to include an integer in increasing order. + + fn should be a callable that returns a path that + includes the passed int at a fixed location. + + Note: This function has a TOCTOU race condition. + Caller should use atomic operations (e.g., open with 'x' mode) + when creating the file to ensure thread safety. + """ + i = 0 + while True: + p = fn(i) + if not p.exists(): + return p + i += 1 diff --git a/vllm/utils/cache.py b/vllm/utils/cache.py new file mode 100644 index 0000000000000..a57ef9b70ccc8 --- /dev/null +++ b/vllm/utils/cache.py @@ -0,0 +1,220 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from collections import UserDict +from collections.abc import Hashable, Iterator, KeysView, Mapping +from types import MappingProxyType +from typing import Callable, Generic, NamedTuple, TypeVar, Union, cast, overload + +import cachetools + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") +_T = TypeVar("_T") + + +class _Sentinel: ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + +class _MappingOrderCacheView(UserDict[_K, _V]): + def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): + super().__init__(data) + self.ordered_keys = ordered_keys + + def __iter__(self) -> Iterator[_K]: + return iter(self.ordered_keys) + + def keys(self) -> KeysView[_K]: + return KeysView(self.ordered_keys) + + +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + def __sub__(self, other: CacheInfo): + return CacheInfo( + hits=self.hits - other.hits, + total=self.total - other.total, + ) + + +class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): + def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): + super().__init__(capacity, getsizeof) + + self.pinned_items = set[_K]() + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: + value = super().__getitem__(key) + + if update_info: + self._hits += 1 + self._total += 1 + + return value + + def __delitem__(self, key: _K) -> None: + run_on_remove = key in self + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + super().__delitem__(key) + if key in self.pinned_items: + # Todo: add warning to inform that del pinned item + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + + @property + def cache(self) -> Mapping[_K, _V]: + """Return the internal cache dictionary in order (read-only).""" + return _MappingOrderCacheView( + self._Cache__data, # type: ignore + self.order, + ) + + @property + def order(self) -> Mapping[_K, None]: + """Return the internal order dictionary (read-only).""" + return MappingProxyType(self._LRUCache__order) # type: ignore + + @property + def capacity(self) -> float: + return self.maxsize + + @property + def usage(self) -> float: + if self.maxsize == 0: + return 0 + + return self.currsize / self.maxsize + + def stat(self, *, delta: bool = False) -> CacheInfo: + """ + Gets the cumulative number of hits and queries against this cache. + + If `delta=True`, instead gets these statistics + since the last call that also passed `delta=True`. + """ + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + + def touch(self, key: _K) -> None: + try: + self._LRUCache__order.move_to_end(key) # type: ignore + except KeyError: + self._LRUCache__order[key] = None # type: ignore + + @overload + def get(self, key: _K, /) -> _V | None: ... + + @overload + def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... + + def get( + self, key: _K, /, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None + if key in self: + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + + self._hits += 1 + else: + value = default + + self._total += 1 + return value + + @overload + def pop(self, key: _K) -> _V: ... + + @overload + def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... + + def pop( + self, key: _K, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None + if key not in self: + return default + + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + self.__delitem__(key) + return value + + def put(self, key: _K, value: _V) -> None: + self.__setitem__(key, value) + + def pin(self, key: _K) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: _K) -> None: + """ + Unpins a key in the cache allowing it to be + evicted in the LRU order. + """ + self.pinned_items.remove(key) + + def _on_remove(self, key: _K, value: _V | None) -> None: + pass + + def remove_oldest(self, *, remove_pinned: bool = False) -> None: + if len(self) == 0: + return + + self.popitem(remove_pinned=remove_pinned) + + def _remove_old_if_needed(self) -> None: + while self.currsize > self.capacity: + self.remove_oldest() + + def popitem(self, remove_pinned: bool = False): + """Remove and return the `(key, value)` pair least recently used.""" + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.order if key not in self.pinned_items), + ALL_PINNED_SENTINEL, + ) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError( + "All items are pinned, cannot remove oldest from the cache." + ) + else: + lru_key = next(iter(self.order)) + value = self.pop(cast(_K, lru_key)) + return (lru_key, value) + + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index c0a4ed077e660..8f8f25f1302d6 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -4,6 +4,7 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import functools @@ -26,111 +27,124 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability(100)) - return has_deep_gemm() and is_supported_arch + or current_platform.is_device_capability(100) + ) + return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch @functools.cache -def is_blackwell_deep_gemm_e8m0_used() -> bool: +def is_deep_gemm_e8m0_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM " - "E8M0 scale on a Blackwell-class GPU. + "E8M0 scale on a Hopper or Blackwell-class GPU. """ - if not (envs.VLLM_USE_DEEP_GEMM): - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.") - return False - - if not has_deep_gemm(): - logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.") - return False - - if not envs.VLLM_USE_DEEP_GEMM_E8M0: - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.") + if not is_deep_gemm_supported(): + logger.debug_once( + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system." + ) return False _lazy_init() if _fp8_gemm_nt_impl is None: - logger.debug_once( - "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - enabled = (current_platform.is_cuda() - and current_platform.has_device_capability(100)) - if enabled: - logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.") - else: - logger.debug_once( - "DeepGEMM E8M0 disabled: not running on Blackwell GPU.") - return enabled + if envs.VLLM_USE_FLASHINFER_MOE_FP8: + logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.") + return False + + if envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on current platform.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False def _missing(*_: Any, **__: Any) -> NoReturn: """Placeholder for unavailable DeepGEMM backend.""" raise RuntimeError( - "DeepGEMM backend is not available. Please install the `deep_gemm` " - "package to enable FP8 kernels.") - - -def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: - """Return the *new* symbol if it exists, otherwise the *old* one.""" - if hasattr(module, new): - return getattr(module, new) - if hasattr(module, old): - # TODO(wentao): deprecate old symbol in the future. - logger.warning_once( - "Found legacy DeepGEMM symbol `%s`. Please upgrade the `deep_gemm` " - "package so that `%s` is available. Support for the legacy symbol " - "will be removed in a future vLLM release.", - old, - new, - ) - return getattr(module, old) - return None + "DeepGEMM backend is not available or outdated. Please install or " + "update the `deep_gemm` to a newer version to enable FP8 kernels." + ) _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None +_fp8_mqa_logits_impl: Callable[..., Any] | None = None +_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None +_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None +_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl + global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl + global _get_paged_mqa_logits_metadata_impl + global _get_mn_major_tma_aligned_tensor_impl # fast path - if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None): + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + ): return if not has_deep_gemm(): return # Set up deep_gemm cache path - DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR' + DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR" if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( - envs.VLLM_CACHE_ROOT, "deep_gemm") + envs.VLLM_CACHE_ROOT, "deep_gemm" + ) _dg = importlib.import_module("deep_gemm") - _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt", - "gemm_fp8_fp8_bf16_nt") - _grouped_impl = _resolve_symbol( - _dg, "m_grouped_fp8_gemm_nt_contiguous", - "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous") - _grouped_masked_impl = _resolve_symbol( - _dg, "fp8_m_grouped_gemm_nt_masked", - "m_grouped_gemm_fp8_fp8_bf16_nt_masked") + _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) + _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) + _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) + _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) + _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) + _get_paged_mqa_logits_metadata_impl = getattr( + _dg, "get_paged_mqa_logits_metadata", None + ) + _get_mn_major_tma_aligned_tensor_impl = getattr( + _dg, "get_mn_major_tma_aligned_tensor", None + ) + + +def get_num_sms() -> int: + _lazy_init() + _dg = importlib.import_module("deep_gemm") + return int(_dg.get_num_sms()) + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor""" + _lazy_init() + if _get_mn_major_tma_aligned_tensor_impl is None: + return _missing() + return _get_mn_major_tma_aligned_tensor_impl(x) def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + if "is_deep_gemm_e8m0_used" in kwargs: + use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"] + del kwargs["is_deep_gemm_e8m0_used"] + else: + use_ue8m0 = is_deep_gemm_e8m0_used() + return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): @@ -138,9 +152,8 @@ def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): if _grouped_impl is None: return _missing(*args, **kwargs) return _grouped_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -148,9 +161,104 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) + + +def fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + _lazy_init() + if _fp8_mqa_logits_impl is None: + return _missing() + return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: + """Build scheduling metadata for paged MQA logits. + + Args: + context_lens: Tensor of shape [B], dtype int32; effective context length + per batch element. + block_size: KV-cache block size in tokens (e.g., 64). + num_sms: Number of SMs available. 132 for Hopper + + Returns: + Backend-specific tensor consumed by `fp8_paged_mqa_logits` to + schedule work across SMs. + """ + _lazy_init() + if _get_paged_mqa_logits_metadata_impl is None: + return _missing() + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) + + +def fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + _lazy_init() + if _fp8_paged_mqa_logits_impl is None: + return _missing() + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) def _ceil_to_ue8m0(x: torch.Tensor): @@ -165,17 +273,16 @@ DEFAULT_BLOCK_SIZE = [128, 128] # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38 -# TODO(wentao): optimize this function, using triton or cuda kernel +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def per_block_cast_to_fp8( - x: torch.Tensor, - block_size: list[int] = DEFAULT_BLOCK_SIZE, - use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape block_m, block_n = block_size - x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -183,7 +290,8 @@ def per_block_cast_to_fp8( sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( - x_view.size(0), x_view.size(2)) + x_view.size(0), x_view.size(2) + ) def calc_diff(x: torch.Tensor, y: torch.Tensor): @@ -202,10 +310,19 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim -def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, - weight: torch.Tensor): - return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) +def should_use_deepgemm_for_fp8_linear( + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: bool | None = None, +): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return ( + supports_deep_gemm + and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0 + ) __all__ = [ @@ -213,8 +330,13 @@ __all__ = [ "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", + "fp8_mqa_logits", + "fp8_paged_mqa_logits", + "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_e8m0_used", + "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "get_num_sms", "should_use_deepgemm_for_fp8_linear", + "get_col_major_tma_aligned_tensor", ] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5dd239c50f637..ad8295f8f6893 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -4,6 +4,7 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import contextlib @@ -11,7 +12,8 @@ import functools import importlib import importlib.util import os -from typing import Any, Callable, NoReturn, Optional +import shutil +from typing import Any, Callable, NoReturn import requests import torch @@ -36,7 +38,14 @@ def has_flashinfer() -> bool: """Return ``True`` if FlashInfer is available.""" # Use find_spec to check if the module exists without importing it # This avoids potential CUDA initialization side effects - return importlib.util.find_spec("flashinfer") is not None + if importlib.util.find_spec("flashinfer") is None: + logger.debug_once("FlashInfer unavailable since package was not found") + return False + # Also check if nvcc is available since it's required to JIT compile flashinfer + if shutil.which("nvcc") is None: + logger.debug_once("FlashInfer unavailable since nvcc was not found") + return False + return True def _missing(*_: Any, **__: Any) -> NoReturn: @@ -44,7 +53,8 @@ def _missing(*_: Any, **__: Any) -> NoReturn: raise RuntimeError( "FlashInfer backend is not available. Please install the package " "to enable FlashInfer kernels: " - "https://github.com/flashinfer-ai/flashinfer") + "https://github.com/flashinfer-ai/flashinfer" + ) def _get_submodule(module_name: str) -> Any | None: @@ -56,9 +66,9 @@ def _get_submodule(module_name: str) -> Any | None: # General lazy import wrapper -def _lazy_import_wrapper(module_name: str, - attr_name: str, - fallback_fn: Callable[..., Any] = _missing): +def _lazy_import_wrapper( + module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing +): """Create a lazy import wrapper for a specific function.""" @functools.cache @@ -79,29 +89,64 @@ def _lazy_import_wrapper(module_name: str, # Create lazy wrappers for each function flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( - "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") + "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe" +) flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper( - "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe") -flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", - "cutlass_fused_moe") -fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") + "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe" +) +flashinfer_cutlass_fused_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "cutlass_fused_moe" +) +flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_block_scale_interleave = _lazy_import_wrapper( - "flashinfer", "nvfp4_block_scale_interleave") + "flashinfer", "nvfp4_block_scale_interleave" +) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( - "flashinfer", "trtllm_fp4_block_scale_moe") + "flashinfer", "trtllm_fp4_block_scale_moe" +) # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", "autotune", - fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), +) + + +@functools.cache +def has_flashinfer_comm() -> bool: + """Return ``True`` if FlashInfer comm module is available.""" + return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None + + +@functools.cache +def has_flashinfer_all2all() -> bool: + """Return ``True`` if FlashInfer mnnvl all2all is available.""" + if not has_flashinfer_comm(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.comm", "Mapping"), + ("flashinfer.comm.mnnvl", "MnnvlMemory"), + ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"), + ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True @functools.cache def has_flashinfer_moe() -> bool: """Return ``True`` if FlashInfer MoE module is available.""" - return has_flashinfer() and importlib.util.find_spec( - "flashinfer.fused_moe") is not None + return ( + has_flashinfer() + and importlib.util.find_spec("flashinfer.fused_moe") is not None + ) @functools.cache @@ -146,7 +191,8 @@ def has_nvidia_artifactory() -> bool: else: logger.warning_once( "NVIDIA artifactory returned failed status code: %d", - response.status_code) + response.status_code, + ) return accessible except Exception as e: logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e) @@ -154,28 +200,38 @@ def has_nvidia_artifactory() -> bool: @functools.cache -def supports_trtllm_attention() -> tuple[bool, Optional[str]]: - """Cache result which only depends on the environment""" - # This is a lambda, call it once - env_value = envs.VLLM_USE_TRTLLM_ATTENTION - +def supports_trtllm_attention() -> bool: + """ + TRTLLM attention is supported if the platform is SM100 and + NVIDIA artifactory is accessible + """ # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - if not (current_platform.is_device_capability(100) - and has_nvidia_artifactory()): - return False, env_value + return current_platform.is_device_capability(100) and has_nvidia_artifactory() + +@functools.cache +def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: + """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) - # Environment variable is set - respect it - # Making the conditional check for zero because - # the path is automatically enabled if the batch size condition - # is satisfied. - use_trtllm = (env_value == "1") - if use_trtllm: - logger.info_once("Using TRTLLM attention.") - return use_trtllm, env_value + return env_value - return True, None + +def force_use_trtllm_attention() -> bool | None: + """ + Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, + return ``True`` if TRTLLM attention is forced to be used, + return ``False`` if TRTLLM attention is forced to be not used. + """ + return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) + + +def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: + """Check if the current configuration supports TRTLLM attention.""" + if force_use_trtllm_attention() is False: + return False + has_trtllm = supports_trtllm_attention() + return has_trtllm and (num_qo_heads % num_kv_heads == 0) def use_trtllm_attention( @@ -187,40 +243,67 @@ def use_trtllm_attention( q_dtype: torch.dtype, is_prefill: bool, has_sinks: bool = False, + has_spec: bool = False, ) -> bool: - use_trtllm, env_value = supports_trtllm_attention() - if not use_trtllm: + """Return ``True`` if TRTLLM attention is used.""" + force_use_trtllm = force_use_trtllm_attention() + + # Environment variable is set to 0 - respect it + if force_use_trtllm is not None and not force_use_trtllm: return False - if num_qo_heads % num_kv_heads != 0: + # The platform is not supported + if not supports_trtllm_attention(): + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported on this platform, " + "but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) return False + # The combination of query and key heads is not supported + if num_qo_heads % num_kv_heads != 0: + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported for this combination of " + "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) + return False + + if has_spec and not is_prefill: + # Speculative decoding requires TRTLLM attention for decodes + logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") + return True + # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): logger.info_once("Using TRTLLM attention (query is quantized).") return True - # TRTLLM prefill attention does not support FP8 kv cache with - # non-quantized query - if is_prefill and kv_cache_dtype.startswith("fp8"): - return False - # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: - logger.info_once( - "Using TRTLLM attention (required for attention sinks).") + logger.info_once("Using TRTLLM attention (required for attention sinks).") return True - if env_value is None: + if force_use_trtllm is None: # Environment variable not set - use auto-detection - use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 - and kv_cache_dtype == "auto") - if use_trtllm: - logger.warning_once("Using TRTLLM attention (auto-detected).") + if is_prefill: + # Prefill auto-detection + use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto" + if use_trtllm: + logger.warning_once("Using TRTLLM prefill attention (auto-detected).") + else: + # Decode auto-detection + use_trtllm = ( + num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" + ) + if use_trtllm: + logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm # Environment variable is set to 1 - respect it + logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") return True @@ -241,16 +324,14 @@ if has_flashinfer(): backend: str, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ - return flashinfer_mm_fp4_(A, - B, - A_scale, - B_scale, - g_scale, - dtype, - block_size=16, - backend=backend) - @torch.library.register_fake("vllm::flashinfer_mm_fp4", ) + return flashinfer_mm_fp4_( + A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + ) + + @torch.library.register_fake( + "vllm::flashinfer_mm_fp4", + ) def flashinfer_mm_fp4_fake( A: torch.Tensor, B: torch.Tensor, @@ -260,23 +341,54 @@ if has_flashinfer(): dtype: torch.dtype, backend: str, ) -> torch.Tensor: - return torch.empty(A.shape[0], - B.shape[1], - dtype=dtype, - device=A.device) + return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) + + @torch.library.custom_op( + "vllm::bmm_fp8", + mutates_args=[], + device_types="cuda", + ) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import bmm_fp8 as bmm_fp8_ + + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) + + @torch.library.register_fake( + "vllm::bmm_fp8", + ) + def bmm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty( + A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device + ) -def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, alpha: torch.Tensor, - out_dtype: torch.dtype, - backend: str) -> torch.Tensor: +def flashinfer_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] - assert block_scale_a.shape[1] == a.shape[1] // 8 - assert block_scale_b.shape[1] == b.shape[1] // 8 if backend == "cutlass": block_scale_a = block_scale_a.view(torch.uint8) @@ -293,18 +405,59 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, ) +def flashinfer_scaled_fp8_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[0] + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert a.device.type == "cuda" and b.device.type == "cuda" + assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32 + assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda" + + output = bmm_fp8( + a.unsqueeze(0), + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype, + "auto", + ).view(a.shape[0], b.shape[1]) + + if bias is not None: + output = output + bias + return output + + +@functools.cache +def flashinfer_disable_q_quantization() -> bool: + """Cache result which only depends on the environment""" + return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_cutlass_fused_moe", - "fp4_quantize", + "flashinfer_fp4_quantize", "nvfp4_block_scale_interleave", "trtllm_fp4_block_scale_moe", "autotune", "has_flashinfer_moe", + "has_flashinfer_comm", + "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", "supports_trtllm_attention", + "can_use_trtllm_attention", "use_trtllm_attention", + "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py new file mode 100644 index 0000000000000..e3b5b61dd3643 --- /dev/null +++ b/vllm/utils/gc_utils.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import json +import time +from collections import Counter +from contextlib import suppress +from typing import Any, Optional + +from vllm.envs import VLLM_GC_DEBUG +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class GCDebugConfig: + """ + Config for GC Debugger. + - 0: disable GC debugger + - 1: enable GC debugger with gc.collect elpased times + - '{"top_objects":5}': enable GC debugger with top 5 collected objects + """ + + def __init__(self, gc_debug_conf: Optional[str] = None) -> None: + self.enabled: bool = False + self.top_objects: int = -1 + + if not gc_debug_conf or gc_debug_conf == "0": + pass + elif gc_debug_conf == "1": + self.enabled = True + else: + try: + json_conf = json.loads(gc_debug_conf) + self.enabled = True + self.top_objects = json_conf.get("top_objects", -1) + except Exception: + self.enabled = False + logger.error("Failed to parse VLLM_GC_DEBUG(%s)", VLLM_GC_DEBUG) + logger.info("GC Debug Config. %s", str(self)) + + def __repr__(self) -> str: + return f"enabled:{self.enabled},top_objects:{self.top_objects}" + + +class GCDebugger: + """ + Debugger for GC which logs helpful information for GC understanding. + To enable, you should call maybe_attach_gc_debug_callback in the process. + """ + + def __init__(self, config: GCDebugConfig) -> None: + self.config = config + # Start time in micro second of this GC cycle + self.start_time_ns: int = time.monotonic_ns() + # If config.top_objects is positive, + # compute top collected objects by object types + self.gc_top_collected_objects: str = "" + + def handle(self, phase: str, info: dict[str, int]) -> None: + """ + Handles a GC event (e.g. GC start or GC finish) + """ + generation = info.get("generation") + if generation is None: + return + if phase == "start": + # Before GC started, record GC start time + # and top collected objects + self.start_time_ns = time.monotonic_ns() + self.gc_top_collected_objects = _compute_top_gc_collected_objects( + gc.get_objects(generation), self.config.top_objects + ) + elif phase == "stop": + # After GC finished, Record GC elapsed time and + # optionally top collected objects + elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6 + logger.info( + "GC took %.3fms to complete. " + "Collected %s objects in GC generation %d.%s", + elpased_ms, + str(info.get("collected", "?")), + generation, + ( + f" Top collected objects: \n{self.gc_top_collected_objects}" + if self.gc_top_collected_objects + else "" + ), + ) + + +def maybe_attach_gc_debug_callback() -> None: + """ + Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. + """ + config = GCDebugConfig(VLLM_GC_DEBUG) + if config.enabled: + debugger: GCDebugger = GCDebugger(config) + + def gc_callback(phase: str, info: dict[str, int]) -> None: + debugger.handle(phase, info) + + gc.callbacks.append(gc_callback) + + +def _compute_detailed_type(o: Any) -> str: + """ + Detailed object type. + + TODO(Jialin): Further enhance the detailed type with element types for + easier debugging. We tried but occasionally it would run into signals + which kills the engine. + """ + size_str: str = "" + # Object doesn't support len() - this can happen with type objects + # or other objects that don't implement __len__ properly + with suppress(Exception): + size_str = f"(size:{len(o)})" + return f"{str(type(o))}{size_str}" + + +def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str: + """ + Group collected objects by types. + """ + if top <= 0: + return "" + object_types = [_compute_detailed_type(o) for o in objects] + return "\n".join( + f"{count:>5}:{object_type}" + for object_type, count in Counter(object_types).most_common(top) + ) diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 4cbe0f76e0067..dcdc6ccb4c638 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -1,17 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helper functions to work with nested JSON structures.""" + from collections.abc import Iterable from functools import reduce -from typing import Callable, TypeVar, Union, overload +from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload + +if TYPE_CHECKING: + import torch + + from vllm.multimodal.inputs import BatchedTensorInputs _T = TypeVar("_T") _U = TypeVar("_U") -JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], - tuple["JSONTree[_T]", ...], _T] +JSONTree = Union[ + dict[str, "JSONTree[_T]"], + list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], + _T, +] """A nested JSON structure where the leaves need not be JSON-serializable.""" +_JSONTree = Union[ + dict[str, "JSONTree[_T]"], + list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], + dict[str, _T], + list[_T], + tuple[_T, ...], + _T, +] +""" +Same as `JSONTree` but with additional `Union` members to satisfy overloads. +""" + def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: """Iterate through each leaf in a nested JSON structure.""" @@ -25,13 +48,51 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: yield value +@overload +def json_map_leaves( + func: Callable[["torch.Tensor"], "torch.Tensor"], + value: "BatchedTensorInputs", +) -> "BatchedTensorInputs": ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, dict[str, _T]], +) -> Union[_U, dict[str, _U]]: ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, list[_T]], +) -> Union[_U, list[_U]]: ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: Union[_T, tuple[_T, ...]], +) -> Union[_U, tuple[_U, ...]]: ... + + +@overload def json_map_leaves( func: Callable[[_T], _U], value: JSONTree[_T], -) -> JSONTree[_U]: +) -> JSONTree[_U]: ... + + +def json_map_leaves( + func: Callable[[_T], _U], + value: Union["BatchedTensorInputs", _JSONTree[_T]], +) -> Union["BatchedTensorInputs", _JSONTree[_U]]: """Apply a function to each leaf in a nested JSON structure.""" if isinstance(value, dict): - return {k: json_map_leaves(func, v) for k, v in value.items()} + return { + k: json_map_leaves(func, v) # type: ignore[arg-type] + for k, v in value.items() + } elif isinstance(value, list): return [json_map_leaves(func, v) for v in value] elif isinstance(value, tuple): @@ -40,13 +101,36 @@ def json_map_leaves( return func(value) +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, dict[str, _T]], + /, +) -> _T: ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, list[_T]], + /, +) -> _T: ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: Union[_T, tuple[_T, ...]], + /, +) -> _T: ... + + @overload def json_reduce_leaves( func: Callable[[_T, _T], _T], value: JSONTree[_T], /, -) -> _T: - ... +) -> _T: ... @overload @@ -55,14 +139,13 @@ def json_reduce_leaves( value: JSONTree[_T], initial: _U, /, -) -> _U: - ... +) -> _U: ... def json_reduce_leaves( func: Callable[..., Union[_T, _U]], - value: JSONTree[_T], - initial: _U = ..., # type: ignore[assignment] + value: _JSONTree[_T], + initial: _U = cast(_U, ...), # noqa: B008 /, ) -> Union[_T, _U]: """ @@ -78,3 +161,8 @@ def json_reduce_leaves( json_iter_leaves(value), initial, ) + + +def json_count_leaves(value: JSONTree[_T]) -> int: + """Count the number of leaves in a nested JSON structure.""" + return sum(1 for _ in json_iter_leaves(value)) diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 21d3249fe1547..e17676ccf7ef2 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (Annotated, Any, Optional, Union, get_args, get_origin, - get_type_hints) +from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints import torch @@ -11,7 +10,6 @@ logger = init_logger(__name__) class TensorShape: - def __init__( self, *dims: Union[int, str], @@ -22,9 +20,8 @@ class TensorShape: self.dims = dims self.dynamic_dims = dynamic_dims if dynamic_dims else set() - def resolve(self, **bindings: dict[str, - int]) -> tuple[Union[int, str], ...]: - resolved = [] + def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]: + resolved = list[Union[int, str]]() for dim in self.dims: if isinstance(dim, str) and dim in bindings: resolved.append(bindings[dim]) @@ -38,8 +35,7 @@ class TensorShape: for dim in self.dims: if isinstance(dim, str): if dim in self.dynamic_dims: - dim_strs.append( - f"{dim}*") # Mark dynamic dimensions with * + dim_strs.append(f"{dim}*") # Mark dynamic dimensions with * else: dim_strs.append(dim) else: @@ -48,7 +44,6 @@ class TensorShape: class TensorSchema: - def __init__( self, *, @@ -95,34 +90,66 @@ class TensorSchema: return False return True - def _validate_nested_tensors( + def _fmt_indexer(self, idxs: tuple[int, ...]) -> str: + if not idxs: + return "" + + return str(list(idxs)) + + def _validate_field( self, - value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + value: object, field_name: str, expected_shape: tuple[Union[int, str], ...], dynamic_dims: set[str], + leading_idxs: tuple[int, ...] = (), ) -> tuple[int, ...]: - """Validate a list/tuple of tensors and return the actual shape.""" + """Validate a field and return the actual shape.""" + if isinstance(value, (int, float)): + return () # Scalar + if isinstance(value, torch.Tensor): + return value.shape + + if not isinstance(value, (list, tuple)): + raise TypeError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is not " + f"one of the expected types: int, float, Tensor, list, tuple. " + f"Got: {type(value)}" + ) + + if len(value) == 0: + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence" + ) + # Ensure all tensors in the list have the same # shape, besides dynamic dimensions - first = value[0] for i, v in enumerate(value): - if not isinstance(v, torch.Tensor): - raise ValueError(f"{field_name}[{i}] is not a " - f"torch.Tensor") - if not self._match_shape_with_dynamic( - v.shape, - first.shape, - expected_shape, - dynamic_dims, + shape = self._validate_field( + v, + field_name, + expected_shape[1:], + dynamic_dims, + leading_idxs=leading_idxs + (i,), + ) + + if i == 0: + first_shape = shape + elif not self._match_shape_with_dynamic( + shape, + first_shape, + expected_shape, + dynamic_dims, ): - raise ValueError(f"{field_name} contains inconsistent " - f"shapes: {first.shape} vs {v.shape} " - f"at index {i}") + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} " + f"contains inconsistent shapes: {first_shape} " + f"(index 0) vs {shape} (index {i})" + ) # Treat the list as a stacked tensor: # shape = (len(list), *tensor.shape) - return (len(value), ) + first.shape + return (len(value),) + first_shape def _validate_tensor_shape_expected( self, @@ -135,36 +162,46 @@ class TensorSchema: """Validate that the actual tensor shape matches the expected shape.""" if len(actual_shape) != len(expected_shape): - raise ValueError(f"{field_name} has rank {len(actual_shape)} " - f"but expected {len(expected_shape)}") + raise ValueError( + f"{field_name} has rank {len(actual_shape)} " + f"but expected {len(expected_shape)}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) for i, dim in enumerate(expected_shape): if dim in dynamic_dims: continue elif isinstance(dim, int): if actual_shape[i] != dim: - raise ValueError(f"{field_name} dim[{i}] expected " - f"{dim}, got {actual_shape[i]}") + raise ValueError( + f"{field_name} dim[{i}] expected " + f"{dim}, got {actual_shape[i]}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) elif isinstance(dim, str): if dim in shape_env: if actual_shape[i] != shape_env[dim]: - raise ValueError(f"{field_name} dim[{i}] expected " - f"'{dim}'={shape_env[dim]}, got " - f"{actual_shape[i]}") + raise ValueError( + f"{field_name} dim[{i}] expected " + f"'{dim}'={shape_env[dim]}, got " + f"{actual_shape[i]}" + ) else: shape_env[dim] = actual_shape[i] else: - raise TypeError(f"{field_name} dim[{i}] has unsupported " - f"type: {type(dim)}") + raise TypeError( + f"{field_name} dim[{i}] has unsupported type: {type(dim)}" + ) def validate(self) -> None: type_hints = get_type_hints(self.__class__, include_extras=True) - shape_env = {} + shape_env = dict[str, int]() for field_name, field_type in type_hints.items(): # Check if field is missing - if (not hasattr(self, field_name) - or getattr(self, field_name) is None): + if not hasattr(self, field_name) or getattr(self, field_name) is None: # Check if field is marked as optional actual_type = field_type if get_origin(field_type) is Annotated: @@ -188,40 +225,20 @@ class TensorSchema: for arg in args: if isinstance(arg, TensorShape): expected_shape = arg.resolve(**self._resolve_bindings) - if isinstance(value, (list, tuple)): - # list/tuple of Tensors → shape = (len(value), ...) - if value and isinstance(value[0], torch.Tensor): - actual_shape = self._validate_nested_tensors( - value, field_name, expected_shape, - arg.dynamic_dims) - elif value: - # list/tuple of scalars → shape = (len(value),) - actual_shape = (len(value), ) - else: - raise ValueError( - f"{field_name} is an empty list") - - # Tensor → shape = tensor.shape - elif isinstance(value, torch.Tensor): - actual_shape = value.shape - - # Otherwise, it's an unsupported type - else: - type_names = [] - for arg in args: - if hasattr(arg, "__name__"): - type_names.append(str(arg.__name__)) - else: - type_names.append(str(arg)) - - expected_types = ", ".join(type_names) - raise ValueError( - f"{field_name} is not one of the expected " - f"types: {expected_types}") + actual_shape = self._validate_field( + value, + field_name, + expected_shape, + arg.dynamic_dims, + ) self._validate_tensor_shape_expected( - actual_shape, expected_shape, field_name, - shape_env, arg.dynamic_dims) + actual_shape, + expected_shape, + field_name, + shape_env, + arg.dynamic_dims, + ) def print_shapes(self) -> None: """Print TensorShape annotations for debugging.""" diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 973979fdf7dfd..6e27e93c91153 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -7,21 +7,26 @@ import numpy as np import torch from torch.nn.functional import scaled_dot_product_attention -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.gpu_input_batch import InputBatch try: import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True # AttributeError is to handle a bug in ipex # https://github.com/intel/intel-extension-for-pytorch/pull/813 @@ -43,19 +48,19 @@ class TorchSDPABackend(AttentionBackend): @classmethod def validate_head_size(cls, head_size: int) -> None: attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size( - head_size) + is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) if not is_valid: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "TORCH_SDPA_VLLM_V1" + return "TORCH_SDPA" @staticmethod def get_impl_cls() -> type["TorchSDPABackendImpl"]: @@ -65,10 +70,6 @@ class TorchSDPABackend(AttentionBackend): def get_metadata_cls() -> type["AttentionMetadata"]: return TorchSDPAMetadata - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: return TorchSDPAMetadataBuilderV1 @@ -79,9 +80,11 @@ class TorchSDPABackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return _get_paged_attn_impl().get_kv_cache_shape( - num_blocks, block_size, num_kv_heads, head_size) + num_blocks, block_size, num_kv_heads, head_size + ) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -90,19 +93,33 @@ class TorchSDPABackend(AttentionBackend): @dataclass class TorchSDPAMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - seq_lens_tensor: Optional[torch.Tensor] + decode_seq_lens_tensor: Optional[torch.Tensor] # Maximum sequence length in the batch. 0 if it is prefill-only batch. - max_decode_seq_len: int + decode_max_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. - block_tables: Optional[torch.Tensor] + decode_block_tables: Optional[torch.Tensor] """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -112,9 +129,9 @@ class TorchSDPAMetadata(AttentionMetadata): # For chunked prefill only max_query_len: Optional[int] = None - max_kv_len: Optional[int] = None + prefill_max_seq_len: Optional[int] = None prefill_query_start_loc: Optional[torch.Tensor] = None - kv_start_loc: Optional[torch.Tensor] = None + prefill_seq_start_loc: Optional[torch.Tensor] = None prefill_block_tables: Optional[torch.Tensor] = None # For V1 logits index only @@ -148,23 +165,27 @@ class TorchSDPAMetadata(AttentionMetadata): @property def is_all_encoder_attn_metadata_set(self): - ''' + """ All attention metadata required for encoder attention is set. - ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + """ + return ( + (self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None) + ) @property def is_all_cross_attn_metadata_set(self): - ''' + """ All attention metadata required for enc/dec cross-attention is set. Superset of encoder attention required metadata. - ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + """ + return ( + self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None) + ) @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: @@ -182,7 +203,7 @@ class TorchSDPAMetadata(AttentionMetadata): self, attn_type: str, ): - ''' + """ Extract appropriate sequence lengths from attention metadata according to attention type. @@ -195,10 +216,12 @@ class TorchSDPAMetadata(AttentionMetadata): Returns: * Appropriate sequence lengths tensor for query * Appropriate sequence lengths tensor for key & value - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): seq_lens_q = self.seq_lens seq_lens_kv = self.seq_lens elif attn_type == AttentionType.ENCODER: @@ -215,7 +238,7 @@ class TorchSDPAMetadata(AttentionMetadata): self, attn_type: str, ) -> Optional[list[torch.Tensor]]: - ''' + """ Extract appropriate attention bias from attention metadata according to attention type. @@ -227,10 +250,12 @@ class TorchSDPAMetadata(AttentionMetadata): Returns: * Appropriate attention bias value given the attention type - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): return self.attn_bias elif attn_type == AttentionType.ENCODER: return self.encoder_attn_bias @@ -244,7 +269,7 @@ class TorchSDPAMetadata(AttentionMetadata): attn_bias: list[torch.Tensor], attn_type: str, ) -> None: - ''' + """ Update appropriate attention bias field of attention metadata, according to attention type. @@ -254,10 +279,12 @@ class TorchSDPAMetadata(AttentionMetadata): * attn_bias: The desired attention bias value * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): self.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: self.encoder_attn_bias = attn_bias @@ -270,7 +297,7 @@ class TorchSDPAMetadata(AttentionMetadata): self, attn_type: str, ) -> tuple: - ''' + """ The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation. @@ -292,41 +319,48 @@ class TorchSDPAMetadata(AttentionMetadata): * Appropriate sequence-lengths tensor * Appropriate max sequence-length scalar * Appropriate block tables (or None) - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run - return (self.seq_lens_tensor, self.max_decode_seq_len, - self.block_tables) + return ( + self.decode_seq_lens_tensor, + self.decode_max_seq_len, + self.decode_block_tables, + ) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - self.cross_block_tables) + return ( + self.encoder_seq_lens_tensor, + self.max_encoder_seq_len, + self.cross_block_tables, + ) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - None) + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): + reorder_batch_threshold: int = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ) -> None: + super().__init__(kv_cache_spec, layer_names, vllm_config, device) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device) -> None: - self.kv_cache_spec = kv_cache_spec - self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config - - # For reorder - self.reorder_prompt_req_index_list = np.empty( - vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) - self.reorder_decode_req_index_list = np.empty( - vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) - self.num_prompt_req: int = 0 + self._init_reorder_batch_threshold(1, False) self.seq_start_loc_cpu = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -335,105 +369,70 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - prompt_list_idx = 0 - decode_list_idx = 0 - for req_index in range(input_batch.num_reqs): - if input_batch.num_computed_tokens_cpu[ - req_index] < input_batch.num_prompt_tokens[req_index]: - # prompt stage - self.reorder_prompt_req_index_list[prompt_list_idx] = req_index - prompt_list_idx += 1 - else: - # decode stage - self.reorder_decode_req_index_list[decode_list_idx] = req_index - decode_list_idx += 1 - assert decode_list_idx + prompt_list_idx == input_batch.num_reqs - - # Update prompt requests number - self.num_prompt_req = prompt_list_idx - - reorder_req_num = 0 - for req_index in range(decode_list_idx): - if self.reorder_decode_req_index_list[req_index] < prompt_list_idx: - reorder_req_num += 1 - else: - break - - if reorder_req_num == 0: - return False - - reorder_prompt_list = ( - self.reorder_prompt_req_index_list[:prompt_list_idx] - [-reorder_req_num:]) - reorder_decode_list = ( - self.reorder_decode_req_index_list[:decode_list_idx] - [:reorder_req_num]) - assert reorder_decode_list.size == reorder_prompt_list.size - - for idx in range(reorder_req_num): - prompt_req_index = reorder_prompt_list[idx].item() - decode_req_index = reorder_decode_list[idx].item() - input_batch.swap_states(prompt_req_index, decode_req_index) - - return True - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TorchSDPAMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TorchSDPAMetadata: num_reqs = common_attn_metadata.num_reqs max_query_len = common_attn_metadata.max_query_len seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_np = seq_lens_cpu.numpy() - num_prompt_req = self.num_prompt_req - max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( - ) if num_prompt_req > 0 else 0 - max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item( - ) if num_prompt_req < num_reqs else 0 - self.seq_start_loc_np[0] = 0 - np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) - num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - - num_prefill_tokens) + query_start_loc_np = query_start_loc_cpu.numpy() + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) + + max_prefill_seq_len = ( + seq_lens_np[num_decodes:num_reqs].max().item() if num_prefills > 0 else 0 + ) + max_decode_seq_len = ( + seq_lens_np[:num_decodes].max().item() if num_prefills < num_reqs else 0 + ) + self.seq_start_loc_np[0] = 0 + np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1 : num_reqs + 1]) slot_mapping = common_attn_metadata.slot_mapping.long() block_table_tensor = common_attn_metadata.block_table_tensor + query_start_loc_np = query_start_loc_cpu.numpy() + query_start_loc_np[num_decodes : num_reqs + 1] -= num_decode_tokens attn_metadata = TorchSDPAMetadata( - num_prefills=num_prompt_req, + num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled seq_lens=seq_lens_cpu.tolist(), - seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode - max_decode_seq_len=max_decode_seq_len, # decode - block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode + decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode + decode_max_seq_len=max_decode_seq_len, # decode + decode_block_tables=block_table_tensor[:num_decodes], # decode chunked_prefill=self.scheduler_config.chunked_prefill_enabled, max_query_len=max_query_len, - max_kv_len=max_prefill_seq_len, - prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req + - 1], # prefill - kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + - 1], # prefill - prefill_block_tables=block_table_tensor[: - num_prompt_req], # prefill - query_start_loc=query_start_loc_cpu[:num_reqs + - 1], # for logits index - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, + prefill_max_seq_len=max_prefill_seq_len, + prefill_query_start_loc=query_start_loc_cpu[ + num_decodes : num_reqs + 1 + ], # prefill + prefill_seq_start_loc=self.seq_start_loc_cpu[ + num_decodes : num_reqs + 1 + ], # prefill + prefill_block_tables=block_table_tensor[num_decodes:num_reqs], # prefill + query_start_loc=query_start_loc_cpu[: num_reqs + 1], # for logits index ) return attn_metadata class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): - def __init__( self, num_heads: int, @@ -450,8 +449,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if logits_soft_cap is not None: - logger.warning_once("Torch SPDA does not support logits soft cap. " - "Outputs may be slightly off.") + logger.warning_once( + "Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off." + ) self.paged_attn_impl = _get_paged_attn_impl() self.num_heads = num_heads self.head_size = head_size @@ -464,13 +465,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): self.kv_cache_dtype = kv_cache_dtype self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.need_mask = ( + self.alibi_slopes is not None or self.sliding_window is not None + ) if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: raise NotImplementedError( "Torch SDPA backend FP8 KV cache requires " - "intel_extension_for_pytorch support.") + "intel_extension_for_pytorch support." + ) self.attn_type = attn_type def forward( @@ -491,7 +494,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache: shape = + [2, num_blocks, block_size * num_kv_heads * head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. @@ -501,22 +505,28 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for TorchSDPABackendImpl") + " for TorchSDPABackendImpl" + ) # For warming-up if attn_metadata is None: return query attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") + if attn_type == AttentionType.ENCODER and ( + not attn_metadata.is_all_encoder_attn_metadata_set + ): + raise AttributeError( + "Encoder attention requires setting encoder metadata attributes." + ) + elif attn_type == AttentionType.ENCODER_DECODER and ( + not attn_metadata.is_all_cross_attn_metadata_set + ): + raise AttributeError( + "Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes." + ) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -527,7 +537,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): else: assert value is None - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -536,7 +546,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = self.paged_attn_impl.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size + ) if (key is not None) and (value is not None): if attn_type == AttentionType.ENCODER_DECODER: @@ -549,8 +560,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): updated_slot_mapping = attn_metadata.slot_mapping self.paged_attn_impl.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) + key, + value, + key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -576,26 +594,24 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if prefill_meta := attn_metadata.prefill_metadata: if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore assert attn_metadata.seq_lens is not None - self._run_sdpa_forward(output, - query, - key, - value, - prefill_meta, - attn_type=attn_type) + self._run_sdpa_forward( + output, query, key, value, prefill_meta, attn_type=attn_type + ) else: # prefix-enabled attention assert not self.need_mask import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) ipex_modules.PagedAttention.flash_attn_varlen_func( - output[:prefill_meta.num_prefill_tokens, :, :], - query[:prefill_meta.num_prefill_tokens, :, :], + output[prefill_meta.num_decode_tokens :, :, :], + query[prefill_meta.num_decode_tokens :, :, :], key_cache, value_cache, prefill_meta.prefill_query_start_loc, - prefill_meta.kv_start_loc, + prefill_meta.prefill_seq_start_loc, prefill_meta.max_query_len, - prefill_meta.max_kv_len, + prefill_meta.prefill_max_seq_len, self.scale, True, prefill_meta.prefill_block_tables, @@ -604,7 +620,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") + "Encoder-only models should not have decode metadata." + ) # Decoding run. ( seq_lens_arg, @@ -613,8 +630,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ) = decode_meta.get_seq_len_block_table_args(attn_type) self.paged_attn_impl.forward_decode( - output[attn_metadata.num_prefill_tokens:, :, :], - query[attn_metadata.num_prefill_tokens:, :, :], + output[: attn_metadata.num_decode_tokens, :, :], + query[: attn_metadata.num_decode_tokens, :, :], key_cache, value_cache, block_tables_arg, @@ -640,21 +657,19 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, attn_type: str = AttentionType.DECODER, ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: attn_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, # type: ignore + ) elif self.sliding_window is not None: assert attn_metadata.seq_lens is not None attn_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore + attn_metadata.seq_lens, self.sliding_window, query.dtype + ) else: seq_lens, _ = attn_metadata.get_seq_lens(attn_type) attn_masks = [None] * len(seq_lens) @@ -664,22 +679,30 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) - causal_attn = (attn_type == AttentionType.DECODER) + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) + value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) + + causal_attn = attn_type == AttentionType.DECODER seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) start_q, start_kv = 0, 0 - for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, - attn_masks): + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv - sub_out = scaled_dot_product_attention( - query[None, :, start_q:end_q, :], - key[None, :, start_kv:end_kv, :], - value[None, :, start_kv:end_kv, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=causal_attn and mask is None, - scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + sub_out = ( + scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv @@ -702,9 +725,11 @@ def _make_alibi_bias( num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + inf_mask = ( + torch.empty((1, seq_len, seq_len), dtype=bias.dtype) + .fill_(-torch.inf) + .triu_(diagonal=1) + ) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases @@ -733,7 +758,6 @@ def _make_sliding_window_bias( class _PagedAttention: - @staticmethod def validate_head_size(head_size: int) -> tuple[bool, list[int]]: SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] @@ -760,8 +784,7 @@ class _PagedAttention: num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @@ -835,19 +858,8 @@ class _PagedAttention: blocksparse_head_sliding_step, ) - @staticmethod - def copy_blocks( - kv_caches: list[torch.Tensor], - src_to_dists: torch.Tensor, - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) - class _IPEXPagedAttention(_PagedAttention): - @staticmethod def validate_head_size(head_size: int) -> tuple[bool, list[int]]: return True, [] @@ -880,8 +892,8 @@ class _IPEXPagedAttention(_PagedAttention): *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, - slot_mapping.flatten().int()) + key, value, key_cache, value_cache, slot_mapping.flatten().int() + ) @staticmethod def forward_decode( @@ -901,17 +913,30 @@ class _IPEXPagedAttention(_PagedAttention): *args, ) -> None: block_size = value_cache.shape[2] - head_mapping = torch.arange( - 0, - num_kv_heads, - device="cpu", - dtype=torch.int32, - ).view(num_kv_heads, - 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + head_mapping = ( + torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ) + .view(num_kv_heads, 1) + .repeat_interleave(query.size(1) // num_kv_heads) + .flatten() + ) ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, query.contiguous(), key_cache, value_cache, head_mapping, - scale, block_tables, context_lens, block_size, max_context_len, - alibi_slopes) + output, + query.contiguous(), + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) def _get_paged_attn_impl(): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 6e7096de924ca..a71e51471905a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -1,45 +1,54 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import numpy as np import torch -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version, - is_flash_attn_varlen_func_available) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, - get_scheduler_metadata, - reshape_and_cache_flash) + from vllm.attention.utils.fa_utils import ( + flash_attn_varlen_func, + get_scheduler_metadata, + reshape_and_cache_flash, + ) from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) -# NOTE(woosuk): This is an arbitrary number. Tune it if needed. -_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 - class FlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True + supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -49,6 +58,10 @@ class FlashAttentionBackend(AttentionBackend): def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -58,11 +71,12 @@ class FlashAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "FLASH_ATTN_VLLM_V1" + return "FLASH_ATTN" @staticmethod def get_impl_cls() -> type["FlashAttentionImpl"]: @@ -82,6 +96,7 @@ class FlashAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") @@ -142,7 +157,8 @@ class FlashAttentionMetadata: def _get_sliding_window_configs( - vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + vllm_config: VllmConfig, +) -> set[Optional[tuple[int, int]]]: """Get the set of all sliding window configs used in the model.""" sliding_window_configs: set[Optional[tuple[int, int]]] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) @@ -152,8 +168,7 @@ def _get_sliding_window_configs( return sliding_window_configs -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]): # FA3: # Supports full cudagraphs for all cases. # @@ -167,46 +182,53 @@ class FlashAttentionMetadataBuilder( # work for mixed prefill-decode and uniform-decode. But for non-spec decodes # the graphs would not work for mixed prefill-decode; sorta the inverse # of UNIFORM_SINGLE_TOKEN_DECODE. - # Theres probably a better way to describe this using `AttentionCGSupport` + # There's probably a better way to describe this using `AttentionCGSupport` # but for now just set it to `UNIFORM_BATCH` to get use to drop down # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support = AttentionCGSupport.ALWAYS \ - if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH + cudagraph_support = ( + AttentionCGSupport.ALWAYS + if get_flash_attn_version() == 3 + else AttentionCGSupport.UNIFORM_BATCH + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. - self.aot_schedule = (get_flash_attn_version() == 3) + self.aot_schedule = get_flash_attn_version() == 3 - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.aot_schedule: - self.max_cudagraph_size = self.compilation_config.max_capture_size - if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. raise ValueError( - "Capture size larger than 992 is not supported for " - "full cuda graph.") + "Capture size larger than 992 is not supported for full cuda graph." + ) self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -216,18 +238,20 @@ class FlashAttentionMetadataBuilder( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: """ - fast_build disables AOT scheduling, used when there will be few + fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode """ num_reqs = common_attn_metadata.num_reqs @@ -251,8 +275,7 @@ class FlashAttentionMetadataBuilder( # build() call so the layers are constructed (cannot populate) # in __init__. if aot_schedule: - sliding_window_configs = _get_sliding_window_configs( - self.vllm_config) + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: @@ -261,12 +284,22 @@ class FlashAttentionMetadataBuilder( self.aot_schedule = False aot_schedule = False - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( - cache_dtype) + cache_dtype + ) else: qkv_dtype = self.kv_cache_dtype if aot_schedule: @@ -283,48 +316,52 @@ class FlashAttentionMetadataBuilder( page_size=self.block_size, causal=causal, window_size=self.aot_sliding_window, - num_splits=self.max_num_splits, + num_splits=max_num_splits, ) return None use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, - causal=False) - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=suffix_kv_lens, - max_seq_len=max_seq_len - - common_prefix_len, - causal=True) + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=seq_lens, - max_seq_len=max_seq_len, - causal=causal) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) # For FA3 + full cudagraph - max_num_splits = 0 if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata @@ -335,13 +372,6 @@ class FlashAttentionMetadataBuilder( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if num_actual_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, @@ -358,7 +388,8 @@ class FlashAttentionMetadataBuilder( suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - causal=causal) + causal=causal, + ) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -366,7 +397,6 @@ class FlashAttentionMetadataBuilder( class FlashAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -407,18 +437,20 @@ class FlashAttentionImpl(AttentionImpl): self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() - if is_quantized_kv_cache(self.kv_cache_dtype) \ - and not flash_attn_supports_fp8(): + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( - "FlashAttention does not support fp8 kv-cache on this device.") + "FlashAttention does not support fp8 kv-cache on this device." + ) self.sinks = sinks if self.sinks is not None: assert self.vllm_flash_attn_version == 3, ( - "Sinks are only supported in FlashAttention 3") + "Sinks are only supported in FlashAttention 3" + ) assert self.sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " - "heads in the layer") + "heads in the layer" + ) def forward( self, @@ -438,7 +470,8 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -450,8 +483,8 @@ class FlashAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -474,11 +507,14 @@ class FlashAttentionImpl(AttentionImpl): if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching - return self._forward_encoder_attention(query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - output[:num_actual_tokens], - attn_metadata, layer) + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) @@ -486,8 +522,11 @@ class FlashAttentionImpl(AttentionImpl): # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if (self.kv_sharing_target_layer_name is None and key is not None - and value is not None): + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -507,16 +546,12 @@ class FlashAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( - self.kv_cache_dtype) + self.kv_cache_dtype + ) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc @@ -577,6 +612,7 @@ class FlashAttentionImpl(AttentionImpl): q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, + s_aux=self.sinks, ) return output @@ -602,7 +638,8 @@ class FlashAttentionImpl(AttentionImpl): # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError( - "quantization is not supported for encoder attention") + "quantization is not supported for encoder attention" + ) # Use encoder-specific metadata for sequence information cu_seqlens_q = attn_metadata.query_start_loc @@ -612,7 +649,8 @@ class FlashAttentionImpl(AttentionImpl): descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads) + self.num_kv_heads, + ) # Call flash attention directly on Q, K, V tensors flash_attn_varlen_func( @@ -675,8 +713,12 @@ def use_cascade_attention( num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window - and not use_alibi and np.all(query_lens == 1)) + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) if not use_flash_decoding: # Use cascade attention. return True @@ -698,8 +740,9 @@ def use_cascade_attention( cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles - flash_decoding_ctas = (num_reqs * num_kv_heads * - cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) @@ -730,11 +773,13 @@ def cascade_attention( q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, + s_aux: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window.") + "Cascade attention does not support sliding window." + ) num_tokens = query.shape[0] block_size = key_cache.shape[-3] @@ -760,12 +805,12 @@ def cascade_attention( return_softmax_lse=True, scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) - if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) - if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) - if v_descale is not None else None, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + # s_aux is incorporated into prefix_lse inside the GPU kernel, + # enabling its effect during the final attention merge. + s_aux=s_aux, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -787,14 +832,10 @@ def cascade_attention( return_softmax_lse=True, scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) - if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) - if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) - if v_descale is not None else None, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, ) # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 50819bb2bb943..15fd48ca54aa1 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,41 +1,54 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashInfer.""" + from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar, Union +import numpy as np import torch -from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - MultiLevelCascadeAttentionWrapper) +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper, +) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant) + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (supports_trtllm_attention, - use_trtllm_attention) -from vllm.v1.attention.backends.flash_attn import use_cascade_attention -# yapf conflicts with isort for this block -# yapf: disable -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) -# yapf: enable +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + flashinfer_disable_q_quantization, + use_trtllm_attention, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -45,9 +58,103 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) +trtllm_gen_workspace_buffer = None + + +def _get_trtllm_gen_workspace_buffer(): + global trtllm_gen_workspace_buffer + if trtllm_gen_workspace_buffer is None: + trtllm_gen_workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + ) + return trtllm_gen_workspace_buffer + + +@triton.jit +def _trtllm_prefill_attn_kvfp8_dequant( + kv_cache_ptr, + block_tables_prefill_ptr, + block_table_stride, + mock_kv_cache_ptr, + k_scale_ptr, + v_scale_ptr, + K_CACHE_STRIDE: tl.constexpr, + KV_CACHE_STRIDE: tl.constexpr, +): + batch_idx = tl.program_id(0).to(tl.int64) + mock_block_table_idx = tl.program_id(1).to(tl.int64) + orig_page_num = tl.load( + block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx + ).to(tl.int64) + if orig_page_num <= 0: + return + dequant_dtype = mock_kv_cache_ptr.dtype.element_ty + + # Dequantize K + k_scale_val = tl.load(k_scale_ptr) + offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val + mock_cache_offset = ( + batch_idx * block_table_stride + mock_block_table_idx + 1 + ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + # Dequantize V + v_scale_val = tl.load(v_scale_ptr) + offset = ( + orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + ) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val + mock_cache_offset = ( + (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE + + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE) + ) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + +def trtllm_prefill_attn_kvfp8_dequant( + kv_cache: torch.Tensor, + block_tables_prefill: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + dequant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_of_page_per_token = block_tables_prefill.shape + s = kv_cache.shape + assert s[1] == 2 + assert dequant_dtype in (torch.bfloat16, torch.float16) + k_cache_stride = s[2] * s[3] * s[4] + kv_cache_stride = k_cache_stride * s[1] + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) + # mock kv cache contains just the pages needed by this prefill + mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) + # we simply sequentially index the pages needed by this prefill + mock_block_table = torch.arange( + start=1, + end=batch_size * num_of_page_per_token + 1, + dtype=torch.int32, + device=block_tables_prefill.device, + ).reshape(batch_size, num_of_page_per_token) + grid = (batch_size, num_of_page_per_token) + _trtllm_prefill_attn_kvfp8_dequant[grid]( + kv_cache, + block_tables_prefill, + num_of_page_per_token, + mock_kv_cache, + k_scale, + v_scale, + k_cache_stride, + kv_cache_stride, + ) + return mock_kv_cache, mock_block_table + class FlashInferBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -59,6 +166,13 @@ class FlashInferBackend(AttentionBackend): # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 + return [16, 32, 64] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -68,11 +182,12 @@ class FlashInferBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "FLASHINFER_VLLM_V1" + return "FLASHINFER" @staticmethod def get_impl_cls() -> type[FlashInferImpl]: @@ -92,6 +207,7 @@ class FlashInferBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) @@ -120,36 +236,16 @@ class FlashInferBackend(AttentionBackend): @dataclass class FlashInferMetadata: - num_actual_tokens: int # Number of tokens excluding padding. - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - qo_indptr_cpu: torch.Tensor - # An example for paged_kv_indices, paged_kv_indptr: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan) - paged_kv_indptr_cpu: torch.Tensor - # The page indices of the paged kv cache (on device for plan) - paged_kv_indices: torch.Tensor - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] (CPU for plan) - paged_kv_last_page_len_cpu: torch.Tensor # The data type of the query q_data_type: torch.dtype - seq_lens_cpu: torch.Tensor slot_mapping: torch.Tensor # For flashinfer trtllm batch decode max_q_len: int + max_q_len_prefill: int max_seq_len: int seq_lens: torch.Tensor block_table_tensor: torch.Tensor @@ -164,138 +260,162 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - shared_qo_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None - shared_kv_page_indices_cpu: Optional[torch.Tensor] = None - shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None + decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None + cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None - qo_indptr_gpu: Optional[torch.Tensor] = None - paged_kv_indptr_gpu: Optional[torch.Tensor] = None + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - self.device = device - self.vllm_config = vllm_config + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config - self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(self.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL + speculative_config = vllm_config.speculative_config + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if speculative_config is not None + else 0 + ) + self.enable_cuda_graph = ( + self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + ) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. self._decode_wrappers_cudagraph: dict[ - int, BatchDecodeWithPagedKVCacheWrapper] = {} + int, BatchDecodeWithPagedKVCacheWrapper + ] = {} self._decode_cudagraph_max_bs = min( - max_num_reqs, self.compilation_config.max_capture_size) + (1 + num_spec_tokens) * max_num_reqs, + self.compilation_config.max_capture_size, + ) self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) + self.vllm_config.parallel_config + ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size - self.enable_fusion = ( - self.compilation_config.pass_config.enable_attn_fusion) - self.q_data_type = self.model_config.dtype self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): - self.kv_cache_dtype = ( - FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.cache_dtype)) - # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled - if self.enable_fusion: - self.q_data_type = self.kv_cache_dtype + self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype + ) else: + assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype + # Use model dtype as q dtype when TRTLLM attn is not supported, or + # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to + # use fp8 q if kv cache is fp8, and will fall back to model dtype + # if TRTLLM attention kernel is not used when building attn metadata + can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + if can_use_trtllm and not flashinfer_disable_q_quantization(): + self.q_data_type = self.kv_cache_dtype + else: + self.q_data_type = self.model_config.dtype + + self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) - + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl) + ) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks + if self.has_sinks and not can_use_trtllm: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention on " + "earlier GPUs." + ) # Preparing persistent buffers (device-side) - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=self.device) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) self.paged_kv_indices = torch.zeros( max_num_pages, # max num pages possible dtype=torch.int32, - device=self.device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=self.device) + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) # host-side buffer pin_memory = is_pin_memory_available() - self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_indices_cpu = torch.zeros(max_num_pages, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - - self.block_table_arange = torch.arange(max_num_pages_per_req, - dtype=torch.int32, - device=self.device) + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indptr_buffer = torch.zeros_like( + self.paged_kv_indptr_cpu, pin_memory=pin_memory + ) + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device + ) return self._workspace_buffer def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout()) + self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._prefill_wrapper - def _get_decode_wrapper(self, - batch_size: int, - use_cudagraph: bool = False): + def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): if use_cudagraph: - decode_wrapper = self._decode_wrappers_cudagraph.get( - batch_size, None) + decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) else: decode_wrapper = self._decode_wrapper if decode_wrapper is None: if use_cudagraph: - paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[: - batch_size] + paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] else: paged_kv_indptr = None paged_kv_indices = None @@ -308,7 +428,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, # Tensor cores are enabled by default because the perf would be - # atleast as good as cuda cores for all attention ops in latest + # at least as good as cuda cores for all attention ops in latest # gpus. use_tensor_cores=True, ) @@ -324,154 +444,35 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), get_kv_cache_layout()) + 2, self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._cascade_wrapper - def _plan(self, attn_metadata: FlashInferMetadata): - if attn_metadata.use_cascade: - attn_metadata.cascade_wrapper = self._get_cascade_wrapper() - attn_metadata.cascade_wrapper.plan( - [ - attn_metadata.shared_qo_indptr_cpu, - attn_metadata.qo_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indptr_cpu, - attn_metadata.paged_kv_indptr_cpu - ], - [ - attn_metadata.shared_kv_page_indices_cpu, - attn_metadata.paged_kv_indices - ], - [ - attn_metadata.shared_kv_last_page_len_cpu, - attn_metadata.paged_kv_last_page_len_cpu - ], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - ) - else: - # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() - num_prefills = attn_metadata.num_prefills - num_decodes = attn_metadata.num_decodes - if num_prefills > 0: - # Decodes are first so prefills start after the last decode - prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert attn_metadata.paged_kv_last_page_len_cpu[ - prefill_start:].shape[0] == num_prefills - # Since prefill_wrapper.run() will be called with - # query[num_decode_tokens:] we need to adjust the qo_indptr - # to be relative to the start of the prefill queries. - qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ - prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] - paged_kv_indptr_cpu = attn_metadata.paged_kv_indptr_cpu[ - prefill_start:] - if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - attn_metadata.paged_kv_indices, - attn_metadata. - paged_kv_last_page_len_cpu[prefill_start:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - ) - else: - attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) - attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device) - - if num_decodes > 0: - pure_decode = num_prefills == 0 - # possible required padding for cudagraph replay - use_cudagraph = (self.enable_cuda_graph and pure_decode and - num_decodes <= self._decode_cudagraph_max_bs) - if use_cudagraph: - num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decodes)) - # Carefully fulfill the padding region with reasonable value - # on cpu. - # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - attn_metadata. - paged_kv_indptr_cpu[-1]) - # Fill the remaining paged_kv_last_page_len_cpu with 1. - # This is because flashinfer treats 0 as a full page - # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) - - else: - num_input_tokens = num_decodes - - attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph) - if not attn_metadata.decode_use_trtllm: - # Use the persistent buffer with padding length, - # instead of the same address but chunked version - # in atten_metadata when using cudagraph. - fast_plan_decode( - attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], - attn_metadata.paged_kv_indices, - self.paged_kv_last_page_len_cpu[:num_input_tokens], - attn_metadata.seq_lens_cpu[:num_input_tokens], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - ) - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashInferMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashInferMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor - block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size use_cascade = common_prefix_len > 0 if use_cascade: @@ -480,83 +481,117 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_common_kv_blocks = common_prefix_len // page_size # Create CPU versions directly for cascade (no GPU versions needed) - shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device='cpu') - shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device='cpu') - shared_kv_page_indices_cpu = block_table_tensor[ - 0, :num_common_kv_blocks] - shared_kv_last_page_len_cpu = torch.tensor([page_size], - dtype=torch.int32, - device='cpu') + shared_qo_indptr_cpu = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indptr_cpu = torch.tensor( + [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor( + [page_size], dtype=torch.int32, device="cpu" + ) # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds_cpu -= num_common_kv_blocks + num_blocks_np -= num_common_kv_blocks else: shared_qo_indptr_cpu = None shared_kv_page_indptr_cpu = None shared_kv_page_indices_cpu = None shared_kv_last_page_len_cpu = None - max_num_blocks = block_table_bounds_cpu.max().item() - block_table_bounds = block_table_bounds_cpu.to(self.device, - non_blocking=True) - mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) - < block_table_bounds.unsqueeze(1)) - # write self.paged_kv_indices inplace - num_actual_pages = torch.sum(mask) - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - torch.masked_select(block_table_tensor[:, :max_num_blocks], - mask, - out=paged_kv_indices) - # write self.paged_kv_indptr_cpu inplace (0-index is always 0) - torch.cumsum(block_table_bounds_cpu, - dim=0, - dtype=torch.int32, - out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1 : num_reqs + 1], + ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + ) + + # write self.paged_kv_indices inplace + num_actual_pages = self.paged_kv_indptr_np[num_reqs] + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + _copy_page_indices_kernel[(num_reqs,)]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) - paged_kv_last_page_len_cpu = seq_lens_cpu % page_size # write self.paged_kv_last_page_len_cpu inplace - torch.where(paged_kv_last_page_len_cpu == 0, - torch.tensor(page_size), - paged_kv_last_page_len_cpu, - out=self.paged_kv_last_page_len_cpu[:num_reqs]) + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) - # Check if any layer uses sinks (requires TRTLLM attention) - has_sinks = self.global_hyperparameters.has_sinks + uses_spec_reorder = self.reorder_batch_threshold > 1 + prefill_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) + decode_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_prefill_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=True, - has_sinks=has_sinks) - decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_decode_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=False, - has_sinks=has_sinks) + if not (prefill_use_trtllm and decode_use_trtllm): + if self.has_sinks: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention " + "on earlier GPUs." + ) + + if not self.global_hyperparameters.has_same_window_lefts: + raise ValueError( + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) + + assert self.global_hyperparameters.has_same_all_params, ( + "FlashInfer backend currently only supports models in which " + "all layers share the same values for the following " + "hyperparameters: `window_left`, `logits_soft_cap`, " + "`sm_scale`." + ) + + # The q quantization is not supported for non-trtllm attention, + # fall back to model dtype. + self.q_data_type = self.model_config.dtype attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, - paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs], - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len_cpu=self. - paged_kv_last_page_len_cpu[:num_reqs], q_data_type=self.q_data_type, - seq_lens_cpu=seq_lens_cpu, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, + max_q_len_prefill=max_q_len, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor, @@ -567,55 +602,163 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - shared_qo_indptr_cpu=shared_qo_indptr_cpu, - shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, - shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, - shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, ) - self._plan(attn_metadata) + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [shared_qo_indptr_cpu, qo_indptr_cpu], + [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], + [shared_kv_page_indices_cpu, paged_kv_indices], + [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + # Regular attention (common case). + # Decodes are at the front and prefills are at the back. + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decodes + if num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert ( + paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills + ) + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) + paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] + + # Recompute max_q_len for the slice of requests we are using + # for prefills. This can be different from max_q_len when + # we have a non-uniform batch with some short decodes offloaded + # to the prefill pathway + query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] + attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) + + if not attn_metadata.prefill_use_trtllm: + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( + self.device, non_blocking=True + ) + attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( + self.device, non_blocking=True + ) + + if num_decodes > 0: + pure_decode = num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = ( + self.enable_cuda_graph + and pure_decode + and num_decode_tokens <= self._decode_cudagraph_max_bs + ) + if use_cudagraph: + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_decode_tokens + ) + # Carefully fulfill the padding region with reasonable value + # on cpu. + # Make sure paged_kv_indptr_cpu is not decreasing + self.paged_kv_indptr_cpu[ + 1 + num_decodes : 1 + num_input_tokens + ].fill_(paged_kv_indptr_cpu[-1]) + # Fill the remaining paged_kv_last_page_len_cpu with 1. + # This is because flashinfer treats 0 as a full page + # instead of empty. + self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( + 1 + ) + + else: + num_input_tokens = num_decode_tokens + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens, use_cudagraph + ) + if not attn_metadata.decode_use_trtllm: + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + attn_metadata.decode_wrapper, + self.paged_kv_indptr_cpu[: num_input_tokens + 1], + paged_kv_indices, + self.paged_kv_last_page_len_cpu[:num_input_tokens], + seq_lens_cpu[:num_input_tokens], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) return attn_metadata - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with FlashInfer. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "FlashInfer only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False - return use_cascade_attention(*args, **kwargs) + # TODO: Cascade attention doesn't work, disable it for now + # return use_cascade_attention(*args, **kwargs) + return False class FlashInferImpl(AttentionImpl): - def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -628,8 +771,9 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) - self.window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) + self.window_left = ( + self.sliding_window[0] if self.sliding_window is not None else -1 + ) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -637,30 +781,34 @@ class FlashInferImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl" + ) - self.sinks: Optional[torch.Tensor] = None + self.sinks: torch.Tensor | None = None if sinks is not None: if sinks.shape[0] != num_heads: raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}.") + f"{sinks.shape[0]}." + ) self.sinks = sinks - self.support_trtllm_attn = (supports_trtllm_attention() - and num_heads % num_kv_heads == 0) - self.bmm1_scale: Optional[float] = None - self.bmm2_scale: Optional[float] = None - self.o_sf_scale: Optional[float] = None + self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + self.o_sf_scale: float | None = None def fused_output_quant_supported(self, quant_key: QuantKey): - return (self.support_trtllm_attn - and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + ) def forward( self, @@ -670,9 +818,9 @@ class FlashInferImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -680,11 +828,9 @@ class FlashInferImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape - - # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - - + kv_cache: KV cache tensor with different possible shapes: + - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -696,34 +842,36 @@ class FlashInferImpl(AttentionImpl): return output if self.bmm1_scale is None: - self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * - self.scale) + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert attn_metadata.q_data_type != FP8_DTYPE, \ - "Query can only be FP8 if output fusion happened." - assert output_block_scale is None, "output_block_scale "\ - "is not supported when fusion has not happened" + assert output_block_scale is None, ( + "output_block_scale is not supported when fusion has not happened" + ) else: - assert attn_metadata.q_data_type == FP8_DTYPE, \ + assert attn_metadata.q_data_type == FP8_DTYPE, ( "Query must be FP8 when attn+quant fusion happened." - assert (attn_metadata.prefill_use_trtllm and - attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + ) + assert ( + attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + ), "Must use TRT-LLM attn" if output.dtype == FP8_DTYPE: - assert output_block_scale is None, \ + assert output_block_scale is None, ( "output_block_scale should not be provided for fp8 output" + ) elif output.dtype == FP4_DTYPE: - assert output_block_scale is not None, \ + assert output_block_scale is not None, ( "output_block_scale is required for nvfp4 output" + ) else: raise ValueError(f"Unsupported output dtype: {output.dtype}") - # TRTLLM attn kernel requires o scale to pass as a host scalar, + # TRTLLM attn kernel requires to scale to pass as a host scalar, # store the o scale as a host scalar in warmup run with cuda graph # not enabled if layer._o_scale_float is None: @@ -733,12 +881,13 @@ class FlashInferImpl(AttentionImpl): elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query + # Insert FP8 quant for query + if attn_metadata.q_data_type == FP8_DTYPE: num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) # IMPORTANT! @@ -775,7 +924,8 @@ class FlashInferImpl(AttentionImpl): # to process the cache when the kv_cache_dtype is fp8 if self.kv_cache_dtype.startswith("fp8"): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype) + self.kv_cache_dtype + ) kv_cache = kv_cache.view(torch_dtype) # Inputs and outputs may be padded for CUDA graphs @@ -789,14 +939,16 @@ class FlashInferImpl(AttentionImpl): output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output + # When using spec decoding, num_decodes can be < num_decode_tokens + # because some decode requests may have more than one query token. + num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens stride_order = FlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() + # Decodes are at the front and prefills are at the back. if num_prefill_tokens > 0: prefill_wrapper = attn_metadata.prefill_wrapper prefill_query = query[num_decode_tokens:] @@ -806,8 +958,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, @@ -819,10 +970,9 @@ class FlashInferImpl(AttentionImpl): else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() - workspace_buffer = prefill_wrapper._float_workspace_buffer - block_tables_prefill = attn_metadata.block_table_tensor[ - num_decode_tokens:] - seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] + workspace_buffer = _get_trtllm_gen_workspace_buffer() + block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] + seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" @@ -834,21 +984,42 @@ class FlashInferImpl(AttentionImpl): if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None - out = FP4Tensor(data=output[num_decode_tokens:], - scale=output_block_scale, - scale_start_index=num_decode_tokens, - original_shape=prefill_query.shape) + out = FP4Tensor( + data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape, + ) else: assert self.o_sf_scale is None out = output[num_decode_tokens:] + if ( + attn_metadata.q_data_type != FP8_DTYPE + and self.kv_cache_dtype.startswith("fp8") + ): + # TRTLLM prefill attention does not support BF16 Q + # and fp8 kv cache. So to enable prefill attention + # with fp8 kv cache, we can construct a mock block + # and mock kv cache with BF16 KV involved in the prefill + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + ) + else: + mock_kv_cache = kv_cache_permute + mock_block_table = block_tables_prefill + trtllm_batch_context_with_kv_cache( query=prefill_query, - kv_cache=kv_cache_permute, + kv_cache=mock_kv_cache, workspace_buffer=workspace_buffer, - block_tables=block_tables_prefill, + block_tables=mock_block_table, seq_lens=seq_lens_prefill, - max_q_len=attn_metadata.max_q_len, + max_q_len=attn_metadata.max_q_len_prefill, max_kv_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, @@ -869,8 +1040,7 @@ class FlashInferImpl(AttentionImpl): if not attn_metadata.decode_use_trtllm: assert decode_wrapper._window_left == self.window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, @@ -882,9 +1052,10 @@ class FlashInferImpl(AttentionImpl): else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() - workspace_buffer = decode_wrapper._float_workspace_buffer - block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] + workspace_buffer = _get_trtllm_gen_workspace_buffer() + block_tables_decode = attn_metadata.block_table_tensor[ + :num_decode_tokens + ] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -897,14 +1068,23 @@ class FlashInferImpl(AttentionImpl): if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None - out = FP4Tensor(data=output[:num_decode_tokens], - scale=output_block_scale, - scale_start_index=0, - original_shape=decode_query.shape) + out = FP4Tensor( + data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape, + ) else: assert self.o_sf_scale is None out = output[:num_decode_tokens] + if num_decode_tokens % attn_metadata.num_decodes != 0: + # This gets triggered when the dummy_run forces + # attention to be initialized with q_len = 0 + q_len_per_req = 1 + else: + q_len_per_req = num_decode_tokens // attn_metadata.num_decodes + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -918,6 +1098,7 @@ class FlashInferImpl(AttentionImpl): sinks=self.sinks, o_sf_scale=self.o_sf_scale, out=out, + q_len_per_req=q_len_per_req, ) return output_padded @@ -934,13 +1115,13 @@ def fast_plan_decode( page_size: int, pos_encoding_mode: str = "NONE", window_left: int = -1, - logits_soft_cap: Optional[float] = None, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, + logits_soft_cap: float | None = None, + q_data_type: Union[str, torch.dtype] | None = "float16", + kv_data_type: Union[str, torch.dtype] | None = None, + data_type: Union[str, torch.dtype] | None = None, + sm_scale: float | None = None, + rope_scale: float | None = None, + rope_theta: float | None = None, non_blocking: bool = True, ) -> None: """ @@ -959,8 +1140,7 @@ def fast_plan_decode( # Warm up with the original plan if it is first call, and always run the # original plan if we run for dynamic shape. For fixed shape (cudagraph), # this warm up is to generate the _cached_module for the decode wrapper. - if not self.is_cuda_graph_enabled or \ - getattr(self, "vllm_first_call", True): + if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): self.plan( indptr_cpu, indices, @@ -1000,31 +1180,33 @@ def fast_plan_decode( if kv_data_type is None: kv_data_type = q_data_type - q_data_type = getattr(torch, q_data_type) if isinstance( - q_data_type, str) else q_data_type - kv_data_type = getattr(torch, kv_data_type) if isinstance( - kv_data_type, str) else kv_data_type + q_data_type = ( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ) + kv_data_type = ( + getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type + ) if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " "batch size {} mismatches the batch size set during " - "initialization {}".format(batch_size, self._fixed_batch_size)) + "initialization {}".format(batch_size, self._fixed_batch_size) + ) if len(indices) > len(self._paged_kv_indices_buf): raise ValueError( - "The size of indices should be less than or equal to the " - "allocated buffer") + "The size of indices should be less than or equal to the allocated buffer" + ) # host-to-device copy for the indptr buffer self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) # host-to-device copy for the last_page_len buffer - self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, - non_blocking=True) + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") try: - # Make sure we pass exactly 15 arguments for tensor core version + # Make sure we pass exactly 18 arguments for tensor core version self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1041,6 +1223,9 @@ def fast_plan_decode( head_dim, head_dim, False, # causal + window_left, + -1, # fixed_split_size + False, # disable_split_kv ) except Exception as e: raise RuntimeError(f"Error in tensor core plan: {e}") from e @@ -1051,3 +1236,27 @@ def fast_plan_decode( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index f4aa54660a078..7775445ae773e 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -1,31 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict +"""Attention layer with FlexAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - _score_mod_signature, - create_block_mask, - flex_attention) +import torch._dynamo.decorators +import torch.nn.functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + _score_mod_signature, + and_masks, + create_block_mask, + flex_attention, +) -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) +from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) -create_block_mask_compiled = torch.compile(create_block_mask, - fullgraph=True, - mode="reduce-overhead") +create_block_mask_compiled = torch.compile( + create_block_mask, fullgraph=True, mode="reduce-overhead" +) flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) @@ -33,7 +48,25 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: device = offsets.device counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts) + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) + + +def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): + difference = (multiple - (x.shape[dim] % multiple)) % multiple + if difference == 0: + return x + + dim = dim if dim >= 0 else x.ndim + dim + pad_list = [] + + for i in range(x.ndim - 1, dim - 1, -1): + if i == dim: + pad_list.extend([0, difference]) + else: + pad_list.extend([0, 0]) + + return F.pad(x, pad_list, mode="constant", value=0) class FlexAttentionBackend(AttentionBackend): @@ -65,6 +98,7 @@ class FlexAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -79,8 +113,11 @@ class FlexAttentionBackend(AttentionBackend): # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( - block_table: torch.Tensor, - total_blocks: Optional[int] = None) -> torch.Tensor: + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + total_blocks: int, +) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -114,35 +151,121 @@ def physical_to_logical_mapping( If a physical block is not mapped to by any logical block, its value in the result will be -1. + IMPORTANT: Garbage Value Protection + ──────────────────────────────────── + The block_table tensor may contain garbage values in unused positions + (beyond the actual sequence length). For example, if a sequence only + needs 3 blocks but the table has space for 8: + + block_table[0] = [10, 25, 7, 999, 1234, 888, ...] + ^^^^^^^^^^^^^^^^^^^^ + garbage values + + These garbage values can cause issues because: + 1. They may map to valid physical blocks by coincidence + 2. The scatter_ operation will assign them logical indices + 3. Later attention computations may incorrectly access these blocks + + To prevent this, we use seq_lens and block_size to mask out unused + entries, ensuring only valid block references are processed. Args: block_table: Tensor of shape [max_reqs, max_num_blocks] - mapping logical blocks to physical locations + mapping logical blocks to physical locations. May contain + garbage values in unused positions. + seq_lens: Tensor of sequence lengths for each request. Used to + determine how many blocks are actually needed per sequence. + block_size: Size of each block in tokens. Used with seq_lens to + compute the number of valid blocks per sequence. + total_blocks: Total number of physical blocks available Returns: - A tensor of shape [max_reqs, max_physical_block] + A tensor of shape [max_reqs, total_blocks] where each entry + physical_to_logical[req_id, physical_block] contains the logical + block index for that physical block, or -1 if unused. """ max_reqs, max_num_blocks = block_table.shape device = block_table.device - physical_to_logical = torch.full((max_reqs, total_blocks), - -1, - dtype=torch.long, - device=device) + physical_to_logical = torch.full( + (max_reqs, total_blocks), -1, dtype=torch.long, device=device + ) - logical_indices = (torch.arange(max_num_blocks, - device=device).unsqueeze(0).expand( - max_reqs, -1)) + # Only process valid blocks to avoid garbage values + num_blocks_per_seq = cdiv(seq_lens, block_size) + mask = ( + torch.arange(max_num_blocks, device=device)[None, :] + < num_blocks_per_seq[:, None] + ) - physical_to_logical.scatter_(-1, block_table.to(torch.int64), - logical_indices) - # TODO Confirm - Seems like block 0 is always empty so we reset it manually + valid_block_table = torch.where(mask, block_table, 0) + valid_logical_indices = torch.where( + mask, torch.arange(max_num_blocks, device=device)[None, :], 0 + ) + + physical_to_logical.scatter_( + -1, valid_block_table.to(torch.int64), valid_logical_indices + ) + # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical -def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, - kv_idx: torch.Tensor): +def unique_static_unsorted( + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots +) -> torch.Tensor: + """ + - Keeps the first occurrence of each non-zero value while preserving order, + then left-packs those uniques and fills the rest with `pad_val`. + - Returns (packed, keep_mask) with the *same shape* as `x`. + - Requires that all values be in the range [0, M] + - Skips ignored_val + + Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory. + + Example: + x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1] + """ + if not (-1 <= pad_val <= M): + raise ValueError("`pad_val` must lie in [-1, M]") + + # ── move `dim` to the end so we can treat tensor as [B, N] ────────── + dim = dim % x.ndim + x_perm = x.movedim(dim, -1) # shape [..., N] + B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1] + x_flat = x_perm.reshape(B, N) # [B, N] + + device = x.device + idx = torch.arange(N, device=device).expand(B, N) # per-row indices + + # ── build first-occurrence table for every v ∈ [0, M] ─────────────── + first_idx = torch.full((B, M + 1), N, device=device) # “∞” + # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i + first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") + + # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] + + # ── left-pack uniques into a fresh tensor ─────────────────────────── + dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go + packed_flat = torch.full_like(x_flat, pad_val) + + rows, src_cols = torch.nonzero(keep, as_tuple=True) + packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + + # ── restore original layout ───────────────────────────────────────── + packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) + return packed + + +def causal_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +): return q_idx >= kv_idx @@ -170,6 +293,7 @@ class FlexAttentionMetadata: num_reqs: int physical_to_logical: torch.Tensor decode_offset: torch.Tensor + num_blocks_per_seq: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -179,6 +303,46 @@ class FlexAttentionMetadata: block_mask: Optional[BlockMask] = None score_mod: Optional[_score_mod_signature] = None logical_mask_mod: _mask_mod_signature = causal_mask_mod + doc_ids: Optional[torch.Tensor] = None + direct_build: bool = True + q_block_size: int = 16 + kv_block_size: int = 16 + transformed_score_mod: Optional[_score_mod_signature] = None + sliding_window: Optional[int] = None + + def _convert_physical_to_logical( + self, + request_lookup: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert physical indices to logical indices for both query and kv. + + NB is_within_lower_bound: do sequences start on block_boundaries? + + Returns: + tuple of (is_valid, logical_q_idx, logical_kv_idx) + """ + # Map query indices to corresponding request indices + q_req = request_lookup[q_idx] + + # Convert physical KV indices to logical indices + physical_kv_block = physical_kv_idx // self.block_size + physical_kv_offset = physical_kv_idx % self.block_size + logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] + logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset + + # Determine valid kv indices + live_block = logical_block_idx >= 0 + within_upper_bound = logical_kv_idx < self.seq_lens[q_req] + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + # Convert physical query indices to logical indices + local_q_idx = q_idx - self.query_start_loc[q_req] + logical_q_idx = local_q_idx + self.decode_offset[q_req] + + return is_valid, logical_q_idx, logical_kv_idx def get_causal_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. @@ -191,11 +355,8 @@ class FlexAttentionMetadata: With this info we create the "logical" indices that are passed to mask_mod functions. This allows mask mod functions to be agnostic to layout of the query and key/value tensors. - - TODO is_within_lower_bound: do sequences start on block_boundaries? """ - # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + assert self.doc_ids is not None def final_mask_mod( b: torch.Tensor, @@ -203,27 +364,9 @@ class FlexAttentionMetadata: q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - # Map query indices to corresponding request indices - q_req = request_lookup[q_idx] - - # Convert physical KV indices to logical indices - physical_kv_block = physical_kv_idx // self.block_size - physical_kv_offset = physical_kv_idx % self.block_size - logical_block_idx = self.physical_to_logical[q_req, - physical_kv_block] - logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501 - - # Determine valid kv indices - live_block = logical_block_idx >= 0 - within_upper_bound = logical_kv_idx < self.seq_lens[q_req] - within_lower_bound = logical_kv_idx >= 0 - - is_valid = live_block & within_upper_bound & within_lower_bound - - # Convert physical query indices to logical indices - local_q_idx = q_idx - self.query_start_loc[q_req] - logical_q_idx = local_q_idx + self.decode_offset[q_req] - + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -236,7 +379,7 @@ class FlexAttentionMetadata: def get_bidirectional_mask_mod(self) -> _mask_mod_signature: """Creates the encoder mask_mod function for FlexAttention. - Since the encoder bidirectional attention doesn't run with + Since the encoder bidirectional attention doesn't run with KV cache, this function creates a mask based on the packed query sequences. """ @@ -253,13 +396,152 @@ class FlexAttentionMetadata: return final_mask_mod - def build_block_mask(self) -> BlockMask: + def get_sliding_window_mask_mod(self) -> _mask_mod_signature: + """Creates the sliding window mask_mod function for FlexAttention. + + Note that the sliding window mask here is bidirectional, we need + to mask it with the bidirectional/causal mask for encoder/decoder. + """ + + if self.sliding_window is None: + raise ValueError("sliding_window must be set for sliding window attention") + + def sliding_window_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): + return torch.abs(q_idx - kv_idx) < self.sliding_window + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) + return torch.where( + is_valid, + sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod if self.causal else sliding_window_mask_mod + + def get_mask_mod(self): + # Stage-1: initialize the base mask_mod + # (causal mask for decoder or bidirectional mask for encoder) if self.causal: mask_mod = self.get_causal_mask_mod() - kv_len = self.total_cache_tokens else: mask_mod = self.get_bidirectional_mask_mod() - kv_len = self.num_actual_tokens + # stage-2: add external mask_mod for special attention during + # forwarding runtime to create the combined mask_mod. + if self.sliding_window is not None: + # Add sliding window mask for sliding window attention + sliding_window_mask_mod = self.get_sliding_window_mask_mod() + mask_mod = and_masks(mask_mod, sliding_window_mask_mod) + return mask_mod + + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + """Creates the transformed score_mod function for FlexAttention. + + This function wraps the user's score_mod to handle physical-to-logical + index conversion, similar to how get_mask_mod works for mask functions. + """ + if self.score_mod is None: + return None + + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + user_score_mod = self.score_mod + + def transformed_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx + ) + ) + + return torch.where( + is_valid, + user_score_mod( + score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx + ), + -float("inf"), + ) + + return transformed_score_mod + + def _build_block_mask_direct(self) -> BlockMask: + """Direct block mask construction for standard causal attention. + + This method constructs the block mask directly using + BlockMask.from_kv_blocks which is much more efficient than the + generic create_block_mask approach. + + The direct path works as follows: + 1. For each query token, fetch blocks from block_table using max_seq_len + (this fetches more blocks than needed for shorter sequences) + 2. Group query tokens into chunks of q_block_size + 3. For each group, deduplicate the blocks using unique_static_unsorted + 4. Create BlockMask using the deduplicated block indices + + Over-estimation occurs when a group of q_block_size tokens contains + multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for + each sequence represented in the group, even though individual query + tokens may only need a subset of those blocks based on causal masking + and their position. + + """ + page_to_block_ratio = self.kv_block_size // self.block_size + if page_to_block_ratio != 1: + raise ValueError( + f"FlexAttention currently requires the cache block size " + f"({self.block_size}) to be equal to the kv_block_size " + f"({self.kv_block_size}). Please check your model's " + f"configuration." + ) + + used_pages = self.block_table[ + self.doc_ids, : cdiv(self.max_seq_len, self.block_size) + ] + used_pages_padded = pad_to_multiple( + used_pages, multiple=self.q_block_size, dim=0 + ) + used_pages_padded = used_pages_padded.reshape( + used_pages_padded.shape[0] // self.q_block_size, -1 + ) + used_pages_padded = used_pages_padded // page_to_block_ratio + kv_indices = unique_static_unsorted( + (used_pages_padded.long()), M=self.num_blocks + ).to(torch.int32) + + kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + block_mask_kwargs = { + "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), + "kv_num_blocks": kv_num_blocks[None, None], + "kv_indices": kv_indices[None, None], + "full_kv_num_blocks": None, + "full_kv_indices": None, + "BLOCK_SIZE": (self.q_block_size, self.kv_block_size), + "mask_mod": self.mask_mod, + } + + # compute_q_blocks parameter is available in PyTorch 2.9+ + if is_torch_equal_or_newer("2.9.0.dev0"): + block_mask_kwargs["compute_q_blocks"] = False + return BlockMask.from_kv_blocks(**block_mask_kwargs) + + def build_block_mask(self) -> BlockMask: + mask_mod = self.get_mask_mod() + kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens return create_block_mask_compiled( mask_mod, None, @@ -267,6 +549,7 @@ class FlexAttentionMetadata: self.num_actual_tokens, kv_len, device=self.block_table.device, + BLOCK_SIZE=(self.q_block_size, self.kv_block_size), ) def __post_init__(self): @@ -275,32 +558,50 @@ class FlexAttentionMetadata: assert self.cu_prefix_query_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." + # Create a lookup mapping from query indices -> request number + self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - self.block_mask = self.build_block_mask() + + self.mask_mod = self.get_mask_mod() + self.transformed_score_mod = self.get_transformed_score_mod() + + if self.direct_build and self.causal: + self.block_mask = self._build_block_mask_direct() + else: + self.block_mask = self.build_block_mask() -class FlexAttentionMetadataBuilder( - AttentionMetadataBuilder[FlexAttentionMetadata]): +class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.num_heads_q = self.model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - vllm_config.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.device = device + self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") + self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlexAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlexAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -310,6 +611,7 @@ class FlexAttentionMetadataBuilder( seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_blocks_per_seq = cdiv(seq_lens, self.block_size) use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -320,14 +622,20 @@ class FlexAttentionMetadataBuilder( block_size = self.kv_cache_spec.block_size max_possible_seq_len = self.model_config.max_model_len - total_cache_tokens = self.cache_config.num_gpu_blocks * block_size + num_gpu_blocks = self.cache_config.num_gpu_blocks + + assert num_gpu_blocks is not None, ( + "FlexAttention requires num_gpu_blocks to be set" + ) + total_cache_tokens = num_gpu_blocks * block_size inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.cache_config.num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks + ) - # Get the original offset tensor offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) out = FlexAttentionMetadata( causal=common_attn_metadata.causal, @@ -349,12 +657,19 @@ class FlexAttentionMetadataBuilder( physical_to_logical=inverse_block_table, total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, + num_blocks_per_seq=num_blocks_per_seq, + direct_build=self.direct_build, + q_block_size=self.q_block_size, + kv_block_size=self.kv_block_size, ) return out + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class FlexAttentionImpl(AttentionImpl): - sliding_window: Optional[tuple[int, int]] + sliding_window: Optional[int] alibi_slopes: Optional[torch.Tensor] logits_soft_cap: Optional[float] @@ -370,6 +685,7 @@ class FlexAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -377,38 +693,38 @@ class FlexAttentionImpl(AttentionImpl): self.num_kv_heads = num_kv_heads self.attn_type = attn_type - if attn_type not in (AttentionType.ENCODER_ONLY, - AttentionType.DECODER): + if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER): raise NotImplementedError( - f"FlexAttention does not support {attn_type} attention") + f"FlexAttention does not support {attn_type} attention" + ) if alibi_slopes is not None: raise NotImplementedError( - "FlexAttention does not support alibi slopes yet.") + "FlexAttention does not support alibi slopes yet." + ) else: self.alibi_slopes = None - if sliding_window is not None: - raise NotImplementedError( - "FlexAttention does not support sliding window yet.") - else: - self.sliding_window = (-1, -1) + + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap if self.logits_soft_cap is not None: raise NotImplementedError( - "FlexAttention does not support logits soft cap yet.") + "FlexAttention does not support logits soft cap yet." + ) + assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: - raise NotImplementedError( - "FlexAttention does not support kv sharing yet.") + raise NotImplementedError("FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) - if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "FlexAttention does not support quantized kv-cache. Yet") + "FlexAttention does not support quantized kv-cache. Yet" + ) @staticmethod def view_as_4d(tensor: torch.Tensor) -> torch.Tensor: @@ -436,7 +752,8 @@ class FlexAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -444,8 +761,8 @@ class FlexAttentionImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlexAttentionImpl") + "fused output quantization is not yet supported for FlexAttentionImpl" + ) enable_gqa = self.num_kv_heads != self.num_heads @@ -457,6 +774,21 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens + if attn_metadata.sliding_window != self.sliding_window: + attn_metadata.sliding_window = self.sliding_window + if attn_metadata.direct_build: + # TODO: Support skipping the computation of sliding window + # in direct block mask building code path. + logger.warning_once( + "Using direct block mask building with sliding window, " + "which is suboptimal now. Performance may be degraded." + ) + # update mask mod in attention metadata + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + attn_metadata.block_mask = attn_metadata._build_block_mask_direct() + else: + attn_metadata.block_mask = attn_metadata.build_block_mask() + if not attn_metadata.causal: assert self.attn_type == AttentionType.ENCODER_ONLY @@ -465,6 +797,16 @@ class FlexAttentionImpl(AttentionImpl): (query, key, value), ) + query = query[:, :, :num_actual_tokens, :] + if (key_tensor.size(-2) > num_actual_tokens) or ( + value_tensor.size(-2) > num_actual_tokens + ): + # In the encoder-only model with torch.compile, + # qkv might be padded, which might cause exception. + # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 + key_tensor = key_tensor[:, :, :num_actual_tokens, :] + value_tensor = value_tensor[:, :, :num_actual_tokens, :] + else: assert self.attn_type == AttentionType.DECODER key_cache, value_cache = kv_cache.unbind(0) @@ -482,46 +824,67 @@ class FlexAttentionImpl(AttentionImpl): # View out the block_size dim key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, - self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) query, key_tensor, value_tensor = map( lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), (query, key_cache, value_cache), ) - query = query[:, :, :num_actual_tokens, :] + query = query[:, :, :num_actual_tokens, :] + # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) - # default M=64, N=64 may run out of shared memory on some GPUs - # TODO: Explicit configs for each GPU? - # Not sure how to calculate the shared memory requirement - extra_kernel_options = defaultdict[str, int](lambda: 64) - if query.dtype == torch.float32: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 - if current_platform.is_cuda(): - device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin - if max_shared_memory < 144 * 1024: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 + assert attn_metadata.block_mask is not None + block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE + kernel_options = get_kernel_options( + query, block_m, block_n, attn_metadata.direct_build + ) out = flex_attention_compiled( query, key_tensor, value_tensor, - attn_metadata.score_mod, + attn_metadata.transformed_score_mod, attn_metadata.block_mask, self.scale, enable_gqa=enable_gqa, - kernel_options={ - "FORCE_USE_FLEX_ATTENTION": True, - **extra_kernel_options - }, + kernel_options=kernel_options, ) # Flex doesn't have an out variant today, rely on epilogue fusion out = out.permute(0, 2, 1, 3).squeeze(0) output[:num_actual_tokens, :, :].copy_(out) return output + + +def get_kernel_options( + query, block_m, block_n, use_direct_build: bool +) -> dict[str, Union[int, bool]]: + kernel_options: dict[str, Union[int, bool]] = { + "FORCE_USE_FLEX_ATTENTION": True, + } + if vllm_kernel_override_batch_invariant(): + kernel_options["BLOCK_M"] = 16 + kernel_options["BLOCK_N"] = 16 + kernel_options["IS_DIVISIBLE"] = False + return kernel_options + if use_direct_build: + kernel_options["BLOCK_M"] = block_m + kernel_options["BLOCK_N"] = block_n + return kernel_options + else: + kernel_options["BLOCK_M"] = 64 + kernel_options["BLOCK_N"] = 64 + if query.dtype == torch.float32: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + # if current_platform.is_cuda(): + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties() + max_shared_memory = device_props.shared_memory_per_block_optin + if max_shared_memory < 144 * 1024: + kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2 + kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2 + + return kernel_options diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py new file mode 100644 index 0000000000000..21fc2ab72768c --- /dev/null +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Backend for GatedDeltaNet attention.""" + +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class GDNAttentionBackend(AttentionBackend): + @staticmethod + def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: + return GDNAttentionMetadataBuilder + + +@dataclass +class GDNAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_spec_decodes: int + num_spec_decode_tokens: int + num_actual_tokens: int + + has_initial_state: Optional[torch.Tensor] = None + + spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [num_spec_decodes + 1,] + ) + non_spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes + 1,] + ) + + spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes,] + ) + spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] + spec_token_masks: Optional[torch.Tensor] = ( + None # shape: [num_prefill_tokens + num_decode_tokens,] + ) + num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + + # The following attributes are for triton implementation of causal_conv1d + nums_dict: Optional[dict] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None + + +class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): + cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: int = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + assert isinstance(kv_cache_spec, MambaSpec) + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.speculative_config = vllm_config.speculative_config + self.kv_cache_spec = kv_cache_spec + if self.speculative_config: + self.num_spec = self.speculative_config.num_speculative_tokens + else: + self.num_spec = 0 + self.use_spec_decode = self.num_spec > 0 + self._init_reorder_batch_threshold(1, self.use_spec_decode) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), + self.compilation_config.max_capture_size, + ) + + self.spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, self.num_spec + 1), + dtype=torch.int32, + device=device, + ) + self.non_spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.spec_sequence_masks = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.bool, + device=device, + ) + self.spec_token_masks = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), + dtype=torch.bool, + device=device, + ) + self.spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1,), + dtype=torch.int32, + device=device, + ) + self.non_spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1,), + dtype=torch.int32, + device=device, + ) + self.num_accepted_tokens = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + + def build( # type: ignore[override] + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + num_accepted_tokens: Optional[torch.Tensor] = None, + num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None, + fast_build: bool = False, + ) -> GDNAttentionMetadata: + m = common_attn_metadata + + query_start_loc = m.query_start_loc + context_lens = m.num_computed_tokens_cpu + context_lens_tensor = context_lens.to(query_start_loc.device) + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + + if ( + not self.use_spec_decode + or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0] + .sum() + .item() + == 0 + ): + spec_sequence_masks = None + num_spec_decodes = 0 + else: + spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 + num_spec_decodes = spec_sequence_masks.sum().item() + if num_spec_decodes == 0: + spec_sequence_masks = None + else: + spec_sequence_masks = spec_sequence_masks.to( + query_start_loc.device, non_blocking=True + ) + + if spec_sequence_masks is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(m, decode_threshold=1) + ) + num_spec_decode_tokens = 0 + spec_token_masks = None + spec_state_indices_tensor = None + non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + spec_query_start_loc = None + non_spec_query_start_loc = query_start_loc + num_accepted_tokens = None + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + non_spec_query_lens = query_lens[~spec_sequence_masks] + num_decodes = (non_spec_query_lens == 1).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes + num_decode_tokens = num_decodes + num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens + + if num_prefills == 0 and num_decodes == 0: + spec_token_masks = torch.ones( + ( + min( + num_spec_decodes * (self.num_spec + 1), + query_start_loc[-1].item(), + ) + ), + dtype=torch.bool, + device=query_start_loc.device, + ) + spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] + non_spec_state_indices_tensor = None + spec_query_start_loc = query_start_loc + non_spec_query_start_loc = None + else: + spec_token_masks = torch.repeat_interleave( + spec_sequence_masks, query_lens + ) + spec_state_indices_tensor = m.block_table_tensor[ + spec_sequence_masks, : self.num_spec + 1 + ] + non_spec_state_indices_tensor = m.block_table_tensor[ + ~spec_sequence_masks, 0 + ] + + spec_query_start_loc = torch.zeros( + num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] + ) + non_spec_query_start_loc = torch.zeros( + query_lens.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:], + ) + + num_spec_decode_tokens = ( + query_lens.sum().item() - num_prefill_tokens - num_decode_tokens + ) + assert num_accepted_tokens is not None + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + + if num_prefills > 0: + has_initial_state = context_lens_tensor > 0 + if spec_sequence_masks is not None: + has_initial_state = has_initial_state[~spec_sequence_masks] + nums_dict, batch_ptr, token_chunk_offset_ptr = ( + compute_causal_conv1d_metadata(non_spec_query_start_loc) + ) + else: + has_initial_state = None + num_actual_tokens = ( + num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens + ) + + # prepare tensors for cudagraph + # + # With speculative decoding, the xgrammar backend may rollback tokens + # and causing some sequences has less draft tokens than self.num_spec. + # + # In above cases, the max possible batch size for n tokens, can be + # min(n, cudagraph_max_bs). + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) + batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) + + self.spec_state_indices_tensor[:num_spec_decodes].copy_( + spec_state_indices_tensor, non_blocking=True + ) + spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] + spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + + self.spec_sequence_masks[:num_spec_decodes].copy_( + spec_sequence_masks, non_blocking=True + ) + spec_sequence_masks = self.spec_sequence_masks[:batch_size] + spec_sequence_masks[num_spec_decodes:].fill_(False) + + assert spec_token_masks is not None + self.spec_token_masks[: spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True + ) + spec_token_masks = self.spec_token_masks[:num_actual_tokens] + spec_token_masks[spec_token_masks.size(0) :].fill_(False) + + self.spec_query_start_loc[: num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True + ) + spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1] + spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens) + + self.num_accepted_tokens[:num_spec_decodes].copy_( + num_accepted_tokens, non_blocking=True + ) + num_accepted_tokens = self.num_accepted_tokens[:batch_size] + num_accepted_tokens[num_spec_decodes:].fill_(1) + + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) + batch_size = num_actual_tokens + + self.non_spec_state_indices_tensor[:num_decodes].copy_( + non_spec_state_indices_tensor, non_blocking=True + ) + non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ + :batch_size + ] + non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + + self.non_spec_query_start_loc[: num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True + ) + non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index] + non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] + non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) + + attn_metadata = GDNAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_spec_decodes=num_spec_decodes, + num_spec_decode_tokens=num_spec_decode_tokens, + num_actual_tokens=num_actual_tokens, + has_initial_state=has_initial_state, + spec_query_start_loc=spec_query_start_loc, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_state_indices_tensor=spec_state_indices_tensor, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + spec_sequence_masks=spec_sequence_masks, + spec_token_masks=spec_token_masks, + num_accepted_tokens=num_accepted_tokens, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, + ) + return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert ( + m.num_reqs <= self.decode_cudagraph_max_bs + and m.num_actual_tokens <= self.decode_cudagraph_max_bs + ), ( + f"GDN only supports decode-only full CUDAGraph capture. " + f"Make sure batch size ({m.num_reqs}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " + f"and number of tokens ({m.num_actual_tokens}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})." + ) + + num_accepted_tokens = torch.diff(m.query_start_loc) + num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() + m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() + + return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index f08b6d7f177c7..1900c50849eca 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -1,20 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class LinearAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder @@ -32,28 +32,35 @@ class LinearAttentionMetadata: state_indices_tensor: torch.Tensor # shape: [batch,] -class LinearAttentionMetadataBuilder( - AttentionMetadataBuilder[LinearAttentionMetadata]): +class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]): + reorder_batch_threshold: int = 1 - reorder_batch_threshold: ClassVar[int] = 1 - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> LinearAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> LinearAttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 97a1aa86dda0d..e305cb2d87029 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -8,14 +8,14 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + split_decodes_and_prefills, +) class Mamba1AttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: return Mamba1AttentionMetadataBuilder @@ -35,8 +35,8 @@ class Mamba1AttentionMetadata: class Mamba1AttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): - + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] +): def build( self, common_prefix_len: int, @@ -47,23 +47,30 @@ class Mamba1AttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - query_start_loc.device) + query_start_loc.device + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) has_initial_states = None padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 - elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph): + elif ( + num_decodes > 0 + and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): state_indices_for_decode = state_indices_tensor[:num_decodes] padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( - state_indices_for_decode, non_blocking=True) + state_indices_for_decode, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor[num_decodes:] = PAD_SLOT_ID diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ed30884fdbc94..10f09442d82e2 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -1,59 +1,94 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math +import itertools from dataclasses import dataclass from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.utils import cdiv +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): +def compute_varlen_chunk_metadata( + query_start_loc: torch.Tensor, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels. - cu_seqlens = query_start_loc[1:] # remove prepended 0 + Given per-sequence cumulative token starts `query_start_loc` of shape [B+1] + and a physical `chunk_size`, returns three tensors on the same device: + - cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of + logical-chunk lengths (each logical chunk never crosses a sequence or + physical-chunk boundary). + - last_chunk_indices: (B,) int32 index of the last logical chunk + for each sequence (=-1 for empty sequences). + - seq_idx_chunks: (nchunks,) int32 sequence index for each logical + chunk in order. - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, - dtype=torch.int, - device=query_start_loc.device) - chunk_offsets = torch.zeros((N, ), - dtype=torch.int, - device=query_start_loc.device) + This is intentionally lightweight and CPU-side; it mirrors the metadata + produced by the V1 Mamba2 meta-data builder and is exported so tests + (and other callers) can avoid duplicating the logic. + """ + assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]" + assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0" + device = query_start_loc.device - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + qsl64 = query_start_loc.to(torch.int64) + starts = qsl64[:-1].tolist() + ends = qsl64[1:].tolist() + total = int(qsl64[-1].item()) - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) + chunk_lens: list[int] = [] + seq_idx_chunks: list[int] = [] + last_chunk_indices: list[int] = [-1] * len(starts) - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) + for b, (s, e) in enumerate(zip(starts, ends)): + if e <= s: + # empty sequence + continue + pos = s + while pos < e: + # split at both sequence boundaries and physical chunk boundaries + room = chunk_size - (pos % chunk_size) + take = min(room, e - pos) + chunk_lens.append(int(take)) + seq_idx_chunks.append(b) + last_chunk_indices[b] = len(chunk_lens) - 1 + pos += take - # adjust indices and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size + # Exclusive prefix sum over logical-chunk lengths + if chunk_lens: + cu_chunk_seqlens = torch.tensor( + [0] + list(itertools.accumulate(chunk_lens)), + device=device, + dtype=torch.int32, + ) + # Final boundary must equal total tokens + assert int(cu_chunk_seqlens[-1].item()) == total + else: + cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32) - return chunk_indices, chunk_offsets + last_chunk_indices_t = ( + torch.tensor(last_chunk_indices, device=device, dtype=torch.int32) + if len(starts) > 0 + else torch.empty((0,), device=device, dtype=torch.int32) + ) + seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32) + return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t class Mamba2AttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: return Mamba2AttentionMetadataBuilder @@ -65,7 +100,7 @@ class Mamba2AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int - query_start_loc: torch.Tensor + query_start_loc_p: torch.Tensor seq_lens: torch.Tensor prep_initial_states: bool @@ -75,100 +110,274 @@ class Mamba2AttentionMetadata: # the batch has no prefill request. has_initial_states_p: Optional[torch.Tensor] seq_idx_p: Optional[torch.Tensor] - chunk_indices_p: Optional[torch.Tensor] - chunk_offsets_p: Optional[torch.Tensor] + + # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for + # each chunk, its offests into the varlen sequence dimension. It is defined + # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to + # cu_chunk_seqlen_p[i+1]. + cu_chunk_seqlen_p: Optional[torch.Tensor] + + # last_chunk_indices_p is a tensor of shape (batch,) that contains the + # index of the last chunk for every sequence in the (prefill) batch. + last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] + block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] + block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] + block_idx_last_computed_token: torch.Tensor # shape: [batch,] + num_computed_tokens_p: torch.Tensor # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None class Mamba2AttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] +): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( - "chunk_size needs to be set in the model config for Mamba2 models") + "chunk_size needs to be set in the model config for Mamba2 models" + ) + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor = torch.empty( + ( + self.decode_cudagraph_max_bs, + cdiv( + vllm_config.model_config.max_model_len, kv_cache_spec.block_size + ), + ), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> Mamba2AttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs - query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens + query_start_loc_p = None seq_idx_p = None - chunk_indices_p, chunk_offsets_p = None, None + cu_chunk_seqlen_p = None + last_chunk_indices_p = None + # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend has_initial_states_p = None prep_initial_states = False - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + + num_computed_tokens, num_computed_tokens_p = None, None + block_idx_first_scheduled_token = None + block_idx_first_scheduled_token_p = None + + if self.vllm_config.cache_config.enable_prefix_caching: + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor + # Additional cache-related varaiables: + mamba_block_size = self.kv_cache_spec.block_size + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + # Block index of the last computed token + block_idx_last_computed_token = ( + cdiv(num_computed_tokens, mamba_block_size) - 1 + ) + # which is <= block index for the first scheduled token + block_idx_first_scheduled_token = ( + cdiv(num_computed_tokens + 1, mamba_block_size) - 1 + ) + # which is <= block index of the last scheduled token + block_idx_last_scheduled_token = ( + cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 + ) + # -1 in case it's non-computed and causes later issues with indexing + block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # Additional cache-related varaiables: + block_idx_last_scheduled_token = None + block_idx_last_computed_token = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + # Compute seq_idx for prefill only if num_prefills > 0: - #[batch,] + # [batch,] has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) prep_initial_states = torch.any(has_initial_states_cpu).item() has_initial_states_p = has_initial_states_cpu.to( - query_start_loc.device) + common_attn_metadata.query_start_loc.device + ) - query_start_loc_p = common_attn_metadata.query_start_loc[ - -num_prefills - 1:] - num_decode_tokens + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) - seq_idx_p = torch.repeat_interleave(torch.arange( - num_prefills, - dtype=torch.int32, - device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - seq_idx_p.unsqueeze_(0) + if self.vllm_config.cache_config.enable_prefix_caching: + assert num_computed_tokens is not None + num_computed_tokens_p = num_computed_tokens[ + num_reqs - num_prefills : num_reqs + ] + assert block_idx_first_scheduled_token is not None + block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ + num_reqs - num_prefills : num_reqs + ] + num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + query_start_loc_p_cpu = ( + common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) - # We compute metadata for chunked prefill once at the top level - # model forward and reuse them in mamba layers. If not needed, - # they will be ignored inside mamba kernels. - if prep_initial_states: - chunk_indices_p, chunk_offsets_p = ( - _query_start_loc_to_chunk_indices_offsets( - query_start_loc_p, self.chunk_size, - num_prefill_tokens)) + # The code below carefully constructs the chunks such that: + # 1. Chunks contain tokens from a *single* sequence only. + # 2. For every sequence, we are guaranteed that we can + # retrieve the mamba state *every* chunk_size tokens. + # Constraint (1) dramatically simplifies the mamba2 kernels. + # Constraint (2) dramatically simplifies the implementation + # of prefix caching for mamba2 (wip). We need to take care + # of the interaction with chunked prefill in order to + # satisfy constraint (2). + # TODO (tdoublep): This code could probably be optimized. + cu_chunk_seqlen = [] + seq_idx = [] + last_chunk_indices = [] + seqlen_pos = 0 + for req_idx in range(num_prefills): + this_num_computed = num_computed_tokens_p_cpu[req_idx].item() + this_new_tokens = ( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) - elif num_decodes <= self.decode_cudagraph_max_bs: + # if computed tokens are not chunk-aligned, use the first + # chunk to finish it off + if this_num_computed % self.chunk_size != 0: + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + # how many tokens to finish the chunk? + chunk_len = ( + cdiv(this_num_computed, self.chunk_size) * self.chunk_size + - this_num_computed + ) + # we can only use at most this_new_tokens + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for chunk in range(n_chunks): + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk_indices.append(len(cu_chunk_seqlen) - 1) + + cu_chunk_seqlen.append(seqlen_pos) + + seq_idx_p = torch.as_tensor( + seq_idx, device=query_start_loc_p.device, dtype=torch.int32 + ) + cu_chunk_seqlen_p = torch.as_tensor( + cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32 + ) + last_chunk_indices_p = torch.as_tensor( + last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32 + ) + + nums_dict, batch_ptr, token_chunk_offset_ptr = ( + compute_causal_conv1d_metadata(query_start_loc_p) + ) + + elif ( + num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) - self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, - non_blocking=True) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID + if self.vllm_config.cache_config.enable_prefix_caching: + self.block_idx_last_scheduled_token[:num_decodes].copy_( + block_idx_last_scheduled_token, non_blocking=True + ) + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + :num_input_tokens + ] + block_idx_last_scheduled_token[num_decodes:] = 0 + + self.block_idx_last_computed_token[:num_decodes].copy_( + block_idx_last_computed_token, non_blocking=True + ) + block_idx_last_computed_token = self.block_idx_last_computed_token[ + :num_input_tokens + ] + block_idx_last_computed_token[num_decodes:] = 0 + attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - query_start_loc=query_start_loc, + query_start_loc_p=query_start_loc_p, seq_lens=seq_lens, prep_initial_states=prep_initial_states, chunk_size=self.chunk_size, has_initial_states_p=has_initial_states_p, seq_idx_p=seq_idx_p, - chunk_indices_p=chunk_indices_p, - chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_indices_p=last_chunk_indices_p, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token=block_idx_last_computed_token, + num_computed_tokens_p=num_computed_tokens_p, ) return attn_metadata diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 07ef7cb69a160..5aafb9813df06 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,49 +7,57 @@ from typing import ClassVar, TypeVar import torch from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec M = TypeVar("M") class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): - reorder_batch_threshold: ClassVar[int] = 1 - cudagraph_support: ClassVar[AttentionCGSupport] = \ + reorder_batch_threshold: int = 1 + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names - self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) + self.compilation_config.max_capture_size, + ) self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs == m.num_actual_tokens, ( + "Mamba only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) m.max_query_len = 1 # decode-only - return self.build(0, m) \ No newline at end of file + return self.build(0, m) diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py deleted file mode 100644 index fb1844508211b..0000000000000 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionBackend) - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba1": - return Mamba1AttentionBackend - if mamba_type == "mamba2": - return Mamba2AttentionBackend - if mamba_type == "linear_attention": - return LinearAttentionBackend - if mamba_type == "short_conv": - return ShortConvAttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ce45b34f64355..af396c2b41035 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -193,32 +193,43 @@ from dataclasses import dataclass, field from typing import ClassVar, Generic, Optional, TypeVar, Union import torch +from tqdm import tqdm import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, +) from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod, +) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.utils.flashinfer import has_nvidia_artifactory -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -228,24 +239,48 @@ except ImportError: try: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper - from flashinfer.prefill import ( # noqa: F401 - cudnn_batch_prefill_with_kv_cache) + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401 + flashinfer_available = True except ImportError: flashinfer_available = False + +def is_rocm_aiter_fp8bmm_enabled() -> bool: + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_FP8BMM + and envs.VLLM_ROCM_USE_AITER + ) + + +if is_rocm_aiter_fp8bmm_enabled(): + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 + ) + + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + ): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 class MLACommonBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod def get_name() -> str: - return "TRITON_MLA_VLLM_V1" + return "TRITON_MLA" @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: @@ -261,6 +296,7 @@ class MLACommonBackend(AttentionBackend): block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) @@ -281,12 +317,13 @@ class MLACommonBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @dataclass class MLACommonPrefillMetadata: - """ Prefill Specific Metadata """ + """Prefill Specific Metadata""" @dataclass class ChunkedContextMetadata: @@ -299,6 +336,13 @@ class MLACommonPrefillMetadata: seq_lens: torch.Tensor workspace: torch.Tensor + # for mla DCP + cp_chunk_seq_lens: Optional[list[list[int]]] = None + origin_context_lens: Optional[list[int]] = None + cp_cu_seq_lens: Optional[torch.Tensor] = None + chunk_size: Optional[int] = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -307,16 +351,15 @@ class MLACommonPrefillMetadata: @dataclass class FlashInferPrefillMetadata(MLACommonPrefillMetadata): - prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None - prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( - default_factory=list) + prefill_main: Optional["BatchPrefillWithRaggedKVCacheWrapper"] = None + prefill_chunks: list["BatchPrefillWithRaggedKVCacheWrapper"] = field( + default_factory=list + ) @dataclass class CudnnPrefillMetadata(MLACommonPrefillMetadata): - - class ChunkedContextMetadata( - MLACommonPrefillMetadata.ChunkedContextMetadata): + class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor query_seq_lens: Optional[torch.Tensor] = None @@ -327,6 +370,7 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata): class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + dcp_tot_seq_lens: Optional[torch.Tensor] D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -339,6 +383,7 @@ class MLACommonMetadata(Generic[D]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -349,6 +394,7 @@ class MLACommonMetadata(Generic[D]): num_reqs: int max_query_len: int + max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor @@ -364,9 +410,9 @@ class MLACommonMetadata(Generic[D]): head_dim: Optional[int] = None decode: Optional[D] = None - prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata, - CudnnPrefillMetadata]] = None + prefill: Optional[ + Union[MLACommonPrefillMetadata, FlashInferPrefillMetadata, CudnnPrefillMetadata] + ] = None def __post_init__(self): if self.head_dim is not None: @@ -374,19 +420,27 @@ class MLACommonMetadata(Generic[D]): M = TypeVar("M", bound=MLACommonMetadata) +A = TypeVar("A") def use_flashinfer_prefill() -> bool: - # For blackwell default to flashinfer prefill if its available since + # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. - return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100)) + return ( + not envs.VLLM_DISABLE_FLASHINFER_PREFILL + and flashinfer_available + and not envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + ) def use_cudnn_prefill() -> bool: - return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100) - and has_nvidia_artifactory()) + return ( + flashinfer_available + and envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + and has_nvidia_artifactory() + ) # Currently 394MB, this can be tuned based on GEMM sizes used. @@ -400,36 +454,35 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - reorder_batch_threshold: ClassVar[int] = 1 - def __init__(self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[type[M]] = None): - self.metadata_cls = metadata_cls \ - if metadata_cls is not None else MLACommonMetadata - self.kv_cache_spec = kv_cache_spec - self.device = device + # Whether the backend supports reordering the batch such that + # short sequences (i.e. verification for speculative decoding) are + # classified as decode requests. + # If True, this will increase `reorder_batch_threshold` (below) when + # speculative decoding is enabled, and set `require_uniform=True` when + # when reordering the batch. Non-uniform decode requests will + # fall back to prefill in this case. + supports_uniform_spec_as_decode: ClassVar[bool] = False + + # The threshold for reordering the batch into decode and prefill requests. + # If > 1, the batch will be reordered such that requests with + # query length <= threshold are classified as decode requests. + # Use `supports_uniform_spec_as_decode` (above) to set this automatically + # when speculative decoding is enabled. + reorder_batch_threshold: int = 1 + + @staticmethod + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: scheduler_config = vllm_config.scheduler_config - self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config - parallel_config = vllm_config.parallel_config - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) - self.mla_dims = get_mla_dims(self.model_config) - self.aot_schedule = current_platform.is_cuda() + model_config = vllm_config.model_config - # Dont try to access the runner on AMD - if self.aot_schedule: - self.page_size = self.kv_cache_spec.block_size - - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * cache_config.block_size), + chunked_prefill_workspace_size = min( + # Try for 8 full length request or at least 4 pages per-request + max( + 8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size, + ), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, # which would result in the workspace being: @@ -438,37 +491,101 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, + 64 * 1024, ) + # Enforce that we enough for at least 1 page per request + chunked_prefill_workspace_size = max( + chunked_prefill_workspace_size, + scheduler_config.max_num_seqs * cache_config.block_size, + ) + + return chunked_prefill_workspace_size + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[type[M]] = None, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) + self.kv_cache_spec = kv_cache_spec + scheduler_config = vllm_config.scheduler_config + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + self.device = device + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.aot_schedule = current_platform.is_cuda() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + # Don't try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.kv_cache_spec.block_size + + self.chunked_prefill_workspace_size = ( + self.determine_chunked_prefill_workspace_size(vllm_config) + ) + + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + ( + self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size(), + ), + dtype=self.model_config.dtype, + device=device, + ) + else: + self.chunked_prefill_workspace = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.model_config.get_head_size(), + ), + dtype=self.model_config.dtype, + device=device, + ) + self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata - if self._use_fi_prefill else CudnnPrefillMetadata - if self._use_cudnn_prefill else MLACommonPrefillMetadata) + if self._use_fi_prefill + else CudnnPrefillMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata + ) if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=device) + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) - self._fi_prefill_main: Optional[ - BatchPrefillWithRaggedKVCacheWrapper] = None - self._fi_prefill_chunks: list[ - BatchPrefillWithRaggedKVCacheWrapper] = [] + self._fi_prefill_main: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, - MLACommonImpl)) + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + ) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( @@ -477,6 +594,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): device=device, ) + supports_spec_as_decode = self.supports_uniform_spec_as_decode + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_as_decode + ) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -487,7 +609,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._fi_prefill_main is None: self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass") + self._workspace_buffer, "NHD", backend="cutlass" + ) if has_context: num_chunks = chunked_context.cu_seq_lens.shape[0] @@ -496,7 +619,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): for _ in range(len(self._fi_prefill_chunks), num_chunks): self._fi_prefill_chunks.append( BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass")) + self._workspace_buffer, "NHD", backend="cutlass" + ) + ) assert num_chunks <= len(self._fi_prefill_chunks) # In MLA, the non-latent num_qo_heads == num_kv_heads @@ -507,8 +632,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): assert self.kv_cache_spec.num_kv_heads == 1 # Get non-latent head_dim_qk and head_dim_vo - head_dim_qk = (self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim) + head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim head_dim_vo = self.mla_dims.v_head_dim # For main run, qo_indptr == kv_indptr @@ -527,7 +651,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) # Prepare context prefills @@ -545,44 +668,56 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): causal=False, # This is context run sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, - logits_soft_cap=self._global_hyperparameters. - logits_soft_cap, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ - "MLA only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + "MLA only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) - assert m.max_query_len == 1 # decode-only + assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -594,14 +729,29 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - - query_seq_lens_cpu) + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=self.supports_uniform_spec_as_decode, + ) + ) + + # Note(hc): update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + assert dcp_local_seq_lens is not None + dcp_local_seq_lens[:num_decodes] = seq_lens[ + :num_decodes + ] // self.dcp_world_size + ( + self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -611,10 +761,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + # Note(hc): The context lengths in the perspective of dcp rank0. + cp_context_lens_cpu = torch.ceil( + context_lens_cpu.float() / self.dcp_world_size + ).int() + origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) chunked_context_metadata = None if max_context_len_cpu > 0: @@ -626,16 +782,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_and_maybe_dequant_cache` kernel # cannot handle `context_chunk_starts` that are not aligned # to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -646,43 +802,90 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # Note(simon): this is done in CPU because of downstream's # of `to_list`. - chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - - chunked_context_metadata_cls = \ - CudnnPrefillMetadata.ChunkedContextMetadata \ - if self._use_cudnn_prefill else \ - MLACommonPrefillMetadata.ChunkedContextMetadata - - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, - workspace=self.chunked_prefill_workspace, + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) + + if self.dcp_world_size > 1: + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + cp_max_context_chunk = max_context_chunk // self.dcp_world_size + cp_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) + * cp_max_context_chunk + ) + cp_chunk_ends = torch.min( + cp_context_lens_cpu.unsqueeze(0), + cp_chunk_starts + cp_max_context_chunk, + ) + cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + + chunked_context_metadata_cls = ( + CudnnPrefillMetadata.ChunkedContextMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata.ChunkedContextMetadata + ) + if self.dcp_world_size > 1: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=cp_chunk_starts.to(device, non_blocking=True), + seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), + origin_context_lens=origin_context_lens, + cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True), + chunk_size=max_context_chunk, + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + ) + else: + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens - assert max(chunked_context_metadata.max_seq_lens) <= \ - self.chunked_prefill_workspace_size + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], @@ -693,20 +896,31 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if self._use_cudnn_prefill: assert isinstance(prefill_metadata, CudnnPrefillMetadata) - prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ - - prefill_query_start_loc[:-1] + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) prefill_metadata.cudnn_workspace = self.cudnn_workspace decode_metadata = None if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens=seq_lens[:num_decodes], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=dcp_local_seq_lens[:num_decodes] + if self.dcp_world_size > 1 and dcp_local_seq_lens is not None + else seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], + num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=seq_lens[:num_decodes] + if self.dcp_world_size > 1 + else None, ) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, @@ -726,7 +940,78 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return attn_metadata -class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + cp_chunk_seq_lens_lst: list[int], + origin_context_lens: list[int], + cp_world_size: int, + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg kvcache after cp local gather to tp layout for attn kernel. + + Args: + cp_chunk_seq_lens_lst: chunk context lengths under CP. + origin_context_lens: origin full context lengths under CP. + cp_world_size: CP size. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: equals to max_context_chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for cp_chunk_seq_len, origin_context_len in zip( + cp_chunk_seq_lens_lst, origin_context_lens + ): + chunk_context_len = chunk_size + if cp_chunk_seq_len != 0: + chunk_context_len = min( + chunk_context_len, origin_context_len - chunk_size * chunk_idx + ) + cp_target_rank = (chunk_context_len - 1) % cp_world_size + cur_seq_len = 0 + for rank in range(cp_world_size): + if rank > cp_target_rank and cp_chunk_seq_len: + real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + else: + real_cp_chunk_seq_len = cp_chunk_seq_len + if real_cp_chunk_seq_len: + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += real_cp_chunk_seq_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += cp_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + +# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, +# and MLACommonImpl -> MLACommonDenseImpl or somthing like that +class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -752,6 +1037,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + indexer=None, + q_pad_num_heads: Optional[int] = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -769,6 +1056,140 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj + self.indexer = indexer + self.q_pad_num_heads = q_pad_num_heads + + def process_weights_after_loading(self, act_dtype: torch.dtype): + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype() + ) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype() + ) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + # Copy result + out.copy_(x) + else: + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result + + +class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") @@ -777,8 +1198,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self._pad_v = False elif use_cudnn_prefill(): logger.debug_once("Using CUDNN prefill for MLA") - self._run_prefill_context_chunk = \ - self._run_prefill_context_chunk_cudnn + self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention @@ -793,9 +1213,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) + self.flash_attn_varlen_func = functools.partial( + flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + ) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do @@ -803,19 +1223,25 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # We don't need to pad V if we are on a hopper system with FA3 self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) + and current_platform.get_device_capability()[0] == 9 + ) - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + self.dcp_world_size: Optional[int] = None + + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse @@ -843,8 +1269,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return attn_out, lse return attn_out - def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fa( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -858,19 +1285,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return_softmax_lse=return_softmax_lse, ) - def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fi( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None - return prefill.prefill_main.run( + ret = prefill.prefill_main.run( q=q, k=k, v=v, return_lse=return_softmax_lse, ) - def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, - q, k, v, return_softmax_lse): + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_new_tokens_cudnn( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None output, lse = cudnn_batch_prefill_with_kv_cache( @@ -884,16 +1318,18 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), causal=True, - return_lse=True, # do not support False for now - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + # Do not support False for now + return_lse=True, + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, ) if return_softmax_lse: return output, lse return output - def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fa( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert prefill.chunked_context is not None return self._flash_attn_varlen_diff_headdims( q=q, @@ -908,19 +1344,22 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return_softmax_lse=True, ) - def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fi( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, FlashInferPrefillMetadata) - return prefill.prefill_chunks[chunk_idx].run( + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, v=v, return_lse=True, ) + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() - def _run_prefill_context_chunk_cudnn(self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_cudnn( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None @@ -934,59 +1373,53 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): max_token_per_sequence=prefill.max_query_len, max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), - actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. - view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( + -1, 1, 1, 1 + ), causal=False, return_lse=True, - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, ) - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T return layer.weight # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -994,12 +1427,55 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype() + ) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype() + ) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, @@ -1030,18 +1506,118 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): seq_starts=prefill_metadata.chunked_context.starts[i], ) - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _context_parallel_compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + dcp_world_size: int, + ): + assert k_scale is None, "DCP not support scaled kvcache now." + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.origin_context_lens is not None + assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None + assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + i + ], + origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, + cp_world_size=dcp_world_size, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks, + ) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1079,13 +1655,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, ) -> torch.Tensor: + # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None + assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) @@ -1099,8 +1677,20 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + if self.dcp_world_size > 1: + context_output, context_lse = ( + self._context_parallel_compute_prefill_context( + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) + else: + context_output, context_lse = self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1113,19 +1703,18 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # unpad if necessary if self._pad_v: - output = output[..., :v.shape[-1]] + output = output[..., : v.shape[-1]] return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError def forward( @@ -1144,15 +1733,31 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. return output.fill_(0) + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + fp8_attention = self.kv_cache_dtype.startswith("fp8") num_actual_toks = attn_metadata.num_actual_tokens @@ -1164,9 +1769,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1194,35 +1801,90 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_prefill: output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + ) if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) + + # Pads the head_dim if necessary (for the underlying kernel) + if self.q_pad_num_heads is not None: + B, N, L = decode_q_pe.shape + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) + decode_pe_padded.resize_((B, N, L)) + decode_pe_padded.copy_(decode_q_pe) + decode_q_pe = decode_pe_padded + + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + decode_ql_nope = aiter_triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) + else: + # Pads the head_dim if necessary (for the underlying kernel) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L) + ) + decode_ql_nope.resize_((N, B, L)) + + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: ql_nope_shape = decode_ql_nope.shape decode_ql_nope, _ = ops.scaled_fp8_quant( - decode_ql_nope.reshape([ - ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] - ]), layer._q_scale) + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape( - [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), - layer._q_scale) + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) decode_q_pe = decode_q_pe.reshape(q_pe_shape) - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + # call decode attn + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer + ) + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + + # v_up projection + self._v_up_proj(attn_out, out=output[:num_decode_tokens]) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 8a17d3a492783..11e06cc6daac7 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,18 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) @@ -21,12 +27,12 @@ logger = init_logger(__name__) class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable full CUDA Graph support for decode-only capture - cudagraph_support: ClassVar[ - AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) class CutlassMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -39,13 +45,16 @@ class CutlassMLABackend(MLACommonBackend): def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [128] + class SM100Workspace: - def __init__(self, initial_workspace_size): - self._workspace_buf = torch.empty(initial_workspace_size, - device="cuda", - dtype=torch.uint8) + self._workspace_buf = torch.empty( + initial_workspace_size, device="cuda", dtype=torch.uint8 + ) self._block_size = 128 # Forced to 128 @@ -57,8 +66,7 @@ class SM100Workspace: def get_buf(self): return self._workspace_buf - def ensure_size(self, attn_metadata: MLACommonMetadata, - num_kv_splits: int): + def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int): batch_size = attn_metadata.num_reqs max_seq_len = attn_metadata.max_query_len @@ -66,7 +74,8 @@ class SM100Workspace: max_seq_len * self._block_size, batch_size, self._sm_count, - num_kv_splits=num_kv_splits) + num_kv_splits=num_kv_splits, + ) if self._workspace_buf.shape[0] < workspace_size: self._workspace_buf.resize_(workspace_size) @@ -74,57 +83,63 @@ class SM100Workspace: g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB +MAX_HEADS = 128 + class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + q_pad_num_heads=MAX_HEADS, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "CutlassMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "CutlassMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "CutlassMLA V1 with FP8 KV cache not yet supported") - - self._use_old_cutlass_mla = False - force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) - if force_old_cutlass: - logger.warning_once("Forcing old cutlass mla kernel") - self._use_old_cutlass_mla = True + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CutlassMLAImpl" + ) # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging # issues. In case the code hangs, use: # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning_once("Forcing num_kv_splits to %d", - int(force_num_kv_splits)) + logger.warning_once("Forcing num_kv_splits to %d", int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -142,15 +157,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): workspace: torch.Tensor, sm_scale: float, num_kv_splits: int, - ) -> torch.Tensor: - assert (q_nope.ndim == 3 - ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" - assert ( - q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}" - assert ( - kv_c_and_k_pe_cache.ndim == 3 - ), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( - kv_c_and_k_pe_cache.ndim) + ) -> tuple[torch.Tensor, torch.Tensor]: + assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" + assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" + assert kv_c_and_k_pe_cache.ndim == 3, ( + "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( + kv_c_and_k_pe_cache.ndim + ) + ) B_q, H, D_q_nope = q_nope.shape B_q_2, H_2, D_q_pe = q_pe.shape @@ -166,39 +180,39 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): MAX_HEADS = 128 assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" - if H < MAX_HEADS: - q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) - q_nope_padded[:, :H] = q_nope - q_nope = q_nope_padded - - q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) - q_pe_padded[:, :H] = q_pe - q_pe = q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape assert B_block_table == B_q - assert (block_num - > 0), f"block num must be greater than 0, got {block_num}" + assert block_num > 0, f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 - # TODO(kaixih@nvidia): support fp8 - assert q_nope.dtype in ( - torch.float16, - torch.bfloat16, - ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." + assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( + f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}." + ) assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype - assert ( - seq_lens.dtype == torch.int32 - ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." - assert ( - page_table.dtype == torch.int32 - ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." + assert seq_lens.dtype == torch.int32, ( + f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." + ) + assert page_table.dtype == torch.int32, ( + f"page_table.dtype needs to be int32 but got {page_table.dtype}." + ) - out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) + dtype = ( + torch.bfloat16 + if is_quantized_kv_cache(self.kv_cache_dtype) + else q_nope.dtype + ) + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) + lse = ( + torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode + else torch.Tensor() + ) ops.sm100_cutlass_mla_decode( out, + lse, q_nope, q_pe, kv_c_and_k_pe_cache, @@ -208,83 +222,44 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale, num_kv_splits, ) - return out[:, :H].contiguous() - def _sm100_forward_decode( + if H < MAX_HEADS: + # Extract the subsets of the outputs + lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse + out = out[:, :H] + + return out, lse + + def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits) # Run MLA - # Clone q_nope and q_pe to make sure strides computation is correct. - # TODO: Check if we really need it - q_nope = q_nope.clone() - q_pe = q_pe.clone() + o, lse = self._sm100_cutlass_mla_decode( + q_nope, + q_pe, + kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, + self._workspace.get_buf(), + self.scale, + self._num_kv_splits, + ) - o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, - self._workspace.get_buf(), - self.scale, self._num_kv_splits) - - return self._v_up_proj(o) - - # TODO: Currently we leave it here only for backup in case something is - # wrong with the new SM100 CUTLASS MLA kernel - def _old_forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - assert attn_metadata.decode is not None - - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") - - B = q_nope.shape[0] - - o = torch.empty((B, self.num_heads, self.kv_lora_rank), - dtype=q_nope.dtype, - device=q_nope.device) - - # Run MLA - # Clone q_nope and q_pe to make sure strides computation is correct. - q_nope = q_nope.clone() - q_pe = q_pe.clone() - - ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, self.scale) - - return self._v_up_proj(o) - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - layer: AttentionLayer, - ) -> torch.Tensor: - if self._use_old_cutlass_mla: - # TODO: Remove the old cutlass MLA kernel after more extensive - # testing - return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) - - return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) + return o, (lse if self.need_to_return_lse_for_decode else None) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py new file mode 100644 index 0000000000000..c043990ffcc61 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,298 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, + get_flash_attn_version, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata + +logger = init_logger(__name__) + + +class FlashAttnMLABackend(MLACommonBackend): + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_MLA" + + @staticmethod + def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: + return FlashAttnMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: + return FlashAttnMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +@dataclass +class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): + query_start_loc: torch.Tensor + max_query_len: int + max_seq_len: int + scheduler_metadata: Optional[torch.Tensor] = None + max_num_splits: int = 0 + + +@dataclass +class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): + pass + + +class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: int = 512 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata + ) + self.max_num_splits = 0 # No upper bound on the number of splits. + self.fa_aot_schedule = get_flash_attn_version() == 3 + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + + if self.use_full_cuda_graph and self.fa_aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + + if self.max_cudagraph_size > 992: + # This condition derives from FA3's internal heuristic. + # TODO(woosuk): Support larger cudagraph sizes. + raise ValueError( + "Capture size larger than 992 is not supported for full cuda graph." + ) + + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + + def _schedule_decode( + self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads * self.dcp_world_size, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + cache_seqlens=seqlens, + qkv_dtype=self.kv_cache_spec.dtype, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + num_splits=self.max_num_splits, + ) + return None + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> FlashAttnMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_query_len = query_lens_cpu.max().item() + max_seq_len = seq_lens_device.max().item() + + scheduler_metadata = self._schedule_decode( + num_reqs=seq_lens_cpu.numel(), + cu_query_lens=query_start_loc_device, + max_query_len=max_query_len, + seqlens=seq_lens_device, + max_seq_len=max_seq_len, + causal=True, + ) + + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + # Ensure the persistent buffer is large enough + assert n <= self.scheduler_metadata.shape[0], ( + f"Scheduler metadata size {n} exceeds buffer size " + + f"{self.scheduler_metadata.shape[0]}" + ) + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + if num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + return FlashAttnMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + query_start_loc=query_start_loc_device, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + scheduler_metadata=scheduler_metadata, + max_num_splits=max_num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + ) + + +class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device" + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashAttnMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap" + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttnMLAImpl" + ) + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttnMLA V1 with FP8 KV cache not yet supported" + ) + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashAttnMLAMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashAttention MLA not yet supported") + + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] + + # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the + # kernel uses this to calculate grid dimensions. Ensure it's at least 1 + # to prevent invalid grid configuration during graph capture. + max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) + + attn_out = flash_attn_varlen_func( + q=q_pe, + k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=max_seqlen_q, + cu_seqlens_q=attn_metadata.decode.query_start_loc, + max_seqlen_k=attn_metadata.decode.max_seq_len, + seqused_k=attn_metadata.decode.seq_lens, + block_table=attn_metadata.decode.block_table, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=self.need_to_return_lse_for_decode, + fa_version=3, # only version 3 is supported + scheduler_metadata=attn_metadata.decode.scheduler_metadata, + num_splits=attn_metadata.decode.max_num_splits, + cp_world_size=self.dcp_world_size, + cp_rank=self.dcp_rank, + cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, + ) + + if self.need_to_return_lse_for_decode: + o, lse = attn_out + # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] + else: + o = attn_out + return o, None diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py new file mode 100644 index 0000000000000..206f96ea366a4 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import ClassVar, Optional, Union + +import torch +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla + +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) +from vllm.v1.attention.backends.utils import AttentionCGSupport + +logger = init_logger(__name__) + +FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable spec-as-decode optimization + supports_uniform_spec_as_decode: ClassVar[bool] = True + + # enable full CUDA Graph support for decode-only capture + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + +class FlashInferMLABackend(MLACommonBackend): + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + @staticmethod + def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + + +g_fi_workspace = torch.zeros( + FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device="cuda", +) + + +class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap" + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl" + ) + + self._workspace_buffer = g_fi_workspace + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if isinstance(q, tuple): + q_nope, q_pe = q + q = torch.cat([q_nope, q_pe], dim=-1) + + # trtllm API requires extra dimension q_len_per_request for MTP + if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0: + logger.warning_once( + """FlashInferMLAImpl got a query of uneven length. + This usually indicates an issue in batch reordering + or incorrect setup in dummy_run.""" + ) + q = q.unsqueeze(1) + else: + q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1]) + + if self.bmm1_scale is None: + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + o = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + workspace_buffer=self._workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=attn_metadata.decode.block_table, + seq_lens=attn_metadata.decode.seq_lens, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + ) + + # Flatten the output for consistent shape + o = o.view(-1, o.shape[-2], o.shape[-1]) + + # TODO: Return LSE pending support from Flashinfer API: + # https://github.com/flashinfer-ai/flashinfer/pull/1566 + return o, None diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1c50144d47900..e0f4a7f0382b3 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,21 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_dense_supported, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -24,10 +28,9 @@ logger = init_logger(__name__) class FlashMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: - return "FLASHMLA_VLLM_V1" + return "FLASHMLA" @staticmethod def get_metadata_cls() -> type["FlashMLAMetadata"]: @@ -41,6 +44,10 @@ class FlashMLABackend(MLACommonBackend): def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [64] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): @@ -54,17 +61,22 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashMLAMetadata) + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata + ) - self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None @@ -83,15 +95,23 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_num_splits = torch.empty( (vllm_config.scheduler_config.max_num_seqs + 1), device=self.device, - dtype=torch.int32) + dtype=torch.int32, + ) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: - tile_scheduler_metadata, num_splits = \ - get_mla_metadata( - seq_lens, + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = get_mla_metadata( + seq_lens_device, self.num_q_heads, - 1, # MQA for the decode path + 1, # MQA for the decode path ) # TODO: we can disambiguate between decode and mixed-prefill decode here @@ -104,8 +124,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): sm_parts = tile_scheduler_metadata.size(0) # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) - tile_scheduler_metadata_view = \ - self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[ + :sm_parts + ] tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) tile_scheduler_metadata = tile_scheduler_metadata_view @@ -123,70 +144,85 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): return FlashMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" + is_supported, reason = is_flashmla_dense_supported() + assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl" + ) def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # TODO: (zyongye) decode function for mla here assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) + if type(q) is tuple: + q = torch.cat(q, dim=-1) - o, _ = flash_mla_with_kvcache( - q=q, + assert isinstance(q, torch.Tensor) + o, lse = flash_mla_with_kvcache( + q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata.decode. - tile_scheduler_metadata, + tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata, num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, @@ -194,4 +230,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): descale_k=layer._k_scale.reshape(1), ) - return self._v_up_proj(o) + return o, lse diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py new file mode 100644 index 0000000000000..144e46d5e9537 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -0,0 +1,539 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, +) +from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.flashmla import ( + flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer + +logger = init_logger(__name__) +""" +NOTE: FlashMLA Sparse uses an fp8 cache with the following format + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 + `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, + the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This + part is not quantized for accuracy. +""" + + +class FlashMLASparseBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE" + + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return FlashMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + +@dataclass +class FlashMLASparseMetadata: + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + block_size: int = 64 + topk_tokens: int = 2048 + + @dataclass + class FP8KernelMetadata: + scheduler_metadata: Optional[torch.Tensor] + num_splits: torch.Tensor + dummy_block_table: torch.Tensor + cache_lens: torch.Tensor + + fp8_extra_metadata: Optional[FP8KernelMetadata] = None + + +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + +@dataclass +class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + cache_config = vllm_config.cache_config + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 + ) + self.max_model_len_tensor = torch.tensor( + [self.model_config.max_model_len], device=device, dtype=torch.int32 + ) + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty( + (1, 1), dtype=torch.int32, device=self.device + ) + + # Equation taken from FlashMLA/csrc/pybind.cpp + h_q, h_k = self.num_heads, 1 + s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest + max_num_sm_parts = int( + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) + ) + if current_platform.is_device_capability(100): + max_num_sm_parts *= 2 + self.tile_scheduler_metadata_buffer = torch.empty( + # TileSchedulerMetaDataSize = 8 + # see: FlashMLA/csrc/params.h + (max_num_sm_parts, 8), + dtype=torch.int32, + device=device, + ) + self.num_splits_buffer = torch.empty( + # We pack all the tokens into one batch for sparse attention. + # Otherwise, we can exceed the sm of `get_mla_metadata`. + (2,), + dtype=torch.int32, + device=device, + ) + self.req_id_per_token_buffer = torch.empty( + (vllm_config.scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: + num_tokens = common_attn_metadata.num_actual_tokens + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + fp8_extra_metadata = None + if self.use_fp8_kv_cache: + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor, + num_q_tokens_per_head_k=num_tokens * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + + num_sm_parts = tile_scheduler_metadata.size(0) + # Copy to persistent buffer for full-CG support + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + self.num_splits_buffer.copy_(num_splits) + + fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=self.num_splits_buffer, + # cache_lens and block_table are basically unused in sparse case + # but the decode kernel will treat -1 and indices >= cache_lens + # as invalid so we make sure cache_lens is large enough to not + # accidentally mark indices invalid, we will use -1 exclusively + # to mark invalid indices + cache_lens=self.max_model_len_tensor, + dummy_block_table=self.dummy_block_table, + ) + + metadata = FlashMLASparseMetadata( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + block_table=common_attn_metadata.block_table_tensor, + req_id_per_token=req_id_per_token, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + fp8_extra_metadata=fp8_extra_metadata, + ) + return metadata + + +class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + topk_indice_buffer: Optional[torch.Tensor] = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer = indexer.topk_indices_buffer + self.padding = 128 if current_platform.is_device_capability(100) else 64 + + def _forward_bf16_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) + + # NOTE(Chen): kernel requires num_local_head to be a multiple of + # 64 on hopper and 128 on blackwell + if self.num_heads % self.padding != 0: + assert self.padding % self.num_heads == 0 + logger.warning_once( + f"padding num_heads to {self.padding} \ + due to sparse attn kernel requirement" + ) + q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) + q_padded[:, : self.num_heads, :] = q + q = q_padded + + topk_indices = topk_indices.view(num_tokens, 1, -1) + output = flash_mla_sparse_prefill( + q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale + )[0] + output = output[:, : self.num_heads, :] + return output + + def _forward_fp8_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + assert attn_metadata.fp8_extra_metadata is not None + extra_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = flash_mla_with_kvcache( + q=q.unsqueeze(0), # unsqueeze to add batch_dim + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=extra_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=extra_metadata.cache_lens, + tile_scheduler_metadata=extra_metadata.scheduler_metadata, + num_splits=extra_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim + softmax_scale=self.softmax_scale, + ) + + return _attn_out + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode + + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLACommonImpl" + ) + + if attn_metadata is None: + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + # TODO: handle index / kv_cache correctly + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) + + q = torch.cat([ql_nope, q_pe], dim=-1) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if self.kv_cache_dtype != "fp8_ds_mla": + attn_out = self._forward_bf16_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) + else: + attn_out = self._forward_fp8_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) + + self._v_up_proj(attn_out, out=output[:num_actual_toks]) + return output diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py new file mode 100644 index 0000000000000..b8a232c8447bb --- /dev/null +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + MultipleOf, +) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) + +logger = init_logger(__name__) + + +class DeepseekV32IndexerBackend(AttentionBackend): + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return DeepseekV32IndexerMetadata + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 128] + + @staticmethod + def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: + return DeepseekV32IndexerMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + assert num_kv_heads == 1 + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + return (0, 1, 2) + + @classmethod + def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]: + return [64] + + +@dataclass +class DeepseekV32IndexerPrefillChunkMetadata: + block_table: torch.Tensor + cu_seqlen_ks: torch.Tensor + cu_seqlen_ke: torch.Tensor + cu_seq_lens: torch.Tensor + total_seq_lens: int + token_start: int + token_end: int + num_reqs: int + + +@dataclass +class DeepseekV32IndexerPrefillMetadata: + chunks: list[DeepseekV32IndexerPrefillChunkMetadata] + + +@dataclass +class DeepSeekV32IndexerDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + decode_lens: torch.Tensor + requires_padding: bool + schedule_metadata: torch.Tensor + + +@dataclass +class DeepseekV32IndexerMetadata: + # FIXME (zyongye) + # hacky way to access the data now, need to be in chunked meta + seq_lens: torch.Tensor + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + # The dimension of the attention heads + head_dim: int + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None + prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None + + +# TODO (zyongye) optimize this, this is now vibe coded +def kv_spans_from_batches( + start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + start_seq_loc: 1D long tensor [B+1], cumulative counts of + selected tokens per batch. + Example: [0, 2, 4, 7] -> + batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], + full sequence length (KV length) of each batch. + Example: [5, 9, 4]. + + Returns: + start_tensor: 1D long tensor [N], start offset in the + concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], + **exclusive** end = start + token's local position. + (So the attended KV slice is kv[start:end].) + + Assumes each batch contributes its full `seq_len_per_batch[i]` + keys to the KV cache, andthe selected tokens within a batch + are the **last** `counts[i]` positions of that sequence. + """ + q = start_seq_loc.to(dtype=torch.long) + L = seq_len_per_batch.to(dtype=torch.long) + assert q.dim() == 1 and L.dim() == 1 + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" + + # Selected tokens per batch and totals + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens + B = L.numel() + + if N == 0: + return ( + torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) + + # KV start offsets per batch in the concatenated KV cache + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + + # For each selected token, which batch does it belong to? + batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N] + + # Map batch KV start to each token + start_tensor = kv_starts_per_batch[batch_id] # [N] + + # End-align local positions inside each batch: + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] + # position within the selected block: 1..counts[b] + pos_within = ( + torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1 + ) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end + + return start_tensor.int().to(device), end_location.int().to(device) + + +def get_max_prefill_buffer_size(vllm_config: VllmConfig): + max_model_len = vllm_config.model_config.max_model_len + # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. + # May be tuned later. + return max_model_len * 2 + + +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int +) -> list[tuple[int, int]]: + """ + Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) + such that the total sequence length of each chunk is less than the + maximum prefill buffer size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests. + max_prefill_buffer_size: The maximum prefill buffer size. + reqs_start: The start index of the prefill requests. + + Returns: + A list of tuples of (reqs_start, reqs_end). + """ + chunk_seq_ids = [] + total_seq_lens = 0 + for i in range(reqs_start, len(seq_lens_cpu)): + cur_seq_len = seq_lens_cpu[i].item() + assert cur_seq_len <= max_prefill_buffer_size + total_seq_lens += cur_seq_len + if total_seq_lens > max_prefill_buffer_size: + chunk_seq_ids.append((reqs_start, i)) + reqs_start = i + total_seq_lens = cur_seq_len + if total_seq_lens > 0: + chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) + return chunk_seq_ids + + +class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + + reorder_batch_threshold: int = 1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + scheduler_config = self.vllm_config.scheduler_config + # NOTE(Chen):an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) + self.num_speculative_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config + else 0 + ) + # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 + self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) + + props = torch.cuda.get_device_properties(self.device) + sm_count = props.multi_processor_count + self.num_sms = sm_count + + self.decode_lens_buffer = torch.empty( + (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device + ) + + # See: DeepGMM/csrc/apis/attention.hpp + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) + + def build_one_prefill_chunk( + self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table + ): + prefill_query_start_loc = ( + query_start_loc_cpu[reqs_start : reqs_end + 1] + - query_start_loc_cpu[reqs_start] + ) + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + ) + token_start = query_start_loc_cpu[reqs_start].item() + token_end = query_start_loc_cpu[reqs_end].item() + total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + assert total_seq_lens <= self.max_prefill_buffer_size + cu_seq_lens = ( + torch.cat( + [ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + ] + ) + .to(torch.int32) + .to(self.device) + ) + return DeepseekV32IndexerPrefillChunkMetadata( + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + cu_seq_lens=cu_seq_lens, + total_seq_lens=total_seq_lens, + block_table=block_table[reqs_start:reqs_end], + token_start=token_start, + token_end=token_end, + num_reqs=reqs_end - reqs_start, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> DeepseekV32IndexerMetadata: + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + chunk_seq_ids = split_prefill_chunks( + common_attn_metadata.seq_lens_cpu, + self.max_prefill_buffer_size, + num_decodes, + ) + chunks = [ + self.build_one_prefill_chunk( + reqs_start, + reqs_end, + query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu, + common_attn_metadata.block_table_tensor, + ) + for reqs_start, reqs_end in chunk_seq_ids + ] + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + chunks=chunks, + ) + + decode_metadata = None + if num_decodes > 0: + torch.diff( + common_attn_metadata.query_start_loc[: num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes], + ) + decode_lens = self.decode_lens_buffer[:num_decodes] + decode_lens_cpu = torch.diff( + common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] + ) + + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + + self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( + seq_lens, self.kv_cache_spec.block_size, self.num_sms + ) + decode_metadata = DeepSeekV32IndexerDecodeMetadata( + block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + decode_lens=decode_lens, + requires_padding=requires_padding, + schedule_metadata=self.scheduler_metadata_buffer, + ) + + attn_metadata = DeepseekV32IndexerMetadata( + seq_lens=common_attn_metadata.seq_lens, + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + head_dim=128, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + # if get_tensor_model_parallel_rank() == 0: + # logger.info(f"attn_metadata: {attn_metadata}") + return attn_metadata diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 870cc600388e7..195b05e0a301f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -11,29 +11,25 @@ from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv -# yapf conflicts with isort for this docstring -# yapf: disable -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec -# yapf: enable - def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_MLA + return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA class AiterMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: - return "ROCM_AITER_MLA_VLLM_V1" + return "ROCM_AITER_MLA" @staticmethod def get_impl_cls() -> type["AiterMLAImpl"]: @@ -68,19 +64,28 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - AiterMLAMetadata) - assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ - "only supports block size 1." + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata + ) + assert self.kv_cache_spec.block_size == 1, ( + "AITER MLAonly supports block size 1." + ) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -89,122 +94,141 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # so we can only use the persistent buffer if a cudagraph is actually # being used. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=device) - self.paged_kv_indices = torch.zeros(max_num_pages, - dtype=torch.int32, - device=device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=device) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=device + ) - self.qo_indptr = torch.arange(0, - max_num_reqs + 1, - dtype=torch.int32, - device=device) + self.qo_indptr = torch.arange( + 0, max_num_reqs + 1, dtype=torch.int32, device=device + ) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size - block_table_bounds = (seq_lens + page_size - 1) // page_size + block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device - num_reqs = seq_lens.size(0) + num_reqs = seq_lens_device.size(0) - mask = (torch.arange(block_table_tensor.size(1), - dtype=block_table_tensor.dtype, - device=device).unsqueeze(0) - < block_table_bounds.unsqueeze(1)) + mask = torch.arange( + block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device + ).unsqueeze(0) < block_table_bounds.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) + paged_kv_last_page_len = seq_lens_device % page_size + paged_kv_last_page_len = torch.where( + paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len + ) - paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, device=device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32) - ]) + paged_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32), + ] + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - num_actual_pages = paged_kv_indices.size(0) - self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, - non_blocking=True) + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True + ) self.paged_kv_indices[num_actual_pages:].fill_(-1) paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, - non_blocking=True) - self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) - paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] + self.paged_kv_indptr[: 1 + num_reqs].copy_( + paged_kv_indptr, non_blocking=True + ) + self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] self.paged_kv_last_page_len[:num_reqs].copy_( - paged_kv_last_page_len, non_blocking=True) + paged_kv_last_page_len, non_blocking=True + ) self.paged_kv_last_page_len[num_reqs:].fill_(1) paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] - qo_indptr = self.qo_indptr[:1 + num_reqs] + qo_indptr = self.qo_indptr[: 1 + num_reqs] else: - qo_indptr = torch.arange(0, - num_reqs + 1, - step=1, - dtype=torch.int32, - device=device) + qo_indptr = torch.arange( + 0, num_reqs + 1, step=1, dtype=torch.int32, device=device + ) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - qo_indptr=qo_indptr) + qo_indptr=qo_indptr, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + ) return attn_metadata class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - assert (num_heads == 16 or num_heads == 128), ( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + assert num_heads == 16 or num_heads == 128, ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" - "Try adjusting tensor_parallel_size value.") + "Try adjusting tensor_parallel_size value." + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): output = self.flash_attn_varlen_func( q=q, k=k, @@ -218,33 +242,38 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + assert isinstance(q, torch.Tensor) + B = q.shape[0] + o = torch.zeros( + B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.decode.qo_indptr, max_seqlen_qo, - attn_metadata.decode.paged_kv_indptr, - attn_metadata.decode.paged_kv_indices, - attn_metadata.decode.paged_kv_last_page_len) + aiter_mla_decode_fwd( + q, + kv_buffer, + o, + self.scale, + attn_metadata.decode.qo_indptr, + max_seqlen_qo, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len, + ) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f2974ed668d99..3b6718c48d09a 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,30 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) logger = init_logger(__name__) class TritonMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: - return "TRITON_MLA_VLLM_V1" + return "TRITON_MLA" @staticmethod def get_impl_cls() -> type["TritonMLAImpl"]: @@ -32,56 +36,67 @@ class TritonMLABackend(MLACommonBackend): class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl" + ) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported") + "TritonMLA V1 with FP8 KV cache not yet supported" + ) self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.triton_fa_func = triton_attention if HAS_TRITON else None - def _flash_attn_varlen_diff_headdims_rocm(self, - q, - k, - v, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims_rocm( + self, q, k, v, softmax_scale=None, **kwargs + ): assert self.triton_fa_func is not None # Triton Attention requires a padded V - padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) # The output of triton_attention is a tuple of # [output_tensor, encoded_softmax] where encoded_softmax is always None output_tensor, _ = self.triton_fa_func( @@ -100,18 +115,17 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): return output_tensor - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): - if current_platform.is_rocm() \ - and self.use_triton_flash_attn \ - and not return_softmax_lse: + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + if ( + current_platform.is_rocm() + and self.use_triton_flash_attn + and not return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims_rocm( - q, k, v, softmax_scale=softmax_scale, **kwargs) + q, k, v, softmax_scale=softmax_scale, **kwargs + ) else: return super()._flash_attn_varlen_diff_headdims( q, @@ -119,38 +133,39 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): v, return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, - **kwargs) + **kwargs, + ) def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + if type(q) is tuple: + q = torch.cat(q, dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] + q_num_heads = q.shape[1] + o = torch.zeros( + B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) + lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) num_kv_splits = 4 # TODO: heuristic # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( ( B, - self.num_heads, + q_num_heads, num_kv_splits, # NOTE(lucas) idk why the +1 is here but sglang has it so we # just mirror that @@ -162,13 +177,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) + decode_attention_fwd( + q, + kv_c_and_k_pe_cache, + kv_c_cache, + o, + lse, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, + attn_logits, + num_kv_splits, + self.scale, + PAGE_SIZE, + ) - return self._v_up_proj(o) + return o, lse diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index fd97db0abb84f..1622f852a9522 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -6,9 +6,12 @@ from typing import Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, AttentionType) -from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionType, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, next_power_of_2 @@ -32,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { } try: - import tpu_commons # noqa: F401 + import tpu_inference # noqa: F401 except ImportError: # Lazy import torch_xla import torch_xla.core.xla_builder as xb @@ -42,52 +45,65 @@ except ImportError: from torch_xla.experimental.custom_kernel import XLA_LIB @requires_jax - def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, num_slices_per_block: int): + def kv_cache_update_op_impl( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ): from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( kv_cache_update, - (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) + (kv, slot_mapping, kv_cache, num_kv_update_slices), + {"page_size": page_size, "num_slices_per_block": num_slices_per_block}, + ) return new_kv_cache - XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ - "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ - "int num_slices_per_block)" \ - "-> Tensor", ) + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," + "int num_slices_per_block)" + "-> Tensor", + ) @impl(XLA_LIB, "kv_cache_update_op", "XLA") - def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) + def kv_cache_update_op_xla( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl( + kv, + slot_mapping, + kv_cache, + num_kv_update_slices, + page_size, + num_slices_per_block, + ) return new_kv_cache @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") - def kv_cache_update_op_non_xla(kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: + def kv_cache_update_op_non_xla( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ) -> torch.Tensor: return kv_cache class PallasAttentionBackend(AttentionBackend): - @staticmethod def get_name() -> str: - return "PALLAS_VLLM_V1" + return "PALLAS" @staticmethod def get_impl_cls() -> type["PallasAttentionBackendImpl"]: @@ -97,19 +113,17 @@ class PallasAttentionBackend(AttentionBackend): def get_metadata_cls() -> type["PallasMetadata"]: return PallasMetadata - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) @staticmethod @@ -126,10 +140,12 @@ class PallasAttentionBackend(AttentionBackend): # we simply make sure that the size is smaller than half of SMEM capacity. @staticmethod def get_min_page_size(vllm_config: VllmConfig) -> int: - max_num_page_per_req = (1024 * 1024 // 2 // - vllm_config.scheduler_config.max_num_seqs // 4) - min_page_size = cdiv(vllm_config.model_config.max_model_len, - max_num_page_per_req) + max_num_page_per_req = ( + 1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4 + ) + min_page_size = cdiv( + vllm_config.model_config.max_model_len, max_num_page_per_req + ) min_page_size = 1 << (min_page_size - 1).bit_length() return min_page_size @@ -150,8 +166,7 @@ class PallasAttentionBackend(AttentionBackend): # handle VREG spills. if vllm_config.model_config.max_model_len > 8192: return 16 - page_size = next_power_of_2( - vllm_config.model_config.max_model_len) // 16 + page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16 if page_size <= 16: return 16 if page_size >= 256: @@ -180,7 +195,6 @@ class PallasMetadata: class PallasAttentionBackendImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -207,15 +221,18 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("Alibi slopes is not supported.") if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl" + ) self.kv_cache_quantized_dtype = None if kv_cache_dtype != "auto": self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( - kv_cache_dtype.lower().strip()) + kv_cache_dtype.lower().strip() + ) def forward( self, @@ -235,7 +252,8 @@ class PallasAttentionBackendImpl(AttentionImpl): query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads * 2, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -243,7 +261,8 @@ class PallasAttentionBackendImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for PallasAttentionBackendImpl") + " for PallasAttentionBackendImpl" + ) # For determine_available_memory case. if kv_cache.numel() == 0: @@ -256,15 +275,18 @@ class PallasAttentionBackendImpl(AttentionImpl): key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - padded_head_size = cdiv( - self.head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) query = torch.nn.functional.pad( - query, (0, padded_head_size - self.head_size), value=0.0) + query, (0, padded_head_size - self.head_size), value=0.0 + ) key = torch.nn.functional.pad( - key, (0, padded_head_size - self.head_size), value=0.0) + key, (0, padded_head_size - self.head_size), value=0.0 + ) value = torch.nn.functional.pad( - value, (0, padded_head_size - self.head_size), value=0.0) + value, (0, padded_head_size - self.head_size), value=0.0 + ) if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: # Write input keys and values to the KV cache. @@ -283,9 +305,9 @@ class PallasAttentionBackendImpl(AttentionImpl): ) if self.kv_cache_quantized_dtype is not None and ( - layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0): - raise ValueError( - "k_scale_float and v_scale_float must be non-zero") + layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0 + ): + raise ValueError("k_scale_float and v_scale_float must be non-zero") output = torch.ops.xla.ragged_paged_attention( query, kv_cache, @@ -308,7 +330,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - output = output[:, :, :self.head_size] + output = output[:, :, : self.head_size] return output.reshape(num_tokens, hidden_size) @@ -324,17 +346,16 @@ def write_to_kv_cache( k_scale: float = 1.0, v_scale: float = 1.0, ) -> None: - """ Write the key and values to the KV cache. + """Write the key and values to the KV cache. Args: key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] + kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size] num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape - head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT if kv_cache_quantized_dtype is not None: dtype_info = torch.finfo(kv_cache_quantized_dtype) @@ -346,15 +367,19 @@ def write_to_kv_cache( value = torch.clamp(value, dtype_info.min, dtype_info.max) value = value.to(kv_cache_quantized_dtype) - kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, - head_size) + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, slot_mapping, kv_cache, num_kv_update_slices, page_size, - num_slices_per_kv_cache_update_block) + kv, + slot_mapping, + kv_cache, + num_kv_update_slices, + page_size, + num_slices_per_kv_cache_update_block, + ) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) @@ -392,15 +417,18 @@ def get_dtype_packing(dtype): if 32 % bits != 0: raise ValueError( f"The bit width must be divisible by 32, but got bits={bits}, " - "dtype={dtype}") + "dtype={dtype}" + ) return 32 // bits -def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, - kv_cache_dtype: torch.dtype) -> int: +def get_page_size_bytes( + block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype +) -> int: """Returns the size in bytes of one page of the KV cache.""" - padded_head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) num_combined_kv_heads = num_kv_heads * 2 # NOTE: for the implicit padding in XLA @@ -408,5 +436,6 @@ def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) - return (block_size * num_combined_kv_heads * padded_head_size * - kv_cache_dtype_bits // 8) + return ( + block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8 + ) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 403ad8e88a958..82505f6281c0a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1,19 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 @@ -43,55 +51,63 @@ if current_platform.is_rocm(): batch_idx = tl.program_id(0) block_idx = tl.program_id(1) - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + - tl.arange(0, 2)) + batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2)) batch_query_start, batch_query_end = tl.split(batch_query_indexes) query_len = batch_query_end - batch_query_start if query_len <= 1: return - batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + - tl.arange(0, 2)) + batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2)) batch_token_start, batch_token_end = tl.split(batch_token_indexes) seq_len = batch_token_end - batch_token_start if block_idx * BLOCK_SIZE < seq_len: - block_mask = (block_idx * BLOCK_SIZE + - tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len + block_mask = ( + block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ) < seq_len - kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + - block_idx).to(tl.int64) + kv_idx = tl.load( + block_table + batch_idx * block_table_stride_0 + block_idx + ).to(tl.int64) - kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( - 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] - k_vals = tl.load(k_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) + kv_buffer_off = ( + kv_idx * BLOCK_SIZE * E_DIM + + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + + tl.arange(0, E_DIM)[None, :] + ) + k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) if k_vals.dtype.is_fp8(): - k_vals = (k_vals.to(tl.float32) * - tl.load(k_scale)).to(output_dtype) + k_vals = (k_vals.to(tl.float32) * tl.load(k_scale)).to(output_dtype) else: k_vals = k_vals.to(output_dtype) - v_vals = tl.load(v_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) + v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) if v_vals.dtype.is_fp8(): - v_vals = (v_vals.to(tl.float32) * - tl.load(v_scale)).to(output_dtype) + v_vals = (v_vals.to(tl.float32) * tl.load(v_scale)).to(output_dtype) else: v_vals = v_vals.to(output_dtype) - kv_values_off = batch_token_start * E_DIM + \ - block_idx * BLOCK_SIZE * E_DIM + \ - tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \ - tl.arange(0, E_DIM)[None, :] + kv_values_off = ( + batch_token_start * E_DIM + + block_idx * BLOCK_SIZE * E_DIM + + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + + tl.arange(0, E_DIM)[None, :] + ) tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask) tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) - def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, - k_cache, v_cache, max_seq_len, k_scale, v_scale, - output_dtype, total_tokens): + def vllm_layout_trans( + b_query_lens_loc, + b_seq_lens_loc, + block_table, + k_cache, + v_cache, + max_seq_len, + k_scale, + v_scale, + output_dtype, + total_tokens, + ): H_KV = v_cache.shape[2] D = v_cache.shape[3] BLOCK_SIZE = v_cache.shape[1] @@ -107,8 +123,7 @@ if current_platform.is_rocm(): device=v_cache.device, ) - grid = (block_table.shape[0], - (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) + grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) if output_dtype == torch.float16: output_dtype = tl.float16 @@ -117,19 +132,21 @@ if current_platform.is_rocm(): else: raise ValueError(f"Unsupported output dtype: {output_dtype}") - _vllm_layout_trans_kernel[grid](k_cache, - v_cache, - k_values, - v_values, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table.stride(0), - k_scale, - v_scale, - output_dtype=output_dtype, - E_DIM=H_KV * D, - BLOCK_SIZE=BLOCK_SIZE) + _vllm_layout_trans_kernel[grid]( + k_cache, + v_cache, + k_values, + v_values, + b_query_lens_loc, + b_seq_lens_loc, + block_table, + block_table.stride(0), + k_scale, + v_scale, + output_dtype=output_dtype, + E_DIM=H_KV * D, + BLOCK_SIZE=BLOCK_SIZE, + ) return k_values, v_values @@ -152,9 +169,18 @@ if current_platform.is_rocm(): ) -> torch.Tensor: if total_tokens == 0: total_tokens = int(cu_seqlens_k[-1].item()) - k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table, - k_cache, v_cache, max_seqlen_k, k_scale, - v_scale, q.dtype, total_tokens) + k, v = vllm_layout_trans( + cu_seqlens_q, + cu_seqlens_k, + block_table, + k_cache, + v_cache, + max_seqlen_k, + k_scale, + v_scale, + q.dtype, + total_tokens, + ) output = aiter.flash_attn_varlen_func( q=q, @@ -190,16 +216,17 @@ if current_platform.is_rocm(): v_scale: torch.Tensor, total_tokens: int = 0, ) -> torch.Tensor: - return torch.empty(q.shape[0], - q.shape[1], - v_cache.shape[-2], - dtype=q.dtype, - device=q.device) + return torch.empty( + q.shape[0], q.shape[1], v_cache.shape[-2], dtype=q.dtype, device=q.device + ) - direct_register_custom_op("flash_attn_varlen_func", - flash_attn_varlen_func_impl, ["out"], - flash_attn_varlen_func_fake, - dispatch_key=current_platform.dispatch_key) + direct_register_custom_op( + "flash_attn_varlen_func", + flash_attn_varlen_func_impl, + ["out"], + flash_attn_varlen_func_fake, + dispatch_key=current_platform.dispatch_key, + ) logger = init_logger(__name__) @@ -231,43 +258,51 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - cudagraph_support = AttentionCGSupport.ALWAYS + AttentionMetadataBuilder[AiterFlashAttentionMetadata] +): + cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - self.total_tokens = self.model_config.max_model_len \ + self, common_attn_metadata: CommonAttentionMetadata + ): + self.total_tokens = ( + self.model_config.max_model_len * self.vllm_config.scheduler_config.max_num_partial_prefills - res = self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + ) + res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) self.total_tokens = 0 return res - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> "AiterFlashAttentionMetadata": num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len @@ -278,20 +313,18 @@ class AiterFlashAttentionMetadataBuilder( if max_query_len > 1: # We pre-compute cumulative seq len needed for prefill attention # here to avoid recomputing it for every layer - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, - dtype=torch.int32, - device=seq_lens.device) - torch.cumsum(seq_lens, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) + cu_seq_lens = torch.zeros( + seq_lens.shape[0] + 1, dtype=torch.int32, device=seq_lens.device + ) + torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) num_actual_kv_tokens = int(cu_seq_lens[-1].item()) else: cu_seq_lens = None num_actual_kv_tokens = 0 - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): return None use_cascade = common_prefix_len > 0 @@ -317,7 +350,6 @@ class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -328,6 +360,10 @@ class AiterFlashAttentionBackend(AttentionBackend): def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -337,11 +373,12 @@ class AiterFlashAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "FLASH_ATTN_VLLM_V1" + return "FLASH_ATTN" @staticmethod def get_impl_cls() -> type["AiterFlashAttentionImpl"]: @@ -361,6 +398,7 @@ class AiterFlashAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") @@ -368,7 +406,6 @@ class AiterFlashAttentionBackend(AttentionBackend): class AiterFlashAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -396,7 +433,7 @@ class AiterFlashAttentionImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0. + logits_soft_cap = 0.0 self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -406,10 +443,12 @@ class AiterFlashAttentionImpl(AttentionImpl): AiterFlashAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl" + ) def forward( self, @@ -429,7 +468,8 @@ class AiterFlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -441,8 +481,8 @@ class AiterFlashAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -479,8 +519,8 @@ class AiterFlashAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc @@ -511,13 +551,14 @@ class AiterFlashAttentionImpl(AttentionImpl): _, num_heads, head_size = query.shape nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 num_seqs = seqused_k.shape[0] - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - - 1) // _PARTITION_SIZE_ROCM + max_num_partitions = ( + max_seqlen_k + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * - nbytes_per_qo_elem + 2 * - (num_seqs * num_heads * max_num_partitions) * 4, + (num_seqs * num_heads * max_num_partitions * head_size) + * nbytes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, dtype=torch.uint8, device=output.device, ) @@ -545,4 +586,5 @@ class AiterFlashAttentionImpl(AttentionImpl): return output else: raise NotImplementedError( - "Cascade attention is not implemented for ROCM AITER") + "Cascade attention is not implemented for ROCM AITER" + ) diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py new file mode 100644 index 0000000000000..235ea1c376ef4 --- /dev/null +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with PagedAttention and Triton prefix prefill.""" + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, +) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.rocm_attn import ( + RocmAttentionBackend, + RocmAttentionImpl, + RocmAttentionMetadata, + RocmAttentionMetadataBuilder, +) + +logger = init_logger(__name__) + + +class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_UNIFIED_ATTN" + + @staticmethod + def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]: + return RocmAiterUnifiedAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return RocmAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: + return RocmAttentionMetadataBuilder + + +class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + sinks, + ) + logger.info_once( + "Using aiter unified attention for RocmAiterUnifiedAttentionImpl" + ) + from aiter.ops.triton.unified_attention import unified_attention + + self.unified_attention = unified_attention + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_block_scale is not None: + raise NotImplementedError( + "fused block_scale output quantization is not yet supported" + " for RocmAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + key_cache, value_cache = kv_cache.unbind(0) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + assert layer._q_scale_float == 1.0, ( + "A non 1.0 q_scale is not currently supported." + ) + + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + self.unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, + output_scale=output_scale, + ) + + return output diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py new file mode 100644 index 0000000000000..10dd01f0a5aa4 --- /dev/null +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with PagedAttention and Triton prefix prefill.""" + +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +logger = init_logger(__name__) + + +@dataclass +class RocmAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + +class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + self.block_size = kv_cache_spec.block_size + + model_config = vllm_config.model_config + self.num_heads_q = model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) + self.headdim = model_config.get_head_size() + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> RocmAttentionMetadata: + attn_metadata = self.build(0, common_attn_metadata) + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + attn_metadata.seq_lens.fill_(1) + + # Here we set the query start locs to 0. This is to + # cover up an invalid memory access in the prefix_prefil kernel + # that we run into during graph capture (#25985) + common_attn_metadata.query_start_loc.zero_() + common_attn_metadata.query_start_loc_cpu.zero_() + + return attn_metadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> RocmAttentionMetadata: + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len + suffix_kv_lens = suffix_kv_lens.to(self.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = RocmAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + ) + return attn_metadata + + +class RocmAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes." + ) + + @staticmethod + def get_name() -> str: + return "ROCM_ATTN" + + @staticmethod + def get_impl_cls() -> type["RocmAttentionImpl"]: + return RocmAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return RocmAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: + return RocmAttentionMetadataBuilder + + +class RocmAttentionImpl(AttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + RocmAttentionBackend.validate_head_size(head_size) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "RocmAttentionImpl" + ) + + self.fp8_dtype = current_platform.fp8_dtype() + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}." + ) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_block_scale is not None: + raise NotImplementedError( + "fused block_scale output quantization is not yet supported" + " for RocmAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size + ) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + assert layer._q_scale_float == 1.0, ( + "A non 1.0 q_scale is not currently supported." + ) + + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + output_scale=output_scale, + sinks=self.sinks, + ) + + return output diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index d80ced8ec876a..74cfecca764e6 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -1,20 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) class ShortConvAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: return ShortConvAttentionMetadataBuilder @@ -28,54 +29,78 @@ class ShortConvAttentionMetadata: num_decode_tokens: int query_start_loc: torch.Tensor - has_initial_states: torch.Tensor - state_indices_tensor: torch.Tensor # shape: [batch,] + state_indices_tensor: torch.Tensor + has_initial_states_p: Optional[torch.Tensor] # For causal_conv1d nums_dict: Optional[dict] = None - cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None class ShortConvAttentionMetadataBuilder( - AttentionMetadataBuilder[ShortConvAttentionMetadata]): - - reorder_batch_threshold: ClassVar[int] = 1 - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> ShortConvAttentionMetadata: + BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata] +): + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # for causal_conv1d + nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) - has_initial_states = None + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) + + has_initial_states_p = None if num_prefills > 0: - #[batch,] has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) - has_initial_states = has_initial_states_cpu.to( - query_start_loc.device) + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) + has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device) + + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) + + nums_dict, batch_ptr, token_chunk_offset_ptr = ( + compute_causal_conv1d_metadata(query_start_loc_p) + ) + + elif ( + num_decodes > 0 + and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) + state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID attn_metadata = ShortConvAttentionMetadata( + query_start_loc=query_start_loc, + state_indices_tensor=state_indices_tensor, + has_initial_states_p=has_initial_states_p, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - query_start_loc=query_start_loc, - has_initial_states=has_initial_states, - state_indices_tensor=state_indices_tensor, + nums_dict=nums_dict, + batch_ptr=batch_ptr, + token_chunk_offset_ptr=token_chunk_offset_ptr, ) - return attn_metadata \ No newline at end of file + return attn_metadata diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index c93223a340839..669dbe31810b6 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,31 +4,32 @@ import ast from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional, Union import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - -from vllm import _custom_ops as ops - logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -39,6 +40,10 @@ class TreeAttentionBackend(AttentionBackend): def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -48,11 +53,12 @@ class TreeAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "TREE_ATTN_VLLM_V1" + return "TREE_ATTN" @staticmethod def get_impl_cls() -> type["TreeAttentionImpl"]: @@ -68,6 +74,7 @@ class TreeAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") @@ -113,9 +120,9 @@ class TreeAttentionMetadata: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -123,8 +130,8 @@ class TreeAttentionMetadata: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -138,9 +145,9 @@ class TreeAttentionMetadata: # metadata structure return self._cached_decode_metadata - q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_start_loc = self.query_start_loc[: self.num_decodes + 1] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[:self.num_decodes] + kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, @@ -148,16 +155,14 @@ class TreeAttentionMetadata: query_start_loc=q_start_loc, max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], tree_attn_bias=self.tree_attn_bias, ) return self._cached_decode_metadata -class TreeAttentionMetadataBuilder( - AttentionMetadataBuilder[TreeAttentionMetadata]): - +class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, @@ -165,15 +170,15 @@ class TreeAttentionMetadataBuilder( vllm_config: VllmConfig, device: torch.device, ): - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config spec_token_tree = (spec := spec_config) and spec.speculative_token_tree - tree_choices: list[tuple[int, - ...]] = (ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else - [(0, )]) + tree_choices: list[tuple[int, ...]] = ( + ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] + ) # Construct the tree attention bias. depth_counts = _get_depth_counts(tree_choices) self.tree_attn_bias = _prepare_tree_attn_bias( @@ -183,12 +188,7 @@ class TreeAttentionMetadataBuilder( device=device, ) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.tree_attn_bias.shape[0]) + self.reorder_batch_threshold = self.tree_attn_bias.shape[0] def build( self, @@ -198,8 +198,10 @@ class TreeAttentionMetadataBuilder( ) -> TreeAttentionMetadata: decode_threshold = self.tree_attn_bias.shape[0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=decode_threshold)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=decode_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -239,8 +241,7 @@ class TreeAttentionMetadataBuilder( # Slice the tree attention bias for drafting. Exclude # the root level. start, end = 1, 1 + common_attn_metadata.max_query_len - self.tree_attn_bias = self.tree_attn_bias[start:end, - start:end].contiguous() + self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() # Build attention bias. attn_metadata = self.build(0, common_attn_metadata, fast_build=True) @@ -271,10 +272,9 @@ def _prepare_tree_attn_bias( ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) + tree_attn_mask = torch.full( + (tree_len, tree_len), -torch.inf, device=device, dtype=dtype + ) # Set diagonal to all zeros. Each token should # attend to itself. @@ -296,14 +296,14 @@ def _prepare_tree_attn_bias( ancestor_idx = [] for c in range(len(cur_tree_choice) - 1): ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 + ) tree_attn_mask[j + start + 1, ancestor_idx] = mask_val start += depth_counts[i] return tree_attn_mask class TreeAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -339,10 +339,12 @@ class TreeAttentionImpl(AttentionImpl): TreeAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TreeAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TreeAttentionImpl." + ) def forward( self, @@ -362,7 +364,8 @@ class TreeAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -371,8 +374,8 @@ class TreeAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for TreeAttentionImpl") + "fused output quantization is not yet supported for TreeAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -401,8 +404,7 @@ class TreeAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens - descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) if prefill_meta := attn_metadata.prefill_metadata: unified_attention( q=query[num_decode_tokens:num_actual_tokens], diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b12036c599799..878634c7f521d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,28 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with PagedAttention and Triton prefix prefill.""" +"""High-Performance Triton-only Attention layer.""" + from dataclasses import dataclass -from functools import cache -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) -from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, +) +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) +from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + logger = init_logger(__name__) @@ -56,21 +70,25 @@ class TritonAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None -class TritonAttentionMetadataBuilder( - AttentionMetadataBuilder[TritonAttentionMetadata]): +class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - self.device = device + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( @@ -83,10 +101,12 @@ class TritonAttentionMetadataBuilder( attn_metadata.seq_lens.fill_(1) return attn_metadata - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TritonAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TritonAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -99,14 +119,13 @@ class TritonAttentionMetadataBuilder( use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None @@ -133,31 +152,30 @@ class TritonAttentionMetadataBuilder( class TritonAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + return [torch.float16, torch.bfloat16, torch.float32] - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") + # Triton Attention supports any head size above 32 + if head_size < 32: raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " + f"Head size {head_size} is not supported by TritonAttention." + f"Head sizes need to be larger or equal 32 for this backend. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "TRITON_ATTN_VLLM_V1" + return "TRITON_ATTN" @staticmethod def get_impl_cls() -> type["TritonAttentionImpl"]: @@ -173,10 +191,11 @@ class TritonAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -187,16 +206,9 @@ class TritonAttentionBackend(AttentionBackend): return TritonAttentionMetadataBuilder -@cache -def use_aiter_unified_attention() -> bool: - """Check if aiter unified attention should be used.""" - # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set - # to 1 as default - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_USE_AITER_UNIFIED_ATTENTION - - class TritonAttentionImpl(AttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym def __init__( self, @@ -235,37 +247,22 @@ class TritonAttentionImpl(AttentionImpl): TritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonAttentionImpl" + ) self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - - if not self.force_prefill_decode_attn: - # If not using prefill decode attention, we use the Triton - # unified attention implementation. - if use_aiter_unified_attention(): - logger.info_once( - "Using aiter unified attention for TritonAttentionImpl") - from aiter.ops.triton.unified_attention import ( - unified_attention) - self.unified_attention = unified_attention - else: - logger.info_once( - "Using vllm unified attention for TritonAttentionImpl") - from vllm.attention.ops.triton_unified_attention import ( - unified_attention) - self.unified_attention = unified_attention self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") + f"num_heads: {num_heads}." + ) def forward( self, @@ -274,28 +271,30 @@ class TritonAttentionImpl(AttentionImpl): key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: TritonAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with Paged Attention impl. in Triton. Args: query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [num_blocks, 2, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: + if output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for TritonAttentionImpl") + "fused block_scale output quantization is not yet supported" + " for TritonAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -312,54 +311,45 @@ class TritonAttentionImpl(AttentionImpl): # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - use_prefill_decode_attn = self.force_prefill_decode_attn num_actual_tokens = attn_metadata.num_actual_tokens - - if use_prefill_decode_attn: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - else: - key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = kv_cache.unbind(1) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - if use_prefill_decode_attn: - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) + if key_cache.dtype != self.fp8_dtype: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ + assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." - if not current_platform.is_rocm(): - # Skip Q quantization on ROCm, since dequantizing back to - # f32 in the attention kernel is not supported. + ) + if current_platform.is_cuda(): + # Skip Q quantization on ROCm and XPU, enable this on cuda + # only, since dequantizing back to f32 in the attention kernel + # is not supported. query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc @@ -368,51 +358,28 @@ class TritonAttentionImpl(AttentionImpl): max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - if use_prefill_decode_attn: - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode( - query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale, - sinks=self.sinks, - ) + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - else: - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - - self.unified_attention( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - sinks=self.sinks, - ) + unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, + output_scale=output_scale, + ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 39bdbe125635b..7c6940d9b15d5 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,11 +4,23 @@ import abc import enum import functools from abc import abstractmethod -from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar +from dataclasses import dataclass, field, fields, make_dataclass +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + Optional, + Protocol, + TypeVar, + Union, + get_args, +) import numpy as np import torch +from typing_extensions import runtime_checkable from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv @@ -19,15 +31,24 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layer import Attention +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) + get_kv_connector_cache_layout, +) from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) -_KV_CACHE_LAYOUT_OVERRIDE = None +KVCacheLayoutType = Literal["NHD", "HND"] +_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None + +PAD_SLOT_ID = -1 + + +def is_valid_kv_cache_layout(value: str) -> bool: + return value in get_args(KVCacheLayoutType) @dataclass @@ -35,7 +56,7 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - + For many of the tensors we keep both GPU and CPU versions. """ @@ -65,11 +86,15 @@ class CommonAttentionMetadata: causal: bool = True + # Needed by FastPrefillAttentionBuilder + logits_indices_padded: Optional[torch.Tensor] = None + num_logits_indices: Optional[int] = None -@dataclass -class UbatchSlice: - request_slice: slice - token_slice: slice + # Needed by CrossAttentionBuilder + encoder_seq_lens: Optional[np.ndarray] = None + + dcp_local_seq_lens: Optional[torch.Tensor] = None + """Sequence lengths of the local rank in decode context parallelism world""" def slice_query_start_locs( @@ -77,46 +102,92 @@ def slice_query_start_locs( request_slice: slice, ) -> torch.Tensor: """ - Creates a new query_start_loc that corresponds to the requests in + Creates a new query_start_loc that corresponds to the requests in request_slice. Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ - return query_start_loc[request_slice.start: request_slice.stop + 1] -\ - query_start_loc[request_slice.start] + return ( + query_start_loc[request_slice.start : request_slice.stop + 1] + - query_start_loc[request_slice.start] + ) def _make_metadata_with_slice( - ubatch_slice: UbatchSlice, - attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata +) -> CommonAttentionMetadata: """ - This function creates a new CommonAttentionMetadata that corresponds to + This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice """ + assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" + request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice - query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, - request_slice) - assert len(query_start_loc) >= 2, ( - f"query_start_loc must have at least 2 elements, " - f"got {len(query_start_loc)}") - query_start_loc_cpu = slice_query_start_locs( - attn_metadata.query_start_loc_cpu, request_slice) + start_locs = attn_metadata.query_start_loc_cpu + first_req = request_slice.start + first_tok = token_slice.start + last_req = request_slice.stop - 1 + last_tok = token_slice.stop - 1 + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( + "Token slice start outside of first request" + ) + assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], ( + "Token slice end outside of last request" + ) + + # If the "middle" request has tokens in both ubatches, we have to split it. + # If ubatch_slice is the first ubatch then we will be splitting the last + # request. If it's the second microbatch, then we will be splitting the + # first request + splits_first_request = first_tok > start_locs[first_req] + splits_last_request = last_tok < start_locs[last_req + 1] - 1 + + query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) + query_start_loc = slice_query_start_locs( + attn_metadata.query_start_loc, request_slice + ) + + assert len(query_start_loc) >= 2, ( + f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" + ) + + if splits_first_request: + tokens_skipped = first_tok - start_locs[first_req] + query_start_loc[1:] -= tokens_skipped + query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + + if splits_last_request: + tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop + query_start_loc[-1] -= tokens_skipped + query_start_loc_cpu[-1] -= tokens_skipped + + # Make sure we don't modify the seq_lens tensors + # (not cudagraph compatible) + seq_lens = seq_lens.clone() + seq_lens_cpu = seq_lens_cpu.clone() + seq_lens[-1] -= tokens_skipped + seq_lens_cpu[-1] -= tokens_skipped + max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ - request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( - torch.max(torch.abs(query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1])).item()) + torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() + ) + + # This is to account for the case where we are in a dummy + # run and query_start_loc_cpu is full of 0s + if max_query_len == 0: + max_query_len = attn_metadata.max_query_len block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] @@ -137,19 +208,19 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: list[UbatchSlice], + ubatch_slices: list[UBatchSlice], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ - Creates a new CommonAttentionMetadata instance that corresponds to the - requests for each UbatchSlice in ubatch_slices. + Creates a new CommonAttentionMetadata instance that corresponds to the + requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: - results.append( - _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + return results @@ -157,7 +228,7 @@ M = TypeVar("M") class AttentionCGSupport(enum.Enum): - """ Constants for the cudagraph support of the attention backend + """Constants for the cudagraph support of the attention backend Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" @@ -165,7 +236,7 @@ class AttentionCGSupport(enum.Enum): """Cudagraph always supported; supports mixed-prefill-decode""" UNIFORM_BATCH = 2 """Cudagraph supported for batches the only contain query lengths that are - the same, this can be used for spec-decode + the same, this can be used for spec-decode i.e. "decodes" are 1 + num_speculative_tokens""" UNIFORM_SINGLE_TOKEN_DECODE = 1 """Cudagraph supported for batches the only contain query_len==1 decodes""" @@ -175,27 +246,54 @@ class AttentionCGSupport(enum.Enum): class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. - reorder_batch_threshold: ClassVar[Optional[int]] = None + reorder_batch_threshold: Optional[int] = None @abstractmethod - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + self.vllm_config = vllm_config + self.device = device + + def _init_reorder_batch_threshold( + self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False + ) -> None: + self.reorder_batch_threshold = reorder_batch_threshold + if self.reorder_batch_threshold is not None and supports_spec_as_decode: + # If the backend supports spec-as-decode kernels, then we can set + # the reorder_batch_threshold based on the number of speculative + # tokens from the config. + speculative_config = self.vllm_config.speculative_config + if ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ): + self.reorder_batch_threshold = max( + self.reorder_batch_threshold, + 1 + speculative_config.num_speculative_tokens, + ) @abstractmethod - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. - + Args: common_prefix_len: The length of the common prefix of the batch. common_attn_metadata: The common attention metadata. @@ -206,14 +304,16 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): raise NotImplementedError def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) def build_for_drafting( self, @@ -222,7 +322,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ) -> M: """ Build attention metadata for draft model. Uses build by default. - + Args: common_attn_metadata: The common attention metadata. draft_index: The index of the current draft operation. @@ -231,9 +331,11 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): For tree-based attention, this index instead refers to the draft attempt for the i-th level in the tree of tokens. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - fast_build=True) + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) def use_cascade_attention( self, @@ -256,8 +358,11 @@ def get_kv_cache_layout(): if _KV_CACHE_LAYOUT_OVERRIDE is not None: cache_layout = _KV_CACHE_LAYOUT_OVERRIDE - logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ - "Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " + "Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout # Format specified by the user. @@ -266,12 +371,16 @@ def get_kv_cache_layout(): if cache_layout is None: cache_layout = get_kv_connector_cache_layout() else: - logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ - "detected. Setting KV cache layout to %s.", cache_layout) + assert is_valid_kv_cache_layout(cache_layout) + logger.info_once( + "`VLLM_KV_CACHE_LAYOUT` environment variable " + "detected. Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout -def set_kv_cache_layout(cache_layout: str): +def set_kv_cache_layout(cache_layout: KVCacheLayoutType): global _KV_CACHE_LAYOUT_OVERRIDE _KV_CACHE_LAYOUT_OVERRIDE = cache_layout @@ -289,17 +398,20 @@ class PerLayerParameters: logits_soft_cap: Optional[float] sm_scale: float has_sinks: bool = False + # has same params for all layers + has_same_window_lefts: Optional[bool] = field(default=None, compare=False) + has_same_all_params: Optional[bool] = field(default=None, compare=False) def get_per_layer_parameters( - vllm_config: VllmConfig, layer_names: list[str], - cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: + vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"] +) -> dict[str, PerLayerParameters]: """ Scan layers in `layer_names` and determine some hyperparameters to use during `plan`. """ - layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names) + layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): @@ -313,17 +425,18 @@ def get_per_layer_parameters( sm_scale = impl.scale has_sinks = getattr(impl, "sinks", None) is not None - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale, - has_sinks) + per_layer_params[key] = PerLayerParameters( + window_left, logits_soft_cap, sm_scale, has_sinks + ) return per_layer_params def infer_global_hyperparameters( - per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + per_layer_params: dict[str, PerLayerParameters], +) -> PerLayerParameters: """ - Currently, FlashInfer backend other than trtllm-gen + Currently, FlashInfer backend other than trtllm-gen only support models in which all layers share the same values for the following hyperparameters: - `window_left` @@ -339,18 +452,12 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] - # trtllm attention doesn't need global hyper params so disable the check - if not envs.VLLM_USE_TRTLLM_ATTENTION: - for params in param_sets: - if params.window_left != global_params.window_left: - raise ValueError( - "Window left is not the same for all layers. " \ - "One potential fix is to set disable_sliding_window=True") - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all" - "layers share the same values " - "for the following hyperparameters:" - "`window_left`, `logits_soft_cap`, `sm_scale`.") + global_params.has_same_window_lefts = all( + params.window_left == global_params.window_left for params in param_sets + ) + global_params.has_same_all_params = all( + params == global_params for params in param_sets + ) return global_params @@ -432,11 +539,10 @@ def make_local_attention_virtual_batches( # new_tokens_in_first_block = [2, 1, 4] # local_blocks = [2, 4, 2] q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), - q_seqlens).astype(np.int32) + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, - attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) # Once we know the number of local blocks we can compute the request spans # for each batch idx, we can figure out the number of "virtual" requests we @@ -457,14 +563,13 @@ def make_local_attention_virtual_batches( rarange = np.repeat(local_blocks, local_blocks) - arange - 1 # Then we can compute the seqlens_q_local, handling the fact that the # first and last blocks could be partial - seqlens_q_local = \ - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) # set the first block since this may be a partial block seqlens_q_local[arange == 0] = q_tokens_in_first_block # set the remaining blocks seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), - attn_chunk_size)[arange > 0] + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] # convert from q_seqlens to cu_seqlens_q cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) @@ -476,22 +581,20 @@ def make_local_attention_virtual_batches( # batch # For our example this will be: # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], - attn_chunk_size, - dtype=np.int32) + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block num_computed_tokens_local = seqlens_k_local - seqlens_q_local - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ - (rarange * attn_chunk_size + \ - np.repeat(tokens_in_last_block, local_blocks)) + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] block_starts = k_seqstarts_absolute // block_size - assert attn_chunk_size % block_size == 0, \ - f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by block_size {block_size}" + assert attn_chunk_size % block_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}" + ) pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks @@ -512,14 +615,24 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices = (block_starts[:, None] + - np.arange(pages_per_local_batch, dtype=np.int32)) - block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - - 1) - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch) - block_table_local = block_table[batch_indices, block_indices]\ - .view(virtual_batches, -1) + block_indices = block_starts[:, None] + np.arange( + pages_per_local_batch, dtype=np.int32 + ) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) + + # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance + # regression when using numpy arrays (batch and block indices) to index into + # torch tensor (block_table). As a workaround, convert numpy arrays to torch + # tensor first, which recovers perf. + batch_indices_torch = torch.from_numpy(batch_indices) + block_indices_torch = torch.from_numpy(block_indices) + block_table_local = block_table[batch_indices_torch, block_indices_torch].view( + virtual_batches, -1 + ) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) @@ -527,8 +640,7 @@ def make_local_attention_virtual_batches( return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, - query_start_loc=query_start_loc_cpu.to(device=device, - non_blocking=True), + query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True), seq_lens_cpu=seq_lens_cpu, seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), @@ -542,22 +654,85 @@ def make_local_attention_virtual_batches( ) +def make_kv_sharing_fast_prefill_common_attn_metadata( + common_attn_metadata: CommonAttentionMetadata, +) -> CommonAttentionMetadata: + if common_attn_metadata.max_query_len == 1: + # All requests are decode (assume 1 token for now) + # Skip computing fast prefill path + return common_attn_metadata + + assert common_attn_metadata.logits_indices_padded is not None + assert common_attn_metadata.num_logits_indices is not None + + logits_indices_padded = common_attn_metadata.logits_indices_padded + num_logits_indices = common_attn_metadata.num_logits_indices + # Get rid of CUDAGraph padding, if any + logits_indices = logits_indices_padded[:num_logits_indices] + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty( + num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype + ) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + causal=True, + ) + return common_attn_metadata + + def subclass_attention_backend( - name_prefix: str, attention_backend_cls: type[AttentionBackend], - builder_cls: type[AttentionMetadataBuilder[M]] + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]], ) -> type[AttentionBackend]: """ Return a new subclass where `get_builder_cls` returns `builder_cls`. """ name: str = name_prefix + attention_backend_cls.__name__ # type: ignore - return type(name, (attention_backend_cls, ), - {"get_builder_cls": lambda: builder_cls}) + return type( + name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} + ) def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, + require_uniform: bool = False, ) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode @@ -567,6 +742,9 @@ def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. + require_uniform: If True, requires that all decode requests have the + same query length. When set, some queries may be considered prefills + even if they are <= decode_threshold, in order to ensure uniformity. Returns: num_decodes: The number of decode requests. @@ -579,16 +757,25 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold: + if max_query_len <= decode_threshold and ( + not require_uniform or decode_threshold <= 1 + ): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] - is_prefill = query_lens > decode_threshold + if query_lens[0].item() > decode_threshold: + # first request is not decode, so no decode requests + return 0, num_reqs, 0, num_tokens + + if require_uniform: + is_prefill = query_lens != query_lens[0] + else: + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes @@ -605,7 +792,7 @@ def reorder_batch_to_split_decodes_and_prefills( """ Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch. - + Returns: True if the batch was modified, False otherwise. """ @@ -622,10 +809,6 @@ def reorder_batch_to_split_decodes_and_prefills( for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the TritonMLA._forward_decode only supports - # num_tokens = 1 if num_tokens <= decode_threshold: decodes.append(i) num_decode_tokens += num_tokens @@ -660,9 +843,38 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch +def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor: + """ + Reshapes the query tensor for the specified batch size, so that + it has shape (batch_size, seq_len, num_heads, head_dim). + """ + assert query.dim() == 3, f"query must be 3D, got {query.dim()}D" + total_tokens = query.shape[0] + num_heads = query.shape[1] + head_dim = query.shape[2] + assert total_tokens % batch_size == 0, ( + f"{total_tokens=} is not divisible by {batch_size=}" + ) + seq_len = total_tokens // batch_size + return query.view(batch_size, seq_len, num_heads, head_dim) + + +def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor: + """ + Reshapes the attention output tensor, so that + the batch_size and seq_len dimensions are combined. + """ + if attn_output.dim() == 3: + # Already in the correct shape + return attn_output + assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D" + total_tokens = attn_output.shape[0] * attn_output.shape[1] + return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) + + KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ('logits_indices_padded', Optional[torch.Tensor], None), - ('num_logits_indices', int, 0), + ("logits_indices_padded", Optional[torch.Tensor], None), + ("num_logits_indices", int, 0), ] @@ -675,17 +887,109 @@ def subclass_attention_metadata( Return a new subclass of `metadata_cls` with additional fields """ name: str = name_prefix + metadata_cls.__name__ # type: ignore - Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) + Wrapped = make_dataclass(name, fields, bases=(metadata_cls,)) return Wrapped -def make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls: Any, ) -> Any: - """ - Return a new subclass of `metadata_cls` for fast prefill - """ - return subclass_attention_metadata( - name_prefix="KVSharingFastPrefill", - metadata_cls=metadata_cls, - fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, +@runtime_checkable +class KVSharingFastPrefillMetadata(Protocol): + logits_indices_padded: torch.Tensor + num_logits_indices: int + + +def create_fast_prefill_custom_backend( + prefix: str, + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: + underlying_builder = underlying_attn_backend.get_builder_cls() + + class FastPrefillAttentionBuilder(underlying_builder): # type: ignore + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_common_attn_metadata = ( + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + ) + metadata = super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) + + class KVSharingFastPrefillAttentionMetadata( + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata, + ): + def __init__(self, metadata, common_attn_metadata): + # Shallow copy all fields in metadata cls + for _field in fields(metadata.__class__): + setattr(self, _field.name, getattr(metadata, _field.name)) + + # Set additional fields that will be used in model code + assert ( + common_attn_metadata.logits_indices_padded is not None + and common_attn_metadata.num_logits_indices is not None + ) + self.logits_indices_padded = ( + common_attn_metadata.logits_indices_padded + ) + self.num_logits_indices = common_attn_metadata.num_logits_indices + + return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=FastPrefillAttentionBuilder, ) + + return attn_backend + + +def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): + # Needed for causal_conv1d + seqlens = query_start_loc_p.diff().to("cpu") + nums_dict = {} # type: ignore + batch_ptr = None + token_chunk_offset_ptr = None + device = query_start_loc_p.device + for BLOCK_M in [8]: # cover all BLOCK_M values + nums = -(-seqlens // BLOCK_M) + nums_dict[BLOCK_M] = {} + nums_dict[BLOCK_M]["nums"] = nums + nums_dict[BLOCK_M]["tot"] = nums.sum().item() + mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) + nums_dict[BLOCK_M]["mlist"] = mlist + mlist_len = len(nums_dict[BLOCK_M]["mlist"]) + nums_dict[BLOCK_M]["mlist_len"] = mlist_len + MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 + offsetlist = [] # type: ignore + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) + offsetlist = torch.tensor(offsetlist, dtype=torch.int32) + nums_dict[BLOCK_M]["offsetlist"] = offsetlist + + if batch_ptr is None: + # Update default value after class definition + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + else: + if batch_ptr.nelement() < MAX_NUM_PROGRAMS: + batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + token_chunk_offset_ptr.resize_( # type: ignore + MAX_NUM_PROGRAMS + ).fill_(PAD_SLOT_ID) + + batch_ptr[0:mlist_len].copy_(mlist) + token_chunk_offset_ptr[ # type: ignore + 0:mlist_len + ].copy_(offsetlist) + nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr + nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore + + return nums_dict, batch_ptr, token_chunk_offset_ptr diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index e0eb7d8be9746..eb1fcc2c024d2 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,40 +3,44 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional, Union import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( - AttentionBias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask) + AttentionBias, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ) XFORMERS_AVAILABLE = True except ImportError: XFORMERS_AVAILABLE = False -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm import _custom_ops as ops logger = init_logger(__name__) class XFormersAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -77,6 +81,10 @@ class XFormersAttentionBackend(AttentionBackend): 256, ] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -86,11 +94,12 @@ class XFormersAttentionBackend(AttentionBackend): f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: - return "XFORMERS_VLLM_V1" + return "XFORMERS" @staticmethod def get_impl_cls() -> type["XFormersAttentionImpl"]: @@ -106,6 +115,7 @@ class XFormersAttentionBackend(AttentionBackend): block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") @@ -152,9 +162,9 @@ class XFormersAttentionMetadata: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -162,8 +172,8 @@ class XFormersAttentionMetadata: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -179,23 +189,25 @@ class XFormersAttentionMetadata: q_start_loc = self.query_start_loc q_seqlens = torch.diff(q_start_loc) - decode_kv_seqlens = self.seq_lens[:self.num_decodes] + decode_kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersAttentionMetadata( num_actual_tokens=self.num_decode_tokens, - max_query_len=int(q_seqlens[:self.num_decodes].max().item()), - query_start_loc=q_start_loc[:self.num_decodes + 1], + max_query_len=int(q_seqlens[: self.num_decodes].max().item()), + query_start_loc=q_start_loc[: self.num_decodes + 1], max_seq_len=int(decode_kv_seqlens.max().item()), seq_lens=decode_kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], attn_bias=self.attn_bias, ) return self._cached_decode_metadata class XFormersAttentionMetadataBuilder( - AttentionMetadataBuilder[XFormersAttentionMetadata]): + AttentionMetadataBuilder[XFormersAttentionMetadata] +): + reorder_batch_threshold: int = 1 def __init__( self, @@ -204,18 +216,13 @@ class XFormersAttentionMetadataBuilder( vllm_config: VllmConfig, device: torch.device, ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert XFORMERS_AVAILABLE - self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size self._num_decodes = 0 self._num_decode_tokens = 0 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) - def build( self, common_prefix_len: int, @@ -223,8 +230,10 @@ class XFormersAttentionMetadataBuilder( fast_build: bool = False, ) -> XFormersAttentionMetadata: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -240,14 +249,13 @@ class XFormersAttentionMetadataBuilder( # Construct the decoder bias. decode_q_seqlens = q_seqlens[:num_decodes] decode_kv_seqlens = kv_seqlens[:num_decodes] - bias = ( - PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( - q_seqlen=decode_q_seqlens.tolist(), - kv_seqlen=decode_kv_seqlens.tolist(), - page_size=self.block_size, - block_tables=block_table[:num_decodes], - device=block_table.device, - )) + bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=decode_q_seqlens.tolist(), + kv_seqlen=decode_kv_seqlens.tolist(), + page_size=self.block_size, + block_tables=block_table[:num_decodes], + device=block_table.device, + ) return XFormersAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -266,7 +274,6 @@ class XFormersAttentionMetadataBuilder( class XFormersAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -283,8 +290,7 @@ class XFormersAttentionImpl(AttentionImpl): if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if alibi_slopes is not None: - raise NotImplementedError( - "XFormers does not support alibi slopes yet.") + raise NotImplementedError("XFormers does not support alibi slopes yet.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -307,10 +313,12 @@ class XFormersAttentionImpl(AttentionImpl): XFormersAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "XFormersAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "XFormersAttentionImpl." + ) def forward( self, @@ -330,7 +338,8 @@ class XFormersAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -340,7 +349,8 @@ class XFormersAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for XFormersAttentionImpl") + " for XFormersAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -370,8 +380,7 @@ class XFormersAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens if prefill_meta := attn_metadata.prefill_metadata: - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1]) unified_attention( q=query[num_decode_tokens:num_actual_tokens], k=key_cache, @@ -396,36 +405,38 @@ class XFormersAttentionImpl(AttentionImpl): # Query for decode. KV is not needed because it is already cached. decode_query = query[:num_decode_tokens] # Reshape query to [1, B_T, G, H, D]. - q = decode_query.view(1, -1, self.num_kv_heads, - self.num_queries_per_kv, self.head_size) + q = decode_query.view( + 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size + ) # Reshape the k and v caches to [1, Bkv_T, G, H, D] - cache_k = key_cache.view(1, -1, self.num_kv_heads, 1, - self.head_size).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) - cache_v = value_cache.view(1, -1, self.num_kv_heads, 1, - self.head_size).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) + cache_k = key_cache.view( + 1, -1, self.num_kv_heads, 1, self.head_size + ).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) + cache_v = value_cache.view( + 1, -1, self.num_kv_heads, 1, self.head_size + ).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) attn_bias = decode_meta.attn_bias - output[: - num_decode_tokens] = xops.memory_efficient_attention_forward( - q, - cache_k, - cache_v, - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - ).view(decode_query.shape) + output[:num_decode_tokens] = xops.memory_efficient_attention_forward( + q, + cache_k, + cache_v, + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + ).view(decode_query.shape) # Reshape the output tensor. return output diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index fdd96c3e9557d..ddfd94322737f 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,19 +1,127 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from collections.abc import Iterable -from typing import Optional +from typing import Any, Optional, Union -from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, - BlockStored, KVCacheEvent) +from vllm.distributed.kv_events import ( + MEDIUM_GPU, + AllBlocksCleared, + BlockRemoved, + BlockStored, + KVCacheEvent, +) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashWithGroupId, + ExternalBlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash, +) from vllm.v1.request import Request logger = init_logger(__name__) +class BlockHashToBlockMap: + """ + Cache of blocks that are used for prefix caching. It caches blocks + from hash directly to a block or multiple blocks + (i.e. {block_hash: KVCacheBlocks}) + - Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks + would simply be a KVCacheBlock. + - Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock} + + A cached block is a full block with a block hash that can be used + for prefix caching. + The cached block may be used by running requests or in the + free_block_queue that could potentially be evicted. + + NOTE #1: We currently don't de-duplicate the blocks in the cache, + meaning that if a block becomes full and is cached, we don't check + if there is already an identical block in the cache. This is because + we want to make sure the allocated block IDs won't change so that + block tables are append-only. + NOTE #2: The union type is introduced in order to reduce GC costs + from the inner dict. + """ + + def __init__(self): + self._cache: dict[ + BlockHashWithGroupId, Union[KVCacheBlock, dict[int, KVCacheBlock]] + ] = {} + + def get_one_block(self, key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + """ + Gets any block with the given block hash key. + """ + blocks = self._cache.get(key) + if blocks is not None: + if isinstance(blocks, KVCacheBlock): + return blocks + if isinstance(blocks, dict): + return next(iter(blocks.values())) + self._unexpected_blocks_type(blocks) + return None + + def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: + """ + Inserts the KVCacheBlock to the cache + """ + blocks = self._cache.get(key) + if blocks is None: + # When key is not found, attach a single block to the key + self._cache[key] = block + elif isinstance(blocks, KVCacheBlock): + # If there's a block with the same key, merge the original block + # and the new block into a dict + self._cache[key] = {blocks.block_id: blocks, block.block_id: block} + elif isinstance(blocks, dict): + # If it's already a dict, simply insert the block + blocks[block.block_id] = block + else: + self._unexpected_blocks_type(blocks) + + def pop(self, key: BlockHashWithGroupId, block_id: int) -> Optional[KVCacheBlock]: + """ + Checks if block_hash exists and pop block_id from the cache + """ + blocks = self._cache.pop(key, None) + if blocks is None: + # block_hash not found in the cache + return None + # TODO(Jialin): If key is found, block_id should always present + # in blocks. We currently keep the original behaviour for safety. + # + # Will add block_id == blocks.block_id assertion and + # use del blocks[block_id] instead as followup. + if isinstance(blocks, KVCacheBlock): + if blocks.block_id == block_id: + return blocks + # If the single block ID doesn't match, we should put the + # block back (it should happen rarely) + self._cache[key] = blocks + return None + if isinstance(blocks, dict): + # Try to pop block_id from the block dict, and if dict still + # contain blocks, put back to the cache. + block = blocks.pop(block_id, None) + if len(blocks) > 0: + self._cache[key] = blocks + return block + self._unexpected_blocks_type(blocks) + return None + + def __len__(self) -> int: + return len(self._cache) + + def _unexpected_blocks_type(self, blocks: Any) -> None: + raise AssertionError(f"Invalid KV cache block type {type(blocks)}") + + class BlockPool: """BlockPool that manages KVCacheBlocks. It provides methods to allocate, free and cache the kv cache blocks. The @@ -46,17 +154,8 @@ class BlockPool: # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. - # The cached block may be used by running requests or in the - # free_block_queue that could potentially be evicted. - # NOTE: We currently don't de-duplicate the blocks in the cache, - # meaning that if a block becomes full and is cached, we don't check - # if there is already an identical block in the cache. This is because - # we want to make sure the allocated block IDs won't change so that - # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[ - int, KVCacheBlock]] = defaultdict(dict) + # Cache for block lookup + self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap() # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to @@ -68,9 +167,9 @@ class BlockPool: self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block( - self, block_hash: BlockHash, - kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: - """Get the cached block by the block hash for each group in + self, block_hash: BlockHash, kv_cache_group_ids: list[int] + ) -> Optional[list[KVCacheBlock]]: + """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -83,12 +182,15 @@ class BlockPool: """ cached_blocks = [] for group_id in kv_cache_group_ids: - cached_blocks_one_group = self.cached_block_hash_to_block.get( - BlockHashWithGroupId(block_hash, group_id)) - if not cached_blocks_one_group: + block_hash_with_group_id = make_block_hash_with_group_id( + block_hash, group_id + ) + block = self.cached_block_hash_to_block.get_one_block( + block_hash_with_group_id + ) + if not block: return None - first_block = next(iter(cached_blocks_one_group.values())) - cached_blocks.append(first_block) + cached_blocks.append(block) return cached_blocks def cache_full_blocks( @@ -117,46 +219,50 @@ class BlockPool: block_size: Number of tokens in each block. kv_cache_group_id: The id of the KV cache group. """ - if num_cached_blocks == num_full_blocks: + if num_cached_blocks >= num_full_blocks: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(request.block_hashes) >= num_full_blocks new_block_hashes = request.block_hashes[num_cached_blocks:] - new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events - else None) + new_hashes: Optional[list[ExternalBlockHash]] = ( + [] if self.enable_kv_cache_events else None + ) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. - block_hash_with_group_id = BlockHashWithGroupId( - block_hash, kv_cache_group_id) + block_hash_with_group_id = make_block_hash_with_group_id( + block_hash, kv_cache_group_id + ) blk.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block[block_hash_with_group_id][ - blk.block_id] = blk + self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk) if new_hashes is not None: - new_hashes.append(block_hash.hash_value) + new_hashes.append(maybe_convert_block_hash(block_hash)) if self.enable_kv_cache_events: if num_cached_blocks == 0: - parent_block_hash = None + parent_block_hash: Optional[ExternalBlockHash] = None else: parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None - parent_block_hash = parent_block.block_hash.get_hash_value() + parent_block_hash = maybe_convert_block_hash( + get_block_hash(parent_block.block_hash) + ) self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, parent_block_hash=parent_block_hash, - token_ids=request. - all_token_ids[num_cached_blocks * - block_size:num_full_blocks * block_size], + token_ids=request.all_token_ids[ + num_cached_blocks * block_size : num_full_blocks * block_size + ], block_size=block_size, - lora_id=request.lora_request.id - if request.lora_request else None, - )) + lora_id=request.lora_request.id if request.lora_request else None, + medium=MEDIUM_GPU, + ) + ) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -170,8 +276,7 @@ class BlockPool: A list of new block. """ if num_blocks > self.get_num_free_blocks(): - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") + raise ValueError(f"Cannot get {num_blocks} free blocks from the pool") ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) @@ -202,15 +307,13 @@ class BlockPool: if block_hash is None: # The block doesn't have hash, eviction is not needed return False - blocks_by_id = self.cached_block_hash_to_block.get(block_hash) - if blocks_by_id is None: - # block_hash not found in cached_block_hash_to_block, + + if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: + # block not found in cached_block_hash_to_block, # eviction is not needed return False + block.reset_hash() - blocks_by_id.pop(block.block_id, None) - if len(blocks_by_id) == 0: - del self.cached_block_hash_to_block[block_hash] if self.enable_kv_cache_events: # FIXME (Chen): Not sure whether we should return `hash_value` @@ -218,7 +321,11 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()])) + BlockRemoved( + block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))], + medium=MEDIUM_GPU, + ) + ) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: @@ -249,10 +356,9 @@ class BlockPool: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - self.free_block_queue.append_n([ - block for block in blocks_list - if block.ref_cnt == 0 and not block.is_null - ]) + self.free_block_queue.append_n( + [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] + ) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -267,11 +373,13 @@ class BlockPool: if num_used_blocks != 1: # The null block is always marked as used logger.warning( "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks - 1) + "blocks (%d) are not freed yet", + num_used_blocks - 1, + ) return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self.cached_block_hash_to_block = BlockHashToBlockMap() # Remove all hashes from all blocks. for block in self.blocks: @@ -307,7 +415,7 @@ class BlockPool: def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. - + Returns: A list of KV cache events. """ diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 0b9da60c67dee..c70025992e70c 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import OrderedDict from collections.abc import Mapping from typing import TYPE_CHECKING @@ -31,34 +33,52 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Note that no caching is shared between requests at this time. If the same - input is used across multiple requests, it will be reprocessed for each - request. - + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free + space for new embeddings. + Oldest cached embeddings with no request referenced will be first evicted. + Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens - num_free_slots: Current available cache capacity in encoder tokens - cached: Mapping from request_id to set of cached input_ids for that - request - freed: List of (request_id, input_id) pairs that were recently freed. - This is cleared after every call to get_freed_ids(). + cache_size: Total cache capacity in encoder tokens. + num_free_slots: Current available cache capacity in encoder tokens. + num_freeable_slots: Capacity that can be immediately reclaimed by + evicting entries with zero references (in encoder tokens). + cached: Mapping from mm_hash to a set of request IDs that currently + reference the cached entry. If the set is empty, the entry exists + but is not referenced by any request and is eligible for + reclamation. + freeable: List of tuples (mm_hash, num_tokens) representing entries + whose no current running request is needed and that can be freed to + make space when needed. + freed: List of mm_hash strings that were actually evicted since the + last call to get_freed_mm_hashes(). This list is cleared on return. """ def __init__(self, cache_size: int): self.cache_size = cache_size self.num_free_slots = cache_size - # req_id -> cached input ids - self.cached: dict[str, set[int]] = {} - # list of [req_id, input_id] - self.freed: list[tuple[str, int]] = [] + self.num_freeable_slots = cache_size - def has_cache(self, request: Request, input_id: int) -> bool: + # mm_hash of mm_data => ids of requests that reference the mm_data + self.cached: dict[str, set[str]] = {} + + # mm_hash of mm_data => num_encoder_tokens of the mm_data + self.freeable: OrderedDict[str, int] = OrderedDict() + self.freed: list[str] = [] + + def check_and_update_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. + If the encoder output is cached, update `cached` to add the request id + to the set of request ids that reference the cached encoder output. + If the encoder output was previously not referenced by any request, + update `freeable` and `num_freeable_slots` accordingly. + Args: request: The request containing the multimodal input input_id: Index of the multimodal input within the request @@ -66,103 +86,163 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - req_id = request.request_id - return req_id in self.cached and input_id in self.cached[req_id] + mm_hash = request.mm_features[input_id].identifier + # Not cached at all + if mm_hash not in self.cached: + return False - def can_allocate(self, request: Request, input_id: int) -> bool: + # Cached but currently not referenced by any request + if not self.cached[mm_hash]: + num_tokens = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_tokens + + self.cached[mm_hash].add(request.request_id) + return True + + def can_allocate( + self, + request: Request, + input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int, + ) -> bool: """Check if there's sufficient cache space for a multimodal input. + If there is, return True and update EncoderCacheManager state. + + If there is not enough free space in `num_free_slots` but there is + enough reclaimable space in `num_freeable_slots`, entries will be + evicted from `freeable` (their mm_hash appended to `freed`) until + enough space is available, and then this method returns True. + Older entries are evicted first. + + Returns False only if the requested number of tokens exceeds both + the free and reclaimable capacities combined. Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + request: The request containing the multimodal input. + input_id: Index of the multimodal input within the request. + encoder_compute_budget: Number of encoder tokens allowed to be + computed when this method is invoked. + num_tokens_to_schedule: Number of tokens already scheduled to be + allocated with cache space when this method is invoked. Returns: - True if there's enough free cache space to store the encoder output - for this multimodal input + True if there's enough capacity to hold the encoder output for this + input (possibly after reclaiming `freeable` entries); otherwise + False. + + Note: This method does not allocate physical memory for the encoder + output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) - return num_tokens <= self.num_free_slots + + # Not enough compute budget + if num_tokens > encoder_compute_budget: + return False + + num_tokens += num_tokens_to_schedule + + # Enough free slots + if num_tokens <= self.num_free_slots: + return True + + # Not enough reclaimable slots + if num_tokens > self.num_freeable_slots: + return False + + # Not enough free slots but enough reclaimable slots + # NOTE: Eviction takes place here, but physical memory is not freed + # until model runner is notified by the scheduler output. + while num_tokens > self.num_free_slots: + mm_hash, num_free_token = self.freeable.popitem(last=False) + del self.cached[mm_hash] + self.freed.append(mm_hash) + self.num_free_slots += num_free_token + return True def allocate(self, request: Request, input_id: int) -> None: """Allocate cache space for a multimodal input's encoder output. - This method reserves cache space for storing the encoder output of - the specified multimodal input. The actual encoder output storage - happens in the model runner, but this method ensures the cache - manager tracks the allocation. - - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + This reserves cache space for storing the encoder output of the + specified multimodal input. The actual encoder output storage happens in + the model runner; this method updates the manager's bookkeeping. Note: - This method assumes can_allocate() returned True for the same - request and input_id. It will reduce available cache space. + This method assumes can_allocate() returned True for the same input. """ - req_id = request.request_id - if req_id not in self.cached: - self.cached[req_id] = set() - self.cached[req_id].add(input_id) - self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + mm_hash = request.mm_features[input_id].identifier + request_id = request.request_id + if mm_hash not in self.cached: + self.cached[mm_hash] = set() + + num_encoder_tokens = request.get_num_encoder_tokens(input_id) + + # NOTE: Encoder cache should always have enough space for encoder inputs + # that are scheduled since eviction takes place at can_allocate(). + assert self.num_free_slots >= num_encoder_tokens + assert self.num_freeable_slots >= num_encoder_tokens + + self.cached[mm_hash].add(request_id) + self.num_free_slots -= num_encoder_tokens + self.num_freeable_slots -= num_encoder_tokens def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. - Args: - request: The request to query - - Returns: - Set of input_ids that have cached encoder outputs for this request. - Returns empty set if no inputs are cached for this request. + Returns the set of input IDs whose `mm_hash` exists in the cache map. + This includes entries that are currently unreferenced (and thus present + in `freeable`); for such entries, freeing for this request will be a + no-op. """ - return self.cached.get(request.request_id, set()) + return { + input_id + for input_id in range(len(request.mm_features)) + if request.mm_features[input_id].identifier in self.cached + } def free_encoder_input(self, request: Request, input_id: int) -> None: - """Free cache space for a single multimodal input's encoder output. + """Free the request's reference to the encoder input (`mm_data`) - This method is called when: - - The encoder output has been fully consumed by the decoder and is - no longer needed (e.g., in vision-language models after image - tokens are processed) - - A request is being cancelled or aborted + When the reference set for the corresponding `mm_hash` becomes empty, + the entry is appended to `freeable` and `num_freeable_slots` is + increased by the number of encoder tokens for that input. - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input to free from cache + The entry is NOT physically freed until capacity is needed (e.g., by + `can_allocate`). """ req_id = request.request_id - if req_id not in self.cached: + mm_hash = request.mm_features[input_id].identifier + # The mm_hash not in cache or the req_id set is empty + if not self.cached.get(mm_hash, None): return - - self.cached[req_id].discard(input_id) - if len(self.cached[req_id]) == 0: - del self.cached[req_id] - self.num_free_slots += request.get_num_encoder_tokens(input_id) - self.freed.append((req_id, input_id)) + self.cached[mm_hash].discard(req_id) + if not self.cached[mm_hash]: + num_tokens = request.get_num_encoder_tokens(input_id) + self.freeable[mm_hash] = num_tokens + self.num_freeable_slots += num_tokens def free(self, request: Request) -> None: - """Free all cached encoder outputs for a request. + """Free all encoder input cache reference held by *request*. - This method is typically called when a request is finished, cancelled, - or aborted, and all its encoder outputs should be freed from cache. + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future + attempt allocation called by 'can_allocate'. - Args: - request: The request whose encoder outputs should be freed + Typically called when a request is finished, cancelled, or aborted. """ input_ids = self.get_cached_input_ids(request).copy() for input_id in input_ids: self.free_encoder_input(request, input_id) - def get_freed_ids(self) -> list[tuple[str, int]]: + def get_freed_mm_hashes(self) -> list[str]: """Get and clear the list of recently freed encoder cache entries. - This method returns all encoder cache entries that were freed since - the last call to this method. It's used by the scheduler to notify - workers about which encoder outputs can be removed from their caches. - Returns: - List of (request_id, input_id) tuples that were freed since the - last call. The internal freed list is cleared after this call. + List of mm_hash strings that were actually evicted since the last + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -174,23 +254,19 @@ def compute_encoder_budget( scheduler_config: "SchedulerConfig", mm_registry: MultiModalRegistry, ) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler + """Compute the encoder cache budget based on the model and scheduler configurations. - Args: - model_config: Model configuration. - scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. - Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ if mm_registry.supports_multimodal_inputs(model_config): - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) + ) return compute_mm_encoder_budget( scheduler_config, @@ -200,18 +276,17 @@ def compute_encoder_budget( return compute_text_encoder_budget(scheduler_config) -def compute_text_encoder_budget( - scheduler_config: "SchedulerConfig") -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler +def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler configurations for a text-only model. Args: scheduler_config: Scheduler configuration. Returns: - - Compute budget for encoder execution, in unit of number of tokens + - Compute budget for encoder execution, in unit of number of tokens in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens + - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ # Currently text-only encoder-decoder models are not supported @@ -222,7 +297,7 @@ def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler + """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: @@ -231,32 +306,38 @@ def compute_mm_encoder_budget( non-text modality. Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ if not max_tokens_by_modality: logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " - "not be initialized.") + "not be initialized." + ) return 0, 0 max_tokens_per_mm_item = max(max_tokens_by_modality.values()) - if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item - > scheduler_config.max_num_batched_tokens): + if ( + scheduler_config.disable_chunked_mm_input + and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens + ): raise ValueError( "Chunked MM input disabled but max_tokens_per_mm_item " f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens" f" ({scheduler_config.max_num_batched_tokens}). Please increase " - "max_num_batched_tokens.") + "max_num_batched_tokens." + ) - encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, - max_tokens_per_mm_item) - encoder_cache_size = max(scheduler_config.encoder_cache_size, - max_tokens_per_mm_item) + encoder_compute_budget = max( + scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item + ) + encoder_cache_size = max( + scheduler_config.encoder_cache_size, max_tokens_per_mm_item + ) return encoder_compute_budget, encoder_cache_size diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index a0ea4d96015a2..ef6da9adeea70 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,9 +6,11 @@ from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, get_manager_for_kv_cache_spec) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + CrossAttentionManager, + FullAttentionManager, + get_manager_for_kv_cache_spec, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.request import Request @@ -24,13 +26,15 @@ class KVCacheCoordinator(ABC): use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, + dcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len self.enable_caching = enable_caching - self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, - enable_kv_cache_events) + self.block_pool = BlockPool( + kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events + ) # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle @@ -39,34 +43,50 @@ class KVCacheCoordinator(ABC): kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, - ) for i, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups)) + dcp_world_size=dcp_world_size, + ) + for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) + ) def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[list[KVCacheBlock], ...], + num_encoder_tokens: int, + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The number of blocks. """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + if isinstance(manager, CrossAttentionManager): + # For cross-attention, we issue a single static allocation + # of blocks based on the number of encoder input tokens. + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_encoder_tokens, [] + ) + else: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i] + ) return num_blocks_to_allocate def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None: + self, request_id: str, new_computed_blocks: tuple[list[KVCacheBlock], ...] + ) -> None: """ Add the new computed blocks to the request. @@ -76,26 +96,34 @@ class KVCacheCoordinator(ABC): prefix cache. """ for i, manager in enumerate(self.single_type_managers): - manager.save_new_computed_blocks(request_id, - new_computed_blocks[i]) + manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> tuple[list[KVCacheBlock], ...]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 + ) -> tuple[list[KVCacheBlock], ...]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). + num_encoder_tokens: The number of encoder tokens for allocating + blocks for cross-attention. Returns: The new allocated blocks. """ return tuple( - manager.allocate_new_blocks(request_id, num_tokens) - for manager in self.single_type_managers) + manager.allocate_new_blocks( + request_id, + num_encoder_tokens + if isinstance(manager, CrossAttentionManager) + else num_tokens, + ) + for manager in self.single_type_managers + ) def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ @@ -103,7 +131,8 @@ class KVCacheCoordinator(ABC): Args: request: The request. - num_tokens: The total number of tokens that need to be cached + num_computed_tokens: The total number of tokens + that need to be cached (including tokens that are already cached). """ for manager in self.single_type_managers: @@ -119,32 +148,26 @@ class KVCacheCoordinator(ABC): for manager in self.single_type_managers: manager.free(request_id) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]: """ - Get the number of common prefix blocks for all requests in the RUNNING - state for each kv cache group. + Get the number of common prefix blocks for all requests with allocated + KV cache for each kv cache group. Args: - request_id: The request ID. - num_running_requests: The total number of requests in the RUNNING - state. + running_request_id: The request ID of any running request, used to + identify the common prefix blocks. Returns: - list[int]: The number of common prefix blocks for all requests in - the RUNNING state for each kv cache group. + list[int]: The number of common prefix blocks for each kv cache group. """ - num_blocks_per_group = [ - manager.get_num_common_prefix_blocks(request_id, - num_running_requests) + return [ + manager.get_num_common_prefix_blocks(running_request_id) for manager in self.single_type_managers ] - return num_blocks_per_group - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and replace + Remove the blocks that are no longer needed from `blocks` and replace the removed blocks with null_block. Args: @@ -160,7 +183,8 @@ class KVCacheCoordinator(ABC): """ return tuple( manager.req_to_blocks.get(request_id) or [] - for manager in self.single_type_managers) + for manager in self.single_type_managers + ) @abstractmethod def find_longest_cache_hit( @@ -179,14 +203,25 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): Does not implement any features related to prefix caching. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, False, - enable_kv_cache_events) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) self.num_single_type_manager = len(self.single_type_managers) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]: return [0] * self.num_single_type_manager def find_longest_cache_hit( @@ -195,7 +230,8 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(self.num_single_type_manager)) + [] for _ in range(self.num_single_type_manager) + ) return blocks, 0 @@ -206,16 +242,31 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): full attention or all attention layers use sliding window attention. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) - self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if dcp_world_size > 1: + self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "UnitaryKVCacheCoordinator assumes only one kv cache group") + "UnitaryKVCacheCoordinator assumes only one kv cache group" + ) def find_longest_cache_hit( self, @@ -229,6 +280,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, + dcp_world_size=self.dcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -237,21 +289,34 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and thus multiple kv cache groups. - To simplify `find_longest_cache_hit`, it only supports the combination of + To simplify `find_longest_cache_hit`, it only supports the combination of two types of KV cache groups, and one of them must be full attention. May extend to more general cases in the future. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: """ - Verifies that the model has exactly two types of KV cache groups, and + Verifies that the model has exactly two types of KV cache groups, and one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ @@ -266,7 +331,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): else: assert full_attention_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes exactly one type of " - "full attention groups now.") + "full attention groups now." + ) self.full_attention_group_ids.append(i) else: if other_spec is None: @@ -274,19 +340,22 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): else: assert other_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes " - "exactly one other type of groups now.") + "exactly one other type of groups now." + ) self.other_group_ids.append(i) assert full_attention_spec is not None, ( "HybridKVCacheCoordinator assumes exactly one type of full " - "attention groups now.") + "attention groups now." + ) assert other_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of other " - "groups now.") + "HybridKVCacheCoordinator assumes exactly one type of other groups now." + ) self.full_attention_manager_cls = FullAttentionManager self.other_attention_cls = self.single_type_managers[ - self.other_group_ids[0]].__class__ + self.other_group_ids[0] + ].__class__ self.full_attention_spec = full_attention_spec self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size @@ -297,7 +366,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): divisible = self.other_block_size % self.full_attention_block_size assert divisible == 0, ( "KVCacheCoordinator assumes the block_size of full " - "attention layers is divisible by other layers now.") + "attention layers is divisible by other layers now." + ) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -310,7 +380,8 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): "do not interleave, either full attention group ids " "are before other attention group ids or vice versa." "This is for simplifying merging hit_blocks_full_attn and " - "hit_blocks_other_attn to hit_blocks.") + "hit_blocks_other_attn to hit_blocks." + ) def find_longest_cache_hit( self, @@ -330,29 +401,26 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. - hit_blocks_full_attn = ( - self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=block_hashes, - max_length=max_cache_hit_length, - kv_cache_group_ids=self.full_attention_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.full_attention_spec, - use_eagle=self.use_eagle, - )) - hit_length = len( - hit_blocks_full_attn[0]) * self.full_attention_block_size + hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=self.full_attention_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.full_attention_spec, + use_eagle=self.use_eagle, + ) + hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. - hit_blocks_other_attn = ( - self.other_attention_cls.find_longest_cache_hit( - block_hashes=block_hashes, - max_length=hit_length, - kv_cache_group_ids=self.other_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.other_spec, - use_eagle=self.use_eagle, - )) + hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.other_spec, + use_eagle=self.use_eagle, + ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size # NOTE: the prefix cache hit length must be a multiple of block_size as @@ -367,7 +435,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): # Truncate the full attention cache hit to the length of the # cache hit of the other attention. for group_hit_blocks in hit_blocks_full_attn: - del group_hit_blocks[hit_length // self.full_attention_block_size:] + del group_hit_blocks[hit_length // self.full_attention_block_size :] # Merge the hit blocks of full attention and other attention. if self.full_attn_first: @@ -378,16 +446,35 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def get_kv_cache_coordinator( - kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, - enable_kv_cache_events: bool) -> KVCacheCoordinator: + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, +) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, - use_eagle, - enable_kv_cache_events) + return KVCacheCoordinatorNoPrefixCache( + kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - enable_kv_cache_events) - return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + return UnitaryKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + return HybridKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index fd0bdb2c80fc5..b74ccd30b97b3 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import Request logger = init_logger(__name__) @@ -22,46 +22,47 @@ class KVCacheBlocks: Scheduler and KVCacheManager, to hide KVCacheManager's internal data structure from the Scheduler. """ + blocks: tuple[list[KVCacheBlock], ...] """ - blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. - We don't use block of tokens as the outer dimension because it assumes all - kv_cache_groups have the same number of blocks, which is true for now but - will be broken if we want to give different block_size to different + `blocks[i][j]` refers to the i-th kv_cache_group + and the j-th block of tokens.We don't use block of + tokens as the outer dimension because it assumes all + kv_cache_groups have the same number of blocks, which is true for now but + will be broken if we want to give different block_size to different kv_cache_groups in the future. """ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( - tuple(blk1 + blk2 - for blk1, blk2 in zip(self.blocks, other.blocks))) + tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)) + ) @overload def get_block_ids( self, allow_none: Literal[False] = False, - ) -> tuple[list[int], ...]: - ... + ) -> tuple[list[int], ...]: ... @overload def get_block_ids( self, allow_none: Literal[True] = True, - ) -> Optional[tuple[list[int], ...]]: - ... + ) -> Optional[tuple[list[int], ...]]: ... def get_block_ids( self, allow_none: bool = False, - ): + ) -> Optional[tuple[list[int], ...]]: """ Converts the KVCacheBlocks instance to block_ids. - + Returns: - tuple[list[int], ...]: A tuple of lists where - * the outer tuple corresponds to KV cache groups - * each inner list contains the block_ids of the blocks in that group + tuple[list[int], ...]: A tuple of lists where: + - the outer tuple corresponds to KV cache groups + - each inner list contains the block_ids of the blocks in that + group """ if allow_none and all(len(group) == 0 for group in self.blocks): return None @@ -70,10 +71,7 @@ class KVCacheBlocks: def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" assert len(self.blocks) == 1, "Only one group is supported" - return [ - block.block_id for block in self.blocks[0] - if block.block_hash is None - ] + return [block.block_id for block in self.blocks[0] if block.block_hash is None] def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" @@ -81,7 +79,6 @@ class KVCacheBlocks: class KVCacheManager: - def __init__( self, kv_cache_config: KVCacheConfig, @@ -90,6 +87,7 @@ class KVCacheManager: use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, + dcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len @@ -101,12 +99,25 @@ class KVCacheManager: self.block_size: Optional[int] = None if self.enable_caching: - assert len( - set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) - ) == 1, "Only one block size is supported for now" + assert ( + len( + set( + g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups + ) + ) + == 1 + ), "Only one block size is supported for now" self.block_size = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size + 0 + ].kv_cache_spec.block_size + + if dcp_world_size > 1: + assert len(kv_cache_config.kv_cache_groups) == 1 + # Note(hc): need revisit. When both DCP and any future + # PCP are enabled, the block_size may need to be scaled + # by a factor of dcp_size × pcp_size? + self.block_size *= dcp_world_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, @@ -114,6 +125,7 @@ class KVCacheManager: use_eagle=self.use_eagle, enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, + dcp_world_size=dcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool @@ -140,8 +152,7 @@ class KVCacheManager: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, - request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -155,9 +166,10 @@ class KVCacheManager: """ # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. - if (not self.enable_caching - or (request.sampling_params is not None - and request.sampling_params.prompt_logprobs is not None)): + if not self.enable_caching or ( + request.sampling_params is not None + and request.sampling_params.prompt_logprobs is not None + ): return self.create_empty_block_list(), 0 # NOTE: When all tokens hit the cache, we must recompute the last token @@ -168,14 +180,23 @@ class KVCacheManager: # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(request.block_hashes, - max_cache_hit_length)) + self.coordinator.find_longest_cache_hit( + request.block_hashes, max_cache_hit_length + ) + ) if self.log_stats: assert self.prefix_cache_stats is not None - self.prefix_cache_stats.requests += 1 - self.prefix_cache_stats.queries += request.num_tokens - self.prefix_cache_stats.hits += num_new_computed_tokens + if request.num_preemptions > 0: + # Previously preempted request + self.prefix_cache_stats.preempted_requests += 1 + self.prefix_cache_stats.preempted_queries += request.num_tokens + self.prefix_cache_stats.preempted_hits += num_new_computed_tokens + else: + # New request + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += request.num_tokens + self.prefix_cache_stats.hits += num_new_computed_tokens return KVCacheBlocks(computed_blocks), num_new_computed_tokens @@ -187,6 +208,7 @@ class KVCacheManager: new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, + num_encoder_tokens: int = 0, ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. @@ -197,10 +219,10 @@ class KVCacheManager: already been computed locally (i.e. new_computed_blocks). num_new_computed_tokens: The number of new computed tokens just hitting the prefix caching, excluding external tokens. - new_computed_blocks: The cached blocks for the above new computed + new_computed_blocks: The cached blocks for the above new computed tokens. num_lookahead_tokens: The number of speculative tokens to allocate. - This is used by spec decode proposers with kv-cache such + This is used by spec decode proposers with kv-cache such as eagle. delay_cache_blocks: Whether to skip caching the blocks. This is used by P/D when allocating blocks used in a KV transfer @@ -230,7 +252,8 @@ class KVCacheManager: new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = tuple( - [] for _ in range(len(self.kv_cache_config.kv_cache_groups))) + [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) + ) # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -238,21 +261,23 @@ class KVCacheManager: # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self.coordinator.remove_skipped_blocks(request.request_id, - request.num_computed_tokens) + self.coordinator.remove_skipped_blocks( + request.request_id, request.num_computed_tokens + ) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits - num_computed_tokens = (request.num_computed_tokens + - num_new_computed_tokens) + num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, - self.max_model_len) + self.max_model_len, + ) num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, + num_encoder_tokens=num_encoder_tokens, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -264,16 +289,18 @@ class KVCacheManager: self.block_pool.touch(new_computed_block_list) else: assert not any(new_computed_block_list), ( - "Computed blocks should be empty when " - "prefix caching is disabled") + "Computed blocks should be empty when prefix caching is disabled" + ) # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.coordinator.save_new_computed_blocks(request.request_id, - new_computed_block_list) + self.coordinator.save_new_computed_blocks( + request.request_id, new_computed_block_list + ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot) + request.request_id, num_tokens_need_slot, num_encoder_tokens + ) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -284,15 +311,16 @@ class KVCacheManager: # num_new_tokens, but must exclude "non-committable" tokens (e.g., # draft tokens that could be rejected). Therefore, we cap the number # at `request.num_tokens`, ensuring only "finalized" tokens are cached. - num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, - request.num_tokens) + num_tokens_to_cache = min( + num_computed_tokens + num_new_tokens, request.num_tokens + ) self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. - We free the blocks in reverse order so that he tail blocks are evicted + We free the blocks in reverse order so that the tail blocks are evicted first when caching is enabled. Args: @@ -316,48 +344,39 @@ class KVCacheManager: self.prefix_cache_stats.reset = True return True - def get_num_common_prefix_blocks( - self, - request: Request, - num_running_requests: int, - ) -> list[int]: - """Calculate the number of common prefix blocks shared by all requests - in the RUNNING state for each kv cache group. + def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]: + """Calculate the number of common prefix blocks for each kv cache group. - The function determines this by selecting any request and iterating - through its blocks. A block is considered a common prefix block if its - `ref_cnt` equals the total number of requests in the RUNNING state. + The function selects a running request and iterates through its blocks. + A block is considered a common prefix block if ALL requests with + allocated KV cache share it (i.e., ref_cnt equals the number of entries + in req_to_blocks). - NOTE(woosuk): The number of requests in the RUNNING state is **greater + NOTE(woosuk): The number of requests with allocated KV cache is **greater than or equal to** the number of requests scheduled in the current step. - This is because the RUNNING state only indicates that: + This is because having allocated KV cache only indicates that: 1. The request has not yet finished, and 2. The request holds its blocks unfreed. - While all scheduled requests must be in the RUNNING state, the inverse - is not necessarily true. There may be RUNNING requests that are not - scheduled in the current step. + While all scheduled requests must have allocated KV cache, the inverse + is not necessarily true. There may be requests with allocated KV cache + that are not scheduled in the current step. This can result in an edge case where the number of common prefix blocks is 0, even though all scheduled requests share a common prefix. This - occurs because there may be unscheduled RUNNING requests that do not - share the common prefix. Currently, this case cannot be easily detected, - so the function returns 0 in such cases. + occurs because there may be unscheduled requests that do not share the + common prefix. Currently, this case cannot be easily detected, so the + function returns 0 in such cases. Args: - request: Any request in the RUNNING state, used to identify the - common prefix blocks. - num_running_requests: The total number of requests in the RUNNING - state. This can be different from the number of scheduled - requests in the current step. + running_request_id: The request ID of any running request, used to + identify the common prefix blocks. Returns: - list[int]: The number of common prefix blocks for each kv cache + list[int]: The number of common prefix blocks for each kv cache group. """ - assert request.status == RequestStatus.RUNNING - return self.coordinator.get_num_common_prefix_blocks( - request.request_id, num_running_requests) + return self.coordinator.get_num_common_prefix_blocks(running_request_id) def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -382,5 +401,4 @@ class KVCacheManager: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] - for _ in range(self.num_kv_cache_groups))) + return KVCacheBlocks(tuple([] for _ in range(self.num_kv_cache_groups))) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6a62c55fb2d5f..7a602b9936855 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -2,148 +2,114 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """KV-Cache Utilities.""" +import copy import os -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Iterable, Sequence -from dataclasses import astuple, dataclass -from typing import Any, Callable, NamedTuple, Optional +from dataclasses import dataclass +from typing import Any, Callable, NewType, Optional, Union +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec) -from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.utils import GiB_bytes, cdiv, sha256_cbor +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.request import Request -logger = init_logger(__name__) +# BlockHash represents the hash of a single KV-cache block used for +# prefix caching. Treating it as a distinct type from ``bytes`` helps +# catch accidental misuse when passing around raw byte strings. +BlockHash = NewType("BlockHash", bytes) + +# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID. +# It is represented as raw bytes for compactness and efficiency. The helper +# functions below pack/unpack the ``BlockHash`` and group id into/from the key. +BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes) + +# ExternalBlockHash is used for reproducible prefix-cache block hashing. +# It's a union of ``bytes`` and ``int`` to keep backward compatibility +# after we default block hashing to use sha256 bytes. +ExternalBlockHash = Union[bytes, int] -class BlockHash(NamedTuple): - """Hash value of a block (int), the token IDs in the block, and extra keys. - We keep a tuple of token IDs and extra keys to reduce the likelihood of - hash collisions when the hash value is the same. By using SHA256 however, - hash collisions are practically impossible. +def make_block_hash_with_group_id( + block_hash: BlockHash, group_id: int +) -> BlockHashWithGroupId: + """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. + + The group id is encoded using 4 bytes in big-endian order and appended to + the block hash bytes. This representation avoids creating tuples while + still allowing us to recover both components when needed. """ - # Hash value of the block in an integer. - hash_value: int - # Token IDs in the block. - token_ids: tuple[int, ...] - # Extra keys for the block. - extra_keys: Optional[Any] = None + return BlockHashWithGroupId(block_hash + group_id.to_bytes(4, "big", signed=False)) -class BlockHashWithGroupId(NamedTuple): - # The hash value for the contents (e.g., token_ids) of a block without group - # ID. The value is the same for blocks representing the same tokens but for - # different groups. - block_hash: BlockHash - # The KV cache group ID. - group_id: int +def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: + """Extract the ``BlockHash`` from a ``BlockHashWithGroupId``.""" + return BlockHash(key[:-4]) - def get_hash_value(self) -> int: - return self.block_hash.hash_value +def get_group_id(key: BlockHashWithGroupId) -> int: + """Extract the group id from a ``BlockHashWithGroupId``.""" + return int.from_bytes(key[-4:], "big", signed=False) + + +def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: + if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: + return hash_bytes + return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1) + + +logger = init_logger(__name__) # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment -# variable if set such that processes can share the seed if needed. -# This aligns with the behavior of Python's hash() function, which also uses -# a random seed if PYTHONHASHSEED is not set. +# variable if set such that processes can share the seed if needed. This aligns +# with the behavior of Python's hash() function, which also uses a random seed +# if PYTHONHASHSEED is not set. # # The function `init_none_hash` initializes this variable globally. -NONE_HASH: int +NONE_HASH: BlockHash -def init_none_hash(hash_fn: Callable): +def init_none_hash(hash_fn: Callable[[Any], bytes]): global NONE_HASH hash_seed = os.getenv("PYTHONHASHSEED") - if hash_seed is None and hash_fn is sha256_cbor_64bit: + if hash_seed is None and hash_fn is sha256_cbor: logger.warning( "PYTHONHASHSEED is not set. This will lead to non-reproducible " - "block-hashes when using sha256_cbor_64bit as the hash function." + "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " - "reproducibility.") + "reproducibility." + ) - NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") - if hash_seed is None else hash_fn(hash_seed)) - - -class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the max recent N requests. - - Args: - max_recent_requests: The number of the max recent requests to aggregate. - Defaults to 1000. - """ - - def __init__(self, max_recent_requests: int = 1000): - self.max_recent_requests = max_recent_requests - # The current aggregated values. - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - # A deque of (requests, queries, hits) for the most recent requests. - self.query_queue: deque[tuple[int, int, int]] = deque() - - def observe(self, stats: PrefixCacheStats): - """Observe the prefix caching for a set of requests. - - This function is called with information gathered when new requests - are being scheduled and are looking for computed blocks. - - When there are more than `interval` requests, the oldest set of - requests are removed from the metrics. - - Args: - stats: The prefix cache stats. - """ - # reset_prefix_cache was invoked before the current update. - # Reset the metrics before aggregating the current stats. - if stats.reset: - self.reset() - - # Update the metrics. - self.query_queue.append((stats.requests, stats.queries, stats.hits)) - self.aggregated_requests += stats.requests - self.aggregated_query_total += stats.queries - self.aggregated_query_hit += stats.hits - - # Remove the oldest stats if the number of requests exceeds. - if self.aggregated_requests > self.max_recent_requests: - old_requests, old_queries, old_hits = self.query_queue.popleft() - self.aggregated_requests -= old_requests - self.aggregated_query_total -= old_queries - self.aggregated_query_hit -= old_hits - - def reset(self): - """Reset the metrics.""" - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - self.query_queue.clear() - - @property - def hit_rate(self) -> float: - """Calculate the hit rate for the past N requests.""" - if self.aggregated_query_total == 0: - return 0.0 - return self.aggregated_query_hit / self.aggregated_query_total + if hash_seed is None: + NONE_HASH = BlockHash(os.urandom(32)) + else: + NONE_HASH = BlockHash(hash_fn(hash_seed)) @dataclass class KVCacheBlock: """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int # Reference count. ref_cnt: int = 0 - # The hash of the block composed of (block hash, tuple of token IDs). - # It is only available when the block is full. + # The hash key (block hash + group id) of the block, only available + # when the block is full and cached. _block_hash: Optional[BlockHashWithGroupId] = None # Used to construct a doubly linked list for free blocks. @@ -161,7 +127,8 @@ class KVCacheBlock: @block_hash.setter def block_hash(self, block_hash: BlockHashWithGroupId): assert self.block_hash is None, ( - "The block already has a hash. This should not happen.") + "The block already has a hash. This should not happen." + ) self._block_hash = block_hash def reset_hash(self): @@ -171,15 +138,15 @@ class KVCacheBlock: def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = (self.prev_free_block.block_id - if self.prev_free_block else None) - next_block_id = (self.next_free_block.block_id - if self.next_free_block else None) - return (f"KVCacheBlock(block_id={self.block_id}, " - f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash}, " - f"prev_free_block={prev_block_id}, " - f"next_free_block={next_block_id})") + prev_block_id = self.prev_free_block.block_id if self.prev_free_block else None + next_block_id = self.next_free_block.block_id if self.next_free_block else None + return ( + f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " + f"_block_hash={self._block_hash!r}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})" + ) class FreeKVCacheBlockQueue: @@ -217,7 +184,7 @@ class FreeKVCacheBlockQueue: # Create a fake head and a tail block for the doubly linked list to # reduce branching in the code # - # The implementation garenteed that the fake head and tail + # The implementation guaranteed that the fake head and tail # are NEVER got popped, so we could safely assume each real blocks # in the queue has prev and next blocks. self.fake_free_list_head = KVCacheBlock(block_id=-1) @@ -240,12 +207,14 @@ class FreeKVCacheBlockQueue: Returns: The first free block. """ - if (self.fake_free_list_head.next_free_block - is self.fake_free_list_tail - or self.fake_free_list_head.next_free_block is None): + if ( + self.fake_free_list_head.next_free_block is self.fake_free_list_tail + or self.fake_free_list_head.next_free_block is None + ): assert self.num_free_blocks == 0, ( f"num_free_blocks ({self.num_free_blocks}) is out of sync " - "with the free list.") + "with the free list." + ) raise ValueError("No free blocks available") first_block: KVCacheBlock = self.fake_free_list_head.next_free_block @@ -253,8 +222,10 @@ class FreeKVCacheBlockQueue: if first_block.next_free_block is None: # This should not happen if the block is from the free list. # It indicates a bug in the caller's logic. - raise RuntimeError("Invalid block found in popleft() " - "which doesn't have a valid next_free_block") + raise RuntimeError( + "Invalid block found in popleft() " + "which doesn't have a valid next_free_block" + ) # Connect fake_head and the next block of first_block (i.e. second block # or fake tail). @@ -329,7 +300,8 @@ class FreeKVCacheBlockQueue: """ if self.fake_free_list_tail.prev_free_block is None: raise RuntimeError( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block # Connect the new block after the last block. @@ -350,11 +322,11 @@ class FreeKVCacheBlockQueue: """ if len(blocks) == 0: return - self.num_free_blocks += len(blocks) last_block = self.fake_free_list_tail.prev_free_block assert last_block is not None, ( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) # Add inter-connections between consecutive blocks for block in blocks: block.prev_free_block = last_block @@ -365,6 +337,8 @@ class FreeKVCacheBlockQueue: last_block.next_free_block = self.fake_free_list_tail self.fake_free_list_tail.prev_free_block = last_block + self.num_free_blocks += len(blocks) + def get_all_free_blocks(self) -> list[KVCacheBlock]: """Get all free blocks in the free list. Mainly used for testing. @@ -374,7 +348,8 @@ class FreeKVCacheBlockQueue: ret = [] if self.fake_free_list_head.next_free_block is None: raise RuntimeError( - "next_free_block of fake_free_list_head should always exist") + "next_free_block of fake_free_list_head should always exist" + ) # Start from the first block curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block # As long as next_free_block is available, we haven't reached to @@ -398,14 +373,16 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_hashes) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return ( + bool(request.mm_features) + or (request.lora_request is not None) + or (request.cache_salt is not None) + ) -def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, - end_token_idx: int, - start_mm_idx: int) -> tuple[list[Any], int]: +def _gen_mm_extra_hash_keys( + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -422,32 +399,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, """ extra_keys: list[Any] = [] - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes - if not mm_positions: + mm_features = request.mm_features + if not mm_features: return extra_keys, start_mm_idx - if mm_positions and len(mm_positions) != len(mm_hashes): - raise ValueError( - "The number of multi-modal positions and hashes must match. This " - "is likely because you did not enable MM hashing. " - "Please set `mm_processor_cache_gb > 0`.") - - # Note that we assume mm_positions is sorted by offset. + # Note that we assume mm_features are sorted by mm_position.offset. # We do not need to check all mm inputs if the start token index is out of # range. This usually happens in the late prefill phase and decoding phase. - if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx: + last_pos = mm_features[-1].mm_position + if last_pos.offset + last_pos.length < start_token_idx: return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. if start_mm_idx < 0: - assert -start_mm_idx <= len(mm_positions) - start_mm_idx = len(mm_positions) + start_mm_idx + assert -start_mm_idx <= len(mm_features) + start_mm_idx = len(mm_features) + start_mm_idx curr_mm_idx = start_mm_idx - while mm_positions and curr_mm_idx < len(mm_positions): - assert mm_hashes[curr_mm_idx] is not None - offset = mm_positions[curr_mm_idx].offset - length = mm_positions[curr_mm_idx].length + while mm_features and curr_mm_idx < len(mm_features): + mm_feature = mm_features[curr_mm_idx] + assert mm_feature.identifier is not None + offset = mm_feature.mm_position.offset + length = mm_feature.mm_position.length if end_token_idx > offset: if start_token_idx > offset + length: # This block has passed the current mm input. @@ -455,7 +428,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, continue # The block contains the current mm input. - extra_keys.append(mm_hashes[curr_mm_idx]) + extra_keys.append(mm_feature.identifier) if end_token_idx >= offset + length: # If this block contains the end of the current mm input, @@ -487,8 +460,8 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -503,10 +476,12 @@ def generate_block_hash_extra_keys( """ mm_extra_keys: list[Any] mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( - request, start_token_idx, end_token_idx, start_mm_idx) + request, start_token_idx, end_token_idx, start_mm_idx + ) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - cache_salt_keys: list[str] = [request.cache_salt] if ( - start_token_idx == 0 and request.cache_salt) else [] + cache_salt_keys: list[str] = ( + [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] + ) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys @@ -517,22 +492,22 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable, - parent_block_hash: Optional[int], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None, +) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing hash values for the same block contents. - Args: + hash_function: The hash function used to compute block hash. parent_block_hash: The hash of the parent block. None if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. extra_keys: Extra keys for the block. - Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. @@ -542,32 +517,26 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( - hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, extra_keys) + hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) + ) def get_request_block_hasher( block_size: int, - caching_hash_fn: Callable[[Any], - int]) -> Callable[[Request], list[BlockHash]]: + caching_hash_fn: Callable[[Any], bytes], +) -> Callable[[Request], list[BlockHash]]: """ Returns a function which computes the list of un-computed block hashes - of a request. - - Each request holds a list of its block hashes (request.block_hashes). - When a request is created, it calls the below function to compute - the hashes of all full blocks of the request's initial tokens. - The hashes are then stored in request.block_hashes. - Later, whenever new tokens are appended to the request, it calls - the below function again to compute any new full blocks of tokens. - The returned new hashes are appended to request.block_hashes. - """ + of a request.""" def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size num_tokens = request.num_tokens + if start_token_idx + block_size > num_tokens: + # Early stop when there no new full blocks created. + return [] + curr_mm_idx = 0 if start_token_idx > 0: # Set curr_mm_idx = -1 to indicate the last mm input. @@ -576,8 +545,9 @@ def get_request_block_hasher( # last mm input. curr_mm_idx = -1 - prev_block_hash_value = request.block_hashes[-1].hash_value \ - if request.block_hashes else None + prev_block_hash_value = ( + request.block_hashes[-1] if request.block_hashes else None + ) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -587,35 +557,38 @@ def get_request_block_hasher( # MM and LoRA requests need extra keys for block-hash computation. extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, curr_mm_idx) + request, start_token_idx, end_token_idx, curr_mm_idx + ) # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] - block_hash = hash_block_tokens(caching_hash_fn, - prev_block_hash_value, block_tokens, - extra_keys) + block_hash = hash_block_tokens( + caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys + ) new_block_hashes.append(block_hash) start_token_idx += block_size - prev_block_hash_value = block_hash.hash_value + prev_block_hash_value = block_hash return new_block_hashes return request_block_hasher -def max_memory_usage_bytes(vllm_config: VllmConfig, - kv_cache_specs: Iterable[KVCacheSpec]) -> int: +def max_memory_usage_bytes( + vllm_config: VllmConfig, kv_cache_specs: Iterable[KVCacheSpec] +) -> int: """ Get the maximum memory usage in bytes for the given KV cache specs. """ - return sum( - spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) + return sum(spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) -def estimate_max_model_len(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> int: +def estimate_max_model_len( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> int: """ Estimates the maximum model length that can fit in the available memory using binary search. @@ -634,8 +607,7 @@ def estimate_max_model_len(vllm_config: VllmConfig, # Modify the max_model_len for this calculation vllm_config.model_config.max_model_len = model_len # Calculate memory needed for the given model length - memory_needed = max_memory_usage_bytes(vllm_config, - kv_cache_spec.values()) + memory_needed = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) return memory_needed <= available_memory # Binary search for the maximum model length @@ -658,9 +630,11 @@ def estimate_max_model_len(vllm_config: VllmConfig, return result -def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int): +def check_enough_kv_cache_memory( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +): """ Checks whether `available_memory` is enough for the KV cache to hold at least one request with the model's max_model_len. @@ -679,36 +653,41 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, return if available_memory <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_model_len = vllm_config.model_config.max_model_len needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) if needed_memory > available_memory: # Estimate the maximum model length that can fit in the available memory - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - available_memory) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, available_memory + ) estimated_msg = "" if estimated_max_len > 0: estimated_msg = ( "Based on the available memory, " - f"the estimated maximum model length is {estimated_max_len}.") + f"the estimated maximum model length is {estimated_max_len}." + ) raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " + f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/GiB_bytes:.2f} GiB). " + f"memory ({available_memory / GiB_bytes:.2f} GiB). " f"{estimated_msg} " f"Try increasing `gpu_memory_utilization` or decreasing " - f"`max_model_len` when initializing the engine.") + f"`max_model_len` when initializing the engine." + ) def create_kv_cache_group_specs( - kv_cache_spec: dict[str, KVCacheSpec], - grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]] +) -> list[KVCacheGroupSpec]: """ Create KVCacheGroupSpec object for each kv cache group layer. The layers in the same group should share the same @@ -731,11 +710,12 @@ def create_kv_cache_group_specs( ] merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) + ) return kv_cache_groups -def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same KV cache spec. Note that we regard FullAttentionSpec with and without sliding window as @@ -748,6 +728,10 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: True if all layers have the same type, False otherwise. """ + if not kv_cache_spec: + # Encoder-only models do not have KV cache, kv_cache_type can be + # regarded as uniform. + return True try: kv_cache_spec_values = list(kv_cache_spec.values()) _ = kv_cache_spec_values[0].merge(kv_cache_spec_values) @@ -757,25 +741,45 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: def get_max_concurrency_for_kv_cache_config( - vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> float: """ Get the maximum concurrency for the given KV cache configuration. """ num_layer_per_group = max( - len(group.layer_names) for group in kv_cache_config.kv_cache_groups) + len(group.layer_names) for group in kv_cache_config.kv_cache_groups + ) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( - vllm_config, - (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)) - memory_per_block = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes * num_layer_per_group - num_block_per_request = cdiv(max_memory_usage_per_request, - memory_per_block) + vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) + ) + memory_per_block = ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes + * num_layer_per_group + ) + num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block) max_concurrency = kv_cache_config.num_blocks / num_block_per_request return max_concurrency -def get_num_blocks(vllm_config: VllmConfig, num_layers: int, - available_memory: int, page_size: int) -> int: +def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: + """ + Override the number of kv cache blocks if `num_gpu_blocks_override` is set. + """ + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + num_blocks, + num_gpu_blocks_override, + ) + num_blocks = num_gpu_blocks_override + + return num_blocks + + +def get_num_blocks( + vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int +) -> int: """ Get the number of kv cache blocks. @@ -787,13 +791,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int, """ num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) - if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) - num_blocks = num_gpu_blocks_override + num_blocks = may_override_num_blocks(vllm_config, num_blocks) return num_blocks @@ -806,57 +804,41 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: return page_sizes.pop() -def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_groups_uniform_spec( + kv_cache_specs: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model with one type of KV cache. - Divide the available memory equally among all layers. + Generates the KV cache configuration for a model with the same KV cache + spec for all layers. Args: - vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. + kv_cache_specs: The kv cache spec of each attention layer in the model Returns: - The generated KVCacheConfig + The generated KVCacheGroupSpecs """ - page_size = get_uniform_page_size(kv_cache_spec) - num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), - available_memory, page_size) - - per_layer_size = page_size * num_blocks - # All layers have the same KV cache spec, so we create one kv cache group - # for all layers. - grouped_layer_names = [list(kv_cache_spec.keys())] - - # Each layer uses a separate Tensor to store its KV cache. - kv_cache_tensors = [ - KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) - for layer_name in kv_cache_spec - ] - - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=kv_cache_tensors, - kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, - grouped_layer_names), - ) - - num_tokens = num_blocks * vllm_config.cache_config.block_size - num_tokens_str = f"{num_tokens:,}" - logger.info("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" - max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) - return kv_cache_config + return create_kv_cache_group_specs(kv_cache_specs, [list(kv_cache_specs.keys())]) -def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def _get_kv_cache_groups_uniform_type( + spec: UniformTypeKVCacheSpecs, +) -> list[KVCacheGroupSpec]: + """ + Generates the KV cache configuration for a model with one type of KV cache + but different hidden sizes. All layers are merged into one group. + + Args: + spec: The UniformTypeKVCacheSpecs of the model + + Returns: + The generated KVCacheGroupSpecs + """ + + return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)] + + +def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. Args: @@ -870,79 +852,75 @@ def is_kv_cache_page_size_uniform( return len(page_sizes) == 1 -def is_kv_cache_type_attention_free( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: - +def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: # kv_cache_spec is an empty dict for attention free models return not kv_cache_spec -def _get_kv_cache_config_uniform_page_size( - vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_groups_uniform_page_size( + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for hybrid models with multiple - attention types but still with a uniform page size (physical memory per + Generates the KV cache groups for hybrid models with multiple + attention types but still with a uniform page size (physical memory per block per layer) for all layers. Detailed explanation about kv cache management of hybrid models: The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. The KVCacheManager allocates different block tables for each of the 3 layers - in the pattern, and repeats each of them 10 times to generate the + in the pattern, and repeats each of them 10 times to generate the block_table for the 30 layers in the model. Therefore, we can group the layers in the model into 3 kv_cache_groups, each of which contains 10 layers in the model. The KVCacheManager allocates the block_table for each group based on its - kv_cache spec, and the model runner applies the block table to each layer + kv_cache spec, and the model runner applies the block table to each layer in the group. For example: - 1. A model only uses full attention. The pattern is - (num_hidden_layers * full), so there is only one group and the block table - is shared by all layers. It is already handled by + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. It is already handled by `_get_kv_cache_config_uniform_type`. - 2. A model with 10 full attention layers and 20 sliding window - attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + 2. A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so there are 3 kv_cache_groups, each of which represents 10 layers. To simplify the implementation, we make the following assumptions: - 1. Physical memory per block: Must be the same across all KV cache groups. + 1. Physical memory per block: Must be the same across all KV cache groups. Breaking this assumption is non-trivial due to memory fragmentation concerns when allocating blocks of different sizes. - 2. Tokens per block (block_size): Currently, we directly use - `CacheConfig.block_size` for all layers. It can be extended to vary by KV - cache group, but within each KV cache group, all layers must share the same + 2. Tokens per block (block_size): Currently, we directly use + `CacheConfig.block_size` for all layers. It can be extended to vary by KV + cache group, but within each KV cache group, all layers must share the same block size. - 3. Physical memory per token per layer: This property is decided by model - config. Currently we only support models that have the same physical memory - per token per layer for all layers. Can be relaxed with a simple extension, + 3. Physical memory per token per layer: This property is decided by model + config. Currently we only support models that have the same physical memory + per token per layer for all layers. Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. - 4. Number of layers per group: Currently assumed the same for all layers. - Can be relaxed with a simple extension, but still need to keep physical + 4. Number of layers per group: Currently assumed the same for all layers. + Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. 5. Attention type within groups: All layers in a group must share the same - attention type. One exception is that, when - `--disable-hybrid-kv-cache-manager` is true, the single group for full - attention layers may also include attention layers using sliding window or + attention type. One exception is that, when + `--disable-hybrid-kv-cache-manager` is true, the single group for full + attention layers may also include attention layers using sliding window or LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details. - 6. Support for multiple attention types: The design for most components is - general to an arbitrary number of attention types. But - `find_longest_cache_hit` only supports one attention type or two + 6. Support for multiple attention types: The design for most components is + general to an arbitrary number of attention types. But + `find_longest_cache_hit` only supports one attention type or two types of full-attention plus exactly one another type. The general - implementation of this function is feasible but we don't know how to + implementation of this function is feasible but we don't know how to implement it cleanly yet. - As we assume tokens per block, physical memory per token per layer, and - number of layers per group are the same now, we can ensure that physical + As we assume tokens per block, physical memory per token per layer, and + number of layers per group are the same now, we can ensure that physical memory per block is the same for all groups. Args: - vllm_config: The global VllmConfig kv_cache_spec: The KVCacheSpec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfig + The generated KVCacheGroupSpecs """ # Group all layers by kv_cache_spec. # E.g., 2 full attention layers and 3 sliding window attention layers, @@ -955,7 +933,7 @@ def _get_kv_cache_config_uniform_page_size( # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) # split to 3 groups with 2 layers each: - # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + # (full.0, full.1), (sw.0, sw.2), (sw.1, padding). # FIXME(Chen): At the moment of writing this code (2025-06-02), all # open-source hybrid model follows a n:1 pattern between different attention # types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and @@ -973,55 +951,101 @@ def _get_kv_cache_config_uniform_page_size( num_padding_layers, num_padding_layers / len(layers) * 100, ) - for i in range(0, len(layers), group_size): - grouped_layers.append(layers[i:i + group_size]) - kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec, - grouped_layers) + num_groups = cdiv(len(layers), group_size) + # In PP case, say if we have + # - stage 0: full.0, sw.0, sw.1 + # - stage 1: full.1, sw.2, sw.3 + # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) + # It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because + # the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) + # and it will be padded to (full.0, padding), (sw.0, sw.1), + # (padding, padding) to ensure the number of layers in each group is + # the same and will cause memory waste. + # To avoid this, we assign layers[i::num_groups] to the i-th group + # instead of layers[i * group_size: (i + 1) * group_size] + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) + return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) + + +def get_kv_cache_config_from_groups( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + kv_cache_specs: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: + """ + Generate the KV cache configuration from the KV cache groups and spec + of each layer. + + Args: + vllm_config: The global VllmConfig + kv_cache_groups: The KV cache groups + kv_cache_specs: The KV cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes + Returns: + The generated KVCacheConfig + """ + if len(kv_cache_groups) == 0: + # Attention free models do not have KV cache. + # Return num_blocks=1 as BlockPool always needs a null_block. + return KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=kv_cache_groups, + ) # Determine how model runners should initialize the KV cache tensors. - # We will have group_size memory pools, each is shared by one layer from - # each group. As layers of different groups have different block table, - # they will use different parts of the shared Tensor. - # The memory layout in the example will be: - # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 - # full.1, sw.1: share another Tensor with size=available_memory//2 - page_size = get_uniform_page_size(kv_cache_spec) - num_blocks = get_num_blocks(vllm_config, group_size, available_memory, - page_size) - per_memory_pool_size = page_size * num_blocks - kv_cache_tensors = [] - for i in range(group_size): - shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(grouped_layers[j]): - shared_by.append(grouped_layers[j][i]) - kv_cache_tensors.append( - KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): + # Special case: all layers have the same type of KV cache but with + # different hidden size. Allocate different amount of memory for each + # layer based on its hidden size. + num_blocks = ( + available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes + ) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs + kv_cache_tensors = [ + KVCacheTensor( + size=per_layer_specs[layer_name].page_size_bytes * num_blocks, + shared_by=[layer_name], + ) + for layer_name in kv_cache_groups[0].layer_names + ] + else: + # General case: + # We will have group_size memory pools, each is shared by one layer from + # each group. As layers of different groups have different block table, + # they will use different parts of the shared Tensor. + # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), + # (sw.1, padding) will be: (group_size = 2) + # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 + # full.1, sw.2: share another Tensor with size=available_memory//2 + group_size = max(len(group.layer_names) for group in kv_cache_groups) - kv_cache_config = KVCacheConfig( + page_size = get_uniform_page_size(kv_cache_specs) + assert group_size > 0, "group_size must be greater than 0" + num_blocks = get_num_blocks( + vllm_config, group_size, available_memory, page_size + ) + kv_cache_tensors = [] + for i in range(group_size): + shared_by = [] + for j in range(len(kv_cache_groups)): + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) + kv_cache_tensors.append( + KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) + ) + + return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups, ) - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_groups]) - - # Print the KV cache size and maximum concurrency. - num_tokens = num_blocks // len(grouped_layers) * min_block_size - num_tokens_str = f"{num_tokens:,}" - logger.info("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" - max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) - return kv_cache_config - - -def _get_kv_cache_config_attention_free() -> KVCacheConfig: - return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) - def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ @@ -1033,24 +1057,28 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): kv_cache_spec: The kv cache spec of each attention layer in the model """ - if is_kv_cache_type_uniform(kv_cache_spec): + if is_kv_cache_spec_uniform( + kv_cache_spec + ) or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec): return logger.warning( "Hybrid KV cache manager is disabled for this hybrid model, " "This means we do not enable any optimizations for saving KV cache " "memory (e.g., dropping the KV cache outside the sliding window). " - "The compute of layers like sliding window is still saved.") + "The compute of layers like sliding window is still saved." + ) has_full_attention = any( - isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) + isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values() + ) has_sliding_window = any( - isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) + isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values() + ) has_chunked_local_attention = any( - isinstance(spec, ChunkedLocalAttentionSpec) - for spec in kv_cache_spec.values()) - if has_full_attention and (has_sliding_window - or has_chunked_local_attention): + isinstance(spec, ChunkedLocalAttentionSpec) for spec in kv_cache_spec.values() + ) + if has_full_attention and (has_sliding_window or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( @@ -1058,7 +1086,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, - use_mla=spec.use_mla, sliding_window=spec.sliding_window, ) elif isinstance(spec, ChunkedLocalAttentionSpec): @@ -1067,88 +1094,217 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, - use_mla=spec.use_mla, attention_chunk_size=spec.attention_chunk_size, ) - if not is_kv_cache_type_uniform(kv_cache_spec): - raise ValueError("Hybrid KV cache manager is disabled but failed to " - "convert the KV cache specs to one unified type.") + if not ( + is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec) + ): + raise ValueError( + "Hybrid KV cache manager is disabled but failed to " + "convert the KV cache specs to one unified type." + ) -def get_kv_cache_config( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -) -> KVCacheConfig: +def get_kv_cache_groups( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model. + Split the layers in the model into groups with the same KV cache spec. Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfigs + The generated KVCacheGroups """ - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_attention_free(kv_cache_spec): - # This returns a kv_cache config with 0 kv_cache groups and 1 block - # to allow for the KVCache manager to handle attention free models. - return _get_kv_cache_config_attention_free() - elif is_kv_cache_type_uniform(kv_cache_spec): + # This returns an empty list to allow for the KVCacheManager to handle + # attention free models. + return [] + elif is_kv_cache_spec_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. - return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory) + return _get_kv_cache_groups_uniform_spec(kv_cache_spec) + elif uniform_spec := UniformTypeKVCacheSpecs.from_specs(kv_cache_spec): + # All layers need the same number of token slots (e.g., all layers are + # full attention, or all layers are sliding window attention with the + # same window size). Put all layers into one group. + return _get_kv_cache_groups_uniform_type(uniform_spec) elif is_kv_cache_page_size_uniform(kv_cache_spec): # Model contains multiple attention types, but KV cache of all layers # have the same physical memory per block per layer. Split the layers # into groups with the same number of layers, and thus same total page # size. - return _get_kv_cache_config_uniform_page_size(vllm_config, - kv_cache_spec, - available_memory) + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) raise NotImplementedError -def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): +def generate_scheduler_kv_cache_config( + kv_cache_configs: list[KVCacheConfig], +) -> KVCacheConfig: """ - Make the KV cache configurations for each worker consistent, so that all - workers can be controlled by the same KVCacheManager. - This function verifies that the layer group of each worker are the same, - and changes the num_blocks of each worker to the smallest among all workers. + Generate the KV cache configuration for the scheduler. + """ + assert all( + [cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs] + ) + # All workers have the same kv_cache_config except layer names, so use + # an arbitrary one to initialize the scheduler. + cfg = copy.deepcopy(kv_cache_configs[0]) + for group in cfg.kv_cache_groups: + if isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # so use an arbitrary one to initialize the scheduler. + group.kv_cache_spec = next( + iter(group.kv_cache_spec.kv_cache_specs.values()) + ) + return cfg + + +def _report_kv_cache_config( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> None: + """ + Log resolved KV cache configuration. Args: - kv_cache_configs: The KV cache configurations for each worker. Will be - in-place modified to make them consistent. + vllm_config: The global VllmConfig + kv_cache_config: The resolved KV cache configuration + """ + min_block_size = min( + [group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups] + ) + + # Log the KV cache size and maximum concurrency. + num_tokens = ( + kv_cache_config.num_blocks + // len(kv_cache_config.kv_cache_groups) + * min_block_size + ) + if vllm_config.parallel_config.decode_context_parallel_size > 1: + num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + logger.info( + "Multiplying the GPU KV cache size by the dcp_world_size %d.", + vllm_config.parallel_config.decode_context_parallel_size, + ) + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) + + +def get_kv_cache_configs( + vllm_config: VllmConfig, + kv_cache_specs: list[dict[str, KVCacheSpec]], + available_memory: list[int], +) -> list[KVCacheConfig]: + """ + Generates the KV cache configurations for a model. + Since we use a shared centralized controller for all workers, we need the + `kv_cache_config` to be consistent across all workers to make sure + the KV cache allocation can be applied to all workers. However, different + workers may have different memory available, and different type of layers + (when pipeline parallel is enabled). To handle the difference between + workers, the current implementation is: + 1. Merge the KV cache specs of all workers to get the KVCacheSpecs for + the whole model. + 2. Generate the KV cache groups based on the layer ratio of the whole model. + 3. Generate the KV cache configs for each worker based on the KV cache + grouping strategy. (This is reasonable because the layer ratio of + different PP stages are similar.) + 4. Change the num_blocks of each worker to the smallest among all workers + and shrink tensor sizes proportionally to avoid allocating unused memory. + + Args: + vllm_config: The global VllmConfig + kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. + available_memory: Memory available for KV cache in bytes for each + worker. + + Returns: + The generated KVCacheConfigs for each worker. """ - # Sort the kv cache groups by their KV cache spec. - # This can avoid the inconsistency caused by the order of groups. - for kv_cache_config in kv_cache_configs: - kv_cache_config.kv_cache_groups.sort(key=lambda x: (type( - x.kv_cache_spec).__name__, astuple(x.kv_cache_spec))) + # Check if the available memory is enough for each worker. + for kv_cache_spec_one_worker, available_memory_one_worker in zip( + kv_cache_specs, available_memory + ): + check_enough_kv_cache_memory( + vllm_config, kv_cache_spec_one_worker, available_memory_one_worker + ) - # Verify that the groups of each rank are the same. - for kv_cache_config in kv_cache_configs[1:]: - for group_rank_0, group_rank_i in zip( - kv_cache_configs[0].kv_cache_groups, - kv_cache_config.kv_cache_groups): - assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec + # Merge the KV cache specs of all workers. Different PP stages may have + # different layer names, and different TP ranks of the same PP stage should + # have the same KV cache spec. + merged_kv_cache_specs: dict[str, KVCacheSpec] = {} + for kv_cache_spec_one_worker in kv_cache_specs: + for layer_name, layer_spec in kv_cache_spec_one_worker.items(): + if layer_name not in merged_kv_cache_specs: + merged_kv_cache_specs[layer_name] = layer_spec + else: + assert merged_kv_cache_specs[layer_name] == layer_spec, ( + "The KV cache specs for the same layer are different " + "across workers. This is not supported yet." + ) + global_kv_cache_groups = get_kv_cache_groups(vllm_config, merged_kv_cache_specs) - # Change the num_blocks of each rank to the smallest among all ranks. We - # do not need to shrink the tensor size because it is valid to only use the - # first `num_blocks` blocks of the tensor. - min_num_blocks = min(kv_cache_config.num_blocks - for kv_cache_config in kv_cache_configs) + kv_cache_configs: list[KVCacheConfig] = [] + for kv_cache_spec_one_worker, available_memory_one_worker in zip( + kv_cache_specs, available_memory + ): + kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] + for group in global_kv_cache_groups: + group_layer_names_one_worker = [ + layer_name + for layer_name in group.layer_names + if layer_name in kv_cache_spec_one_worker + ] + kv_cache_groups_one_worker.append( + KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec) + ) + assert sum( + len(group.layer_names) for group in kv_cache_groups_one_worker + ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." + kv_cache_configs.append( + get_kv_cache_config_from_groups( + vllm_config, + kv_cache_groups_one_worker, + kv_cache_spec_one_worker, + available_memory_one_worker, + ) + ) + + # Change the num_blocks of each rank to the smallest among all ranks. + # We also need to shrink the tensor size proportionally to avoid + # allocating unused memory. + min_num_blocks = min( + kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + ) for kv_cache_config in kv_cache_configs: + num_blocks_old = kv_cache_config.num_blocks kv_cache_config.num_blocks = min_num_blocks + # Shrink tensor size proportionally + for tensor in kv_cache_config.kv_cache_tensors: + assert tensor.size % num_blocks_old == 0 + tensor.size = tensor.size // num_blocks_old * min_num_blocks + + if len(kv_cache_config.kv_cache_groups) > 0: + _report_kv_cache_config(vllm_config, kv_cache_config) + return kv_cache_configs diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 74ff6261732c7..968b4db530bfe 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -12,7 +12,6 @@ logger = init_logger(__name__) class AsyncScheduler(Scheduler): - def _update_after_schedule( self, scheduler_output: SchedulerOutput, @@ -20,8 +19,10 @@ class AsyncScheduler(Scheduler): super()._update_after_schedule(scheduler_output) for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] - if (request.num_computed_tokens == request.num_tokens + - request.num_output_placeholders): + if ( + request.num_computed_tokens + == request.num_tokens + request.num_output_placeholders + ): # The request will generate a new token in this scheduling step. # TODO(woosuk): Support speculative decoding. request.num_output_placeholders += 1 @@ -33,7 +34,8 @@ class AsyncScheduler(Scheduler): ) -> tuple[list[int], bool]: status_before_update = request.status new_token_ids, stopped = super()._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Update the number of output placeholders. request.num_output_placeholders -= len(new_token_ids) @@ -42,6 +44,6 @@ class AsyncScheduler(Scheduler): # Cache the new tokens. Preempted requests should be skipped. if status_before_update == RequestStatus.RUNNING: self.kv_cache_manager.cache_blocks( - request, - request.num_computed_tokens - request.num_output_placeholders) + request, request.num_computed_tokens - request.num_output_placeholders + ) return new_token_ids, stopped diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 5b1de3a66ceb4..b92ef395e9b71 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -14,7 +14,6 @@ if TYPE_CHECKING: class SchedulerInterface(ABC): - @abstractmethod def schedule(self) -> "SchedulerOutput": """Schedule the requests to process in this scheduling step. @@ -72,7 +71,7 @@ class SchedulerInterface(ABC): @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. - + Args: request: The new request being added. """ @@ -91,7 +90,7 @@ class SchedulerInterface(ABC): 1. When the request is aborted by the client. 2. When the frontend process detects a stop string of the request after de-tokenizing its generated tokens. - + Args: request_ids: A single or a list of request IDs. finished_status: The finished status of the given requests. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 9ba7ec9d96932..981c5e9c76361 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -4,34 +4,35 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING + +from vllm._bc_linter import bc_linter_include if TYPE_CHECKING: import numpy as np import numpy.typing as npt + import torch - from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorMetadata) + from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request +@bc_linter_include @dataclass class NewRequestData: - req_id: str - prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_hashes: list[str] - mm_positions: list[PlaceholderRange] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] + prompt_token_ids: list[int] | None + mm_features: list[MultiModalFeatureSpec] + sampling_params: SamplingParams | None + pooling_params: PoolingParams | None block_ids: tuple[list[int], ...] num_computed_tokens: int - lora_request: Optional[LoRARequest] + lora_request: LoRARequest | None + prompt_embeds: torch.Tensor | None = None @classmethod def from_request( @@ -42,47 +43,53 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - mm_kwargs=request.mm_kwargs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, + prompt_embeds=request.prompt_embeds, ) - def __repr__(self): - return (f"NewRequestData(" - f"req_id={self.req_id}," - f"prompt_token_ids={self.prompt_token_ids}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," - f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," - f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}" - ")") + def __repr__(self) -> str: + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) # Version of __repr__ with the prompt data obfuscated - def anon_repr(self): - return (f"NewRequestData(" - f"req_id={self.req_id}," - f"prompt_token_ids_len={len(self.prompt_token_ids)}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," - f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," - f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}" - ")") + def anon_repr(self) -> str: + prompt_token_ids_len = ( + len(self.prompt_token_ids) if self.prompt_token_ids is not None else None + ) + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={prompt_token_ids_len}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) +@bc_linter_include @dataclass class CachedRequestData: - req_ids: list[str] # If resumed_from_preemption is False, new_block_ids will be appended to # the request's block IDs. If True, new_block_ids will be used as the @@ -91,8 +98,12 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[Optional[tuple[list[int], ...]]] + # If resumed_from_preemption is True, propogate the token ids to the + # connector, otherwise will be empty. + resumed_req_token_ids: list[list[int] | None] + new_block_ids: list[tuple[list[int], ...] | None] num_computed_tokens: list[int] + num_output_tokens: list[int] @property def num_reqs(self) -> int: @@ -104,14 +115,16 @@ class CachedRequestData: req_ids=[], resumed_from_preemption=[], new_token_ids=[], + resumed_req_token_ids=[], new_block_ids=[], num_computed_tokens=[], + num_output_tokens=[], ) +@bc_linter_include @dataclass class SchedulerOutput: - # list of the requests that are scheduled for the first time. # We cache the request's data in each worker process, so that we don't # need to re-send it every scheduling step. @@ -143,15 +156,15 @@ class SchedulerOutput: # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. finished_req_ids: set[str] - # list of (req_id, encoder_input_index) tuples. - # Used to free the encoder cache. - free_encoder_input_ids: list[tuple[str, int]] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] # Dict of request ids to their index within the batch # for filling the next token bitmask structured_output_request_ids: dict[str, int] # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] + grammar_bitmask: npt.NDArray[np.int32] | None # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None + kv_connector_metadata: KVConnectorMetadata | None = None diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index fc2bc30b9a5fd..33e5ec72ebd78 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -14,6 +14,7 @@ from vllm.v1.request import Request class SchedulingPolicy(Enum): """Enum for scheduling policies.""" + FCFS = "fcfs" PRIORITY = "priority" @@ -111,9 +112,7 @@ class FCFSRequestQueue(deque[Request], RequestQueue): def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - filtered_requests = [ - req for req in self if req not in requests_to_remove - ] + filtered_requests = [req for req in self if req not in requests_to_remove] # deque does not support in-place filtering, so we need to clear # and extend self.clear() @@ -150,8 +149,7 @@ class PriorityRequestQueue(RequestQueue): def add_request(self, request: Request) -> None: """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, - (request.priority, request.arrival_time, request)) + heapq.heappush(self._heap, (request.priority, request.arrival_time, request)) def pop_request(self) -> Request: """Pop a request from the queue according to priority policy.""" @@ -169,15 +167,15 @@ class PriorityRequestQueue(RequestQueue): def prepend_request(self, request: Request) -> None: """Add a request to the queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the + + Note: In a priority queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: """Add all requests from another queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the + + Note: In a priority queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" for request in requests: self.add_request(request) @@ -190,8 +188,9 @@ class PriorityRequestQueue(RequestQueue): def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - self._heap = [(p, t, r) for p, t, r in self._heap - if r not in requests_to_remove] + self._heap = [ + (p, t, r) for p, t, r in self._heap if r not in requests_to_remove + ] heapq.heapify(self._heap) def __bool__(self) -> bool: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dff624ac6b21d..875220588b3de 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,27 +7,28 @@ import itertools import time from collections import defaultdict from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + compute_encoder_budget, +) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -39,12 +40,12 @@ logger = init_logger(__name__) class Scheduler(SchedulerInterface): - def __init__( self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, + block_size: int, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, @@ -58,22 +59,24 @@ class Scheduler(SchedulerInterface): self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) + self.finished_req_ids_dict: dict[int, set[str]] | None = ( + defaultdict(set) if include_finished_set else None + ) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) + and self.kv_events_config.enable_kv_cache_events + ) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -82,9 +85,14 @@ class Scheduler(SchedulerInterface): if self.vllm_config.kv_transfer_config is not None: assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " - "with KV connectors") + "with KV connectors" + ) + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported with KV connectors" + ) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + config=self.vllm_config, role=KVConnectorRole.SCHEDULER + ) self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -94,7 +102,8 @@ class Scheduler(SchedulerInterface): num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 - self.block_size = self.cache_config.block_size + self.block_size = block_size + self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # req_id -> Request self.requests: dict[str, Request] = {} @@ -105,7 +114,8 @@ class Scheduler(SchedulerInterface): self.policy = SchedulingPolicy.FCFS else: raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") + f"Unknown scheduling policy: {self.scheduler_config.policy}" + ) # Priority queues for requests. self.waiting = create_request_queue(self.policy) self.running: list[Request] = [] @@ -118,6 +128,7 @@ class Scheduler(SchedulerInterface): # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + self.failed_recving_kv_req_ids: set[str] = set() # Encoder-related. # Calculate encoder cache size if applicable @@ -131,14 +142,13 @@ class Scheduler(SchedulerInterface): ) # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed). Currently, we assume that the encoder also - # has the Transformer architecture (e.g., ViT). + # projector if needed) for MM models as well as encoder-decoder + # transformers. self.max_num_encoder_input_tokens = encoder_compute_budget # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -157,6 +167,7 @@ class Scheduler(SchedulerInterface): use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 @@ -182,7 +193,7 @@ class Scheduler(SchedulerInterface): token_budget = self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} - encoder_budget = self.max_num_encoder_input_tokens + encoder_compute_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: dict[str, list[int]] = {} @@ -194,29 +205,35 @@ class Scheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens + ) # Schedule encoder inputs. encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -234,43 +251,50 @@ class Scheduler(SchedulerInterface): req_index += 1 continue + # Schedule newly needed KV blocks for the request. while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) - if new_blocks is None: - # The request cannot be scheduled. - # Preempt the lowest-priority request. - if self.policy == SchedulingPolicy.PRIORITY: - preempted_req = max( - self.running, - key=lambda r: (r.priority, r.arrival_time), - ) - self.running.remove(preempted_req) - else: - preempted_req = self.running.pop() + num_lookahead_tokens=self.num_lookahead_tokens, + ) - self.kv_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp) - - self.waiting.prepend_request(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. - can_schedule = False - break - else: + if new_blocks is not None: # The request can be scheduled. - can_schedule = True break - if not can_schedule: + + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break + + if new_blocks is None: + # Cannot schedule this request. break - assert new_blocks is not None # Schedule the request. scheduled_running_reqs.append(request) @@ -281,30 +305,34 @@ class Scheduler(SchedulerInterface): # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids + ) # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -327,7 +355,8 @@ class Scheduler(SchedulerInterface): else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -345,9 +374,14 @@ class Scheduler(SchedulerInterface): # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -359,29 +393,41 @@ class Scheduler(SchedulerInterface): # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + request, num_new_local_computed_tokens + ) + ) + + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + self.kv_cache_manager.create_empty_block_list() + ) num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + new_encoder_compute_budget = encoder_compute_budget # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: @@ -393,15 +439,21 @@ class Scheduler(SchedulerInterface): # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + self.scheduler_config.long_prefill_token_threshold + ) # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.chunked_prefill_enabled and \ - num_new_tokens > token_budget: + if ( + not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget + ): self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -411,11 +463,16 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -425,9 +482,21 @@ class Scheduler(SchedulerInterface): # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens = ( + self.scheduler_config.max_num_encoder_input_tokens + ) + else: + num_encoder_tokens = 0 new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -436,6 +505,7 @@ class Scheduler(SchedulerInterface): new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, ) if new_blocks is None: @@ -466,20 +536,21 @@ class Scheduler(SchedulerInterface): req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + self.kv_cache_manager.get_blocks(request.request_id) + ) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -490,11 +561,12 @@ class Scheduler(SchedulerInterface): # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget + encoder_compute_budget = new_encoder_compute_budget # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: @@ -508,23 +580,26 @@ class Scheduler(SchedulerInterface): # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) + any_request.request_id + ) + ) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + req, req_to_new_blocks[req.request_id].get_block_ids() + ) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -534,9 +609,12 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens, req_to_new_blocks, ) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(self.running, - scheduled_spec_decode_tokens)) + scheduled_requests = ( + scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs + ) + structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( + scheduled_requests, scheduled_spec_decode_tokens + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -550,7 +628,7 @@ class Scheduler(SchedulerInterface): # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -563,7 +641,19 @@ class Scheduler(SchedulerInterface): meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + # collect KV cache events from KV cache manager events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events if events: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) @@ -612,43 +702,53 @@ class Scheduler(SchedulerInterface): ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + new_block_ids: list[tuple[list[int], ...] | None] = [] + resumed_req_token_ids: list[list[int] | None] = [] num_computed_tokens: list[int] = [] + num_output_tokens: list[int] = [] - use_connector = self.connector is not None - for req in itertools.chain(running_reqs, resumed_reqs): + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] + token_ids = req.all_token_ids[ + req.num_computed_tokens : req.num_computed_tokens + num_tokens + ] new_token_ids.append(token_ids) - elif use_connector: - # When using a KVConnector, we add a placeholder to avoid index - # out of bounds errors. TODO: Remove this once the KVConnector - # is updated to handle token IDs properly. - new_token_ids.append([]) + resumed_token_ids = None + if resumed_from_preemption[idx]: + resumed_token_ids = req.all_token_ids[ + : req.num_computed_tokens + num_tokens + ] + resumed_req_token_ids.append(resumed_token_ids) new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + req_to_new_blocks[req_id].get_block_ids(allow_none=True) + ) num_computed_tokens.append(req.num_computed_tokens) - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) - resumed_from_preemption += [True] * len(resumed_reqs) + num_output_tokens.append( + req.num_output_tokens + req.num_output_placeholders + ) return CachedRequestData( req_ids=req_ids, resumed_from_preemption=resumed_from_preemption, new_token_ids=new_token_ids, + resumed_req_token_ids=resumed_req_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, + num_output_tokens=num_output_tokens, ) def _try_schedule_encoder_inputs( @@ -656,7 +756,7 @@ class Scheduler(SchedulerInterface): request: Request, num_computed_tokens: int, num_new_tokens: int, - encoder_budget: int, + encoder_compute_budget: int, ) -> tuple[list[int], int, int]: """ Determine which encoder inputs need to be scheduled in the current step, @@ -678,14 +778,20 @@ class Scheduler(SchedulerInterface): blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_budget + return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] - mm_positions = request.mm_positions - assert mm_positions is not None - assert len(mm_positions) > 0 - for i, pos_info in enumerate(mm_positions): - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -693,27 +799,55 @@ class Scheduler(SchedulerInterface): if start_pos >= num_computed_tokens + num_new_tokens: # The encoder input is not needed in this step. break - if start_pos + num_encoder_tokens <= num_computed_tokens: + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used." + ) + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder input is already computed and stored # in the decoder's KV cache. continue - if self.encoder_cache_manager.has_cache(request, i): - # The encoder input is already computed and cached. - continue + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_features[i].identifier in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue + + if self.encoder_cache_manager.check_and_update_cache(request, i): + # The encoder input is already computed and cached from a + # previous step. + continue # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + if ( + self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens) + ): num_new_tokens = start_pos - num_computed_tokens break - if (not self.encoder_cache_manager.can_allocate(request, i) - or num_encoder_tokens > encoder_budget): + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, num_tokens_to_schedule + ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -730,9 +864,16 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break - encoder_budget -= num_encoder_tokens + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_features[i].identifier) encoder_inputs_to_schedule.append(i) - return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) def get_grammar_bitmask( self, @@ -742,7 +883,7 @@ class Scheduler(SchedulerInterface): # NOTE: structured_output_request_ids maps # a request's (request that uses structured output) # request_id to its index in the batch. - # This will helps us determine to slice the grammar bitmask + # This will help us determine to slice the grammar bitmask # and only applies valid mask for requests that # uses structured decoding. structured_output_request_ids: dict[str, int] = {} @@ -775,9 +916,22 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) + + failed_kv_load_req_ids = None + if kv_connector_output and kv_connector_output.invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + failed_kv_load_req_ids = self._handle_invalid_blocks( + kv_connector_output.invalid_block_ids + ) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -787,6 +941,9 @@ class Scheduler(SchedulerInterface): for req_index, req_id in enumerate(model_runner_output.req_ids): num_tokens_scheduled = num_scheduled_tokens[req_id] assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue request = self.requests.get(req_id) if request is None: # The request is already finished. This can happen if the @@ -794,25 +951,28 @@ class Scheduler(SchedulerInterface): # in pipeline parallelism). continue - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] + ) scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted # num_computed_tokens represents the number of tokens # processed in the current step, considering scheduled # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected - # tokens, where is given by: - # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - - len(generated_token_ids)) - request.num_computed_tokens -= num_tokens_rejected + # tokens. + request.num_computed_tokens -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, - num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -823,14 +983,14 @@ class Scheduler(SchedulerInterface): # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -840,28 +1000,29 @@ class Scheduler(SchedulerInterface): stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): + if new_token_ids and self.structured_output_manager.should_advance(request): # NOTE: structured_output_request # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning + # checked above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + req_id, new_token_ids + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -874,9 +1035,10 @@ class Scheduler(SchedulerInterface): stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, - )) - + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -889,9 +1051,8 @@ class Scheduler(SchedulerInterface): self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. - if model_runner_output.kv_connector_output: - self._update_from_kv_xfer_finished( - model_runner_output.kv_connector_output) + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. @@ -910,10 +1071,13 @@ class Scheduler(SchedulerInterface): eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + finished_requests=finished_set + ) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats)) is not None: + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -944,8 +1108,9 @@ class Scheduler(SchedulerInterface): return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) + cached_encoder_input_ids = self.encoder_cache_manager.get_cached_input_ids( + request + ) # OPTIMIZATION: Avoid list(set) if the set is empty. if not cached_encoder_input_ids: return @@ -953,22 +1118,26 @@ class Scheduler(SchedulerInterface): # Here, we use list(set) to avoid modifying the set while iterating # over it. for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input(request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, ): request = self.requests.get(req_id) if request is None or request.is_finished(): @@ -982,7 +1151,8 @@ class Scheduler(SchedulerInterface): elif self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_token_ids + ) else: request.spec_token_ids = spec_token_ids @@ -1008,7 +1178,7 @@ class Scheduler(SchedulerInterface): """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): - request_ids = (request_ids, ) + request_ids = (request_ids,) else: request_ids = set(request_ids) @@ -1019,7 +1189,7 @@ class Scheduler(SchedulerInterface): # First pass: collect requests to remove from queues for req_id in request_ids: request = self.requests.get(req_id) - if request is None: + if request is None or request.is_finished(): # Invalid request ID. continue @@ -1040,7 +1210,7 @@ class Scheduler(SchedulerInterface): request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + def _free_request(self, request: Request) -> dict[str, Any] | None: assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) @@ -1071,8 +1241,9 @@ class Scheduler(SchedulerInterface): def make_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats] = None, - ) -> Optional[SchedulerStats]: + spec_decoding_stats: SpecDecodingStats | None = None, + kv_connector_stats: KVConnectorStats | None = None, + ) -> SchedulerStats | None: if not self.log_stats: return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() @@ -1083,38 +1254,41 @@ class Scheduler(SchedulerInterface): kv_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), + num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), + kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, ) def make_spec_decoding_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats], + spec_decoding_stats: SpecDecodingStats | None, num_draft_tokens: int, num_accepted_tokens: int, - ) -> Optional[SpecDecodingStats]: + ) -> SpecDecodingStats | None: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens + ) return spec_decoding_stats def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() ######################################################################## # KV Connector Related Methods ######################################################################## - def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + self, request: Request + ) -> tuple[bool, dict[str, Any] | None]: """ Invoke the KV connector request_finished() method if applicable. @@ -1124,7 +1298,7 @@ class Scheduler(SchedulerInterface): if self.connector is None: return False, None - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -1143,25 +1317,37 @@ class Scheduler(SchedulerInterface): if request.request_id not in self.finished_recving_kv_req_ids: return False - # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) - num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less then one block. - num_computed_tokens = min(num_computed_tokens, request.num_tokens) - if num_computed_tokens == request.num_tokens: - num_computed_tokens -= 1 - # This will cache the blocks iff caching is enabled. - self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + if request.request_id in self.failed_recving_kv_req_ids: + # Request had KV load failures; num_computed_tokens was already + # updated in _update_requests_with_invalid_blocks + if request.num_computed_tokens: + # Cache any valid computed tokens. + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) + else: + # No valid computed tokens, release allocated blocks. + # There may be a local cache hit on retry. + self.kv_cache_manager.free(request) - # Update the request state for scheduling. - request.num_computed_tokens = num_computed_tokens + self.failed_recving_kv_req_ids.remove(request.request_id) + else: + # Now that the blocks are ready, actually cache them. + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens # Return that we are ready. self.finished_recving_kv_req_ids.remove(request.request_id) return True - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1169,16 +1355,156 @@ class Scheduler(SchedulerInterface): finished_sending reqs to the output. * if finished_sending: free the blocks # if finished_recving: add to state so we can - scheduler the request during the next step. + schedule the request during the next step. """ if self.connector is not None: self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): + for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): + for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) + assert req_id in self.requests self._free_blocks(self.requests[req_id]) + + def _update_requests_with_invalid_blocks( + self, requests: Iterable[Request], invalid_block_ids: set[int] + ) -> tuple[set[str], int]: + """ + Identify and update requests affected by invalid KV cache blocks. + + This method scans the given requests, detects those with invalid blocks + and adjusts their `num_computed_tokens` to the longest valid prefix. + For observability, it also accumulates the total number of tokens that + will need to be recomputed across all affected requests. + + Args: + requests: The set of requests to scan for invalid blocks. + invalid_block_ids: IDs of invalid blocks. + + Returns: + tuple: + - affected_req_ids (set[str]): IDs of requests impacted by + invalid blocks. + - total_affected_tokens (int): Total number of tokens that must + be recomputed across all affected requests (for observability). + """ + affected_req_ids: set[str] = set() + total_affected_tokens = 0 + # If a block is invalid and shared by multiple requests in the batch, + # these requests must be rescheduled, but only the first will recompute + # it. This set tracks blocks already marked for recomputation. + marked_invalid_block_ids: set[int] = set() + for request in requests: + is_affected = False + marked_invalid_block = False + req_id = request.request_id + # TODO (davidb): add support for hybrid memory allocator + (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) + # We iterate only over blocks that may contain externally computed + # tokens + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + # Async loading. If num_computed_tokens is set it implies we + # already processed some block failures for it in a prior step + req_num_computed_tokens = ( + request.num_computed_tokens + if req_id in self.failed_recving_kv_req_ids + else len(req_block_ids) * self.block_size + ) + else: + # Sync loading. num_computed_tokens includes new tokens + req_num_computed_tokens = request.num_cached_tokens + + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): + if block_id not in invalid_block_ids: + continue + + is_affected = True + + if block_id in marked_invalid_block_ids: + # This invalid block is shared with a previous request + # and was already marked for recomputation. + # This means this request can still consider this block + # as computed when rescheduled. + # Currently this only applies to sync loading; Async + # loading does not yet support block sharing + continue + + marked_invalid_block_ids.add(block_id) + + if marked_invalid_block: + # This request has already marked an invalid block for + # recomputation and updated its num_computed_tokens. + continue + + marked_invalid_block = True + # Truncate the computed tokens at the first failed block + request.num_computed_tokens = idx * self.block_size + total_affected_tokens += ( + req_num_computed_tokens - request.num_computed_tokens + ) + + if is_affected: + if not marked_invalid_block: + # All invalid blocks of this request are shared with + # previous requests and will be recomputed by them. + # Revert to considering only cached tokens as computed. + # Currently this only applies to sync loading; Async + # loading does not yet support block sharing + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) + request.num_computed_tokens = request.num_cached_tokens + + affected_req_ids.add(request.request_id) + + return affected_req_ids, total_affected_tokens + + def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: + total_requests_to_reschedule = 0 + total_tokens_to_reschedule = 0 + + # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- + async_load_reqs = ( + req + for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ) + async_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks( + async_load_reqs, invalid_block_ids + ) + ) + + total_requests_to_reschedule += len(async_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + # Mark requests with async KV load failures; they will be rescheduled + # once loading completes + self.failed_recving_kv_req_ids |= async_affected_req_ids + + # --- Handle sync KV loads (running requests) --- + sync_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + ) + + total_requests_to_reschedule += len(sync_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + if total_requests_to_reschedule: + logger.warning( + "Recovered from KV load failure: " + "%d request(s) rescheduled (%d tokens affected).", + total_requests_to_reschedule, + total_tokens_to_reschedule, + ) + + # Return the IDs of affected running requests to skip in + # update_from_output. + return sync_affected_req_ids diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 42d3e5c68b4c8..5906a73382a2d 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -10,19 +10,19 @@ from vllm.v1.request import Request, RequestStatus def remove_all(lst: list, items_to_remove: set) -> list: """Remove all items from a list that are in the items_to_remove set. - + This method optimizes for the common case of removing a single item, falling back to list comprehension for multiple items. - + Args: lst: The list to remove items from items_to_remove: Set of items to remove - + Returns: Either the modified original list (for single item removal) or a new list (for multiple item removal). Callers should use the returned value. - + Note: For single item removal, this modifies the original list in-place and returns it. For multiple items, it creates and returns a new list. @@ -40,11 +40,13 @@ def remove_all(lst: list, items_to_remove: set) -> list: return [item for item in lst if item not in items_to_remove] -def check_stop(request: Request, - max_model_len: int, - pooler_output: Optional[torch.Tensor] = None) -> bool: - if (request.num_tokens >= max_model_len - or request.num_output_tokens >= request.max_tokens): +def check_stop( + request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None +) -> bool: + if ( + request.num_tokens >= max_model_len + or request.num_output_tokens >= request.max_tokens + ): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True @@ -56,9 +58,12 @@ def check_stop(request: Request, sampling_params = request.sampling_params assert sampling_params is not None + + if request.num_output_tokens < sampling_params.min_tokens: + return False + last_token_id = request.output_token_ids[-1] - if (not sampling_params.ignore_eos - and last_token_id == request.eos_token_id): + if not sampling_params.ignore_eos and last_token_id == request.eos_token_id: request.status = RequestStatus.FINISHED_STOPPED return True diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 82e0292522b9a..7984a6ce29df7 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,15 +7,21 @@ from collections import defaultdict from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) from vllm.v1.request import Request class SingleTypeKVCacheManager(ABC): """ - An abstract base class for a manager that handle the kv cache management + An abstract base class for a manager that handle the kv cache management logic of one specific type of attention layer. """ @@ -24,6 +30,7 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, + dcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -32,35 +39,36 @@ class SingleTypeKVCacheManager(ABC): block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. """ - self.block_size = kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if self.dcp_world_size > 1: + self.block_size *= dcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. + # data for preempted ones. self.num_cached_block: dict[str, int] = {} self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -70,20 +78,23 @@ class SingleTypeKVCacheManager(ABC): """ num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = (num_required_blocks - len(new_computed_blocks) - - len(self.req_to_blocks[request_id])) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block # to a computed block when the request is allocated, so we also count # it as needed to be allocated. num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks) + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: """ Add the new computed blocks to the request. @@ -102,15 +113,16 @@ class SingleTypeKVCacheManager(ABC): # A running request. Should not have new computed blocks. assert len(new_computed_blocks) == 0 - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). Returns: @@ -132,12 +144,15 @@ class SingleTypeKVCacheManager(ABC): Args: request: The request. - num_tokens: The total number of tokens that need to be cached + num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ num_cached_blocks = self.num_cached_block[request.request_id] num_full_blocks = num_tokens // self.block_size + if num_cached_blocks >= num_full_blocks: + return + self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], @@ -167,20 +182,17 @@ class SingleTypeKVCacheManager(ABC): self.num_cached_block.pop(request_id, None) @abstractmethod - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ - Get the number of common prefix blocks for all requests in the RUNNING - state. + Get the number of common prefix blocks for all requests with allocated + KV cache. Args: - request_id: The request ID. - num_running_requests: The total number of requests in the RUNNING - state. + running_request_id: The request ID. Returns: - The number of common prefix blocks for all requests in the RUNNING - state. + The number of common prefix blocks for all requests with allocated + KV cache. """ raise NotImplementedError @@ -195,14 +207,15 @@ class SingleTypeKVCacheManager(ABC): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ - Get the longest cache hit prefix of the blocks that is not longer than - `max_length`. The prefix should be a common prefix hit for all the - kv cache groups in `kv_cache_group_ids`. If no cache hit is found, - return an empty list. - If eagle is enabled, drop the last matched block to force recompute the - last block to get the required hidden states for eagle drafting head. + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. The prefix should be a common prefix hit for all the + kv cache groups in `kv_cache_group_ids`. If no cache hit is found, + return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type. Args: @@ -227,10 +240,9 @@ class SingleTypeKVCacheManager(ABC): raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and free the + Remove the blocks that are no longer needed from `blocks` and free the blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. @@ -242,7 +254,6 @@ class SingleTypeKVCacheManager(ABC): class FullAttentionManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -252,20 +263,28 @@ class FullAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) - ), "FullAttentionManager can only be used for full attention " \ + ), ( + "FullAttentionManager can only be used for full attention " "and chunked local attention groups" + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) - max_num_blocks = max_length // kv_cache_spec.block_size + [] for _ in range(len(kv_cache_group_ids)) + ) + block_size = kv_cache_spec.block_size + if dcp_world_size > 1: + block_size *= dcp_world_size + max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: @@ -275,17 +294,15 @@ class FullAttentionManager(SingleTypeKVCacheManager): computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # No need to remove blocks for full attention. pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: - blocks = self.req_to_blocks[request_id] + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + blocks = self.req_to_blocks[running_request_id] num_common_blocks = 0 for block in blocks: - if block.ref_cnt == num_running_requests: + if block.ref_cnt == len(self.req_to_blocks): num_common_blocks += 1 else: break @@ -293,9 +310,9 @@ class FullAttentionManager(SingleTypeKVCacheManager): class SlidingWindowManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - **kwargs) -> None: + def __init__( + self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.sliding_window = kv_cache_spec.sliding_window self._null_block = block_pool.null_block @@ -309,14 +326,18 @@ class SlidingWindowManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( - "SlidingWindowManager can only be used for sliding window groups") + "SlidingWindowManager can only be used for sliding window groups" + ) + assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window sliding_window_contiguous_blocks = cdiv( - kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size) + kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size + ) if use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of @@ -330,14 +351,17 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // kv_cache_spec.block_size - computed_blocks = tuple([block_pool.null_block] * max_num_blocks - for _ in range(len(kv_cache_group_ids))) + computed_blocks = tuple( + [block_pool.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids)) + ) num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( - block_hashes[i], kv_cache_group_ids): + block_hashes[i], kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached num_contiguous_blocks += 1 @@ -346,7 +370,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. for computed in computed_blocks: - del computed[i + num_contiguous_blocks:] + del computed[i + num_contiguous_blocks :] match_found = True break else: @@ -361,8 +385,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 @@ -379,21 +402,20 @@ class SlidingWindowManager(SingleTypeKVCacheManager): blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ NOTE(Chen): The prefix blocks are null blocks for sliding window layers. - So it's not correct to count ref_cnt like FullAttentionManager. Return - 0 here for correctness. Need to support cascade attention + sliding + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding window in the future. """ return 0 class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, - block_pool: BlockPool, **kwargs) -> None: + def __init__( + self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size self._null_block = block_pool.null_block @@ -407,25 +429,26 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit prefix of the blocks that is not longer than `max_length`. The prefix should be a common prefix hit for all the kv cache groups in `kv_cache_group_ids`. If no cache hit is found, return an empty list. - note we mark as computed if the whole block is outside of the local + note we mark as computed if the whole block is outside of the local window, and set the block as null. Examples: 1. Attention chunk size of 8, block size of 4, max length of 15 - for next token at 15th (zero-indexed), 8th - 14th tokens are in - the window(needs lookup), 0th - 7th are not in the window, - so they are already marked as computed. We check the complete - block3 (8th - 11th tokens), Assume block 3 is hit, we will return + for next token at 15th (zero-indexed), 8th - 14th tokens are in + the window(needs lookup), 0th - 7th are not in the window, + so they are already marked as computed. We check the complete + block3 (8th - 11th tokens), Assume block 3 is hit, we will return [null, null, block 3], otherwise, we return [null, null] 2. Attention chunk size of 8, block size of 4, max length of 16 - for next token at 16th (zero-indexed), 0th - 15th tokens are not - in the window, so they are already marked as computed. + for next token at 16th (zero-indexed), 0th - 15th tokens are not + in the window, so they are already marked as computed. we return 4 blocks[null, null, null, null] Args: @@ -440,38 +463,45 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): A list of cached blocks """ assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( - "ChunkedLocalAttentionManager can only be used for " + - "chunked local attention groups") - assert use_eagle is False, ("Hybrid KV cache is not supported for " + - "eagle + chunked local attention.") + "ChunkedLocalAttentionManager can only be used for " + + "chunked local attention groups" + ) + assert use_eagle is False, ( + "Hybrid KV cache is not supported for " + "eagle + chunked local attention." + ) + assert dcp_world_size == 1, "DCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: - local_attention_start_idx = (max_length // - kv_cache_spec.attention_chunk_size * - kv_cache_spec.attention_chunk_size) + local_attention_start_idx = ( + max_length + // kv_cache_spec.attention_chunk_size + * kv_cache_spec.attention_chunk_size + ) else: local_attention_start_idx = 0 # we marked blocks out of window as computed # with null blocks, and blocks inside window based on cache lookup # result [null] [null] ... [null] [hit block 1 (1st block contain # last window)] [hit block 2] ... [hit block x] - local_attention_start_block_idx = (local_attention_start_idx // - kv_cache_spec.block_size) + local_attention_start_block_idx = ( + local_attention_start_idx // kv_cache_spec.block_size + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [block_pool.null_block] * local_attention_start_block_idx - for _ in range(len(kv_cache_group_ids))) + for _ in range(len(kv_cache_group_ids)) + ) for i in range(local_attention_start_block_idx, max_num_blocks): block_hash = block_hashes[i] if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: break return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the chunked attention # window and skipped during the attention computation. @@ -483,13 +513,14 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): # is 1024. for 1023, it will be 0. num_cached_block = self.num_cached_block.get(request_id, 0) local_attention_start_idx = ( - num_computed_tokens - ) // self.attention_chunk_size * self.attention_chunk_size + (num_computed_tokens) + // self.attention_chunk_size + * self.attention_chunk_size + ) first_useful_block_idx = local_attention_start_idx // self.block_size if num_cached_block > 0: # Make sure we don't delete the last cached block - first_useful_block_idx = min(first_useful_block_idx, - num_cached_block - 1) + first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1) # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> # block 8, 372 (= 128 * 2 + 116) -> block 2 blocks = self.req_to_blocks[request_id] @@ -505,8 +536,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ cascade attention is not supported by chunked local attention. """ @@ -514,6 +544,102 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): class MambaManager(SingleTypeKVCacheManager): + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + dcp_world_size: int = 1, + ) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, MambaSpec), ( + "MambaManager can only be used for mamba groups" + ) + assert dcp_world_size == 1, "DCP not support mamba now." + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(len(kv_cache_group_ids)) + ) + + max_num_blocks = max_length // kv_cache_spec.block_size + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := block_pool.get_cached_block( + block_hashes[i], kv_cache_group_ids + ): + for computed, cached in zip(computed_blocks, cached_block): + # the hit length logic later assumes: + # hit_length = len(hit_blocks_other_attn[0]) + # * self.other_block_size + # so we insert dummy blocks at the beginning: + computed.extend([block_pool.null_block] * i) + computed.append(cached) + break # we just need the last match - early stopping + + return computed_blocks + + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + # Here unused blocks may be freed up for running requests. + # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 + # (for which find_longest_cache_hit returns block_pool.null_block) + pass + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + """ + cascade attention is not supported by mamba + """ + return 0 + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks + ) + + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().allocate_new_blocks(request_id, num_tokens) + + +class CrossAttentionManager(SingleTypeKVCacheManager): + """Manager for cross-attention KV cache in encoder-decoder models.""" + + def save_new_computed_blocks( + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so `new_computed_blocks` should always be empty. + assert len(new_computed_blocks) == 0 + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so this method is not relevant. + raise ValueError("Should not be called as prefix caching is disabled.") + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + # Cross-attention blocks contain request-specific encoder states + # and are not shared between different requests + return 0 @classmethod def find_longest_cache_hit( @@ -524,44 +650,38 @@ class MambaManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance( - kv_cache_spec, - MambaSpec), ("MambaManager can only be used for mamba groups") - # Prefix caching is not supported for mamba now. Always return empty - # list. - computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) - return computed_blocks + assert isinstance(kv_cache_spec, CrossAttentionSpec), ( + "CrossAttentionManager can only be used for cross-attention groups" + ) + # Cross-attention does not benefit from prefix caching since: + # 1. Encoder states are unique per request (different audio/image + # inputs) + # 2. Encoder states are computed once per request, not incrementally + # 3. No reusable prefix exists between different multimodal inputs + # Return empty blocks to indicate no cache hits + raise NotImplementedError("CrossAttentionManager does not support caching") - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # Each request will always have 1 block at this moment, so no need to - # remove blocks. + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + # Cross-attention blocks represent encoder states which are needed + # for the entire decoding process, so no blocks should be skipped pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: - return 0 - - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: - new_blocks = super().allocate_new_blocks(request_id, num_tokens) - assert len(self.req_to_blocks[request_id]) == 1, ( - "MambaManager should only allocate 1 block for each request.") - return new_blocks - spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, } -def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, - **kwargs) -> SingleTypeKVCacheManager: +def get_manager_for_kv_cache_spec( + kv_cache_spec: KVCacheSpec, **kwargs +) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 02e65820b7c00..ce47147028696 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -2,29 +2,27 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional -from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor -from vllm.logger import init_logger - -logger = init_logger(__name__) class CudagraphDispatcher: """ - Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs. + Runtime cudagraph dispatcher to dispatch keys for multiple set of + cudagraphs. The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one - for FULL cudagraph runtime mode. The keys are initialized depending on - attention support and what cudagraph mode is set in CompilationConfig. The + for FULL cudagraph runtime mode. The keys are initialized depending on + attention support and what cudagraph mode is set in CompilationConfig. The keys stored in dispatcher are the only source of truth for valid cudagraphs that can be dispatched at runtime. - At runtime, the dispatch method generates the runtime cudagraph mode (FULL, + At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) - based on the input key. After dispatching (commuicate via forward context), - the cudagraph wrappers will trust the dispatch key to do either capturing - or replaying (if mode matched), or pass through to the underlying runnable - without cudagraph (if mode no match or mode is NONE). + based on the input key. After dispatching (communicated via forward + context), the cudagraph wrappers will trust the dispatch key to either + capture or replay (if the mode matches), or pass through to the underlying + runnable without cudagraph (if the mode does not match or mode is NONE). """ def __init__(self, vllm_config: VllmConfig): @@ -38,78 +36,93 @@ class CudagraphDispatcher: CUDAGraphMode.FULL: set(), } - assert not self.cudagraph_mode.requires_piecewise_compilation() or \ - (self.compilation_config.level == CompilationLevel.PIECEWISE and - self.compilation_config.splitting_ops_contain_attention()), \ - "Compilation level should be CompilationLevel.PIECEWISE when "\ - "cudagraph_mode piecewise cudagraphs is used, "\ - f"cudagraph_mode={self.cudagraph_mode}, "\ - f"compilation_level={self.compilation_config.level}, "\ + not_use_piecewise_compilation = ( + not self.cudagraph_mode.requires_piecewise_compilation() + ) + + assert ( + not_use_piecewise_compilation + or self.compilation_config.is_attention_compiled_piecewise() + ), ( + "Compilation level should be CompilationLevel.PIECEWISE when " + "cudagraph_mode piecewise cudagraphs is used, " + "and attention should be in splitting_ops or " + "inductor splitting should be used. " + f"cudagraph_mode={self.cudagraph_mode}, " + f"compilation_level={self.compilation_config.level}, " f"splitting_ops={self.compilation_config.splitting_ops}" + ) self.keys_initialized = False - def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, - batch_descriptor: BatchDescriptor): - assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ - f"Invalid cudagraph runtime mode: {runtime_mode}" + def add_cudagraph_key( + self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor + ): + assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( + f"Invalid cudagraph runtime mode for keys: {runtime_mode}" + ) self.cudagraph_keys[runtime_mode].add(batch_descriptor) - def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, - uniform_decode_query_len: int): + def initialize_cudagraph_keys( + self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int + ): # This should be called only after attention backend is initialized. - # Note: we create all valid keys possible for cudagraph but do not - # guarantee all keys would be used. For example, we create keys for - # piecewise cudagraphs when it is piecewise compilation, which is always - # valid, but for attention backend support unified routine, we may not - # trigger capturing/replaying the piecewise cudagraphs depending on - # CompilationConfig.cudagraph_mode. In addition, if we allow lazy + # Note: we create all valid keys for cudagraph here but do not + # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: for bs in self.compilation_config.cudagraph_capture_sizes: self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor(num_tokens=bs, uniform_decode=False)) + BatchDescriptor(num_tokens=bs, uniform_decode=False), + ) # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ - and cudagraph_mode.separate_routine(): - max_num_tokens = uniform_decode_query_len * \ - self.vllm_config.scheduler_config.max_num_seqs + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + uniform_decode_query_len + * self.vllm_config.scheduler_config.max_num_seqs + ) cudagraph_capture_sizes_for_decode = [ - x for x in self.compilation_config.cudagraph_capture_sizes + x + for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs in cudagraph_capture_sizes_for_decode: self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor(num_tokens=bs, uniform_decode=True)) + BatchDescriptor(num_tokens=bs, uniform_decode=True), + ) self.keys_initialized = True def dispatch( - self, batch_descriptor: BatchDescriptor + self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: """ - Given a batch descriptor, dispatch to a cudagraph mode. - A new batch descriptor is returned as we might dispatch a uniform batch + Given conditions(e.g.,batch descriptor and if using cascade attention), + dispatch to a cudagraph runtime mode and the valid batch descriptor. + A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ # if not initialized, just skip dispatching. if not self.keys_initialized: - logger.warning_once("cudagraph dispatching keys are not " - "initialized. No cudagraph will be used.") return CUDAGraphMode.NONE, None - # check if key exists for full cudagraph - if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, batch_descriptor - - # otherwise, check if non-uniform key exists non_uniform_key = batch_descriptor.non_uniform - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, non_uniform_key + # if a batch use cascade attention, bypass checking full cudagraphs + if not use_cascade_attn: + # check if key exists for full cudagraph + if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, batch_descriptor + + # otherwise, check if non-uniform key exists + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, non_uniform_key # also check if non-uniform key exists for more "general" # piecewise cudagraph diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index f7ec982db41b4..163c050e559e0 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,14 +3,14 @@ import enum import time -from collections.abc import Sequence +from collections.abc import Mapping from typing import Any, Optional, Union import msgspec import torch from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats @@ -32,6 +32,7 @@ class FinishReason(enum.IntEnum): abort - aborted for another reason """ + STOP = 0 LENGTH = 1 ABORT = 2 @@ -41,16 +42,14 @@ class FinishReason(enum.IntEnum): class EngineCoreRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] request_id: str - prompt_token_ids: list[int] - mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]] - mm_hashes: Optional[list[str]] - mm_placeholders: Optional[list[PlaceholderRange]] + prompt_token_ids: Optional[list[int]] + mm_features: Optional[list[MultiModalFeatureSpec]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] eos_token_id: Optional[int] @@ -58,6 +57,7 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] data_parallel_rank: Optional[int] + prompt_embeds: Optional[torch.Tensor] = None # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. @@ -69,9 +69,12 @@ class EngineCoreRequest( current_wave: int = 0 priority: int = 0 + trace_headers: Optional[Mapping[str, str]] = None + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" + QUEUED = 1 SCHEDULED = 2 PREEMPTED = 3 @@ -84,23 +87,24 @@ class EngineCoreEvent(msgspec.Struct): frontend to calculate intervals between engine core events. These timestamps should not be compared with timestamps from other processes. """ + type: EngineCoreEventType timestamp: float @classmethod - def new_event(cls, - event_type: EngineCoreEventType, - timestamp: Optional[float] = None) -> "EngineCoreEvent": + def new_event( + cls, event_type: EngineCoreEventType, timestamp: Optional[float] = None + ) -> "EngineCoreEvent": timestamp = time.monotonic() if timestamp is None else timestamp return cls(event_type, timestamp) class EngineCoreOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] request_id: str new_token_ids: list[int] @@ -114,6 +118,7 @@ class EngineCoreOutput( events: Optional[list[EngineCoreEvent]] = None kv_transfer_params: Optional[dict[str, Any]] = None + trace_headers: Optional[Mapping[str, str]] = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -130,10 +135,10 @@ class UtilityResult: class UtilityOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] call_id: int # Non-None implies the call failed, result should be None. @@ -142,12 +147,12 @@ class UtilityOutput( class EngineCoreOutputs( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - - #NOTE(Nick): We could consider ways to make this more compact, + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] + # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout engine_index: int = 0 @@ -177,12 +182,13 @@ class EngineCoreRequestType(enum.Enum): Request types defined as hex byte strings, so it can be sent over sockets without separate encoding step. """ - ADD = b'\x00' - ABORT = b'\x01' - START_DP_WAVE = b'\x02' - UTILITY = b'\x03' + + ADD = b"\x00" + ABORT = b"\x01" + START_DP_WAVE = b"\x02" + UTILITY = b"\x03" # Sentinel used within EngineCoreProc. - EXECUTOR_FAILED = b'\x04' + EXECUTOR_FAILED = b"\x04" class ReconfigureDistributedRequest(msgspec.Struct): @@ -197,5 +203,6 @@ class ReconfigureRankType(enum.IntEnum): """ Rank type for reconfiguring distributed request. """ + KEEP_CURRENT_RANK = -1 SHUTDOWN_CURRENT_RANK = -2 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 342d7b24f8e98..112ec92b3af8e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -12,31 +12,29 @@ import numpy as np import torch import vllm.envs as envs -from vllm.config import ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient +from vllm.entrypoints.utils import _validate_truncation_size from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.inputs import PromptType -from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.tracing import init_tracer +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, - deprecate_kwargs) +from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError -from vllm.v1.engine.output_processor import (OutputProcessor, - RequestOutputCollector) +from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -48,7 +46,6 @@ logger = init_logger(__name__) class AsyncLLM(EngineClient): - def __init__( self, vllm_config: VllmConfig, @@ -89,35 +86,44 @@ class AsyncLLM(EngineClient): "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) # Ensure we can serialize custom transformer configs maybe_register_config_serialize_by_value() self.model_config = vllm_config.model_config self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.log_requests = log_requests - self.log_stats = log_stats + + self.log_stats = log_stats or (stat_loggers is not None) + if not log_stats and stat_loggers is not None: + logger.info( + "AsyncLLM created with log_stats=False and non-empty custom " + "logger list; enabling logging without default stat loggers" + ) if self.model_config.skip_tokenizer_init: - self.tokenizer = None + tokenizer = None else: - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + tokenizer = init_tokenizer_from_configs(self.model_config) - # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor( - vllm_config=vllm_config, - tokenizer=self.tokenizer, - mm_registry=mm_registry, + self.processor = Processor(self.vllm_config, tokenizer) + self.io_processor = get_io_processor( + self.vllm_config, + self.model_config.io_processor_plugin, ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + self.tokenizer, log_stats=self.log_stats + ) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", self.observability_config.otlp_traces_endpoint + ) + self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_async_mp_client( @@ -136,6 +142,8 @@ class AsyncLLM(EngineClient): vllm_config=vllm_config, engine_idxs=self.engine_core.engine_ranks_managed, custom_stat_loggers=stat_loggers, + enable_default_loggers=log_stats, + client_count=client_count, ) self.logger_manager.log_engine_initialized() @@ -150,7 +158,8 @@ class AsyncLLM(EngineClient): if envs.VLLM_TORCH_PROFILER_DIR: logger.info( "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 - envs.VLLM_TORCH_PROFILER_DIR) + envs.VLLM_TORCH_PROFILER_DIR, + ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" self.profiler = torch.profiler.profile( activities=[ @@ -158,40 +167,39 @@ class AsyncLLM(EngineClient): ], with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, - worker_name=worker_name, - use_gzip=True)) - else: - logger.info( - "Torch profiler disabled. AsyncLLM CPU traces will not be collected." # noqa: E501 + envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True + ), ) + else: self.profiler = None @classmethod @deprecate_kwargs( "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), + additional_message=( + "This argument will have no effect. Use `enable_log_requests` instead." + ), ) def from_vllm_config( - cls, - vllm_config: VllmConfig, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0, - disable_log_requests: bool = True, # Deprecated, will be removed + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_log_requests: bool = False, + disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) # Create the LLMEngine. return cls( @@ -251,7 +259,7 @@ class AsyncLLM(EngineClient): async def add_request( self, request_id: str, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -259,6 +267,7 @@ class AsyncLLM(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, + prompt_text: Optional[str] = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -271,33 +280,58 @@ class AsyncLLM(EngineClient): queue = RequestOutputCollector(output_kind=params.output_kind) # Convert Input --> Request. - prompt_str, request = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - tokenization_kwargs, trace_headers, priority, data_parallel_rank) + if isinstance(prompt, EngineCoreRequest): + request = prompt + else: + assert prompt_text is None + logger.warning_once( + "Processor has been moved under OpenAIServing and will " + "be removed from AsyncLLM in v0.13." + ) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + data_parallel_rank, + ) + prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if is_pooling or params.n == 1: - await self._add_request(request, prompt_str, None, 0, queue) + await self._add_request(request, prompt_text, None, 0, queue) return queue + # Get the updated SamplingParams from the request, which + # were cloned/updated in processor.process_inputs above. + parent_params = request.sampling_params + assert parent_params is not None + # Fan out child requests (for n>1). - parent_request = ParentRequest(request_id, params) - for idx in range(params.n): - request_id, params = parent_request.get_child_info(idx) - child_request = request if idx == params.n - 1 else copy(request) + parent_request = ParentRequest(request_id, parent_params) + for idx in range(parent_params.n): + request_id, child_params = parent_request.get_child_info(idx) + child_request = request if idx == parent_params.n - 1 else copy(request) child_request.request_id = request_id - child_request.sampling_params = params - await self._add_request(child_request, prompt_str, parent_request, - idx, queue) + child_request.sampling_params = child_params + await self._add_request( + child_request, prompt_text, parent_request, idx, queue + ) return queue - async def _add_request(self, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest], index: int, - queue: RequestOutputCollector): - + async def _add_request( + self, + request: EngineCoreRequest, + prompt: Optional[str], + parent_req: Optional[ParentRequest], + index: int, + queue: RequestOutputCollector, + ): # Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, prompt, parent_req, index, - queue) + self.output_processor.add_request(request, prompt, parent_req, index, queue) # Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(request) @@ -312,10 +346,13 @@ class AsyncLLM(EngineClient): # re-multiplexed in the API server anyhow. async def generate( self, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], sampling_params: SamplingParams, request_id: str, + *, + prompt_text: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, @@ -335,20 +372,42 @@ class AsyncLLM(EngineClient): returning the RequestOutput back to the caller. """ + if ( + self.vllm_config.cache_config.kv_sharing_fast_prefill + and sampling_params.prompt_logprobs + ): + raise ValueError( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, please disable it when the requests need " + "prompt logprobs" + ) + try: # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + if tokenization_kwargs is None: + tokenization_kwargs = {} + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) + q = await self.add_request( request_id, prompt, sampling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, data_parallel_rank=data_parallel_rank, + prompt_text=prompt_text, ) # The output_handler task pushes items into the queue. @@ -404,6 +463,7 @@ class AsyncLLM(EngineClient): output_processor = self.output_processor log_stats = self.log_stats logger_manager = self.logger_manager + processor = self.processor async def output_handler(): try: @@ -412,23 +472,26 @@ class AsyncLLM(EngineClient): outputs = await engine_core.get_output_async() num_outputs = len(outputs.outputs) - iteration_stats = IterationStats() if ( - log_stats and num_outputs) else None + iteration_stats = ( + IterationStats() if (log_stats and num_outputs) else None + ) # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: - slices = (outputs.outputs, ) + slices = (outputs.outputs,) else: slices = np.array_split( outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE), + ) for i, outputs_slice in enumerate(slices): # 2) Process EngineCoreOutputs. processed_outputs = output_processor.process_outputs( - outputs_slice, outputs.timestamp, iteration_stats) + outputs_slice, outputs.timestamp, iteration_stats + ) # NOTE: RequestOutputs are pushed to their queues. assert not processed_outputs.request_outputs @@ -438,7 +501,8 @@ class AsyncLLM(EngineClient): # 3) Abort any reqs that finished due to stop strings. await engine_core.abort_requests_async( - processed_outputs.reqs_to_abort) + processed_outputs.reqs_to_abort + ) # 4) Logging. # TODO(rob): make into a coroutine and launch it in @@ -448,6 +512,7 @@ class AsyncLLM(EngineClient): engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=processor.stat_mm_cache(), ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") @@ -458,8 +523,9 @@ class AsyncLLM(EngineClient): async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = (request_id, ) if isinstance( - request_id, str) else as_list(request_id) + request_ids = ( + (request_id,) if isinstance(request_id, str) else as_list(request_id) + ) all_request_ids = self.output_processor.abort_requests(request_ids) await self.engine_core.abort_requests_async(all_request_ids) @@ -474,6 +540,7 @@ class AsyncLLM(EngineClient): lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + truncate_prompt_tokens: Optional[int] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ @@ -496,14 +563,22 @@ class AsyncLLM(EngineClient): # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() + if tokenization_kwargs is None: + tokenization_kwargs = {} + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) + q = await self.add_request( request_id, prompt, pooling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, - tokenization_kwargs=tokenization_kwargs, ) # The output_handler task pushes items into the queue. @@ -546,36 +621,26 @@ class AsyncLLM(EngineClient): logger.info("Request %s failed.", request_id) raise EngineGenerateError() from e - async def get_vllm_config(self) -> VllmConfig: - return self.vllm_config + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.processor.tokenizer - async def get_model_config(self) -> ModelConfig: - return self.model_config + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.processor.tokenizer = tokenizer - async def get_decoding_config(self): - raise ValueError("Not Supported on V1 yet.") - - async def get_input_preprocessor(self) -> InputPreprocessor: - return self.processor.input_preprocessor - - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: + async def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") + raise ValueError( + "Unable to get tokenizer because skip_tokenizer_init is True" + ) - return self.tokenizer.get_lora_tokenizer(lora_request) + return self.tokenizer async def is_tracing_enabled(self) -> bool: - return False + return self.observability_config.otlp_traces_endpoint is not None - async def do_log_stats( - self, - scheduler_outputs=None, - model_output=None, - ) -> None: + async def do_log_stats(self) -> None: if self.logger_manager: self.logger_manager.log() @@ -597,12 +662,10 @@ class AsyncLLM(EngineClient): await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: if device == Device.CPU: raise ValueError("Not supported on CPU.") await self.engine_core.reset_prefix_cache_async() @@ -633,16 +696,19 @@ class AsyncLLM(EngineClient): """Prevent an adapter from being evicted.""" return await self.engine_core.pin_lora_async(lora_id) - async def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): + async def collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): """ Perform a collective RPC call to the given path. """ return await self.engine_core.collective_rpc_async( - method, timeout, args, kwargs) + method, timeout, args, kwargs + ) async def wait_for_requests_to_drain(self, drain_timeout: int = 300): """Wait for all requests to be drained.""" @@ -652,16 +718,17 @@ class AsyncLLM(EngineClient): logger.info("Engines are idle, requests have been drained") return - logger.info( - "Engines are still running, waiting for requests to drain...") + logger.info("Engines are still running, waiting for requests to drain...") await asyncio.sleep(1) # Wait 1 second before checking again - raise TimeoutError(f"Timeout reached after {drain_timeout} seconds " - "waiting for requests to drain.") + raise TimeoutError( + f"Timeout reached after {drain_timeout} seconds " + "waiting for requests to drain." + ) - async def scale_elastic_ep(self, - new_data_parallel_size: int, - drain_timeout: int = 300): + async def scale_elastic_ep( + self, new_data_parallel_size: int, drain_timeout: int = 300 + ): """ Scale up or down the data parallel size by adding or removing engine cores. @@ -670,22 +737,24 @@ class AsyncLLM(EngineClient): drain_timeout: Maximum time to wait for requests to drain (seconds) """ - old_data_parallel_size = \ - self.vllm_config.parallel_config.data_parallel_size + old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if old_data_parallel_size == new_data_parallel_size: - logger.info("Data parallel size is already %s, skipping scale", - new_data_parallel_size) + logger.info( + "Data parallel size is already %s, skipping scale", + new_data_parallel_size, + ) return logger.info( - "Waiting for requests to drain before " - "scaling up to %s engines...", new_data_parallel_size) + "Waiting for requests to drain before scaling up to %s engines...", + new_data_parallel_size, + ) await self.wait_for_requests_to_drain(drain_timeout) logger.info( - "Requests have been drained, proceeding with scale " - "to %s engines", new_data_parallel_size) + "Requests have been drained, proceeding with scale to %s engines", + new_data_parallel_size, + ) await self.engine_core.scale_elastic_ep(new_data_parallel_size) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # recreate stat loggers if new_data_parallel_size > old_data_parallel_size and self.log_stats: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 596edfdbe24f8..9bb08e6db7bec 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -56,7 +56,6 @@ class DPCoordinator: """ def __init__(self, parallel_config: ParallelConfig): - dp_size = parallel_config.data_parallel_size assert dp_size > 1, "Coordinator only used for data parallel" @@ -68,7 +67,8 @@ class DPCoordinator: # either external or hybrid DP LB mode. local_only = not (external_lb or hybrid_lb) front_publish_address = get_engine_client_zmq_addr( - local_only=local_only, host=host) + local_only=local_only, host=host + ) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -84,7 +84,8 @@ class DPCoordinator: "back_output_address": back_output_address, "back_publish_address": back_publish_address, }, - daemon=True) + daemon=True, + ) self.proc.start() self.stats_publish_address = front_publish_address @@ -104,16 +105,12 @@ class DPCoordinator: class EngineState: - def __init__(self): self.request_counts = [0, 0] # [waiting, running] class DPCoordinatorProc: - - def __init__(self, - engine_count: int, - min_stats_update_interval_ms: int = 100): + def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100): set_process_title("DPCoordinator") self.ctx = zmq.Context() @@ -131,7 +128,8 @@ class DPCoordinatorProc: ): coordinator = DPCoordinatorProc( engine_count=engine_count, - min_stats_update_interval_ms=min_stats_update_interval_ms) + min_stats_update_interval_ms=min_stats_update_interval_ms, + ) try: coordinator.process_input_socket( front_publish_address, @@ -141,10 +139,12 @@ class DPCoordinatorProc: except KeyboardInterrupt: logger.info("DP Coordinator process exiting") - def process_input_socket(self, front_publish_address: str, - back_output_address: str, - back_publish_address: str): - + def process_input_socket( + self, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): decoder = MsgpackDecoder(EngineCoreOutputs) # For tracking request wave progression. @@ -157,29 +157,33 @@ class DPCoordinatorProc: last_stats_wave = -1 last_step_counts: Optional[list[list[int]]] = None - with make_zmq_socket( + with ( + make_zmq_socket( path=front_publish_address, # IPC ctx=self.ctx, socket_type=zmq.XPUB, bind=True, - ) as publish_front, make_zmq_socket( + ) as publish_front, + make_zmq_socket( path=back_output_address, # IPC or TCP ctx=self.ctx, socket_type=zmq.PULL, bind=True, - ) as output_back, make_zmq_socket( + ) as output_back, + make_zmq_socket( path=back_publish_address, # IPC or TCP ctx=self.ctx, socket_type=zmq.XPUB, bind=True, - ) as publish_back: - + ) as publish_back, + ): # Wait until all engines subscribe. for _ in self.engines: - if publish_back.recv() != b'\x01': + if publish_back.recv() != b"\x01": logger.error( "DP Coordinator received unexpected message while " - "waiting for engines to subscribe") + "waiting for engines to subscribe" + ) return # Send ready message to engines. publish_back.send(b"READY") @@ -194,15 +198,13 @@ class DPCoordinatorProc: elapsed = int(time.time() * 1000) - last_publish_time # Send at stats_update_interval_ms interval if the stats have # changed, or otherwise every 5 seconds. - wait_for = (self.stats_update_interval_ms - if stats_changed else 5000) + wait_for = self.stats_update_interval_ms if stats_changed else 5000 # Wait at least 50ms to ensure we've received all stats for # the current step. min_timeout = 50 if last_step_counts is None else 0 - events = poller.poll(timeout=max(min_timeout, wait_for - - elapsed)) + events = poller.poll(timeout=max(min_timeout, wait_for - elapsed)) if not events: # Poller timeout - publish current stats to front-ends. if last_step_counts is not None: @@ -212,8 +214,7 @@ class DPCoordinatorProc: engine_req_counts_list = self._get_engine_counts() stats_changed = False - to_publish = (engine_req_counts_list, current_wave, - engines_running) + to_publish = (engine_req_counts_list, current_wave, engines_running) publish_front.send(msgspec.msgpack.encode(to_publish)) last_publish_time = int(time.time() * 1000) continue @@ -223,13 +224,16 @@ class DPCoordinatorProc: if publish_front in events: buffer = publish_front.recv() - if buffer in (b'\x01', b'\x00'): + if buffer in (b"\x01", b"\x00"): # Ignore subscription messages. continue decoded = msgspec.msgpack.decode(buffer) - if isinstance(decoded, (list, tuple)) and len( - decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP": + if ( + isinstance(decoded, (list, tuple)) + and len(decoded) == 2 + and decoded[0] == "SCALE_ELASTIC_EP" + ): # Handle scale up notification new_engine_count = decoded[1] current_count = len(self.engines) @@ -248,13 +252,17 @@ class DPCoordinatorProc: # engine engines_running = False logger.info( - "DPCoordinator scaled up from %s to %s " - "engines", current_count, new_engine_count) + "DPCoordinator scaled up from %s to %s engines", + current_count, + new_engine_count, + ) else: self.engines = self.engines[:new_engine_count] logger.info( - "DPCoordinator scaled down from %s to %s " - "engines", current_count, new_engine_count) + "DPCoordinator scaled down from %s to %s engines", + current_count, + new_engine_count, + ) continue # Skip normal engine notification processing # We received a message on the front-end XPUB socket, @@ -270,8 +278,9 @@ class DPCoordinatorProc: engines_running = True wave_state_changed = True - self._send_start_wave(publish_back, current_wave, - engine_to_exclude) + self._send_start_wave( + publish_back, current_wave, engine_to_exclude + ) if output_back in events: # We received a message from one of the engines. @@ -290,21 +299,28 @@ class DPCoordinatorProc: stats = self.engines[eng_index].request_counts stats_step = scheduler_stats.step_counter stats_wave = scheduler_stats.current_wave - if (stats_wave > last_stats_wave - or stats_wave == last_stats_wave - and stats_step > last_stats_step): + if ( + stats_wave > last_stats_wave + or stats_wave == last_stats_wave + and stats_step > last_stats_step + ): if stats_changed: - last_step_counts = self._get_engine_counts( - do_copy=True) + last_step_counts = self._get_engine_counts(do_copy=True) last_stats_step = stats_step last_stats_wave = stats_wave elif stats_wave != last_stats_wave or ( - stats_step != last_stats_step): + stats_step != last_stats_step + ): logger.warning( "Received stats for out-of-order " "step (%d, %d) from engine %d (expected " - "> (%d, %d))", stats_wave, stats_step, - eng_index, last_stats_wave, last_stats_step) + "> (%d, %d))", + stats_wave, + stats_step, + eng_index, + last_stats_wave, + last_stats_step, + ) stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs stats_changed = True @@ -315,20 +331,24 @@ class DPCoordinatorProc: # (engines_running==False). if current_wave <= wave: new_wave = wave + 1 - logger.debug("Moving DP wave from %d to %d.", - current_wave, new_wave) + logger.debug( + "Moving DP wave from %d to %d.", current_wave, new_wave + ) current_wave = new_wave engines_running = False wave_state_changed = True elif (wave := outputs.start_wave) is not None and ( - wave > current_wave or - (wave == current_wave and not engines_running)): + wave > current_wave + or (wave == current_wave and not engines_running) + ): # 3. The engine received request for a non-current wave # so we must ensure that other engines progress to the # next wave (race condition handling). logger.debug( "Starting wave %d after notification of " - "stale wave request from engine.", wave) + "stale wave request from engine.", + wave, + ) current_wave = wave engines_running = True wave_state_changed = True @@ -339,16 +359,16 @@ class DPCoordinatorProc: publish_front.send(msgspec.msgpack.encode(message)) @staticmethod - def _send_start_wave(socket: zmq.Socket, wave: int, - exclude_engine_index: Optional[int]): + def _send_start_wave( + socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int] + ): """Broadcast the START_DP_WAVE message to all the engines. It includes the current wave number and index of engine which has already received a request with this wave number and so doesn't require additional notification. """ wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) - socket.send_multipart( - (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) def _get_engine_counts(self, do_copy=False) -> list[list[int]]: """Return list of [waiting, running] count lists for each engine.""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 32765cda6482f..e6474d91ffedb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import os import queue import signal @@ -22,24 +23,41 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, - resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, - get_request_block_hasher, - init_none_hash, - unify_kv_cache_configs) +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.utils import ( + decorate_logs, + get_hash_fn_by_name, + make_zmq_socket, + resolve_obj_by_qualname, + set_process_title, +) +from vllm.utils.gc_utils import maybe_attach_gc_debug_callback +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_request_block_hasher, + init_none_hash, +) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, - ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer -from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses +from vllm.v1.engine import ( + EngineCoreOutputs, + EngineCoreRequest, + EngineCoreRequestType, + ReconfigureDistributedRequest, + ReconfigureRankType, + UtilityOutput, + UtilityResult, +) +from vllm.v1.engine.utils import ( + EngineHandshakeMetadata, + EngineZmqAddresses, + get_device_indices, +) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -54,51 +72,56 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 HANDSHAKE_TIMEOUT_MINS = 5 -_R = TypeVar('_R') # Return type for collective_rpc +_R = TypeVar("_R") # Return type for collective_rpc class EngineCore: """Inner loop of vLLM's Engine.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - executor_fail_callback: Optional[Callable] = None): - + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Optional[Callable] = None, + ): # plugins need to be loaded at the engine/scheduler level too from vllm.plugins import load_general_plugins + load_general_plugins() self.vllm_config = vllm_config - logger.info("Initializing a V1 LLM engine (v%s) with config: %s", - VLLM_VERSION, vllm_config) + logger.info( + "Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, + vllm_config, + ) self.log_stats = log_stats # Setup Model. self.model_executor = executor_class(vllm_config) if executor_fail_callback is not None: - self.model_executor.register_failure_callback( - executor_fail_callback) + self.model_executor.register_failure_callback(executor_fail_callback) self.available_gpu_memory_for_kv_cache = -1 # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ - self._initialize_kv_caches(vllm_config) + num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( + vllm_config + ) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks - self.collective_rpc("initialize_cache", - args=(num_gpu_blocks, num_cpu_blocks)) + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. if isinstance(vllm_config.scheduler_config.scheduler_cls, str): Scheduler = resolve_obj_by_qualname( - vllm_config.scheduler_config.scheduler_cls) + vllm_config.scheduler_config.scheduler_cls + ) else: Scheduler = vllm_config.scheduler_config.scheduler_cls @@ -110,7 +133,8 @@ class EngineCore: "Using configured V1 scheduler class %s. " "This scheduler interface is not public and " "compatibility may not be maintained.", - vllm_config.scheduler_config.scheduler_cls) + vllm_config.scheduler_config.scheduler_cls, + ) if len(kv_cache_config.kv_cache_groups) == 0: # Encoder models without KV cache don't support @@ -118,46 +142,63 @@ class EngineCore: logger.info("Disabling chunked prefill for model without KVCache") vllm_config.scheduler_config.chunked_prefill_enabled = False + scheduler_block_size = ( + vllm_config.cache_config.block_size + * vllm_config.parallel_config.decode_context_parallel_size + ) + self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=self.structured_output_manager, - include_finished_set=vllm_config.parallel_config.data_parallel_size - > 1, + include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, + block_size=scheduler_block_size, ) self.use_spec_decode = vllm_config.speculative_config is not None + if self.scheduler.connector is not None: # type: ignore + self.model_executor.init_kv_output_aggregator( + self.scheduler.connector.get_finished_count() # type: ignore + ) - self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config, MULTIMODAL_REGISTRY) + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = engine_receiver_cache_from_config( + vllm_config, mm_registry + ) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], - SchedulerOutput]]] = None + self.batch_queue: Optional[ + deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] + ] = None if self.batch_queue_size > 1: - logger.info("Batch queue is enabled with size %d", - self.batch_queue_size) - self.batch_queue = queue.Queue(self.batch_queue_size) + logger.info("Batch queue is enabled with size %d", self.batch_queue_size) + self.batch_queue = deque(maxlen=self.batch_queue_size) - self.request_block_hasher: Optional[Callable[[Request], - list[BlockHash]]] = None - if (self.vllm_config.cache_config.enable_prefix_caching - or self.scheduler.get_kv_connector() is not None): - - block_size = vllm_config.cache_config.block_size + self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None + if ( + self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None + ): caching_hash_fn = get_hash_fn_by_name( - vllm_config.cache_config.prefix_caching_hash_algo) + vllm_config.cache_config.prefix_caching_hash_algo + ) init_none_hash(caching_hash_fn) self.request_block_hasher = get_request_block_hasher( - block_size, caching_hash_fn) + scheduler_block_size, caching_hash_fn + ) + + self.step_fn = ( + self.step if self.batch_queue is None else self.step_with_batch_queue + ) def _initialize_kv_caches( - self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + self, vllm_config: VllmConfig + ) -> tuple[int, int, KVCacheConfig]: start = time.time() # Get all kv cache needed by the model @@ -168,52 +209,38 @@ class EngineCore: if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None - self.available_gpu_memory_for_kv_cache = \ + self.available_gpu_memory_for_kv_cache = ( ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) - available_gpu_memory = [ - self.available_gpu_memory_for_kv_cache - ] * len(kv_cache_specs) + ) + available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( + kv_cache_specs + ) else: # Profiles the peak memory usage of the model to determine how # much memory can be allocated for kv cache. - available_gpu_memory = ( - self.model_executor.determine_available_memory()) - self.available_gpu_memory_for_kv_cache = \ - available_gpu_memory[0] + available_gpu_memory = self.model_executor.determine_available_memory() + self.available_gpu_memory_for_kv_cache = available_gpu_memory[0] else: # Attention free models don't need memory for kv cache available_gpu_memory = [0] * len(kv_cache_specs) assert len(kv_cache_specs) == len(available_gpu_memory) - # Get the kv cache tensor size - kv_cache_configs = [ - get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, - available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) - ] - # Since we use a shared centralized controller, we need the - # `kv_cache_config` to be consistent across all workers to make sure - # all the memory operators can be applied to all workers. - unify_kv_cache_configs(kv_cache_configs) - - # All workers have the same kv_cache_config except layer names, so use - # an arbitrary one to initialize the scheduler. - assert all([ - cfg.num_blocks == kv_cache_configs[0].num_blocks - for cfg in kv_cache_configs - ]) - num_gpu_blocks = kv_cache_configs[0].num_blocks + kv_cache_configs = get_kv_cache_configs( + vllm_config, kv_cache_specs, available_gpu_memory + ) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + num_gpu_blocks = scheduler_kv_cache_config.num_blocks num_cpu_blocks = 0 - scheduler_kv_cache_config = kv_cache_configs[0] # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) elapsed = time.time() - start - logger.info(("init engine (profile, create kv cache, " - "warmup model) took %.2f seconds"), elapsed) + logger.info( + ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), + elapsed, + ) return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config def get_supported_tasks(self) -> tuple[SupportedTask, ...]: @@ -221,29 +248,34 @@ class EngineCore: def add_request(self, request: Request, request_wave: int = 0): """Add request to the scheduler. - + `request_wave`: indicate which wave of requests this is expected to belong to in DP case """ # Validate the request_id type. if not isinstance(request.request_id, str): raise TypeError( - f"request_id must be a string, got {type(request.request_id)}") + f"request_id must be a string, got {type(request.request_id)}" + ) if pooling_params := request.pooling_params: supported_pooling_tasks = [ - task for task in self.get_supported_tasks() - if task in POOLING_TASKS + task for task in self.get_supported_tasks() if task in POOLING_TASKS ] if pooling_params.task not in supported_pooling_tasks: - raise ValueError(f"Unsupported task: {pooling_params.task!r} " - f"Supported tasks: {supported_pooling_tasks}") + raise ValueError( + f"Unsupported task: {pooling_params.task!r} " + f"Supported tasks: {supported_pooling_tasks}" + ) if request.kv_transfer_params is not None and ( - not self.scheduler.get_kv_connector()): - logger.warning("Got kv_transfer_params, but no KVConnector found. " - "Disabling KVTransfer for this request.") + not self.scheduler.get_kv_connector() + ): + logger.warning( + "Got kv_transfer_params, but no KVConnector found. " + "Disabling KVTransfer for this request." + ) self.scheduler.add_request(request) @@ -253,8 +285,7 @@ class EngineCore: # TODO: The scheduler doesn't really need to know the # specific finish reason, TBD whether we propagate that # (i.e. client-aborted vs stop criteria met). - self.scheduler.finish_requests(request_ids, - RequestStatus.FINISHED_ABORTED) + self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) def execute_model_with_error_logging( self, @@ -270,8 +301,9 @@ class EngineCore: # error from execute_model itself. # NOTE: This method is exception-free - dump_engine_exception(self.vllm_config, scheduler_output, - self.scheduler.make_stats()) + dump_engine_exception( + self.vllm_config, scheduler_output, self.scheduler.make_stats() + ) raise err def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: @@ -288,12 +320,13 @@ class EngineCore: scheduler_output = self.scheduler.schedule() model_output = self.execute_model_with_error_logging( self.model_executor.execute_model, # type: ignore - scheduler_output) + scheduler_output, + ) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) # type: ignore + scheduler_output, model_output + ) - return (engine_core_outputs, - scheduler_output.total_num_scheduled_tokens > 0) + return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) def post_step(self, model_executed: bool) -> None: if self.use_spec_decode and model_executed: @@ -303,7 +336,8 @@ class EngineCore: self.scheduler.update_draft_token_ids(draft_token_ids) def step_with_batch_queue( - self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: + self, + ) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -317,41 +351,47 @@ class EngineCore: batch in the job queue is finished. 3. Update the scheduler from the output. """ - assert self.batch_queue is not None + batch_queue = self.batch_queue + assert batch_queue is not None - engine_core_outputs = None - scheduler_output = None # Try to schedule a new batch if the batch queue is not full, but # the scheduler may return an empty batch if all requests are scheduled. # Note that this is not blocking. - if not self.batch_queue.full(): + assert len(batch_queue) < self.batch_queue_size + + model_executed = False + if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - if scheduler_output.total_num_scheduled_tokens > 0: - future = self.model_executor.execute_model(scheduler_output) - self.batch_queue.put_nowait( - (future, scheduler_output)) # type: ignore + future = self.model_executor.execute_model(scheduler_output, non_block=True) + batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type] - scheduled_batch = (scheduler_output is not None - and scheduler_output.total_num_scheduled_tokens > 0) + model_executed = scheduler_output.total_num_scheduled_tokens > 0 + if ( + model_executed + and len(batch_queue) < self.batch_queue_size + and not batch_queue[-1][0].done() + ): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True - # If no more requests can be scheduled and the job queue is not empty, - # block until the first batch in the job queue is finished. - # TODO(comaniac): Ideally we should peek the first batch in the - # job queue to check if it's finished before scheduling a new batch, - # but peeking the first element in a queue is not thread-safe, - # so we need more work. - if not scheduled_batch and not self.batch_queue.empty(): - future, scheduler_output = self.batch_queue.get_nowait() + elif not batch_queue: + # Queue is empty. We should not reach here since this method should + # only be called when the scheduler contains requests or the queue + # is non-empty. + return None, False - # Blocking until the first result is available. - model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output) + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + model_output = self.execute_model_with_error_logging( + lambda _: future.result(), scheduler_output + ) - self.batch_queue.task_done() - engine_core_outputs = (self.scheduler.update_from_output( - scheduler_output, model_output)) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) - return engine_core_outputs, scheduled_batch + return engine_core_outputs, model_executed def shutdown(self): self.structured_output_manager.clear_backend() @@ -365,12 +405,18 @@ class EngineCore: def reset_mm_cache(self): # NOTE: Since this is mainly for debugging, we don't attempt to - # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) + # re-sync the internal caches (P0 sender, P1 receiver) if self.scheduler.has_unfinished_requests(): - logger.warning("Resetting the multi-modal cache when requests are " - "in progress may lead to desynced internal caches.") + logger.warning( + "Resetting the multi-modal cache when requests are " + "in progress may lead to desynced internal caches." + ) - self.mm_input_cache_server.reset() + # The cache either exists in EngineCore or WorkerWrapperBase + if self.mm_receiver_cache is not None: + self.mm_receiver_cache.clear_cache() + + self.model_executor.reset_mm_cache() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -385,7 +431,7 @@ class EngineCore: return self.model_executor.is_sleeping def execute_dummy_batch(self): - self.model_executor.collective_rpc("execute_dummy_batch") + self.model_executor.execute_dummy_batch() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) @@ -405,43 +451,42 @@ class EngineCore: pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self.model_executor.save_sharded_state(path=path, - pattern=pattern, - max_size=max_size) + self.model_executor.save_sharded_state( + path=path, pattern=pattern, max_size=max_size + ) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.model_executor.collective_rpc(method, timeout, args, - kwargs) + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, kwargs) def save_tensorized_model( self, tensorizer_config, ) -> None: self.model_executor.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + tensorizer_config=tensorizer_config, + ) - def preprocess_add_request( - self, request: EngineCoreRequest) -> tuple[Request, int]: + def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. - + This function could be directly used in input processing thread to allow request initialization running in parallel with Model forward """ - if request.mm_hashes is not None: - assert request.mm_kwargs is not None + # Note on thread safety: no race condition. + # `mm_receiver_cache` is reset at the end of LLMEngine init, + # and will only be accessed in the input processing thread afterwards. + if self.mm_receiver_cache is not None and request.mm_features: + request.mm_features = self.mm_receiver_cache.get_and_update_features( + request.mm_features + ) - # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, - # and will only accessed in the input processing thread afterwards. - request.mm_kwargs = self.mm_input_cache_server.get_and_update( - request.mm_kwargs, request.mm_hashes) - - req = Request.from_engine_core_request(request, - self.request_block_hasher) + req = Request.from_engine_core_request(request, self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For @@ -455,7 +500,7 @@ class EngineCore: class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" - ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" def __init__( self, @@ -468,37 +513,46 @@ class EngineCoreProc(EngineCore): engine_index: int = 0, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() - self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], - bytes]]() + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]() executor_fail_callback = lambda: self.input_queue.put_nowait( - (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + (EngineCoreRequestType.EXECUTOR_FAILED, b"") + ) self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False - with self._perform_handshakes(handshake_address, identity, - local_client, vllm_config, - client_handshake_address) as addresses: + with self._perform_handshakes( + handshake_address, + identity, + local_client, + vllm_config, + client_handshake_address, + ) as addresses: self.client_count = len(addresses.outputs) # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( - addresses.frontend_stats_publish_address) - logger.debug("Has DP Coordinator: %s, stats publish address: %s", - self.has_coordinator, - self.frontend_stats_publish_address) + addresses.frontend_stats_publish_address + ) + logger.debug( + "Has DP Coordinator: %s, stats publish address: %s", + self.has_coordinator, + self.frontend_stats_publish_address, + ) # Only publish request queue stats to coordinator for "internal" # and "hybrid" LB modes . self.publish_dp_lb_stats = ( self.has_coordinator - and not vllm_config.parallel_config.data_parallel_external_lb) + and not vllm_config.parallel_config.data_parallel_external_lb + ) self._init_data_parallel(vllm_config) - super().__init__(vllm_config, executor_class, log_stats, - executor_fail_callback) + super().__init__( + vllm_config, executor_class, log_stats, executor_fail_callback + ) # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, @@ -506,31 +560,44 @@ class EngineCoreProc(EngineCore): # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. ready_event = threading.Event() - input_thread = threading.Thread(target=self.process_input_sockets, - args=(addresses.inputs, - addresses.coordinator_input, - identity, ready_event), - daemon=True) + input_thread = threading.Thread( + target=self.process_input_sockets, + args=( + addresses.inputs, + addresses.coordinator_input, + identity, + ready_event, + ), + daemon=True, + ) input_thread.start() self.output_thread = threading.Thread( target=self.process_output_sockets, - args=(addresses.outputs, addresses.coordinator_output, - self.engine_index), - daemon=True) + args=( + addresses.outputs, + addresses.coordinator_output, + self.engine_index, + ), + daemon=True, + ) self.output_thread.start() # Don't complete handshake until DP coordinator ready message is # received. while not ready_event.wait(timeout=10): if not input_thread.is_alive(): - raise RuntimeError( - "Input socket thread died during startup") + raise RuntimeError("Input socket thread died during startup") assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() + + # If enable, attach GC debugger after static variable freeze. + maybe_attach_gc_debug_callback() @contextmanager def _perform_handshakes( @@ -566,18 +633,23 @@ class EngineCoreProc(EngineCore): input_ctx = zmq.Context() is_local = local_client and client_handshake_address is None headless = not local_client - handshake = self._perform_handshake(input_ctx, handshake_address, - identity, is_local, headless, - vllm_config, - vllm_config.parallel_config) + handshake = self._perform_handshake( + input_ctx, + handshake_address, + identity, + is_local, + headless, + vllm_config, + vllm_config.parallel_config, + ) if client_handshake_address is None: with handshake as addresses: yield addresses else: assert local_client local_handshake = self._perform_handshake( - input_ctx, client_handshake_address, identity, True, False, - vllm_config) + input_ctx, client_handshake_address, identity, True, False, vllm_config + ) with handshake as addresses, local_handshake as client_addresses: addresses.inputs = client_addresses.inputs addresses.outputs = client_addresses.outputs @@ -597,16 +669,18 @@ class EngineCoreProc(EngineCore): vllm_config: VllmConfig, parallel_config_to_update: Optional[ParallelConfig] = None, ) -> Generator[EngineZmqAddresses, None, None]: - with make_zmq_socket(ctx, - handshake_address, - zmq.DEALER, - identity=identity, - linger=5000, - bind=False) as handshake_socket: + with make_zmq_socket( + ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False, + ) as handshake_socket: # Register engine with front-end. - addresses = self.startup_handshake(handshake_socket, local_client, - headless, - parallel_config_to_update) + addresses = self.startup_handshake( + handshake_socket, local_client, headless, parallel_config_to_update + ) yield addresses # Send ready message. @@ -615,14 +689,21 @@ class EngineCoreProc(EngineCore): # external LB case for our colocated front-end to use (coordinator # only runs with rank 0). dp_stats_address = self.frontend_stats_publish_address - handshake_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": local_client, - "headless": headless, - "num_gpu_blocks": num_gpu_blocks, - "dp_stats_address": dp_stats_address, - })) + + # Include config hash for DP configuration validation + ready_msg = { + "status": "READY", + "local": local_client, + "headless": headless, + "num_gpu_blocks": num_gpu_blocks, + "dp_stats_address": dp_stats_address, + } + if vllm_config.parallel_config.data_parallel_size > 1: + ready_msg["parallel_config_hash"] = ( + vllm_config.parallel_config.compute_hash() + ) + + handshake_socket.send(msgspec.msgpack.encode(ready_msg)) @staticmethod def startup_handshake( @@ -631,24 +712,29 @@ class EngineCoreProc(EngineCore): headless: bool, parallel_config: Optional[ParallelConfig] = None, ) -> EngineZmqAddresses: - # Send registration message. handshake_socket.send( - msgspec.msgpack.encode({ - "status": "HELLO", - "local": local_client, - "headless": headless, - })) + msgspec.msgpack.encode( + { + "status": "HELLO", + "local": local_client, + "headless": headless, + } + ) + ) # Receive initialization message. logger.info("Waiting for init message from front-end.") if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): - raise RuntimeError("Did not receive response from front-end " - f"process within {HANDSHAKE_TIMEOUT_MINS} " - f"minutes") + raise RuntimeError( + "Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes" + ) init_bytes = handshake_socket.recv() init_message: EngineHandshakeMetadata = msgspec.msgpack.decode( - init_bytes, type=EngineHandshakeMetadata) + init_bytes, type=EngineHandshakeMetadata + ) logger.debug("Received init message: %s", init_message) if parallel_config is not None: @@ -658,10 +744,7 @@ class EngineCoreProc(EngineCore): return init_message.addresses @staticmethod - def run_engine_core(*args, - dp_rank: int = 0, - local_dp_rank: int = 0, - **kwargs): + def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -684,10 +767,9 @@ class EngineCoreProc(EngineCore): engine_core: Optional[EngineCoreProc] = None try: - parallel_config: ParallelConfig = kwargs[ - "vllm_config"].parallel_config + parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: - set_process_title("DPEngineCore", str(dp_rank)) + set_process_title("EngineCore", f"DP{dp_rank}") decorate_logs() # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank @@ -731,7 +813,11 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not self.scheduler.has_requests(): + while ( + not self.engines_running + and not self.scheduler.has_requests() + and not self.batch_queue + ): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -752,15 +838,16 @@ class EngineCoreProc(EngineCore): # Step the engine core. outputs, model_executed = self.step_fn() # Put EngineCoreOutputs into the output queue. - for output in (outputs.items() if outputs else ()): + for output in outputs.items() if outputs else (): self.output_queue.put_nowait(output) # Post-step hook. self.post_step(model_executed) return model_executed - def _handle_client_request(self, request_type: EngineCoreRequestType, - request: Any) -> None: + def _handle_client_request( + self, request_type: EngineCoreRequestType, request: Any + ) -> None: """Dispatch request from client.""" if request_type == EngineCoreRequestType.ADD: @@ -777,29 +864,35 @@ class EngineCoreProc(EngineCore): output.result = UtilityResult(result) except BaseException as e: logger.exception("Invocation of %s method failed", method_name) - output.failure_message = (f"Call to {method_name} method" - f" failed: {str(e)}") + output.failure_message = ( + f"Call to {method_name} method failed: {str(e)}" + ) self.output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output))) + (client_idx, EngineCoreOutputs(utility_output=output)) + ) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: - logger.error("Unrecognized input request type encountered: %s", - request_type) + logger.error( + "Unrecognized input request type encountered: %s", request_type + ) @staticmethod def _convert_msgspec_args(method, args): """If a provided arg type doesn't match corresponding target method - arg type, try converting to msgspec object.""" + arg type, try converting to msgspec object.""" if not args: return args arg_types = signature(method).parameters.values() assert len(args) <= len(arg_types) return tuple( - msgspec.convert(v, type=p.annotation) if isclass(p.annotation) + msgspec.convert(v, type=p.annotation) + if isclass(p.annotation) and issubclass(p.annotation, msgspec.Struct) - and not isinstance(v, p.annotation) else v - for v, p in zip(args, arg_types)) + and not isinstance(v, p.annotation) + else v + for v, p in zip(args, arg_types) + ) def _send_engine_dead(self): """Send EngineDead status to the EngineCoreClient.""" @@ -810,12 +903,18 @@ class EngineCoreProc(EngineCore): # Wait until msg sent by the daemon before shutdown. self.output_thread.join(timeout=5.0) if self.output_thread.is_alive(): - logger.fatal("vLLM shutdown signal from EngineCore failed " - "to send. Please report this issue.") + logger.fatal( + "vLLM shutdown signal from EngineCore failed " + "to send. Please report this issue." + ) - def process_input_sockets(self, input_addresses: list[str], - coord_input_address: Optional[str], - identity: bytes, ready_event: threading.Event): + def process_input_sockets( + self, + input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes, + ready_event: threading.Event, + ): """Input socket IO thread.""" # Msgpack serialization decoding. @@ -825,24 +924,26 @@ class EngineCoreProc(EngineCore): with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ stack.enter_context( - make_zmq_socket(ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False)) + make_zmq_socket( + ctx, input_address, zmq.DEALER, identity=identity, bind=False + ) + ) for input_address in input_addresses ] if coord_input_address is None: coord_socket = None else: coord_socket = stack.enter_context( - make_zmq_socket(ctx, - coord_input_address, - zmq.XSUB, - identity=identity, - bind=False)) + make_zmq_socket( + ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False, + ) + ) # Send subscription message to coordinator. - coord_socket.send(b'\x01') + coord_socket.send(b"\x01") # Register sockets with poller. poller = zmq.Poller() @@ -850,7 +951,7 @@ class EngineCoreProc(EngineCore): # Send initial message to each input socket - this is required # before the front-end ROUTER socket can send input messages # back to us. - input_socket.send(b'') + input_socket.send(b"") poller.register(input_socket, zmq.POLLIN) if coord_socket is not None: @@ -863,10 +964,8 @@ class EngineCoreProc(EngineCore): while True: for input_socket, _ in poller.poll(): # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart( - copy=False) - request_type = EngineCoreRequestType( - bytes(type_frame.buffer)) + type_frame, *data_frames = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. if request_type == EngineCoreRequestType.ADD: @@ -878,9 +977,12 @@ class EngineCoreProc(EngineCore): # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) - def process_output_sockets(self, output_paths: list[str], - coord_output_path: Optional[str], - engine_index: int): + def process_output_sockets( + self, + output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int, + ): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -897,13 +999,19 @@ class EngineCoreProc(EngineCore): with ExitStack() as stack, zmq.Context() as ctx: sockets = [ stack.enter_context( - make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000) + ) for output_path in output_paths ] - coord_socket = stack.enter_context( - make_zmq_socket( - ctx, coord_output_path, zmq.PUSH, bind=False, - linger=4000)) if coord_output_path is not None else None + coord_socket = ( + stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000 + ) + ) + if coord_output_path is not None + else None + ) max_reuse_bufs = len(sockets) + 1 while True: @@ -929,9 +1037,9 @@ class EngineCoreProc(EngineCore): buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = sockets[client_index].send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart( + buffers, copy=False, track=True + ) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) @@ -961,12 +1069,17 @@ class DPEngineCoreProc(EngineCoreProc): # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, local_client, handshake_address, - executor_class, log_stats, client_handshake_address, - dp_rank) + super().__init__( + vllm_config, + local_client, + handshake_address, + executor_class, + log_stats, + client_handshake_address, + dp_rank, + ) def _init_data_parallel(self, vllm_config: VllmConfig): - # Configure GPUs and stateless process group for data parallel. dp_rank = vllm_config.parallel_config.data_parallel_rank dp_size = vllm_config.parallel_config.data_parallel_size @@ -981,8 +1094,10 @@ class DPEngineCoreProc(EngineCoreProc): vllm_config.kv_transfer_config.engine_id = ( f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}" ) - logger.debug("Setting kv_transfer_config.engine_id to %s", - vllm_config.kv_transfer_config.engine_id) + logger.debug( + "Setting kv_transfer_config.engine_id to %s", + vllm_config.kv_transfer_config.engine_id, + ) self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() @@ -1000,20 +1115,22 @@ class DPEngineCoreProc(EngineCoreProc): # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - (-1, EngineCoreOutputs(start_wave=self.current_wave))) + (-1, EngineCoreOutputs(start_wave=self.current_wave)) + ) super().add_request(request, request_wave) - def _handle_client_request(self, request_type: EngineCoreRequestType, - request: Any) -> None: + def _handle_client_request( + self, request_type: EngineCoreRequestType, request: Any + ) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: new_wave, exclude_eng_index = request if exclude_eng_index != self.engine_index and ( - new_wave >= self.current_wave): + new_wave >= self.current_wave + ): self.current_wave = new_wave if not self.engines_running: - logger.debug("EngineCore starting idle loop for wave %d.", - new_wave) + logger.debug("EngineCore starting idle loop for wave %d.", new_wave) self.engines_running = True else: super()._handle_client_request(request_type, request) @@ -1026,11 +1143,10 @@ class DPEngineCoreProc(EngineCoreProc): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts, - step_counter=self.step_counter, - current_wave=self.current_wave) - self.output_queue.put_nowait( - (-1, EngineCoreOutputs(scheduler_stats=stats))) + stats = SchedulerStats( + *counts, step_counter=self.step_counter, current_wave=self.current_wave + ) + self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats))) def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -1056,58 +1172,65 @@ class DPEngineCoreProc(EngineCoreProc): # 3) All-reduce operation to determine global unfinished reqs. self.engines_running = self._has_global_unfinished_reqs( - local_unfinished_reqs) + local_unfinished_reqs + ) if not self.engines_running: if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop. - logger.debug("Wave %d finished, pausing engine loop.", - self.current_wave) + logger.debug( + "Wave %d finished, pausing engine loop.", self.current_wave + ) # In the coordinator case, dp rank 0 sends updates to the # coordinator. Otherwise (offline spmd case), each rank # sends the update to its colocated front-end process. client_index = -1 if self.has_coordinator else 0 self.output_queue.put_nowait( - (client_index, - EngineCoreOutputs(wave_complete=self.current_wave))) + ( + client_index, + EngineCoreOutputs(wave_complete=self.current_wave), + ) + ) # Increment wave count and reset step counter. self.current_wave += 1 self.step_counter = 0 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 32 steps. self.step_counter += 1 if self.step_counter % 32 != 0: return True - return ParallelConfig.has_unfinished_dp(self.dp_group, - local_unfinished) + return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: stateless_destroy_torch_distributed_process_group(self.dp_group) self.shutdown() parallel_config = self.vllm_config.parallel_config old_dp_size = parallel_config.data_parallel_size - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size if reconfig_request.new_data_parallel_rank != -1: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank # local rank specifies device visibility, it should not be changed - assert reconfig_request.new_data_parallel_rank_local == \ - ReconfigureRankType.KEEP_CURRENT_RANK - parallel_config.data_parallel_master_ip = \ + assert ( + reconfig_request.new_data_parallel_rank_local + == ReconfigureRankType.KEEP_CURRENT_RANK + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) if reconfig_request.new_data_parallel_rank != -2: self.dp_rank = parallel_config.data_parallel_rank self.dp_group = parallel_config.stateless_init_dp_group() - reconfig_request.new_data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port = ( parallel_config.data_parallel_master_port + ) self.model_executor.reinitialize_distributed(reconfig_request) if reconfig_request.new_data_parallel_size > old_dp_size: @@ -1116,17 +1239,21 @@ class DPEngineCoreProc(EngineCoreProc): # engine-cores to new engine-cores so they can directly # use it in _initialize_kv_caches() rather than profiling. ParallelConfig.sync_kv_cache_memory_size( - self.dp_group, self.available_gpu_memory_for_kv_cache) + self.dp_group, self.available_gpu_memory_for_kv_cache + ) # NOTE(yongji): newly joined workers require dummy_run even # CUDA graph is not used self.model_executor.collective_rpc("compile_or_warm_up_model") - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) else: - logger.info("Distributed environment reinitialized for DP rank %s", - self.dp_rank) + logger.info( + "Distributed environment reinitialized for DP rank %s", self.dp_rank + ) class DPEngineCoreActor(DPEngineCoreProc): @@ -1146,8 +1273,7 @@ class DPEngineCoreActor(DPEngineCoreProc): ): self.addresses = addresses vllm_config.parallel_config.data_parallel_rank = dp_rank - vllm_config.parallel_config.data_parallel_rank_local = \ - local_dp_rank + vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time, @@ -1166,33 +1292,48 @@ class DPEngineCoreActor(DPEngineCoreProc): # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # and get_accelerator_ids_for_accelerator_resource() in worker.py # of ray. - self._set_cuda_visible_devices(vllm_config, local_dp_rank) + self._set_visible_devices(vllm_config, local_dp_rank) - super().__init__(vllm_config, local_client, "", executor_class, - log_stats) + super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int): from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var + + if current_platform.is_xpu(): + pass + else: + device_control_env_var = current_platform.device_control_env_var + self._set_cuda_visible_devices( + vllm_config, local_dp_rank, device_control_env_var + ) + + def _set_cuda_visible_devices( + self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str + ): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) + value = get_device_indices( + device_control_env_var, local_dp_rank, world_size + ) + os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " - f"base value: \"{os.getenv(device_control_env_var)}\"") from e + f'base value: "{os.getenv(device_control_env_var)}"' + ) from e @contextmanager - def _perform_handshakes(self, handshake_address: str, identity: bytes, - local_client: bool, vllm_config: VllmConfig, - client_handshake_address: Optional[str]): + def _perform_handshakes( + self, + handshake_address: str, + identity: bytes, + local_client: bool, + vllm_config: VllmConfig, + client_handshake_address: Optional[str], + ): """ For Ray, we don't need to actually perform handshake. All addresses information is known before the actor creation. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 079dd9a7d38d1..27283411eada9 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,17 +23,29 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path, - in_loop, make_zmq_socket) -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, - ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput) +from vllm.utils import ( + close_sockets, + get_open_port, + get_open_zmq_inproc_path, + in_loop, + make_zmq_socket, +) +from vllm.v1.engine import ( + EngineCoreOutputs, + EngineCoreRequest, + EngineCoreRequestType, + ReconfigureDistributedRequest, + ReconfigureRankType, + UtilityOutput, +) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager, launch_core_engines) +from vllm.v1.engine.utils import ( + CoreEngineActorManager, + CoreEngineProcManager, + launch_core_engines, +) from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr @@ -41,14 +53,14 @@ logger = init_logger(__name__) AnyFuture = Union[asyncio.Future[Any], Future[Any]] -_R = TypeVar('_R') # Return type for collective_rpc +_R = TypeVar("_R") # Return type for collective_rpc EngineIdentity = bytes class EngineCoreClient(ABC): """ - EngineCoreClient: subclasses handle different methods for pushing + EngineCoreClient: subclasses handle different methods for pushing and pulling from the EngineCore for asyncio / multiprocessing. Subclasses: @@ -65,16 +77,17 @@ class EngineCoreClient(ABC): executor_class: type[Executor], log_stats: bool, ) -> "EngineCoreClient": - # TODO: support this for debugging purposes. if asyncio_mode and not multiprocess_mode: raise NotImplementedError( "Running EngineCore in asyncio without multiprocessing " - "is not currently supported.") + "is not currently supported." + ) if multiprocess_mode and asyncio_mode: return EngineCoreClient.make_async_mp_client( - vllm_config, executor_class, log_stats) + vllm_config, executor_class, log_stats + ) if multiprocess_mode and not asyncio_mode: return SyncMPClient(vllm_config, executor_class, log_stats) @@ -91,8 +104,14 @@ class EngineCoreClient(ABC): client_index: int = 0, ) -> "MPClient": parallel_config = vllm_config.parallel_config - client_args = (vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + client_args = ( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_external_lb: # External load balancer - client per DP rank. @@ -102,8 +121,7 @@ class EngineCoreClient(ABC): return AsyncMPClient(*client_args) @abstractmethod - def shutdown(self): - ... + def shutdown(self): ... def get_output(self) -> EngineCoreOutputs: raise NotImplementedError @@ -153,17 +171,18 @@ class EngineCoreClient(ABC): def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: raise NotImplementedError - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: raise NotImplementedError def dp_engines_running(self) -> bool: @@ -216,24 +235,24 @@ class EngineCoreClient(ABC): async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError - async def save_sharded_state_async(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + async def save_sharded_state_async( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: raise NotImplementedError async def collective_rpc_async( - self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: raise NotImplementedError class InprocClient(EngineCoreClient): """ - InprocClient: client for in-process EngineCore. Intended + InprocClient: client for in-process EngineCore. Intended for use in LLMEngine for V0-style add_request() and step() EngineCore setup in this process (no busy loop). @@ -245,8 +264,8 @@ class InprocClient(EngineCoreClient): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - outputs, _ = self.engine_core.step() - return outputs.get(0) or EngineCoreOutputs() + outputs, _ = self.engine_core.step_fn() + return outputs and outputs.get(0) or EngineCoreOutputs() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() @@ -295,17 +314,18 @@ class InprocClient(EngineCoreClient): def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: self.engine_core.save_sharded_state(path, pattern, max_size) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) def dp_engines_running(self) -> bool: @@ -320,8 +340,9 @@ class BackgroundResources: ctx: zmq.Context # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. - engine_manager: Optional[Union[CoreEngineProcManager, - CoreEngineActorManager]] = None + engine_manager: Optional[Union[CoreEngineProcManager, CoreEngineActorManager]] = ( + None + ) coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None @@ -347,11 +368,15 @@ class BackgroundResources: if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. - loop = self.output_socket._get_loop() - asyncio.get_running_loop() - sockets = (self.output_socket, self.input_socket, - self.first_req_send_socket, self.first_req_rcv_socket, - self.stats_update_socket) + loop = self.output_queue_task._loop if self.output_queue_task else None + + sockets = ( + self.output_socket, + self.input_socket, + self.first_req_send_socket, + self.first_req_rcv_socket, + self.stats_update_socket, + ) tasks = (self.output_queue_task, self.stats_update_task) @@ -359,11 +384,12 @@ class BackgroundResources: close_sockets(sockets) for task in tasks: if task is not None and not task.done(): - task.cancel() + with contextlib.suppress(Exception): + task.cancel() if in_loop(loop): close_sockets_and_tasks() - elif not loop.is_closed(): + elif loop and not loop.is_closed(): loop.call_soon_threadsafe(close_sockets_and_tasks) else: # Loop has been closed, try to clean up directly. @@ -385,11 +411,10 @@ class BackgroundResources: with self.ctx.socket(zmq.PAIR) as shutdown_sender: shutdown_sender.connect(self.shutdown_path) # Send shutdown signal. - shutdown_sender.send(b'') + shutdown_sender.send(b"") def validate_alive(self, frames: Sequence[zmq.Frame]): - if len(frames) == 1 and (frames[0].buffer - == EngineCoreProc.ENGINE_CORE_DEAD): + if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD): self.engine_dead = True raise EngineDeadError() @@ -402,7 +427,7 @@ class MPClient(EngineCoreClient): * pushes EngineCoreRequests via input_socket * pulls EngineCoreOutputs via output_socket - + * AsyncMPClient subclass for AsyncLLM usage * SyncMPClient subclass for LLM usage """ @@ -435,34 +460,36 @@ class MPClient(EngineCoreClient): self.engines_running = False self.stats_update_address: Optional[str] = None - if client_addresses is not None: + if client_addresses: # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] - self.stats_update_address = client_addresses.get( - "stats_update_address") + self.stats_update_address = client_addresses.get("stats_update_address") else: # Engines are managed by this client. - with launch_core_engines(vllm_config, executor_class, - log_stats) as (engine_manager, - coordinator, - addresses): + with launch_core_engines(vllm_config, executor_class, log_stats) as ( + engine_manager, + coordinator, + addresses, + ): self.resources.coordinator = coordinator self.resources.engine_manager = engine_manager - (input_address, ) = addresses.inputs - (output_address, ) = addresses.outputs - self.stats_update_address = ( - addresses.frontend_stats_publish_address) + (input_address,) = addresses.inputs + (output_address,) = addresses.outputs + self.stats_update_address = addresses.frontend_stats_publish_address if coordinator is not None: assert self.stats_update_address == ( - coordinator.get_stats_publish_address()) + coordinator.get_stats_publish_address() + ) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( - self.ctx, input_address, zmq.ROUTER, bind=True) + self.ctx, input_address, zmq.ROUTER, bind=True + ) self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.PULL) + self.ctx, output_address, zmq.PULL + ) parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size @@ -471,19 +498,22 @@ class MPClient(EngineCoreClient): offline_mode = parallel_config.data_parallel_rank_local is not None # Client manages local+remote EngineCores in pure internal LB case. # Client manages local EngineCores in hybrid and external LB case. - local_engines_only = (parallel_config.data_parallel_hybrid_lb - or parallel_config.data_parallel_external_lb) + local_engines_only = ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) num_ranks = dp_local_size if local_engines_only else dp_size - self.engine_ranks_managed = [dp_rank] if offline_mode else list( - range(dp_rank, dp_rank + num_ranks)) + self.engine_ranks_managed = ( + [dp_rank] if offline_mode else list(range(dp_rank, dp_rank + num_ranks)) + ) assert parallel_config.data_parallel_size_local <= len( - self.engine_ranks_managed) + self.engine_ranks_managed + ) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ - rank.to_bytes(2, "little") - for rank in self.engine_ranks_managed + rank.to_bytes(2, "little") for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. @@ -491,8 +521,10 @@ class MPClient(EngineCoreClient): sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: if not sync_input_socket.poll(timeout=600_000): - raise TimeoutError("Timed out waiting for engines to send" - "initial message on input socket.") + raise TimeoutError( + "Timed out waiting for engines to send" + "initial message on input socket." + ) identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) @@ -518,8 +550,9 @@ class MPClient(EngineCoreClient): def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" - return EngineDeadError( - suppress_context=True) if self.resources.engine_dead else e + return ( + EngineDeadError(suppress_context=True) if self.resources.engine_dead else e + ) def ensure_alive(self): if self.resources.engine_dead: @@ -539,8 +572,11 @@ class MPClient(EngineCoreClient): def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager - if (engine_manager is None or not hasattr(engine_manager, 'processes') - or not engine_manager.processes): + if ( + engine_manager is None + or not hasattr(engine_manager, "processes") + or not engine_manager.processes + ): # No engine processes to monitor return @@ -557,23 +593,26 @@ class MPClient(EngineCoreClient): if not _self or _self.resources.engine_dead: return _self.resources.engine_dead = True - proc_name = next(proc.name for proc in engine_processes - if proc.sentinel == died[0]) + proc_name = next( + proc.name for proc in engine_processes if proc.sentinel == died[0] + ) logger.error( - "Engine core proc %s died unexpectedly, " - "shutting down client.", proc_name) + "Engine core proc %s died unexpectedly, shutting down client.", + proc_name, + ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will # cause subsequent operations to raise EngineDeadError - Thread(target=monitor_engine_cores, - daemon=True, - name="MPClientEngineMonitor").start() + Thread( + target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor" + ).start() -def _process_utility_output(output: UtilityOutput, - utility_results: dict[int, AnyFuture]): +def _process_utility_output( + output: UtilityOutput, utility_results: dict[int, AnyFuture] +): """Set the result from a utility method in the waiting future.""" future = utility_results.pop(output.call_id) failure_message = output.failure_message @@ -588,15 +627,17 @@ def _process_utility_output(output: UtilityOutput, # original calling task being cancelled. if failure_message is not None: logger.error( - "Cancelled call to utility method failed " - "with error: %s", failure_message) + "Cancelled call to utility method failed with error: %s", + failure_message, + ) class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__( + self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool + ): super().__init__( asyncio_mode=False, vllm_config=vllm_config, @@ -639,8 +680,7 @@ class SyncMPClient(MPClient): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) except Exception as e: @@ -651,9 +691,11 @@ class SyncMPClient(MPClient): out_socket.close(linger=0) # Process outputs from engine in separate thread. - self.output_queue_thread = Thread(target=process_outputs_socket, - name="EngineCoreOutputQueueThread", - daemon=True) + self.output_queue_thread = Thread( + target=process_outputs_socket, + name="EngineCoreOutputQueueThread", + daemon=True, + ) self.output_queue_thread.start() # The thread takes on responsibility for closing the socket. @@ -674,8 +716,7 @@ class SyncMPClient(MPClient): self.ensure_alive() self.free_pending_messages() # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine, request_type.value, - *self.encoder.encode(request)) + msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. @@ -689,8 +730,7 @@ class SyncMPClient(MPClient): call_id = uuid.uuid1().int >> 64 future: Future[Any] = Future() self.utility_results[call_id] = future - self._send_input(EngineCoreRequestType.UTILITY, - (0, call_id, method, args)) + self._send_input(EngineCoreRequestType.UTILITY, (0, call_id, method, args)) return future.result() @@ -739,31 +779,33 @@ class SyncMPClient(MPClient): def execute_dummy_batch(self) -> None: self.call_utility("execute_dummy_batch") - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.call_utility("collective_rpc", method, timeout, args, - kwargs) + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return self.call_utility("collective_rpc", method, timeout, args, kwargs) - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): super().__init__( asyncio_mode=True, vllm_config=vllm_config, @@ -772,9 +814,9 @@ class AsyncMPClient(MPClient): client_addresses=client_addresses, ) + self.client_count = client_count self.client_index = client_index - self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, - Exception]]() + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: # If we are running in an asyncio event loop, start the queue task. # Otherwise, it will be started lazily. If it is not started here, @@ -795,10 +837,9 @@ class AsyncMPClient(MPClient): decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], - Awaitable[None]]] = getattr( - self.__class__, - "process_engine_outputs", None) + output_handler: Optional[ + Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]] + ] = getattr(self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_socket = resources.output_socket assert output_socket is not None @@ -810,8 +851,7 @@ class AsyncMPClient(MPClient): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + _process_utility_output(outputs.utility_output, utility_results) continue if output_handler is not None: @@ -830,7 +870,8 @@ class AsyncMPClient(MPClient): outputs_queue.put_nowait(EngineDeadError()) resources.output_queue_task = asyncio.create_task( - process_outputs_socket(), name="EngineCoreOutputQueueTask") + process_outputs_socket(), name="EngineCoreOutputQueueTask" + ) async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() @@ -843,19 +884,21 @@ class AsyncMPClient(MPClient): raise self._format_exception(outputs) from None return outputs - def _send_input(self, - request_type: EngineCoreRequestType, - request: Any, - engine: Optional[EngineIdentity] = None) -> Awaitable[Any]: + def _send_input( + self, + request_type: EngineCoreRequestType, + request: Any, + engine: Optional[EngineIdentity] = None, + ) -> Awaitable[Any]: if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) - def _send_input_message(self, message: tuple[bytestr, - ...], engine: EngineIdentity, - objects: Any) -> Awaitable[Any]: + def _send_input_message( + self, message: tuple[bytestr, ...], engine: EngineIdentity, objects: Any + ) -> Awaitable[Any]: """ objects is a reference to retain until zmq is finished with the buffers, in case they were extracted from tensors in the request. @@ -863,7 +906,7 @@ class AsyncMPClient(MPClient): self.ensure_alive() self.free_pending_messages() - msg = (engine, ) + message + msg = (engine,) + message if not objects or len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. return self.input_socket.send_multipart(msg, copy=False) @@ -879,17 +922,18 @@ class AsyncMPClient(MPClient): return future async def call_utility_async(self, method: str, *args) -> Any: - return await self._call_utility_async(method, - *args, - engine=self.core_engine) + return await self._call_utility_async(method, *args, engine=self.core_engine) - async def _call_utility_async(self, method: str, *args, - engine: EngineIdentity) -> Any: + async def _call_utility_async( + self, method: str, *args, engine: EngineIdentity + ) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, method, args))) + message = ( + EngineCoreRequestType.UTILITY.value, + *self.encoder.encode((self.client_index, call_id, method, args)), + ) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future @@ -939,38 +983,46 @@ class AsyncMPClient(MPClient): async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) - async def save_sharded_state_async(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: - await self.call_utility_async("save_sharded_state", path, pattern, - max_size) + async def save_sharded_state_async( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: + await self.call_utility_async("save_sharded_state", path, pattern, max_size) async def collective_rpc_async( - self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return await self.call_utility_async("collective_rpc", method, timeout, - args, kwargs) + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return await self.call_utility_async( + "collective_rpc", method, timeout, args, kwargs + ) class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore. Assumes external load-balancing by default.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): self.current_wave = 0 - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + super().__init__( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. @@ -978,10 +1030,8 @@ class DPAsyncMPClient(AsyncMPClient): self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( - make_zmq_socket(self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=True)) + make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True) + ) try: # If we are running in an asyncio event loop, start the stats task. # Otherwise, it will be started lazily. @@ -1000,25 +1050,25 @@ class DPAsyncMPClient(AsyncMPClient): # NOTE: running and waiting counts are all global from # the Coordinator include all global EngineCores. This # slice includes just the cores managed by this client. - count_slice = slice(self.engine_ranks_managed[0], - self.engine_ranks_managed[-1] + 1) + count_slice = slice( + self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1 + ) async def run_engine_stats_update_task(): - with (make_zmq_socket(self.ctx, - self.stats_update_address, - zmq.XSUB, - linger=0) as socket, - make_zmq_socket(self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=False, - linger=0) as first_req_rcv_socket): + with ( + make_zmq_socket( + self.ctx, self.stats_update_address, zmq.XSUB, linger=0 + ) as socket, + make_zmq_socket( + self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 + ) as first_req_rcv_socket, + ): assert isinstance(socket, zmq.asyncio.Socket) assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) self.resources.stats_update_socket = socket self.resources.first_req_rcv_socket = first_req_rcv_socket # Send subscription message. - await socket.send(b'\x01') + await socket.send(b"\x01") poller = zmq.asyncio.Poller() poller.register(socket, zmq.POLLIN) @@ -1026,23 +1076,27 @@ class DPAsyncMPClient(AsyncMPClient): while True: events = await poller.poll() - if not self.engines_running and len(events) == 2 or ( - events[0][0] == first_req_rcv_socket): + if ( + not self.engines_running + and len(events) == 2 + or (events[0][0] == first_req_rcv_socket) + ): # Check if this is a regular request notification or # scale up notification - buf = first_req_rcv_socket.recv( - flags=zmq.NOBLOCK).result() + buf = first_req_rcv_socket.recv(flags=zmq.NOBLOCK).result() decoded = msgspec.msgpack.decode(buf) - if isinstance( - decoded, - (list, tuple)) and len(decoded) == 2 and decoded[ - 0] == "SCALE_ELASTIC_EP": + if ( + isinstance(decoded, (list, tuple)) + and len(decoded) == 2 + and decoded[0] == "SCALE_ELASTIC_EP" + ): # Extract new engine count from the decoded message new_engine_count = decoded[1] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_engine_count)) + ("SCALE_ELASTIC_EP", new_engine_count) + ) await socket.send(scale_msg) continue @@ -1053,14 +1107,14 @@ class DPAsyncMPClient(AsyncMPClient): target_eng_index = decoded[1] self.engines_running = True msg = msgspec.msgpack.encode( - (target_eng_index, self.current_wave)) + (target_eng_index, self.current_wave) + ) await socket.send(msg) buf = None while True: # Drain all stats events (we only care about latest). - future: asyncio.Future[bytes] = socket.recv( - flags=zmq.NOBLOCK) + future: asyncio.Future[bytes] = socket.recv(flags=zmq.NOBLOCK) if isinstance(future.exception(), zmq.Again): break buf = future.result() @@ -1074,11 +1128,13 @@ class DPAsyncMPClient(AsyncMPClient): if counts is not None: sliced_counts = counts[count_slice] self.lb_engines = sliced_counts - logger.debug("Received counts: %s (%s)", sliced_counts, - count_slice) + logger.debug( + "Received counts: %s (%s)", sliced_counts, count_slice + ) resources.stats_update_task = asyncio.create_task( - run_engine_stats_update_task()) + run_engine_stats_update_task() + ) async def add_request_async(self, request: EngineCoreRequest) -> None: self._ensure_stats_update_task() @@ -1087,8 +1143,7 @@ class DPAsyncMPClient(AsyncMPClient): request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request(request) - to_await = self._send_input(EngineCoreRequestType.ADD, request, - chosen_engine) + to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: # Notify coordinator that we're sending a request req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine)) @@ -1106,29 +1161,36 @@ class DPLBAsyncMPClient(DPAsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore. Load-balances between multiple engine processes.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): - + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): self.client_count = client_count # To route aborts to the correct engine. self.reqs_in_flight: dict[str, EngineIdentity] = {} - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + super().__init__( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) assert len(self.core_engines) > 1 - self.eng_start_index = (len(self.core_engines) * - self.client_index) // client_count + self.eng_start_index = ( + len(self.core_engines) * self.client_index + ) // client_count - def get_core_engine_for_request( - self, request: EngineCoreRequest) -> EngineIdentity: + def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. if (eng_index := request.data_parallel_rank) is None: current_counts = self.lb_engines @@ -1156,14 +1218,19 @@ class DPLBAsyncMPClient(DPAsyncMPClient): async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. - return (await asyncio.gather(*[ - self._call_utility_async(method, *args, engine=engine) - for engine in self.core_engines - ]))[0] + return ( + await asyncio.gather( + *[ + self._call_utility_async(method, *args, engine=engine) + for engine in self.core_engines + ] + ) + )[0] @staticmethod - async def process_engine_outputs(self: "DPLBAsyncMPClient", - outputs: EngineCoreOutputs): + async def process_engine_outputs( + self: "DPLBAsyncMPClient", outputs: EngineCoreOutputs + ): if outputs.finished_requests and self.reqs_in_flight: for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) @@ -1185,25 +1252,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): for engine, req_ids in by_engine.items(): await self._abort_requests(req_ids, engine) - async def _abort_requests(self, request_ids: list[str], - engine: EngineIdentity) -> None: - await self._send_input(EngineCoreRequestType.ABORT, request_ids, - engine) - - async def _send_reconfig_message( - self, reconfig_request: ReconfigureDistributedRequest, - engine: EngineIdentity) -> asyncio.Future: - """Send reconfiguration message and return the result future without - waiting for completion.""" - call_id = uuid.uuid1().int >> 64 - future = asyncio.get_running_loop().create_future() - self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, "reinitialize_distributed", - (reconfig_request, )))) - await self._send_input_message(message, engine, reconfig_request) - self._ensure_output_queue_task() - return future + async def _abort_requests( + self, request_ids: list[str], engine: EngineIdentity + ) -> None: + await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" @@ -1211,22 +1263,27 @@ class DPLBAsyncMPClient(DPAsyncMPClient): assert new_data_parallel_size != cur_data_parallel_size, ( f"new_data_parallel_size {new_data_parallel_size} must be " - f"different from cur_data_parallel_size {cur_data_parallel_size}") + f"different from cur_data_parallel_size {cur_data_parallel_size}" + ) - assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", ("Only ray DP backend supports scaling elastic EP") + assert self.vllm_config.parallel_config.data_parallel_backend == "ray", ( + "Only ray DP backend supports scaling elastic EP" + ) scale_up = new_data_parallel_size > cur_data_parallel_size if scale_up: - await self._scale_up_elastic_ep(cur_data_parallel_size, - new_data_parallel_size) + await self._scale_up_elastic_ep( + cur_data_parallel_size, new_data_parallel_size + ) else: - await self._scale_down_elastic_ep(cur_data_parallel_size, - new_data_parallel_size) + await self._scale_down_elastic_ep( + cur_data_parallel_size, new_data_parallel_size + ) - async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + async def _scale_up_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: """Scale up the data parallel size by creating new engine cores and reconfiguring existing ones.""" cur_data_parallel_size = len(self.core_engines) @@ -1234,31 +1291,29 @@ class DPLBAsyncMPClient(DPAsyncMPClient): # Phase 1: Send reconfigure messages to all existing engines and wait # for them to be sent reconfig_futures = [] - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() for engine in self.core_engines: reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_rank_local=\ - ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + ) + coro = self._call_utility_async( + "reinitialize_distributed", reconfig_request, engine=engine + ) + reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") # Phase 2: Create new engines now that reconfig messages have been sent # self.resources.engine_manager is guaranteed to be # CoreEngineActorManager for RayDPClient - assert isinstance(self.resources.engine_manager, - CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.resources.engine_manager.scale_up_elastic_ep( - self.vllm_config, new_data_parallel_size) + self.vllm_config, new_data_parallel_size + ) # Create new CoreEngine objects for the new engines new_engine_identities = set() @@ -1273,7 +1328,8 @@ class DPLBAsyncMPClient(DPAsyncMPClient): if not sync_input_socket.poll(timeout=600_000): raise TimeoutError( "Timed out waiting for new engines to send initial " - "message on input socket.") + "message on input socket." + ) identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) @@ -1285,60 +1341,62 @@ class DPLBAsyncMPClient(DPAsyncMPClient): # stats_update_task connection self._ensure_stats_update_task() scale_up_marker = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_data_parallel_size)) + ("SCALE_ELASTIC_EP", new_data_parallel_size) + ) await self.first_req_send_socket.send(scale_up_marker) # Update the parallel config - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale up completed, new data parallel size: %s", - new_data_parallel_size) + new_data_parallel_size, + ) - async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + async def _scale_down_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: """Scale down the data parallel size by shutting down and reconfiguring existing engine cores.""" cur_data_parallel_size = len(self.core_engines) - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_rank_local=\ - ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) + new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + ) if cur_dp_rank >= new_data_parallel_size: - reconfig_request.new_data_parallel_rank = \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + reconfig_request.new_data_parallel_rank = ( + ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ) + coro = self._call_utility_async( + "reinitialize_distributed", reconfig_request, engine=engine + ) + reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): self.core_engines.pop() await asyncio.gather(*reconfig_futures) - assert isinstance(self.resources.engine_manager, - CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.resources.engine_manager.scale_down_elastic_ep( - cur_data_parallel_size, new_data_parallel_size) + cur_data_parallel_size, new_data_parallel_size + ) self._ensure_stats_update_task() scale_down_marker = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_data_parallel_size)) + ("SCALE_ELASTIC_EP", new_data_parallel_size) + ) await self.first_req_send_socket.send(scale_down_marker) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale down completed, new data parallel size: %s", - new_data_parallel_size) + new_data_parallel_size, + ) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 04ad51aae0a8c..5efde9e2ff878 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -9,25 +9,26 @@ from tokenizers import Tokenizer from tokenizers.decoders import DecodeStream from transformers import PreTrainedTokenizerFast -from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( - AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) + AnyTokenizer, + convert_prompt_ids_to_tokens, + detokenize_incrementally, +) +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) # Only tokenizers >= 0.21.1 supports DecodeStream used for # FastIncrementalDetokenizer. -USE_FAST_DETOKENIZER = version.parse( - tokenizers.__version__) >= version.parse("0.21.1") +USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1") # Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042 INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered" class IncrementalDetokenizer: - def __init__(self): self.token_ids: list[int] = [] @@ -35,8 +36,7 @@ class IncrementalDetokenizer: def output_token_ids(self) -> list[int]: return self.token_ids - def update(self, new_token_ids: list[int], - stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: self.token_ids.extend(new_token_ids) return None @@ -49,15 +49,13 @@ class IncrementalDetokenizer: tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "IncrementalDetokenizer": - assert request.sampling_params is not None if tokenizer is None: # No tokenizer => skipping detokenization. return IncrementalDetokenizer() - if USE_FAST_DETOKENIZER and isinstance(tokenizer, - PreTrainedTokenizerFast): + if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast): # Fast tokenizer => use tokenizers library DecodeStream. return FastIncrementalDetokenizer(tokenizer, request) @@ -66,7 +64,6 @@ class IncrementalDetokenizer: class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): - def __init__(self, request: EngineCoreRequest): super().__init__() @@ -88,8 +85,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # Generation data self.output_text = "" - def update(self, new_token_ids: list[int], - stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. @@ -117,21 +113,17 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): self.token_ids.append(new_token_id) self.output_text += self.decode_next(new_token_id) # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014 - if self.min_tokens and len( - self.output_token_ids) <= self.min_tokens: + if self.min_tokens and len(self.output_token_ids) <= self.min_tokens: stop_check_offset = len(self.output_text) - if stop_terminated: - if skipped_stop_token_id is not None: - # Cleanup after skipping detokenization. - self.token_ids.append(skipped_stop_token_id) - # Stop token triggered; skip stop string check. - return None + if skipped_stop_token_id is not None: + # Cleanup after skipping detokenization. + self.token_ids.append(skipped_stop_token_id) # 2) Evaluate stop strings. stop_string = None if self.stop and len(self.output_token_ids) > self.min_tokens: - stop = StopChecker.check_stop_strings( + stop = check_stop_strings( output_text=self.output_text, new_char_count=len(self.output_text) - stop_check_offset, stop=self.stop, @@ -155,8 +147,11 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # We return the full output text if the sequence is finished. buffer_length = 0 if finished else self.stop_buffer_length if not delta: - return self.output_text[:-buffer_length] if buffer_length else ( - self.output_text) + return ( + self.output_text[:-buffer_length] + if buffer_length + else (self.output_text) + ) length = len(self.output_text) - buffer_length last_offset = self._last_output_text_offset if last_offset < length: @@ -166,9 +161,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): - - def __init__(self, tokenizer: PreTrainedTokenizerFast, - request: EngineCoreRequest): + def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest): super().__init__(request) sampling_params = request.sampling_params @@ -176,18 +169,18 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): self.request_id = request.request_id self.skip_special_tokens = sampling_params.skip_special_tokens - self.stream = DecodeStream( - skip_special_tokens=self.skip_special_tokens) + self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) self.tokenizer: Tokenizer = tokenizer._tokenizer # Find a safe place to start. - prompt_suffix = request.prompt_token_ids + prompt_token_ids = request.prompt_token_ids or [] + prompt_suffix = prompt_token_ids prompt_len = len(prompt_suffix) if prompt_len > 4: for i in range(4, min(prompt_len + 1, 24)): - suffix = request.prompt_token_ids[-i:] - if '�' not in self.tokenizer.decode(suffix): + suffix = prompt_token_ids[-i:] + if "�" not in self.tokenizer.decode(suffix): prompt_suffix = suffix break @@ -197,17 +190,18 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): self.spaces_between_special_tokens = ( sampling_params.skip_special_tokens - or sampling_params.spaces_between_special_tokens) + or sampling_params.spaces_between_special_tokens + ) if not self.spaces_between_special_tokens: # Store dict of added token ids so that we can suppress # the spaces between them. - if (added_token_ids := getattr(self.tokenizer, "added_token_ids", - None)) is None: + if ( + added_token_ids := getattr(self.tokenizer, "added_token_ids", None) + ) is None: self.tokenizer.added_token_ids = added_token_ids = { tid: tok.content - for tid, tok in - self.tokenizer.get_added_tokens_decoder().items() + for tid, tok in self.tokenizer.get_added_tokens_decoder().items() } if added_token_ids: @@ -233,8 +227,13 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): def _protected_step(self, next_token_id: int) -> Optional[str]: try: token = self.stream.step(self.tokenizer, next_token_id) + except (OverflowError, TypeError): + # Handle rare observed overflow, still to be diagnosed. + # See https://github.com/vllm-project/vllm/issues/21951. + logger.exception("Encountered invalid token id: %r", next_token_id) + token = None except Exception as e: - if str(e) != INVALID_PREFIX_ERR_MSG: + if not str(e).startswith(INVALID_PREFIX_ERR_MSG): raise e # Recover from edge case where tokenizer can produce non-monotonic, # invalid UTF-8 output, which breaks the internal state of @@ -242,14 +241,15 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): # See https://github.com/vllm-project/vllm/issues/17448. logger.warning( "Encountered invalid prefix detokenization error" - " for request %s, resetting decode stream.", self.request_id) - self.stream = DecodeStream(self.skip_special_tokens) + " for request %s, resetting decode stream.", + self.request_id, + ) + self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) token = self.stream.step(self.tokenizer, next_token_id) return token class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): - def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): super().__init__(request) @@ -257,41 +257,89 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): params = request.sampling_params assert params is not None - # Metadata for incremental detokenization. - self.tokens, self.prefix_offset, self.read_offset = ( - convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=request.prompt_token_ids, - skip_special_tokens=params.skip_special_tokens, - )) + self.prompt_len = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds + ) - self.token_ids.extend(request.prompt_token_ids) - self.prompt_len = len(request.prompt_token_ids) + # Metadata for incremental detokenization. + if request.prompt_token_ids is not None: + self.tokens, self.prefix_offset, self.read_offset = ( + convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=params.skip_special_tokens, + ) + ) + else: + # Prompt embedding requests cannot be detokenized, in general. + self.tokens = [""] * self.prompt_len + self.prefix_offset = 0 + self.read_offest = 0 + + self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len) self.skip_special_tokens = params.skip_special_tokens - self.spaces_between_special_tokens = ( - params.spaces_between_special_tokens) + self.spaces_between_special_tokens = params.spaces_between_special_tokens @property def output_token_ids(self) -> list[int]: - return self.token_ids if not self.prompt_len else ( - self.token_ids[self.prompt_len:]) + return ( + self.token_ids + if not self.prompt_len + else (self.token_ids[self.prompt_len :]) + ) def decode_next(self, next_token_id: int) -> str: - new_tokens, decoded_text, prefix_offset, read_offset = ( - detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=self.token_ids, - prev_tokens=self.tokens, - prefix_offset=self.prefix_offset, - read_offset=self.read_offset, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self. - spaces_between_special_tokens, - )) + new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + ) self.tokens.extend(new_tokens) self.prefix_offset = prefix_offset self.read_offset = read_offset return decoded_text + + +def check_stop_strings( + output_text: str, + new_char_count: int, + stop: list[str], + include_in_output: bool, +) -> Optional[tuple[str, int]]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. + """ + if not new_char_count or not stop: + return None + + for stop_str in stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len) + if stop_index == -1: + continue + + if include_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(output_text): + # No truncation required. + return stop_str, -1 + + # Truncate the output text to either the beginning + # or end of the stop string. + return stop_str, stop_index + return None diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py index 692ba9dc840f8..d9f79a019e2df 100644 --- a/vllm/v1/engine/exceptions.py +++ b/vllm/v1/engine/exceptions.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project class EngineGenerateError(Exception): """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5a00a930951cc..b2261855d125c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,37 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from collections.abc import Mapping from copy import copy from typing import Any, Callable, Optional, Union +import torch.nn as nn from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.distributed.parallel_state import get_dp_group from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.transformers_utils.tokenizer_group import ( - TokenizerGroup, init_tokenizer_from_configs) +from vllm.tracing import init_tracer +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils import Device +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase, - StatLoggerFactory) +from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -57,48 +62,60 @@ class LLMEngine: "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) if stat_loggers is not None: raise NotImplementedError( "Passing StatLoggers to LLMEngine in V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") + "Set VLLM_USE_V1=0 and file and issue on Github." + ) self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.log_stats = log_stats - self.stat_logger: Optional[StatLoggerBase] = None - if self.log_stats: - self.stat_logger = PrometheusStatLogger(vllm_config) + executor_backend = self.vllm_config.parallel_config.distributed_executor_backend + parallel_config = vllm_config.parallel_config + self.external_launcher_dp = ( + parallel_config.data_parallel_size > 1 + and executor_backend == "external_launcher" + ) # important: init dp group before init the engine_core # In the decoupled engine case this is handled in EngineCoreProc. - parallel_config = vllm_config.parallel_config - if not multiprocess_mode and parallel_config.data_parallel_size > 1: + if ( + not multiprocess_mode + and parallel_config.data_parallel_size > 1 + and not self.external_launcher_dp + ): self.dp_group = parallel_config.stateless_init_dp_group() else: self.dp_group = None self.should_execute_dummy_batch = False if self.model_config.skip_tokenizer_init: - self.tokenizer = None + tokenizer = None else: - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + tokenizer = init_tokenizer_from_configs(self.model_config) - # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config=vllm_config, - tokenizer=self.tokenizer, - mm_registry=mm_registry) + self.processor = Processor(self.vllm_config, tokenizer) + self.io_processor = get_io_processor( + self.vllm_config, + self.model_config.io_processor_plugin, + ) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + self.tokenizer, log_stats=self.log_stats + ) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", self.observability_config.otlp_traces_endpoint + ) + self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( @@ -109,10 +126,24 @@ class LLMEngine: log_stats=self.log_stats, ) + self.logger_manager: Optional[StatLoggerManager] = None + if self.log_stats: + self.logger_manager = StatLoggerManager( + vllm_config=vllm_config, + custom_stat_loggers=stat_loggers, + enable_default_loggers=log_stats, + ) + self.logger_manager.log_engine_initialized() + if not multiprocess_mode: # for v0 compatibility self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + if self.external_launcher_dp: + # If we use DP in external launcher mode, we reuse the + # existing DP group used for data communication. + self.dp_group = get_dp_group().cpu_group + # Don't keep the dummy data in memory self.reset_mm_cache() @@ -124,12 +155,14 @@ class LLMEngine: stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - return cls(vllm_config=vllm_config, - executor_class=Executor.get_class(vllm_config), - log_stats=(not disable_log_stats), - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING) + return cls( + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING, + ) @classmethod def from_engine_args( @@ -150,12 +183,14 @@ class LLMEngine: enable_multiprocessing = True # Create the LLMEngine. - return cls(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=enable_multiprocessing) + return cls( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing, + ) def get_num_unfinished_requests(self) -> int: return self.output_processor.get_num_unfinished_requests() @@ -168,7 +203,8 @@ class LLMEngine: def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: aggregated_has_unfinished = ParallelConfig.has_unfinished_dp( - self.dp_group, has_unfinished) + self.dp_group, has_unfinished + ) if not has_unfinished and aggregated_has_unfinished: self.should_execute_dummy_batch = True return aggregated_has_unfinished @@ -189,29 +225,45 @@ class LLMEngine: def add_request( self, request_id: str, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + prompt_text: Optional[str] = None, ) -> None: # Validate the request_id type. if not isinstance(request_id, str): - raise TypeError( - f"request_id must be a string, got {type(request_id)}") + raise TypeError(f"request_id must be a string, got {type(request_id)}") # Process raw inputs into the request. - prompt_str, request = self.processor.process_inputs( - request_id, prompt, params, arrival_time, lora_request, - tokenization_kwargs, trace_headers, priority) + if isinstance(prompt, EngineCoreRequest): + request = prompt + else: + assert prompt_text is None + logger.warning_once( + "Processor has been moved under LLM and will " + "be removed from LLMEngine in v0.13." + ) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + ) + prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") n = params.n if isinstance(params, SamplingParams) else 1 if n == 1: # Make a new RequestState and queue. - self.output_processor.add_request(request, prompt_str, None, 0) + self.output_processor.add_request(request, prompt_text, None, 0) # Add the request to EngineCore. self.engine_core.add_request(request) return @@ -225,13 +277,13 @@ class LLMEngine: child_request.sampling_params = params # Make a new RequestState and queue. - self.output_processor.add_request(child_request, prompt_str, - parent_req, idx) + self.output_processor.add_request( + child_request, prompt_text, parent_req, idx + ) # Add the request to EngineCore. self.engine_core.add_request(child_request) def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: - if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False self.engine_core.execute_dummy_batch() @@ -245,25 +297,25 @@ class LLMEngine: processed_outputs = self.output_processor.process_outputs( outputs.outputs, engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats) + iteration_stats=iteration_stats, + ) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) # 4) Record stats - if self.stat_logger is not None: + if self.logger_manager is not None: assert outputs.scheduler_stats is not None - self.stat_logger.record(scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats) + + self.logger_manager.record( + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), + ) + self.do_log_stats_with_interval() return processed_outputs.request_outputs - def get_vllm_config(self): - return self.vllm_config - - def get_model_config(self): - return self.model_config - def start_profile(self): self.engine_core.profile(True) @@ -271,8 +323,7 @@ class LLMEngine: self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_mm_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): @@ -291,13 +342,36 @@ class LLMEngine: assert self.log_stats, "Stat logging disabled" return get_metrics_snapshot() - def get_tokenizer_group(self) -> TokenizerGroup: + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.processor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.processor.tokenizer = tokenizer + + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") + raise ValueError( + "Unable to get tokenizer because skip_tokenizer_init is True" + ) return self.tokenizer + def do_log_stats(self) -> None: + """Log stats if logging is enabled.""" + if self.logger_manager: + self.logger_manager.log() + + def do_log_stats_with_interval(self) -> None: + """Log stats when the time interval has passed.""" + now = time.time() + if not hasattr(self, "_last_log_time"): + self._last_log_time = now + if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL: + self.do_log_stats() + self._last_log_time = now + def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" return self.engine_core.add_lora(lora_request) @@ -314,13 +388,21 @@ class LLMEngine: """Prevent an adapter from being evicted.""" return self.engine_core.pin_lora(lora_id) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[[WorkerBase], _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + return self.collective_rpc("apply_model", args=(func,)) + def __del__(self): - if dp_group := getattr(self, "dp_group", None): + if ( + dp_group := getattr(self, "dp_group", None) + and not self.external_launcher_dp + ): stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 3de7fa6889e55..ab0e44fce1558 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -7,9 +7,11 @@ from dataclasses import dataclass from typing import Optional from vllm.logger import init_logger -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( - AnyTokenizer, convert_ids_list_to_tokens) + AnyTokenizer, + convert_ids_list_to_tokens, +) from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -20,7 +22,6 @@ NONES = itertools.repeat(None) @dataclass class LogprobsProcessor: - # Tokenizer for this request, # None if detokenization is disabled. tokenizer: Optional[AnyTokenizer] @@ -43,7 +44,7 @@ class LogprobsProcessor: num_prompt_logprobs = request.sampling_params.prompt_logprobs return cls( tokenizer=tokenizer, - cumulative_logprob=(None if num_logprobs is None else 0.), + cumulative_logprob=(None if num_logprobs is None else 0.0), logprobs=(None if num_logprobs is None else []), # NOTE: logprob of first prompt token is None. prompt_logprobs=(None if num_prompt_logprobs is None else [None]), @@ -68,12 +69,13 @@ class LogprobsProcessor: token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists - for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, - token_ids_lst): - + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): # Detokenize (non-incrementally). - decoded_tokens = NONES if self.tokenizer is None else ( - convert_ids_list_to_tokens(self.tokenizer, token_ids)) + decoded_tokens = ( + NONES + if self.tokenizer is None + else (convert_ids_list_to_tokens(self.tokenizer, token_ids)) + ) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] @@ -87,7 +89,8 @@ class LogprobsProcessor: decoded_tokens, rank, self.num_logprobs, - )) + ) + ) def _update_prompt_logprobs( self, @@ -109,9 +112,13 @@ class LogprobsProcessor: # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - decoded_tokens = None if self.tokenizer is None else ( - convert_ids_list_to_tokens(self.tokenizer, - token_ids.flatten().tolist())) + decoded_tokens = ( + None + if self.tokenizer is None + else ( + convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist()) + ) + ) # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape @@ -126,15 +133,20 @@ class LogprobsProcessor: # Handle flattening. offset = pos * num_logprobs offset_end = offset + num_logprobs - decoded_tokens_for_pos = NONES \ - if decoded_tokens is None else decoded_tokens[offset:offset_end] + decoded_tokens_for_pos = ( + NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] + ) # Update with the Logprob dictionary for this pos. self.prompt_logprobs.append( - self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], - decoded_tokens_for_pos, - prompt_token_ranks[pos], - self.num_prompt_logprobs)) + self._make_logprob_dict( + prompt_logprobs[pos], + token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs, + ) + ) def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: """Pop and return all request prompt logprobs @@ -182,7 +194,7 @@ class LogprobsProcessor: # being in the topk, since inserting duplicated data # into a dictionary twice is the same as doing it once. topk_ranks = range(1, num_logprobs + 1) - ranks = itertools.chain((rank, ), topk_ranks) + ranks = itertools.chain((rank,), topk_ranks) return { token_id: Logprob( @@ -191,7 +203,8 @@ class LogprobsProcessor: decoded_token=token, ) for token_id, logprob, rank, token in zip( - logprob_token_ids, logprobs, ranks, decoded_tokens) + logprob_token_ids, logprobs, ranks, decoded_tokens + ) } def update_from_output(self, output: EngineCoreOutput) -> None: diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index aa7dc62fd4acb..0000000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional - -from vllm.multimodal import MultiModalRegistry -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import MultiModalKwargsItem -from vllm.utils import is_list_of - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -# The idea of multimodal input caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- P0: -# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of -# each input multi-modal item (e.g. image), -# - BaseMultiModalProcessor processes the input items into `mm_kwargs`, -# which are MultiModalKwargsItem instances that each correspond to an -# input multi-modal item. -# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding -# `mm_hash` for each item. It stores the `mm_hash` as keys and the size -# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking -# up additional memory in P0. -# - The `mm_hash` is always sent to P1. -# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached -# in MultiModalInputCacheServer. -# -# -- P1: -# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0), -# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`. -# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0), -# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`. -# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to -# the engine for model execution. -# -# Both Client and Server must perform cache update and eviction based on the -# same item size. This ensures that the keys of MultiModalInputCacheClient -# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 -# whether a key is cached in MultiModalInputCacheServer by querying -# MultiModalInputCacheClient without having to communicate with P1. - - -class MultiModalInputCacheClient: - """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalCacheItemMetadata, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[MultiModalKwargsItem], - mm_hashes: list[str], - ) -> list[Optional[MultiModalKwargsItem]]: - if not self.enabled: - return list(mm_kwargs) - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[Optional[MultiModalKwargsItem]]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - out_mm_items.append(None) - else: - self.mm_cache[mm_hash] = \ - MultiModalCacheItemMetadata.wraps(mm_item) - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() - - -class MultiModalInputCacheServer: - """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalKwargsItem, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[Optional[MultiModalKwargsItem]], - mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: - if not self.enabled: - mm_kwargs_lst = list(mm_kwargs) - assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem) - return mm_kwargs_lst - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[MultiModalKwargsItem]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if mm_item is None: - out_mm_items.append(self.mm_cache[mm_hash]) - else: - self.mm_cache[mm_hash] = mm_item - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2ee55b585da6c..eb65b68969e35 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,17 +8,21 @@ from typing import Any, Optional, Union, cast import torch -from vllm.outputs import (CompletionOutput, PoolingOutput, - PoolingRequestOutput, RequestOutput) +from vllm.outputs import ( + CompletionOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, +) from vllm.sampling_params import RequestOutputKind +from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats) +from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats class RequestOutputCollector: @@ -32,12 +36,14 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[Union[RequestOutput, PoolingRequestOutput, - Exception]] = None + self.output: Optional[Union[RequestOutput, PoolingRequestOutput, Exception]] = ( + None + ) self.ready = asyncio.Event() - def put(self, output: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: + def put( + self, output: Union[RequestOutput, PoolingRequestOutput, Exception] + ) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): self.output = output @@ -57,8 +63,7 @@ class RequestOutputCollector: raise output return output - def get_nowait( - self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: + def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: """Non-blocking get operation.""" output = self.output if output is not None: @@ -71,13 +76,11 @@ class RequestOutputCollector: @dataclass class OutputProcessorOutput: - request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] reqs_to_abort: list[str] class RequestState: - def __init__( self, request_id: str, @@ -86,13 +89,17 @@ class RequestState: lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], - prompt_token_ids: list[int], + prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], logprobs_processor: Optional[LogprobsProcessor], detokenizer: Optional[IncrementalDetokenizer], max_tokens_param: Optional[int], arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + top_p: Optional[float] = None, + n: Optional[int] = None, + temperature: Optional[float] = None, ): self.request_id = request_id self.parent_req = parent_req @@ -101,16 +108,21 @@ class RequestState: self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.prompt_len = len(prompt_token_ids) + self.prompt_embeds = prompt_embeds + self.prompt_len = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds + ) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param + self.top_p = top_p + self.n = n + self.temperature = temperature self.is_prefilling = True self.queue = queue self.num_cached_tokens = 0 - self.stats = RequestStateStats( - arrival_time=arrival_time) if log_stats else None + self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None @classmethod def from_new_request( @@ -123,7 +135,6 @@ class RequestState: queue: Optional[RequestOutputCollector], log_stats: bool, ) -> "RequestState": - if sampling_params := request.sampling_params: if not sampling_params.detokenize: tokenizer = None @@ -137,10 +148,16 @@ class RequestState: request=request, ) max_tokens_param = sampling_params.max_tokens + top_p = sampling_params.top_p + n = sampling_params.n + temperature = sampling_params.temperature else: logprobs_processor = None detokenizer = None max_tokens_param = None + top_p = None + n = None + temperature = None assert request.pooling_params is not None output_kind = request.pooling_params.output_kind @@ -148,14 +165,19 @@ class RequestState: request_id=request.request_id, parent_req=parent_req, request_index=request_index, - lora_name=(request.lora_request.name - if request.lora_request is not None else None), + lora_name=( + request.lora_request.name if request.lora_request is not None else None + ), output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, + prompt_embeds=request.prompt_embeds, logprobs_processor=logprobs_processor, detokenizer=detokenizer, max_tokens_param=max_tokens_param, + top_p=top_p, + n=n, + temperature=temperature, arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -169,7 +191,6 @@ class RequestState: stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: - finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -180,22 +201,23 @@ class RequestState: request_id = self.request_id if pooling_output is not None: return self._new_request_output( - request_id, [self._new_pooling_output(pooling_output)], - finished) + request_id, [self._new_pooling_output(pooling_output)], finished + ) - output = self._new_completion_output(new_token_ids, finish_reason, - stop_reason) + output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) if self.parent_req is None: outputs = [output] else: request_id, outputs, finished = self.parent_req.get_outputs( - request_id, output) + request_id, output + ) if not outputs: return None - return self._new_request_output(request_id, outputs, finished, - kv_transfer_params) + return self._new_request_output( + request_id, outputs, finished, kv_transfer_params + ) def _new_request_output( self, @@ -204,10 +226,11 @@ class RequestState: finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Union[RequestOutput, PoolingRequestOutput]: - first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 + # Prompt embeddings are currently not supported by pooling requests. + assert self.prompt_token_ids is not None return PoolingRequestOutput( request_id=request_id, outputs=first_output, @@ -221,15 +244,21 @@ class RequestState: else: prompt_logprobs = self.logprobs_processor.prompt_logprobs + # If prompt embeds were used, put placeholder prompt token ids + prompt_token_ids = self.prompt_token_ids + if prompt_token_ids is None and self.prompt_embeds is not None: + prompt_token_ids = [0] * len(self.prompt_embeds) + return RequestOutput( request_id=request_id, prompt=self.prompt, - prompt_token_ids=self.prompt_token_ids, + prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs, outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, num_cached_tokens=self.num_cached_tokens, + metrics=self.stats, ) def _new_completion_output( @@ -238,7 +267,6 @@ class RequestState: finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], ) -> CompletionOutput: - assert self.detokenizer is not None assert self.logprobs_processor is not None finished = finish_reason is not None @@ -252,7 +280,7 @@ class RequestState: # Prepare logprobs, based on delta mode logprobs = self.logprobs_processor.logprobs if delta and logprobs: - logprobs = logprobs[-len(token_ids):] + logprobs = logprobs[-len(token_ids) :] return CompletionOutput( index=self.request_index, @@ -261,29 +289,26 @@ class RequestState: logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None) + stop_reason=stop_reason if finished else None, + ) def _new_pooling_output( self, pooling_output: torch.Tensor, ) -> PoolingOutput: - return PoolingOutput(data=pooling_output) class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__( - self, - tokenizer: TokenizerGroup, - log_stats: bool, - ): + def __init__(self, tokenizer: AnyTokenizer, log_stats: bool): self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() + self.tracer: Optional[Tracer] = None def get_num_unfinished_requests(self): return len(self.request_states) @@ -310,8 +335,18 @@ class OutputProcessor: request_ids_to_abort.append(request_id) # Produce final abort output. if req_state.queue is not None and ( - request_output := req_state.make_request_output( - [], None, FinishReason.ABORT, None, None)): + request_output := req_state.make_request_output( + new_token_ids=[], + # Set pooling_output is not None to + # correctly enter the abort pooling branch + pooling_output=torch.randn(0, device="cpu") + if req_state.detokenizer is None + else None, + finish_reason=FinishReason.ABORT, + stop_reason=None, + kv_transfer_params=None, + ) + ): req_state.queue.put(request_output) elif parent := self.parent_requests.get(request_id): # Abort children prior to removing the parent. @@ -334,16 +369,15 @@ class OutputProcessor: if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - tokenizer = None if not self.tokenizer else \ - self.tokenizer.get_lora_tokenizer(request.lora_request) - - req_state = RequestState.from_new_request(tokenizer=tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request( + tokenizer=self.tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats, + ) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: @@ -360,25 +394,24 @@ class OutputProcessor: 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. NOTE FOR DEVELOPERS vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from within the loop below. """ - request_outputs: Union[list[RequestOutput], - list[PoolingRequestOutput]] = [] + request_outputs: Union[list[RequestOutput], list[PoolingRequestOutput]] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id @@ -388,9 +421,9 @@ class OutputProcessor: continue # 1) Compute stats for this iteration. - self._update_stats_from_output(req_state, engine_core_output, - engine_core_timestamp, - iteration_stats) + self._update_stats_from_output( + req_state, engine_core_output, engine_core_timestamp, iteration_stats + ) new_token_ids = engine_core_output.new_token_ids pooling_output = engine_core_output.pooling_output @@ -405,20 +438,24 @@ class OutputProcessor: assert req_state.logprobs_processor is not None # 2) Detokenize the token ids into text and perform stop checks. stop_string = req_state.detokenizer.update( - new_token_ids, finish_reason == FinishReason.STOP) + new_token_ids, finish_reason == FinishReason.STOP + ) if stop_string: finish_reason = FinishReason.STOP stop_reason = stop_string # 3) Compute sample and prompt logprobs for request, # if required. - req_state.logprobs_processor.update_from_output( - engine_core_output) + req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params): + new_token_ids, + pooling_output, + finish_reason, + stop_reason, + kv_transfer_params, + ): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) @@ -439,9 +476,11 @@ class OutputProcessor: reqs_to_abort.append(req_id) # Track per-request stats - self._update_stats_from_finished(req_state, finish_reason, - iteration_stats) - + self._update_stats_from_finished( + req_state, finish_reason, iteration_stats + ) + if self.tracer: + self.do_tracing(engine_core_output, req_state, iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( @@ -449,10 +488,76 @@ class OutputProcessor: reqs_to_abort=reqs_to_abort, ) - def _update_stats_from_output(self, req_state: RequestState, - engine_core_output: EngineCoreOutput, - engine_core_timestamp: Optional[float], - iteration_stats: Optional[IterationStats]): + def do_tracing( + self, + engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats], + ) -> None: + assert req_state.stats is not None + assert iteration_stats is not None + assert self.tracer is not None + + arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) + trace_context = extract_trace_context(engine_core_output.trace_headers) + prompt_length = length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds + ) + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds, + ) as span: + metrics = req_state.stats + e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time + queued_time = metrics.scheduled_ts - metrics.queued_ts + prefill_time = metrics.first_token_ts - metrics.scheduled_ts + decode_time = metrics.last_token_ts - metrics.first_token_ts + inference_time = metrics.last_token_ts - metrics.scheduled_ts + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, + metrics.first_token_latency, + ) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length) + span.set_attribute( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens, + ) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time + ) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time + ) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time + ) + + # meta + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id) + if req_state.top_p: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p) + if req_state.max_tokens_param: + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param + ) + if req_state.temperature: + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature + ) + if req_state.n: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n) + + def _update_stats_from_output( + self, + req_state: RequestState, + engine_core_output: EngineCoreOutput, + engine_core_timestamp: Optional[float], + iteration_stats: Optional[IterationStats], + ): if iteration_stats is None: return @@ -460,15 +565,21 @@ class OutputProcessor: assert engine_core_timestamp is not None assert req_state.stats is not None - iteration_stats.update_from_output(engine_core_output, - engine_core_timestamp, - req_state.is_prefilling, - req_state.prompt_len, - req_state.stats, lora_stats) + iteration_stats.update_from_output( + engine_core_output, + engine_core_timestamp, + req_state.is_prefilling, + req_state.prompt_len, + req_state.stats, + lora_stats, + ) - def _update_stats_from_finished(self, req_state: RequestState, - finish_reason: Optional[FinishReason], - iteration_stats: Optional[IterationStats]): + def _update_stats_from_finished( + self, + req_state: RequestState, + finish_reason: Optional[FinishReason], + iteration_stats: Optional[IterationStats], + ): if iteration_stats is None: return @@ -476,11 +587,14 @@ class OutputProcessor: assert req_state.stats is not None iteration_stats.update_from_finished_request( finish_reason=finish_reason, - num_prompt_tokens=len(req_state.prompt_token_ids), + num_prompt_tokens=length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds + ), max_tokens_param=req_state.max_tokens_param, - req_stats=req_state.stats) + req_stats=req_state.stats, + ) self.lora_states.finish_request(req_state) ParentRequest.observe_finished_request( - req_state.parent_req, iteration_stats, - req_state.stats.num_generation_tokens) + req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens + ) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 1e9911152c6df..daf115c0325ff 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -31,15 +31,16 @@ class ParentRequest: # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] - def __init__(self, request_id: str, - sampling_params: SamplingParams) -> None: + def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params self.child_requests = set() - self.output_aggregator = [None] * sampling_params.n if ( - sampling_params.output_kind - == RequestOutputKind.FINAL_ONLY) else [] + self.output_aggregator = ( + [None] * sampling_params.n + if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) + else [] + ) self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @@ -49,7 +50,7 @@ class ParentRequest: ) -> SamplingParams: """Efficiently obtain child `sampling_params` - If `sampling_params.seed` is not `None` then + If `sampling_params.seed` is not `None` then each child request requires a unique clone of parent `sampling_params` with a unique seed. @@ -76,10 +77,10 @@ class ParentRequest: def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. - + Args: index: index within `n` child requests. - + Returns: (request ID, sampling_params) tuple """ @@ -111,23 +112,25 @@ class ParentRequest: return self.request_id, outputs, finished def observe_num_generation_tokens(self, num_generation_tokens: int): - self.max_num_generation_tokens = max(num_generation_tokens, - self.max_num_generation_tokens) + self.max_num_generation_tokens = max( + num_generation_tokens, self.max_num_generation_tokens + ) return self.max_num_generation_tokens @staticmethod - def observe_finished_request(parent_req: Optional['ParentRequest'], - iteration_stats: IterationStats, - num_generation_tokens: int): - + def observe_finished_request( + parent_req: Optional["ParentRequest"], + iteration_stats: IterationStats, + num_generation_tokens: int, + ): n_param = parent_req.n if parent_req is not None else 1 if parent_req is not None: num_generation_tokens = parent_req.observe_num_generation_tokens( - num_generation_tokens) + num_generation_tokens + ) # Child requests finished, we can now record to iteration stats if parent_req is None or not parent_req.child_requests: - iteration_stats.max_num_generation_tokens_iter.append( - num_generation_tokens) + iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens) iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 69f8e531e01b1..d106783d6dc12 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -9,52 +9,63 @@ from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.cache import processor_cache_from_config +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient -from vllm.v1.structured_output.backend_guidance import ( - validate_guidance_grammar) +from vllm.v1.metrics.stats import MultiModalCacheStats +from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar +from vllm.v1.structured_output.backend_lm_format_enforcer import ( + validate_structured_output_request_lm_format_enforcer, +) from vllm.v1.structured_output.backend_outlines import ( - validate_structured_output_request_outlines) -from vllm.v1.structured_output.backend_xgrammar import ( - validate_xgrammar_grammar) + validate_structured_output_request_outlines, +) +from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar + +logger = init_logger(__name__) class Processor: - def __init__( self, vllm_config: VllmConfig, - tokenizer: TokenizerGroup, + tokenizer: Optional[AnyTokenizer], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - + ) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config - self.decoding_config = vllm_config.decoding_config - self.tokenizer = tokenizer + self.structured_outputs_config = vllm_config.structured_outputs_config - self.generation_config_fields = ( - self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) + self.generation_config_fields = self.model_config.try_get_generation_config() - self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config, mm_registry) + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) + + self.input_preprocessor = InputPreprocessor( + self.model_config, + tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) @property - def mm_registry(self): - return self.input_preprocessor.mm_registry + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.input_preprocessor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.input_preprocessor.tokenizer = tokenizer def _validate_logprobs( self, @@ -62,24 +73,33 @@ class Processor: ) -> None: max_logprobs = self.model_config.max_logprobs if max_logprobs == -1: - return + max_logprobs = self.model_config.get_vocab_size() + # Validate sample logprobs. - if params.logprobs and (params.logprobs == -1 - or params.logprobs > max_logprobs): - raise ValueError( - f"Requested sample logprobs of {params.logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.logprobs: + num_logprobs = params.logprobs + if num_logprobs == -1: + num_logprobs = self.model_config.get_vocab_size() + if num_logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {num_logprobs}, " + f"which is greater than max allowed: {max_logprobs}" + ) # Validate prompt logprobs. - if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: - raise ValueError( - f"Requested prompt logprobs of {params.prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.prompt_logprobs: + num_prompt_logprobs = params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.model_config.get_vocab_size() + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {num_prompt_logprobs}, " + f"which is greater than max allowed: {max_logprobs}" + ) def _validate_sampling_params( self, params: SamplingParams, - lora_request: Optional[LoRARequest], ) -> None: self._validate_structured_output(params) self._validate_logit_bias(params) @@ -92,11 +112,9 @@ class Processor: # When skip_tokenizer_init=True, we can't validate token IDs # Skip validation and let the model handle invalid tokens return - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) - vocab_size = len(tokenizer) + vocab_size = len(self.tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): - raise ValueError( - "allowed_token_ids contains out-of-vocab token id!") + raise ValueError("allowed_token_ids contains out-of-vocab token id!") def _validate_logit_bias( self, @@ -116,7 +134,8 @@ class Processor: if invalid_token_ids: raise ValueError( f"token_id(s) {invalid_token_ids} in logit_bias contain " - f"out-of-vocab token ids. Vocabulary size: {vocab_size}") + f"out-of-vocab token ids. Vocabulary size: {vocab_size}" + ) def _validate_supported_sampling_params( self, @@ -127,13 +146,13 @@ class Processor: raise ValueError("vLLM V1 does not yet support best_of.") # Logits processors not supported. if params.logits_processors: - raise ValueError("vLLM V1 does not support per request " - "user provided logits processors.") + raise ValueError( + "vLLM V1 does not support per request user provided logits processors." + ) def _validate_params( self, params: Union[SamplingParams, PoolingParams], - lora_request: Optional[LoRARequest], ): """ Validate supported SamplingParam. @@ -144,80 +163,180 @@ class Processor: return self._validate_logprobs(params) - self._validate_sampling_params(params, lora_request) + self._validate_sampling_params(params) self._validate_supported_sampling_params(params) - def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: + """ + Validate that user-provided multi_modal_uuids align with + multi_modal_data in the incoming request prompt(s). + Only checks lengths; `None` entries are allowed and will be + auto-hashed downstream. + """ - def _validate_structured_output(self, params: SamplingParams) -> None: - if not params.guided_decoding or not self.decoding_config: + def _validate_single_prompt(single_prompt: Union[dict, str]) -> None: + if not isinstance(single_prompt, dict): + return + mm_data = single_prompt.get("multi_modal_data") + mm_uuids = single_prompt.get("multi_modal_uuids") + if not mm_data or not mm_uuids: + return + + for modality, items in mm_data.items(): + if modality in mm_uuids: + data_len = len(items) if isinstance(items, list) else 1 + uuid_len = ( + len(mm_uuids[modality]) + if isinstance(mm_uuids[modality], list) + else 1 + ) + if uuid_len != data_len: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' " + "must have same length as data: got " + f"{uuid_len} uuids vs " + f"{data_len} items." + ) + else: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' must " + "be provided if multi_modal_data is provided." + ) + + # Handle explicit encoder/decoder prompts or singleton prompt + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + enc = prompt.get("encoder_prompt") + dec = prompt.get("decoder_prompt") + if enc is not None: + _validate_single_prompt(enc) + if dec is not None: + _validate_single_prompt(dec) + else: + _validate_single_prompt(prompt) # type: ignore[arg-type] + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is None: return - if self.model_config.skip_tokenizer_init and params.guided_decoding: + # LoRA request passed in while LoRA is not enabled + if not self.lora_config: + raise ValueError( + f"Got lora_request {lora_request} but LoRA is not enabled!" + ) + + if self.tokenizer is not None: + logger.warning_once( + "vLLM has deprecated support for supporting different " + "tokenizers for different LoRAs. By default, vLLM uses base " + "model's tokenizer. If you are using a LoRA " + "with its own tokenizer, consider specifying `--tokenizer " + "[lora_path]` to use the LoRA tokenizer." + ) + + def _validate_structured_output(self, params: SamplingParams) -> None: + if not params.structured_outputs or not self.structured_outputs_config: + return + + if self.model_config.skip_tokenizer_init and params.structured_outputs: raise ValueError( "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 ) - engine_level_backend = self.decoding_config.backend - if params.guided_decoding.backend: - # Request-level backend selection is not supported in V1. + backend = self.structured_outputs_config.backend + if _backend := params.structured_outputs._backend: + # Request-level backend selection is not supported. # The values may differ if `params` is reused and was set # to a specific backend based on `auto` behavior in a previous # request. We remember that it was set as a result of `auto` - # using the `_auto` option set on the backend in the params. - if (params.guided_decoding.backend != engine_level_backend - and not (engine_level_backend == "auto" - and params.guided_decoding.backend_was_auto)): + # using the `_backend_was_auto` field set in the params. + if backend != _backend and not ( + backend == "auto" and params.structured_outputs._backend_was_auto + ): raise ValueError( - "Request-level structured output backend selection is no " - "longer supported. The request specified " - f"'{params.guided_decoding.backend}', but vLLM was " - f"initialised with '{engine_level_backend}'. This error " - "can be resolved by removing backend selection from the " - "request.") + "Request-level structured output backend selection is not " + f"supported. The request specified '{_backend}', but vLLM " + f"was initialised with '{backend}'. This error can be " + "resolved by removing '_backend' from the request." + ) else: - params.guided_decoding.backend = engine_level_backend + params.structured_outputs._backend = backend # Request content validation - if (isinstance(params.guided_decoding.choice, list) - and not params.guided_decoding.choice): + if ( + isinstance(params.structured_outputs.choice, list) + and not params.structured_outputs.choice + ): # It is invalid for choice to be an empty list - raise ValueError(f"Choice '{params.guided_decoding.choice}' " - "cannot be an empty list") + raise ValueError( + f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501 + ) - if engine_level_backend.startswith("xgrammar"): + if backend.startswith("xgrammar"): # xgrammar with no fallback validate_xgrammar_grammar(params) - elif engine_level_backend.startswith("guidance"): + elif backend.startswith("guidance"): # TODO: ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. validate_guidance_grammar(params, tokenizer=None) - elif engine_level_backend == "outlines": + elif backend == "outlines": # outlines backend validate_structured_output_request_outlines(params) + elif backend == "lm-format-enforcer": + # lm format enforcer backend + validate_structured_output_request_lm_format_enforcer(params) else: - # NOTE: engine_level_backend must be "auto" here, because we have + # NOTE: backend must be "auto" here, because we have # checked supported_backends above. - # "auto" is an opt-in to opinionated behavior where we try to - # choose a backend based on request contents. This is not the - # default as it is less predictable and subject to change - # between releases as feature support changes. + # In this mode, we set opinionated defaults based on what we think + # will satisfy the most use cases without having to worry about + # this setting. We include fallback behavior here, but not with any + # other setting where a specific backend was specified. try: validate_xgrammar_grammar(params) - params.guided_decoding.backend = "xgrammar" + params.structured_outputs._backend = "xgrammar" except ValueError: # The request either failed validation # or includes some jsonschema feature(s) that # are not supported in xgrammar. Fall back to guidance. validate_guidance_grammar(params, tokenizer=None) - params.guided_decoding.backend = "guidance" + params.structured_outputs._backend = "guidance" # Remember that this backend was set automatically - params.guided_decoding.backend_was_auto = True + params.structured_outputs._backend_was_auto = True + + def _maybe_build_mm_uuids( + self, + request_id: str, + prompt: PromptType, + ) -> Optional[MultiModalUUIDDict]: + """Build per-item multimodal hash overrides when enabled. In this case, + multimodal data items are identified by their request id, modality and + index rather than their content. + + Returns a dictionary of modality -> list[str] of overrides, or None if + disabled or no multimodal data is present. + """ + + def _extract_mm_data(p: PromptType): + if isinstance(p, dict) and "encoder_prompt" in p: + enc = p.get("encoder_prompt") + if isinstance(enc, dict): + return enc.get("multi_modal_data") + return None + if isinstance(p, dict): + return p.get("multi_modal_data") + return None + + mm_data = _extract_mm_data(prompt) + if not mm_data: + return None + + mm_uuids: MultiModalUUIDDict = {} + for modality, data in mm_data.items(): + n = len(data) if isinstance(data, list) else 1 + mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] + return mm_uuids def process_inputs( self, @@ -230,24 +349,45 @@ class Processor: trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, - ) -> tuple[Optional[str], EngineCoreRequest]: - - # TODO(woosuk): Support pooling models. - # TODO(woosuk): Support encoder-decoder models. + ) -> EngineCoreRequest: self._validate_lora(lora_request) - self._validate_params(params, lora_request) - if trace_headers is not None: - raise ValueError("V1 does not support tracing yet.") + self._validate_params(params) data_parallel_size = self.vllm_config.parallel_config.data_parallel_size - if data_parallel_rank is not None and not (0 <= data_parallel_rank < - data_parallel_size): - raise ValueError(f"data_parallel_rank {data_parallel_rank} " - f"is out of range [0, {data_parallel_size}).") + if data_parallel_rank is not None and not ( + 0 <= data_parallel_rank < data_parallel_size + ): + raise ValueError( + f"data_parallel_rank {data_parallel_rank} " + f"is out of range [0, {data_parallel_size})." + ) if arrival_time is None: arrival_time = time.time() + # Optionally generate multimodal hash overrides to avoid hashing + # multimodal data items by their content as their identifiers. + + # NOTE: when users explicitly turn off BOTH prefix caching and input + # processing caching, no multimodal features or embeddings will be + # reused across requests, therefore identifying multimodal data items + # by their content is no longer necessary, and we create uuids with + # request id-modality-index as multimodal hash overrides. + if ( + self.model_config.multimodal_config + and self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching + ): + mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) + else: + # Otherwise, use user-provided uuids as multimodal hash overrides + # if provided. + self._validate_multi_modal_uuids(prompt) + if isinstance(prompt, dict): + mm_uuids = prompt.get("multi_modal_uuids") + else: + mm_uuids = None + # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess @@ -255,23 +395,35 @@ class Processor: processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, + mm_uuids=mm_uuids, ) from vllm.platforms import current_platform + current_platform.validate_request( prompt=prompt, params=params, processed_inputs=processed_inputs, ) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - self._validate_model_inputs(processed_inputs, lora_request) + eos_token_id = self.input_preprocessor.get_eos_token_id() encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + self._validate_model_inputs(encoder_inputs, decoder_inputs) - # TODO: Impl encoder-decoder - if encoder_inputs is not None: - raise NotImplementedError + # Mypy does not always properly infer the types of some elements of + # discriminated unions of TypedDicts, because of how it handles + # inheritance of TypedDict. If we explicitly extract the items we want + # we can avoid type errors from using `dict.get` later in the method. + prompt_token_ids = ( + decoder_inputs["prompt_token_ids"] + if decoder_inputs["type"] != "embeds" + else None + ) + prompt_embeds = ( + decoder_inputs["prompt_embeds"] + if decoder_inputs["type"] == "embeds" + else None + ) sampling_params = None pooling_params = None @@ -280,21 +432,21 @@ class Processor: sampling_params = params.clone() # If unset max tokens, then generate up to the max_model_len. if sampling_params.max_tokens is None: - sampling_params.max_tokens = ( - self.model_config.max_model_len - - len(decoder_inputs["prompt_token_ids"])) + seq_len = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds + ) + sampling_params.max_tokens = self.model_config.max_model_len - seq_len sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) + self.generation_config_fields, eos_token_id + ) if self.tokenizer is not None: - sampling_params.update_from_tokenizer( - self.tokenizer.get_lora_tokenizer(lora_request)) + sampling_params.update_from_tokenizer(self.tokenizer) else: pooling_params = params.clone() # Multimodal related. - sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None - sorted_mm_positions: Optional[list[PlaceholderRange]] = None - sorted_mm_hashes: Optional[list[str]] = None + mm_features: Optional[list[MultiModalFeatureSpec]] = None + if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_positions = decoder_inputs["mm_placeholders"] @@ -305,30 +457,22 @@ class Processor: # in the input sequence. sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - orig_sorted_mm_inputs = [ - decoder_mm_inputs[modality][idx] - for modality, idx in sorted_mm_idxs - ] - sorted_mm_positions = [ - decoder_mm_positions[modality][idx] - for modality, idx in sorted_mm_idxs - ] - sorted_mm_hashes = [ - decoder_mm_hashes[modality][idx] - for modality, idx in sorted_mm_idxs - ] + mm_features = [] + for modality, idx in sorted_mm_idxs: + mm_features.append( + MultiModalFeatureSpec( + data=decoder_mm_inputs[modality][idx], + modality=modality, + identifier=decoder_mm_hashes[modality][idx], + mm_position=decoder_mm_positions[modality][idx], + ) + ) - sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - orig_sorted_mm_inputs, - sorted_mm_hashes, - ) - - return decoder_inputs.get("prompt"), EngineCoreRequest( + return EngineCoreRequest( request_id=request_id, - prompt_token_ids=decoder_inputs["prompt_token_ids"], - mm_kwargs=sorted_mm_inputs, - mm_hashes=sorted_mm_hashes, - mm_placeholders=sorted_mm_positions, + prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, + mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, eos_token_id=eos_token_id, @@ -337,49 +481,65 @@ class Processor: cache_salt=decoder_inputs.get("cache_salt"), priority=priority, data_parallel_rank=data_parallel_rank, + trace_headers=trace_headers, ) - def _validate_model_inputs(self, - inputs: ProcessorInputs, - lora_request: Optional[LoRARequest] = None): - encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - + def _validate_model_inputs( + self, encoder_inputs: Optional[SingletonInputs], decoder_inputs: SingletonInputs + ): if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, - lora_request, - prompt_type="encoder") + self._validate_model_input(encoder_inputs, prompt_type="encoder") - self._validate_model_input(decoder_inputs, - lora_request, - prompt_type="decoder") + self._validate_model_input(decoder_inputs, prompt_type="decoder") def _validate_model_input( self, prompt_inputs: SingletonInputs, - lora_request: Optional[LoRARequest], *, prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - prompt_ids = prompt_inputs["prompt_token_ids"] + prompt_ids = ( + None + if prompt_inputs["type"] == "embeds" + else prompt_inputs["prompt_token_ids"] + ) + prompt_embeds = ( + prompt_inputs["prompt_embeds"] + if prompt_inputs["type"] == "embeds" + else None + ) + prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data + elif prompt_inputs["type"] == "embeds": + pass # Prompt embeds should not have prompt_ids. else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") + tokenizer = self.tokenizer + if tokenizer is not None: + max_input_id = max(prompt_ids or [], default=0) + + # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while + # self.model_config.get_vocab_size() is the model’s vocab size. + # For Qwen3 models, the language model has extra tokens that do + # not exist in the tokenizer, and vice versa for multimodal + # placeholder tokens in some multimodal models. + # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501 + # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501 + + # Here we take the max of the two to determine if a token id is + # truly out-of-vocabulary. + if max_input_id > max( + tokenizer.max_token_id, self.model_config.get_vocab_size() - 1 + ): + raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) > max_prompt_len: + if prompt_len > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( @@ -396,17 +556,26 @@ class Processor: "Make sure that `max_model_len` is no smaller than the " "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") + "of images, and possibly their aspect ratios as well." + ) else: suggestion = ( "Make sure that `max_model_len` is no smaller than the " - "number of text tokens.") + "number of text tokens." + ) raise ValueError( - f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"The {prompt_type} prompt (length {prompt_len}) is " f"longer than the maximum model length of {max_prompt_len}. " - f"{suggestion}") + f"{suggestion}" + ) # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def stat_mm_cache(self) -> Optional[MultiModalCacheStats]: + return self.input_preprocessor.stat_mm_cache() + + def clear_mm_cache(self) -> None: + self.input_preprocessor.clear_mm_cache() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 62f229e286931..ac2a6b997e9fe 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -15,6 +15,7 @@ from unittest.mock import patch import msgspec import zmq +from vllm import envs from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -70,8 +71,10 @@ class EngineHandshakeMetadata: including addresses of the front-end ZMQ queues that they should connect to. """ + addresses: EngineZmqAddresses parallel_config: dict[str, Union[int, str, list[int]]] + parallel_config_hash: Optional[str] = None class CoreEngineProcManager: @@ -103,8 +106,7 @@ class CoreEngineProcManager: } if client_handshake_address: - common_kwargs[ - "client_handshake_address"] = client_handshake_address + common_kwargs["client_handshake_address"] = client_handshake_address self.processes: list[BaseProcess] = [] local_dp_ranks = [] @@ -115,21 +117,27 @@ class CoreEngineProcManager: # Start EngineCore in background process. local_dp_ranks.append(local_index) self.processes.append( - context.Process(target=target_fn, - name=f"EngineCore_{global_index}", - kwargs=common_kwargs | { - "dp_rank": global_index, - "local_dp_rank": local_index, - })) + context.Process( + target=target_fn, + name=f"EngineCore_DP{global_index}", + kwargs=common_kwargs + | { + "dp_rank": global_index, + "local_dp_rank": local_index, + }, + ) + ) self._finalizer = weakref.finalize(self, shutdown, self.processes) data_parallel = vllm_config.parallel_config.data_parallel_size > 1 try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): - with set_device_control_env_var( - vllm_config, local_dp_rank) if ( - data_parallel) else contextlib.nullcontext(): + with ( + set_device_control_env_var(vllm_config, local_dp_rank) + if (data_parallel) + else contextlib.nullcontext() + ): proc.start() finally: # Kill other procs if not all are running. @@ -151,32 +159,51 @@ class CoreEngineProcManager: """Returns dict of proc name -> exit code for any finished procs.""" return { proc.name: proc.exitcode - for proc in self.processes if proc.exitcode is not None + for proc in self.processes + if proc.exitcode is not None } @contextlib.contextmanager -def set_device_control_env_var(vllm_config: VllmConfig, - local_dp_rank: int) -> Iterator[None]: +def set_device_control_env_var( + vllm_config: VllmConfig, local_dp_rank: int +) -> Iterator[None]: """ Temporarily set CUDA_VISIBLE_DEVICES or equivalent for engine subprocess. """ world_size = vllm_config.parallel_config.world_size evar = current_platform.device_control_env_var + + value = get_device_indices(evar, local_dp_rank, world_size) + with patch.dict(os.environ, values=((evar, value),)): + yield + + +def get_device_indices( + device_control_env_var: str, local_dp_rank: int, world_size: int +): + """ + Returns a comma-separated string of device indices for the specified + data parallel rank. + + For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, + this will select devices 2 and 3 for local_dp_rank=1. + """ try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * - world_size)) + for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size) + ) except IndexError as e: - raise Exception(f"Error setting {evar}: " - f"local range: [{local_dp_rank * world_size}, " - f"{(local_dp_rank + 1) * world_size}) " - "base value: " - f"\"{os.getenv(evar)}\"") from e - with patch.dict(os.environ, values=((evar, value), )): - yield + raise Exception( + f"Error setting {device_control_env_var}: " + f"local range: [{local_dp_rank * world_size}, " + f"{(local_dp_rank + 1) * world_size}) " + "base value: " + f'"{os.getenv(device_control_env_var)}"' + ) from e + return value class CoreEngineActorManager: @@ -201,8 +228,7 @@ class CoreEngineActorManager: import ray from ray.runtime_env import RuntimeEnv - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.v1.engine.core import DPEngineCoreActor @@ -211,8 +237,7 @@ class CoreEngineActorManager: env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") self.env_vars_dict = { - name: os.environ[name] - for name in env_vars_list if name in os.environ + name: os.environ[name] for name in env_vars_list if name in os.environ } runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) @@ -220,52 +245,74 @@ class CoreEngineActorManager: self.executor_class = executor_class self.log_stats = log_stats dp_size = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local + local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size if ray.is_initialized(): - logger.info( - "Ray is already initialized. Skipping Ray initialization.") + logger.info("Ray is already initialized. Skipping Ray initialization.") else: ray.init() if placement_groups is not None: assert local_dp_ranks is not None, ( - "local_dp_ranks must be provided if " - "placement_groups is provided") + "local_dp_ranks must be provided if placement_groups is provided" + ) assert len(placement_groups) == len(local_dp_ranks), ( - "placement_groups and local_dp_ranks must " - "have the same length") + "placement_groups and local_dp_ranks must have the same length" + ) logger.info("Using provided placement groups") # TODO(rui): validate passed-in placement groups self.created_placement_groups = [] else: - placement_groups, local_dp_ranks = \ + placement_groups, local_dp_ranks = ( CoreEngineActorManager.create_dp_placement_groups(vllm_config) + ) self.created_placement_groups = placement_groups assert len(placement_groups) == dp_size, ( - "Number of placement groups must match data parallel size") + "Number of placement groups must match data parallel size" + ) self.placement_group_is_local = [] refs = [] - for index, local_index, pg in zip(range(dp_size), local_dp_ranks, - placement_groups): + for index, local_index, pg in zip( + range(dp_size), local_dp_ranks, placement_groups + ): dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - ), - runtime_env=runtime_env).remote(vllm_config=dp_vllm_config, - executor_class=executor_class, - log_stats=log_stats, - local_client=local_client, - addresses=addresses, - dp_rank=index, - local_dp_rank=local_index) + + # Ray XPU known issue: dpctl initializes the GPU runtime early, so + # setting device env vars in Ray actor's initialization method + # will not affect device selection. See: + # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 + if current_platform.is_xpu(): + device_evar = current_platform.device_control_env_var + device_indices = get_device_indices( + device_evar, local_index, world_size + ) + actor_env_vars = self.env_vars_dict.copy() + actor_env_vars[device_evar] = device_indices + runtime_env = RuntimeEnv(env_vars=actor_env_vars) + + actor = ( + ray.remote(DPEngineCoreActor) + .options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env, + ) + .remote( + vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + local_client=local_client, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index, + ) + ) if local_client: self.local_engine_actors.append(actor) else: @@ -280,7 +327,7 @@ class CoreEngineActorManager: @staticmethod def create_dp_placement_groups( - vllm_config: VllmConfig + vllm_config: VllmConfig, ) -> tuple[list["PlacementGroup"], list[int]]: """ Create placement groups for data parallel. @@ -288,73 +335,109 @@ class CoreEngineActorManager: import ray from ray._private.state import available_resources_per_node - from ray.util.state import list_nodes logger.info("Creating placement groups for data parallel") - dp_master_ip = \ - vllm_config.parallel_config.data_parallel_master_ip - num_pg_to_create = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local - - nodes = sorted(list_nodes(filters=[("state", "=", "ALIVE")]), - key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The head node is missing or dead") - assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") + dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip + dp_size = vllm_config.parallel_config.data_parallel_size + dp_size_local = vllm_config.parallel_config.data_parallel_size_local available_resources = available_resources_per_node() world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] + dp_master_ip_key = f"node:{dp_master_ip}" + nodes = sorted( + available_resources.values(), key=lambda x: dp_master_ip_key not in x + ) + assert len(nodes) > 0, "No nodes with resources found in Ray cluster." + assert dp_master_ip_key in nodes[0], ( + "The DP master node (ip: %s) is missing or dead", + dp_master_ip, + ) + device_str = current_platform.ray_device_key + + if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and ( + envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" + ): + raise ValueError( + "DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) " + "to be on the same node, but VLLM_RAY_DP_PACK_STRATEGY=fill " + "does not guarantee that. " + "Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead." + ) + logger.info( + "Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY", + envs.VLLM_RAY_DP_PACK_STRATEGY, + ) + strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict" + + for node_resources in nodes: + node_ip_keys = [ + key + for key in node_resources + if key != "node:__internal_head__" and key.startswith("node:") + ] + assert len(node_ip_keys) == 1, ( + "Zero or multiple node IP keys found in node resources: %s", + node_ip_keys, + ) + node_ip_key = node_ip_keys[0] + node_ip = node_ip_key.split(":")[1] - for node in nodes: - node_ip = node.node_ip - node_resources = available_resources[node.node_id] - if "GPU" not in node_resources: - continue # For now, each DP rank can only be assigned to one node # TODO(rui): support allocating a single DP rank # to multiple nodes - available_engine_count = int(node_resources["GPU"]) // world_size + dp_size_available = ( + int(node_resources[device_str]) // world_size + if device_str in node_resources + else 0 + ) + if node_ip == dp_master_ip: - assert available_engine_count >= local_engine_count, ( - "Not enough resources to allocate DP ranks " - f"on DP master node {node_ip}") - for i in range(local_engine_count): - bundles = [{ - "GPU": 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] - pg = ray.util.placement_group( - name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", - bundles=bundles, + if dp_size_available < dp_size_local: + raise ValueError( + "Not enough resources to allocate %s DP ranks " + "on DP master node %s, possible to fit %s DP ranks", + dp_size_local, + dp_master_ip, + dp_size_available, ) - placement_groups.append(pg) - local_dp_ranks.append(i) + dp_size_to_allocate = dp_size_local + elif strict_local_size: + if dp_size_available < dp_size_local: + logger.info( + "Skipping node %s as %s DP ranks could not fit, " + "possible to fit %s DP ranks", + node_ip, + dp_size_local, + dp_size_available, + ) + continue + dp_size_to_allocate = dp_size_local else: - for i in range(available_engine_count): - if len(placement_groups) == num_pg_to_create: - break - bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] - pg = ray.util.placement_group( - name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", - bundles=bundles, - ) - placement_groups.append(pg) - local_dp_ranks.append(i) - if len(placement_groups) < num_pg_to_create: + dp_size_to_allocate = dp_size_available + + for i in range(dp_size_to_allocate): + bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [ + {"CPU": 1.0} + ] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + + if len(placement_groups) < dp_size: raise ValueError( - f"Not enough resources to allocate {num_pg_to_create} " + f"Not enough resources to allocate {dp_size} " "placement groups, only created " f"{len(placement_groups)} placement groups. " "Available resources: " - f"{available_resources}") + f"{available_resources}" + ) return placement_groups, local_dp_ranks @staticmethod @@ -365,8 +448,10 @@ class CoreEngineActorManager: Add placement groups for new data parallel size. """ import ray - from ray._private.state import (available_resources_per_node, - total_resources_per_node) + from ray._private.state import ( + available_resources_per_node, + total_resources_per_node, + ) from ray.util.state import list_nodes old_dp_size = old_vllm_config.parallel_config.data_parallel_size @@ -380,10 +465,10 @@ class CoreEngineActorManager: nodes = list_nodes() nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The first node must be the head node") + assert nodes[0].node_ip == dp_master_ip, "The first node must be the head node" assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") + "There can only be one head node" + ) available_resources = available_resources_per_node() total_resources = total_resources_per_node() @@ -392,17 +477,18 @@ class CoreEngineActorManager: local_dp_ranks = [] num_pg_created = 0 + device_str = current_platform.ray_device_key for node in nodes: if num_pg_created >= num_pg_to_create: break node_ip = node.node_ip node_id = node.node_id - available_gpus = int(available_resources[node_id]["GPU"]) + available_gpus = int(available_resources[node_id][device_str]) # Get total GPUs on this node from the node's resources # Ray stores node resources with node ID as key - total_gpus = int(total_resources[node_id]["GPU"]) + total_gpus = int(total_resources[node_id][device_str]) # Calculate used GPUs and used engines on this node used_gpus = max(0, total_gpus - available_gpus) @@ -420,14 +506,11 @@ class CoreEngineActorManager: # Create bundles with node constraint for master node if node_ip == dp_master_ip: - bundles = [{ - "GPU": 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] + bundles = [ + {device_str: 1.0, "node:" + dp_master_ip: 0.001} + ] * world_size + [{"CPU": 1.0}] else: - bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( name=f"dp_rank_{rank}", @@ -444,69 +527,76 @@ class CoreEngineActorManager: return placement_groups, local_dp_ranks - def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig, - new_data_parallel_size: int) -> None: + def scale_up_elastic_ep( + self, cur_vllm_config: VllmConfig, new_data_parallel_size: int + ) -> None: import copy import ray from ray.runtime_env import RuntimeEnv - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.v1.engine.core import DPEngineCoreActor - cur_data_parallel_size = len(self.local_engine_actors) + \ - len(self.remote_engine_actors) + cur_data_parallel_size = len(self.local_engine_actors) + len( + self.remote_engine_actors + ) assert new_data_parallel_size > cur_data_parallel_size, ( f"New data parallel size {new_data_parallel_size} must be greater " f"than current data parallel size {cur_data_parallel_size} " - "for scale up") + "for scale up" + ) - placement_groups, local_dp_ranks = \ - self.add_dp_placement_groups( - cur_vllm_config, new_data_parallel_size) + placement_groups, local_dp_ranks = self.add_dp_placement_groups( + cur_vllm_config, new_data_parallel_size + ) world_size = cur_vllm_config.parallel_config.world_size dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip new_local_engines = 0 - runtime_env = RuntimeEnv(env_vars=self.env_vars_dict - | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"}) - for i, (pg, - local_rank) in enumerate(zip(placement_groups, - local_dp_ranks)): + runtime_env = RuntimeEnv( + env_vars=self.env_vars_dict | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"} + ) + for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)): rank = cur_data_parallel_size + i dp_vllm_config = copy.deepcopy(cur_vllm_config) - dp_vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + dp_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size dp_vllm_config.parallel_config.placement_group = pg # Check if this placement group is on the head node local_client = any( - bundle.get("node:" + dp_master_ip, 0) > 0 - for bundle in pg.bundle_specs) + bundle.get("node:" + dp_master_ip, 0) > 0 for bundle in pg.bundle_specs + ) if local_client: new_local_engines += 1 # Update data_parallel_size_local dp_vllm_config.parallel_config.data_parallel_size_local = ( - cur_vllm_config.parallel_config.data_parallel_size_local + - new_local_engines) + cur_vllm_config.parallel_config.data_parallel_size_local + + new_local_engines + ) - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - ), - runtime_env=runtime_env).remote( + actor = ( + ray.remote(DPEngineCoreActor) + .options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env, + ) + .remote( vllm_config=dp_vllm_config, executor_class=self.executor_class, log_stats=self.log_stats, local_client=local_client, addresses=self.addresses, dp_rank=rank, - local_dp_rank=local_rank) + local_dp_rank=local_rank, + ) + ) if local_client: self.local_engine_actors.append(actor) @@ -515,37 +605,47 @@ class CoreEngineActorManager: self.created_placement_groups.append(pg) self.placement_group_is_local.append(local_client) - ray.get([ - actor.wait_for_init.remote() - for actor in (self.local_engine_actors[-new_local_engines:] - if new_local_engines > 0 else []) + - self.remote_engine_actors[-(len(placement_groups) - - new_local_engines):] - ]) + ray.get( + [ + actor.wait_for_init.remote() + for actor in ( + self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 + else [] + ) + + self.remote_engine_actors[ + -(len(placement_groups) - new_local_engines) : + ] + ] + ) - actors = (self.local_engine_actors[-new_local_engines:] - if new_local_engines > 0 else []) + \ - self.remote_engine_actors[-(len(placement_groups) - - new_local_engines):] + actors = ( + self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 + else [] + ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :] for actor in actors: self.run_refs.append(actor.run.remote()) - cur_vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Update old_vllm_config with new data_parallel_size_local if any new # local engines were added if new_local_engines > 0: - cur_vllm_config.parallel_config.data_parallel_size_local += \ + cur_vllm_config.parallel_config.data_parallel_size_local += ( new_local_engines + ) - def scale_down_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + def scale_down_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: import ray + assert cur_data_parallel_size > new_data_parallel_size, ( f"cur_data_parallel_size {cur_data_parallel_size} must be greater " f"than new_data_parallel_size {new_data_parallel_size} " - "for scale down") + "for scale down" + ) for _ in range(cur_data_parallel_size - new_data_parallel_size): pg = self.created_placement_groups.pop() is_local = self.placement_group_is_local.pop() @@ -560,6 +660,7 @@ class CoreEngineActorManager: def close(self): import ray + for actor in self.local_engine_actors + self.remote_engine_actors: ray.kill(actor) for pg in self.created_placement_groups: @@ -572,11 +673,13 @@ def launch_core_engines( executor_class: type[Executor], log_stats: bool, num_api_servers: int = 1, -) -> Iterator[tuple[ +) -> Iterator[ + tuple[ Optional[Union[CoreEngineProcManager, CoreEngineActorManager]], Optional[DPCoordinator], EngineZmqAddresses, -]]: + ] +]: """Launch engine and DP coordinator processes as needed.""" parallel_config = vllm_config.parallel_config @@ -585,8 +688,10 @@ def launch_core_engines( local_start_index = parallel_config.data_parallel_rank_local dp_rank = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip - local_engines_only = (parallel_config.data_parallel_hybrid_lb - or parallel_config.data_parallel_external_lb) + local_engines_only = ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) # In offline mode there is an LLM instance per DP rank and # one core engine per LLM, see @@ -595,8 +700,9 @@ def launch_core_engines( # client_local_only = True for cases where this front-end # sends requests only to colocated engines. - client_local_only = (offline_mode or local_engines_only - or (local_engine_count == dp_size)) + client_local_only = ( + offline_mode or local_engines_only or (local_engine_count == dp_size) + ) # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -618,12 +724,13 @@ def launch_core_engines( coordinator = DPCoordinator(parallel_config) addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) + coordinator.get_engine_socket_addresses() + ) addresses.frontend_stats_publish_address = ( - coordinator.get_stats_publish_address()) + coordinator.get_stats_publish_address() + ) - logger.info("Started DP Coordinator process (PID: %d)", - coordinator.proc.pid) + logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid) else: coordinator = None @@ -649,14 +756,14 @@ def launch_core_engines( # Note this also covers the case where we have zero local engines # and rank 0 is headless. engines_to_handshake = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) + CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] else: # Rank > 0 handshakes with just the local cores it is managing. assert local_engines_only, ( "Attempting to launch core_engines from dp_rank > 0, but " - "found internal DPLB, which is incompatible.") + "found internal DPLB, which is incompatible." + ) engines_to_handshake = [ CoreEngine(index=i, local=True) for i in range(dp_rank, dp_rank + local_engine_count) @@ -669,7 +776,8 @@ def launch_core_engines( handshake_local_only = offline_mode or local_engine_count == dp_size handshake_address = get_engine_client_zmq_addr( - handshake_local_only, host, parallel_config.data_parallel_rpc_port) + handshake_local_only, host, parallel_config.data_parallel_rpc_port + ) if local_engines_only and dp_rank > 0: assert not handshake_local_only @@ -679,9 +787,9 @@ def launch_core_engines( local_handshake_address = handshake_address client_handshake_address = None - with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, - bind=True) as handshake_socket: - + with zmq_socket_ctx( + local_handshake_address, zmq.ROUTER, bind=True + ) as handshake_socket: from vllm.v1.engine.core import EngineCoreProc # Start local engines. @@ -696,7 +804,8 @@ def launch_core_engines( local_client=True, local_engine_count=local_engine_count, start_index=dp_rank, - local_start_index=local_start_index or 0) + local_start_index=local_start_index or 0, + ) else: local_engine_manager = None @@ -731,8 +840,10 @@ def wait_for_engine_startup( poller = zmq.Poller() poller.register(handshake_socket, zmq.POLLIN) - remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \ + remote_should_be_headless = ( + not parallel_config.data_parallel_hybrid_lb and not parallel_config.data_parallel_external_lb + ) if proc_manager is not None: for sentinel in proc_manager.sentinels(): @@ -744,67 +855,80 @@ def wait_for_engine_startup( if not events: if any(conn_pending): logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) + "Waiting for %d local, %d remote core engine proc(s) to connect.", + *conn_pending, + ) if any(start_pending): logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) + "Waiting for %d local, %d remote core engine proc(s) to start.", + *start_pending, + ) continue if len(events) > 1 or events[0][0] != handshake_socket: # One of the local core processes exited. finished = proc_manager.finished_procs() if proc_manager else {} if coord_process is not None and coord_process.exitcode is not None: finished[coord_process.name] = coord_process.exitcode - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") + raise RuntimeError( + "Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}" + ) # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, "little") - engine = next((e for e in core_engines if e.identity == eng_identity), - None) + engine = next((e for e in core_engines if e.identity == eng_identity), None) if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") + raise RuntimeError( + f"Message from engine with unexpected data parallel rank: {eng_index}" + ) msg = msgspec.msgpack.decode(ready_msg_bytes) status, local, headless = msg["status"], msg["local"], msg["headless"] if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") + raise RuntimeError( + f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}" + ) # Remote engines must be headless iff we aren't in hybrid dp lb mode. if not local and headless != remote_should_be_headless: if headless: - raise RuntimeError(f"Remote engine {eng_index} must not use " - f"--headless in external or hybrid dp lb " - f"mode") + raise RuntimeError( + f"Remote engine {eng_index} must not use " + f"--headless in external or hybrid dp lb " + f"mode" + ) else: - raise RuntimeError(f"Remote engine {eng_index} must use " - f"--headless unless in external or hybrid " - f"dp lb mode") + raise RuntimeError( + f"Remote engine {eng_index} must use " + f"--headless unless in external or hybrid " + f"dp lb mode" + ) if status == "HELLO" and engine.state == CoreEngineState.NEW: - - # Send init message with DP config info. + # Send init message with DP config info and config hash. + # The config hash ensures all DP workers have compatible configs. init_message = msgspec.msgpack.encode( EngineHandshakeMetadata( addresses=addresses, parallel_config={ - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "_data_parallel_master_port_list": - parallel_config._data_parallel_master_port_list, - "data_parallel_size": - parallel_config.data_parallel_size, - })) - handshake_socket.send_multipart((eng_identity, init_message), - copy=False) + k: getattr(parallel_config, k) + for k in ( + "data_parallel_master_ip", + "data_parallel_master_port", + "_data_parallel_master_port_list", + "data_parallel_size", + ) + }, + parallel_config_hash=parallel_config.compute_hash() + if parallel_config.data_parallel_size > 1 + else None, + ) + ) + handshake_socket.send_multipart((eng_identity, init_message), copy=False) conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 engine.state = CoreEngineState.CONNECTED @@ -820,15 +944,37 @@ def wait_for_engine_startup( # one of the engine handshakes, and passed to the local # front-end process in the response from the other. if addresses.frontend_stats_publish_address is None: - addresses.frontend_stats_publish_address = msg.get( - "dp_stats_address") + addresses.frontend_stats_publish_address = msg.get("dp_stats_address") + + # Validate config hash consistency across DP workers + if parallel_config.data_parallel_size > 1: + worker_config_hash = msg.get("parallel_config_hash") + expected_hash = parallel_config.compute_hash() + if worker_config_hash != expected_hash: + raise RuntimeError( + f"Configuration mismatch detected for engine " + f"{eng_index}. All DP workers must have identical " + f"configurations for parameters that affect collective " + f"communication (e.g., enable_eplb, " + f"eplb_config.log_balancedness). " + f"Worker hash: {worker_config_hash}, " + f"Expected hash: {expected_hash}. " + f"Please ensure all workers are started with the same " + f"command-line arguments." + ) start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") + raise RuntimeError( + f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state." + ) - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + logger.debug( + "%s from %s core engine process %s.", + status, + "local" if local else "remote", + eng_index, + ) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 4be2f74177b1f..064e4b2bbf181 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -10,10 +10,11 @@ import torch.distributed as dist from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) -from vllm.executor.uniproc_executor import ( # noqa - UniProcExecutor as UniProcExecutorV0) + ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0, +) +from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa from vllm.utils import resolve_obj_by_qualname +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -29,21 +30,24 @@ class Executor(ExecutorBase): def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class: type[Executor] parallel_config = vllm_config.parallel_config - distributed_executor_backend = ( - parallel_config.distributed_executor_backend) + distributed_executor_backend = parallel_config.distributed_executor_backend # distributed_executor_backend must be set in VllmConfig.__post_init__ if isinstance(distributed_executor_backend, type): if not issubclass(distributed_executor_backend, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") + f"ExecutorBase. Got {distributed_executor_backend}." + ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": from vllm.v1.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor) + RayDistributedExecutor, + ) + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor elif distributed_executor_backend == "uni": executor_class = UniProcExecutor @@ -52,25 +56,24 @@ class Executor(ExecutorBase): # to support external launcher executor_class = ExecutorWithExternalLauncher elif isinstance(distributed_executor_backend, str): - executor_class = resolve_obj_by_qualname( - distributed_executor_backend) + executor_class = resolve_obj_by_qualname(distributed_executor_backend) if not issubclass(executor_class, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {executor_class}.") + f"ExecutorBase. Got {executor_class}." + ) else: - raise ValueError("Unknown distributed executor backend: " - f"{distributed_executor_backend}") + raise ValueError( + f"Unknown distributed executor backend: {distributed_executor_backend}" + ) return executor_class - def initialize_from_config(self, - kv_cache_configs: list[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_from_config", - args=(kv_cache_configs, )) + self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("compile_or_warm_up_model") def register_failure_callback(self, callback: FailureCallback): @@ -81,21 +84,34 @@ class Executor(ExecutorBase): pass def determine_available_memory(self) -> list[int]: # in bytes - output = self.collective_rpc("determine_available_memory") - return output + return self.collective_rpc("determine_available_memory") def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: - output = self.collective_rpc("get_kv_cache_spec") - return output + return self.collective_rpc("get_kv_cache_spec") + + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + ) -> list[Any]: + raise NotImplementedError def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - output = self.collective_rpc("execute_model", - args=(scheduler_output, )) + output = self.collective_rpc( + "execute_model", args=(scheduler_output,), non_block=non_block + ) return output[0] + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch") + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: output = self.collective_rpc("take_draft_token_ids") return output[0] @@ -105,7 +121,7 @@ class Executor(ExecutorBase): return 1 def profile(self, is_start: bool = True): - self.collective_rpc("profile", args=(is_start, )) + self.collective_rpc("profile", args=(is_start,)) class UniProcExecutor(UniProcExecutorV0, Executor): @@ -113,12 +129,12 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> list[int]: # in bytes # same as determine_num_available_blocks in v0, # we need to get the min across all ranks. memory = super().determine_available_memory() from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 15b88a2128994..d92c8f38571e9 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -3,6 +3,7 @@ import multiprocessing import os import pickle +import queue import signal import threading import time @@ -11,36 +12,45 @@ import weakref from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto -from functools import partial +from functools import cached_property, partial from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Lock as LockType from threading import Thread from typing import Any, Callable, Optional, Union, cast import cloudpickle +import torch import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) -from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.executor.multiproc_worker_utils import ( - set_multiprocessing_worker_envs) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel +from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue +from vllm.distributed.parallel_state import ( + get_dp_group, + get_ep_group, + get_pp_group, + get_tp_group, +) from vllm.logger import init_logger -from vllm.utils import (decorate_logs, get_distributed_init_method, - get_loopback_ip, get_mp_context, get_open_port, - set_process_title) +from vllm.utils import ( + _maybe_force_spawn, + decorate_logs, + get_distributed_init_method, + get_loopback_ip, + get_mp_context, + get_open_port, + set_process_title, +) +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput +from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class MultiprocExecutor(Executor): - supports_pp: bool = True def _init_executor(self) -> None: @@ -58,26 +68,30 @@ class MultiprocExecutor(Executor): assert self.world_size == tensor_parallel_size * pp_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). ") + f"_parallel_size ({pp_parallel_size}). " + ) - # Set multiprocessing envs that are common to V0 and V1 - set_multiprocessing_worker_envs(self.parallel_config) + # Set multiprocessing envs + set_multiprocessing_worker_envs() # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( - get_loopback_ip(), get_open_port()) + get_loopback_ip(), get_open_port() + ) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue(self.world_size, - self.world_size, - max_chunk_bytes=max_chunk_bytes) + self.rpc_broadcast_mq = MessageQueue( + self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes + ) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers + context = get_mp_context() + shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: @@ -89,7 +103,9 @@ class MultiprocExecutor(Executor): rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, - )) + shared_worker_lock=shared_worker_lock, + ) + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -110,8 +126,7 @@ class MultiprocExecutor(Executor): for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() - self._ensure_worker_termination( - [uw.proc for uw in unready_workers]) + self._ensure_worker_termination([uw.proc for uw in unready_workers]) # For pipeline parallel, we use a thread pool for asynchronous # execute_model. @@ -120,12 +135,11 @@ class MultiprocExecutor(Executor): # from the response queue # _async_aggregate_workers_output also assumes a single IO thread self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") + max_workers=1, thread_name_prefix="mp_exec_io" + ) self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator( - self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers @@ -138,23 +152,22 @@ class MultiprocExecutor(Executor): sentinels = [h.proc.sentinel for h in workers] died = multiprocessing.connection.wait(sentinels) _self = self_ref() - if not _self or getattr(_self, 'shutting_down', False): + if not _self or getattr(_self, "shutting_down", False): return _self.is_failed = True - proc_name = next(h.proc.name for h in workers - if h.proc.sentinel == died[0]) + proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) logger.error( - "Worker proc %s died unexpectedly, " - "shutting down executor.", proc_name) + "Worker proc %s died unexpectedly, shutting down executor.", proc_name + ) _self.shutdown() callback = _self.failure_callback if callback is not None: _self.failure_callback = None callback() - Thread(target=monitor_workers, - daemon=True, - name="MultiprocWorkerMonitor").start() + Thread( + target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" + ).start() def register_failure_callback(self, callback: FailureCallback): if self.is_failed: @@ -164,46 +177,52 @@ class MultiprocExecutor(Executor): def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - non_block = self.max_concurrent_batches > 1 - if not self.has_connector: # get output only from a single worker (output_rank) - (output, ) = self.collective_rpc( + (output,) = self.collective_rpc( "execute_model", - args=(scheduler_output, ), + args=(scheduler_output,), unique_reply_rank=self.output_rank, non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) return output # get output from all workers outputs = self.collective_rpc( "execute_model", - args=(scheduler_output, ), + args=(scheduler_output,), non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) # aggregate all workers output to a single output if non_block: - return self.kv_output_aggregator.async_aggregate( - outputs, self.output_rank) + return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: # OPTIMIZATION: Get output only from a single worker (output_rank) - outputs = self.collective_rpc("take_draft_token_ids", - unique_reply_rank=self.output_rank) + outputs = self.collective_rpc( + "take_draft_token_ids", unique_reply_rank=self.output_rank + ) return outputs[0] - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False, - unique_reply_rank: Optional[int] = None) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None, + ) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -218,36 +237,53 @@ class MultiprocExecutor(Executor): send_method = method else: send_method = cloudpickle.dumps( - method, protocol=pickle.HIGHEST_PROTOCOL) + method, protocol=pickle.HIGHEST_PROTOCOL + ) self.rpc_broadcast_mq.enqueue( - (send_method, args, kwargs, unique_reply_rank)) + (send_method, args, kwargs, unique_reply_rank) + ) - workers = (self.workers[unique_reply_rank], - ) if unique_reply_rank is not None else self.workers + workers = ( + (self.workers[unique_reply_rank],) + if unique_reply_rank is not None + else self.workers + ) responses = [] - def get_response(w: WorkerProcHandle, - dequeue_timeout: Optional[float] = None, - cancel_event: Optional[threading.Event] = None): + def get_response( + w: WorkerProcHandle, + dequeue_timeout: Optional[float] = None, + cancel_event: Optional[threading.Event] = None, + ): status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=cancel_event) + timeout=dequeue_timeout, cancel=cancel_event + ) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" - " stack trace above for the root cause") + " stack trace above for the root cause" + ) return result for w in workers: - dequeue_timeout = None if deadline is None else ( - deadline - time.monotonic()) + dequeue_timeout = ( + None if deadline is None else (deadline - time.monotonic()) + ) - if non_block: + if self.io_thread_pool is not None: + # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore - get_response, w, dequeue_timeout, self.shutdown_event) + get_response, w, dequeue_timeout, self.shutdown_event + ) + if not non_block: + result = result.result() + elif not non_block: + result = get_response(w, dequeue_timeout, self.shutdown_event) else: - result = get_response(w, dequeue_timeout) - + raise RuntimeError( + "non_block can only be used when max_concurrent_batches > 1" + ) responses.append(result) return responses @@ -284,15 +320,11 @@ class MultiprocExecutor(Executor): def shutdown(self): """Properly shut down the executor and its workers""" - if not getattr(self, 'shutting_down', False): + if not getattr(self, "shutting_down", False): self.shutting_down = True - self.shutdown_event.set() - if self.io_thread_pool is not None: - self.io_thread_pool.shutdown(wait=False, cancel_futures=True) - self.io_thread_pool = None - - if workers := getattr(self, 'workers', None): + # Make sure all the worker processes are terminated first. + if workers := getattr(self, "workers", None): for w in workers: # Close death_writer to signal child processes to exit if w.death_writer is not None: @@ -301,13 +333,18 @@ class MultiprocExecutor(Executor): w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) + self.shutdown_event.set() + if self.io_thread_pool is not None: + self.io_thread_pool.shutdown(wait=False, cancel_futures=True) + del self.io_thread_pool + self.rpc_broadcast_mq = None def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return - @property + @cached_property def max_concurrent_batches(self) -> int: if self.scheduler_config.async_scheduling: return 2 @@ -329,6 +366,7 @@ class MultiprocExecutor(Executor): @dataclass class UnreadyWorkerProcHandle: """WorkerProcess handle before READY.""" + proc: BaseProcess rank: int ready_pipe: Connection @@ -344,8 +382,8 @@ class WorkerProcHandle: @classmethod def from_unready_handle( - cls, unready_handle: UnreadyWorkerProcHandle, - worker_response_mq: MessageQueue) -> "WorkerProcHandle": + cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue + ) -> "WorkerProcHandle": return cls( proc=unready_handle.proc, rank=unready_handle.rank, @@ -366,6 +404,7 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle: Handle, + shared_worker_lock: LockType, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -373,47 +412,56 @@ class WorkerProc: all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] - is_driver_worker = ( - rank % vllm_config.parallel_config.tensor_parallel_size == 0) + is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0 all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "is_driver_worker": is_driver_worker, + "shared_worker_lock": shared_worker_lock, } wrapper.init_worker(all_kwargs) self.worker = wrapper - pp_size = vllm_config.parallel_config.pipeline_parallel_size - tp_size = vllm_config.parallel_config.tensor_parallel_size - pp_str = f"PP{rank // tp_size}" if pp_size > 1 else "" - tp_str = f"TP{rank % tp_size}" if tp_size > 1 else "" - suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}" - process_name = "VllmWorker" - if suffix: - set_process_title(suffix, append=True) - process_name = f"{process_name} {suffix}" - decorate_logs(process_name) - # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( - input_shm_handle, self.worker.rank) + input_shm_handle, self.worker.rank + ) # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) - # Initialize device and loads weights + scheduler_config = vllm_config.scheduler_config + self.use_async_scheduling = scheduler_config.async_scheduling + if self.use_async_scheduling: + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy", + ) + self.async_output_copy_thread.start() + + # Initialize device self.worker.init_device() + + # Set process title and log prefix + self.setup_proc_title_and_log_prefix( + enable_ep=vllm_config.parallel_config.enable_expert_parallel + ) + + # Load model self.worker.load_model() @staticmethod def make_worker_process( - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - input_shm_handle, # Receive SchedulerOutput + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + shared_worker_lock: LockType, ) -> UnreadyWorkerProcHandle: context = get_mp_context() # (reader, writer) @@ -430,12 +478,15 @@ class WorkerProc: "input_shm_handle": input_shm_handle, "ready_pipe": (reader, writer), "death_pipe": death_reader, + "shared_worker_lock": shared_worker_lock, } # Run EngineCore busy loop in background process. - proc = context.Process(target=WorkerProc.worker_main, - kwargs=process_kwargs, - name=f"VllmWorker-{rank}", - daemon=True) + proc = context.Process( + target=WorkerProc.worker_main, + kwargs=process_kwargs, + name=f"VllmWorker-{rank}", + daemon=True, + ) proc.start() writer.close() @@ -445,16 +496,18 @@ class WorkerProc: @staticmethod def wait_for_ready( - unready_proc_handles: list[UnreadyWorkerProcHandle] + unready_proc_handles: list[UnreadyWorkerProcHandle], ) -> list[WorkerProcHandle]: - - e = Exception("WorkerProc initialization failed due to " - "an exception in a background process. " - "See stack trace for root cause.") + e = Exception( + "WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause." + ) pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} - ready_proc_handles: list[Optional[WorkerProcHandle]] = ( - [None] * len(unready_proc_handles)) + ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len( + unready_proc_handles + ) while pipes: ready = multiprocessing.connection.wait(pipes.keys()) for pipe in ready: @@ -468,10 +521,13 @@ class WorkerProc: # Extract the message queue handle. worker_response_mq = MessageQueue.create_from_handle( - response["handle"], 0) + response["handle"], 0 + ) ready_proc_handles[unready_proc_handle.rank] = ( WorkerProcHandle.from_unready_handle( - unready_proc_handle, worker_response_mq)) + unready_proc_handle, worker_response_mq + ) + ) except EOFError: e.__suppress_context__ = True @@ -484,6 +540,7 @@ class WorkerProc: return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() @@ -491,8 +548,8 @@ class WorkerProc: @staticmethod def worker_main(*args, **kwargs): - """ Worker initialization and execution loops. - This runs a background process """ + """Worker initialization and execution loops. + This runs a background process""" # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker @@ -513,7 +570,7 @@ class WorkerProc: # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None) - + shutdown_event = threading.Event() # Start death monitoring thread if death_pipe is provided if death_pipe is not None: @@ -525,13 +582,13 @@ class WorkerProc: # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown - os.kill(os.getpid(), signal.SIGTERM) + shutdown_event.set() except Exception as e: logger.warning("Death monitoring error: %s", e) - death_monitor = Thread(target=monitor_parent_death, - daemon=True, - name="WorkerDeathMonitor") + death_monitor = Thread( + target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" + ) death_monitor.start() try: @@ -539,12 +596,12 @@ class WorkerProc: worker = WorkerProc(*args, **kwargs) # Send READY once we know everything is loaded - ready_writer.send({ - "status": - WorkerProc.READY_STR, - "handle": - worker.worker_response_mq.export_handle(), - }) + ready_writer.send( + { + "status": WorkerProc.READY_STR, + "handle": worker.worker_response_mq.export_handle(), + } + ) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor @@ -553,7 +610,7 @@ class WorkerProc: ready_writer.close() ready_writer = None - worker.worker_busy_loop() + worker.worker_busy_loop(cancel=shutdown_event) except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -563,6 +620,8 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") + elif shutdown_event.is_set(): + logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -584,16 +643,49 @@ class WorkerProc: SUCCESS = auto() FAILURE = auto() - def worker_busy_loop(self): + def enqueue_output(self, output: Any): + """Prepares output from the worker and enqueues it to the + worker_response_mq. If the output is an Exception, it is + converted to a FAILURE response. + """ + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() + + if isinstance(output, Exception): + result = (WorkerProc.ResponseStatus.FAILURE, str(output)) + else: + result = (WorkerProc.ResponseStatus.SUCCESS, output) + if (response_mq := self.worker_response_mq) is not None: + response_mq.enqueue(result) + + def handle_output(self, output: Any): + """Handles output from the worker. If async scheduling is enabled, + it is passed to the async_output_busy_loop thread. Otherwise, it is + enqueued directly to the worker_response_mq. + """ + if self.use_async_scheduling: + self.async_output_queue.put(output) + else: + self.enqueue_output(output) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + output = self.async_output_queue.get() + self.enqueue_output(output) + + def worker_busy_loop(self, cancel: Optional[threading.Event] = None): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() - + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( + cancel=cancel, indefinite=True + ) try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) + output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 @@ -603,10 +695,58 @@ class WorkerProc: # exception might not be serializable, so we convert it to # string, only for logging purpose. if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + self.handle_output(e) continue if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + self.handle_output(output) + + @staticmethod + def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: + dp_size = get_dp_group().world_size + dp_rank = get_dp_group().rank_in_group + pp_size = get_pp_group().world_size + pp_rank = get_pp_group().rank_in_group + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + process_name = "Worker" + if dp_size > 1: + process_name += f"_DP{dp_rank}" + if pp_size > 1: + process_name += f"_PP{pp_rank}" + if tp_size > 1: + process_name += f"_TP{tp_rank}" + if enable_ep: + ep_rank = get_ep_group().rank_in_group + process_name += f"_EP{ep_rank}" + set_process_title(name=process_name) + decorate_logs(process_name) + + +def set_multiprocessing_worker_envs(): + """Set up environment variables that should be used when there are workers + in a multiprocessing environment. This should be called by the parent + process before worker processes are created""" + + _maybe_force_spawn() + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if ( + "OMP_NUM_THREADS" not in os.environ + and (current_parallelism := torch.get_num_threads()) > default_omp_num_threads + ): + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, + default_omp_num_threads, + ) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index c05ad1966d611..e2c2bfd45d7bd 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -6,8 +6,10 @@ from typing import Optional, Union from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0) + RayDistributedExecutor as RayDistributedExecutorV0, +) from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -17,10 +19,10 @@ logger = init_logger(__name__) class FutureWrapper(Future): """A wrapper around Ray output reference to meet the interface - of .execute_model(): The top level (core busy loop) expects .result() api + of .execute_model(): The top level (core busy loop) expects .result() api to block and return a single output. - - If aggregator is provided, the outputs from all workers are aggregated upon + + If aggregator is provided, the outputs from all workers are aggregated upon the result() call. If not only the first worker's output is returned. """ @@ -50,8 +52,6 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator( - self.parallel_config.world_size) @property def max_concurrent_batches(self) -> int: @@ -64,12 +64,14 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: """Execute the model on the Ray workers. Args: scheduler_output: The scheduler output to execute. + non_block: If True, the method will return a Future. Returns: The model runner output. @@ -83,7 +85,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): if not self.has_connector: # Get output only from a single worker (output_rank) # When PP is not used, we block here until the result is available. - if self.max_concurrent_batches == 1: + if not non_block: return refs[0].get() # When PP is used, we return a FutureWrapper immediately so that @@ -91,7 +93,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): return FutureWrapper(refs) # Get output from all workers when connector is present - if self.max_concurrent_batches == 1: + if not non_block: # Block and get results from all workers outputs = [ref.get() for ref in refs] return self.kv_output_aggregator.aggregate(outputs) @@ -100,9 +102,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): return FutureWrapper(refs, self.kv_output_aggregator) def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: self._run_workers("reinitialize_distributed", reconfig_request) - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() - return \ No newline at end of file diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index ed8e0bf798988..9c28eb92c17a9 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -50,7 +50,8 @@ class KVCacheSpec: Merge a list of KVCacheSpec objects into a single KVCacheSpec object. """ assert all(spec == specs[0] for spec in specs[1:]), ( - "All layers in the same KV cache group must be the same.") + "All layers in the same KV cache group must be the same." + ) return copy.deepcopy(specs[0]) @@ -59,14 +60,16 @@ class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype - use_mla: bool @property def page_size_bytes(self) -> int: - # For MLA we only store a single latent vector - coef = 1 if self.use_mla else 2 - return coef * self.block_size * self.num_kv_heads * self.head_size \ - * get_dtype_size(self.dtype) + return ( + 2 + * self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) @dataclass(frozen=True) @@ -85,6 +88,11 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod @@ -96,28 +104,35 @@ class FullAttentionSpec(AttentionSpec): else: raise ValueError( "All attention layers in the same KV cache group must have the " - "same window size.") + "same window size." + ) @classmethod def merge(cls, specs: list[Self]) -> Self: """ - Merge a list of FullAttentionSpec objects into a single + Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object. """ assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "FullAttentionSpec.") + "All attention layers in the same KV cache group must be FullAttentionSpec." + ) - sliding_window = set(spec.sliding_window for spec in specs - if spec.sliding_window is not None) - attention_chunk_size = set(spec.attention_chunk_size for spec in specs - if spec.attention_chunk_size is not None) + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) merged_spec = cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, - use_mla=specs[0].use_mla, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -125,30 +140,69 @@ class FullAttentionSpec(AttentionSpec): for f in fields(AttentionSpec): assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( "All attention layers in the same KV cache group must have " - "the same attention spec.") - assert ( - (merged_spec.sliding_window is not None) + - (merged_spec.attention_chunk_size is not None) <= 1 - ), ("Model with both sliding window layers and chunked local attention " - "layers is not supported.") + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) return merged_spec +@dataclass(frozen=True) +class MLAAttentionSpec(FullAttentionSpec): + # TODO(Lucas/Chen): less hacky way to do this + cache_dtype_str: Optional[str] = None + + @property + def page_size_bytes(self) -> int: + if self.cache_dtype_str == "fp8_ds_mla": + # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` + # for details. + return self.block_size * 656 + return ( + self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be MLAAttentionSpec." + ) + cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) + assert len(cache_dtype_str_set) == 1, ( + "All attention layers in the same KV cache group must use the same " + "quantization method." + ) + return cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + cache_dtype_str=cache_dtype_str_set.pop(), + ) + + @dataclass(frozen=True) class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # During chunked prefill, we allocate KV cache for at most # `self.attention_chunk_size` computed tokens plus the newly scheduled # tokens. And we won't allocate KV cache for more than `max_model_len` # tokens. - num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, - max_model_len) + num_tokens = min( + self.attention_chunk_size + max_num_batched_tokens, max_model_len + ) return cdiv(num_tokens, self.block_size) * self.page_size_bytes @@ -157,20 +211,20 @@ class ChunkedLocalAttentionSpec(AttentionSpec): class SlidingWindowSpec(AttentionSpec): sliding_window: int - def __post_init__(self): - assert not self.use_mla, "MLA is not supported for sliding window" - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( + "DCP not support sliding window." + ) max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # During chunked prefill, we allocate KV cache for the last # `self.sliding_window-1` computed tokens plus the newly scheduled # tokens. And we won't allocate KV cache for more than `max_model_len` # tokens. - num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, - max_model_len) + num_tokens = min( + self.sliding_window - 1 + max_num_batched_tokens, max_model_len + ) # +1 here because the sliding window may not start from the beginning # of the block. For example, if the block size is 4 and num_token @@ -185,37 +239,127 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" + num_speculative_blocks: int = 0 @property def page_size_bytes(self) -> int: page_size = sum( prod(shape) * get_dtype_size(dtype) - for (shape, dtype) in zip(self.shapes, self.dtypes)) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - # We allocate 1 block for each request now, so max_memory_usage_bytes is - # the same as page_size_bytes. - # Need to update this when supporting prefix caching. - return self.page_size_bytes + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes @dataclass(frozen=True) class EncoderOnlyAttentionSpec(AttentionSpec): - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # Encoder-only layers do not need KV cache return 0 +@dataclass(frozen=True) +class CrossAttentionSpec(AttentionSpec): + """ + KV cache spec for cross-attention layers in encoder-decoder models. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # For cross-attention, we need to cache encoder states + # Get encoder length (e.g., 1500 for Whisper). + max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens + return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes + + +@dataclass(frozen=True) +class UniformTypeKVCacheSpecs(KVCacheSpec): + """ + A KV cache spec for multiple layers with the same type of attention. Here, + same types means always need the same number of token slots. For example, + sliding window attentions with different window sizes are not the same type + and should not be merged into one UniformTypeKVCacheSpecs. + """ + + kv_cache_specs: dict[str, KVCacheSpec] + + @property + def page_size_bytes(self) -> int: + return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values()) + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_num_pages = max( + cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes) + for spec in self.kv_cache_specs.values() + ) + return max_num_pages * self.page_size_bytes + + @classmethod + def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers have the same type of KV cache spec. + """ + block_sizes = set(spec.block_size for spec in kv_cache_specs.values()) + if len(block_sizes) > 1: + # Different block sizes, not uniform. + return False + one_spec = next(iter(kv_cache_specs.values())) + if isinstance(one_spec, FullAttentionSpec): + return all( + isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values() + ) + elif isinstance(one_spec, CrossAttentionSpec): + return all( + isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values() + ) + elif isinstance(one_spec, SlidingWindowSpec): + return all( + isinstance(spec, SlidingWindowSpec) + and spec.sliding_window == one_spec.sliding_window + for spec in kv_cache_specs.values() + ) + elif isinstance(one_spec, ChunkedLocalAttentionSpec): + return all( + isinstance(spec, ChunkedLocalAttentionSpec) + and spec.attention_chunk_size == one_spec.attention_chunk_size + for spec in kv_cache_specs.values() + ) + elif isinstance(one_spec, MambaSpec): + return all( + isinstance(spec, MambaSpec) + and spec.num_speculative_blocks == one_spec.num_speculative_blocks + for spec in kv_cache_specs.values() + ) + else: + # NOTE(Chen): Please add new branches for new KV cache spec types. + raise NotImplementedError( + f"Unsupported KV cache spec type: {type(one_spec)}" + ) + + @classmethod + def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Optional[Self]: + """ + Return a SameTypeKVCacheSpecs object if all layers have the same type + of KV cache spec. Return None if not. + """ + if cls.is_uniform_type(kv_cache_specs): + block_size = next(iter(kv_cache_specs.values())).block_size + return cls(block_size=block_size, kv_cache_specs=kv_cache_specs) + else: + return None + + @dataclass class KVCacheTensor: """ A class for specifying how the workers should initialize the KV cache. """ + size: int # size of the KV cache tensor in bytes shared_by: list[str] # layer names that share the same KV cache tensor @@ -226,6 +370,7 @@ class KVCacheGroupSpec: Represents a group of model layers that share the same KV cache block table. These layers are regarded as one layer in the KV cache manager. """ + # The names of model layers in this group layer_names: list[str] # The KV cache spec of this manager layer @@ -237,6 +382,7 @@ class KVCacheConfig: """ The KV cache configuration of a model. """ + """The number of KV cache blocks""" num_blocks: int """How should model runner initialize the KV cache tensors for each layer""" diff --git a/vllm/core/__init__.py b/vllm/v1/kv_offload/__init__.py similarity index 100% rename from vllm/core/__init__.py rename to vllm/v1/kv_offload/__init__.py diff --git a/vllm/v1/kv_offload/abstract.py b/vllm/v1/kv_offload/abstract.py new file mode 100644 index 0000000000000..ce2d0dffc0ff6 --- /dev/null +++ b/vllm/v1/kv_offload/abstract.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +OffloadingManager class for managing KV data offloading in vLLM v1 + +This class runs in the scheduler, tracks which blocks are offloaded +and their address. + +The class provides the following primitives: + lookup() - find the length of the maximal series of blocks, + starting from the first one, that are all offloaded. + prepare_load() - prepare given blocks to be read. + The given blocks will be protected from eviction. + This function returns a LoadSpec which encapsulates + information required for performing the load. + touch() - marks the give blocks as recently used. Can be used + to track block's LRU. This function is separated from the + prepare_load function to allow setting block recency even + for blocks which do not need reading from the cache, such as + blocks that are cached by the GPU prefix cache. + complete_load() - mark blocks which were previously prepared to be + loaded as done loading. This is to re-allow their eviction. + prepare_store() - prepare the given blocks to be written. + Returns a StoreSpec encapsulating offloading information, + as well as a list of blocks that were evicted as a result. + complete_store() - marks a previous store as completed. + Following this call, the given blocks will become loadable. +""" + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional + +from vllm.v1.core.kv_cache_utils import BlockHash + + +class LoadStoreSpec(ABC): + """ + Abstract metadata that encapsulates information allowing a worker + to load, and optionally also to store, blocks of KV data. + """ + + @staticmethod + @abstractmethod + def medium() -> str: + """ + Returns a string representation of the medium type + this store/load targets. + """ + pass + + +@dataclass +class PrepareStoreOutput: + block_hashes_to_store: list[BlockHash] + store_spec: LoadStoreSpec + block_hashes_evicted: list[BlockHash] + + +@dataclass +class OffloadingEvent: + block_hashes: list[BlockHash] + block_size: int + medium: str + # True if blocks are removed, False if stored + removed: bool + + +class OffloadingManager(ABC): + @abstractmethod + def lookup(self, block_hashes: Iterable[BlockHash]) -> int: + """ + Finds the length of the maximal series of blocks, starting from the + first one, that are all offloaded. + + Args: + block_hashes: the hashes identifying the blocks to lookup. + + Returns: + An integer representing the maximal number of blocks that + are currently offloaded. + """ + pass + + @abstractmethod + def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec: + """ + Prepare the given blocks to be read. + The given blocks will be protected from eviction until + complete_load is called. + It assumes all given blocks are offloaded. + + Args: + block_hashes: the hashes identifying the blocks. + + Returns: + A LoadStoreSpec that can be used by a worker to locate and load + the actual offloaded KV data. + """ + pass + + def touch(self, block_hashes: Iterable[BlockHash]): + """ + Mark the given blocks as recently used. + This could in practice mean moving them to the end of an LRU list. + + Args: + block_hashes: the hashes identifying the blocks. + """ + return + + def complete_load(self, block_hashes: Iterable[BlockHash]): + """ + Marks previous blocks that were prepared to load as done loading. + + Args: + block_hashes: the hashes identifying the blocks. + """ + return + + @abstractmethod + def prepare_store( + self, block_hashes: Iterable[BlockHash] + ) -> Optional[PrepareStoreOutput]: + """ + Prepare the given blocks to be offloaded. + The given blocks will be protected from eviction until + complete_store is called. + + Args: + block_hashes: the hashes identifying the blocks. + + Returns: + A PrepareStoreOutput indicating which blocks need storing, + where to store them (LoadStoreSpec), and list of blocks that + were evicted as a result. + None is returned if the blocks cannot be stored. + """ + pass + + def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): + """ + Marks blocks which were previously prepared to be stored, as stored. + Following this call, the blocks become loadable. + If if_success is False, blocks that were not marked as stored will be + removed. + + Args: + block_hashes: the hashes identifying the blocks. + success: whether the blocks were stored successfully. + """ + return + + def take_events(self) -> Iterable[OffloadingEvent]: + """ + Take the offloading events from the manager. + + Yields: + New OffloadingEvents collected since the last call. + """ + return () diff --git a/vllm/v1/kv_offload/backend.py b/vllm/v1/kv_offload/backend.py new file mode 100644 index 0000000000000..538f7bf0584b5 --- /dev/null +++ b/vllm/v1/kv_offload/backend.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ctypes +from abc import ABC, abstractmethod +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import LoadStoreSpec + + +class BlockStatus(ctypes.Structure): + """ + Offloading status for a single block of KV data. + Holds the following information: + + ref_cnt - the current number of transfers using this block as a source. + A value of -1 indicates the block is not yet ready to be read. + load_store_spec - backend-specific information on how to actually + read/write the block. + """ + + _fields_ = [("ref_cnt", ctypes.c_int32)] + + def __init__(self): + super().__init__() + # initialize block as "not ready" (ref_cnt = -1) + self.ref_cnt = -1 + + @property + def is_ready(self) -> bool: + """ + Returns whether the block is ready to be read. + """ + return self.ref_cnt >= 0 + + +class Backend(ABC): + """ + An abstract class for allocating and returning specs for writing + KV blocks to some backend. + """ + + def __init__(self, block_size: int, medium: str): + self.block_size = block_size + self.medium = medium + + @abstractmethod + def get_num_free_blocks(self): + """ + Returns the number of current number of blocks that can be allocated. + """ + pass + + @abstractmethod + def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: + """ + Allocate space for writing blocks. + This method assumes there is enough space for allocation. + It is unsafe to use without checking get_num_free_blocks beforehand. + + Args: + block_hashes: the hashes identifying the blocks to be written. + + Returns: + A list of BlockStatus for the allocated blocks. + The ref_cnt of each returned item will be -1, meaning the block + is not yet ready to be read. + """ + pass + + @abstractmethod + def free(self, block: BlockStatus): + """ + Free a previously allocated block. + You should only call this function with blocks returned by + allocate_blocks, and only once per each block. + + Args: + block: The block to be freed. + """ + pass + + def get_load_store_spec( + self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] + ) -> LoadStoreSpec: + """ + Get backend-specific information on how to read/write blocks. + + Args: + block_hashes: the list of block hashes identifying the blocks. + blocks: the list of blocks. + + Returns: + A LoadStoreSpec that can be used by a worker + to read/write the blocks. + """ + raise NotImplementedError diff --git a/vllm/core/block/__init__.py b/vllm/v1/kv_offload/backends/__init__.py similarity index 100% rename from vllm/core/block/__init__.py rename to vllm/v1/kv_offload/backends/__init__.py diff --git a/vllm/v1/kv_offload/backends/cpu.py b/vllm/v1/kv_offload/backends/cpu.py new file mode 100644 index 0000000000000..736cf37853cdc --- /dev/null +++ b/vllm/v1/kv_offload/backends/cpu.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ctypes +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import LoadStoreSpec +from vllm.v1.kv_offload.backend import Backend, BlockStatus +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec + + +class CPUBlockStatus(BlockStatus): + _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)] # type: ignore + + def __init__(self, block_id: int): + super().__init__() + self.block_id = block_id + + +class CPUBackend(Backend): + def __init__(self, block_size: int, num_blocks: int): + super().__init__(block_size=block_size, medium=CPULoadStoreSpec.medium()) + + self.num_blocks: int = num_blocks + self.num_allocated_blocks: int = 0 + self.allocated_blocks_free_list: list[int] = [] + + def get_num_free_blocks(self): + return ( + len(self.allocated_blocks_free_list) + + self.num_blocks + - self.num_allocated_blocks + ) + + def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: + num_fresh_blocks = min( + len(block_hashes), self.num_blocks - self.num_allocated_blocks + ) + num_reused_blocks = len(block_hashes) - num_fresh_blocks + assert len(self.allocated_blocks_free_list) >= num_reused_blocks + + # allocate fresh blocks + blocks: list[BlockStatus] = [] + for _ in range(num_fresh_blocks): + blocks.append(CPUBlockStatus(self.num_allocated_blocks)) + self.num_allocated_blocks += 1 + + # allocate reused blocks + for _ in range(num_reused_blocks): + block_id = self.allocated_blocks_free_list.pop() + blocks.append(CPUBlockStatus(block_id)) + + return blocks + + def free(self, block: BlockStatus): + assert isinstance(block, CPUBlockStatus) + self.allocated_blocks_free_list.append(block.block_id) + + def get_load_store_spec( + self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] + ) -> LoadStoreSpec: + return CPULoadStoreSpec([block.block_id for block in blocks]) diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py new file mode 100644 index 0000000000000..0c1cf64a237cb --- /dev/null +++ b/vllm/v1/kv_offload/cpu.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from typing import Optional + +import torch + +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.backends.cpu import CPUBackend +from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + + +class CPUOffloadingSpec(OffloadingSpec): + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + num_cpu_blocks = self.extra_config.get("num_cpu_blocks") + if not num_cpu_blocks: + raise Exception( + "num_cpu_blocks must be specified in kv_connector_extra_config" + ) + self.num_cpu_blocks: int = num_cpu_blocks + + # scheduler-side + self._manager: Optional[OffloadingManager] = None + + # worker-side + self._handler: Optional[OffloadingHandler] = None + + def get_manager(self) -> OffloadingManager: + if not self._manager: + kv_events_config = self.vllm_config.kv_events_config + enable_events = ( + kv_events_config is not None and kv_events_config.enable_kv_cache_events + ) + self._manager = LRUOffloadingManager( + CPUBackend( + block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks + ), + enable_events=enable_events, + ) + return self._manager + + def get_handlers( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: + if not self._handler: + if not current_platform.is_cuda(): + raise Exception( + "CPU Offloading is currently only supported on CUDA GPUs" + ) + + layer_names = list(kv_caches.keys()) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) + attn_backends = { + layer_name: layers[layer_name].get_attn_backend() + for layer_name in layer_names + } + + self._handler = CpuGpuOffloadingHandler( + attn_backends=attn_backends, + gpu_block_size=self.gpu_block_size, + cpu_block_size=self.offloaded_block_size, + num_cpu_blocks=self.num_cpu_blocks, + gpu_caches=kv_caches, + ) + + assert self._handler is not None + yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler + yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py new file mode 100644 index 0000000000000..e0a53460e840d --- /dev/null +++ b/vllm/v1/kv_offload/factory.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +from typing import TYPE_CHECKING, Callable + +from vllm.logger import init_logger +from vllm.v1.kv_offload.spec import OffloadingSpec + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class OffloadingSpecFactory: + _registry: dict[str, Callable[[], type[OffloadingSpec]]] = {} + + @classmethod + def register_spec(cls, name: str, module_path: str, class_name: str) -> None: + """Register a spec with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[OffloadingSpec]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_spec( + cls, + config: "VllmConfig", + ) -> OffloadingSpec: + kv_transfer_config = config.kv_transfer_config + assert kv_transfer_config is not None + extra_config = kv_transfer_config.kv_connector_extra_config + spec_name = extra_config.get("spec_name", "CPUOffloadingSpec") + if spec_name in cls._registry: + spec_cls = cls._registry[spec_name]() + else: + spec_module_path = extra_config.get("spec_module_path") + if spec_module_path is None: + raise ValueError(f"Unsupported spec type: {spec_name}") + spec_module = importlib.import_module(spec_module_path) + spec_cls = getattr(spec_module, spec_name) + assert issubclass(spec_cls, OffloadingSpec) + logger.info("Creating offloading spec with name: %s", spec_name) + return spec_cls(config) + + +# Register various specs here. +OffloadingSpecFactory.register_spec( + "CPUOffloadingSpec", "vllm.v1.kv_offload.cpu", "CPUOffloadingSpec" +) diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py new file mode 100644 index 0000000000000..36f5eb4a0abdd --- /dev/null +++ b/vllm/v1/kv_offload/lru_manager.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Iterable +from typing import Optional + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + PrepareStoreOutput, +) +from vllm.v1.kv_offload.backend import Backend, BlockStatus + + +class LRUOffloadingManager(OffloadingManager): + """ + An OffloadingManager with a pluggable backend, which evicts blocks by LRU. + """ + + def __init__(self, backend: Backend, enable_events: bool = False): + self.backend: Backend = backend + # block_hash -> BlockStatus + self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() + self.events: Optional[list[OffloadingEvent]] = [] if enable_events else None + + def lookup(self, block_hashes: Iterable[BlockHash]) -> int: + hit_count = 0 + for block_hash in block_hashes: + block = self.blocks.get(block_hash) + if block is None or not block.is_ready: + break + hit_count += 1 + return hit_count + + def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec: + blocks = [] + for block_hash in block_hashes: + block = self.blocks[block_hash] + assert block.is_ready + block.ref_cnt += 1 + blocks.append(block) + + return self.backend.get_load_store_spec(block_hashes, blocks) + + def touch(self, block_hashes: Iterable[BlockHash]): + for block_hash in reversed(list(block_hashes)): + if self.blocks.get(block_hash): + self.blocks.move_to_end(block_hash) + + def complete_load(self, block_hashes: Iterable[BlockHash]): + for block_hash in block_hashes: + block = self.blocks[block_hash] + assert block.ref_cnt > 0 + block.ref_cnt -= 1 + + def prepare_store( + self, block_hashes: Iterable[BlockHash] + ) -> Optional[PrepareStoreOutput]: + # filter out blocks that are already stored + block_hashes_to_store = [ + block_hash for block_hash in block_hashes if block_hash not in self.blocks + ] + + num_blocks_to_evict = ( + len(block_hashes_to_store) - self.backend.get_num_free_blocks() + ) + + # build list of blocks to evict + to_evict = [] + if num_blocks_to_evict > 0: + for block_hash, block in self.blocks.items(): + if block.ref_cnt == 0: + to_evict.append(block_hash) + num_blocks_to_evict -= 1 + if num_blocks_to_evict == 0: + break + else: + # we could not evict enough blocks + return None + + # evict blocks + for block_hash in to_evict: + self.backend.free(self.blocks.pop(block_hash)) + + if to_evict and self.events is not None: + self.events.append( + OffloadingEvent( + block_hashes=to_evict, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=True, + ) + ) + + blocks = self.backend.allocate_blocks(block_hashes_to_store) + assert len(blocks) == len(block_hashes_to_store) + + for block_hash, block in zip(block_hashes_to_store, blocks): + self.blocks[block_hash] = block + + # build store specs for allocated blocks + store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks) + + return PrepareStoreOutput( + block_hashes_to_store=block_hashes_to_store, + store_spec=store_spec, + block_hashes_evicted=to_evict, + ) + + def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): + stored_block_hashes: list[BlockHash] = [] + if success: + for block_hash in block_hashes: + block = self.blocks[block_hash] + if not block.is_ready: + block.ref_cnt = 0 + stored_block_hashes.append(block_hash) + else: + for block_hash in block_hashes: + block = self.blocks[block_hash] + if not block.is_ready: + self.backend.free(block) + del self.blocks[block_hash] + + if stored_block_hashes and self.events is not None: + self.events.append( + OffloadingEvent( + block_hashes=stored_block_hashes, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=False, + ) + ) + + def take_events(self) -> Iterable[OffloadingEvent]: + if self.events is not None: + yield from self.events + self.events.clear() diff --git a/vllm/v1/kv_offload/mediums.py b/vllm/v1/kv_offload/mediums.py new file mode 100644 index 0000000000000..8962819178459 --- /dev/null +++ b/vllm/v1/kv_offload/mediums.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC + +import numpy as np + +from vllm.v1.kv_offload.abstract import LoadStoreSpec + + +class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC): + """ + Spec for loading/storing KV blocks from given block numbers. + """ + + def __init__(self, block_ids: list[int]): + self.block_ids = np.array(block_ids, dtype=np.int64) + + def __repr__(self) -> str: + return repr(self.block_ids) + + +class GPULoadStoreSpec(BlockIDsLoadStoreSpec): + """ + Spec for loading/storing a KV block to GPU memory. + """ + + @staticmethod + def medium() -> str: + return "GPU" + + +class CPULoadStoreSpec(BlockIDsLoadStoreSpec): + """ + Spec for loading/storing a KV block to CPU memory. + """ + + @staticmethod + def medium() -> str: + return "CPU" diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py new file mode 100644 index 0000000000000..a3c539a47d458 --- /dev/null +++ b/vllm/v1/kv_offload/spec.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class OffloadingSpec(ABC): + """Spec for an offloading connector""" + + def __init__(self, vllm_config: "VllmConfig"): + logger.warning( + "Initializing OffloadingSpec. This API is experimental and " + "subject to change in the future as we iterate the design." + ) + self.vllm_config = vllm_config + + kv_transfer_config = vllm_config.kv_transfer_config + assert kv_transfer_config is not None + self.extra_config = kv_transfer_config.kv_connector_extra_config + + self.gpu_block_size = vllm_config.cache_config.block_size + self.offloaded_block_size = int( + self.extra_config.get("block_size", self.gpu_block_size) + ) + + assert self.offloaded_block_size % self.gpu_block_size == 0 + + @abstractmethod + def get_manager(self) -> OffloadingManager: + """ + Get an OffloadingManager that will be used + by the scheduler-side offloading connector to track + offloaded blocks and manage evictions. + """ + pass + + @abstractmethod + def get_handlers( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: + """ + Get offloading handlers along with their respective src and dst types. + + Args: + kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor. + + Yields: + Tuples of (src_type, dst_type, offloading_handler). + """ + pass diff --git a/vllm/engine/output_processor/__init__.py b/vllm/v1/kv_offload/worker/__init__.py similarity index 100% rename from vllm/engine/output_processor/__init__.py rename to vllm/v1/kv_offload/worker/__init__.py diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py new file mode 100644 index 0000000000000..eb7117a400b90 --- /dev/null +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention import AttentionBackend +from vllm.logger import init_logger +from vllm.utils import is_pin_memory_available +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + TransferResult, + TransferSpec, +) + +logger = init_logger(__name__) + + +def expand_block_ids( + block_ids: np.ndarray, + block_size_factor: int, + output: np.ndarray, + skip_count: int = 0, +): + """ + Convert a list of block IDs to a list of matching block ids, + assuming each block is composed of actual block_size_factor blocks. + Outputs to output tensor. + The first skip_count blocks will be skipped. + Note that skip_count must be less than block_size_factor. + + For example, if block_ids = [0, 1, 3] and block_size_factor = 4, + then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] + since 0 maps to [0, 1, 2, 3] + 1 maps to [4, 5, 6, 7] + and 3 maps to [12, 13, 14, 15] + """ + assert skip_count < block_size_factor + + first_range = np.arange(skip_count, block_size_factor) + full_range = np.arange(0, block_size_factor) + + output_idx = 0 + for i, block_id in enumerate(block_ids): + base_block_id = block_id * block_size_factor + indices = first_range if i == 0 else full_range + output_end_idx = output_idx + len(indices) + output[output_idx:output_end_idx] = base_block_id + indices + output_idx = output_end_idx + + +class CpuGpuOffloadingHandler(OffloadingHandler): + def __init__( + self, + gpu_block_size: int, + cpu_block_size: int, + num_cpu_blocks: int, + gpu_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], + ): + assert cpu_block_size % gpu_block_size == 0 + self.block_size_factor = cpu_block_size // gpu_block_size + + # cuda streams for gpu->cpu and cpu->gpu + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + # job_id -> transfer cuda event + self.transfer_events: dict[int, torch.cuda.Event] = {} + # list of cuda events available for re-use + self.events_pool: list[torch.cuda.Event] = [] + + pin_memory = is_pin_memory_available() + + # allocate cpu tensors + logger.info("Allocating %d CPU tensors...", len(gpu_caches)) + self.gpu_tensors: list[torch.Tensor] = [] + self.cpu_tensors: list[torch.Tensor] = [] + self.kv_dim_before_num_blocks: list[bool] = [] + for layer_name, gpu_tensor in gpu_caches.items(): + self.gpu_tensors.append(gpu_tensor) + + gpu_shape = gpu_tensor.shape + test_shape = attn_backends[layer_name].get_kv_cache_shape( + num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 + ) + if test_shape[0] == 1234: + # shape is (num_blocks, ...) + num_blocks_idx = 0 + self.kv_dim_before_num_blocks.append(False) + else: + # shape should be (2, num_blocks, ...) + assert test_shape[0] == 2 + assert test_shape[1] == 1234 + assert gpu_shape[0] == 2 + + num_blocks_idx = 1 + self.kv_dim_before_num_blocks.append(True) + + cpu_shape = list(gpu_shape) + cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor + + logger.debug("Allocating CPU tensor of shape %r", cpu_shape) + self.cpu_tensors.append( + torch.zeros( + cpu_shape, + dtype=gpu_tensor.dtype, + device="cpu", + pin_memory=pin_memory, + ) + ) + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + src_spec, dst_spec = spec + if isinstance(src_spec, CPULoadStoreSpec): + assert isinstance(dst_spec, GPULoadStoreSpec) + stream = self.h2d_stream + src_tensors = self.cpu_tensors + dst_tensors = self.gpu_tensors + src_block_size_factor = self.block_size_factor + dst_block_size_factor = 1 + else: + assert isinstance(src_spec, GPULoadStoreSpec) + assert isinstance(dst_spec, CPULoadStoreSpec) + stream = self.d2h_stream + src_tensors = self.gpu_tensors + dst_tensors = self.cpu_tensors + src_block_size_factor = 1 + dst_block_size_factor = self.block_size_factor + + src_blocks = src_spec.block_ids + dst_blocks = dst_spec.block_ids + assert src_blocks.ndim == 1 + assert dst_blocks.ndim == 1 + + dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor + src_sub_block_count = src_blocks.size * src_block_size_factor + + assert ( + src_sub_block_count + == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip + ) + + src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) + expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) + expand_block_ids( + dst_blocks, + dst_block_size_factor, + src_to_dst[:, 1], + skip_count=dst_sub_blocks_to_skip, + ) + src_to_dst_tensor = torch.from_numpy(src_to_dst) + + event = self.events_pool.pop() if self.events_pool else torch.cuda.Event() + with torch.cuda.stream(stream): + for src_tensor, dst_tensor, kv_dim in zip( + src_tensors, dst_tensors, self.kv_dim_before_num_blocks + ): + if kv_dim: + src_key_cache = src_tensor[0] + dst_key_cache = dst_tensor[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) + src_value_cache = src_tensor[1] + dst_value_cache = dst_tensor[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) + else: + ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) + event.record(stream) + + self.transfer_events[job_id] = event + + # success + return True + + def get_finished(self) -> list[TransferResult]: + results: list[TransferResult] = [] + for job_id, event in self.transfer_events.items(): + if event.query(): + results.append((job_id, True)) + self.events_pool.append(event) + for job_id, _ in results: + del self.transfer_events[job_id] + return results diff --git a/vllm/v1/kv_offload/worker/worker.py b/vllm/v1/kv_offload/worker/worker.py new file mode 100644 index 0000000000000..58ba082497fa8 --- /dev/null +++ b/vllm/v1/kv_offload/worker/worker.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +from vllm.logger import init_logger +from vllm.v1.kv_offload.abstract import LoadStoreSpec + +# a single transfer spec (src_blocks_spec, dst_blocks_spec) +TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec] +# transfers are forwarded to workers by (src_medium, dst_medium) +TransferType = tuple[str, str] +# transfer result (job_id, success) +TransferResult = tuple[int, bool] + +logger = init_logger(__name__) + + +class OffloadingHandler(ABC): + """ + OffloadingHandler class for managing asynchronous KV data transfers + + This class runs in the worker. + It kicks off async KV data transfer requests, and allows + collecting back completion statuses. + + The class provides the following primitives: + transfer_async() - kicks off a new transfer job + get_finished() - returns a list of newly finished job IDs. + """ + + @abstractmethod + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + """ + Initiates an asynchronous transfer of KV data. + + Args: + job_id: a unique ID that will be used when notifying back on + transfer completion. + spec: the (src, dst) spec of the KV data transfer. + + Returns: + True if transfer was submitted successfully. + """ + pass + + @abstractmethod + def get_finished(self) -> list[TransferResult]: + """ + Get transfers finished since last call. + + Returns: + A list of (job_id, success) of transfers. + """ + pass + + +class OffloadingWorker: + """ + OffloadingWorker class for managing asynchronous KV data transfers + using multiple OffloadingHandlers + + This class runs in the worker. + It kicks off async KV data transfer requests, by delegating + to one of its registered OffloadingHandlers, based on the transfer type. + + The class provides the following primitives: + register_handler() - registers a new handler to handle + a specific transfer type + transfer_async() - kicks off a new transfer job + using one of the registered handlers. + get_finished() - returns a list of newly finished job IDs + from all handlers. + """ + + def __init__(self): + self.handlers: set[OffloadingHandler] = set() + self.transfer_type_to_handler: dict[TransferType, OffloadingHandler] = {} + + def register_handler( + self, + src_cls: type[LoadStoreSpec], + dst_cls: type[LoadStoreSpec], + handler: OffloadingHandler, + ) -> None: + """ + Registers a new handler. + + Args: + src_cls: the source type of transfers handled by this handler. + dst_cls: the destination type of transfers handled by this handler. + handler: the handler that will handle transfers. + """ + transfer_type = (src_cls.medium(), dst_cls.medium()) + assert transfer_type not in self.transfer_type_to_handler + self.handlers.add(handler) + self.transfer_type_to_handler[transfer_type] = handler + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + """ + Initiates an asynchronous transfer of KV data. + + Args: + job_id: a unique ID that will be used when notifying back on + transfer completion. + spec: the (src, dst) spec of the KV data transfer. + + Returns: + True if transfer was submitted successfully. + """ + src, dst = spec + transfer_type = (src.medium(), dst.medium()) + handler = self.transfer_type_to_handler.get(transfer_type) + assert handler is not None + + try: + success = handler.transfer_async(job_id, spec) + except Exception as e: + logger.warning( + "Exception in %r transfer %d: %r", + transfer_type, + job_id, + e, + exc_info=True, + ) + return False + + if not success: + logger.warning("Failed to submit %r transfer %d", transfer_type, job_id) + else: + logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec) + + return success + + def get_finished(self) -> list[TransferResult]: + """ + Get transfers finished since last call. + + Returns: + A list of (job_id, success) of transfers. + """ + finished = [] + for handler in self.handlers: + finished.extend(handler.get_finished()) + return finished diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3b0616952babf..32d2ed2961dee 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,11 +9,16 @@ from typing import Callable, Optional, Union import prometheus_client from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.prometheus import unregister_vllm_metrics -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import ( + CachingMetrics, + IterationStats, + MultiModalCacheStats, + SchedulerStats, +) from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) @@ -30,35 +35,40 @@ class StatLoggerBase(ABC): """ @abstractmethod - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - ... + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): - ... + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, + engine_idx: int = 0, + ): ... @abstractmethod - def log_engine_initialized(self): - ... + def log_engine_initialized(self): ... def log(self): # noqa pass class LoggingStatLogger(StatLoggerBase): - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self.vllm_config = vllm_config self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() - # Prefix cache metrics. This cannot be reset. + + # Caching metrics. This cannot be reset. # TODO: Make the interval configurable. - self.prefix_caching_metrics = PrefixCachingMetrics() + self.prefix_caching_metrics = CachingMetrics() + self.mm_caching_metrics = CachingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() + kv_tranfer_config = self.vllm_config.kv_transfer_config + self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 @@ -81,30 +91,33 @@ class LoggingStatLogger(StatLoggerBase): return 0.0 return float(tracked_stats / delta_time) - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, + engine_idx: int = 0, + ): """Log Stats to standard output.""" - if iteration_stats: self._track_iteration_stats(iteration_stats) if scheduler_stats is not None: - self.prefix_caching_metrics.observe( - scheduler_stats.prefix_cache_stats) + self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) - + self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) + if kv_connector_stats := scheduler_stats.kv_connector_stats: + self.kv_connector_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats + if mm_cache_stats: + self.mm_caching_metrics.observe(mm_cache_stats) + def log(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) - generation_throughput = self._get_throughput( - self.num_generation_tokens, now) + generation_throughput = self._get_throughput(self.num_generation_tokens, now) self._reset(now) @@ -112,37 +125,56 @@ class LoggingStatLogger(StatLoggerBase): log_fn = logger.info if not any( - (prompt_throughput, generation_throughput, - self.last_prompt_throughput, self.last_generation_throughput)): + ( + prompt_throughput, + generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput, + ) + ): # Avoid log noise on an idle production system log_fn = logger.debug self.last_generation_throughput = generation_throughput self.last_prompt_throughput = prompt_throughput # Format and print output. - log_fn( - "Engine %03d: " - "Avg prompt throughput: %.1f tokens/s, " - "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Waiting: %d reqs, " - "GPU KV cache usage: %.1f%%, " + log_parts = [ + "Avg prompt throughput: %.1f tokens/s", + "Avg generation throughput: %.1f tokens/s", + "Running: %d reqs", + "Waiting: %d reqs", + "GPU KV cache usage: %.1f%%", "Prefix cache hit rate: %.1f%%", - self.engine_index, + ] + log_args = [ prompt_throughput, generation_throughput, scheduler_stats.num_running_reqs, scheduler_stats.num_waiting_reqs, scheduler_stats.kv_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, + ] + if not self.mm_caching_metrics.empty: + log_parts.append("MM cache hit rate: %.1f%%") + log_args.append(self.mm_caching_metrics.hit_rate * 100) + + log_fn( + "Engine %03d: " + ", ".join(log_parts), + self.engine_index, + *log_args, ) + self.spec_decoding_logging.log(log_fn=log_fn) + self.kv_connector_logging.log(log_fn=log_fn) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: logger.info( "Engine %03d: vllm cache_config_info with initialization " - "after num_gpu_blocks is: %d", self.engine_index, - self.vllm_config.cache_config.num_gpu_blocks) + "after num_gpu_blocks is: %d", + self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks, + ) class PrometheusStatLogger(StatLoggerBase): @@ -151,9 +183,9 @@ class PrometheusStatLogger(StatLoggerBase): _histogram_cls = prometheus_client.Histogram _spec_decoding_cls = SpecDecodingProm - def __init__(self, - vllm_config: VllmConfig, - engine_indexes: Optional[list[int]] = None): + def __init__( + self, vllm_config: VllmConfig, engine_indexes: Optional[list[int]] = None + ): if engine_indexes is None: engine_indexes = [0] self.engine_indexes = engine_indexes @@ -162,25 +194,19 @@ class PrometheusStatLogger(StatLoggerBase): self.vllm_config = vllm_config # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future - self.show_hidden_metrics = \ - vllm_config.observability_config.show_hidden_metrics + self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics labelnames = ["model_name", "engine"] model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - if (len(self.engine_indexes) > 1 - and vllm_config.speculative_config is not None): - raise NotImplementedError("Prometheus metrics with Spec Decoding " - "with >1 EngineCore per AsyncLLM is not " - "supported yet.") - spec_decode_labelvalues = [ - vllm_config.model_config.served_model_name, - str(self.engine_indexes[0]) - ] + spec_decode_labelvalues: dict[int, list[str]] = { + idx: [model_name, str(idx)] for idx in engine_indexes + } + self.spec_decoding_prom = self._spec_decoding_cls( - vllm_config.speculative_config, labelnames, - spec_decode_labelvalues) + vllm_config.speculative_config, labelnames, spec_decode_labelvalues + ) # # Scheduler state @@ -189,80 +215,128 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_scheduler_running = make_per_engine(gauge_scheduler_running, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.gauge_scheduler_running = make_per_engine( + gauge_scheduler_running, engine_indexes, model_name + ) gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_scheduler_waiting = make_per_engine(gauge_scheduler_waiting, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.gauge_scheduler_waiting = make_per_engine( + gauge_scheduler_waiting, engine_indexes, model_name + ) # # GPU cache # - # Deprecated in 0.9 - Renamed as vllm:kv_cache_usage_perc - # TODO: in 0.10, only enable if show_hidden_metrics=True - gauge_gpu_cache_usage = self._gauge_cls( - name="vllm:gpu_cache_usage_perc", - documentation=( - "GPU KV-cache usage. 1 means 100 percent usage." - "DEPRECATED: Use vllm:kv_cache_usage_perc instead."), - multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_gpu_cache_usage = make_per_engine(gauge_gpu_cache_usage, - engine_indexes, - model_name) + # Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc + # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 + # TODO: remove in 0.12.0 + if self.show_hidden_metrics: + gauge_gpu_cache_usage = self._gauge_cls( + name="vllm:gpu_cache_usage_perc", + documentation=( + "GPU KV-cache usage. 1 means 100 percent usage." + "DEPRECATED: Use vllm:kv_cache_usage_perc instead." + ), + multiprocess_mode="mostrecent", + labelnames=labelnames, + ) + self.gauge_gpu_cache_usage = make_per_engine( + gauge_gpu_cache_usage, engine_indexes, model_name + ) - # Deprecated in 0.9 - Renamed as vllm:prefix_cache_queries - # TODO: in 0.10, only enable if show_hidden_metrics=True - counter_gpu_prefix_cache_queries = self._counter_cls( - name="vllm:gpu_prefix_cache_queries", - documentation=( - "GPU prefix cache queries, in terms of number of queried" - "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."), - labelnames=labelnames) - self.counter_gpu_prefix_cache_queries = make_per_engine( - counter_gpu_prefix_cache_queries, engine_indexes, model_name) + # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries + # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 + # TODO: remove in 0.12.0 + if self.show_hidden_metrics: + counter_gpu_prefix_cache_queries = self._counter_cls( + name="vllm:gpu_prefix_cache_queries", + documentation=( + "GPU prefix cache queries, in terms of number of queried" + "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead." + ), + labelnames=labelnames, + ) + self.counter_gpu_prefix_cache_queries = make_per_engine( + counter_gpu_prefix_cache_queries, engine_indexes, model_name + ) - # Deprecated in 0.9 - Renamed as vllm:prefix_cache_hits - # TODO: in 0.10, only enable if show_hidden_metrics=True - counter_gpu_prefix_cache_hits = self._counter_cls( - name="vllm:gpu_prefix_cache_hits", - documentation=( - "GPU prefix cache hits, in terms of number of cached " - "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."), - labelnames=labelnames) - self.counter_gpu_prefix_cache_hits = make_per_engine( - counter_gpu_prefix_cache_hits, engine_indexes, model_name) + # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits + # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 + # TODO: remove in 0.12.0 + if self.show_hidden_metrics: + counter_gpu_prefix_cache_hits = self._counter_cls( + name="vllm:gpu_prefix_cache_hits", + documentation=( + "GPU prefix cache hits, in terms of number of cached " + "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead." + ), + labelnames=labelnames, + ) + self.counter_gpu_prefix_cache_hits = make_per_engine( + counter_gpu_prefix_cache_hits, engine_indexes, model_name + ) gauge_kv_cache_usage = self._gauge_cls( name="vllm:kv_cache_usage_perc", documentation="KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) - self.gauge_kv_cache_usage = make_per_engine(gauge_kv_cache_usage, - engine_indexes, model_name) + labelnames=labelnames, + ) + self.gauge_kv_cache_usage = make_per_engine( + gauge_kv_cache_usage, engine_indexes, model_name + ) counter_prefix_cache_queries = self._counter_cls( name="vllm:prefix_cache_queries", documentation=( - "Prefix cache queries, in terms of number of queried tokens."), - labelnames=labelnames) + "Prefix cache queries, in terms of number of queried tokens." + ), + labelnames=labelnames, + ) self.counter_prefix_cache_queries = make_per_engine( - counter_prefix_cache_queries, engine_indexes, model_name) + counter_prefix_cache_queries, engine_indexes, model_name + ) counter_prefix_cache_hits = self._counter_cls( name="vllm:prefix_cache_hits", - documentation=( - "Prefix cache hits, in terms of number of cached tokens."), - labelnames=labelnames) + documentation=("Prefix cache hits, in terms of number of cached tokens."), + labelnames=labelnames, + ) self.counter_prefix_cache_hits = make_per_engine( - counter_prefix_cache_hits, engine_indexes, model_name) + counter_prefix_cache_hits, engine_indexes, model_name + ) + + # + # Multi-modal cache + # + + counter_mm_cache_queries = self._counter_cls( + name="vllm:mm_cache_queries", + documentation=( + "Multi-modal cache queries, in terms of number of queried items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_queries = make_per_engine( + counter_mm_cache_queries, engine_indexes, model_name + ) + + counter_mm_cache_hits = self._counter_cls( + name="vllm:mm_cache_hits", + documentation=( + "Multi-modal cache hits, in terms of number of cached items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_hits = make_per_engine( + counter_mm_cache_hits, engine_indexes, model_name + ) # # Counters @@ -270,36 +344,43 @@ class PrometheusStatLogger(StatLoggerBase): counter_num_preempted_reqs = self._counter_cls( name="vllm:num_preemptions", documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_num_preempted_reqs = make_per_engine( - counter_num_preempted_reqs, engine_indexes, model_name) + counter_num_preempted_reqs, engine_indexes, model_name + ) counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens", documentation="Number of prefill tokens processed.", - labelnames=labelnames) - self.counter_prompt_tokens = make_per_engine(counter_prompt_tokens, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.counter_prompt_tokens = make_per_engine( + counter_prompt_tokens, engine_indexes, model_name + ) counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens", documentation="Number of generation tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_generation_tokens = make_per_engine( - counter_generation_tokens, engine_indexes, model_name) + counter_generation_tokens, engine_indexes, model_name + ) - self.counter_request_success: dict[FinishReason, dict[ - int, prometheus_client.Counter]] = {} + self.counter_request_success: dict[ + FinishReason, dict[int, prometheus_client.Counter] + ] = {} counter_request_success_base = self._counter_cls( name="vllm:request_success", documentation="Count of successfully processed requests.", - labelnames=labelnames + ["finished_reason"]) + labelnames=labelnames + ["finished_reason"], + ) for reason in FinishReason: self.counter_request_success[reason] = { - idx: - counter_request_success_base.labels(model_name, str(idx), - str(reason)) + idx: counter_request_success_base.labels( + model_name, str(idx), str(reason) + ) for idx in engine_indexes } @@ -310,18 +391,21 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_num_prompt_tokens_request = make_per_engine( - histogram_num_prompt_tokens_request, engine_indexes, model_name) + histogram_num_prompt_tokens_request, engine_indexes, model_name + ) histogram_num_generation_tokens_request = self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_num_generation_tokens_request = make_per_engine( - histogram_num_generation_tokens_request, engine_indexes, - model_name) + histogram_num_generation_tokens_request, engine_indexes, model_name + ) # TODO: This metric might be incorrect in case of using multiple # api_server counts which uses prometheus mp. @@ -329,38 +413,42 @@ class PrometheusStatLogger(StatLoggerBase): histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", - buckets=[ - 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 - ], - labelnames=labelnames) + buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + labelnames=labelnames, + ) self.histogram_iteration_tokens = make_per_engine( - histogram_iteration_tokens, engine_indexes, model_name) + histogram_iteration_tokens, engine_indexes, model_name + ) histogram_max_num_generation_tokens_request = self._histogram_cls( name="vllm:request_max_num_generation_tokens", - documentation= - "Histogram of maximum number of requested generation tokens.", + documentation="Histogram of maximum number of requested generation tokens.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_max_num_generation_tokens_request = make_per_engine( - histogram_max_num_generation_tokens_request, engine_indexes, - model_name) + histogram_max_num_generation_tokens_request, engine_indexes, model_name + ) histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", buckets=[1, 2, 5, 10, 20], - labelnames=labelnames) - self.histogram_n_request = make_per_engine(histogram_n_request, - engine_indexes, model_name) + labelnames=labelnames, + ) + self.histogram_n_request = make_per_engine( + histogram_n_request, engine_indexes, model_name + ) histogram_max_tokens_request = self._histogram_cls( name="vllm:request_params_max_tokens", documentation="Histogram of the max_tokens request parameter.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_max_tokens_request = make_per_engine( - histogram_max_tokens_request, engine_indexes, model_name) + histogram_max_tokens_request, engine_indexes, model_name + ) # # Histogram of timing intervals @@ -369,72 +457,202 @@ class PrometheusStatLogger(StatLoggerBase): name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", buckets=[ - 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, - 2560.0 + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + 160.0, + 640.0, + 2560.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_time_to_first_token = make_per_engine( - histogram_time_to_first_token, engine_indexes, model_name) + histogram_time_to_first_token, engine_indexes, model_name + ) + # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds + # TODO: in 0.12, only enable if show_hidden_metrics=True histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", + documentation=( + "Histogram of time per output token in seconds." + "DEPRECATED: Use vllm:inter_token_latency_seconds instead." + ), buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_time_per_output_token = make_per_engine( - histogram_time_per_output_token, engine_indexes, model_name) + histogram_time_per_output_token, engine_indexes, model_name + ) + + histogram_inter_token_latency = self._histogram_cls( + name="vllm:inter_token_latency_seconds", + documentation="Histogram of inter-token latency in seconds.", + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + ], + labelnames=labelnames, + ) + self.histogram_inter_token_latency = make_per_engine( + histogram_inter_token_latency, engine_indexes, model_name + ) + + histogram_request_time_per_output_token = self._histogram_cls( + name="vllm:request_time_per_output_token_seconds", + documentation="Histogram of time_per_output_token_seconds per request.", + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + ], + labelnames=labelnames, + ) + self.histogram_request_time_per_output_token = make_per_engine( + histogram_request_time_per_output_token, engine_indexes, model_name + ) request_latency_buckets = [ - 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, - 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + 0.3, + 0.5, + 0.8, + 1.0, + 1.5, + 2.0, + 2.5, + 5.0, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + 120.0, + 240.0, + 480.0, + 960.0, + 1920.0, + 7680.0, ] histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of e2e request latency in seconds.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_e2e_time_request = make_per_engine( - histogram_e2e_time_request, engine_indexes, model_name) + histogram_e2e_time_request, engine_indexes, model_name + ) histogram_queue_time_request = self._histogram_cls( name="vllm:request_queue_time_seconds", - documentation= - "Histogram of time spent in WAITING phase for request.", + documentation="Histogram of time spent in WAITING phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_queue_time_request = make_per_engine( - histogram_queue_time_request, engine_indexes, model_name) + histogram_queue_time_request, engine_indexes, model_name + ) histogram_inference_time_request = self._histogram_cls( name="vllm:request_inference_time_seconds", - documentation= - "Histogram of time spent in RUNNING phase for request.", + documentation="Histogram of time spent in RUNNING phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_inference_time_request = make_per_engine( - histogram_inference_time_request, engine_indexes, model_name) + histogram_inference_time_request, engine_indexes, model_name + ) histogram_prefill_time_request = self._histogram_cls( name="vllm:request_prefill_time_seconds", - documentation= - "Histogram of time spent in PREFILL phase for request.", + documentation="Histogram of time spent in PREFILL phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_prefill_time_request = make_per_engine( - histogram_prefill_time_request, engine_indexes, model_name) + histogram_prefill_time_request, engine_indexes, model_name + ) histogram_decode_time_request = self._histogram_cls( name="vllm:request_decode_time_seconds", - documentation= - "Histogram of time spent in DECODE phase for request.", + documentation="Histogram of time spent in DECODE phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_decode_time_request = make_per_engine( - histogram_decode_time_request, engine_indexes, model_name) + histogram_decode_time_request, engine_indexes, model_name + ) # # LoRA metrics @@ -445,23 +663,21 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: if len(self.engine_indexes) > 1: - raise NotImplementedError( - "LoRA in DP mode is not supported yet.") + raise NotImplementedError("LoRA in DP mode is not supported yet.") self.labelname_max_lora = "max_lora" self.labelname_waiting_lora_adapters = "waiting_lora_adapters" self.labelname_running_lora_adapters = "running_lora_adapters" self.max_lora = vllm_config.lora_config.max_loras - self.gauge_lora_info = \ - self._gauge_cls( - name="vllm:lora_requests_info", - documentation="Running stats on lora requests.", - multiprocess_mode="sum", - labelnames=[ - self.labelname_max_lora, - self.labelname_waiting_lora_adapters, - self.labelname_running_lora_adapters, - ], - ) + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + multiprocess_mode="sum", + labelnames=[ + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + self.labelname_running_lora_adapters, + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): metrics_info = config_obj.metrics_info() @@ -487,92 +703,124 @@ class PrometheusStatLogger(StatLoggerBase): metrics_info["engine"] = str(engine_index) info_gauge.labels(**metrics_info).set(1) - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, + engine_idx: int = 0, + ): """Log to prometheus.""" if scheduler_stats is not None: self.gauge_scheduler_running[engine_idx].set( - scheduler_stats.num_running_reqs) + scheduler_stats.num_running_reqs + ) self.gauge_scheduler_waiting[engine_idx].set( - scheduler_stats.num_waiting_reqs) + scheduler_stats.num_waiting_reqs + ) - self.gauge_gpu_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) - self.gauge_kv_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) + if self.show_hidden_metrics: + self.gauge_gpu_cache_usage[engine_idx].set( + scheduler_stats.kv_cache_usage + ) + self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage) - self.counter_gpu_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + if self.show_hidden_metrics: + self.counter_gpu_prefix_cache_queries[engine_idx].inc( + scheduler_stats.prefix_cache_stats.queries + ) + self.counter_gpu_prefix_cache_hits[engine_idx].inc( + scheduler_stats.prefix_cache_stats.hits + ) self.counter_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) + scheduler_stats.prefix_cache_stats.queries + ) self.counter_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + scheduler_stats.prefix_cache_stats.hits + ) if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + scheduler_stats.spec_decoding_stats, engine_idx + ) + + if mm_cache_stats is not None: + self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) + self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) if iteration_stats is None: return self.counter_num_preempted_reqs[engine_idx].inc( - iteration_stats.num_preempted_reqs) - self.counter_prompt_tokens[engine_idx].inc( - iteration_stats.num_prompt_tokens) + iteration_stats.num_preempted_reqs + ) + self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens[engine_idx].inc( - iteration_stats.num_generation_tokens) + iteration_stats.num_generation_tokens + ) self.histogram_iteration_tokens[engine_idx].observe( - iteration_stats.num_prompt_tokens + \ - iteration_stats.num_generation_tokens) + iteration_stats.num_prompt_tokens + iteration_stats.num_generation_tokens + ) for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: - self.histogram_max_num_generation_tokens_request[ - engine_idx].observe(max_gen_tokens) + self.histogram_max_num_generation_tokens_request[engine_idx].observe( + max_gen_tokens + ) for n_param in iteration_stats.n_params_iter: self.histogram_n_request[engine_idx].observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token[engine_idx].observe(ttft) - for tpot in iteration_stats.time_per_output_tokens_iter: - self.histogram_time_per_output_token[engine_idx].observe(tpot) + for itl in iteration_stats.inter_token_latencies_iter: + self.histogram_inter_token_latency[engine_idx].observe(itl) + self.histogram_time_per_output_token[engine_idx].observe(itl) for finished_request in iteration_stats.finished_requests: - self.counter_request_success[ - finished_request.finish_reason][engine_idx].inc() + self.counter_request_success[finished_request.finish_reason][ + engine_idx + ].inc() self.histogram_e2e_time_request[engine_idx].observe( - finished_request.e2e_latency) + finished_request.e2e_latency + ) self.histogram_queue_time_request[engine_idx].observe( - finished_request.queued_time) + finished_request.queued_time + ) self.histogram_prefill_time_request[engine_idx].observe( - finished_request.prefill_time) + finished_request.prefill_time + ) self.histogram_inference_time_request[engine_idx].observe( - finished_request.inference_time) + finished_request.inference_time + ) self.histogram_decode_time_request[engine_idx].observe( - finished_request.decode_time) + finished_request.decode_time + ) self.histogram_num_prompt_tokens_request[engine_idx].observe( - finished_request.num_prompt_tokens) + finished_request.num_prompt_tokens + ) self.histogram_num_generation_tokens_request[engine_idx].observe( - finished_request.num_generation_tokens) + finished_request.num_generation_tokens + ) + self.histogram_request_time_per_output_token[engine_idx].observe( + finished_request.mean_time_per_output_token + ) if finished_request.max_tokens_param: self.histogram_max_tokens_request[engine_idx].observe( - finished_request.max_tokens_param) + finished_request.max_tokens_param + ) if self.gauge_lora_info is not None: - running_lora_adapters = \ - ",".join(iteration_stats.running_lora_adapters.keys()) - waiting_lora_adapters = \ - ",".join(iteration_stats.waiting_lora_adapters.keys()) + running_lora_adapters = ",".join( + iteration_stats.running_lora_adapters.keys() + ) + waiting_lora_adapters = ",".join( + iteration_stats.waiting_lora_adapters.keys() + ) lora_info_labels = { self.labelname_running_lora_adapters: running_lora_adapters, self.labelname_waiting_lora_adapters: waiting_lora_adapters, self.labelname_max_lora: self.max_lora, } - self.gauge_lora_info.labels(**lora_info_labels)\ - .set_to_current_time() + self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) @@ -585,8 +833,9 @@ PromMetric = Union[ ] -def make_per_engine(metric: PromMetric, engine_idxs: list[int], - model_name: str) -> dict[int, PromMetric]: +def make_per_engine( + metric: PromMetric, engine_idxs: list[int], model_name: str +) -> dict[int, PromMetric]: return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs} @@ -635,15 +884,22 @@ class StatLoggerManager: vllm_config: VllmConfig, engine_idxs: Optional[list[int]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_default_loggers: bool = True, + client_count: int = 1, ): self.engine_idxs = engine_idxs if engine_idxs else [0] - factories: list[StatLoggerFactory] + factories: list[StatLoggerFactory] = [] if custom_stat_loggers is not None: - factories = custom_stat_loggers - else: - factories = [] - if logger.isEnabledFor(logging.INFO): + factories.extend(custom_stat_loggers) + + if enable_default_loggers and logger.isEnabledFor(logging.INFO): + if client_count > 1: + logger.warning( + "AsyncLLM created with api_server_count more than 1; " + "disabling stats logging to avoid incomplete stats." + ) + else: factories.append(LoggingStatLogger) # engine_idx: StatLogger @@ -654,12 +910,12 @@ class StatLoggerManager: for logger_factory in factories: # If we get a custom prometheus logger, use that # instead. This is typically used for the ray case. - if (isinstance(logger_factory, type) - and issubclass(logger_factory, PrometheusStatLogger)): + if isinstance(logger_factory, type) and issubclass( + logger_factory, PrometheusStatLogger + ): prometheus_factory = logger_factory continue - loggers.append(logger_factory(vllm_config, - engine_idx)) # type: ignore + loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore self.per_engine_logger_dict[engine_idx] = loggers # For Prometheus, need to share the metrics between EngineCores. @@ -670,6 +926,7 @@ class StatLoggerManager: self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: Optional[int] = None, ): if engine_idx is None: @@ -677,10 +934,19 @@ class StatLoggerManager: per_engine_loggers = self.per_engine_logger_dict[engine_idx] for logger in per_engine_loggers: - logger.record(scheduler_stats, iteration_stats, engine_idx) + logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) - self.prometheus_logger.record(scheduler_stats, iteration_stats, - engine_idx) + self.prometheus_logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) def log(self): for per_engine_loggers in self.per_engine_logger_dict.values(): diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index 61ba5d66cb31a..5823737968f9a 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -16,9 +16,7 @@ _prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None def setup_multiprocess_prometheus(): - """Set up prometheus multiprocessing directory if not already configured. - - """ + """Set up prometheus multiprocessing directory if not already configured.""" global _prometheus_multiproc_dir if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: @@ -27,19 +25,22 @@ def setup_multiprocess_prometheus(): # cleaned up upon exit. _prometheus_multiproc_dir = tempfile.TemporaryDirectory() os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name - logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", - _prometheus_multiproc_dir.name) + logger.debug( + "Created PROMETHEUS_MULTIPROC_DIR at %s", _prometheus_multiproc_dir.name + ) else: - logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup." + ) -def get_prometheus_registry(): - """Get the appropriate prometheus registry based on multiprocessing +def get_prometheus_registry() -> CollectorRegistry: + """Get the appropriate prometheus registry based on multiprocessing configuration. - + Returns: Registry: A prometheus registry """ @@ -54,11 +55,11 @@ def get_prometheus_registry(): def unregister_vllm_metrics(): """Unregister any existing vLLM collectors from the prometheus registry. - + This is useful for testing and CI/CD where metrics may be registered multiple times across test runs. - - Also, in case of multiprocess, we need to unregister the metrics from the + + Also, in case of multiprocess, we need to unregister the metrics from the global registry. """ registry = REGISTRY diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index ae8f9447e9c8b..a6fe2062f70cf 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -11,14 +11,13 @@ try: from ray.util.metrics import Metric except ImportError: ray_metrics = None +import regex as re class RayPrometheusMetric: - def __init__(self): if ray_metrics is None: - raise ImportError( - "RayPrometheusMetric requires Ray to be installed.") + raise ImportError("RayPrometheusMetric requires Ray to be installed.") self.metric: Metric = None @@ -37,30 +36,46 @@ class RayPrometheusMetric: f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" ) - self.metric.set_default_tags( - dict(zip(self.metric._tag_keys, labels))) + self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels))) return self + @staticmethod + def _get_sanitized_opentelemetry_name(name: str) -> str: + """ + For compatibility with Ray + OpenTelemetry, the metric name must be + sanitized. In particular, this replaces disallowed character (e.g., ':') + with '_' in the metric name. + Allowed characters: a-z, A-Z, 0-9, _ + + # ruff: noqa: E501 + Ref: https://github.com/open-telemetry/opentelemetry-cpp/blob/main/sdk/src/metrics/instrument_metadata_validator.cc#L22-L23 + Ref: https://github.com/ray-project/ray/blob/master/src/ray/stats/metric.cc#L107 + """ + + return re.sub(r"[^a-zA-Z0-9_]", "_", name) + class RayGaugeWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Gauge to provide same API as prometheus_client.Gauge""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - multiprocess_mode: Optional[str] = ""): - + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + multiprocess_mode: Optional[str] = "", + ): # All Ray metrics are keyed by WorkerId, so multiprocess modes like # "mostrecent", "all", "sum" do not apply. This logic can be manually # implemented at the observability layer (Prometheus/Grafana). del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None - self.metric = ray_metrics.Gauge(name=name, - description=documentation, - tag_keys=labelnames_tuple) + name = self._get_sanitized_opentelemetry_name(name) + self.metric = ray_metrics.Gauge( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def set(self, value: Union[int, float]): return self.metric.set(value) @@ -74,14 +89,17 @@ class RayCounterWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Counter to provide same API as prometheus_client.Counter""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None): + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None - self.metric = ray_metrics.Counter(name=name, - description=documentation, - tag_keys=labelnames_tuple) + name = self._get_sanitized_opentelemetry_name(name) + self.metric = ray_metrics.Counter( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def inc(self, value: Union[int, float] = 1.0): if value == 0: @@ -93,17 +111,22 @@ class RayHistogramWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Histogram to provide same API as prometheus_client.Histogram""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - buckets: Optional[list[float]] = None): + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None + name = self._get_sanitized_opentelemetry_name(name) boundaries = buckets if buckets else [] - self.metric = ray_metrics.Histogram(name=name, - description=documentation, - tag_keys=labelnames_tuple, - boundaries=boundaries) + self.metric = ray_metrics.Histogram( + name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries, + ) def observe(self, value: Union[int, float]): return self.metric.observe(value) diff --git a/vllm/v1/metrics/reader.py b/vllm/v1/metrics/reader.py index 4d6e599841541..5d50fa9461d0c 100644 --- a/vllm/v1/metrics/reader.py +++ b/vllm/v1/metrics/reader.py @@ -17,6 +17,7 @@ class Metric: in some cases a single vLLM instance may have multiple metrics with the same name but different sets of labels. """ + name: str labels: dict[str, str] @@ -24,6 +25,7 @@ class Metric: @dataclass class Counter(Metric): """A monotonically increasing integer counter.""" + value: int @@ -34,12 +36,14 @@ class Vector(Metric): This type - which doesn't exist in Prometheus - models one very specific metric, vllm:spec_decode_num_accepted_tokens_per_pos. """ + values: list[int] @dataclass class Gauge(Metric): """A numerical value that can go up or down.""" + value: float @@ -58,6 +62,7 @@ class Histogram(Metric): The sum property is the total sum of all observed values. """ + count: int sum: float buckets: dict[str, int] @@ -87,7 +92,8 @@ def get_metrics_snapshot() -> list[Metric]: samples = _get_samples(metric) for s in samples: collected.append( - Gauge(name=metric.name, labels=s.labels, value=s.value)) + Gauge(name=metric.name, labels=s.labels, value=s.value) + ) elif metric.type == "counter": samples = _get_samples(metric, "_total") if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": @@ -99,16 +105,15 @@ def get_metrics_snapshot() -> list[Metric]: # accepted tokens using a Counter labeled with 'position'. # We convert these into a vector of integer values. # - for labels, values in _digest_num_accepted_by_pos_samples( - samples): + for labels, values in _digest_num_accepted_by_pos_samples(samples): collected.append( - Vector(name=metric.name, labels=labels, values=values)) + Vector(name=metric.name, labels=labels, values=values) + ) else: for s in samples: collected.append( - Counter(name=metric.name, - labels=s.labels, - value=int(s.value))) + Counter(name=metric.name, labels=s.labels, value=int(s.value)) + ) elif metric.type == "histogram": # @@ -122,21 +127,24 @@ def get_metrics_snapshot() -> list[Metric]: count_samples = _get_samples(metric, "_count") sum_samples = _get_samples(metric, "_sum") for labels, buckets, count_value, sum_value in _digest_histogram( - bucket_samples, count_samples, sum_samples): + bucket_samples, count_samples, sum_samples + ): collected.append( - Histogram(name=metric.name, - labels=labels, - buckets=buckets, - count=count_value, - sum=sum_value)) + Histogram( + name=metric.name, + labels=labels, + buckets=buckets, + count=count_value, + sum=sum_value, + ) + ) else: raise AssertionError(f"Unknown metric type {metric.type}") return collected -def _get_samples(metric: PromMetric, - suffix: Optional[str] = None) -> list[Sample]: +def _get_samples(metric: PromMetric, suffix: Optional[str] = None) -> list[Sample]: name = (metric.name + suffix) if suffix is not None else metric.name return [s for s in metric.samples if s.name == name] @@ -148,8 +156,7 @@ def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]: def _digest_histogram( - bucket_samples: list[Sample], count_samples: list[Sample], - sum_samples: list[Sample] + bucket_samples: list[Sample], count_samples: list[Sample], sum_samples: list[Sample] ) -> list[tuple[dict[str, str], dict[str, int], int, float]]: # # In the case of DP, we have an indigestable @@ -192,20 +199,25 @@ def _digest_histogram( labels_key = frozenset(s.labels.items()) sums_by_labels[labels_key] = s.value - assert set(buckets_by_labels.keys()) == set( - counts_by_labels.keys()) == set(sums_by_labels.keys()) + assert ( + set(buckets_by_labels.keys()) + == set(counts_by_labels.keys()) + == set(sums_by_labels.keys()) + ) output = [] label_keys = list(buckets_by_labels.keys()) for k in label_keys: labels = dict(k) - output.append((labels, buckets_by_labels[k], counts_by_labels[k], - sums_by_labels[k])) + output.append( + (labels, buckets_by_labels[k], counts_by_labels[k], sums_by_labels[k]) + ) return output def _digest_num_accepted_by_pos_samples( - samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]: + samples: list[Sample], +) -> list[tuple[dict[str, str], list[int]]]: # # In the case of DP, we have an indigestable # per-position-per-engine count as a list of diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 9a80460261e02..8d21efca87f44 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from collections import deque from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -13,17 +14,127 @@ if TYPE_CHECKING: @dataclass -class PrefixCacheStats: - """Stores prefix cache hit statistics.""" - # Whether reset_prefix_cache was invoked. +class BaseCacheStats: + """Stores cache hit statistics.""" + reset: bool = False - # The number of requests in this update. + """Whether the cache was reset.""" + requests: int = 0 - # The number of queries in these requests. Note that "queries" here - # means the number of tokens that were queried from the cache. + """The number of requests in this update.""" + queries: int = 0 - # The number of hits in these requests. + """The number of queries in these requests.""" + hits: int = 0 + """The number of hits in these requests.""" + + +class CachingMetrics: + """Metrics for caching with a hit rate of the most recent N requests. + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, max_recent_requests: int = 1000) -> None: + super().__init__() + + self.max_recent_requests = max_recent_requests + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue = deque[tuple[int, int, int]]() + + def observe(self, stats: BaseCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `max_recent_requests` requests, the oldest set + of requests are removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # DO NOT appending empty stats to avoid helpful info get kicked out + # due to sliding window. + if stats.requests == 0: + return + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats until number of requests does not exceed + # the limit. + # NOTE: We preserve the latest added stats regardless. + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def empty(self) -> bool: + """Return true if no requests have been observed.""" + return self.aggregated_requests == 0 + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + +@dataclass +class PrefixCacheStats(BaseCacheStats): + """ + Stores prefix cache hit statistics. + - `reset`: Whether `reset_prefix_cache` was invoked. + - `queries`: Refers to the number of tokens that were queried. + """ + + preempted_requests: int = 0 + """The number of previously preempted requests in this update.""" + + preempted_queries: int = 0 + """The `queries` number for preempted requests.""" + + preempted_hits: int = 0 + """The `hits` number for preempted requests.""" + + +@dataclass +class MultiModalCacheStats(BaseCacheStats): + """ + Stores multi-modal cache hit statistics. + - `reset`: Whether `reset_mm_cache` was invoked. + - `queries`: Refers to the number of multi-modal data items + that were queried. + """ @dataclass @@ -39,10 +150,10 @@ class SchedulerStats: kv_cache_usage: float = 0.0 - prefix_cache_stats: PrefixCacheStats = field( - default_factory=PrefixCacheStats) + prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) spec_decoding_stats: Optional[SpecDecodingStats] = None + kv_connector_stats: Optional[dict[str, Any]] = None num_corrupted_reqs: int = 0 @@ -59,7 +170,7 @@ class RequestStateStats: num_generation_tokens: int = 0 - # This is a engine frontend timestamp (wall-clock) + # This is an engine frontend timestamp (wall-clock) arrival_time: float = 0.0 # These are engine core timestamps (monotonic) @@ -68,6 +179,9 @@ class RequestStateStats: first_token_ts: float = 0.0 last_token_ts: float = 0.0 + # first token latency + first_token_latency: float = 0.0 + @dataclass class FinishedRequestStats: @@ -82,6 +196,7 @@ class FinishedRequestStats: prefill_time: float = 0.0 inference_time: float = 0.0 decode_time: float = 0.0 + mean_time_per_output_token: float = 0.0 class IterationStats: @@ -96,18 +211,27 @@ class IterationStats: self.max_num_generation_tokens_iter: list[int] = [] self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] - self.time_per_output_tokens_iter: list[float] = [] + self.inter_token_latencies_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} self.running_lora_adapters: dict[str, int] = {} + def __repr__(self) -> str: + field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items()) + return f"{self.__class__.__name__}({field_to_value_str})" + def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" return self.iteration_timestamp - start - def update_from_output(self, output: "EngineCoreOutput", - engine_core_timestamp: float, is_prefilling: bool, - prompt_len: int, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats]): + def update_from_output( + self, + output: "EngineCoreOutput", + engine_core_timestamp: float, + is_prefilling: bool, + prompt_len: int, + req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats], + ): num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens @@ -116,28 +240,36 @@ class IterationStats: first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) + req_stats.first_token_latency = first_token_latency req_stats.num_generation_tokens += num_new_generation_tokens # Process request-level engine core events if output.events is not None: - self.update_from_events(output.request_id, output.events, - is_prefilling, req_stats, lora_stats) + self.update_from_events( + output.request_id, output.events, is_prefilling, req_stats, lora_stats + ) # Process the batch-level "new tokens" engine core event if is_prefilling: req_stats.first_token_ts = engine_core_timestamp else: - tpot = engine_core_timestamp - req_stats.last_token_ts - self.time_per_output_tokens_iter.append(tpot) + itl = engine_core_timestamp - req_stats.last_token_ts + self.inter_token_latencies_iter.append(itl) req_stats.last_token_ts = engine_core_timestamp - def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], - is_prefilling: bool, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats]): + def update_from_events( + self, + req_id: str, + events: list["EngineCoreEvent"], + is_prefilling: bool, + req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats], + ): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType + for event in events: if event.type == EngineCoreEventType.QUEUED: req_stats.queued_ts = event.timestamp @@ -151,10 +283,13 @@ class IterationStats: self.num_preempted_reqs += 1 LoRARequestStates.preempted_request(lora_stats, req_id) - def update_from_finished_request(self, finish_reason: "FinishReason", - num_prompt_tokens: int, - max_tokens_param: Optional[int], - req_stats: RequestStateStats): + def update_from_finished_request( + self, + finish_reason: "FinishReason", + num_prompt_tokens: int, + max_tokens_param: Optional[int], + req_stats: RequestStateStats, + ): e2e_latency = self._time_since(req_stats.arrival_time) # Queued interval is from first QUEUED event to first SCHEDULED @@ -172,16 +307,25 @@ class IterationStats: # Any preemptions during prefill or decode are included inference_time = req_stats.last_token_ts - req_stats.scheduled_ts - finished_req = \ - FinishedRequestStats(finish_reason=finish_reason, - e2e_latency=e2e_latency, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=req_stats.num_generation_tokens, - max_tokens_param=max_tokens_param, - queued_time=queued_time, - prefill_time=prefill_time, - inference_time=inference_time, - decode_time=decode_time) + # Do not count the token generated by the prefill phase + mean_time_per_output_token = ( + decode_time / (req_stats.num_generation_tokens - 1) + if req_stats.num_generation_tokens - 1 > 0 + else 0 + ) + + finished_req = FinishedRequestStats( + finish_reason=finish_reason, + e2e_latency=e2e_latency, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, + queued_time=queued_time, + prefill_time=prefill_time, + inference_time=inference_time, + decode_time=decode_time, + mean_time_per_output_token=mean_time_per_output_token, + ) self.finished_requests.append(finished_req) @@ -191,24 +335,24 @@ class LoRARequestStates: def __init__(self): self.lora_name_to_stats: dict[str, LoRAStats] = {} - def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: + def get_stats(self, req_state: "RequestState") -> Optional[LoRAStats]: if req_state.lora_name is None: return None if req_state.lora_name not in self.lora_name_to_stats: self.lora_name_to_stats[req_state.lora_name] = LoRAStats() return self.lora_name_to_stats[req_state.lora_name] - def add_request(self, req_state: 'RequestState'): + def add_request(self, req_state: "RequestState"): if (lora_stats := self.get_stats(req_state)) is not None: lora_stats.waiting_requests.add(req_state.request_id) - def finish_request(self, req_state: 'RequestState'): + def finish_request(self, req_state: "RequestState"): if req_state.lora_name is None: return lora_stats = self.lora_name_to_stats[req_state.lora_name] lora_stats.running_requests.remove(req_state.request_id) - def abort_request(self, req_state: 'RequestState'): + def abort_request(self, req_state: "RequestState"): if req_state.lora_name is None: return lora_stats = self.lora_name_to_stats[req_state.lora_name] @@ -231,14 +375,15 @@ class LoRARequestStates: lora_stats.running_requests.remove(request_id) lora_stats.waiting_requests.add(request_id) - def update_iteration_stats(self, - iteration_stats: Optional[IterationStats]): + def update_iteration_stats(self, iteration_stats: Optional[IterationStats]): if iteration_stats is None: return for lora_name, stats in self.lora_name_to_stats.items(): if stats.waiting_requests: - iteration_stats.waiting_lora_adapters[lora_name] = \ - len(stats.waiting_requests) + iteration_stats.waiting_lora_adapters[lora_name] = len( + stats.waiting_requests + ) if stats.running_requests: - iteration_stats.running_lora_adapters[lora_name] = \ - len(stats.running_requests) + iteration_stats.running_lora_adapters[lora_name] = len( + stats.running_requests + ) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 965df89dda7f2..d647b207575cf 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import NamedTuple, Optional +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, NamedTuple, Optional, Union import torch +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats + class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] @@ -25,7 +28,6 @@ class LogprobsLists(NamedTuple): class LogprobsTensors(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] @@ -41,18 +43,18 @@ class LogprobsTensors(NamedTuple): ) @staticmethod - def empty_cpu(num_positions: int, - num_tokens_per_position: int) -> "LogprobsTensors": + def empty_cpu( + num_positions: int, num_tokens_per_position: int + ) -> "LogprobsTensors": """Create empty LogprobsTensors on CPU.""" logprob_token_ids = torch.empty( - (num_positions, num_tokens_per_position), - dtype=torch.int32, - device="cpu") + (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu" + ) logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) - selected_token_ranks = torch.empty(num_positions, - dtype=torch.int32, - device="cpu") + selected_token_ranks = torch.empty( + num_positions, dtype=torch.int32, device="cpu" + ) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, @@ -60,9 +62,13 @@ class LogprobsTensors(NamedTuple): ) +# [num_reqs, <dynamic>] +# The shape of each element depends on the pooler used +PoolerOutput = Union[torch.Tensor, list[torch.Tensor]] + + @dataclass class SamplerOutput: - # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. @@ -76,15 +82,28 @@ class KVConnectorOutput: # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + kv_connector_stats: Optional["KVConnectorStats"] = None + # IDs of externally computed KV blocks that failed to load. + # Requests referencing these blocks should be rescheduled to recompute them. + invalid_block_ids: set[int] = field(default_factory=set) + + def is_empty(self): + return ( + not self.finished_sending + and not self.finished_recving + and not self.kv_connector_stats + and not self.invalid_block_ids + ) # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass class ModelRunnerOutput: - # [num_reqs] req_ids: list[str] + # req_id -> index + req_id_to_index: dict[str, int] # num_reqs x num_generated_tokens # num_generated_tokens is the number of tokens @@ -112,18 +131,33 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +# ModelRunnerOutput wrapper for async scheduling. +class AsyncModelRunnerOutput(ABC): + @abstractmethod + def get_output(self) -> ModelRunnerOutput: + """Get the ModelRunnerOutput for this async output. + + This is a blocking call that waits until the results are ready, which + might involve copying device tensors to the host. + This method should only be called once per AsyncModelRunnerOutput. + """ + pass + + @dataclass class DraftTokenIds: - # [num_reqs] req_ids: list[str] # num_reqs x num_draft_tokens draft_token_ids: list[list[int]] -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, +) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 46506d272e90a..36ae5b40a3138 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -29,13 +29,13 @@ class PoolingCursor: ) def is_partial_prefill(self): - return not torch.all( - self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] @@ -44,34 +44,40 @@ class PoolingMetadata: def __getitem__(self, indices: slice): return PoolingMetadata( prompt_lens=self.prompt_lens[indices], - prompt_token_ids=None if self.prompt_token_ids is None else - self.prompt_token_ids[indices], + prompt_token_ids=None + if self.prompt_token_ids is None + else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], pooling_cursor=None - if self.pooling_cursor is None else self.pooling_cursor[indices], + if self.pooling_cursor is None + else self.pooling_cursor[indices], ) - def build_pooling_cursor(self, num_scheduled_tokens: list[int], - device: torch.device): - self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, - self.prompt_lens, device) + def build_pooling_cursor( + self, num_scheduled_tokens: list[int], device: torch.device + ): + self.pooling_cursor = build_pooling_cursor( + num_scheduled_tokens, self.prompt_lens, device + ) -def build_pooling_cursor(num_scheduled_tokens: list[int], - prompt_lens: torch.Tensor, device: torch.device): +def build_pooling_cursor( + num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device +): assert len(prompt_lens) == len(num_scheduled_tokens) n_seq = len(num_scheduled_tokens) index = list(range(n_seq)) num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") - cumsum = torch.zeros(n_seq + 1, - dtype=torch.int64, - pin_memory=pin_memory, - device="cpu") + cumsum = torch.zeros( + n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" + ) torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) cumsum = cumsum.to(device, non_blocking=True) - return PoolingCursor(index=index, - first_token_indices_gpu=cumsum[:n_seq], - last_token_indices_gpu=cumsum[1:] - 1, - prompt_lens_cpu=prompt_lens, - num_scheduled_tokens_cpu=num_scheduled_tokens) + return PoolingCursor( + index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens, + ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 4e99a9ccef46e..ac6e583099bc6 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,15 +3,22 @@ import enum import time +from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +import torch + +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.utils import is_list_of -from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, - EngineCoreRequest, FinishReason) +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.v1.engine import ( + EngineCoreEvent, + EngineCoreEventType, + EngineCoreRequest, + FinishReason, +) from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList @@ -21,25 +28,23 @@ if TYPE_CHECKING: class Request: - def __init__( self, request_id: str, - prompt_token_ids: list[int], - multi_modal_kwargs: Optional[list[MultiModalKwargsItem]], - multi_modal_hashes: Optional[list[str]], - multi_modal_placeholders: Optional[list[PlaceholderRange]], + prompt_token_ids: Optional[list[int]], sampling_params: Optional[SamplingParams], pooling_params: Optional[PoolingParams], eos_token_id: Optional[int], client_index: int = 0, arrival_time: Optional[float] = None, + prompt_embeds: Optional[torch.Tensor] = None, + mm_features: Optional[list[MultiModalFeatureSpec]] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, - block_hasher: Optional[Callable[["Request"], - list["BlockHash"]]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -50,8 +55,7 @@ class Request: self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request - self.arrival_time = arrival_time if arrival_time is not None else \ - time.time() + self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING self.use_structured_output = False @@ -68,44 +72,45 @@ class Request: # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - if sampling_params.guided_decoding is not None: + if sampling_params.structured_outputs is not None: self.status = RequestStatus.WAITING_FOR_FSM self.use_structured_output = True if sampling_params.extra_args is not None: - self.kv_transfer_params = \ - sampling_params.extra_args.get("kv_transfer_params") + self.kv_transfer_params = sampling_params.extra_args.get( + "kv_transfer_params" + ) else: - raise ValueError( - "sampling_params and pooling_params can't both be unset") + raise ValueError("sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids - self.num_prompt_tokens = len(self.prompt_token_ids) + self.prompt_embeds = prompt_embeds + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds + ) self._output_token_ids: list[int] = [] - self._all_token_ids: list[int] = self.prompt_token_ids.copy() + self._all_token_ids: list[int] = ( + self.prompt_token_ids.copy() + if self.prompt_token_ids is not None + else [0] * self.num_prompt_tokens + ) self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 self.cache_salt: Optional[str] = cache_salt # Multi-modal related - self.mm_positions = multi_modal_placeholders or [] - self.mm_kwargs = multi_modal_kwargs or [] - self.mm_hashes: list[str] = multi_modal_hashes or [] - self.num_encoder_inputs = len(self.mm_kwargs) + self.mm_features = mm_features or [] + self.num_encoder_inputs = len(self.mm_features) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # Sanity check - assert len(self.mm_kwargs) == len(self.mm_positions) - if self.mm_hashes: - assert len(self.mm_kwargs) == len(self.mm_hashes) - # Read-only views # Prevent directly appending to these lists since # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) - + # trace_headers + self.trace_headers = trace_headers # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 @@ -114,42 +119,40 @@ class Request: # indicates that the output is corrupted self.num_nans_in_logits = 0 + # The number of requests being preempted by the scheduler + self.num_preemptions = 0 + self.block_hashes: list[BlockHash] = [] - self.get_hash_new_full_blocks: Optional[Callable[ - [], list[BlockHash]]] = None + self.get_hash_new_full_blocks: Optional[Callable[[], list[BlockHash]]] = None if block_hasher is not None: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() @classmethod def from_engine_core_request( - cls, request: EngineCoreRequest, - block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + cls, + request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]], ) -> "Request": - if request.mm_kwargs is not None: - mm_kwargs_lst = list(request.mm_kwargs) - assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), ( - "mm_kwargs was not updated in EngineCore.add_request") - else: - mm_kwargs_lst = None - return cls( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, - multi_modal_kwargs=mm_kwargs_lst, - multi_modal_hashes=request.mm_hashes, - multi_modal_placeholders=request.mm_placeholders, + prompt_embeds=request.prompt_embeds, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( - sampling_params=request.sampling_params) \ - if request.sampling_params else None, + sampling_params=request.sampling_params + ) + if request.sampling_params + else None, cache_salt=request.cache_salt, priority=request.priority, + trace_headers=request.trace_headers, block_hasher=block_hasher, ) @@ -190,8 +193,8 @@ class Request: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_tokens(self, input_id: int) -> int: - assert input_id < len(self.mm_positions) - num_tokens = self.mm_positions[input_id].length + assert input_id < len(self.mm_features) + num_tokens = self.mm_features[input_id].mm_position.length return num_tokens def record_event( @@ -210,6 +213,7 @@ class Request: class RequestStatus(enum.IntEnum): """Status of a request.""" + WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() WAITING_FOR_REMOTE_KVS = enum.auto() @@ -230,8 +234,7 @@ class RequestStatus(enum.IntEnum): return status > RequestStatus.PREEMPTED @staticmethod - def get_finished_reason( - status: "RequestStatus") -> Union[FinishReason, None]: + def get_finished_reason(status: "RequestStatus") -> Union[FinishReason, None]: return _FINISHED_REASON_MAP.get(status) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 8220269162951..06b9e4b12d7b6 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -1,21 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +import inspect import itertools +from abc import abstractmethod from collections.abc import Sequence +from functools import partial from typing import TYPE_CHECKING, Optional, Union import torch from vllm.logger import init_logger -from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor) -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) -from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, - LogitsProcessors) +from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor.builtin import ( + LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + process_dict_updates, +) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) +from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors if TYPE_CHECKING: from vllm.config import VllmConfig @@ -24,10 +33,17 @@ logger = init_logger(__name__) # Error message when the user tries to initialize vLLM with a pooling model # and custom logitsproces -STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" - " logits processors.") +STR_POOLING_REJECTS_LOGITSPROCS = ( + "Pooling models do not support custom logits processors." +) -LOGITSPROCS_GROUP = 'vllm.logits_processors' +# Error message when the user tries to initialize vLLM with a speculative +# decoding enabled and custom logitsproces +STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( + "Custom logits processors are not supportedwhen speculative decoding is enabled." +) + +LOGITSPROCS_GROUP = "vllm.logits_processors" BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ MinTokensLogitsProcessor, @@ -39,36 +55,33 @@ BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: """Load all installed logit processor plugins""" - import sys - - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) if len(installed_logitsprocs_plugins) == 0: - logger.debug("No logitsprocs plugins installed (group %s).", - LOGITSPROCS_GROUP) + logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP) return [] # Load logitsprocs plugins - logger.debug("Loading installed logitsprocs plugins (group %s):", - LOGITSPROCS_GROUP) + logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP) classes: list[type[LogitsProcessor]] = [] for entrypoint in installed_logitsprocs_plugins: try: - logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", - entrypoint.name, entrypoint.value) + logger.debug( + "- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, + entrypoint.value, + ) classes.append(entrypoint.load()) except Exception as e: raise RuntimeError( - f"Failed to load LogitsProcessor plugin {entrypoint}") from e + f"Failed to load LogitsProcessor plugin {entrypoint}" + ) from e return classes def _load_logitsprocs_by_fqcns( - logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], ) -> list[type[LogitsProcessor]]: """Load logit processor types, identifying them by fully-qualified class names (FQCNs). @@ -93,13 +106,14 @@ def _load_logitsprocs_by_fqcns( logger.debug( "%s additional custom logits processors specified, checking whether " - "they need to be loaded.", len(logits_processors)) + "they need to be loaded.", + len(logits_processors), + ) classes: list[type[LogitsProcessor]] = [] for ldx, logitproc in enumerate(logits_processors): if isinstance(logitproc, type): - logger.debug(" - Already-loaded logit processor: %s", - logitproc.__name__) + logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__) if not issubclass(logitproc, LogitsProcessor): raise ValueError( f"{logitproc.__name__} is not a subclass of LogitsProcessor" @@ -125,8 +139,7 @@ def _load_logitsprocs_by_fqcns( if not isinstance(obj, type): raise ValueError("Loaded logit processor must be a type.") if not issubclass(obj, LogitsProcessor): - raise ValueError( - f"{obj.__name__} must be a subclass of LogitsProcessor") + raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor") classes.append(obj) return classes @@ -149,13 +162,13 @@ def _load_custom_logitsprocs( A list of all loaded logitproc types """ from vllm.platforms import current_platform + if current_platform.is_tpu(): # No logitsprocs specified by caller # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs return [] - return (_load_logitsprocs_plugins() + - _load_logitsprocs_by_fqcns(logits_processors)) + return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors) def build_logitsprocs( @@ -168,18 +181,148 @@ def build_logitsprocs( if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) - logger.debug("Skipping logits processor loading because pooling models" - " do not support logits processors.") + logger.debug( + "Skipping logits processor loading because pooling models" + " do not support logits processors." + ) return LogitsProcessors() + + # Check if speculative decoding is enabled. + if vllm_config.speculative_config: + if custom_logitsprocs: + raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS) + logger.warning( + "min_p, logit_bias, and min_tokens parameters won't currently work " + "with speculative decoding enabled." + ) + return LogitsProcessors() + custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( - ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( - BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + ctor(vllm_config, device, is_pin_memory) + for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes + ) + ) + + +class AdapterLogitsProcessor(LogitsProcessor): + """Wrapper for per-request logits processors + + To wrap a specific per-request logits processor, + * Subclass `AdapterLogitsProcessor` + * Implement `self.is_argmax_invariant()` base-class method + * Implement `self.new_req_logits_processor(params)` + + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be + overridden in general. However, to implement custom constructor behavior - + especially any logic which operates on or stores `vllm_config`, `device`, + or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)` + must be overridden and the override must call + `super().__init__(vllm_config, device, is_pin_memory)` + """ + + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): + """Subclass must invoke + `super().__init__(vllm_config, device, is_pin_memory)`. + + Subclass constructor may find it useful to utilize the `vllm_config`, + `device` and `is_pin_memory` argument. However regardless of whether + these arguments are used, the vLLM logits processor interface requires + all three arguments to be present. + """ + + # Map req index -> logits processor state + # + # State representation is a partial[Tensor] comprising a request-level + # logits processor with the output token ids argument and (if required) + # the prompt token ids argument pre-populated + # + # Note that the partial carries a *reference* to output token ids, and + # will thus always operate on the list as it is currently, not as it + # was when the partial was created. + self.req_info: dict[int, partial[torch.Tensor]] = {} + + @abstractmethod + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """Consume request info; return a per-request logits processor. + + Return None if logits processor does not need to be applied to request + + Args: + params: request sampling params + + Returns: + None if logits processor should not be applied to request; otherwise + returns a `RequestLogitsProcessor` instance + + """ + raise NotImplementedError + + def _new_state( + self, + params: SamplingParams, + prompt_ids: Optional[list[int]], + output_ids: list[int], + ) -> Optional[partial[torch.Tensor]]: + """Return state representation for new request + + Returns None if logits processor is not applicable to request + + Args: + params: request sampling params + prompt_ids: request prompt token ids + output_ids: decoded tokens so far for this request + + Returns: + logits processor partial[Tensor] or None + + """ + if req_lp := self.new_req_logits_processor(params): + args = ( + [prompt_ids, output_ids] + if (len(inspect.signature(req_lp).parameters) == 3) + else [output_ids] + ) + return partial(req_lp, *args) + return None + + def update_state(self, batch_update: Optional[BatchUpdate]): + process_dict_updates( + self.req_info, + batch_update, + self._new_state, + ) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.req_info: + # Apply per-request logits processors to corresponding rows of + # logits tensor + for req_idx, req_lp in self.req_info.items(): + req_logits = logits[req_idx] + new_logits = req_lp(req_logits) + if new_logits is not req_logits: + # Modify logits tensor row in-place if necessary + logits[req_idx] = new_logits + return logits __all__ = [ - "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", - "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", - "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", - "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP" + "LogitsProcessor", + "LogitBiasLogitsProcessor", + "MinPLogitsProcessor", + "MinTokensLogitsProcessor", + "BatchUpdate", + "BatchUpdateBuilder", + "MoveDirectionality", + "LogitsProcessors", + "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", + "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor", ] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 00dd757489ca0..3c3ddda7fb3e4 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,38 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional, TypeVar import torch -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) +from vllm import SamplingParams +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) if TYPE_CHECKING: from vllm.config import VllmConfig +T = TypeVar("T") + class MinPLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 - self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=is_pin_memory) + self.min_p_cpu_tensor = torch.zeros( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) else: self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor @@ -90,8 +94,7 @@ class MinPLogitsProcessor(LogitsProcessor): if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] if self.use_double_tensor: - self.min_p.copy_(self.min_p_cpu_tensor[:size], - non_blocking=True) + self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -101,28 +104,27 @@ class MinPLogitsProcessor(LogitsProcessor): # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Adjust min_p adjusted_min_p = max_probabilities.mul_(self.min_p) # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing - logits[invalid_token_mask] = -float('inf') + logits[invalid_token_mask] = -float("inf") return logits class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device self.pin_memory = is_pin_memory self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) - self.logits_slice = (self._device_tensor([], torch.int32), - self._device_tensor([], torch.int32)) + self.logits_slice = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """Logit bias can rebalance token probabilities and change the @@ -130,64 +132,30 @@ class LogitBiasLogitsProcessor(LogitsProcessor): return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - needs_update: bool = False - # Process added requests. - for index, params, _, _ in batch_update.added: - if lb := params.logit_bias: - self.biases[index] = lb - needs_update = True - else: - # Drop biases metadata at batch index - if self.biases.pop(index, None) is not None: - # If a new request replaces an old request which - # specified biases, we should update processor tensors - needs_update = True - - if self.biases: - # Process removed requests. - for index in batch_update.removed: - if self.biases.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and swap (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.biases.pop(a_index, None)) is None: - if self.biases.pop(b_index, None) is not None: - needs_update = True - else: - self.biases[b_index] = a_entry - needs_update = True - else: - a_entry = self.biases.pop(a_index, None) - if (b_entry := self.biases.pop(b_index, None)) is not None: - self.biases[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.biases[b_index] = a_entry - needs_update = True + needs_update = process_dict_updates( + self.biases, batch_update, lambda params, _, __: params.logit_bias or None + ) # Update tensors if needed. if needs_update: - reqs, tok_ids, biases = [], [], [] + reqs: list[int] = [] + tok_ids: list[int] = [] + biases: list[float] = [] for req, lb in self.biases.items(): reqs.extend([req] * len(lb)) tok_ids.extend(lb.keys()) biases.extend(lb.values()) self.bias_tensor = self._device_tensor(biases, torch.float32) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.biases: @@ -196,77 +164,45 @@ class LogitBiasLogitsProcessor(LogitsProcessor): class MinTokensLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): # index -> (min_toks, output_token_ids, stop_token_ids) self.device = device self.pin_memory = is_pin_memory self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) - self.logits_slice: tuple[torch.Tensor, - torch.Tensor] = (self._device_tensor( - [], torch.int32), - self._device_tensor( - [], torch.int32)) + self.logits_slice: tuple[torch.Tensor, torch.Tensor] = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome of the argmax operation in greedy sampling.""" return False + @staticmethod + def add_request( + params: SamplingParams, _: Optional[list[int]], output_tok_ids: list[int] + ) -> Optional[tuple[int, Sequence[int], set[int]]]: + min_tokens = params.min_tokens + if not min_tokens or len(output_tok_ids) >= min_tokens: + return None + return min_tokens, output_tok_ids, params.all_stop_token_ids + def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = False - - if batch_update: - # Process added requests. - for index, params, _, output_tok_ids in batch_update.added: - if ((min_tokens := params.min_tokens) - and len(output_tok_ids) < min_tokens): - # Replace request metadata at batch index - self.min_toks[index] = (min_tokens, output_tok_ids, - params.all_stop_token_ids) - needs_update = True - else: - # Drop min_toks metadata at batch index - if self.min_toks.pop(index, None) is not None: - # If a new request replaces an old request which - # specified min_toks, we should update processor tensors - needs_update = True - - if self.min_toks: - # Process removed requests. - for index in batch_update.removed: - if self.min_toks.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and - # swapped (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.min_toks.pop(a_index, - None)) is None: - if self.min_toks.pop(b_index, None) is not None: - needs_update = True - else: - self.min_toks[b_index] = a_entry - needs_update = True - else: - a_entry = self.min_toks.pop(a_index, None) - if (b_entry := self.min_toks.pop(b_index, - None)) is not None: - self.min_toks[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.min_toks[b_index] = a_entry - needs_update = True - + needs_update = process_dict_updates( + self.min_toks, batch_update, self.add_request + ) if self.min_toks: # Check for any requests that have attained their min tokens. - to_remove = tuple(index for index, (min_toks, out_tok_ids, - _) in self.min_toks.items() - if len(out_tok_ids) >= min_toks) + to_remove = tuple( + index + for index, (min_toks, out_tok_ids, _) in self.min_toks.items() + if len(out_tok_ids) >= min_toks + ) if to_remove: needs_update = True for index in to_remove: @@ -280,18 +216,59 @@ class MinTokensLogitsProcessor(LogitsProcessor): reqs.extend([req] * len(stop_tok_ids)) tok_ids.extend(stop_tok_ids) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits + + +def process_dict_updates( + req_entries: dict[int, T], + batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], Optional[T]], +) -> bool: + """Utility function to update dict state for sparse LogitsProcessors.""" + + if not batch_update: + # Nothing to do. + return False + + updated = False + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None: + req_entries[index] = state + updated = True + elif req_entries.pop(index, None) is not None: + updated = True + + if req_entries: + # Process removed requests. + for index in batch_update.removed: + if req_entries.pop(index, None): + updated = True + + # Process moved requests, unidirectional (a->b) and + # swapped (a<->b) + for a_index, b_index, direct in batch_update.moved: + a_entry = req_entries.pop(a_index, None) + b_entry = req_entries.pop(b_index, None) + if a_entry is not None: + req_entries[b_index] = a_entry + updated = True + if b_entry is not None: + updated = True + if direct == MoveDirectionality.SWAP: + req_entries[a_index] = b_entry + + return updated diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 12b4db24bff88..713bd21d38554 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -21,21 +21,22 @@ class MoveDirectionality(Enum): SWAP = auto() +# Batch indices of any removed requests. +RemovedRequest = int + # (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, SamplingParams, list[int], list[int]] +AddedRequest = tuple[int, SamplingParams, Optional[list[int]], list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch MovedRequest = tuple[int, int, MoveDirectionality] -# Batch indices of any removed requests. -RemovedRequest = int - @dataclass(frozen=True) class BatchUpdate: """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch # Metadata for requests added to, removed from, and moved @@ -44,21 +45,32 @@ class BatchUpdate: # Key assumption: the `output_tok_ids` list (which is an element of each # tuple in `added`) is a reference to the request's running output tokens # list; via this reference, the logits processors always see the latest - # list of generated output tokens + # list of generated output tokens. + # + # NOTE: + # * Added or moved requests may replace existing requests with the same + # index. + # * Operations should be processed in the following order: + # - removed, added, moved removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] added: Sequence[AddedRequest] + moved: Sequence[MovedRequest] class LogitsProcessor(ABC): - @abstractmethod - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool) -> None: + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ) -> None: raise NotImplementedError @abstractmethod def apply(self, logits: torch.Tensor) -> torch.Tensor: + """Apply LogitsProcessor to batch logits tensor. + + The updated tensor must be returned but may be + modified in-place. + """ raise NotImplementedError @abstractmethod @@ -80,7 +92,7 @@ class LogitsProcessor(ABC): to each forward pass. Args: - batch_update is non-None iff there have been - changes to the batch makeup. + batch_update: Non-None iff there have been changes + to the batch makeup. """ raise NotImplementedError diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 31cece58c7db5..a601f66415818 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -4,10 +4,12 @@ from collections.abc import Iterator from itertools import chain from typing import TYPE_CHECKING, Optional -from vllm.v1.sample.logits_processor.interface import (AddedRequest, - BatchUpdate, - MovedRequest, - RemovedRequest) +from vllm.v1.sample.logits_processor.interface import ( + AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest, +) if TYPE_CHECKING: from vllm.v1.sample.logits_processor.interface import LogitsProcessor @@ -36,18 +38,18 @@ class BatchUpdateBuilder: _removed: list[RemovedRequest] _is_removed_sorted: bool - moved: list[MovedRequest] added: list[AddedRequest] + moved: list[MovedRequest] def __init__( self, removed: Optional[list[RemovedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, added: Optional[list[AddedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, ) -> None: self._removed = removed or [] - self.moved = moved or [] self.added = added or [] + self.moved = moved or [] self._is_removed_sorted = False # Used to track changes in the pooling case @@ -81,8 +83,9 @@ class BatchUpdateBuilder: index: request index """ if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") + raise RuntimeError( + "Cannot register new removed request after self.removed has been read." + ) self._removed.append(index) self.batch_changed = True @@ -107,8 +110,8 @@ class BatchUpdateBuilder: """Returns True if there were any changes to the batch.""" self._is_removed_sorted = False self._removed.clear() - self.moved.clear() self.added.clear() + self.moved.clear() batch_changed = self.batch_changed self.batch_changed = False return batch_changed @@ -116,7 +119,7 @@ class BatchUpdateBuilder: def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """Generate a logitsprocs batch update data structure and reset internal batch update builder state. - + Args: batch_size: current persistent batch size @@ -146,14 +149,17 @@ class LogitsProcessors: """Encapsulates initialized logitsproc objects.""" def __init__( - self, - logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self, logitsprocs: Optional[Iterator["LogitsProcessor"]] = None + ) -> None: self.argmax_invariant: list[LogitsProcessor] = [] self.non_argmax_invariant: list[LogitsProcessor] = [] if logitsprocs: for logitproc in logitsprocs: - (self.argmax_invariant if logitproc.is_argmax_invariant() else - self.non_argmax_invariant).append(logitproc) + ( + self.argmax_invariant + if logitproc.is_argmax_invariant() + else self.non_argmax_invariant + ).append(logitproc) @property def all(self) -> Iterator["LogitsProcessor"]: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9d6a87cea3d07..e252ace97d27e 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -11,7 +11,6 @@ from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass class SamplingMetadata: - temperature: Optional[torch.Tensor] all_greedy: bool all_random: bool @@ -41,3 +40,6 @@ class SamplingMetadata: # Loaded logits processors logitsprocs: LogitsProcessors + + # Speculative token ids + spec_token_ids: Optional[list[list[int]]] = None diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py index 1b699565f26f2..8e2c798dd35ff 100644 --- a/vllm/v1/sample/ops/bad_words.py +++ b/vllm/v1/sample/ops/bad_words.py @@ -17,10 +17,7 @@ def _apply_bad_words_single_batch( prefix_length = len(bad_word_ids) - 1 last_token_id = bad_word_ids[-1] - if prefix_length > 0: - actual_prefix = past_tokens_ids[-prefix_length:] - else: - actual_prefix = [] + actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else [] expected_prefix = bad_word_ids[:prefix_length] assert len(actual_prefix) == len(expected_prefix) @@ -35,5 +32,21 @@ def apply_bad_words( past_tokens_ids: list[list[int]], ) -> None: for i, bad_words_ids in bad_words_token_ids.items(): - _apply_bad_words_single_batch(logits[i], bad_words_ids, - past_tokens_ids[i]) + _apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i]) + + +def apply_bad_words_with_drafts( + logits: torch.Tensor, + bad_words_token_ids: dict[int, list[list[int]]], + past_tokens_ids: list[list[int]], + num_draft_tokens: list[int], +) -> None: + start_idx = 0 + for i, bad_words_ids in bad_words_token_ids.items(): + for draft_idx in range(num_draft_tokens[i]): + _apply_bad_words_single_batch( + logits[start_idx + draft_idx], + bad_words_ids, + past_tokens_ids[start_idx + draft_idx], + ) + start_idx += num_draft_tokens[i] diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py index 82875b7c84522..cf36d46e13fda 100644 --- a/vllm/v1/sample/ops/logprobs.py +++ b/vllm/v1/sample/ops/logprobs.py @@ -8,8 +8,7 @@ from vllm.platforms import current_platform @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def batched_count_greater_than(x: torch.Tensor, - values: torch.Tensor) -> torch.Tensor: +def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor: """ Counts elements in each row of x that are greater than the corresponding value in values. Use torch.compile to generate an optimized kernel for diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index 5d54f6679a1a9..e49b8db47800d 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -19,15 +19,20 @@ def apply_all_penalties( Applies presence, frequency and repetition penalties to the logits. """ _, vocab_size = logits.shape - output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, - logits.device) - return apply_penalties(logits, prompt_token_ids, output_tokens_t, - presence_penalties, frequency_penalties, - repetition_penalties) + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device) + return apply_penalties( + logits, + prompt_token_ids, + output_tokens_t, + presence_penalties, + frequency_penalties, + repetition_penalties, + ) -def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, - device: torch.device) -> torch.Tensor: +def _convert_to_tensors( + output_token_ids: list[list[int]], vocab_size: int, device: torch.device +) -> torch.Tensor: """ Convert the different list data structures to tensors. """ diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7bd4a5a380ac0..5fa7a9ad44cd4 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -8,7 +8,7 @@ import torch.nn as nn from packaging import version from vllm import envs -from vllm.config import LogprobsMode +from vllm.config.model import LogprobsMode from vllm.logger import init_logger from vllm.platforms import current_platform @@ -16,6 +16,7 @@ logger = init_logger(__name__) try: import flashinfer.sampling + is_flashinfer_available = True except ImportError: is_flashinfer_available = False @@ -29,22 +30,22 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__( - self, - logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None: + def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: super().__init__() self.logprobs_mode = logprobs_mode # flashinfer optimization does not apply if intermediate # logprobs/logits after top_k/top_p need to be returned - if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, - LogprobsMode.PROCESSED_LOGPROBS - ) and current_platform.is_cuda(): + if ( + logprobs_mode not in ("processed_logits", "processed_logprobs") + and current_platform.is_cuda() + ): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): logger.warning_once( "FlashInfer version >= 0.2.3 required. " - "Falling back to default sampling implementation.") + "Falling back to default sampling implementation." + ) self.forward = self.forward_native elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for @@ -55,28 +56,29 @@ class TopKTopPSampler(nn.Module): # None means False, while in V1, None means True. This is # why we use the condition # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - logger.info_once( - "Using FlashInfer for top-p & top-k sampling.") + logger.info_once("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda else: logger.warning_once( "FlashInfer is available, but it is not enabled. " "Falling back to the PyTorch-native implementation of " "top-p & top-k sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + "please set VLLM_USE_FLASHINFER_SAMPLER=1." + ) self.forward = self.forward_native else: logger.warning_once( "FlashInfer is not available. Falling back to the PyTorch-" "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer.") + "best performance, please install FlashInfer." + ) self.forward = self.forward_native + elif current_platform.is_cpu(): + self.forward = self.forward_cpu else: self.forward = self.forward_native - if current_platform.is_tpu(): - self.apply_top_k_top_p = apply_top_k_top_p_tpu - else: - self.apply_top_k_top_p = apply_top_k_top_p + + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -92,9 +94,9 @@ class TopKTopPSampler(nn.Module): """ logits = self.apply_top_k_top_p(logits, k, p) logits_to_return = None - if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + if self.logprobs_mode == "processed_logits": logits_to_return = logits - elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators), logits_to_return @@ -112,64 +114,58 @@ class TopKTopPSampler(nn.Module): # CPU-GPU synchronization while `flashinfer_sample` does. if (k is None and p is None) or generators: if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + logger.debug_once( + "FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation." + ) return self.forward_native(logits, generators, k, p) - assert self.logprobs_mode not in ( - LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS - ), "FlashInfer does not support returning logits/logprobs" + assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), ( + "FlashInfer does not support returning logits/logprobs" + ) # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. return flashinfer_sample(logits.contiguous(), k, p, generators), None + def forward_cpu( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + PyTorch-native implementation of top-k and top-p sampling for CPU. -def apply_top_k_top_p_tpu( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - """ - Apply top-k and top-p optimized for TPU. + The logits tensor may be updated in-place. + """ + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == "processed_logits": + logits_to_return = logits + elif self.logprobs_mode == "processed_logprobs": + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) - This algorithm avoids using torch.scatter which is extremely slow on TPU. - This is achieved by finding a "cut-off" element in the original logit, and - after thresholding the logit using this cut-off, the remaining elements - shall constitute the top-p set. + # Note: this is a workaround for + # https://github.com/pytorch/pytorch/pull/151218 + @torch.compile(dynamic=True) + def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + return probs.div(q).argmax(dim=-1).view(-1) - Note: in the case of tie (i.e. multipple cut-off elements present in the - logit), all tie elements are included in the top-p set. In other words, - this function does not break ties. Instead, these tie tokens have equal - chance of being chosen during final sampling, so we can consider the tie - being broken then. - """ - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) + if len(generators) != logits.shape[0]: + return compiled_random_sample(logits), logits_to_return + else: + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + for i, generator in generators.items(): + q[i].exponential_(generator=generator) - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits + return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return def apply_top_k_top_p( @@ -289,15 +285,18 @@ def flashinfer_sample( # Top-p only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( - probs, p, deterministic=True) + probs, p, deterministic=True + ) elif p is None: # Top-k only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( - probs, k, deterministic=True) + probs, k, deterministic=True + ) else: # Both top-k and top-p. next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits( - logits, k, p, deterministic=True) + logits, k, p, deterministic=True + ) return next_token_ids.view(-1) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b2354c53302ad..72cee8c73969a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -8,6 +8,8 @@ import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts +from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -17,7 +19,7 @@ PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. -MAX_SPEC_LEN = 32 +MAX_SPEC_LEN = 128 class RejectionSampler(nn.Module): @@ -54,7 +56,7 @@ class RejectionSampler(nn.Module): bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - ''' + """ Args: metadata: Metadata for spec decoding. @@ -68,7 +70,7 @@ class RejectionSampler(nn.Module): different requests are flattened into a single tensor because this is the shape of the output logits. NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids_tensor (torch.Tensor): + bonus_token_ids (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens @@ -81,8 +83,16 @@ class RejectionSampler(nn.Module): Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. - ''' + """ assert metadata.max_spec_len <= MAX_SPEC_LEN + + # Use float32 for the target_logits. + target_logits = target_logits.to(torch.float32) + + target_logits = self.apply_logits_processors( + target_logits, sampling_metadata, metadata + ) + # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the # `compute_probs` function. @@ -123,14 +133,103 @@ class RejectionSampler(nn.Module): """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. - valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & - (output_token_ids_np < vocab_size)) + valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( + output_token_ids_np < vocab_size + ) outputs = [ - row[valid_mask[i]].tolist() - for i, row in enumerate(output_token_ids_np) + row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs + def apply_logits_processors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + metadata: SpecDecodeMetadata, + ) -> torch.Tensor: + has_penalties = not sampling_metadata.no_penalties + any_penalties_or_bad_words = ( + sampling_metadata.bad_words_token_ids or has_penalties + ) + + output_token_ids = sampling_metadata.output_token_ids + if any_penalties_or_bad_words: + output_token_ids = self._combine_outputs_with_spec_tokens( + output_token_ids, + sampling_metadata.spec_token_ids, + ) + + # Calculate indices of target logits. + if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: + num_requests = len(sampling_metadata.output_token_ids) + num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") + original_indices = torch.arange(num_requests, device="cpu") + repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens) + repeat_indices = repeat_indices_cpu.to( + device=logits.device, non_blocking=True + ) + logits = self.apply_penalties( + logits, sampling_metadata, metadata, repeat_indices, output_token_ids + ) + + # Apply allowed token ids. + if sampling_metadata.allowed_token_ids_mask is not None: + token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices] + logits.masked_fill_(token_mask, float("-inf")) + + # Apply bad words exclusion. + if bad_words_token_ids := sampling_metadata.bad_words_token_ids: + apply_bad_words_with_drafts( + logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens + ) + + return logits + + @staticmethod + def apply_penalties( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + metadata: SpecDecodeMetadata, + repeat_indices: torch.Tensor, + output_token_ids: list[list[int]], + ) -> torch.Tensor: + if sampling_metadata.no_penalties: + return logits + + assert sampling_metadata.prompt_token_ids is not None + + prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices] + presence_penalties = sampling_metadata.presence_penalties[repeat_indices] + frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices] + repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices] + + logits = apply_all_penalties( + logits, + prompt_token_ids, + presence_penalties, + frequency_penalties, + repetition_penalties, + output_token_ids, + ) + return logits + + @staticmethod + def _combine_outputs_with_spec_tokens( + output_token_ids: list[list[int]], + spec_token_ids: Optional[list[list[int]]] = None, + ) -> list[list[int]]: + if spec_token_ids is None: + return output_token_ids + + result = [] + for out, spec in zip(output_token_ids, spec_token_ids): + if len(spec) == 0: + continue + result.append(out) + for i in range(len(spec) - 1): + result.append([*result[-1], spec[i]]) + return result + def rejection_sample( # [num_tokens] @@ -164,12 +263,12 @@ def rejection_sample( assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. - output_token_ids = torch.empty( + output_token_ids = torch.full( (batch_size, max_spec_len + 1), + PLACEHOLDER_TOKEN_ID, dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) - output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) if sampling_metadata.all_greedy: is_greedy = None @@ -178,7 +277,7 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - rejection_greedy_sample_kernel[(batch_size, )]( + rejection_greedy_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -186,7 +285,6 @@ def rejection_sample( bonus_token_ids, is_greedy, max_spec_len, - num_warps=1, ) if sampling_metadata.all_greedy: return output_token_ids @@ -214,7 +312,7 @@ def rejection_sample( ) # Rejection sampling for random sampling requests. - rejection_random_sample_kernel[(batch_size, )]( + rejection_random_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -227,7 +325,6 @@ def rejection_sample( max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, - num_warps=1, ) return output_token_ids @@ -322,14 +419,13 @@ def expand_batch_to_tokens( batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) - expand_kernel[(batch_size, )]( + expand_kernel[(batch_size,)]( expanded_x, x, cu_num_tokens, replace_from, replace_to, MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. - num_warps=1, ) return expanded_x @@ -351,23 +447,28 @@ def generate_uniform_probs( without a seed. Args: - num_tokens : int + num_tokens: int Total number of tokens. - num_draft_tokens : List[List[int]] + num_draft_tokens: List[List[int]] Number of draft tokens per request. - generators : Optional[Dict[int, torch.Generator]] + generators: Optional[Dict[int, torch.Generator]] A dictionary mapping indices in the batch to `torch.Generator` objects. - device : torch.device + device: torch.device The device on which to allocate the tensor. Returns: - uniform_rand : torch.Tensor + uniform_rand: torch.Tensor A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ + # NOTE(woosuk): We deliberately use float64 instead of float32 here + # because when using float32, there's a non-negligible chance that + # uniform_prob is sampled to be exact 0.0 as reported in + # https://github.com/pytorch/pytorch/issues/16706. Using float64 + # mitigates the issue. uniform_probs = torch.rand( - (num_tokens, ), - dtype=torch.float32, + (num_tokens,), + dtype=torch.float64, device=device, ) start_idx = 0 @@ -442,18 +543,12 @@ def rejection_greedy_sample_kernel( req_idx = tl.program_id(0) # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # re-compilation may happen during runtime when is_greedy_ptr is None. - if is_greedy_ptr is None: - is_greedy = True - else: - is_greedy = tl.load(is_greedy_ptr + req_idx) + is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx) if not is_greedy: # Early exit for non-greedy sampling requests. return - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -462,8 +557,10 @@ def rejection_greedy_sample_kernel( if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - target_argmax_id) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) if draft_token_id != target_argmax_id: # Reject. rejected = True @@ -472,8 +569,9 @@ def rejection_greedy_sample_kernel( # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @@ -498,10 +596,7 @@ def rejection_random_sample_kernel( # Early exit for greedy sampling requests. return - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -512,12 +607,12 @@ def rejection_random_sample_kernel( if NO_DRAFT_PROBS: draft_prob = 1 else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. @@ -528,15 +623,17 @@ def rejection_random_sample_kernel( # Reject. Use recovered token. rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - token_id) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id + ) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @@ -560,9 +657,7 @@ def expand_kernel( src_val = tl.load(input_ptr + req_idx) src_val = tl.where(src_val == replace_from, replace_to, src_val) offset = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx + offset, - src_val, - mask=offset < num_tokens) + tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens) @triton.jit @@ -578,10 +673,7 @@ def sample_recovered_tokens_kernel( NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -593,39 +685,30 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - 0) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)), + other=0, + ) else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf")) + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf"), + ) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - orig_prob) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 546531a91610f..2e076ca8e3c84 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -7,7 +7,7 @@ from typing import Optional import torch import torch.nn as nn -from vllm.config import LogprobsMode +from vllm.config.model import LogprobsMode from vllm.utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -24,44 +24,43 @@ class Sampler(nn.Module): A layer that samples the next tokens from the model's outputs with the following steps in order: - 1. If logprobs are requested: + 1. If logprobs are requested: a) If `logprobs_mode` is `raw_logprobs`, compute logprobs - as the final logprobs to return. + as the final logprobs to return. b) If `logprobs_mode` is `raw_logits`, clone the logits - as the final logprobs to return. - 2. Convert logits to float32. - 3. Apply allowed token ids whitelist. - 4. Apply bad words exclusion. + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. 5. Apply logit processors which are not argmax-invariant, - i.e. that can impact greedy sampling. - a) Min tokens processor - b) Logit bias processor - 6. Apply penalties - a) Repetition penalty - b) Frequency penalty - c) Presence penalty - 7. Sample the next tokens. `sample` method performs the following steps: + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: a) If not `all_random`, perform greedy sampling. If `all_greedy`, - return the greedily sampled tokens and final logprobs if requested. - b) Apply temperature. + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. c) Apply logit processors which are argmax-invariant, by default - the min_p processor. - d) Apply top_k and/or top_p. - e) Sample the next tokens with the probability distribution. + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. f) If `all_random` or temperature >= epsilon (1e-5), return the randomly sampled tokens and final logprobs if requested. Else, - return the greedily sampled tokens and logprobs if requested. + return the greedily sampled tokens and logprobs if requested. 8. Gather the logprobs of the top `max_num_logprobs` and sampled token (if requested). Note that if the sampled token is within the top `max_num_logprobs`, the logprob will be eventually merged in `LogprobsProcessor` during output processing. Therefore, the final output may contain either `max_num_logprobs + 1` or - `max_num_logprobs` logprobs. + `max_num_logprobs` logprobs. 9. Return the final `SamplerOutput`. """ - def __init__(self, - logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS): + def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): super().__init__() self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.pin_memory = is_pin_memory_available() @@ -71,6 +70,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, + predict_bonus_token: bool = False, ) -> SamplerOutput: # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. @@ -78,25 +78,17 @@ class Sampler(nn.Module): # is used for sampling (after penalties and temperature scaling). num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: + if self.logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: + elif self.logprobs_mode == "raw_logits": raw_logprobs = logits.clone() # Use float32 for the logits. logits = logits.to(torch.float32) - # Apply allowed token ids. - logits = self.apply_allowed_token_ids(logits, sampling_metadata) - # Apply bad words exclusion. - logits = self.apply_bad_words(logits, sampling_metadata) - - # Apply logits processors which can impact greedy sampling - for processor in sampling_metadata.logitsprocs.non_argmax_invariant: - logits = processor.apply(logits) - - # Apply penalties (e.g., min_tokens, freq_penalties). - logits = self.apply_penalties(logits, sampling_metadata) + logits = self.apply_logits_processors( + logits, sampling_metadata, predict_bonus_token + ) # Sample the next token. sampled, processed_logprobs = self.sample(logits, sampling_metadata) if processed_logprobs is not None: @@ -109,8 +101,11 @@ class Sampler(nn.Module): # Gather the logprobs of the topk and sampled token (if requested). # Get logprobs and rank tensors (if requested) - logprobs_tensors = None if num_logprobs is None else \ - self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + logprobs_tensors = ( + None + if num_logprobs is None + else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + ) # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) @@ -125,15 +120,20 @@ class Sampler(nn.Module): ) return sampler_output + @staticmethod def apply_temperature( - self, logits: torch.Tensor, temp: torch.Tensor, + all_random: bool, ) -> torch.Tensor: # Use in-place division to avoid creating a new tensor. + # Avoid division by zero if there are greedy requests. + if not all_random: + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) return logits.div_(temp.unsqueeze(dim=1)) - def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + @staticmethod + def greedy_sample(logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) def sample( @@ -147,8 +147,7 @@ class Sampler(nn.Module): may update the logits tensor in-place. """ - assert not (sampling_metadata.all_greedy - and sampling_metadata.all_random) + assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: greedy_sampled = None else: @@ -156,16 +155,18 @@ class Sampler(nn.Module): if sampling_metadata.all_greedy: processed_logprobs = None if sampling_metadata.max_num_logprobs is not None: - if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + if self.logprobs_mode == "processed_logits": processed_logprobs = logits - elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + elif self.logprobs_mode == "processed_logprobs": processed_logprobs = self.compute_logprobs(logits) return greedy_sampled, processed_logprobs assert sampling_metadata.temperature is not None # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_temperature( + logits, sampling_metadata.temperature, sampling_metadata.all_random + ) # Apply logits processors that only apply to random sampling # (argmax invariant) @@ -191,11 +192,12 @@ class Sampler(nn.Module): ) return sampled, processed_logprobs - def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + @staticmethod + def compute_logprobs(logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) + @staticmethod def gather_logprobs( - self, logprobs: torch.Tensor, num_logprobs: int, token_ids: torch.Tensor, @@ -220,9 +222,7 @@ class Sampler(nn.Module): """ assert token_ids.dtype == torch.int64 # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) + topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) @@ -240,42 +240,70 @@ class Sampler(nn.Module): return LogprobsTensors(indices, logprobs, token_ranks) - def apply_penalties( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - if not sampling_metadata.no_penalties: - assert sampling_metadata.prompt_token_ids is not None - logits = apply_all_penalties( - logits, - sampling_metadata.prompt_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids, - ) - return logits + @staticmethod + def _combine_outputs_with_spec_tokens( + output_token_ids: list[list[int]], + spec_token_ids: Optional[list[list[int]]] = None, + ) -> list[list[int]]: + if spec_token_ids is None: + return output_token_ids - def apply_allowed_token_ids( + return [ + [*out, *spec] if spec else out + for out, spec in zip(output_token_ids, spec_token_ids) + ] + + def apply_logits_processors( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, + predict_bonus_token: bool, ) -> torch.Tensor: + bad_words_token_ids = sampling_metadata.bad_words_token_ids + any_penalties_or_bad_words = ( + bool(bad_words_token_ids) or not sampling_metadata.no_penalties + ) + + output_token_ids = sampling_metadata.output_token_ids + if predict_bonus_token and any_penalties_or_bad_words: + # Combine base outputs with spec tokens when speculative decoding + # is enabled. + output_token_ids = self._combine_outputs_with_spec_tokens( + output_token_ids, + sampling_metadata.spec_token_ids, + ) + + # Apply allowed token ids. if sampling_metadata.allowed_token_ids_mask is not None: - logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, - float("-inf")) + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) + + # Apply bad words exclusion. + if bad_words_token_ids: + apply_bad_words(logits, bad_words_token_ids, output_token_ids) + + # Apply logits processors which can impact greedy sampling. + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: + logits = processor.apply(logits) + + # Apply penalties (e.g., freq_penalties). + logits = self.apply_penalties(logits, sampling_metadata, output_token_ids) return logits - def apply_bad_words( - self, + @staticmethod + def apply_penalties( logits: torch.Tensor, sampling_metadata: SamplingMetadata, + output_token_ids: list[list[int]], ) -> torch.Tensor: - if sampling_metadata.bad_words_token_ids: - apply_bad_words( - logits, - sampling_metadata.bad_words_token_ids, - sampling_metadata.output_token_ids, - ) - return logits + if sampling_metadata.no_penalties: + return logits + + assert sampling_metadata.prompt_token_ids is not None + return apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + output_token_ids, + ) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 6491c84f60762..b58a94d0bf7dc 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -48,15 +48,13 @@ class TPUSupportedSamplingMetadata: min_tokens = None # impl is not vectorized - logit_bias: list[Optional[dict[int, float]]] = field( - default_factory=lambda: list()) + logit_bias: list[Optional[dict[int, float]]] = field(default_factory=lambda: list()) allowed_token_ids_mask = None bad_words_token_ids = None # Generator not supported by xla - _generators: dict[int, - torch.Generator] = field(default_factory=lambda: dict()) + _generators: dict[int, torch.Generator] = field(default_factory=lambda: dict()) @property def generators(self) -> dict[int, torch.Generator]: @@ -69,13 +67,13 @@ class TPUSupportedSamplingMetadata: input_batch: InputBatch, padded_num_reqs: int, xla_device: torch.device, - generate_params_if_all_greedy: bool = False + generate_params_if_all_greedy: bool = False, ) -> "TPUSupportedSamplingMetadata": """ Copy sampling tensors slices from `input_batch` to on device tensors. - `InputBatch._make_sampling_metadata` causes recompilation on XLA as it - slices dynamic shapes on device tensors. This impl moves the dynamic + `InputBatch._make_sampling_metadata` causes recompilation on XLA as it + slices dynamic shapes on device tensors. This impl moves the dynamic ops to CPU and produces tensors of fixed `padded_num_reqs` size. Args: @@ -87,11 +85,11 @@ class TPUSupportedSamplingMetadata: we want to pre-compile a graph with sampling parameters, even if they are not strictly needed for greedy decoding. """ - needs_logprobs = input_batch.max_num_logprobs>0 if \ - input_batch.max_num_logprobs else False + needs_logprobs = ( + input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False + ) # Early return to avoid unnecessary cpu to tpu copy - if (input_batch.all_greedy is True - and generate_params_if_all_greedy is False): + if input_batch.all_greedy is True and generate_params_if_all_greedy is False: return cls(all_greedy=True, logprobs=needs_logprobs) num_reqs = input_batch.num_reqs @@ -100,25 +98,22 @@ class TPUSupportedSamplingMetadata: # Pad value is the default one. cpu_tensor[num_reqs:padded_num_reqs] = fill_val - fill_slice(input_batch.temperature_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["temperature"]) - fill_slice(input_batch.min_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["min_p"]) - fill_slice(input_batch.top_k_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_k"]) - fill_slice(input_batch.top_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_p"]) + fill_slice( + input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"] + ) + fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"]) + fill_slice(input_batch.top_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( - temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs]. - to(xla_device), + temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].to( + xla_device + ), all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values - top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( - xla_device), - min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - logprobs=needs_logprobs) + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), + min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(xla_device), + logprobs=needs_logprobs, + ) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 04545d587e4a9..ccef283a81829 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -2,22 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampler layer implementing TPU supported operations.""" +from typing import Optional + import torch import torch.nn as nn from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): - def __init__(self): # TODO(houseroad): Add support for logprobs_mode. super().__init__() - self.topk_topp_sampler = TopKTopPSampler() def forward( self, @@ -35,7 +34,8 @@ class Sampler(nn.Module): # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=sampled.unsqueeze(-1), - logprobs_tensors=None) + logprobs_tensors=None, + ) return sampler_output def apply_temperature( @@ -65,15 +65,21 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled, _ = self.topk_topp_sampler( + logits = apply_top_k_top_p( logits, - sampling_metadata.generators, sampling_metadata.top_k, sampling_metadata.top_p, ) - sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, - greedy_sampled, random_sampled) + # Random sample. + probs = logits.softmax(dim=-1, dtype=torch.float32) + random_sampled = self.random_sample(probs, sampling_metadata.generators) + + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + ) return sampled def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: @@ -89,7 +95,7 @@ class Sampler(nn.Module): Gather logprobs for topk and sampled/prompt token. Args: - logits: (num tokens) x (vocab) tensor + logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) @@ -103,9 +109,7 @@ class Sampler(nn.Module): Sampled token rank tensor, (num tokens) """ # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) + topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) @@ -134,9 +138,7 @@ class Sampler(nn.Module): # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Reshape min_p for broadcasting adjusted_min_p = min_p.unsqueeze(1) * max_probabilities # Identify valid tokens using threshold comparison @@ -144,3 +146,66 @@ class Sampler(nn.Module): # Apply mask using boolean indexing (xla friendly) logits.masked_fill_(~valid_token_mask, -float("inf")) return logits + + def random_sample( + self, + probs: torch.Tensor, + generators: dict[int, torch.Generator], + ) -> torch.Tensor: + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + q.exponential_() + if generators: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Apply top-k and top-p optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. + """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index c8375d6f15517..f4e1cbd2e0243 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -7,7 +7,7 @@ import pickle from collections.abc import Sequence from inspect import isclass from types import FunctionType -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import cloudpickle import msgspec @@ -18,15 +18,18 @@ from msgspec import msgpack from vllm import envs from vllm.logger import init_logger -# yapf: disable -from vllm.multimodal.inputs import (BaseMultiModalField, - MultiModalBatchedField, - MultiModalFieldConfig, MultiModalFieldElem, - MultiModalFlatField, MultiModalKwargs, - MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField, NestedTensors) -# yapf: enable +from vllm.multimodal.inputs import ( + BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, + MultiModalFieldElem, + MultiModalFlatField, + MultiModalKwargs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, + NestedTensors, +) from vllm.v1.engine import UtilityResult logger = init_logger(__name__) @@ -48,8 +51,10 @@ bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] def _log_insecure_serialization_warning(): - logger.warning_once("Allowing insecure serialization using pickle due to " - "VLLM_ALLOW_INSECURE_SERIALIZATION=1") + logger.warning_once( + "Allowing insecure serialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1" + ) def _typestr(val: Any) -> Optional[tuple[str, str]]: @@ -59,13 +64,50 @@ def _typestr(val: Any) -> Optional[tuple[str, str]]: return t.__module__, t.__qualname__ +def _encode_type_info_recursive(obj: Any) -> Any: + """Recursively encode type information for nested structures of + lists/dicts.""" + if obj is None: + return None + if type(obj) is list: + return [_encode_type_info_recursive(item) for item in obj] + if type(obj) is dict: + return {k: _encode_type_info_recursive(v) for k, v in obj.items()} + return _typestr(obj) + + +def _decode_type_info_recursive( + type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any] +) -> Any: + """Recursively decode type information for nested structures of + lists/dicts.""" + if type_info is None: + return data + if isinstance(type_info, dict): + assert isinstance(data, dict) + return { + k: _decode_type_info_recursive(type_info[k], data[k], convert_fn) + for k in type_info + } + if isinstance(type_info, list) and ( + # Exclude serialized tensors/numpy arrays. + len(type_info) != 2 or not isinstance(type_info[0], str) + ): + assert isinstance(data, list) + return [ + _decode_type_info_recursive(ti, d, convert_fn) + for ti, d in zip(type_info, data) + ] + return convert_fn(type_info, data) + + class MsgpackEncoder: """Encoder with custom torch tensor and numpy array serialization. Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. - By default, arrays below 256B are serialized inline Larger will get sent + By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. """ @@ -83,7 +125,7 @@ class MsgpackEncoder: def encode(self, obj: Any) -> Sequence[bytestr]: try: - self.aux_buffers = bufs = [b''] + self.aux_buffers = bufs = [b""] bufs[0] = self.encoder.encode(obj) # This `bufs` list allows us to collect direct pointers to backing # buffers of tensors and np arrays, and return them along with the @@ -107,14 +149,15 @@ class MsgpackEncoder: return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. - if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): + if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"): return self._encode_ndarray(obj) if isinstance(obj, slice): # We are assuming only int-based values will be used here. return tuple( int(v) if v is not None else None - for v in (obj.start, obj.stop, obj.step)) + for v in (obj.start, obj.stop, obj.step) + ) if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) @@ -129,25 +172,26 @@ class MsgpackEncoder: result = obj.result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: return None, result - # Since utility results are not strongly typed, we also encode - # the type (or a list of types in the case it's a list) to - # help with correct msgspec deserialization. - return _typestr(result) if type(result) is not list else [ - _typestr(v) for v in result - ], result + # Since utility results are not strongly typed, we recursively + # encode type information for nested structures of lists/dicts + # to help with correct msgspec deserialization. + return _encode_type_info_recursive(result), result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise TypeError(f"Object of type {type(obj)} is not serializable" - "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " - "fallback to pickle-based serialization.") + raise TypeError( + f"Object of type {type(obj)} is not serializable" + "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " + "fallback to pickle-based serialization." + ) if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) - return msgpack.Ext(CUSTOM_TYPE_PICKLE, - pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + return msgpack.Ext( + CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + ) def _encode_ndarray( self, obj: np.ndarray @@ -191,27 +235,22 @@ class MsgpackEncoder: for modality, itemlist in items.items() } - def _encode_mm_item(self, - item: MultiModalKwargsItem) -> list[dict[str, Any]]: + def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: return [self._encode_mm_field_elem(elem) for elem in item.values()] - def _encode_mm_field_elem(self, - elem: MultiModalFieldElem) -> dict[str, Any]: + def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]: return { - "modality": - elem.modality, - "key": - elem.key, - "data": (None if elem.data is None else - self._encode_nested_tensors(elem.data)), - "field": - self._encode_mm_field(elem.field), + "modality": elem.modality, + "key": elem.key, + "data": ( + None if elem.data is None else self._encode_nested_tensors(elem.data) + ), + "field": self._encode_mm_field(elem.field), } def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: return { - modality: self._encode_nested_tensors(data) - for modality, data in kw.items() + modality: self._encode_nested_tensors(data) for modality, data in kw.items() } def _encode_nested_tensors(self, nt: NestedTensors) -> Any: @@ -230,8 +269,7 @@ class MsgpackEncoder: raise TypeError(f"Unsupported field type: {field.__class__}") # We just need to copy all of the field values in order # which will be then used to reconstruct the field. - field_values = (getattr(field, f.name) - for f in dataclasses.fields(field)) + field_values = (getattr(field, f.name) for f in dataclasses.fields(field)) return name, *field_values @@ -243,18 +281,16 @@ class MsgpackDecoder: """ def __init__(self, t: Optional[Any] = None): - args = () if t is None else (t, ) - self.decoder = msgpack.Decoder(*args, - ext_hook=self.ext_hook, - dec_hook=self.dec_hook) + args = () if t is None else (t,) + self.decoder = msgpack.Decoder( + *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook + ) self.aux_buffers: Sequence[bytestr] = () if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: - if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): - # TODO - This check can become `isinstance(bufs, bytestr)` - # as of Python 3.10. + if isinstance(bufs, bytestr): # type: ignore return self.decoder.decode(bufs) self.aux_buffers = bufs @@ -286,17 +322,14 @@ class MsgpackDecoder: result_type, result = obj if result_type is not None: if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must " - "be set to use custom utility result types") - assert isinstance(result_type, list) - if len(result_type) == 2 and isinstance(result_type[0], str): - result = self._convert_result(result_type, result) - else: - assert isinstance(result, list) - result = [ - self._convert_result(rt, r) - for rt, r in zip(result_type, result) - ] + raise TypeError( + "VLLM_ALLOW_INSECURE_SERIALIZATION must " + "be set to use custom utility result types" + ) + # Use recursive decoding to handle nested structures + result = _decode_type_info_recursive( + result_type, result, self._convert_result + ) return UtilityResult(result) def _convert_result(self, result_type: Sequence[str], result: Any) -> Any: @@ -319,8 +352,7 @@ class MsgpackDecoder: # Copy from inline representation, to decouple the memory storage # of the message from the original buffer. And also make Torch # not complain about a readonly memoryview. - buffer = self.aux_buffers[data] if isinstance(data, int) \ - else bytearray(data) + buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) if not buffer: # torch.frombuffer doesn't like empty buffers @@ -332,17 +364,19 @@ class MsgpackDecoder: return arr.view(torch_dtype).view(shape) def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: - return MultiModalKwargsItems({ - modality: [self._decode_mm_item(item) for item in itemlist] - for modality, itemlist in obj.items() - }) + return MultiModalKwargsItems( + { + modality: [self._decode_mm_item(item) for item in itemlist] + for modality, itemlist in obj.items() + } + ) def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: return MultiModalKwargsItem.from_elems( - [self._decode_mm_field_elem(v) for v in obj]) + [self._decode_mm_field_elem(v) for v in obj] + ) - def _decode_mm_field_elem(self, obj: dict[str, - Any]) -> MultiModalFieldElem: + def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem: if obj["data"] is not None: obj["data"] = self._decode_nested_tensors(obj["data"]) @@ -359,10 +393,12 @@ class MsgpackDecoder: return MultiModalFieldElem(**obj) def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: - return MultiModalKwargs({ - modality: self._decode_nested_tensors(data) - for modality, data in obj.items() - }) + return MultiModalKwargs( + { + modality: self._decode_nested_tensors(data) + for modality, data in obj.items() + } + ) def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): @@ -391,5 +427,4 @@ class MsgpackDecoder: if code == CUSTOM_TYPE_CLOUDPICKLE: return cloudpickle.loads(data) - raise NotImplementedError( - f"Extension type code {code} is not supported") + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0a0e9fed725cb..393a4d964ee3e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,49 +3,51 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional, Protocol +from typing import Optional import numpy as np import torch import torch.nn as nn -from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, +) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.tree_attn import ( + TreeAttentionMetadata, + TreeAttentionMetadataBuilder, +) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch logger = init_logger(__name__) PADDING_SLOT_ID = -1 -class EagleAttentionMetadata(Protocol): - # Required attributes - num_actual_tokens: int - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - class EagleProposer: - def __init__( self, vllm_config: VllmConfig, @@ -54,77 +56,114 @@ class EagleProposer: ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config + assert self.speculative_config is not None self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method self.runner = runner + self.device = device self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() - self.is_multimodal_model = vllm_config.model_config \ - .is_multimodal_model + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + vllm_config.model_config + ) - self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) + self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None + self.draft_indexer_metadata_builder: Optional[AttentionMetadataBuilder] = None + self.attn_layer_names: list[str] = [] + self.indexer_layer_names: list[str] = [] + + self.use_cuda_graph = False + + compilation_config = self.vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE: + cudagraph_mode = compilation_config.cudagraph_mode + if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( + CUDAGraphMode.PIECEWISE + ): + logger.warning( + "Currently the eagle proposer only supports cudagraph_mode " + "PIECEWISE, if you want the drafter to use cuda graphs, " + "please set compilation_config.cudagraph_mode to PIECEWISE " + "or FULL_AND_PIECEWISE" + ) + self.use_cuda_graph = ( + cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) + and not self.speculative_config.enforce_eager + ) + + self.cudagraph_batch_sizes = ( + list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + if self.use_cuda_graph + else [] + ) # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + self.input_ids = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=device + ) + self.uses_mrope = self.vllm_config.model_config.uses_mrope + if self.uses_mrope: + # M-RoPE need (3, max_num_tokens) + self.mrope_positions = torch.zeros( + (3, self.max_num_tokens), dtype=torch.int64, device=device + ) + else: + # RoPE need (max_num_tokens,) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) self.arange = torch.arange( - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_batch_size + 1, - device=device, - dtype=torch.int32, + max_num_slots_for_arange, device=device, dtype=torch.int32 ) self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) + + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + self.allowed_attn_types: Optional[tuple] = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) + AiterFlashAttentionMetadata, + ) + rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) - else: - self.allowed_attn_types = (FlashAttentionMetadata, - TreeAttentionMetadata) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) + self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) tree_depth = len(self.tree_choices[-1]) # Precompute per-level properties of the tree. num_drafts_per_level = [0] * tree_depth @@ -133,10 +172,12 @@ class EagleProposer: self.cu_drafts_per_level = [num_drafts_per_level[0]] self.child_drafts_per_level = [num_drafts_per_level[0]] for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) + self.cu_drafts_per_level.append( + self.cu_drafts_per_level[-1] + num_drafts_per_level[level] + ) + self.child_drafts_per_level.append( + num_drafts_per_level[level] // num_drafts_per_level[level - 1] + ) # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, @@ -145,87 +186,139 @@ class EagleProposer: dtype=torch.int32, ).repeat(max_batch_size, 1) + def _get_positions(self, num_tokens: int): + if self.uses_mrope: + return self.mrope_positions[:, :num_tokens] + return self.positions[:num_tokens] + + def _set_positions(self, num_tokens: int, positions: torch.Tensor): + if self.uses_mrope: + self.mrope_positions[:, :num_tokens] = positions + else: + self.positions[:num_tokens] = positions + def propose( self, # [num_tokens] target_token_ids: torch.Tensor, - # [num_tokens] + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embeds: Optional[list[torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states + ) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + # FIXME: support hybrid kv for draft model (remove separate indexer) + if self.draft_indexer_metadata_builder: + draft_indexer_metadata = ( + self.draft_indexer_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + ) + else: + draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + + for layer_name in self.indexer_layer_names: + assert draft_indexer_metadata is not None + per_layer_attn_metadata[layer_name] = draft_indexer_metadata + + cudagraph_runtime_mode = CUDAGraphMode.NONE + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions + self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states - if self.is_multimodal_model: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = self.model.get_input_embeddings( - input_ids, - multimodal_embeddings=mm_embeds or None, - ) - self.inputs_embeds[:num_tokens] = inputs_embeds - inputs_embeds = self.inputs_embeds[:num_input_tokens] - input_ids = None - else: - inputs_embeds = None - input_ids = self.input_ids[:num_input_tokens] - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + ): ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=self._get_positions(num_input_tokens), hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method in ("deepseek_mtp", "ernie_mtp"): + if self.method == "mtp": last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) + return draft_token_ids.view(-1, 1) + + if self.uses_mrope: + positions = target_positions[:, last_token_indices] + else: + positions = target_positions[last_token_indices] + if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"): + hidden_states = self.hidden_states[last_token_indices] + else: + hidden_states = hidden_states[last_token_indices] if isinstance(attn_metadata, TreeAttentionMetadata): # Draft using tree attention. @@ -241,95 +334,139 @@ class EagleProposer: draft_token_ids = logits.argmax(dim=-1) - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # TODO: Currently, MTP module released by deepseek only has - # one layer. Adapt this code to support multiple layers once - # there's a multi-layer MTP module. - assert isinstance(attn_metadata, self.allowed_attn_types) + if self.allowed_attn_types is not None and not isinstance( + attn_metadata, self.allowed_attn_types + ): + raise ValueError( + f"Unsupported attention metadata type for speculative " + "decoding with num_speculative_tokens > 1: " + f"{type(attn_metadata)}. Supported types are: " + f"{self.allowed_attn_types}" + ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + cudagraph_runtime_mode = CUDAGraphMode.NONE + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[: batch_size + 1] + ).clone() + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + if self.uses_mrope: + positions += 1 + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. + # Since it is complex to remove such requests from the batch, + # we keep them in the batch but adjust the position ids + # and slot mappings to avoid the + # out-of-range access during the model execution. + # The draft tokens generated with this adjustment + # should be ignored. + exceeds_max_model_len = positions[0] >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), + positions, + ) + else: + positions += 1 + exceeds_max_model_len = positions >= self.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + common_attn_metadata.seq_lens += 1 + common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + + common_attn_metadata.num_computed_tokens_cpu = ( + common_attn_metadata.seq_lens_cpu - 1 + ) # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = attn_metadata.block_table.gather( - dim=1, index=block_numbers.view(-1, 1)) + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size + block_ids = common_attn_metadata.block_table_tensor.gather( + dim=1, index=block_numbers.view(-1, 1) + ) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + if self.uses_mrope: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + clamped_positions[0] % self.block_size + ) + else: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + clamped_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID + ) + + # Rebuild attention metadata + attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore + common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 + ) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions + self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states - if self.is_multimodal_model: - inputs_embeds = self.model.get_input_embeddings(input_ids) - self.inputs_embeds[:batch_size] = inputs_embeds - inputs_embeds = self.inputs_embeds[:input_batch_size] + if self.supports_mm_inputs: + self.inputs_embeds[:batch_size] = self.model.get_input_embeddings( + input_ids + ) + input_ids = None + inputs_embeds = self.inputs_embeds[:input_batch_size] else: - inputs_embeds = None input_ids = self.input_ids[:input_batch_size] + inputs_embeds = None # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=cudagraph_runtime_mode, + ): + ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:input_batch_size], + positions=self._get_positions(input_batch_size), hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) + logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -337,6 +474,166 @@ class EagleProposer: draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def prepare_next_token_ids_cpu( + self, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor( + next_token_ids, dtype=torch.int32, device=self.input_ids.device + ) + return next_token_ids + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item() + ) + for i in range(num_reqs) + ] + ) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1 + ) + + # Generate a mask for all valid tokens within those requests + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) + + return next_token_ids, valid_sampled_tokens_count + + def prepare_inputs_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] + - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] + ) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu), + ) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, + ) + + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu + ) + + return spec_common_attn_metadata, token_indices, token_indices_to_sample + def propose_tree( self, batch_size: int, @@ -348,10 +645,10 @@ class EagleProposer: hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builder - assert isinstance(tree_attn_metadata_builder, - TreeAttentionMetadataBuilder) + tree_attn_metadata_builder = self.runner.attn_groups[0][ + 0 + ].get_metadata_builder() + assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts @@ -360,31 +657,31 @@ class EagleProposer: if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view(batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list = [draft_token_ids] draft_hidden_states = hidden_states.view(batch_size, 1, -1) # Initialize empty tensors for concatenation with the level outputs. - tree_input_ids = torch.empty(0, - device=self.input_ids.device, - dtype=self.input_ids.dtype) - tree_positions = torch.empty(0, - device=self.positions.device, - dtype=self.positions.dtype) - tree_hidden_states = torch.empty(0, - device=self.hidden_states.device, - dtype=self.hidden_states.dtype) + tree_input_ids = torch.empty( + 0, device=self.input_ids.device, dtype=self.input_ids.dtype + ) + tree_positions = torch.empty( + 0, device=self.positions.device, dtype=self.positions.dtype + ) + tree_hidden_states = torch.empty( + 0, device=self.hidden_states.device, dtype=self.hidden_states.dtype + ) # Precompute the draft token positions. flattened_draft_positions = ( - positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) + positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :] + ) tree_depth = len(self.cu_drafts_per_level) for level in range(tree_depth - 1): # Get draft positions for RoPE. draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions + - total_num_drafts) >= self.max_model_len + exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. draft_positions = torch.where( @@ -396,27 +693,28 @@ class EagleProposer: if level_num_drafts > 1: # Repeat the positions for each draft at this level. draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=1) + level_num_drafts, dim=1 + ) if num_children > 1: # Repeat draft hidden states for each child. draft_hidden_states = draft_hidden_states.repeat_interleave( - num_children, dim=1) + num_children, dim=1 + ) # Concatenate the draft tokens, positions, and hidden states. - tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], - dim=1) - tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1) + tree_positions = torch.cat([tree_positions, draft_positions], dim=1) tree_hidden_states = torch.cat( - [tree_hidden_states, draft_hidden_states], dim=1) + [tree_hidden_states, draft_hidden_states], dim=1 + ) # Build new attention metadata for the next level of drafts. # This is necessary to support tree attention. query_len = total_num_drafts common_attn_metadata = replace( common_attn_metadata, - query_start_loc=query_len * self.arange[:batch_size + 1], + query_start_loc=query_len * self.arange[: batch_size + 1], seq_lens=common_attn_metadata.seq_lens + level_num_drafts, num_actual_tokens=batch_size * query_len, max_query_len=query_len, @@ -432,20 +730,20 @@ class EagleProposer: per_layer_attn_metadata[layer_name] = attn_metadata # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + attn_metadata.max_seq_len = min( + attn_metadata.max_seq_len, self.max_model_len + ) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] + query_positions = flattened_draft_positions[:, level : level + query_len] block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, - index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions % self.block_size) + block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) + slot_mapping = ( + block_ids * self.block_size + query_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -457,19 +755,21 @@ class EagleProposer: input_ids = tree_input_ids.view(-1) self.input_ids[:num_tokens] = input_ids self.positions[:num_tokens] = tree_positions.view(-1) - self.hidden_states[:num_tokens] = tree_hidden_states.view( - num_tokens, -1) + self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens) + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens + cudagraph_runtime_mode = CUDAGraphMode.NONE # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -479,15 +779,15 @@ class EagleProposer: # Get the output hidden states for the draft tokens. draft_hidden_states = hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] draft_last_hidden_states = last_hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] # Get the output logits for the draft tokens. logits = self.model.compute_logits( - draft_last_hidden_states.reshape(batch_size * level_num_drafts, - -1), - None, + draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1) ) # Sample a draft token for each child at the next tree level. @@ -495,25 +795,24 @@ class EagleProposer: if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view( - batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list.append(draft_token_ids) # Update the # drafts counters for the next tree level. - level_num_drafts = self.cu_drafts_per_level[level + - 1] - total_num_drafts + level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, - # [batch_size] - num_rejected_tokens: torch.Tensor + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ - This function is used to prepare the inputs for the spec decode. + This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. @@ -534,14 +833,18 @@ class EagleProposer: # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) + device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() @@ -551,7 +854,8 @@ class EagleProposer: new_query_start_loc_cpu = torch.zeros( query_start_loc_cpu.shape, dtype=torch.int32, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + ) new_query_start_loc_np = new_query_start_loc_cpu.numpy() np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) @@ -561,36 +865,36 @@ class EagleProposer: # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) + new_query_start_locs_expanded = np.repeat( + new_query_start_loc_np[:-1], new_num_tokens_per_req_np + ) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded + token_offests = ( + self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded + ) # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np + ) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) + token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), + query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -598,65 +902,166 @@ class EagleProposer: block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) return spec_common_attn_metadata, token_indices + def get_model_name(self, model: nn.Module) -> str: + if hasattr(model, "module"): # multi-GPU + model = model.module + return model.__class__.__name__ + def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() + ) + # FIXME: support hybrid kv for draft model + target_indexer_layer_names = set( + get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ).keys() + ) from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, model_config=draft_model_config + ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() + - target_attn_layer_names + ) + indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ) + draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names self.attn_layer_names = list(draft_attn_layer_names) + self.indexer_layer_names = list(draft_indexer_layer_names) + + if self.indexer_layer_names: + first_layer = self.indexer_layer_names[0] + self.draft_indexer_metadata_builder = ( + indexer_layers[first_layer] + .get_attn_backend() + .get_builder_cls()( + indexer_layers[first_layer].get_kv_cache_spec(), + self.indexer_layer_names, + self.vllm_config, + self.device, + ) + ) + else: + self.draft_indexer_metadata_builder = None + + if self.supports_mm_inputs: + # Even if the target model is multimodal, we can also use + # text-only draft models + try: + dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device) + self.model.get_input_embeddings( + dummy_input_ids, multimodal_embeddings=None + ) + except (NotImplementedError, AttributeError, TypeError): + logger.warning( + "Draft model does not support multimodal inputs, " + "falling back to text-only mode" + ) + self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( - target_model.config.image_token_index) + if ( + self.get_model_name(target_model) + == "Qwen2_5_VLForConditionalGeneration" + ): + self.model.config.image_token_index = target_model.config.image_token_id + else: + self.model.config.image_token_index = ( + target_model.config.image_token_index + ) target_language_model = target_model.get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: - logger.info( - "Assuming the EAGLE head shares the same vocab embedding" - " with the target model.") - del self.model.model.embed_tokens - self.model.model.embed_tokens = ( - target_language_model.model.embed_tokens) + if get_pp_group().world_size == 1: + if hasattr(target_language_model.model, "embed_tokens"): + target_embed_tokens = target_language_model.model.embed_tokens + elif hasattr(target_language_model.model, "embedding"): + target_embed_tokens = target_language_model.model.embedding + else: + raise AttributeError( + "Target model does not have 'embed_tokens' or 'embedding' attribute" + ) + + # Check if shapes match and we found the embedding + eagle_shape = self.model.model.embed_tokens.weight.shape + target_shape = target_embed_tokens.weight.shape + if eagle_shape == target_shape: + logger.info( + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model." + ) + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens + else: + logger.info( + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model." + ) else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_language_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_language_model.lm_head + if self.vllm_config.speculative_config.method != "eagle3": + if hasattr(target_language_model, "lm_head"): + logger.info("Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_language_model.lm_head + else: + if ( + hasattr(self.model, "lm_head") + and hasattr(target_language_model, "lm_head") + and self.model.lm_head.weight.shape + == target_language_model.lm_head.weight.shape + ): + logger.info( + "Assuming the EAGLE head shares the same lm_head" + " with the target model." + ) + del self.model.lm_head + self.model.lm_head = target_language_model.lm_head + else: + logger.info( + "The EAGLE head's lm_head will be loaded separately" + " from the target model." + ) @torch.inference_mode() def dummy_run( self, num_tokens: int, + use_cudagraphs=True, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): - if self.is_multimodal_model: + if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE + if use_cudagraphs + else CUDAGraphMode.NONE, + ): + if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: @@ -665,13 +1070,37 @@ class EagleProposer: self.model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=self._get_positions(num_tokens), hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) - def validate_same_kv_cache_group(self, - kv_cache_config: KVCacheConfig) -> None: + def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: + """Find and return the attention metadata builders for EAGLE layers. + + Returns: + The metadata builders for EAGLE layers. + + Raises: + AssertionError: If no metadata builders are found for EAGLE layers. + """ + builder = None + chosen_layer = self.attn_layer_names[0] + + for kv_cache_group in self.runner.attn_groups: + for attn_group in kv_cache_group: + if chosen_layer in attn_group.layer_names: + builder = attn_group.get_metadata_builder() + break + if builder is not None: + break + + assert builder is not None, ( + "Failed to find attention metadata builder for EAGLE layers." + ) + return builder + + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ Validate that all eagle layers belong to the same KVCacheGroup. Need this assumption to ensure all eagle layers can use the @@ -682,12 +1111,17 @@ class EagleProposer: for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( - set([ - kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + assert ( + len( + set( + [ + kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + ] + ) + ) + == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 3e90179e78d99..150dde177ce8d 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -27,10 +27,9 @@ class MedusaProposer: # Save config parameters self.vllm_config = vllm_config self.device = device - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.hidden_size = vllm_config.speculative_config.\ - draft_model_config.get_hidden_size( + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.hidden_size = ( + vllm_config.speculative_config.draft_model_config.get_hidden_size() ) self.dtype = vllm_config.model_config.dtype @@ -41,7 +40,7 @@ class MedusaProposer: ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) - logits = self.model.compute_logits(blocks, None) + logits = self.model.compute_logits(blocks) # Get draft tokens and transpose the result # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU @@ -51,16 +50,19 @@ class MedusaProposer: def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag + with set_model_tag("medusa_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=self.vllm_config. - speculative_config.draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.speculative_config.draft_model_config, + ) @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: - hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device, + ) + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model(hidden_states) diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index b1efb40612d54..d0695244cb164 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -8,7 +8,6 @@ import torch @dataclass class SpecDecodeMetadata: - # [num_tokens] draft_token_ids: torch.Tensor # [batch_size] @@ -36,22 +35,19 @@ class SpecDecodeMetadata: flattened_draft_token_ids = sum(draft_token_ids, []) num_tokens = len(flattened_draft_token_ids) - draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids, - dtype=torch.int32, - device=device) + draft_token_ids_tensor = torch.tensor( + flattened_draft_token_ids, dtype=torch.int32, device=device + ) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) - cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to( - device) + cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) - target_logits_indices = torch.zeros(num_tokens, - dtype=torch.int32, - device=device) - bonus_logits_indices = torch.zeros(batch_size, - dtype=torch.int32, - device=device) - logits_indices = torch.zeros(num_tokens + batch_size, - dtype=torch.int32, - device=device) + target_logits_indices = torch.zeros( + num_tokens, dtype=torch.int32, device=device + ) + bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device) + logits_indices = torch.zeros( + num_tokens + batch_size, dtype=torch.int32, device=device + ) return cls( draft_token_ids=draft_token_ids_tensor, num_draft_tokens=num_draft_tokens, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index b4bc3058c570a..89a8a11a3d560 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from dataclasses import dataclass, field from typing import Optional @@ -30,8 +31,10 @@ class SpecDecodingStats: @classmethod def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": - return cls(num_spec_tokens=num_spec_tokens, - num_accepted_tokens_per_pos=[0] * num_spec_tokens) + return cls( + num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens, + ) def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): self.num_drafts += 1 @@ -58,14 +61,15 @@ class SpecDecodingLogging: self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] self.accepted_tokens_per_pos_lists: list[list[int]] = [] + self.last_log_time = time.monotonic() def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) - self.num_accepted_tokens.append( - spec_decoding_stats.num_accepted_tokens) + self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens) self.accepted_tokens_per_pos_lists.append( - spec_decoding_stats.num_accepted_tokens_per_pos) + spec_decoding_stats.num_accepted_tokens_per_pos + ) def log(self, log_fn=logger.info): if not self.num_drafts: @@ -73,9 +77,19 @@ class SpecDecodingLogging: num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) + draft_throughput = 0 + accepted_throughput = 0 - draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * - 100 if num_draft_tokens > 0 else float("nan")) + elapsed_time = time.monotonic() - self.last_log_time + if elapsed_time > 0: + draft_throughput = num_draft_tokens / elapsed_time + accepted_throughput = num_accepted_tokens / elapsed_time + + draft_acceptance_rate = ( + num_accepted_tokens / num_draft_tokens * 100 + if num_draft_tokens > 0 + else float("nan") + ) # Conventionally, mean acceptance length includes the bonus token mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) @@ -86,16 +100,20 @@ class SpecDecodingLogging: log_fn( "SpecDecoding metrics: " - "Draft acceptance rate: %.1f%%, " "Mean acceptance length: %.2f, " + "Accepted throughput: %.2f tokens/s, " + "Drafted throughput: %.2f tokens/s, " "Accepted: %d tokens, " "Drafted: %d tokens, " - "Per-position acceptance rate: %s", - draft_acceptance_rate, + "Per-position acceptance rate: %s, " + "Avg Draft acceptance rate: %.1f%%", mean_acceptance_length, + accepted_throughput, + draft_throughput, num_accepted_tokens, num_draft_tokens, rates_str, + draft_acceptance_rate, ) self.reset() @@ -127,52 +145,81 @@ class SpecDecodingProm: self, speculative_config: Optional[SpeculativeConfig], labelnames: list[str], - labelvalues: list[str], + per_engine_labelvalues: dict[int, list[str]], ): self.spec_decoding_enabled = speculative_config is not None if not self.spec_decoding_enabled: return - self.counter_spec_decode_num_drafts = \ - self._counter_cls( - name="vllm:spec_decode_num_drafts", - documentation="Number of spec decoding drafts.", - labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_draft_tokens = \ - self._counter_cls( - name="vllm:spec_decode_num_draft_tokens", - documentation="Number of draft tokens.", - labelnames=labelnames,).labels(*labelvalues) - self.counter_spec_decode_num_accepted_tokens = \ - self._counter_cls( - name="vllm:spec_decode_num_accepted_tokens", - documentation="Number of accepted tokens.", - labelnames=labelnames).labels(*labelvalues) + counter_drafts = self._counter_cls( + name="vllm:spec_decode_num_drafts", + documentation="Number of spec decoding drafts.", + labelnames=labelnames, + ) + self.counter_spec_decode_num_drafts = make_per_engine( + counter_drafts, per_engine_labelvalues + ) + + counter_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens", + documentation="Number of draft tokens.", + labelnames=labelnames, + ) + self.counter_spec_decode_num_draft_tokens = make_per_engine( + counter_draft_tokens, per_engine_labelvalues + ) + + counter_accepted_tokens = self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens", + documentation="Number of accepted tokens.", + labelnames=labelnames, + ) + self.counter_spec_decode_num_accepted_tokens = make_per_engine( + counter_accepted_tokens, per_engine_labelvalues + ) assert speculative_config is not None - num_spec_tokens = (speculative_config.num_speculative_tokens - if self.spec_decoding_enabled else 0) + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if self.spec_decoding_enabled + else 0 + ) pos_labelnames = labelnames + ["position"] base_counter = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_per_pos", documentation="Accepted tokens per draft position.", labelnames=pos_labelnames, ) - self.counter_spec_decode_num_accepted_tokens_per_pos: list[ - prometheus_client.Counter] = [] - for pos in range(num_spec_tokens): - pos_labelvalues = labelvalues + [str(pos)] - self.counter_spec_decode_num_accepted_tokens_per_pos.append( - base_counter.labels(*pos_labelvalues)) + self.counter_spec_decode_num_accepted_tokens_per_pos: dict[ + int, list[prometheus_client.Counter] + ] = { + idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)] + for idx, lv in per_engine_labelvalues.items() + } - def observe(self, spec_decoding_stats: SpecDecodingStats): + def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0): if not self.spec_decoding_enabled: return - self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts) - self.counter_spec_decode_num_draft_tokens.inc( - spec_decoding_stats.num_draft_tokens) - self.counter_spec_decode_num_accepted_tokens.inc( - spec_decoding_stats.num_accepted_tokens) + self.counter_spec_decode_num_drafts[engine_idx].inc( + spec_decoding_stats.num_drafts + ) + self.counter_spec_decode_num_draft_tokens[engine_idx].inc( + spec_decoding_stats.num_draft_tokens + ) + self.counter_spec_decode_num_accepted_tokens[engine_idx].inc( + spec_decoding_stats.num_accepted_tokens + ) for pos, counter in enumerate( - self.counter_spec_decode_num_accepted_tokens_per_pos): + self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx] + ): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) + + +def make_per_engine( + counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]] +): + """Create a counter for each label value.""" + return { + idx: counter.labels(*labelvalues) + for idx, labelvalues in per_engine_labelvalues.items() + } diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index fbcf2cb50d371..e2f83cb24aa90 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +import os import numpy as np -from numba import jit +from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig class NgramProposer: - def __init__(self, vllm_config: VllmConfig): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.prompt_lookup_min is not None @@ -26,55 +25,190 @@ class NgramProposer: # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len + # Pre-allocate buffers for numba batch propose. + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32) + self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) + + # Threshold of total number of tokens in the batch to enable + # multi-threading in numba batch propose. + self.num_tokens_threshold = 8192 + tp_size = vllm_config.parallel_config.tensor_parallel_size + cpu_count = os.cpu_count() + # Max number of threads for numba parallel processing. + if cpu_count: + # Divide by 2 to use physical cores + # and not logical cores (hyper-threading). + # Cap the number of threads to 8 to avoid using too many threads + # since other components like frontend (incl tokenization) + # and Structured Outputs also use multiple threads. + # TODO(ekagra-ranjan): bump up the cap from 1 to 8 + # when TP parallelization for ngram is implemented. + self.num_numba_thread_available = min(1, (cpu_count // 2)) + # Divide by tp_size to ensure each tensor parallel rank + # has some threads since all ranks will run this. + self.num_numba_thread_available //= tp_size + else: + self.num_numba_thread_available = 1 + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. - self.propose(np.zeros(1024, dtype=np.int32)) + self.propose( + [[]] * 1024, + [""] * 1024, + np.zeros(1024, dtype=np.int32), + np.zeros((1024, self.max_model_len), dtype=np.int32), + set(), + ) + + def batch_propose( + self, + num_requests: int, + valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + ) -> list[list[int]]: + """Batch version of ngram proposer using numba for acceleration. + + Args: + valid_ngram_requests: + Set of indices of requests that need ngram proposals. + num_tokens_no_spec: + Numpy array of shape (batch_size,) representing the number + of tokens without speculative tokens for each request. + token_ids_cpu: + Numpy array of shape (batch_size, max_model_len) + representing the token IDs for each request. + + Returns: + list[list[int]]: + A list where each element is a list of proposed + token IDs for the corresponding request. + """ + draft_token_ids: list[list[int]] = [] + + # Only run batch propose if there are requests needing ngram proposals. + # avoid calling numba function with empty list which causes error + # ValueError: cannot compute fingerprint of empty list + if num_ngram_requests := len(valid_ngram_requests): + original_num_numba_threads = get_num_threads() + # Ensure we use at least one thread. + # If total tokens is small, using multiple threads + # may slow down due to overhead. + total_tokens = np.sum(num_tokens_no_spec) + if total_tokens >= self.num_tokens_threshold: + final_num_threads = max( + 1, min(self.num_numba_thread_available, num_ngram_requests) + ) + set_num_threads(final_num_threads) + else: + set_num_threads(1) + + batch_propose_numba( + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + self.min_n, + self.max_n, + self.max_model_len, + self.k, + self.valid_ngram_draft, + self.valid_ngram_num_drafts, + ) + + # Restore original number of threads. + set_num_threads(original_num_numba_threads) + + for i in range(num_requests): + if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0: + draft_token_ids.append( + self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist() + ) + else: + draft_token_ids.append([]) + + return draft_token_ids def propose( self, - context_token_ids: np.ndarray, - ) -> Optional[np.ndarray]: - """Proposes the next sequence of tokens based on n-gram pattern - matching in the context. The function finds matches of the last n - tokens in the previous context, and returns k tokens that followed - that match. - - Args: - context_token_ids: Numpy array of token IDs representing the - context sequence. + sampled_token_ids: list[list[int]], + req_ids: list[str], + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + spec_decode_unsupported_reqs: set, + ) -> list[list[int]]: + # find which requests need ngram proposals + valid_ngram_requests = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + continue - Returns: - np.ndarray: The sequence of tokens that followed - the matched n-gram in the context. - None: If no matching n-gram pattern is found. + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in spec_decode_unsupported_reqs: + continue - Example: - If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and - k = 4: - - The last 3 (= max_n) tokens [4,2,3] cannot find a match. - - The last 2 tokens [2,3] will be matched against the previous - 4 tokens [1,2,3,4]. - - Finding a match of [2,3] would return the tokens that - followed that pattern. Here we will return [4,2,3] because - we only have three tokens after the match. - """ - # TODO(woosuk): Optimize this. - return _find_longest_matched_ngram_and_propose_tokens( - origin_tokens=context_token_ids, - min_ngram=self.min_n, - max_ngram=self.max_n, - max_model_len=self.max_model_len, - k=self.k) + num_tokens = num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + continue + + valid_ngram_requests.append(i) + + draft_token_ids = self.batch_propose( + len(sampled_token_ids), + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + ) + + return draft_token_ids def load_model(self, *args, **kwargs): # No model to load. pass +@njit(parallel=True) +def batch_propose_numba( + valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + min_n: int, + max_n: int, + max_model_len: int, + k: int, + valid_ngram_draft: np.ndarray, + valid_ngram_num_drafts: np.ndarray, +): + for i in prange(len(valid_ngram_requests)): + idx = valid_ngram_requests[i] + num_tokens = num_tokens_no_spec[idx] + context_token_ids = token_ids_cpu[idx, :num_tokens] + drafter_output = _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=context_token_ids, + min_ngram=min_n, + max_ngram=max_n, + max_model_len=max_model_len, + k=k, + ) + + valid_ngram_num_drafts[i] = drafter_output.shape[0] + if len(drafter_output): + valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output + + @jit(nopython=True) def _find_longest_matched_ngram_and_propose_tokens( - origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, - max_model_len: int, k: int) -> Optional[np.ndarray]: + origin_tokens: np.ndarray, + min_ngram: int, + max_ngram: int, + max_model_len: int, + k: int, +) -> np.ndarray: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). @@ -84,12 +218,12 @@ def _find_longest_matched_ngram_and_propose_tokens( # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: - return None + return np.empty((0,), dtype=origin_tokens.dtype) # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: - return None + return np.empty((0,), dtype=origin_tokens.dtype) # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with @@ -107,7 +241,7 @@ def _find_longest_matched_ngram_and_propose_tokens( longest_ngram = 0 position = 0 - # lps[0] always equal to 0, we starts with index 1 + # lps[0] always equal to 0, we start with index 1 prev_lps = 0 i = 1 while i < total_token: @@ -146,7 +280,7 @@ def _find_longest_matched_ngram_and_propose_tokens( if longest_ngram < min_ngram: # No valid ngram is found - return None + return np.empty((0,), dtype=origin_tokens.dtype) # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] @@ -154,4 +288,4 @@ def _find_longest_matched_ngram_and_propose_tokens( # total_token-1-position+longest_ngram start_position = total_token - 1 - position + longest_ngram k = min(k, total_token - start_position) - return origin_tokens[start_position:start_position + k] + return origin_tokens[start_position : start_position + k] diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 1116179dc5b61..1901c6fc9f14f 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -7,8 +7,10 @@ _SAMPLING_EPS = 1e-5 def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: """True if request is incompatible with speculative decoding""" - return (sampling_params.frequency_penalty != 0.0 - or sampling_params.presence_penalty != 0.0 - or sampling_params.repetition_penalty != 1.0 - or sampling_params.min_p > _SAMPLING_EPS - or sampling_params.logprobs is not None) + return ( + sampling_params.frequency_penalty != 0.0 + or sampling_params.presence_penalty != 0.0 + or sampling_params.repetition_penalty != 1.0 + or sampling_params.min_p > _SAMPLING_EPS + or sampling_params.logprobs is not None + ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3bafa61044abc..1b5e75313d89d 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -4,16 +4,18 @@ from __future__ import annotations import multiprocessing from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, +) from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: @@ -33,11 +35,11 @@ class StructuredOutputManager: """Engine-level manager for structured output requests.""" def __init__(self, vllm_config: VllmConfig): - self.backend: Optional[StructuredOutputBackend] = None - self.reasoner: Optional[ReasoningParser] = None + self.backend: StructuredOutputBackend | None = None + self.reasoner: ReasoningParser | None = None self.vllm_config = vllm_config - self._grammar_bitmask: Optional[torch.Tensor] = None + self._grammar_bitmask: torch.Tensor | None = None self._full_mask = torch.tensor(-1, dtype=torch.int32) max_batch_size = self.vllm_config.scheduler_config.max_num_seqs @@ -48,8 +50,7 @@ class StructuredOutputManager: # - at least 1 CPU # - at most half the number of CPUs or 8, whichever is less max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) - self.executor_for_fillmask = ThreadPoolExecutor( - max_workers=max_workers) + self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers) if not self.vllm_config.model_config.skip_tokenizer_init: # The default max_workers if not specified is the number of @@ -60,15 +61,15 @@ class StructuredOutputManager: max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) - reasoning_backend = \ - self.vllm_config.decoding_config.reasoning_backend - if reasoning_backend: + model_config=self.vllm_config.model_config + ) + reasoning_parser = ( + self.vllm_config.structured_outputs_config.reasoning_parser + ) + if reasoning_parser: reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) + reasoning_parser + ) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: @@ -76,16 +77,19 @@ class StructuredOutputManager: return if TYPE_CHECKING: - assert request.sampling_params is not None and \ - request.sampling_params.guided_decoding is not None + assert ( + request.sampling_params is not None + and request.sampling_params.structured_outputs is not None + ) # Initialize the backend the first time it is needed. # # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). + # _backend is set in Processor._validate_structured_output if self.backend is None: assert request.sampling_params is not None - backend = request.sampling_params.guided_decoding.backend + backend = request.sampling_params.structured_outputs._backend vocab_size = self.vllm_config.model_config.get_vocab_size() if backend == "xgrammar": self.backend = XgrammarBackend( @@ -100,17 +104,25 @@ class StructuredOutputManager: vocab_size=vocab_size, ) elif backend == "outlines": - from vllm.v1.structured_output.backend_outlines import ( - OutlinesBackend) + from vllm.v1.structured_output.backend_outlines import OutlinesBackend self.backend = OutlinesBackend( self.vllm_config, tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "lm-format-enforcer": + from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 + LMFormatEnforcerBackend, + ) + + self.backend = LMFormatEnforcerBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: - raise ValueError( - f"Unsupported structured output backend: {backend}") + raise ValueError(f"Unsupported structured output backend: {backend}") grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] @@ -156,15 +168,16 @@ class StructuredOutputManager: requests: dict[str, Request], structured_output_request_ids: dict[str, int], scheduled_spec_decode_tokens: dict[str, list[int]], - ) -> Optional[npt.NDArray[np.int32]]: + ) -> npt.NDArray[np.int32] | None: # Prepare the structured output bitmask for this batch. if not structured_output_request_ids: return None max_num_spec_tokens = 0 if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = \ + max_num_spec_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens + ) if self._grammar_bitmask is None: assert self.backend is not None @@ -173,22 +186,23 @@ class StructuredOutputManager: # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the # bonus token / non-speculative token. - self._grammar_bitmask = \ - self.backend.allocate_token_bitmask( - max_batch_size * (1 + max_num_spec_tokens)) + self._grammar_bitmask = self.backend.allocate_token_bitmask( + max_batch_size * (1 + max_num_spec_tokens) + ) # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. # These are stored inline in the tensor and unpacked by the gpu runner. cumulative_index = 0 - ordered_seq = sorted(structured_output_request_ids.items(), - key=lambda x: x[1]) + ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) # Optimized parallel filling of bitmasks for # non-spec, large-batch-size cases - if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \ - max_num_spec_tokens == 0: + if ( + len(ordered_seq) > self.fill_bitmask_parallel_threshold + and max_num_spec_tokens == 0 + ): promises = [] batch = [] for req_id, _ in ordered_seq: @@ -199,8 +213,9 @@ class StructuredOutputManager: assert structured_output_request.grammar is not None apply_bitmask = self.should_fill_bitmask(request) - batch.append((structured_output_request.grammar, - cumulative_index, apply_bitmask)) + batch.append( + (structured_output_request.grammar, cumulative_index, apply_bitmask) + ) if len(batch) == self.fill_bitmask_parallel_batch_size: promises.append(self._async_submit_fill_bitmask(batch)) batch = [] @@ -226,18 +241,28 @@ class StructuredOutputManager: state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) for i, token in enumerate(req_tokens + [None]): - self._fill_bitmasks([(structured_output_request.grammar, - cumulative_index, apply_bitmask)]) + self._fill_bitmasks( + [ + ( + structured_output_request.grammar, + cumulative_index, + apply_bitmask, + ) + ] + ) - if apply_bitmask and token is not None and \ - not structured_output_request.grammar.is_terminated(): + if ( + apply_bitmask + and token is not None + and not structured_output_request.grammar.is_terminated() + ): assert structured_output_request.grammar.accept_tokens( - req_id, [token]) + req_id, [token] + ) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: - structured_output_request.grammar.rollback( - state_advancements) + structured_output_request.grammar.rollback(state_advancements) bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]: @@ -252,8 +277,9 @@ class StructuredOutputManager: if self.reasoner is not None: assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: - request.structured_output_request.reasoning_ended = \ + request.structured_output_request.reasoning_ended = ( self.reasoner.is_reasoning_end(request.prompt_token_ids) + ) return request.structured_output_request.reasoning_ended return True diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 02e7fc33f517d..081cdfdc9932b 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -7,16 +7,18 @@ import copy import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) from vllm.v1.structured_output.request import get_structured_output_key if TYPE_CHECKING: @@ -26,8 +28,7 @@ if TYPE_CHECKING: else: llguidance = LazyLoader("llguidance", globals(), "llguidance") llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") - llguidance_torch = LazyLoader("llguidance.torch", globals(), - "llguidance.torch") + llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") logger = init_logger(__name__) @@ -36,16 +37,18 @@ def _walk_json_for_additional_properties(data: object): if isinstance(data, dict): for value in data.values(): _walk_json_for_additional_properties(value) - if 'additionalProperties' not in data and \ - ('properties' in data or 'patternProperties' in data): - data['additionalProperties'] = False + if "additionalProperties" not in data and ( + "properties" in data or "patternProperties" in data + ): + data["additionalProperties"] = False elif isinstance(data, list): for item in data: _walk_json_for_additional_properties(item) def process_for_additional_properties( - guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + guide_json: Union[str, dict[str, Any]], +) -> dict[str, Any]: if isinstance(guide_json, str): guide_json_obj = json.loads(guide_json) else: @@ -57,21 +60,27 @@ def process_for_additional_properties( @dataclass class GuidanceBackend(StructuredOutputBackend): - def __post_init__(self): - self.disable_any_whitespace = \ - self.vllm_config.decoding_config.disable_any_whitespace - self.disable_additional_properties = \ - self.vllm_config.decoding_config.disable_additional_properties + self.disable_any_whitespace = ( + self.vllm_config.structured_outputs_config.disable_any_whitespace + ) + self.disable_additional_properties = ( + self.vllm_config.structured_outputs_config.disable_additional_properties + ) self.ll_tokenizer = llguidance_hf.from_tokenizer( - self.tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec, self.disable_any_whitespace, - self.disable_additional_properties) + request_type, + grammar_spec, + self.disable_any_whitespace, + self.disable_additional_properties, + ) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -90,7 +99,8 @@ class GuidanceBackend(StructuredOutputBackend): def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( - max_num_seqs, self.ll_tokenizer.vocab_size) + max_num_seqs, self.ll_tokenizer.vocab_size + ) def destroy(self): pass @@ -178,15 +188,17 @@ def serialize_guidance_grammar( disable_any_whitespace: bool = False, disable_additional_properties: bool = False, ) -> str: - - def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str: + def _process_schema( + grammar_spec: Union[str, dict[str, Any]], + ) -> str: if disable_additional_properties: grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ "whitespace_flexible": not disable_any_whitespace, - }) + }, + ) if request_type == StructuredOutputOptions.JSON: return _process_schema(grammar_spec) @@ -195,7 +207,8 @@ def serialize_guidance_grammar( '{"type": "object"}', defaults={ "whitespace_flexible": not disable_any_whitespace, - }) + }, + ) else: if request_type == StructuredOutputOptions.REGEX: tp = "regex" @@ -215,29 +228,32 @@ def serialize_guidance_grammar( trig = next((t for t in triggers if begin.startswith(t)), None) if trig is None: raise ValueError( - f"Trigger {begin} not found in triggers {triggers}") + f"Trigger {begin} not found in triggers {triggers}" + ) tags.append( llguidance.StructTag( trigger=trig, begin=s["begin"], grammar=_process_schema(s["schema"]), end=s["end"], - )) + ) + ) if not tags: - raise ValueError( - "No structural tags found in the grammar spec.") + raise ValueError("No structural tags found in the grammar spec.") return llguidance.StructTag.to_grammar(tags) else: - logger.error("Validation should have already occurred. " - "Please file an issue.") - raise ValueError("grammar is not of valid supported types. " - f"({request_type!s})") + logger.error( + "Validation should have already occurred. Please file an issue." + ) + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})" + ) return llguidance.grammar_from(tp, grammar_spec) def validate_guidance_grammar( - sampling_params: SamplingParams, - tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: + sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None +) -> None: tp, grm = get_structured_output_key(sampling_params) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py new file mode 100644 index 0000000000000..d9e484092d6ab --- /dev/null +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass, field +from functools import lru_cache +from typing import TYPE_CHECKING + +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) + +if TYPE_CHECKING: + import lmformatenforcer + import lmformatenforcer.integrations.vllm as lmfe_vllm +else: + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), "lmformatenforcer") + lmfe_vllm = LazyLoader( + "lmformatenforcer.integrations.vllm", + globals(), + "lmformatenforcer.integrations.vllm", + ) + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase, vocab_size: int +) -> lmfe_vllm.TokenEnforcerTokenizerData: + return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( + tokenizer, use_bitmask=True, vocab_size=vocab_size + ) + + +@dataclass +class LMFormatEnforcerGrammar(StructuredOutputGrammar): + token_enforcer: lmformatenforcer.TokenEnforcer + current_tokens_prefix: list[int] = field(default_factory=list) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + original_len = len(self.current_tokens_prefix) + for token in tokens: + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + ).is_token_allowed(token): + # Rollback partial updates to ensure atomicity. + del self.current_tokens_prefix[original_len:] + return False + self.current_tokens_prefix.append(token) + return True + + def validate_tokens(self, tokens: list[int]) -> list[int]: + for prefix_length in range(len(tokens)): + prefix = tokens[:prefix_length] + next_token = tokens[prefix_length] + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + prefix + ).is_token_allowed(next_token): + break + else: + return tokens + + return tokens[:prefix_length] + + def rollback(self, num_tokens: int) -> None: + self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens] + + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + allowed_tokens = self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + ) + bitmask[batch_index] = allowed_tokens.allowed_tokens + + def is_terminated(self) -> bool: + # We are considered terminated if the prefix ends with eos_token_id + return_value = ( + len(self.current_tokens_prefix) > 0 + and self.current_tokens_prefix[-1] == self.token_enforcer.eos_token_id + ) + return return_value + + def reset(self): + self.current_tokens_prefix = [] + + +@dataclass +class LMFormatEnforcerBackend(StructuredOutputBackend): + def __post_init__(self): + self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + self.tokenizer, self.vocab_size + ) + + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: + character_level_parser: lmformatenforcer.CharacterLevelParser + if request_type == StructuredOutputOptions.JSON: + spec_dict = json.loads(grammar_spec) + character_level_parser = lmformatenforcer.JsonSchemaParser(spec_dict) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + character_level_parser = lmformatenforcer.JsonSchemaParser(None) + elif request_type == StructuredOutputOptions.REGEX: + character_level_parser = lmformatenforcer.RegexParser(grammar_spec) + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + character_level_parser = lmformatenforcer.UnionParser( + [lmformatenforcer.StringParser(choice) for choice in choices] + ) + else: + raise ValueError( + f"Invalid request type for LM Format Enforcer backend({request_type!s})" + ) + max_rollback_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None + else 0 + ) + + if max_rollback_tokens > 0: + raise ValueError( + "LM Format Enforcer backend does not support speculative tokens" + ) + + token_enforcer = lmformatenforcer.TokenEnforcer( + tokenizer_data=self.tokenizer_data, + parser=character_level_parser, + ) + return LMFormatEnforcerGrammar(token_enforcer) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +def validate_structured_output_request_lm_format_enforcer(params: SamplingParams): + if params.structured_outputs is None: + return + + so_params = params.structured_outputs + + if so_params.regex: + return + elif so_params.json: + if isinstance(so_params.json, str): + try: + # make sure schema is valid json + json.loads(so_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + try: + json.dumps(so_params.json) + except Exception as e: + raise ValueError( + f"Error serializing structured outputs jsonschema: {e}" + ) from e + return + elif so_params.choice: + return + elif so_params.grammar: + raise ValueError( + "LM Format Enforcer structured outputs backend " + "does not support grammar specifications" + ) diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py index 572e4984480fa..c9875337179ef 100644 --- a/vllm/v1/structured_output/backend_outlines.py +++ b/vllm/v1/structured_output/backend_outlines.py @@ -15,20 +15,23 @@ from regex import escape as regex_escape from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) -from vllm.v1.structured_output.utils import (OutlinesVocabulary, - get_outlines_cache, - get_outlines_vocabulary) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) +from vllm.v1.structured_output.utils import ( + OutlinesVocabulary, + get_outlines_cache, + get_outlines_vocabulary, +) if TYPE_CHECKING: import outlines_core as oc import outlines_core.json_schema as json_schema else: oc = LazyLoader("oc", globals(), "outlines_core") - json_schema = LazyLoader("json_schema", globals(), - "outlines_core.json_schema") + json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema") # Python 3.11+ sre_parse and sre_constants # are deprecated, so we must import them from re @@ -46,13 +49,13 @@ else: @dataclass class OutlinesBackend(StructuredOutputBackend): - def __post_init__(self): self.vocabulary = get_outlines_vocabulary(self.tokenizer) self.cache = get_outlines_cache() - def _compile_index(self, regex_string: str, - vocabulary: OutlinesVocabulary) -> oc.Index: + def _compile_index( + self, regex_string: str, vocabulary: OutlinesVocabulary + ) -> oc.Index: cache_key = f"{vocabulary._hash}_{regex_string}" if cache_key in self.cache: return self.cache[cache_key] @@ -62,8 +65,9 @@ class OutlinesBackend(StructuredOutputBackend): return index - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: regex = json_schema.build_regex_from_schema(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -79,10 +83,13 @@ class OutlinesBackend(StructuredOutputBackend): index = self._compile_index(regex, self.vocabulary) max_rollback_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config is not None else 0) - return OutlinesGrammar(vocab_size=self.vocab_size, - guide=oc.Guide( - index, max_rollback=max_rollback_tokens)) + if self.vllm_config.speculative_config is not None + else 0 + ) + return OutlinesGrammar( + vocab_size=self.vocab_size, + guide=oc.Guide(index, max_rollback=max_rollback_tokens), + ) def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: return torch.full( @@ -98,20 +105,15 @@ class OutlinesBackend(StructuredOutputBackend): @dataclass class OutlinesGrammar(StructuredOutputGrammar): - vocab_size: int guide: oc.Guide = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) + num_processed_tokens: int = field( + default_factory=lambda: 0, repr=False, hash=False, init=False + ) # outlines_core signals done on DFA accept; vLLM expects done after EOS. # We delay the finished flag by one step so EOS can still be emitted. - _prev_finished: bool = field(default=False, - init=False, - repr=False, - hash=False) + _prev_finished: bool = field(default=False, init=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -142,8 +144,7 @@ class OutlinesGrammar(StructuredOutputGrammar): def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: mask = bitmask[idx] - self.guide.write_mask_into(mask.data_ptr(), mask.numel(), - mask.element_size()) + self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size()) def is_terminated(self) -> bool: curr = self.guide.is_finished() @@ -158,37 +159,39 @@ class OutlinesGrammar(StructuredOutputGrammar): def validate_structured_output_request_outlines(params: SamplingParams): - if params.guided_decoding is None: + if params.structured_outputs is None: return - gd_params = params.guided_decoding + so_params = params.structured_outputs - if gd_params.regex: - validate_regex_is_buildable(gd_params.regex) - elif gd_params.json: - if isinstance(gd_params.json, str): + if so_params.regex: + validate_regex_is_buildable(so_params.regex) + elif so_params.json: + if isinstance(so_params.json, str): try: # make sure schema is valid json - json.loads(gd_params.json) - schema = gd_params.json + json.loads(so_params.json) + schema = so_params.json except json.JSONDecodeError as e: raise ValueError("Invalid JSON grammar specification.") from e else: try: - schema = json.dumps(gd_params.json) + schema = json.dumps(so_params.json) except Exception as e: raise ValueError( - f"Error serializing guided decoding jsonschema: {e}" + f"Error serializing structured outputs jsonschema: {e}" ) from e pattern = json_schema.build_regex_from_schema(schema) validate_regex_is_buildable(pattern) - elif gd_params.choice: - choices = [regex_escape(str(choice)) for choice in gd_params.choice] + elif so_params.choice: + choices = [regex_escape(str(choice)) for choice in so_params.choice] regex = "(" + "|".join(choices) + ")" validate_regex_is_buildable(regex) - elif gd_params.grammar: - raise ValueError("Outlines guided decoding backend " - "does not support grammar specifications") + elif so_params.grammar: + raise ValueError( + "Outlines structured outputs backend " + "does not support grammar specifications" + ) def _prefix_needs_context(parsed) -> bool: @@ -196,7 +199,7 @@ def _prefix_needs_context(parsed) -> bool: def subpattern_consumes(parsed) -> bool: """Return True if subpattern can consume at least one character.""" - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: # literal, character class, or dot always consumes if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): @@ -212,17 +215,18 @@ def _prefix_needs_context(parsed) -> bool: if any(subpattern_consumes(br) for br in branches): return True # grouped subpattern: recurse into its contents - elif ttype == sre_parse.SUBPATTERN and subpattern_consumes( - tval[3]): + elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]): return True # No consumers, return False return False - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: # Direct anchors or look-around - if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT, - sre_constants.ASSERT_NOT): + if ttype == sre_parse.AT or ttype in ( + sre_constants.ASSERT, + sre_constants.ASSERT_NOT, + ): return True # Nested subpattern: check @@ -261,9 +265,8 @@ def _prefix_needs_context(parsed) -> bool: def _check_unsupported(parsed) -> None: """Check for regex features unsupported by regex-automata""" - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: - # backreference if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS): raise ValueError("Backreferences are unsupported.") @@ -274,8 +277,7 @@ def _check_unsupported(parsed) -> None: # unicode word boundaries elif ttype == sre_parse.AT: - if tval in (sre_constants.AT_BOUNDARY, - sre_constants.AT_NON_BOUNDARY): + if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY): raise ValueError("Unicode word boundaries are unsupported.") elif ttype == sre_parse.BRANCH: @@ -306,15 +308,17 @@ def validate_regex_is_buildable(pattern: str) -> None: _check_unsupported(parsed) except ValueError as e: raise ValueError( - f"Regex uses unsupported feature for guided decoding: {e}. " + f"Regex uses unsupported feature for structured outputs: {e}. " "Only basic matching constructs are supported—lookarounds, " - "backreferences, and unicode boundaries are not.") from e + "backreferences, and unicode boundaries are not." + ) from e if _prefix_needs_context(parsed): raise ValueError( "Regex does not have a anchored universal start state" "This means that the Regex uses anchors (^) or look-arounds " "in a way which requires context before any token is matched." - "Guided decoding needs regexes that can match without needing " + "structured outputs needs regexes that can match without needing " "that context. Try rewriting the pattern without using these " - f"constructs. Pattern:\n{pattern}") + f"constructs. Pattern:\n{pattern}" + ) diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index d500783aa4b30..2051b336e5bf1 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -103,14 +103,15 @@ class StructuredOutputBackend(ABC): vocab_size: int @abstractmethod - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: """ Compiles a grammar specification into a structured output grammar. Args: request_type (StructuredOutputOptions): The type of structured - output request. + output request. grammar_spec (str): The grammar specification to compile. Returns: @@ -124,7 +125,7 @@ class StructuredOutputBackend(ABC): Args: max_num_seqs (int): The maximum number of sequences for which - to allocate the bitmask. + to allocate the bitmask. """ @abstractmethod diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 5e00f63804162..4b21b2591c589 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -14,12 +14,16 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) -from vllm.v1.structured_output.utils import (choice_as_grammar, - convert_lark_to_ebnf, - grammar_is_likely_lark) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) +from vllm.v1.structured_output.utils import ( + choice_as_grammar, + convert_lark_to_ebnf, + grammar_is_likely_lark, +) if TYPE_CHECKING: import xgrammar as xgr @@ -31,40 +35,25 @@ logger = init_logger(__name__) @dataclass class XgrammarBackend(StructuredOutputBackend): - def __post_init__(self): - self.disable_any_whitespace = \ - self.vllm_config.decoding_config.disable_any_whitespace + self.disable_any_whitespace = ( + self.vllm_config.structured_outputs_config.disable_any_whitespace + ) if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 - try: - if self.tokenizer.is_tekken: - encoded_vocab = self.tokenizer._vocab - else: - encoded_vocab = [ - token for token, _ in sorted( - self.tokenizer.get_vocab().items(), - key=lambda x: x[1], - ) - ] - stop_token_ids = None - if (hasattr( - self.tokenizer, - "eos_token_id", - ) and self.tokenizer.eos_token_id is not None): - stop_token_ids = [self.tokenizer.eos_token_id] - except AttributeError as e: - raise ValueError( - f"Cannot get the vocabulary of the tokenizer " - f"{type(self.tokenizer)}. The tokenizer should have a " - "get_vocab method.") from e + stop_token_ids = [self.tokenizer.eos_token_id] + + # not self.tokenizer.vocab_size as self.tokenizer.vocab + # collapses all decoded errors into a single token. + self.vocab_size = len(self.tokenizer.vocab) tokenizer_info = xgr.TokenizerInfo( # type: ignore - encoded_vocab=encoded_vocab, + encoded_vocab=self.tokenizer.vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken + else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, @@ -83,18 +72,21 @@ class XgrammarBackend(StructuredOutputBackend): self.num_speculative_tokens = 0 if self.vllm_config.speculative_config is not None: - self.num_speculative_tokens = \ + self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: ctx = self.compiler.compile_json_schema( - grammar_spec, any_whitespace=not self.disable_any_whitespace) + grammar_spec, any_whitespace=not self.disable_any_whitespace + ) elif request_type == StructuredOutputOptions.JSON_OBJECT: ctx = self.compiler.compile_json_schema( - '{"type": "object"}', - any_whitespace=not self.disable_any_whitespace) + '{"type": "object"}', any_whitespace=not self.disable_any_whitespace + ) elif request_type == StructuredOutputOptions.GRAMMAR: ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -106,15 +98,20 @@ class XgrammarBackend(StructuredOutputBackend): begin=s["begin"], schema=json.dumps(s["schema"]), end=s["end"], - ) for s in s_tag["structures"] + ) + for s in s_tag["structures"] ] - ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) + structural_tag = xgr.StructuralTag.from_legacy_structural_tag( + tags, s_tag["triggers"] + ) + ctx = self.compiler.compile_structural_tag(structural_tag) else: logger.error( "Validation should have already occurred. Please file an issue." ) raise ValueError( - f"grammar is not of valid supported types. ({request_type!s})") + f"grammar is not of valid supported types. ({request_type!s})" + ) return XgrammarGrammar( matcher=xgr.GrammarMatcher( @@ -144,10 +141,9 @@ class XgrammarGrammar(StructuredOutputGrammar): vocab_size: int matcher: xgr.GrammarMatcher = field(hash=False) ctx: xgr.CompiledGrammar = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) + num_processed_tokens: int = field( + default_factory=lambda: 0, repr=False, hash=False, init=False + ) _is_terminated: bool = field(default=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: @@ -162,7 +158,10 @@ class XgrammarGrammar(StructuredOutputGrammar): if not self.matcher.accept_token(token): logger.error( "Failed to advance FSM for request %s " - "for tokens %s. Please file an issue.", request_id, token) + "for tokens %s. Please file an issue.", + request_id, + token, + ) return False self.num_processed_tokens += 1 self._is_terminated = self.matcher.is_terminated() @@ -214,8 +213,9 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Check for array unsupported keywords if obj.get("type") == "array" and any( - key in obj for key in ("uniqueItems", "contains", - "minContains", "maxContains")): + key in obj + for key in ("uniqueItems", "contains", "minContains", "maxContains") + ): return True # Unsupported keywords for strings @@ -224,8 +224,14 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Unsupported keywords for objects if obj.get("type") == "object" and any( - key in obj for key in ("minProperties", "maxProperties", - "propertyNames", "patternProperties")): + key in obj + for key in ( + "minProperties", + "maxProperties", + "propertyNames", + "patternProperties", + ) + ): return True # Recursively check all nested objects and arrays @@ -248,76 +254,85 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: Raises ValueError if the request is not supported. """ - if sampling_params.guided_decoding is None: + if sampling_params.structured_outputs is None: return - gd_params = sampling_params.guided_decoding + so_params = sampling_params.structured_outputs - if gd_params.regex: + if so_params.regex: try: - xgr.Grammar.from_regex(gd_params.regex) + xgr.Grammar.from_regex(so_params.regex) except Exception as err: - raise ValueError("Failed to transform regex into a grammar: " - f"{err}") from err + raise ValueError( + f"Failed to transform regex into a grammar: {err}" + ) from err - if gd_params.choice: - choice_grammar = choice_as_grammar(gd_params.choice) + if so_params.choice: + choice_grammar = choice_as_grammar(so_params.choice) try: xgr.Grammar.from_ebnf(choice_grammar) except Exception as err: - raise ValueError("Failed to transform choices into a grammar: " - "{err}") from err - gd_params.choice = None - gd_params.grammar = choice_grammar + raise ValueError( + "Failed to transform choices into a grammar: {err}" + ) from err + so_params.choice = None + so_params.grammar = choice_grammar return - if gd_params.json: - if isinstance(gd_params.json, str): + if so_params.json: + if isinstance(so_params.json, str): try: - schema = json.loads(gd_params.json) + schema = json.loads(so_params.json) except json.JSONDecodeError as e: raise ValueError("Invalid JSON grammar specification.") from e else: - schema = gd_params.json + schema = so_params.json try: xgr.Grammar.from_json_schema(schema) except Exception as err: - raise ValueError("Failed to transform json schema into a grammar: " - f"{err}") from err + raise ValueError( + f"Failed to transform json schema into a grammar: {err}" + ) from err if has_xgrammar_unsupported_json_features(schema): - raise ValueError("The provided JSON schema contains features not " - "supported by xgrammar.") + raise ValueError( + "The provided JSON schema contains features not supported by xgrammar." + ) return - if gd_params.grammar: - if grammar_is_likely_lark(gd_params.grammar): + if so_params.grammar: + if grammar_is_likely_lark(so_params.grammar): # xgrammar supports EBNF grammars only try: - gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + so_params.grammar = convert_lark_to_ebnf(so_params.grammar) except ValueError as e: raise ValueError( - "Failed to convert the grammar from Lark to EBNF. ") from e + "Failed to convert the grammar from Lark to EBNF. " + ) from e # Test parsing EBNF grammar, possibly already converted from Lark try: # parse the grammar, but we aren't compiling it. - xgr.Grammar.from_ebnf(gd_params.grammar) + xgr.Grammar.from_ebnf(so_params.grammar) except Exception as e: raise ValueError("Invalid grammar specification.") from e return - if gd_params.structural_tag: + if so_params.structural_tag: try: - s_tag = json.loads(gd_params.structural_tag) + s_tag = json.loads(so_params.structural_tag) tags = [ xgr.StructuralTagItem( begin=s["begin"], schema=json.dumps(s["schema"]), end=s["end"], - ) for s in s_tag["structures"] + ) + for s in s_tag["structures"] ] - xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + structural_tag = xgr.StructuralTag.from_legacy_structural_tag( + tags, s_tag["triggers"] + ) + xgr.Grammar.from_structural_tag(structural_tag) except Exception as e: raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index fc365f12573fc..233c7c1e7805d 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -10,18 +10,20 @@ from concurrent.futures._base import TimeoutError from typing import Optional, Union, cast from vllm.sampling_params import SamplingParams -from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, - StructuredOutputKey, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputGrammar, + StructuredOutputKey, + StructuredOutputOptions, +) @dataclasses.dataclass class StructuredOutputRequest: - sampling_params: SamplingParams - _grammar: Optional[Union[Future[StructuredOutputGrammar], - StructuredOutputGrammar]] = None - reasoning_ended: Optional[bool] = None + _grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = ( + None + ) + reasoning_ended: bool | None = None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports @@ -41,15 +43,17 @@ class StructuredOutputRequest: return self._check_grammar_completion() @property - def grammar(self) -> Optional[StructuredOutputGrammar]: + def grammar(self) -> StructuredOutputGrammar | None: completed = self._check_grammar_completion() - return cast(Optional[StructuredOutputGrammar], - self._grammar) if completed else None + return ( + cast(Optional[StructuredOutputGrammar], self._grammar) + if completed + else None + ) @grammar.setter def grammar( - self, grammar: Union[StructuredOutputGrammar, - Future[StructuredOutputGrammar]] + self, grammar: Union[StructuredOutputGrammar, Future[StructuredOutputGrammar]] ) -> None: self._grammar = grammar @@ -58,9 +62,8 @@ class StructuredOutputRequest: return get_structured_output_key(self.sampling_params) -def get_structured_output_key( - sampling_params: SamplingParams) -> StructuredOutputKey: - params = sampling_params.guided_decoding +def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey: + params = sampling_params.structured_outputs assert params is not None, "params can't be None." if params.json is not None: if not isinstance(params.json, str): diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 95319831d5121..b7326847d016d 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -8,7 +8,9 @@ import importlib.metadata import os from typing import TYPE_CHECKING +import numpy as np import regex as re +import torch from cachetools import LRUCache from diskcache import Cache @@ -20,9 +22,13 @@ if TYPE_CHECKING: import outlines_core as oc import transformers.file_utils as file_utils import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2 + import xgrammar as xgr from vllm.transformers_utils.tokenizer import AnyTokenizer + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch else: + xgr = LazyLoader("xgr", globals(), "xgrammar") oc = LazyLoader("oc", globals(), "outlines_core") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") tokenization_gpt2 = LazyLoader( @@ -36,6 +42,85 @@ logger = init_logger(__name__) CACHE = None +def apply_grammar_bitmask( + scheduler_output: SchedulerOutput, + input_batch: InputBatch, + logits: torch.Tensor, + device: torch.device, +) -> None: + """ + Apply grammar bitmask to output logits of the model with xgrammar function. + + Args: + scheduler_output (SchedulerOutput): The result of engine scheduling. + input_batch (InputBatch): The input of model runner. + logits (torch.Tensor): The output logits of model forward. + device (torch.device): The device that model runner running on. + """ + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.full( + shape=(logits.shape[0], grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype, + ) + cumulative_index = 0 + seq = sorted( + scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1] + ) + for req_id, _ in seq: + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) + if req_id in struct_out_req_batch_indices: + logit_index = struct_out_req_batch_indices[req_id] + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # If the length of out indices and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = len(out_indices) == logits.shape[0] + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() + + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(device, non_blocking=True), + indices=out_indices if not skip_out_indices else None, + ) + + class OutlinesVocabulary: """ Wrapper class for `outlines_core.Vocabulary`, @@ -47,8 +132,7 @@ class OutlinesVocabulary: self.inner = vocabulary # Have to do abs(hash()) because python hashes can # be negative, and we are using hash as a cache key. - hex_str = hashlib.sha256( - vocabulary.__repr__().encode('utf-8')).hexdigest() + hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest() hash_int = int(hex_str, 16) self._hash = hash_int @@ -65,9 +149,9 @@ def get_outlines_cache_path() -> str: elif xdg_cache_home: return os.path.join(xdg_cache_home, ".cache", "outlines") # If homedir is "/", we may be inside a container, and thus writing to - # root would be problematic, so we fallback to using a tempfile. + # root would be problematic, so we fall back to using a tempfile. # Also validate the path exists, since os.path.expanduser does - # not garuntee existence. + # not guarantee existence. elif os.path.isdir(home_dir) and home_dir != "/": # Default Unix fallback: ~/.cache/outlines return os.path.join(home_dir, ".cache", "outlines") @@ -84,16 +168,18 @@ def get_outlines_cache(): cache_dir = get_outlines_cache_path() if envs.VLLM_V1_USE_OUTLINES_CACHE: - logger.warning("Enabling outlines cache. This is an unbounded on-disk " - "cache. It may consume a lot of disk space and should " - "not be used with untrusted clients.") + logger.warning( + "Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients." + ) cache = Cache(cache_dir, eviction_policy="none", cull_limit=0) outlines_version = importlib.metadata.version("outlines_core") - cached_version = cache.get('__version__', None) + cached_version = cache.get("__version__", None) if cached_version != outlines_version: cache.clear() - cache.set('__version__', outlines_version) + cache.set("__version__", outlines_version) return cache else: return LRUCache(maxsize=128) @@ -113,19 +199,17 @@ def _reduced_vocabulary( A Dict of token string -> equivalent token ids """ - unicode_to_bytes = { - v: k - for k, v in tokenization_gpt2.bytes_to_unicode().items() - } + unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: - string = tokenizer.convert_tokens_to_string([token]) # A hack to handle missing spaces to HF's Llama tokenizers - if (type(token) is str - and token.startswith(file_utils.SPIECE_UNDERLINE) - or token == "<0x20>"): + if ( + type(token) is str + and token.startswith(file_utils.SPIECE_UNDERLINE) + or token == "<0x20>" + ): return " " + string return string @@ -145,8 +229,7 @@ def _reduced_vocabulary( # by this point. token_bytes = bytes(token_str) # type: ignore[arg-type] - elif "\ufffd" in token_str and not re_replacement_seq.match( - token_str): + elif "\ufffd" in token_str and not re_replacement_seq.match(token_str): # Handle tokens with invalid UTF-8 sequences. if re_llama_byte_token.match(token): # Llama-like tokenizers use <0xXX> for incomplete sequences. @@ -157,12 +240,13 @@ def _reduced_vocabulary( if None in byte_vals: raise RuntimeError( f"Cannot convert token `{token}`" - f" ({token_idx}) to bytes: {token_str}") + f" ({token_idx}) to bytes: {token_str}" + ) # safe to ignore, since if None in byte_vals, # an error is thrown. token_bytes = bytes(byte_vals) # type: ignore[arg-type] else: - token_bytes = token_str.encode('utf-8') + token_bytes = token_str.encode("utf-8") if token_idx != eos_token_id: vocabulary.setdefault(token_bytes, []).append(token_idx) @@ -173,16 +257,18 @@ def _reduced_vocabulary( def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: - """Get the `Vocabulary` object for a given tokenizer. - """ + """Get the `Vocabulary` object for a given tokenizer.""" if hasattr(tokenizer, "_outlines_vocabulary"): return tokenizer._outlines_vocabulary # type: ignore try: - if hasattr( + if ( + hasattr( tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: + ) + and tokenizer.eos_token_id is not None + ): eos_token_id = tokenizer.eos_token_id else: raise ValueError( @@ -191,17 +277,18 @@ def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: reduced_vocab = _reduced_vocabulary( tokenizer, - eos_token_id #type: ignore + eos_token_id, # type: ignore ) - vocabulary = OutlinesVocabulary( - oc.Vocabulary(eos_token_id, reduced_vocab)) + vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab)) tokenizer._outlines_vocabulary = vocabulary # type: ignore return vocabulary except AttributeError as e: - raise ValueError(f"Cannot get the vocabulary of the tokenizer " - f"({type(tokenizer)}). The tokenizer should have a " - "get_vocab method.") from e + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"({type(tokenizer)}). The tokenizer should have a " + "get_vocab method." + ) from e def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -223,14 +310,14 @@ def grammar_is_likely_lark(grammar_str: str) -> bool: if not grammar_str or not isinstance(grammar_str, str): return False - for line in grammar_str.split('\n'): + for line in grammar_str.split("\n"): # Remove both comment styles - line = re.sub(r'(#|//).*$', '', line).strip() + line = re.sub(r"(#|//).*$", "", line).strip() if not line: continue # Look for EBNF rule definition - if '::=' in line: + if "::=" in line: return False return True @@ -267,40 +354,41 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: def clean_line(line: str) -> str: """Remove comments and whitespace from line.""" - return re.sub(r'(#|//).*$', '', line).strip() + return re.sub(r"(#|//).*$", "", line).strip() def check_quotes(text: str, rule_name: str, line_num: int) -> None: """Validate quote matching in text.""" if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: - raise ValueError( - f"Mismatched quotes in {rule_name} on line {line_num}") + raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}") def extract_references(text: str) -> set[str]: """Extract rule references from text.""" # Remove quoted strings and special characters - text = re.sub(r'"[^"]*"', '', text) - text = re.sub(r'[+*?()|\[\]{}]', ' ', text) - return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + text = re.sub(r'"[^"]*"', "", text) + text = re.sub(r"[+*?()|\[\]{}]", " ", text) + return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text)) # First pass: Find root rule and validate rule definitions - lines = [clean_line(line) for line in grammar_str.split('\n')] + lines = [clean_line(line) for line in grammar_str.split("\n")] first_rule = None for line_num, line in enumerate(lines, 1): - if not line or line.startswith('|'): + if not line or line.startswith("|"): continue - if ':' in line: + if ":" in line: try: - name = line.split(':', 1)[0].strip().strip('?') + name = line.split(":", 1)[0].strip().strip("?") defined_rules.add(name) if first_rule is None: first_rule = name - if name == 'start': - first_rule = 'start' + if name == "start": + first_rule = "start" except IndexError as e: - raise ValueError(f"Invalid rule format on line {line_num}. " - "Expected 'rule_name: definition'") from e + raise ValueError( + f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'" + ) from e if not defined_rules: raise ValueError("No valid rules found in grammar") @@ -317,29 +405,33 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: continue try: - if ':' in line and not line.startswith('|'): + if ":" in line and not line.startswith("|"): # Save previous rule if exists if current_rule: output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + f"{current_rule} ::= {' | '.join(current_definition)}" + ) # Process new rule - name, definition = line.split(':', 1) - current_rule = name.strip().strip('?') + name, definition = line.split(":", 1) + current_rule = name.strip().strip("?") check_quotes(definition, f"rule '{current_rule}'", line_num) definition = re.sub(r"'([^']*)'", r'"\1"', definition) referenced_rules.update(extract_references(definition)) current_definition = [definition.strip()] - elif line.startswith('|'): + elif line.startswith("|"): if not current_rule: - raise ValueError(f"Alternative '|' on line {line_num} " - "without a preceding rule definition") + raise ValueError( + f"Alternative '|' on line {line_num} " + "without a preceding rule definition" + ) alt_def = line[1:].strip() - check_quotes(alt_def, f"alternative for rule '{current_rule}'", - line_num) + check_quotes( + alt_def, f"alternative for rule '{current_rule}'", line_num + ) alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) referenced_rules.update(extract_references(alt_def)) current_definition.append(alt_def) @@ -349,25 +441,24 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: # Add final rule if exists if current_rule: - output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}") # Validate all rules are defined - undefined_rules = referenced_rules - defined_rules - {'root'} + undefined_rules = referenced_rules - defined_rules - {"root"} if undefined_rules: - raise ValueError("Referenced rules are not defined: " - f"{', '.join(sorted(undefined_rules))}") + raise ValueError( + f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}" + ) - return '\n'.join(output_lines) + return "\n".join(output_lines) def choice_as_grammar(choice: list[str]) -> str: - def escape_ebnf_string(s: str) -> str: """Escape special characters in a EBNF string.""" # Escape double quotes and backslashes - return re.sub(r'(["\\])', r'\\\1', s) + return re.sub(r'(["\\])', r"\\\1", s) escaped_choices = (escape_ebnf_string(c) for c in choice) - grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices) return grammar diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b5750c82db023..9259432628949 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,27 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import contextlib import multiprocessing import time import weakref from collections.abc import Sequence +from contextlib import AbstractContextManager from multiprocessing import connection from multiprocessing.process import BaseProcess -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Optional, + TypeVar, + Union, + overload, +) import torch +from torch.autograd.profiler import record_function +import vllm.envs as envs from vllm.logger import init_logger -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, - kill_process_tree) +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import ( + get_open_port, + get_open_zmq_ipc_path, + get_tcp_uri, + kill_process_tree, +) if TYPE_CHECKING: + import numpy as np + from vllm.v1.engine.coordinator import DPCoordinator - from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager) + from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager logger = init_logger(__name__) @@ -29,7 +45,6 @@ T = TypeVar("T") class ConstantList(Generic[T], Sequence): - def __init__(self, x: list[T]) -> None: self._x = x @@ -51,31 +66,23 @@ class ConstantList(Generic[T], Sequence): def clear(self): raise TypeError("Cannot clear a constant list") - def index(self, - item: T, - start: int = 0, - stop: Optional[int] = None) -> int: - return self._x.index(item, start, - stop if stop is not None else len(self._x)) + def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int: + return self._x.index(item, start, stop if stop is not None else len(self._x)) @overload - def __getitem__(self, item: int) -> T: - ... + def __getitem__(self, item: int) -> T: ... @overload - def __getitem__(self, s: slice, /) -> list[T]: - ... + def __getitem__(self, s: slice, /) -> list[T]: ... def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]: return self._x[item] @overload - def __setitem__(self, item: int, value: T): - ... + def __setitem__(self, item: int, value: T): ... @overload - def __setitem__(self, s: slice, value: T, /): - ... + def __setitem__(self, s: slice, value: T, /): ... def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]): raise TypeError("Cannot set item in a constant list") @@ -96,9 +103,45 @@ class ConstantList(Generic[T], Sequence): return f"ConstantList({self._x})" -def get_engine_client_zmq_addr(local_only: bool, - host: str, - port: int = 0) -> str: +class CpuGpuBuffer: + """Buffer to easily copy tensors between CPU and GPU.""" + + def __init__( + self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + with_numpy: bool = True, + ) -> None: + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory) + self.gpu = torch.zeros_like(self.cpu, device=device) + self.np: np.ndarray + # To keep type hints simple (avoiding generics and subclasses), we + # only conditionally create the numpy array attribute. This can cause + # AttributeError if `self.np` is accessed when `with_numpy=False`. + if with_numpy: + if dtype == torch.bfloat16: + raise ValueError( + "Bfloat16 torch tensors cannot be directly cast to a " + "numpy array, so call CpuGpuBuffer with with_numpy=False" + ) + self.np = self.cpu.numpy() + + def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: + if n is None: + return self.gpu.copy_(self.cpu, non_blocking=True) + return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + + def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: + """NOTE: Because this method is non-blocking, explicit synchronization + is needed to ensure the data is copied to CPU.""" + if n is None: + return self.cpu.copy_(self.gpu, non_blocking=True) + return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) + + +def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: """Assign a new ZMQ socket address. If local_only is True, participants are colocated and so a unique IPC @@ -107,13 +150,16 @@ def get_engine_client_zmq_addr(local_only: bool, Otherwise, the provided host and port will be used to construct a TCP address (port == 0 means assign an available port).""" - return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( - host, port or get_open_port())) + return ( + get_open_zmq_ipc_path() + if local_only + else (get_tcp_uri(host, port or get_open_port())) + ) class APIServerProcessManager: """Manages a group of API server processes. - + Handles creation, monitoring, and termination of API server worker processes. Also monitors extra processes to check if they are healthy. """ @@ -130,7 +176,7 @@ class APIServerProcessManager: stats_update_address: Optional[str] = None, ): """Initialize and start API server worker processes. - + Args: target_server_fn: Function to call for each API server process listen_address: Address to listen for client connections @@ -139,7 +185,7 @@ class APIServerProcessManager: num_servers: Number of API server processes to start input_addresses: Input addresses for each API server output_addresses: Output addresses for each API server - stats_update_address: Optional stats update address + stats_update_address: Optional stats update address """ self.listen_address = listen_address self.sock = sock @@ -149,21 +195,23 @@ class APIServerProcessManager: spawn_context = multiprocessing.get_context("spawn") self.processes: list[BaseProcess] = [] - for i, in_addr, out_addr in zip(range(num_servers), input_addresses, - output_addresses): + for i, in_addr, out_addr in zip( + range(num_servers), input_addresses, output_addresses + ): client_config = { "input_address": in_addr, "output_address": out_addr, "client_count": num_servers, - "client_index": i + "client_index": i, } if stats_update_address is not None: client_config["stats_update_address"] = stats_update_address - proc = spawn_context.Process(target=target_server_fn, - name=f"ApiServer_{i}", - args=(listen_address, sock, args, - client_config)) + proc = spawn_context.Process( + target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, client_config), + ) self.processes.append(proc) proc.start() @@ -178,12 +226,14 @@ class APIServerProcessManager: def wait_for_completion_or_failure( - api_server_manager: APIServerProcessManager, - engine_manager: Optional[Union["CoreEngineProcManager", - "CoreEngineActorManager"]] = None, - coordinator: Optional["DPCoordinator"] = None) -> None: + api_server_manager: APIServerProcessManager, + engine_manager: Optional[ + Union["CoreEngineProcManager", "CoreEngineActorManager"] + ] = None, + coordinator: Optional["DPCoordinator"] = None, +) -> None: """Wait for all processes to complete or detect if any fail. - + Raises an exception if any process exits with a non-zero status. Args: @@ -194,16 +244,14 @@ def wait_for_completion_or_failure( coordinator: The coordinator for data parallel. """ - from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager) + from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager try: logger.info("Waiting for API servers to complete ...") # Create a mapping of sentinels to their corresponding processes # for efficient lookup sentinel_to_proc: dict[Any, BaseProcess] = { - proc.sentinel: proc - for proc in api_server_manager.processes + proc.sentinel: proc for proc in api_server_manager.processes } if coordinator: @@ -219,8 +267,7 @@ def wait_for_completion_or_failure( # Check if any process terminates while sentinel_to_proc or actor_run_refs: # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, - timeout=5) + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) # Process any terminated processes for sentinel in ready_sentinels: @@ -230,17 +277,18 @@ def wait_for_completion_or_failure( if proc.exitcode != 0: raise RuntimeError( f"Process {proc.name} (PID: {proc.pid}) " - f"died with exit code {proc.exitcode}") + f"died with exit code {proc.exitcode}" + ) if actor_run_refs: import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) except KeyboardInterrupt: logger.info("Received KeyboardInterrupt, shutting down API servers...") except Exception as e: - logger.exception("Exception occurred while running API servers: %s", - str(e)) + logger.exception("Exception occurred while running API servers: %s", str(e)) raise finally: logger.info("Terminating remaining processes ...") @@ -273,8 +321,9 @@ def shutdown(procs: list[BaseProcess]): kill_process_tree(pid) -def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, - length: int) -> torch.Tensor: +def copy_slice( + from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int +) -> torch.Tensor: """ Copy the first length elements of a tensor into another tensor in a non-blocking manner. @@ -287,8 +336,8 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def report_usage_stats( - vllm_config, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None: + vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT +) -> None: """Report usage statistics if enabled.""" if not is_usage_stats_enabled(): @@ -296,33 +345,47 @@ def report_usage_stats( from vllm.model_executor.model_loader import get_architecture_class_name + parallel_config = vllm_config.parallel_config + usage_message.report_usage( get_architecture_class_name(vllm_config.model_config), usage_context, extra_kvs={ # Common configuration - "dtype": - str(vllm_config.model_config.dtype), - "tensor_parallel_size": - vllm_config.parallel_config.tensor_parallel_size, - "block_size": - vllm_config.cache_config.block_size, - "gpu_memory_utilization": - vllm_config.cache_config.gpu_memory_utilization, - + "dtype": str(vllm_config.model_config.dtype), + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "block_size": vllm_config.cache_config.block_size, + "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, + "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes, # Quantization - "quantization": - vllm_config.model_config.quantization, - "kv_cache_dtype": - str(vllm_config.cache_config.cache_dtype), - + "quantization": vllm_config.model_config.quantization, + "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype), # Feature flags - "enable_lora": - bool(vllm_config.lora_config), - "enable_prefix_caching": - vllm_config.cache_config.enable_prefix_caching, - "enforce_eager": - vllm_config.model_config.enforce_eager, - "disable_custom_all_reduce": - vllm_config.parallel_config.disable_custom_all_reduce, - }) + "enable_lora": bool(vllm_config.lora_config), + "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + }, + ) + + +_PROFILER_FUNC = None + + +def record_function_or_nullcontext(name: str) -> AbstractContextManager: + global _PROFILER_FUNC + + # fast path assume it is set + if _PROFILER_FUNC is not None: + return _PROFILER_FUNC(name) + + func = contextlib.nullcontext + if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: + func = record_function + elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: + import nvtx + + func = nvtx.annotate + + _PROFILER_FUNC = func + return func(name) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 5662fc350e198..0c44834b55056 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,17 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union import numpy as np import torch +from vllm.distributed import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.utils import CpuGpuBuffer logger = init_logger(__name__) class BlockTable: - def __init__( self, block_size: int, @@ -20,36 +22,71 @@ class BlockTable: max_num_batched_tokens: int, pin_memory: bool, device: torch.device, + kernel_block_size: int, ): - self.block_size = block_size + """ + Args: + block_size: Block size used for KV cache memory allocation + max_num_reqs: Maximum number of concurrent requests supported. + max_num_blocks_per_req: Maximum number of blocks per request. + max_num_batched_tokens: Maximum number of tokens in a batch. + pin_memory: Whether to pin memory for faster GPU transfers. + device: Target device for the block table. + kernel_block_size: The block_size of underlying attention kernel. + Will be the same as `block_size` if `block_size` is supported + by the attention kernel. + """ self.max_num_reqs = max_num_reqs - self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device - self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32, + if kernel_block_size == block_size: + # Standard case: allocation and computation use same block size + # No block splitting needed, direct mapping + self.block_size = block_size + self.blocks_per_kv_block = 1 + self.use_hybrid_blocks = False + else: + # Hybrid case: allocation block size differs from kernel block size + # Memory blocks are subdivided to match kernel requirements + # Example: 32-token memory blocks with 16-token kernel blocks + # → Each memory block corresponds to 2 kernel blocks + if block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"kv_manager_block_size size {block_size} evenly" + ) + + self.block_size = kernel_block_size + self.blocks_per_kv_block = block_size // kernel_block_size + self.use_hybrid_blocks = True + + self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block + + self.block_table = self._make_buffer( + self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 ) - self.block_table_cpu = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() - self.slot_mapping = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device=self.device) + self.slot_mapping = self._make_buffer( + self.max_num_batched_tokens, dtype=torch.int64 + ) + + if self.use_hybrid_blocks: + self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape( + 1, -1 + ) + else: + self._kernel_block_arange = None + + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 def append_row( self, @@ -58,10 +95,14 @@ class BlockTable: ) -> None: if not block_ids: return + + if self.use_hybrid_blocks: + block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks - self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.block_table.np[row_idx, start : start + num_blocks] = block_ids def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 @@ -69,74 +110,165 @@ class BlockTable: def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] - self.block_table_np[tgt, :num_blocks] = self.block_table_np[ - src, :num_blocks] + block_table_np = self.block_table.np + block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks def swap_row(self, src: int, tgt: int) -> None: - num_blocks_src = self.num_blocks_per_row[src] - num_blocks_tgt = self.num_blocks_per_row[tgt] - self.num_blocks_per_row[src] = num_blocks_tgt - self.num_blocks_per_row[tgt] = num_blocks_src + src_tgt, tgt_src = [src, tgt], [tgt, src] + self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src] + self.block_table.np[src_tgt] = self.block_table.np[tgt_src] - self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] - - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping( + self, req_indices: np.ndarray, positions: np.ndarray + ) -> None: # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) - block_numbers = self.block_table_np.ravel()[block_table_indices] - block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:req_indices.shape[0]]) + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with an interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + + positions // virtual_block_size + ) + + block_numbers = self.block_table.np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calculate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calculate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping.np[: req_indices.shape[0]] = np.where( + mask, slot_mapping, -1 + ) + else: + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + positions // self.block_size + ) + + block_numbers = self.block_table.np.ravel()[block_table_indices] + block_offsets = positions % self.block_size + np.add( + block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping.np[: req_indices.shape[0]], + ) def commit_block_table(self, num_reqs: int) -> None: - self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], - non_blocking=True) + self.block_table.copy_to_gpu(num_reqs) def commit_slot_mapping(self, num_tokens: int) -> None: - self.slot_mapping[:num_tokens].copy_( - self.slot_mapping_cpu[:num_tokens], non_blocking=True) + self.slot_mapping.copy_to_gpu(num_tokens) def clear(self) -> None: - self.block_table.fill_(0) - self.block_table_cpu.fill_(0) + self.block_table.gpu.fill_(0) + self.block_table.cpu.fill_(0) - def get_device_tensor(self) -> torch.Tensor: - """Ruturns the device tensor of the block table.""" - return self.block_table + def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_id IDs to kernel block IDs. + + Example: + # kv_manager_block_ids: 32 tokens, + # Kernel block size: 16 tokens + # blocks_per_kv_block = 2 + >>> kv_manager_block_ids = np.array([0, 1, 2]) + >>> Result: [0, 1, 2, 3, 4, 5] + + # Each kv_manager_block_id maps to 2 kernel block id: + # kv_manager_block_id 0 → kernel block id [0, 1] + # kv_manager_block_id 1 → kernel block id [2, 3] + # kv_manager_block_id 2 → kernel block id [4, 5] + """ + if not self.use_hybrid_blocks: + return kv_manager_block_ids + + kernel_block_ids = ( + kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block + + self._kernel_block_arange + ) + + return kernel_block_ids.reshape(-1) + + def get_device_tensor(self, num_reqs: int) -> torch.Tensor: + """Returns the device tensor of the block table.""" + return self.block_table.gpu[:num_reqs] def get_cpu_tensor(self) -> torch.Tensor: """Returns the CPU tensor of the block table.""" - return self.block_table_cpu + return self.block_table.cpu def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" - return self.block_table_np + return self.block_table.np + + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory + ) class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, block_sizes: list[int]) -> None: + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + kernel_block_sizes: list[int], + num_speculative_tokens: int = 0, + ) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + + if len(kernel_block_sizes) != len(block_sizes): + raise ValueError( + f"kernel_block_sizes length ({len(kernel_block_sizes)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) + self.block_tables = [ - BlockTable(block_size, max_num_reqs, cdiv(max_model_len, - block_size), - max_num_batched_tokens, pin_memory, device) - for block_size in block_sizes + BlockTable( + block_size, + max_num_reqs, + max( + cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens, + ), + max_num_batched_tokens, + pin_memory, + device, + kernel_block_size, + ) + for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] - def append_row(self, block_ids: tuple[list[int], ...], - row_idx: int) -> None: + def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) @@ -152,8 +284,9 @@ class MultiGroupBlockTable: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping( + self, req_indices: np.ndarray, positions: np.ndarray + ) -> None: for block_table in self.block_tables: block_table.compute_slot_mapping(req_indices, positions) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index a7180afbd64b5..299567427027e 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch import torch.nn as nn @@ -9,7 +9,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner if TYPE_CHECKING: @@ -19,9 +19,9 @@ logger = init_logger(__name__) class CPUModelRunner(GPUModelRunner): - def __init__(self, vllm_config: VllmConfig, device: torch.device): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) assert device == torch.device("cpu") assert self.speculative_config is None, "spec decode is not supported." @@ -31,38 +31,18 @@ class CPUModelRunner(GPUModelRunner): self._postprocess_tensors() + # Note: Remove the override after new attention backend finished def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - scheduler_output: The scheduler output. - """ - # Attention free models have zero kv_cache_goups, however models - # like Mamba are also attention free but use the kv_cache for - # keeping its internal state. This is why we check the number - # of kv_cache groups instead of solely checking - # for self.model_config.is_attention_free. - if len(self.kv_cache_config.kv_cache_groups) == 0: - return - if len(self.kv_cache_config.kv_cache_groups) > 1: - raise ValueError("Multiple KVCacheGroups is not" - "currently supported with CPU model runner.") - - assert type(self.attn_groups[0] - [0].metadata_builder) is TorchSDPAMetadataBuilderV1 - - self.attn_groups[0][0].metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + raise ValueError( + "Multiple KVCacheGroups is not" + "currently supported with CPU model runner." + ) + super()._may_reorder_batch(scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors - def replace_tensor(obj: Any, cpu_attr_name: str, - device_attr_name) -> None: + def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: cpu_tensor = getattr(obj, cpu_attr_name, None) device_tensor = getattr(obj, device_attr_name, None) if cpu_tensor is not None and device_tensor is not None: @@ -70,27 +50,25 @@ class CPUModelRunner(GPUModelRunner): assert isinstance(device_tensor, torch.Tensor) setattr(obj, device_attr_name, cpu_tensor) - for k, v in vars(self).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(self, k, k[:-4]) + for v in vars(self).values(): + if isinstance(v, CpuGpuBuffer): + v.gpu = v.cpu for k, v in vars(self.input_batch).items(): if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): replace_tensor(self.input_batch, k, k[:-11]) for block_table in self.input_batch.block_table.block_tables: - for k, v in vars(block_table).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(block_table, k, k[:-4]) + for v in vars(block_table).values(): + if isinstance(v, CpuGpuBuffer): + v.gpu = v.cpu def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: - self.model = self.load_lora_model(self.model, self.model_config, - self.scheduler_config, - self.lora_config, self.device) + self.model = self.load_lora_model(self.model, self.vllm_config, self.device) def get_model(self) -> nn.Module: return self.model @@ -99,7 +77,13 @@ class CPUModelRunner(GPUModelRunner): logger.info("Warming up model for the compilation...") # Only generate graph for the generic shape with _set_global_compilation_settings(self.vllm_config): - self._dummy_run(max(16, self.max_num_reqs)) + self._dummy_run( + min( + max(16, self.max_num_reqs), + self.scheduler_config.max_num_batched_tokens, + ) + ) + logger.info("Warming up done.") def _init_device_properties(self) -> None: @@ -108,17 +92,46 @@ class CPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: pass + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + return sampled_token_ids.tolist() + + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + # Note: For CPU backend, dp padding is not required for now. + return 0, None + + +@contextmanager +def _torch_cuda_wrapper(): + class _EventPlaceholder: + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + class _StreamPlaceholder: + def __init__(self, *args, **kwargs) -> None: + pass + + cuda_event = torch.cuda.Event + cuda_stream = torch.cuda.Stream + try: + torch.cuda.Event = _EventPlaceholder + torch.cuda.Stream = _StreamPlaceholder + yield + finally: + torch.cuda.Event = cuda_event + torch.cuda.Stream = cuda_stream + @contextmanager def _set_global_compilation_settings(config: VllmConfig): - import torch._inductor.config + import torch._inductor.config as torch_inductor_config inductor_config = config.compilation_config.inductor_compile_config + # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. + freezing_value = torch_inductor_config.freezing try: - # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. - freezing_value = torch._inductor.config.freezing if inductor_config.get("max_autotune", False): - torch._inductor.config.freezing = True + torch_inductor_config.freezing = True yield finally: - torch._inductor.config.freezing = freezing_value + torch_inductor_config.freezing = freezing_value diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index be78597926e09..ee865ec8e6493 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -8,34 +8,32 @@ import torch from vllm import envs from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.model_executor.utils import set_random_seed from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo -from vllm.sequence import IntermediateTensors -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.cpu_model_runner import CPUModelRunner -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment logger = init_logger(__name__) class CPUWorker(Worker): - - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False): - super().__init__(vllm_config, - local_rank, - rank, - distributed_init_method, - is_driver_worker=is_driver_worker) + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__( + vllm_config, + local_rank, + rank, + distributed_init_method, + is_driver_worker=is_driver_worker, + ) self.parallel_config.disable_custom_all_reduce = True @@ -47,15 +45,24 @@ class CPUWorker(Worker): if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X): # For S390X/POWERPC SMT-8/4/2 self.local_omp_cpuid = self._get_autobind_cpu_ids( - lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) + lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4] + ) elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: # For x86 SMT-2, use 1 CPU per core self.local_omp_cpuid = self._get_autobind_cpu_ids( - lambda cpus: cpus[-1:]) + lambda cpus: cpus[-1:] + ) else: self.local_omp_cpuid = "all" else: - self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] + local_dp_rank = self.parallel_config.data_parallel_rank_local + omp_cpuids = omp_cpuids.split("|") + if local_dp_rank is not None: + world_size = self.parallel_config.world_size + omp_cpuids = omp_cpuids[ + local_dp_rank * world_size : (local_dp_rank + 1) * world_size + ] + self.local_omp_cpuid = omp_cpuids[self.rank] if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) @@ -63,19 +70,22 @@ class CPUWorker(Worker): logger.info(ret) # Note: unique identifier for creating allreduce shared memory - os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( - ":")[-1] + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1] # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner: CPUModelRunner = CPUModelRunner( - self.vllm_config, torch.device("cpu")) + self.vllm_config, torch.device("cpu") + ) def sleep(self, level: int = 1) -> None: logger.warning("sleep mode is not supported on CPU, ignore it.") @@ -94,55 +104,32 @@ class CPUWorker(Worker): set_random_seed(self.model_config.seed) self.model_runner.warming_up_model() - @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) - - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) - - if not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return None - - assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None - def _get_autobind_cpu_ids( - self, cpu_selector: Callable[[list[LogicalCPUInfo]], - list[LogicalCPUInfo]] + self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]] ) -> str: """ - Return CPU ids to bind based on NUMA nodes. - Currently for rank N, only CPU ids on the N-th node in available NUMA + Return CPU ids to bind based on NUMA nodes. + Currently for rank N, only CPU ids on the N-th node in available NUMA node list will be selected. Args: - cpu_selector: a callable object to select CPUs from a CPU list + cpu_selector: a callable object to select CPUs from a CPU list of a physical core. The input is a LogicalCPUInfo list, sorted by - the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be + the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be returned. """ - allowed_numa_nodes, logical_cpu_list = \ + allowed_numa_nodes, logical_cpu_list = ( CpuPlatform.get_allowed_cpu_core_node_list() + ) assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( f"No enough allowed NUMA nodes to bind threads of " f"{self.parallel_config.world_size} CPUWorkers. " f"Allowed NUMA nodes are {allowed_numa_nodes}. " - "Please try to bind threads manually.") + "Please try to bind threads manually." + ) # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`` - selected_numa_node = allowed_numa_nodes[ - self.local_rank] # type: ignore + selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore logical_cpu_list = [ x for x in logical_cpu_list if x.numa_node == selected_numa_node ] @@ -162,13 +149,20 @@ class CPUWorker(Worker): # Reserve CPUs for other processes reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU if reserve_cpu_num is None: - reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0 + need_reserve = ( + self.parallel_config.world_size > 1 + or self.parallel_config.data_parallel_size_local > 1 + ) + reserve_cpu_num = 1 if need_reserve else 0 assert len(logical_cpu_list) > reserve_cpu_num, ( f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " - f"should less than {len(logical_cpu_list)}.") + f"should less than {len(logical_cpu_list)}." + ) if reserve_cpu_num != 0: logical_cpu_list = logical_cpu_list[:-reserve_cpu_num] - logger.info("auto thread-binding list (id, physical core): %s", - [(x.id, x.physical_core) for x in logical_cpu_list]) + logger.info( + "auto thread-binding list (id, physical core): %s", + [(x.id, x.physical_core) for x in logical_cpu_list], + ) return ",".join([str(x.id) for x in logical_cpu_list]) diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py new file mode 100644 index 0000000000000..1bb6a6f4d05f7 --- /dev/null +++ b/vllm/v1/worker/dp_utils.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist + +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.worker.ubatch_utils import ( + UBatchSlices, + check_ubatch_thresholds, + create_ubatch_slices, + is_second_ubatch_empty, +) + +logger = init_logger(__name__) + + +def _get_device_and_group(parallel_config: ParallelConfig): + device = current_platform.device_type + group = get_dp_group().device_group + + # Transfering this tensor from GPU to CPU will introduce a GPU sync + # point that could adversely affect performance of vllm with asynch + # scheduling. This environment variable exists to quickly disable + # this optimization if we run into this case. + if parallel_config.disable_nccl_for_dp_synchronization: + logger.info_once("Using CPU all reduce to syncronize DP padding between ranks.") + device = "cpu" + group = get_dp_group().cpu_group + return device, group + + +def _run_ar( + should_ubatch: bool, + should_dp_pad: bool, + orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, + parallel_config: ParallelConfig, +) -> torch.Tensor: + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + device, group = _get_device_and_group(parallel_config) + tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) + tensor[0][dp_rank] = orig_num_tokens_per_ubatch + tensor[1][dp_rank] = padded_num_tokens_per_ubatch + tensor[2][dp_rank] = 1 if should_ubatch else 0 + tensor[3][dp_rank] = 1 if should_dp_pad else 0 + dist.all_reduce(tensor, group=group) + return tensor + + +def _post_process_ubatch(tensor: torch.Tensor) -> bool: + orig_num_tokens_tensor = tensor[0, :] + padded_num_tokens_tensor = tensor[1, :] + + # First determine if we are going to be ubatching. + should_ubatch: bool = bool(torch.all(tensor[2] == 1).item()) + if not should_ubatch: + return False + # If the DP ranks are planning to ubatch, make sure that + # there are no "empty" second ubatches + orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) + padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) + if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + logger.debug( + "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens + ) + should_ubatch = False + return should_ubatch + + +def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor: + num_tokens_across_dp = tensor[1, :] + if should_dp_pad: + # If DP padding is enabled, ensure that each rank is processing the same number + # of tokens + max_num_tokens = int(num_tokens_across_dp.max().item()) + return torch.tensor( + [max_num_tokens] * len(num_tokens_across_dp), + device="cpu", + dtype=torch.int32, + ) + else: + return num_tokens_across_dp.cpu() + + +def _synchronize_dp_ranks( + num_tokens_unpadded: int, + num_tokens_padded: int, + should_attempt_ubatching: bool, + should_attempt_dp_padding: bool, + parallel_config: ParallelConfig, +) -> tuple[bool, Optional[torch.Tensor]]: + """ + 1. Decides if each DP rank is going to microbatch. Either all ranks + run with microbatching or none of them do. + + 2. Determines the total number of tokens that each rank will run. + When running microbatched or if should_attempt_dp_padding is True, all + ranks will be padded out so that the run with the same number of tokens + + Returns: tuple[ + should_ubatch: Are all DP ranks going to microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including any DP padding. + ] + + """ + assert num_tokens_padded >= num_tokens_unpadded + + # Coordinate between the DP ranks via an All Reduce + # to determine the total number of tokens that each rank + # will run and if we are using ubatching or not. + tensor = _run_ar( + should_ubatch=should_attempt_ubatching, + should_dp_pad=should_attempt_dp_padding, + orig_num_tokens_per_ubatch=num_tokens_unpadded, + padded_num_tokens_per_ubatch=num_tokens_padded, + parallel_config=parallel_config, + ) + + should_dp_pad = bool(torch.all(tensor[3] == 1).item()) + + # DP ranks should all have the same value for should_attempt_dp_padding. + assert should_attempt_dp_padding == should_dp_pad + + # Check conditions for microbatching + should_ubatch = _post_process_ubatch(tensor) + + if should_ubatch and not should_dp_pad: + if is_global_first_rank(): + logger.debug( + "Microbatching has been triggered and requires DP padding. " + "Enabling DP padding even though it has been explicitly " + "disabled." + ) + should_dp_pad = True + + # Pad all DP ranks up to the maximum token count across ranks if + # should_dp_pad is True + num_tokens_after_padding = _post_process_dp_padding( + tensor, + should_dp_pad, + ) + + return should_ubatch, num_tokens_after_padding + + +def coordinate_batch_across_dp( + num_tokens_unpadded: int, + allow_microbatching: bool, + allow_dp_padding: bool, + parallel_config: ParallelConfig, + num_tokens_padded: Optional[int] = None, + uniform_decode: Optional[bool] = None, + num_scheduled_tokens_per_request: Optional[np.ndarray] = None, +) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + """ + Coordinates amongst all DP ranks to determine if and how the full batch + should be split into microbatches. + + Args: + num_tokens_unpadded: Number of tokens without accounting for padding + allow_microbatching: If microbatching should be attempted + allow_dp_padding: If all DP ranks should be padded up to the same value + parallel_config: The parallel config + num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs, + TP, etc) + uniform_decode: Only used if allow_microbatching is True. True if the batch + only contains single token decodes + num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The + number of tokens per request. + + Returns: tuple[ + ubatch_slices: if this is set then all DP ranks have agreed to + microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + padded up to the max value across all DP ranks when allow_dp_padding + is True. + ] + + """ + if parallel_config.data_parallel_size == 1: + # Early exit. + return None, None + + # If the caller has explicitly enabled microbatching. + should_attempt_ubatching = False + if allow_microbatching: + # Check preconditions for microbatching + assert uniform_decode is not None + should_attempt_ubatching = check_ubatch_thresholds( + parallel_config, + num_tokens_unpadded, + uniform_decode=uniform_decode, + ) + + if num_tokens_padded is None: + num_tokens_padded = num_tokens_unpadded + + (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( + num_tokens_unpadded, + num_tokens_padded, + should_attempt_ubatching, + allow_dp_padding, + parallel_config, + ) + + # Don't microbatch unless every other DP worker is also microbatching + if not should_ubatch: + return (None, num_tokens_after_padding) + + # This doesn't actually pad the ubatch slices. It just initializes the + # split point to the padded value so that padding can be applied + # to the second ubatch in pad_out_ubatch_slice after attention + # metadata creation + assert num_tokens_after_padding is not None + token_split_point = int(num_tokens_after_padding[0].item()) // 2 + + assert num_scheduled_tokens_per_request is not None + ubatch_slices = create_ubatch_slices( + num_scheduled_tokens_per_request, token_split_point + ) + + return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f48c9de2f4e1a..0ced400bcb663 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -7,19 +7,19 @@ from typing import Optional, cast import numpy as np import torch -from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargsItem, - MultiModalKwargsItems, PlaceholderRange) +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - LogitsProcessors, - MoveDirectionality) +from vllm.v1.sample.logits_processor import ( + BatchUpdateBuilder, + LogitsProcessors, + MoveDirectionality, +) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -28,11 +28,9 @@ from vllm.v1.worker.block_table import MultiGroupBlockTable @dataclass class CachedRequestState: - req_id: str - prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_positions: list[PlaceholderRange] + prompt_token_ids: Optional[list[int]] + mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -45,31 +43,31 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None + prompt_embeds: Optional[torch.Tensor] = None def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds + ) @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) - # Temporary back-compatibility for plugins that define model runner - @property - @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " - "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargsItems]: - return [ - MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs - ] - def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: + if self.prompt_token_ids is None: + raise ValueError( + f"Tried to access token index {idx}, but that token was " + "provided via prompt_embeds, and its ID is unknown." + ) return self.prompt_token_ids[idx] - return self.output_token_ids[idx - self.num_prompt_tokens] + if idx - self.num_prompt_tokens < len(self.output_token_ids): + return self.output_token_ids[idx - self.num_prompt_tokens] + return -1 class InputBatch: - def __init__( self, max_num_reqs: int, @@ -79,9 +77,12 @@ class InputBatch: pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], logitsprocs: Optional[LogitsProcessors] = None, + logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, is_pooling_model: bool = False, + num_speculative_tokens: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -106,17 +107,23 @@ class InputBatch: pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.is_token_ids = torch.zeros( + (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False + ) + # Store prompt embeddings per request to avoid OOM from large upfront + # allocation if max_model_len is big. + # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) + self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), + (max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( @@ -126,37 +133,32 @@ class InputBatch: pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, + num_speculative_tokens=num_speculative_tokens, ) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) + self.temperature_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device) + self.top_k_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() @@ -164,46 +166,43 @@ class InputBatch: self.spec_decode_unsupported_reqs: set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.frequency_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + self.presence_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device ) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.repetition_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() + # Speculative decoding + self.num_accepted_tokens_cpu_tensor = torch.ones( + (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory + ) + self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() + # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -235,20 +234,32 @@ class InputBatch: # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} - self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, - dtype=bool) + self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool) self.req_output_token_ids: list[Optional[list[int]]] = [] # Store provided logitsprocs. If none are provided, initialize empty # data structure self.logitsprocs = logitsprocs or LogitsProcessors() + self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids + + # Store last speculative tokens for sampler. + self.spec_token_ids: list[Optional[list[int]]] = [] # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + # These are used to update output_token_ids with real sampled + # ids from prior step, if required by current sampling params + # (e.g. penalties). + self.sampled_token_ids_cpu: Optional[torch.Tensor] = None + self.async_copy_ready_event: Optional[torch.cuda.Event] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -271,8 +282,13 @@ class InputBatch: # Detailed added request metadata is only required for non-pooling # models, to support logitsprocs. self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, - request.prompt_token_ids, request.output_token_ids)) + ( + new_req_index, + request.sampling_params, + request.prompt_token_ids, + request.output_token_ids, + ) + ) return new_req_index @@ -286,22 +302,31 @@ class InputBatch: if req_index == len(self._req_ids): self._req_ids.append(req_id) self.req_output_token_ids.append(request.output_token_ids) + self.spec_token_ids.append([]) else: self._req_ids[req_index] = req_id self.req_output_token_ids[req_index] = request.output_token_ids + self.spec_token_ids[req_index] = [] self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds + ) self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - # Number of token ids in token_ids_cpu. + if request.prompt_token_ids is not None: + self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids + self.is_token_ids[req_index, :num_prompt_tokens] = True + else: + self.is_token_ids[req_index, :num_prompt_tokens] = False + if request.prompt_embeds is not None: + self.req_prompt_embeds[req_index] = request.prompt_embeds + self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids + self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens # Number of tokens without spec decode tokens. @@ -311,12 +336,11 @@ class InputBatch: self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): + if self.is_spec_decode and is_spec_decode_unsupported(sampling_params): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 + # Should avoid division by zero later when apply_temperature. + self.temperature_cpu[req_index] = 0.0 self.greedy_reqs.add(req_id) else: self.temperature_cpu[req_index] = sampling_params.temperature @@ -331,16 +355,15 @@ class InputBatch: else: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty + self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty + self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty + self.repetition_penalties_cpu[req_index] = ( + sampling_params.repetition_penalty + ) if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) @@ -350,12 +373,17 @@ class InputBatch: self.generators[req_index] = request.generator if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = (self.vocab_size - if sampling_params.logprobs == -1 - else sampling_params.logprobs) + self.num_logprobs[req_id] = ( + self.vocab_size + if sampling_params.logprobs == -1 + else sampling_params.logprobs + ) if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[ - req_id] = sampling_params.prompt_logprobs + self.num_prompt_logprobs[req_id] = ( + self.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -366,27 +394,35 @@ class InputBatch: self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device=self.device) + device=self.device, + ) self.allowed_token_ids_mask_cpu_tensor = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device="cpu") + device="cpu", + ) self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + sampling_params.allowed_token_ids + ] = False if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + self.bad_words_token_ids[req_index] = ( + sampling_params.bad_words_token_ids + ) elif pooling_params := request.pooling_params: self.pooling_params[req_id] = pooling_params self.logits_processing_needs_token_ids[req_index] = ( - pooling_params.requires_token_ids) + pooling_params.requires_token_ids + ) else: raise NotImplementedError("Unrecognized request type") + # Speculative decoding: by default 1 token is generated. + self.num_accepted_tokens_cpu[req_index] = 1 + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -419,6 +455,7 @@ class InputBatch: self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + self.spec_token_ids[req_index] = None # LoRA lora_id = self.request_lora_mapping[req_index] @@ -457,21 +494,36 @@ class InputBatch: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( + self.req_output_token_ids[i2], + self.req_output_token_ids[i1], + ) + self.spec_token_ids[i1], self.spec_token_ids[i2] = ( + self.spec_token_ids[i2], + self.spec_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( + self.req_id_to_index[old_id_i2], + self.req_id_to_index[old_id_i1], + ) + self.num_tokens[i1], self.num_tokens[i2] = ( + self.num_tokens[i2], + self.num_tokens[i1], + ) + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( + self.num_tokens_no_spec[i2], + self.num_tokens_no_spec[i1], + ) + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = ( + self.num_prompt_tokens[i2], + self.num_prompt_tokens[i1], + ) + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = ( + self.num_computed_tokens_cpu[i2], + self.num_computed_tokens_cpu[i1], + ) # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -482,10 +534,26 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + + # Swap prompt embeddings if they exist + embeds_i1 = self.req_prompt_embeds.get(i1) + embeds_i2 = self.req_prompt_embeds.get(i2) + if embeds_i1 is not None: + self.req_prompt_embeds[i2] = embeds_i1 + else: + self.req_prompt_embeds.pop(i2, None) + if embeds_i2 is not None: + self.req_prompt_embeds[i1] = embeds_i2 + else: + self.req_prompt_embeds.pop(i1, None) + self.block_table.swap_row(i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = ( + self.request_lora_mapping[i2], + self.request_lora_mapping[i1], + ) if self.is_pooling_model: # Sampling and logits parameters don't apply to pooling models. @@ -493,30 +561,42 @@ class InputBatch: # For autoregressive models, track detailed request reordering info # to support logitsprocs. - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) + self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP)) - self.temperature_cpu[i1], self.temperature_cpu[i2] = \ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] = \ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] = \ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.temperature_cpu[i1], self.temperature_cpu[i2] = ( + self.temperature_cpu[i2], + self.temperature_cpu[i1], + ) + self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = ( + self.frequency_penalties_cpu[i2], + self.frequency_penalties_cpu[i1], + ) + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = ( + self.presence_penalties_cpu[i2], + self.presence_penalties_cpu[i1], + ) + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = ( + self.repetition_penalties_cpu[i2], + self.repetition_penalties_cpu[i1], + ) + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = ( + self.num_accepted_tokens_cpu[i2], + self.num_accepted_tokens_cpu[i1], + ) swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] + ( + self.allowed_token_ids_mask_cpu_tensor[i1], + self.allowed_token_ids_mask_cpu_tensor[i2], + ) = ( + self.allowed_token_ids_mask_cpu_tensor[i2], + self.allowed_token_ids_mask_cpu_tensor[i1], + ) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -524,9 +604,6 @@ class InputBatch: Any consecutive empty indices at the very end of the list are not filled. - Args: - empty_req_indices: empty indices which may be filled. - Returns: swaps: list of (from,to) swap tuples for moved requests empty_req_indices: indices not filled by condensation @@ -541,6 +618,7 @@ class InputBatch: # The batched states are empty. self._req_ids.clear() self.req_output_token_ids.clear() + self.spec_token_ids.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices @@ -569,54 +647,72 @@ class InputBatch: self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index + spec_token_ids = self.spec_token_ids[last_req_index] + self.spec_token_ids[empty_index] = spec_token_ids + self.spec_token_ids[last_req_index] = None + num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] + self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ + last_req_index, :num_tokens + ] + if last_req_index in self.req_prompt_embeds: + self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop( + last_req_index + ) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] + last_req_index + ] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ + last_req_index + ] self.block_table.move_row(last_req_index, empty_index) self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] + last_req_index + ] if self.is_pooling_model: last_req_index -= 1 - # Samping state not used by pooling models. + # Sampling state not used by pooling models. continue # Autoregressive models require detailed tracking of condense # operations to support logitsprocs self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) + (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL) + ) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[ + last_req_index + ] + self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[ + last_req_index + ] + self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[ + last_req_index + ] + self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[ + last_req_index + ] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] + self.allowed_token_ids_mask_cpu_tensor[empty_index] = ( + self.allowed_token_ids_mask_cpu_tensor[last_req_index] + ) - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) + bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids @@ -626,6 +722,7 @@ class InputBatch: # Trim lists to the batch size. del self._req_ids[num_reqs:] del self.req_output_token_ids[num_reqs:] + del self.spec_token_ids[num_reqs:] def refresh_metadata(self): """Apply any batch updates to sampling metadata.""" @@ -648,8 +745,9 @@ class InputBatch: def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs if not self.all_greedy: - temperature = copy_slice(self.temperature_cpu_tensor, - self.temperature, num_reqs) + temperature = copy_slice( + self.temperature_cpu_tensor, self.temperature, num_reqs + ) else: temperature = None if not self.no_top_p: @@ -661,30 +759,51 @@ class InputBatch: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. - copy_slice(self.frequency_penalties_cpu_tensor, - self.frequency_penalties, num_reqs) - copy_slice(self.presence_penalties_cpu_tensor, - self.presence_penalties, num_reqs) - copy_slice(self.repetition_penalties_cpu_tensor, - self.repetition_penalties, num_reqs) + copy_slice( + self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs + ) + copy_slice( + self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs + ) + copy_slice( + self.repetition_penalties_cpu_tensor, + self.repetition_penalties, + num_reqs, + ) needs_prompt_token_ids = ( not self.no_penalties - or self.logits_processing_needs_token_ids[:num_reqs].any()) - if needs_prompt_token_ids: - # The prompt tokens are used only for applying penalties or - # step pooling during the sampling/pooling process. - # Hence copy these tensors only when there are requests which - # need penalties/step_pooler to be applied. - prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None + or self.logits_processing_needs_token_ids[:num_reqs].any() + ) + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. + prompt_token_ids = ( + self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None + ) + + # Only set output_token_ids if required by the current requests' + # sampling parameters. + needs_output_token_ids = ( + not self.no_penalties + or bool(self.bad_words_token_ids) + or self.logitsprocs_need_output_token_ids + ) + output_token_ids = ( + cast(list[list[int]], self.req_output_token_ids) + if needs_output_token_ids + else [] + ) allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None - copy_slice(self.allowed_token_ids_mask_cpu_tensor, - self.allowed_token_ids_mask, num_reqs) + copy_slice( + self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, + num_reqs, + ) allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] return SamplingMetadata( @@ -699,28 +818,23 @@ class InputBatch: frequency_penalties=self.frequency_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], self.req_output_token_ids), + output_token_ids=output_token_ids, + spec_token_ids=cast(list[list[int]], self.spec_token_ids), no_penalties=self.no_penalties, allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=self.bad_words_token_ids, logitsprocs=self.logitsprocs, ) - @property - def pooling_metadata(self) -> PoolingMetadata: - if len(self.pooling_params) == 0: - pooling_params = [] - else: - # Note, for now this assumes that all request in the batch - # are either sampling or pooling requests - assert len(self.req_ids) == len(self.pooling_params) - pooling_params = [ - self.pooling_params[req_id] for req_id in self.req_ids - ] + def get_pooling_params(self) -> list[PoolingParams]: + assert len(self.req_ids) == len(self.pooling_params) + return [self.pooling_params[req_id] for req_id in self.req_ids] + + def get_pooling_metadata(self) -> PoolingMetadata: + pooling_params = self.get_pooling_params() return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]), + prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) @@ -739,9 +853,8 @@ class InputBatch: # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) + prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray @@ -757,15 +870,61 @@ class InputBatch: 3. lora_requests: Set of relevant LoRA requests. """ - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + req_lora_mapping = self.request_lora_mapping[: self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) + token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) + self.lora_id_to_lora_request.values() + ) return prompt_lora_mapping, token_lora_mapping, active_lora_requests + def set_async_sampled_token_ids( + self, + sampled_token_ids_cpu: torch.Tensor, + async_copy_ready_event: torch.cuda.Event, + ) -> None: + """ + In async scheduling case, store ref to sampled_token_ids_cpu + tensor and corresponding copy-ready event. Used to repair + output_token_ids prior to sampling, if needed by logits processors. + """ + if self.sampling_metadata.output_token_ids: + self.sampled_token_ids_cpu = sampled_token_ids_cpu + self.async_copy_ready_event = async_copy_ready_event + else: + self.sampled_token_ids_cpu = None + self.async_copy_ready_event = None + + def update_async_output_token_ids(self) -> None: + """ + In async scheduling case, update output_token_ids in sampling metadata + from prior steps sampled token ids once they've finished copying to CPU. + This is called right before they are needed by the logits processors. + """ + output_token_ids = self.sampling_metadata.output_token_ids + if self.sampled_token_ids_cpu is None or not output_token_ids: + # Output token ids not needed or not async scheduling. + return + + assert self.prev_req_id_to_index is not None + sampled_token_ids = None + for index, req_id in enumerate(self.req_ids): + prev_index = self.prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_output_token_ids = output_token_ids[index] + if not req_output_token_ids or req_output_token_ids[-1] != -1: + # Final output id is not a placeholder, some tokens must have + # been discarded after a kv-load failure. + continue + if sampled_token_ids is None: + assert self.async_copy_ready_event is not None + self.async_copy_ready_event.synchronize() + sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist() + # Replace placeholder token id with actual sampled id. + req_output_token_ids[-1] = sampled_token_ids[prev_index] + @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -788,9 +947,11 @@ class InputBatch: @property def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) + return ( + len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0 + ) @property def max_num_logprobs(self) -> Optional[int]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 72c38af45b70e..ec824f6d6bf5e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import gc import itertools import time @@ -9,66 +8,119 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast import numpy as np import torch import torch.distributed import torch.nn as nn from tqdm import tqdm +from typing_extensions import TypeAlias import vllm.envs as envs from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, MultipleOf +from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, +) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import (is_mixture_of_experts, - supports_eagle3, - supports_transcription) +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_multimodal_pruning, + supports_transcription, +) from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) + VllmModelForPooling, + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, round_up, - supports_dynamo) -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + DeviceMemoryProfiler, + GiB_bytes, + cdiv, + check_use_alibi, + get_dtype_size, + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, + round_up, + supports_dynamo, +) +from vllm.utils.jsontree import json_map_leaves +from vllm.v1.attention.backends.flash_attn import AttentionMetadata +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, - reorder_batch_to_split_decodes_and_prefills) + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + create_fast_prefill_custom_backend, + reorder_batch_to_split_decodes_and_prefills, + split_attn_metadata, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - LogprobsTensors, ModelRunnerOutput) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, + SamplerOutput, +) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -78,32 +130,89 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.structured_output.utils import apply_grammar_bitmask +from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, KVConnectorOutput) +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, + UBatchSlices, + check_ubatch_thresholds, +) +from vllm.v1.worker.utils import is_residual_scattered_for_sp -from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, - gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from .utils import ( + AttentionGroup, + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + gather_mm_placeholders, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders, +) if TYPE_CHECKING: - import xgrammar as xgr - import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - xgr_torch_compile = LazyLoader( - "xgr_torch_compile", globals(), - "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") logger = init_logger(__name__) +AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] + + +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self.async_copy_ready_event = torch.cuda.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self.sampled_token_ids_cpu = self._sampled_token_ids.to( + "cpu", non_blocking=True + ) + self.async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self.async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -121,8 +230,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + + init_batch_invariance() model_config = self.model_config cache_config = self.cache_config @@ -134,21 +246,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self.is_pooling_model = model_config.pooler_config is not None - self.is_multimodal_raw_input_supported = ( - model_config.is_multimodal_raw_input_supported) + self.is_pooling_model = model_config.runner_type == "pooling" + self.enable_prompt_embeds = model_config.enable_prompt_embeds + self.is_multimodal_raw_input_only_model = ( + model_config.is_multimodal_raw_input_only_model + ) + # This will be overridden in load_model() + self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + self.broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) + # Model-related. - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn @@ -156,7 +282,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) + + if self.model_config.is_encoder_decoder: + # Maximum length of the encoder input, only for encoder-decoder + # models. + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens + else: + self.max_encoder_len = 0 # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -176,8 +310,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -188,21 +322,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore + vllm_config=self.vllm_config, device=self.device + ) # type: ignore else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) self.rejection_sampler = RejectionSampler() # Request states. self.requests: dict[str, CachedRequestState] = {} + self.comm_stream = torch.cuda.Stream() # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside @@ -213,53 +349,85 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. + custom_logitsprocs = model_config.logits_processors self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + custom_logitsprocs, + ), + # We currently don't know whether a particular custom logits processor + # uses output token ids so we set this conservatively. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, ) + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = ( + torch.cuda.Stream() if self.use_async_scheduling else None + ) + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - if self.compilation_config.cudagraph_capture_sizes and \ - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes) + ) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - self.query_start_loc = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device=self.device) - self.seq_lens = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + if self.dcp_world_size > 1: + self.dcp_local_seq_lens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + # Because inputs_embeds may be bfloat16 and we don't need a numpy + # version of this tensor, avoid a RuntimeError by not creating a + # numpy buffer. + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_indices = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) + self.num_discarded_requests = 0 - # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) + + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -273,52 +441,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), - dtype=torch.int64, - device=self.device) - self.mrope_positions_cpu = torch.zeros( - (3, self.max_num_tokens + 1), - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.mrope_positions_np = self.mrope_positions_cpu.numpy() + self.mrope_positions = self._make_buffer( + (3, self.max_num_tokens + 1), dtype=torch.int64 + ) - # Only relevant for models using ALiBi (e.g, MPT) - self.use_alibi = check_use_alibi(model_config) + # CUDA event to synchronize use of reused CPU tensors between steps + # when async scheduling is enabled. + self.prepare_inputs_event: Optional[torch.cuda.Event] = None + if self.use_async_scheduling: + self.prepare_inputs_event = torch.cuda.Event() + # Start in a completed state. + self.prepare_inputs_event.record(torch.cuda.default_stream()) - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) - # NOTE(woosuk): These tensors are "stateless", i.e., they are literally - # a faster version of creating a new tensor every time. Thus, we should - # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_lens_np = self.seq_lens_cpu.numpy() + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -330,19 +473,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device) + self.max_num_tokens, dtype=torch.int32, device=self.device + ) - self.uniform_decode_query_len = 1 if not self.speculative_config else \ - 1 + self.speculative_config.num_speculative_tokens + self.uniform_decode_query_len = ( + 1 + if not self.speculative_config + else 1 + self.speculative_config.num_speculative_tokens + ) # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = (MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None) + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) self.reorder_batch_threshold: Optional[int] = None @@ -352,34 +503,62 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], - torch.Tensor]] = None + self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None + self.transfer_event = torch.cuda.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_model_len, 1), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + + def _get_positions(self, num_tokens: Any): + if isinstance(num_tokens, int): + if self.uses_mrope: + return self.mrope_positions.gpu[:, :num_tokens] + return self.positions.gpu[:num_tokens] + else: + if self.uses_mrope: + return self.mrope_positions.gpu[:, num_tokens] + return self.positions.gpu[num_tokens] + + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy, + ) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() - num_reqs = self.input_batch.num_reqs - num_pooling_reqs = len(self.input_batch.pooling_params) - - if num_pooling_reqs == 0: + if not self.is_pooling_model: return model_kwargs - # This does nontrivial work. - pooling_params = self.input_batch.pooling_metadata.pooling_params - - assert num_pooling_reqs == num_reqs + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) + is not None + ): token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens[:num_reqs] + seq_lens = self.seq_lens.gpu[:num_reqs] token_type_ids = [] for i in range(num_reqs): @@ -388,7 +567,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) + device=self.device + ) return model_kwargs def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -410,15 +590,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. + if ( + self.dcp_world_size > 1 + and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" + ): + assert self.reorder_batch_threshold == 1, ( + "DCP not support reorder_batch_threshold > 1 now." + ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + ) # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ + """Initialize attributes from torch.cuda.get_device_properties""" self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -439,7 +629,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -450,12 +639,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.remove_request(req_id) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -479,14 +664,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None - if pooling_params: + if self.is_pooling_model: + assert pooling_params is not None task = pooling_params.task assert task is not None, "You did not set `task` in the API" @@ -497,8 +685,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, + prompt_embeds=new_req_data.prompt_embeds, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -523,9 +711,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + num_output_tokens = req_data.num_output_tokens[i] # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + req_index = self.input_batch.req_id_to_index.get(req_id) if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, @@ -534,29 +725,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + # Some output tokens were discarded due to a sync-KV-load + # failure. Align the cached state. + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: + assert req_index is None assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - req_index = self.input_batch.req_id_to_index.get(req_id) + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.resumed_req_token_ids[i] + assert resumed_token_ids is not None + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not @@ -565,11 +772,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -578,23 +783,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids + req_index, start_index:end_token_index + ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens + self.input_batch.spec_token_ids[req_index] = spec_token_ids # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -608,13 +815,54 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor + ) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = ( + ( + torch.cat( + [ + output_token_ids, + torch.full( + (output_token_ids.size(0), 1), + -1, + device=output_token_ids.device, + ), + ], + dim=1, + ) + == -1 + ) + .int() + .argmax(-1) + .cpu() + .numpy() + ) + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + def _init_mrope_positions(self, req_state: CachedRequestState): image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for mm_item in req_state.mm_kwargs: + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue mm_input = mm_item.get_data() if (t := mm_input.get("image_grid_thw")) is not None: image_grid_thw.append(t.tolist()) @@ -627,8 +875,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = ( + self.model.get_mrope_input_positions( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, image_grid_thw=image_grid_thw, @@ -637,32 +887,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", ) -> BatchedTensorInputs: - if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102 + if not scheduler_output or not self.is_multimodal_raw_input_only_model: return {} mm_kwargs = list[MultiModalKwargsItem]() for req in scheduler_output.scheduled_new_reqs: - mm_kwargs.extend(req.mm_kwargs) + for feature in req.mm_features: + if feature.data is not None: + mm_kwargs.append(feature.data) # Input all modalities at once + model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): mm_kwargs_combined.update(mm_kwargs_group) return mm_kwargs_combined def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: - if not self.is_multimodal_raw_input_supported: + if not self.is_multimodal_raw_input_only_model: return {} + mm_budget = self.mm_budget assert mm_budget is not None @@ -689,15 +945,117 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return cu_num_tokens, arange - def _prepare_inputs( + def _prepare_input_ids( + self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + ) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if self.enable_prompt_embeds: + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= prev_index == flattened_index + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if self.enable_prompt_embeds: + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids.cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids.gpu[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + non_blocking=True, + ) + if self.enable_prompt_embeds: + self.is_token_ids.gpu[:num_commmon_tokens] = True + return + # Upload the index tensors asynchronously so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor( + flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0 + ], + ) + + def _get_encoder_seq_lens( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: + kv_cache_spec: KVCacheSpec, + num_reqs: int, + ) -> Optional[np.ndarray]: + if not isinstance(kv_cache_spec, CrossAttentionSpec): + return None + + # Build encoder_seq_lens array mapping request indices to + # encoder lengths for inputs scheduled in this batch + encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) + for req_id in scheduler_output.scheduled_encoder_inputs: + req_index = self.input_batch.req_id_to_index[req_id] + encoder_seq_lens[req_index] = self.max_encoder_len + + return encoder_seq_lens + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> tuple[ + PerLayerAttnMetadata, + torch.Tensor, + Optional[SpecDecodeMetadata], + np.ndarray, + Optional[CommonAttentionMetadata], + int, + Optional[UBatchSlices], + Optional[torch.Tensor], + bool, + ]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, - logits_indices, spec_decode_metadata + logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens, use_cascade_attn ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -717,19 +1075,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -740,56 +1098,138 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) + token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) + if self.enable_prompt_embeds: + is_token_ids = self.input_batch.is_token_ids.flatten() + torch.index_select( + is_token_ids, + 0, + token_indices_tensor, + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + # Because we did not pre-allocate a massive prompt_embeds CPU tensor on + # the InputBatch, we need to fill in the prompt embeds into the expected + # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. + if self.input_batch.req_prompt_embeds: + output_idx = 0 + for req_idx in range(num_reqs): + num_sched = num_scheduled_tokens[req_idx] + + # Skip if this request doesn't have embeddings + if req_idx not in self.input_batch.req_prompt_embeds: + output_idx += num_sched + continue + + # Skip if no tokens scheduled + if num_sched <= 0: + output_idx += num_sched + continue + + req_embeds = self.input_batch.req_prompt_embeds[req_idx] + start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] + + # Skip if trying to read beyond available embeddings + if start_pos >= req_embeds.shape[0]: + output_idx += num_sched + continue + + # Copy available embeddings + end_pos = start_pos + num_sched + actual_end = min(end_pos, req_embeds.shape[0]) + actual_num_sched = actual_end - start_pos + + if actual_num_sched > 0: + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) + + output_idx += num_sched + + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1]) - self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True) - query_start_loc = self.query_start_loc[:num_reqs + 1] + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) + self.query_start_loc.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set enforce_eager on the prefiller in + # a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.parallel_config, + allow_microbatching=True, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, + ) + + self.seq_lens.np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) # Fill unused with 0 for full cuda graph mode. - self.seq_lens_np[num_reqs:].fill(0) - self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True) - seq_lens = self.seq_lens[:num_reqs] - max_seq_len = self.seq_lens_np[:num_reqs].max().item() + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = ( + discard_request_indices + ) + + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( - self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) else: # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) + self.positions.copy_to_gpu(total_num_scheduled_tokens) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -797,87 +1237,99 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 + num_draft_tokens = None spec_decode_metadata = None else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + # For chunked prefills, use -1 as mask rather than 0, as guided + # decoding may rollback speculative tokens. + num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens + ) logits_indices = spec_decode_metadata.logits_indices + # For DECODE only cuda graph of some attention backends (e.g., GDN). + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[num_reqs:].fill(-1) + self.num_decode_draft_tokens.copy_to_gpu() + logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: - assert self.kv_sharing_fast_prefill_logits_indices is not None - num_logits = logits_indices.shape[0] - assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) - # There might have leftover indices in logits_indices[num_logits:] - # from previous iterations, whose values may be greater than the - # batch size in the current iteration. To ensure indices are always - # valid, we fill the padded indices with the last index. - self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph( - num_logits) - else: - num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices ) - attn_metadata: dict[str, Any] = {} + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + use_cascade_attn = False # Used in the below loop. - query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1] - seq_lens_cpu = self.seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ] spec_decode_common_attn_metadata = None + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): + encoder_seq_lens = self._get_encoder_seq_lens( + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs + ) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( (num_reqs, 1), dtype=torch.int32, - pin_memory=self.pin_memory, - device="cpu").to(self.device, non_blocking=True) - slot_mapping = torch.zeros((total_num_scheduled_tokens, ), - dtype=torch.int32, - pin_memory=self.pin_memory, - device="cpu").to(self.device, - non_blocking=True) + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens,), + dtype=torch.int64, + device=self.device, + ) num_common_prefix_blocks = 0 else: blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[: - total_num_scheduled_tokens] + blk_table_tensor = blk_table.get_device_tensor(num_reqs) + slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) - num_common_prefix_blocks = ( - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id]) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ + kv_cache_group_id + ] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -891,61 +1343,92 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), causal=True, + encoder_seq_lens=encoder_seq_lens, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, ) - if self.speculative_config and \ - spec_decode_common_attn_metadata is None: - spec_decode_common_attn_metadata = common_attn_metadata + if self.speculative_config and spec_decode_common_attn_metadata is None: + if isinstance(self.drafter, EagleProposer): + if ( + self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names + ): + spec_decode_common_attn_metadata = common_attn_metadata + else: + spec_decode_common_attn_metadata = common_attn_metadata for attn_group in self.attn_groups[kv_cache_group_id]: # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 - builder = attn_group.metadata_builder + builder = attn_group.get_metadata_builder() if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, num_common_prefix_blocks, - kv_cache_group_spec.kv_cache_spec, + attn_group.kv_cache_spec, builder, ) - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) - - fast_prefill_metadata = attn_metadata_i - if (self.cache_config.kv_sharing_fast_prefill - and self.kv_sharing_fast_prefill_eligible_layers): - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - fast_prefill_metadata_type = ( - make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls=type(attn_metadata_i), )) - fast_prefill_metadata = fast_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs + ], ) - for layer_name in attn_group.layer_names: - if (self.cache_config.kv_sharing_fast_prefill - and layer_name - in self.kv_sharing_fast_prefill_eligible_layers): - attn_metadata[layer_name] = fast_prefill_metadata - continue - attn_metadata[layer_name] = attn_metadata_i + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata + ) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list + ): + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert isinstance(attn_metadata, dict) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + # disable cascade attention when DBO + if ubatch_slices is not None: + use_cascade_attn = False # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + return ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + max_num_scheduled_tokens, + ubatch_slices, + num_tokens_across_dp, + use_cascade_attn, + ) def _compute_cascade_attn_prefix_len( self, @@ -1017,18 +1500,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # this case. num_reqs = len(num_scheduled_tokens) common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() + ) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + common_prefix_len = ( + common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size + ) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None + ) + use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None + ) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1048,17 +1533,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] - num_prompt_tokens = len(req.prompt_token_ids) + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds + ) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 @@ -1072,9 +1555,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - req.mrope_positions[:,src_start:src_end] - + self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ + :, src_start:src_end + ] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1083,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dst_end = mrope_pos_ptr + completion_part_len MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions_np, + out=self.mrope_positions.np, out_offset=dst_start, mrope_position_delta=req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, @@ -1114,10 +1597,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32) + num_sampled_tokens, cumsum_dtype=np.int32 + ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens + ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -1128,26 +1613,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # cu_num_draft_tokens: [3, 3, 5, 5, 6] # arange: [0, 1, 2, 0, 1, 0] cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32) + num_draft_tokens, cumsum_dtype=np.int32 + ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens + ) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) + self.device, non_blocking=True + ) + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] - draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] metadata = SpecDecodeMetadata( @@ -1160,21 +1651,76 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return metadata - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _prepare_kv_sharing_fast_prefill( + self, + logits_indices: torch.Tensor, + ) -> torch.Tensor: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item() + ) + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1] + ): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ + :num_logits_padded + ] + return logits_indices_padded + + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return - + return [], [] # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + + return mm_kwargs, mm_hashes_pos + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + # Batch the multi-modal inputs using the helper method. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output + ) + + if not mm_kwargs: + return # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1183,39 +1729,54 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. + model = cast(SupportsMultiModal, self.model) encoder_outputs = [] - for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, + for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, each of shape - # (feature_size, hidden_size) in case the feature size is dynamic - # depending on the input multimodal items. - curr_group_outputs = self.model.get_multimodal_embeddings( - **mm_kwargs_group) + # (ekhvedchenia): Temporary hack to limit peak memory usage when + # processing multimodal data.This solves the issue with scheduler + # putting too many video samples into a single batch. Scheduler + # uses pruned vision tokens count to compare it versus compute + # budget which is incorrect (Either input media size or non-pruned + # output vision tokens count should be considered) + curr_group_outputs = [] + + if self.is_multimodal_pruning_enabled and modality == "video": + micro_batch_size = 1 + for i in range(0, num_items, micro_batch_size): + micro_batch_mm_inputs = dict( + (k, v[i : i + micro_batch_size]) + for k, v in mm_kwargs_group.items() + ) + + micro_batch_outputs = model.get_multimodal_embeddings( + **micro_batch_mm_inputs + ) + + curr_group_outputs.extend(micro_batch_outputs) + else: + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, + # each of shape (feature_size, hidden_size) in case the feature + # size is dynamic depending on the input multimodal items. + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items, ) + encoder_outputs.extend(curr_group_outputs) - for output in curr_group_outputs: - encoder_outputs.append(output) - - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( output, is_embed=pos_info.is_embed, ) @@ -1224,16 +1785,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self, scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + mm_embeds = list[torch.Tensor]() + is_mm_embed = self.is_mm_embed.cpu + is_mm_embed[:total_num_scheduled_tokens] = False + + req_start_idx = 0 + should_sync_mrope_positions = False + for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + mm_embeds_req: list[torch.Tensor] = [] + + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens - mm_positions = req_state.mm_positions - for i, pos_info in enumerate(mm_positions): + num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens + + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1252,25 +1822,87 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) + num_encoder_tokens, + ) assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] + + mm_hash = mm_feature.identifier + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed + ) + mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], is_embed=is_embed, ) - mm_embeds.append(mm_embeds_item) - return mm_embeds + mm_embeds_req.append(mm_embeds_item) + + if self.is_multimodal_pruning_enabled and self.uses_mrope: + assert req_state.mrope_positions is not None + should_sync_mrope_positions = True + mm_embeds_req, new_mrope_positions, new_delta = ( + self.model.recompute_mrope_positions( + input_ids=req_state.prompt_token_ids, + multimodal_embeddings=mm_embeds_req, + mrope_positions=req_state.mrope_positions, + num_computed_tokens=req_state.num_computed_tokens, + ) + ) + req_state.mrope_positions.copy_(new_mrope_positions) + req_state.mrope_position_delta = new_delta + + mm_embeds.extend(mm_embeds_req) + req_start_idx += num_scheduled_tokens + + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) + + if should_sync_mrope_positions: + self._calc_mrope_positions(scheduler_output) + self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) + + return mm_embeds, is_mm_embed + + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models. + + This method extracts multimodal input features from scheduled encoder + inputs and formats them for the encoder-decoder model forward pass. + """ + # Batch the multi-modal inputs using the helper method. + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + + if not mm_kwargs: + return {} + + # Group MM kwargs by modality and extract features + model = cast(SupportsMultiModal, self.model) + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ): + # Add the grouped features to encoder_features dict + # This allows the model to receive them as kwargs (e.g., + # input_features=...) + encoder_features.update(mm_kwargs_group) + + return encoder_features def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. - if isinstance(self.model, CUDAGraphWrapper): + if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): return self.model.unwrap() return self.model @@ -1296,14 +1928,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): supported_tasks = list(model.pooler.get_supported_tasks()) - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): + if ( + self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks + ): supported_tasks.remove("encode") - logger.info_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") + logger.debug_once( + "Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it." + ) + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once("Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1317,111 +1959,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return tuple(tasks) - def apply_grammar_bitmask( - self, - scheduler_output: "SchedulerOutput", - logits: torch.Tensor, - ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return - - # We receive the structured output bitmask from the scheduler, - # compacted to contain bitmasks only for structured output requests. - # The order of the requests in the bitmask is not guaranteed to be the - # same as the order of the requests in the gpu runner's batch. We need - # to sort the bitmask to match the order of the requests used here. - - # Get the batch indices of the structured output requests. - # Keep track of the number of speculative tokens scheduled for every - # request in the batch, as the logit indices are offset by this amount. - struct_out_req_batch_indices: dict[str, int] = {} - cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) - for req_id, batch_index in seq: - logit_index = batch_index + cumulative_offset - cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: - struct_out_req_batch_indices[req_id] = logit_index - - out_indices = [] - - # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.full(shape=(logits.shape[0], - grammar_bitmask.shape[1]), - fill_value=-1, - dtype=grammar_bitmask.dtype) - cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: - logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask - - # If the length of out indices and the logits have the same shape - # we don't need to pass indices to the kernel, - # since the bitmask is already aligned with the logits. - skip_out_indices = len(out_indices) == logits.shape[0] - - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() - - # Force use of the torch.compile implementation from xgrammar to work - # around issues with the Triton kernel in concurrent structured output - # scenarios. See PR #19565 and issues #19493, #18376 for details. - xgr_torch_compile.apply_token_bitmask_inplace_torch_compile( - logits, - grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices if not skip_out_indices else None, - ) - def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.compilation_config.pass_config. \ - enable_sequence_parallelism - if enabled_sp: - # When sequence parallelism is enabled, we always pad num_tokens - # to be a multiple of tensor_parallel_size (tp) earlier - assert num_tokens % tp == 0 - is_residual_scattered = tp > 1 and enabled_sp \ - and num_tokens % tp == 0 + is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens) # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. if sync_self: assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): - is_scattered = k == "residual" and is_residual_scattered - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + is_scattered = k == "residual" and is_rs + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) + v[:copy_len], non_blocking=True + ) - return IntermediateTensors({ - k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) + return IntermediateTensors( + { + k: v[: num_tokens // tp] + if k == "residual" and is_rs + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1438,287 +2007,277 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use CUDA graphs (enabled by this padding) on the decoder. - # - # TODO(tms) : There are many cases where padding is enabled for - # prefills, causing unnecessary and excessive padding of activations. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + # This is where the second ubatch is adjusted to account for the padding. + # Should be called after attention metadata creation. This just pads + # the second ubatch slice out to the total number of tokens + # (num_tokens + padding) + @staticmethod + def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int): + padded_second_ubatch_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + ubatch_slices[1] = UBatchSlice( + padded_second_ubatch_slice, padded_second_ubatch_slice + ) def _pool( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + "Either all or none of the requests in a batch must be pooling request" + ) hidden_states = hidden_states[:num_scheduled_tokens] - pooling_metadata = self.input_batch.pooling_metadata - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] + pooling_metadata = self.input_batch.get_pooling_metadata() + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens_np.tolist(), device=hidden_states.device + ) + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] - # Pooling models D2H & synchronize occurs in pooler.py:build_output - raw_pooler_output = self.model.pooler( - hidden_states=hidden_states, pooling_metadata=pooling_metadata) + model = cast(VllmModelForPooling, self.model) + raw_pooler_output: PoolerOutput = model.pooler( + hidden_states=hidden_states, + pooling_metadata=pooling_metadata, + ) + raw_pooler_output = json_map_leaves( + lambda x: x.to("cpu", non_blocking=True), + raw_pooler_output, + ) + self._sync_device() pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - - output = raw_output.data if seq_len == prompt_len else None + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens + ): + output = raw_output if seq_len == prompt_len else None pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - kv_connector_output=kv_connector_output, ) - @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) - - # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = self._prepare_inputs(scheduler_output) - - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): # Use CUDA graphs. # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens + return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + num_input_tokens: int, # Padded + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> tuple[ + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + ]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_first_rank = get_pp_group().is_first_rank # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.supports_mm_inputs: + if ( + self.supports_mm_inputs + and is_first_rank + and not self.model_config.is_encoder_decoder + ): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) - if self.supports_mm_inputs and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds_scheduled = self.model.get_input_embeddings( - input_ids=self.input_ids[:num_scheduled_tokens], - multimodal_embeddings=mm_embeds or None, + self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] model_kwargs = { **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } + elif self.enable_prompt_embeds and is_first_rank: + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the CUDA graph all the time. The v0 + # engine avoids this by "double compiling" the CUDA graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the CUDA graph will be more performant (like in the else case + # below). + token_ids_idx = ( + self.is_token_ids.gpu[:num_scheduled_tokens] + .nonzero(as_tuple=False) + .squeeze(1) + ) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids.gpu[token_ids_idx] + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs(num_input_tokens) + input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] + positions = self.mrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions[:num_input_tokens] + positions = self.positions.gpu[:num_input_tokens] - if get_pp_group().is_first_rank: + if is_first_rank: intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) - - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) - - # Run the model. - # Use persistent buffers for CUDA graphs. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ), self.maybe_get_kv_connector_output( - scheduler_output) as kv_connector_output: - - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, + num_input_tokens, intermediate_tensors, True ) - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - assert isinstance(hidden_states, IntermediateTensors) - if not broadcast_pp_output: - hidden_states.kv_connector_output = kv_connector_output - return hidden_states - get_pp_group().send_tensor_dict(hidden_states.tensors, - all_gather_group=get_tp_group()) - logits = None - else: - if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, kv_connector_output) - - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + return ( + num_scheduled_tokens, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) + def _sample( + self, + logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - sampler_output = self.sampler( + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() + return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + predict_bonus_token=True, + ) + bonus_token_ids = sampler_output.sampled_token_ids + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + self._update_states_after_model_execute(output_token_ids) + return sampler_output + + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + ) -> tuple[ + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], + ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) + discard_sampled_tokens_req_indices = self.discard_request_indices.np[ + : self.num_discarded_requests + ] + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -1726,21 +2285,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): scheduler_output.num_scheduled_tokens, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + invalid_req_indices = [] + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. @@ -1748,7 +2325,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1757,33 +2338,319 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}") + f"{self.max_model_len}" + ) - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if self.speculative_config: - assert spec_decode_common_attn_metadata is not None - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + @contextmanager + def synchronize_input_prep(self): + if self.prepare_inputs_event is None: + yield + return + + # Ensure prior step has finished with reused CPU tensors. + # This is required in the async scheduling case because + # the CPU->GPU transfer happens async. + self.prepare_inputs_event.synchronize() + try: + yield + finally: + self.prepare_inputs_event.record() + + def _model_forward( + self, + input_ids: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any], + ) -> Any: + """Helper method to call the model forward pass. + + This method can be overridden by subclasses for model execution. + Motivation: We can inspect only this method versus + the whole execute_model, which has additional logic. + + Args: + input_ids: Input token IDs + positions: Token positions + intermediate_tensors: Tensors from previous pipeline stages + inputs_embeds: Input embeddings (alternative to input_ids) + **model_kwargs: Additional model arguments + + Returns: + Model output tensor + """ + return self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + with record_function_or_nullcontext("Preprocess"): + with self.synchronize_input_prep(): + # Update persistent batch states. + self._update_states(scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward( + scheduler_output, self.vllm_config + ) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" + ) + + # Prepare the decoder inputs. + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_across_dp, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) + + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + num_input_tokens = self._get_num_input_tokens( + scheduler_output.total_num_scheduled_tokens + ) + + ( + num_scheduled_tokens, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) = self._preprocess( + scheduler_output, num_input_tokens, intermediate_tensors ) - self.eplb_step() + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + ) - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, + # Set cudagraph mode to none if calc_kv_scales is true. + if attn_metadata is not None: + metadata_list = ( + attn_metadata.values() + if isinstance(attn_metadata, dict) + else [attn_metadata] + ) + if any( + getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list + ): + cudagraph_runtime_mode = CUDAGraphMode.NONE + + # Run the model. + # Use persistent buffers for CUDA graphs. + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): + model_output = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. + hidden_states, aux_hidden_states = model_output + else: + # Common case. + hidden_states = model_output + aux_hidden_states = None + + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + # Return the pooling output. + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) + output.kv_connector_output = kv_connector_output + return output + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + else: + # Rare case. + assert not self.is_pooling_model + + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) + + with record_function_or_nullcontext("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + def propose_draft_token_ids(sampled_token_ids): + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + effective_drafter_max_model_len = self.max_model_len + if effective_drafter_max_model_len is None: + effective_drafter_max_model_len = self.model_config.max_model_len + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): + effective_drafter_max_model_len = ( + self.speculative_config.draft_model_config.max_model_len + ) + input_fits_in_drafter = spec_decode_common_attn_metadata and ( + spec_decode_common_attn_metadata.max_seq_len + + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) + if use_padded_batch_for_eagle and input_fits_in_drafter: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + + with record_function_or_nullcontext("Bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) + + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("EPLB"): + self.eplb_step() + + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1792,6 +2659,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits=num_nans_in_logits, ) + if not self.use_async_scheduling: + return output + + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None @@ -1806,30 +2692,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], + sampled_token_ids: Union[torch.Tensor, list[list[int]]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[torch.Tensor], + aux_hidden_states: Optional[list[torch.Tensor]], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) - draft_token_ids = self.propose_ngram_draft_token_ids( - sampled_token_ids) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs, + ) elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: indices = [] offset = 0 + assert spec_decode_metadata is not None for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) @@ -1841,115 +2736,108 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - req_ids = self.input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" + "padded-batch is disabled." + ) + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" + "padded-batch is enabled." + ) + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) if spec_decode_metadata is None: + token_indices_to_sample = None # input_ids can be None for multimodal models. - target_token_ids = self.input_ids[:num_scheduled_tokens] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[:num_scheduled_tokens] + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) + else: + common_attn_metadata, token_indices, token_indices_to_sample = ( + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count, + ) + ) - target_token_ids = self.input_ids[token_indices] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[token_indices] + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[token_indices] - mm_embeds = None + if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) + mm_embed_inputs = self._gather_mm_embeddings( + scheduler_output, + shift_computed_tokens=1, + ) + else: + mm_embed_inputs = None draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, + mm_embed_inputs=mm_embed_inputs, ) - return draft_token_ids - def propose_ngram_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - # TODO(woosuk): Optimize. - req_ids = self.input_batch.req_ids - draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - num_tokens = self.input_batch.num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - draft_token_ids.append([]) - continue - - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :num_tokens]) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -1962,26 +2850,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) + global_expert_load, old_global_expert_indices = EplbState.recv_state() num_logical_experts = global_expert_load.shape[1] self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts + num_local_physical_experts * new_ep_size - num_logical_experts + ) + assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0 + old_ep_size = ( + old_global_expert_indices.shape[1] // num_local_physical_experts + ) rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) + old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) } else: global_expert_load = None @@ -1993,35 +2879,50 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) if self.lora_config: - self.model = self.load_lora_model(self.model, - self.model_config, - self.scheduler_config, - self.lora_config, - self.device) + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - if supports_eagle3(self.model): - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: + if not supports_eagle3(self.get_model()): raise RuntimeError( "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") + "aux_hidden_state_outputs was requested" + ) + + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers, + ) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + + self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) prepare_communication_buffer_for_model(self.model) - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.get_model()) + and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + ) + + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( self.model, self.device, @@ -2032,40 +2933,72 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + and supports_dynamo() ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.dynamo_as_is_count += 1 - self.model.compile( - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) + self.model.compile(fullgraph=True, backend=backend) return # for other compilation levels, cudagraph behavior is controlled by # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.enable_dbo + ): + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) + elif self.parallel_config.enable_dbo: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) + else: + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) + + def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: + """Extract Eagle3 auxiliary layer indices from speculative config. + + These indices specify which hidden states from the base model should + be used as auxiliary inputs for the Eagle3 drafter model during + speculative decoding. + + Returns: + Tuple of layer indices if found in draft model config, + None otherwise. + """ + if not (self.speculative_config and self.speculative_config.draft_model_config): + return None + + hf_config = self.speculative_config.draft_model_config.hf_config + if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"): + return None + + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + + return None def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model = self.get_model() - model_loader.load_weights(model, model_config=self.model_config) + model_loader.load_weights(self.get_model(), model_config=self.model_config) def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", ) -> None: - model = self.get_model() TensorizerLoader.save_model( - model, + self.get_model(), tensorizer_config=tensorizer_config, model_config=self.model_config, ) @@ -2090,9 +3023,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get metadata for this request. request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) @@ -2100,7 +3038,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( - num_prompt_tokens - 1, num_prompt_logprobs + 1) + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. @@ -2129,28 +3068,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If this is a partial request (i.e. chunked prefill), # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] - offset = self.query_start_loc_np[req_idx].item() - prompt_hidden_states = hidden_states[offset:offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states, None) + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. - tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( - logprobs, num_prompt_logprobs, tgt_token_ids) + logprobs, num_prompt_logprobs, tgt_token_ids + ) # Transfer GPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( - token_ids, non_blocking=True) - logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, - non_blocking=True) + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( - ranks, non_blocking=True) + ranks, non_blocking=True + ) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. @@ -2178,8 +3119,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_index = self.input_batch.req_id_to_index[req_id] num_nans_in_logits[req_id] = ( int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + if num_nans_for_index is not None and req_index < logits.shape[0] + else 0 + ) return num_nans_in_logits except IndexError: return {} @@ -2202,14 +3144,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @functools.cache def rand_input_ids() -> torch.Tensor: return torch.randint_like( - self.input_ids, + self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + dtype=input_ids.dtype, + ) logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) + input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) @@ -2219,10 +3161,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=self.max_model_len, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data @@ -2230,22 +3175,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_items = [dummy_mm_item] * max_items_per_batch - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - )) + model = cast(SupportsMultiModal, self.model) + return next( + mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) @torch.inference_mode() def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, + allow_microbatching: bool = True, skip_eplb: bool = False, is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -2254,26 +3206,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: num_tokens: Number of tokens to run the dummy forward pass. cudagraph_runtime_mode: used to control the behavior. + - if not set will determine the cudagraph mode based on using + the self.cudagraph_dispatcher. - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is needed. - force_attention: If True, always create attention metadata. Used to + force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. + remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - } - - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens += num_pad + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and - # cudagraph_mode.seperate_routine(). This means that we are using + # cudagraph_mode.separate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. # uniform decode batches. A uniform decode batch means that all # requests have identical query length, except a potential virtual @@ -2285,18 +3239,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - if uniform_decode: - num_reqs = cdiv(num_tokens, max_query_len) - assert num_reqs <= max_num_reqs, \ - "Do not capture num_reqs > max_num_reqs for uniform batch" + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len @@ -2308,65 +3272,127 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) - attn_metadata: Optional[dict[str, Any]] = None + # Disable DP padding when running eager + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + + # We currently only microbatch if the number of tokens is + # over a certain threshold. + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=total_num_scheduled_tokens, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=total_num_scheduled_tokens, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, + ) + num_tokens_after_padding = num_tokens + if num_tokens_across_dp is not None: + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) + + attn_metadata: Optional[PerLayerAttnMetadata] = None # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] - # Make sure max_model_len is used at the graph capture time. - self.seq_lens_np[:num_reqs] = self.max_model_len - self.seq_lens_np[num_reqs:] = 0 - self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True) + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + seq_lens = max_query_len + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + - 1], - seq_lens=self.seq_lens[:num_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], + seq_lens=self.seq_lens.gpu[:num_reqs], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, block_table_tensor=self.input_batch.block_table[ - kv_cache_group_id].get_device_tensor()[:num_reqs], - slot_mapping=self.input_batch. - block_table[kv_cache_group_id].slot_mapping[:num_tokens], - causal=True) - + kv_cache_group_id + ].get_device_tensor(num_reqs), + slot_mapping=self.input_batch.block_table[ + kv_cache_group_id + ].slot_mapping.gpu[:num_tokens], + causal=True, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, + ) for attn_group in self.attn_groups[kv_cache_group_id]: - attn_metadata_i = attn_group.metadata_builder\ - .build_for_cudagraph_capture(common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata + ) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list + ): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) + for layer_name in attn_group.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert type(attn_metadata) is dict + metadata_builder = attn_group.get_metadata_builder() + attn_metadata_i = metadata_builder.build_for_cudagraph_capture( + common_attn_metadata + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens): - if self.supports_mm_inputs: + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): + # Make sure padding doesn't exceed max_num_tokens + assert num_tokens_after_padding <= self.max_num_tokens + model_kwargs = self._init_model_kwargs(num_tokens_after_padding) + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] model_kwargs = { - **self._init_model_kwargs(num_tokens), + **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding] + model_kwargs = self._init_model_kwargs(num_tokens_after_padding) else: - input_ids = self.input_ids[:num_tokens] + input_ids = self.input_ids.gpu[:num_tokens_after_padding] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] + positions = self.mrope_positions.gpu[:, :num_tokens_after_padding] else: - positions = self.positions[:num_tokens] + positions = self.positions.gpu[:num_tokens_after_padding] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -2376,30 +3402,58 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + ) + ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) - if cudagraph_runtime_mode == CUDAGraphMode.NONE: - batch_descriptor = None - else: - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) - # sanity check - assert cudagraph_runtime_mode == _cg_mode, ( - f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + num_tokens_after_padding, None, False + ) - with self.maybe_randomize_inputs(input_ids), set_forward_context( + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + if not is_profile + else (CUDAGraphMode.NONE, None) + ) + if cudagraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for cudagraph capture + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) + else: + cudagraph_runtime_mode = _cg_mode + + if ubatch_slices is not None: + # Adjust values to reflect a single ubatch. + # TODO(sage,lucas): this is cruft that should be addressed in + # the padding refactor. + num_tokens_after_padding = ubatch_slices[0].num_tokens + if num_tokens_across_dp is not None: + num_tokens_across_dp[:] = num_tokens_after_padding + + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_tokens, + num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), + ): outputs = self.model( input_ids=input_ids, positions=positions, @@ -2415,7 +3469,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real @@ -2440,11 +3495,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # To avoid breaking the sampler, we use a random tensor here instead. hidden_states = torch.rand_like(hidden_states) - logits = self.model.compute_logits(hidden_states, None) + logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) + dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=dummy_tensors(0.5), @@ -2460,42 +3514,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), output_token_ids=[[] for _ in range(num_reqs)], + spec_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, logitsprocs=LogitsProcessors(), ) try: - sampler_output = self.sampler(logits=logits, - sampling_metadata=dummy_metadata) + sampler_output = self.sampler( + logits=logits, sampling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " "`max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, self.device) + draft_token_ids, self.device + ) num_tokens = sum(len(ids) for ids in draft_token_ids) # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) + target_logits = torch.randn( + num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype + ) # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + bonus_token_ids = torch.zeros( + num_reqs, device=self.device, dtype=torch.int32 + ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, @@ -2525,12 +3582,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens_list, device="cpu", ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) + dummy_token_ids = torch.zeros( + (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device + ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) + dummy_pooling_params.verify(task=task, model_config=self.model_config) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) @@ -2540,19 +3598,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooling_params=[dummy_pooling_params] * num_reqs, ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor( + num_scheduled_tokens_list, device=hidden_states.device + ) try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) + return model.pooler( + hidden_states=hidden_states, pooling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e @@ -2566,7 +3627,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for task in self.get_supported_pooling_tasks(): # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = output.get_data_nbytes() + output_size[task] = sum(o.nbytes for o in output) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0] @@ -2578,19 +3639,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None - # TODO: handle encoder-decoder models once we support them. if (encoder_budget := mm_budget.get_encoder_budget()) > 0: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -2608,22 +3670,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, expected_num_items=max_mm_items_per_batch, ) + # NOTE: This happens when encoder cache needs to store + # the embeddings that encoder outputs are scattered onto. + # In this case we create dummy embeddings of size + # (encode_budget, hidden_size) and scatter encoder + # output into it. + encoder_output_shape = dummy_encoder_outputs[0].shape + if encoder_output_shape[0] < encoder_budget: + expanded_outputs = [] + for output in dummy_encoder_outputs: + expanded = output.new_zeros( + (encoder_budget, encoder_output_shape[-1]) + ) + num_tokens = output.shape[0] + expanded[:num_tokens].copy_(output) + expanded_outputs.append(expanded) + + dummy_encoder_outputs = expanded_outputs + # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True + ) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -2636,19 +3716,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.clear() gc.collect() - def capture_model(self) -> None: + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to `NONE`") - return + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) + return 0 else: self.initialize_cudagraph_capture() compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] @contextmanager def freeze_gc(): @@ -2664,13 +3744,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): finally: if should_freeze: gc.unfreeze() + gc.collect() # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): + start_free_gpu_memory = torch.cuda.mem_get_info()[0] cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() @@ -2678,46 +3761,61 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + ) - # Capture full cudagraph for uniform decode batches if we have - # dont already have full mixed prefill-decode cudagraphs - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len + # Capture full cudagraph for uniform decode batches if we + # don't already have full mixed prefill-decode cudagraphs. + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len + x + for x in self.cudagraph_batch_sizes + if max_num_tokens >= x >= self.uniform_decode_query_len ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + ) + + torch.cuda.synchronize() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because - # we may doing lazy capturing in future that still allows capturing + # we may do lazy capturing in future that still allows capturing # after here. set_cudagraph_capturing_enabled(False) end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) + return cuda_graph_size - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode in [CUDAGraphMode.FULL, - CUDAGraphMode.PIECEWISE] + def _capture_cudagraphs( + self, + compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -2726,163 +3824,244 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + cudagraph_runtime_mode.name, + ), + ) + # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: + # We currently only capture ubatched graphs when its a FULL + # cudagraph, a uniform decode batch, and the number of tokens + # is above the threshold. Otherwise we just capture a non-ubatched + # version of the graph + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode + and check_ubatch_thresholds( + config=self.vllm_config.parallel_config, + num_tokens=num_tokens, + uniform_decode=uniform_decode, + ) + ) + for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. # But be careful, warm up with `NONE`is orthogonal to # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - skip_eplb=True) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - skip_eplb=True) + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) + self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + assert len(self.attn_groups) == 0, "Attention backends are already initialized" - def get_attn_backends_for_layers( - layer_names: list[str] - ) -> dict[type[AttentionBackend], list[str]]: + class AttentionGroupKey(NamedTuple): + attn_backend: type[AttentionBackend] + kv_cache_spec: KVCacheSpec + + def get_attn_backends_for_group( + kv_cache_group_spec: KVCacheGroupSpec, + ) -> dict[AttentionGroupKey, list[str]]: + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names + ) attn_backends = {} attn_backend_layers = defaultdict(list) - # Dedupe based on full class name; this is a bit safer than using + # Dedupe based on full class name; this is a bit safer than # using the class itself as the key because when we create dynamic # attention backend subclasses (e.g. ChunkedLocalAttention) unless # they are cached correctly, there will be different objects per # layer. - for layer_name in layer_names: - attn_backend = attn_layers[layer_name].get_attn_backend() - key = attn_backend.full_cls_name() - attn_backends[key] = attn_backend + for layer_name in kv_cache_group_spec.layer_names: + attn_backend = layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + + full_cls_name = attn_backend.full_cls_name() + layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] + key = (full_cls_name, layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) attn_backend_layers[key].append(layer_name) - return { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - } + return {attn_backends[k]: v for k, v in attn_backend_layers.items()} def create_attn_groups( - attn_backends_map: dict[AttentionBackend, list[str]], - kv_cache_spec: KVCacheSpec, + attn_backends_map: dict[AttentionGroupKey, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for attn_backend, layer_names in attn_backends_map.items(): - attn_metadata_builder_i = attn_backend.get_builder_cls()( - kv_cache_spec, + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): + attn_group = AttentionGroup.create_with_metadata_builders( + attn_backend, layer_names, + kv_cache_spec, self.vllm_config, self.device, + num_metadata_builders=1 + if not self.parallel_config.enable_dbo + else 2, ) - attn_group = AttentionGroup(attn_backend, - attn_metadata_builder_i, - layer_names) + attn_groups.append(attn_group) return attn_groups for kv_cache_group_spec in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backends = get_attn_backends_for_layers( - kv_cache_group_spec.layer_names) - # TODO(lucas): move `get_mamba_attn_backend` into the mamba - # layers like above - elif isinstance(kv_cache_spec, MambaSpec): - attn_backends = { - get_mamba_attn_backend(kv_cache_spec.mamba_type): - kv_cache_group_spec.layer_names - } - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") + attn_backends = get_attn_backends_for_group(kv_cache_group_spec) + self.attn_groups.append(create_attn_groups(attn_backends)) - self.attn_groups.append( - create_attn_groups(attn_backends, kv_cache_spec)) - - # Calculate reorder batch threshold (if neeeded) + # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() def initialize_cudagraph_capture(self) -> None: + """ + Resolve the cudagraph_mode when there are multiple attention + backends with potential conflicting CUDA graph support. + Then initialize the cudagraph_dispatcher based on the resolved + cudagraph_mode. + """ min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None for attn_group in self._attn_group_iterator(): - builder = attn_group.metadata_builder + builder = attn_group.get_metadata_builder() if builder.cudagraph_support.value < min_cg_support.value: min_cg_support = builder.cudagraph_support min_cg_builder_name = builder.__class__.__name__ - # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported - if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ - and min_cg_support != AttentionCGSupport.ALWAYS: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") + if ( + cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL + and min_cg_support != AttentionCGSupport.ALWAYS + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += "; please try cudagraph_mode=PIECEWISE, and "\ + msg += ( + "; please try cudagraph_mode=PIECEWISE, and " "make sure compilation level is piecewise" + ) raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY + ) + logger.warning(msg) + + # check that if we are doing decode full-cudagraphs it is supported + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) + if self.compilation_config.level == CompilationLevel.PIECEWISE and ( + self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition + ): + msg += ( + "; setting cudagraph_mode=PIECEWISE because " + "attention is compiled piecewise" + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.PIECEWISE + ) + else: + msg += ( + "; setting cudagraph_mode=NONE because " + "attention is not compiled piecewise" + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.NONE + ) logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is # supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_cg_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})") + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 + and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})" + ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if cudagraph_mode.has_full_cudagraphs() \ - and min_cg_support == AttentionCGSupport.NEVER: - raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" - f"support:{min_cg_support}) " - "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise") + if ( + cudagraph_mode.has_full_cudagraphs() + and min_cg_support == AttentionCGSupport.NEVER + ): + raise ValueError( + f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise" + ) # Trigger cudagraph dispatching keys initialization here (after # initializing attn backends). self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len) + self.compilation_config.cudagraph_mode, self.uniform_decode_query_len + ) def calculate_reorder_batch_threshold(self) -> None: """ @@ -2890,26 +4069,104 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is compatible (e.g., decode threshold is the same) """ for group in self._attn_group_iterator(): - attn_metadata_builder_i = group.metadata_builder + attn_metadata_builder_i = group.get_metadata_builder() # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: + if reorder_batch_threshold_i != self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " - f"{self.reorder_batch_threshold}") + f"{self.reorder_batch_threshold}" + ) else: self.reorder_batch_threshold = reorder_batch_threshold_i - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def _find_compatible_block_sizes( + self, + kv_manager_block_size: int, + backend_cls: type[AttentionBackend], + return_all: bool = False, + ) -> list[int]: + """ + Find compatible block sizes for a backend. + + Args: + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + return_all: Return all compatible sizes if True, max size if False + + Returns: + Compatible block size(s) based on return_all parameter + + Raises: + ValueError: If no compatible block size found + """ + supported_block_size = backend_cls.get_supported_kernel_block_size() + compatible_sizes = [] + + for block_size in supported_block_size: + if isinstance(block_size, int): + if kv_manager_block_size % block_size == 0: + compatible_sizes.append(block_size) + elif ( + isinstance(block_size, MultipleOf) + and kv_manager_block_size % block_size.base == 0 + ): + compatible_sizes.append(kv_manager_block_size) + + if not compatible_sizes: + raise ValueError(f"No compatible block size for {kv_manager_block_size}") + + return compatible_sizes if return_all else [max(compatible_sizes)] + + def _select_common_block_size( + self, kv_manager_block_size: int, attn_groups: list[AttentionGroup] + ) -> int: + """ + Select common block size for all backends. + + Args: + kv_manager_block_size: Block size of KV cache + attn_groups: List of attention groups + + Returns: + Block size supported by all backends, + prioritizing cache_config.block_size + + Raises: + ValueError: If no common block size found + """ + all_backend_supports = [] + + for attn_group in attn_groups: + compatible_sizes = self._find_compatible_block_sizes( + kv_manager_block_size, attn_group.backend, return_all=True + ) + supported_sizes = sorted(list(set(compatible_sizes)), reverse=True) + all_backend_supports.append(set(supported_sizes)) + + common_supported_sizes = set.intersection(*all_backend_supports) + + if not common_supported_sizes: + error_msg = f"No common block size for {kv_manager_block_size}. " + for i, attn_group in enumerate(attn_groups): + supported = all_backend_supports[i] + error_msg += ( + f"Backend {attn_group.backend} supports: {sorted(supported)}. " + ) + raise ValueError(error_msg) + + if self.cache_config.block_size in common_supported_sizes: + return self.cache_config.block_size + + return max(common_supported_sizes) + + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -2921,27 +4178,43 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] - if block_sizes != [self.cache_config.block_size]: + + # Generate kernel_block_sizes that matches each block_size + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ + self.cache_config.block_size + ]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, + logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config + else 0 + ), ) def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -2951,12 +4224,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -2966,21 +4239,64 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) - def _kv_cache_spec_attn_group_iterator( - self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: if not self.kv_cache_config.kv_cache_groups: return - for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): - for attn_group in attn_groups: - yield self.kv_cache_config.kv_cache_groups[ - kv_cache_spec_id].kv_cache_spec, attn_group + for attn_groups in self.attn_groups: + yield from attn_groups + + def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]: + """ + Generate kernel_block_sizes that matches each block_size. + + For attention backends that support virtual block splitting, + use the supported block sizes from the backend. + For other backends (like Mamba), use the same block size (no splitting). + + Args: + kv_cache_config: The KV cache configuration. + + Returns: + list[int]: List of kernel block sizes for each cache group. + """ + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups + ): + kv_cache_spec = kv_cache_group.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # Pick an arbitrary one to dispatch. + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # all backends in the group. + attn_groups = self.attn_groups[kv_cache_group_id] + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + selected_kernel_size = self._select_common_block_size( + kv_manager_block_size, attn_groups + ) + kernel_block_sizes.append(selected_kernel_size) + elif isinstance(kv_cache_spec, MambaSpec): + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append(kv_cache_spec.block_size) + else: + raise NotImplementedError( + f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" + ) + return kernel_block_sizes def _reshape_kv_cache_tensors( self, @@ -2993,61 +4309,74 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: kv_cache_config: The KV cache config kv_cache_raw_tensors: The KV cache buffer of each layer, with - correct size but uninitialized shape. + correct size but uninitialized shape. Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} has_attn, has_mamba = False, False - for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True + kv_manager_block_size = kv_cache_spec.block_size + kernel_size_list = self._find_compatible_block_sizes( + kv_manager_block_size, attn_backend, return_all=False + ) + kernel_size = kernel_size_list[0] + num_blocks_per_kv_block = kv_manager_block_size // kernel_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + kernel_num_blocks, + kernel_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len( - kv_cache_shape) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501 + assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple( - range(len(kv_cache_shape))) + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and # then permute it according to the stride order which could # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + kv_cache_spec.page_size_bytes // dtype_size + ) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) @@ -3066,43 +4395,40 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): raise NotImplementedError if has_attn and has_mamba: - self._verify_hybrid_attention_mamba_layout(kv_cache_config, - kv_cache_raw_tensors) + self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches - def _verify_hybrid_attention_mamba_layout( - self, kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor] + ) -> None: """ - Verify that the KV cache memory layout is compatible for - models with both attention and mamba KV cache groups. + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). Args: - kv_cache_config: The KV cache config - kv_cache_raw_tensors: The KV cache buffer of each layer. + kv_caches: The KV cache buffer of each layer. """ - for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: - raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, AttentionSpec): - - kv_cache_shape = group.backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - if kv_cache_shape[0] != num_blocks or kv_cache_shape[ - 1] != 2: - raise ValueError( - "Hybrid models in V1 require an attention " - "backend with kv_cache_shape=" - "(num_blocks, 2, ...). Please try setting " - "VLLM_ATTENTION_BACKEND=FLASHINFER") + kv_cache = kv_caches[layer_name] + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " + f"a tensor of shape {kv_cache.shape}" + ) + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -3115,35 +4441,54 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors + ) - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - self.attn_groups, - self.runner_only_attn_layers, - ) - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) - # Iterate in reversed order and add layers that re-use KV cache - # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) + return kv_caches + + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig + ) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + self.runner_only_attn_layers, + ) + + if self.cache_config.kv_sharing_fast_prefill: + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) else: break - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - return kv_caches - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -3153,9 +4498,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): @@ -3165,34 +4512,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.drafter.validate_same_kv_cache_group(kv_cache_config) if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) + kv_transfer_group = get_kv_transfer_group() + kv_transfer_group.register_kv_caches(kv_caches) + kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + + if self.dcp_world_size > 1: + layer_names = self.attn_groups[0][0].layer_names + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode." + ) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_spec = EncoderOnlyAttentionSpec( + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) + ) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -3205,80 +4566,121 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue + if isinstance(attn_module, Attention): + if ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue - # TODO: Support other attention modules, e.g., cross-attention - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + assert not use_mla, "MLA is not supported for slidingwindow" + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + ) + elif self.attention_chunk_size is not None and isinstance( + attn_module, ChunkedLocalAttention + ): + kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + attention_chunk_size=self.attention_chunk_size, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla) + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): + # encoder-only attention does not need KV cache. + continue else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if self.vllm_config.speculative_config is not None: - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len + elif isinstance(attn_module, MLAAttention): + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str, + ) - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. - for layer_name, mamba_module in mamba_layers.items(): + elif isinstance(attn_module, MambaBase): + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"] + ): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet." + ) + mamba_block_size = self.vllm_config.cache_config.mamba_block_size + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, + shapes=attn_module.get_state_shape(), + dtypes=attn_module.get_state_dtype(), + block_size=mamba_block_size, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type) + mamba_type=attn_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ), + ) + + ds_indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ) + for layer_name, ds_indexer_module in ds_indexer_layers.items(): + kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() return kv_cache_spec + + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist() diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py new file mode 100644 index 0000000000000..fb63fe8d25430 --- /dev/null +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -0,0 +1,465 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import threading +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import torch + +import vllm.envs as envs +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed import get_ep_group +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id +from vllm.forward_context import ( + DPMetadata, + create_forward_context, + get_forward_context, + override_forward_context, +) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.utils import has_deep_gemm +from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts + +logger = init_logger(__name__) + + +@dataclass +class UbatchMetadata: + context: UBatchContext + input_ids: torch.Tensor + positions: torch.Tensor + inputs_embeds: Optional[torch.Tensor] + intermediate_tensors: Optional[IntermediateTensors] + num_tokens: int + + +@dataclass +class CUDAGraphMetaData: + cudagraph: torch.cuda.CUDAGraph + ubatch_metadata: UbatchMetadata + outputs: Optional[Any] = None + + +class SMControlContextManager: + def __init__( + self, + comm_sms: int, + set_comm_sms: Callable[[int], None], + set_compute_sms: Callable[[int], None], + ): + """ + Context manager for controlling SM (Streaming Multiprocessor) + allocation. Upon entering the context, it sets the number of SMs + allocated for communication and computation to comm_sms and + total_sms - comm_sms respectively. Upon exiting, it restores the + allocation to use all available SMs (i.e. total_sms). + + Args: + comm_sms (int): The number of SMs to allocate for communication. + (The remainder will be used for computation.) + set_comm_sms (Callable[[int], None]): + A function that sets the number of SMs for communication. + set_compute_sms (Callable[[int], None]): + A function that sets the number of SMs for computation. + """ + + assert current_platform.is_cuda(), ( + "SM control is currently only supported on CUDA" + ) + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + total_sms = props.multi_processor_count + + assert comm_sms < total_sms + self.total_sms = total_sms + self.compute_sms = total_sms - comm_sms + self.comm_sms = comm_sms + self.set_comm_sms = set_comm_sms + self.set_compute_sms = set_compute_sms + + def __enter__(self): + self.set_comm_sms(self.comm_sms) + self.set_compute_sms(self.compute_sms) + + def __exit__(self, exc_type, exc_value, traceback): + self.set_comm_sms(self.total_sms) + self.set_compute_sms(self.total_sms) + + +class UBatchWrapper: + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + device: torch.cuda.device, + ): + self.runnable = runnable + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.comm_stream = torch.cuda.Stream(device=device) + # Two ubatch threads plus the main thread + self.ready_barrier = threading.Barrier(3) + + self.cudagraphs: dict[int, CUDAGraphMetaData] = {} + + self.cudagraph_wrapper = None + self.graph_pool = None + if runtime_mode is not CUDAGraphMode.NONE: + self.cudagraph_wrapper = CUDAGraphWrapper( + runnable, vllm_config, runtime_mode=runtime_mode + ) + self.graph_pool = current_platform.get_global_graph_pool() + + self.sm_control = self._create_sm_control_context(vllm_config) + self.device = device + + @staticmethod + def _create_sm_control_context(vllm_config: VllmConfig): + comm_sms = envs.VLLM_DBO_COMM_SMS + + set_comm_sms = lambda sms: None + if vllm_config.parallel_config.enable_expert_parallel: + # Currently only DeepEP highthroughput supports SM control so this + # only affects that case. + all2all_manager = get_ep_group().device_communicator.all2all_manager + + if all2all_manager.max_sms_used() is not None: + comm_sms = min(comm_sms, all2all_manager.max_sms_used()) + + if comm_sms > 0: + set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms) + + # TODO(lucas): support other kernels besides DeepGEMM + set_compute_sms = lambda sms: None + if has_deep_gemm() and comm_sms > 0: + import deep_gemm as dg + + set_compute_sms = lambda sms: dg.set_num_sms(sms) + + return SMControlContextManager( + comm_sms=comm_sms, + set_comm_sms=set_comm_sms, + set_compute_sms=set_compute_sms, + ) + + def __getattr__(self, key: str): + # allow accessing the attributes of the runnable. + if hasattr(self.runnable, key): + return getattr(self.runnable, key) + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}" + ) + + def unwrap(self) -> Callable: + # in case we need to access the original runnable. + return self.runnable + + def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + """ + Capture a cudagraph for a microbatched run. + + The logic here is somewhat complicated because we need to make sure that + each of the ubatch threads initialize the cuda context before we start + the graph capture. + + The flow is as follows: + 1. The main thread starts up each ubatch thread. Each thread will + initialize its cuda context (torch.cuda.current_blas_handle()) + before going to sleep upon entering the ubatch_context. + + 2. The main thread starts the graph capture and wakes up the first + ubatch thread. + + 3. Each ubatch thread runs the model to completion and returns the + completed output tensors back to the main thread. + + 4. The main thread stores the captured cudagraph along with its metadata + and returns + """ + + @torch.inference_mode() + def _capture_ubatch_thread(results, ubatch_metadata): + torch.cuda.set_device(self.device) + ubatch_context = ubatch_metadata.context + with torch.cuda.stream(ubatch_context.compute_stream): + _ = torch.cuda.current_blas_handle() + with torch.cuda.stream(ubatch_context.comm_stream): + _ = torch.cuda.current_blas_handle() + with ubatch_context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + compute_stream = ubatch_metadata[0].context.compute_stream + num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens + + # Ubatches will manually manage the forward context, so we override + # it to None here so we can have it restored correctly later + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread( + target=_capture_ubatch_thread, + args=( + results, + metadata, + ), + ) + ubatch_threads.append(thread) + thread.start() + self.ready_barrier.wait() # Wait for both threads to be ready + + # Capture the cudagraph + cudagraph_metadata = CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + ubatch_metadata=ubatch_metadata, + ) + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) + with torch.cuda.graph( + cudagraph_metadata.cudagraph, + stream=compute_stream, + pool=self.graph_pool, + ): + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + cudagraph_metadata.outputs = result + self.cudagraphs[num_tokens] = cudagraph_metadata + return cudagraph_metadata.outputs + + def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + @torch.inference_mode() + def _ubatch_thread(results, model, ubatch_metadata): + with ubatch_metadata.context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + + # Ubatch threads will manually manage the forward context, so we + # override it to None here so we can have it restored correctly + # after both threads have finished + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread( + target=_ubatch_thread, + args=( + results, + model, + metadata, + ), + ) + ubatch_threads.append(thread) + thread.start() + self.ready_barrier.wait() # Wait for both threads to be ready + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + return result + + def _make_ubatch_metadata( + self, + ubatch_slices, + attn_metadata, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + compute_stream, + dp_metadata, + batch_descriptor, + cudagraph_runtime_mode, + ) -> list[UbatchMetadata]: + # Create one forward context per ubatch + forward_contexts = [] + for i, ubatch_slice in enumerate(ubatch_slices): + forward_contexts.append( + create_forward_context( + attn_metadata[i] if attn_metadata is not None else None, + self.vllm_config, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=cudagraph_runtime_mode, + ) + ) + + ubatch_ctxs = make_ubatch_contexts( + num_micro_batches=len(ubatch_slices), + comm_stream=self.comm_stream, + compute_stream=compute_stream, + forward_contexts=forward_contexts, + ready_barrier=self.ready_barrier, + ) + + ubatch_metadata: list[UbatchMetadata] = [] + for i, ubatch_slice in enumerate(ubatch_slices): + ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) = self._slice_model_inputs( + ubatch_slice.token_slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ) + ubatch_metadata.append( + UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=sliced_input_ids, + positions=sliced_positions, + inputs_embeds=sliced_inputs_embeds, + intermediate_tensors=sliced_intermediate_tensors, + num_tokens=ubatch_slice.token_slice.stop + - ubatch_slice.token_slice.start, + ) + ) + + return ubatch_metadata + + def _slice_model_inputs( + self, + tokens_slice: slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ): + sliced_input_ids = input_ids[tokens_slice] + # if we are using mrope. Mrope adds an additional dimension to the + # positions tensor + if positions.ndim == 2: + sliced_positions = positions[:, tokens_slice] + else: + sliced_positions = positions[tokens_slice] + sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None + sliced_intermediate_tensors = ( + intermediate_tensors[tokens_slice] if intermediate_tensors else None + ) + + return ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + ubatch_slices = forward_context.ubatch_slices + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + # If there's no ubatching, just run the runnable object + if ubatch_slices is None: + # This is to account for the case where ubatching was aborted. + # When we capture full graphs we only capture one graph per shape, + # meaning that if we have a ubatched cudagraph for the current + # num_tokens, we don't have a non-ubatched one. Without this + # check, the cudagraph wrapper will try to capture a cudagraph + # for this shape during a normal run. + if cudagraph_runtime_mode is CUDAGraphMode.FULL: + assert batch_descriptor is not None + if batch_descriptor.num_tokens in self.cudagraphs: + cudagraph_runtime_mode = CUDAGraphMode.NONE + + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): + return self.runnable(*args, **kwargs) + else: + assert self.cudagraph_wrapper is not None + return self.cudagraph_wrapper(*args, **kwargs) + + attn_metadata = forward_context.attn_metadata + num_tokens = ( + ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start + ) * 2 + input_ids = kwargs["input_ids"] + positions = kwargs["positions"] + intermediate_tensors = kwargs["intermediate_tensors"] + inputs_embeds = kwargs["inputs_embeds"] + compute_stream = torch.cuda.current_stream() + + dp_metadata = forward_context.dp_metadata + + # We shouldn't be here unless we are running with multiple DP ranks + assert dp_metadata is not None + num_tokens_per_ubatch = ( + ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start + ) + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata = DPMetadata.make( + self.vllm_config.parallel_config, + num_tokens_per_ubatch, + ubatch_num_tokens_across_dp, + ) + + if ( + num_tokens not in self.cudagraphs + and cudagraph_runtime_mode is CUDAGraphMode.FULL + ): + ubatch_metadata = self._make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + dp_metadata=ubatch_dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ) + with self.sm_control: + return self._capture_ubatches(ubatch_metadata, self.model) + elif ( + num_tokens in self.cudagraphs + and cudagraph_runtime_mode is CUDAGraphMode.FULL + ): + cudagraph_metadata = self.cudagraphs[num_tokens] + cudagraph_metadata.cudagraph.replay() + return cudagraph_metadata.outputs + else: + ubatch_metadata = self._make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ) + with self.sm_control: + return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f83a4f4faeb5e..4f4da73fba6e6 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" + import copy import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed @@ -13,9 +14,11 @@ import torch.nn as nn import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce, +) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -28,10 +31,15 @@ from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + ModelRunnerOutput, +) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -42,7 +50,6 @@ if TYPE_CHECKING: class Worker(WorkerBase): - def __init__( self, vllm_config: VllmConfig, @@ -51,16 +58,18 @@ class Worker(WorkerBase): distributed_init_method: str, is_driver_worker: bool = False, ): - - super().__init__(vllm_config=vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker) + super().__init__( + vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() # Buffers saved before sleep @@ -70,8 +79,10 @@ class Worker(WorkerBase): # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", @@ -90,7 +101,9 @@ class Worker(WorkerBase): with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) + torch_profiler_trace_dir, use_gzip=True + ), + ) else: self.profiler = None @@ -103,20 +116,20 @@ class Worker(WorkerBase): if level == 2: model = self.model_runner.model self._sleep_saved_buffers = { - name: buffer.cpu().clone() - for name, buffer in model.named_buffers() + name: buffer.cpu().clone() for name, buffer in model.named_buffers() } allocator = CuMemAllocator.get_instance() - allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep used_bytes = total - free_bytes_after_sleep assert freed_bytes >= 0, "Memory usage increased after sleeping." logger.info( - "Sleep mode freed %.2f GiB memory, " - "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, - used_bytes / GiB_bytes) + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) def wake_up(self, tags: Optional[list[str]] = None) -> None: from vllm.device_allocator.cumem import CuMemAllocator @@ -132,49 +145,58 @@ class Worker(WorkerBase): buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} - def _maybe_get_memory_pool_context(self, - tag: str) -> AbstractContextManager: + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() if tag == "weights": assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") + "Sleep mode can only be used for one instance per process." + ) context = allocator.use_memory_pool(tag=tag) else: context = nullcontext() return context - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks def init_device(self): if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) - _check_if_gpu_supports_dtype(self.model_config.dtype) + current_platform.check_if_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment BEFORE taking + # memory snapshot + # This ensures NCCL buffers are allocated before we measure + # available memory + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Now take memory snapshot after NCCL is initialized gc.collect() torch.cuda.empty_cache() # take current memory snapshot self.init_snapshot = MemorySnapshot() - self.requested_memory = (self.init_snapshot.total_memory * - self.cache_config.gpu_memory_utilization) + self.requested_memory = ( + self.init_snapshot.total_memory + * self.cache_config.gpu_memory_utilization + ) if self.init_snapshot.free_memory < self.requested_memory: GiB = lambda b: round(b / GiB_bytes, 2) raise ValueError( @@ -187,19 +209,12 @@ class Worker(WorkerBase): f"utilization or reduce GPU memory used by other processes." ) else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) - # Set random seed. - set_random_seed(self.model_config.seed) + raise RuntimeError(f"Not support device type: {self.device_config.device}") # Construct the model runner self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device) + self.vllm_config, self.device + ) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -216,8 +231,7 @@ class Worker(WorkerBase): self.model_runner.update_config(overrides) def reload_weights(self) -> None: - with self._maybe_get_memory_pool_context(tag="weights"): - self.model_runner.reload_weights() + self.model_runner.reload_weights() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -225,25 +239,48 @@ class Worker(WorkerBase): memory can be used for KV cache without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the free memory that can be used for KV cache in + Then, it calculates the free memory that can be used for KV cache in bytes. Tip: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ + GiB = lambda b: b / GiB_bytes + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + msg = ( + f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly." + ) + logger.info(msg) + return kv_cache_memory_bytes + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( - self.init_snapshot, - weights_memory=int( - self.model_runner.model_memory_usage)) as profile_result: + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: self.model_runner.profile_run() + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + free_gpu_memory = profile_result.after_profile.free_memory # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. @@ -254,15 +291,15 @@ class Worker(WorkerBase): "This happens when other processes sharing the same container " "release GPU memory while vLLM is profiling during initialization. " "To fix this, ensure consistent GPU memory allocation or " - "isolate vLLM in its own container.") - available_kv_cache_memory = self.requested_memory \ - - profile_result.non_kv_cache_memory + "isolate vLLM in its own container." + ) + self.available_kv_cache_memory_bytes = ( + self.requested_memory - profile_result.non_kv_cache_memory + ) - unrequested_memory = self.init_snapshot.free_memory \ - - self.requested_memory + unrequested_memory = self.init_snapshot.free_memory - self.requested_memory logger.debug( - "Initial free memory: %.2f GiB; " - "Requested memory: %.2f (util), %.2f GiB", + "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB", GiB(self.init_snapshot.free_memory), self.cache_config.gpu_memory_utilization, GiB(self.requested_memory), @@ -274,11 +311,13 @@ class Worker(WorkerBase): GiB(free_gpu_memory - unrequested_memory), ) logger.debug(profile_result) - logger.info("Available KV cache memory: %.2f GiB", - GiB(available_kv_cache_memory)) + logger.info( + "Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + ) gc.collect() - return int(available_kv_cache_memory) + return int(self.available_kv_cache_memory_bytes) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() @@ -303,16 +342,80 @@ class Worker(WorkerBase): warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes + x + for x in warmup_sizes + if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, skip_eplb=True) + self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False) + self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) + # Warmup and tune the kernels used during model execution before + # cuda graph capture. + kernel_warmup(self) + + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model() + cuda_graph_memory_bytes = self.model_runner.capture_model() + + if self.cache_config.kv_cache_memory_bytes is None and hasattr( + self, "peak_activation_memory" + ): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + non_kv_cache_memory = ( + self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes + ) + kv_cache_memory_bytes_to_gpu_limit = ( + self.init_snapshot.free_memory + - non_kv_cache_memory + - redundancy_buffer_memory + ) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) + - non_kv_cache_memory + - redundancy_buffer_memory + ) + + msg = ( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` " + f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit " + f"into requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` " + f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{GiB(self.available_kv_cache_memory_bytes)} GiB." + ) + + logger.debug(msg) # Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory @@ -320,28 +423,28 @@ class Worker(WorkerBase): # NOTE: This is called after `capture_model` on purpose to prevent # memory buffers from being cleared by `torch.cuda.empty_cache`. if get_pp_group().is_last_rank: - max_num_reqs = min(self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens) + max_num_reqs = min( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + ) # We skip EPLB here since we don't want to record dummy metrics - hidden_states, last_hidden_states = \ - self.model_runner._dummy_run( - num_tokens=max_num_reqs, - skip_eplb=True, - ) + hidden_states, last_hidden_states = self.model_runner._dummy_run( + num_tokens=max_num_reqs, + skip_eplb=True, + ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) else: - self.model_runner._dummy_sampler_run( - hidden_states=last_hidden_states) - - # Warmup kernels used during model execution - kernel_warmup(self) + self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -352,38 +455,52 @@ class Worker(WorkerBase): def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None - if not get_pp_group().is_first_rank: + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) + } + if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + ) - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) - - parallel_config = self.vllm_config.parallel_config - if parallel_config.distributed_executor_backend != "external_launcher" \ - and not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output + output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output - assert isinstance(output, ModelRunnerOutput) + assert isinstance(output, IntermediateTensors) + parallel_config = self.vllm_config.parallel_config + assert ( + parallel_config.distributed_executor_backend != ("external_launcher") + and not get_pp_group().is_last_rank + ) + + get_pp_group().send_tensor_dict( + output.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output return output def take_draft_token_ids(self) -> Optional[DraftTokenIds]: @@ -396,11 +513,14 @@ class Worker(WorkerBase): self.profiler.start() else: self.profiler.stop() - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) + # only print profiler results on rank 0 + if self.local_rank == 0: + print( + self.profiler.key_averages().table(sort_by="self_cuda_time_total") + ) def execute_dummy_batch(self) -> None: - self.model_runner._dummy_run(1) + self.model_runner._dummy_run(1, uniform_decode=True) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) @@ -418,68 +538,79 @@ class Worker(WorkerBase): # worker will always be healthy as long as it's running. return - def _eplb_before_scale_down(self, old_ep_size: int, - new_ep_size: int) -> None: + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: from vllm.distributed.parallel_state import get_ep_group + if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "before scaling down...") + logger.info( + "[Elastic EP] Starting expert resharding before scaling down..." + ) rank_mapping = { old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 for old_ep_rank in range(old_ep_size) } assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange(self.model_runner.model, - execute_shuffle=True, - global_expert_load=None, - rank_mapping=rank_mapping) - torch.cuda.synchronize() - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _eplb_after_scale_up( - self, old_ep_size: int, new_ep_size: int, - global_expert_load: Optional[torch.Tensor]) -> None: - from vllm.distributed.parallel_state import get_ep_group - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "after scaling up...") - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } - assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange( + self.model_runner.model, + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping, + ) + torch.cuda.synchronize() + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + def _eplb_after_scale_up( + self, + old_ep_size: int, + new_ep_size: int, + global_expert_load: Optional[torch.Tensor], + ) -> None: + from vllm.distributed.parallel_state import get_ep_group + + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding after scaling up...") + rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} + assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=True, global_expert_load=global_expert_load, - rank_mapping=rank_mapping) + rank_mapping=rank_mapping, + ) if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") def _reconfigure_parallel_config( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: """ Update parallel config with provided reconfig_request """ parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank - if reconfig_request.new_data_parallel_rank_local != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank_local = \ + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( reconfig_request.new_data_parallel_rank_local - parallel_config.data_parallel_master_ip = \ + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) - def _reconfigure_moe(self, old_ep_size: int, - new_ep_size: int) -> Optional[torch.Tensor]: + def _reconfigure_moe( + self, old_ep_size: int, new_ep_size: int + ) -> Optional[torch.Tensor]: """ Reconfigure MoE modules with provided reconfig_request @@ -487,19 +618,26 @@ class Worker(WorkerBase): otherwise None """ from vllm.distributed.parallel_state import ( - get_dp_group, get_ep_group, prepare_communication_buffer_for_model) - from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoEParallelConfig) + get_dp_group, + get_ep_group, + prepare_communication_buffer_for_model, + ) + from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig parallel_config = self.vllm_config.parallel_config moe_modules = [ - module for module in self.model_runner.model.modules() - if module.__class__.__name__ == "FusedMoE" + module + for module in self.model_runner.model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) ] num_local_experts = moe_modules[0].moe_config.num_local_experts - assert all(module.moe_config.num_local_experts == num_local_experts - for module in moe_modules), ( - "All MoE modules must have the same number of experts") + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" for module in moe_modules: module.moe_config.num_experts = num_local_experts * new_ep_size module.global_num_experts = module.moe_config.num_experts @@ -512,49 +650,62 @@ class Worker(WorkerBase): if new_ep_size < old_ep_size: num_local_physical_experts = num_local_experts assert self.model_runner.eplb_state is not None - new_physical_experts = \ + new_physical_experts = ( self.model_runner.eplb_state.physical_to_logical_map.shape[1] + ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1]) + new_physical_experts + - self.model_runner.eplb_state.logical_replica_count.shape[1] + ) global_expert_load = None else: - num_local_physical_experts = torch.tensor([num_local_experts], - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + num_local_physical_experts = torch.tensor( + [num_local_experts], dtype=torch.int32, device="cpu" + ) + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = num_local_physical_experts.item() new_physical_experts = num_local_physical_experts * new_ep_size assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=False) + self.model_runner.model, execute_shuffle=False + ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_load.shape[1]) + new_physical_experts - global_expert_load.shape[1] + ) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( num_physical_experts=new_physical_experts, - num_local_physical_experts=num_local_physical_experts) + num_local_physical_experts=num_local_physical_experts, + ) return global_expert_load def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: from vllm.config import set_current_vllm_config from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_ep_group) + cleanup_dist_env_and_memory, + get_ep_group, + ) old_ep_size = get_ep_group().world_size old_ep_rank = get_ep_group().rank - new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( - ).world_size * get_pp_group().world_size + new_ep_size = ( + reconfig_request.new_data_parallel_size + * get_tp_group().world_size + * get_pp_group().world_size + ) if new_ep_size < old_ep_size: self._eplb_before_scale_down(old_ep_size, new_ep_size) cleanup_dist_env_and_memory() - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): assert old_ep_rank >= new_ep_size # shutdown return @@ -562,16 +713,18 @@ class Worker(WorkerBase): self._reconfigure_parallel_config(reconfig_request) with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + ) global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) if new_ep_size > old_ep_size: assert global_expert_load is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, - global_expert_load) + self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load) def save_sharded_state( self, @@ -580,6 +733,7 @@ class Worker(WorkerBase): max_size: Optional[int] = None, ) -> None: from vllm.model_executor.model_loader import ShardedStateLoader + ShardedStateLoader.save_model( self.model_runner.model, path, @@ -592,7 +746,12 @@ class Worker(WorkerBase): tensorizer_config: "TensorizerConfig", ) -> None: self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + tensorizer_config=tensorizer_config, + ) + + def shutdown(self) -> None: + if runner := getattr(self, "model_runner", None): + runner.ensure_kv_transfer_shutdown() def init_worker_distributed_environment( @@ -606,30 +765,14 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, backend) + init_distributed_environment( + parallel_config.world_size, rank, distributed_init_method, local_rank, backend + ) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size, + ) ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a03ebe35d8e0a..473982bebb127 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -3,19 +3,30 @@ """ Define KV connector functionality mixin for model runners. """ + import copy +from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Generator # noqa: UP035 -from typing import TYPE_CHECKING, Optional +from typing import ( + TYPE_CHECKING, # noqa: UP035 + Optional, +) from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_shutdown, + get_kv_transfer_group, + has_kv_transfer_group, +) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, - ModelRunnerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -25,7 +36,6 @@ logger = init_logger(__name__) # Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) class KVConnectorModelRunnerMixin: - @staticmethod def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). @@ -33,8 +43,7 @@ class KVConnectorModelRunnerMixin: kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) # Background KV cache transfers happen here. # These transfers are designed to be async and the requests @@ -42,6 +51,12 @@ class KVConnectorModelRunnerMixin: # Do this here to save a collective_rpc. kv_connector.start_load_kv(get_forward_context()) + @staticmethod + def ensure_kv_transfer_shutdown() -> None: + # has_kv_transfer_group can be None during interpreter shutdown. + if has_kv_transfer_group and has_kv_transfer_group(): + ensure_kv_transfer_shutdown() + @staticmethod def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): @@ -53,21 +68,24 @@ class KVConnectorModelRunnerMixin: ) -> tuple[Optional[set[str]], Optional[set[str]]]: if has_kv_transfer_group(): return get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids) + scheduler_output.finished_req_ids + ) return None, None @staticmethod - def kv_connector_no_forward(scheduler_output: "SchedulerOutput", - vllm_config: VllmConfig) -> ModelRunnerOutput: + def kv_connector_no_forward( + scheduler_output: "SchedulerOutput", vllm_config: VllmConfig + ) -> ModelRunnerOutput: # KV send/recv even if no work to do. - with set_forward_context( - None, vllm_config - ), KVConnectorModelRunnerMixin._get_kv_connector_output( - scheduler_output, wait_for_save=False) as kv_connector_output: + with ( + set_forward_context(None, vllm_config), + KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output, wait_for_save=False + ) as kv_connector_output, + ): pass - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): + if kv_connector_output.is_empty(): return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) @@ -76,18 +94,20 @@ class KVConnectorModelRunnerMixin: @staticmethod def maybe_get_kv_connector_output( - scheduler_output: "SchedulerOutput" + scheduler_output: "SchedulerOutput", ) -> AbstractContextManager[Optional[KVConnectorOutput]]: - return KVConnectorModelRunnerMixin._get_kv_connector_output( - scheduler_output) if has_kv_transfer_group() else nullcontext() + return ( + KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) + if has_kv_transfer_group() + else nullcontext() + ) # This context manager must be used within an active forward context. - # It encapsulates the entire KV conector lifecycle within execute_model + # It encapsulates the entire KV connector lifecycle within execute_model @staticmethod @contextmanager def _get_kv_connector_output( - scheduler_output: "SchedulerOutput", - wait_for_save: bool = True + scheduler_output: "SchedulerOutput", wait_for_save: bool = True ) -> Generator[KVConnectorOutput, None, None]: output = KVConnectorOutput() @@ -95,8 +115,7 @@ class KVConnectorModelRunnerMixin: kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) # Background KV cache transfers happen here. # These transfers are designed to be async and the requests @@ -110,6 +129,17 @@ class KVConnectorModelRunnerMixin: kv_connector.wait_for_save() output.finished_sending, output.finished_recving = ( - kv_connector.get_finished(scheduler_output.finished_req_ids)) + kv_connector.get_finished(scheduler_output.finished_req_ids) + ) + output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() + output.kv_connector_stats = ( + KVConnectorModelRunnerMixin.get_kv_connector_stats() + ) kv_connector.clear_connector_metadata() + + @staticmethod + def get_kv_connector_stats() -> Optional[KVConnectorStats]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_kv_connector_stats() + return None diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 84ed46989ea97..45b7a548d1843 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -5,13 +5,14 @@ Define LoRA functionality mixin for model runners. """ from contextlib import contextmanager -from typing import Union +from typing import Optional, Union import numpy as np import torch import torch.nn as nn -from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.config import VllmConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -27,67 +28,65 @@ logger = init_logger(__name__) # Defined as a mixin for GPUModelRunner class LoRAModelRunnerMixin: - - LORA_WARMUP_RANK = 8 - - def load_lora_model(self, model: nn.Module, model_config: ModelConfig, - scheduler_config: SchedulerConfig, - lora_config: LoRAConfig, - device: torch.device) -> nn.Module: - + def load_lora_model( + self, model: nn.Module, vllm_config: VllmConfig, device: torch.device + ) -> nn.Module: if not supports_lora(model): - raise ValueError( - f"{model.__class__.__name__} does not support LoRA yet.") + raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.") if supports_multimodal(model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") - - # Use get_text_config() in case of multimodal models - text_config = model_config.hf_config.get_text_config() + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model." + ) # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( - scheduler_config.max_num_seqs, - scheduler_config.max_num_batched_tokens, - model_config.get_vocab_size(), - lora_config, + vllm_config, device, model.embedding_modules, model.embedding_padding_modules, - max_position_embeddings=text_config.max_position_embeddings, ) return self.lora_manager.create_lora_manager(model) - def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], - token_lora_mapping: tuple[int, ...], - lora_requests: set[LoRARequest]) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + def _set_active_loras( + self, + prompt_lora_mapping: tuple[int, ...], + token_lora_mapping: tuple[int, ...], + lora_requests: set[LoRARequest], + ) -> None: + self._ensure_lora_enabled() # Set is_prefill to True, so we always use the SGMV kernels on # non-cuda platforms. # On cuda platforms we use the same kernels for prefill and # decode and this flag is generally ignored. - lora_mapping = LoRAMapping(token_lora_mapping, - prompt_lora_mapping, - is_prefill=True) + lora_mapping = LoRAMapping( + token_lora_mapping, prompt_lora_mapping, is_prefill=True + ) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - def set_active_loras(self, input_batch: InputBatch, - num_scheduled_tokens: np.ndarray) -> None: + def _ensure_lora_enabled(self) -> None: + if not hasattr(self, "lora_manager"): + raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") + def set_active_loras( + self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray + ) -> None: prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs - token_lora_mapping: tuple[int, - ...] # of size np.sum(num_scheduled_tokens) + token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) lora_requests: set[LoRARequest] - prompt_lora_mapping, token_lora_mapping, lora_requests = \ - input_batch.make_lora_inputs(num_scheduled_tokens) - return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) + prompt_lora_mapping, token_lora_mapping, lora_requests = ( + input_batch.make_lora_inputs(num_scheduled_tokens) + ) + return self._set_active_loras( + prompt_lora_mapping, token_lora_mapping, lora_requests + ) @contextmanager - def maybe_setup_dummy_loras(self, lora_config): + def maybe_setup_dummy_loras( + self, lora_config: Optional[LoRAConfig], remove_lora: bool = True + ): if lora_config is None: yield else: @@ -95,12 +94,16 @@ class LoRAModelRunnerMixin: assert self.lora_manager is not None, "LoRA is not enabled" num_loras = lora_config.max_loras - + lora_warmup_rank = ( + lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8 + ) # Make dummy lora requests lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") + LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) for lora_id in range(1, num_loras + 1) } @@ -108,17 +111,18 @@ class LoRAModelRunnerMixin: # Add the dummy LoRAs here so _set_active_loras doesn't try to # load from disk. for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank) yield # __exit__ code - self.lora_manager.remove_all_adapters() + if remove_lora: + self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + def maybe_select_dummy_loras( + self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray + ): if lora_config is None: yield else: @@ -130,50 +134,57 @@ class LoRAModelRunnerMixin: # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % - num_loras) + 1 + prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 # Make token lora mapping - token_lora_mapping = np.repeat(prompt_lora_mapping, - num_scheduled_tokens) + token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) # Make dummy lora requests lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") + LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) for lora_id in range(1, num_loras + 1) } - self._set_active_loras(tuple(prompt_lora_mapping), - tuple(token_lora_mapping), lora_requests) + self._set_active_loras( + tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests + ) yield @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): - with self.maybe_setup_dummy_loras( - lora_config), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens): + def maybe_dummy_run_with_lora( + self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + remove_lora: bool = True, + ): + with ( + self.maybe_setup_dummy_loras(lora_config, remove_lora), + self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens), + ): yield + def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]): + if lora_config is None: + return + self.lora_manager.remove_all_adapters() + def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.list_adapters() diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 81c798685cb3a..ef115ade09ab8 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -9,7 +9,7 @@ import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -18,16 +18,16 @@ _SAMPLING_EPS = 1e-5 class InputBatch: - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -54,13 +54,12 @@ class InputBatch: self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), + (max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( @@ -70,94 +69,76 @@ class InputBatch: pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, ) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) + self.temperature_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device) + self.top_k_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() - self.min_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.min_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.min_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.min_p_reqs: set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.frequency_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + self.presence_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device ) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.repetition_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # req_index -> (min_tokens, stop_token_ids) self.min_tokens: dict[int, tuple[int, set[int]]] = {} # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -174,8 +155,7 @@ class InputBatch: # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} - self.logit_bias: list[Optional[dict[int, - float]]] = [None] * max_num_reqs + self.logit_bias: list[Optional[dict[int, float]]] = [None] * max_num_reqs self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. @@ -213,14 +193,15 @@ class InputBatch: self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds + ) + # TODO: copy prompt_embeds self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids + self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids + self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids # Number of token ids in token_ids_cpu. # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens @@ -250,23 +231,22 @@ class InputBatch: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k self.min_p_cpu[req_index] = sampling_params.min_p - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty + self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty if sampling_params.min_p > _SAMPLING_EPS: self.min_p_reqs.add(req_id) if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty + self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty + self.repetition_penalties_cpu[req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) if sampling_params.min_tokens: - self.min_tokens[req_index] = (sampling_params.min_tokens, - sampling_params.all_stop_token_ids) + self.min_tokens[req_index] = ( + sampling_params.min_tokens, + sampling_params.all_stop_token_ids, + ) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -285,23 +265,23 @@ class InputBatch: if self.allowed_token_ids_mask_cpu_tensor is None: # Lazy allocation for this tensor, which can be large. # False means we don't fill with -inf. - self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.allowed_token_ids_mask = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device="cpu") + device=self.device, + ) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu" + ) self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + sampling_params.allowed_token_ids + ] = False if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids # Add request lora ID if request.lora_request: @@ -359,40 +339,56 @@ class InputBatch: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( + self.req_output_token_ids[i2], + self.req_output_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.min_p_cpu[i1], self.min_p_cpu[i2] =\ - self.min_p_cpu[i2], self.min_p_cpu[i1] + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( + self.req_id_to_index[old_id_i2], + self.req_id_to_index[old_id_i1], + ) + self.num_tokens[i1], self.num_tokens[i2] = ( + self.num_tokens[i2], + self.num_tokens[i1], + ) + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( + self.num_tokens_no_spec[i2], + self.num_tokens_no_spec[i1], + ) + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = ( + self.num_prompt_tokens[i2], + self.num_prompt_tokens[i1], + ) + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = ( + self.num_computed_tokens_cpu[i2], + self.num_computed_tokens_cpu[i1], + ) + self.temperature_cpu[i1], self.temperature_cpu[i2] = ( + self.temperature_cpu[i2], + self.temperature_cpu[i1], + ) + self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = ( + self.frequency_penalties_cpu[i2], + self.frequency_penalties_cpu[i1], + ) + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = ( + self.presence_penalties_cpu[i2], + self.presence_penalties_cpu[i1], + ) + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = ( + self.repetition_penalties_cpu[i2], + self.repetition_penalties_cpu[i1], + ) + self.min_p_cpu[i1], self.min_p_cpu[i2] = self.min_p_cpu[i2], self.min_p_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] - # instead, we need to temporiarily copy the data for one of the indices + # instead, we need to temporarily copy the data for one of the indices # TODO(lucas): optimize this by only copying valid indices tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] @@ -402,21 +398,28 @@ class InputBatch: swap_dict_values(self.min_tokens, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - self.logit_bias[i1], self.logit_bias[i2] =\ - self.logit_bias[i2], self.logit_bias[i1] + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = ( + self.request_lora_mapping[i2], + self.request_lora_mapping[i1], + ) + self.logit_bias[i1], self.logit_bias[i2] = ( + self.logit_bias[i2], + self.logit_bias[i1], + ) if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] + ( + self.allowed_token_ids_mask_cpu_tensor[i1], + self.allowed_token_ids_mask_cpu_tensor[i2], + ) = ( + self.allowed_token_ids_mask_cpu_tensor[i2], + self.allowed_token_ids_mask_cpu_tensor[i1], + ) self.block_table.swap_row(i1, i2) def condense(self, empty_req_indices: list[int]) -> None: """Move non-empty requests down into lower, empty indices. - + Args: empty_req_indices: empty batch indices, sorted descending. """ @@ -452,25 +455,29 @@ class InputBatch: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] + last_req_index + ] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ + last_req_index + ] self.block_table.move_row(last_req_index, empty_index) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[ + last_req_index + ] + self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[ + last_req_index + ] + self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[ + last_req_index + ] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: @@ -481,28 +488,28 @@ class InputBatch: self.min_tokens[empty_index] = min_token self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] + last_req_index + ] self.logit_bias[empty_index] = self.logit_bias[last_req_index] if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] + self.allowed_token_ids_mask_cpu_tensor[empty_index] = ( + self.allowed_token_ids_mask_cpu_tensor[last_req_index] + ) - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) + bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids # Decrement last_req_index since it is now empty. last_req_index -= 1 # Trim lists to the batch size. - del self._req_ids[self.num_reqs:] - del self.req_output_token_ids[self.num_reqs:] + del self._req_ids[self.num_reqs :] + del self.req_output_token_ids[self.num_reqs :] def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + max_prompt_len = self.num_prompt_tokens[: self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -510,14 +517,12 @@ class InputBatch: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[: self.num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(self.num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) + prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray @@ -533,12 +538,12 @@ class InputBatch: 3. lora_requests: Set of relevant LoRA requests. """ - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + req_lora_mapping = self.request_lora_mapping[: self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) + token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) + self.lora_id_to_lora_request.values() + ) return prompt_lora_mapping, token_lora_mapping, active_lora_requests @@ -568,9 +573,11 @@ class InputBatch: @property def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) + return ( + len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0 + ) @property def max_num_logprobs(self) -> Optional[int]: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 559af13d23303..f9e1fcedc8903 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,13 +3,15 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np import torch import torch.nn as nn + # TPU XLA related +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr @@ -17,47 +19,76 @@ import torch_xla.runtime as xr import vllm.envs as envs from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import (ParallelConfig, VllmConfig, - get_layers_from_vllm_config, update_config) -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.config import ( + ParallelConfig, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader -from vllm.model_executor.models.interfaces import supports_transcription +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + supports_transcription, +) from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, is_text_generation_model) + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, - prev_power_of_2) -from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, - PallasAttentionBackend, - PallasMetadata, - get_page_size_bytes) -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, - LogprobsTensors, ModelRunnerOutput) +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available, prev_power_of_2 +from vllm.v1.attention.backends.pallas import ( + TPU_STR_DTYPE_TO_TORCH_DTYPE, + PallasAttentionBackend, + PallasMetadata, + get_page_size_bytes, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, +) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, KVConnectorOutput) + KVConnectorModelRunnerMixin, + KVConnectorOutput, +) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, bind_kv_cache, - initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs) +from .utils import ( + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + sanity_check_mm_encoder_outputs, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -105,7 +136,6 @@ MIN_NUM_SEQS = 8 # branch predictions are included as subgraph inputs to facilitate # pre-compilation. class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -137,7 +167,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) self.enforce_eager = model_config.enforce_eager @@ -153,8 +183,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: self.kv_cache_dtype = model_dtype else: - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] self._hidden_states_dtype = self.dtype self.sliding_window = model_config.get_sliding_window() @@ -162,25 +191,28 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_model_len = model_config.max_model_len self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.num_blocks_per_most_len_req = cdiv( - self.most_model_len, - self.block_size) if self.most_model_len is not None else None + self.num_blocks_per_most_len_req = ( + cdiv(self.most_model_len, self.block_size) + if self.most_model_len is not None + else None + ) # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) self.num_tokens_paddings = _get_token_paddings( min_token_size=16, max_token_size=scheduler_config.max_num_batched_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP, + ) # In case `max_num_tokens < max(num_tokens_paddings)` use the actual # padded max value to pre-allocate data structures and pre-compile. self.max_num_tokens = self.num_tokens_paddings[-1] # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + parallel_config, LayerBlockType.attention + ) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() @@ -193,23 +225,27 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." - self._num_slices_per_kv_cache_update_block = \ - _get_num_slices_per_kv_cache_update_block(get_page_size_bytes( - block_size=self.block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - kv_cache_dtype=self.kv_cache_dtype, - )) + self._num_slices_per_kv_cache_update_block = ( + _get_num_slices_per_kv_cache_update_block( + get_page_size_bytes( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + kv_cache_dtype=self.kv_cache_dtype, + ) + ) + ) # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -223,50 +259,74 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + kernel_block_sizes=[self.cache_config.block_size], ) # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + self.input_ids_cpu = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device="cpu" + ) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + self.positions_cpu = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device="cpu" + ) self.positions_np = self.positions_cpu.numpy() self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), dtype=torch.int32, - device="cpu") + device="cpu", + ) # adjust num_reqs to avoid SMEM OOM. - self.num_reqs_most_model_len = min( - PallasAttentionBackend.get_max_num_seqs(self.most_model_len, - self.block_size), - self.max_num_reqs) if self.most_model_len is not None else None + self.num_reqs_most_model_len = ( + min( + PallasAttentionBackend.get_max_num_seqs( + self.most_model_len, self.block_size + ), + self.max_num_reqs, + ) + if self.most_model_len is not None + else None + ) self.num_reqs_max_model_len = min( - PallasAttentionBackend.get_max_num_seqs(self.max_model_len, - self.block_size), - self.max_num_reqs) - self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + PallasAttentionBackend.get_max_num_seqs( + self.max_model_len, self.block_size + ), + self.max_num_reqs, + ) + self.query_start_loc_cpu = torch.zeros( + self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.seq_lens_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.seq_lens_np = self.seq_lens_cpu.numpy() + # Only relevant for multimodal models + if self.supports_mm_inputs: + self.is_mm_embed_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory, + ) + # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens # Keep in int64 to avoid overflow with long context self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) self.num_reqs_paddings = _get_req_paddings( - min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -279,30 +339,42 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): (self.max_num_reqs, cdiv(self.vocab_size, 32)), dtype=torch.int32, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.require_structured_out_cpu = torch.zeros( (self.max_num_reqs, 1), dtype=torch.bool, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory) + 0, 32, device="cpu", pin_memory=self.pin_memory + ) - self.mm_budget = (MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None) + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) if not self.use_spmd: self.sample_from_logits_func = torch.compile( self.sample_from_logits, backend="openxla", fullgraph=True, - dynamic=False) + dynamic=False, + ) else: self.sample_from_logits_func = self.sample_from_logits + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -313,8 +385,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if new_compiled_graphs == 0: return - logger.info("Add new %d compiled XLA graphs due to %s", - new_compiled_graphs, case_str) + logger.info( + "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str + ) self.num_xla_graphs += new_compiled_graphs def _verify_num_xla_graphs(self, case_str): @@ -326,7 +399,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert self.num_xla_graphs == curr_cached_graph, ( "Recompilation after warm up is detected during {}." " num_xla_graphs = {} curr_cached_graph = {}".format( - case_str, self.num_xla_graphs, curr_cached_graph)) + case_str, self.num_xla_graphs, curr_cached_graph + ) + ) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler @@ -342,7 +417,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -357,12 +431,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): removed_req_indices.append(req_index) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -384,16 +454,17 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: - assert new_req_data.sampling_params is not None,\ + assert new_req_data.sampling_params is not None, ( "Pooling is not supported in TPU yet" + ) req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, + prompt_embeds=new_req_data.prompt_embeds, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -418,8 +489,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -436,23 +506,17 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None + # Fill the empty index or append to the end + req_index = removed_req_indices.pop() if removed_req_indices else None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. @@ -505,58 +569,77 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_vllm_config(self.vllm_config, Attention) + layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) block_size = self.vllm_config.cache_config.block_size + cache_dtype_str = self.vllm_config.cache_config.cache_dtype + kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue + # Classic Attention path + if isinstance(attn_module, Attention): + if ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue - if attn_module.attn_type == AttentionType.DECODER: - if isinstance(attn_module, ChunkedLocalAttention): - logger.warning_once( - "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context.") - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=False, - ) + if attn_module.attn_type == AttentionType.DECODER: + if isinstance(attn_module, ChunkedLocalAttention): + logger.warning_once( + "Using irope in Pallas is not supported yet, it " + "will fall back to global attention for long context." + ) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=False, - ) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") + # MLAAttention path + elif isinstance(attn_module, MLAAttention): + if layer_name in kv_cache_spec: + continue + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str, + ) else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + continue return kv_cache_spec - def _get_slot_mapping_metadata(self, num_reqs, - num_scheduled_tokens_per_req): + def _get_slot_mapping_metadata( + self, num_reqs, num_scheduled_tokens_per_req + ) -> np.ndarray: """ Computes metadata for mapping slots to blocks in the key-value (KV) cache for a batch of requests. @@ -569,26 +652,28 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: num_reqs (int): Number of requests in the current batch. num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens - to be scheduled for each request. + to be scheduled for each request. Returns: np.ndarray: A 2D array of shape (total_block_len, 3), where each row - contains: + contains: - kv_cache_start_index (int): The starting index in the KV cache - for the corresponding slice. + for the corresponding slice. - new_kv_start_index (int): The starting index in the new KV - cache for the corresponding slice. + cache for the corresponding slice. - slice_len (int): The length of the slice. """ slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] - slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ - num_scheduled_tokens_per_req + slices_end = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req + ) local_block_start_idx = slices_start // self.block_size local_block_end_idx = (slices_end - 1) // self.block_size no_repeat_req_indices = self.arange_np[:num_reqs] global_block_start_idx = ( - no_repeat_req_indices * self.max_num_blocks_per_req + - local_block_start_idx) + no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx + ) block_lens = local_block_end_idx - local_block_start_idx + 1 global_block_start_idx = np.repeat(global_block_start_idx, block_lens) slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) @@ -596,30 +681,31 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() total_block_len = np.sum(block_lens) - slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], - dtype=np.int32), - total_block_len, - axis=0) + slot_mapping_slices = np.repeat( + np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0 + ) cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) np.cumsum(block_lens, out=cu_block_lens[1:]) for req_idx in range(num_reqs): - slot_mapping_slices[cu_block_lens[req_idx]][ - 0] = slices_start[req_idx] % self.block_size - slot_mapping_slices[ - cu_block_lens[req_idx + 1] - - 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slot_mapping_slices[cu_block_lens[req_idx]][0] = ( + slices_start[req_idx] % self.block_size + ) + slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = ( + slices_end[req_idx] - 1 + ) % self.block_size + 1 slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) np.cumsum(slice_lens, out=cu_slices_lens[1:]) - kv_cache_start_indices = slot_mapping_slices[:, 0] + \ - (block_numbers * self.block_size) + kv_cache_start_indices = slot_mapping_slices[:, 0] + ( + block_numbers * self.block_size + ) new_kv_start_indices = cu_slices_lens[:-1] slot_mapping_metadata = np.stack( - [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1 + ) return slot_mapping_metadata - def _prepare_inputs(self, scheduler_output: "SchedulerOutput", - start_index: int): + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): assert scheduler_output.total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 @@ -641,22 +727,24 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens_per_req.append(num_tokens) if use_max_model_len: if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len: - num_scheduled_tokens_per_req = \ - num_scheduled_tokens_per_req[:self.num_reqs_max_model_len] + num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ + : self.num_reqs_max_model_len + ] end_index = start_index + self.num_reqs_max_model_len else: end_index = num_reqs else: - if len(num_scheduled_tokens_per_req - ) > self.num_reqs_most_model_len: - num_scheduled_tokens_per_req = \ - num_scheduled_tokens_per_req[:self.num_reqs_most_model_len] + if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: + num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ + : self.num_reqs_most_model_len + ] end_index = start_index + self.num_reqs_most_model_len else: end_index = num_reqs max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req) - num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, - dtype=np.int32) + num_scheduled_tokens_per_req = np.array( + num_scheduled_tokens_per_req, dtype=np.int32 + ) total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req) assert max_num_scheduled_tokens_all_reqs > 0 @@ -665,121 +753,130 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # For each scheduled token, what are the corresponding req index. - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_per_req) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # For each scheduled token, what is its position in corresponding req. arange = np.concatenate( - [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + [self.arange_np[:n] for n in num_scheduled_tokens_per_req] + ) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 - np.cumsum(num_scheduled_tokens_per_req, - out=self.query_start_loc_np[1:num_reqs + 1]) - self.query_start_loc_np[num_reqs + 1:] = 1 + np.cumsum( + num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1] + ) + self.query_start_loc_np[num_reqs + 1 :] = 1 self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens_per_req) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req + ) # Do the padding and copy the tensors to the TPU. padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens) + self.num_tokens_paddings, total_num_scheduled_tokens + ) # Zero out to avoid spurious values from prev iteration (last cp chunk) self.input_ids_cpu[ - total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 - self.input_ids = self.input_ids_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) - self.position_ids = self.positions_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) + total_num_scheduled_tokens:padded_total_num_scheduled_tokens + ] = 0 + self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to( + self.device + ) + self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to( + self.device + ) if use_max_model_len: - block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : - self.max_num_blocks_per_req] - block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) - query_start_loc = self.query_start_loc_cpu[:self. - num_reqs_max_model_len + - 1].to(self.device) - seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to( - self.device) + block_tables = self.block_table_cpu[ + : self.num_reqs_max_model_len, : self.max_num_blocks_per_req + ] + block_tables[:num_reqs, : self.max_num_blocks_per_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs] + ) + query_start_loc = self.query_start_loc_cpu[ + : self.num_reqs_max_model_len + 1 + ].to(self.device) + seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) else: - block_tables = self.block_table_cpu[:self. - num_reqs_most_model_len, :self. - num_blocks_per_most_len_req] - block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = ( - self.input_batch.block_table[0].get_cpu_tensor() - [:num_reqs, :self.num_blocks_per_most_len_req]) - query_start_loc = self.query_start_loc_cpu[:self. - num_reqs_most_model_len + - 1].to(self.device) - seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to( - self.device) + block_tables = self.block_table_cpu[ + : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req + ] + block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[ + :num_reqs, : self.num_blocks_per_most_len_req + ] + ) + query_start_loc = self.query_start_loc_cpu[ + : self.num_reqs_most_model_len + 1 + ].to(self.device) + seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device) block_tables = block_tables.to(self.device) # Calculate the slot mapping slot_mapping_metadata = self._get_slot_mapping_metadata( - num_reqs, num_scheduled_tokens_per_req) + num_reqs, num_scheduled_tokens_per_req + ) num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( - padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size) + padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size + ) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], - constant_values=0) + constant_values=0, + ) slot_mapping_metadata = np.transpose(slot_mapping_metadata) - slot_mapping_metadata = torch.tensor(slot_mapping_metadata, - device=self.device) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device) if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ + padded_num_scheduled_tokens_per_req[-1] += ( padded_total_num_scheduled_tokens - total_num_scheduled_tokens + ) - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, - num_seqs=torch.tensor([num_reqs], - dtype=torch.int32, - device=self.device), - num_kv_update_slices=torch.tensor([num_kv_update_slices], - dtype=torch.int32, - device=self.device), - num_slices_per_kv_cache_update_block=self. - _num_slices_per_kv_cache_update_block, + num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + num_kv_update_slices=torch.tensor( + [num_kv_update_slices], dtype=torch.int32, device=self.device + ), + num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -787,10 +884,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # token from the partial request. # TODO: Support prompt logprobs. padded_num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs, self.max_num_reqs) + num_reqs, self.max_num_reqs + ) # Indices at which we sample (positions of last token in the sequence). # Padded to avoid recompiling when `num_reqs` varies. - logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 + logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) if self.lora_config is not None: @@ -798,45 +896,23 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ + padded_num_scheduled_tokens_per_req[-1] += ( padded_total_num_scheduled_tokens - total_num_scheduled_tokens + ) - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names + layer_name: attn_metadata for layer_name in layer_names } - return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ - num_reqs, end_index - - def _scatter_placeholders( - self, - embeds: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return embeds - - placeholders = embeds.new_full( - (is_embed.shape[0], embeds.shape[-1]), - fill_value=torch.nan, + return ( + per_layer_attn_metadata, + logits_indices, + padded_num_reqs, + num_reqs, + end_index, ) - placeholders[is_embed] = embeds - return placeholders - - def _gather_placeholders( - self, - placeholders: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return placeholders - - return placeholders[is_embed] def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs @@ -845,14 +921,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # List of tuple (mm_hash, pos_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -861,11 +939,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. + model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Run the encoder. # `curr_group_outputs` is either of the following: @@ -874,10 +954,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. - xm.mark_step() - curr_group_outputs = self.model.get_multimodal_embeddings( - **mm_kwargs_group) - xm.mark_step() + torch_xla.sync(wait=False) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) + torch_xla.sync(wait=False) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -895,32 +974,37 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE (NickLucche) here we diverge from logic in other runners, as we # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - assert pos_info.is_embed is None, "Expected all positions to be"\ - " contiguous and embeddings." - self.encoder_cache[req_id][input_id] = output + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + assert pos_info.is_embed is None, ( + "Expected all positions to be contiguous and embeddings." + ) + self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] + ) -> tuple[list[torch.Tensor], torch.Tensor]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + padded_total_num_scheduled_tokens = _get_padded_token_len( + self.num_tokens_paddings, total_num_scheduled_tokens + ) + + is_mm_embed = self.is_mm_embed_cpu + is_mm_embed[:padded_total_num_scheduled_tokens] = False + mm_embeds = list[torch.Tensor]() + req_start_idx = 0 + for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions + # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -936,24 +1020,50 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # in the decoder's KV cache. continue - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - assert pos_info.is_embed is None, "Expected all positions to"\ - " be contiguous and embeddings." - encoder_output = self.encoder_cache[req_id][i] - mm_embeds.append(encoder_output) - return mm_embeds + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx - def _get_model_inputs(self, input_ids: torch.Tensor, - mm_embeds: list[torch.Tensor]): + mm_hash = mm_feature.identifier + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." + + assert pos_info.is_embed is None, ( + "Expected all positions to be contiguous and embeddings." + ) + + req_start_pos = req_start_idx + start_pos - num_computed_tokens + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True + + # Only whole mm items are processed + mm_embeds.append(encoder_output) + + req_start_idx += num_scheduled_tokens + + is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device) + + return mm_embeds, is_mm_embed + + def _get_model_inputs( + self, + input_ids: torch.Tensor, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]], + ): if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. inputs_embeds = self.model.get_input_embeddings( - input_ids=input_ids, + input_ids, multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, ) + return None, inputs_embeds else: # For text-only models, we use token ids as input. @@ -975,16 +1085,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) + mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) else: - mm_embeds = [] - xm.mark_step() + mm_embed_inputs = None + + torch_xla.sync(wait=False) # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. start_index = 0 @@ -997,41 +1107,48 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) while start_index < self.input_batch.num_reqs: - attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ - end_index = self._prepare_inputs(scheduler_output, start_index) + attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = ( + self._prepare_inputs(scheduler_output, start_index) + ) input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embeds) - xm.mark_step() + self.input_ids, mm_embed_inputs + ) + torch_xla.sync(wait=False) # Run the decoder with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens): + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens, + ): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, inputs_embeds=inputs_embeds, ) - hidden_states = self.select_hidden_states(hidden_states, - logits_indices) + hidden_states = self.select_hidden_states(hidden_states, logits_indices) logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, padded_num_reqs, self.device) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, padded_num_reqs, self.device + ) if scheduler_output.grammar_bitmask is not None: - require_struct_decoding, grammar_bitmask_padded, arange = \ - self.prepare_structured_decoding_input(logits, - scheduler_output) - logits = self.structured_decode(require_struct_decoding, - grammar_bitmask_padded, logits, - arange) + require_struct_decoding, grammar_bitmask_padded, arange = ( + self.prepare_structured_decoding_input(logits, scheduler_output) + ) + logits = self.structured_decode( + require_struct_decoding, grammar_bitmask_padded, logits, arange + ) selected_token_ids = self.sample_from_logits_func( - logits, tpu_sampling_metadata) + logits, tpu_sampling_metadata + ) # NOTE (NickLucche) Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. We can't enforce it # due to recompilations outside torch.compiled code, so just make # sure `sample_from_logits` does not modify the logits in-place. - logprobs = self.gather_logprobs(logits, selected_token_ids) \ - if tpu_sampling_metadata.logprobs else None + logprobs = ( + self.gather_logprobs(logits, selected_token_ids) + if tpu_sampling_metadata.logprobs + else None + ) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] @@ -1047,8 +1164,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # should be called right after each single forward pass, # instead of the forwards of the entire input batch. self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + finished_sending, finished_recving = self.get_finished_kv_transfers( + scheduler_output + ) selected_token_ids = torch.cat(combined_selected_tokens, dim=0) if tpu_sampling_metadata.logprobs: @@ -1059,16 +1177,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): result.extend(input_list) return result - logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists( - [lp.logprob_token_ids for lp in combined_logprobs]), - logprobs=concat_lists([ - lp.logprobs - for lp in combined_logprobs - ]), - sampled_token_ranks=concat_lists([ - lp.sampled_token_ranks - for lp in combined_logprobs - ])) + logprobs_lists = LogprobsLists( + logprob_token_ids=concat_lists( + [lp.logprob_token_ids for lp in combined_logprobs] + ), + logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]), + sampled_token_ranks=concat_lists( + [lp.sampled_token_ranks for lp in combined_logprobs] + ), + ) else: logprobs_lists = None @@ -1080,8 +1197,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) + seq_len = ( + req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id] + ) if seq_len >= req_state.num_tokens: request_seq_lens.append((i, req_state, seq_len)) else: @@ -1097,8 +1216,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): discard_sampled_tokens_req_indices.append(i) assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_id is not None for req_id in self.input_batch.req_ids[:num_reqs] + ), "req_ids contains None" req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} @@ -1126,25 +1245,28 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ - seq.tolist() - for seq in selected_token_ids[valid_mask].split(gen_lens) + seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) ] self.input_batch.num_tokens[:num_reqs] += gen_lens for i, req_state, seq_len in request_seq_lens: target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] + self.input_batch.token_ids_cpu[i, target_slice] = ( + valid_sampled_token_ids[i] + ) req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - kv_connector_output = None if ( - finished_sending is None - and finished_recving is None) else KVConnectorOutput( + kv_connector_output = ( + None + if (finished_sending is None and finished_recving is None) + else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, ) + ) model_runner_output = ModelRunnerOutput( req_ids=req_ids, + req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1163,9 +1285,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -1184,83 +1307,84 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the embedding weights. xm_tp_rank = xr.global_ordinal() with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank, + ): try: if self.use_spmd: tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config) + load_config=self.vllm_config.load_config + ) model = tpu_loader.load_model( vllm_config=self.vllm_config, model_config=self.vllm_config.model_config, - mesh=self.mesh) + mesh=self.mesh, + ) else: model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) except RuntimeError as e: raise RuntimeError( f"Unable to load model, a likely reason is the model is " "too large for the current device's HBM memory. " "Consider switching to a smaller model " "or sharding the weights on more chips. " - f"See the detailed error: {e}") from e + f"See the detailed error: {e}" + ) from e if self.lora_config is not None: - model = self.load_lora_model(model, self.model_config, - self.scheduler_config, - self.lora_config, self.device) + model = self.load_lora_model(model, self.vllm_config, self.device) replace_set_lora(model) # Sync all pending XLA execution during model initialization and weight # loading. - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() if not hasattr(self, "model"): self.model = model self.sampler = TPUSampler() def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") model_loader.load_weights(self.model, model_config=self.model_config) @torch.no_grad() - def _dummy_run(self, num_tokens: int, num_reqs: int, - num_blocks: int) -> None: + def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: if self.supports_mm_inputs: input_ids = None - inputs_embeds = torch.zeros((num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + inputs_embeds = torch.zeros( + (num_tokens, self.hidden_size), dtype=self.dtype, device=self.device + ) else: - input_ids = torch.zeros((num_tokens), - dtype=torch.int32).to(self.device) + input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device) inputs_embeds = None actual_num_reqs = min(num_tokens, num_reqs) - position_ids = torch.zeros(num_tokens, - dtype=torch.int32).to(self.device) + position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size) - num_kv_update_slices = torch.tensor([padded_num_slices], - dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros((3, padded_num_slices), - dtype=torch.int32).to(self.device) - block_tables = torch.zeros((num_reqs, num_blocks), - dtype=torch.int32).to(self.device) + num_tokens, self.max_num_reqs, self.block_size + ) + num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to( + self.device + ) + slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to( + self.device + ) + block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to( + self.device + ) query_lens = [1] * num_reqs - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.int32), - dim=0, - dtype=torch.int32).to(self.device) - context_lens = torch.ones((num_reqs, ), - dtype=torch.int32).to(self.device) - num_seqs = torch.tensor([actual_num_reqs], - dtype=torch.int32).to(self.device) + query_start_loc = torch.cumsum( + torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ).to(self.device) + context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device) + num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -1268,8 +1392,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc=query_start_loc, num_seqs=num_seqs, num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block=self. - _num_slices_per_kv_cache_update_block, + num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, ) if self.supports_mm_inputs: @@ -1282,28 +1405,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names + layer_name: attn_metadata for layer_name in layer_names } - with self.maybe_select_dummy_loras( - self.lora_config, - np.array([num_tokens], dtype=np.int32)), set_forward_context( - per_layer_attn_metadata, self.vllm_config, 0): - out = self.model(input_ids=input_ids, - positions=position_ids, - inputs_embeds=inputs_embeds) + with ( + self.maybe_select_dummy_loras( + self.lora_config, np.array([num_tokens], dtype=np.int32) + ), + set_forward_context(per_layer_attn_metadata, self.vllm_config, 0), + ): + out = self.model( + input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds + ) self._hidden_states_dtype = out.dtype - def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, - lora_requests) -> None: - xm.mark_step() # Captures input updates - super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) - xm.mark_step() # Captures metadata updates + def _set_active_loras( + self, prompt_lora_mapping, token_lora_mapping, lora_requests + ) -> None: + torch_xla.sync(wait=False) # Captures input updates + super()._set_active_loras( + prompt_lora_mapping, token_lora_mapping, lora_requests + ) + torch_xla.sync(wait=False) # Captures metadata updates def _precompile_mm_encoder(self) -> None: if not self.supports_mm_inputs: @@ -1319,8 +1444,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for mode, max_items_per_seq in max_items_per_seq_by_modality.items(): logger.info( - "Compiling Multimodal %s Encoder with different input" - " shapes.", mode) + "Compiling Multimodal %s Encoder with different input shapes.", mode + ) start = time.perf_counter() # No padding for MM encoder just yet. for num_items in range(1, max_items_per_seq + 1): @@ -1330,10 +1455,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_items, ) # Run multimodal encoder. - xm.mark_step() + torch_xla.sync(wait=False) mm_embeds = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() + **batched_dummy_mm_inputs + ) + torch_xla.sync(wait=False) num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items @@ -1346,47 +1472,61 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # XLA Workaround: if torch.zeros(..device) is used, XLA # compiles a scalar+expansion op, which won't match # the graph generated at runtime. CPU->TPU must be used - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") + placeholders_ids = torch.zeros( + num_tokens, dtype=torch.int32, device="cpu" + ) # Align placeholders and actual num mm_embeddings. - placeholders_ids[:items_size] = \ - hf_config.image_token_index + placeholders_ids[:items_size] = hf_config.image_token_index placeholders_ids = placeholders_ids.to(self.device) + + mm_mask = torch.tensor([False] * num_tokens) + mm_mask[:items_size] = True + mm_mask = mm_mask.to(self.device) # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs(placeholders_ids, - [mm_embeds]) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=([mm_embeds], mm_mask), + ) assert a is None - xm.mark_step() + torch_xla.sync(wait=False) # Pre-compile `get_input_embeddings` when mm_embeddings are not # present. Chunk is only made of text, no mm_placeholders. for num_tokens in self.num_tokens_paddings: - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") + placeholders_ids = torch.zeros( + num_tokens, dtype=torch.int32, device="cpu" + ) placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs(placeholders_ids, []) + a, b = self._get_model_inputs( + placeholders_ids, + mm_embed_inputs=None, + ) assert a is None - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() logger.info( - "Multimodal %s Encoder compilation finished in in %.2f " - "[secs].", mode, end - start) + "Multimodal %s Encoder compilation finished in in %.2f [secs].", + mode, + end - start, + ) def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run(num_tokens, self.num_reqs_max_model_len, - self.max_num_blocks_per_req) + self._dummy_run( + num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req + ) if self.most_model_len is not None: - self._dummy_run(num_tokens, self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req) + self._dummy_run( + num_tokens, + self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req, + ) xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in %.2f [secs].", end - start) @@ -1395,23 +1535,19 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _precompile_select_hidden_states(self) -> None: # Compile hidden state selection function for bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. - logger.info( - "Compiling select_hidden_states with different input shapes.") + logger.info("Compiling select_hidden_states with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_tokens in self.num_tokens_paddings: - dummy_hidden = torch.zeros((num_tokens, hsize), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_hidden = torch.zeros( + (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype + ) torch._dynamo.mark_dynamic(dummy_hidden, 0) for num_reqs in self.num_reqs_paddings: - indices = torch.zeros(num_reqs, - dtype=torch.int32, - device=self.device) + indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) torch._dynamo.mark_dynamic(indices, 0) self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, - num_reqs) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs) # Requests can't be more than tokens. But do compile for the # next bigger value in case num_tokens uses bucketed padding. if num_reqs >= min(num_tokens, self.max_num_reqs): @@ -1426,9 +1562,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_reqs in self.num_reqs_paddings: - dummy_hidden = torch.zeros((num_reqs, hsize), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_hidden = torch.zeros( + (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype + ) torch._dynamo.mark_dynamic(dummy_hidden, 0) self.compute_logits(dummy_hidden) logger.info(" -- num_seqs: %d", num_reqs) @@ -1438,23 +1574,28 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._update_num_xla_graphs("compute_logits") def _precompile_structured_decoding(self) -> None: - logger.info( - "Compiling structured_decoding with different input shapes.") + logger.info("Compiling structured_decoding with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_require_struct_decoding = \ - self.require_structured_out_cpu[:num_reqs].to(self.device) - dummy_grammar_bitmask = \ - self.grammar_bitmask_cpu[:num_reqs].to(self.device) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) + dummy_require_struct_decoding = self.require_structured_out_cpu[ + :num_reqs + ].to(self.device) + dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device) # The first dimension of the above 3 dummy tensors cannot be # mark_dynamic because some operations in structured_decode require # them to be static. arange = self.structured_decode_arange.to(self.device) - self.structured_decode(dummy_require_struct_decoding, - dummy_grammar_bitmask, dummy_logits, arange) + self.structured_decode( + dummy_require_struct_decoding, + dummy_grammar_bitmask, + dummy_logits, + arange, + ) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1462,30 +1603,29 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._update_num_xla_graphs("structured_decoding") def _precompile_sample_from_logits(self) -> None: - logger.info( - "Compiling sample_from_logits with different input shapes.") + logger.info("Compiling sample_from_logits with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) # The first dimension of dummy_logits cannot be mark_dynamic # because some operations in the sampler require it to be static. for all_greedy in [False, True]: generate_params_if_all_greedy = not all_greedy - sampling_metadata = ( - TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, - num_reqs, - self.device, - generate_params_if_all_greedy, - )) + sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, + num_reqs, + self.device, + generate_params_if_all_greedy, + ) sampling_metadata.all_greedy = all_greedy with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], - dtype=np.int32)): - self.sample_from_logits_func(dummy_logits, - sampling_metadata) + self.lora_config, np.array([num_reqs], dtype=np.int32) + ): + self.sample_from_logits_func(dummy_logits, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1496,13 +1636,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Compiling gather_logprobs with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_tokens = torch.zeros((num_reqs, 1), - dtype=torch.int64).to(self.device) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) + dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device) with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32)): + self.lora_config, np.array([num_reqs], dtype=np.int32) + ): self.gather_logprobs(dummy_logits, dummy_tokens) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() @@ -1532,7 +1674,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -1543,8 +1686,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -1565,16 +1709,17 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Isolate encoder graph from post-processing to minimize # impact of recompilation until it's fixed. start = time.perf_counter() - xm.mark_step() - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() + torch_xla.sync(wait=False) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) + torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() logger.info( "Multimodal Encoder profiling finished in %.2f [secs].", - end - start) + end - start, + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -1582,21 +1727,46 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - self._dummy_run(num_tokens, self.num_reqs_max_model_len, - self.max_num_blocks_per_req) + self._dummy_run( + num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req + ) if self.most_model_len is not None: - self._dummy_run(num_tokens, self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req) + self._dummy_run( + num_tokens, + self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req, + ) - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() self.encoder_cache.clear() gc.collect() + def maybe_setup_cross_layer_kv_sharing( + self, + kv_caches: dict[str, torch.Tensor], + kv_cache_config: KVCacheConfig, + ) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1606,11 +1776,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") + "Hybrid models with more than one KV cache type are not supported yet." + ) - if kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size != self.block_size: + if ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + != self.block_size + ): self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -1621,16 +1793,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) # Verify dtype compatibility between block_table_cpu and input_batch - assert self.block_table_cpu.dtype == self.input_batch.block_table[ - 0].get_cpu_tensor().dtype + assert ( + self.block_table_cpu.dtype + == self.input_batch.block_table[0].get_cpu_tensor().dtype + ) kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "TPU.") + "KV cache tensor shared by multiple layers is not supported in TPU." + ) kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size kv_caches: dict[str, torch.Tensor] = {} @@ -1644,49 +1821,47 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.use_spmd: num_kv_heads = kv_cache_spec.num_kv_heads assert self.original_parallel_config is not None - tp_size = \ - self.original_parallel_config.tensor_parallel_size + tp_size = self.original_parallel_config.tensor_parallel_size # TODO: Handle kv cache duplication under SPMD mode. assert num_kv_heads % tp_size == 0, ( f"num_kv_heads {num_kv_heads} must be divisible by " - f"tp_size {tp_size} under SPMD mode") + f"tp_size {tp_size} under SPMD mode" + ) kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) dtype = kv_cache_spec.dtype - tpu_kv_cache = torch.zeros(kv_cache_shape, - dtype=dtype).to(self.device) + tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to( + self.device + ) kv_caches[layer_name] = tpu_kv_cache else: raise NotImplementedError - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - ) + # Set up cross-layer KV cache sharing if needed + self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + self.kv_caches, + ) if self.use_spmd: # Shard KV Cache for cache in self.kv_caches: - xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) + xs.mark_sharding(cache, self.mesh, (None, "x", None, None)) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): - # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` # since the compiled model object of the language backbone of a # multimodal model needs to be extracted via `get_language_model`. @@ -1697,7 +1872,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): logger.info("Clear dynamo cache and cached dynamo bytecode.") torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object) + compiled_model.original_code_object + ) compiled_model.compiled_codes.clear() @torch.compile(backend="openxla", fullgraph=True, dynamic=False) @@ -1705,30 +1881,29 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return hidden_states[indices_do_sample] @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits(self, - sample_hidden_states: torch.Tensor) -> torch.Tensor: - return self.model.compute_logits(sample_hidden_states, None) + def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: + return self.model.compute_logits(sample_hidden_states) # TODO: Under SPMD mode, sample_from_logits has correctness issue. # Re-enable the torch.compile once the issue is fixed in torchxla. # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_logits( - self, logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: + self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata + ) -> torch.Tensor: """ - Sample with xla-friendly function. This function is to be traced + Sample with xla-friendly function. This function is to be traced separately from `forward` for lighter compilation overhead. """ if sampling_metadata.all_greedy: out_tokens = torch.argmax(logits, dim=-1, keepdim=True) else: - out_tokens = self.sampler(logits, - sampling_metadata).sampled_token_ids + out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids return out_tokens @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def gather_logprobs(self, logits: torch.Tensor, - sampled_tokens: torch.Tensor) -> LogprobsTensors: + def gather_logprobs( + self, logits: torch.Tensor, sampled_tokens: torch.Tensor + ) -> LogprobsTensors: """ Gather the top_logprobs with corresponding tokens. Use a fixed number of logprobs as an alternative to having multiple pre-compiled graphs. @@ -1738,28 +1913,37 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.sampler.gather_logprobs( logprobs, self.model_config.max_logprobs, - token_ids=sampled_tokens.squeeze(-1)) + token_ids=sampled_tokens.squeeze(-1), + ) @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode(self, require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, logits: torch.Tensor, - arange: torch.Tensor) -> torch.Tensor: + def structured_decode( + self, + require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, + logits: torch.Tensor, + arange: torch.Tensor, + ) -> torch.Tensor: return torch.where( require_struct_decoding, self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits) + logits, + ) - def apply_grammar_bitmask(self, logits: torch.Tensor, - grammar_bitmask: torch.Tensor, - arange: torch.Tensor): - assert (logits.shape[0] == grammar_bitmask.shape[0]) + def apply_grammar_bitmask( + self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor + ): + assert logits.shape[0] == grammar_bitmask.shape[0] logits_cloned = logits.clone() for i in range(logits.shape[0]): - unpacked_bitmask = (torch.bitwise_right_shift( - grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] + unpacked_bitmask = ( + torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :]) + & 1 + ) == 0 + unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size] logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf")) + unpacked_bitmask, -float("inf") + ) return logits_cloned def get_multimodal_embeddings(self, *args, **kwargs): @@ -1779,31 +1963,29 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.grammar_bitmask_cpu.zero_() self.require_structured_out_cpu.zero_() - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the tpu runner is - # ordering the requests in the batch. We need to match the order of - # bitmask with the order of requests - struct_out_indices: list[int] = [] - mask_indices: list[int] = [] - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: + sorted_struct_requests = sorted( + scheduler_output.structured_output_request_ids.items(), + key=lambda item: item[1], + ) + cumulative_mask_idx = 0 + for req_id, _ in sorted_struct_requests: + if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] - struct_out_indices.append(batch_index) - mask_indices.append(mask_index) - self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( - grammar_bitmask[mask_indices]) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) - self.require_structured_out_cpu[struct_out_indices] = True - return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ - self.structured_decode_arange.to(logits.device) + self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( + grammar_bitmask[cumulative_mask_idx] + ) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + self.require_structured_out_cpu[batch_index] = True + cumulative_mask_idx += 1 + + return ( + self.require_structured_out_cpu[:num_reqs].to(logits.device), + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), + self.structured_decode_arange.to(logits.device), + ) def _get_mm_dummy_batch( self, @@ -1811,10 +1993,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=self.max_model_len, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data @@ -1822,12 +2007,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_items = [dummy_mm_item] * max_items_per_batch - return next(grouped_mm_kwargs - for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - )) + model = cast(SupportsMultiModal, self.model) + return next( + grouped_mm_kwargs + for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: @@ -1848,9 +2037,10 @@ def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: return min(res, upper_limit) -def _get_token_paddings(min_token_size: int, max_token_size: int, - padding_gap: int) -> list[int]: - """Generate a list of padding size, starting from min_token_size, +def _get_token_paddings( + min_token_size: int, max_token_size: int, padding_gap: int +) -> list[int]: + """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size If padding_gap == 0 then: @@ -1888,84 +2078,15 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, def _get_padded_token_len(paddings: list[int], x: int) -> int: - """Return the first element in paddings list greater or equal to x. - """ + """Return the first element in paddings list greater or equal to x.""" index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index] -def _make_src_and_dst_indices( - src_block_ids: list[int], - dst_block_ids: list[int], - src_device: Union[torch.device, str], - dst_device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor]: - src_indices = torch.tensor(src_block_ids, - device=src_device, - dtype=torch.int64) - dst_indices = torch.tensor(dst_block_ids, - device=dst_device, - dtype=torch.int64) - return src_indices, dst_indices - - -@torch.compile(backend="openxla") -def _insert_blocks_to_tpu( - cpu_cache: torch.Tensor, - tpu_cache: torch.Tensor, - cpu_block_indices: torch.Tensor, - tpu_block_indices: torch.Tensor, -) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( - tpu_cache.device) - - -@torch.compile(backend="openxla") -def _swap_out_tpu_blocks( - tpu_cache: torch.Tensor, - cpu_cache: torch.Tensor, - tpu_block_indices: torch.Tensor, - cpu_block_indices: torch.Tensor, -) -> None: - """ tpu blocks to cpu blocks""" - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() - - -def copy_kv_blocks( - src_kv_caches: dict[str, torch.Tensor], - dst_kv_caches: dict[str, torch.Tensor], - src_block_ids: list[int], - dst_block_ids: list[int], - direction: Literal["h2d", "d2h"], -) -> None: - """Copy kv blocks between different buffers.""" - if not src_kv_caches or not dst_kv_caches or \ - not src_block_ids or not dst_block_ids or \ - len(src_block_ids) != len(dst_block_ids): - return - - src_device = next(iter(src_kv_caches.values())).device - dst_device = next(iter(dst_kv_caches.values())).device - - src_indices, dst_indices = _make_src_and_dst_indices( - src_block_ids=src_block_ids, - dst_block_ids=dst_block_ids, - src_device=src_device, - dst_device=dst_device) - - _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ - _swap_out_tpu_blocks - for layer_name in src_kv_caches: - src_tensor = src_kv_caches[layer_name] - dst_tensor = dst_kv_caches[layer_name] - _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) - - -def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, - page_size: int) -> int: +def _get_padded_num_kv_cache_update_slices( + num_tokens: int, max_num_reqs: int, page_size: int +) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" # NOTE(chengjiyao): let's say R_i is the token num for i-th request, @@ -2001,7 +2122,6 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: def replace_set_lora(model): - def _tpu_set_lora( self, index: int, @@ -2014,16 +2134,15 @@ def replace_set_lora(model): # to a tensor doesn't seem to work anymore. This might be fixed with a # later release of torch_xla. self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) - xm.mark_step() + torch_xla.sync(wait=False) def _tpu_reset_lora(self, index: int): self._original_reset_lora(index) - xm.mark_step() + torch_xla.sync(wait=False) for _, module in model.named_modules(): if isinstance(module, BaseLayerWithLoRA): module._original_set_lora = module.set_lora module._original_reset_lora = module.reset_lora module.set_lora = _tpu_set_lora.__get__(module, module.__class__) - module.reset_lora = _tpu_reset_lora.__get__( - module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9adf8a14213f3..b64cec318f6c6 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -3,7 +3,7 @@ """A TPU worker class.""" import os -from typing import Any, Optional +from typing import Any, Callable, Optional, TypeVar import torch import torch.distributed @@ -11,28 +11,33 @@ import torch.nn as nn import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + has_kv_transfer_group, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.platforms.tpu import USE_TPU_COMMONS +from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, - KVCacheSpec) +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) -if not USE_TPU_COMMONS: - logger.info("tpu_commons not found, using vLLM's TPUWorker.") +_R = TypeVar("_R") + +if not USE_TPU_INFERENCE: + logger.info("tpu_inference not found, using vLLM's TPUWorker.") import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp import torch_xla.runtime as xr @@ -42,7 +47,6 @@ if not USE_TPU_COMMONS: class TPUWorker: - def __init__( self, vllm_config: VllmConfig, @@ -80,12 +84,12 @@ class TPUWorker: if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype] + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() # Delay profiler initialization to the start of the profiling. @@ -98,14 +102,14 @@ class TPUWorker: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - self.profile_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", self.profile_dir + ) if self.model_config.seed is None: self.model_config.seed = 0 - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -116,9 +120,10 @@ class TPUWorker: # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + - " --xla_tpu_force_1d_allreduce_at_chunk_count=1" - " --xla_jf_conv_input_fusion=False") + os.environ.get("LIBTPU_INIT_ARGS", "") + + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" + " --xla_jf_conv_input_fusion=False" + ) # --xla_jf_conv_input_fusion=False is used to improve the perf of # quantized matmul. torch.set_grad_enabled(False) @@ -126,8 +131,8 @@ class TPUWorker: # Initialize the distributed environment. self._init_tpu_worker_distributed_environment( - self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank) + self.vllm_config, self.rank, self.distributed_init_method, self.local_rank + ) # Device initialization should happen after initializing # the distributed runtime. @@ -156,14 +161,15 @@ class TPUWorker: # cache during development is recommended.We can disable it by # `export VLLM_XLA_CACHE_PATH=` if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") + per_rank_path = os.path.join( + envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}" + ) xr.initialize_cache(per_rank_path, readonly=False) # Init ModelRunner here, so that we have access to self.device. - self.model_runner = \ - TPUModelRunner(self.vllm_config, self.device, - self.original_parallel_config) + self.model_runner = TPUModelRunner( + self.vllm_config, self.device, self.original_parallel_config + ) if rank == 0: # If usage stat is enabled, collect relevant info. @@ -182,13 +188,15 @@ class TPUWorker: kv_caches[layer_name] = tpu_kv_cache else: raise NotImplementedError( - f"Unsupported KV cache spec '{type(layer_spec)}'") + f"Unsupported KV cache spec '{type(layer_spec)}'" + ) runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches) + runner_kv_caches, + ) # `max_num_tokens >= max_num_batched_tokens` due to padding. with self.model_runner.maybe_setup_dummy_loras(self.lora_config): @@ -213,6 +221,7 @@ class TPUWorker: # TODO: use xm.get_memory_info for SPMD once it's supported in # PyTorch/XLA. import tpu_info + chip_type, _ = tpu_info.device.get_local_chips() device_usage = tpu_info.metrics.get_chip_usage(chip_type) total_memory_size = device_usage[0].total_memory @@ -229,20 +238,20 @@ class TPUWorker: profiled = current_mem * 1.02 # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int(total_memory_size * - self.cache_config.gpu_memory_utilization) + usable_memory_size = int( + total_memory_size * self.cache_config.gpu_memory_utilization + ) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) head_size = self.model_config.get_head_size() if head_size > 0: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) if padded_head_size != head_size: - logger.warning_once("head size is padded to %d", - padded_head_size) + logger.warning_once("head size is padded to %d", padded_head_size) # We adjust the usable memory size for the KV cache to prevent OOM # errors, even after padding the head_size. - tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size // - padded_head_size) + tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size return int(tpu_kv_cache_bytes) def execute_model( @@ -250,9 +259,8 @@ class TPUWorker: scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - # every worker's output is needed when kv_transfer_group is setup - return output if self.is_driver_worker or has_kv_transfer_group( - ) else None + # every worker's output is needed when kv_transfer_group is set up + return output if self.is_driver_worker or has_kv_transfer_group() else None def profile(self, is_start: bool = True): if self.rank < 1: @@ -285,6 +293,9 @@ class TPUWorker: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -325,13 +336,20 @@ class TPUWorker: backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) ensure_kv_transfer_initialized(vllm_config) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() -if USE_TPU_COMMONS: - from tpu_commons.worker import TPUWorker as TPUCommonsWorker + def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: + """Apply a function on the model inside this worker.""" + return fn(self.get_model()) - TPUWorker = TPUCommonsWorker # type: ignore + +if USE_TPU_INFERENCE: + from tpu_inference.worker import TPUWorker as TpuInferenceWorker + + TPUWorker = TpuInferenceWorker # type: ignore diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py new file mode 100644 index 0000000000000..ef22977e094b2 --- /dev/null +++ b/vllm/v1/worker/ubatch_utils.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import numpy as np +from typing_extensions import TypeAlias + +from vllm.config import ParallelConfig + + +@dataclass +class UBatchSlice: + request_slice: slice + token_slice: slice + + def is_empty(self) -> bool: + return ( + self.request_slice.start == self.request_slice.stop + or self.token_slice.start == self.token_slice.stop + ) + + @property + def num_tokens(self) -> int: + return self.token_slice.stop - self.token_slice.start + + +UBatchSlices: TypeAlias = list[UBatchSlice] + + +def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool: + return (padded_num_tokens // 2) >= orig_num_tokens + + +def check_ubatch_thresholds( + config: ParallelConfig, num_tokens: int, uniform_decode: bool +) -> bool: + if not config.enable_dbo: + return False + if uniform_decode: + return num_tokens >= config.dbo_decode_token_threshold + else: + return num_tokens >= config.dbo_prefill_token_threshold + + +def create_ubatch_slices( + num_scheduled_tokens: np.ndarray, split_point: int +) -> UBatchSlices: + # TODO(lucas): Refactor the gpu_model_runner.py so we can pass + # in cu_num_tokens directly (i.e. query_start_loc) + cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) + np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) + + first_ubatch_token_slice = slice(0, split_point) + second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + + # Determine request slices using exclusive stop semantics + # First ubatch includes requests whose tokens overlap [0, split_point) + first_ubatch_req_stop = int( + np.searchsorted(cu_num_tokens, split_point, side="left") + ) + first_ubatch_req_slice = slice(0, first_ubatch_req_stop) + + # Second ubatch starts at the request that contains the split_point + # or the request starting exactly at split_point (if on boundary) + second_ubatch_req_start = int( + np.searchsorted(cu_num_tokens, split_point, side="right") - 1 + ) + second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) + + return [ + UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), + UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), + ] diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py new file mode 100644 index 0000000000000..867ce2b930369 --- /dev/null +++ b/vllm/v1/worker/ubatching.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import Optional + +import torch + +from vllm import forward_context +from vllm.forward_context import ForwardContext +from vllm.utils import current_stream + +_THREAD_ID_TO_CONTEXT: dict = {} +_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] + + +class UBatchContext: + """ + Context manager for micro-batching synchronization using threading events. + """ + + def __init__( + self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + forward_context: ForwardContext, + ready_barrier: threading.Barrier, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default", + ): + self.id = id + self.comm_stream = comm_stream + self.compute_stream = compute_stream + self.forward_context = forward_context + self.ready_barrier = ready_barrier + self.cpu_wait_event = cpu_wait_event + self.cpu_signal_event = cpu_signal_event + self.current_stream = compute_stream + self.gpu_comm_done_event = gpu_comm_done_event + self.gpu_compute_done_event = gpu_compute_done_event + self.schedule = schedule + self.recv_hook = None + + def __enter__(self): + global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT + _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id + _CURRENT_CONTEXTS[self.id] = self + self.ready_barrier.wait() + + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + # Assume we want to start on the compute stream + self.update_stream(self.compute_stream) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT + _CURRENT_CONTEXTS[self.id] = None + del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + self.maybe_run_recv_hook() + self.cpu_signal_event.set() + self.cpu_wait_event.clear() + return False + + def _restore_context(self): + forward_context._forward_context = self.forward_context + + def update_stream(self, stream): + self.current_stream = stream + if current_stream() != self.current_stream: + torch.cuda.set_stream(self.current_stream) + + def _signal_comm_done(self): + self.gpu_comm_done_event.record(self.comm_stream) + + def _signal_compute_done(self): + self.gpu_compute_done_event.record(self.compute_stream) + + def _wait_compute_done(self): + self.comm_stream.wait_event(self.gpu_compute_done_event) + + def _wait_comm_done(self): + self.compute_stream.wait_event(self.gpu_comm_done_event) + + def _cpu_yield(self): + # It is critical for correctness that only one thread is running + # at a time. These asserts just make sure that this is the only + # thread running before waking the other one up and going to sleep + assert forward_context._forward_context == self.forward_context + assert current_stream() == self.current_stream + assert not self.cpu_wait_event.is_set() + + self.cpu_signal_event.set() + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + + def switch_to_comm(self): + self.update_stream(self.comm_stream) + + def switch_to_compute(self): + self.update_stream(self.compute_stream) + + def switch_to_comm_sync(self): + self._signal_compute_done() + self.update_stream(self.comm_stream) + self._wait_compute_done() + + def switch_to_compute_sync(self): + self._signal_comm_done() + self.update_stream(self.compute_stream) + self._wait_comm_done() + + def maybe_run_recv_hook(self): + if self.recv_hook is not None: + self.recv_hook() + self.recv_hook = None + + def yield_(self): + self.current_stream = current_stream() + self._cpu_yield() + self.update_stream(self.current_stream) + + def yield_and_switch_from_compute_to_comm(self): + assert current_stream() == self.compute_stream + self._signal_compute_done() + self._cpu_yield() + assert self.current_stream == self.compute_stream + self.update_stream(self.comm_stream) + self._wait_compute_done() + + def yield_and_switch_from_comm_to_compute(self): + assert current_stream() == self.comm_stream + self._signal_comm_done() + self._cpu_yield() + assert self.current_stream == self.comm_stream + self.update_stream(self.compute_stream) + self._wait_comm_done() + + +def dbo_enabled() -> bool: + return len(_THREAD_ID_TO_CONTEXT) > 0 + + +def dbo_current_ubatch_id() -> int: + if len(_THREAD_ID_TO_CONTEXT) == 0: + return 0 + return _THREAD_ID_TO_CONTEXT[threading.get_ident()] + + +def _register_ubatch_function(func): + def wrapper(*args, **kwargs): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + ctx = _CURRENT_CONTEXTS[ctx_idx] + func(ctx, *args, **kwargs) + + return wrapper + + +dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) +dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( + UBatchContext.yield_and_switch_from_compute_to_comm +) +dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( + UBatchContext.yield_and_switch_from_comm_to_compute +) +dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm) +dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute) +dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute_sync = _register_ubatch_function( + UBatchContext.switch_to_compute_sync +) + + +def dbo_register_recv_hook(recv_hook): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx.recv_hook = recv_hook + + +def make_ubatch_contexts( + num_micro_batches: int, + compute_stream: torch.cuda.Stream, + comm_stream: torch.cuda.Stream, + forward_contexts: list[ForwardContext], + ready_barrier: threading.Barrier, + schedule: str = "default", +) -> list[UBatchContext]: + assert num_micro_batches == 2, "only been tested with 2 micro-batches" + """ + Create a context manager for micro-batching synchronization. + """ + cpu_events = [threading.Event() for _ in range(num_micro_batches)] + gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] + gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] + + assert len(forward_contexts) == 2 + + ctxs = [] + for i in range(num_micro_batches): + ctx = UBatchContext( + id=i, + compute_stream=compute_stream, + comm_stream=comm_stream, + forward_context=forward_contexts[i], + ready_barrier=ready_barrier, + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule, + ) + ctxs.append(ctx) + + return ctxs diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index ffc1a11bc3ba1..6657a2a8db828 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -7,13 +7,15 @@ from typing import TYPE_CHECKING, Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import ModelConfig, SchedulerConfig +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry +from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -33,14 +35,18 @@ class MultiModalBudget: self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality( + model_config, cache=cache + ) + ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, @@ -120,12 +126,39 @@ class MultiModalBudget: return max_items_per_prompt, max_items_per_batch + def reset_cache(self) -> None: + if self.cache is not None: + self.cache.clear_cache() + @dataclass class AttentionGroup: backend: type[AttentionBackend] - metadata_builder: AttentionMetadataBuilder + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. + metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] + kv_cache_spec: KVCacheSpec + + @staticmethod + def create_with_metadata_builders( + backend: type[AttentionBackend], + layer_names: list[str], + kv_cache_spec: KVCacheSpec, + vllm_config: VllmConfig, + device: torch.device, + num_metadata_builders: int = 1, + ) -> "AttentionGroup": + metadata_builders = [ + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) + for _ in range(num_metadata_builders) + ] + return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec) + + def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: + assert len(self.metadata_builders) > ubatch_id + return self.metadata_builders[ubatch_id] def sanity_check_mm_encoder_outputs( @@ -140,19 +173,22 @@ def sanity_check_mm_encoder_outputs( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " f"or a single 3D tensor, but got {type(mm_embeddings)} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert len(mm_embeddings) == expected_num_items, ( "Expected number of multimodal embeddings to match number of " f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert all(e.ndim == 2 for e in mm_embeddings), ( "Expected multimodal embeddings to be a sequence of 2D tensors, " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) def scatter_mm_placeholders( @@ -167,10 +203,10 @@ def scatter_mm_placeholders( Args: embeds: The multimodal embeddings. - Shape: `(num_embeds, embed_dim)` + Shape: `(num_embeds, embed_dim)` is_embed: A boolean mask indicating which positions in the placeholder - tokens need to be filled with multimodal embeddings. - Shape: `(num_placeholders, num_embeds)` + tokens need to be filled with multimodal embeddings. + Shape: `(num_placeholders, num_embeds)` """ if is_embed is None: return embeds @@ -190,7 +226,8 @@ def gather_mm_placeholders( """ Reconstructs the embeddings from the placeholder tokens. - This is the operation of [scatter_mm_placeholders][]. + This is the operation of [`scatter_mm_placeholders`] + [vllm.v1.worker.utils.scatter_mm_placeholders]. """ if is_embed is None: return placeholders @@ -198,12 +235,9 @@ def gather_mm_placeholders( return placeholders[is_embed] -def initialize_kv_cache_for_kv_sharing( +def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - kv_caches: dict[str, torch.Tensor], - # Optional for now to avoid breaking TPU - attn_groups: Optional[list[list[AttentionGroup]]] = None, runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ @@ -218,38 +252,15 @@ def initialize_kv_cache_for_kv_sharing( means this layer will perform attention using the keys and values from the KV cache of `shared_kv_cache_layers[layer_name]`. kv_cache_groups: The KV cache groups of the model. - kv_caches: The allocated kv_caches with layer names as keys. - Note that layers in shared_kv_cache_layers.keys() are not - originally included as it only contains layers which have its own - KV cache allocation. - attn_groups: Optional list of attention groups. Layers in the same KV - cache group may be placed in different attention groups if they - have different attention backends. Currently only provided by - GPU model runner. """ - # mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx) - layer_to_attn_group_idx: dict[str, tuple[int, int]] = {} - if attn_groups: - for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups): - for attn_group_idx, attn_group in enumerate(kv_attn_groups): - for layer_name in attn_group.layer_names: - layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, - attn_group_idx) - else: - for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - # attn group idx default to 0 if not provided - layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0) + layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} + for kv_cache_group in kv_cache_groups: + for layer_name in kv_cache_group.layer_names: + layer_to_kv_cache_group[layer_name] = kv_cache_group for layer_name, target_layer_name in shared_kv_cache_layers.items(): - kv_caches[layer_name] = kv_caches[target_layer_name] - kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0] - kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name) - - if attn_groups: - attn_group_idx = layer_to_attn_group_idx[target_layer_name][1] - attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( - layer_name) + tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] + tgt_kv_cache_group.layer_names.append(layer_name) if runner_only_attn_layers is not None: runner_only_attn_layers.add(layer_name) @@ -259,6 +270,7 @@ def bind_kv_cache( kv_caches: dict[str, torch.Tensor], forward_context: dict[str, "Attention"], runner_kv_caches: list[torch.Tensor], + num_attn_module: Optional[int] = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so @@ -273,7 +285,7 @@ def bind_kv_cache( Args: kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention - layers with layer names as keys. + layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ # Bind kv_caches to ModelRunner @@ -282,7 +294,7 @@ def bind_kv_cache( # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: - index2name[extract_layer_index(layer_name)].append(layer_name) + index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] @@ -290,7 +302,17 @@ def bind_kv_cache( # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. - raise NotImplementedError + + # TODO - analyze where runner_kv_caches is used and the right + # way to ensure it properly reflects multiple attention layers + # in the same decoder block. + if current_platform.is_cuda() or current_platform.is_xpu(): + # We know that the GPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) @@ -298,3 +320,28 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +def is_residual_scattered_for_sp( + vllm_config: VllmConfig, num_input_tokens: int +) -> bool: + """Check if the residual tensor is scattered for sequence parallelism. + + The residual tensor is scattered across tensor parallel ranks when sequence + parallelism and tensor parallelism is enabled, and the number of + input tokens is one of the compilation sizes. + """ + if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: + return False + + tp = vllm_config.parallel_config.tensor_parallel_size + + if tp == 1: + return False + + # When sequence parallelism is enabled, we always pad num_input_tokens + # to be a multiple of tensor_parallel_size (tp) earlier. + assert num_input_tokens % tp == 0 + + # Currently, SP is only enabled for static size fx graphs. + return num_input_tokens in vllm_config.compilation_config.compile_sizes diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 9c93754f93f81..8ee3b240904ca 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -1,23 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union import torch import torch.nn as nn -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config +from vllm.utils import ( + enable_trace_function_call_for_thread, + resolve_obj_by_qualname, + run_method, + update_environment_variables, + warn_for_unimplemented_methods, +) from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) +_R = TypeVar("_R") -class WorkerBase(WorkerBaseV0): - """ - Abstract class for v1 worker, mainly define some methods for v1. - For methods shared by v0 and v1, define them in v0 WorkerBase + +@warn_for_unimplemented_methods +class WorkerBase: + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. """ def __init__( @@ -27,20 +46,34 @@ class WorkerBase(WorkerBaseV0): rank: int, distributed_init_method: str, is_driver_worker: bool = False, - ): + ) -> None: """ Initialize common worker components. - + Args: vllm_config: Complete vLLM configuration local_rank: Local device index rank: Global rank in distributed setup distributed_init_method: Distributed initialization method - is_driver_worker: Whether this worker handles driver - responsibilities + is_driver_worker: Whether this worker handles driver + responsibilities """ - # Configuration storage - super().__init__(vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config + self.compilation_config = vllm_config.compilation_config + + from vllm.platforms import current_platform + + self.current_platform = current_platform self.parallel_config.rank = rank self.local_rank = local_rank @@ -49,8 +82,8 @@ class WorkerBase(WorkerBaseV0): self.is_driver_worker = is_driver_worker # Device and model state - self.device: Optional[torch.device] = None - self.model_runner: Optional[nn.Module] = None + self.device: torch.device | None = None + self.model_runner: nn.Module | None = None def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """Get specifications for KV cache implementation.""" @@ -63,3 +96,286 @@ class WorkerBase(WorkerBaseV0): def check_health(self) -> None: """Basic health check (override for device-specific checks).""" return + + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks.""" + raise NotImplementedError + + def reset_mm_cache(self) -> None: + reset_fn = getattr(self.model_runner, "reset_mm_cache", None) + if callable(reset_fn): + reset_fn() + + def get_model(self) -> nn.Module: + raise NotImplementedError + + def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: + """Apply a function on the model inside this worker.""" + return fn(self.get_model()) + + def load_model(self) -> None: + """Load model onto target device.""" + raise NotImplementedError + + def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput: + raise NotImplementedError + + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + raise NotImplementedError("Dead V0 code") + + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + def get_cache_block_size_bytes(self) -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + def list_loras(self) -> set[int]: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + """Get vocabulary size from model configuration.""" + return self.model_config.get_vocab_size() + + def shutdown(self) -> None: + """Clean up resources held by the worker.""" + return + + +class WorkerWrapperBase: + """ + This class represents one process in an executor/engine. It is responsible + for lazily initializing the worker and handling the worker's lifecycle. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__( + self, + vllm_config: VllmConfig, + rpc_rank: int = 0, + ) -> None: + """ + Initialize the worker wrapper with the given vllm_config and rpc_rank. + Note: rpc_rank is the rank of the worker in the executor. In most cases, + it is also the rank of the worker in the distributed group. However, + when multiple executors work together, they can be different. + e.g. in the case of SPMD-style offline inference with TP=2, + users can launch 2 engines/executors, each with only 1 worker. + All workers have rpc_rank=0, but they have different ranks in the TP + group. + """ + self.rpc_rank = rpc_rank + self.worker: WorkerBase | None = None + self.vllm_config: VllmConfig | None = None + # do not store this `vllm_config`, `init_worker` will set the final + # one. TODO: investigate if we can remove this field in + # `WorkerWrapperBase`, `init_cached_hf_modules` should be + # unnecessary now. + if vllm_config.model_config is not None: + # it can be None in tests + trust_remote_code = vllm_config.model_config.trust_remote_code + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + + def shutdown(self) -> None: + if self.worker is not None: + self.worker.shutdown() + + def adjust_rank(self, rank_mapping: dict[int, int]) -> None: + """ + Adjust the rpc_rank based on the given mapping. + It is only used during the initialization of the executor, + to adjust the rpc_rank of workers after we create all workers. + """ + if self.rpc_rank in rank_mapping: + self.rpc_rank = rank_mapping[self.rpc_rank] + + def update_environment_variables( + self, + envs_list: list[dict[str, str]], + ) -> None: + envs = envs_list[self.rpc_rank] + key = "CUDA_VISIBLE_DEVICES" + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + kwargs = all_kwargs[self.rpc_rank] + self.vllm_config = kwargs.get("vllm_config") + assert self.vllm_config is not None, ( + "vllm_config is required to initialize the worker" + ) + enable_trace_function_call_for_thread(self.vllm_config) + + from vllm.plugins import load_general_plugins + + load_general_plugins() + + if isinstance(self.vllm_config.parallel_config.worker_cls, str): + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls + ) + else: + raise ValueError( + "passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501 + ) + if self.vllm_config.parallel_config.worker_extension_cls: + worker_extension_cls = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_extension_cls + ) + extended_calls = [] + if worker_extension_cls not in worker_class.__bases__: + # check any conflicts between worker and worker_extension_cls + for attr in dir(worker_extension_cls): + if attr.startswith("__"): + continue + assert not hasattr(worker_class, attr), ( + f"Worker class {worker_class} already has an attribute" + f" {attr}, which conflicts with the worker" + f" extension class {worker_extension_cls}." + ) + if callable(getattr(worker_extension_cls, attr)): + extended_calls.append(attr) + # dynamically inherit the worker extension class + worker_class.__bases__ = worker_class.__bases__ + ( + worker_extension_cls, + ) + logger.info( + "Injected %s into %s for extended collective_rpc calls %s", + worker_extension_cls, + worker_class, + extended_calls, + ) + + shared_worker_lock = kwargs.pop("shared_worker_lock", None) + if shared_worker_lock is None: + msg = ( + "Missing `shared_worker_lock` argument from executor. " + "This argument is needed for mm_processor_cache_type='shm'." + ) + + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_processor_cache_type == "shm": + raise ValueError(msg) + else: + logger.warning_once(msg) + + self.mm_receiver_cache = None + else: + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, + MULTIMODAL_REGISTRY, + shared_worker_lock, + ) + + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during worker initialization + self.worker = worker_class(**kwargs) + assert self.worker is not None + + def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: + kv_cache_config = kv_cache_configs[self.rpc_rank] + with set_current_vllm_config(self.vllm_config): + self.worker.initialize_from_config(kv_cache_config) # type: ignore + + def init_device(self): + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during device initialization + self.worker.init_device() # type: ignore + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + try: + # method resolution order: + # if a method is defined in this class, it will be called directly. + # otherwise, since we define `__getattr__` and redirect attribute + # query to `self.worker`, the method will be called on the worker. + return run_method(self, method, args, kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = ( + f"Error executing method {method!r}. " + "This might cause deadlock in distributed execution." + ) + logger.exception(msg) + raise e + + def __getattr__(self, attr: str): + return getattr(self.worker, attr) + + def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None: + mm_cache = self.mm_receiver_cache + if mm_cache is None: + return + + for req_data in scheduler_output.scheduled_new_reqs: + req_data.mm_features = mm_cache.get_and_update_features( + req_data.mm_features + ) + + def execute_model( + self, + scheduler_output: SchedulerOutput, + *args, + **kwargs, + ) -> ModelRunnerOutput: + self._apply_mm_cache(scheduler_output) + + assert self.worker is not None + return self.worker.execute_model(scheduler_output, *args, **kwargs) + + def reset_mm_cache(self) -> None: + mm_receiver_cache = self.mm_receiver_cache + if mm_receiver_cache is not None: + mm_receiver_cache.clear_cache() + + assert self.worker is not None + self.worker.reset_mm_cache() diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 59f8d0fcf5bd9..4f82c18da73aa 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager from typing import TYPE_CHECKING import torch @@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner): vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False @@ -31,3 +33,23 @@ class XPUModelRunner(GPUModelRunner): def _sync_device(self) -> None: torch.xpu.synchronize() + + +@contextmanager +def _torch_cuda_wrapper(): + class _EventPlaceholder: + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + try: + # replace cuda APIs with xpu APIs, this should work by default + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.default_stream = torch.xpu.current_stream + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.stream = torch.xpu.stream + yield + finally: + # if anything goes wrong, just patch it with a placeholder + torch.cuda.Event = _EventPlaceholder diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 134d839252653..a1e54628d9ed1 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -11,8 +11,7 @@ from vllm.distributed import get_world_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) @@ -29,8 +28,9 @@ class XPUWorker(Worker): distributed_init_method: str, is_driver_worker: bool = False, ): - super().__init__(vllm_config, local_rank, rank, - distributed_init_method, is_driver_worker) + super().__init__( + vllm_config, local_rank, rank, distributed_init_method, is_driver_worker + ) device_config = self.device_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() @@ -39,8 +39,10 @@ class XPUWorker(Worker): # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", @@ -59,7 +61,9 @@ class XPUWorker(Worker): with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) + torch_profiler_trace_dir, use_gzip=True + ), + ) else: self.profiler = None @@ -75,8 +79,7 @@ class XPUWorker(Worker): # and we don't have any API to get it. so we mark it as 128MB. used_memory = torch.xpu.memory_allocated() non_torch_allocations = 128 * 1024 * 1024 - free_gpu_memory = total_gpu_memory - (used_memory + - non_torch_allocations) + free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations) return free_gpu_memory, total_gpu_memory @torch.inference_mode() @@ -84,7 +87,7 @@ class XPUWorker(Worker): """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks + Then, it calculates the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. .. tip:: You may limit the usage of GPU memory @@ -97,10 +100,12 @@ class XPUWorker(Worker): free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() current_allocated_bytes = torch.xpu.memory_allocated() - msg = ("Before memory profiling run, " - f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " - f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + msg = ( + "Before memory profiling run, " + f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " + f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." + ) logger.info(msg) # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -113,66 +118,73 @@ class XPUWorker(Worker): "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + "not properly cleaned up before initializing the vLLM instance." + ) # Get the peak memory allocation recorded by torch peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] torch.xpu.empty_cache() - torch_allocated_bytes = torch.xpu.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = self.xpu_get_mem_info( - )[1] - self.xpu_get_mem_info()[0] + torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"] + total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0] non_torch_allocations = total_allocated_bytes - torch_allocated_bytes if non_torch_allocations > 0: peak_memory += non_torch_allocations available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) + total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory + ) - msg = ("After memory profiling run, " - f"peak memory usage is {peak_memory / 1024**2:.2f} MB," - f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " - f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + msg = ( + "After memory profiling run, " + f"peak memory usage is {peak_memory / 1024**2:.2f} MB," + f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " + f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." + ) logger.info(msg) return int(available_kv_cache_memory) def init_device(self): - if self.device_config.device.type == "xpu" and current_platform.is_xpu( - ): + if self.device_config.device.type == "xpu" and current_platform.is_xpu(): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) + current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory + self.local_rank + ).total_memory else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") + raise RuntimeError(f"Not support device type: {self.device_config.device}") ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd") ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") - ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", - str(self.parallel_config.world_size)) + ENV_LOCAL_WORLD_SIZE = os.getenv( + "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size) + ) os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_RANK"] = str(self.local_rank) - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce(torch.zeros(1).xpu(), - group=get_world_group().device_group) + torch.distributed.all_reduce( + torch.zeros(1).xpu(), group=get_world_group().device_group + ) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner = XPUModelRunner( # type: ignore - self.vllm_config, self.device) + self.vllm_config, self.device + ) diff --git a/vllm/version.py b/vllm/version.py index 6c88b1b5a3bf4..63095f8bce1ea 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -6,9 +6,7 @@ try: except Exception as e: import warnings - warnings.warn(f"Failed to read commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) __version__ = "dev" __version_tuple__ = (0, 0, __version__) diff --git a/vllm/worker/__init__.py b/vllm/worker/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py deleted file mode 100644 index 530907012f704..0000000000000 --- a/vllm/worker/cache_engine.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""CacheEngine class for managing the KV cache.""" -from typing import List - -import torch - -from vllm.attention import get_attn_backend -from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig -from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - get_dtype_size, is_pin_memory_available) - -logger = init_logger(__name__) - - -class CacheEngine: - """Manages the KV cache. - - This class is responsible for initializing and managing the GPU and CPU KV - caches. It also provides methods for performing KV cache operations, such - as swapping and copying. - """ - - def __init__( - self, - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, - device_config: DeviceConfig, - ) -> None: - self.cache_config = cache_config - self.model_config = model_config - self.parallel_config = parallel_config - self.device_config = device_config - - self.head_size = model_config.get_head_size() - # Models like Jamba, have mixed typed layers, E.g Mamba - self.num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - - self.block_size = cache_config.block_size - self.num_gpu_blocks = cache_config.num_gpu_blocks - if self.num_gpu_blocks: - self.num_gpu_blocks //= parallel_config.pipeline_parallel_size - self.num_cpu_blocks = cache_config.num_cpu_blocks - if self.num_cpu_blocks: - self.num_cpu_blocks //= parallel_config.pipeline_parallel_size - - if cache_config.cache_dtype == "auto": - self.dtype = model_config.dtype - else: - self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - # Get attention backend. - self.attn_backend = get_attn_backend(self.head_size, - model_config.dtype, - cache_config.cache_dtype, - self.block_size, - model_config.is_attention_free, - use_mla=model_config.use_mla) - - # Initialize the cache. - self.gpu_cache = self._allocate_kv_cache( - self.num_gpu_blocks, self.device_config.device_type) - self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") - - def _allocate_kv_cache( - self, - num_blocks: int, - device: str, - ) -> List[torch.Tensor]: - """Allocates KV cache on the specified device.""" - kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - pin_memory = is_pin_memory_available() if device == "cpu" else False - kv_cache: List[torch.Tensor] = [] - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( - ) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape))) - - # The allocation respects the backend-defined stride order to ensure - # the semantic remains consistent for each backend. We first obtain the - # generic kv cache shape and then permute it according to the stride - # order which could result in a non-contiguous tensor. - kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i] - for i in kv_cache_stride_order) - - for _ in range(self.num_attention_layers): - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros( - kv_cache_allocation_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device).permute(*kv_cache_stride_order) - - # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases - # when entry_shape is higher than 1D - kv_cache.append(layer_kv_cache) - return kv_cache - - def swap_in(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_attention_layers): - self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], - src_to_dst) - - def swap_out(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_attention_layers): - self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], - src_to_dst) - - def copy(self, src_to_dsts: torch.Tensor) -> None: - self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) - - @staticmethod - def get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, - ) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - key_cache_entry = num_heads * head_size - - # For MLA there is no value cache, since the latent vector - # is joint keys and values. - value_cache_entry = key_cache_entry if not model_config.use_mla else 0 - total = num_attention_layers * cache_config.block_size * \ - (key_cache_entry + value_cache_entry) - - dtype_size = get_dtype_size(dtype) - return dtype_size * total diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py deleted file mode 100644 index cb5d5664ab5c0..0000000000000 --- a/vllm/worker/enc_dec_model_runner.py +++ /dev/null @@ -1,554 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import itertools -from typing import Any, Dict, List, Optional, Tuple, Type, cast - -import torch -import torch.distributed - -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.attention.selector import (get_env_variable_attn_backend, - get_global_forced_attn_backend) -from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, - MultiModalRegistry) -from vllm.platforms import _Backend -from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, - SequenceGroupMetadata) -from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad -from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) -from vllm.worker.model_runner_base import ( - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict) -from vllm.worker.utils import assert_enc_dec_mr_supported_scenario - -logger = init_logger(__name__) -LORA_WARMUP_RANK = 8 - - -@dataclasses.dataclass(frozen=True) -class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): - """ - Used by the EncoderDecoderModelRunner. - """ - encoder_input_tokens: Optional[torch.Tensor] = None - encoder_input_positions: Optional[torch.Tensor] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "encoder_input_tokens": self.encoder_input_tokens, - "encoder_input_positions": self.encoder_input_positions, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "EncoderDecoderModelInput": - return cast( - EncoderDecoderModelInput, - super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) - - -class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): - _model_input_cls: Type[EncoderDecoderModelInput] = ( - EncoderDecoderModelInput) - _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - ''' - EncoderDecoderModelRunner constructor. - - `lora_config` is unused (since these features are not yet supported - for encoder/decoder models) but these arguments are present here for - compatibility with the base-class constructor. - ''' - self._maybe_force_supported_attention_backend() - - super().__init__( - vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - input_registry=input_registry, - mm_registry=mm_registry, - ) - - # Crash for unsupported encoder/scenarios - assert_enc_dec_mr_supported_scenario(self) - - def _maybe_force_supported_attention_backend(self): - ''' - Force vLLM to use the XFormers attention backend, - which is currently the only supported option. - ''' - - def raise_backend_err(): - # The user has specified an attention backend override - # which is invalid for encoder/decoder models - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) - - maybe_env_var_forced_backend = get_env_variable_attn_backend() - maybe_global_forced_backend = get_global_forced_attn_backend() - is_forced_by_global = maybe_global_forced_backend is not None - is_forced_by_env_var = maybe_env_var_forced_backend is not None - if is_forced_by_global: # noqa: SIM102 - # Backend override enforced by global variable takes - # precedence over vLLM backend environment variable. - if maybe_global_forced_backend not in\ - [_Backend.XFORMERS, _Backend.FLASH_ATTN]: - raise_backend_err() - elif is_forced_by_env_var: # noqa: SIM102 - # Backend override enforced by vLLM backend - # environment variable - if maybe_env_var_forced_backend not in\ - [_Backend.XFORMERS, _Backend.FLASH_ATTN]: - raise_backend_err() - - def _list_to_int32_tensor( - self, - _list: List[int], - ) -> torch.Tensor: - return torch.tensor(_list, dtype=torch.int32, device=self.device) - - def _list_to_long_tensor( - self, - _list: List[int], - ) -> torch.Tensor: - return torch.tensor(_list, dtype=torch.long, device=self.device) - - def _empty_int32_tensor(self) -> torch.Tensor: - return self._list_to_int32_tensor([]) - - def _empty_long_tensor(self) -> torch.Tensor: - return self._list_to_long_tensor([]) - - @torch.inference_mode() - def execute_model( - self, - model_input: EncoderDecoderModelInput, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[PoolerOutput]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in " - "EncoderDecoderModelRunner") - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - if (model_input.attn_metadata is not None - and model_input.attn_metadata.prefill_metadata is None - and model_input.attn_metadata.decode_metadata.use_cuda_graph): - if model_input.inputs_embeds is None: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, False)]) - else: - graph_batch_size = model_input.inputs_embeds.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, True)]) - else: - model_executable = self.model - - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata, self.vllm_config, - model_input.virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - inputs_embeds=model_input.inputs_embeds, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **seqlen_agnostic_kwargs, - ) - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: - return EncoderDecoderModelInput.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> EncoderDecoderModelInput: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - Since chunked prefill is not supported for encoder/decoder models, - `input_tokens` is assumed to be either entirely prefill tokens or - entirely decode tokens. - - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - ( - attn_metadata, - encoder_input_tokens_tensor, - encoder_input_positions_tensor, - ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input)) - # Inject attn_metadata encoder/cross-attention fields & - # encoder input tokens/positions into model_input. - # Frozen dataclass fields cannot be modified, so use - # dataclasses.replace to construct a new model input - # instance. - model_input = dataclasses.replace( - model_input, - attn_metadata=attn_metadata, - encoder_input_tokens=encoder_input_tokens_tensor, - encoder_input_positions=encoder_input_positions_tensor, - ) - - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory, - generators=generators) - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() - def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - - # This represents the maximum number of different requests - # that will have unique loras, and therefore the max amount of - # memory consumption. Create dummy lora request copies from the - # lora request passed in, which contains a lora from the lora - # warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - dummy_lora_requests = self._add_dummy_loras( - self.lora_config.max_loras) - assert len(dummy_lora_requests) == self.lora_config.max_loras - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - logger.info("Starting profile run for multi-modal models.") - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - decoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=False) - encoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) - - # Having more tokens is over-conservative but otherwise fine - assert len( - decoder_dummy_data.seq_data.prompt_token_ids - ) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" - ) - - assert decoder_dummy_data.multi_modal_data is None or \ - encoder_dummy_data.multi_modal_data is None, ( - "Multi-modal data can't be provided in both encoder and decoder" - ) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: decoder_dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - encoder_seq_data=encoder_dummy_data.seq_data, - cross_block_table=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=decoder_dummy_data.multi_modal_data - or encoder_dummy_data.multi_modal_data, - multi_modal_placeholders=decoder_dummy_data. - multi_modal_placeholders - or encoder_dummy_data.multi_modal_placeholders) - seqs.append(seq) - - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - self.execute_model(model_input, None, intermediate_tensors) - torch.cuda.synchronize() - return - - def _prepare_encoder_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInput, - ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], - Optional[torch.Tensor]]: - """Helper method to prepare the encoder- and cross-attn-related - model inputs based on a given sequence group. These additional inputs - are used to augment an already-computed `EncoderDecoderModelInput` - data structure which already has decoder-related model inputs - populated. - - Sets the following attn_metadata fields: - * `num_encoder_tokens` - * `encoder_seq_lens` - * `encoder_seq_lens_tensor` - * `max_encoder_seq_len` - * `cross_slot_mapping` - * `cross_block_tables` - - Constructs a new model inputs data structure, based on - (1) the existing fields in the `model_inputs` argument, - and (2) the following additional fields which are - computed (or in the case of `attn_metadata`, updated) - by this function: - * attn_metadata - * encoder_input_tokens - * encoder_input_positions - - Arguments: - - * seq_group_metadata_list: list of sequence groups for which to - compute inputs - * model_inputs: model inputs data structure with decoder-oriented - fields already computed. - - Return: - - * Updated model inputs data structure - """ - - if len(seq_group_metadata_list) == 0: - return (model_input.attn_metadata, None, None) - - # Since we are not supporting chunked prefill either the entire - # batch is prefill or it is decode - is_prompt = seq_group_metadata_list[0].is_prompt - - # Build encoder inputs - encoder_seq_lens: List[int] = [] - if is_prompt: - # Prefill phase. - cross_block_tables = self._empty_int32_tensor().view( - len(seq_group_metadata_list), -1) - - # Extract input tokens/positions, cross-attention slot-mapping, - # & seq len from each sequence group metadata - ( - encoder_input_tokens, - encoder_input_positions, - cross_slot_mapping, - ) = ( - [], - [], - [], - ) - for seq_group_metadata in seq_group_metadata_list: - # Build seq lens - seq_len = seq_group_metadata.encoder_seq_data.get_len() - token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() - encoder_seq_lens.append(seq_len) - - # Build slot mapping - is_profile_run = (seq_group_metadata.block_tables is None) - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) - else: - for i in range(0, seq_len): - block_number = seq_group_metadata.cross_block_table[ - i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - cross_slot_mapping.append(slot) - - # Build encoder input tokens - encoder_input_tokens.extend(token_ids) - encoder_input_positions.extend(list(range(0, seq_len))) - - # Convert tokens/positions & cross-attention - # slot-mapping to encoder input tensors - encoder_input_tokens_tensor = self._list_to_long_tensor( - encoder_input_tokens) - encoder_input_positions_tensor = self._list_to_long_tensor( - encoder_input_positions) - cross_slot_mapping_tensor = self._list_to_long_tensor( - cross_slot_mapping) - - else: - # Decode phase. - encoder_input_tokens_tensor = self._empty_long_tensor() - encoder_input_positions_tensor = self._empty_long_tensor() - cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & - # seq len from each sequence group metadata. - # Cross-attention block tables are empty - # during vLLM memory profiling. - cross_block_tables = [] - for seq_group_metadata in seq_group_metadata_list: - for _ in range(len(seq_group_metadata.seq_data)): - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) - - if (model_input.attn_metadata is not None - and model_input.attn_metadata.use_cuda_graph): - # We will be using CUDA graph replay for this decode. - max_len_of_block_table = self.get_max_block_per_batch() - batch_size = len(encoder_seq_lens) - graph_batch_size = self.vllm_config.pad_for_cudagraph( - batch_size) - assert graph_batch_size >= batch_size - cuda_graph_pad_size = graph_batch_size - batch_size - # extend the cross_block_tables and encoder_seq_lens to match - # the graph_batch_size. - cross_block_tables.extend([[] - for _ in range(cuda_graph_pad_size) - ]) - encoder_seq_lens.extend( - itertools.repeat(1, cuda_graph_pad_size)) - - else: - max_len_of_block_table = max( - len(block_table) for block_table in cross_block_tables) - - cross_block_tables = make_tensor_with_pad( - cross_block_tables, - max_len=max_len_of_block_table, - pad=0, - dtype=torch.int32, - device=self.device, - ) - - # Compute encoder sequence lengths & encoder - # sequence starting offset tensors - max_encoder_seq_len = max(encoder_seq_lens, default=0) - encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + - 1, - dtype=torch.int32, - device=self.device) - torch.cumsum(encoder_seq_lens_tensor, - dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) - - # Update attention metadata with encoder-oriented attributes - attn_metadata = model_input.attn_metadata - assert attn_metadata is not None - ( - attn_metadata.num_encoder_tokens, - attn_metadata.encoder_seq_lens, - attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.cross_slot_mapping, - attn_metadata.cross_block_tables, - ) = ( - sum(encoder_seq_lens), - encoder_seq_lens, - encoder_seq_lens_tensor, - max_encoder_seq_len, - encoder_seq_start_loc, - cross_slot_mapping_tensor, - cross_block_tables, - ) - - return (attn_metadata, encoder_input_tokens_tensor, - encoder_input_positions_tensor) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py deleted file mode 100644 index a1c08fa814db4..0000000000000 --- a/vllm/worker/model_runner.py +++ /dev/null @@ -1,2043 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import gc -import inspect -import itertools -import time -import weakref -from contextlib import contextmanager -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -from tqdm.auto import tqdm - -import vllm.envs as envs -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState -from vllm.compilation.counter import compilation_counter -from vllm.config import CompilationLevel, VllmConfig -from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.distributed.kv_transfer import get_kv_transfer_group -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - graph_capture) -from vllm.forward_context import get_forward_context, set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata, SamplingMetadataCache -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, - get_sampler) -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.model_executor.models.utils import set_cpu_offload_max_bytes -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap, - MultiModalRegistry) -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, - async_tensor_h2d, flatten_2d_lists, - is_pin_memory_available, supports_dynamo, - weak_ref_tensor) -from vllm.worker.model_runner_base import ( - InputProcessingError, ModelRunnerBase, ModelRunnerInputBase, - ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -LORA_WARMUP_RANK = 8 - -_NUM_WARMUP_ITERS = 2 - -TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") - -# For now, bump up cache limits for recompilations during CUDA graph warmups. -torch._dynamo.config.cache_size_limit = 128 -torch._dynamo.config.accumulated_cache_size_limit = 128 - - -@dataclass(frozen=True) -class ModelInputForGPU(ModelRunnerInputBase): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - input_tokens: Optional[torch.Tensor] = None - inputs_embeds: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - token_types: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None - request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None - finished_requests_ids: Optional[List[str]] = None - virtual_engine: int = 0 - async_callback: Optional[Callable] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - previous_hidden_states: Optional[torch.Tensor] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForGPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForGPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - # Exclude `async_callback` to be able to pickle this object - def __getstate__(self): - state = self.__dict__.copy() - del state["async_callback"] - return state - - # TODO: What happens when we depickle this object? - # How can we update this callback to properly pass it to the engine? - def __setstate__(self, state): - self.__dict__.update(state) - self.__dict__.update({'async_callback': None}) - - -@dataclass(frozen=True) -class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - # Used for speculative decoding. We do not broadcast it because it is only - # used by the driver worker. - is_prompt: Optional[bool] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForGPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): - """Build ModelInputForGPU from SequenceGroupMetadata.""" - - # Note: ideally we would be using a dataclass(kw_only=True) - # here, so that this can be subclassed easily, - # but kw_only is not supported in python<3.10. - class InterDataForSeqGroup: - """Intermediate data for the current sequence group.""" - - def simple_reinit(self): - self.input_tokens[0].clear() # type: ignore - self.inputs_embeds = None # type: ignore - self.input_positions[0].clear() # type: ignore - self.token_types[0].clear() # type: ignore - self.mrope_input_positions = None # type: ignore - self.seq_lens[0] = 0 # type: ignore - self.orig_seq_lens[0] = 0 # type: ignore - self.prompt_lens[0] = 0 # type: ignore - self.query_lens[0] = 0 # type: ignore - self.context_lens[0] = 0 # type: ignore - self.curr_sliding_window_blocks[0] = 0 # type: ignore - self.lora_index_mapping.clear() # type: ignore - self.lora_prompt_mapping.clear() # type: ignore - self.lora_requests.clear() # type: ignore - - def __init__( - self, - *, - # From sequence group metadata. - request_id: str, - seq_ids: List[int], - is_prompt: bool, - block_tables: Optional[Dict[int, List[int]]], - computed_block_nums: List[int], - n_seqs: int = 0, - - # Input tokens and positions. - input_tokens: Optional[List[List[int]]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - input_positions: Optional[List[List[int]]] = None, - token_types: Optional[List[List[int]]] = None, - mrope_input_positions: Optional[List[List[List[int]]]] = None, - - # The sequence length (may be capped to the sliding window). - seq_lens: Optional[List[int]] = None, - # The original sequence length (before applying sliding window). - # This is used to compute slot mapping. - orig_seq_lens: Optional[List[int]] = None, - # This is used in the dual-chunk flash attention backend. - prompt_lens: Optional[List[int]] = None, - # The query length. - query_lens: Optional[List[int]] = None, - # The number of tokens that are already computed. - context_lens: Optional[List[int]] = None, - # The current sliding window block. - curr_sliding_window_blocks: Optional[List[int]] = None, - - # LoRA inputs. - lora_index_mapping: Optional[List[List[int]]] = None, - lora_prompt_mapping: Optional[List[List[int]]] = None, - lora_requests: Optional[Set[LoRARequest]] = None, - - # Multi-modal inputs. - multi_modal_kwargs: Optional[MultiModalKwargs] = None, - multi_modal_placeholder_maps: Optional[Dict[ - str, MultiModalPlaceholderMap]] = None, - - # Whether the prefix cache is hit (prefill only). - prefix_cache_hit: bool = False, - reinit: bool = False, - reinit_use_defaults: bool = False, - encoder_seq_len: int = 0, - ): - if reinit: - assert len(self.seq_ids) == len(seq_ids) # type: ignore - for i, seq_id in enumerate(seq_ids): - self.seq_ids[i] = seq_id # type: ignore - else: - self.seq_ids = seq_ids - - self.request_id = request_id - self.is_prompt = is_prompt - self.block_tables = block_tables - self.computed_block_nums = computed_block_nums - self.n_seqs = n_seqs - self.encoder_seq_len = encoder_seq_len - - if reinit: - if len(self.seq_ids) == 1 and reinit_use_defaults: - self.simple_reinit() - else: - if input_tokens: - self.input_tokens = input_tokens - else: - for seq_id in range(len(self.seq_ids)): - self.input_tokens[seq_id].clear() - - self.inputs_embeds = inputs_embeds - - if input_positions: - self.input_positions = input_positions - else: - for seq_id in range(len(self.seq_ids)): - self.input_positions[seq_id].clear() - - if token_types: - self.token_types = token_types - else: - for seq_id in range(len(self.seq_ids)): - self.token_types[seq_id].clear() - - self.mrope_input_positions = None - - if seq_lens: - self.seq_lens = seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.seq_lens[seq_id] = 0 - - if orig_seq_lens: - self.orig_seq_lens = orig_seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.orig_seq_lens[seq_id] = 0 - - if prompt_lens: - self.prompt_lens = prompt_lens - else: - for seq_id in range(len(self.seq_ids)): - self.prompt_lens[seq_id] = 0 - - if query_lens: - self.query_lens = query_lens - else: - for seq_id in range(len(self.seq_ids)): - self.query_lens[seq_id] = 0 - - if context_lens: - self.context_lens = context_lens - else: - for seq_id in range(len(self.seq_ids)): - self.context_lens[seq_id] = 0 - - if curr_sliding_window_blocks: - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks - else: - for seq_id in range(len(self.seq_ids)): - self.curr_sliding_window_blocks[seq_id] = 0 - - if lora_index_mapping: - self.lora_index_mapping = lora_index_mapping - else: - self.lora_index_mapping.clear() - - if lora_prompt_mapping: - self.lora_prompt_mapping = lora_prompt_mapping - else: - self.lora_prompt_mapping.clear() - - if lora_requests: - self.lora_requests = lora_requests - else: - self.lora_requests.clear() - - else: - self.input_tokens = input_tokens or [] - self.inputs_embeds = inputs_embeds - self.input_positions = input_positions or [] - self.token_types = token_types or [] - self.mrope_input_positions = mrope_input_positions or None - self.seq_lens = seq_lens or [] - self.orig_seq_lens = orig_seq_lens or [] - self.prompt_lens = prompt_lens or [] - self.query_lens = query_lens or [] - self.context_lens = context_lens or [] - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks or [] - - self.lora_index_mapping = lora_index_mapping or [] - self.lora_prompt_mapping = lora_prompt_mapping or [] - self.lora_requests = lora_requests or set() - - self.multi_modal_kwargs = multi_modal_kwargs - self.multi_modal_placeholder_maps = multi_modal_placeholder_maps - self.prefix_cache_hit = prefix_cache_hit - - self.n_seqs = len(self.seq_ids) - - if not reinit: - self.__post_init__() - - def __post_init__(self): - self.n_seqs = len(self.seq_ids) - - self.input_tokens = [[] for _ in range(self.n_seqs)] - self.input_positions = [[] for _ in range(self.n_seqs)] - self.token_types = [[] for _ in range(self.n_seqs)] - self.mrope_input_positions = None - self.seq_lens = [0] * self.n_seqs - self.orig_seq_lens = [0] * self.n_seqs - self.prompt_lens = [0] * self.n_seqs - self.query_lens = [0] * self.n_seqs - self.context_lens = [0] * self.n_seqs - self.curr_sliding_window_blocks = [0] * self.n_seqs - - self.lora_index_mapping = [] - self.lora_prompt_mapping = [] - - def __repr__(self) -> str: - return (f"InterDataForSeqGroup(" - f"request_id={self.request_id}, " - f"seq_ids={self.seq_ids}, " - f"is_prompt={self.is_prompt}, " - f"block_tables={self.block_tables}, " - f"computed_block_nums={self.computed_block_nums}, " - f"n_seqs={self.n_seqs}, " - f"input_tokens={self.input_tokens}, " - f"inputs_embeds.shape=" - f"{getattr(self.inputs_embeds, 'shape', None)}, " - f"input_positions={self.input_positions}, " - f"token_types={self.token_types}, " - f"mrope_input_positions={self.mrope_input_positions}, " - f"seq_lens={self.seq_lens}, " - f"orig_seq_lens={self.orig_seq_lens}, " - f"query_lens={self.query_lens}, " - f"context_lens={self.context_lens}, " - f"multi_modal_kwargs={self.multi_modal_kwargs}") - - def gen_inter_data_builder(self, num_seqs: int): - return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( - request_id="", - seq_ids=[0] * num_seqs, - is_prompt=True, - block_tables=None, - computed_block_nums=[]) - - def init_cached_inter_data(self, *args, **kwargs): - assert len(args) == 0 - assert "seq_ids" in kwargs - seq_ids = kwargs["seq_ids"] - num_seqs = len(seq_ids) - - # The inter-data cache is per model_runner - inter_data_cache = self.runner.inter_data_cache - if num_seqs not in inter_data_cache: - inter_data_cache[num_seqs] = PyObjectCache( - self.gen_inter_data_builder(num_seqs)) - - obj = inter_data_cache[num_seqs].get_object() - obj.__init__(*args, **kwargs) - return obj - - def reset_cached_inter_data(self): - for cache in self.runner.inter_data_cache.values(): - cache.reset() - - def __init__(self, - runner: "GPUModelRunnerBase", - finished_requests_ids: Optional[List[str]] = None): - super().__init__() - # Compute functions for each sequence in a sequence group. - # WARNING: The order of the functions matters! - self.per_seq_compute_fns = [ - self._compute_lens, - self._compute_for_prefix_cache_hit, - self._compute_for_sliding_window, - self._compute_lora_input, - ] - # Compute functions for each sequence group. - # WARNING: The order of the functions matters! - self.per_seq_group_compute_fns = [ - self._compute_multi_modal_input, - ] - - self.runner = runner - self.model_input_cls = self.runner._model_input_cls - self.attn_backend = self.runner.attn_backend - self.scheduler_config = self.runner.scheduler_config - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.enable_lora = self.runner.lora_config is not None - - # Attention metadata inputs. - if self.attn_backend is not None: - # spec decode (e.g. Medusa) does not have atten backend - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) - - # Engine/Model configurations. - self.chunked_prefill_enabled = ( - self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled) - if self.sliding_window is not None: - self.sliding_window_blocks = ( - self.sliding_window + self.block_size - 1) // self.block_size - self.block_aligned_sliding_window = \ - self.sliding_window_blocks * self.block_size - - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.finished_requests_ids = finished_requests_ids - - # if the current batch is decode-only. - # will be set to False if there is any non-decode request. - self.decode_only = True - - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - - self.attn_metadata_builder.prepare() - - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Compute context length, sequence length and tokens - for the given sequence data. - """ - seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] - token_chunk_size = seq_group_metadata.token_chunk_size - - # Compute context length (the number of tokens that are - # already computed) and sequence length (total number of tokens). - - seq_len = seq_data.get_len() - if inter_data.is_prompt: - context_len = seq_data.get_num_computed_tokens() - seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.model_config.is_encoder_decoder: - context_len = seq_len - 1 - else: - context_len = seq_data.get_num_computed_tokens() - - # Compute tokens. - if seq_data.prompt_embeds is None: - tokens = seq_data.get_token_ids()[context_len:seq_len] - prompt_embeds = None - else: - tokens = [0] * (seq_len - context_len) - prompt_embeds = seq_data.get_token_embeddings( - )[context_len:seq_len] - - token_types = seq_group_metadata.token_type_ids - - inter_data.seq_lens[seq_idx] = seq_len - inter_data.orig_seq_lens[seq_idx] = seq_len - inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() - inter_data.context_lens[seq_idx] = context_len - inter_data.input_tokens[seq_idx].extend(tokens) - inter_data.inputs_embeds = prompt_embeds - inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) - inter_data.token_types[seq_idx].extend( - token_types if token_types else []) - inter_data.query_lens[seq_idx] = seq_len - context_len - - if seq_data.mrope_position_delta is not None: - if inter_data.mrope_input_positions is None: - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - - inter_data.mrope_input_positions[ - seq_idx] = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - - def _compute_for_prefix_cache_hit( - self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Check if hit prefix cache (i.e., some blocks are already computed). - If hit, update input tokens and positions to only compute the - remaining blocks. - """ - computed_block_nums = inter_data.computed_block_nums - - # Note that prefix caching does not support sliding window. - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and inter_data.is_prompt) - inter_data.prefix_cache_hit = prefix_cache_hit - - if not prefix_cache_hit: - return - - assert computed_block_nums is not None - # The cache hit prompt tokens in this sequence. Note that - # this may be larger than the sequence length if chunked - # prefill is enabled. - prefix_cache_len = len(computed_block_nums) * self.block_size - seq_group_metadata.seq_data[inter_data.seq_ids[ - seq_idx]].update_num_cached_tokens(prefix_cache_len) - - # The number of so far computed prompt tokens in this sequence. - context_len = inter_data.context_lens[seq_idx] - # The total number of prompt tokens in this sequence. - # When chunked prefill is enabled, this is the token number of - # computed chunks + current chunk. - seq_len = inter_data.seq_lens[seq_idx] - if prefix_cache_len <= context_len: - # We already passed the cache hit region, - # so do normal computation. - pass - elif context_len < prefix_cache_len < seq_len: - # Partial hit. Compute the missing part. - uncomputed_start = prefix_cache_len - context_len - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][uncomputed_start:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][uncomputed_start:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - uncomputed_start:] - context_len = prefix_cache_len - - inter_data.context_lens[seq_idx] = context_len - inter_data.query_lens[ - seq_idx] = inter_data.seq_lens[seq_idx] - context_len - elif seq_len <= prefix_cache_len: - # Full hit. Only compute the last token to avoid - # erroneous behavior. FIXME: Ideally we should directly - # mark all tokens as computed in the scheduler and do not - # schedule this sequence, so this case should not happen. - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][-1:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][-1:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - -1:] - inter_data.query_lens[seq_idx] = 1 - inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 - - def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Update seq_len and curr_sliding_window_block for the given - sequence data (only required by decoding) if sliding window is enabled. - """ - curr_sliding_window_block = 0 - sliding_seq_len = inter_data.seq_lens[seq_idx] - if not inter_data.is_prompt and self.sliding_window is not None: - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - curr_sliding_window_block = self.sliding_window_blocks - # number of elements in last block - suff_len = inter_data.seq_lens[seq_idx] % self.block_size - sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_block += 1 - - inter_data.curr_sliding_window_blocks[ - seq_idx] = curr_sliding_window_block - inter_data.seq_lens[seq_idx] = sliding_seq_len - - def _compute_lora_input(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """If LoRA is enabled, compute LoRA index and prompt mapping.""" - if not self.enable_lora: - return - - lora_id = seq_group_metadata.lora_int_id - if lora_id > 0: - inter_data.lora_requests.add(seq_group_metadata.lora_request) - query_len = inter_data.query_lens[seq_idx] - inter_data.lora_index_mapping.append([lora_id] * query_len) - sampling_params = seq_group_metadata.sampling_params - if sampling_params and sampling_params.prompt_logprobs is not None: - inter_data.lora_prompt_mapping.append([lora_id] * query_len) - elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample: - inter_data.lora_prompt_mapping.append([lora_id]) - else: - inter_data.lora_prompt_mapping.append([]) - - def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If multi-modal data is given, add it to the input.""" - # NOTE: mm_kwargs only includes the subset of multi-modal items that - # intersect with the current prefill positions. - positions = inter_data.input_positions[0] - mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, - range(positions[0], positions[0] + len(positions))) - - # M-RoPE requires mrope_positions even for plain text; return early - # when mm_kwargs is empty only if inter_data.is_prompt is False. - if not mm_kwargs and not inter_data.is_prompt: - return - - inter_data.multi_modal_kwargs = mm_kwargs - inter_data.multi_modal_placeholder_maps = placeholder_maps - - # special processing for mrope position deltas. - if self.runner.model_config.uses_mrope: - image_grid_thw = mm_kwargs.get("image_grid_thw", None) - video_grid_thw = mm_kwargs.get("video_grid_thw", None) - audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", - None) - - second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) - use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) - hf_config = self.runner.model_config.hf_config - - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - for seq_idx in range(inter_data.n_seqs): - seq_data = seq_group_metadata.seq_data[ - inter_data.seq_ids[seq_idx]] - token_ids = seq_data.get_token_ids() - - mrope_input_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=inter_data.context_lens[seq_idx], - seq_len=inter_data.seq_lens[seq_idx], - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - seq_data.mrope_position_delta = mrope_position_delta - inter_data.mrope_input_positions[ - seq_idx] = mrope_input_positions - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - """Add a sequence group to the builder.""" - seq_ids = seq_group_metadata.seq_data.keys() - n_seqs = len(seq_ids) - is_prompt = seq_group_metadata.is_prompt - - if is_prompt: - assert n_seqs == 1 - self.decode_only = False - - encoder_seq_len = 0 - - if self.runner.model_config.is_encoder_decoder: - encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() - - inter_data = self.init_cached_inter_data( - request_id=seq_group_metadata.request_id, - seq_ids=seq_ids, - is_prompt=is_prompt, - block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums, - reinit=True, - reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) - - self.inter_data_list.append(inter_data) - - for seq_idx in range(n_seqs): - for per_seq_fn in self.per_seq_compute_fns: - per_seq_fn(inter_data, seq_idx, seq_group_metadata) - for per_seq_group_fn in self.per_seq_group_compute_fns: - per_seq_group_fn(inter_data, seq_group_metadata) - - def _use_captured_graph(self, - batch_size: int, - decode_only: bool, - max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> bool: - return (decode_only and not self.runner.model_config.enforce_eager - and max_decode_seq_len <= self.runner.max_seq_len_to_capture - and max_encoder_seq_len <= self.runner.max_seq_len_to_capture - and batch_size <= self.runner.max_batchsize_to_capture) - - def _get_cuda_graph_pad_size(self, - num_seqs: int, - max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> int: - """ - Determine the number of padding sequences required for running in - CUDA graph mode. Returns -1 if CUDA graphs cannot be used. - - In the multi-step + chunked-prefill case, only the first step - has Prefills (if any). The rest of the steps are guaranteed to be all - decodes. In this case, we set up the padding as if all the sequences - are decodes so we may run all steps except the first step in CUDA graph - mode. - - Args: - num_seqs (int): Number of sequences scheduled to run. - max_decode_seq_len (int): Greatest of all the decode sequence - lengths. Used only in checking the viablility of using - CUDA graphs. - max_encoder_seq_len (int, optional): Greatest of all the encode - sequence lengths. Defaults to 0. Used only in checking the - viability of using CUDA graphs. - Returns: - int: Returns the determined number of padding sequences. If - CUDA graphs is not viable, returns -1. - """ - decode_only = self.decode_only - if not decode_only: - # Early exit so we can treat num_seqs as the batch_size below. - return -1 - - # batch_size out of this function refers to the number of input - # tokens being scheduled. This conflation of num_seqs as batch_size - # is valid as this is a decode-only case. - batch_size = num_seqs - if not self._use_captured_graph(batch_size, decode_only, - max_decode_seq_len, - max_encoder_seq_len): - return -1 - - graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( - batch_size) - assert graph_batch_size >= batch_size - return graph_batch_size - batch_size - - def build(self) -> ModelInputForGPU: - """Finalize the builder intermediate data and - create on-device tensors. - """ - # Combine and flatten intermediate data. - input_tokens = list[int]() - inputs_embeds_list = list[torch.Tensor]() - token_types = list[int]() - for inter_data in self.inter_data_list: - for cur_input_tokens in inter_data.input_tokens: - input_tokens.extend(cur_input_tokens) - for cur_token_types in inter_data.token_types: - token_types.extend(cur_token_types) - if inter_data.inputs_embeds is not None: - inputs_embeds_list.append( - inter_data.inputs_embeds.to( - dtype=self.runner.model_config.dtype, - device=self.runner.device)) - inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_list) == 0: - inputs_embeds = None - else: - inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( - dtype=self.runner.model_config.dtype, - device=self.runner.device) - assert len(inputs_embeds) == len(input_tokens) - - if not input_tokens and inputs_embeds is None: - # This may happen when all prefill requests hit - # prefix caching and there is no decode request. - return self.model_input_cls() - - mrope_input_positions: Optional[List[List[int]]] = None - if any(inter_data.mrope_input_positions is not None - for inter_data in self.inter_data_list): - mrope_input_positions = [[] for _ in range(3)] - for idx in range(3): - for inter_data in self.inter_data_list: - msections = inter_data.mrope_input_positions - if msections is None: - for _seq_input_positions in inter_data.input_positions: - mrope_input_positions[idx].extend( - _seq_input_positions) - else: - for _seq_mrope_input_positions in msections: - mrope_input_positions[idx].extend( - _seq_mrope_input_positions[idx]) - input_positions = None - else: - input_positions = [] - for inter_data in self.inter_data_list: - for cur_input_positions in inter_data.input_positions: - input_positions.extend(cur_input_positions) - - seq_lens = [] - query_lens = [] - max_decode_seq_len = 0 - max_encoder_seq_len = 0 - for inter_data in self.inter_data_list: - seq_lens.extend(inter_data.seq_lens) - query_lens.extend(inter_data.query_lens) - if not inter_data.is_prompt: - max_decode_seq_len = max(max_decode_seq_len, - max(inter_data.seq_lens)) - if self.runner.model_config.is_encoder_decoder: - max_encoder_seq_len = max(max_encoder_seq_len, - inter_data.encoder_seq_len) - - # Mapping from request IDs to sequence IDs. Used for Jamba models - # that manages the cache by itself. - request_ids_to_seq_ids = { - data.request_id: data.seq_ids - for data in self.inter_data_list - } - - cuda_graph_pad_size = self._get_cuda_graph_pad_size( - num_seqs=len(seq_lens), - max_decode_seq_len=max_decode_seq_len, - max_encoder_seq_len=max_encoder_seq_len) - - batch_size = len(input_tokens) - if cuda_graph_pad_size != -1: - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - batch_size += cuda_graph_pad_size - - # Tokens and positions. - if cuda_graph_pad_size: - input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) - assert self.runner.device is not None - input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, - self.runner.device, - self.runner.pin_memory) - - token_types_tensor = async_tensor_h2d(token_types, torch.long, - self.runner.device, - self.runner.pin_memory) \ - if token_types else None - - if mrope_input_positions is not None: - for idx in range(3): - mrope_input_positions[idx].extend( - itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(mrope_input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) - else: - input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) - # Sequence and query lengths. - if cuda_graph_pad_size: - seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) - - # Attention metadata. - attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, cuda_graph_pad_size, batch_size) - - # LoRA data. - lora_requests = set() - lora_mapping = None - if self.enable_lora: - lora_requests = set(r for data in self.inter_data_list - for r in data.lora_requests) - lora_index_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_index_mapping) - for inter_data in self.inter_data_list - ]) - if cuda_graph_pad_size: - lora_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - lora_prompt_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_prompt_mapping) - for inter_data in self.inter_data_list - ]) - - lora_mapping = LoRAMapping( - **dict(index_mapping=lora_index_mapping, - prompt_mapping=lora_prompt_mapping, - is_prefill=not self.decode_only)) - - # Multi-modal data. - multi_modal_kwargs_list = [ - data.multi_modal_kwargs for data in self.inter_data_list - if data.multi_modal_kwargs is not None - ] - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return self.model_input_cls( - input_tokens=input_tokens_tensor, - inputs_embeds=inputs_embeds, - input_positions=input_positions_tensor, - token_types=token_types_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=self.finished_requests_ids) - - -class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): - """ - Helper class for shared methods between GPU model runners. - """ - _model_input_cls: Type[TModelInputForGPU] - _builder_cls: Type[ModelInputForGPUBuilder] - builder: ModelInputForGPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - return_hidden_states: bool = False, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - - ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - cache_config = self.cache_config - - self.is_driver_worker = is_driver_worker - self.return_hidden_states = return_hidden_states - - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = \ - self.vllm_config.compilation_config.max_capture_size - - # - self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [ - {} for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.graph_memory_pool: Optional[Tuple[ - int, int]] = None # Set during graph capture. - - self.has_inner_state = model_config.has_inner_state - - self.in_profile_run = False - - # When using CUDA graph, the input block tables must be padded to - # max_seq_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max seq len to capture / block size). - self.graph_block_tables = np.zeros( - (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) - - self.cross_layer_shared_graph_block_tables = np.zeros( - (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) - - # Attention-free but stateful models like Mamba need a placeholder attn - # backend, as the attention metadata is needed to manage internal state. - # However we must bypass attention selection altogether for some models - # used for speculative decoding to avoid a divide-by-zero in - # model_config.get_head_size() - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - needs_attn_backend = (num_attn_heads != 0 - or self.model_config.is_attention_free) - - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) if needs_attn_backend else None - if self.attn_backend: - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) - else: - self.attn_state = CommonAttentionState(weakref.proxy(self)) - - # Multi-modal data support - self.input_registry = input_registry - self.mm_registry = mm_registry - - # Lazy initialization - self.model: nn.Module # Set after load_model - # Set after load_model. - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self.sampler = get_sampler() - - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - - # Used to cache python objects - self.inter_data_cache: Dict[int, PyObjectCache] = {} - - # Using the PythonizationCache in Pipeline-Parallel clobbers the - # SequenceGroupToSample object. In Pipeline-Parallel, we have - # more than 1 Scheduler, resulting in a potential back-to-back - # prepare_model_inputs() call. This clobbers the cached - # SequenceGroupToSample objects, as we reset the cache during - # every prepare_model_inputs() call. - self.sampling_metadata_cache: SamplingMetadataCache = \ - SamplingMetadataCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None - - if hasattr(self, "_builder_cls"): - # multi-step model runner does not have `_builder_cls` - self.builder = self._builder_cls(weakref.proxy(self)) - - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler(self.device) as m: - time_before_load = time.perf_counter() - self.model = get_model(vllm_config=self.vllm_config) - if self.lora_config: - assert supports_lora( - self.model - ), f"{self.model.__class__.__name__} does not support LoRA yet." - - if supports_multimodal(self.model): - logger.warning( - "Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") - - # Use get_text_config() in case of multimodal models - text_config = self.model_config.hf_config.get_text_config() - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=text_config. - max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - time_after_load = time.perf_counter() - - self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) - - - if self.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) - compilation_counter.dynamo_as_is_count += 1 - self.model = torch.compile( - self.model, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) - - def get_model(self) -> nn.Module: - return self.model - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from vllm.model_executor.model_loader import ShardedStateLoader - ShardedStateLoader.save_model( - self.model, - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - from vllm.model_executor.model_loader import TensorizerLoader - TensorizerLoader.save_model( - self.model, - tensorizer_config=tensorizer_config, - model_config=self.model_config, - ) - - def get_max_block_per_batch(self) -> int: - block_size = self.block_size - return (self.max_seq_len_to_capture + block_size - 1) // block_size - - def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForGPU: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - self.builder.prepare(finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - try: - self.builder.add_seq_group(seq_group_metadata) - except Exception as e: - # Raise an exception that tracks the ID of the bad request - raise InputProcessingError(seq_group_metadata.request_id, - str(e)) from e - - self.builder.reset_cached_inter_data() - - return self.builder.build() # type: ignore - - @contextmanager - def set_in_profile_run(self): - self.in_profile_run = True - try: - yield - finally: - self.in_profile_run = False - - @torch.inference_mode() - def profile_run(self) -> None: - max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - self._dummy_run(max_num_batched_tokens, max_num_seqs) - - def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]: - assert num_loras > 0 - assert self.lora_manager is not None - - dummy_lora_requests: list[LoRARequest] = [] - with self.lora_manager.dummy_lora_cache(): - for idx in range(num_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - return dummy_lora_requests - - def _remove_dummy_loras(self): - # Remove dummy loras. - assert self.lora_manager is not None - self.remove_all_loras() - - def _dummy_run(self, - max_num_batched_tokens: int, - max_num_seqs: int = 1) -> None: - with self.set_in_profile_run(): - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = \ - SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - - # This represents the maximum number of different requests - # that will have unique loras, and therefore the max amount of - # memory consumption. Create dummy lora request copies from the - # lora request passed in, which contains a lora from the lora - # warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - dummy_lora_requests = self._add_dummy_loras( - self.lora_config.max_loras) - assert len(dummy_lora_requests) == self.lora_config.max_loras - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the - # total number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, - # which needs to be accounted for when calculating the GPU blocks - # for vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data. - multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = \ - self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - # Disable KV Scale Calculation for dummy data during profile run - if model_input.attn_metadata is not None: - model_input.attn_metadata.enable_kv_scales_calculation = False - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - if self.lora_config: - self._remove_dummy_loras() - - return - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - - @torch.inference_mode() - def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: - """Cuda graph capture a model. - - Note that CUDA graph's performance gain is negligible if number - of batched tokens are larger than 200. And since CUDA graph - requires fixed sized tensors, supporting large/variable batch - size requires high GPU memory overhead. Thus, vLLM only captures - decoding requests. Mixed batch (chunked prefill + decoding) or - prefill requests are not captured. - - Since it is used for decoding-only, it assumes there's only 1 token - per sequence in the batch. - """ - assert not self.model_config.enforce_eager - logger.info("Capturing cudagraphs for decoding. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI. " - "If out-of-memory error occurs during cudagraph capture," - " consider decreasing `gpu_memory_utilization` or " - "switching to eager mode. You can also reduce the " - "`max_num_seqs` as needed to decrease memory usage.") - start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = self.max_batchsize_to_capture - input_tokens = torch.zeros(max_batch_size, - dtype=torch.long, - device=self.device) - input_positions = torch.zeros(max_batch_size, - dtype=torch.long, - device=self.device) - inputs_embeds = torch.zeros( - (max_batch_size, self.model_config.get_hidden_size()), - dtype=self.model_config.dtype, - device=self.device) - if self.model_config.uses_mrope: - input_positions = torch.tile(input_positions, - (3, 1)).cuda(device=self.device) - # Prepare dummy previous_hidden_states only if needed by the model. - # This is used by draft models such as EAGLE. - previous_hidden_states = None - if "previous_hidden_states" in inspect.signature( - self.model.forward).parameters: - previous_hidden_states = torch.empty( - [max_batch_size, - self.model_config.get_hidden_size()], - dtype=self.model_config.dtype, - device=self.device) - - intermediate_inputs = None - if not get_pp_group().is_first_rank: - intermediate_inputs = self.model.make_empty_intermediate_tensors( - batch_size=max_batch_size, - dtype=self.model_config.dtype, - device=self.device) - - dummy_lora_id: Optional[int] = None - dummy_lora_request: LoRARequest = [] - if self.lora_config: - # The goal is to capture the LoRA kernels in cuda graphs. - # for this purpose, as single dummy lora is sufficient. - dummy_lora_requests = self._add_dummy_loras(num_loras=1) - assert len(dummy_lora_requests) == 1 - dummy_lora_request = dummy_lora_requests[0] - dummy_lora_id = dummy_lora_request.lora_int_id - - with self.attn_state.graph_capture(max_batch_size), graph_capture( - self.device) as graph_capture_context: - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for virtual_engine in range( - self.parallel_config.pipeline_parallel_size): - # We need to not only iterate over batch sizes, but also whether - # to use inputs_embeds or not, hence we use the cartesian - # product. - cudagraph_capture_sizes = self.vllm_config.compilation_config\ - .cudagraph_capture_sizes - cudagraph_inputs_embeds = (( - True, False) if self.model_config.enable_prompt_embeds else - (False, )) - compilation_cases = itertools.product( - cudagraph_capture_sizes, - cudagraph_inputs_embeds, - ) - # Only rank 0 should print progress bar during capture - if get_tensor_model_parallel_rank() == 0: - compilation_cases = tqdm( - list(compilation_cases), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes") - for batch_size, use_inputs_embeds in compilation_cases: - attn_metadata = ( - self.attn_state.graph_capture_get_metadata_for_batch( - batch_size, - is_encoder_decoder_model=self.model_config. - is_encoder_decoder)) - # Disable KV Scale Calculation for graph capture - attn_metadata.enable_kv_scales_calculation = False - if self.lora_config: - lora_mapping = LoRAMapping( - **dict(index_mapping=[dummy_lora_id] * batch_size, - prompt_mapping=[dummy_lora_id] * batch_size, - is_prefill=False)) - self.set_active_loras(set([dummy_lora_request]), - lora_mapping) - - graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name(), - self.attn_state.graph_clone(batch_size), - self.model_config.is_encoder_decoder) - - capture_inputs = { - "input_ids": - input_tokens[:batch_size], - "inputs_embeds": - inputs_embeds[:batch_size] - if use_inputs_embeds else None, - "positions": - input_positions[..., :batch_size], - "intermediate_inputs": - intermediate_inputs[:batch_size] - if intermediate_inputs is not None else None, - "kv_caches": - kv_caches[virtual_engine], - "attn_metadata": - attn_metadata, - "memory_pool": - self.graph_memory_pool, - "stream": - graph_capture_context.stream - } - if previous_hidden_states is not None: - capture_inputs[ - "previous_hidden_states"] = previous_hidden_states[: - batch_size] - - if self.has_inner_state: - # Only used by Mamba-based models CUDA graph atm (Jamba) - capture_inputs.update({ - "seqlen_agnostic_capture_inputs": - self.model.get_seqlen_agnostic_capture_inputs( - batch_size) - }) - if self.model_config.is_encoder_decoder: - # add the additional inputs to capture for - # encoder-decoder models. - self._update_inputs_to_capture_for_enc_dec_model( - capture_inputs) - - with set_forward_context(attn_metadata, self.vllm_config, - virtual_engine): - graph_runner.capture(**capture_inputs) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][( - batch_size, use_inputs_embeds)] = graph_runner - - if self.lora_config: - self._remove_dummy_loras() - - end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] - elapsed_time = end_time - start_time - cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory - # This usually takes < 10 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / GiB_bytes) - - def _update_inputs_to_capture_for_enc_dec_model(self, - capture_inputs: Dict[str, - Any]): - """ - Updates the set of input tensors needed for CUDA graph capture in an - encoder-decoder model. - - This method modifies the provided `capture_inputs` dictionary by - adding tensors specific to encoder-decoder specific models that - need to be captured for CUDA Graph replay. - """ - # During the decode phase encoder_input_ids and encoder_positions are - # unset. Do the same thing for graph capture. - capture_inputs["encoder_input_ids"] = torch.tensor([], - dtype=torch.long, - device=self.device) - capture_inputs["encoder_positions"] = torch.tensor([], - dtype=torch.long, - device=self.device) - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - -class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - """ - GPU model runner with sampling step. - """ - _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( - ModelInputForGPUWithSamplingMetadata) - _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> ModelInputForGPUWithSamplingMetadata: - model_input = \ - ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - return model_input - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - if get_pp_group().is_last_rank: - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory, - generators, self.sampling_metadata_cache) - else: - sampling_metadata = None - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in ModelRunner") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - self.attn_state.begin_forward(model_input) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - # TODO(andoorve): We can remove this once all - # virtual engines share the same kv cache. - virtual_engine = model_input.virtual_engine - previous_hidden_states = kwargs.get("previous_hidden_states") - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - use_inputs_embeds = model_input.inputs_embeds is not None - model_executable = self.graph_runners[virtual_engine][( - graph_batch_size, use_inputs_embeds)] - if previous_hidden_states is not None: - previous_hidden_states = torch.cat([ - previous_hidden_states, - torch.empty([ - graph_batch_size - previous_hidden_states.shape[0], - *previous_hidden_states.shape[1:] - ], - dtype=previous_hidden_states.dtype, - device=previous_hidden_states.device) - ]) - else: - model_executable = self.model - - # Receive KV cache in distributed KV cache transfer setting - # In disagg prefill setting, it will also recv hidden states and bypass - # model forwarding - # In KV cache database setting, it will change the model input so that - # we can skip prefilling on tokens that successfully received KV caches - # NOTE: The receive operation is blocking - bypass_model_exec = False - if self.need_recv_kv(model_input, kv_caches): - hidden_or_intermediate_states, bypass_model_exec, model_input = \ - get_kv_transfer_group().recv_kv_caches_and_hidden_states( - # model is used to know which layer the current worker - # is working on, so that we can receive KV for only those - # layers. - model_executable, - model_input, - kv_caches=kv_caches - ) - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - model_kwargs = {} - if previous_hidden_states is not None: - model_kwargs["previous_hidden_states"] = previous_hidden_states - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.cuda.Event(enable_timing=True) - model_forward_end = torch.cuda.Event(enable_timing=True) - model_forward_start.record() - - if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - inputs_embeds=model_input.inputs_embeds, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **seqlen_agnostic_kwargs, - **model_kwargs, - ) - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Sending KV cache in distributed KV cache transfer setting - # NOTE: the send operation is non-blocking - if self.need_send_kv(model_input, kv_caches): - get_kv_transfer_group().send_kv_caches_and_hidden_states( - # model_executable is used to know which layer the current - # worker is working on, so that we can send KV for only those - # layers. - model_executable, - model_input, - kv_caches, - hidden_or_intermediate_states, - ) - - # Compute the logits in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if self.is_driver_worker: - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - assert isinstance(self.sampler, Sampler) - orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor - if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = True - - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the - # latency from the start time of the driver worker to the end - # time of the driver worker. The model forward time will then - # end up covering the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) - - if model_input.inputs_embeds is not None: - if self.is_driver_worker: - sampled_token_ids = [] - valid_outputs = [] - for sequence_group_output in output.outputs: - if len(sequence_group_output.samples) == 0: - continue - assert len(sequence_group_output.samples) == 1 - valid_outputs.append(sequence_group_output) - sampled_token_ids.append( - sequence_group_output.samples[0].output_token) - sampled_token_ids = torch.tensor(sampled_token_ids).to( - self.device) - sampled_token_ids = broadcast_tensor_dict( - {"sampled_token_ids": - sampled_token_ids})["sampled_token_ids"] - else: - sampled_token_ids = broadcast_tensor_dict( - )["sampled_token_ids"] - if len(sampled_token_ids) > 0: - sampled_token_embeds = \ - self.model.get_input_embeddings(sampled_token_ids) - if self.is_driver_worker: - self.sampler.include_gpu_probs_tensor = \ - orig_include_gpu_probs - for i, sequence_group_output in enumerate(valid_outputs): - sequence_group_output.samples[0].output_embed = \ - sampled_token_embeds[i] - - if not self.is_driver_worker: - return [] - - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert model_input.sampling_metadata is not None - indices = model_input.sampling_metadata.selected_token_indices - if model_input.is_prompt: - hidden_states = hidden_or_intermediate_states.index_select( - 0, indices) - output.prefill_hidden_states = hidden_or_intermediate_states - elif decode_meta.use_cuda_graph: - hidden_states = hidden_or_intermediate_states[:len(indices)] - else: - hidden_states = hidden_or_intermediate_states - - output.hidden_states = hidden_states - - return [output] - - def need_recv_kv(self, model_input, kv_caches) -> bool: - """Check if we need to receive kv-cache from the other worker. - We need to receive KV when - 1. current vLLM instance is KV cache consumer/decode vLLM instance - 2. this batch is not a profiling run - 3. this batch is a prefill run - - Args: - model_input: input to the model executable - kv_caches: vLLM's paged memory - """ - - if self.vllm_config.kv_transfer_config is None: - return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches[0].numel() == 0) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return self.vllm_config.kv_transfer_config.is_kv_consumer and ( - not is_profile_run) and is_prefill_run - - def need_send_kv(self, model_input, kv_caches) -> bool: - """Check if we need to send kv-cache to the other worker. - We need to send KV when - 1. current vLLM instance is KV cache producer/prefill vLLM instance - 2. this batch is not a profiling run - 3. this batch is a prefill run - - Args: - model_input: input to the model executable - kv_caches: vLLM's paged memory - """ - - if self.vllm_config.kv_transfer_config is None: - return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches[0].numel() == 0) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return self.vllm_config.kv_transfer_config.is_kv_producer and ( - not is_profile_run) and is_prefill_run - - -# NOTE: this is nn.Module so the profiler can properly capture/group -# kernels calls made within the graph -class CUDAGraphRunner(nn.Module): - - def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState, is_encoder_decoder_model: bool): - super().__init__() - self.model = model - self.backend_name = backend_name - self.attn_state = attn_state - - self.input_buffers: Dict[str, torch.Tensor] = {} - self.output_buffers: Dict[str, torch.Tensor] = {} - - self._graph: Optional[torch.cuda.CUDAGraph] = None - self._is_encoder_decoder_model = is_encoder_decoder_model - - @property - def graph(self): - assert self._graph is not None - return self._graph - - def capture( - self, - input_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_inputs: Optional[IntermediateTensors], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - memory_pool: Optional[Tuple[int, int]], - stream: torch.cuda.Stream, - **kwargs, - ): - assert self._graph is None - # Run the model a few times without capturing the graph. - # This is to make sure that the captured graph does not include the - # kernel launches for initial benchmarking (e.g., Triton autotune). - # Note one iteration is not enough for torch.compile - for _ in range(_NUM_WARMUP_ITERS): - self.model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - positions=positions, - intermediate_tensors=intermediate_inputs, - **kwargs, - ) - # Wait for the warm up operations to finish before proceeding with - # Graph Capture. - torch.cuda.synchronize() - # Capture the graph. - self._graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - output_hidden_or_intermediate_states = self.model( - input_ids=input_ids, - **({ - "inputs_embeds": inputs_embeds, - } if inputs_embeds is not None else {}), - positions=positions, - intermediate_tensors=intermediate_inputs, - **kwargs, - ) - - if isinstance(output_hidden_or_intermediate_states, torch.Tensor): - hidden_or_intermediate_states = weak_ref_tensor( - output_hidden_or_intermediate_states) - elif isinstance(output_hidden_or_intermediate_states, - IntermediateTensors): - hidden_or_intermediate_states = IntermediateTensors( - tensors={ - key: weak_ref_tensor(value) - for key, value in - output_hidden_or_intermediate_states.tensors.items() - }) - - del output_hidden_or_intermediate_states - # make sure `output_hidden_or_intermediate_states` is deleted - # in the graph's memory pool - gc.collect() - torch.cuda.synchronize() - - # Save the input and output buffers. - self.input_buffers = { - "input_ids": - input_ids, - **({ - "inputs_embeds": inputs_embeds, - } if inputs_embeds is not None else {}), - "positions": - positions, - "kv_caches": - kv_caches, - **self.attn_state.get_graph_input_buffers( - attn_metadata, self._is_encoder_decoder_model), - **kwargs, - } - if intermediate_inputs is not None: - self.input_buffers.update(intermediate_inputs.tensors) - if get_pp_group().is_last_rank: - self.output_buffers = { - "hidden_states": hidden_or_intermediate_states - } - else: - self.output_buffers = hidden_or_intermediate_states - - def forward( - self, - input_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - **kwargs, - ) -> torch.Tensor: - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - # Copy the input tensors to the input buffers. - self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) - if positions is not None: - # in some case like MLA, it will reuse positions in metadata - # but truncate them to the original size - # so the shape is not padded, we need to copy partial only - self.input_buffers["positions"][:positions.shape[0]].copy_( - positions, non_blocking=True) - if inputs_embeds is not None: - self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_( - inputs_embeds, non_blocking=True) - - if self.backend_name != "NO_ATTENTION": - self.input_buffers["slot_mapping"].copy_( - attn_metadata.slot_mapping, non_blocking=True) - - self.attn_state.prepare_graph_input_buffers( - self.input_buffers, attn_metadata, self._is_encoder_decoder_model) - - if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_graphs(self.input_buffers, - **kwargs) - - if "previous_hidden_states" in self.input_buffers: - self.input_buffers["previous_hidden_states"].copy_( - kwargs["previous_hidden_states"], non_blocking=True) - - if intermediate_tensors is not None: - for key in intermediate_tensors.tensors: - if key != "model_execute_time" and key != "model_forward_time": - self.input_buffers[key].copy_(intermediate_tensors[key], - non_blocking=True) - if self._is_encoder_decoder_model: - self.input_buffers["encoder_input_ids"].copy_( - kwargs['encoder_input_ids'], non_blocking=True) - self.input_buffers["encoder_positions"].copy_( - kwargs['encoder_positions'], non_blocking=True) - - # Run the graph. - self.graph.replay() - # Return the output tensor. - if get_pp_group().is_last_rank: - return self.output_buffers["hidden_states"] - - return self.output_buffers diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py deleted file mode 100644 index 7b8fe2f802d68..0000000000000 --- a/vllm/worker/model_runner_base.py +++ /dev/null @@ -1,317 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) - -import torch -import torch.nn as nn - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.models.interfaces import supports_transcription -from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, is_text_generation_model) -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.tasks import GenerationTask, PoolingTask, SupportedTask - -if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.model_executor import SamplingMetadata - -logger = init_logger(__name__) - -T = TypeVar('T', bound="BroadcastableModelInput") - - -def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - attn_metadata: Optional["AttentionMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - AttentionMetadata fields. - """ - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - -def _init_attn_metadata_from_tensor_dict( - attn_backend: "AttentionBackend", - tensor_dict: Dict[str, Any], -) -> Dict[str, Any]: - """ - Helper method to initialize AttentionMetadata based on an - AttentionBackend and broadcastable AttentionMetadata fields. - """ - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict: - if field.name == "input_positions": - valid_attn_kwargs[field.name] = tensor_dict[field.name] - else: - valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - tensor_dict["attn_metadata"] = attn_metadata - return tensor_dict - - -def _init_sampling_metadata_from_tensor_dict( # type: ignore - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize SamplingMetadata based on broadcastable - SamplingMetadata fields. - """ - from vllm.model_executor import SamplingMetadata - - selected_token_indices = tensor_dict.pop("selected_token_indices", None) - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - if selected_token_indices is not None: - tensor_dict["sampling_metadata"] = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - return tensor_dict - - -def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - - -def _init_frozen_model_input_from_tensor_dict( - frozen_model_input_cls: Type["ModelRunnerInputBase"], - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize a frozen ModelInput based on broadcastable - """ - valid_tensor_kwargs = {} - for field in dataclasses.fields(frozen_model_input_cls): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_tensor_kwargs[field.name] = val - - frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) - tensor_dict["frozen_model_input"] = frozen_model_input - return tensor_dict - - -class BroadcastableModelInput(ABC): - - @abstractmethod - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def from_broadcasted_tensor_dict( - cls: Type[T], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> T: - """ - Pop fields from the given tensor_dict and populate a new instance of - BroadcastableModelInput. - """ - raise NotImplementedError - - -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(BroadcastableModelInput): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. - """ - pass - - -class ModelRunnerInputBuilderBase(ABC, Generic[T]): - """A builder to create ModelRunnerInputBase objects. - """ - - @abstractmethod - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - raise NotImplementedError - - @abstractmethod - def add_seq_group(self, seq_group_metadata): - """TBA""" - raise NotImplementedError - - @abstractmethod - def build(self, *args, **kwargs) -> T: - """Build metadata with on-device tensors.""" - raise NotImplementedError - - -class ModelRunnerBase(ABC, Generic[T]): - """ - Model runner interface that abstracts a particular hardware and/or type of - model. Model execution may communicate data with model runners in other - processes, but it should not include control plane metadata communication. - - Each ModelRunnerBase subclass should define a corresponding - ModelRunnerInputBase subclass. - """ - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - - # Map of request_id -> generator used for seeded random sampling - generators: Dict[str, torch.Generator] = {} - - @abstractmethod - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> T: - """ - Make an instance of a ModelRunnerInputBase from the broadcasted tensor - dict. - """ - raise NotImplementedError - - @abstractmethod - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> T: - """ - Prepare the inputs to ModelRunnerBase.execute_model from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def get_model(self) -> nn.Module: - raise NotImplementedError - - def get_supported_generation_tasks(self) -> list[GenerationTask]: - model = self.get_model() - supported_tasks = list[GenerationTask]() - - if is_text_generation_model(model): - supported_tasks.append("generate") - - if supports_transcription(model): - if model.supports_transcription_only: - return ["transcription"] - - supported_tasks.append("transcription") - - return supported_tasks - - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - model = self.get_model() - if not is_pooling_model(model): - return [] - - return list(model.pooler.get_supported_tasks()) - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - tasks = list[SupportedTask]() - - if self.model_config.runner_type == "generate": - tasks.extend(self.get_supported_generation_tasks()) - if self.model_config.runner_type == "pooling": - tasks.extend(self.get_supported_pooling_tasks()) - - return tuple(tasks) - - def execute_model( - self, - model_input: T, - kv_caches: Optional[List[torch.Tensor]], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[List[SamplerOutput]]: - """ - Execute the model on the given input. - """ - raise NotImplementedError - - def get_generators(self, finished_request_ids: Optional[List[str]] = None): - """ - Return dict of per-request generators used for random sampling. - """ - - # Clean up generators from completed requests - if finished_request_ids: - for request_id in finished_request_ids: - self.generators.pop(request_id, None) - - return self.generators - - -class ModelRunnerWrapperBase: - """ - The whole point of this class is to lazily initialize the model_runner. - """ - - def __init__( - self, - model_runner: ModelRunnerBase, - ) -> None: - self.model_runner: ModelRunnerBase = model_runner - - def __getattr__(self, attr): - return getattr(self.model_runner, attr) - - -class InputProcessingError(Exception): - """This exception is raised when an error occurs preparing the inputs for - a single sequence group. - This allows the engine to gracefully handle errors with a single sequence - group without having to fail the entire batch. - """ - - def __init__(self, request_id, message): - """request_id is the id of the offending sequence group""" - self.request_id = request_id - self.message = message - super().__init__(self.message) - - def __str__(self): - return "Failed to prepare inputs for sequence group with request id: " \ - f"{self.request_id}, Error: {self.message}" diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py deleted file mode 100644 index 8317b9abff0cd..0000000000000 --- a/vllm/worker/neuron_model_runner.py +++ /dev/null @@ -1,455 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union - -import torch -from torch import nn - -from vllm.config import DeviceConfig, VllmConfig -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs -from vllm.platforms import current_platform -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class ModelInputForNeuron(ModelRunnerInputBase): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: SamplingMetadata = None - multi_modal_kwargs: BatchedTensorInputs = None - adapter_ids: Optional[str] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - return { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "input_block_ids": self.input_block_ids, - "sampling_metadata": self.sampling_metadata, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForNeuron": - return ModelInputForNeuron( - input_tokens=tensor_dict["input_tokens"], - input_positions=tensor_dict["input_positions"], - input_block_ids=tensor_dict["input_block_ids"], - sampling_metadata=tensor_dict["sampling_metadata"], - multi_modal_kwargs=tensor_dict["multi_modal_kwargs"], - ) - - -class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): - """A model runner for AWS Neuron hardware""" - - # NEURON has an upper limit on the top_k - _MAX_NEURON_SAMPLING_TOP_K = 256 - - def __init__( - self, - vllm_config: VllmConfig, - ): - ModelRunnerBase.__init__(self, vllm_config) - - if (self.model_config is not None - and self.model_config.get_sliding_window()): - logger.warning("Sliding window is not supported on Neuron. " - "The model will run without sliding window.") - self.device_config = (self.device_config if self.device_config - is not None else DeviceConfig()) - self.lora_config = vllm_config.lora_config - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - # Lazy initialization. - self.model: nn.Module # initialize after load_model. - - # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value, - # turn off on-device sampling. - self._on_device_sampling_disabled = int( - os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0")) - - # NEURON needs to update sampling parameters when request IDs change - # across batches. This variable stores the previous batch's request IDs - # to determine if an update is needed. - self._previous_batch_request_ids: List[str] = [] - - if not self._on_device_sampling_disabled: - self._init_neuron_sampling() - - def _init_neuron_sampling(self) -> None: - if current_platform.use_transformers_neuronx(): - from transformers_neuronx.config import GenerationConfig - else: - from transformers import GenerationConfig - logger.warning( - "On-device sampling is turned on in Neuron by default, only " - "top_k, top_p, and temperature are current supported sampling " - "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.") - self.model_config.neuron_sampling_params = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) - - def load_model(self) -> None: - self.model = get_neuron_model(self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - def get_model(self) -> nn.Module: - return self.model - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_block_ids: List[int] = [] - - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - seq_len = len(prompt_tokens) - seq_lens.append(seq_len) - - input_tokens.append(prompt_tokens) - input_positions.append(list(range(seq_len))) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - mm_kwargs = seq_group_metadata.multi_modal_data - if mm_kwargs: - mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs) - multi_modal_kwargs_list.append(mm_kwargs) - - max_seq_len = max(seq_lens) - assert max_seq_len > 0 - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_block_ids: List[int] = [] - context_lens: List[int] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) - context_lens.append(seq_len) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - - return input_tokens, input_positions, input_block_ids - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: - return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForNeuron: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = None - - if not self._on_device_sampling_disabled: - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = ( - self._convert_to_neuron_sampling_params(sampling_params)) - sampling_params.top_k = top_k - sampling_params.top_p = top_p - sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - multi_modal_kwargs_list.append(mm_data) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - generators=self.get_generators(finished_requests_ids)) - - if current_platform.use_transformers_neuronx( - ) and not self._on_device_sampling_disabled: - # Once the request IDs are changed in current iteration, we will - # update the on-device sampling parameters. - current_batch_request_ids = [ - seq_group_meta_data.request_id - for seq_group_meta_data in seq_group_metadata_list - ] - if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(seq_group_metadata_list) - self._previous_batch_request_ids = current_batch_request_ids - - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs) - - def _update_neuron_sampling_params( - self, seq_group_metadata_list: List[SequenceGroupMetadata]): - # Update Neuron sampling parameters (GenerationConfig in Neuron) - current_sampling_params = self.model_config.neuron_sampling_params - assert current_sampling_params is not None, ( - f"Failed to update sampling_params, " - f"current sampling params is {current_sampling_params}") - - is_update_needed = False - - top_k = current_sampling_params.top_k - top_p = current_sampling_params.top_p - temperature = current_sampling_params.temperature - - # The index of a sequence's sampling parameters in neuron is equal to - # its index in `input_block_ids`. - for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - - seq_group_top_k = sampling_params.top_k - seq_group_top_p = sampling_params.top_p - seq_group_temperature = sampling_params.temperature - - for seq_id in seq_ids: - index = seq_group_metadata.block_tables[seq_id][0] - if (top_k[index] != seq_group_top_k - or top_p[index] != seq_group_top_p - or temperature[index] != seq_group_temperature): - is_update_needed = True - - top_k[index] = seq_group_top_k - top_p[index] = seq_group_top_p - temperature[index] = seq_group_temperature - - # update_generation_config is only available in transformers-neuronx - if is_update_needed and current_platform.use_transformers_neuronx(): - self.model.model.update_generation_config(current_sampling_params) - - def _convert_to_neuron_sampling_params( - self, sampling_params: SamplingParams) -> Tuple[int, float, float]: - # Returns the top_k, top_p and temperature parameters for neuron. - top_k = sampling_params.top_k - top_p = sampling_params.top_p - temperature = sampling_params.temperature - - if temperature == 0.0: - # Enable greedy sampling on zero temperature - return (1, 1.0, 1.0) - if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: - top_k = self._MAX_NEURON_SAMPLING_TOP_K - - return (top_k, top_p, temperature) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "NeuronModelRunner does not support multi-step execution.") - - # extract top_k, top_p and temperature from model_input for neuron - # forward call - sampling_params = (torch.tensor([[ - seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature - ] for seq_group in model_input.sampling_metadata.seq_groups])) - - if current_platform.use_neuronx_distributed(): - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - sampling_params=sampling_params, - adapter_ids=model_input.adapter_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - elif current_platform.use_transformers_neuronx(): - # [TODO] validate on-device sampling - # The model signature may need change for on-device sampling - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - - # Compute the logits only if the on-device sampling is turned off as - # on-device sampling outputs the token ids. - if self._on_device_sampling_disabled: - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - else: - logits = hidden_states - - # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - return [output] - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - def process_multi_modal_data_neuron(self, mm_data): - # this is a no-op for NeuronModelRunner - return mm_data - - def remove_all_loras(self): - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def add_lora(self, lora_request: LoRARequest): - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py deleted file mode 100644 index 3e4512a639083..0000000000000 --- a/vllm/worker/neuron_worker.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A Neuron worker class.""" -import os -from typing import List, Optional, Set, Tuple - -import torch.distributed - -from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.platforms.neuron import NeuronFramework -from vllm.sequence import ExecuteModelRequest -from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class NeuronWorker(LocalOrDistributedWorkerBase): - """A worker class that executes the model on a group of neuron cores. - """ - - model_runner: NeuronModelRunner - - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - self.lora_config = vllm_config.lora_config - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - neuron_framework = current_platform.get_neuron_framework_to_use() - if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: - self.model_runner = self.get_tnx_model_runner(vllm_config) - elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: - self.model_runner = self.get_neuronx_distributed_model_runner( - vllm_config) - else: - raise NotImplementedError( - "Specified framework" + - f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + - " is either not installed or not supported." + - " Supported frameworks: " + - "[transformers-neuronx, neuronx-distributed-inference]") - - def get_tnx_model_runner(self, vllm_config): - assert (self.lora_config - is None), ("LoRA is not supported for TransformersNeuronX " - "framework.") - if self.speculative_config is not None: - raise NotImplementedError( - "Speculative decoding is not supported for TransformersNeuronX" - ) - return NeuronModelRunner(vllm_config=vllm_config) - - def get_neuronx_distributed_model_runner(self, vllm_config): - from vllm.worker.neuronx_distributed_model_runner import ( - NeuronxDistributedModelRunner) - if self.speculative_config is not None: - assert (self.lora_config is None), ( - "LoRA is not supported for Speculative Decoding") - raise NotImplementedError( - "Speculative decoding is not supported for NeuronxDistributed") - return NeuronxDistributedModelRunner(vllm_config=vllm_config) - - def init_device(self) -> None: - self.init_distributed_environment() - - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - Swapping is not yet supported, so always return num_cpu_blocks=0. - - We configure num_gpu_blocks to be equal to max_num_seqs. - """ - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - num_gpu_blocks = self.scheduler_config.max_num_seqs + 1 - - # Swap not yet supported with Neuron backend. - num_cpu_blocks = 0 - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache. - """ - - # Different values are not tested. - assert num_cpu_blocks == 0 - assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1 - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - @property - def do_metadata_broadcast(self) -> bool: - return False - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return None - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - return WorkerInput(num_seq_groups=len( - execute_model_req.seq_group_metadata_list), ) - - def execute_worker(self, worker_input: WorkerInput) -> None: - pass - - def get_cache_block_size_bytes(self) -> int: - """Determine the size in bytes of a cache block. - - This is required for speculative decoding; it is not yet implemented. - """ - raise NotImplementedError - - def init_distributed_environment(self): - """Neuron uses transformers-neuronx for tensor parallelism. - - vLLM still needs the environment initialized when TP/PP > 1 - """ - init_distributed_environment( - world_size=1, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend=current_platform.dist_backend, - ) - - ensure_model_parallel_initialized( - 1, - 1, - ) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.list_loras() diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py deleted file mode 100644 index 2a0f4e77c99e5..0000000000000 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ /dev/null @@ -1,294 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Set - -import torch -from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import ( - get_all_supported_aspect_ratios) -from neuronx_distributed_inference.modules.generation.sampling import ( - prepare_sampling_params) -from neuronx_distributed_inference.modules.lora_serving import ( - LoraCheckpoint, LoraServingConfig) - -from vllm.config import VllmConfig -from vllm.entrypoints.openai.serving_models import LoRAModulePath -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.neuronx_distributed import ( - _get_model_architecture, get_neuron_model) -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.worker.neuron_model_runner import (ModelInputForNeuron, - NeuronModelRunner) - -logger = init_logger(__name__) - - -class NeuronxDistributedModelRunner(NeuronModelRunner): - - def __init__( - self, - vllm_config: VllmConfig, - ): - super().__init__(vllm_config) - self.lora_checkpoint = None - self.model = None - self.lora_serving_config = None - - @staticmethod - def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]): - if not lora_modules: - return None - return {_.get("name"): _.get("path") for _ in lora_modules} - - def _get_nxdi_lora_config(self): - override_neuron_config = self.model_config.override_neuron_config - lora_modules = override_neuron_config.pop("lora_modules", None) - target_modules = override_neuron_config.pop("target_modules", None) - lora_ckpt_paths = self._get_lora_paths_strings(lora_modules) - if self.lora_config.max_loras < len(lora_ckpt_paths): - raise ValueError( - "Number of LoRAs (%s) exceeds maximum " - "allowed (%s)", len(lora_ckpt_paths), - self.lora_config.max_loras) - - return LoraServingConfig( - max_loras=self.lora_config.max_loras, - max_lora_rank=self.lora_config.max_lora_rank, - target_modules=target_modules, - lora_ckpt_paths=lora_ckpt_paths, - ) - - def load_model(self) -> None: - # Update LoRA config - if self.lora_config is not None: - self.lora_serving_config = self._get_nxdi_lora_config() - self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config) - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - lora_serving_config=self.lora_serving_config) - - def get_nxd_sampling_params(self, sampling_metadata): - if self.model.config.neuron_config.on_device_sampling_config: - max_topk = (self.model.config.neuron_config. - on_device_sampling_config.global_topk) - else: - max_topk = self.model.config.vocab_size - - top_k = [1] * self.scheduler_config.max_num_seqs - top_p = [1.0] * self.scheduler_config.max_num_seqs - temperature = [1.0] * self.scheduler_config.max_num_seqs - - for index, sequenceGroupToSample in enumerate( - sampling_metadata.seq_groups): - top_k[index] = (sequenceGroupToSample.sampling_params.top_k - if sequenceGroupToSample.sampling_params.top_k > 0 - else max_topk) - top_p[index] = sequenceGroupToSample.sampling_params.top_p - temperature[index] = ( - sequenceGroupToSample.sampling_params.temperature) - - sampling_params = prepare_sampling_params( - batch_size=self.scheduler_config.max_num_seqs, - top_k=top_k, - top_p=top_p, - temperature=temperature) - return sampling_params - - def get_multi_modal_data_neuron(self, input_images): - raise NotImplementedError("need to restore multi-modal support") - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "NeuronModelRunner does not support multi-step execution.") - - if _get_model_architecture( - self.model.config) != "MllamaForConditionalGeneration": - return super().execute_model(model_input, kv_caches, - intermediate_tensors, num_steps) - - sampling_params = self.get_nxd_sampling_params( - model_input.sampling_metadata) - - if model_input.multi_modal_kwargs.get('pixel_values') is not None: - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=model_input.multi_modal_kwargs.get( - 'pixel_values'), - aspect_ratios=model_input.multi_modal_kwargs.get( - 'aspect_ratios'), - sampling_params=sampling_params, - num_chunks=model_input.multi_modal_kwargs.get('num_chunks'), - has_image=model_input.multi_modal_kwargs.get( - 'has_image').squeeze(1), - ) - else: - bs = model_input.input_tokens.shape[0] if (model_input.input_tokens - is not None) else 1 - empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560], - dtype=torch.bfloat16) - empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64) - num_chunks = torch.zeros((bs, 1), dtype=torch.int32) - has_image = torch.zeros([bs], dtype=torch.int32) - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=empty_pixel_values, - aspect_ratios=empty_aspect_ratios, - sampling_params=sampling_params, - num_chunks=num_chunks, - has_image=has_image, - ) - - output = self.model.sample( - hidden_states=hidden_states, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - - def process_multi_modal_data_neuron(self, mm_data): - # Neuron uses aspect_ratios instead of aspect_ratio_ids - all_supported_aspect_ratios = get_all_supported_aspect_ratios( - self.model.config.vision_config.max_num_tiles) - aspect_ratio_ids = mm_data.get("aspect_ratio_ids") - mm_data["aspect_ratios"] = torch.tensor( - all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0) - - # Neuron's num_chunks is HF's num_tiles - mm_data["num_chunks"] = mm_data.get("num_tiles") - - # Input has an image if it has pixel_values - bs = mm_data["num_chunks"].shape[0] - pixel_values = mm_data.get("pixel_values") - if pixel_values is not None and not torch.all(pixel_values == 0): - mm_data["has_image"] = torch.ones(bs) - - else: - mm_data["has_image"] = torch.zeros(bs) - return mm_data - - def _get_lora_adapter_ids(self, seq_group_metadata_list): - # set LoRA adapter IDs for multi-lora serving - batch_size = len(seq_group_metadata_list) - if self.lora_checkpoint is not None: - # "0" indicates NxDI to use the base model for inference - adapter_ids = ["0"] * batch_size - for idx, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.lora_request is not None: - adapter_ids[ - idx] = seq_group_metadata.lora_request.lora_name - - # convert adapter_ids from strings to integers - adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices( - adapter_ids, batch_size) - else: - adapter_ids = torch.zeros((batch_size), dtype=torch.int32) - - return adapter_ids - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForNeuron: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = None - - if not self._on_device_sampling_disabled: - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = ( - self._convert_to_neuron_sampling_params(sampling_params)) - sampling_params.top_k = top_k - sampling_params.top_p = top_p - sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - multi_modal_kwargs_list.append(mm_data) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - generators=self.get_generators(finished_requests_ids)) - - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - adapter_ids=lora_adapter_ids) - - def remove_all_loras(self): - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def add_lora(self, lora_request: LoRARequest): - logger.warning( - "Adding LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config. If you supplied " - "the parameter, you can ignore this warning. Ignoring" - "lora request: ", lora_request) - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py deleted file mode 100644 index 8d8d9b4d0503f..0000000000000 --- a/vllm/worker/pooling_model_runner.py +++ /dev/null @@ -1,221 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast - -import torch - -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.model_executor.models.interfaces_base import VllmModelForPooling -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import MultiModalKwargs -from vllm.pooling_params import PoolingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, - ModelInputForGPUBuilder) - -logger = init_logger(__name__) - - -@dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): - """ - Used by the PoolingModelRunner. - """ - pooling_metadata: Optional["PoolingMetadata"] = None - - -class PoolingModelRunner( - GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): - _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( - ModelInputForGPUWithPoolingMetadata) - _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - ): - super().__init__(vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForGPUWithPoolingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError( - "PoolingModelRunner does not support multi-step execution.") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - virtual_engine = model_input.virtual_engine - # Pooling models are (ab-)used also to integrate non text models that - # are not autoregressive (PrithviGeosaptialMAE). - # These model might not use attention and do not really have a prefill - # and decode phase. The model input is processed in one shot and both - # decode_metadata and prefill_metadata would be None for such models. - # See the PlaceholderAttentionMetadata class. - # TODO: Figure out if cuda_graph is of any use for these models and - # explore how to leverage it. - if (prefill_meta is None and decode_meta is not None - and decode_meta.use_cuda_graph): - if model_input.inputs_embeds is None: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, False)]) - else: - graph_batch_size = model_input.inputs_embeds.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, True)]) - else: - model_executable = self.model - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.cuda.Event(enable_timing=True) - model_forward_end = torch.cuda.Event(enable_timing=True) - model_forward_start.record() - - cross_enc_kwargs = {} - if model_input.token_types is not None: - cross_enc_kwargs["token_type_ids"] = model_input.token_types - - with set_forward_context(model_input.attn_metadata, self.vllm_config, - virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **cross_enc_kwargs, - **seqlen_agnostic_kwargs, - ) - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Only perform pooling in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - # Only perform pooling in the driver worker. - if not self.is_driver_worker: - return [] - - pooling_metadata = model_input.pooling_metadata - assert pooling_metadata is not None - - pooling_metadata.build_pooling_cursor( - num_scheduled_tokens=pooling_metadata.prompt_lens, - device=hidden_or_intermediate_states.device) - - return [ - self.model.pooler(hidden_states=hidden_or_intermediate_states, - pooling_metadata=pooling_metadata) - ] - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForGPUWithPoolingMetadata: - return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithPoolingMetadata: - assert seq_group_metadata_list is not None - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) - - return dataclasses.replace(model_input, - pooling_metadata=pooling_metadata) - - def _prepare_pooling( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> PoolingMetadata: - """Prepare PoolingMetadata for the sequence group metadata list.""" - seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - - pooling_params = seq_group_metadata.pooling_params - assert pooling_params is not None - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") - - model = cast(VllmModelForPooling, self.model) - to_update = model.pooler.get_pooling_updates(task) - to_update.apply(pooling_params) - - seq_groups.append((seq_ids, pooling_params)) - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - pooling_metadata = PoolingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - ) - - return pooling_metadata diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py deleted file mode 100644 index 512a1dca73701..0000000000000 --- a/vllm/worker/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -''' -Worker-related helper functions. -''' - -from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS -from vllm.worker.model_runner import GPUModelRunnerBase - - -def assert_enc_dec_mr_supported_scenario( - enc_dec_mr: GPUModelRunnerBase) -> None: - ''' - Asserted that the provided encoder/decoder model runner instance reflects - a supported scenario. - ''' - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - - if enc_dec_mr.cache_config.enable_prefix_caching: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE']) - - if enc_dec_mr.sliding_window is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA']) - - if enc_dec_mr.scheduler_config.chunked_prefill_enabled: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ - 'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL']) - - if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping', - None) is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP'] - ) - - if enc_dec_mr.lora_config is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA']) - - if enc_dec_mr.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - - if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py deleted file mode 100644 index fc24d95b80f2c..0000000000000 --- a/vllm/worker/worker.py +++ /dev/null @@ -1,587 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A GPU worker class.""" -import gc -import os -from contextlib import nullcontext -from typing import Dict, List, Optional, Set, Tuple, Type, Union - -import torch -import torch.distributed - -import vllm.envs as envs -from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.device_allocator.cumem import CuMemAllocator -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, - memory_profiling) -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.pooling_model_runner import PoolingModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class Worker(LocalOrDistributedWorkerBase): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single GPU. The worker is responsible for - maintaining the KV cache and executing the model on the GPU. In case of - distributed inference, each worker is assigned a partition of the model. - """ - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - ) -> None: - WorkerBase.__init__(self, vllm_config) - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_config = self.speculative_config - model_config = self.model_config - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.hf_config.model_type == - model_config.hf_config.model_type) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ("medusa", - "mlp_speculator", - "eagle", - "deepseek_mtp", - "glm4_moe_mtp", - "mimo_mtp", - "ernie_mtp")) \ - else {"return_hidden_states": True} - - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_config.runner_type == "pooling": - ModelRunnerClass = PoolingModelRunner - elif self.model_config.is_encoder_decoder: - ModelRunnerClass = EncoderDecoderModelRunner - self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - vllm_config=self.vllm_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - **speculative_args, - ) - if model_runner_cls is not None: - self.model_runner = model_runner_cls(self.model_runner) - - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] - # Initialize gpu_cache as pooling models don't initialize kv_caches - self.gpu_cache: Optional[List[List[torch.Tensor]]] = None - self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} - - # Buffers saved before sleep - self._sleep_saved_buffers: Dict[str, torch.Tensor] = {} - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() - - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - print( - self.profiler.key_averages().table(sort_by="self_cuda_time_total")) - - def sleep(self, level: int = 1) -> None: - free_bytes_before_sleep = torch.cuda.mem_get_info()[0] - - # Save the buffers before level 2 sleep - if level == 2: - model = self.model_runner.model - self._sleep_saved_buffers = { - name: buffer.cpu().clone() - for name, buffer in model.named_buffers() - } - - allocator = CuMemAllocator.get_instance() - allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) - free_bytes_after_sleep, total = torch.cuda.mem_get_info() - freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep - used_bytes = total - free_bytes_after_sleep - assert freed_bytes >= 0, "Memory usage increased after sleeping." - logger.info( - "Sleep mode freed %.2f GiB memory, " - "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, - used_bytes / GiB_bytes) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - allocator = CuMemAllocator.get_instance() - allocator.wake_up(tags=tags) - - # Restore the buffers after level 2 sleep - if len(self._sleep_saved_buffers): - model = self.model_runner.model - for name, buffer in model.named_buffers(): - if name in self._sleep_saved_buffers: - buffer.data.copy_(self._sleep_saved_buffers[name].data) - self._sleep_saved_buffers = {} - - def init_device(self) -> None: - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) - - _check_if_gpu_supports_dtype(self.model_config.dtype) - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - self.baseline_snapshot = MemorySnapshot() - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") - context = allocator.use_memory_pool(tag="weights") - else: - context = nullcontext() - with context: - self.model_runner.load_model() - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self.model_runner.save_sharded_state( - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - Tip: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with memory_profiling( - self.baseline_snapshot, - weights_memory=self.model_runner.model_memory_usage) as result: - self.model_runner.profile_run() - - self._assert_memory_footprint_increased_during_profiling() - - memory_for_current_instance = total_gpu_memory * \ - self.cache_config.gpu_memory_utilization - available_kv_cache_memory = (memory_for_current_instance - - result.non_kv_cache_memory) - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = self.get_cache_block_size_bytes() - if cache_block_size == 0: - num_gpu_blocks = 0 - num_cpu_blocks = 0 - else: - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - - msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" - "the current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" - "model weights take " - f"{(result.weights_memory / GiB_bytes):.2f}GiB;" - " non_torch_memory takes " - f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" - " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") - - logger.info(msg) - # Final cleanup - gc.collect() - - return num_gpu_blocks, num_cpu_blocks - - def _assert_memory_footprint_increased_during_profiling(self): - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - free_gpu_memory, total = torch.cuda.mem_get_info() - cuda_memory = total - free_gpu_memory - assert self.baseline_snapshot.cuda_memory < cuda_memory, ( - "Error in memory profiling. " - f"Initial used memory {self.baseline_snapshot.cuda_memory}, " - f"currently used memory {cuda_memory}. " - f"This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks. - - This also warms up the model, which may record CUDA graphs. - """ - raise_if_cache_size_invalid( - num_gpu_blocks, self.cache_config.block_size, - self.cache_config.is_attention_free, - self.model_config.max_model_len, - self.parallel_config.pipeline_parallel_size) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(tag="kv_cache") - else: - context = nullcontext() - with context: - self._init_cache_engine() - self._warm_up_model() - - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.gpu_cache = [ - self.cache_engine[ve].gpu_cache - for ve in range(self.parallel_config.pipeline_parallel_size) - ] - - # Layer pairings for cross-layer KV sharing. - # If an Attention layer `layer_name` is in the keys of this dict, it - # means this layer will perform attention using the keys and values - # from the KV cache of `shared_kv_cache_layers[layer_name]`. - shared_kv_cache_layers: dict[str, str] = {} - - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - - for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - shared_kv_cache_layers[layer_name] = kv_tgt_layer - - bind_kv_cache(self.compilation_config.static_forward_context, - self.gpu_cache, shared_kv_cache_layers) - - def _warm_up_model(self) -> None: - # warm up sizes that are not in cudagraph capture sizes, - # but users still want to compile for better performance, - # e.g. for the max-num-batched token size in chunked prefill. - warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: - warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes - ] - for size in sorted(warmup_sizes, reverse=True): - logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size) - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - @property - def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return self.gpu_cache - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - virtual_engine = execute_model_req.virtual_engine - num_steps = execute_model_req.num_steps - num_seq_groups = len(execute_model_req.seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) - - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - num_steps=num_steps, - ) - - @torch.inference_mode() - def execute_worker(self, worker_input: WorkerInput) -> None: - virtual_engine = worker_input.virtual_engine - # Issue cache operations. - if (worker_input.blocks_to_swap_in is not None - and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine[virtual_engine].swap_in( - worker_input.blocks_to_swap_in) - if (worker_input.blocks_to_swap_out is not None - and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine[virtual_engine].swap_out( - worker_input.blocks_to_swap_out) - if (worker_input.blocks_to_copy is not None - and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) - - def _get_cached_seq_group_metadata( - self, - seq_group_metadata_list: List[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]], - finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: - """Return a list of cached Sequence Group Metadata after updating its - state. - - It is used because scheduler only sends delta to workers to reduce - the data payload size. The function also cleans up cache based on - a given `finished_request_ids`. - """ - new_seq_group_metadata_list = [] - for metadata_or_delta in seq_group_metadata_list: - request_id = metadata_or_delta.request_id - if request_id not in self._seq_group_metadata_cache: - # The first prefill. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[request_id] = metadata_or_delta - else: - # The first prefill is already cached. - if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): - self._seq_group_metadata_cache[request_id].apply_delta( - metadata_or_delta) - else: - # If metadata snapshot is sent again, it is - # preempted. Reset the cache because we need to start - # from scratch. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[ - request_id] = metadata_or_delta - - new_seq_group_metadata_list.append( - self._seq_group_metadata_cache[request_id]) - - # Clean up finished ids - for finished_id in finished_request_ids: - del self._seq_group_metadata_cache[finished_id] - - return new_seq_group_metadata_list - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Optional[List[SamplerOutput]]: - if execute_model_req is not None: - new_seq_group_metadata_list = self._get_cached_seq_group_metadata( - execute_model_req.seq_group_metadata_list, - execute_model_req.finished_requests_ids) - - execute_model_req.seq_group_metadata_list = ( - new_seq_group_metadata_list) - output = super()._execute_model_spmd(execute_model_req, - intermediate_tensors) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() - - @property - def max_model_len(self) -> int: - return self.model_config.max_model_len - - @property - def vocab_size(self) -> int: - return self.model_runner.vocab_size - - def get_cache_block_size_bytes(self) -> int: - """Get the size of the KV cache block size in bytes. - """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - - -def init_worker_distributed_environment( - vllm_config: VllmConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - parallel_config = vllm_config.parallel_config - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, - current_platform.dist_backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") - - -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, - max_model_len, pipeline_parallel_size) -> None: - if is_attention_free and num_gpu_blocks != 0: - raise ValueError("No memory should be allocated for the cache blocks " - f"for an attention-free model, but {num_gpu_blocks} " - "blocks are allocated.") - if not is_attention_free and num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) - if not is_attention_free and max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py deleted file mode 100644 index a1fa7f2cf7a2e..0000000000000 --- a/vllm/worker/worker_base.py +++ /dev/null @@ -1,643 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import os -import time -from abc import abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union - -import cloudpickle -import torch -import torch.nn as nn - -from vllm.config import (ObservabilityConfig, VllmConfig, - set_current_vllm_config) -from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import (enable_trace_function_call_for_thread, - resolve_obj_by_qualname, run_method, - update_environment_variables, - warn_for_unimplemented_methods) -from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) - -logger = init_logger(__name__) - - -@warn_for_unimplemented_methods -class WorkerBase: - """Worker interface that allows vLLM to cleanly separate implementations for - different hardware. Also abstracts control plane communication, e.g., to - communicate request metadata to other workers. - """ - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - self.kv_transfer_config = vllm_config.kv_transfer_config - self.compilation_config = vllm_config.compilation_config - from vllm.platforms import current_platform - self.current_platform = current_platform - - def init_device(self) -> None: - """Initialize device state, such as loading the model or other on-device - memory allocations. - """ - raise NotImplementedError - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. - """ - raise NotImplementedError - - def get_model(self) -> nn.Module: - raise NotImplementedError - - def load_model(self) -> None: - """Load model onto target device.""" - raise NotImplementedError - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - raise NotImplementedError - - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - with self.current_platform.inference_mode(): - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - def get_cache_block_size_bytes(self) -> int: - """Return the size of a single cache block, in bytes. Used in - speculative decoding. - """ - raise NotImplementedError - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - def list_loras(self) -> Set[int]: - raise NotImplementedError - - @property - def vocab_size(self) -> int: - """Get vocabulary size from model configuration.""" - return self.model_config.get_vocab_size() - - -class DelegateWorkerBase(WorkerBase): - """ - A class that delegates all methods to another WorkerBase instance. This is - useful for creating a WorkerBase that wraps another WorkerBase instance, - e.g. speculative decoding. - """ - worker: WorkerBase - - def __init__( - self, - *args, - **kwargs, - ) -> None: - vllm_config: VllmConfig = kwargs.get("vllm_config") - cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls) - self.worker = cls(*args, **kwargs) - - def init_device(self) -> None: - self.worker.init_device() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - return self.worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def load_model(self) -> None: - """Load model onto target device.""" - self.worker.load_model() - - def get_model(self) -> nn.Module: - return self.worker.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - return self.worker.execute_model(execute_model_req) - - def get_cache_block_size_bytes(self) -> int: - return self.worker.get_cache_block_size_bytes() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.worker.list_loras() - - def __getattr__(self, attr): - return getattr(self.worker, attr) - - -class LoRANotSupportedWorkerBase(WorkerBase): - """Partial implementation of WorkerBase that raises exceptions when LoRA - methods are invoked. - """ - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def remove_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def pin_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def list_loras(self) -> Set[int]: - raise ValueError(f"{type(self)} does not support LoRA") - - -@dataclasses.dataclass(frozen=True) -class WorkerInput: - """Local inputs to each worker. May contain device-specific data. These - fields should be broadcastable to other workers. - """ - - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - virtual_engine: int = 0 - num_steps: int = 1 - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], - tensor_dict: Dict[str, Any], - ) -> "WorkerInput": - """ - Pop fields from the given tensor_dict and populate a new instance of - WorkerInput. - """ - return cls( - num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - ) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. - """ - tensor_dict = { - "num_seq_groups": self.num_seq_groups, - "blocks_to_swap_in": self.blocks_to_swap_in, - "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, - } - - return tensor_dict - - -class LocalOrDistributedWorkerBase(WorkerBase): - """ - Partial implementation of WorkerBase that has a default `execute_model` - definition to perform metadata transfer between workers when in distributed - mode. Subclasses of this interface should use model runners that inherit - from ModelRunnerBase, and should only need to implement worker-local logic. - If custom control plane logic is needed to transfer metadata, or if the - model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. - """ - is_driver_worker: bool - model_runner: ModelRunnerBase - observability_config: Optional[ObservabilityConfig] = None - - @property - @abstractmethod - def do_metadata_broadcast(self) -> bool: - """ - Used by the default `execute_model` to check whether broadcast is - needed to transfer request inputs from the driver worker to other - workers in the TP group. If WorkerBase subclass only supports - single-worker execution, then this method should return False. - """ - raise NotImplementedError - - @property - @abstractmethod - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - """ - Gets the list of kv caches to pass to the worker's model runner. Each - element in the list is a kv cache corresponding to a particular virtual - engine (PP stream). Used by the default `execute_model`. If the worker's - model runner does not follow the ModelRunnerBase interface, then inherit - from WorkerBase instead. - """ - raise NotImplementedError - - @abstractmethod - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - """ - Prepare the inputs to WorkerBase.execute_worker from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def execute_worker(self, worker_input: WorkerInput) -> None: - """ - Process an execution request. - """ - raise NotImplementedError - - def _get_worker_input_from_broadcast( - self - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ Get the worker input from the broadcasted tensor dict. """ - assert self.do_metadata_broadcast - assert not self.is_driver_worker - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) - - kwargs = extract_previous_hidden_states(broadcast_data) - - return model_input, worker_input, kwargs - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: - """ Get the driver input and broadcast it to other workers. """ - assert self.is_driver_worker - - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - kwargs = extract_previous_hidden_states(execute_model_req) - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update(model_input.as_broadcastable_tensor_dict()) - broadcast_data.update(kwargs) - broadcast_tensor_dict(broadcast_data, src=0) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( # type: ignore - model_input, - async_callback=execute_model_req.async_callback) - - return model_input, worker_input, kwargs - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ - Prepare the inputs to ModelRunner and workers. - """ - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - return self._get_driver_input_and_broadcast(execute_model_req) - else: - return self._get_worker_input_from_broadcast() - - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() - - inputs = self.prepare_input(execute_model_req) - if inputs is None: - return None - - model_input, worker_input, kwargs = inputs - num_steps = worker_input.num_steps - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() - - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) - - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - assert isinstance(output, IntermediateTensors) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) - - # output is List[SamplerOutput] - return output - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None - ) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, ( - "_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - kwargs = extract_previous_hidden_states(execute_model_req) - - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) - - -class WorkerWrapperBase: - """ - This class represents one process in an executor/engine. It is responsible - for lazily initializing the worker and handling the worker's lifecycle. - We first instantiate the WorkerWrapper, which remembers the worker module - and class name. Then, when we call `update_environment_variables`, and the - real initialization happens in `init_worker`. - """ - - def __init__( - self, - vllm_config: VllmConfig, - rpc_rank: int = 0, - ) -> None: - """ - Initialize the worker wrapper with the given vllm_config and rpc_rank. - Note: rpc_rank is the rank of the worker in the executor. In most cases, - it is also the rank of the worker in the distributed group. However, - when multiple executors work together, they can be different. - e.g. in the case of SPMD-style offline inference with TP=2, - users can launch 2 engines/executors, each with only 1 worker. - All workers have rpc_rank=0, but they have different ranks in the TP - group. - """ - self.rpc_rank = rpc_rank - self.worker: Optional[WorkerBase] = None - self.vllm_config: Optional[VllmConfig] = None - # do not store this `vllm_config`, `init_worker` will set the final - # one. TODO: investigate if we can remove this field in - # `WorkerWrapperBase`, `init_cached_hf_modules` should be - # unnecessary now. - if vllm_config.model_config is not None: - # it can be None in tests - trust_remote_code = vllm_config.model_config.trust_remote_code - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: - """ - Adjust the rpc_rank based on the given mapping. - It is only used during the initialization of the executor, - to adjust the rpc_rank of workers after we create all workers. - """ - if self.rpc_rank in rank_mapping: - self.rpc_rank = rank_mapping[self.rpc_rank] - - def update_environment_variables(self, envs_list: List[Dict[str, - str]]) -> None: - envs = envs_list[self.rpc_rank] - key = 'CUDA_VISIBLE_DEVICES' - if key in envs and key in os.environ: - # overwriting CUDA_VISIBLE_DEVICES is desired behavior - # suppress the warning in `update_environment_variables` - del os.environ[key] - update_environment_variables(envs) - - def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: - """ - Here we inject some common logic before initializing the worker. - Arguments are passed to the worker class constructor. - """ - kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config") - assert self.vllm_config is not None, ( - "vllm_config is required to initialize the worker") - enable_trace_function_call_for_thread(self.vllm_config) - - from vllm.plugins import load_general_plugins - load_general_plugins() - - if isinstance(self.vllm_config.parallel_config.worker_cls, str): - worker_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_cls) - else: - logger.warning( - "passing worker_cls as a class object is strongly deprecated," - " as the serialization of class objects can be tricky and" - " error-prone. To be safe, please keep the class in a separate" - " module and pass the qualified name of the class as a string." - ) - assert isinstance(self.vllm_config.parallel_config.worker_cls, - bytes) - worker_class = cloudpickle.loads( - self.vllm_config.parallel_config.worker_cls) - if self.vllm_config.parallel_config.worker_extension_cls: - worker_extension_cls = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_extension_cls) - extended_calls = [] - if worker_extension_cls not in worker_class.__bases__: - # check any conflicts between worker and worker_extension_cls - for attr in dir(worker_extension_cls): - if attr.startswith("__"): - continue - assert not hasattr(worker_class, attr), ( - f"Worker class {worker_class} already has an attribute" - f" {attr}, which conflicts with the worker" - f" extension class {worker_extension_cls}.") - if callable(getattr(worker_extension_cls, attr)): - extended_calls.append(attr) - # dynamically inherit the worker extension class - worker_class.__bases__ = worker_class.__bases__ + ( - worker_extension_cls, ) - logger.info( - "Injected %s into %s for extended collective_rpc calls %s", - worker_extension_cls, worker_class, extended_calls) - with set_current_vllm_config(self.vllm_config): - # To make vLLM config available during worker initialization - self.worker = worker_class(**kwargs) - assert self.worker is not None - - def initialize_from_config(self, kv_cache_configs: List[Any]) -> None: - kv_cache_config = kv_cache_configs[self.rpc_rank] - with set_current_vllm_config(self.vllm_config): - self.worker.initialize_from_config(kv_cache_config) # type: ignore - - def init_device(self): - with set_current_vllm_config(self.vllm_config): - # To make vLLM config available during device initialization - self.worker.init_device() # type: ignore - - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - try: - # method resolution order: - # if a method is defined in this class, it will be called directly. - # otherwise, since we define `__getattr__` and redirect attribute - # query to `self.worker`, the method will be called on the worker. - return run_method(self, method, args, kwargs) - except Exception as e: - # if the driver worker also execute methods, - # exceptions in the rest worker may cause deadlock in rpc like ray - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method!r}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - - def __getattr__(self, attr): - return getattr(self.worker, attr) - - -def extract_previous_hidden_states( - data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ - Dict[str, torch.Tensor]: - """If data contains previous_hidden_states, extract it. This returns a dict - which can be used directly as additional kwargs in any following - execute_model calls. This is used in draft models like EAGLE.""" - output = {} - - # When called from non-driver worker, data is dict but when called from - # driver worker, data is ExecuteModelRequest. - if isinstance(data, dict): - if "previous_hidden_states" in data: - output["previous_hidden_states"] = data["previous_hidden_states"] - elif data.previous_hidden_states is not None: - output["previous_hidden_states"] = data.previous_hidden_states\ - .hidden_states - - return output